diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs index e38f00c937f..be2a19d323b 100644 --- a/.git-blame-ignore-revs +++ b/.git-blame-ignore-revs @@ -8,3 +8,11 @@ 4fccbe2d18c6d2f4059036d61489467c780bbc0e # Delete `FastIndexedSeq` fcc9ffab3cd68c4cfb26a1553d65118797c59d6e +# blackens hail directory +da2790242a40ec425a53a02707d261c893b264f7 +# scalafmt +422edf6386616711ca70f87c455f76781ac925d4 +# replaces black formatting with ruff +fa2ef0f2c76654d0c037ff6db60ccb8842fb8539 +# ruff lint python imports +01a6a6a107faf204d4f5c20f8ae510d2c35518e9 diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml deleted file mode 100644 index 2ab495924c4..00000000000 --- a/.github/workflows/codeql-analysis.yml +++ /dev/null @@ -1,34 +0,0 @@ -name: "CodeQL" -on: - push: - branches: [ "main" ] - pull_request: - branches: [ "main" ] - schedule: - - cron: '24 19 * * 0' -jobs: - analyze: - name: Analyze - runs-on: ubuntu-latest - permissions: - actions: read - contents: read - security-events: write - strategy: - fail-fast: false - matrix: - language: [ 'cpp', 'java', 'javascript', 'python' ] - steps: - - name: Checkout repository - uses: actions/checkout@v3 - - name: Initialize CodeQL - uses: github/codeql-action/init@v2 - with: - languages: ${{ matrix.language }} - - run: | - sudo apt-get update && sudo apt-get install liblz4-dev - make -C hail shadowJar HAIL_COMPILE_NATIVES=1 - - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@v2 - with: - category: "/language:${{matrix.language}}" diff --git a/.gitignore b/.gitignore index 50a2a16a17c..5bb09fbcca5 100644 --- a/.gitignore +++ b/.gitignore @@ -50,3 +50,4 @@ hail/python/hail/backend/extra_classpath hail/python/hail/backend/hail.jar hail/install-editable _/ +.helix diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 88d5a94de29..5df84fce692 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,12 +9,19 @@ repos: language: system types: [python] require_serial: true + - id: ruff-format + name: ruff-format + entry: ruff format + language: system + types: [python] + require_serial: true - id: pyright name: pyright entry: pyright language: system types: [python] require_serial: true + exclude: hail/python/(hail|test) stages: - pre-push - repo: https://github.com/pre-commit/pre-commit-hooks @@ -22,11 +29,6 @@ repos: hooks: - id: end-of-file-fixer - id: trailing-whitespace - - repo: https://github.com/psf/black - rev: 22.3.0 - hooks: - - id: black - language_version: python3 - repo: https://github.com/thibaudcolas/curlylint rev: v0.13.1 hooks: diff --git a/Makefile b/Makefile index cce0ecc5922..7a7d23a1a6a 100644 --- a/Makefile +++ b/Makefile @@ -11,7 +11,7 @@ CHECK_SERVICES_MODULES := $(patsubst %, check-%, $(SERVICES_MODULES)) SPECIAL_IMAGES := hail-ubuntu batch-worker letsencrypt HAILGENETICS_IMAGES = $(foreach img,hail vep-grch37-85 vep-grch38-95,hailgenetics-$(img)) -CI_IMAGES = ci-utils ci-buildkit base hail-run +CI_IMAGES = ci-utils hail-buildkit base hail-run PRIVATE_REGISTRY_IMAGES = $(patsubst %, pushed-private-%-image, $(SPECIAL_IMAGES) $(SERVICES_PLUS_ADMIN_POD) $(CI_IMAGES) $(HAILGENETICS_IMAGES)) HAILTOP_VERSION := hail/python/hailtop/hail_version @@ -39,10 +39,13 @@ check-all: check-hail check-services .PHONY: check-hail-fast check-hail-fast: - ruff check hail/python/hail - ruff check hail/python/hailtop + ruff check hail + ruff format hail --diff $(PYTHON) -m pyright hail/python/hailtop + ruff check hail/python/test/hailtop/batch + $(PYTHON) -m pyright hail/python/test/hailtop/batch + .PHONY: pylint-hailtop pylint-hailtop: # pylint on hail is still a work in progress @@ -50,6 +53,7 @@ pylint-hailtop: .PHONY: check-hail check-hail: check-hail-fast pylint-hailtop + cd hail && sh millw __.checkFormat + __.fix --check .PHONY: check-services check-services: $(CHECK_SERVICES_MODULES) @@ -61,8 +65,8 @@ pylint-%: .PHONY: check-%-fast check-%-fast: ruff check $* + ruff format $* --diff $(PYTHON) -m pyright $* - $(PYTHON) -m black $* --check --diff curlylint $* cd $* && bash ../check-sql.sh @@ -82,7 +86,6 @@ check-pip-requirements: hail/python/dev \ gear \ web_common \ - auth \ batch \ ci @@ -94,7 +97,6 @@ check-linux-pip-requirements: hail/python/dev \ gear \ web_common \ - auth \ batch \ ci @@ -173,7 +175,8 @@ hail-0.1-docs-5a6778710097.tar.gz: gcloud storage cp gs://hail-common/builds/0.1/docs/$@ . hail/build/www: hail-0.1-docs-5a6778710097.tar.gz $(shell git ls-files hail) - $(MAKE) -C hail hail-docs-no-test batch-docs + @echo !!! This target does not render the notebooks because it takes a long time !!! + $(MAKE) -C hail hail-docs-do-not-render-notebooks batch-docs mkdir -p hail/build/www/docs/0.1 tar -xvf hail-0.1-docs-5a6778710097.tar.gz -C hail/build/www/docs/0.1 --strip-components 2 touch $@ # Copying into the dir does not necessarily touch it diff --git a/auth/auth/auth.py b/auth/auth/auth.py index e41bcec51dc..c50572db3d5 100644 --- a/auth/auth/auth.py +++ b/auth/auth/auth.py @@ -11,7 +11,6 @@ import kubernetes_asyncio.client import kubernetes_asyncio.client.rest import kubernetes_asyncio.config -import uvloop from aiohttp import web from prometheus_async.aio.web import server_stats # type: ignore @@ -33,11 +32,10 @@ from gear.auth import AIOHTTPHandler, get_session_id from gear.cloud_config import get_global_config from gear.profiling import install_profiler_if_requested -from hailtop import httpx +from hailtop import httpx, uvloopx from hailtop.auth import AzureFlow, Flow, GoogleFlow, IdentityProvider from hailtop.config import get_deploy_config from hailtop.hail_logging import AccessLogger -from hailtop.tls import internal_server_ssl_context from hailtop.utils import secret_alnum_string from web_common import render_template, set_message, setup_aiohttp_jinja2, setup_common_static_routes @@ -56,8 +54,6 @@ log = logging.getLogger('auth') -uvloop.install() - CLOUD = get_global_config()['cloud'] DEFAULT_NAMESPACE = os.environ['HAIL_DEFAULT_NAMESPACE'] @@ -164,10 +160,10 @@ async def _insert(tx): return False await tx.execute_insertone( - ''' + """ INSERT INTO users (state, username, login_id, is_developer, is_service_account, hail_identity, hail_credentials_secret_name) VALUES (%s, %s, %s, %s, %s, %s, %s); -''', +""", ( 'creating', username, @@ -482,9 +478,11 @@ async def rest_login(request: web.Request) -> web.Response: flow_data['callback_uri'] = callback_uri # keeping authorization_url and state for backwards compatibility - return json_response( - {'flow': flow_data, 'authorization_url': flow_data['authorization_url'], 'state': flow_data['state']} - ) + return json_response({ + 'flow': flow_data, + 'authorization_url': flow_data['authorization_url'], + 'state': flow_data['state'], + }) @routes.get('/api/v1alpha/oauth2-client') @@ -511,10 +509,10 @@ async def post_create_role(request: web.Request, _) -> NoReturn: name = str(post['name']) role_id = await db.execute_insertone( - ''' + """ INSERT INTO `roles` (`name`) VALUES (%s); -''', +""", (name), ) @@ -564,10 +562,10 @@ async def rest_get_users(request: web.Request, userdata: UserData) -> web.Respon raise web.HTTPUnauthorized() db = request.app[AppKeys.DB] - _query = ''' + _query = """ SELECT id, username, login_id, state, is_developer, is_service_account, hail_identity FROM users; -''' +""" users = [x async for x in db.select_and_fetchall(_query)] return json_response(users) @@ -579,10 +577,10 @@ async def rest_get_user(request: web.Request, _) -> web.Response: username = request.match_info['user'] user = await db.select_and_fetchone( - ''' + """ SELECT id, username, login_id, state, is_developer, is_service_account, hail_identity FROM users WHERE username = %s; -''', +""", (username,), ) if user is None: @@ -599,11 +597,11 @@ async def _delete_user(db: Database, username: str, id: Optional[str]): where_args.append(id) n_rows = await db.execute_update( - f''' + f""" UPDATE users SET state = 'deleting' WHERE {' AND '.join(where_conditions)}; -''', +""", where_args, ) @@ -743,11 +741,11 @@ async def get_userinfo_from_login_id_or_hail_identity_id( users = [ x async for x in db.select_and_fetchall( - ''' + """ SELECT users.* FROM users WHERE (users.login_id = %s OR users.hail_identity_uid = %s) AND users.state = 'active' -''', +""", (login_id_or_hail_idenity_uid, login_id_or_hail_idenity_uid), ) ] @@ -767,12 +765,12 @@ async def get_userinfo_from_hail_session_id(request: web.Request, session_id: st users = [ x async for x in db.select_and_fetchall( - ''' + """ SELECT users.* FROM users INNER JOIN sessions ON users.id = sessions.user_id WHERE users.state = 'active' AND sessions.session_id = %s AND (ISNULL(sessions.max_age_secs) OR (NOW() < TIMESTAMPADD(SECOND, sessions.max_age_secs, sessions.created))); -''', +""", session_id, 'get_userinfo', ) @@ -812,13 +810,20 @@ class AppKeys: HAILCTL_CLIENT_CONFIG = web.AppKey('hailctl_client_config', dict) K8S_CLIENT = web.AppKey('k8s_client', kubernetes_asyncio.client.CoreV1Api) K8S_CACHE = web.AppKey('k8s_cache', K8sCache) + EXIT_STACK = web.AppKey('exit_stack', AsyncExitStack) async def on_startup(app): + exit_stack = AsyncExitStack() + app[AppKeys.EXIT_STACK] = exit_stack + db = Database() await db.async_init(maxsize=50) + exit_stack.push_async_callback(db.async_close) app[AppKeys.DB] = db + app[AppKeys.CLIENT_SESSION] = httpx.client_session() + exit_stack.push_async_callback(app[AppKeys.CLIENT_SESSION].close) credentials_file = '/auth-oauth2-client-secret/client_secret.json' if CLOUD == 'gcp': @@ -832,14 +837,13 @@ async def on_startup(app): kubernetes_asyncio.config.load_incluster_config() app[AppKeys.K8S_CLIENT] = kubernetes_asyncio.client.CoreV1Api() + exit_stack.push_async_callback(app[AppKeys.K8S_CLIENT].api_client.rest_client.pool_manager.close) + app[AppKeys.K8S_CACHE] = K8sCache(app[AppKeys.K8S_CLIENT]) async def on_cleanup(app): - async with AsyncExitStack() as cleanup: - cleanup.push_async_callback(app[AppKeys.K8S_CLIENT].api_client.rest_client.pool_manager.close) - cleanup.push_async_callback(app[AppKeys.DB].async_close) - cleanup.push_async_callback(app[AppKeys.CLIENT_SESSION].close) + await app[AppKeys.EXIT_STACK].aclose() class AuthAccessLogger(AccessLogger): @@ -879,6 +883,8 @@ async def auth_check_csrf_token(request: web.Request, handler: AIOHTTPHandler): def run(): + uvloopx.install() + install_profiler_if_requested('auth') app = web.Application(middlewares=[auth_check_csrf_token, monitor_endpoints_middleware]) @@ -898,5 +904,5 @@ def run(): host='0.0.0.0', port=443, access_log_class=AuthAccessLogger, - ssl_context=internal_server_ssl_context(), + ssl_context=deploy_config.server_ssl_context(), ) diff --git a/auth/auth/driver/driver.py b/auth/auth/driver/driver.py index b61c7ba6ec2..5388b657d8a 100644 --- a/auth/auth/driver/driver.py +++ b/auth/auth/driver/driver.py @@ -94,10 +94,10 @@ async def delete(self): return await self.db.just_execute( - ''' + """ DELETE FROM sessions WHERE session_id = %s; -''', +""", (self.session_id,), ) self.session_id = None @@ -430,11 +430,11 @@ async def _create_user(app, user, skip_trial_bp, cleanup): updates['trial_bp_name'] = billing_project_name n_rows = await db.execute_update( - f''' + f""" UPDATE users SET {', '.join([f'{k} = %({k})s' for k in updates])} WHERE id = %(id)s AND state = 'creating'; -''', +""", {'id': user['id'], **updates}, ) if n_rows != 1: @@ -504,10 +504,10 @@ async def delete_user(app, user): await bp.delete() await db.just_execute( - ''' + """ DELETE FROM sessions WHERE user_id = %s; UPDATE users SET state = 'deleted' WHERE id = %s; -''', +""", (user['id'], user['id']), ) @@ -525,11 +525,11 @@ async def resolve_identity_uid(app, hail_identity): hail_identity_uid = await sp.get_service_principal_object_id() await db.just_execute( - ''' + """ UPDATE users SET hail_identity_uid = %s WHERE hail_identity = %s -''', +""", (hail_identity_uid, hail_identity), ) diff --git a/auth/auth/exceptions.py b/auth/auth/exceptions.py index 18e734d02ae..769404ae2e9 100644 --- a/auth/auth/exceptions.py +++ b/auth/auth/exceptions.py @@ -7,7 +7,7 @@ def __init__(self, message, severity): self.message = message self.ui_error_type = severity - def http_response(self): + def http_response(self) -> web.HTTPError: return web.HTTPBadRequest(reason=self.message) diff --git a/auth/deployment.yaml b/auth/deployment.yaml index 7d49afc40c1..8dd1679b8b7 100644 --- a/auth/deployment.yaml +++ b/auth/deployment.yaml @@ -45,11 +45,6 @@ spec: value: "{{ default_ns.name }}" - name: HAIL_DEPLOY_CONFIG_FILE value: /deploy-config/deploy-config.json - - name: HAIL_DOMAIN - valueFrom: - secretKeyRef: - name: global-config - key: domain - name: CLOUD valueFrom: secretKeyRef: @@ -179,11 +174,6 @@ spec: value: "{{ default_ns.name }}" - name: HAIL_DEPLOY_CONFIG_FILE value: /deploy-config/deploy-config.json - - name: HAIL_DOMAIN - valueFrom: - secretKeyRef: - name: global-config - key: domain - name: HAIL_ORGANIZATION_DOMAIN valueFrom: secretKeyRef: diff --git a/batch/Dockerfile.worker b/batch/Dockerfile.worker index 7e0454a729a..5fa4599b07e 100644 --- a/batch/Dockerfile.worker +++ b/batch/Dockerfile.worker @@ -51,7 +51,7 @@ RUN hail-pip-install \ -r hailtop-requirements.txt \ -r gear-requirements.txt \ -r batch-requirements.txt \ - pyspark==3.3.0 + pyspark==3.3.2 ENV SPARK_HOME /usr/local/lib/python3.9/dist-packages/pyspark ENV PATH "$PATH:$SPARK_HOME/sbin:$SPARK_HOME/bin" diff --git a/batch/batch/batch.py b/batch/batch/batch.py index ea6a45b7ae4..a2da5ba685b 100644 --- a/batch/batch/batch.py +++ b/batch/batch/batch.py @@ -1,18 +1,25 @@ import json import logging -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, cast from gear import transaction -from hailtop.batch_client.types import CostBreakdownEntry, JobListEntryV1Alpha +from hailtop.batch_client.globals import ROOT_JOB_GROUP_ID +from hailtop.batch_client.types import CostBreakdownEntry, GetJobGroupResponseV1Alpha, JobListEntryV1Alpha from hailtop.utils import humanize_timedelta_msecs, time_msecs_str from .batch_format_version import BatchFormatVersion -from .exceptions import NonExistentBatchError, OpenBatchError +from .exceptions import NonExistentJobGroupError from .utils import coalesce log = logging.getLogger('batch') +def _maybe_time_msecs_str(t: Optional[int]) -> Optional[str]: + if t is not None: + return time_msecs_str(t) + return None + + def cost_breakdown_to_dict(cost_breakdown: Dict[str, float]) -> List[CostBreakdownEntry]: return [{'resource': resource, 'cost': cost} for resource, cost in cost_breakdown.items()] @@ -30,14 +37,9 @@ def batch_record_to_dict(record: Dict[str, Any]) -> Dict[str, Any]: else: state = 'running' - def _time_msecs_str(t): - if t: - return time_msecs_str(t) - return None - - time_created = _time_msecs_str(record['time_created']) - time_closed = _time_msecs_str(record['time_closed']) - time_completed = _time_msecs_str(record['time_completed']) + time_created = _maybe_time_msecs_str(record['time_created']) + time_closed = _maybe_time_msecs_str(record['time_closed']) + time_completed = _maybe_time_msecs_str(record['time_completed']) if record['time_created'] and record['time_completed']: duration_ms = record['time_completed'] - record['time_created'] @@ -50,7 +52,7 @@ def _time_msecs_str(t): if cost_breakdown is not None: cost_breakdown = cost_breakdown_to_dict(json.loads(cost_breakdown)) - d = { + batch_response = { 'id': record['id'], 'user': record['user'], 'billing_project': record['billing_project'], @@ -75,9 +77,55 @@ def _time_msecs_str(t): attributes = json.loads(record['attributes']) if attributes: - d['attributes'] = attributes + batch_response['attributes'] = attributes + + return batch_response + + +def job_group_record_to_dict(record: Dict[str, Any]) -> GetJobGroupResponseV1Alpha: + if record['n_failed'] > 0: + state = 'failure' + elif record['cancelled'] or record['n_cancelled'] > 0: + state = 'cancelled' + elif record['state'] == 'complete': + assert record['n_succeeded'] == record['n_jobs'] + state = 'success' + else: + state = 'running' + + time_created = _maybe_time_msecs_str(record['time_created']) + time_completed = _maybe_time_msecs_str(record['time_completed']) - return d + if record['time_created'] and record['time_completed']: + duration_ms = record['time_completed'] - record['time_created'] + else: + duration_ms = None + + if record['cost_breakdown'] is not None: + record['cost_breakdown'] = cost_breakdown_to_dict(json.loads(record['cost_breakdown'])) + + job_group_response = { + 'batch_id': record['batch_id'], + 'job_group_id': record['job_group_id'], + 'state': state, + 'complete': record['state'] == 'complete', + 'n_jobs': record['n_jobs'], + 'n_completed': record['n_completed'], + 'n_succeeded': record['n_succeeded'], + 'n_failed': record['n_failed'], + 'n_cancelled': record['n_cancelled'], + 'time_created': time_created, + 'time_completed': time_completed, + 'duration': duration_ms, + 'cost': coalesce(record['cost'], 0), + 'cost_breakdown': record['cost_breakdown'], + } + + attributes = json.loads(record['attributes']) + if attributes: + job_group_response['attributes'] = attributes + + return cast(GetJobGroupResponseV1Alpha, job_group_response) def job_record_to_dict(record: Dict[str, Any], name: Optional[str]) -> JobListEntryV1Alpha: @@ -95,38 +143,44 @@ def job_record_to_dict(record: Dict[str, Any], name: Optional[str]) -> JobListEn if cost_breakdown is not None: cost_breakdown = cost_breakdown_to_dict(json.loads(cost_breakdown)) - return { - 'batch_id': record['batch_id'], - 'job_id': record['job_id'], - 'name': name, - 'user': record['user'], - 'billing_project': record['billing_project'], - 'state': record['state'], - 'exit_code': exit_code, - 'duration': duration, - 'cost': coalesce(record.get('cost'), 0), - 'msec_mcpu': record['msec_mcpu'], - 'cost_breakdown': cost_breakdown, - } - - -async def cancel_batch_in_db(db, batch_id): + return cast( + JobListEntryV1Alpha, + { + 'batch_id': record['batch_id'], + 'job_id': record['job_id'], + 'name': name, + 'user': record['user'], + 'billing_project': record['billing_project'], + 'state': record['state'], + 'exit_code': exit_code, + 'duration': duration, + 'cost': coalesce(record.get('cost'), 0), + 'msec_mcpu': record['msec_mcpu'], + 'cost_breakdown': cost_breakdown, + 'always_run': bool(record['always_run']), + 'display_state': None, + }, + ) + + +async def cancel_job_group_in_db(db, batch_id, job_group_id): @transaction(db) async def cancel(tx): record = await tx.execute_and_fetchone( - ''' -SELECT `state` FROM batches -WHERE id = %s AND NOT deleted + """ +SELECT 1 +FROM job_groups +LEFT JOIN batches ON batches.id = job_groups.batch_id +LEFT JOIN batch_updates ON job_groups.batch_id = batch_updates.batch_id AND + job_groups.update_id = batch_updates.update_id +WHERE job_groups.batch_id = %s AND job_groups.job_group_id = %s AND NOT deleted AND (batch_updates.committed OR job_groups.job_group_id = %s) FOR UPDATE; -''', - (batch_id,), +""", + (batch_id, job_group_id, ROOT_JOB_GROUP_ID), ) if not record: - raise NonExistentBatchError(batch_id) - - if record['state'] == 'open': - raise OpenBatchError(batch_id) + raise NonExistentJobGroupError(batch_id, job_group_id) - await tx.just_execute('CALL cancel_batch(%s);', (batch_id,)) + await tx.just_execute('CALL cancel_job_group(%s, %s);', (batch_id, job_group_id)) await cancel() diff --git a/batch/batch/cloud/azure/driver/create_instance.py b/batch/batch/cloud/azure/driver/create_instance.py index 6602c8d8074..122386a6b72 100644 --- a/batch/batch/cloud/azure/driver/create_instance.py +++ b/batch/batch/cloud/azure/driver/create_instance.py @@ -92,7 +92,7 @@ def create_vm_config( jvm_touch_command = '\n'.join(touch_commands) - startup_script = r'''#cloud-config + startup_script = r"""#cloud-config mounts: - [ ephemeral0, null ] @@ -123,10 +123,10 @@ def create_vm_config( runcmd: - sh /startup.sh -''' +""" startup_script = base64.b64encode(startup_script.encode('utf-8')).decode('utf-8') - run_script = f''' + run_script = f""" #!/bin/bash set -x @@ -284,7 +284,6 @@ def create_vm_config( -v /sys/fs/cgroup:/sys/fs/cgroup \ --mount type=bind,source=/mnt/disks/$WORKER_DATA_DISK_NAME,target=/host \ --mount type=bind,source=/dev,target=/dev,bind-propagation=rshared \ --p 5000:5000 \ --device /dev/fuse \ --device $XFS_DEVICE \ --device /dev \ @@ -302,7 +301,7 @@ def create_vm_config( az vm delete -g $RESOURCE_GROUP -n $NAME --yes sleep 1 done -''' +""" user_data = { 'run_script': run_script, diff --git a/batch/batch/cloud/azure/driver/driver.py b/batch/batch/cloud/azure/driver/driver.py index 58ffb1f3570..b1d802fca37 100644 --- a/batch/batch/cloud/azure/driver/driver.py +++ b/batch/batch/cloud/azure/driver/driver.py @@ -37,10 +37,10 @@ async def create( region_args = [(r,) for r in regions] await db.execute_many( - ''' + """ INSERT INTO regions (region) VALUES (%s) ON DUPLICATE KEY UPDATE region = region; -''', +""", region_args, ) diff --git a/batch/batch/cloud/azure/instance_config.py b/batch/batch/cloud/azure/instance_config.py index e14123a9bb3..fbde785cc98 100644 --- a/batch/batch/cloud/azure/instance_config.py +++ b/batch/batch/cloud/azure/instance_config.py @@ -35,16 +35,14 @@ def create( else: data_disk_resource = AzureStaticSizedDiskResource.create(product_versions, 'P', data_disk_size_gb, location) - resources: List[AzureResource] = filter_none( - [ - AzureVMResource.create(product_versions, machine_type, preemptible, location), - AzureStaticSizedDiskResource.create(product_versions, 'E', boot_disk_size_gb, location), - data_disk_resource, - AzureDynamicSizedDiskResource.create(product_versions, 'P', location), - AzureIPFeeResource.create(product_versions, 1024), - AzureServiceFeeResource.create(product_versions), - ] - ) + resources: List[AzureResource] = filter_none([ + AzureVMResource.create(product_versions, machine_type, preemptible, location), + AzureStaticSizedDiskResource.create(product_versions, 'E', boot_disk_size_gb, location), + data_disk_resource, + AzureDynamicSizedDiskResource.create(product_versions, 'P', location), + AzureIPFeeResource.create(product_versions, 1024), + AzureServiceFeeResource.create(product_versions), + ]) return AzureSlimInstanceConfig( machine_type=machine_type, diff --git a/batch/batch/cloud/azure/worker/credentials.py b/batch/batch/cloud/azure/worker/credentials.py deleted file mode 100644 index 0488b1537f1..00000000000 --- a/batch/batch/cloud/azure/worker/credentials.py +++ /dev/null @@ -1,44 +0,0 @@ -import base64 -import json -from typing import Dict - -from hailtop.auth.auth import IdentityProvider - -from ....worker.credentials import CloudUserCredentials - - -class AzureUserCredentials(CloudUserCredentials): - def __init__(self, data: Dict[str, str]): - self._data = data - self._credentials = json.loads(base64.b64decode(data['key.json']).decode()) - - @property - def cloud_env_name(self) -> str: - return 'AZURE_APPLICATION_CREDENTIALS' - - @property - def username(self): - return self._credentials['appId'] - - @property - def password(self): - return self._credentials['password'] - - @property - def mount_path(self): - return '/azure-credentials/key.json' - - @property - def identity_provider_json(self): - return {'idp': IdentityProvider.MICROSOFT.value} - - def blobfuse_credentials(self, account: str, container: str) -> str: - # https://github.com/Azure/azure-storage-fuse - return f''' -accountName {account} -authType SPN -servicePrincipalClientId {self.username} -servicePrincipalClientSecret {self.password} -servicePrincipalTenantId {self._credentials['tenant']} -containerName {container} -''' diff --git a/batch/batch/cloud/azure/worker/worker_api.py b/batch/batch/cloud/azure/worker/worker_api.py index a8cf68cdbb9..779bc13fc8d 100644 --- a/batch/batch/cloud/azure/worker/worker_api.py +++ b/batch/batch/cloud/azure/worker/worker_api.py @@ -1,21 +1,24 @@ import abc +import base64 import os import tempfile from typing import Dict, List, Optional, Tuple import aiohttp +import orjson +from aiohttp import web from hailtop import httpx from hailtop.aiocloud import aioazure +from hailtop.auth.auth import IdentityProvider from hailtop.utils import check_exec_output, retry_transient_errors, time_msecs from ....worker.worker_api import CloudWorkerAPI, ContainerRegistryCredentials from ..instance_config import AzureSlimInstanceConfig -from .credentials import AzureUserCredentials from .disk import AzureDisk -class AzureWorkerAPI(CloudWorkerAPI[AzureUserCredentials]): +class AzureWorkerAPI(CloudWorkerAPI): nameserver_ip = '168.63.129.16' @staticmethod @@ -37,17 +40,16 @@ def __init__(self, subscription_id: str, resource_group: str, acr_url: str, hail @property def cloud_specific_env_vars_for_user_jobs(self) -> List[str]: - return [f'HAIL_AZURE_OAUTH_SCOPE={self.hail_oauth_scope}'] + idp_json = orjson.dumps({'idp': IdentityProvider.MICROSOFT.value}).decode('utf-8') + return [ + f'HAIL_AZURE_OAUTH_SCOPE={self.hail_oauth_scope}', + 'AZURE_APPLICATION_CREDENTIALS=/azure-credentials/key.json', + f'HAIL_IDENTITY_PROVIDER_JSON={idp_json}', + ] def create_disk(self, instance_name: str, disk_name: str, size_in_gb: int, mount_path: str) -> AzureDisk: return AzureDisk(disk_name, instance_name, size_in_gb, mount_path) - def get_cloud_async_fs(self) -> aioazure.AzureAsyncFS: - return aioazure.AzureAsyncFS(credentials=self.azure_credentials) - - def user_credentials(self, credentials: Dict[str, str]) -> AzureUserCredentials: - return AzureUserCredentials(credentials) - async def worker_container_registry_credentials(self, session: httpx.ClientSession) -> ContainerRegistryCredentials: # https://docs.microsoft.com/en-us/azure/container-registry/container-registry-authentication?tabs=azure-cli#az-acr-login-with---expose-token return { @@ -55,33 +57,44 @@ async def worker_container_registry_credentials(self, session: httpx.ClientSessi 'password': await self.acr_refresh_token.token(session), } - async def user_container_registry_credentials( - self, user_credentials: AzureUserCredentials - ) -> ContainerRegistryCredentials: - return { - 'username': user_credentials.username, - 'password': user_credentials.password, - } + async def user_container_registry_credentials(self, credentials: Dict[str, str]) -> ContainerRegistryCredentials: + credentials = orjson.loads(base64.b64decode(credentials['key.json']).decode()) + return {'username': credentials['appId'], 'password': credentials['password']} + + def create_metadata_server_app(self, credentials: Dict[str, str]) -> web.Application: + raise NotImplementedError def instance_config_from_config_dict(self, config_dict: Dict[str, str]) -> AzureSlimInstanceConfig: return AzureSlimInstanceConfig.from_dict(config_dict) + def _blobfuse_credentials(self, credentials: Dict[str, str], account: str, container: str) -> str: + credentials = orjson.loads(base64.b64decode(credentials['key.json']).decode()) + # https://github.com/Azure/azure-storage-fuse + return f""" +accountName {account} +authType SPN +servicePrincipalClientId {credentials["appId"]} +servicePrincipalClientSecret {credentials["password"]} +servicePrincipalTenantId {credentials["tenant"]} +containerName {container} +""" + def _write_blobfuse_credentials( self, - credentials: AzureUserCredentials, + credentials: Dict[str, str], account: str, container: str, mount_base_path_data: str, ) -> str: if mount_base_path_data not in self._blobfuse_credential_files: with tempfile.NamedTemporaryFile(mode='w', encoding='utf-8', delete=False) as credsfile: - credsfile.write(credentials.blobfuse_credentials(account, container)) + credsfile.write(self._blobfuse_credentials(credentials, account, container)) self._blobfuse_credential_files[mount_base_path_data] = credsfile.name return self._blobfuse_credential_files[mount_base_path_data] async def _mount_cloudfuse( self, - credentials: AzureUserCredentials, + credentials: Dict[str, str], mount_base_path_data: str, mount_base_path_tmp: str, config: dict, diff --git a/batch/batch/cloud/driver.py b/batch/batch/cloud/driver.py index 0be00d9a749..db14db95b0c 100644 --- a/batch/batch/cloud/driver.py +++ b/batch/batch/cloud/driver.py @@ -1,3 +1,5 @@ +import os + from gear import Database from gear.cloud_config import get_global_config @@ -5,6 +7,7 @@ from ..inst_coll_config import InstanceCollectionConfigs from .azure.driver.driver import AzureDriver from .gcp.driver.driver import GCPDriver +from .terra.azure.driver.driver import TerraAzureDriver async def get_cloud_driver( @@ -17,6 +20,8 @@ async def get_cloud_driver( cloud = get_global_config()['cloud'] if cloud == 'azure': + if os.environ.get('HAIL_TERRA'): + return await TerraAzureDriver.create(app, db, machine_name_prefix, namespace, inst_coll_configs) return await AzureDriver.create(app, db, machine_name_prefix, namespace, inst_coll_configs) assert cloud == 'gcp', cloud diff --git a/batch/batch/cloud/gcp/driver/activity_logs.py b/batch/batch/cloud/gcp/driver/activity_logs.py index 0e6dd3199b3..cc075516ea2 100644 --- a/batch/batch/cloud/gcp/driver/activity_logs.py +++ b/batch/batch/cloud/gcp/driver/activity_logs.py @@ -95,14 +95,14 @@ async def process_activity_log_events_since( project: str, mark: str, ) -> str: - filter = f''' + filter = f""" (logName="projects/{project}/logs/cloudaudit.googleapis.com%2Factivity" OR logName="projects/{project}/logs/cloudaudit.googleapis.com%2Fsystem_event" ) AND resource.type=gce_instance AND protoPayload.resourceName:"{machine_name_prefix}" AND timestamp >= "{mark}" -''' +""" body = { 'resourceNames': [f'projects/{project}'], diff --git a/batch/batch/cloud/gcp/driver/create_instance.py b/batch/batch/cloud/gcp/driver/create_instance.py index db1e452025b..da0d514d7e8 100644 --- a/batch/batch/cloud/gcp/driver/create_instance.py +++ b/batch/batch/cloud/gcp/driver/create_instance.py @@ -85,12 +85,10 @@ def scheduling() -> dict: } if preemptible: - result.update( - { - 'provisioningModel': 'SPOT', - 'instanceTerminationAction': 'DELETE', - } - ) + result.update({ + 'provisioningModel': 'SPOT', + 'instanceTerminationAction': 'DELETE', + }) return result @@ -129,7 +127,7 @@ def scheduling() -> dict: 'items': [ { 'key': 'startup-script', - 'value': ''' + 'value': """ #!/bin/bash set -x @@ -150,11 +148,11 @@ def scheduling() -> dict: curl -s -H "Metadata-Flavor: Google" "http://metadata.google.internal/computeMetadata/v1/instance/attributes/run_script" >./run.sh nohup /bin/bash run.sh >run.log 2>&1 & - ''', + """, }, { 'key': 'run_script', - 'value': rf''' + 'value': rf""" #!/bin/bash set -x @@ -232,6 +230,8 @@ def scheduling() -> dict: - /batch/jvm-container-logs/jvm-*.log record_log_file_path: true processors: + parse_message: + type: parse_json labels: type: modify_fields fields: @@ -239,11 +239,13 @@ def scheduling() -> dict: static_value: $NAMESPACE labels.instance_id: static_value: $INSTANCE_ID + severity: + move_from: jsonPayload.severity service: log_level: error pipelines: default_pipeline: - processors: [labels] + processors: [parse_message, labels] receivers: [runlog, workerlog, jvmlog] metrics: @@ -264,9 +266,9 @@ def scheduling() -> dict: iptables --table nat --append POSTROUTING --source 172.20.0.0/15 --jump MASQUERADE # [public] -# Block public traffic to the metadata server -iptables --append FORWARD --source 172.21.0.0/16 --destination 169.254.169.254 --jump DROP -# But allow the internal gateway +# Send public jobs' metadata server requests to the batch worker itself +iptables --table nat --append PREROUTING --source 172.21.0.0/16 --destination 169.254.169.254 -p tcp -j REDIRECT --to-ports 5555 +# Allow the internal gateway iptables --append FORWARD --destination $INTERNAL_GATEWAY_IP --jump ACCEPT # And this worker iptables --append FORWARD --destination $IP_ADDRESS --jump ACCEPT @@ -328,7 +330,6 @@ def scheduling() -> dict: -v /sys/fs/cgroup:/sys/fs/cgroup \ --mount type=bind,source=/mnt/disks/$WORKER_DATA_DISK_NAME,target=/host \ --mount type=bind,source=/dev,target=/dev,bind-propagation=rshared \ --p 5000:5000 \ --device /dev/fuse \ --device $XFS_DEVICE \ --device /dev \ @@ -347,18 +348,18 @@ def scheduling() -> dict: gcloud -q compute instances delete $NAME --zone=$ZONE sleep 1 done -''', +""", }, { 'key': 'shutdown-script', - 'value': ''' + 'value': """ set -x INSTANCE_ID=$(curl -s -H "Metadata-Flavor: Google" "http://metadata.google.internal/computeMetadata/v1/instance/attributes/instance_id") NAME=$(curl -s http://metadata.google.internal/computeMetadata/v1/instance/name -H 'Metadata-Flavor: Google') journalctl -u docker.service > dockerd.log -''', +""", }, {'key': 'activation_token', 'value': activation_token}, {'key': 'batch_worker_image', 'value': BATCH_WORKER_IMAGE}, diff --git a/batch/batch/cloud/gcp/driver/driver.py b/batch/batch/cloud/gcp/driver/driver.py index 4000b650469..cc5914f18b3 100644 --- a/batch/batch/cloud/gcp/driver/driver.py +++ b/batch/batch/cloud/gcp/driver/driver.py @@ -34,10 +34,10 @@ async def create( region_args = [(region,) for region in regions] await db.execute_many( - ''' + """ INSERT INTO regions (region) VALUES (%s) ON DUPLICATE KEY UPDATE region = region; -''', +""", region_args, ) @@ -92,7 +92,7 @@ async def create( inst_coll_configs.jpim_config, task_manager, ), - *create_pools_coros + *create_pools_coros, ) driver = GCPDriver( diff --git a/batch/batch/cloud/gcp/instance_config.py b/batch/batch/cloud/gcp/instance_config.py index 02e96662c11..2789ff6adcc 100644 --- a/batch/batch/cloud/gcp/instance_config.py +++ b/batch/batch/cloud/gcp/instance_config.py @@ -57,7 +57,7 @@ def create( GCPStaticSizedDiskResource.create(product_versions, 'pd-ssd', boot_disk_size_gb, region), data_disk_resource, GCPDynamicSizedDiskResource.create(product_versions, 'pd-ssd', region), - GCPIPFeeResource.create(product_versions, 1024), + GCPIPFeeResource.create(product_versions, 1024, preemptible), GCPServiceFeeResource.create(product_versions), GCPSupportLogsSpecsAndFirewallFees.create(product_versions), ] diff --git a/batch/batch/cloud/gcp/resources.py b/batch/batch/cloud/gcp/resources.py index 5fe098fcebf..e058d9d0cf5 100644 --- a/batch/batch/cloud/gcp/resources.py +++ b/batch/batch/cloud/gcp/resources.py @@ -270,8 +270,9 @@ class GCPIPFeeResource(IPFeeResourceMixin, GCPResource): TYPE = 'gcp_ip_fee' @staticmethod - def product_name(base: int) -> str: - return f'ip-fee/{base}' + def product_name(base: int, preemptible: bool) -> str: + preemptible_str = 'preemptible' if preemptible else 'nonpreemptible' + return f'ip-fee/{preemptible_str}/{base}' @staticmethod def from_dict(data: Dict[str, Any]) -> 'GCPIPFeeResource': @@ -279,8 +280,8 @@ def from_dict(data: Dict[str, Any]) -> 'GCPIPFeeResource': return GCPIPFeeResource(data['name']) @staticmethod - def create(product_versions: ProductVersions, base: int) -> 'GCPIPFeeResource': - product = GCPIPFeeResource.product_name(base) + def create(product_versions: ProductVersions, base: int, preemptible: bool) -> 'GCPIPFeeResource': + product = GCPIPFeeResource.product_name(base, preemptible) name = product_versions.resource_name(product) assert name, product return GCPIPFeeResource(name) diff --git a/batch/batch/cloud/gcp/worker/credentials.py b/batch/batch/cloud/gcp/worker/credentials.py deleted file mode 100644 index b637ef4e951..00000000000 --- a/batch/batch/cloud/gcp/worker/credentials.py +++ /dev/null @@ -1,28 +0,0 @@ -import base64 -from typing import Dict - -from hailtop.auth.auth import IdentityProvider - -from ....worker.credentials import CloudUserCredentials - - -class GCPUserCredentials(CloudUserCredentials): - def __init__(self, data: Dict[str, str]): - self._data = data - self._key = base64.b64decode(self._data['key.json']).decode() - - @property - def cloud_env_name(self) -> str: - return 'GOOGLE_APPLICATION_CREDENTIALS' - - @property - def mount_path(self): - return '/gsa-key/key.json' - - @property - def key(self): - return self._key - - @property - def identity_provider_json(self): - return {'idp': IdentityProvider.GOOGLE.value} diff --git a/batch/batch/cloud/gcp/worker/metadata_server.py b/batch/batch/cloud/gcp/worker/metadata_server.py new file mode 100644 index 00000000000..5475c9982a9 --- /dev/null +++ b/batch/batch/cloud/gcp/worker/metadata_server.py @@ -0,0 +1,109 @@ +from aiohttp import web + +from hailtop.aiocloud import aiogoogle + +from ....globals import HTTP_CLIENT_MAX_SIZE + + +class AppKeys: + USER_CREDENTIALS = web.AppKey('credentials', aiogoogle.GoogleServiceAccountCredentials) + GCE_METADATA_SERVER_CLIENT = web.AppKey('ms_client', aiogoogle.GoogleMetadataServerClient) + + +async def root(_): + return web.Response(text='computeMetadata/\n') + + +async def project_id(request: web.Request): + metadata_server_client = request.app[AppKeys.GCE_METADATA_SERVER_CLIENT] + return web.Response(text=await metadata_server_client.project()) + + +async def numeric_project_id(request: web.Request): + metadata_server_client = request.app[AppKeys.GCE_METADATA_SERVER_CLIENT] + return web.Response(text=await metadata_server_client.numeric_project_id()) + + +async def service_accounts(request: web.Request): + gsa_email = request.app[AppKeys.USER_CREDENTIALS].email + return web.Response(text=f'default\n{gsa_email}\n') + + +async def user_service_account(request: web.Request): + gsa_email = request.app[AppKeys.USER_CREDENTIALS].email + recursive = request.query.get('recursive') + # https://cloud.google.com/compute/docs/metadata/querying-metadata + # token is not included in the recursive version, presumably as that + # is not simple metadata but requires requesting an access token + if recursive == 'true': + return web.json_response( + { + 'aliases': ['default'], + 'email': gsa_email, + 'scopes': ['https://www.googleapis.com/auth/cloud-platform'], + }, + ) + return web.Response(text='aliases\nemail\nscopes\ntoken\n') + + +async def user_email(request: web.Request): + return web.Response(text=request.app[AppKeys.USER_CREDENTIALS].email) + + +async def user_token(request: web.Request): + access_token = await request.app[AppKeys.USER_CREDENTIALS]._get_access_token() + return web.json_response({ + 'access_token': access_token.token, + 'expires_in': access_token.expires_in, + 'token_type': 'Bearer', + }) + + +@web.middleware +async def middleware(request: web.Request, handler): + credentials = request.app[AppKeys.USER_CREDENTIALS] + gsa = request.match_info.get('gsa') + if gsa and gsa not in (credentials.email, 'default'): + raise web.HTTPBadRequest() + + response = await handler(request) + response.enable_compression() + + # `gcloud` does not properly respect `charset`, which aiohttp automatically + # sets so we have to explicitly erase it + # See https://github.com/googleapis/google-auth-library-python/blob/b935298aaf4ea5867b5778bcbfc42408ba4ec02c/google/auth/compute_engine/_metadata.py#L170 + if 'application/json' in response.headers['Content-Type']: + response.headers['Content-Type'] = 'application/json' + response.headers['Metadata-Flavor'] = 'Google' + response.headers['Server'] = 'Metadata Server for VM' + response.headers['X-XSS-Protection'] = '0' + response.headers['X-Frame-Options'] = 'SAMEORIGIN' + return response + + +def create_app( + credentials: aiogoogle.GoogleServiceAccountCredentials, + metadata_server_client: aiogoogle.GoogleMetadataServerClient, +) -> web.Application: + app = web.Application( + client_max_size=HTTP_CLIENT_MAX_SIZE, + middlewares=[middleware], + ) + app[AppKeys.USER_CREDENTIALS] = credentials + app[AppKeys.GCE_METADATA_SERVER_CLIENT] = metadata_server_client + + app.add_routes([ + web.get('/', root), + web.get('/computeMetadata/v1/project/project-id', project_id), + web.get('/computeMetadata/v1/project/numeric-project-id', numeric_project_id), + web.get('/computeMetadata/v1/instance/service-accounts/', service_accounts), + web.get('/computeMetadata/v1/instance/service-accounts/{gsa}/', user_service_account), + web.get('/computeMetadata/v1/instance/service-accounts/{gsa}/email', user_email), + web.get('/computeMetadata/v1/instance/service-accounts/{gsa}/token', user_token), + ]) + + async def close_credentials(_): + await credentials.close() + + app.on_cleanup.append(close_credentials) + return app diff --git a/batch/batch/cloud/gcp/worker/worker_api.py b/batch/batch/cloud/gcp/worker/worker_api.py index 19c431f6c06..4e248fdf329 100644 --- a/batch/batch/cloud/gcp/worker/worker_api.py +++ b/batch/batch/cloud/gcp/worker/worker_api.py @@ -1,40 +1,63 @@ +import base64 import os import tempfile +from contextlib import AsyncExitStack from typing import Dict, List -import aiohttp +import orjson +from aiohttp import web from hailtop import httpx from hailtop.aiocloud import aiogoogle -from hailtop.utils import check_exec_output, retry_transient_errors +from hailtop.auth.auth import IdentityProvider +from hailtop.utils import check_exec_output from ....worker.worker_api import CloudWorkerAPI, ContainerRegistryCredentials from ..instance_config import GCPSlimInstanceConfig -from .credentials import GCPUserCredentials from .disk import GCPDisk +from .metadata_server import create_app -class GCPWorkerAPI(CloudWorkerAPI[GCPUserCredentials]): +class GCPWorkerAPI(CloudWorkerAPI): nameserver_ip = '169.254.169.254' - # async because GoogleSession must be created inside a running event loop + # async because ClientSession must be created inside a running event loop @staticmethod async def from_env() -> 'GCPWorkerAPI': project = os.environ['PROJECT'] zone = os.environ['ZONE'].rsplit('/', 1)[1] - session = aiogoogle.GoogleSession() - return GCPWorkerAPI(project, zone, session) + worker_credentials = aiogoogle.GoogleInstanceMetadataCredentials() + http_session = httpx.client_session() + return GCPWorkerAPI(project, zone, worker_credentials, http_session) - def __init__(self, project: str, zone: str, session: aiogoogle.GoogleSession): + def __init__( + self, + project: str, + zone: str, + worker_credentials: aiogoogle.GoogleInstanceMetadataCredentials, + http_session: httpx.ClientSession, + ): self.project = project self.zone = zone - self._google_session = session - self._compute_client = aiogoogle.GoogleComputeClient(project, session=session) + + self._exit_stack = AsyncExitStack() + self._http_session = http_session + self._exit_stack.push_async_callback(self._http_session.close) + + self._compute_client = aiogoogle.GoogleComputeClient(project) + self._exit_stack.push_async_callback(self._compute_client.close) + + self._metadata_server_client = aiogoogle.GoogleMetadataServerClient(http_session) self._gcsfuse_credential_files: Dict[str, str] = {} + self._worker_credentials = worker_credentials @property def cloud_specific_env_vars_for_user_jobs(self) -> List[str]: - return [] + idp_json = orjson.dumps({'idp': IdentityProvider.GOOGLE.value}).decode('utf-8') + return [ + 'GOOGLE_APPLICATION_CREDENTIALS=/gsa-key/key.json', + f'HAIL_IDENTITY_PROVIDER_JSON={idp_json}', + ] def create_disk(self, instance_name: str, disk_name: str, size_in_gb: int, mount_path: str) -> GCPDisk: return GCPDisk( @@ -47,45 +70,37 @@ def create_disk(self, instance_name: str, disk_name: str, size_in_gb: int, mount compute_client=self._compute_client, ) - def get_cloud_async_fs(self) -> aiogoogle.GoogleStorageAsyncFS: - return aiogoogle.GoogleStorageAsyncFS(session=self._google_session) - - def user_credentials(self, credentials: Dict[str, str]) -> GCPUserCredentials: - return GCPUserCredentials(credentials) - async def worker_container_registry_credentials(self, session: httpx.ClientSession) -> ContainerRegistryCredentials: - token_dict = await retry_transient_errors( - session.post_read_json, - 'http://169.254.169.254/computeMetadata/v1/instance/service-accounts/default/token', - headers={'Metadata-Flavor': 'Google'}, - timeout=aiohttp.ClientTimeout(total=60), # type: ignore - ) - access_token = token_dict['access_token'] + access_token = await self._worker_credentials.access_token() return {'username': 'oauth2accesstoken', 'password': access_token} - async def user_container_registry_credentials( - self, user_credentials: GCPUserCredentials - ) -> ContainerRegistryCredentials: - return {'username': '_json_key', 'password': user_credentials.key} + async def user_container_registry_credentials(self, credentials: Dict[str, str]) -> ContainerRegistryCredentials: + key = orjson.loads(base64.b64decode(credentials['key.json']).decode()) + async with aiogoogle.GoogleServiceAccountCredentials(key) as sa_credentials: + access_token = await sa_credentials.access_token() + return {'username': 'oauth2accesstoken', 'password': access_token} + + def create_metadata_server_app(self, credentials: Dict[str, str]) -> web.Application: + key = orjson.loads(base64.b64decode(credentials['key.json']).decode()) + return create_app(aiogoogle.GoogleServiceAccountCredentials(key), self._metadata_server_client) def instance_config_from_config_dict(self, config_dict: Dict[str, str]) -> GCPSlimInstanceConfig: return GCPSlimInstanceConfig.from_dict(config_dict) - def _write_gcsfuse_credentials(self, credentials: GCPUserCredentials, mount_base_path_data: str) -> str: + def _write_gcsfuse_credentials(self, credentials: Dict[str, str], mount_base_path_data: str) -> str: if mount_base_path_data not in self._gcsfuse_credential_files: with tempfile.NamedTemporaryFile(mode='w', encoding='utf-8', delete=False) as credsfile: - credsfile.write(credentials.key) + credsfile.write(base64.b64decode(credentials['key.json']).decode()) self._gcsfuse_credential_files[mount_base_path_data] = credsfile.name return self._gcsfuse_credential_files[mount_base_path_data] async def _mount_cloudfuse( self, - credentials: GCPUserCredentials, + credentials: Dict[str, str], mount_base_path_data: str, mount_base_path_tmp: str, config: dict, ): # pylint: disable=unused-argument - fuse_credentials_path = self._write_gcsfuse_credentials(credentials, mount_base_path_data) bucket = config['bucket'] @@ -124,7 +139,7 @@ async def unmount_cloudfuse(self, mount_base_path_data: str): del self._gcsfuse_credential_files[mount_base_path_data] async def close(self): - await self._compute_client.close() + await self._exit_stack.aclose() def __str__(self): return f'project={self.project} zone={self.zone}' diff --git a/batch/batch/cloud/terra/__init__.py b/batch/batch/cloud/terra/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/batch/batch/cloud/terra/azure/__init__.py b/batch/batch/cloud/terra/azure/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/batch/batch/cloud/terra/azure/driver/__init__.py b/batch/batch/cloud/terra/azure/driver/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/batch/batch/cloud/terra/azure/driver/driver.py b/batch/batch/cloud/terra/azure/driver/driver.py new file mode 100644 index 00000000000..fb683129b75 --- /dev/null +++ b/batch/batch/cloud/terra/azure/driver/driver.py @@ -0,0 +1,529 @@ +import asyncio +import base64 +import json +import logging +import os +import uuid +from shlex import quote as shq +from typing import List, Tuple + +import aiohttp + +from gear import Database +from gear.cloud_config import get_azure_config +from hailtop import aiotools +from hailtop.aiocloud.aioazure import AzurePricingClient +from hailtop.aiocloud.aioterra.azure import TerraClient +from hailtop.config import get_deploy_config +from hailtop.config.deploy_config import TerraDeployConfig +from hailtop.utils import parse_timestamp_msecs, periodically_call, secret_alnum_string + +from .....batch_configuration import DOCKER_PREFIX, INTERNAL_GATEWAY_IP +from .....driver.driver import CloudDriver +from .....driver.instance import Instance +from .....driver.instance_collection import InstanceCollectionManager, JobPrivateInstanceManager, Pool +from .....driver.location import CloudLocationMonitor +from .....driver.resource_manager import ( + CloudResourceManager, + UnknownVMState, + VMDoesNotExist, + VMState, + VMStateCreating, + VMStateRunning, + VMStateTerminated, +) +from .....file_store import FileStore +from .....inst_coll_config import InstanceCollectionConfigs +from .....instance_config import InstanceConfig, QuantifiedResource +from ....azure.driver.billing_manager import AzureBillingManager +from ....azure.resource_utils import ( + azure_machine_type_to_worker_type_and_cores, + azure_worker_memory_per_core_mib, + azure_worker_properties_to_machine_type, +) +from ....utils import ACCEPTABLE_QUERY_JAR_URL_PREFIX +from ..instance_config import TerraAzureSlimInstanceConfig + +log = logging.getLogger('driver') + +TERRA_AZURE_INSTANCE_CONFIG_VERSION = 1 + +deploy_config = get_deploy_config() + + +class SingleRegionMonitor(CloudLocationMonitor): + @staticmethod + async def create(default_region: str) -> 'SingleRegionMonitor': + return SingleRegionMonitor(default_region) + + def __init__(self, default_region: str): + self._default_region = default_region + + def default_location(self) -> str: + return self._default_region + + def choose_location( + self, + cores: int, + local_ssd_data_disk: bool, + data_disk_size_gb: int, + preemptible: bool, + regions: List[str], + machine_type: str, + ) -> str: + return self._default_region + + +def create_vm_config( + file_store: FileStore, + location: str, + machine_name: str, + machine_type: str, + activation_token: str, + max_idle_time_msecs: int, + instance_config: InstanceConfig, +): + BATCH_WORKER_IMAGE = os.environ['HAIL_BATCH_WORKER_IMAGE'] + TERRA_STORAGE_ACCOUNT = os.environ['TERRA_STORAGE_ACCOUNT'] + WORKSPACE_STORAGE_CONTAINER_ID = os.environ['WORKSPACE_STORAGE_CONTAINER_ID'] + WORKSPACE_STORAGE_CONTAINER_URL = os.environ['WORKSPACE_STORAGE_CONTAINER_URL'] + WORKSPACE_MANAGER_URL = os.environ['WORKSPACE_MANAGER_URL'] + WORKSPACE_ID = os.environ['WORKSPACE_ID'] + + instance_config_base64 = base64.b64encode(json.dumps(instance_config.to_dict()).encode()).decode() + + assert isinstance(deploy_config, TerraDeployConfig) + assert isinstance(instance_config, TerraAzureSlimInstanceConfig) + + startup_script = rf"""#cloud-config + +mounts: + - [ ephemeral0, null ] + - [ ephemeral0.1, null ] + +write_files: + - owner: batch-worker:batch-worker + path: /startup.sh + content: | + #!/bin/bash + + set -ex + + function cleanup() {{ + set +x + sleep 1000 + token=$(az account get-access-token --query accessToken --output tsv) + + VM_RESOURCE_ID={ shq(instance_config._resource_id) } + curl -X POST "{ shq(WORKSPACE_MANAGER_URL) }/api/workspaces/v1/$WORKSPACE_ID/resources/controlled/azure/vm/$VM_RESOURCE_ID" \ + -H "accept: */*" \ + -H "Authorization: Bearer $token" \ + -H "Content-Type: application/json" \ + -d "{{\"jobControl\":{{\"id\":\"$VM_RESOURCE_ID\"}}}}" + }} + + trap cleanup EXIT + + apt-get update + apt-get -o DPkg::Lock::Timeout=60 install -y \ + apt-transport-https \ + ca-certificates \ + curl \ + gnupg \ + jq \ + lsb-release \ + software-properties-common + + curl --connect-timeout 5 \ + --max-time 10 \ + --retry 5 \ + --retry-max-time 40 \ + --location \ + --fail \ + --silent \ + --show-error \ + https://download.docker.com/linux/ubuntu/gpg | apt-key add - + + add-apt-repository \ + "deb [arch=amd64] https://download.docker.com/linux/ubuntu \ + $(lsb_release -cs) \ + stable" + + apt-get install -y docker-ce + + curl -sL https://aka.ms/InstallAzureCLIDeb | sudo bash + + az login --identity --allow-no-subscription + + # avoid "unable to get current user home directory: os/user lookup failed" + export HOME=/root + + UNRESERVED_WORKER_DATA_DISK_SIZE_GB=50 + ACCEPTABLE_QUERY_JAR_URL_PREFIX={ shq(ACCEPTABLE_QUERY_JAR_URL_PREFIX) } + + sudo mkdir -p /host/batch + sudo mkdir -p /host/logs + sudo mkdir -p /host/cloudfuse + + sudo mkdir -p /etc/netns + + sudo mkdir /deploy-config + sudo cat >/deploy-config/deploy-config.json <<'EOF' + { json.dumps(get_deploy_config().with_location('external').get_config()) } + EOF + + + SUBSCRIPTION_ID=$(curl -s -H Metadata:true --noproxy "*" "http://169.254.169.254/metadata/instance/compute/subscriptionId?api-version=2021-02-01&format=text") + RESOURCE_GROUP=$(curl -s -H Metadata:true --noproxy "*" "http://169.254.169.254/metadata/instance/compute/resourceGroupName?api-version=2021-02-01&format=text") + LOCATION=$(curl -s -H Metadata:true --noproxy "*" "http://169.254.169.254/metadata/instance/compute/location?api-version=2021-02-01&format=text") + NAME=$(curl -s -H Metadata:true --noproxy "*" "http://169.254.169.254/metadata/instance/compute/name?api-version=2021-02-01&format=text") + IP_ADDRESS=$(curl -s -H Metadata:true --noproxy "*" "http://169.254.169.254/metadata/instance/network/interface/0/ipv4/ipAddress/0/privateIpAddress?api-version=2021-02-01&format=text") + + CORES=$(nproc) + NAMESPACE=default + ACTIVATION_TOKEN={ shq(activation_token) } + BATCH_LOGS_STORAGE_URI={ shq(file_store.batch_logs_storage_uri) } + INSTANCE_ID={ shq(file_store.instance_id) } + INSTANCE_CONFIG="{ instance_config_base64 }" + MAX_IDLE_TIME_MSECS={ max_idle_time_msecs } + BATCH_WORKER_IMAGE={ shq(BATCH_WORKER_IMAGE) } + INTERNET_INTERFACE=eth0 + WORKSPACE_STORAGE_CONTAINER_ID={ shq(WORKSPACE_STORAGE_CONTAINER_ID) } + TERRA_STORAGE_ACCOUNT={ shq(TERRA_STORAGE_ACCOUNT) } + WORKSPACE_STORAGE_CONTAINER_URL={ shq(WORKSPACE_STORAGE_CONTAINER_URL) } + WORKSPACE_MANAGER_URL={ shq(WORKSPACE_MANAGER_URL) } + WORKSPACE_ID={ shq(WORKSPACE_ID) } + REGION={ shq(instance_config.region_for(location)) } + INTERNAL_GATEWAY_IP={ shq(INTERNAL_GATEWAY_IP) } + DOCKER_PREFIX={ shq(DOCKER_PREFIX) } + + # private job network = 172.20.0.0/16 + # public job network = 172.21.0.0/16 + # [all networks] Rewrite traffic coming from containers to masquerade as the host + iptables --table nat --append POSTROUTING --source 172.20.0.0/15 --jump MASQUERADE + + # retry once + docker pull $BATCH_WORKER_IMAGE || \ + (echo 'pull failed, retrying' && sleep 15 && docker pull $BATCH_WORKER_IMAGE) + + BATCH_WORKER_IMAGE_ID=$(docker inspect $BATCH_WORKER_IMAGE --format='{{{{.Id}}}}' | cut -d':' -f2) + + # So here I go it's my shot. + docker run \ + -e CLOUD=azure \ + -e CORES=$CORES \ + -e NAME=$NAME \ + -e NAMESPACE=$NAMESPACE \ + -e ACTIVATION_TOKEN=$ACTIVATION_TOKEN \ + -e IP_ADDRESS=$IP_ADDRESS \ + -e BATCH_LOGS_STORAGE_URI=$BATCH_LOGS_STORAGE_URI \ + -e INSTANCE_ID=$INSTANCE_ID \ + -e SUBSCRIPTION_ID=$SUBSCRIPTION_ID \ + -e RESOURCE_GROUP=$RESOURCE_GROUP \ + -e LOCATION=$LOCATION \ + -e INSTANCE_CONFIG=$INSTANCE_CONFIG \ + -e MAX_IDLE_TIME_MSECS=$MAX_IDLE_TIME_MSECS \ + -e BATCH_WORKER_IMAGE=$BATCH_WORKER_IMAGE \ + -e BATCH_WORKER_IMAGE_ID=$BATCH_WORKER_IMAGE_ID \ + -e INTERNET_INTERFACE=$INTERNET_INTERFACE \ + -e INTERNAL_GATEWAY_IP=$INTERNAL_GATEWAY_IP \ + -e DOCKER_PREFIX=$DOCKER_PREFIX \ + -e HAIL_TERRA=true \ + -e WORKSPACE_STORAGE_CONTAINER_ID=$WORKSPACE_STORAGE_CONTAINER_ID \ + -e WORKSPACE_STORAGE_CONTAINER_URL=$WORKSPACE_STORAGE_CONTAINER_URL \ + -e TERRA_STORAGE_ACCOUNT=$TERRA_STORAGE_ACCOUNT \ + -e WORKSPACE_MANAGER_URL=$WORKSPACE_MANAGER_URL \ + -e WORKSPACE_ID=$WORKSPACE_ID \ + -e UNRESERVED_WORKER_DATA_DISK_SIZE_GB=$UNRESERVED_WORKER_DATA_DISK_SIZE_GB \ + -e ACCEPTABLE_QUERY_JAR_URL_PREFIX=$ACCEPTABLE_QUERY_JAR_URL_PREFIX \ + -e REGION=$REGION \ + -v /var/run/docker.sock:/var/run/docker.sock \ + -v /var/run/netns:/var/run/netns:shared \ + -v /usr/bin/docker:/usr/bin/docker \ + -v /usr/sbin/xfs_quota:/usr/sbin/xfs_quota \ + -v /batch:/batch:shared \ + -v /logs:/logs \ + -v /global-config:/global-config \ + -v /cloudfuse:/cloudfuse:shared \ + -v /etc/netns:/etc/netns \ + -v /sys/fs/cgroup:/sys/fs/cgroup \ + --mount type=bind,source=/host,target=/host \ + --mount type=bind,source=/dev,target=/dev,bind-propagation=rshared \ + -p 5000:5000 \ + --device /dev/fuse \ + --device /dev \ + --privileged \ + --cap-add SYS_ADMIN \ + --security-opt apparmor:unconfined \ + --network host \ + $BATCH_WORKER_IMAGE \ + python3 -u -m batch.worker.worker + + +runcmd: + - nohup bash /startup.sh 2>&1 >worker.log & + """ + + encoded_startup_script = base64.b64encode(startup_script.encode()).decode() + + config = { + 'common': { + 'name': machine_name, + 'description': machine_name, + 'cloningInstructions': 'COPY_NOTHING', + 'accessScope': 'PRIVATE_ACCESS', + 'managedBy': 'USER', + 'resourceId': instance_config._resource_id, + 'properties': [], + }, + 'azureVm': { + 'name': machine_name, + 'vmSize': machine_type, + 'vmImage': { + 'publisher': 'Canonical', + 'offer': '0001-com-ubuntu-server-focal', + 'sku': '20_04-lts-gen2', + 'version': '20.04.202305150', + }, + 'vmUser': { + 'name': 'hail-admin', + 'password': secret_alnum_string(), + }, + 'ephemeralOSDisk': 'NONE', + 'customData': encoded_startup_script, + }, + 'jobControl': { + 'id': machine_name[32:], + }, + } + + return config + + +class TerraAzureResourceManager(CloudResourceManager): + def __init__( + self, + billing_manager, + ): + self.terra_client = TerraClient() + self.billing_manager = billing_manager + + async def delete_vm(self, instance: Instance): + config = instance.instance_config + assert isinstance(config, TerraAzureSlimInstanceConfig) + terra_vm_resource_id = config._resource_id + + try: + await self.terra_client.post( + f'/vm/{terra_vm_resource_id}', + json={ + 'jobControl': {'id': str(uuid.uuid4())}, + }, + ) + except aiohttp.ClientResponseError as e: + if e.status == 404: + raise VMDoesNotExist() from e + raise + + async def get_vm_state(self, instance: Instance) -> VMState: + try: + spec = await self.terra_client.get(f'/vm/create-result/{instance.name[32:]}') + state = spec['metadata']['state'] + if state == 'CREATING': + return VMStateCreating(spec, instance.time_created) + if state == 'READY': + last_start_timestamp_msecs = parse_timestamp_msecs(spec['metadata'].get('lastUpdatedDate')) + assert last_start_timestamp_msecs is not None + return VMStateRunning(spec, last_start_timestamp_msecs) + if state == 'DELETING': + return VMStateTerminated(spec) + return UnknownVMState(spec) + except aiohttp.ClientResponseError as e: + if e.status == 404: + raise VMDoesNotExist() from e + raise + + def machine_type(self, cores: int, worker_type: str, local_ssd: bool) -> str: + return azure_worker_properties_to_machine_type(worker_type, cores, local_ssd) + + def worker_type_and_cores(self, machine_type: str) -> Tuple[str, int]: + return azure_machine_type_to_worker_type_and_cores(machine_type) + + def instance_config( + self, + machine_type: str, + preemptible: bool, + local_ssd_data_disk: bool, + data_disk_size_gb: int, + boot_disk_size_gb: int, + job_private: bool, + location: str, + ) -> TerraAzureSlimInstanceConfig: + return TerraAzureSlimInstanceConfig.create( + self.billing_manager.product_versions, + machine_type, + preemptible, + local_ssd_data_disk, + data_disk_size_gb, + boot_disk_size_gb, + job_private, + location, + ) + + def instance_config_from_dict(self, data: dict) -> TerraAzureSlimInstanceConfig: + return TerraAzureSlimInstanceConfig.from_dict(data) + + async def create_vm( + self, + file_store: FileStore, + machine_name: str, + activation_token: str, + max_idle_time_msecs: int, + local_ssd_data_disk: bool, + data_disk_size_gb: int, + boot_disk_size_gb: int, + preemptible: bool, + job_private: bool, + location: str, + machine_type: str, + instance_config: InstanceConfig, + ) -> List[QuantifiedResource]: + assert isinstance(instance_config, TerraAzureSlimInstanceConfig) + worker_type, cores = self.worker_type_and_cores(machine_type) + + memory_mib = azure_worker_memory_per_core_mib(worker_type) * cores + memory_in_bytes = memory_mib << 20 + cores_mcpu = cores * 1000 + total_resources_on_instance = instance_config.quantified_resources( + cpu_in_mcpu=cores_mcpu, memory_in_bytes=memory_in_bytes, extra_storage_in_gib=0 + ) + + if not local_ssd_data_disk: + raise ValueError('VMs without a local ssd data disk are not yet supported') + + vm_config = create_vm_config( + file_store, + location, + machine_name, + machine_type, + activation_token, + max_idle_time_msecs, + instance_config, + ) + + try: + res = await self.terra_client.post('/vm', json=vm_config) + log.info(f'Terra response creating machine {machine_name}: {res}') + except Exception: + log.exception(f'error while creating machine {machine_name}') + return total_resources_on_instance + + +class TerraAzureDriver(CloudDriver): + @staticmethod + async def create( + app, + db: Database, # BORROWED + machine_name_prefix: str, + namespace: str, + inst_coll_configs: InstanceCollectionConfigs, + ) -> 'TerraAzureDriver': + azure_config = get_azure_config() + region = azure_config.region + regions = [region] + + region_args = [(r,) for r in regions] + await db.execute_many( + """ +INSERT INTO regions (region) VALUES (%s) +ON DUPLICATE KEY UPDATE region = region; +""", + region_args, + ) + + db_regions = { + record['region']: record['region_id'] + async for record in db.select_and_fetchall('SELECT region_id, region from regions') + } + assert max(db_regions.values()) < 64, str(db_regions) + app['regions'] = db_regions + + region_monitor = await SingleRegionMonitor.create(region) + inst_coll_manager = InstanceCollectionManager(db, machine_name_prefix, region_monitor, region, regions) + pricing_client = AzurePricingClient() + billing_manager = await AzureBillingManager.create(db, pricing_client, regions) + resource_manager = TerraAzureResourceManager(billing_manager) + task_manager = aiotools.BackgroundTaskManager() + task_manager.ensure_future(periodically_call(300, billing_manager.refresh_resources_from_retail_prices)) + + create_pools_coros = [ + Pool.create( + app, + db, + inst_coll_manager, + resource_manager, + machine_name_prefix, + config, + app['async_worker_pool'], + task_manager, + ) + for config in inst_coll_configs.name_pool_config.values() + ] + + jpim, *_ = await asyncio.gather( + JobPrivateInstanceManager.create( + app, + db, + inst_coll_manager, + resource_manager, + machine_name_prefix, + inst_coll_configs.jpim_config, + task_manager, + ), + *create_pools_coros, + ) + + return TerraAzureDriver( + db, + machine_name_prefix, + namespace, + region_monitor, + inst_coll_manager, + jpim, + billing_manager, + task_manager, + ) + + def __init__( + self, + db: Database, + machine_name_prefix: str, + namespace: str, + region_monitor: SingleRegionMonitor, + inst_coll_manager: InstanceCollectionManager, + job_private_inst_manager: JobPrivateInstanceManager, + billing_manager: AzureBillingManager, + task_manager: aiotools.BackgroundTaskManager, + ): + self.db = db + self.machine_name_prefix = machine_name_prefix + self.namespace = namespace + self.region_monitor = region_monitor + self.job_private_inst_manager = job_private_inst_manager + self._inst_coll_manager = inst_coll_manager + self._billing_manager = billing_manager + self._task_manager = task_manager + + @property + def billing_manager(self) -> AzureBillingManager: + return self._billing_manager + + @property + def inst_coll_manager(self) -> InstanceCollectionManager: + return self._inst_coll_manager + + async def shutdown(self) -> None: + await self._task_manager.shutdown_and_wait() + + def get_quotas(self): + raise NotImplementedError diff --git a/batch/batch/cloud/terra/azure/instance_config.py b/batch/batch/cloud/terra/azure/instance_config.py new file mode 100644 index 00000000000..f16d527a8f4 --- /dev/null +++ b/batch/batch/cloud/terra/azure/instance_config.py @@ -0,0 +1,83 @@ +import uuid +from typing import List + +from hailtop.utils import filter_none + +from ....driver.billing_manager import ProductVersions +from ...azure.instance_config import AzureSlimInstanceConfig +from ...azure.resources import AzureResource, AzureStaticSizedDiskResource, AzureVMResource, azure_resource_from_dict + + +class TerraAzureSlimInstanceConfig(AzureSlimInstanceConfig): + @staticmethod + def create( + product_versions: ProductVersions, + machine_type: str, + preemptible: bool, + local_ssd_data_disk: bool, + data_disk_size_gb: int, + boot_disk_size_gb: int, + job_private: bool, + location: str, + ) -> 'TerraAzureSlimInstanceConfig': + resource_id = str(uuid.uuid4()) + + resources: List[AzureResource] = filter_none([ + AzureVMResource.create(product_versions, machine_type, preemptible, location), + AzureStaticSizedDiskResource.create(product_versions, 'E', boot_disk_size_gb, location), + ]) + + return TerraAzureSlimInstanceConfig( + machine_type=machine_type, + preemptible=preemptible, + local_ssd_data_disk=local_ssd_data_disk, + data_disk_size_gb=data_disk_size_gb, + boot_disk_size_gb=boot_disk_size_gb, + job_private=job_private, + resources=resources, + resource_id=resource_id, + ) + + def __init__( + self, + machine_type: str, + preemptible: bool, + local_ssd_data_disk: bool, + data_disk_size_gb: int, + boot_disk_size_gb: int, + job_private: bool, + resources: List[AzureResource], + resource_id: str, + ): + super().__init__( + machine_type=machine_type, + preemptible=preemptible, + local_ssd_data_disk=local_ssd_data_disk, + data_disk_size_gb=data_disk_size_gb, + boot_disk_size_gb=boot_disk_size_gb, + job_private=job_private, + resources=resources, + ) + self._resource_id = resource_id + + @staticmethod + def from_dict(data: dict) -> 'TerraAzureSlimInstanceConfig': + resources = data.get('resources', []) + resources = [azure_resource_from_dict(resource) for resource in resources] + return TerraAzureSlimInstanceConfig( + data['machine_type'], + data['preemptible'], + data['local_ssd_data_disk'], + data['data_disk_size_gb'], + data['boot_disk_size_gb'], + data['job_private'], + resources, + data['resource_id'], + ) + + def to_dict(self) -> dict: + azure_dict = super().to_dict() + azure_dict.update({ + 'resource_id': self._resource_id, + }) + return azure_dict diff --git a/batch/batch/cloud/terra/azure/worker/__init__.py b/batch/batch/cloud/terra/azure/worker/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/batch/batch/cloud/terra/azure/worker/worker_api.py b/batch/batch/cloud/terra/azure/worker/worker_api.py new file mode 100644 index 00000000000..d0e55fe26b2 --- /dev/null +++ b/batch/batch/cloud/terra/azure/worker/worker_api.py @@ -0,0 +1,85 @@ +import os +from typing import Dict, List + +import orjson +from aiohttp import web + +from hailtop import httpx +from hailtop.aiocloud.aioazure import AzureCredentials +from hailtop.aiocloud.aioterra.azure import TerraAzureAsyncFS +from hailtop.aiotools.fs import AsyncFS +from hailtop.auth.auth import IdentityProvider + +from .....worker.disk import CloudDisk +from .....worker.worker_api import CloudWorkerAPI, ContainerRegistryCredentials +from ....terra.azure.instance_config import TerraAzureSlimInstanceConfig + + +class TerraAzureWorkerAPI(CloudWorkerAPI): + nameserver_ip = '168.63.129.16' + + @staticmethod + def from_env() -> 'TerraAzureWorkerAPI': + return TerraAzureWorkerAPI( + os.environ['WORKSPACE_STORAGE_CONTAINER_ID'], + os.environ['WORKSPACE_STORAGE_CONTAINER_URL'], + os.environ['WORKSPACE_ID'], + os.environ['WORKSPACE_MANAGER_URL'], + ) + + def __init__( + self, + workspace_storage_container_id: str, + workspace_storage_container_url: str, + workspace_id: str, + workspace_manager_url: str, + ): + self.workspace_storage_container_id = workspace_storage_container_id + self.workspace_storage_container_url = workspace_storage_container_url + self.workspace_id = workspace_id + self.workspace_manager_url = workspace_manager_url + self._managed_identity_credentials = AzureCredentials.default_credentials() + + @property + def cloud_specific_env_vars_for_user_jobs(self) -> List[str]: + idp_json = orjson.dumps({'idp': IdentityProvider.MICROSOFT.value}).decode('utf-8') + return [ + 'HAIL_TERRA=1', + 'HAIL_LOCATION=external', # There is no internal gateway, jobs must communicate over the internet + f'HAIL_IDENTITY_PROVIDER_JSON={idp_json}', + f'WORKSPACE_STORAGE_CONTAINER_ID={self.workspace_storage_container_id}', + f'WORKSPACE_STORAGE_CONTAINER_URL={self.workspace_storage_container_url}', + f'WORKSPACE_ID={self.workspace_id}', + f'WORKSPACE_MANAGER_URL={self.workspace_manager_url}', + ] + + def create_disk(self, *_) -> CloudDisk: + raise NotImplementedError + + def get_cloud_async_fs(self) -> AsyncFS: + return TerraAzureAsyncFS() + + async def worker_container_registry_credentials(self, session: httpx.ClientSession) -> ContainerRegistryCredentials: + return {} + + async def user_container_registry_credentials(self, credentials: Dict[str, str]) -> ContainerRegistryCredentials: + return {} + + def create_metadata_server_app(self, credentials: Dict[str, str]) -> web.Application: + raise NotImplementedError + + def instance_config_from_config_dict(self, config_dict: Dict[str, str]) -> TerraAzureSlimInstanceConfig: + return TerraAzureSlimInstanceConfig.from_dict(config_dict) + + async def extra_hail_headers(self) -> Dict[str, str]: + token = await self._managed_identity_credentials.access_token() + return {'Authorization': f'Bearer {token}'} + + async def _mount_cloudfuse(self, *_): + raise NotImplementedError + + async def unmount_cloudfuse(self, mount_base_path_data: str): + raise NotImplementedError + + async def close(self): + pass diff --git a/batch/batch/cloud/utils.py b/batch/batch/cloud/utils.py index c2c1a1aea04..2ed1cd91191 100644 --- a/batch/batch/cloud/utils.py +++ b/batch/batch/cloud/utils.py @@ -8,11 +8,14 @@ from ..instance_config import InstanceConfig from .azure.instance_config import AzureSlimInstanceConfig from .gcp.instance_config import GCPSlimInstanceConfig +from .terra.azure.instance_config import TerraAzureSlimInstanceConfig def instance_config_from_config_dict(config: Dict[str, Any]) -> InstanceConfig: cloud = config.get('cloud', 'gcp') if cloud == 'azure': + if os.environ.get('HAIL_TERRA'): + return TerraAzureSlimInstanceConfig.from_dict(config) return AzureSlimInstanceConfig.from_dict(config) assert cloud == 'gcp' return GCPSlimInstanceConfig.from_dict(config) diff --git a/batch/batch/constants.py b/batch/batch/constants.py deleted file mode 100644 index 76800e53aee..00000000000 --- a/batch/batch/constants.py +++ /dev/null @@ -1 +0,0 @@ -ROOT_JOB_GROUP_ID = 0 diff --git a/batch/batch/driver/billing_manager.py b/batch/batch/driver/billing_manager.py index b901272d73b..d322fe3a418 100644 --- a/batch/batch/driver/billing_manager.py +++ b/batch/batch/driver/billing_manager.py @@ -151,39 +151,37 @@ async def _refresh_resources_from_retail_prices(self, prices: List[Price]): @transaction(self.db) async def insert_or_update(tx): if resource_updates: - last_resource_id = await tx.execute_and_fetchone( - ''' + last_resource_id = await tx.execute_and_fetchone(""" SELECT COALESCE(MAX(resource_id), 0) AS last_resource_id FROM resources FOR UPDATE -''' - ) +""") last_resource_id = last_resource_id['last_resource_id'] await tx.execute_many( - ''' + """ INSERT INTO `resources` (resource, rate) VALUES (%s, %s) -''', +""", resource_updates, ) await tx.execute_update( - ''' + """ UPDATE resources SET deduped_resource_id = resource_id WHERE resource_id > %s AND deduped_resource_id IS NULL -''', +""", (last_resource_id,), ) if product_version_updates: await tx.execute_many( - ''' + """ INSERT INTO `latest_product_versions` (product, version, sku) VALUES (%s, %s, %s) ON DUPLICATE KEY UPDATE version = VALUES(version) -''', +""", product_version_updates, ) diff --git a/batch/batch/driver/canceller.py b/batch/batch/driver/canceller.py index 4ee7f0e51c1..3b65b9a5cd6 100644 --- a/batch/batch/driver/canceller.py +++ b/batch/batch/driver/canceller.py @@ -75,12 +75,12 @@ async def shutdown_and_wait(self): async def cancel_cancelled_ready_jobs_loop_body(self): records = self.db.select_and_fetchall( - ''' + """ SELECT user, CAST(COALESCE(SUM(n_cancelled_ready_jobs), 0) AS SIGNED) AS n_cancelled_ready_jobs FROM user_inst_coll_resources GROUP BY user HAVING n_cancelled_ready_jobs > 0; -''', +""", ) user_n_cancelled_ready_jobs = {record['user']: record['n_cancelled_ready_jobs'] async for record in records} @@ -94,39 +94,44 @@ async def cancel_cancelled_ready_jobs_loop_body(self): } async def user_cancelled_ready_jobs(user, remaining) -> AsyncIterator[Dict[str, Any]]: - async for batch in self.db.select_and_fetchall( - ''' -SELECT batches.id, job_groups_cancelled.id IS NOT NULL AS cancelled -FROM batches -LEFT JOIN job_groups_cancelled - ON batches.id = job_groups_cancelled.id + async for job_group in self.db.select_and_fetchall( + """ +SELECT job_groups.batch_id, job_groups.job_group_id, t.cancelled IS NOT NULL AS cancelled +FROM job_groups +LEFT JOIN LATERAL ( + SELECT 1 AS cancelled + FROM job_group_self_and_ancestors + INNER JOIN job_groups_cancelled + ON job_group_self_and_ancestors.batch_id = job_groups_cancelled.id AND + job_group_self_and_ancestors.ancestor_id = job_groups_cancelled.job_group_id + WHERE job_groups.batch_id = job_group_self_and_ancestors.batch_id AND + job_groups.job_group_id = job_group_self_and_ancestors.job_group_id +) AS t ON TRUE WHERE user = %s AND `state` = 'running'; -''', +""", (user,), ): - if batch['cancelled']: + if job_group['cancelled']: async for record in self.db.select_and_fetchall( - ''' -SELECT jobs.job_id + """ +SELECT jobs.batch_id, jobs.job_id, jobs.job_group_id FROM jobs FORCE INDEX(jobs_batch_id_state_always_run_cancelled) -WHERE batch_id = %s AND state = 'Ready' AND always_run = 0 +WHERE batch_id = %s AND job_group_id = %s AND state = 'Ready' AND always_run = 0 LIMIT %s; -''', - (batch['id'], remaining.value), +""", + (job_group['batch_id'], job_group['job_group_id'], remaining.value), ): - record['batch_id'] = batch['id'] yield record else: async for record in self.db.select_and_fetchall( - ''' -SELECT jobs.job_id + """ +SELECT jobs.batch_id, jobs.job_id, jobs.job_group_id FROM jobs FORCE INDEX(jobs_batch_id_state_always_run_cancelled) -WHERE batch_id = %s AND state = 'Ready' AND always_run = 0 AND cancelled = 1 +WHERE batch_id = %s AND job_group_id = %s AND state = 'Ready' AND always_run = 0 AND cancelled = 1 LIMIT %s; -''', - (batch['id'], remaining.value), +""", + (job_group['batch_id'], job_group['job_group_id'], remaining.value), ): - record['batch_id'] = batch['id'] yield record waitable_pool = WaitableSharedPool(self.async_worker_pool) @@ -137,18 +142,30 @@ async def user_cancelled_ready_jobs(user, remaining) -> AsyncIterator[Dict[str, async for record in user_cancelled_ready_jobs(user, remaining): batch_id = record['batch_id'] job_id = record['job_id'] + job_group_id = record['job_group_id'] id = (batch_id, job_id) log.info(f'cancelling job {id}') - async def cancel_with_error_handling(app, batch_id, job_id, id): + async def cancel_with_error_handling(app, batch_id, job_id, job_group_id, id): try: await mark_job_complete( - app, batch_id, job_id, None, None, 'Cancelled', None, None, None, 'cancelled', [] + app, + batch_id, + job_id, + None, + job_group_id, + None, + 'Cancelled', + None, + None, + None, + 'cancelled', + [], ) except Exception: log.info(f'error while cancelling job {id}', exc_info=True) - await waitable_pool.call(cancel_with_error_handling, self.app, batch_id, job_id, id) + await waitable_pool.call(cancel_with_error_handling, self.app, batch_id, job_id, job_group_id, id) remaining.value -= 1 if remaining.value <= 0: @@ -161,12 +178,12 @@ async def cancel_with_error_handling(app, batch_id, job_id, id): async def cancel_cancelled_creating_jobs_loop_body(self): records = self.db.select_and_fetchall( - ''' + """ SELECT user, CAST(COALESCE(SUM(n_cancelled_creating_jobs), 0) AS SIGNED) AS n_cancelled_creating_jobs FROM user_inst_coll_resources GROUP BY user HAVING n_cancelled_creating_jobs > 0; -''', +""", ) user_n_cancelled_creating_jobs = { record['user']: record['n_cancelled_creating_jobs'] async for record in records @@ -182,28 +199,34 @@ async def cancel_cancelled_creating_jobs_loop_body(self): } async def user_cancelled_creating_jobs(user, remaining) -> AsyncIterator[Dict[str, Any]]: - async for batch in self.db.select_and_fetchall( - ''' -SELECT batches.id -FROM batches -INNER JOIN job_groups_cancelled - ON batches.id = job_groups_cancelled.id + async for job_group in self.db.select_and_fetchall( + """ +SELECT job_groups.batch_id, job_groups.job_group_id +FROM job_groups +INNER JOIN LATERAL ( + SELECT 1 AS cancelled + FROM job_group_self_and_ancestors + INNER JOIN job_groups_cancelled + ON job_group_self_and_ancestors.batch_id = job_groups_cancelled.id AND + job_group_self_and_ancestors.ancestor_id = job_groups_cancelled.job_group_id + WHERE job_groups.batch_id = job_group_self_and_ancestors.batch_id AND + job_groups.job_group_id = job_group_self_and_ancestors.job_group_id +) AS t ON TRUE WHERE user = %s AND `state` = 'running'; -''', +""", (user,), ): async for record in self.db.select_and_fetchall( - ''' -SELECT jobs.job_id, attempts.attempt_id, attempts.instance_name + """ +SELECT jobs.batch_id, jobs.job_id, attempts.attempt_id, attempts.instance_name, jobs.job_group_id FROM jobs FORCE INDEX(jobs_batch_id_state_always_run_cancelled) STRAIGHT_JOIN attempts ON attempts.batch_id = jobs.batch_id AND attempts.job_id = jobs.job_id -WHERE jobs.batch_id = %s AND state = 'Creating' AND always_run = 0 AND cancelled = 0 +WHERE jobs.batch_id = %s AND jobs.job_group_id = %s AND state = 'Creating' AND always_run = 0 AND cancelled = 0 LIMIT %s; -''', - (batch['id'], remaining.value), +""", + (job_group['batch_id'], job_group['job_group_id'], remaining.value), ): - record['batch_id'] = batch['id'] yield record waitable_pool = WaitableSharedPool(self.async_worker_pool) @@ -215,10 +238,13 @@ async def user_cancelled_creating_jobs(user, remaining) -> AsyncIterator[Dict[st batch_id = record['batch_id'] job_id = record['job_id'] attempt_id = record['attempt_id'] + job_group_id = record['job_group_id'] instance_name = record['instance_name'] id = (batch_id, job_id) - async def cancel_with_error_handling(app, batch_id, job_id, attempt_id, instance_name, id): + async def cancel_with_error_handling( + app, batch_id, job_id, attempt_id, job_group_id, instance_name, id + ): try: end_time = time_msecs() await mark_job_complete( @@ -226,6 +252,7 @@ async def cancel_with_error_handling(app, batch_id, job_id, attempt_id, instance batch_id, job_id, attempt_id, + job_group_id, instance_name, 'Cancelled', None, @@ -246,7 +273,7 @@ async def cancel_with_error_handling(app, batch_id, job_id, attempt_id, instance log.info(f'cancelling creating job {id} on instance {instance_name}', exc_info=True) await waitable_pool.call( - cancel_with_error_handling, self.app, batch_id, job_id, attempt_id, instance_name, id + cancel_with_error_handling, self.app, batch_id, job_id, attempt_id, job_group_id, instance_name, id ) remaining.value -= 1 @@ -260,12 +287,12 @@ async def cancel_with_error_handling(app, batch_id, job_id, attempt_id, instance async def cancel_cancelled_running_jobs_loop_body(self): records = self.db.select_and_fetchall( - ''' + """ SELECT user, CAST(COALESCE(SUM(n_cancelled_running_jobs), 0) AS SIGNED) AS n_cancelled_running_jobs FROM user_inst_coll_resources GROUP BY user HAVING n_cancelled_running_jobs > 0; -''', +""", ) user_n_cancelled_running_jobs = {record['user']: record['n_cancelled_running_jobs'] async for record in records} @@ -279,28 +306,34 @@ async def cancel_cancelled_running_jobs_loop_body(self): } async def user_cancelled_running_jobs(user, remaining) -> AsyncIterator[Dict[str, Any]]: - async for batch in self.db.select_and_fetchall( - ''' -SELECT batches.id -FROM batches -INNER JOIN job_groups_cancelled - ON batches.id = job_groups_cancelled.id + async for job_group in self.db.select_and_fetchall( + """ +SELECT job_groups.batch_id, job_groups.job_group_id +FROM job_groups +INNER JOIN LATERAL ( + SELECT 1 AS cancelled + FROM job_group_self_and_ancestors + INNER JOIN job_groups_cancelled + ON job_group_self_and_ancestors.batch_id = job_groups_cancelled.id AND + job_group_self_and_ancestors.ancestor_id = job_groups_cancelled.job_group_id + WHERE job_groups.batch_id = job_group_self_and_ancestors.batch_id AND + job_groups.job_group_id = job_group_self_and_ancestors.job_group_id +) AS t ON TRUE WHERE user = %s AND `state` = 'running'; -''', +""", (user,), ): async for record in self.db.select_and_fetchall( - ''' -SELECT jobs.job_id, attempts.attempt_id, attempts.instance_name + """ +SELECT jobs.batch_id, jobs.job_id, attempts.attempt_id, attempts.instance_name FROM jobs FORCE INDEX(jobs_batch_id_state_always_run_cancelled) STRAIGHT_JOIN attempts ON attempts.batch_id = jobs.batch_id AND attempts.job_id = jobs.job_id -WHERE jobs.batch_id = %s AND state = 'Running' AND always_run = 0 AND cancelled = 0 +WHERE jobs.batch_id = %s AND jobs.job_group_id = %s AND state = 'Running' AND always_run = 0 AND cancelled = 0 LIMIT %s; -''', - (batch['id'], remaining.value), +""", + (job_group['batch_id'], job_group['job_group_id'], remaining.value), ): - record['batch_id'] = batch['id'] yield record waitable_pool = WaitableSharedPool(self.async_worker_pool) @@ -336,7 +369,7 @@ async def cancel_orphaned_attempts_loop_body(self): n_unscheduled = 0 async for record in self.db.select_and_fetchall( - ''' + """ SELECT attempts.* FROM attempts INNER JOIN jobs ON attempts.batch_id = jobs.batch_id AND attempts.job_id = jobs.job_id @@ -347,7 +380,7 @@ async def cancel_orphaned_attempts_loop_body(self): AND instances.`state` = 'active' ORDER BY attempts.start_time ASC LIMIT 300; -''', +""", ): batch_id = record['batch_id'] job_id = record['job_id'] diff --git a/batch/batch/driver/instance.py b/batch/batch/driver/instance.py index ab30e6484ad..5211e3143b6 100644 --- a/batch/batch/driver/instance.py +++ b/batch/batch/driver/instance.py @@ -6,8 +6,7 @@ import aiohttp -from gear import Database, transaction -from hailtop import httpx +from gear import CommonAiohttpAppKeys, Database, transaction from hailtop.humanizex import naturaldelta_msec from hailtop.utils import retry_transient_errors, time_msecs, time_msecs_str @@ -62,11 +61,11 @@ async def create( @transaction(db) async def insert(tx): await tx.just_execute( - ''' + """ INSERT INTO instances (name, state, activation_token, token, cores_mcpu, time_created, last_updated, version, location, inst_coll, machine_type, preemptible, instance_config) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s); -''', +""", ( name, state, @@ -84,10 +83,10 @@ async def insert(tx): ), ) await tx.just_execute( - ''' + """ INSERT INTO instances_free_cores_mcpu (name, free_cores_mcpu) VALUES (%s, %s); -''', +""", ( name, worker_cores_mcpu, @@ -133,7 +132,7 @@ def __init__( instance_config: InstanceConfig, ): self.db: Database = app['db'] - self.client_session: httpx.ClientSession = app['client_session'] + self.client_session = app[CommonAiohttpAppKeys.CLIENT_SESSION] self.inst_coll = inst_coll # pending, active, inactive, deleted self._state = state @@ -311,22 +310,22 @@ async def mark_healthy(self): self.inst_coll.adjust_for_add_instance(self) await self.db.execute_update( - ''' + """ UPDATE instances SET last_updated = %s, failed_request_count = 0 WHERE name = %s; -''', +""", (now, self.name), 'mark_healthy', ) async def incr_failed_request_count(self): await self.db.execute_update( - ''' + """ UPDATE instances SET failed_request_count = failed_request_count + 1 WHERE name = %s; -''', +""", (self.name,), ) diff --git a/batch/batch/driver/instance_collection/job_private.py b/batch/batch/driver/instance_collection/job_private.py index d4800402cbc..0d4d336c92b 100644 --- a/batch/batch/driver/instance_collection/job_private.py +++ b/batch/batch/driver/instance_collection/job_private.py @@ -51,13 +51,13 @@ async def create( log.info(f'initializing {jpim}') async for record in db.select_and_fetchall( - ''' + """ SELECT instances.*, instances_free_cores_mcpu.free_cores_mcpu FROM instances INNER JOIN instances_free_cores_mcpu ON instances.name = instances_free_cores_mcpu.name WHERE removed = 0 AND inst_coll = %s; -''', +""", (jpim.name,), ): jpim.add_instance(Instance.from_record(app, jpim, record)) @@ -135,7 +135,7 @@ async def configure( worker_max_idle_time_secs, ): await self.db.just_execute( - ''' + """ UPDATE inst_colls SET boot_disk_size_gb = %s, max_instances = %s, @@ -144,7 +144,7 @@ async def configure( autoscaler_loop_period_secs = %s, worker_max_idle_time_secs = %s WHERE name = %s; -''', +""", ( boot_disk_size_gb, max_instances, @@ -177,21 +177,22 @@ async def schedule_jobs_loop_body(self): max_records = 300 async for record in self.db.select_and_fetchall( - ''' + """ SELECT jobs.*, batches.format_version, batches.userdata, batches.user, attempts.instance_name, time_ready -FROM batches -INNER JOIN jobs ON batches.id = jobs.batch_id +FROM job_groups +LEFT JOIN batches ON batches.id = job_groups.batch_id +LEFT JOIN jobs ON job_groups.batch_id = jobs.batch_id AND job_groups.job_group_id = jobs.job_group_id LEFT JOIN jobs_telemetry ON jobs.batch_id = jobs_telemetry.batch_id AND jobs.job_id = jobs_telemetry.job_id LEFT JOIN attempts ON jobs.batch_id = attempts.batch_id AND jobs.job_id = attempts.job_id LEFT JOIN instances ON attempts.instance_name = instances.name -WHERE batches.state = 'running' +WHERE job_groups.state = 'running' AND jobs.state = 'Creating' AND (jobs.always_run OR NOT jobs.cancelled) AND jobs.inst_coll = %s AND instances.`state` = 'active' ORDER BY instances.time_activated ASC LIMIT %s; -''', +""", (self.name, max_records), ): batch_id = record['batch_id'] @@ -241,7 +242,7 @@ async def compute_fair_share(self): allocating_users_by_total_jobs = sortedcontainers.SortedSet(key=lambda user: user_total_jobs[user]) records = self.db.execute_and_fetchall( - ''' + """ SELECT user, CAST(COALESCE(SUM(n_ready_jobs), 0) AS SIGNED) AS n_ready_jobs, CAST(COALESCE(SUM(n_creating_jobs), 0) AS SIGNED) AS n_creating_jobs, @@ -250,7 +251,7 @@ async def compute_fair_share(self): WHERE inst_coll = %s GROUP BY user HAVING n_ready_jobs + n_creating_jobs + n_running_jobs > 0; -''', +""", (self.name,), ) @@ -349,54 +350,62 @@ async def create_instances_loop_body(self): } async def user_runnable_jobs(user, remaining) -> AsyncIterator[Dict[str, Any]]: - async for batch in self.db.select_and_fetchall( - ''' -SELECT batches.id, job_groups_cancelled.id IS NOT NULL AS cancelled, userdata, user, format_version -FROM batches -LEFT JOIN job_groups_cancelled - ON batches.id = job_groups_cancelled.id -WHERE user = %s AND `state` = 'running'; -''', + async for job_group in self.db.select_and_fetchall( + """ +SELECT job_groups.batch_id, job_groups.job_group_id, t.cancelled IS NOT NULL AS cancelled, userdata, job_groups.user, format_version +FROM job_groups +LEFT JOIN batches ON batches.id = job_groups.batch_id +LEFT JOIN LATERAL ( + SELECT 1 AS cancelled + FROM job_group_self_and_ancestors + INNER JOIN job_groups_cancelled + ON job_group_self_and_ancestors.batch_id = job_groups_cancelled.id AND + job_group_self_and_ancestors.ancestor_id = job_groups_cancelled.job_group_id + WHERE job_groups.batch_id = job_group_self_and_ancestors.batch_id AND + job_groups.job_group_id = job_group_self_and_ancestors.job_group_id +) AS t ON TRUE +WHERE job_groups.user = %s AND job_groups.`state` = 'running'; +""", (user,), ): async for record in self.db.select_and_fetchall( - ''' + """ SELECT jobs.batch_id, jobs.job_id, jobs.spec, jobs.cores_mcpu, regions_bits_rep, COALESCE(SUM(instances.state IS NOT NULL AND - (instances.state = 'pending' OR instances.state = 'active')), 0) as live_attempts -FROM jobs FORCE INDEX(jobs_batch_id_state_always_run_inst_coll_cancelled) + (instances.state = 'pending' OR instances.state = 'active')), 0) as live_attempts, jobs.job_group_id +FROM jobs FORCE INDEX(jobs_batch_id_ic_state_ar_n_regions_bits_rep_job_group_id) LEFT JOIN attempts ON jobs.batch_id = attempts.batch_id AND jobs.job_id = attempts.job_id LEFT JOIN instances ON attempts.instance_name = instances.name -WHERE jobs.batch_id = %s AND jobs.state = 'Ready' AND always_run = 1 AND jobs.inst_coll = %s +WHERE jobs.batch_id = %s AND jobs.job_group_id = %s AND jobs.state = 'Ready' AND always_run = 1 AND jobs.inst_coll = %s GROUP BY jobs.job_id, jobs.spec, jobs.cores_mcpu HAVING live_attempts = 0 LIMIT %s; -''', - (batch['id'], self.name, remaining.value), +""", + (job_group['batch_id'], job_group['job_group_id'], self.name, remaining.value), ): - record['batch_id'] = batch['id'] - record['userdata'] = batch['userdata'] - record['user'] = batch['user'] - record['format_version'] = batch['format_version'] + record['batch_id'] = job_group['batch_id'] + record['userdata'] = job_group['userdata'] + record['user'] = job_group['user'] + record['format_version'] = job_group['format_version'] yield record - if not batch['cancelled']: + if not job_group['cancelled']: async for record in self.db.select_and_fetchall( - ''' + """ SELECT jobs.batch_id, jobs.job_id, jobs.spec, jobs.cores_mcpu, regions_bits_rep, COALESCE(SUM(instances.state IS NOT NULL AND - (instances.state = 'pending' OR instances.state = 'active')), 0) as live_attempts -FROM jobs FORCE INDEX(jobs_batch_id_state_always_run_cancelled) + (instances.state = 'pending' OR instances.state = 'active')), 0) as live_attempts, jobs.job_group_id +FROM jobs FORCE INDEX(jobs_batch_id_ic_state_ar_n_regions_bits_rep_job_group_id) LEFT JOIN attempts ON jobs.batch_id = attempts.batch_id AND jobs.job_id = attempts.job_id LEFT JOIN instances ON attempts.instance_name = instances.name -WHERE jobs.batch_id = %s AND jobs.state = 'Ready' AND always_run = 0 AND jobs.inst_coll = %s AND cancelled = 0 +WHERE jobs.batch_id = %s AND jobs.job_group_id = %s AND jobs.state = 'Ready' AND always_run = 0 AND jobs.inst_coll = %s AND cancelled = 0 GROUP BY jobs.job_id, jobs.spec, jobs.cores_mcpu HAVING live_attempts = 0 LIMIT %s -''', - (batch['id'], self.name, remaining.value), +""", + (job_group['batch_id'], job_group['job_group_id'], self.name, remaining.value), ): - record['batch_id'] = batch['id'] - record['userdata'] = batch['userdata'] - record['user'] = batch['user'] - record['format_version'] = batch['format_version'] + record['batch_id'] = job_group['batch_id'] + record['userdata'] = job_group['userdata'] + record['user'] = job_group['user'] + record['format_version'] = job_group['format_version'] yield record waitable_pool = WaitableSharedPool(self.async_worker_pool) @@ -420,6 +429,7 @@ async def user_runnable_jobs(user, remaining) -> AsyncIterator[Dict[str, Any]]: id = (batch_id, job_id) attempt_id = secret_alnum_string(6) record['attempt_id'] = attempt_id + job_group_id = record['job_group_id'] if n_user_instances_created >= n_allocated_instances: if random.random() > self.exceeded_shares_counter.rate(): @@ -435,7 +445,7 @@ async def user_runnable_jobs(user, remaining) -> AsyncIterator[Dict[str, Any]]: log.info(f'creating job private instance for job {id}') async def create_instance_with_error_handling( - batch_id: int, job_id: int, attempt_id: str, record: dict, id: Tuple[int, int] + batch_id: int, job_id: int, attempt_id: str, job_group_id: int, record: dict, id: Tuple[int, int] ): try: batch_format_version = BatchFormatVersion(record['format_version']) @@ -458,6 +468,7 @@ async def create_instance_with_error_handling( await mark_job_errored( self.app, batch_id, + job_group_id, job_id, attempt_id, record['user'], @@ -467,7 +478,9 @@ async def create_instance_with_error_handling( except Exception: log.exception(f'while creating job private instance for job {id}', exc_info=True) - await waitable_pool.call(create_instance_with_error_handling, batch_id, job_id, attempt_id, record, id) + await waitable_pool.call( + create_instance_with_error_handling, batch_id, job_id, attempt_id, job_group_id, record, id + ) remaining.value -= 1 if remaining.value <= 0: diff --git a/batch/batch/driver/instance_collection/pool.py b/batch/batch/driver/instance_collection/pool.py index 24adbcf9da8..42e2469e882 100644 --- a/batch/batch/driver/instance_collection/pool.py +++ b/batch/batch/driver/instance_collection/pool.py @@ -79,13 +79,13 @@ async def create( log.info(f'initializing {pool}') async for record in db.select_and_fetchall( - ''' + """ SELECT instances.*, instances_free_cores_mcpu.free_cores_mcpu FROM instances INNER JOIN instances_free_cores_mcpu ON instances.name = instances_free_cores_mcpu.name WHERE removed = 0 AND inst_coll = %s; -''', +""", (pool.name,), ): pool.add_instance(Instance.from_record(app, pool, record)) @@ -269,17 +269,15 @@ async def _create_instances( if n_instances > 0: log.info(f'creating {n_instances} new instances') # parallelism will be bounded by thread pool - await asyncio.gather( - *[ - self.create_instance( - cores=cores, - data_disk_size_gb=data_disk_size_gb, - regions=regions, - max_idle_time_msecs=max_idle_time_msecs, - ) - for _ in range(n_instances) - ] - ) + await asyncio.gather(*[ + self.create_instance( + cores=cores, + data_disk_size_gb=data_disk_size_gb, + regions=regions, + max_idle_time_msecs=max_idle_time_msecs, + ) + for _ in range(n_instances) + ]) async def create_instances_from_ready_cores( self, ready_cores_mcpu: int, regions: List[str], remaining_max_new_instances_per_autoscaler_loop: int @@ -323,46 +321,56 @@ async def regions_to_ready_cores_mcpu_from_estimated_job_queue(self) -> List[Tup jobs_query_args = [] for user_idx, (user, share) in enumerate(user_share.items(), start=1): - user_job_query = f''' + # job_group_id must be part of the ordering when selecting records + # because the scheduler selects records by job group in order + user_job_query = f""" ( SELECT scheduling_iteration, user_idx, n_regions, regions_bits_rep, CAST(COALESCE(SUM(cores_mcpu), 0) AS SIGNED) AS ready_cores_mcpu FROM ( SELECT {user_idx} AS user_idx, batch_id, job_id, cores_mcpu, always_run, n_regions, regions_bits_rep, - ROW_NUMBER() OVER (ORDER BY batch_id, always_run DESC, -n_regions DESC, regions_bits_rep, job_id ASC) DIV {share} AS scheduling_iteration + ROW_NUMBER() OVER (ORDER BY batch_id, job_group_id, always_run DESC, -n_regions DESC, regions_bits_rep, job_id ASC) DIV {share} AS scheduling_iteration FROM ( ( - SELECT jobs.batch_id, jobs.job_id, cores_mcpu, always_run, n_regions, regions_bits_rep - FROM jobs FORCE INDEX(jobs_batch_id_state_always_run_cancelled) + SELECT jobs.batch_id, jobs.job_id, jobs.job_group_id, cores_mcpu, always_run, n_regions, regions_bits_rep + FROM jobs FORCE INDEX(jobs_batch_id_ic_state_ar_n_regions_bits_rep_job_group_id) LEFT JOIN batches ON jobs.batch_id = batches.id WHERE user = %s AND batches.`state` = 'running' AND jobs.state = 'Ready' AND always_run AND inst_coll = %s - ORDER BY jobs.batch_id ASC, jobs.job_id ASC + ORDER BY jobs.batch_id ASC, jobs.job_group_id ASC, jobs.job_id ASC LIMIT {share * self.job_queue_scheduling_window_secs} ) UNION ( - SELECT jobs.batch_id, jobs.job_id, cores_mcpu, always_run, n_regions, regions_bits_rep - FROM jobs FORCE INDEX(jobs_batch_id_state_always_run_cancelled) + SELECT jobs.batch_id, jobs.job_id, jobs.job_group_id, cores_mcpu, always_run, n_regions, regions_bits_rep + FROM jobs FORCE INDEX(jobs_batch_id_ic_state_ar_n_regions_bits_rep_job_group_id) LEFT JOIN batches ON jobs.batch_id = batches.id - LEFT JOIN job_groups_cancelled ON batches.id = job_groups_cancelled.id - WHERE user = %s AND batches.`state` = 'running' AND jobs.state = 'Ready' AND NOT always_run AND job_groups_cancelled.id IS NULL AND inst_coll = %s - ORDER BY jobs.batch_id ASC, jobs.job_id ASC + LEFT JOIN LATERAL ( + SELECT 1 AS cancelled + FROM job_group_self_and_ancestors + INNER JOIN job_groups_cancelled + ON job_group_self_and_ancestors.batch_id = job_groups_cancelled.id AND + job_group_self_and_ancestors.ancestor_id = job_groups_cancelled.job_group_id + WHERE jobs.batch_id = job_group_self_and_ancestors.batch_id AND + jobs.job_group_id = job_group_self_and_ancestors.job_group_id + ) AS t ON TRUE + WHERE user = %s AND batches.`state` = 'running' AND jobs.state = 'Ready' AND NOT always_run AND t.cancelled IS NULL AND inst_coll = %s + ORDER BY jobs.batch_id ASC, jobs.job_group_id ASC, jobs.job_id ASC LIMIT {share * self.job_queue_scheduling_window_secs} ) ) AS t1 - ORDER BY batch_id, always_run DESC, -n_regions DESC, regions_bits_rep, job_id ASC + ORDER BY batch_id, job_group_id, always_run DESC, -n_regions DESC, regions_bits_rep, job_id ASC LIMIT {share * self.job_queue_scheduling_window_secs} ) AS t2 GROUP BY scheduling_iteration, user_idx, regions_bits_rep, n_regions HAVING ready_cores_mcpu > 0 LIMIT {self.max_new_instances_per_autoscaler_loop * self.worker_cores} ) -''' +""" jobs_query.append(user_job_query) jobs_query_args += [user, self.name, user, self.name] result = self.db.select_and_fetchall( - f''' + f""" WITH ready_cores_by_scheduling_iteration_regions AS ( {" UNION ".join(jobs_query)} ) @@ -370,7 +378,7 @@ async def regions_to_ready_cores_mcpu_from_estimated_job_queue(self) -> List[Tup FROM ready_cores_by_scheduling_iteration_regions ORDER BY scheduling_iteration, user_idx, -n_regions DESC, regions_bits_rep LIMIT {self.max_new_instances_per_autoscaler_loop * self.worker_cores}; -''', +""", jobs_query_args, query_name='get_job_queue_head', ) @@ -384,13 +392,13 @@ def extract_regions(regions_bits_rep: int): async def ready_cores_mcpu_per_user(self): ready_cores_mcpu_per_user = self.db.select_and_fetchall( - ''' + """ SELECT user, CAST(COALESCE(SUM(ready_cores_mcpu), 0) AS SIGNED) AS ready_cores_mcpu FROM user_inst_coll_resources WHERE inst_coll = %s GROUP BY user; -''', +""", (self.name,), ) @@ -510,7 +518,7 @@ async def _compute_fair_share(self, free_cores_mcpu): allocating_users_by_total_cores = sortedcontainers.SortedSet(key=lambda user: user_total_cores_mcpu[user]) records = self.db.execute_and_fetchall( - ''' + """ SELECT user, CAST(COALESCE(SUM(n_ready_jobs), 0) AS SIGNED) AS n_ready_jobs, CAST(COALESCE(SUM(ready_cores_mcpu), 0) AS SIGNED) AS ready_cores_mcpu, @@ -520,7 +528,7 @@ async def _compute_fair_share(self, free_cores_mcpu): WHERE inst_coll = %s GROUP BY user HAVING n_ready_jobs + n_running_jobs > 0; -''', +""", (self.pool.name,), "compute_fair_share", ) @@ -608,51 +616,62 @@ async def schedule_loop_body(self): } async def user_runnable_jobs(user): - async for batch in self.db.select_and_fetchall( - ''' -SELECT batches.id, job_groups_cancelled.id IS NOT NULL AS cancelled, userdata, user, format_version -FROM batches -LEFT JOIN job_groups_cancelled - ON batches.id = job_groups_cancelled.id -WHERE user = %s AND `state` = 'running'; -''', + async for job_group in self.db.select_and_fetchall( + """ +SELECT job_groups.batch_id, job_groups.job_group_id, t.cancelled IS NOT NULL AS cancelled, userdata, job_groups.user, format_version +FROM job_groups +LEFT JOIN batches ON job_groups.batch_id = batches.id +LEFT JOIN LATERAL ( + SELECT 1 AS cancelled + FROM job_group_self_and_ancestors + INNER JOIN job_groups_cancelled + ON job_group_self_and_ancestors.batch_id = job_groups_cancelled.id AND + job_group_self_and_ancestors.ancestor_id = job_groups_cancelled.job_group_id + WHERE job_groups.batch_id = job_group_self_and_ancestors.batch_id AND + job_groups.job_group_id = job_group_self_and_ancestors.job_group_id +) AS t ON TRUE +WHERE job_groups.user = %s AND job_groups.`state` = 'running' +ORDER BY job_groups.batch_id, job_groups.job_group_id; +""", (user,), "user_runnable_jobs__select_running_batches", ): async for record in self.db.select_and_fetchall( - ''' -SELECT jobs.job_id, spec, cores_mcpu, regions_bits_rep, time_ready -FROM jobs FORCE INDEX(jobs_batch_id_state_always_run_inst_coll_cancelled) + """ +SELECT jobs.job_id, spec, cores_mcpu, regions_bits_rep, time_ready, job_group_id +FROM jobs FORCE INDEX(jobs_batch_id_ic_state_ar_n_regions_bits_rep_job_group_id) LEFT JOIN jobs_telemetry ON jobs.batch_id = jobs_telemetry.batch_id AND jobs.job_id = jobs_telemetry.job_id -WHERE jobs.batch_id = %s AND inst_coll = %s AND jobs.state = 'Ready' AND always_run = 1 -ORDER BY jobs.batch_id, inst_coll, state, always_run, -n_regions DESC, regions_bits_rep, jobs.job_id +WHERE jobs.batch_id = %s AND job_group_id = %s AND inst_coll = %s AND jobs.state = 'Ready' AND always_run = 1 +ORDER BY jobs.batch_id, jobs.job_group_id, inst_coll, state, always_run, -n_regions DESC, regions_bits_rep, jobs.job_id LIMIT 300; -''', - (batch['id'], self.pool.name), +""", + (job_group['batch_id'], job_group['job_group_id'], self.pool.name), "user_runnable_jobs__select_ready_always_run_jobs", ): - record['batch_id'] = batch['id'] - record['userdata'] = batch['userdata'] - record['user'] = batch['user'] - record['format_version'] = batch['format_version'] + record['batch_id'] = job_group['batch_id'] + record['job_group_id'] = job_group['job_group_id'] + record['userdata'] = job_group['userdata'] + record['user'] = job_group['user'] + record['format_version'] = job_group['format_version'] yield record - if not batch['cancelled']: + if not job_group['cancelled']: async for record in self.db.select_and_fetchall( - ''' -SELECT jobs.job_id, spec, cores_mcpu, regions_bits_rep, time_ready -FROM jobs FORCE INDEX(jobs_batch_id_state_always_run_cancelled) + """ +SELECT jobs.job_id, spec, cores_mcpu, regions_bits_rep, time_ready, job_group_id +FROM jobs FORCE INDEX(jobs_batch_id_ic_state_ar_n_regions_bits_rep_job_group_id) LEFT JOIN jobs_telemetry ON jobs.batch_id = jobs_telemetry.batch_id AND jobs.job_id = jobs_telemetry.job_id -WHERE jobs.batch_id = %s AND inst_coll = %s AND jobs.state = 'Ready' AND always_run = 0 AND cancelled = 0 -ORDER BY jobs.batch_id, inst_coll, state, always_run, -n_regions DESC, regions_bits_rep, jobs.job_id +WHERE jobs.batch_id = %s AND job_group_id = %s AND inst_coll = %s AND jobs.state = 'Ready' AND always_run = 0 AND cancelled = 0 +ORDER BY jobs.batch_id, jobs.job_group_id, inst_coll, state, always_run, -n_regions DESC, regions_bits_rep, jobs.job_id LIMIT 300; -''', - (batch['id'], self.pool.name), +""", + (job_group['batch_id'], job_group['job_group_id'], self.pool.name), "user_runnable_jobs__select_ready_jobs_batch_not_cancelled", ): - record['batch_id'] = batch['id'] - record['userdata'] = batch['userdata'] - record['user'] = batch['user'] - record['format_version'] = batch['format_version'] + record['batch_id'] = job_group['batch_id'] + record['job_group_id'] = job_group['job_group_id'] + record['userdata'] = job_group['userdata'] + record['user'] = job_group['user'] + record['format_version'] = job_group['format_version'] yield record waitable_pool = WaitableSharedPool(self.async_worker_pool) @@ -682,6 +701,7 @@ async def user_runnable_jobs(user): await mark_job_errored( self.app, record['batch_id'], + record['job_group_id'], record['job_id'], attempt_id, record['user'], diff --git a/batch/batch/driver/job.py b/batch/batch/driver/job.py index a4b54705e3e..c6d129eef18 100644 --- a/batch/batch/driver/job.py +++ b/batch/batch/driver/job.py @@ -3,17 +3,19 @@ import collections import json import logging +import os import traceback from typing import TYPE_CHECKING, Dict, List import aiohttp -from gear import Database, K8sCache +from gear import CommonAiohttpAppKeys, Database, K8sCache from hailtop import httpx from hailtop.aiotools import BackgroundTaskManager +from hailtop.batch_client.globals import ROOT_JOB_GROUP_ID from hailtop.utils import Notice, retry_transient_errors, time_msecs -from ..batch import batch_record_to_dict +from ..batch import batch_record_to_dict, job_group_record_to_dict from ..batch_configuration import KUBERNETES_SERVER_URL from ..batch_format_version import BatchFormatVersion from ..file_store import FileStore @@ -28,9 +30,9 @@ log = logging.getLogger('job') -async def notify_batch_job_complete(db: Database, client_session: httpx.ClientSession, batch_id): +async def notify_batch_job_complete(db: Database, client_session: httpx.ClientSession, batch_id: int): record = await db.select_and_fetchone( - ''' + """ SELECT batches.*, cost_t.cost, cost_t.cost_breakdown, @@ -57,7 +59,7 @@ async def notify_batch_job_complete(db: Database, client_session: httpx.ClientSe ON batches.id = job_groups_cancelled.id WHERE batches.id = %s AND NOT deleted AND callback IS NOT NULL AND batches.`state` = 'complete'; -''', +""", (batch_id,), 'notify_batch_job_complete', ) @@ -85,6 +87,81 @@ async def request(session): log.info(f'callback for batch {batch_id} failed, will not retry.') +async def notify_job_group_on_job_complete( + db: Database, client_session: httpx.ClientSession, batch_id: int, job_group_id: int +): + records = db.select_and_fetchall( + """ +SELECT job_groups.*, + ancestor_id, + cost_t.cost, + cost_t.cost_breakdown, + t.cancelled IS NOT NULL AS cancelled, + job_groups_n_jobs_in_complete_states.n_completed, + job_groups_n_jobs_in_complete_states.n_succeeded, + job_groups_n_jobs_in_complete_states.n_failed, + job_groups_n_jobs_in_complete_states.n_cancelled +FROM job_group_self_and_ancestors +LEFT JOIN job_groups ON job_groups.batch_id = job_group_self_and_ancestors.batch_id AND + job_groups.job_group_id = job_group_self_and_ancestors.ancestor_id +LEFT JOIN batches ON job_group_self_and_ancestors.batch_id = batches.id +LEFT JOIN job_groups_n_jobs_in_complete_states + ON job_group_self_and_ancestors.batch_id = job_groups_n_jobs_in_complete_states.id AND + job_group_self_and_ancestors.ancestor_id = job_groups_n_jobs_in_complete_states.job_group_id +LEFT JOIN LATERAL ( + SELECT COALESCE(SUM(`usage` * rate), 0) AS cost, JSON_OBJECTAGG(resources.resource, COALESCE(`usage` * rate, 0)) AS cost_breakdown + FROM ( + SELECT resource_id, CAST(COALESCE(SUM(`usage`), 0) AS SIGNED) AS `usage` + FROM aggregated_job_group_resources_v3 + WHERE job_group_self_and_ancestors.batch_id = aggregated_job_group_resources_v3.batch_id AND + job_group_self_and_ancestors.ancestor_id = aggregated_job_group_resources_v3.job_group_id + GROUP BY resource_id + ) AS usage_t + LEFT JOIN resources ON usage_t.resource_id = resources.resource_id +) AS cost_t ON TRUE +LEFT JOIN LATERAL ( + SELECT 1 AS cancelled + FROM job_group_self_and_ancestors AS self_and_ancestors + INNER JOIN job_groups_cancelled + ON self_and_ancestors.batch_id = job_groups_cancelled.id AND + self_and_ancestors.ancestor_id = job_groups_cancelled.job_group_id + WHERE self_and_ancestors.batch_id = job_group_self_and_ancestors.batch_id AND + self_and_ancestors.job_group_id = job_group_self_and_ancestors.ancestor_id +) AS t ON TRUE +WHERE job_group_self_and_ancestors.batch_id = %s AND + job_group_self_and_ancestors.job_group_id = %s AND + job_group_self_and_ancestors.ancestor_id != %s AND + NOT deleted AND + job_groups.callback IS NOT NULL AND + job_groups.`state` = 'complete'; +""", + (batch_id, job_group_id, ROOT_JOB_GROUP_ID), + 'notify_job_group_on_job_complete', + ) + + async for record in records: + ancestor_job_group_id = record['ancestor_id'] + callback = record['callback'] + + log.info(f'making callback for batch {batch_id} job group {ancestor_job_group_id}: {callback}') + + async def request(session, record, callback, batch_id, ancestor_job_group_id): + await session.post(callback, json=job_group_record_to_dict(record)) + log.info(f'callback for batch {batch_id} job group {ancestor_job_group_id} successful') + + try: + if record['user'] == 'ci': + # only jobs from CI may use batch's TLS identity + await request(client_session, record, callback, batch_id, ancestor_job_group_id) + else: + async with httpx.client_session() as session: + await request(session, record, callback, batch_id, ancestor_job_group_id) + except asyncio.CancelledError: + raise + except Exception: + log.info(f'callback for batch {batch_id} job group {ancestor_job_group_id} failed, will not retry.') + + async def add_attempt_resources(app, db, batch_id, job_id, attempt_id, resources: List[QuantifiedResource]): resource_name_to_id = app['resource_name_to_id'] if attempt_id and len(resources) > 0: @@ -109,11 +186,11 @@ async def add_attempt_resources(app, db, batch_id, job_id, attempt_id, resources ] await db.execute_many( - ''' + """ INSERT INTO `attempt_resources` (batch_id, job_id, attempt_id, resource_id, deduped_resource_id, quantity) VALUES (%s, %s, %s, %s, %s, %s) ON DUPLICATE KEY UPDATE quantity = quantity; -''', +""", resource_args, 'add_attempt_resources', ) @@ -127,6 +204,7 @@ async def mark_job_complete( batch_id, job_id, attempt_id, + job_group_id, instance_name, new_state, status, @@ -140,7 +218,7 @@ async def mark_job_complete( scheduler_state_changed: Notice = app['scheduler_state_changed'] cancel_ready_state_changed: asyncio.Event = app['cancel_ready_state_changed'] db: Database = app['db'] - client_session: httpx.ClientSession = app['client_session'] + client_session = app[CommonAiohttpAppKeys.CLIENT_SESSION] inst_coll_manager: 'InstanceCollectionManager' = app['driver'].inst_coll_manager task_manager: BackgroundTaskManager = app['task_manager'] @@ -200,6 +278,7 @@ async def mark_job_complete( return await notify_batch_job_complete(db, client_session, batch_id) + await notify_job_group_on_job_complete(db, client_session, batch_id, job_group_id) if instance and not instance.inst_coll.is_pool and instance.state == 'active': task_manager.ensure_future(instance.kill()) @@ -214,9 +293,9 @@ async def mark_job_started(app, batch_id, job_id, attempt_id, instance, start_ti try: rv = await db.execute_and_fetchone( - ''' + """ CALL mark_job_started(%s, %s, %s, %s, %s); -''', +""", (batch_id, job_id, attempt_id, instance.name, start_time), 'mark_job_started', ) @@ -247,9 +326,9 @@ async def mark_job_creating( try: rv = await db.execute_and_fetchone( - ''' + """ CALL mark_job_creating(%s, %s, %s, %s, %s); -''', +""", (batch_id, job_id, attempt_id, instance.name, start_time), 'mark_job_creating', ) @@ -268,7 +347,7 @@ async def unschedule_job(app, record): cancel_ready_state_changed: asyncio.Event = app['cancel_ready_state_changed'] scheduler_state_changed: Notice = app['scheduler_state_changed'] db: Database = app['db'] - client_session: httpx.ClientSession = app['client_session'] + client_session = app[CommonAiohttpAppKeys.CLIENT_SESSION] inst_coll_manager = app['driver'].inst_coll_manager batch_id = record['batch_id'] @@ -333,13 +412,15 @@ async def make_request(): log.info(f'unschedule job {id}, attempt {attempt_id}: called delete job') -async def job_config(app, record, attempt_id): +async def job_config(app, record): k8s_cache: K8sCache = app['k8s_cache'] db: Database = app['db'] format_version = BatchFormatVersion(record['format_version']) batch_id = record['batch_id'] job_id = record['job_id'] + attempt_id = record['attempt_id'] + job_group_id = record['job_group_id'] db_spec = json.loads(record['spec']) @@ -352,15 +433,15 @@ async def job_config(app, record, attempt_id): job_spec = db_spec job_spec['attempt_id'] = attempt_id + job_spec['job_group_id'] = job_group_id userdata = json.loads(record['userdata']) - secrets = job_spec.get('secrets', []) - k8s_secrets = await asyncio.gather( - *[k8s_cache.read_secret(secret['name'], secret['namespace']) for secret in secrets] - ) - gsa_key = None + secrets = job_spec.get('secrets', []) + k8s_secrets = await asyncio.gather(*[ + k8s_cache.read_secret(secret['name'], secret['namespace']) for secret in secrets + ]) # backwards compatibility gsa_key_secret_name = userdata.get('hail_credentials_secret_name') or userdata['gsa_key_secret_name'] @@ -370,7 +451,10 @@ async def job_config(app, record, attempt_id): gsa_key = k8s_secret.data secret['data'] = k8s_secret.data - assert gsa_key + if os.environ.get('HAIL_TERRA'): + assert not gsa_key + else: + assert gsa_key service_account = job_spec.get('service_account') if service_account: @@ -391,7 +475,7 @@ async def job_config(app, record, attempt_id): user_token = base64.b64decode(secret.data['token']).decode() cert = secret.data['ca.crt'] - kube_config = f''' + kube_config = f""" apiVersion: v1 clusters: - cluster: @@ -411,15 +495,13 @@ async def job_config(app, record, attempt_id): - name: {namespace}-{name} user: token: {user_token} -''' - - job_spec['secrets'].append( - { - 'name': 'kube-config', - 'mount_path': '/.kube', - 'data': {'config': base64.b64encode(kube_config.encode()).decode(), 'ca.crt': cert}, - } - ) +""" + + job_spec['secrets'].append({ + 'name': 'kube-config', + 'mount_path': '/.kube', + 'data': {'config': base64.b64encode(kube_config.encode()).decode(), 'ca.crt': cert}, + }) env = job_spec.get('env') if not env: @@ -446,7 +528,7 @@ async def job_config(app, record, attempt_id): } -async def mark_job_errored(app, batch_id, job_id, attempt_id, user, format_version, error_msg): +async def mark_job_errored(app, batch_id, job_group_id, job_id, attempt_id, user, format_version, error_msg): file_store: FileStore = app['file_store'] status = { @@ -454,6 +536,7 @@ async def mark_job_errored(app, batch_id, job_id, attempt_id, user, format_versi 'worker': None, 'batch_id': batch_id, 'job_id': job_id, + 'job_group_id': job_group_id, 'attempt_id': attempt_id, 'user': user, 'state': 'error', @@ -466,29 +549,32 @@ async def mark_job_errored(app, batch_id, job_id, attempt_id, user, format_versi db_status = format_version.db_status(status) - await mark_job_complete(app, batch_id, job_id, attempt_id, None, 'Error', db_status, None, None, 'error', []) + await mark_job_complete( + app, batch_id, job_id, attempt_id, job_group_id, None, 'Error', db_status, None, None, 'error', [] + ) async def schedule_job(app, record, instance): assert instance.state == 'active' db: Database = app['db'] - client_session: httpx.ClientSession = app['client_session'] + client_session = app[CommonAiohttpAppKeys.CLIENT_SESSION] batch_id = record['batch_id'] job_id = record['job_id'] attempt_id = record['attempt_id'] + job_group_id = record['job_group_id'] format_version = BatchFormatVersion(record['format_version']) id = (batch_id, job_id) try: - body = await job_config(app, record, attempt_id) + body = await job_config(app, record) except Exception: log.exception(f'while making job config for job {id} with attempt id {attempt_id}') await mark_job_errored( - app, batch_id, job_id, attempt_id, record['user'], format_version, traceback.format_exc() + app, batch_id, job_group_id, job_id, attempt_id, record['user'], format_version, traceback.format_exc() ) raise @@ -512,9 +598,9 @@ async def schedule_job(app, record, instance): try: rv = await db.execute_and_fetchone( - ''' + """ CALL schedule_job(%s, %s, %s, %s); -''', +""", (batch_id, job_id, attempt_id, instance.name), 'schedule_job', ) diff --git a/batch/batch/driver/main.py b/batch/batch/driver/main.py index b6e3f0c0bb6..9380a4cd149 100644 --- a/batch/batch/driver/main.py +++ b/batch/batch/driver/main.py @@ -5,6 +5,7 @@ import os import re import signal +import traceback import warnings from collections import defaultdict, namedtuple from contextlib import AsyncExitStack @@ -19,17 +20,17 @@ import plotly import plotly.graph_objects as go import prometheus_client as pc # type: ignore -import uvloop from aiohttp import web from plotly.subplots import make_subplots from prometheus_async.aio.web import server_stats from gear import ( - AuthServiceAuthenticator, + CommonAiohttpAppKeys, Database, K8sCache, Transaction, check_csrf_token, + get_authenticator, json_request, json_response, monitor_endpoints_middleware, @@ -39,7 +40,8 @@ from gear.auth import AIOHTTPHandler, UserData from gear.clients import get_cloud_async_fs from gear.profiling import install_profiler_if_requested -from hailtop import aiotools, httpx +from hailtop import aiotools, httpx, uvloopx +from hailtop.batch_client.globals import ROOT_JOB_GROUP_ID from hailtop.config import get_deploy_config from hailtop.hail_logging import AccessLogger from hailtop.utils import ( @@ -52,7 +54,7 @@ ) from web_common import render_template, set_message, setup_aiohttp_jinja2, setup_common_static_routes -from ..batch import cancel_batch_in_db +from ..batch import cancel_job_group_in_db from ..batch_configuration import ( BATCH_STORAGE_URI, CLOUD, @@ -78,7 +80,7 @@ from .instance_collection import InstanceCollectionManager, JobPrivateInstanceManager, Pool from .job import mark_job_complete, mark_job_started -uvloop.install() +uvloopx.install() log = logging.getLogger('batch') @@ -88,7 +90,7 @@ deploy_config = get_deploy_config() -auth = AuthServiceAuthenticator() +auth = get_authenticator() warnings.filterwarnings( 'ignore', @@ -202,12 +204,18 @@ async def get_check_invariants(request: web.Request, _) -> web.Response: incremental_result, resource_agg_result = await asyncio.gather( check_incremental(db), check_resource_aggregation(db), return_exceptions=True ) - return json_response( - { - 'check_incremental_error': incremental_result, - 'check_resource_aggregation_error': resource_agg_result, - } - ) + return json_response({ + 'check_incremental_error': '\n'.join( + traceback.format_exception(None, incremental_result, incremental_result.__traceback__) + ) + if incremental_result + else None, + 'check_resource_aggregation_error': '\n'.join( + traceback.format_exception(None, resource_agg_result, resource_agg_result.__traceback__) + ) + if resource_agg_result + else None, + }) @routes.patch('/api/v1alpha/batches/{user}/{batch_id}/update') @@ -220,9 +228,9 @@ async def update_batch(request): batch_id = int(request.match_info['batch_id']) record = await db.select_and_fetchone( - ''' + """ SELECT state FROM batches WHERE user = %s AND id = %s; -''', +""", (user, batch_id), ) if not record: @@ -317,6 +325,7 @@ async def job_complete_1(request, instance): batch_id = job_status['batch_id'] job_id = job_status['job_id'] attempt_id = job_status['attempt_id'] + job_group_id = job_status.get('job_group_id', ROOT_JOB_GROUP_ID) request['batch_telemetry']['batch_id'] = str(batch_id) request['batch_telemetry']['job_id'] = str(job_id) @@ -340,6 +349,7 @@ async def job_complete_1(request, instance): batch_id, job_id, attempt_id, + job_group_id, instance.name, new_state, status, @@ -407,11 +417,11 @@ async def billing_update_1(request, instance): where_args = [update_timestamp, *flatten(where_attempt_args)] await db.execute_update( - f''' + f""" UPDATE attempts SET rollup_time = %s {where_query}; -''', +""", where_args, ) @@ -436,12 +446,10 @@ async def get_index(request, userdata): inst_coll_manager: InstanceCollectionManager = app['driver'].inst_coll_manager jpim: JobPrivateInstanceManager = app['driver'].job_private_inst_manager - ready_cores = await db.select_and_fetchone( - ''' + ready_cores = await db.select_and_fetchone(""" SELECT CAST(COALESCE(SUM(ready_cores_mcpu), 0) AS SIGNED) AS ready_cores_mcpu FROM user_inst_coll_resources; -''' - ) +""") ready_cores_mcpu = ready_cores['ready_cores_mcpu'] page_context = { @@ -565,9 +573,9 @@ async def configure_feature_flags(request: web.Request, _) -> NoReturn: oms_agent = 'oms_agent' in post await db.execute_update( - ''' + """ UPDATE feature_flags SET compact_billing_tables = %s, oms_agent = %s; -''', +""", (compact_billing_tables, oms_agent), ) @@ -940,11 +948,9 @@ async def freeze_batch(request: web.Request, _) -> NoReturn: set_message(session, 'Batch is already frozen.', 'info') raise web.HTTPFound(deploy_config.external_url('batch-driver', '/')) - await db.execute_update( - ''' + await db.execute_update(""" UPDATE globals SET frozen = 1; -''' - ) +""") app['frozen'] = True @@ -964,11 +970,9 @@ async def unfreeze_batch(request: web.Request, _) -> NoReturn: set_message(session, 'Batch is already unfrozen.', 'info') raise web.HTTPFound(deploy_config.external_url('batch-driver', '/')) - await db.execute_update( - ''' + await db.execute_update(""" UPDATE globals SET frozen = 0; -''' - ) +""") app['frozen'] = False @@ -983,8 +987,7 @@ async def get_user_resources(request, userdata): app = request.app db: Database = app['db'] - records = db.execute_and_fetchall( - ''' + records = db.execute_and_fetchall(""" SELECT user, CAST(COALESCE(SUM(n_ready_jobs), 0) AS SIGNED) AS n_ready_jobs, CAST(COALESCE(SUM(ready_cores_mcpu), 0) AS SIGNED) AS ready_cores_mcpu, @@ -993,8 +996,7 @@ async def get_user_resources(request, userdata): FROM user_inst_coll_resources GROUP BY user HAVING n_ready_jobs + n_running_jobs > 0; -''' - ) +""") user_resources = sorted( [record async for record in records], @@ -1009,8 +1011,7 @@ async def get_user_resources(request, userdata): async def check_incremental(db): @transaction(db, read_only=True) async def check(tx): - user_inst_coll_with_broken_resources = tx.execute_and_fetchall( - ''' + user_inst_coll_with_broken_resources = tx.execute_and_fetchall(""" SELECT t.*, u.* @@ -1027,13 +1028,21 @@ async def check(tx): CAST(COALESCE(SUM(state = 'Creating' AND cancelled), 0) AS SIGNED) AS actual_n_cancelled_creating_jobs FROM ( - SELECT batches.user, jobs.state, jobs.cores_mcpu, jobs.inst_coll, - (jobs.always_run OR NOT (jobs.cancelled OR job_groups_cancelled.id IS NOT NULL)) AS runnable, - (NOT jobs.always_run AND (jobs.cancelled OR job_groups_cancelled.id IS NOT NULL)) AS cancelled - FROM batches - INNER JOIN jobs ON batches.id = jobs.batch_id - LEFT JOIN job_groups_cancelled ON batches.id = job_groups_cancelled.id - WHERE batches.`state` = 'running' + SELECT job_groups.user, jobs.state, jobs.cores_mcpu, jobs.inst_coll, + (jobs.always_run OR NOT (jobs.cancelled OR t.cancelled IS NOT NULL)) AS runnable, + (NOT jobs.always_run AND (jobs.cancelled OR t.cancelled IS NOT NULL)) AS cancelled + FROM job_groups + LEFT JOIN jobs ON job_groups.batch_id = jobs.batch_id AND job_groups.job_group_id = jobs.job_group_id + LEFT JOIN LATERAL ( + SELECT 1 AS cancelled + FROM job_group_self_and_ancestors + INNER JOIN job_groups_cancelled + ON job_group_self_and_ancestors.batch_id = job_groups_cancelled.id AND + job_group_self_and_ancestors.ancestor_id = job_groups_cancelled.job_group_id + WHERE job_groups.batch_id = job_group_self_and_ancestors.batch_id AND + job_groups.job_group_id = job_group_self_and_ancestors.job_group_id + ) AS t ON TRUE + WHERE job_groups.`state` = 'running' ) as v GROUP BY user, inst_coll ) as t @@ -1069,8 +1078,7 @@ async def check(tx): OR expected_n_cancelled_running_jobs != 0 OR expected_n_cancelled_creating_jobs != 0 LOCK IN SHARE MODE; -''' - ) +""") failures = [record async for record in user_inst_coll_with_broken_resources] if len(failures) > 0: @@ -1116,8 +1124,7 @@ def fold(d, key_f): @transaction(db, read_only=True) async def check(tx): - attempt_resources = tx.execute_and_fetchall( - ''' + attempt_resources = tx.execute_and_fetchall(""" SELECT attempt_resources.batch_id, attempt_resources.job_id, attempt_resources.attempt_id, JSON_OBJECTAGG(resources.resource, quantity * GREATEST(COALESCE(rollup_time - start_time, 0), 0)) as resources FROM attempt_resources @@ -1129,35 +1136,68 @@ async def check(tx): WHERE GREATEST(COALESCE(rollup_time - start_time, 0), 0) != 0 GROUP BY batch_id, job_id, attempt_id LOCK IN SHARE MODE; -''' - ) +""") - agg_job_resources = tx.execute_and_fetchall( - ''' + attempt_by_job_group_resources = tx.execute_and_fetchall(""" +SELECT batch_id, ancestor_id, JSON_OBJECTAGG(resource, `usage`) as resources +FROM ( + SELECT job_group_self_and_ancestors.batch_id, job_group_self_and_ancestors.ancestor_id, resource, + CAST(COALESCE(SUM(quantity * GREATEST(COALESCE(rollup_time - start_time, 0), 0)), 0) AS SIGNED) as `usage` + FROM attempt_resources + INNER JOIN attempts + ON attempts.batch_id = attempt_resources.batch_id AND + attempts.job_id = attempt_resources.job_id AND + attempts.attempt_id = attempt_resources.attempt_id + LEFT JOIN resources ON attempt_resources.resource_id = resources.resource_id + LEFT JOIN jobs ON attempts.batch_id = jobs.batch_id AND attempts.job_id = jobs.job_id + LEFT JOIN job_group_self_and_ancestors ON jobs.batch_id = job_group_self_and_ancestors.batch_id AND + jobs.job_group_id = job_group_self_and_ancestors.job_group_id + WHERE GREATEST(COALESCE(rollup_time - start_time, 0), 0) != 0 + GROUP BY job_group_self_and_ancestors.batch_id, job_group_self_and_ancestors.ancestor_id, resource + LOCK IN SHARE MODE +) AS t +GROUP BY t.batch_id, t.ancestor_id; +""") + + agg_job_resources = tx.execute_and_fetchall(""" SELECT batch_id, job_id, JSON_OBJECTAGG(resource, `usage`) as resources -FROM aggregated_job_resources_v3 -LEFT JOIN resources ON aggregated_job_resources_v3.resource_id = resources.resource_id +FROM ( + SELECT batch_id, job_id, resource_id, CAST(COALESCE(SUM(`usage`), 0) AS SIGNED) AS `usage` + FROM aggregated_job_resources_v3 + GROUP BY batch_id, job_id, resource_id +) AS t +LEFT JOIN resources ON t.resource_id = resources.resource_id GROUP BY batch_id, job_id LOCK IN SHARE MODE; -''' - ) +""") + + agg_job_group_resources = tx.execute_and_fetchall(""" +SELECT batch_id, job_group_id, JSON_OBJECTAGG(resource, `usage`) as resources +FROM ( + SELECT batch_id, job_group_id, resource_id, CAST(COALESCE(SUM(`usage`), 0) AS SIGNED) AS `usage` + FROM aggregated_job_group_resources_v3 + GROUP BY batch_id, job_group_id, resource_id +) AS t +LEFT JOIN resources ON t.resource_id = resources.resource_id +GROUP BY batch_id, job_group_id +LOCK IN SHARE MODE; +""") - agg_batch_resources = tx.execute_and_fetchall( - ''' + agg_batch_resources = tx.execute_and_fetchall(""" SELECT batch_id, billing_project, JSON_OBJECTAGG(resource, `usage`) as resources FROM ( SELECT batch_id, resource_id, CAST(COALESCE(SUM(`usage`), 0) AS SIGNED) AS `usage` FROM aggregated_job_group_resources_v3 - GROUP BY batch_id, resource_id) AS t + WHERE job_group_id = 0 + GROUP BY batch_id, resource_id +) AS t LEFT JOIN resources ON t.resource_id = resources.resource_id JOIN batches ON batches.id = t.batch_id GROUP BY t.batch_id, billing_project LOCK IN SHARE MODE; -''' - ) +""") - agg_billing_project_resources = tx.execute_and_fetchall( - ''' + agg_billing_project_resources = tx.execute_and_fetchall(""" SELECT billing_project, JSON_OBJECTAGG(resource, `usage`) as resources FROM ( SELECT billing_project, resource_id, CAST(COALESCE(SUM(`usage`), 0) AS SIGNED) AS `usage` @@ -1166,19 +1206,28 @@ async def check(tx): LEFT JOIN resources ON t.resource_id = resources.resource_id GROUP BY t.billing_project LOCK IN SHARE MODE; -''' - ) +""") attempt_resources = { (record['batch_id'], record['job_id'], record['attempt_id']): json_to_value(record['resources']) async for record in attempt_resources } + attempt_by_job_group_resources = { + (record['batch_id'], record['ancestor_id']): json_to_value(record['resources']) + async for record in attempt_by_job_group_resources + } + agg_job_resources = { (record['batch_id'], record['job_id']): json_to_value(record['resources']) async for record in agg_job_resources } + agg_job_group_resources = { + (record['batch_id'], record['job_group_id']): json_to_value(record['resources']) + async for record in agg_job_group_resources + } + agg_batch_resources = { (record['batch_id'], record['billing_project']): json_to_value(record['resources']) async for record in agg_batch_resources @@ -1191,6 +1240,7 @@ async def check(tx): attempt_by_batch_resources = fold(attempt_resources, lambda k: k[0]) attempt_by_job_resources = fold(attempt_resources, lambda k: (k[0], k[1])) + attempt_by_job_group_resources = fold(attempt_by_job_group_resources, lambda k: (k[0], k[1])) job_by_batch_resources = fold(agg_job_resources, lambda k: k[0]) batch_by_billing_project_resources = fold(agg_batch_resources, lambda k: k[1]) @@ -1217,14 +1267,20 @@ async def check(tx): agg_billing_project_resources, ) + assert attempt_by_job_group_resources == agg_job_group_resources, ( + dictdiffer.diff(attempt_by_job_group_resources, agg_job_group_resources), + attempt_by_job_group_resources, + agg_job_group_resources, + ) + await check() -async def _cancel_batch(app, batch_id): +async def _cancel_job_group(app, batch_id, job_group_id): try: - await cancel_batch_in_db(app['db'], batch_id) + await cancel_job_group_in_db(app['db'], batch_id, job_group_id) except BatchUserError as exc: - log.info(f'cannot cancel batch because {exc.message}') + log.info(f'cannot cancel job group because {exc.message}') return set_cancel_state_changed(app) @@ -1238,31 +1294,40 @@ async def monitor_billing_limits(app): accrued_cost = record['accrued_cost'] if limit is not None and accrued_cost >= limit: running_batches = db.execute_and_fetchall( - ''' + """ SELECT id FROM batches WHERE billing_project = %s AND state = 'running'; -''', +""", (record['billing_project'],), ) async for batch in running_batches: - await _cancel_batch(app, batch['id']) + await _cancel_job_group(app, batch['id'], ROOT_JOB_GROUP_ID) -async def cancel_fast_failing_batches(app): +async def cancel_fast_failing_job_groups(app): db: Database = app['db'] - records = db.select_and_fetchall( - ''' -SELECT batches.id, job_groups_n_jobs_in_complete_states.n_failed -FROM batches + """ +SELECT job_groups.batch_id, job_groups.job_group_id, job_groups_n_jobs_in_complete_states.n_failed +FROM job_groups +LEFT JOIN LATERAL ( + SELECT 1 AS cancelled + FROM job_group_self_and_ancestors + INNER JOIN job_groups_cancelled + ON job_group_self_and_ancestors.batch_id = job_groups_cancelled.id AND + job_group_self_and_ancestors.ancestor_id = job_groups_cancelled.job_group_id + WHERE job_groups.batch_id = job_group_self_and_ancestors.batch_id AND + job_groups.job_group_id = job_group_self_and_ancestors.job_group_id +) AS t_cancelled ON TRUE LEFT JOIN job_groups_n_jobs_in_complete_states - ON batches.id = job_groups_n_jobs_in_complete_states.id -WHERE state = 'running' AND cancel_after_n_failures IS NOT NULL AND n_failed >= cancel_after_n_failures -''' + ON job_groups.batch_id = job_groups_n_jobs_in_complete_states.id AND + job_groups.job_group_id = job_groups_n_jobs_in_complete_states.job_group_id +WHERE t_cancelled.cancelled IS NULL AND state = 'running' AND cancel_after_n_failures IS NOT NULL AND n_failed >= cancel_after_n_failures; +""", ) - async for batch in records: - await _cancel_batch(app, batch['id']) + async for job_group in records: + await _cancel_job_group(app, job_group['batch_id'], job_group['job_group_id']) USER_CORES = pc.Gauge('batch_user_cores', 'Batch user cores (i.e. total in-use cores)', ['state', 'user', 'inst_coll']) @@ -1295,8 +1360,7 @@ async def monitor_user_resources(app): global ACTIVE_USER_INST_COLL_PAIRS db: Database = app['db'] - records = db.select_and_fetchall( - ''' + records = db.select_and_fetchall(""" SELECT user, inst_coll, CAST(COALESCE(SUM(ready_cores_mcpu), 0) AS SIGNED) AS ready_cores_mcpu, CAST(COALESCE(SUM(running_cores_mcpu), 0) AS SIGNED) AS running_cores_mcpu, @@ -1305,8 +1369,7 @@ async def monitor_user_resources(app): CAST(COALESCE(SUM(n_creating_jobs), 0) AS SIGNED) AS n_creating_jobs FROM user_inst_coll_resources GROUP BY user, inst_coll; -''' - ) +""") current_user_inst_coll_pairs: Set[Tuple[str, str]] = set() @@ -1372,6 +1435,66 @@ async def monitor_system(app): monitor_instances(app) +async def delete_committed_job_groups_inst_coll_staging_records(db: Database): + targets = db.execute_and_fetchall( + """ +SELECT job_groups_inst_coll_staging.batch_id, + job_groups_inst_coll_staging.update_id, + job_groups_inst_coll_staging.job_group_id +FROM job_groups_inst_coll_staging +INNER JOIN batch_updates ON batch_updates.batch_id = job_groups_inst_coll_staging.batch_id AND + batch_updates.update_id = job_groups_inst_coll_staging.update_id +WHERE committed +GROUP BY job_groups_inst_coll_staging.batch_id, job_groups_inst_coll_staging.update_id, job_groups_inst_coll_staging.job_group_id +LIMIT 1000; +""", + query_name='find_staging_records_to_delete', + ) + + async for target in targets: + await db.just_execute( + """ +DELETE FROM job_groups_inst_coll_staging +WHERE batch_id = %s AND update_id = %s AND job_group_id = %s; +""", + (target['batch_id'], target['update_id'], target['job_group_id']), + ) + + +async def delete_prev_cancelled_job_group_cancellable_resources_records(db: Database): + targets = db.execute_and_fetchall( + """ +SELECT DISTINCT + group_resources.batch_id, + group_resources.update_id, + group_resources.job_group_id +FROM job_group_inst_coll_cancellable_resources AS group_resources +INNER JOIN LATERAL ( + SELECT + 1 + FROM job_group_self_and_ancestors AS descendant + INNER JOIN job_groups_cancelled AS cancelled + ON descendant.batch_id = cancelled.id + AND descendant.ancestor_id = cancelled.job_group_id + WHERE descendant.batch_id = group_resources.batch_id + AND descendant.job_group_id = group_resources.job_group_id +) AS t ON TRUE +ORDER BY group_resources.batch_id desc, group_resources.update_id desc, group_resources.job_group_id desc +LIMIT 1000; +""", + query_name='find_cancelled_cancellable_resources_records_to_delete', + ) + + async for target in targets: + await db.just_execute( + """ +DELETE FROM job_group_inst_coll_cancellable_resources +WHERE batch_id = %s AND update_id = %s AND job_group_id = %s; +""", + (target['batch_id'], target['update_id'], target['job_group_id']), + ) + + async def compact_agg_billing_project_users_table(app, db: Database): if not app['feature_flags']['compact_billing_tables']: return @@ -1379,28 +1502,28 @@ async def compact_agg_billing_project_users_table(app, db: Database): @transaction(db) async def compact(tx: Transaction, target: dict): original_usage = await tx.execute_and_fetchone( - ''' + """ SELECT CAST(COALESCE(SUM(`usage`), 0) AS SIGNED) AS `usage` FROM aggregated_billing_project_user_resources_v3 WHERE billing_project = %s AND `user` = %s AND resource_id = %s FOR UPDATE; -''', +""", (target['billing_project'], target['user'], target['resource_id']), ) await tx.just_execute( - ''' + """ DELETE FROM aggregated_billing_project_user_resources_v3 WHERE billing_project = %s AND `user` = %s AND resource_id = %s; -''', +""", (target['billing_project'], target['user'], target['resource_id']), ) await tx.execute_update( - ''' + """ INSERT INTO aggregated_billing_project_user_resources_v3 (billing_project, `user`, resource_id, token, `usage`) VALUES (%s, %s, %s, %s, %s); -''', +""", ( target['billing_project'], target['user'], @@ -1411,12 +1534,12 @@ async def compact(tx: Transaction, target: dict): ) new_usage = await tx.execute_and_fetchone( - ''' + """ SELECT CAST(COALESCE(SUM(`usage`), 0) AS SIGNED) AS `usage` FROM aggregated_billing_project_user_resources_v3 WHERE billing_project = %s AND `user` = %s AND resource_id = %s GROUP BY billing_project, `user`, resource_id; -''', +""", (target['billing_project'], target['user'], target['resource_id']), ) @@ -1426,14 +1549,14 @@ async def compact(tx: Transaction, target: dict): ) targets = db.execute_and_fetchall( - ''' + """ SELECT billing_project, `user`, resource_id, COUNT(*) AS n_tokens FROM aggregated_billing_project_user_resources_v3 WHERE token != 0 GROUP BY billing_project, `user`, resource_id ORDER BY n_tokens DESC LIMIT 10000; -''', +""", query_name='find_agg_billing_project_user_resource_to_compact', ) @@ -1450,28 +1573,28 @@ async def compact_agg_billing_project_users_by_date_table(app, db: Database): @transaction(db) async def compact(tx: Transaction, target: dict): original_usage = await tx.execute_and_fetchone( - ''' + """ SELECT CAST(COALESCE(SUM(`usage`), 0) AS SIGNED) AS `usage` FROM aggregated_billing_project_user_resources_by_date_v3 WHERE billing_date = %s AND billing_project = %s AND `user` = %s AND resource_id = %s FOR UPDATE; -''', +""", (target['billing_date'], target['billing_project'], target['user'], target['resource_id']), ) await tx.just_execute( - ''' + """ DELETE FROM aggregated_billing_project_user_resources_by_date_v3 WHERE billing_date = %s AND billing_project = %s AND `user` = %s AND resource_id = %s; -''', +""", (target['billing_date'], target['billing_project'], target['user'], target['resource_id']), ) await tx.execute_update( - ''' + """ INSERT INTO aggregated_billing_project_user_resources_by_date_v3 (billing_date, billing_project, `user`, resource_id, token, `usage`) VALUES (%s, %s, %s, %s, %s, %s); -''', +""", ( target['billing_date'], target['billing_project'], @@ -1483,12 +1606,12 @@ async def compact(tx: Transaction, target: dict): ) new_usage = await tx.execute_and_fetchone( - ''' + """ SELECT CAST(COALESCE(SUM(`usage`), 0) AS SIGNED) AS `usage` FROM aggregated_billing_project_user_resources_by_date_v3 WHERE billing_date = %s AND billing_project = %s AND `user` = %s AND resource_id = %s GROUP BY billing_date, billing_project, `user`, resource_id; -''', +""", (target['billing_date'], target['billing_project'], target['user'], target['resource_id']), ) @@ -1498,14 +1621,14 @@ async def compact(tx: Transaction, target: dict): ) targets = db.execute_and_fetchall( - ''' + """ SELECT billing_date, billing_project, `user`, resource_id, COUNT(*) AS n_tokens FROM aggregated_billing_project_user_resources_by_date_v3 WHERE token != 0 GROUP BY billing_date, billing_project, `user`, resource_id ORDER BY n_tokens DESC LIMIT 10000; -''', +""", query_name='find_agg_billing_project_user_resource_by_date_to_compact', ) @@ -1529,11 +1652,9 @@ async def scheduling_cancelling_bump(app): async def refresh_globals_from_db(app, db): resource_ids = { record['resource']: Resource(record['resource_id'], record['deduped_resource_id']) - async for record in db.select_and_fetchall( - ''' + async for record in db.select_and_fetchall(""" SELECT resource, resource_id, deduped_resource_id FROM resources; -''' - ) +""") } app['resource_name_to_id'] = resource_ids @@ -1587,11 +1708,9 @@ async def close_and_wait(): app['db'] = db exit_stack.push_async_callback(app['db'].async_close) - row = await db.select_and_fetchone( - ''' + row = await db.select_and_fetchone(""" SELECT instance_id, frozen FROM globals; -''' - ) +""") instance_id = row['instance_id'] log.info(f'instance_id {instance_id}') app['instance_id'] = instance_id @@ -1616,8 +1735,8 @@ async def close_and_wait(): inst_coll_configs = await InstanceCollectionConfigs.create(db) - app['client_session'] = httpx.client_session() - exit_stack.push_async_callback(app['client_session'].close) + app[CommonAiohttpAppKeys.CLIENT_SESSION] = httpx.client_session() + exit_stack.push_async_callback(app[CommonAiohttpAppKeys.CLIENT_SESSION].close) app['driver'] = await get_cloud_driver(app, db, MACHINE_NAME_PREFIX, DEFAULT_NAMESPACE, inst_coll_configs) exit_stack.push_async_callback(app['driver'].shutdown) @@ -1630,12 +1749,14 @@ async def close_and_wait(): exit_stack.push_async_callback(app['task_manager'].shutdown_and_wait) task_manager.ensure_future(periodically_call(10, monitor_billing_limits, app)) - task_manager.ensure_future(periodically_call(10, cancel_fast_failing_batches, app)) + task_manager.ensure_future(periodically_call(10, cancel_fast_failing_job_groups, app)) task_manager.ensure_future(periodically_call(60, scheduling_cancelling_bump, app)) task_manager.ensure_future(periodically_call(15, monitor_system, app)) task_manager.ensure_future(periodically_call(5, refresh_globals_from_db, app, db)) task_manager.ensure_future(periodically_call(60, compact_agg_billing_project_users_table, app, db)) task_manager.ensure_future(periodically_call(60, compact_agg_billing_project_users_by_date_table, app, db)) + task_manager.ensure_future(periodically_call(60, delete_committed_job_groups_inst_coll_staging_records, db)) + task_manager.ensure_future(periodically_call(60, delete_prev_cancelled_job_group_cancellable_resources_records, db)) async def on_cleanup(app): diff --git a/batch/batch/exceptions.py b/batch/batch/exceptions.py index 4ce2ff78aab..2a8cab3715e 100644 --- a/batch/batch/exceptions.py +++ b/batch/batch/exceptions.py @@ -1,18 +1,20 @@ +from typing import Union + from aiohttp import web class BatchUserError(Exception): - def __init__(self, message, severity): + def __init__(self, message: str, severity: str): super().__init__(message) self.message = message self.ui_error_type = severity - def http_response(self): + def http_response(self) -> web.HTTPError: return web.HTTPForbidden(reason=self.message) class NonExistentBillingProjectError(BatchUserError): - def __init__(self, billing_project): + def __init__(self, billing_project: str): super().__init__(f'Billing project {billing_project} does not exist.', 'error') def http_response(self): @@ -20,12 +22,12 @@ def http_response(self): class ClosedBillingProjectError(BatchUserError): - def __init__(self, billing_project): + def __init__(self, billing_project: str): super().__init__(f'Billing project {billing_project} is closed and cannot be modified.', 'error') class InvalidBillingLimitError(BatchUserError): - def __init__(self, billing_limit): + def __init__(self, billing_limit: Union[str, float, int]): super().__init__(f'Invalid billing_limit {billing_limit}.', 'error') def http_response(self): @@ -33,29 +35,34 @@ def http_response(self): class NonExistentBatchError(BatchUserError): - def __init__(self, batch_id): + def __init__(self, batch_id: int): super().__init__(f'Batch {batch_id} does not exist.', 'error') +class NonExistentJobGroupError(BatchUserError): + def __init__(self, batch_id: int, job_group_id: int): + super().__init__(f'Job Group ({batch_id}, {job_group_id}) does not exist.', 'error') + + class NonExistentUserError(BatchUserError): - def __init__(self, user): + def __init__(self, user: str): super().__init__(f'User {user} does not exist.', 'error') class OpenBatchError(BatchUserError): - def __init__(self, batch_id): + def __init__(self, batch_id: int): super().__init__(f'Batch {batch_id} is open.', 'error') class BatchOperationAlreadyCompletedError(Exception): - def __init__(self, message, severity): + def __init__(self, message: str, severity: str): super().__init__(message) self.message = message self.ui_error_type = severity class QueryError(BatchUserError): - def __init__(self, message): + def __init__(self, message: str): super().__init__(message, 'error') self.message = message diff --git a/batch/batch/front_end/front_end.py b/batch/batch/front_end/front_end.py index 65fad6bc5d3..307be7e38c8 100644 --- a/batch/batch/front_end/front_end.py +++ b/batch/batch/front_end/front_end.py @@ -23,18 +23,18 @@ import plotly.express as px import plotly.graph_objects as go import pymysql -import uvloop from aiohttp import web from plotly.subplots import make_subplots from prometheus_async.aio.web import server_stats # type: ignore from typing_extensions import ParamSpec from gear import ( - AuthServiceAuthenticator, + CommonAiohttpAppKeys, Database, Transaction, UserData, check_csrf_token, + get_authenticator, json_request, json_response, monitor_endpoints_middleware, @@ -45,13 +45,18 @@ from gear.clients import get_cloud_async_fs from gear.database import CallError from gear.profiling import install_profiler_if_requested -from hailtop import aiotools, dictfix, httpx, version +from hailtop import aiotools, dictfix, httpx, uvloopx, version from hailtop.auth import hail_credentials +from hailtop.batch_client.globals import MAX_JOB_GROUPS_DEPTH, ROOT_JOB_GROUP_ID from hailtop.batch_client.parse import parse_cpu_in_mcpu, parse_memory_in_bytes, parse_storage_in_bytes -from hailtop.batch_client.types import GetJobResponseV1Alpha, GetJobsResponseV1Alpha, JobListEntryV1Alpha +from hailtop.batch_client.types import ( + GetJobGroupResponseV1Alpha, + GetJobResponseV1Alpha, + GetJobsResponseV1Alpha, + JobListEntryV1Alpha, +) from hailtop.config import get_deploy_config from hailtop.hail_logging import AccessLogger -from hailtop.tls import internal_server_ssl_context from hailtop.utils import ( cost_str, dump_all_stacktraces, @@ -65,7 +70,7 @@ ) from web_common import render_template, set_message, setup_aiohttp_jinja2, setup_common_static_routes -from ..batch import batch_record_to_dict, cancel_batch_in_db, job_record_to_dict +from ..batch import batch_record_to_dict, cancel_job_group_in_db, job_group_record_to_dict, job_record_to_dict from ..batch_configuration import BATCH_STORAGE_URI, CLOUD, DEFAULT_NAMESPACE, SCOPE from ..batch_format_version import BatchFormatVersion from ..cloud.resource_utils import ( @@ -75,13 +80,13 @@ valid_machine_types, ) from ..cloud.utils import ACCEPTABLE_QUERY_JAR_URL_PREFIX -from ..constants import ROOT_JOB_GROUP_ID from ..exceptions import ( BatchOperationAlreadyCompletedError, BatchUserError, ClosedBillingProjectError, InvalidBillingLimitError, NonExistentBillingProjectError, + NonExistentJobGroupError, NonExistentUserError, QueryError, ) @@ -104,14 +109,21 @@ ) from .query import ( CURRENT_QUERY_VERSION, - parse_batch_jobs_query_v1, - parse_batch_jobs_query_v2, + parse_job_group_jobs_query_v1, + parse_job_group_jobs_query_v2, parse_list_batches_query_v1, parse_list_batches_query_v2, + parse_list_job_groups_query_v1, +) +from .validate import ( + ValidationError, + validate_and_clean_jobs, + validate_batch, + validate_batch_update, + validate_job_groups, ) -from .validate import ValidationError, validate_and_clean_jobs, validate_batch, validate_batch_update -uvloop.install() +uvloopx.install() log = logging.getLogger('batch.front_end') @@ -119,7 +131,7 @@ deploy_config = get_deploy_config() -auth = AuthServiceAuthenticator() +auth = get_authenticator() BATCH_JOB_DEFAULT_CPU = os.environ.get('HAIL_BATCH_JOB_DEFAULT_CPU', '1') BATCH_JOB_DEFAULT_MEMORY = os.environ.get('HAIL_BATCH_JOB_DEFAULT_MEMORY', 'standard') @@ -162,12 +174,12 @@ async def wrapped(request, userdata, *args, **kwargs): async def _user_can_access(db: Database, batch_id: int, user: str): record = await db.select_and_fetchone( - ''' + """ SELECT id FROM batches LEFT JOIN billing_project_users ON batches.billing_project = billing_project_users.billing_project WHERE id = %s AND billing_project_users.`user_cs` = %s; -''', +""", (batch_id, user), ) @@ -198,6 +210,13 @@ def cast_query_param_to_int(param: Optional[str]) -> Optional[int]: return None +def cast_query_param_to_bool(param: Optional[str]) -> bool: + if param is None or param in ('False', 'false', '0'): + return False + assert param in ('True', 'true', '1') + return True + + @routes.get('/healthcheck') async def get_healthcheck(_) -> web.Response: return web.Response() @@ -337,15 +356,22 @@ async def _query_batch_jobs_for_billing(request, batch_id): return jobs, last_job_id -async def _query_batch_jobs( - request: web.Request, batch_id: int, version: int, q: str, last_job_id: Optional[int] +async def _query_job_group_jobs( + request: web.Request, + batch_id: int, + job_group_id: int, + version: int, + q: str, + last_job_id: Optional[int], + recursive: bool, ) -> Tuple[List[JobListEntryV1Alpha], Optional[int]]: db: Database = request.app['db'] + if version == 1: - sql, sql_args = parse_batch_jobs_query_v1(batch_id, q, last_job_id) + sql, sql_args = parse_job_group_jobs_query_v1(batch_id, job_group_id, q, last_job_id, recursive) else: assert version == 2, version - sql, sql_args = parse_batch_jobs_query_v2(batch_id, q, last_job_id) + sql, sql_args = parse_job_group_jobs_query_v2(batch_id, job_group_id, q, last_job_id, recursive) jobs = [job_record_to_dict(record, record['name']) async for record in db.select_and_fetchall(sql, sql_args)] @@ -365,7 +391,7 @@ async def get_completed_batches_ordered_by_completed_time(request, userdata): wheres = [ 'billing_project_users.`user` = %s', 'billing_project_users.billing_project = batches.billing_project', - 'time_completed IS NOT NULL', + 'batches.time_completed IS NOT NULL', 'NOT deleted', ] @@ -384,32 +410,32 @@ async def get_completed_batches_ordered_by_completed_time(request, userdata): sql = f""" SELECT batches.*, - job_groups_cancelled.id IS NOT NULL AS cancelled, + cancelled_t.cancelled IS NOT NULL AS cancelled, job_groups_n_jobs_in_complete_states.n_completed, job_groups_n_jobs_in_complete_states.n_succeeded, job_groups_n_jobs_in_complete_states.n_failed, job_groups_n_jobs_in_complete_states.n_cancelled, cost_t.* -FROM batches +FROM job_groups +LEFT JOIN batches ON batches.id = job_groups.batch_id LEFT JOIN billing_projects ON batches.billing_project = billing_projects.name LEFT JOIN job_groups_n_jobs_in_complete_states - ON batches.id = job_groups_n_jobs_in_complete_states.id + ON job_groups.batch_id = job_groups_n_jobs_in_complete_states.id AND job_groups.job_group_id = job_groups_n_jobs_in_complete_states.job_group_id LEFT JOIN job_groups_cancelled - ON batches.id = job_groups_cancelled.id + ON batches.id = job_groups_cancelled.id AND job_groups_cancelled.job_group_id = %s STRAIGHT_JOIN billing_project_users ON batches.billing_project = billing_project_users.billing_project LEFT JOIN LATERAL ( - SELECT COALESCE(SUM(`usage` * rate), 0) AS cost, JSON_OBJECTAGG(resources.resource, COALESCE(`usage` * rate, 0)) AS cost_breakdown - FROM ( - SELECT batch_id, resource_id, CAST(COALESCE(SUM(`usage`), 0) AS SIGNED) AS `usage` - FROM aggregated_job_group_resources_v3 - WHERE batches.id = aggregated_job_group_resources_v3.batch_id - GROUP BY batch_id, resource_id - ) AS usage_t - LEFT JOIN resources ON usage_t.resource_id = resources.resource_id - GROUP BY batch_id - ) AS cost_t ON TRUE + SELECT COALESCE(SUM(`usage` * rate), 0) AS cost, JSON_OBJECTAGG(resources.resource, COALESCE(`usage` * rate, 0)) AS cost_breakdown + FROM ( + SELECT resource_id, CAST(COALESCE(SUM(`usage`), 0) AS SIGNED) AS `usage` + FROM aggregated_job_group_resources_v3 + WHERE job_groups.batch_id = aggregated_job_group_resources_v3.batch_id AND job_groups.job_group_id = aggregated_job_group_resources_v3.job_group_id + GROUP BY resource_id + ) AS usage_t + LEFT JOIN resources ON usage_t.resource_id = resources.resource_id +) AS cost_t ON TRUE WHERE {' AND '.join(wheres)} ORDER BY time_completed DESC @@ -417,7 +443,7 @@ async def get_completed_batches_ordered_by_completed_time(request, userdata): """ records = [ - batch async for batch in db.select_and_fetchall(sql, (*where_args, limit), query_name='get_completed_batches') + batch async for batch in db.select_and_fetchall(sql, (ROOT_JOB_GROUP_ID, *where_args, limit), query_name='get_completed_batches') ] # this comes out as a timestamp (rather than a formed date) last_completed_timestamp = records[-1]['time_completed'] @@ -428,22 +454,37 @@ async def get_completed_batches_ordered_by_completed_time(request, userdata): return web.json_response(body) -async def _get_jobs( - request: web.Request, batch_id: int, version: int, q: str, last_job_id: Optional[int] +async def _get_job_group_jobs( + request: web.Request, + batch_id: int, + job_group_id: int, + version: int, + q: str, + last_job_id: Optional[int], + recursive: bool, ) -> GetJobsResponseV1Alpha: db = request.app['db'] + is_root_job_group = job_group_id == ROOT_JOB_GROUP_ID + record = await db.select_and_fetchone( - ''' -SELECT * FROM batches -WHERE id = %s AND NOT deleted; -''', - (batch_id,), + """ +SELECT * FROM job_groups +LEFT JOIN batches ON batches.id = job_groups.batch_id +LEFT JOIN batch_updates + ON job_groups.batch_id = batch_updates.batch_id AND + job_groups.update_id = batch_updates.update_id +WHERE job_groups.batch_id = %s AND + job_groups.job_group_id = %s AND + NOT deleted AND + (batch_updates.committed OR %s); +""", + (batch_id, job_group_id, is_root_job_group), ) if not record: raise web.HTTPNotFound() - jobs, last_job_id = await _query_batch_jobs(request, batch_id, version, q, last_job_id) + jobs, last_job_id = await _query_job_group_jobs(request, batch_id, job_group_id, version, q, last_job_id, recursive) if last_job_id is not None: return {'jobs': jobs, 'last_job_id': last_job_id} @@ -453,21 +494,40 @@ async def _get_jobs( @routes.get('/api/v1alpha/batches/{batch_id}/jobs') @billing_project_users_only() @add_metadata_to_request -async def get_jobs_v1(request: web.Request, _, batch_id: int) -> web.Response: - q = request.query.get('q', '') - last_job_id = cast_query_param_to_int(request.query.get('last_job_id')) - resp = await _handle_api_error(_get_jobs, request, batch_id, 1, q, last_job_id) - assert resp is not None - return json_response(resp) +async def get_batch_jobs_v1(request: web.Request, _, batch_id: int) -> web.Response: + return await _api_get_job_group_jobs(request, batch_id, ROOT_JOB_GROUP_ID, 1) @routes.get('/api/v2alpha/batches/{batch_id}/jobs') @billing_project_users_only() @add_metadata_to_request -async def get_jobs_v2(request: web.Request, _, batch_id: int) -> web.Response: +async def get_batch_jobs_v2(request: web.Request, _, batch_id: int) -> web.Response: + return await _api_get_job_group_jobs(request, batch_id, ROOT_JOB_GROUP_ID, 2) + + +@routes.get('/api/v1alpha/batches/{batch_id}/job-groups/{job_group_id}/jobs') +@billing_project_users_only() +@add_metadata_to_request +async def get_job_group_jobs_v1(request: web.Request, _, batch_id: int) -> web.Response: + job_group_id = int(request.match_info['job_group_id']) + return await _api_get_job_group_jobs(request, batch_id, job_group_id, 1) + + +@routes.get('/api/v2alpha/batches/{batch_id}/job-groups/{job_group_id}/jobs') +@billing_project_users_only() +@add_metadata_to_request +async def get_job_group_jobs_v2(request: web.Request, _, batch_id: int) -> web.Response: + job_group_id = int(request.match_info['job_group_id']) + return await _api_get_job_group_jobs(request, batch_id, job_group_id, 2) + + +async def _api_get_job_group_jobs(request, batch_id: int, job_group_id: int, version: int): q = request.query.get('q', '') + recursive = cast_query_param_to_bool(request.query.get('recursive')) last_job_id = cast_query_param_to_int(request.query.get('last_job_id')) - resp = await _handle_api_error(_get_jobs, request, batch_id, 2, q, last_job_id) + resp = await _handle_api_error( + _get_job_group_jobs, request, batch_id, job_group_id, version, q, last_job_id, recursive + ) assert resp is not None return json_response(resp) @@ -578,7 +638,7 @@ async def _get_job_record(app, batch_id, job_id): db: Database = app['db'] record = await db.select_and_fetchone( - ''' + """ SELECT jobs.state, jobs.spec, ip_address, format_version, jobs.attempt_id, t.attempt_id AS last_cancelled_attempt_id FROM jobs INNER JOIN batches @@ -596,7 +656,7 @@ async def _get_job_record(app, batch_id, job_id): ) AS t ON jobs.batch_id = t.batch_id AND jobs.job_id = t.job_id WHERE jobs.batch_id = %s AND NOT deleted AND jobs.job_id = %s; -''', +""", (batch_id, job_id, batch_id, job_id), ) if not record: @@ -667,7 +727,7 @@ async def _get_job_container_log(app, batch_id, job_id, container, job_record) - state = job_record['state'] if state == 'Running': return await _get_job_container_log_from_worker( - app['client_session'], batch_id, job_id, container, job_record['ip_address'] + app[CommonAiohttpAppKeys.CLIENT_SESSION], batch_id, job_id, container, job_record['ip_address'] ) attempt_id = attempt_id_from_spec(job_record) @@ -698,7 +758,7 @@ async def _get_job_resource_usage_from_record( app, record, batch_id: int, job_id: int ) -> Optional[Dict[str, Optional[pd.DataFrame]]]: - client_session: httpx.ClientSession = app['client_session'] + client_session = app[CommonAiohttpAppKeys.CLIENT_SESSION] file_store: FileStore = app['file_store'] batch_format_version = BatchFormatVersion(record['format_version']) @@ -774,11 +834,11 @@ async def _get_attributes(app, record): return spec.get('attributes') records = db.select_and_fetchall( - ''' + """ SELECT `key`, `value` FROM job_attributes WHERE batch_id = %s AND job_id = %s; -''', +""", (batch_id, job_id), query_name='get_attributes', ) @@ -808,7 +868,7 @@ async def _get_full_job_spec(app, record): async def _get_full_job_status(app, record): - client_session: httpx.ClientSession = app['client_session'] + client_session = app[CommonAiohttpAppKeys.CLIENT_SESSION] file_store: FileStore = app['file_store'] batch_id = record['batch_id'] @@ -940,6 +1000,91 @@ async def get_batches_v2(request, userdata): # pylint: disable=unused-argument return json_response({'batches': batches}) +async def _query_job_groups( + request, batch_id: int, job_group_id: int, last_child_job_group_id: Optional[int] +) -> Tuple[List[GetJobGroupResponseV1Alpha], Optional[int]]: + db: Database = request.app['db'] + + @transaction(db) + async def _query(tx): + is_root_job_group = job_group_id == ROOT_JOB_GROUP_ID + record = await tx.execute_and_fetchone( + """ +SELECT 1 +FROM job_groups +LEFT JOIN batches ON batches.id = job_groups.batch_id +LEFT JOIN batch_updates + ON job_groups.batch_id = batch_updates.batch_id AND job_groups.update_id = batch_updates.update_id +WHERE job_groups.batch_id = %s AND job_groups.job_group_id = %s AND NOT deleted AND (batch_updates.committed OR %s); +""", + (batch_id, job_group_id, is_root_job_group), + ) + if not record: + raise NonExistentJobGroupError(batch_id, job_group_id) + + sql, sql_args = parse_list_job_groups_query_v1(batch_id, job_group_id, last_child_job_group_id) + job_groups = [job_group_record_to_dict(record) async for record in tx.execute_and_fetchall(sql, sql_args)] + + if len(job_groups) == 51: + job_groups.pop() + new_last_child_job_group_id = job_groups[-1]['job_group_id'] + else: + new_last_child_job_group_id = None + + return (job_groups, new_last_child_job_group_id) + + return await _query() + + +async def _api_get_job_groups_v1(request: web.Request, batch_id: int, job_group_id: int): + last_child_job_group_id = cast_query_param_to_int(request.query.get('last_job_group_id')) + result = await _handle_api_error(_query_job_groups, request, batch_id, job_group_id, last_child_job_group_id) + assert result is not None + job_groups, last_child_job_group_id = result + if last_child_job_group_id is not None: + return json_response({'job_groups': job_groups, 'last_job_group_id': last_child_job_group_id}) + return json_response({'job_groups': job_groups}) + + +@routes.get('/api/v1alpha/batches/{batch_id}/job-groups') +@billing_project_users_only() +@add_metadata_to_request +async def get_root_job_groups_v1(request: web.Request, _, batch_id: int): # pylint: disable=unused-argument + return await _api_get_job_groups_v1(request, batch_id, ROOT_JOB_GROUP_ID) + + +@routes.get('/api/v1alpha/batches/{batch_id}/job-groups/{job_group_id}/job-groups') +@billing_project_users_only() +@add_metadata_to_request +async def get_job_groups_v1(request: web.Request, _, batch_id: int): # pylint: disable=unused-argument + job_group_id = int(request.match_info['job_group_id']) + return await _api_get_job_groups_v1(request, batch_id, job_group_id) + + +@routes.post('/api/v1alpha/batches/{batch_id}/updates/{update_id}/job-groups/create') +@auth.authenticated_users_only() +@add_metadata_to_request +async def create_job_groups(request: web.Request, userdata: UserData) -> web.Response: + app = request.app + db: Database = app['db'] + user = userdata['username'] + + if app['frozen']: + log.info('ignoring batch job group create request; batch is frozen') + raise web.HTTPServiceUnavailable() + + batch_id = int(request.match_info['batch_id']) + update_id = int(request.match_info['update_id']) + job_group_specs = await json_request(request) + try: + validate_job_groups(job_group_specs) + except ValidationError as e: + raise web.HTTPBadRequest(reason=e.reason) + + await _create_job_groups(db, batch_id, update_id, user, job_group_specs) + return web.Response() + + def check_service_account_permissions(user, sa): if sa is None: return @@ -960,6 +1105,11 @@ async def create_jobs(request: web.Request, userdata: UserData) -> web.Response: app = request.app batch_id = int(request.match_info['batch_id']) job_specs = await json_request(request) + try: + validate_and_clean_jobs(job_specs) + except ValidationError as e: + raise web.HTTPBadRequest(reason=e.reason) + return await _create_jobs(userdata, job_specs, batch_id, 1, app) @@ -976,6 +1126,11 @@ async def create_jobs_for_update(request: web.Request, userdata: UserData) -> we batch_id = int(request.match_info['batch_id']) update_id = int(request.match_info['update_id']) job_specs = await json_request(request) + try: + validate_and_clean_jobs(job_specs) + except ValidationError as e: + raise web.HTTPBadRequest(reason=e.reason) + return await _create_jobs(userdata, job_specs, batch_id, update_id, app) @@ -987,9 +1142,180 @@ def assert_is_sha_1_hex_string(revision: str): raise web.HTTPBadRequest(reason=f'revision must be 40 character hexadecimal encoded SHA-1, got: {revision}') +async def _create_job_group( + tx: Transaction, + *, + batch_id: int, + job_group_id: int, + update_id: Optional[int], + user: str, + attributes: Optional[Dict[str, str]], + cancel_after_n_failures: Optional[int], + callback: Optional[str], + timestamp: int, + parent_job_group_id: int, +): + cancelled_parent = await tx.execute_and_fetchone( + """ +SELECT 1 AS cancelled +FROM job_group_self_and_ancestors +INNER JOIN job_groups_cancelled + ON job_group_self_and_ancestors.batch_id = job_groups_cancelled.id AND + job_group_self_and_ancestors.ancestor_id = job_groups_cancelled.job_group_id +WHERE job_group_self_and_ancestors.batch_id = %s AND job_group_self_and_ancestors.job_group_id = %s; +""", + (batch_id, parent_job_group_id), + ) + if cancelled_parent is not None: + raise web.HTTPBadRequest(reason='job group parent has already been cancelled') + + await tx.execute_insertone( + """ +INSERT INTO job_groups (batch_id, job_group_id, `user`, attributes, cancel_after_n_failures, state, n_jobs, time_created, time_completed, callback, update_id) +VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s); +""", + ( + batch_id, + job_group_id, + user, + json.dumps(attributes), + cancel_after_n_failures, + 'complete', + 0, + timestamp, + timestamp, + callback, + update_id, + ), + query_name='insert_job_group', + ) + + if job_group_id != ROOT_JOB_GROUP_ID: + assert parent_job_group_id < job_group_id + + n_rows_inserted = await tx.execute_update( + """ +INSERT INTO job_group_self_and_ancestors (batch_id, job_group_id, ancestor_id, level) +SELECT batch_id, %s, ancestor_id, ancestors.level + 1 +FROM job_group_self_and_ancestors ancestors +WHERE batch_id = %s AND job_group_id = %s; +""", + (job_group_id, batch_id, parent_job_group_id), + query_name='insert_job_group_ancestors', + ) + + if n_rows_inserted > MAX_JOB_GROUPS_DEPTH: + raise web.HTTPBadRequest(reason='job group exceeded the maximum level of nesting') + + await tx.execute_insertone( + """ +INSERT INTO job_group_self_and_ancestors (batch_id, job_group_id, ancestor_id, level) +VALUES (%s, %s, %s, %s); +""", + (batch_id, job_group_id, job_group_id, 0), + query_name='insert_job_group_self', + ) + + await tx.execute_insertone( + """ +INSERT INTO job_groups_n_jobs_in_complete_states (id, job_group_id) +VALUES (%s, %s); +""", + (batch_id, job_group_id), + query_name='insert_job_groups_n_jobs_in_complete_states', + ) + + if attributes: + await tx.execute_many( + """ +INSERT INTO job_group_attributes (batch_id, job_group_id, `key`, `value`) +VALUES (%s, %s, %s, %s); +""", + [(batch_id, job_group_id, k, v) for k, v in attributes.items()], + query_name='insert_job_group_attributes', + ) + + +async def _create_job_groups(db: Database, batch_id: int, update_id: int, user: str, job_group_specs: List[dict]): + assert len(job_group_specs) > 0 + + @transaction(db) + async def insert(tx): + record = await tx.execute_and_fetchone( + """ +SELECT `state`, format_version, `committed`, start_job_group_id +FROM batch_updates +INNER JOIN batches ON batch_updates.batch_id = batches.id +WHERE batch_updates.batch_id = %s AND batch_updates.update_id = %s AND `user` = %s AND NOT deleted +LOCK IN SHARE MODE; +""", + (batch_id, update_id, user), + ) + + if not record: + raise web.HTTPNotFound() + if record['committed']: + raise web.HTTPBadRequest(reason=f'update {update_id} is already committed') + + start_job_group_id = record['start_job_group_id'] + + last_inserted_job_group_id = await tx.execute_and_fetchone( + """ +SELECT job_group_id +FROM job_groups +WHERE batch_id = %s +ORDER BY job_group_id DESC +LIMIT 1 +FOR UPDATE; +""", + (batch_id,), + ) + + next_job_group_id = start_job_group_id + job_group_specs[0]['job_group_id'] - 1 + if next_job_group_id != last_inserted_job_group_id['job_group_id'] + 1: + raise web.HTTPBadRequest(reason='job group specs were not submitted in order') + + now = time_msecs() + + for spec in job_group_specs: + job_group_id = start_job_group_id + spec['job_group_id'] - 1 + + if 'absolute_parent_id' in spec: + parent_job_group_id = spec['absolute_parent_id'] + else: + assert 'in_update_parent_id' in spec + parent_job_group_id = start_job_group_id + spec['in_update_parent_id'] - 1 + + try: + await _create_job_group( + tx, + batch_id=batch_id, + job_group_id=job_group_id, + update_id=update_id, + user=user, + attributes=spec.get('attributes'), + cancel_after_n_failures=spec.get('cancel_after_n_failures'), + callback=spec.get('callback'), + timestamp=now, + parent_job_group_id=parent_job_group_id, + ) + except asyncio.CancelledError: + raise + except Exception as e: + raise web.HTTPBadRequest( + reason=f'error while inserting job group {spec["job_group_id"]} into batch {batch_id}: {e}' + ) + + await insert() + + return web.Response() + + async def _create_jobs( userdata, job_specs: List[Dict[str, Any]], batch_id: int, update_id: int, app: web.Application ) -> web.Response: + assert len(job_specs) > 0 + db: Database = app['db'] file_store: FileStore = app['file_store'] user = userdata['username'] @@ -1002,12 +1328,12 @@ async def _create_jobs( } record = await db.select_and_fetchone( - ''' -SELECT `state`, format_version, `committed`, start_job_id + """ +SELECT `state`, format_version, `committed`, start_job_id, start_job_group_id FROM batch_updates INNER JOIN batches ON batch_updates.batch_id = batches.id WHERE batch_updates.batch_id = %s AND batch_updates.update_id = %s AND user = %s AND NOT deleted; -''', +""", (batch_id, update_id, user), ) @@ -1015,13 +1341,10 @@ async def _create_jobs( raise web.HTTPNotFound() if record['committed']: raise web.HTTPBadRequest(reason=f'update {update_id} is already committed') + batch_format_version = BatchFormatVersion(record['format_version']) update_start_job_id = int(record['start_job_id']) - - try: - validate_and_clean_jobs(job_specs) - except ValidationError as e: - raise web.HTTPBadRequest(reason=e.reason) + update_start_job_group_id = int(record['start_job_group_id']) spec_writer = SpecWriter(file_store, batch_id) @@ -1030,7 +1353,7 @@ async def _create_jobs( job_attributes_args = [] jobs_telemetry_args = [] - inst_coll_resources: Dict[str, Dict[str, int]] = collections.defaultdict( + inst_coll_resources: Dict[Tuple[int, str], Dict[str, int]] = collections.defaultdict( lambda: { 'n_jobs': 0, 'n_ready_jobs': 0, @@ -1040,7 +1363,6 @@ async def _create_jobs( } ) - prev_job_idx = None bunch_start_job_id = None for spec in job_specs: @@ -1051,6 +1373,15 @@ async def _create_jobs( in_update_parent_ids = spec.pop('in_update_parent_ids', []) parent_ids = absolute_parent_ids + [update_start_job_id + parent_id - 1 for parent_id in in_update_parent_ids] + absolute_job_group_id = spec.pop('absolute_job_group_id', None) + in_update_job_group_id = spec.pop('in_update_job_group_id', None) + if absolute_job_group_id is not None: + job_group_id = absolute_job_group_id + else: + assert in_update_job_group_id is not None + job_group_id = update_start_job_group_id + in_update_job_group_id - 1 + spec['job_group_id'] = job_group_id + always_run = spec.pop('always_run', False) cloud = spec.get('cloud', CLOUD) @@ -1065,11 +1396,6 @@ async def _create_jobs( if bunch_start_job_id is None: bunch_start_job_id = job_id - if batch_format_version.has_full_spec_in_cloud() and prev_job_idx: - if job_id != prev_job_idx + 1: - raise web.HTTPBadRequest(reason=f'noncontiguous job ids found in the spec: {prev_job_idx} -> {job_id}') - prev_job_idx = job_id - resources = spec.get('resources') if not resources: resources = {} @@ -1214,26 +1540,24 @@ async def _create_jobs( spec['secrets'] = secrets - secrets.append( - { - 'namespace': DEFAULT_NAMESPACE, - 'name': userdata['hail_credentials_secret_name'], - 'mount_path': '/gsa-key', - 'mount_in_copy': True, - } - ) - env = spec.get('env') if not env: env = [] spec['env'] = env assert isinstance(spec['env'], list) - if cloud == 'gcp' and all(envvar['name'] != 'GOOGLE_APPLICATION_CREDENTIALS' for envvar in spec['env']): - spec['env'].append({'name': 'GOOGLE_APPLICATION_CREDENTIALS', 'value': '/gsa-key/key.json'}) + if not os.environ.get('HAIL_TERRA'): + secrets.append({ + 'namespace': DEFAULT_NAMESPACE, + 'name': userdata['hail_credentials_secret_name'], + 'mount_path': '/gsa-key', + 'mount_in_copy': True, + }) + if cloud == 'gcp' and all(envvar['name'] != 'GOOGLE_APPLICATION_CREDENTIALS' for envvar in spec['env']): + spec['env'].append({'name': 'GOOGLE_APPLICATION_CREDENTIALS', 'value': '/gsa-key/key.json'}) - if cloud == 'azure' and all(envvar['name'] != 'AZURE_APPLICATION_CREDENTIALS' for envvar in spec['env']): - spec['env'].append({'name': 'AZURE_APPLICATION_CREDENTIALS', 'value': '/gsa-key/key.json'}) + if cloud == 'azure' and all(envvar['name'] != 'AZURE_APPLICATION_CREDENTIALS' for envvar in spec['env']): + spec['env'].append({'name': 'AZURE_APPLICATION_CREDENTIALS', 'value': '/gsa-key/key.json'}) cloudfuse = spec.get('gcsfuse') or spec.get('cloudfuse') if cloudfuse: @@ -1246,27 +1570,25 @@ async def _create_jobs( ) if spec.get('mount_tokens', False): - secrets.append( - { - 'namespace': DEFAULT_NAMESPACE, - 'name': userdata['tokens_secret_name'], - 'mount_path': '/user-tokens', - 'mount_in_copy': False, - } - ) - secrets.append( - { - 'namespace': DEFAULT_NAMESPACE, - 'name': 'ssl-config-batch-user-code', - 'mount_path': '/ssl-config', - 'mount_in_copy': False, - } - ) + # Clients stopped using `mount_tokens` prior to the introduction of terra deployments + assert not os.environ.get('HAIL_TERRA', False) + secrets.append({ + 'namespace': DEFAULT_NAMESPACE, + 'name': userdata['tokens_secret_name'], + 'mount_path': '/user-tokens', + 'mount_in_copy': False, + }) + secrets.append({ + 'namespace': DEFAULT_NAMESPACE, + 'name': 'ssl-config-batch-user-code', + 'mount_path': '/ssl-config', + 'mount_in_copy': False, + }) sa = spec.get('service_account') check_service_account_permissions(user, sa) - icr = inst_coll_resources[inst_coll_name] + icr = inst_coll_resources[(job_group_id, inst_coll_name)] icr['n_jobs'] += 1 # jobs in non-initial updates of a batch always start out as pending @@ -1296,22 +1618,20 @@ async def _create_jobs( spec_writer.add(json.dumps(spec)) db_spec = batch_format_version.db_spec(spec) - jobs_args.append( - ( - batch_id, - job_id, - update_id, - ROOT_JOB_GROUP_ID, - state, - json.dumps(db_spec), - always_run, - cores_mcpu, - len(parent_ids), - inst_coll_name, - n_regions, - regions_bits_rep, - ) - ) + jobs_args.append(( + batch_id, + job_id, + update_id, + job_group_id, + state, + json.dumps(db_spec), + always_run, + cores_mcpu, + len(parent_ids), + inst_coll_name, + n_regions, + regions_bits_rep, + )) jobs_telemetry_args.append((batch_id, job_id, time_ready)) @@ -1330,123 +1650,127 @@ async def write_spec_to_cloud(): async def insert_jobs_into_db(tx): try: - try: - await tx.execute_many( - ''' + await tx.execute_many( + """ INSERT INTO jobs (batch_id, job_id, update_id, job_group_id, state, spec, always_run, cores_mcpu, n_pending_parents, inst_coll, n_regions, regions_bits_rep) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s); -''', - jobs_args, - query_name='insert_jobs', +""", + jobs_args, + query_name='insert_jobs', + ) + except pymysql.err.IntegrityError as err: + # 1062 ER_DUP_ENTRY https://dev.mysql.com/doc/refman/5.7/en/server-error-reference.html#error_er_dup_entry + if err.args[0] == 1062: + log.info(f'bunch containing job {(batch_id, jobs_args[0][1])} already inserted') + return + raise + except pymysql.err.OperationalError as err: + if err.args[0] == 1644 and err.args[1] == 'job group has already been cancelled': + raise web.HTTPBadRequest( + text=f'bunch contains job where the job group has already been cancelled ({(batch_id, jobs_args[0][1])})' ) - except pymysql.err.IntegrityError as err: - # 1062 ER_DUP_ENTRY https://dev.mysql.com/doc/refman/5.7/en/server-error-reference.html#error_er_dup_entry - if err.args[0] == 1062: - log.info(f'bunch containing job {(batch_id, jobs_args[0][1])} already inserted') - return - raise - try: - await tx.execute_many( - ''' + raise + + try: + await tx.execute_many( + """ INSERT INTO `job_parents` (batch_id, job_id, parent_id) VALUES (%s, %s, %s); -''', - job_parents_args, - query_name='insert_job_parents', - ) - except pymysql.err.IntegrityError as err: - # 1062 ER_DUP_ENTRY https://dev.mysql.com/doc/refman/5.7/en/server-error-reference.html#error_er_dup_entry - if err.args[0] == 1062: - raise web.HTTPBadRequest(text=f'bunch contains job with duplicated parents ({job_parents_args})') - raise +""", + job_parents_args, + query_name='insert_job_parents', + ) + except pymysql.err.IntegrityError as err: + # 1062 ER_DUP_ENTRY https://dev.mysql.com/doc/refman/5.7/en/server-error-reference.html#error_er_dup_entry + if err.args[0] == 1062: + raise web.HTTPBadRequest(text=f'bunch contains job with duplicated parents ({job_parents_args})') + raise - await tx.execute_many( - ''' + await tx.execute_many( + """ INSERT INTO `job_attributes` (batch_id, job_id, `key`, `value`) VALUES (%s, %s, %s, %s); -''', - job_attributes_args, - query_name='insert_job_attributes', - ) +""", + job_attributes_args, + query_name='insert_job_attributes', + ) - await tx.execute_many( - ''' + await tx.execute_many( + """ INSERT INTO jobs_telemetry (batch_id, job_id, time_ready) VALUES (%s, %s, %s); -''', - jobs_telemetry_args, - query_name='insert_jobs_telemetry', - ) +""", + jobs_telemetry_args, + query_name='insert_jobs_telemetry', + ) - job_groups_inst_coll_staging_args = [ - ( - batch_id, - update_id, - ROOT_JOB_GROUP_ID, - inst_coll, - rand_token, - resources['n_jobs'], - resources['n_ready_jobs'], - resources['ready_cores_mcpu'], - ) - for inst_coll, resources in inst_coll_resources.items() - ] - await tx.execute_many( - ''' + job_groups_inst_coll_staging_args = [ + ( + batch_id, + update_id, + inst_coll, + rand_token, + resources['n_jobs'], + resources['n_ready_jobs'], + resources['ready_cores_mcpu'], + batch_id, + icr_job_group_id, + ) + for (icr_job_group_id, inst_coll), resources in inst_coll_resources.items() + ] + # job_groups_inst_coll_staging tracks the num of resources recursively for all children job groups + await tx.execute_many( + """ INSERT INTO job_groups_inst_coll_staging (batch_id, update_id, job_group_id, inst_coll, token, n_jobs, n_ready_jobs, ready_cores_mcpu) -VALUES (%s, %s, %s, %s, %s, %s, %s, %s) +SELECT %s, %s, ancestor_id, %s, %s, %s, %s, %s +FROM job_group_self_and_ancestors +WHERE batch_id = %s AND job_group_id = %s ON DUPLICATE KEY UPDATE - n_jobs = n_jobs + VALUES(n_jobs), - n_ready_jobs = n_ready_jobs + VALUES(n_ready_jobs), - ready_cores_mcpu = ready_cores_mcpu + VALUES(ready_cores_mcpu); -''', - job_groups_inst_coll_staging_args, - query_name='insert_job_groups_inst_coll_staging', - ) +n_jobs = n_jobs + VALUES(n_jobs), +n_ready_jobs = n_ready_jobs + VALUES(n_ready_jobs), +ready_cores_mcpu = ready_cores_mcpu + VALUES(ready_cores_mcpu); +""", + job_groups_inst_coll_staging_args, + query_name='insert_job_groups_inst_coll_staging', + ) - job_group_inst_coll_cancellable_resources_args = [ - ( - batch_id, - update_id, - ROOT_JOB_GROUP_ID, - inst_coll, - rand_token, - resources['n_ready_cancellable_jobs'], - resources['ready_cancellable_cores_mcpu'], - ) - for inst_coll, resources in inst_coll_resources.items() - ] - await tx.execute_many( - ''' + job_group_inst_coll_cancellable_resources_args = [ + ( + batch_id, + update_id, + inst_coll, + rand_token, + resources['n_ready_cancellable_jobs'], + resources['ready_cancellable_cores_mcpu'], + batch_id, + icr_job_group_id, + ) + for (icr_job_group_id, inst_coll), resources in inst_coll_resources.items() + ] + # job_group_inst_coll_cancellable_resources tracks the num of resources recursively for all children job groups + await tx.execute_many( + """ INSERT INTO job_group_inst_coll_cancellable_resources (batch_id, update_id, job_group_id, inst_coll, token, n_ready_cancellable_jobs, ready_cancellable_cores_mcpu) -VALUES (%s, %s, %s, %s, %s, %s, %s) +SELECT %s, %s, ancestor_id, %s, %s, %s, %s +FROM job_group_self_and_ancestors +WHERE batch_id = %s AND job_group_id = %s ON DUPLICATE KEY UPDATE - n_ready_cancellable_jobs = n_ready_cancellable_jobs + VALUES(n_ready_cancellable_jobs), - ready_cancellable_cores_mcpu = ready_cancellable_cores_mcpu + VALUES(ready_cancellable_cores_mcpu); -''', - job_group_inst_coll_cancellable_resources_args, - query_name='insert_inst_coll_cancellable_resources', - ) +n_ready_cancellable_jobs = n_ready_cancellable_jobs + VALUES(n_ready_cancellable_jobs), +ready_cancellable_cores_mcpu = ready_cancellable_cores_mcpu + VALUES(ready_cancellable_cores_mcpu); +""", + job_group_inst_coll_cancellable_resources_args, + query_name='insert_inst_coll_cancellable_resources', + ) - if batch_format_version.has_full_spec_in_cloud(): - await tx.execute_update( - ''' + if batch_format_version.has_full_spec_in_cloud(): + await tx.execute_update( + """ INSERT INTO batch_bunches (batch_id, token, start_job_id) VALUES (%s, %s, %s); -''', - (batch_id, spec_writer.token, bunch_start_job_id), - query_name='insert_batch_bunches', - ) - except asyncio.CancelledError: - raise - except web.HTTPException: - raise - except Exception as err: - raise ValueError( - f'encountered exception while inserting a bunch' - f'jobs_args={json.dumps(jobs_args)}' - f'job_parents_args={json.dumps(job_parents_args)}' - ) from err +""", + (batch_id, spec_writer.token, bunch_start_job_id), + query_name='insert_batch_bunches', + ) @transaction(db) async def write_and_insert(tx): @@ -1454,7 +1778,18 @@ async def write_and_insert(tx): # must rollback. See https://github.com/hail-is/hail-production-issues/issues/9 await asyncio.gather(write_spec_to_cloud(), insert_jobs_into_db(tx)) - await write_and_insert() + try: + await write_and_insert() + except asyncio.CancelledError: + raise + except web.HTTPException: + raise + except Exception as err: + raise ValueError( + f'encountered exception while inserting a bunch' + f'jobs_args={json.dumps(jobs_args)}' + f'job_parents_args={json.dumps(job_parents_args)}' + ) from err return web.Response() @@ -1469,18 +1804,42 @@ async def create_batch_fast(request, userdata): user = userdata['username'] batch_and_bunch = await json_request(request) batch_spec = batch_and_bunch['batch'] - bunch = batch_and_bunch['bunch'] - batch_id = await _create_batch(batch_spec, userdata, db) - update_id, _ = await _create_batch_update(batch_id, batch_spec['token'], batch_spec['n_jobs'], user, db) + jobs = batch_and_bunch['bunch'] + job_groups = batch_and_bunch.get('job_groups', []) + try: - await _create_jobs(userdata, bunch, batch_id, update_id, app) - except web.HTTPBadRequest as e: - if f'update {update_id} is already committed' == e.reason: - return json_response({'id': batch_id}) - raise + validate_batch(batch_spec) + validate_and_clean_jobs(jobs) + validate_job_groups(job_groups) + except ValidationError as e: + raise web.HTTPBadRequest(reason=e.reason) + + batch_id = await _create_batch(batch_spec, userdata, db) + + update_id, start_job_group_id, start_job_id = await _create_batch_update( + batch_id, batch_spec['token'], batch_spec['n_jobs'], batch_spec.get('n_job_groups', 0), user, db + ) + + if len(job_groups) > 0: + try: + await _create_job_groups(db, batch_id, update_id, user, job_groups) + except web.HTTPBadRequest as e: + if f'update {update_id} is already committed' == e.reason: + return json_response({'id': batch_id}) + raise + + if len(jobs) > 0: + try: + await _create_jobs(userdata, jobs, batch_id, update_id, app) + except web.HTTPBadRequest as e: + if f'update {update_id} is already committed' == e.reason: + return json_response({'id': batch_id}) + raise + await _commit_update(app, batch_id, update_id, user, db) + request['batch_telemetry']['batch_id'] = str(batch_id) - return json_response({'id': batch_id}) + return json_response({'id': batch_id, 'start_job_group_id': start_job_group_id, 'start_job_id': start_job_id}) @routes.post('/api/v1alpha/batches/create') @@ -1491,24 +1850,33 @@ async def create_batch(request, userdata): db: Database = app['db'] batch_spec = await json_request(request) + try: + validate_batch(batch_spec) + except ValidationError as e: + raise web.HTTPBadRequest(reason=e.reason) + id = await _create_batch(batch_spec, userdata, db) n_jobs = batch_spec['n_jobs'] - if n_jobs > 0: - update_id, _ = await _create_batch_update( - id, batch_spec['token'], batch_spec['n_jobs'], userdata['username'], db + n_job_groups = batch_spec.get('n_job_groups', 0) + if n_jobs > 0 or n_job_groups > 0: + update_id, start_job_group_id, start_job_id = await _create_batch_update( + id, batch_spec['token'], n_jobs, n_job_groups, userdata['username'], db ) else: update_id = None + start_job_group_id = None + start_job_id = None + request['batch_telemetry']['batch_id'] = str(id) - return json_response({'id': id, 'update_id': update_id}) + return json_response({ + 'id': id, + 'update_id': update_id, + 'start_job_group_id': start_job_group_id, + 'start_job_id': start_job_id, + }) async def _create_batch(batch_spec: dict, userdata, db: Database) -> int: - try: - validate_batch(batch_spec) - except ValidationError as e: - raise web.HTTPBadRequest(reason=e.reason) - user = userdata['username'] # restrict to what's necessary; in particular, drop the session @@ -1527,13 +1895,13 @@ async def _create_batch(batch_spec: dict, userdata, db: Database) -> int: @transaction(db) async def insert(tx): bp = await tx.execute_and_fetchone( - ''' + """ SELECT billing_projects.status, billing_projects.limit FROM billing_project_users INNER JOIN billing_projects ON billing_projects.name = billing_project_users.billing_project WHERE billing_projects.name_cs = %s AND user_cs = %s -LOCK IN SHARE MODE''', +LOCK IN SHARE MODE""", (billing_project, user), ) @@ -1543,7 +1911,7 @@ async def insert(tx): raise web.HTTPForbidden(reason=f'Billing project {billing_project} is closed or deleted.') bp_cost_record = await tx.execute_and_fetchone( - ''' + """ SELECT COALESCE(SUM(t.`usage` * rate), 0) AS cost FROM ( SELECT resource_id, CAST(COALESCE(SUM(`usage`), 0) AS SIGNED) AS `usage` @@ -1552,7 +1920,7 @@ async def insert(tx): GROUP BY resource_id ) AS t LEFT JOIN resources on resources.resource_id = t.resource_id; -''', +""", (billing_project,), ) limit = bp['limit'] @@ -1563,10 +1931,10 @@ async def insert(tx): ) maybe_batch = await tx.execute_and_fetchone( - ''' + """ SELECT * FROM batches WHERE token = %s AND user = %s FOR UPDATE; -''', +""", (token, user), ) @@ -1575,10 +1943,10 @@ async def insert(tx): now = time_msecs() id = await tx.execute_insertone( - ''' + """ INSERT INTO batches (userdata, user, billing_project, attributes, callback, n_jobs, time_created, time_completed, token, state, format_version, cancel_after_n_failures, migrated_batch) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s); -''', +""", ( json.dumps(userdata), user, @@ -1597,57 +1965,19 @@ async def insert(tx): query_name='insert_batches', ) - await tx.execute_insertone( - ''' -INSERT INTO job_groups (batch_id, job_group_id, `user`, attributes, cancel_after_n_failures, state, n_jobs, time_created, time_completed, callback) -VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s); -''', - ( - id, - ROOT_JOB_GROUP_ID, - user, - json.dumps(attributes), - batch_spec.get('cancel_after_n_failures'), - 'complete', - 0, - now, - now, - batch_spec.get('callback'), - ), - query_name='insert_job_group', + await _create_job_group( + tx, + batch_id=id, + job_group_id=ROOT_JOB_GROUP_ID, + update_id=None, + user=user, + attributes=attributes, + cancel_after_n_failures=batch_spec.get('cancel_after_n_failures'), + callback=batch_spec.get('callback'), + timestamp=now, + parent_job_group_id=ROOT_JOB_GROUP_ID, ) - await tx.execute_insertone( - ''' -INSERT INTO job_group_self_and_ancestors (batch_id, job_group_id, ancestor_id, level) -VALUES (%s, %s, %s, %s); -''', - ( - id, - ROOT_JOB_GROUP_ID, - ROOT_JOB_GROUP_ID, - 0, - ), - query_name='insert_job_group_parent', - ) - - await tx.execute_insertone( - ''' -INSERT INTO job_groups_n_jobs_in_complete_states (id, job_group_id) VALUES (%s, %s); -''', - (id, ROOT_JOB_GROUP_ID), - query_name='insert_job_groups_n_jobs_in_complete_states', - ) - - if attributes: - await tx.execute_many( - ''' -INSERT INTO `job_group_attributes` (batch_id, job_group_id, `key`, `value`) -VALUES (%s, %s, %s, %s) -''', - [(id, ROOT_JOB_GROUP_ID, k, v) for k, v in attributes.items()], - query_name='insert_job_group_attributes', - ) return id return await insert() @@ -1664,26 +1994,53 @@ async def update_batch_fast(request, userdata): user = userdata['username'] update_and_bunch = await json_request(request) update_spec = update_and_bunch['update'] - bunch = update_and_bunch['bunch'] + jobs = update_and_bunch['bunch'] + job_groups = update_and_bunch.get('job_groups', []) try: validate_batch_update(update_spec) + validate_and_clean_jobs(jobs) + validate_job_groups(job_groups) except ValidationError as e: raise web.HTTPBadRequest(reason=e.reason) - update_id, start_job_id = await _create_batch_update( - batch_id, update_spec['token'], update_spec['n_jobs'], user, db + update_id, start_job_group_id, start_job_id = await _create_batch_update( + batch_id, update_spec['token'], update_spec['n_jobs'], update_spec.get('n_job_groups', 0), user, db ) - try: - await _create_jobs(userdata, bunch, batch_id, update_id, app) - except web.HTTPBadRequest as e: - if f'update {update_id} is already committed' == e.reason: - return json_response({'update_id': update_id, 'start_job_id': start_job_id}) - raise + if len(job_groups) > 0: + try: + await _create_job_groups(db, batch_id, update_id, user, job_groups) + except web.HTTPBadRequest as e: + if f'update {update_id} is already committed' == e.reason: + return json_response({ + 'update_id': update_id, + 'start_job_group_id': start_job_group_id, + 'start_job_id': start_job_id, + }) + raise + + if len(jobs) > 0: + try: + await _create_jobs(userdata, jobs, batch_id, update_id, app) + except web.HTTPBadRequest as e: + if f'update {update_id} is already committed' == e.reason: + return json_response({ + 'update_id': update_id, + 'start_job_id': start_job_id, + 'start_job_group_id': start_job_group_id, + }) + raise + await _commit_update(app, batch_id, update_id, user, db) + request['batch_telemetry']['batch_id'] = str(batch_id) - return json_response({'update_id': update_id, 'start_job_id': start_job_id}) + + return json_response({ + 'update_id': update_id, + 'start_job_id': start_job_id, + 'start_job_group_id': start_job_group_id, + }) @routes.post('/api/v1alpha/batches/{batch_id}/updates/create') @@ -1706,75 +2063,105 @@ async def create_update(request, userdata): except ValidationError as e: raise web.HTTPBadRequest(reason=e.reason) - update_id, _ = await _create_batch_update(batch_id, update_spec['token'], update_spec['n_jobs'], user, db) - return json_response({'update_id': update_id}) + n_jobs = update_spec['n_jobs'] + n_job_groups = update_spec.get('n_job_groups', 0) + + update_id, start_job_group_id, start_job_id = await _create_batch_update( + batch_id, update_spec['token'], n_jobs, n_job_groups, user, db + ) + return json_response({ + 'update_id': update_id, + 'start_job_group_id': start_job_group_id, + 'start_job_id': start_job_id, + }) async def _create_batch_update( - batch_id: int, update_token: str, n_jobs: int, user: str, db: Database -) -> Tuple[int, int]: + batch_id: int, update_token: str, n_jobs: int, n_job_groups: int, user: str, db: Database +) -> Tuple[int, int, int]: @transaction(db) async def update(tx: Transaction): - assert n_jobs > 0 + assert n_jobs > 0 or n_job_groups > 0 record = await tx.execute_and_fetchone( - ''' -SELECT update_id, start_job_id FROM batch_updates -WHERE batch_id = %s AND token = %s; -''', + """ +SELECT update_id, start_job_id, start_job_group_id +FROM batch_updates +WHERE batch_id = %s AND token = %s +FOR UPDATE; +""", (batch_id, update_token), ) if record: - return record['update_id'], record['start_job_id'] + return (record['update_id'], record['start_job_id'], record['start_job_group_id']) # We use FOR UPDATE so that we serialize batch update insertions - # This is necessary to reserve job id ranges. + # This is necessary to reserve job id and job group id ranges. # We don't allow updates to batches that have been cancelled # but do allow updates to batches with jobs that have been cancelled. record = await tx.execute_and_fetchone( - ''' -SELECT job_groups_cancelled.id IS NOT NULL AS cancelled + """ +SELECT cancelled_t.cancelled IS NOT NULL AS cancelled FROM batches -LEFT JOIN job_groups_cancelled ON batches.id = job_groups_cancelled.id -WHERE batches.id = %s AND user = %s AND NOT deleted +LEFT JOIN ( + SELECT id, 1 AS cancelled + FROM job_groups_cancelled + WHERE id = %s AND job_group_id = %s +) AS cancelled_t ON batches.id = cancelled_t.id +WHERE batches.id = %s AND batches.user = %s AND NOT deleted FOR UPDATE; -''', - (batch_id, user), +""", + (batch_id, ROOT_JOB_GROUP_ID, batch_id, user), ) if not record: raise web.HTTPNotFound() if record['cancelled']: - raise web.HTTPBadRequest(reason='Cannot submit new jobs to a cancelled batch') + raise web.HTTPBadRequest(reason='Cannot submit new jobs or job groups to a cancelled batch') now = time_msecs() record = await tx.execute_and_fetchone( - ''' -SELECT update_id, start_job_id, n_jobs FROM batch_updates + """ +SELECT update_id, start_job_id, n_jobs, start_job_group_id, n_job_groups +FROM batch_updates WHERE batch_id = %s ORDER BY update_id DESC -LIMIT 1; -''', +LIMIT 1 +FOR UPDATE; +""", (batch_id,), ) - if record: + + if record is not None: update_id = int(record['update_id']) + 1 + update_start_job_group_id = int(record['start_job_group_id']) + int(record['n_job_groups']) update_start_job_id = int(record['start_job_id']) + int(record['n_jobs']) else: update_id = 1 + update_start_job_group_id = 1 update_start_job_id = 1 await tx.execute_insertone( - ''' + """ INSERT INTO batch_updates -(batch_id, update_id, token, start_job_id, n_jobs, committed, time_created) -VALUES (%s, %s, %s, %s, %s, %s, %s); -''', - (batch_id, update_id, update_token, update_start_job_id, n_jobs, False, now), +(batch_id, update_id, token, start_job_group_id, n_job_groups, start_job_id, n_jobs, committed, time_created) +VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s); +""", + ( + batch_id, + update_id, + update_token, + update_start_job_group_id, + n_job_groups, + update_start_job_id, + n_jobs, + False, + now, + ), query_name='insert_batch_update', ) - return (update_id, update_start_job_id) + return (update_id, update_start_job_group_id, update_start_job_id) return await update() @@ -1783,33 +2170,36 @@ async def _get_batch(app, batch_id): db: Database = app['db'] record = await db.select_and_fetchone( - ''' + """ SELECT batches.*, - job_groups_cancelled.id IS NOT NULL AS cancelled, + cancelled_t.cancelled IS NOT NULL AS cancelled, job_groups_n_jobs_in_complete_states.n_completed, job_groups_n_jobs_in_complete_states.n_succeeded, job_groups_n_jobs_in_complete_states.n_failed, job_groups_n_jobs_in_complete_states.n_cancelled, cost_t.* -FROM batches +FROM job_groups +LEFT JOIN batches ON batches.id = job_groups.batch_id LEFT JOIN job_groups_n_jobs_in_complete_states - ON batches.id = job_groups_n_jobs_in_complete_states.id -LEFT JOIN job_groups_cancelled - ON batches.id = job_groups_cancelled.id + ON job_groups.batch_id = job_groups_n_jobs_in_complete_states.id AND job_groups.job_group_id = job_groups_n_jobs_in_complete_states.job_group_id +LEFT JOIN ( + SELECT id, 1 AS cancelled + FROM job_groups_cancelled + WHERE id = %s AND job_group_id = %s +) AS cancelled_t ON batches.id = cancelled_t.id LEFT JOIN LATERAL ( SELECT COALESCE(SUM(`usage` * rate), 0) AS cost, JSON_OBJECTAGG(resources.resource, COALESCE(`usage` * rate, 0)) AS cost_breakdown FROM ( - SELECT batch_id, resource_id, CAST(COALESCE(SUM(`usage`), 0) AS SIGNED) AS `usage` + SELECT resource_id, CAST(COALESCE(SUM(`usage`), 0) AS SIGNED) AS `usage` FROM aggregated_job_group_resources_v3 - WHERE batches.id = aggregated_job_group_resources_v3.batch_id - GROUP BY batch_id, resource_id + WHERE job_groups.batch_id = aggregated_job_group_resources_v3.batch_id AND job_groups.job_group_id = aggregated_job_group_resources_v3.job_group_id + GROUP BY resource_id ) AS usage_t LEFT JOIN resources ON usage_t.resource_id = resources.resource_id - GROUP BY batch_id ) AS cost_t ON TRUE -WHERE batches.id = %s AND NOT deleted; -''', - (batch_id,), +WHERE job_groups.batch_id = %s AND job_groups.job_group_id = %s AND NOT deleted; +""", + (batch_id, ROOT_JOB_GROUP_ID, batch_id, ROOT_JOB_GROUP_ID), ) if not record: raise web.HTTPNotFound() @@ -1817,8 +2207,57 @@ async def _get_batch(app, batch_id): return batch_record_to_dict(record) -async def _cancel_batch(app, batch_id): - await cancel_batch_in_db(app['db'], batch_id) +async def _get_job_group(app, batch_id: int, job_group_id: int) -> GetJobGroupResponseV1Alpha: + db: Database = app['db'] + + is_root_job_group = job_group_id == ROOT_JOB_GROUP_ID + + record = await db.select_and_fetchone( + """ +SELECT job_groups.*, + cancelled_t.cancelled IS NOT NULL AS cancelled, + job_groups_n_jobs_in_complete_states.n_completed, + job_groups_n_jobs_in_complete_states.n_succeeded, + job_groups_n_jobs_in_complete_states.n_failed, + job_groups_n_jobs_in_complete_states.n_cancelled, + cost_t.* +FROM job_groups +LEFT JOIN batches ON batches.id = job_groups.batch_id +LEFT JOIN batch_updates + ON job_groups.batch_id = batch_updates.batch_id AND job_groups.update_id = batch_updates.update_id +LEFT JOIN job_groups_n_jobs_in_complete_states + ON job_groups.batch_id = job_groups_n_jobs_in_complete_states.id AND job_groups.job_group_id = job_groups_n_jobs_in_complete_states.job_group_id +LEFT JOIN LATERAL ( + SELECT 1 AS cancelled + FROM job_group_self_and_ancestors + INNER JOIN job_groups_cancelled + ON job_group_self_and_ancestors.batch_id = job_groups_cancelled.id AND + job_group_self_and_ancestors.ancestor_id = job_groups_cancelled.job_group_id + WHERE job_groups.batch_id = job_group_self_and_ancestors.batch_id AND + job_groups.job_group_id = job_group_self_and_ancestors.job_group_id +) AS cancelled_t ON TRUE +LEFT JOIN LATERAL ( + SELECT COALESCE(SUM(`usage` * rate), 0) AS cost, JSON_OBJECTAGG(resources.resource, COALESCE(`usage` * rate, 0)) AS cost_breakdown + FROM ( + SELECT resource_id, CAST(COALESCE(SUM(`usage`), 0) AS SIGNED) AS `usage` + FROM aggregated_job_group_resources_v3 + WHERE job_groups.batch_id = aggregated_job_group_resources_v3.batch_id AND job_groups.job_group_id = aggregated_job_group_resources_v3.job_group_id + GROUP BY resource_id + ) AS usage_t + LEFT JOIN resources ON usage_t.resource_id = resources.resource_id +) AS cost_t ON TRUE +WHERE job_groups.batch_id = %s AND job_groups.job_group_id = %s AND NOT deleted AND (batch_updates.committed OR %s); +""", + (batch_id, job_group_id, is_root_job_group), + ) + if not record: + raise web.HTTPNotFound() + + return job_group_record_to_dict(record) + + +async def _cancel_job_group(app, batch_id, job_group_id): + await cancel_job_group_in_db(app['db'], batch_id, job_group_id) app['cancel_batch_state_changed'].set() return web.Response() @@ -1827,16 +2266,16 @@ async def _delete_batch(app, batch_id): db: Database = app['db'] record = await db.select_and_fetchone( - ''' + """ SELECT `state` FROM batches WHERE id = %s AND NOT deleted; -''', +""", (batch_id,), ) if not record: raise web.HTTPNotFound() - await db.just_execute('CALL cancel_batch(%s);', (batch_id,)) + await db.just_execute('CALL cancel_job_group(%s, %s);', (batch_id, ROOT_JOB_GROUP_ID)) await db.execute_update('UPDATE batches SET deleted = 1 WHERE id = %s;', (batch_id,)) if record['state'] == 'running': @@ -1854,7 +2293,24 @@ async def get_batch(request: web.Request, _, batch_id: int) -> web.Response: @billing_project_users_only() @add_metadata_to_request async def cancel_batch(request: web.Request, _, batch_id: int) -> web.Response: - await _handle_api_error(_cancel_batch, request.app, batch_id) + await _handle_api_error(_cancel_job_group, request.app, batch_id, ROOT_JOB_GROUP_ID) + return web.Response() + + +@routes.get('/api/v1alpha/batches/{batch_id}/job-groups/{job_group_id}') +@billing_project_users_only() +@add_metadata_to_request +async def get_job_group(request: web.Request, _, batch_id: int) -> web.Response: + job_group_id = int(request.match_info['job_group_id']) + return json_response(await _get_job_group(request.app, batch_id, job_group_id)) + + +@routes.patch('/api/v1alpha/batches/{batch_id}/job-groups/{job_group_id}/cancel') +@billing_project_users_only() +@add_metadata_to_request +async def cancel_job_group(request: web.Request, _, batch_id: int) -> web.Response: + job_group_id = int(request.match_info['job_group_id']) + await _handle_api_error(_cancel_job_group, request.app, batch_id, job_group_id) return web.Response() @@ -1870,13 +2326,21 @@ async def close_batch(request, userdata): db: Database = app['db'] record = await db.select_and_fetchone( - ''' -SELECT job_groups_cancelled.id IS NOT NULL AS cancelled -FROM batches -LEFT JOIN job_groups_cancelled ON batches.id = job_groups_cancelled.id -WHERE user = %s AND batches.id = %s AND NOT deleted; -''', - (user, batch_id), + """ +SELECT cancelled_t.cancelled IS NOT NULL AS cancelled +FROM job_groups +LEFT JOIN LATERAL ( + SELECT 1 AS cancelled + FROM job_group_self_and_ancestors + INNER JOIN job_groups_cancelled + ON job_group_self_and_ancestors.batch_id = job_groups_cancelled.id AND + job_group_self_and_ancestors.ancestor_id = job_groups_cancelled.job_group_id + WHERE job_groups.batch_id = job_group_self_and_ancestors.batch_id AND + job_groups.job_group_id = job_group_self_and_ancestors.job_group_id +) AS cancelled_t ON TRUE +WHERE user = %s AND job_groups.batch_id = %s AND job_groups.job_group_id = %s AND NOT deleted; +""", + (user, batch_id, ROOT_JOB_GROUP_ID), ) if not record: raise web.HTTPNotFound() @@ -1884,10 +2348,10 @@ async def close_batch(request, userdata): raise web.HTTPBadRequest(reason='Cannot close a previously cancelled batch.') record = await db.select_and_fetchone( - ''' + """ SELECT 1 FROM batch_updates WHERE batch_id = %s AND update_id = 1; -''', +""", (batch_id,), ) if record: @@ -1907,14 +2371,18 @@ async def commit_update(request: web.Request, userdata): update_id = int(request.match_info['update_id']) record = await db.select_and_fetchone( - ''' -SELECT start_job_id, job_groups_cancelled.id IS NOT NULL AS cancelled + """ +SELECT start_job_id, start_job_group_id, cancelled_t.cancelled IS NOT NULL AS cancelled FROM batches LEFT JOIN batch_updates ON batches.id = batch_updates.batch_id -LEFT JOIN job_groups_cancelled ON batches.id = job_groups_cancelled.id -WHERE user = %s AND batches.id = %s AND batch_updates.update_id = %s AND NOT deleted; -''', - (user, batch_id, update_id), +LEFT JOIN ( + SELECT id, 1 AS cancelled + FROM job_groups_cancelled + WHERE id = %s AND job_group_id = %s +) AS cancelled_t ON batches.id = cancelled_t.id +WHERE batches.user = %s AND batches.id = %s AND batch_updates.update_id = %s AND NOT deleted; +""", + (batch_id, ROOT_JOB_GROUP_ID, user, batch_id, update_id), ) if not record: raise web.HTTPNotFound() @@ -1922,11 +2390,11 @@ async def commit_update(request: web.Request, userdata): raise web.HTTPBadRequest(reason='Cannot commit an update to a cancelled batch') await _commit_update(app, batch_id, update_id, user, db) - return json_response({'start_job_id': record['start_job_id']}) + return json_response({'start_job_id': record['start_job_id'], 'start_job_group_id': record['start_job_group_id']}) async def _commit_update(app: web.Application, batch_id: int, update_id: int, user: str, db: Database): - client_session: httpx.ClientSession = app['client_session'] + client_session = app[CommonAiohttpAppKeys.CLIENT_SESSION] try: now = time_msecs() @@ -1969,7 +2437,9 @@ async def ui_batch(request, userdata, batch_id): last_job_id = cast_query_param_to_int(request.query.get('last_job_id')) try: - jobs, last_job_id = await _query_batch_jobs(request, batch_id, CURRENT_QUERY_VERSION, q, last_job_id) + jobs, last_job_id = await _query_job_group_jobs( + request, batch_id, ROOT_JOB_GROUP_ID, CURRENT_QUERY_VERSION, q, last_job_id, recursive=True + ) except QueryError as e: session = await aiohttp_session.get_session(request) set_message(session, e.message, 'error') @@ -1979,6 +2449,11 @@ async def ui_batch(request, userdata, batch_id): for j in jobs: j['duration'] = humanize_timedelta_msecs(j['duration']) j['cost'] = cost_str(j['cost']) + j['display_state'] = ( + f"{j['state']} (always run)" + if j['always_run'] and j['state'] not in {'Success', 'Failed', 'Error'} + else j['state'] + ) batch['jobs'] = jobs batch['cost'] = cost_str(batch['cost']) @@ -2007,7 +2482,7 @@ async def ui_cancel_batch(request: web.Request, _, batch_id: int) -> NoReturn: params['q'] = str(q) session = await aiohttp_session.get_session(request) try: - await _handle_ui_error(session, _cancel_batch, request.app, batch_id) + await _handle_ui_error(session, _cancel_job_group, request.app, batch_id, ROOT_JOB_GROUP_ID) set_message(session, f'Batch {batch_id} cancelled.', 'info') finally: location = request.app.router['batches'].url_for().with_query(params) @@ -2058,7 +2533,7 @@ async def _get_job(app, batch_id, job_id) -> GetJobResponseV1Alpha: db: Database = app['db'] record = await db.select_and_fetchone( - ''' + """ WITH base_t AS ( SELECT jobs.*, user, billing_project, ip_address, format_version, t.attempt_id AS last_cancelled_attempt_id FROM jobs @@ -2083,15 +2558,14 @@ async def _get_job(app, batch_id, job_id) -> GetJobResponseV1Alpha: FROM base_t LEFT JOIN LATERAL ( SELECT COALESCE(SUM(`usage` * rate), 0) AS cost, JSON_OBJECTAGG(resources.resource, COALESCE(`usage` * rate, 0)) AS cost_breakdown -FROM (SELECT aggregated_job_resources_v3.batch_id, aggregated_job_resources_v3.job_id, resource_id, CAST(COALESCE(SUM(`usage`), 0) AS SIGNED) AS `usage` +FROM (SELECT resource_id, CAST(COALESCE(SUM(`usage`), 0) AS SIGNED) AS `usage` FROM aggregated_job_resources_v3 WHERE aggregated_job_resources_v3.batch_id = base_t.batch_id AND aggregated_job_resources_v3.job_id = base_t.job_id - GROUP BY aggregated_job_resources_v3.batch_id, aggregated_job_resources_v3.job_id, aggregated_job_resources_v3.resource_id + GROUP BY aggregated_job_resources_v3.resource_id ) AS usage_t LEFT JOIN resources ON usage_t.resource_id = resources.resource_id -GROUP BY usage_t.batch_id, usage_t.job_id ) AS cost_t ON TRUE; -''', +""", (batch_id, job_id, batch_id, job_id), ) if not record: @@ -2115,13 +2589,13 @@ async def _get_attempts(app, batch_id, job_id): db: Database = app['db'] attempts = db.select_and_fetchall( - ''' + """ SELECT attempts.* FROM jobs INNER JOIN batches ON jobs.batch_id = batches.id LEFT JOIN attempts ON jobs.batch_id = attempts.batch_id and jobs.job_id = attempts.job_id WHERE jobs.batch_id = %s AND NOT deleted AND jobs.job_id = %s; -''', +""", (batch_id, job_id), query_name='get_attempts', ) @@ -2401,20 +2875,18 @@ async def ui_get_job(request, userdata, batch_id): job['cost_breakdown'].sort(key=lambda record: record['resource']) job_status = job['status'] - container_status_spec = dictfix.NoneOr( - { - 'name': str, - 'timing': { - 'pulling': dictfix.NoneOr({'duration': dictfix.NoneOr(Number)}), - 'running': dictfix.NoneOr({'duration': dictfix.NoneOr(Number)}), - 'uploading_resource_usage': dictfix.NoneOr({'duration': dictfix.NoneOr(Number)}), - }, - 'short_error': dictfix.NoneOr(str), - 'error': dictfix.NoneOr(str), - 'container_status': {'out_of_memory': dictfix.NoneOr(bool)}, - 'state': str, - } - ) + container_status_spec = dictfix.NoneOr({ + 'name': str, + 'timing': { + 'pulling': dictfix.NoneOr({'duration': dictfix.NoneOr(Number)}), + 'running': dictfix.NoneOr({'duration': dictfix.NoneOr(Number)}), + 'uploading_resource_usage': dictfix.NoneOr({'duration': dictfix.NoneOr(Number)}), + }, + 'short_error': dictfix.NoneOr(str), + 'error': dictfix.NoneOr(str), + 'container_status': {'out_of_memory': dictfix.NoneOr(bool)}, + 'state': str, + }) job_status_spec = { 'container_statuses': { 'input': container_status_spec, @@ -2550,13 +3022,13 @@ async def _edit_billing_limit(db, billing_project, limit): @transaction(db) async def insert(tx): row = await tx.execute_and_fetchone( - ''' + """ SELECT billing_projects.name as billing_project, billing_projects.`status` as `status` FROM billing_projects WHERE billing_projects.name_cs = %s AND billing_projects.`status` != 'deleted' FOR UPDATE; - ''', + """, (billing_project,), ) if row is None: @@ -2566,9 +3038,9 @@ async def insert(tx): raise ClosedBillingProjectError(billing_project) await tx.execute_update( - ''' + """ UPDATE billing_projects SET `limit` = %s WHERE name_cs = %s; -''', +""", (limit, billing_project), ) @@ -2649,7 +3121,7 @@ async def parse_error(msg: str) -> Tuple[list, str, None]: where_conditions.append("`user` = %s") where_args.append(user) - sql = f''' + sql = f""" SELECT billing_project, `user`, @@ -2663,7 +3135,7 @@ async def parse_error(msg: str) -> Tuple[list, str, None]: ) AS t LEFT JOIN resources ON resources.resource_id = t.resource_id GROUP BY billing_project, `user`; -''' +""" sql_args = where_args @@ -2770,7 +3242,7 @@ async def _remove_user_from_billing_project(db, billing_project, user): @transaction(db) async def delete(tx): row = await tx.execute_and_fetchone( - ''' + """ SELECT billing_projects.name_cs as billing_project, billing_projects.`status` as `status`, `user` @@ -2782,7 +3254,7 @@ async def delete(tx): FOR UPDATE ) AS t ON billing_projects.name = t.billing_project WHERE billing_projects.name_cs = %s; -''', +""", (billing_project, user, billing_project), ) if not row: @@ -2800,11 +3272,11 @@ async def delete(tx): ) await tx.just_execute( - ''' + """ DELETE billing_project_users FROM billing_project_users LEFT JOIN billing_projects ON billing_projects.name = billing_project_users.billing_project WHERE billing_projects.name_cs = %s AND user_cs = %s; -''', +""", (billing_project, user), ) @@ -2842,7 +3314,7 @@ async def _add_user_to_billing_project(request: web.Request, db: Database, billi session_id = await get_session_id(request) assert session_id is not None url = deploy_config.url('auth', f'/api/v1alpha/users/{user}') - await impersonate_user(session_id, request.app['client_session'], url) + await impersonate_user(session_id, request.app[CommonAiohttpAppKeys.CLIENT_SESSION], url) except aiohttp.ClientResponseError as e: if e.status == 404: raise NonExistentUserError(user) from e @@ -2852,7 +3324,7 @@ async def _add_user_to_billing_project(request: web.Request, db: Database, billi async def insert(tx): # we want to be case-insensitive here to avoid duplicates with existing records row = await tx.execute_and_fetchone( - ''' + """ SELECT billing_projects.name as billing_project, billing_projects.`status` as `status`, user @@ -2866,7 +3338,7 @@ async def insert(tx): ) AS t ON billing_projects.name = t.billing_project WHERE billing_projects.name_cs = %s AND billing_projects.`status` != 'deleted' LOCK IN SHARE MODE; - ''', + """, (billing_project, user, billing_project), ) if row is None: @@ -2881,10 +3353,10 @@ async def insert(tx): ) await tx.execute_insertone( - ''' + """ INSERT INTO billing_project_users(billing_project, user, user_cs) VALUES (%s, %s, %s); - ''', + """, (billing_project, user, user), ) @@ -2925,12 +3397,12 @@ async def _create_billing_project(db, billing_project): async def insert(tx): # we want to avoid having billing projects with different cases but the same name row = await tx.execute_and_fetchone( - ''' + """ SELECT name_cs, `status` FROM billing_projects WHERE name = %s FOR UPDATE; -''', +""", (billing_project), ) if row is not None: @@ -2938,10 +3410,10 @@ async def insert(tx): raise BatchOperationAlreadyCompletedError(f'Billing project {billing_project_cs} already exists.', 'info') await tx.execute_insertone( - ''' + """ INSERT INTO billing_projects(name, name_cs) VALUES (%s, %s); -''', +""", (billing_project, billing_project), ) @@ -2977,7 +3449,7 @@ async def _close_billing_project(db, billing_project): @transaction(db) async def close_project(tx): row = await tx.execute_and_fetchone( - ''' + """ SELECT name_cs, `status`, batches.id as batch_id FROM billing_projects LEFT JOIN batches @@ -2988,7 +3460,7 @@ async def close_project(tx): WHERE name_cs = %s LIMIT 1 FOR UPDATE; - ''', + """, (billing_project,), ) if not row: @@ -3111,11 +3583,9 @@ async def _refresh(app): db: Database = app['db'] inst_coll_configs: InstanceCollectionConfigs = app['inst_coll_configs'] await inst_coll_configs.refresh(db) - row = await db.select_and_fetchone( - ''' + row = await db.select_and_fetchone(""" SELECT frozen FROM globals; -''' - ) +""") app['frozen'] = row['frozen'] regions = { @@ -3135,7 +3605,7 @@ async def index(request: web.Request, _) -> NoReturn: async def cancel_batch_loop_body(app): - client_session: httpx.ClientSession = app['client_session'] + client_session = app[CommonAiohttpAppKeys.CLIENT_SESSION] await retry_transient_errors( client_session.post, deploy_config.url('batch-driver', '/api/v1alpha/batches/cancel'), @@ -3147,7 +3617,7 @@ async def cancel_batch_loop_body(app): async def delete_batch_loop_body(app): - client_session: httpx.ClientSession = app['client_session'] + client_session = app[CommonAiohttpAppKeys.CLIENT_SESSION] await retry_transient_errors( client_session.post, deploy_config.url('batch-driver', '/api/v1alpha/batches/delete'), @@ -3185,19 +3655,17 @@ async def on_startup(app): exit_stack = AsyncExitStack() app['exit_stack'] = exit_stack - app['client_session'] = httpx.client_session() - exit_stack.push_async_callback(app['client_session'].close) + app[CommonAiohttpAppKeys.CLIENT_SESSION] = httpx.client_session() + exit_stack.push_async_callback(app[CommonAiohttpAppKeys.CLIENT_SESSION].close) db = Database() await db.async_init() app['db'] = db exit_stack.push_async_callback(app['db'].async_close) - row = await db.select_and_fetchone( - ''' + row = await db.select_and_fetchone(""" SELECT instance_id, n_tokens, frozen FROM globals; -''' - ) +""") app['n_tokens'] = row['n_tokens'] @@ -3274,5 +3742,5 @@ def run(): host='0.0.0.0', port=int(os.environ['PORT']), access_log_class=BatchFrontEndAccessLogger, - ssl_context=internal_server_ssl_context(), + ssl_context=deploy_config.server_ssl_context(), ) diff --git a/batch/batch/front_end/query/__init__.py b/batch/batch/front_end/query/__init__.py index 5f1e45f7f82..7567bc1f6bd 100644 --- a/batch/batch/front_end/query/__init__.py +++ b/batch/batch/front_end/query/__init__.py @@ -1,12 +1,13 @@ -from .query_v1 import parse_batch_jobs_query_v1, parse_list_batches_query_v1 -from .query_v2 import parse_batch_jobs_query_v2, parse_list_batches_query_v2 +from .query_v1 import parse_job_group_jobs_query_v1, parse_list_batches_query_v1, parse_list_job_groups_query_v1 +from .query_v2 import parse_job_group_jobs_query_v2, parse_list_batches_query_v2 CURRENT_QUERY_VERSION = 2 __all__ = [ 'CURRENT_QUERY_VERSION', - 'parse_batch_jobs_query_v1', - 'parse_batch_jobs_query_v2', + 'parse_job_group_jobs_query_v1', + 'parse_job_group_jobs_query_v2', 'parse_list_batches_query_v1', 'parse_list_batches_query_v2', + 'parse_list_job_groups_query_v1', ] diff --git a/batch/batch/front_end/query/query.py b/batch/batch/front_end/query/query.py index c28d804e9f0..5534eecf974 100644 --- a/batch/batch/front_end/query/query.py +++ b/batch/batch/front_end/query/query.py @@ -130,11 +130,11 @@ def query(self) -> Tuple[str, List[str]]: op = self.operator.to_sql() if isinstance(self.operator, PartialMatchOperator): self.instance = f'%{self.instance}%' - sql = f''' + sql = f""" ((jobs.batch_id, jobs.job_id) IN (SELECT batch_id, job_id FROM attempts WHERE instance_name {op} %s)) -''' +""" return (sql, [self.instance]) @@ -171,14 +171,14 @@ def __init__(self, term: str): self.term = term def query(self) -> Tuple[str, List[str]]: - sql = ''' + sql = """ (((jobs.batch_id, jobs.job_id) IN (SELECT batch_id, job_id FROM job_attributes WHERE `key` = %s OR `value` = %s)) OR ((jobs.batch_id, jobs.job_id) IN (SELECT batch_id, job_id FROM attempts WHERE instance_name = %s))) -''' +""" return (sql, [self.term, self.term, self.term]) @@ -197,14 +197,14 @@ def __init__(self, term: str): self.term = term def query(self) -> Tuple[str, List[str]]: - sql = ''' + sql = """ (((jobs.batch_id, jobs.job_id) IN (SELECT batch_id, job_id FROM job_attributes WHERE `key` LIKE %s OR `value` LIKE %s)) OR ((jobs.batch_id, jobs.job_id) IN (SELECT batch_id, job_id FROM attempts WHERE instance_name LIKE %s))) -''' +""" escaped_term = f'%{self.term}%' return (sql, [escaped_term, escaped_term, escaped_term]) @@ -227,11 +227,11 @@ def query(self) -> Tuple[str, List[str]]: value = self.value if isinstance(self.operator, PartialMatchOperator): value = f'%{value}%' - sql = f''' + sql = f""" ((jobs.batch_id, jobs.job_id) IN (SELECT batch_id, job_id FROM job_attributes WHERE `key` = %s AND `value` {op} %s)) - ''' + """ return (sql, [self.key, value]) @@ -250,11 +250,11 @@ def __init__(self, operator: ComparisonOperator, time_msecs: int): def query(self) -> Tuple[str, List[int]]: op = self.operator.to_sql() - sql = f''' + sql = f""" ((jobs.batch_id, jobs.job_id) IN (SELECT batch_id, job_id FROM attempts WHERE start_time {op} %s)) -''' +""" return (sql, [self.time_msecs]) @@ -273,11 +273,11 @@ def __init__(self, operator: ComparisonOperator, time_msecs: int): def query(self) -> Tuple[str, List[int]]: op = self.operator.to_sql() - sql = f''' + sql = f""" ((jobs.batch_id, jobs.job_id) IN (SELECT batch_id, job_id FROM attempts WHERE end_time {op} %s)) -''' +""" return (sql, [self.time_msecs]) @@ -296,11 +296,11 @@ def __init__(self, operator: ComparisonOperator, time_msecs: int): def query(self) -> Tuple[str, List[int]]: op = self.operator.to_sql() - sql = f''' + sql = f""" ((jobs.batch_id, jobs.job_id) IN (SELECT batch_id, job_id FROM attempts WHERE end_time - start_time {op} %s)) -''' +""" return (sql, [self.time_msecs]) @@ -361,19 +361,19 @@ def __init__(self, state: BatchState, operator: ExactMatchOperator): def query(self) -> Tuple[str, List[Any]]: args: List[Any] if self.state == BatchState.OPEN: - condition = "(`state` = 'open')" + condition = "(batches.`state` = 'open')" args = [] elif self.state == BatchState.CLOSED: - condition = "(`state` != 'open')" + condition = "(batches.`state` != 'open')" args = [] elif self.state == BatchState.COMPLETE: - condition = "(`state` = 'complete')" + condition = "(batches.`state` = 'complete')" args = [] elif self.state == BatchState.RUNNING: - condition = "(`state` = 'running')" + condition = "(batches.`state` = 'running')" args = [] elif self.state == BatchState.CANCELLED: - condition = '(job_groups_cancelled.id IS NOT NULL)' + condition = '(cancelled_t.cancelled IS NOT NULL)' args = [] elif self.state == BatchState.FAILURE: condition = '(n_failed > 0)' @@ -381,7 +381,7 @@ def query(self) -> Tuple[str, List[Any]]: else: assert self.state == BatchState.SUCCESS # need complete because there might be no jobs - condition = "(`state` = 'complete' AND n_succeeded = n_jobs)" + condition = "(batches.`state` = 'complete' AND n_succeeded = batches.n_jobs)" args = [] if isinstance(self.operator, NotEqualExactMatchOperator): @@ -442,58 +442,58 @@ def query(self) -> Tuple[str, List[str]]: return (f'(batches.billing_project {op} %s)', [self.billing_project]) -class BatchQuotedExactMatchQuery(Query): +class JobGroupQuotedExactMatchQuery(Query): @staticmethod - def parse(term: str) -> 'BatchQuotedExactMatchQuery': + def parse(term: str) -> 'JobGroupQuotedExactMatchQuery': if len(term) < 3: raise QueryError(f'expected a string of minimum length 3. Found {term}') if term[-1] != '"': raise QueryError("expected the last character of the string to be '\"'") - return BatchQuotedExactMatchQuery(term[1:-1]) + return JobGroupQuotedExactMatchQuery(term[1:-1]) def __init__(self, term: str): self.term = term def query(self) -> Tuple[str, List[str]]: - sql = ''' -((batches.id) IN - (SELECT batch_id FROM job_group_attributes + sql = """ +((job_groups.batch_id, job_groups.job_group_id) IN + (SELECT batch_id, job_group_id FROM job_group_attributes WHERE `key` = %s OR `value` = %s)) -''' +""" return (sql, [self.term, self.term]) -class BatchUnquotedPartialMatchQuery(Query): +class JobGroupUnquotedPartialMatchQuery(Query): @staticmethod - def parse(term: str) -> 'BatchUnquotedPartialMatchQuery': + def parse(term: str) -> 'JobGroupUnquotedPartialMatchQuery': if len(term) < 1: raise QueryError(f'expected a string of minimum length 1. Found {term}') if term[0] == '"': raise QueryError("expected the first character of the string to not be '\"'") if term[-1] == '"': raise QueryError("expected the last character of the string to not be '\"'") - return BatchUnquotedPartialMatchQuery(term) + return JobGroupUnquotedPartialMatchQuery(term) def __init__(self, term: str): self.term = term def query(self) -> Tuple[str, List[str]]: - sql = ''' -((batches.id) IN - (SELECT batch_id FROM job_group_attributes + sql = """ +((job_groups.batch_id, job_groups.job_group_id) IN + (SELECT batch_id, job_group_id FROM job_group_attributes WHERE `key` LIKE %s OR `value` LIKE %s)) -''' +""" escaped_term = f'%{self.term}%' return (sql, [escaped_term, escaped_term]) -class BatchKeywordQuery(Query): +class JobGroupKeywordQuery(Query): @staticmethod - def parse(op: str, key: str, value: str) -> 'BatchKeywordQuery': + def parse(op: str, key: str, value: str) -> 'JobGroupKeywordQuery': operator = get_operator(op) if not isinstance(operator, MatchOperator): raise QueryError(f'unexpected operator "{op}" expected one of {MatchOperator.symbols}') - return BatchKeywordQuery(operator, key, value) + return JobGroupKeywordQuery(operator, key, value) def __init__(self, operator: MatchOperator, key: str, value: str): self.operator = operator @@ -505,22 +505,22 @@ def query(self) -> Tuple[str, List[str]]: value = self.value if isinstance(self.operator, PartialMatchOperator): value = f'%{value}%' - sql = f''' -((batches.id) IN - (SELECT batch_id FROM job_group_attributes + sql = f""" +((job_groups.batch_id, job_groups.job_group_id) IN + (SELECT batch_id, job_group_id FROM job_group_attributes WHERE `key` = %s AND `value` {op} %s)) - ''' +""" return (sql, [self.key, value]) -class BatchStartTimeQuery(Query): +class JobGroupStartTimeQuery(Query): @staticmethod - def parse(op: str, time: str) -> 'BatchStartTimeQuery': + def parse(op: str, time: str) -> 'JobGroupStartTimeQuery': operator = get_operator(op) if not isinstance(operator, ComparisonOperator): raise QueryError(f'unexpected operator "{op}" expected one of {ComparisonOperator.symbols}') time_msecs = parse_date(time) - return BatchStartTimeQuery(operator, time_msecs) + return JobGroupStartTimeQuery(operator, time_msecs) def __init__(self, operator: ComparisonOperator, time_msecs: int): self.operator = operator @@ -528,18 +528,18 @@ def __init__(self, operator: ComparisonOperator, time_msecs: int): def query(self) -> Tuple[str, List[int]]: op = self.operator.to_sql() - sql = f'(batches.time_created {op} %s)' + sql = f'(job_groups.time_created {op} %s)' return (sql, [self.time_msecs]) -class BatchEndTimeQuery(Query): +class JobGroupEndTimeQuery(Query): @staticmethod - def parse(op: str, time: str) -> 'BatchEndTimeQuery': + def parse(op: str, time: str) -> 'JobGroupEndTimeQuery': operator = get_operator(op) if not isinstance(operator, ComparisonOperator): raise QueryError(f'unexpected operator "{op}" expected one of {ComparisonOperator.symbols}') time_msecs = parse_date(time) - return BatchEndTimeQuery(operator, time_msecs) + return JobGroupEndTimeQuery(operator, time_msecs) def __init__(self, operator: ComparisonOperator, time_msecs: int): self.operator = operator @@ -547,18 +547,18 @@ def __init__(self, operator: ComparisonOperator, time_msecs: int): def query(self) -> Tuple[str, List[int]]: op = self.operator.to_sql() - sql = f'(batches.time_completed {op} %s)' + sql = f'(job_groups.time_completed {op} %s)' return (sql, [self.time_msecs]) -class BatchDurationQuery(Query): +class JobGroupDurationQuery(Query): @staticmethod - def parse(op: str, time: str) -> 'BatchDurationQuery': + def parse(op: str, time: str) -> 'JobGroupDurationQuery': operator = get_operator(op) if not isinstance(operator, ComparisonOperator): raise QueryError(f'unexpected operator "{op}" expected one of {ComparisonOperator.symbols}') time_msecs = int(parse_float(time) * 1000) - return BatchDurationQuery(operator, time_msecs) + return JobGroupDurationQuery(operator, time_msecs) def __init__(self, operator: ComparisonOperator, time_msecs: int): self.operator = operator @@ -566,18 +566,18 @@ def __init__(self, operator: ComparisonOperator, time_msecs: int): def query(self) -> Tuple[str, List[int]]: op = self.operator.to_sql() - sql = f'((batches.time_completed - batches.time_created) {op} %s)' + sql = f'((job_groups.time_completed - job_groups.time_created) {op} %s)' return (sql, [self.time_msecs]) -class BatchCostQuery(Query): +class JobGroupCostQuery(Query): @staticmethod - def parse(op: str, cost_str: str) -> 'BatchCostQuery': + def parse(op: str, cost_str: str) -> 'JobGroupCostQuery': operator = get_operator(op) if not isinstance(operator, ComparisonOperator): raise QueryError(f'unexpected operator "{op}" expected one of {ComparisonOperator.symbols}') cost = parse_cost(cost_str) - return BatchCostQuery(operator, cost) + return JobGroupCostQuery(operator, cost) def __init__(self, operator: ComparisonOperator, cost: float): self.operator = operator diff --git a/batch/batch/front_end/query/query_v1.py b/batch/batch/front_end/query/query_v1.py index a52b1cf2c25..184339472ec 100644 --- a/batch/batch/front_end/query/query_v1.py +++ b/batch/batch/front_end/query/query_v1.py @@ -1,5 +1,7 @@ from typing import Any, List, Optional, Tuple +from hailtop.batch_client.globals import ROOT_JOB_GROUP_ID + from ...exceptions import QueryError from .query import job_state_search_term_to_states @@ -8,11 +10,12 @@ def parse_list_batches_query_v1(user: str, q: str, last_batch_id: Optional[int]) where_conditions = [ '(billing_project_users.`user` = %s AND billing_project_users.billing_project = batches.billing_project)', 'NOT deleted', + 'job_groups.job_group_id = %s', ] - where_args: List[Any] = [user] + where_args: List[Any] = [user, ROOT_JOB_GROUP_ID] if last_batch_id is not None: - where_conditions.append('(batches.id < %s)') + where_conditions.append('(job_groups.batch_id < %s)') where_args.append(last_batch_id) terms = q.split() @@ -26,53 +29,53 @@ def parse_list_batches_query_v1(user: str, q: str, last_batch_id: Optional[int]) if '=' in t: k, v = t.split('=', 1) - condition = ''' -((batches.id) IN - (SELECT batch_id FROM job_group_attributes + condition = """ +((job_groups.batch_id, job_groups.job_group_id) IN + (SELECT batch_id, job_group_id FROM job_group_attributes WHERE `key` = %s AND `value` = %s)) -''' +""" args = [k, v] elif t.startswith('has:'): k = t[4:] - condition = ''' -((batches.id) IN - (SELECT batch_id FROM job_group_attributes + condition = """ +((job_groups.batch_id, job_groups.job_group_id) IN + (SELECT batch_id, job_group_id FROM job_group_attributes WHERE `key` = %s)) -''' +""" args = [k] elif t.startswith('user:'): k = t[5:] - condition = ''' + condition = """ (batches.`user` = %s) -''' +""" args = [k] elif t.startswith('billing_project:'): k = t[16:] - condition = ''' + condition = """ (billing_projects.name_cs = %s) -''' +""" args = [k] elif t == 'open': - condition = "(`state` = 'open')" + condition = "(batches.`state` = 'open')" args = [] elif t == 'closed': - condition = "(`state` != 'open')" + condition = "(batches.`state` != 'open')" args = [] elif t == 'complete': - condition = "(`state` = 'complete')" + condition = "(batches.`state` = 'complete')" args = [] elif t == 'running': - condition = "(`state` = 'running')" + condition = "(batches.`state` = 'running')" args = [] elif t == 'cancelled': - condition = '(job_groups_cancelled.id IS NOT NULL)' + condition = '(cancelled_t.cancelled IS NOT NULL)' args = [] elif t == 'failure': condition = '(n_failed > 0)' args = [] elif t == 'success': # need complete because there might be no jobs - condition = "(`state` = 'complete' AND n_succeeded = n_jobs)" + condition = "(batches.`state` = 'complete' AND n_succeeded = batches.n_jobs)" args = [] else: raise QueryError(f'Invalid search term: {t}.') @@ -83,23 +86,34 @@ def parse_list_batches_query_v1(user: str, q: str, last_batch_id: Optional[int]) where_conditions.append(condition) where_args.extend(args) - sql = f''' + sql = f""" WITH base_t AS ( - SELECT batches.*, - job_groups_cancelled.id IS NOT NULL AS cancelled, + SELECT + batches.*, + job_groups.batch_id as batch_id, + job_groups.job_group_id as job_group_id, + cancelled_t.cancelled IS NOT NULL AS cancelled, job_groups_n_jobs_in_complete_states.n_completed, job_groups_n_jobs_in_complete_states.n_succeeded, job_groups_n_jobs_in_complete_states.n_failed, job_groups_n_jobs_in_complete_states.n_cancelled - FROM batches + FROM job_groups + LEFT JOIN batches ON batches.id = job_groups.batch_id LEFT JOIN billing_projects ON batches.billing_project = billing_projects.name LEFT JOIN job_groups_n_jobs_in_complete_states - ON batches.id = job_groups_n_jobs_in_complete_states.id - LEFT JOIN job_groups_cancelled - ON batches.id = job_groups_cancelled.id + ON job_groups.batch_id = job_groups_n_jobs_in_complete_states.id AND job_groups.job_group_id = job_groups_n_jobs_in_complete_states.job_group_id + LEFT JOIN LATERAL ( + SELECT 1 AS cancelled + FROM job_group_self_and_ancestors + INNER JOIN job_groups_cancelled + ON job_group_self_and_ancestors.batch_id = job_groups_cancelled.id AND + job_group_self_and_ancestors.ancestor_id = job_groups_cancelled.job_group_id + WHERE job_groups.batch_id = job_group_self_and_ancestors.batch_id AND + job_groups.job_group_id = job_group_self_and_ancestors.job_group_id + ) AS cancelled_t ON TRUE STRAIGHT_JOIN billing_project_users ON batches.billing_project = billing_project_users.billing_project WHERE {' AND '.join(where_conditions)} - ORDER BY id DESC + ORDER BY job_groups.batch_id DESC LIMIT 51 ) SELECT base_t.*, cost_t.cost, cost_t.cost_breakdown @@ -107,25 +121,100 @@ def parse_list_batches_query_v1(user: str, q: str, last_batch_id: Optional[int]) LEFT JOIN LATERAL ( SELECT COALESCE(SUM(`usage` * rate), 0) AS cost, JSON_OBJECTAGG(resources.resource, COALESCE(`usage` * rate, 0)) AS cost_breakdown FROM ( - SELECT batch_id, resource_id, CAST(COALESCE(SUM(`usage`), 0) AS SIGNED) AS `usage` + SELECT resource_id, CAST(COALESCE(SUM(`usage`), 0) AS SIGNED) AS `usage` FROM aggregated_job_group_resources_v3 - WHERE base_t.id = aggregated_job_group_resources_v3.batch_id - GROUP BY batch_id, resource_id + WHERE base_t.id = aggregated_job_group_resources_v3.batch_id AND base_t.job_group_id = aggregated_job_group_resources_v3.job_group_id + GROUP BY resource_id ) AS usage_t LEFT JOIN resources ON usage_t.resource_id = resources.resource_id - GROUP BY batch_id ) AS cost_t ON TRUE -ORDER BY id DESC; -''' +ORDER BY batch_id DESC +""" return (sql, where_args) -def parse_batch_jobs_query_v1(batch_id: int, q: str, last_job_id: Optional[int]) -> Tuple[str, List[Any]]: +def parse_list_job_groups_query_v1( + batch_id: int, job_group_id: int, last_child_job_group_id: Optional[int] +) -> Tuple[str, List[Any]]: + where_conds = [ + '(job_groups.batch_id = %s)', + '(NOT deleted)', + '(job_group_self_and_ancestors.ancestor_id = %s AND job_group_self_and_ancestors.level = 1)', + '(batch_updates.committed OR job_groups.job_group_id = %s)', + ] + sql_args = [batch_id, job_group_id, ROOT_JOB_GROUP_ID] + + if last_child_job_group_id is not None: + where_conds.append('(job_groups.job_group_id > %s)') + sql_args.append(last_child_job_group_id) + + sql = f""" +SELECT job_groups.*, + cancelled_t.cancelled IS NOT NULL AS cancelled, + job_groups_n_jobs_in_complete_states.n_completed, + job_groups_n_jobs_in_complete_states.n_succeeded, + job_groups_n_jobs_in_complete_states.n_failed, + job_groups_n_jobs_in_complete_states.n_cancelled, + cost_t.cost, cost_t.cost_breakdown +FROM job_group_self_and_ancestors +LEFT JOIN batches ON batches.id = job_group_self_and_ancestors.batch_id +LEFT JOIN job_groups + ON job_group_self_and_ancestors.batch_id = job_groups.batch_id AND + job_group_self_and_ancestors.job_group_id = job_groups.job_group_id +LEFT JOIN batch_updates ON batch_updates.batch_id = job_groups.batch_id AND + batch_updates.update_id = job_groups.update_id +LEFT JOIN job_groups_n_jobs_in_complete_states + ON job_groups.batch_id = job_groups_n_jobs_in_complete_states.id AND + job_groups.job_group_id = job_groups_n_jobs_in_complete_states.job_group_id +LEFT JOIN LATERAL ( + SELECT 1 AS cancelled + FROM job_group_self_and_ancestors + INNER JOIN job_groups_cancelled + ON job_group_self_and_ancestors.batch_id = job_groups_cancelled.id AND + job_group_self_and_ancestors.ancestor_id = job_groups_cancelled.job_group_id + WHERE job_groups.batch_id = job_group_self_and_ancestors.batch_id AND + job_groups.job_group_id = job_group_self_and_ancestors.job_group_id +) AS cancelled_t ON TRUE +LEFT JOIN LATERAL ( + SELECT COALESCE(SUM(`usage` * rate), 0) AS cost, JSON_OBJECTAGG(resources.resource, COALESCE(`usage` * rate, 0)) AS cost_breakdown + FROM ( + SELECT resource_id, CAST(COALESCE(SUM(`usage`), 0) AS SIGNED) AS `usage` + FROM aggregated_job_group_resources_v3 + WHERE job_groups.batch_id = aggregated_job_group_resources_v3.batch_id AND + job_groups.job_group_id = aggregated_job_group_resources_v3.job_group_id + GROUP BY resource_id + ) AS usage_t + LEFT JOIN resources ON usage_t.resource_id = resources.resource_id +) AS cost_t ON TRUE +WHERE {' AND '.join(where_conds)} +ORDER BY job_group_id ASC +LIMIT 51; +""" + + return (sql, sql_args) + + +def parse_job_group_jobs_query_v1( + batch_id: int, job_group_id: int, q: str, last_job_id: Optional[int], recursive: bool +) -> Tuple[str, List[Any]]: # batch has already been validated where_conditions = ['(jobs.batch_id = %s AND batch_updates.committed)'] where_args: List[Any] = [batch_id] + if recursive: + jg_cond = """ +((jobs.batch_id, jobs.job_group_id) IN + (SELECT batch_id, job_group_id FROM job_group_self_and_ancestors + WHERE batch_id = %s AND ancestor_id = %s)) +""" + where_args.extend([batch_id, job_group_id]) + else: + jg_cond = '(jobs.job_group_id = %s)' + where_args.append(job_group_id) + + where_conditions.append(jg_cond) + if last_job_id is not None: where_conditions.append('(jobs.job_id > %s)') where_args.append(last_job_id) @@ -147,19 +236,19 @@ def parse_batch_jobs_query_v1(batch_id: int, q: str, last_job_id: Optional[int]) condition = '(jobs.job_id = %s)' args = [v] else: - condition = ''' + condition = """ ((jobs.batch_id, jobs.job_id) IN (SELECT batch_id, job_id FROM job_attributes WHERE `key` = %s AND `value` = %s)) -''' +""" args = [k, v] elif t.startswith('has:'): k = t[4:] - condition = ''' + condition = """ ((jobs.batch_id, jobs.job_id) IN (SELECT batch_id, job_id FROM job_attributes WHERE `key` = %s)) -''' +""" args = [k] elif t in job_state_search_term_to_states: values = job_state_search_term_to_states[t] @@ -175,7 +264,7 @@ def parse_batch_jobs_query_v1(batch_id: int, q: str, last_job_id: Optional[int]) where_conditions.append(condition) where_args.extend(args) - sql = f''' + sql = f""" WITH base_t AS ( SELECT jobs.*, batches.user, batches.billing_project, batches.format_version, @@ -194,14 +283,13 @@ def parse_batch_jobs_query_v1(batch_id: int, q: str, last_job_id: Optional[int]) FROM base_t LEFT JOIN LATERAL ( SELECT COALESCE(SUM(`usage` * rate), 0) AS cost, JSON_OBJECTAGG(resources.resource, COALESCE(`usage` * rate, 0)) AS cost_breakdown -FROM (SELECT aggregated_job_resources_v3.batch_id, aggregated_job_resources_v3.job_id, resource_id, CAST(COALESCE(SUM(`usage`), 0) AS SIGNED) AS `usage` +FROM (SELECT resource_id, CAST(COALESCE(SUM(`usage`), 0) AS SIGNED) AS `usage` FROM aggregated_job_resources_v3 WHERE aggregated_job_resources_v3.batch_id = base_t.batch_id AND aggregated_job_resources_v3.job_id = base_t.job_id - GROUP BY aggregated_job_resources_v3.batch_id, aggregated_job_resources_v3.job_id, aggregated_job_resources_v3.resource_id + GROUP BY aggregated_job_resources_v3.resource_id ) AS usage_t LEFT JOIN resources ON usage_t.resource_id = resources.resource_id -GROUP BY usage_t.batch_id, usage_t.job_id ) AS cost_t ON TRUE; -''' +""" return (sql, where_args) diff --git a/batch/batch/front_end/query/query_v2.py b/batch/batch/front_end/query/query_v2.py index ad2df661ff8..a818eedd64c 100644 --- a/batch/batch/front_end/query/query_v2.py +++ b/batch/batch/front_end/query/query_v2.py @@ -1,5 +1,7 @@ from typing import Any, List, Optional, Tuple +from hailtop.batch_client.globals import ROOT_JOB_GROUP_ID + from ...exceptions import QueryError from .operators import ( GreaterThanEqualOperator, @@ -10,19 +12,19 @@ ) from .query import ( BatchBillingProjectQuery, - BatchCostQuery, - BatchDurationQuery, - BatchEndTimeQuery, BatchIdQuery, - BatchKeywordQuery, - BatchQuotedExactMatchQuery, - BatchStartTimeQuery, BatchStateQuery, - BatchUnquotedPartialMatchQuery, BatchUserQuery, JobCostQuery, JobDurationQuery, JobEndTimeQuery, + JobGroupCostQuery, + JobGroupDurationQuery, + JobGroupEndTimeQuery, + JobGroupKeywordQuery, + JobGroupQuotedExactMatchQuery, + JobGroupStartTimeQuery, + JobGroupUnquotedPartialMatchQuery, JobIdQuery, JobInstanceCollectionQuery, JobInstanceQuery, @@ -58,8 +60,8 @@ def parse_list_batches_query_v2(user: str, q: str, last_batch_id: Optional[int]) queries: List[Query] = [] # logic to make time interval queries fast - min_start_gt_query: Optional[BatchStartTimeQuery] = None - max_end_lt_query: Optional[BatchEndTimeQuery] = None + min_start_gt_query: Optional[JobGroupStartTimeQuery] = None + max_end_lt_query: Optional[JobGroupEndTimeQuery] = None if q: terms = q.rstrip().lstrip().split('\n') @@ -69,9 +71,9 @@ def parse_list_batches_query_v2(user: str, q: str, last_batch_id: Optional[int]) if len(statement) == 1: word = statement[0] if word[0] == '"': - queries.append(BatchQuotedExactMatchQuery.parse(word)) + queries.append(JobGroupQuotedExactMatchQuery.parse(word)) else: - queries.append(BatchUnquotedPartialMatchQuery.parse(word)) + queries.append(JobGroupUnquotedPartialMatchQuery.parse(word)) elif len(statement) == 3: left, op, right = statement if left == 'batch_id': @@ -83,42 +85,39 @@ def parse_list_batches_query_v2(user: str, q: str, last_batch_id: Optional[int]) elif left == 'state': queries.append(BatchStateQuery.parse(op, right)) elif left == 'start_time': - st_query = BatchStartTimeQuery.parse(op, right) + st_query = JobGroupStartTimeQuery.parse(op, right) queries.append(st_query) if (type(st_query.operator) in [GreaterThanOperator, GreaterThanEqualOperator]) and ( min_start_gt_query is None or min_start_gt_query.time_msecs >= st_query.time_msecs ): min_start_gt_query = st_query elif left == 'end_time': - et_query = BatchEndTimeQuery.parse(op, right) + et_query = JobGroupEndTimeQuery.parse(op, right) queries.append(et_query) if (type(et_query.operator) in [LessThanOperator, LessThanEqualOperator]) and ( max_end_lt_query is None or max_end_lt_query.time_msecs <= et_query.time_msecs ): max_end_lt_query = et_query elif left == 'duration': - queries.append(BatchDurationQuery.parse(op, right)) + queries.append(JobGroupDurationQuery.parse(op, right)) elif left == 'cost': - queries.append(BatchCostQuery.parse(op, right)) + queries.append(JobGroupCostQuery.parse(op, right)) else: - queries.append(BatchKeywordQuery.parse(op, left, right)) + queries.append(JobGroupKeywordQuery.parse(op, left, right)) else: raise QueryError(f'could not parse term "{_term}"') # this is to make time interval queries fast by using the bounds on both indices if min_start_gt_query and max_end_lt_query and min_start_gt_query.time_msecs <= max_end_lt_query.time_msecs: - queries.append(BatchStartTimeQuery(max_end_lt_query.operator, max_end_lt_query.time_msecs)) - queries.append(BatchEndTimeQuery(min_start_gt_query.operator, min_start_gt_query.time_msecs)) + queries.append(JobGroupStartTimeQuery(max_end_lt_query.operator, max_end_lt_query.time_msecs)) + queries.append(JobGroupEndTimeQuery(min_start_gt_query.operator, min_start_gt_query.time_msecs)) # batch has already been validated - where_conditions = [ - '(billing_project_users.`user` = %s)', - 'NOT deleted', - ] - where_args: List[Any] = [user] + where_conditions = ['(billing_project_users.`user` = %s)', 'NOT deleted', 'job_groups.job_group_id = %s'] + where_args: List[Any] = [user, ROOT_JOB_GROUP_ID] if last_batch_id is not None: - where_conditions.append('(batches.id < %s)') + where_conditions.append('(job_groups.batch_id < %s)') where_args.append(last_batch_id) for query in queries: @@ -126,34 +125,42 @@ def parse_list_batches_query_v2(user: str, q: str, last_batch_id: Optional[int]) where_conditions.append(f'({cond})') where_args += args - sql = f''' + sql = f""" SELECT batches.*, - job_groups_cancelled.id IS NOT NULL AS cancelled, + cancelled_t.cancelled IS NOT NULL AS cancelled, job_groups_n_jobs_in_complete_states.n_completed, job_groups_n_jobs_in_complete_states.n_succeeded, job_groups_n_jobs_in_complete_states.n_failed, job_groups_n_jobs_in_complete_states.n_cancelled, cost_t.cost, cost_t.cost_breakdown -FROM batches +FROM job_groups +LEFT JOIN batches ON batches.id = job_groups.batch_id LEFT JOIN billing_projects ON batches.billing_project = billing_projects.name -LEFT JOIN job_groups_n_jobs_in_complete_states ON batches.id = job_groups_n_jobs_in_complete_states.id -LEFT JOIN job_groups_cancelled ON batches.id = job_groups_cancelled.id +LEFT JOIN job_groups_n_jobs_in_complete_states ON job_groups.batch_id = job_groups_n_jobs_in_complete_states.id AND job_groups.job_group_id = job_groups_n_jobs_in_complete_states.job_group_id +LEFT JOIN LATERAL ( + SELECT 1 AS cancelled + FROM job_group_self_and_ancestors + INNER JOIN job_groups_cancelled + ON job_group_self_and_ancestors.batch_id = job_groups_cancelled.id AND + job_group_self_and_ancestors.ancestor_id = job_groups_cancelled.job_group_id + WHERE job_groups.batch_id = job_group_self_and_ancestors.batch_id AND + job_groups.job_group_id = job_group_self_and_ancestors.job_group_id +) AS cancelled_t ON TRUE STRAIGHT_JOIN billing_project_users ON batches.billing_project = billing_project_users.billing_project LEFT JOIN LATERAL ( SELECT COALESCE(SUM(`usage` * rate), 0) AS cost, JSON_OBJECTAGG(resources.resource, COALESCE(`usage` * rate, 0)) AS cost_breakdown FROM ( - SELECT batch_id, resource_id, CAST(COALESCE(SUM(`usage`), 0) AS SIGNED) AS `usage` + SELECT resource_id, CAST(COALESCE(SUM(`usage`), 0) AS SIGNED) AS `usage` FROM aggregated_job_group_resources_v3 - WHERE batches.id = aggregated_job_group_resources_v3.batch_id - GROUP BY batch_id, resource_id + WHERE job_groups.batch_id = aggregated_job_group_resources_v3.batch_id AND job_groups.job_group_id = aggregated_job_group_resources_v3.job_group_id + GROUP BY resource_id ) AS usage_t LEFT JOIN resources ON usage_t.resource_id = resources.resource_id - GROUP BY batch_id ) AS cost_t ON TRUE WHERE {' AND '.join(where_conditions)} -ORDER BY id DESC +ORDER BY job_groups.batch_id DESC LIMIT 51; -''' +""" return (sql, where_args) @@ -178,7 +185,9 @@ def parse_list_batches_query_v2(user: str, q: str, last_batch_id: Optional[int]) # ::= -def parse_batch_jobs_query_v2(batch_id: int, q: str, last_job_id: Optional[int]) -> Tuple[str, List[Any]]: +def parse_job_group_jobs_query_v2( + batch_id: int, job_group_id: int, q: str, last_job_id: Optional[int], recursive: bool +) -> Tuple[str, List[Any]]: queries: List[Query] = [] # logic to make time interval queries fast @@ -238,6 +247,19 @@ def parse_batch_jobs_query_v2(batch_id: int, q: str, last_job_id: Optional[int]) where_conditions = ['(jobs.batch_id = %s AND batch_updates.committed)'] where_args = [batch_id] + if recursive: + jg_cond = """ +((jobs.batch_id, jobs.job_group_id) IN + (SELECT batch_id, job_group_id FROM job_group_self_and_ancestors + WHERE batch_id = %s AND ancestor_id = %s)) +""" + where_args.extend([batch_id, job_group_id]) + else: + jg_cond = '(jobs.job_group_id = %s)' + where_args.append(job_group_id) + + where_conditions.append(jg_cond) + if last_job_id is not None: where_conditions.append('(jobs.job_id > %s)') where_args.append(last_job_id) @@ -268,7 +290,7 @@ def parse_batch_jobs_query_v2(batch_id: int, q: str, last_job_id: Optional[int]) else: attempts_table_join_str = '' - sql = f''' + sql = f""" SELECT jobs.*, batches.user, batches.billing_project, batches.format_version, job_attributes.value AS name, cost_t.cost, cost_t.cost_breakdown FROM jobs @@ -281,16 +303,15 @@ def parse_batch_jobs_query_v2(batch_id: int, q: str, last_job_id: Optional[int]) {attempts_table_join_str} LEFT JOIN LATERAL ( SELECT COALESCE(SUM(`usage` * rate), 0) AS cost, JSON_OBJECTAGG(resources.resource, COALESCE(`usage` * rate, 0)) AS cost_breakdown -FROM (SELECT aggregated_job_resources_v3.batch_id, aggregated_job_resources_v3.job_id, resource_id, CAST(COALESCE(SUM(`usage`), 0) AS SIGNED) AS `usage` +FROM (SELECT resource_id, CAST(COALESCE(SUM(`usage`), 0) AS SIGNED) AS `usage` FROM aggregated_job_resources_v3 WHERE aggregated_job_resources_v3.batch_id = jobs.batch_id AND aggregated_job_resources_v3.job_id = jobs.job_id - GROUP BY aggregated_job_resources_v3.batch_id, aggregated_job_resources_v3.job_id, aggregated_job_resources_v3.resource_id + GROUP BY aggregated_job_resources_v3.resource_id ) AS usage_t LEFT JOIN resources ON usage_t.resource_id = resources.resource_id -GROUP BY usage_t.batch_id, usage_t.job_id ) AS cost_t ON TRUE WHERE {" AND ".join(where_conditions)} LIMIT 50; -''' +""" return (sql, where_args) diff --git a/batch/batch/front_end/templates/batch.html b/batch/batch/front_end/templates/batch.html index 109ffb5c12e..f1c89215c86 100644 --- a/batch/batch/front_end/templates/batch.html +++ b/batch/batch/front_end/templates/batch.html @@ -91,7 +91,7 @@

Jobs

{{ job['name'] }} {% endif %} - {{ job['state'] }} + {{ job['display_state'] }} {% if 'exit_code' in job and job['exit_code'] is not none %} {{ job['exit_code'] }} diff --git a/batch/batch/front_end/templates/billing.html b/batch/batch/front_end/templates/billing.html index 027c6993414..ffcdb4066e3 100644 --- a/batch/batch/front_end/templates/billing.html +++ b/batch/batch/front_end/templates/billing.html @@ -34,19 +34,19 @@

Billing

-

Total Cost

+

Total Spend

{% if is_developer %} -

Cost by Billing Project

+

Spend by Billing Project

- + @@ -60,13 +60,13 @@

Cost by Billing Project

Billing ProjectCostSpend
-

Cost by User

+

Spend by User

- + @@ -81,14 +81,14 @@

Cost by User

{% endif %} -

Cost by Billing Project and User

+

Spend by Billing Project and User

UserCostSpend
- + diff --git a/batch/batch/front_end/templates/job.html b/batch/batch/front_end/templates/job.html index a528a5dd1d7..d7cee213e6c 100644 --- a/batch/batch/front_end/templates/job.html +++ b/batch/batch/front_end/templates/job.html @@ -13,6 +13,7 @@

Properties

  • Exit Code: {% if 'exit_code' in job and job['exit_code'] is not none %}{{ job['exit_code'] }}{% endif %}
  • Duration: {% if 'duration' in job and job['duration'] is not none %}{{ job['duration'] }}{% endif %}
  • Cost: {% if 'cost' in job and job['cost'] is not none %}{{ job['cost'] }}{% endif %}
  • +
  • Always Run: {% if 'always_run' in job and job['always_run'] is not none %}{{ job['always_run'] }}{% endif %}
  • Attributes

    @@ -166,7 +167,7 @@

    Logs

    {% if 'input' in job_log or 'input' in step_errors %}

    Input

    {% if 'input' in job_log %} -

    Log

    +

    Log download

    {{ job_log['input'] }}
    {% endif %} {% if 'input' in step_errors and step_errors['input'] is not none %} @@ -178,7 +179,7 @@

    Error

    {% if 'main' in job_log or 'main' in step_errors %}

    Main

    {% if 'main' in job_log %} -

    Log

    +

    Log download

    {{ job_log['main'] }}
    {% endif %} {% if 'main' in step_errors and step_errors['main'] is not none %} @@ -190,7 +191,7 @@

    Error

    {% if 'output' in job_log or 'output' in step_errors %}

    Output

    {% if 'output' in job_log %} -

    Log

    +

    Log download

    {{ job_log['output'] }}
    {% endif %} {% if 'output' in step_errors and step_errors['output'] is not none %} diff --git a/batch/batch/front_end/templates/table_search.html b/batch/batch/front_end/templates/table_search.html index 00c3736f231..859a1b87da2 100644 --- a/batch/batch/front_end/templates/table_search.html +++ b/batch/batch/front_end/templates/table_search.html @@ -53,8 +53,7 @@
    - - + help
    ' summary += '
    Billing Project UserCostSpend
    {html.escape(name)}{html.escape(self.format(v))}
    ' for name, field in self.nested.items(): + _name = name if prefix is not None: - name = f'{prefix}{name}' - summary += '
  • ' + field._html_string(prefix=name) + '
  • ' + _name = f'{prefix}{name}' + summary += '
  • ' + field._html_string(prefix=_name) + '
  • ' summary += '' return summary @@ -107,6 +129,7 @@ def __repr__(self): def _repr_html_(self): import html + s = self.summary._html_string(prefix=self.name) if self.header: s = f'

    {html.escape(self.header)}

    ' + s @@ -132,15 +155,14 @@ def impute_type(x, partial_type=None): def _impute_type(x, partial_type): - from hail.genetics import Locus, Call + from hail.genetics import Call, Locus from hail.utils import Interval, Struct def refine(t, refined): if t is None: return refined if not isinstance(t, type(refined)): - raise ExpressionException( - "Incompatible partial_type, {}, for value {}".format(partial_type, x)) + raise ExpressionException("Incompatible partial_type, {}, for value {}".format(partial_type, x)) return t if isinstance(x, Expression): @@ -170,8 +192,10 @@ def refine(t, refined): return t elif isinstance(x, tuple): partial_type = refine(partial_type, hl.ttuple()) - return ttuple(*[_impute_type(element, partial_type[index] if index < len(partial_type) else None) - for index, element in enumerate(x)]) + return ttuple(*[ + _impute_type(element, partial_type[index] if index < len(partial_type) else None) + for index, element in enumerate(x) + ]) elif isinstance(x, list): partial_type = refine(partial_type, hl.tarray(None)) if len(x) == 0: @@ -179,8 +203,9 @@ def refine(t, refined): ts = {_impute_type(element, partial_type.element_type) for element in x} unified_type = super_unify_types(*ts) if unified_type is None: - raise ExpressionException("Hail does not support heterogeneous arrays: " - "found list with elements of types {} ".format(list(ts))) + raise ExpressionException( + "Hail does not support heterogeneous arrays: " "found list with elements of types {} ".format(list(ts)) + ) return tarray(unified_type) elif is_setlike(x): @@ -190,8 +215,9 @@ def refine(t, refined): ts = {_impute_type(element, partial_type.element_type) for element in x} unified_type = super_unify_types(*ts) if not unified_type: - raise ExpressionException("Hail does not support heterogeneous sets: " - "found set with elements of types {} ".format(list(ts))) + raise ExpressionException( + "Hail does not support heterogeneous sets: " "found set with elements of types {} ".format(list(ts)) + ) return tset(unified_type) elif isinstance(x, Mapping): @@ -204,14 +230,18 @@ def refine(t, refined): unified_key_type = super_unify_types(*kts) unified_value_type = super_unify_types(*vts) if unified_key_type is None: - raise ExpressionException("Hail does not support heterogeneous dicts: " - "found dict with keys {} of types {} ".format(list(x.keys()), list(kts))) + raise ExpressionException( + "Hail does not support heterogeneous dicts: " "found dict with keys {} of types {} ".format( + list(x.keys()), list(kts) + ) + ) if not unified_value_type: if unified_key_type == hl.tstr and user_partial_type is None: return tstruct(**{k: _impute_type(x[k], None) for k in x}) - raise ExpressionException("Hail does not support heterogeneous dicts: " - "found dict with values of types {} ".format(list(vts))) + raise ExpressionException( + "Hail does not support heterogeneous dicts: " "found dict with values of types {} ".format(list(vts)) + ) return tdict(unified_key_type, unified_value_type) elif isinstance(x, np.generic): return from_numpy(x.dtype) @@ -221,8 +251,9 @@ def refine(t, refined): elif x is None or pd.isna(x): return partial_type elif isinstance(x, (hl.expr.builders.CaseBuilder, hl.expr.builders.SwitchBuilder)): - raise ExpressionException("'switch' and 'case' expressions must end with a call to either" - "'default' or 'or_missing'") + raise ExpressionException( + "'switch' and 'case' expressions must end with a call to either" "'default' or 'or_missing'" + ) else: raise ExpressionException("Hail cannot automatically impute type of {}: {}".format(type(x), x)) @@ -316,11 +347,13 @@ def _to_expr(e, dtype): if not found_expr: return e else: - exprs = [new_fields[i] if isinstance(new_fields[i], Expression) - else hl.literal(new_fields[i], dtype[i]) - for i in range(len(new_fields))] + exprs = [ + new_fields[i] if isinstance(new_fields[i], Expression) else hl.literal(new_fields[i], dtype[i]) + for i in range(len(new_fields)) + ] fields = {name: expr for name, expr in zip(dtype.keys(), exprs)} from .typed_expressions import StructExpression + return StructExpression._from_fields(fields) elif isinstance(dtype, tarray): @@ -334,9 +367,10 @@ def _to_expr(e, dtype): return e else: assert len(elements) > 0 - exprs = [element if isinstance(element, Expression) - else hl.literal(element, dtype.element_type) - for element in elements] + exprs = [ + element if isinstance(element, Expression) else hl.literal(element, dtype.element_type) + for element in elements + ] indices, aggregations = unify_all(*exprs) x = ir.MakeArray([e._ir for e in exprs], None) return expressions.construct_expr(x, dtype, indices, aggregations) @@ -351,9 +385,10 @@ def _to_expr(e, dtype): return e else: assert len(elements) > 0 - exprs = [element if isinstance(element, Expression) - else hl.literal(element, dtype.element_type) - for element in elements] + exprs = [ + element if isinstance(element, Expression) else hl.literal(element, dtype.element_type) + for element in elements + ] indices, aggregations = unify_all(*exprs) x = ir.ToSet(ir.toStream(ir.MakeArray([e._ir for e in exprs], None))) return expressions.construct_expr(x, dtype, indices, aggregations) @@ -368,9 +403,10 @@ def _to_expr(e, dtype): if not found_expr: return e else: - exprs = [elements[i] if isinstance(elements[i], Expression) - else hl.literal(elements[i], dtype.types[i]) - for i in range(len(elements))] + exprs = [ + elements[i] if isinstance(elements[i], Expression) else hl.literal(elements[i], dtype.types[i]) + for i in range(len(elements)) + ] indices, aggregations = unify_all(*exprs) x = ir.MakeTuple([expr._ir for expr in exprs]) return expressions.construct_expr(x, dtype, indices, aggregations) @@ -409,17 +445,19 @@ def unify_all(*exprs) -> Tuple[Indices, LinkedList]: except ExpressionException: # source mismatch from collections import defaultdict + sources = defaultdict(lambda: []) for e in exprs: from .expression_utils import get_refs + for name, inds in get_refs(e, *[e for a in e._aggregations for e in a.exprs]).items(): sources[inds.source].append(str(name)) raise ExpressionException( "Cannot combine expressions from different source objects." "\n Found fields from {n} objects:{fields}".format( - n=len(sources), - fields=''.join("\n {}: {}".format(src, fds) for src, fds in sources.items()) - )) from None + n=len(sources), fields=''.join("\n {}: {}".format(src, fds) for src, fds in sources.items()) + ) + ) from None first, rest = exprs[0], exprs[1:] aggregations = first._aggregations for e in rest: @@ -483,8 +521,7 @@ def super_unify_types(*ts): return tdict(kt, vt) if isinstance(t0, tstruct): keys = [k for t in ts for k in t.fields] - kvs = {k: super_unify_types(*[t.get(k, None) for t in ts]) - for k in keys} + kvs = {k: super_unify_types(*[t.get(k, None) for t in ts]) for k in keys} return tstruct(**kvs) if all(t0 == t for t in ts): return t0 @@ -498,15 +535,15 @@ def unify_exprs(*exprs: 'Expression') -> Tuple: # all types are the same if len(types) == 1: - return exprs + (True,) + return (*exprs, True) for t in types: c = expressions.coercer_from_dtype(t) if all(c.can_coerce(e.dtype) for e in exprs): - return tuple([c.coerce(e) for e in exprs]) + (True,) + return (*tuple([c.coerce(e) for e in exprs]), True) # cannot coerce all types to the same type - return exprs + (False,) + return (*exprs, False) class Expression(object): @@ -515,12 +552,9 @@ class Expression(object): __array_ufunc__ = None # disable NumPy coercions, so Hail coercions take priority @typecheck_method(x=ir.IR, type=nullable(HailType), indices=Indices, aggregations=linked_list(Aggregation)) - def __init__(self, - x: ir.IR, - type: HailType, - indices: Indices = Indices(), - aggregations: LinkedList = LinkedList(Aggregation)): - + def __init__( + self, x: ir.IR, type: HailType, indices: Indices = Indices(), aggregations: LinkedList = LinkedList(Aggregation) + ): self._ir: ir.IR = x self._type = type self._indices = indices @@ -534,28 +568,34 @@ def describe(self, handler=print): for a in self._aggregations: agg_indices = agg_indices.union(a.indices.axes) agg_tag = ' (aggregated)' - agg_str = f'Includes aggregation with index {list(agg_indices)}\n' \ - f' (Aggregation index may be promoted based on context)' + agg_str = ( + f'Includes aggregation with index {list(agg_indices)}\n' + f' (Aggregation index may be promoted based on context)' + ) else: agg_tag = '' agg_str = '' bar = '--------------------------------------------------------' - s = '{bar}\n' \ - 'Type:\n' \ - ' {t}\n' \ - '{bar}\n' \ - 'Source:\n' \ - ' {src}\n' \ - 'Index:\n' \ - ' {inds}{agg_tag}{maybe_bar}{agg}\n' \ - '{bar}'.format(bar=bar, - t=self.dtype.pretty(indent=4), - src=self._indices.source, - inds=list(self._indices.axes), - maybe_bar='\n' + bar + '\n' if agg_str else '', - agg_tag=agg_tag, - agg=agg_str) + s = ( + '{bar}\n' + 'Type:\n' + ' {t}\n' + '{bar}\n' + 'Source:\n' + ' {src}\n' + 'Index:\n' + ' {inds}{agg_tag}{maybe_bar}{agg}\n' + '{bar}'.format( + bar=bar, + t=self.dtype.pretty(indent=4), + src=self._indices.source, + inds=list(self._indices.axes), + maybe_bar='\n' + bar + '\n' if agg_str else '', + agg_tag=agg_tag, + agg=agg_str, + ) + ) handler(s) def __lt__(self, other): @@ -575,17 +615,19 @@ def __nonzero__(self): "The truth value of an expression is undefined\n" " Hint: instead of 'if x', use 'hl.if_else(x, ...)'\n" " Hint: instead of 'x and y' or 'x or y', use 'x & y' or 'x | y'\n" - " Hint: instead of 'not x', use '~x'") + " Hint: instead of 'not x', use '~x'" + ) def __iter__(self): - raise ExpressionException(f"{repr(self)} object is not iterable") + raise ExpressionException(f"{self!r} object is not iterable") def _compare_op(self, op, other): other = to_expr(other) left, right, success = unify_exprs(self, other) if not success: - raise TypeError(f"Invalid '{op}' comparison, cannot compare expressions " - f"of type '{self.dtype}' and '{other.dtype}'") + raise TypeError( + f"Invalid '{op}' comparison, cannot compare expressions " f"of type '{self.dtype}' and '{other.dtype}'" + ) res = left._bin_op(op, right, hl.tbool) return res @@ -615,7 +657,7 @@ def _promote_numeric(self, typ): @staticmethod def _div_ret_type_f(t): assert is_numeric(t) - if t == tint32 or t == tint64: + if t in {tint32, tint64}: return tfloat64 else: # Float64 or Float32 @@ -669,7 +711,9 @@ def _bin_op_numeric_reverse(self, name, other, ret_type_f=None): return to_expr(other)._bin_op_numeric(name, self, ret_type_f) def _unary_op(self, name): - return expressions.construct_expr(ir.ApplyUnaryPrimOp(name, self._ir), self._type, self._indices, self._aggregations) + return expressions.construct_expr( + ir.ApplyUnaryPrimOp(name, self._ir), self._type, self._indices, self._aggregations + ) def _bin_op(self, name, other, ret_type): other = to_expr(other) @@ -679,10 +723,7 @@ def _bin_op(self, name, other, ret_type): elif name in {"==", "!=", "<", "<=", ">", ">="}: op = ir.ApplyComparisonOp(name, self._ir, other._ir) else: - d = { - '+': 'add', '-': 'sub', '*': 'mul', '/': 'div', '//': 'floordiv', - '%': 'mod', '**': 'pow' - } + d = {'+': 'add', '-': 'sub', '*': 'mul', '/': 'div', '//': 'floordiv', '%': 'mod', '**': 'pow'} op = ir.Apply(d.get(name, name), ret_type, self._ir, other._ir) return expressions.construct_expr(op, ret_type, indices, aggregations) @@ -703,7 +744,8 @@ def _ir_lambda_method(self, irf, f, input_type, ret_type_f, *args): args = (to_expr(arg)._ir for arg in args) new_id = Env.get_uid() lambda_result = to_expr( - f(expressions.construct_variable(new_id, input_type, self._indices, self._aggregations))) + f(expressions.construct_variable(new_id, input_type, self._indices, self._aggregations)) + ) indices, aggregations = unify_all(self, lambda_result) x = irf(self._ir, new_id, lambda_result._ir, *args) @@ -714,8 +756,11 @@ def _ir_lambda_method2(self, other, irf, f, input_type1, input_type2, ret_type_f new_id1 = Env.get_uid() new_id2 = Env.get_uid() lambda_result = to_expr( - f(expressions.construct_variable(new_id1, input_type1, self._indices, self._aggregations), - expressions.construct_variable(new_id2, input_type2, other._indices, other._aggregations))) + f( + expressions.construct_variable(new_id1, input_type1, self._indices, self._aggregations), + expressions.construct_variable(new_id2, input_type2, other._indices, other._aggregations), + ) + ) indices, aggregations = unify_all(self, other, lambda_result) x = irf(self._ir, other._ir, new_id1, new_id2, lambda_result._ir, *args) return expressions.construct_expr(x, ret_type_f(lambda_result._type), indices, aggregations) @@ -732,7 +777,9 @@ def dtype(self) -> HailType: return self._type def __bool__(self): - raise TypeError("'Expression' objects cannot be converted to a 'bool'. Use 'hl.if_else' instead of Python if statements.") + raise TypeError( + "'Expression' objects cannot be converted to a 'bool'. Use 'hl.if_else' instead of Python if statements." + ) def __len__(self): raise TypeError("'Expression' objects have no static length: use 'hl.len' for the length of collections") @@ -859,21 +906,16 @@ def _to_relational(self, fallback_name): ds = source.select_entries(**named_self).select_globals().select_cols().select_rows() return name, ds - @typecheck_method(n=nullable(int), - width=nullable(int), - truncate=nullable(int), - types=bool, - handler=nullable(anyfunc), - n_rows=nullable(int), - n_cols=nullable(int)) - def show(self, - n=None, - width=None, - truncate=None, - types=True, - handler=None, - n_rows=None, - n_cols=None): + @typecheck_method( + n=nullable(int), + width=nullable(int), + truncate=nullable(int), + types=bool, + handler=nullable(anyfunc), + n_rows=nullable(int), + n_cols=nullable(int), + ) + def show(self, n=None, width=None, truncate=None, types=True, handler=None, n_rows=None, n_cols=None): """Print the first few records of the expression to the console. If the expression refers to a value on a keyed axis of a table or matrix @@ -923,8 +965,14 @@ def show(self, Print an extra header line with the type of each field. """ kwargs = { - 'n': n, 'width': width, 'truncate': truncate, 'types': types, - 'handler': handler, 'n_rows': n_rows, 'n_cols': n_cols} + 'n': n, + 'width': width, + 'truncate': truncate, + 'types': types, + 'handler': handler, + 'n_rows': n_rows, + 'n_cols': n_cols, + } if kwargs.get('n_rows') is None: kwargs['n_rows'] = kwargs['n'] del kwargs['n'] @@ -962,31 +1010,30 @@ def export(self, path, delimiter='\t', missing='NA', header=True): >>> with open('output/gt.tsv', 'r') as f: ... for line in f: ... print(line, end='') - locus alleles 0 1 2 3 - 1:1 ["A","C"] 0/1 0/1 1/1 0/0 - 1:2 ["A","C"] 1/1 1/1 0/1 1/1 - 1:3 ["A","C"] 0/0 0/0 1/1 0/1 - 1:4 ["A","C"] 0/0 0/0 0/0 1/1 + locus alleles 0 1 2 3 + 1:1 ["A","C"] 0/1 0/0 0/1 0/0 + 1:2 ["A","C"] 1/1 0/1 0/1 0/1 + 1:3 ["A","C"] 0/0 0/1 0/0 0/0 + 1:4 ["A","C"] 0/1 1/1 0/1 0/1 >>> small_mt.GT.export('output/gt-no-header.tsv', header=False) >>> with open('output/gt-no-header.tsv', 'r') as f: ... for line in f: ... print(line, end='') - 1:1 ["A","C"] 0/1 0/1 1/1 0/0 - 1:2 ["A","C"] 1/1 1/1 0/1 1/1 - 1:3 ["A","C"] 0/0 0/0 1/1 0/1 - 1:4 ["A","C"] 0/0 0/0 0/0 1/1 + 1:1 ["A","C"] 0/1 0/0 0/1 0/0 + 1:2 ["A","C"] 1/1 0/1 0/1 0/1 + 1:3 ["A","C"] 0/0 0/1 0/0 0/0 + 1:4 ["A","C"] 0/1 1/1 0/1 0/1 >>> small_mt.pop.export('output/pops.tsv') >>> with open('output/pops.tsv', 'r') as f: ... for line in f: ... print(line, end='') - sample_idx pop - 0 0 - 1 0 - 2 2 - 3 0 - + sample_idx pop + 0 1 + 1 2 + 2 2 + 3 2 >>> small_mt.ancestral_af.export('output/ancestral_af.tsv') >>> with open('output/ancestral_af.tsv', 'r') as f: @@ -1021,12 +1068,11 @@ def export(self, path, delimiter='\t', missing='NA', header=True): >>> with open('output/gt-no-header.tsv', 'r') as f: ... for line in f: ... print(line, end='') - locus alleles {"s":0,"family":"fam1"} {"s":1,"family":"fam1"} {"s":2,"family":"fam1"} {"s":3,"family":"fam1"} - 1:1 ["A","C"] 0/1 0/1 1/1 0/0 - 1:2 ["A","C"] 1/1 1/1 0/1 1/1 - 1:3 ["A","C"] 0/0 0/0 1/1 0/1 - 1:4 ["A","C"] 0/0 0/0 0/0 1/1 - + locus alleles {"s":0,"family":"fam1"} {"s":1,"family":"fam1"} {"s":2,"family":"fam1"} {"s":3,"family":"fam1"} + 1:1 ["A","C"] 0/1 0/0 0/1 0/0 + 1:2 ["A","C"] 1/1 0/1 0/1 0/1 + 1:3 ["A","C"] 0/0 0/1 0/0 0/0 + 1:4 ["A","C"] 0/1 1/1 0/1 0/1 Parameters @@ -1053,19 +1099,20 @@ def export(self, path, delimiter='\t', missing='NA', header=True): entry_array = t[entries] if self_name: entry_array = hl.map(lambda x: x[self_name], entry_array) - entry_array = hl.map(lambda x: hl.if_else(hl.is_missing(x), missing, hl.str(x)), - entry_array) + entry_array = hl.map(lambda x: hl.if_else(hl.is_missing(x), missing, hl.str(x)), entry_array) file_contents = t.select( - **{k: hl.str(t[k]) for k in ds.row_key}, - **{output_col_name: hl.delimit(entry_array, delimiter)}) + **{k: hl.str(t[k]) for k in ds.row_key}, **{output_col_name: hl.delimit(entry_array, delimiter)} + ) if header: col_key = t[cols] if len(ds.col_key) == 1: col_key = hl.map(lambda x: x[0], col_key) column_names = hl.map(hl.str, col_key).collect(_localize=False)[0] - header_table = hl.utils.range_table(1).key_by().select( - **{k: k for k in ds.row_key}, - **{output_col_name: hl.delimit(column_names, delimiter)}) + header_table = ( + hl.utils.range_table(1) + .key_by() + .select(**{k: k for k in ds.row_key}, **{output_col_name: hl.delimit(column_names, delimiter)}) + ) file_contents = header_table.union(file_contents) file_contents.export(path, delimiter=delimiter, header=False) @@ -1102,11 +1149,11 @@ def take(self, n, _localize=True): return e @overload - def collect(self) -> List[Any]: - ... + def collect(self) -> List[Any]: ... + @overload - def collect(self, _localize=False) -> 'Expression': - ... + def collect(self, _localize=False) -> 'Expression': ... + @typecheck_method(_localize=bool) def collect(self, _localize=True): """Collect all records of an expression into a local list. @@ -1151,8 +1198,11 @@ def _summary_fields(self, agg_result, top): defined_value_str = str(n_defined) if n_defined == 0 else f'{n_defined} ({(n_defined / tot) * 100:.2f}%)' if n_defined == 0: return {'Non-missing': defined_value_str, 'Missing': missing_value_str}, {} - return {'Non-missing': defined_value_str, 'Missing': missing_value_str, - **self._extra_summary_fields(agg_result[2])}, self._nested_summary(agg_result[2], top) + return { + 'Non-missing': defined_value_str, + 'Missing': missing_value_str, + **self._extra_summary_fields(agg_result[2]), + }, self._nested_summary(agg_result[2], top) def _nested_summary(self, agg_result, top): return {} @@ -1164,7 +1214,8 @@ def _all_summary_aggs(self): return hl.tuple(( hl.agg.filter(hl.is_missing(self), hl.agg.count()), hl.agg.filter(hl.is_defined(self), hl.agg.count()), - self._summary_aggs())) + self._summary_aggs(), + )) def _summarize(self, agg_res=None, *, name=None, header=None, top=False): src = self._indices.source @@ -1191,11 +1242,10 @@ def summarize(self, handler=None): if self in src._fields: field_name = src._fields_inverse[self] prefix = field_name + elif self._ir.is_nested_field: + prefix = self._ir.name else: - if self._ir.is_nested_field: - prefix = self._ir.name - else: - prefix = '' + prefix = '' if handler is None: handler = hl.utils.default_handler() diff --git a/hail/python/hail/expr/expressions/expression_typecheck.py b/hail/python/hail/expr/expressions/expression_typecheck.py index b155814a88a..9c5f87e57f7 100644 --- a/hail/python/hail/expr/expressions/expression_typecheck.py +++ b/hail/python/hail/expr/expressions/expression_typecheck.py @@ -1,12 +1,28 @@ import abc -from typing import Optional, Dict, Any, TypeVar, List +from typing import Any, Dict, List, Optional, TypeVar import hail as hl from hail.expr.expressions import Expression, ExpressionException, to_expr -from hail.expr.types import HailType, tint32, tint64, tfloat32, tfloat64, \ - tstr, tbool, tarray, tstream, tndarray, tset, tdict, tstruct, tunion, \ - ttuple, tinterval, tlocus, tcall - +from hail.expr.types import ( + HailType, + tarray, + tbool, + tcall, + tdict, + tfloat32, + tfloat64, + tint32, + tint64, + tinterval, + tlocus, + tndarray, + tset, + tstr, + tstream, + tstruct, + ttuple, + tunion, +) from hail.typecheck import TypeChecker, TypecheckFailure from hail.utils.java import escape_parsable @@ -41,8 +57,7 @@ class ExprCoercer(TypeChecker): @property @abc.abstractmethod - def str_t(self) -> str: - ... + def str_t(self) -> str: ... def requires_conversion(self, t: HailType) -> bool: assert self.can_coerce(t), t @@ -54,8 +69,7 @@ def _requires_conversion(self, t: HailType) -> bool: ... @abc.abstractmethod - def can_coerce(self, t: HailType) -> bool: - ... + def can_coerce(self, t: HailType) -> bool: ... def coerce(self, x) -> Expression: x = to_expr(x) @@ -231,10 +245,12 @@ def can_coerce(self, t: HailType) -> bool: def _coerce(self, x): assert isinstance(x, hl.expr.IntervalExpression) - return hl.interval(self.point_type.coerce(x.start), - self.point_type.coerce(x.end), - includes_start=x.includes_start, - includes_end=x.includes_end) + return hl.interval( + self.point_type.coerce(x.start), + self.point_type.coerce(x.end), + includes_start=x.includes_start, + includes_end=x.includes_end, + ) class ArrayCoercer(ExprCoercer): @@ -253,8 +269,7 @@ def _requires_conversion(self, t: HailType) -> bool: return self.ec._requires_conversion(t.element_type) def can_coerce(self, t: HailType) -> bool: - return ((isinstance(t, tndarray) and t.ndim == 1 or isinstance(t, tarray)) - and self.ec.can_coerce(t.element_type)) + return (isinstance(t, tndarray) and t.ndim == 1 or isinstance(t, tarray)) and self.ec.can_coerce(t.element_type) def _coerce(self, x: Expression): if isinstance(x, hl.expr.NDArrayExpression): @@ -306,7 +321,6 @@ def _coerce(self, x: Expression): class SetCoercer(ExprCoercer): - def __init__(self, ec: ExprCoercer = AnyCoercer()): super(SetCoercer, self).__init__() self.ec = ec @@ -350,8 +364,7 @@ def _coerce(self, x: Expression): # fast path return x.map_values(self.vc.coerce) else: - return hl.dict(hl.map(lambda e: (self.kc.coerce(e[0]), self.vc.coerce(e[1])), - hl.array(x))) + return hl.dict(hl.map(lambda e: (self.kc.coerce(e[0]), self.vc.coerce(e[1])), hl.array(x))) class TupleCoercer(ExprCoercer): @@ -378,9 +391,11 @@ def can_coerce(self, t: HailType): if self.elements is None: return isinstance(t, ttuple) else: - return (isinstance(t, ttuple) - and len(t.types) == len(self.elements) - and all(c.can_coerce(t_) for c, t_ in zip(self.elements, t.types))) + return ( + isinstance(t, ttuple) + and len(t.types) == len(self.elements) + and all(c.can_coerce(t_) for c, t_ in zip(self.elements, t.types)) + ) def _coerce(self, x: Expression): assert isinstance(x, hl.expr.TupleExpression) @@ -411,10 +426,14 @@ def can_coerce(self, t: HailType): if self.fields is None: return isinstance(t, tstruct) else: - return (isinstance(t, tstruct) - and len(t) == len(self.fields) - and all(expected[0] == actual[0] and expected[1].can_coerce(actual[1]) - for expected, actual in zip(self.fields.items(), t.items()))) + return ( + isinstance(t, tstruct) + and len(t) == len(self.fields) + and all( + expected[0] == actual[0] and expected[1].can_coerce(actual[1]) + for expected, actual in zip(self.fields.items(), t.items()) + ) + ) def _coerce(self, x: Expression): assert isinstance(x, hl.expr.StructExpression) @@ -446,10 +465,14 @@ def can_coerce(self, t: HailType): if self.cases is None: return isinstance(t, tunion) else: - return (isinstance(t, tunion) - and len(t) == len(self.cases) - and all(expected[0] == actual[0] and expected[1].can_coerce(actual[1]) - for expected, actual in zip(self.cases.items(), t.items()))) + return ( + isinstance(t, tunion) + and len(t) == len(self.cases) + and all( + expected[0] == actual[0] and expected[1].can_coerce(actual[1]) + for expected, actual in zip(self.cases.items(), t.items()) + ) + ) def _coerce(self, x: Expression): assert isinstance(x, hl.expr.StructExpression) @@ -511,7 +534,7 @@ def _coerce(self, x: Expression) -> Expression: tfloat64: expr_float64, tbool: expr_bool, tcall: expr_call, - tstr: expr_str + tstr: expr_str, } @@ -531,8 +554,7 @@ def coercer_from_dtype(t: HailType) -> ExprCoercer: elif isinstance(t, tset): return expr_set(coercer_from_dtype(t.element_type)) elif isinstance(t, tdict): - return expr_dict(coercer_from_dtype(t.key_type), - coercer_from_dtype(t.value_type)) + return expr_dict(coercer_from_dtype(t.key_type), coercer_from_dtype(t.value_type)) elif isinstance(t, ttuple): return expr_tuple([coercer_from_dtype(t_) for t_ in t.types]) elif isinstance(t, tstruct): diff --git a/hail/python/hail/expr/expressions/expression_utils.py b/hail/python/hail/expr/expressions/expression_utils.py index c2ab0ed7133..0a95c4d4a87 100644 --- a/hail/python/hail/expr/expressions/expression_utils.py +++ b/hail/python/hail/expr/expressions/expression_utils.py @@ -1,22 +1,15 @@ -from typing import Set, Dict -from hail.typecheck import typecheck, setof +from typing import Dict, Set + +from hail.typecheck import setof, typecheck -from .indices import Indices, Aggregation -from ..expressions import Expression, ExpressionException, expr_any from ...ir import MakeTuple +from ..expressions import Expression, ExpressionException, expr_any +from .indices import Aggregation, Indices -@typecheck(caller=str, - expr=Expression, - expected_indices=Indices, - aggregation_axes=setof(str), - broadcast=bool) -def analyze(caller: str, - expr: Expression, - expected_indices: Indices, - aggregation_axes: Set = set(), - broadcast=True): - from hail.utils import warning, error +@typecheck(caller=str, expr=Expression, expected_indices=Indices, aggregation_axes=setof(str), broadcast=bool) +def analyze(caller: str, expr: Expression, expected_indices: Indices, aggregation_axes: Set = set(), broadcast=True): + from hail.utils import error, warning indices = expr._indices source = indices.source @@ -35,19 +28,20 @@ def analyze(caller: str, if inds.source is not expected_source: bad_refs.append(name) errors.append( - ExpressionException("'{caller}': source mismatch\n" - " Expected an expression from source {expected}\n" - " Found expression derived from source {actual}\n" - " Problematic field(s): {bad_refs}\n\n" - " This error is commonly caused by chaining methods together:\n" - " >>> ht.distinct().select(ht.x)\n\n" - " Correct usage:\n" - " >>> ht = ht.distinct()\n" - " >>> ht = ht.select(ht.x)".format( - caller=caller, - expected=expected_source, - actual=source, - bad_refs=list(bad_refs)))) + ExpressionException( + "'{caller}': source mismatch\n" + " Expected an expression from source {expected}\n" + " Found expression derived from source {actual}\n" + " Problematic field(s): {bad_refs}\n\n" + " This error is commonly caused by chaining methods together:\n" + " >>> ht.distinct().select(ht.x)\n\n" + " Correct usage:\n" + " >>> ht = ht.distinct()\n" + " >>> ht = ht.select(ht.x)".format( + caller=caller, expected=expected_source, actual=source, bad_refs=list(bad_refs) + ) + ) + ) # check for stray indices by subtracting expected axes from observed if broadcast: @@ -66,28 +60,32 @@ def analyze(caller: str, bad_axes = inds.axes.intersection(unexpected_axes) if bad_axes: bad_refs.append((name, inds)) - else: - if inds.axes != expected_axes: - bad_refs.append((name, inds)) + elif inds.axes != expected_axes: + bad_refs.append((name, inds)) assert len(bad_refs) > 0 - errors.append(ExpressionException( - "scope violation: '{caller}' expects an expression {strictness}indexed by {expected}" - "\n Found indices {axes}, with unexpected indices {stray}. Invalid fields:{fields}{agg}".format( - caller=caller, - strictness=strictness, - expected=list(expected_axes), - axes=list(indices.axes), - stray=list(unexpected_axes), - fields=''.join("\n '{}' (indices {})".format(name, list(inds.axes)) for name, inds in bad_refs), - agg='' if (unexpected_axes - aggregation_axes) else - "\n '{}' supports aggregation over axes {}, " - "so these fields may appear inside an aggregator function.".format(caller, list(aggregation_axes)) - ))) + errors.append( + ExpressionException( + "scope violation: '{caller}' expects an expression {strictness}indexed by {expected}" + "\n Found indices {axes}, with unexpected indices {stray}. Invalid fields:{fields}{agg}".format( + caller=caller, + strictness=strictness, + expected=list(expected_axes), + axes=list(indices.axes), + stray=list(unexpected_axes), + fields=''.join( + "\n '{}' (indices {})".format(name, list(inds.axes)) for name, inds in bad_refs + ), + agg='' + if (unexpected_axes - aggregation_axes) + else "\n '{}' supports aggregation over axes {}, " + "so these fields may appear inside an aggregator function.".format(caller, list(aggregation_axes)), + ) + ) + ) if aggregations: if aggregation_axes: - # the expected axes of aggregated expressions are the expected axes + axes aggregated over expected_agg_axes = expected_axes.union(aggregation_axes) @@ -108,17 +106,21 @@ def analyze(caller: str, assert len(bad_refs) > 0 - errors.append(ExpressionException( - "scope violation: '{caller}' supports aggregation over indices {expected}" - "\n Found indices {axes}, with unexpected indices {stray}. Invalid fields:{fields}".format( - caller=caller, - expected=list(aggregation_axes), - axes=list(agg_axes), - stray=list(unexpected_agg_axes), - fields=''.join("\n '{}' (indices {})".format( - name, list(inds.axes)) for name, inds in bad_refs) + errors.append( + ExpressionException( + "scope violation: '{caller}' supports aggregation over indices {expected}" + "\n Found indices {axes}, with unexpected indices {stray}. Invalid fields:{fields}".format( + caller=caller, + expected=list(aggregation_axes), + axes=list(agg_axes), + stray=list(unexpected_agg_axes), + fields=''.join( + "\n '{}' (indices {})".format(name, list(inds.axes)) + for name, inds in bad_refs + ), + ) ) - )) + ) else: errors.append(ExpressionException("'{}' does not support aggregation".format(caller))) @@ -147,6 +149,7 @@ def eval_timed(expression): """ from hail.utils.java import Env + analyze('eval', expression, Indices(expression._indices.source)) if expression._indices.source is None: ir_type = expression._ir.typ @@ -227,9 +230,8 @@ def _get_refs(expr: Expression, builder: Dict[str, Indices]) -> None: from hail.ir import GetField, TopLevelReference for ir in expr._ir.search( - lambda a: (isinstance(a, GetField) - and not a.name.startswith('__uid') - and isinstance(a.o, TopLevelReference))): + lambda a: (isinstance(a, GetField) and not a.name.startswith('__uid') and isinstance(a.o, TopLevelReference)) + ): src = expr._indices.source builder[ir.name] = src._indices_from_ref[ir.o.name] @@ -261,63 +263,59 @@ def get_refs(*exprs: Expression) -> Dict[str, Indices]: return builder -@typecheck(caller=str, - expr=Expression) +@typecheck(caller=str, expr=Expression) def matrix_table_source(caller, expr): from hail import MatrixTable + source = expr._indices.source if not isinstance(source, MatrixTable): raise ValueError( "{}: Expect an expression of 'MatrixTable', found {}".format( - caller, - "expression of '{}'".format(source.__class__) if source is not None else 'scalar expression')) + caller, "expression of '{}'".format(source.__class__) if source is not None else 'scalar expression' + ) + ) return source -@typecheck(caller=str, - expr=Expression) +@typecheck(caller=str, expr=Expression) def table_source(caller, expr): from hail import Table + source = expr._indices.source if not isinstance(source, Table): raise ValueError( "{}: Expect an expression of 'Table', found {}".format( - caller, - "expression of '{}'".format(source.__class__) if source is not None else 'scalar expression')) + caller, "expression of '{}'".format(source.__class__) if source is not None else 'scalar expression' + ) + ) return source @typecheck(caller=str, expr=Expression) def raise_unless_entry_indexed(caller, expr): if expr._indices.source is None: - raise ExpressionException(f"{caller}: expression must be entry-indexed" - f", found no indices (no source)" - ) + raise ExpressionException(f"{caller}: expression must be entry-indexed" f", found no indices (no source)") if expr._indices != expr._indices.source._entry_indices: - raise ExpressionException(f"{caller}: expression must be entry-indexed" - f", found indices {list(expr._indices.axes)}." - ) + raise ExpressionException( + f"{caller}: expression must be entry-indexed" f", found indices {list(expr._indices.axes)}." + ) @typecheck(caller=str, expr=Expression) def raise_unless_row_indexed(caller, expr): if expr._indices.source is None: - raise ExpressionException(f"{caller}: expression must be row-indexed" - f", found no indices (no source)." - ) + raise ExpressionException(f"{caller}: expression must be row-indexed" f", found no indices (no source).") if expr._indices != expr._indices.source._row_indices: - raise ExpressionException(f"{caller}: expression must be row-indexed" - f", found indices {list(expr._indices.axes)}." - ) + raise ExpressionException( + f"{caller}: expression must be row-indexed" f", found indices {list(expr._indices.axes)}." + ) @typecheck(caller=str, expr=Expression) def raise_unless_column_indexed(caller, expr): if expr._indices.source is None: - raise ExpressionException(f"{caller}: expression must be column-indexed" - f", found no indices (no source)." - ) + raise ExpressionException(f"{caller}: expression must be column-indexed" f", found no indices (no source).") if expr._indices != expr._indices.source._col_indices: - raise ExpressionException(f"{caller}: expression must be column-indexed" - f", found indices ({list(expr._indices.axes)})." - ) + raise ExpressionException( + f"{caller}: expression must be column-indexed" f", found indices ({list(expr._indices.axes)})." + ) diff --git a/hail/python/hail/expr/expressions/indices.py b/hail/python/hail/expr/expressions/indices.py index cf95618a406..a54d1619d5a 100644 --- a/hail/python/hail/expr/expressions/indices.py +++ b/hail/python/hail/expr/expressions/indices.py @@ -1,8 +1,8 @@ -from hail.typecheck import typecheck_method, anytype, setof -import hail as hl - from typing import List +import hail as hl +from hail.typecheck import anytype, setof, typecheck_method + class Indices(object): @typecheck_method(source=anytype, axes=setof(str)) @@ -27,10 +27,10 @@ def unify(*indices): for ind in indices: if src is None: src = ind.source - else: - if ind.source is not None and ind.source is not src: - from . import ExpressionException - raise ExpressionException() + elif ind.source is not None and ind.source is not src: + from . import ExpressionException + + raise ExpressionException() axes = axes.union(ind.axes) @@ -72,6 +72,7 @@ class Aggregation(object): def __init__(self, *exprs): self.exprs = exprs from ..expressions import unify_all + indices, agg = unify_all(*exprs) self.nested = agg self.indices = indices diff --git a/hail/python/hail/expr/expressions/typed_expressions.py b/hail/python/hail/expr/expressions/typed_expressions.py index 9b9a9351e2f..280cc7cd1f3 100644 --- a/hail/python/hail/expr/expressions/typed_expressions.py +++ b/hail/python/hail/expr/expressions/typed_expressions.py @@ -1,26 +1,64 @@ -from typing import Mapping, Dict, Sequence, Union +from typing import Dict, Mapping, Sequence, Union +import numpy as np from deprecated import deprecated import hail as hl -from .indices import Indices, Aggregation -from .base_expression import Expression, ExpressionException, to_expr, \ - unify_all, unify_types -from .expression_typecheck import coercer_from_dtype, \ - expr_any, expr_array, expr_set, expr_bool, expr_numeric, expr_int32, \ - expr_int64, expr_str, expr_dict, expr_interval, expr_tuple, expr_oneof, \ - expr_ndarray -from hail.expr.types import HailType, tint32, tint64, tfloat32, \ - tfloat64, tbool, tcall, tset, tarray, tstream, tstruct, tdict, ttuple,\ - tstr, tndarray, tlocus, tinterval, is_numeric -import hail.ir as ir -from hail.typecheck import typecheck, typecheck_method, func_spec, oneof, \ - identity, nullable, tupleof, sliceof, dictof, anyfunc +from hail import ir +from hail.expr.types import ( + HailType, + is_numeric, + tarray, + tbool, + tcall, + tdict, + tfloat32, + tfloat64, + tint32, + tint64, + tinterval, + tlocus, + tndarray, + tset, + tstr, + tstream, + tstruct, + ttuple, +) +from hail.typecheck import ( + anyfunc, + dictof, + func_spec, + identity, + nullable, + oneof, + sliceof, + tupleof, + typecheck, + typecheck_method, +) from hail.utils.java import Env, warning from hail.utils.linkedlist import LinkedList -from hail.utils.misc import wrap_to_list, wrap_to_tuple, get_nice_field_error, get_nice_attr_error - -import numpy as np +from hail.utils.misc import get_nice_attr_error, get_nice_field_error, wrap_to_list, wrap_to_tuple + +from .base_expression import Expression, ExpressionException, to_expr, unify_all, unify_types +from .expression_typecheck import ( + coercer_from_dtype, + expr_any, + expr_array, + expr_bool, + expr_dict, + expr_int32, + expr_int64, + expr_interval, + expr_ndarray, + expr_numeric, + expr_oneof, + expr_set, + expr_str, + expr_tuple, +) +from .indices import Aggregation, Indices class CollectionExpression(Expression): @@ -145,9 +183,9 @@ def find(self, f): """ # FIXME this should short-circuit - return self.fold(lambda accum, x: - hl.if_else(hl.is_missing(accum) & f(x), x, accum), - hl.missing(self._type.element_type)) + return self.fold( + lambda accum, x: hl.if_else(hl.is_missing(accum) & f(x), x, accum), hl.missing(self._type.element_type) + ) @typecheck_method(f=func_spec(1, expr_any)) def flatmap(self, f): @@ -177,7 +215,9 @@ def flatmap(self, f): value_type = f(construct_variable(Env.get_uid(), self.dtype.element_type)).dtype if not isinstance(value_type, expected_type): - raise TypeError("'flatmap' expects 'f' to return an expression of type '{}', found '{}'".format(s, value_type)) + raise TypeError( + "'flatmap' expects 'f' to return an expression of type '{}', found '{}'".format(s, value_type) + ) def f2(x): return hl.array(f(x)) if isinstance(value_type, tset) else f(x) @@ -274,7 +314,12 @@ def group_by(self, f): keyed = hl.array(self).map(lambda x: hl.tuple([f(x), x])) types = keyed.dtype.element_type.types - return construct_expr(ir.GroupByKey(ir.toStream(keyed._ir)), tdict(types[0], tarray(types[1])), keyed._indices, keyed._aggregations) + return construct_expr( + ir.GroupByKey(ir.toStream(keyed._ir)), + tdict(types[0], tarray(types[1])), + keyed._indices, + keyed._aggregations, + ) @typecheck_method(f=func_spec(1, expr_any)) def map(self, f): @@ -306,7 +351,9 @@ def transform_ir(array, name, body): a = ir.ToSet(ir.toStream(a)) return a - array_map = hl.array(self)._ir_lambda_method(transform_ir, f, self._type.element_type, lambda t: self._type.__class__(t)) + array_map = hl.array(self)._ir_lambda_method( + transform_ir, f, self._type.element_type, lambda t: self._type.__class__(t) + ) if isinstance(self._type, tset): return hl.set(array_map) @@ -391,12 +438,15 @@ def _summary_aggs(self): hl.agg.min(length), hl.agg.max(length), hl.agg.mean(length), - hl.agg.explode(lambda elt: elt._all_summary_aggs(), self))) + hl.agg.explode(lambda elt: elt._all_summary_aggs(), self), + )) def __contains__(self, element): class_name = type(self).__name__ - raise TypeError(f"Cannot use `in` operator on hail `{class_name}`s. Use the `contains` method instead." - "`names.contains('Charlie')` instead of `'Charlie' in names`") + raise TypeError( + f"Cannot use `in` operator on hail `{class_name}`s. Use the `contains` method instead." + "`names.contains('Charlie')` instead of `'Charlie' in names`" + ) class ArrayExpression(CollectionExpression): @@ -442,16 +492,16 @@ def __getitem__(self, item): return self._slice(item.start, item.stop, item.step) item = to_expr(item) if not item.dtype == tint32: - raise TypeError("array expects key to be type 'slice' or expression of type 'int32', " - "found expression of type '{}'".format(item._type)) + raise TypeError( + "array expects key to be type 'slice' or expression of type 'int32', " + "found expression of type '{}'".format(item._type) + ) else: return self._method("indexArray", self.dtype.element_type, item) @typecheck_method(start=nullable(expr_int32), stop=nullable(expr_int32), step=nullable(expr_int32)) def _slice(self, start=None, stop=None, step=None): - indices, aggregations = unify_all( - self, - *(x for x in (start, stop, step) if x is not None)) + indices, aggregations = unify_all(self, *(x for x in (start, stop, step) if x is not None)) if step is None: step = hl.int(1) if start is None: @@ -481,7 +531,6 @@ def aggregate(self, f): """ return hl.agg._aggregate_local_array(self, f) - @typecheck_method(item=expr_any) def contains(self, item): """Returns a boolean indicating whether `item` is found in the array. @@ -604,11 +653,15 @@ def index(self, x): None """ if callable(x): + def f(elt, x): return x(elt) + else: + def f(elt, x): return elt == x + return hl.bind(lambda a: hl.range(0, a.length()).filter(lambda i: f(a[i], x)).first(), self) @typecheck_method(item=expr_any) @@ -636,9 +689,11 @@ def append(self, item): :class:`.ArrayExpression` """ if item._type != self._type.element_type: - raise TypeError("'ArrayExpression.append' expects 'item' to be the same type as its elements\n" - " array element type: '{}'\n" - " type of arg 'item': '{}'".format(self._type._element_type, item._type)) + raise TypeError( + "'ArrayExpression.append' expects 'item' to be the same type as its elements\n" + " array element type: '{}'\n" + " type of arg 'item': '{}'".format(self._type._element_type, item._type) + ) return self._method("append", self._type, item) @typecheck_method(a=expr_array()) @@ -661,9 +716,11 @@ def extend(self, a): :class:`.ArrayExpression` """ if not a._type == self._type: - raise TypeError("'ArrayExpression.extend' expects 'a' to be the same type as the caller\n" - " caller type: '{}'\n" - " type of 'a': '{}'".format(self._type, a._type)) + raise TypeError( + "'ArrayExpression.extend' expects 'a' to be the same type as the caller\n" + " caller type: '{}'\n" + " type of 'a': '{}'".format(self._type, a._type) + ) return self._method("extend", self._type, a) @typecheck_method(f=func_spec(2, expr_any), zero=expr_any) @@ -714,11 +771,15 @@ def grouped(self, group_size): indices, aggregations = unify_all(self, group_size) stream_ir = ir.StreamGrouped(ir.toStream(self._ir), group_size._ir) mapping_identifier = Env.get_uid("stream_grouped_map_to_arrays") - mapped_to_arrays = ir.StreamMap(stream_ir, mapping_identifier, ir.toArray(ir.Ref(mapping_identifier, tstream(self._type.element_type)))) + mapped_to_arrays = ir.StreamMap( + stream_ir, mapping_identifier, ir.toArray(ir.Ref(mapping_identifier, tstream(self._type.element_type))) + ) return construct_expr(ir.toArray(mapped_to_arrays), tarray(self._type), indices, aggregations) def _to_stream(self): - return construct_expr(ir.toStream(self._ir), tstream(self.dtype.element_type), self._indices, self._aggregations) + return construct_expr( + ir.toStream(self._ir), tstream(self.dtype.element_type), self._indices, self._aggregations + ) class ArrayStructExpression(ArrayExpression): @@ -1042,9 +1103,11 @@ def add(self, item): Set with `item` added. """ if not self._ec.can_coerce(item.dtype): - raise TypeError("'SetExpression.add' expects 'item' to be the same type as its elements\n" - " set element type: '{}'\n" - " type of arg 'item': '{}'".format(self.dtype.element_type, item.dtype)) + raise TypeError( + "'SetExpression.add' expects 'item' to be the same type as its elements\n" + " set element type: '{}'\n" + " type of arg 'item': '{}'".format(self.dtype.element_type, item.dtype) + ) return self._method("add", self.dtype, self._ec.coerce(item)) @typecheck_method(item=expr_any) @@ -1068,9 +1131,11 @@ def remove(self, item): Set with `item` removed. """ if not self._ec.can_coerce(item.dtype): - raise TypeError("'SetExpression.remove' expects 'item' to be the same type as its elements\n" - " set element type: '{}'\n" - " type of arg 'item': '{}'".format(self.dtype.element_type, item.dtype)) + raise TypeError( + "'SetExpression.remove' expects 'item' to be the same type as its elements\n" + " set element type: '{}'\n" + " type of arg 'item': '{}'".format(self.dtype.element_type, item.dtype) + ) return self._method("remove", self._type, self._ec.coerce(item)) @typecheck_method(item=expr_any) @@ -1097,9 +1162,11 @@ def contains(self, item): ``True`` if `item` is in the set. """ if not self._ec.can_coerce(item.dtype): - raise TypeError("'SetExpression.contains' expects 'item' to be the same type as its elements\n" - " set element type: '{}'\n" - " type of arg 'item': '{}'".format(self.dtype.element_type, item.dtype)) + raise TypeError( + "'SetExpression.contains' expects 'item' to be the same type as its elements\n" + " set element type: '{}'\n" + " type of arg 'item': '{}'".format(self.dtype.element_type, item.dtype) + ) return self._method("contains", tbool, self._ec.coerce(item)) @typecheck_method(s=expr_set()) @@ -1126,9 +1193,11 @@ def difference(self, s): Set of elements not in `s`. """ if not s._type.element_type == self._type.element_type: - raise TypeError("'SetExpression.difference' expects 's' to be the same type\n" - " set type: '{}'\n" - " type of 's': '{}'".format(self._type, s._type)) + raise TypeError( + "'SetExpression.difference' expects 's' to be the same type\n" + " set type: '{}'\n" + " type of 's': '{}'".format(self._type, s._type) + ) return self._method("difference", self._type, s) @typecheck_method(s=expr_set()) @@ -1152,9 +1221,11 @@ def intersection(self, s): Set of elements present in `s`. """ if not s._type.element_type == self._type.element_type: - raise TypeError("'SetExpression.intersection' expects 's' to be the same type\n" - " set type: '{}'\n" - " type of 's': '{}'".format(self._type, s._type)) + raise TypeError( + "'SetExpression.intersection' expects 's' to be the same type\n" + " set type: '{}'\n" + " type of 's': '{}'".format(self._type, s._type) + ) return self._method("intersection", self._type, s) @typecheck_method(s=expr_set()) @@ -1181,9 +1252,11 @@ def is_subset(self, s): ``True`` if every element is contained in set `s`. """ if not s._type.element_type == self._type.element_type: - raise TypeError("'SetExpression.is_subset' expects 's' to be the same type\n" - " set type: '{}'\n" - " type of 's': '{}'".format(self._type, s._type)) + raise TypeError( + "'SetExpression.is_subset' expects 's' to be the same type\n" + " set type: '{}'\n" + " type of 's': '{}'".format(self._type, s._type) + ) return self._method("isSubset", tbool, s) @typecheck_method(s=expr_set()) @@ -1207,9 +1280,11 @@ def union(self, s): Set of elements present in either set. """ if not s._type.element_type == self._type.element_type: - raise TypeError("'SetExpression.union' expects 's' to be the same type\n" - " set type: '{}'\n" - " type of 's': '{}'".format(self._type, s._type)) + raise TypeError( + "'SetExpression.union' expects 's' to be the same type\n" + " set type: '{}'\n" + " type of 's': '{}'".format(self._type, s._type) + ) return self._method("union", self._type, s) def __le__(self, other): @@ -1509,9 +1584,11 @@ def __getitem__(self, item): Value associated with key `item`. """ if not self._kc.can_coerce(item.dtype): - raise TypeError("dict encountered an invalid key type\n" - " dict key type: '{}'\n" - " type of 'item': '{}'".format(self.dtype.key_type, item.dtype)) + raise TypeError( + "dict encountered an invalid key type\n" " dict key type: '{}'\n" " type of 'item': '{}'".format( + self.dtype.key_type, item.dtype + ) + ) return self._index(self.dtype.value_type, self._kc.coerce(item)) @typecheck_method(item=expr_any) @@ -1538,9 +1615,11 @@ def contains(self, item): ``True`` if `item` is a key of the dictionary, ``False`` otherwise. """ if not self._kc.can_coerce(item.dtype): - raise TypeError("'DictExpression.contains' encountered an invalid key type\n" - " dict key type: '{}'\n" - " type of 'item': '{}'".format(self._type.key_type, item.dtype)) + raise TypeError( + "'DictExpression.contains' encountered an invalid key type\n" + " dict key type: '{}'\n" + " type of 'item': '{}'".format(self._type.key_type, item.dtype) + ) return self._method("contains", tbool, self._kc.coerce(item)) @typecheck_method(item=expr_any, default=nullable(expr_any)) @@ -1572,16 +1651,21 @@ def get(self, item, default=None): The value associated with `item`, or `default`. """ if not self._kc.can_coerce(item.dtype): - raise TypeError("'DictExpression.get' encountered an invalid key type\n" - " dict key type: '{}'\n" - " type of 'item': '{}'".format(self.dtype.key_type, item.dtype)) + raise TypeError( + "'DictExpression.get' encountered an invalid key type\n" + " dict key type: '{}'\n" + " type of 'item': '{}'".format(self.dtype.key_type, item.dtype) + ) key = self._kc.coerce(item) if default is not None: if not self._vc.can_coerce(default.dtype): - raise TypeError("'get' expects parameter 'default' to have the same type " - "as the dictionary value type, expected '{}' and found '{}'" - .format(self.dtype.value_type, default.dtype)) + raise TypeError( + "'get' expects parameter 'default' to have the same type " + "as the dictionary value type, expected '{}' and found '{}'".format( + self.dtype.value_type, default.dtype + ) + ) return self._method("get", self.dtype.value_type, key, self._vc.coerce(default)) else: return self._method("get", self.dtype.value_type, key) @@ -1701,7 +1785,6 @@ def _nested_summary(self, agg_result, top): return { '[]': k._summarize(agg_result[3][0]), '[]': v._summarize(agg_result[3][1]), - } def _summary_aggs(self): @@ -1710,7 +1793,10 @@ def _summary_aggs(self): hl.agg.min(length), hl.agg.max(length), hl.agg.mean(length), - hl.agg.explode(lambda elt: hl.tuple((elt[0]._all_summary_aggs(), elt[1]._all_summary_aggs())), hl.array(self)))) + hl.agg.explode( + lambda elt: hl.tuple((elt[0]._all_summary_aggs(), elt[1]._all_summary_aggs())), hl.array(self) + ), + )) class StructExpression(Mapping[Union[str, int], Expression], Expression): @@ -1758,25 +1844,15 @@ def __init__(self, x, type, indices=Indices(), aggregations=LinkedList(Aggregati for i, (f, t) in enumerate(self.dtype.items()): if isinstance(self._ir, ir.MakeStruct): - expr = construct_expr(self._ir.fields[i][1], - t, - self._indices, - self._aggregations) + expr = construct_expr(self._ir.fields[i][1], t, self._indices, self._aggregations) elif isinstance(self._ir, ir.SelectedTopLevelReference): - expr = construct_expr(ir.ProjectedTopLevelReference(self._ir.ref.name, f, t), - t, - self._indices, - self._aggregations) + expr = construct_expr( + ir.ProjectedTopLevelReference(self._ir.ref.name, f, t), t, self._indices, self._aggregations + ) elif isinstance(self._ir, ir.SelectFields): - expr = construct_expr(ir.GetField(self._ir.old, f), - t, - self._indices, - self._aggregations) + expr = construct_expr(ir.GetField(self._ir.old, f), t, self._indices, self._aggregations) else: - expr = construct_expr(ir.GetField(self._ir, f), - t, - self._indices, - self._aggregations) + expr = construct_expr(ir.GetField(self._ir, f), t, self._indices, self._aggregations) self._set_field(f, expr) def _set_field(self, key, value): @@ -1797,8 +1873,10 @@ def _get_field(self, item): def __getattribute__(self, item): if item in super().__getattribute__('_warn_on_shadowed_name'): - warning(f'Field {item} is shadowed by another method or attribute. ' - f'Use ["{item}"] syntax to access the field.') + warning( + f'Field {item} is shadowed by another method or attribute. ' + f'Use ["{item}"] syntax to access the field.' + ) self._warn_on_shadowed_name.remove(item) return super().__getattribute__(item) @@ -1842,8 +1920,7 @@ def __getitem__(self, item): assert item.start is None or isinstance(item.start, int) assert item.stop is None or isinstance(item.stop, int) assert item.step is None or isinstance(item.step, int) - return self.select( - *self.dtype.fields[item.start:item.stop:item.step]) + return self.select(*self.dtype.fields[item.start : item.stop : item.step]) def __iter__(self): return iter(self._fields) @@ -1879,11 +1956,14 @@ def get_type(field): new_type = hl.tstruct(**{f: get_type(f) for f in field_order}) indices, aggregations = unify_all(self, *insertions_dict.values()) - return construct_expr(ir.InsertFields.construct_with_deduplication( - self._ir, [(field, expr._ir) for field, expr in insertions_dict.items()], field_order), + return construct_expr( + ir.InsertFields.construct_with_deduplication( + self._ir, [(field, expr._ir) for field, expr in insertions_dict.items()], field_order + ), new_type, indices, - aggregations) + aggregations, + ) @typecheck_method(named_exprs=expr_any) def annotate(self, **named_exprs): @@ -1919,9 +1999,14 @@ def annotate(self, **named_exprs): result_type = tstruct(**new_types) indices, aggregations = unify_all(self, *[x for (f, x) in named_exprs.items()]) - return construct_expr(ir.InsertFields.construct_with_deduplication( - self._ir, list(map(lambda x: (x[0], x[1]._ir), named_exprs.items())), None), - result_type, indices, aggregations) + return construct_expr( + ir.InsertFields.construct_with_deduplication( + self._ir, list(map(lambda x: (x[0], x[1]._ir), named_exprs.items())), None + ), + result_type, + indices, + aggregations, + ) @typecheck_method(fields=str, named_exprs=expr_any) def select(self, *fields, **named_exprs): @@ -1957,18 +2042,25 @@ def select(self, *fields, **named_exprs): name_set = set() for a in fields: if a not in self._fields: - raise KeyError("Struct has no field '{}'\n" - " Fields: [ {} ]".format(a, ', '.join("'{}'".format(x) for x in self._fields))) + raise KeyError( + "Struct has no field '{}'\n" " Fields: [ {} ]".format( + a, ', '.join("'{}'".format(x) for x in self._fields) + ) + ) if a in name_set: - raise ExpressionException("'StructExpression.select' does not support duplicate identifiers.\n" - " Identifier '{}' appeared more than once".format(a)) + raise ExpressionException( + "'StructExpression.select' does not support duplicate identifiers.\n" + " Identifier '{}' appeared more than once".format(a) + ) name_set.add(a) - for (n, _) in named_exprs.items(): + for n, _ in named_exprs.items(): if n in name_set: raise ExpressionException("Cannot select and assign '{}' in the same 'select' call".format(n)) selected_type = tstruct(**{f: self.dtype[f] for f in fields}) - selected_expr = construct_expr(ir.SelectFields(self._ir, fields), selected_type, self._indices, self._aggregations) + selected_expr = construct_expr( + ir.SelectFields(self._ir, fields), selected_type, self._indices, self._aggregations + ) if len(named_exprs) == 0: return selected_expr @@ -2012,15 +2104,13 @@ def rename(self, mapping): if old not in old_fields: raise ValueError(f'{old} is not a field of this struct: {self.dtype}.') if new in old_fields and new not in mapping: - raise ValueError(f'{old} is renamed to {new} but {new} is already in the ' - f'struct: {self.dtype}.') + raise ValueError(f'{old} is renamed to {new} but {new} is already in the ' f'struct: {self.dtype}.') if new in new_to_old: raise ValueError(f'{new} is the new name of both {old} and {new_to_old[new]}.') new_to_old[new] = old return self.select( - *list(set(self._fields) - set(mapping)), - **{new: self._get_field(old) for old, new in mapping.items()} + *list(set(self._fields) - set(mapping)), **{new: self._get_field(old) for old, new in mapping.items()} ) @typecheck_method(fields=str) @@ -2046,8 +2136,11 @@ def drop(self, *fields): to_drop = set() for a in fields: if a not in self._fields: - raise KeyError("Struct has no field '{}'\n" - " Fields: [ {} ]".format(a, ', '.join("'{}'".format(x) for x in self._fields))) + raise KeyError( + "Struct has no field '{}'\n" " Fields: [ {} ]".format( + a, ', '.join("'{}'".format(x) for x in self._fields) + ) + ) if a in to_drop: warning("Found duplicate field name in 'StructExpression.drop': '{}'".format(a)) to_drop.add(a) @@ -2057,11 +2150,13 @@ def drop(self, *fields): def flatten(self): """Recursively eliminate struct fields by adding their fields to this struct.""" + def _flatten(prefix, s): if isinstance(s, StructExpression): return [(k, v) for (f, e) in s.items() for (k, v) in _flatten(prefix + '.' + f, e)] else: return [(prefix, s)] + return self.select(**{k: v for (f, e) in self.items() for (k, v) in _flatten(f, e)}) def _nested_summary(self, agg_result, top): @@ -2117,13 +2212,11 @@ def __getitem__(self, item): assert item.start is None or isinstance(item.start, int) assert item.stop is None or isinstance(item.stop, int) assert item.step is None or isinstance(item.step, int) - return hl.or_missing(hl.is_defined(self), - hl.tuple([ - self[i] - for i in range(len(self))[item.start:item.stop:item.step]])) + return hl.or_missing( + hl.is_defined(self), hl.tuple([self[i] for i in range(len(self))[item.start : item.stop : item.step]]) + ) if not 0 <= item < len(self): - raise IndexError("Out of bounds index, {}. Tuple length is {}.".format( - item, len(self))) + raise IndexError("Out of bounds index, {}. Tuple length is {}.".format(item, len(self))) return construct_expr(ir.GetTupleElement(self._ir, item), self.dtype.types[item], self._indices) def __len__(self): @@ -2618,7 +2711,7 @@ def _extra_summary_fields(self, agg_result): 'Minimum': agg_result['min'], 'Maximum': agg_result['max'], 'Mean': agg_result['mean'], - 'Std Dev': agg_result['stdev'] + 'Std Dev': agg_result['stdev'], } def _summary_aggs(self): @@ -2633,7 +2726,7 @@ def _extra_summary_fields(self, agg_result): 'Minimum': agg_result['min'], 'Maximum': agg_result['max'], 'Mean': agg_result['mean'], - 'Std Dev': agg_result['stdev'] + 'Std Dev': agg_result['stdev'], } def _summary_aggs(self): @@ -2642,12 +2735,13 @@ def _summary_aggs(self): class Int32Expression(NumericExpression): """Expression of type :py:data:`.tint32`.""" + def _extra_summary_fields(self, agg_result): return { 'Minimum': int(agg_result['min']), 'Maximum': int(agg_result['max']), 'Mean': agg_result['mean'], - 'Std Dev': agg_result['stdev'] + 'Std Dev': agg_result['stdev'], } def _summary_aggs(self): @@ -2670,12 +2764,13 @@ def __rmul__(self, other): class Int64Expression(NumericExpression): """Expression of type :py:data:`.tint64`.""" + def _extra_summary_fields(self, agg_result): return { 'Minimum': int(agg_result['min']), 'Maximum': int(agg_result['max']), 'Mean': agg_result['mean'], - 'Std Dev': agg_result['stdev'] + 'Std Dev': agg_result['stdev'], } def _summary_aggs(self): @@ -2715,13 +2810,17 @@ def __getitem__(self, item): else: item = to_expr(item) if not item.dtype == tint32: - raise TypeError("String expects index to be type 'slice' or expression of type 'int32', " - "found expression of type '{}'".format(item.dtype)) + raise TypeError( + "String expects index to be type 'slice' or expression of type 'int32', " + "found expression of type '{}'".format(item.dtype) + ) return self._index(tstr, item) def __contains__(self, item): - raise TypeError("Cannot use `in` operator on hail `StringExpression`s. Use the `contains` method instead." - "`my_string.contains('cat')` instead of `'cat' in my_string`") + raise TypeError( + "Cannot use `in` operator on hail `StringExpression`s. Use the `contains` method instead." + "`my_string.contains('cat')` instead of `'cat' in my_string`" + ) def __add__(self, other): """Concatenate strings. @@ -2774,12 +2873,11 @@ def _slice(self, start=None, stop=None, step=None): return self._method('slice', tstr, start, stop) else: return self._method('sliceRight', tstr, start) + elif stop is not None: + stop = to_expr(stop) + return self._method('sliceLeft', tstr, stop) else: - if stop is not None: - stop = to_expr(stop) - return self._method('sliceLeft', tstr, stop) - else: - return self + return self def length(self): """Returns the length of the string. @@ -3203,7 +3301,8 @@ def _summary_aggs(self): hl.agg.min(length), hl.agg.max(length), hl.agg.mean(length), - hl.agg.filter(hl.is_defined(self), hl.agg.take(self, 5)))) + hl.agg.filter(hl.is_defined(self), hl.agg.take(self, 5)), + )) class CallExpression(Expression): @@ -3240,8 +3339,10 @@ def __getitem__(self, item): else: item = to_expr(item) if not item.dtype == tint32: - raise TypeError("Call expects allele index to be an expression of type 'int32', " - "found expression of type '{}'".format(item.dtype)) + raise TypeError( + "Call expects allele index to be an expression of type 'int32', " + "found expression of type '{}'".format(item.dtype) + ) return self._index(tint32, item) def unphase(self): @@ -3532,7 +3633,7 @@ def _extra_summary_fields(self, agg_result): 'Heterozygous': agg_result[1], 'Homozygous Variant': agg_result[2], 'Ploidy': agg_result[3], - 'Phased': agg_result[4] + 'Phased': agg_result[4], } def _summary_aggs(self): @@ -3541,7 +3642,8 @@ def _summary_aggs(self): hl.agg.count_where(self.is_het()), hl.agg.count_where(self.is_hom_var()), hl.agg.filter(hl.is_defined(self), hl.agg.counter(self.ploidy)), - hl.agg.filter(hl.is_defined(self), hl.agg.counter(self.phased)))) + hl.agg.filter(hl.is_defined(self), hl.agg.counter(self.phased)), + )) class LocusExpression(Expression): @@ -3810,7 +3912,11 @@ def sequence_context(self, before=0, after=0): rg = self.dtype.reference_genome if not rg.has_sequence(): - raise TypeError("Reference genome '{}' does not have a sequence loaded. Use 'add_sequence' to load the sequence from a FASTA file.".format(rg.name)) + raise TypeError( + "Reference genome '{}' does not have a sequence loaded. Use 'add_sequence' to load the sequence from a FASTA file.".format( + rg.name + ) + ) return hl.get_sequence(self.contig, self.position, before, after, rg) @typecheck_method(before=expr_int32, after=expr_int32) @@ -3846,10 +3952,12 @@ def window(self, before, after): start_pos = hl.max(1, self.position - before) rg = self.dtype.reference_genome end_pos = hl.min(hl.contig_length(self.contig, rg), self.position + after) - return hl.interval(start=hl.locus(self.contig, start_pos, reference_genome=rg), - end=hl.locus(self.contig, end_pos, reference_genome=rg), - includes_start=True, - includes_end=True) + return hl.interval( + start=hl.locus(self.contig, start_pos, reference_genome=rg), + end=hl.locus(self.contig, end_pos, reference_genome=rg), + includes_start=True, + includes_end=True, + ) def _extra_summary_fields(self, agg_result): return {'Contig Counts': agg_result} @@ -4054,8 +4162,10 @@ def transpose(self, axes=None): axes = list(reversed(range(self.ndim))) else: if len(axes) != self.ndim: - raise ValueError(f'Must specify a complete permutation of the dimensions. ' - f'Expected {self.ndim} axes, got {len(axes)}') + raise ValueError( + f'Must specify a complete permutation of the dimensions. ' + f'Expected {self.ndim} axes, got {len(axes)}' + ) if len(set(axes)) != len(axes): raise ValueError(f'Axes cannot contain duplicates: {axes}') @@ -4087,14 +4197,20 @@ def shape(self): _opt_long_slice = sliceof(nullable(expr_int64), nullable(expr_int64), nullable(expr_int64)) - @typecheck_method(item=nullable(oneof(expr_int64, type(...), _opt_long_slice, tupleof(nullable(oneof(expr_int64, type(...), _opt_long_slice)))))) + @typecheck_method( + item=nullable( + oneof( + expr_int64, type(...), _opt_long_slice, tupleof(nullable(oneof(expr_int64, type(...), _opt_long_slice))) + ) + ) + ) def __getitem__(self, item): if not isinstance(item, tuple): item = (item,) num_ellipses = len([e for e in item if isinstance(e, type(...))]) if num_ellipses > 1: - raise IndexError("an index can only have a single ellipsis (\'...\')") + raise IndexError("an index can only have a single ellipsis ('...')") num_nones = len([x for x in item if x is None]) list_item = list(item) @@ -4103,7 +4219,9 @@ def __getitem__(self, item): list_types = [type(e) for e in list_item] ellipsis_location = list_types.index(type(...)) num_slices_to_add = self.ndim - (len(item) - num_nones) + 1 - no_ellipses = list_item[:ellipsis_location] + [slice(None)] * num_slices_to_add + list_item[ellipsis_location + 1:] + no_ellipses = ( + list_item[:ellipsis_location] + [slice(None)] * num_slices_to_add + list_item[ellipsis_location + 1 :] + ) else: no_ellipses = list_item @@ -4112,8 +4230,9 @@ def __getitem__(self, item): formatted_item = [x for x in no_ellipses if x is not None] if len(formatted_item) > self.ndim: - raise IndexError(f'too many indices for array: array is ' - f'{self.ndim}-dimensional, but {len(item)} were indexed') + raise IndexError( + f'too many indices for array: array is ' f'{self.ndim}-dimensional, but {len(item)} were indexed' + ) if len(formatted_item) < self.ndim: formatted_item += [slice(None, None, None)] * (self.ndim - len(formatted_item)) @@ -4124,10 +4243,8 @@ def __getitem__(self, item): for i, s in enumerate(formatted_item): dlen = self.shape[i] if isinstance(s, slice): - if s.step is not None: - step = hl.case().when(s.step != 0, s.step) \ - .or_error("Slice step cannot be zero") + step = hl.case().when(s.step != 0, s.step).or_error("Slice step cannot be zero") else: step = to_expr(1, tint64) @@ -4137,37 +4254,50 @@ def __getitem__(self, item): if s.start is not None: # python treats start < -dlen as None when step < 0: [0,1][-3:0:-1] # and 0 otherwise: [0,1][-3::1] == [0,1][0::1] - start = hl.case() \ - .when(s.start >= dlen, max_bound) \ - .when(s.start >= 0, s.start) \ - .when((s.start + dlen) >= 0, dlen + s.start) \ + start = ( + hl.case() + .when(s.start >= dlen, max_bound) + .when(s.start >= 0, s.start) + .when((s.start + dlen) >= 0, dlen + s.start) .default(min_bound) + ) else: start = hl.if_else(step >= 0, to_expr(0, tint64), dlen - 1) if s.stop is not None: # python treats stop < -dlen as None when step < 0: [0,1][0:-3:-1] == [0,1][0::-1] # and 0 otherwise: [0,1][:-3:1] == [0,1][:0:1] - stop = hl.case() \ - .when(s.stop >= dlen, max_bound) \ - .when(s.stop >= 0, s.stop) \ - .when((s.stop + dlen) >= 0, dlen + s.stop) \ + stop = ( + hl.case() + .when(s.stop >= dlen, max_bound) + .when(s.stop >= 0, s.stop) + .when((s.stop + dlen) >= 0, dlen + s.stop) .default(min_bound) + ) else: stop = hl.if_else(step > 0, dlen, to_expr(-1, tint64)) slices.append(hl.tuple((start, stop, step))) else: adjusted_index = hl.if_else(s < 0, s + dlen, s) - checked_int = hl.case().when((adjusted_index < dlen) & (adjusted_index >= 0), adjusted_index).or_error( - hl.str("Index ") + hl.str(s) + hl.str(f" is out of bounds for axis {i} with size ") + hl.str(dlen) + checked_int = ( + hl.case() + .when((adjusted_index < dlen) & (adjusted_index >= 0), adjusted_index) + .or_error( + hl.str("Index ") + + hl.str(s) + + hl.str(f" is out of bounds for axis {i} with size ") + + hl.str(dlen) + ) ) slices.append(checked_int) indices, aggregations = unify_all(self, *slices) - product = construct_expr(ir.NDArraySlice(self._ir, hl.tuple(slices)._ir), - tndarray(self._type.element_type, n_sliced_dims), - indices, - aggregations) + product = construct_expr( + ir.NDArraySlice(self._ir, hl.tuple(slices)._ir), + tndarray(self._type.element_type, n_sliced_dims), + indices, + aggregations, + ) if len(indices_nones) > 0: reshape_arg = [] @@ -4182,10 +4312,12 @@ def __getitem__(self, item): else: indices, aggregations = unify_all(self, *formatted_item) - product = construct_expr(ir.NDArrayRef(self._ir, [idx._ir for idx in formatted_item]), - self._type.element_type, - indices, - aggregations) + product = construct_expr( + ir.NDArrayRef(self._ir, [idx._ir for idx in formatted_item]), + self._type.element_type, + indices, + aggregations, + ) if len(indices_nones) > 0: reshape_arg = [] @@ -4239,10 +4371,9 @@ def reshape(self, *shape): ndim = len(wrapped_shape) shape_ir = hl.tuple(wrapped_shape)._ir - return construct_expr(ir.NDArrayReshape(self._ir, shape_ir), - tndarray(self._type.element_type, ndim), - indices, - aggregations) + return construct_expr( + ir.NDArrayReshape(self._ir, shape_ir), tndarray(self._type.element_type, ndim), indices, aggregations + ) @typecheck_method(f=func_spec(1, expr_any)) def map(self, f): @@ -4283,15 +4414,16 @@ def map2(self, other, f): Element-wise result of applying `f` to each index in NDArrays. """ - if isinstance(other, list) or isinstance(other, np.ndarray): + if isinstance(other, (list, np.ndarray)): other = hl.nd.array(other) self_broadcast, other_broadcast = self._broadcast_to_same_ndim(other) element_type1 = self_broadcast._type.element_type element_type2 = other_broadcast._type.element_type - ndarray_map2 = self_broadcast._ir_lambda_method2(other_broadcast, ir.NDArrayMap2, f, element_type1, - element_type2, lambda t: tndarray(t, self_broadcast.ndim)) + ndarray_map2 = self_broadcast._ir_lambda_method2( + other_broadcast, ir.NDArrayMap2, f, element_type1, element_type2, lambda t: tndarray(t, self_broadcast.ndim) + ) assert isinstance(self._type, tndarray) return ndarray_map2 @@ -4317,9 +4449,12 @@ def _broadcast(self, n_output_dims): new_dims = range(self.ndim, n_output_dims) idx_mapping = list(reversed(new_dims)) + list(old_dims) - return construct_expr(ir.NDArrayReindex(self._ir, idx_mapping), - tndarray(self._type.element_type, n_output_dims), - self._indices, self._aggregations) + return construct_expr( + ir.NDArrayReindex(self._ir, idx_mapping), + tndarray(self._type.element_type, n_output_dims), + self._indices, + self._aggregations, + ) class NDArrayNumericExpression(NDArrayExpression): @@ -4335,18 +4470,20 @@ class NDArrayNumericExpression(NDArrayExpression): """ def _bin_op_numeric(self, name, other, ret_type_f=None): - if isinstance(other, list) or isinstance(other, np.ndarray): + if isinstance(other, (list, np.ndarray)): other = hl.nd.array(other) self_broadcast, other_broadcast = self._broadcast_to_same_ndim(other) return super(NDArrayNumericExpression, self_broadcast)._bin_op_numeric(name, other_broadcast, ret_type_f) def _bin_op_numeric_reverse(self, name, other, ret_type_f=None): - if isinstance(other, list) or isinstance(other, np.ndarray): + if isinstance(other, (list, np.ndarray)): other = hl.nd.array(other) self_broadcast, other_broadcast = self._broadcast_to_same_ndim(other) - return super(NDArrayNumericExpression, self_broadcast)._bin_op_numeric_reverse(name, other_broadcast, ret_type_f) + return super(NDArrayNumericExpression, self_broadcast)._bin_op_numeric_reverse( + name, other_broadcast, ret_type_f + ) def __neg__(self): """Negate elements of the ndarray. @@ -4489,6 +4626,7 @@ def __matmul__(self, other): left, right = self, other from hail.linalg.utils.misc import _ndarray_matmul_ndim + result_ndim = _ndarray_matmul_ndim(left.ndim, right.ndim) elem_type = unify_types(self._type.element_type, other._type.element_type) ret_type = tndarray(elem_type, result_ndim) @@ -4532,7 +4670,9 @@ def sum(self, axis=None): num_axes_deleted = len(axes_set) result_ndim = self.ndim - num_axes_deleted - result = construct_expr(res_ir, tndarray(self._type.element_type, result_ndim), self._indices, self._aggregations) + result = construct_expr( + res_ir, tndarray(self._type.element_type, result_ndim), self._indices, self._aggregations + ) if result_ndim == 0: return result[()] @@ -4603,10 +4743,12 @@ def fold(self, f, zero): zero = zero_coerced if body.dtype != zero.dtype: - raise ExpressionException("'StreamExpression.fold' must take function returning " - "same expression type as zero value: \n" - " zero.dtype: {}\n" - " f.dtype: {}".format(zero.dtype, body.dtype)) + raise ExpressionException( + "'StreamExpression.fold' must take function returning " + "same expression type as zero value: \n" + " zero.dtype: {}\n" + " f.dtype: {}".format(zero.dtype, body.dtype) + ) x = ir.StreamFold(self._ir, zero._ir, accum_name, elt_name, body._ir) @@ -4638,10 +4780,12 @@ def scan(self, f, zero): zero = zero_coerced if body.dtype != zero.dtype: - raise ExpressionException("'StreamExpression.scan' must take function returning " - "same expression type as zero value: \n" - " zero.dtype: {}\n" - " f.dtype: {}".format(zero.dtype, body.dtype)) + raise ExpressionException( + "'StreamExpression.scan' must take function returning " + "same expression type as zero value: \n" + " zero.dtype: {}\n" + " f.dtype: {}".format(zero.dtype, body.dtype) + ) x = ir.StreamScan(self._ir, zero._ir, accum_name, elt_name, body._ir) @@ -4666,15 +4810,10 @@ def zip_with_index(self, start, index_first=True): else: tuple = ir.MakeTuple([ir.Ref(elt, elt_type), ir.Ref(idx, tint32)]) return construct_expr( - ir.StreamZip( - [self._ir, ir.StreamIota(start._ir, ir.I32(1))], - [elt, idx], - tuple, - 'TakeMinLength' - ), + ir.StreamZip([self._ir, ir.StreamIota(start._ir, ir.I32(1))], [elt, idx], tuple, 'TakeMinLength'), hl.tstream(hl.ttuple(hl.tint32, elt_type) if index_first else hl.ttuple(elt_type, hl.tint32)), indices, - aggs + aggs, ) @typecheck_method(group_size=expr_int32) @@ -4682,17 +4821,21 @@ def grouped(self, group_size): indices, aggregations = unify_all(self, group_size) stream_ir = ir.StreamGrouped(self._ir, group_size._ir) mapping_identifier = Env.get_uid("stream_grouped_map_to_arrays") - mapped_to_arrays = ir.StreamMap(stream_ir, mapping_identifier, ir.toArray(ir.Ref(mapping_identifier, tstream(self._type.element_type)))) + mapped_to_arrays = ir.StreamMap( + stream_ir, mapping_identifier, ir.toArray(ir.Ref(mapping_identifier, tstream(self._type.element_type))) + ) return construct_expr(mapped_to_arrays, tstream(tarray(self._type.element_type)), indices, aggregations) -scalars = {tbool: BooleanExpression, - tint32: Int32Expression, - tint64: Int64Expression, - tfloat32: Float32Expression, - tfloat64: Float64Expression, - tstr: StringExpression, - tcall: CallExpression} +scalars = { + tbool: BooleanExpression, + tint32: Int32Expression, + tint64: Int64Expression, + tfloat32: Float32Expression, + tfloat64: Float64Expression, + tstr: StringExpression, + tcall: CallExpression, +} typ_to_expr = { tlocus: LocusExpression, @@ -4704,7 +4847,7 @@ def grouped(self, group_size): tset: SetExpression, tstruct: StructExpression, ttuple: TupleExpression, - tndarray: NDArrayExpression + tndarray: NDArrayExpression, } @@ -4715,10 +4858,9 @@ def apply_expr(f, result_type, *args): @typecheck(x=ir.IR, type=nullable(HailType), indices=Indices, aggregations=LinkedList) -def construct_expr(x: ir.IR, - type: HailType, - indices: Indices = Indices(), - aggregations: LinkedList = LinkedList(Aggregation)): +def construct_expr( + x: ir.IR, type: HailType, indices: Indices = Indices(), aggregations: LinkedList = LinkedList(Aggregation) +): if type is None: return Expression(x, None, indices, aggregations) x.assign_type(type) @@ -4760,7 +4902,5 @@ def construct_reference(name, type, indices): @typecheck(name=str, type=HailType, indices=Indices, aggregations=LinkedList) -def construct_variable(name, type, - indices: Indices = Indices(), - aggregations: LinkedList = LinkedList(Aggregation)): +def construct_variable(name, type, indices: Indices = Indices(), aggregations: LinkedList = LinkedList(Aggregation)): return construct_expr(ir.Ref(name, type), type, indices, aggregations) diff --git a/hail/python/hail/expr/functions.py b/hail/python/hail/expr/functions.py index 97d08feb2df..6d6d67dcf7d 100644 --- a/hail/python/hail/expr/functions.py +++ b/hail/python/hail/expr/functions.py @@ -1,46 +1,119 @@ -import operator import builtins import functools -from typing import Union, Optional, Any, Callable, Iterable, TypeVar -import pandas as pd +import operator +from typing import Any, Callable, Iterable, Optional, TypeVar, Union +import numpy as np +import pandas as pd from deprecated import deprecated import hail import hail as hl -from hail.expr.expressions import (Expression, ArrayExpression, StreamExpression, SetExpression, - Int32Expression, Int64Expression, Float32Expression, Float64Expression, - DictExpression, StructExpression, LocusExpression, StringExpression, - IntervalExpression, ArrayNumericExpression, BooleanExpression, - CallExpression, TupleExpression, ExpressionException, NumericExpression, - unify_all, construct_expr, to_expr, unify_exprs, impute_type, - construct_variable, apply_expr, coercer_from_dtype, unify_types_limited, - expr_array, expr_any, expr_struct, expr_int32, expr_int64, expr_float32, - expr_float64, expr_oneof, expr_bool, expr_tuple, expr_dict, expr_str, expr_stream, - expr_set, expr_call, expr_locus, expr_interval, expr_ndarray, expr_numeric, - cast_expr) -from hail.expr.types import (HailType, hail_type, tint32, tint64, tfloat32, - tfloat64, tstr, tbool, tarray, tstream, tset, tdict, - tstruct, tlocus, tinterval, tcall, ttuple, - tndarray, trngstate, is_primitive, is_numeric, - is_int32, is_int64, is_float32, is_float64) -from hail.genetics.reference_genome import reference_genome_type, ReferenceGenome -import hail.ir as ir -from hail.typecheck import (typecheck, nullable, anytype, enumeration, tupleof, - func_spec, oneof, arg_check, args_check, anyfunc, - sequenceof) +from hail import ir +from hail.expr.expressions import ( + ArrayExpression, + ArrayNumericExpression, + BooleanExpression, + CallExpression, + DictExpression, + Expression, + ExpressionException, + Float32Expression, + Float64Expression, + Int32Expression, + Int64Expression, + IntervalExpression, + LocusExpression, + NumericExpression, + SetExpression, + StreamExpression, + StringExpression, + StructExpression, + TupleExpression, + apply_expr, + cast_expr, + coercer_from_dtype, + construct_expr, + construct_variable, + expr_any, + expr_array, + expr_bool, + expr_call, + expr_dict, + expr_float32, + expr_float64, + expr_int32, + expr_int64, + expr_interval, + expr_locus, + expr_ndarray, + expr_numeric, + expr_oneof, + expr_set, + expr_str, + expr_stream, + expr_struct, + expr_tuple, + impute_type, + to_expr, + unify_all, + unify_exprs, + unify_types_limited, +) +from hail.expr.types import ( + HailType, + hail_type, + is_float32, + is_float64, + is_int32, + is_int64, + is_numeric, + is_primitive, + tarray, + tbool, + tcall, + tdict, + tfloat32, + tfloat64, + tint32, + tint64, + tinterval, + tlocus, + tndarray, + trngstate, + tset, + tstr, + tstream, + tstruct, + ttuple, +) +from hail.genetics.allele_type import AlleleType +from hail.genetics.reference_genome import ReferenceGenome, reference_genome_type +from hail.typecheck import ( + anyfunc, + anytype, + arg_check, + args_check, + enumeration, + func_spec, + nullable, + oneof, + sequenceof, + tupleof, + typecheck, +) from hail.utils.java import Env, warning from hail.utils.misc import plural -import numpy as np - -Coll_T = TypeVar('Collection_T', ArrayExpression, SetExpression) -Num_T = TypeVar('Numeric_T', Int32Expression, Int64Expression, Float32Expression, Float64Expression) +Coll_T = TypeVar('Coll_T', ArrayExpression, SetExpression) +Num_T = TypeVar('Num_T', Int32Expression, Int64Expression, Float32Expression, Float64Expression) def _func(name, ret_type, *args, type_args=()): indices, aggregations = unify_all(*args) - return construct_expr(ir.Apply(name, ret_type, *(a._ir for a in args), type_args=type_args), ret_type, indices, aggregations) + return construct_expr( + ir.Apply(name, ret_type, *(a._ir for a in args), type_args=type_args), ret_type, indices, aggregations + ) def _seeded_func(name, ret_type, seed, *args): @@ -48,13 +121,20 @@ def _seeded_func(name, ret_type, seed, *args): static_rng_uid = Env.next_static_rng_uid() else: if Env._hc is None or not Env._hc._user_specified_rng_nonce: - warning('To ensure reproducible randomness across Hail sessions, ' - 'you must set the "global_seed" parameter in hl.init(), in ' - 'addition to the local seed in each random function.') + warning( + 'To ensure reproducible randomness across Hail sessions, ' + 'you must set the "global_seed" parameter in hl.init(), in ' + 'addition to the local seed in each random function.' + ) static_rng_uid = -seed - 1 indices, aggregations = unify_all(*args) rng_state = ir.Ref('__rng_state', trngstate) - return construct_expr(ir.ApplySeeded(name, static_rng_uid, rng_state, ret_type, *(a._ir for a in args)), ret_type, indices, aggregations) + return construct_expr( + ir.ApplySeeded(name, static_rng_uid, rng_state, ret_type, *(a._ir for a in args)), + ret_type, + indices, + aggregations, + ) def ndarray_broadcasting(func): @@ -63,6 +143,7 @@ def broadcast_or_not(x): return x.map(func) else: return func(x) + return broadcast_or_not @@ -80,20 +161,24 @@ def compute(cdf): n = cdf.ranks[cdf.ranks.length() - 1] pos = hl.int64(q * n) + 1 idx = hl.max(0, hl.min(cdf['values'].length() - 1, _lower_bound(cdf.ranks, pos) - 1)) - res = hl.if_else(n == 0, - hl.missing(cdf['values'].dtype.element_type), - cdf['values'][idx]) + res = hl.if_else(n == 0, hl.missing(cdf['values'].dtype.element_type), cdf['values'][idx]) return res + return hl.rbind(cdf, compute) @typecheck(raw_cdf=expr_struct()) def _result_from_raw_cdf(raw_cdf): levels = raw_cdf.levels - item_weights = hl._stream_range(hl.len(levels) - 1) \ - .flatmap(lambda l: hl._stream_range(levels[l], levels[l+1]) - .map(lambda i: hl.struct(level=l, value=raw_cdf['items'][i]))) \ + item_weights = ( + hl._stream_range(hl.len(levels) - 1) + .flatmap( + lambda l: hl._stream_range(levels[l], levels[l + 1]).map( + lambda i: hl.struct(level=l, value=raw_cdf['items'][i]) + ) + ) .aggregate(lambda x: hl.agg.group_by(x.value, hl.agg.sum(hl.bit_lshift(1, x.level)))) + ) weights = item_weights.values() ranks = weights.scan(lambda acc, weight: acc + weight, 0) values = item_weights.keys() @@ -125,8 +210,11 @@ def _error_from_cdf(cdf, failure_prob, all_quantiles=False): :class:`.NumericExpression` Upper bound on error of quantile estimates. """ + def compute_sum(cdf): - s = hl.sum(hl.range(0, hl.len(cdf._compaction_counts)).map(lambda i: cdf._compaction_counts[i] * (2 ** (2 * i)))) + s = hl.sum( + hl.range(0, hl.len(cdf._compaction_counts)).map(lambda i: cdf._compaction_counts[i] * (2 ** (2 * i))) + ) return s / (cdf.ranks[-1] ** 2) def update_grid_size(p, s): @@ -305,23 +393,26 @@ def typecheck_expr(t, x): if isinstance(x, Expression): wrapper['has_expr'] = True wrapper['has_free_vars'] |= ( - builtins.len(x._ir.free_vars) > 0 or - builtins.len(x._ir.free_agg_vars) > 0 or - builtins.len(x._ir.free_scan_vars) > 0 + builtins.len(x._ir.free_vars) > 0 + or builtins.len(x._ir.free_agg_vars) > 0 + or builtins.len(x._ir.free_scan_vars) > 0 ) if x.dtype != t: raise TypeError(f"'literal': type mismatch: expected '{t}', found '{x.dtype}'") elif x._indices.source is not None: if x._indices.axes: - raise ExpressionException(f"'literal' can only accept scalar or global expression arguments," - f" found indices {x._indices.axes}") + raise ExpressionException( + f"'literal' can only accept scalar or global expression arguments," + f" found indices {x._indices.axes}" + ) return False elif x is None or x is pd.NA: return False else: t._typecheck_one_level(x) return True + if dtype is None: dtype = impute_type(x) @@ -334,8 +425,7 @@ def typecheck_expr(t, x): try: dtype._traverse(x, typecheck_expr) except TypeError as e: - raise TypeError("'literal': object did not match the passed type '{}'" - .format(dtype)) from e + raise TypeError("'literal': object did not match the passed type '{}'".format(dtype)) from e if wrapper['has_free_vars']: raise ValueError( @@ -377,10 +467,7 @@ def typecheck_expr(t, x): @deprecated(version="0.2.59", reason="Replaced by hl.if_else") @typecheck(condition=expr_bool, consequent=expr_any, alternate=expr_any, missing_false=bool) -def cond(condition, - consequent, - alternate, - missing_false: bool = False): +def cond(condition, consequent, alternate, missing_false: bool = False): """Deprecated in favor of :func:`.if_else`. Expression for an if/else statement; tests a condition and returns one of two options based on the result. @@ -431,10 +518,7 @@ def cond(condition, @typecheck(condition=expr_bool, consequent=expr_any, alternate=expr_any, missing_false=bool) -def if_else(condition, - consequent, - alternate, - missing_false: bool = False): +def if_else(condition, consequent, alternate, missing_false: bool = False): """Expression for an if/else statement; tests a condition and returns one of two options based on the result. Examples @@ -480,19 +564,19 @@ def if_else(condition, One of `consequent`, `alternate`, or missing, based on `condition`. """ if missing_false: - condition = hl.bind(lambda x: hl.is_defined(x) & x, - condition) + condition = hl.bind(lambda x: hl.is_defined(x) & x, condition) indices, aggregations = unify_all(condition, consequent, alternate) consequent, alternate, success = unify_exprs(consequent, alternate) if not success: - raise TypeError(f"'if_else' and 'cond' require the 'consequent' and 'alternate' arguments to have the same type\n" - f" consequent: type '{consequent.dtype}'\n" - f" alternate: type '{alternate.dtype}'") + raise TypeError( + f"'if_else' and 'cond' require the 'consequent' and 'alternate' arguments to have the same type\n" + f" consequent: type '{consequent.dtype}'\n" + f" alternate: type '{alternate.dtype}'" + ) assert consequent.dtype == alternate.dtype - return construct_expr(ir.If(condition._ir, consequent._ir, alternate._ir), - consequent.dtype, indices, aggregations) + return construct_expr(ir.If(condition._ir, consequent._ir, alternate._ir), consequent.dtype, indices, aggregations) def case(missing_false: bool = False) -> 'hail.expr.builders.CaseBuilder': @@ -524,6 +608,7 @@ def case(missing_false: bool = False) -> 'hail.expr.builders.CaseBuilder': :class:`.CaseBuilder`. """ from .builders import CaseBuilder + return CaseBuilder(missing_false=missing_false) @@ -560,6 +645,7 @@ def switch(expr) -> 'hail.expr.builders.SwitchBuilder': :class:`.SwitchBuilder` """ from .builders import SwitchBuilder + return SwitchBuilder(expr) @@ -607,7 +693,7 @@ def bind(f: Callable, *exprs, _ctx=None): indices, aggregations = unify_all(*exprs, lambda_result) res_ir = lambda_result._ir - for (uid, value_ir) in builtins.zip(uids, irs): + for uid, value_ir in builtins.zip(uids, irs): if _ctx == 'agg': res_ir = ir.AggLet(uid, value_ir, res_ir, is_scan=False) elif _ctx == 'scan': @@ -648,8 +734,7 @@ def rbind(*exprs, _ctx=None): """ *args, f = exprs - args = [expr_any.check(arg, 'rbind', f'argument {index}') - for index, arg in builtins.enumerate(args)] + args = [expr_any.check(arg, 'rbind', f'argument {index}') for index, arg in builtins.enumerate(args)] return hl.bind(f, *args, _ctx=_ctx) @@ -738,9 +823,89 @@ def contingency_table_test(c1, c2, c3, c4, min_cell_count) -> StructExpression: return _func("contingency_table_test", ret_type, c1, c2, c3, c4, min_cell_count) -@typecheck(collection=expr_oneof(expr_dict(), - expr_set(expr_tuple([expr_any, expr_any])), - expr_array(expr_tuple([expr_any, expr_any])))) +# We use 64-bit integers. +# It is relatively easy to encounter an integer overflow bug with 32-bit integers. +@typecheck(a=expr_array(expr_int64), b=expr_array(expr_int64), c=expr_array(expr_int64), d=expr_array(expr_int64)) +def cochran_mantel_haenszel_test( + a: Union[tarray, list], b: Union[tarray, list], c: Union[tarray, list], d: Union[tarray, list] +) -> StructExpression: + """Perform the Cochran-Mantel-Haenszel test for association. + + Examples + -------- + >>> a = [56, 61, 73, 71] + >>> b = [69, 257, 65, 48] + >>> c = [40, 57, 71, 55] + >>> d = [77, 301, 79, 48] + >>> hl.eval(hl.cochran_mantel_haenszel_test(a, b, c, d)) + Struct(test_statistic=5.0496881823306765, p_value=0.024630370456863417) + + >>> mt = ds.filter_rows(mt.locus == hl.Locus(20, 10633237)) + >>> mt.count_rows() + 1 + >>> a, b, c, d = mt.aggregate_entries( + ... hl.tuple([ + ... hl.array([hl.agg.count_where(mt.GT.is_non_ref() & mt.pheno.is_case & mt.pheno.is_female), hl.agg.count_where(mt.GT.is_non_ref() & mt.pheno.is_case & ~mt.pheno.is_female)]), + ... hl.array([hl.agg.count_where(mt.GT.is_non_ref() & ~mt.pheno.is_case & mt.pheno.is_female), hl.agg.count_where(mt.GT.is_non_ref() & ~mt.pheno.is_case & ~mt.pheno.is_female)]), + ... hl.array([hl.agg.count_where(~mt.GT.is_non_ref() & mt.pheno.is_case & mt.pheno.is_female), hl.agg.count_where(~mt.GT.is_non_ref() & mt.pheno.is_case & ~mt.pheno.is_female)]), + ... hl.array([hl.agg.count_where(~mt.GT.is_non_ref() & ~mt.pheno.is_case & mt.pheno.is_female), hl.agg.count_where(~mt.GT.is_non_ref() & ~mt.pheno.is_case & ~mt.pheno.is_female)]) + ... ]) + ... ) + >>> hl.eval(hl.cochran_mantel_haenszel_test(a, b, c, d)) + Struct(test_statistic=0.2188830334629822, p_value=0.6398923118508772) + + Notes + ----- + See the `Wikipedia article `_ + for more details. + + Parameters + ---------- + a : :class:`.ArrayExpression` of type :py:data:`.tint64` + Values for the upper-left cell in the contingency tables. + b : :class:`.ArrayExpression` of type :py:data:`.tint64` + Values for the upper-right cell in the contingency tables. + c : :class:`.ArrayExpression` of type :py:data:`.tint64` + Values for the lower-left cell in the contingency tables. + d : :class:`.ArrayExpression` of type :py:data:`.tint64` + Values for the lower-right cell in the contingency tables. + + Returns + ------- + :class:`.StructExpression` + A :class:`.tstruct` expression with two fields, `test_statistic` + (:py:data:`.tfloat64`) and `p_value` (:py:data:`.tfloat64`). + """ + # The variable names below correspond to the notation used in the Wikipedia article. + # https://en.m.wikipedia.org/wiki/Cochran%E2%80%93Mantel%E2%80%93Haenszel_statistics + n1 = hl.zip(a, b).map(lambda ab: ab[0] + ab[1]) + n2 = hl.zip(c, d).map(lambda cd: cd[0] + cd[1]) + m1 = hl.zip(a, c).map(lambda ac: ac[0] + ac[1]) + m2 = hl.zip(b, d).map(lambda bd: bd[0] + bd[1]) + t = hl.zip(n1, n2).map(lambda nn: nn[0] + nn[1]) + + def numerator_term(a, n1, m1, t): + return a - n1 * m1 / t + + # The numerator comes from the link below, not from the Wikipedia article. + # https://www.biostathandbook.com/cmh.html + numerator = (hl.abs(hl.sum(hl.zip(a, n1, m1, t).map(lambda tup: numerator_term(*tup)))) - 0.5) ** 2 + + def denominator_term(n1, n2, m1, m2, t): + return n1 * n2 * m1 * m2 / (t**3 - t**2) + + denominator = hl.sum(hl.zip(n1, n2, m1, m2, t).map(lambda tup: denominator_term(*tup))) + + test_statistic = numerator / denominator + p_value = pchisqtail(test_statistic, 1) + return struct(test_statistic=test_statistic, p_value=p_value) + + +@typecheck( + collection=expr_oneof( + expr_dict(), expr_set(expr_tuple([expr_any, expr_any])), expr_array(expr_tuple([expr_any, expr_any])) + ) +) def dict(collection) -> DictExpression: """Creates a dictionary. @@ -764,7 +929,7 @@ def dict(collection) -> DictExpression: ------- :class:`.DictExpression` """ - if isinstance(collection.dtype, tarray) or isinstance(collection.dtype, tset): + if isinstance(collection.dtype, (tarray, tset)): key_type, value_type = collection.dtype.element_type.types return _func('dict', tdict(key_type, value_type), collection) else: @@ -972,10 +1137,7 @@ def fisher_exact_test(c1, c2, c3, c4) -> StructExpression: `ci_95_lower (:py:data:`.tfloat64`), and `ci_95_upper` (:py:data:`.tfloat64`). """ - ret_type = tstruct(p_value=tfloat64, - odds_ratio=tfloat64, - ci_95_lower=tfloat64, - ci_95_upper=tfloat64) + ret_type = tstruct(p_value=tfloat64, odds_ratio=tfloat64, ci_95_lower=tfloat64, ci_95_upper=tfloat64) return _func("fisher_exact_test", ret_type, c1, c2, c3, c4) @@ -1072,13 +1234,11 @@ def hardy_weinberg_test(n_hom_ref, n_het, n_hom_var, one_sided=False) -> StructE A struct expression with two fields, `het_freq_hwe` (:py:data:`.tfloat64`) and `p_value` (:py:data:`.tfloat64`). """ - ret_type = tstruct(het_freq_hwe=tfloat64, - p_value=tfloat64) + ret_type = tstruct(het_freq_hwe=tfloat64, p_value=tfloat64) return _func("hardy_weinberg_test", ret_type, n_hom_ref, n_het, n_hom_var, one_sided) -@typecheck(contig=expr_str, pos=expr_int32, - reference_genome=reference_genome_type) +@typecheck(contig=expr_str, pos=expr_int32, reference_genome=reference_genome_type) def locus(contig, pos, reference_genome: Union[str, ReferenceGenome] = 'default') -> LocusExpression: """Construct a locus expression from a chromosome and position. @@ -1104,10 +1264,10 @@ def locus(contig, pos, reference_genome: Union[str, ReferenceGenome] = 'default' return _func('Locus', tlocus(reference_genome), contig, pos) -@typecheck(global_pos=expr_int64, - reference_genome=reference_genome_type) -def locus_from_global_position(global_pos, - reference_genome: Union[str, ReferenceGenome] = 'default') -> LocusExpression: +@typecheck(global_pos=expr_int64, reference_genome=reference_genome_type) +def locus_from_global_position( + global_pos, reference_genome: Union[str, ReferenceGenome] = 'default' +) -> LocusExpression: """Constructs a locus expression from a global position and a reference genome. The inverse of :meth:`.LocusExpression.global_position`. @@ -1136,8 +1296,7 @@ def locus_from_global_position(global_pos, return _func('globalPosToLocus', tlocus(reference_genome), global_pos) -@typecheck(s=expr_str, - reference_genome=reference_genome_type) +@typecheck(s=expr_str, reference_genome=reference_genome_type) def parse_locus(s, reference_genome: Union[str, ReferenceGenome] = 'default') -> LocusExpression: """Construct a locus expression by parsing a string or string expression. @@ -1166,8 +1325,7 @@ def parse_locus(s, reference_genome: Union[str, ReferenceGenome] = 'default') -> return _func('Locus', tlocus(reference_genome), s) -@typecheck(s=expr_str, - reference_genome=reference_genome_type) +@typecheck(s=expr_str, reference_genome=reference_genome_type) def parse_variant(s, reference_genome: Union[str, ReferenceGenome] = 'default') -> StructExpression: """Construct a struct with a locus and alleles by parsing a string. @@ -1197,8 +1355,7 @@ def parse_variant(s, reference_genome: Union[str, ReferenceGenome] = 'default') :class:`.StructExpression` Struct with fields `locus` and `alleles`. """ - t = tstruct(locus=tlocus(reference_genome), - alleles=tarray(tstr)) + t = tstruct(locus=tlocus(reference_genome), alleles=tarray(tstr)) return _func('LocusAlleles', t, s) @@ -1235,19 +1392,23 @@ def variant_str(*args) -> 'StringExpression': args = [to_expr(arg) for arg in args] def type_error(): - raise ValueError(f"'variant_str' expects arguments of the following types:\n" - f" Option 1: 1 argument of type 'struct{{locus: locus, alleles: array}}\n" - f" Option 2: 2 arguments of type 'locus', 'array'\n" - f" Found: {builtins.len(args)} {plural('argument', builtins.len(args))} " - f"of type {', '.join(builtins.str(x.dtype) for x in args)}") + raise ValueError( + f"'variant_str' expects arguments of the following types:\n" + f" Option 1: 1 argument of type 'struct{{locus: locus, alleles: array}}\n" + f" Option 2: 2 arguments of type 'locus', 'array'\n" + f" Found: {builtins.len(args)} {plural('argument', builtins.len(args))} " + f"of type {', '.join(builtins.str(x.dtype) for x in args)}" + ) if builtins.len(args) == 1: [s] = args t = s.dtype - if not isinstance(t, tstruct) \ - or not builtins.len(t) == 2 \ - or not isinstance(t[0], tlocus) \ - or not t[1] == tarray(tstr): + if ( + not isinstance(t, tstruct) + or not builtins.len(t) == 2 + or not isinstance(t[0], tlocus) + or not t[1] == tarray(tstr) + ): type_error() return hl.rbind(s, lambda x: hl.str(x[0]) + ":" + x[1][0] + ":" + hl.delimit(x[1][1:])) elif builtins.len(args) == 2: @@ -1330,37 +1491,33 @@ def pl_dosage(pl) -> Float64Expression: @typecheck(pl=expr_array(expr_int32), _cache_size=int) def pl_to_gp(pl, _cache_size=2048) -> ArrayNumericExpression: """ - Return the linear-scaled genotype probabilities from an array of Phred-scaled genotype likelihoods. + Return the linear-scaled genotype probabilities from an array of Phred-scaled genotype likelihoods. - Examples - -------- - >>> hl.eval(hl.pl_to_gp([0, 10, 100])) - [0.9090909090082644, 0.09090909090082644, 9.090909090082645e-11] + Examples + -------- + >>> hl.eval(hl.pl_to_gp([0, 10, 100])) + [0.9090909090082644, 0.09090909090082644, 9.090909090082645e-11] - Notes - ----- - This function assumes a uniform prior on the possible genotypes. + Notes + ----- + This function assumes a uniform prior on the possible genotypes. - Parameters - ---------- - pl : :class:`.ArrayNumericExpression` of type :py:data:`.tint32` - Array of Phred-scaled genotype likelihoods. + Parameters + ---------- + pl : :class:`.ArrayNumericExpression` of type :py:data:`.tint32` + Array of Phred-scaled genotype likelihoods. - Returns - ------- - :class:`.ArrayNumericExpression` of type :py:data:`.tfloat64` + Returns + ------- + :class:`.ArrayNumericExpression` of type :py:data:`.tfloat64` """ phred_table = hl.literal([10 ** (-x / 10.0) for x in builtins.range(_cache_size)]) gp = hl.bind(lambda pls: pls.map(lambda x: hl.if_else(x >= _cache_size, 10 ** (-x / 10.0), phred_table[x])), pl) return hl.bind(lambda gp: gp / hl.sum(gp), gp) -@typecheck(start=expr_any, end=expr_any, - includes_start=expr_bool, includes_end=expr_bool) -def interval(start, - end, - includes_start=True, - includes_end=False) -> IntervalExpression: +@typecheck(start=expr_any, end=expr_any, includes_start=expr_bool, includes_end=expr_bool) +def interval(start, end, includes_start=True, includes_end=False) -> IntervalExpression: """Construct an interval expression. Examples @@ -1400,17 +1557,24 @@ def interval(start, return _func('Interval', tinterval(start.dtype), start, end, includes_start, includes_end) -@typecheck(contig=expr_str, start=expr_int32, - end=expr_int32, includes_start=expr_bool, - includes_end=expr_bool, reference_genome=reference_genome_type, - invalid_missing=expr_bool) -def locus_interval(contig, - start, - end, - includes_start=True, - includes_end=False, - reference_genome: Union[str, ReferenceGenome] = 'default', - invalid_missing=False) -> IntervalExpression: +@typecheck( + contig=expr_str, + start=expr_int32, + end=expr_int32, + includes_start=expr_bool, + includes_end=expr_bool, + reference_genome=reference_genome_type, + invalid_missing=expr_bool, +) +def locus_interval( + contig, + start, + end, + includes_start=True, + includes_end=False, + reference_genome: Union[str, ReferenceGenome] = 'default', + invalid_missing=False, +) -> IntervalExpression: """Construct a locus interval expression. Examples @@ -1443,13 +1607,22 @@ def locus_interval(contig, ------- :class:`.IntervalExpression` """ - return _func('LocusInterval', tinterval(tlocus(reference_genome)), contig, start, end, includes_start, includes_end, invalid_missing) + return _func( + 'LocusInterval', + tinterval(tlocus(reference_genome)), + contig, + start, + end, + includes_start, + includes_end, + invalid_missing, + ) -@typecheck(s=expr_str, - reference_genome=reference_genome_type, - invalid_missing=expr_bool) -def parse_locus_interval(s, reference_genome: Union[str, ReferenceGenome] = 'default', invalid_missing=False) -> IntervalExpression: +@typecheck(s=expr_str, reference_genome=reference_genome_type, invalid_missing=expr_bool) +def parse_locus_interval( + s, reference_genome: Union[str, ReferenceGenome] = 'default', invalid_missing=False +) -> IntervalExpression: """Construct a locus interval expression by parsing a string or string expression. @@ -1519,12 +1692,10 @@ def parse_locus_interval(s, reference_genome: Union[str, ReferenceGenome] = 'def ------- :class:`.IntervalExpression` """ - return _func('LocusInterval', - tinterval(tlocus(reference_genome)), s, invalid_missing) + return _func('LocusInterval', tinterval(tlocus(reference_genome)), s, invalid_missing) -@typecheck(alleles=expr_int32, - phased=expr_bool) +@typecheck(alleles=expr_int32, phased=expr_bool) def call(*alleles, phased=False) -> CallExpression: """Construct a call expression. @@ -1856,6 +2027,7 @@ def log(x, base=None) -> Float64Expression: ------- :class:`.Expression` of type :py:data:`.tfloat64` """ + def scalar_log(x): if base is not None: return _func("log", tfloat64, x, to_expr(base)) @@ -1978,8 +2150,7 @@ def coalesce(*args): *exprs, success = unify_exprs(*args) if not success: arg_types = ''.join([f"\n argument {i}: type '{arg.dtype}'" for i, arg in builtins.enumerate(exprs)]) - raise TypeError(f"'coalesce' requires all arguments to have the same type or compatible types" - f"{arg_types}") + raise TypeError(f"'coalesce' requires all arguments to have the same type or compatible types" f"{arg_types}") indices, aggregations = unify_all(*exprs) return construct_expr(ir.Coalesce(*(e._ir for e in exprs)), exprs[0].dtype, indices, aggregations) @@ -2012,9 +2183,11 @@ def or_else(a, b): """ a, b, success = unify_exprs(a, b) if not success: - raise TypeError(f"'or_else' requires the 'a' and 'b' arguments to have the same type\n" - f" a: type '{a.dtype}'\n" - f" b: type '{b.dtype}'") + raise TypeError( + f"'or_else' requires the 'a' and 'b' arguments to have the same type\n" + f" a: type '{a.dtype}'\n" + f" b: type '{b.dtype}'" + ) return coalesce(a, b) @@ -2046,8 +2219,9 @@ def or_missing(predicate, value): return hl.if_else(predicate, value, hl.missing(value.dtype)) -@typecheck(x=expr_int32, n=expr_int32, p=expr_float64, - alternative=enumeration("two.sided", "two-sided", "greater", "less")) +@typecheck( + x=expr_int32, n=expr_int32, p=expr_float64, alternative=enumeration("two.sided", "two-sided", "greater", "less") +) def binom_test(x, n, p, alternative: str) -> Float64Expression: """Performs a binomial test on `p` given `x` successes in `n` trials. @@ -2102,9 +2276,11 @@ def binom_test(x, n, p, alternative: str) -> Float64Expression: """ if alternative == 'two.sided': - warning('"two.sided" is a deprecated and will be removed in a future ' - 'release, please use "two-sided" for the `alternative` parameter ' - 'to hl.binom_test') + warning( + '"two.sided" is a deprecated and will be removed in a future ' + 'release, please use "two-sided" for the `alternative` parameter ' + 'to hl.binom_test' + ) alternative = 'two-sided' alt_enum = {"two-sided": 0, "less": 1, "greater": 2}[alternative] @@ -2158,14 +2334,16 @@ def pchisqtail(x, df, ncp=None, lower_tail=False, log_p=False) -> Float64Express PGENCHISQ_RETURN_TYPE = tstruct(value=tfloat64, n_iterations=tint32, converged=tbool, fault=tint32) -@typecheck(x=expr_float64, - w=expr_array(expr_float64), - k=expr_array(expr_int32), - lam=expr_array(expr_float64), - mu=expr_float64, - sigma=expr_float64, - max_iterations=nullable(expr_int32), - min_accuracy=nullable(expr_float64)) +@typecheck( + x=expr_float64, + w=expr_array(expr_float64), + k=expr_array(expr_int32), + lam=expr_array(expr_float64), + mu=expr_float64, + sigma=expr_float64, + max_iterations=nullable(expr_int32), + min_accuracy=nullable(expr_float64), +) def pgenchisq(x, w, k, lam, mu, sigma, *, max_iterations=None, min_accuracy=None) -> Float64Expression: r"""The cumulative probability function of a `generalized chi-squared distribution `__. @@ -2662,7 +2840,9 @@ def range(start, stop=None, step=1) -> ArrayNumericExpression: if stop is None: stop = start start = hl.literal(0) - return apply_expr(lambda sta, sto, ste: ir.toArray(ir.StreamRange(sta, sto, ste)), tarray(tint32), start, stop, step) + return apply_expr( + lambda sta, sto, ste: ir.toArray(ir.StreamRange(sta, sto, ste)), tarray(tint32), start, stop, step + ) @typecheck(start=expr_int32, stop=nullable(expr_int32), step=expr_int32) @@ -2810,12 +2990,13 @@ def f(mean, cov): s22 = cov[2] x = hl.range(0, 2).map(lambda i: rand_norm(seed=seed)) - return hl.rbind(hl.sqrt(s11), - lambda root_s11: - hl.array([ - m1 + root_s11 * x[0], - m2 + (s12 / root_s11) * x[0] - + hl.sqrt(s22 - s12 * s12 / s11) * x[1]])) + return hl.rbind( + hl.sqrt(s11), + lambda root_s11: hl.array([ + m1 + root_s11 * x[0], + m2 + (s12 / root_s11) * x[0] + hl.sqrt(s22 - s12 * s12 / s11) * x[1], + ]), + ) return hl.rbind(mean, cov, f) @@ -2970,11 +3151,9 @@ def rand_int64(a=None, b=None, *, seed=None) -> Int64Expression: return _seeded_func("rand_int64", tint64, seed, b - a) + a -@typecheck(a=expr_float64, - b=expr_float64, - lower=nullable(expr_float64), - upper=nullable(expr_float64), - seed=nullable(int)) +@typecheck( + a=expr_float64, b=expr_float64, lower=nullable(expr_float64), upper=nullable(expr_float64), seed=nullable(int) +) def rand_beta(a, b, lower=None, upper=None, seed=None) -> Float64Expression: """Samples from a `beta distribution `__ with parameters `a` @@ -3023,9 +3202,7 @@ def rand_beta(a, b, lower=None, upper=None, seed=None) -> Float64Expression: return _seeded_func("rand_beta", tfloat64, seed, a, b, lower, upper) -@typecheck(shape=expr_float64, - scale=expr_float64, - seed=nullable(int)) +@typecheck(shape=expr_float64, scale=expr_float64, seed=nullable(int)) def rand_gamma(shape, scale, seed=None) -> Float64Expression: """Samples from a `gamma distribution `__ @@ -3055,8 +3232,7 @@ def rand_gamma(shape, scale, seed=None) -> Float64Expression: return _seeded_func("rand_gamma", tfloat64, seed, shape, scale) -@typecheck(prob=expr_array(expr_float64), - seed=nullable(int)) +@typecheck(prob=expr_array(expr_float64), seed=nullable(int)) def rand_cat(prob, seed=None) -> Int32Expression: """Samples from a `categorical distribution `__. @@ -3094,8 +3270,7 @@ def rand_cat(prob, seed=None) -> Int32Expression: return _seeded_func("rand_cat", tint32, seed, prob) -@typecheck(a=expr_array(expr_float64), - seed=nullable(int)) +@typecheck(a=expr_array(expr_float64), seed=nullable(int)) def rand_dirichlet(a, seed=None) -> ArrayExpression: """Samples from a `Dirichlet distribution `__. @@ -3121,11 +3296,7 @@ def rand_dirichlet(a, seed=None) -> ArrayExpression: ------- :class:`.Float64Expression` """ - return hl.bind(lambda x: x / hl.sum(x), - a.map(lambda p: - hl.if_else(p == 0.0, - 0.0, - hl.rand_gamma(p, 1, seed=seed)))) + return hl.bind(lambda x: x / hl.sum(x), a.map(lambda p: hl.if_else(p == 0.0, 0.0, hl.rand_gamma(p, 1, seed=seed)))) @typecheck(x=oneof(expr_float64, expr_ndarray(expr_float64))) @@ -3185,36 +3356,62 @@ def corr(x, y) -> Float64Expression: return _func("corr", tfloat64, x, y) -_base_regex = "^([ACGTNM])+$" -_symbolic_regex = r"(^\.)|(\.$)|(^<)|(>$)|(\[)|(\])" -_allele_types = ["Unknown", "SNP", "MNP", "Insertion", "Deletion", "Complex", "Star", "Symbolic"] -_allele_enum = {i: v for i, v in builtins.enumerate(_allele_types)} -_allele_ints = {v: k for k, v in _allele_enum.items()} +@typecheck(ref=expr_str, alt=expr_str) +@ir.udf(tstr, tstr) +def numeric_allele_type(ref, alt) -> Int32Expression: + """Returns the type of the polymorphism as an integer. The value returned + is the integer value of :class:`.AlleleType` representing that kind of + polymorphism. + + Examples + -------- + + >>> hl.eval(hl.numeric_allele_type('A', 'T')) == AlleleType.SNP + True + Notes + ----- + The values of :class:`.AlleleType` are not stable and thus should not be + relied upon across hail versions. + """ + _base_regex = "^([ACGTNM])+$" + _symbolic_regex = r"(^\.)|(\.$)|(^<)|(>$)|(\[)|(\])" + return hl.bind( + lambda r, a: hl.if_else( + r.matches(_base_regex), + hl.case() + .when( + a.matches(_base_regex), + hl.case() + .when( + r.length() == a.length(), + hl.if_else( + r.length() == 1, + hl.if_else(r != a, AlleleType.SNP, AlleleType.UNKNOWN), + hl.if_else(hamming(r, a) == 1, AlleleType.SNP, AlleleType.MNP), + ), + ) + .when((r.length() < a.length()) & (r[0] == a[0]) & a.endswith(r[1:]), AlleleType.INSERTION) + .when((r[0] == a[0]) & r.endswith(a[1:]), AlleleType.DELETION) + .default(AlleleType.COMPLEX), + ) + .when(a == '*', AlleleType.STAR) + .when(a.matches(_symbolic_regex), AlleleType.SYMBOLIC) + .default(AlleleType.UNKNOWN), + AlleleType.UNKNOWN, + ), + ref, + alt, + ) + +@deprecated(version='0.2.129', reason="Replaced by the public numeric_allele_type") @typecheck(ref=expr_str, alt=expr_str) -@ir.udf(tstr, tstr) def _num_allele_type(ref, alt) -> Int32Expression: - return hl.bind(lambda r, a: - hl.if_else(r.matches(_base_regex), - hl.case() - .when(a.matches(_base_regex), hl.case() - .when(r.length() == a.length(), - hl.if_else(r.length() == 1, - hl.if_else(r != a, _allele_ints['SNP'], _allele_ints['Unknown']), - hl.if_else(hamming(r, a) == 1, - _allele_ints['SNP'], - _allele_ints['MNP']))) - .when((r.length() < a.length()) & (r[0] == a[0]) & a.endswith(r[1:]), - _allele_ints["Insertion"]) - .when((r[0] == a[0]) & r.endswith(a[1:]), - _allele_ints["Deletion"]) - .default(_allele_ints['Complex'])) - .when(a == '*', _allele_ints['Star']) - .when(a.matches(_symbolic_regex), _allele_ints['Symbolic']) - .default(_allele_ints['Unknown']), - _allele_ints['Unknown']), - ref, alt) + """Provided for backwards compatibility, don't use it in new code, or + within the hail library itself + """ + return numeric_allele_type(ref, alt) @typecheck(ref=expr_str, alt=expr_str) @@ -3238,7 +3435,7 @@ def is_snp(ref, alt) -> BooleanExpression: ------- :class:`.BooleanExpression` """ - return _num_allele_type(ref, alt) == _allele_ints["SNP"] + return numeric_allele_type(ref, alt) == AlleleType.SNP @typecheck(ref=expr_str, alt=expr_str) @@ -3262,7 +3459,7 @@ def is_mnp(ref, alt) -> BooleanExpression: ------- :class:`.BooleanExpression` """ - return _num_allele_type(ref, alt) == _allele_ints["MNP"] + return numeric_allele_type(ref, alt) == AlleleType.MNP @typecheck(ref=expr_str, alt=expr_str) @@ -3323,10 +3520,18 @@ def is_transversion(ref, alt) -> BooleanExpression: @ir.udf(tstr, tstr) def _is_snp_transition(ref, alt) -> BooleanExpression: indices = hl.range(0, ref.length()) - return hl.any(lambda i: ((ref[i] != alt[i]) & (((ref[i] == 'A') & (alt[i] == 'G')) - | ((ref[i] == 'G') & (alt[i] == 'A')) - | ((ref[i] == 'C') & (alt[i] == 'T')) - | ((ref[i] == 'T') & (alt[i] == 'C')))), indices) + return hl.any( + lambda i: ( + (ref[i] != alt[i]) + & ( + ((ref[i] == 'A') & (alt[i] == 'G')) + | ((ref[i] == 'G') & (alt[i] == 'A')) + | ((ref[i] == 'C') & (alt[i] == 'T')) + | ((ref[i] == 'T') & (alt[i] == 'C')) + ) + ), + indices, + ) @typecheck(ref=expr_str, alt=expr_str) @@ -3350,7 +3555,7 @@ def is_insertion(ref, alt) -> BooleanExpression: ------- :class:`.BooleanExpression` """ - return _num_allele_type(ref, alt) == _allele_ints["Insertion"] + return numeric_allele_type(ref, alt) == AlleleType.INSERTION @typecheck(ref=expr_str, alt=expr_str) @@ -3374,7 +3579,7 @@ def is_deletion(ref, alt) -> BooleanExpression: ------- :class:`.BooleanExpression` """ - return _num_allele_type(ref, alt) == _allele_ints["Deletion"] + return numeric_allele_type(ref, alt) == AlleleType.DELETION @typecheck(ref=expr_str, alt=expr_str) @@ -3398,9 +3603,7 @@ def is_indel(ref, alt) -> BooleanExpression: ------- :class:`.BooleanExpression` """ - return hl.bind(lambda t: (t == _allele_ints["Insertion"]) - | (t == _allele_ints["Deletion"]), - _num_allele_type(ref, alt)) + return hl.bind(lambda t: (t == AlleleType.INSERTION) | (t == AlleleType.DELETION), numeric_allele_type(ref, alt)) @typecheck(ref=expr_str, alt=expr_str) @@ -3424,7 +3627,7 @@ def is_star(ref, alt) -> BooleanExpression: ------- :class:`.BooleanExpression` """ - return _num_allele_type(ref, alt) == _allele_ints["Star"] + return numeric_allele_type(ref, alt) == AlleleType.STAR @typecheck(ref=expr_str, alt=expr_str) @@ -3448,7 +3651,7 @@ def is_complex(ref, alt) -> BooleanExpression: ------- :class:`.BooleanExpression` """ - return _num_allele_type(ref, alt) == _allele_ints["Complex"] + return numeric_allele_type(ref, alt) == AlleleType.COMPLEX @typecheck(ref=expr_str, alt=expr_str) @@ -3516,7 +3719,7 @@ def allele_type(ref, alt) -> StringExpression: ------- :class:`.StringExpression` """ - return hl.literal(_allele_types)[_num_allele_type(ref, alt)] + return hl.literal(AlleleType.strings())[numeric_allele_type(ref, alt)] @typecheck(s1=expr_str, s2=expr_str) @@ -3696,8 +3899,7 @@ def triangle(n) -> Int32Expression: return _func("triangle", tint32, n) -@typecheck(f=func_spec(1, expr_bool), - collection=expr_oneof(expr_set(), expr_array())) +@typecheck(f=func_spec(1, expr_bool), collection=expr_oneof(expr_set(), expr_array())) def filter(f: Callable, collection): """Returns a new collection containing elements where `f` returns ``True``. @@ -3795,8 +3997,7 @@ def any(*args) -> BooleanExpression: if builtins.len(args) == 0: return base if builtins.len(args) == 1: - arg = arg_check(args[0], 'any', 'collection', - oneof(collection_type, expr_bool)) + arg = arg_check(args[0], 'any', 'collection', oneof(collection_type, expr_bool)) if arg.dtype == hl.tbool: return arg return arg.any(lambda x: x) @@ -3806,8 +4007,7 @@ def any(*args) -> BooleanExpression: collection = arg_check(args[1], 'any', 'collection', collection_type) return collection.any(f) n_args = builtins.len(args) - args = [args_check(x, 'any', 'exprs', i, n_args, expr_bool) - for i, x in builtins.enumerate(args)] + args = [args_check(x, 'any', 'exprs', i, n_args, expr_bool) for i, x in builtins.enumerate(args)] return functools.reduce(operator.ior, args, base) @@ -3870,8 +4070,7 @@ def all(*args) -> BooleanExpression: if builtins.len(args) == 0: return base if builtins.len(args) == 1: - arg = arg_check(args[0], 'any', 'collection', - oneof(collection_type, expr_bool)) + arg = arg_check(args[0], 'any', 'collection', oneof(collection_type, expr_bool)) if arg.dtype == hl.tbool: return arg return arg.all(lambda x: x) @@ -3881,13 +4080,11 @@ def all(*args) -> BooleanExpression: collection = arg_check(args[1], 'all', 'collection', collection_type) return collection.all(f) n_args = builtins.len(args) - args = [args_check(x, 'all', 'exprs', i, n_args, expr_bool) - for i, x in builtins.enumerate(args)] + args = [args_check(x, 'all', 'exprs', i, n_args, expr_bool) for i, x in builtins.enumerate(args)] return functools.reduce(operator.iand, args, base) -@typecheck(f=func_spec(1, expr_bool), - collection=expr_oneof(expr_set(), expr_array())) +@typecheck(f=func_spec(1, expr_bool), collection=expr_oneof(expr_set(), expr_array())) def find(f: Callable, collection): """Returns the first element where `f` returns ``True``. @@ -3926,8 +4123,7 @@ def find(f: Callable, collection): return collection.find(f) -@typecheck(f=func_spec(1, expr_any), - collection=expr_oneof(expr_set(), expr_array())) +@typecheck(f=func_spec(1, expr_any), collection=expr_oneof(expr_set(), expr_array())) def flatmap(f: Callable, collection): """Map each element of the collection to a new collection, and flatten the results. @@ -3962,8 +4158,7 @@ def unify_ret(t): return collection.flatmap(f) -@typecheck(f=func_spec(1, expr_any), - collection=expr_oneof(expr_set(), expr_array())) +@typecheck(f=func_spec(1, expr_any), collection=expr_oneof(expr_set(), expr_array())) def group_by(f: Callable, collection) -> DictExpression: """Group collection elements into a dict according to a lambda function. @@ -3990,9 +4185,7 @@ def group_by(f: Callable, collection) -> DictExpression: return collection.group_by(f) -@typecheck(f=func_spec(2, expr_any), - zero=expr_any, - collection=expr_oneof(expr_set(), expr_array())) +@typecheck(f=func_spec(2, expr_any), zero=expr_any, collection=expr_oneof(expr_set(), expr_array())) def fold(f: Callable, zero, collection) -> Expression: """Reduces a collection with the given function `f`, provided the initial value `zero`. @@ -4019,9 +4212,7 @@ def fold(f: Callable, zero, collection) -> Expression: return collection.fold(lambda x, y: f(x, y), zero) -@typecheck(f=func_spec(2, expr_any), - zero=expr_any, - a=expr_array()) +@typecheck(f=func_spec(2, expr_any), zero=expr_any, a=expr_array()) def array_scan(f: Callable, zero, a) -> ArrayExpression: """Map each element of `a` to cumulative value of function `f`, with initial value `zero`. @@ -4056,10 +4247,12 @@ def _zip_streams(*streams, fill_missing: bool = False) -> StreamExpression: body_ir = ir.MakeTuple([ir.Ref(uid, type) for uid, type in builtins.zip(uids, types)]) indices, aggregations = unify_all(*streams) behavior = 'ExtendNA' if fill_missing else 'TakeMinLength' - return construct_expr(ir.StreamZip([s._ir for s in streams], uids, body_ir, behavior), - tstream(ttuple(*(s.dtype.element_type for s in streams))), - indices, - aggregations) + return construct_expr( + ir.StreamZip([s._ir for s in streams], uids, body_ir, behavior), + tstream(ttuple(*(s.dtype.element_type for s in streams))), + indices, + aggregations, + ) @typecheck(arrays=expr_array(), fill_missing=bool) @@ -4105,8 +4298,10 @@ def zip(*arrays, fill_missing: bool = False) -> ArrayExpression: def _zip_func(*arrays, fill_missing=False, f): n_arrays = builtins.len(arrays) uids = [Env.get_uid() for _ in builtins.range(n_arrays)] - refs = [construct_expr(ir.Ref(uid, a.dtype.element_type), a.dtype.element_type, a._indices, a._aggregations) for uid, a in - builtins.zip(uids, arrays)] + refs = [ + construct_expr(ir.Ref(uid, a.dtype.element_type), a.dtype.element_type, a._indices, a._aggregations) + for uid, a in builtins.zip(uids, arrays) + ] body_result = f(*refs) indices, aggregations = unify_all(*arrays, body_result) behavior = 'ExtendNA' if fill_missing else 'TakeMinLength' @@ -4114,7 +4309,8 @@ def _zip_func(*arrays, fill_missing=False, f): ir.toArray(ir.StreamZip([ir.toStream(a._ir) for a in arrays], uids, body_result._ir, behavior)), tarray(body_result.dtype), indices, - aggregations) + aggregations, + ) @typecheck(a=expr_array(), start=expr_int32, index_first=bool) @@ -4183,8 +4379,7 @@ def zip_with_index(a, index_first=True): return enumerate(a, index_first=index_first) -@typecheck(f=anyfunc, - collections=expr_oneof(expr_set(), expr_array(), expr_ndarray())) +@typecheck(f=anyfunc, collections=expr_oneof(expr_set(), expr_array(), expr_ndarray())) def map(f: Callable, *collections): r"""Transform each element of a collection. @@ -4220,10 +4415,7 @@ def map(f: Callable, *collections): @typecheck(expr=oneof(expr_any, func_spec(0, expr_any)), n=expr_int32) -def repeat( - expr: 'Union[hl.Expression, Callable[[], hl.Expression]]', - n: 'hl.tint32' -) -> 'hl.ArrayExpression': +def repeat(expr: 'Union[hl.Expression, Callable[[], hl.Expression]]', n: 'hl.tint32') -> 'hl.ArrayExpression': """Return array of `n` elements initialized by `expr`. Examples @@ -4251,13 +4443,10 @@ def repeat( Array where each element has been initialized by `expr` """ mkarray = lambda x: hl.range(n).map(lambda _: x) - return hl.rbind(expr, mkarray) \ - if isinstance(expr, hl.Expression) \ - else mkarray(expr()) + return hl.rbind(expr, mkarray) if isinstance(expr, hl.Expression) else mkarray(expr()) -@typecheck(f=anyfunc, - collection=expr_oneof(expr_set(), expr_array(), expr_ndarray())) +@typecheck(f=anyfunc, collection=expr_oneof(expr_set(), expr_array(), expr_ndarray())) def starmap(f: Callable, collection): r"""Transform each element of a collection of tuples. @@ -4312,7 +4501,7 @@ def len(x) -> Int32Expression: ------- :class:`.Expression` of type :py:data:`.tint32` """ - if isinstance(x.dtype, ttuple) or isinstance(x.dtype, tstruct): + if isinstance(x.dtype, (ttuple, tstruct)): return hl.int32(builtins.len(x)) elif x.dtype == tstr: return apply_expr(lambda x: ir.Apply("length", tint32, x), tint32, x) @@ -4347,16 +4536,15 @@ def reversed(x): return x -@typecheck(name=builtins.str, - exprs=tupleof(Expression), - filter_missing=builtins.bool, - filter_nan=builtins.bool) +@typecheck(name=builtins.str, exprs=tupleof(Expression), filter_missing=builtins.bool, filter_nan=builtins.bool) def _comparison_func(name, exprs, filter_missing, filter_nan): if builtins.len(exprs) < 1: raise ValueError(f"{name:!r} requires at least one argument") - if (builtins.len(exprs) == 1 - and (isinstance(exprs[0].dtype, (tarray, tset))) - and is_numeric(exprs[0].dtype.element_type)): + if ( + builtins.len(exprs) == 1 + and (isinstance(exprs[0].dtype, (tarray, tset))) + and is_numeric(exprs[0].dtype.element_type) + ): [e] = exprs if filter_nan and e.dtype.element_type in (tfloat32, tfloat64): name = 'nan' + name @@ -4364,8 +4552,10 @@ def _comparison_func(name, exprs, filter_missing, filter_nan): else: if not builtins.all(is_numeric(e.dtype) for e in exprs): expr_types = ', '.join("'{}'".format(e.dtype) for e in exprs) - raise TypeError(f"{name!r} expects a single numeric array expression or multiple numeric expressions\n" - f" Found {builtins.len(exprs)} arguments with types {expr_types}") + raise TypeError( + f"{name!r} expects a single numeric array expression or multiple numeric expressions\n" + f" Found {builtins.len(exprs)} arguments with types {expr_types}" + ) unified_typ = unify_types_limited(*(e.dtype for e in exprs)) ec = coercer_from_dtype(unified_typ) indices, aggs = unify_all(*exprs) @@ -4375,14 +4565,17 @@ def _comparison_func(name, exprs, filter_missing, filter_nan): func_name += '_ignore_missing' if filter_nan and unified_typ in (tfloat32, tfloat64): func_name = 'nan' + func_name - return construct_expr(functools.reduce(lambda l, r: ir.Apply(func_name, unified_typ, l, r), [ec.coerce(e)._ir for e in exprs]), - unified_typ, - indices, - aggs) + return construct_expr( + functools.reduce(lambda l, r: ir.Apply(func_name, unified_typ, l, r), [ec.coerce(e)._ir for e in exprs]), + unified_typ, + indices, + aggs, + ) -@typecheck(exprs=expr_oneof(expr_numeric, expr_set(expr_numeric), expr_array(expr_numeric)), - filter_missing=builtins.bool) +@typecheck( + exprs=expr_oneof(expr_numeric, expr_set(expr_numeric), expr_array(expr_numeric)), filter_missing=builtins.bool +) def nanmax(*exprs, filter_missing: builtins.bool = True) -> NumericExpression: """Returns the maximum value of a collection or of given arguments, excluding NaN. @@ -4433,8 +4626,9 @@ def nanmax(*exprs, filter_missing: builtins.bool = True) -> NumericExpression: return _comparison_func('max', exprs, filter_missing, filter_nan=True) -@typecheck(exprs=expr_oneof(expr_numeric, expr_set(expr_numeric), expr_array(expr_numeric)), - filter_missing=builtins.bool) +@typecheck( + exprs=expr_oneof(expr_numeric, expr_set(expr_numeric), expr_array(expr_numeric)), filter_missing=builtins.bool +) def max(*exprs, filter_missing: builtins.bool = True) -> NumericExpression: """Returns the maximum element of a collection or of given numeric expressions. @@ -4483,8 +4677,9 @@ def max(*exprs, filter_missing: builtins.bool = True) -> NumericExpression: return _comparison_func('max', exprs, filter_missing, filter_nan=False) -@typecheck(exprs=expr_oneof(expr_numeric, expr_set(expr_numeric), expr_array(expr_numeric)), - filter_missing=builtins.bool) +@typecheck( + exprs=expr_oneof(expr_numeric, expr_set(expr_numeric), expr_array(expr_numeric)), filter_missing=builtins.bool +) def nanmin(*exprs, filter_missing: builtins.bool = True) -> NumericExpression: """Returns the minimum value of a collection or of given arguments, excluding NaN. @@ -4535,8 +4730,9 @@ def nanmin(*exprs, filter_missing: builtins.bool = True) -> NumericExpression: return _comparison_func('min', exprs, filter_missing, filter_nan=True) -@typecheck(exprs=expr_oneof(expr_numeric, expr_set(expr_numeric), expr_array(expr_numeric)), - filter_missing=builtins.bool) +@typecheck( + exprs=expr_oneof(expr_numeric, expr_set(expr_numeric), expr_array(expr_numeric)), filter_missing=builtins.bool +) def min(*exprs, filter_missing: builtins.bool = True) -> NumericExpression: """Returns the minimum element of a collection or of given numeric expressions. @@ -4606,7 +4802,7 @@ def abs(x): ------- :class:`.NumericExpression`, :class:`.ArrayNumericExpression` or :class:`.NDArrayNumericExpression`. """ - if isinstance(x.dtype, tarray) or isinstance(x.dtype, tndarray): + if isinstance(x.dtype, (tarray, tndarray)): return map(abs, x) else: return x._method('abs', x.dtype) @@ -4643,14 +4839,13 @@ def sign(x): ------- :class:`.NumericExpression`, :class:`.ArrayNumericExpression` or :class:`.NDArrayNumericExpression`. """ - if isinstance(x.dtype, tarray) or isinstance(x.dtype, tndarray): + if isinstance(x.dtype, (tarray, tndarray)): return map(sign, x) else: return x._method('sign', x.dtype) -@typecheck(collection=expr_oneof(expr_set(expr_numeric), expr_array(expr_numeric)), - filter_missing=bool) +@typecheck(collection=expr_oneof(expr_set(expr_numeric), expr_array(expr_numeric)), filter_missing=bool) def mean(collection, filter_missing: bool = True) -> Float64Expression: """Returns the mean of all values in the collection. @@ -4709,8 +4904,7 @@ def median(collection) -> NumericExpression: return collection._method("median", collection.dtype.element_type) -@typecheck(collection=expr_oneof(expr_set(expr_numeric), expr_array(expr_numeric)), - filter_missing=bool) +@typecheck(collection=expr_oneof(expr_set(expr_numeric), expr_array(expr_numeric)), filter_missing=bool) def product(collection, filter_missing: bool = True) -> NumericExpression: """Returns the product of values in the collection. @@ -4741,8 +4935,7 @@ def product(collection, filter_missing: bool = True) -> NumericExpression: return array(collection)._filter_missing_method(filter_missing, "product", collection.dtype.element_type) -@typecheck(collection=expr_oneof(expr_set(expr_numeric), expr_array(expr_numeric)), - filter_missing=bool) +@typecheck(collection=expr_oneof(expr_set(expr_numeric), expr_array(expr_numeric)), filter_missing=bool) def sum(collection, filter_missing: bool = True) -> NumericExpression: """Returns the sum of values in the collection. @@ -4772,8 +4965,7 @@ def sum(collection, filter_missing: bool = True) -> NumericExpression: return array(collection)._filter_missing_method(filter_missing, "sum", collection.dtype.element_type) -@typecheck(a=expr_array(expr_numeric), - filter_missing=bool) +@typecheck(a=expr_array(expr_numeric), filter_missing=bool) def cumulative_sum(a, filter_missing: bool = True) -> ArrayNumericExpression: """Returns an array of the cumulative sum of values in the array. @@ -4963,8 +5155,9 @@ def _ndarray(collection, row_major=None, dtype=None): ------- :class:`.NDArrayExpression` """ + def list_shape(x): - if isinstance(x, list) or isinstance(x, builtins.tuple): + if isinstance(x, (list, builtins.tuple)): dim_len = builtins.len(x) if dim_len != 0: first, rest = x[0], x[1:] @@ -4973,7 +5166,7 @@ def list_shape(x): other_inner_shape = list_shape(e) if inner_shape != other_inner_shape: raise ValueError(f'inner dimensions do not match: {inner_shape}, {other_inner_shape}') - return [dim_len] + inner_shape + return [dim_len, *inner_shape] else: return [dim_len] else: @@ -4982,7 +5175,7 @@ def list_shape(x): def deep_flatten(es): result = [] for e in es: - if isinstance(e, list) or isinstance(e, builtins.tuple): + if isinstance(e, (list, builtins.tuple)): result.extend(deep_flatten(e)) else: result.append(e) @@ -4990,11 +5183,13 @@ def deep_flatten(es): return result def check_arrays_uniform(nested_arr, shape_list, ndim): - current_level_correct = (hl.len(nested_arr) == shape_list[-ndim]) + current_level_correct = hl.len(nested_arr) == shape_list[-ndim] if ndim == 1: return current_level_correct else: - return current_level_correct & (hl.all(lambda inner: check_arrays_uniform(inner, shape_list, ndim - 1), nested_arr)) + return current_level_correct & ( + hl.all(lambda inner: check_arrays_uniform(inner, shape_list, ndim - 1), nested_arr) + ) if isinstance(collection, Expression): if isinstance(collection, ArrayNumericExpression): @@ -5008,7 +5203,7 @@ def check_arrays_uniform(nested_arr, shape_list, ndim): elif isinstance(collection, ArrayExpression): recursive_type = collection.dtype ndim = 0 - while isinstance(recursive_type, tarray) or isinstance(recursive_type, tndarray): + while isinstance(recursive_type, (tarray, tndarray)): recursive_type = recursive_type._element_type ndim += 1 @@ -5022,8 +5217,11 @@ def check_arrays_uniform(nested_arr, shape_list, ndim): shape_list.append(hl.int64(hl.len(nested_collection))) nested_collection = nested_collection[0] - shape_expr = (hl.case().when(check_arrays_uniform(collection, shape_list, ndim), hl.tuple(shape_list)) - .or_error("inner dimensions do not match")) + shape_expr = ( + hl.case() + .when(check_arrays_uniform(collection, shape_list, ndim), hl.tuple(shape_list)) + .or_error("inner dimensions do not match") + ) else: raise ValueError(f"{collection} cannot be converted into an ndarray") @@ -5031,7 +5229,7 @@ def check_arrays_uniform(nested_arr, shape_list, ndim): else: if isinstance(collection, np.ndarray): return hl.literal(collection) - elif isinstance(collection, list) or isinstance(collection, builtins.tuple): + elif isinstance(collection, (list, builtins.tuple)): shape = list_shape(collection) data = deep_flatten(collection) else: @@ -5110,9 +5308,11 @@ def _union_intersection_base(name, arrays, key, join_f, result_f): raise ValueError(f"{name}: key field {k!r} not in element type {t}") for i, a in builtins.enumerate(arrays): if a.dtype.element_type != t: - raise ValueError(f"{name}: input {i} has a different element type than input 0:" - f"\n input 0: {t}" - f"\n input {i}: {a.dtype.element_type}") + raise ValueError( + f"{name}: input {i} has a different element type than input 0:" + f"\n input 0: {t}" + f"\n input {i}: {a.dtype.element_type}" + ) key_typ = hl.tstruct(**{k: t[k] for k in key}) vals_typ = hl.tarray(t) @@ -5157,8 +5357,7 @@ def _zip_join_producers(contexts, stream_f, key, join_f): vals_var = construct_variable(vals_uid, vals_typ) join_ir = join_f(key_var, vals_var) - zj = ir.ToArray( - ir.StreamZipJoinProducers(contexts._ir, ctx_uid, make_prod_ir, key, key_uid, vals_uid, join_ir._ir)) + zj = ir.ToArray(ir.StreamZipJoinProducers(contexts._ir, ctx_uid, make_prod_ir, key, key_uid, vals_uid, join_ir._ir)) indices, aggs = unify_all(contexts, stream_req, join_ir) return construct_expr(zj, zj.typ, indices, aggs) @@ -5187,9 +5386,10 @@ def keyed_intersection(*arrays, key): arrays, key, lambda key_var, vals_var: hl.tuple((key_var, vals_var)), - lambda res: res - .filter(lambda x: hl.fold(lambda acc, elt: acc & hl.is_defined(elt), True, x[1])) - .map(lambda x: x[1].first())) + lambda res: res.filter(lambda x: hl.fold(lambda acc, elt: acc & hl.is_defined(elt), True, x[1])).map( + lambda x: x[1].first() + ), + ) @typecheck(arrays=expr_oneof(expr_stream(expr_any), expr_array(expr_any)), key=sequenceof(builtins.str)) @@ -5215,13 +5415,14 @@ def keyed_union(*arrays, key): 'keyed_union', arrays, key, - lambda keys_var, vals_var: hl.fold(lambda acc, elt: hl.coalesce(acc, elt), - hl.missing(vals_var.dtype.element_type), vals_var), - lambda res: res) + lambda keys_var, vals_var: hl.fold( + lambda acc, elt: hl.coalesce(acc, elt), hl.missing(vals_var.dtype.element_type), vals_var + ), + lambda res: res, + ) -@typecheck(collection=expr_oneof(expr_array(), expr_set()), - delimiter=expr_str) +@typecheck(collection=expr_oneof(expr_array(), expr_set()), delimiter=expr_str) def delimit(collection, delimiter=',') -> StringExpression: """Joins elements of `collection` into single string delimited by `delimiter`. @@ -5259,13 +5460,14 @@ def delimit(collection, delimiter=',') -> StringExpression: @typecheck(left=expr_any, right=expr_any) def _compare(left, right): if left.dtype != right.dtype: - raise TypeError(f"'compare' expected 'left' and 'right' to have the same type: found {left.dtype} vs {right.dtype}") + raise TypeError( + f"'compare' expected 'left' and 'right' to have the same type: found {left.dtype} vs {right.dtype}" + ) indices, aggregations = unify_all(left, right) return construct_expr(ir.ApplyComparisonOp("Compare", left._ir, right._ir), tint32, indices, aggregations) -@typecheck(collection=expr_array(), - less_than=nullable(func_spec(2, expr_bool))) +@typecheck(collection=expr_array(), less_than=nullable(func_spec(2, expr_bool))) def _sort_by(collection, less_than): left_id = Env.get_uid() right_id = Env.get_uid() @@ -5276,15 +5478,16 @@ def _sort_by(collection, less_than): ir.ArraySort(ir.toStream(collection._ir), left_id, right_id, less_than(left, right)._ir), collection.dtype, collection._indices, - collection._aggregations) + collection._aggregations, + ) -@typecheck(collection=expr_oneof(expr_array(), expr_dict(), expr_set()), - key=nullable(func_spec(1, expr_any)), - reverse=expr_bool) -def sorted(collection, - key: Optional[Callable] = None, - reverse=False) -> ArrayExpression: +@typecheck( + collection=expr_oneof(expr_array(), expr_dict(), expr_set()), + key=nullable(func_spec(1, expr_any)), + reverse=expr_bool, +) +def sorted(collection, key: Optional[Callable] = None, reverse=False) -> ArrayExpression: """Returns a sorted array. Examples @@ -5324,11 +5527,13 @@ def sorted(collection, collection = hl.array(collection) def comp(left, right): - return (hl.case() - .when(hl.is_missing(left), False) - .when(hl.is_missing(right), True) - .when(reverse, hl._compare(right, left) < 0) - .default(hl._compare(left, right) < 0)) + return ( + hl.case() + .when(hl.is_missing(left), False) + .when(hl.is_missing(right), True) + .when(reverse, hl._compare(right, left) < 0) + .default(hl._compare(left, right) < 0) + ) if key is None: return _sort_by(collection, comp) @@ -5829,8 +6034,7 @@ def bool(x) -> BooleanExpression: return x._method("toBoolean", tbool) -@typecheck(s=expr_str, - rna=builtins.bool) +@typecheck(s=expr_str, rna=builtins.bool) def reverse_complement(s, rna=False): """Reverses the string and translates base pairs into their complements Examples @@ -5865,11 +6069,9 @@ def reverse_complement(s, rna=False): return s.translate(d) -@typecheck(contig=expr_str, - position=expr_int32, - before=expr_int32, - after=expr_int32, - reference_genome=reference_genome_type) +@typecheck( + contig=expr_str, position=expr_int32, before=expr_int32, after=expr_int32, reference_genome=reference_genome_type +) def get_sequence(contig, position, before=0, after=0, reference_genome='default') -> StringExpression: """Return the reference sequence at a given locus. @@ -5911,13 +6113,16 @@ def get_sequence(contig, position, before=0, after=0, reference_genome='default' """ if not reference_genome.has_sequence(): - raise TypeError("Reference genome '{}' does not have a sequence loaded. Use 'add_sequence' to load the sequence from a FASTA file.".format(reference_genome.name)) + raise TypeError( + "Reference genome '{}' does not have a sequence loaded. Use 'add_sequence' to load the sequence from a FASTA file.".format( + reference_genome.name + ) + ) - return _func("getReferenceSequence", tstr, contig, position, before, after, type_args=(tlocus(reference_genome), )) + return _func("getReferenceSequence", tstr, contig, position, before, after, type_args=(tlocus(reference_genome),)) -@typecheck(contig=expr_str, - reference_genome=reference_genome_type) +@typecheck(contig=expr_str, reference_genome=reference_genome_type) def is_valid_contig(contig, reference_genome='default') -> BooleanExpression: """Returns ``True`` if `contig` is a valid contig name in `reference_genome`. @@ -5939,11 +6144,10 @@ def is_valid_contig(contig, reference_genome='default') -> BooleanExpression: ------- :class:`.BooleanExpression` """ - return _func("isValidContig", tbool, contig, type_args=(tlocus(reference_genome), )) + return _func("isValidContig", tbool, contig, type_args=(tlocus(reference_genome),)) -@typecheck(contig=expr_str, - reference_genome=reference_genome_type) +@typecheck(contig=expr_str, reference_genome=reference_genome_type) def contig_length(contig, reference_genome='default') -> Int32Expression: """Returns the length of `contig` in `reference_genome`. @@ -5962,12 +6166,10 @@ def contig_length(contig, reference_genome='default') -> Int32Expression: ------- :class:`.Int32Expression` """ - return _func("contigLength", tint32, contig, type_args=(tlocus(reference_genome), )) + return _func("contigLength", tint32, contig, type_args=(tlocus(reference_genome),)) -@typecheck(contig=expr_str, - position=expr_int32, - reference_genome=reference_genome_type) +@typecheck(contig=expr_str, position=expr_int32, reference_genome=reference_genome_type) def is_valid_locus(contig, position, reference_genome='default') -> BooleanExpression: """Returns ``True`` if `contig` and `position` is a valid site in `reference_genome`. @@ -5990,7 +6192,7 @@ def is_valid_locus(contig, position, reference_genome='default') -> BooleanExpre ------- :class:`.BooleanExpression` """ - return _func("isValidLocus", tbool, contig, position, type_args=(tlocus(reference_genome), )) + return _func("isValidLocus", tbool, contig, position, type_args=(tlocus(reference_genome),)) @typecheck(locus=expr_locus(), is_female=expr_bool, father=expr_call, mother=expr_call, child=expr_call) @@ -6074,35 +6276,40 @@ def mendel_error_code(locus, is_female, father, mother, child): mother_n = mother.n_alt_alleles() child_n = child.n_alt_alleles() - auto_cond = (hl.case(missing_false=True) - .when((father_n == 2) & (mother_n == 2) & (child_n == 1), 1) - .when((father_n == 0) & (mother_n == 0) & (child_n == 1), 2) - .when((father_n == 0) & (mother_n == 0) & (child_n == 2), 5) - .when((father_n == 2) & (mother_n == 2) & (child_n == 0), 8) - .when((father_n == 0) & (child_n == 2), 3) - .when((mother_n == 0) & (child_n == 2), 4) - .when((father_n == 2) & (child_n == 0), 6) - .when((mother_n == 2) & (child_n == 0), 7) - .or_missing() - ) - - hemi_x_cond = (hl.case(missing_false=True) - .when((mother_n == 2) & (child_n == 0), 9) - .when((mother_n == 0) & (child_n > 0), 10) - .or_missing() - ) - - hemi_y_cond = (hl.case(missing_false=True) - .when((father_n > 0) & (child_n == 0), 11) - .when((father_n == 0) & (child_n > 0), 12) - .or_missing() - ) - - return (hl.case() - .when(locus.in_autosome_or_par() | is_female, auto_cond) - .when(locus.in_x_nonpar() & (~is_female), hemi_x_cond) - .when(locus.in_y_nonpar() & (~is_female), hemi_y_cond) - .or_missing()) + auto_cond = ( + hl.case(missing_false=True) + .when((father_n == 2) & (mother_n == 2) & (child_n == 1), 1) + .when((father_n == 0) & (mother_n == 0) & (child_n == 1), 2) + .when((father_n == 0) & (mother_n == 0) & (child_n == 2), 5) + .when((father_n == 2) & (mother_n == 2) & (child_n == 0), 8) + .when((father_n == 0) & (child_n == 2), 3) + .when((mother_n == 0) & (child_n == 2), 4) + .when((father_n == 2) & (child_n == 0), 6) + .when((mother_n == 2) & (child_n == 0), 7) + .or_missing() + ) + + hemi_x_cond = ( + hl.case(missing_false=True) + .when((mother_n == 2) & (child_n == 0), 9) + .when((mother_n == 0) & (child_n > 0), 10) + .or_missing() + ) + + hemi_y_cond = ( + hl.case(missing_false=True) + .when((father_n > 0) & (child_n == 0), 11) + .when((father_n == 0) & (child_n > 0), 12) + .or_missing() + ) + + return ( + hl.case() + .when(locus.in_autosome_or_par() | is_female, auto_cond) + .when(locus.in_x_nonpar() & (~is_female), hemi_x_cond) + .when(locus.in_y_nonpar() & (~is_female), hemi_y_cond) + .or_missing() + ) @typecheck(locus=expr_locus(), alleles=expr_array(expr_str)) @@ -6139,10 +6346,12 @@ def min_rep(locus, alleles): return _func('min_rep', ret_type, locus, alleles) -@typecheck(x=oneof(expr_locus(), expr_interval(expr_locus())), - dest_reference_genome=reference_genome_type, - min_match=builtins.float, - include_strand=builtins.bool) +@typecheck( + x=oneof(expr_locus(), expr_interval(expr_locus())), + dest_reference_genome=reference_genome_type, + min_match=builtins.float, + include_strand=builtins.bool, +) def liftover(x, dest_reference_genome, min_match=0.95, include_strand=False): """Lift over coordinates to a different reference genome. @@ -6213,8 +6422,10 @@ def liftover(x, dest_reference_genome, min_match=0.95, include_strand=False): rtype = tstruct(result=tinterval(tlocus(dest_reference_genome)), is_negative_strand=tbool) if not rg.has_liftover(dest_reference_genome.name): - raise TypeError("""Reference genome '{}' does not have liftover to '{}'. - Use 'add_liftover' to load a liftover chain file.""".format(rg.name, dest_reference_genome.name)) + raise TypeError( + """Reference genome '{}' does not have liftover to '{}'. + Use 'add_liftover' to load a liftover chain file.""".format(rg.name, dest_reference_genome.name) + ) expr = _func(method_name, rtype, x, to_expr(min_match, tfloat64)) if not include_strand: @@ -6222,12 +6433,14 @@ def liftover(x, dest_reference_genome, min_match=0.95, include_strand=False): return expr -@typecheck(f=func_spec(1, expr_float64), - min=expr_float64, - max=expr_float64, - max_iter=builtins.int, - epsilon=builtins.float, - tolerance=builtins.float) +@typecheck( + f=func_spec(1, expr_float64), + min=expr_float64, + max=expr_float64, + max_iter=builtins.int, + epsilon=builtins.float, + tolerance=builtins.float, +) def uniroot(f: Callable, min, max, *, max_iter=1000, epsilon=2.2204460492503131e-16, tolerance=1.220703e-4): """Finds a root of the function `f` within the interval `[min, max]`. @@ -6273,9 +6486,8 @@ def uniroot(f: Callable, min, max, *, max_iter=1000, epsilon=2.2204460492503131e def error_if_missing(x): res = f(x) - return (case() - .when(is_defined(res), res) - .or_error(format("'uniroot': value of f(x) is missing for x = %.1e", x))) + return case().when(is_defined(res), res).or_error(format("'uniroot': value of f(x) is missing for x = %.1e", x)) + wrapped_f = hl.experimental.define_function(error_if_missing, 'float') def uniroot(recur, a, b, c, fa, fb, fc, prev, iterations_remaining): @@ -6287,17 +6499,18 @@ def uniroot(recur, a, b, c, fa, fb, fc, prev, iterations_remaining): pq = if_else( a == c, (cb * t1) / (t1 - 1.0), # linear - -t2 * (cb * q1 * (q1 - t1) - (b - a) * (t1 - 1.0)) - / ((q1 - 1.0) * (t1 - 1.0) * (t2 - 1.0))) # quadratic - - interpolated = if_else((sign(pq) == sign(cb)) - & (.75 * abs(cb) > abs(pq) + tol / 2) # b + pq within [b, c] - & (abs(pq) < abs(prev / 2)), # pq not too large - pq, cb / 2) + -t2 * (cb * q1 * (q1 - t1) - (b - a) * (t1 - 1.0)) / ((q1 - 1.0) * (t1 - 1.0) * (t2 - 1.0)), + ) # quadratic + + interpolated = if_else( + (sign(pq) == sign(cb)) + & (0.75 * abs(cb) > abs(pq) + tol / 2) # b + pq within [b, c] + & (abs(pq) < abs(prev / 2)), # pq not too large + pq, + cb / 2, + ) - new_step = if_else( - (abs(prev) >= tol) & (abs(fa) > abs(fb)), # try interpolation - interpolated, cb / 2) + new_step = if_else((abs(prev) >= tol) & (abs(fa) > abs(fb)), interpolated, cb / 2) # try interpolation new_b = b + if_else(new_step < 0, hl.min(new_step, -tol), hl.max(new_step, tol)) new_fb = wrapped_f(new_b) @@ -6305,28 +6518,49 @@ def uniroot(recur, a, b, c, fa, fb, fc, prev, iterations_remaining): return if_else( iterations_remaining == 0, missing('float'), - if_else(abs(fc) < abs(fb), - recur(b, c, b, fb, fc, fb, prev, iterations_remaining), - if_else((abs(cb / 2) <= tol) | (fb == 0), - b, # acceptable approximation found - if_else(sign(new_fb) == sign(fc), # use c = b for next iteration if signs match - recur(b, new_b, b, fb, new_fb, fb, new_step, iterations_remaining - 1), - recur(b, new_b, c, fb, new_fb, fc, new_step, iterations_remaining - 1) - )))) + if_else( + abs(fc) < abs(fb), + recur(b, c, b, fb, fc, fb, prev, iterations_remaining), + if_else( + (abs(cb / 2) <= tol) | (fb == 0), + b, # acceptable approximation found + if_else( + sign(new_fb) == sign(fc), # use c = b for next iteration if signs match + recur(b, new_b, b, fb, new_fb, fb, new_step, iterations_remaining - 1), + recur(b, new_b, c, fb, new_fb, fc, new_step, iterations_remaining - 1), + ), + ), + ), + ) fmin = wrapped_f(min) fmax = wrapped_f(max) run_loop = hl.experimental.define_function( - lambda min, max, fmin, fmax: - hl.experimental.loop(uniroot, 'float', - min, max, min, fmin, fmax, fmin, max - min, max_iter), - 'float', 'float', 'float', 'float') + lambda min, max, fmin, fmax: hl.experimental.loop( + uniroot, 'float', min, max, min, fmin, fmax, fmin, max - min, max_iter + ), + 'float', + 'float', + 'float', + 'float', + ) - return (case() - .when(min < max, case() - .when(fmin * fmax <= 0, run_loop(min, max, fmin, fmax)) - .or_error(format("'uniroot': sign of endpoints must have opposite signs, got: f(min) = %.1e, f(max) = %.1e", fmin, fmax))) - .or_error(format("'uniroot': min must be less than max in call to uniroot, got: min %.1e, max %.1e", min, max))) + return ( + case() + .when( + min < max, + case() + .when(fmin * fmax <= 0, run_loop(min, max, fmin, fmax)) + .or_error( + format( + "'uniroot': sign of endpoints must have opposite signs, got: f(min) = %.1e, f(max) = %.1e", + fmin, + fmax, + ) + ), + ) + .or_error(format("'uniroot': min must be less than max in call to uniroot, got: min %.1e, max %.1e", min, max)) + ) @typecheck(f=expr_str, args=expr_any) @@ -6411,11 +6645,16 @@ def _shift_op(x, y, op): zero = hl.int32(0) indices, aggregations = unify_all(x, y) - return hl.bind(lambda x, y: ( - hl.case() - .when(y >= word_size, hl.sign(x) if op == '>>' else zero) - .when(y >= 0, construct_expr(ir.ApplyBinaryPrimOp(op, x._ir, y._ir), t, indices, aggregations)) - .or_error('cannot shift by a negative value: ' + hl.str(x) + f" {op} " + hl.str(y))), x, y) + return hl.bind( + lambda x, y: ( + hl.case() + .when(y >= word_size, hl.sign(x) if op == '>>' else zero) + .when(y >= 0, construct_expr(ir.ApplyBinaryPrimOp(op, x._ir, y._ir), t, indices, aggregations)) + .or_error('cannot shift by a negative value: ' + hl.str(x) + f" {op} " + hl.str(y)) + ), + x, + y, + ) def _bit_op(x, y, op): @@ -6692,7 +6931,9 @@ def binary_search(array, elem) -> Int32Expression: """ c = coercer_from_dtype(array.dtype.element_type) if not c.can_coerce(elem.dtype): - raise TypeError(f"'binary_search': cannot search an array of type {array.dtype} for a value of type {elem.dtype}") + raise TypeError( + f"'binary_search': cannot search an array of type {array.dtype} for a value of type {elem.dtype}" + ) elem = c.coerce(elem) return hl.switch(elem).when_missing(hl.missing(hl.tint32)).default(_lower_bound(array, elem)) @@ -6705,8 +6946,9 @@ def _escape_string(s): @typecheck(left=expr_any, right=expr_any, tolerance=expr_float64, absolute=expr_bool) def _values_similar(left, right, tolerance=1e-6, absolute=False): assert left.dtype == right.dtype - return ((is_missing(left) & is_missing(right)) - | ((is_defined(left) & is_defined(right)) & _func("valuesSimilar", hl.tbool, left, right, tolerance, absolute))) + return (is_missing(left) & is_missing(right)) | ( + (is_defined(left) & is_defined(right)) & _func("valuesSimilar", hl.tbool, left, right, tolerance, absolute) + ) @typecheck(coords=expr_array(expr_array(expr_float64)), radius=expr_float64) @@ -6715,9 +6957,8 @@ def _locus_windows_per_contig(coords, radius): return _func("locus_windows_per_contig", rt, coords, radius) -@typecheck(a=expr_array(), - seed=nullable(builtins.int)) -def shuffle(a, seed: builtins.int = None) -> ArrayExpression: +@typecheck(a=expr_array(), seed=nullable(builtins.int)) +def shuffle(a, seed: Optional[builtins.int] = None) -> ArrayExpression: """Randomly permute an array Example @@ -6781,32 +7022,41 @@ def coerce_endpoint(point): ts = point.dtype if isinstance(ts, tstruct): i = 0 - while (i < len(ts)): + while i < len(ts): if i >= len(key_typ): raise ValueError( - f"query_table: queried with {len(ts)} key field(s), but table only has {len(key_typ)} key field(s)") + f"query_table: queried with {len(ts)} key field(s), but table only has {len(key_typ)} key field(s)" + ) if key_typ[i] != ts[i]: raise ValueError( - f"query_table: key mismatch at key field {i} ({list(ts.keys())[i]!r}): query type is {ts[i]}, table key type is {key_typ[i]}") + f"query_table: key mismatch at key field {i} ({list(ts.keys())[i]!r}): query type is {ts[i]}, table key type is {key_typ[i]}" + ) i += 1 if i == 0: raise ValueError("query_table: cannot query with empty key") point_size = builtins.len(point.dtype) - return hl.tuple( - [hl.struct(**{key_names[i]: (point[i] if i < point_size else hl.missing(key_typ[i])) - for i in builtins.range(builtins.len(key_typ))}), hl.int32(point_size)]) + return hl.tuple([ + hl.struct(**{ + key_names[i]: (point[i] if i < point_size else hl.missing(key_typ[i])) + for i in builtins.range(builtins.len(key_typ)) + }), + hl.int32(point_size), + ]) else: raise ValueError( f"query_table: key mismatch: cannot query a table with key " - f"({', '.join(builtins.str(x) for x in key_typ.values())}) with query point type {point.dtype}") + f"({', '.join(builtins.str(x) for x in key_typ.values())}) with query point type {point.dtype}" + ) if point_or_interval.dtype != key_typ[0] and isinstance(point_or_interval.dtype, hl.tinterval): - partition_interval = hl.interval(start=coerce_endpoint(point_or_interval.start), - end=coerce_endpoint(point_or_interval.end), - includes_start=point_or_interval.includes_start, - includes_end=point_or_interval.includes_end) + partition_interval = hl.interval( + start=coerce_endpoint(point_or_interval.start), + end=coerce_endpoint(point_or_interval.end), + includes_start=point_or_interval.includes_start, + includes_end=point_or_interval.includes_end, + ) else: point = coerce_endpoint(point_or_interval) partition_interval = hl.interval(start=point, end=point, includes_start=True, includes_end=True) @@ -6814,7 +7064,7 @@ def coerce_endpoint(point): ir.ToArray(ir.ReadPartition(partition_interval._ir, reader=ir.PartitionNativeIntervalReader(path, row_typ))), type=hl.tarray(row_typ), indices=partition_interval._indices, - aggregations=partition_interval._aggregations + aggregations=partition_interval._aggregations, ) diff --git a/hail/python/hail/expr/matrix_type.py b/hail/python/hail/expr/matrix_type.py index 520d65f7540..a50e7b6fdd4 100644 --- a/hail/python/hail/expr/matrix_type.py +++ b/hail/python/hail/expr/matrix_type.py @@ -1,9 +1,8 @@ import pprint -from hail.typecheck import typecheck_method, sequenceof -from hail.utils.java import escape_parsable from hail.expr.types import dtype, tstruct -from hail.utils.java import jiterable_to_list +from hail.typecheck import sequenceof, typecheck_method +from hail.utils.java import escape_parsable, jiterable_to_list class tmatrix(object): @@ -17,7 +16,8 @@ def _from_java(jtt): jiterable_to_list(jtt.colKey()), dtype(jtt.rowType().toString()), jiterable_to_list(jtt.rowKey()), - dtype(jtt.entryType().toString())) + dtype(jtt.entryType().toString()), + ) @staticmethod def _from_json(json): @@ -27,12 +27,17 @@ def _from_json(json): col_key=json['col_key'], row_type=dtype(json['row_type']), row_key=json['row_key'], - entry_type=dtype(json['entry_type'])) - - @typecheck_method(global_type=tstruct, - col_type=tstruct, col_key=sequenceof(str), - row_type=tstruct, row_key=sequenceof(str), - entry_type=tstruct) + entry_type=dtype(json['entry_type']), + ) + + @typecheck_method( + global_type=tstruct, + col_type=tstruct, + col_key=sequenceof(str), + row_type=tstruct, + row_key=sequenceof(str), + entry_type=tstruct, + ) def __init__(self, global_type, col_type, col_key, row_type, row_key, entry_type): self.global_type = global_type self.col_type = col_type @@ -42,21 +47,25 @@ def __init__(self, global_type, col_type, col_key, row_type, row_key, entry_type self.entry_type = entry_type def to_dict(self): - return dict(global_type=str(self.global_type), - col_type=str(self.col_type), - col_key=self.col_key, - row_type=str(self.row_type), - row_key=self.row_key, - entry_type=self.entry_type) + return dict( + global_type=str(self.global_type), + col_type=str(self.col_type), + col_key=self.col_key, + row_type=str(self.row_type), + row_key=self.row_key, + entry_type=self.entry_type, + ) def __eq__(self, other): - return (isinstance(other, tmatrix) - and self.global_type == other.global_type - and self.col_type == other.col_type - and self.col_key == other.col_key - and self.row_type == other.row_type - and self.row_key == other.row_key - and self.entry_type == other.entry_type) + return ( + isinstance(other, tmatrix) + and self.global_type == other.global_type + and self.col_type == other.col_type + and self.col_key == other.col_key + and self.row_type == other.row_type + and self.row_key == other.row_key + and self.entry_type == other.entry_type + ) def __hash__(self): return 43 + hash(str(self)) @@ -128,12 +137,14 @@ def row_value_type(self): return self.row_type._drop_fields(set(self.row_key)) def _rename(self, global_map, col_map, row_map, entry_map): - return tmatrix(self.global_type._rename(global_map), - self.col_type._rename(col_map), - [col_map.get(k, k) for k in self.col_key], - self.row_type._rename(row_map), - [row_map.get(k, k) for k in self.row_key], - self.entry_type._rename(entry_map)) + return tmatrix( + self.global_type._rename(global_map), + self.col_type._rename(col_map), + [col_map.get(k, k) for k in self.col_key], + self.row_type._rename(row_map), + [row_map.get(k, k) for k in self.row_key], + self.entry_type._rename(entry_map), + ) def global_env(self, default_value=None): if default_value is None: @@ -143,31 +154,21 @@ def global_env(self, default_value=None): def row_env(self, default_value=None): if default_value is None: - return {'global': self.global_type, - 'va': self.row_type} + return {'global': self.global_type, 'va': self.row_type} else: - return {'global': default_value, - 'va': default_value} + return {'global': default_value, 'va': default_value} def col_env(self, default_value=None): if default_value is None: - return {'global': self.global_type, - 'sa': self.col_type} + return {'global': self.global_type, 'sa': self.col_type} else: - return {'global': default_value, - 'sa': default_value} + return {'global': default_value, 'sa': default_value} def entry_env(self, default_value=None): if default_value is None: - return {'global': self.global_type, - 'va': self.row_type, - 'sa': self.col_type, - 'g': self.entry_type} + return {'global': self.global_type, 'va': self.row_type, 'sa': self.col_type, 'g': self.entry_type} else: - return {'global': default_value, - 'va': default_value, - 'sa': default_value, - 'g': default_value} + return {'global': default_value, 'va': default_value, 'sa': default_value, 'g': default_value} _old_printer = pprint.PrettyPrinter diff --git a/hail/python/hail/expr/table_type.py b/hail/python/hail/expr/table_type.py index 70528397591..72384f5307f 100644 --- a/hail/python/hail/expr/table_type.py +++ b/hail/python/hail/expr/table_type.py @@ -1,8 +1,8 @@ import pprint -from hail.typecheck import typecheck_method, sequenceof -from hail.utils.java import escape_parsable + from hail.expr.types import dtype, tstruct -from hail.utils.java import jiterable_to_list +from hail.typecheck import sequenceof, typecheck_method +from hail.utils.java import escape_parsable, jiterable_to_list class ttable(object): @@ -10,17 +10,11 @@ class ttable(object): @staticmethod def _from_java(jtt): - return ttable( - dtype(jtt.globalType().toString()), - dtype(jtt.rowType().toString()), - jiterable_to_list(jtt.key())) + return ttable(dtype(jtt.globalType().toString()), dtype(jtt.rowType().toString()), jiterable_to_list(jtt.key())) @staticmethod def _from_json(json): - return ttable( - global_type=dtype(json['global_type']), - row_type=dtype(json['row_type']), - row_key=json['row_key']) + return ttable(global_type=dtype(json['global_type']), row_type=dtype(json['row_type']), row_key=json['row_key']) @typecheck_method(global_type=tstruct, row_type=tstruct, row_key=sequenceof(str)) def __init__(self, global_type, row_type, row_key): @@ -29,15 +23,15 @@ def __init__(self, global_type, row_type, row_key): self.row_key = row_key def to_dict(self): - return dict(global_type=str(self.global_type), - row_type=str(self.row_type), - row_key=self.row_key) + return dict(global_type=str(self.global_type), row_type=str(self.row_type), row_key=self.row_key) def __eq__(self, other): - return (isinstance(other, ttable) - and self.global_type == other.global_type - and self.row_type == other.row_type - and self.row_key == other.row_key) + return ( + isinstance(other, ttable) + and self.global_type == other.global_type + and self.row_type == other.row_type + and self.row_key == other.row_key + ) def __hash__(self): return 43 + hash(str(self)) @@ -85,9 +79,11 @@ def value_type(self): return self.row_type._drop_fields(set(self.row_key)) def _rename(self, global_map, row_map): - return ttable(self.global_type._rename(global_map), - self.row_type._rename(row_map), - [row_map.get(k, k) for k in self.row_key]) + return ttable( + self.global_type._rename(global_map), + self.row_type._rename(row_map), + [row_map.get(k, k) for k in self.row_key], + ) def row_env(self, default_value=None): if default_value is None: diff --git a/hail/python/hail/expr/type_parsing.py b/hail/python/hail/expr/type_parsing.py index 884f524938b..f5b22e36f37 100644 --- a/hail/python/hail/expr/type_parsing.py +++ b/hail/python/hail/expr/type_parsing.py @@ -1,10 +1,11 @@ from parsimonious import Grammar, NodeVisitor + from hail.expr.nat import NatVariable -from . import types from hail.utils.java import unescape_parsable -type_grammar = Grammar( - r""" +from . import types + +type_grammar = Grammar(r""" type = _ (array / ndarray / set / dict / struct / union / tuple / interval / int64 / int32 / float32 / float64 / bool / str / call / str / locus / void / variable) _ variable = "?" simple_identifier (":" simple_identifier)? void = "void" / "tvoid" @@ -154,8 +155,7 @@ def visit_nat_variable(self, node, visited_children): type_node_visitor = TypeConstructor() -vcf_type_grammar = Grammar( - r""" +vcf_type_grammar = Grammar(r""" type = _ (array / set / int32 / int64 / float32 / float64 / str / bool / call / struct) _ int64 = "Int64" int32 = "Int32" diff --git a/hail/python/hail/expr/types.py b/hail/python/hail/expr/types.py index 754beab11c5..af74251e6ca 100644 --- a/hail/python/hail/expr/types.py +++ b/hail/python/hail/expr/types.py @@ -1,10 +1,10 @@ -from typing import Union import abc +import builtins import json import math -from collections.abc import Mapping, Sequence import pprint -import builtins +from collections.abc import Mapping, Sequence +from typing import ClassVar, Union import numpy as np import pandas as pd @@ -13,16 +13,15 @@ from hailtop.frozendict import frozendict from hailtop.hail_frozenlist import frozenlist -from .nat import NatBase, NatLiteral -from .type_parsing import type_grammar, type_node_visitor from .. import genetics -from ..typecheck import typecheck, typecheck_method, oneof, transformed, nullable -from ..utils.struct import Struct +from ..genetics.reference_genome import reference_genome_type +from ..typecheck import nullable, oneof, transformed, typecheck, typecheck_method from ..utils.byte_reader import ByteReader, ByteWriter -from ..utils.misc import lookup_bit from ..utils.java import escape_parsable -from ..genetics.reference_genome import reference_genome_type - +from ..utils.misc import lookup_bit +from ..utils.struct import Struct +from .nat import NatBase, NatLiteral +from .type_parsing import type_grammar, type_node_visitor __all__ = [ 'dtype', @@ -153,10 +152,7 @@ def is_empty(self): def _to_json_context(self): if self._json is None: - self._json = { - 'reference_genomes': - {r: hl.get_reference(r)._config for r in self.references} - } + self._json = {'reference_genomes': {r: hl.get_reference(r)._config for r in self.references}} return self._json @classmethod @@ -188,14 +184,14 @@ def __repr__(self): @abc.abstractmethod def _eq(self, other): - return + raise NotImplementedError def __eq__(self, other): return isinstance(other, HailType) and self._eq(other) @abc.abstractmethod def __str__(self): - return + raise NotImplementedError def __hash__(self): # FIXME this is a bit weird @@ -223,7 +219,7 @@ def _pretty(self, b, indent, increment): @abc.abstractmethod def _parsable_string(self) -> str: - pass + raise NotImplementedError def typecheck(self, value): """Check that `value` matches a type. @@ -237,14 +233,16 @@ def typecheck(self, value): ------ :obj:`TypeError` """ + def check(t, obj): t._typecheck_one_level(obj) return True + self._traverse(value, check) @abc.abstractmethod def _typecheck_one_level(self, annotation): - pass + raise NotImplementedError def _to_json(self, x): converted = self._convert_to_json_na(x) @@ -375,8 +373,10 @@ def _typecheck_one_level(self, annotation): if not is_int32(annotation): raise TypeError("type 'tint32' expected Python 'int', but found type '%s'" % type(annotation)) elif not self.min_value <= annotation <= self.max_value: - raise TypeError(f"Value out of range for 32-bit integer: " - f"expected [{self.min_value}, {self.max_value}], found {annotation}") + raise TypeError( + f"Value out of range for 32-bit integer: " + f"expected [{self.min_value}, {self.max_value}], found {annotation}" + ) def __str__(self): return "int32" @@ -433,8 +433,10 @@ def _typecheck_one_level(self, annotation): if not is_int64(annotation): raise TypeError("type 'int64' expected Python 'int', but found type '%s'" % type(annotation)) if not self.min_value <= annotation <= self.max_value: - raise TypeError(f"Value out of range for 64-bit integer: " - f"expected [{self.min_value}, {self.max_value}], found {annotation}") + raise TypeError( + f"Value out of range for 64-bit integer: " + f"expected [{self.min_value}, {self.max_value}], found {annotation}" + ) def __str__(self): return "int64" @@ -667,8 +669,8 @@ def _convert_from_encoding(self, byte_reader, _should_freeze: bool = False) -> b def _convert_to_encoding(self, byte_writer: ByteWriter, value): byte_writer.write_bool(value) -class _trngstate(HailType): +class _trngstate(HailType): def __init__(self): super(_trngstate, self).__init__() @@ -758,9 +760,7 @@ def __str__(self): return "ndarray<{}, {}>".format(self.element_type, self.ndim) def _eq(self, other): - return (isinstance(other, tndarray) - and self.element_type == other.element_type - and self.ndim == other.ndim) + return isinstance(other, tndarray) and self.element_type == other.element_type and self.ndim == other.ndim def _pretty(self, b, indent, increment): b.append('ndarray<') @@ -786,12 +786,9 @@ def _convert_to_json(self, x): axis_one_step_byte_size = x.itemsize for dimension_size in x.shape: strides.append(axis_one_step_byte_size) - axis_one_step_byte_size *= (dimension_size if dimension_size > 0 else 1) + axis_one_step_byte_size *= dimension_size if dimension_size > 0 else 1 - json_dict = { - "shape": x.shape, - "data": data - } + json_dict = {"shape": x.shape, "data": data} return json_dict def clear(self): @@ -799,9 +796,7 @@ def clear(self): self._ndim.clear() def unify(self, t): - return isinstance(t, tndarray) and \ - self._element_type.unify(t._element_type) and \ - self._ndim.unify(t._ndim) + return isinstance(t, tndarray) and self._element_type.unify(t._element_type) and self._ndim.unify(t._ndim) def subst(self): return tndarray(self._element_type.subst(), self._ndim.subst()) @@ -819,7 +814,9 @@ def _convert_from_encoding(self, byte_reader, _should_freeze: bool = False) -> n buffer = byte_reader.read_bytes_view(bytes_to_read) return np.frombuffer(buffer, self.element_type.to_numpy, count=total_num_elements).reshape(shape) else: - elements = [self.element_type._convert_from_encoding(byte_reader, _should_freeze) for i in range(total_num_elements)] + elements = [ + self.element_type._convert_from_encoding(byte_reader, _should_freeze) for i in range(total_num_elements) + ] np_type = self.element_type.to_numpy() return np.ndarray(shape=shape, buffer=np.array(elements, dtype=np_type), dtype=np_type, order="F") @@ -944,7 +941,6 @@ def _convert_from_encoding(self, byte_reader, _should_freeze: bool = False) -> U return frozenlist(decoded) return decoded - def _convert_to_encoding(self, byte_writer: ByteWriter, value): length = len(value) byte_writer.write_int32(length) @@ -1020,7 +1016,7 @@ def _get_context(self): def is_setlike(maybe_setlike): - return isinstance(maybe_setlike, set) or isinstance(maybe_setlike, frozenset) + return isinstance(maybe_setlike, (set, frozenset)) class tset(HailType): @@ -1216,15 +1212,21 @@ def _parsable_string(self): return "Dict[{},{}]".format(self.key_type._parsable_string(), self.value_type._parsable_string()) def _convert_from_json(self, x, _should_freeze: bool = False) -> Union[dict, frozendict]: - d = {self.key_type._convert_from_json_na(elt['key'], _should_freeze=True): - self.value_type._convert_from_json_na(elt['value'], _should_freeze=_should_freeze) for elt in x} + d = { + self.key_type._convert_from_json_na(elt['key'], _should_freeze=True): self.value_type._convert_from_json_na( + elt['value'], _should_freeze=_should_freeze + ) + for elt in x + } if _should_freeze: return frozendict(d) return d def _convert_to_json(self, x): - return [{'key': self.key_type._convert_to_json(k), - 'value': self.value_type._convert_to_json(v)} for k, v in x.items()] + return [ + {'key': self.key_type._convert_to_json(k), 'value': self.value_type._convert_to_json(v)} + for k, v in x.items() + ] def _convert_from_encoding(self, byte_reader, _should_freeze: bool = False) -> Union[dict, frozendict]: # NB: We ensure the key is always frozen with a wrapper on the key_type in the _array_repr. @@ -1249,9 +1251,7 @@ def _propagate_jtypes(self, jtype): self._value_type._add_jtype(jtype.valueType()) def unify(self, t): - return (isinstance(t, tdict) - and self.key_type.unify(t.key_type) - and self.value_type.unify(t.value_type)) + return isinstance(t, tdict) and self.key_type.unify(t.key_type) and self.value_type.unify(t.value_type) def subst(self): return tdict(self._key_type.subst(), self._value_type.subst()) @@ -1344,11 +1344,15 @@ def _typecheck_one_level(self, annotation): s = set(self) for f in annotation: if f not in s: - raise TypeError("type '%s' expected fields '%s', but found fields '%s'" % - (self, list(self), list(annotation))) + raise TypeError( + "type '%s' expected fields '%s', but found fields '%s'" + % (self, list(self), list(annotation)) + ) else: - raise TypeError("type 'struct' expected type Mapping (e.g. dict or hail.utils.Struct), but found '%s'" % - type(annotation)) + raise TypeError( + "type 'struct' expected type Mapping (e.g. dict or hail.utils.Struct), but found '%s'" + % type(annotation) + ) @typecheck_method(item=oneof(int, str)) def __getitem__(self, item): @@ -1363,16 +1367,17 @@ def __len__(self): return len(self._fields) def __str__(self): - return "struct{{{}}}".format( - ', '.join('{}: {}'.format(escape_parsable(f), str(t)) for f, t in self.items())) + return "struct{{{}}}".format(', '.join('{}: {}'.format(escape_parsable(f), str(t)) for f, t in self.items())) def items(self): return self._field_types.items() def _eq(self, other): - return (isinstance(other, tstruct) - and self._fields == other._fields - and all(self[f] == other[f] for f in self._fields)) + return ( + isinstance(other, tstruct) + and self._fields == other._fields + and all(self[f] == other[f] for f in self._fields) + ) def _pretty(self, b, indent, increment): if not self._fields: @@ -1395,7 +1400,8 @@ def _pretty(self, b, indent, increment): def _parsable_string(self): return "Struct{{{}}}".format( - ','.join('{}:{}'.format(escape_parsable(f), t._parsable_string()) for f, t in self.items())) + ','.join('{}:{}'.format(escape_parsable(f), t._parsable_string()) for f, t in self.items()) + ) def _convert_from_json(self, x, _should_freeze: bool = False) -> Struct: return Struct(**{f: t._convert_from_json_na(x.get(f), _should_freeze) for f, t in self._field_types.items()}) @@ -1440,9 +1446,11 @@ def _convert_to_encoding(self, byte_writer: ByteWriter, value): t._convert_to_encoding(byte_writer, value[f]) def _is_prefix_of(self, other): - return (isinstance(other, tstruct) - and len(self._fields) <= len(other._fields) - and all(x == y for x, y in zip(self._field_types.values(), other._field_types.values()))) + return ( + isinstance(other, tstruct) + and len(self._fields) <= len(other._fields) + and all(x == y for x, y in zip(self._field_types.values(), other._field_types.values())) + ) def _concat(self, other): new_field_types = {} @@ -1490,7 +1498,9 @@ def _rename(self, map): if f in seen: raise ValueError( "Cannot rename two fields to the same name: attempted to rename {} and {} both to {}".format( - repr(seen[f]), repr(f0), repr(f))) + repr(seen[f]), repr(f0), repr(f) + ) + ) else: seen[f] = f0 new_field_types[f] = t @@ -1535,7 +1545,6 @@ def __init__(self, **case_types): @property def cases(self): - """Return union case names. Returns @@ -1558,13 +1567,12 @@ def __len__(self): return len(self._cases) def __str__(self): - return "union{{{}}}".format( - ', '.join('{}: {}'.format(escape_parsable(f), str(t)) for f, t in self.items())) + return "union{{{}}}".format(', '.join('{}: {}'.format(escape_parsable(f), str(t)) for f, t in self.items())) def _eq(self, other): - return (isinstance(other, tunion) - and self._cases == other._cases - and all(self[c] == other[c] for c in self._cases)) + return ( + isinstance(other, tunion) and self._cases == other._cases and all(self[c] == other[c] for c in self._cases) + ) def _pretty(self, b, indent, increment): if not self._cases: @@ -1587,7 +1595,8 @@ def _pretty(self, b, indent, increment): def _parsable_string(self): return "Union{{{}}}".format( - ','.join('{}:{}'.format(escape_parsable(f), t._parsable_string()) for f, t in self.items())) + ','.join('{}:{}'.format(escape_parsable(f), t._parsable_string()) for f, t in self.items()) + ) def unify(self, t): if not (isinstance(t, tunion) and len(self) == len(t)): @@ -1646,11 +1655,9 @@ def _traverse(self, obj, f): def _typecheck_one_level(self, annotation): if annotation: if not isinstance(annotation, tuple): - raise TypeError("type 'tuple' expected Python tuple, but found '%s'" % - type(annotation)) + raise TypeError("type 'tuple' expected Python tuple, but found '%s'" % type(annotation)) if len(annotation) != len(self.types): - raise TypeError("%s expected tuple of size '%i', but found '%s'" % - (self, len(self.types), annotation)) + raise TypeError("%s expected tuple of size '%i', but found '%s'" % (self, len(self.types), annotation)) @typecheck_method(item=int) def __getitem__(self, item): @@ -1668,8 +1675,10 @@ def __str__(self): def _eq(self, other): from operator import eq - return isinstance(other, ttuple) and len(self.types) == len(other.types) and all( - map(eq, self.types, other.types)) + + return ( + isinstance(other, ttuple) and len(self.types) == len(other.types) and all(map(eq, self.types, other.types)) + ) def _pretty(self, b, indent, increment): pre_indent = indent @@ -1747,13 +1756,13 @@ def _get_context(self): def allele_pair(j: int, k: int): - assert j >= 0 and j <= 0xffff - assert k >= 0 and k <= 0xffff + assert j >= 0 and j <= 0xFFFF + assert k >= 0 and k <= 0xFFFF return j | (k << 16) def allele_pair_sqrt(i): - k = int(math.sqrt(8 * float(i) + 1) / 2 - .5) + k = int(math.sqrt(8 * float(i) + 1) / 2 - 0.5) assert k * (k + 1) // 2 <= i j = i - k * (k + 1) // 2 # TODO another assert @@ -1761,13 +1770,42 @@ def allele_pair_sqrt(i): small_allele_pair = [ - allele_pair(0, 0), allele_pair(0, 1), allele_pair(1, 1), - allele_pair(0, 2), allele_pair(1, 2), allele_pair(2, 2), - allele_pair(0, 3), allele_pair(1, 3), allele_pair(2, 3), allele_pair(3, 3), - allele_pair(0, 4), allele_pair(1, 4), allele_pair(2, 4), allele_pair(3, 4), allele_pair(4, 4), - allele_pair(0, 5), allele_pair(1, 5), allele_pair(2, 5), allele_pair(3, 5), allele_pair(4, 5), allele_pair(5, 5), - allele_pair(0, 6), allele_pair(1, 6), allele_pair(2, 6), allele_pair(3, 6), allele_pair(4, 6), allele_pair(5, 6), allele_pair(6, 6), - allele_pair(0, 7), allele_pair(1, 7), allele_pair(2, 7), allele_pair(3, 7), allele_pair(4, 7), allele_pair(5, 7), allele_pair(6, 7), allele_pair(7, 7) + allele_pair(0, 0), + allele_pair(0, 1), + allele_pair(1, 1), + allele_pair(0, 2), + allele_pair(1, 2), + allele_pair(2, 2), + allele_pair(0, 3), + allele_pair(1, 3), + allele_pair(2, 3), + allele_pair(3, 3), + allele_pair(0, 4), + allele_pair(1, 4), + allele_pair(2, 4), + allele_pair(3, 4), + allele_pair(4, 4), + allele_pair(0, 5), + allele_pair(1, 5), + allele_pair(2, 5), + allele_pair(3, 5), + allele_pair(4, 5), + allele_pair(5, 5), + allele_pair(0, 6), + allele_pair(1, 6), + allele_pair(2, 6), + allele_pair(3, 6), + allele_pair(4, 6), + allele_pair(5, 6), + allele_pair(6, 6), + allele_pair(0, 7), + allele_pair(1, 7), + allele_pair(2, 7), + allele_pair(3, 7), + allele_pair(4, 7), + allele_pair(5, 7), + allele_pair(6, 7), + allele_pair(7, 7), ] @@ -1782,8 +1820,7 @@ def __init__(self): def _typecheck_one_level(self, annotation): if annotation is not None and not isinstance(annotation, genetics.Call): - raise TypeError("type 'call' expected Python hail.genetics.Call, but found %s'" % - type(annotation)) + raise TypeError("type 'call' expected Python hail.genetics.Call, but found %s'" % type(annotation)) def __str__(self): return "call" @@ -1813,7 +1850,7 @@ def _convert_from_json(self, x, _should_freeze: bool = False) -> genetics.Call: if i == n: return genetics.Call([int(x)]) - return genetics.Call([int(x[0:i]), int(x[i + 1:])], phased=(c == '|')) + return genetics.Call([int(x[0:i]), int(x[i + 1 :])], phased=(c == '|')) def _convert_to_json(self, x): return str(x) @@ -1828,10 +1865,10 @@ def allele_repr(c): return c >> 3 def ap_j(p): - return p & 0xffff + return p & 0xFFFF def ap_k(p): - return (p >> 16) & 0xffff + return (p >> 16) & 0xFFFF def gt_allele_pair(i): if i < len(small_allele_pair): @@ -1931,11 +1968,15 @@ def __init__(self, reference_genome='default'): def _typecheck_one_level(self, annotation): if annotation is not None: if not isinstance(annotation, genetics.Locus): - raise TypeError("type '{}' expected Python hail.genetics.Locus, but found '{}'" - .format(self, type(annotation))) + raise TypeError( + "type '{}' expected Python hail.genetics.Locus, but found '{}'".format(self, type(annotation)) + ) if not self.reference_genome == annotation.reference_genome: - raise TypeError("type '{}' encountered Locus with reference genome {}" - .format(self, repr(annotation.reference_genome))) + raise TypeError( + "type '{}' encountered Locus with reference genome {}".format( + self, repr(annotation.reference_genome) + ) + ) def __str__(self): return "locus<{}>".format(escape_parsable(str(self.reference_genome))) @@ -2028,13 +2069,16 @@ def _traverse(self, obj, f): def _typecheck_one_level(self, annotation): from hail.utils import Interval + if annotation is not None: if not isinstance(annotation, Interval): - raise TypeError("type '{}' expected Python hail.utils.Interval, but found {}" - .format(self, type(annotation))) + raise TypeError( + "type '{}' expected Python hail.utils.Interval, but found {}".format(self, type(annotation)) + ) if annotation.point_type != self.point_type: - raise TypeError("type '{}' encountered Interval with point type {}" - .format(self, repr(annotation.point_type))) + raise TypeError( + "type '{}' encountered Interval with point type {}".format(self, repr(annotation.point_type)) + ) def __str__(self): return "interval<{}>".format(str(self.point_type)) @@ -2052,25 +2096,32 @@ def _parsable_string(self): def _convert_from_json(self, x, _should_freeze: bool = False): from hail.utils import Interval - return Interval(self.point_type._convert_from_json_na(x['start'], _should_freeze), - self.point_type._convert_from_json_na(x['end'], _should_freeze), - x['includeStart'], - x['includeEnd'], - point_type=self.point_type) + + return Interval( + self.point_type._convert_from_json_na(x['start'], _should_freeze), + self.point_type._convert_from_json_na(x['end'], _should_freeze), + x['includeStart'], + x['includeEnd'], + point_type=self.point_type, + ) def _convert_to_json(self, x): - return {'start': self.point_type._convert_to_json_na(x.start), - 'end': self.point_type._convert_to_json_na(x.end), - 'includeStart': x.includes_start, - 'includeEnd': x.includes_end} + return { + 'start': self.point_type._convert_to_json_na(x.start), + 'end': self.point_type._convert_to_json_na(x.end), + 'includeStart': x.includes_start, + 'includeEnd': x.includes_end, + } def _convert_from_encoding(self, byte_reader, _should_freeze: bool = False): interval_as_struct = self._struct_repr._convert_from_encoding(byte_reader, _should_freeze) - return hl.Interval(interval_as_struct.start, - interval_as_struct.end, - interval_as_struct.includes_start, - interval_as_struct.includes_end, - point_type=self.point_type) + return hl.Interval( + interval_as_struct.start, + interval_as_struct.end, + interval_as_struct.includes_start, + interval_as_struct.includes_end, + point_type=self.point_type, + ) def _convert_to_encoding(self, byte_writer, value): interval_dict = { @@ -2095,7 +2146,7 @@ def _get_context(self): class Box(object): - named_boxes = {} + named_boxes: ClassVar = {} @staticmethod def from_name(name): @@ -2230,23 +2281,16 @@ def is_primitive(t) -> bool: @typecheck(t=HailType) def is_container(t) -> bool: - return (isinstance(t, tarray) - or isinstance(t, tset) - or isinstance(t, tdict)) + return isinstance(t, (tarray, tset, tdict)) @typecheck(t=HailType) def is_compound(t) -> bool: - return (is_container(t) - or isinstance(t, tstruct) - or isinstance(t, tunion) - or isinstance(t, ttuple) - or isinstance(t, tndarray)) + return is_container(t) or isinstance(t, (tstruct, tunion, ttuple, tndarray)) def types_match(left, right) -> bool: - return (len(left) == len(right) - and all(map(lambda lr: lr[0].dtype == lr[1].dtype, zip(left, right)))) + return len(left) == len(right) and all(map(lambda lr: lr[0].dtype == lr[1].dtype, zip(left, right))) def is_int32(x): @@ -2302,7 +2346,7 @@ def dtypes_from_pandas(pd_dtype): class tvariable(HailType): - _cond_map = { + _cond_map: ClassVar = { 'numeric': is_numeric, 'int32': lambda x: x == tint32, 'int64': lambda x: x == tint64, @@ -2311,7 +2355,7 @@ class tvariable(HailType): 'locus': lambda x: isinstance(x, tlocus), 'struct': lambda x: isinstance(x, tstruct), 'union': lambda x: isinstance(x, tunion), - 'tuple': lambda x: isinstance(x, ttuple) + 'tuple': lambda x: isinstance(x, ttuple), } def __init__(self, name, cond): diff --git a/hail/python/hail/fs/hadoop_fs.py b/hail/python/hail/fs/hadoop_fs.py index 9927a6b7abb..b90e8662d07 100644 --- a/hail/python/hail/fs/hadoop_fs.py +++ b/hail/python/hail/fs/hadoop_fs.py @@ -1,12 +1,20 @@ import io import json import time -from typing import Dict, List, Union, Any +from typing import Any, Dict, List, Union import dateutil.parser from hailtop.fs.fs import FS -from hailtop.fs.stat_result import FileType, FileListEntry +from hailtop.fs.stat_result import FileListEntry, FileStatus, FileType + + +def _file_status_scala_to_python(file_status: Dict[str, Any]) -> FileStatus: + dt = dateutil.parser.isoparse(file_status['modification_time']) + mtime = time.mktime(dt.timetuple()) + return FileStatus( + path=file_status['path'], owner=file_status['owner'], size=file_status['size'], modification_time=mtime + ) def _file_list_entry_scala_to_python(file_list_entry: Dict[str, Any]) -> FileListEntry: @@ -18,11 +26,13 @@ def _file_list_entry_scala_to_python(file_list_entry: Dict[str, Any]) -> FileLis typ = FileType.SYMLINK else: typ = FileType.FILE - return FileListEntry(path=file_list_entry['path'], - owner=file_list_entry['owner'], - size=file_list_entry['size'], - typ=typ, - modification_time=mtime) + return FileListEntry( + path=file_list_entry['path'], + owner=file_list_entry['owner'], + size=file_list_entry['size'], + typ=typ, + modification_time=mtime, + ) class HadoopFS(FS): @@ -40,11 +50,15 @@ def legacy_open(self, path: str, mode: str = 'r', buffer_size: int = 8192): def _open(self, path: str, mode: str = 'r', buffer_size: int = 8192, use_codec: bool = False): handle: Union[io.BufferedReader, io.BufferedWriter] if 'r' in mode: - handle = io.BufferedReader(HadoopReader(self, path, buffer_size, use_codec=use_codec), buffer_size=buffer_size) + handle = io.BufferedReader( + HadoopReader(self, path, buffer_size, use_codec=use_codec), buffer_size=buffer_size + ) elif 'w' in mode: handle = io.BufferedWriter(HadoopWriter(self, path, use_codec=use_codec), buffer_size=buffer_size) elif 'x' in mode: - handle = io.BufferedWriter(HadoopWriter(self, path, exclusive=True, use_codec=use_codec), buffer_size=buffer_size) + handle = io.BufferedWriter( + HadoopWriter(self, path, exclusive=True, use_codec=use_codec), buffer_size=buffer_size + ) if 'b' in mode: return handle @@ -63,13 +77,26 @@ def is_file(self, path: str) -> bool: def is_dir(self, path: str) -> bool: return self._jfs.isDir(path) + def fast_stat(self, path: str) -> FileStatus: + """Get information about a path other than its file/directory status. + + In the cloud, determining if a given path is a file, a directory, or both is expensive. This + method simply returns file metadata if there is a file at this path. If there is no file at + this path, this operation will fail. The presence or absence of a directory at this path + does not affect the behaviors of this method. + + """ + file_status_dict = json.loads(self._utils_package_object.fileStatus(self._jfs, path)) + return _file_status_scala_to_python(file_status_dict) + def stat(self, path: str) -> FileListEntry: - stat_dict = json.loads(self._utils_package_object.fileListEntry(self._jfs, path)) - return _file_list_entry_scala_to_python(stat_dict) + file_list_entry_dict = json.loads(self._utils_package_object.fileListEntry(self._jfs, path)) + return _file_list_entry_scala_to_python(file_list_entry_dict) def ls(self, path: str) -> List[FileListEntry]: - return [_file_list_entry_scala_to_python(st) - for st in json.loads(self._utils_package_object.ls(self._jfs, path))] + return [ + _file_list_entry_scala_to_python(st) for st in json.loads(self._utils_package_object.ls(self._jfs, path)) + ] def mkdir(self, path: str) -> None: return self._jfs.mkDir(path) diff --git a/hail/python/hail/genetics/__init__.py b/hail/python/hail/genetics/__init__.py index 1ca527cf16a..116c31f1a80 100644 --- a/hail/python/hail/genetics/__init__.py +++ b/hail/python/hail/genetics/__init__.py @@ -1,10 +1,7 @@ +from .allele_type import AlleleType from .call import Call -from .reference_genome import ReferenceGenome -from .pedigree import Pedigree, Trio from .locus import Locus +from .pedigree import Pedigree, Trio +from .reference_genome import ReferenceGenome -__all__ = ['Locus', - 'Call', - 'Pedigree', - 'Trio', - 'ReferenceGenome'] +__all__ = ['AlleleType', 'Locus', 'Call', 'Pedigree', 'Trio', 'ReferenceGenome'] diff --git a/hail/python/hail/genetics/allele_type.py b/hail/python/hail/genetics/allele_type.py new file mode 100644 index 00000000000..e777c6658a8 --- /dev/null +++ b/hail/python/hail/genetics/allele_type.py @@ -0,0 +1,95 @@ +from enum import IntEnum, auto + +_ALLELE_STRS = ( + "Unknown", + "SNP", + "MNP", + "Insertion", + "Deletion", + "Complex", + "Star", + "Symbolic", + "Transition", + "Transversion", +) + + +class AlleleType(IntEnum): + """An enumeration for allele type. + + Notes + ----- + The precise values of the enumeration constants are not guarenteed + to be stable and must not be relied upon. + """ + + UNKNOWN = 0 + """Unknown Allele Type""" + SNP = auto() + """Single-nucleotide Polymorphism (SNP)""" + MNP = auto() + """Multi-nucleotide Polymorphism (MNP)""" + INSERTION = auto() + """Insertion""" + DELETION = auto() + """Deletion""" + COMPLEX = auto() + """Complex Polymorphism""" + STAR = auto() + """Star Allele (``alt=*``)""" + SYMBOLIC = auto() + """Symbolic Allele + + e.g. ``alt=`` + """ + TRANSITION = auto() + """Transition SNP + + e.g. ``ref=A alt=G`` + + Note + ---- + This is only really used internally in :func:`hail.vds.sample_qc` and + :func:`hail.methods.sample_qc`. + """ + TRANSVERSION = auto() + """Transversion SNP + + e.g. ``ref=A alt=C`` + + Note + ---- + This is only really used internally in :func:`hail.vds.sample_qc` and + :func:`hail.methods.sample_qc`. + """ + + def __str__(self): + return str(self.value) + + @property + def pretty_name(self): + """A formatted (as opposed to uppercase) version of the member's name, + to match :func:`~hail.expr.functions.allele_type` + + Examples + -------- + >>> AlleleType.INSERTION.pretty_name + 'Insertion' + >>> at = AlleleType(hl.eval(hl.numeric_allele_type('a', 'att'))) + >>> at.pretty_name == hl.eval(hl.allele_type('a', 'att')) + True + """ + return _ALLELE_STRS[self] + + @classmethod + def _missing_(cls, value): + if not isinstance(value, str): + return None + return cls.__members__.get(value.upper()) + + @staticmethod + def strings(): + """Returns the names of the allele types, for use with + :func:`~hail.expr.functions.literal` + """ + return list(_ALLELE_STRS) diff --git a/hail/python/hail/genetics/call.py b/hail/python/hail/genetics/call.py index 79eb5bd1289..caf61049328 100644 --- a/hail/python/hail/genetics/call.py +++ b/hail/python/hail/genetics/call.py @@ -1,7 +1,6 @@ from collections.abc import Sequence from hail.typecheck import typecheck_method -from hail.utils import FatalError class Call(object): @@ -70,9 +69,11 @@ def __repr__(self): return 'Call(alleles=%s, phased=%s)' % (self._alleles, self._phased) def __eq__(self, other): - return ( self._phased == other._phased and - self._alleles == other._alleles - ) if isinstance(other, Call) else NotImplemented + return ( + (self._phased == other._phased and self._alleles == other._alleles) + if isinstance(other, Call) + else NotImplemented + ) def __hash__(self): return hash(self._phased) ^ hash(tuple(self._alleles)) @@ -261,10 +262,12 @@ def unphased_diploid_gt_index(self): ------- :obj:`int` """ + from hail.utils import FatalError if self.ploidy != 2 or self.phased: raise FatalError( - "'unphased_diploid_gt_index' is only valid for unphased, diploid calls. Found {}.".format(repr(self))) + "'unphased_diploid_gt_index' is only valid for unphased, diploid calls. Found {}.".format(repr(self)) + ) a0 = self._alleles[0] a1 = self._alleles[1] assert a0 <= a1 diff --git a/hail/python/hail/genetics/locus.py b/hail/python/hail/genetics/locus.py index 97029f60a6a..4fe4833f8b0 100644 --- a/hail/python/hail/genetics/locus.py +++ b/hail/python/hail/genetics/locus.py @@ -1,7 +1,7 @@ from typing import Union import hail as hl -from hail.genetics.reference_genome import reference_genome_type, ReferenceGenome +from hail.genetics.reference_genome import ReferenceGenome, reference_genome_type from hail.typecheck import typecheck_method @@ -51,17 +51,17 @@ def __repr__(self): return 'Locus(contig=%s, position=%s, reference_genome=%s)' % (self.contig, self.position, self._rg) def __eq__(self, other): - return ( self._contig == other._contig and - self._position == other._position and - self._rg == other._rg - ) if isinstance(other, Locus) else NotImplemented + return ( + (self._contig == other._contig and self._position == other._position and self._rg == other._rg) + if isinstance(other, Locus) + else NotImplemented + ) def __hash__(self): return hash(self._contig) ^ hash(self._position) ^ hash(self._rg) @classmethod - @typecheck_method(string=str, - reference_genome=reference_genome_type) + @typecheck_method(string=str, reference_genome=reference_genome_type) def parse(cls, string, reference_genome='default'): """Parses a locus object from a CHR:POS string. diff --git a/hail/python/hail/genetics/pedigree.py b/hail/python/hail/genetics/pedigree.py index 6563e23d4f0..f9e8f64593e 100644 --- a/hail/python/hail/genetics/pedigree.py +++ b/hail/python/hail/genetics/pedigree.py @@ -1,7 +1,7 @@ import re from collections import Counter -from hail.typecheck import typecheck_method, nullable, sequenceof +from hail.typecheck import nullable, sequenceof, typecheck_method from hail.utils.java import Env, FatalError, warning @@ -23,13 +23,8 @@ class Trio(object): :type is_female: bool or None """ - @typecheck_method(s=str, - fam_id=nullable(str), - pat_id=nullable(str), - mat_id=nullable(str), - is_female=nullable(bool)) + @typecheck_method(s=str, fam_id=nullable(str), pat_id=nullable(str), mat_id=nullable(str), is_female=nullable(bool)) def __init__(self, s, fam_id=None, pat_id=None, mat_id=None, is_female=None): - self._fam_id = fam_id self._s = s self._pat_id = pat_id @@ -38,21 +33,31 @@ def __init__(self, s, fam_id=None, pat_id=None, mat_id=None, is_female=None): def __repr__(self): return 'Trio(s=%s, fam_id=%s, pat_id=%s, mat_id=%s, is_female=%s)' % ( - repr(self.s), repr(self.fam_id), repr(self.pat_id), - repr(self.mat_id), repr(self.is_female)) + repr(self.s), + repr(self.fam_id), + repr(self.pat_id), + repr(self.mat_id), + repr(self.is_female), + ) def __str__(self): return 'Trio(s=%s, fam_id=%s, pat_id=%s, mat_id=%s, is_female=%s)' % ( - str(self.s), str(self.fam_id), str(self.pat_id), - str(self.mat_id), str(self.is_female)) + str(self.s), + str(self.fam_id), + str(self.pat_id), + str(self.mat_id), + str(self.is_female), + ) def __eq__(self, other): - return (isinstance(other, Trio) - and self._s == other._s - and self._mat_id == other._mat_id - and self._pat_id == other._pat_id - and self._fam_id == other._fam_id - and self._is_female == other._is_female) + return ( + isinstance(other, Trio) + and self._s == other._s + and self._mat_id == other._mat_id + and self._pat_id == other._pat_id + and self._fam_id == other._fam_id + and self._is_female == other._is_female + ) def __hash__(self): return hash((self._s, self._pat_id, self._mat_id, self._fam_id, self._is_female)) @@ -135,11 +140,13 @@ def _restrict_to(self, ids): if self._s not in ids: return None - return Trio(self._s, - self._fam_id, - self._pat_id if self._pat_id in ids else None, - self._mat_id if self._mat_id in ids else None, - self._is_female) + return Trio( + self._s, + self._fam_id, + self._pat_id if self._pat_id in ids else None, + self._mat_id if self._mat_id in ids else None, + self._is_female, + ) def _sex_as_numeric_string(self): if self._is_female is None: @@ -151,12 +158,15 @@ def sample_id_or_else_zero(sample_id): if sample_id is None: return "0" return sample_id - line_list = [sample_id_or_else_zero(self._fam_id), - self._s, - sample_id_or_else_zero(self._pat_id), - sample_id_or_else_zero(self._mat_id), - self._sex_as_numeric_string(), - "0"] + + line_list = [ + sample_id_or_else_zero(self._fam_id), + self._s, + sample_id_or_else_zero(self._pat_id), + sample_id_or_else_zero(self._mat_id), + self._sex_as_numeric_string(), + "0", + ] return "\t".join(line_list) @@ -181,8 +191,7 @@ def __iter__(self): return self._trios.__iter__() @classmethod - @typecheck_method(fam_path=str, - delimiter=str) + @typecheck_method(fam_path=str, delimiter=str) def read(cls, fam_path, delimiter='\\s+') -> 'Pedigree': """Read a PLINK .fam file and return a pedigree object. @@ -211,20 +220,24 @@ def read(cls, fam_path, delimiter='\\s+') -> 'Pedigree': split_line = re.split(delimiter, line.strip()) num_fields = len(split_line) if num_fields != 6: - raise FatalError("Require 6 fields per line in .fam, but this line has {}: {}".format(num_fields, line)) + raise FatalError( + "Require 6 fields per line in .fam, but this line has {}: {}".format(num_fields, line) + ) (fam, kid, dad, mom, sex, _) = tuple(split_line) # 1 is male, 2 is female, 0 is unknown. - is_female = sex == "2" if sex == "1" or sex == "2" else None + is_female = sex == "2" if sex in {'1', '2'} else None if is_female is None: missing_sex_count += 1 missing_sex_values.add(kid) - trio = Trio(kid, - fam if fam != "0" else None, - dad if dad != "0" else None, - mom if mom != "0" else None, - is_female) + trio = Trio( + kid, + fam if fam != "0" else None, + dad if dad != "0" else None, + mom if mom != "0" else None, + is_female, + ) trios.append(trio) only_ids = [trio.s for trio in trios] @@ -233,7 +246,11 @@ def read(cls, fam_path, delimiter='\\s+') -> 'Pedigree': raise FatalError("Invalid pedigree: found duplicate proband IDs\n{}".format(duplicate_ids)) if missing_sex_count > 0: - warning("Found {} samples with missing sex information (not 1 or 2).\n Missing samples: [{}]".format(missing_sex_count, missing_sex_values)) + warning( + "Found {} samples with missing sex information (not 1 or 2).\n Missing samples: [{}]".format( + missing_sex_count, missing_sex_values + ) + ) return Pedigree(trios) diff --git a/hail/python/hail/genetics/reference_genome.py b/hail/python/hail/genetics/reference_genome.py index e60853799f1..172a44a2e59 100644 --- a/hail/python/hail/genetics/reference_genome.py +++ b/hail/python/hail/genetics/reference_genome.py @@ -1,11 +1,11 @@ -from bisect import bisect_right import json import re -from hail.typecheck import typecheck_method, sequenceof, dictof, oneof, \ - sized_tupleof, nullable, transformed, lazy -from hail.utils.misc import wrap_to_list -from hail.utils.java import Env +from bisect import bisect_right + import hail as hl +from hail.typecheck import dictof, lazy, nullable, oneof, sequenceof, sized_tupleof, transformed, typecheck_method +from hail.utils.java import Env +from hail.utils.misc import wrap_to_list rg_type = lazy() reference_genome_type = oneof(transformed((str, lambda x: hl.get_reference(x))), rg_type) @@ -83,26 +83,30 @@ def _from_config(cls, config, _builtin=False): def par_tuple(p): assert p['start']['contig'] == p['end']['contig'] return (p['start']['contig'], p['start']['position'], p['end']['position']) + contigs = config['contigs'] - return ReferenceGenome(config['name'], - [c['name'] for c in contigs], - {c['name']: c['length'] for c in contigs}, - config['xContigs'], - config['yContigs'], - config['mtContigs'], - [par_tuple(p) for p in config['par']], - _builtin) - - @typecheck_method(name=str, - contigs=sequenceof(str), - lengths=dictof(str, int), - x_contigs=oneof(str, sequenceof(str)), - y_contigs=oneof(str, sequenceof(str)), - mt_contigs=oneof(str, sequenceof(str)), - par=sequenceof(sized_tupleof(str, int, int)), - _builtin=bool) + return ReferenceGenome( + config['name'], + [c['name'] for c in contigs], + {c['name']: c['length'] for c in contigs}, + config['xContigs'], + config['yContigs'], + config['mtContigs'], + [par_tuple(p) for p in config['par']], + _builtin, + ) + + @typecheck_method( + name=str, + contigs=sequenceof(str), + lengths=dictof(str, int), + x_contigs=oneof(str, sequenceof(str)), + y_contigs=oneof(str, sequenceof(str)), + mt_contigs=oneof(str, sequenceof(str)), + par=sequenceof(sized_tupleof(str, int, int)), + _builtin=bool, + ) def __init__(self, name, contigs, lengths, x_contigs=[], y_contigs=[], mt_contigs=[], par=[], _builtin=False): - contigs = wrap_to_list(contigs) x_contigs = wrap_to_list(x_contigs) y_contigs = wrap_to_list(y_contigs) @@ -114,7 +118,7 @@ def __init__(self, name, contigs, lengths, x_contigs=[], y_contigs=[], mt_contig 'xContigs': x_contigs, 'yContigs': y_contigs, 'mtContigs': mt_contigs, - 'par': [{'start': {'contig': c, 'position': s}, 'end': {'contig': c, 'position': e}} for (c, s, e) in par] + 'par': [{'start': {'contig': c, 'position': s}, 'end': {'contig': c, 'position': e}} for (c, s, e) in par], } self._contigs = contigs @@ -134,8 +138,15 @@ def __str__(self): return self._config['name'] def __repr__(self): - return 'ReferenceGenome(name=%s, contigs=%s, lengths=%s, x_contigs=%s, y_contigs=%s, mt_contigs=%s, par=%s)' % \ - (self.name, self.contigs, self.lengths, self.x_contigs, self.y_contigs, self.mt_contigs, self._par_tuple) + return 'ReferenceGenome(name=%s, contigs=%s, lengths=%s, x_contigs=%s, y_contigs=%s, mt_contigs=%s, par=%s)' % ( + self.name, + self.contigs, + self.lengths, + self.x_contigs, + self.y_contigs, + self.mt_contigs, + self._par_tuple, + ) def __eq__(self, other): return isinstance(other, ReferenceGenome) and self._config == other._config @@ -304,7 +315,7 @@ def read(cls, path): @typecheck_method(output=str) def write(self, output): - """"Write this reference genome to a file in JSON format. + """ "Write this reference genome to a file in JSON format. Examples -------- @@ -326,8 +337,7 @@ def write(self, output): with hl.utils.hadoop_open(output, 'w') as f: json.dump(self._config, f) - @typecheck_method(fasta_file=str, - index_file=nullable(str)) + @typecheck_method(fasta_file=str, index_file=nullable(str)) def add_sequence(self, fasta_file, index_file=None): """Load the reference sequence from a FASTA file. @@ -396,15 +406,16 @@ def remove_sequence(self): Env.backend().remove_sequence(self.name) @classmethod - @typecheck_method(name=str, - fasta_file=str, - index_file=str, - x_contigs=oneof(str, sequenceof(str)), - y_contigs=oneof(str, sequenceof(str)), - mt_contigs=oneof(str, sequenceof(str)), - par=sequenceof(sized_tupleof(str, int, int))) - def from_fasta_file(cls, name, fasta_file, index_file, - x_contigs=[], y_contigs=[], mt_contigs=[], par=[]): + @typecheck_method( + name=str, + fasta_file=str, + index_file=str, + x_contigs=oneof(str, sequenceof(str)), + y_contigs=oneof(str, sequenceof(str)), + mt_contigs=oneof(str, sequenceof(str)), + par=sequenceof(sized_tupleof(str, int, int)), + ) + def from_fasta_file(cls, name, fasta_file, index_file, x_contigs=[], y_contigs=[], mt_contigs=[], par=[]): """Create reference genome from a FASTA file. Parameters @@ -429,7 +440,9 @@ def from_fasta_file(cls, name, fasta_file, index_file, :class:`.ReferenceGenome` """ par_strings = ["{}:{}-{}".format(contig, start, end) for (contig, start, end) in par] - config = Env.backend().from_fasta_file(name, fasta_file, index_file, x_contigs, y_contigs, mt_contigs, par_strings) + config = Env.backend().from_fasta_file( + name, fasta_file, index_file, x_contigs, y_contigs, mt_contigs, par_strings + ) rg = ReferenceGenome._from_config(config) rg.add_sequence(fasta_file, index_file) @@ -462,8 +475,7 @@ def remove_liftover(self, dest_reference_genome): del self._liftovers[dest_reference_genome.name] Env.backend().remove_liftover(self.name, dest_reference_genome.name) - @typecheck_method(chain_file=str, - dest_reference_genome=reference_genome_type) + @typecheck_method(chain_file=str, dest_reference_genome=reference_genome_type) def add_liftover(self, chain_file, dest_reference_genome): """Register a chain file for liftover. @@ -515,7 +527,7 @@ def add_liftover(self, chain_file, dest_reference_genome): @typecheck_method(global_pos=int) def locus_from_global_position(self, global_pos: int) -> 'hl.Locus': - """" + """ " Constructs a locus from a global position in reference genome. The inverse of :meth:`.Locus.position`. @@ -553,9 +565,7 @@ def locus_from_global_position(self, global_pos: int) -> 'hl.Locus': contig_pos = self.global_positions_dict[contig] if global_pos >= contig_pos + self.lengths[contig]: - raise ValueError( - f"global_pos {global_pos} exceeds length of reference genome {self}." - ) + raise ValueError(f"global_pos {global_pos} exceeds length of reference genome {self}.") return hl.Locus(contig, global_pos - contig_pos + 1, self) diff --git a/hail/python/hail/ggplot/__init__.py b/hail/python/hail/ggplot/__init__.py index 5ff75af7041..6ca37ca9635 100644 --- a/hail/python/hail/ggplot/__init__.py +++ b/hail/python/hail/ggplot/__init__.py @@ -2,20 +2,53 @@ if is_notebook(): from plotly.io import renderers - renderers.default='iframe' + renderers.default = 'iframe' + +from .aes import Aesthetic, aes # noqa F401 from .coord_cartesian import coord_cartesian -from .ggplot import ggplot, GGPlot # noqa F401 -from .aes import aes, Aesthetic # noqa F401 -from .geoms import FigureAttribute, geom_line, geom_point, geom_text, geom_bar,\ - geom_histogram, geom_density, geom_func, geom_hline, geom_vline, geom_tile,\ - geom_col, geom_area, geom_ribbon # noqa F401 -from .labels import ggtitle, xlab, ylab, labs -from .scale import scale_x_continuous, scale_y_continuous, scale_x_discrete, scale_y_discrete, scale_x_genomic, \ - scale_x_log10, scale_y_log10, scale_x_reverse, scale_y_reverse, scale_color_discrete, scale_color_hue, scale_color_identity,\ - scale_color_manual, scale_color_continuous, scale_fill_discrete, scale_fill_hue, scale_fill_identity, scale_fill_continuous,\ - scale_fill_manual, scale_shape_manual, scale_shape_auto -from .facets import vars, facet_wrap +from .facets import facet_wrap, vars +from .geoms import ( + FigureAttribute, # noqa F401 + geom_area, + geom_bar, + geom_col, + geom_density, + geom_func, + geom_histogram, + geom_hline, + geom_line, + geom_point, + geom_ribbon, + geom_text, + geom_tile, + geom_vline, +) +from .ggplot import GGPlot, ggplot # noqa F401 +from .labels import ggtitle, labs, xlab, ylab +from .scale import ( + scale_color_continuous, + scale_color_discrete, + scale_color_hue, + scale_color_identity, + scale_color_manual, + scale_fill_continuous, + scale_fill_discrete, + scale_fill_hue, + scale_fill_identity, + scale_fill_manual, + scale_shape_auto, + scale_shape_manual, + scale_x_continuous, + scale_x_discrete, + scale_x_genomic, + scale_x_log10, + scale_x_reverse, + scale_y_continuous, + scale_y_discrete, + scale_y_log10, + scale_y_reverse, +) __all__ = [ "aes", @@ -60,5 +93,5 @@ "scale_shape_manual", "scale_shape_auto", "facet_wrap", - "vars" + "vars", ] diff --git a/hail/python/hail/ggplot/aes.py b/hail/python/hail/ggplot/aes.py index 5497f28d4d2..a78c9cfd087 100644 --- a/hail/python/hail/ggplot/aes.py +++ b/hail/python/hail/ggplot/aes.py @@ -1,10 +1,9 @@ from collections.abc import Mapping -from hail.expr import Expression -from hail import literal +from hail.expr import Expression, literal -class Aesthetic(Mapping): +class Aesthetic(Mapping): def __init__(self, properties): self.properties = properties @@ -44,7 +43,8 @@ def aes(**kwargs): hail_field_properties = {} for k, v in kwargs.items(): + _v = v if not isinstance(v, Expression): - v = literal(v) - hail_field_properties[k] = v + _v = literal(v) + hail_field_properties[k] = _v return Aesthetic(hail_field_properties) diff --git a/hail/python/hail/ggplot/facets.py b/hail/python/hail/ggplot/facets.py index 32cf43d61ed..9ce9274515d 100644 --- a/hail/python/hail/ggplot/facets.py +++ b/hail/python/hail/ggplot/facets.py @@ -1,14 +1,13 @@ import abc import math +from typing import ClassVar, Dict, Optional, Tuple -from typing import Dict, Tuple +import hail as hl +from hail.expr import Expression, StructExpression from .geoms import FigureAttribute from .utils import n_partitions -import hail as hl -from hail import Expression, StructExpression - def vars(*args: Expression) -> StructExpression: """ @@ -27,7 +26,9 @@ def vars(*args: Expression) -> StructExpression: return hl.struct(**{f"var_{i}": arg for i, arg in enumerate(args)}) -def facet_wrap(facets: StructExpression, *, nrow: int = None, ncol: int = None, scales: str = "fixed") -> "FacetWrap": +def facet_wrap( + facets: StructExpression, *, nrow: Optional[int] = None, ncol: Optional[int] = None, scales: str = "fixed" +) -> "FacetWrap": """Introduce a one dimensional faceting on specified fields. Parameters @@ -51,19 +52,18 @@ def facet_wrap(facets: StructExpression, *, nrow: int = None, ncol: int = None, class Faceter(FigureAttribute): - @abc.abstractmethod def get_expr_to_group_by(self) -> StructExpression: pass class FacetWrap(Faceter): - _base_scale_mappings = { + _base_scale_mappings: ClassVar = { "shared_xaxes": "all", "shared_yaxes": "all", } - _scale_mappings = { + _scale_mappings: ClassVar = { "fixed": _base_scale_mappings, "free_x": { **_base_scale_mappings, @@ -76,15 +76,14 @@ class FacetWrap(Faceter): "free": { "shared_xaxes": False, "shared_yaxes": False, - } + }, } - def __init__(self, facets: StructExpression, nrow: int = None, ncol: int = None, scales: str = "fixed"): + def __init__( + self, facets: StructExpression, nrow: Optional[int] = None, ncol: Optional[int] = None, scales: str = "fixed" + ): if nrow is not None and ncol is not None: - raise ValueError( - "Both `nrow` and `ncol` were specified. " - "Please specify only one of these values." - ) + raise ValueError("Both `nrow` and `ncol` were specified. " "Please specify only one of these values.") if scales not in self._scale_mappings: raise ValueError( f"An unsupported value ({scales}) was provided for `scales`. " diff --git a/hail/python/hail/ggplot/geoms.py b/hail/python/hail/ggplot/geoms.py index 585bd39507c..65c5babdced 100644 --- a/hail/python/hail/ggplot/geoms.py +++ b/hail/python/hail/ggplot/geoms.py @@ -1,10 +1,11 @@ -from typing import Dict, Any, Optional import abc +from typing import Any, ClassVar, Dict, Optional + import numpy as np import plotly.graph_objects as go from .aes import aes -from .stats import StatCount, StatIdentity, StatBin, StatNone, StatFunction, StatCDF +from .stats import StatBin, StatCDF, StatCount, StatFunction, StatIdentity, StatNone from .utils import bar_position_plotly_to_gg, linetype_plotly_to_gg @@ -13,12 +14,13 @@ class FigureAttribute(abc.ABC): class Geom(FigureAttribute): - def __init__(self, aes): self.aes = aes @abc.abstractmethod - def apply_to_fig(self, agg_result, fig_so_far: go.Figure, precomputed, facet_row, facet_col, legend_cache, is_faceted: bool): + def apply_to_fig( + self, agg_result, fig_so_far: go.Figure, precomputed, facet_row, facet_col, legend_cache, is_faceted: bool + ): """Add this geometry to the figure and indicate if this geometry demands a static figure.""" pass @@ -48,27 +50,22 @@ def _update_legend_trace_args(self, trace_args, legend_cache): class GeomLineBasic(Geom): - aes_to_arg = { + aes_to_arg: ClassVar = { "color": ("line_color", "black"), "size": ("marker_size", None), "tooltip": ("hovertext", None), - "color_legend": ("name", None) + "color_legend": ("name", None), } def __init__(self, aes, color): super().__init__(aes) self.color = color - def apply_to_fig(self, grouped_data, fig_so_far: go.Figure, precomputed, facet_row, facet_col, legend_cache, is_faceted: bool): - + def apply_to_fig( + self, grouped_data, fig_so_far: go.Figure, precomputed, facet_row, facet_col, legend_cache, is_faceted: bool + ): def plot_group(df): - trace_args = { - "x": df.x, - "y": df.y, - "mode": "lines", - "row": facet_row, - "col": facet_col - } + trace_args = {"x": df.x, "y": df.y, "mode": "lines", "row": facet_row, "col": facet_col} self._add_aesthetics_to_trace_args(trace_args, df) self._update_legend_trace_args(trace_args, legend_cache) @@ -84,8 +81,7 @@ def get_stat(self): class GeomPoint(Geom): - - aes_to_plotly = { + aes_to_plotly: ClassVar = { "color": "marker_color", "size": "marker_size", "tooltip": "hovertext", @@ -93,12 +89,12 @@ class GeomPoint(Geom): "shape": "marker_symbol", } - aes_defaults = { + aes_defaults: ClassVar = { "color": "black", "shape": "circle", } - aes_legend_groups = { + aes_legend_groups: ClassVar = { "color", "shape", } @@ -134,41 +130,35 @@ def _get_aes_values(self, df): return values def _add_trace(self, fig_so_far: go.Figure, df, facet_row, facet_col, values, legend: Optional[str] = None): - fig_so_far.add_scatter( + fig_so_far.add_scatter(**{ **{ - **{ - "x": df.x, - "y": df.y, - "mode": "markers", - "row": facet_row, - "col": facet_col, - **( - {"showlegend": False} - if legend is None else - {"name": legend, "showlegend": True} - ) - }, - **self._map_to_plotly(values) - } - ) + "x": df.x, + "y": df.y, + "mode": "markers", + "row": facet_row, + "col": facet_col, + **({"showlegend": False} if legend is None else {"name": legend, "showlegend": True}), + }, + **self._map_to_plotly(values), + }) def _add_legend(self, fig_so_far: go.Figure, aes_name, category, value): - fig_so_far.add_scatter( + fig_so_far.add_scatter(**{ **{ - **{ - "x": [None], - "y": [None], - "mode": "markers", - "name": category, - "showlegend": True, - "legendgroup": aes_name, - "legendgrouptitle_text": aes_name, - }, - **self._map_to_plotly({**self.aes_defaults, aes_name: value}) - } - ) - - def apply_to_fig(self, grouped_data, fig_so_far: go.Figure, precomputed, facet_row, facet_col, legend_cache, is_faceted: bool): + "x": [None], + "y": [None], + "mode": "markers", + "name": category, + "showlegend": True, + "legendgroup": aes_name, + "legendgrouptitle_text": aes_name, + }, + **self._map_to_plotly({**self.aes_defaults, aes_name: value}), + }) + + def apply_to_fig( + self, grouped_data, fig_so_far: go.Figure, precomputed, facet_row, facet_col, legend_cache, is_faceted: bool + ): traces = [] legends = {} for df in grouped_data: @@ -178,15 +168,13 @@ def apply_to_fig(self, grouped_data, fig_so_far: go.Figure, precomputed, facet_r category = self._get_aes_value(df, f"{aes_name}_legend") if category is not None: trace_categories.append(category) - legends[aes_name] = ({ - **legends.get(aes_name, {}), - category: values[aes_name] - }) + legends[aes_name] = {**legends.get(aes_name, {}), category: values[aes_name]} traces.append([fig_so_far, df, facet_row, facet_col, values, trace_categories]) non_empty_legend_groups = [ - legend_group for legend_group in legends.values() - if len(legend_group) > 1 or (len(legend_group) == 1 and list(legend_group.keys())[0] is not None) + legend_group + for legend_group in legends.values() + if len(legend_group) > 1 or (len(legend_group) == 1 and next(iter(legend_group.keys())) is not None) ] dummy_legend = is_faceted or len(non_empty_legend_groups) >= 2 @@ -204,7 +192,7 @@ def apply_to_fig(self, grouped_data, fig_so_far: go.Figure, precomputed, facet_r for trace in traces: trace_categories = trace[-1] if main_categories is not None: - trace[-1] = [category for category in trace_categories if category in main_categories][0] + trace[-1] = next(category for category in trace_categories if category in main_categories) elif len(trace_categories) == 1: trace[-1] = [trace_categories][0] else: @@ -230,12 +218,13 @@ def geom_point(mapping=aes(), *, color=None, size=None, alpha=None, shape=None): class GeomLine(GeomLineBasic): - def __init__(self, aes, color=None): super().__init__(aes, color) self.color = color - def apply_to_fig(self, agg_result, fig_so_far: go.Figure, precomputed, facet_row, facet_col, legend_cache, is_faceted: bool): + def apply_to_fig( + self, agg_result, fig_so_far: go.Figure, precomputed, facet_row, facet_col, legend_cache, is_faceted: bool + ): return super().apply_to_fig(agg_result, fig_so_far, precomputed, facet_row, facet_col, legend_cache, is_faceted) def get_stat(self): @@ -256,12 +245,12 @@ def geom_line(mapping=aes(), *, color=None, size=None, alpha=None): class GeomText(Geom): - aes_to_arg = { + aes_to_arg: ClassVar = { "color": ("textfont_color", "black"), "size": ("marker_size", None), "tooltip": ("hovertext", None), "color_legend": ("name", None), - "alpha": ("marker_opacity", None) + "alpha": ("marker_opacity", None), } def __init__(self, aes, color=None, size=None, alpha=None): @@ -270,16 +259,11 @@ def __init__(self, aes, color=None, size=None, alpha=None): self.size = size self.alpha = alpha - def apply_to_fig(self, grouped_data, fig_so_far: go.Figure, precomputed, facet_row, facet_col, legend_cache, is_faceted: bool): + def apply_to_fig( + self, grouped_data, fig_so_far: go.Figure, precomputed, facet_row, facet_col, legend_cache, is_faceted: bool + ): def plot_group(df): - trace_args = { - "x": df.x, - "y": df.y, - "text": df.label, - "mode": "text", - "row": facet_row, - "col": facet_col - } + trace_args = {"x": df.x, "y": df.y, "text": df.label, "mode": "text", "row": facet_row, "col": facet_col} self._add_aesthetics_to_trace_args(trace_args, df) self._update_legend_trace_args(trace_args, legend_cache) @@ -307,13 +291,12 @@ def geom_text(mapping=aes(), *, color=None, size=None, alpha=None): class GeomBar(Geom): - - aes_to_arg = { + aes_to_arg: ClassVar = { "fill": ("marker_color", "black"), "color": ("marker_line_color", None), "tooltip": ("hovertext", None), "fill_legend": ("name", None), - "alpha": ("marker_opacity", None) + "alpha": ("marker_opacity", None), } def __init__(self, aes, fill=None, color=None, alpha=None, position="stack", size=None, stat=None): @@ -328,14 +311,11 @@ def __init__(self, aes, fill=None, color=None, alpha=None, position="stack", siz stat = StatCount() self.stat = stat - def apply_to_fig(self, grouped_data, fig_so_far: go.Figure, precomputed, facet_row, facet_col, legend_cache, is_faceted: bool): + def apply_to_fig( + self, grouped_data, fig_so_far: go.Figure, precomputed, facet_row, facet_col, legend_cache, is_faceted: bool + ): def plot_group(df): - trace_args = { - "x": df.x, - "y": df.y, - "row": facet_row, - "col": facet_col - } + trace_args = {"x": df.x, "y": df.y, "row": facet_row, "col": facet_col} self._add_aesthetics_to_trace_args(trace_args, df) self._update_legend_trace_args(trace_args, legend_cache) @@ -379,15 +359,17 @@ def geom_col(mapping=aes(), *, fill=None, color=None, alpha=None, position="stac class GeomHistogram(Geom): - aes_to_arg = { + aes_to_arg: ClassVar = { "fill": ("marker_color", "black"), "color": ("marker_line_color", None), "tooltip": ("hovertext", None), "fill_legend": ("name", None), - "alpha": ("marker_opacity", None) + "alpha": ("marker_opacity", None), } - def __init__(self, aes, min_val=None, max_val=None, bins=None, fill=None, color=None, alpha=None, position='stack', size=None): + def __init__( + self, aes, min_val=None, max_val=None, bins=None, fill=None, color=None, alpha=None, position='stack', size=None + ): super().__init__(aes) self.min_val = min_val self.max_val = max_val @@ -398,7 +380,9 @@ def __init__(self, aes, min_val=None, max_val=None, bins=None, fill=None, color= self.position = position self.size = size - def apply_to_fig(self, grouped_data, fig_so_far: go.Figure, precomputed, facet_row, facet_col, legend_cache, is_faceted: bool): + def apply_to_fig( + self, grouped_data, fig_so_far: go.Figure, precomputed, facet_row, facet_col, legend_cache, is_faceted: bool + ): min_val = self.min_val if self.min_val is not None else precomputed.min_val max_val = self.max_val if self.max_val is not None else precomputed.max_val # This assumes it doesn't really make sense to use another stat for geom_histogram @@ -429,10 +413,9 @@ def plot_group(df, idx): "col": facet_col, "customdata": list(zip(left_xs, right_xs)), "width": bar_width, - "hovertemplate": - "Range: [%{customdata[0]:.3f}-%{customdata[1]:.3f})
    " - "Count: %{y}
    " - "", + "hovertemplate": "Range: [%{customdata[0]:.3f}-%{customdata[1]:.3f})
    " + "Count: %{y}
    " + "", } self._add_aesthetics_to_trace_args(trace_args, df) @@ -449,8 +432,18 @@ def get_stat(self): return StatBin(self.min_val, self.max_val, self.bins) -def geom_histogram(mapping=aes(), *, min_val=None, max_val=None, bins=None, fill=None, color=None, alpha=None, position='stack', - size=None): +def geom_histogram( + mapping=aes(), + *, + min_val=None, + max_val=None, + bins=None, + fill=None, + color=None, + alpha=None, + position='stack', + size=None, +): """Creates a histogram. Note: this function currently does not support same interface as R's ggplot. @@ -481,7 +474,18 @@ def geom_histogram(mapping=aes(), *, min_val=None, max_val=None, bins=None, fill :class:`FigureAttribute` The geom to be applied. """ - return GeomHistogram(mapping, min_val=min_val, max_val=max_val, bins=bins, fill=fill, color=color, alpha=alpha, position=position, size=size) + return GeomHistogram( + mapping, + min_val=min_val, + max_val=max_val, + bins=bins, + fill=fill, + color=color, + alpha=alpha, + position=position, + size=size, + ) + # Computes the maximum entropy distribution whose cdf is within +- e of the # staircase-shaped cdf encoded by min_x, max_x, x, y. @@ -517,7 +521,7 @@ def point_on_bound(i, upper): if i == len(x): return max_x, 1 else: - yi = y[i] + e if upper else y[i+1] - e + yi = y[i] + e if upper else y[i + 1] - e return x[i], yi # Result variables: @@ -595,12 +599,12 @@ def fix_point_on_result(i, upper): class GeomDensity(Geom): - aes_to_arg = { + aes_to_arg: ClassVar = { "fill": ("marker_color", "black"), "color": ("marker_line_color", None), "tooltip": ("hovertext", None), "fill_legend": ("name", None), - "alpha": ("marker_opacity", None) + "alpha": ("marker_opacity", None), } def __init__(self, aes, k=1000, smoothing=0.5, fill=None, color=None, alpha=None, smoothed=False): @@ -612,8 +616,11 @@ def __init__(self, aes, k=1000, smoothing=0.5, fill=None, color=None, alpha=None self.alpha = alpha self.smoothed = smoothed - def apply_to_fig(self, grouped_data, fig_so_far: go.Figure, precomputed, facet_row, facet_col, legend_cache, is_faceted: bool): + def apply_to_fig( + self, grouped_data, fig_so_far: go.Figure, precomputed, facet_row, facet_col, legend_cache, is_faceted: bool + ): from hail.expr.functions import _error_from_cdf_python + def plot_group(df, idx): data = df.attrs['data'] @@ -628,7 +635,9 @@ def plot_group(df, idx): def f(x, prev): inv_scale = (np.sqrt(n * slope) / self.smoothing) * np.sqrt(prev / weights) diff = x[:, np.newaxis] - values - grid = (3 / (4 * n)) * weights * np.maximum(0, inv_scale - np.power(diff, 2) * np.power(inv_scale, 3)) + grid = ( + (3 / (4 * n)) * weights * np.maximum(0, inv_scale - np.power(diff, 2) * np.power(inv_scale, 3)) + ) return np.sum(grid, axis=1) round1 = f(values, np.full(len(values), slope)) @@ -641,7 +650,7 @@ def f(x, prev): "mode": "lines", "fill": "tozeroy", "row": facet_row, - "col": facet_col + "col": facet_col, } self._add_aesthetics_to_trace_args(trace_args, df) @@ -670,7 +679,7 @@ def f(x, prev): "row": facet_row, "col": facet_col, "width": widths, - "offset": 0 + "offset": 0, } self._add_aesthetics_to_trace_args(trace_args, df) @@ -726,18 +735,16 @@ def geom_density(mapping=aes(), *, k=1000, smoothing=0.5, fill=None, color=None, class GeomHLine(Geom): - def __init__(self, yintercept, linetype="solid", color=None): self.yintercept = yintercept self.aes = aes() self.linetype = linetype self.color = color - def apply_to_fig(self, agg_result, fig_so_far: go.Figure, precomputed, facet_row, facet_col, legend_cache, is_faceted: bool): - line_attributes = { - "y": self.yintercept, - "line_dash": linetype_plotly_to_gg(self.linetype) - } + def apply_to_fig( + self, agg_result, fig_so_far: go.Figure, precomputed, facet_row, facet_col, legend_cache, is_faceted: bool + ): + line_attributes = {"y": self.yintercept, "line_dash": linetype_plotly_to_gg(self.linetype)} if self.color is not None: line_attributes["line_color"] = self.color @@ -769,18 +776,16 @@ def geom_hline(yintercept, *, linetype="solid", color=None): class GeomVLine(Geom): - def __init__(self, xintercept, linetype="solid", color=None): self.xintercept = xintercept self.aes = aes() self.linetype = linetype self.color = color - def apply_to_fig(self, agg_result, fig_so_far: go.Figure, precomputed, facet_row, facet_col, legend_cache, is_faceted: bool): - line_attributes = { - "x": self.xintercept, - "line_dash": linetype_plotly_to_gg(self.linetype) - } + def apply_to_fig( + self, agg_result, fig_so_far: go.Figure, precomputed, facet_row, facet_col, legend_cache, is_faceted: bool + ): + line_attributes = {"x": self.xintercept, "line_dash": linetype_plotly_to_gg(self.linetype)} if self.color is not None: line_attributes["line_color"] = self.color @@ -812,13 +817,13 @@ def geom_vline(xintercept, *, linetype="solid", color=None): class GeomTile(Geom): - def __init__(self, aes): self.aes = aes - def apply_to_fig(self, grouped_data, fig_so_far: go.Figure, precomputed, facet_row, facet_col, legend_cache, is_faceted: bool): + def apply_to_fig( + self, grouped_data, fig_so_far: go.Figure, precomputed, facet_row, facet_col, legend_cache, is_faceted: bool + ): def plot_group(df): - for idx, row in df.iterrows(): x_center = row['x'] y_center = row['y'] @@ -832,7 +837,7 @@ def plot_group(df): "y1": y_center + height / 2, "row": facet_row, "col": facet_col, - "opacity": row.get('alpha', 1.0) + "opacity": row.get('alpha', 1.0), } if "fill" in df.attrs: shape_args["fillcolor"] = df.attrs["fill"] @@ -858,7 +863,9 @@ def __init__(self, aes, fun, color): super().__init__(aes, color) self.fun = fun - def apply_to_fig(self, agg_result, fig_so_far: go.Figure, precomputed, facet_row, facet_col, legend_cache, is_faceted: bool): + def apply_to_fig( + self, agg_result, fig_so_far: go.Figure, precomputed, facet_row, facet_col, legend_cache, is_faceted: bool + ): return super().apply_to_fig(agg_result, fig_so_far, precomputed, facet_row, facet_col, legend_cache, is_faceted) def get_stat(self): @@ -870,11 +877,11 @@ def geom_func(mapping=aes(), fun=None, color=None): class GeomArea(Geom): - aes_to_arg = { + aes_to_arg: ClassVar = { "fill": ("fillcolor", "black"), "color": ("line_color", "rgba(0, 0, 0, 0)"), "tooltip": ("hovertext", None), - "fill_legend": ("name", None) + "fill_legend": ("name", None), } def __init__(self, aes, fill, color): @@ -882,15 +889,11 @@ def __init__(self, aes, fill, color): self.fill = fill self.color = color - def apply_to_fig(self, grouped_data, fig_so_far: go.Figure, precomputed, facet_row, facet_col, legend_cache, is_faceted: bool): + def apply_to_fig( + self, grouped_data, fig_so_far: go.Figure, precomputed, facet_row, facet_col, legend_cache, is_faceted: bool + ): def plot_group(df): - trace_args = { - "x": df.x, - "y": df.y, - "row": facet_row, - "col": facet_col, - "fill": 'tozeroy' - } + trace_args = {"x": df.x, "y": df.y, "row": facet_row, "col": facet_col, "fill": 'tozeroy'} self._add_aesthetics_to_trace_args(trace_args, df) self._update_legend_trace_args(trace_args, legend_cache) @@ -927,11 +930,11 @@ def geom_area(mapping=aes(), fill=None, color=None): class GeomRibbon(Geom): - aes_to_arg = { + aes_to_arg: ClassVar = { "fill": ("fillcolor", "black"), "color": ("line_color", "rgba(0, 0, 0, 0)"), "tooltip": ("hovertext", None), - "fill_legend": ("name", None) + "fill_legend": ("name", None), } def __init__(self, aes, fill, color): @@ -939,16 +942,17 @@ def __init__(self, aes, fill, color): self.fill = fill self.color = color - def apply_to_fig(self, grouped_data, fig_so_far: go.Figure, precomputed, facet_row, facet_col, legend_cache, is_faceted: bool): + def apply_to_fig( + self, grouped_data, fig_so_far: go.Figure, precomputed, facet_row, facet_col, legend_cache, is_faceted: bool + ): def plot_group(df): - trace_args_bottom = { "x": df.x, "y": df.ymin, "row": facet_row, "col": facet_col, "mode": "lines", - "showlegend": False + "showlegend": False, } self._add_aesthetics_to_trace_args(trace_args_bottom, df) self._update_legend_trace_args(trace_args_bottom, legend_cache) @@ -959,7 +963,7 @@ def plot_group(df): "row": facet_row, "col": facet_col, "mode": "lines", - "fill": 'tonexty' + "fill": 'tonexty', } self._add_aesthetics_to_trace_args(trace_args_top, df) self._update_legend_trace_args(trace_args_top, legend_cache) diff --git a/hail/python/hail/ggplot/ggplot.py b/hail/python/hail/ggplot/ggplot.py index 559c1c26616..5cef19a9484 100644 --- a/hail/python/hail/ggplot/ggplot.py +++ b/hail/python/hail/ggplot/ggplot.py @@ -1,19 +1,31 @@ -from plotly.subplots import make_subplots - -from pprint import pprint import itertools +from pprint import pprint + +from plotly.subplots import make_subplots import hail as hl -from .coord_cartesian import CoordCartesian -from .geoms import Geom, FigureAttribute -from .labels import Labels -from .scale import Scale, ScaleContinuous, ScaleDiscrete, scale_x_continuous, scale_x_genomic, scale_y_continuous, \ - scale_x_discrete, scale_y_discrete, scale_color_discrete, scale_color_continuous, scale_fill_discrete, \ - scale_fill_continuous, scale_shape_auto from .aes import Aesthetic, aes +from .coord_cartesian import CoordCartesian from .facets import Faceter -from .utils import is_continuous_type, is_genomic_type, check_scale_continuity +from .geoms import FigureAttribute, Geom +from .labels import Labels +from .scale import ( + Scale, + ScaleContinuous, + ScaleDiscrete, + scale_color_continuous, + scale_color_discrete, + scale_fill_continuous, + scale_fill_discrete, + scale_shape_auto, + scale_x_continuous, + scale_x_discrete, + scale_x_genomic, + scale_y_continuous, + scale_y_discrete, +) +from .utils import check_scale_continuity, is_continuous_type, is_genomic_type class GGPlot: @@ -41,7 +53,7 @@ def __init__(self, ht, aes, geoms=[], labels=Labels(), coord_cartesian=None, sca self.add_default_scales(aes) def __add__(self, other): - assert(isinstance(other, FigureAttribute) or isinstance(other, Aesthetic)) + assert isinstance(other, (FigureAttribute, Aesthetic)) copied = self.copy() if isinstance(other, Geom): @@ -63,7 +75,6 @@ def __add__(self, other): return copied def add_default_scales(self, aesthetic): - for aesthetic_str, mapped_expr in aesthetic.items(): dtype = mapped_expr.dtype if aesthetic_str not in self.scales: @@ -98,11 +109,10 @@ def add_default_scales(self, aesthetic): "The 'shape' aesthetic does not support continuous " "types. Specify values of a discrete type instead." ) + elif is_continuous: + self.scales[aesthetic_str] = ScaleContinuous(aesthetic_str) else: - if is_continuous: - self.scales[aesthetic_str] = ScaleContinuous(aesthetic_str) - else: - self.scales[aesthetic_str] = ScaleDiscrete(aesthetic_str) + self.scales[aesthetic_str] = ScaleDiscrete(aesthetic_str) def copy(self): return GGPlot(self.ht, self.aes, self.geoms[:], self.labels, self.coord_cartesian, self.scales, self.facet) @@ -148,7 +158,9 @@ def collect_mappings_and_precomputed(selected): for key in combined_mapping: if key in self.scales: - combined_mapping = combined_mapping.annotate(**{key: self.scales[key].transform_data(combined_mapping[key])}) + combined_mapping = combined_mapping.annotate(**{ + key: self.scales[key].transform_data(combined_mapping[key]) + }) mapping_per_geom.append(combined_mapping) precomputes[geom_label] = geom.get_stat().get_precomputes(combined_mapping) @@ -170,7 +182,9 @@ def get_aggregation_result(selected, mapping_per_geom, precomputed): stat = self.geoms[geom_idx].get_stat() geom_label = make_geom_label(geom_idx) if use_faceting: - agg = hl.agg.group_by(selected.facet, stat.make_agg(combined_mapping, precomputed[geom_label], self.scales)) + agg = hl.agg.group_by( + selected.facet, stat.make_agg(combined_mapping, precomputed[geom_label], self.scales) + ) else: agg = stat.make_agg(combined_mapping, precomputed[geom_label], self.scales) aggregators[geom_label] = agg @@ -181,10 +195,15 @@ def get_aggregation_result(selected, mapping_per_geom, precomputed): if use_faceting: facet_list = list(set(itertools.chain(*[list(x.keys()) for x in all_agg_results.values()]))) facet_to_idx = {facet: idx for idx, facet in enumerate(facet_list)} - facet_idx_to_agg_result = {geom_label: {facet_to_idx[facet]: agg_result for facet, agg_result in facet_to_agg_result.items()} for geom_label, facet_to_agg_result in all_agg_results.items()} + facet_idx_to_agg_result = { + geom_label: {facet_to_idx[facet]: agg_result for facet, agg_result in facet_to_agg_result.items()} + for geom_label, facet_to_agg_result in all_agg_results.items() + } num_facets = len(facet_list) else: - facet_idx_to_agg_result = {geom_label: {0: agg_result} for geom_label, agg_result in all_agg_results.items()} + facet_idx_to_agg_result = { + geom_label: {0: agg_result} for geom_label, agg_result in all_agg_results.items() + } num_facets = 1 facet_list = None @@ -193,17 +212,26 @@ def get_aggregation_result(selected, mapping_per_geom, precomputed): self.verify_scales() selected = select_table() mapping_per_geom, precomputed = collect_mappings_and_precomputed(selected) - labels_to_stats, aggregated, num_facets, facet_list = get_aggregation_result(selected, mapping_per_geom, precomputed) + labels_to_stats, aggregated, num_facets, facet_list = get_aggregation_result( + selected, mapping_per_geom, precomputed + ) geoms_and_grouped_dfs_by_facet_idx = [] for geom, (geom_label, agg_result_by_facet) in zip(self.geoms, aggregated.items()): - dfs_by_facet_idx = {facet_idx: labels_to_stats[geom_label].listify(agg_result) for facet_idx, agg_result in agg_result_by_facet.items()} + dfs_by_facet_idx = { + facet_idx: labels_to_stats[geom_label].listify(agg_result) + for facet_idx, agg_result in agg_result_by_facet.items() + } geoms_and_grouped_dfs_by_facet_idx.append((geom, geom_label, dfs_by_facet_idx)) # Create scaling functions based on all the data: transformers = {} for scale in self.scales.values(): - all_dfs = list(itertools.chain(*[facet_to_dfs_dict.values() for _, _, facet_to_dfs_dict in geoms_and_grouped_dfs_by_facet_idx])) + all_dfs = list( + itertools.chain(*[ + facet_to_dfs_dict.values() for _, _, facet_to_dfs_dict in geoms_and_grouped_dfs_by_facet_idx + ]) + ) transformers[scale.aesthetic_name] = scale.create_local_transformer(all_dfs) is_faceted = self.facet is not None @@ -212,8 +240,10 @@ def get_aggregation_result(selected, mapping_per_geom, precomputed): subplot_args = { "rows": n_facet_rows, "cols": n_facet_cols, - "subplot_titles": [", ".join([str(fs_value) for fs_value in facet_struct.values()]) for facet_struct in facet_list], - **self.facet.get_shared_axis_kwargs() + "subplot_titles": [ + ", ".join([str(fs_value) for fs_value in facet_struct.values()]) for facet_struct in facet_list + ], + **self.facet.get_shared_axis_kwargs(), } else: n_facet_rows = 1 @@ -240,7 +270,9 @@ def get_aggregation_result(selected, mapping_per_geom, precomputed): facet_row = facet_idx // n_facet_cols + 1 facet_col = facet_idx % n_facet_cols + 1 - geom.apply_to_fig(scaled_grouped_dfs, fig, precomputed[geom_label], facet_row, facet_col, legend_cache, is_faceted) + geom.apply_to_fig( + scaled_grouped_dfs, fig, precomputed[geom_label], facet_row, facet_col, legend_cache, is_faceted + ) # Important to update axes after labels, axes names take precedence. self.labels.apply_to_fig(fig) @@ -270,8 +302,7 @@ def get_aggregation_result(selected, mapping_per_geom, precomputed): return fig def show(self): - """Render and show the plot, either in a browser or notebook. - """ + """Render and show the plot, either in a browser or notebook.""" self.to_plotly().show() def write_image(self, path): diff --git a/hail/python/hail/ggplot/scale.py b/hail/python/hail/ggplot/scale.py index 331c53a21a9..62c6ae530d2 100644 --- a/hail/python/hail/ggplot/scale.py +++ b/hail/python/hail/ggplot/scale.py @@ -1,12 +1,11 @@ import abc - from collections.abc import Mapping -import plotly.express as px import plotly +import plotly.express as px from hail.context import get_reference -from hail import tstr +from hail.expr.types import tstr from .geoms import FigureAttribute from .utils import continuous_nums_to_colors, is_continuous_type, is_discrete_type @@ -88,7 +87,6 @@ def is_continuous(self): class PositionScaleContinuous(PositionScale): - def __init__(self, axis=None, name=None, breaks=None, labels=None, transformation="identity"): super().__init__(axis, name, breaks, labels) self.transformation = transformation @@ -190,8 +188,7 @@ def create_local_transformer(self, groups_of_dfs): mapping = dict(zip(categories, values)) else: raise TypeError( - "Expected scale values to be a Mapping or list, but received a(n) " - f"{type(values)}: {values}." + "Expected scale values to be a Mapping or list, but received a(n) " f"{type(values)}: {values}." ) def transform(df): @@ -212,7 +209,6 @@ def get_values(self, categories): class ScaleColorContinuous(ScaleContinuous): - def create_local_transformer(self, groups_of_dfs): overall_min = None overall_max = None diff --git a/hail/python/hail/ggplot/stats.py b/hail/python/hail/ggplot/stats.py index 1bcd7367069..1307eb2ba0c 100644 --- a/hail/python/hail/ggplot/stats.py +++ b/hail/python/hail/ggplot/stats.py @@ -1,13 +1,12 @@ -from typing import List - import abc - -from pandas import DataFrame +from typing import List import pandas as pd +from pandas import DataFrame import hail as hl from hail.utils.java import warning + from .utils import should_use_scale_for_grouping @@ -27,9 +26,12 @@ def get_precomputes(self, mapping): class StatIdentity(Stat): def make_agg(self, mapping, precomputed, scales): - grouping_variables = {aes_key: mapping[aes_key] for aes_key in mapping.keys() - if should_use_scale_for_grouping(scales[aes_key])} - non_grouping_variables = {aes_key: mapping[aes_key] for aes_key in mapping.keys() if aes_key not in grouping_variables} + grouping_variables = { + aes_key: mapping[aes_key] for aes_key in mapping.keys() if should_use_scale_for_grouping(scales[aes_key]) + } + non_grouping_variables = { + aes_key: mapping[aes_key] for aes_key in mapping.keys() if aes_key not in grouping_variables + } return hl.agg.group_by(hl.struct(**grouping_variables), hl.agg.collect(hl.struct(**non_grouping_variables))) def listify(self, agg_result) -> List[DataFrame]: @@ -67,10 +69,13 @@ def listify(self, agg_result) -> List[DataFrame]: class StatCount(Stat): def make_agg(self, mapping, precomputed, scales): - grouping_variables = {aes_key: mapping[aes_key] for aes_key in mapping.keys() - if should_use_scale_for_grouping(scales[aes_key])} + grouping_variables = { + aes_key: mapping[aes_key] for aes_key in mapping.keys() if should_use_scale_for_grouping(scales[aes_key]) + } if "weight" in mapping: - return hl.agg.group_by(hl.struct(**grouping_variables), hl.agg.counter(mapping["x"], weight=mapping["weight"])) + return hl.agg.group_by( + hl.struct(**grouping_variables), hl.agg.counter(mapping["x"], weight=mapping["weight"]) + ) return hl.agg.group_by(hl.struct(**grouping_variables), hl.agg.group_by(mapping["x"], hl.agg.count())) def listify(self, agg_result) -> List[DataFrame]: @@ -97,7 +102,6 @@ def __init__(self, min_val, max_val, bins): self.bins = bins def get_precomputes(self, mapping): - precomputes = {} if self.min_val is None: precomputes["min_val"] = hl.agg.min(mapping.x) @@ -106,8 +110,9 @@ def get_precomputes(self, mapping): return hl.struct(**precomputes) def make_agg(self, mapping, precomputed, scales): - grouping_variables = {aes_key: mapping[aes_key] for aes_key in mapping.keys() - if should_use_scale_for_grouping(scales[aes_key])} + grouping_variables = { + aes_key: mapping[aes_key] for aes_key in mapping.keys() if should_use_scale_for_grouping(scales[aes_key]) + } start = self.min_val if self.min_val is not None else precomputed.min_val end = self.max_val if self.max_val is not None else precomputed.max_val @@ -128,7 +133,7 @@ def listify(self, agg_result) -> List[DataFrame]: for grouped_struct, hist in items: data_rows = [] y_values = hist.bin_freq - for i, x in enumerate(x_edges[:num_edges - 1]): + for i, x in enumerate(x_edges[: num_edges - 1]): data_rows.append({"x": x, "y": y_values[i]}) df = pd.DataFrame.from_records(data_rows) df.attrs.update(**grouped_struct) @@ -141,8 +146,9 @@ def __init__(self, k): self.k = k def make_agg(self, mapping, precomputed, scales): - grouping_variables = {aes_key: mapping[aes_key] for aes_key in mapping.keys() - if should_use_scale_for_grouping(scales[aes_key])} + grouping_variables = { + aes_key: mapping[aes_key] for aes_key in mapping.keys() if should_use_scale_for_grouping(scales[aes_key]) + } return hl.agg.group_by(hl.struct(**grouping_variables), hl.agg.approx_cdf(mapping["x"], self.k)) def listify(self, agg_result) -> List[DataFrame]: diff --git a/hail/python/hail/ggplot/utils.py b/hail/python/hail/ggplot/utils.py index 93c85db9646..c5d040d8215 100644 --- a/hail/python/hail/ggplot/utils.py +++ b/hail/python/hail/ggplot/utils.py @@ -1,4 +1,5 @@ import plotly + import hail as hl @@ -40,6 +41,7 @@ def adjust_color(input_color): def transform_color(input_color): return plotly.colors.sample_colorscale(continuous_color_scale, adjust_color(input_color))[0] + return transform_color @@ -49,11 +51,5 @@ def bar_position_plotly_to_gg(plotly_pos): def linetype_plotly_to_gg(plotly_linetype): - linetype_dict = { - "solid": "solid", - "dashed": "dash", - "dotted": "dot", - "longdash": "longdash", - "dotdash": "dashdot" - } + linetype_dict = {"solid": "solid", "dashed": "dash", "dotted": "dot", "longdash": "longdash", "dotdash": "dashdot"} return linetype_dict[plotly_linetype] diff --git a/hail/python/hail/ir/__init__.py b/hail/python/hail/ir/__init__.py index b853f6b7948..2325598a9f0 100644 --- a/hail/python/hail/ir/__init__.py +++ b/hail/python/hail/ir/__init__.py @@ -1,64 +1,272 @@ +from .base_ir import IR, BaseIR, BlockMatrixIR, MatrixIR, TableIR +from .blockmatrix_ir import ( + BandSparsifier, + BlockMatrixAgg, + BlockMatrixBroadcast, + BlockMatrixDensify, + BlockMatrixDot, + BlockMatrixFilter, + BlockMatrixMap, + BlockMatrixMap2, + BlockMatrixRandom, + BlockMatrixRead, + BlockMatrixSlice, + BlockMatrixSparsifier, + BlockMatrixSparsify, + PerBlockSparsifier, + RectangleSparsifier, + RowIntervalSparsifier, + ValueToBlockMatrix, + tensor_shape_to_matrix_shape, +) +from .blockmatrix_reader import ( + BlockMatrixBinaryReader, + BlockMatrixNativeReader, + BlockMatrixPersistReader, + BlockMatrixReader, +) +from .blockmatrix_writer import ( + BlockMatrixBinaryMultiWriter, + BlockMatrixBinaryWriter, + BlockMatrixMultiWriter, + BlockMatrixNativeMultiWriter, + BlockMatrixNativeWriter, + BlockMatrixPersistWriter, + BlockMatrixRectanglesWriter, + BlockMatrixTextMultiWriter, + BlockMatrixWriter, +) from .export_type import ExportType -from .base_ir import BaseIR, IR, TableIR, MatrixIR, BlockMatrixIR -from .ir import MatrixWrite, MatrixMultiWrite, BlockMatrixWrite, \ - BlockMatrixMultiWrite, TableToValueApply, \ - MatrixToValueApply, BlockMatrixToValueApply, BlockMatrixCollect, \ - Literal, EncodedLiteral, LiftMeOut, Join, JavaIR, I32, I64, F32, F64, Str, FalseIR, TrueIR, \ - Void, Cast, NA, IsNA, If, Coalesce, Let, AggLet, Ref, TopLevelReference, ProjectedTopLevelReference, SelectedTopLevelReference, \ - TailLoop, Recur, ApplyBinaryPrimOp, ApplyUnaryPrimOp, ApplyComparisonOp, \ - MakeArray, ArrayRef, ArraySlice, ArrayLen, ArrayZeros, StreamIota, StreamRange, StreamGrouped, MakeNDArray, \ - NDArrayShape, NDArrayReshape, NDArrayMap, NDArrayMap2, NDArrayRef, NDArraySlice, NDArraySVD, NDArrayEigh, \ - NDArrayReindex, NDArrayAgg, NDArrayMatMul, NDArrayQR, NDArrayInv, NDArrayConcat, NDArrayWrite, \ - ArraySort, ArrayMaximalIndependentSet, ToSet, ToDict, toArray, ToArray, CastToArray, \ - ToStream, toStream, LowerBoundOnOrderedCollection, GroupByKey, StreamMap, StreamZip, StreamTake, \ - StreamFilter, StreamFlatMap, StreamFold, StreamScan, StreamJoinRightDistinct, StreamFor, StreamWhiten, \ - AggFilter, AggExplode, AggGroupBy, AggArrayPerElement, BaseApplyAggOp, ApplyAggOp, ApplyScanOp, \ - AggFold, Begin, MakeStruct, SelectFields, InsertFields, GetField, MakeTuple, \ - GetTupleElement, Die, ConsoleLog, Apply, ApplySeeded, RNGStateLiteral, RNGSplit,\ - TableCount, TableGetGlobals, TableCollect, TableAggregate, MatrixCount, \ - MatrixAggregate, TableWrite, udf, subst, clear_session_functions, ReadPartition, \ - PartitionNativeIntervalReader, StreamMultiMerge, StreamZipJoin, StreamAgg, StreamZipJoinProducers, \ - GVCFPartitionReader -from .register_functions import register_functions +from .ir import ( + F32, + F64, + I32, + I64, + NA, + AggArrayPerElement, + AggExplode, + AggFilter, + AggFold, + AggGroupBy, + AggLet, + Apply, + ApplyAggOp, + ApplyBinaryPrimOp, + ApplyComparisonOp, + ApplyScanOp, + ApplySeeded, + ApplyUnaryPrimOp, + ArrayLen, + ArrayMaximalIndependentSet, + ArrayRef, + ArraySlice, + ArraySort, + ArrayZeros, + BaseApplyAggOp, + Begin, + BlockMatrixCollect, + BlockMatrixMultiWrite, + BlockMatrixToValueApply, + BlockMatrixWrite, + Cast, + CastToArray, + Coalesce, + ConsoleLog, + Die, + EncodedLiteral, + FalseIR, + GetField, + GetTupleElement, + GroupByKey, + GVCFPartitionReader, + If, + InsertFields, + IsNA, + JavaIR, + Join, + Let, + LiftMeOut, + Literal, + LowerBoundOnOrderedCollection, + MakeArray, + MakeNDArray, + MakeStruct, + MakeTuple, + MatrixAggregate, + MatrixCount, + MatrixMultiWrite, + MatrixToValueApply, + MatrixWrite, + NDArrayAgg, + NDArrayConcat, + NDArrayEigh, + NDArrayInv, + NDArrayMap, + NDArrayMap2, + NDArrayMatMul, + NDArrayQR, + NDArrayRef, + NDArrayReindex, + NDArrayReshape, + NDArrayShape, + NDArraySlice, + NDArraySVD, + NDArrayWrite, + PartitionNativeIntervalReader, + ProjectedTopLevelReference, + ReadPartition, + Recur, + Ref, + RNGSplit, + RNGStateLiteral, + SelectedTopLevelReference, + SelectFields, + Str, + StreamAgg, + StreamFilter, + StreamFlatMap, + StreamFold, + StreamFor, + StreamGrouped, + StreamIota, + StreamJoinRightDistinct, + StreamMap, + StreamMultiMerge, + StreamRange, + StreamScan, + StreamTake, + StreamWhiten, + StreamZip, + StreamZipJoin, + StreamZipJoinProducers, + TableAggregate, + TableCollect, + TableCount, + TableGetGlobals, + TableToValueApply, + TableWrite, + TailLoop, + ToArray, + ToDict, + TopLevelReference, + ToSet, + ToStream, + TrueIR, + Void, + clear_session_functions, + subst, + toArray, + toStream, + udf, +) +from .matrix_ir import ( + CastTableToMatrix, + MatrixAggregateColsByKey, + MatrixAggregateRowsByKey, + MatrixAnnotateColsTable, + MatrixAnnotateRowsTable, + MatrixChooseCols, + MatrixCollectColsByKey, + MatrixColsHead, + MatrixColsTail, + MatrixDistinctByRow, + MatrixExplodeCols, + MatrixExplodeRows, + MatrixFilterCols, + MatrixFilterEntries, + MatrixFilterIntervals, + MatrixFilterRows, + MatrixKeyRowsBy, + MatrixMapCols, + MatrixMapEntries, + MatrixMapGlobals, + MatrixMapRows, + MatrixRead, + MatrixRename, + MatrixRepartition, + MatrixRowsHead, + MatrixRowsTail, + MatrixToMatrixApply, + MatrixUnionCols, + MatrixUnionRows, +) +from .matrix_reader import ( + MatrixBGENReader, + MatrixNativeReader, + MatrixPLINKReader, + MatrixRangeReader, + MatrixReader, + MatrixVCFReader, +) +from .matrix_writer import ( + MatrixBGENWriter, + MatrixBlockMatrixWriter, + MatrixGENWriter, + MatrixNativeMultiWriter, + MatrixNativeWriter, + MatrixPLINKWriter, + MatrixVCFWriter, + MatrixWriter, +) from .register_aggregators import register_aggregators -from .table_ir import (MatrixRowsTable, TableJoin, TableLeftJoinRightDistinct, TableIntervalJoin, - TableUnion, TableRange, TableMapGlobals, TableExplode, TableKeyBy, TableMapRows, TableRead, - MatrixEntriesTable, TableFilter, TableKeyByAndAggregate, TableAggregateByKey, MatrixColsTable, - TableParallelize, TableHead, TableTail, TableOrderBy, TableDistinct, RepartitionStrategy, - TableRepartition, CastMatrixToTable, TableRename, TableMultiWayZipJoin, TableFilterIntervals, - TableToTableApply, MatrixToTableApply, BlockMatrixToTableApply, BlockMatrixToTable, JavaTable, - TableMapPartitions, TableGen, Partitioner) -from .matrix_ir import MatrixAggregateRowsByKey, MatrixRead, MatrixFilterRows, \ - MatrixChooseCols, MatrixMapCols, MatrixUnionCols, MatrixMapEntries, \ - MatrixFilterEntries, MatrixKeyRowsBy, MatrixMapRows, MatrixMapGlobals, \ - MatrixFilterCols, MatrixCollectColsByKey, MatrixAggregateColsByKey, \ - MatrixExplodeRows, MatrixRepartition, MatrixUnionRows, MatrixDistinctByRow, \ - MatrixRowsHead, MatrixColsHead, MatrixRowsTail, MatrixColsTail, \ - MatrixExplodeCols, CastTableToMatrix, MatrixAnnotateRowsTable, \ - MatrixAnnotateColsTable, MatrixToMatrixApply, MatrixRename, \ - MatrixFilterIntervals -from .blockmatrix_ir import BlockMatrixRead, BlockMatrixMap, BlockMatrixMap2, \ - BlockMatrixDot, BlockMatrixBroadcast, BlockMatrixAgg, BlockMatrixFilter, \ - BlockMatrixDensify, BlockMatrixSparsifier, BandSparsifier, \ - RowIntervalSparsifier, RectangleSparsifier, PerBlockSparsifier, BlockMatrixSparsify, \ - BlockMatrixSlice, ValueToBlockMatrix, BlockMatrixRandom, \ - tensor_shape_to_matrix_shape -from .utils import filter_predicate_with_keep, make_filter_and_replace, finalize_randomness -from .matrix_reader import MatrixReader, MatrixNativeReader, MatrixRangeReader, \ - MatrixVCFReader, MatrixBGENReader, MatrixPLINKReader -from .table_reader import AvroTableReader, TableReader, TableNativeReader, \ - TextTableReader, TableFromBlockMatrixNativeReader, StringTableReader -from .blockmatrix_reader import BlockMatrixReader, BlockMatrixNativeReader, \ - BlockMatrixBinaryReader, BlockMatrixPersistReader -from .matrix_writer import MatrixWriter, MatrixNativeWriter, MatrixVCFWriter, \ - MatrixGENWriter, MatrixBGENWriter, MatrixPLINKWriter, MatrixNativeMultiWriter, MatrixBlockMatrixWriter -from .table_writer import (TableWriter, TableNativeWriter, TableTextWriter, TableNativeFanoutWriter) -from .blockmatrix_writer import BlockMatrixWriter, BlockMatrixNativeWriter, \ - BlockMatrixBinaryWriter, BlockMatrixRectanglesWriter, \ - BlockMatrixMultiWriter, BlockMatrixBinaryMultiWriter, \ - BlockMatrixTextMultiWriter, BlockMatrixPersistWriter, BlockMatrixNativeMultiWriter -from .renderer import Renderable, RenderableStr, ParensRenderer, \ - RenderableQueue, RQStack, Renderer, PlainRenderer, CSERenderer +from .register_functions import register_functions +from .renderer import ( + CSERenderer, + ParensRenderer, + PlainRenderer, + Renderable, + RenderableQueue, + RenderableStr, + Renderer, + RQStack, +) +from .table_ir import ( + BlockMatrixToTable, + BlockMatrixToTableApply, + CastMatrixToTable, + JavaTable, + MatrixColsTable, + MatrixEntriesTable, + MatrixRowsTable, + MatrixToTableApply, + Partitioner, + RepartitionStrategy, + TableAggregateByKey, + TableDistinct, + TableExplode, + TableFilter, + TableFilterIntervals, + TableGen, + TableHead, + TableIntervalJoin, + TableJoin, + TableKeyBy, + TableKeyByAndAggregate, + TableLeftJoinRightDistinct, + TableMapGlobals, + TableMapPartitions, + TableMapRows, + TableMultiWayZipJoin, + TableOrderBy, + TableParallelize, + TableRange, + TableRead, + TableRename, + TableRepartition, + TableTail, + TableToTableApply, + TableUnion, +) +from .table_reader import ( + AvroTableReader, + StringTableReader, + TableFromBlockMatrixNativeReader, + TableNativeReader, + TableReader, + TextTableReader, +) +from .table_writer import TableNativeFanoutWriter, TableNativeWriter, TableTextWriter, TableWriter +from .utils import filter_predicate_with_keep, finalize_randomness, make_filter_and_replace __all__ = [ 'ExportType', @@ -321,5 +529,5 @@ 'PartitionNativeIntervalReader', 'GVCFPartitionReader', 'TableGen', - 'Partitioner' + 'Partitioner', ] diff --git a/hail/python/hail/ir/base_ir.py b/hail/python/hail/ir/base_ir.py index 2ca2a33d833..5ee3e5e773d 100644 --- a/hail/python/hail/ir/base_ir.py +++ b/hail/python/hail/ir/base_ir.py @@ -1,8 +1,9 @@ import abc +from typing import ClassVar from hail.expr.types import tstream -from .renderer import Renderer, PlainRenderer, Renderable +from .renderer import PlainRenderer, Renderable, Renderer counter = 0 @@ -204,13 +205,14 @@ def free_scan_vars(self): def base_search(self, criteria): others = [node for child in self.children if isinstance(child, BaseIR) for node in child.base_search(criteria)] if criteria(self): - return others + [self] + return [*others, self] return others def save_error_info(self): self._error_id = get_next_int() import traceback + stack = traceback.format_stack() i = len(stack) while i > 0: @@ -226,11 +228,10 @@ def save_error_info(self): 'typecheck/check', 'interactiveshell.py', 'expressions.construct_variable', - 'traceback.format_stack()' + 'traceback.format_stack()', ] filt_stack = [ - candidate for candidate in stack[i:] - if not any(phrase in candidate for phrase in forbidden_phrases) + candidate for candidate in stack[i:] if not any(phrase in candidate for phrase in forbidden_phrases) ] self._stack_trace = '\n'.join(filt_stack) @@ -259,7 +260,7 @@ def is_nested_field(self): def search(self, criteria): others = [node for child in self.children if isinstance(child, IR) for node in child.search(criteria)] if criteria(self): - return others + [self] + return [*others, self] return others def map_ir(self, f): @@ -274,7 +275,11 @@ def map_ir(self, f): @property def uses_randomness(self) -> bool: - return '__rng_state' in self.free_vars or '__rng_state' in self.free_agg_vars or '__rng_state' in self.free_scan_vars + return ( + '__rng_state' in self.free_vars + or '__rng_state' in self.free_agg_vars + or '__rng_state' in self.free_scan_vars + ) @property def uses_value_randomness(self): @@ -302,7 +307,7 @@ def renderable_new_block(self, i: int) -> bool: def compute_type(self, env, agg_env, deep_typecheck): if deep_typecheck or self._type is None: computed = self._compute_type(env, agg_env, deep_typecheck) - assert(computed is not None) + assert computed is not None if self._type is not None: assert self._type == computed self._type = computed @@ -341,7 +346,7 @@ def handle_randomness(self, create_uids): The uid may be an int64, or arbitrary tuple of int64s. The only requirement is that all stream elements contain distinct uid values. """ - assert(self.is_stream) + assert self.is_stream if (create_uids == self.has_uids) and not self.needs_randomness_handling: return self new = self._handle_randomness(create_uids) @@ -353,17 +358,15 @@ def handle_randomness(self, create_uids): def free_vars(self): def vars_from_child(i): if self.uses_agg_context(i): - assert(len(self.children[i].free_agg_vars) == 0) + assert len(self.children[i].free_agg_vars) == 0 return set() if self.uses_scan_context(i): - assert(len(self.children[i].free_scan_vars) == 0) + assert len(self.children[i].free_scan_vars) == 0 return set() return self.children[i].free_vars.difference(self.bindings(i, 0).keys()) if self._free_vars is None: - self._free_vars = { - var for i in range(len(self.children)) - for var in vars_from_child(i)} + self._free_vars = {var for i in range(len(self.children)) for var in vars_from_child(i)} if self.uses_agg_capability(): self._free_vars.add(BaseIR.agg_capability) return self._free_vars @@ -376,9 +379,7 @@ def vars_from_child(i): return self.children[i].free_agg_vars.difference(self.agg_bindings(i, 0).keys()) if self._free_agg_vars is None: - self._free_agg_vars = { - var for i in range(len(self.children)) - for var in vars_from_child(i)} + self._free_agg_vars = {var for i in range(len(self.children)) for var in vars_from_child(i)} return self._free_agg_vars @property @@ -389,9 +390,7 @@ def vars_from_child(i): return self.children[i].free_scan_vars.difference(self.scan_bindings(i, 0).keys()) if self._free_scan_vars is None: - self._free_scan_vars = { - var for i in range(len(self.children)) - for var in vars_from_child(i)} + self._free_scan_vars = {var for i in range(len(self.children)) for var in vars_from_child(i)} return self._free_scan_vars @@ -401,8 +400,7 @@ def __init__(self, *children): self._children_use_randomness = any(child.uses_randomness for child in children) @abc.abstractmethod - def _compute_type(self, deep_typecheck): - ... + def _compute_type(self, deep_typecheck): ... def compute_type(self, deep_typecheck): if deep_typecheck or self._type is None: @@ -437,13 +435,16 @@ def handle_randomness(self, uid_field_name): """ if uid_field_name is None and not self.uses_randomness: return self - return self._handle_randomness(uid_field_name) + + new_self = self._handle_randomness(uid_field_name) + assert uid_field_name is None or uid_field_name in new_self.typ.row_type.fields + return new_self def renderable_new_block(self, i: int) -> bool: return True - global_env = {'global'} - row_env = {'global', 'row'} + global_env: ClassVar = {'global'} + row_env: ClassVar = {'global', 'row'} class MatrixIR(BaseIR): @@ -477,8 +478,7 @@ def _handle_randomness(self, row_uid_field_name, col_uid_field_name): pass @abc.abstractmethod - def _compute_type(self, deep_typecheck): - ... + def _compute_type(self, deep_typecheck): ... def compute_type(self, deep_typecheck): if deep_typecheck or self._type is None: @@ -497,10 +497,10 @@ def typ(self): def renderable_new_block(self, i: int) -> bool: return True - global_env = {'global'} - row_env = {'global', 'va'} - col_env = {'global', 'sa'} - entry_env = {'global', 'sa', 'va', 'g'} + global_env: ClassVar = {'global'} + row_env: ClassVar = {'global', 'va'} + col_env: ClassVar = {'global', 'sa'} + entry_env: ClassVar = {'global', 'sa', 'va', 'g'} class BlockMatrixIR(BaseIR): @@ -513,8 +513,7 @@ def uses_randomness(self) -> bool: return self._children_use_randomness @abc.abstractmethod - def _compute_type(self, deep_typecheck): - ... + def _compute_type(self, deep_typecheck): ... def compute_type(self, deep_typecheck): if deep_typecheck or self._type is None: diff --git a/hail/python/hail/ir/blockmatrix_ir.py b/hail/python/hail/ir/blockmatrix_ir.py index a6278c65fa9..3342ccf9eca 100644 --- a/hail/python/hail/ir/blockmatrix_ir.py +++ b/hail/python/hail/ir/blockmatrix_ir.py @@ -1,12 +1,12 @@ import hail as hl from hail.expr.blockmatrix_type import tblockmatrix from hail.expr.types import tarray -from .blockmatrix_reader import BlockMatrixReader -from .base_ir import BlockMatrixIR, IR -from hail.typecheck import typecheck_method, sequenceof, nullable +from hail.typecheck import nullable, sequenceof, typecheck_method +from hail.utils.java import Env from hail.utils.misc import escape_id -from hail.utils.java import Env +from .base_ir import IR, BlockMatrixIR +from .blockmatrix_reader import BlockMatrixReader class BlockMatrixRead(BlockMatrixIR): @@ -32,7 +32,7 @@ def _compute_type(self, deep_typecheck): class BlockMatrixMap(BlockMatrixIR): @typecheck_method(child=BlockMatrixIR, name=str, f=IR, needs_dense=bool) def __init__(self, child, name, f, needs_dense): - assert(not f.uses_randomness) + assert not f.uses_randomness super().__init__(child, f) self.child = child self.name = name @@ -43,10 +43,7 @@ def _compute_type(self, deep_typecheck): self.child.compute_type(deep_typecheck) self.f.compute_type(self.bindings(1), None, deep_typecheck) child_type = self.child.typ - return tblockmatrix(self.f.typ, - child_type.shape, - child_type.is_row_vector, - child_type.block_size) + return tblockmatrix(self.f.typ, child_type.shape, child_type.is_row_vector, child_type.block_size) def head_str(self): return escape_id(self.name) + " " + str(self.needs_dense) @@ -63,9 +60,11 @@ def binds(self, i): class BlockMatrixMap2(BlockMatrixIR): - @typecheck_method(left=BlockMatrixIR, right=BlockMatrixIR, left_name=str, right_name=str, f=IR, sparsity_strategy=str) + @typecheck_method( + left=BlockMatrixIR, right=BlockMatrixIR, left_name=str, right_name=str, f=IR, sparsity_strategy=str + ) def __init__(self, left, right, left_name, right_name, f, sparsity_strategy): - assert(not f.uses_randomness) + assert not f.uses_randomness super().__init__(left, right, f) self.left = left self.right = right @@ -79,10 +78,7 @@ def _compute_type(self, deep_typecheck): self.right.compute_type(deep_typecheck) self.f.compute_type(self.bindings(2), None, deep_typecheck) left_type = self.left.typ - return tblockmatrix(self.f.typ, - left_type.shape, - left_type.is_row_vector, - left_type.block_size) + return tblockmatrix(self.f.typ, left_type.shape, left_type.is_row_vector, left_type.block_size) def head_str(self): return escape_id(self.left_name) + " " + escape_id(self.right_name) + " " + self.sparsity_strategy @@ -117,17 +113,11 @@ def _compute_type(self, deep_typecheck): assert l_cols == r_rows tensor_shape, is_row_vector = _matrix_shape_to_tensor_shape(l_rows, r_cols) - return tblockmatrix(self.left.typ.element_type, - tensor_shape, - is_row_vector, - self.left.typ.block_size) + return tblockmatrix(self.left.typ.element_type, tensor_shape, is_row_vector, self.left.typ.block_size) class BlockMatrixBroadcast(BlockMatrixIR): - @typecheck_method(child=BlockMatrixIR, - in_index_expr=sequenceof(int), - shape=sequenceof(int), - block_size=int) + @typecheck_method(child=BlockMatrixIR, in_index_expr=sequenceof(int), shape=sequenceof(int), block_size=int) def __init__(self, child, in_index_expr, shape, block_size): super().__init__(child) self.child = child @@ -136,28 +126,24 @@ def __init__(self, child, in_index_expr, shape, block_size): self.block_size = block_size def head_str(self): - return '{} {} {}'.format(_serialize_list(self.in_index_expr), - _serialize_list(self.shape), - self.block_size) + return '{} {} {}'.format(_serialize_list(self.in_index_expr), _serialize_list(self.shape), self.block_size) def _eq(self, other): - return self.in_index_expr == other.in_index_expr and \ - self.shape == other.shape and \ - self.block_size == other.block_size + return ( + self.in_index_expr == other.in_index_expr + and self.shape == other.shape + and self.block_size == other.block_size + ) def _compute_type(self, deep_typecheck): self.child.compute_type(deep_typecheck) assert len(self.shape) == 2, self.shape tensor_shape, is_row_vector = _matrix_shape_to_tensor_shape(self.shape[0], self.shape[1]) - return tblockmatrix(self.child.typ.element_type, - tensor_shape, - is_row_vector, - self.block_size) + return tblockmatrix(self.child.typ.element_type, tensor_shape, is_row_vector, self.block_size) class BlockMatrixAgg(BlockMatrixIR): - @typecheck_method(child=BlockMatrixIR, - out_index_expr=sequenceof(int)) + @typecheck_method(child=BlockMatrixIR, out_index_expr=sequenceof(int)) def __init__(self, child, out_index_expr): super().__init__(child) self.child = child @@ -184,10 +170,7 @@ def _compute_type(self, deep_typecheck): else: raise ValueError("Invalid out_index_expr") - return tblockmatrix(self.child.typ.element_type, - shape, - is_row_vector, - self.child.typ.block_size) + return tblockmatrix(self.child.typ.element_type, shape, is_row_vector, self.child.typ.block_size) class BlockMatrixFilter(BlockMatrixIR): @@ -217,14 +200,12 @@ def _compute_type(self, deep_typecheck): else: child_matrix_shape = child_tensor_shape - matrix_shape = [len(idxs) if len(idxs) != 0 else child_matrix_shape[i] for i, idxs in - enumerate(self.indices_to_keep)] + matrix_shape = [ + len(idxs) if len(idxs) != 0 else child_matrix_shape[i] for i, idxs in enumerate(self.indices_to_keep) + ] tensor_shape, is_row_vector = _matrix_shape_to_tensor_shape(matrix_shape[0], matrix_shape[1]) - return tblockmatrix(self.child.typ.element_type, - tensor_shape, - is_row_vector, - self.child.typ.block_size) + return tblockmatrix(self.child.typ.element_type, tensor_shape, is_row_vector, self.child.typ.block_size) class BlockMatrixDensify(BlockMatrixIR): @@ -301,7 +282,7 @@ def __repr__(self): class BlockMatrixSparsify(BlockMatrixIR): @typecheck_method(child=BlockMatrixIR, value=IR, sparsifier=BlockMatrixSparsifier) def __init__(self, child, value, sparsifier): - assert(not value.uses_randomness) + assert not value.uses_randomness super().__init__(value, child) self.child = child self.value = value @@ -337,18 +318,14 @@ def _compute_type(self, deep_typecheck): assert len(self.slices) == 2 matrix_shape = [1 + (s.stop - s.start - 1) // s.step for s in self.slices] tensor_shape, is_row_vector = _matrix_shape_to_tensor_shape(matrix_shape[0], matrix_shape[1]) - return tblockmatrix(self.child.typ.element_type, - tensor_shape, - is_row_vector, - self.child.typ.block_size) + return tblockmatrix(self.child.typ.element_type, tensor_shape, is_row_vector, self.child.typ.block_size) class ValueToBlockMatrix(BlockMatrixIR): - @typecheck_method(child=IR, - shape=sequenceof(int), - block_size=int) + @typecheck_method(child=IR, shape=sequenceof(int), block_size=int) def __init__(self, child, shape, block_size): from .ir import Let, RNGStateLiteral + if child.uses_randomness: child = Let('__rng_state', RNGStateLiteral(), child) super().__init__(child) @@ -357,12 +334,10 @@ def __init__(self, child, shape, block_size): self.block_size = block_size def head_str(self): - return '{} {}'.format(_serialize_list(self.shape), - self.block_size) + return '{} {}'.format(_serialize_list(self.shape), self.block_size) def _eq(self, other): - return self.shape == other.shape and \ - self.block_size == other.block_size + return self.shape == other.shape and self.block_size == other.block_size def _compute_type(self, deep_typecheck): self.child.compute_type(deep_typecheck, {}, None) @@ -378,10 +353,7 @@ def _compute_type(self, deep_typecheck): class BlockMatrixRandom(BlockMatrixIR): - @typecheck_method(static_rng_uid=int, - gaussian=bool, - shape=sequenceof(int), - block_size=int) + @typecheck_method(static_rng_uid=int, gaussian=bool, shape=sequenceof(int), block_size=int) def __init__(self, static_rng_uid, gaussian, shape, block_size): super().__init__() self.static_rng_uid = static_rng_uid @@ -390,16 +362,15 @@ def __init__(self, static_rng_uid, gaussian, shape, block_size): self.block_size = block_size def head_str(self): - return '{} {} {} {}'.format(self.static_rng_uid, - self.gaussian, - _serialize_list(self.shape), - self.block_size) + return '{} {} {} {}'.format(self.static_rng_uid, self.gaussian, _serialize_list(self.shape), self.block_size) def _eq(self, other): - return self.static_rng_uid == other.static_rng_uid and \ - self.gaussian == other.gaussian and \ - self.shape == other.shape and \ - self.block_size == other.block_size + return ( + self.static_rng_uid == other.static_rng_uid + and self.gaussian == other.gaussian + and self.shape == other.shape + and self.block_size == other.block_size + ) def _compute_type(self, deep_typecheck): assert len(self.shape) == 2, self.shape diff --git a/hail/python/hail/ir/blockmatrix_reader.py b/hail/python/hail/ir/blockmatrix_reader.py index 61e94d56594..e3c7f3d8541 100644 --- a/hail/python/hail/ir/blockmatrix_reader.py +++ b/hail/python/hail/ir/blockmatrix_reader.py @@ -1,7 +1,7 @@ import abc import json -from ..typecheck import typecheck_method, sequenceof +from ..typecheck import sequenceof, typecheck_method from ..utils.misc import escape_str @@ -21,13 +21,11 @@ def __init__(self, path): self.path = path def render(self): - reader = {'name': 'BlockMatrixNativeReader', - 'path': self.path} + reader = {'name': 'BlockMatrixNativeReader', 'path': self.path} return escape_str(json.dumps(reader)) def __eq__(self, other): - return isinstance(other, BlockMatrixNativeReader) and \ - self.path == other.path + return isinstance(other, BlockMatrixNativeReader) and self.path == other.path class BlockMatrixBinaryReader(BlockMatrixReader): @@ -38,17 +36,21 @@ def __init__(self, path, shape, block_size): self.block_size = block_size def render(self): - reader = {'name': 'BlockMatrixBinaryReader', - 'path': self.path, - 'shape': self.shape, - 'blockSize': self.block_size} + reader = { + 'name': 'BlockMatrixBinaryReader', + 'path': self.path, + 'shape': self.shape, + 'blockSize': self.block_size, + } return escape_str(json.dumps(reader)) def __eq__(self, other): - return isinstance(other, BlockMatrixBinaryReader) and \ - self.path == other.path and \ - self.shape == other.shape and \ - self.block_size == other.block_size + return ( + isinstance(other, BlockMatrixBinaryReader) + and self.path == other.path + and self.shape == other.shape + and self.block_size == other.block_size + ) class BlockMatrixPersistReader(BlockMatrixReader): @@ -57,13 +59,11 @@ def __init__(self, id, original): self.original = original def render(self): - reader = {'name': 'BlockMatrixPersistReader', - 'id': self.id} + reader = {'name': 'BlockMatrixPersistReader', 'id': self.id} return escape_str(json.dumps(reader)) def __eq__(self, other): - return isinstance(other, BlockMatrixPersistReader) and \ - self.id == other.id + return isinstance(other, BlockMatrixPersistReader) and self.id == other.id def unpersisted(self): return self.original diff --git a/hail/python/hail/ir/blockmatrix_writer.py b/hail/python/hail/ir/blockmatrix_writer.py index 9f41f17f593..c2f7654deff 100644 --- a/hail/python/hail/ir/blockmatrix_writer.py +++ b/hail/python/hail/ir/blockmatrix_writer.py @@ -1,8 +1,8 @@ import abc import json -from ..typecheck import typecheck_method, sequenceof, nullable, enumeration -from ..expr.types import tvoid, tstr +from ..expr.types import tstr, tvoid +from ..typecheck import enumeration, nullable, sequenceof, typecheck_method from ..utils.misc import escape_str @@ -29,22 +29,26 @@ def __init__(self, path, overwrite, force_row_major, stage_locally): self.stage_locally = stage_locally def render(self): - writer = {'name': 'BlockMatrixNativeWriter', - 'path': self.path, - 'overwrite': self.overwrite, - 'forceRowMajor': self.force_row_major, - 'stageLocally': self.stage_locally} + writer = { + 'name': 'BlockMatrixNativeWriter', + 'path': self.path, + 'overwrite': self.overwrite, + 'forceRowMajor': self.force_row_major, + 'stageLocally': self.stage_locally, + } return escape_str(json.dumps(writer)) def _type(self): return tvoid def __eq__(self, other): - return isinstance(other, BlockMatrixNativeWriter) and \ - self.path == other.path and \ - self.overwrite == other.overwrite and \ - self.force_row_major == other.force_row_major and \ - self.stage_locally == other.stage_locally + return ( + isinstance(other, BlockMatrixNativeWriter) + and self.path == other.path + and self.overwrite == other.overwrite + and self.force_row_major == other.force_row_major + and self.stage_locally == other.stage_locally + ) class BlockMatrixBinaryWriter(BlockMatrixWriter): @@ -53,23 +57,18 @@ def __init__(self, path): self.path = path def render(self): - writer = {'name': 'BlockMatrixBinaryWriter', - 'path': self.path} + writer = {'name': 'BlockMatrixBinaryWriter', 'path': self.path} return escape_str(json.dumps(writer)) def _type(self): return tstr def __eq__(self, other): - return isinstance(other, BlockMatrixBinaryWriter) and \ - self.path == other.path + return isinstance(other, BlockMatrixBinaryWriter) and self.path == other.path class BlockMatrixRectanglesWriter(BlockMatrixWriter): - @typecheck_method(path=str, - rectangles=sequenceof(sequenceof(int)), - delimiter=str, - binary=bool) + @typecheck_method(path=str, rectangles=sequenceof(sequenceof(int)), delimiter=str, binary=bool) def __init__(self, path, rectangles, delimiter, binary): self.path = path self.rectangles = rectangles @@ -77,22 +76,26 @@ def __init__(self, path, rectangles, delimiter, binary): self.binary = binary def render(self): - writer = {'name': 'BlockMatrixRectanglesWriter', - 'path': self.path, - 'rectangles': self.rectangles, - 'delimiter': self.delimiter, - 'binary': self.binary} + writer = { + 'name': 'BlockMatrixRectanglesWriter', + 'path': self.path, + 'rectangles': self.rectangles, + 'delimiter': self.delimiter, + 'binary': self.binary, + } return escape_str(json.dumps(writer)) def _type(self): return tvoid def __eq__(self, other): - return isinstance(other, BlockMatrixRectanglesWriter) and \ - self.path == other.path and \ - self.rectangles == other.rectangles and \ - self.delimiter == other.delimiter and \ - self.binary == other.binary + return ( + isinstance(other, BlockMatrixRectanglesWriter) + and self.path == other.path + and self.rectangles == other.rectangles + and self.delimiter == other.delimiter + and self.binary == other.binary + ) class BlockMatrixMultiWriter(object): @@ -115,23 +118,30 @@ def __init__(self, prefix, overwrite): self.overwrite = overwrite def render(self): - writer = {'name': 'BlockMatrixBinaryMultiWriter', - 'prefix': self.prefix, - 'overwrite': self.overwrite} + writer = {'name': 'BlockMatrixBinaryMultiWriter', 'prefix': self.prefix, 'overwrite': self.overwrite} return escape_str(json.dumps(writer)) def _type(self): return tvoid def __eq__(self, other): - return isinstance(other, BlockMatrixBinaryMultiWriter) and \ - self.prefix == other.prefix and \ - self.overwrite == other.overwrite + return ( + isinstance(other, BlockMatrixBinaryMultiWriter) + and self.prefix == other.prefix + and self.overwrite == other.overwrite + ) class BlockMatrixTextMultiWriter(BlockMatrixMultiWriter): - @typecheck_method(prefix=str, overwrite=bool, delimiter=str, header=nullable(str), add_index=bool, - compression=nullable(enumeration('gz', 'bgz')), custom_filenames=nullable(sequenceof(str))) + @typecheck_method( + prefix=str, + overwrite=bool, + delimiter=str, + header=nullable(str), + add_index=bool, + compression=nullable(enumeration('gz', 'bgz')), + custom_filenames=nullable(sequenceof(str)), + ) def __init__(self, prefix, overwrite, delimiter, header, add_index, compression, custom_filenames): self.prefix = prefix self.overwrite = overwrite @@ -142,28 +152,32 @@ def __init__(self, prefix, overwrite, delimiter, header, add_index, compression, self.custom_filenames = custom_filenames def render(self): - writer = {'name': 'BlockMatrixTextMultiWriter', - 'prefix': self.prefix, - 'overwrite': self.overwrite, - 'delimiter': self.delimiter, - 'header': self.header, - 'addIndex': self.add_index, - 'compression': self.compression, - 'customFilenames': self.custom_filenames} + writer = { + 'name': 'BlockMatrixTextMultiWriter', + 'prefix': self.prefix, + 'overwrite': self.overwrite, + 'delimiter': self.delimiter, + 'header': self.header, + 'addIndex': self.add_index, + 'compression': self.compression, + 'customFilenames': self.custom_filenames, + } return escape_str(json.dumps(writer)) def _type(self): return tvoid def __eq__(self, other): - return isinstance(other, BlockMatrixTextMultiWriter) and \ - self.prefix == other.prefix and \ - self.overwrite == other.overwrite and \ - self.delimiter == other.overwrite and \ - self.header == other.header and \ - self.add_index == other.add_index and \ - self.compression == other.compression and \ - self.custom_filenames == other.custom_filenames + return ( + isinstance(other, BlockMatrixTextMultiWriter) + and self.prefix == other.prefix + and self.overwrite == other.overwrite + and self.delimiter == other.overwrite + and self.header == other.header + and self.add_index == other.add_index + and self.compression == other.compression + and self.custom_filenames == other.custom_filenames + ) class BlockMatrixPersistWriter(BlockMatrixWriter): @@ -173,18 +187,18 @@ def __init__(self, id, storage_level): self.storage_level = storage_level def render(self): - writer = {'name': 'BlockMatrixPersistWriter', - 'id': self.id, - 'storageLevel': self.storage_level} + writer = {'name': 'BlockMatrixPersistWriter', 'id': self.id, 'storageLevel': self.storage_level} return escape_str(json.dumps(writer)) def _type(self): return tvoid def __eq__(self, other): - return isinstance(other, BlockMatrixPersistWriter) and \ - self.id == other.id and \ - self.storage_level == other.storage_level + return ( + isinstance(other, BlockMatrixPersistWriter) + and self.id == other.id + and self.storage_level == other.storage_level + ) class BlockMatrixNativeMultiWriter(BlockMatrixMultiWriter): @@ -196,19 +210,23 @@ def __init__(self, prefix, overwrite, force_row_major, stage_locally): self.stage_locally = stage_locally def render(self): - writer = {'name': 'BlockMatrixNativeMultiWriter', - 'prefix': self.prefix, - 'overwrite': self.overwrite, - 'forceRowMajor': self.force_row_major, - 'stageLocally': self.stage_locally} + writer = { + 'name': 'BlockMatrixNativeMultiWriter', + 'prefix': self.prefix, + 'overwrite': self.overwrite, + 'forceRowMajor': self.force_row_major, + 'stageLocally': self.stage_locally, + } return escape_str(json.dumps(writer)) def _type(self): return tvoid def __eq__(self, other): - return isinstance(other, BlockMatrixNativeMultiWriter) and \ - self.prefix == other.prefix and \ - self.overwrite == other.overwrite and \ - self.force_row_major == other.force_row_major and \ - self.stage_locally == other.stage_locally + return ( + isinstance(other, BlockMatrixNativeMultiWriter) + and self.prefix == other.prefix + and self.overwrite == other.overwrite + and self.force_row_major == other.force_row_major + and self.stage_locally == other.stage_locally + ) diff --git a/hail/python/hail/ir/export_type.py b/hail/python/hail/ir/export_type.py index 3a9e99ec419..9c54179d971 100644 --- a/hail/python/hail/ir/export_type.py +++ b/hail/python/hail/ir/export_type.py @@ -7,8 +7,7 @@ class ExportType: PARALLEL_HEADER_IN_SHARD = "header_per_shard" PARALLEL_COMPOSABLE = "composable" - checker = enumeration("concatenated", "separate_header", "header_per_shard", - "composable") + checker = enumeration("concatenated", "separate_header", "header_per_shard", "composable") @staticmethod def default(export_type): diff --git a/hail/python/hail/ir/ir.py b/hail/python/hail/ir/ir.py index 5b54f1dd482..de3002de415 100644 --- a/hail/python/hail/ir/ir.py +++ b/hail/python/hail/ir/ir.py @@ -1,27 +1,56 @@ -from typing import Callable, Optional, TypeVar, cast -from typing_extensions import ParamSpec import base64 import copy import json from collections import defaultdict +from typing import Callable, Optional, TypeVar, cast -from hailtop.hail_decorator import decorator +from typing_extensions import ParamSpec import hail -from hail.expr.types import HailType, hail_type, tint32, tint64, \ - tfloat32, tfloat64, tstr, tbool, tarray, tstream, tndarray, tset, tdict, \ - tstruct, ttuple, tinterval, tvoid, trngstate, tlocus, tcall -from hail.ir.blockmatrix_writer import BlockMatrixWriter, BlockMatrixMultiWriter -from hail.typecheck import typecheck, typecheck_method, sequenceof, numeric, \ - sized_tupleof, nullable, tupleof, anytype, func_spec +from hail.expr.types import ( + HailType, + hail_type, + tarray, + tbool, + tcall, + tdict, + tfloat32, + tfloat64, + tint32, + tint64, + tinterval, + tlocus, + tndarray, + trngstate, + tset, + tstr, + tstream, + tstruct, + ttuple, + tvoid, +) +from hail.ir.blockmatrix_writer import BlockMatrixMultiWriter, BlockMatrixWriter +from hail.typecheck import ( + anytype, + func_spec, + nullable, + numeric, + sequenceof, + sized_tupleof, + tupleof, + typecheck, + typecheck_method, +) from hail.utils.java import Env, HailUserError from hail.utils.jsonx import dump_json -from hail.utils.misc import escape_str, parsable_strings, escape_id -from .base_ir import BaseIR, IR, TableIR, MatrixIR, BlockMatrixIR, _env_bind -from .matrix_writer import MatrixWriter, MatrixNativeMultiWriter -from .renderer import Renderer, Renderable, ParensRenderer +from hail.utils.misc import escape_id, escape_str, parsable_strings +from hailtop.hail_decorator import decorator + +from .base_ir import IR, BaseIR, BlockMatrixIR, MatrixIR, TableIR, _env_bind +from .matrix_writer import MatrixNativeMultiWriter, MatrixWriter +from .renderer import ParensRenderer, Renderable, Renderer from .table_writer import TableWriter -from .utils import default_row_uid, default_col_uid, unpack_row_uid, unpack_col_uid +from .utils import default_col_uid, default_row_uid, unpack_col_uid, unpack_row_uid class I32(IR): @@ -240,9 +269,7 @@ def __init__(self, cond, cnsq, altr): self.needs_randomness_handling = cnsq.needs_randomness_handling or altr.needs_randomness_handling def _handle_randomness(self, create_uids): - return If(self.cond, - self.cnsq.handle_randomness(create_uids), - self.altr.handle_randomness(create_uids)) + return If(self.cond, self.cnsq.handle_randomness(create_uids), self.altr.handle_randomness(create_uids)) @typecheck_method(cond=IR, cnsq=IR, altr=IR) def copy(self, cond, cnsq, altr): @@ -252,11 +279,11 @@ def _compute_type(self, env, agg_env, deep_typecheck): self.cond.compute_type(env, agg_env, deep_typecheck) self.cnsq.compute_type(env, agg_env, deep_typecheck) self.altr.compute_type(env, agg_env, deep_typecheck) - assert (self.cnsq.typ == self.altr.typ) + assert self.cnsq.typ == self.altr.typ return self.cnsq.typ def renderable_new_block(self, i): - return i == 1 or i == 2 + return i in {1, 2} class Coalesce(IR): @@ -301,7 +328,7 @@ def copy(self, value, body): return Let(self.name, value, body) def head_str(self): - return escape_id(self.name) + return f'eval {escape_id(self.name)}' @property def bound_variables(self): @@ -395,10 +422,12 @@ def _handle_randomness(self, create_uids): if create_uids: elt = Env.get_uid() uid = Env.get_uid() - return StreamZip([self, StreamIota(I32(0), I32(1))], - [elt, uid], - pack_uid(Cast(Ref(uid, tint32), tint64), Ref(elt, self.typ.element_type)), - 'TakeMinLength') + return StreamZip( + [self, StreamIota(I32(0), I32(1))], + [elt, uid], + pack_uid(Cast(Ref(uid, tint32), tint64), Ref(elt, self.typ.element_type)), + 'TakeMinLength', + ) else: tuple, uid, elt = unpack_uid(self.typ) return StreamMap(self, tuple, elt) @@ -754,8 +783,9 @@ def _handle_randomness(self, create_uids): @typecheck_method(start=IR, step=IR) def copy(self, start, step): - return StreamIota(start, step, - requires_memory_management_per_element=self.requires_memory_management_per_element) + return StreamIota( + start, step, requires_memory_management_per_element=self.requires_memory_management_per_element + ) def head_str(self): return f'{self.requires_memory_management_per_element}' @@ -767,10 +797,17 @@ def _compute_type(self, env, agg_env, deep_typecheck): class StreamRange(IR): - @typecheck_method(start=IR, stop=IR, step=IR, requires_memory_management_per_element=bool, - error_id=nullable(int), stack_trace=nullable(str)) - def __init__(self, start, stop, step, requires_memory_management_per_element=False, - error_id=None, stack_trace=None): + @typecheck_method( + start=IR, + stop=IR, + step=IR, + requires_memory_management_per_element=bool, + error_id=nullable(int), + stack_trace=nullable(str), + ) + def __init__( + self, start, stop, step, requires_memory_management_per_element=False, error_id=None, stack_trace=None + ): super().__init__(start, stop, step) self.start = start self.stop = stop @@ -809,8 +846,8 @@ def __init__(self, stream, group_size): self.needs_randomness_handling = stream.needs_randomness_handling def _handle_randomness(self, create_uids): - assert(not create_uids) - assert(self.stream.needs_randomness_handling) + assert not create_uids + assert self.stream.needs_randomness_handling self.stream.handle_randomness(False) @typecheck_method(stream=IR, group_size=IR) @@ -890,7 +927,7 @@ def _compute_type(self, env, agg_env, deep_typecheck): class NDArrayMap(IR): @typecheck_method(nd=IR, name=str, body=IR) def __init__(self, nd, name, body): - assert(not body.uses_randomness) + assert not body.uses_randomness super().__init__(nd, body) self.nd = nd self.name = name @@ -927,7 +964,9 @@ def renderable_bindings(self, i, default_value=None): class NDArrayMap2(IR): - @typecheck_method(left=IR, right=IR, lname=str, rname=str, body=IR, error_id=nullable(int), stack_trace=nullable(str)) + @typecheck_method( + left=IR, right=IR, lname=str, rname=str, body=IR, error_id=nullable(int), stack_trace=nullable(str) + ) def __init__(self, left, right, lname, rname, body, error_id=None, stack_trace=None): super().__init__(left, right, body) self.right = right @@ -948,8 +987,7 @@ def head_str(self): return f'{self._error_id} {escape_id(self.lname)} {escape_id(self.rname)}' def _eq(self, other): - return self.lname == other.lname and \ - self.rname == other.rname + return self.lname == other.lname and self.rname == other.rname @property def bound_variables(self): @@ -1008,8 +1046,7 @@ def _compute_type(self, env, agg_env, deep_typecheck): self.nd.compute_type(env, agg_env, deep_typecheck) self.slices.compute_type(env, agg_env, deep_typecheck) - return tndarray(self.nd.typ.element_type, - len([t for t in self.slices.typ.types if isinstance(t, ttuple)])) + return tndarray(self.nd.typ.element_type, len([t for t in self.slices.typ.types if isinstance(t, ttuple)])) class NDArrayReindex(IR): @@ -1083,8 +1120,8 @@ def _compute_type(self, env, agg_env, deep_typecheck): ndim = hail.linalg.utils.misc._ndarray_matmul_ndim(self.left.typ.ndim, self.right.typ.ndim) from hail.expr.expressions import unify_types - return tndarray(unify_types(self.left.typ.element_type, - self.right.typ.element_type), ndim) + + return tndarray(unify_types(self.left.typ.element_type, self.right.typ.element_type), ndim) class NDArrayQR(IR): @@ -1203,8 +1240,7 @@ def head_str(self): return self.axis def _eq(self, other): - return other.nds == self.nds and \ - other.axis == self.axis + return other.nds == self.nds and other.axis == self.axis def _compute_type(self, env, agg_env, deep_typecheck): self.nds.compute_type(env, agg_env, deep_typecheck) @@ -1361,7 +1397,12 @@ def _handle_randomness(self, create_uids): uid = Env.get_uid() elt = Env.get_uid() iota = StreamIota(I32(0), I32(1)) - return StreamZip([self, iota], [elt, uid], MakeTuple([Cast(Ref(uid, tint32), tint64), Ref(elt, self.typ.element_type)]), 'TakeMinLength') + return StreamZip( + [self, iota], + [elt, uid], + MakeTuple([Cast(Ref(uid, tint32), tint64), Ref(elt, self.typ.element_type)]), + 'TakeMinLength', + ) @typecheck_method(a=IR) def copy(self, a): @@ -1376,13 +1417,9 @@ def _compute_type(self, env, agg_env, deep_typecheck): class StreamZipJoinProducers(IR): - @typecheck_method(contexts=IR, - ctx_name=str, - make_producer=IR, - key=sequenceof(str), - cur_key=str, - cur_vals=str, - join_f=IR) + @typecheck_method( + contexts=IR, ctx_name=str, make_producer=IR, key=sequenceof(str), cur_key=str, cur_vals=str, join_f=IR + ) def __init__(self, contexts, ctx_name, make_producer, key, cur_key, cur_vals, join_f): super().__init__(contexts, make_producer, join_f) self.contexts = contexts @@ -1400,12 +1437,14 @@ def _handle_randomness(self, create_uids): @typecheck_method(new_ir=IR) def copy(self, *new_irs): assert len(new_irs) == 3 - return StreamZipJoinProducers(new_irs[0], self.ctx_name, new_irs[1], - self.key, self.cur_key, self.cur_vals, new_irs[2]) + return StreamZipJoinProducers( + new_irs[0], self.ctx_name, new_irs[1], self.key, self.cur_key, self.cur_vals, new_irs[2] + ) def head_str(self): - return '({}) {} {} {}'.format(' '.join([escape_id(x) for x in self.key]), self.ctx_name, - self.cur_key, self.cur_vals) + return '({}) {} {} {}'.format( + ' '.join([escape_id(x) for x in self.key]), self.ctx_name, self.cur_key, self.cur_vals + ) def _compute_type(self, env, agg_env, deep_typecheck): self.contexts.compute_type(env, agg_env, deep_typecheck) @@ -1459,7 +1498,9 @@ def copy(self, *new_irs): return StreamZipJoin(new_irs[:-1], self.key, self.cur_key, self.cur_vals, new_irs[-1]) def head_str(self): - return '{} ({}) {} {}'.format(len(self.streams), ' '.join([escape_id(x) for x in self.key]), self.cur_key, self.cur_vals) + return '{} ({}) {} {}'.format( + len(self.streams), ' '.join([escape_id(x) for x in self.key]), self.cur_key, self.cur_vals + ) def _compute_type(self, env, agg_env, deep_typecheck): for stream in self.streams: @@ -1544,8 +1585,7 @@ def copy(self, collection): def _compute_type(self, env, agg_env, deep_typecheck): self.collection.compute_type(env, agg_env, deep_typecheck) - return tdict(self.collection.typ.element_type.types[0], - tarray(self.collection.typ.element_type.types[1])) + return tdict(self.collection.typ.element_type.types[0], tarray(self.collection.typ.element_type.types[1])) uid_field_name = '__uid' @@ -1572,7 +1612,7 @@ def pad_uid(uid, type, tag=None): padding = padded - size if tag is not None: padding -= 1 - assert(padding >= 0) + assert padding >= 0 if size == 1: fields = (uid,) else: @@ -1605,11 +1645,11 @@ def unpack_uid(stream_type, name=None): tuple_type = stream_type.element_type tuple = Ref(name or Env.get_uid(), tuple_type) if isinstance(tuple_type, tstruct): - return \ - tuple.name, \ - GetField(tuple, uid_field_name), \ - SelectFields(tuple, [field for field in tuple_type.fields if - not field == uid_field_name]) + return ( + tuple.name, + GetField(tuple, uid_field_name), + SelectFields(tuple, [field for field in tuple_type.fields if not field == uid_field_name]), + ) else: return tuple.name, GetTupleElement(tuple, 0), GetTupleElement(tuple, 1) @@ -1623,9 +1663,9 @@ def pack_to_structs(stream): return stream uid = Env.get_uid() elt = Ref(uid, stream.typ.element_type) - return StreamMap(stream, uid, InsertFields(GetTupleElement(elt, 1), - [(uid_field_name, GetTupleElement(elt, 0))], - None)) + return StreamMap( + stream, uid, InsertFields(GetTupleElement(elt, 1), [(uid_field_name, GetTupleElement(elt, 0))], None) + ) def with_split_rng_state(ir, split, is_scan=None) -> 'BaseIR': @@ -1674,7 +1714,7 @@ def _handle_randomness(self, create_uids): return StreamMap(a, self.name, self.body) if isinstance(self.typ.element_type, tstream): - assert(self.body.uses_randomness and not create_uids) + assert self.body.uses_randomness and not create_uids a = self.a.handle_randomness(False) uid = Env.get_uid() elt = Env.get_uid() @@ -1685,7 +1725,7 @@ def _handle_randomness(self, create_uids): # There are occations when handle_randomness is called twice on a # `StreamMap`: once with `create_uids=False` and the second time # with `True`. In these cases, we only need to propagate the uid. - assert(create_uids) + assert create_uids _, uid, _ = unpack_uid(self.a.typ, self.name) new_body = pack_uid(uid, self.body) return StreamMap(self.a, self.name, new_body) @@ -1731,8 +1771,14 @@ def renderable_bindings(self, i, default_value=None): class StreamZip(IR): - @typecheck_method(streams=sequenceof(IR), names=sequenceof(str), body=IR, behavior=str, - error_id=nullable(int), stack_trace=nullable(str)) + @typecheck_method( + streams=sequenceof(IR), + names=sequenceof(str), + body=IR, + behavior=str, + error_id=nullable(int), + stack_trace=nullable(str), + ) def __init__(self, streams, names, body, behavior, error_id=None, stack_trace=None): super().__init__(*streams, body) self.streams = streams @@ -1743,7 +1789,9 @@ def __init__(self, streams, names, body, behavior, error_id=None, stack_trace=No self._stack_trace = stack_trace if error_id is None or stack_trace is None: self.save_error_info() - self.needs_randomness_handling = any(stream.needs_randomness_handling for stream in streams) or body.uses_randomness + self.needs_randomness_handling = ( + any(stream.needs_randomness_handling for stream in streams) or body.uses_randomness + ) def _handle_randomness(self, create_uids): if not self.body.uses_randomness and not create_uids: @@ -1764,14 +1812,19 @@ def _handle_randomness(self, create_uids): new_body = pack_uid(uid, new_body) return StreamZip(new_streams, tuples, new_body, self.behavior, self._error_id, self._stack_trace) - new_streams = [self.streams[0].handle_randomness(True), *(stream.handle_randomness(False) for stream in self.streams[1:])] + new_streams = [ + self.streams[0].handle_randomness(True), + *(stream.handle_randomness(False) for stream in self.streams[1:]), + ] tuple, uid, elt = unpack_uid(new_streams[0].typ) new_body = Let(self.names[0], elt, self.body) if self.body.uses_randomness: new_body = with_split_rng_state(new_body, uid) if create_uids: new_body = pack_uid(uid, new_body) - return StreamZip(new_streams, [tuple, *self.names[1:]], new_body, self.behavior, self._error_id, self._stack_trace) + return StreamZip( + new_streams, [tuple, *self.names[1:]], new_body, self.behavior, self._error_id, self._stack_trace + ) @typecheck_method(children=IR) def copy(self, *children): @@ -1795,7 +1848,10 @@ def _compute_type(self, env, agg_env, deep_typecheck): def renderable_bindings(self, i, default_value=None): if i == len(self.names): - return {name: default_value if default_value is not None else a.typ.element_type for name, a in zip(self.names, self.streams)} + return { + name: default_value if default_value is not None else a.typ.element_type + for name, a in zip(self.names, self.streams) + } else: return {} @@ -1934,8 +1990,7 @@ def head_str(self): return f'{escape_id(self.accum_name)} {escape_id(self.value_name)}' def _eq(self, other): - return other.accum_name == self.accum_name and \ - other.value_name == self.value_name + return other.accum_name == self.accum_name and other.value_name == self.value_name @property def bound_variables(self): @@ -1990,8 +2045,7 @@ def head_str(self): return f'{escape_id(self.accum_name)} {escape_id(self.value_name)}' def _eq(self, other): - return other.accum_name == self.accum_name and \ - other.value_name == self.value_name + return other.accum_name == self.accum_name and other.value_name == self.value_name @property def bound_variables(self): @@ -2014,8 +2068,19 @@ def renderable_bindings(self, i, default_value=None): class StreamWhiten(IR): - @typecheck_method(stream=IR, new_chunk=str, prev_window=str, vec_size=int, window_size=int, chunk_size=int, block_size=int, normalize_after_whiten=bool) - def __init__(self, stream, new_chunk, prev_window, vec_size, window_size, chunk_size, block_size, normalize_after_whiten): + @typecheck_method( + stream=IR, + new_chunk=str, + prev_window=str, + vec_size=int, + window_size=int, + chunk_size=int, + block_size=int, + normalize_after_whiten=bool, + ) + def __init__( + self, stream, new_chunk, prev_window, vec_size, window_size, chunk_size, block_size, normalize_after_whiten + ): super().__init__(stream) self.stream = stream self.new_chunk = new_chunk @@ -2028,19 +2093,30 @@ def __init__(self, stream, new_chunk, prev_window, vec_size, window_size, chunk_ @typecheck_method(stream=IR) def copy(self, stream): - return StreamWhiten(stream, self.new_chunk, self.prev_window, self.vec_size, self.window_size, self.chunk_size, self.block_size, self.normalize_after_whiten) + return StreamWhiten( + stream, + self.new_chunk, + self.prev_window, + self.vec_size, + self.window_size, + self.chunk_size, + self.block_size, + self.normalize_after_whiten, + ) def head_str(self): return f'{escape_id(self.new_chunk)} {escape_id(self.prev_window)} {self.vec_size} {self.window_size} {self.chunk_size} {self.block_size} {self.normalize_after_whiten}' def _eq(self, other): - return other.new_chunk == self.new_chunk and \ - other.prev_window == self.prev_window and \ - other.vec_size == self.vec_size and \ - other.window_size == self.window_size and \ - other.chunk_size == self.chunk_size and \ - other.block_size == self.block_size and \ - other.normalize_after_whiten == self.normalize_after_whiten + return ( + other.new_chunk == self.new_chunk + and other.prev_window == self.prev_window + and other.vec_size == self.vec_size + and other.window_size == self.window_size + and other.chunk_size == self.chunk_size + and other.block_size == self.block_size + and other.normalize_after_whiten == self.normalize_after_whiten + ) def _compute_type(self, env, agg_env, deep_typecheck): self.stream._compute_type(env, agg_env, deep_typecheck) @@ -2048,7 +2124,9 @@ def _compute_type(self, env, agg_env, deep_typecheck): class StreamJoinRightDistinct(IR): - @typecheck_method(left=IR, right=IR, l_key=sequenceof(str), r_key=sequenceof(str), l_name=str, r_name=str, join=IR, join_type=str) + @typecheck_method( + left=IR, right=IR, l_key=sequenceof(str), r_key=sequenceof(str), l_name=str, r_name=str, join=IR, join_type=str + ) def __init__(self, left, right, l_key, r_key, l_name, r_name, join, join_type): super().__init__(left, right, join) self.left = left @@ -2059,15 +2137,19 @@ def __init__(self, left, right, l_key, r_key, l_name, r_name, join, join_type): self.r_name = r_name self.join = join self.join_type = join_type - self.needs_randomness_handling = left.needs_randomness_handling or right.needs_randomness_handling or join.uses_randomness + self.needs_randomness_handling = ( + left.needs_randomness_handling or right.needs_randomness_handling or join.uses_randomness + ) def _handle_randomness(self, create_uids): if not self.join.uses_randomness and not create_uids: left = self.left.handle_randomness(False) right = self.right.handle_randomness(False) - return StreamJoinRightDistinct(left, right, self.l_key, self.r_key, self.l_name, self.r_name, self.join, self.join_type) + return StreamJoinRightDistinct( + left, right, self.l_key, self.r_key, self.l_name, self.r_name, self.join, self.join_type + ) - if self.join_type == 'left' or self.join_type == 'inner': + if self.join_type in {'left', 'inner'}: left = pack_to_structs(self.left.handle_randomness(True)) right = self.right.handle_randomness(False) r_name = self.r_name @@ -2094,7 +2176,9 @@ def _handle_randomness(self, create_uids): @typecheck_method(left=IR, right=IR, join=IR) def copy(self, left, right, join): - return StreamJoinRightDistinct(left, right, self.l_key, self.r_key, self.l_name, self.r_name, join, self.join_type) + return StreamJoinRightDistinct( + left, right, self.l_key, self.r_key, self.l_name, self.r_name, join, self.join_type + ) def head_str(self): return '({}) ({}) {} {} {}'.format( @@ -2102,12 +2186,11 @@ def head_str(self): ' '.join([escape_id(x) for x in self.r_key]), self.l_name, self.r_name, - self.join_type) + self.join_type, + ) def _eq(self, other): - return other.l_name == self.l_name and \ - other.r_name == self.r_name and \ - other.join_type == self.join_type + return other.l_name == self.l_name and other.r_name == self.r_name and other.join_type == self.join_type @property def bound_variables(self): @@ -2116,11 +2199,9 @@ def bound_variables(self): def renderable_bindings(self, i, default_value=None): if i == 2: if default_value is None: - return {self.l_name: self.left.typ.element_type, - self.r_name: self.right.typ.element_type} + return {self.l_name: self.left.typ.element_type, self.r_name: self.right.typ.element_type} else: - return {self.l_name: default_value, - self.r_name: default_value} + return {self.l_name: default_value, self.r_name: default_value} else: return {} @@ -2404,12 +2485,17 @@ def head_str(self): return f'{escape_id(self.element_name)} {escape_id(self.index_name)} {self.is_scan} False' def _eq(self, other): - return self.element_name == other.element_name and self.index_name == other.index_name and self.is_scan == other.is_scan + return ( + self.element_name == other.element_name + and self.index_name == other.index_name + and self.is_scan == other.is_scan + ) def _compute_type(self, env, agg_env, deep_typecheck): self.array.compute_type(agg_env, None, deep_typecheck) - self.agg_ir.compute_type(_env_bind(env, self.bindings(1)), - _env_bind(agg_env, self.agg_bindings(1)), deep_typecheck) + self.agg_ir.compute_type( + _env_bind(env, self.bindings(1)), _env_bind(agg_env, self.agg_bindings(1)), deep_typecheck + ) return tarray(self.agg_ir.typ) @property @@ -2472,16 +2558,15 @@ def lookup_aggregator_return_type(name, init_args, seq_args): p.clear() for p in seq_params: p.clear() - if (all(p.unify(a) for p, a in zip(init_params, init_args)) - and all(p.unify(a) for p, a in zip(seq_params, seq_args))): + if all(p.unify(a) for p, a in zip(init_params, init_args)) and all( + p.unify(a) for p, a in zip(seq_params, seq_args) + ): return ret_type.subst() raise KeyError(f'aggregator {name}({ ",".join([str(t) for t in seq_args]) }) not found') class BaseApplyAggOp(IR): - @typecheck_method(agg_op=str, - init_op_args=sequenceof(IR), - seq_op_args=sequenceof(IR)) + @typecheck_method(agg_op=str, init_op_args=sequenceof(IR), seq_op_args=sequenceof(IR)) def __init__(self, agg_op, init_op_args, seq_op_args): super().__init__(*init_op_args, *seq_op_args) self.agg_op = agg_op @@ -2491,7 +2576,7 @@ def __init__(self, agg_op, init_op_args, seq_op_args): def copy(self, *args): new_instance = self.__class__ n_seq_op_args = len(self.seq_op_args) - init_op_args = args[:len(self.init_op_args)] + init_op_args = args[: len(self.init_op_args)] seq_op_args = args[-n_seq_op_args:] return new_instance(self.agg_op, init_op_args, seq_op_args) @@ -2503,10 +2588,7 @@ def render_head(self, r): return f'({self._ir_name()} {self.agg_op} ' def render_children(self, r): - return [ - ParensRenderer(self.init_op_args), - ParensRenderer(self.seq_op_args) - ] + return [ParensRenderer(self.init_op_args), ParensRenderer(self.seq_op_args)] @property def aggregations(self): @@ -2514,15 +2596,15 @@ def aggregations(self): return [self] def __eq__(self, other): - return isinstance(other, self.__class__) and \ - other.agg_op == self.agg_op and \ - other.init_op_args == self.init_op_args and \ - other.seq_op_args == self.seq_op_args + return ( + isinstance(other, self.__class__) + and other.agg_op == self.agg_op + and other.init_op_args == self.init_op_args + and other.seq_op_args == self.seq_op_args + ) def __hash__(self): - return hash(tuple([self.agg_op, - tuple(self.init_op_args), - tuple(self.seq_op_args)])) + return hash(tuple([self.agg_op, tuple(self.init_op_args), tuple(self.seq_op_args)])) def _compute_type(self, env, agg_env, deep_typecheck): for a in self.init_op_args: @@ -2531,9 +2613,8 @@ def _compute_type(self, env, agg_env, deep_typecheck): a.compute_type(agg_env, None, deep_typecheck) return lookup_aggregator_return_type( - self.agg_op, - [a.typ for a in self.init_op_args], - [a.typ for a in self.seq_op_args]) + self.agg_op, [a.typ for a in self.init_op_args], [a.typ for a in self.seq_op_args] + ) def renderable_new_block(self, i: int) -> bool: return i == 0 @@ -2549,9 +2630,7 @@ def uses_agg_capability(cls) -> bool: class ApplyAggOp(BaseApplyAggOp): - @typecheck_method(agg_op=str, - init_op_args=sequenceof(IR), - seq_op_args=sequenceof(IR)) + @typecheck_method(agg_op=str, init_op_args=sequenceof(IR), seq_op_args=sequenceof(IR)) def __init__(self, agg_op, init_op_args, seq_op_args): super().__init__(agg_op, init_op_args, seq_op_args) @@ -2560,9 +2639,7 @@ def renderable_uses_agg_context(self, i: int): class ApplyScanOp(BaseApplyAggOp): - @typecheck_method(agg_op=str, - init_op_args=sequenceof(IR), - seq_op_args=sequenceof(IR)) + @typecheck_method(agg_op=str, init_op_args=sequenceof(IR), seq_op_args=sequenceof(IR)) def __init__(self, agg_op, init_op_args, seq_op_args): super().__init__(agg_op, init_op_args, seq_op_args) @@ -2602,7 +2679,7 @@ def _compute_type(self, env, agg_env, deep_typecheck): def renderable_bindings(self, i: int, default_value=None): dict_so_far = {} - if i == 1 or i == 2: + if i in {1, 2}: if default_value is None: dict_so_far[self.accum_name] = self.zero.typ else: @@ -2624,10 +2701,10 @@ def bound_variables(self): return {self.accum_name, self.other_accum_name} | super().bound_variables def renderable_uses_agg_context(self, i: int) -> bool: - return (i == 1 or i == 2) and not self.is_scan + return (i in {1, 2}) and not self.is_scan def renderable_uses_scan_context(self, i: int) -> bool: - return (i == 1 or i == 2) and self.is_scan + return (i in {1, 2}) and self.is_scan class Begin(IR): @@ -2659,8 +2736,7 @@ def render_children(self, r): return [InsertFields.IFRenderField(escape_id(f), x) for f, x in self.fields] def __eq__(self, other): - return isinstance(other, MakeStruct) \ - and other.fields == self.fields + return isinstance(other, MakeStruct) and other.fields == self.fields def __hash__(self): return hash(tuple(self.fields)) @@ -2717,7 +2793,7 @@ def _ir_name(self): def _compute_type(self, env, agg_env, deep_typecheck): if deep_typecheck: self.ref.compute_type(env, agg_env, deep_typecheck) - assert(self.ref.typ._select_fields(self._typ.fields) == self._typ) + assert self.ref.typ._select_fields(self._typ.fields) == self._typ return self._typ @@ -2780,16 +2856,17 @@ def copy(self, *args): def render_children(self, r): return [ self.old, - hail.ir.RenderableStr( - 'None' if self.field_order is None else parsable_strings(self.field_order)), - *(InsertFields.IFRenderField(escape_id(f), x) for f, x in self.fields) + hail.ir.RenderableStr('None' if self.field_order is None else parsable_strings(self.field_order)), + *(InsertFields.IFRenderField(escape_id(f), x) for f, x in self.fields), ] def __eq__(self, other): - return isinstance(other, InsertFields) and \ - other.old == self.old and \ - other.fields == self.fields and \ - other.field_order == self.field_order + return ( + isinstance(other, InsertFields) + and other.old == self.old + and other.fields == self.fields + and other.field_order == self.field_order + ) def __hash__(self): return hash((self.old, tuple(self.fields), tuple(self.field_order) if self.field_order else None)) @@ -2855,7 +2932,7 @@ def _ir_name(self): def _compute_type(self, env, agg_env, deep_typecheck): if deep_typecheck: self.ref.compute_type(env, agg_env, deep_typecheck) - assert(self.ref.typ[self.field] == self._typ) + assert self.ref.typ[self.field] == self._typ return self._typ @@ -2984,7 +3061,6 @@ def register_seeded_function(name, param_types, ret_type): def udf(*param_types: HailType) -> Callable[[Callable[P, T]], Callable[P, T]]: - uid = Env.get_uid() @decorator @@ -3001,9 +3077,23 @@ def wrapper(__original_func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) class Apply(IR): - @typecheck_method(function=str, return_type=hail_type, args=IR, - error_id=nullable(int), stack_trace=nullable(str), type_args=tupleof(hail_type)) - def __init__(self, function, return_type, *args, type_args=(), error_id=None, stack_trace=None,): + @typecheck_method( + function=str, + return_type=hail_type, + args=IR, + error_id=nullable(int), + stack_trace=nullable(str), + type_args=tupleof(hail_type), + ) + def __init__( + self, + function, + return_type, + *args, + type_args=(), + error_id=None, + stack_trace=None, + ): super().__init__(*args) self.function = function self.return_type = return_type @@ -3015,16 +3105,25 @@ def __init__(self, function, return_type, *args, type_args=(), error_id=None, st self.save_error_info() def copy(self, *args): - return Apply(self.function, self.return_type, *args, type_args=self.type_args, error_id=self._error_id, stack_trace=self._stack_trace,) + return Apply( + self.function, + self.return_type, + *args, + type_args=self.type_args, + error_id=self._error_id, + stack_trace=self._stack_trace, + ) def head_str(self): type_args = "(" + " ".join([a._parsable_string() for a in self.type_args]) + ")" return f'{self._error_id} {escape_id(self.function)} {type_args} {self.return_type._parsable_string()}' def _eq(self, other): - return other.function == self.function and \ - other.type_args == self.type_args and \ - other.return_type == self.return_type + return ( + other.function == self.function + and other.type_args == self.type_args + and other.return_type == self.return_type + ) def _compute_type(self, env, agg_env, deep_typecheck): for arg in self.args: @@ -3050,9 +3149,11 @@ def head_str(self): return f'{escape_id(self.function)} {self.static_rng_uid} {self.return_type._parsable_string()}' def _eq(self, other): - return other.function == self.function and \ - other.static_rng_uid == self.static_rng_uid and \ - other.return_type == self.return_type + return ( + other.function == self.function + and other.static_rng_uid == self.static_rng_uid + and other.return_type == self.return_type + ) def _compute_type(self, env, agg_env, deep_typecheck): for arg in self.args: @@ -3134,8 +3235,7 @@ def copy(self, child): def _compute_type(self, env, agg_env, deep_typecheck): self.child.compute_type(deep_typecheck) - return tstruct(**{'rows': tarray(self.child.typ.row_type), - 'global': self.child.typ.global_type}) + return tstruct(**{'rows': tarray(self.child.typ.row_type), 'global': self.child.typ.global_type}) class TableAggregate(IR): @@ -3254,10 +3354,22 @@ def row_type(self): class GVCFPartitionReader(PartitionReader): - entries_field_name = '__entries' - def __init__(self, header, call_fields, entry_float_type, array_elements_required, rg, contig_recoding, - skip_invalid_loci, filter, find, replace, uid_field): + + def __init__( + self, + header, + call_fields, + entry_float_type, + array_elements_required, + rg, + contig_recoding, + skip_invalid_loci, + filter, + find, + replace, + uid_field, + ): self.header = header self.call_fields = call_fields self.entry_float_type = entry_float_type @@ -3271,49 +3383,57 @@ def __init__(self, header, call_fields, entry_float_type, array_elements_require self.uid_field = uid_field def with_uid_field(self, uid_field): - return GVCFPartitionReader(self.header, - self.call_fields, - self.entry_float_type, - self.array_elements_required, - self.rg, - self.contig_recoding, - self.skip_invalid_loci, - self.filter, - self.find, - self.replace, - uid_field) + return GVCFPartitionReader( + self.header, + self.call_fields, + self.entry_float_type, + self.array_elements_required, + self.rg, + self.contig_recoding, + self.skip_invalid_loci, + self.filter, + self.find, + self.replace, + uid_field, + ) def render(self): - return escape_str(json.dumps({"name": "GVCFPartitionReader", - "header": {"name": "VCFHeaderInfo", **self.header}, - "callFields": list(self.call_fields), - "entryFloatType": "Float64" if self.entry_float_type == tfloat64 else "Float32", - "arrayElementsRequired": self.array_elements_required, - "rg": self.rg.name if self.rg is not None else None, - "contigRecoding": self.contig_recoding, - "filterAndReplace": { - "name": "TextInputFilterAndReplace", - "filter": self.filter, - "find": self.find, - "replace": self.replace, - }, - "skipInvalidLoci": self.skip_invalid_loci, - "entriesFieldName": GVCFPartitionReader.entries_field_name, - "uidFieldName": self.uid_field if self.uid_field is not None else '__dummy'})) + return escape_str( + json.dumps({ + "name": "GVCFPartitionReader", + "header": {"name": "VCFHeaderInfo", **self.header}, + "callFields": list(self.call_fields), + "entryFloatType": "Float64" if self.entry_float_type == tfloat64 else "Float32", + "arrayElementsRequired": self.array_elements_required, + "rg": self.rg.name if self.rg is not None else None, + "contigRecoding": self.contig_recoding, + "filterAndReplace": { + "name": "TextInputFilterAndReplace", + "filter": self.filter, + "find": self.find, + "replace": self.replace, + }, + "skipInvalidLoci": self.skip_invalid_loci, + "entriesFieldName": GVCFPartitionReader.entries_field_name, + "uidFieldName": self.uid_field if self.uid_field is not None else '__dummy', + }) + ) def _eq(self, other): - return isinstance(other, GVCFPartitionReader) \ - and self.header == other.header \ - and self.call_fields == other.call_fields \ - and self.entry_float_type == other.entry_float_type \ - and self.array_elements_required == other.array_elements_required \ - and self.rg == other.rg \ - and self.contig_recoding == other.contig_recoding \ - and self.skip_invalid_loci == other.skip_invalid_loci \ - and self.filter == other.filter \ - and self.find == other.find \ - and self.replace == other.replace \ - and self.uid_field == other.uid_field + return ( + isinstance(other, GVCFPartitionReader) + and self.header == other.header + and self.call_fields == other.call_fields + and self.entry_float_type == other.entry_float_type + and self.array_elements_required == other.array_elements_required + and self.rg == other.rg + and self.contig_recoding == other.contig_recoding + and self.skip_invalid_loci == other.skip_invalid_loci + and self.filter == other.filter + and self.find == other.find + and self.replace == other.replace + and self.uid_field == other.uid_field + ) def row_type(self): if self.uid_field is not None: @@ -3336,15 +3456,20 @@ def subst_format(name, t): return tarray(subst_format(name, t.element_type)) return t - return tstruct(locus=tstruct(contig=tstr, position=tint32) if self.rg is None else tlocus(self.rg), - alleles=tarray(tstr), - rsid=tstr, - qual=tfloat64, - filters=tset(tstr), - info=tstruct(**{k: parse_type(v) for k, v in self.header['infoFields']}), - **{GVCFPartitionReader.entries_field_name: tarray( - tstruct(**{k: subst_format(k, parse_type(v)) for k, v in self.header['formatFields']}))}, - **uid_fd) + return tstruct( + locus=tstruct(contig=tstr, position=tint32) if self.rg is None else tlocus(self.rg), + alleles=tarray(tstr), + rsid=tstr, + qual=tfloat64, + filters=tset(tstr), + info=tstruct(**{k: parse_type(v) for k, v in self.header['infoFields']}), + **{ + GVCFPartitionReader.entries_field_name: tarray( + tstruct(**{k: subst_format(k, parse_type(v)) for k, v in self.header['formatFields']}) + ) + }, + **uid_fd, + ) class PartitionNativeIntervalReader(PartitionReader): @@ -3357,13 +3482,20 @@ def with_uid_field(self, uid_field): return PartitionNativeIntervalReader(self.path, self.table_row_type, uid_field) def render(self): - return escape_str(json.dumps({"name": "PartitionNativeIntervalReader", - "path": self.path, - "uidFieldName": self.uid_field if self.uid_field is not None else '__dummy'})) + return escape_str( + json.dumps({ + "name": "PartitionNativeIntervalReader", + "path": self.path, + "uidFieldName": self.uid_field if self.uid_field is not None else '__dummy', + }) + ) def _eq(self, other): - return isinstance(other, - PartitionNativeIntervalReader) and self.path == other.path and self.uid_field == other.uid_field + return ( + isinstance(other, PartitionNativeIntervalReader) + and self.path == other.path + and self.uid_field == other.uid_field + ) def row_type(self): if self.uid_field is None: @@ -3628,8 +3760,7 @@ def _compute_type(self, env, agg_env, deep_typecheck): class Literal(IR): - @typecheck_method(typ=hail_type, - value=anytype) + @typecheck_method(typ=hail_type, value=anytype) def __init__(self, typ, value): super(Literal, self).__init__() self._typ: HailType = typ @@ -3642,8 +3773,7 @@ def head_str(self): return f'{self._typ._parsable_string()} {dump_json(self._typ._convert_to_json_na(self.value))}' def _eq(self, other): - return other._typ == self._typ and \ - other.value == self.value + return other._typ == self._typ and other.value == self.value def _compute_type(self, env, agg_env, deep_typecheck): return self._typ @@ -3651,7 +3781,7 @@ def _compute_type(self, env, agg_env, deep_typecheck): class EncodedLiteral(IR): @typecheck_method(typ=hail_type, value=anytype, encoded_value=nullable(str)) - def __init__(self, typ, value, *, encoded_value = None): + def __init__(self, typ, value, *, encoded_value=None): super(EncodedLiteral, self).__init__() self._typ: HailType = typ self._value = value @@ -3670,8 +3800,7 @@ def head_str(self): return f'{self._typ._parsable_string()} "{self.encoded_value}"' def _eq(self, other): - return other._typ == self._typ and \ - other.encoded_value == self.encoded_value + return other._typ == self._typ and other.encoded_value == self.encoded_value def _compute_type(self, env, agg_env, deep_typecheck): return self._typ @@ -3694,10 +3823,9 @@ def _compute_type(self, env, agg_env, deep_typecheck): class Join(IR): _idx = 0 - @typecheck_method(virtual_ir=IR, - temp_vars=sequenceof(str), - join_exprs=sequenceof(anytype), - join_func=func_spec(1, anytype)) + @typecheck_method( + virtual_ir=IR, temp_vars=sequenceof(str), join_exprs=sequenceof(anytype), join_func=func_spec(1, anytype) + ) def __init__(self, virtual_ir, temp_vars, join_exprs, join_func): super(Join, self).__init__(virtual_ir) self.virtual_ir = virtual_ir @@ -3710,10 +3838,7 @@ def __init__(self, virtual_ir, temp_vars, join_exprs, join_func): def copy(self, virtual_ir): # FIXME: This is pretty fucked, Joins should probably be tracked on Expression? new_instance = self.__class__ - new_instance = new_instance(virtual_ir, - self.temp_vars, - self.join_exprs, - self.join_func) + new_instance = new_instance(virtual_ir, self.temp_vars, self.join_exprs, self.join_func) new_instance.idx = self.idx return new_instance @@ -3744,6 +3869,7 @@ def __init__(self, ir_id): def __del__(self): from hail.backend.py4j_backend import Py4JBackend + if Env._hc: backend = Env.backend() assert isinstance(backend, Py4JBackend) @@ -3783,71 +3909,59 @@ def delete(env, name): if isinstance(ir, Ref): return env.get(ir.name, ir) elif isinstance(ir, Let): - return Let(ir.name, - _subst(ir.value), - _subst(ir.body, delete(env, ir.name))) + return Let(ir.name, _subst(ir.value), _subst(ir.body, delete(env, ir.name))) elif isinstance(ir, AggLet): - return AggLet(ir.name, - _subst(ir.value, agg_env, {}), - _subst(ir.body, delete(env, ir.name)), - ir.is_scan) + return AggLet(ir.name, _subst(ir.value, agg_env, {}), _subst(ir.body, delete(env, ir.name)), ir.is_scan) elif isinstance(ir, StreamMap): - return StreamMap(_subst(ir.a), - ir.name, - _subst(ir.body, delete(env, ir.name))) + return StreamMap(_subst(ir.a), ir.name, _subst(ir.body, delete(env, ir.name))) elif isinstance(ir, StreamFilter): - return StreamFilter(_subst(ir.a), - ir.name, - _subst(ir.body, delete(env, ir.name))) + return StreamFilter(_subst(ir.a), ir.name, _subst(ir.body, delete(env, ir.name))) elif isinstance(ir, StreamFlatMap): - return StreamFlatMap(_subst(ir.a), - ir.name, - _subst(ir.body, delete(env, ir.name))) + return StreamFlatMap(_subst(ir.a), ir.name, _subst(ir.body, delete(env, ir.name))) elif isinstance(ir, StreamFold): - return StreamFold(_subst(ir.a), - _subst(ir.zero), - ir.accum_name, - ir.value_name, - _subst(ir.body, delete(delete(env, ir.accum_name), ir.value_name))) + return StreamFold( + _subst(ir.a), + _subst(ir.zero), + ir.accum_name, + ir.value_name, + _subst(ir.body, delete(delete(env, ir.accum_name), ir.value_name)), + ) elif isinstance(ir, StreamScan): - return StreamScan(_subst(ir.a), - _subst(ir.zero), - ir.accum_name, - ir.value_name, - _subst(ir.body, delete(delete(env, ir.accum_name), ir.value_name))) + return StreamScan( + _subst(ir.a), + _subst(ir.zero), + ir.accum_name, + ir.value_name, + _subst(ir.body, delete(delete(env, ir.accum_name), ir.value_name)), + ) elif isinstance(ir, StreamFor): - return StreamFor(_subst(ir.a), - ir.value_name, - _subst(ir.body, delete(env, ir.value_name))) + return StreamFor(_subst(ir.a), ir.value_name, _subst(ir.body, delete(env, ir.value_name))) elif isinstance(ir, AggFilter): - return AggFilter(_subst(ir.cond, agg_env), - _subst(ir.agg_ir, agg_env), - ir.is_scan) + return AggFilter(_subst(ir.cond, agg_env), _subst(ir.agg_ir, agg_env), ir.is_scan) elif isinstance(ir, AggExplode): - return AggExplode(_subst(ir.s, agg_env), - ir.name, - _subst(ir.agg_body, delete(agg_env, ir.name), delete(agg_env, ir.name)), - ir.is_scan) + return AggExplode( + _subst(ir.s, agg_env), + ir.name, + _subst(ir.agg_body, delete(agg_env, ir.name), delete(agg_env, ir.name)), + ir.is_scan, + ) elif isinstance(ir, AggGroupBy): - return AggGroupBy(_subst(ir.key, agg_env), - _subst(ir.agg_ir, agg_env), - ir.is_scan) + return AggGroupBy(_subst(ir.key, agg_env), _subst(ir.agg_ir, agg_env), ir.is_scan) elif isinstance(ir, ApplyAggOp): subst_init_op_args = [x.map_ir(lambda x: _subst(x)) for x in ir.init_op_args] subst_seq_op_args = [subst(x, agg_env, {}) for x in ir.seq_op_args] - return ApplyAggOp(ir.agg_op, - subst_init_op_args, - subst_seq_op_args) + return ApplyAggOp(ir.agg_op, subst_init_op_args, subst_seq_op_args) elif isinstance(ir, AggFold): subst_seq_op = subst(ir.seq_op, agg_env, {}) return AggFold(ir.zero, subst_seq_op, ir.comb_op, ir.accum_name, ir.other_accum_name, ir.is_scan) elif isinstance(ir, AggArrayPerElement): - return AggArrayPerElement(_subst(ir.array, agg_env), - ir.element_name, - ir.index_name, - _subst(ir.agg_ir, delete(env, ir.index_name), - delete(agg_env, ir.element_name)), - ir.is_scan) + return AggArrayPerElement( + _subst(ir.array, agg_env), + ir.element_name, + ir.index_name, + _subst(ir.agg_ir, delete(env, ir.index_name), delete(agg_env, ir.element_name)), + ir.is_scan, + ) else: assert isinstance(ir, IR) return ir.map_ir(lambda x: _subst(x)) diff --git a/hail/python/hail/ir/matrix_ir.py b/hail/python/hail/ir/matrix_ir.py index 06644337b19..006ea2e6dc9 100644 --- a/hail/python/hail/ir/matrix_ir.py +++ b/hail/python/hail/ir/matrix_ir.py @@ -1,12 +1,21 @@ from typing import Optional + import hail as hl from hail.expr.types import HailType, tint64 +from hail.ir import ir from hail.ir.base_ir import BaseIR, MatrixIR -from hail.ir.utils import modify_deep_field, zip_with_index, zip_with_index_field, default_row_uid, default_col_uid, unpack_row_uid, unpack_col_uid -import hail.ir.ir as ir -from hail.utils.misc import escape_str, parsable_strings, escape_id -from hail.utils.jsonx import dump_json +from hail.ir.utils import ( + default_col_uid, + default_row_uid, + modify_deep_field, + unpack_col_uid, + unpack_row_uid, + zip_with_index, + zip_with_index_field, +) from hail.utils.java import Env +from hail.utils.jsonx import dump_json +from hail.utils.misc import escape_id, escape_str, parsable_strings class MatrixAggregateRowsByKey(MatrixIR): @@ -42,23 +51,20 @@ def _handle_randomness(self, row_uid_field_name, col_uid_field_name): entry_expr = ir.AggLet('sa', old_col, entry_expr, is_scan=False) entry_expr = ir.Let('sa', old_col, entry_expr) if self.entry_expr.uses_value_randomness: - entry_expr = ir.Let('__rng_state', - ir.RNGSplit(ir.RNGStateLiteral(), ir.concat_uids(first_row_uid, col_uid)), - entry_expr) + entry_expr = ir.Let( + '__rng_state', ir.RNGSplit(ir.RNGStateLiteral(), ir.concat_uids(first_row_uid, col_uid)), entry_expr + ) if self.entry_expr.uses_agg_randomness(is_scan=False): - entry_expr = ir.AggLet('__rng_state', - ir.RNGSplit(ir.RNGStateLiteral(), ir.concat_uids(row_uid, col_uid)), - entry_expr, - is_scan=False) + entry_expr = ir.AggLet( + '__rng_state', + ir.RNGSplit(ir.RNGStateLiteral(), ir.concat_uids(row_uid, col_uid)), + entry_expr, + is_scan=False, + ) if self.row_expr.uses_value_randomness: - row_expr = ir.Let('__rng_state', - ir.RNGSplit(ir.RNGStateLiteral(), first_row_uid), - row_expr) + row_expr = ir.Let('__rng_state', ir.RNGSplit(ir.RNGStateLiteral(), first_row_uid), row_expr) if self.row_expr.uses_agg_randomness(is_scan=False): - row_expr = ir.AggLet('__rng_state', - ir.RNGSplit(ir.RNGStateLiteral(), row_uid), - row_expr, - is_scan=False) + row_expr = ir.AggLet('__rng_state', ir.RNGSplit(ir.RNGStateLiteral(), row_uid), row_expr, is_scan=False) result = MatrixAggregateRowsByKey(child, entry_expr, row_expr) if drop_row_uid: @@ -80,7 +86,8 @@ def _compute_type(self, deep_typecheck): child_typ.col_key, child_typ.row_key_type._concat(self.row_expr.typ), child_typ.row_key, - self.entry_expr.typ) + self.entry_expr.typ, + ) def renderable_bindings(self, i, default_value=None): if i == 1: @@ -104,14 +111,16 @@ def renderable_agg_bindings(self, i, default_value=None): class MatrixRead(MatrixIR): - def __init__(self, - reader, - drop_cols: bool = False, - drop_rows: bool = False, - drop_row_uids: bool = True, - drop_col_uids: bool = True, - *, - _assert_type: Optional[HailType] = None): + def __init__( + self, + reader, + drop_cols: bool = False, + drop_rows: bool = False, + drop_row_uids: bool = True, + drop_col_uids: bool = True, + *, + _assert_type: Optional[HailType] = None, + ): super().__init__() self.reader = reader self.drop_cols = drop_cols @@ -145,8 +154,8 @@ def _handle_randomness(self, row_uid_field_name, col_uid_field_name): else: row = ir.Ref('va', self.typ.row_type) result = MatrixMapRows( - result, - ir.InsertFields(row, [(row_uid_field_name, ir.GetField(row, default_row_uid))], None)) + result, ir.InsertFields(row, [(row_uid_field_name, ir.GetField(row, default_row_uid))], None) + ) if rename_col_uid: if self.drop_col_uids: rename = True @@ -156,7 +165,8 @@ def _handle_randomness(self, row_uid_field_name, col_uid_field_name): result = MatrixMapCols( result, ir.InsertFields(col, [(col_uid_field_name, ir.GetField(col, default_col_uid))], None), - None) + None, + ) if rename: result = MatrixRename(result, {}, col_map, row_map, {}) return result @@ -173,11 +183,13 @@ def render_head(self, r): return f'(MatrixRead {reqType} {self.drop_cols} {self.drop_rows} "{self.reader.render(r)}"' def _eq(self, other): - return (self.reader == other.reader - and self.drop_cols == other.drop_cols - and self.drop_rows == other.drop_rows - and self.drop_row_uids == other.drop_row_uids - and self.drop_col_uids == other.drop_col_uids) + return ( + self.reader == other.reader + and self.drop_cols == other.drop_cols + and self.drop_rows == other.drop_rows + and self.drop_row_uids == other.drop_row_uids + and self.drop_col_uids == other.drop_col_uids + ) def _compute_type(self, deep_typecheck): if self._type is None: @@ -294,7 +306,8 @@ def _compute_type(self, deep_typecheck): self.new_key if self.new_key is not None else child_typ.col_key, child_typ.row_type, child_typ.row_key, - child_typ.entry_type) + child_typ.entry_type, + ) def renderable_bindings(self, i, default_value=None): if i == 1: @@ -331,14 +344,20 @@ def _handle_randomness(self, row_uid_field_name, col_uid_field_name): (left_uid, _) = unpack_col_uid(left.typ.col_type, col_uid_field_name) (right_uid, _) = unpack_col_uid(right.typ.col_type, col_uid_field_name) uid_type = ir.unify_uid_types((left_uid.typ, right_uid.typ), tag=True) - left = MatrixMapCols(left, - ir.InsertFields(ir.Ref('sa', left.typ.col_type), - [(col_uid_field_name, ir.pad_uid(left_uid, uid_type, 0))], None), - new_key=None) - right = MatrixMapCols(right, - ir.InsertFields(ir.Ref('sa', right.typ.col_type), - [(col_uid_field_name, ir.pad_uid(right_uid, uid_type, 1))], None), - new_key=None) + left = MatrixMapCols( + left, + ir.InsertFields( + ir.Ref('sa', left.typ.col_type), [(col_uid_field_name, ir.pad_uid(left_uid, uid_type, 0))], None + ), + new_key=None, + ) + right = MatrixMapCols( + right, + ir.InsertFields( + ir.Ref('sa', right.typ.col_type), [(col_uid_field_name, ir.pad_uid(right_uid, uid_type, 1))], None + ), + new_key=None, + ) result = MatrixUnionCols(left, right, self.join_type) @@ -370,7 +389,8 @@ def _compute_type(self, deep_typecheck): col_key=left_typ.col_key, row_type=left_typ.row_type._concat(right_typ.row_value_type), row_key=left_typ.row_key, - entry_type=left_typ.entry_type) + entry_type=left_typ.entry_type, + ) class MatrixMapEntries(MatrixIR): @@ -399,7 +419,9 @@ def _handle_randomness(self, row_uid_field_name, col_uid_field_name): col_uid, old_col = unpack_col_uid(child.typ.col_type, col_uid_field_name) new_entry = ir.Let('sa', old_col, new_entry) if self.new_entry.uses_value_randomness: - new_entry = ir.Let('__rng_state', ir.RNGSplit(ir.RNGStateLiteral(), ir.concat_uids(row_uid, col_uid)), new_entry) + new_entry = ir.Let( + '__rng_state', ir.RNGSplit(ir.RNGStateLiteral(), ir.concat_uids(row_uid, col_uid)), new_entry + ) result = MatrixMapEntries(child, new_entry) if drop_row_uid: _, old_row = unpack_row_uid(result.typ.row_type, row_uid_field_name) @@ -419,7 +441,8 @@ def _compute_type(self, deep_typecheck): child_typ.col_key, child_typ.row_type, child_typ.row_key, - self.new_entry.typ) + self.new_entry.typ, + ) def renderable_bindings(self, i, default_value=None): return self.child.typ.entry_env(default_value) if i == 1 else {} @@ -482,9 +505,7 @@ def _handle_randomness(self, row_uid_field_name, col_uid_field_name): return MatrixKeyRowsBy(child, self.keys, self.is_sorted) def head_str(self): - return '({}) {}'.format( - ' '.join([escape_id(x) for x in self.keys]), - self.is_sorted) + return '({}) {}'.format(' '.join([escape_id(x) for x in self.keys]), self.is_sorted) def _eq(self, other): return self.keys == other.keys and self.is_sorted == other.is_sorted @@ -498,7 +519,8 @@ def _compute_type(self, deep_typecheck): child_typ.col_key, child_typ.row_type, self.keys, - child_typ.entry_type) + child_typ.entry_type, + ) class MatrixMapRows(MatrixIR): @@ -520,8 +542,9 @@ def _handle_randomness(self, row_uid_field_name, col_uid_field_name): if row_uid_field_name is None: row_uid_field_name = default_row_uid child = self.child.handle_randomness(row_uid_field_name, col_uid_field_name) - row_uid, old_row = unpack_row_uid(child.typ.row_type, row_uid_field_name, - drop_uid=row_uid_field_name not in self.child.typ.row_type) + row_uid, old_row = unpack_row_uid( + child.typ.row_type, row_uid_field_name, drop_uid=row_uid_field_name not in self.child.typ.row_type + ) new_row = ir.Let('va', old_row, self.new_row) if col_uid_field_name is not None: col_uid, old_col = unpack_col_uid(child.typ.col_type, col_uid_field_name) @@ -549,7 +572,8 @@ def _compute_type(self, deep_typecheck): child_typ.col_key, self.new_row.typ, child_typ.row_key, - child_typ.entry_type) + child_typ.entry_type, + ) def renderable_bindings(self, i, default_value=None): if i == 1: @@ -590,7 +614,8 @@ def _compute_type(self, deep_typecheck): child_typ.col_key, child_typ.row_type, child_typ.row_key, - child_typ.entry_type) + child_typ.entry_type, + ) def renderable_bindings(self, i, default_value=None): return self.child.typ.global_env(default_value) if i == 1 else {} @@ -651,11 +676,13 @@ def _compute_type(self, deep_typecheck): return hl.tmatrix( child_typ.global_type, child_typ.col_key_type._concat( - hl.tstruct(**{f: hl.tarray(t) for f, t in child_typ.col_value_type.items()})), + hl.tstruct(**{f: hl.tarray(t) for f, t in child_typ.col_value_type.items()}) + ), child_typ.col_key, child_typ.row_type, child_typ.row_key, - hl.tstruct(**{f: hl.tarray(t) for f, t in child_typ.entry_type.items()})) + hl.tstruct(**{f: hl.tarray(t) for f, t in child_typ.entry_type.items()}), + ) class MatrixAggregateColsByKey(MatrixIR): @@ -691,25 +718,20 @@ def _handle_randomness(self, row_uid_field_name, col_uid_field_name): col_expr = ir.AggLet('sa', old_col, col_expr, is_scan=False) col_expr = ir.InsertFields(col_expr, [(col_uid_field_name, first_col_uid)], None) if self.entry_expr.uses_value_randomness: - entry_expr = ir.Let('__rng_state', - ir.RNGSplit(ir.RNGStateLiteral(), - ir.concat_uids(row_uid, first_col_uid)), - entry_expr) + entry_expr = ir.Let( + '__rng_state', ir.RNGSplit(ir.RNGStateLiteral(), ir.concat_uids(row_uid, first_col_uid)), entry_expr + ) if self.entry_expr.uses_agg_randomness(is_scan=False): - entry_expr = ir.AggLet('__rng_state', - ir.RNGSplit(ir.RNGStateLiteral(), - ir.concat_uids(row_uid, col_uid)), - entry_expr, - is_scan=False) + entry_expr = ir.AggLet( + '__rng_state', + ir.RNGSplit(ir.RNGStateLiteral(), ir.concat_uids(row_uid, col_uid)), + entry_expr, + is_scan=False, + ) if self.col_expr.uses_value_randomness: - col_expr = ir.Let('__rng_state', - ir.RNGSplit(ir.RNGStateLiteral(), first_col_uid), - col_expr) + col_expr = ir.Let('__rng_state', ir.RNGSplit(ir.RNGStateLiteral(), first_col_uid), col_expr) if self.col_expr.uses_agg_randomness(is_scan=False): - col_expr = ir.AggLet('__rng_state', - ir.RNGSplit(ir.RNGStateLiteral(), col_uid), - col_expr, - is_scan=False) + col_expr = ir.AggLet('__rng_state', ir.RNGSplit(ir.RNGStateLiteral(), col_uid), col_expr, is_scan=False) result = MatrixAggregateColsByKey(child, entry_expr, col_expr) if drop_row_uid: @@ -731,7 +753,8 @@ def _compute_type(self, deep_typecheck): child_typ.col_key, child_typ.row_type, child_typ.row_key, - self.entry_expr.typ) + self.entry_expr.typ, + ) def renderable_bindings(self, i, default_value=None): if i == 1: @@ -777,7 +800,19 @@ def _handle_randomness(self, row_uid_field_name, col_uid_field_name): ir.Ref('va', new_explode.typ.row_type), self.path, lambda tuple: ir.GetTupleElement(tuple, 0), - lambda row, tuple: ir.InsertFields(row, [(row_uid_field_name, ir.concat_uids(ir.GetField(row, row_uid_field_name), ir.Cast(ir.GetTupleElement(tuple, 1), tint64)))], None)) + lambda row, tuple: ir.InsertFields( + row, + [ + ( + row_uid_field_name, + ir.concat_uids( + ir.GetField(row, row_uid_field_name), ir.Cast(ir.GetTupleElement(tuple, 1), tint64) + ), + ) + ], + None, + ), + ) return MatrixMapRows(new_explode, new_row) def head_str(self): @@ -797,7 +832,8 @@ def _compute_type(self, deep_typecheck): child_typ.col_key, new_row_type, child_typ.row_key, - child_typ.entry_type) + child_typ.entry_type, + ) class MatrixRepartition(MatrixIR): @@ -808,7 +844,9 @@ def __init__(self, child, n, strategy): self.strategy = strategy def _handle_randomness(self, row_uid_field_name, col_uid_field_name): - return MatrixRepartition(self.child.handle_randomness(row_uid_field_name, col_uid_field_name), self.n, self.strategy) + return MatrixRepartition( + self.child.handle_randomness(row_uid_field_name, col_uid_field_name), self.n, self.strategy + ) def head_str(self): return f'{self.n} {self.strategy}' @@ -827,17 +865,23 @@ def __init__(self, *children): self.children = children def _handle_randomness(self, row_uid_field_name, col_uid_field_name): - children = [self.children[0].handle_randomness(row_uid_field_name, col_uid_field_name), - *[child.handle_randomness(row_uid_field_name, None) for child in self.children[1:]]] + children = [ + self.children[0].handle_randomness(row_uid_field_name, col_uid_field_name), + *[child.handle_randomness(row_uid_field_name, None) for child in self.children[1:]], + ] if row_uid_field_name is not None: uids = [uid for uid, _ in (unpack_row_uid(child.typ.row_type, row_uid_field_name) for child in children)] uid_type = ir.unify_uid_types((uid.typ for uid in uids), tag=True) - children = [MatrixMapRows(child, - ir.InsertFields(ir.Ref('va', child.typ.row_type), - [(row_uid_field_name, ir.pad_uid(uid, uid_type, i))], - None)) - for i, (child, uid) in enumerate(zip(children, uids))] + children = [ + MatrixMapRows( + child, + ir.InsertFields( + ir.Ref('va', child.typ.row_type), [(row_uid_field_name, ir.pad_uid(uid, uid_type, i))], None + ), + ) + for i, (child, uid) in enumerate(zip(children, uids)) + ] return MatrixUnionRows(*children) @@ -963,7 +1007,19 @@ def _handle_randomness(self, row_uid_field_name, col_uid_field_name): ir.Ref('sa', new_explode.typ.col_type), self.path, lambda tuple: ir.GetTupleElement(tuple, 0), - lambda col, tuple: ir.InsertFields(col, [(col_uid_field_name, ir.concat_uids(ir.GetField(col, col_uid_field_name), ir.Cast(ir.GetTupleElement(tuple, 1), tint64)))], None)) + lambda col, tuple: ir.InsertFields( + col, + [ + ( + col_uid_field_name, + ir.concat_uids( + ir.GetField(col, col_uid_field_name), ir.Cast(ir.GetTupleElement(tuple, 1), tint64) + ), + ) + ], + None, + ), + ) return MatrixMapCols(new_explode, new_col, None) def head_str(self): @@ -983,7 +1039,8 @@ def _compute_type(self, deep_typecheck): child_typ.col_key, child_typ.row_type, child_typ.row_key, - child_typ.entry_type) + child_typ.entry_type, + ) class CastTableToMatrix(MatrixIR): @@ -996,29 +1053,33 @@ def __init__(self, child, entries_field_name, cols_field_name, col_key): def _handle_randomness(self, row_uid_field_name, col_uid_field_name): from hail.ir.table_ir import TableMapGlobals + child = self.child if col_uid_field_name is not None: new_globals = modify_deep_field( ir.Ref('global', child.typ.global_type), [self.cols_field_name], - lambda g: zip_with_index_field(g, col_uid_field_name)) + lambda g: zip_with_index_field(g, col_uid_field_name), + ) child = TableMapGlobals(child, new_globals) - return CastTableToMatrix(child.handle_randomness(row_uid_field_name), - self.entries_field_name, - self.cols_field_name, - self.col_key) + return CastTableToMatrix( + child.handle_randomness(row_uid_field_name), self.entries_field_name, self.cols_field_name, self.col_key + ) def head_str(self): return '{} {} ({})'.format( escape_str(self.entries_field_name), escape_str(self.cols_field_name), - ' '.join([escape_id(id) for id in self.col_key])) + ' '.join([escape_id(id) for id in self.col_key]), + ) def _eq(self, other): - return self.entries_field_name == other.entries_field_name and \ - self.cols_field_name == other.cols_field_name and \ - self.col_key == other.col_key + return ( + self.entries_field_name == other.entries_field_name + and self.cols_field_name == other.cols_field_name + and self.col_key == other.col_key + ) def _compute_type(self, deep_typecheck): self.child.compute_type(deep_typecheck) @@ -1029,7 +1090,8 @@ def _compute_type(self, deep_typecheck): self.col_key, child_typ.row_type._drop_fields([self.entries_field_name]), child_typ.row_key, - child_typ.row_type[self.entries_field_name].element_type) + child_typ.row_type[self.entries_field_name].element_type, + ) class MatrixAnnotateRowsTable(MatrixIR): @@ -1045,7 +1107,8 @@ def _handle_randomness(self, row_uid_field_name, col_uid_field_name): self.child.handle_randomness(row_uid_field_name, col_uid_field_name), self.table.handle_randomness(None), self.root, - self.product) + self.product, + ) def head_str(self): return f'"{escape_str(self.root)}" {self.product}' @@ -1067,7 +1130,8 @@ def _compute_type(self, deep_typecheck): child_typ.col_key, child_typ.row_type._insert_field(self.root, value_type), child_typ.row_key, - child_typ.entry_type) + child_typ.entry_type, + ) class MatrixAnnotateColsTable(MatrixIR): @@ -1081,7 +1145,8 @@ def _handle_randomness(self, row_uid_field_name, col_uid_field_name): return MatrixAnnotateColsTable( self.child.handle_randomness(row_uid_field_name, col_uid_field_name), self.table.handle_randomness(None), - self.root) + self.root, + ) def head_str(self): return f'"{escape_str(self.root)}"' @@ -1099,7 +1164,8 @@ def _compute_type(self, deep_typecheck): child_typ.col_key, child_typ.row_type, child_typ.row_key, - child_typ.entry_type) + child_typ.entry_type, + ) class MatrixToMatrixApply(MatrixIR): @@ -1150,20 +1216,24 @@ def _handle_randomness(self, row_uid_field_name, col_uid_field_name): return MatrixRename(child, self.global_map, self.col_map, self.row_map, self.entry_map) def head_str(self): - return f'{parsable_strings(self.global_map.keys())} ' \ - f'{parsable_strings(self.global_map.values())} ' \ - f'{parsable_strings(self.col_map.keys())} ' \ - f'{parsable_strings(self.col_map.values())} ' \ - f'{parsable_strings(self.row_map.keys())} ' \ - f'{parsable_strings(self.row_map.values())} ' \ - f'{parsable_strings(self.entry_map.keys())} ' \ - f'{parsable_strings(self.entry_map.values())} ' + return ( + f'{parsable_strings(self.global_map.keys())} ' + f'{parsable_strings(self.global_map.values())} ' + f'{parsable_strings(self.col_map.keys())} ' + f'{parsable_strings(self.col_map.values())} ' + f'{parsable_strings(self.row_map.keys())} ' + f'{parsable_strings(self.row_map.values())} ' + f'{parsable_strings(self.entry_map.keys())} ' + f'{parsable_strings(self.entry_map.values())} ' + ) def _eq(self, other): - return self.global_map == other.global_map and \ - self.col_map == other.col_map and \ - self.row_map == other.row_map and \ - self.entry_map == other.entry_map + return ( + self.global_map == other.global_map + and self.col_map == other.col_map + and self.row_map == other.row_map + and self.entry_map == other.entry_map + ) def _compute_type(self, deep_typecheck): self.child.compute_type(deep_typecheck) diff --git a/hail/python/hail/ir/matrix_reader.py b/hail/python/hail/ir/matrix_reader.py index 08386b94484..b88cbc16c44 100644 --- a/hail/python/hail/ir/matrix_reader.py +++ b/hail/python/hail/ir/matrix_reader.py @@ -1,13 +1,12 @@ import abc import json -from .utils import make_filter_and_replace, impute_type_of_partition_interval_array from ..expr.types import HailType, tfloat32, tfloat64 from ..genetics.reference_genome import reference_genome_type -from ..typecheck import (typecheck_method, sequenceof, nullable, enumeration, anytype, oneof, - dictof, sized_tupleof) +from ..typecheck import anytype, dictof, enumeration, nullable, oneof, sequenceof, sized_tupleof, typecheck_method from ..utils import wrap_to_list from ..utils.misc import escape_str +from .utils import impute_type_of_partition_interval_array, make_filter_and_replace class MatrixReader(object): @@ -21,17 +20,14 @@ def __eq__(self, other): class MatrixNativeReader(MatrixReader): - @typecheck_method(path=str, - intervals=nullable(sequenceof(anytype)), - filter_intervals=bool) + @typecheck_method(path=str, intervals=nullable(sequenceof(anytype)), filter_intervals=bool) def __init__(self, path, intervals, filter_intervals): self.path = path self.filter_intervals = filter_intervals self.intervals, self._interval_type = impute_type_of_partition_interval_array(intervals) def render(self, r): - reader = {'name': 'MatrixNativeReader', - 'path': self.path} + reader = {'name': 'MatrixNativeReader', 'path': self.path} if self.intervals is not None: assert self._interval_type is not None reader['options'] = { @@ -43,74 +39,82 @@ def render(self, r): return escape_str(json.dumps(reader)) def __eq__(self, other): - return isinstance(other, MatrixNativeReader) and \ - other.path == self.path and \ - other.intervals == self.intervals and \ - other.filter_intervals == self.filter_intervals + return ( + isinstance(other, MatrixNativeReader) + and other.path == self.path + and other.intervals == self.intervals + and other.filter_intervals == self.filter_intervals + ) class MatrixRangeReader(MatrixReader): - @typecheck_method(n_rows=int, - n_cols=int, - n_partitions=nullable(int)) + @typecheck_method(n_rows=int, n_cols=int, n_partitions=nullable(int)) def __init__(self, n_rows, n_cols, n_partitions): self.n_rows = n_rows self.n_cols = n_cols self.n_partitions = n_partitions def render(self, r): - reader = {'name': 'MatrixRangeReader', - 'nRows': self.n_rows, - 'nCols': self.n_cols, - 'nPartitions': self.n_partitions} + reader = { + 'name': 'MatrixRangeReader', + 'nRows': self.n_rows, + 'nCols': self.n_cols, + 'nPartitions': self.n_partitions, + } return escape_str(json.dumps(reader)) def __eq__(self, other): - return isinstance(other, MatrixRangeReader) and \ - other.n_rows == self.n_rows and \ - other.n_cols == self.n_cols and \ - other.n_partitions == self.n_partitions + return ( + isinstance(other, MatrixRangeReader) + and other.n_rows == self.n_rows + and other.n_cols == self.n_cols + and other.n_partitions == self.n_partitions + ) class MatrixVCFReader(MatrixReader): - @typecheck_method(path=oneof(str, sequenceof(str)), - call_fields=oneof(str, sequenceof(str)), - entry_float_type=enumeration(tfloat32, tfloat64), - header_file=nullable(str), - n_partitions=nullable(int), - block_size=nullable(int), - min_partitions=nullable(int), - reference_genome=nullable(reference_genome_type), - contig_recoding=nullable(dictof(str, str)), - array_elements_required=bool, - skip_invalid_loci=bool, - force_bgz=bool, - force_gz=bool, - filter=nullable(str), - find_replace=nullable(sized_tupleof(str, str)), - _sample_ids=nullable(sequenceof(str)), - _partitions_json=nullable(str), - _partitions_type=nullable(HailType)) - def __init__(self, - path, - call_fields, - entry_float_type, - header_file, - n_partitions, - block_size, - min_partitions, - reference_genome, - contig_recoding, - array_elements_required, - skip_invalid_loci, - force_bgz, - force_gz, - filter, - find_replace, - *, - _sample_ids=None, - _partitions_json=None, - _partitions_type=None): + @typecheck_method( + path=oneof(str, sequenceof(str)), + call_fields=oneof(str, sequenceof(str)), + entry_float_type=enumeration(tfloat32, tfloat64), + header_file=nullable(str), + n_partitions=nullable(int), + block_size=nullable(int), + min_partitions=nullable(int), + reference_genome=nullable(reference_genome_type), + contig_recoding=nullable(dictof(str, str)), + array_elements_required=bool, + skip_invalid_loci=bool, + force_bgz=bool, + force_gz=bool, + filter=nullable(str), + find_replace=nullable(sized_tupleof(str, str)), + _sample_ids=nullable(sequenceof(str)), + _partitions_json=nullable(str), + _partitions_type=nullable(HailType), + ) + def __init__( + self, + path, + call_fields, + entry_float_type, + header_file, + n_partitions, + block_size, + min_partitions, + reference_genome, + contig_recoding, + array_elements_required, + skip_invalid_loci, + force_bgz, + force_gz, + filter, + find_replace, + *, + _sample_ids=None, + _partitions_json=None, + _partitions_type=None, + ): self.path = wrap_to_list(path) self.header_file = header_file self.n_partitions = n_partitions @@ -131,53 +135,61 @@ def __init__(self, self._partitions_type = _partitions_type def render(self, r): - reader = {'name': 'MatrixVCFReader', - 'files': self.path, - 'callFields': self.call_fields, - 'entryFloatTypeName': self.entry_float_type, - 'headerFile': self.header_file, - 'nPartitions': self.n_partitions, - 'blockSizeInMB': self.block_size, - 'minPartitions': self.min_partitions, - 'rg': self.reference_genome.name if self.reference_genome else None, - 'contigRecoding': self.contig_recoding if self.contig_recoding else {}, - 'arrayElementsRequired': self.array_elements_required, - 'skipInvalidLoci': self.skip_invalid_loci, - 'gzAsBGZ': self.force_bgz, - 'forceGZ': self.force_gz, - 'filterAndReplace': make_filter_and_replace(self.filter, self.find_replace), - 'sampleIDs': self._sample_ids, - 'partitionsTypeStr': self._partitions_type._parsable_string() if self._partitions_type is not None else None, - 'partitionsJSON': self._partitions_json} + reader = { + 'name': 'MatrixVCFReader', + 'files': self.path, + 'callFields': self.call_fields, + 'entryFloatTypeName': self.entry_float_type, + 'headerFile': self.header_file, + 'nPartitions': self.n_partitions, + 'blockSizeInMB': self.block_size, + 'minPartitions': self.min_partitions, + 'rg': self.reference_genome.name if self.reference_genome else None, + 'contigRecoding': self.contig_recoding if self.contig_recoding else {}, + 'arrayElementsRequired': self.array_elements_required, + 'skipInvalidLoci': self.skip_invalid_loci, + 'gzAsBGZ': self.force_bgz, + 'forceGZ': self.force_gz, + 'filterAndReplace': make_filter_and_replace(self.filter, self.find_replace), + 'sampleIDs': self._sample_ids, + 'partitionsTypeStr': self._partitions_type._parsable_string() + if self._partitions_type is not None + else None, + 'partitionsJSON': self._partitions_json, + } return escape_str(json.dumps(reader)) def __eq__(self, other): - return isinstance(other, MatrixVCFReader) and \ - other.path == self.path and \ - other.call_fields == self.call_fields and \ - other.entry_float_type == self.entry_float_type and \ - other.header_file == self.header_file and \ - other.min_partitions == self.min_partitions and \ - other.reference_genome == self.reference_genome and \ - other.contig_recoding == self.contig_recoding and \ - other.array_elements_required == self.array_elements_required and \ - other.skip_invalid_loci == self.skip_invalid_loci and \ - other.force_bgz == self.force_bgz and \ - other.force_gz == self.force_gz and \ - other.filter == self.filter and \ - other.find_replace == self.find_replace and \ - other._partitions_json == self._partitions_json and \ - other._partitions_type == self._partitions_type and \ - other._sample_ids == self._sample_ids + return ( + isinstance(other, MatrixVCFReader) + and other.path == self.path + and other.call_fields == self.call_fields + and other.entry_float_type == self.entry_float_type + and other.header_file == self.header_file + and other.min_partitions == self.min_partitions + and other.reference_genome == self.reference_genome + and other.contig_recoding == self.contig_recoding + and other.array_elements_required == self.array_elements_required + and other.skip_invalid_loci == self.skip_invalid_loci + and other.force_bgz == self.force_bgz + and other.force_gz == self.force_gz + and other.filter == self.filter + and other.find_replace == self.find_replace + and other._partitions_json == self._partitions_json + and other._partitions_type == self._partitions_type + and other._sample_ids == self._sample_ids + ) class MatrixBGENReader(MatrixReader): - @typecheck_method(path=oneof(str, sequenceof(str)), - sample_file=nullable(str), - index_file_map=nullable(dictof(str, str)), - n_partitions=nullable(int), - block_size=nullable(int), - included_variants=nullable(str)) + @typecheck_method( + path=oneof(str, sequenceof(str)), + sample_file=nullable(str), + index_file_map=nullable(dictof(str, str)), + n_partitions=nullable(int), + block_size=nullable(int), + included_variants=nullable(str), + ) def __init__(self, path, sample_file, index_file_map, n_partitions, block_size, included_variants): self.path = wrap_to_list(path) self.sample_file = sample_file @@ -187,34 +199,60 @@ def __init__(self, path, sample_file, index_file_map, n_partitions, block_size, self.included_variants = included_variants def render(self, r): - reader = {'name': 'MatrixBGENReader', - 'files': self.path, - 'sampleFile': self.sample_file, - 'indexFileMap': self.index_file_map, - 'nPartitions': self.n_partitions, - 'blockSizeInMB': self.block_size, - 'includedVariants': self.included_variants - } + reader = { + 'name': 'MatrixBGENReader', + 'files': self.path, + 'sampleFile': self.sample_file, + 'indexFileMap': self.index_file_map, + 'nPartitions': self.n_partitions, + 'blockSizeInMB': self.block_size, + 'includedVariants': self.included_variants, + } return escape_str(json.dumps(reader)) def __eq__(self, other): - return isinstance(other, MatrixBGENReader) and \ - other.path == self.path and \ - other.sample_file == self.sample_file and \ - other.index_file_map == self.index_file_map and \ - other.block_size == self.block_size and \ - other.included_variants == self.included_variants + return ( + isinstance(other, MatrixBGENReader) + and other.path == self.path + and other.sample_file == self.sample_file + and other.index_file_map == self.index_file_map + and other.block_size == self.block_size + and other.included_variants == self.included_variants + ) class MatrixPLINKReader(MatrixReader): - @typecheck_method(bed=str, bim=str, fam=str, - n_partitions=nullable(int), block_size=nullable(int), min_partitions=nullable(int), - missing=str, delimiter=str, quant_pheno=bool, - a2_reference=bool, reference_genome=nullable(reference_genome_type), - contig_recoding=nullable(dictof(str, str)), skip_invalid_loci=bool) - def __init__(self, bed, bim, fam, n_partitions, block_size, min_partitions, - missing, delimiter, quant_pheno, a2_reference, reference_genome, - contig_recoding, skip_invalid_loci): + @typecheck_method( + bed=str, + bim=str, + fam=str, + n_partitions=nullable(int), + block_size=nullable(int), + min_partitions=nullable(int), + missing=str, + delimiter=str, + quant_pheno=bool, + a2_reference=bool, + reference_genome=nullable(reference_genome_type), + contig_recoding=nullable(dictof(str, str)), + skip_invalid_loci=bool, + ) + def __init__( + self, + bed, + bim, + fam, + n_partitions, + block_size, + min_partitions, + missing, + delimiter, + quant_pheno, + a2_reference, + reference_genome, + contig_recoding, + skip_invalid_loci, + ): self.bed = bed self.bim = bim self.fam = fam @@ -230,32 +268,36 @@ def __init__(self, bed, bim, fam, n_partitions, block_size, min_partitions, self.skip_invalid_loci = skip_invalid_loci def render(self, r): - reader = {'name': 'MatrixPLINKReader', - 'bed': self.bed, - 'bim': self.bim, - 'fam': self.fam, - 'nPartitions': self.n_partitions, - 'blockSizeInMB': self.block_size, - 'minPartitions': self.min_partitions, - 'missing': self.missing, - 'delimiter': self.delimiter, - 'quantPheno': self.quant_pheno, - 'a2Reference': self.a2_reference, - 'rg': self.reference_genome.name if self.reference_genome else None, - 'contigRecoding': self.contig_recoding if self.contig_recoding else {}, - 'skipInvalidLoci': self.skip_invalid_loci} + reader = { + 'name': 'MatrixPLINKReader', + 'bed': self.bed, + 'bim': self.bim, + 'fam': self.fam, + 'nPartitions': self.n_partitions, + 'blockSizeInMB': self.block_size, + 'minPartitions': self.min_partitions, + 'missing': self.missing, + 'delimiter': self.delimiter, + 'quantPheno': self.quant_pheno, + 'a2Reference': self.a2_reference, + 'rg': self.reference_genome.name if self.reference_genome else None, + 'contigRecoding': self.contig_recoding if self.contig_recoding else {}, + 'skipInvalidLoci': self.skip_invalid_loci, + } return escape_str(json.dumps(reader)) def __eq__(self, other): - return isinstance(other, MatrixPLINKReader) and \ - other.bed == self.bed and \ - other.bim == self.bim and \ - other.fam == self.fam and \ - other.min_partitions == self.min_partitions and \ - other.missing == self.missing and \ - other.delimiter == self.delimiter and \ - other.quant_pheno == self.quant_pheno and \ - other.a2_reference == self.a2_reference and \ - other.reference_genome == self.reference_genome and \ - other.contig_recoding == self.contig_recoding and \ - other.skip_invalid_loci == self.skip_invalid_loci + return ( + isinstance(other, MatrixPLINKReader) + and other.bed == self.bed + and other.bim == self.bim + and other.fam == self.fam + and other.min_partitions == self.min_partitions + and other.missing == self.missing + and other.delimiter == self.delimiter + and other.quant_pheno == self.quant_pheno + and other.a2_reference == self.a2_reference + and other.reference_genome == self.reference_genome + and other.contig_recoding == self.contig_recoding + and other.skip_invalid_loci == self.skip_invalid_loci + ) diff --git a/hail/python/hail/ir/matrix_writer.py b/hail/python/hail/ir/matrix_writer.py index 1000071f719..6bec21148aa 100644 --- a/hail/python/hail/ir/matrix_writer.py +++ b/hail/python/hail/ir/matrix_writer.py @@ -1,7 +1,9 @@ import abc import json + from hail.expr.types import hail_type -from ..typecheck import typecheck_method, nullable, dictof, sequenceof + +from ..typecheck import dictof, nullable, sequenceof, typecheck_method from ..utils.misc import escape_str from .export_type import ExportType @@ -17,12 +19,14 @@ def __eq__(self, other): class MatrixNativeWriter(MatrixWriter): - @typecheck_method(path=str, - overwrite=bool, - stage_locally=bool, - codec_spec=nullable(str), - partitions=nullable(str), - partitions_type=nullable(hail_type)) + @typecheck_method( + path=str, + overwrite=bool, + stage_locally=bool, + codec_spec=nullable(str), + partitions=nullable(str), + partitions_type=nullable(hail_type), + ) def __init__(self, path, overwrite, stage_locally, codec_spec, partitions, partitions_type): self.path = path self.overwrite = overwrite @@ -32,32 +36,37 @@ def __init__(self, path, overwrite, stage_locally, codec_spec, partitions, parti self.partitions_type = partitions_type def render(self): - writer = {'name': 'MatrixNativeWriter', - 'path': self.path, - 'overwrite': self.overwrite, - 'stageLocally': self.stage_locally, - 'codecSpecJSONStr': self.codec_spec, - 'partitions': self.partitions, - 'partitionsTypeStr': self.partitions_type._parsable_string() if self.partitions_type is not None else None - } + writer = { + 'name': 'MatrixNativeWriter', + 'path': self.path, + 'overwrite': self.overwrite, + 'stageLocally': self.stage_locally, + 'codecSpecJSONStr': self.codec_spec, + 'partitions': self.partitions, + 'partitionsTypeStr': self.partitions_type._parsable_string() if self.partitions_type is not None else None, + } return escape_str(json.dumps(writer)) def __eq__(self, other): - return isinstance(other, MatrixNativeWriter) and \ - other.path == self.path and \ - other.overwrite == self.overwrite and \ - other.stage_locally == self.stage_locally and \ - other.codec_spec == self.codec_spec and \ - other.partitions == self.partitions and \ - other.partitions_type == self.partitions_type + return ( + isinstance(other, MatrixNativeWriter) + and other.path == self.path + and other.overwrite == self.overwrite + and other.stage_locally == self.stage_locally + and other.codec_spec == self.codec_spec + and other.partitions == self.partitions + and other.partitions_type == self.partitions_type + ) class MatrixVCFWriter(MatrixWriter): - @typecheck_method(path=str, - append=nullable(str), - export_type=ExportType.checker, - metadata=nullable(dictof(str, dictof(str, dictof(str, str)))), - tabix=bool) + @typecheck_method( + path=str, + append=nullable(str), + export_type=ExportType.checker, + metadata=nullable(dictof(str, dictof(str, dictof(str, str)))), + tabix=bool, + ) def __init__(self, path, append, export_type, metadata, tabix): self.path = path self.append = append @@ -66,40 +75,39 @@ def __init__(self, path, append, export_type, metadata, tabix): self.tabix = tabix def render(self): - writer = {'name': 'MatrixVCFWriter', - 'path': self.path, - 'append': self.append, - 'exportType': self.export_type, - 'metadata': self.metadata, - 'tabix': self.tabix} + writer = { + 'name': 'MatrixVCFWriter', + 'path': self.path, + 'append': self.append, + 'exportType': self.export_type, + 'metadata': self.metadata, + 'tabix': self.tabix, + } return escape_str(json.dumps(writer)) def __eq__(self, other): - return isinstance(other, MatrixVCFWriter) and \ - other.path == self.path and \ - other.append == self.append and \ - other.export_type == self.export_type and \ - other.metadata == self.metadata and \ - other.tabix == self.tabix + return ( + isinstance(other, MatrixVCFWriter) + and other.path == self.path + and other.append == self.append + and other.export_type == self.export_type + and other.metadata == self.metadata + and other.tabix == self.tabix + ) class MatrixGENWriter(MatrixWriter): - @typecheck_method(path=str, - precision=int) + @typecheck_method(path=str, precision=int) def __init__(self, path, precision): self.path = path self.precision = precision def render(self): - writer = {'name': 'MatrixGENWriter', - 'path': self.path, - 'precision': self.precision} + writer = {'name': 'MatrixGENWriter', 'path': self.path, 'precision': self.precision} return escape_str(json.dumps(writer)) def __eq__(self, other): - return isinstance(other, MatrixGENWriter) and \ - other.path == self.path and \ - other.precision == self.precision + return isinstance(other, MatrixGENWriter) and other.path == self.path and other.precision == self.precision class MatrixBGENWriter(MatrixWriter): @@ -110,17 +118,21 @@ def __init__(self, path, export_type, compression_codec): self.compression_codec = compression_codec def render(self): - writer = {'name': 'MatrixBGENWriter', - 'path': self.path, - 'exportType': self.export_type, - 'compressionCodec': self.compression_codec} + writer = { + 'name': 'MatrixBGENWriter', + 'path': self.path, + 'exportType': self.export_type, + 'compressionCodec': self.compression_codec, + } return escape_str(json.dumps(writer)) def __eq__(self, other): - return isinstance(other, MatrixBGENWriter) and \ - other.path == self.path and \ - other.export_type == self.export_type and \ - other.compression_codec == self.compression_codec + return ( + isinstance(other, MatrixBGENWriter) + and other.path == self.path + and other.export_type == self.export_type + and other.compression_codec == self.compression_codec + ) class MatrixPLINKWriter(MatrixWriter): @@ -129,13 +141,11 @@ def __init__(self, path): self.path = path def render(self): - writer = {'name': 'MatrixPLINKWriter', - 'path': self.path} + writer = {'name': 'MatrixPLINKWriter', 'path': self.path} return escape_str(json.dumps(writer)) def __eq__(self, other): - return isinstance(other, MatrixPLINKWriter) and \ - other.path == self.path + return isinstance(other, MatrixPLINKWriter) and other.path == self.path class MatrixBlockMatrixWriter(MatrixWriter): @@ -147,24 +157,27 @@ def __init__(self, path, overwrite, entry_field, block_size): self.block_size = block_size def render(self): - writer = {'name': 'MatrixBlockMatrixWriter', - 'path': self.path, - 'overwrite': self.overwrite, - 'entryField': self.entry_field, - 'blockSize': self.block_size} + writer = { + 'name': 'MatrixBlockMatrixWriter', + 'path': self.path, + 'overwrite': self.overwrite, + 'entryField': self.entry_field, + 'blockSize': self.block_size, + } return escape_str(json.dumps(writer)) def __eq__(self, other): - return isinstance(other, MatrixBlockMatrixWriter) and \ - other.path == self.path and other.overwrite == self.overwrite and \ - other.entry_field == self.entry_field and other.block_size == self.block_size + return ( + isinstance(other, MatrixBlockMatrixWriter) + and other.path == self.path + and other.overwrite == self.overwrite + and other.entry_field == self.entry_field + and other.block_size == self.block_size + ) class MatrixNativeMultiWriter(object): - @typecheck_method(paths=sequenceof(str), - overwrite=bool, - stage_locally=bool, - codec_spec=nullable(str)) + @typecheck_method(paths=sequenceof(str), overwrite=bool, stage_locally=bool, codec_spec=nullable(str)) def __init__(self, paths, overwrite, stage_locally, codec_spec): self.paths = paths self.overwrite = overwrite @@ -172,16 +185,20 @@ def __init__(self, paths, overwrite, stage_locally, codec_spec): self.codec_spec = codec_spec def render(self): - writer = {'name': 'MatrixNativeMultiWriter', - 'paths': list(self.paths), - 'overwrite': self.overwrite, - 'stageLocally': self.stage_locally, - 'codecSpecJSONStr': self.codec_spec} + writer = { + 'name': 'MatrixNativeMultiWriter', + 'paths': list(self.paths), + 'overwrite': self.overwrite, + 'stageLocally': self.stage_locally, + 'codecSpecJSONStr': self.codec_spec, + } return escape_str(json.dumps(writer)) def __eq__(self, other): - return isinstance(other, MatrixNativeMultiWriter) and \ - other.paths == self.paths and \ - other.overwrite == self.overwrite and \ - other.stage_locally == self.stage_locally and \ - other.codec_spec == self.codec_spec + return ( + isinstance(other, MatrixNativeMultiWriter) + and other.paths == self.paths + and other.overwrite == self.overwrite + and other.stage_locally == self.stage_locally + and other.codec_spec == self.codec_spec + ) diff --git a/hail/python/hail/ir/register_aggregators.py b/hail/python/hail/ir/register_aggregators.py index e71b7a83ea7..df94be89b3e 100644 --- a/hail/python/hail/ir/register_aggregators.py +++ b/hail/python/hail/ir/register_aggregators.py @@ -4,14 +4,30 @@ def register_aggregators(): from hail.expr.types import dtype - register_aggregator('ApproxCDF', (dtype('int32'),), (dtype('int32'),), - dtype('struct{levels:array,items:array,_compaction_counts:array}')) - register_aggregator('ApproxCDF', (dtype('int32'),), (dtype('int64'),), - dtype('struct{levels:array,items:array,_compaction_counts:array}')) - register_aggregator('ApproxCDF', (dtype('int32'),), (dtype('float32'),), - dtype('struct{levels:array,items:array,_compaction_counts:array}')) - register_aggregator('ApproxCDF', (dtype('int32'),), (dtype('float64'),), - dtype('struct{levels:array,items:array,_compaction_counts:array}')) + register_aggregator( + 'ApproxCDF', + (dtype('int32'),), + (dtype('int32'),), + dtype('struct{levels:array,items:array,_compaction_counts:array}'), + ) + register_aggregator( + 'ApproxCDF', + (dtype('int32'),), + (dtype('int64'),), + dtype('struct{levels:array,items:array,_compaction_counts:array}'), + ) + register_aggregator( + 'ApproxCDF', + (dtype('int32'),), + (dtype('float32'),), + dtype('struct{levels:array,items:array,_compaction_counts:array}'), + ) + register_aggregator( + 'ApproxCDF', + (dtype('int32'),), + (dtype('float64'),), + dtype('struct{levels:array,items:array,_compaction_counts:array}'), + ) register_aggregator('Collect', (), (dtype("?in"),), dtype('array')) register_aggregator('Densify', (dtype('int32'),), (dtype("?in"),), dtype('?in')) @@ -53,29 +69,77 @@ def register_aggregators(): register_aggregator('ReservoirSample', (dtype('int32'),), (dtype('?in'),), dtype('array')) - register_aggregator('TakeBy', (dtype('int32'),), (dtype('?in'), dtype('?key'),), dtype('array')) + register_aggregator( + 'TakeBy', + (dtype('int32'),), + ( + dtype('?in'), + dtype('?key'), + ), + dtype('array'), + ) downsample_aggregator_type = dtype('array)>') - register_aggregator('Downsample', (dtype('int32'),), (dtype('float64'), dtype('float64'), dtype('array'),), downsample_aggregator_type) - - call_stats_aggregator_type = dtype('struct{AC: array,AF:array,AN:int32,homozygote_count:array}') + register_aggregator( + 'Downsample', + (dtype('int32'),), + ( + dtype('float64'), + dtype('float64'), + dtype('array'), + ), + downsample_aggregator_type, + ) + + call_stats_aggregator_type = dtype( + 'struct{AC: array,AF:array,AN:int32,homozygote_count:array}' + ) register_aggregator('CallStats', (dtype('int32'),), (dtype('call'),), call_stats_aggregator_type) - inbreeding_aggregator_type = dtype('struct{f_stat:float64,n_called:int64,expected_homs:float64,observed_homs:int64}') - register_aggregator('Inbreeding', (), (dtype('call'), dtype('float64'),), inbreeding_aggregator_type) - - linreg_aggregator_type = dtype('struct{xty:array,beta:array,diag_inv:array,beta0:array}') - register_aggregator('LinearRegression', (dtype('int32'), dtype('int32'),), (dtype('float64'), dtype('array'),), linreg_aggregator_type) + inbreeding_aggregator_type = dtype( + 'struct{f_stat:float64,n_called:int64,expected_homs:float64,observed_homs:int64}' + ) + register_aggregator( + 'Inbreeding', + (), + ( + dtype('call'), + dtype('float64'), + ), + inbreeding_aggregator_type, + ) + + linreg_aggregator_type = dtype( + 'struct{xty:array,beta:array,diag_inv:array,beta0:array}' + ) + register_aggregator( + 'LinearRegression', + ( + dtype('int32'), + dtype('int32'), + ), + ( + dtype('float64'), + dtype('array'), + ), + linreg_aggregator_type, + ) register_aggregator('PrevNonnull', (), (dtype('?in'),), dtype('?in')) - register_aggregator('ImputeType', (), (dtype('str'),), - dtype('struct{anyNonMissing: bool,' - 'allDefined: bool,' - 'supportsBool: bool,' - 'supportsInt32: bool,' - 'supportsInt64: bool,' - 'supportsFloat64: bool}')) + register_aggregator( + 'ImputeType', + (), + (dtype('str'),), + dtype( + 'struct{anyNonMissing: bool,' + 'allDefined: bool,' + 'supportsBool: bool,' + 'supportsInt32: bool,' + 'supportsInt64: bool,' + 'supportsFloat64: bool}' + ), + ) numeric_ndarray_type = dtype("ndarray") register_aggregator('NDArraySum', (), (numeric_ndarray_type,), numeric_ndarray_type) diff --git a/hail/python/hail/ir/register_functions.py b/hail/python/hail/ir/register_functions.py index 0c75f67cb19..874d7c8f4e0 100644 --- a/hail/python/hail/ir/register_functions.py +++ b/hail/python/hail/ir/register_functions.py @@ -1,31 +1,67 @@ from hail.expr.nat import NatVariable -from hail.expr.types import dtype, tvariable, tarray, \ - tint32, tint64, tfloat32, tfloat64, tndarray +from hail.expr.types import dtype, tarray, tfloat32, tfloat64, tint32, tint64, tndarray, tvariable from .ir import register_function, register_seeded_function -vcf_header_type_str = "struct{sampleIDs: array, " \ - "infoFields: array, " \ - "formatFields: array, " \ - "filterAttrs: dict>, " \ - "infoAttrs: dict>, " \ - "formatAttrs: dict>, " \ - "infoFlagFields: array}" +vcf_header_type_str = ( + "struct{sampleIDs: array, " + "infoFields: array, " + "formatFields: array, " + "filterAttrs: dict>, " + "infoAttrs: dict>, " + "formatAttrs: dict>, " + "infoFlagFields: array}" +) def register_functions(): locusVar = tvariable("R", "locus") register_function("isValidContig", (dtype("str"),), dtype("bool"), (locusVar,)) - register_function("isValidLocus", (dtype("str"), dtype("int32"),), dtype("bool"), (locusVar,)) + register_function( + "isValidLocus", + ( + dtype("str"), + dtype("int32"), + ), + dtype("bool"), + (locusVar,), + ) register_function("contigLength", (dtype("str"),), dtype("int32"), (locusVar,)) - register_function("getReferenceSequenceFromValidLocus", (dtype("str"), dtype("int32"), dtype("int32"), dtype("int32"),), dtype("str"), (locusVar,)) - register_function("getReferenceSequence", (dtype("str"), dtype("int32"), dtype("int32"), dtype("int32"),), dtype("str"), (locusVar,)) + register_function( + "getReferenceSequenceFromValidLocus", + ( + dtype("str"), + dtype("int32"), + dtype("int32"), + dtype("int32"), + ), + dtype("str"), + (locusVar,), + ) + register_function( + "getReferenceSequence", + ( + dtype("str"), + dtype("int32"), + dtype("int32"), + dtype("int32"), + ), + dtype("str"), + (locusVar,), + ) register_function("parse_json", (dtype("str"),), dtype("tuple(?T)"), (dtype("?T"),)) register_function("flatten", (dtype("array>"),), dtype("array")) - register_function("difference", (dtype("set"), dtype("set"),), dtype("set")) + register_function( + "difference", + ( + dtype("set"), + dtype("set"), + ), + dtype("set"), + ) register_function("median", (dtype("set"),), dtype("?T")) register_function("median", (dtype("array"),), dtype("?T")) register_function("uniqueMinIndex", (dtype("array"),), dtype("int32")) @@ -35,273 +71,1283 @@ def register_functions(): register_function("toSet", (dtype("array"),), dtype("set")) def array_floating_point_divide(arg_type, ret_type): - register_function("div", (arg_type, tarray(arg_type),), tarray(ret_type)) + register_function( + "div", + ( + arg_type, + tarray(arg_type), + ), + tarray(ret_type), + ) register_function("div", (tarray(arg_type), arg_type), tarray(ret_type)) register_function("div", (tarray(arg_type), tarray(arg_type)), tarray(ret_type)) + array_floating_point_divide(tint32, tfloat32) array_floating_point_divide(tint64, tfloat32) array_floating_point_divide(tfloat32, tfloat32) array_floating_point_divide(tfloat64, tfloat64) def ndarray_floating_point_divide(arg_type, ret_type): - register_function("div", (arg_type, tndarray(arg_type, NatVariable()),), tndarray(ret_type, NatVariable())) + register_function( + "div", + ( + arg_type, + tndarray(arg_type, NatVariable()), + ), + tndarray(ret_type, NatVariable()), + ) register_function("div", (tndarray(arg_type, NatVariable()), arg_type), tndarray(ret_type, NatVariable())) - register_function("div", (tndarray(arg_type, NatVariable()), - tndarray(arg_type, NatVariable())), tndarray(ret_type, NatVariable())) + register_function( + "div", + (tndarray(arg_type, NatVariable()), tndarray(arg_type, NatVariable())), + tndarray(ret_type, NatVariable()), + ) + ndarray_floating_point_divide(tint32, tfloat32) ndarray_floating_point_divide(tint64, tfloat32) ndarray_floating_point_divide(tfloat32, tfloat32) ndarray_floating_point_divide(tfloat64, tfloat64) register_function("values", (dtype("dict"),), dtype("array")) - register_function("sliceRight", (dtype("str"), dtype("int32"),), dtype("str")) - register_function("get", (dtype("dict"), dtype("?key"),), dtype("?value")) - register_function("get", (dtype("dict"), dtype("?key"), dtype("?value"),), dtype("?value")) + register_function( + "sliceRight", + ( + dtype("str"), + dtype("int32"), + ), + dtype("str"), + ) + register_function( + "get", + ( + dtype("dict"), + dtype("?key"), + ), + dtype("?value"), + ) + register_function( + "get", + ( + dtype("dict"), + dtype("?key"), + dtype("?value"), + ), + dtype("?value"), + ) register_function("max", (dtype("array"),), dtype("?T")) register_function("nanmax", (dtype("array"),), dtype("?T")) - register_function("max", (dtype("?T"), dtype("?T"),), dtype("?T")) - register_function("nanmax", (dtype("?T"), dtype("?T"),), dtype("?T")) - register_function("max_ignore_missing", (dtype("?T"), dtype("?T"),), dtype("?T")) - register_function("nanmax_ignore_missing", (dtype("?T"), dtype("?T"),), dtype("?T")) + register_function( + "max", + ( + dtype("?T"), + dtype("?T"), + ), + dtype("?T"), + ) + register_function( + "nanmax", + ( + dtype("?T"), + dtype("?T"), + ), + dtype("?T"), + ) + register_function( + "max_ignore_missing", + ( + dtype("?T"), + dtype("?T"), + ), + dtype("?T"), + ) + register_function( + "nanmax_ignore_missing", + ( + dtype("?T"), + dtype("?T"), + ), + dtype("?T"), + ) register_function("product", (dtype("array"),), dtype("?T")) register_function("toInt32", (dtype("?T:numeric"),), dtype("int32")) - register_function("extend", (dtype("array"), dtype("array"),), dtype("array")) + register_function( + "extend", + ( + dtype("array"), + dtype("array"), + ), + dtype("array"), + ) register_function("argmin", (dtype("array"),), dtype("int32")) register_function("toFloat64", (dtype("?T:numeric"),), dtype("float64")) register_function("sort", (dtype("array"),), dtype("array")) - register_function("sort", (dtype("array"), dtype("bool"),), dtype("array")) - register_function("isSubset", (dtype("set"), dtype("set"),), dtype("bool")) - register_function("slice", (dtype("str"), dtype("int32"), dtype("int32"),), dtype("str")) - register_function("add", (dtype("array"), dtype("array"),), dtype("array")) - register_function("add", (dtype("array"), dtype("?T"),), dtype("array")) - register_function("add", (dtype("?T:numeric"), dtype("array"),), dtype("array")) - register_function("add", (dtype("ndarray"), dtype("ndarray"),), dtype("ndarray")) - register_function("add", (dtype("ndarray"), dtype("?T"),), dtype("ndarray")) - register_function("add", (dtype("?T:numeric"), dtype("ndarray"),), dtype("ndarray")) - register_function("pow", (dtype("array"), dtype("array"),), dtype("array")) - register_function("pow", (dtype("array"), dtype("?T"),), dtype("array")) - register_function("pow", (dtype("?T:numeric"), dtype("array"),), dtype("array")) - register_function("pow", (dtype("ndarray"), dtype("ndarray"),), dtype("ndarray")) - register_function("pow", (dtype("ndarray"), dtype("?T"),), dtype("ndarray")) - register_function("pow", (dtype("?T:numeric"), dtype("ndarray"),), dtype("ndarray")) - register_function("append", (dtype("array"), dtype("?T"),), dtype("array")) - register_function("sliceLeft", (dtype("str"), dtype("int32"),), dtype("str")) - register_function("remove", (dtype("set"), dtype("?T"),), dtype("set")) - register_function("index", (dtype("str"), dtype("int32"),), dtype("str")) + register_function( + "sort", + ( + dtype("array"), + dtype("bool"), + ), + dtype("array"), + ) + register_function( + "isSubset", + ( + dtype("set"), + dtype("set"), + ), + dtype("bool"), + ) + register_function( + "slice", + ( + dtype("str"), + dtype("int32"), + dtype("int32"), + ), + dtype("str"), + ) + register_function( + "add", + ( + dtype("array"), + dtype("array"), + ), + dtype("array"), + ) + register_function( + "add", + ( + dtype("array"), + dtype("?T"), + ), + dtype("array"), + ) + register_function( + "add", + ( + dtype("?T:numeric"), + dtype("array"), + ), + dtype("array"), + ) + register_function( + "add", + ( + dtype("ndarray"), + dtype("ndarray"), + ), + dtype("ndarray"), + ) + register_function( + "add", + ( + dtype("ndarray"), + dtype("?T"), + ), + dtype("ndarray"), + ) + register_function( + "add", + ( + dtype("?T:numeric"), + dtype("ndarray"), + ), + dtype("ndarray"), + ) + register_function( + "pow", + ( + dtype("array"), + dtype("array"), + ), + dtype("array"), + ) + register_function( + "pow", + ( + dtype("array"), + dtype("?T"), + ), + dtype("array"), + ) + register_function( + "pow", + ( + dtype("?T:numeric"), + dtype("array"), + ), + dtype("array"), + ) + register_function( + "pow", + ( + dtype("ndarray"), + dtype("ndarray"), + ), + dtype("ndarray"), + ) + register_function( + "pow", + ( + dtype("ndarray"), + dtype("?T"), + ), + dtype("ndarray"), + ) + register_function( + "pow", + ( + dtype("?T:numeric"), + dtype("ndarray"), + ), + dtype("ndarray"), + ) + register_function( + "append", + ( + dtype("array"), + dtype("?T"), + ), + dtype("array"), + ) + register_function( + "sliceLeft", + ( + dtype("str"), + dtype("int32"), + ), + dtype("str"), + ) + register_function( + "remove", + ( + dtype("set"), + dtype("?T"), + ), + dtype("set"), + ) + register_function( + "index", + ( + dtype("str"), + dtype("int32"), + ), + dtype("str"), + ) register_function("indexArray", (dtype("array"), dtype("int32"), dtype("str")), dtype("?T")) - register_function("index", (dtype("dict"), dtype("?key"),), dtype("?value")) + register_function( + "index", + ( + dtype("dict"), + dtype("?key"), + ), + dtype("?value"), + ) register_function("dictToArray", (dtype("dict"),), dtype("array")) - register_function("mod", (dtype("array"), dtype("array"),), dtype("array")) - register_function("mod", (dtype("array"), dtype("?T"),), dtype("array")) - register_function("mod", (dtype("?T:numeric"), dtype("array"),), dtype("array")) - register_function("mod", (dtype("ndarray"), dtype("ndarray"),), dtype("ndarray")) - register_function("mod", (dtype("ndarray"), dtype("?T"),), dtype("ndarray")) - register_function("mod", (dtype("?T:numeric"), dtype("ndarray"),), dtype("ndarray")) + register_function( + "mod", + ( + dtype("array"), + dtype("array"), + ), + dtype("array"), + ) + register_function( + "mod", + ( + dtype("array"), + dtype("?T"), + ), + dtype("array"), + ) + register_function( + "mod", + ( + dtype("?T:numeric"), + dtype("array"), + ), + dtype("array"), + ) + register_function( + "mod", + ( + dtype("ndarray"), + dtype("ndarray"), + ), + dtype("ndarray"), + ) + register_function( + "mod", + ( + dtype("ndarray"), + dtype("?T"), + ), + dtype("ndarray"), + ) + register_function( + "mod", + ( + dtype("?T:numeric"), + dtype("ndarray"), + ), + dtype("ndarray"), + ) register_function("dict", (dtype("array"),), dtype("dict")) register_function("dict", (dtype("set"),), dtype("dict")) register_function("keys", (dtype("dict"),), dtype("array")) register_function("min", (dtype("array"),), dtype("?T")) register_function("nanmin", (dtype("array"),), dtype("?T")) - register_function("min", (dtype("?T"), dtype("?T"),), dtype("?T")) - register_function("nanmin", (dtype("?T"), dtype("?T"),), dtype("?T")) - register_function("min_ignore_missing", (dtype("?T"), dtype("?T"),), dtype("?T")) - register_function("nanmin_ignore_missing", (dtype("?T"), dtype("?T"),), dtype("?T")) + register_function( + "min", + ( + dtype("?T"), + dtype("?T"), + ), + dtype("?T"), + ) + register_function( + "nanmin", + ( + dtype("?T"), + dtype("?T"), + ), + dtype("?T"), + ) + register_function( + "min_ignore_missing", + ( + dtype("?T"), + dtype("?T"), + ), + dtype("?T"), + ) + register_function( + "nanmin_ignore_missing", + ( + dtype("?T"), + dtype("?T"), + ), + dtype("?T"), + ) register_function("sum", (dtype("array"),), dtype("?T")) register_function("toInt64", (dtype("?T:numeric"),), dtype("int64")) - register_function("contains", (dtype("dict"), dtype("?key"),), dtype("bool")) - register_function("contains", (dtype("array"), dtype("?T"),), dtype("bool")) - register_function("contains", (dtype("set"), dtype("?T"),), dtype("bool")) - register_function("-", (dtype("array"), dtype("?T"),), dtype("array")) - register_function("-", (dtype("array"), dtype("array"),), dtype("array")) - register_function("-", (dtype("?T:numeric"), dtype("array"),), dtype("array")) - register_function("-", (dtype("ndarray"), dtype("ndarray"),), dtype("ndarray")) - register_function("-", (dtype("ndarray"), dtype("?T"),), dtype("ndarray")) - register_function("-", (dtype("?T:numeric"), dtype("ndarray"),), dtype("ndarray")) + register_function( + "contains", + ( + dtype("dict"), + dtype("?key"), + ), + dtype("bool"), + ) + register_function( + "contains", + ( + dtype("array"), + dtype("?T"), + ), + dtype("bool"), + ) + register_function( + "contains", + ( + dtype("set"), + dtype("?T"), + ), + dtype("bool"), + ) + register_function( + "-", + ( + dtype("array"), + dtype("?T"), + ), + dtype("array"), + ) + register_function( + "-", + ( + dtype("array"), + dtype("array"), + ), + dtype("array"), + ) + register_function( + "-", + ( + dtype("?T:numeric"), + dtype("array"), + ), + dtype("array"), + ) + register_function( + "-", + ( + dtype("ndarray"), + dtype("ndarray"), + ), + dtype("ndarray"), + ) + register_function( + "-", + ( + dtype("ndarray"), + dtype("?T"), + ), + dtype("ndarray"), + ) + register_function( + "-", + ( + dtype("?T:numeric"), + dtype("ndarray"), + ), + dtype("ndarray"), + ) register_function("addone", (dtype("int32"),), dtype("int32")) register_function("isEmpty", (dtype("dict"),), dtype("bool")) register_function("isEmpty", (dtype("array"),), dtype("bool")) register_function("isEmpty", (dtype("set"),), dtype("bool")) - register_function("union", (dtype("set"), dtype("set"),), dtype("set")) - register_function("mul", (dtype("array"), dtype("array"),), dtype("array")) - register_function("mul", (dtype("array"), dtype("?T"),), dtype("array")) - register_function("mul", (dtype("?T:numeric"), dtype("array"),), dtype("array")) - register_function("mul", (dtype("ndarray"), dtype("ndarray"),), dtype("ndarray")) - register_function("mul", (dtype("ndarray"), dtype("?T"),), dtype("ndarray")) - register_function("mul", (dtype("?T:numeric"), dtype("ndarray"),), dtype("ndarray")) - register_function("intersection", (dtype("set"), dtype("set"),), dtype("set")) - register_function("add", (dtype("set"), dtype("?T"),), dtype("set")) + register_function( + "union", + ( + dtype("set"), + dtype("set"), + ), + dtype("set"), + ) + register_function( + "mul", + ( + dtype("array"), + dtype("array"), + ), + dtype("array"), + ) + register_function( + "mul", + ( + dtype("array"), + dtype("?T"), + ), + dtype("array"), + ) + register_function( + "mul", + ( + dtype("?T:numeric"), + dtype("array"), + ), + dtype("array"), + ) + register_function( + "mul", + ( + dtype("ndarray"), + dtype("ndarray"), + ), + dtype("ndarray"), + ) + register_function( + "mul", + ( + dtype("ndarray"), + dtype("?T"), + ), + dtype("ndarray"), + ) + register_function( + "mul", + ( + dtype("?T:numeric"), + dtype("ndarray"), + ), + dtype("ndarray"), + ) + register_function( + "intersection", + ( + dtype("set"), + dtype("set"), + ), + dtype("set"), + ) + register_function( + "add", + ( + dtype("set"), + dtype("?T"), + ), + dtype("set"), + ) register_function("argmax", (dtype("array"),), dtype("int32")) - register_function("floordiv", (dtype("array"), dtype("array"),), dtype("array")) - register_function("floordiv", (dtype("array"), dtype("?T"),), dtype("array")) - register_function("floordiv", (dtype("?T:numeric"), dtype("array"),), dtype("array")) - register_function("floordiv", (dtype("ndarray"), dtype("ndarray"),), dtype("ndarray")) - register_function("floordiv", (dtype("ndarray"), dtype("?T"),), dtype("ndarray")) - register_function("floordiv", (dtype("?T:numeric"), dtype("ndarray"),), dtype("ndarray")) + register_function( + "floordiv", + ( + dtype("array"), + dtype("array"), + ), + dtype("array"), + ) + register_function( + "floordiv", + ( + dtype("array"), + dtype("?T"), + ), + dtype("array"), + ) + register_function( + "floordiv", + ( + dtype("?T:numeric"), + dtype("array"), + ), + dtype("array"), + ) + register_function( + "floordiv", + ( + dtype("ndarray"), + dtype("ndarray"), + ), + dtype("ndarray"), + ) + register_function( + "floordiv", + ( + dtype("ndarray"), + dtype("?T"), + ), + dtype("ndarray"), + ) + register_function( + "floordiv", + ( + dtype("?T:numeric"), + dtype("ndarray"), + ), + dtype("ndarray"), + ) register_function("keySet", (dtype("dict"),), dtype("set")) register_function("qnorm", (dtype("float64"),), dtype("float64")) - register_function("oneHotAlleles", (dtype("call"), dtype("int32"),), dtype("array")) - register_function("dpois", (dtype("float64"), dtype("float64"), dtype("bool"),), dtype("float64")) - register_function("dpois", (dtype("float64"), dtype("float64"),), dtype("float64")) + register_function( + "oneHotAlleles", + ( + dtype("call"), + dtype("int32"), + ), + dtype("array"), + ) + register_function( + "dpois", + ( + dtype("float64"), + dtype("float64"), + dtype("bool"), + ), + dtype("float64"), + ) + register_function( + "dpois", + ( + dtype("float64"), + dtype("float64"), + ), + dtype("float64"), + ) register_function("ploidy", (dtype("call"),), dtype("int32")) - register_function("lor", (dtype("bool"), dtype("bool"),), dtype("bool")) - register_function("ppois", (dtype("float64"), dtype("float64"), dtype("bool"), dtype("bool"),), dtype("float64")) - register_function("ppois", (dtype("float64"), dtype("float64"),), dtype("float64")) + register_function( + "lor", + ( + dtype("bool"), + dtype("bool"), + ), + dtype("bool"), + ) + register_function( + "ppois", + ( + dtype("float64"), + dtype("float64"), + dtype("bool"), + dtype("bool"), + ), + dtype("float64"), + ) + register_function( + "ppois", + ( + dtype("float64"), + dtype("float64"), + ), + dtype("float64"), + ) register_function("log10", (dtype("float64"),), dtype("float64")) register_function("isHet", (dtype("call"),), dtype("bool")) register_function("add_on_contig", (dtype("?T:locus"), dtype("int32")), dtype("?T:locus")) register_function("contig_idx", (dtype("?T:locus")), dtype("int32")) register_function("isAutosomalOrPseudoAutosomal", (dtype("?T:locus"),), dtype("bool")) - register_function("testCodeUnification", (dtype("?x:numeric"), dtype("?x:int32"),), dtype("?x")) + register_function( + "testCodeUnification", + ( + dtype("?x:numeric"), + dtype("?x:int32"), + ), + dtype("?x"), + ) register_seeded_function("rand_pois", (dtype("float64"),), dtype("float64")) - register_seeded_function("rand_pois", (dtype("int32"), dtype("float64"),), dtype("array")) + register_seeded_function( + "rand_pois", + ( + dtype("int32"), + dtype("float64"), + ), + dtype("array"), + ) register_function("toFloat32", (dtype("str"),), dtype("float32")) register_function("toFloat32", (dtype("bool"),), dtype("float32")) register_function("isAutosomal", (dtype("?T:locus"),), dtype("bool")) register_function("isPhased", (dtype("call"),), dtype("bool")) register_function("isHomVar", (dtype("call"),), dtype("bool")) - register_function("corr", (dtype("array"), dtype("array"),), dtype("float64")) - register_function("log", (dtype("float64"), dtype("float64"),), dtype("float64")) + register_function( + "corr", + ( + dtype("array"), + dtype("array"), + ), + dtype("float64"), + ) + register_function( + "log", + ( + dtype("float64"), + dtype("float64"), + ), + dtype("float64"), + ) register_function("log", (dtype("float64"),), dtype("float64")) register_function("foobar2", (), dtype("int32")) - register_function("approxEqual", (dtype("float64"), dtype("float64"), dtype("float64"), dtype("bool"), dtype("bool"),), dtype("bool")) + register_function( + "approxEqual", + ( + dtype("float64"), + dtype("float64"), + dtype("float64"), + dtype("bool"), + dtype("bool"), + ), + dtype("bool"), + ) register_function("includesEnd", (dtype("interval"),), dtype("bool")) register_function("position", (dtype("?T:locus"),), dtype("int32")) - register_seeded_function("rand_unif", (dtype("float64"), dtype("float64"),), dtype("float64")) + register_seeded_function( + "rand_unif", + ( + dtype("float64"), + dtype("float64"), + ), + dtype("float64"), + ) register_seeded_function("rand_int32", (dtype("int32")), dtype("int32")) register_seeded_function("rand_int64", (dtype("int64"),), dtype("int64")) register_function("showStr", (dtype("?T"), dtype("int32")), dtype("str")) register_function("str", (dtype("?T"),), dtype("str")) - register_function("valuesSimilar", (dtype("?T"), dtype("?T"), dtype('float64'), dtype('bool'),), dtype("bool")) - register_function("replace", (dtype("str"), dtype("str"), dtype("str"),), dtype("str")) + register_function( + "valuesSimilar", + ( + dtype("?T"), + dtype("?T"), + dtype('float64'), + dtype('bool'), + ), + dtype("bool"), + ) + register_function( + "replace", + ( + dtype("str"), + dtype("str"), + dtype("str"), + ), + dtype("str"), + ) register_function("exp", (dtype("float64"),), dtype("float64")) - register_function("land", (dtype("bool"), dtype("bool"),), dtype("bool")) - register_function("compare", (dtype("int32"), dtype("int32"),), dtype("int32")) + register_function( + "land", + ( + dtype("bool"), + dtype("bool"), + ), + dtype("bool"), + ) + register_function( + "compare", + ( + dtype("int32"), + dtype("int32"), + ), + dtype("int32"), + ) register_function("triangle", (dtype("int32"),), dtype("int32")) - register_function("Interval", (dtype("?T"), dtype("?T"), dtype("bool"), dtype("bool"),), dtype("interval")) + register_function( + "Interval", + ( + dtype("?T"), + dtype("?T"), + dtype("bool"), + dtype("bool"), + ), + dtype("interval"), + ) register_function("contig", (dtype("?T:locus"),), dtype("str")) register_function("Call", (dtype("bool"),), dtype("call")) register_function("Call", (dtype("str"),), dtype("call")) - register_function("Call", (dtype("int32"), dtype("bool"),), dtype("call")) - register_function("Call", (dtype("int32"), dtype("int32"), dtype("bool"),), dtype("call")) - register_function("Call", (dtype("array"), dtype("bool"),), dtype("call")) - register_function("qchisqtail", (dtype("float64"), dtype("float64"),), dtype("float64")) - register_function("binomTest", (dtype("int32"), dtype("int32"), dtype("float64"), dtype("int32"),), dtype("float64")) - register_function("qpois", (dtype("float64"), dtype("float64"),), dtype("int32")) - register_function("qpois", (dtype("float64"), dtype("float64"), dtype("bool"), dtype("bool"),), dtype("int32")) + register_function( + "Call", + ( + dtype("int32"), + dtype("bool"), + ), + dtype("call"), + ) + register_function( + "Call", + ( + dtype("int32"), + dtype("int32"), + dtype("bool"), + ), + dtype("call"), + ) + register_function( + "Call", + ( + dtype("array"), + dtype("bool"), + ), + dtype("call"), + ) + register_function( + "qchisqtail", + ( + dtype("float64"), + dtype("float64"), + ), + dtype("float64"), + ) + register_function( + "binomTest", + ( + dtype("int32"), + dtype("int32"), + dtype("float64"), + dtype("int32"), + ), + dtype("float64"), + ) + register_function( + "qpois", + ( + dtype("float64"), + dtype("float64"), + ), + dtype("int32"), + ) + register_function( + "qpois", + ( + dtype("float64"), + dtype("float64"), + dtype("bool"), + dtype("bool"), + ), + dtype("int32"), + ) register_function("is_finite", (dtype("float32"),), dtype("bool")) register_function("is_finite", (dtype("float64"),), dtype("bool")) register_function("inYPar", (dtype("?T:locus"),), dtype("bool")) - register_function("contingency_table_test", (dtype("int32"), dtype("int32"), dtype("int32"), dtype("int32"), dtype("int32"),), dtype("struct{p_value: float64, odds_ratio: float64}")) + register_function( + "contingency_table_test", + ( + dtype("int32"), + dtype("int32"), + dtype("int32"), + dtype("int32"), + dtype("int32"), + ), + dtype("struct{p_value: float64, odds_ratio: float64}"), + ) register_function("toInt32", (dtype("bool"),), dtype("int32")) register_function("toInt32", (dtype("str"),), dtype("int32")) register_function("foobar1", (), dtype("int32")) register_function("toFloat64", (dtype("str"),), dtype("float64")) register_function("toFloat64", (dtype("bool"),), dtype("float64")) - register_function("dbeta", (dtype("float64"), dtype("float64"), dtype("float64"),), dtype("float64")) + register_function( + "dbeta", + ( + dtype("float64"), + dtype("float64"), + dtype("float64"), + ), + dtype("float64"), + ) register_function("Locus", (dtype("str"),), dtype("?T:locus")) - register_function("Locus", (dtype("str"), dtype("int32"),), dtype("?T:locus")) + register_function( + "Locus", + ( + dtype("str"), + dtype("int32"), + ), + dtype("?T:locus"), + ) register_function("LocusAlleles", (dtype("str"),), dtype("struct{locus: ?T, alleles: array}")) - register_function("LocusInterval", (dtype("str"), dtype("bool"),), dtype("interval")) - register_function("LocusInterval", (dtype("str"), dtype("int32"), dtype("int32"), dtype("bool"), dtype("bool"), dtype("bool"),), dtype("interval")) + register_function( + "LocusInterval", + ( + dtype("str"), + dtype("bool"), + ), + dtype("interval"), + ) + register_function( + "LocusInterval", + ( + dtype("str"), + dtype("int32"), + dtype("int32"), + dtype("bool"), + dtype("bool"), + dtype("bool"), + ), + dtype("interval"), + ) register_function("globalPosToLocus", (dtype("int64"),), dtype("?T:locus")) register_function("locusToGlobalPos", (dtype("?T:locus"),), dtype("int64")) - register_function("liftoverLocus", (dtype("?T:locus"), dtype('float64'),), dtype("struct{result:?U:locus,is_negative_strand:bool}")) - register_function("liftoverLocusInterval", (dtype("interval"), dtype('float64'),), dtype("struct{result:interval,is_negative_strand:bool}")) - register_function("min_rep", (dtype("?T:locus"), dtype("array"),), dtype("struct{locus: ?T, alleles: array}")) - register_function("locus_windows_per_contig", (dtype("array>"), dtype("float64"),), dtype("tuple(array, array)")) + register_function( + "liftoverLocus", + ( + dtype("?T:locus"), + dtype('float64'), + ), + dtype("struct{result:?U:locus,is_negative_strand:bool}"), + ) + register_function( + "liftoverLocusInterval", + ( + dtype("interval"), + dtype('float64'), + ), + dtype("struct{result:interval,is_negative_strand:bool}"), + ) + register_function( + "min_rep", + ( + dtype("?T:locus"), + dtype("array"), + ), + dtype("struct{locus: ?T, alleles: array}"), + ) + register_function( + "locus_windows_per_contig", + ( + dtype("array>"), + dtype("float64"), + ), + dtype("tuple(array, array)"), + ) register_function("toBoolean", (dtype("str"),), dtype("bool")) register_seeded_function("rand_bool", (dtype("float64"),), dtype("bool")) - register_function("pchisqtail", (dtype("float64"), dtype("float64"),), dtype("float64")) + register_function( + "pchisqtail", + ( + dtype("float64"), + dtype("float64"), + ), + dtype("float64"), + ) register_seeded_function("rand_cat", (dtype("array"),), dtype("int32")) register_function("inYNonPar", (dtype("?T:locus"),), dtype("bool")) - register_function("concat", (dtype("str"), dtype("str"),), dtype("str")) - register_function("pow", (dtype("float32"), dtype("float32"),), dtype("float64")) - register_function("pow", (dtype("int32"), dtype("int32"),), dtype("float64")) - register_function("pow", (dtype("int64"), dtype("int64"),), dtype("float64")) - register_function("pow", (dtype("float64"), dtype("float64"),), dtype("float64")) + register_function( + "concat", + ( + dtype("str"), + dtype("str"), + ), + dtype("str"), + ) + register_function( + "pow", + ( + dtype("float32"), + dtype("float32"), + ), + dtype("float64"), + ) + register_function( + "pow", + ( + dtype("int32"), + dtype("int32"), + ), + dtype("float64"), + ) + register_function( + "pow", + ( + dtype("int64"), + dtype("int64"), + ), + dtype("float64"), + ) + register_function( + "pow", + ( + dtype("float64"), + dtype("float64"), + ), + dtype("float64"), + ) register_function("length", (dtype("str"),), dtype("int32")) - register_function("slice", (dtype("str"), dtype("int32"), dtype("int32"),), dtype("str")) - register_function("split", (dtype("str"), dtype("str"), dtype("int32"),), dtype("array")) - register_function("split", (dtype("str"), dtype("str"),), dtype("array")) - register_function("splitQuotedChar", (dtype("str"), dtype("str"), dtype("array"), dtype("str"),), - dtype("array")) - register_function("splitQuotedRegex", (dtype("str"), dtype("str"), dtype("array"), dtype("str"),), - dtype("array")) - register_function("splitChar", (dtype("str"), dtype("str"), dtype("array"),), dtype("array")) - register_function("splitRegex", (dtype("str"), dtype("str"), dtype("array"),), dtype("array")) - register_seeded_function("rand_gamma", (dtype("float64"), dtype("float64"),), dtype("float64")) + register_function( + "slice", + ( + dtype("str"), + dtype("int32"), + dtype("int32"), + ), + dtype("str"), + ) + register_function( + "split", + ( + dtype("str"), + dtype("str"), + dtype("int32"), + ), + dtype("array"), + ) + register_function( + "split", + ( + dtype("str"), + dtype("str"), + ), + dtype("array"), + ) + register_function( + "splitQuotedChar", + ( + dtype("str"), + dtype("str"), + dtype("array"), + dtype("str"), + ), + dtype("array"), + ) + register_function( + "splitQuotedRegex", + ( + dtype("str"), + dtype("str"), + dtype("array"), + dtype("str"), + ), + dtype("array"), + ) + register_function( + "splitChar", + ( + dtype("str"), + dtype("str"), + dtype("array"), + ), + dtype("array"), + ) + register_function( + "splitRegex", + ( + dtype("str"), + dtype("str"), + dtype("array"), + ), + dtype("array"), + ) + register_seeded_function( + "rand_gamma", + ( + dtype("float64"), + dtype("float64"), + ), + dtype("float64"), + ) register_function("UnphasedDiploidGtIndexCall", (dtype("int32"),), dtype("call")) - register_function("lgt_to_gt", (dtype("call"), dtype("array"),), dtype("call")) - register_function("local_to_global_g", (dtype("array"), dtype("array"), dtype("int32"), dtype("?T")), dtype("array")) - register_function("local_to_global_a_r", (dtype("array"), dtype("array"), dtype("int32"), dtype("?T"), dtype("bool")), dtype("array")) - register_function("index", (dtype("call"), dtype("int32"),), dtype("int32")) + register_function( + "lgt_to_gt", + ( + dtype("call"), + dtype("array"), + ), + dtype("call"), + ) + register_function( + "local_to_global_g", + (dtype("array"), dtype("array"), dtype("int32"), dtype("?T")), + dtype("array"), + ) + register_function( + "local_to_global_a_r", + (dtype("array"), dtype("array"), dtype("int32"), dtype("?T"), dtype("bool")), + dtype("array"), + ) + register_function( + "index", + ( + dtype("call"), + dtype("int32"), + ), + dtype("int32"), + ) register_function("sign", (dtype("int64"),), dtype("int64")) register_function("sign", (dtype("float64"),), dtype("float64")) register_function("sign", (dtype("float32"),), dtype("float32")) register_function("sign", (dtype("int32"),), dtype("int32")) register_function("unphasedDiploidGtIndex", (dtype("call"),), dtype("int32")) register_function("gamma", (dtype("float64"),), dtype("float64")) - register_function("mod", (dtype("float64"), dtype("float64"),), dtype("float64")) - register_function("mod", (dtype("int64"), dtype("int64"),), dtype("int64")) - register_function("mod", (dtype("float32"), dtype("float32"),), dtype("float32")) - register_function("mod", (dtype("int32"), dtype("int32"),), dtype("int32")) - register_function("fisher_exact_test", (dtype("int32"), dtype("int32"), dtype("int32"), dtype("int32"),), dtype("struct{p_value: float64, odds_ratio: float64, ci_95_lower: float64, ci_95_upper: float64}")) + register_function( + "mod", + ( + dtype("float64"), + dtype("float64"), + ), + dtype("float64"), + ) + register_function( + "mod", + ( + dtype("int64"), + dtype("int64"), + ), + dtype("int64"), + ) + register_function( + "mod", + ( + dtype("float32"), + dtype("float32"), + ), + dtype("float32"), + ) + register_function( + "mod", + ( + dtype("int32"), + dtype("int32"), + ), + dtype("int32"), + ) + register_function( + "fisher_exact_test", + ( + dtype("int32"), + dtype("int32"), + dtype("int32"), + dtype("int32"), + ), + dtype("struct{p_value: float64, odds_ratio: float64, ci_95_lower: float64, ci_95_upper: float64}"), + ) register_function("floor", (dtype("float64"),), dtype("float64")) register_function("floor", (dtype("float32"),), dtype("float32")) register_function("isNonRef", (dtype("call"),), dtype("bool")) register_function("includesStart", (dtype("interval"),), dtype("bool")) register_function("isHetNonRef", (dtype("call"),), dtype("bool")) - register_function("hardy_weinberg_test", (dtype("int32"), dtype("int32"), dtype("int32"), dtype("bool")), dtype("struct{het_freq_hwe: float64, p_value: float64}")) + register_function( + "hardy_weinberg_test", + (dtype("int32"), dtype("int32"), dtype("int32"), dtype("bool")), + dtype("struct{het_freq_hwe: float64, p_value: float64}"), + ) register_function("haplotype_freq_em", (dtype("array"),), dtype("array")) register_function("nNonRefAlleles", (dtype("call"),), dtype("int32")) register_function("abs", (dtype("float64"),), dtype("float64")) register_function("abs", (dtype("float32"),), dtype("float32")) register_function("abs", (dtype("int64"),), dtype("int64")) register_function("abs", (dtype("int32"),), dtype("int32")) - register_function("endswith", (dtype("str"), dtype("str"),), dtype("bool")) + register_function( + "endswith", + ( + dtype("str"), + dtype("str"), + ), + dtype("bool"), + ) register_function("sqrt", (dtype("float64"),), dtype("float64")) register_function("isnan", (dtype("float32"),), dtype("bool")) register_function("isnan", (dtype("float64"),), dtype("bool")) register_function("lower", (dtype("str"),), dtype("str")) - register_seeded_function("rand_beta", (dtype("float64"), dtype("float64"),), dtype("float64")) - register_seeded_function("rand_beta", (dtype("float64"), dtype("float64"), dtype("float64"), dtype("float64"),), dtype("float64")) + register_seeded_function( + "rand_beta", + ( + dtype("float64"), + dtype("float64"), + ), + dtype("float64"), + ) + register_seeded_function( + "rand_beta", + ( + dtype("float64"), + dtype("float64"), + dtype("float64"), + dtype("float64"), + ), + dtype("float64"), + ) register_function("toInt64", (dtype("bool"),), dtype("int64")) register_function("toInt64", (dtype("str"),), dtype("int64")) register_function("testCodeUnification2", (dtype("?x"),), dtype("?x")) - register_function("contains", (dtype("str"), dtype("str"),), dtype("bool")) - register_function("contains", (dtype("interval"), dtype("?T"),), dtype("bool")) + register_function( + "contains", + ( + dtype("str"), + dtype("str"), + ), + dtype("bool"), + ) + register_function( + "contains", + ( + dtype("interval"), + dtype("?T"), + ), + dtype("bool"), + ) register_function("entropy", (dtype("str"),), dtype("float64")) - register_function("filtering_allele_frequency", (dtype("int32"), dtype("int32"), dtype("float64"),), dtype("float64")) + register_function( + "filtering_allele_frequency", + ( + dtype("int32"), + dtype("int32"), + dtype("float64"), + ), + dtype("float64"), + ) register_function("gqFromPL", (dtype("array"),), dtype("int32")) - register_function("startswith", (dtype("str"), dtype("str"),), dtype("bool")) + register_function( + "startswith", + ( + dtype("str"), + dtype("str"), + ), + dtype("bool"), + ) register_function("ceil", (dtype("float32"),), dtype("float32")) register_function("ceil", (dtype("float64"),), dtype("float64")) register_function("json", (dtype("?T"),), dtype("str")) register_function("strip", (dtype("str"),), dtype("str")) - register_function("firstMatchIn", (dtype("str"), dtype("str"),), dtype("array")) + register_function( + "firstMatchIn", + ( + dtype("str"), + dtype("str"), + ), + dtype("array"), + ) register_function("isEmpty", (dtype("interval"),), dtype("bool")) - register_function("~", (dtype("str"), dtype("str"),), dtype("bool")) - register_function("mkString", (dtype("set"), dtype("str"),), dtype("str")) - register_function("mkString", (dtype("array"), dtype("str"),), dtype("str")) + register_function( + "~", + ( + dtype("str"), + dtype("str"), + ), + dtype("bool"), + ) + register_function( + "mkString", + ( + dtype("set"), + dtype("str"), + ), + dtype("str"), + ) + register_function( + "mkString", + ( + dtype("array"), + dtype("str"), + ), + dtype("str"), + ) register_function("dosage", (dtype("array"),), dtype("float64")) register_function("upper", (dtype("str"),), dtype("str")) - register_function("overlaps", (dtype("interval"), dtype("interval"),), dtype("bool")) - register_function("downcode", (dtype("call"), dtype("int32"),), dtype("call")) + register_function( + "overlaps", + ( + dtype("interval"), + dtype("interval"), + ), + dtype("bool"), + ) + register_function( + "downcode", + ( + dtype("call"), + dtype("int32"), + ), + dtype("call"), + ) register_function("inXPar", (dtype("?T:locus"),), dtype("bool")) - register_function("format", (dtype("str"), dtype("?T:tuple"),), dtype("str")) + register_function( + "format", + ( + dtype("str"), + dtype("?T:tuple"), + ), + dtype("str"), + ) register_function("pnorm", (dtype("float64"),), dtype("float64")) register_function("is_infinite", (dtype("float32"),), dtype("bool")) register_function("is_infinite", (dtype("float64"),), dtype("bool")) register_function("isHetRef", (dtype("call"),), dtype("bool")) register_function("isMitochondrial", (dtype("?T:locus"),), dtype("bool")) - register_function("hamming", (dtype("str"), dtype("str"),), dtype("int32")) + register_function( + "hamming", + ( + dtype("str"), + dtype("str"), + ), + dtype("int32"), + ) register_function("end", (dtype("interval"),), dtype("?T")) register_function("start", (dtype("interval"),), dtype("?T")) register_function("inXNonPar", (dtype("?T:locus"),), dtype("bool")) register_function("escapeString", (dtype("str"),), dtype("str")) register_function("isHomRef", (dtype("call"),), dtype("bool")) - register_seeded_function("rand_norm", (dtype("float64"), dtype("float64"),), dtype("float64")) - register_function("chi_squared_test", (dtype("int32"), dtype("int32"), dtype("int32"), dtype("int32"),), dtype("struct{p_value: float64, odds_ratio: float64}")) + register_seeded_function( + "rand_norm", + ( + dtype("float64"), + dtype("float64"), + ), + dtype("float64"), + ) + register_function( + "chi_squared_test", + ( + dtype("int32"), + dtype("int32"), + dtype("int32"), + dtype("int32"), + ), + dtype("struct{p_value: float64, odds_ratio: float64}"), + ) register_function("strftime", (dtype("str"), dtype("int64"), dtype("str")), dtype("str")) register_function("strptime", (dtype("str"), dtype("str"), dtype("str")), dtype("int64")) - register_function("index_bgen", (dtype("str"), dtype("str"), dtype("dict"), dtype('bool'), dtype("int32")), dtype("int64"), (dtype("?T"),)) - register_function("getVCFHeader", (dtype("str"), dtype("str"), dtype("str"), dtype("str")), dtype(vcf_header_type_str),) + register_function( + "index_bgen", + (dtype("str"), dtype("str"), dtype("dict"), dtype('bool'), dtype("int32")), + dtype("int64"), + (dtype("?T"),), + ) + register_function( + "getVCFHeader", + (dtype("str"), dtype("str"), dtype("str"), dtype("str")), + dtype(vcf_header_type_str), + ) diff --git a/hail/python/hail/ir/renderer.py b/hail/python/hail/ir/renderer.py index 9a1d946ef22..d60ae357285 100644 --- a/hail/python/hail/ir/renderer.py +++ b/hail/python/hail/ir/renderer.py @@ -1,21 +1,19 @@ -from hail import ir import abc -from typing import Sequence, MutableSequence, List, Set, Dict, Optional from collections import namedtuple +from typing import Dict, List, MutableSequence, Optional, Sequence, Set + +from hail import ir class Renderable(object): @abc.abstractmethod - def render_head(self, r: 'Renderer') -> str: - ... + def render_head(self, r: 'Renderer') -> str: ... @abc.abstractmethod - def render_tail(self, r: 'Renderer') -> str: - ... + def render_tail(self, r: 'Renderer') -> str: ... @abc.abstractmethod - def render_children(self, r: 'Renderer') -> Sequence['Renderable']: - ... + def render_children(self, r: 'Renderer') -> Sequence['Renderable']: ... class RenderableStr(Renderable): @@ -126,9 +124,7 @@ def __call__(self, x: 'Renderable'): Context = (Vars, Vars, Vars) -BindingSite = namedtuple( - 'BindingSite', - 'depth lifted_lets agg_lifted_lets scan_lifted_lets') +BindingSite = namedtuple('BindingSite', 'depth lifted_lets agg_lifted_lets scan_lifted_lets') class CSERenderer(Renderer): @@ -162,9 +158,7 @@ def uid(self) -> str: # where for each descendant 'x' which will be bound above 'node', # 'lifted_lets' maps 'id(x)' to the unique id 'x' will be bound to. def __call__(self, root: 'ir.BaseIR') -> Dict[int, BindingSite]: - root_frame = self.StackFrame(0, 0, False, - ({var: 0 for var in root.free_vars}, {}, {}), - root) + root_frame = self.StackFrame(0, 0, False, ({var: 0 for var in root.free_vars}, {}, {}), root) stack = [root_frame] binding_sites = {} @@ -212,9 +206,8 @@ def __call__(self, root: 'ir.BaseIR') -> Dict[int, BindingSite]: elif child_frame.scan_scope: if id(child) in bind_frame.scan_visited: lets = bind_frame.scan_lifted_lets - else: - if id(child) in bind_frame.agg_visited: - lets = bind_frame.agg_lifted_lets + elif id(child) in bind_frame.agg_visited: + lets = bind_frame.agg_lifted_lets # 'lets' is either assigned before one of the 'br/has if lets is not None: @@ -235,13 +228,29 @@ def __call__(self, root: 'ir.BaseIR') -> Dict[int, BindingSite]: return binding_sites class StackFrame: - __slots__ = ['min_binding_depth', 'min_value_binding_depth', 'scan_scope', - 'context', 'node', 'visited', 'agg_visited', - 'scan_visited', 'lifted_lets', 'agg_lifted_lets', - 'scan_lifted_lets', 'child_idx'] - - def __init__(self, min_binding_depth: int, min_value_binding_depth: int, - scan_scope: bool, context: Context, x: 'ir.BaseIR'): + __slots__ = [ + 'min_binding_depth', + 'min_value_binding_depth', + 'scan_scope', + 'context', + 'node', + 'visited', + 'agg_visited', + 'scan_visited', + 'lifted_lets', + 'agg_lifted_lets', + 'scan_lifted_lets', + 'child_idx', + ] + + def __init__( + self, + min_binding_depth: int, + min_value_binding_depth: int, + scan_scope: bool, + context: Context, + x: 'ir.BaseIR', + ): # Immutable: # The node corresponding to this stack frame. @@ -290,17 +299,18 @@ def make_binding_site(self, depth): lifted_lets=self.lifted_lets, agg_lifted_lets=self.agg_lifted_lets, scan_lifted_lets=self.scan_lifted_lets, - depth=depth) + depth=depth, + ) # compute depth at which we might bind this node def bind_depth(self) -> int: bind_depth = self.min_binding_depth if len(self.node.free_vars) > 0: - bind_depth = max(bind_depth, max(self.context[0][var] for var in self.node.free_vars)) + bind_depth = max(bind_depth, *(self.context[0][var] for var in self.node.free_vars)) if len(self.node.free_agg_vars) > 0: - bind_depth = max(bind_depth, max(self.context[1][var] for var in self.node.free_agg_vars)) + bind_depth = max(bind_depth, *(self.context[1][var] for var in self.node.free_agg_vars)) if len(self.node.free_scan_vars) > 0: - bind_depth = max(bind_depth, max(self.context[2][var] for var in self.node.free_scan_vars)) + bind_depth = max(bind_depth, *(self.context[2][var] for var in self.node.free_scan_vars)) return bind_depth def make_child_frame(self, depth: int): @@ -322,7 +332,9 @@ def make_child_frame(self, depth: int): child_context = x.child_context(i, self.context, depth) - return CSEAnalysisPass.StackFrame(child_min_binding_depth, child_min_value_binding_depth, child_scan_scope, child_context, child) + return CSEAnalysisPass.StackFrame( + child_min_binding_depth, child_min_value_binding_depth, child_scan_scope, child_context, child + ) class CSEPrintPass: @@ -347,8 +359,7 @@ def __call__(self, root: 'ir.BaseIR', binding_sites: Dict[int, BindingSite]): if id(root) in memo: return ''.join(memo[id(root)]) root_ctx = ({var: 0 for var in root.free_vars}, {}, {}) - stack = [self.StackFrame.make(root, self.renderer, binding_sites, - bindings_stack, 0, 0, False, root_ctx, 0)] + stack = [self.StackFrame.make(root, self.renderer, binding_sites, bindings_stack, 0, 0, False, root_ctx, 0)] stack[0].set_builder(root_builder, self.renderer) while True: @@ -359,7 +370,7 @@ def __call__(self, root: 'ir.BaseIR', binding_sites: Dict[int, BindingSite]): if child_idx >= len(frame.children): if frame.lift_to_frame is not None: - assert(not frame.insert_lets) + assert not frame.insert_lets if id(node) in memo: frame.builder.append(memo[id(node)]) else: @@ -423,7 +434,7 @@ def __call__(self, root: 'ir.BaseIR', binding_sites: Dict[int, BindingSite]): continue if lift_type == 'value': - child_builder = [f'(Let {name} '] + child_builder = [f'(Let eval {name} '] elif lift_type == 'agg': child_builder = [f'(AggLet {name} False '] else: @@ -453,24 +464,36 @@ def __call__(self, root: 'ir.BaseIR', binding_sites: Dict[int, BindingSite]): BindingsStackFrame = namedtuple( 'BindingsStackFrame', - 'depth lifted_lets agg_lifted_lets scan_lifted_lets visited agg_visited' - ' scan_visited let_bodies') + 'depth lifted_lets agg_lifted_lets scan_lifted_lets visited agg_visited' ' scan_visited let_bodies', + ) class StackFrame: - __slots__ = ['node', 'children', 'min_binding_depth', 'context', - 'min_value_binding_depth', 'scan_scope', 'depth', - 'lift_to_frame', 'insert_lets', 'builder', 'child_idx'] - - def __init__(self, - node: Renderable, - children: Sequence[Renderable], - min_binding_depth: int, - min_value_binding_depth: int, - scan_scope: bool, - context: Context, - depth: int, - insert_lets: bool, - lift_to_frame: 'Optional[CSEPrintPass.BindingsStackFrame]' = None): + __slots__ = [ + 'node', + 'children', + 'min_binding_depth', + 'context', + 'min_value_binding_depth', + 'scan_scope', + 'depth', + 'lift_to_frame', + 'insert_lets', + 'builder', + 'child_idx', + ] + + def __init__( + self, + node: Renderable, + children: Sequence[Renderable], + min_binding_depth: int, + min_value_binding_depth: int, + scan_scope: bool, + context: Context, + depth: int, + insert_lets: bool, + lift_to_frame: 'Optional[CSEPrintPass.BindingsStackFrame]' = None, + ): # Immutable # The 'Renderable' node corresponding to this stack frame. @@ -519,11 +542,11 @@ def __init__(self, def bind_depth(self) -> int: bind_depth = self.min_binding_depth if len(self.node.free_vars) > 0: - bind_depth = max(bind_depth, max(self.context[0][var] for var in self.node.free_vars)) + bind_depth = max(bind_depth, *(self.context[0][var] for var in self.node.free_vars)) if len(self.node.free_agg_vars) > 0: - bind_depth = max(bind_depth, max(self.context[1][var] for var in self.node.free_agg_vars)) + bind_depth = max(bind_depth, *(self.context[1][var] for var in self.node.free_agg_vars)) if len(self.node.free_scan_vars) > 0: - bind_depth = max(bind_depth, max(self.context[2][var] for var in self.node.free_scan_vars)) + bind_depth = max(bind_depth, *(self.context[2][var] for var in self.node.free_scan_vars)) return bind_depth def add_lets(self, let_bodies: Sequence[str], out_builder: MutableSequence[str]): @@ -534,10 +557,12 @@ def add_lets(self, let_bodies: Sequence[str], out_builder: MutableSequence[str]) for _ in range(num_lets): out_builder.append(')') - def make_child_frame(self, - renderer: 'CSERenderer', - binding_sites: Dict[int, BindingSite], - bindings_stack: 'Dict[int, CSEPrintPass.BindingsStackFrame]'): + def make_child_frame( + self, + renderer: 'CSERenderer', + binding_sites: Dict[int, BindingSite], + bindings_stack: 'Dict[int, CSEPrintPass.BindingsStackFrame]', + ): child_min_binding_depth = self.min_binding_depth child_min_value_binding_depth = self.min_value_binding_depth child_scan_scope = self.scan_scope @@ -562,33 +587,52 @@ def make_child_frame(self, child_context = self.node.renderable_child_context(self.child_idx, self.context, child_depth) else: child_context = self.context - return self.make(child, renderer, binding_sites, bindings_stack, - child_min_binding_depth, child_min_value_binding_depth, - child_scan_scope, child_context, child_depth) + return self.make( + child, + renderer, + binding_sites, + bindings_stack, + child_min_binding_depth, + child_min_value_binding_depth, + child_scan_scope, + child_context, + child_depth, + ) @staticmethod - def make(node: Renderable, - renderer: 'CSERenderer', - binding_sites: Dict[int, BindingSite], - bindings_stack: 'Dict[int, CSEPrintPass.BindingsStackFrame]', - min_binding_depth: int, - min_value_binding_depth: int, - scan_scope: bool, - context: Context, - depth: int): - insert_lets = (id(node) in binding_sites - and depth == binding_sites[id(node)].depth - and (len(binding_sites[id(node)].lifted_lets) > 0 - or len(binding_sites[id(node)].agg_lifted_lets) > 0 - or len(binding_sites[id(node)].scan_lifted_lets) > 0)) - state = CSEPrintPass.StackFrame(node, - node.render_children(renderer), - min_binding_depth, - min_value_binding_depth, - scan_scope, context, depth, insert_lets) + def make( + node: Renderable, + renderer: 'CSERenderer', + binding_sites: Dict[int, BindingSite], + bindings_stack: 'Dict[int, CSEPrintPass.BindingsStackFrame]', + min_binding_depth: int, + min_value_binding_depth: int, + scan_scope: bool, + context: Context, + depth: int, + ): + insert_lets = ( + id(node) in binding_sites + and depth == binding_sites[id(node)].depth + and ( + len(binding_sites[id(node)].lifted_lets) > 0 + or len(binding_sites[id(node)].agg_lifted_lets) > 0 + or len(binding_sites[id(node)].scan_lifted_lets) > 0 + ) + ) + state = CSEPrintPass.StackFrame( + node, + node.render_children(renderer), + min_binding_depth, + min_value_binding_depth, + scan_scope, + context, + depth, + insert_lets, + ) if insert_lets: bind_site = binding_sites[id(node)] - assert(bind_site.depth == depth) + assert bind_site.depth == depth bindings_stack[depth] = CSEPrintPass.StackFrame.make_bindings_stack_frame(bind_site) return state @@ -610,4 +654,5 @@ def make_bindings_stack_frame(site: BindingSite): visited={}, agg_visited={}, scan_visited={}, - let_bodies=[]) + let_bodies=[], + ) diff --git a/hail/python/hail/ir/table_ir.py b/hail/python/hail/ir/table_ir.py index 22c0530c35a..b540ad0b33b 100644 --- a/hail/python/hail/ir/table_ir.py +++ b/hail/python/hail/ir/table_ir.py @@ -1,26 +1,24 @@ from typing import Optional + import hail as hl from hail.expr.types import dtype, tint32, tint64, tstruct -from hail.ir.base_ir import BaseIR, IR, TableIR -import hail.ir.ir as ir -from hail.ir.utils import modify_deep_field, zip_with_index, default_row_uid, default_col_uid -from hail.ir.ir import unify_uid_types, pad_uid, concat_uids -from hail.typecheck import typecheck_method, nullable, sequenceof +from hail.ir import ir +from hail.ir.base_ir import IR, BaseIR, TableIR +from hail.ir.ir import concat_uids, pad_uid, unify_uid_types +from hail.ir.utils import default_col_uid, default_row_uid, modify_deep_field, zip_with_index +from hail.typecheck import nullable, sequenceof, typecheck_method from hail.utils import FatalError from hail.utils.interval import Interval from hail.utils.java import Env -from hail.utils.misc import escape_str, parsable_strings, escape_id from hail.utils.jsonx import dump_json +from hail.utils.misc import escape_id, escape_str, parsable_strings def unpack_uid(new_row_type, uid_field_name): new_row = ir.Ref('row', new_row_type) - if uid_field_name in new_row_type.fields: - uid = ir.GetField(new_row, uid_field_name) - else: - uid = ir.NA(tint64) - return uid, \ - ir.SelectFields(new_row, [field for field in new_row_type.fields if not field == uid_field_name]) + assert uid_field_name in new_row_type.fields + uid = ir.GetField(new_row, uid_field_name) + return uid, ir.SelectFields(new_row, [field for field in new_row_type.fields if not field == uid_field_name]) class MatrixRowsTable(TableIR): @@ -33,9 +31,7 @@ def _handle_randomness(self, uid_field_name): def _compute_type(self, deep_typecheck): self.child.compute_type(deep_typecheck) - return hl.ttable(self.child.typ.global_type, - self.child.typ.row_type, - self.child.typ.row_key) + return hl.ttable(self.child.typ.global_type, self.child.typ.row_type, self.child.typ.row_key) class TableJoin(TableIR): @@ -48,9 +44,9 @@ def __init__(self, left, right, join_type, join_key): def _handle_randomness(self, uid_field_name): if uid_field_name is None: - return TableJoin(self.left.handle_randomness(None), - self.right.handle_randomness(None), - self.join_type, self.join_key) + return TableJoin( + self.left.handle_randomness(None), self.right.handle_randomness(None), self.join_type, self.join_key + ) left = self.left.handle_randomness('__left_uid') right = self.right.handle_randomness('__right_uid') @@ -60,8 +56,8 @@ def _handle_randomness(self, uid_field_name): old_joined_row = ir.SelectFields(row, [field for field in self.typ.row_type]) left_uid = ir.GetField(row, '__left_uid') right_uid = ir.GetField(row, '__right_uid') - handle_missing_left = self.join_type == 'right' or self.join_type == 'outer' - handle_missing_right = self.join_type == 'left' or self.join_type == 'outer' + handle_missing_left = self.join_type in {'right', 'outer'} + handle_missing_right = self.join_type in {'left', 'outer'} uid = concat_uids(left_uid, right_uid, handle_missing_left, handle_missing_right) return TableMapRows(joined, ir.InsertFields(old_joined_row, [(uid_field_name, uid)], None)) @@ -70,17 +66,18 @@ def head_str(self): return f'{escape_id(self.join_type)} {self.join_key}' def _eq(self, other): - return self.join_key == other.join_key and \ - self.join_type == other.join_type + return self.join_key == other.join_key and self.join_type == other.join_type def _compute_type(self, deep_typecheck): self.left.compute_type(deep_typecheck) self.right.compute_type(deep_typecheck) left_typ = self.left.typ right_typ = self.right.typ - return hl.ttable(left_typ.global_type._concat(right_typ.global_type), - left_typ.key_type._concat(left_typ.value_type)._concat(right_typ.value_type), - left_typ.row_key + right_typ.row_key[self.join_key:]) + return hl.ttable( + left_typ.global_type._concat(right_typ.global_type), + left_typ.key_type._concat(left_typ.value_type)._concat(right_typ.value_type), + left_typ.row_key + right_typ.row_key[self.join_key :], + ) class TableLeftJoinRightDistinct(TableIR): @@ -107,9 +104,8 @@ def _compute_type(self, deep_typecheck): left_typ = self.left.typ right_typ = self.right.typ return hl.ttable( - left_typ.global_type, - left_typ.row_type._insert_field(self.root, right_typ.value_type), - left_typ.row_key) + left_typ.global_type, left_typ.row_type._insert_field(self.root, right_typ.value_type), left_typ.row_key + ) class TableIntervalJoin(TableIR): @@ -140,10 +136,7 @@ def _compute_type(self, deep_typecheck): right_val_typ = left_typ.row_type._insert_field(self.root, hl.tarray(right_typ.value_type)) else: right_val_typ = left_typ.row_type._insert_field(self.root, right_typ.value_type) - return hl.ttable( - left_typ.global_type, - right_val_typ, - left_typ.row_key) + return hl.ttable(left_typ.global_type, right_val_typ, left_typ.row_key) class TableUnion(TableIR): @@ -164,10 +157,13 @@ def _handle_randomness(self, uid_field_name): uids = [uid for uid, _ in (unpack_uid(child.typ.row_type, uid_field_name) for child in new_children)] uid_type = unify_uid_types((uid.typ for uid in uids), tag=True) - new_children = [TableMapRows(child, - ir.InsertFields(ir.Ref('row', child.typ.row_type), - [(uid_field_name, pad_uid(uid, uid_type, i))], None)) - for i, (child, uid) in enumerate(zip(new_children, uids))] + new_children = [ + TableMapRows( + child, + ir.InsertFields(ir.Ref('row', child.typ.row_type), [(uid_field_name, pad_uid(uid, uid_type, i))], None), + ) + for i, (child, uid) in enumerate(zip(new_children, uids)) + ] return TableUnion(new_children) def _compute_type(self, deep_typecheck): @@ -183,8 +179,12 @@ def __init__(self, n, n_partitions): self.n_partitions = n_partitions def _handle_randomness(self, uid_field_name): - assert(uid_field_name is not None) - new_row = ir.InsertFields(ir.Ref('row', self.typ.row_type), [(uid_field_name, ir.Cast(ir.GetField(ir.Ref('row', self.typ.row_type), 'idx'), tint64))], None) + assert uid_field_name is not None + new_row = ir.InsertFields( + ir.Ref('row', self.typ.row_type), + [(uid_field_name, ir.Cast(ir.GetField(ir.Ref('row', self.typ.row_type), 'idx'), tint64))], + None, + ) return TableMapRows(self, new_row) def head_str(self): @@ -194,9 +194,7 @@ def _eq(self, other): return self.n == other.n and self.n_partitions == other.n_partitions def _compute_type(self, deep_typecheck): - return hl.ttable(hl.tstruct(), - hl.tstruct(idx=hl.tint32), - ['idx']) + return hl.ttable(hl.tstruct(), hl.tstruct(idx=hl.tint32), ['idx']) class TableMapGlobals(TableIR): @@ -210,14 +208,11 @@ def _handle_randomness(self, uid_field_name): if new_globals.uses_randomness: new_globals = ir.Let('__rng_state', ir.RNGStateLiteral(), new_globals) - return TableMapGlobals(self.child.handle_randomness(uid_field_name), - new_globals) + return TableMapGlobals(self.child.handle_randomness(uid_field_name), new_globals) def _compute_type(self, deep_typecheck): self.new_globals.compute_type(self.child.typ.global_env(), None, deep_typecheck) - return hl.ttable(self.new_globals.typ, - self.child.typ.row_type, - self.child.typ.row_key) + return hl.ttable(self.new_globals.typ, self.child.typ.row_type, self.child.typ.row_key) def renderable_bindings(self, i, default_value=None): return self.child.typ.global_env(default_value) if i == 1 else {} @@ -243,7 +238,17 @@ def _handle_randomness(self, uid_field_name): ir.Ref('row', new_explode.typ.row_type), self.path, lambda tuple: ir.GetTupleElement(tuple, 0), - lambda row, tuple: ir.InsertFields(row, [(uid_field_name, concat_uids(ir.GetField(row, uid_field_name), ir.Cast(ir.GetTupleElement(tuple, 1), tint64)))], None)) + lambda row, tuple: ir.InsertFields( + row, + [ + ( + uid_field_name, + concat_uids(ir.GetField(row, uid_field_name), ir.Cast(ir.GetTupleElement(tuple, 1), tint64)), + ) + ], + None, + ), + ) return TableMapRows(new_explode, new_row) def head_str(self): @@ -255,9 +260,11 @@ def _eq(self, other): def _compute_type(self, deep_typecheck): self.child.compute_type(deep_typecheck) atyp = self.child.typ.row_type._index_path(self.path) - return hl.ttable(self.child.typ.global_type, - self.child.typ.row_type._insert(self.path, atyp.element_type), - self.child.typ.row_key) + return hl.ttable( + self.child.typ.global_type, + self.child.typ.row_type._insert(self.path, atyp.element_type), + self.child.typ.row_key, + ) class TableKeyBy(TableIR): @@ -278,9 +285,7 @@ def _eq(self, other): def _compute_type(self, deep_typecheck): self.child.compute_type(deep_typecheck) - return hl.ttable(self.child.typ.global_type, - self.child.typ.row_type, - self.keys) + return hl.ttable(self.child.typ.global_type, self.child.typ.row_type, self.keys) class TableMapRows(TableIR): @@ -309,10 +314,7 @@ def _compute_type(self, deep_typecheck): self.child.compute_type(deep_typecheck) # agg_env for scans self.new_row.compute_type(self.child.typ.row_env(), self.child.typ.row_env(), deep_typecheck) - return hl.ttable( - self.child.typ.global_type, - self.new_row.typ, - self.child.typ.row_key) + return hl.ttable(self.child.typ.global_type, self.new_row.typ, self.child.typ.row_key) def renderable_bindings(self, i, default_value=None): if i == 1: @@ -339,27 +341,39 @@ def __init__(self, child, global_name, partition_stream_name, body, requested_ke def _handle_randomness(self, uid_field_name): if uid_field_name is not None: raise FatalError('TableMapPartitions does not support randomness, in its body or in consumers') - return TableMapPartitions(self.child.handle_randomness(None), self.global_name, self.partition_stream_name, self.body, self.requested_key, self.allowed_overlap) + return TableMapPartitions( + self.child.handle_randomness(None), + self.global_name, + self.partition_stream_name, + self.body, + self.requested_key, + self.allowed_overlap, + ) def _compute_type(self, deep_typecheck): self.child.compute_type(deep_typecheck) - self.body.compute_type({self.global_name: self.child.typ.global_type, - self.partition_stream_name: hl.tstream(self.child.typ.row_type)}, - {}, - deep_typecheck) + self.body.compute_type( + { + self.global_name: self.child.typ.global_type, + self.partition_stream_name: hl.tstream(self.child.typ.row_type), + }, + {}, + deep_typecheck, + ) assert isinstance(self.body.typ, hl.tstream) and isinstance(self.body.typ.element_type, hl.tstruct) new_row_type = self.body.typ.element_type for k in self.child.typ.row_key: assert k in new_row_type - return hl.ttable(self.child.typ.global_type, - new_row_type, - self.child.typ.row_key) + return hl.ttable(self.child.typ.global_type, new_row_type, self.child.typ.row_key) def renderable_bindings(self, i, default_value=None): if i == 1: - return {self.global_name: self.child.typ.global_type if default_value is None else default_value, - self.partition_stream_name: hl.tstream( - self.child.typ.row_type) if default_value is None else default_value} + return { + self.global_name: self.child.typ.global_type if default_value is None else default_value, + self.partition_stream_name: hl.tstream(self.child.typ.row_type) + if default_value is None + else default_value, + } else: return {} @@ -367,18 +381,17 @@ def head_str(self): return f'{escape_id(self.global_name)} {escape_id(self.partition_stream_name)} {self.requested_key} {self.allowed_overlap}' def _eq(self, other): - return (self.global_name == other.global_name - and self.partition_stream_name == other.partition_stream_name - and self.allowed_overlap == other.allowed_overlap) + return ( + self.global_name == other.global_name + and self.partition_stream_name == other.partition_stream_name + and self.allowed_overlap == other.allowed_overlap + ) class TableRead(TableIR): - def __init__(self, - reader, - drop_rows: bool = False, - drop_row_uids: bool = True, - *, - _assert_type: Optional['hl.ttable'] = None): + def __init__( + self, reader, drop_rows: bool = False, drop_row_uids: bool = True, *, _assert_type: Optional['hl.ttable'] = None + ): super().__init__() self.reader = reader self.drop_rows = drop_rows @@ -399,8 +412,8 @@ def _handle_randomness(self, uid_field_name): else: row = ir.Ref('row', self.typ.row_type) result = TableMapRows( - result, - ir.InsertFields(row, [(uid_field_name, ir.GetField(row, default_row_uid))], None)) + result, ir.InsertFields(row, [(uid_field_name, ir.GetField(row, default_row_uid))], None) + ) return result def head_str(self): @@ -411,9 +424,11 @@ def head_str(self): return f'{reqType} {self.drop_rows} "{self.reader.render()}"' def _eq(self, other): - return (self.reader == other.reader - and self.drop_rows == other.drop_rows - and self.drop_row_uids == other.drop_row_uids) + return ( + self.reader == other.reader + and self.drop_rows == other.drop_rows + and self.drop_row_uids == other.drop_row_uids + ) def _compute_type(self, deep_typecheck): if self._type is not None: @@ -429,6 +444,7 @@ def __init__(self, child): def _handle_randomness(self, uid_field_name): from hail.ir.matrix_ir import MatrixMapEntries, MatrixMapRows + if uid_field_name is None: return MatrixEntriesTable(self.child.handle_randomness(None, None)) @@ -437,18 +453,25 @@ def _handle_randomness(self, uid_field_name): entry = ir.Ref('g', child.typ.entry_type) row_uid = ir.GetField(ir.Ref('va', child.typ.row_type), temp_row_uid) col_uid = ir.GetField(ir.Ref('sa', child.typ.col_type), default_col_uid) - child = MatrixMapEntries(child, ir.InsertFields(entry, [('__entry_uid', ir.concat_uids(row_uid, col_uid))], None)) - child = MatrixMapRows(child, ir.SelectFields(ir.Ref('va', child.typ.row_type), [field for field in child.typ.row_type if field != temp_row_uid])) - return TableRename(MatrixEntriesTable(child), {'__entry_uid': default_row_uid}, {}) + child = MatrixMapEntries( + child, ir.InsertFields(entry, [('__entry_uid', ir.concat_uids(row_uid, col_uid))], None) + ) + child = MatrixMapRows( + child, + ir.SelectFields( + ir.Ref('va', child.typ.row_type), [field for field in child.typ.row_type if field != temp_row_uid] + ), + ) + return TableRename(MatrixEntriesTable(child), {'__entry_uid': uid_field_name}, {}) def _compute_type(self, deep_typecheck): self.child.compute_type(deep_typecheck) child_typ = self.child.typ - return hl.ttable(child_typ.global_type, - child_typ.row_type - ._concat(child_typ.col_type) - ._concat(child_typ.entry_type), - child_typ.row_key + child_typ.col_key) + return hl.ttable( + child_typ.global_type, + child_typ.row_type._concat(child_typ.col_type)._concat(child_typ.entry_type), + child_typ.row_key + child_typ.col_key, + ) class TableFilter(TableIR): @@ -505,23 +528,15 @@ def _handle_randomness(self, uid_field_name): if expr.uses_randomness or uid_field_name is not None: first_uid = ir.Ref(Env.get_uid(), uid.typ) if expr.uses_randomness: - expr = ir.Let( - '__rng_state', - ir.RNGSplit(ir.RNGStateLiteral(), first_uid), - expr) + expr = ir.Let('__rng_state', ir.RNGSplit(ir.RNGStateLiteral(), first_uid), expr) if expr.uses_agg_randomness(is_scan=False): - expr = ir.AggLet('__rng_state', - ir.RNGSplit(ir.RNGStateLiteral(), uid), - expr, is_scan=False) + expr = ir.AggLet('__rng_state', ir.RNGSplit(ir.RNGStateLiteral(), uid), expr, is_scan=False) if uid_field_name is not None: expr = ir.InsertFields(expr, [(uid_field_name, first_uid)], None) expr = ir.Let(first_uid.name, ir.ArrayRef(ir.ApplyAggOp('Take', [ir.I32(1)], [uid]), ir.I32(0)), expr) new_key = self.new_key if new_key.uses_randomness: - expr = ir.Let( - '__rng_state', - ir.RNGSplit(ir.RNGStateLiteral(), uid), - new_key) + expr = ir.Let('__rng_state', ir.RNGSplit(ir.RNGStateLiteral(), uid), new_key) return TableKeyByAndAggregate(child, expr, new_key, self.n_partitions, self.buffer_size) def head_str(self): @@ -534,9 +549,7 @@ def _compute_type(self, deep_typecheck): self.child.compute_type(deep_typecheck) self.expr.compute_type(self.child.typ.global_env(), self.child.typ.row_env(), deep_typecheck) self.new_key.compute_type(self.child.typ.row_env(), None, deep_typecheck) - return hl.ttable(self.child.typ.global_type, - self.new_key.typ._concat(self.expr.typ), - list(self.new_key.typ)) + return hl.ttable(self.child.typ.global_type, self.new_key.typ._concat(self.expr.typ), list(self.new_key.typ)) def renderable_bindings(self, i, default_value=None): if i == 1: @@ -569,16 +582,9 @@ def _handle_randomness(self, uid_field_name): expr = ir.AggLet('va', old_row, self.expr, is_scan=False) first_uid = ir.Ref(Env.get_uid(), uid.typ) if expr.uses_value_randomness: - expr = ir.Let( - '__rng_state', - ir.RNGSplit(ir.RNGStateLiteral(), first_uid), - expr) + expr = ir.Let('__rng_state', ir.RNGSplit(ir.RNGStateLiteral(), first_uid), expr) if expr.uses_agg_randomness(is_scan=False): - expr = ir.AggLet( - '__rng_state', - ir.RNGSplit(ir.RNGStateLiteral(), uid), - expr, - is_scan=False) + expr = ir.AggLet('__rng_state', ir.RNGSplit(ir.RNGStateLiteral(), uid), expr, is_scan=False) if uid_field_name is not None: expr = ir.InsertFields(expr, [(uid_field_name, first_uid)], None) expr = ir.Let(first_uid.name, ir.ArrayRef(ir.ApplyAggOp('Take', [ir.I32(1)], [uid]), ir.I32(0)), expr) @@ -588,9 +594,7 @@ def _compute_type(self, deep_typecheck): self.child.compute_type(deep_typecheck) child_typ = self.child.typ self.expr.compute_type(child_typ.global_env(), child_typ.row_env(), deep_typecheck) - return hl.ttable(child_typ.global_type, - child_typ.key_type._concat(self.expr.typ), - child_typ.row_key) + return hl.ttable(child_typ.global_type, child_typ.key_type._concat(self.expr.typ), child_typ.row_key) def renderable_bindings(self, i, default_value=None): if i == 1: @@ -614,9 +618,7 @@ def _handle_randomness(self, uid_field_name): def _compute_type(self, deep_typecheck): self.child.compute_type(deep_typecheck) - return hl.ttable(self.child.typ.global_type, - self.child.typ.col_type, - self.child.typ.col_key) + return hl.ttable(self.child.typ.global_type, self.child.typ.col_type, self.child.typ.col_key) class TableParallelize(TableIR): @@ -628,29 +630,30 @@ def __init__(self, rows_and_global, n_partitions): def _handle_randomness(self, uid_field_name): rows_and_global = self.rows_and_global if rows_and_global.uses_randomness: - rows_and_global = ir.Let( - '__rng_state', - ir.RNGStateLiteral(), - rows_and_global) + rows_and_global = ir.Let('__rng_state', ir.RNGStateLiteral(), rows_and_global) if uid_field_name is not None: rows_and_global_ref = ir.Ref(Env.get_uid(), rows_and_global.typ) row = Env.get_uid() uid = Env.get_uid() iota = ir.StreamIota(ir.I32(0), ir.I32(1)) rows = ir.ToStream(ir.GetField(rows_and_global_ref, 'rows')) - new_rows = ir.ToArray(ir.StreamZip( - [rows, iota], - [row, uid], - ir.InsertFields( - ir.Ref(row, rows.typ.element_type), - [(uid_field_name, ir.Cast(ir.Ref(uid, tint32), tint64))], - None), - 'TakeMinLength')) - rows_and_global = \ - ir.Let(rows_and_global_ref.name, rows_and_global, - ir.InsertFields( - rows_and_global_ref, - [('rows', new_rows)], None)) + new_rows = ir.ToArray( + ir.StreamZip( + [rows, iota], + [row, uid], + ir.InsertFields( + ir.Ref(row, rows.typ.element_type), + [(uid_field_name, ir.Cast(ir.Ref(uid, tint32), tint64))], + None, + ), + 'TakeMinLength', + ) + ) + rows_and_global = ir.Let( + rows_and_global_ref.name, + rows_and_global, + ir.InsertFields(rows_and_global_ref, [('rows', new_rows)], None), + ) return TableParallelize(rows_and_global, self.n_partitions) def head_str(self): @@ -661,9 +664,7 @@ def _eq(self, other): def _compute_type(self, deep_typecheck): self.rows_and_global.compute_type({}, None, deep_typecheck) - return hl.ttable(self.rows_and_global.typ['global'], - self.rows_and_global.typ['rows'].element_type, - []) + return hl.ttable(self.rows_and_global.typ['global'], self.rows_and_global.typ['rows'].element_type, []) class TableHead(TableIR): @@ -723,9 +724,7 @@ def _eq(self, other): def _compute_type(self, deep_typecheck): self.child.compute_type(deep_typecheck) - return hl.ttable(self.child.typ.global_type, - self.child.typ.row_type, - []) + return hl.ttable(self.child.typ.global_type, self.child.typ.row_type, []) class TableDistinct(TableIR): @@ -776,7 +775,9 @@ def __init__(self, child, entries_field_name, cols_field_name): self.cols_field_name = cols_field_name def _handle_randomness(self, uid_field_name): - return CastMatrixToTable(self.child.handle_randomness(uid_field_name, None), self.entries_field_name, self.cols_field_name) + return CastMatrixToTable( + self.child.handle_randomness(uid_field_name, None), self.entries_field_name, self.cols_field_name + ) def head_str(self): return f'"{escape_str(self.entries_field_name)}" "{escape_str(self.cols_field_name)}"' @@ -787,10 +788,11 @@ def _eq(self, other): def _compute_type(self, deep_typecheck): self.child.compute_type(deep_typecheck) child_typ = self.child.typ - return hl.ttable(child_typ.global_type._insert_field(self.cols_field_name, hl.tarray(child_typ.col_type)), - child_typ.row_type._insert_field(self.entries_field_name, - hl.tarray(child_typ.entry_type)), - child_typ.row_key) + return hl.ttable( + child_typ.global_type._insert_field(self.cols_field_name, hl.tarray(child_typ.col_type)), + child_typ.row_type._insert_field(self.entries_field_name, hl.tarray(child_typ.entry_type)), + child_typ.row_key, + ) class TableRename(TableIR): @@ -804,10 +806,12 @@ def _handle_randomness(self, uid_field_name): return TableRename(self.child.handle_randomness(uid_field_name), self.row_map, self.global_map) def head_str(self): - return f'{parsable_strings(self.row_map.keys())} ' \ - f'{parsable_strings(self.row_map.values())} ' \ - f'{parsable_strings(self.global_map.keys())} ' \ - f'{parsable_strings(self.global_map.values())} ' + return ( + f'{parsable_strings(self.row_map.keys())} ' + f'{parsable_strings(self.row_map.values())} ' + f'{parsable_strings(self.global_map.keys())} ' + f'{parsable_strings(self.global_map.values())} ' + ) def _eq(self, other): return self.row_map == other.row_map and self.global_map == other.global_map @@ -835,22 +839,22 @@ def _handle_randomness(self, uid_field_name): new_children = [ TableMapRows( child, - ir.InsertFields(ir.Ref('row', child.typ.row_type), - [(uid_field_name, - pad_uid(uid, uid_type, i))], None)) - for i, (child, uid) in enumerate(zip(new_children, uids))] + ir.InsertFields(ir.Ref('row', child.typ.row_type), [(uid_field_name, pad_uid(uid, uid_type, i))], None), + ) + for i, (child, uid) in enumerate(zip(new_children, uids)) + ] joined = TableMultiWayZipJoin(new_children, self.data_name, self.global_name) accum = ir.Ref(Env.get_uid(), uid_type) elt = Env.get_uid() row = ir.Ref('row', joined.typ.row_type) data = ir.GetField(row, self.data_name) uid = ir.StreamFold( - ir.toStream(data), ir.NA(uid_type), accum.name, elt, - ir.If(ir.IsNA(accum), - ir.GetField( - ir.Ref(elt, data.typ.element_type), - uid_field_name), - accum)) + ir.toStream(data), + ir.NA(uid_type), + accum.name, + elt, + ir.If(ir.IsNA(accum), ir.GetField(ir.Ref(elt, data.typ.element_type), uid_field_name), accum), + ) return TableMapRows(joined, ir.InsertFields(row, [(uid_field_name, uid)], None)) def head_str(self): @@ -866,7 +870,8 @@ def _compute_type(self, deep_typecheck): return hl.ttable( hl.tstruct(**{self.global_name: hl.tarray(child_typ.global_type)}), child_typ.key_type._insert_field(self.data_name, hl.tarray(child_typ.value_type)), - child_typ.row_key) + child_typ.row_key, + ) class TableFilterIntervals(TableIR): @@ -878,7 +883,9 @@ def __init__(self, child, intervals, point_type, keep): self.keep = keep def _handle_randomness(self, uid_field_name): - return TableFilterIntervals(self.child.handle_randomness(uid_field_name), self.intervals, self.point_type, self.keep) + return TableFilterIntervals( + self.child.handle_randomness(uid_field_name), self.intervals, self.point_type, self.keep + ) def head_str(self): return f'{self.child.typ.key_type._parsable_string()} {dump_json(hl.tarray(hl.tinterval(self.point_type))._convert_to_json(self.intervals))} {self.keep}' @@ -927,7 +934,8 @@ def regression_test_type(test): glm_fit_schema = dtype('struct{n_iterations:int32,converged:bool,exploded:bool}') if test == 'wald': return dtype( - f'struct{{beta:float64,standard_error:float64,z_stat:float64,p_value:float64,fit:{glm_fit_schema}}}') + f'struct{{beta:float64,standard_error:float64,z_stat:float64,p_value:float64,fit:{glm_fit_schema}}}' + ) elif test == 'lrt': return dtype(f'struct{{beta:float64,chi_sq_stat:float64,p_value:float64,fit:{glm_fit_schema}}}') elif test == 'score': @@ -962,73 +970,82 @@ def _compute_type(self, deep_typecheck): if name == 'LinearRegressionRowsChained': pass_through = self.config['passThrough'] chained_schema = hl.dtype( - 'struct{n:array,sum_x:array,y_transpose_x:array>,beta:array>,standard_error:array>,t_stat:array>,p_value:array>}') + 'struct{n:array,sum_x:array,y_transpose_x:array>,beta:array>,standard_error:array>,t_stat:array>,p_value:array>}' + ) return hl.ttable( child_typ.global_type, - (child_typ.row_key_type - ._insert_fields(**{f: child_typ.row_type[f] for f in pass_through}) - ._concat(chained_schema)), - child_typ.row_key) + ( + child_typ.row_key_type._insert_fields(**{f: child_typ.row_type[f] for f in pass_through})._concat( + chained_schema + ) + ), + child_typ.row_key, + ) elif name == 'LinearRegressionRowsSingle': pass_through = self.config['passThrough'] chained_schema = hl.dtype( - 'struct{n:int32,sum_x:float64,y_transpose_x:array,beta:array,standard_error:array,t_stat:array,p_value:array}') + 'struct{n:int32,sum_x:float64,y_transpose_x:array,beta:array,standard_error:array,t_stat:array,p_value:array}' + ) return hl.ttable( child_typ.global_type, - (child_typ.row_key_type - ._insert_fields(**{f: child_typ.row_type[f] for f in pass_through}) - ._concat(chained_schema)), - child_typ.row_key) + ( + child_typ.row_key_type._insert_fields(**{f: child_typ.row_type[f] for f in pass_through})._concat( + chained_schema + ) + ), + child_typ.row_key, + ) elif name == 'LogisticRegression': pass_through = self.config['passThrough'] logreg_type = hl.tstruct(logistic_regression=hl.tarray(regression_test_type(self.config['test']))) return hl.ttable( child_typ.global_type, - (child_typ.row_key_type - ._insert_fields(**{f: child_typ.row_type[f] for f in pass_through}) - ._concat(logreg_type)), - child_typ.row_key) + ( + child_typ.row_key_type._insert_fields(**{f: child_typ.row_type[f] for f in pass_through})._concat( + logreg_type + ) + ), + child_typ.row_key, + ) elif name == 'PoissonRegression': pass_through = self.config['passThrough'] poisreg_type = regression_test_type(self.config['test']) return hl.ttable( child_typ.global_type, - (child_typ.row_key_type - ._insert_fields(**{f: child_typ.row_type[f] for f in pass_through}) - ._concat(poisreg_type)), - child_typ.row_key) + ( + child_typ.row_key_type._insert_fields(**{f: child_typ.row_type[f] for f in pass_through})._concat( + poisreg_type + ) + ), + child_typ.row_key, + ) elif name == 'Skat': key_field = self.config['keyField'] key_type = child_typ.row_type[key_field] skat_type = hl.dtype(f'struct{{id:{key_type},size:int32,q_stat:float64,p_value:float64,fault:int32}}') - return hl.ttable( - hl.tstruct(), - skat_type, - ['id']) + return hl.ttable(hl.tstruct(), skat_type, ['id']) elif name == 'PCA': return hl.ttable( - hl.tstruct(eigenvalues=hl.tarray(hl.tfloat64), - scores=hl.tarray(child_typ.col_key_type._insert_field('scores', hl.tarray(hl.tfloat64)))), + hl.tstruct( + eigenvalues=hl.tarray(hl.tfloat64), + scores=hl.tarray(child_typ.col_key_type._insert_field('scores', hl.tarray(hl.tfloat64))), + ), child_typ.row_key_type._insert_field('loadings', dtype('array')), - child_typ.row_key) + child_typ.row_key, + ) elif name == 'IBD': ibd_info_type = hl.tstruct(Z0=hl.tfloat64, Z1=hl.tfloat64, Z2=hl.tfloat64, PI_HAT=hl.tfloat64) - ibd_type = hl.tstruct(i=hl.tstr, - j=hl.tstr, - ibd=ibd_info_type, - ibs0=hl.tint64, - ibs1=hl.tint64, - ibs2=hl.tint64) - return hl.ttable( - hl.tstruct(), - ibd_type, - ['i', 'j']) + ibd_type = hl.tstruct( + i=hl.tstr, j=hl.tstr, ibd=ibd_info_type, ibs0=hl.tint64, ibs1=hl.tint64, ibs2=hl.tint64 + ) + return hl.ttable(hl.tstruct(), ibd_type, ['i', 'j']) else: assert name == 'LocalLDPrune', name return hl.ttable( hl.tstruct(), child_typ.row_key_type._insert_fields(mean=hl.tfloat64, centered_length_rec=hl.tfloat64), - list(child_typ.row_key)) + list(child_typ.row_key), + ) class BlockMatrixToTableApply(TableIR): @@ -1054,12 +1071,9 @@ def _compute_type(self, deep_typecheck): assert name == 'PCRelate', name return hl.ttable( hl.tstruct(), - hl.tstruct(i=hl.tint32, j=hl.tint32, - kin=hl.tfloat64, - ibd0=hl.tfloat64, - ibd1=hl.tfloat64, - ibd2=hl.tfloat64), - ['i', 'j']) + hl.tstruct(i=hl.tint32, j=hl.tint32, kin=hl.tfloat64, ibd0=hl.tfloat64, ibd1=hl.tfloat64, ibd2=hl.tfloat64), + ['i', 'j'], + ) class BlockMatrixToTable(TableIR): @@ -1071,8 +1085,9 @@ def _handle_randomness(self, uid_field_name): result = self if uid_field_name is not None: row = ir.Ref('row', result.typ.row_type) - new_row = ir.InsertFields(row, [(uid_field_name, - ir.MakeTuple([ir.GetField(row, 'i'), ir.GetField(row, 'j')]))], None) + new_row = ir.InsertFields( + row, [(uid_field_name, ir.MakeTuple([ir.GetField(row, 'i'), ir.GetField(row, 'j')]))], None + ) result = TableMapRows(result, new_row) return result @@ -1082,10 +1097,7 @@ def _compute_type(self, deep_typecheck): class Partitioner(object): - @typecheck_method( - key_type=tstruct, - range_bounds=sequenceof(Interval) - ) + @typecheck_method(key_type=tstruct, range_bounds=sequenceof(Interval)) def __init__(self, key_type, range_bounds): assert all(map(lambda interval: interval.point_type == key_type, range_bounds)) self._key_type = key_type @@ -1101,10 +1113,8 @@ def range_bounds(self): return self._range_bounds def _parsable_string(self): - return ( - f'Partitioner {self.key_type._parsable_string()} ' + dump_json( - self._serialized_type._convert_to_json(self.range_bounds) - ) + return f'Partitioner {self.key_type._parsable_string()} ' + dump_json( + self._serialized_type._convert_to_json(self.range_bounds) ) def __str__(self): @@ -1113,13 +1123,7 @@ def __str__(self): class TableGen(TableIR): @typecheck_method( - contexts=IR, - globals=IR, - cname=str, - gname=str, - body=IR, - partitioner=Partitioner, - error_id=nullable(int) + contexts=IR, globals=IR, cname=str, gname=str, body=IR, partitioner=Partitioner, error_id=nullable(int) ) def __init__(self, contexts, globals, cname, gname, body, partitioner, error_id=None): super().__init__(contexts, globals, body) @@ -1136,22 +1140,21 @@ def __init__(self, contexts, globals, cname, gname, body, partitioner, error_id= def _compute_type(self, deep_typecheck): self.contexts.compute_type({}, None, deep_typecheck) self.globals.compute_type({}, None, deep_typecheck) - bodyenv = { - self.cname: self.contexts.typ.element_type, - self.gname: self.globals.typ - } + bodyenv = {self.cname: self.contexts.typ.element_type, self.gname: self.globals.typ} self.body.compute_type(bodyenv, None, deep_typecheck) return hl.ttable( - global_type=self.globals.typ, - row_type=self.body.typ.element_type, - row_key=self.partitioner.key_type.fields + global_type=self.globals.typ, row_type=self.body.typ.element_type, row_key=self.partitioner.key_type.fields ) def renderable_bindings(self, i, default_value=None): - return {} if i != 2 else { - self.cname: self.contexts.type.element_type if default_value is None else default_value, - self.gname: self.globals.type if default_value is None else default_value - } + return ( + {} + if i != 2 + else { + self.cname: self.contexts.type.element_type if default_value is None else default_value, + self.gname: self.globals.type if default_value is None else default_value, + } + ) def _eq(self, other): return ( @@ -1162,12 +1165,7 @@ def _eq(self, other): ) def head_str(self): - return ' '.join([ - self.cname, - self.gname, - '(' + self.partitioner._parsable_string() + ')', - str(self._error_id) - ]) + return ' '.join([self.cname, self.gname, '(' + self.partitioner._parsable_string() + ')', str(self._error_id)]) def _handle_randomness(self, uid_field_name): globals = self.globals @@ -1183,30 +1181,17 @@ def _handle_randomness(self, uid_field_name): body = ir.Let(self.cname, old_context, self.body) if body.uses_randomness: - body = ir.Let('__rng_state', ir.RNGStateLiteral(), - ir.with_split_rng_state(body, random_uid)) + body = ir.Let('__rng_state', ir.RNGStateLiteral(), ir.with_split_rng_state(body, random_uid)) if uid_field_name is not None: idx = ir.Ref(Env.get_uid(), ir.tint32) elem = ir.Ref(Env.get_uid(), body.typ.element_type) - insert = ir.InsertFields( - elem, - [(uid_field_name, concat_uids(random_uid, ir.Cast(idx, ir.tint64)))], - None - ) + insert = ir.InsertFields(elem, [(uid_field_name, concat_uids(random_uid, ir.Cast(idx, ir.tint64)))], None) iota = ir.StreamIota(ir.I32(0), ir.I32(1)) body = ir.StreamZip([iota, body], [idx.name, elem.name], insert, 'TakeMinLength') - return TableGen( - contexts, - globals, - cname, - self.gname, - body, - self.partitioner, - self._error_id - ) + return TableGen(contexts, globals, cname, self.gname, body, self.partitioner, self._error_id) class JavaTable(TableIR): @@ -1226,6 +1211,7 @@ def _compute_type(self, deep_typecheck): def __del__(self): from hail.backend.py4j_backend import Py4JBackend + if Env._hc: backend = Env.backend() assert isinstance(backend, Py4JBackend) diff --git a/hail/python/hail/ir/table_reader.py b/hail/python/hail/ir/table_reader.py index 84ec6530ace..e3dd8c69b01 100644 --- a/hail/python/hail/ir/table_reader.py +++ b/hail/python/hail/ir/table_reader.py @@ -4,9 +4,8 @@ import avro.schema import hail as hl - -from hail.ir.utils import make_filter_and_replace, default_row_uid -from hail.typecheck import typecheck_method, sequenceof, nullable, anytype, oneof +from hail.ir.utils import default_row_uid, make_filter_and_replace +from hail.typecheck import anytype, nullable, oneof, sequenceof, typecheck_method from hail.utils.misc import escape_str from .utils import impute_type_of_partition_interval_array @@ -23,39 +22,51 @@ def __eq__(self, other): class TableNativeReader(TableReader): - @typecheck_method(path=str, - intervals=nullable(sequenceof(anytype)), - filter_intervals=bool) + @typecheck_method(path=str, intervals=nullable(sequenceof(anytype)), filter_intervals=bool) def __init__(self, path, intervals, filter_intervals): self.path = path self.filter_intervals = filter_intervals self.intervals, self._interval_type = impute_type_of_partition_interval_array(intervals) def render(self): - reader = {'name': 'TableNativeReader', - 'path': self.path} + reader = {'name': 'TableNativeReader', 'path': self.path} if self.intervals is not None: assert self._interval_type is not None reader['options'] = { 'name': 'NativeReaderOptions', 'intervals': self._interval_type._convert_to_json(self.intervals), 'intervalPointType': self._interval_type.element_type.point_type._parsable_string(), - 'filterIntervals': self.filter_intervals + 'filterIntervals': self.filter_intervals, } return escape_str(json.dumps(reader)) def __eq__(self, other): - return isinstance(other, TableNativeReader) and \ - other.path == self.path and \ - other.intervals == self.intervals and \ - other.filter_intervals == self.filter_intervals + return ( + isinstance(other, TableNativeReader) + and other.path == self.path + and other.intervals == self.intervals + and other.filter_intervals == self.filter_intervals + ) class TextTableReader(TableReader): - def __init__(self, paths, min_partitions, types, comment, - delimiter, missing, no_header, quote, - skip_blank_lines, force_bgz, filter, find_replace, - force_gz, source_file_field): + def __init__( + self, + paths, + min_partitions, + types, + comment, + delimiter, + missing, + no_header, + quote, + skip_blank_lines, + force_bgz, + filter, + find_replace, + force_gz, + source_file_field, + ): self.config = { 'files': paths, 'typeMapStr': {f: t._parsable_string() for f, t in types.items()}, @@ -69,7 +80,7 @@ def __init__(self, paths, min_partitions, types, comment, 'forceBGZ': force_bgz, 'filterAndReplace': make_filter_and_replace(filter, find_replace), 'forceGZ': force_gz, - 'sourceFileField': source_file_field + 'sourceFileField': source_file_field, } def render(self): @@ -78,13 +89,17 @@ def render(self): return escape_str(json.dumps(reader)) def __eq__(self, other): - return isinstance(other, TextTableReader) and \ - other.config == self.config + return isinstance(other, TextTableReader) and other.config == self.config class StringTableReader(TableReader): - @typecheck_method(paths=oneof(str, sequenceof(str)), min_partitions=nullable(int), force_bgz=bool, - force=bool, file_per_partition=bool) + @typecheck_method( + paths=oneof(str, sequenceof(str)), + min_partitions=nullable(int), + force_bgz=bool, + force=bool, + file_per_partition=bool, + ) def __init__(self, paths, min_partitions, force_bgz, force, file_per_partition): self.paths = paths self.min_partitions = min_partitions @@ -93,22 +108,26 @@ def __init__(self, paths, min_partitions, force_bgz, force, file_per_partition): self.file_per_partition = file_per_partition def render(self): - reader = {'name': 'StringTableReader', - 'files': self.paths, - 'minPartitions': self.min_partitions, - 'forceBGZ': self.force_bgz, - 'forceGZ': self.force, - 'filePerPartition': self.file_per_partition} + reader = { + 'name': 'StringTableReader', + 'files': self.paths, + 'minPartitions': self.min_partitions, + 'forceBGZ': self.force_bgz, + 'forceGZ': self.force, + 'filePerPartition': self.file_per_partition, + } return escape_str(json.dumps(reader)) def __eq__(self, other): - return isinstance(other, StringTableReader) and \ - other.path == self.path and \ - other.min_partitions == self.min_partitions and \ - other.force_bgz == self.force_bgz and \ - other.force == self.force and \ - other.file_per_partition == self.file_per_partition + return ( + isinstance(other, StringTableReader) + and other.path == self.path + and other.min_partitions == self.min_partitions + and other.force_bgz == self.force_bgz + and other.force == self.force + and other.file_per_partition == self.file_per_partition + ) class TableFromBlockMatrixNativeReader(TableReader): @@ -119,24 +138,30 @@ def __init__(self, path, n_partitions, maximum_cache_memory_in_bytes): self.maximum_cache_memory_in_bytes = maximum_cache_memory_in_bytes def render(self): - reader = {'name': 'TableFromBlockMatrixNativeReader', - 'path': self.path, - 'nPartitions': self.n_partitions, - 'maximumCacheMemoryInBytes': self.maximum_cache_memory_in_bytes} + reader = { + 'name': 'TableFromBlockMatrixNativeReader', + 'path': self.path, + 'nPartitions': self.n_partitions, + 'maximumCacheMemoryInBytes': self.maximum_cache_memory_in_bytes, + } return escape_str(json.dumps(reader)) def __eq__(self, other): - return isinstance(other, TableFromBlockMatrixNativeReader) and \ - other.path == self.path and \ - other.n_partitions == self.n_partitions and \ - other.maximum_cache_memory_in_bytes == self.maximum_cache_memory_in_bytes + return ( + isinstance(other, TableFromBlockMatrixNativeReader) + and other.path == self.path + and other.n_partitions == self.n_partitions + and other.maximum_cache_memory_in_bytes == self.maximum_cache_memory_in_bytes + ) class AvroTableReader(TableReader): - @typecheck_method(schema=avro.schema.Schema, - paths=sequenceof(str), - key=nullable(sequenceof(str)), - intervals=nullable(sequenceof(anytype))) + @typecheck_method( + schema=avro.schema.Schema, + paths=sequenceof(str), + key=nullable(sequenceof(str)), + intervals=nullable(sequenceof(anytype)), + ) def __init__(self, schema, paths, key, intervals): assert (key is None) == (intervals is None) self.schema = schema @@ -154,19 +179,23 @@ def __init__(self, schema, paths, key, intervals): self._interval_type = hl.tarray(hl.tinterval(hl.tstruct(__point=pt))) if intervals is not None and t != self._interval_type: - self.intervals = [hl.Interval(hl.Struct(__point=i.start), - hl.Struct(__point=i.end), - i.includes_start, - i.includes_end) for i in intervals] + self.intervals = [ + hl.Interval(hl.Struct(__point=i.start), hl.Struct(__point=i.end), i.includes_start, i.includes_end) + for i in intervals + ] else: self.intervals = intervals def render(self): - reader = {'name': 'AvroTableReader', - 'partitionReader': {'name': 'AvroPartitionReader', - 'schema': self.schema.to_json(), - 'uidFieldName': default_row_uid}, - 'paths': self.paths} + reader = { + 'name': 'AvroTableReader', + 'partitionReader': { + 'name': 'AvroPartitionReader', + 'schema': self.schema.to_json(), + 'uidFieldName': default_row_uid, + }, + 'paths': self.paths, + } if self.key is not None: assert self.intervals is not None assert self._interval_type is not None @@ -179,8 +208,10 @@ def render(self): return escape_str(json.dumps(reader)) def __eq__(self, other): - return isinstance(other, AvroTableReader) and \ - other.schema == self.schema and \ - other.paths == self.paths and \ - other.key == self.key and \ - other.intervals == self.intervals + return ( + isinstance(other, AvroTableReader) + and other.schema == self.schema + and other.paths == self.paths + and other.key == self.key + and other.intervals == self.intervals + ) diff --git a/hail/python/hail/ir/table_writer.py b/hail/python/hail/ir/table_writer.py index 891b0401ff1..87ff642e3dd 100644 --- a/hail/python/hail/ir/table_writer.py +++ b/hail/python/hail/ir/table_writer.py @@ -1,6 +1,7 @@ import abc import json -from ..typecheck import typecheck_method, nullable, sequenceof + +from ..typecheck import nullable, sequenceof, typecheck_method from ..utils.misc import escape_str from .export_type import ExportType @@ -16,10 +17,7 @@ def __eq__(self, other): class TableNativeWriter(TableWriter): - @typecheck_method(path=str, - overwrite=bool, - stage_locally=bool, - codec_spec=nullable(str)) + @typecheck_method(path=str, overwrite=bool, stage_locally=bool, codec_spec=nullable(str)) def __init__(self, path, overwrite, stage_locally, codec_spec): super(TableNativeWriter, self).__init__() self.path = path @@ -28,27 +26,27 @@ def __init__(self, path, overwrite, stage_locally, codec_spec): self.codec_spec = codec_spec def render(self): - writer = {'name': 'TableNativeWriter', - 'path': self.path, - 'overwrite': self.overwrite, - 'stageLocally': self.stage_locally, - 'codecSpecJSONStr': self.codec_spec} + writer = { + 'name': 'TableNativeWriter', + 'path': self.path, + 'overwrite': self.overwrite, + 'stageLocally': self.stage_locally, + 'codecSpecJSONStr': self.codec_spec, + } return escape_str(json.dumps(writer)) def __eq__(self, other): - return isinstance(other, TableNativeWriter) and \ - other.path == self.path and \ - other.overwrite == self.overwrite and \ - other.stage_locally == self.stage_locally and \ - other.codec_spec == self.codec_spec + return ( + isinstance(other, TableNativeWriter) + and other.path == self.path + and other.overwrite == self.overwrite + and other.stage_locally == self.stage_locally + and other.codec_spec == self.codec_spec + ) class TableTextWriter(TableWriter): - @typecheck_method(path=str, - types_file=nullable(str), - header=bool, - export_type=ExportType.checker, - delimiter=str) + @typecheck_method(path=str, types_file=nullable(str), header=bool, export_type=ExportType.checker, delimiter=str) def __init__(self, path, types_file, header, export_type, delimiter): super(TableTextWriter, self).__init__() self.path = path @@ -58,29 +56,29 @@ def __init__(self, path, types_file, header, export_type, delimiter): self.delimiter = delimiter def render(self): - writer = {'name': 'TableTextWriter', - 'path': self.path, - 'typesFile': self.types_file, - 'header': self.header, - 'exportType': self.export_type, - 'delimiter': self.delimiter} + writer = { + 'name': 'TableTextWriter', + 'path': self.path, + 'typesFile': self.types_file, + 'header': self.header, + 'exportType': self.export_type, + 'delimiter': self.delimiter, + } return escape_str(json.dumps(writer)) def __eq__(self, other): - return isinstance(other, TableTextWriter) and \ - other.path == self.path and \ - other.types_file == self.types_file and \ - other.header == self.header and \ - other.export_type == self.export_type and \ - other.delimiter == self.delimiter + return ( + isinstance(other, TableTextWriter) + and other.path == self.path + and other.types_file == self.types_file + and other.header == self.header + and other.export_type == self.export_type + and other.delimiter == self.delimiter + ) class TableNativeFanoutWriter(TableWriter): - @typecheck_method(path=str, - fields=sequenceof(str), - overwrite=bool, - stage_locally=bool, - codec_spec=nullable(str)) + @typecheck_method(path=str, fields=sequenceof(str), overwrite=bool, stage_locally=bool, codec_spec=nullable(str)) def __init__(self, path, fields, overwrite, stage_locally, codec_spec): super(TableNativeFanoutWriter, self).__init__() self.path = path @@ -90,18 +88,22 @@ def __init__(self, path, fields, overwrite, stage_locally, codec_spec): self.codec_spec = codec_spec def render(self): - writer = {'name': 'TableNativeFanoutWriter', - 'path': self.path, - 'fields': self.fields, - 'overwrite': self.overwrite, - 'stageLocally': self.stage_locally, - 'codecSpecJSONStr': self.codec_spec} + writer = { + 'name': 'TableNativeFanoutWriter', + 'path': self.path, + 'fields': self.fields, + 'overwrite': self.overwrite, + 'stageLocally': self.stage_locally, + 'codecSpecJSONStr': self.codec_spec, + } return escape_str(json.dumps(writer)) def __eq__(self, other): - return isinstance(other, TableNativeWriter) and \ - other.path == self.path and \ - other.fields == self.fields and \ - other.overwrite == self.overwrite and \ - other.stage_locally == self.stage_locally and \ - other.codec_spec == self.codec_spec + return ( + isinstance(other, TableNativeWriter) + and other.path == self.path + and other.fields == self.fields + and other.overwrite == self.overwrite + and other.stage_locally == self.stage_locally + and other.codec_spec == self.codec_spec + ) diff --git a/hail/python/hail/ir/utils.py b/hail/python/hail/ir/utils.py index 34c2cb00e80..950e05b4d2c 100644 --- a/hail/python/hail/ir/utils.py +++ b/hail/python/hail/ir/utils.py @@ -1,11 +1,13 @@ -from typing import Optional, List, Any, Tuple +from typing import Any, List, Optional, Tuple + import hail as hl -from hail.utils.java import Env from hail.expr.types import tint32, tint64 +from hail.utils.java import Env def finalize_randomness(x): - import hail.ir.ir as ir + from hail.ir import ir + if isinstance(x, ir.IR): x = ir.Let('__rng_state', ir.RNGStateLiteral(), x) elif isinstance(x, ir.TableIR): @@ -20,7 +22,8 @@ def finalize_randomness(x): def unpack_row_uid(new_row_type, uid_field_name, drop_uid=True): - import hail.ir.ir as ir + from hail.ir import ir + new_row = ir.Ref('va', new_row_type) if uid_field_name in new_row_type.fields: uid = ir.GetField(new_row, uid_field_name) @@ -32,18 +35,19 @@ def unpack_row_uid(new_row_type, uid_field_name, drop_uid=True): def unpack_col_uid(new_col_type, uid_field_name): - import hail.ir.ir as ir + from hail.ir import ir + new_row = ir.Ref('sa', new_col_type) if uid_field_name in new_col_type.fields: uid = ir.GetField(new_row, uid_field_name) else: uid = ir.NA(tint64) - return uid, \ - ir.SelectFields(new_row, [field for field in new_col_type.fields if not field == uid_field_name]) + return uid, ir.SelectFields(new_row, [field for field in new_col_type.fields if not field == uid_field_name]) def modify_deep_field(struct, path, new_deep_field, new_struct=None): - import hail.ir.ir as ir + from hail.ir import ir + refs = [struct] for i in range(len(path)): refs.append(ir.Ref(Env.get_uid(), refs[i].typ[path[i]])) @@ -59,32 +63,42 @@ def modify_deep_field(struct, path, new_deep_field, new_struct=None): def zip_with_index(array): - import hail.ir.ir as ir + from hail.ir import ir + elt = Env.get_uid() inner_row_uid = Env.get_uid() iota = ir.StreamIota(ir.I32(0), ir.I32(1)) - return ir.toArray(ir.StreamZip( - [ir.toStream(array), iota], - [elt, inner_row_uid], - ir.MakeTuple((ir.Ref(elt, array.typ.element_type), ir.Ref(inner_row_uid, tint32))), - 'TakeMinLength')) + return ir.toArray( + ir.StreamZip( + [ir.toStream(array), iota], + [elt, inner_row_uid], + ir.MakeTuple((ir.Ref(elt, array.typ.element_type), ir.Ref(inner_row_uid, tint32))), + 'TakeMinLength', + ) + ) def zip_with_index_field(array, idx_field_name): - import hail.ir.ir as ir + from hail.ir import ir + elt = Env.get_uid() inner_row_uid = Env.get_uid() iota = ir.StreamIota(ir.I32(0), ir.I32(1)) - return ir.toArray(ir.StreamZip( - [ir.toStream(array), iota], - [elt, inner_row_uid], - ir.InsertFields(ir.Ref(elt, array.typ.element_type), [(idx_field_name, ir.Cast(ir.Ref(inner_row_uid, tint32), tint64))], None), - 'TakeMinLength')) - - -def impute_type_of_partition_interval_array( - intervals: Optional[List[Any]] -) -> Tuple[Optional[List[Any]], Any]: + return ir.toArray( + ir.StreamZip( + [ir.toStream(array), iota], + [elt, inner_row_uid], + ir.InsertFields( + ir.Ref(elt, array.typ.element_type), + [(idx_field_name, ir.Cast(ir.Ref(inner_row_uid, tint32), tint64))], + None, + ), + 'TakeMinLength', + ) + ) + + +def impute_type_of_partition_interval_array(intervals: Optional[List[Any]]) -> Tuple[Optional[List[Any]], Any]: if intervals is None: return None, None if len(intervals) == 0: @@ -99,10 +113,7 @@ def impute_type_of_partition_interval_array( return intervals, t struct_intervals = [ - hl.Interval(hl.Struct(__point=i.start), - hl.Struct(__point=i.end), - i.includes_start, - i.includes_end) + hl.Interval(hl.Struct(__point=i.start), hl.Struct(__point=i.end), i.includes_start, i.includes_end) for i in intervals ] struct_intervals_type = hl.tarray(hl.tinterval(hl.tstruct(__point=pt))) @@ -110,7 +121,8 @@ def impute_type_of_partition_interval_array( def filter_predicate_with_keep(ir_pred, keep): - import hail.ir.ir as ir + from hail.ir import ir + return ir.Coalesce(ir_pred if keep else ir.ApplyUnaryPrimOp('!', ir_pred), ir.FalseIR()) @@ -120,11 +132,7 @@ def make_filter_and_replace(filter, find_replace): replace = None else: find, replace = find_replace - return { - 'filterPattern': filter, - 'findPattern': find, - 'replacePattern': replace - } + return {'filterPattern': filter, 'findPattern': find, 'replacePattern': replace} def parse_type(string_expr, ttype): diff --git a/hail/python/hail/linalg/__init__.py b/hail/python/hail/linalg/__init__.py index e0343121ba2..1971aab0faa 100644 --- a/hail/python/hail/linalg/__init__.py +++ b/hail/python/hail/linalg/__init__.py @@ -1,9 +1,4 @@ -from .blockmatrix import BlockMatrix, _jarray_from_ndarray, _breeze_from_ndarray, _svd, _eigh -from . import utils as utils +from . import utils +from .blockmatrix import BlockMatrix, _breeze_from_ndarray, _eigh, _jarray_from_ndarray, _svd -__all__ = ['BlockMatrix', - 'utils', - '_jarray_from_ndarray', - '_breeze_from_ndarray', - '_svd', - '_eigh'] +__all__ = ['BlockMatrix', 'utils', '_jarray_from_ndarray', '_breeze_from_ndarray', '_svd', '_eigh'] diff --git a/hail/python/hail/linalg/blockmatrix.py b/hail/python/hail/linalg/blockmatrix.py index a43e6a46666..0936435173f 100644 --- a/hail/python/hail/linalg/blockmatrix.py +++ b/hail/python/hail/linalg/blockmatrix.py @@ -1,8 +1,8 @@ -import os - import itertools import math +import os import re + import numpy as np import scipy.linalg as spla @@ -10,26 +10,63 @@ import hail.expr.aggregators as agg from hail.expr import construct_expr, construct_variable from hail.expr.blockmatrix_type import tblockmatrix -from hail.expr.expressions import (expr_float64, matrix_table_source, expr_ndarray, - raise_unless_entry_indexed, expr_tuple, expr_array, expr_int32, expr_int64) -from hail.ir import (BlockMatrixWrite, BlockMatrixMap2, ApplyBinaryPrimOp, F64, - BlockMatrixBroadcast, ValueToBlockMatrix, BlockMatrixRead, - BlockMatrixMap, ApplyUnaryPrimOp, BlockMatrixDot, BlockMatrixCollect, - tensor_shape_to_matrix_shape, BlockMatrixAgg, BlockMatrixRandom, - BlockMatrixToValueApply, BlockMatrixToTable, BlockMatrixFilter, - TableFromBlockMatrixNativeReader, TableRead, BlockMatrixSlice, - BlockMatrixSparsify, BlockMatrixDensify, RectangleSparsifier, - RowIntervalSparsifier, BandSparsifier, PerBlockSparsifier) -from hail.ir.blockmatrix_reader import BlockMatrixNativeReader, BlockMatrixBinaryReader -from hail.ir.blockmatrix_writer import (BlockMatrixBinaryWriter, - BlockMatrixNativeWriter, BlockMatrixRectanglesWriter) -from hail.ir import ExportType +from hail.expr.expressions import ( + expr_array, + expr_float64, + expr_int32, + expr_int64, + expr_ndarray, + expr_tuple, + matrix_table_source, + raise_unless_entry_indexed, +) +from hail.ir import ( + F64, + ApplyBinaryPrimOp, + ApplyUnaryPrimOp, + BandSparsifier, + BlockMatrixAgg, + BlockMatrixBroadcast, + BlockMatrixCollect, + BlockMatrixDensify, + BlockMatrixDot, + BlockMatrixFilter, + BlockMatrixMap, + BlockMatrixMap2, + BlockMatrixRandom, + BlockMatrixRead, + BlockMatrixSlice, + BlockMatrixSparsify, + BlockMatrixToTable, + BlockMatrixToValueApply, + BlockMatrixWrite, + ExportType, + PerBlockSparsifier, + RectangleSparsifier, + RowIntervalSparsifier, + TableFromBlockMatrixNativeReader, + TableRead, + ValueToBlockMatrix, + tensor_shape_to_matrix_shape, +) +from hail.ir.blockmatrix_reader import BlockMatrixBinaryReader, BlockMatrixNativeReader +from hail.ir.blockmatrix_writer import BlockMatrixBinaryWriter, BlockMatrixNativeWriter, BlockMatrixRectanglesWriter from hail.table import Table -from hail.typecheck import (typecheck, typecheck_method, nullable, oneof, - sliceof, sequenceof, lazy, enumeration, numeric, tupleof, func_spec, - sized_tupleof) -from hail.utils import (new_temp_file, local_path_uri, storage_level, with_local_temp_file, - new_local_temp_file) +from hail.typecheck import ( + enumeration, + func_spec, + lazy, + nullable, + numeric, + oneof, + sequenceof, + sized_tupleof, + sliceof, + tupleof, + typecheck, + typecheck_method, +) +from hail.utils import local_path_uri, new_local_temp_file, new_temp_file, storage_level, with_local_temp_file from hail.utils.java import Env block_matrix_type = lazy() @@ -243,11 +280,7 @@ def read(cls, path, *, _assert_type=None): return cls(BlockMatrixRead(BlockMatrixNativeReader(path), _assert_type=_assert_type)) @classmethod - @typecheck_method(uri=str, - n_rows=int, - n_cols=int, - block_size=nullable(int), - _assert_type=nullable(tblockmatrix)) + @typecheck_method(uri=str, n_rows=int, n_cols=int, block_size=nullable(int), _assert_type=nullable(tblockmatrix)) def fromfile(cls, uri, n_rows, n_cols, block_size=None, *, _assert_type=None): """Creates a block matrix from a binary file. @@ -302,11 +335,12 @@ def fromfile(cls, uri, n_rows, n_cols, block_size=None, *, _assert_type=None): if not block_size: block_size = BlockMatrix.default_block_size() - return cls(BlockMatrixRead(BlockMatrixBinaryReader(uri, [n_rows, n_cols], block_size), _assert_type=_assert_type)) + return cls( + BlockMatrixRead(BlockMatrixBinaryReader(uri, [n_rows, n_cols], block_size), _assert_type=_assert_type) + ) @classmethod - @typecheck_method(ndarray=np.ndarray, - block_size=nullable(int)) + @typecheck_method(ndarray=np.ndarray, block_size=nullable(int)) def from_numpy(cls, ndarray, block_size=None): """Distributes a `NumPy ndarray `__ @@ -359,13 +393,17 @@ def from_numpy(cls, ndarray, block_size=None): return cls.fromfile(uri, n_rows, n_cols, block_size) @classmethod - @typecheck_method(entry_expr=expr_float64, - mean_impute=bool, - center=bool, - normalize=bool, - axis=nullable(enumeration('rows', 'cols')), - block_size=nullable(int)) - def from_entry_expr(cls, entry_expr, mean_impute=False, center=False, normalize=False, axis='rows', block_size=None): + @typecheck_method( + entry_expr=expr_float64, + mean_impute=bool, + center=bool, + normalize=bool, + axis=nullable(enumeration('rows', 'cols')), + block_size=nullable(int), + ) + def from_entry_expr( + cls, entry_expr, mean_impute=False, center=False, normalize=False, axis='rows', block_size=None + ): """Creates a block matrix using a matrix table entry expression. Examples @@ -416,16 +454,20 @@ def from_entry_expr(cls, entry_expr, mean_impute=False, center=False, normalize= Block size. Default given by :meth:`.BlockMatrix.default_block_size`. """ path = new_temp_file() - cls.write_from_entry_expr(entry_expr, path, overwrite=False, mean_impute=mean_impute, - center=center, normalize=normalize, axis=axis, block_size=block_size) + cls.write_from_entry_expr( + entry_expr, + path, + overwrite=False, + mean_impute=mean_impute, + center=center, + normalize=normalize, + axis=axis, + block_size=block_size, + ) return cls.read(path) @classmethod - @typecheck_method(n_rows=int, - n_cols=int, - block_size=nullable(int), - seed=nullable(int), - gaussian=bool) + @typecheck_method(n_rows=int, n_cols=int, block_size=nullable(int), seed=nullable(int), gaussian=bool) def random(cls, n_rows, n_cols, block_size=None, seed=None, gaussian=True) -> 'BlockMatrix': """Creates a block matrix with standard normal or uniform random entries. @@ -463,10 +505,7 @@ def random(cls, n_rows, n_cols, block_size=None, seed=None, gaussian=True) -> 'B return BlockMatrix(rand) @classmethod - @typecheck_method(n_rows=int, - n_cols=int, - value=numeric, - block_size=nullable(int)) + @typecheck_method(n_rows=int, n_cols=int, value=numeric, block_size=nullable(int)) def fill(cls, n_rows, n_cols, value, block_size=None): """Creates a block matrix with all elements the same value. @@ -494,16 +533,11 @@ def fill(cls, n_rows, n_cols, value, block_size=None): if not block_size: block_size = BlockMatrix.default_block_size() - bmir = BlockMatrixBroadcast(_to_bmir(value, block_size), - [], [n_rows, n_cols], - block_size) + bmir = BlockMatrixBroadcast(_to_bmir(value, block_size), [], [n_rows, n_cols], block_size) return BlockMatrix(bmir) @classmethod - @typecheck_method(n_rows=int, - n_cols=int, - data=sequenceof(float), - block_size=nullable(int)) + @typecheck_method(n_rows=int, n_cols=int, data=sequenceof(float), block_size=nullable(int)) def _create(cls, n_rows, n_cols, data, block_size=None): """Private method for creating small test matrices.""" @@ -596,10 +630,7 @@ def _last_row_block_height(self): remainder = self.n_rows % self.block_size return remainder if remainder != 0 else self.block_size - @typecheck_method(path=str, - overwrite=bool, - force_row_major=bool, - stage_locally=bool) + @typecheck_method(path=str, overwrite=bool, force_row_major=bool, stage_locally=bool) def write(self, path, overwrite=False, force_row_major=False, stage_locally=False): """Writes the block matrix. @@ -624,10 +655,7 @@ def write(self, path, overwrite=False, force_row_major=False, stage_locally=Fals writer = BlockMatrixNativeWriter(path, overwrite, force_row_major, stage_locally) Env.backend().execute(BlockMatrixWrite(self._bmir, writer)) - @typecheck_method(path=str, - overwrite=bool, - force_row_major=bool, - stage_locally=bool) + @typecheck_method(path=str, overwrite=bool, force_row_major=bool, stage_locally=bool) def checkpoint(self, path, overwrite=False, force_row_major=False, stage_locally=False): """Checkpoint the block matrix. @@ -652,16 +680,26 @@ def checkpoint(self, path, overwrite=False, force_row_major=False, stage_locally return BlockMatrix.read(path, _assert_type=self._bmir._type) @staticmethod - @typecheck(entry_expr=expr_float64, - path=str, - overwrite=bool, - mean_impute=bool, - center=bool, - normalize=bool, - axis=nullable(enumeration('rows', 'cols')), - block_size=nullable(int)) - def write_from_entry_expr(entry_expr, path, overwrite=False, mean_impute=False, - center=False, normalize=False, axis='rows', block_size=None): + @typecheck( + entry_expr=expr_float64, + path=str, + overwrite=bool, + mean_impute=bool, + center=bool, + normalize=bool, + axis=nullable(enumeration('rows', 'cols')), + block_size=nullable(int), + ) + def write_from_entry_expr( + entry_expr, + path, + overwrite=False, + mean_impute=False, + center=False, + normalize=False, + axis='rows', + block_size=None, + ): """Writes a block matrix from a matrix table entry expression. Examples @@ -749,7 +787,7 @@ def write_from_entry_expr(entry_expr, path, overwrite=False, mean_impute=False, compute = { '__count': agg.count_where(hl.is_defined(mt['__x'])), '__sum': agg.sum(mt['__x']), - '__sum_sq': agg.sum(mt['__x'] * mt['__x']) + '__sum_sq': agg.sum(mt['__x'] * mt['__x']), } if axis == 'rows': n_elements = mt.count_cols() @@ -759,11 +797,10 @@ def write_from_entry_expr(entry_expr, path, overwrite=False, mean_impute=False, mt = mt.select_cols(**compute) compute = { '__mean': mt['__sum'] / mt['__count'], - '__centered_length': hl.sqrt(mt['__sum_sq'] - - (mt['__sum'] ** 2) / mt['__count']), - '__length': hl.sqrt(mt['__sum_sq'] - + (n_elements - mt['__count']) - * ((mt['__sum'] / mt['__count']) ** 2)) + '__centered_length': hl.sqrt(mt['__sum_sq'] - (mt['__sum'] ** 2) / mt['__count']), + '__length': hl.sqrt( + mt['__sum_sq'] + (n_elements - mt['__count']) * ((mt['__sum'] / mt['__count']) ** 2) + ), } if axis == 'rows': mt = mt.select_rows(**compute) @@ -779,14 +816,12 @@ def write_from_entry_expr(entry_expr, path, overwrite=False, mean_impute=False, if mean_impute: expr = hl.or_else(expr, mt['__mean']) expr = expr / mt['__length'] - else: - if center: - expr = expr - mt['__mean'] - if mean_impute: - expr = hl.or_else(expr, 0.0) - else: - if mean_impute: - expr = hl.or_else(expr, mt['__mean']) + elif center: + expr = expr - mt['__mean'] + if mean_impute: + expr = hl.or_else(expr, 0.0) + elif mean_impute: + expr = hl.or_else(expr, mt['__mean']) field = Env.get_uid() mt.select_entries(**{field: expr})._write_block_matrix(path, overwrite, field, block_size) @@ -834,8 +869,7 @@ def filter_cols(self, cols_to_keep): BlockMatrix._check_indices(cols_to_keep, self.n_cols) return BlockMatrix(BlockMatrixFilter(self._bmir, [[], cols_to_keep])) - @typecheck_method(rows_to_keep=sequenceof(int), - cols_to_keep=sequenceof(int)) + @typecheck_method(rows_to_keep=sequenceof(int), cols_to_keep=sequenceof(int)) def filter(self, rows_to_keep, cols_to_keep): """Filters matrix rows and columns. @@ -899,8 +933,7 @@ def __getitem__(self, indices): i = BlockMatrix._pos_index(row_idx, self.n_rows, 'row index') j = BlockMatrix._pos_index(col_idx, self.n_cols, 'col index') - return Env.backend().execute(BlockMatrixToValueApply(self._bmir, - {'name': 'GetElement', 'index': [i, j]})) + return Env.backend().execute(BlockMatrixToValueApply(self._bmir, {'name': 'GetElement', 'index': [i, j]})) rows_to_keep = BlockMatrix._range_to_keep(row_idx, self.n_rows) cols_to_keep = BlockMatrix._range_to_keep(col_idx, self.n_cols) @@ -1046,22 +1079,17 @@ def sparsify_triangle(self, lower=False, blocks_only=False): return self.sparsify_band(lower_band, upper_band, blocks_only) - @typecheck_method(intervals=expr_tuple([expr_array(expr_int64), expr_array(expr_int64)]), - blocks_only=bool) + @typecheck_method(intervals=expr_tuple([expr_array(expr_int64), expr_array(expr_int64)]), blocks_only=bool) def _sparsify_row_intervals_expr(self, intervals, blocks_only=False): - return BlockMatrix( - BlockMatrixSparsify(self._bmir, intervals._ir, - RowIntervalSparsifier(blocks_only))) + return BlockMatrix(BlockMatrixSparsify(self._bmir, intervals._ir, RowIntervalSparsifier(blocks_only))) @typecheck_method(indices=expr_array(expr_int32)) def _sparsify_blocks(self, indices): - return BlockMatrix( - BlockMatrixSparsify(self._bmir, indices._ir, - PerBlockSparsifier())) + return BlockMatrix(BlockMatrixSparsify(self._bmir, indices._ir, PerBlockSparsifier())) - @typecheck_method(starts=oneof(sequenceof(int), np.ndarray), - stops=oneof(sequenceof(int), np.ndarray), - blocks_only=bool) + @typecheck_method( + starts=oneof(sequenceof(int), np.ndarray), stops=oneof(sequenceof(int), np.ndarray), blocks_only=bool + ) def sparsify_row_intervals(self, starts, stops, blocks_only=False): """Creates a block-sparse matrix by filtering to an interval for each row. @@ -1129,11 +1157,11 @@ def sparsify_row_intervals(self, starts, stops, blocks_only=False): Sparse block matrix. """ if isinstance(starts, np.ndarray): - if not (starts.dtype == np.int32 or starts.dtype == np.int64): + if starts.dtype not in (np.int32, np.int64): raise ValueError("sparsify_row_intervals: starts ndarray must have dtype 'int32' or 'int64'") starts = [int(s) for s in starts] if isinstance(stops, np.ndarray): - if not (stops.dtype == np.int32 or stops.dtype == np.int64): + if stops.dtype not in (np.int32, np.int64): raise ValueError("sparsify_row_intervals: stops ndarray must have dtype 'int32' or 'int64'") stops = [int(s) for s in stops] @@ -1228,9 +1256,10 @@ def to_numpy(self, _force_blocking=False): if isinstance(hl.current_backend(), ServiceBackend): with hl.TemporaryFilename() as path: self.tofile(path) - return np.frombuffer( - hl.current_backend().fs.open(path, mode='rb').read() - ).reshape((self.n_rows, self.n_cols)) + return np.frombuffer(hl.current_backend().fs.open(path, mode='rb').read()).reshape(( + self.n_rows, + self.n_cols, + )) with with_local_temp_file() as path: uri = local_path_uri(path) @@ -1377,10 +1406,12 @@ def _apply_map(self, f, needs_dense): bmir = BlockMatrixDensify(bmir) return BlockMatrix(BlockMatrixMap(bmir, uid, f(construct_variable(uid, hl.tfloat64))._ir, needs_dense)) - @typecheck_method(f=func_spec(2, expr_float64), - other=oneof(numeric, np.ndarray, block_matrix_type), - sparsity_strategy=str, - reverse=bool) + @typecheck_method( + f=func_spec(2, expr_float64), + other=oneof(numeric, np.ndarray, block_matrix_type), + sparsity_strategy=str, + reverse=bool, + ) def _apply_map2(self, f, other, sparsity_strategy, reverse=False): if not isinstance(other, BlockMatrix): other = BlockMatrix(_to_bmir(other, self.block_size)) @@ -1489,10 +1520,14 @@ def _select_blocks(self, block_row_range, block_col_range): start_bcol, stop_bcol = block_col_range start_row = start_brow * self.block_size - stop_row = (stop_brow - 1) * self.block_size + (self._last_row_block_height if stop_brow == self._n_block_rows else self.block_size) + stop_row = (stop_brow - 1) * self.block_size + ( + self._last_row_block_height if stop_brow == self._n_block_rows else self.block_size + ) start_col = start_bcol * self.block_size - stop_col = (stop_bcol - 1) * self.block_size + (self._last_col_block_width if stop_bcol == self._n_block_cols else self.block_size) + stop_col = (stop_bcol - 1) * self.block_size + ( + self._last_col_block_width if stop_bcol == self._n_block_cols else self.block_size + ) return self[start_row:stop_row, start_col:stop_col] @@ -1547,15 +1582,22 @@ def tree_matmul(self, b, *, splits, path_prefix=None): if splits != 1: inner_brange_size = int(math.ceil(self._n_block_cols / splits)) - split_points = list(range(0, self._n_block_cols, inner_brange_size)) + [self._n_block_cols] + split_points = [*list(range(0, self._n_block_cols, inner_brange_size)), self._n_block_cols] inner_ranges = list(zip(split_points[:-1], split_points[1:])) - blocks_to_multiply = [(self._select_blocks((0, self._n_block_rows), (start, stop)), - b._select_blocks((start, stop), (0, b._n_block_cols))) for start, stop in inner_ranges] + blocks_to_multiply = [ + ( + self._select_blocks((0, self._n_block_rows), (start, stop)), + b._select_blocks((start, stop), (0, b._n_block_cols)), + ) + for start, stop in inner_ranges + ] intermediate_multiply_exprs = [b1 @ b2 for b1, b2 in blocks_to_multiply] hl.experimental.write_block_matrices(intermediate_multiply_exprs, path_prefix) - read_intermediates = [BlockMatrix.read(f"{path_prefix}_{i}") for i in range(0, len(intermediate_multiply_exprs))] + read_intermediates = [ + BlockMatrix.read(f"{path_prefix}_{i}") for i in range(0, len(intermediate_multiply_exprs)) + ] return sum(read_intermediates) @@ -1574,7 +1616,7 @@ def __pow__(self, x): ------- :class:`.BlockMatrix` """ - return self._apply_map(lambda i: i ** x, needs_dense=False) + return self._apply_map(lambda i: i**x, needs_dense=False) def _map_dense(self, func): return self._apply_map(func, True) @@ -1634,10 +1676,7 @@ def diagonal(self): ------- :class:`.BlockMatrix` """ - diag_bmir = BlockMatrixBroadcast(self._bmir, - [0, 0], - [1, min(self.n_rows, self.n_cols)], - self.block_size) + diag_bmir = BlockMatrixBroadcast(self._bmir, [0, 0], [1, min(self.n_rows, self.n_cols)], self.block_size) return BlockMatrix(diag_bmir) @typecheck_method(axis=nullable(int)) @@ -1678,7 +1717,7 @@ def sum(self, axis=None): if axis is None: bmir = BlockMatrixAgg(self._bmir, [0, 1]) return BlockMatrix(bmir)[0, 0] - elif axis == 0 or axis == 1: + elif axis in {0, 1}: out_index_expr = [axis] bmir = BlockMatrixAgg(self._bmir, out_index_expr) @@ -1781,7 +1820,8 @@ def to_table_row_major(self, n_partitions=None, maximum_cache_memory_in_bytes=No path = new_temp_file() if maximum_cache_memory_in_bytes and maximum_cache_memory_in_bytes > (1 << 31) - 1: raise ValueError( - f'maximum_cache_memory_in_bytes must be less than 2^31 -1, was: {maximum_cache_memory_in_bytes}') + f'maximum_cache_memory_in_bytes must be less than 2^31 -1, was: {maximum_cache_memory_in_bytes}' + ) self.write(path, overwrite=True, force_row_major=True) reader = TableFromBlockMatrixNativeReader(path, n_partitions, maximum_cache_memory_in_bytes) @@ -1820,16 +1860,26 @@ def to_matrix_table_row_major(self, n_partitions=None, maximum_cache_memory_in_b return t._unlocalize_entries('entries', 'cols', ['col_idx']) @staticmethod - @typecheck(path_in=str, - path_out=str, - delimiter=str, - header=nullable(str), - add_index=bool, - parallel=nullable(ExportType.checker), - partition_size=nullable(int), - entries=enumeration('full', 'lower', 'strict_lower', 'upper', 'strict_upper')) - def export(path_in, path_out, delimiter='\t', header=None, add_index=False, parallel=None, - partition_size=None, entries='full'): + @typecheck( + path_in=str, + path_out=str, + delimiter=str, + header=nullable(str), + add_index=bool, + parallel=nullable(ExportType.checker), + partition_size=nullable(int), + entries=enumeration('full', 'lower', 'strict_lower', 'upper', 'strict_upper'), + ) + def export( + path_in, + path_out, + delimiter='\t', + header=None, + add_index=False, + parallel=None, + partition_size=None, + entries='full', + ): """Exports a stored block matrix as a delimited text file. Examples @@ -1980,7 +2030,8 @@ def export(path_in, path_out, delimiter='\t', header=None, add_index=False, para export_type = ExportType.default(parallel) Env.spark_backend('BlockMatrix.export')._jbackend.pyExportBlockMatrix( - path_in, path_out, delimiter, header, add_index, export_type, partition_size, entries) + path_in, path_out, delimiter, header, add_index, export_type, partition_size, entries + ) @typecheck_method(rectangles=sequenceof(sequenceof(int))) def sparsify_rectangles(self, rectangles): @@ -2044,17 +2095,14 @@ def sparsify_rectangles(self, rectangles): if len(r) != 4: raise ValueError(f'rectangle {r} does not have length 4') if not (0 <= r[0] <= r[1] <= n_rows and 0 <= r[2] <= r[3] <= n_cols): - raise ValueError(f'rectangle {r} does not satisfy ' - f'0 <= r[0] <= r[1] <= n_rows and 0 <= r[2] <= r[3] <= n_cols') + raise ValueError( + f'rectangle {r} does not satisfy ' f'0 <= r[0] <= r[1] <= n_rows and 0 <= r[2] <= r[3] <= n_cols' + ) rectangles = hl.literal(list(itertools.chain(*rectangles)), hl.tarray(hl.tint64)) - return BlockMatrix( - BlockMatrixSparsify(self._bmir, rectangles._ir, RectangleSparsifier)) + return BlockMatrix(BlockMatrixSparsify(self._bmir, rectangles._ir, RectangleSparsifier)) - @typecheck_method(path_out=str, - rectangles=sequenceof(sequenceof(int)), - delimiter=str, - binary=bool) + @typecheck_method(path_out=str, rectangles=sequenceof(sequenceof(int)), delimiter=str, binary=bool) def export_rectangles(self, path_out, rectangles, delimiter='\t', binary=False): """Export rectangular regions from a block matrix to delimited text or binary files. @@ -2152,8 +2200,9 @@ def export_rectangles(self, path_out, rectangles, delimiter='\t', binary=False): if len(r) != 4: raise ValueError(f'rectangle {r} does not have length 4') if not (0 <= r[0] <= r[1] <= self.n_rows and 0 <= r[2] <= r[3] <= self.n_cols): - raise ValueError(f'rectangle {r} does not satisfy ' - f'0 <= r[0] <= r[1] <= n_rows and 0 <= r[2] <= r[3] <= n_cols') + raise ValueError( + f'rectangle {r} does not satisfy ' f'0 <= r[0] <= r[1] <= n_rows and 0 <= r[2] <= r[3] <= n_cols' + ) writer = BlockMatrixRectanglesWriter(path_out, rectangles, delimiter, binary) Env.backend().execute(BlockMatrixWrite(self._bmir, writer)) @@ -2222,6 +2271,7 @@ def export_blocks(self, path_out, delimiter='\t', binary=False): binary: :obj:`bool` If true, export elements as raw bytes in row major order. """ + def rows_in_block(block_row): if block_row == self._n_block_rows - 1: return self.n_rows - block_row * self.block_size @@ -2294,6 +2344,7 @@ def rectangles_to_numpy(path, binary=False): ------- :class:`numpy.ndarray` """ + def parse_rects(fname): rect_idx_and_bounds = [int(i) for i in re.findall(r'\d+', fname)] if len(rect_idx_and_bounds) != 5: @@ -2315,11 +2366,10 @@ def parse_rects(fname): rect_data = np.reshape(np.fromfile(f), (rect[2] - rect[1], rect[4] - rect[3])) else: rect_data = np.loadtxt(f, ndmin=2) - nd[rect[1]:rect[2], rect[3]:rect[4]] = rect_data + nd[rect[1] : rect[2], rect[3] : rect[4]] = rect_data return nd - @typecheck_method(compute_uv=bool, - complexity_bound=int) + @typecheck_method(compute_uv=bool, complexity_bound=int) def svd(self, compute_uv=True, complexity_bound=8192): r"""Computes the reduced singular value decomposition. @@ -2453,7 +2503,7 @@ def svd(self, compute_uv=True, complexity_bound=8192): """ n, m = self.shape - if n * m * min(n, m) <= complexity_bound ** 3: + if n * m * min(n, m) <= complexity_bound**3: return _svd(self.to_numpy(), full_matrices=False, compute_uv=compute_uv, overwrite_a=True) else: return self._svd_gramian(compute_uv) @@ -2467,9 +2517,7 @@ def _svd_gramian(self, compute_uv): raise ValueError(f'svd: dimensions {n} and {m} both exceed 46300') left_gramian = n <= m - a = ((x @ x.T if left_gramian else x.T @ x) - .sparsify_triangle(lower=True, blocks_only=True) - .to_numpy()) + a = (x @ x.T if left_gramian else x.T @ x).sparsify_triangle(lower=True, blocks_only=True).to_numpy() if compute_uv: e, w = _eigh(a) @@ -2499,7 +2547,7 @@ def _svd_gramian(self, compute_uv): def _is_scalar(x): - return isinstance(x, float) or isinstance(x, int) + return isinstance(x, (int, float)) def _shape_after_broadcast(left, right): @@ -2508,8 +2556,9 @@ def _shape_after_broadcast(left, right): compare corresponding dimensions. See: https://docs.scipy.org/doc/numpy-1.15.0/user/basics.broadcasting.html#general-broadcasting-rules """ + def join_dim(l_size, r_size): - if not (l_size == r_size or l_size == 1 or r_size == 1): + if not (l_size == r_size or l_size == 1 or r_size == 1): # noqa: PLR1714 raise ValueError(f'Incompatible shapes for broadcasting: {left}, {right}') return max(l_size, r_size) @@ -2574,7 +2623,9 @@ def _jarray_from_ndarray(nd): with with_local_temp_file() as path: uri = local_path_uri(path) nd.tofile(path) - return Env.hail().utils.richUtils.RichArray.importFromDoubles(Env.spark_backend('_jarray_from_ndarray').fs._jfs, uri, nd.size) + return Env.hail().utils.richUtils.RichArray.importFromDoubles( + Env.spark_backend('_jarray_from_ndarray').fs._jfs, uri, nd.size + ) def _ndarray_from_jarray(ja): @@ -2587,7 +2638,9 @@ def _ndarray_from_jarray(ja): def _breeze_fromfile(uri, n_rows, n_cols): _check_entries_size(n_rows, n_cols) - return Env.hail().utils.richUtils.RichDenseMatrixDouble.importFromDoubles(Env.spark_backend('_breeze_fromfile').fs._jfs, uri, n_rows, n_cols, True) + return Env.hail().utils.richUtils.RichDenseMatrixDouble.importFromDoubles( + Env.spark_backend('_breeze_fromfile').fs._jfs, uri, n_rows, n_cols, True + ) def _check_entries_size(n_rows, n_cols): @@ -2618,12 +2671,24 @@ def _svd(a, full_matrices=True, compute_uv=True, overwrite_a=False, check_finite DC (gesdd) is faster but uses O(elements) memory; lwork may overflow int32 """ try: - return spla.svd(a, full_matrices=full_matrices, compute_uv=compute_uv, overwrite_a=overwrite_a, - check_finite=check_finite, lapack_driver='gesdd') + return spla.svd( + a, + full_matrices=full_matrices, + compute_uv=compute_uv, + overwrite_a=overwrite_a, + check_finite=check_finite, + lapack_driver='gesdd', + ) except ValueError as e: if 'Too large work array required' in str(e): - return spla.svd(a, full_matrices=full_matrices, compute_uv=compute_uv, overwrite_a=overwrite_a, - check_finite=check_finite, lapack_driver='gesvd') + return spla.svd( + a, + full_matrices=full_matrices, + compute_uv=compute_uv, + overwrite_a=overwrite_a, + check_finite=check_finite, + lapack_driver='gesvd', + ) else: raise diff --git a/hail/python/hail/linalg/utils/__init__.py b/hail/python/hail/linalg/utils/__init__.py index 82fb11b7637..b4f02816161 100644 --- a/hail/python/hail/linalg/utils/__init__.py +++ b/hail/python/hail/linalg/utils/__init__.py @@ -1,5 +1,3 @@ -from .misc import array_windows, locus_windows, _check_dims +from .misc import _check_dims, array_windows, locus_windows -__all__ = ['array_windows', - 'locus_windows', - '_check_dims'] +__all__ = ['array_windows', 'locus_windows', '_check_dims'] diff --git a/hail/python/hail/linalg/utils/misc.py b/hail/python/hail/linalg/utils/misc.py index 3afc343dfd9..c46d40ee0d7 100644 --- a/hail/python/hail/linalg/utils/misc.py +++ b/hail/python/hail/linalg/utils/misc.py @@ -1,13 +1,12 @@ import numpy as np import hail as hl -from hail.typecheck import typecheck, oneof, nullable -from hail.expr.expressions import expr_locus, expr_float64, raise_unless_row_indexed +from hail.expr.expressions import expr_float64, expr_locus, raise_unless_row_indexed +from hail.typecheck import nullable, oneof, typecheck from hail.utils.java import Env -@typecheck(a=np.ndarray, - radius=oneof(int, float)) +@typecheck(a=np.ndarray, radius=oneof(int, float)) def array_windows(a, radius): """Returns start and stop indices for window around each array value. @@ -48,8 +47,9 @@ def array_windows(a, radius): if a.ndim != 1: raise ValueError("array_windows: 'a' must be 1-dimensional") if not (np.issubdtype(a.dtype, np.signedinteger) or np.issubdtype(a.dtype, np.floating)): - raise ValueError(f"array_windows: 'a' must be an ndarray of signed integer or float values, " - f"found dtype {str(a.dtype)}") + raise ValueError( + f"array_windows: 'a' must be an ndarray of signed integer or float values, " f"found dtype {a.dtype!s}" + ) size = a.size if size == 0: @@ -78,10 +78,7 @@ def array_windows(a, radius): return starts, stops -@typecheck(locus_expr=expr_locus(), - radius=oneof(int, float), - coord_expr=nullable(expr_float64), - _localize=bool) +@typecheck(locus_expr=expr_locus(), radius=oneof(int, float), coord_expr=nullable(expr_float64), _localize=bool) def locus_windows(locus_expr, radius, coord_expr=None, _localize=True): """Returns start and stop indices for window around each locus. @@ -198,16 +195,27 @@ def locus_windows(locus_expr, radius, coord_expr=None, _localize=True): contig_group_expr = hl.agg.group_by(hl.locus(locus_expr.contig, 1, reference_genome=rg), hl.agg.collect(coord_expr)) # check loci are in sorted order - last_pos = hl.fold(lambda a, elt: (hl.case() - .when(a <= elt, elt) - .or_error(hl.str("locus_windows: 'locus_expr' global position must be in ascending order. ") + hl.str(a) + hl.str(" was not less then or equal to ") + hl.str(elt))), - -1, - hl.agg.collect(hl.case() - .when(hl.is_defined(locus_expr), locus_expr.global_position()) - .or_error("locus_windows: missing value for 'locus_expr'."))) - checked_contig_groups = (hl.case() - .when(last_pos >= 0, contig_group_expr) - .or_error("locus_windows: 'locus_expr' has length 0")) + last_pos = hl.fold( + lambda a, elt: ( + hl.case() + .when(a <= elt, elt) + .or_error( + hl.str("locus_windows: 'locus_expr' global position must be in ascending order. ") + + hl.str(a) + + hl.str(" was not less then or equal to ") + + hl.str(elt) + ) + ), + -1, + hl.agg.collect( + hl.case() + .when(hl.is_defined(locus_expr), locus_expr.global_position()) + .or_error("locus_windows: missing value for 'locus_expr'.") + ), + ) + checked_contig_groups = ( + hl.case().when(last_pos >= 0, contig_group_expr).or_error("locus_windows: 'locus_expr' has length 0") + ) contig_groups = locus_expr._aggregation_method()(checked_contig_groups, _localize=False) @@ -223,12 +231,10 @@ def locus_windows(locus_expr, radius, coord_expr=None, _localize=True): def _check_dims(a, name, ndim, min_size=1): if len(a.shape) != ndim: - raise ValueError(f'{name} must be {ndim}-dimensional, ' - f'found {a.ndim}') + raise ValueError(f'{name} must be {ndim}-dimensional, ' f'found {a.ndim}') for i in range(ndim): if a.shape[i] < min_size: - raise ValueError(f'{name}.shape[{i}] must be at least ' - f'{min_size}, found {a.shape[i]}') + raise ValueError(f'{name}.shape[{i}] must be at least ' f'{min_size}, found {a.shape[i]}') def _ndarray_matmul_ndim(left, right): diff --git a/hail/python/hail/matrixtable.py b/hail/python/hail/matrixtable.py index 9cf2455b072..d04e52e4ba5 100644 --- a/hail/python/hail/matrixtable.py +++ b/hail/python/hail/matrixtable.py @@ -1,38 +1,63 @@ import itertools -from typing import Iterable, Optional, Dict, Tuple, Any, List +import warnings from collections import Counter +from typing import Any, Dict, Iterable, List, Optional, Tuple + +from deprecated import deprecated + import hail as hl -from hail.expr.expressions import Expression, StructExpression, \ - expr_struct, expr_any, expr_bool, analyze, Indices, \ - construct_reference, construct_expr, extract_refs_by_indices, \ - ExpressionException, TupleExpression, unify_all -from hail.expr.types import types_match, tarray, tset +from hail import ir +from hail.expr.expressions import ( + Expression, + ExpressionException, + Indices, + StructExpression, + TupleExpression, + analyze, + construct_expr, + construct_reference, + expr_any, + expr_bool, + expr_struct, + extract_refs_by_indices, + unify_all, +) from hail.expr.matrix_type import tmatrix -import hail.ir as ir -from hail.table import Table, ExprContainer, TableIndexKeyError -from hail.typecheck import typecheck, typecheck_method, dictof, anytype, \ - anyfunc, nullable, sequenceof, oneof, numeric, lazy, enumeration -from hail.utils import storage_level, default_handler, deduplicate -from hail.utils.java import warning, Env, info -from hail.utils.misc import wrap_to_tuple, \ - get_key_by_exprs, \ - get_select_exprs, check_annotate_exprs, process_joins -import warnings +from hail.expr.types import tarray, tset, types_match +from hail.table import ExprContainer, Table, TableIndexKeyError +from hail.typecheck import ( + anyfunc, + anytype, + dictof, + enumeration, + lazy, + nullable, + numeric, + oneof, + sequenceof, + typecheck, + typecheck_method, +) +from hail.utils import deduplicate, default_handler, storage_level +from hail.utils.java import Env, info, warning +from hail.utils.misc import check_annotate_exprs, get_key_by_exprs, get_select_exprs, process_joins, wrap_to_tuple class GroupedMatrixTable(ExprContainer): """Matrix table grouped by row or column that can be aggregated into a new matrix table.""" - def __init__(self, - parent: 'MatrixTable', - row_keys=None, - computed_row_key=None, - col_keys=None, - computed_col_key=None, - entry_fields=None, - row_fields=None, - col_fields=None, - partitions=None): + def __init__( + self, + parent: 'MatrixTable', + row_keys=None, + computed_row_key=None, + col_keys=None, + computed_col_key=None, + entry_fields=None, + row_fields=None, + col_fields=None, + partitions=None, + ): super(GroupedMatrixTable, self).__init__() self._parent = parent self._copy_fields_from(parent) @@ -45,15 +70,18 @@ def __init__(self, self._col_fields = col_fields self._partitions = partitions - def _copy(self, *, - row_keys=None, - computed_row_key=None, - col_keys=None, - computed_col_key=None, - entry_fields=None, - row_fields=None, - col_fields=None, - partitions=None): + def _copy( + self, + *, + row_keys=None, + computed_row_key=None, + col_keys=None, + computed_col_key=None, + entry_fields=None, + row_fields=None, + col_fields=None, + partitions=None, + ): return GroupedMatrixTable( parent=self._parent, row_keys=row_keys if row_keys is not None else self._row_keys, @@ -63,7 +91,7 @@ def _copy(self, *, entry_fields=entry_fields if entry_fields is not None else self._entry_fields, row_fields=row_fields if row_fields is not None else self._row_fields, col_fields=col_fields if col_fields is not None else self._col_fields, - partitions=partitions if partitions is not None else self._partitions + partitions=partitions if partitions is not None else self._partitions, ) def _fixed_indices(self): @@ -92,16 +120,17 @@ def describe(self, handler=print): else: colstr = "\nColumns: \n" + "\n ".join(["{}: {}".format(k, v) for k, v in self._col_keys.items()]) - s = (f'----------------------------------------\n' - f'GroupedMatrixTable grouped by {rowstr}{colstr}\n' - f'----------------------------------------\n' - f'Parent MatrixTable:\n') + s = ( + f'----------------------------------------\n' + f'GroupedMatrixTable grouped by {rowstr}{colstr}\n' + f'----------------------------------------\n' + f'Parent MatrixTable:\n' + ) handler(s) self._parent.describe(handler) - @typecheck_method(exprs=oneof(str, Expression), - named_exprs=expr_any) + @typecheck_method(exprs=oneof(str, Expression), named_exprs=expr_any) def group_rows_by(self, *exprs, **named_exprs) -> 'GroupedMatrixTable': """Group rows. @@ -135,18 +164,18 @@ def group_rows_by(self, *exprs, **named_exprs) -> 'GroupedMatrixTable': raise NotImplementedError("GroupedMatrixTable is already grouped by cols; cannot also group by rows.") caller = 'group_rows_by' - row_key, computed_key = get_key_by_exprs(caller, - exprs, - named_exprs, - self._parent._row_indices, - override_protected_indices={self._parent._global_indices, - self._parent._col_indices}) + row_key, computed_key = get_key_by_exprs( + caller, + exprs, + named_exprs, + self._parent._row_indices, + override_protected_indices={self._parent._global_indices, self._parent._col_indices}, + ) self._check_bindings(caller, computed_key, self._parent._row_indices) return self._copy(row_keys=row_key, computed_row_key=computed_key) - @typecheck_method(exprs=oneof(str, Expression), - named_exprs=expr_any) + @typecheck_method(exprs=oneof(str, Expression), named_exprs=expr_any) def group_cols_by(self, *exprs, **named_exprs) -> 'GroupedMatrixTable': """Group columns. @@ -180,12 +209,13 @@ def group_cols_by(self, *exprs, **named_exprs) -> 'GroupedMatrixTable': raise NotImplementedError("GroupedMatrixTable is already grouped by cols.") caller = 'group_cols_by' - col_key, computed_key = get_key_by_exprs(caller, - exprs, - named_exprs, - self._parent._col_indices, - override_protected_indices={self._parent._global_indices, - self._parent._row_indices}) + col_key, computed_key = get_key_by_exprs( + caller, + exprs, + named_exprs, + self._parent._col_indices, + override_protected_indices={self._parent._global_indices, self._parent._row_indices}, + ) self._check_bindings(caller, computed_key, self._parent._col_indices) return self._copy(col_keys=col_key, computed_col_key=computed_key) @@ -202,13 +232,16 @@ def iter_option(o): assert indices == self._parent._col_indices fixed_fields = [*self._parent.globals, *self._parent.row] - bound_fields = set(itertools.chain( - iter_option(self._row_keys), - iter_option(self._col_keys), - iter_option(self._col_fields), - iter_option(self._row_fields), - iter_option(self._entry_fields), - fixed_fields)) + bound_fields = set( + itertools.chain( + iter_option(self._row_keys), + iter_option(self._col_keys), + iter_option(self._col_fields), + iter_option(self._row_fields), + iter_option(self._entry_fields), + fixed_fields, + ) + ) for k in new_bindings: if k in bound_fields: @@ -364,10 +397,18 @@ def aggregate_entries(self, **named_exprs) -> 'GroupedMatrixTable': base = self._entry_fields if self._entry_fields is not None else hl.struct() for k, e in named_exprs.items(): - analyze('GroupedMatrixTable.aggregate_entries', e, self._fixed_indices(), {self._parent._row_axis, self._parent._col_axis}) - - self._check_bindings('aggregate_entries', named_exprs, - self._parent._col_indices if self._col_keys is not None else self._parent._row_indices) + analyze( + 'GroupedMatrixTable.aggregate_entries', + e, + self._fixed_indices(), + {self._parent._row_axis, self._parent._col_axis}, + ) + + self._check_bindings( + 'aggregate_entries', + named_exprs, + self._parent._col_indices if self._col_keys is not None else self._parent._row_indices, + ) return self._copy(entry_fields=base.annotate(**named_exprs)) def result(self) -> 'MatrixTable': @@ -414,6 +455,7 @@ def result(self) -> 'MatrixTable': def promote_none(e): return hl.struct() if e is None else e + entry_exprs = promote_none(self._entry_fields) if len(entry_exprs) == 0: warning("'GroupedMatrixTable.result': No entry fields were defined.") @@ -424,27 +466,36 @@ def promote_none(e): cck = self._computed_col_key or {} computed_key_uids = {k: Env.get_uid() for k in cck} modified_keys = [computed_key_uids.get(k, k) for k in self._col_keys] - mt = MatrixTable(ir.MatrixAggregateColsByKey( - ir.MatrixMapCols( - base._mir, - self._parent.col.annotate(**{computed_key_uids[k]: v for k, v in cck.items()})._ir, - modified_keys), - entry_exprs._ir, - promote_none(self._col_fields)._ir)) + mt = MatrixTable( + ir.MatrixAggregateColsByKey( + ir.MatrixMapCols( + base._mir, + self._parent.col.annotate(**{computed_key_uids[k]: v for k, v in cck.items()})._ir, + modified_keys, + ), + entry_exprs._ir, + promote_none(self._col_fields)._ir, + ) + ) if cck: mt = mt.rename({v: k for k, v in computed_key_uids.items()}) else: cck = self._computed_row_key or {} computed_key_uids = {k: Env.get_uid() for k in cck} modified_keys = [computed_key_uids.get(k, k) for k in self._row_keys] - mt = MatrixTable(ir.MatrixAggregateRowsByKey( - ir.MatrixKeyRowsBy( - ir.MatrixMapRows( - ir.MatrixKeyRowsBy(base._mir, []), - self._parent._rvrow.annotate(**{computed_key_uids[k]: v for k, v in cck.items()})._ir), - modified_keys), - entry_exprs._ir, - promote_none(self._row_fields)._ir)) + mt = MatrixTable( + ir.MatrixAggregateRowsByKey( + ir.MatrixKeyRowsBy( + ir.MatrixMapRows( + ir.MatrixKeyRowsBy(base._mir, []), + self._parent._rvrow.annotate(**{computed_key_uids[k]: v for k, v in cck.items()})._ir, + ), + modified_keys, + ), + entry_exprs._ir, + promote_none(self._row_fields)._ir, + ) + ) if cck: mt = mt.rename({v: k for k, v in computed_key_uids.items()}) @@ -547,7 +598,7 @@ def from_parts( globals: Optional[Dict[str, Any]] = None, rows: Optional[Dict[str, Iterable[Any]]] = None, cols: Optional[Dict[str, Iterable[Any]]] = None, - entries: Optional[Dict[str, Iterable[Iterable[Any]]]] = None + entries: Optional[Dict[str, Iterable[Iterable[Any]]]] = None, ) -> 'MatrixTable': """Create a `MatrixTable` from its component parts. @@ -634,6 +685,7 @@ def from_parts( A MatrixTable assembled from inputs whose rows are keyed by `row_idx` and columns are keyed by `col_idx`. """ + # General idea: build a `Table` representation matching that returned by # `MatrixTable.localize_entries` and then call `_unlocalize_entries`. In # this form, the column table is bundled with the globals and the entries @@ -670,14 +722,12 @@ def anyval(kvs): globals[cols_field_name] = cols rows = transpose(rows) if rows else [{} for _ in anyval(entries)] - entries = [transpose(e) for e in transpose(entries) - ] if entries else [[{} for _ in cols] for _ in rows] + entries = [transpose(e) for e in transpose(entries)] if entries else [[{} for _ in cols] for _ in rows] if len(rows) != len(entries) or len(cols) != len(entries[0]): - raise ValueError(( - "mismatched matrix dimensions: " - "number of rows and cols does not match entry dimensions." - )) + raise ValueError( + ("mismatched matrix dimensions: " "number of rows and cols does not match entry dimensions.") + ) entries_field_name = Env.get_uid() for i, (row_props, entry_props) in enumerate(zip(rows, entries)): @@ -710,57 +760,54 @@ def __init__(self, mir): self._row_type = self._type.row_type self._entry_type = self._type.entry_type - self._globals = construct_reference('global', self._global_type, - indices=self._global_indices) - self._rvrow = construct_reference('va', - self._type.row_type, - indices=self._row_indices) + self._globals = construct_reference('global', self._global_type, indices=self._global_indices) + self._rvrow = construct_reference('va', self._type.row_type, indices=self._row_indices) self._row = hl.struct(**{k: self._rvrow[k] for k in self._row_type.keys()}) - self._col = construct_reference('sa', self._col_type, - indices=self._col_indices) - self._entry = construct_reference('g', self._entry_type, - indices=self._entry_indices) - - self._indices_from_ref = {'global': self._global_indices, - 'va': self._row_indices, - 'sa': self._col_indices, - 'g': self._entry_indices} - - self._row_key = hl.struct( - **{k: self._row[k] for k in self._type.row_key}) + self._col = construct_reference('sa', self._col_type, indices=self._col_indices) + self._entry = construct_reference('g', self._entry_type, indices=self._entry_indices) + + self._indices_from_ref = { + 'global': self._global_indices, + 'va': self._row_indices, + 'sa': self._col_indices, + 'g': self._entry_indices, + } + + self._row_key = hl.struct(**{k: self._row[k] for k in self._type.row_key}) self._partition_key = self._row_key - self._col_key = hl.struct( - **{k: self._col[k] for k in self._type.col_key}) + self._col_key = hl.struct(**{k: self._col[k] for k in self._type.col_key}) self._num_samples = None - for k, v in itertools.chain(self._globals.items(), - self._row.items(), - self._col.items(), - self._entry.items()): + for k, v in itertools.chain(self._globals.items(), self._row.items(), self._col.items(), self._entry.items()): self._set_field(k, v) @property def _schema(self) -> tmatrix: return tmatrix( self._global_type, - self._col_type, list(self._col_key), - self._row_type, list(self._row_key), - self._entry_type) + self._col_type, + list(self._col_key), + self._row_type, + list(self._row_key), + self._entry_type, + ) def __getitem__(self, item): - invalid_usage = TypeError("MatrixTable.__getitem__: invalid index argument(s)\n" - " Usage 1: field selection: mt['field']\n" - " Usage 2: Entry joining: mt[mt2.row_key, mt2.col_key]\n\n" - " To join row or column fields, use one of the following:\n" - " rows:\n" - " mt.index_rows(mt2.row_key)\n" - " mt.rows().index(mt2.row_key)\n" - " mt.rows()[mt2.row_key]\n" - " cols:\n" - " mt.index_cols(mt2.col_key)\n" - " mt.cols().index(mt2.col_key)\n" - " mt.cols()[mt2.col_key]") + invalid_usage = TypeError( + "MatrixTable.__getitem__: invalid index argument(s)\n" + " Usage 1: field selection: mt['field']\n" + " Usage 2: Entry joining: mt[mt2.row_key, mt2.col_key]\n\n" + " To join row or column fields, use one of the following:\n" + " rows:\n" + " mt.index_rows(mt2.row_key)\n" + " mt.rows().index(mt2.row_key)\n" + " mt.rows()[mt2.row_key]\n" + " cols:\n" + " mt.index_cols(mt2.col_key)\n" + " mt.cols().index(mt2.col_key)\n" + " mt.cols()[mt2.col_key]" + ) if isinstance(item, str): return self._get_field(item) @@ -921,8 +968,7 @@ def entry(self) -> 'StructExpression': """ return self._entry - @typecheck_method(keys=oneof(str, Expression), - named_keys=expr_any) + @typecheck_method(keys=oneof(str, Expression), named_keys=expr_any) def key_cols_by(self, *keys, **named_keys) -> 'MatrixTable': """Key columns by a new set of fields. @@ -946,30 +992,25 @@ def key_cols_by(self, *keys, **named_keys) -> 'MatrixTable': new_col = self.col.annotate(**computed_keys) base, cleanup = self._process_joins(new_col) - return cleanup(MatrixTable( - ir.MatrixMapCols( - base._mir, - new_col._ir, - key_fields - ))) + return cleanup(MatrixTable(ir.MatrixMapCols(base._mir, new_col._ir, key_fields))) @typecheck_method(new_key=str) def _key_rows_by_assert_sorted(self, *new_key): rk_names = list(self.row_key) i = 0 - while (i < min(len(new_key), len(rk_names))): + while i < min(len(new_key), len(rk_names)): if new_key[i] != rk_names[i]: break i += 1 if i < 1: raise ValueError( - f'cannot implement an unsafe sort with no shared key:\n new key: {new_key}\n old key: {rk_names}') + f'cannot implement an unsafe sort with no shared key:\n new key: {new_key}\n old key: {rk_names}' + ) return MatrixTable(ir.MatrixKeyRowsBy(self._mir, list(new_key), is_sorted=True)) - @typecheck_method(keys=oneof(str, Expression), - named_keys=expr_any) + @typecheck_method(keys=oneof(str, Expression), named_keys=expr_any) def key_rows_by(self, *keys, **named_keys) -> 'MatrixTable': """Key rows by a new set of fields. @@ -1013,12 +1054,13 @@ def key_rows_by(self, *keys, **named_keys) -> 'MatrixTable': new_row = self._rvrow.annotate(**computed_keys) base, cleanup = self._process_joins(new_row) - return cleanup(MatrixTable( - ir.MatrixKeyRowsBy( - ir.MatrixMapRows( - ir.MatrixKeyRowsBy(base._mir, []), - new_row._ir), - list(key_fields)))) + return cleanup( + MatrixTable( + ir.MatrixKeyRowsBy( + ir.MatrixMapRows(ir.MatrixKeyRowsBy(base._mir, []), new_row._ir), list(key_fields) + ) + ) + ) @typecheck_method(named_exprs=expr_any) def annotate_globals(self, **named_exprs) -> 'MatrixTable': @@ -1254,11 +1296,7 @@ def select_globals(self, *exprs, **named_exprs) -> 'MatrixTable': """ caller = 'MatrixTable.select_globals' - new_global = get_select_exprs(caller, - exprs, - named_exprs, - self._global_indices, - self._globals) + new_global = get_select_exprs(caller, exprs, named_exprs, self._global_indices, self._globals) return self._select_globals(caller, new_global) def select_rows(self, *exprs, **named_exprs) -> 'MatrixTable': @@ -1309,11 +1347,7 @@ def select_rows(self, *exprs, **named_exprs) -> 'MatrixTable': MatrixTable with specified row fields. """ caller = 'MatrixTable.select_rows' - new_row = get_select_exprs(caller, - exprs, - named_exprs, - self._row_indices, - self._rvrow) + new_row = get_select_exprs(caller, exprs, named_exprs, self._row_indices, self._rvrow) return self._select_rows(caller, new_row) def select_cols(self, *exprs, **named_exprs) -> 'MatrixTable': @@ -1359,11 +1393,7 @@ def select_cols(self, *exprs, **named_exprs) -> 'MatrixTable': MatrixTable with specified column fields. """ caller = 'MatrixTable.select_cols' - new_col = get_select_exprs(caller, - exprs, - named_exprs, - self._col_indices, - self._col) + new_col = get_select_exprs(caller, exprs, named_exprs, self._col_indices, self._col) return self._select_cols(caller, new_col) def select_entries(self, *exprs, **named_exprs) -> 'MatrixTable': @@ -1402,11 +1432,7 @@ def select_entries(self, *exprs, **named_exprs) -> 'MatrixTable': MatrixTable with specified entry fields. """ caller = 'MatrixTable.select_entries' - new_entry = get_select_exprs(caller, - exprs, - named_exprs, - self._entry_indices, - self._entry) + new_entry = get_select_exprs(caller, exprs, named_exprs, self._entry_indices, self._entry) return self._select_entries(caller, new_entry) @typecheck_method(exprs=oneof(str, Expression)) @@ -1470,8 +1496,10 @@ def check_key(name, keys): if e in all_field_exprs: fields_to_drop.add(all_field_exprs[e]) else: - raise ExpressionException("Method 'drop' expects string field names or top-level field expressions" - " (e.g. 'foo', matrix.foo, or matrix['foo'])") + raise ExpressionException( + "Method 'drop' expects string field names or top-level field expressions" + " (e.g. 'foo', matrix.foo, or matrix['foo'])" + ) else: assert isinstance(e, str) if e not in self._fields: @@ -1483,11 +1511,19 @@ def check_key(name, keys): if global_fields: m = m._select_globals("MatrixTable.drop", m.globals.drop(*global_fields)) - row_fields = [check_key(field, list(self.row_key)) for field in fields_to_drop if self._fields[field]._indices == self._row_indices] + row_fields = [ + check_key(field, list(self.row_key)) + for field in fields_to_drop + if self._fields[field]._indices == self._row_indices + ] if row_fields: m = m._select_rows("MatrixTable.drop", row=m.row.drop(*row_fields)) - col_fields = [check_key(field, list(self.col_key)) for field in fields_to_drop if self._fields[field]._indices == self._col_indices] + col_fields = [ + check_key(field, list(self.col_key)) + for field in fields_to_drop + if self._fields[field]._indices == self._col_indices + ] if col_fields: m = m._select_cols("MatrixTable.drop", m.col.drop(*col_fields)) @@ -1536,10 +1572,14 @@ def semi_join_rows(self, other: 'Table') -> 'MatrixTable': """ if len(other.key) == 0: raise ValueError('semi_join_rows: cannot join with a table with no key') - if len(other.key) > len(self.row_key) or any(t[0].dtype != t[1].dtype for t in zip(self.row_key.values(), other.key.values())): - raise ValueError('semi_join_rows: cannot join: table must have a key of the same type(s) and be the same length or shorter:' - f'\n MatrixTable row key: {", ".join(str(x.dtype) for x in self.row_key.values())}' - f'\n Table key: {", ".join(str(x.dtype) for x in other.key.values())}') + if len(other.key) > len(self.row_key) or any( + t[0].dtype != t[1].dtype for t in zip(self.row_key.values(), other.key.values()) + ): + raise ValueError( + 'semi_join_rows: cannot join: table must have a key of the same type(s) and be the same length or shorter:' + f'\n MatrixTable row key: {", ".join(str(x.dtype) for x in self.row_key.values())}' + f'\n Table key: {", ".join(str(x.dtype) for x in other.key.values())}' + ) return self.filter_rows(hl.is_defined(other.index(*(self.row_key[i] for i in range(len(other.key)))))) @typecheck_method(other=Table) @@ -1581,10 +1621,14 @@ def anti_join_rows(self, other: 'Table') -> 'MatrixTable': """ if len(other.key) == 0: raise ValueError('anti_join_rows: cannot join with a table with no key') - if len(other.key) > len(self.row_key) or any(t[0].dtype != t[1].dtype for t in zip(self.row_key.values(), other.key.values())): - raise ValueError('anti_join_rows: cannot join: table must have a key of the same type(s) and be the same length or shorter:' - f'\n MatrixTable row key: {", ".join(str(x.dtype) for x in self.row_key.values())}' - f'\n Table key: {", ".join(str(x.dtype) for x in other.key.values())}') + if len(other.key) > len(self.row_key) or any( + t[0].dtype != t[1].dtype for t in zip(self.row_key.values(), other.key.values()) + ): + raise ValueError( + 'anti_join_rows: cannot join: table must have a key of the same type(s) and be the same length or shorter:' + f'\n MatrixTable row key: {", ".join(str(x.dtype) for x in self.row_key.values())}' + f'\n Table key: {", ".join(str(x.dtype) for x in other.key.values())}' + ) return self.filter_rows(hl.is_missing(other.index(*(self.row_key[i] for i in range(len(other.key)))))) @typecheck_method(other=Table) @@ -1626,10 +1670,14 @@ def semi_join_cols(self, other: 'Table') -> 'MatrixTable': """ if len(other.key) == 0: raise ValueError('semi_join_cols: cannot join with a table with no key') - if len(other.key) > len(self.col_key) or any(t[0].dtype != t[1].dtype for t in zip(self.col_key.values(), other.key.values())): - raise ValueError('semi_join_cols: cannot join: table must have a key of the same type(s) and be the same length or shorter:' - f'\n MatrixTable col key: {", ".join(str(x.dtype) for x in self.col_key.values())}' - f'\n Table key: {", ".join(str(x.dtype) for x in other.key.values())}') + if len(other.key) > len(self.col_key) or any( + t[0].dtype != t[1].dtype for t in zip(self.col_key.values(), other.key.values()) + ): + raise ValueError( + 'semi_join_cols: cannot join: table must have a key of the same type(s) and be the same length or shorter:' + f'\n MatrixTable col key: {", ".join(str(x.dtype) for x in self.col_key.values())}' + f'\n Table key: {", ".join(str(x.dtype) for x in other.key.values())}' + ) return self.filter_cols(hl.is_defined(other.index(*(self.col_key[i] for i in range(len(other.key)))))) @@ -1672,10 +1720,14 @@ def anti_join_cols(self, other: 'Table') -> 'MatrixTable': """ if len(other.key) == 0: raise ValueError('anti_join_cols: cannot join with a table with no key') - if len(other.key) > len(self.col_key) or any(t[0].dtype != t[1].dtype for t in zip(self.col_key.values(), other.key.values())): - raise ValueError('anti_join_cols: cannot join: table must have a key of the same type(s) and be the same length or shorter:' - f'\n MatrixTable col key: {", ".join(str(x.dtype) for x in self.col_key.values())}' - f'\n Table key: {", ".join(str(x.dtype) for x in other.key.values())}') + if len(other.key) > len(self.col_key) or any( + t[0].dtype != t[1].dtype for t in zip(self.col_key.values(), other.key.values()) + ): + raise ValueError( + 'anti_join_cols: cannot join: table must have a key of the same type(s) and be the same length or shorter:' + f'\n MatrixTable col key: {", ".join(str(x.dtype) for x in self.col_key.values())}' + f'\n Table key: {", ".join(str(x.dtype) for x in other.key.values())}' + ) return self.filter_cols(hl.is_missing(other.index(*(self.col_key[i] for i in range(len(other.key)))))) @@ -1936,9 +1988,8 @@ def unfilter_entries(self): :meth:`filter_entries`, :meth:`compute_entry_filter_stats` """ entry_ir = hl.if_else( - hl.is_defined(self.entry), - self.entry, - hl.struct(**{k: hl.missing(v.dtype) for k, v in self.entry.items()}))._ir + hl.is_defined(self.entry), self.entry, hl.struct(**{k: hl.missing(v.dtype) for k, v in self.entry.items()}) + )._ir return MatrixTable(ir.MatrixMapEntries(self._mir, entry_ir)) @typecheck_method(row_field=str, col_field=str) @@ -1974,12 +2025,16 @@ def compute_entry_filter_stats(self, row_field='entry_stats_row', col_field='ent -------- :meth:`filter_entries`, :meth:`unfilter_entries` """ + def result(count): - return hl.rbind(count, - hl.agg.count(), - lambda n_tot, n_def: hl.struct(n_filtered=n_tot - n_def, - n_remaining=n_def, - fraction_filtered=(n_tot - n_def) / n_tot)) + return hl.rbind( + count, + hl.agg.count(), + lambda n_tot, n_def: hl.struct( + n_filtered=n_tot - n_def, n_remaining=n_def, fraction_filtered=(n_tot - n_def) / n_tot + ), + ) + mt = self mt = mt.annotate_cols(**{col_field: result(mt.count_rows(_localize=False))}) mt = mt.annotate_rows(**{row_field: result(mt.count_cols(_localize=False))}) @@ -2012,9 +2067,10 @@ def transmute_globals(self, **named_exprs) -> 'MatrixTable': """ caller = 'MatrixTable.transmute_globals' check_annotate_exprs(caller, named_exprs, self._global_indices, set()) - fields_referenced = extract_refs_by_indices(named_exprs.values(), self._global_indices) - set(named_exprs.keys()) - return self._select_globals(caller, - self.globals.annotate(**named_exprs).drop(*fields_referenced)) + fields_referenced = extract_refs_by_indices(named_exprs.values(), self._global_indices) - set( + named_exprs.keys() + ) + return self._select_globals(caller, self.globals.annotate(**named_exprs).drop(*fields_referenced)) @typecheck_method(named_exprs=expr_any) def transmute_rows(self, **named_exprs) -> 'MatrixTable': @@ -2094,8 +2150,7 @@ def transmute_cols(self, **named_exprs) -> 'MatrixTable': fields_referenced = extract_refs_by_indices(named_exprs.values(), self._col_indices) - set(named_exprs.keys()) fields_referenced -= set(self.col_key) - return self._select_cols(caller, - self.col.annotate(**named_exprs).drop(*fields_referenced)) + return self._select_cols(caller, self.col.annotate(**named_exprs).drop(*fields_referenced)) @typecheck_method(named_exprs=expr_any) def transmute_entries(self, **named_exprs) -> 'MatrixTable': @@ -2126,8 +2181,7 @@ def transmute_entries(self, **named_exprs) -> 'MatrixTable': check_annotate_exprs(caller, named_exprs, self._entry_indices, set()) fields_referenced = extract_refs_by_indices(named_exprs.values(), self._entry_indices) - set(named_exprs.keys()) - return self._select_entries(caller, - self.entry.annotate(**named_exprs).drop(*fields_referenced)) + return self._select_entries(caller, self.entry.annotate(**named_exprs).drop(*fields_referenced)) @typecheck_method(expr=expr_any, _localize=bool) def aggregate_rows(self, expr, _localize=True) -> Any: @@ -2227,16 +2281,15 @@ def aggregate_cols(self, expr, _localize=True) -> Any: cols = globals[cols_field] else: if Env.hc()._warn_cols_order: - warning("aggregate_cols(): Aggregates over cols ordered by 'col_key'." - "\n To preserve matrix table column order, " - "first unkey columns with 'key_cols_by()'") + warning( + "aggregate_cols(): Aggregates over cols ordered by 'col_key'." + "\n To preserve matrix table column order, " + "first unkey columns with 'key_cols_by()'" + ) Env.hc()._warn_cols_order = False cols = hl.sorted(globals[cols_field], key=lambda x: x.select(*self._col_key.keys())) - agg_ir = ir.Let( - 'global', - globals.drop(cols_field)._ir, - ir.StreamAgg(ir.ToStream(cols._ir), 'sa', expr._ir)) + agg_ir = ir.Let('global', globals.drop(cols_field)._ir, ir.StreamAgg(ir.ToStream(cols._ir), 'sa', expr._ir)) if _localize: return Env.backend().execute(ir.MakeTuple([agg_ir]))[0] @@ -2322,15 +2375,19 @@ def explode_rows(self, field_expr) -> 'MatrixTable': if field_expr not in self._fields: raise KeyError("MatrixTable has no field '{}'".format(field_expr)) elif self._fields[field_expr]._indices != self._row_indices: - raise ExpressionException("Method 'explode_rows' expects a field indexed by row, found axes '{}'" - .format(self._fields[field_expr]._indices.axes)) + raise ExpressionException( + "Method 'explode_rows' expects a field indexed by row, found axes '{}'".format( + self._fields[field_expr]._indices.axes + ) + ) root = [field_expr] field_expr = self._fields[field_expr] else: analyze('MatrixTable.explode_rows', field_expr, self._row_indices, set(self._fields.keys())) if not field_expr._ir.is_nested_field: raise ExpressionException( - "method 'explode_rows' requires a field or subfield, not a complex expression") + "method 'explode_rows' requires a field or subfield, not a complex expression" + ) nested = field_expr._ir root = [] while isinstance(nested, ir.GetField): @@ -2384,15 +2441,19 @@ def explode_cols(self, field_expr) -> 'MatrixTable': if field_expr not in self._fields: raise KeyError("MatrixTable has no field '{}'".format(field_expr)) elif self._fields[field_expr]._indices != self._col_indices: - raise ExpressionException("Method 'explode_cols' expects a field indexed by col, found axes '{}'" - .format(self._fields[field_expr]._indices.axes)) + raise ExpressionException( + "Method 'explode_cols' expects a field indexed by col, found axes '{}'".format( + self._fields[field_expr]._indices.axes + ) + ) root = [field_expr] field_expr = self._fields[field_expr] else: analyze('MatrixTable.explode_cols', field_expr, self._col_indices) if not field_expr._ir.is_nested_field: raise ExpressionException( - "method 'explode_cols' requires a field or subfield, not a complex expression") + "method 'explode_cols' requires a field or subfield, not a complex expression" + ) root = [] nested = field_expr._ir while isinstance(nested, ir.GetField): @@ -2630,18 +2691,29 @@ def count(self) -> Tuple[int, int]: count_ir = ir.MatrixCount(self._mir) return Env.backend().execute(count_ir) - @typecheck_method(output=str, - overwrite=bool, - stage_locally=bool, - _codec_spec=nullable(str), - _read_if_exists=bool, - _intervals=nullable(sequenceof(anytype)), - _filter_intervals=bool, - _drop_cols=bool, - _drop_rows=bool) - def checkpoint(self, output: str, overwrite: bool = False, stage_locally: bool = False, - _codec_spec: Optional[str] = None, _read_if_exists: bool = False, - _intervals=None, _filter_intervals=False, _drop_cols=False, _drop_rows=False) -> 'MatrixTable': + @typecheck_method( + output=str, + overwrite=bool, + stage_locally=bool, + _codec_spec=nullable(str), + _read_if_exists=bool, + _intervals=nullable(sequenceof(anytype)), + _filter_intervals=bool, + _drop_cols=bool, + _drop_rows=bool, + ) + def checkpoint( + self, + output: str, + overwrite: bool = False, + stage_locally: bool = False, + _codec_spec: Optional[str] = None, + _read_if_exists: bool = False, + _intervals=None, + _filter_intervals=False, + _drop_cols=False, + _drop_rows=False, + ) -> 'MatrixTable': """Checkpoint the matrix table to disk by writing and reading using a fast, but less space-efficient codec. Parameters @@ -2689,16 +2761,20 @@ def checkpoint(self, output: str, overwrite: bool = False, stage_locally: bool = _drop_cols=_drop_cols, _drop_rows=_drop_rows, _assert_type=_assert_type, - _load_refs=_load_refs + _load_refs=_load_refs, ) - @typecheck_method(output=str, - overwrite=bool, - stage_locally=bool, - _codec_spec=nullable(str), - _partitions=nullable(expr_any)) - def write(self, output: str, overwrite: bool = False, stage_locally: bool = False, - _codec_spec: Optional[str] = None, _partitions=None): + @typecheck_method( + output=str, overwrite=bool, stage_locally=bool, _codec_spec=nullable(str), _partitions=nullable(expr_any) + ) + def write( + self, + output: str, + overwrite: bool = False, + stage_locally: bool = False, + _codec_spec: Optional[str] = None, + _partitions=None, + ): """Write to disk. Examples @@ -2756,21 +2832,18 @@ def _repr_html_(self): s += '

    \n' return s - @typecheck_method(n_rows=nullable(int), - n_cols=nullable(int), - include_row_fields=bool, - width=nullable(int), - truncate=nullable(int), - types=bool, - handler=nullable(anyfunc)) - def show(self, - n_rows=None, - n_cols=None, - include_row_fields=False, - width=None, - truncate=None, - types=True, - handler=None): + @typecheck_method( + n_rows=nullable(int), + n_cols=nullable(int), + include_row_fields=bool, + width=nullable(int), + truncate=nullable(int), + types=bool, + handler=nullable(anyfunc), + ) + def show( + self, n_rows=None, n_cols=None, include_row_fields=False, width=None, truncate=None, types=True, handler=None + ): """Print the first few rows of the matrix table to the console. .. include:: _templates/experimental.rst @@ -2799,11 +2872,11 @@ def show(self, """ def estimate_size(struct_expression): - return sum(max(len(f), len(str(x.dtype))) + 3 - for f, x in struct_expression.flatten().items()) + return sum(max(len(f), len(str(x.dtype))) + 3 for f, x in struct_expression.flatten().items()) if n_cols is None: import shutil + (characters, _) = shutil.get_terminal_size((80, 10)) characters -= 6 # borders key_characters = estimate_size(self.row_key) @@ -2826,12 +2899,10 @@ def estimate_size(struct_expression): if len(set(cols)) == len(cols): col_headers = [repr(c) for c in cols] - entries = {col_headers[i]: t.entries[i] - for i in range(0, displayed_n_cols)} + entries = {col_headers[i]: t.entries[i] for i in range(0, displayed_n_cols)} t = t.select( - **{f: t[f] for f in self.row_key}, - **{f: t[f] for f in self.row_value if include_row_fields}, - **entries) + **{f: t[f] for f in self.row_key}, **{f: t[f] for f in self.row_value if include_row_fields}, **entries + ) if handler is None: handler = default_handler() return handler(MatrixTable._Show(t, n_rows, actual_n_cols, displayed_n_cols, width, truncate, types)) @@ -2850,8 +2921,7 @@ def globals_table(self) -> Table: :class:`.Table` Table with the globals from the matrix, with a single row. """ - return Table.parallelize( - [hl.eval(self.globals)], self._global_type) + return Table.parallelize([hl.eval(self.globals)], self._global_type) def rows(self) -> Table: """Returns a table with all row fields in the matrix. @@ -2895,9 +2965,11 @@ def cols(self) -> Table: """ if len(self.col_key) != 0 and Env.hc()._warn_cols_order: - warning("cols(): Resulting column table is sorted by 'col_key'." - "\n To preserve matrix table column order, " - "first unkey columns with 'key_cols_by()'") + warning( + "cols(): Resulting column table is sorted by 'col_key'." + "\n To preserve matrix table column order, " + "first unkey columns with 'key_cols_by()'" + ) Env.hc()._warn_cols_order = False return Table(ir.MatrixColsTable(self._mir)) @@ -2955,9 +3027,11 @@ def entries(self) -> Table: Table with all non-global fields from the matrix, with **one row per entry of the matrix**. """ if Env.hc()._warn_entries_order and len(self.col_key) > 0: - warning("entries(): Resulting entries table is sorted by '(row_key, col_key)'." - "\n To preserve row-major matrix table order, " - "first unkey columns with 'key_cols_by()'") + warning( + "entries(): Resulting entries table is sorted by '(row_key, col_key)'." + "\n To preserve row-major matrix table order, " + "first unkey columns with 'key_cols_by()'" + ) Env.hc()._warn_entries_order = False return Table(ir.MatrixEntriesTable(self._mir)) @@ -3015,7 +3089,8 @@ def index_rows(self, *exprs, all_matches=False) -> 'Expression': raise ExpressionException( f"Key type mismatch: cannot index matrix table with given expressions:\n" f" MatrixTable row key: {', '.join(str(t) for t in err.key_type.values()) or '<<>>'}\n" - f" Index expressions: {', '.join(str(e.dtype) for e in err.index_expressions)}") + f" Index expressions: {', '.join(str(e.dtype) for e in err.index_expressions)}" + ) def index_cols(self, *exprs, all_matches=False) -> 'Expression': """Expose the column values as if looked up in a dictionary, indexing @@ -3054,7 +3129,8 @@ def index_cols(self, *exprs, all_matches=False) -> 'Expression': raise ExpressionException( f"Key type mismatch: cannot index matrix table with given expressions:\n" f" MatrixTable col key: {', '.join(str(t) for t in err.key_type.values()) or '<<>>'}\n" - f" Index expressions: {', '.join(str(e.dtype) for e in err.index_expressions)}") + f" Index expressions: {', '.join(str(e.dtype) for e in err.index_expressions)}" + ) def index_entries(self, row_exprs, col_exprs): """Expose the entries as if looked up in a dictionary, indexing @@ -3104,40 +3180,54 @@ def index_entries(self, row_exprs, col_exprs): raise TypeError(f"'MatrixTable.index_entries': col_exprs expects expressions, found {col_non_exprs}") if not types_match(self.row_key.values(), row_exprs): - if (len(row_exprs) == 1 - and isinstance(row_exprs[0], TupleExpression) - and types_match(self.row_key.values(), row_exprs[0])): + if ( + len(row_exprs) == 1 + and isinstance(row_exprs[0], TupleExpression) + and types_match(self.row_key.values(), row_exprs[0]) + ): return self.index_entries(tuple(row_exprs[0]), col_exprs) - elif (len(row_exprs) == 1 - and isinstance(row_exprs[0], StructExpression) - and types_match(self.row_key.values(), row_exprs[0].values())): + elif ( + len(row_exprs) == 1 + and isinstance(row_exprs[0], StructExpression) + and types_match(self.row_key.values(), row_exprs[0].values()) + ): return self.index_entries(tuple(row_exprs[0].values()), col_exprs) elif len(row_exprs) != len(self.row_key): - raise ExpressionException(f'Key mismatch: matrix table has {len(self.row_key)} row key fields, ' - f'found {len(row_exprs)} index expressions') + raise ExpressionException( + f'Key mismatch: matrix table has {len(self.row_key)} row key fields, ' + f'found {len(row_exprs)} index expressions' + ) else: raise ExpressionException( f"Key type mismatch: Cannot index matrix table with given expressions\n" f" MatrixTable row key: {', '.join(str(t) for t in self.row_key.dtype.values())}\n" - f" Row index expressions: {', '.join(str(e.dtype) for e in row_exprs)}") + f" Row index expressions: {', '.join(str(e.dtype) for e in row_exprs)}" + ) if not types_match(self.col_key.values(), col_exprs): - if (len(col_exprs) == 1 - and isinstance(col_exprs[0], TupleExpression) - and types_match(self.col_key.values(), col_exprs[0])): + if ( + len(col_exprs) == 1 + and isinstance(col_exprs[0], TupleExpression) + and types_match(self.col_key.values(), col_exprs[0]) + ): return self.index_entries(row_exprs, tuple(col_exprs[0])) - elif (len(col_exprs) == 1 - and isinstance(col_exprs[0], StructExpression) - and types_match(self.col_key.values(), col_exprs[0].values())): + elif ( + len(col_exprs) == 1 + and isinstance(col_exprs[0], StructExpression) + and types_match(self.col_key.values(), col_exprs[0].values()) + ): return self.index_entries(row_exprs, tuple(col_exprs[0].values())) elif len(col_exprs) != len(self.col_key): - raise ExpressionException(f'Key mismatch: matrix table has {len(self.col_key)} col key fields, ' - f'found {len(col_exprs)} index expressions.') + raise ExpressionException( + f'Key mismatch: matrix table has {len(self.col_key)} col key fields, ' + f'found {len(col_exprs)} index expressions.' + ) else: raise ExpressionException( f"Key type mismatch: cannot index matrix table with given expressions:\n" f" MatrixTable col key: {', '.join(str(t) for t in self.col_key.dtype.values())}\n" - f" Col index expressions: {', '.join(str(e.dtype) for e in col_exprs)}") + f" Col index expressions: {', '.join(str(e.dtype) for e in col_exprs)}" + ) indices, aggregations = unify_all(*(row_exprs + col_exprs)) src = indices.source @@ -3161,28 +3251,25 @@ def joiner(left: MatrixTable): localized = self._localize_entries(row_uid, col_uid) src_cols_indexed = self.add_col_index(col_uid).cols() src_cols_indexed = src_cols_indexed.annotate(**{col_uid: hl.int32(src_cols_indexed[col_uid])}) - left = left._annotate_all(row_exprs={row_uid: localized.index(*row_exprs)[row_uid]}, - col_exprs={col_uid: src_cols_indexed.index(*col_exprs)[col_uid]}) + left = left._annotate_all( + row_exprs={row_uid: localized.index(*row_exprs)[row_uid]}, + col_exprs={col_uid: src_cols_indexed.index(*col_exprs)[col_uid]}, + ) return left.annotate_entries(**{uid: left[row_uid][left[col_uid]]}) - join_ir = ir.Join(ir.ProjectedTopLevelReference('g', uid, self.entry.dtype), - uids, - [*row_exprs, *col_exprs], - joiner) + join_ir = ir.Join( + ir.ProjectedTopLevelReference('g', uid, self.entry.dtype), uids, [*row_exprs, *col_exprs], joiner + ) return construct_expr(join_ir, self.entry.dtype, indices, aggregations) @typecheck_method(entries_field_name=str, cols_field_name=str) def _localize_entries(self, entries_field_name, cols_field_name) -> 'Table': assert entries_field_name not in self.row assert cols_field_name not in self.globals - return Table(ir.CastMatrixToTable( - self._mir, entries_field_name, cols_field_name)) - - @typecheck_method(entries_array_field_name=nullable(str), - columns_array_field_name=nullable(str)) - def localize_entries(self, - entries_array_field_name=None, - columns_array_field_name=None) -> 'Table': + return Table(ir.CastMatrixToTable(self._mir, entries_field_name, cols_field_name)) + + @typecheck_method(entries_array_field_name=nullable(str), columns_array_field_name=nullable(str)) + def localize_entries(self, entries_array_field_name=None, columns_array_field_name=None) -> 'Table': """Convert the matrix table to a table with entries localized as an array of structs. Examples @@ -3253,10 +3340,12 @@ def localize_entries(self, cols = columns_array_field_name or Env.get_uid() if entries in self.row: raise ValueError( - f"'localize_entries': cannot localize entries to field {entries!r}, which is already a row field") + f"'localize_entries': cannot localize entries to field {entries!r}, which is already a row field" + ) if cols in self.globals: raise ValueError( - f"'localize_entries': cannot localize columns to field {cols!r}, which is already a global field") + f"'localize_entries': cannot localize columns to field {cols!r}, which is already a global field" + ) t = self._localize_entries(entries, cols) if entries_array_field_name is None: @@ -3265,75 +3354,81 @@ def localize_entries(self, t = t.drop(cols) return t - @typecheck_method(row_exprs=dictof(str, expr_any), - col_exprs=dictof(str, expr_any), - entry_exprs=dictof(str, expr_any), - global_exprs=dictof(str, expr_any)) - def _annotate_all(self, - row_exprs={}, - col_exprs={}, - entry_exprs={}, - global_exprs={}, - ) -> 'MatrixTable': - all_exprs = list(itertools.chain(row_exprs.values(), - col_exprs.values(), - entry_exprs.values(), - global_exprs.values())) - - for field_name in list(itertools.chain(row_exprs.keys(), - col_exprs.keys(), - entry_exprs.keys(), - global_exprs.keys())): + @typecheck_method( + row_exprs=dictof(str, expr_any), + col_exprs=dictof(str, expr_any), + entry_exprs=dictof(str, expr_any), + global_exprs=dictof(str, expr_any), + ) + def _annotate_all( + self, + row_exprs={}, + col_exprs={}, + entry_exprs={}, + global_exprs={}, + ) -> 'MatrixTable': + all_exprs = list( + itertools.chain(row_exprs.values(), col_exprs.values(), entry_exprs.values(), global_exprs.values()) + ) + + for field_name in list( + itertools.chain(row_exprs.keys(), col_exprs.keys(), entry_exprs.keys(), global_exprs.keys()) + ): if field_name in self._fields: - raise RuntimeError(f'field {repr(field_name)} already in matrix table, cannot use _annotate_all') + raise RuntimeError(f'field {field_name!r} already in matrix table, cannot use _annotate_all') base, cleanup = self._process_joins(*all_exprs) mir = base._mir if row_exprs: row_struct = ir.InsertFields.construct_with_deduplication( - base.row._ir, [(n, e._ir) for (n, e) in row_exprs.items()], None) + base.row._ir, [(n, e._ir) for (n, e) in row_exprs.items()], None + ) mir = ir.MatrixMapRows(mir, row_struct) if col_exprs: col_struct = ir.InsertFields.construct_with_deduplication( - base.col._ir, [(n, e._ir) for (n, e) in col_exprs.items()], None) + base.col._ir, [(n, e._ir) for (n, e) in col_exprs.items()], None + ) mir = ir.MatrixMapCols(mir, col_struct, None) if entry_exprs: entry_struct = ir.InsertFields.construct_with_deduplication( - base.entry._ir, [(n, e._ir) for (n, e) in entry_exprs.items()], None) + base.entry._ir, [(n, e._ir) for (n, e) in entry_exprs.items()], None + ) mir = ir.MatrixMapEntries(mir, entry_struct) if global_exprs: globals_struct = ir.InsertFields.construct_with_deduplication( - base.globals._ir, [(n, e._ir) for (n, e) in global_exprs.items()], None) + base.globals._ir, [(n, e._ir) for (n, e) in global_exprs.items()], None + ) mir = ir.MatrixMapGlobals(mir, globals_struct) return cleanup(MatrixTable(mir)) - @typecheck_method(row_exprs=dictof(str, expr_any), - row_key=nullable(sequenceof(str)), - col_exprs=dictof(str, expr_any), - col_key=nullable(sequenceof(str)), - entry_exprs=dictof(str, expr_any), - global_exprs=dictof(str, expr_any)) - def _select_all(self, - row_exprs={}, - row_key=None, - col_exprs={}, - col_key=None, - entry_exprs={}, - global_exprs={}, - ) -> 'MatrixTable': - - all_names = list(itertools.chain(row_exprs.keys(), - col_exprs.keys(), - entry_exprs.keys(), - global_exprs.keys())) + @typecheck_method( + row_exprs=dictof(str, expr_any), + row_key=nullable(sequenceof(str)), + col_exprs=dictof(str, expr_any), + col_key=nullable(sequenceof(str)), + entry_exprs=dictof(str, expr_any), + global_exprs=dictof(str, expr_any), + ) + def _select_all( + self, + row_exprs={}, + row_key=None, + col_exprs={}, + col_key=None, + entry_exprs={}, + global_exprs={}, + ) -> 'MatrixTable': + all_names = list(itertools.chain(row_exprs.keys(), col_exprs.keys(), entry_exprs.keys(), global_exprs.keys())) uids = {k: Env.get_uid() for k in all_names} - mt = self._annotate_all({uids[k]: v for k, v in row_exprs.items()}, - {uids[k]: v for k, v in col_exprs.items()}, - {uids[k]: v for k, v in entry_exprs.items()}, - {uids[k]: v for k, v in global_exprs.items()}) + mt = self._annotate_all( + {uids[k]: v for k, v in row_exprs.items()}, + {uids[k]: v for k, v in col_exprs.items()}, + {uids[k]: v for k, v in entry_exprs.items()}, + {uids[k]: v for k, v in global_exprs.items()}, + ) keep = set() if row_key is not None: @@ -3349,8 +3444,9 @@ def _select_all(self, keep = keep.union(set(mt.col_key)) keep = keep.union(uids.values()) - return (mt.drop(*(f for f in mt._fields if f not in keep)) - .rename({uid: original for original, uid in uids.items()})) + return mt.drop(*(f for f in mt._fields if f not in keep)).rename({ + uid: original for original, uid in uids.items() + }) def _process_joins(self, *exprs) -> 'MatrixTable': return process_joins(self, exprs) @@ -3371,6 +3467,7 @@ def describe(self, handler=print, *, widget=False): """ if widget: from hail.experimental.interact import interact + return interact(self) def format_type(typ): @@ -3379,50 +3476,51 @@ def format_type(typ): if len(self.globals) == 0: global_fields = '\n None' else: - global_fields = ''.join("\n '{name}': {type}".format( - name=f, type=format_type(t)) for f, t in self.globals.dtype.items()) + global_fields = ''.join( + "\n '{name}': {type}".format(name=f, type=format_type(t)) for f, t in self.globals.dtype.items() + ) if len(self.row) == 0: row_fields = '\n None' else: - row_fields = ''.join("\n '{name}': {type}".format( - name=f, type=format_type(t)) for f, t in self.row.dtype.items()) + row_fields = ''.join( + "\n '{name}': {type}".format(name=f, type=format_type(t)) for f, t in self.row.dtype.items() + ) - row_key = '[' + ', '.join("'{name}'".format(name=f) for f in self.row_key) + ']' \ - if self.row_key else None + row_key = '[' + ', '.join("'{name}'".format(name=f) for f in self.row_key) + ']' if self.row_key else None if len(self.col) == 0: col_fields = '\n None' else: - col_fields = ''.join("\n '{name}': {type}".format( - name=f, type=format_type(t)) for f, t in self.col.dtype.items()) + col_fields = ''.join( + "\n '{name}': {type}".format(name=f, type=format_type(t)) for f, t in self.col.dtype.items() + ) - col_key = '[' + ', '.join("'{name}'".format(name=f) for f in self.col_key) + ']' \ - if self.col_key else None + col_key = '[' + ', '.join("'{name}'".format(name=f) for f in self.col_key) + ']' if self.col_key else None if len(self.entry) == 0: entry_fields = '\n None' else: - entry_fields = ''.join("\n '{name}': {type}".format( - name=f, type=format_type(t)) for f, t in self.entry.dtype.items()) - - s = '----------------------------------------\n' \ - 'Global fields:{g}\n' \ - '----------------------------------------\n' \ - 'Column fields:{c}\n' \ - '----------------------------------------\n' \ - 'Row fields:{r}\n' \ - '----------------------------------------\n' \ - 'Entry fields:{e}\n' \ - '----------------------------------------\n' \ - 'Column key: {ck}\n' \ - 'Row key: {rk}\n' \ - '----------------------------------------'.format(g=global_fields, - rk=row_key, - r=row_fields, - ck=col_key, - c=col_fields, - e=entry_fields) + entry_fields = ''.join( + "\n '{name}': {type}".format(name=f, type=format_type(t)) for f, t in self.entry.dtype.items() + ) + + s = ( + '----------------------------------------\n' + 'Global fields:{g}\n' + '----------------------------------------\n' + 'Column fields:{c}\n' + '----------------------------------------\n' + 'Row fields:{r}\n' + '----------------------------------------\n' + 'Entry fields:{e}\n' + '----------------------------------------\n' + 'Column key: {ck}\n' + 'Row key: {rk}\n' + '----------------------------------------'.format( + g=global_fields, rk=row_key, r=row_fields, ck=col_key, c=col_fields, e=entry_fields + ) + ) handler(s) @typecheck_method(indices=sequenceof(int)) @@ -3478,8 +3576,7 @@ def n_partitions(self) -> int: """ return Env.backend().execute(ir.MatrixToValueApply(self._mir, {'name': 'NPartitionsMatrixTable'})) - @typecheck_method(n_partitions=int, - shuffle=bool) + @typecheck_method(n_partitions=int, shuffle=bool) def repartition(self, n_partitions: int, shuffle: bool = True) -> 'MatrixTable': """Change the number of partitions. @@ -3543,9 +3640,11 @@ def repartition(self, n_partitions: int, shuffle: bool = True) -> 'MatrixTable': self.checkpoint(tmp) return hl.read_matrix_table(tmp, _n_partitions=n_partitions) - return MatrixTable(ir.MatrixRepartition( - self._mir, n_partitions, - ir.RepartitionStrategy.SHUFFLE if shuffle else ir.RepartitionStrategy.COALESCE)) + return MatrixTable( + ir.MatrixRepartition( + self._mir, n_partitions, ir.RepartitionStrategy.SHUFFLE if shuffle else ir.RepartitionStrategy.COALESCE + ) + ) @typecheck_method(max_partitions=int) def naive_coalesce(self, max_partitions: int) -> 'MatrixTable': @@ -3576,8 +3675,7 @@ def naive_coalesce(self, max_partitions: int) -> 'MatrixTable': :class:`.MatrixTable` Matrix table with at most `max_partitions` partitions. """ - return MatrixTable(ir.MatrixRepartition( - self._mir, max_partitions, ir.RepartitionStrategy.NAIVE_COALESCE)) + return MatrixTable(ir.MatrixRepartition(self._mir, max_partitions, ir.RepartitionStrategy.NAIVE_COALESCE)) def cache(self) -> 'MatrixTable': """Persist the dataset in memory. @@ -3711,10 +3809,7 @@ def add_col_index(self, name: str = 'col_idx') -> 'MatrixTable': """ return self.annotate_cols(**{name: hl.scan.count()}) - @typecheck_method(other=matrix_table_type, - tolerance=numeric, - absolute=bool, - reorder_fields=bool) + @typecheck_method(other=matrix_table_type, tolerance=numeric, absolute=bool, reorder_fields=bool) def _same(self, other, tolerance=1e-6, absolute=False, reorder_fields=False) -> bool: entries_name = Env.get_uid('entries_') cols_name = Env.get_uid('columns_') @@ -3756,7 +3851,8 @@ def _same(self, other, tolerance=1e-6, absolute=False, reorder_fields=False) -> return False return self._localize_entries(entries_name, cols_name)._same( - other._localize_entries(entries_name, cols_name), tolerance, absolute) + other._localize_entries(entries_name, cols_name), tolerance, absolute + ) @typecheck_method(caller=str, s=expr_struct()) def _select_entries(self, caller, s) -> 'MatrixTable': @@ -3764,16 +3860,13 @@ def _select_entries(self, caller, s) -> 'MatrixTable': analyze(caller, s, self._entry_indices) return cleanup(MatrixTable(ir.MatrixMapEntries(base._mir, s._ir))) - @typecheck_method(caller=str, - row=expr_struct()) + @typecheck_method(caller=str, row=expr_struct()) def _select_rows(self, caller, row) -> 'MatrixTable': analyze(caller, row, self._row_indices, {self._col_axis}) base, cleanup = self._process_joins(row) return cleanup(MatrixTable(ir.MatrixMapRows(base._mir, row._ir))) - @typecheck_method(caller=str, - col=expr_struct(), - new_key=nullable(sequenceof(str))) + @typecheck_method(caller=str, col=expr_struct(), new_key=nullable(sequenceof(str))) def _select_cols(self, caller, col, new_key=None) -> 'MatrixTable': analyze(caller, col, self._col_indices, {self._row_axis}) base, cleanup = self._process_joins(col) @@ -3850,34 +3943,39 @@ def union_rows(*datasets: 'MatrixTable', _check_cols=True) -> 'MatrixTable': first = datasets[0] for i, next in enumerate(datasets[1:]): if first.row_key.keys() != next.row_key.keys(): - raise ValueError(error_msg.format( - "row keys", 0, first.row_key.keys(), i + 1, next.row_key.keys() - )) + raise ValueError(error_msg.format("row keys", 0, first.row_key.keys(), i + 1, next.row_key.keys())) if first.row.dtype != next.row.dtype: - raise ValueError(error_msg.format( - "row types", 0, first.row.dtype, i + 1, next.row.dtype - )) + raise ValueError(error_msg.format("row types", 0, first.row.dtype, i + 1, next.row.dtype)) if first.entry.dtype != next.entry.dtype: - raise ValueError(error_msg.format( - "entry field types", 0, first.entry.dtype, i + 1, next.entry.dtype - )) + raise ValueError( + error_msg.format("entry field types", 0, first.entry.dtype, i + 1, next.entry.dtype) + ) if first.col_key.dtype != next.col_key.dtype: - raise ValueError(error_msg.format( - "col key types", 0, first.col_key.dtype, i + 1, next.col_key.dtype - )) + raise ValueError( + error_msg.format("col key types", 0, first.col_key.dtype, i + 1, next.col_key.dtype) + ) if _check_cols: - wrong_keys = hl.eval(hl.rbind(first.col_key.collect(_localize=False), lambda first_keys: ( - hl.enumerate([mt.col_key.collect(_localize=False) for mt in datasets[1:]]) - .find(lambda x: ~(x[1] == first_keys))[0]))) + wrong_keys = hl.eval( + hl.rbind( + first.col_key.collect(_localize=False), + lambda first_keys: ( + hl.enumerate([mt.col_key.collect(_localize=False) for mt in datasets[1:]]).find( + lambda x: ~(x[1] == first_keys) + )[0] + ), + ) + ) if wrong_keys is not None: - raise ValueError(f"'MatrixTable.union_rows' expects all datasets to have the same columns. " - f"Datasets 0 and {wrong_keys + 1} have different columns (or possibly different order).") + raise ValueError( + f"'MatrixTable.union_rows' expects all datasets to have the same columns. " + f"Datasets 0 and {wrong_keys + 1} have different columns (or possibly different order)." + ) return MatrixTable(ir.MatrixUnionRows(*[d._mir for d in datasets])) - @typecheck_method(other=matrix_table_type, - row_join_type=enumeration('inner', 'outer'), - drop_right_row_fields=bool) - def union_cols(self, other: 'MatrixTable', row_join_type: str = 'inner', drop_right_row_fields: bool = True) -> 'MatrixTable': + @typecheck_method(other=matrix_table_type, row_join_type=enumeration('inner', 'outer'), drop_right_row_fields=bool) + def union_cols( + self, other: 'MatrixTable', row_join_type: str = 'inner', drop_right_row_fields: bool = True + ) -> 'MatrixTable': """Take the union of dataset columns. Warning @@ -3942,35 +4040,38 @@ def union_cols(self, other: 'MatrixTable', row_join_type: str = 'inner', drop_ri Dataset with columns from both datasets. """ if self.entry.dtype != other.entry.dtype: - raise ValueError(f'entry types differ:\n' - f' left: {self.entry.dtype}\n' - f' right: {other.entry.dtype}') + raise ValueError( + f'entry types differ:\n' f' left: {self.entry.dtype}\n' f' right: {other.entry.dtype}' + ) if self.col.dtype != other.col.dtype: - raise ValueError(f'column types differ:\n' - f' left: {self.col.dtype}\n' - f' right: {other.col.dtype}') + raise ValueError(f'column types differ:\n' f' left: {self.col.dtype}\n' f' right: {other.col.dtype}') if self.col_key.keys() != other.col_key.keys(): - raise ValueError(f'column key fields differ:\n' - f' left: {", ".join(self.col_key.keys())}\n' - f' right: {", ".join(other.col_key.keys())}') + raise ValueError( + f'column key fields differ:\n' + f' left: {", ".join(self.col_key.keys())}\n' + f' right: {", ".join(other.col_key.keys())}' + ) if list(self.row_key.dtype.values()) != list(other.row_key.dtype.values()): - raise ValueError(f'row key types differ:\n' - f' left: {", ".join(self.row_key.dtype.values())}\n' - f' right: {", ".join(other.row_key.dtype.values())}') + raise ValueError( + f'row key types differ:\n' + f' left: {", ".join(self.row_key.dtype.values())}\n' + f' right: {", ".join(other.row_key.dtype.values())}' + ) if drop_right_row_fields: other = other.select_rows() else: left_fields = set(self.row_value) other_fields = set(other.row_value) - set(other.row_key) - renames, _ = deduplicate( - other_fields, max_attempts=100, already_used=left_fields) + renames, _ = deduplicate(other_fields, max_attempts=100, already_used=left_fields) if renames: renames = dict(renames) other = other.rename(renames) - info('Table.union_cols: renamed the following fields on the right to avoid name conflicts:' - + ''.join(f'\n {repr(k)} -> {repr(v)}' for k, v in renames.items())) + info( + 'Table.union_cols: renamed the following fields on the right to avoid name conflicts:' + + ''.join(f'\n {k!r} -> {v!r}' for k, v in renames.items()) + ) return MatrixTable(ir.MatrixUnionCols(self._mir, other._mir, row_join_type)) @@ -4038,7 +4139,9 @@ def head(self, n_rows: Optional[int], n_cols: Optional[int] = None, *, n: Option mt = self if n_rows is not None: if n_rows < 0: - raise ValueError(f"MatrixTable.head: expect '{n_rows_name}' to be non-negative or None, found '{n_rows}'") + raise ValueError( + f"MatrixTable.head: expect '{n_rows_name}' to be non-negative or None, found '{n_rows}'" + ) mt = MatrixTable(ir.MatrixRowsHead(mt._mir, n_rows)) if n_cols is not None: if n_cols < 0: @@ -4112,7 +4215,9 @@ def tail(self, n_rows: Optional[int], n_cols: Optional[int] = None, *, n: Option mt = self if n_rows is not None: if n_rows < 0: - raise ValueError(f"MatrixTable.tail: expect '{n_rows_name}' to be non-negative or None, found '{n_rows}'") + raise ValueError( + f"MatrixTable.tail: expect '{n_rows_name}' to be non-negative or None, found '{n_rows}'" + ) mt = MatrixTable(ir.MatrixRowsTail(mt._mir, n_rows)) if n_cols is not None: if n_cols < 0: @@ -4122,7 +4227,9 @@ def tail(self, n_rows: Optional[int], n_cols: Optional[int] = None, *, n: Option @typecheck_method(parts=sequenceof(int), keep=bool) def _filter_partitions(self, parts, keep=True) -> 'MatrixTable': - return MatrixTable(ir.MatrixToMatrixApply(self._mir, {'name': 'MatrixFilterPartitions', 'parts': parts, 'keep': keep})) + return MatrixTable( + ir.MatrixToMatrixApply(self._mir, {'name': 'MatrixFilterPartitions', 'parts': parts, 'keep': keep}) + ) @classmethod @typecheck_method(table=Table) @@ -4155,12 +4262,13 @@ def from_rows_table(cls, table: Table) -> 'MatrixTable': """ col_values_uid = Env.get_uid() entries_uid = Env.get_uid() - return (table.annotate_globals(**{col_values_uid: hl.empty_array(hl.tstruct())}) - .annotate(**{entries_uid: hl.empty_array(hl.tstruct())}) - ._unlocalize_entries(entries_uid, col_values_uid, [])) + return ( + table.annotate_globals(**{col_values_uid: hl.empty_array(hl.tstruct())}) + .annotate(**{entries_uid: hl.empty_array(hl.tstruct())}) + ._unlocalize_entries(entries_uid, col_values_uid, []) + ) - @typecheck_method(p=numeric, - seed=nullable(int)) + @typecheck_method(p=numeric, seed=nullable(int)) def sample_rows(self, p: float, seed=None) -> 'MatrixTable': """Downsample the matrix table by keeping each row with probability ``p``. @@ -4193,8 +4301,7 @@ def sample_rows(self, p: float, seed=None) -> 'MatrixTable': return self.filter_rows(hl.rand_bool(p, seed)) - @typecheck_method(p=numeric, - seed=nullable(int)) + @typecheck_method(p=numeric, seed=nullable(int)) def sample_cols(self, p: float, seed=None) -> 'MatrixTable': """Downsample the matrix table by keeping each column with probability ``p``. @@ -4261,7 +4368,9 @@ def rename(self, fields: Dict[str, str]) -> 'MatrixTable': if v in seen: raise ValueError( "Cannot rename two fields to the same name: attempted to rename {} and {} both to {}".format( - repr(seen[v]), repr(k), repr(v))) + repr(seen[v]), repr(k), repr(v) + ) + ) if v in self._fields and v not in fields: raise ValueError("Cannot rename {} to {}: field already exists.".format(repr(k), repr(v))) seen[v] = k @@ -4299,19 +4408,46 @@ def distinct_by_col(self) -> 'MatrixTable': t = t.add_index(index_uid) unique_cols = t.aggregate( - hl.agg.group_by( - hl.struct(**{f: t[f] for f in col_key_fields}), hl.agg.take(t[index_uid], 1))) + hl.agg.group_by(hl.struct(**{f: t[f] for f in col_key_fields}), hl.agg.take(t[index_uid], 1)) + ) unique_cols = sorted([v[0] for _, v in unique_cols.items()]) return self.choose_cols(unique_cols) + @deprecated(version="0.2.129") @typecheck_method(separator=str) def make_table(self, separator='.') -> Table: """Make a table from a matrix table with one field per sample. - Examples + .. deprecated:: 0.2.129 + use :meth:`.localize_entries` instead because it supports more + columns + + Parameters + ---------- + separator : :class:`str` + Separator between sample IDs and entry field names. + + Returns + ------- + :class:`.Table` + + See Also -------- + :meth:`.localize_entries` + Notes + ----- + The table has one row for each row of the input matrix. The + per sample and entry fields are formed by concatenating the + sample ID with the entry field name using `separator`. If the + entry field name is empty, the separator is omitted. + + The table inherits the globals from the matrix table. + + + Examples + -------- Consider a matrix table with the following schema: .. code-block:: text @@ -4356,26 +4492,6 @@ def make_table(self, separator='.') -> Table: Key: 'locus': locus 'alleles': array - - Notes - ----- - - The table has one row for each row of the input matrix. The - per sample and entry fields are formed by concatenating the - sample ID with the entry field name using `separator`. If the - entry field name is empty, the separator is omitted. - - The table inherits the globals from the matrix table. - - Parameters - ---------- - separator : :class:`str` - Separator between sample IDs and entry field names. - - Returns - ------- - :class:`.Table` - """ if not (len(self.col_key) == 1 and self.col_key[0].dtype == hl.tstr): raise ValueError("column key must be a single field of type str") @@ -4384,9 +4500,11 @@ def make_table(self, separator='.') -> Table: counts = Counter(col_keys) if counts[None] > 0: - raise ValueError("'make_table' encountered a missing column key; ensure all identifiers are defined.\n" - " To fill in key index, run:\n" - " mt = mt.key_cols_by(ck = hl.coalesce(mt.COL_KEY_NAME, 'missing_' + hl.str(hl.scan.count())))") + raise ValueError( + "'make_table' encountered a missing column key; ensure all identifiers are defined.\n" + " To fill in key index, run:\n" + " mt = mt.key_cols_by(ck = hl.coalesce(mt.COL_KEY_NAME, 'missing_' + hl.str(hl.scan.count())))" + ) duplicates = [k for k, count in counts.items() if count > 1] if duplicates: @@ -4405,9 +4523,7 @@ def fmt(f, col_key): return col_key t = t.annotate(**{ - fmt(f, col_keys[i]): t[entries_uid][i][j] - for i in range(len(col_keys)) - for j, f in enumerate(self.entry) + fmt(f, col_keys[i]): t[entries_uid][i][j] for i in range(len(col_keys)) for j, f in enumerate(self.entry) }) t = t.drop(cols_uid, entries_uid) @@ -4449,10 +4565,9 @@ def _calculate_new_partitions(self, n_partitions): """returns a set of range bounds that can be passed to write""" ht = self.rows() ht = ht.select().select_globals() - return Env.backend().execute(ir.TableToValueApply( - ht._tir, - {'name': 'TableCalculateNewPartitions', - 'nPartitions': n_partitions})) + return Env.backend().execute( + ir.TableToValueApply(ht._tir, {'name': 'TableCalculateNewPartitions', 'nPartitions': n_partitions}) + ) matrix_table_type.set(MatrixTable) diff --git a/hail/python/hail/methods/__init__.py b/hail/python/hail/methods/__init__.py index f73ce0f2307..6c8165f9ce5 100644 --- a/hail/python/hail/methods/__init__.py +++ b/hail/python/hail/methods/__init__.py @@ -1,96 +1,150 @@ -from .family_methods import (trio_matrix, mendel_errors, - transmission_disequilibrium_test, de_novo) -from .impex import (export_elasticsearch, export_gen, export_bgen, export_plink, export_vcf, - import_locus_intervals, import_bed, import_fam, grep, import_bgen, import_gen, - import_table, import_csv, import_plink, read_matrix_table, read_table, - get_vcf_metadata, import_vcf, index_bgen, import_matrix_table, import_lines, - import_avro, get_vcf_header_info, import_gvcf_interval) -from .statgen import (skat, impute_sex, genetic_relatedness_matrix, realized_relationship_matrix, - pca, hwe_normalized_pca, _blanczos_pca, _hwe_normalized_blanczos, - _spectral_moments, _pca_and_moments, split_multi, filter_alleles, - filter_alleles_hts, split_multi_hts, balding_nichols_model, - ld_prune, row_correlation, ld_matrix, linear_mixed_model, - linear_regression_rows, _linear_regression_rows_nd, logistic_regression_rows, - _logistic_regression_rows_nd, poisson_regression_rows, - linear_mixed_regression_rows, lambda_gc, _linear_skat, _logistic_skat) -from .qc import (VEPConfig, VEPConfigGRCh37Version85, VEPConfigGRCh38Version95, sample_qc, variant_qc, vep, - concordance, nirvana, summarize_variants, compute_charr, vep_json_typ) -from .misc import rename_duplicates, maximal_independent_set, segment_intervals, filter_intervals +from .family_methods import de_novo, mendel_errors, transmission_disequilibrium_test, trio_matrix +from .impex import ( + export_bgen, + export_elasticsearch, + export_gen, + export_plink, + export_vcf, + get_vcf_header_info, + get_vcf_metadata, + grep, + import_avro, + import_bed, + import_bgen, + import_csv, + import_fam, + import_gen, + import_gvcf_interval, + import_lines, + import_locus_intervals, + import_matrix_table, + import_plink, + import_table, + import_vcf, + index_bgen, + read_matrix_table, + read_table, +) +from .misc import filter_intervals, maximal_independent_set, rename_duplicates, segment_intervals +from .qc import ( + VEPConfig, + VEPConfigGRCh37Version85, + VEPConfigGRCh38Version95, + compute_charr, + concordance, + nirvana, + sample_qc, + summarize_variants, + variant_qc, + vep, + vep_json_typ, +) from .relatedness import identity_by_descent, king, pc_relate, simulate_random_mating +from .statgen import ( + _blanczos_pca, + _hwe_normalized_blanczos, + _linear_regression_rows_nd, + _linear_skat, + _logistic_regression_rows_nd, + _logistic_skat, + _pca_and_moments, + _spectral_moments, + balding_nichols_model, + filter_alleles, + filter_alleles_hts, + genetic_relatedness_matrix, + hwe_normalized_pca, + impute_sex, + lambda_gc, + ld_matrix, + ld_prune, + linear_mixed_model, + linear_mixed_regression_rows, + linear_regression_rows, + logistic_regression_rows, + pca, + poisson_regression_rows, + realized_relationship_matrix, + row_correlation, + skat, + split_multi, + split_multi_hts, +) -__all__ = ['trio_matrix', - 'linear_mixed_model', - 'skat', - 'identity_by_descent', - 'impute_sex', - 'linear_regression_rows', - '_linear_regression_rows_nd', - 'logistic_regression_rows', - '_logistic_regression_rows_nd', - 'poisson_regression_rows', - 'linear_mixed_regression_rows', - 'lambda_gc', - '_linear_skat', - '_logistic_skat', - 'sample_qc', - 'variant_qc', - 'genetic_relatedness_matrix', - 'realized_relationship_matrix', - 'pca', - 'hwe_normalized_pca', - '_blanczos_pca', - '_hwe_normalized_blanczos', - '_spectral_moments', - '_pca_and_moments', - 'pc_relate', - 'simulate_random_mating', - 'rename_duplicates', - 'split_multi', - 'split_multi_hts', - 'mendel_errors', - 'export_elasticsearch', - 'export_gen', - 'export_bgen', - 'export_plink', - 'export_vcf', - 'vep', - 'concordance', - 'maximal_independent_set', - 'import_locus_intervals', - 'import_bed', - 'import_fam', - 'import_matrix_table', - 'nirvana', - 'transmission_disequilibrium_test', - 'grep', - 'import_avro', - 'import_bgen', - 'import_gen', - 'import_table', - 'import_csv', - 'import_lines', - 'import_plink', - 'read_matrix_table', - 'read_table', - 'get_vcf_metadata', - 'import_vcf', - 'import_gvcf_interval', - 'index_bgen', - 'balding_nichols_model', - 'ld_prune', - 'filter_intervals', - 'segment_intervals', - 'de_novo', - 'filter_alleles', - 'filter_alleles_hts', - 'summarize_variants', - 'compute_charr', - 'row_correlation', - 'ld_matrix', - 'king', - 'VEPConfig', - 'VEPConfigGRCh37Version85', - 'VEPConfigGRCh38Version95', - 'vep_json_typ', - 'get_vcf_header_info', - ] +__all__ = [ + 'trio_matrix', + 'linear_mixed_model', + 'skat', + 'identity_by_descent', + 'impute_sex', + 'linear_regression_rows', + '_linear_regression_rows_nd', + 'logistic_regression_rows', + '_logistic_regression_rows_nd', + 'poisson_regression_rows', + 'linear_mixed_regression_rows', + 'lambda_gc', + '_linear_skat', + '_logistic_skat', + 'sample_qc', + 'variant_qc', + 'genetic_relatedness_matrix', + 'realized_relationship_matrix', + 'pca', + 'hwe_normalized_pca', + '_blanczos_pca', + '_hwe_normalized_blanczos', + '_spectral_moments', + '_pca_and_moments', + 'pc_relate', + 'simulate_random_mating', + 'rename_duplicates', + 'split_multi', + 'split_multi_hts', + 'mendel_errors', + 'export_elasticsearch', + 'export_gen', + 'export_bgen', + 'export_plink', + 'export_vcf', + 'vep', + 'concordance', + 'maximal_independent_set', + 'import_locus_intervals', + 'import_bed', + 'import_fam', + 'import_matrix_table', + 'nirvana', + 'transmission_disequilibrium_test', + 'grep', + 'import_avro', + 'import_bgen', + 'import_gen', + 'import_table', + 'import_csv', + 'import_lines', + 'import_plink', + 'read_matrix_table', + 'read_table', + 'get_vcf_metadata', + 'import_vcf', + 'import_gvcf_interval', + 'index_bgen', + 'balding_nichols_model', + 'ld_prune', + 'filter_intervals', + 'segment_intervals', + 'de_novo', + 'filter_alleles', + 'filter_alleles_hts', + 'summarize_variants', + 'compute_charr', + 'row_correlation', + 'ld_matrix', + 'king', + 'VEPConfig', + 'VEPConfigGRCh37Version85', + 'VEPConfigGRCh38Version95', + 'vep_json_typ', + 'get_vcf_header_info', +] diff --git a/hail/python/hail/methods/family_methods.py b/hail/python/hail/methods/family_methods.py index 136e268987b..25be87a6fac 100644 --- a/hail/python/hail/methods/family_methods.py +++ b/hail/python/hail/methods/family_methods.py @@ -1,18 +1,18 @@ from typing import Tuple + import hail as hl import hail.expr.aggregators as agg +from hail.expr import expr_call, expr_float64 from hail.genetics.pedigree import Pedigree from hail.matrixtable import MatrixTable -from hail.expr import expr_call, expr_float64 from hail.table import Table -from hail.typecheck import typecheck, numeric +from hail.typecheck import numeric, typecheck from hail.utils.java import Env + from .misc import require_biallelic, require_col_key_str -@typecheck(dataset=MatrixTable, - pedigree=Pedigree, - complete_trios=bool) +@typecheck(dataset=MatrixTable, pedigree=Pedigree, complete_trios=bool) def trio_matrix(dataset, pedigree, complete_trios=False) -> MatrixTable: """Builds and returns a matrix where columns correspond to trios and entries contain genotypes for the trio. @@ -74,12 +74,16 @@ def trio_matrix(dataset, pedigree, complete_trios=False) -> MatrixTable: for i, s in enumerate(samples): sample_idx[s] = i - trios = [hl.Struct( - id=sample_idx[t.s], - pat_id=None if t.pat_id is None else sample_idx[t.pat_id], - mat_id=None if t.mat_id is None else sample_idx[t.mat_id], - is_female=t.is_female, - fam_id=t.fam_id) for t in trios] + trios = [ + hl.Struct( + id=sample_idx[t.s], + pat_id=None if t.pat_id is None else sample_idx[t.pat_id], + mat_id=None if t.mat_id is None else sample_idx[t.mat_id], + is_female=t.is_female, + fam_id=t.fam_id, + ) + for t in trios + ] trios_type = hl.dtype('array') trios_sym = Env.get_uid() @@ -89,29 +93,40 @@ def trio_matrix(dataset, pedigree, complete_trios=False) -> MatrixTable: mt = mt.annotate_globals(**{trios_sym: hl.literal(trios, trios_type)}) mt = mt._localize_entries(entries_sym, cols_sym) mt = mt.annotate_globals(**{ - cols_sym: hl.map(lambda i: - hl.bind(lambda t: hl.struct(id=mt[cols_sym][t.id][k], - proband=mt[cols_sym][t.id], - father=mt[cols_sym][t.pat_id], - mother=mt[cols_sym][t.mat_id], - is_female=t.is_female, - fam_id=t.fam_id), - mt[trios_sym][i]), - hl.range(0, n_trios))}) + cols_sym: hl.map( + lambda i: hl.bind( + lambda t: hl.struct( + id=mt[cols_sym][t.id][k], + proband=mt[cols_sym][t.id], + father=mt[cols_sym][t.pat_id], + mother=mt[cols_sym][t.mat_id], + is_female=t.is_female, + fam_id=t.fam_id, + ), + mt[trios_sym][i], + ), + hl.range(0, n_trios), + ) + }) mt = mt.annotate(**{ - entries_sym: hl.map(lambda i: - hl.bind(lambda t: hl.struct(proband_entry=mt[entries_sym][t.id], - father_entry=mt[entries_sym][t.pat_id], - mother_entry=mt[entries_sym][t.mat_id]), - mt[trios_sym][i]), - hl.range(0, n_trios))}) + entries_sym: hl.map( + lambda i: hl.bind( + lambda t: hl.struct( + proband_entry=mt[entries_sym][t.id], + father_entry=mt[entries_sym][t.pat_id], + mother_entry=mt[entries_sym][t.mat_id], + ), + mt[trios_sym][i], + ), + hl.range(0, n_trios), + ) + }) mt = mt.drop(trios_sym) return mt._unlocalize_entries(entries_sym, cols_sym, ['id']) -@typecheck(call=expr_call, - pedigree=Pedigree) +@typecheck(call=expr_call, pedigree=Pedigree) def mendel_errors(call, pedigree) -> Tuple[Table, Table, Table, Table]: r"""Find Mendel errors; count per variant, individual and nuclear family. @@ -250,19 +265,20 @@ def mendel_errors(call, pedigree) -> Tuple[Table, Table, Table, Table]: """ source = call._indices.source if not isinstance(source, MatrixTable): - raise ValueError("'mendel_errors': expected 'call' to be an expression of 'MatrixTable', found {}".format( - "expression of '{}'".format(source.__class__) if source is not None else 'scalar expression')) + raise ValueError( + "'mendel_errors': expected 'call' to be an expression of 'MatrixTable', found {}".format( + "expression of '{}'".format(source.__class__) if source is not None else 'scalar expression' + ) + ) source = source.select_entries(__GT=call) dataset = require_biallelic(source, 'mendel_errors') tm = trio_matrix(dataset, pedigree, complete_trios=True) - tm = tm.select_entries(mendel_code=hl.mendel_error_code( - tm.locus, - tm.is_female, - tm.father_entry['__GT'], - tm.mother_entry['__GT'], - tm.proband_entry['__GT'] - )) + tm = tm.select_entries( + mendel_code=hl.mendel_error_code( + tm.locus, tm.is_female, tm.father_entry['__GT'], tm.mother_entry['__GT'], tm.proband_entry['__GT'] + ) + ) ck_name = next(iter(source.col_key)) tm = tm.filter_entries(hl.is_defined(tm.mendel_code)) tm = tm.rename({'id': ck_name}) @@ -271,74 +287,93 @@ def mendel_errors(call, pedigree) -> Tuple[Table, Table, Table, Table]: table1 = entries.select('fam_id', 'mendel_code') - t2 = tm.annotate_cols( - errors=hl.agg.count(), - snp_errors=hl.agg.count_where(hl.is_snp(tm.alleles[0], tm.alleles[1]))) + t2 = tm.annotate_cols(errors=hl.agg.count(), snp_errors=hl.agg.count_where(hl.is_snp(tm.alleles[0], tm.alleles[1]))) table2 = t2.key_cols_by().cols() - table2 = table2.select(pat_id=table2.father[ck_name], - mat_id=table2.mother[ck_name], - fam_id=table2.fam_id, - errors=table2.errors, - snp_errors=table2.snp_errors) + table2 = table2.select( + pat_id=table2.father[ck_name], + mat_id=table2.mother[ck_name], + fam_id=table2.fam_id, + errors=table2.errors, + snp_errors=table2.snp_errors, + ) table2 = table2.group_by('pat_id', 'mat_id').aggregate( fam_id=hl.agg.take(table2.fam_id, 1)[0], children=hl.int32(hl.agg.count()), errors=hl.agg.sum(table2.errors), - snp_errors=hl.agg.sum(table2.snp_errors)) - table2 = table2.annotate(errors=hl.or_else(table2.errors, hl.int64(0)), - snp_errors=hl.or_else(table2.snp_errors, hl.int64(0))) + snp_errors=hl.agg.sum(table2.snp_errors), + ) + table2 = table2.annotate( + errors=hl.or_else(table2.errors, hl.int64(0)), snp_errors=hl.or_else(table2.snp_errors, hl.int64(0)) + ) # in implicated, idx 0 is dad, idx 1 is mom, idx 2 is child - implicated = hl.literal([ - [0, 0, 0], # dummy - [1, 1, 1], - [1, 1, 1], - [1, 0, 1], - [0, 1, 1], - [0, 0, 1], - [1, 0, 1], - [0, 1, 1], - [0, 0, 1], - [0, 1, 1], - [0, 1, 1], - [1, 0, 1], - [1, 0, 1], - ], dtype=hl.tarray(hl.tarray(hl.tint64))) - - table3 = tm.annotate_cols(all_errors=hl.or_else(hl.agg.array_sum(implicated[tm.mendel_code]), [0, 0, 0]), - snp_errors=hl.or_else( - hl.agg.filter(hl.is_snp(tm.alleles[0], tm.alleles[1]), - hl.agg.array_sum(implicated[tm.mendel_code])), - [0, 0, 0])).key_cols_by().cols() - - table3 = table3.select(xs=[ - hl.struct(**{ck_name: table3.father[ck_name], - 'fam_id': table3.fam_id, - 'errors': table3.all_errors[0], - 'snp_errors': table3.snp_errors[0]}), - hl.struct(**{ck_name: table3.mother[ck_name], - 'fam_id': table3.fam_id, - 'errors': table3.all_errors[1], - 'snp_errors': table3.snp_errors[1]}), - hl.struct(**{ck_name: table3.proband[ck_name], - 'fam_id': table3.fam_id, - 'errors': table3.all_errors[2], - 'snp_errors': table3.snp_errors[2]}), - ]) + implicated = hl.literal( + [ + [0, 0, 0], # dummy + [1, 1, 1], + [1, 1, 1], + [1, 0, 1], + [0, 1, 1], + [0, 0, 1], + [1, 0, 1], + [0, 1, 1], + [0, 0, 1], + [0, 1, 1], + [0, 1, 1], + [1, 0, 1], + [1, 0, 1], + ], + dtype=hl.tarray(hl.tarray(hl.tint64)), + ) + + table3 = ( + tm.annotate_cols( + all_errors=hl.or_else(hl.agg.array_sum(implicated[tm.mendel_code]), [0, 0, 0]), + snp_errors=hl.or_else( + hl.agg.filter(hl.is_snp(tm.alleles[0], tm.alleles[1]), hl.agg.array_sum(implicated[tm.mendel_code])), + [0, 0, 0], + ), + ) + .key_cols_by() + .cols() + ) + + table3 = table3.select( + xs=[ + hl.struct(**{ + ck_name: table3.father[ck_name], + 'fam_id': table3.fam_id, + 'errors': table3.all_errors[0], + 'snp_errors': table3.snp_errors[0], + }), + hl.struct(**{ + ck_name: table3.mother[ck_name], + 'fam_id': table3.fam_id, + 'errors': table3.all_errors[1], + 'snp_errors': table3.snp_errors[1], + }), + hl.struct(**{ + ck_name: table3.proband[ck_name], + 'fam_id': table3.fam_id, + 'errors': table3.all_errors[2], + 'snp_errors': table3.snp_errors[2], + }), + ] + ) table3 = table3.explode('xs') table3 = table3.select(**table3.xs) - table3 = (table3.group_by(ck_name, 'fam_id') - .aggregate(errors=hl.agg.sum(table3.errors), - snp_errors=hl.agg.sum(table3.snp_errors)) - .key_by(ck_name)) + table3 = ( + table3.group_by(ck_name, 'fam_id') + .aggregate(errors=hl.agg.sum(table3.errors), snp_errors=hl.agg.sum(table3.snp_errors)) + .key_by(ck_name) + ) table4 = tm.select_rows(errors=hl.agg.count_where(hl.is_defined(tm.mendel_code))).rows() return table1, table2, table3, table4 -@typecheck(dataset=MatrixTable, - pedigree=Pedigree) +@typecheck(dataset=MatrixTable, pedigree=Pedigree) def transmission_disequilibrium_test(dataset, pedigree) -> Table: r"""Performs the transmission disequilibrium test on trios. @@ -469,21 +504,23 @@ def transmission_disequilibrium_test(dataset, pedigree) -> Table: hemi_x = 1 # kid, dad, mom, copy, t, u - config_counts = [(hom_ref, het, het, auto, 0, 2), - (hom_ref, hom_ref, het, auto, 0, 1), - (hom_ref, het, hom_ref, auto, 0, 1), - ( het, het, het, auto, 1, 1), - ( het, hom_ref, het, auto, 1, 0), - ( het, het, hom_ref, auto, 1, 0), - ( het, hom_var, het, auto, 0, 1), - ( het, het, hom_var, auto, 0, 1), - (hom_var, het, het, auto, 2, 0), - (hom_var, het, hom_var, auto, 1, 0), - (hom_var, hom_var, het, auto, 1, 0), - (hom_ref, hom_ref, het, hemi_x, 0, 1), - (hom_ref, hom_var, het, hemi_x, 0, 1), - (hom_var, hom_ref, het, hemi_x, 1, 0), - (hom_var, hom_var, het, hemi_x, 1, 0)] + config_counts = [ + (hom_ref, het, het, auto, 0, 2), + (hom_ref, hom_ref, het, auto, 0, 1), + (hom_ref, het, hom_ref, auto, 0, 1), + (het, het, het, auto, 1, 1), + (het, hom_ref, het, auto, 1, 0), + (het, het, hom_ref, auto, 1, 0), + (het, hom_var, het, auto, 0, 1), + (het, het, hom_var, auto, 0, 1), + (hom_var, het, het, auto, 2, 0), + (hom_var, het, hom_var, auto, 1, 0), + (hom_var, hom_var, het, auto, 1, 0), + (hom_ref, hom_ref, het, hemi_x, 0, 1), + (hom_ref, hom_var, het, hemi_x, 0, 1), + (hom_var, hom_ref, het, hemi_x, 1, 0), + (hom_var, hom_var, het, hemi_x, 1, 0), + ] count_map = hl.literal({(c[0], c[1], c[2], c[3]): [c[4], c[5]] for c in config_counts}) @@ -492,15 +529,16 @@ def transmission_disequilibrium_test(dataset, pedigree) -> Table: # this filter removes mendel error of het father in x_nonpar. It also avoids # building and looking up config in common case that neither parent is het father_is_het = tri.father_entry.GT.is_het() - parent_is_valid_het = ((father_is_het & tri.auto_or_x_par) - | (tri.mother_entry.GT.is_het() & ~father_is_het)) + parent_is_valid_het = (father_is_het & tri.auto_or_x_par) | (tri.mother_entry.GT.is_het() & ~father_is_het) copy_state = hl.if_else(tri.auto_or_x_par | tri.is_female, 2, 1) - config = (tri.proband_entry.GT.n_alt_alleles(), - tri.father_entry.GT.n_alt_alleles(), - tri.mother_entry.GT.n_alt_alleles(), - copy_state) + config = ( + tri.proband_entry.GT.n_alt_alleles(), + tri.father_entry.GT.n_alt_alleles(), + tri.mother_entry.GT.n_alt_alleles(), + copy_state, + ) tri = tri.annotate_rows(counts=agg.filter(parent_is_valid_het, agg.array_sum(count_map.get(config)))) @@ -512,25 +550,29 @@ def transmission_disequilibrium_test(dataset, pedigree) -> Table: return tab.cache() -@typecheck(mt=MatrixTable, - pedigree=Pedigree, - pop_frequency_prior=expr_float64, - min_gq=int, - min_p=numeric, - max_parent_ab=numeric, - min_child_ab=numeric, - min_dp_ratio=numeric, - ignore_in_sample_allele_frequency=bool) -def de_novo(mt: MatrixTable, - pedigree: Pedigree, - pop_frequency_prior, - *, - min_gq: int = 20, - min_p: float = 0.05, - max_parent_ab: float = 0.05, - min_child_ab: float = 0.20, - min_dp_ratio: float = 0.10, - ignore_in_sample_allele_frequency: bool = False) -> Table: +@typecheck( + mt=MatrixTable, + pedigree=Pedigree, + pop_frequency_prior=expr_float64, + min_gq=int, + min_p=numeric, + max_parent_ab=numeric, + min_child_ab=numeric, + min_dp_ratio=numeric, + ignore_in_sample_allele_frequency=bool, +) +def de_novo( + mt: MatrixTable, + pedigree: Pedigree, + pop_frequency_prior, + *, + min_gq: int = 20, + min_p: float = 0.05, + max_parent_ab: float = 0.05, + min_child_ab: float = 0.20, + min_dp_ratio: float = 0.10, + ignore_in_sample_allele_frequency: bool = False, +) -> Table: r"""Call putative *de novo* events from trio data. .. include:: ../_templates/req_tstring.rst @@ -749,30 +791,35 @@ def de_novo(mt: MatrixTable, required_entry_fields = {'GT', 'AD', 'DP', 'GQ', 'PL'} missing_fields = required_entry_fields - set(mt.entry) if missing_fields: - raise ValueError(f"'de_novo': expected 'MatrixTable' to have at least {required_entry_fields}, " - f"missing {missing_fields}") + raise ValueError( + f"'de_novo': expected 'MatrixTable' to have at least {required_entry_fields}, " f"missing {missing_fields}" + ) - pop_frequency_prior = hl.case() \ - .when((pop_frequency_prior >= 0) & (pop_frequency_prior <= 1), pop_frequency_prior) \ + pop_frequency_prior = ( + hl.case() + .when((pop_frequency_prior >= 0) & (pop_frequency_prior <= 1), pop_frequency_prior) .or_error(hl.str("de_novo: expect 0 <= pop_frequency_prior <= 1, found " + hl.str(pop_frequency_prior))) + ) if ignore_in_sample_allele_frequency: # this mode is used when families larger than a single trio are observed, in which # an allele might be de novo in a parent and transmitted to a child in the dataset. # The original model does not handle this case correctly, and so this experimental # mode can be used to treat each trio as if it were the only one in the dataset. - mt = mt.annotate_rows(__prior=pop_frequency_prior, - __alt_alleles=hl.int64(1), - __site_freq=hl.max(pop_frequency_prior, MIN_POP_PRIOR)) + mt = mt.annotate_rows( + __prior=pop_frequency_prior, + __alt_alleles=hl.int64(1), + __site_freq=hl.max(pop_frequency_prior, MIN_POP_PRIOR), + ) else: n_alt_alleles = hl.agg.sum(mt.GT.n_alt_alleles()) total_alleles = 2 * hl.agg.sum(hl.is_defined(mt.GT)) # subtract 1 from __alt_alleles to correct for the observed genotype - mt = mt.annotate_rows(__prior=pop_frequency_prior, - __alt_alleles=n_alt_alleles, - __site_freq=hl.max((n_alt_alleles - 1) / total_alleles, - pop_frequency_prior, - MIN_POP_PRIOR)) + mt = mt.annotate_rows( + __prior=pop_frequency_prior, + __alt_alleles=n_alt_alleles, + __site_freq=hl.max((n_alt_alleles - 1) / total_alleles, pop_frequency_prior, MIN_POP_PRIOR), + ) mt = require_biallelic(mt, 'de_novo') @@ -817,30 +864,42 @@ def solve(p_de_novo): return ( hl.case() .when(kid.GQ < min_gq, failure) - .when((kid.DP / (dad.DP + mom.DP) < min_dp_ratio) - | ~(kid_ad_ratio >= min_child_ab), failure) + .when((kid.DP / (dad.DP + mom.DP) < min_dp_ratio) | ~(kid_ad_ratio >= min_child_ab), failure) .when((hl.sum(mom.AD) == 0) | (hl.sum(dad.AD) == 0), failure) - .when((mom.AD[1] / hl.sum(mom.AD) > max_parent_ab) - | (dad.AD[1] / hl.sum(dad.AD) > max_parent_ab), failure) + .when( + (mom.AD[1] / hl.sum(mom.AD) > max_parent_ab) | (dad.AD[1] / hl.sum(dad.AD) > max_parent_ab), failure + ) .when(p_de_novo < min_p, failure) - .when(~is_snp, hl.case() - .when((p_de_novo > 0.99) & (kid_ad_ratio > 0.3) & (n_alt_alleles == 1), - hl.struct(p_de_novo=p_de_novo, confidence='HIGH')) - .when((p_de_novo > 0.5) & (kid_ad_ratio > 0.3) & (n_alt_alleles <= 5), - hl.struct(p_de_novo=p_de_novo, confidence='MEDIUM')) - .when(kid_ad_ratio > 0.2, - hl.struct(p_de_novo=p_de_novo, confidence='LOW')) - .or_missing()) - .default(hl.case() - .when(((p_de_novo > 0.99) & (kid_ad_ratio > 0.3) & (dp_ratio > 0.2)) - | ((p_de_novo > 0.99) & (kid_ad_ratio > 0.3) & (n_alt_alleles == 1)) - | ((p_de_novo > 0.5) & (kid_ad_ratio > 0.3) & (n_alt_alleles < 10) & (kid.DP > 10)), - hl.struct(p_de_novo=p_de_novo, confidence='HIGH')) - .when((p_de_novo > 0.5) & ((kid_ad_ratio > 0.3) | (n_alt_alleles == 1)), - hl.struct(p_de_novo=p_de_novo, confidence='MEDIUM')) - .when(kid_ad_ratio > 0.2, - hl.struct(p_de_novo=p_de_novo, confidence='LOW')) - .or_missing())) + .when( + ~is_snp, + hl.case() + .when( + (p_de_novo > 0.99) & (kid_ad_ratio > 0.3) & (n_alt_alleles == 1), + hl.struct(p_de_novo=p_de_novo, confidence='HIGH'), + ) + .when( + (p_de_novo > 0.5) & (kid_ad_ratio > 0.3) & (n_alt_alleles <= 5), + hl.struct(p_de_novo=p_de_novo, confidence='MEDIUM'), + ) + .when(kid_ad_ratio > 0.2, hl.struct(p_de_novo=p_de_novo, confidence='LOW')) + .or_missing(), + ) + .default( + hl.case() + .when( + ((p_de_novo > 0.99) & (kid_ad_ratio > 0.3) & (dp_ratio > 0.2)) + | ((p_de_novo > 0.99) & (kid_ad_ratio > 0.3) & (n_alt_alleles == 1)) + | ((p_de_novo > 0.5) & (kid_ad_ratio > 0.3) & (n_alt_alleles < 10) & (kid.DP > 10)), + hl.struct(p_de_novo=p_de_novo, confidence='HIGH'), + ) + .when( + (p_de_novo > 0.5) & ((kid_ad_ratio > 0.3) | (n_alt_alleles == 1)), + hl.struct(p_de_novo=p_de_novo, confidence='MEDIUM'), + ) + .when(kid_ad_ratio > 0.2, hl.struct(p_de_novo=p_de_novo, confidence='LOW')) + .or_missing() + ) + ) return hl.bind(solve, p_de_novo) @@ -854,29 +913,40 @@ def solve(p_de_novo): return ( hl.case() .when(kid.GQ < min_gq, failure) - .when((kid.DP / (parent.DP) < min_dp_ratio) - | (kid_ad_ratio < min_child_ab), failure) + .when((kid.DP / (parent.DP) < min_dp_ratio) | (kid_ad_ratio < min_child_ab), failure) .when((hl.sum(parent.AD) == 0), failure) .when(parent.AD[1] / hl.sum(parent.AD) > max_parent_ab, failure) .when(p_de_novo < min_p, failure) - .when(~is_snp, hl.case() - .when((p_de_novo > 0.99) & (kid_ad_ratio > 0.3) & (n_alt_alleles == 1), - hl.struct(p_de_novo=p_de_novo, confidence='HIGH')) - .when((p_de_novo > 0.5) & (kid_ad_ratio > 0.3) & (n_alt_alleles <= 5), - hl.struct(p_de_novo=p_de_novo, confidence='MEDIUM')) - .when(kid_ad_ratio > 0.3, - hl.struct(p_de_novo=p_de_novo, confidence='LOW')) - .or_missing()) - .default(hl.case() - .when(((p_de_novo > 0.99) & (kid_ad_ratio > 0.3) & (dp_ratio > 0.2)) - | ((p_de_novo > 0.99) & (kid_ad_ratio > 0.3) & (n_alt_alleles == 1)) - | ((p_de_novo > 0.5) & (kid_ad_ratio > 0.3) & (n_alt_alleles < 10) & (kid.DP > 10)), - hl.struct(p_de_novo=p_de_novo, confidence='HIGH')) - .when((p_de_novo > 0.5) & ((kid_ad_ratio > 0.3) | (n_alt_alleles == 1)), - hl.struct(p_de_novo=p_de_novo, confidence='MEDIUM')) - .when(kid_ad_ratio > 0.2, - hl.struct(p_de_novo=p_de_novo, confidence='LOW')) - .or_missing())) + .when( + ~is_snp, + hl.case() + .when( + (p_de_novo > 0.99) & (kid_ad_ratio > 0.3) & (n_alt_alleles == 1), + hl.struct(p_de_novo=p_de_novo, confidence='HIGH'), + ) + .when( + (p_de_novo > 0.5) & (kid_ad_ratio > 0.3) & (n_alt_alleles <= 5), + hl.struct(p_de_novo=p_de_novo, confidence='MEDIUM'), + ) + .when(kid_ad_ratio > 0.3, hl.struct(p_de_novo=p_de_novo, confidence='LOW')) + .or_missing(), + ) + .default( + hl.case() + .when( + ((p_de_novo > 0.99) & (kid_ad_ratio > 0.3) & (dp_ratio > 0.2)) + | ((p_de_novo > 0.99) & (kid_ad_ratio > 0.3) & (n_alt_alleles == 1)) + | ((p_de_novo > 0.5) & (kid_ad_ratio > 0.3) & (n_alt_alleles < 10) & (kid.DP > 10)), + hl.struct(p_de_novo=p_de_novo, confidence='HIGH'), + ) + .when( + (p_de_novo > 0.5) & ((kid_ad_ratio > 0.3) | (n_alt_alleles == 1)), + hl.struct(p_de_novo=p_de_novo, confidence='MEDIUM'), + ) + .when(kid_ad_ratio > 0.2, hl.struct(p_de_novo=p_de_novo, confidence='LOW')) + .or_missing() + ) + ) return hl.bind(solve, p_de_novo) @@ -886,18 +956,20 @@ def solve(p_de_novo): .when(autosomal, hl.bind(call_auto, kid_pp, dad_pp, mom_pp, kid_ad_ratio)) .when(hemi_x | hemi_mt, hl.bind(call_hemi, kid_pp, mom, mom_pp, kid_ad_ratio)) .when(hemi_y, hl.bind(call_hemi, kid_pp, dad, dad_pp, kid_ad_ratio)) - .or_missing()) + .or_missing() + ) tm = tm.annotate_entries(__call=de_novo_call) tm = tm.filter_entries(hl.is_defined(tm.__call)) entries = tm.entries() - return (entries.select('__site_freq', - 'proband', - 'father', - 'mother', - 'proband_entry', - 'father_entry', - 'mother_entry', - 'is_female', - **entries.__call) - .rename({'__site_freq': 'prior'})) + return entries.select( + '__site_freq', + 'proband', + 'father', + 'mother', + 'proband_entry', + 'father_entry', + 'mother_entry', + 'is_female', + **entries.__call, + ).rename({'__site_freq': 'prior'}) diff --git a/hail/python/hail/methods/impex.py b/hail/python/hail/methods/impex.py index 94a59585464..72541b693e5 100644 --- a/hail/python/hail/methods/impex.py +++ b/hail/python/hail/methods/impex.py @@ -8,40 +8,63 @@ import hail as hl from hail import ir -from hail.expr import StructExpression, LocusExpression, \ - expr_array, expr_float64, expr_str, expr_numeric, expr_call, expr_bool, \ - expr_int32, to_expr, analyze -from hail.expr.types import hail_type, tarray, tfloat64, tstr, tint32, tstruct, \ - tcall, tbool, tint64, tfloat32 +from hail.expr import ( + LocusExpression, + StructExpression, + analyze, + expr_array, + expr_bool, + expr_call, + expr_float64, + expr_int32, + expr_numeric, + expr_str, + to_expr, +) +from hail.expr.matrix_type import tmatrix +from hail.expr.table_type import ttable +from hail.expr.types import hail_type, tarray, tbool, tcall, tfloat32, tfloat64, tint32, tint64, tstr, tstruct from hail.genetics.reference_genome import reference_genome_type from hail.ir.utils import parse_type from hail.matrixtable import MatrixTable -from hail.methods.misc import require_biallelic, require_row_key_variant, require_col_key_str +from hail.methods.misc import require_biallelic, require_col_key_str, require_row_key_variant from hail.table import Table -from hail.typecheck import typecheck, nullable, oneof, dictof, anytype, \ - sequenceof, enumeration, sized_tupleof, numeric, table_key_type, char +from hail.typecheck import ( + anytype, + char, + dictof, + enumeration, + nullable, + numeric, + oneof, + sequenceof, + sized_tupleof, + table_key_type, + typecheck, +) from hail.utils import new_temp_file from hail.utils.deduplicate import deduplicate -from hail.utils.java import Env, FatalError, jindexed_seq_args, warning -from hail.utils.java import info -from hail.utils.misc import wrap_to_list, plural -from .import_lines_helpers import split_lines, should_remove_line +from hail.utils.java import Env, FatalError, info, jindexed_seq_args, warning +from hail.utils.misc import plural, wrap_to_list +from .import_lines_helpers import should_remove_line, split_lines -def locus_interval_expr(contig, start, end, includes_start, includes_end, - reference_genome, skip_invalid_intervals): + +def locus_interval_expr(contig, start, end, includes_start, includes_end, reference_genome, skip_invalid_intervals): includes_start = hl.bool(includes_start) includes_end = hl.bool(includes_end) if reference_genome: - return hl.locus_interval(contig, start, end, includes_start, - includes_end, reference_genome, - skip_invalid_intervals) + return hl.locus_interval( + contig, start, end, includes_start, includes_end, reference_genome, skip_invalid_intervals + ) else: - return hl.interval(hl.struct(contig=contig, position=start), - hl.struct(contig=contig, position=end), - includes_start, - includes_end) + return hl.interval( + hl.struct(contig=contig, position=start), + hl.struct(contig=contig, position=end), + includes_start, + includes_end, + ) def expr_or_else(expr, default, f=lambda x: x): @@ -51,17 +74,18 @@ def expr_or_else(expr, default, f=lambda x: x): return to_expr(default) -@typecheck(dataset=MatrixTable, - output=str, - precision=int, - gp=nullable(expr_array(expr_float64)), - id1=nullable(expr_str), - id2=nullable(expr_str), - missing=nullable(expr_numeric), - varid=nullable(expr_str), - rsid=nullable(expr_str)) -def export_gen(dataset, output, precision=4, gp=None, id1=None, id2=None, - missing=None, varid=None, rsid=None): +@typecheck( + dataset=MatrixTable, + output=str, + precision=int, + gp=nullable(expr_array(expr_float64)), + id1=nullable(expr_str), + id2=nullable(expr_str), + missing=nullable(expr_numeric), + varid=nullable(expr_str), + rsid=nullable(expr_str), +) +def export_gen(dataset, output, precision=4, gp=None, id1=None, id2=None, missing=None, varid=None, rsid=None): """Export a :class:`.MatrixTable` as GEN and SAMPLE files. .. include:: ../_templates/req_tvariant.rst @@ -127,9 +151,11 @@ def export_gen(dataset, output, precision=4, gp=None, id1=None, id2=None, if 'GP' in dataset.entry and dataset.GP.dtype == tarray(tfloat64): entry_exprs = {'GP': dataset.GP} else: - raise ValueError('exporting to GEN requires a GP (genotype probability) array field in the entry' - '\n of the matrix table. If you only have hard calls (GT), BGEN is probably not the' - '\n right format.') + raise ValueError( + 'exporting to GEN requires a GP (genotype probability) array field in the entry' + '\n of the matrix table. If you only have hard calls (GT), BGEN is probably not the' + '\n right format.' + ) else: entry_exprs = {'GP': gp} @@ -157,31 +183,34 @@ def export_gen(dataset, output, precision=4, gp=None, id1=None, id2=None, locus = dataset.locus a = dataset.alleles - gen_exprs = {'varid': expr_or_else(varid, hl.delimit([locus.contig, hl.str(locus.position), a[0], a[1]], ':')), - 'rsid': expr_or_else(rsid, ".")} + gen_exprs = { + 'varid': expr_or_else(varid, hl.delimit([locus.contig, hl.str(locus.position), a[0], a[1]], ':')), + 'rsid': expr_or_else(rsid, "."), + } - for exprs, axis in [(sample_exprs, dataset._col_indices), - (gen_exprs, dataset._row_indices), - (entry_exprs, dataset._entry_indices)]: + for exprs, axis in [ + (sample_exprs, dataset._col_indices), + (gen_exprs, dataset._row_indices), + (entry_exprs, dataset._entry_indices), + ]: for name, expr in exprs.items(): analyze('export_gen/{}'.format(name), expr, axis) - dataset = dataset._select_all(col_exprs=sample_exprs, - col_key=[], - row_exprs=gen_exprs, - entry_exprs=entry_exprs) + dataset = dataset._select_all(col_exprs=sample_exprs, col_key=[], row_exprs=gen_exprs, entry_exprs=entry_exprs) writer = ir.MatrixGENWriter(output, precision) Env.backend().execute(ir.MatrixWrite(dataset._mir, writer)) -@typecheck(mt=MatrixTable, - output=str, - gp=nullable(expr_array(expr_float64)), - varid=nullable(expr_str), - rsid=nullable(expr_str), - parallel=nullable(ir.ExportType.checker), - compression_codec=enumeration('zlib', 'zstd')) +@typecheck( + mt=MatrixTable, + output=str, + gp=nullable(expr_array(expr_float64)), + varid=nullable(expr_str), + rsid=nullable(expr_str), + parallel=nullable(ir.ExportType.checker), + compression_codec=enumeration('zlib', 'zstd'), +) def export_bgen(mt, output, gp=None, varid=None, rsid=None, parallel=None, compression_codec='zlib'): """Export MatrixTable as :class:`.MatrixTable` as BGEN 1.2 file with 8 bits of per probability. Also writes SAMPLE file. @@ -244,9 +273,11 @@ def export_bgen(mt, output, gp=None, varid=None, rsid=None, parallel=None, compr if 'GP' in mt.entry and mt.GP.dtype == tarray(tfloat64): entry_exprs = {'GP': mt.GP} else: - raise ValueError('exporting to BGEN requires a GP (genotype probability) array field in the entry' - '\n of the matrix table. If you only have hard calls (GT), BGEN is probably not the' - '\n right format.') + raise ValueError( + 'exporting to BGEN requires a GP (genotype probability) array field in the entry' + '\n of the matrix table. If you only have hard calls (GT), BGEN is probably not the' + '\n right format.' + ) else: entry_exprs = {'GP': gp} @@ -262,38 +293,46 @@ def export_bgen(mt, output, gp=None, varid=None, rsid=None, parallel=None, compr locus = mt.locus a = mt.alleles - gen_exprs = {'varid': expr_or_else(varid, hl.delimit([locus.contig, hl.str(locus.position), a[0], a[1]], ':')), - 'rsid': expr_or_else(rsid, ".")} + gen_exprs = { + 'varid': expr_or_else(varid, hl.delimit([locus.contig, hl.str(locus.position), a[0], a[1]], ':')), + 'rsid': expr_or_else(rsid, "."), + } - for exprs, axis in [(gen_exprs, mt._row_indices), - (entry_exprs, mt._entry_indices)]: + for exprs, axis in [(gen_exprs, mt._row_indices), (entry_exprs, mt._entry_indices)]: for name, expr in exprs.items(): analyze('export_bgen/{}'.format(name), expr, axis) - mt = mt._select_all(col_exprs={}, - row_exprs=gen_exprs, - entry_exprs=entry_exprs) - - Env.backend().execute(ir.MatrixWrite(mt._mir, ir.MatrixBGENWriter( - output, - parallel, - compression_codec))) - - -@typecheck(dataset=MatrixTable, - output=str, - call=nullable(expr_call), - fam_id=nullable(expr_str), - ind_id=nullable(expr_str), - pat_id=nullable(expr_str), - mat_id=nullable(expr_str), - is_female=nullable(expr_bool), - pheno=oneof(nullable(expr_bool), nullable(expr_numeric)), - varid=nullable(expr_str), - cm_position=nullable(expr_float64)) -def export_plink(dataset, output, call=None, fam_id=None, ind_id=None, pat_id=None, - mat_id=None, is_female=None, pheno=None, varid=None, - cm_position=None): + mt = mt._select_all(col_exprs={}, row_exprs=gen_exprs, entry_exprs=entry_exprs) + + Env.backend().execute(ir.MatrixWrite(mt._mir, ir.MatrixBGENWriter(output, parallel, compression_codec))) + + +@typecheck( + dataset=MatrixTable, + output=str, + call=nullable(expr_call), + fam_id=nullable(expr_str), + ind_id=nullable(expr_str), + pat_id=nullable(expr_str), + mat_id=nullable(expr_str), + is_female=nullable(expr_bool), + pheno=oneof(nullable(expr_bool), nullable(expr_numeric)), + varid=nullable(expr_str), + cm_position=nullable(expr_float64), +) +def export_plink( + dataset, + output, + call=None, + fam_id=None, + ind_id=None, + pat_id=None, + mat_id=None, + is_female=None, + pheno=None, + varid=None, + cm_position=None, +): """Export a :class:`.MatrixTable` as `PLINK2 `__ BED, BIM and FAM files. @@ -378,31 +417,32 @@ def export_plink(dataset, output, call=None, fam_id=None, ind_id=None, pat_id=No else: entry_exprs = {'GT': call} - fam_exprs = {'fam_id': expr_or_else(fam_id, '0'), - 'ind_id': hl.or_else(ind_id, '0'), - 'pat_id': expr_or_else(pat_id, '0'), - 'mat_id': expr_or_else(mat_id, '0'), - 'is_female': expr_or_else(is_female, '0', - lambda x: hl.if_else(x, '2', '1')), - 'pheno': expr_or_else(pheno, 'NA', - lambda x: hl.if_else(x, '2', '1') if x.dtype == tbool else hl.str(x))} + fam_exprs = { + 'fam_id': expr_or_else(fam_id, '0'), + 'ind_id': hl.or_else(ind_id, '0'), + 'pat_id': expr_or_else(pat_id, '0'), + 'mat_id': expr_or_else(mat_id, '0'), + 'is_female': expr_or_else(is_female, '0', lambda x: hl.if_else(x, '2', '1')), + 'pheno': expr_or_else(pheno, 'NA', lambda x: hl.if_else(x, '2', '1') if x.dtype == tbool else hl.str(x)), + } locus = dataset.locus a = dataset.alleles - bim_exprs = {'varid': expr_or_else(varid, hl.delimit([locus.contig, hl.str(locus.position), a[0], a[1]], ':')), - 'cm_position': expr_or_else(cm_position, 0.0)} + bim_exprs = { + 'varid': expr_or_else(varid, hl.delimit([locus.contig, hl.str(locus.position), a[0], a[1]], ':')), + 'cm_position': expr_or_else(cm_position, 0.0), + } - for exprs, axis in [(fam_exprs, dataset._col_indices), - (bim_exprs, dataset._row_indices), - (entry_exprs, dataset._entry_indices)]: + for exprs, axis in [ + (fam_exprs, dataset._col_indices), + (bim_exprs, dataset._row_indices), + (entry_exprs, dataset._entry_indices), + ]: for name, expr in exprs.items(): analyze('export_plink/{}'.format(name), expr, axis) - dataset = dataset._select_all(col_exprs=fam_exprs, - col_key=[], - row_exprs=bim_exprs, - entry_exprs=entry_exprs) + dataset = dataset._select_all(col_exprs=fam_exprs, col_key=[], row_exprs=bim_exprs, entry_exprs=entry_exprs) # check FAM ids for white space t_cols = dataset.cols() @@ -422,12 +462,14 @@ def export_plink(dataset, output, call=None, fam_id=None, ind_id=None, pat_id=No Env.backend().execute(ir.MatrixWrite(dataset._mir, writer)) -@typecheck(dataset=oneof(MatrixTable, Table), - output=str, - append_to_header=nullable(str), - parallel=nullable(ir.ExportType.checker), - metadata=nullable(dictof(str, dictof(str, dictof(str, str)))), - tabix=bool) +@typecheck( + dataset=oneof(MatrixTable, Table), + output=str, + append_to_header=nullable(str), + parallel=nullable(ir.ExportType.checker), + metadata=nullable(dictof(str, dictof(str, dictof(str, str)))), + tabix=bool, +) def export_vcf(dataset, output, append_to_header=None, parallel=None, metadata=None, *, tabix=False): """Export a :class:`.MatrixTable` or :class:`.Table` as a VCF file. @@ -543,10 +585,12 @@ def export_vcf(dataset, output, append_to_header=None, parallel=None, metadata=N _, ext = os.path.splitext(output) if ext == '.gz': - warning('VCF export with standard gzip compression requested. This is almost *never* desired and will ' - 'cause issues with other tools that consume VCF files. The compression format used for VCF ' - 'files is traditionally *block* gzip compression. To use block gzip compression with hail VCF ' - 'export, use a path ending in `.bgz`.') + warning( + 'VCF export with standard gzip compression requested. This is almost *never* desired and will ' + 'cause issues with other tools that consume VCF files. The compression format used for VCF ' + 'files is traditionally *block* gzip compression. To use block gzip compression with hail VCF ' + 'export, use a path ending in `.bgz`.' + ) if isinstance(dataset, Table): mt = MatrixTable.from_rows_table(dataset) @@ -556,15 +600,19 @@ def export_vcf(dataset, output, append_to_header=None, parallel=None, metadata=N require_row_key_variant(dataset, 'export_vcf') if 'filters' in dataset.row and dataset.filters.dtype != hl.tset(hl.tstr): - raise ValueError(f"'export_vcf': expect the 'filters' field to be set, found {dataset.filters.dtype}" - f"\n Either transform this field to set to export as VCF FILTERS field, or drop it from the dataset.") + raise ValueError( + f"'export_vcf': expect the 'filters' field to be set, found {dataset.filters.dtype}" + f"\n Either transform this field to set to export as VCF FILTERS field, or drop it from the dataset." + ) info_fields = list(dataset.info) if "info" in dataset.row else [] invalid_info_fields = [f for f in info_fields if not re.fullmatch(r"^([A-Za-z_][0-9A-Za-z_.]*|1000G)", f)] if invalid_info_fields: invalid_info_str = ''.join(f'\n {f!r}' for f in invalid_info_fields) warning( - 'export_vcf: the following info field names are invalid in VCF 4.3 and may not work with some tools: ' + invalid_info_str) + 'export_vcf: the following info field names are invalid in VCF 4.3 and may not work with some tools: ' + + invalid_info_str + ) row_fields_used = {'rsid', 'info', 'filters', 'qual'} @@ -584,24 +632,20 @@ def export_vcf(dataset, output, append_to_header=None, parallel=None, metadata=N parallel = ir.ExportType.default(parallel) - writer = ir.MatrixVCFWriter(output, - append_to_header, - parallel, - metadata, - tabix) + writer = ir.MatrixVCFWriter(output, append_to_header, parallel, metadata, tabix) Env.backend().execute(ir.MatrixWrite(dataset._mir, writer)) -@typecheck(path=str, - reference_genome=nullable(reference_genome_type), - skip_invalid_intervals=bool, - contig_recoding=nullable(dictof(str, str)), - kwargs=anytype) -def import_locus_intervals(path, - reference_genome='default', - skip_invalid_intervals=False, - contig_recoding=None, - **kwargs) -> Table: +@typecheck( + path=str, + reference_genome=nullable(reference_genome_type), + skip_invalid_intervals=bool, + contig_recoding=nullable(dictof(str, str)), + kwargs=anytype, +) +def import_locus_intervals( + path, reference_genome='default', skip_invalid_intervals=False, contig_recoding=None, **kwargs +) -> Table: """Import a locus interval list as a :class:`.Table`. Examples @@ -683,60 +727,68 @@ def recode_contig(x): return x return contig_recoding.get(x, x) - t = import_table(path, comment="@", impute=False, no_header=True, - types={'f0': tstr, 'f1': tint32, 'f2': tint32, - 'f3': tstr, 'f4': tstr}, - **kwargs) + t = import_table( + path, + comment="@", + impute=False, + no_header=True, + types={'f0': tstr, 'f1': tint32, 'f2': tint32, 'f3': tstr, 'f4': tstr}, + **kwargs, + ) if t.row.dtype == tstruct(f0=tstr): if reference_genome: - t = t.select(interval=hl.parse_locus_interval(t['f0'], - reference_genome)) + t = t.select(interval=hl.parse_locus_interval(t['f0'], reference_genome)) else: interval_regex = r"([^:]*):(\d+)\-(\d+)" def checked_match_interval_expr(match): - return hl.or_missing(hl.len(match) == 3, - locus_interval_expr(recode_contig(match[0]), - hl.int32(match[1]), - hl.int32(match[2]), - True, - True, - reference_genome, - skip_invalid_intervals)) - - expr = ( - hl.bind(t['f0'].first_match_in(interval_regex), - lambda match: hl.if_else(hl.bool(skip_invalid_intervals), - checked_match_interval_expr(match), - locus_interval_expr(recode_contig(match[0]), - hl.int32(match[1]), - hl.int32(match[2]), - True, - True, - reference_genome, - skip_invalid_intervals)))) + return hl.or_missing( + hl.len(match) == 3, + locus_interval_expr( + recode_contig(match[0]), + hl.int32(match[1]), + hl.int32(match[2]), + True, + True, + reference_genome, + skip_invalid_intervals, + ), + ) + + expr = hl.bind( + t['f0'].first_match_in(interval_regex), + lambda match: hl.if_else( + hl.bool(skip_invalid_intervals), + checked_match_interval_expr(match), + locus_interval_expr( + recode_contig(match[0]), + hl.int32(match[1]), + hl.int32(match[2]), + True, + True, + reference_genome, + skip_invalid_intervals, + ), + ), + ) t = t.select(interval=expr) elif t.row.dtype == tstruct(f0=tstr, f1=tint32, f2=tint32): - t = t.select(interval=locus_interval_expr(recode_contig(t['f0']), - t['f1'], - t['f2'], - True, - True, - reference_genome, - skip_invalid_intervals)) + t = t.select( + interval=locus_interval_expr( + recode_contig(t['f0']), t['f1'], t['f2'], True, True, reference_genome, skip_invalid_intervals + ) + ) elif t.row.dtype == tstruct(f0=tstr, f1=tint32, f2=tint32, f3=tstr, f4=tstr): - t = t.select(interval=locus_interval_expr(recode_contig(t['f0']), - t['f1'], - t['f2'], - True, - True, - reference_genome, - skip_invalid_intervals), - target=t['f4']) + t = t.select( + interval=locus_interval_expr( + recode_contig(t['f0']), t['f1'], t['f2'], True, True, reference_genome, skip_invalid_intervals + ), + target=t['f4'], + ) else: raise FatalError("""invalid interval format. Acceptable formats: @@ -750,16 +802,14 @@ def checked_match_interval_expr(match): return t.key_by('interval') -@typecheck(path=str, - reference_genome=nullable(reference_genome_type), - skip_invalid_intervals=bool, - contig_recoding=nullable(dictof(str, str)), - kwargs=anytype) -def import_bed(path, - reference_genome='default', - skip_invalid_intervals=False, - contig_recoding=None, - **kwargs) -> Table: +@typecheck( + path=str, + reference_genome=nullable(reference_genome_type), + skip_invalid_intervals=bool, + contig_recoding=nullable(dictof(str, str)), + kwargs=anytype, +) +def import_bed(path, reference_genome='default', skip_invalid_intervals=False, contig_recoding=None, **kwargs) -> Table: """Import a UCSC BED file as a :class:`.Table`. Examples @@ -849,13 +899,16 @@ def import_bed(path, # UCSC BED spec defined here: https://genome.ucsc.edu/FAQ/FAQformat.html#format1 - t = import_table(path, no_header=True, delimiter=r"\s+", impute=False, - skip_blank_lines=True, types={'f0': tstr, 'f1': tint32, - 'f2': tint32, 'f3': tstr, - 'f4': tstr}, - comment=["""^browser.*""", """^track.*""", - r"""^\w+=("[\w\d ]+"|\d+).*"""], - **kwargs) + t = import_table( + path, + no_header=True, + delimiter=r"\s+", + impute=False, + skip_blank_lines=True, + types={'f0': tstr, 'f1': tint32, 'f2': tint32, 'f3': tstr, 'f4': tstr}, + comment=["""^browser.*""", """^track.*""", r"""^\w+=("[\w\d ]+"|\d+).*"""], + **kwargs, + ) if contig_recoding is not None: contig_recoding = hl.literal(contig_recoding) @@ -866,24 +919,21 @@ def recode_contig(x): return contig_recoding.get(x, x) if t.row.dtype == tstruct(f0=tstr, f1=tint32, f2=tint32): - t = t.select(interval=locus_interval_expr(recode_contig(t['f0']), - t['f1'] + 1, - t['f2'] + 1, - True, - False, - reference_genome, - skip_invalid_intervals)) + t = t.select( + interval=locus_interval_expr( + recode_contig(t['f0']), t['f1'] + 1, t['f2'] + 1, True, False, reference_genome, skip_invalid_intervals + ) + ) elif len(t.row) >= 4 and tstruct(**dict([(n, typ) for n, typ in t.row.dtype._field_types.items()][:4])) == tstruct( - f0=tstr, f1=tint32, f2=tint32, f3=tstr): - t = t.select(interval=locus_interval_expr(recode_contig(t['f0']), - t['f1'] + 1, - t['f2'] + 1, - True, - False, - reference_genome, - skip_invalid_intervals), - target=t['f3']) + f0=tstr, f1=tint32, f2=tint32, f3=tstr + ): + t = t.select( + interval=locus_interval_expr( + recode_contig(t['f0']), t['f1'] + 1, t['f2'] + 1, True, False, reference_genome, skip_invalid_intervals + ), + target=t['f3'], + ) else: raise FatalError("too few fields for BED file: expected 3 or more, but found {}".format(len(t.row))) @@ -894,10 +944,7 @@ def recode_contig(x): return t.key_by('interval') -@typecheck(path=str, - quant_pheno=bool, - delimiter=str, - missing=str) +@typecheck(path=str, quant_pheno=bool, delimiter=str, missing=str) def import_fam(path, quant_pheno=False, delimiter=r'\\s+', missing='NA') -> Table: """Import a PLINK FAM file into a :class:`.Table`. @@ -961,16 +1008,10 @@ def import_fam(path, quant_pheno=False, delimiter=r'\\s+', missing='NA') -> Tabl """ type_and_data = Env.backend().import_fam(path, quant_pheno, delimiter, missing) typ = hl.dtype(type_and_data['type']) - return hl.Table.parallelize( - hl.tarray(typ)._convert_from_json_na(type_and_data['data']), typ, key=['id']) + return hl.Table.parallelize(hl.tarray(typ)._convert_from_json_na(type_and_data['data']), typ, key=['id']) -@typecheck(regex=str, - path=oneof(str, sequenceof(str)), - max_count=int, - show=bool, - force=bool, - force_bgz=bool) +@typecheck(regex=str, path=oneof(str, sequenceof(str)), max_count=int, show=bool, force=bool, force_bgz=bool) def grep(regex, path, max_count=100, *, show: bool = True, force: bool = False, force_bgz: bool = False): r"""Searches given paths for all lines containing regex matches. @@ -1046,23 +1087,28 @@ def grep(regex, path, max_count=100, *, show: bool = True, force: bool = False, return results -@typecheck(path=oneof(str, sequenceof(str)), - sample_file=nullable(str), - entry_fields=sequenceof(enumeration('GT', 'GP', 'dosage')), - n_partitions=nullable(int), - block_size=nullable(int), - index_file_map=nullable(dictof(str, str)), - variants=nullable(oneof(sequenceof(hl.utils.Struct), sequenceof(hl.genetics.Locus), - StructExpression, LocusExpression, Table)), - _row_fields=sequenceof(enumeration('varid', 'rsid'))) -def import_bgen(path, - entry_fields, - sample_file=None, - n_partitions=None, - block_size=None, - index_file_map=None, - variants=None, - _row_fields=['varid', 'rsid']) -> MatrixTable: +@typecheck( + path=oneof(str, sequenceof(str)), + sample_file=nullable(str), + entry_fields=sequenceof(enumeration('GT', 'GP', 'dosage')), + n_partitions=nullable(int), + block_size=nullable(int), + index_file_map=nullable(dictof(str, str)), + variants=nullable( + oneof(sequenceof(hl.utils.Struct), sequenceof(hl.genetics.Locus), StructExpression, LocusExpression, Table) + ), + _row_fields=sequenceof(enumeration('varid', 'rsid')), +) +def import_bgen( + path, + entry_fields, + sample_file=None, + n_partitions=None, + block_size=None, + index_file_map=None, + variants=None, + _row_fields=['varid', 'rsid'], +) -> MatrixTable: """Import BGEN file(s) as a :class:`.MatrixTable`. Examples @@ -1220,66 +1266,67 @@ def import_bgen(path, if variants is not None: mt_type = Env.backend().matrix_type( - ir.MatrixRead(ir.MatrixBGENReader(path, sample_file, index_file_map, n_partitions, block_size, None))) + ir.MatrixRead(ir.MatrixBGENReader(path, sample_file, index_file_map, n_partitions, block_size, None)) + ) lt = mt_type.row_type['locus'] expected_vtype = tstruct(locus=lt, alleles=tarray(tstr)) - if isinstance(variants, StructExpression) or isinstance(variants, LocusExpression): + if isinstance(variants, (StructExpression, LocusExpression)): if isinstance(variants, LocusExpression): variants = hl.struct(locus=variants) if len(variants.dtype) == 0 or not variants.dtype._is_prefix_of(expected_vtype): raise TypeError( "'import_bgen' requires the expression type for 'variants' is a non-empty prefix of the BGEN key type: \n" - + f"\tFound: {repr(variants.dtype)}\n" - + f"\tExpected: {repr(expected_vtype)}\n") + + f"\tFound: {variants.dtype!r}\n" + + f"\tExpected: {expected_vtype!r}\n" + ) uid = Env.get_uid() fnames = list(variants.dtype) name, variants = variants._to_table( - uid) # This will add back the other key fields of the source, which we don't want + uid + ) # This will add back the other key fields of the source, which we don't want variants = variants.key_by(**{fname: variants[name][fname] for fname in fnames}) variants = variants.select() elif isinstance(variants, Table): if len(variants.key) == 0 or not variants.key.dtype._is_prefix_of(expected_vtype): raise TypeError( "'import_bgen' requires the row key type for 'variants' is a non-empty prefix of the BGEN key type: \n" - + f"\tFound: {repr(variants.key.dtype)}\n" - + f"\tExpected: {repr(expected_vtype)}\n") + + f"\tFound: {variants.key.dtype!r}\n" + + f"\tExpected: {expected_vtype!r}\n" + ) variants = variants.select() else: assert isinstance(variants, list) try: if len(variants) == 0: - variants = hl.Table.parallelize(variants, - schema=expected_vtype, - key=['locus', 'alleles']) + variants = hl.Table.parallelize(variants, schema=expected_vtype, key=['locus', 'alleles']) else: first_v = variants[0] if isinstance(first_v, hl.Locus): - variants = hl.Table.parallelize([hl.Struct(locus=v) for v in variants], - schema=hl.tstruct(locus=lt), - key='locus') + variants = hl.Table.parallelize( + [hl.Struct(locus=v) for v in variants], schema=hl.tstruct(locus=lt), key='locus' + ) else: assert isinstance(first_v, hl.utils.Struct) if len(first_v) == 1: - variants = hl.Table.parallelize(variants, - schema=hl.tstruct(locus=lt), - key='locus') + variants = hl.Table.parallelize(variants, schema=hl.tstruct(locus=lt), key='locus') else: - variants = hl.Table.parallelize(variants, - schema=expected_vtype, - key=['locus', 'alleles']) + variants = hl.Table.parallelize(variants, schema=expected_vtype, key=['locus', 'alleles']) except Exception: raise TypeError( - f"'import_bgen' requires all elements in 'variants' are a non-empty prefix of the BGEN key type: {repr(expected_vtype)}") + f"'import_bgen' requires all elements in 'variants' are a non-empty prefix of the BGEN key type: {expected_vtype!r}" + ) vir = variants._tir - if isinstance(vir, ir.TableRead) \ - and isinstance(vir.reader, ir.TableNativeReader) \ - and vir.reader.intervals is None \ - and variants.count() == variants.distinct().count(): + if ( + isinstance(vir, ir.TableRead) + and isinstance(vir.reader, ir.TableNativeReader) + and vir.reader.intervals is None + and variants.count() == variants.distinct().count() + ): variants_path = vir.reader.path else: variants_path = new_temp_file(prefix='bgen_included_vars', extension='ht') @@ -1289,29 +1336,34 @@ def import_bgen(path, reader = ir.MatrixBGENReader(path, sample_file, index_file_map, n_partitions, block_size, variants_path) - mt = (MatrixTable(ir.MatrixRead(reader)) - .drop(*[fd for fd in ['GT', 'GP', 'dosage'] if fd not in entry_set], - *[fd for fd in ['rsid', 'varid', 'offset', 'file_idx'] if fd not in row_set])) + mt = MatrixTable(ir.MatrixRead(reader)).drop( + *[fd for fd in ['GT', 'GP', 'dosage'] if fd not in entry_set], + *[fd for fd in ['rsid', 'varid', 'offset', 'file_idx'] if fd not in row_set], + ) return mt -@typecheck(path=oneof(str, sequenceof(str)), - sample_file=nullable(str), - tolerance=numeric, - min_partitions=nullable(int), - chromosome=nullable(str), - reference_genome=nullable(reference_genome_type), - contig_recoding=nullable(dictof(str, str)), - skip_invalid_loci=bool) -def import_gen(path, - sample_file=None, - tolerance=0.2, - min_partitions=None, - chromosome=None, - reference_genome='default', - contig_recoding=None, - skip_invalid_loci=False) -> MatrixTable: +@typecheck( + path=oneof(str, sequenceof(str)), + sample_file=nullable(str), + tolerance=numeric, + min_partitions=nullable(int), + chromosome=nullable(str), + reference_genome=nullable(reference_genome_type), + contig_recoding=nullable(dictof(str, str)), + skip_invalid_loci=bool, +) +def import_gen( + path, + sample_file=None, + tolerance=0.2, + min_partitions=None, + chromosome=None, + reference_genome='default', + contig_recoding=None, + skip_invalid_loci=False, +) -> MatrixTable: """ Import GEN file(s) as a :class:`.MatrixTable`. @@ -1398,7 +1450,6 @@ def import_gen(path, gen_table = import_lines(path, min_partitions) sample_table = import_lines(sample_file) rg = reference_genome.name if reference_genome else None - contig_recoding = contig_recoding if contig_recoding is None: contig_recoding = hl.empty_dict(hl.tstr, hl.tstr) else: @@ -1421,17 +1472,22 @@ def import_gen(path, varid = gen_table.data[last_rowf_idx - 4] if rg is None: locus = hl.struct(contig=contig_holder, position=position) + elif skip_invalid_loci: + locus = hl.if_else( + hl.is_valid_locus(contig_holder, position, rg), + hl.locus(contig_holder, position, rg), + hl.missing(hl.tlocus(rg)), + ) else: - if skip_invalid_loci: - locus = hl.if_else(hl.is_valid_locus(contig_holder, position, rg), - hl.locus(contig_holder, position, rg), - hl.missing(hl.tlocus(rg))) - else: - locus = hl.locus(contig_holder, position, rg) + locus = hl.locus(contig_holder, position, rg) gen_table = gen_table.annotate(locus=locus, alleles=alleles, rsid=rsid, varid=varid) - gen_table = gen_table.annotate(entries=gen_table.data[last_rowf_idx + 1:].map(lambda x: hl.float64(x)) - .grouped(3).map(lambda x: hl.struct(GP=x))) + gen_table = gen_table.annotate( + entries=gen_table.data[last_rowf_idx + 1 :] + .map(lambda x: hl.float64(x)) + .grouped(3) + .map(lambda x: hl.struct(GP=x)) + ) if skip_invalid_loci: gen_table = gen_table.filter(hl.is_defined(gen_table.locus)) @@ -1444,18 +1500,28 @@ def import_gen(path, sample_table = sample_table.key_by(sample_table.idx) mt = mt.annotate_cols(s=sample_table[hl.int64(mt.col_idx)].s) - mt = mt.annotate_entries(GP=hl.rbind(hl.sum(mt.GP), lambda gp_sum: hl.if_else(hl.abs(1.0 - gp_sum) > tolerance, - hl.missing(hl.tarray(hl.tfloat64)), - hl.abs((1 / gp_sum) * mt.GP)))) - mt = mt.annotate_entries(GT=hl.rbind(hl.argmax(mt.GP), - lambda max_idx: hl.if_else( - hl.len(mt.GP.filter(lambda y: y == mt.GP[max_idx])) == 1, - hl.switch(max_idx) - .when(0, hl.call(0, 0)) - .when(1, hl.call(0, 1)) - .when(2, hl.call(1, 1)) - .or_error("error creating gt field."), - hl.missing(hl.tcall)))) + mt = mt.annotate_entries( + GP=hl.rbind( + hl.sum(mt.GP), + lambda gp_sum: hl.if_else( + hl.abs(1.0 - gp_sum) > tolerance, hl.missing(hl.tarray(hl.tfloat64)), hl.abs((1 / gp_sum) * mt.GP) + ), + ) + ) + mt = mt.annotate_entries( + GT=hl.rbind( + hl.argmax(mt.GP), + lambda max_idx: hl.if_else( + hl.len(mt.GP.filter(lambda y: y == mt.GP[max_idx])) == 1, + hl.switch(max_idx) + .when(0, hl.call(0, 0)) + .when(1, hl.call(0, 1)) + .when(2, hl.call(1, 1)) + .or_error("error creating gt field."), + hl.missing(hl.tcall), + ), + ) + ) mt = mt.filter_entries(hl.is_defined(mt.GP)) mt = mt.key_cols_by('s').drop('col_idx', 'file', 'data') @@ -1463,38 +1529,42 @@ def import_gen(path, return mt -@typecheck(paths=oneof(str, sequenceof(str)), - key=table_key_type, - min_partitions=nullable(int), - impute=bool, - no_header=bool, - comment=oneof(str, sequenceof(str)), - delimiter=str, - missing=oneof(str, sequenceof(str)), - types=dictof(str, hail_type), - quote=nullable(char), - skip_blank_lines=bool, - force_bgz=bool, - filter=nullable(str), - find_replace=nullable(sized_tupleof(str, str)), - force=bool, - source_file_field=nullable(str)) -def import_table(paths, - key=None, - min_partitions=None, - impute=False, - no_header=False, - comment=(), - delimiter="\t", - missing="NA", - types={}, - quote=None, - skip_blank_lines=False, - force_bgz=False, - filter=None, - find_replace=None, - force=False, - source_file_field=None) -> Table: +@typecheck( + paths=oneof(str, sequenceof(str)), + key=table_key_type, + min_partitions=nullable(int), + impute=bool, + no_header=bool, + comment=oneof(str, sequenceof(str)), + delimiter=str, + missing=oneof(str, sequenceof(str)), + types=dictof(str, hail_type), + quote=nullable(char), + skip_blank_lines=bool, + force_bgz=bool, + filter=nullable(str), + find_replace=nullable(sized_tupleof(str, str)), + force=bool, + source_file_field=nullable(str), +) +def import_table( + paths, + key=None, + min_partitions=None, + impute=False, + no_header=False, + comment=(), + delimiter="\t", + missing="NA", + types={}, + quote=None, + skip_blank_lines=False, + force_bgz=False, + filter=None, + find_replace=None, + force=False, + source_file_field=None, +) -> Table: """Import delimited text file (text table) as :class:`.Table`. The resulting :class:`.Table` will have no key fields. Use @@ -1717,7 +1787,8 @@ def import_table(paths, first_rows = first_row_ht.annotate( header=first_row_ht.text._split_line( - delimiter, missing=hl.empty_array(hl.tstr), quote=quote, regex=len(delimiter) > 1) + delimiter, missing=hl.empty_array(hl.tstr), quote=quote, regex=len(delimiter) > 1 + ) ).collect() except FatalError as err: if '_filter_partitions: no partition with index 0' in err.args[0]: @@ -1734,20 +1805,19 @@ def import_table(paths, else: maybe_duplicated_fields = first_row.header renamings, fields = deduplicate(maybe_duplicated_fields) - ht = ht.filter(ht.text == first_row.text, keep=False) # FIXME: seems wrong. Could easily fix with partition index and row_within_partition_index. + ht = ht.filter( + ht.text == first_row.text, keep=False + ) # FIXME: seems wrong. Could easily fix with partition index and row_within_partition_index. if renamings: hl.utils.warning( f'import_table: renamed the following {plural("field", len(renamings))} to avoid name conflicts:' - + ''.join(f'\n {repr(k)} -> {repr(v)}' for k, v in renamings) + + ''.join(f'\n {k!r} -> {v!r}' for k, v in renamings) ) ht = ht.annotate( split_text=( hl.case() - .when( - hl.len(ht.text) > 0, - split_lines(ht, fields, delimiter=delimiter, missing=missing, quote=quote) - ) + .when(hl.len(ht.text) > 0, split_lines(ht, fields, delimiter=delimiter, missing=missing, quote=quote)) .or_error(hl.str("Blank line found in file ") + ht.file) ) ) @@ -1764,8 +1834,9 @@ def import_table(paths, fields_to_guess.append(field) hl.utils.info('Reading table to impute column types') - guessed = ht.aggregate(hl.agg.array_agg(lambda x: hl.agg._impute_type(x), - [ht.split_text[i] for i in fields_to_impute_idx])) + guessed = ht.aggregate( + hl.agg.array_agg(lambda x: hl.agg._impute_type(x), [ht.split_text[i] for i in fields_to_impute_idx]) + ) reasons = {f: 'user-supplied type' for f in types} imputed_types = dict() @@ -1811,6 +1882,7 @@ def import_table(paths, hl.utils.info('\n'.join(strs)) else: from collections import Counter + strs2 = [f'Loading {ht.row} fields. Counts by type:'] for name, count in Counter(ht[f].dtype for f in fields).most_common(): strs2.append(f' {name}: {count}') @@ -1822,8 +1894,9 @@ def import_table(paths, return ht -@typecheck(paths=oneof(str, sequenceof(str)), min_partitions=nullable(int), force_bgz=bool, - force=bool, file_per_partition=bool) +@typecheck( + paths=oneof(str, sequenceof(str)), min_partitions=nullable(int), force_bgz=bool, force=bool, file_per_partition=bool +) def import_lines(paths, min_partitions=None, force_bgz=False, force=False, file_per_partition=False) -> Table: """Import lines of file(s) as a :class:`.Table` of strings. @@ -1876,42 +1949,43 @@ def import_lines(paths, min_partitions=None, force_bgz=False, force=False, file_ if file_per_partition and min_partitions is not None: if min_partitions > len(paths): - raise FatalError(f'file_per_partition is True while min partitions is {min_partitions} ,which is greater' - f' than the number of files, {len(paths)}') + raise FatalError( + f'file_per_partition is True while min partitions is {min_partitions} ,which is greater' + f' than the number of files, {len(paths)}' + ) st_reader = ir.StringTableReader(paths, min_partitions, force_bgz, force, file_per_partition) - table_type = hl.ttable( - global_type=hl.tstruct(), - row_type=hl.tstruct(file=hl.tstr, text=hl.tstr), - row_key=[] - ) + table_type = hl.ttable(global_type=hl.tstruct(), row_type=hl.tstruct(file=hl.tstr, text=hl.tstr), row_key=[]) string_table = Table(ir.TableRead(st_reader, _assert_type=table_type)) return string_table -@typecheck(paths=oneof(str, sequenceof(str)), - row_fields=dictof(str, hail_type), - row_key=oneof(str, sequenceof(str)), - entry_type=enumeration(tint32, tint64, tfloat32, tfloat64, tstr), - missing=str, - min_partitions=nullable(int), - no_header=bool, - force_bgz=bool, - sep=nullable(str), - delimiter=nullable(str), - comment=oneof(str, sequenceof(str)), - ) -def import_matrix_table(paths, - row_fields={}, - row_key=[], - entry_type=tint32, - missing="NA", - min_partitions=None, - no_header=False, - force_bgz=False, - sep=None, - delimiter=None, - comment=()) -> MatrixTable: +@typecheck( + paths=oneof(str, sequenceof(str)), + row_fields=dictof(str, hail_type), + row_key=oneof(str, sequenceof(str)), + entry_type=enumeration(tint32, tint64, tfloat32, tfloat64, tstr), + missing=str, + min_partitions=nullable(int), + no_header=bool, + force_bgz=bool, + sep=nullable(str), + delimiter=nullable(str), + comment=oneof(str, sequenceof(str)), +) +def import_matrix_table( + paths, + row_fields={}, + row_key=[], + entry_type=tint32, + missing="NA", + min_partitions=None, + no_header=False, + force_bgz=False, + sep=None, + delimiter=None, + comment=(), +) -> MatrixTable: """Import tab-delimited file(s) as a :class:`.MatrixTable`. Examples @@ -2072,11 +2146,16 @@ def import_matrix_table(paths, missing_list = wrap_to_list(missing) def comment_filter(table): - return hl.rbind(hl.array(comment), - lambda hl_comment: hl_comment.any(lambda com: hl.if_else(hl.len(com) == 1, - table.text.startswith(com), - table.text.matches(com, False)))) \ - if len(comment) > 0 else False + return ( + hl.rbind( + hl.array(comment), + lambda hl_comment: hl_comment.any( + lambda com: hl.if_else(hl.len(com) == 1, table.text.startswith(com), table.text.matches(com, False)) + ), + ) + if len(comment) > 0 + else False + ) def truncate(string_array, delim=", "): if len(string_array) > 10: @@ -2088,12 +2167,20 @@ def truncate(string_array, delim=", "): def format_file(file_name, hl_value=False): if hl_value: - return hl.rbind(file_name.split('/'), lambda split_file: - hl.if_else(hl.len(split_file) <= 4, hl.str("/").join(file_name.split('/')[-4:]), - hl.str("/") + hl.str("/").join(file_name.split('/')[-4:]))) + return hl.rbind( + file_name.split('/'), + lambda split_file: hl.if_else( + hl.len(split_file) <= 4, + hl.str("/").join(file_name.split('/')[-4:]), + hl.str("/") + hl.str("/").join(file_name.split('/')[-4:]), + ), + ) else: - return "/".join(file_name.split('/')[-3:]) if len(file_name) <= 4 else \ - "/" + "/".join(file_name.split('/')[-3:]) + return ( + "/".join(file_name.split('/')[-3:]) + if len(file_name) <= 4 + else "/" + "/".join(file_name.split('/')[-3:]) + ) file_start_array = None @@ -2103,9 +2190,11 @@ def get_file_start(row): collect_expr = first_lines_table.collect(_localize=False).map(lambda line: (line.file, line.idx)) file_start_array = hl.literal(hl.eval(collect_expr), dtype=collect_expr.dtype) return hl.coalesce( - file_start_array.filter(lambda line_tuple: line_tuple[0] == row.file).map( - lambda line_tuple: line_tuple[1]).first(), - 0) + file_start_array.filter(lambda line_tuple: line_tuple[0] == row.file) + .map(lambda line_tuple: line_tuple[1]) + .first(), + 0, + ) def validate_row_fields(): unique_fields = {} @@ -2115,12 +2204,18 @@ def validate_row_fields(): rowf_type = row_fields.get(header_rowf) if rowf_type is None: import itertools as it - row_fields_string = '\n'.join(list(it.starmap( - lambda row_field, row_type: f" '{row_field}': {str(row_type)}", row_fields.items()))) + + row_fields_string = '\n'.join( + list( + it.starmap(lambda row_field, row_type: f" '{row_field}': {row_type!s}", row_fields.items()) + ) + ) header_fields_string = "\n ".join(map(lambda field: f"'{field}'", header_dict['row_fields'])) - raise FatalError(f"in file {format_file(header_dict['path'])} found row field '{header_rowf}' that's" - f" not in 'row fields'\nrow fields found in file:\n {header_fields_string}" - f"\n'row fields':\n{row_fields_string}") + raise FatalError( + f"in file {format_file(header_dict['path'])} found row field '{header_rowf}' that's" + f" not in 'row fields'\nrow fields found in file:\n {header_fields_string}" + f"\n'row fields':\n{row_fields_string}" + ) if header_rowf in unique_fields: duplicates.append(header_rowf) else: @@ -2131,17 +2226,24 @@ def validate_row_fields(): def parse_entries(row): return hl.range(num_of_row_fields, len(header_dict['column_ids']) + num_of_row_fields).map( - lambda entry_idx: parse_type_or_error(entry_type, row, entry_idx, not_entries=False)) + lambda entry_idx: parse_type_or_error(entry_type, row, entry_idx, not_entries=False) + ) def parse_rows(row): rows_list = list(row_fields.items()) - return {rows_list[idx][0]: - parse_type_or_error(rows_list[idx][1], row, idx) for idx in range(num_of_row_fields)} + return {rows_list[idx][0]: parse_type_or_error(rows_list[idx][1], row, idx) for idx in range(num_of_row_fields)} def error_msg(row, idx, msg): - return (hl.str("in file ") + hl.str(format_file(row.file, True)) - + hl.str(" on line ") + hl.str(row.row_id - get_file_start(row) + 1) - + hl.str(" at value '") + hl.str(row.split_array[idx]) + hl.str("':\n") + hl.str(msg)) + return ( + hl.str("in file ") + + hl.str(format_file(row.file, True)) + + hl.str(" on line ") + + hl.str(row.row_id - get_file_start(row) + 1) + + hl.str(" at value '") + + hl.str(row.split_array[idx]) + + hl.str("':\n") + + hl.str(msg) + ) def parse_type_or_error(hail_type, row, idx, not_entries=True): value = row.split_array[idx] @@ -2159,13 +2261,19 @@ def parse_type_or_error(hail_type, row, idx, not_entries=True): if not_entries: error_clarify_msg = hl.str(" at row field '") + hl.str(hl_row_fields[idx]) + hl.str("'") else: - error_clarify_msg = (hl.str(" at column id '") + hl.str(hl_columns[idx - num_of_row_fields]) - + hl.str("' for entry field 'x' ")) + error_clarify_msg = ( + hl.str(" at column id '") + + hl.str(hl_columns[idx - num_of_row_fields]) + + hl.str("' for entry field 'x' ") + ) - return hl.if_else(hl.is_missing(value), hl.missing(hail_type), - hl.case().when(~hl.is_missing(parsed_type), parsed_type) - .or_error( - error_msg(row, idx, f"error parsing value into {str(hail_type)}" + error_clarify_msg))) + return hl.if_else( + hl.is_missing(value), + hl.missing(hail_type), + hl.case() + .when(~hl.is_missing(parsed_type), parsed_type) + .or_error(error_msg(row, idx, f"error parsing value into {hail_type!s}" + error_clarify_msg)), + ) num_of_row_fields = len(row_fields.keys()) add_row_id = False @@ -2175,9 +2283,7 @@ def parse_type_or_error(hail_type, row, idx, not_entries=True): if sep is not None: if delimiter is not None: - raise ValueError( - f'expecting either sep or delimiter but received both: ' - f'{sep}, {delimiter}') + raise ValueError(f'expecting either sep or delimiter but received both: ' f'{sep}, {delimiter}') delimiter = sep del sep @@ -2189,18 +2295,21 @@ def parse_type_or_error(hail_type, row, idx, not_entries=True): if add_row_id: if 'row_id' in row_fields: raise FatalError( - "import_matrix_table reserves the field name 'row_id' for" - 'its own use, please use a different name') + "import_matrix_table reserves the field name 'row_id' for" 'its own use, please use a different name' + ) for k, v in row_fields.items(): if v not in {tint32, tint64, tfloat32, tfloat64, tstr}: raise FatalError( f'import_matrix_table expects field types to be one of:' - f"'int32', 'int64', 'float32', 'float64', 'str': field {repr(k)} had type '{v}'") + f"'int32', 'int64', 'float32', 'float64', 'str': field {k!r} had type '{v}'" + ) if entry_type not in {tint32, tint64, tfloat32, tfloat64, tstr}: - raise FatalError("""import_matrix_table expects entry types to be one of: - 'int32', 'int64', 'float32', 'float64', 'str': found '{}'""".format(entry_type)) + raise FatalError( + """import_matrix_table expects entry types to be one of: + 'int32', 'int64', 'float32', 'float64', 'str': found '{}'""".format(entry_type) + ) if missing in delimiter: raise FatalError(f"Missing value {missing} contains delimiter {delimiter}") @@ -2208,12 +2317,14 @@ def parse_type_or_error(hail_type, row, idx, not_entries=True): ht = import_lines(paths, min_partitions, force_bgz=force_bgz).add_index(name='row_id') # for checking every header matches file_per_partition = import_lines(paths, force_bgz=force_bgz, file_per_partition=True) - file_per_partition = file_per_partition.filter(hl.bool(hl.len(file_per_partition.text) == 0) - | comment_filter(file_per_partition), False) + file_per_partition = file_per_partition.filter( + hl.bool(hl.len(file_per_partition.text) == 0) | comment_filter(file_per_partition), False + ) first_lines_table = file_per_partition._map_partitions(lambda rows: rows.take(1)) first_lines_table = first_lines_table.annotate(split_array=first_lines_table.text.split(delimiter)).add_index() if not no_header: + def validate_header_get_info_dict(): two_first_lines = file_per_partition.head(2) two_first_lines = two_first_lines.annotate(split_array=two_first_lines.text.split(delimiter)).collect() @@ -2226,9 +2337,11 @@ def validate_header_get_info_dict(): elif not first_data_line or first_data_line.file != header_line.file: hl.utils.warning(f"File {format_file(header_line.file)} contains a header, but no lines of data") if num_of_header_values < num_of_data_line_values: - raise ValueError(f"File {format_file(header_line.file)} contains one line assumed to be the header." - f"The header had a length of {num_of_header_values} while the number" - f"of row fields is {num_of_row_fields}") + raise ValueError( + f"File {format_file(header_line.file)} contains one line assumed to be the header." + f"The header had a length of {num_of_header_values} while the number" + f"of row fields is {num_of_row_fields}" + ) user_row_fields = header_line.split_array[:num_of_row_fields] column_ids = header_line.split_array[num_of_row_fields:] elif num_of_data_line_values != num_of_header_values: @@ -2242,12 +2355,18 @@ def validate_header_get_info_dict(): f" colId0 colId1 ...\nInstead the first two lines were:\nInstead the first two lin" f"es were:\n{header_line.text}\n{first_data_line.text}\nThe first line contained" f" {num_of_header_values} separated values and the second line" - f" contained {num_of_data_line_values}") + f" contained {num_of_data_line_values}" + ) else: user_row_fields = header_line.split_array[:num_of_row_fields] column_ids = header_line.split_array[num_of_row_fields:] - return {'text': header_line.text, 'header_values': header_line.split_array, 'path': header_line.file, - 'row_fields': user_row_fields, 'column_ids': column_ids} + return { + 'text': header_line.text, + 'header_values': header_line.split_array, + 'path': header_line.file, + 'row_fields': user_row_fields, + 'column_ids': column_ids, + } def warn_if_duplicate_col_ids(): time_col_id_encountered_dict = {} @@ -2261,15 +2380,24 @@ def warn_if_duplicate_col_ids(): return import itertools as it + duplicates_to_print = sorted( - [('"' + dup_field + '"', '(' + str(time_col_id_encountered_dict[dup_field]) + ')') - for dup_field in duplicate_cols], key=lambda dup_values: dup_values[1]) + [ + ('"' + dup_field + '"', '(' + str(time_col_id_encountered_dict[dup_field]) + ')') + for dup_field in duplicate_cols + ], + key=lambda dup_values: dup_values[1], + ) duplicates_to_print = truncate(duplicates_to_print) - duplicates_to_print_formatted = it.starmap(lambda dup, time_found: time_found - + " " + dup, duplicates_to_print) - ht.utils.warning(f"Found {len(duplicate_cols)} duplicate column id" - + f"{'s' if len(duplicate_cols) > 1 else ''}\n" + '\n'.join(duplicates_to_print_formatted)) + duplicates_to_print_formatted = it.starmap( + lambda dup, time_found: time_found + " " + dup, duplicates_to_print + ) + ht.utils.warning( + f"Found {len(duplicate_cols)} duplicate column id" + + f"{'s' if len(duplicate_cols) > 1 else ''}\n" + + '\n'.join(duplicates_to_print_formatted) + ) def validate_all_headers(): all_headers = first_lines_table.collect() @@ -2281,17 +2409,21 @@ def validate_all_headers(): main_header_value = header_values[0] error_header_value = header_values[1] if main_header_value != error_header_value: - raise ValueError("invalid header: expected elements to be identical for all input paths" - f". Found different elements at position {header_idx + 1}" - f"\n in file {format_file(header.file)} with value " - f"'{error_header_value}' when expecting value '{main_header_value}'") + raise ValueError( + "invalid header: expected elements to be identical for all input paths" + f". Found different elements at position {header_idx + 1}" + f"\n in file {format_file(header.file)} with value " + f"'{error_header_value}' when expecting value '{main_header_value}'" + ) else: - raise ValueError(f"invalid header: lengths of headers differ. \n" - f"{len(header_dict['header_values'])} elements in " - f"{format_file(header_dict['path'])}:\n" - + truncate(["'{}'".format(value) for value in header_dict['header_values']]) - + f" {len(header.split_array)} elements in {format_file(header.file)}:\n" - + truncate(["'{}'".format(value) for value in header.split_array])) + raise ValueError( + f"invalid header: lengths of headers differ. \n" + f"{len(header_dict['header_values'])} elements in " + f"{format_file(header_dict['path'])}:\n" + + truncate(["'{}'".format(value) for value in header_dict['header_values']]) + + f" {len(header.split_array)} elements in {format_file(header.file)}:\n" + + truncate(["'{}'".format(value) for value in header.split_array]) + ) header_dict = validate_header_get_info_dict() warn_if_duplicate_col_ids() @@ -2300,19 +2432,19 @@ def validate_all_headers(): else: first_line = first_lines_table.head(1).collect() if not first_line or path_to_index[first_line[0].file] != 0: - hl.utils.warning( - f"File {format_file(paths[0])} is empty and has no header, so we assume no columns") - header_dict = {'header_values': [], - 'row_fields': ["f" + str(f_idx) for f_idx in list(range(0, num_of_row_fields))], - 'column_ids': [] - } + hl.utils.warning(f"File {format_file(paths[0])} is empty and has no header, so we assume no columns") + header_dict = { + 'header_values': [], + 'row_fields': ["f" + str(f_idx) for f_idx in list(range(0, num_of_row_fields))], + 'column_ids': [], + } else: first_line = first_line[0] - header_dict = {'header_values': [], - 'row_fields': ["f" + str(f_idx) for f_idx in list(range(0, num_of_row_fields))], - 'column_ids': - [col_id for col_id in list(range(0, len(first_line.split_array) - num_of_row_fields))] - } + header_dict = { + 'header_values': [], + 'row_fields': ["f" + str(f_idx) for f_idx in list(range(0, num_of_row_fields))], + 'column_ids': [col_id for col_id in list(range(0, len(first_line.split_array) - num_of_row_fields))], + } validate_row_fields() header_filter = ht.text == header_dict['text'] if not no_header else False @@ -2320,33 +2452,34 @@ def validate_all_headers(): ht = ht.filter(hl.bool(hl.len(ht.text) == 0) | comment_filter(ht) | header_filter, False) hl_columns = hl.array(header_dict['column_ids']) if len(header_dict['column_ids']) > 0 else hl.empty_array(hl.tstr) - hl_row_fields = hl.array(header_dict['row_fields']) if len(header_dict['row_fields']) > 0 \ - else hl.empty_array(hl.tstr) + hl_row_fields = ( + hl.array(header_dict['row_fields']) if len(header_dict['row_fields']) > 0 else hl.empty_array(hl.tstr) + ) ht = ht.annotate(split_array=ht.text._split_line(delimiter, missing_list, quote=None, regex=False)).add_index( - 'row_id') + 'row_id' + ) - ht = ht.annotate(split_array=hl.case().when(hl.len(ht.split_array) >= num_of_row_fields, ht.split_array) - .or_error(error_msg(ht, hl.len(ht.split_array) - 1, - " unexpected end of line while reading row field"))) + ht = ht.annotate( + split_array=hl.case() + .when(hl.len(ht.split_array) >= num_of_row_fields, ht.split_array) + .or_error(error_msg(ht, hl.len(ht.split_array) - 1, " unexpected end of line while reading row field")) + ) n_column_ids = len(header_dict['column_ids']) - n_in_split_array = hl.len(ht.split_array[num_of_row_fields:(num_of_row_fields + n_column_ids)]) - ht = ht.annotate(split_array=hl.case().when( - n_column_ids <= n_in_split_array, - ht.split_array - ).or_error( - error_msg( - ht, - hl.len(ht.split_array) - 1, - " unexpected end of line while reading entries" - ) - )) + n_in_split_array = hl.len(ht.split_array[num_of_row_fields : (num_of_row_fields + n_column_ids)]) + ht = ht.annotate( + split_array=hl.case() + .when(n_column_ids <= n_in_split_array, ht.split_array) + .or_error(error_msg(ht, hl.len(ht.split_array) - 1, " unexpected end of line while reading entries")) + ) - ht = ht.annotate(**parse_rows(ht), entries=parse_entries(ht).map(lambda entry: hl.struct(x=entry)))\ - .drop('text', 'split_array', 'file') + ht = ht.annotate(**parse_rows(ht), entries=parse_entries(ht).map(lambda entry: hl.struct(x=entry))).drop( + 'text', 'split_array', 'file' + ) - ht = ht.annotate_globals(cols=hl.range(0, len(header_dict['column_ids'])) - .map(lambda col_idx: hl.struct(col_id=hl_columns[col_idx]))) + ht = ht.annotate_globals( + cols=hl.range(0, len(header_dict['column_ids'])).map(lambda col_idx: hl.struct(col_id=hl_columns[col_idx])) + ) if not add_row_id: ht = ht.drop('row_id') @@ -2356,30 +2489,36 @@ def validate_all_headers(): return mt -@typecheck(bed=str, - bim=str, - fam=str, - min_partitions=nullable(int), - delimiter=str, - missing=str, - quant_pheno=bool, - a2_reference=bool, - reference_genome=nullable(reference_genome_type), - contig_recoding=nullable(dictof(str, str)), - skip_invalid_loci=bool, - n_partitions=nullable(int), - block_size=nullable(int)) -def import_plink(bed, bim, fam, - min_partitions=None, - delimiter='\\\\s+', - missing='NA', - quant_pheno=False, - a2_reference=True, - reference_genome='default', - contig_recoding=None, - skip_invalid_loci=False, - n_partitions=None, - block_size=None) -> MatrixTable: +@typecheck( + bed=str, + bim=str, + fam=str, + min_partitions=nullable(int), + delimiter=str, + missing=str, + quant_pheno=bool, + a2_reference=bool, + reference_genome=nullable(reference_genome_type), + contig_recoding=nullable(dictof(str, str)), + skip_invalid_loci=bool, + n_partitions=nullable(int), + block_size=nullable(int), +) +def import_plink( + bed, + bim, + fam, + min_partitions=None, + delimiter='\\\\s+', + missing='NA', + quant_pheno=False, + a2_reference=True, + reference_genome='default', + contig_recoding=None, + skip_invalid_loci=False, + n_partitions=None, + block_size=None, +) -> MatrixTable: """Import a PLINK dataset (BED, BIM, FAM) as a :class:`.MatrixTable`. Examples @@ -2508,31 +2647,56 @@ def import_plink(bed, bim, fam, elif reference_genome.name == "GRCh37": contig_recoding = {'23': 'X', '24': 'Y', '25': 'X', '26': 'MT'} elif reference_genome.name == "GRCh38": - contig_recoding = {**{str(i): f'chr{i}' for i in range(1, 23)}, - **{'23': 'chrX', '24': 'chrY', '25': 'chrX', '26': 'chrM'}} + contig_recoding = { + **{str(i): f'chr{i}' for i in range(1, 23)}, + **{'23': 'chrX', '24': 'chrY', '25': 'chrX', '26': 'chrM'}, + } else: contig_recoding = {} - reader = ir.MatrixPLINKReader(bed, bim, fam, - n_partitions, block_size, min_partitions, - missing, delimiter, quant_pheno, a2_reference, reference_genome, - contig_recoding, skip_invalid_loci) + reader = ir.MatrixPLINKReader( + bed, + bim, + fam, + n_partitions, + block_size, + min_partitions, + missing, + delimiter, + quant_pheno, + a2_reference, + reference_genome, + contig_recoding, + skip_invalid_loci, + ) return MatrixTable(ir.MatrixRead(reader, drop_cols=False, drop_rows=False)) -@typecheck(path=str, - _intervals=nullable(sequenceof(anytype)), - _filter_intervals=bool, - _drop_cols=bool, - _drop_rows=bool, - _create_row_uids=bool, - _create_col_uids=bool, - _n_partitions=nullable(int), - _assert_type=nullable(hl.tmatrix), - _load_refs=bool) -def read_matrix_table(path, *, _intervals=None, _filter_intervals=False, _drop_cols=False, - _drop_rows=False, _create_row_uids=False, _create_col_uids=False, - _n_partitions=None, _assert_type=None, _load_refs=True) -> MatrixTable: +@typecheck( + path=str, + _intervals=nullable(sequenceof(anytype)), + _filter_intervals=bool, + _drop_cols=bool, + _drop_rows=bool, + _create_row_uids=bool, + _create_col_uids=bool, + _n_partitions=nullable(int), + _assert_type=nullable(tmatrix), + _load_refs=bool, +) +def read_matrix_table( + path, + *, + _intervals=None, + _filter_intervals=False, + _drop_cols=False, + _drop_rows=False, + _create_row_uids=False, + _create_col_uids=False, + _n_partitions=None, + _assert_type=None, + _load_refs=True, +) -> MatrixTable: """Read in a :class:`.MatrixTable` written with :meth:`.MatrixTable.write`. Parameters @@ -2551,12 +2715,16 @@ def read_matrix_table(path, *, _intervals=None, _filter_intervals=False, _drop_c if _intervals is not None and _n_partitions is not None: raise ValueError("'read_matrix_table' does not support both _intervals and _n_partitions") - mt = MatrixTable(ir.MatrixRead(ir.MatrixNativeReader(path, _intervals, _filter_intervals), - _drop_cols, - _drop_rows, - drop_row_uids=not _create_row_uids, - drop_col_uids=not _create_col_uids, - _assert_type=_assert_type)) + mt = MatrixTable( + ir.MatrixRead( + ir.MatrixNativeReader(path, _intervals, _filter_intervals), + _drop_cols, + _drop_rows, + drop_row_uids=not _create_row_uids, + drop_col_uids=not _create_col_uids, + _assert_type=_assert_type, + ) + ) if _n_partitions: intervals = mt._calculate_new_partitions(_n_partitions) return read_matrix_table( @@ -2565,7 +2733,7 @@ def read_matrix_table(path, *, _intervals=None, _filter_intervals=False, _drop_c _drop_cols=_drop_cols, _intervals=intervals, _assert_type=_assert_type, - _load_refs=_load_refs + _load_refs=_load_refs, ) return mt @@ -2628,42 +2796,46 @@ def get_vcf_metadata(path): return Env.backend().parse_vcf_metadata(path) -@typecheck(path=oneof(str, sequenceof(str)), - force=bool, - force_bgz=bool, - header_file=nullable(str), - min_partitions=nullable(int), - drop_samples=bool, - call_fields=oneof(str, sequenceof(str)), - reference_genome=nullable(reference_genome_type), - contig_recoding=nullable(dictof(str, str)), - array_elements_required=bool, - skip_invalid_loci=bool, - entry_float_type=enumeration(tfloat32, tfloat64), - filter=nullable(str), - find_replace=nullable(sized_tupleof(str, str)), - n_partitions=nullable(int), - block_size=nullable(int), - _create_row_uids=bool, - _create_col_uids=bool) -def import_vcf(path, - force=False, - force_bgz=False, - header_file=None, - min_partitions=None, - drop_samples=False, - call_fields=['PGT'], - reference_genome='default', - contig_recoding=None, - array_elements_required=True, - skip_invalid_loci=False, - entry_float_type=tfloat64, - filter=None, - find_replace=None, - n_partitions=None, - block_size=None, - _create_row_uids=False, _create_col_uids=False, - ) -> MatrixTable: +@typecheck( + path=oneof(str, sequenceof(str)), + force=bool, + force_bgz=bool, + header_file=nullable(str), + min_partitions=nullable(int), + drop_samples=bool, + call_fields=oneof(str, sequenceof(str)), + reference_genome=nullable(reference_genome_type), + contig_recoding=nullable(dictof(str, str)), + array_elements_required=bool, + skip_invalid_loci=bool, + entry_float_type=enumeration(tfloat32, tfloat64), + filter=nullable(str), + find_replace=nullable(sized_tupleof(str, str)), + n_partitions=nullable(int), + block_size=nullable(int), + _create_row_uids=bool, + _create_col_uids=bool, +) +def import_vcf( + path, + force=False, + force_bgz=False, + header_file=None, + min_partitions=None, + drop_samples=False, + call_fields=['PGT'], + reference_genome='default', + contig_recoding=None, + array_elements_required=True, + skip_invalid_loci=False, + entry_float_type=tfloat64, + filter=None, + find_replace=None, + n_partitions=None, + block_size=None, + _create_row_uids=False, + _create_col_uids=False, +) -> MatrixTable: """Import VCF file(s) as a :class:`.MatrixTable`. Examples @@ -2692,7 +2864,49 @@ def import_vcf(path, Import a bgzipped VCF which uses the "gz" extension rather than the "bgz" extension: - >>> ds = hl.import_vcf('data/samplepart*.vcf.gz', force_bgz=True) + >>> ds = hl.import_vcf('data/sample.vcf.gz', force_bgz=True) + + Import a VCF which has missing values (".") inside INFO or FORMAT array fields: + + >>> print(open('data/missing-values-in-array-fields.vcf').read()) + ##fileformat=VCFv4.1 + ##FORMAT= + ##FORMAT= + ##FORMAT= + ##FORMAT= + ##INFO= + ##INFO= + ##INFO= + ##INFO= + #CHROM POS ID REF ALT QUAL FILTER INFO FORMAT SAMPLE1 + 1 123456 . A C . . A=1,.;B=.,2,.;C=. GT:X:Y:Z 0/0:1,.,1:. + + >>> ds = hl.import_vcf('data/missing-values-in-array-fields.vcf', array_elements_required=False) + >>> ds.show(n_rows=1, n_cols=1, include_row_fields=True) + +---------------+------------+------+-----------+----------+--------------+ + | locus | alleles | rsid | qual | filters | info.A | + +---------------+------------+------+-----------+----------+--------------+ + | locus | array | str | float64 | set | array | + +---------------+------------+------+-----------+----------+--------------+ + | 1:123456 | ["A","C"] | NA | -1.00e+01 | NA | [1,NA] | + +---------------+------------+------+-----------+----------+--------------+ + + +------------------+----------------+----------------+--------------+ + | info.B | info.C | info.D | 'SAMPLE1'.GT | + +------------------+----------------+----------------+--------------+ + | array | array | array | call | + +------------------+----------------+----------------+--------------+ + | [NA,2.00e+00,NA] | NA | NA | 0/0 | + +------------------+----------------+----------------+--------------+ + + +--------------+--------------+--------------+ + | 'SAMPLE1'.X | 'SAMPLE1'.Y | 'SAMPLE1'.Z | + +--------------+--------------+--------------+ + | array | array | array | + +--------------+--------------+--------------+ + | [1,NA,1] | NA | NA | + +--------------+--------------+--------------+ + Notes ----- @@ -2835,54 +3049,101 @@ def import_vcf(path, 'force_bgz=True instead.' ) - reader = ir.MatrixVCFReader(path, call_fields, entry_float_type, header_file, - n_partitions, block_size, min_partitions, - reference_genome, contig_recoding, array_elements_required, - skip_invalid_loci, force_bgz, force, filter, find_replace) - return MatrixTable(ir.MatrixRead(reader, drop_cols=drop_samples, drop_row_uids=not _create_row_uids, drop_col_uids=not _create_col_uids)) - - -@typecheck(path=expr_str, - file_num=expr_int32, - contig=expr_str, - start=expr_int32, - end=expr_int32, - header_info=anytype, - call_fields=sequenceof(str), - entry_float_type=hail_type, - array_elements_required=bool, - reference_genome=reference_genome_type, - contig_recoding=nullable(dictof(str, str)), - skip_invalid_loci=bool, - filter=nullable(str), - find=nullable(str), - replace=nullable(str)) -def import_gvcf_interval(path, file_num, contig, start, end, header_info, call_fields=['PGT'], entry_float_type='float64', - array_elements_required=True, reference_genome='default', contig_recoding=None, - skip_invalid_loci=False, filter=None, find=None, replace=None): + reader = ir.MatrixVCFReader( + path, + call_fields, + entry_float_type, + header_file, + n_partitions, + block_size, + min_partitions, + reference_genome, + contig_recoding, + array_elements_required, + skip_invalid_loci, + force_bgz, + force, + filter, + find_replace, + ) + return MatrixTable( + ir.MatrixRead( + reader, drop_cols=drop_samples, drop_row_uids=not _create_row_uids, drop_col_uids=not _create_col_uids + ) + ) + + +@typecheck( + path=expr_str, + file_num=expr_int32, + contig=expr_str, + start=expr_int32, + end=expr_int32, + header_info=anytype, + call_fields=sequenceof(str), + entry_float_type=hail_type, + array_elements_required=bool, + reference_genome=reference_genome_type, + contig_recoding=nullable(dictof(str, str)), + skip_invalid_loci=bool, + filter=nullable(str), + find=nullable(str), + replace=nullable(str), +) +def import_gvcf_interval( + path, + file_num, + contig, + start, + end, + header_info, + call_fields=['PGT'], + entry_float_type='float64', + array_elements_required=True, + reference_genome='default', + contig_recoding=None, + skip_invalid_loci=False, + filter=None, + find=None, + replace=None, +): indices, aggs = hl.expr.unify_all(path, file_num, contig, start, end) - stream_ir = ir.ReadPartition(hl.struct(fileNum=file_num, path=path, contig=contig, start=start, end=end)._ir, - ir.GVCFPartitionReader(header_info, call_fields, entry_float_type, - array_elements_required, - reference_genome, contig_recoding or {}, - skip_invalid_loci, filter, find, replace, - None)) + stream_ir = ir.ReadPartition( + hl.struct(fileNum=file_num, path=path, contig=contig, start=start, end=end)._ir, + ir.GVCFPartitionReader( + header_info, + call_fields, + entry_float_type, + array_elements_required, + reference_genome, + contig_recoding or {}, + skip_invalid_loci, + filter, + find, + replace, + None, + ), + ) arr = ir.ToArray(stream_ir) return hl.expr.construct_expr(arr, arr.typ, indices, aggs) -@typecheck(path=oneof(str, sequenceof(str)), - index_file_map=nullable(dictof(str, str)), - reference_genome=nullable(reference_genome_type), - contig_recoding=nullable(dictof(str, str)), - skip_invalid_loci=bool, - _buffer_size=int) -def index_bgen(path, - index_file_map=None, - reference_genome='default', - contig_recoding=None, - skip_invalid_loci=False, - _buffer_size=16_000_000): +@typecheck( + path=oneof(str, sequenceof(str)), + index_file_map=nullable(dictof(str, str)), + reference_genome=nullable(reference_genome_type), + contig_recoding=nullable(dictof(str, str)), + skip_invalid_loci=bool, + _buffer_size=int, +) +def index_bgen( + path, + index_file_map=None, + reference_genome='default', + contig_recoding=None, + skip_invalid_loci=False, + _buffer_size=16_000_000, +): """Index BGEN files as required by :func:`.import_bgen`. If `index_file_map` is unspecified, then, for each BGEN file, the index file is written in the @@ -2941,8 +3202,10 @@ def index_bgen(path, if not fs.is_dir(p): raise ValueError(f'index_bgen: no file or directory at {p}') for stat_result in fs.ls(p): - if re.match(r"^.*part-[0-9]+(-[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12})?$", - os.path.basename(stat_result.path)): + if re.match( + r"^.*part-[0-9]+(-[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12})?$", + os.path.basename(stat_result.path), + ): paths.append(stat_result.path) paths_lit = hl.literal(paths, hl.tarray(hl.tstr)) @@ -2953,15 +3216,18 @@ def index_bgen(path, contig_recoding_lit = hl.literal(contig_recoding, hl.tdict(hl.tstr, hl.tstr)) ht = hl.utils.range_table(len(paths), len(paths)) path_fd = paths_lit[ht.idx] - ht = ht.annotate(n_indexed=hl.expr.functions._func( - "index_bgen", - hl.tint64, - path_fd, - index_file_map_lit.get(path_fd, path_fd + ".idx2"), - contig_recoding_lit, - hl.bool(skip_invalid_loci), - hl.int32(_buffer_size), - type_args=(rg_t,))) + ht = ht.annotate( + n_indexed=hl.expr.functions._func( + "index_bgen", + hl.tint64, + path_fd, + index_file_map_lit.get(path_fd, path_fd + ".idx2"), + contig_recoding_lit, + hl.bool(skip_invalid_loci), + hl.int32(_buffer_size), + type_args=(rg_t,), + ) + ) for r in ht.collect(): idx = r.idx @@ -2974,30 +3240,36 @@ def index_bgen(path, @typecheck(path=expr_str, filter=nullable(expr_str), find=nullable(expr_str), replace=nullable(expr_str)) def get_vcf_header_info(path, filter=None, find=None, replace=None): from hail.ir.register_functions import vcf_header_type_str + return hl.expr.functions._func( "getVCFHeader", hl.dtype(vcf_header_type_str), path, hl.missing('str') if filter is None else filter, hl.missing('str') if find is None else find, - hl.missing('str') if replace is None else replace) - - -@typecheck(path=str, - _intervals=nullable(sequenceof(anytype)), - _filter_intervals=bool, - _n_partitions=nullable(int), - _assert_type=nullable(hl.ttable), - _load_refs=bool, - _create_row_uids=bool) -def read_table(path, - *, - _intervals=None, - _filter_intervals=False, - _n_partitions=None, - _assert_type=None, - _load_refs=True, - _create_row_uids=False) -> Table: + hl.missing('str') if replace is None else replace, + ) + + +@typecheck( + path=str, + _intervals=nullable(sequenceof(anytype)), + _filter_intervals=bool, + _n_partitions=nullable(int), + _assert_type=nullable(ttable), + _load_refs=bool, + _create_row_uids=bool, +) +def read_table( + path, + *, + _intervals=None, + _filter_intervals=False, + _n_partitions=None, + _assert_type=None, + _load_refs=True, + _create_row_uids=False, +) -> Table: """Read in a :class:`.Table` written with :meth:`.Table.write`. Parameters @@ -3020,18 +3292,26 @@ def read_table(path, if _n_partitions: intervals = ht._calculate_new_partitions(_n_partitions) - return read_table(path, _intervals=intervals, _assert_type=_assert_type, _load_refs=_load_refs, _create_row_uids=_create_row_uids) + return read_table( + path, + _intervals=intervals, + _assert_type=_assert_type, + _load_refs=_load_refs, + _create_row_uids=_create_row_uids, + ) return ht -@typecheck(t=Table, - host=str, - port=int, - index=str, - index_type=str, - block_size=int, - config=nullable(dictof(str, str)), - verbose=bool) +@typecheck( + t=Table, + host=str, + port=int, + index=str, + index_type=str, + block_size=int, + config=nullable(dictof(str, str)), + verbose=bool, +) def export_elasticsearch(t, host, port, index, index_type, block_size, config=None, verbose=True): """Export a :class:`.Table` to Elasticsearch. @@ -3060,7 +3340,6 @@ def import_avro(paths, *, key=None, intervals=None): raise ValueError('key and intervals must either be both defined or both undefined') with hl.current_backend().fs.open(paths[0], 'rb') as avro_file: - # monkey patch DataFileReader.determine_file_length to account for bug in Google HadoopFS def patched_determine_file_length(self) -> int: @@ -3084,36 +3363,41 @@ def patched_determine_file_length(self) -> int: return Table(ir.TableRead(tr)) -@typecheck(paths=oneof(str, sequenceof(str)), - key=table_key_type, - min_partitions=nullable(int), - impute=bool, - no_header=bool, - comment=oneof(str, sequenceof(str)), - missing=oneof(str, sequenceof(str)), - types=dictof(str, hail_type), - skip_blank_lines=bool, - force_bgz=bool, - filter=nullable(str), - find_replace=nullable(sized_tupleof(str, str)), - force=bool, - source_file_field=nullable(str)) -def import_csv(paths, - *, - key=None, - min_partitions=None, - impute=False, - no_header=False, - comment=(), - missing="NA", - types={}, - quote='"', - skip_blank_lines=False, - force_bgz=False, - filter=None, - find_replace=None, - force=False, - source_file_field=None) -> Table: +@typecheck( + paths=oneof(str, sequenceof(str)), + key=table_key_type, + min_partitions=nullable(int), + impute=bool, + no_header=bool, + comment=oneof(str, sequenceof(str)), + missing=oneof(str, sequenceof(str)), + types=dictof(str, hail_type), + quote=nullable(char), + skip_blank_lines=bool, + force_bgz=bool, + filter=nullable(str), + find_replace=nullable(sized_tupleof(str, str)), + force=bool, + source_file_field=nullable(str), +) +def import_csv( + paths, + *, + key=None, + min_partitions=None, + impute=False, + no_header=False, + comment=(), + missing="NA", + types={}, + quote='"', + skip_blank_lines=False, + force_bgz=False, + filter=None, + find_replace=None, + force=False, + source_file_field=None, +) -> Table: """Import a csv file as a :class:`.Table`. Examples @@ -3246,20 +3530,22 @@ def import_csv(paths, :class:`.Table` """ - ht = hl.import_table(paths, - key=key, - min_partitions=min_partitions, - impute=impute, - no_header=no_header, - comment=comment, - missing=missing, - types=types, - skip_blank_lines=skip_blank_lines, - force_bgz=force_bgz, - filter=filter, - find_replace=find_replace, - force=force, - source_file_field=source_file_field, - delimiter=",", - quote=quote) + ht = hl.import_table( + paths, + key=key, + min_partitions=min_partitions, + impute=impute, + no_header=no_header, + comment=comment, + missing=missing, + types=types, + skip_blank_lines=skip_blank_lines, + force_bgz=force_bgz, + filter=filter, + find_replace=find_replace, + force=force, + source_file_field=source_file_field, + delimiter=",", + quote=quote, + ) return ht diff --git a/hail/python/hail/methods/import_lines_helpers.py b/hail/python/hail/methods/import_lines_helpers.py index 51737887e76..555bffbe083 100644 --- a/hail/python/hail/methods/import_lines_helpers.py +++ b/hail/python/hail/methods/import_lines_helpers.py @@ -1,44 +1,40 @@ from typing import List, Optional + import hail as hl +from hail.expr import ArrayExpression, BooleanExpression, Expression, StringExpression, StructExpression from hail.utils.misc import hl_plural, plural def split_lines( - row: hl.StructExpression, - fields: List[str], - *, - delimiter: str, - missing: str, - quote: str -) -> hl.ArrayExpression: + row: StructExpression, fields: List[str], *, delimiter: str, missing: str, quote: str +) -> ArrayExpression: split_array = row.text._split_line(delimiter, missing=missing, quote=quote, regex=len(delimiter) > 1) return ( hl.case() .when(hl.len(split_array) == len(fields), split_array) .or_error( hl.format( - f'''error in number of fields found: in file %s + f"""error in number of fields found: in file %s Expected {len(fields)} {plural("field", len(fields))}, found %d %s on line: -%s''', - row.file, hl.len(split_array), hl_plural("field", hl.len(split_array)), row.text +%s""", + row.file, + hl.len(split_array), + hl_plural("field", hl.len(split_array)), + row.text, ) ) ) -def match_comment(comment: str, line: hl.StringExpression) -> hl.Expression: +def match_comment(comment: str, line: StringExpression) -> Expression: if len(comment) == 1: return line.startswith(comment) return line.matches(comment, True) def should_remove_line( - line: hl.StringExpression, - *, - filter: str, - comment: List[str], - skip_blank_lines: bool -) -> Optional[hl.BooleanExpression]: + line: StringExpression, *, filter: str, comment: List[str], skip_blank_lines: bool +) -> Optional[BooleanExpression]: condition = None if filter is not None: condition = line.matches(filter) diff --git a/hail/python/hail/methods/misc.py b/hail/python/hail/methods/misc.py index 819d335bdde..218aed64f88 100644 --- a/hail/python/hail/methods/misc.py +++ b/hail/python/hail/methods/misc.py @@ -1,24 +1,18 @@ from typing import Union import hail as hl -from hail.expr import Expression, \ - expr_numeric, expr_array, expr_interval, expr_any, \ - construct_expr, construct_variable -from hail.expr.types import tlocus, tarray, tstr, tstruct, ttuple +from hail import ir +from hail.expr import Expression, construct_expr, construct_variable, expr_any, expr_array, expr_interval, expr_numeric +from hail.expr.types import tarray, tlocus, tstr, tstruct, ttuple from hail.matrixtable import MatrixTable from hail.table import Table -from hail.typecheck import typecheck, nullable, func_spec, oneof -from hail.utils import Interval, Struct, new_temp_file, deduplicate -from hail.utils.misc import plural +from hail.typecheck import func_spec, nullable, oneof, typecheck +from hail.utils import Interval, Struct, deduplicate, new_temp_file from hail.utils.java import Env, info -from hail import ir +from hail.utils.misc import plural -@typecheck(i=Expression, - j=Expression, - keep=bool, - tie_breaker=nullable(func_spec(2, expr_numeric)), - keyed=bool) +@typecheck(i=Expression, j=Expression, keep=bool, tie_breaker=nullable(func_spec(2, expr_numeric)), keyed=bool) def maximal_independent_set(i, j, keep=True, tie_breaker=None, keyed=True) -> Table: """Return a table containing the vertices in a near `maximal independent set `_ @@ -119,19 +113,25 @@ def maximal_independent_set(i, j, keep=True, tie_breaker=None, keyed=True) -> Ta """ if i.dtype != j.dtype: - raise ValueError("'maximal_independent_set' expects arguments `i` and `j` to have same type. " - "Found {} and {}.".format(i.dtype, j.dtype)) + raise ValueError( + "'maximal_independent_set' expects arguments `i` and `j` to have same type. " "Found {} and {}.".format( + i.dtype, j.dtype + ) + ) source = i._indices.source if not isinstance(source, Table): - raise ValueError("'maximal_independent_set' expects an expression of 'Table'. Found {}".format( - "expression of '{}'".format( - source.__class__) if source is not None else 'scalar expression')) + raise ValueError( + "'maximal_independent_set' expects an expression of 'Table'. Found {}".format( + "expression of '{}'".format(source.__class__) if source is not None else 'scalar expression' + ) + ) if i._indices.source != j._indices.source: raise ValueError( "'maximal_independent_set' expects arguments `i` and `j` to be expressions of the same Table. " - "Found\n{}\n{}".format(i, j)) + "Found\n{}\n{}".format(i, j) + ) node_t = i.dtype @@ -151,9 +151,12 @@ def maximal_independent_set(i, j, keep=True, tie_breaker=None, keyed=True) -> Ta edges = t.select(__i=i, __j=j).key_by().select('__i', '__j') edges = edges.checkpoint(new_temp_file()) - mis_nodes = hl.set(construct_expr( - ir.ArrayMaximalIndependentSet(edges.collect(_localize=False)._ir, left_id, right_id, tie_breaker_ir), - hl.tarray(node_t))) + mis_nodes = hl.set( + construct_expr( + ir.ArrayMaximalIndependentSet(edges.collect(_localize=False)._ir, left_id, right_id, tie_breaker_ir), + hl.tarray(node_t), + ) + ) nodes = edges.select(node=[edges.__i, edges.__j]) nodes = nodes.explode(nodes.node) @@ -167,18 +170,23 @@ def maximal_independent_set(i, j, keep=True, tie_breaker=None, keyed=True) -> Ta def require_col_key_str(dataset: MatrixTable, method: str): if not len(dataset.col_key) == 1 or dataset[next(iter(dataset.col_key))].dtype != hl.tstr: - raise ValueError(f"Method '{method}' requires column key to be one field of type 'str', found " - f"{list(str(x.dtype) for x in dataset.col_key.values())}") + raise ValueError( + f"Method '{method}' requires column key to be one field of type 'str', found " + f"{list(str(x.dtype) for x in dataset.col_key.values())}" + ) def require_table_key_variant(ht, method): - if (list(ht.key) != ['locus', 'alleles'] - or not isinstance(ht['locus'].dtype, tlocus) - or not ht['alleles'].dtype == tarray(tstr)): - raise ValueError("Method '{}' requires key to be two fields 'locus' (type 'locus') and " - "'alleles' (type 'array')\n" - " Found:{}".format(method, ''.join( - "\n '{}': {}".format(k, str(ht[k].dtype)) for k in ht.key))) + if ( + list(ht.key) != ['locus', 'alleles'] + or not isinstance(ht['locus'].dtype, tlocus) + or not ht['alleles'].dtype == tarray(tstr) + ): + raise ValueError( + "Method '{}' requires key to be two fields 'locus' (type 'locus') and " + "'alleles' (type 'array')\n" + " Found:{}".format(method, ''.join("\n '{}': {}".format(k, str(ht[k].dtype)) for k in ht.key)) + ) def require_row_key_variant(dataset, method): @@ -187,36 +195,46 @@ def require_row_key_variant(dataset, method): else: assert isinstance(dataset, MatrixTable) key = dataset.row_key - if (list(key) != ['locus', 'alleles'] - or not isinstance(dataset['locus'].dtype, tlocus) - or not dataset['alleles'].dtype == tarray(tstr)): - raise ValueError("Method '{}' requires row key to be two fields 'locus' (type 'locus') and " - "'alleles' (type 'array')\n" - " Found:{}".format(method, ''.join( - "\n '{}': {}".format(k, str(dataset[k].dtype)) for k in key))) + if ( + list(key) != ['locus', 'alleles'] + or not isinstance(dataset['locus'].dtype, tlocus) + or not dataset['alleles'].dtype == tarray(tstr) + ): + raise ValueError( + "Method '{}' requires row key to be two fields 'locus' (type 'locus') and " + "'alleles' (type 'array')\n" + " Found:{}".format(method, ''.join("\n '{}': {}".format(k, str(dataset[k].dtype)) for k in key)) + ) def require_alleles_field(dataset, method): if 'alleles' not in dataset.row: - raise ValueError( - f"Method '{method}' requires a field 'alleles' (type 'array')\n") + raise ValueError(f"Method '{method}' requires a field 'alleles' (type 'array')\n") if dataset.alleles.dtype != tarray(tstr): raise ValueError( f"Method '{method}' requires a field 'alleles' (type 'array')\n" f" Found:\n" - f" 'alleles': {dataset.alleles.dtype}") + f" 'alleles': {dataset.alleles.dtype}" + ) def require_row_key_variant_w_struct_locus(dataset, method): - if (list(dataset.row_key) != ['locus', 'alleles'] + if ( + list(dataset.row_key) != ['locus', 'alleles'] or not dataset['alleles'].dtype == tarray(tstr) - or (not isinstance(dataset['locus'].dtype, tlocus) - and dataset['locus'].dtype != hl.dtype('struct{contig: str, position: int32}'))): - raise ValueError("Method '{}' requires row key to be two fields 'locus'" - " (type 'locus' or 'struct{{contig: str, position: int32}}') and " - "'alleles' (type 'array')\n" - " Found:{}".format(method, ''.join( - "\n '{}': {}".format(k, str(dataset[k].dtype)) for k in dataset.row_key))) + or ( + not isinstance(dataset['locus'].dtype, tlocus) + and dataset['locus'].dtype != hl.dtype('struct{contig: str, position: int32}') + ) + ): + raise ValueError( + "Method '{}' requires row key to be two fields 'locus'" + " (type 'locus' or 'struct{{contig: str, position: int32}}') and " + "'alleles' (type 'array')\n" + " Found:{}".format( + method, ''.join("\n '{}': {}".format(k, str(dataset[k].dtype)) for k in dataset.row_key) + ) + ) def require_first_key_field_locus(dataset, method): @@ -225,11 +243,12 @@ def require_first_key_field_locus(dataset, method): else: assert isinstance(dataset, MatrixTable) key = dataset.row_key - if (len(key) == 0 - or not isinstance(key[0].dtype, tlocus)): - raise ValueError("Method '{}' requires first key field of type 'locus'.\n" - " Found:{}".format(method, ''.join( - "\n '{}': {}".format(k, str(dataset[k].dtype)) for k in key))) + if len(key) == 0 or not isinstance(key[0].dtype, tlocus): + raise ValueError( + "Method '{}' requires first key field of type 'locus'.\n" " Found:{}".format( + method, ''.join("\n '{}': {}".format(k, str(dataset[k].dtype)) for k in key) + ) + ) @typecheck(table=Table, method=str) @@ -244,11 +263,17 @@ def require_biallelic(dataset, method, tolerate_generic_locus: bool = False) -> require_row_key_variant_w_struct_locus(dataset, method) else: require_row_key_variant(dataset, method) - return dataset._select_rows(method, - hl.case() - .when(dataset.alleles.length() == 2, dataset._rvrow) - .or_error(f"'{method}' expects biallelic variants ('alleles' field of length 2), found " - + hl.str(dataset.locus) + ", " + hl.str(dataset.alleles))) + return dataset._select_rows( + method, + hl.case() + .when(dataset.alleles.length() == 2, dataset._rvrow) + .or_error( + f"'{method}' expects biallelic variants ('alleles' field of length 2), found " + + hl.str(dataset.locus) + + ", " + + hl.str(dataset.alleles) + ), + ) @typecheck(dataset=MatrixTable, name=str) @@ -292,16 +317,16 @@ def rename_duplicates(dataset, name='unique_id') -> MatrixTable: mapping, new_ids = deduplicate(ids) if mapping: - info(f'Renamed {len(mapping)} duplicate {plural("sample ID", len(mapping))}. Mangled IDs as follows:' - + ''.join(f'\n "{pre}" => "{post}"' for pre, post in mapping)) + info( + f'Renamed {len(mapping)} duplicate {plural("sample ID", len(mapping))}. Mangled IDs as follows:' + + ''.join(f'\n "{pre}" => "{post}"' for pre, post in mapping) + ) else: info('No duplicate sample IDs found.') return dataset.annotate_cols(**{name: hl.literal(new_ids)[hl.int(hl.scan.count())]}) -@typecheck(ds=oneof(Table, MatrixTable), - intervals=expr_array(expr_interval(expr_any)), - keep=bool) +@typecheck(ds=oneof(Table, MatrixTable), intervals=expr_array(expr_interval(expr_any)), keep=bool) def filter_intervals(ds, intervals, keep=True) -> Union[Table, MatrixTable]: """Filter rows with a list of intervals. @@ -354,7 +379,7 @@ def filter_intervals(ds, intervals, keep=True) -> Union[Table, MatrixTable]: point_type = intervals.dtype.element_type.point_type def is_struct_prefix(partial, full): - if list(partial) != list(full)[:len(partial)]: + if list(partial) != list(full)[: len(partial)]: return False for k, v in partial.items(): if full[k] != v: @@ -369,17 +394,21 @@ def is_struct_prefix(partial, full): needs_wrapper = False else: raise TypeError( - "The point type is incompatible with key type of the dataset ('{}', '{}')".format(repr(point_type), - repr(k_type))) + "The point type is incompatible with key type of the dataset ('{}', '{}')".format( + repr(point_type), repr(k_type) + ) + ) def wrap_input(interval): if interval is None: raise TypeError("'filter_intervals' does not allow missing values in 'intervals'.") elif needs_wrapper: - return Interval(Struct(**{k_name: interval.start}), - Struct(**{k_name: interval.end}), - interval.includes_start, - interval.includes_end) + return Interval( + Struct(**{k_name: interval.start}), + Struct(**{k_name: interval.end}), + interval.includes_start, + interval.includes_end, + ) else: return interval @@ -392,8 +421,7 @@ def wrap_input(interval): return Table(ir.TableFilterIntervals(ds._tir, intervals, point_type, keep)) -@typecheck(ht=Table, - points=oneof(Table, expr_array(expr_any))) +@typecheck(ht=Table, points=oneof(Table, expr_array(expr_any))) def segment_intervals(ht, points): """Segment the interval keys of `ht` at a given set of points. @@ -413,14 +441,18 @@ def segment_intervals(ht, points): point_type = ht.key[0].dtype.point_type if isinstance(points, Table): if len(points.key) != 1 or points.key[0].dtype != point_type: - raise ValueError("'segment_intervals' expects points to be a table with a single" - " key of the same type as the intervals in 'ht', or an array of those points:" - f"\n expect {point_type}, found {list(points.key.dtype.values())}") + raise ValueError( + "'segment_intervals' expects points to be a table with a single" + " key of the same type as the intervals in 'ht', or an array of those points:" + f"\n expect {point_type}, found {list(points.key.dtype.values())}" + ) points = hl.array(hl.set(points.collect(_localize=False))) if points.dtype.element_type != point_type: - raise ValueError(f"'segment_intervals' expects points to be a table with a single" - f" key of the same type as the intervals in 'ht', or an array of those points:" - f"\n expect {point_type}, found {points.dtype.element_type}") + raise ValueError( + f"'segment_intervals' expects points to be a table with a single" + f" key of the same type as the intervals in 'ht', or an array of those points:" + f"\n expect {point_type}, found {points.dtype.element_type}" + ) points = hl._sort_by(points, lambda l, r: hl._compare(l, r) < 0) @@ -433,18 +465,28 @@ def segment_intervals(ht, points): n_points = hl.len(points) lower = hl.if_else((lower < n_points) & (points[lower] == interval.start), lower + 1, lower) higher = hl.if_else((higher < n_points) & (points[higher] == interval.end), higher - 1, higher) - interval_results = hl.rbind(lower, higher, - lambda lower, higher: hl.if_else( - lower >= higher, - [interval], - hl.flatten([ - [hl.interval(interval.start, points[lower], - includes_start=interval.includes_start, includes_end=False)], - hl.range(lower, higher - 1).map( - lambda x: hl.interval(points[x], points[x + 1], includes_start=True, - includes_end=False)), - [hl.interval(points[higher - 1], interval.end, includes_start=True, - includes_end=interval.includes_end)], - ]))) + interval_results = hl.rbind( + lower, + higher, + lambda lower, higher: hl.if_else( + lower >= higher, + [interval], + hl.flatten([ + [ + hl.interval( + interval.start, points[lower], includes_start=interval.includes_start, includes_end=False + ) + ], + hl.range(lower, higher - 1).map( + lambda x: hl.interval(points[x], points[x + 1], includes_start=True, includes_end=False) + ), + [ + hl.interval( + points[higher - 1], interval.end, includes_start=True, includes_end=interval.includes_end + ) + ], + ]), + ), + ) ht = ht.annotate(__new_intervals=interval_results, lower=lower, higher=higher).explode('__new_intervals') - return ht.key_by(**{list(ht.key)[0]: ht.__new_intervals}).drop('__new_intervals') + return ht.key_by(**{next(iter(ht.key)): ht.__new_intervals}).drop('__new_intervals') diff --git a/hail/python/hail/methods/pca.py b/hail/python/hail/methods/pca.py index b14eee8d079..dcb991957b2 100644 --- a/hail/python/hail/methods/pca.py +++ b/hail/python/hail/methods/pca.py @@ -2,22 +2,20 @@ import hail as hl import hail.expr.aggregators as agg -from hail.expr.expressions import construct_expr -from hail.expr import (expr_float64, expr_call, raise_unless_entry_indexed, - matrix_table_source) from hail import ir +from hail.experimental import mt_to_table_of_ndarray +from hail.expr import expr_call, expr_float64, matrix_table_source, raise_unless_entry_indexed +from hail.expr.expressions import construct_expr from hail.table import Table -from hail.typecheck import (typecheck, oneof, nullable) +from hail.typecheck import nullable, oneof, typecheck from hail.utils import FatalError from hail.utils.java import Env, info -from hail.experimental import mt_to_table_of_ndarray def hwe_normalize(call_expr): mt = matrix_table_source('hwe_normalize/call_expr', call_expr) mt = mt.select_entries(__gt=call_expr.n_alt_alleles()) - mt = mt.annotate_rows(__AC=agg.sum(mt.__gt), - __n_called=agg.count_where(hl.is_defined(mt.__gt))) + mt = mt.annotate_rows(__AC=agg.sum(mt.__gt), __n_called=agg.count_where(hl.is_defined(mt.__gt))) mt = mt.filter_rows((mt.__AC > 0) & (mt.__AC < 2 * mt.__n_called)) n_variants = mt.count_rows() @@ -26,17 +24,14 @@ def hwe_normalize(call_expr): info(f"hwe_normalize: found {n_variants} variants after filtering out monomorphic sites.") mt = mt.annotate_rows(__mean_gt=mt.__AC / mt.__n_called) - mt = mt.annotate_rows( - __hwe_scaled_std_dev=hl.sqrt(mt.__mean_gt * (2 - mt.__mean_gt) * n_variants / 2)) + mt = mt.annotate_rows(__hwe_scaled_std_dev=hl.sqrt(mt.__mean_gt * (2 - mt.__mean_gt) * n_variants / 2)) mt = mt.unfilter_entries() normalized_gt = hl.or_else((mt.__gt - mt.__mean_gt) / mt.__hwe_scaled_std_dev, 0.0) return normalized_gt -@typecheck(call_expr=expr_call, - k=int, - compute_loadings=bool) +@typecheck(call_expr=expr_call, k=int, compute_loadings=bool) def hwe_normalized_pca(call_expr, k=10, compute_loadings=False) -> Tuple[List[float], Table, Table]: r"""Run principal component analysis (PCA) on the Hardy-Weinberg-normalized genotype call matrix. @@ -98,14 +93,10 @@ def hwe_normalized_pca(call_expr, k=10, compute_loadings=False) -> Tuple[List[fl if isinstance(hl.current_backend(), ServiceBackend): return _hwe_normalized_blanczos(call_expr, k, compute_loadings) - return pca(hwe_normalize(call_expr), - k, - compute_loadings) + return pca(hwe_normalize(call_expr), k, compute_loadings) -@typecheck(entry_expr=expr_float64, - k=int, - compute_loadings=bool) +@typecheck(entry_expr=expr_float64, k=int, compute_loadings=bool) def pca(entry_expr, k=10, compute_loadings=False) -> Tuple[List[float], Table, Table]: r"""Run principal component analysis (PCA) on numeric columns derived from a matrix table. @@ -203,12 +194,11 @@ def pca(entry_expr, k=10, compute_loadings=False) -> Tuple[List[float], Table, T mt = mt.select_entries(**{field: entry_expr}) mt = mt.select_cols().select_rows().select_globals() - t = (Table(ir.MatrixToTableApply(mt._mir, { - 'name': 'PCA', - 'entryField': field, - 'k': k, - 'computeLoadings': compute_loadings - })).persist()) + t = Table( + ir.MatrixToTableApply( + mt._mir, {'name': 'PCA', 'entryField': field, 'k': k, 'computeLoadings': compute_loadings} + ) + ).persist() g = t.index_globals() scores = hl.Table.parallelize(g.scores, key=list(mt.col_key)) @@ -227,7 +217,15 @@ def __init__(self, block_table, block_expr, source_table, col_key): self.source_table = source_table -def _make_tsm(entry_expr, block_size, *, partition_size=None, whiten_window_size=None, whiten_block_size=64, normalize_after_whiten=False): +def _make_tsm( + entry_expr, + block_size, + *, + partition_size=None, + whiten_window_size=None, + whiten_block_size=64, + normalize_after_whiten=False, +): mt = matrix_table_source('_make_tsm/entry_expr', entry_expr) if whiten_window_size is None: @@ -236,10 +234,12 @@ def _make_tsm(entry_expr, block_size, *, partition_size=None, whiten_window_size else: # FIXME: don't whiten across chromosome boundaries A, trailing_blocks_ht, ht = mt_to_table_of_ndarray( - entry_expr, block_size, + entry_expr, + block_size, return_checkpointed_table_also=True, partition_size=partition_size, - window_size=whiten_window_size) + window_size=whiten_window_size, + ) A = A.annotate(ndarray=A.ndarray.T) vec_size = hl.eval(ht.take(1, _localize=False)[0].xs.length()) @@ -256,20 +256,31 @@ def whiten_map_body(part_stream): whiten_window_size, block_size, whiten_block_size, - normalize_after_whiten) + normalize_after_whiten, + ) return construct_expr(stream_ir, part_stream.dtype) + whitened = joined._map_partitions(whiten_map_body) whitened = whitened.annotate(ndarray=whitened.ndarray.T).persist() return TallSkinnyMatrix(whitened, whitened.ndarray, ht, list(mt.col_key)) -def _make_tsm_from_call(call_expr, block_size, *, mean_center=False, hwe_normalize=False, partition_size=None, whiten_window_size=None, whiten_block_size=64, normalize_after_whiten=False): +def _make_tsm_from_call( + call_expr, + block_size, + *, + mean_center=False, + hwe_normalize=False, + partition_size=None, + whiten_window_size=None, + whiten_block_size=64, + normalize_after_whiten=False, +): mt = matrix_table_source('_make_tsm/entry_expr', call_expr) mt = mt.select_entries(__gt=call_expr.n_alt_alleles()) if mean_center or hwe_normalize: - mt = mt.annotate_rows(__AC=agg.sum(mt.__gt), - __n_called=agg.count_where(hl.is_defined(mt.__gt))) + mt = mt.annotate_rows(__AC=agg.sum(mt.__gt), __n_called=agg.count_where(hl.is_defined(mt.__gt))) mt = mt.filter_rows((mt.__AC > 0) & (mt.__AC < 2 * mt.__n_called)) n_variants = mt.count_rows() @@ -283,17 +294,19 @@ def _make_tsm_from_call(call_expr, block_size, *, mean_center=False, hwe_normali mt = mt.select_entries(__x=hl.or_else(mt.__gt - mt.__mean_gt, 0.0)) if hwe_normalize: - mt = mt.annotate_rows( - __hwe_scaled_std_dev=hl.sqrt(mt.__mean_gt * (2 - mt.__mean_gt) / 2)) + mt = mt.annotate_rows(__hwe_scaled_std_dev=hl.sqrt(mt.__mean_gt * (2 - mt.__mean_gt) / 2)) mt = mt.select_entries(__x=mt.__x / mt.__hwe_scaled_std_dev) else: mt = mt.select_entries(__x=mt.__gt) - return _make_tsm(mt.__x, block_size, - partition_size=partition_size, - whiten_window_size=whiten_window_size, - whiten_block_size=whiten_block_size, - normalize_after_whiten=normalize_after_whiten) + return _make_tsm( + mt.__x, + block_size, + partition_size=partition_size, + whiten_window_size=whiten_window_size, + whiten_block_size=whiten_block_size, + normalize_after_whiten=normalize_after_whiten, + ) class KrylovFactorization: @@ -321,8 +334,8 @@ def reduced_svd(self, k): return U, S, V def spectral_moments(self, num_moments, R): - eigval_powers = hl.nd.vstack([self.S.map(lambda x: x**(2 * i)) for i in range(1, num_moments + 1)]) - moments = eigval_powers @ (self.V1t[:, :self.k] @ R).map(lambda x: x**2) + eigval_powers = hl.nd.vstack([self.S.map(lambda x: x ** (2 * i)) for i in range(1, num_moments + 1)]) + moments = eigval_powers @ (self.V1t[:, : self.k] @ R).map(lambda x: x**2) means = moments.sum(1) / self.k variances = (moments - means.reshape(-1, 1)).map(lambda x: x**2).sum(1) / (self.k - 1) stdevs = variances.map(lambda x: hl.sqrt(x)) @@ -395,7 +408,7 @@ def _reduced_svd(A: TallSkinnyMatrix, k=10, compute_U=False, iterations=2, itera L = k + 2 else: L = iteration_size - assert((q + 1) * L >= k) + assert (q + 1) * L >= k n = A.ncols # Generate random matrix G @@ -407,11 +420,9 @@ def _reduced_svd(A: TallSkinnyMatrix, k=10, compute_U=False, iterations=2, itera return fact.reduced_svd(k) -@typecheck(A=oneof(expr_float64, TallSkinnyMatrix), - num_moments=int, - p=nullable(int), - moment_samples=int, - block_size=int) +@typecheck( + A=oneof(expr_float64, TallSkinnyMatrix), num_moments=int, p=nullable(int), moment_samples=int, block_size=int +) def _spectral_moments(A, num_moments, p=None, moment_samples=500, block_size=128): if not isinstance(A, TallSkinnyMatrix): raise_unless_entry_indexed('_spectral_moments/entry_expr', A) @@ -434,15 +445,26 @@ def _spectral_moments(A, num_moments, p=None, moment_samples=500, block_size=128 return moments, stdevs -@typecheck(A=oneof(expr_float64, TallSkinnyMatrix), - k=int, - num_moments=int, - compute_loadings=bool, - q_iterations=int, - oversampling_param=nullable(int), - block_size=int, - moment_samples=int) -def _pca_and_moments(A, k=10, num_moments=5, compute_loadings=False, q_iterations=10, oversampling_param=None, block_size=128, moment_samples=100): +@typecheck( + A=oneof(expr_float64, TallSkinnyMatrix), + k=int, + num_moments=int, + compute_loadings=bool, + q_iterations=int, + oversampling_param=nullable(int), + block_size=int, + moment_samples=int, +) +def _pca_and_moments( + A, + k=10, + num_moments=5, + compute_loadings=False, + q_iterations=10, + oversampling_param=None, + block_size=128, + moment_samples=100, +): if not isinstance(A, TallSkinnyMatrix): raise_unless_entry_indexed('_spectral_moments/entry_expr', A) A = _make_tsm(A, block_size) @@ -473,7 +495,9 @@ def _pca_and_moments(A, k=10, num_moments=5, compute_loadings=False, q_iteration fact2 = _krylov_factorization(A, Q1, p, compute_U=False) moments_and_stdevs = fact2.spectral_moments(num_moments, R1) # Add back exact moments - moments = moments_and_stdevs.moments + hl.nd.array([fact.S.map(lambda x: x**(2 * i)).sum() for i in range(1, num_moments + 1)]) + moments = moments_and_stdevs.moments + hl.nd.array([ + fact.S.map(lambda x: x ** (2 * i)).sum() for i in range(1, num_moments + 1) + ]) moments_and_stdevs = hl.eval(hl.struct(moments=moments, stdevs=moments_and_stdevs.stdevs)) moments = moments_and_stdevs.moments stdevs = moments_and_stdevs.stdevs @@ -483,7 +507,9 @@ def _pca_and_moments(A, k=10, num_moments=5, compute_loadings=False, q_iteration info("blanczos_pca: SVD Complete. Computing conversion to PCs.") hail_array_scores = scores._data_array() - cols_and_scores = hl.zip(A.source_table.index_globals().cols, hail_array_scores).map(lambda tup: tup[0].annotate(scores=tup[1])) + cols_and_scores = hl.zip(A.source_table.index_globals().cols, hail_array_scores).map( + lambda tup: tup[0].annotate(scores=tup[1]) + ) st = hl.Table.parallelize(cols_and_scores, key=A.col_key) if compute_loadings: @@ -499,15 +525,26 @@ def _pca_and_moments(A, k=10, num_moments=5, compute_loadings=False, q_iteration return eigens, st, lt, moments, stdevs -@typecheck(A=oneof(expr_float64, TallSkinnyMatrix), - k=int, - compute_loadings=bool, - q_iterations=int, - oversampling_param=nullable(int), - block_size=int, - compute_scores=bool, - transpose=bool) -def _blanczos_pca(A, k=10, compute_loadings=False, q_iterations=10, oversampling_param=None, block_size=128, compute_scores=True, transpose=False): +@typecheck( + A=oneof(expr_float64, TallSkinnyMatrix), + k=int, + compute_loadings=bool, + q_iterations=int, + oversampling_param=nullable(int), + block_size=int, + compute_scores=bool, + transpose=bool, +) +def _blanczos_pca( + A, + k=10, + compute_loadings=False, + q_iterations=10, + oversampling_param=None, + block_size=128, + compute_scores=True, + transpose=False, +): r"""Run randomized principal component analysis approximation (PCA) on numeric columns derived from a matrix table. @@ -616,7 +653,9 @@ def numpy_to_rows_table(X, field_name): def numpy_to_cols_table(X, field_name): hail_array = X._data_array() - cols_and_X = hl.zip(A.source_table.index_globals().cols, hail_array).map(lambda tup: tup[0].annotate(**{field_name: tup[1]})) + cols_and_X = hl.zip(A.source_table.index_globals().cols, hail_array).map( + lambda tup: tup[0].annotate(**{field_name: tup[1]}) + ) t = hl.Table.parallelize(cols_and_X, key=A.col_key) return t @@ -637,13 +676,17 @@ def numpy_to_cols_table(X, field_name): return eigens, st, lt -@typecheck(call_expr=expr_call, - k=int, - compute_loadings=bool, - q_iterations=int, - oversampling_param=nullable(int), - block_size=int) -def _hwe_normalized_blanczos(call_expr, k=10, compute_loadings=False, q_iterations=10, oversampling_param=None, block_size=128): +@typecheck( + call_expr=expr_call, + k=int, + compute_loadings=bool, + q_iterations=int, + oversampling_param=nullable(int), + block_size=int, +) +def _hwe_normalized_blanczos( + call_expr, k=10, compute_loadings=False, q_iterations=10, oversampling_param=None, block_size=128 +): r"""Run randomized principal component analysis approximation (PCA) on the Hardy-Weinberg-normalized genotype call matrix. @@ -679,5 +722,11 @@ def _hwe_normalized_blanczos(call_expr, k=10, compute_loadings=False, q_iteratio raise_unless_entry_indexed('_blanczos_pca/entry_expr', call_expr) A = _make_tsm_from_call(call_expr, block_size, hwe_normalize=True) - return _blanczos_pca(A, k, compute_loadings=compute_loadings, q_iterations=q_iterations, - oversampling_param=oversampling_param, block_size=block_size) + return _blanczos_pca( + A, + k, + compute_loadings=compute_loadings, + q_iterations=q_iterations, + oversampling_param=oversampling_param, + block_size=block_size, + ) diff --git a/hail/python/hail/methods/qc.py b/hail/python/hail/methods/qc.py index 8ac04e0e363..2efd7b37944 100644 --- a/hail/python/hail/methods/qc.py +++ b/hail/python/hail/methods/qc.py @@ -1,35 +1,57 @@ import abc import logging - -import hail as hl -from collections import Counter import os +from collections import Counter from shlex import quote as shq +from typing import Dict, List, Optional, Tuple, Union -from typing import Dict, Tuple, List, Optional, Union - -from hailtop import pip_version -from hailtop.utils import async_to_blocking +import hail as hl import hailtop.batch_client as bc -from hailtop.config import get_deploy_config -from hailtop import yamlx - from hail.backend.service_backend import ServiceBackend -from hail.typecheck import typecheck, oneof, anytype, nullable, numeric +from hail.expr import Float64Expression from hail.expr.expressions.expression_typecheck import expr_float64 +from hail.expr.functions import numeric_allele_type +from hail.expr.types import tarray, tfloat, tint32, tstr, tstruct +from hail.genetics.allele_type import AlleleType +from hail.ir import TableToTableApply +from hail.matrixtable import MatrixTable +from hail.table import Table +from hail.typecheck import anytype, nullable, numeric, oneof, typecheck from hail.utils import FatalError from hail.utils.java import Env, info, warning from hail.utils.misc import divide_null, guess_cloud_spark_provider, new_temp_file -from hail.matrixtable import MatrixTable -from hail.table import Table -from hail.ir import TableToTableApply -from .misc import require_biallelic, require_row_key_variant, require_col_key_str, require_table_key_variant, require_alleles_field +from hailtop import pip_version, yamlx +from hailtop.config import get_deploy_config +from hailtop.utils import async_to_blocking + +from .misc import ( + require_alleles_field, + require_biallelic, + require_col_key_str, + require_row_key_variant, + require_table_key_variant, +) log = logging.getLogger('methods.qc') -HAIL_GENETICS_VEP_GRCH37_85_IMAGE = os.environ.get('HAIL_GENETICS_VEP_GRCH37_85_IMAGE', f'hailgenetics/vep-grch37-85:{pip_version()}') -HAIL_GENETICS_VEP_GRCH38_95_IMAGE = os.environ.get('HAIL_GENETICS_VEP_GRCH38_95_IMAGE', f'hailgenetics/vep-grch38-95:{pip_version()}') +HAIL_GENETICS_VEP_GRCH37_85_IMAGE = os.environ.get( + 'HAIL_GENETICS_VEP_GRCH37_85_IMAGE', f'hailgenetics/vep-grch37-85:{pip_version()}' +) +HAIL_GENETICS_VEP_GRCH38_95_IMAGE = os.environ.get( + 'HAIL_GENETICS_VEP_GRCH38_95_IMAGE', f'hailgenetics/vep-grch38-95:{pip_version()}' +) + + +def _qc_allele_type(ref, alt): + return hl.bind( + lambda at: hl.if_else( + at == AlleleType.SNP, + hl.if_else(hl.is_transition(ref, alt), AlleleType.TRANSITION, AlleleType.TRANSVERSION), + at, + ), + numeric_allele_type(ref, alt), + ) @typecheck(mt=MatrixTable, name=str) @@ -105,25 +127,12 @@ def sample_qc(mt, name='sample_qc') -> MatrixTable: require_row_key_variant(mt, 'sample_qc') - from hail.expr.functions import _num_allele_type, _allele_types - - allele_types = _allele_types[:] - allele_types.extend(['Transition', 'Transversion']) - allele_enum = {i: v for i, v in enumerate(allele_types)} - allele_ints = {v: k for k, v in allele_enum.items()} - - def allele_type(ref, alt): - return hl.bind(lambda at: hl.if_else(at == allele_ints['SNP'], - hl.if_else(hl.is_transition(ref, alt), - allele_ints['Transition'], - allele_ints['Transversion']), - at), - _num_allele_type(ref, alt)) - variant_ac = Env.get_uid() variant_atypes = Env.get_uid() - mt = mt.annotate_rows(**{variant_ac: hl.agg.call_stats(mt.GT, mt.alleles).AC, - variant_atypes: mt.alleles[1:].map(lambda alt: allele_type(mt.alleles[0], alt))}) + mt = mt.annotate_rows(**{ + variant_ac: hl.agg.call_stats(mt.GT, mt.alleles).AC, + variant_atypes: mt.alleles[1:].map(lambda alt: _qc_allele_type(mt.alleles[0], alt)), + }) bound_exprs = {} gq_dp_exprs = {} @@ -143,22 +152,31 @@ def has_field_of_type(name, dtype): bound_exprs['n_called'] = hl.agg.count_where(hl.is_defined(mt['GT'])) bound_exprs['n_not_called'] = hl.agg.count_where(hl.is_missing(mt['GT'])) - n_rows_ref = hl.expr.construct_expr(hl.ir.Ref('n_rows', hl.tint64), hl.tint64, mt._col_indices, - hl.utils.LinkedList(hl.expr.expressions.Aggregation)) + n_rows_ref = hl.expr.construct_expr( + hl.ir.Ref('n_rows', hl.tint64), hl.tint64, mt._col_indices, hl.utils.LinkedList(hl.expr.expressions.Aggregation) + ) bound_exprs['n_filtered'] = n_rows_ref - hl.agg.count() bound_exprs['n_hom_ref'] = hl.agg.count_where(mt['GT'].is_hom_ref()) bound_exprs['n_het'] = hl.agg.count_where(mt['GT'].is_het()) - bound_exprs['n_singleton'] = hl.agg.sum(hl.rbind(mt['GT'], lambda gt: hl.sum( - hl.range(0, gt.ploidy).map(lambda i: hl.rbind(gt[i], lambda gti: (gti != 0) & (mt[variant_ac][gti] == 1)))))) + bound_exprs['n_singleton'] = hl.agg.sum( + hl.rbind( + mt['GT'], + lambda gt: hl.sum( + hl.range(0, gt.ploidy).map( + lambda i: hl.rbind(gt[i], lambda gti: (gti != 0) & (mt[variant_ac][gti] == 1)) + ) + ), + ) + ) bound_exprs['allele_type_counts'] = hl.agg.explode( - lambda allele_type: hl.tuple( - hl.agg.count_where(allele_type == i) for i in range(len(allele_ints)) + lambda allele_type: hl.tuple(hl.agg.count_where(allele_type == i) for i in range(len(AlleleType))), + ( + hl.range(0, mt['GT'].ploidy) + .map(lambda i: mt['GT'][i]) + .filter(lambda allele_idx: allele_idx > 0) + .map(lambda allele_idx: mt[variant_atypes][allele_idx - 1]) ), - (hl.range(0, mt['GT'].ploidy) - .map(lambda i: mt['GT'][i]) - .filter(lambda allele_idx: allele_idx > 0) - .map(lambda allele_idx: mt[variant_atypes][allele_idx - 1])) ) result_struct = hl.rbind( @@ -175,19 +193,20 @@ def has_field_of_type(name, dtype): 'n_hom_var': x.n_called - x.n_hom_ref - x.n_het, 'n_non_ref': x.n_called - x.n_hom_ref, 'n_singleton': x.n_singleton, - 'n_snp': (x.allele_type_counts[allele_ints["Transition"]] - + x.allele_type_counts[allele_ints["Transversion"]]), - 'n_insertion': x.allele_type_counts[allele_ints["Insertion"]], - 'n_deletion': x.allele_type_counts[allele_ints["Deletion"]], - 'n_transition': x.allele_type_counts[allele_ints["Transition"]], - 'n_transversion': x.allele_type_counts[allele_ints["Transversion"]], - 'n_star': x.allele_type_counts[allele_ints["Star"]], + 'n_snp': x.allele_type_counts[AlleleType.TRANSITION] + x.allele_type_counts[AlleleType.TRANSVERSION], + 'n_insertion': x.allele_type_counts[AlleleType.INSERTION], + 'n_deletion': x.allele_type_counts[AlleleType.DELETION], + 'n_transition': x.allele_type_counts[AlleleType.TRANSITION], + 'n_transversion': x.allele_type_counts[AlleleType.TRANSVERSION], + 'n_star': x.allele_type_counts[AlleleType.STAR], }), lambda s: s.annotate( r_ti_tv=divide_null(hl.float64(s.n_transition), s.n_transversion), r_het_hom_var=divide_null(hl.float64(s.n_het), s.n_hom_var), - r_insertion_deletion=divide_null(hl.float64(s.n_insertion), s.n_deletion) - ))) + r_insertion_deletion=divide_null(hl.float64(s.n_insertion), s.n_deletion), + ), + ), + ) mt = mt.annotate_cols(**{name: result_struct}) mt = mt.drop(variant_ac, variant_atypes) @@ -290,44 +309,53 @@ def has_field_of_type(name, dtype): bound_exprs['n_called'] = hl.agg.count_where(hl.is_defined(mt['GT'])) bound_exprs['n_not_called'] = hl.agg.count_where(hl.is_missing(mt['GT'])) - n_cols_ref = hl.expr.construct_expr(hl.ir.Ref('n_cols', hl.tint32), hl.tint32, - mt._row_indices, hl.utils.LinkedList(hl.expr.expressions.Aggregation)) + n_cols_ref = hl.expr.construct_expr( + hl.ir.Ref('n_cols', hl.tint32), hl.tint32, mt._row_indices, hl.utils.LinkedList(hl.expr.expressions.Aggregation) + ) bound_exprs['n_filtered'] = hl.int64(n_cols_ref) - hl.agg.count() bound_exprs['call_stats'] = hl.agg.call_stats(mt.GT, mt.alleles) - result = hl.rbind(hl.struct(**bound_exprs), - lambda e1: hl.rbind( - hl.case().when( - hl.len(mt.alleles) == 2, - (hl.hardy_weinberg_test(e1.call_stats.homozygote_count[0], - e1.call_stats.AC[1] - 2 - * e1.call_stats.homozygote_count[1], - e1.call_stats.homozygote_count[1]), - hl.hardy_weinberg_test(e1.call_stats.homozygote_count[0], - e1.call_stats.AC[1] - 2 - * e1.call_stats.homozygote_count[1], - e1.call_stats.homozygote_count[1], - one_sided=True)) - ).or_missing(), - lambda hwe: hl.struct(**{ - **gq_dp_exprs, - **e1.call_stats, - 'call_rate': hl.float(e1.n_called) / (e1.n_called + e1.n_not_called + e1.n_filtered), - 'n_called': e1.n_called, - 'n_not_called': e1.n_not_called, - 'n_filtered': e1.n_filtered, - 'n_het': e1.n_called - hl.sum(e1.call_stats.homozygote_count), - 'n_non_ref': e1.n_called - e1.call_stats.homozygote_count[0], - 'het_freq_hwe': hwe[0].het_freq_hwe, - 'p_value_hwe': hwe[0].p_value, - 'p_value_excess_het': hwe[1].p_value}))) + result = hl.rbind( + hl.struct(**bound_exprs), + lambda e1: hl.rbind( + hl.case() + .when( + hl.len(mt.alleles) == 2, + ( + hl.hardy_weinberg_test( + e1.call_stats.homozygote_count[0], + e1.call_stats.AC[1] - 2 * e1.call_stats.homozygote_count[1], + e1.call_stats.homozygote_count[1], + ), + hl.hardy_weinberg_test( + e1.call_stats.homozygote_count[0], + e1.call_stats.AC[1] - 2 * e1.call_stats.homozygote_count[1], + e1.call_stats.homozygote_count[1], + one_sided=True, + ), + ), + ) + .or_missing(), + lambda hwe: hl.struct(**{ + **gq_dp_exprs, + **e1.call_stats, + 'call_rate': hl.float(e1.n_called) / (e1.n_called + e1.n_not_called + e1.n_filtered), + 'n_called': e1.n_called, + 'n_not_called': e1.n_not_called, + 'n_filtered': e1.n_filtered, + 'n_het': e1.n_called - hl.sum(e1.call_stats.homozygote_count), + 'n_non_ref': e1.n_called - e1.call_stats.homozygote_count[0], + 'het_freq_hwe': hwe[0].het_freq_hwe, + 'p_value_hwe': hwe[0].p_value, + 'p_value_excess_het': hwe[1].p_value, + }), + ), + ) return mt.annotate_rows(**{name: result}) -@typecheck(left=MatrixTable, - right=MatrixTable, - _localize_global_statistics=bool) +@typecheck(left=MatrixTable, right=MatrixTable, _localize_global_statistics=bool) def concordance(left, right, *, _localize_global_statistics=True) -> Tuple[List[List[int]], Table, Table]: """Calculate call concordance with another dataset. @@ -450,14 +478,16 @@ def concordance(left, right, *, _localize_global_statistics=True) -> Tuple[List[ left_bad = [f'{k!r}: {v}' for k, v in left_sample_counter.items() if v > 1] right_bad = [f'{k!r}: {v}' for k, v in right_sample_counter.items() if v > 1] if left_bad or right_bad: - raise ValueError(f"Found duplicate sample IDs:\n" - f" left: {', '.join(left_bad)}\n" - f" right: {', '.join(right_bad)}") + raise ValueError( + f"Found duplicate sample IDs:\n" f" left: {', '.join(left_bad)}\n" f" right: {', '.join(right_bad)}" + ) included = set(left_sample_counter.keys()).intersection(set(right_sample_counter.keys())) - info(f"concordance: including {len(included)} shared samples " - f"({len(left_sample_counter)} total on left, {len(right_sample_counter)} total on right)") + info( + f"concordance: including {len(included)} shared samples " + f"({len(left_sample_counter)} total on left, {len(right_sample_counter)} total on right)" + ) left = require_biallelic(left, 'concordance, left') right = require_biallelic(right, 'concordance, right') @@ -472,10 +502,7 @@ def concordance(left, right, *, _localize_global_statistics=True) -> Tuple[List[ joined = hl.experimental.full_outer_join_mt(left, right) def get_idx(struct): - return hl.if_else( - hl.is_missing(struct), - 0, - hl.coalesce(2 + struct.GT.n_alt_alleles(), 1)) + return hl.if_else(hl.is_missing(struct), 0, hl.coalesce(2 + struct.GT.n_alt_alleles(), 1)) aggr = hl.agg.counter(get_idx(joined.left_entry) + 5 * get_idx(joined.right_entry)) @@ -492,7 +519,8 @@ def n_discordant(counter): return hl.sum( hl.array(counter) .filter(lambda tup: hl.literal(discordant_indices).contains(tup[0])) - .map(lambda tup: tup[1])) + .map(lambda tup: tup[1]) + ) glob = joined.aggregate_entries(concordance_array(aggr), _localize=_localize_global_statistics) if _localize_global_statistics: @@ -503,138 +531,159 @@ def n_discordant(counter): info(f"concordance: total concordance {pct:.2f}%") per_variant = joined.annotate_rows(concordance=aggr) - per_variant = per_variant.select_rows(concordance=concordance_array(per_variant.concordance), - n_discordant=n_discordant(per_variant.concordance)) + per_variant = per_variant.select_rows( + concordance=concordance_array(per_variant.concordance), n_discordant=n_discordant(per_variant.concordance) + ) per_sample = joined.annotate_cols(concordance=aggr) - per_sample = per_sample.select_cols(concordance=concordance_array(per_sample.concordance), - n_discordant=n_discordant(per_sample.concordance)) + per_sample = per_sample.select_cols( + concordance=concordance_array(per_sample.concordance), n_discordant=n_discordant(per_sample.concordance) + ) return glob, per_sample.cols(), per_variant.rows() -vep_json_typ = hl.tstruct( - assembly_name=hl.tstr, - allele_string=hl.tstr, - ancestral=hl.tstr, - colocated_variants=hl.tarray(hl.tstruct( - aa_allele=hl.tstr, - aa_maf=hl.tfloat, - afr_allele=hl.tstr, - afr_maf=hl.tfloat, - allele_string=hl.tstr, - amr_allele=hl.tstr, - amr_maf=hl.tfloat, - clin_sig=hl.tarray(hl.tstr), - end=hl.tint32, - eas_allele=hl.tstr, - eas_maf=hl.tfloat, - ea_allele=hl.tstr, - ea_maf=hl.tfloat, - eur_allele=hl.tstr, - eur_maf=hl.tfloat, - exac_adj_allele=hl.tstr, - exac_adj_maf=hl.tfloat, - exac_allele=hl.tstr, - exac_afr_allele=hl.tstr, - exac_afr_maf=hl.tfloat, - exac_amr_allele=hl.tstr, - exac_amr_maf=hl.tfloat, - exac_eas_allele=hl.tstr, - exac_eas_maf=hl.tfloat, - exac_fin_allele=hl.tstr, - exac_fin_maf=hl.tfloat, - exac_maf=hl.tfloat, - exac_nfe_allele=hl.tstr, - exac_nfe_maf=hl.tfloat, - exac_oth_allele=hl.tstr, - exac_oth_maf=hl.tfloat, - exac_sas_allele=hl.tstr, - exac_sas_maf=hl.tfloat, - id=hl.tstr, - minor_allele=hl.tstr, - minor_allele_freq=hl.tfloat, - phenotype_or_disease=hl.tint32, - pubmed=hl.tarray(hl.tint32), - sas_allele=hl.tstr, - sas_maf=hl.tfloat, - somatic=hl.tint32, - start=hl.tint32, - strand=hl.tint32)), - context=hl.tstr, - end=hl.tint32, - id=hl.tstr, - input=hl.tstr, - intergenic_consequences=hl.tarray(hl.tstruct(allele_num=hl.tint32, - consequence_terms=hl.tarray(hl.tstr), - impact=hl.tstr, - minimised=hl.tint32, - variant_allele=hl.tstr)), - most_severe_consequence=hl.tstr, - motif_feature_consequences=hl.tarray(hl.tstruct(allele_num=hl.tint32, - consequence_terms=hl.tarray(hl.tstr), - high_inf_pos=hl.tstr, - impact=hl.tstr, - minimised=hl.tint32, - motif_feature_id=hl.tstr, - motif_name=hl.tstr, - motif_pos=hl.tint32, - motif_score_change=hl.tfloat, - strand=hl.tint32, - variant_allele=hl.tstr)), - regulatory_feature_consequences=hl.tarray(hl.tstruct(allele_num=hl.tint32, - biotype=hl.tstr, - consequence_terms=hl.tarray(hl.tstr), - impact=hl.tstr, - minimised=hl.tint32, - regulatory_feature_id=hl.tstr, - variant_allele=hl.tstr)), - seq_region_name=hl.tstr, - start=hl.tint32, - strand=hl.tint32, - transcript_consequences=hl.tarray(hl.tstruct(allele_num=hl.tint32, - amino_acids=hl.tstr, - biotype=hl.tstr, - canonical=hl.tint32, - ccds=hl.tstr, - cdna_start=hl.tint32, - cdna_end=hl.tint32, - cds_end=hl.tint32, - cds_start=hl.tint32, - codons=hl.tstr, - consequence_terms=hl.tarray(hl.tstr), - distance=hl.tint32, - domains=hl.tarray(hl.tstruct(db=hl.tstr, - name=hl.tstr)), - exon=hl.tstr, - gene_id=hl.tstr, - gene_pheno=hl.tint32, - gene_symbol=hl.tstr, - gene_symbol_source=hl.tstr, - hgnc_id=hl.tstr, - hgvsc=hl.tstr, - hgvsp=hl.tstr, - hgvs_offset=hl.tint32, - impact=hl.tstr, - intron=hl.tstr, - lof=hl.tstr, - lof_flags=hl.tstr, - lof_filter=hl.tstr, - lof_info=hl.tstr, - minimised=hl.tint32, - polyphen_prediction=hl.tstr, - polyphen_score=hl.tfloat, - protein_end=hl.tint32, - protein_start=hl.tint32, - protein_id=hl.tstr, - sift_prediction=hl.tstr, - sift_score=hl.tfloat, - strand=hl.tint32, - swissprot=hl.tstr, - transcript_id=hl.tstr, - trembl=hl.tstr, - uniparc=hl.tstr, - variant_allele=hl.tstr)), - variant_class=hl.tstr) +vep_json_typ = tstruct( + assembly_name=tstr, + allele_string=tstr, + ancestral=tstr, + colocated_variants=tarray( + tstruct( + aa_allele=tstr, + aa_maf=tfloat, + afr_allele=tstr, + afr_maf=tfloat, + allele_string=tstr, + amr_allele=tstr, + amr_maf=tfloat, + clin_sig=tarray(tstr), + end=tint32, + eas_allele=tstr, + eas_maf=tfloat, + ea_allele=tstr, + ea_maf=tfloat, + eur_allele=tstr, + eur_maf=tfloat, + exac_adj_allele=tstr, + exac_adj_maf=tfloat, + exac_allele=tstr, + exac_afr_allele=tstr, + exac_afr_maf=tfloat, + exac_amr_allele=tstr, + exac_amr_maf=tfloat, + exac_eas_allele=tstr, + exac_eas_maf=tfloat, + exac_fin_allele=tstr, + exac_fin_maf=tfloat, + exac_maf=tfloat, + exac_nfe_allele=tstr, + exac_nfe_maf=tfloat, + exac_oth_allele=tstr, + exac_oth_maf=tfloat, + exac_sas_allele=tstr, + exac_sas_maf=tfloat, + id=tstr, + minor_allele=tstr, + minor_allele_freq=tfloat, + phenotype_or_disease=tint32, + pubmed=tarray(tint32), + sas_allele=tstr, + sas_maf=tfloat, + somatic=tint32, + start=tint32, + strand=tint32, + ) + ), + context=tstr, + end=tint32, + id=tstr, + input=tstr, + intergenic_consequences=tarray( + tstruct( + allele_num=tint32, + consequence_terms=tarray(tstr), + impact=tstr, + minimised=tint32, + variant_allele=tstr, + ) + ), + most_severe_consequence=tstr, + motif_feature_consequences=tarray( + tstruct( + allele_num=tint32, + consequence_terms=tarray(tstr), + high_inf_pos=tstr, + impact=tstr, + minimised=tint32, + motif_feature_id=tstr, + motif_name=tstr, + motif_pos=tint32, + motif_score_change=tfloat, + strand=tint32, + variant_allele=tstr, + ) + ), + regulatory_feature_consequences=tarray( + tstruct( + allele_num=tint32, + biotype=tstr, + consequence_terms=tarray(tstr), + impact=tstr, + minimised=tint32, + regulatory_feature_id=tstr, + variant_allele=tstr, + ) + ), + seq_region_name=tstr, + start=tint32, + strand=tint32, + transcript_consequences=tarray( + tstruct( + allele_num=tint32, + amino_acids=tstr, + biotype=tstr, + canonical=tint32, + ccds=tstr, + cdna_start=tint32, + cdna_end=tint32, + cds_end=tint32, + cds_start=tint32, + codons=tstr, + consequence_terms=tarray(tstr), + distance=tint32, + domains=tarray(tstruct(db=tstr, name=tstr)), + exon=tstr, + gene_id=tstr, + gene_pheno=tint32, + gene_symbol=tstr, + gene_symbol_source=tstr, + hgnc_id=tstr, + hgvsc=tstr, + hgvsp=tstr, + hgvs_offset=tint32, + impact=tstr, + intron=tstr, + lof=tstr, + lof_flags=tstr, + lof_filter=tstr, + lof_info=tstr, + minimised=tint32, + polyphen_prediction=tstr, + polyphen_score=tfloat, + protein_end=tint32, + protein_start=tint32, + protein_id=tstr, + sift_prediction=tstr, + sift_score=tfloat, + strand=tint32, + swissprot=tstr, + transcript_id=tstr, + trembl=tstr, + uniparc=tstr, + variant_allele=tstr, + ) + ), + variant_class=tstr, +) class VEPConfig(abc.ABC): @@ -710,12 +759,9 @@ def command(self, batch_run_csq_header_command: List[str] @abc.abstractmethod - def command(self, - consequence: bool, - tolerate_parse_error: bool, - part_id: int, - input_file: Optional[str], - output_file: str) -> List[str]: + def command( + self, consequence: bool, tolerate_parse_error: bool, part_id: int, input_file: Optional[str], output_file: str + ) -> List[str]: raise NotImplementedError @@ -734,14 +780,16 @@ class VEPConfigGRCh37Version85(VEPConfig): """ - def __init__(self, - *, - data_bucket: str, - data_mount: str, - image: str, - regions: List[str], - cloud: str, - data_bucket_is_requester_pays: bool): + def __init__( + self, + *, + data_bucket: str, + data_mount: str, + image: str, + regions: List[str], + cloud: str, + data_bucket_is_requester_pays: bool, + ): self.data_bucket = data_bucket self.data_mount = data_mount self.image = image @@ -753,16 +801,18 @@ def __init__(self, self.batch_run_csq_header_command = ['python3', '/hail-vep/vep.py', 'csq_header'] self.json_typ = vep_json_typ - def command(self, - *, - consequence: bool, - tolerate_parse_error: bool, - part_id: int, - input_file: Optional[str], - output_file: str) -> str: + def command( + self, + *, + consequence: bool, + tolerate_parse_error: bool, + part_id: int, + input_file: Optional[str], + output_file: str, + ) -> str: vcf_or_json = '--vcf' if consequence else '--json' input_file = f'--input_file {input_file}' if input_file else '' - return f'''/vep/vep {input_file} \ + return f"""/vep/vep {input_file} \ --format vcf \ {vcf_or_json} \ --everything \ @@ -775,7 +825,7 @@ def command(self, --dir={self.data_mount} \ --plugin LoF,human_ancestor_fa:{self.data_mount}/loftee_data/human_ancestor.fa.gz,filter_position:0.05,min_intron_size:15,conservation_file:{self.data_mount}/loftee_data/phylocsf_gerp.sql,gerp_file:{self.data_mount}/loftee_data/GERP_scores.final.sorted.txt.gz \ -o STDOUT -''' +""" class VEPConfigGRCh38Version95(VEPConfig): @@ -793,14 +843,16 @@ class VEPConfigGRCh38Version95(VEPConfig): """ - def __init__(self, - *, - data_bucket: str, - data_mount: str, - image: str, - regions: List[str], - cloud: str, - data_bucket_is_requester_pays: bool): + def __init__( + self, + *, + data_bucket: str, + data_mount: str, + image: str, + regions: List[str], + cloud: str, + data_bucket_is_requester_pays: bool, + ): self.data_bucket = data_bucket self.data_mount = data_mount self.image = image @@ -810,23 +862,28 @@ def __init__(self, self.cloud = cloud self.batch_run_command = ['python3', '/hail-vep/vep.py', 'vep'] self.batch_run_csq_header_command = ['python3', '/hail-vep/vep.py', 'csq_header'] - self.json_typ = vep_json_typ._insert_field('transcript_consequences', hl.tarray( - vep_json_typ['transcript_consequences'].element_type._insert_fields( - appris=hl.tstr, - tsl=hl.tint32, - ) - )) - - def command(self, - *, - consequence: bool, - tolerate_parse_error: bool, - part_id: int, - input_file: Optional[str], - output_file: str) -> str: + self.json_typ = vep_json_typ._insert_field( + 'transcript_consequences', + tarray( + vep_json_typ['transcript_consequences'].element_type._insert_fields( + appris=tstr, + tsl=tint32, + ) + ), + ) + + def command( + self, + *, + consequence: bool, + tolerate_parse_error: bool, + part_id: int, + input_file: Optional[str], + output_file: str, + ) -> str: vcf_or_json = '--vcf' if consequence else '--json' input_file = f'--input_file {input_file}' if input_file else '' - return f'''/vep/vep {input_file} \ + return f"""/vep/vep {input_file} \ --format vcf \ {vcf_or_json} \ --everything \ @@ -836,12 +893,12 @@ def command(self, --offline \ --minimal \ --assembly GRCh38 \ ---fasta {self.data_mount}/homo_sapiens/95_GRCh38/Homo_sapiens.GRCh38.dna.toplevel.fa.gz \ +--fasta {self.data_mount}homo_sapiens/95_GRCh38/Homo_sapiens.GRCh38.dna.toplevel.fa.gz \ --plugin "LoF,loftee_path:/vep/ensembl-vep/Plugins/,gerp_bigwig:{self.data_mount}/gerp_conservation_scores.homo_sapiens.GRCh38.bw,human_ancestor_fa:{self.data_mount}/human_ancestor.fa.gz,conservation_file:{self.data_mount}/loftee.sql" \ --dir_plugins /vep/ensembl-vep/Plugins/ \ --dir_cache {self.data_mount} \ -o STDOUT -''' +""" supported_vep_configs = { @@ -872,20 +929,24 @@ def _supported_vep_config(cloud: str, reference_genome: str, *, regions: List[st if config_params in supported_vep_configs: return supported_vep_configs[config_params] - raise ValueError(f'could not find a supported vep configuration for reference genome {reference_genome}, ' - f'cloud {cloud}, regions {regions}, and domain {domain}') + raise ValueError( + f'could not find a supported vep configuration for reference genome {reference_genome}, ' + f'cloud {cloud}, regions {regions}, and domain {domain}' + ) -def _service_vep(backend: ServiceBackend, - ht: Table, - config: Optional[VEPConfig], - block_size: int, - csq: bool, - tolerate_parse_error: bool, - temp_input_directory: str, - temp_output_directory: str) -> hl.Table: +def _service_vep( + backend: ServiceBackend, + ht: Table, + config: Optional[VEPConfig], + block_size: int, + csq: bool, + tolerate_parse_error: bool, + temp_input_directory: str, + temp_output_directory: str, +) -> Table: reference_genome = ht.locus.dtype.reference_genome.name - cloud = backend.bc.cloud() + cloud = async_to_blocking(backend._batch_client.cloud()) regions = backend.regions if config is not None: @@ -895,9 +956,11 @@ def _service_vep(backend: ServiceBackend, requester_pays_project = backend.flags.get('gcs_requester_pays_project') if requester_pays_project is None and vep_config.data_bucket_is_requester_pays and vep_config.cloud == 'gcp': - raise ValueError("No requester pays project has been set. " - "Use hl.init(gcs_requester_pays_configuration='MY_PROJECT') " - "to set the requester pays project to use.") + raise ValueError( + "No requester pays project has been set. " + "Use hl.init(gcs_requester_pays_configuration='MY_PROJECT') " + "to set the requester pays project to use." + ) if csq: vep_typ = hl.tarray(hl.tstr) @@ -908,7 +971,11 @@ def build_vep_batch(b: bc.aioclient.Batch, vep_input_path: str, vep_output_path: if csq: local_output_file = '/io/output' vep_command = vep_config.command( - consequence=csq, part_id=-1, input_file=None, output_file=local_output_file, tolerate_parse_error=tolerate_parse_error + consequence=csq, + part_id=-1, + input_file=None, + output_file=local_output_file, + tolerate_parse_error=tolerate_parse_error, ) env = { 'VEP_BLOCK_SIZE': str(block_size), @@ -917,18 +984,20 @@ def build_vep_batch(b: bc.aioclient.Batch, vep_input_path: str, vep_output_path: 'VEP_TOLERATE_PARSE_ERROR': str(int(tolerate_parse_error)), 'VEP_PART_ID': str(-1), 'VEP_OUTPUT_FILE': local_output_file, - 'VEP_COMMAND': vep_command + 'VEP_COMMAND': vep_command, } env.update(vep_config.env) - b.create_job(vep_config.image, - vep_config.batch_run_csq_header_command, - attributes={'name': 'csq-header'}, - resources={'cpu': '1', 'memory': 'standard'}, - cloudfuse=[(vep_config.data_bucket, vep_config.data_mount, True)], - output_files=[(local_output_file, f'{vep_output_path}/csq-header')], - regions=vep_config.regions, - requester_pays_project=requester_pays_project, - env=env) + b.create_job( + vep_config.image, + vep_config.batch_run_csq_header_command, + attributes={'name': 'csq-header'}, + resources={'cpu': '1', 'memory': 'standard'}, + cloudfuse=[(vep_config.data_bucket, vep_config.data_mount, True)], + output_files=[(local_output_file, f'{vep_output_path}/csq-header')], + regions=vep_config.regions, + requester_pays_project=requester_pays_project, + env=env, + ) for f in hl.hadoop_ls(vep_input_path): path = f['path'] @@ -960,16 +1029,18 @@ def build_vep_batch(b: bc.aioclient.Batch, vep_input_path: str, vep_output_path: } env.update(vep_config.env) - b.create_job(vep_config.image, - vep_config.batch_run_command, - attributes={'name': f'vep-{part_id}'}, - resources={'cpu': '1', 'memory': 'standard'}, - input_files=[(path, local_input_file)], - output_files=[(local_output_file, f'{vep_output_path}/annotations/{part_name}.tsv.gz')], - cloudfuse=[(vep_config.data_bucket, vep_config.data_mount, True)], - regions=vep_config.regions, - requester_pays_project=requester_pays_project, - env=env) + b.create_job( + vep_config.image, + vep_config.batch_run_command, + attributes={'name': f'vep-{part_id}'}, + resources={'cpu': '1', 'memory': 'standard'}, + input_files=[(path, local_input_file)], + output_files=[(local_output_file, f'{vep_output_path}/annotations/{part_name}.tsv.gz')], + cloudfuse=[(vep_config.data_bucket, vep_config.data_mount, True)], + regions=vep_config.regions, + requester_pays_project=requester_pays_project, + env=env, + ) hl.export_vcf(ht, temp_input_directory, parallel='header_per_shard') @@ -981,10 +1052,12 @@ def build_vep_batch(b: bc.aioclient.Batch, vep_input_path: str, vep_output_path: b.submit(disable_progress_bar=True) try: - status = b.wait(description='vep(...)', - disable_progress_bar=backend.disable_progress_bar, - progress=None, - starting_job=starting_job_id) + status = b.wait( + description='vep(...)', + disable_progress_bar=backend.disable_progress_bar, + progress=None, + starting_job=starting_job_id, + ) except BaseException as e: if isinstance(e, KeyboardInterrupt): print("Received a keyboard interrupt, cancelling the batch...") @@ -993,26 +1066,20 @@ def build_vep_batch(b: bc.aioclient.Batch, vep_input_path: str, vep_output_path: raise if status['n_succeeded'] != status['n_jobs']: - failing_job = [job for job in b.jobs('!success')][0] + failing_job = next(iter(b.jobs('!success'))) failing_job = b.get_job(failing_job['job_id']) - message = { - 'batch_status': status, - 'job_status': failing_job.status(), - 'log': failing_job.log() - } + message = {'batch_status': status, 'job_status': failing_job.status(), 'log': failing_job.log()} raise FatalError(yamlx.dump(message)) - annotations = hl.import_table(f'{temp_output_directory}/annotations/*', - types={'variant': hl.tstr, - 'vep': vep_typ, - 'part_id': hl.tint, - 'block_id': hl.tint}, - force=True) + annotations = hl.import_table( + f'{temp_output_directory}/annotations/*', + types={'variant': hl.tstr, 'vep': vep_typ, 'part_id': hl.tint, 'block_id': hl.tint}, + force=True, + ) - annotations = annotations.annotate(vep_proc_id=hl.struct( - part_id=annotations.part_id, - block_id=annotations.block_id - )) + annotations = annotations.annotate( + vep_proc_id=hl.struct(part_id=annotations.part_id, block_id=annotations.block_id) + ) annotations = annotations.drop('part_id', 'block_id') annotations = annotations.key_by(**hl.parse_variant(annotations.variant, reference_genome=reference_genome)) annotations = annotations.drop('variant') @@ -1025,18 +1092,22 @@ def build_vep_batch(b: bc.aioclient.Batch, vep_input_path: str, vep_output_path: return annotations -@typecheck(dataset=oneof(Table, MatrixTable), - config=nullable(oneof(str, VEPConfig)), - block_size=int, - name=str, - csq=bool, - tolerate_parse_error=bool) -def vep(dataset: Union[Table, MatrixTable], - config: Optional[Union[str, VEPConfig]] = None, - block_size: int = 1000, - name: str = 'vep', - csq: bool = False, - tolerate_parse_error: bool = False): +@typecheck( + dataset=oneof(Table, MatrixTable), + config=nullable(oneof(str, VEPConfig)), + block_size=int, + name=str, + csq=bool, + tolerate_parse_error=bool, +) +def vep( + dataset: Union[Table, MatrixTable], + config: Optional[Union[str, VEPConfig]] = None, + block_size: int = 1000, + name: str = 'vep', + csq: bool = False, + tolerate_parse_error: bool = False, +): """Annotate variants with VEP. .. include:: ../_templates/req_tvariant.rst @@ -1103,8 +1174,8 @@ def vep(dataset: Union[Table, MatrixTable], The configuration files used by``hailctl dataproc`` can be found at the following locations: - - ``GRCh37``: ``gs://hail-us-vep/vep85-loftee-gcloud.json`` - - ``GRCh38``: ``gs://hail-us-vep/vep95-GRCh38-loftee-gcloud.json`` + - ``GRCh37``: ``gs://hail-us-central1-vep/vep85-loftee-gcloud.json`` + - ``GRCh38``: ``gs://hail-us-central1-vep/vep95-GRCh38-loftee-gcloud.json`` If no config file is specified, this function will check to see if environment variable `VEP_CONFIG_URI` is set with a path to a config file. @@ -1158,7 +1229,9 @@ def vep(dataset: Union[Table, MatrixTable], if isinstance(backend, ServiceBackend): with hl.TemporaryDirectory(prefix='qob/vep/inputs/') as vep_input_path: with hl.TemporaryDirectory(prefix='qob/vep/outputs/') as vep_output_path: - annotations = _service_vep(backend, ht, config, block_size, csq, tolerate_parse_error, vep_input_path, vep_output_path) + annotations = _service_vep( + backend, ht, config, block_size, csq, tolerate_parse_error, vep_input_path, vep_output_path + ) annotations = annotations.checkpoint(new_temp_file()) else: if config is None: @@ -1168,21 +1241,27 @@ def vep(dataset: Union[Table, MatrixTable], config = maybe_config elif maybe_cloud_spark_provider == 'hdinsight': warning( - 'Assuming you are in a hailctl hdinsight cluster. If not, specify the config parameter to `hl.vep`.') + 'Assuming you are in a hailctl hdinsight cluster. If not, specify the config parameter to `hl.vep`.' + ) config = 'file:/vep_data/vep-azure.json' else: raise ValueError("No config set and VEP_CONFIG_URI was not set.") - annotations = Table(TableToTableApply(ht._tir, - {'name': 'VEP', - 'config': config, - 'csq': csq, - 'blockSize': block_size, - 'tolerateParseError': tolerate_parse_error})).persist() + annotations = Table( + TableToTableApply( + ht._tir, + { + 'name': 'VEP', + 'config': config, + 'csq': csq, + 'blockSize': block_size, + 'tolerateParseError': tolerate_parse_error, + }, + ) + ).persist() if csq: - dataset = dataset.annotate_globals( - **{name + '_csq_header': annotations.index_globals()['vep_csq_header']}) + dataset = dataset.annotate_globals(**{name + '_csq_header': annotations.index_globals()['vep_csq_header']}) if isinstance(dataset, MatrixTable): vep = annotations[dataset.row_key] @@ -1192,10 +1271,7 @@ def vep(dataset: Union[Table, MatrixTable], return dataset.annotate(**{name: vep.vep, name + '_proc_id': vep.vep_proc_id}) -@typecheck(dataset=oneof(Table, MatrixTable), - config=str, - block_size=int, - name=str) +@typecheck(dataset=oneof(Table, MatrixTable), config=str, block_size=int, name=str) def nirvana(dataset: Union[MatrixTable, Table], config, block_size=500000, name='nirvana'): """Annotate variants using `Nirvana `_. @@ -1523,11 +1599,9 @@ def nirvana(dataset: Union[MatrixTable, Table], config, block_size=500000, name= require_table_key_variant(dataset, 'nirvana') ht = dataset.select() - annotations = Table(TableToTableApply(ht._tir, - {'name': 'Nirvana', - 'config': config, - 'blockSize': block_size} - )).persist() + annotations = Table( + TableToTableApply(ht._tir, {'name': 'Nirvana', 'config': config, 'blockSize': block_size}) + ).persist() if isinstance(dataset, MatrixTable): return dataset.annotate_rows(**{name: annotations[dataset.row_key].nirvana}) @@ -1594,6 +1668,7 @@ def _html_string(self): contig_idx = {contig: i for i, contig in enumerate(self.rg.contigs)} import html + builder = [] builder.append('

    Variant summary:

    ') builder.append('
      ') @@ -1706,15 +1781,18 @@ def summarize_variants(mt: Union[MatrixTable, MatrixTable], show=True, *, handle def explode_result(alleles): ref, alt = alleles - return (hl.agg.counter(hl.allele_type(ref, alt)), - hl.agg.count_where(hl.is_transition(ref, alt)), - hl.agg.count_where(hl.is_transversion(ref, alt))) - - (allele_types, nti, ntv), contigs, allele_counts, n_variants = ht.aggregate( - (hl.agg.explode(explode_result, allele_pairs), - hl.agg.counter(ht.locus.contig), - hl.agg.counter(hl.len(ht.alleles)), - hl.agg.count())) + return ( + hl.agg.counter(hl.allele_type(ref, alt)), + hl.agg.count_where(hl.is_transition(ref, alt)), + hl.agg.count_where(hl.is_transversion(ref, alt)), + ) + + (allele_types, nti, ntv), contigs, allele_counts, n_variants = ht.aggregate(( + hl.agg.explode(explode_result, allele_pairs), + hl.agg.counter(ht.locus.contig), + hl.agg.counter(hl.len(ht.alleles)), + hl.agg.count(), + )) rg = ht.locus.dtype.reference_genome if show: summary = _VariantSummary(rg, n_variants, allele_counts, contigs, allele_types, nti, ntv) @@ -1722,28 +1800,32 @@ def explode_result(alleles): handler = hl.utils.default_handler() handler(summary) else: - return hl.Struct(allele_types=allele_types, - contigs=contigs, - allele_counts=allele_counts, - n_variants=n_variants, - r_ti_tv=nti / ntv) - - -@typecheck(ds=oneof(hl.MatrixTable, lambda: hl.vds.VariantDataset), - min_af=numeric, - max_af=numeric, - min_dp=int, - max_dp=int, - min_gq=int, - ref_AF=nullable(expr_float64)) + return hl.Struct( + allele_types=allele_types, + contigs=contigs, + allele_counts=allele_counts, + n_variants=n_variants, + r_ti_tv=nti / ntv, + ) + + +@typecheck( + ds=oneof(MatrixTable, lambda: hl.vds.VariantDataset), + min_af=numeric, + max_af=numeric, + min_dp=int, + max_dp=int, + min_gq=int, + ref_AF=nullable(expr_float64), +) def compute_charr( - ds: Union[hl.MatrixTable, 'hl.vds.VariantDataset'], - min_af: float = 0.05, - max_af: float = 0.95, - min_dp: int = 10, - max_dp: int = 100, - min_gq: int = 20, - ref_AF: Optional[hl.Float64Expression] = None + ds: Union[MatrixTable, 'hl.vds.VariantDataset'], + min_af: float = 0.05, + max_af: float = 0.95, + min_dp: int = 10, + max_dp: int = 100, + min_gq: int = 20, + ref_AF: Optional[Float64Expression] = None, ): """Compute CHARR, the DNA sample contamination estimator. @@ -1802,19 +1884,21 @@ def compute_charr( ad_field = 'AD' gt_field = 'GT' else: - raise ValueError(f"'compute_charr': require a VDS or MatrixTable with fields LAD/LAD/LGT/GQ/DP or AD/GT/GQ/DP," - f" found entry fields {list(mt.entry)}") + raise ValueError( + f"'compute_charr': require a VDS or MatrixTable with fields LAD/LAD/LGT/GQ/DP or AD/GT/GQ/DP," + f" found entry fields {list(mt.entry)}" + ) # Annotate reference allele frequency when it is not defined in the original data, and name it 'ref_AF'. ref_af_field = '__ref_af' if ref_AF is None: n_samples = mt.count_cols() if n_samples < 10000: - raise ValueError("'compute_charr': with fewer than 10,000 samples, require a reference AF in 'reference_data_source'.") + raise ValueError( + "'compute_charr': with fewer than 10,000 samples, require a reference AF in 'reference_data_source'." + ) n_alleles = 2 * n_samples - mt = mt.annotate_rows( - **{ref_af_field: 1 - hl.agg.sum(mt[gt_field].n_alt_alleles()) / n_alleles} - ) + mt = mt.annotate_rows(**{ref_af_field: 1 - hl.agg.sum(mt[gt_field].n_alt_alleles()) / n_alleles}) else: mt = mt.annotate_rows(**{ref_af_field: ref_AF}) @@ -1836,14 +1920,10 @@ def compute_charr( # Filter to variant calls with GQ above min_gq and DP within the range (min_dp, max_dp) ad_dp = mt['DP'] if 'DP' in mt.entry else hl.sum(mt[ad_field]) - mt = mt.filter_entries( - mt[gt_field].is_hom_var() & (mt.GQ >= min_gq) & (ad_dp >= min_dp) & (ad_dp <= max_dp) - ) + mt = mt.filter_entries(mt[gt_field].is_hom_var() & (mt.GQ >= min_gq) & (ad_dp >= min_dp) & (ad_dp <= max_dp)) # Compute CHARR - mt = mt.select_cols( - charr=hl.agg.mean((mt[ad_field][0] / (mt[ad_field][0] + mt[ad_field][1])) / mt[ref_af_field]) - ) + mt = mt.select_cols(charr=hl.agg.mean((mt[ad_field][0] / (mt[ad_field][0] + mt[ad_field][1])) / mt[ref_af_field])) mt = mt.select_globals( af_min=min_af, diff --git a/hail/python/hail/methods/relatedness/__init__.py b/hail/python/hail/methods/relatedness/__init__.py index 6aeaec63142..84fce5c99b8 100644 --- a/hail/python/hail/methods/relatedness/__init__.py +++ b/hail/python/hail/methods/relatedness/__init__.py @@ -1,7 +1,7 @@ from .identity_by_descent import identity_by_descent from .king import king -from .pc_relate import pc_relate from .mating_simulation import simulate_random_mating +from .pc_relate import pc_relate __all__ = [ 'identity_by_descent', diff --git a/hail/python/hail/methods/relatedness/identity_by_descent.py b/hail/python/hail/methods/relatedness/identity_by_descent.py index ecce456eaca..68460389e60 100644 --- a/hail/python/hail/methods/relatedness/identity_by_descent.py +++ b/hail/python/hail/methods/relatedness/identity_by_descent.py @@ -1,21 +1,17 @@ import hail as hl +from hail import ir from hail.backend.spark_backend import SparkBackend from hail.expr import analyze from hail.expr.expressions import expr_float64 -import hail.ir as ir -from hail.table import Table +from hail.linalg import BlockMatrix from hail.matrixtable import MatrixTable from hail.methods.misc import require_biallelic, require_col_key_str -from hail.typecheck import typecheck, nullable, numeric -from hail.linalg import BlockMatrix +from hail.table import Table +from hail.typecheck import nullable, numeric, typecheck from hail.utils.java import Env -@typecheck(dataset=MatrixTable, - maf=nullable(expr_float64), - bounded=bool, - min=nullable(numeric), - max=nullable(numeric)) +@typecheck(dataset=MatrixTable, maf=nullable(expr_float64), bounded=bool, min=nullable(numeric), max=nullable(numeric)) def identity_by_descent(dataset, maf=None, bounded=True, min=None, max=None) -> Table: """Compute matrix of identity-by-descent estimates. @@ -105,13 +101,18 @@ def identity_by_descent(dataset, maf=None, bounded=True, min=None, max=None) -> dataset = require_biallelic(dataset, 'ibd') if isinstance(Env.backend(), SparkBackend): - return Table(ir.MatrixToTableApply(dataset._mir, { - 'name': 'IBD', - 'mafFieldName': '__maf' if maf is not None else None, - 'bounded': bounded, - 'min': min, - 'max': max, - })).persist() + return Table( + ir.MatrixToTableApply( + dataset._mir, + { + 'name': 'IBD', + 'mafFieldName': '__maf' if maf is not None else None, + 'bounded': bounded, + 'min': min, + 'max': max, + }, + ) + ).persist() min = min or 0 max = max or 1 @@ -143,31 +144,41 @@ def identity_by_descent(dataset, maf=None, bounded=True, min=None, max=None) -> q = 1 - p dataset = dataset.annotate_rows( - _e00=(2 * (p ** 2) * (q ** 2) * ((X - 1) / X) * ((Y - 1) / Y) * (T / (T - 1)) * (T / (T - 2)) * (T / (T - 3))), - _e10=(4 * (p ** 3) * q * ((X - 1) / X) * ((X - 2) / X) * (T / (T - 1)) * (T / (T - 2)) * (T / (T - 3)) - + 4 * p * (q ** 3) * ((Y - 1) / X) * ((Y - 2) / X) * (T / (T - 1)) * (T / (T - 2)) * (T / (T - 3))), - _e20=((p ** 4) * ((X - 1) / X) * ((X - 2) / X) * ((X - 3) / X) * (T / (T - 1)) * (T / (T - 2)) * (T / (T - 3)) - + (q ** 4) * ((Y - 1) / Y) * ((Y - 2) / Y) * ((Y - 3) / Y) * (T / (T - 1)) * (T / (T - 2)) * (T / (T - 3)) - + 4 * (p ** 2) * (q ** 2) * ((X - 1) / X) * ((Y - 1) / Y) * (T / (T - 1)) * (T / (T - 2)) * (T / (T - 3))), - _e11=(2 * (p ** 2) * q * ((X - 1) / X) * (T / (T - 1)) * (T / (T - 2)) - + 2 * p * (q ** 2) * ((Y - 1) / Y) * (T / (T - 1)) * (T / (T - 2))), - _e21=((p ** 3) * ((X - 1) / X) * ((X - 2) / X) * (T / (T - 1)) * (T / (T - 2)) - + (q ** 3) * ((Y - 1) / Y) * ((Y - 2) / Y) * (T / (T - 1)) * (T / (T - 2)) - + (p ** 2) * q * ((X - 1) / X) * (T / (T - 1)) * (T / (T - 2)) - + p * (q ** 2) * ((Y - 1) / Y) * (T / (T - 1)) * (T / (T - 2))), - _e22=(T / 2) + _e00=(2 * (p**2) * (q**2) * ((X - 1) / X) * ((Y - 1) / Y) * (T / (T - 1)) * (T / (T - 2)) * (T / (T - 3))), + _e10=( + 4 * (p**3) * q * ((X - 1) / X) * ((X - 2) / X) * (T / (T - 1)) * (T / (T - 2)) * (T / (T - 3)) + + 4 * p * (q**3) * ((Y - 1) / Y) * ((Y - 2) / Y) * (T / (T - 1)) * (T / (T - 2)) * (T / (T - 3)) + ), + _e20=( + (p**4) * ((X - 1) / X) * ((X - 2) / X) * ((X - 3) / X) * (T / (T - 1)) * (T / (T - 2)) * (T / (T - 3)) + + (q**4) * ((Y - 1) / Y) * ((Y - 2) / Y) * ((Y - 3) / Y) * (T / (T - 1)) * (T / (T - 2)) * (T / (T - 3)) + + 4 * (p**2) * (q**2) * ((X - 1) / X) * ((Y - 1) / Y) * (T / (T - 1)) * (T / (T - 2)) * (T / (T - 3)) + ), + _e11=( + 2 * (p**2) * q * ((X - 1) / X) * (T / (T - 1)) * (T / (T - 2)) + + 2 * p * (q**2) * ((Y - 1) / Y) * (T / (T - 1)) * (T / (T - 2)) + ), + _e21=( + (p**3) * ((X - 1) / X) * ((X - 2) / X) * (T / (T - 1)) * (T / (T - 2)) + + (q**3) * ((Y - 1) / Y) * ((Y - 2) / Y) * (T / (T - 1)) * (T / (T - 2)) + + (p**2) * q * ((X - 1) / X) * (T / (T - 1)) * (T / (T - 2)) + + p * (q**2) * ((Y - 1) / Y) * (T / (T - 1)) * (T / (T - 2)) + ), + _e22=1, ) dataset = dataset.checkpoint(hl.utils.new_temp_file()) - expectations = dataset.aggregate_rows(hl.struct( - e00=hl.agg.sum(dataset._e00), - e10=hl.agg.sum(dataset._e10), - e20=hl.agg.sum(dataset._e20), - e11=hl.agg.sum(dataset._e11), - e21=hl.agg.sum(dataset._e21), - e22=hl.agg.sum(dataset._e22) - )) + expectations = dataset.aggregate_rows( + hl.struct( + e00=hl.agg.sum(dataset._e00), + e10=hl.agg.sum(dataset._e10), + e20=hl.agg.sum(dataset._e20), + e11=hl.agg.sum(dataset._e11), + e21=hl.agg.sum(dataset._e21), + e22=hl.agg.sum(dataset._e22), + ) + ) IS_HOM_REF = BlockMatrix.from_entry_expr(dataset.is_hom_ref).checkpoint(hl.utils.new_temp_file()) IS_HET = BlockMatrix.from_entry_expr(dataset.is_het).checkpoint(hl.utils.new_temp_file()) @@ -204,15 +215,25 @@ def convert_to_table(bm, annotation_name): result = z0.join(z1.join(z2).join(ibs0).join(ibs1).join(ibs2)) def bound_result(_ibd): - return (hl.case() - .when(_ibd.Z0 > 1, hl.struct(Z0=hl.float(1), Z1=hl.float(0), Z2=hl.float(0))) - .when(_ibd.Z1 > 1, hl.struct(Z0=hl.float(0), Z1=hl.float(1), Z2=hl.float(0))) - .when(_ibd.Z2 > 1, hl.struct(Z0=hl.float(0), Z1=hl.float(0), Z2=hl.float(1))) - .when(_ibd.Z0 < 0, hl.struct(Z0=hl.float(0), Z1=_ibd.Z1 / (_ibd.Z1 + _ibd.Z2), Z2=_ibd.Z2 / (_ibd.Z1 + _ibd.Z2))) - .when(_ibd.Z1 < 0, hl.struct(Z0=_ibd.Z0 / (_ibd.Z0 + _ibd.Z2), Z1=hl.float(0), Z2=_ibd.Z2 / (_ibd.Z0 + _ibd.Z2))) - .when(_ibd.Z2 < 0, hl.struct(Z0=_ibd.Z0 / (_ibd.Z0 + _ibd.Z1), Z1=_ibd.Z1 / (_ibd.Z0 + _ibd.Z1), Z2=hl.float(0))) - .default(_ibd) - ) + return ( + hl.case() + .when(_ibd.Z0 > 1, hl.struct(Z0=hl.float(1), Z1=hl.float(0), Z2=hl.float(0))) + .when(_ibd.Z1 > 1, hl.struct(Z0=hl.float(0), Z1=hl.float(1), Z2=hl.float(0))) + .when(_ibd.Z2 > 1, hl.struct(Z0=hl.float(0), Z1=hl.float(0), Z2=hl.float(1))) + .when( + _ibd.Z0 < 0, + hl.struct(Z0=hl.float(0), Z1=_ibd.Z1 / (_ibd.Z1 + _ibd.Z2), Z2=_ibd.Z2 / (_ibd.Z1 + _ibd.Z2)), + ) + .when( + _ibd.Z1 < 0, + hl.struct(Z0=_ibd.Z0 / (_ibd.Z0 + _ibd.Z2), Z1=hl.float(0), Z2=_ibd.Z2 / (_ibd.Z0 + _ibd.Z2)), + ) + .when( + _ibd.Z2 < 0, + hl.struct(Z0=_ibd.Z0 / (_ibd.Z0 + _ibd.Z1), Z1=_ibd.Z1 / (_ibd.Z0 + _ibd.Z1), Z2=hl.float(0)), + ) + .default(_ibd) + ) result = result.annotate(ibd=hl.struct(Z0=result.Z0, Z1=result.Z1, Z2=result.Z2)) result = result.drop('Z0', 'Z1', 'Z2') @@ -222,9 +243,6 @@ def bound_result(_ibd): result = result.filter((result.i < result.j) & (min <= result.ibd.PI_HAT) & (result.ibd.PI_HAT <= max)) samples = hl.literal(dataset.s.collect()) - result = result.key_by( - i=samples[hl.int32(result.i)], - j=samples[hl.int32(result.j)] - ) + result = result.key_by(i=samples[hl.int32(result.i)], j=samples[hl.int32(result.j)]) return result.persist() diff --git a/hail/python/hail/methods/relatedness/king.py b/hail/python/hail/methods/relatedness/king.py index 111d0fadf8e..75ed605cee1 100644 --- a/hail/python/hail/methods/relatedness/king.py +++ b/hail/python/hail/methods/relatedness/king.py @@ -1,8 +1,6 @@ import hail as hl - -from hail.expr.expressions import expr_call -from hail.expr.expressions import matrix_table_source -from hail.typecheck import typecheck, nullable +from hail.expr.expressions import expr_call, matrix_table_source +from hail.typecheck import nullable, typecheck from hail.utils import deduplicate from hail.utils.java import Env @@ -232,7 +230,7 @@ def king(call_expr, *, block_size=None): is_hom_ref: hl.float(hl.or_else(mt[call].is_hom_ref(), 0)), is_het: hl.float(hl.or_else(mt[call].is_het(), 0)), is_hom_var: hl.float(hl.or_else(mt[call].is_hom_var(), 0)), - is_defined: hl.float(hl.is_defined(mt[call])) + is_defined: hl.float(hl.is_defined(mt[call])), }) ref = hl.linalg.BlockMatrix.from_entry_expr(mt[is_hom_ref], block_size=block_size) het = hl.linalg.BlockMatrix.from_entry_expr(mt[is_het], block_size=block_size) @@ -257,16 +255,14 @@ def king(call_expr, *, block_size=None): kinship_between = het_hom_balance.rename({'element': 'het_hom_balance'}) kinship_between = kinship_between.annotate_entries( n_hets_row=n_hets_for_rows[kinship_between.row_key, kinship_between.col_key].element, - n_hets_col=n_hets_for_cols[kinship_between.row_key, kinship_between.col_key].element + n_hets_col=n_hets_for_cols[kinship_between.row_key, kinship_between.col_key].element, ) col_index_field = Env.get_uid() col_key = mt.col_key cols = mt.add_col_index(col_index_field).key_cols_by(col_index_field).cols() - kinship_between = kinship_between.key_cols_by( - **cols[kinship_between.col_idx].select(*col_key) - ) + kinship_between = kinship_between.key_cols_by(**cols[kinship_between.col_idx].select(*col_key)) renaming, _ = deduplicate(list(col_key), already_used=set(col_key)) assert len(renaming) == len(col_key) @@ -276,19 +272,17 @@ def king(call_expr, *, block_size=None): ) kinship_between = kinship_between.annotate_entries( - min_n_hets=hl.min(kinship_between.n_hets_row, - kinship_between.n_hets_col) + min_n_hets=hl.min(kinship_between.n_hets_row, kinship_between.n_hets_col) ) - return kinship_between.select_entries( - phi=( - 0.5 - ) + ( - ( - 2 * kinship_between.het_hom_balance + - - kinship_between.n_hets_row - - kinship_between.n_hets_col - ) / ( - 4 * kinship_between.min_n_hets + return ( + kinship_between.select_entries( + phi=(0.5) + + ( + (2 * kinship_between.het_hom_balance + -kinship_between.n_hets_row - kinship_between.n_hets_col) + / (4 * kinship_between.min_n_hets) ) ) - ).select_rows().select_cols().select_globals() + .select_rows() + .select_cols() + .select_globals() + ) diff --git a/hail/python/hail/methods/relatedness/mating_simulation.py b/hail/python/hail/methods/relatedness/mating_simulation.py index 4ade8f7d69c..e36e26ef03b 100644 --- a/hail/python/hail/methods/relatedness/mating_simulation.py +++ b/hail/python/hail/methods/relatedness/mating_simulation.py @@ -1,8 +1,9 @@ import hail as hl -from hail.typecheck import typecheck, numeric +from hail.matrixtable import MatrixTable +from hail.typecheck import numeric, typecheck -@typecheck(mt=hl.MatrixTable, n_rounds=int, generation_size_multiplier=numeric, keep_founders=bool) +@typecheck(mt=MatrixTable, n_rounds=int, generation_size_multiplier=numeric, keep_founders=bool) def simulate_random_mating(mt, n_rounds=1, generation_size_multiplier=1.0, keep_founders=True): """Simulate random diploid mating to produce new individuals. @@ -23,48 +24,68 @@ def simulate_random_mating(mt, n_rounds=1, generation_size_multiplier=1.0, keep_ """ if generation_size_multiplier <= 0: raise ValueError( - f"simulate_random_mating: 'generation_size_multiplier' must be greater than zero: got {generation_size_multiplier}") + f"simulate_random_mating: 'generation_size_multiplier' must be greater than zero: got {generation_size_multiplier}" + ) if n_rounds < 1: raise ValueError(f"simulate_random_mating: 'n_rounds' must be positive: got {n_rounds}") - ck = list(mt.col_key)[0] + ck = next(iter(mt.col_key)) mt = mt.select_entries('GT') ht = mt.localize_entries('__entries', '__cols') ht = ht.annotate_globals( - generation_0=hl.range(hl.len(ht.__cols)).map(lambda i: hl.struct(s=hl.str('generation_0_idx_') + hl.str(i), - original=hl.str(ht.__cols[i][ck]), - mother=hl.missing('int32'), - father=hl.missing('int32')))) + generation_0=hl.range(hl.len(ht.__cols)).map( + lambda i: hl.struct( + s=hl.str('generation_0_idx_') + hl.str(i), + original=hl.str(ht.__cols[i][ck]), + mother=hl.missing('int32'), + father=hl.missing('int32'), + ) + ) + ) def make_new_generation(prev_generation_tup, idx): prev_size = prev_generation_tup[1] n_new = hl.int32(hl.floor(prev_size * generation_size_multiplier)) new_generation = hl.range(n_new).map( - lambda i: hl.struct(s=hl.str('generation_') + hl.str(idx + 1) + hl.str('_idx_') + hl.str(i), - original=hl.missing('str'), - mother=hl.rand_int32(0, prev_size), - father=hl.rand_int32(0, prev_size))) + lambda i: hl.struct( + s=hl.str('generation_') + hl.str(idx + 1) + hl.str('_idx_') + hl.str(i), + original=hl.missing('str'), + mother=hl.rand_int32(0, prev_size), + father=hl.rand_int32(0, prev_size), + ) + ) return (new_generation, (prev_size + n_new) if keep_founders else n_new) - ht = ht.annotate_globals(generations=hl.range(n_rounds).scan(lambda prev, idx: make_new_generation(prev, idx), - (ht.generation_0, hl.len(ht.generation_0)))) + ht = ht.annotate_globals( + generations=hl.range(n_rounds).scan( + lambda prev, idx: make_new_generation(prev, idx), (ht.generation_0, hl.len(ht.generation_0)) + ) + ) def simulate_mating_calls(prev_generation_calls, new_generation): - new_samples = new_generation.map(lambda samp: hl.call(prev_generation_calls[samp.mother][hl.rand_int32(0, 2)], - prev_generation_calls[samp.father][hl.rand_int32(0, 2)])) + new_samples = new_generation.map( + lambda samp: hl.call( + prev_generation_calls[samp.mother][hl.rand_int32(0, 2)], + prev_generation_calls[samp.father][hl.rand_int32(0, 2)], + ) + ) if keep_founders: return prev_generation_calls.extend(new_samples) else: return new_samples - ht = ht.annotate(__new_entries=hl.fold( - lambda prev_calls, generation_metadata: simulate_mating_calls(prev_calls, generation_metadata[0]), - ht.__entries.GT, - ht.generations[1:]).map(lambda gt: hl.struct(GT=gt))) + ht = ht.annotate( + __new_entries=hl.fold( + lambda prev_calls, generation_metadata: simulate_mating_calls(prev_calls, generation_metadata[0]), + ht.__entries.GT, + ht.generations[1:], + ).map(lambda gt: hl.struct(GT=gt)) + ) ht = ht.annotate_globals( - __new_cols=ht.generations.flatmap(lambda x: x[0]) if keep_founders else ht.generations[-1][0]) + __new_cols=ht.generations.flatmap(lambda x: x[0]) if keep_founders else ht.generations[-1][0] + ) ht = ht.drop('__entries', '__cols', 'generation_0', 'generations') return ht._unlocalize_entries('__new_entries', '__new_cols', list('s')) diff --git a/hail/python/hail/methods/relatedness/pc_relate.py b/hail/python/hail/methods/relatedness/pc_relate.py index 37be5b04f6e..8b141ddf8c8 100644 --- a/hail/python/hail/methods/relatedness/pc_relate.py +++ b/hail/python/hail/methods/relatedness/pc_relate.py @@ -4,35 +4,48 @@ import hail.expr.aggregators as agg from hail import ir from hail.backend.spark_backend import SparkBackend -from hail.expr import (ArrayNumericExpression, BooleanExpression, CallExpression, - Float64Expression, analyze, expr_array, expr_call, - expr_float64, matrix_table_source) +from hail.expr import ( + ArrayNumericExpression, + BooleanExpression, + CallExpression, + Float64Expression, + analyze, + expr_array, + expr_call, + expr_float64, + matrix_table_source, +) from hail.expr.types import tarray from hail.linalg import BlockMatrix from hail.table import Table from hail.typecheck import enumeration, nullable, numeric, typecheck from hail.utils import new_temp_file from hail.utils.java import Env + from ..pca import _hwe_normalized_blanczos, hwe_normalized_pca -@typecheck(call_expr=expr_call, - min_individual_maf=numeric, - k=nullable(int), - scores_expr=nullable(expr_array(expr_float64)), - min_kinship=nullable(numeric), - statistics=enumeration('kin', 'kin2', 'kin20', 'all'), - block_size=nullable(int), - include_self_kinship=bool) -def pc_relate(call_expr: CallExpression, - min_individual_maf: float, - *, - k: Optional[int] = None, - scores_expr: Optional[ArrayNumericExpression] = None, - min_kinship: Optional[float] = None, - statistics: str = 'all', - block_size: Optional[int] = None, - include_self_kinship: bool = False) -> Table: +@typecheck( + call_expr=expr_call, + min_individual_maf=numeric, + k=nullable(int), + scores_expr=nullable(expr_array(expr_float64)), + min_kinship=nullable(numeric), + statistics=enumeration('kin', 'kin2', 'kin20', 'all'), + block_size=nullable(int), + include_self_kinship=bool, +) +def pc_relate( + call_expr: CallExpression, + min_individual_maf: float, + *, + k: Optional[int] = None, + scores_expr: Optional[ArrayNumericExpression] = None, + min_kinship: Optional[float] = None, + statistics: str = 'all', + block_size: Optional[int] = None, + include_self_kinship: bool = False, +) -> Table: r"""Compute relatedness estimates between individuals using a variant of the PC-Relate method. @@ -301,14 +314,16 @@ def pc_relate(call_expr: CallExpression, A :class:`.Table` mapping pairs of samples to their pair-wise statistics. """ if not isinstance(Env.backend(), SparkBackend): - return _pc_relate_bm(call_expr, - min_individual_maf, - k=k, - scores_expr=scores_expr, - min_kinship=min_kinship, - statistics=statistics, - block_size=block_size, - include_self_kinship=include_self_kinship) + return _pc_relate_bm( + call_expr, + min_individual_maf, + k=k, + scores_expr=scores_expr, + min_kinship=min_kinship, + statistics=statistics, + block_size=block_size, + include_self_kinship=include_self_kinship, + ) mt = matrix_table_source('pc_relate/call_expr', call_expr) @@ -322,8 +337,7 @@ def pc_relate(call_expr: CallExpression, else: raise ValueError("pc_relate: exactly one of 'k' and 'scores_expr' must be set, found neither") - scores_table = mt.select_cols(__scores=scores_expr) \ - .key_cols_by().select_cols('__scores').cols() + scores_table = mt.select_cols(__scores=scores_expr).key_cols_by().select_cols('__scores').cols() n_missing = scores_table.aggregate(agg.count_where(hl.is_missing(scores_table.__scores))) if n_missing > 0: @@ -336,17 +350,23 @@ def pc_relate(call_expr: CallExpression, if not block_size: block_size = BlockMatrix.default_block_size() - g = BlockMatrix.from_entry_expr(mean_imputed_gt, - block_size=block_size) + g = BlockMatrix.from_entry_expr(mean_imputed_gt, block_size=block_size) pcs = scores_table.collect(_localize=False).map(lambda x: x.__scores) - ht = Table(ir.BlockMatrixToTableApply(g._bmir, pcs._ir, { - 'name': 'PCRelate', - 'maf': min_individual_maf, - 'blockSize': block_size, - 'minKinship': min_kinship, - 'statistics': {'kin': 0, 'kin2': 1, 'kin20': 2, 'all': 3}[statistics]})).persist() + ht = Table( + ir.BlockMatrixToTableApply( + g._bmir, + pcs._ir, + { + 'name': 'PCRelate', + 'maf': min_individual_maf, + 'blockSize': block_size, + 'minKinship': min_kinship, + 'statistics': {'kin': 0, 'kin2': 1, 'kin20': 2, 'all': 3}[statistics], + }, + ) + ).persist() if statistics == 'kin': ht = ht.drop('ibd0', 'ibd1', 'ibd2') @@ -413,12 +433,14 @@ def _dominance_encoding(g: Float64Expression, mu: Float64Expression) -> Float64E gd : :class:`.Float64Expression` Dominance-coded entry for dominance-coded genotype matrix. """ - gd = hl.case() \ - .when(hl.is_nan(mu), 0.0) \ - .when(g == 0.0, mu) \ - .when(g == 1.0, 0.0) \ - .when(g == 2.0, 1 - mu) \ + gd = ( + hl.case() + .when(hl.is_nan(mu), 0.0) + .when(g == 0.0, mu) + .when(g == 1.0, 0.0) + .when(g == 2.0, 1 - mu) .or_error('entries in genotype matrix must be 0.0, 1.0, or 2.0') + ) return gd @@ -455,26 +477,31 @@ def _replace_nan(M: BlockMatrix, value: float) -> BlockMatrix: return M._map_dense(lambda x: hl.if_else(hl.is_nan(x), value, x)) -@typecheck(call_expr=expr_call, - min_individual_maf=numeric, - k=nullable(int), - scores_expr=nullable(expr_array(expr_float64)), - min_kinship=nullable(numeric), - statistics=enumeration('kin', 'kin2', 'kin20', 'all'), - block_size=nullable(int), - include_self_kinship=bool) -def _pc_relate_bm(call_expr: CallExpression, - min_individual_maf: float, - *, - k: Optional[int] = None, - scores_expr: Optional[ArrayNumericExpression] = None, - min_kinship: Optional[float] = None, - statistics: str = "all", - block_size: Optional[int] = None, - include_self_kinship: bool = False) -> Table: - assert (0.0 <= min_individual_maf <= 1.0), \ - f'invalid argument: min_individual_maf={min_individual_maf}. ' \ +@typecheck( + call_expr=expr_call, + min_individual_maf=numeric, + k=nullable(int), + scores_expr=nullable(expr_array(expr_float64)), + min_kinship=nullable(numeric), + statistics=enumeration('kin', 'kin2', 'kin20', 'all'), + block_size=nullable(int), + include_self_kinship=bool, +) +def _pc_relate_bm( + call_expr: CallExpression, + min_individual_maf: float, + *, + k: Optional[int] = None, + scores_expr: Optional[ArrayNumericExpression] = None, + min_kinship: Optional[float] = None, + statistics: str = "all", + block_size: Optional[int] = None, + include_self_kinship: bool = False, +) -> Table: + assert 0.0 <= min_individual_maf <= 1.0, ( + f'invalid argument: min_individual_maf={min_individual_maf}. ' f'Must have min_individual_maf on interval [0.0, 1.0].' + ) mt = matrix_table_source('pc_relate_bm/call_expr', call_expr) if k and scores_expr is None: eigens, scores, _ = _hwe_normalized_blanczos(call_expr, k, compute_loadings=False, q_iterations=10) @@ -486,11 +513,9 @@ def _pc_relate_bm(call_expr: CallExpression, scores_table = mt.select_cols(__scores=scores_expr).key_cols_by().select_cols('__scores').cols() compute_S0 = True elif k and scores_expr is not None: - raise ValueError("pc_relate_bm: exactly one of 'k' and 'scores_expr' " - "must be set, found both") + raise ValueError("pc_relate_bm: exactly one of 'k' and 'scores_expr' " "must be set, found both") else: - raise ValueError("pc_relate_bm: exactly one of 'k' and 'scores_expr' " - "must be set, found neither") + raise ValueError("pc_relate_bm: exactly one of 'k' and 'scores_expr' " "must be set, found neither") n_missing = scores_table.aggregate(agg.count_where(hl.is_missing(scores_table.__scores))) if n_missing > 0: @@ -529,9 +554,11 @@ def _pc_relate_bm(call_expr: CallExpression, # Compute matrix of individual-specific AF estimates (mu), shape (m, n) mu = 0.5 * (BlockMatrix.from_ndarray(V * S, block_size=block_size) @ beta).T # Replace entries in mu with NaN if invalid or if corresponding GT is missing (no contribution from that variant) - mu = mu._apply_map2(lambda _mu, _g: hl.if_else(_bad_mu(_mu, min_individual_maf) | hl.is_nan(_g), nan, _mu), - g, - sparsity_strategy='NeedsDense') + mu = mu._apply_map2( + lambda _mu, _g: hl.if_else(_bad_mu(_mu, min_individual_maf) | hl.is_nan(_g), nan, _mu), + g, + sparsity_strategy='NeedsDense', + ) mu = mu.checkpoint(new_temp_file('pc_relate_bm/mu', 'bm')) # Compute kinship matrix (phi), shape (n, n) @@ -541,9 +568,7 @@ def _pc_relate_bm(call_expr: CallExpression, phi = _gram(centered_af) / (4.0 * _gram(variance.sqrt())) phi = phi.checkpoint(new_temp_file('pc_relate_bm/phi', 'bm')) ht = phi.entries().rename({'entry': 'kin'}) - ht = ht.annotate(k0=hl.missing(hl.tfloat64), - k1=hl.missing(hl.tfloat64), - k2=hl.missing(hl.tfloat64)) + ht = ht.annotate(k0=hl.missing(hl.tfloat64), k1=hl.missing(hl.tfloat64), k2=hl.missing(hl.tfloat64)) if statistics in ['kin2', 'kin20', 'all']: # Compute inbreeding coefficient and dominance encoding of GT matrix @@ -557,16 +582,16 @@ def _pc_relate_bm(call_expr: CallExpression, if statistics in ['kin20', 'all']: # Get the numerator used in IBD0 (k0) computation (IBS0), compute indicator matrices for homozygotes - hom_alt = g._apply_map2(lambda _g, _mu: hl.if_else((_g != 2.0) | hl.is_nan(_mu), 0.0, 1.0), - mu, - sparsity_strategy='NeedsDense') - hom_ref = g._apply_map2(lambda _g, _mu: hl.if_else((_g != 0.0) | hl.is_nan(_mu), 0.0, 1.0), - mu, - sparsity_strategy='NeedsDense') + hom_alt = g._apply_map2( + lambda _g, _mu: hl.if_else((_g != 2.0) | hl.is_nan(_mu), 0.0, 1.0), mu, sparsity_strategy='NeedsDense' + ) + hom_ref = g._apply_map2( + lambda _g, _mu: hl.if_else((_g != 0.0) | hl.is_nan(_mu), 0.0, 1.0), mu, sparsity_strategy='NeedsDense' + ) ibs0 = _AtB_plus_BtA(hom_alt, hom_ref) # Get the denominator used in IBD0 (k0) computation - mu2 = _replace_nan(mu ** 2.0, 0.0) + mu2 = _replace_nan(mu**2.0, 0.0) one_minus_mu2 = _replace_nan((1.0 - mu) ** 2.0, 0.0) k0_denom = _AtB_plus_BtA(mu2, one_minus_mu2) @@ -588,15 +613,11 @@ def _pc_relate_bm(call_expr: CallExpression, ht = ht.filter(ht.kin >= min_kinship) if statistics != 'all': - fields_to_drop = { - 'kin': ['ibd0', 'ibd1', 'ibd2'], - 'kin2': ['ibd0', 'ibd1'], - 'kin20': ['ibd1']} + fields_to_drop = {'kin': ['ibd0', 'ibd1', 'ibd2'], 'kin2': ['ibd0', 'ibd1'], 'kin20': ['ibd1']} ht = ht.drop(*fields_to_drop[statistics]) if not include_self_kinship: ht = ht.filter(ht.i == ht.j, keep=False) - col_keys = hl.literal(mt.select_cols().key_cols_by().cols().collect(), - dtype=hl.tarray(mt.col_key.dtype)) + col_keys = hl.literal(mt.select_cols().key_cols_by().cols().collect(), dtype=hl.tarray(mt.col_key.dtype)) return ht.key_by(i=col_keys[hl.int32(ht.i)], j=col_keys[hl.int32(ht.j)]) diff --git a/hail/python/hail/methods/statgen.py b/hail/python/hail/methods/statgen.py index 89f87890ce3..26ab2e74da1 100644 --- a/hail/python/hail/methods/statgen.py +++ b/hail/python/hail/methods/statgen.py @@ -1,30 +1,42 @@ import builtins import itertools import math -from typing import Dict, Callable, Optional, Union, Tuple, List +from typing import Callable, Dict, List, Optional, Tuple, Union -import hail import hail as hl import hail.expr.aggregators as agg from hail import ir -from hail.expr import (Expression, ExpressionException, expr_float64, expr_call, - expr_any, expr_numeric, expr_locus, analyze, raise_unless_entry_indexed, - raise_unless_row_indexed, matrix_table_source, table_source, - raise_unless_column_indexed) -from hail.expr.types import tbool, tarray, tfloat64, tint32 +from hail.expr import ( + Expression, + ExpressionException, + NDArrayNumericExpression, + StructExpression, + analyze, + expr_any, + expr_call, + expr_float64, + expr_locus, + expr_numeric, + matrix_table_source, + raise_unless_column_indexed, + raise_unless_entry_indexed, + raise_unless_row_indexed, + table_source, +) +from hail.expr.functions import expit +from hail.expr.types import tarray, tbool, tfloat64, tint32, tndarray, tstruct from hail.genetics.reference_genome import reference_genome_type from hail.linalg import BlockMatrix from hail.matrixtable import MatrixTable from hail.methods.misc import require_biallelic, require_row_key_variant from hail.stats import LinearMixedModel from hail.table import Table -from hail.typecheck import (typecheck, nullable, numeric, oneof, sized_tupleof, - sequenceof, enumeration, anytype) -from hail.utils import wrap_to_list, new_temp_file, FatalError +from hail.typecheck import anytype, enumeration, nullable, numeric, oneof, sequenceof, sized_tupleof, typecheck +from hail.utils import FatalError, new_temp_file, wrap_to_list from hail.utils.java import Env, info, warning -from . import pca -from . import relatedness + from ..backend.spark_backend import SparkBackend +from . import pca, relatedness pc_relate = relatedness.pc_relate identity_by_descent = relatedness.identity_by_descent @@ -36,25 +48,28 @@ pca = pca.pca -tvector64 = hl.tndarray(hl.tfloat64, 1) -tmatrix64 = hl.tndarray(hl.tfloat64, 2) -numerical_regression_fit_dtype = hl.tstruct( +tvector64 = tndarray(tfloat64, 1) +tmatrix64 = tndarray(tfloat64, 2) +numerical_regression_fit_dtype = tstruct( b=tvector64, score=tvector64, fisher=tmatrix64, mu=tvector64, - n_iterations=hl.tint32, - log_lkhd=hl.tfloat64, - converged=hl.tbool, - exploded=hl.tbool) - - -@typecheck(call=expr_call, - aaf_threshold=numeric, - include_par=bool, - female_threshold=numeric, - male_threshold=numeric, - aaf=nullable(str)) + n_iterations=tint32, + log_lkhd=tfloat64, + converged=tbool, + exploded=tbool, +) + + +@typecheck( + call=expr_call, + aaf_threshold=numeric, + include_par=bool, + female_threshold=numeric, + male_threshold=numeric, + aaf=nullable(str), +) def impute_sex(call, aaf_threshold=0.0, include_par=False, female_threshold=0.2, male_threshold=0.8, aaf=None) -> Table: r"""Impute sex of samples by calculating inbreeding coefficient on the X chromosome. @@ -153,71 +168,71 @@ def impute_sex(call, aaf_threshold=0.0, include_par=False, female_threshold=0.2, mt, _ = mt._process_joins(call) mt = mt.annotate_entries(call=call) mt = require_biallelic(mt, 'impute_sex') - if (aaf is None): + if aaf is None: mt = mt.annotate_rows(aaf=agg.call_stats(mt.call, mt.alleles).AF[1]) aaf = 'aaf' rg = mt.locus.dtype.reference_genome - mt = hl.filter_intervals(mt, - hl.map(lambda x_contig: hl.parse_locus_interval(x_contig, rg), rg.x_contigs), - keep=True) + mt = hl.filter_intervals( + mt, hl.map(lambda x_contig: hl.parse_locus_interval(x_contig, rg), rg.x_contigs), keep=True + ) if not include_par: interval_type = hl.tarray(hl.tinterval(hl.tlocus(rg))) - mt = hl.filter_intervals(mt, - hl.literal(rg.par, interval_type), - keep=False) + mt = hl.filter_intervals(mt, hl.literal(rg.par, interval_type), keep=False) mt = mt.filter_rows((mt[aaf] > aaf_threshold) & (mt[aaf] < (1 - aaf_threshold))) mt = mt.annotate_cols(ib=agg.inbreeding(mt.call, mt[aaf])) kt = mt.select_cols( - is_female=hl.if_else(mt.ib.f_stat < female_threshold, - True, - hl.if_else(mt.ib.f_stat > male_threshold, - False, - hl.missing(tbool))), - **mt.ib).cols() + is_female=hl.if_else( + mt.ib.f_stat < female_threshold, True, hl.if_else(mt.ib.f_stat > male_threshold, False, hl.missing(tbool)) + ), + **mt.ib, + ).cols() return kt def _get_regression_row_fields(mt, pass_through, method) -> Dict[str, str]: - row_fields = dict(zip(mt.row_key.keys(), mt.row_key.keys())) for f in pass_through: if isinstance(f, str): if f not in mt.row: - raise ValueError(f"'{method}/pass_through': MatrixTable has no row field {repr(f)}") + raise ValueError(f"'{method}/pass_through': MatrixTable has no row field {f!r}") if f in row_fields: # allow silent pass through of key fields if f in mt.row_key: pass else: - raise ValueError(f"'{method}/pass_through': found duplicated field {repr(f)}") + raise ValueError(f"'{method}/pass_through': found duplicated field {f!r}") row_fields[f] = mt[f] else: assert isinstance(f, Expression) if not f._ir.is_nested_field: raise ValueError(f"'{method}/pass_through': expect fields or nested fields, not complex expressions") if not f._indices == mt._row_indices: - raise ExpressionException(f"'{method}/pass_through': require row-indexed fields, found indices {f._indices.axes}") + raise ExpressionException( + f"'{method}/pass_through': require row-indexed fields, found indices {f._indices.axes}" + ) name = f._ir.name if name in row_fields: # allow silent pass through of key fields if not (name in mt.row_key and f._ir == mt[name]._ir): - raise ValueError(f"'{method}/pass_through': found duplicated field {repr(name)}") + raise ValueError(f"'{method}/pass_through': found duplicated field {name!r}") row_fields[name] = f for k in mt.row_key: del row_fields[k] return row_fields -@typecheck(y=oneof(expr_float64, sequenceof(expr_float64), sequenceof(sequenceof(expr_float64))), - x=expr_float64, - covariates=sequenceof(expr_float64), - block_size=int, - pass_through=sequenceof(oneof(str, Expression)), - weights=nullable(oneof(expr_float64, sequenceof(expr_float64)))) -def linear_regression_rows(y, x, covariates, block_size=16, pass_through=(), *, weights=None) -> hail.Table: +@typecheck( + y=oneof(expr_float64, sequenceof(expr_float64), sequenceof(sequenceof(expr_float64))), + x=expr_float64, + covariates=sequenceof(expr_float64), + block_size=int, + pass_through=sequenceof(oneof(str, Expression)), + weights=nullable(oneof(expr_float64, sequenceof(expr_float64))), +) +def linear_regression_rows(y, x, covariates, block_size=16, pass_through=(), *, weights=None) -> Table: r"""For each row, test an input variable for association with response variables using linear regression. @@ -339,12 +354,13 @@ def linear_regression_rows(y, x, covariates, block_size=16, pass_through=(), *, if is_chained and any(len(lst) == 0 for lst in y): raise ValueError("'linear_regression_rows': found empty inner list for 'y'") - y = [raise_unless_column_indexed('linear_regression_rows_nd/y', y) or ys - for ys in wrap_to_list(y) - for y in (ys if is_chained else [ys]) - ] + y = [ + raise_unless_column_indexed('linear_regression_rows_nd/y', y) or ys + for ys in wrap_to_list(y) + for y in (ys if is_chained else [ys]) + ] - for e in (itertools.chain.from_iterable(y) if is_chained else y): + for e in itertools.chain.from_iterable(y) if is_chained else y: analyze('linear_regression_rows/y', e, mt._col_indices) for e in covariates: @@ -368,11 +384,12 @@ def linear_regression_rows(y, x, covariates, block_size=16, pass_through=(), *, row_fields = _get_regression_row_fields(mt, pass_through, 'linear_regression_rows') # FIXME: selecting an existing entry field should be emitted as a SelectFields - mt = mt._select_all(col_exprs=dict(**y_dict, - **dict(zip(cov_field_names, covariates))), - row_exprs=row_fields, - col_key=[], - entry_exprs={x_field_name: x}) + mt = mt._select_all( + col_exprs=dict(**y_dict, **dict(zip(cov_field_names, covariates))), + row_exprs=row_fields, + col_key=[], + entry_exprs={x_field_name: x}, + ) config = { 'name': func, @@ -380,7 +397,7 @@ def linear_regression_rows(y, x, covariates, block_size=16, pass_through=(), *, 'xField': x_field_name, 'covFields': cov_field_names, 'rowBlockSize': block_size, - 'passThrough': [x for x in row_fields if x not in mt.row_key] + 'passThrough': [x for x in row_fields if x not in mt.row_key], } ht_result = Table(ir.MatrixToTableApply(mt._mir, config)) @@ -391,13 +408,15 @@ def linear_regression_rows(y, x, covariates, block_size=16, pass_through=(), *, return ht_result.persist() -@typecheck(y=oneof(expr_float64, sequenceof(expr_float64), sequenceof(sequenceof(expr_float64))), - x=expr_float64, - covariates=sequenceof(expr_float64), - block_size=int, - weights=nullable(oneof(expr_float64, sequenceof(expr_float64))), - pass_through=sequenceof(oneof(str, Expression))) -def _linear_regression_rows_nd(y, x, covariates, block_size=16, weights=None, pass_through=()) -> hail.Table: +@typecheck( + y=oneof(expr_float64, sequenceof(expr_float64), sequenceof(sequenceof(expr_float64))), + x=expr_float64, + covariates=sequenceof(expr_float64), + block_size=int, + weights=nullable(oneof(expr_float64, sequenceof(expr_float64))), + pass_through=sequenceof(oneof(str, Expression)), +) +def _linear_regression_rows_nd(y, x, covariates, block_size=16, weights=None, pass_through=()) -> Table: mt = matrix_table_source('linear_regression_rows_nd/x', x) raise_unless_entry_indexed('linear_regression_rows_nd/x', x) @@ -409,10 +428,11 @@ def _linear_regression_rows_nd(y, x, covariates, block_size=16, weights=None, pa if is_chained and any(len(lst) == 0 for lst in y): raise ValueError("'linear_regression_rows': found empty inner list for 'y'") - y = [raise_unless_column_indexed('linear_regression_rows_nd/y', y) or ys - for ys in wrap_to_list(y) - for y in (ys if is_chained else [ys]) - ] + y = [ + raise_unless_column_indexed('linear_regression_rows_nd/y', y) or ys + for ys in wrap_to_list(y) + for y in (ys if is_chained else [ys]) + ] if weights is not None: if y_is_list and is_chained and not isinstance(weights, list): @@ -424,7 +444,7 @@ def _linear_regression_rows_nd(y, x, covariates, block_size=16, weights=None, pa weights = wrap_to_list(weights) if weights is not None else None - for e in (itertools.chain.from_iterable(y) if is_chained else y): + for e in itertools.chain.from_iterable(y) if is_chained else y: analyze('linear_regression_rows_nd/y', e, mt._col_indices) for e in covariates: @@ -453,12 +473,12 @@ def _linear_regression_rows_nd(y, x, covariates, block_size=16, weights=None, pa row_field_names = _get_regression_row_fields(mt, pass_through, 'linear_regression_rows_nd') # FIXME: selecting an existing entry field should be emitted as a SelectFields - mt = mt._select_all(col_exprs=dict(**y_dict, - **weight_dict, - **dict(zip(cov_field_names, covariates))), - row_exprs=row_field_names, - col_key=[], - entry_exprs={x_field_name: x}) + mt = mt._select_all( + col_exprs=dict(**y_dict, **weight_dict, **dict(zip(cov_field_names, covariates))), + row_exprs=row_field_names, + col_key=[], + entry_exprs={x_field_name: x}, + ) entries_field_name = 'ent' sample_field_name = "by_sample" @@ -487,58 +507,110 @@ def no_missing(hail_array): def setup_globals(ht): # cov_arrays is per sample, then per cov. if covariates: - ht = ht.annotate_globals(cov_arrays=ht[sample_field_name].map(lambda sample_struct: [sample_struct[cov_name] for cov_name in cov_field_names])) + ht = ht.annotate_globals( + cov_arrays=ht[sample_field_name].map( + lambda sample_struct: [sample_struct[cov_name] for cov_name in cov_field_names] + ) + ) else: - ht = ht.annotate_globals(cov_arrays=ht[sample_field_name].map(lambda sample_struct: hl.empty_array(hl.tfloat64))) + ht = ht.annotate_globals( + cov_arrays=ht[sample_field_name].map(lambda sample_struct: hl.empty_array(hl.tfloat64)) + ) - y_arrays_per_group = [ht[sample_field_name].map(lambda sample_struct: [sample_struct[y_name] for y_name in one_y_field_name_set]) for one_y_field_name_set in y_field_name_groups] + y_arrays_per_group = [ + ht[sample_field_name].map(lambda sample_struct: [sample_struct[y_name] for y_name in one_y_field_name_set]) + for one_y_field_name_set in y_field_name_groups + ] if weight_field_names: - weight_arrays = ht[sample_field_name].map(lambda sample_struct: [sample_struct[weight_name] for weight_name in weight_field_names]) + weight_arrays = ht[sample_field_name].map( + lambda sample_struct: [sample_struct[weight_name] for weight_name in weight_field_names] + ) else: weight_arrays = ht[sample_field_name].map(lambda sample_struct: hl.empty_array(hl.tfloat64)) - ht = ht.annotate_globals( - y_arrays_per_group=y_arrays_per_group, - weight_arrays=weight_arrays - ) + ht = ht.annotate_globals(y_arrays_per_group=y_arrays_per_group, weight_arrays=weight_arrays) ht = ht.annotate_globals(all_covs_defined=ht.cov_arrays.map(lambda sample_covs: no_missing(sample_covs))) def get_kept_samples(group_idx, sample_ys): # sample_ys is an array of samples, with each element being an array of the y_values - return hl.enumerate(sample_ys).filter( - lambda idx_and_y_values: ht.all_covs_defined[idx_and_y_values[0]] & no_missing(idx_and_y_values[1]) & (hl.is_defined(ht.weight_arrays[idx_and_y_values[0]][group_idx]) if weights else True) - ).map(lambda idx_and_y_values: idx_and_y_values[0]) + return ( + hl.enumerate(sample_ys) + .filter( + lambda idx_and_y_values: ht.all_covs_defined[idx_and_y_values[0]] + & no_missing(idx_and_y_values[1]) + & (hl.is_defined(ht.weight_arrays[idx_and_y_values[0]][group_idx]) if weights else True) + ) + .map(lambda idx_and_y_values: idx_and_y_values[0]) + ) ht = ht.annotate_globals(kept_samples=hl.enumerate(ht.y_arrays_per_group).starmap(get_kept_samples)) - ht = ht.annotate_globals(y_nds=hl.zip(ht.kept_samples, ht.y_arrays_per_group).starmap( - lambda sample_indices, y_arrays: hl.nd.array(sample_indices.map(lambda idx: y_arrays[idx])))) - ht = ht.annotate_globals(cov_nds=ht.kept_samples.map(lambda group: hl.nd.array(group.map(lambda idx: ht.cov_arrays[idx])))) + ht = ht.annotate_globals( + y_nds=hl.zip(ht.kept_samples, ht.y_arrays_per_group).starmap( + lambda sample_indices, y_arrays: hl.nd.array(sample_indices.map(lambda idx: y_arrays[idx])) + ) + ) + ht = ht.annotate_globals( + cov_nds=ht.kept_samples.map(lambda group: hl.nd.array(group.map(lambda idx: ht.cov_arrays[idx]))) + ) if weights is None: ht = ht.annotate_globals(sqrt_weights=hl.missing(hl.tarray(hl.tndarray(hl.tfloat64, 2)))) ht = ht.annotate_globals(scaled_y_nds=ht.y_nds) ht = ht.annotate_globals(scaled_cov_nds=ht.cov_nds) else: - ht = ht.annotate_globals(weight_nds=hl.enumerate(ht.kept_samples).starmap( - lambda group_idx, group_sample_indices: hl.nd.array(group_sample_indices.map(lambda group_sample_idx: ht.weight_arrays[group_sample_idx][group_idx])))) - ht = ht.annotate_globals(sqrt_weights=ht.weight_nds.map(lambda weight_nd: weight_nd.map(lambda e: hl.sqrt(e)))) - ht = ht.annotate_globals(scaled_y_nds=hl.zip(ht.y_nds, ht.sqrt_weights).starmap(lambda y, sqrt_weight: y * sqrt_weight.reshape(-1, 1))) - ht = ht.annotate_globals(scaled_cov_nds=hl.zip(ht.cov_nds, ht.sqrt_weights).starmap(lambda cov, sqrt_weight: cov * sqrt_weight.reshape(-1, 1))) + ht = ht.annotate_globals( + weight_nds=hl.enumerate(ht.kept_samples).starmap( + lambda group_idx, group_sample_indices: hl.nd.array( + group_sample_indices.map(lambda group_sample_idx: ht.weight_arrays[group_sample_idx][group_idx]) + ) + ) + ) + ht = ht.annotate_globals( + sqrt_weights=ht.weight_nds.map(lambda weight_nd: weight_nd.map(lambda e: hl.sqrt(e))) + ) + ht = ht.annotate_globals( + scaled_y_nds=hl.zip(ht.y_nds, ht.sqrt_weights).starmap( + lambda y, sqrt_weight: y * sqrt_weight.reshape(-1, 1) + ) + ) + ht = ht.annotate_globals( + scaled_cov_nds=hl.zip(ht.cov_nds, ht.sqrt_weights).starmap( + lambda cov, sqrt_weight: cov * sqrt_weight.reshape(-1, 1) + ) + ) k = builtins.len(covariates) ht = ht.annotate_globals(ns=ht.kept_samples.map(lambda one_sample_set: hl.len(one_sample_set))) def log_message(i): if is_chained: - return "linear regression_rows[" + hl.str(i) + "] running on " + hl.str(ht.ns[i]) + " samples for " + hl.str(ht.scaled_y_nds[i].shape[1]) + f" response variables y, with input variables x, and {len(covariates)} additional covariates..." + return ( + "linear regression_rows[" + + hl.str(i) + + "] running on " + + hl.str(ht.ns[i]) + + " samples for " + + hl.str(ht.scaled_y_nds[i].shape[1]) + + f" response variables y, with input variables x, and {len(covariates)} additional covariates..." + ) else: - return "linear_regression_rows running on " + hl.str(ht.ns[0]) + " samples for " + hl.str(ht.scaled_y_nds[i].shape[1]) + f" response variables y, with input variables x, and {len(covariates)} additional covariates..." + return ( + "linear_regression_rows running on " + + hl.str(ht.ns[0]) + + " samples for " + + hl.str(ht.scaled_y_nds[i].shape[1]) + + f" response variables y, with input variables x, and {len(covariates)} additional covariates..." + ) ht = ht.annotate_globals(ns=hl.range(num_y_lists).map(lambda i: hl._console_log(log_message(i), ht.ns[i]))) - ht = ht.annotate_globals(cov_Qts=hl.if_else(k > 0, - ht.scaled_cov_nds.map(lambda one_cov_nd: hl.nd.qr(one_cov_nd)[0].T), - ht.ns.map(lambda n: hl.nd.zeros((0, n))))) + ht = ht.annotate_globals( + cov_Qts=hl.if_else( + k > 0, + ht.scaled_cov_nds.map(lambda one_cov_nd: hl.nd.qr(one_cov_nd)[0].T), + ht.ns.map(lambda n: hl.nd.zeros((0, n))), + ) + ) ht = ht.annotate_globals(Qtys=hl.zip(ht.cov_Qts, ht.scaled_y_nds).starmap(lambda cov_qt, y: cov_qt @ y)) return ht.select_globals( @@ -549,7 +621,10 @@ def log_message(i): ds=ht.ns.map(lambda n: n - k - 1), __cov_Qts=ht.cov_Qts, __Qtys=ht.Qtys, - __yyps=hl.range(num_y_lists).map(lambda i: dot_rows_with_themselves(ht.scaled_y_nds[i].T) - dot_rows_with_themselves(ht.Qtys[i].T))) + __yyps=hl.range(num_y_lists).map( + lambda i: dot_rows_with_themselves(ht.scaled_y_nds[i].T) - dot_rows_with_themselves(ht.Qtys[i].T) + ), + ) ht = setup_globals(ht) @@ -559,9 +634,20 @@ def process_block(block): # Processes one block group based on given idx. Returns a single struct. def process_y_group(idx): if weights is not None: - X = (hl.nd.array(block[entries_field_name].map(lambda row: mean_impute(select_array_indices(row, ht.kept_samples[idx])))) * ht.__sqrt_weight_nds[idx]).T + X = ( + hl.nd.array( + block[entries_field_name].map( + lambda row: mean_impute(select_array_indices(row, ht.kept_samples[idx])) + ) + ) + * ht.__sqrt_weight_nds[idx] + ).T else: - X = hl.nd.array(block[entries_field_name].map(lambda row: mean_impute(select_array_indices(row, ht.kept_samples[idx])))).T + X = hl.nd.array( + block[entries_field_name].map( + lambda row: mean_impute(select_array_indices(row, ht.kept_samples[idx])) + ) + ).T n = ht.ns[idx] sum_x = X.sum(0) Qtx = ht.__cov_Qts[idx] @ X @@ -569,15 +655,28 @@ def process_y_group(idx): xyp = ytx - (ht.__Qtys[idx].T @ Qtx) xxpRec = (dot_rows_with_themselves(X.T) - dot_rows_with_themselves(Qtx.T)).map(lambda entry: 1 / entry) b = xyp * xxpRec - se = ((1.0 / ht.ds[idx]) * (ht.__yyps[idx].reshape((-1, 1)) @ xxpRec.reshape((1, -1)) - (b * b))).map(lambda entry: hl.sqrt(entry)) + se = ((1.0 / ht.ds[idx]) * (ht.__yyps[idx].reshape((-1, 1)) @ xxpRec.reshape((1, -1)) - (b * b))).map( + lambda entry: hl.sqrt(entry) + ) t = b / se - return hl.rbind(t, lambda t: - hl.rbind(ht.ds[idx], lambda d: - hl.rbind(t.map(lambda entry: 2 * hl.expr.functions.pT(-hl.abs(entry), d, True, False)), lambda p: - hl.struct(n=hl.range(rows_in_block).map(lambda i: n), sum_x=sum_x._data_array(), - y_transpose_x=ytx.T._data_array(), beta=b.T._data_array(), - standard_error=se.T._data_array(), t_stat=t.T._data_array(), - p_value=p.T._data_array())))) + return hl.rbind( + t, + lambda t: hl.rbind( + ht.ds[idx], + lambda d: hl.rbind( + t.map(lambda entry: 2 * hl.expr.functions.pT(-hl.abs(entry), d, True, False)), + lambda p: hl.struct( + n=hl.range(rows_in_block).map(lambda i: n), + sum_x=sum_x._data_array(), + y_transpose_x=ytx.T._data_array(), + beta=b.T._data_array(), + standard_error=se.T._data_array(), + t_stat=t.T._data_array(), + p_value=p.T._data_array(), + ), + ), + ), + ) per_y_list = hl.range(num_y_lists).map(lambda i: process_y_group(i)) @@ -588,11 +687,10 @@ def build_row(row_idx): idxth_keys = {field_name: block[field_name][row_idx] for field_name in key_field_names} computed_row_field_names = ['n', 'sum_x', 'y_transpose_x', 'beta', 'standard_error', 't_stat', 'p_value'] computed_row_fields = { - field_name: per_y_list.map(lambda one_y: one_y[field_name][row_idx]) for field_name in computed_row_field_names - } - pass_through_rows = { - field_name: block[field_name][row_idx] for field_name in row_field_names + field_name: per_y_list.map(lambda one_y: one_y[field_name][row_idx]) + for field_name in computed_row_field_names } + pass_through_rows = {field_name: block[field_name][row_idx] for field_name in row_field_names} if not is_chained: computed_row_fields = {key: value[0] for key, value in computed_row_fields.items()} @@ -621,21 +719,18 @@ def process_partition(part): return res -@typecheck(test=enumeration('wald', 'lrt', 'score', 'firth'), - y=oneof(expr_float64, sequenceof(expr_float64)), - x=expr_float64, - covariates=sequenceof(expr_float64), - pass_through=sequenceof(oneof(str, Expression)), - max_iterations=nullable(int), - tolerance=nullable(float)) -def logistic_regression_rows(test, - y, - x, - covariates, - pass_through=(), - *, - max_iterations: Optional[int] = None, - tolerance: Optional[float] = None) -> hail.Table: +@typecheck( + test=enumeration('wald', 'lrt', 'score', 'firth'), + y=oneof(expr_float64, sequenceof(expr_float64)), + x=expr_float64, + covariates=sequenceof(expr_float64), + pass_through=sequenceof(oneof(str, Expression)), + max_iterations=nullable(int), + tolerance=nullable(float), +) +def logistic_regression_rows( + test, y, x, covariates, pass_through=(), *, max_iterations: Optional[int] = None, tolerance: Optional[float] = None +) -> Table: r"""For each row, test an input variable for association with a binary response variable using logistic regression. @@ -872,7 +967,8 @@ def logistic_regression_rows(test, if hl.current_backend().requires_lowering: return _logistic_regression_rows_nd( - test, y, x, covariates, pass_through, max_iterations=max_iterations, tolerance=tolerance) + test, y, x, covariates, pass_through, max_iterations=max_iterations, tolerance=tolerance + ) if tolerance is None: tolerance = 1e-6 @@ -887,9 +983,7 @@ def logistic_regression_rows(test, y_is_list = isinstance(y, list) if y_is_list and len(y) == 0: raise ValueError("'logistic_regression_rows': found no values for 'y'") - y = [raise_unless_column_indexed('logistic_regression_rows/y', y) or y - for y in wrap_to_list(y) - ] + y = [raise_unless_column_indexed('logistic_regression_rows/y', y) or y for y in wrap_to_list(y)] for e in covariates: analyze('logistic_regression_rows/covariates', e, mt._col_indices) @@ -905,11 +999,12 @@ def logistic_regression_rows(test, row_fields = _get_regression_row_fields(mt, pass_through, 'logistic_regression_rows') # FIXME: selecting an existing entry field should be emitted as a SelectFields - mt = mt._select_all(col_exprs=dict(**y_dict, - **dict(zip(cov_field_names, covariates))), - row_exprs=row_fields, - col_key=[], - entry_exprs={x_field_name: x}) + mt = mt._select_all( + col_exprs=dict(**y_dict, **dict(zip(cov_field_names, covariates))), + row_exprs=row_fields, + col_key=[], + entry_exprs={x_field_name: x}, + ) config = { 'name': 'LogisticRegression', @@ -919,7 +1014,7 @@ def logistic_regression_rows(test, 'covFields': cov_field_names, 'passThrough': [x for x in row_fields if x not in mt.row_key], 'maxIterations': max_iterations, - 'tolerance': tolerance + 'tolerance': tolerance, } result = Table(ir.MatrixToTableApply(mt._mir, config)) @@ -936,19 +1031,20 @@ def mean_impute(hl_array): return hl_array.map(lambda entry: hl.coalesce(entry, non_missing_mean)) -sigmoid = hl.expit +sigmoid = expit def nd_max(hl_nd): return hl.max(hl.array(hl_nd.reshape(-1))) -def logreg_fit(X: hl.NDArrayNumericExpression, # (K,) - y: hl.NDArrayNumericExpression, # (N, K) - null_fit: Optional[hl.StructExpression], - max_iterations: int, - tolerance: float - ) -> hl.StructExpression: +def logreg_fit( + X: NDArrayNumericExpression, # (K,) + y: NDArrayNumericExpression, # (N, K) + null_fit: Optional[StructExpression], + max_iterations: int, + tolerance: float, +) -> StructExpression: """Iteratively reweighted least squares to fit the model y ~ Bernoulli(logit(X \beta)) When fitting the null model, K=n_covariates, otherwise K=n_covariates + 1. @@ -986,10 +1082,7 @@ def logreg_fit(X: hl.NDArrayNumericExpression, # (K,) fisher10 = fisher01.T fisher11 = X1.T @ (X1 * (mu * (1 - mu)).reshape(-1, 1)) - fisher = hl.nd.vstack([ - hl.nd.hstack([fisher00, fisher01]), - hl.nd.hstack([fisher10, fisher11]) - ]) + fisher = hl.nd.vstack([hl.nd.hstack([fisher00, fisher01]), hl.nd.hstack([fisher10, fisher11])]) dtype = numerical_regression_fit_dtype blank_struct = hl.struct(**{k: hl.missing(dtype[k]) for k in dtype}) @@ -1003,14 +1096,31 @@ def cont(exploded, delta_b, max_delta_b): next_score = X.T @ (y - next_mu) next_fisher = X.T @ (X * (next_mu * (1 - next_mu)).reshape(-1, 1)) - return (hl.case() - .when(exploded | hl.is_nan(delta_b[0]), - blank_struct.annotate(n_iterations=iteration, log_lkhd=log_lkhd, converged=False, exploded=True)) - .when(max_delta_b < tolerance, - hl.struct(b=b, score=score, fisher=fisher, mu=mu, n_iterations=iteration, log_lkhd=log_lkhd, converged=True, exploded=False)) - .when(iteration == max_iterations, - blank_struct.annotate(n_iterations=iteration, log_lkhd=log_lkhd, converged=False, exploded=False)) - .default(recur(iteration + 1, next_b, next_mu, next_score, next_fisher))) + return ( + hl.case() + .when( + exploded | hl.is_nan(delta_b[0]), + blank_struct.annotate(n_iterations=iteration, log_lkhd=log_lkhd, converged=False, exploded=True), + ) + .when( + max_delta_b < tolerance, + hl.struct( + b=b, + score=score, + fisher=fisher, + mu=mu, + n_iterations=iteration, + log_lkhd=log_lkhd, + converged=True, + exploded=False, + ), + ) + .when( + iteration == max_iterations, + blank_struct.annotate(n_iterations=iteration, log_lkhd=log_lkhd, converged=False, exploded=False), + ) + .default(recur(iteration + 1, next_b, next_mu, next_score, next_fisher)) + ) delta_b_struct = hl.nd.solve(fisher, score, no_crash=True) exploded = delta_b_struct.failed @@ -1032,7 +1142,8 @@ def wald_test(X, fit): standard_error=se[X.shape[1] - 1], z_stat=z[X.shape[1] - 1], p_value=p[X.shape[1] - 1], - fit=fit.select('n_iterations', 'converged', 'exploded')) + fit=fit.select('n_iterations', 'converged', 'exploded'), + ) def lrt_test(X, null_fit, fit): @@ -1043,7 +1154,8 @@ def lrt_test(X, null_fit, fit): beta=fit.b[X.shape[1] - 1], chi_sq_stat=chi_sq, p_value=p, - fit=fit.select('n_iterations', 'converged', 'exploded')) + fit=fit.select('n_iterations', 'converged', 'exploded'), + ) def logistic_score_test(X, y, null_fit): @@ -1065,29 +1177,24 @@ def logistic_score_test(X, y, null_fit): fisher10 = fisher01.T fisher11 = X1.T @ (X1 * (mu * (1 - mu)).reshape(-1, 1)) - fisher = hl.nd.vstack([ - hl.nd.hstack([fisher00, fisher01]), - hl.nd.hstack([fisher10, fisher11]) - ]) + fisher = hl.nd.vstack([hl.nd.hstack([fisher00, fisher01]), hl.nd.hstack([fisher10, fisher11])]) solve_attempt = hl.nd.solve(fisher, score, no_crash=True) - chi_sq = hl.or_missing( - ~solve_attempt.failed, - (score * solve_attempt.solution).sum() - ) + chi_sq = hl.or_missing(~solve_attempt.failed, (score * solve_attempt.solution).sum()) p = hl.pchisqtail(chi_sq, m - m0) return hl.struct(chi_sq_stat=chi_sq, p_value=p) -def _firth_fit(b: hl.NDArrayNumericExpression, # (K,) - X: hl.NDArrayNumericExpression, # (N, K) - y: hl.NDArrayNumericExpression, # (N,) - max_iterations: int, - tolerance: float - ) -> hl.StructExpression: +def _firth_fit( + b: NDArrayNumericExpression, # (K,) + X: NDArrayNumericExpression, # (N, K) + y: NDArrayNumericExpression, # (N,) + max_iterations: int, + tolerance: float, +) -> StructExpression: """Iteratively reweighted least squares using Firth's regression to fit the model y ~ Bernoulli(logit(X \beta)) When fitting the null model, K=n_covariates, otherwise K=n_covariates + 1. @@ -1099,7 +1206,7 @@ def _firth_fit(b: hl.NDArrayNumericExpression, # (K,) dtype = numerical_regression_fit_dtype._drop_fields(['score', 'fisher']) blank_struct = hl.struct(**{k: hl.missing(dtype[k]) for k in dtype}) - X_bslice = X[:, :b.shape[0]] + X_bslice = X[:, : b.shape[0]] def fit(recur, iteration, b): def cont(exploded, delta_b, max_delta_b): @@ -1109,14 +1216,22 @@ def cont(exploded, delta_b, max_delta_b): next_b = b + delta_b - return (hl.case() - .when(exploded | hl.is_nan(delta_b[0]), - blank_struct.annotate(n_iterations=iteration, log_lkhd=log_lkhd, converged=False, exploded=True)) - .when(max_delta_b < tolerance, - hl.struct(b=b, mu=mu, n_iterations=iteration, log_lkhd=log_lkhd, converged=True, exploded=False)) - .when(iteration == max_iterations, - blank_struct.annotate(n_iterations=iteration, log_lkhd=log_lkhd, converged=False, exploded=False)) - .default(recur(iteration + 1, next_b))) + return ( + hl.case() + .when( + exploded | hl.is_nan(delta_b[0]), + blank_struct.annotate(n_iterations=iteration, log_lkhd=log_lkhd, converged=False, exploded=True), + ) + .when( + max_delta_b < tolerance, + hl.struct(b=b, mu=mu, n_iterations=iteration, log_lkhd=log_lkhd, converged=True, exploded=False), + ) + .when( + iteration == max_iterations, + blank_struct.annotate(n_iterations=iteration, log_lkhd=log_lkhd, converged=False, exploded=False), + ) + .default(recur(iteration + 1, next_b)) + ) m = b.shape[0] # n_covariates or n_covariates + 1, depending on improved null fit vs full fit mu = sigmoid(X_bslice @ b) @@ -1139,7 +1254,7 @@ def cont(exploded, delta_b, max_delta_b): return hl.experimental.loop(fit, dtype, 1, b) -def _firth_test(null_fit, X, y, max_iterations, tolerance) -> hl.StructExpression: +def _firth_test(null_fit, X, y, max_iterations, tolerance) -> StructExpression: firth_improved_null_fit = _firth_fit(null_fit.b, X, y, max_iterations=max_iterations, tolerance=tolerance) dof = 1 # 1 variant @@ -1156,45 +1271,45 @@ def cont2(firth_fit): chi_sq_stat=hl.missing(hl.tfloat64), p_value=hl.missing(hl.tfloat64), firth_null_fit=hl.missing(firth_improved_null_fit.dtype), - fit=hl.missing(firth_fit.dtype) + fit=hl.missing(firth_fit.dtype), ) - return (hl.case() - .when(firth_improved_null_fit.converged, - hl.case() - .when(firth_fit.converged, - hl.struct( - beta=firth_fit.b[firth_fit.b.shape[0] - 1], - chi_sq_stat=firth_chi_sq, - p_value=firth_p, - firth_null_fit=firth_improved_null_fit, - fit=firth_fit - )) - .default(blank_struct.annotate( - firth_null_fit=firth_improved_null_fit, - fit=firth_fit - ))) - .default(blank_struct.annotate( - firth_null_fit=firth_improved_null_fit - ))) + return ( + hl.case() + .when( + firth_improved_null_fit.converged, + hl.case() + .when( + firth_fit.converged, + hl.struct( + beta=firth_fit.b[firth_fit.b.shape[0] - 1], + chi_sq_stat=firth_chi_sq, + p_value=firth_p, + firth_null_fit=firth_improved_null_fit, + fit=firth_fit, + ), + ) + .default(blank_struct.annotate(firth_null_fit=firth_improved_null_fit, fit=firth_fit)), + ) + .default(blank_struct.annotate(firth_null_fit=firth_improved_null_fit)) + ) + return hl.bind(cont2, firth_fit) + return hl.bind(cont, firth_improved_null_fit) -@typecheck(test=enumeration('wald', 'lrt', 'score', 'firth'), - y=oneof(expr_float64, sequenceof(expr_float64)), - x=expr_float64, - covariates=sequenceof(expr_float64), - pass_through=sequenceof(oneof(str, Expression)), - max_iterations=nullable(int), - tolerance=nullable(float)) -def _logistic_regression_rows_nd(test, - y, - x, - covariates, - pass_through=(), - *, - max_iterations: Optional[int] = None, - tolerance: Optional[float] = None) -> hail.Table: +@typecheck( + test=enumeration('wald', 'lrt', 'score', 'firth'), + y=oneof(expr_float64, sequenceof(expr_float64)), + x=expr_float64, + covariates=sequenceof(expr_float64), + pass_through=sequenceof(oneof(str, Expression)), + max_iterations=nullable(int), + tolerance=nullable(float), +) +def _logistic_regression_rows_nd( + test, y, x, covariates, pass_through=(), *, max_iterations: Optional[int] = None, tolerance: Optional[float] = None +) -> Table: r"""For each row, test an input variable for association with a binary response variable using logistic regression. @@ -1427,9 +1542,7 @@ def _logistic_regression_rows_nd(test, if y_is_list and len(y) == 0: raise ValueError("'logistic_regression_rows': found no values for 'y'") - y = [raise_unless_column_indexed('logistic_regression_rows/y', y) or y - for y in wrap_to_list(y) - ] + y = [raise_unless_column_indexed('logistic_regression_rows/y', y) or y for y in wrap_to_list(y)] for e in covariates: analyze('logistic_regression_rows/covariates', e, mt._col_indices) @@ -1448,16 +1561,19 @@ def _logistic_regression_rows_nd(test, mt = mt.filter_cols(hl.array(y + covariates).all(hl.is_defined)) # FIXME: selecting an existing entry field should be emitted as a SelectFields - mt = mt._select_all(col_exprs=dict(**y_dict, - **dict(zip(cov_field_names, covariates))), - row_exprs=row_fields, - col_key=[], - entry_exprs={x_field_name: x}) + mt = mt._select_all( + col_exprs=dict(**y_dict, **dict(zip(cov_field_names, covariates))), + row_exprs=row_fields, + col_key=[], + entry_exprs={x_field_name: x}, + ) ht = mt._localize_entries('entries', 'samples') # covmat rows are samples, columns are the different covariates - ht = ht.annotate_globals(covmat=hl.nd.array(ht.samples.map(lambda s: [s[cov_name] for cov_name in cov_field_names]))) + ht = ht.annotate_globals( + covmat=hl.nd.array(ht.samples.map(lambda s: [s[cov_name] for cov_name in cov_field_names])) + ) # yvecs is a list of sample-length vectors, one for each dependent variable. ht = ht.annotate_globals(yvecs=[hl.nd.array(ht.samples[y_name]) for y_name in y_field_names]) @@ -1467,16 +1583,29 @@ def fit_null(yvec): def error_if_not_converged(null_fit): return ( hl.case() - .when(~null_fit.exploded, - (hl.case() - .when(null_fit.converged, null_fit) - .or_error("Failed to fit logistic regression null model (standard MLE with covariates only): " - "Newton iteration failed to converge"))) - .or_error(hl.format("Failed to fit logistic regression null model (standard MLE with covariates only): " - "exploded at Newton iteration %d", null_fit.n_iterations))) + .when( + ~null_fit.exploded, + ( + hl.case() + .when(null_fit.converged, null_fit) + .or_error( + "Failed to fit logistic regression null model (standard MLE with covariates only): " + "Newton iteration failed to converge" + ) + ), + ) + .or_error( + hl.format( + "Failed to fit logistic regression null model (standard MLE with covariates only): " + "exploded at Newton iteration %d", + null_fit.n_iterations, + ) + ) + ) null_fit = logreg_fit(ht.covmat, yvec, None, max_iterations=max_iterations, tolerance=tolerance) return hl.bind(error_if_not_converged, null_fit) + ht = ht.annotate_globals(null_fits=ht.yvecs.map(fit_null)) ht = ht.transmute(x=hl.nd.array(mean_impute(ht.entries[x_field_name]))) @@ -1493,9 +1622,9 @@ def run_test(yvec, null_fit): return wald_test(ht.covs_and_x, test_fit) assert test == 'lrt', test return lrt_test(ht.covs_and_x, null_fit, test_fit) + ht = ht.select( - logistic_regression=hl.starmap(run_test, hl.zip(ht.yvecs, ht.null_fits)), - **{f: ht[f] for f in row_fields} + logistic_regression=hl.starmap(run_test, hl.zip(ht.yvecs, ht.null_fits)), **{f: ht[f] for f in row_fields} ) assert 'null_fits' not in row_fields assert 'logistic_regression' not in row_fields @@ -1509,21 +1638,18 @@ def run_test(yvec, null_fit): return ht -@typecheck(test=enumeration('wald', 'lrt', 'score'), - y=expr_float64, - x=expr_float64, - covariates=sequenceof(expr_float64), - pass_through=sequenceof(oneof(str, Expression)), - max_iterations=int, - tolerance=nullable(float)) -def poisson_regression_rows(test, - y, - x, - covariates, - pass_through=(), - *, - max_iterations: int = 25, - tolerance: Optional[float] = None) -> Table: +@typecheck( + test=enumeration('wald', 'lrt', 'score'), + y=expr_float64, + x=expr_float64, + covariates=sequenceof(expr_float64), + pass_through=sequenceof(oneof(str, Expression)), + max_iterations=int, + tolerance=nullable(float), +) +def poisson_regression_rows( + test, y, x, covariates, pass_through=(), *, max_iterations: int = 25, tolerance: Optional[float] = None +) -> Table: r"""For each row, test an input variable for association with a count response variable using `Poisson regression `__. @@ -1559,7 +1685,9 @@ def poisson_regression_rows(test, """ if hl.current_backend().requires_lowering: - return _lowered_poisson_regression_rows(test, y, x, covariates, pass_through, max_iterations=max_iterations, tolerance=tolerance) + return _lowered_poisson_regression_rows( + test, y, x, covariates, pass_through, max_iterations=max_iterations, tolerance=tolerance + ) if tolerance is None: tolerance = 1e-6 @@ -1586,11 +1714,12 @@ def poisson_regression_rows(test, row_fields = _get_regression_row_fields(mt, pass_through, 'poisson_regression_rows') # FIXME: selecting an existing entry field should be emitted as a SelectFields - mt = mt._select_all(col_exprs=dict(**{y_field_name: y}, - **dict(zip(cov_field_names, covariates))), - row_exprs=row_fields, - col_key=[], - entry_exprs={x_field_name: x}) + mt = mt._select_all( + col_exprs=dict(**{y_field_name: y}, **dict(zip(cov_field_names, covariates))), + row_exprs=row_fields, + col_key=[], + entry_exprs={x_field_name: x}, + ) config = { 'name': 'PoissonRegression', @@ -1600,27 +1729,24 @@ def poisson_regression_rows(test, 'covFields': cov_field_names, 'passThrough': [x for x in row_fields if x not in mt.row_key], 'maxIterations': max_iterations, - 'tolerance': tolerance + 'tolerance': tolerance, } return Table(ir.MatrixToTableApply(mt._mir, config)).persist() -@typecheck(test=enumeration('wald', 'lrt', 'score'), - y=expr_float64, - x=expr_float64, - covariates=sequenceof(expr_float64), - pass_through=sequenceof(oneof(str, Expression)), - max_iterations=int, - tolerance=nullable(float)) -def _lowered_poisson_regression_rows(test, - y, - x, - covariates, - pass_through=(), - *, - max_iterations: int = 25, - tolerance: Optional[float] = None): +@typecheck( + test=enumeration('wald', 'lrt', 'score'), + y=expr_float64, + x=expr_float64, + covariates=sequenceof(expr_float64), + pass_through=sequenceof(oneof(str, Expression)), + max_iterations=int, + tolerance=nullable(float), +) +def _lowered_poisson_regression_rows( + test, y, x, covariates, pass_through=(), *, max_iterations: int = 25, tolerance: Optional[float] = None +): assert max_iterations > 0 if tolerance is None: @@ -1637,37 +1763,35 @@ def _lowered_poisson_regression_rows(test, row_exprs = _get_regression_row_fields(mt, pass_through, '_lowered_poisson_regression_rows') mt = mt._select_all( - row_exprs=dict( - pass_through=hl.struct(**row_exprs) - ), - col_exprs=dict( - y=y, - covariates=covariates - ), - entry_exprs=dict( - x=x - ) + row_exprs=dict(pass_through=hl.struct(**row_exprs)), + col_exprs=dict(y=y, covariates=covariates), + entry_exprs=dict(x=x), ) # FIXME: the order of the columns is irrelevant to regression mt = mt.key_cols_by() - mt = mt.filter_cols( - hl.all(hl.is_defined(mt.y), *[hl.is_defined(mt.covariates[i]) for i in range(k)]) - ) + mt = mt.filter_cols(hl.all(hl.is_defined(mt.y), *[hl.is_defined(mt.covariates[i]) for i in range(k)])) - mt = mt.annotate_globals(**mt.aggregate_cols(hl.struct( - yvec=hl.agg.collect(hl.float(mt.y)), - covmat=hl.agg.collect(mt.covariates.map(hl.float)), - n=hl.agg.count() - ), _localize=False)) mt = mt.annotate_globals( - yvec=(hl.case() - .when(mt.n - k - 1 >= 1, hl.nd.array(mt.yvec)) - .or_error(hl.format( - "_lowered_poisson_regression_rows: insufficient degrees of freedom: n=%s, k=%s", - mt.n, k))), + **mt.aggregate_cols( + hl.struct( + yvec=hl.agg.collect(hl.float(mt.y)), + covmat=hl.agg.collect(mt.covariates.map(hl.float)), + n=hl.agg.count(), + ), + _localize=False, + ) + ) + mt = mt.annotate_globals( + yvec=( + hl.case() + .when(mt.n - k - 1 >= 1, hl.nd.array(mt.yvec)) + .or_error( + hl.format("_lowered_poisson_regression_rows: insufficient degrees of freedom: n=%s, k=%s", mt.n, k) + ) + ), covmat=hl.nd.array(mt.covmat), - n_complete_samples=mt.n + n_complete_samples=mt.n, ) covmat = mt.covmat yvec = mt.yvec @@ -1681,9 +1805,14 @@ def _lowered_poisson_regression_rows(test, fisher = (mu * covmat.T) @ covmat mt = mt.annotate_globals(null_fit=_poisson_fit(covmat, yvec, b, mu, score, fisher, max_iterations, tolerance)) mt = mt.annotate_globals( - null_fit=hl.case().when(mt.null_fit.converged, mt.null_fit).or_error( - hl.format('_lowered_poisson_regression_rows: null model did not converge: %s', - mt.null_fit.select('n_iterations', 'log_lkhd', 'converged', 'exploded'))) + null_fit=hl.case() + .when(mt.null_fit.converged, mt.null_fit) + .or_error( + hl.format( + '_lowered_poisson_regression_rows: null model did not converge: %s', + mt.null_fit.select('n_iterations', 'log_lkhd', 'converged', 'exploded'), + ) + ) ) mt = mt.annotate_rows(mean_x=hl.agg.mean(mt.x)) mt = mt.annotate_rows(xvec=hl.nd.array(hl.agg.collect(hl.coalesce(mt.x, mt.mean_x)))) @@ -1697,11 +1826,7 @@ def _lowered_poisson_regression_rows(test, if test == 'score': chi_sq, p = _poisson_score_test(null_fit, covmat, yvec, xvec) - return ht.select( - chi_sq_stat=chi_sq, - p_value=p, - **ht.pass_through - ).select_globals('null_fit') + return ht.select(chi_sq_stat=chi_sq, p_value=p, **ht.pass_through).select_globals('null_fit') X = hl.nd.hstack([covmat, xvec.T.reshape(-1, 1)]) b = hl.nd.hstack([null_fit.b, hl.nd.array([0.0])]) @@ -1713,35 +1838,27 @@ def _lowered_poisson_regression_rows(test, fisher01 = ((covmat.T * mu) @ xvec).reshape((-1, 1)) fisher10 = fisher01.T fisher11 = hl.nd.array([[(mu * xvec.T) @ xvec]]) - fisher = hl.nd.vstack([ - hl.nd.hstack([fisher00, fisher01]), - hl.nd.hstack([fisher10, fisher11]) - ]) + fisher = hl.nd.vstack([hl.nd.hstack([fisher00, fisher01]), hl.nd.hstack([fisher10, fisher11])]) test_fit = _poisson_fit(X, yvec, b, mu, score, fisher, max_iterations, tolerance) if test == 'lrt': - return ht.select( - test_fit=test_fit, - **lrt_test(X, null_fit, test_fit), - **ht.pass_through - ).select_globals('null_fit') + return ht.select(test_fit=test_fit, **lrt_test(X, null_fit, test_fit), **ht.pass_through).select_globals( + 'null_fit' + ) assert test == 'wald' - return ht.select( - test_fit=test_fit, - **wald_test(X, test_fit), - **ht.pass_through - ).select_globals('null_fit') - - -def _poisson_fit(X: hl.NDArrayNumericExpression, # (N, K) - y: hl.NDArrayNumericExpression, # (N,) - b: hl.NDArrayNumericExpression, # (K,) - mu: hl.NDArrayNumericExpression, # (N,) - score: hl.NDArrayNumericExpression, # (K,) - fisher: hl.NDArrayNumericExpression, # (K, K) - max_iterations: int, - tolerance: float - ) -> hl.StructExpression: + return ht.select(test_fit=test_fit, **wald_test(X, test_fit), **ht.pass_through).select_globals('null_fit') + + +def _poisson_fit( + X: NDArrayNumericExpression, # (N, K) + y: NDArrayNumericExpression, # (N,) + b: NDArrayNumericExpression, # (K,) + mu: NDArrayNumericExpression, # (N,) + score: NDArrayNumericExpression, # (K,) + fisher: NDArrayNumericExpression, # (K, K) + max_iterations: int, + tolerance: float, +) -> StructExpression: """Iteratively reweighted least squares to fit the model y ~ Poisson(exp(X \beta)) When fitting the null model, K=n_covariates, otherwise K=n_covariates + 1. @@ -1766,14 +1883,31 @@ def cont(exploded, delta_b, max_delta_b): next_score = X.T @ (y - next_mu) next_fisher = (next_mu * X.T) @ X - return (hl.case() - .when(exploded | hl.is_nan(delta_b[0]), - blank_struct.annotate(n_iterations=iteration, log_lkhd=log_lkhd, converged=False, exploded=True)) - .when(max_delta_b < tolerance, - hl.struct(b=b, score=score, fisher=fisher, mu=mu, n_iterations=iteration, log_lkhd=log_lkhd, converged=True, exploded=False)) - .when(iteration == max_iterations, - blank_struct.annotate(n_iterations=iteration, log_lkhd=log_lkhd, converged=False, exploded=False)) - .default(recur(iteration + 1, next_b, next_mu, next_score, next_fisher))) + return ( + hl.case() + .when( + exploded | hl.is_nan(delta_b[0]), + blank_struct.annotate(n_iterations=iteration, log_lkhd=log_lkhd, converged=False, exploded=True), + ) + .when( + max_delta_b < tolerance, + hl.struct( + b=b, + score=score, + fisher=fisher, + mu=mu, + n_iterations=iteration, + log_lkhd=log_lkhd, + converged=True, + exploded=False, + ), + ) + .when( + iteration == max_iterations, + blank_struct.annotate(n_iterations=iteration, log_lkhd=log_lkhd, converged=False, exploded=False), + ) + .default(recur(iteration + 1, next_b, next_mu, next_score, next_fisher)) + ) delta_b_struct = hl.nd.solve(fisher, score, no_crash=True) @@ -1799,26 +1933,15 @@ def _poisson_score_test(null_fit, covmat, y, xvec): fisher01 = ((mu * covmat.T) @ xvec).reshape((-1, 1)) fisher10 = fisher01.T fisher11 = hl.nd.array([[(mu * xvec.T) @ xvec]]) - fisher = hl.nd.vstack([ - hl.nd.hstack([fisher00, fisher01]), - hl.nd.hstack([fisher10, fisher11]) - ]) + fisher = hl.nd.vstack([hl.nd.hstack([fisher00, fisher01]), hl.nd.hstack([fisher10, fisher11])]) fisher_div_score = hl.nd.solve(fisher, score, no_crash=True) - chi_sq = hl.or_missing(~fisher_div_score.failed, - score @ fisher_div_score.solution) + chi_sq = hl.or_missing(~fisher_div_score.failed, score @ fisher_div_score.solution) p = hl.pchisqtail(chi_sq, dof) return chi_sq, p -def linear_mixed_model(y, - x, - z_t=None, - k=None, - p_path=None, - overwrite=False, - standardize=True, - mean_impute=True): +def linear_mixed_model(y, x, z_t=None, k=None, p_path=None, overwrite=False, standardize=True, mean_impute=True): r"""Initialize a linear mixed model from a matrix table. .. warning:: @@ -1828,20 +1951,18 @@ def linear_mixed_model(y, raise NotImplementedError("linear_mixed_model is no longer implemented/supported as of Hail 0.2.94") -@typecheck(entry_expr=expr_float64, - model=LinearMixedModel, - pa_t_path=nullable(str), - a_t_path=nullable(str), - mean_impute=bool, - partition_size=nullable(int), - pass_through=sequenceof(oneof(str, Expression))) -def linear_mixed_regression_rows(entry_expr, - model, - pa_t_path=None, - a_t_path=None, - mean_impute=True, - partition_size=None, - pass_through=()): +@typecheck( + entry_expr=expr_float64, + model=LinearMixedModel, + pa_t_path=nullable(str), + a_t_path=nullable(str), + mean_impute=bool, + partition_size=nullable(int), + pass_through=sequenceof(oneof(str, Expression)), +) +def linear_mixed_regression_rows( + entry_expr, model, pa_t_path=None, a_t_path=None, mean_impute=True, partition_size=None, pass_through=() +): """For each row, test an input variable for association using a linear mixed model. @@ -1852,23 +1973,20 @@ def linear_mixed_regression_rows(entry_expr, raise NotImplementedError("linear_mixed_model is no longer implemented/supported as of Hail 0.2.94") -@typecheck(group=expr_any, - weight=expr_float64, - y=expr_float64, - x=expr_float64, - covariates=sequenceof(expr_float64), - max_size=int, - accuracy=numeric, - iterations=int) -def _linear_skat(group, - weight, - y, - x, - covariates, - max_size: int = 46340, - accuracy: float = 1e-6, - iterations: int = 10000): - r'''The linear sequence kernel association test (SKAT). +@typecheck( + group=expr_any, + weight=expr_float64, + y=expr_float64, + x=expr_float64, + covariates=sequenceof(expr_float64), + max_size=int, + accuracy=numeric, + iterations=int, +) +def _linear_skat( + group, weight, y, x, covariates, max_size: int = 46340, accuracy: float = 1e-6, iterations: int = 10000 +): + r"""The linear sequence kernel association test (SKAT). Linear SKAT tests if the phenotype, `y`, is significantly associated with the genotype, `x`. For :math:`N` samples, in a group of :math:`M` variants, with :math:`K` covariates, the model is @@ -2126,38 +2244,20 @@ def _linear_skat(group, - s2 : :obj:`.tfloat64`, the variance of the residuals, :math:`\sigma^2` in the paper. - ''' + """ mt = matrix_table_source('skat/x', x) k = len(covariates) if k == 0: raise ValueError('_linear_skat: at least one covariate is required.') _warn_if_no_intercept('_linear_skat', covariates) mt = mt._select_all( - row_exprs=dict( - group=group, - weight=weight - ), - col_exprs=dict( - y=y, - covariates=covariates - ), - entry_exprs=dict( - x=x - ) + row_exprs=dict(group=group, weight=weight), col_exprs=dict(y=y, covariates=covariates), entry_exprs=dict(x=x) ) - mt = mt.filter_cols( - hl.all(hl.is_defined(mt.y), *[hl.is_defined(mt.covariates[i]) for i in range(k)]) - ) - yvec, covmat, n = mt.aggregate_cols(( - hl.agg.collect(hl.float(mt.y)), - hl.agg.collect(mt.covariates.map(hl.float)), - hl.agg.count() - ), _localize=False) - mt = mt.annotate_globals( - yvec=hl.nd.array(yvec), - covmat=hl.nd.array(covmat), - n_complete_samples=n + mt = mt.filter_cols(hl.all(hl.is_defined(mt.y), *[hl.is_defined(mt.covariates[i]) for i in range(k)])) + yvec, covmat, n = mt.aggregate_cols( + (hl.agg.collect(hl.float(mt.y)), hl.agg.collect(mt.covariates.map(hl.float)), hl.agg.count()), _localize=False ) + mt = mt.annotate_globals(yvec=hl.nd.array(yvec), covmat=hl.nd.array(covmat), n_complete_samples=n) # Instead of finding the best-fit beta, we go directly to the best-predicted value using the # reduced QR decomposition: # @@ -2175,37 +2275,24 @@ def _linear_skat(group, # = Q Q^T y # covmat_Q, _ = hl.nd.qr(mt.covmat) - mt = mt.annotate_globals( - covmat_Q=covmat_Q - ) + mt = mt.annotate_globals(covmat_Q=covmat_Q) null_mu = mt.covmat_Q @ (mt.covmat_Q.T @ mt.yvec) y_residual = mt.yvec - null_mu - mt = mt.annotate_globals( - y_residual=y_residual, - s2=y_residual @ y_residual.T / (n - k) - ) - mt = mt.annotate_rows( - G_row_mean=hl.agg.mean(mt.x) - ) - mt = mt.annotate_rows( - G_row=hl.agg.collect(hl.coalesce(mt.x, mt.G_row_mean)) - ) + mt = mt.annotate_globals(y_residual=y_residual, s2=y_residual @ y_residual.T / (n - k)) + mt = mt.annotate_rows(G_row_mean=hl.agg.mean(mt.x)) + mt = mt.annotate_rows(G_row=hl.agg.collect(hl.coalesce(mt.x, mt.G_row_mean))) ht = mt.rows() ht = ht.filter(hl.all(hl.is_defined(ht.group), hl.is_defined(ht.weight))) - ht = ht.group_by( - 'group' - ).aggregate( + ht = ht.group_by('group').aggregate( weight_take=hl.agg.take(ht.weight, n=max_size + 1), G_take=hl.agg.take(ht.G_row, n=max_size + 1), - size=hl.agg.count() + size=hl.agg.count(), ) ht = ht.annotate( weight=hl.nd.array(hl.or_missing(hl.len(ht.weight_take) <= max_size, ht.weight_take)), - G=hl.nd.array(hl.or_missing(hl.len(ht.G_take) <= max_size, ht.G_take)).T - ) - ht = ht.annotate( - Q=((ht.y_residual @ ht.G).map(lambda x: x**2) * ht.weight).sum(0) + G=hl.nd.array(hl.or_missing(hl.len(ht.G_take) <= max_size, ht.G_take)).T, ) + ht = ht.annotate(Q=((ht.y_residual @ ht.G).map(lambda x: x**2) * ht.weight).sum(0)) # Null model: # @@ -2286,11 +2373,20 @@ def _linear_skat(group, # = W S^2 W weights_arr = hl.array(ht.weight) - A = hl.case().when( - hl.all(weights_arr.map(lambda x: x >= 0)), - (ht.G - ht.covmat_Q @ (ht.covmat_Q.T @ ht.G)) * hl.sqrt(ht.weight) - ).or_error(hl.format('hl._linear_skat: every weight must be positive, in group %s, the weights were: %s', - ht.group, weights_arr)) + A = ( + hl.case() + .when( + hl.all(weights_arr.map(lambda x: x >= 0)), + (ht.G - ht.covmat_Q @ (ht.covmat_Q.T @ ht.G)) * hl.sqrt(ht.weight), + ) + .or_error( + hl.format( + 'hl._linear_skat: every weight must be positive, in group %s, the weights were: %s', + ht.group, + weights_arr, + ) + ) + ) singular_values = hl.nd.svd(A, compute_uv=False) # SVD(M) = U S V. U and V are unitary, therefore SVD(k M) = U (k S) V. @@ -2309,7 +2405,7 @@ def _linear_skat(group, mu=0, sigma=0, min_accuracy=accuracy, - max_iterations=iterations + max_iterations=iterations, ) ht = ht.select( 'size', @@ -2324,32 +2420,36 @@ def _linear_skat(group, # # Ergo, we want to check the right-tail of the distribution. p_value=1.0 - genchisq_data.value, - fault=genchisq_data.fault + fault=genchisq_data.fault, ) return ht.select_globals('y_residual', 's2', 'n_complete_samples') -@typecheck(group=expr_any, - weight=expr_float64, - y=expr_float64, - x=expr_float64, - covariates=sequenceof(expr_float64), - max_size=int, - null_max_iterations=int, - null_tolerance=float, - accuracy=numeric, - iterations=int) -def _logistic_skat(group, - weight, - y, - x, - covariates, - max_size: int = 46340, - null_max_iterations: int = 25, - null_tolerance: float = 1e-6, - accuracy: float = 1e-6, - iterations: int = 10000): - r'''The logistic sequence kernel association test (SKAT). +@typecheck( + group=expr_any, + weight=expr_float64, + y=expr_float64, + x=expr_float64, + covariates=sequenceof(expr_float64), + max_size=int, + null_max_iterations=int, + null_tolerance=float, + accuracy=numeric, + iterations=int, +) +def _logistic_skat( + group, + weight, + y, + x, + covariates, + max_size: int = 46340, + null_max_iterations: int = 25, + null_tolerance: float = 1e-6, + accuracy: float = 1e-6, + iterations: int = 10000, +): + r"""The logistic sequence kernel association test (SKAT). Logistic SKAT tests if the phenotype, `y`, is significantly associated with the genotype, `x`. For :math:`N` samples, in a group of :math:`M` variants, with :math:`K` covariates, the @@ -2649,74 +2749,54 @@ def _logistic_skat(group, - exploded : :obj:`.tbool` True if the null model failed to converge due to numerical explosion. - ''' + """ mt = matrix_table_source('skat/x', x) k = len(covariates) if k == 0: raise ValueError('_logistic_skat: at least one covariate is required.') _warn_if_no_intercept('_logistic_skat', covariates) mt = mt._select_all( - row_exprs=dict( - group=group, - weight=weight - ), - col_exprs=dict( - y=y, - covariates=covariates - ), - entry_exprs=dict( - x=x - ) - ) - mt = mt.filter_cols( - hl.all(hl.is_defined(mt.y), *[hl.is_defined(mt.covariates[i]) for i in range(k)]) + row_exprs=dict(group=group, weight=weight), col_exprs=dict(y=y, covariates=covariates), entry_exprs=dict(x=x) ) + mt = mt.filter_cols(hl.all(hl.is_defined(mt.y), *[hl.is_defined(mt.covariates[i]) for i in range(k)])) if mt.y.dtype != hl.tbool: mt = mt.annotate_cols( - y=(hl.case() - .when(hl.any(mt.y == 0, mt.y == 1), hl.bool(mt.y)) - .or_error(hl.format( - f'hl._logistic_skat: phenotypes must either be True, False, 0, or 1, found: %s of type {mt.y.dtype}', mt.y))) + y=( + hl.case() + .when(hl.any(mt.y == 0, mt.y == 1), hl.bool(mt.y)) + .or_error( + hl.format( + f'hl._logistic_skat: phenotypes must either be True, False, 0, or 1, found: %s of type {mt.y.dtype}', + mt.y, + ) + ) + ) ) - yvec, covmat, n = mt.aggregate_cols(( - hl.agg.collect(hl.float(mt.y)), - hl.agg.collect(mt.covariates.map(hl.float)), - hl.agg.count() - ), _localize=False) - mt = mt.annotate_globals( - yvec=hl.nd.array(yvec), - covmat=hl.nd.array(covmat), - n_complete_samples=n + yvec, covmat, n = mt.aggregate_cols( + (hl.agg.collect(hl.float(mt.y)), hl.agg.collect(mt.covariates.map(hl.float)), hl.agg.count()), _localize=False ) + mt = mt.annotate_globals(yvec=hl.nd.array(yvec), covmat=hl.nd.array(covmat), n_complete_samples=n) null_fit = logreg_fit(mt.covmat, mt.yvec, None, max_iterations=null_max_iterations, tolerance=null_tolerance) mt = mt.annotate_globals( - null_fit=hl.case().when(null_fit.converged, null_fit).or_error( - hl.format('hl._logistic_skat: null model did not converge: %s', null_fit)) + null_fit=hl.case() + .when(null_fit.converged, null_fit) + .or_error(hl.format('hl._logistic_skat: null model did not converge: %s', null_fit)) ) null_mu = mt.null_fit.mu y_residual = mt.yvec - null_mu - mt = mt.annotate_globals( - y_residual=y_residual, - s2=null_mu * (1 - null_mu) - ) - mt = mt.annotate_rows( - G_row_mean=hl.agg.mean(mt.x) - ) - mt = mt.annotate_rows( - G_row=hl.agg.collect(hl.coalesce(mt.x, mt.G_row_mean)) - ) + mt = mt.annotate_globals(y_residual=y_residual, s2=null_mu * (1 - null_mu)) + mt = mt.annotate_rows(G_row_mean=hl.agg.mean(mt.x)) + mt = mt.annotate_rows(G_row=hl.agg.collect(hl.coalesce(mt.x, mt.G_row_mean))) ht = mt.rows() ht = ht.filter(hl.all(hl.is_defined(ht.group), hl.is_defined(ht.weight))) - ht = ht.group_by( - 'group' - ).aggregate( + ht = ht.group_by('group').aggregate( weight_take=hl.agg.take(ht.weight, n=max_size + 1), G_take=hl.agg.take(ht.G_row, n=max_size + 1), - size=hl.agg.count() + size=hl.agg.count(), ) ht = ht.annotate( weight=hl.nd.array(hl.or_missing(hl.len(ht.weight_take) <= max_size, ht.weight_take)), - G=hl.nd.array(hl.or_missing(hl.len(ht.G_take) <= max_size, ht.G_take)).T + G=hl.nd.array(hl.or_missing(hl.len(ht.G_take) <= max_size, ht.G_take)).T, ) ht = ht.annotate( # Q=ht.y_residual @ (ht.G * ht.weight) @ ht.G.T @ ht.y_residual.T @@ -2729,11 +2809,17 @@ def _logistic_skat(group, Q, _ = hl.nd.qr(ht.covmat * sqrtv.reshape(-1, 1)) weights_arr = hl.array(ht.weight) G_scaled = ht.G * sqrtv.reshape(-1, 1) - A = hl.case().when( - hl.all(weights_arr.map(lambda x: x >= 0)), - (G_scaled - Q @ (Q.T @ G_scaled)) * hl.sqrt(ht.weight) - ).or_error(hl.format('hl._logistic_skat: every weight must be positive, in group %s, the weights were: %s', - ht.group, weights_arr)) + A = ( + hl.case() + .when(hl.all(weights_arr.map(lambda x: x >= 0)), (G_scaled - Q @ (Q.T @ G_scaled)) * hl.sqrt(ht.weight)) + .or_error( + hl.format( + 'hl._logistic_skat: every weight must be positive, in group %s, the weights were: %s', + ht.group, + weights_arr, + ) + ) + ) singular_values = hl.nd.svd(A, compute_uv=False) eigenvalues = singular_values.map(lambda x: x**2) @@ -2750,7 +2836,7 @@ def _logistic_skat(group, mu=0, sigma=0, min_accuracy=accuracy, - max_iterations=iterations + max_iterations=iterations, ) ht = ht.select( 'size', @@ -2765,29 +2851,33 @@ def _logistic_skat(group, # # Ergo, we want to check the right-tail of the distribution. p_value=1.0 - genchisq_data.value, - fault=genchisq_data.fault + fault=genchisq_data.fault, ) return ht.select_globals('y_residual', 's2', 'n_complete_samples', 'null_fit') -@typecheck(key_expr=expr_any, - weight_expr=expr_float64, - y=expr_float64, - x=expr_float64, - covariates=sequenceof(expr_float64), - logistic=oneof(bool, sized_tupleof(nullable(int), nullable(float))), - max_size=int, - accuracy=numeric, - iterations=int) -def skat(key_expr, - weight_expr, - y, - x, - covariates, - logistic: Union[bool, Tuple[int, float]] = False, - max_size: int = 46340, - accuracy: float = 1e-6, - iterations: int = 10000) -> Table: +@typecheck( + key_expr=expr_any, + weight_expr=expr_float64, + y=expr_float64, + x=expr_float64, + covariates=sequenceof(expr_float64), + logistic=oneof(bool, sized_tupleof(nullable(int), nullable(float))), + max_size=int, + accuracy=numeric, + iterations=int, +) +def skat( + key_expr, + weight_expr, + y, + x, + covariates, + logistic: Union[bool, Tuple[int, float]] = False, + max_size: int = 46340, + accuracy: float = 1e-6, + iterations: int = 10000, +) -> Table: r"""Test each keyed group of rows for association by linear or logistic SKAT test. @@ -2929,10 +3019,7 @@ def skat(key_expr, """ if hl.current_backend().requires_lowering: if logistic: - kwargs = { - 'accuracy': accuracy, - 'iterations': iterations - } + kwargs = {'accuracy': accuracy, 'iterations': iterations} if logistic is not True: null_max_iterations, null_tolerance = logistic kwargs['null_max_iterations'] = null_max_iterations @@ -2969,11 +3056,11 @@ def skat(key_expr, key_field_name = '__key' cov_field_names = list(f'__cov{i}' for i in range(len(covariates))) - mt = mt._select_all(col_exprs=dict(**{y_field_name: y}, - **dict(zip(cov_field_names, covariates))), - row_exprs={weight_field_name: weight_expr, - key_field_name: key_expr}, - entry_exprs=entry_expr) + mt = mt._select_all( + col_exprs=dict(**{y_field_name: y}, **dict(zip(cov_field_names, covariates))), + row_exprs={weight_field_name: weight_expr, key_field_name: key_expr}, + entry_exprs=entry_expr, + ) if logistic is True: use_logistic = True @@ -3000,14 +3087,13 @@ def skat(key_expr, 'accuracy': accuracy, 'iterations': iterations, 'logistic_max_iterations': max_iterations, - 'logistic_tolerance': tolerance + 'logistic_tolerance': tolerance, } return Table(ir.MatrixToTableApply(mt._mir, config)).persist() -@typecheck(p_value=expr_numeric, - approximate=bool) +@typecheck(p_value=expr_numeric, approximate=bool) def lambda_gc(p_value, approximate=True): """ Compute genomic inflation factor (lambda GC) from an Expression of p-values. @@ -3032,8 +3118,7 @@ def lambda_gc(p_value, approximate=True): return t.aggregate(med_chisq) -@typecheck(p_value=expr_numeric, - approximate=bool) +@typecheck(p_value=expr_numeric, approximate=bool) def _lambda_gc_agg(p_value, approximate=True): chisq = hl.qchisqtail(p_value, 1) if approximate: @@ -3043,10 +3128,7 @@ def _lambda_gc_agg(p_value, approximate=True): return med_chisq / hl.qchisqtail(0.5, 1) -@typecheck(ds=oneof(Table, MatrixTable), - keep_star=bool, - left_aligned=bool, - permit_shuffle=bool) +@typecheck(ds=oneof(Table, MatrixTable), keep_star=bool, left_aligned=bool, permit_shuffle=bool) def split_multi(ds, keep_star=False, left_aligned=False, *, permit_shuffle=False): """Split multiallelic variants. @@ -3152,25 +3234,23 @@ def split_multi(ds, keep_star=False, left_aligned=False, *, permit_shuffle=False kept_alleles = kept_alleles.filter(lambda i: old_row.alleles[i] != "*") def new_struct(variant, i): - return hl.struct(alleles=variant.alleles, - locus=variant.locus, - a_index=i, - was_split=hl.len(old_row.alleles) > 2) + return hl.struct(alleles=variant.alleles, locus=variant.locus, a_index=i, was_split=hl.len(old_row.alleles) > 2) def split_rows(expr, rekey): if isinstance(ds, MatrixTable): - mt = (ds.annotate_rows(**{new_id: expr}) - .explode_rows(new_id)) + mt = ds.annotate_rows(**{new_id: expr}).explode_rows(new_id) if rekey: mt = mt.key_rows_by() else: mt = mt.key_rows_by('locus') - new_row_expr = mt._rvrow.annotate(locus=mt[new_id]['locus'], - alleles=mt[new_id]['alleles'], - a_index=mt[new_id]['a_index'], - was_split=mt[new_id]['was_split'], - old_locus=mt.locus, - old_alleles=mt.alleles).drop(new_id) + new_row_expr = mt._rvrow.annotate( + locus=mt[new_id]['locus'], + alleles=mt[new_id]['alleles'], + a_index=mt[new_id]['a_index'], + was_split=mt[new_id]['was_split'], + old_locus=mt.locus, + old_alleles=mt.alleles, + ).drop(new_id) mt = mt._select_rows('split_multi', new_row_expr) if rekey: @@ -3179,18 +3259,19 @@ def split_rows(expr, rekey): return MatrixTable(ir.MatrixKeyRowsBy(mt._mir, ['locus', 'alleles'], is_sorted=True)) else: assert isinstance(ds, Table) - ht = (ds.annotate(**{new_id: expr}) - .explode(new_id)) + ht = ds.annotate(**{new_id: expr}).explode(new_id) if rekey: ht = ht.key_by() else: ht = ht.key_by('locus') - new_row_expr = ht.row.annotate(locus=ht[new_id]['locus'], - alleles=ht[new_id]['alleles'], - a_index=ht[new_id]['a_index'], - was_split=ht[new_id]['was_split'], - old_locus=ht.locus, - old_alleles=ht.alleles).drop(new_id) + new_row_expr = ht.row.annotate( + locus=ht[new_id]['locus'], + alleles=ht[new_id]['alleles'], + a_index=ht[new_id]['a_index'], + was_split=ht[new_id]['was_split'], + old_locus=ht.locus, + old_alleles=ht.alleles, + ).drop(new_id) ht = ht._select('split_multi', new_row_expr) if rekey: @@ -3199,22 +3280,25 @@ def split_rows(expr, rekey): return Table(ir.TableKeyBy(ht._tir, ['locus', 'alleles'], is_sorted=True)) if left_aligned: + def make_struct(i): def error_on_moved(v): - return (hl.case() - .when(v.locus == old_row.locus, new_struct(v, i)) - .or_error("Found non-left-aligned variant in split_multi")) - return hl.bind(error_on_moved, - hl.min_rep(old_row.locus, [old_row.alleles[0], old_row.alleles[i]])) + return ( + hl.case() + .when(v.locus == old_row.locus, new_struct(v, i)) + .or_error("Found non-left-aligned variant in split_multi") + ) + + return hl.bind(error_on_moved, hl.min_rep(old_row.locus, [old_row.alleles[0], old_row.alleles[i]])) + return split_rows(hl.sorted(kept_alleles.map(make_struct)), permit_shuffle) else: + def make_struct(i, cond): def struct_or_empty(v): - return (hl.case() - .when(cond(v.locus), hl.array([new_struct(v, i)])) - .or_missing()) - return hl.bind(struct_or_empty, - hl.min_rep(old_row.locus, [old_row.alleles[0], old_row.alleles[i]])) + return hl.case().when(cond(v.locus), hl.array([new_struct(v, i)])).or_missing() + + return hl.bind(struct_or_empty, hl.min_rep(old_row.locus, [old_row.alleles[0], old_row.alleles[i]])) def make_array(cond): return hl.sorted(kept_alleles.flatmap(lambda i: make_struct(i, cond))) @@ -3224,11 +3308,7 @@ def make_array(cond): return left.union(moved) if is_table else left.union_rows(moved, _check_cols=False) -@typecheck(ds=oneof(Table, MatrixTable), - keep_star=bool, - left_aligned=bool, - vep_root=str, - permit_shuffle=bool) +@typecheck(ds=oneof(Table, MatrixTable), keep_star=bool, left_aligned=bool, vep_root=str, permit_shuffle=bool) def split_multi_hts(ds, keep_star=False, left_aligned=False, vep_root='vep', *, permit_shuffle=False): """Split multiallelic variants for datasets that contain one or more fields from a standard high-throughput sequencing entry schema. @@ -3409,8 +3489,13 @@ def split_multi_hts(ds, keep_star=False, left_aligned=False, vep_root='vep', *, if vep_root in row_fields: update_rows_expression[vep_root] = split[vep_root].annotate(**{ x: split[vep_root][x].filter(lambda csq: csq.allele_num == split.a_index) - for x in ('intergenic_consequences', 'motif_feature_consequences', - 'regulatory_feature_consequences', 'transcript_consequences')}) + for x in ( + 'intergenic_consequences', + 'motif_feature_consequences', + 'regulatory_feature_consequences', + 'transcript_consequences', + ) + }) if isinstance(ds, Table): return split.annotate(**update_rows_expression).drop('old_locus', 'old_alleles') @@ -3425,7 +3510,7 @@ def split_multi_hts(ds, keep_star=False, left_aligned=False, vep_root='vep', *, 'GQ': hl.tint, 'PL': hl.tarray(hl.tint), 'PGT': hl.tcall, - 'PID': hl.tstr + 'PID': hl.tstr, } bad_fields = [] @@ -3443,24 +3528,36 @@ def split_multi_hts(ds, keep_star=False, left_aligned=False, vep_root='vep', *, if 'DP' in entry_fields: update_entries_expression['DP'] = split.DP if 'AD' in entry_fields: - update_entries_expression['AD'] = hl.or_missing(hl.is_defined(split.AD), - [hl.sum(split.AD) - split.AD[split.a_index], split.AD[split.a_index]]) + update_entries_expression['AD'] = hl.or_missing( + hl.is_defined(split.AD), [hl.sum(split.AD) - split.AD[split.a_index], split.AD[split.a_index]] + ) if 'PL' in entry_fields: pl = hl.or_missing( hl.is_defined(split.PL), - (hl.range(0, 3).map(lambda i: - hl.min((hl.range(0, hl.triangle(split.old_alleles.length())) - .filter(lambda j: hl.downcode(hl.unphased_diploid_gt_index_call(j), - split.a_index).unphased_diploid_gt_index() == i - ).map(lambda j: split.PL[j])))))) + ( + hl.range(0, 3).map( + lambda i: hl.min( + ( + hl.range(0, hl.triangle(split.old_alleles.length())) + .filter( + lambda j: hl.downcode( + hl.unphased_diploid_gt_index_call(j), split.a_index + ).unphased_diploid_gt_index() + == i + ) + .map(lambda j: split.PL[j]) + ) + ) + ) + ), + ) if 'GQ' in entry_fields: update_entries_expression['PL'] = pl update_entries_expression['GQ'] = hl.or_else(hl.gq_from_pl(pl), split.GQ) else: update_entries_expression['PL'] = pl - else: - if 'GQ' in entry_fields: - update_entries_expression['GQ'] = split.GQ + elif 'GQ' in entry_fields: + update_entries_expression['GQ'] = split.GQ if 'PGT' in entry_fields: update_entries_expression['PGT'] = hl.downcode(split.PGT, split.a_index) @@ -3531,8 +3628,7 @@ def genetic_relatedness_matrix(call_expr) -> BlockMatrix: raise_unless_entry_indexed('genetic_relatedness_matrix/call_expr', call_expr) mt = mt.select_entries(__gt=call_expr.n_alt_alleles()).unfilter_entries() - mt = mt.select_rows(__AC=agg.sum(mt.__gt), - __n_called=agg.count_where(hl.is_defined(mt.__gt))) + mt = mt.select_rows(__AC=agg.sum(mt.__gt), __n_called=agg.count_where(hl.is_defined(mt.__gt))) mt = mt.filter_rows((mt.__AC > 0) & (mt.__AC < 2 * mt.__n_called)) mt = mt.select_rows(__mean_gt=mt.__AC / mt.__n_called) @@ -3604,11 +3700,12 @@ def realized_relationship_matrix(call_expr) -> BlockMatrix: raise_unless_entry_indexed('realized_relationship_matrix/call_expr', call_expr) mt = mt.select_entries(__gt=call_expr.n_alt_alleles()).unfilter_entries() - mt = mt.select_rows(__AC=agg.sum(mt.__gt), - __ACsq=agg.sum(mt.__gt * mt.__gt), - __n_called=agg.count_where(hl.is_defined(mt.__gt))) - mt = mt.select_rows(__mean_gt=mt.__AC / mt.__n_called, - __centered_length=hl.sqrt(mt.__ACsq - (mt.__AC ** 2) / mt.__n_called)) + mt = mt.select_rows( + __AC=agg.sum(mt.__gt), __ACsq=agg.sum(mt.__gt * mt.__gt), __n_called=agg.count_where(hl.is_defined(mt.__gt)) + ) + mt = mt.select_rows( + __mean_gt=mt.__AC / mt.__n_called, __centered_length=hl.sqrt(mt.__ACsq - (mt.__AC**2) / mt.__n_called) + ) fmt = mt.filter_rows(mt.__centered_length > 0.1) # truly non-zero values are at least sqrt(0.5) normalized_gt = hl.or_else((fmt.__gt - fmt.__mean_gt) / fmt.__centered_length, 0.0) @@ -3617,8 +3714,10 @@ def realized_relationship_matrix(call_expr) -> BlockMatrix: bm = BlockMatrix.from_entry_expr(normalized_gt) return (bm.T @ bm) / (bm.n_rows / bm.n_cols) except FatalError as fe: - raise FatalError("Could not convert MatrixTable to BlockMatrix. It's possible all variants were dropped by variance filter.\n" - "Check that the input MatrixTable has at least two samples in it: mt.count_cols().") from fe + raise FatalError( + "Could not convert MatrixTable to BlockMatrix. It's possible all variants were dropped by variance filter.\n" + "Check that the input MatrixTable has at least two samples in it: mt.count_cols()." + ) from fe @typecheck(entry_expr=expr_float64, block_size=nullable(int)) @@ -3721,11 +3820,13 @@ def row_correlation(entry_expr, block_size=None) -> BlockMatrix: return bm @ bm.T -@typecheck(entry_expr=expr_float64, - locus_expr=expr_locus(), - radius=oneof(int, float), - coord_expr=nullable(expr_float64), - block_size=nullable(int)) +@typecheck( + entry_expr=expr_float64, + locus_expr=expr_locus(), + radius=oneof(int, float), + coord_expr=nullable(expr_float64), + block_size=nullable(int), +) def ld_matrix(entry_expr, locus_expr, radius, coord_expr=None, block_size=None) -> BlockMatrix: """Computes the windowed correlation (linkage disequilibrium) matrix between variants. @@ -3854,33 +3955,39 @@ def ld_matrix(entry_expr, locus_expr, radius, coord_expr=None, block_size=None) Row and column indices correspond to matrix table variant index. """ starts_and_stops = hl.linalg.utils.locus_windows(locus_expr, radius, coord_expr, _localize=False) - starts_and_stops = hl.tuple([starts_and_stops[0].map(lambda i: hl.int64(i)), starts_and_stops[1].map(lambda i: hl.int64(i))]) + starts_and_stops = hl.tuple([ + starts_and_stops[0].map(lambda i: hl.int64(i)), + starts_and_stops[1].map(lambda i: hl.int64(i)), + ]) ld = hl.row_correlation(entry_expr, block_size) return ld._sparsify_row_intervals_expr(starts_and_stops, blocks_only=False) -@typecheck(n_populations=int, - n_samples=int, - n_variants=int, - n_partitions=nullable(int), - pop_dist=nullable(sequenceof(numeric)), - fst=nullable(sequenceof(numeric)), - af_dist=nullable(expr_any), - reference_genome=reference_genome_type, - mixture=bool, - phased=bool) -def balding_nichols_model(n_populations: int, - n_samples: int, - n_variants: int, - n_partitions: Optional[int] = None, - pop_dist: Optional[List[int]] = None, - fst: Optional[List[Union[float, int]]] = None, - af_dist: Optional[hl.Expression] = None, - reference_genome: str = 'default', - mixture: bool = False, - *, - phased: bool = False - ) -> MatrixTable: +@typecheck( + n_populations=int, + n_samples=int, + n_variants=int, + n_partitions=nullable(int), + pop_dist=nullable(sequenceof(numeric)), + fst=nullable(sequenceof(numeric)), + af_dist=nullable(expr_any), + reference_genome=reference_genome_type, + mixture=bool, + phased=bool, +) +def balding_nichols_model( + n_populations: int, + n_samples: int, + n_variants: int, + n_partitions: Optional[int] = None, + pop_dist: Optional[List[int]] = None, + fst: Optional[List[Union[float, int]]] = None, + af_dist: Optional[Expression] = None, + reference_genome: str = 'default', + mixture: bool = False, + *, + phased: bool = False, +) -> MatrixTable: r"""Generate a matrix table of variants, samples, and genotypes using the Balding-Nichols or Pritchard-Stephens-Donnelly model. @@ -4079,41 +4186,47 @@ def balding_nichols_model(n_populations: int, n_partitions = max(8, int(n_samples * n_variants / (128 * 1024 * 1024))) # verify args - for name, var in {"populations": n_populations, - "samples": n_samples, - "variants": n_variants, - "partitions": n_partitions}.items(): + for name, var in { + "populations": n_populations, + "samples": n_samples, + "variants": n_variants, + "partitions": n_partitions, + }.items(): if var < 1: raise ValueError("n_{} must be positive, got {}".format(name, var)) for name, var in {"pop_dist": pop_dist, "fst": fst}.items(): if len(var) != n_populations: - raise ValueError("{} must be of length n_populations={}, got length {}" - .format(name, n_populations, len(var))) + raise ValueError( + "{} must be of length n_populations={}, got length {}".format(name, n_populations, len(var)) + ) if any(x < 0 for x in pop_dist): - raise ValueError("pop_dist must be non-negative, got {}" - .format(pop_dist)) + raise ValueError("pop_dist must be non-negative, got {}".format(pop_dist)) if any(x <= 0 or x >= 1 for x in fst): - raise ValueError("elements of fst must satisfy 0 < x < 1, got {}" - .format(fst)) + raise ValueError("elements of fst must satisfy 0 < x < 1, got {}".format(fst)) # verify af_dist if not af_dist._is_scalar: - raise ExpressionException('balding_nichols_model expects af_dist to ' - + 'have scalar arguments: found expression ' - + 'from source {}' - .format(af_dist._indices.source)) + raise ExpressionException( + 'balding_nichols_model expects af_dist to ' + + 'have scalar arguments: found expression ' + + 'from source {}'.format(af_dist._indices.source) + ) if af_dist.dtype != tfloat64: raise ValueError("af_dist must be a hail function with return type tfloat64.") - info("balding_nichols_model: generating genotypes for {} populations, {} samples, and {} variants..." - .format(n_populations, n_samples, n_variants)) + info( + "balding_nichols_model: generating genotypes for {} populations, {} samples, and {} variants...".format( + n_populations, n_samples, n_variants + ) + ) # generate matrix table from numpy import linspace + n_partitions = min(n_partitions, n_variants) start_idxs = [int(x) for x in linspace(0, n_variants, n_partitions + 1)] idx_bounds = list(zip(start_idxs, start_idxs[1:])) @@ -4130,18 +4243,14 @@ def balding_nichols_model(n_populations: int, n_partitions=n_partitions, pop_dist=pop_dist, fst=fst, - mixture=mixture + mixture=mixture, ), - cols=hl.range(n_samples).map( - lambda idx: hl.struct(sample_idx=idx, pop=pop_f(pop_dist)) - ) + cols=hl.range(n_samples).map(lambda idx: hl.struct(sample_idx=idx, pop=pop_f(pop_dist))), ), partitions=[ hl.Interval(**{ - endpoint: hl.Struct( - locus=reference_genome.locus_from_global_position(idx), - alleles=['A', 'C'] - ) for endpoint, idx in [('start', lo), ('end', hi)] + endpoint: hl.Struct(locus=reference_genome.locus_from_global_position(idx), alleles=['A', 'C']) + for endpoint, idx in [('start', lo), ('end', hi)] }) for (lo, hi) in idx_bounds ], @@ -4156,9 +4265,9 @@ def balding_nichols_model(n_populations: int, ), entries=hl.repeat(hl.struct(), n_samples), ), - af_dist + af_dist, ) - ) + ), ) bn = bn._unlocalize_entries('entries', 'cols', ['sample_idx']) @@ -4172,13 +4281,12 @@ def balding_nichols_model(n_populations: int, dad = hl.rand_bool(p) return bn.select_entries(GT=hl.call(mom, dad, phased=True)) - idx = hl.rand_cat([q ** 2, 2 * p * q, p ** 2]) + idx = hl.rand_cat([q**2, 2 * p * q, p**2]) return bn.select_entries(GT=hl.unphased_diploid_gt_index_call(idx)) @typecheck(mt=MatrixTable, f=anytype) -def filter_alleles(mt: MatrixTable, - f: Callable) -> MatrixTable: +def filter_alleles(mt: MatrixTable, f: Callable) -> MatrixTable: """Filter alternate alleles. .. include:: ../_templates/req_tvariant.rst @@ -4267,15 +4375,13 @@ def filter_alleles(mt: MatrixTable, inclusion = hl.range(0, hl.len(mt.alleles)).map(lambda i: (i == 0) | hl.bind(lambda ii: f(mt.alleles[ii], ii), i)) # old locus, old alleles, new to old, old to new - mt = mt.annotate_rows(__allele_inclusion=inclusion, - old_locus=mt.locus, - old_alleles=mt.alleles) - new_to_old = (hl.enumerate(mt.__allele_inclusion) - .filter(lambda elt: elt[1]) - .map(lambda elt: elt[0])) - old_to_new_dict = (hl.dict(hl.enumerate(hl.enumerate(mt.alleles) - .filter(lambda elt: mt.__allele_inclusion[elt[0]])) - .map(lambda elt: (elt[1][1], elt[0])))) + mt = mt.annotate_rows(__allele_inclusion=inclusion, old_locus=mt.locus, old_alleles=mt.alleles) + new_to_old = hl.enumerate(mt.__allele_inclusion).filter(lambda elt: elt[1]).map(lambda elt: elt[0]) + old_to_new_dict = hl.dict( + hl.enumerate(hl.enumerate(mt.alleles).filter(lambda elt: mt.__allele_inclusion[elt[0]])).map( + lambda elt: (elt[1][1], elt[0]) + ) + ) old_to_new = hl.bind(lambda d: mt.alleles.map(lambda a: d.get(a)), old_to_new_dict) mt = mt.annotate_rows(old_to_new=old_to_new, new_to_old=new_to_old) @@ -4290,9 +4396,7 @@ def filter_alleles(mt: MatrixTable, @typecheck(mt=MatrixTable, f=anytype, subset=bool) -def filter_alleles_hts(mt: MatrixTable, - f: Callable, - subset: bool = False) -> MatrixTable: +def filter_alleles_hts(mt: MatrixTable, f: Callable, subset: bool = False) -> MatrixTable: """Filter alternate alleles and update standard GATK entry fields. Examples @@ -4476,11 +4580,14 @@ def filter_alleles_hts(mt: MatrixTable, :class:`.MatrixTable` """ if mt.entry.dtype != hl.hts_entry_schema: - raise FatalError("'filter_alleles_hts': entry schema must be the HTS entry schema:\n" - " found: {}\n" - " expected: {}\n" - " Use 'hl.filter_alleles' to split entries with non-HTS entry fields.".format( - mt.entry.dtype, hl.hts_entry_schema)) + raise FatalError( + "'filter_alleles_hts': entry schema must be the HTS entry schema:\n" + " found: {}\n" + " expected: {}\n" + " Use 'hl.filter_alleles' to split entries with non-HTS entry fields.".format( + mt.entry.dtype, hl.hts_entry_schema + ) + ) mt = filter_alleles(mt, f) @@ -4491,53 +4598,70 @@ def filter_alleles_hts(mt: MatrixTable, lambda unnorm: unnorm - hl.min(unnorm), hl.range(0, hl.triangle(mt.alleles.length())).map( lambda newi: hl.bind( - lambda newc: mt.PL[hl.call(mt.new_to_old[newc[0]], - mt.new_to_old[newc[1]]).unphased_diploid_gt_index()], - hl.unphased_diploid_gt_index_call(newi)))), - hl.missing(tarray(tint32))) + lambda newc: mt.PL[ + hl.call(mt.new_to_old[newc[0]], mt.new_to_old[newc[1]]).unphased_diploid_gt_index() + ], + hl.unphased_diploid_gt_index_call(newi), + ) + ), + ), + hl.missing(tarray(tint32)), + ) return mt.annotate_entries( GT=hl.unphased_diploid_gt_index_call(hl.argmin(newPL, unique=True)), AD=hl.if_else( hl.is_defined(mt.AD), - hl.range(0, mt.alleles.length()).map( - lambda newi: mt.AD[mt.new_to_old[newi]]), - hl.missing(tarray(tint32))), + hl.range(0, mt.alleles.length()).map(lambda newi: mt.AD[mt.new_to_old[newi]]), + hl.missing(tarray(tint32)), + ), # DP unchanged GQ=hl.gq_from_pl(newPL), - PL=newPL) + PL=newPL, + ) # otherwise downcode else: mt = mt.annotate_rows(__old_to_new_no_na=mt.old_to_new.map(lambda x: hl.or_else(x, 0))) newPL = hl.if_else( hl.is_defined(mt.PL), - (hl.range(0, hl.triangle(hl.len(mt.alleles))) - .map(lambda newi: hl.min(hl.range(0, hl.triangle(hl.len(mt.old_alleles))) - .filter(lambda oldi: hl.bind( - lambda oldc: hl.call(mt.__old_to_new_no_na[oldc[0]], - mt.__old_to_new_no_na[oldc[1]]) == hl.unphased_diploid_gt_index_call(newi), - hl.unphased_diploid_gt_index_call(oldi))) - .map(lambda oldi: mt.PL[oldi])))), - hl.missing(tarray(tint32))) + ( + hl.range(0, hl.triangle(hl.len(mt.alleles))).map( + lambda newi: hl.min( + hl.range(0, hl.triangle(hl.len(mt.old_alleles))) + .filter( + lambda oldi: hl.bind( + lambda oldc: hl.call(mt.__old_to_new_no_na[oldc[0]], mt.__old_to_new_no_na[oldc[1]]) + == hl.unphased_diploid_gt_index_call(newi), + hl.unphased_diploid_gt_index_call(oldi), + ) + ) + .map(lambda oldi: mt.PL[oldi]) + ) + ) + ), + hl.missing(tarray(tint32)), + ) return mt.annotate_entries( - GT=hl.call(mt.__old_to_new_no_na[mt.GT[0]], - mt.__old_to_new_no_na[mt.GT[1]]), + GT=hl.call(mt.__old_to_new_no_na[mt.GT[0]], mt.__old_to_new_no_na[mt.GT[1]]), AD=hl.if_else( hl.is_defined(mt.AD), - (hl.range(0, hl.len(mt.alleles)) - .map(lambda newi: hl.sum(hl.range(0, hl.len(mt.old_alleles)) - .filter(lambda oldi: mt.__old_to_new_no_na[oldi] == newi) - .map(lambda oldi: mt.AD[oldi])))), - hl.missing(tarray(tint32))), + ( + hl.range(0, hl.len(mt.alleles)).map( + lambda newi: hl.sum( + hl.range(0, hl.len(mt.old_alleles)) + .filter(lambda oldi: mt.__old_to_new_no_na[oldi] == newi) + .map(lambda oldi: mt.AD[oldi]) + ) + ) + ), + hl.missing(tarray(tint32)), + ), # DP unchanged GQ=hl.gq_from_pl(newPL), - PL=newPL).drop('__old_to_new_no_na') + PL=newPL, + ).drop('__old_to_new_no_na') -@typecheck(mt=MatrixTable, - call_field=str, - r2=numeric, - bp_window_size=int, - memory_per_core=int) +@typecheck(mt=MatrixTable, call_field=str, r2=numeric, bp_window_size=int, memory_per_core=int) def _local_ld_prune(mt, call_field, r2=0.2, bp_window_size=1000000, memory_per_core=256): bytes_per_core = memory_per_core * 1024 * 1024 fraction_memory_to_use = 0.25 @@ -4553,21 +4677,28 @@ def _local_ld_prune(mt, call_field, r2=0.2, bp_window_size=1000000, memory_per_c info(f'ld_prune: running local pruning stage with max queue size of {max_queue_size} variants') - return Table(ir.MatrixToTableApply(mt._mir, { - 'name': 'LocalLDPrune', - 'callField': call_field, - 'r2Threshold': float(r2), - 'windowSize': bp_window_size, - 'maxQueueSize': max_queue_size - })).persist() - - -@typecheck(call_expr=expr_call, - r2=numeric, - bp_window_size=int, - memory_per_core=int, - keep_higher_maf=bool, - block_size=nullable(int)) + return Table( + ir.MatrixToTableApply( + mt._mir, + { + 'name': 'LocalLDPrune', + 'callField': call_field, + 'r2Threshold': float(r2), + 'windowSize': bp_window_size, + 'maxQueueSize': max_queue_size, + }, + ) + ).persist() + + +@typecheck( + call_expr=expr_call, + r2=numeric, + bp_window_size=int, + memory_per_core=int, + keep_higher_maf=bool, + block_size=nullable(int), +) def ld_prune(call_expr, r2=0.2, bp_window_size=1000000, memory_per_core=256, keep_higher_maf=True, block_size=None): """Returns a maximal subset of variants that are nearly uncorrelated within each window. @@ -4669,18 +4800,19 @@ def ld_prune(call_expr, r2=0.2, bp_window_size=1000000, memory_per_core=256, kee mt = mt.select_rows().select_cols() mt = mt.distinct_by_row() locally_pruned_table_path = new_temp_file() - (_local_ld_prune(require_biallelic(mt, 'ld_prune'), field, r2, bp_window_size, memory_per_core) - .write(locally_pruned_table_path, overwrite=True)) + ( + _local_ld_prune(require_biallelic(mt, 'ld_prune'), field, r2, bp_window_size, memory_per_core).write( + locally_pruned_table_path, overwrite=True + ) + ) locally_pruned_table = hl.read_table(locally_pruned_table_path).add_index() mt = mt.annotate_rows(info=locally_pruned_table[mt.row_key]) mt = mt.filter_rows(hl.is_defined(mt.info)).unfilter_entries() std_gt_bm = BlockMatrix.from_entry_expr( - hl.or_else( - (mt[field].n_alt_alleles() - mt.info.mean) * mt.info.centered_length_rec, - 0.0), - block_size=block_size) + hl.or_else((mt[field].n_alt_alleles() - mt.info.mean) * mt.info.centered_length_rec, 0.0), block_size=block_size + ) r2_bm = (std_gt_bm @ std_gt_bm.T) ** 2 _, stops = hl.linalg.utils.locus_windows(locally_pruned_table.locus, bp_window_size) @@ -4695,42 +4827,56 @@ def ld_prune(call_expr, r2=0.2, bp_window_size=1000000, memory_per_core=256, kee fields = ['locus'] info = locally_pruned_table.aggregate( - hl.agg.collect(locally_pruned_table.row.select('idx', *fields)), _localize=False) + hl.agg.collect(locally_pruned_table.row.select('idx', *fields)), _localize=False + ) info = hl.sorted(info, key=lambda x: x.idx) entries = entries.annotate_globals(info=info) entries = entries.filter( (entries.info[entries.i].locus.contig == entries.info[entries.j].locus.contig) - & (entries.info[entries.j].locus.position - entries.info[entries.i].locus.position <= bp_window_size)) + & (entries.info[entries.j].locus.position - entries.info[entries.i].locus.position <= bp_window_size) + ) if keep_higher_maf: entries = entries.annotate( - i=hl.struct(idx=entries.i, - twice_maf=hl.min(entries.info[entries.i].mean, 2.0 - entries.info[entries.i].mean)), - j=hl.struct(idx=entries.j, - twice_maf=hl.min(entries.info[entries.j].mean, 2.0 - entries.info[entries.j].mean))) + i=hl.struct( + idx=entries.i, twice_maf=hl.min(entries.info[entries.i].mean, 2.0 - entries.info[entries.i].mean) + ), + j=hl.struct( + idx=entries.j, twice_maf=hl.min(entries.info[entries.j].mean, 2.0 - entries.info[entries.j].mean) + ), + ) def tie_breaker(left, right): return hl.sign(right.twice_maf - left.twice_maf) + else: tie_breaker = None variants_to_remove = hl.maximal_independent_set( - entries.i, entries.j, keep=False, tie_breaker=tie_breaker, keyed=False) + entries.i, entries.j, keep=False, tie_breaker=tie_breaker, keyed=False + ) locally_pruned_table = locally_pruned_table.annotate_globals( variants_to_remove=variants_to_remove.aggregate( - hl.agg.collect_as_set(variants_to_remove.node.idx), _localize=False)) - return locally_pruned_table.filter( - locally_pruned_table.variants_to_remove.contains(hl.int32(locally_pruned_table.idx)), - keep=False - ).select().persist() + hl.agg.collect_as_set(variants_to_remove.node.idx), _localize=False + ) + ) + return ( + locally_pruned_table.filter( + locally_pruned_table.variants_to_remove.contains(hl.int32(locally_pruned_table.idx)), keep=False + ) + .select() + .persist() + ) def _warn_if_no_intercept(caller, covariates): if all([e._indices.axes for e in covariates]): - warning(f'{caller}: model appears to have no intercept covariate.' - '\n To include an intercept, add 1.0 to the list of covariates.') + warning( + f'{caller}: model appears to have no intercept covariate.' + '\n To include an intercept, add 1.0 to the list of covariates.' + ) return True return False diff --git a/hail/python/hail/nd/__init__.py b/hail/python/hail/nd/__init__.py index 93cbcdb1ff3..5cc3664d4e4 100644 --- a/hail/python/hail/nd/__init__.py +++ b/hail/python/hail/nd/__init__.py @@ -1,9 +1,48 @@ -from .nd import array, from_column_major, arange, full, zeros, ones, svd, eigh, qr, solve, solve_triangular, diagonal, inv, concatenate, \ - eye, identity, vstack, hstack, maximum, minimum +from .nd import ( + arange, + array, + concatenate, + diagonal, + eigh, + eye, + from_column_major, + full, + hstack, + identity, + inv, + maximum, + minimum, + ones, + qr, + solve, + solve_triangular, + svd, + vstack, + zeros, +) newaxis = None __all__ = [ - 'array', 'from_column_major', 'arange', 'full', 'zeros', 'ones', 'qr', 'solve', 'solve_triangular', 'svd', 'eigh', 'diagonal', 'inv', - 'concatenate', 'eye', 'identity', 'vstack', 'hstack', 'newaxis', 'maximum', 'minimum' + 'array', + 'from_column_major', + 'arange', + 'full', + 'zeros', + 'ones', + 'qr', + 'solve', + 'solve_triangular', + 'svd', + 'eigh', + 'diagonal', + 'inv', + 'concatenate', + 'eye', + 'identity', + 'vstack', + 'hstack', + 'newaxis', + 'maximum', + 'minimum', ] diff --git a/hail/python/hail/nd/nd.py b/hail/python/hail/nd/nd.py index 46b15192378..32b16b6c282 100644 --- a/hail/python/hail/nd/nd.py +++ b/hail/python/hail/nd/nd.py @@ -1,17 +1,26 @@ from functools import reduce import hail as hl -from hail.expr.functions import _ndarray -from hail.expr.functions import array as aarray -from hail.expr.types import HailType, tfloat64, tfloat32, ttuple, tndarray -from hail.typecheck import typecheck, nullable, oneof, tupleof, sequenceof from hail.expr.expressions import ( - expr_int32, expr_int64, expr_tuple, expr_any, expr_array, expr_ndarray, - expr_numeric, Int64Expression, cast_expr, construct_expr, expr_bool, - unify_all) + Int64Expression, + cast_expr, + construct_expr, + expr_any, + expr_array, + expr_bool, + expr_int32, + expr_int64, + expr_ndarray, + expr_numeric, + expr_tuple, + unify_all, +) from hail.expr.expressions.typed_expressions import NDArrayNumericExpression -from hail.ir import NDArrayQR, NDArrayInv, NDArrayConcat, NDArraySVD, NDArrayEigh, Apply - +from hail.expr.functions import _ndarray +from hail.expr.functions import array as aarray +from hail.expr.types import HailType, tfloat32, tfloat64, tndarray, ttuple +from hail.ir import Apply, NDArrayConcat, NDArrayEigh, NDArrayInv, NDArrayQR, NDArraySVD +from hail.typecheck import nullable, oneof, sequenceof, tupleof, typecheck tsequenceof_nd = oneof(sequenceof(expr_ndarray()), expr_array(expr_ndarray())) shape_type = oneof(expr_int64, tupleof(expr_int64), expr_tuple()) @@ -139,34 +148,34 @@ def full(shape, value, dtype=None): def zeros(shape, dtype=tfloat64): """Creates a hail :class:`.NDArrayNumericExpression` full of zeros. - Examples - -------- + Examples + -------- - Create a 5 by 7 NDArray of type :py:data:`.tfloat64` zeros. + Create a 5 by 7 NDArray of type :py:data:`.tfloat64` zeros. - >>> hl.nd.zeros((5, 7)) + >>> hl.nd.zeros((5, 7)) - It is possible to specify a type other than :py:data:`.tfloat64` with the `dtype` argument. + It is possible to specify a type other than :py:data:`.tfloat64` with the `dtype` argument. - >>> hl.nd.zeros((5, 7), dtype=hl.tfloat32) + >>> hl.nd.zeros((5, 7), dtype=hl.tfloat32) - Parameters - ---------- - shape : `tuple` or :class:`.TupleExpression` - Desired shape. - dtype : :class:`.HailType` - Desired hail type. Default: `float64`. - - See Also - -------- - :func:`.full` - - Returns - ------- - :class:`.NDArrayNumericExpression` - ndarray of the specified size full of zeros. - """ + Parameters + ---------- + shape : `tuple` or :class:`.TupleExpression` + Desired shape. + dtype : :class:`.HailType` + Desired hail type. Default: `float64`. + + See Also + -------- + :func:`.full` + + Returns + ------- + :class:`.NDArrayNumericExpression` + ndarray of the specified size full of zeros. + """ return full(shape, 0, dtype) @@ -174,35 +183,35 @@ def zeros(shape, dtype=tfloat64): def ones(shape, dtype=tfloat64): """Creates a hail :class:`.NDArrayNumericExpression` full of ones. - Examples - -------- + Examples + -------- - Create a 5 by 7 NDArray of type :py:data:`.tfloat64` ones. + Create a 5 by 7 NDArray of type :py:data:`.tfloat64` ones. - >>> hl.nd.ones((5, 7)) + >>> hl.nd.ones((5, 7)) - It is possible to specify a type other than :py:data:`.tfloat64` with the `dtype` argument. + It is possible to specify a type other than :py:data:`.tfloat64` with the `dtype` argument. - >>> hl.nd.ones((5, 7), dtype=hl.tfloat32) + >>> hl.nd.ones((5, 7), dtype=hl.tfloat32) - Parameters - ---------- - shape : `tuple` or :class:`.TupleExpression` - Desired shape. - dtype : :class:`.HailType` - Desired hail type. Default: `float64`. + Parameters + ---------- + shape : `tuple` or :class:`.TupleExpression` + Desired shape. + dtype : :class:`.HailType` + Desired hail type. Default: `float64`. - See Also - -------- - :func:`.full` + See Also + -------- + :func:`.full` - Returns - ------- - :class:`.NDArrayNumericExpression` - ndarray of the specified size full of ones. - """ + Returns + ------- + :class:`.NDArrayNumericExpression` + ndarray of the specified size full of ones. + """ return full(shape, 1, dtype) @@ -312,7 +321,7 @@ def solve_triangular(A, b, lower=False, no_crash=False): def solve_helper(nd_coef, nd_dep, nd_dep_ndim_orig): assert nd_coef.ndim == 2 - assert nd_dep_ndim_orig == 1 or nd_dep_ndim_orig == 2 + assert nd_dep_ndim_orig in {1, 2} if nd_dep_ndim_orig == 1: nd_dep = nd_dep.reshape((-1, 1)) @@ -422,7 +431,11 @@ def svd(nd, full_matrices=True, compute_uv=True): float_nd = nd.map(lambda x: hl.float64(x)) ir = NDArraySVD(float_nd._ir, full_matrices, compute_uv) - return_type = ttuple(tndarray(tfloat64, 2), tndarray(tfloat64, 1), tndarray(tfloat64, 2)) if compute_uv else tndarray(tfloat64, 1) + return_type = ( + ttuple(tndarray(tfloat64, 2), tndarray(tfloat64, 1), tndarray(tfloat64, 2)) + if compute_uv + else tndarray(tfloat64, 1) + ) return construct_expr(ir, return_type, nd._indices, nd._aggregations) @@ -515,7 +528,9 @@ def concatenate(nds, axis=0): element_types = {t.element_type for t in typs} if len(element_types) != 1: argument_element_types_str = ", ".join(str(nd.dtype.element_type) for nd in nds) - raise ValueError(f'hl.nd.concatenate: ndarrays must have same element types, found these element types: ({argument_element_types_str})') + raise ValueError( + f'hl.nd.concatenate: ndarrays must have same element types, found these element types: ({argument_element_types_str})' + ) ndims = {t.ndim for t in typs} assert len(ndims) != 1 @@ -532,7 +547,7 @@ def concatenate(nds, axis=0): @typecheck(N=expr_numeric, M=nullable(expr_numeric), dtype=HailType) -def eye(N, M=None, dtype=hl.tfloat64): +def eye(N, M=None, dtype=tfloat64): """ Construct a 2-D :class:`.NDArrayExpression` with ones on the *main* diagonal and zeros elsewhere. @@ -574,15 +589,15 @@ def eye(N, M=None, dtype=hl.tfloat64): else: n_col = hl.int32(M) - return hl.nd.array(hl.range(0, n_row * n_col).map( - lambda i: hl.if_else((i // n_col) == (i % n_col), - hl.literal(1, dtype), - hl.literal(0, dtype)) - )).reshape((n_row, n_col)) + return hl.nd.array( + hl.range(0, n_row * n_col).map( + lambda i: hl.if_else((i // n_col) == (i % n_col), hl.literal(1, dtype), hl.literal(0, dtype)) + ) + ).reshape((n_row, n_col)) @typecheck(N=expr_numeric, dtype=HailType) -def identity(N, dtype=hl.tfloat64): +def identity(N, dtype=tfloat64): """ Constructs a 2-D :class:`.NDArrayExpression` representing the identity array. The identity array is a square array with ones on the main diagonal. @@ -748,8 +763,9 @@ def maximum(nd1, nd2): """ if (nd1.dtype.element_type or nd2.dtype.element_type) == (tfloat64 or tfloat32): - return nd1.map2(nd2, lambda a, b: hl.if_else(hl.is_nan(a) | hl.is_nan(b), - hl.float64(float("NaN")), hl.if_else(a > b, a, b))) + return nd1.map2( + nd2, lambda a, b: hl.if_else(hl.is_nan(a) | hl.is_nan(b), hl.float64(float("NaN")), hl.if_else(a > b, a, b)) + ) return nd1.map2(nd2, lambda a, b: hl.if_else(a > b, a, b)) @@ -790,6 +806,7 @@ def minimum(nd1, nd2): """ if (nd1.dtype.element_type or nd2.dtype.element_type) == (tfloat64 or tfloat32): - return nd1.map2(nd2, lambda a, b: hl.if_else(hl.is_nan(a) | hl.is_nan(b), - hl.float64(float("NaN")), hl.if_else(a < b, a, b))) + return nd1.map2( + nd2, lambda a, b: hl.if_else(hl.is_nan(a) | hl.is_nan(b), hl.float64(float("NaN")), hl.if_else(a < b, a, b)) + ) return nd1.map2(nd2, lambda a, b: hl.if_else(a < b, a, b)) diff --git a/hail/python/hail/plot/__init__.py b/hail/python/hail/plot/__init__.py index fe902bd8713..97fee2aff1b 100644 --- a/hail/python/hail/plot/__init__.py +++ b/hail/python/hail/plot/__init__.py @@ -2,21 +2,39 @@ if is_notebook(): from bokeh.io import output_notebook + output_notebook() -from .plots import output_notebook, show, histogram, cumulative_histogram, histogram2d, scatter, joint_plot, qq, manhattan, smoothed_pdf, pdf, cdf, set_font_size, visualize_missingness +from .plots import ( + cdf, + cumulative_histogram, + histogram, + histogram2d, + joint_plot, + manhattan, + output_notebook, + pdf, + qq, + scatter, + set_font_size, + show, + smoothed_pdf, + visualize_missingness, +) -__all__ = ['output_notebook', - 'show', - 'histogram', - 'cumulative_histogram', - 'scatter', - 'joint_plot', - 'histogram2d', - 'qq', - 'manhattan', - 'pdf', - 'smoothed_pdf', - 'cdf', - 'set_font_size', - 'visualize_missingness'] +__all__ = [ + 'output_notebook', + 'show', + 'histogram', + 'cumulative_histogram', + 'scatter', + 'joint_plot', + 'histogram2d', + 'qq', + 'manhattan', + 'pdf', + 'smoothed_pdf', + 'cdf', + 'set_font_size', + 'visualize_missingness', +] diff --git a/hail/python/hail/plot/plots.py b/hail/python/hail/plot/plots.py index f16ae728551..4092dbea5e3 100644 --- a/hail/python/hail/plot/plots.py +++ b/hail/python/hail/plot/plots.py @@ -1,36 +1,69 @@ +import collections import math +import warnings +from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union -import collections -import numpy as np -import pandas as pd import bokeh import bokeh.io import bokeh.models -import warnings -from bokeh.models import (HoverTool, ColorBar, LogTicker, LogColorMapper, LinearColorMapper, - CategoricalColorMapper, ColumnDataSource, BasicTicker, Plot, CDSView, - GroupFilter, IntersectionFilter, Legend, LegendItem, Renderer, CustomJS, - Select, Column, Span, DataRange1d, Slope, Label, ColorMapper, GridPlot) -import bokeh.plotting import bokeh.palettes +import bokeh.plotting +import numpy as np +import pandas as pd +from bokeh.layouts import gridplot +from bokeh.models import ( + BasicTicker, + CategoricalColorMapper, + CDSView, + ColorBar, + ColorMapper, + Column, + ColumnDataSource, + CustomJS, + DataRange1d, + GridPlot, + GroupFilter, + HoverTool, + IntersectionFilter, + Label, + Legend, + LegendItem, + LinearColorMapper, + LogColorMapper, + LogTicker, + Plot, + Renderer, + Select, + Slope, + Span, +) from bokeh.plotting import figure from bokeh.transform import transform -from bokeh.layouts import gridplot +import hail from hail.expr import aggregators from hail.expr.expressions import ( - Expression, NumericExpression, StringExpression, LocusExpression, - Int32Expression, Int64Expression, Float32Expression, Float64Expression, - expr_numeric, expr_float64, expr_any, expr_locus, expr_str, raise_unless_row_indexed + Expression, + Float32Expression, + Float64Expression, + Int32Expression, + Int64Expression, + LocusExpression, + NumericExpression, + StringExpression, + expr_any, + expr_float64, + expr_locus, + expr_numeric, + expr_str, + raise_unless_row_indexed, ) from hail.expr.functions import _error_from_cdf_python -from hail.typecheck import typecheck, oneof, nullable, sized_tupleof, numeric, \ - sequenceof, dictof -from hail import Table, MatrixTable -from hail.utils.struct import Struct +from hail.matrixtable import MatrixTable +from hail.table import Table +from hail.typecheck import dictof, nullable, numeric, oneof, sequenceof, sized_tupleof, typecheck from hail.utils.java import warning -from typing import List, Tuple, Dict, Union, Callable, Optional, Sequence, Any, Set -import hail +from hail.utils.struct import Struct palette = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf'] @@ -110,7 +143,8 @@ def cdf(data, k=350, legend=None, title=None, normalize=True, log=False) -> figu height=400, background_fill_color='#EEEEEE', tools='xpan,xwheel_zoom,reset,save', - active_scroll='xwheel_zoom') + active_scroll='xwheel_zoom', + ) p.add_tools(HoverTool(tooltips=[("value", "$x"), ("rank", "@top")], mode='vline')) ranks = np.array(data.ranks) @@ -119,21 +153,14 @@ def cdf(data, k=350, legend=None, title=None, normalize=True, log=False) -> figu ranks = ranks / ranks[-1] # invisible, there to support tooltips - p.quad(top=ranks[1:-1], - bottom=ranks[1:-1], - left=values[:-1], - right=values[1:], - fill_alpha=0, - line_alpha=0) - p.step(x=[*values, values[-1]], - y=ranks, - line_width=2, - line_color='black', - legend_label=legend) + p.quad(top=ranks[1:-1], bottom=ranks[1:-1], left=values[:-1], right=values[1:], fill_alpha=0, line_alpha=0) + p.step(x=[*values, values[-1]], y=ranks, line_width=2, line_color='black', legend_label=legend) return p -def pdf(data, k=1000, confidence=5, legend=None, title=None, log=False, interactive=False) -> Union[figure, Tuple[figure, Callable]]: +def pdf( + data, k=1000, confidence=5, legend=None, title=None, log=False, interactive=False +) -> Union[figure, Tuple[figure, Callable]]: if isinstance(data, Expression): if data._indices is None: raise ValueError('Invalid input') @@ -157,7 +184,8 @@ def pdf(data, k=1000, confidence=5, legend=None, title=None, log=False, interact height=400, tools='xpan,xwheel_zoom,reset,save', active_scroll='xwheel_zoom', - background_fill_color='#EEEEEE') + background_fill_color='#EEEEEE', + ) y = np.array(data['ranks'][1:-1]) / data['ranks'][-1] x = np.array(data['values'][1:-1]) @@ -173,6 +201,7 @@ def pdf(data, k=1000, confidence=5, legend=None, title=None, log=False, interact plot = fig.quad(left=[min_x, *x[keep]], right=[*x[keep], max_x], bottom=0, top=slopes, legend_label=legend) if interactive: + def mk_interact(handle): def update(confidence=confidence): err = _error_from_cdf_python(data, 10 ** (-confidence), all_quantiles=True) / 1.8 @@ -181,12 +210,18 @@ def update(confidence=confidence): if log: new_data = {'x': [min_x, *x[keep], max_x], 'y': [*slopes, slopes[-1]]} else: - new_data = {'left': [min_x, *x[keep]], 'right': [*x[keep], max_x], 'bottom': np.full(len(slopes), 0), 'top': slopes} + new_data = { + 'left': [min_x, *x[keep]], + 'right': [*x[keep], max_x], + 'bottom': np.full(len(slopes), 0), + 'top': slopes, + } plot.data_source.data = new_data bokeh.io.push_notebook(handle=handle) from ipywidgets import interact - interact(update, confidence=(1, 10, .01)) + + interact(update, confidence=(1, 10, 0.01)) return fig, mk_interact else: @@ -269,7 +304,9 @@ def compare(x1, y1, x2, y2): return new_y, keep -def smoothed_pdf(data, k=350, smoothing=.5, legend=None, title=None, log=False, interactive=False, figure=None) -> Union[figure, Tuple[figure, Callable]]: +def smoothed_pdf( + data, k=350, smoothing=0.5, legend=None, title=None, log=False, interactive=False, figure=None +) -> Union[figure, Tuple[figure, Callable]]: """Create a density plot. Parameters @@ -320,7 +357,8 @@ def smoothed_pdf(data, k=350, smoothing=.5, legend=None, title=None, log=False, height=400, tools='xpan,xwheel_zoom,reset,save', active_scroll='xwheel_zoom', - background_fill_color='#EEEEEE') + background_fill_color='#EEEEEE', + ) else: p = figure @@ -344,6 +382,7 @@ def f(x, prev, smoothing=smoothing): line = p.line(x_d, final, line_width=2, line_color='black', legend_label=legend) if interactive: + def mk_interact(handle): def update(smoothing=smoothing): final = f(x_d, round1, smoothing) @@ -351,16 +390,26 @@ def update(smoothing=smoothing): bokeh.io.push_notebook(handle=handle) from ipywidgets import interact - interact(update, smoothing=(.02, .8, .005)) + + interact(update, smoothing=(0.02, 0.8, 0.005)) return p, mk_interact else: return p -@typecheck(data=oneof(Struct, expr_float64), range=nullable(sized_tupleof(numeric, numeric)), - bins=int, legend=nullable(str), title=nullable(str), log=bool, interactive=bool) -def histogram(data, range=None, bins=50, legend=None, title=None, log=False, interactive=False) -> Union[figure, Tuple[figure, Callable]]: +@typecheck( + data=oneof(Struct, expr_float64), + range=nullable(sized_tupleof(numeric, numeric)), + bins=int, + legend=nullable(str), + title=nullable(str), + log=bool, + interactive=bool, +) +def histogram( + data, range=None, bins=50, legend=None, title=None, log=False, interactive=False +) -> Union[figure, Tuple[figure, Callable]]: """Create a histogram. Notes @@ -397,8 +446,7 @@ def histogram(data, range=None, bins=50, legend=None, title=None, log=False, int end = range[1] else: finite_data = hail.bind(lambda x: hail.case().when(hail.is_finite(x), x).or_missing(), data) - start, end = agg_f((aggregators.min(finite_data), - aggregators.max(finite_data))) + start, end = agg_f((aggregators.min(finite_data), aggregators.max(finite_data))) if start is None and end is None: raise ValueError("'data' contains no values that are defined and finite") data = agg_f(aggregators.hist(data, start, end, bins)) @@ -423,12 +471,14 @@ def histogram(data, range=None, bins=50, legend=None, title=None, log=False, int bin_freq.append(math.log10(x)) if count_problems > 0: - warning(f"There were {count_problems} bins with height 0, those cannot be log transformed and were left as 0s.") + warning( + f"There were {count_problems} bins with height 0, those cannot be log transformed and were left as 0s." + ) changes = { "bin_freq": bin_freq, "n_larger": math.log10(data.n_larger) if data.n_larger > 0.0 else data.n_larger, - "n_smaller": math.log10(data.n_smaller) if data.n_smaller > 0.0 else data.n_smaller + "n_smaller": math.log10(data.n_smaller) if data.n_smaller > 0.0 else data.n_smaller, } data = data.annotate(**changes) y_axis_label = 'log10 Frequency' @@ -436,29 +486,45 @@ def histogram(data, range=None, bins=50, legend=None, title=None, log=False, int y_axis_label = 'Frequency' x_span = data.bin_edges[-1] - data.bin_edges[0] - x_start = data.bin_edges[0] - .05 * x_span - x_end = data.bin_edges[-1] + .05 * x_span + x_start = data.bin_edges[0] - 0.05 * x_span + x_end = data.bin_edges[-1] + 0.05 * x_span p = figure( title=title, x_axis_label=legend, y_axis_label=y_axis_label, background_fill_color='#EEEEEE', - x_range=(x_start, x_end)) + x_range=(x_start, x_end), + ) q = p.quad( - bottom=0, top=data.bin_freq, - left=data.bin_edges[:-1], right=data.bin_edges[1:], - legend_label=legend, line_color='black') + bottom=0, + top=data.bin_freq, + left=data.bin_edges[:-1], + right=data.bin_edges[1:], + legend_label=legend, + line_color='black', + ) if data.n_larger > 0: p.quad( - bottom=0, top=data.n_larger, - left=data.bin_edges[-1], right=(data.bin_edges[-1] + (data.bin_edges[1] - data.bin_edges[0])), - line_color='black', fill_color='green', legend_label='Outliers Above') + bottom=0, + top=data.n_larger, + left=data.bin_edges[-1], + right=(data.bin_edges[-1] + (data.bin_edges[1] - data.bin_edges[0])), + line_color='black', + fill_color='green', + legend_label='Outliers Above', + ) if data.n_smaller > 0: p.quad( - bottom=0, top=data.n_smaller, - left=data.bin_edges[0] - (data.bin_edges[1] - data.bin_edges[0]), right=data.bin_edges[0], - line_color='black', fill_color='red', legend_label='Outliers Below') + bottom=0, + top=data.n_smaller, + left=data.bin_edges[0] - (data.bin_edges[1] - data.bin_edges[0]), + right=data.bin_edges[0], + line_color='black', + fill_color='red', + legend_label='Outliers Below', + ) if interactive: + def mk_interact(handle): def update(bins=bins, phase=0): if phase > 0 and phase < 1: @@ -473,15 +539,23 @@ def update(bins=bins, phase=0): bokeh.io.push_notebook(handle=handle) from ipywidgets import interact - interact(update, bins=(0, 5 * bins), phase=(0, 1, .01)) + + interact(update, bins=(0, 5 * bins), phase=(0, 1, 0.01)) return p, mk_interact else: return p -@typecheck(data=oneof(Struct, expr_float64), range=nullable(sized_tupleof(numeric, numeric)), - bins=int, legend=nullable(str), title=nullable(str), normalize=bool, log=bool) +@typecheck( + data=oneof(Struct, expr_float64), + range=nullable(sized_tupleof(numeric, numeric)), + bins=int, + legend=nullable(str), + title=nullable(str), + normalize=bool, + log=bool, +) def cumulative_histogram(data, range=None, bins=50, legend=None, title=None, normalize=True, log=False) -> figure: """Create a cumulative histogram. @@ -530,8 +604,13 @@ def cumulative_histogram(data, range=None, bins=50, legend=None, title=None, nor if title is not None: title = f'{title} ({num_data_points:,} data points)' if log: - p = figure(title=title, x_axis_label=legend, y_axis_label='Frequency', - background_fill_color='#EEEEEE', y_axis_type='log') + p = figure( + title=title, + x_axis_label=legend, + y_axis_label='Frequency', + background_fill_color='#EEEEEE', + y_axis_type='log', + ) else: p = figure(title=title, x_axis_label=legend, y_axis_label='Frequency', background_fill_color='#EEEEEE') p.line(data.bin_edges[:-1], cumulative_data, line_color='#036564', line_width=3) @@ -565,21 +644,28 @@ def set_font_size(p, font_size: str = '12pt'): return p -@typecheck(x=expr_numeric, y=expr_numeric, bins=oneof(int, sequenceof(int)), - range=nullable(sized_tupleof(nullable(sized_tupleof(numeric, numeric)), - nullable(sized_tupleof(numeric, numeric)))), - title=nullable(str), width=int, height=int, - colors=sequenceof(str), - log=bool) -def histogram2d(x: NumericExpression, - y: NumericExpression, - bins: int = 40, - range: Optional[Tuple[int, int]] = None, - title: Optional[str] = None, - width: int = 600, - height: int = 600, - colors: Sequence[str] = bokeh.palettes.all_palettes['Blues'][7][::-1], - log: bool = False) -> figure: +@typecheck( + x=expr_numeric, + y=expr_numeric, + bins=oneof(int, sequenceof(int)), + range=nullable(sized_tupleof(nullable(sized_tupleof(numeric, numeric)), nullable(sized_tupleof(numeric, numeric)))), + title=nullable(str), + width=int, + height=int, + colors=sequenceof(str), + log=bool, +) +def histogram2d( + x: NumericExpression, + y: NumericExpression, + bins: int = 40, + range: Optional[Tuple[int, int]] = None, + title: Optional[str] = None, + width: int = 600, + height: int = 600, + colors: Sequence[str] = bokeh.palettes.all_palettes['Blues'][7][::-1], + log: bool = False, +) -> figure: """Plot a two-dimensional histogram. ``x`` and ``y`` must both be a :class:`.NumericExpression` from the same :class:`.Table`. @@ -643,38 +729,58 @@ def histogram2d(x: NumericExpression, x_axis = sorted(set(data.x), key=lambda z: float(z)) y_axis = sorted(set(data.y), key=lambda z: float(z)) - p = figure(title=title, - x_range=x_axis, y_range=y_axis, - x_axis_location="above", width=width, height=height, - tools="hover,save,pan,box_zoom,reset,wheel_zoom", toolbar_location='below') + p = figure( + title=title, + x_range=x_axis, + y_range=y_axis, + x_axis_location="above", + width=width, + height=height, + tools="hover,save,pan,box_zoom,reset,wheel_zoom", + toolbar_location='below', + ) p.grid.grid_line_color = None p.axis.axis_line_color = None p.axis.major_tick_line_color = None p.axis.major_label_standoff = 0 import math + p.xaxis.major_label_orientation = math.pi / 3 - p.rect(x='x', y='y', width=1, height=1, - source=data, - fill_color={'field': 'c', 'transform': mapper}, - line_color=None) + p.rect( + x='x', y='y', width=1, height=1, source=data, fill_color={'field': 'c', 'transform': mapper}, line_color=None + ) - color_bar = ColorBar(color_mapper=mapper, - ticker=LogTicker(desired_num_ticks=len(colors)) if log else BasicTicker(desired_num_ticks=len(colors)), - label_standoff=12 if log else 6, border_line_color=None, location=(0, 0)) + color_bar = ColorBar( + color_mapper=mapper, + ticker=LogTicker(desired_num_ticks=len(colors)) if log else BasicTicker(desired_num_ticks=len(colors)), + label_standoff=12 if log else 6, + border_line_color=None, + location=(0, 0), + ) p.add_layout(color_bar, 'right') hovertool = p.select_one(HoverTool) assert hovertool is not None - hovertool.tooltips = [('x', '@x'), ('y', '@y',), ('count', '@c')] + hovertool.tooltips = [ + ('x', '@x'), + ( + 'y', + '@y', + ), + ('count', '@c'), + ] return p -@typecheck(x=expr_numeric, y=expr_numeric, bins=oneof(int, sequenceof(int)), - range=nullable(sized_tupleof(nullable(sized_tupleof(numeric, numeric)), - nullable(sized_tupleof(numeric, numeric))))) +@typecheck( + x=expr_numeric, + y=expr_numeric, + bins=oneof(int, sequenceof(int)), + range=nullable(sized_tupleof(nullable(sized_tupleof(numeric, numeric)), nullable(sized_tupleof(numeric, numeric)))), +) def _generate_hist2d_data(x, y, bins, range): source = x._indices.source y_source = y._indices.source @@ -696,15 +802,16 @@ def _generate_hist2d_data(x, y, bins, range): x_range, y_range = range if x_range is None or y_range is None: warning('At least one range was not defined in histogram_2d. Doing two passes...') - ranges = source.aggregate(hail.struct(x_stats=hail.agg.stats(x), - y_stats=hail.agg.stats(y))) + ranges = source.aggregate(hail.struct(x_stats=hail.agg.stats(x), y_stats=hail.agg.stats(y))) if x_range is None: x_range = (ranges.x_stats.min, ranges.x_stats.max) if y_range is None: y_range = (ranges.y_stats.min, ranges.y_stats.max) else: - warning('If x_range or y_range are specified in histogram_2d, and there are points ' - 'outside of these ranges, they will not be plotted') + warning( + 'If x_range or y_range are specified in histogram_2d, and there are points ' + 'outside of these ranges, they will not be plotted' + ) x_range = list(map(float, x_range)) y_range = list(map(float, y_range)) x_spacing = (x_range[1] - x_range[0]) / x_bins @@ -712,30 +819,35 @@ def _generate_hist2d_data(x, y, bins, range): def frange(start, stop, step): from itertools import count, takewhile + return takewhile(lambda x: x <= stop, count(start, step)) x_levels = hail.literal(list(frange(x_range[0], x_range[1], x_spacing))[::-1]) y_levels = hail.literal(list(frange(y_range[0], y_range[1], y_spacing))[::-1]) grouped_ht = source.group_by( - x=hail.str(x_levels.find(lambda w: x >= w)), - y=hail.str(y_levels.find(lambda w: y >= w)) + x=hail.str(x_levels.find(lambda w: x >= w)), y=hail.str(y_levels.find(lambda w: y >= w)) ).aggregate(c=hail.agg.count()) - data = grouped_ht.filter(hail.is_defined(grouped_ht.x) & (grouped_ht.x != str(x_range[1])) - & hail.is_defined(grouped_ht.y) & (grouped_ht.y != str(y_range[1]))) + data = grouped_ht.filter( + hail.is_defined(grouped_ht.x) + & (grouped_ht.x != str(x_range[1])) + & hail.is_defined(grouped_ht.y) + & (grouped_ht.y != str(y_range[1])) + ) return data def _collect_scatter_plot_data( - x: Tuple[str, NumericExpression], - y: Tuple[str, NumericExpression], - fields: Optional[Dict[str, Expression]] = None, - n_divisions: Optional[int] = None, - missing_label: str = 'NA' + x: Tuple[str, NumericExpression], + y: Tuple[str, NumericExpression], + fields: Optional[Dict[str, Expression]] = None, + n_divisions: Optional[int] = None, + missing_label: str = 'NA', ) -> pd.DataFrame: - expressions = dict() if fields is not None: - expressions.update({k: hail.or_else(v, missing_label) if isinstance(v, StringExpression) else v for k, v in fields.items()}) + expressions.update({ + k: hail.or_else(v, missing_label) if isinstance(v, StringExpression) else v for k, v in fields.items() + }) if n_divisions is None: collect_expr = hail.struct(**dict((k, v) for k, v in (x, y)), **expressions) @@ -753,12 +865,17 @@ def _collect_scatter_plot_data( expressions = {k: hail.str(v) if not isinstance(v, StringExpression) else v for k, v in expressions.items()} agg_f = x[1]._aggregation_method() - res = agg_f(hail.agg.downsample(x[1], y[1], label=list(expressions.values()) if expressions else None, n_divisions=n_divisions)) + res = agg_f( + hail.agg.downsample( + x[1], y[1], label=list(expressions.values()) if expressions else None, n_divisions=n_divisions + ) + ) source_pd = pd.DataFrame([ dict( **{x[0]: point[0], y[0]: point[1]}, - **(dict(zip(expressions, point[2])) if point[2] is not None else {}) - ) for point in res + **(dict(zip(expressions, point[2])) if point[2] is not None else {}), + ) + for point in res ]) source_pd = source_pd.astype(numeric_expr, copy=False) @@ -772,45 +889,39 @@ def _get_categorical_palette(factors: List[str]) -> ColorMapper: _palette = palette elif n < 21: from bokeh.palettes import Category20 + _palette = Category20[n] else: from bokeh.palettes import viridis + _palette = viridis(n) return CategoricalColorMapper(factors=factors, palette=_palette) def _get_scatter_plot_elements( - sp: Plot, - source_pd: pd.DataFrame, - x_col: str, - y_col: str, - label_cols: List[str], - colors: Optional[Dict[str, ColorMapper]] = None, - size: int = 4, - hover_cols: Optional[Set[str]] = None, -) -> Union[Tuple[Plot, Dict[str, List[LegendItem]], Legend, ColorBar, Dict[str, ColorMapper], List[Renderer]], - Tuple[Plot, None, None, None, None, None]]: - + sp: Plot, + source_pd: pd.DataFrame, + x_col: str, + y_col: str, + label_cols: List[str], + colors: Optional[Dict[str, ColorMapper]] = None, + size: int = 4, + hover_cols: Optional[Set[str]] = None, +) -> Union[ + Tuple[Plot, Dict[str, List[LegendItem]], Legend, ColorBar, Dict[str, ColorMapper], List[Renderer]], + Tuple[Plot, None, None, None, None, None], +]: if not source_pd.shape[0]: print("WARN: No data to plot.") return sp, None, None, None, None, None - possible_tooltips = [ - (x_col, f'@{x_col}'), - (y_col, f'@{y_col}') - ] + [ - (c, f'@{c}') - for c in source_pd.columns - if c not in [x_col, y_col] + possible_tooltips = [(x_col, f'@{x_col}'), (y_col, f'@{y_col}')] + [ + (c, f'@{c}') for c in source_pd.columns if c not in [x_col, y_col] ] if hover_cols is not None: - possible_tooltips = [ - x - for x in possible_tooltips - if x[0] in hover_cols - ] + possible_tooltips = [x for x in possible_tooltips if x[0] in hover_cols] sp.tools.append(HoverTool(tooltips=possible_tooltips)) cds = ColumnDataSource(source_pd) @@ -818,9 +929,11 @@ def _get_scatter_plot_elements( if not label_cols: sp.circle(x_col, y_col, source=cds, size=size) return sp, None, None, None, None, None - continuous_cols = [col for col in label_cols if - (str(source_pd.dtypes[col]).startswith('float') - or str(source_pd.dtypes[col]).startswith('int'))] + continuous_cols = [ + col + for col in label_cols + if (str(source_pd.dtypes[col]).startswith('float') or str(source_pd.dtypes[col]).startswith('int')) + ] factor_cols = [col for col in label_cols if col not in continuous_cols] # Assign color mappers to columns @@ -850,29 +963,38 @@ def _get_scatter_plot_elements( legend_items: Dict[str, List[LegendItem]] = {} if not factor_cols: - all_renderers = [ - sp.circle(x_col, y_col, color=transform(initial_col, initial_mapper), source=cds, size=size) - ] + all_renderers = [sp.circle(x_col, y_col, color=transform(initial_col, initial_mapper), source=cds, size=size)] else: all_renderers = [] legend_items_by_key_by_factor = {col: collections.defaultdict(list) for col in factor_cols} for key in source_pd.groupby(factor_cols).groups.keys(): - key = key if len(factor_cols) > 1 else [key] + _key = key if len(factor_cols) > 1 else [key] cds_view = CDSView( - filter=IntersectionFilter(operands=[GroupFilter(column_name=factor_cols[i], group=key[i]) for i in range(0, len(factor_cols))]) + filter=IntersectionFilter( + operands=[ + GroupFilter(column_name=factor_cols[i], group=_key[i]) for i in range(0, len(factor_cols)) + ] + ) + ) + renderer = sp.circle( + x_col, y_col, color=transform(initial_col, initial_mapper), source=cds, view=cds_view, size=size ) - renderer = sp.circle(x_col, y_col, color=transform(initial_col, initial_mapper), source=cds, view=cds_view, size=size) all_renderers.append(renderer) for i in range(0, len(factor_cols)): - legend_items_by_key_by_factor[factor_cols[i]][key[i]].append(renderer) + legend_items_by_key_by_factor[factor_cols[i]][_key[i]].append(renderer) - legend_items = {factor: [LegendItem(label=key, renderers=renderers) - for key, renderers in key_renderers.items()] - for factor, key_renderers in legend_items_by_key_by_factor.items()} + legend_items = { + factor: [LegendItem(label=key, renderers=renderers) for key, renderers in key_renderers.items()] + for factor, key_renderers in legend_items_by_key_by_factor.items() + } # Add legend / color bar - legend = Legend(visible=False, click_policy='hide', orientation='vertical') if initial_col not in factor_cols else Legend(items=legend_items[initial_col], click_policy='hide', orientation='vertical') + legend = ( + Legend(visible=False, click_policy='hide', orientation='vertical') + if initial_col not in factor_cols + else Legend(items=legend_items[initial_col], click_policy='hide', orientation='vertical') + ) color_bar = ColorBar(color_mapper=color_mappers[initial_col]) if initial_col not in continuous_cols: color_bar.visible = False @@ -882,10 +1004,7 @@ def _get_scatter_plot_elements( return sp, legend_items, legend, color_bar, color_mappers, all_renderers -def _downsampling_factor(fname: str, - n_divisions: Optional[int], - collect_all: Optional[bool] - ) -> Optional[int]: +def _downsampling_factor(fname: str, n_divisions: Optional[int], collect_all: Optional[bool]) -> Optional[int]: if collect_all is not None: warnings.warn(f'{fname}: `collect_all` has been deprecated. Use `n_divisions` instead.') if n_divisions is not None and collect_all is not None: @@ -899,38 +1018,39 @@ def _downsampling_factor(fname: str, return n_divisions -@typecheck(x=oneof(expr_numeric, sized_tupleof(str, expr_numeric)), - y=oneof(expr_numeric, sized_tupleof(str, expr_numeric)), - label=nullable(oneof(dictof(str, expr_any), expr_any)), - title=nullable(str), - xlabel=nullable(str), - ylabel=nullable(str), - size=int, - legend=bool, - hover_fields=nullable(dictof(str, expr_any)), - colors=nullable(oneof(bokeh.models.mappers.ColorMapper, dictof(str, bokeh.models.mappers.ColorMapper))), - width=int, - height=int, - collect_all=nullable(bool), - n_divisions=nullable(int), - missing_label=str - ) +@typecheck( + x=oneof(expr_numeric, sized_tupleof(str, expr_numeric)), + y=oneof(expr_numeric, sized_tupleof(str, expr_numeric)), + label=nullable(oneof(dictof(str, expr_any), expr_any)), + title=nullable(str), + xlabel=nullable(str), + ylabel=nullable(str), + size=int, + legend=bool, + hover_fields=nullable(dictof(str, expr_any)), + colors=nullable(oneof(bokeh.models.mappers.ColorMapper, dictof(str, bokeh.models.mappers.ColorMapper))), + width=int, + height=int, + collect_all=nullable(bool), + n_divisions=nullable(int), + missing_label=str, +) def scatter( - x: Union[NumericExpression, Tuple[str, NumericExpression]], - y: Union[NumericExpression, Tuple[str, NumericExpression]], - label: Optional[Union[Expression, Dict[str, Expression]]] = None, - title: Optional[str] = None, - xlabel: Optional[str] = None, - ylabel: Optional[str] = None, - size: int = 4, - legend: bool = True, - hover_fields: Optional[Dict[str, Expression]] = None, - colors: Optional[Union[ColorMapper, Dict[str, ColorMapper]]] = None, - width: int = 800, - height: int = 800, - collect_all: Optional[bool] = None, - n_divisions: Optional[int] = 500, - missing_label: str = 'NA' + x: Union[NumericExpression, Tuple[str, NumericExpression]], + y: Union[NumericExpression, Tuple[str, NumericExpression]], + label: Optional[Union[Expression, Dict[str, Expression]]] = None, + title: Optional[str] = None, + xlabel: Optional[str] = None, + ylabel: Optional[str] = None, + size: int = 4, + legend: bool = True, + hover_fields: Optional[Dict[str, Expression]] = None, + colors: Optional[Union[ColorMapper, Dict[str, ColorMapper]]] = None, + width: int = 800, + height: int = 800, + collect_all: Optional[bool] = None, + n_divisions: Optional[int] = 500, + missing_label: str = 'NA', ) -> Union[Plot, Column]: """Create an interactive scatter plot. @@ -1032,12 +1152,11 @@ def scatter( _y, fields={**hover_fields, **label_by_col}, n_divisions=_downsampling_factor('scatter', n_divisions, collect_all), - missing_label=missing_label + missing_label=missing_label, ) sp = figure(title=title, x_axis_label=xlabel, y_axis_label=ylabel, height=height, width=width) sp, sp_legend_items, sp_legend, sp_color_bar, sp_color_mappers, sp_scatter_renderers = _get_scatter_plot_elements( - sp, source_pd, _x[0], _y[0], label_cols, colors_by_col, size, - hover_cols={'x', 'y'} | set(hover_fields) + sp, source_pd, _x[0], _y[0], label_cols, colors_by_col, size, hover_cols={'x', 'y'} | set(hover_fields) ) if not legend: @@ -1049,10 +1168,7 @@ def scatter( # If multiple labels, create JS call back selector if len(label_cols) > 1: callback_args: Dict[str, Any] - callback_args = dict( - color_mappers=sp_color_mappers, - scatter_renderers=sp_scatter_renderers - ) + callback_args = dict(color_mappers=sp_color_mappers, scatter_renderers=sp_scatter_renderers) callback_code = """ for (var i = 0; i < scatter_renderers.length; i++){ scatter_renderers[i].glyph.fill_color = {field: cb_obj.value, transform: color_mappers[cb_obj.value]} @@ -1063,11 +1179,7 @@ def scatter( """ if legend: - callback_args.update(dict( - legend_items=sp_legend_items, - legend=sp_legend, - color_bar=sp_color_bar - )) + callback_args.update(dict(legend_items=sp_legend_items, legend=sp_legend, color_bar=sp_color_bar)) callback_code += """ if (cb_obj.value in legend_items){ legend.items=legend_items[cb_obj.value] @@ -1088,107 +1200,109 @@ def scatter( return sp -@typecheck(x=oneof(expr_numeric, sized_tupleof(str, expr_numeric)), - y=oneof(expr_numeric, sized_tupleof(str, expr_numeric)), - label=nullable(oneof(dictof(str, expr_any), expr_any)), - title=nullable(str), - xlabel=nullable(str), ylabel=nullable(str), - size=int, - legend=bool, - hover_fields=nullable(dictof(str, expr_any)), - colors=nullable(oneof(bokeh.models.mappers.ColorMapper, dictof(str, bokeh.models.mappers.ColorMapper))), - width=int, - height=int, - collect_all=nullable(bool), - n_divisions=nullable(int), - missing_label=str - ) +@typecheck( + x=oneof(expr_numeric, sized_tupleof(str, expr_numeric)), + y=oneof(expr_numeric, sized_tupleof(str, expr_numeric)), + label=nullable(oneof(dictof(str, expr_any), expr_any)), + title=nullable(str), + xlabel=nullable(str), + ylabel=nullable(str), + size=int, + legend=bool, + hover_fields=nullable(dictof(str, expr_any)), + colors=nullable(oneof(bokeh.models.mappers.ColorMapper, dictof(str, bokeh.models.mappers.ColorMapper))), + width=int, + height=int, + collect_all=nullable(bool), + n_divisions=nullable(int), + missing_label=str, +) def joint_plot( - x: Union[NumericExpression, Tuple[str, NumericExpression]], - y: Union[NumericExpression, Tuple[str, NumericExpression]], - label: Optional[Union[Expression, Dict[str, Expression]]] = None, - title: Optional[str] = None, - xlabel: Optional[str] = None, - ylabel: Optional[str] = None, - size: int = 4, - legend: bool = True, - hover_fields: Optional[Dict[str, StringExpression]] = None, - colors: Optional[Union[ColorMapper, Dict[str, ColorMapper]]] = None, - width: int = 800, - height: int = 800, - collect_all: Optional[bool] = None, - n_divisions: Optional[int] = 500, - missing_label: str = 'NA' + x: Union[NumericExpression, Tuple[str, NumericExpression]], + y: Union[NumericExpression, Tuple[str, NumericExpression]], + label: Optional[Union[Expression, Dict[str, Expression]]] = None, + title: Optional[str] = None, + xlabel: Optional[str] = None, + ylabel: Optional[str] = None, + size: int = 4, + legend: bool = True, + hover_fields: Optional[Dict[str, StringExpression]] = None, + colors: Optional[Union[ColorMapper, Dict[str, ColorMapper]]] = None, + width: int = 800, + height: int = 800, + collect_all: Optional[bool] = None, + n_divisions: Optional[int] = 500, + missing_label: str = 'NA', ) -> GridPlot: """Create an interactive scatter plot with marginal densities on the side. - ``x`` and ``y`` must both be either: - - a :class:`.NumericExpression` from the same :class:`.Table`. - - a tuple (str, :class:`.NumericExpression`) from the same :class:`.Table`. If passed as a tuple the first element is used as the hover label. - - This function returns a :class:`bokeh.models.layouts.Column` containing two :class:`figure.Row`: - - The first row contains the X-axis marginal density and a selection widget if multiple entries are specified in the ``label`` - - The second row contains the scatter plot and the y-axis marginal density - - Points will be colored by one of the labels defined in the ``label`` using the color scheme defined in - the corresponding entry of ``colors`` if provided (otherwise a default scheme is used). To specify your color - mapper, check `the bokeh documentation `__ - for CategoricalMapper for categorical labels, and for LinearColorMapper and LogColorMapper - for continuous labels. - For categorical labels, clicking on one of the items in the legend will hide/show all points with the corresponding label in the scatter plot. - Note that using many different labelling schemes in the same plots, particularly if those labels contain many - different classes could slow down the plot interactions. - - Hovering on points in the scatter plot displays their coordinates, labels and any additional fields specified in ``hover_fields``. - - Parameters - ---------- - ---------- - x : :class:`.NumericExpression` or (str, :class:`.NumericExpression`) - List of x-values to be plotted. - y : :class:`.NumericExpression` or (str, :class:`.NumericExpression`) - List of y-values to be plotted. - label : :class:`.Expression` or Dict[str, :class:`.Expression`]], optional - Either a single expression (if a single label is desired), or a - dictionary of label name -> label value for x and y values. - Used to color each point w.r.t its label. - When multiple labels are given, a dropdown will be displayed with the different options. - Can be used with categorical or continuous expressions. - title : str, optional - Title of the scatterplot. - xlabel : str, optional - X-axis label. - ylabel : str, optional - Y-axis label. - size : int - Size of markers in screen space units. - legend: bool - Whether or not to show the legend in the resulting figure. - hover_fields : Dict[str, :class:`.Expression`], optional - Extra fields to be displayed when hovering over a point on the plot. - colors : :class:`bokeh.models.mappers.ColorMapper` or Dict[str, :class:`bokeh.models.mappers.ColorMapper`], optional - If a single label is used, then this can be a color mapper, if multiple labels are used, then this should - be a Dict of label name -> color mapper. - Used to set colors for the labels defined using ``label``. - If not used at all, or label names not appearing in this dict will be colored using a default color scheme. - width: int - Plot width - height: int - Plot height - collect_all : bool, optional - Deprecated. Use `n_divisions` instead. - n_divisions : int, optional - Factor by which to downsample (default value = 500). - A lower input results in fewer output datapoints. - Use `None` to collect all points. - missing_label: str - Label to use when a point is missing data for a categorical label - - - Returns - ------- - :class:`.GridPlot` - """ + ``x`` and ``y`` must both be either: + - a :class:`.NumericExpression` from the same :class:`.Table`. + - a tuple (str, :class:`.NumericExpression`) from the same :class:`.Table`. If passed as a tuple the first element is used as the hover label. + + This function returns a :class:`bokeh.models.layouts.Column` containing two :class:`figure.Row`: + - The first row contains the X-axis marginal density and a selection widget if multiple entries are specified in the ``label`` + - The second row contains the scatter plot and the y-axis marginal density + + Points will be colored by one of the labels defined in the ``label`` using the color scheme defined in + the corresponding entry of ``colors`` if provided (otherwise a default scheme is used). To specify your color + mapper, check `the bokeh documentation `__ + for CategoricalMapper for categorical labels, and for LinearColorMapper and LogColorMapper + for continuous labels. + For categorical labels, clicking on one of the items in the legend will hide/show all points with the corresponding label in the scatter plot. + Note that using many different labelling schemes in the same plots, particularly if those labels contain many + different classes could slow down the plot interactions. + + Hovering on points in the scatter plot displays their coordinates, labels and any additional fields specified in ``hover_fields``. + + Parameters + ---------- + ---------- + x : :class:`.NumericExpression` or (str, :class:`.NumericExpression`) + List of x-values to be plotted. + y : :class:`.NumericExpression` or (str, :class:`.NumericExpression`) + List of y-values to be plotted. + label : :class:`.Expression` or Dict[str, :class:`.Expression`]], optional + Either a single expression (if a single label is desired), or a + dictionary of label name -> label value for x and y values. + Used to color each point w.r.t its label. + When multiple labels are given, a dropdown will be displayed with the different options. + Can be used with categorical or continuous expressions. + title : str, optional + Title of the scatterplot. + xlabel : str, optional + X-axis label. + ylabel : str, optional + Y-axis label. + size : int + Size of markers in screen space units. + legend: bool + Whether or not to show the legend in the resulting figure. + hover_fields : Dict[str, :class:`.Expression`], optional + Extra fields to be displayed when hovering over a point on the plot. + colors : :class:`bokeh.models.mappers.ColorMapper` or Dict[str, :class:`bokeh.models.mappers.ColorMapper`], optional + If a single label is used, then this can be a color mapper, if multiple labels are used, then this should + be a Dict of label name -> color mapper. + Used to set colors for the labels defined using ``label``. + If not used at all, or label names not appearing in this dict will be colored using a default color scheme. + width: int + Plot width + height: int + Plot height + collect_all : bool, optional + Deprecated. Use `n_divisions` instead. + n_divisions : int, optional + Factor by which to downsample (default value = 500). + A lower input results in fewer output datapoints. + Use `None` to collect all points. + missing_label: str + Label to use when a point is missing data for a categorical label + + + Returns + ------- + :class:`.GridPlot` + """ # Collect data hover_fields = {} if hover_fields is None else hover_fields @@ -1221,30 +1335,30 @@ def joint_plot( _y, fields={**hover_fields, **label_by_col}, n_divisions=_downsampling_factor('join_plot', n_divisions, collect_all), - missing_label=missing_label + missing_label=missing_label, ) sp = figure(title=title, x_axis_label=xlabel, y_axis_label=ylabel, height=height, width=width) sp, sp_legend_items, sp_legend, sp_color_bar, sp_color_mappers, sp_scatter_renderers = _get_scatter_plot_elements( - sp, source_pd, _x[0], _y[0], label_cols, colors_by_col, size, - hover_cols={'x', 'y'} | set(hover_fields) + sp, source_pd, _x[0], _y[0], label_cols, colors_by_col, size, hover_cols={'x', 'y'} | set(hover_fields) ) - continuous_cols = [col for col in label_cols if - (str(source_pd.dtypes[col]).startswith('float') - or str(source_pd.dtypes[col]).startswith('int'))] + continuous_cols = [ + col + for col in label_cols + if (str(source_pd.dtypes[col]).startswith('float') or str(source_pd.dtypes[col]).startswith('int')) + ] factor_cols = [col for col in label_cols if col not in continuous_cols] # Density plots def get_density_plot_items( - source_pd, - data_col, - p, - x_axis, - colors: Optional[Dict[str, ColorMapper]], - continuous_cols: List[str], - factor_cols: List[str] + source_pd, + data_col, + p, + x_axis, + colors: Optional[Dict[str, ColorMapper]], + continuous_cols: List[str], + factor_cols: List[str], ): - density_renderers = [] max_densities = {} if not factor_cols or continuous_cols: @@ -1260,23 +1374,47 @@ def get_density_plot_items( assert colors is not None, (colors, factor_cols) factor_colors = colors.get(factor_col, _get_categorical_palette(list(set(source_pd[factor_col])))) factor_colors = dict(zip(factor_colors.factors, factor_colors.palette)) - density_data = source_pd[[factor_col, data_col]].groupby(factor_col).apply(lambda df: np.histogram(df['x' if x_axis else 'y'], density=True)) + density_data = ( + source_pd[[factor_col, data_col]] + .groupby(factor_col) + .apply(lambda df: np.histogram(df['x' if x_axis else 'y'], density=True)) + ) for factor, (dens, edges) in density_data.iteritems(): - edges = edges[:-1] - xy = (edges, dens) if x_axis else (dens, edges) + _edges = edges[:-1] + xy = (_edges, dens) if x_axis else (dens, _edges) cds = ColumnDataSource({'x': xy[0], 'y': xy[1]}) - density_renderers.append((factor_col, factor, p.line('x', 'y', color=factor_colors.get(factor, 'gray'), source=cds))) - max_densities[factor_col] = np.max(list(dens) + [max_densities.get(factor_col, 0)]) + density_renderers.append(( + factor_col, + factor, + p.line('x', 'y', color=factor_colors.get(factor, 'gray'), source=cds), + )) + max_densities[factor_col] = np.max([*list(dens), max_densities.get(factor_col, 0)]) p.grid.visible = False p.outline_line_color = None return p, density_renderers, max_densities xp = figure(title=title, height=int(height / 3), width=width, x_range=sp.x_range) - xp, x_renderers, x_max_densities = get_density_plot_items(source_pd, _x[0], xp, x_axis=True, colors=sp_color_mappers, continuous_cols=continuous_cols, factor_cols=factor_cols) + xp, x_renderers, x_max_densities = get_density_plot_items( + source_pd, + _x[0], + xp, + x_axis=True, + colors=sp_color_mappers, + continuous_cols=continuous_cols, + factor_cols=factor_cols, + ) xp.xaxis.visible = False yp = figure(height=height, width=int(width / 3), y_range=sp.y_range) - yp, y_renderers, y_max_densities = get_density_plot_items(source_pd, _y[0], yp, x_axis=False, colors=sp_color_mappers, continuous_cols=continuous_cols, factor_cols=factor_cols) + yp, y_renderers, y_max_densities = get_density_plot_items( + source_pd, + _y[0], + yp, + x_axis=False, + colors=sp_color_mappers, + continuous_cols=continuous_cols, + factor_cols=factor_cols, + ) yp.yaxis.visible = False density_renderers = x_renderers + y_renderers first_row = [xp] @@ -1289,7 +1427,6 @@ def get_density_plot_items( # If multiple labels, create JS call back selector if len(label_cols) > 1: - for factor_col, _, renderer in density_renderers: renderer.visible = factor_col == label_cols[0] @@ -1307,7 +1444,7 @@ def get_density_plot_items( x_range=xp.y_range, x_max_densities=x_max_densities, y_range=yp.x_range, - y_max_densities=y_max_densities + y_max_densities=y_max_densities, ) callback_code = """ @@ -1329,11 +1466,7 @@ def get_density_plot_items( """ if legend: - callback_args.update(dict( - legend_items=sp_legend_items, - legend=sp_legend, - color_bar=sp_color_bar - )) + callback_args.update(dict(legend_items=sp_legend_items, legend=sp_legend, color_bar=sp_color_bar)) callback_code += """ if (cb_obj.value in legend_items){ legend.items=legend_items[cb_obj.value] @@ -1354,36 +1487,37 @@ def get_density_plot_items( return gridplot([first_row, [sp, yp]]) -@typecheck(pvals=expr_numeric, - label=nullable(oneof(dictof(str, expr_any), expr_any)), - title=nullable(str), - xlabel=nullable(str), - ylabel=nullable(str), - size=int, - legend=bool, - hover_fields=nullable(dictof(str, expr_any)), - colors=nullable(oneof(bokeh.models.mappers.ColorMapper, dictof(str, bokeh.models.mappers.ColorMapper))), - width=int, - height=int, - collect_all=nullable(bool), - n_divisions=nullable(int), - missing_label=str - ) +@typecheck( + pvals=expr_numeric, + label=nullable(oneof(dictof(str, expr_any), expr_any)), + title=nullable(str), + xlabel=nullable(str), + ylabel=nullable(str), + size=int, + legend=bool, + hover_fields=nullable(dictof(str, expr_any)), + colors=nullable(oneof(bokeh.models.mappers.ColorMapper, dictof(str, bokeh.models.mappers.ColorMapper))), + width=int, + height=int, + collect_all=nullable(bool), + n_divisions=nullable(int), + missing_label=str, +) def qq( - pvals: NumericExpression, - label: Optional[Union[Expression, Dict[str, Expression]]] = None, - title: Optional[str] = 'Q-Q plot', - xlabel: Optional[str] = 'Expected -log10(p)', - ylabel: Optional[str] = 'Observed -log10(p)', - size: int = 6, - legend: bool = True, - hover_fields: Optional[Dict[str, Expression]] = None, - colors: Optional[Union[ColorMapper, Dict[str, ColorMapper]]] = None, - width: int = 800, - height: int = 800, - collect_all: Optional[bool] = None, - n_divisions: Optional[int] = 500, - missing_label: str = 'NA' + pvals: NumericExpression, + label: Optional[Union[Expression, Dict[str, Expression]]] = None, + title: Optional[str] = 'Q-Q plot', + xlabel: Optional[str] = 'Expected -log10(p)', + ylabel: Optional[str] = 'Observed -log10(p)', + size: int = 6, + legend: bool = True, + hover_fields: Optional[Dict[str, Expression]] = None, + colors: Optional[Union[ColorMapper, Dict[str, ColorMapper]]] = None, + width: int = 800, + height: int = 800, + collect_all: Optional[bool] = None, + n_divisions: Optional[int] = 500, + missing_label: str = 'NA', ) -> Union[figure, Column]: """Create a Quantile-Quantile plot. (https://en.wikipedia.org/wiki/Q-Q_plot) @@ -1465,10 +1599,7 @@ def qq( ht = source.select_rows(p_value=pvals, **hover_fields, **label_by_col).rows() ht = ht.key_by().select('p_value', *hover_fields, *label_by_col).key_by('p_value') n = ht.aggregate(aggregators.count(), _localize=False) - ht = ht.annotate( - observed_p=-hail.log10(ht['p_value']), - expected_p=-hail.log10((hail.scan.count() + 1) / n) - ) + ht = ht.annotate(observed_p=-hail.log10(ht['p_value']), expected_p=-hail.log10((hail.scan.count() + 1) / n)) if 'p' not in hover_fields: hover_fields['p_value'] = ht['p_value'] p = scatter( @@ -1485,10 +1616,14 @@ def qq( width=width, height=height, n_divisions=_downsampling_factor('qq', n_divisions, collect_all), - missing_label=missing_label + missing_label=missing_label, ) from hail.methods.statgen import _lambda_gc_agg - lambda_gc, max_p = ht.aggregate((_lambda_gc_agg(ht['p_value']), hail.agg.max(hail.max(ht.observed_p, ht.expected_p)))) + + lambda_gc, max_p = ht.aggregate(( + _lambda_gc_agg(ht['p_value']), + hail.agg.max(hail.max(ht.observed_p, ht.expected_p)), + )) if isinstance(p, Column): qq = p.children[1] else: @@ -1498,31 +1633,39 @@ def qq( qq.add_layout(Slope(gradient=1, y_intercept=0, line_color='red')) label_color = 'red' if lambda_gc > 1.25 else 'orange' if lambda_gc > 1.1 else 'black' - lgc_label = Label(x=max_p * 0.85, y=1, text=f'λ GC: {lambda_gc:.2f}', - text_font_style='bold', text_color=label_color, text_font_size='14pt') + lgc_label = Label( + x=max_p * 0.85, + y=1, + text=f'λ GC: {lambda_gc:.2f}', + text_font_style='bold', + text_color=label_color, + text_font_size='14pt', + ) p.add_layout(lgc_label) return p -@typecheck(pvals=expr_float64, - locus=nullable(expr_locus()), - title=nullable(str), - size=int, - hover_fields=nullable(dictof(str, expr_any)), - collect_all=nullable(bool), - n_divisions=nullable(int), - significance_line=nullable(numeric) - ) -def manhattan(pvals: 'Float64Expression', - locus: 'Optional[LocusExpression]' = None, - title: 'Optional[str]' = None, - size: int = 4, - hover_fields: 'Optional[Dict[str, Expression]]' = None, - collect_all: 'Optional[bool]' = None, - n_divisions: 'Optional[int]' = 500, - significance_line: 'Optional[Union[int, float]]' = 5e-8 - ) -> Plot: +@typecheck( + pvals=expr_float64, + locus=nullable(expr_locus()), + title=nullable(str), + size=int, + hover_fields=nullable(dictof(str, expr_any)), + collect_all=nullable(bool), + n_divisions=nullable(int), + significance_line=nullable(numeric), +) +def manhattan( + pvals: 'Float64Expression', + locus: 'Optional[LocusExpression]' = None, + title: 'Optional[str]' = None, + size: int = 4, + hover_fields: 'Optional[Dict[str, Expression]]' = None, + collect_all: 'Optional[bool]' = None, + n_divisions: 'Optional[int]' = 500, + significance_line: 'Optional[Union[int, float]]' = 5e-8, +) -> Plot: """Create a Manhattan plot. (https://en.wikipedia.org/wiki/Manhattan_plot) Parameters @@ -1567,15 +1710,12 @@ def manhattan(pvals: 'Float64Expression', ('_global_locus', locus.global_position()), ('_pval', pvals), fields=hover_fields, - n_divisions=_downsampling_factor('manhattan', n_divisions, collect_all) + n_divisions=_downsampling_factor('manhattan', n_divisions, collect_all), ) source_pd['p_value'] = [10 ** (-p) for p in source_pd['_pval']] source_pd['_contig'] = [locus.split(":")[0] for locus in source_pd['locus']] - observed_contigs = [ - contig for contig in ref.contigs.copy() - if contig in set(source_pd['_contig']) - ] + observed_contigs = [contig for contig in ref.contigs.copy() if contig in set(source_pd['_contig'])] contig_ticks = [ref._contig_global_position(contig) + ref.contig_length(contig) // 2 for contig in observed_contigs] color_mapper = CategoricalColorMapper(factors=ref.contigs, palette=palette[:2] * int((len(ref.contigs) + 1) / 2)) @@ -1589,7 +1729,7 @@ def manhattan(pvals: 'Float64Expression', label_cols=['_contig'], colors={'_contig': color_mapper}, size=size, - hover_cols={'locus', 'p_value'} | set(hover_fields) + hover_cols={'locus', 'p_value'} | set(hover_fields), ) assert legend is not None legend.visible = False @@ -1597,19 +1737,30 @@ def manhattan(pvals: 'Float64Expression', p.xaxis.major_label_overrides = dict(zip(contig_ticks, [contig.replace("chr", "") for contig in observed_contigs])) if significance_line is not None: - p.renderers.append(Span(location=-math.log10(significance_line), - dimension='width', - line_color='red', - line_dash='dashed', - line_width=1.5)) + p.renderers.append( + Span( + location=-math.log10(significance_line), + dimension='width', + line_color='red', + line_dash='dashed', + line_width=1.5, + ) + ) return p -@typecheck(entry_field=expr_any, row_field=nullable(oneof(expr_numeric, expr_locus())), column_field=nullable(expr_str), - window=nullable(int), plot_width=int, plot_height=int) -def visualize_missingness(entry_field, row_field=None, column_field=None, - window=6000000, plot_width=1800, plot_height=900) -> figure: +@typecheck( + entry_field=expr_any, + row_field=nullable(oneof(expr_numeric, expr_locus())), + column_field=nullable(expr_str), + window=nullable(int), + plot_width=int, + plot_height=int, +) +def visualize_missingness( + entry_field, row_field=None, column_field=None, window=6000000, plot_width=1800, plot_height=900 +) -> figure: """Visualize missingness in a MatrixTable. Inspired by `naniar `__. @@ -1657,33 +1808,41 @@ def visualize_missingness(entry_field, row_field=None, column_field=None, raise ValueError("visualize_missingness requires source to be MatrixTable, not Table") columns = column_field.collect() if not (mt == row_source == column_source): - raise ValueError(f"visualize_missingness expects expressions from the same 'MatrixTable', " - f"found {mt} and {row_source} and {column_source}") + raise ValueError( + f"visualize_missingness expects expressions from the same 'MatrixTable', " + f"found {mt} and {row_source} and {column_source}" + ) # raise_unless_row_indexed('visualize_missingness', row_source) if window: row_field_is_locus = isinstance(row_field.dtype, hail.tlocus) row_field_is_numeric = row_field.dtype in (hail.tint32, hail.tint64, hail.tfloat32, hail.tfloat64) if row_field_is_locus: - grouping = hail.locus_from_global_position(hail.int64(window) - * hail.int64(row_field.global_position() / window)) + grouping = hail.locus_from_global_position( + hail.int64(window) * hail.int64(row_field.global_position() / window) + ) elif row_field_is_numeric: grouping = hail.int64(window) * hail.int64(row_field / window) else: - raise ValueError(f'When window is not None and row key must be numeric, but row key type was {mt.row_key.dtype}.') - mt = mt.group_rows_by( - _new_row_key=grouping - ).partition_hint(100).aggregate( - is_defined=hail.agg.fraction(hail.is_defined(entry_field)) + raise ValueError( + f'When window is not None and row key must be numeric, but row key type was {mt.row_key.dtype}.' + ) + mt = ( + mt.group_rows_by(_new_row_key=grouping) + .partition_hint(100) + .aggregate(is_defined=hail.agg.fraction(hail.is_defined(entry_field))) ) else: - mt = mt._select_all(row_exprs={'_new_row_key': row_field}, - entry_exprs={'is_defined': hail.is_defined(entry_field)}) + mt = mt._select_all( + row_exprs={'_new_row_key': row_field}, entry_exprs={'is_defined': hail.is_defined(entry_field)} + ) ht = mt.localize_entries('entry_fields', 'phenos') ht = ht.select(entry_fields=ht.entry_fields.map(lambda entry: entry.is_defined)) data = ht.entry_fields.collect() if len(data) > 200: - warning(f'Missingness dataset has {len(data)} rows. ' - f'This may take {"a very long time" if len(data) > 1000 else "a few minutes"} to plot.') + warning( + f'Missingness dataset has {len(data)} rows. ' + f'This may take {"a very long time" if len(data) > 1000 else "a few minutes"} to plot.' + ) rows = hail.str(ht._new_row_key).collect() df = pd.DataFrame(data) @@ -1693,11 +1852,15 @@ def visualize_missingness(entry_field, row_field=None, column_field=None, df = pd.DataFrame(df.stack(), columns=['defined']).reset_index() - p = figure(x_range=columns, y_range=list(reversed(rows)), - x_axis_location="above", width=plot_width, height=plot_height, - toolbar_location='below', - tooltips=[('defined', '@defined'), ('row', '@row'), ('column', '@column')] - ) + p = figure( + x_range=columns, + y_range=list(reversed(rows)), + x_axis_location="above", + width=plot_width, + height=plot_height, + toolbar_location='below', + tooltips=[('defined', '@defined'), ('row', '@row'), ('column', '@column')], + ) p.grid.grid_line_color = None p.axis.axis_line_color = None @@ -1705,18 +1868,28 @@ def visualize_missingness(entry_field, row_field=None, column_field=None, p.axis.major_label_text_font_size = "5pt" p.axis.major_label_standoff = 0 colors = ["#75968f", "#a5bab7", "#c9d9d3", "#e2e2e2", "#dfccce", "#ddb7b1", "#cc7878", "#933b41", "#550b1d"] - from bokeh.models import LinearColorMapper, ColorBar, BasicTicker, PrintfTickFormatter + from bokeh.models import BasicTicker, ColorBar, LinearColorMapper, PrintfTickFormatter mapper = LinearColorMapper(palette=colors, low=df.defined.min(), high=df.defined.max()) - p.rect(x='column', y='row', width=1, height=1, - source=df, - fill_color={'field': 'defined', 'transform': mapper}, - line_color=None) + p.rect( + x='column', + y='row', + width=1, + height=1, + source=df, + fill_color={'field': 'defined', 'transform': mapper}, + line_color=None, + ) - color_bar = ColorBar(color_mapper=mapper, major_label_text_font_size="5pt", - ticker=BasicTicker(desired_num_ticks=len(colors)), - formatter=PrintfTickFormatter(format="%d"), - label_standoff=6, border_line_color=None, location=(0, 0)) + color_bar = ColorBar( + color_mapper=mapper, + major_label_text_font_size="5pt", + ticker=BasicTicker(desired_num_ticks=len(colors)), + formatter=PrintfTickFormatter(format="%d"), + label_standoff=6, + border_line_color=None, + location=(0, 0), + ) p.add_layout(color_bar, 'right') return p diff --git a/hail/python/hail/stats/__init__.py b/hail/python/hail/stats/__init__.py index 9995cbb25f5..6b7ccb725cd 100644 --- a/hail/python/hail/stats/__init__.py +++ b/hail/python/hail/stats/__init__.py @@ -1,4 +1,3 @@ - from .linear_mixed_model import LinearMixedModel __all__ = [ diff --git a/hail/python/hail/stats/linear_mixed_model.py b/hail/python/hail/stats/linear_mixed_model.py index 672d811ab29..6a0508c3966 100644 --- a/hail/python/hail/stats/linear_mixed_model.py +++ b/hail/python/hail/stats/linear_mixed_model.py @@ -6,5 +6,6 @@ class LinearMixedModel(object): This functionality is no longer implemented/supported as of Hail 0.2.94. """ + def __init__(self, py, px, s, y=None, x=None, p_path=None): raise NotImplementedError("LinearMixedModel is no longer implemented/supported as of Hail 0.2.94") diff --git a/hail/python/hail/table.py b/hail/python/hail/table.py index b1cccc3d09f..f0fd6e83ba1 100644 --- a/hail/python/hail/table.py +++ b/hail/python/hail/table.py @@ -1,33 +1,76 @@ import collections import itertools -import pandas -import numpy as np -import pyspark import pprint import shutil -from typing import Optional, Dict, Callable, Sequence, Union, List, overload - -from hail.expr.expressions import Expression, StructExpression, \ - BooleanExpression, expr_struct, expr_any, expr_bool, analyze, Indices, \ - construct_reference, to_expr, construct_expr, extract_refs_by_indices, \ - ExpressionException, TupleExpression, unify_all, NumericExpression, \ - StringExpression, CallExpression, CollectionExpression, DictExpression, \ - IntervalExpression, LocusExpression, NDArrayExpression, expr_stream, \ - expr_array -from hail.expr.types import hail_type, tstruct, types_match, tarray, tset, dtypes_from_pandas +from typing import Callable, ClassVar, Dict, List, Optional, Sequence, Union, overload + +import numpy as np +import pandas +import pyspark + +import hail as hl +from hail import ir +from hail.expr.expressions import ( + ArrayExpression, + BooleanExpression, + CallExpression, + CollectionExpression, + DictExpression, + Expression, + ExpressionException, + Indices, + IntervalExpression, + LocusExpression, + NDArrayExpression, + NumericExpression, + StringExpression, + StructExpression, + TupleExpression, + analyze, + construct_expr, + construct_reference, + expr_any, + expr_array, + expr_bool, + expr_stream, + expr_struct, + extract_refs_by_indices, + to_expr, + unify_all, +) from hail.expr.table_type import ttable -import hail.ir as ir -from hail.typecheck import typecheck, typecheck_method, dictof, anytype, \ - anyfunc, nullable, sequenceof, oneof, numeric, lazy, enumeration, \ - table_key_type, func_spec +from hail.expr.types import dtypes_from_pandas, hail_type, tarray, tset, tstruct, types_match +from hail.typecheck import ( + anyfunc, + anytype, + dictof, + enumeration, + func_spec, + lazy, + nullable, + numeric, + oneof, + sequenceof, + table_key_type, + typecheck, + typecheck_method, +) from hail.utils import deduplicate from hail.utils.interval import Interval -from hail.utils.placement_tree import PlacementTree from hail.utils.java import Env, info, warning -from hail.utils.misc import wrap_to_tuple, storage_level, plural, \ - get_nice_field_error, get_nice_attr_error, get_key_by_exprs, check_keys, \ - get_select_exprs, check_annotate_exprs, process_joins -import hail as hl +from hail.utils.misc import ( + check_annotate_exprs, + check_keys, + get_key_by_exprs, + get_nice_attr_error, + get_nice_field_error, + get_select_exprs, + plural, + process_joins, + storage_level, + wrap_to_tuple, +) +from hail.utils.placement_tree import PlacementTree table_type = lazy() @@ -76,9 +119,8 @@ def desc(col): class ExprContainer: - # this can only grow as big as the object dir, so no need to worry about memory leak - _warned_about = set() + _warned_about: ClassVar = set() def __init__(self): self._fields: Dict[str, Expression] = {} @@ -96,8 +138,10 @@ def _set_field(self, key, value): if key in self._dir or key in self.__dict__: if key not in ExprContainer._warned_about: ExprContainer._warned_about.add(key) - warning(f"Name collision: field {repr(key)} already in object dict. " - f"\n This field must be referenced with __getitem__ syntax: obj[{repr(key)}]") + warning( + f"Name collision: field {key!r} already in object dict. " + f"\n This field must be referenced with __getitem__ syntax: obj[{key!r}]" + ) else: self.__dict__[key] = value @@ -234,20 +278,21 @@ def aggregate(self, **named_exprs) -> 'Table': Aggregated table. """ for name, expr in named_exprs.items(): - analyze(f'GroupedTable.aggregate: ({repr(name)})', expr, self._parent._global_indices, {self._parent._row_axis}) + analyze(f'GroupedTable.aggregate: ({name!r})', expr, self._parent._global_indices, {self._parent._row_axis}) if not named_exprs.keys().isdisjoint(set(self._key_expr)): intersection = set(named_exprs.keys()) & set(self._key_expr) raise ValueError( - f'GroupedTable.aggregate: Group names and aggregration expression names overlap: {intersection}') + f'GroupedTable.aggregate: Group names and aggregration expression names overlap: {intersection}' + ) base, _ = self._parent._process_joins(self._key_expr, *named_exprs.values()) key_struct = self._key_expr - return Table(ir.TableKeyByAndAggregate(base._tir, - hl.struct(**named_exprs)._ir, - key_struct._ir, - self._npartitions, - self._buffer_size)) + return Table( + ir.TableKeyByAndAggregate( + base._tir, hl.struct(**named_exprs)._ir, key_struct._ir, self._npartitions, self._buffer_size + ) + ) class Table(ExprContainer): @@ -358,14 +403,11 @@ def __init__(self, tir): self._globals = construct_reference('global', self._global_type, indices=self._global_indices) self._row = construct_reference('row', self._row_type, indices=self._row_indices) - self._indices_from_ref = {'global': self._global_indices, - 'row': self._row_indices} + self._indices_from_ref = {'global': self._global_indices, 'row': self._row_indices} - self._key = hl.struct( - **{k: self._row[k] for k in self._type.row_key}) + self._key = hl.struct(**{k: self._row[k] for k in self._type.row_key}) - for k, v in itertools.chain(self._globals.items(), - self._row.items()): + for k, v in itertools.chain(self._globals.items(), self._row.items()): self._set_field(k, v) @property @@ -379,9 +421,11 @@ def __getitem__(self, item): try: return self.index(*wrap_to_tuple(item)) except TypeError as e: - raise TypeError("Table.__getitem__: invalid index argument(s)\n" - " Usage 1: field selection: ht['field']\n" - " Usage 2: Left distinct join: ht[ht2.key] or ht[ht2.field1, ht2.field2]") from e + raise TypeError( + "Table.__getitem__: invalid index argument(s)\n" + " Usage 1: field selection: ht['field']\n" + " Usage 2: Left distinct join: ht[ht2.key] or ht[ht2.field1, ht2.field2]" + ) from e @property def key(self) -> StructExpression: @@ -414,9 +458,34 @@ def _value(self) -> 'StructExpression': def n_partitions(self): """Returns the number of partitions in the table. + Examples + -------- + + Range tables can be constructed with an explicit number of partitions: + + >>> ht = hl.utils.range_table(100, n_partitions=10) + >>> ht.n_partitions() + 10 + + Small files are often imported with one partition: + + >>> ht2 = hl.import_table('data/coordinate_matrix.tsv', impute=True) + >>> ht2.n_partitions() + 1 + + The `min_partitions` argument to :func:`.import_table` forces more partitions, but it can + produce empty partitions. Empty partitions do not affect correctness but introduce + unnecessary extra bookkeeping that slows down the pipeline. + + >>> ht2 = hl.import_table('data/coordinate_matrix.tsv', impute=True, min_partitions=10) + >>> ht2.n_partitions() + 10 + Returns ------- :obj:`int` + Number of partitions. + """ return Env.backend().execute(ir.TableToValueApply(self._tir, {'name': 'NPartitionsTable'})) @@ -426,12 +495,18 @@ def count(self): Examples -------- - >>> table1.count() + Count the number of rows in a table loaded from 'data/kt_example1.tsv'. Each line of the TSV + becomes one row in the Hail Table. + + >>> ht = hl.import_table('data/kt_example1.tsv', impute=True) + >>> ht.count() 4 Returns ------- :obj:`int` + The number of rows in the table. + """ return Env.backend().execute(ir.TableCount(self._tir)) @@ -444,8 +519,7 @@ def _force_count(self): async def _async_force_count(self): return await Env.backend()._async_execute(ir.TableToValueApply(self._tir, {'name': 'ForceCountTable'})) - @typecheck_method(caller=str, - row=expr_struct()) + @typecheck_method(caller=str, row=expr_struct()) def _select(self, caller, row) -> 'Table': analyze(caller, row, self._row_indices) base, cleanup = self._process_joins(row) @@ -458,35 +532,89 @@ def _select_globals(self, caller, s) -> 'Table': return cleanup(Table(ir.TableMapGlobals(base._tir, s._ir))) @classmethod - @typecheck_method(rows=anytype, - schema=nullable(hail_type), - key=table_key_type, - n_partitions=nullable(int), - partial_type=nullable(dict), - globals=nullable(expr_struct())) - def parallelize(cls, rows, schema=None, key=None, n_partitions=None, *, - partial_type=None, - globals=None - ) -> 'Table': + @typecheck_method( + rows=anytype, + schema=nullable(hail_type), + key=table_key_type, + n_partitions=nullable(int), + partial_type=nullable(dict), + globals=nullable(expr_struct()), + ) + def parallelize(cls, rows, schema=None, key=None, n_partitions=None, *, partial_type=None, globals=None) -> 'Table': """Parallelize a local array of structs into a distributed table. Examples -------- - Parallelize a list of dictionaries: - >>> a = [ {'a': 5, 'b': 10}, {'a': 0, 'b': 200} ] - >>> t = hl.Table.parallelize(hl.literal(a, 'array')) + Parallelize a list of dictionaries into a Hail Table. The fields of the dictionary become + the fields of the Table. The schema should always be a :class:`.tstruct` whose fields + correspond to the dictionaries' fields. + + >>> t = hl.Table.parallelize( + ... [{'a': 5, 'b': 10}, {'a': 0, 'b': 200}], + ... schema=hl.tstruct(a=hl.tint, b=hl.tint) + ... ) >>> t.show() + +-------+-------+ + | a | b | + +-------+-------+ + | int32 | int32 | + +-------+-------+ + | 5 | 10 | + | 0 | 200 | + +-------+-------+ + + The `key` parameter sets the key of the Table. Notice that the order of the rows changes, + because the rows of a Table are always appear in ascending order of the key. - Parallelize complex JSON with a `partial_type`: - >>> dicts = [{"number":10038,"state":"open","user":{"login":"tpoterba","site_admin":False,"id":10562794}, "milestone":None,"labels":[]},\ - {"number":10037,"state":"open","user":{"login":"daniel-goldstein","site_admin":False,"id":24440116},"milestone":None,"labels":[]},\ - {"number":10036,"state":"open","user":{"login":"jigold","site_admin":False,"id":1693348},"milestone":None,"labels":[]},\ - {"number":10035,"state":"open","user":{"login":"tpoterba","site_admin":False,"id":10562794},"milestone":None,"labels":[]},\ - {"number":10033,"state":"open","user":{"login":"tpoterba","site_admin":False,"id":10562794},"milestone":None,"labels":[]}] >>> t = hl.Table.parallelize( - ... dicts, - ... partial_type={"milestone":hl.tstr, "labels":hl.tarray(hl.tstr)}) + ... [{'a': 5, 'b': 10}, {'a': 0, 'b': 200}], + ... schema=hl.tstruct(a=hl.tint, b=hl.tint), + ... key='a' + ... ) + >>> t.show() + +-------+-------+ + | a | b | + +-------+-------+ + | int32 | int32 | + +-------+-------+ + | 0 | 200 | + | 5 | 10 | + +-------+-------+ + + You may also elide schema entirely and let Hail guess the type. The list elements must + either be Hail :class:`.Struct` or :class:`.dict` s. + + >>> t = hl.Table.parallelize( + ... [{'a': 5, 'b': 10}, {'a': 0, 'b': 200}], + ... key='a' + ... ) + >>> t.show() + +-------+-------+ + | a | b | + +-------+-------+ + | int32 | int32 | + +-------+-------+ + | 0 | 200 | + | 5 | 10 | + +-------+-------+ + + You may also specify only a handful of types in `partial_type`. Hail will automatically + deduce the types of the other fields. Hail _cannot_ deduce the type of a field which only + contains empty arrays (the element type is unspecified), so we specify the type of labels + explicitly. + + >>> dictionaries = [ + ... {"number":10038,"state":"open","user":{"login":"tpoterba","site_admin":False,"id":10562794}, "milestone":None,"labels":[]}, + ... {"number":10037,"state":"open","user":{"login":"daniel-goldstein","site_admin":False,"id":24440116},"milestone":None,"labels":[]}, + ... {"number":10036,"state":"open","user":{"login":"jigold","site_admin":False,"id":1693348},"milestone":None,"labels":[]}, + ... {"number":10035,"state":"open","user":{"login":"tpoterba","site_admin":False,"id":10562794},"milestone":None,"labels":[]}, + ... {"number":10033,"state":"open","user":{"login":"tpoterba","site_admin":False,"id":10562794},"milestone":None,"labels":[]}, + ... ] + >>> t = hl.Table.parallelize( + ... dictionaries, + ... partial_type={"milestone": hl.tstr, "labels": hl.tarray(hl.tstr)} + ... ) >>> t.show() +--------+--------+--------------------+-----------------+----------+ | number | state | user.login | user.site_admin | user.id | @@ -511,8 +639,25 @@ def parallelize(cls, rows, schema=None, key=None, n_partitions=None, *, | NA | [] | +-----------+------------+ + Parallelizing with a specified number of partitions: + + >>> rows = [ {'a': i} for i in range(100) ] + >>> ht = hl.Table.parallelize(rows, n_partitions=10) + >>> ht.n_partitions() + 10 + >>> ht.count() + 100 + + Parallelizing with some global information: + + >>> rows = [ {'a': i} for i in range(5) ] + >>> ht = hl.Table.parallelize(rows, globals=hl.Struct(global_value=3)) + >>> ht.aggregate(hl.agg.sum(ht.global_value * ht.a)) + 30 + Warning ------- + Parallelizing very large local arrays will be slow. Parameters @@ -533,6 +678,7 @@ def parallelize(cls, rows, schema=None, key=None, n_partitions=None, *, Returns ------- :class:`.Table` + A distributed Hail table created from the local collection of rows. """ if schema and partial_type: @@ -540,20 +686,23 @@ def parallelize(cls, rows, schema=None, key=None, n_partitions=None, *, dtype = schema if schema is not None: + if not isinstance(schema, hl.tstruct): + raise ValueError( + "parallelize expectes the 'schema' argument to be an hl.tstruct, see docs for details." + ) dtype = hl.tarray(schema) - if partial_type is not None: + elif partial_type is not None: partial_type = hl.tarray(hl.tstruct(**partial_type)) + else: + partial_type = hl.tarray(hl.tstruct()) rows = to_expr(rows, dtype=dtype, partial_type=partial_type) if not isinstance(rows.dtype.element_type, tstruct): - raise TypeError("'parallelize' expects an array with element type 'struct', found '{}'" - .format(rows.dtype)) - table = Table(ir.TableParallelize( - ir.MakeStruct([ - ('rows', rows._ir), - ('global', (globals or hl.struct())._ir) - ]), - n_partitions - )) + raise TypeError("'parallelize' expects an array with element type 'struct', found '{}'".format(rows.dtype)) + table = Table( + ir.TableParallelize( + ir.MakeStruct([('rows', rows._ir), ('global', (globals or hl.struct())._ir)]), n_partitions + ) + ) if key is not None: table = table.key_by(*key) return table @@ -563,12 +712,12 @@ def parallelize(cls, rows, schema=None, key=None, n_partitions=None, *, contexts=expr_array(expr_any), partitions=oneof(sequenceof(Interval), int), rowfn=func_spec(2, expr_array(expr_struct())), - globals=nullable(expr_struct()) + globals=nullable(expr_struct()), ) def _generate( - contexts: 'hl.ArrayExpression', + contexts: 'ArrayExpression', partitions: 'Union[Sequence[Interval], int]', - rowfn: 'Callable[[hl.Expression, hl.StructExpression], hl.ArrayExpression]', + rowfn: 'Callable[[hl.Expression, hl.StructExpression], ArrayExpression]', globals: 'Optional[hl.StructExpression]' = None, ) -> 'Table': """ @@ -585,48 +734,146 @@ def _generate( body = ir.toStream(rowfn(cexpr, gexpr)._ir) if isinstance(partitions, int): - partitions = [ - Interval(hl.Struct(), hl.Struct(), True, True) - for _ in range(partitions) - ] + partitions = [Interval(hl.Struct(), hl.Struct(), True, True) for _ in range(partitions)] partitioner = ir.Partitioner(partitions[0].point_type, partitions) - return Table(ir.TableGen( - ir.toStream(contexts._ir), globals._ir, context_name, - globals_name, body, partitioner - )) + return Table(ir.TableGen(ir.toStream(contexts._ir), globals._ir, context_name, globals_name, body, partitioner)) - @typecheck_method(keys=oneof(str, expr_any), - named_keys=expr_any) + @typecheck_method(keys=oneof(str, expr_any), named_keys=expr_any) def key_by(self, *keys, **named_keys) -> 'Table': """Key table by a new set of fields. + Table keys control both the order of the rows in the table and the ability to join or + annotate one table with the information in another table. + Examples -------- - Assume `table1` is a :class:`.Table` with three fields: `C1`, `C2` - and `C3`. - Changing key fields: + Consider a simple unkeyed table. Its rows appear are guaranteed to appear in the same order + as they were in the source text file. - >>> table_result = table1.key_by('C2', 'C3') + >>> ht = hl.import_table('data/kt_example1.tsv', impute=True) + >>> ht.show() + +-------+-------+-----+-------+-------+-------+-------+-------+ + | ID | HT | SEX | X | Z | C1 | C2 | C3 | + +-------+-------+-----+-------+-------+-------+-------+-------+ + | int32 | int32 | str | int32 | int32 | int32 | int32 | int32 | + +-------+-------+-----+-------+-------+-------+-------+-------+ + | 1 | 65 | "M" | 5 | 4 | 2 | 50 | 5 | + | 2 | 72 | "M" | 6 | 3 | 2 | 61 | 1 | + | 3 | 70 | "F" | 7 | 3 | 10 | 81 | -5 | + | 4 | 60 | "F" | 8 | 2 | 11 | 90 | -10 | + +-------+-------+-----+-------+-------+-------+-------+-------+ - This keys the table by 'C2' and 'C3', preserving old keys as value fields. + Changing the key forces the rows to appear in ascending order. For this reason, + :meth:`.key_by` is a relatively expensive operation. It must sort the entire dataset. - >>> table_result = table1.key_by(table1.C1) + >>> ht = ht.key_by('HT') + >>> ht.show() + +-------+-------+-----+-------+-------+-------+-------+-------+ + | ID | HT | SEX | X | Z | C1 | C2 | C3 | + +-------+-------+-----+-------+-------+-------+-------+-------+ + | int32 | int32 | str | int32 | int32 | int32 | int32 | int32 | + +-------+-------+-----+-------+-------+-------+-------+-------+ + | 4 | 60 | "F" | 8 | 2 | 11 | 90 | -10 | + | 1 | 65 | "M" | 5 | 4 | 2 | 50 | 5 | + | 3 | 70 | "F" | 7 | 3 | 10 | 81 | -5 | + | 2 | 72 | "M" | 6 | 3 | 2 | 61 | 1 | + +-------+-------+-----+-------+-------+-------+-------+-------+ + + Suppose that `ht` represents some human subjects in an experiment. We might need to combine + sample metadata from `ht` with sample metadata from another source. For example: + + >>> ht2 = hl.import_table('data/kt_example2.tsv', impute=True) + >>> ht2 = ht2.key_by('ID') + >>> ht2.show() + +-------+-------+----------+ + | ID | A | B | + +-------+-------+----------+ + | int32 | int32 | str | + +-------+-------+----------+ + | 1 | 65 | "cat" | + | 2 | 72 | "dog" | + | 3 | 70 | "mouse" | + | 4 | 60 | "rabbit" | + +-------+-------+----------+ + >>> combined_ht = ht + >>> combined_ht = combined_ht.key_by('ID') + >>> combined_ht = combined_ht.annotate(favorite_pet = ht2[combined_ht.key].B) + >>> combined_ht.show() + +-------+-------+-----+-------+-------+-------+-------+-------+--------------+ + | ID | HT | SEX | X | Z | C1 | C2 | C3 | favorite_pet | + +-------+-------+-----+-------+-------+-------+-------+-------+--------------+ + | int32 | int32 | str | int32 | int32 | int32 | int32 | int32 | str | + +-------+-------+-----+-------+-------+-------+-------+-------+--------------+ + | 1 | 65 | "M" | 5 | 4 | 2 | 50 | 5 | "cat" | + | 2 | 72 | "M" | 6 | 3 | 2 | 61 | 1 | "dog" | + | 3 | 70 | "F" | 7 | 3 | 10 | 81 | -5 | "mouse" | + | 4 | 60 | "F" | 8 | 2 | 11 | 90 | -10 | "rabbit" | + +-------+-------+-----+-------+-------+-------+-------+-------+--------------+ + + Hail supports compound keys which enforce a dictionary ordering on the rows of the Table. + + >>> ht = ht.key_by('SEX', 'HT') + >>> ht.show() + +-------+-------+-----+-------+-------+-------+-------+-------+ + | ID | HT | SEX | X | Z | C1 | C2 | C3 | + +-------+-------+-----+-------+-------+-------+-------+-------+ + | int32 | int32 | str | int32 | int32 | int32 | int32 | int32 | + +-------+-------+-----+-------+-------+-------+-------+-------+ + | 4 | 60 | "F" | 8 | 2 | 11 | 90 | -10 | + | 3 | 70 | "F" | 7 | 3 | 10 | 81 | -5 | + | 1 | 65 | "M" | 5 | 4 | 2 | 50 | 5 | + | 2 | 72 | "M" | 6 | 3 | 2 | 61 | 1 | + +-------+-------+-----+-------+-------+-------+-------+-------+ - This keys the table by 'C1', preserving old keys as value fields. + A key may also be shortened by removing some fields. The ordering of two rows with the same + key is undefined. You should not rely on them appearing in any particular order. - >>> table_result = table1.key_by(C1 = table1.C2, foo = table1.C1) + >>> ht = ht.key_by('SEX') + >>> ht.show() + +-------+-------+-----+-------+-------+-------+-------+-------+ + | ID | HT | SEX | X | Z | C1 | C2 | C3 | + +-------+-------+-----+-------+-------+-------+-------+-------+ + | int32 | int32 | str | int32 | int32 | int32 | int32 | int32 | + +-------+-------+-----+-------+-------+-------+-------+-------+ + | 3 | 70 | "F" | 7 | 3 | 10 | 81 | -5 | + | 4 | 60 | "F" | 8 | 2 | 11 | 90 | -10 | + | 1 | 65 | "M" | 5 | 4 | 2 | 50 | 5 | + | 2 | 72 | "M" | 6 | 3 | 2 | 61 | 1 | + +-------+-------+-----+-------+-------+-------+-------+-------+ - This keys the table by fields named 'C1' and 'foo', which have values - corresponding to the original 'C2' and 'C1' fields respectively. The original - 'C1' field has been overwritten by the new assignment, but the original - 'C2' field is preserved as a value field. + Key fields may also be a complex expression: - Remove key: + >>> ht = ht.key_by(C4 = ht.X + ht.Z) + >>> ht.show() + +-------+-------+-----+-------+-------+-------+-------+-------+-------+ + | ID | HT | SEX | X | Z | C1 | C2 | C3 | C4 | + +-------+-------+-----+-------+-------+-------+-------+-------+-------+ + | int32 | int32 | str | int32 | int32 | int32 | int32 | int32 | int32 | + +-------+-------+-----+-------+-------+-------+-------+-------+-------+ + | 1 | 65 | "M" | 5 | 4 | 2 | 50 | 5 | 9 | + | 2 | 72 | "M" | 6 | 3 | 2 | 61 | 1 | 9 | + | 3 | 70 | "F" | 7 | 3 | 10 | 81 | -5 | 10 | + | 4 | 60 | "F" | 8 | 2 | 11 | 90 | -10 | 10 | + +-------+-------+-----+-------+-------+-------+-------+-------+-------+ - >>> table_result = table1.key_by() + The key can be "removed" or set to the empty key. The ordering of the rows in a table + without a key is undefined. + + >>> ht = ht.key_by() + >>> ht.show() + +-------+-------+-----+-------+-------+-------+-------+-------+-------+ + | ID | HT | SEX | X | Z | C1 | C2 | C3 | C4 | + +-------+-------+-----+-------+-------+-------+-------+-------+-------+ + | int32 | int32 | str | int32 | int32 | int32 | int32 | int32 | int32 | + +-------+-------+-----+-------+-------+-------+-------+-------+-------+ + | 1 | 65 | "M" | 5 | 4 | 2 | 50 | 5 | 9 | + | 2 | 72 | "M" | 6 | 3 | 2 | 61 | 1 | 9 | + | 3 | 70 | "F" | 7 | 3 | 10 | 81 | -5 | 10 | + | 4 | 60 | "F" | 8 | 2 | 11 | 90 | -10 | 10 | + +-------+-------+-----+-------+-------+-------+-------+-------+-------+ Notes ----- @@ -647,6 +894,7 @@ def key_by(self, *keys, **named_keys) -> 'Table': ------- :class:`.Table` Table with a new key. + """ key_fields, computed_keys = get_key_by_exprs("Table.key_by", keys, named_keys, self._row_indices) @@ -656,15 +904,11 @@ def key_by(self, *keys, **named_keys) -> 'Table': new_row = self.row.annotate(**computed_keys) base, cleanup = self._process_joins(new_row) - return cleanup(Table( - ir.TableKeyBy( - ir.TableMapRows( - ir.TableKeyBy(base._tir, []), - new_row._ir), - list(key_fields)))) + return cleanup( + Table(ir.TableKeyBy(ir.TableMapRows(ir.TableKeyBy(base._tir, []), new_row._ir), list(key_fields))) + ) - @typecheck_method(keys=oneof(str, expr_any), - named_keys=expr_any) + @typecheck_method(keys=oneof(str, expr_any), named_keys=expr_any) def _key_by_assert_sorted(self, *keys, **named_keys) -> 'Table': key_fields, computed_keys = get_key_by_exprs("Table.key_by", keys, named_keys, self._row_indices) @@ -674,13 +918,13 @@ def _key_by_assert_sorted(self, *keys, **named_keys) -> 'Table': new_row = self.row.annotate(**computed_keys) base, cleanup = self._process_joins(new_row) - return cleanup(Table( - ir.TableKeyBy( - ir.TableMapRows( - ir.TableKeyBy(base._tir, []), - new_row._ir), - list(key_fields), - is_sorted=True))) + return cleanup( + Table( + ir.TableKeyBy( + ir.TableMapRows(ir.TableKeyBy(base._tir, []), new_row._ir), list(key_fields), is_sorted=True + ) + ) + ) @typecheck_method(named_exprs=expr_any) def annotate_globals(self, **named_exprs) -> 'Table': @@ -691,7 +935,23 @@ def annotate_globals(self, **named_exprs) -> 'Table': Add a new global field: - >>> table_result = table1.annotate_globals(pops = ['EUR', 'AFR', 'EAS', 'SAS']) + >>> ht = hl.utils.range_table(1) + >>> ht = ht.annotate_globals(pops = ['EUR', 'AFR', 'EAS', 'SAS']) + >>> ht.globals.show() + +---------------------------+ + | .pops | + +---------------------------+ + | array | + +---------------------------+ + | ["EUR","AFR","EAS","SAS"] | + +---------------------------+ + + Global fields may be used to store metadata about an experiment: + + >>> ht = ht.annotate_globals( + ... study_name='HGDP+1kG', + ... release_date='2023-01-01' + ... ) Note ---- @@ -700,7 +960,7 @@ def annotate_globals(self, **named_exprs) -> 'Table': Parameters ---------- named_exprs : varargs of :class:`.Expression` - Annotation expressions. + Expressions defining new global fields. Returns ------- @@ -716,10 +976,48 @@ def select_globals(self, *exprs, **named_exprs) -> 'Table': Examples -------- - Select one existing field and compute a new one: - >>> table_result = table1.select_globals(table1.global_field_1, - ... another_global=['AFR', 'EUR', 'EAS', 'AMR', 'SAS']) + Selecting two global fields, one by name and one new one, replacing any previously annotated + global fields. + + >>> ht = hl.utils.range_table(1) + >>> ht = ht.annotate_globals(pops = ['EUR', 'AFR', 'EAS', 'SAS']) + >>> ht = ht.annotate_globals(study_name = 'HGDP+1kg') + >>> ht.describe() + ---------------------------------------- + Global fields: + 'pops': array + 'study_name': str + ---------------------------------------- + Row fields: + 'idx': int32 + ---------------------------------------- + Key: ['idx'] + ---------------------------------------- + >>> ht = ht.select_globals(ht.pops, target_date='2025-01-01') + >>> ht.describe() + ---------------------------------------- + Global fields: + 'pops': array + 'target_date': str + ---------------------------------------- + Row fields: + 'idx': int32 + ---------------------------------------- + Key: ['idx'] + ---------------------------------------- + + Fields may also be selected by their name: + + >>> ht = ht.select_globals('target_date') + >>> ht.globals.show() + +--------------------+ + | .target_date | + +--------------------+ + | str | + +--------------------+ + | "2025-01-01" | + +--------------------+ Notes ----- @@ -746,13 +1044,10 @@ def select_globals(self, *exprs, **named_exprs) -> 'Table': ------- :class:`.Table` Table with specified global fields. + """ caller = 'Table.select_globals' - new_globals = get_select_exprs(caller, - exprs, - named_exprs, - self._global_indices, - self._globals) + new_globals = get_select_exprs(caller, exprs, named_exprs, self._global_indices, self._globals) return self._select_globals(caller, new_globals) @@ -762,10 +1057,35 @@ def transmute_globals(self, **named_exprs) -> 'Table': Notes ----- - This method adds new global fields according to `named_exprs`, and - drops all global fields referenced in those expressions. See - :meth:`.Table.transmute` for full documentation on how transmute - methods work. + Consider a table with global fields `population`, `area`, and `year`: + + >>> ht = hl.utils.range_table(1) + >>> ht = ht.annotate_globals(population=1000000, area=500, year=2020) + + Compute a new field, `density` from `population` and `area` and also drop the latter two + fields: + + >>> ht = ht.transmute_globals(density=ht.population / ht.area) + >>> ht.globals.show() + +-------------+----------------+ + | .year | .density | + +-------------+----------------+ + | int32 | float64 | + +-------------+----------------+ + | 2020 | 2.00e+03 | + +-------------+----------------+ + + Introduce a new global field `next_year` based on `year`: + + >>> ht = ht.transmute_globals(next_year=ht.year + 1) + >>> ht.globals.show() + +----------------+------------------+ + | .density | .next_year | + +----------------+------------------+ + | float64 | int32 | + +----------------+------------------+ + | 2.00e+03 | 2021 | + +----------------+------------------+ See Also -------- @@ -780,13 +1100,15 @@ def transmute_globals(self, **named_exprs) -> 'Table': Returns ------- :class:`.Table` + """ caller = 'Table.transmute_globals' check_annotate_exprs(caller, named_exprs, self._global_indices, set()) - fields_referenced = extract_refs_by_indices(named_exprs.values(), self._global_indices) - set(named_exprs.keys()) + fields_referenced = extract_refs_by_indices(named_exprs.values(), self._global_indices) - set( + named_exprs.keys() + ) - return self._select_globals(caller, - self.globals.annotate(**named_exprs).drop(*fields_referenced)) + return self._select_globals(caller, self.globals.annotate(**named_exprs).drop(*fields_referenced)) @typecheck_method(named_exprs=expr_any) def transmute(self, **named_exprs) -> 'Table': @@ -795,26 +1117,62 @@ def transmute(self, **named_exprs) -> 'Table': Examples -------- - Create a single field from an expression of `C1`, `C2`, and `C3`. + Consider this table: - >>> table4.show() - +-------+------+---------+-------+-------+-------+-------+-------+ - | A | B.B0 | B.B1 | C | D.cat | D.dog | E.A | E.B | - +-------+------+---------+-------+-------+-------+-------+-------+ - | int32 | bool | str | bool | int32 | int32 | int32 | int32 | - +-------+------+---------+-------+-------+-------+-------+-------+ - | 32 | True | "hello" | False | 5 | 7 | 5 | 7 | - +-------+------+---------+-------+-------+-------+-------+-------+ + >>> ht = table1 + >>> ht.show() + +-------+-------+-----+-------+-------+-------+-------+-------+ + | ID | HT | SEX | X | Z | C1 | C2 | C3 | + +-------+-------+-----+-------+-------+-------+-------+-------+ + | int32 | int32 | str | int32 | int32 | int32 | int32 | int32 | + +-------+-------+-----+-------+-------+-------+-------+-------+ + | 1 | 65 | "M" | 5 | 4 | 2 | 50 | 5 | + | 2 | 72 | "M" | 6 | 3 | 2 | 61 | 1 | + | 3 | 70 | "F" | 7 | 3 | 10 | 81 | -5 | + | 4 | 60 | "F" | 8 | 2 | 11 | 90 | -10 | + +-------+-------+-----+-------+-------+-------+-------+-------+ - >>> table_result = table4.transmute(F=table4.A + 2 * table4.E.B) - >>> table_result.show() - +------+---------+-------+-------+-------+-------+ - | B.B0 | B.B1 | C | D.cat | D.dog | F | - +------+---------+-------+-------+-------+-------+ - | bool | str | bool | int32 | int32 | int32 | - +------+---------+-------+-------+-------+-------+ - | True | "hello" | False | 5 | 7 | 46 | - +------+---------+-------+-------+-------+-------+ + Transmuting a field without referencing other fields has the same effect as annotating: + + >>> ht = ht.transmute(new_field=hl.struct(x=3, y=4)) + >>> ht.show() + +-------+-------+-----+-------+-------+-------+-------+-------+-------------+ + | ID | HT | SEX | X | Z | C1 | C2 | C3 | new_field.x | + +-------+-------+-----+-------+-------+-------+-------+-------+-------------+ + | int32 | int32 | str | int32 | int32 | int32 | int32 | int32 | int32 | + +-------+-------+-----+-------+-------+-------+-------+-------+-------------+ + | 1 | 65 | "M" | 5 | 4 | 2 | 50 | 5 | 3 | + | 2 | 72 | "M" | 6 | 3 | 2 | 61 | 1 | 3 | + | 3 | 70 | "F" | 7 | 3 | 10 | 81 | -5 | 3 | + | 4 | 60 | "F" | 8 | 2 | 11 | 90 | -10 | 3 | + +-------+-------+-----+-------+-------+-------+-------+-------+-------------+ + +-------------+ + | new_field.y | + +-------------+ + | int32 | + +-------------+ + | 4 | + | 4 | + | 4 | + | 4 | + +-------------+ + + Transmuting a field while referencing other fields drops those other fields. Notice how the + compound field, `new_field` is dropped entirely even though we only used one of its + component fields. + + >>> ht = ht.transmute(F=ht.X + 2 * ht.new_field.x) + >>> ht.show() + +-------+-------+-----+-------+-------+-------+-------+-------+ + | ID | HT | SEX | Z | C1 | C2 | C3 | F | + +-------+-------+-----+-------+-------+-------+-------+-------+ + | int32 | int32 | str | int32 | int32 | int32 | int32 | int32 | + +-------+-------+-----+-------+-------+-------+-------+-------+ + | 1 | 65 | "M" | 4 | 2 | 50 | 5 | 11 | + | 2 | 72 | "M" | 3 | 2 | 61 | 1 | 12 | + | 3 | 70 | "F" | 3 | 10 | 81 | -5 | 13 | + | 4 | 60 | "F" | 2 | 11 | 90 | -10 | 14 | + +-------+-------+-----+-------+-------+-------+-------+-------+ Notes ----- @@ -830,9 +1188,9 @@ def transmute(self, **named_exprs) -> 'Table': Warning ------- - References to fields inside a top-level struct will remove the entire - struct, as field `E` was removed in the example above since `E.B` was - referenced. + + References to fields inside a top-level struct will remove the entire struct, as field + `new_field` was removed in the example above since `new_field.x` was referenced. Note ---- @@ -847,6 +1205,7 @@ def transmute(self, **named_exprs) -> 'Table': ------- :class:`.Table` Table with transmuted fields. + """ caller = "Table.transmute" check_annotate_exprs(caller, named_exprs, self._row_indices, set()) @@ -859,17 +1218,138 @@ def transmute(self, **named_exprs) -> 'Table': def annotate(self, **named_exprs) -> 'Table': """Add new fields. + New Table fields may be defined in several ways: + + 1. In terms of constant values. Every row will have the same value. + 2. In terms of other fields in the table. + 3. In terms of fields in other tables, this is called "joining". + Examples -------- - Add field `Y` by computing the square of `X`: - - >>> table_result = table1.annotate(Y = table1.X ** 2) + Consider this table: + + >>> ht = ht.drop('C1', 'C2', 'C3') + >>> ht.show() + +-------+-------+-----+-------+-------+ + | ID | HT | SEX | X | Z | + +-------+-------+-----+-------+-------+ + | int32 | int32 | str | int32 | int32 | + +-------+-------+-----+-------+-------+ + | 1 | 65 | "M" | 5 | 4 | + | 2 | 72 | "M" | 6 | 3 | + | 3 | 70 | "F" | 7 | 3 | + | 4 | 60 | "F" | 8 | 2 | + +-------+-------+-----+-------+-------+ + + Add field Y containing the square of field X + + >>> ht = ht.annotate(Y = ht.X ** 2) + >>> ht.show() + +-------+-------+-----+-------+-------+----------+ + | ID | HT | SEX | X | Z | Y | + +-------+-------+-----+-------+-------+----------+ + | int32 | int32 | str | int32 | int32 | float64 | + +-------+-------+-----+-------+-------+----------+ + | 1 | 65 | "M" | 5 | 4 | 2.50e+01 | + | 2 | 72 | "M" | 6 | 3 | 3.60e+01 | + | 3 | 70 | "F" | 7 | 3 | 4.90e+01 | + | 4 | 60 | "F" | 8 | 2 | 6.40e+01 | + +-------+-------+-----+-------+-------+----------+ Add multiple fields simultaneously: - >>> table_result = table1.annotate(A = table1.X / 2, - ... B = table1.X + 21) + >>> ht = ht.annotate( + ... A = ht.X / 2, + ... B = ht.X + 21 + ... ) + >>> ht.show() + +-------+-------+-----+-------+-------+----------+----------+-------+ + | ID | HT | SEX | X | Z | Y | A | B | + +-------+-------+-----+-------+-------+----------+----------+-------+ + | int32 | int32 | str | int32 | int32 | float64 | float64 | int32 | + +-------+-------+-----+-------+-------+----------+----------+-------+ + | 1 | 65 | "M" | 5 | 4 | 2.50e+01 | 2.50e+00 | 26 | + | 2 | 72 | "M" | 6 | 3 | 3.60e+01 | 3.00e+00 | 27 | + | 3 | 70 | "F" | 7 | 3 | 4.90e+01 | 3.50e+00 | 28 | + | 4 | 60 | "F" | 8 | 2 | 6.40e+01 | 4.00e+00 | 29 | + +-------+-------+-----+-------+-------+----------+----------+-------+ + + Add a new field computed from extant fields and a small dictionary: + + >>> py_height_description = {65: 'sixty-five', 72: 'seventy-two', 70: 'seventy', 60: 'sixty'} + >>> hail_height_description = hl.literal(py_height_description) + >>> ht = ht.annotate(HT_DESCRIPTION=hail_height_description[ht.HT]) + >>> ht.select('HT', 'HT_DESCRIPTION').show() + +-------+-------+----------------+ + | ID | HT | HT_DESCRIPTION | + +-------+-------+----------------+ + | int32 | int32 | str | + +-------+-------+----------------+ + | 1 | 65 | "sixty-five" | + | 2 | 72 | "seventy-two" | + | 3 | 70 | "seventy" | + | 4 | 60 | "sixty" | + +-------+-------+----------------+ + + Add fields from another table onto this table: + + >>> ht2 = table2 + >>> ht2 = ht2.key_by('ID') + >>> ht2.show() + + >>> ht = ht.key_by('ID') + >>> ht = ht.annotate( + ... A=ht2[ht.key].A, + ... B=ht2[ht.key].B, + ... ) + >>> ht.show() + +-------+-------+-----+-------+-------+----------+-------+----------+ + | ID | HT | SEX | X | Z | Y | A | B | + +-------+-------+-----+-------+-------+----------+-------+----------+ + | int32 | int32 | str | int32 | int32 | float64 | int32 | str | + +-------+-------+-----+-------+-------+----------+-------+----------+ + | 1 | 65 | "M" | 5 | 4 | 2.50e+01 | 65 | "cat" | + | 2 | 72 | "M" | 6 | 3 | 3.60e+01 | 72 | "dog" | + | 3 | 70 | "F" | 7 | 3 | 4.90e+01 | 70 | "mouse" | + | 4 | 60 | "F" | 8 | 2 | 6.40e+01 | 60 | "rabbit" | + +-------+-------+-----+-------+-------+----------+-------+----------+ + +----------------+ + | HT_DESCRIPTION | + +----------------+ + | str | + +----------------+ + | "sixty-five" | + | "seventy-two" | + | "seventy" | + | "sixty" | + +----------------+ + + Instead of repeating all the fields from the other table, we may use Python's splat operator + to indicate we want to copy all the non-key fields from the other table: + + >>> ht = ht.annotate(**ht2[ht.key]) + >>> ht.show() + +-------+-------+-----+-------+-------+----------+-------+----------+ + | ID | HT | SEX | X | Z | Y | A | B | + +-------+-------+-----+-------+-------+----------+-------+----------+ + | int32 | int32 | str | int32 | int32 | float64 | int32 | str | + +-------+-------+-----+-------+-------+----------+-------+----------+ + | 1 | 65 | "M" | 5 | 4 | 2.50e+01 | 65 | "cat" | + | 2 | 72 | "M" | 6 | 3 | 3.60e+01 | 72 | "dog" | + | 3 | 70 | "F" | 7 | 3 | 4.90e+01 | 70 | "mouse" | + | 4 | 60 | "F" | 8 | 2 | 6.40e+01 | 60 | "rabbit" | + +-------+-------+-----+-------+-------+----------+-------+----------+ + +----------------+ + | HT_DESCRIPTION | + +----------------+ + | str | + +----------------+ + | "sixty-five" | + | "seventy-two" | + | "seventy" | + | "sixty" | + +----------------+ Parameters ---------- @@ -880,26 +1360,170 @@ def annotate(self, **named_exprs) -> 'Table': ------- :class:`.Table` Table with new fields. + """ caller = "Table.annotate" check_annotate_exprs(caller, named_exprs, self._row_indices, set()) return self._select(caller, self.row.annotate(**named_exprs)) - @typecheck_method(expr=expr_bool, - keep=bool) + @typecheck_method(expr=expr_bool, keep=bool) def filter(self, expr, keep: bool = True) -> 'Table': - """Filter rows. + """Filter rows conditional on the value of each row's fields. - Examples - -------- + Note + ---- + + Hail will can read much less data if a Table filter condition references the key field and + the Table is stored in Hail native format (i.e. read using :func:`.read_table`, _not_ + :func:`.import_table`). In other words: filtering on the key will make a pipeline faster by + reading fewer rows. This optimization is prevented by certain operations appearing between a + :func:`.read_table` and a :meth:`.filter`. For example, a `key_by` and `group_by`, both + force reading all the data. - Keep rows where ``C1`` equals 5: + Suppose we previously :meth:`.write` a Hail Table with one million rows keyed by a field + called `idx`. If we filter this table to one value of `idx`, the pipeline will be fast + because we read only the rows that have that value of `idx`: - >>> table_result = table1.filter(table1.C1 == 5) + >>> ht = hl.read_table('large-table.ht') # doctest: +SKIP + >>> ht = ht.filter(ht.idx == 5) # doctest: +SKIP - Remove rows where ``C1`` equals 10: + This also works with inequality conditions: - >>> table_result = table1.filter(table1.C1 == 10, keep=False) + >>> ht = hl.read_table('large-table.ht') # doctest: +SKIP + >>> ht = ht.filter(ht.idx <= 5) # doctest: +SKIP + + Examples + -------- + + Consider this table: + + >>> ht = ht.drop('C1', 'C2', 'C3') + >>> ht.show() + +-------+-------+-----+-------+-------+ + | ID | HT | SEX | X | Z | + +-------+-------+-----+-------+-------+ + | int32 | int32 | str | int32 | int32 | + +-------+-------+-----+-------+-------+ + | 1 | 65 | "M" | 5 | 4 | + | 2 | 72 | "M" | 6 | 3 | + | 3 | 70 | "F" | 7 | 3 | + | 4 | 60 | "F" | 8 | 2 | + +-------+-------+-----+-------+-------+ + + Keep rows where ``Z`` is 3: + + >>> filtered_ht = ht.filter(ht.Z == 3) + >>> filtered_ht.show() + + +-------+-------+-----+-------+-------+ + | ID | HT | SEX | X | Z | + +-------+-------+-----+-------+-------+ + | int32 | int32 | str | int32 | int32 | + +-------+-------+-----+-------+-------+ + | 2 | 72 | "M" | 6 | 3 | + | 3 | 70 | "F" | 7 | 3 | + +-------+-------+-----+-------+-------+ + + Remove rows where ``Z`` is 3: + + >>> filtered_ht = ht.filter(ht.Z == 3, keep=False) + >>> filtered_ht.show() + +-------+-------+-----+-------+-------+ + | ID | HT | SEX | X | Z | + +-------+-------+-----+-------+-------+ + | int32 | int32 | str | int32 | int32 | + +-------+-------+-----+-------+-------+ + | 1 | 65 | "M" | 5 | 4 | + | 4 | 60 | "F" | 8 | 2 | + +-------+-------+-----+-------+-------+ + + Keep rows where X is less than 7 and Z is greater than 2: + + >>> filtered_ht = ht.filter(hl.all( + ... ht.X < 7, + ... ht.Z > 2 + ... )) + >>> filtered_ht.show() + +-------+-------+-----+-------+-------+ + | ID | HT | SEX | X | Z | + +-------+-------+-----+-------+-------+ + | int32 | int32 | str | int32 | int32 | + +-------+-------+-----+-------+-------+ + | 1 | 65 | "M" | 5 | 4 | + | 2 | 72 | "M" | 6 | 3 | + +-------+-------+-----+-------+-------+ + + Keep rows where X is less than 7 or Z is greater than 2: + + >>> filtered_ht = ht.filter(hl.any( + ... ht.X < 7, + ... ht.Z > 2 + ... )) + >>> filtered_ht.show() + +-------+-------+-----+-------+-------+ + | ID | HT | SEX | X | Z | + +-------+-------+-----+-------+-------+ + | int32 | int32 | str | int32 | int32 | + +-------+-------+-----+-------+-------+ + | 1 | 65 | "M" | 5 | 4 | + | 2 | 72 | "M" | 6 | 3 | + | 3 | 70 | "F" | 7 | 3 | + +-------+-------+-----+-------+-------+ + + Keep "M" rows where ``HT`` is less than 72 and "F" rows where ``HT`` is less than 65: + + >>> filtered_ht = ht.filter( + ... hl.if_else( + ... ht.SEX == "M", + ... ht.HT < 72, + ... ht.HT < 65 + ... ) + ... ) + >>> filtered_ht.show() + +-------+-------+-----+-------+-------+ + | ID | HT | SEX | X | Z | + +-------+-------+-----+-------+-------+ + | int32 | int32 | str | int32 | int32 | + +-------+-------+-----+-------+-------+ + | 1 | 65 | "M" | 5 | 4 | + | 4 | 60 | "F" | 8 | 2 | + +-------+-------+-----+-------+-------+ + + Notice that if the condition evaluates to missing, the row is _always_ removed regardless of + the setting of `keep`: + + >>> ht2 = ht + >>> ht2 = ht.annotate(X = hl.or_missing(ht.X != 5, ht.X)) + >>> ht2.show() + +-------+-------+-----+-------+-------+ + | ID | HT | SEX | X | Z | + +-------+-------+-----+-------+-------+ + | int32 | int32 | str | int32 | int32 | + +-------+-------+-----+-------+-------+ + | 1 | 65 | "M" | NA | 4 | + | 2 | 72 | "M" | 6 | 3 | + | 3 | 70 | "F" | 7 | 3 | + | 4 | 60 | "F" | 8 | 2 | + +-------+-------+-----+-------+-------+ + >>> filtered_ht = ht2.filter(ht2.X < 7, keep=True) + >>> filtered_ht.show() + +-------+-------+-----+-------+-------+ + | ID | HT | SEX | X | Z | + +-------+-------+-----+-------+-------+ + | int32 | int32 | str | int32 | int32 | + +-------+-------+-----+-------+-------+ + | 2 | 72 | "M" | 6 | 3 | + +-------+-------+-----+-------+-------+ + >>> filtered_ht = ht2.filter(ht2.X < 7, keep=False) + >>> filtered_ht.show() + +-------+-------+-----+-------+-------+ + | ID | HT | SEX | X | Z | + +-------+-------+-----+-------+-------+ + | int32 | int32 | str | int32 | int32 | + +-------+-------+-----+-------+-------+ + | 3 | 70 | "F" | 7 | 3 | + | 4 | 60 | "F" | 8 | 2 | + +-------+-------+-----+-------+-------+ Notes ----- @@ -929,14 +1553,14 @@ def filter(self, expr, keep: bool = True) -> 'Table': ------- :class:`.Table` Filtered table. + """ analyze('Table.filter', expr, self._row_indices) base, cleanup = self._process_joins(expr) return cleanup(Table(ir.TableFilter(base._tir, ir.filter_predicate_with_keep(expr._ir, keep)))) - @typecheck_method(exprs=oneof(Expression, str), - named_exprs=anytype) + @typecheck_method(exprs=oneof(Expression, str), named_exprs=anytype) def select(self, *exprs, **named_exprs) -> 'Table': """Select existing fields or create new fields by name, dropping the rest. @@ -1021,11 +1645,7 @@ def select(self, *exprs, **named_exprs) -> 'Table': :class:`.Table` Table with specified fields. """ - row = get_select_exprs('Table.select', - exprs, - named_exprs, - self._row_indices, - self._row) + row = get_select_exprs('Table.select', exprs, named_exprs, self._row_indices, self._row) return self._select('Table.select', row) @@ -1073,8 +1693,9 @@ def drop(self, *exprs) -> 'Table': if e in all_field_exprs: fields_to_drop.add(all_field_exprs[e]) else: - raise ExpressionException("method 'drop' expects string field names or top-level field expressions" - " (e.g. table['foo'])") + raise ExpressionException( + "method 'drop' expects string field names or top-level field expressions" " (e.g. table['foo'])" + ) else: assert isinstance(e, str) if e not in self._fields: @@ -1084,8 +1705,9 @@ def drop(self, *exprs) -> 'Table': table = self if any(self._fields[field]._indices == self._global_indices for field in fields_to_drop): # need to drop globals - table = table._select_globals('drop', - self._globals.drop(*[f for f in table.globals if f in fields_to_drop])) + table = table._select_globals( + 'drop', self._globals.drop(*[f for f in table.globals if f in fields_to_drop]) + ) if any(self._fields[field]._indices == self._row_indices for field in fields_to_drop): # need to drop row fields @@ -1098,11 +1720,9 @@ def drop(self, *exprs) -> 'Table': return table - @typecheck_method(output=str, - types_file=nullable(str), - header=bool, - parallel=nullable(ir.ExportType.checker), - delimiter=str) + @typecheck_method( + output=str, types_file=nullable(str), header=bool, parallel=nullable(ir.ExportType.checker), delimiter=str + ) def export(self, output, types_file=None, header=True, parallel=None, delimiter='\t'): """Export to a text file. @@ -1151,7 +1771,8 @@ def export(self, output, types_file=None, header=True, parallel=None, delimiter= parallel = ir.ExportType.default(parallel) Env.backend().execute( - ir.TableWrite(self._tir, ir.TableTextWriter(output, types_file, header, parallel, delimiter))) + ir.TableWrite(self._tir, ir.TableTextWriter(output, types_file, header, parallel, delimiter)) + ) def group_by(self, *exprs, **named_exprs) -> 'GroupedTable': """Group by a new key for use with :meth:`.GroupedTable.aggregate`. @@ -1243,11 +1864,9 @@ def group_by(self, *exprs, **named_exprs) -> 'GroupedTable': :class:`.GroupedTable` Grouped table; use :meth:`.GroupedTable.aggregate` to complete the aggregation. """ - key, computed_key = get_key_by_exprs('Table.group_by', - exprs, - named_exprs, - self._row_indices, - override_protected_indices={self._global_indices}) + key, computed_key = get_key_by_exprs( + 'Table.group_by', exprs, named_exprs, self._row_indices, override_protected_indices={self._global_indices} + ) return GroupedTable(self, self.row.annotate(**computed_key).select(*key)) @typecheck_method(expr=expr_any, _localize=bool) @@ -1287,16 +1906,25 @@ def aggregate(self, expr, _localize=True): return construct_expr(ir.LiftMeOut(agg_ir), expr.dtype) - @typecheck_method(output=str, - overwrite=bool, - stage_locally=bool, - _codec_spec=nullable(str), - _read_if_exists=bool, - _intervals=nullable(sequenceof(anytype)), - _filter_intervals=bool) - def checkpoint(self, output: str, overwrite: bool = False, stage_locally: bool = False, - _codec_spec: Optional[str] = None, _read_if_exists: bool = False, - _intervals=None, _filter_intervals=False) -> 'Table': + @typecheck_method( + output=str, + overwrite=bool, + stage_locally=bool, + _codec_spec=nullable(str), + _read_if_exists=bool, + _intervals=nullable(sequenceof(anytype)), + _filter_intervals=bool, + ) + def checkpoint( + self, + output: str, + overwrite: bool = False, + stage_locally: bool = False, + _codec_spec: Optional[str] = None, + _read_if_exists: bool = False, + _intervals=None, + _filter_intervals=False, + ) -> 'Table': """Checkpoint the table to disk by writing and reading. Parameters @@ -1323,7 +1951,7 @@ def checkpoint(self, output: str, overwrite: bool = False, stage_locally: bool = Examples -------- - >>> table1 = table1.checkpoint('output/table_checkpoint.ht') + >>> table1 = table1.checkpoint('output/table_checkpoint.ht', overwrite=True) """ hl.current_backend().validate_file(output) @@ -1340,21 +1968,17 @@ def checkpoint(self, output: str, overwrite: bool = False, stage_locally: bool = _intervals=_intervals, _filter_intervals=_filter_intervals, _assert_type=_assert_type, - _load_refs=_load_refs + _load_refs=_load_refs, ) - @typecheck_method(output=str, - overwrite=bool, - stage_locally=bool, - _codec_spec=nullable(str)) - def write(self, output: str, overwrite=False, stage_locally: bool = False, - _codec_spec: Optional[str] = None): + @typecheck_method(output=str, overwrite=bool, stage_locally=bool, _codec_spec=nullable(str)) + def write(self, output: str, overwrite=False, stage_locally: bool = False, _codec_spec: Optional[str] = None): """Write to disk. Examples -------- - >>> table1.write('output/table1.ht') + >>> table1.write('output/table1.ht', overwrite=True) .. include:: _templates/write_warning.rst @@ -1375,20 +1999,20 @@ def write(self, output: str, overwrite=False, stage_locally: bool = False, hl.current_backend().validate_file(output) - Env.backend().execute(ir.TableWrite(self._tir, ir.TableNativeWriter(output, overwrite, stage_locally, _codec_spec))) - - @typecheck_method(output=str, - fields=sequenceof(str), - overwrite=bool, - stage_locally=bool, - _codec_spec=nullable(str)) - def write_many(self, - output: str, - fields: Sequence[str], - *, - overwrite: bool = False, - stage_locally: bool = False, - _codec_spec: Optional[str] = None): + Env.backend().execute( + ir.TableWrite(self._tir, ir.TableNativeWriter(output, overwrite, stage_locally, _codec_spec)) + ) + + @typecheck_method(output=str, fields=sequenceof(str), overwrite=bool, stage_locally=bool, _codec_spec=nullable(str)) + def write_many( + self, + output: str, + fields: Sequence[str], + *, + overwrite: bool = False, + stage_locally: bool = False, + _codec_spec: Optional[str] = None, + ): """Write fields to distinct tables. Examples @@ -1396,8 +2020,8 @@ def write_many(self, >>> t = hl.utils.range_table(10) >>> t = t.annotate(a = t.idx, b = t.idx * t.idx, c = hl.str(t.idx)) - >>> t.write_many('output', fields=('a', 'b', 'c')) - >>> hl.read_table('output/a').describe() + >>> t.write_many('output-many', fields=('a', 'b', 'c'), overwrite=True) + >>> hl.read_table('output-many/a').describe() ---------------------------------------- Global fields: None @@ -1408,7 +2032,7 @@ def write_many(self, ---------------------------------------- Key: ['idx'] ---------------------------------------- - >>> hl.read_table('output/a').show() + >>> hl.read_table('output-many/a').show() +-------+-------+ | a | idx | +-------+-------+ @@ -1425,7 +2049,7 @@ def write_many(self, | 8 | 8 | | 9 | 9 | +-------+-------+ - >>> hl.read_table('output/b').describe() + >>> hl.read_table('output-many/b').describe() ---------------------------------------- Global fields: None @@ -1436,7 +2060,7 @@ def write_many(self, ---------------------------------------- Key: ['idx'] ---------------------------------------- - >>> hl.read_table('output/b').show() + >>> hl.read_table('output-many/b').show() +-------+-------+ | b | idx | +-------+-------+ @@ -1453,7 +2077,7 @@ def write_many(self, | 64 | 8 | | 81 | 9 | +-------+-------+ - >>> hl.read_table('output/c').describe() + >>> hl.read_table('output-many/c').describe() ---------------------------------------- Global fields: None @@ -1464,7 +2088,7 @@ def write_many(self, ---------------------------------------- Key: ['idx'] ---------------------------------------- - >>> hl.read_table('output/c').show() + >>> hl.read_table('output-many/c').show() +-----+-------+ | c | idx | +-----+-------+ @@ -1504,10 +2128,7 @@ def write_many(self, hl.current_backend().validate_file(output) Env.backend().execute( - ir.TableWrite( - self._tir, - ir.TableNativeFanoutWriter(output, fields, overwrite, stage_locally, _codec_spec) - ) + ir.TableWrite(self._tir, ir.TableNativeFanoutWriter(output, fields, overwrite, stage_locally, _codec_spec)) ) def _show(self, n, width, truncate, types): @@ -1553,7 +2174,7 @@ def _ascii_str(self): def trunc(s): if len(s) > truncate: - return s[:truncate - 3] + "..." + return s[: truncate - 3] + "..." return s rows, has_more, dtype = self.data() @@ -1604,7 +2225,7 @@ def format_line(values, widths, right_align): s = '' first = True - for (start, end) in column_blocks: + for start, end in column_blocks: if first: first = False else: @@ -1621,8 +2242,8 @@ def format_line(values, widths, right_align): s += format_line(type_strs[start:end], block_column_width, block_right_align) s += hline for row in rows: - row = row[start:end] - s += format_line(row, block_column_width, block_right_align) + _row = row[start:end] + s += format_line(_row, block_column_width, block_right_align) s += hline if has_more: @@ -1633,19 +2254,19 @@ def format_line(values, widths, right_align): def _html_str(self): import html + types = self.types rows, has_more, dtype = self.data() fields = list(dtype) - default_td_style = ('white-space: nowrap; ' - 'max-width: 500px; ' - 'overflow: hidden; ' - 'text-overflow: ellipsis; ') + default_td_style = ( + 'white-space: nowrap; ' 'max-width: 500px; ' 'overflow: hidden; ' 'text-overflow: ellipsis; ' + ) def format_line(values, extra_style=''): style = default_td_style + extra_style - return (f'' + f''.join(values) + '\n') + return f'' + f''.join(values) + '\n' arranged_field_names = PlacementTree.from_named_type('row', self.table.row.dtype) @@ -1667,8 +2288,7 @@ def format_line(values, extra_style=''): s += '' s += '' if types: - s += format_line([html.escape(str(dtype[f])) for f in fields], - extra_style="text-align: left;") + s += format_line([html.escape(str(dtype[f])) for f in fields], extra_style="text-align: left;") s += '' for row in rows: s += format_line([html.escape(row[f]) for f in row]) @@ -1696,7 +2316,14 @@ def _take_n(self, n): def _hl_format(v, truncate): return hl._showstr(v, truncate) - @typecheck_method(n=nullable(int), width=nullable(int), truncate=nullable(int), types=bool, handler=nullable(anyfunc), n_rows=nullable(int)) + @typecheck_method( + n=nullable(int), + width=nullable(int), + truncate=nullable(int), + types=bool, + handler=nullable(anyfunc), + n_rows=nullable(int), + ) def show(self, n=None, width=None, truncate=None, types=True, handler=None, n_rows=None): """Print the first few rows of the table to the console. @@ -1821,9 +2448,11 @@ def index(self, *exprs, all_matches=False) -> 'Expression': try: return self._index(*exprs, all_matches=all_matches) except TableIndexKeyError as err: - raise ExpressionException(f"Key type mismatch: cannot index table with given expressions:\n" - f" Table key: {', '.join(str(t) for t in err.key_type.values()) or '<<>>'}\n" - f" Index Expressions: {', '.join(str(e.dtype) for e in err.index_expressions)}") + raise ExpressionException( + f"Key type mismatch: cannot index table with given expressions:\n" + f" Table key: {', '.join(str(t) for t in err.key_type.values()) or '<<>>'}\n" + f" Index Expressions: {', '.join(str(e.dtype) for e in err.index_expressions)}" + ) @staticmethod def _maybe_truncate_for_flexindex(indexer, indexee_dtype): @@ -1839,13 +2468,13 @@ def _maybe_truncate_for_flexindex(indexer, indexee_dtype): break matching_prefix += 1 prefix_match = matching_prefix == len(indexee_dtype) - direct_match = prefix_match and \ - len(indexer) == len(indexee_dtype) - prefix_interval_match = len(indexee_dtype) == 1 and \ - isinstance(indexee_dtype[0], hl.tinterval) and \ - indexer.dtype[0] == indexee_dtype[0].point_type - direct_interval_match = prefix_interval_match and \ - len(indexer) == 1 + direct_match = prefix_match and len(indexer) == len(indexee_dtype) + prefix_interval_match = ( + len(indexee_dtype) == 1 + and isinstance(indexee_dtype[0], hl.tinterval) + and indexer.dtype[0] == indexee_dtype[0].point_type + ) + direct_interval_match = prefix_interval_match and len(indexer) == 1 if direct_match or direct_interval_match: return indexer if prefix_match: @@ -1856,8 +2485,7 @@ def _maybe_truncate_for_flexindex(indexer, indexee_dtype): @typecheck_method(indexer=expr_any, all_matches=bool) def _maybe_flexindex_table_by_expr(self, indexer, all_matches=False): - truncated_indexer = Table._maybe_truncate_for_flexindex( - indexer, self.key.dtype) + truncated_indexer = Table._maybe_truncate_for_flexindex(indexer, self.key.dtype) if truncated_indexer is not None: return self.index(truncated_indexer, all_matches=all_matches) return None @@ -1871,6 +2499,7 @@ def _index(self, *exprs, all_matches=False) -> 'Expression': raise TypeError(f"Index arguments must be expressions, found {non_exprs}") from hail.matrixtable import MatrixTable + indices, aggregations = unify_all(*exprs) src = indices.source @@ -1878,18 +2507,18 @@ def _index(self, *exprs, all_matches=False) -> 'Expression': # FIXME: this should be OK: table[m.global_index_into_table] raise ExpressionException('Cannot index with a scalar expression') - is_interval = (len(exprs) == 1 - and len(self.key) > 0 - and isinstance(self.key[0].dtype, hl.tinterval) - and exprs[0].dtype == self.key[0].dtype.point_type) + is_interval = ( + len(exprs) == 1 + and len(self.key) > 0 + and isinstance(self.key[0].dtype, hl.tinterval) + and exprs[0].dtype == self.key[0].dtype.point_type + ) if not types_match(list(self.key.values()), list(exprs)): - if (len(exprs) == 1 - and isinstance(exprs[0], TupleExpression)): + if len(exprs) == 1 and isinstance(exprs[0], TupleExpression): return self._index(*exprs[0], all_matches=all_matches) - if (len(exprs) == 1 - and isinstance(exprs[0], StructExpression)): + if len(exprs) == 1 and isinstance(exprs[0], StructExpression): return self._index(*exprs[0].values(), all_matches=all_matches) if not is_interval: @@ -1908,7 +2537,9 @@ def _index(self, *exprs, all_matches=False) -> 'Expression': for e in exprs: analyze('Table.index', e, src._row_indices) - is_key = len(src.key) >= len(exprs) and all(expr is key_field for expr, key_field in zip(exprs, src.key.values())) + is_key = len(src.key) >= len(exprs) and all( + expr is key_field for expr, key_field in zip(exprs, src.key.values()) + ) if not is_key: uids = [Env.get_uid() for i in range(len(exprs))] @@ -1919,30 +2550,29 @@ def _index(self, *exprs, all_matches=False) -> 'Expression': def joiner(left): if not is_key: original_key = list(left.key) - left = Table(ir.TableMapRows(left.key_by()._tir, - ir.InsertFields(left._row._ir, - list(zip(uids, [e._ir for e in exprs])), - None))).key_by(*uids) + left = Table( + ir.TableMapRows( + left.key_by()._tir, + ir.InsertFields(left._row._ir, list(zip(uids, [e._ir for e in exprs])), None), + ) + ).key_by(*uids) def rekey_f(t): return t.key_by(*original_key) + else: + def rekey_f(t): return t if is_interval: - if all_matches: - hl.utils.no_service_backend('interval join with all_matches=True') left = Table(ir.TableIntervalJoin(left._tir, self._tir, uid, all_matches)) else: left = Table(ir.TableLeftJoinRightDistinct(left._tir, self._tir, uid)) return rekey_f(left) all_uids.append(uid) - join_ir = ir.Join(ir.ProjectedTopLevelReference('row', uid, new_schema), - all_uids, - exprs, - joiner) + join_ir = ir.Join(ir.ProjectedTopLevelReference('row', uid, new_schema), all_uids, exprs, joiner) return construct_expr(join_ir, new_schema, indices, aggregations) elif isinstance(src, MatrixTable): for e in exprs: @@ -1954,7 +2584,8 @@ def rekey_f(t): raise NotImplementedError('entry-based matrix joins') elif indices == src._row_indices: is_subset_row_key = len(exprs) <= len(src.row_key) and all( - expr is key_field for expr, key_field in zip(exprs, src.row_key.values())) + expr is key_field for expr, key_field in zip(exprs, src.row_key.values()) + ) if not (is_subset_row_key or is_interval): # foreign-key join @@ -1969,10 +2600,16 @@ def rekey_f(t): join_table = join_table.annotate(**{value_uid: right.index(join_table.key)}) # FIXME: Maybe zip join here? - join_table = join_table.group_by(*src.row_key).aggregate( - **{uid: - hl.dict(hl.agg.collect(hl.tuple([hl.tuple([join_table[f] for f in foreign_key_annotates]), - join_table[value_uid]])))}) + join_table = join_table.group_by(*src.row_key).aggregate(**{ + uid: hl.dict( + hl.agg.collect( + hl.tuple([ + hl.tuple([join_table[f] for f in foreign_key_annotates]), + join_table[value_uid], + ]) + ) + ) + }) def joiner(left: MatrixTable): mart = ir.MatrixAnnotateRowsTable(left._mir, join_table._tir, uid) @@ -1981,25 +2618,36 @@ def joiner(left: MatrixTable): mart, ir.InsertFields( ir.Ref('va', mart.typ.row_type), - [(uid, ir.Apply('get', join_table._row_type[uid].value_type, - ir.GetField(ir.GetField(ir.Ref('va', mart.typ.row_type), uid), uid), - ir.MakeTuple([e._ir for e in exprs])))], - None))) + [ + ( + uid, + ir.Apply( + 'get', + join_table._row_type[uid].value_type, + ir.GetField(ir.GetField(ir.Ref('va', mart.typ.row_type), uid), uid), + ir.MakeTuple([e._ir for e in exprs]), + ), + ) + ], + None, + ), + ) + ) + else: + def joiner(left: MatrixTable): return MatrixTable(ir.MatrixAnnotateRowsTable(left._mir, right._tir, uid, all_matches)) - ast = ir.Join(ir.ProjectedTopLevelReference('va', uid, new_schema), - [uid], - exprs, - joiner) + + ast = ir.Join(ir.ProjectedTopLevelReference('va', uid, new_schema), [uid], exprs, joiner) return construct_expr(ast, new_schema, indices, aggregations) elif indices == src._col_indices and not (is_interval and all_matches): all_uids = [uid] - if len(exprs) == len(src.col_key) and all([ - exprs[i] is src.col_key[i] for i in range(len(exprs))]): + if len(exprs) == len(src.col_key) and all([exprs[i] is src.col_key[i] for i in range(len(exprs))]): # key is already correct def joiner(left): return MatrixTable(ir.MatrixAnnotateColsTable(left._mir, right._tir, uid)) + else: index_uid = Env.get_uid() uids = [Env.get_uid() for _ in exprs] @@ -2009,26 +2657,24 @@ def joiner(left): def joiner(left: MatrixTable): prev_key = list(src.col_key) - joined = (src - .annotate_cols(**dict(zip(uids, exprs))) - .add_col_index(index_uid) - .key_cols_by(*uids) - .cols() - .select(index_uid) - .join(self, 'inner') - .key_by(index_uid) - .drop(*uids)) - result = MatrixTable(ir.MatrixAnnotateColsTable( - (left.add_col_index(index_uid) - .key_cols_by(index_uid) - ._mir), - joined._tir, - uid)).key_cols_by(*prev_key) + joined = ( + src.annotate_cols(**dict(zip(uids, exprs))) + .add_col_index(index_uid) + .key_cols_by(*uids) + .cols() + .select(index_uid) + .join(self, 'inner') + .key_by(index_uid) + .drop(*uids) + ) + result = MatrixTable( + ir.MatrixAnnotateColsTable( + (left.add_col_index(index_uid).key_cols_by(index_uid)._mir), joined._tir, uid + ) + ).key_cols_by(*prev_key) return result - join_ir = ir.Join(ir.ProjectedTopLevelReference('sa', uid, new_schema), - all_uids, - exprs, - joiner) + + join_ir = ir.Join(ir.ProjectedTopLevelReference('sa', uid, new_schema), all_uids, exprs, joiner) return construct_expr(join_ir, new_schema, indices, aggregations) else: raise NotImplementedError() @@ -2128,11 +2774,11 @@ def unpersist(self) -> 'Table': return Env.backend().unpersist(self) @overload - def collect(self) -> List[hl.Struct]: - ... + def collect(self) -> List[hl.Struct]: ... + @overload - def collect(self, _localize=False) -> hl.ArrayExpression: - ... + def collect(self, _localize=False) -> ArrayExpression: ... + @typecheck_method(_localize=bool, _timed=bool) def collect(self, _localize=True, *, _timed=False): """Collect the rows of the table into a local list. @@ -2185,6 +2831,7 @@ def describe(self, handler=print, *, widget=False): """ if widget: from hail.experimental.interact import interact + return interact(self) def format_type(typ): @@ -2193,26 +2840,28 @@ def format_type(typ): if len(self.globals) == 0: global_fields = '\n None' else: - global_fields = ''.join("\n '{name}': {type} ".format( - name=f, type=format_type(t)) for f, t in self.globals.dtype.items()) + global_fields = ''.join( + "\n '{name}': {type} ".format(name=f, type=format_type(t)) for f, t in self.globals.dtype.items() + ) if len(self.row) == 0: row_fields = '\n None' else: - row_fields = ''.join("\n '{name}': {type} ".format( - name=f, type=format_type(t)) for f, t in self.row.dtype.items()) + row_fields = ''.join( + "\n '{name}': {type} ".format(name=f, type=format_type(t)) for f, t in self.row.dtype.items() + ) row_key = '[' + ', '.join("'{name}'".format(name=f) for f in self.key) + ']' - s = '----------------------------------------\n' \ - 'Global fields:{g}\n' \ - '----------------------------------------\n' \ - 'Row fields:{r}\n' \ - '----------------------------------------\n' \ - 'Key: {rk}\n' \ - '----------------------------------------'.format(g=global_fields, - rk=row_key, - r=row_fields) + s = ( + '----------------------------------------\n' + 'Global fields:{g}\n' + '----------------------------------------\n' + 'Row fields:{r}\n' + '----------------------------------------\n' + 'Key: {rk}\n' + '----------------------------------------'.format(g=global_fields, rk=row_key, r=row_fields) + ) handler(s) @typecheck_method(name=str) @@ -2295,19 +2944,26 @@ def union(self, *tables, unify: bool = False) -> 'Table': Table with all rows from each component table. """ left_key = self.key.dtype - for i, ht, in enumerate(tables): + for ( + i, + ht, + ) in enumerate(tables): if left_key != ht.key.dtype: - raise ValueError(f"'union': table {i} has a different key." - f" Expected: {left_key}\n" - f" Table {i}: {ht.key.dtype}") + raise ValueError( + f"'union': table {i} has a different key." + f" Expected: {left_key}\n" + f" Table {i}: {ht.key.dtype}" + ) if not (unify or ht.row.dtype == self.row.dtype): - raise ValueError(f"'union': table {i} has a different row type.\n" - f" Expected: {self.row.dtype}\n" - f" Table {i}: {ht.row.dtype}\n" - f" If the tables have the same fields in different orders, or some\n" - f" common and some unique fields, then the 'unify' parameter may be\n" - f" able to coerce the tables to a common type.") + raise ValueError( + f"'union': table {i} has a different row type.\n" + f" Expected: {self.row.dtype}\n" + f" Table {i}: {ht.row.dtype}\n" + f" If the tables have the same fields in different orders, or some\n" + f" common and some unique fields, then the 'unify' parameter may be\n" + f" able to coerce the tables to a common type." + ) all_tables = [self] all_tables.extend(tables) @@ -2320,8 +2976,10 @@ def union(self, *tables, unify: bool = False) -> 'Table': for field_name, expr_dict in discovered.items(): *unified, can_unify = hl.expr.expressions.unify_exprs(*expr_dict.values()) if not can_unify: - raise ValueError(f"cannot unify field {field_name!r}: found fields of types " - f"{[str(t) for t in {e.dtype for e in expr_dict.values()}]}") + raise ValueError( + f"cannot unify field {field_name!r}: found fields of types " + f"{[str(t) for t in {e.dtype for e in expr_dict.values()}]}" + ) unified_map = dict(zip(expr_dict.keys(), unified)) default = hl.missing(unified[0].dtype) for i in range(len(all_tables)): @@ -2376,9 +3034,17 @@ def head(self, n) -> 'Table': -------- Subset to the first three rows: - >>> table_result = table1.head(3) - >>> table_result.count() - 3 + >>> ht = hl.import_table('data/kt_example1.tsv', impute=True) + >>> ht.head(3).show() + +-------+-------+-----+-------+-------+-------+-------+-------+ + | ID | HT | SEX | X | Z | C1 | C2 | C3 | + +-------+-------+-----+-------+-------+-------+-------+-------+ + | int32 | int32 | str | int32 | int32 | int32 | int32 | int32 | + +-------+-------+-----+-------+-------+-------+-------+-------+ + | 1 | 65 | "M" | 5 | 4 | 2 | 50 | 5 | + | 2 | 72 | "M" | 6 | 3 | 2 | 61 | 1 | + | 3 | 70 | "F" | 7 | 3 | 10 | 81 | -5 | + +-------+-------+-----+-------+-------+-------+-------+-------+ Notes ----- @@ -2394,7 +3060,7 @@ def head(self, n) -> 'Table': Returns ------- :class:`.Table` - Table including the first `n` rows. + Table limited to the first `n` rows. """ return Table(ir.TableHead(self._tir, n)) @@ -2430,8 +3096,7 @@ def tail(self, n) -> 'Table': return Table(ir.TableTail(self._tir, n)) - @typecheck_method(p=numeric, - seed=nullable(int)) + @typecheck_method(p=numeric, seed=nullable(int)) def sample(self, p, seed=None) -> 'Table': """Downsample the table by keeping each row with probability ``p``. @@ -2440,7 +3105,37 @@ def sample(self, p, seed=None) -> 'Table': Downsample the table to approximately 1% of its rows. - >>> small_table1 = table1.sample(0.01) + >>> table1.show() + +-------+-------+-----+-------+-------+-------+-------+-------+ + | ID | HT | SEX | X | Z | C1 | C2 | C3 | + +-------+-------+-----+-------+-------+-------+-------+-------+ + | int32 | int32 | str | int32 | int32 | int32 | int32 | int32 | + +-------+-------+-----+-------+-------+-------+-------+-------+ + | 1 | 65 | "M" | 5 | 4 | 2 | 50 | 5 | + | 2 | 72 | "M" | 6 | 3 | 2 | 61 | 1 | + | 3 | 70 | "F" | 7 | 3 | 10 | 81 | -5 | + | 4 | 60 | "F" | 8 | 2 | 11 | 90 | -10 | + +-------+-------+-----+-------+-------+-------+-------+-------+ + >>> small_table1 = table1.sample(0.75, seed=0) + >>> small_table1.show() + +-------+-------+-----+-------+-------+-------+-------+-------+ + | ID | HT | SEX | X | Z | C1 | C2 | C3 | + +-------+-------+-----+-------+-------+-------+-------+-------+ + | int32 | int32 | str | int32 | int32 | int32 | int32 | int32 | + +-------+-------+-----+-------+-------+-------+-------+-------+ + | 1 | 65 | "M" | 5 | 4 | 2 | 50 | 5 | + | 3 | 70 | "F" | 7 | 3 | 10 | 81 | -5 | + | 4 | 60 | "F" | 8 | 2 | 11 | 90 | -10 | + +-------+-------+-----+-------+-------+-------+-------+-------+ + >>> small_table1 = table1.sample(0.25, seed=4) + >>> small_table1.show() + +-------+-------+-----+-------+-------+-------+-------+-------+ + | ID | HT | SEX | X | Z | C1 | C2 | C3 | + +-------+-------+-----+-------+-------+-------+-------+-------+ + | int32 | int32 | str | int32 | int32 | int32 | int32 | int32 | + +-------+-------+-----+-------+-------+-------+-------+-------+ + | 1 | 65 | "M" | 5 | 4 | 2 | 50 | 5 | + +-------+-------+-----+-------+-------+-------+-------+-------+ Parameters ---------- @@ -2460,8 +3155,7 @@ def sample(self, p, seed=None) -> 'Table': return self.filter(hl.rand_bool(p, seed)) - @typecheck_method(n=int, - shuffle=bool) + @typecheck_method(n=int, shuffle=bool) def repartition(self, n, shuffle=True) -> 'Table': """Change the number of partitions. @@ -2525,8 +3219,11 @@ def repartition(self, n, shuffle=True) -> 'Table': self.checkpoint(tmp) return hl.read_table(tmp, _n_partitions=n) - return Table(ir.TableRepartition( - self._tir, n, ir.RepartitionStrategy.SHUFFLE if shuffle else ir.RepartitionStrategy.COALESCE)) + return Table( + ir.TableRepartition( + self._tir, n, ir.RepartitionStrategy.SHUFFLE if shuffle else ir.RepartitionStrategy.COALESCE + ) + ) @typecheck_method(max_partitions=int) def naive_coalesce(self, max_partitions: int) -> 'Table': @@ -2557,8 +3254,7 @@ def naive_coalesce(self, max_partitions: int) -> 'Table': :class:`.Table` Table with at most `max_partitions` partitions. """ - return Table(ir.TableRepartition( - self._tir, max_partitions, ir.RepartitionStrategy.NAIVE_COALESCE)) + return Table(ir.TableRepartition(self._tir, max_partitions, ir.RepartitionStrategy.NAIVE_COALESCE)) @typecheck_method(other=table_type) def semi_join(self, other: 'Table') -> 'Table': @@ -2584,13 +3280,50 @@ def semi_join(self, other: 'Table') -> 'Table': Examples -------- - >>> table_result = table1.semi_join(table2) + >>> table1.show() + +-------+-------+-----+-------+-------+-------+-------+-------+ + | ID | HT | SEX | X | Z | C1 | C2 | C3 | + +-------+-------+-----+-------+-------+-------+-------+-------+ + | int32 | int32 | str | int32 | int32 | int32 | int32 | int32 | + +-------+-------+-----+-------+-------+-------+-------+-------+ + | 1 | 65 | "M" | 5 | 4 | 2 | 50 | 5 | + | 2 | 72 | "M" | 6 | 3 | 2 | 61 | 1 | + | 3 | 70 | "F" | 7 | 3 | 10 | 81 | -5 | + | 4 | 60 | "F" | 8 | 2 | 11 | 90 | -10 | + +-------+-------+-----+-------+-------+-------+-------+-------+ + >>> small_table2 = table2.head(2) + >>> small_table2.show() + +-------+-------+-------+ + | ID | A | B | + +-------+-------+-------+ + | int32 | int32 | str | + +-------+-------+-------+ + | 1 | 65 | "cat" | + | 2 | 72 | "dog" | + +-------+-------+-------+ + >>> table1.semi_join(small_table2).show() + +-------+-------+-----+-------+-------+-------+-------+-------+ + | ID | HT | SEX | X | Z | C1 | C2 | C3 | + +-------+-------+-----+-------+-------+-------+-------+-------+ + | int32 | int32 | str | int32 | int32 | int32 | int32 | int32 | + +-------+-------+-----+-------+-------+-------+-------+-------+ + | 1 | 65 | "M" | 5 | 4 | 2 | 50 | 5 | + | 2 | 72 | "M" | 6 | 3 | 2 | 61 | 1 | + +-------+-------+-----+-------+-------+-------+-------+-------+ It may be expensive to key the left-side table by the right-side key. In this case, it is possible to implement a semi-join using a non-key field as follows: - >>> table_result = table1.filter(hl.is_defined(table2.index(table1['ID']))) + >>> table1.filter(hl.is_defined(small_table2.index(table1['ID']))).show() + +-------+-------+-----+-------+-------+-------+-------+-------+ + | ID | HT | SEX | X | Z | C1 | C2 | C3 | + +-------+-------+-----+-------+-------+-------+-------+-------+ + | int32 | int32 | str | int32 | int32 | int32 | int32 | int32 | + +-------+-------+-----+-------+-------+-------+-------+-------+ + | 1 | 65 | "M" | 5 | 4 | 2 | 50 | 5 | + | 2 | 72 | "M" | 6 | 3 | 2 | 61 | 1 | + +-------+-------+-----+-------+-------+-------+-------+-------+ See Also -------- @@ -2598,10 +3331,14 @@ def semi_join(self, other: 'Table') -> 'Table': """ if len(other.key) == 0: raise ValueError('semi_join: cannot join with a table with no key') - if len(other.key) > len(self.key) or any(t[0].dtype != t[1].dtype for t in zip(self.key.values(), other.key.values())): - raise ValueError('semi_join: cannot join: table must have a key of the same type(s) and be the same length or shorter:' - f'\n Left key: {", ".join(str(x.dtype) for x in self.key.values())}' - f'\n Right key: {", ".join(str(x.dtype) for x in other.key.values())}') + if len(other.key) > len(self.key) or any( + t[0].dtype != t[1].dtype for t in zip(self.key.values(), other.key.values()) + ): + raise ValueError( + 'semi_join: cannot join: table must have a key of the same type(s) and be the same length or shorter:' + f'\n Left key: {", ".join(str(x.dtype) for x in self.key.values())}' + f'\n Right key: {", ".join(str(x.dtype) for x in other.key.values())}' + ) return self.filter(hl.is_defined(other.index(*(self.key[i] for i in range(len(other.key)))))) @@ -2643,22 +3380,27 @@ def anti_join(self, other: 'Table') -> 'Table': """ if len(other.key) == 0: raise ValueError('anti_join: cannot join with a table with no key') - if len(other.key) > len(self.key) or any(t[0].dtype != t[1].dtype for t in zip(self.key.values(), other.key.values())): - raise ValueError('anti_join: cannot join: table must have a key of the same type(s) and be the same length or shorter:' - f'\n Left key: {", ".join(str(x.dtype) for x in self.key.values())}' - f'\n Right key: {", ".join(str(x.dtype) for x in other.key.values())}') + if len(other.key) > len(self.key) or any( + t[0].dtype != t[1].dtype for t in zip(self.key.values(), other.key.values()) + ): + raise ValueError( + 'anti_join: cannot join: table must have a key of the same type(s) and be the same length or shorter:' + f'\n Left key: {", ".join(str(x.dtype) for x in self.key.values())}' + f'\n Right key: {", ".join(str(x.dtype) for x in other.key.values())}' + ) return self.filter(hl.is_missing(other.index(*(self.key[i] for i in range(len(other.key)))))) - @typecheck_method(right=table_type, - how=enumeration('inner', 'outer', 'left', 'right'), - _mangle=anyfunc, - _join_key=nullable(int)) - def join(self, - right: 'Table', - how='inner', - _mangle: Callable[[str, int], str] = lambda s, i: f'{s}_{i}', - _join_key: int = None) -> 'Table': + @typecheck_method( + right=table_type, how=enumeration('inner', 'outer', 'left', 'right'), _mangle=anyfunc, _join_key=nullable(int) + ) + def join( + self, + right: 'Table', + how='inner', + _mangle: Callable[[str, int], str] = lambda s, i: f'{s}_{i}', + _join_key: Optional[int] = None, + ) -> 'Table': """Join two tables together. Examples @@ -2725,20 +3467,23 @@ def join(self, left_key_types = list(self.key.dtype.values())[:_join_key] right_key_types = list(right.key.dtype.values())[:_join_key] if not left_key_types == right_key_types: - raise ValueError(f"'join': key mismatch:\n " - f" left: [{', '.join(str(t) for t in left_key_types)}]\n " - f" right: [{', '.join(str(t) for t in right_key_types)}]") + raise ValueError( + f"'join': key mismatch:\n " + f" left: [{', '.join(str(t) for t in left_key_types)}]\n " + f" right: [{', '.join(str(t) for t in right_key_types)}]" + ) left_fields = set(self._fields) right_fields = set(right._fields) - set(right.key) - renames, _ = deduplicate( - right_fields, max_attempts=100, already_used=left_fields) + renames, _ = deduplicate(right_fields, max_attempts=100, already_used=left_fields) if renames: renames = dict(renames) right = right.rename(renames) - info('Table.join: renamed the following fields on the right to avoid name conflicts:' - + ''.join(f'\n {repr(k)} -> {repr(v)}' for k, v in renames.items())) + info( + 'Table.join: renamed the following fields on the right to avoid name conflicts:' + + ''.join(f'\n {k!r} -> {v!r}' for k, v in renames.items()) + ) return Table(ir.TableJoin(self._tir, right._tir, how, _join_key)) @@ -2822,7 +3567,9 @@ def rename(self, mapping) -> 'Table': if v in seen: raise ValueError( "Cannot rename two fields to the same name: attempted to rename {} and {} both to {}".format( - repr(seen[v]), repr(k), repr(v))) + repr(seen[v]), repr(k), repr(v) + ) + ) if v in self._fields and v not in mapping: raise ValueError("Cannot rename {} to {}: field already exists.".format(repr(k), repr(v))) seen[v] = k @@ -2869,23 +3616,18 @@ def expand_types(self) -> 'Table': t = t.order_by(*t.key) def _expand(e): - if isinstance(e, CollectionExpression) or isinstance(e, DictExpression): + if isinstance(e, (CollectionExpression, DictExpression)): return hl.map(lambda x: _expand(x), hl.array(e)) elif isinstance(e, StructExpression): return hl.struct(**{k: _expand(v) for (k, v) in e.items()}) elif isinstance(e, TupleExpression): return hl.struct(**{f'_{i}': x for (i, x) in enumerate(e)}) elif isinstance(e, IntervalExpression): - return hl.struct(start=e.start, - end=e.end, - includesStart=e.includes_start, - includesEnd=e.includes_end) + return hl.struct(start=e.start, end=e.end, includesStart=e.includes_start, includesEnd=e.includes_end) elif isinstance(e, LocusExpression): - return hl.struct(contig=e.contig, - position=e.position) + return hl.struct(contig=e.contig, position=e.position) elif isinstance(e, CallExpression): - return hl.struct(alleles=hl.map(lambda i: e[i], hl.range(0, e.ploidy)), - phased=e.phased) + return hl.struct(alleles=hl.map(lambda i: e[i], hl.range(0, e.ploidy)), phased=e.phased) elif isinstance(e, NDArrayExpression): return hl.struct(shape=e.shape, data=_expand(e._data_array())) else: @@ -3012,17 +3754,18 @@ def order_by(self, *exprs) -> 'Table': """ lifted_exprs = [] for e in exprs: + _e = e sort_type = 'A' if isinstance(e, Ascending): - e = e.col + _e = e.col elif isinstance(e, Descending): - e = e.col + _e = e.col sort_type = 'D' - if isinstance(e, str): - expr = self[e] + if isinstance(_e, str): + expr = self[_e] else: - expr = e + expr = _e lifted_exprs.append((expr, sort_type)) sort_fields = [] @@ -3033,8 +3776,9 @@ def order_by(self, *exprs) -> 'Table': if e._indices.source is None: raise ValueError("Sort fields must be fields of the callee Table, found scalar expression") else: - raise ValueError(f"Sort fields must be fields of the callee Table," - f" found field of {e._indices.source}") + raise ValueError( + f"Sort fields must be fields of the callee Table," f" found field of {e._indices.source}" + ) elif e._indices != self._row_indices: raise ValueError("Sort fields must be row-indexed, found global sort expression") else: @@ -3052,8 +3796,7 @@ def order_by(self, *exprs) -> 'Table': t = t.drop(*complex_exprs.keys()) return t - @typecheck_method(field=oneof(str, Expression), - name=nullable(str)) + @typecheck_method(field=oneof(str, Expression), name=nullable(str)) def explode(self, field, name=None) -> 'Table': """Explode rows along a field of type array or set, copying the entire row for each element. @@ -3136,15 +3879,17 @@ def explode(self, field, name=None) -> 'Table': if field not in self._fields: raise KeyError("Table has no field '{}'".format(field)) elif self._fields[field]._indices != self._row_indices: - raise ExpressionException("Method 'explode' expects a field indexed by row, found axes '{}'" - .format(self._fields[field]._indices.axes)) + raise ExpressionException( + "Method 'explode' expects a field indexed by row, found axes '{}'".format( + self._fields[field]._indices.axes + ) + ) root = [field] field = self._fields[field] else: analyze('Table.explode', field, self._row_indices, set(self._fields.keys())) if not field._ir.is_nested_field: - raise ExpressionException( - "method 'explode' requires a field or subfield, not a complex expression") + raise ExpressionException("method 'explode' requires a field or subfield, not a complex expression") nested = field._ir root = [] while isinstance(nested, ir.GetField): @@ -3166,11 +3911,13 @@ def explode(self, field, name=None) -> 'Table': t = t.rename({root[0]: name}) return t - @typecheck_method(row_key=sequenceof(str), - col_key=sequenceof(str), - row_fields=sequenceof(str), - col_fields=sequenceof(str), - n_partitions=nullable(int)) + @typecheck_method( + row_key=sequenceof(str), + col_key=sequenceof(str), + row_fields=sequenceof(str), + col_fields=sequenceof(str), + n_partitions=nullable(int), + ) def to_matrix_table(self, row_key, col_key, row_fields=[], col_fields=[], n_partitions=None) -> 'hl.MatrixTable': """Construct a matrix table from a table in coordinate representation. @@ -3234,9 +3981,9 @@ def to_matrix_table(self, row_key, col_key, row_fields=[], col_fields=[], n_part row_field_set = set(self.row) for k, v in c.items(): if k not in row_field_set: - raise ValueError(f"'to_matrix_table': field {repr(k)} is not a row field") + raise ValueError(f"'to_matrix_table': field {k!r} is not a row field") if v > 1: - raise ValueError(f"'to_matrix_table': field {repr(k)} appeared in {v} field groups") + raise ValueError(f"'to_matrix_table': field {k!r} appeared in {v} field groups") if len(row_key) == 0: raise ValueError("'to_matrix_table': require at least one row key field") @@ -3249,16 +3996,21 @@ def to_matrix_table(self, row_key, col_key, row_fields=[], col_fields=[], n_part entry_fields = [x for x in ht.row if x not in non_entry_fields] if not entry_fields: - raise ValueError("'Table.to_matrix_table': no fields remain as entry fields:\n" - " all table fields found in one of 'row_key', 'col_key', 'row_fields', 'col_fields'") + raise ValueError( + "'Table.to_matrix_table': no fields remain as entry fields:\n" + " all table fields found in one of 'row_key', 'col_key', 'row_fields', 'col_fields'" + ) col_data = hl.rbind( hl.array( ht.aggregate( hl.agg.group_by(ht.row.select(*col_key), hl.agg.take(ht.row.select(*col_fields), 1)[0]), - _localize=False)), - lambda data: hl.struct(data=data, - key_to_index=hl.dict(hl.range(0, hl.len(data)).map(lambda i: (data[i][0], i)))) + _localize=False, + ) + ), + lambda data: hl.struct( + data=data, key_to_index=hl.dict(hl.range(0, hl.len(data)).map(lambda i: (data[i][0], i))) + ), ) col_data_uid = Env.get_uid() @@ -3266,17 +4018,30 @@ def to_matrix_table(self, row_key, col_key, row_fields=[], col_fields=[], n_part ht = ht.annotate_globals(**{col_data_uid: col_data}) entries_uid = Env.get_uid() - ht = (ht.group_by(*row_key) - .partition_hint(n_partitions) - # FIXME: should be agg._prev_nonnull https://github.com/hail-is/hail/issues/5345 - .aggregate(**{x: hl.agg.take(ht[x], 1)[0] for x in row_fields}, - **{entries_uid: hl.rbind( - hl.dict(hl.agg.collect((ht[col_data_uid]['key_to_index'][ht.row.select(*col_key)], - ht.row.select(*entry_fields)))), - lambda entry_dict: hl.range(0, hl.len(ht[col_data_uid]['key_to_index'])) - .map(lambda i: entry_dict.get(i)))})) - ht = ht.annotate_globals( - **{col_data_uid: hl.array(ht[col_data_uid]['data'].map(lambda elt: hl.struct(**elt[0], **elt[1])))}) + ht = ( + ht.group_by(*row_key) + .partition_hint(n_partitions) + # FIXME: should be agg._prev_nonnull https://github.com/hail-is/hail/issues/5345 + .aggregate( + **{x: hl.agg.take(ht[x], 1)[0] for x in row_fields}, + **{ + entries_uid: hl.rbind( + hl.dict( + hl.agg.collect(( + ht[col_data_uid]['key_to_index'][ht.row.select(*col_key)], + ht.row.select(*entry_fields), + )) + ), + lambda entry_dict: hl.range(0, hl.len(ht[col_data_uid]['key_to_index'])).map( + lambda i: entry_dict.get(i) + ), + ) + }, + ) + ) + ht = ht.annotate_globals(**{ + col_data_uid: hl.array(ht[col_data_uid]['data'].map(lambda elt: hl.struct(**elt[0], **elt[1]))) + }) return ht._unlocalize_entries(entries_uid, col_data_uid, col_key) @typecheck_method(columns=sequenceof(str), entry_field_name=nullable(str), col_field_name=str) @@ -3449,8 +4214,7 @@ def row_value(self) -> 'StructExpression': return self._row.drop(*self.key.keys()) @staticmethod - @typecheck(df=pyspark.sql.DataFrame, - key=table_key_type) + @typecheck(df=pyspark.sql.DataFrame, key=table_key_type) def from_spark(df, key=[]) -> 'Table': """Convert PySpark SQL DataFrame to a table. @@ -3541,7 +4305,7 @@ def to_pandas(self, flatten=True, types={}): hl.tint64: "Int64", hl.tfloat32: "Float32", hl.tfloat64: "Float64", - hl.tbool: "boolean" + hl.tbool: "boolean", } all_types = {**hl_default_dtypes, **types} @@ -3558,8 +4322,7 @@ def to_pandas(self, flatten=True, types={}): return pandas.DataFrame(data_dict) @staticmethod - @typecheck(df=pandas.DataFrame, - key=oneof(str, sequenceof(str))) + @typecheck(df=pandas.DataFrame, key=oneof(str, sequenceof(str))) def from_pandas(df, key=[]) -> 'Table': """Create table from Pandas DataFrame @@ -3642,52 +4405,52 @@ def _same(self, other, tolerance=1e-6, absolute=False, reorder_fields=False): return False left = self - left = left.select_globals(left_globals = left.globals) - left = left.group_by(key=left.key).aggregate(left_row = hl.agg.collect(left.row_value)) + left = left.select_globals(left_globals=left.globals) + left = left.group_by(key=left.key).aggregate(left_row=hl.agg.collect(left.row_value)) right = other - right = right.select_globals(right_globals = right.globals) - right = right.group_by(key=right.key).aggregate(right_row = hl.agg.collect(right.row_value)) + right = right.select_globals(right_globals=right.globals) + right = right.group_by(key=right.key).aggregate(right_row=hl.agg.collect(right.row_value)) t = left.join(right, how='outer') - mismatched_globals, mismatched_rows = t.aggregate(hl.tuple(( - hl.or_missing( - ~_values_similar(t.left_globals, t.right_globals, tolerance, absolute), - t.globals - ), - hl.agg.filter( - ~hl.all( - hl.is_defined(t.left_row), - hl.is_defined(t.right_row), - _values_similar(t.left_row, t.right_row, tolerance, absolute), + mismatched_globals, mismatched_rows = t.aggregate( + hl.tuple(( + hl.or_missing(~_values_similar(t.left_globals, t.right_globals, tolerance, absolute), t.globals), + hl.agg.filter( + ~hl.all( + hl.is_defined(t.left_row), + hl.is_defined(t.right_row), + _values_similar(t.left_row, t.right_row, tolerance, absolute), + ), + hl.agg.take(t.row, 10), ), - hl.agg.take(t.row, 10) - ) - ))) + )) + ) columns, _ = shutil.get_terminal_size((80, 10)) + def pretty(obj): pretty_str = pprint.pformat(obj, width=columns) return ''.join(' ' + line for line in pretty_str.splitlines(keepends=True)) is_same = True if mismatched_globals is not None: - print(f'''Table._same: globals differ: + print(f"""Table._same: globals differ: Left: {pretty(mismatched_globals.left_globals)} Right: -{pretty(mismatched_globals.right_globals)}''') +{pretty(mismatched_globals.right_globals)}""") is_same = False if len(mismatched_rows) > 0: print('Table._same: rows differ:') for r in mismatched_rows: - print(f''' Row mismatch at key={r.key}: + print(f""" Row mismatch at key={r.key}: Left: {pretty(r.left_row)} Right: -{pretty(r.right_row)}''') +{pretty(r.right_row)}""") is_same = False return is_same @@ -3744,12 +4507,11 @@ def collect_by_key(self, name: str = 'values') -> 'Table': :class:`.Table` """ - import hail.methods.misc as misc + from hail.methods import misc + misc.require_key(self, 'collect_by_key') - return Table(ir.TableAggregateByKey( - self._tir, - hl.struct(**{name: hl.agg.collect(self.row_value)})._ir)) + return Table(ir.TableAggregateByKey(self._tir, hl.struct(**{name: hl.agg.collect(self.row_value)})._ir)) def distinct(self) -> 'Table': """Deduplicate keys, keeping exactly one row for each unique key. @@ -3795,7 +4557,8 @@ def distinct(self) -> 'Table': :class:`.Table` """ - import hail.methods.misc as misc + from hail.methods import misc + misc.require_key(self, 'distinct') return Table(ir.TableDistinct(self._tir)) @@ -3814,12 +4577,9 @@ def summarize(self, handler=None): def _filter_partitions(self, parts, keep=True) -> 'Table': return Table(ir.TableToTableApply(self._tir, {'name': 'TableFilterPartitions', 'parts': parts, 'keep': keep})) - @typecheck_method(entries_field_name=str, - cols_field_name=str, - col_key=sequenceof(str)) + @typecheck_method(entries_field_name=str, cols_field_name=str, col_key=sequenceof(str)) def _unlocalize_entries(self, entries_field_name, cols_field_name, col_key) -> 'hl.MatrixTable': - return hl.MatrixTable(ir.CastTableToMatrix( - self._tir, entries_field_name, cols_field_name, col_key)) + return hl.MatrixTable(ir.CastTableToMatrix(self._tir, entries_field_name, cols_field_name, col_key)) @staticmethod @typecheck(tables=sequenceof(table_type), data_field_name=str, global_field_name=str) @@ -3866,16 +4626,21 @@ def multi_way_zip_join(tables, data_field_name, global_field_name) -> 'Table': raise ValueError('multi_way_zip_join must have at least one table as an argument') head = tables[0] if any(head.key.dtype != t.key.dtype for t in tables): - raise TypeError('All input tables to multi_way_zip_join must have the same key type:\n ' - + '\n '.join(str(t.key.dtype) for t in tables)) + raise TypeError( + 'All input tables to multi_way_zip_join must have the same key type:\n ' + + '\n '.join(str(t.key.dtype) for t in tables) + ) if any(head.row.dtype != t.row.dtype for t in tables): - raise TypeError('All input tables to multi_way_zip_join must have the same row type\n ' - + '\n '.join(str(t.row.dtype) for t in tables)) + raise TypeError( + 'All input tables to multi_way_zip_join must have the same row type\n ' + + '\n '.join(str(t.row.dtype) for t in tables) + ) if any(head.globals.dtype != t.globals.dtype for t in tables): - raise TypeError('All input tables to multi_way_zip_join must have the same global type\n ' - + '\n '.join(str(t.globals.dtype) for t in tables)) - return Table(ir.TableMultiWayZipJoin( - [t._tir for t in tables], data_field_name, global_field_name)) + raise TypeError( + 'All input tables to multi_way_zip_join must have the same global type\n ' + + '\n '.join(str(t.globals.dtype) for t in tables) + ) + return Table(ir.TableMultiWayZipJoin([t._tir for t in tables], data_field_name, global_field_name)) def _group_within_partitions(self, name, n): def grouping_func(part): @@ -3889,7 +4654,9 @@ def grouping_func(part): def _map_partitions(self, f): rows_uid = 'tmp_rows_' + Env.get_uid() globals_uid = 'tmp_globals_' + Env.get_uid() - expr = construct_expr(ir.Ref(rows_uid, hl.tstream(self.row.dtype)), hl.tstream(self.row.dtype), self._row_indices) + expr = construct_expr( + ir.Ref(rows_uid, hl.tstream(self.row.dtype)), hl.tstream(self.row.dtype), self._row_indices + ) body = f(expr) result_t = body.dtype if any(k not in result_t.element_type for k in self.key): @@ -3900,10 +4667,12 @@ def _map_partitions(self, f): def _calculate_new_partitions(self, n_partitions): """returns a set of range bounds that can be passed to write""" - return Env.backend().execute(ir.TableToValueApply( - self.select().select_globals()._tir, - {'name': 'TableCalculateNewPartitions', - 'nPartitions': n_partitions})) + return Env.backend().execute( + ir.TableToValueApply( + self.select().select_globals()._tir, + {'name': 'TableCalculateNewPartitions', 'nPartitions': n_partitions}, + ) + ) table_type.set(Table) diff --git a/hail/python/hail/typecheck/__init__.py b/hail/python/hail/typecheck/__init__.py index 5020749aa6c..5cff90ee174 100644 --- a/hail/python/hail/typecheck/__init__.py +++ b/hail/python/hail/typecheck/__init__.py @@ -1,9 +1,32 @@ -from .check import (TypeChecker, typecheck, typecheck_method, anytype, anyfunc, - nullable, sequenceof, tupleof, sized_tupleof, sliceof, - dictof, linked_list, setof, oneof, exactly, numeric, char, - lazy, enumeration, identity, transformed, func_spec, - table_key_type, TypecheckFailure, arg_check, args_check, - kwargs_check) +from .check import ( + TypeChecker, + TypecheckFailure, + anyfunc, + anytype, + arg_check, + args_check, + char, + dictof, + enumeration, + exactly, + func_spec, + identity, + kwargs_check, + lazy, + linked_list, + nullable, + numeric, + oneof, + sequenceof, + setof, + sized_tupleof, + sliceof, + table_key_type, + transformed, + tupleof, + typecheck, + typecheck_method, +) __all__ = [ 'TypeChecker', @@ -32,5 +55,5 @@ 'TypecheckFailure', 'arg_check', 'args_check', - 'kwargs_check' + 'kwargs_check', ] diff --git a/hail/python/hail/typecheck/check.py b/hail/python/hail/typecheck/check.py index 52d3e03a85a..2a0fd04134a 100644 --- a/hail/python/hail/typecheck/check.py +++ b/hail/python/hail/typecheck/check.py @@ -1,8 +1,9 @@ -from typing import TypeVar, Callable -import re -import inspect import abc import collections +import inspect +import re +from typing import Callable, TypeVar + from hailtop.hail_decorator import decorator @@ -29,12 +30,10 @@ def __init__(self): pass @abc.abstractmethod - def check(self, x, caller, param): - ... + def check(self, x, caller, param): ... @abc.abstractmethod - def expects(self): - ... + def expects(self): ... def format(self, arg): return f"{extract(type(arg))}: {arg}" @@ -220,6 +219,7 @@ def __init__(self, type): def check(self, x, caller, param): from hail.utils import LinkedList + if not isinstance(x, LinkedList): raise TypecheckFailure if x.type is not self.type: @@ -360,10 +360,7 @@ def check(self, x, caller, param): params = inspect.signature(x).parameters if self.nargs != len(params): - n_required_params = len([ - p for p in params.values() - if p.default == inspect.Parameter.empty - ]) + n_required_params = len([p for p in params.values() if p.default == inspect.Parameter.empty]) if not (self.nargs >= n_required_params and self.nargs < len(params)): raise TypecheckFailure @@ -373,12 +370,14 @@ def f(*args): try: return self.ret_checker.check(ret, caller, param) except TypecheckFailure: - raise TypeError("'{caller}': '{param}': expected return type {expected}, found {found}".format( - caller=caller, - param=param, - expected=self.ret_checker.expects(), - found=self.ret_checker.format(ret) - )) + raise TypeError( + "'{caller}': '{param}': expected return type {expected}, found {found}".format( + caller=caller, + param=param, + expected=self.ret_checker.expects(), + found=self.ret_checker.format(ret), + ) + ) return f @@ -400,7 +399,9 @@ def only(t): elif isinstance(t, TypeChecker): return t else: - raise RuntimeError("invalid typecheck signature: expected 'type', 'lambda', or 'TypeChecker', found '%s'" % type(t)) + raise RuntimeError( + "invalid typecheck signature: expected 'type', 'lambda', or 'TypeChecker', found '%s'" % type(t) + ) def exactly(v, reference_equality=False): @@ -457,8 +458,8 @@ def func_spec(n, tc): def transformed(*tcs): fs = [] for tc, f in tcs: - tc = only(tc) - fs.append((tc, f)) + _tc = only(tc) + fs.append((_tc, f)) return CoercionChecker(*fs) @@ -472,10 +473,7 @@ def lazy(): char = CharChecker() -table_key_type = nullable( - oneof( - transformed((str, lambda x: [x])), - sequenceof(str))) +table_key_type = nullable(oneof(transformed((str, lambda x: [x])), sequenceof(str))) def get_signature(f) -> inspect.Signature: @@ -525,17 +523,18 @@ def check_all(f, args, kwargs, checks, is_method): has_varargs = any(param.kind == param.VAR_POSITIONAL for param in spec.parameters.values()) n_pos_args = len( - list(filter( - lambda p: p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD), - spec.parameters.values()))) + list(filter(lambda p: p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD), spec.parameters.values())) + ) if not has_varargs and len(args) > n_pos_args: raise TypeError(f"'{name}' takes {n_pos_args} positional arguments, found {len(args)}") for i, (arg_name, param) in enumerate(spec.parameters.items()): if i == 0 and is_method: if not isinstance(arg_list[0], object): - raise RuntimeError("no class found as first argument. Did you mean to use 'typecheck' " - "instead of 'typecheck_method'?") + raise RuntimeError( + "no class found as first argument. Did you mean to use 'typecheck' " + "instead of 'typecheck_method'?" + ) args_.append(args[i]) continue checker = checks[arg_name] @@ -546,14 +545,12 @@ def check_all(f, args, kwargs, checks, is_method): if necessarily_positional or keyword_passed_as_positional: if i >= len(args): - raise TypeError( - f'Expected {n_pos_args} positional arguments, found {len(args)}') + raise TypeError(f'Expected {n_pos_args} positional arguments, found {len(args)}') args_.append(arg_check(args[i], name, arg_name, checker)) elif param.kind in (param.KEYWORD_ONLY, param.POSITIONAL_OR_KEYWORD): arg = kwargs.pop(arg_name, param.default) if arg is inspect._empty: - raise TypeError( - f"{name}() missing required keyword-only argument '{arg_name}'") + raise TypeError(f"{name}() missing required keyword-only argument '{arg_name}'") kwargs_[arg_name] = arg_check(arg, name, arg_name, checker) elif param.kind == param.VAR_POSITIONAL: # consume the rest of the positional arguments @@ -567,6 +564,7 @@ def check_all(f, args, kwargs, checks, is_method): kwargs_[kwarg_name] = kwargs_check(arg, name, kwarg_name, checker) return args_, kwargs_ + def typecheck_method(**checkers): return _make_dec(checkers, is_method=True) @@ -593,42 +591,35 @@ def arg_check(arg, function_name: str, arg_name: str, checker: TypeChecker): try: return checker.check(arg, function_name, arg_name) except TypecheckFailure as e: - raise TypeError("{fname}: parameter '{argname}': " - "expected {expected}, found {found}".format( - fname=function_name, - argname=arg_name, - expected=checker.expects(), - found=checker.format(arg) - )) from e - - -def args_check(arg, - function_name: str, - arg_name: str, - index: int, - total_varargs: int, - checker: TypeChecker): + raise TypeError( + "{fname}: parameter '{argname}': " "expected {expected}, found {found}".format( + fname=function_name, argname=arg_name, expected=checker.expects(), found=checker.format(arg) + ) + ) from e + + +def args_check(arg, function_name: str, arg_name: str, index: int, total_varargs: int, checker: TypeChecker): try: return checker.check(arg, function_name, arg_name) except TypecheckFailure as e: - raise TypeError("{fname}: parameter '*{argname}' (arg {idx} of {tot}): " - "expected {expected}, found {found}".format( - fname=function_name, - argname=arg_name, - idx=index, - tot=total_varargs, - expected=checker.expects(), - found=checker.format(arg) - )) from e + raise TypeError( + "{fname}: parameter '*{argname}' (arg {idx} of {tot}): " "expected {expected}, found {found}".format( + fname=function_name, + argname=arg_name, + idx=index, + tot=total_varargs, + expected=checker.expects(), + found=checker.format(arg), + ) + ) from e def kwargs_check(arg, function_name: str, kwarg_name: str, checker: TypeChecker): try: return checker.check(arg, function_name, kwarg_name) except TypecheckFailure as e: - raise TypeError("{fname}: keyword argument '{argname}': " - "expected {expected}, found {found}".format( - fname=function_name, - argname=kwarg_name, - expected=checker.expects(), - found=checker.format(arg))) from e + raise TypeError( + "{fname}: keyword argument '{argname}': " "expected {expected}, found {found}".format( + fname=function_name, argname=kwarg_name, expected=checker.expects(), found=checker.format(arg) + ) + ) from e diff --git a/hail/python/hail/utils/__init__.py b/hail/python/hail/utils/__init__.py index bd35290fa31..938e587c916 100644 --- a/hail/python/hail/utils/__init__.py +++ b/hail/python/hail/utils/__init__.py @@ -1,60 +1,85 @@ -from .misc import (wrap_to_list, get_env_or_default, uri_path, local_path_uri, new_temp_file, - new_local_temp_dir, new_local_temp_file, with_local_temp_file, storage_level, - range_matrix_table, range_table, run_command, timestamp_path, - _dumps_partitions, default_handler, guess_cloud_spark_provider, no_service_backend, - ANY_REGION) -from .hadoop_utils import (hadoop_copy, hadoop_open, hadoop_exists, hadoop_is_dir, hadoop_is_file, - hadoop_ls, hadoop_scheme_supported, hadoop_stat, copy_log) -from .struct import Struct -from .linkedlist import LinkedList -from .interval import Interval -from .frozendict import frozendict -from .java import error, warning, info, FatalError, HailUserError -from .tutorial import get_1kg, get_hgdp, get_movie_lens from .deduplicate import deduplicate -from .jsonx import JSONEncoder +from .frozendict import frozendict from .genomic_range_table import genomic_range_table +from .hadoop_utils import ( + copy_log, + hadoop_copy, + hadoop_exists, + hadoop_is_dir, + hadoop_is_file, + hadoop_ls, + hadoop_open, + hadoop_scheme_supported, + hadoop_stat, +) +from .interval import Interval +from .java import FatalError, HailUserError, error, info, warning +from .jsonx import JSONEncoder +from .linkedlist import LinkedList +from .misc import ( + ANY_REGION, + _dumps_partitions, + default_handler, + get_env_or_default, + guess_cloud_spark_provider, + local_path_uri, + new_local_temp_dir, + new_local_temp_file, + new_temp_file, + no_service_backend, + range_matrix_table, + range_table, + run_command, + storage_level, + timestamp_path, + uri_path, + with_local_temp_file, + wrap_to_list, +) +from .struct import Struct +from .tutorial import get_1kg, get_hgdp, get_movie_lens -__all__ = ['hadoop_open', - 'hadoop_copy', - 'hadoop_exists', - 'hadoop_is_dir', - 'hadoop_is_file', - 'hadoop_stat', - 'hadoop_ls', - 'hadoop_scheme_supported', - 'copy_log', - 'wrap_to_list', - 'new_local_temp_dir', - 'new_local_temp_file', - 'new_temp_file', - 'get_env_or_default', - 'storage_level', - 'uri_path', - 'local_path_uri', - 'run_command', - 'Struct', - 'Interval', - 'frozendict', - 'error', - 'warning', - 'info', - 'FatalError', - 'HailUserError', - 'range_table', - 'range_matrix_table', - 'LinkedList', - 'get_1kg', - 'get_hgdp', - 'get_movie_lens', - 'timestamp_path', - '_dumps_partitions', - 'default_handler', - 'deduplicate', - 'with_local_temp_file', - 'guess_cloud_spark_provider', - 'no_service_backend', - 'JSONEncoder', - 'genomic_range_table', - 'ANY_REGION', - ] +__all__ = [ + 'hadoop_open', + 'hadoop_copy', + 'hadoop_exists', + 'hadoop_is_dir', + 'hadoop_is_file', + 'hadoop_stat', + 'hadoop_ls', + 'hadoop_scheme_supported', + 'copy_log', + 'wrap_to_list', + 'new_local_temp_dir', + 'new_local_temp_file', + 'new_temp_file', + 'get_env_or_default', + 'storage_level', + 'uri_path', + 'local_path_uri', + 'run_command', + 'Struct', + 'Interval', + 'frozendict', + 'error', + 'warning', + 'info', + 'FatalError', + 'HailUserError', + 'range_table', + 'range_matrix_table', + 'LinkedList', + 'get_1kg', + 'get_hgdp', + 'get_movie_lens', + 'timestamp_path', + '_dumps_partitions', + 'default_handler', + 'deduplicate', + 'with_local_temp_file', + 'guess_cloud_spark_provider', + 'no_service_backend', + 'JSONEncoder', + 'genomic_range_table', + 'ANY_REGION', +] diff --git a/hail/python/hail/utils/byte_reader.py b/hail/python/hail/utils/byte_reader.py index e084a7f6a65..c56efc951ef 100644 --- a/hail/python/hail/utils/byte_reader.py +++ b/hail/python/hail/utils/byte_reader.py @@ -8,12 +8,12 @@ def __init__(self, byte_memview, offset=0): self._offset = offset def read_int32(self) -> int: - res = struct.unpack('=i', self._memview[self._offset:self._offset + 4])[0] + res = struct.unpack('=i', self._memview[self._offset : self._offset + 4])[0] self._offset += 4 return res def read_int64(self) -> int: - res = struct.unpack('=q', self._memview[self._offset:self._offset + 8])[0] + res = struct.unpack('=q', self._memview[self._offset : self._offset + 8])[0] self._offset += 8 return res @@ -23,17 +23,17 @@ def read_bool(self) -> bool: return res def read_float32(self) -> float: - res = struct.unpack('=f', self._memview[self._offset:self._offset + 4])[0] + res = struct.unpack('=f', self._memview[self._offset : self._offset + 4])[0] self._offset += 4 return res def read_float64(self) -> float: - res = struct.unpack('=d', self._memview[self._offset:self._offset + 8])[0] + res = struct.unpack('=d', self._memview[self._offset : self._offset + 8])[0] self._offset += 8 return res def read_bytes_view(self, num_bytes): - res = self._memview[self._offset: self._offset + num_bytes] + res = self._memview[self._offset : self._offset + num_bytes] self._offset += num_bytes return res diff --git a/hail/python/hail/utils/deduplicate.py b/hail/python/hail/utils/deduplicate.py index d47aa86acd6..b2602dcafd9 100644 --- a/hail/python/hail/utils/deduplicate.py +++ b/hail/python/hail/utils/deduplicate.py @@ -1,11 +1,8 @@ -from typing import Tuple, Optional, Iterable, List +from typing import Iterable, List, Optional, Tuple def deduplicate( - ids: Iterable[str], - *, - max_attempts: Optional[int] = None, - already_used: Optional[Iterable[str]] = None + ids: Iterable[str], *, max_attempts: Optional[int] = None, already_used: Optional[Iterable[str]] = None ) -> Tuple[List[Tuple[str, str]], List[str]]: """Deduplicate the strings in `ids`. @@ -45,8 +42,7 @@ def fmt(s, i): while s_ in uniques: i += 1 if max_attempts and i > max_attempts: - raise RecursionError( - f'cannot deduplicate {s} after {max_attempts} attempts') + raise RecursionError(f'cannot deduplicate {s} after {max_attempts} attempts') s_ = fmt(s, i) if s_ != s: diff --git a/hail/python/hail/utils/genomic_range_table.py b/hail/python/hail/utils/genomic_range_table.py index eb0690ecb92..bf7a503c1d4 100644 --- a/hail/python/hail/utils/genomic_range_table.py +++ b/hail/python/hail/utils/genomic_range_table.py @@ -1,18 +1,16 @@ +from typing import Optional from numpy import linspace -from typing import Optional import hail as hl -from .misc import check_nonnegative_and_in_range, check_positive_and_in_range + from ..genetics.reference_genome import reference_genome_type -from ..typecheck import typecheck, nullable +from ..typecheck import nullable, typecheck +from .misc import check_nonnegative_and_in_range, check_positive_and_in_range @typecheck(n=int, n_partitions=nullable(int), reference_genome=nullable(reference_genome_type)) -def genomic_range_table(n: int, - n_partitions: Optional[int] = None, - reference_genome='default' - ) -> 'hl.Table': +def genomic_range_table(n: int, n_partitions: Optional[int] = None, reference_genome='default') -> 'hl.Table': """Construct a table with a locus and no other fields. Examples @@ -57,15 +55,12 @@ def genomic_range_table(n: int, contexts=idx_bounds, partitions=[ hl.Interval(**{ - endpoint: hl.Struct( - locus=reference_genome.locus_from_global_position(idx) - ) for endpoint, idx in [('start', lo), ('end', hi)] + endpoint: hl.Struct(locus=reference_genome.locus_from_global_position(idx)) + for endpoint, idx in [('start', lo), ('end', hi)] }) for (lo, hi) in idx_bounds ], rowfn=lambda idx_range, _: hl.range(idx_range[0], idx_range[1]).map( - lambda idx: hl.struct( - locus=hl.locus_from_global_position(idx, reference_genome) - ) - ) + lambda idx: hl.struct(locus=hl.locus_from_global_position(idx, reference_genome)) + ), ) diff --git a/hail/python/hail/utils/hadoop_utils.py b/hail/python/hail/utils/hadoop_utils.py index f87203d5a2e..1920ce01f61 100644 --- a/hail/python/hail/utils/hadoop_utils.py +++ b/hail/python/hail/utils/hadoop_utils.py @@ -5,14 +5,11 @@ from typing import Any, Dict, List from hail.fs.hadoop_fs import HadoopFS -from hail.utils import local_path_uri +from hail.typecheck import enumeration, typecheck from hail.utils.java import Env, info -from hail.typecheck import typecheck, enumeration -@typecheck(path=str, - mode=enumeration('r', 'w', 'x', 'rb', 'wb', 'xb'), - buffer_size=int) +@typecheck(path=str, mode=enumeration('r', 'w', 'x', 'rb', 'wb', 'xb'), buffer_size=int) def hadoop_open(path: str, mode: str = 'r', buffer_size: int = 8192): """Open a file through the Hadoop filesystem API. Supports distributed file systems like hdfs, gs, and s3. @@ -99,8 +96,7 @@ def hadoop_open(path: str, mode: str = 'r', buffer_size: int = 8192): return file -@typecheck(src=str, - dest=str) +@typecheck(src=str, dest=str) def hadoop_copy(src, dest): """Copy a file through the Hadoop filesystem API. Supports distributed file systems like hdfs, gs, and s3. @@ -287,12 +283,14 @@ def copy_log(path: str) -> None: ---------- path: :class:`str` """ + from hail.utils import local_path_uri + log = os.path.realpath(Env.hc()._log) try: if hadoop_is_dir(path): _, tail = os.path.split(log) path = os.path.join(path, tail) - info(f"copying log to {repr(path)}...") + info(f"copying log to {path!r}...") hadoop_copy(local_path_uri(log), path) except Exception as e: sys.stderr.write(f'Could not copy log: encountered error:\n {e}') diff --git a/hail/python/hail/utils/interval.py b/hail/python/hail/utils/interval.py index 0e08041e5c8..4a2f8e406d1 100644 --- a/hail/python/hail/utils/interval.py +++ b/hail/python/hail/utils/interval.py @@ -1,5 +1,5 @@ -from hail.typecheck import typecheck_method, lazy, nullable, anytype import hail as hl +from hail.typecheck import anytype, lazy, nullable, typecheck_method interval_type = lazy() @@ -33,14 +33,17 @@ class Interval(object): - :func:`.parse_locus_interval` """ - @typecheck_method(start=anytype, - end=anytype, - includes_start=bool, - includes_end=bool, - point_type=nullable(lambda: hl.expr.types.hail_type)) + @typecheck_method( + start=anytype, + end=anytype, + includes_start=bool, + includes_end=bool, + point_type=nullable(lambda: hl.expr.types.hail_type), + ) def __init__(self, start, end, includes_start=True, includes_end=False, point_type=None): if point_type is None: from hail.expr.expressions import impute_type, unify_types_limited + start_type = impute_type(start) end_type = impute_type(end) point_type = unify_types_limited(start_type, end_type) @@ -63,16 +66,21 @@ def __str__(self): return f'{open}{bounds}{close}' def __repr__(self): - return 'Interval(start={}, end={}, includes_start={}, includes_end={})'\ - .format(repr(self.start), repr(self.end), repr(self.includes_start), repr(self._includes_end)) + return 'Interval(start={}, end={}, includes_start={}, includes_end={})'.format( + repr(self.start), repr(self.end), repr(self.includes_start), repr(self._includes_end) + ) def __eq__(self, other): - return ( self._start == other._start and - self._end == other._end and - self._includes_start == other._includes_start and - self._includes_end == other._includes_end - ) if isinstance(other, Interval) else NotImplemented - + return ( + ( + self._start == other._start + and self._end == other._end + and self._includes_start == other._includes_start + and self._includes_end == other._includes_end + ) + if isinstance(other, Interval) + else NotImplemented + ) def __hash__(self): return hash(self._start) ^ hash(self._end) ^ hash(self._includes_start) ^ hash(self._includes_end) diff --git a/hail/python/hail/utils/java.py b/hail/python/hail/utils/java.py index 3f19d0d6e77..b876680807b 100644 --- a/hail/python/hail/utils/java.py +++ b/hail/python/hail/utils/java.py @@ -1,6 +1,6 @@ -from typing import Optional, Union -import sys import re +import sys +from typing import Optional, Union from hailtop.config import ConfigVariable, configuration_of @@ -27,10 +27,7 @@ def maybe_user_error(self, ir) -> Union['FatalError', HailUserError]: better_stack_trace = error_sources[0]._stack_trace error_message = str(self) - message_and_trace = (f'{error_message}\n' - '------------\n' - 'Hail stack trace:\n' - f'{better_stack_trace}') + message_and_trace = f'{error_message}\n' '------------\n' 'Hail stack trace:\n' f'{better_stack_trace}' return HailUserError(message_and_trace) @@ -63,11 +60,16 @@ def hc() -> 'hail.context.HailContext': sys.stderr.write("Initializing Hail with default parameters...\n") sys.stderr.flush() from ..context import init + init() assert Env._hc is not None return Env._hc + @staticmethod + def is_fully_initialized() -> bool: + return Env._hc is not None + @staticmethod async def _async_hc() -> 'hail.context.HailContext': if not Env._hc: @@ -77,6 +79,7 @@ async def _async_hc() -> 'hail.context.HailContext': backend_name = choose_backend() if backend_name == 'service': from hail.context import init_batch + await init_batch() else: return Env.hc() @@ -90,22 +93,22 @@ def backend() -> 'hail.backend.Backend': @staticmethod def py4j_backend(op): from hail.backend.py4j_backend import Py4JBackend + b = Env.backend() if isinstance(b, Py4JBackend): return b else: - raise NotImplementedError( - f"{b.__class__.__name__} doesn't support {op}, only Py4JBackend") + raise NotImplementedError(f"{b.__class__.__name__} doesn't support {op}, only Py4JBackend") @staticmethod def spark_backend(op): from hail.backend.spark_backend import SparkBackend + b = Env.backend() if isinstance(b, SparkBackend): return b else: - raise NotImplementedError( - f"{b.__class__.__name__} doesn't support {op}, only SparkBackend") + raise NotImplementedError(f"{b.__class__.__name__} doesn't support {op}, only SparkBackend") @staticmethod def fs(): @@ -121,13 +124,14 @@ def spark_session(): def dummy_table(): if Env._dummy_table is None: import hail + Env._dummy_table = hail.utils.range_table(1, 1).key_by().cache() return Env._dummy_table @staticmethod def next_static_rng_uid(): result = Env._static_rng_uid - assert(result <= 0x7FFF_FFFF_FFFF_FFFF) + assert result <= 0x7FFF_FFFF_FFFF_FFFF Env._static_rng_uid += 1 return result diff --git a/hail/python/hail/utils/jsonx.py b/hail/python/hail/utils/jsonx.py index 4d56641226f..a5bfcc39098 100644 --- a/hail/python/hail/utils/jsonx.py +++ b/hail/python/hail/utils/jsonx.py @@ -1,15 +1,15 @@ -from typing import Any import json +from typing import Any + import numpy as np import pandas as pd - -from .frozendict import frozendict -from .struct import Struct -from .interval import Interval from ..genetics.locus import Locus from ..genetics.reference_genome import ReferenceGenome +from .frozendict import frozendict +from .interval import Interval from .misc import escape_str +from .struct import Struct class JSONEncoder(json.JSONEncoder): diff --git a/hail/python/hail/utils/linkedlist.py b/hail/python/hail/utils/linkedlist.py index d1a30ab4aab..93cb1375e60 100644 --- a/hail/python/hail/utils/linkedlist.py +++ b/hail/python/hail/utils/linkedlist.py @@ -42,15 +42,13 @@ def __iter__(self): return ListIterator(self.node) def __str__(self): - return f'''List({', '.join(str(x) for x in self)})''' + return f"""List({', '.join(str(x) for x in self)})""" def __repr__(self): - return f'''List({', '.join(repr(x) for x in self)})''' + return f"""List({', '.join(repr(x) for x in self)})""" def __eq__(self, other): - return list(self) == list(other) \ - if isinstance(other, LinkedList) \ - else NotImplemented + return list(self) == list(other) if isinstance(other, LinkedList) else NotImplemented def __ne__(self, other): return not self.__eq__(other) diff --git a/hail/python/hail/utils/misc.py b/hail/python/hail/utils/misc.py index a2630e277f2..5d6de2885dc 100644 --- a/hail/python/hail/utils/misc.py +++ b/hail/python/hail/utils/misc.py @@ -8,15 +8,15 @@ import shutil import string import tempfile -from collections import defaultdict, Counter +from collections import Counter, defaultdict from contextlib import contextmanager from io import StringIO -from typing import Optional, Literal +from typing import Literal, Optional from urllib.parse import urlparse import hail import hail as hl -from hail.typecheck import enumeration, typecheck, nullable +from hail.typecheck import enumeration, nullable, typecheck from hail.utils.java import Env, error @@ -64,17 +64,19 @@ def range_matrix_table(n_rows, n_cols, n_partitions=None) -> 'hail.MatrixTable': check_nonnegative_and_in_range('range_matrix_table', 'n_cols', n_cols) if n_partitions is not None: check_positive_and_in_range('range_matrix_table', 'n_partitions', n_partitions) - return hail.MatrixTable(hail.ir.MatrixRead( - hail.ir.MatrixRangeReader(n_rows, n_cols, n_partitions), - _assert_type=hl.tmatrix( - hl.tstruct(), - hl.tstruct(col_idx=hl.tint32), - ['col_idx'], - hl.tstruct(row_idx=hl.tint32), - ['row_idx'], - hl.tstruct() + return hail.MatrixTable( + hail.ir.MatrixRead( + hail.ir.MatrixRangeReader(n_rows, n_cols, n_partitions), + _assert_type=hl.tmatrix( + hl.tstruct(), + hl.tstruct(col_idx=hl.tint32), + ['col_idx'], + hl.tstruct(row_idx=hl.tint32), + ['row_idx'], + hl.tstruct(), + ), ) - )) + ) @typecheck(n=int, n_partitions=nullable(int)) @@ -120,16 +122,18 @@ def check_positive_and_in_range(caller, name, value): if value <= 0: raise ValueError(f"'{caller}': parameter '{name}' must be positive, found {value}") elif value > hail.tint32.max_value: - raise ValueError(f"'{caller}': parameter '{name}' must be less than or equal to {hail.tint32.max_value}, " - f"found {value}") + raise ValueError( + f"'{caller}': parameter '{name}' must be less than or equal to {hail.tint32.max_value}, " f"found {value}" + ) def check_nonnegative_and_in_range(caller, name, value): if value < 0: raise ValueError(f"'{caller}': parameter '{name}' must be non-negative, found {value}") elif value > hail.tint32.max_value: - raise ValueError(f"'{caller}': parameter '{name}' must be less than or equal to {hail.tint32.max_value}, " - f"found {value}") + raise ValueError( + f"'{caller}': parameter '{name}' must be less than or equal to {hail.tint32.max_value}, " f"found {value}" + ) def wrap_to_list(s): @@ -145,7 +149,7 @@ def wrap_to_tuple(x): if isinstance(x, tuple): return x else: - return x, + return (x,) def wrap_to_sequence(x): @@ -154,7 +158,7 @@ def wrap_to_sequence(x): if isinstance(x, list): return tuple(x) else: - return x, + return (x,) def get_env_or_default(maybe, envvar, default): @@ -208,14 +212,25 @@ def with_local_temp_file(filename: str = 'temp') -> str: pass -storage_level = enumeration('NONE', 'DISK_ONLY', 'DISK_ONLY_2', 'MEMORY_ONLY', - 'MEMORY_ONLY_2', 'MEMORY_ONLY_SER', 'MEMORY_ONLY_SER_2', - 'MEMORY_AND_DISK', 'MEMORY_AND_DISK_2', 'MEMORY_AND_DISK_SER', - 'MEMORY_AND_DISK_SER_2', 'OFF_HEAP') +storage_level = enumeration( + 'NONE', + 'DISK_ONLY', + 'DISK_ONLY_2', + 'MEMORY_ONLY', + 'MEMORY_ONLY_2', + 'MEMORY_ONLY_SER', + 'MEMORY_ONLY_SER_2', + 'MEMORY_AND_DISK', + 'MEMORY_AND_DISK_2', + 'MEMORY_AND_DISK_SER', + 'MEMORY_AND_DISK_SER_2', + 'OFF_HEAP', +) def run_command(args): import subprocess as sp + try: sp.check_output(args, stderr=sp.STDOUT) except sp.CalledProcessError as e: @@ -241,10 +256,10 @@ def plural(orig, n, alternate=None): def get_obj_metadata(obj): - from hail.matrixtable import MatrixTable, GroupedMatrixTable - from hail.table import Table, GroupedTable + from hail.expr.expressions import ArrayStructExpression, SetStructExpression, StructExpression + from hail.matrixtable import GroupedMatrixTable, MatrixTable + from hail.table import GroupedTable, Table from hail.utils import Struct - from hail.expr.expressions import StructExpression, ArrayStructExpression, SetStructExpression def table_error(index_obj): def fmt_field(field): @@ -259,12 +274,14 @@ def fmt_field(field): else: assert inds == index_obj._entry_indices return "'{}' [entry]".format(field) + return fmt_field def struct_error(s): def fmt_field(field): assert field in s._fields return "'{}'".format(field) + return fmt_field if isinstance(obj, MatrixTable): @@ -322,16 +339,21 @@ def get_nice_attr_error(obj, item): s.append('\n Data {}: {}'.format(word, ', '.join(handler(f) for f in fs))) if method_matches: word = plural('method', len(method_matches)) - s.append('\n {} {}: {}'.format(class_name, word, - ', '.join("'{}'".format(m) for m in method_matches))) + s.append( + '\n {} {}: {}'.format(class_name, word, ', '.join("'{}'".format(m) for m in method_matches)) + ) if prop_matches: word = plural('property', len(prop_matches), 'properties') - s.append('\n {} {}: {}'.format(class_name, word, - ', '.join("'{}'".format(p) for p in prop_matches))) + s.append( + '\n {} {}: {}'.format(class_name, word, ', '.join("'{}'".format(p) for p in prop_matches)) + ) if inherited_matches: word = plural('inherited method', len(inherited_matches)) - s.append('\n {} {}: {}'.format(class_name, word, - ', '.join("'{}'".format(m) for m in inherited_matches))) + s.append( + '\n {} {}: {}'.format( + class_name, word, ', '.join("'{}'".format(m) for m in inherited_matches) + ) + ) elif has_describe: s.append("\n Hint: use 'describe()' to show the names of all data fields.") return ''.join(s) @@ -362,12 +384,16 @@ def get_nice_field_error(obj, item): def check_collisions(caller, names, indices, override_protected_indices=None): from hail.expr.expressions import ExpressionException + fields = indices.source._fields if override_protected_indices is not None: + def invalid(e): return e._indices in override_protected_indices + else: + def invalid(e): return e._indices != indices @@ -382,11 +408,13 @@ def invalid(e): for k, v in Counter(names).items(): if v > 1: from hail.expr.expressions import ExpressionException + raise ExpressionException(f"{caller!r}: selection would produce duplicate field {k!r}") def get_key_by_exprs(caller, exprs, named_exprs, indices, override_protected_indices=None): - from hail.expr.expressions import to_expr, ExpressionException, analyze + from hail.expr.expressions import ExpressionException, analyze, to_expr + exprs = [indices.source[e] if isinstance(e, str) else e for e in exprs] named_exprs = {k: to_expr(v) for k, v in named_exprs.items()} @@ -400,11 +428,13 @@ def is_top_level_field(e): for e in exprs: analyze(caller, e, indices, broadcast=False) if not e._ir.is_nested_field: - raise ExpressionException(f"{caller!r} expects keyword arguments for complex expressions\n" - f" Correct: ht = ht.key_by('x')\n" - f" Correct: ht = ht.key_by(ht.x)\n" - f" Correct: ht = ht.key_by(x = ht.x.replace(' ', '_'))\n" - f" INCORRECT: ht = ht.key_by(ht.x.replace(' ', '_'))") + raise ExpressionException( + f"{caller!r} expects keyword arguments for complex expressions\n" + f" Correct: ht = ht.key_by('x')\n" + f" Correct: ht = ht.key_by(ht.x)\n" + f" Correct: ht = ht.key_by(x = ht.x.replace(' ', '_'))\n" + f" INCORRECT: ht = ht.key_by(ht.x.replace(' ', '_'))" + ) name = e._ir.name final_key.append(name) @@ -422,15 +452,19 @@ def is_top_level_field(e): def check_keys(caller, name, protected_key): from hail.expr.expressions import ExpressionException + if name in protected_key: - msg = f"{caller!r}: cannot overwrite key field {name!r} with annotate, select or drop; " \ - f"use key_by to modify keys." + msg = ( + f"{caller!r}: cannot overwrite key field {name!r} with annotate, select or drop; " + f"use key_by to modify keys." + ) error('Analysis exception: {}'.format(msg)) raise ExpressionException(msg) def get_select_exprs(caller, exprs, named_exprs, indices, base_struct): - from hail.expr.expressions import to_expr, ExpressionException, analyze + from hail.expr.expressions import ExpressionException, analyze, to_expr + exprs = [indices.source[e] if isinstance(e, str) else e for e in exprs] named_exprs = {k: to_expr(v) for k, v in named_exprs.items()} select_fields = indices.protected_key[:] @@ -444,11 +478,13 @@ def is_top_level_field(e): for e in exprs: if not e._ir.is_nested_field: - raise ExpressionException(f"{caller!r} expects keyword arguments for complex expressions\n" - f" Correct: ht = ht.select('x')\n" - f" Correct: ht = ht.select(ht.x)\n" - f" Correct: ht = ht.select(x = ht.x.replace(' ', '_'))\n" - f" INCORRECT: ht = ht.select(ht.x.replace(' ', '_'))") + raise ExpressionException( + f"{caller!r} expects keyword arguments for complex expressions\n" + f" Correct: ht = ht.select('x')\n" + f" Correct: ht = ht.select(ht.x)\n" + f" Correct: ht = ht.select(x = ht.x.replace(' ', '_'))\n" + f" INCORRECT: ht = ht.select(ht.x.replace(' ', '_'))" + ) analyze(caller, e, indices, broadcast=False) name = e._ir.name @@ -477,6 +513,7 @@ def is_top_level_field(e): def check_annotate_exprs(caller, named_exprs, indices, agg_axes): from hail.expr.expressions import analyze + protected_key = set(indices.protected_key) for k, v in named_exprs.items(): analyze(f'{caller}: field {k!r}', v, indices, agg_axes, broadcast=True) @@ -506,8 +543,9 @@ def cleanup(table): def divide_null(num, denom): + from hail.expr import if_else, missing from hail.expr.expressions.base_expression import unify_types_limited - from hail.expr import missing, if_else + typ = unify_types_limited(num.dtype, denom.dtype) assert typ is not None return if_else(denom != 0, num / denom, missing(typ)) @@ -518,10 +556,7 @@ def lookup_bit(byte, which_bit): def timestamp_path(base, suffix=''): - return ''.join([base, - '-', - datetime.datetime.now().strftime("%Y%m%d-%H%M"), - suffix]) + return ''.join([base, '-', datetime.datetime.now().strftime("%Y%m%d-%H%M"), suffix]) def upper_hex(n, num_digits=None): @@ -534,41 +569,33 @@ def upper_hex(n, num_digits=None): def escape_str(s, backticked=False): sb = StringIO() - rewrite_dict = { - '\b': '\\b', - '\n': '\\n', - '\t': '\\t', - '\f': '\\f', - '\r': '\\r' - } + rewrite_dict = {'\b': '\\b', '\n': '\\n', '\t': '\\t', '\f': '\\f', '\r': '\\r'} for ch in s: chNum = ord(ch) - if chNum > 0x7f: + if chNum > 0x7F: sb.write("\\u" + upper_hex(chNum, 4)) elif chNum < 32: if ch in rewrite_dict: sb.write(rewrite_dict[ch]) + elif chNum > 0xF: + sb.write("\\u00" + upper_hex(chNum)) else: - if chNum > 0xf: - sb.write("\\u00" + upper_hex(chNum)) - else: - sb.write("\\u000" + upper_hex(chNum)) - else: - if ch == '"': - if backticked: - sb.write('"') - else: - sb.write('\\\"') - elif ch == '`': - if backticked: - sb.write("\\`") - else: - sb.write("`") - elif ch == '\\': - sb.write('\\\\') + sb.write("\\u000" + upper_hex(chNum)) + elif ch == '"': + if backticked: + sb.write('"') + else: + sb.write('\\"') + elif ch == '`': + if backticked: + sb.write("\\`") else: - sb.write(ch) + sb.write("`") + elif ch == '\\': + sb.write('\\\\') + else: + sb.write(ch) escaped = sb.getvalue() sb.close() @@ -590,20 +617,22 @@ def parsable_strings(strs): def _dumps_partitions(partitions, row_key_type): parts_type = partitions.dtype - if not (isinstance(parts_type, hl.tarray) - and isinstance(parts_type.element_type, hl.tinterval)): + if not (isinstance(parts_type, hl.tarray) and isinstance(parts_type.element_type, hl.tinterval)): raise ValueError(f'partitions type invalid: {parts_type} must be array of intervals') point_type = parts_type.element_type.point_type f1, t1 = next(iter(row_key_type.items())) if point_type == t1: - partitions = hl.map(lambda x: hl.interval( - start=hl.struct(**{f1: x.start}), - end=hl.struct(**{f1: x.end}), - includes_start=x.includes_start, - includes_end=x.includes_end), - partitions) + partitions = hl.map( + lambda x: hl.interval( + start=hl.struct(**{f1: x.start}), + end=hl.struct(**{f1: x.end}), + includes_start=x.includes_start, + includes_end=x.includes_end, + ), + partitions, + ) else: if not isinstance(point_type, hl.tstruct): raise ValueError(f'partitions has wrong type: {point_type} must be struct or type of first row key field') @@ -617,6 +646,7 @@ def _dumps_partitions(partitions, row_key_type): def default_handler(): try: from IPython.display import display + return display except ImportError: return print @@ -633,10 +663,13 @@ def guess_cloud_spark_provider() -> Optional[Literal['dataproc', 'hdinsight']]: def no_service_backend(unsupported_feature): from hail import current_backend from hail.backend.service_backend import ServiceBackend + if isinstance(current_backend(), ServiceBackend): - raise NotImplementedError(f'{unsupported_feature!r} is not yet supported on the service backend.' - f'\n If this is a pressing need, please alert the team on the discussion' - f'\n forum to aid in prioritization: https://discuss.hail.is') + raise NotImplementedError( + f'{unsupported_feature!r} is not yet supported on the service backend.' + f'\n If this is a pressing need, please alert the team on the discussion' + f'\n forum to aid in prioritization: https://discuss.hail.is' + ) ANY_REGION = ['any_region'] diff --git a/hail/python/hail/utils/struct.py b/hail/python/hail/utils/struct.py index 658c25175f9..f7a19efa1c4 100644 --- a/hail/python/hail/utils/struct.py +++ b/hail/python/hail/utils/struct.py @@ -1,10 +1,10 @@ -from typing import Dict, Any +import pprint from collections import OrderedDict from collections.abc import Mapping -import pprint +from typing import Any, Dict +from hail.typecheck import anytype, typecheck, typecheck_method from hail.utils.misc import get_nice_attr_error, get_nice_field_error -from hail.typecheck import typecheck, typecheck_method, anytype class Struct(Mapping): @@ -89,21 +89,11 @@ def __repr__(self): def __str__(self): if all(k.isidentifier() for k in self._fields): - return ( - 'Struct(' - + ', '.join(f'{k}={repr(v)}' for k, v in self._fields.items()) - + ')' - ) - return ( - 'Struct(**{' - + ', '.join(f'{repr(k)}: {repr(v)}' for k, v in self._fields.items()) - + '})' - ) + return 'Struct(' + ', '.join(f'{k}={v!r}' for k, v in self._fields.items()) + ')' + return 'Struct(**{' + ', '.join(f'{k!r}: {v!r}' for k, v in self._fields.items()) + '})' def __eq__(self, other): - return self._fields == other._fields \ - if isinstance(other, Struct) \ - else NotImplemented + return self._fields == other._fields if isinstance(other, Struct) else NotImplemented def __hash__(self): return 37 + hash(tuple(sorted(self._fields.items()))) diff --git a/hail/python/hail/utils/tutorial.py b/hail/python/hail/utils/tutorial.py index 9bf58d18f3c..4c59372654a 100644 --- a/hail/python/hail/utils/tutorial.py +++ b/hail/python/hail/utils/tutorial.py @@ -1,16 +1,14 @@ -import hail as hl -from .java import Env, info -from .misc import new_temp_file, local_path_uri, new_local_temp_dir import os import zipfile from urllib.request import urlretrieve + +import hail as hl from hailtop.utils import sync_retry_transient_errors -__all__ = [ - 'get_1kg', - 'get_hgdp', - 'get_movie_lens' -] +from .java import Env, info +from .misc import local_path_uri, new_local_temp_dir, new_temp_file + +__all__ = ['get_1kg', 'get_hgdp', 'get_movie_lens'] resources = { '1kg_annotations': 'https://storage.googleapis.com/hail-tutorial/1kg_annotations.txt', @@ -70,16 +68,17 @@ def get_1kg(output_dir, overwrite: bool = False): sample_annotations_path = os.path.join(output_dir, '1kg_annotations.txt') gene_annotations_path = os.path.join(output_dir, 'ensembl_gene_annotations.txt') - if (overwrite - or not _dir_exists(fs, matrix_table_path) - or not _file_exists(fs, sample_annotations_path) - or not _file_exists(fs, vcf_path) - or not _file_exists(fs, gene_annotations_path)): + if ( + overwrite + or not _dir_exists(fs, matrix_table_path) + or not _file_exists(fs, sample_annotations_path) + or not _file_exists(fs, vcf_path) + or not _file_exists(fs, gene_annotations_path) + ): init_temp_dir() tmp_vcf = os.path.join(tmp_dir, '1kg.vcf.bgz') source = resources['1kg_matrix_table'] - info(f'downloading 1KG VCF ...\n' - f' Source: {source}') + info(f'downloading 1KG VCF ...\n' f' Source: {source}') sync_retry_transient_errors(urlretrieve, resources['1kg_matrix_table'], tmp_vcf) cluster_readable_vcf = _copy_to_tmp(fs, local_path_uri(tmp_vcf), extension='vcf.bgz') info('importing VCF and writing to matrix table...') @@ -87,14 +86,12 @@ def get_1kg(output_dir, overwrite: bool = False): tmp_sample_annot = os.path.join(tmp_dir, '1kg_annotations.txt') source = resources['1kg_annotations'] - info(f'downloading 1KG annotations ...\n' - f' Source: {source}') + info(f'downloading 1KG annotations ...\n' f' Source: {source}') sync_retry_transient_errors(urlretrieve, source, tmp_sample_annot) tmp_gene_annot = os.path.join(tmp_dir, 'ensembl_gene_annotations.txt') source = resources['1kg_ensembl_gene_annotations'] - info(f'downloading Ensembl gene annotations ...\n' - f' Source: {source}') + info(f'downloading Ensembl gene annotations ...\n' f' Source: {source}') sync_retry_transient_errors(urlretrieve, source, tmp_gene_annot) hl.hadoop_copy(local_path_uri(tmp_sample_annot), sample_annotations_path) @@ -131,31 +128,32 @@ def get_hgdp(output_dir, overwrite: bool = False): sample_annotations_path = os.path.join(output_dir, 'HGDP_annotations.txt') gene_annotations_path = os.path.join(output_dir, 'ensembl_gene_annotations.txt') - if (overwrite - or not _dir_exists(fs, matrix_table_path) - or not _file_exists(fs, sample_annotations_path) - or not _file_exists(fs, vcf_path) - or not _file_exists(fs, gene_annotations_path)): + if ( + overwrite + or not _dir_exists(fs, matrix_table_path) + or not _file_exists(fs, sample_annotations_path) + or not _file_exists(fs, vcf_path) + or not _file_exists(fs, gene_annotations_path) + ): init_temp_dir() tmp_vcf = os.path.join(tmp_dir, 'HGDP.vcf.bgz') source = resources['HGDP_matrix_table'] - info(f'downloading HGDP VCF ...\n' - f' Source: {source}') + info(f'downloading HGDP VCF ...\n' f' Source: {source}') sync_retry_transient_errors(urlretrieve, resources['HGDP_matrix_table'], tmp_vcf) cluster_readable_vcf = _copy_to_tmp(fs, local_path_uri(tmp_vcf), extension='vcf.bgz') info('importing VCF and writing to matrix table...') - hl.import_vcf(cluster_readable_vcf, min_partitions=16, reference_genome='GRCh38').write(matrix_table_path, overwrite=True) + hl.import_vcf(cluster_readable_vcf, min_partitions=16, reference_genome='GRCh38').write( + matrix_table_path, overwrite=True + ) tmp_sample_annot = os.path.join(tmp_dir, 'HGDP_annotations.txt') source = resources['HGDP_annotations'] - info(f'downloading HGDP annotations ...\n' - f' Source: {source}') + info(f'downloading HGDP annotations ...\n' f' Source: {source}') sync_retry_transient_errors(urlretrieve, source, tmp_sample_annot) tmp_gene_annot = os.path.join(tmp_dir, 'ensembl_gene_annotations.txt') source = resources['HGDP_ensembl_gene_annotations'] - info(f'downloading Ensembl gene annotations ...\n' - f' Source: {source}') + info(f'downloading Ensembl gene annotations ...\n' f' Source: {source}') sync_retry_transient_errors(urlretrieve, source, tmp_gene_annot) hl.hadoop_copy(local_path_uri(tmp_sample_annot), sample_annotations_path) @@ -194,8 +192,7 @@ def get_movie_lens(output_dir, overwrite: bool = False): init_temp_dir() source = resources['movie_lens_100k'] tmp_path = os.path.join(tmp_dir, 'ml-100k.zip') - info(f'downloading MovieLens-100k data ...\n' - f' Source: {source}') + info(f'downloading MovieLens-100k data ...\n' f' Source: {source}') sync_retry_transient_errors(urlretrieve, source, tmp_path) with zipfile.ZipFile(tmp_path, 'r') as z: z.extractall(tmp_dir) @@ -203,9 +200,9 @@ def get_movie_lens(output_dir, overwrite: bool = False): user_table_path = os.path.join(tmp_dir, 'ml-100k', 'u.user') movie_table_path = os.path.join(tmp_dir, 'ml-100k', 'u.item') ratings_table_path = os.path.join(tmp_dir, 'ml-100k', 'u.data') - assert (os.path.exists(user_table_path)) - assert (os.path.exists(movie_table_path)) - assert (os.path.exists(ratings_table_path)) + assert os.path.exists(user_table_path) + assert os.path.exists(movie_table_path) + assert os.path.exists(ratings_table_path) user_cluster_readable = _copy_to_tmp(fs, local_path_uri(user_table_path), extension='txt') movie_cluster_readable = _copy_to_tmp(fs, local_path_uri(movie_table_path), 'txt') @@ -213,12 +210,26 @@ def get_movie_lens(output_dir, overwrite: bool = False): [movies_path, ratings_path, users_path] = paths - genres = ['Action', 'Adventure', 'Animation', - "Children's", 'Comedy', 'Crime', - 'Documentary', 'Drama', 'Fantasy', - 'Film-Noir', 'Horror', 'Musical', - 'Mystery', 'Romance', 'Sci-Fi', - 'Thriller', 'War', 'Western'] + genres = [ + 'Action', + 'Adventure', + 'Animation', + "Children's", + 'Comedy', + 'Crime', + 'Documentary', + 'Drama', + 'Fantasy', + 'Film-Noir', + 'Horror', + 'Musical', + 'Mystery', + 'Romance', + 'Sci-Fi', + 'Thriller', + 'War', + 'Western', + ] # utility functions for importing movies def field_to_array(ds, field): @@ -234,14 +245,16 @@ def rename_columns(ht, new_names): users = rename_columns( hl.import_table(user_cluster_readable, key=['f0'], no_header=True, impute=True, delimiter='|'), - ['id', 'age', 'sex', 'occupation', 'zipcode']) + ['id', 'age', 'sex', 'occupation', 'zipcode'], + ) users.write(users_path, overwrite=True) info(f'importing movies table and writing to {movies_path} ...') movies = hl.import_table(movie_cluster_readable, key=['f0'], no_header=True, impute=True, delimiter='|') - movies = rename_columns(movies, - ['id', 'title', 'release date', 'video release date', 'IMDb URL', 'unknown'] + genres) + movies = rename_columns( + movies, ['id', 'title', 'release date', 'video release date', 'IMDb URL', 'unknown', *genres] + ) movies = movies.drop('release date', 'video release date', 'unknown', 'IMDb URL') movies = movies.transmute(genres=fields_to_array(movies, genres)) movies.write(movies_path, overwrite=True) @@ -249,8 +262,7 @@ def rename_columns(ht, new_names): info(f'importing ratings table and writing to {ratings_path} ...') ratings = hl.import_table(ratings_cluster_readable, no_header=True, impute=True) - ratings = rename_columns(ratings, - ['user_id', 'movie_id', 'rating', 'timestamp']) + ratings = rename_columns(ratings, ['user_id', 'movie_id', 'rating', 'timestamp']) ratings = ratings.drop('timestamp') ratings.write(ratings_path, overwrite=True) diff --git a/hail/python/hail/vds/__init__.py b/hail/python/hail/vds/__init__.py index 14a5940bc45..754b7cb58c7 100644 --- a/hail/python/hail/vds/__init__.py +++ b/hail/python/hail/vds/__init__.py @@ -1,11 +1,24 @@ from . import combiner +from .combiner import load_combiner, new_combiner from .functions import lgt_to_gt, local_to_global -from .methods import filter_intervals, filter_samples, filter_variants, sample_qc, split_multi, to_dense_mt, \ - to_merged_sparse_mt, segment_reference_blocks, write_variant_datasets, interval_coverage, \ - impute_sex_chr_ploidy_from_interval_coverage, impute_sex_chromosome_ploidy, filter_chromosomes, \ - truncate_reference_blocks, merge_reference_blocks +from .methods import ( + filter_chromosomes, + filter_intervals, + filter_samples, + filter_variants, + impute_sex_chr_ploidy_from_interval_coverage, + impute_sex_chromosome_ploidy, + interval_coverage, + merge_reference_blocks, + segment_reference_blocks, + split_multi, + to_dense_mt, + to_merged_sparse_mt, + truncate_reference_blocks, + write_variant_datasets, +) +from .sample_qc import sample_qc from .variant_dataset import VariantDataset, read_vds, store_ref_block_max_length -from .combiner import load_combiner, new_combiner __all__ = [ 'VariantDataset', diff --git a/hail/python/hail/vds/combiner/__init__.py b/hail/python/hail/vds/combiner/__init__.py index 0b3443ed407..73a21fa954f 100644 --- a/hail/python/hail/vds/combiner/__init__.py +++ b/hail/python/hail/vds/combiner/__init__.py @@ -1,5 +1,5 @@ -from .combine import transform_gvcf, combine_variant_datasets -from .variant_dataset_combiner import new_combiner, load_combiner, VariantDatasetCombiner, VDSMetadata +from .combine import combine_variant_datasets, transform_gvcf +from .variant_dataset_combiner import VariantDatasetCombiner, VDSMetadata, load_combiner, new_combiner __all__ = [ 'combine_variant_datasets', @@ -7,5 +7,5 @@ 'new_combiner', 'load_combiner', 'VariantDatasetCombiner', - 'VDSMetadata' + 'VDSMetadata', ] diff --git a/hail/python/hail/vds/combiner/combine.py b/hail/python/hail/vds/combiner/combine.py index 780e8f9123f..ea3bd047d35 100644 --- a/hail/python/hail/vds/combiner/combine.py +++ b/hail/python/hail/vds/combiner/combine.py @@ -1,24 +1,27 @@ import math -from typing import Collection, Optional, Set, Union, List, Tuple, Dict +from typing import Collection, Dict, List, Optional, Set, Tuple, Union import hail as hl -from hail import MatrixTable, Table from hail.experimental.function import Function -from hail.expr import StructExpression, unify_all, construct_expr +from hail.expr import BooleanExpression, StructExpression, construct_expr, unify_all from hail.expr.expressions import expr_bool, expr_str +from hail.expr.functions import numeric_allele_type +from hail.expr.types import HailType +from hail.genetics.allele_type import AlleleType from hail.genetics.reference_genome import reference_genome_type from hail.ir import Apply, TableMapRows +from hail.matrixtable import MatrixTable +from hail.table import Table from hail.typecheck import oneof, sequenceof, typecheck + from ..variant_dataset import VariantDataset -_transform_variant_function_map: Dict[Tuple[hl.HailType, Tuple[str, ...]], Function] = {} -_transform_reference_fuction_map: Dict[Tuple[hl.HailType, Tuple[str, ...]], Function] = {} -_merge_function_map: Dict[Tuple[hl.HailType, hl.HailType], Function] = {} +_transform_variant_function_map: Dict[Tuple[HailType, Tuple[str, ...]], Function] = {} +_transform_reference_fuction_map: Dict[Tuple[HailType, Tuple[str, ...]], Function] = {} +_merge_function_map: Dict[Tuple[HailType, HailType], Function] = {} -def make_variants_matrix_table(mt: MatrixTable, - info_to_keep: Optional[Collection[str]] = None - ) -> MatrixTable: +def make_variants_matrix_table(mt: MatrixTable, info_to_keep: Optional[Collection[str]] = None) -> MatrixTable: if info_to_keep is None: info_to_keep = [] if not info_to_keep: @@ -29,24 +32,22 @@ def make_variants_matrix_table(mt: MatrixTable, transform_row = _transform_variant_function_map.get((mt.row.dtype, info_key)) if transform_row is None or not hl.current_backend()._is_registered_ir_function_name(transform_row._name): + def get_lgt(gt, n_alleles, has_non_ref, row): index = gt.unphase().unphased_diploid_gt_index() n_no_nonref = n_alleles - hl.int(has_non_ref) triangle_without_nonref = hl.triangle(n_no_nonref) - return (hl.case() - .when(gt.is_haploid(), - hl.or_missing(gt[0] < n_no_nonref, gt)) - .when(index < triangle_without_nonref, gt) - .when(index < hl.triangle(n_alleles), hl.missing('call')) - .or_error('invalid call ' + hl.str(gt) + ' at site ' + hl.str(row.locus))) + return ( + hl.case() + .when(gt.is_haploid(), hl.or_missing(gt[0] < n_no_nonref, gt)) + .when(index < triangle_without_nonref, gt) + .when(index < hl.triangle(n_alleles), hl.missing('call')) + .or_error('invalid call ' + hl.str(gt) + ' at site ' + hl.str(row.locus)) + ) def make_entry_struct(e, alleles_len, has_non_ref, row): handled_fields = dict() - handled_names = {'LA', 'gvcf_info', - 'LAD', 'AD', - 'LGT', 'GT', - 'LPL', 'PL', - 'LPGT', 'PGT'} + handled_names = {'LA', 'gvcf_info', 'LAD', 'AD', 'LGT', 'GT', 'LPL', 'PL', 'LPGT', 'PGT'} if 'GT' not in e: raise hl.utils.FatalError("the Hail VDS combiner expects input GVCFs to have a 'GT' field in FORMAT.") @@ -56,44 +57,50 @@ def make_entry_struct(e, alleles_len, has_non_ref, row): if 'AD' in e: handled_fields['LAD'] = hl.if_else(has_non_ref, e.AD[:-1], e.AD) if 'PGT' in e: - handled_fields['LPGT'] = e.PGT if e.PGT.dtype != hl.tcall \ - else get_lgt(e.PGT, alleles_len, has_non_ref, row) + handled_fields['LPGT'] = ( + e.PGT if e.PGT.dtype != hl.tcall else get_lgt(e.PGT, alleles_len, has_non_ref, row) + ) if 'PL' in e: - handled_fields['LPL'] = hl.if_else(has_non_ref, - hl.if_else(alleles_len > 2, - e.PL[:-alleles_len], - hl.missing(e.PL.dtype)), - hl.if_else(alleles_len > 1, - e.PL, - hl.missing(e.PL.dtype))) + handled_fields['LPL'] = hl.if_else( + has_non_ref, + hl.if_else(alleles_len > 2, e.PL[:-alleles_len], hl.missing(e.PL.dtype)), + hl.if_else(alleles_len > 1, e.PL, hl.missing(e.PL.dtype)), + ) handled_fields['RGQ'] = hl.if_else( has_non_ref, - hl.if_else(e.GT.is_haploid(), - e.PL[alleles_len - 1], - e.PL[hl.call(0, alleles_len - 1).unphased_diploid_gt_index()]), - hl.missing(e.PL.dtype.element_type)) - - handled_fields['gvcf_info'] = (hl.case() - .when(hl.is_missing(row.info.END), - parse_allele_specific_fields( - row.info.select(*info_to_keep), - has_non_ref - )) - .or_missing()) + hl.if_else( + e.GT.is_haploid(), + e.PL[alleles_len - 1], + e.PL[hl.call(0, alleles_len - 1).unphased_diploid_gt_index()], + ), + hl.missing(e.PL.dtype.element_type), + ) + + handled_fields['gvcf_info'] = ( + hl.case() + .when( + hl.is_missing(row.info.END), + parse_allele_specific_fields(row.info.select(*info_to_keep), has_non_ref), + ) + .or_missing() + ) pass_through_fields = {k: v for k, v in e.items() if k not in handled_names} return hl.struct(**handled_fields, **pass_through_fields) transform_row = hl.experimental.define_function( lambda row: hl.rbind( - hl.len(row.alleles), '' == row.alleles[-1], + hl.len(row.alleles), + '' == row.alleles[-1], lambda alleles_len, has_non_ref: hl.struct( locus=row.locus, alleles=hl.if_else(has_non_ref, row.alleles[:-1], row.alleles), **({'rsid': row.rsid} if 'rsid' in row else {}), - __entries=row.__entries.map( - lambda e: make_entry_struct(e, alleles_len, has_non_ref, row)))), - mt.row.dtype) + __entries=row.__entries.map(lambda e: make_entry_struct(e, alleles_len, has_non_ref, row)), + ), + ), + mt.row.dtype, + ) _transform_variant_function_map[mt.row.dtype, info_key] = transform_row return unlocalize(Table(TableMapRows(mt._tir, Apply(transform_row._name, transform_row._ret_type, mt.row._ir)))) @@ -101,9 +108,7 @@ def make_entry_struct(e, alleles_len, has_non_ref, row): def defined_entry_fields(mt: MatrixTable, sample=None) -> Set[str]: if sample is not None: mt = mt.head(sample) - used = mt.aggregate_entries(hl.struct(**{ - k: hl.agg.any(hl.is_defined(v)) for k, v in mt.entry.items() - })) + used = mt.aggregate_entries(hl.struct(**{k: hl.agg.any(hl.is_defined(v)) for k, v in mt.entry.items()})) return set(k for k in mt.entry if used[k]) @@ -122,28 +127,25 @@ def make_entry_struct(e, row): if 'PL' in entry_to_keep: handled_fields['LPL'] = e['PL'][:1] - reference_fields = {k: v for k, v in e.items() - if k in entry_to_keep and k not in handled_names} - return (hl.case() - .when(e.GT.is_hom_ref(), - hl.struct(END=row.info.END, **reference_fields, **handled_fields)) - .or_error('found END with non reference-genotype at' + hl.str(row.locus))) + reference_fields = {k: v for k, v in e.items() if k in entry_to_keep and k not in handled_names} + return ( + hl.case() + .when(e.GT.is_hom_ref(), hl.struct(END=row.info.END, **reference_fields, **handled_fields)) + .or_error('found END with non reference-genotype at' + hl.str(row.locus)) + ) row_type = stream.dtype.element_type transform_row = _transform_reference_fuction_map.get((row_type, entry_key)) if transform_row is None or not hl.current_backend()._is_registered_ir_function_name(transform_row._name): transform_row = hl.experimental.define_function( - lambda row: hl.struct( - locus=row.locus, - __entries=row.__entries.map( - lambda e: make_entry_struct(e, row))), - row_type) + lambda row: hl.struct(locus=row.locus, __entries=row.__entries.map(lambda e: make_entry_struct(e, row))), + row_type, + ) _transform_reference_fuction_map[row_type, entry_key] = transform_row - return stream.map(lambda row: hl.struct( - locus=row.locus, - __entries=row.__entries.map( - lambda e: make_entry_struct(e, row)))) + return stream.map( + lambda row: hl.struct(locus=row.locus, __entries=row.__entries.map(lambda e: make_entry_struct(e, row))) + ) def make_variant_stream(stream, info_to_keep): @@ -159,24 +161,22 @@ def make_variant_stream(stream, info_to_keep): transform_row = _transform_variant_function_map.get((row_type, info_key)) if transform_row is None or not hl.current_backend()._is_registered_ir_function_name(transform_row._name): + def get_lgt(e, n_alleles, has_non_ref, row): index = e.GT.unphased_diploid_gt_index() n_no_nonref = n_alleles - hl.int(has_non_ref) triangle_without_nonref = hl.triangle(n_no_nonref) - return (hl.case() - .when(e.GT.is_haploid(), - hl.or_missing(e.GT[0] < n_no_nonref, e.GT)) - .when(index < triangle_without_nonref, e.GT) - .when(index < hl.triangle(n_alleles), hl.missing('call')) - .or_error('invalid GT ' + hl.str(e.GT) + ' at site ' + hl.str(row.locus))) + return ( + hl.case() + .when(e.GT.is_haploid(), hl.or_missing(e.GT[0] < n_no_nonref, e.GT)) + .when(index < triangle_without_nonref, e.GT) + .when(index < hl.triangle(n_alleles), hl.missing('call')) + .or_error('invalid GT ' + hl.str(e.GT) + ' at site ' + hl.str(row.locus)) + ) def make_entry_struct(e, alleles_len, has_non_ref, row): handled_fields = dict() - handled_names = {'LA', 'gvcf_info', - 'LAD', 'AD', - 'LGT', 'GT', - 'LPL', 'PL', - 'LPGT', 'PGT'} + handled_names = {'LA', 'gvcf_info', 'LAD', 'AD', 'LGT', 'GT', 'LPL', 'PL', 'LPGT', 'PGT'} if 'GT' not in e: raise hl.utils.FatalError("the Hail GVCF combiner expects GVCFs to have a 'GT' field in FORMAT.") @@ -188,55 +188,63 @@ def make_entry_struct(e, alleles_len, has_non_ref, row): if 'PGT' in e: handled_fields['LPGT'] = e.PGT if 'PL' in e: - handled_fields['LPL'] = hl.if_else(has_non_ref, - hl.if_else(alleles_len > 2, - e.PL[:-alleles_len], - hl.missing(e.PL.dtype)), - hl.if_else(alleles_len > 1, - e.PL, - hl.missing(e.PL.dtype))) + handled_fields['LPL'] = hl.if_else( + has_non_ref, + hl.if_else(alleles_len > 2, e.PL[:-alleles_len], hl.missing(e.PL.dtype)), + hl.if_else(alleles_len > 1, e.PL, hl.missing(e.PL.dtype)), + ) handled_fields['RGQ'] = hl.if_else( has_non_ref, - hl.if_else(e.GT.is_haploid(), - e.PL[alleles_len - 1], - e.PL[hl.call(0, alleles_len - 1).unphased_diploid_gt_index()]), - hl.missing(e.PL.dtype.element_type)) - - handled_fields['gvcf_info'] = (hl.case() - .when(hl.is_missing(row.info.END), - parse_allele_specific_fields( - row.info.select(*info_to_keep), - has_non_ref - )) - .or_missing()) + hl.if_else( + e.GT.is_haploid(), + e.PL[alleles_len - 1], + e.PL[hl.call(0, alleles_len - 1).unphased_diploid_gt_index()], + ), + hl.missing(e.PL.dtype.element_type), + ) + + handled_fields['gvcf_info'] = ( + hl.case() + .when( + hl.is_missing(row.info.END), + parse_allele_specific_fields(row.info.select(*info_to_keep), has_non_ref), + ) + .or_missing() + ) pass_through_fields = {k: v for k, v in e.items() if k not in handled_names} return hl.struct(**handled_fields, **pass_through_fields) transform_row = hl.experimental.define_function( lambda row: hl.rbind( - hl.len(row.alleles), '' == row.alleles[-1], + hl.len(row.alleles), + '' == row.alleles[-1], lambda alleles_len, has_non_ref: hl.struct( locus=row.locus, alleles=hl.if_else(has_non_ref, row.alleles[:-1], row.alleles), **({'rsid': row.rsid} if 'rsid' in row else {}), - __entries=row.__entries.map( - lambda e: make_entry_struct(e, alleles_len, has_non_ref, row)))), - row_type) + __entries=row.__entries.map(lambda e: make_entry_struct(e, alleles_len, has_non_ref, row)), + ), + ), + row_type, + ) _transform_variant_function_map[row_type, info_key] = transform_row from hail.expr import construct_expr from hail.utils.java import Env + uid = Env.get_uid() - map_ir = hl.ir.ToArray(hl.ir.StreamMap(hl.ir.ToStream(stream._ir), uid, - Apply(transform_row._name, transform_row._ret_type, - hl.ir.Ref(uid, type=row_type)))) + map_ir = hl.ir.ToArray( + hl.ir.StreamMap( + hl.ir.ToStream(stream._ir), + uid, + Apply(transform_row._name, transform_row._ret_type, hl.ir.Ref(uid, type=row_type)), + ) + ) return construct_expr(map_ir, map_ir.typ, stream._indices, stream._aggregations) -def make_reference_matrix_table(mt: MatrixTable, - entry_to_keep: Collection[str] - ) -> MatrixTable: +def make_reference_matrix_table(mt: MatrixTable, entry_to_keep: Collection[str]) -> MatrixTable: mt = mt.filter_rows(hl.is_defined(mt.info.END)) entry_key = tuple(sorted(entry_to_keep)) # hashable stable value @@ -251,30 +259,28 @@ def make_entry_struct(e, row): if 'PL' in entry_to_keep: handled_fields['LPL'] = e['PL'][:1] - reference_fields = {k: v for k, v in e.items() - if k in entry_to_keep and k not in handled_names} - return (hl.case() - .when(e.GT.is_hom_ref(), - hl.struct(END=row.info.END, **reference_fields, **handled_fields)) - .or_error('found END with non reference-genotype at' + hl.str(row.locus))) + reference_fields = {k: v for k, v in e.items() if k in entry_to_keep and k not in handled_names} + return ( + hl.case() + .when(e.GT.is_hom_ref(), hl.struct(END=row.info.END, **reference_fields, **handled_fields)) + .or_error('found END with non reference-genotype at' + hl.str(row.locus)) + ) mt = localize(mt).key_by('locus') transform_row = _transform_reference_fuction_map.get((mt.row.dtype, entry_key)) if transform_row is None or not hl.current_backend()._is_registered_ir_function_name(transform_row._name): transform_row = hl.experimental.define_function( - lambda row: hl.struct( - locus=row.locus, - __entries=row.__entries.map( - lambda e: make_entry_struct(e, row))), - mt.row.dtype) + lambda row: hl.struct(locus=row.locus, __entries=row.__entries.map(lambda e: make_entry_struct(e, row))), + mt.row.dtype, + ) _transform_reference_fuction_map[mt.row.dtype, entry_key] = transform_row return unlocalize(Table(TableMapRows(mt._tir, Apply(transform_row._name, transform_row._ret_type, mt.row._ir)))) -def transform_gvcf(mt: MatrixTable, - reference_entry_fields_to_keep: Collection[str], - info_to_keep: Optional[Collection[str]] = None) -> VariantDataset: +def transform_gvcf( + mt: MatrixTable, reference_entry_fields_to_keep: Collection[str], info_to_keep: Optional[Collection[str]] = None +) -> VariantDataset: """Transforms a GVCF into a single sample VariantDataSet The input to this should be some result of :func:`.import_vcf` @@ -332,21 +338,23 @@ def combine_reference_row(row, globals): merge_function = _merge_function_map.get((row.dtype, globals)) if merge_function is None or not hl.current_backend()._is_registered_ir_function_name(merge_function._name): merge_function = hl.experimental.define_function( - lambda row, gbl: - hl.struct( + lambda row, gbl: hl.struct( locus=row.locus, __entries=hl.range(0, hl.len(row.data)).flatmap( - lambda i: - hl.if_else(hl.is_missing(row.data[i]), - hl.range(0, hl.len(gbl.g[i].__cols)) - .map(lambda _: hl.missing(row.data[i].__entries.dtype.element_type)), - row.data[i].__entries))), - row.dtype, globals.dtype) + lambda i: hl.if_else( + hl.is_missing(row.data[i]), + hl.range(0, hl.len(gbl.g[i].__cols)).map( + lambda _: hl.missing(row.data[i].__entries.dtype.element_type) + ), + row.data[i].__entries, + ) + ), + ), + row.dtype, + globals.dtype, + ) _merge_function_map[(row.dtype, globals.dtype)] = merge_function - apply_ir = Apply(merge_function._name, - merge_function._ret_type, - row._ir, - globals._ir) + apply_ir = Apply(merge_function._name, merge_function._ret_type, row._ir, globals._ir) indices, aggs = unify_all(row, globals) return construct_expr(apply_ir, apply_ir.typ, indices, aggs) @@ -376,8 +384,8 @@ def combine_variant_datasets(vdss: List[VariantDataset]) -> VariantDataset: return VariantDataset(reference, variants._key_rows_by_assert_sorted('locus', 'alleles')) -_transform_rows_function_map: Dict[Tuple[hl.HailType], Function] = {} -_merge_function_map: Dict[Tuple[hl.HailType, hl.HailType], Function] = {} +_transform_rows_function_map: Dict[Tuple[HailType], Function] = {} +_merge_function_map: Dict[Tuple[HailType, HailType], Function] = {} @typecheck(string=expr_str, has_non_ref=expr_bool) @@ -406,13 +414,20 @@ def parse_allele_specific_ranksum(string, has_non_ref): typ = hl.ttuple(hl.tfloat64, hl.tint32) items = string.split(r'\|') items = hl.if_else(has_non_ref, items[:-1], items) - return items.map(lambda s: hl.if_else( - (hl.len(s) == 0) | (s == '.'), - hl.missing(typ), - hl.rbind(s.split(','), lambda ss: hl.if_else( - hl.len(ss) != 2, # bad field, possibly 'NaN', just set it null - hl.missing(hl.ttuple(hl.tfloat64, hl.tint32)), - hl.tuple([hl.float64(ss[0]), hl.int32(ss[1])]))))) + return items.map( + lambda s: hl.if_else( + (hl.len(s) == 0) | (s == '.'), + hl.missing(typ), + hl.rbind( + s.split(','), + lambda ss: hl.if_else( + hl.len(ss) != 2, # bad field, possibly 'NaN', just set it null + hl.missing(hl.ttuple(hl.tfloat64, hl.tint32)), + hl.tuple([hl.float64(ss[0]), hl.int32(ss[1])]), + ), + ), + ) + ) _allele_specific_field_parsers = { @@ -425,9 +440,9 @@ def parse_allele_specific_ranksum(string, has_non_ref): } -def parse_allele_specific_fields(info: hl.StructExpression, - has_non_ref: Union[bool, hl.BooleanExpression] - ) -> hl.StructExpression: +def parse_allele_specific_fields( + info: StructExpression, has_non_ref: Union[bool, BooleanExpression] +) -> StructExpression: def parse_field(field: str) -> hl.Expression: if parse := _allele_specific_field_parsers.get(field): return parse(info[field], has_non_ref) @@ -449,34 +464,35 @@ def unlocalize(mt): def merge_alleles(alleles): - from hail.expr.functions import _num_allele_type, _allele_ints return hl.rbind( - alleles.map(lambda a: hl.or_else(a[0], '')) - .fold(lambda s, t: hl.if_else(hl.len(s) > hl.len(t), s, t), ''), - lambda ref: - hl.rbind( + alleles.map(lambda a: hl.or_else(a[0], '')).fold(lambda s, t: hl.if_else(hl.len(s) > hl.len(t), s, t), ''), + lambda ref: hl.rbind( alleles.map( lambda al: hl.rbind( al[0], - lambda r: - hl.array([ref]).extend( + lambda r: hl.array([ref]).extend( al[1:].map( - lambda a: - hl.rbind( - _num_allele_type(r, a), - lambda at: - hl.if_else( - (_allele_ints['SNP'] == at) - | (_allele_ints['Insertion'] == at) - | (_allele_ints['Deletion'] == at) - | (_allele_ints['MNP'] == at) - | (_allele_ints['Complex'] == at), - a + ref[hl.len(r):], - a)))))), - lambda lal: - hl.struct( - globl=hl.array([ref]).extend(hl.array(hl.set(hl.flatten(lal)).remove(ref))), - local=lal))) + lambda a: hl.rbind( + numeric_allele_type(r, a), + lambda at: hl.if_else( + (at == AlleleType.SNP) + | (at == AlleleType.INSERTION) + | (at == AlleleType.DELETION) + | (at == AlleleType.MNP) + | (at == AlleleType.COMPLEX), + a + ref[hl.len(r) :], + a, + ), + ) + ) + ), + ) + ), + lambda lal: hl.struct( + globl=hl.array([ref]).extend(hl.array(hl.set(hl.flatten(lal)).remove(ref))), local=lal + ), + ), + ) def combine_variant_rows(row, globals): @@ -487,43 +503,48 @@ def renumber_entry(entry, old_to_new) -> StructExpression: merge_function = _merge_function_map.get((row.dtype, globals.dtype)) if merge_function is None or not hl.current_backend()._is_registered_ir_function_name(merge_function._name): merge_function = hl.experimental.define_function( - lambda row, gbl: - hl.rbind( + lambda row, gbl: hl.rbind( merge_alleles(row.data.map(lambda d: d.alleles)), - lambda alleles: - hl.struct( + lambda alleles: hl.struct( locus=row.locus, alleles=alleles.globl, - **({'rsid': hl.find(hl.is_defined, row.data.map( - lambda d: d.rsid))} if 'rsid' in row.data.dtype.element_type else {}), + **( + {'rsid': hl.find(hl.is_defined, row.data.map(lambda d: d.rsid))} + if 'rsid' in row.data.dtype.element_type + else {} + ), __entries=hl.bind( - lambda combined_allele_index: - hl.range(0, hl.len(row.data)).flatmap( - lambda i: - hl.if_else(hl.is_missing(row.data[i].__entries), - hl.range(0, hl.len(gbl.g[i].__cols)) - .map(lambda _: hl.missing(row.data[i].__entries.dtype.element_type)), - hl.bind( - lambda old_to_new: row.data[i].__entries.map( - lambda e: renumber_entry(e, old_to_new)), - hl.range(0, hl.len(alleles.local[i])).map( - lambda j: combined_allele_index[alleles.local[i][j]])))), - hl.dict(hl.range(0, hl.len(alleles.globl)).map( - lambda j: hl.tuple([alleles.globl[j], j])))))), - row.dtype, globals.dtype) + lambda combined_allele_index: hl.range(0, hl.len(row.data)).flatmap( + lambda i: hl.if_else( + hl.is_missing(row.data[i].__entries), + hl.range(0, hl.len(gbl.g[i].__cols)).map( + lambda _: hl.missing(row.data[i].__entries.dtype.element_type) + ), + hl.bind( + lambda old_to_new: row.data[i].__entries.map( + lambda e: renumber_entry(e, old_to_new) + ), + hl.range(0, hl.len(alleles.local[i])).map( + lambda j: combined_allele_index[alleles.local[i][j]] + ), + ), + ) + ), + hl.dict(hl.range(0, hl.len(alleles.globl)).map(lambda j: hl.tuple([alleles.globl[j], j]))), + ), + ), + ), + row.dtype, + globals.dtype, + ) _merge_function_map[(row.dtype, globals.dtype)] = merge_function indices, aggs = unify_all(row, globals) - apply_ir = Apply(merge_function._name, - merge_function._ret_type, - row._ir, - globals._ir) + apply_ir = Apply(merge_function._name, merge_function._ret_type, row._ir, globals._ir) return construct_expr(apply_ir, apply_ir.typ, indices, aggs) def combine(ts): - ts = Table(TableMapRows(ts._tir, combine_variant_rows( - ts.row, - ts.globals)._ir)) + ts = Table(TableMapRows(ts._tir, combine_variant_rows(ts.row, ts.globals)._ir)) return ts.transmute_globals(__cols=hl.flatten(ts.g.map(lambda g: g.__cols))) @@ -550,7 +571,7 @@ def combine_gvcfs(mts): return unlocalize(combined) -@typecheck(mt=hl.MatrixTable, desired_average_partition_size=int, tmp_path=str) +@typecheck(mt=MatrixTable, desired_average_partition_size=int, tmp_path=str) def calculate_new_intervals(mt, desired_average_partition_size: int, tmp_path: str): """takes a table, keyed by ['locus', ...] and produces a list of intervals suitable for repartitioning a combiner matrix table. @@ -571,9 +592,11 @@ def calculate_new_intervals(mt, desired_average_partition_size: int, tmp_path: s assert list(mt.row_key) == ['locus'] assert isinstance(mt.locus.dtype, hl.tlocus) reference_genome = mt.locus.dtype.reference_genome - end = hl.Locus(reference_genome.contigs[-1], - reference_genome.lengths[reference_genome.contigs[-1]], - reference_genome=reference_genome) + end = hl.Locus( + reference_genome.contigs[-1], + reference_genome.lengths[reference_genome.contigs[-1]], + reference_genome=reference_genome, + ) (n_rows, n_cols) = mt.count() @@ -589,9 +612,9 @@ def calculate_new_intervals(mt, desired_average_partition_size: int, tmp_path: s total_weight = ht.aggregate(hl.agg.sum(ht.weight)) partition_weight = int(total_weight / (n_rows / desired_average_partition_size)) - ht = ht.annotate(cumulative_weight=hl.scan.sum(ht.weight), - last_weight=hl.scan._prev_nonnull(ht.weight), - row_idx=hl.scan.count()) + ht = ht.annotate( + cumulative_weight=hl.scan.sum(ht.weight), last_weight=hl.scan._prev_nonnull(ht.weight), row_idx=hl.scan.count() + ) def partition_bound(x): return x - (x % hl.int64(partition_weight)) @@ -599,18 +622,25 @@ def partition_bound(x): at_partition_bound = partition_bound(ht.cumulative_weight) != partition_bound(ht.cumulative_weight - ht.last_weight) ht = ht.filter(at_partition_bound | (ht.row_idx == n_rows - 1)) - ht = ht.annotate(start=hl.or_else( - hl.scan._prev_nonnull(hl.locus_from_global_position(ht.locus.global_position() + 1, - reference_genome=reference_genome)), - hl.locus_from_global_position(0, reference_genome=reference_genome))) + ht = ht.annotate( + start=hl.or_else( + hl.scan._prev_nonnull( + hl.locus_from_global_position(ht.locus.global_position() + 1, reference_genome=reference_genome) + ), + hl.locus_from_global_position(0, reference_genome=reference_genome), + ) + ) ht = ht.select( - interval=hl.interval(start=hl.struct(locus=ht.start), end=hl.struct(locus=ht.locus), includes_end=True)) + interval=hl.interval(start=hl.struct(locus=ht.start), end=hl.struct(locus=ht.locus), includes_end=True) + ) intervals_dtype = hl.tarray(ht.interval.dtype) intervals = ht.aggregate(hl.agg.collect(ht.interval)) last_st = hl.eval( - hl.locus_from_global_position(hl.literal(intervals[-1].end.locus).global_position() + 1, - reference_genome=reference_genome)) + hl.locus_from_global_position( + hl.literal(intervals[-1].end.locus).global_position() + 1, reference_genome=reference_genome + ) + ) interval = hl.Interval(start=hl.Struct(locus=last_st), end=hl.Struct(locus=end), includes_end=True) intervals.append(interval) return intervals, intervals_dtype @@ -638,7 +668,8 @@ def locus_interval(start, end): return hl.Interval( start=hl.Locus(contig=contig, position=start, reference_genome=reference_genome), end=hl.Locus(contig=contig, position=end, reference_genome=reference_genome), - includes_end=True) + includes_end=True, + ) contig_length = reference_genome.lengths[contig] n_parts = math.ceil(contig_length / interval_size) @@ -659,8 +690,8 @@ def locus_interval(start, end): contigs = [f'chr{i}' for i in range(1, 23)] + ['chrX', 'chrY', 'chrM'] else: raise ValueError( - f"Unsupported reference genome '{reference_genome.name}', " - "only 'GRCh37' and 'GRCh38' are supported") + f"Unsupported reference genome '{reference_genome.name}', " "only 'GRCh37' and 'GRCh38' are supported" + ) intervals = [] for ctg in contigs: diff --git a/hail/python/hail/vds/combiner/variant_dataset_combiner.py b/hail/python/hail/vds/combiner/variant_dataset_combiner.py index 6ed4009e89d..54112b7a339 100644 --- a/hail/python/hail/vds/combiner/variant_dataset_combiner.py +++ b/hail/python/hail/vds/combiner/variant_dataset_combiner.py @@ -6,16 +6,26 @@ import uuid from itertools import chain from math import floor, log -from typing import Collection, Dict, List, NamedTuple, Optional, Union +from typing import ClassVar, Collection, Dict, List, NamedTuple, Optional, Union import hail as hl from hail.expr import HailType, tmatrix +from hail.genetics.reference_genome import ReferenceGenome from hail.utils import FatalError, Interval from hail.utils.java import info, warning -from .combine import combine_variant_datasets, transform_gvcf, defined_entry_fields, make_variant_stream, \ - make_reference_stream, combine_r, calculate_even_genome_partitioning, \ - calculate_new_intervals, combine + from ..variant_dataset import VariantDataset +from .combine import ( + calculate_even_genome_partitioning, + calculate_new_intervals, + combine, + combine_r, + combine_variant_datasets, + defined_entry_fields, + make_reference_stream, + make_variant_stream, + transform_gvcf, +) class VDSMetadata(NamedTuple): @@ -29,12 +39,14 @@ class VDSMetadata(NamedTuple): Number of samples contained within the Variant Dataset at `path`. """ + path: str n_samples: int class CombinerOutType(NamedTuple): """A container for the types of a VDS""" + reference_type: tmatrix variant_type: tmatrix @@ -165,6 +177,7 @@ class VariantDatasetCombiner: # pylint: disable=too-many-instance-attributes and ``PL`` will be entry fields in the resulting reference matrix in the dataset. """ + _default_gvcf_batch_size = 50 _default_branch_factor = 100 _default_target_records = 24_000 @@ -180,7 +193,7 @@ class VariantDatasetCombiner: # pylint: disable=too-many-instance-attributes default_exome_interval_size = 60_000_000 "A reasonable partition size in basepairs given the density of exomes." - __serialized_slots__ = [ + __serialized_slots__: ClassVar = [ '_save_path', '_output_path', '_temp_path', @@ -201,41 +214,45 @@ class VariantDatasetCombiner: # pylint: disable=too-many-instance-attributes '_call_fields', ] - __slots__ = tuple(__serialized_slots__ + ['_uuid', '_job_id', '__intervals_cache']) - - def __init__(self, - *, - save_path: str, - output_path: str, - temp_path: str, - reference_genome: hl.ReferenceGenome, - dataset_type: CombinerOutType, - gvcf_type: Optional[tmatrix] = None, - branch_factor: int = _default_branch_factor, - target_records: int = _default_target_records, - gvcf_batch_size: int = _default_gvcf_batch_size, - contig_recoding: Optional[Dict[str, str]] = None, - call_fields: Collection[str], - vdses: List[VDSMetadata], - gvcfs: List[str], - gvcf_sample_names: Optional[List[str]] = None, - gvcf_external_header: Optional[str] = None, - gvcf_import_intervals: List[Interval], - gvcf_info_to_keep: Optional[Collection[str]] = None, - gvcf_reference_entry_fields_to_keep: Optional[Collection[str]] = None, - ): + __slots__ = tuple([*__serialized_slots__, "_uuid", "_job_id", "__intervals_cache"]) + + def __init__( + self, + *, + save_path: str, + output_path: str, + temp_path: str, + reference_genome: ReferenceGenome, + dataset_type: CombinerOutType, + gvcf_type: Optional[tmatrix] = None, + branch_factor: int = _default_branch_factor, + target_records: int = _default_target_records, + gvcf_batch_size: int = _default_gvcf_batch_size, + contig_recoding: Optional[Dict[str, str]] = None, + call_fields: Collection[str], + vdses: List[VDSMetadata], + gvcfs: List[str], + gvcf_sample_names: Optional[List[str]] = None, + gvcf_external_header: Optional[str] = None, + gvcf_import_intervals: List[Interval], + gvcf_info_to_keep: Optional[Collection[str]] = None, + gvcf_reference_entry_fields_to_keep: Optional[Collection[str]] = None, + ): if gvcf_import_intervals: interval = gvcf_import_intervals[0] if not isinstance(interval.point_type, hl.tlocus): raise ValueError(f'intervals point type must be a locus, found {interval.point_type}') if interval.point_type.reference_genome != reference_genome: - raise ValueError(f'mismatch in intervals ({interval.point_type.reference_genome}) ' - f'and reference genome ({reference_genome}) types') + raise ValueError( + f'mismatch in intervals ({interval.point_type.reference_genome}) ' + f'and reference genome ({reference_genome}) types' + ) if (gvcf_sample_names is None) != (gvcf_external_header is None): raise ValueError("both 'gvcf_sample_names' and 'gvcf_external_header' must be set or unset") if gvcf_sample_names is not None and len(gvcf_sample_names) != len(gvcfs): - raise ValueError("'gvcf_sample_names' and 'gvcfs' must have the same length " - f'{len(gvcf_sample_names)} != {len(gvcfs)}') + raise ValueError( + "'gvcf_sample_names' and 'gvcfs' must have the same length " f'{len(gvcf_sample_names)} != {len(gvcfs)}' + ) if branch_factor < 2: raise ValueError(f"'branch_factor' must be at least 2, found {branch_factor}") if gvcf_batch_size < 1: @@ -258,10 +275,10 @@ def __init__(self, self._gvcf_sample_names = gvcf_sample_names self._gvcf_external_header = gvcf_external_header self._gvcf_import_intervals = gvcf_import_intervals - self._gvcf_info_to_keep = set(gvcf_info_to_keep) if gvcf_info_to_keep is not None \ - else None - self._gvcf_reference_entry_fields_to_keep = set(gvcf_reference_entry_fields_to_keep) \ - if gvcf_reference_entry_fields_to_keep is not None else None + self._gvcf_info_to_keep = set(gvcf_info_to_keep) if gvcf_info_to_keep is not None else None + self._gvcf_reference_entry_fields_to_keep = ( + set(gvcf_reference_entry_fields_to_keep) if gvcf_reference_entry_fields_to_keep is not None else None + ) self._uuid = uuid.uuid4() self._job_id = 1 @@ -278,8 +295,7 @@ def gvcf_batch_size(self, value: int): if value * len(self._gvcf_import_intervals) > VariantDatasetCombiner._gvcf_merge_task_limit: old_value = value value = VariantDatasetCombiner._gvcf_merge_task_limit // len(self._gvcf_import_intervals) - warning(f'gvcf_batch_size of {old_value} would produce too many tasks ' - f'using {value} instead') + warning(f'gvcf_batch_size of {old_value} would produce too many tasks ' f'using {value} instead') self._gvcf_batch_size = value def __eq__(self, other): @@ -313,8 +329,10 @@ def save(self): print(f'Failed saving {self.__class__.__name__} state at {self._save_path}') print(f'An attempt was made to copy {self._save_path} to {backup_path}') print('An old version of this state may be there.') - print('Dumping current state as json to standard output, you may wish ' - 'to save this output in order to resume the combiner.') + print( + 'Dumping current state as json to standard output, you may wish ' + 'to save this output in order to resume the combiner.' + ) json.dump(self, sys.stdout, indent=2, cls=Encoder) print() raise e @@ -326,11 +344,13 @@ def run(self): hl._set_flags(**{flagname: '1'}) vds_samples = sum(vds.n_samples for vdses in self._vdses.values() for vds in vdses) - info('Running VDS combiner:\n' - f' VDS arguments: {self._num_vdses} datasets with {vds_samples} samples\n' - f' GVCF arguments: {len(self._gvcfs)} inputs/samples\n' - f' Branch factor: {self._branch_factor}\n' - f' GVCF merge batch size: {self._gvcf_batch_size}') + info( + 'Running VDS combiner:\n' + f' VDS arguments: {self._num_vdses} datasets with {vds_samples} samples\n' + f' GVCF arguments: {len(self._gvcfs)} inputs/samples\n' + f' Branch factor: {self._branch_factor}\n' + f' GVCF merge batch size: {self._gvcf_batch_size}' + ) while not self.finished: self.save() self.step() @@ -346,45 +366,50 @@ def load(path) -> 'VariantDatasetCombiner': combiner = json.load(stream, cls=Decoder) combiner._raise_if_output_exists() if combiner._save_path != path: - warning('path/save_path mismatch in loaded VariantDatasetCombiner, using ' - f'{path} as the new save_path for this combiner') + warning( + 'path/save_path mismatch in loaded VariantDatasetCombiner, using ' + f'{path} as the new save_path for this combiner' + ) combiner._save_path = path return combiner def _raise_if_output_exists(self): + if self.finished: + return fs = hl.current_backend().fs ref_success_path = os.path.join(VariantDataset._reference_path(self._output_path), '_SUCCESS') var_success_path = os.path.join(VariantDataset._variants_path(self._output_path), '_SUCCESS') if fs.exists(ref_success_path) and fs.exists(var_success_path): - raise FatalError(f'combiner output already exists at {self._output_path}\n' - 'move or delete it before continuing') + raise FatalError( + f'combiner output already exists at {self._output_path}\n' 'move or delete it before continuing' + ) def to_dict(self) -> dict: """A serializable representation of this combiner.""" intervals_typ = hl.tarray(hl.tinterval(hl.tlocus(self._reference_genome))) - return {'name': self.__class__.__name__, - 'save_path': self._save_path, - 'output_path': self._output_path, - 'temp_path': self._temp_path, - 'reference_genome': str(self._reference_genome), - 'dataset_type': self._dataset_type, - 'gvcf_type': self._gvcf_type, - 'branch_factor': self._branch_factor, - 'target_records': self._target_records, - 'gvcf_batch_size': self._gvcf_batch_size, - 'gvcf_external_header': self._gvcf_external_header, # put this here for humans - 'contig_recoding': self._contig_recoding, - 'gvcf_info_to_keep': None if self._gvcf_info_to_keep is None - else list(self._gvcf_info_to_keep), - 'gvcf_reference_entry_fields_to_keep': None - if self._gvcf_reference_entry_fields_to_keep is None - else list(self._gvcf_reference_entry_fields_to_keep), - 'call_fields': self._call_fields, - 'vdses': [md for i in sorted(self._vdses, reverse=True) for md in self._vdses[i]], - 'gvcfs': self._gvcfs, - 'gvcf_sample_names': self._gvcf_sample_names, - 'gvcf_import_intervals': intervals_typ._convert_to_json(self._gvcf_import_intervals), - } + return { + 'name': self.__class__.__name__, + 'save_path': self._save_path, + 'output_path': self._output_path, + 'temp_path': self._temp_path, + 'reference_genome': str(self._reference_genome), + 'dataset_type': self._dataset_type, + 'gvcf_type': self._gvcf_type, + 'branch_factor': self._branch_factor, + 'target_records': self._target_records, + 'gvcf_batch_size': self._gvcf_batch_size, + 'gvcf_external_header': self._gvcf_external_header, # put this here for humans + 'contig_recoding': self._contig_recoding, + 'gvcf_info_to_keep': None if self._gvcf_info_to_keep is None else list(self._gvcf_info_to_keep), + 'gvcf_reference_entry_fields_to_keep': None + if self._gvcf_reference_entry_fields_to_keep is None + else list(self._gvcf_reference_entry_fields_to_keep), + 'call_fields': self._call_fields, + 'vdses': [md for i in sorted(self._vdses, reverse=True) for md in self._vdses[i]], + 'gvcfs': self._gvcfs, + 'gvcf_sample_names': self._gvcf_sample_names, + 'gvcf_import_intervals': intervals_typ._convert_to_json(self._gvcf_import_intervals), + } @property def _num_vdses(self): @@ -407,25 +432,19 @@ def step(self): self._job_id += 1 def _write_final(self, vds): - fd = VariantDataset.ref_block_max_length_field + vds.write(self._output_path) - if fd not in vds.reference_data.globals: + if VariantDataset.ref_block_max_length_field not in vds.reference_data.globals: info("VDS combiner: computing reference block max length...") - max_len = vds.reference_data.aggregate_entries( - hl.agg.max(vds.reference_data.END + 1 - vds.reference_data.locus.position)) - info(f"VDS combiner: max reference block length is {max_len}") - vds = VariantDataset(reference_data=vds.reference_data.annotate_globals(**{fd: max_len}), - variant_data=vds.variant_data) - - vds.write(self._output_path) + hl.vds.store_ref_block_max_length(self._output_path) def _step_vdses(self): current_bin = original_bin = min(self._vdses) - files_to_merge = self._vdses[current_bin][:self._branch_factor] + files_to_merge = self._vdses[current_bin][: self._branch_factor] if len(files_to_merge) == len(self._vdses[current_bin]): del self._vdses[current_bin] else: - self._vdses[current_bin] = self._vdses[current_bin][self._branch_factor:] + self._vdses[current_bin] = self._vdses[current_bin][self._branch_factor :] remaining = self._branch_factor - len(files_to_merge) while self._num_vdses > 0 and remaining > 0: @@ -443,18 +462,21 @@ def _step_vdses(self): temp_path = self._temp_out_path(f'vds-combine_job{self._job_id}') largest_vds = max(files_to_merge, key=lambda vds: vds.n_samples) - vds = hl.vds.read_vds(largest_vds.path, - _assert_reference_type=self._dataset_type.reference_type, - _assert_variant_type=self._dataset_type.variant_type, - _warn_no_ref_block_max_length=False) + vds = hl.vds.read_vds( + largest_vds.path, + _assert_reference_type=self._dataset_type.reference_type, + _assert_variant_type=self._dataset_type.variant_type, + _warn_no_ref_block_max_length=False, + ) interval_bin = floor(log(new_n_samples, self._branch_factor)) intervals = self.__intervals_cache.get(interval_bin) if intervals is None: # we use the reference data since it generally has more rows than the variant data - intervals, _ = calculate_new_intervals(vds.reference_data, self._target_records, - os.path.join(temp_path, 'interval_checkpoint.ht')) + intervals, _ = calculate_new_intervals( + vds.reference_data, self._target_records, os.path.join(temp_path, 'interval_checkpoint.ht') + ) self.__intervals_cache[interval_bin] = intervals paths = [f.path for f in files_to_merge] @@ -476,15 +498,17 @@ def _step_vdses(self): def _step_gvcfs(self): step = self._branch_factor - files_to_merge = self._gvcfs[:self._gvcf_batch_size * step] - self._gvcfs = self._gvcfs[self._gvcf_batch_size * step:] + files_to_merge = self._gvcfs[: self._gvcf_batch_size * step] + self._gvcfs = self._gvcfs[self._gvcf_batch_size * step :] - info(f'GVCF combine (job {self._job_id}): merging {len(files_to_merge)} GVCFs into ' - f'{(len(files_to_merge) + step - 1) // step} datasets') + info( + f'GVCF combine (job {self._job_id}): merging {len(files_to_merge)} GVCFs into ' + f'{(len(files_to_merge) + step - 1) // step} datasets' + ) if self._gvcf_external_header is not None: - sample_names = self._gvcf_sample_names[:self._gvcf_batch_size * step] - self._gvcf_sample_names = self._gvcf_sample_names[self._gvcf_batch_size * step:] + sample_names = self._gvcf_sample_names[: self._gvcf_batch_size * step] + self._gvcf_sample_names = self._gvcf_sample_names[self._gvcf_batch_size * step :] else: sample_names = None header_file = self._gvcf_external_header or files_to_merge[0] @@ -492,74 +516,97 @@ def _step_gvcfs(self): merge_vds = [] merge_n_samples = [] - intervals_literal = hl.literal([hl.Struct(contig=i.start.contig, start=i.start.position, end=i.end.position) for - i in self._gvcf_import_intervals]) + intervals_literal = hl.literal([ + hl.Struct(contig=i.start.contig, start=i.start.position, end=i.end.position) + for i in self._gvcf_import_intervals + ]) partition_interval_point_type = hl.tstruct(locus=hl.tlocus(self._reference_genome)) - partition_intervals = [hl.Interval(start=hl.Struct(locus=i.start), - end=hl.Struct(locus=i.end), - includes_start=i.includes_start, - includes_end=i.includes_end, - point_type=partition_interval_point_type) for i in - self._gvcf_import_intervals] + partition_intervals = [ + hl.Interval( + start=hl.Struct(locus=i.start), + end=hl.Struct(locus=i.end), + includes_start=i.includes_start, + includes_end=i.includes_end, + point_type=partition_interval_point_type, + ) + for i in self._gvcf_import_intervals + ] vcfs = files_to_merge if sample_names is None: vcfs_lit = hl.literal(vcfs) range_ht = hl.utils.range_table(len(vcfs), n_partitions=min(len(vcfs), 32)) - range_ht = range_ht.annotate(sample_id=hl.rbind(hl.get_vcf_header_info(vcfs_lit[range_ht.idx]), - lambda header: header.sampleIDs[0])) + range_ht = range_ht.annotate( + sample_id=hl.rbind(hl.get_vcf_header_info(vcfs_lit[range_ht.idx]), lambda header: header.sampleIDs[0]) + ) sample_ids = range_ht.aggregate(hl.agg.collect(range_ht.sample_id)) else: sample_ids = sample_names for start in range(0, len(vcfs), step): - ids = sample_ids[start:start + step] - merging = vcfs[start:start + step] - - reference_ht = hl.Table._generate(contexts=intervals_literal, - partitions=partition_intervals, - rowfn=lambda interval, globals: - hl._zip_join_producers(hl.enumerate(hl.literal(merging)), - lambda idx_and_path: make_reference_stream( - hl.import_gvcf_interval( - idx_and_path[1], idx_and_path[0], - interval.contig, - interval.start, interval.end, header_info, - call_fields=self._call_fields, - array_elements_required=False, - reference_genome=self._reference_genome, - contig_recoding=self._contig_recoding), - self._gvcf_reference_entry_fields_to_keep), - ['locus'], - lambda k, v: k.annotate(data=v)), - globals=hl.struct( - g=hl.literal(ids).map(lambda s: hl.struct(__cols=[hl.struct(s=s)])))) + ids = sample_ids[start : start + step] + merging = vcfs[start : start + step] + + reference_ht = hl.Table._generate( + contexts=intervals_literal, + partitions=partition_intervals, + rowfn=lambda interval, globals: hl._zip_join_producers( + hl.enumerate(hl.literal(merging)), + lambda idx_and_path: make_reference_stream( + hl.import_gvcf_interval( + idx_and_path[1], + idx_and_path[0], + interval.contig, + interval.start, + interval.end, + header_info, + call_fields=self._call_fields, + array_elements_required=False, + reference_genome=self._reference_genome, + contig_recoding=self._contig_recoding, + ), + self._gvcf_reference_entry_fields_to_keep, + ), + ['locus'], + lambda k, v: k.annotate(data=v), + ), + globals=hl.struct(g=hl.literal(ids).map(lambda s: hl.struct(__cols=[hl.struct(s=s)]))), + ) reference_ht = combine_r(reference_ht, ref_block_max_len_field=None) # compute max length at the end - variant_ht = hl.Table._generate(contexts=intervals_literal, - partitions=partition_intervals, - rowfn=lambda interval, globals: - hl._zip_join_producers(hl.enumerate(hl.literal(merging)), - lambda idx_and_path: make_variant_stream( - hl.import_gvcf_interval( - idx_and_path[1], idx_and_path[0], - interval.contig, - interval.start, interval.end, header_info, - call_fields=self._call_fields, - array_elements_required=False, - reference_genome=self._reference_genome, - contig_recoding=self._contig_recoding), - self._gvcf_info_to_keep), - ['locus'], - lambda k, v: k.annotate(data=v)), - globals=hl.struct( - g=hl.literal(ids).map(lambda s: hl.struct(__cols=[hl.struct(s=s)])))) + variant_ht = hl.Table._generate( + contexts=intervals_literal, + partitions=partition_intervals, + rowfn=lambda interval, globals: hl._zip_join_producers( + hl.enumerate(hl.literal(merging)), + lambda idx_and_path: make_variant_stream( + hl.import_gvcf_interval( + idx_and_path[1], + idx_and_path[0], + interval.contig, + interval.start, + interval.end, + header_info, + call_fields=self._call_fields, + array_elements_required=False, + reference_genome=self._reference_genome, + contig_recoding=self._contig_recoding, + ), + self._gvcf_info_to_keep, + ), + ['locus'], + lambda k, v: k.annotate(data=v), + ), + globals=hl.struct(g=hl.literal(ids).map(lambda s: hl.struct(__cols=[hl.struct(s=s)]))), + ) variant_ht = combine(variant_ht) - vds = VariantDataset(reference_ht._unlocalize_entries('__entries', '__cols', ['s']), - variant_ht._unlocalize_entries('__entries', '__cols', - ['s'])._key_rows_by_assert_sorted('locus', - 'alleles')) + vds = VariantDataset( + reference_ht._unlocalize_entries('__entries', '__cols', ['s']), + variant_ht._unlocalize_entries('__entries', '__cols', ['s'])._key_rows_by_assert_sorted( + 'locus', 'alleles' + ), + ) merge_vds.append(vds) merge_n_samples.append(len(merging)) @@ -569,9 +616,10 @@ def _step_gvcfs(self): temp_path = self._temp_out_path(f'gvcf-combine_job{self._job_id}/dataset_') pad = len(str(len(merge_vds) - 1)) - merge_metadata = [VDSMetadata(path=temp_path + str(count).rjust(pad, '0') + '.vds', - n_samples=n_samples) - for count, n_samples in enumerate(merge_n_samples)] + merge_metadata = [ + VDSMetadata(path=temp_path + str(count).rjust(pad, '0') + '.vds', n_samples=n_samples) + for count, n_samples in enumerate(merge_n_samples) + ] paths = [md.path for md in merge_metadata] hl.vds.write_variant_datasets(merge_vds, paths, overwrite=True, codec_spec=FAST_CODEC_SPEC) for md in merge_metadata: @@ -583,67 +631,92 @@ def _temp_out_path(self, extra): def _read_variant_datasets(self, inputs: List[str], intervals: List[Interval]): reference_type = self._dataset_type.reference_type variant_type = self._dataset_type.variant_type - return [hl.vds.read_vds(path, intervals=intervals, - _assert_reference_type=reference_type, - _assert_variant_type=variant_type, - _warn_no_ref_block_max_length=False) - for path in inputs] - - -def new_combiner(*, - output_path: str, - temp_path: str, - save_path: Optional[str] = None, - gvcf_paths: Optional[List[str]] = None, - vds_paths: Optional[List[str]] = None, - vds_sample_counts: Optional[List[int]] = None, - intervals: Optional[List[Interval]] = None, - import_interval_size: Optional[int] = None, - use_genome_default_intervals: bool = False, - use_exome_default_intervals: bool = False, - gvcf_external_header: Optional[str] = None, - gvcf_sample_names: Optional[List[str]] = None, - gvcf_info_to_keep: Optional[Collection[str]] = None, - gvcf_reference_entry_fields_to_keep: Optional[Collection[str]] = None, - call_fields: Collection[str] = ['PGT'], - branch_factor: int = VariantDatasetCombiner._default_branch_factor, - target_records: int = VariantDatasetCombiner._default_target_records, - gvcf_batch_size: Optional[int] = None, - batch_size: Optional[int] = None, - reference_genome: Union[str, hl.ReferenceGenome] = 'default', - contig_recoding: Optional[Dict[str, str]] = None, - force: bool = False, - ) -> VariantDatasetCombiner: + return [ + hl.vds.read_vds( + path, + intervals=intervals, + _assert_reference_type=reference_type, + _assert_variant_type=variant_type, + _warn_no_ref_block_max_length=False, + ) + for path in inputs + ] + + +def new_combiner( + *, + output_path: str, + temp_path: str, + save_path: Optional[str] = None, + gvcf_paths: Optional[List[str]] = None, + vds_paths: Optional[List[str]] = None, + vds_sample_counts: Optional[List[int]] = None, + intervals: Optional[List[Interval]] = None, + import_interval_size: Optional[int] = None, + use_genome_default_intervals: bool = False, + use_exome_default_intervals: bool = False, + gvcf_external_header: Optional[str] = None, + gvcf_sample_names: Optional[List[str]] = None, + gvcf_info_to_keep: Optional[Collection[str]] = None, + gvcf_reference_entry_fields_to_keep: Optional[Collection[str]] = None, + call_fields: Collection[str] = ['PGT'], + branch_factor: int = VariantDatasetCombiner._default_branch_factor, + target_records: int = VariantDatasetCombiner._default_target_records, + gvcf_batch_size: Optional[int] = None, + batch_size: Optional[int] = None, + reference_genome: Union[str, ReferenceGenome] = 'default', + contig_recoding: Optional[Dict[str, str]] = None, + force: bool = False, +) -> VariantDatasetCombiner: """Create a new :class:`.VariantDatasetCombiner` or load one from `save_path`.""" if not (gvcf_paths or vds_paths): raise ValueError("at least one of 'gvcf_paths' or 'vds_paths' must be nonempty") if gvcf_paths is None: gvcf_paths = [] + if len(gvcf_paths) > 0: + if len(set(gvcf_paths)) != len(gvcf_paths): + duplicates = [gvcf for gvcf, count in collections.Counter(gvcf_paths).items() if count > 1] + duplicates = '\n '.join(duplicates) + raise ValueError(f'gvcf paths should be unique, the following paths are repeated:{duplicates}') + if gvcf_sample_names is not None and len(set(gvcf_sample_names)) != len(gvcf_sample_names): + duplicates = [gvcf for gvcf, count in collections.Counter(gvcf_sample_names).items() if count > 1] + duplicates = '\n '.join(duplicates) + raise ValueError( + "provided sample names ('gvcf_sample_names') should be unique, " + f'the following names are repeated:{duplicates}' + ) + if vds_paths is None: vds_paths = [] if vds_sample_counts is not None and len(vds_paths) != len(vds_sample_counts): - raise ValueError("'vds_paths' and 'vds_sample_counts' (if present) must have the same length " - f'{len(vds_paths)} != {len(vds_sample_counts)}') + raise ValueError( + "'vds_paths' and 'vds_sample_counts' (if present) must have the same length " + f'{len(vds_paths)} != {len(vds_sample_counts)}' + ) if (gvcf_sample_names is None) != (gvcf_external_header is None): raise ValueError("both 'gvcf_sample_names' and 'gvcf_external_header' must be set or unset") if gvcf_sample_names is not None and len(gvcf_sample_names) != len(gvcf_paths): - raise ValueError("'gvcf_sample_names' and 'gvcf_paths' must have the same length " - f'{len(gvcf_sample_names)} != {len(gvcf_paths)}') + raise ValueError( + "'gvcf_sample_names' and 'gvcf_paths' must have the same length " + f'{len(gvcf_sample_names)} != {len(gvcf_paths)}' + ) if batch_size is None: if gvcf_batch_size is None: gvcf_batch_size = VariantDatasetCombiner._default_gvcf_batch_size else: pass + elif gvcf_batch_size is None: + warning( + 'The batch_size parameter is deprecated. ' + 'The batch_size parameter will be removed in a future version of Hail. ' + 'Please use gvcf_batch_size instead.' + ) + gvcf_batch_size = batch_size else: - if gvcf_batch_size is None: - warning('The batch_size parameter is deprecated. ' - 'The batch_size parameter will be removed in a future version of Hail. ' - 'Please use gvcf_batch_size instead.') - gvcf_batch_size = batch_size - else: - raise ValueError('Specify only one of batch_size and gvcf_batch_size. ' - f'Received {batch_size} and {gvcf_batch_size}.') + raise ValueError( + 'Specify only one of batch_size and gvcf_batch_size. ' f'Received {batch_size} and {gvcf_batch_size}.' + ) del batch_size def maybe_load_from_saved_path(save_path: str) -> Optional[VariantDatasetCombiner]: @@ -663,8 +736,10 @@ def maybe_load_from_saved_path(save_path: str) -> Optional[VariantDatasetCombine combiner._gvcf_batch_size = gvcf_batch_size return combiner except (ValueError, TypeError, OSError, KeyError) as e: - warning(f'file exists at {save_path}, but it is not a valid combiner plan, overwriting\n' - f' caused by: {e}') + warning( + f'file exists at {save_path}, but it is not a valid combiner plan, overwriting\n' + f' caused by: {e}' + ) return None # We do the first save_path check now after validating the arguments @@ -674,19 +749,25 @@ def maybe_load_from_saved_path(save_path: str) -> Optional[VariantDatasetCombine return saved_combiner if len(gvcf_paths) > 0: - n_partition_args = (int(intervals is not None) - + int(import_interval_size is not None) - + int(use_genome_default_intervals) - + int(use_exome_default_intervals)) + n_partition_args = ( + int(intervals is not None) + + int(import_interval_size is not None) + + int(use_genome_default_intervals) + + int(use_exome_default_intervals) + ) if n_partition_args == 0: - raise ValueError("'new_combiner': require one argument from 'intervals', 'import_interval_size', " - "'use_genome_default_intervals', or 'use_exome_default_intervals' to choose GVCF partitioning") + raise ValueError( + "'new_combiner': require one argument from 'intervals', 'import_interval_size', " + "'use_genome_default_intervals', or 'use_exome_default_intervals' to choose GVCF partitioning" + ) if n_partition_args > 1: - warning("'new_combiner': multiple colliding arguments found from 'intervals', 'import_interval_size', " - "'use_genome_default_intervals', or 'use_exome_default_intervals'." - "\n The argument found first in the list in this warning will be used, and others ignored.") + warning( + "'new_combiner': multiple colliding arguments found from 'intervals', 'import_interval_size', " + "'use_genome_default_intervals', or 'use_exome_default_intervals'." + "\n The argument found first in the list in this warning will be used, and others ignored." + ) if intervals is not None: pass @@ -715,39 +796,46 @@ def maybe_load_from_saved_path(save_path: str) -> Optional[VariantDatasetCombine vds = hl.vds.read_vds(vds_paths[0], _warn_no_ref_block_max_length=False) vds_ref_entry = set(vds.reference_data.entry) - {'END'} if gvcf_reference_entry_fields_to_keep is not None and vds_ref_entry != gvcf_reference_entry_fields_to_keep: - warning("Mismatch between 'gvcf_reference_entry_fields' to keep and VDS reference data " - "entry types. Overwriting with reference entry fields from supplied VDS.\n" - f" VDS reference entry fields : {sorted(vds_ref_entry)}\n" - f" requested reference entry fields: {sorted(gvcf_reference_entry_fields_to_keep)}") + warning( + "Mismatch between 'gvcf_reference_entry_fields' to keep and VDS reference data " + "entry types. Overwriting with reference entry fields from supplied VDS.\n" + f" VDS reference entry fields : {sorted(vds_ref_entry)}\n" + f" requested reference entry fields: {sorted(gvcf_reference_entry_fields_to_keep)}" + ) gvcf_reference_entry_fields_to_keep = vds_ref_entry # sync up call_fields and call fields present in the VDS - all_entry_types = chain(vds.reference_data._type.entry_type.items(), - vds.variant_data._type.entry_type.items()) + all_entry_types = chain(vds.reference_data._type.entry_type.items(), vds.variant_data._type.entry_type.items()) vds_call_fields = {name for name, typ in all_entry_types if typ == hl.tcall} - {'LGT', 'GT'} if 'LPGT' in vds_call_fields: vds_call_fields = (vds_call_fields - {'LPGT'}) | {'PGT'} if set(call_fields) != vds_call_fields: - warning("Mismatch between 'call_fields' and VDS call fields. " - "Overwriting with call fields from supplied VDS.\n" - f" VDS call fields : {sorted(vds_call_fields)}\n" - f" requested call fields: {sorted(call_fields)}\n") + warning( + "Mismatch between 'call_fields' and VDS call fields. " + "Overwriting with call fields from supplied VDS.\n" + f" VDS call fields : {sorted(vds_call_fields)}\n" + f" requested call fields: {sorted(call_fields)}\n" + ) call_fields = vds_call_fields if gvcf_paths: - mt = hl.import_vcf(gvcf_paths[0], header_file=gvcf_external_header, force_bgz=True, - array_elements_required=False, reference_genome=reference_genome, - contig_recoding=contig_recoding) + mt = hl.import_vcf( + gvcf_paths[0], + header_file=gvcf_external_header, + force_bgz=True, + array_elements_required=False, + reference_genome=reference_genome, + contig_recoding=contig_recoding, + ) gvcf_type = mt._type if gvcf_reference_entry_fields_to_keep is None: rmt = mt.filter_rows(hl.is_defined(mt.info.END)) gvcf_reference_entry_fields_to_keep = defined_entry_fields(rmt, 100_000) - {'GT', 'PGT', 'PL'} if vds is None: - vds = transform_gvcf(mt._key_rows_by_assert_sorted('locus'), - gvcf_reference_entry_fields_to_keep, - gvcf_info_to_keep) - dataset_type = CombinerOutType(reference_type=vds.reference_data._type, - variant_type=vds.variant_data._type) + vds = transform_gvcf( + mt._key_rows_by_assert_sorted('locus'), gvcf_reference_entry_fields_to_keep, gvcf_info_to_keep + ) + dataset_type = CombinerOutType(reference_type=vds.reference_data._type, variant_type=vds.variant_data._type) if save_path is None: sha = hashlib.sha256() @@ -793,31 +881,36 @@ def maybe_load_from_saved_path(save_path: str) -> Optional[VariantDatasetCombine else: vdses = [] for path in vds_paths: - vds = hl.vds.read_vds(path, _assert_reference_type=dataset_type.reference_type, - _assert_variant_type=dataset_type.variant_type, - _warn_no_ref_block_max_length=False) + vds = hl.vds.read_vds( + path, + _assert_reference_type=dataset_type.reference_type, + _assert_variant_type=dataset_type.variant_type, + _warn_no_ref_block_max_length=False, + ) n_samples = vds.n_samples() vdses.append(VDSMetadata(path, n_samples)) vdses.sort(key=lambda x: x.n_samples, reverse=True) - combiner = VariantDatasetCombiner(save_path=save_path, - output_path=output_path, - temp_path=temp_path, - reference_genome=reference_genome, - dataset_type=dataset_type, - branch_factor=branch_factor, - target_records=target_records, - gvcf_batch_size=gvcf_batch_size, - contig_recoding=contig_recoding, - call_fields=call_fields, - vdses=vdses, - gvcfs=gvcf_paths, - gvcf_import_intervals=intervals, - gvcf_external_header=gvcf_external_header, - gvcf_sample_names=gvcf_sample_names, - gvcf_info_to_keep=gvcf_info_to_keep, - gvcf_reference_entry_fields_to_keep=gvcf_reference_entry_fields_to_keep) + combiner = VariantDatasetCombiner( + save_path=save_path, + output_path=output_path, + temp_path=temp_path, + reference_genome=reference_genome, + dataset_type=dataset_type, + branch_factor=branch_factor, + target_records=target_records, + gvcf_batch_size=gvcf_batch_size, + contig_recoding=contig_recoding, + call_fields=call_fields, + vdses=vdses, + gvcfs=gvcf_paths, + gvcf_import_intervals=intervals, + gvcf_external_header=gvcf_external_header, + gvcf_sample_names=gvcf_sample_names, + gvcf_info_to_keep=gvcf_info_to_keep, + gvcf_reference_entry_fields_to_keep=gvcf_reference_entry_fields_to_keep, + ) combiner._raise_if_output_exists() return combiner diff --git a/hail/python/hail/vds/functions.py b/hail/python/hail/vds/functions.py index af9a7166312..b130f08ad6f 100644 --- a/hail/python/hail/vds/functions.py +++ b/hail/python/hail/vds/functions.py @@ -1,7 +1,7 @@ import hail as hl -from hail.expr.expressions import expr_array, expr_call, expr_int32, expr_any +from hail.expr.expressions import expr_any, expr_array, expr_call, expr_int32 from hail.expr.functions import _func -from hail.typecheck import typecheck, enumeration +from hail.typecheck import enumeration, typecheck @typecheck(lgt=expr_call, la=expr_array(expr_int32)) @@ -27,7 +27,7 @@ def lgt_to_gt(lgt, la): local_alleles=expr_array(expr_int32), n_alleles=expr_int32, fill_value=expr_any, - number=enumeration('A', 'R', 'G') + number=enumeration('A', 'R', 'G'), ) def local_to_global(array, local_alleles, n_alleles, fill_value, number): """Reindex a locally-indexed array to globally-indexed. diff --git a/hail/python/hail/vds/methods.py b/hail/python/hail/vds/methods.py index 8f93755c400..100f381d2c2 100644 --- a/hail/python/hail/vds/methods.py +++ b/hail/python/hail/vds/methods.py @@ -1,23 +1,19 @@ -from typing import Sequence - import hail as hl from hail import ir -from hail.expr import expr_any, expr_array, expr_interval, expr_locus, expr_str, expr_bool +from hail.expr import expr_any, expr_array, expr_bool, expr_interval, expr_locus, expr_str from hail.matrixtable import MatrixTable -from hail.methods.misc import require_first_key_field_locus from hail.table import Table -from hail.typecheck import sequenceof, typecheck, nullable, oneof, enumeration, func_spec, dictof +from hail.typecheck import dictof, enumeration, func_spec, nullable, oneof, sequenceof, typecheck from hail.utils.java import Env, info, warning -from hail.utils.misc import divide_null, new_temp_file, wrap_to_list +from hail.utils.misc import new_temp_file, wrap_to_list from hail.vds.variant_dataset import VariantDataset -def write_variant_datasets(vdss, paths, *, - overwrite=False, stage_locally=False, - codec_spec=None): +def write_variant_datasets(vdss, paths, *, overwrite=False, stage_locally=False, codec_spec=None): """Write many `vdses` to their corresponding path in `paths`.""" - ref_writer = ir.MatrixNativeMultiWriter([f"{p}/reference_data" for p in paths], overwrite, stage_locally, - codec_spec) + ref_writer = ir.MatrixNativeMultiWriter( + [f"{p}/reference_data" for p in paths], overwrite, stage_locally, codec_spec + ) var_writer = ir.MatrixNativeMultiWriter([f"{p}/variant_data" for p in paths], overwrite, stage_locally, codec_spec) Env.backend().execute(ir.MatrixMultiWrite([vds.reference_data._mir for vds in vdss], ref_writer)) Env.backend().execute(ir.MatrixMultiWrite([vds.variant_data._mir for vds in vdss], var_writer)) @@ -54,33 +50,40 @@ def to_dense_mt(vds: 'VariantDataset') -> 'MatrixTable': joined = varl.key_by('locus').join(refl, how='outer') dr = joined.annotate( dense_ref=hl.or_missing( - joined._variant_defined, - hl.scan._densify(hl.len(joined._var_cols), joined._ref_entries) + joined._variant_defined, hl.scan._densify(hl.len(joined._var_cols), joined._ref_entries) ) ) dr = dr.filter(dr._variant_defined) def coalesce_join(ref, var): - call_field = 'GT' if 'GT' in var else 'LGT' assert call_field in var, var.dtype - shared_fields = [call_field] + list(f for f in ref.dtype if f in var.dtype) + shared_fields = [call_field, *list(f for f in ref.dtype if f in var.dtype)] shared_field_set = set(shared_fields) var_fields = [f for f in var.dtype if f not in shared_field_set] - return hl.if_else(hl.is_defined(var), - var.select(*shared_fields, *var_fields), - ref.annotate(**{call_field: hl.call(0, 0)}) - .select(*shared_fields, **{f: hl.missing(var[f].dtype) for f in var_fields})) + return hl.if_else( + hl.is_defined(var), + var.select(*shared_fields, *var_fields), + ref.annotate(**{call_field: hl.call(0, 0)}).select( + *shared_fields, **{f: hl.missing(var[f].dtype) for f in var_fields} + ), + ) dr = dr.annotate( - _dense=hl.rbind(dr._ref_entries, - lambda refs_at_this_row: hl.enumerate(hl.zip(dr._var_entries, dr.dense_ref)).map( - lambda tup: coalesce_join(hl.coalesce(refs_at_this_row[tup[0]], - hl.or_missing(tup[1][1]._END_GLOBAL >= dr.locus.global_position(), - tup[1][1])), tup[1][0]) - )), + _dense=hl.rbind( + dr._ref_entries, + lambda refs_at_this_row: hl.enumerate(hl.zip(dr._var_entries, dr.dense_ref)).map( + lambda tup: coalesce_join( + hl.coalesce( + refs_at_this_row[tup[0]], + hl.or_missing(tup[1][1]._END_GLOBAL >= dr.locus.global_position(), tup[1][1]), + ), + tup[1][0], + ) + ), + ), ) dr = dr._key_by_assert_sorted('locus', 'alleles') @@ -130,7 +133,6 @@ def to_merged_sparse_mt(vds: 'VariantDataset', *, ref_allele_function=None) -> ' ht = vht.join(rht, how='outer').drop('_ref_cols') def merge_arrays(r_array, v_array): - def rewrite_ref(r): ref_block_selector = {} for k, t in merged_schema.items(): @@ -143,207 +145,48 @@ def rewrite_ref(r): return r.select(**ref_block_selector) def rewrite_var(v): - return v.select(**{ - k: v[k] if k in v else hl.missing(t) - for k, t in merged_schema.items() - }) - - return hl.case() \ - .when(hl.is_missing(r_array), v_array.map(rewrite_var)) \ - .when(hl.is_missing(v_array), r_array.map(rewrite_ref)) \ + return v.select(**{k: v[k] if k in v else hl.missing(t) for k, t in merged_schema.items()}) + + return ( + hl.case() + .when(hl.is_missing(r_array), v_array.map(rewrite_var)) + .when(hl.is_missing(v_array), r_array.map(rewrite_ref)) .default(hl.zip(r_array, v_array).map(lambda t: hl.coalesce(rewrite_var(t[1]), rewrite_ref(t[0])))) + ) if ref_allele_function is None: rg = ht.locus.dtype.reference_genome if 'ref_allele' in ht.row: + def ref_allele_function(ht): return ht.ref_allele + elif rg.has_sequence(): + def ref_allele_function(ht): return ht.locus.sequence_context() + info("to_merged_sparse_mt: using locus sequence context to fill in reference alleles at monomorphic loci.") else: - raise ValueError("to_merged_sparse_mt: in order to construct a ref allele for reference-only sites, " - "either pass a function to fill in reference alleles (e.g. ref_allele_function=lambda locus: hl.missing('str'))" - " or add a sequence file with 'hl.get_reference(RG_NAME).add_sequence(FASTA_PATH)'.") + raise ValueError( + "to_merged_sparse_mt: in order to construct a ref allele for reference-only sites, " + "either pass a function to fill in reference alleles (e.g. ref_allele_function=lambda locus: hl.missing('str'))" + " or add a sequence file with 'hl.get_reference(RG_NAME).add_sequence(FASTA_PATH)'." + ) ht = ht.select( alleles=hl.coalesce(ht['alleles'], hl.array([ref_allele_function(ht)])), # handle cases where vmt is not keyed by alleles **{k: ht[k] for k in vds.variant_data.row_value if k != 'alleles'}, - _entries=merge_arrays(ht['_ref_entries'], ht['_var_entries']) + _entries=merge_arrays(ht['_ref_entries'], ht['_var_entries']), ) ht = ht._key_by_assert_sorted('locus', 'alleles') return ht._unlocalize_entries('_entries', '_var_cols', list(vds.variant_data.col_key)) -@typecheck(vds=VariantDataset, gq_bins=sequenceof(int), dp_bins=sequenceof(int), dp_field=nullable(str)) -def sample_qc(vds: 'VariantDataset', *, gq_bins: 'Sequence[int]' = (0, 20, 60), - dp_bins: 'Sequence[int]' = (0, 1, 10, 20, 30), dp_field=None) -> 'Table': - """Compute sample quality metrics about a :class:`.VariantDataset`. - - If the `dp_field` parameter is not specified, the ``DP`` is used for depth - if present. If no ``DP`` field is present, the ``MIN_DP`` field is used. If no ``DP`` - or ``MIN_DP`` field is present, no depth statistics will be calculated. - - Parameters - ---------- - vds : :class:`.VariantDataset` - Dataset in VariantDataset representation. - name : :obj:`str` - Name for resulting field. - gq_bins : :class:`tuple` of :obj:`int` - Tuple containing cutoffs for genotype quality (GQ) scores. - dp_bins : :class:`tuple` of :obj:`int` - Tuple containing cutoffs for depth (DP) scores. - dp_field : :obj:`str` - Name of depth field. If not supplied, DP or MIN_DP will be used, in that order. - - Returns - ------- - :class:`.Table` - Hail Table of results, keyed by sample. - """ - - require_first_key_field_locus(vds.reference_data, 'sample_qc') - require_first_key_field_locus(vds.variant_data, 'sample_qc') - - ref = vds.reference_data - - if 'DP' in ref.entry: - ref_dp_field_to_use = 'DP' - elif 'MIN_DP' in ref.entry: - ref_dp_field_to_use = 'MIN_DP' - else: - ref_dp_field_to_use = dp_field - - from hail.expr.functions import _num_allele_type, _allele_types - - allele_types = _allele_types[:] - allele_types.extend(['Transition', 'Transversion']) - allele_enum = {i: v for i, v in enumerate(allele_types)} - allele_ints = {v: k for k, v in allele_enum.items()} - - def allele_type(ref, alt): - return hl.bind( - lambda at: hl.if_else(at == allele_ints['SNP'], - hl.if_else(hl.is_transition(ref, alt), - allele_ints['Transition'], - allele_ints['Transversion']), - at), - _num_allele_type(ref, alt) - ) - - variant_ac = Env.get_uid() - variant_atypes = Env.get_uid() - - vmt = vds.variant_data - if 'GT' not in vmt.entry: - vmt = vmt.annotate_entries(GT=hl.vds.lgt_to_gt(vmt.LGT, vmt.LA)) - - vmt = vmt.annotate_rows(**{ - variant_ac: hl.agg.call_stats(vmt.GT, vmt.alleles).AC, - variant_atypes: vmt.alleles[1:].map(lambda alt: allele_type(vmt.alleles[0], alt)) - }) - - bound_exprs = {} - - bound_exprs['n_het'] = hl.agg.count_where(vmt['GT'].is_het()) - bound_exprs['n_hom_var'] = hl.agg.count_where(vmt['GT'].is_hom_var()) - bound_exprs['n_singleton'] = hl.agg.sum( - hl.rbind(vmt['GT'], lambda gt: hl.sum(hl.range(0, gt.ploidy).map( - lambda i: hl.rbind(gt[i], lambda gti: (gti != 0) & (vmt[variant_ac][gti] == 1))))) - ) - bound_exprs['n_singleton_ti'] = hl.agg.sum( - hl.rbind(vmt['GT'], lambda gt: hl.sum(hl.range(0, gt.ploidy).map( - lambda i: hl.rbind(gt[i], lambda gti: (gti != 0) & (vmt[variant_ac][gti] == 1) & ( - vmt[variant_atypes][gti - 1] == allele_ints['Transition']))))) - ) - bound_exprs['n_singleton_tv'] = hl.agg.sum( - hl.rbind(vmt['GT'], lambda gt: hl.sum(hl.range(0, gt.ploidy).map( - lambda i: hl.rbind(gt[i], lambda gti: (gti != 0) & (vmt[variant_ac][gti] == 1) & ( - vmt[variant_atypes][gti - 1] == allele_ints['Transversion']))))) - ) - - bound_exprs['allele_type_counts'] = hl.agg.explode( - lambda allele_type: hl.tuple( - hl.agg.count_where(allele_type == i) for i in range(len(allele_ints)) - ), - (hl.range(0, vmt['GT'].ploidy) - .map(lambda i: vmt['GT'][i]) - .filter(lambda allele_idx: allele_idx > 0) - .map(lambda allele_idx: vmt[variant_atypes][allele_idx - 1])) - ) - - dp_exprs = {} - if ref_dp_field_to_use is not None and 'DP' in vmt.entry: - dp_exprs['dp'] = hl.tuple(hl.agg.count_where(vmt.DP >= x) for x in dp_bins) - - gq_dp_exprs = hl.struct(**{'gq': hl.tuple(hl.agg.count_where(vmt.GQ >= x) for x in gq_bins)}, - **dp_exprs) - - result_struct = hl.rbind( - hl.struct(**bound_exprs), - lambda x: hl.rbind( - hl.struct(**{ - 'gq_dp_exprs': gq_dp_exprs, - 'n_het': x.n_het, - 'n_hom_var': x.n_hom_var, - 'n_non_ref': x.n_het + x.n_hom_var, - 'n_singleton': x.n_singleton, - 'n_singleton_ti': x.n_singleton_ti, - 'n_singleton_tv': x.n_singleton_tv, - 'n_snp': (x.allele_type_counts[allele_ints['Transition']] - + x.allele_type_counts[allele_ints['Transversion']]), - 'n_insertion': x.allele_type_counts[allele_ints['Insertion']], - 'n_deletion': x.allele_type_counts[allele_ints['Deletion']], - 'n_transition': x.allele_type_counts[allele_ints['Transition']], - 'n_transversion': x.allele_type_counts[allele_ints['Transversion']], - 'n_star': x.allele_type_counts[allele_ints['Star']] - }), - lambda s: s.annotate( - r_ti_tv=divide_null(hl.float64(s.n_transition), s.n_transversion), - r_ti_tv_singleton=divide_null(hl.float64(s.n_singleton_ti), s.n_singleton_tv), - r_het_hom_var=divide_null(hl.float64(s.n_het), s.n_hom_var), - r_insertion_deletion=divide_null(hl.float64(s.n_insertion), s.n_deletion) - ) - ) - ) - variant_results = vmt.select_cols(**result_struct).cols() - - rmt = vds.reference_data - - ref_dp_expr = {} - if ref_dp_field_to_use is not None: - ref_dp_expr['ref_bases_over_dp_threshold'] = hl.tuple( - hl.agg.filter(rmt[ref_dp_field_to_use] >= x, hl.agg.sum(1 + rmt.END - rmt.locus.position)) for x in - dp_bins) - ref_results = rmt.select_cols( - ref_bases_over_gq_threshold=hl.tuple( - hl.agg.filter(rmt.GQ >= x, hl.agg.sum(1 + rmt.END - rmt.locus.position)) for x in gq_bins), - **ref_dp_expr).cols() - - joined = ref_results[variant_results.key] - - joined_dp_expr = {} - dp_bins_field = {} - if ref_dp_field_to_use is not None: - joined_dp_expr['bases_over_dp_threshold'] = hl.tuple( - x + y for x, y in zip(variant_results.gq_dp_exprs.dp, joined.ref_bases_over_dp_threshold)) - dp_bins_field['dp_bins'] = hl.tuple(dp_bins) - - joined_results = variant_results.transmute( - bases_over_gq_threshold=hl.tuple( - x + y for x, y in zip(variant_results.gq_dp_exprs.gq, joined.ref_bases_over_gq_threshold)), - **joined_dp_expr) - - joined_results = joined_results.annotate_globals(gq_bins=hl.tuple(gq_bins), **dp_bins_field) - return joined_results - - @typecheck(vds=VariantDataset, samples=oneof(Table, expr_array(expr_str)), keep=bool, remove_dead_alleles=bool) -def filter_samples(vds: 'VariantDataset', samples, *, - keep: bool = True, - remove_dead_alleles: bool = False) -> 'VariantDataset': +def filter_samples( + vds: 'VariantDataset', samples, *, keep: bool = True, remove_dead_alleles: bool = False +) -> 'VariantDataset': """Filter samples in a :class:`.VariantDataset`. Parameters @@ -377,18 +220,26 @@ def filter_samples(vds: 'VariantDataset', samples, *, vd = vd.filter_rows(vd.__n > 0) vd = vd.drop('__n') - vd = vd.annotate_rows(__kept_indices=hl.dict( - hl.enumerate( - hl.range(hl.len(vd.alleles)).filter(lambda idx: (idx == 0) | (vd.__allele_counts.get(idx, 0) > 0)), - index_first=False))) + vd = vd.annotate_rows( + __kept_indices=hl.dict( + hl.enumerate( + hl.range(hl.len(vd.alleles)).filter(lambda idx: (idx == 0) | (vd.__allele_counts.get(idx, 0) > 0)), + index_first=False, + ) + ) + ) vd = vd.annotate_rows( - __old_to_new_LA=hl.range(hl.len(vd.alleles)).map(lambda idx: vd.__kept_indices.get(idx, -1))) + __old_to_new_LA=hl.range(hl.len(vd.alleles)).map(lambda idx: vd.__kept_indices.get(idx, -1)) + ) def new_la_index(old_idx): raw_idx = vd.__old_to_new_LA[old_idx] - return hl.case().when(raw_idx >= 0, raw_idx) \ + return ( + hl.case() + .when(raw_idx >= 0, raw_idx) .or_error("'filter_samples': unexpected local allele: old index=" + hl.str(old_idx)) + ) vd = vd.annotate_entries(LA=vd.LA.map(lambda la: new_la_index(la))) vd = vd.key_rows_by('locus') @@ -403,8 +254,8 @@ def new_la_index(old_idx): @typecheck(mt=MatrixTable, normalization_contig=str) def impute_sex_chr_ploidy_from_interval_coverage( - mt: 'MatrixTable', - normalization_contig: str, + mt: 'MatrixTable', + normalization_contig: str, ) -> 'Table': """Impute sex chromosome ploidy from a precomputed interval coverage MatrixTable. @@ -451,36 +302,32 @@ def impute_sex_chr_ploidy_from_interval_coverage( chr_y = rg.y_contigs[0] mt = mt.annotate_rows(contig=mt.interval.start.contig) - mt = mt.annotate_cols( - __mean_dp=hl.agg.group_by( - mt.contig, hl.agg.sum(mt.sum_dp) / hl.agg.sum(mt.interval_size) - ) - ) + mt = mt.annotate_cols(__mean_dp=hl.agg.group_by(mt.contig, hl.agg.sum(mt.sum_dp) / hl.agg.sum(mt.interval_size))) mean_dp_dict = mt.__mean_dp auto_dp = mean_dp_dict.get(normalization_contig, 0.0) x_dp = mean_dp_dict.get(chr_x, 0.0) y_dp = mean_dp_dict.get(chr_y, 0.0) - per_sample = mt.transmute_cols(autosomal_mean_dp=auto_dp, - x_mean_dp=x_dp, - x_ploidy=2 * x_dp / auto_dp, - y_mean_dp=y_dp, - y_ploidy=2 * y_dp / auto_dp) + per_sample = mt.transmute_cols( + autosomal_mean_dp=auto_dp, + x_mean_dp=x_dp, + x_ploidy=2 * x_dp / auto_dp, + y_mean_dp=y_dp, + y_ploidy=2 * y_dp / auto_dp, + ) info("'impute_sex_chromosome_ploidy': computing and checkpointing coverage and karyotype metrics") return per_sample.cols().checkpoint(new_temp_file('impute_sex_karyotype', extension='ht')) -@typecheck(vds=VariantDataset, - calling_intervals=oneof(Table, expr_array(expr_interval(expr_locus()))), - normalization_contig=str, - use_variant_dataset=bool - ) +@typecheck( + vds=VariantDataset, + calling_intervals=oneof(Table, expr_array(expr_interval(expr_locus()))), + normalization_contig=str, + use_variant_dataset=bool, +) def impute_sex_chromosome_ploidy( - vds: VariantDataset, - calling_intervals, - normalization_contig: str, - use_variant_dataset: bool = False -) -> hl.Table: + vds: VariantDataset, calling_intervals, normalization_contig: str, use_variant_dataset: bool = False +) -> Table: """Impute sex chromosome ploidy from depth of reference or variant data within calling intervals. Returns a :class:`.Table` with sample ID keys, with the following fields: @@ -508,14 +355,22 @@ def impute_sex_chromosome_ploidy( """ if not isinstance(calling_intervals, Table): - calling_intervals = hl.Table.parallelize(hl.map(lambda i: hl.struct(interval=i), calling_intervals), - schema=hl.tstruct(interval=calling_intervals.dtype.element_type), - key='interval') + calling_intervals = hl.Table.parallelize( + hl.map(lambda i: hl.struct(interval=i), calling_intervals), + schema=hl.tstruct(interval=calling_intervals.dtype.element_type), + key='interval', + ) else: key_dtype = calling_intervals.key.dtype - if len(key_dtype) != 1 or not isinstance(calling_intervals.key[0].dtype, hl.tinterval) or calling_intervals.key[0].dtype.point_type != vds.reference_data.locus.dtype: - raise ValueError(f"'impute_sex_chromosome_ploidy': expect calling_intervals to be list of intervals or" - f" table with single key of type interval, found table with key: {key_dtype}") + if ( + len(key_dtype) != 1 + or not isinstance(calling_intervals.key[0].dtype, hl.tinterval) + or calling_intervals.key[0].dtype.point_type != vds.reference_data.locus.dtype + ): + raise ValueError( + f"'impute_sex_chromosome_ploidy': expect calling_intervals to be list of intervals or" + f" table with single key of type interval, found table with key: {key_dtype}" + ) rg = vds.reference_data.locus.dtype.reference_genome @@ -528,17 +383,23 @@ def impute_sex_chromosome_ploidy( calling_intervals = hl.segment_intervals(calling_intervals, par_boundaries) # remove intervals overlapping PAR - calling_intervals = calling_intervals.filter(hl.all(lambda x: ~x.overlaps(calling_intervals.interval), hl.literal(rg.par))) + calling_intervals = calling_intervals.filter( + hl.all(lambda x: ~x.overlaps(calling_intervals.interval), hl.literal(rg.par)) + ) # checkpoint for efficient multiple downstream usages info("'impute_sex_chromosome_ploidy': checkpointing calling intervals") calling_intervals = calling_intervals.checkpoint(new_temp_file(extension='ht')) interval = calling_intervals.key[0] - (any_bad_intervals, chrs_represented) = calling_intervals.aggregate( - (hl.agg.any(interval.start.contig != interval.end.contig), hl.agg.collect_as_set(interval.start.contig))) + (any_bad_intervals, chrs_represented) = calling_intervals.aggregate(( + hl.agg.any(interval.start.contig != interval.end.contig), + hl.agg.collect_as_set(interval.start.contig), + )) if any_bad_intervals: - raise ValueError("'impute_sex_chromosome_ploidy' does not support calling intervals that span chromosome boundaries") + raise ValueError( + "'impute_sex_chromosome_ploidy' does not support calling intervals that span chromosome boundaries" + ) if len(rg.x_contigs) != 1: raise NotImplementedError( @@ -550,8 +411,10 @@ def impute_sex_chromosome_ploidy( ) kept_contig_filter = hl.array(chrs_represented).map(lambda x: hl.parse_locus_interval(x, reference_genome=rg)) - vds = VariantDataset(hl.filter_intervals(vds.reference_data, kept_contig_filter), - hl.filter_intervals(vds.variant_data, kept_contig_filter)) + vds = VariantDataset( + hl.filter_intervals(vds.reference_data, kept_contig_filter), + hl.filter_intervals(vds.variant_data, kept_contig_filter), + ) if use_variant_dataset: mt = vds.variant_data @@ -590,27 +453,29 @@ def filter_variants(vds: 'VariantDataset', variants_table: 'Table', *, keep: boo return VariantDataset(vds.reference_data, variant_data) -@typecheck(vds=VariantDataset, - intervals=oneof(Table, expr_array(expr_interval(expr_any))), - keep=bool, - mode=enumeration('variants_only', 'split_at_boundaries', 'unchecked_filter_both')) -def _parameterized_filter_intervals(vds: 'VariantDataset', - intervals, - keep: bool, - mode: str) -> 'VariantDataset': +@typecheck( + vds=VariantDataset, + intervals=oneof(Table, expr_array(expr_interval(expr_any))), + keep=bool, + mode=enumeration('variants_only', 'split_at_boundaries', 'unchecked_filter_both'), +) +def _parameterized_filter_intervals(vds: 'VariantDataset', intervals, keep: bool, mode: str) -> 'VariantDataset': intervals_table = None if isinstance(intervals, Table): expected = hl.tinterval(hl.tlocus(vds.reference_genome)) if len(intervals.key) != 1 or intervals.key[0].dtype != hl.tinterval(hl.tlocus(vds.reference_genome)): raise ValueError( f"'filter_intervals': expect a table with a single key of type {expected}; " - f"found {list(intervals.key.dtype.values())}") + f"found {list(intervals.key.dtype.values())}" + ) intervals_table = intervals intervals = hl.literal(intervals.aggregate(hl.agg.collect(intervals.key[0]), _localize=False)) if mode == 'unchecked_filter_both': - return VariantDataset(hl.filter_intervals(vds.reference_data, intervals, keep), - hl.filter_intervals(vds.variant_data, intervals, keep)) + return VariantDataset( + hl.filter_intervals(vds.reference_data, intervals, keep), + hl.filter_intervals(vds.variant_data, intervals, keep), + ) reference_data = vds.reference_data if keep: @@ -618,13 +483,17 @@ def _parameterized_filter_intervals(vds: 'VariantDataset', if rbml in vds.reference_data.globals: max_len = hl.eval(vds.reference_data.index_globals()[rbml]) ref_intervals = intervals.map( - lambda interval: hl.interval(interval.start - (max_len - 1), interval.end, interval.includes_start, - interval.includes_end)) + lambda interval: hl.interval( + interval.start - (max_len - 1), interval.end, interval.includes_start, interval.includes_end + ) + ) reference_data = hl.filter_intervals(reference_data, ref_intervals, keep) else: - warning("'hl.vds.filter_intervals': filtering intervals without a known max reference block length" - "\n (computed by `hl.vds.store_ref_block_max_length` or 'hl.vds.truncate_reference_blocks')" - "\n requires a full pass over the reference data (expensive!)") + warning( + "'hl.vds.filter_intervals': filtering intervals without a known max reference block length" + "\n (computed by `hl.vds.store_ref_block_max_length` or 'hl.vds.truncate_reference_blocks')" + "\n requires a full pass over the reference data (expensive!)" + ) if mode == 'variants_only': variant_data = hl.filter_intervals(vds.variant_data, intervals, keep) @@ -635,17 +504,20 @@ def _parameterized_filter_intervals(vds: 'VariantDataset', par_intervals = intervals_table or hl.Table.parallelize( intervals.map(lambda x: hl.struct(interval=x)), schema=hl.tstruct(interval=intervals.dtype.element_type), - key='interval') - ref = segment_reference_blocks(reference_data, par_intervals).drop('interval_end', - list(par_intervals.key)[0]) - return VariantDataset(ref, - hl.filter_intervals(vds.variant_data, intervals, keep)) + key='interval', + ) + ref = segment_reference_blocks(reference_data, par_intervals).drop( + 'interval_end', next(iter(par_intervals.key)) + ) + return VariantDataset(ref, hl.filter_intervals(vds.variant_data, intervals, keep)) -@typecheck(vds=VariantDataset, - keep=nullable(oneof(str, sequenceof(str))), - remove=nullable(oneof(str, sequenceof(str))), - keep_autosomes=bool) +@typecheck( + vds=VariantDataset, + keep=nullable(oneof(str, sequenceof(str))), + remove=nullable(oneof(str, sequenceof(str))), + keep_autosomes=bool, +) def filter_chromosomes(vds: 'VariantDataset', *, keep=None, remove=None, keep_autosomes=False) -> 'VariantDataset': """Filter chromosomes of a :class:`.VariantDataset` in several possible modes. @@ -683,8 +555,10 @@ def filter_chromosomes(vds: 'VariantDataset', *, keep=None, remove=None, keep_au if n_args_passed == 0: raise ValueError("filter_chromosomes: expect one of 'keep', 'remove', or 'keep_autosomes' arguments") if n_args_passed > 1: - raise ValueError("filter_chromosomes: expect ONLY one of 'keep', 'remove', or 'keep_autosomes' arguments" - "\n In order use 'keep_autosomes' with 'keep' or 'remove', call the function twice") + raise ValueError( + "filter_chromosomes: expect ONLY one of 'keep', 'remove', or 'keep_autosomes' arguments" + "\n In order use 'keep_autosomes' with 'keep' or 'remove', call the function twice" + ) rg = vds.reference_genome @@ -705,22 +579,20 @@ def filter_chromosomes(vds: 'VariantDataset', *, keep=None, remove=None, keep_au to_keep.append(c) parsed_intervals = hl.literal(to_keep, hl.tarray(hl.tstr)).map( - lambda c: hl.parse_locus_interval(c, reference_genome=rg)) - return _parameterized_filter_intervals(vds, - intervals=parsed_intervals, - keep=True, - mode='unchecked_filter_both') - - -@typecheck(vds=VariantDataset, - intervals=oneof(Table, expr_array(expr_interval(expr_any))), - split_reference_blocks=bool, - keep=bool) -def filter_intervals(vds: 'VariantDataset', - intervals, - *, - split_reference_blocks: bool = False, - keep: bool = True) -> 'VariantDataset': + lambda c: hl.parse_locus_interval(c, reference_genome=rg) + ) + return _parameterized_filter_intervals(vds, intervals=parsed_intervals, keep=True, mode='unchecked_filter_both') + + +@typecheck( + vds=VariantDataset, + intervals=oneof(Table, expr_array(expr_interval(expr_any))), + split_reference_blocks=bool, + keep=bool, +) +def filter_intervals( + vds: 'VariantDataset', intervals, *, split_reference_blocks: bool = False, keep: bool = True +) -> 'VariantDataset': """Filter intervals in a :class:`.VariantDataset`. Parameters @@ -743,8 +615,9 @@ def filter_intervals(vds: 'VariantDataset', """ if split_reference_blocks and not keep: raise ValueError("'filter_intervals': cannot use 'split_reference_blocks' with keep=False") - return _parameterized_filter_intervals(vds, intervals, keep=keep, - mode='split_at_boundaries' if split_reference_blocks else 'variants_only') + return _parameterized_filter_intervals( + vds, intervals, keep=keep, mode='split_at_boundaries' if split_reference_blocks else 'variants_only' + ) @typecheck(vds=VariantDataset, filter_changed_loci=bool) @@ -791,15 +664,20 @@ def segment_reference_blocks(ref: 'MatrixTable', intervals: 'Table') -> 'MatrixT ------- :class:`.MatrixTable` """ - interval_field = list(intervals.key)[0] + interval_field = next(iter(intervals.key)) if not intervals[interval_field].dtype == hl.tinterval(ref.locus.dtype): - raise ValueError(f"expect intervals to be keyed by intervals of loci matching the VariantDataset:" - f" found {intervals[interval_field].dtype} / {ref.locus.dtype}") + raise ValueError( + f"expect intervals to be keyed by intervals of loci matching the VariantDataset:" + f" found {intervals[interval_field].dtype} / {ref.locus.dtype}" + ) intervals = intervals.select(_interval_dup=intervals[interval_field]) if not intervals.aggregate( - hl.agg.all(intervals[interval_field].includes_start & ( - intervals[interval_field].start.contig == intervals[interval_field].end.contig))): + hl.agg.all( + intervals[interval_field].includes_start + & (intervals[interval_field].start.contig == intervals[interval_field].end.contig) + ) + ): raise ValueError("expect intervals to be start-inclusive") starts = intervals.key_by(_start_locus=intervals[interval_field].start) @@ -811,17 +689,25 @@ def segment_reference_blocks(ref: 'MatrixTable', intervals: 'Table') -> 'MatrixT contig_idx_map = hl.literal({contigs[i]: i for i in range(len(contigs))}, 'dict') joined = joined.annotate(__contig_idx=contig_idx_map[joined.locus.contig]) joined = joined.annotate( - _ref_entries=joined._ref_entries.map(lambda e: e.annotate(__contig_idx=joined.__contig_idx))) + _ref_entries=joined._ref_entries.map(lambda e: e.annotate(__contig_idx=joined.__contig_idx)) + ) dense = joined.annotate( dense_ref=hl.or_missing( joined._include_locus, - hl.rbind(joined.locus.position, - lambda pos: hl.enumerate(hl.scan._densify(hl.len(joined._ref_cols), joined._ref_entries)) - .map(lambda idx_and_e: hl.rbind(idx_and_e[0], idx_and_e[1], - lambda idx, e: hl.coalesce(joined._ref_entries[idx], hl.or_missing( - (e.__contig_idx == joined.__contig_idx) & (e.END >= pos), - e))).drop('__contig_idx')) - )) + hl.rbind( + joined.locus.position, + lambda pos: hl.enumerate(hl.scan._densify(hl.len(joined._ref_cols), joined._ref_entries)).map( + lambda idx_and_e: hl.rbind( + idx_and_e[0], + idx_and_e[1], + lambda idx, e: hl.coalesce( + joined._ref_entries[idx], + hl.or_missing((e.__contig_idx == joined.__contig_idx) & (e.END >= pos), e), + ), + ).drop('__contig_idx') + ), + ), + ) ) dense = dense.filter(dense._include_locus).drop('_interval_dup', '_include_locus', '__contig_idx') @@ -832,25 +718,43 @@ def segment_reference_blocks(ref: 'MatrixTable', intervals: 'Table') -> 'MatrixT # remove rows that are not contained in an interval, and rows that are the start of an # interval (interval starts come from the 'dense' table) refl_filtered = refl_filtered.filter( - hl.is_defined(refl_filtered[interval_field]) & (refl_filtered.locus != refl_filtered[interval_field].start)) + hl.is_defined(refl_filtered[interval_field]) & (refl_filtered.locus != refl_filtered[interval_field].start) + ) # union dense interval starts with filtered table refl_filtered = refl_filtered.union(dense.transmute(_ref_entries=dense.dense_ref)) # rewrite reference blocks to end at the first of (interval end, reference block end) refl_filtered = refl_filtered.annotate( - interval_end=refl_filtered[interval_field].end.position - ~refl_filtered[interval_field].includes_end) + interval_end=refl_filtered[interval_field].end.position - ~refl_filtered[interval_field].includes_end + ) refl_filtered = refl_filtered.annotate( _ref_entries=refl_filtered._ref_entries.map( - lambda entry: entry.annotate(END=hl.min(entry.END, refl_filtered.interval_end)))) + lambda entry: entry.annotate(END=hl.min(entry.END, refl_filtered.interval_end)) + ) + ) return refl_filtered._unlocalize_entries('_ref_entries', '_ref_cols', list(ref.col_key)) -@typecheck(vds=VariantDataset, intervals=Table, gq_thresholds=sequenceof(int), dp_thresholds=sequenceof(int), - dp_field=nullable(str)) -def interval_coverage(vds: VariantDataset, intervals: hl.Table, gq_thresholds=(0, 10, 20,), dp_thresholds=(0, 1, 10, 20, 30), - dp_field=None) -> 'MatrixTable': +@typecheck( + vds=VariantDataset, + intervals=Table, + gq_thresholds=sequenceof(int), + dp_thresholds=sequenceof(int), + dp_field=nullable(str), +) +def interval_coverage( + vds: VariantDataset, + intervals: Table, + gq_thresholds=( + 0, + 10, + 20, + ), + dp_thresholds=(0, 1, 10, 20, 30), + dp_field=None, +) -> 'MatrixTable': """Compute statistics about base coverage by interval. Returns a :class:`.MatrixTable` with interval row keys and sample column keys. @@ -915,49 +819,56 @@ def interval_coverage(vds: VariantDataset, intervals: hl.Table, gq_thresholds=(0 else: dp_field_to_use = dp_field - ref_block_length = (split.END - split.locus.position + 1) + ref_block_length = split.END - split.locus.position + 1 if dp_field_to_use is not None: dp = split[dp_field_to_use] - dp_field_dict = {'sum_dp': hl.agg.sum(ref_block_length * dp), - 'bases_over_dp_threshold': tuple( - hl.agg.filter(dp >= dp_threshold, hl.agg.sum(ref_block_length)) for dp_threshold in - dp_thresholds)} + dp_field_dict = { + 'sum_dp': hl.agg.sum(ref_block_length * dp), + 'bases_over_dp_threshold': tuple( + hl.agg.filter(dp >= dp_threshold, hl.agg.sum(ref_block_length)) for dp_threshold in dp_thresholds + ), + } else: dp_field_dict = dict() - per_interval = split.group_rows_by(interval=intervals[split.row_key[0]].interval_dup) \ - .aggregate( + per_interval = split.group_rows_by(interval=intervals[split.row_key[0]].interval_dup).aggregate( bases_over_gq_threshold=tuple( - hl.agg.filter(split.GQ >= gq_threshold, hl.agg.sum(ref_block_length)) for gq_threshold in - gq_thresholds), - **dp_field_dict + hl.agg.filter(split.GQ >= gq_threshold, hl.agg.sum(ref_block_length)) for gq_threshold in gq_thresholds + ), + **dp_field_dict, ) interval = per_interval.interval - interval_size = interval.end.position + interval.includes_end - interval.start.position - 1 + interval.includes_start + interval_size = ( + interval.end.position + interval.includes_end - interval.start.position - 1 + interval.includes_start + ) per_interval = per_interval.annotate_rows(interval_size=interval_size) dp_mod_dict = {} if dp_field_to_use is not None: dp_mod_dict['fraction_over_dp_threshold'] = tuple( - hl.float(x) / per_interval.interval_size for x in per_interval.bases_over_dp_threshold) + hl.float(x) / per_interval.interval_size for x in per_interval.bases_over_dp_threshold + ) dp_mod_dict['mean_dp'] = per_interval.sum_dp / per_interval.interval_size per_interval = per_interval.annotate_entries( fraction_over_gq_threshold=tuple( - hl.float(x) / per_interval.interval_size for x in per_interval.bases_over_gq_threshold), - **dp_mod_dict) + hl.float(x) / per_interval.interval_size for x in per_interval.bases_over_gq_threshold + ), + **dp_mod_dict, + ) per_interval = per_interval.annotate_globals(gq_thresholds=hl.tuple(gq_thresholds)) return per_interval -@typecheck(ds=oneof(MatrixTable, VariantDataset), - max_ref_block_base_pairs=nullable(int), - ref_block_winsorize_fraction=nullable(float)) -def truncate_reference_blocks(ds, *, max_ref_block_base_pairs=None, - ref_block_winsorize_fraction=None): +@typecheck( + ds=oneof(MatrixTable, VariantDataset), + max_ref_block_base_pairs=nullable(int), + ref_block_winsorize_fraction=nullable(float), +) +def truncate_reference_blocks(ds, *, max_ref_block_base_pairs=None, ref_block_winsorize_fraction=None): """Cap reference blocks at a maximum length in order to permit faster interval filtering. Examples @@ -1006,42 +917,54 @@ def truncate_reference_blocks(ds, *, max_ref_block_base_pairs=None, if int(ref_block_winsorize_fraction is None) + int(max_ref_block_base_pairs is None) != 1: raise ValueError( - 'truncate_reference_blocks: require exactly one of "max_ref_block_base_pairs", "ref_block_winsorize_fraction"') + 'truncate_reference_blocks: require exactly one of "max_ref_block_base_pairs", "ref_block_winsorize_fraction"' + ) if ref_block_winsorize_fraction is not None: - assert ref_block_winsorize_fraction > 0 and ref_block_winsorize_fraction < 1, \ - 'truncate_reference_blocks: "ref_block_winsorize_fraction" must be between 0 and 1 (e.g. 0.01 to truncate the top 1% of reference blocks)' + assert ( + ref_block_winsorize_fraction > 0 and ref_block_winsorize_fraction < 1 + ), 'truncate_reference_blocks: "ref_block_winsorize_fraction" must be between 0 and 1 (e.g. 0.01 to truncate the top 1% of reference blocks)' if ref_block_winsorize_fraction > 0.1: warning( f"'truncate_reference_blocks': ref_block_winsorize_fraction of {ref_block_winsorize_fraction} will lead to significant data duplication," - f" recommended values are <0.05.") + f" recommended values are <0.05." + ) max_ref_block_base_pairs = rd.aggregate_entries( - hl.agg.approx_quantiles(rd.END - rd.locus.position + 1, 1 - ref_block_winsorize_fraction, k=200)) + hl.agg.approx_quantiles(rd.END - rd.locus.position + 1, 1 - ref_block_winsorize_fraction, k=200) + ) - assert max_ref_block_base_pairs > 0, \ - 'truncate_reference_blocks: "max_ref_block_base_pairs" must be between greater than zero' + assert ( + max_ref_block_base_pairs > 0 + ), 'truncate_reference_blocks: "max_ref_block_base_pairs" must be between greater than zero' info(f"splitting VDS reference blocks at {max_ref_block_base_pairs} base pairs") - rd_under_limit = (rd.filter_entries(rd.END - rd.locus.position < max_ref_block_base_pairs) - .localize_entries('fixed_blocks', 'cols')) + rd_under_limit = rd.filter_entries(rd.END - rd.locus.position < max_ref_block_base_pairs).localize_entries( + 'fixed_blocks', 'cols' + ) rd_over_limit = rd.filter_entries(rd.END - rd.locus.position >= max_ref_block_base_pairs).key_cols_by( - col_idx=hl.scan.count()) + col_idx=hl.scan.count() + ) rd_over_limit = rd_over_limit.select_rows().select_cols().key_rows_by().key_cols_by() es = rd_over_limit.entries() es = es.annotate(new_start=hl.range(es.locus.position, es.END + 1, max_ref_block_base_pairs)) es = es.explode('new_start') - es = es.transmute(locus=hl.locus(es.locus.contig, es.new_start, reference_genome=es.locus.dtype.reference_genome), - END=hl.min(es.new_start + max_ref_block_base_pairs - 1, es.END)) + es = es.transmute( + locus=hl.locus(es.locus.contig, es.new_start, reference_genome=es.locus.dtype.reference_genome), + END=hl.min(es.new_start + max_ref_block_base_pairs - 1, es.END), + ) es = es.key_by(es.locus).collect_by_key("new_blocks") - es = es.transmute(moved_blocks_dict=hl.dict(es.new_blocks - .map(lambda x: (x.col_idx, x.drop('col_idx'))))) + es = es.transmute(moved_blocks_dict=hl.dict(es.new_blocks.map(lambda x: (x.col_idx, x.drop('col_idx'))))) joined = rd_under_limit.join(es, how='outer') - joined = joined.transmute(merged_blocks=hl.range(hl.len(joined.cols)).map( - lambda idx: hl.coalesce(joined.moved_blocks_dict.get(idx), joined.fixed_blocks[idx]))) - new_rd = joined._unlocalize_entries(entries_field_name='merged_blocks', cols_field_name='cols', - col_key=list(rd.col_key)) + joined = joined.transmute( + merged_blocks=hl.range(hl.len(joined.cols)).map( + lambda idx: hl.coalesce(joined.moved_blocks_dict.get(idx), joined.fixed_blocks[idx]) + ) + ) + new_rd = joined._unlocalize_entries( + entries_field_name='merged_blocks', cols_field_name='cols', col_key=list(rd.col_key) + ) new_rd = new_rd.annotate_globals(**{fd_name: max_ref_block_base_pairs}) if isinstance(ds, hl.vds.VariantDataset): @@ -1049,9 +972,11 @@ def truncate_reference_blocks(ds, *, max_ref_block_base_pairs=None, return new_rd -@typecheck(ds=oneof(MatrixTable, VariantDataset), - equivalence_function=func_spec(2, expr_bool), - merge_functions=nullable(dictof(str, oneof(str, func_spec(1, expr_any))))) +@typecheck( + ds=oneof(MatrixTable, VariantDataset), + equivalence_function=func_spec(2, expr_bool), + merge_functions=nullable(dictof(str, oneof(str, func_spec(1, expr_any)))), +) def merge_reference_blocks(ds, equivalence_function, merge_functions=None): """Merge adjacent reference blocks according to user equivalence criteria. @@ -1096,23 +1021,33 @@ def merge(block1, block2): if merge_functions: for k, f in merge_functions.items(): if isinstance(f, str): - f = f.lower() - if f == 'min': - def f(b1, b2): + _f = f.lower() + if _f == 'min': + + def __f(b1, b2): return hl.min(block1[k], block2[k]) - elif f == 'max': - def f(b1, b2): + + elif _f == 'max': + + def __f(b1, b2): return hl.max(block1[k], block2[k]) - elif f == 'sum': - def f(b1, b2): + + elif _f == 'sum': + + def __f(b1, b2): return block1[k] + block2[k] + else: - raise ValueError(f"merge_reference_blocks: unknown merge function {f!r}," - f" support 'min', 'max', and 'sum' in addition to custom lambdas") - new_value = f(block1, block2) + raise ValueError( + f"merge_reference_blocks: unknown merge function {_f!r}," + f" support 'min', 'max', and 'sum' in addition to custom lambdas" + ) + new_value = __f(block1, block2) if new_value.dtype != block1[k].dtype: - raise ValueError(f'merge_reference_blocks: merge_function for {k!r}: new type {new_value.dtype!r} ' - f'differs from original type {block1[k].dtype!r}') + raise ValueError( + f'merge_reference_blocks: merge_function for {k!r}: new type {new_value.dtype!r} ' + f'differs from original type {block1[k].dtype!r}' + ) new_fields[k] = new_value return block1.annotate(**new_fields) @@ -1120,49 +1055,67 @@ def keep_last(t1, t2): e1 = t1[0] e2 = t2[0] are_adjacent = (e1.contig_idx == e2.contig_idx) & (e1.END + 1 == e2.start_pos) - return hl.if_else(hl.is_defined(e1) & hl.is_defined(e2) & are_adjacent & equivalence_function(e1, e2), - (merge(e1, e2), True), - t2) + return hl.if_else( + hl.is_defined(e1) & hl.is_defined(e2) & are_adjacent & equivalence_function(e1, e2), + (merge(e1, e2), True), + t2, + ) # approximate a scan that merges before result - ht = ht.annotate(prev_block=hl.zip(hl.scan.array_agg(lambda elt: hl.scan.fold((hl.missing(rd.entry.dtype), False), - lambda acc: keep_last(acc, ( - elt, False)), - keep_last), ht.entries), ht.entries) - .map(lambda tup: keep_last(tup[0], (tup[1], False)))) + ht = ht.annotate( + prev_block=hl.zip( + hl.scan.array_agg( + lambda elt: hl.scan.fold( + (hl.missing(rd.entry.dtype), False), lambda acc: keep_last(acc, (elt, False)), keep_last + ), + ht.entries, + ), + ht.entries, + ).map(lambda tup: keep_last(tup[0], (tup[1], False))) + ) ht_join = ht ht = ht.key_by() - ht = ht.select(to_shuffle=hl.enumerate(ht.prev_block) - .filter(lambda idx_and_elt: hl.is_defined(idx_and_elt[1]) & idx_and_elt[1][1])) + ht = ht.select( + to_shuffle=hl.enumerate(ht.prev_block).filter( + lambda idx_and_elt: hl.is_defined(idx_and_elt[1]) & idx_and_elt[1][1] + ) + ) ht = ht.explode('to_shuffle') rg = rd.locus.dtype.reference_genome ht = ht.transmute(col_idx=ht.to_shuffle[0], entry=ht.to_shuffle[1][0]) ht_shuf = ht.key_by( - locus=hl.locus(hl.literal(rg.contigs)[ht.entry.contig_idx], ht.entry.start_pos, reference_genome=rg)) + locus=hl.locus(hl.literal(rg.contigs)[ht.entry.contig_idx], ht.entry.start_pos, reference_genome=rg) + ) ht_shuf = ht_shuf.collect_by_key("new_starts") # new_starts can contain multiple records for a collapsed ref block, one for each folded block. # We want to keep the one with the highest END - ht_shuf = ht_shuf.select(moved_blocks_dict=hl.group_by(lambda elt: elt.col_idx, ht_shuf.new_starts) - .map_values( - lambda arr: arr[hl.argmax(arr.map(lambda x: x.entry.END))].entry.drop('contig_idx', 'start_pos'))) + ht_shuf = ht_shuf.select( + moved_blocks_dict=hl.group_by(lambda elt: elt.col_idx, ht_shuf.new_starts).map_values( + lambda arr: arr[hl.argmax(arr.map(lambda x: x.entry.END))].entry.drop('contig_idx', 'start_pos') + ) + ) ht_joined = ht_join.join(ht_shuf.select_globals(), 'left') def merge_f(tup): (idx, original_entry) = tup - return (hl.case() - .when(~(hl.coalesce(ht_joined.prev_block[idx][1], False)), - hl.coalesce(ht_joined.moved_blocks_dict.get(idx), original_entry.drop('contig_idx', 'start_pos'))) - .or_missing()) + return ( + hl.case() + .when( + ~(hl.coalesce(ht_joined.prev_block[idx][1], False)), + hl.coalesce(ht_joined.moved_blocks_dict.get(idx), original_entry.drop('contig_idx', 'start_pos')), + ) + .or_missing() + ) - ht_joined = ht_joined.annotate(new_entries=hl.enumerate(ht_joined.entries) - .map(lambda tup: merge_f(tup))) + ht_joined = ht_joined.annotate(new_entries=hl.enumerate(ht_joined.entries).map(lambda tup: merge_f(tup))) ht_joined = ht_joined.drop('moved_blocks_dict', 'entries', 'prev_block', 'contig_idx_row', 'start_pos_row') - new_rd = ht_joined._unlocalize_entries(entries_field_name='new_entries', cols_field_name='cols', - col_key=list(rd.col_key)) + new_rd = ht_joined._unlocalize_entries( + entries_field_name='new_entries', cols_field_name='cols', col_key=list(rd.col_key) + ) rbml = hl.vds.VariantDataset.ref_block_max_length_field if rbml in new_rd.globals: diff --git a/hail/python/hail/vds/sample_qc.py b/hail/python/hail/vds/sample_qc.py new file mode 100644 index 00000000000..9ec90a4a667 --- /dev/null +++ b/hail/python/hail/vds/sample_qc.py @@ -0,0 +1,408 @@ +from collections.abc import Sequence +from typing import Optional + +import hail as hl +from hail.expr.expressions import Expression +from hail.expr.expressions.typed_expressions import ( + ArrayExpression, + CallExpression, + LocusExpression, + NumericExpression, + StructExpression, +) +from hail.genetics.allele_type import AlleleType +from hail.methods.misc import require_first_key_field_locus +from hail.methods.qc import _qc_allele_type +from hail.table import Table +from hail.typecheck import nullable, sequenceof, typecheck +from hail.utils.java import Env +from hail.utils.misc import divide_null +from hail.vds.variant_dataset import VariantDataset + + +@typecheck(global_gt=Expression, alleles=ArrayExpression) +def vmt_sample_qc_variant_annotations( + *, + global_gt: 'Expression', + alleles: 'ArrayExpression', +) -> tuple['Expression', 'Expression']: + """Compute the necessary variant annotations for :func:`.vmt_sample_qc`, that is, + allele count (AC) and an integer representation of allele type. + + Parameters + ---------- + global_gt : :class:`.Expression` + Call expression of the global GT of a variants matrix table usually generated + by :func:`..lgt_to_gt` + alleles : :class:`.ArrayExpression` + Array expression of the alleles of a variants matrix table + (generally ``vds.variant_data.alleles``) + + Returns + ------- + :class:`tuple` + Tuple of expressions representing the AC (first element) and allele type + (second element). + """ + + return (hl.agg.call_stats(global_gt, alleles).AC, alleles[1:].map(lambda alt: _qc_allele_type(alleles[0], alt))) + + +@typecheck( + global_gt=Expression, + gq=Expression, + variant_ac=ArrayExpression, + variant_atypes=ArrayExpression, + dp=nullable(Expression), + gq_bins=sequenceof(int), + dp_bins=sequenceof(int), +) +def vmt_sample_qc( + *, + global_gt: 'CallExpression', + gq: 'Expression', + variant_ac: 'ArrayExpression', + variant_atypes: 'ArrayExpression', + dp: Optional['Expression'] = None, + gq_bins: 'Sequence[int]' = (0, 20, 60), + dp_bins: 'Sequence[int]' = (0, 1, 10, 20, 30), +) -> 'Expression': + """Computes sample quality metrics from variant data of a VDS + + Parameters + ---------- + global_gt : :class:`.CallExpression` + Global GT of a variants matrix table or subset thereof (ex. ``hl.agg.group_by``). + gq : :class:`.Expression` + GQ of a variants matrix table. + variant_ac : :class:`.ArrayExpression` + Allele counts of a the genotypes of a variants matrix table. This can + be generated by ``hl.agg.call_stats`` or alternatively + :func:`.vmt_sample_qc_variant_annotations` (which calls ``call_stats`` + internally) + variant_atypes : :class:`.ArrayExpression` + Allele types of the alternate alleles a variants matrix table. This + must be generated with :func:`.vmt_sample_qc_variant_annotations` in + order to return correct results. + dp : :class:`.Expression` or :obj:`NoneType` + DP of a variants matrix table (or ``None``) + gq_bins : :class:`tuple` of :obj:`int` + Tuple containing cutoffs for genotype quality (GQ) scores. + dp_bins : :class:`tuple` of :obj:`int` + Tuple containing cutoffs for depth (DP) scores. + + Returns + ------- + :class:`.StructExpression` + A struct expression of type:: + + struct{ + bases_over_gq_threshold: tuple(int64 * len(gq_bins)), + bases_over_dp_threshold: tuple(int64 * len(gq_bins)), # present if dp is not None + n_het: int64, + n_hom_var: int64, + n_non_ref: int64, + n_singleton: int64, + n_singleton_ti: int64, + n_singleton_tv: int64, + n_snp: int64, + n_insertion: int64, + n_deletion: int64, + n_transition: int64, + n_transversion: int64, + n_star: int64, + r_ti_tv: float64, + r_ti_tv_singleton: float64, + r_het_hom_var: float64, + r_insertion_deletion: float64, + } + + """ + bound_exprs = {} + + bound_exprs['n_het'] = hl.agg.count_where(global_gt.is_het()) + bound_exprs['n_hom_var'] = hl.agg.count_where(global_gt.is_hom_var()) + bound_exprs['n_singleton'] = hl.agg.sum( + hl.rbind( + global_gt, + lambda global_gt: hl.sum( + hl.range(0, global_gt.ploidy).map( + lambda i: hl.rbind(global_gt[i], lambda gti: (gti != 0) & (variant_ac[gti] == 1)) + ) + ), + ) + ) + bound_exprs['n_singleton_ti'] = hl.agg.sum( + hl.rbind( + global_gt, + lambda global_gt: hl.sum( + hl.range(0, global_gt.ploidy).map( + lambda i: hl.rbind( + global_gt[i], + lambda gti: (gti != 0) + & (variant_ac[gti] == 1) + & (variant_atypes[gti - 1] == AlleleType.TRANSITION), + ) + ) + ), + ) + ) + bound_exprs['n_singleton_tv'] = hl.agg.sum( + hl.rbind( + global_gt, + lambda global_gt: hl.sum( + hl.range(0, global_gt.ploidy).map( + lambda i: hl.rbind( + global_gt[i], + lambda gti: (gti != 0) + & (variant_ac[gti] == 1) + & (variant_atypes[gti - 1] == AlleleType.TRANSVERSION), + ) + ) + ), + ) + ) + + bound_exprs['allele_type_counts'] = hl.agg.explode( + lambda allele_type: hl.tuple(hl.agg.count_where(allele_type == i) for i in range(len(AlleleType))), + ( + hl.range(0, global_gt.ploidy) + .map(lambda i: global_gt[i]) + .filter(lambda allele_idx: allele_idx > 0) + .map(lambda allele_idx: variant_atypes[allele_idx - 1]) + ), + ) + + dp_exprs = {} + if dp is not None: + dp_exprs['bases_over_dp_threshold'] = hl.tuple(hl.agg.count_where(dp >= x) for x in dp_bins) + + gq_dp_exprs = {'bases_over_gq_threshold': hl.tuple(hl.agg.count_where(gq >= x) for x in gq_bins), **dp_exprs} + + return hl.rbind( + hl.struct(**bound_exprs), + lambda x: hl.rbind( + hl.struct(**{ + **gq_dp_exprs, + 'n_het': x.n_het, + 'n_hom_var': x.n_hom_var, + 'n_non_ref': x.n_het + x.n_hom_var, + 'n_singleton': x.n_singleton, + 'n_singleton_ti': x.n_singleton_ti, + 'n_singleton_tv': x.n_singleton_tv, + 'n_snp': x.allele_type_counts[AlleleType.TRANSITION] + x.allele_type_counts[AlleleType.TRANSVERSION], + 'n_insertion': x.allele_type_counts[AlleleType.INSERTION], + 'n_deletion': x.allele_type_counts[AlleleType.DELETION], + 'n_transition': x.allele_type_counts[AlleleType.TRANSITION], + 'n_transversion': x.allele_type_counts[AlleleType.TRANSVERSION], + 'n_star': x.allele_type_counts[AlleleType.STAR], + }), + lambda s: s.annotate( + r_ti_tv=divide_null(hl.float64(s.n_transition), s.n_transversion), + r_ti_tv_singleton=divide_null(hl.float64(s.n_singleton_ti), s.n_singleton_tv), + r_het_hom_var=divide_null(hl.float64(s.n_het), s.n_hom_var), + r_insertion_deletion=divide_null(hl.float64(s.n_insertion), s.n_deletion), + ), + ), + ) + + +@typecheck( + locus=LocusExpression, + gq=NumericExpression, + end=NumericExpression, + dp=nullable(Expression), + gq_bins=sequenceof(int), + dp_bins=sequenceof(int), +) +def rmt_sample_qc( + *, + locus: 'LocusExpression', + end: 'NumericExpression', + gq: 'NumericExpression', + dp: Optional['Expression'] = None, + gq_bins: 'Sequence[int]' = (0, 20, 60), + dp_bins: 'Sequence[int]' = (0, 1, 10, 20, 30), +) -> 'StructExpression': + """Computes sample quality metrics from reference data of a VDS + Parameters + ---------- + locus : :class:`.LocusExpression` + Locus of a refrence matrix table + end : :class:`.NumericExpression` + END of a reference matrix table + gq : :class:`.Expression` + GQ of a variants matrix table. + dp : :class:`.Expression` or :obj:`NoneType` + DP of a variants matrix table (or ``None``) + gq_bins : :class:`tuple` of :obj:`int` + Tuple containing cutoffs for genotype quality (GQ) scores. + dp_bins : :class:`tuple` of :obj:`int` + Tuple containing cutoffs for depth (DP) scores. + + Returns + ------- + :class:`.StructExpression` + A struct expression of type:: + + struct{ + bases_over_gq_threshold: tuple(int64 * len(gq_bins)), + bases_over_dp_threshold: tuple(int64 * len(dp_bins)), # present if dp is not None + } + + """ + ref_dp_expr = {} + if dp is not None: + ref_dp_expr['bases_over_dp_threshold'] = hl.tuple( + hl.agg.filter(dp >= x, hl.agg.sum(1 + end - locus.position)) for x in dp_bins + ) + return hl.struct( + bases_over_gq_threshold=hl.tuple(hl.agg.filter(gq >= x, hl.agg.sum(1 + end - locus.position)) for x in gq_bins), + **ref_dp_expr, + ) + + +def combine_sample_qc( + rmt_sample_qc: Expression, + vmt_sample_qc: Expression, +) -> Expression: + """Combine reference and variants sample quality results + Parameters + ---------- + rmt_sample_qc : :class:`.Expression` + A struct expression produced by :func:`.rmt_sample_qc` + vmt_sample_qc : :class:`.Expression` + A struct expression produced by :func:`.vmt_sample_qc` + + Returns + ------- + :class:`.StructExpression` + A struct expression of type:: + + struct{ + bases_over_gq_threshold: + tuple(int64 * len(rmt_sample_qc.bases_over_gq_threshold)), + bases_over_dp_threshold: # present if dp was present for qc stats generation + tuple(int64 * len(rmt_sample_qc.bases_over_dp_threshold)), + } + + Note + ---- + It is the responsibility of the caller of this function to make sure that + the ``gq_bins`` and ``dp_bins`` that are used for the generation of both of + the arguments to this function are the same. Incorrect results will occur + if the bins are not the same. This function checks the length of the bins + used, but cannot check the bin values themselves. + """ + if 'bases_over_gq_threshold' not in rmt_sample_qc: + raise ValueError("Expect 'bases_over_gq_threshold' field in 'rmt_sample_qc' expression") + if 'bases_over_gq_threshold' not in vmt_sample_qc: + raise ValueError("Expect 'bases_over_gq_threshold' field in 'vmt_sample_qc' expression") + if sum('bases_over_dp_threshold' in expr for expr in (rmt_sample_qc, vmt_sample_qc)) % 2 == 1: + raise ValueError( + "Expect 'bases_over_dp_threshold' field in both or neither of " "'rmt_sample_qc' and 'vmt_sample_qc'" + ) + if len(rmt_sample_qc.bases_over_gq_threshold) != len(vmt_sample_qc.bases_over_gq_threshold): + raise ValueError("Expect same number of GQ bins for both variant and reference qc results") + if 'bases_over_dp_threshold' in rmt_sample_qc and len(rmt_sample_qc.bases_over_dp_threshold) != len( + vmt_sample_qc.bases_over_dp_threshold + ): + raise ValueError("Expect same number of DP bins for both variant and reference qc results") + + joined_dp_expr = {} + if 'bases_over_dp_threshold' in vmt_sample_qc: + joined_dp_expr['bases_over_dp_threshold'] = hl.tuple( + x + y for x, y in zip(vmt_sample_qc.bases_over_dp_threshold, rmt_sample_qc.bases_over_dp_threshold) + ) + + return hl.struct( + bases_over_gq_threshold=hl.tuple( + x + y for x, y in zip(vmt_sample_qc.bases_over_gq_threshold, rmt_sample_qc.bases_over_gq_threshold) + ), + **joined_dp_expr, + ) + + +@typecheck(vds=VariantDataset, gq_bins=sequenceof(int), dp_bins=sequenceof(int), dp_field=nullable(str)) +def sample_qc( + vds: 'VariantDataset', + *, + gq_bins: 'Sequence[int]' = (0, 20, 60), + dp_bins: 'Sequence[int]' = (0, 1, 10, 20, 30), + dp_field=None, +) -> 'Table': + """Compute sample quality metrics about a :class:`.VariantDataset`. + + If the `dp_field` parameter is not specified, the ``DP`` is used for depth + if present. If no ``DP`` field is present, the ``MIN_DP`` field is used. If no ``DP`` + or ``MIN_DP`` field is present, no depth statistics will be calculated. + + Parameters + ---------- + vds : :class:`.VariantDataset` + Dataset in VariantDataset representation. + gq_bins : :class:`tuple` of :obj:`int` + Tuple containing cutoffs for genotype quality (GQ) scores. + dp_bins : :class:`tuple` of :obj:`int` + Tuple containing cutoffs for depth (DP) scores. + dp_field : :obj:`str` + Name of depth field. If not supplied, DP or MIN_DP will be used, in that order. + + Returns + ------- + :class:`.Table` + Hail Table of results, keyed by sample. + """ + + require_first_key_field_locus(vds.reference_data, 'sample_qc') + require_first_key_field_locus(vds.variant_data, 'sample_qc') + + if dp_field is not None: + ref_dp_field_to_use = dp_field + elif 'DP' in vds.reference_data.entry: + ref_dp_field_to_use = 'DP' + elif 'MIN_DP' in vds.reference_data.entry: + ref_dp_field_to_use = 'MIN_DP' + else: + ref_dp_field_to_use = None + + vmt = vds.variant_data + if 'GT' not in vmt.entry: + vmt = vmt.annotate_entries(GT=hl.vds.lgt_to_gt(vmt.LGT, vmt.LA)) + allele_count, atypes = vmt_sample_qc_variant_annotations(global_gt=vmt.GT, alleles=vmt.alleles) + variant_ac = Env.get_uid() + variant_atypes = Env.get_uid() + vmt = vmt.annotate_rows(**{variant_ac: allele_count, variant_atypes: atypes}) + vmt_dp = vmt['DP'] if ref_dp_field_to_use is not None and 'DP' in vmt.entry else None + variant_results = vmt.select_cols( + **vmt_sample_qc( + global_gt=vmt.GT, + gq=vmt.GQ, + variant_ac=vmt[variant_ac], + variant_atypes=vmt[variant_atypes], + dp=vmt_dp, + gq_bins=gq_bins, + dp_bins=dp_bins, + ) + ).cols() + + rmt = vds.reference_data + rmt_dp = rmt[ref_dp_field_to_use] if ref_dp_field_to_use is not None else None + reference_results = rmt.select_cols( + **rmt_sample_qc( + locus=rmt.locus, + gq=rmt.GQ, + end=rmt.END, + dp=rmt_dp, + gq_bins=gq_bins, + dp_bins=dp_bins, + ) + ).cols() + + joined = reference_results[variant_results.key] + dp_bins_field = {} + if ref_dp_field_to_use is not None: + dp_bins_field['dp_bins'] = hl.tuple(dp_bins) + joined_results = variant_results.transmute(**combine_sample_qc(joined, variant_results.row)) + joined_results = joined_results.annotate_globals(gq_bins=hl.tuple(gq_bins), **dp_bins_field) + return joined_results diff --git a/hail/python/hail/vds/variant_dataset.py b/hail/python/hail/vds/variant_dataset.py index 388897298d3..db13b5ca2da 100644 --- a/hail/python/hail/vds/variant_dataset.py +++ b/hail/python/hail/vds/variant_dataset.py @@ -1,19 +1,24 @@ +import json import os import hail as hl +from hail.genetics import ReferenceGenome from hail.matrixtable import MatrixTable from hail.typecheck import typecheck_method from hail.utils.java import info, warning -from hail.genetics import ReferenceGenome - -import json extra_ref_globals_file = 'extra_reference_globals.json' -def read_vds(path, *, intervals=None, n_partitions=None, - _assert_reference_type=None, _assert_variant_type=None, - _warn_no_ref_block_max_length=True) -> 'VariantDataset': +def read_vds( + path, + *, + intervals=None, + n_partitions=None, + _assert_reference_type=None, + _assert_variant_type=None, + _warn_no_ref_block_max_length=True, +) -> 'VariantDataset': """Read in a :class:`.VariantDataset` written with :meth:`.VariantDataset.write`. Parameters @@ -44,10 +49,12 @@ def read_vds(path, *, intervals=None, n_partitions=None, metadata = json.load(f) vds.reference_data = vds.reference_data.annotate_globals(**metadata) elif _warn_no_ref_block_max_length: - warning("You are reading a VDS written with an older version of Hail." - "\n Hail now supports much faster interval filters on VDS, but you'll need to run either" - "\n `hl.vds.truncate_reference_blocks(vds, ...)` and write a copy (see docs) or patch the" - "\n existing VDS in place with `hl.vds.store_ref_block_max_length(vds_path)`.") + warning( + "You are reading a VDS written with an older version of Hail." + "\n Hail now supports much faster interval filters on VDS, but you'll need to run either" + "\n `hl.vds.truncate_reference_blocks(vds, ...)` and write a copy (see docs) or patch the" + "\n existing VDS in place with `hl.vds.store_ref_block_max_length(vds_path)`." + ) return vds @@ -75,7 +82,7 @@ def store_ref_block_max_length(vds_path): ---------- vds_path : :obj:`str` """ - vds = hl.vds.read_vds(vds_path) + vds = read_vds(vds_path, _warn_no_ref_block_max_length=False) if VariantDataset.ref_block_max_length_field in vds.reference_data.globals: warning(f"VDS at {vds_path} already contains a global annotation with the max reference block length") @@ -115,24 +122,23 @@ def _variants_path(base: str) -> str: return os.path.join(base, 'variant_data') @staticmethod - def from_merged_representation(mt, - *, - ref_block_fields=(), - infer_ref_block_fields: bool = True, - is_split=False): + def from_merged_representation(mt, *, ref_block_fields=(), infer_ref_block_fields: bool = True, is_split=False): """Create a VariantDataset from a sparse MatrixTable containing variant and reference data.""" if 'END' not in mt.entry: raise ValueError("VariantDataset.from_merged_representation: expect field 'END' in matrix table entry") if 'LA' not in mt.entry and not is_split: - raise ValueError("VariantDataset.from_merged_representation: expect field 'LA' in matrix table entry." - "\n If this dataset is already split into biallelics, use `is_split=True` to permit a conversion" - " with no LA field.") + raise ValueError( + "VariantDataset.from_merged_representation: expect field 'LA' in matrix table entry." + "\n If this dataset is already split into biallelics, use `is_split=True` to permit a conversion" + " with no LA field." + ) if 'GT' not in mt.entry and 'LGT' not in mt.entry: raise ValueError( - "VariantDataset.from_merged_representation: expect field 'LGT' or 'GT' in matrix table entry") + "VariantDataset.from_merged_representation: expect field 'LGT' or 'GT' in matrix table entry" + ) n_rows_to_use = 100 info(f"inferring reference block fields from missingness patterns in first {n_rows_to_use} rows") @@ -142,9 +148,13 @@ def from_merged_representation(mt, if infer_ref_block_fields: mt_head = mt.head(n_rows=n_rows_to_use) for k, any_present in zip( - list(mt_head.entry), - mt_head.aggregate_entries(hl.agg.filter(hl.is_defined(mt_head.END), tuple( - hl.agg.any(hl.is_defined(mt_head[x])) for x in mt_head.entry)))): + list(mt_head.entry), + mt_head.aggregate_entries( + hl.agg.filter( + hl.is_defined(mt_head.END), tuple(hl.agg.any(hl.is_defined(mt_head[x])) for x in mt_head.entry) + ) + ), + ): if any_present: used_ref_block_fields.add(k) @@ -156,15 +166,24 @@ def from_merged_representation(mt, if 'LA' in used_ref_block_fields: used_ref_block_fields.remove('LA') - info("Including the following fields in reference block table:" + "".join( - f"\n {k!r}" for k in mt.entry if k in used_ref_block_fields)) - - rmt = mt.filter_entries(hl.case() - .when(hl.is_missing(mt.END), False) - .when(hl.is_defined(mt.END) & mt[gt_field].is_hom_ref(), True) - .or_error(hl.str('cannot create VDS from merged representation -' - ' found END field with non-reference genotype at ') - + hl.str(mt.locus) + hl.str(' / ') + hl.str(mt.col_key[0]))) + info( + "Including the following fields in reference block table:" + + "".join(f"\n {k!r}" for k in mt.entry if k in used_ref_block_fields) + ) + + rmt = mt.filter_entries( + hl.case() + .when(hl.is_missing(mt.END), False) + .when(hl.is_defined(mt.END) & mt[gt_field].is_hom_ref(), True) + .or_error( + hl.str( + 'cannot create VDS from merged representation -' ' found END field with non-reference genotype at ' + ) + + hl.str(mt.locus) + + hl.str(' / ') + + hl.str(mt.col_key[0]) + ) + ) rmt = rmt.select_entries(*(x for x in rmt.entry if x in used_ref_block_fields)) rmt = rmt.filter_rows(hl.agg.count() > 0) @@ -219,31 +238,32 @@ def error(msg): raise ValueError(f'VDS.validate: {msg}') rd_row_key = rd.row_key.dtype - if (not isinstance(rd_row_key, hl.tstruct) - or len(rd_row_key) != 1 - or not rd_row_key.fields[0] == 'locus' - or not isinstance(rd_row_key.types[0], hl.tlocus)): + if ( + not isinstance(rd_row_key, hl.tstruct) + or len(rd_row_key) != 1 + or not rd_row_key.fields[0] == 'locus' + or not isinstance(rd_row_key.types[0], hl.tlocus) + ): error(f"expect reference data to have a single row key 'locus' of type locus, found {rd_row_key}") vd_row_key = vd.row_key.dtype - if (not isinstance(vd_row_key, hl.tstruct) - or len(vd_row_key) != 2 - or not vd_row_key.fields == ('locus', 'alleles') - or not isinstance(vd_row_key.types[0], hl.tlocus) - or vd_row_key.types[1] != hl.tarray(hl.tstr)): + if ( + not isinstance(vd_row_key, hl.tstruct) + or len(vd_row_key) != 2 + or not vd_row_key.fields == ('locus', 'alleles') + or not isinstance(vd_row_key.types[0], hl.tlocus) + or vd_row_key.types[1] != hl.tarray(hl.tstr) + ): error( - f"expect variant data to have a row key {{'locus': locus, alleles: array}}, found {vd_row_key}") + f"expect variant data to have a row key {{'locus': locus, alleles: array}}, found {vd_row_key}" + ) rd_col_key = rd.col_key.dtype - if (not isinstance(rd_col_key, hl.tstruct) - or len(rd_row_key) != 1 - or rd_col_key.types[0] != hl.tstr): + if not isinstance(rd_col_key, hl.tstruct) or len(rd_row_key) != 1 or rd_col_key.types[0] != hl.tstr: error(f"expect reference data to have a single col key of type string, found {rd_col_key}") vd_col_key = vd.col_key.dtype - if (not isinstance(vd_col_key, hl.tstruct) - or len(vd_col_key) != 1 - or vd_col_key.types[0] != hl.tstr): + if not isinstance(vd_col_key, hl.tstruct) or len(vd_col_key) != 1 or vd_col_key.types[0] != hl.tstr: error(f"expect variant data to have a single col key of type string, found {vd_col_key}") if 'END' not in rd.entry or rd.END.dtype != hl.tint32: @@ -255,14 +275,16 @@ def error(msg): var_cols = vd.col_key.collect() if len(ref_cols) != len(var_cols): error( - f"mismatch in number of columns: reference data has {ref_cols} columns, variant data has {var_cols} columns") + f"mismatch in number of columns: reference data has {ref_cols} columns, variant data has {var_cols} columns" + ) if ref_cols != var_cols: first_mismatch = 0 - while (ref_cols[first_mismatch] == var_cols[first_mismatch]): + while ref_cols[first_mismatch] == var_cols[first_mismatch]: first_mismatch += 1 error( - f"mismatch in columns keys: ref={ref_cols[first_mismatch]}, var={var_cols[first_mismatch]} at position {first_mismatch}") + f"mismatch in columns keys: ref={ref_cols[first_mismatch]}, var={var_cols[first_mismatch]} at position {first_mismatch}" + ) # check locus distinctness n_rd_rows = rd.count_rows() @@ -272,24 +294,40 @@ def error(msg): error(f'reference data loci are not distinct: found {n_rd_rows} rows, but {n_distinct} distinct loci') # check END field - (missing_end, end_before_position) = rd.aggregate_entries(( - hl.agg.filter(hl.is_missing(rd.END), hl.agg.take((rd.row_key, rd.col_key), 5)), - hl.agg.filter(rd.END < rd.locus.position, hl.agg.take((rd.row_key, rd.col_key), 5)), - )) - - if missing_end: + end_exprs = dict( + missing_end=hl.agg.filter(hl.is_missing(rd.END), hl.agg.take((rd.row_key, rd.col_key), 5)), + end_before_position=hl.agg.filter(rd.END < rd.locus.position, hl.agg.take((rd.row_key, rd.col_key), 5)), + ) + if VariantDataset.ref_block_max_length_field in rd.globals: + rbml = rd[VariantDataset.ref_block_max_length_field] + end_exprs['blocks_too_long'] = hl.agg.filter( + rd.END - rd.locus.position + 1 > rbml, hl.agg.take((rd.row_key, rd.col_key), 5) + ) + + res = rd.aggregate_entries(hl.struct(**end_exprs)) + + if res.missing_end: + error( + 'found records in reference data with missing END field\n ' + + '\n '.join(str(x) for x in res.missing_end) + ) + if res.end_before_position: error( - 'found records in reference data with missing END field\n ' + '\n '.join( - str(x) for x in missing_end)) - if end_before_position: - error('found records in reference data with END before locus position\n ' + '\n '.join( - str(x) for x in end_before_position)) + 'found records in reference data with END before locus position\n ' + + '\n '.join(str(x) for x in res.end_before_position) + ) + blocks_too_long = res.get('blocks_too_long', []) + if blocks_too_long: + error( + 'found records in reference data with blocks larger than `ref_block_max_length`\n ' + + '\n '.join(str(x) for x in blocks_too_long) + ) def _same(self, other: 'VariantDataset'): return self.reference_data._same(other.reference_data) and self.variant_data._same(other.variant_data) def union_rows(*vdses): - '''Combine many VDSes with the same samples but disjoint variants. + """Combine many VDSes with the same samples but disjoint variants. **Examples** @@ -300,7 +338,7 @@ def union_rows(*vdses): ... vds_per_chrom = [hl.vds.read_vds(path) for path in vds_paths) # doctest: +SKIP ... hl.vds.VariantDataset.union_rows(*vds_per_chrom) # doctest: +SKIP - ''' + """ fd = hl.vds.VariantDataset.ref_block_max_length_field mts = [vds.reference_data for vds in vdses] @@ -310,7 +348,9 @@ def union_rows(*vdses): # if some mts have max ref len but not all, drop it if all_ref_max: - new_ref_mt = hl.MatrixTable.union_rows(*mts).annotate_globals(**{fd: hl.max([mt.index_globals()[fd] for mt in mts])}) + new_ref_mt = hl.MatrixTable.union_rows(*mts).annotate_globals(**{ + fd: hl.max([mt.index_globals()[fd] for mt in mts]) + }) else: if any_ref_max: mts = [mt.drop(fd) if fd in mt.globals else mt for mt in mts] diff --git a/hail/python/hailtop/__init__.py b/hail/python/hailtop/__init__.py index cfeb992ed83..06926474381 100644 --- a/hail/python/hailtop/__init__.py +++ b/hail/python/hailtop/__init__.py @@ -5,6 +5,7 @@ def version() -> str: global _VERSION if _VERSION is None: import pkg_resources # pylint: disable=import-outside-toplevel + _VERSION = pkg_resources.resource_string(__name__, 'hail_version').decode().strip() return _VERSION @@ -21,6 +22,7 @@ def is_notebook() -> bool: if IS_NOTEBOOK is None: try: from IPython.core.getipython import get_ipython # pylint: disable=import-outside-toplevel + IS_NOTEBOOK = get_ipython().__class__.__name__ == 'ZMQInteractiveShell' except (NameError, ModuleNotFoundError): IS_NOTEBOOK = False diff --git a/hail/python/hailtop/aiocloud/aioaws/fs.py b/hail/python/hailtop/aiocloud/aioaws/fs.py index 6db1b402177..b45bcbae7cc 100644 --- a/hail/python/hailtop/aiocloud/aioaws/fs.py +++ b/hail/python/hailtop/aiocloud/aioaws/fs.py @@ -1,27 +1,50 @@ -from typing import (Any, AsyncIterator, BinaryIO, cast, AsyncContextManager, Dict, List, Optional, - Set, Tuple, Type, ClassVar) -from types import TracebackType -import sys -from concurrent.futures import ThreadPoolExecutor -import os.path -import threading import asyncio -import logging import datetime - +import logging +import os.path +import sys +import threading +from concurrent.futures import ThreadPoolExecutor +from types import TracebackType +from typing import ( + Any, + AsyncContextManager, + AsyncIterator, + BinaryIO, + Dict, + List, + Optional, + Set, + Tuple, + Type, + Union, + cast, +) + +import aiohttp +import boto3 import botocore.config import botocore.exceptions -import boto3 -from hailtop.utils import blocking_to_async -from hailtop.aiotools.fs import (FileStatus, FileListEntry, ReadableStream, WritableStream, AsyncFS, - AsyncFSURL, MultiPartCreate, FileAndDirectoryError) + +from hailtop.aiotools.fs import ( + AsyncFS, + AsyncFSURL, + FileAndDirectoryError, + FileListEntry, + FileStatus, + IsABucketError, + MultiPartCreate, + ReadableStream, + WritableStream, +) from hailtop.aiotools.fs.exceptions import UnexpectedEOFError from hailtop.aiotools.fs.stream import ( AsyncQueueWritableStream, - async_writable_blocking_readable_stream_pair, async_writable_blocking_collect_pair, - blocking_readable_stream_to_async) - + async_writable_blocking_readable_stream_pair, + blocking_readable_stream_to_async, +) +from hailtop.utils import blocking_to_async log = logging.getLogger(__name__) @@ -41,27 +64,40 @@ def __aiter__(self) -> 'PageIterator': async def __anext__(self): if self._page is None: - self._page = await blocking_to_async(self._fs._thread_pool, self._fs._s3.list_objects_v2, # type: ignore - Bucket=self._bucket, - Prefix=self._prefix, - **self._kwargs) + self._page = await blocking_to_async( + self._fs._thread_pool, + self._fs._s3.list_objects_v2, # type: ignore + Bucket=self._bucket, + Prefix=self._prefix, + **self._kwargs, + ) return self._page next_continuation_token = self._page.get('NextContinuationToken') if next_continuation_token is not None: - self._page = await blocking_to_async(self._fs._thread_pool, self._fs._s3.list_objects_v2, - Bucket=self._bucket, - Prefix=self._prefix, - ContinuationToken=next_continuation_token, - **self._kwargs) + self._page = await blocking_to_async( + self._fs._thread_pool, + self._fs._s3.list_objects_v2, + Bucket=self._bucket, + Prefix=self._prefix, + ContinuationToken=next_continuation_token, + **self._kwargs, + ) return self._page raise StopAsyncIteration class S3HeadObjectFileStatus(FileStatus): - def __init__(self, head_object_resp): + def __init__(self, head_object_resp, url: str): self.head_object_resp = head_object_resp + self._url = url + + def basename(self) -> str: + return os.path.basename(self._url.rstrip('/')) + + def url(self) -> str: + return self._url async def size(self) -> int: return self.head_object_resp['ContentLength'] @@ -81,8 +117,15 @@ async def __getitem__(self, key: str) -> Any: class S3ListFilesFileStatus(FileStatus): - def __init__(self, item: Dict[str, Any]): + def __init__(self, item: Dict[str, Any], url: str): self._item = item + self._url = url + + def basename(self) -> str: + return os.path.basename(self._url.rstrip('/')) + + def url(self) -> str: + return self._url async def size(self) -> int: return self._item['Size'] @@ -127,9 +170,11 @@ def put(): return async_writable async def __aexit__( - self, exc_type: Optional[Type[BaseException]] = None, - exc_value: Optional[BaseException] = None, - exc_traceback: Optional[TracebackType] = None) -> None: + self, + exc_type: Optional[Type[BaseException]] = None, + exc_value: Optional[BaseException] = None, + exc_traceback: Optional[TracebackType] = None, + ) -> None: assert self.async_writable assert self._put_thread await self.async_writable.wait_closed() @@ -150,8 +195,8 @@ def __init__(self, bucket: str, key: str, item: Optional[Dict[str, Any]]): self._item = item self._status: Optional[S3ListFilesFileStatus] = None - def name(self) -> str: - return os.path.basename(self._key) + def basename(self) -> str: + return os.path.basename(self._key.rstrip('/')) async def url(self) -> str: return f's3://{self._bucket}/{self._key}' @@ -166,7 +211,7 @@ async def status(self) -> FileStatus: if self._status is None: if self._item is None: raise IsADirectoryError(f's3://{self._bucket}/{self._key}') - self._status = S3ListFilesFileStatus(self._item) + self._status = S3ListFilesFileStatus(self._item, await self.url()) return self._status @@ -191,7 +236,8 @@ def put(): Key=self._mpc._name, PartNumber=self._number + 1, UploadId=self._mpc._upload_id, - Body=b) + Body=b, + ) self._mpc._etags[self._number] = resp['ETag'] except BaseException as e: self._exc = e @@ -201,9 +247,11 @@ def put(): return async_writable async def __aexit__( - self, exc_type: Optional[Type[BaseException]] = None, - exc_value: Optional[BaseException] = None, - exc_traceback: Optional[TracebackType] = None) -> None: + self, + exc_type: Optional[Type[BaseException]] = None, + exc_value: Optional[BaseException] = None, + exc_traceback: Optional[TracebackType] = None, + ) -> None: assert self._async_writable is not None assert self._put_thread is not None await self._async_writable.wait_closed() @@ -229,38 +277,43 @@ def __init__(self, sema: asyncio.Semaphore, fs: 'S3AsyncFS', bucket: str, name: self._etags: List[Optional[str]] = [None] * num_parts async def __aenter__(self) -> 'S3MultiPartCreate': - resp = await blocking_to_async(self._fs._thread_pool, self._fs._s3.create_multipart_upload, - Bucket=self._bucket, - Key=self._name) + resp = await blocking_to_async( + self._fs._thread_pool, self._fs._s3.create_multipart_upload, Bucket=self._bucket, Key=self._name + ) self._upload_id = resp['UploadId'] return self async def __aexit__( - self, exc_type: Optional[Type[BaseException]] = None, - exc_value: Optional[BaseException] = None, - exc_traceback: Optional[TracebackType] = None) -> None: + self, + exc_type: Optional[Type[BaseException]] = None, + exc_value: Optional[BaseException] = None, + exc_traceback: Optional[TracebackType] = None, + ) -> None: if exc_value is not None: - await blocking_to_async(self._fs._thread_pool, self._fs._s3.abort_multipart_upload, - Bucket=self._bucket, - Key=self._name, - UploadId=self._upload_id) + await blocking_to_async( + self._fs._thread_pool, + self._fs._s3.abort_multipart_upload, + Bucket=self._bucket, + Key=self._name, + UploadId=self._upload_id, + ) return parts = [] part_number = 1 for etag in self._etags: assert etag is not None - parts.append({ - 'ETag': etag, - 'PartNumber': part_number - }) + parts.append({'ETag': etag, 'PartNumber': part_number}) part_number += 1 - await blocking_to_async(self._fs._thread_pool, self._fs._s3.complete_multipart_upload, - Bucket=self._bucket, - Key=self._name, - MultipartUpload={'Parts': parts}, - UploadId=self._upload_id) + await blocking_to_async( + self._fs._thread_pool, + self._fs._s3.complete_multipart_upload, + Bucket=self._bucket, + Key=self._name, + MultipartUpload={'Parts': parts}, + UploadId=self._upload_id, + ) async def create_part(self, number: int, start: int, size_hint: Optional[int] = None) -> S3CreatePartManager: # pylint: disable=unused-argument if size_hint is None: @@ -273,6 +326,9 @@ def __init__(self, bucket: str, path: str): self._bucket = bucket self._path = path + def __repr__(self): + return f'S3AsyncFSURL({self._bucket}, {self._path})' + @property def bucket_parts(self) -> List[str]: return [self._bucket] @@ -292,28 +348,69 @@ def scheme(self) -> str: def with_path(self, path) -> 'S3AsyncFSURL': return S3AsyncFSURL(self._bucket, path) + def with_root_path(self) -> 'S3AsyncFSURL': + return self.with_path('') + def __str__(self) -> str: return f's3://{self._bucket}/{self._path}' class S3AsyncFS(AsyncFS): - schemes: ClassVar[Set[str]] = {'s3'} - - def __init__(self, thread_pool: Optional[ThreadPoolExecutor] = None, max_workers: Optional[int] = None, *, max_pool_connections: int = 10): + def __init__( + self, + thread_pool: Optional[ThreadPoolExecutor] = None, + max_workers: Optional[int] = None, + *, + max_pool_connections: int = 10, + timeout: Optional[Union[int, float, aiohttp.ClientTimeout]] = None, + ): if not thread_pool: thread_pool = ThreadPoolExecutor(max_workers=max_workers) self._thread_pool = thread_pool + + kwargs = {} + if isinstance(timeout, aiohttp.ClientTimeout): + if timeout.sock_read: + kwargs['read_timeout'] = timeout.sock_read + elif timeout.total: + kwargs['read_timeout'] = timeout.total + + if timeout.sock_connect: + kwargs['connect_timeout'] = timeout.sock_connect + elif timeout.connect: + kwargs['connect_timeout'] = timeout.connect + elif timeout.total: + kwargs['connect_timeout'] = timeout.total + elif isinstance(timeout, (int, float)): + kwargs['read_timeout'] = timeout + kwargs['connect_timeout'] = timeout + config = botocore.config.Config( max_pool_connections=max_pool_connections, + **kwargs, ) self._s3 = boto3.client('s3', config=config) + @staticmethod + def schemes() -> Set[str]: + return {'s3'} + + @staticmethod + def copy_part_size(url: str) -> int: # pylint: disable=unused-argument + # Because the S3 upload_part API call requires the entire part + # be loaded into memory, use a smaller part size. + return 32 * 1024 * 1024 + @staticmethod def valid_url(url: str) -> bool: return url.startswith('s3://') - def parse_url(self, url: str) -> S3AsyncFSURL: - return S3AsyncFSURL(*self.get_bucket_and_name(url)) + @staticmethod + def parse_url(url: str, *, error_if_bucket: bool = False) -> S3AsyncFSURL: + fsurl = S3AsyncFSURL(*S3AsyncFS.get_bucket_and_name(url)) + if error_if_bucket and fsurl._path == '': + raise IsABucketError + return fsurl @staticmethod def get_bucket_and_name(url: str) -> Tuple[str, str]: @@ -325,37 +422,36 @@ def get_bucket_and_name(url: str) -> Tuple[str, str]: if scheme != 's3': raise ValueError(f'invalid scheme, expected s3: {scheme}') - rest = url[(colon_index + 1):] + rest = url[(colon_index + 1) :] if not rest.startswith('//'): raise ValueError(f's3 URI must be of the form: s3://bucket/key, found: {url}') end_of_bucket = rest.find('/', 2) bucket = rest[2:end_of_bucket] - name = rest[(end_of_bucket + 1):] + name = rest[(end_of_bucket + 1) :] return (bucket, name) async def open(self, url: str) -> ReadableStream: - bucket, name = self.get_bucket_and_name(url) + fsurl = self.parse_url(url, error_if_bucket=True) try: - resp = await blocking_to_async(self._thread_pool, self._s3.get_object, - Bucket=bucket, - Key=name) + resp = await blocking_to_async( + self._thread_pool, self._s3.get_object, Bucket=fsurl._bucket, Key=fsurl._path + ) return blocking_readable_stream_to_async(self._thread_pool, cast(BinaryIO, resp['Body'])) except self._s3.exceptions.NoSuchKey as e: raise FileNotFoundError(url) from e async def _open_from(self, url: str, start: int, *, length: Optional[int] = None) -> ReadableStream: - bucket, name = self.get_bucket_and_name(url) + fsurl = self.parse_url(url, error_if_bucket=True) range_str = f'bytes={start}-' if length is not None: assert length >= 1 range_str += str(start + length - 1) try: - resp = await blocking_to_async(self._thread_pool, self._s3.get_object, - Bucket=bucket, - Key=name, - Range=range_str) + resp = await blocking_to_async( + self._thread_pool, self._s3.get_object, Bucket=fsurl._bucket, Key=fsurl._path, Range=range_str + ) return blocking_readable_stream_to_async(self._thread_pool, cast(BinaryIO, resp['Body'])) except self._s3.exceptions.NoSuchKey as e: raise FileNotFoundError(url) from e @@ -405,16 +501,12 @@ async def create(self, url: str, *, retry_writes: bool = True) -> S3CreateManage # interface. This has the disadvantage that the read must # complete before the write can begin (unlike the current # code, that copies 128MB parts in 256KB chunks). - bucket, name = self.get_bucket_and_name(url) - return S3CreateManager(self, bucket, name) + fsurl = self.parse_url(url, error_if_bucket=True) + return S3CreateManager(self, fsurl._bucket, fsurl._path) - async def multi_part_create( - self, - sema: asyncio.Semaphore, - url: str, - num_parts: int) -> MultiPartCreate: - bucket, name = self.get_bucket_and_name(url) - return S3MultiPartCreate(sema, self, bucket, name, num_parts) + async def multi_part_create(self, sema: asyncio.Semaphore, url: str, num_parts: int) -> MultiPartCreate: + fsurl = self.parse_url(url, error_if_bucket=True) + return S3MultiPartCreate(sema, self, fsurl._bucket, fsurl._path, num_parts) async def mkdir(self, url: str) -> None: pass @@ -423,12 +515,12 @@ async def makedirs(self, url: str, exist_ok: bool = False) -> None: pass async def statfile(self, url: str) -> FileStatus: - bucket, name = self.get_bucket_and_name(url) + fsurl = self.parse_url(url, error_if_bucket=True) try: - resp = await blocking_to_async(self._thread_pool, self._s3.head_object, - Bucket=bucket, - Key=name) - return S3HeadObjectFileStatus(resp) + resp = await blocking_to_async( + self._thread_pool, self._s3.head_object, Bucket=fsurl._bucket, Key=fsurl._path + ) + return S3HeadObjectFileStatus(resp, url) except botocore.exceptions.ClientError as e: if e.response['ResponseMetadata']['HTTPStatusCode'] == 404: raise FileNotFoundError(url) from e @@ -455,11 +547,9 @@ async def _listfiles_flat(self, bucket: str, name: str) -> AsyncIterator[S3FileL for item in contents: yield S3FileListEntry(bucket, item['Key'], item) - async def listfiles(self, - url: str, - recursive: bool = False, - exclude_trailing_slash_files: bool = True - ) -> AsyncIterator[FileListEntry]: + async def listfiles( + self, url: str, recursive: bool = False, exclude_trailing_slash_files: bool = True + ) -> AsyncIterator[FileListEntry]: bucket, name = self.get_bucket_and_name(url) if name and not name.endswith('/'): name += '/' @@ -503,11 +593,11 @@ async def staturl(self, url: str) -> str: return await self._staturl_parallel_isfile_isdir(url) async def isfile(self, url: str) -> bool: + bucket, name = self.get_bucket_and_name(url) + if name == '': + return False try: - bucket, name = self.get_bucket_and_name(url) - await blocking_to_async(self._thread_pool, self._s3.head_object, - Bucket=bucket, - Key=name) + await blocking_to_async(self._thread_pool, self._s3.head_object, Bucket=bucket, Key=name) return True except botocore.exceptions.ClientError as e: if e.response['ResponseMetadata']['HTTPStatusCode'] == 404: @@ -515,6 +605,7 @@ async def isfile(self, url: str) -> bool: raise e async def isdir(self, url: str) -> bool: + self.parse_url(url, error_if_bucket=True) try: async for _ in await self.listfiles(url, recursive=True): return True @@ -523,18 +614,11 @@ async def isdir(self, url: str) -> bool: return False async def remove(self, url: str) -> None: + fsurl = self.parse_url(url, error_if_bucket=True) try: - bucket, name = self.get_bucket_and_name(url) - await blocking_to_async(self._thread_pool, self._s3.delete_object, - Bucket=bucket, - Key=name) + await blocking_to_async(self._thread_pool, self._s3.delete_object, Bucket=fsurl._bucket, Key=fsurl._path) except self._s3.exceptions.NoSuchKey as e: raise FileNotFoundError(url) from e async def close(self) -> None: del self._s3 - - def copy_part_size(self, url: str) -> int: # pylint: disable=unused-argument - # Because the S3 upload_part API call requires the entire part - # be loaded into memory, use a smaller part size. - return 32 * 1024 * 1024 diff --git a/hail/python/hailtop/aiocloud/aioazure/__init__.py b/hail/python/hailtop/aiocloud/aioazure/__init__.py index 389b2bffd7b..689ab210772 100644 --- a/hail/python/hailtop/aiocloud/aioazure/__init__.py +++ b/hail/python/hailtop/aiocloud/aioazure/__init__.py @@ -1,14 +1,19 @@ -from .client import (AzureComputeClient, AzureGraphClient, AzureNetworkClient, AzureResourcesClient, - AzureResourceManagerClient, AzurePricingClient) +from .client import ( + AzureComputeClient, + AzureGraphClient, + AzureNetworkClient, + AzurePricingClient, + AzureResourceManagerClient, + AzureResourcesClient, +) from .credentials import AzureCredentials -from .fs import AzureAsyncFS, AzureAsyncFSFactory -from .session import AzureSession +from .fs import AzureAsyncFS, AzureAsyncFSFactory, AzureAsyncFSURL __all__ = [ 'AzureAsyncFS', 'AzureAsyncFSFactory', + 'AzureAsyncFSURL', 'AzureCredentials', - 'AzureSession', 'AzureComputeClient', 'AzureGraphClient', 'AzureNetworkClient', diff --git a/hail/python/hailtop/aiocloud/aioazure/client/arm_client.py b/hail/python/hailtop/aiocloud/aioazure/client/arm_client.py index f078ee4ef58..31e8a67d2a6 100644 --- a/hail/python/hailtop/aiocloud/aioazure/client/arm_client.py +++ b/hail/python/hailtop/aiocloud/aioazure/client/arm_client.py @@ -1,4 +1,4 @@ -from typing import AsyncGenerator, Any, Optional +from typing import Any, AsyncGenerator, Optional from .base_client import AzureBaseClient @@ -10,8 +10,10 @@ def __init__(self, subscription_id: str, resource_group_name: str, **kwargs): params = kwargs['params'] if 'api-version' not in params: params['api-version'] = '2021-04-01' - super().__init__(f'https://management.azure.com/subscriptions/{subscription_id}/resourceGroups/{resource_group_name}/providers/Microsoft.Resources', - **kwargs) + super().__init__( + f'https://management.azure.com/subscriptions/{subscription_id}/resourceGroups/{resource_group_name}/providers/Microsoft.Resources', + **kwargs, + ) async def list_deployments(self, filter: Optional[str] = None) -> AsyncGenerator[Any, None]: # https://docs.microsoft.com/en-us/rest/api/resources/deployments/list-by-resource-group diff --git a/hail/python/hailtop/aiocloud/aioazure/client/base_client.py b/hail/python/hailtop/aiocloud/aioazure/client/base_client.py index 74257649db6..c0b060dd975 100644 --- a/hail/python/hailtop/aiocloud/aioazure/client/base_client.py +++ b/hail/python/hailtop/aiocloud/aioazure/client/base_client.py @@ -1,19 +1,37 @@ -from typing import Optional, AsyncGenerator, Any +from typing import Any, AsyncGenerator, List, Mapping, Optional, Union import aiohttp + from hailtop.utils import RateLimit, sleep_before_try, url_and_params from ...common import CloudBaseClient -from ..session import AzureSession +from ...common.credentials import AnonymousCloudCredentials +from ...common.session import BaseSession, Session +from ..credentials import AzureCredentials class AzureBaseClient(CloudBaseClient): - _session: AzureSession - - def __init__(self, base_url: str, *, session: Optional[AzureSession] = None, - rate_limit: Optional[RateLimit] = None, **kwargs): + def __init__( + self, + base_url: str, + *, + session: Optional[BaseSession] = None, + rate_limit: Optional[RateLimit] = None, + credentials: Optional[Union['AzureCredentials', AnonymousCloudCredentials]] = None, + credentials_file: Optional[str] = None, + scopes: Optional[List[str]] = None, + params: Optional[Mapping[str, str]] = None, + **kwargs, + ): if session is None: - session = AzureSession(**kwargs) + session = Session( + credentials=credentials or AzureCredentials.from_file_or_default(credentials_file, scopes), + params=params, + **kwargs, + ) + elif credentials_file is not None or credentials is not None: + raise ValueError('Do not provide credentials_file or credentials when session is not None') + super().__init__(base_url, session, rate_limit=rate_limit) async def _paged_get(self, path, **kwargs) -> AsyncGenerator[Any, None]: @@ -28,14 +46,9 @@ async def _paged_get(self, path, **kwargs) -> AsyncGenerator[Any, None]: yield v next_link = page.get('nextLink') - async def delete(self, path: Optional[str] = None, *, url: Optional[str] = None, **kwargs) -> aiohttp.ClientResponse: - if url is None: - assert path - url = f'{self._base_url}{path}' - async with await self._session.delete(url, **kwargs) as resp: - return resp - - async def delete_and_wait(self, path: Optional[str] = None, *, url: Optional[str] = None, **kwargs) -> aiohttp.ClientResponse: + async def delete_and_wait( + self, path: Optional[str] = None, *, url: Optional[str] = None, **kwargs + ) -> aiohttp.ClientResponse: tries = 1 while True: resp = await self.delete(path, url=url, **kwargs) diff --git a/hail/python/hailtop/aiocloud/aioazure/client/compute_client.py b/hail/python/hailtop/aiocloud/aioazure/client/compute_client.py index 10a6a93f386..bee3789a41c 100644 --- a/hail/python/hailtop/aiocloud/aioazure/client/compute_client.py +++ b/hail/python/hailtop/aiocloud/aioazure/client/compute_client.py @@ -8,5 +8,7 @@ def __init__(self, subscription_id, resource_group_name, **kwargs): params = kwargs['params'] if 'api-version' not in params: params['api-version'] = '2021-07-01' - super().__init__(f'https://management.azure.com/subscriptions/{subscription_id}/resourceGroups/{resource_group_name}/providers/Microsoft.Compute', - **kwargs) + super().__init__( + f'https://management.azure.com/subscriptions/{subscription_id}/resourceGroups/{resource_group_name}/providers/Microsoft.Compute', + **kwargs, + ) diff --git a/hail/python/hailtop/aiocloud/aioazure/client/graph_client.py b/hail/python/hailtop/aiocloud/aioazure/client/graph_client.py index 982b3cbcc92..dcf97376a1a 100644 --- a/hail/python/hailtop/aiocloud/aioazure/client/graph_client.py +++ b/hail/python/hailtop/aiocloud/aioazure/client/graph_client.py @@ -1,16 +1,14 @@ -from typing import Optional, ClassVar, List +from typing import ClassVar, List -from ..session import AzureSession from .base_client import AzureBaseClient class AzureGraphClient(AzureBaseClient): required_scopes: ClassVar[List[str]] = ['https://graph.microsoft.com/.default'] - def __init__(self, session: Optional[AzureSession] = None, **kwargs): + def __init__(self, **kwargs): if 'scopes' in kwargs: kwargs['scopes'] += AzureGraphClient.required_scopes else: kwargs['scopes'] = AzureGraphClient.required_scopes - session = session or AzureSession(**kwargs) - super().__init__('https://graph.microsoft.com/v1.0', session=session) + super().__init__('https://graph.microsoft.com/v1.0', **kwargs) diff --git a/hail/python/hailtop/aiocloud/aioazure/client/network_client.py b/hail/python/hailtop/aiocloud/aioazure/client/network_client.py index 495771bb1ea..c56dd232153 100644 --- a/hail/python/hailtop/aiocloud/aioazure/client/network_client.py +++ b/hail/python/hailtop/aiocloud/aioazure/client/network_client.py @@ -1,21 +1,19 @@ -from typing import Optional - import aiohttp -from ..session import AzureSession from .base_client import AzureBaseClient class AzureNetworkClient(AzureBaseClient): - def __init__(self, subscription_id, resource_group_name, session: Optional[AzureSession] = None, **kwargs): + def __init__(self, subscription_id: str, resource_group_name: str, **kwargs): if 'params' not in kwargs: kwargs['params'] = {} params = kwargs['params'] if 'api-version' not in params: params['api-version'] = '2021-03-01' - session = session or AzureSession(**kwargs) - super().__init__(f'https://management.azure.com/subscriptions/{subscription_id}/resourceGroups/{resource_group_name}/providers/Microsoft.Network', - session=session) + super().__init__( + f'https://management.azure.com/subscriptions/{subscription_id}/resourceGroups/{resource_group_name}/providers/Microsoft.Network', + **kwargs, + ) async def delete_nic(self, nic_name: str, ignore_not_found: bool = False): try: diff --git a/hail/python/hailtop/aiocloud/aioazure/client/pricing_client.py b/hail/python/hailtop/aiocloud/aioazure/client/pricing_client.py index 7e9a5e18aaa..1d784425328 100644 --- a/hail/python/hailtop/aiocloud/aioazure/client/pricing_client.py +++ b/hail/python/hailtop/aiocloud/aioazure/client/pricing_client.py @@ -1,16 +1,14 @@ -from typing import Optional, AsyncGenerator, Any - -from ...common import AnonymousCloudCredentials -from ..session import AzureSession - +from typing import Any, AsyncGenerator, Optional +from ...common import AnonymousCloudCredentials, Session from .base_client import AzureBaseClient class AzurePricingClient(AzureBaseClient): - def __init__(self): - session = AzureSession(credentials=AnonymousCloudCredentials()) - super().__init__('https://prices.azure.com/api/retail', session=session) + def __init__(self, **kwargs): + super().__init__( + 'https://prices.azure.com/api/retail', session=Session(credentials=AnonymousCloudCredentials()), **kwargs + ) async def _paged_get(self, path, **kwargs) -> AsyncGenerator[Any, None]: page = await self.get(path, **kwargs) diff --git a/hail/python/hailtop/aiocloud/aioazure/client/resources_client.py b/hail/python/hailtop/aiocloud/aioazure/client/resources_client.py index 040a48c6019..f9fa03c94fc 100644 --- a/hail/python/hailtop/aiocloud/aioazure/client/resources_client.py +++ b/hail/python/hailtop/aiocloud/aioazure/client/resources_client.py @@ -1,19 +1,15 @@ -from typing import Optional, Any, AsyncGenerator +from typing import Any, AsyncGenerator, Optional -from ..session import AzureSession from .base_client import AzureBaseClient class AzureResourcesClient(AzureBaseClient): - def __init__(self, subscription_id, session: Optional[AzureSession] = None, **kwargs): - session = session or AzureSession(**kwargs) - super().__init__(f'https://management.azure.com/subscriptions/{subscription_id}', session=session) + def __init__(self, subscription_id: str, **kwargs): + super().__init__(f'https://management.azure.com/subscriptions/{subscription_id}', **kwargs) async def _list_resources(self, filter: Optional[str] = None) -> AsyncGenerator[Any, None]: # https://docs.microsoft.com/en-us/rest/api/resources/resources/list - params = { - 'api-version': '2021-04-01' - } + params = {'api-version': '2021-04-01'} if filter is not None: params['$filter'] = filter return self._paged_get('/resources', params=params) diff --git a/hail/python/hailtop/aiocloud/aioazure/credentials.py b/hail/python/hailtop/aiocloud/aioazure/credentials.py index bb86bc6781d..9d15b2c7f67 100644 --- a/hail/python/hailtop/aiocloud/aioazure/credentials.py +++ b/hail/python/hailtop/aiocloud/aioazure/credentials.py @@ -1,18 +1,17 @@ import concurrent.futures -import os import json -import time import logging - +import os +import time from types import TracebackType -from typing import Any, List, Optional, Type, Union, Tuple, Dict -from azure.identity.aio import DefaultAzureCredential, ClientSecretCredential -from azure.core.credentials import AccessToken -from azure.core.credentials_async import AsyncTokenCredential +from typing import Any, Dict, List, Optional, Tuple, Type, Union import msal +from azure.core.credentials import AccessToken +from azure.core.credentials_async import AsyncTokenCredential +from azure.identity.aio import ClientSecretCredential, DefaultAzureCredential -from hailtop.utils import first_extant_file, blocking_to_async +from hailtop.utils import blocking_to_async, first_extant_file from ..common.credentials import CloudCredentials @@ -35,7 +34,9 @@ async def get_token( # See docs: # https://msal-python.readthedocs.io/en/latest/#msal.ClientApplication.acquire_token_by_refresh_token if self._refresh_token: - res_co = blocking_to_async(self._pool, self._app.acquire_token_by_refresh_token, self._refresh_token, scopes) + res_co = blocking_to_async( + self._pool, self._app.acquire_token_by_refresh_token, self._refresh_token, scopes + ) self._refresh_token = None res = await res_co else: @@ -59,6 +60,15 @@ async def close(self) -> None: class AzureCredentials(CloudCredentials): + @staticmethod + def from_file_or_default( + credentials_file: Optional[str] = None, + scopes: Optional[List[str]] = None, + ) -> 'AzureCredentials': + if credentials_file: + return AzureCredentials.from_file(credentials_file, scopes=scopes) + return AzureCredentials.default_credentials(scopes=scopes) + @staticmethod def from_credentials_data(credentials: dict, scopes: Optional[List[str]] = None): if 'refreshToken' in credentials: @@ -74,11 +84,9 @@ def from_credentials_data(credentials: dict, scopes: Optional[List[str]] = None) assert 'password' in credentials return AzureCredentials( ClientSecretCredential( - tenant_id=credentials['tenant'], - client_id=credentials['appId'], - client_secret=credentials['password'] + tenant_id=credentials['tenant'], client_id=credentials['appId'], client_secret=credentials['password'] ), - scopes + scopes, ) @staticmethod @@ -92,7 +100,7 @@ def default_credentials(scopes: Optional[List[str]] = None): credentials_file = first_extant_file( os.environ.get('AZURE_APPLICATION_CREDENTIALS'), '/azure-credentials/credentials.json', - '/gsa-key/key.json' # FIXME: make this file path cloud-agnostic + '/gsa-key/key.json', # FIXME: make this file path cloud-agnostic ) if credentials_file: @@ -101,7 +109,11 @@ def default_credentials(scopes: Optional[List[str]] = None): return AzureCredentials(DefaultAzureCredential(), scopes) - def __init__(self, credential: Union[DefaultAzureCredential, ClientSecretCredential, RefreshTokenCredential], scopes: Optional[List[str]] = None): + def __init__( + self, + credential: Union[DefaultAzureCredential, ClientSecretCredential, RefreshTokenCredential], + scopes: Optional[List[str]] = None, + ): self.credential = credential self._access_token = None self._expires_at = None @@ -118,7 +130,7 @@ async def access_token_with_expiration(self) -> Tuple[str, Optional[float]]: now = time.time() if self._access_token is None or (self._expires_at is not None and now > self._expires_at): self._access_token = await self.get_access_token() - self._expires_at = now + (self._access_token.expires_on - now) // 2 # type: ignore + self._expires_at = now + (self._access_token.expires_on - now) // 2 # type: ignore assert self._access_token return self._access_token.token, self._expires_at diff --git a/hail/python/hailtop/aiocloud/aioazure/fs.py b/hail/python/hailtop/aiocloud/aioazure/fs.py index e2d60adbe09..9db30e9cc1d 100644 --- a/hail/python/hailtop/aiocloud/aioazure/fs.py +++ b/hail/python/hailtop/aiocloud/aioazure/fs.py @@ -1,29 +1,38 @@ -from typing import Any, AsyncContextManager, AsyncIterator, Dict, List, Optional, Set, Tuple, Type, Union, ClassVar -from types import TracebackType - -import abc -import re import asyncio -from functools import wraps -import secrets import logging +import os +import re +import secrets from datetime import datetime, timedelta +from functools import wraps +from types import TracebackType +from typing import Any, AsyncContextManager, AsyncIterator, Dict, List, Optional, Set, Tuple, Type, Union +import aiohttp +import azure.core.exceptions from azure.mgmt.storage.aio import StorageManagementClient from azure.storage.blob import BlobProperties, ResourceTypes, generate_account_sas -from azure.storage.blob.aio import BlobClient, ContainerClient, BlobServiceClient, StorageStreamDownloader +from azure.storage.blob.aio import BlobClient, BlobServiceClient, ContainerClient, StorageStreamDownloader from azure.storage.blob.aio._list_blobs_helper import BlobPrefix -import azure.core.exceptions -from hailtop.utils import retry_transient_errors, flatten from hailtop.aiotools import WriteBuffer -from hailtop.aiotools.fs import (AsyncFS, AsyncFSURL, AsyncFSFactory, ReadableStream, - WritableStream, MultiPartCreate, FileListEntry, FileStatus, - FileAndDirectoryError, UnexpectedEOFError) +from hailtop.aiotools.fs import ( + AsyncFS, + AsyncFSFactory, + AsyncFSURL, + FileAndDirectoryError, + FileListEntry, + FileStatus, + IsABucketError, + MultiPartCreate, + ReadableStream, + UnexpectedEOFError, + WritableStream, +) +from hailtop.utils import flatten, retry_transient_errors from .credentials import AzureCredentials - logger = logging.getLogger("azure.core.pipeline.policies.http_logging_policy") logger.setLevel(logging.WARNING) @@ -91,21 +100,21 @@ def __init__(self, sema: asyncio.Semaphore, client: BlobClient, num_parts: int): self._client = client self._block_ids: List[List[str]] = [[] for _ in range(num_parts)] - async def create_part(self, number: int, start: int, size_hint: Optional[int] = None) -> AsyncContextManager[WritableStream]: # pylint: disable=unused-argument + async def create_part( + self, number: int, start: int, size_hint: Optional[int] = None + ) -> AsyncContextManager[WritableStream]: # pylint: disable=unused-argument return AzureCreatePartManager(self._client, self._block_ids[number]) async def __aenter__(self) -> 'AzureMultiPartCreate': return self - async def __aexit__(self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType]) -> None: + async def __aexit__( + self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType] + ) -> None: try: # azure allows both BlockBlob and the string id here, despite # only having BlockBlob annotations - await self._client.commit_block_list(flatten(self._block_ids) # type: ignore - ) + await self._client.commit_block_list(flatten(self._block_ids)) # type: ignore except: try: await self._client.delete_blob() @@ -125,17 +134,18 @@ async def __aenter__(self) -> WritableStream: return self._writable_stream async def __aexit__( - self, exc_type: Optional[Type[BaseException]] = None, - exc_value: Optional[BaseException] = None, - exc_traceback: Optional[TracebackType] = None) -> None: + self, + exc_type: Optional[Type[BaseException]] = None, + exc_value: Optional[BaseException] = None, + exc_traceback: Optional[TracebackType] = None, + ) -> None: if self._writable_stream: await self._writable_stream.wait_closed() try: # azure allows both BlockBlob and the string id here, despite # only having BlockBlob annotations - await self._client.commit_block_list(self._block_ids # type: ignore - ) + await self._client.commit_block_list(self._block_ids) # type: ignore except: try: await self._client.delete_blob() @@ -145,9 +155,11 @@ async def __aexit__( class AzureReadableStream(ReadableStream): - def __init__(self, client: BlobClient, url: str, offset: Optional[int] = None, length: Optional[int] = None): + def __init__( + self, fs: 'AzureAsyncFS', url: 'AzureAsyncFSURL', offset: Optional[int] = None, length: Optional[int] = None + ): super().__init__() - self._client = client + self._fs = fs self._buffer = bytearray() self._url = url @@ -160,24 +172,29 @@ def __init__(self, client: BlobClient, url: str, offset: Optional[int] = None, l self._downloader: Optional[StorageStreamDownloader] = None self._chunk_it: Optional[AsyncIterator[bytes]] = None + async def _get_client(self) -> BlobClient: + return await self._fs.get_blob_client(self._url) + async def read(self, n: int = -1) -> bytes: if self._eof: return b'' if n == -1: try: - downloader = await self._client.download_blob(offset=self._offset, length=self._length) # type: ignore + client = await self._get_client() + downloader = await client.download_blob(offset=self._offset, length=self._length) # type: ignore except azure.core.exceptions.ResourceNotFoundError as e: - raise FileNotFoundError(self._url) from e + raise FileNotFoundError(self._url.base) from e data = await downloader.readall() self._eof = True return data if self._downloader is None: try: - self._downloader = await self._client.download_blob(offset=self._offset) # type: ignore + client = await self._get_client() + self._downloader = await client.download_blob(offset=self._offset) # type: ignore except azure.core.exceptions.ResourceNotFoundError as e: - raise FileNotFoundError(self._url) from e + raise FileNotFoundError(self._url.base) from e except azure.core.exceptions.HttpResponseError as e: if e.status_code == 416: raise UnexpectedEOFError from e @@ -226,8 +243,8 @@ def __init__(self, url: 'AzureAsyncFSURL', blob_props: Optional[BlobProperties]) self._blob_props = blob_props self._status: Optional[AzureFileStatus] = None - def name(self) -> str: - return self._url.path + def basename(self) -> str: + return os.path.basename(self._url.base.rstrip('/')) async def url(self) -> str: return self._url.base @@ -245,13 +262,20 @@ async def status(self) -> FileStatus: if self._status is None: if self._blob_props is None: raise IsADirectoryError(await self.url()) - self._status = AzureFileStatus(self._blob_props) + self._status = AzureFileStatus(self._blob_props, self._url) return self._status class AzureFileStatus(FileStatus): - def __init__(self, blob_props: BlobProperties): + def __init__(self, blob_props: BlobProperties, url: 'AzureAsyncFSURL'): self.blob_props = blob_props + self._url = url + + def basename(self) -> str: + return os.path.basename(self._url.base.rstrip('/')) + + def url(self) -> str: + return str(self._url) async def size(self) -> int: size = self.blob_props.size @@ -279,6 +303,13 @@ def __init__(self, account: str, container: str, path: str, query: Optional[str] self._path = path self._query = query + @property + def scheme(self) -> str: + return 'https' + + def __repr__(self): + return f'AzureAsyncFSURL({self._account}, {self._container}, {self._path}, {self._query})' + @property def bucket_parts(self) -> List[str]: return [self._account, self._container] @@ -300,37 +331,19 @@ def query(self) -> Optional[str]: return self._query @property - @abc.abstractmethod def base(self) -> str: - pass + return f'https://{self._account}.blob.core.windows.net/{self._container}/{self._path}' def with_path(self, path) -> 'AzureAsyncFSURL': return self.__class__(self._account, self._container, path, self._query) + def with_root_path(self) -> 'AzureAsyncFSURL': + return self.with_path('') + def __str__(self) -> str: return self.base if not self._query else f'{self.base}?{self._query}' -class AzureAsyncFSHailAzURL(AzureAsyncFSURL): - @property - def scheme(self) -> str: - return 'hail-az' - - @property - def base(self) -> str: - return f'hail-az://{self._account}/{self._container}/{self._path}' - - -class AzureAsyncFSHttpsURL(AzureAsyncFSURL): - @property - def scheme(self) -> str: - return 'https' - - @property - def base(self) -> str: - return f'https://{self._account}.blob.core.windows.net/{self._container}/{self._path}' - - # ABS errors if you attempt credentialed access for a public container, # so we try once with credentials, if that fails use anonymous access for # that container going forward. @@ -342,20 +355,25 @@ async def wrapped(self: 'AzureAsyncFS', url, *args, **kwargs): except azure.core.exceptions.ClientAuthenticationError: fs_url = self.parse_url(url) # https://github.com/Azure/azure-sdk-for-python/tree/main/sdk/storage/azure-storage-blob#other-client--per-operation-configuration - anon_client = BlobServiceClient(f'https://{fs_url.account}.blob.core.windows.net', - credential=None, - connection_timeout=5, - read_timeout=5) + anon_client = BlobServiceClient( + f'https://{fs_url.account}.blob.core.windows.net', credential=None, connection_timeout=5, read_timeout=5 + ) self._blob_service_clients[(fs_url.account, fs_url.container, fs_url.query)] = anon_client return await fun(self, url, *args, **kwargs) + return wrapped class AzureAsyncFS(AsyncFS): - schemes: ClassVar[Set[str]] = {'hail-az', 'https'} PATH_REGEX = re.compile('/(?P[^/]+)(?P.*)') - def __init__(self, *, credential_file: Optional[str] = None, credentials: Optional[AzureCredentials] = None): + def __init__( + self, + *, + credential_file: Optional[str] = None, + credentials: Optional[AzureCredentials] = None, + timeout: Optional[Union[int, float, aiohttp.ClientTimeout]] = None, + ): if credentials is None: scopes = ['https://storage.azure.com/.default'] if credential_file is not None: @@ -365,9 +383,23 @@ def __init__(self, *, credential_file: Optional[str] = None, credentials: Option elif credential_file is not None: raise ValueError('credential and credential_file cannot both be defined') + if isinstance(timeout, aiohttp.ClientTimeout): + self.read_timeout = timeout.sock_read or timeout.total or 5 + self.connection_timeout = timeout.sock_connect or timeout.connect or timeout.total or 5 + elif isinstance(timeout, (int, float)): + self.read_timeout = timeout + self.connection_timeout = timeout + else: + self.read_timeout = 5 + self.connection_timeout = 5 + self._credential = credentials.credential self._blob_service_clients: Dict[Tuple[str, str, Union[AzureCredentials, str, None]], BlobServiceClient] = {} + @staticmethod + def schemes() -> Set[str]: + return {'https'} + @staticmethod def valid_url(url: str) -> bool: if url.startswith('https://'): @@ -377,7 +409,7 @@ def valid_url(url: str) -> bool: return False _, suffix = authority.split('.', maxsplit=1) return suffix == 'blob.core.windows.net' - return url.startswith('hail-az://') + return False async def generate_sas_token( self, @@ -385,7 +417,7 @@ async def generate_sas_token( resource_group: str, account: str, permissions: str = "rw", - valid_interval: timedelta = timedelta(hours=1) + valid_interval: timedelta = timedelta(hours=1), ) -> str: assert self._credential mgmt_client = StorageManagementClient(self._credential, subscription_id) # type: ignore @@ -397,20 +429,28 @@ async def generate_sas_token( storage_key, resource_types=ResourceTypes(container=True, object=True), permission=permissions, - expiry=datetime.utcnow() + valid_interval) + expiry=datetime.utcnow() + valid_interval, + ) return token @staticmethod - def parse_url(url: str) -> AzureAsyncFSURL: + def parse_url(url: str, *, error_if_bucket: bool = False) -> AzureAsyncFSURL: + fsurl = AzureAsyncFS._parse_url(url) + if error_if_bucket and fsurl._path == '': + raise IsABucketError + return fsurl + + @staticmethod + def _parse_url(url: str) -> AzureAsyncFSURL: colon_index = url.find(':') if colon_index == -1: raise ValueError(f'invalid URL: {url}') scheme = url[:colon_index] - if scheme not in AzureAsyncFS.schemes: - raise ValueError(f'invalid scheme, expected hail-az or https: {scheme}') + if scheme not in AzureAsyncFS.schemes(): + raise ValueError(f'invalid scheme, expected https: {scheme}') - rest = url[(colon_index + 1):] + rest = url[(colon_index + 1) :] if not rest.startswith('//'): raise ValueError(f'invalid url: {url}') @@ -420,7 +460,9 @@ def parse_url(url: str) -> AzureAsyncFSURL: match = AzureAsyncFS.PATH_REGEX.fullmatch(container_and_name) if match is None: - raise ValueError(f'invalid path name, expected hail-az://account/container/blob_name: {container_and_name}') + raise ValueError( + f'invalid path name, expected https://account.blob.core.windows.net/container/blob_name: {container_and_name}' + ) container = match.groupdict()['container'] @@ -431,69 +473,64 @@ def parse_url(url: str) -> AzureAsyncFSURL: name, token = AzureAsyncFS.get_name_parts(name) - if scheme == 'hail-az': - account = authority - return AzureAsyncFSHailAzURL(account, container, name, token) - assert scheme == 'https' assert len(authority) > len('.blob.core.windows.net') - account = authority[:-len('.blob.core.windows.net')] - return AzureAsyncFSHttpsURL(account, container, name, token) + account = authority[: -len('.blob.core.windows.net')] + return AzureAsyncFSURL(account, container, name, token) @staticmethod - def get_name_parts(name: str) -> Tuple[str, str]: + def get_name_parts(name: str) -> Tuple[str, Optional[str]]: # Look for a terminating SAS token. query_index = name.rfind('?') if query_index != -1: - query_string = name[query_index + 1:] + query_string = name[query_index + 1 :] first_kv_pair = query_string.split('&')[0].split('=') # We will accept it as a token string if it begins with at least 1 key-value pair of the form 'k=v'. if len(first_kv_pair) == 2 and all(s != '' for s in first_kv_pair): - return (name[:query_index], query_string) - return (name, '') + return (name[:query_index], query_string) + return (name, None) - def get_blob_service_client(self, account: str, container: str, token: Optional[str]) -> BlobServiceClient: - credential = token if token else self._credential - k = account, container, token + async def get_blob_service_client(self, url: AzureAsyncFSURL) -> BlobServiceClient: + credential = url.query if url.query else self._credential + k = url.account, url.container, url.query if k not in self._blob_service_clients: # https://github.com/Azure/azure-sdk-for-python/tree/main/sdk/storage/azure-storage-blob#other-client--per-operation-configuration - self._blob_service_clients[k] = BlobServiceClient(f'https://{account}.blob.core.windows.net', - credential=credential, # type: ignore - connection_timeout=5, - read_timeout=5) + self._blob_service_clients[k] = BlobServiceClient( + f'https://{url.account}.blob.core.windows.net', + credential=credential, # type: ignore + connection_timeout=self.connection_timeout, + read_timeout=self.read_timeout, + ) return self._blob_service_clients[k] - def get_blob_client(self, url: AzureAsyncFSURL) -> BlobClient: - blob_service_client = self.get_blob_service_client(url.account, url.container, url.query) + async def get_blob_client(self, url: AzureAsyncFSURL) -> BlobClient: + blob_service_client = await self.get_blob_service_client(url) return blob_service_client.get_blob_client(url.container, url.path) - def get_container_client(self, url: AzureAsyncFSURL) -> ContainerClient: - return self.get_blob_service_client(url.account, url.container, url.query).get_container_client(url.container) + async def get_container_client(self, url: AzureAsyncFSURL) -> ContainerClient: + return (await self.get_blob_service_client(url)).get_container_client(url.container) @handle_public_access_error async def open(self, url: str) -> ReadableStream: + parsed_url = self.parse_url(url, error_if_bucket=True) if not await self.exists(url): raise FileNotFoundError - client = self.get_blob_client(self.parse_url(url)) - return AzureReadableStream(client, url) + return AzureReadableStream(self, parsed_url) @handle_public_access_error async def _open_from(self, url: str, start: int, *, length: Optional[int] = None) -> ReadableStream: assert length is None or length >= 1 if not await self.exists(url): raise FileNotFoundError - client = self.get_blob_client(self.parse_url(url)) - return AzureReadableStream(client, url, offset=start, length=length) + parsed_url = self.parse_url(url, error_if_bucket=True) + return AzureReadableStream(self, parsed_url, offset=start, length=length) async def create(self, url: str, *, retry_writes: bool = True) -> AsyncContextManager[WritableStream]: # pylint: disable=unused-argument - return AzureCreateManager(self.get_blob_client(self.parse_url(url))) - - async def multi_part_create( - self, - sema: asyncio.Semaphore, - url: str, - num_parts: int) -> MultiPartCreate: - client = self.get_blob_client(self.parse_url(url)) + parsed_url = self.parse_url(url, error_if_bucket=True) + return AzureCreateManager(await self.get_blob_client(parsed_url)) + + async def multi_part_create(self, sema: asyncio.Semaphore, url: str, num_parts: int) -> MultiPartCreate: + client = await self.get_blob_client(self.parse_url(url)) return AzureMultiPartCreate(sema, client, num_parts) @handle_public_access_error @@ -504,16 +541,14 @@ async def isfile(self, url: str) -> bool: if not fs_url.path: return False - return await self.get_blob_client(fs_url).exists() + return await (await self.get_blob_client(fs_url)).exists() @handle_public_access_error async def isdir(self, url: str) -> bool: - fs_url = self.parse_url(url) + fs_url = self.parse_url(url, error_if_bucket=True) assert not fs_url.path or fs_url.path.endswith('/'), fs_url.path - client = self.get_container_client(fs_url) - async for _ in client.walk_blobs(name_starts_with=fs_url.path, - include=['metadata'], - delimiter='/'): + client = await self.get_container_client(fs_url) + async for _ in client.walk_blobs(name_starts_with=fs_url.path, include=['metadata'], delimiter='/'): return True return False @@ -525,25 +560,27 @@ async def makedirs(self, url: str, exist_ok: bool = False) -> None: @handle_public_access_error async def statfile(self, url: str) -> FileStatus: + parsed_url = self.parse_url(url, error_if_bucket=True) try: - blob_props = await self.get_blob_client(self.parse_url(url)).get_blob_properties() - return AzureFileStatus(blob_props) + blob_props = await (await self.get_blob_client(parsed_url)).get_blob_properties() + return AzureFileStatus(blob_props, parsed_url) except azure.core.exceptions.ResourceNotFoundError as e: raise FileNotFoundError(url) from e @staticmethod - async def _listfiles_recursive(client: ContainerClient, original_url: AzureAsyncFSURL, name: str) -> AsyncIterator[FileListEntry]: + async def _listfiles_recursive( + client: ContainerClient, original_url: AzureAsyncFSURL, name: str + ) -> AsyncIterator[FileListEntry]: assert not name or name.endswith('/') - async for blob_props in client.list_blobs(name_starts_with=name, - include=['metadata']): + async for blob_props in client.list_blobs(name_starts_with=name, include=['metadata']): yield AzureFileListEntry(original_url.with_path(blob_props.name), blob_props) # type: ignore @staticmethod - async def _listfiles_flat(client: ContainerClient, original_url: AzureAsyncFSURL, name: str) -> AsyncIterator[FileListEntry]: + async def _listfiles_flat( + client: ContainerClient, original_url: AzureAsyncFSURL, name: str + ) -> AsyncIterator[FileListEntry]: assert not name or name.endswith('/') - async for item in client.walk_blobs(name_starts_with=name, - include=['metadata'], - delimiter='/'): + async for item in client.walk_blobs(name_starts_with=name, include=['metadata'], delimiter='/'): if isinstance(item, BlobPrefix): yield AzureFileListEntry(original_url.with_path(item.prefix), None) # type: ignore else: @@ -551,17 +588,15 @@ async def _listfiles_flat(client: ContainerClient, original_url: AzureAsyncFSURL yield AzureFileListEntry(original_url.with_path(item.name), item) # type: ignore @handle_public_access_error - async def listfiles(self, - url: str, - recursive: bool = False, - exclude_trailing_slash_files: bool = True - ) -> AsyncIterator[FileListEntry]: + async def listfiles( + self, url: str, recursive: bool = False, exclude_trailing_slash_files: bool = True + ) -> AsyncIterator[FileListEntry]: fs_url = self.parse_url(url) name = fs_url.path if name and not name.endswith('/'): name = f'{name}/' - client = self.get_container_client(fs_url) + client = await self.get_container_client(fs_url) if recursive: it = AzureAsyncFS._listfiles_recursive(client, fs_url, name) else: @@ -604,7 +639,8 @@ async def staturl(self, url: str) -> str: async def remove(self, url: str) -> None: try: - await self.get_blob_client(self.parse_url(url)).delete_blob() + parsed_url = self.parse_url(url, error_if_bucket=True) + await (await self.get_blob_client(parsed_url)).delete_blob() except azure.core.exceptions.ResourceNotFoundError as e: raise FileNotFoundError(url) from e @@ -614,18 +650,15 @@ async def close(self) -> None: self._credential = None if self._blob_service_clients: - await asyncio.wait([asyncio.create_task(client.close()) - for client in self._blob_service_clients.values()]) + await asyncio.wait([asyncio.create_task(client.close()) for client in self._blob_service_clients.values()]) + class AzureAsyncFSFactory(AsyncFSFactory[AzureAsyncFS]): def from_credentials_data(self, credentials_data: dict) -> AzureAsyncFS: - return AzureAsyncFS( - credentials=AzureCredentials.from_credentials_data(credentials_data)) + return AzureAsyncFS(credentials=AzureCredentials.from_credentials_data(credentials_data)) def from_credentials_file(self, credentials_file: str) -> AzureAsyncFS: - return AzureAsyncFS( - credentials=AzureCredentials.from_file(credentials_file)) + return AzureAsyncFS(credentials=AzureCredentials.from_file(credentials_file)) def from_default_credentials(self) -> AzureAsyncFS: - return AzureAsyncFS( - credentials=AzureCredentials.default_credentials()) + return AzureAsyncFS(credentials=AzureCredentials.default_credentials()) diff --git a/hail/python/hailtop/aiocloud/aioazure/session.py b/hail/python/hailtop/aiocloud/aioazure/session.py deleted file mode 100644 index fa13eee57c9..00000000000 --- a/hail/python/hailtop/aiocloud/aioazure/session.py +++ /dev/null @@ -1,21 +0,0 @@ -from typing import Mapping, Optional, List, Union - -import aiohttp - -from ..common import Session, AnonymousCloudCredentials -from .credentials import AzureCredentials - - -class AzureSession(Session): - def __init__(self, *, credentials: Optional[Union[AzureCredentials, AnonymousCloudCredentials]] = None, credentials_file: Optional[str] = None, - params: Optional[Mapping[str, str]] = None, scopes: Optional[List[str]] = None, **kwargs): - assert credentials is None or credentials_file is None, \ - f'specify only one of credentials or credentials_file: {(credentials, credentials_file)}' - if credentials is None: - if credentials_file: - credentials = AzureCredentials.from_file(credentials_file, scopes=scopes) - else: - credentials = AzureCredentials.default_credentials(scopes=scopes) - if 'timeout' not in kwargs: - kwargs['timeout'] = aiohttp.ClientTimeout(total=30) - super().__init__(credentials=credentials, params=params, **kwargs) diff --git a/hail/python/hailtop/aiocloud/aiogoogle/__init__.py b/hail/python/hailtop/aiocloud/aiogoogle/__init__.py index c9976ecd562..1e54acda88e 100644 --- a/hail/python/hailtop/aiocloud/aiogoogle/__init__.py +++ b/hail/python/hailtop/aiocloud/aiogoogle/__init__.py @@ -1,32 +1,37 @@ from .client import ( + GCSRequesterPaysConfiguration, GoogleBigQueryClient, GoogleBillingClient, - GoogleContainerClient, GoogleComputeClient, + GoogleContainerClient, GoogleIAmClient, GoogleLoggingClient, - GoogleStorageClient, - GCSRequesterPaysConfiguration, + GoogleMetadataServerClient, GoogleStorageAsyncFS, - GoogleStorageAsyncFSFactory + GoogleStorageAsyncFSFactory, + GoogleStorageClient, +) +from .credentials import ( + GoogleApplicationDefaultCredentials, + GoogleCredentials, + GoogleInstanceMetadataCredentials, + GoogleServiceAccountCredentials, ) -from .credentials import GoogleCredentials, GoogleApplicationDefaultCredentials, GoogleServiceAccountCredentials -from .session import GoogleSession from .user_config import get_gcs_requester_pays_configuration - __all__ = [ 'GCSRequesterPaysConfiguration', 'GoogleCredentials', 'GoogleApplicationDefaultCredentials', 'GoogleServiceAccountCredentials', - 'GoogleSession', + 'GoogleInstanceMetadataCredentials', 'GoogleBigQueryClient', 'GoogleBillingClient', 'GoogleContainerClient', 'GoogleComputeClient', 'GoogleIAmClient', 'GoogleLoggingClient', + 'GoogleMetadataServerClient', 'GoogleStorageClient', 'GoogleStorageAsyncFS', 'GoogleStorageAsyncFSFactory', diff --git a/hail/python/hailtop/aiocloud/aiogoogle/client/__init__.py b/hail/python/hailtop/aiocloud/aiogoogle/client/__init__.py index d9afd5bbb59..34ccd0379ea 100644 --- a/hail/python/hailtop/aiocloud/aiogoogle/client/__init__.py +++ b/hail/python/hailtop/aiocloud/aiogoogle/client/__init__.py @@ -1,10 +1,16 @@ from .bigquery_client import GoogleBigQueryClient from .billing_client import GoogleBillingClient -from .container_client import GoogleContainerClient from .compute_client import GoogleComputeClient +from .container_client import GoogleContainerClient from .iam_client import GoogleIAmClient from .logging_client import GoogleLoggingClient -from .storage_client import GCSRequesterPaysConfiguration, GoogleStorageClient, GoogleStorageAsyncFS, GoogleStorageAsyncFSFactory +from .metadata_server_client import GoogleMetadataServerClient +from .storage_client import ( + GCSRequesterPaysConfiguration, + GoogleStorageAsyncFS, + GoogleStorageAsyncFSFactory, + GoogleStorageClient, +) __all__ = [ 'GoogleBigQueryClient', @@ -13,8 +19,9 @@ 'GoogleComputeClient', 'GoogleIAmClient', 'GoogleLoggingClient', + 'GoogleMetadataServerClient', 'GCSRequesterPaysConfiguration', 'GoogleStorageClient', 'GoogleStorageAsyncFS', - 'GoogleStorageAsyncFSFactory' + 'GoogleStorageAsyncFSFactory', ] diff --git a/hail/python/hailtop/aiocloud/aiogoogle/client/base_client.py b/hail/python/hailtop/aiocloud/aiogoogle/client/base_client.py index bdcf80e1291..3e94edb2c19 100644 --- a/hail/python/hailtop/aiocloud/aiogoogle/client/base_client.py +++ b/hail/python/hailtop/aiocloud/aiogoogle/client/base_client.py @@ -1,16 +1,32 @@ -from typing import Optional +from typing import Mapping, Optional, Union from hailtop.utils import RateLimit from ...common import CloudBaseClient -from ..session import GoogleSession +from ...common.credentials import AnonymousCloudCredentials +from ...common.session import BaseSession, Session +from ..credentials import GoogleCredentials class GoogleBaseClient(CloudBaseClient): - _session: GoogleSession - - def __init__(self, base_url: str, *, session: Optional[GoogleSession] = None, - rate_limit: Optional[RateLimit] = None, **kwargs): + def __init__( + self, + base_url: str, + *, + session: Optional[BaseSession] = None, + rate_limit: Optional[RateLimit] = None, + credentials: Optional[Union[GoogleCredentials, AnonymousCloudCredentials]] = None, + credentials_file: Optional[str] = None, + params: Optional[Mapping[str, str]] = None, + **kwargs, + ): if session is None: - session = GoogleSession(**kwargs) - super().__init__(base_url, session, rate_limit=rate_limit) + session = Session( + credentials=credentials or GoogleCredentials.from_file_or_default(credentials_file), + params=params, + **kwargs, + ) + elif credentials_file is not None or credentials is not None: + raise ValueError('Do not provide credentials_file or credentials when session is not None') + + super().__init__(base_url=base_url, session=session, rate_limit=rate_limit) diff --git a/hail/python/hailtop/aiocloud/aiogoogle/client/bigquery_client.py b/hail/python/hailtop/aiocloud/aiogoogle/client/bigquery_client.py index 4412eb45421..669dbdac76a 100644 --- a/hail/python/hailtop/aiocloud/aiogoogle/client/bigquery_client.py +++ b/hail/python/hailtop/aiocloud/aiogoogle/client/bigquery_client.py @@ -1,4 +1,5 @@ from typing import Any, Dict, Mapping, Optional + from .base_client import GoogleBaseClient @@ -39,6 +40,7 @@ def parse_field(name: str, value: Any, schema: Dict[str, Any]) -> Any: return int(or_none(float, value)) # DATE, TIME, DATETIME raise NotImplementedError((name, value, typ, mode)) + return { field['name']: parse_field(field['name'], field['v'], field_schema) for field, field_schema in zip(data['f'], schema['fields']) @@ -68,13 +70,9 @@ def __aiter__(self) -> 'PagedQueriesIterator': async def __anext__(self): if self._page is None: - config = { - 'kind': 'bigquery#queryRequest', - 'useLegacySql': False, - 'query': self._query} + config = {'kind': 'bigquery#queryRequest', 'useLegacySql': False, 'query': self._query} - self._page = await self._client.post( - '/queries', json=config, **self._request_kwargs) + self._page = await self._client.post('/queries', json=config, **self._request_kwargs) self._row_index = 0 self._total_rows = self._page['totalRows'] self._parser = ResultsParser(self._page['schema']) @@ -97,20 +95,22 @@ async def __anext__(self): next_page_token = self._page.get('pageToken') if next_page_token is not None: - query_parameters = { - 'pageToken': next_page_token, - 'location': self._location - } + query_parameters = {'pageToken': next_page_token, 'location': self._location} self._page = await self._client.get_query_results( - self._job_id, query_parameters, **self._request_kwargs) + self._job_id, query_parameters, **self._request_kwargs + ) self._row_index = 0 else: raise StopAsyncIteration class GoogleBigQueryClient(GoogleBaseClient): - def __init__(self, project, **kwargs): - super().__init__(f'https://bigquery.googleapis.com/bigquery/v2/projects/{project}', **kwargs) + def __init__( + self, + project: str, + **kwargs, + ): + super().__init__(base_url=f'https://bigquery.googleapis.com/bigquery/v2/projects/{project}', **kwargs) # docs: # https://cloud.google.com/bigquery/docs/reference diff --git a/hail/python/hailtop/aiocloud/aiogoogle/client/compute_client.py b/hail/python/hailtop/aiocloud/aiogoogle/client/compute_client.py index baa84c8bc59..a8081f7050a 100644 --- a/hail/python/hailtop/aiocloud/aiogoogle/client/compute_client.py +++ b/hail/python/hailtop/aiocloud/aiogoogle/client/compute_client.py @@ -1,6 +1,7 @@ -import uuid -from typing import Mapping, Any, Optional, MutableMapping, List, Dict import logging +import uuid +from typing import Any, Dict, List, Mapping, MutableMapping, Optional + import aiohttp from hailtop.utils import retry_transient_errors, sleep_before_try @@ -11,7 +12,14 @@ class GCPOperationError(Exception): - def __init__(self, status: int, message: str, error_codes: Optional[List[str]], error_messages: Optional[List[str]], response: Dict[str, Any]): + def __init__( + self, + status: int, + message: str, + error_codes: Optional[List[str]], + error_messages: Optional[List[str]], + response: Dict[str, Any], + ): super().__init__(message) self.status = status self.message = message @@ -20,11 +28,19 @@ def __init__(self, status: int, message: str, error_codes: Optional[List[str]], self.response = response def __str__(self): - return f'GCPOperationError: {self.status}:{self.message} {self.error_codes} {self.error_messages}; {self.response}' + return ( + f'GCPOperationError: {self.status}:{self.message} {self.error_codes} {self.error_messages}; {self.response}' + ) class PagedIterator: - def __init__(self, client: 'GoogleComputeClient', path: str, request_params: Optional[MutableMapping[str, Any]], request_kwargs: Mapping[str, Any]): + def __init__( + self, + client: 'GoogleComputeClient', + path: str, + request_params: Optional[MutableMapping[str, Any]], + request_kwargs: Mapping[str, Any], + ): assert 'params' not in request_kwargs self._client = client self._path = path @@ -88,7 +104,9 @@ async def detach_disk(self, path: str, *, params: Optional[MutableMapping[str, A async def delete_disk(self, path: str, *, params: Optional[MutableMapping[str, Any]] = None, **kwargs): return await self.delete(path, params=params, **kwargs) - async def _request_with_zonal_operations_response(self, request_f, path, maybe_params: Optional[MutableMapping[str, Any]] = None, **kwargs): + async def _request_with_zonal_operations_response( + self, request_f, path, maybe_params: Optional[MutableMapping[str, Any]] = None, **kwargs + ): params = maybe_params or {} assert 'requestId' not in params @@ -102,8 +120,9 @@ async def request_and_wait(): tries = 0 while True: - result = await self.post(f'/zones/{zone}/operations/{operation_id}/wait', - timeout=aiohttp.ClientTimeout(total=150)) + result = await self.post( + f'/zones/{zone}/operations/{operation_id}/wait', timeout=aiohttp.ClientTimeout(total=150) + ) if result['status'] == 'DONE': error = result.get('error') if error: @@ -113,11 +132,13 @@ async def request_and_wait(): error_codes = [e['code'] for e in error['errors']] error_messages = [e['message'] for e in error['errors']] - raise GCPOperationError(result['httpErrorStatusCode'], - result['httpErrorMessage'], - error_codes, - error_messages, - result) + raise GCPOperationError( + result['httpErrorStatusCode'], + result['httpErrorMessage'], + error_codes, + error_messages, + result, + ) return result tries += 1 diff --git a/hail/python/hailtop/aiocloud/aiogoogle/client/logging_client.py b/hail/python/hailtop/aiocloud/aiogoogle/client/logging_client.py index 12024d924eb..05d7a443bec 100644 --- a/hail/python/hailtop/aiocloud/aiogoogle/client/logging_client.py +++ b/hail/python/hailtop/aiocloud/aiogoogle/client/logging_client.py @@ -1,9 +1,12 @@ from typing import Any, Mapping, MutableMapping, Optional + from .base_client import GoogleBaseClient class PagedEntryIterator: - def __init__(self, client: 'GoogleLoggingClient', body: MutableMapping[str, Any], request_kwargs: Mapping[str, Any]): + def __init__( + self, client: 'GoogleLoggingClient', body: MutableMapping[str, Any], request_kwargs: Mapping[str, Any] + ): self._client = client self._body = body self._request_kwargs = request_kwargs @@ -16,15 +19,18 @@ def __aiter__(self) -> 'PagedEntryIterator': async def __anext__(self): if self._page is None: assert 'pageToken' not in self._body - self._page = await self._client.post( - '/entries:list', json=self._body, **self._request_kwargs) + self._page = await self._client.post('/entries:list', json=self._body, **self._request_kwargs) self._entry_index = 0 # in case a response is empty but there are more pages while True: assert self._page # an empty page has no entries - if 'entries' in self._page and self._entry_index is not None and self._entry_index < len(self._page['entries']): + if ( + 'entries' in self._page + and self._entry_index is not None + and self._entry_index < len(self._page['entries']) + ): i = self._entry_index self._entry_index += 1 return self._page['entries'][i] @@ -32,8 +38,7 @@ async def __anext__(self): next_page_token = self._page.get('nextPageToken') if next_page_token is not None: self._body['pageToken'] = next_page_token - self._page = await self._client.post( - '/entries:list', json=self._body, **self._request_kwargs) + self._page = await self._client.post('/entries:list', json=self._body, **self._request_kwargs) self._entry_index = 0 else: raise StopAsyncIteration diff --git a/hail/python/hailtop/aiocloud/aiogoogle/client/metadata_server_client.py b/hail/python/hailtop/aiocloud/aiogoogle/client/metadata_server_client.py new file mode 100644 index 00000000000..b716830ae06 --- /dev/null +++ b/hail/python/hailtop/aiocloud/aiogoogle/client/metadata_server_client.py @@ -0,0 +1,30 @@ +from typing import Optional + +import aiohttp + +from hailtop import httpx +from hailtop.utils import retry_transient_errors + + +class GoogleMetadataServerClient: + def __init__(self, http_session: httpx.ClientSession): + self._session = http_session + self._project_id: Optional[str] = None + self._numeric_project_id: Optional[str] = None + + async def project(self) -> str: + if self._project_id is None: + self._project_id = await retry_transient_errors(self._get_text, '/project/project-id') + return self._project_id + + async def numeric_project_id(self) -> str: + if self._numeric_project_id is None: + self._numeric_project_id = await retry_transient_errors(self._get_text, '/project/numeric-project-id') + return self._numeric_project_id + + async def _get_text(self, path: str) -> str: + url = f'http://metadata.google.internal/computeMetadata/v1{path}' + headers = {'Metadata-Flavor': 'Google'} + timeout = aiohttp.ClientTimeout(total=60) + res = await self._session.get_read(url, headers=headers, timeout=timeout) + return res.decode('utf-8') diff --git a/hail/python/hailtop/aiocloud/aiogoogle/client/storage_client.py b/hail/python/hailtop/aiocloud/aiogoogle/client/storage_client.py index c9043717166..f5f4fe084af 100644 --- a/hail/python/hailtop/aiocloud/aiogoogle/client/storage_client.py +++ b/hail/python/hailtop/aiocloud/aiogoogle/client/storage_client.py @@ -1,33 +1,42 @@ -import os -from typing import (Tuple, Any, Set, Optional, MutableMapping, Dict, AsyncIterator, cast, Type, - List, Coroutine, ClassVar) -from types import TracebackType -from multidict import CIMultiDictProxy # pylint: disable=unused-import -import sys -import logging import asyncio +import datetime +import logging +import os import urllib.parse +from contextlib import AsyncExitStack +from types import TracebackType +from typing import Any, AsyncIterator, Coroutine, Dict, List, MutableMapping, Optional, Set, Tuple, Type, cast + import aiohttp -import datetime +from multidict import CIMultiDictProxy # pylint: disable=unused-import # pylint: disable=unused-import + from hailtop import timex -from hailtop.utils import ( - secret_alnum_string, OnlineBoundedGather2, - TransientError, retry_transient_errors) -from hailtop.aiotools.fs import (FileStatus, FileListEntry, ReadableStream, WritableStream, AsyncFS, - AsyncFSURL, AsyncFSFactory, FileAndDirectoryError, MultiPartCreate, - UnexpectedEOFError) from hailtop.aiotools import FeedableAsyncIterable, WriteBuffer - -from .base_client import GoogleBaseClient -from ..session import GoogleSession +from hailtop.aiotools.fs import ( + AsyncFS, + AsyncFSFactory, + AsyncFSURL, + FileAndDirectoryError, + FileListEntry, + FileStatus, + IsABucketError, + MultiPartCreate, + ReadableStream, + UnexpectedEOFError, + WritableStream, +) +from hailtop.utils import OnlineBoundedGather2, TransientError, retry_transient_errors, secret_alnum_string + +from ...common.session import BaseSession from ..credentials import GoogleCredentials -from ..user_config import get_gcs_requester_pays_configuration, GCSRequesterPaysConfiguration +from ..user_config import GCSRequesterPaysConfiguration, get_gcs_requester_pays_configuration +from .base_client import GoogleBaseClient log = logging.getLogger(__name__) class PageIterator: - def __init__(self, client: 'GoogleBaseClient', path: str, request_kwargs: MutableMapping[str, Any]): + def __init__(self, client: GoogleBaseClient, path: str, request_kwargs: MutableMapping[str, Any]): if 'params' in request_kwargs: request_params = request_kwargs['params'] del request_kwargs['params'] @@ -45,53 +54,69 @@ def __aiter__(self) -> 'PageIterator': async def __anext__(self): if self._page is None: assert 'pageToken' not in self._request_params - self._page = await retry_transient_errors(self._client.get, self._path, params=self._request_params, **self._request_kwargs) + self._page = await retry_transient_errors( + self._client.get, self._path, params=self._request_params, **self._request_kwargs + ) return self._page next_page_token = self._page.get('nextPageToken') if next_page_token is not None: self._request_params['pageToken'] = next_page_token - self._page = await retry_transient_errors(self._client.get, self._path, params=self._request_params, **self._request_kwargs) + self._page = await retry_transient_errors( + self._client.get, self._path, params=self._request_params, **self._request_kwargs + ) return self._page raise StopAsyncIteration +async def _cleanup_future(fut: asyncio.Future): + if not fut.done(): + fut.cancel() + await asyncio.wait([fut]) + if not fut.cancelled(): + if exc := fut.exception(): + raise exc + + class InsertObjectStream(WritableStream): - def __init__(self, - it: FeedableAsyncIterable[bytes], - request_task: asyncio.Task[aiohttp.ClientResponse]): + def __init__(self, it: FeedableAsyncIterable[bytes], request_task: asyncio.Task[aiohttp.ClientResponse]): super().__init__() self._it = it self._request_task = request_task self._value = None + self._exit_stack = AsyncExitStack() + + async def cleanup_request_task(): + if not self._request_task.cancelled(): + try: + async with await self._request_task as response: + self._value = await response.json() + except AttributeError as err: + raise ValueError(repr(self._request_task)) from err + await _cleanup_future(self._request_task) + + self._exit_stack.push_async_callback(cleanup_request_task) async def write(self, b): assert not self.closed fut = asyncio.ensure_future(self._it.feed(b)) - try: - await asyncio.wait([fut, self._request_task], return_when=asyncio.FIRST_COMPLETED) - if fut.done() and not fut.cancelled(): - if exc := fut.exception(): - raise exc - return len(b) - raise ValueError('request task finished early') - finally: - fut.cancel() + self._exit_stack.push_async_callback(_cleanup_future, fut) + + await asyncio.wait([fut, self._request_task], return_when=asyncio.FIRST_COMPLETED) + if fut.done(): + await fut + return len(b) + raise ValueError('request task finished early') async def _wait_closed(self): - fut = asyncio.ensure_future(self._it.stop()) try: + fut = asyncio.ensure_future(self._it.stop()) + self._exit_stack.push_async_callback(_cleanup_future, fut) await asyncio.wait([fut, self._request_task], return_when=asyncio.FIRST_COMPLETED) - async with await self._request_task as resp: - self._value = await resp.json() finally: - if fut.done() and not fut.cancelled(): - if exc := fut.exception(): - raise exc - else: - fut.cancel() + await self._exit_stack.aclose() class _TaskManager: @@ -104,35 +129,19 @@ async def __aenter__(self) -> asyncio.Task: self._task = asyncio.create_task(self._coro) return self._task - async def __aexit__(self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType]) -> None: + async def __aexit__( + self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType] + ) -> None: assert self._task is not None - if not self._task.done(): - if exc_val: - self._task.cancel() - try: - value = await self._task - if self._closable: - value.close() - except: - _, exc, _ = sys.exc_info() - if exc is not exc_val: - log.warning('dropping preempted task exception', exc_info=True) - else: - value = await self._task - if self._closable: - value.close() + if self._closable and self._task.done() and not self._task.cancelled(): + (await self._task).close() else: - value = await self._task - if self._closable: - value.close() + await _cleanup_future(self._task) class ResumableInsertObjectStream(WritableStream): - def __init__(self, session: GoogleSession, session_url: str, chunk_size: int): + def __init__(self, session: BaseSession, session_url: str, chunk_size: int): super().__init__() self._session = session self._session_url = session_url @@ -166,12 +175,11 @@ async def _write_chunk_1(self): # https://cloud.google.com/storage/docs/performing-resumable-uploads#status-check # note: this retries - async with await self._session.put(self._session_url, - headers={ - 'Content-Length': '0', - 'Content-Range': f'bytes */{total_size_str}' - }, - raise_for_status=False) as resp: + async with await self._session.put( + self._session_url, + headers={'Content-Length': '0', 'Content-Range': f'bytes */{total_size_str}'}, + raise_for_status=False, + ) as resp: if resp.status >= 200 and resp.status < 300: assert self._closed assert total_size is not None @@ -213,15 +221,15 @@ async def _write_chunk_1(self): # https://cloud.google.com/storage/docs/performing-resumable-uploads#chunked-upload it: FeedableAsyncIterable[bytes] = FeedableAsyncIterable() async with _TaskManager( - self._session.put(self._session_url, - data=aiohttp.AsyncIterablePayload(it), - headers={ - 'Content-Length': f'{n}', - 'Content-Range': range - }, - raise_for_status=False, - retry=False), - closable=True) as put_task: + self._session.put( + self._session_url, + data=aiohttp.AsyncIterablePayload(it), + headers={'Content-Length': f'{n}', 'Content-Range': range}, + raise_for_status=False, + retry=False, + ), + closable=True, + ) as put_task: with self._write_buffer.chunks(n) as chunks: for chunk in chunks: async with _TaskManager(it.feed(chunk)) as feed_task: @@ -307,8 +315,7 @@ async def _wait_closed(self) -> None: class GoogleStorageClient(GoogleBaseClient): - def __init__(self, gcs_requester_pays_configuration: Optional[GCSRequesterPaysConfiguration] = None, - **kwargs): + def __init__(self, gcs_requester_pays_configuration: Optional[GCSRequesterPaysConfiguration] = None, **kwargs): if 'timeout' not in kwargs and 'http_session' not in kwargs: # Around May 2022, GCS started timing out a lot with our default 5s timeout kwargs['timeout'] = aiohttp.ClientTimeout(total=20) @@ -352,10 +359,11 @@ async def insert_object(self, bucket: str, name: str, **kwargs) -> WritableStrea if upload_type == 'media': it: FeedableAsyncIterable[bytes] = FeedableAsyncIterable() kwargs['data'] = aiohttp.AsyncIterablePayload(it) - request_task = asyncio.create_task(self._session.post( - f'https://storage.googleapis.com/upload/storage/v1/b/{bucket}/o', - retry=False, - **kwargs)) + request_task = asyncio.create_task( + self._session.post( + f'https://storage.googleapis.com/upload/storage/v1/b/{bucket}/o', retry=False, **kwargs + ) + ) return InsertObjectStream(it, request_task) # Write using resumable uploads. See: @@ -364,8 +372,7 @@ async def insert_object(self, bucket: str, name: str, **kwargs) -> WritableStrea chunk_size = kwargs.get('bufsize', 8 * 1024 * 1024) async with await self._session.post( - f'https://storage.googleapis.com/upload/storage/v1/b/{bucket}/o', - **kwargs + f'https://storage.googleapis.com/upload/storage/v1/b/{bucket}/o', **kwargs ) as resp: session_url = resp.headers['Location'] return ResumableInsertObjectStream(self._session, session_url, chunk_size) @@ -383,7 +390,8 @@ async def get_object(self, bucket: str, name: str, **kwargs) -> GetObjectStream: try: resp = await self._session.get( - f'https://storage.googleapis.com/storage/v1/b/{bucket}/o/{urllib.parse.quote(name, safe="")}', **kwargs) + f'https://storage.googleapis.com/storage/v1/b/{bucket}/o/{urllib.parse.quote(name, safe="")}', **kwargs + ) return GetObjectStream(resp) except aiohttp.ClientResponseError as e: if e.status == 404: @@ -416,9 +424,7 @@ async def compose(self, bucket: str, names: List[str], destination: str, **kwarg raise ValueError(f'too many components in compose, maximum of 32: {n}') assert 'json' not in kwargs assert 'body' not in kwargs - kwargs['json'] = { - 'sourceObjects': [{'name': name} for name in names] - } + kwargs['json'] = {'sourceObjects': [{'name': name} for name in names]} self._update_params_with_user_project(kwargs, bucket) await self.post(f'/b/{bucket}/o/{urllib.parse.quote(destination, safe="")}/compose', **kwargs) @@ -436,8 +442,15 @@ def _update_params_with_user_project(self, request_kwargs, bucket): class GetObjectFileStatus(FileStatus): - def __init__(self, items: Dict[str, str]): + def __init__(self, items: Dict[str, str], url: str): self._items = items + self._url = url + + def basename(self) -> str: + return os.path.basename(self._url.rstrip('/')) + + def url(self) -> str: + return self._url async def size(self) -> int: return int(self._items['size']) @@ -459,8 +472,8 @@ def __init__(self, bucket: str, name: str, items: Optional[Dict[str, Any]]): self._items = items self._status: Optional[GetObjectFileStatus] = None - def name(self) -> str: - return os.path.basename(self._name) + def basename(self) -> str: + return os.path.basename(self._name.rstrip('/')) async def url(self) -> str: return f'gs://{self._bucket}/{self._name}' @@ -475,7 +488,7 @@ async def status(self) -> FileStatus: if self._status is None: if self._items is None: raise IsADirectoryError(await self.url()) - self._status = GetObjectFileStatus(self._items) + self._status = GetObjectFileStatus(self._items, await self.url()) return self._status @@ -506,21 +519,18 @@ def _part_name(self, number: int) -> str: async def create_part(self, number: int, start: int, size_hint: Optional[int] = None) -> WritableStream: # pylint: disable=unused-argument part_name = self._part_name(number) - params = { - 'uploadType': 'media' - } + params = {'uploadType': 'media'} return await self._fs._storage_client.insert_object(self._bucket, part_name, params=params) async def __aenter__(self) -> 'GoogleStorageMultiPartCreate': return self async def _compose(self, names: List[str], dest_name: str): - await self._fs._storage_client.compose(self._bucket, names, dest_name) + await retry_transient_errors(self._fs._storage_client.compose, self._bucket, names, dest_name) - async def __aexit__(self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType]) -> None: + async def __aexit__( + self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType] + ) -> None: async with OnlineBoundedGather2(self._sema) as pool: try: if exc_val is not None: @@ -542,7 +552,7 @@ async def tree_compose(names, dest_name): chunk_size = q if i < r: chunk_size += 1 - chunks.append(names[p:p + chunk_size]) + chunks.append(names[p : p + chunk_size]) p += chunk_size i += 1 assert p == n @@ -550,23 +560,27 @@ async def tree_compose(names, dest_name): chunk_names = [self._tmp_name(f'chunk-{secret_alnum_string()}') for _ in range(32)] - chunk_tasks = [ - pool.call(tree_compose, c, n) - for c, n in zip(chunks, chunk_names) - ] + async with AsyncExitStack() as stack: + chunk_tasks = [] + for chunk, name in zip(chunks, chunk_names): + fut = pool.call(tree_compose, chunk, name) + stack.push_async_callback(_cleanup_future, fut) + chunk_tasks.append(fut) - await pool.wait(chunk_tasks) + await pool.wait(chunk_tasks) await self._compose(chunk_names, dest_name) for name in chunk_names: - await pool.call(self._fs._remove_doesnt_exist_ok, f'gs://{self._bucket}/{name}') + await pool.call( + retry_transient_errors, self._fs._remove_doesnt_exist_ok, f'gs://{self._bucket}/{name}' + ) - await tree_compose( - [self._part_name(i) for i in range(self._num_parts)], - self._dest_name) + await tree_compose([self._part_name(i) for i in range(self._num_parts)], self._dest_name) finally: - await self._fs.rmtree(self._sema, f'gs://{self._bucket}/{self._dest_dirname}_/{self._token}') + await retry_transient_errors( + self._fs.rmtree, self._sema, f'gs://{self._bucket}/{self._dest_dirname}_/{self._token}' + ) class GoogleStorageAsyncFSURL(AsyncFSURL): @@ -574,6 +588,9 @@ def __init__(self, bucket: str, path: str): self._bucket = bucket self._path = path + def __repr__(self): + return f'GoogleStorageAsyncFSURL({self._bucket}, {self._path})' + @property def bucket_parts(self) -> List[str]: return [self._bucket] @@ -593,17 +610,21 @@ def scheme(self) -> str: def with_path(self, path) -> 'GoogleStorageAsyncFSURL': return GoogleStorageAsyncFSURL(self._bucket, path) + def with_root_path(self) -> 'GoogleStorageAsyncFSURL': + return self.with_path('') + def __str__(self) -> str: return f'gs://{self._bucket}/{self._path}' class GoogleStorageAsyncFS(AsyncFS): - schemes: ClassVar[Set[str]] = {'gs'} - - def __init__(self, *, - storage_client: Optional[GoogleStorageClient] = None, - bucket_allow_list: Optional[List[str]] = None, - **kwargs): + def __init__( + self, + *, + storage_client: Optional[GoogleStorageClient] = None, + bucket_allow_list: Optional[List[str]] = None, + **kwargs, + ): if not storage_client: storage_client = GoogleStorageClient(**kwargs) self._storage_client = storage_client @@ -611,6 +632,10 @@ def __init__(self, *, bucket_allow_list = [] self.allowed_storage_locations = bucket_allow_list + @staticmethod + def schemes() -> Set[str]: + return {'gs'} + def storage_location(self, uri: str) -> str: return self.get_bucket_and_name(uri)[0] @@ -635,8 +660,12 @@ async def is_hot_storage(self, location: str) -> bool: def valid_url(url: str) -> bool: return url.startswith('gs://') - def parse_url(self, url: str) -> GoogleStorageAsyncFSURL: - return GoogleStorageAsyncFSURL(*self.get_bucket_and_name(url)) + @staticmethod + def parse_url(url: str, *, error_if_bucket: bool = False) -> GoogleStorageAsyncFSURL: + fsurl = GoogleStorageAsyncFSURL(*GoogleStorageAsyncFS.get_bucket_and_name(url)) + if error_if_bucket and fsurl._path == '': + raise IsABucketError + return fsurl @staticmethod def get_bucket_and_name(url: str) -> Tuple[str, str]: @@ -648,14 +677,14 @@ def get_bucket_and_name(url: str) -> Tuple[str, str]: if scheme != 'gs': raise ValueError(f'invalid scheme, expected gs: {scheme}') - rest = url[(colon_index + 1):] + rest = url[(colon_index + 1) :] if not rest.startswith('//'): raise ValueError(f'Google Cloud Storage URI must be of the form: gs://bucket/path, found: {url}') end_of_bucket = rest.find('/', 2) if end_of_bucket != -1: bucket = rest[2:end_of_bucket] - name = rest[(end_of_bucket + 1):] + name = rest[(end_of_bucket + 1) :] else: bucket = rest[2:] name = '' @@ -663,30 +692,26 @@ def get_bucket_and_name(url: str) -> Tuple[str, str]: return (bucket, name) async def open(self, url: str) -> GetObjectStream: - bucket, name = self.get_bucket_and_name(url) - return await self._storage_client.get_object(bucket, name) + fsurl = self.parse_url(url, error_if_bucket=True) + return await self._storage_client.get_object(fsurl._bucket, fsurl._path) async def _open_from(self, url: str, start: int, *, length: Optional[int] = None) -> GetObjectStream: - bucket, name = self.get_bucket_and_name(url) + fsurl = self.parse_url(url, error_if_bucket=True) range_str = f'bytes={start}-' if length is not None: assert length >= 1 range_str += str(start + length - 1) - return await self._storage_client.get_object( - bucket, name, headers={'Range': range_str}) + return await self._storage_client.get_object(fsurl._bucket, fsurl._path, headers={'Range': range_str}) async def create(self, url: str, *, retry_writes: bool = True) -> WritableStream: - bucket, name = self.get_bucket_and_name(url) - params = { - 'uploadType': 'resumable' if retry_writes else 'media' - } - return await self._storage_client.insert_object(bucket, name, params=params) + fsurl = self.parse_url(url, error_if_bucket=True) + params = {'uploadType': 'resumable' if retry_writes else 'media'} + return await self._storage_client.insert_object(fsurl._bucket, fsurl._path, params=params) async def multi_part_create( - self, - sema: asyncio.Semaphore, - url: str, - num_parts: int) -> GoogleStorageMultiPartCreate: + self, sema: asyncio.Semaphore, url: str, num_parts: int + ) -> GoogleStorageMultiPartCreate: + self.parse_url(url, error_if_bucket=True) return GoogleStorageMultiPartCreate(sema, self, url, num_parts) async def staturl(self, url: str) -> str: @@ -700,8 +725,8 @@ async def makedirs(self, url: str, exist_ok: bool = False) -> None: async def statfile(self, url: str) -> GetObjectFileStatus: try: - bucket, name = self.get_bucket_and_name(url) - return GetObjectFileStatus(await self._storage_client.get_object_metadata(bucket, name)) + fsurl = self.parse_url(url, error_if_bucket=True) + return GetObjectFileStatus(await self._storage_client.get_object_metadata(fsurl._bucket, fsurl._path), url) except aiohttp.ClientResponseError as e: if e.status == 404: raise FileNotFoundError(url) from e @@ -709,9 +734,7 @@ async def statfile(self, url: str) -> GetObjectFileStatus: async def _listfiles_recursive(self, bucket: str, name: str) -> AsyncIterator[FileListEntry]: assert not name or name.endswith('/') - params = { - 'prefix': name - } + params = {'prefix': name} async for page in await self._storage_client.list_objects(bucket, params=params): prefixes = page.get('prefixes') assert not prefixes @@ -723,11 +746,7 @@ async def _listfiles_recursive(self, bucket: str, name: str) -> AsyncIterator[Fi async def _listfiles_flat(self, bucket: str, name: str) -> AsyncIterator[FileListEntry]: assert not name or name.endswith('/') - params = { - 'prefix': name, - 'delimiter': '/', - 'includeTrailingDelimiter': 'true' - } + params = {'prefix': name, 'delimiter': '/', 'includeTrailingDelimiter': 'true'} async for page in await self._storage_client.list_objects(bucket, params=params): prefixes = page.get('prefixes') if prefixes: @@ -740,11 +759,9 @@ async def _listfiles_flat(self, bucket: str, name: str) -> AsyncIterator[FileLis for item in page['items']: yield GoogleStorageFileListEntry(bucket, item['name'], item) - async def listfiles(self, - url: str, - recursive: bool = False, - exclude_trailing_slash_files: bool = True - ) -> AsyncIterator[FileListEntry]: + async def listfiles( + self, url: str, recursive: bool = False, exclude_trailing_slash_files: bool = True + ) -> AsyncIterator[FileListEntry]: bucket, name = self.get_bucket_and_name(url) if name and not name.endswith('/'): name = f'{name}/' @@ -787,12 +804,12 @@ async def cons(first_entry, it) -> AsyncIterator[FileListEntry]: async def isfile(self, url: str) -> bool: try: - bucket, name = self.get_bucket_and_name(url) + fsurl = self.parse_url(url) # if name is empty, get_object_metadata behaves like list objects # the urls are the same modulo the object name - if not name: + if not fsurl._path: return False - await self._storage_client.get_object_metadata(bucket, name) + await self._storage_client.get_object_metadata(fsurl._bucket, fsurl._path) return True except aiohttp.ClientResponseError as e: if e.status == 404: @@ -800,15 +817,10 @@ async def isfile(self, url: str) -> bool: raise async def isdir(self, url: str) -> bool: - bucket, name = self.get_bucket_and_name(url) - assert not name or name.endswith('/'), name - params = { - 'prefix': name, - 'delimiter': '/', - 'includeTrailingDelimiter': 'true', - 'maxResults': 1 - } - async for page in await self._storage_client.list_objects(bucket, params=params): + fsurl = self.parse_url(url, error_if_bucket=True) + assert not fsurl._path or fsurl.path.endswith('/'), fsurl._path + params = {'prefix': fsurl._path, 'delimiter': '/', 'includeTrailingDelimiter': 'true', 'maxResults': 1} + async for page in await self._storage_client.list_objects(fsurl._bucket, params=params): prefixes = page.get('prefixes') items = page.get('items') return bool(prefixes or items) @@ -816,6 +828,8 @@ async def isdir(self, url: str) -> bool: async def remove(self, url: str) -> None: bucket, name = self.get_bucket_and_name(url) + if name == '': + raise IsABucketError(url) try: await self._storage_client.delete_object(bucket, name) except aiohttp.ClientResponseError as e: @@ -831,13 +845,10 @@ async def close(self) -> None: class GoogleStorageAsyncFSFactory(AsyncFSFactory[GoogleStorageAsyncFS]): def from_credentials_data(self, credentials_data: dict) -> GoogleStorageAsyncFS: - return GoogleStorageAsyncFS( - credentials=GoogleCredentials.from_credentials_data(credentials_data)) + return GoogleStorageAsyncFS(credentials=GoogleCredentials.from_credentials_data(credentials_data)) def from_credentials_file(self, credentials_file: str) -> GoogleStorageAsyncFS: - return GoogleStorageAsyncFS( - credentials=GoogleCredentials.from_file(credentials_file)) + return GoogleStorageAsyncFS(credentials=GoogleCredentials.from_file(credentials_file)) def from_default_credentials(self) -> GoogleStorageAsyncFS: - return GoogleStorageAsyncFS( - credentials=GoogleCredentials.default_credentials()) + return GoogleStorageAsyncFS(credentials=GoogleCredentials.default_credentials()) diff --git a/hail/python/hailtop/aiocloud/aiogoogle/credentials.py b/hail/python/hailtop/aiocloud/aiogoogle/credentials.py index 0b1356c1604..bdb8031e3c0 100644 --- a/hail/python/hailtop/aiocloud/aiogoogle/credentials.py +++ b/hail/python/hailtop/aiocloud/aiogoogle/credentials.py @@ -1,14 +1,16 @@ -from typing import Dict, Optional, Union, List, Literal, ClassVar, overload, Tuple -import os import json -import time import logging +import os import socket +import time +from typing import ClassVar, Dict, List, Literal, Optional, Tuple, Union, overload from urllib.parse import urlencode + import jwt -from hailtop.utils import first_extant_file, retry_transient_errors from hailtop import httpx +from hailtop.utils import first_extant_file, retry_transient_errors + from ..common.credentials import AnonymousCloudCredentials, CloudCredentials log = logging.getLogger(__name__) @@ -19,11 +21,13 @@ class GoogleExpiringAccessToken: def from_dict(data: dict) -> 'GoogleExpiringAccessToken': now = time.time() token = data['access_token'] - expiry_time = now + data['expires_in'] // 2 - return GoogleExpiringAccessToken(token, expiry_time) + expires_in = data['expires_in'] + expiry_time = now + expires_in // 2 + return GoogleExpiringAccessToken(token, expires_in, expiry_time) - def __init__(self, token, expiry_time: int): + def __init__(self, token, expires_in: int, expiry_time: int): self.token = token + self.expires_in = expires_in self._expiry_time = expiry_time def expired(self) -> bool: @@ -40,10 +44,9 @@ class GoogleCredentials(CloudCredentials): 'https://www.googleapis.com/auth/compute', ] - def __init__(self, - http_session: Optional[httpx.ClientSession] = None, - scopes: Optional[List[str]] = None, - **kwargs): + def __init__( + self, http_session: Optional[httpx.ClientSession] = None, scopes: Optional[List[str]] = None, **kwargs + ): self._access_token: Optional[GoogleExpiringAccessToken] = None self._scopes = scopes or GoogleCredentials.default_scopes if http_session is not None: @@ -52,6 +55,14 @@ def __init__(self, else: self._http_session = httpx.ClientSession(**kwargs) + @staticmethod + def from_file_or_default( + credentials_file: Optional[str] = None, + ) -> 'GoogleCredentials': + if credentials_file: + return GoogleCredentials.from_file(credentials_file) + return GoogleCredentials.default_credentials() + @staticmethod def from_file(credentials_file: str, *, scopes: Optional[List[str]] = None) -> 'GoogleCredentials': with open(credentials_file, encoding='utf-8') as f: @@ -71,17 +82,25 @@ def from_credentials_data(credentials: dict, scopes: Optional[List[str]] = None, @overload @staticmethod - def default_credentials(scopes: Optional[List[str]] = ..., *, anonymous_ok: Literal[False] = ...) -> 'GoogleCredentials': ... + def default_credentials( + scopes: Optional[List[str]] = ..., *, anonymous_ok: Literal[False] = ... + ) -> 'GoogleCredentials': ... @overload @staticmethod - def default_credentials(scopes: Optional[List[str]] = ..., *, anonymous_ok: Literal[True] = ...) -> Union['GoogleCredentials', AnonymousCloudCredentials]: ... + def default_credentials( + scopes: Optional[List[str]] = ..., *, anonymous_ok: Literal[True] = ... + ) -> Union['GoogleCredentials', AnonymousCloudCredentials]: ... @staticmethod - def default_credentials(scopes: Optional[List[str]] = None, *, anonymous_ok: bool = True) -> Union['GoogleCredentials', AnonymousCloudCredentials]: + def default_credentials( + scopes: Optional[List[str]] = None, *, anonymous_ok: bool = True + ) -> Union['GoogleCredentials', AnonymousCloudCredentials]: credentials_file = first_extant_file( os.environ.get('GOOGLE_APPLICATION_CREDENTIALS'), - f'{os.environ["HOME"]}/.config/gcloud/application_default_credentials.json' if 'HOME' in os.environ else None, + f'{os.environ["HOME"]}/.config/gcloud/application_default_credentials.json' + if 'HOME' in os.environ + else None, ) if credentials_file: @@ -98,8 +117,10 @@ def default_credentials(scopes: Optional[List[str]] = None, *, anonymous_ok: boo raise ValueError( 'No valid Google Cloud credentials found. Run `gcloud auth application-default login` or set `GOOGLE_APPLICATION_CREDENTIALS`.' ) - log.warning('Using anonymous credentials. If accessing private data, ' - 'run `gcloud auth application-default login` first to log in.') + log.warning( + 'Using anonymous credentials. If accessing private data, ' + 'run `gcloud auth application-default login` first to log in.' + ) return AnonymousCloudCredentials() async def auth_headers_with_expiration(self) -> Tuple[Dict[str, str], Optional[float]]: @@ -114,6 +135,12 @@ async def access_token_with_expiration(self) -> Tuple[str, Optional[float]]: async def _get_access_token(self) -> GoogleExpiringAccessToken: raise NotImplementedError + async def __aenter__(self): + return self + + async def __aexit__(self, *_): + await self.close() + async def close(self): await self._http_session.close() @@ -133,15 +160,13 @@ async def _get_access_token(self) -> GoogleExpiringAccessToken: token_dict = await retry_transient_errors( self._http_session.post_read_json, 'https://www.googleapis.com/oauth2/v4/token', - headers={ - 'content-type': 'application/x-www-form-urlencoded' - }, + headers={'content-type': 'application/x-www-form-urlencoded'}, data=urlencode({ 'grant_type': 'refresh_token', 'client_id': self.credentials['client_id'], 'client_secret': self.credentials['client_secret'], - 'refresh_token': self.credentials['refresh_token'] - }) + 'refresh_token': self.credentials['refresh_token'], + }), ) return GoogleExpiringAccessToken.from_dict(token_dict) @@ -150,13 +175,17 @@ async def _get_access_token(self) -> GoogleExpiringAccessToken: # https://developers.google.com/identity/protocols/oauth2/service-account # studying `gcloud --log-http print-access-token` was also useful class GoogleServiceAccountCredentials(GoogleCredentials): - def __init__(self, key, **kwargs): + def __init__(self, key: dict, **kwargs): super().__init__(**kwargs) self.key = key def __str__(self): return f'GoogleServiceAccountCredentials for {self.key["client_email"]}' + @property + def email(self) -> str: + return self.key['client_email'] + async def _get_access_token(self) -> GoogleExpiringAccessToken: now = int(time.time()) scope = ' '.join(self._scopes) @@ -165,19 +194,17 @@ async def _get_access_token(self) -> GoogleExpiringAccessToken: "iat": now, "scope": scope, "exp": now + 300, # 5m - "iss": self.key['client_email'] + "iss": self.key['client_email'], } encoded_assertion = jwt.encode(assertion, self.key['private_key'], algorithm='RS256') token_dict = await retry_transient_errors( self._http_session.post_read_json, 'https://www.googleapis.com/oauth2/v4/token', - headers={ - 'content-type': 'application/x-www-form-urlencoded' - }, + headers={'content-type': 'application/x-www-form-urlencoded'}, data=urlencode({ 'grant_type': 'urn:ietf:params:oauth:grant-type:jwt-bearer', - 'assertion': encoded_assertion - }) + 'assertion': encoded_assertion, + }), ) return GoogleExpiringAccessToken.from_dict(token_dict) @@ -188,7 +215,7 @@ async def _get_access_token(self) -> GoogleExpiringAccessToken: token_dict = await retry_transient_errors( self._http_session.get_read_json, 'http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/token', - headers={'Metadata-Flavor': 'Google'} + headers={'Metadata-Flavor': 'Google'}, ) return GoogleExpiringAccessToken.from_dict(token_dict) diff --git a/hail/python/hailtop/aiocloud/aiogoogle/session.py b/hail/python/hailtop/aiocloud/aiogoogle/session.py deleted file mode 100644 index 8a883a23a5c..00000000000 --- a/hail/python/hailtop/aiocloud/aiogoogle/session.py +++ /dev/null @@ -1,18 +0,0 @@ -from typing import Mapping, Optional, Union - -from ..common import Session, AnonymousCloudCredentials - -from .credentials import GoogleCredentials - - -class GoogleSession(Session): - def __init__(self, *, credentials: Optional[Union[GoogleCredentials, AnonymousCloudCredentials]] = None, credentials_file: Optional[str] = None, - params: Optional[Mapping[str, str]] = None, **kwargs): - assert credentials is None or credentials_file is None, \ - f'specify only one of credentials or credentials_file: {(credentials, credentials_file)}' - if credentials is None: - if credentials_file: - credentials = GoogleCredentials.from_file(credentials_file) - else: - credentials = GoogleCredentials.default_credentials() - super().__init__(credentials=credentials, params=params, **kwargs) diff --git a/hail/python/hailtop/aiocloud/aiogoogle/user_config.py b/hail/python/hailtop/aiocloud/aiogoogle/user_config.py index cdcee03da77..abc1d016c3f 100644 --- a/hail/python/hailtop/aiocloud/aiogoogle/user_config.py +++ b/hail/python/hailtop/aiocloud/aiogoogle/user_config.py @@ -1,21 +1,20 @@ -from typing import Optional, List, Union, Tuple import os import warnings -from jproperties import Properties -from enum import Enum from dataclasses import dataclass +from enum import Enum +from typing import List, Optional, Tuple, Union +from jproperties import Properties from hailtop.config.user_config import configuration_of from hailtop.config.variables import ConfigVariable - GCSRequesterPaysConfiguration = Union[str, Tuple[str, List[str]]] def get_gcs_requester_pays_configuration( - *, - gcs_requester_pays_configuration: Optional[GCSRequesterPaysConfiguration] = None, + *, + gcs_requester_pays_configuration: Optional[GCSRequesterPaysConfiguration] = None, ) -> Optional[GCSRequesterPaysConfiguration]: if gcs_requester_pays_configuration: return gcs_requester_pays_configuration @@ -55,7 +54,7 @@ def get_gcs_requester_pays_configuration( f'When reading GCS requester pays configuration from spark-defaults.conf ' f'({spark_conf.conf_path}), no mode is set, so requester pays ' f'will be disabled.' - ) + ) return None if spark_conf.mode == SparkConfGcsRequesterPaysMode.DISABLED: diff --git a/hail/python/hailtop/aiocloud/aioterra/__init__.py b/hail/python/hailtop/aiocloud/aioterra/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/hail/python/hailtop/aiocloud/aioterra/azure/__init__.py b/hail/python/hailtop/aiocloud/aioterra/azure/__init__.py new file mode 100644 index 00000000000..7fdb8333a66 --- /dev/null +++ b/hail/python/hailtop/aiocloud/aioterra/azure/__init__.py @@ -0,0 +1,7 @@ +from .client import TerraClient +from .fs import TerraAzureAsyncFS + +__all__ = [ + 'TerraClient', + 'TerraAzureAsyncFS', +] diff --git a/hail/python/hailtop/aiocloud/aioterra/azure/client/__init__.py b/hail/python/hailtop/aiocloud/aioterra/azure/client/__init__.py new file mode 100644 index 00000000000..f2df21b77e3 --- /dev/null +++ b/hail/python/hailtop/aiocloud/aioterra/azure/client/__init__.py @@ -0,0 +1,5 @@ +from .terra_client import TerraClient + +__all__ = [ + 'TerraClient', +] diff --git a/hail/python/hailtop/aiocloud/aioterra/azure/client/terra_client.py b/hail/python/hailtop/aiocloud/aioterra/azure/client/terra_client.py new file mode 100644 index 00000000000..93ced2821c1 --- /dev/null +++ b/hail/python/hailtop/aiocloud/aioterra/azure/client/terra_client.py @@ -0,0 +1,22 @@ +import os + +from .....auth import hail_credentials +from ....common import CloudBaseClient, Session + + +class TerraClient(CloudBaseClient): + def __init__(self): + base_url = f"{os.environ['WORKSPACE_MANAGER_URL']}/api/workspaces/v1/{os.environ['WORKSPACE_ID']}/resources/controlled/azure" + super().__init__(base_url, Session(credentials=hail_credentials())) + + async def get_storage_container_sas_token( + self, container_resource_id: str, blob_name: str, permissions: str = 'racwdl', expires_after: int = 3600 + ) -> str: + headers = {'Content-Type': 'application/json'} + params = {'sasPermissions': permissions, 'sasExpirationDuration': expires_after, 'sasBlobName': blob_name} + resp = await self.post( + f'/storageContainer/{container_resource_id}/getSasToken', + headers=headers, + params=params, + ) + return resp['url'] diff --git a/hail/python/hailtop/aiocloud/aioterra/azure/fs.py b/hail/python/hailtop/aiocloud/aioterra/azure/fs.py new file mode 100644 index 00000000000..aac9bd3c7d3 --- /dev/null +++ b/hail/python/hailtop/aiocloud/aioterra/azure/fs.py @@ -0,0 +1,55 @@ +import os +from typing import Dict, Tuple + +from azure.storage.blob.aio import BlobServiceClient + +from hailtop.aiocloud.aioazure import AzureAsyncFS, AzureAsyncFSURL +from hailtop.utils import time_msecs + +from .client import TerraClient + +WORKSPACE_STORAGE_CONTAINER_ID = os.environ.get('WORKSPACE_STORAGE_CONTAINER_ID') +WORKSPACE_STORAGE_CONTAINER_URL = os.environ.get('WORKSPACE_STORAGE_CONTAINER_URL') + + +class TerraAzureAsyncFS(AzureAsyncFS): + def __init__(self, **azure_kwargs): + assert WORKSPACE_STORAGE_CONTAINER_URL is not None + super().__init__(**azure_kwargs) + self._terra_client = TerraClient() + self._sas_token_cache: Dict[str, Tuple[AzureAsyncFSURL, int]] = {} + self._workspace_container = AzureAsyncFS.parse_url(WORKSPACE_STORAGE_CONTAINER_URL) + + @staticmethod + def enabled() -> bool: + return WORKSPACE_STORAGE_CONTAINER_ID is not None and WORKSPACE_STORAGE_CONTAINER_URL is not None + + async def get_blob_service_client(self, url: AzureAsyncFSURL) -> BlobServiceClient: + if self._in_workspace_container(url): + return await super().get_blob_service_client(await self._get_terra_sas_token_url(url)) + return await super().get_blob_service_client(url) + + async def _get_terra_sas_token_url(self, url: AzureAsyncFSURL) -> AzureAsyncFSURL: + if url.base in self._sas_token_cache: + sas_token_url, expiration = self._sas_token_cache[url.base] + ten_minutes_from_now = time_msecs() + 10 * 60 + if expiration > ten_minutes_from_now: + return sas_token_url + + sas_token_url, expiration = await self._create_terra_sas_token(url) + self._sas_token_cache[url.base] = (sas_token_url, expiration) + return sas_token_url + + async def _create_terra_sas_token(self, url: AzureAsyncFSURL) -> Tuple[AzureAsyncFSURL, int]: + an_hour_in_seconds = 3600 + expiration = time_msecs() + an_hour_in_seconds * 1000 + + assert WORKSPACE_STORAGE_CONTAINER_ID is not None + sas_token = await self._terra_client.get_storage_container_sas_token( + WORKSPACE_STORAGE_CONTAINER_ID, url.path, expires_after=an_hour_in_seconds + ) + + return AzureAsyncFS.parse_url(sas_token), expiration + + def _in_workspace_container(self, url: AzureAsyncFSURL) -> bool: + return url.account == self._workspace_container.account and url.container == self._workspace_container.container diff --git a/hail/python/hailtop/aiocloud/common/__init__.py b/hail/python/hailtop/aiocloud/common/__init__.py index 85c501754e1..ef2aff15337 100644 --- a/hail/python/hailtop/aiocloud/common/__init__.py +++ b/hail/python/hailtop/aiocloud/common/__init__.py @@ -1,7 +1,6 @@ from .base_client import CloudBaseClient -from .session import Session, RateLimitedSession from .credentials import AnonymousCloudCredentials - +from .session import RateLimitedSession, Session __all__ = [ 'CloudBaseClient', diff --git a/hail/python/hailtop/aiocloud/common/base_client.py b/hail/python/hailtop/aiocloud/common/base_client.py index 4ea04fd6b75..5e8a09b679b 100644 --- a/hail/python/hailtop/aiocloud/common/base_client.py +++ b/hail/python/hailtop/aiocloud/common/base_client.py @@ -1,5 +1,6 @@ from types import TracebackType from typing import Any, Optional, Type, TypeVar + from hailtop.utils import RateLimit from .session import BaseSession, RateLimitedSession @@ -8,41 +9,36 @@ class CloudBaseClient: - _session: BaseSession - def __init__(self, base_url: str, session: BaseSession, *, rate_limit: Optional[RateLimit] = None): self._base_url = base_url if rate_limit is not None: session = RateLimitedSession(session=session, rate_limit=rate_limit) self._session = session - async def get(self, path: Optional[str] = None, *, url: Optional[str] = None, **kwargs) -> Any: + async def request(self, method: str, path: Optional[str] = None, *, url: Optional[str] = None, **kwargs) -> Any: if url is None: assert path url = f'{self._base_url}{path}' - async with await self._session.get(url, **kwargs) as resp: + async with await self._session.request(method, url, **kwargs) as resp: return await resp.json() - async def post(self, path: Optional[str] = None, *, url: Optional[str] = None, **kwargs) -> Any: - if url is None: - assert path - url = f'{self._base_url}{path}' - async with await self._session.post(url, **kwargs) as resp: - return await resp.json() + async def get(self, *args, **kwargs) -> Any: + return await self.request('GET', *args, **kwargs) - async def delete(self, path: Optional[str] = None, *, url: Optional[str] = None, **kwargs) -> Any: - if url is None: - assert path - url = f'{self._base_url}{path}' - async with await self._session.delete(url, **kwargs) as resp: - return await resp.json() + async def post(self, *args, **kwargs) -> Any: + return await self.request('POST', *args, **kwargs) - async def put(self, path: Optional[str] = None, *, url: Optional[str] = None, **kwargs) -> Any: - if url is None: - assert path - url = f'{self._base_url}{path}' - async with await self._session.put(url, **kwargs) as resp: - return await resp.json() + async def put(self, *args, **kwargs) -> Any: + return await self.request('PUT', *args, **kwargs) + + async def patch(self, *args, **kwargs) -> Any: + return await self.request('PATCH', *args, **kwargs) + + async def delete(self, *args, **kwargs) -> Any: + return await self.request('DELETE', *args, **kwargs) + + async def head(self, *args, **kwargs) -> Any: + return await self.request('HEAD', *args, **kwargs) async def close(self) -> None: if hasattr(self, '_session'): @@ -52,8 +48,7 @@ async def close(self) -> None: async def __aenter__(self: ClientType) -> ClientType: return self - async def __aexit__(self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType]) -> None: + async def __aexit__( + self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType] + ) -> None: await self.close() diff --git a/hail/python/hailtop/aiocloud/common/credentials.py b/hail/python/hailtop/aiocloud/common/credentials.py index 30d7df47141..1f56cbaaa56 100644 --- a/hail/python/hailtop/aiocloud/common/credentials.py +++ b/hail/python/hailtop/aiocloud/common/credentials.py @@ -1,5 +1,5 @@ import abc -from typing import Dict, Tuple, Optional +from typing import Dict, Optional, Tuple class CloudCredentials(abc.ABC): diff --git a/hail/python/hailtop/aiocloud/common/session.py b/hail/python/hailtop/aiocloud/common/session.py index 6be8347261b..84b77d45594 100644 --- a/hail/python/hailtop/aiocloud/common/session.py +++ b/hail/python/hailtop/aiocloud/common/session.py @@ -1,14 +1,16 @@ +import abc +import logging +import time from contextlib import AsyncExitStack from types import TracebackType -from typing import Optional, Type, TypeVar, Mapping, Union -import time +from typing import Mapping, Optional, Type, TypeVar, Union + import aiohttp -import abc -import logging + from hailtop import httpx -from hailtop.utils import retry_transient_errors, RateLimit, RateLimiter -from .credentials import CloudCredentials, AnonymousCloudCredentials +from hailtop.utils import RateLimit, RateLimiter, retry_transient_errors +from .credentials import AnonymousCloudCredentials, CloudCredentials SessionType = TypeVar('SessionType', bound='BaseSession') log = logging.getLogger('hailtop.aiocloud.common.session') @@ -43,10 +45,9 @@ async def close(self) -> None: async def __aenter__(self: SessionType) -> SessionType: return self - async def __aexit__(self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType]) -> None: + async def __aexit__( + self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType] + ) -> None: await self.close() @@ -68,12 +69,14 @@ async def close(self) -> None: class Session(BaseSession): - def __init__(self, - *, - credentials: Union[CloudCredentials, AnonymousCloudCredentials], - params: Optional[Mapping[str, str]] = None, - http_session: Optional[httpx.ClientSession] = None, - **kwargs): + def __init__( + self, + *, + credentials: Union[CloudCredentials, AnonymousCloudCredentials], + params: Optional[Mapping[str, str]] = None, + http_session: Optional[httpx.ClientSession] = None, + **kwargs, + ): if 'raise_for_status' not in kwargs: kwargs['raise_for_status'] = True self._params = params diff --git a/hail/python/hailtop/aiogoogle/__init__.py b/hail/python/hailtop/aiogoogle/__init__.py index d368712e672..464707e40c0 100644 --- a/hail/python/hailtop/aiogoogle/__init__.py +++ b/hail/python/hailtop/aiogoogle/__init__.py @@ -1,6 +1,7 @@ import warnings + from ..aiocloud.aiogoogle import * # noqa: F403 -warnings.warn("importing hailtop.aiogoogle is deprecated, please use hailtop.aiocloud.aiogoogle", - DeprecationWarning, - stacklevel=2) +warnings.warn( + "importing hailtop.aiogoogle is deprecated, please use hailtop.aiocloud.aiogoogle", DeprecationWarning, stacklevel=2 +) diff --git a/hail/python/hailtop/aiotools/__init__.py b/hail/python/hailtop/aiotools/__init__.py index c25a7dcaffc..9488c047e07 100644 --- a/hail/python/hailtop/aiotools/__init__.py +++ b/hail/python/hailtop/aiotools/__init__.py @@ -1,9 +1,21 @@ -from .fs import (FileStatus, FileListEntry, AsyncFS, Transfer, MultiPartCreate, - FileAndDirectoryError, UnexpectedEOFError, Copier, ReadableStream, - WritableStream, blocking_readable_stream_to_async, blocking_writable_stream_to_async) +from .fs import ( + AsyncFS, + Copier, + FileAndDirectoryError, + FileListEntry, + FileStatus, + IsABucketError, + MultiPartCreate, + ReadableStream, + Transfer, + UnexpectedEOFError, + WritableStream, + blocking_readable_stream_to_async, + blocking_writable_stream_to_async, +) from .local_fs import LocalAsyncFS -from .utils import FeedableAsyncIterable, WriteBuffer from .tasks import BackgroundTaskManager, TaskManagerClosedError +from .utils import FeedableAsyncIterable, WriteBuffer from .weighted_semaphore import WeightedSemaphore __all__ = [ @@ -22,6 +34,7 @@ 'FileAndDirectoryError', 'MultiPartCreate', 'UnexpectedEOFError', + 'IsABucketError', 'WeightedSemaphore', 'WriteBuffer', 'Copier', diff --git a/hail/python/hailtop/aiotools/aio_contextlib.py b/hail/python/hailtop/aiotools/aio_contextlib.py index 6bc6616139f..a28a2199440 100644 --- a/hail/python/hailtop/aiotools/aio_contextlib.py +++ b/hail/python/hailtop/aiotools/aio_contextlib.py @@ -15,6 +15,7 @@ class closing: await f.close() """ + def __init__(self, thing): self.thing = thing diff --git a/hail/python/hailtop/aiotools/copy.py b/hail/python/hailtop/aiotools/copy.py index 5d65a5a5fef..d35f96a3ec2 100644 --- a/hail/python/hailtop/aiotools/copy.py +++ b/hail/python/hailtop/aiotools/copy.py @@ -1,28 +1,19 @@ -from typing import List, Dict, AsyncContextManager, Optional, Tuple import argparse import asyncio import json import logging import sys - from concurrent.futures import ThreadPoolExecutor +from typing import AsyncContextManager, Dict, List, Optional, Tuple + from rich.progress import Progress, TaskID -from ..utils.utils import sleep_before_try +from .. import uvloopx from ..utils.rich_progress_bar import CopyToolProgressBar, make_listener -from . import Transfer, Copier +from ..utils.utils import sleep_before_try +from . import Copier, Transfer from .router_fs import RouterAsyncFS -try: - import uvloop - uvloop_install = uvloop.install -except ImportError as e: - if not sys.platform.startswith('win32'): - raise e - - def uvloop_install(): - pass - class GrowingSempahore(AsyncContextManager[asyncio.Semaphore]): def __init__(self, start_max: int, target_max: int, progress_and_tid: Optional[Tuple[Progress, TaskID]]): @@ -66,15 +57,16 @@ async def __aexit__(self, exc_type, exc, tb): self.task.cancel() -async def copy(*, - max_simultaneous_transfers: Optional[int] = None, - local_kwargs: Optional[dict] = None, - gcs_kwargs: Optional[dict] = None, - azure_kwargs: Optional[dict] = None, - s3_kwargs: Optional[dict] = None, - transfers: List[Transfer], - verbose: bool = False, - ) -> None: +async def copy( + *, + max_simultaneous_transfers: Optional[int] = None, + local_kwargs: Optional[dict] = None, + gcs_kwargs: Optional[dict] = None, + azure_kwargs: Optional[dict] = None, + s3_kwargs: Optional[dict] = None, + transfers: List[Transfer], + verbose: bool = False, +) -> None: with ThreadPoolExecutor() as thread_pool: if max_simultaneous_transfers is None: max_simultaneous_transfers = 75 @@ -90,19 +82,20 @@ async def copy(*, if 'max_pool_connections' not in s3_kwargs: s3_kwargs['max_pool_connections'] = max_simultaneous_transfers * 2 - async with RouterAsyncFS(local_kwargs=local_kwargs, - gcs_kwargs=gcs_kwargs, - azure_kwargs=azure_kwargs, - s3_kwargs=s3_kwargs) as fs: + async with RouterAsyncFS( + local_kwargs=local_kwargs, gcs_kwargs=gcs_kwargs, azure_kwargs=azure_kwargs, s3_kwargs=s3_kwargs + ) as fs: with CopyToolProgressBar(transient=True, disable=not verbose) as progress: initial_simultaneous_transfers = 10 - parallelism_tid = progress.add_task(description='parallelism', - completed=initial_simultaneous_transfers, - total=max_simultaneous_transfers, - visible=verbose) - async with GrowingSempahore(initial_simultaneous_transfers, - max_simultaneous_transfers, - (progress, parallelism_tid)) as sema: + parallelism_tid = progress.add_task( + description='parallelism', + completed=initial_simultaneous_transfers, + total=max_simultaneous_transfers, + visible=verbose, + ) + async with GrowingSempahore( + initial_simultaneous_transfers, max_simultaneous_transfers, (progress, parallelism_tid) + ) as sema: file_tid = progress.add_task(description='files', total=0, visible=verbose) bytes_tid = progress.add_task(description='bytes', total=0, visible=verbose) copy_report = await Copier.copy( @@ -110,7 +103,8 @@ async def copy(*, sema, transfers, files_listener=make_listener(progress, file_tid), - bytes_listener=make_listener(progress, bytes_tid)) + bytes_listener=make_listener(progress, bytes_tid), + ) if verbose: copy_report.summarize() @@ -122,15 +116,16 @@ def make_transfer(json_object: Dict[str, str]) -> Transfer: return Transfer(json_object['from'], json_object['into'], treat_dest_as=Transfer.DEST_DIR) -async def copy_from_dict(*, - max_simultaneous_transfers: Optional[int] = None, - local_kwargs: Optional[dict] = None, - gcs_kwargs: Optional[dict] = None, - azure_kwargs: Optional[dict] = None, - s3_kwargs: Optional[dict] = None, - files: List[Dict[str, str]], - verbose: bool = False, - ) -> None: +async def copy_from_dict( + *, + max_simultaneous_transfers: Optional[int] = None, + local_kwargs: Optional[dict] = None, + gcs_kwargs: Optional[dict] = None, + azure_kwargs: Optional[dict] = None, + s3_kwargs: Optional[dict] = None, + files: List[Dict[str, str]], + verbose: bool = False, +) -> None: transfers = [make_transfer(json_object) for json_object in files] await copy( max_simultaneous_transfers=max_simultaneous_transfers, @@ -145,15 +140,26 @@ async def copy_from_dict(*, async def main() -> None: parser = argparse.ArgumentParser(description='Hail copy tool') - parser.add_argument('requester_pays_project', type=str, - help='a JSON string indicating the Google project to which to charge egress costs') - parser.add_argument('files', type=str, nargs='?', - help='a JSON array of JSON objects indicating from where and to where to copy files. If empty or "-", read the array from standard input instead') - parser.add_argument('--max-simultaneous-transfers', type=int, - help='The limit on the number of simultaneous transfers. Large files are uploaded as multiple transfers. This parameter sets an upper bound on the number of open source and destination files.') - parser.add_argument('-v', '--verbose', action='store_const', - const=True, default=False, - help='show logging information') + parser.add_argument( + 'requester_pays_project', + type=str, + help='a JSON string indicating the Google project to which to charge egress costs', + ) + parser.add_argument( + 'files', + type=str, + nargs='?', + help='a JSON array of JSON objects indicating from where and to where to copy files. If empty or "-", read the array from standard input instead', + ) + parser.add_argument( + '--max-simultaneous-transfers', + type=int, + help='The limit on the number of simultaneous transfers. Large files are uploaded as multiple transfers. This parameter sets an upper bound on the number of open source and destination files.', + ) + parser.add_argument( + '-v', '--verbose', action='store_const', const=True, default=False, help='show logging information' + ) + parser.add_argument('--timeout', type=str, default=None, help='show logging information') args = parser.parse_args() if args.verbose: @@ -164,16 +170,31 @@ async def main() -> None: if args.files is None or args.files == '-': args.files = sys.stdin.read() files = json.loads(args.files) - gcs_kwargs = {'gcs_requester_pays_configuration': requester_pays_project} + + timeout = args.timeout + if timeout: + timeout = float(timeout) + gcs_kwargs = { + 'gcs_requester_pays_configuration': requester_pays_project, + 'timeout': timeout, + } + azure_kwargs = { + 'timeout': timeout, + } + s3_kwargs = { + 'timeout': timeout, + } await copy_from_dict( max_simultaneous_transfers=args.max_simultaneous_transfers, gcs_kwargs=gcs_kwargs, + azure_kwargs=azure_kwargs, + s3_kwargs=s3_kwargs, files=files, - verbose=args.verbose + verbose=args.verbose, ) if __name__ == '__main__': - uvloop_install() + uvloopx.install() asyncio.run(main()) diff --git a/hail/python/hailtop/aiotools/delete.py b/hail/python/hailtop/aiotools/delete.py index 84e22634904..1d9f61a10a7 100644 --- a/hail/python/hailtop/aiotools/delete.py +++ b/hail/python/hailtop/aiotools/delete.py @@ -1,13 +1,13 @@ -from typing import Iterator -import sys +import argparse import asyncio import logging -import argparse +import sys from concurrent.futures import ThreadPoolExecutor +from typing import Iterator -from .router_fs import RouterAsyncFS -from ..utils.rich_progress_bar import SimpleCopyToolProgressBar from ..utils import grouped +from ..utils.rich_progress_bar import SimpleCopyToolProgressBar +from .router_fs import RouterAsyncFS async def delete(paths: Iterator[str]) -> None: @@ -16,32 +16,38 @@ async def delete(paths: Iterator[str]) -> None: async with RouterAsyncFS(local_kwargs=kwargs, s3_kwargs=kwargs) as fs: sema = asyncio.Semaphore(50) async with sema: - with SimpleCopyToolProgressBar( - description='files', - transient=True, - total=0) as file_pbar: + with SimpleCopyToolProgressBar(description='files', transient=True, total=0) as file_pbar: listener = file_pbar.make_listener() + + async def remove(path): + try: + await fs.remove(path) + except FileNotFoundError: + await fs.rmtree(sema, path, listener=listener) + file_pbar.update(1) # only advance if file or directory removal was successful, not on error + for grouped_paths in grouped(5_000, paths): - await asyncio.gather(*[ - fs.rmtree(sema, path, listener=listener) - for path in grouped_paths - ]) + file_pbar.update(0, total=file_pbar.total() + len(grouped_paths)) + await asyncio.gather(*[remove(path) for path in grouped_paths]) async def main() -> None: - parser = argparse.ArgumentParser(description='Delete the given files and directories.', - epilog='''Examples: + parser = argparse.ArgumentParser( + description='Delete the given files and directories.', + epilog="""Examples: python3 -m hailtop.aiotools.delete dir1/ file1 dir2/file1 dir2/file3 dir3 python3 -m hailtop.aiotools.delete gs://bucket1/dir1 gs://bucket1/file1 gs://bucket2/abc/123 -''', - formatter_class=argparse.RawDescriptionHelpFormatter) - parser.add_argument('files', type=str, nargs='*', - help='the paths (files or directories) to delete; if unspecified, read from stdin') - parser.add_argument('-v', '--verbose', action='store_const', - const=True, default=False, - help='show logging information') +""", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument( + 'files', type=str, nargs='*', help='the paths (files or directories) to delete; if unspecified, read from stdin' + ) + parser.add_argument( + '-v', '--verbose', action='store_const', const=True, default=False, help='show logging information' + ) args = parser.parse_args() if args.verbose: logging.basicConfig() @@ -54,5 +60,6 @@ async def main() -> None: await delete(files) + if __name__ == '__main__': asyncio.run(main()) diff --git a/hail/python/hailtop/aiotools/diff.py b/hail/python/hailtop/aiotools/diff.py index cc2aec590d5..40822671b05 100644 --- a/hail/python/hailtop/aiotools/diff.py +++ b/hail/python/hailtop/aiotools/diff.py @@ -1,37 +1,29 @@ -from typing import Optional, Tuple, TypeVar, List import argparse import asyncio -import orjson import logging import sys - from concurrent.futures import ThreadPoolExecutor +from typing import List, Optional, Tuple, TypeVar + +import orjson +from .. import uvloopx from ..utils.rich_progress_bar import SimpleCopyToolProgressBar, SimpleCopyToolProgressBarTask -from .router_fs import RouterAsyncFS from .fs import AsyncFS, FileStatus +from .router_fs import RouterAsyncFS -try: - import uvloop - uvloop_install = uvloop.install -except ImportError as e: - if not sys.platform.startswith('win32'): - raise e - - def uvloop_install(): - pass - - -async def diff(*, - max_simultaneous: Optional[int] = None, - local_kwargs: Optional[dict] = None, - gcs_kwargs: Optional[dict] = None, - azure_kwargs: Optional[dict] = None, - s3_kwargs: Optional[dict] = None, - source: str, - target: str, - verbose: bool = False, - ) -> List[dict]: + +async def diff( + *, + max_simultaneous: Optional[int] = None, + local_kwargs: Optional[dict] = None, + gcs_kwargs: Optional[dict] = None, + azure_kwargs: Optional[dict] = None, + s3_kwargs: Optional[dict] = None, + source: str, + target: str, + verbose: bool = False, +) -> List[dict]: with ThreadPoolExecutor() as thread_pool: if max_simultaneous is None: max_simultaneous = 500 @@ -47,15 +39,10 @@ async def diff(*, if 'max_pool_connections' not in s3_kwargs: s3_kwargs['max_pool_connections'] = max_simultaneous * 2 - async with RouterAsyncFS(local_kwargs=local_kwargs, - gcs_kwargs=gcs_kwargs, - azure_kwargs=azure_kwargs, - s3_kwargs=s3_kwargs) as fs: - with SimpleCopyToolProgressBar( - description='files', - transient=True, - total=0, - disable=not verbose) as pbar: + async with RouterAsyncFS( + local_kwargs=local_kwargs, gcs_kwargs=gcs_kwargs, azure_kwargs=azure_kwargs, s3_kwargs=s3_kwargs + ) as fs: + with SimpleCopyToolProgressBar(description='files', transient=True, total=0, disable=not verbose) as pbar: return await do_diff(source, target, fs, max_simultaneous, pbar) @@ -68,7 +55,9 @@ class DiffException(ValueError): pass -async def do_diff(top_source: str, top_target: str, fs: AsyncFS, max_simultaneous: int, pbar: SimpleCopyToolProgressBarTask) -> List[dict]: +async def do_diff( + top_source: str, top_target: str, fs: AsyncFS, max_simultaneous: int, pbar: SimpleCopyToolProgressBarTask +) -> List[dict]: if await fs.isfile(top_source): result = await diff_one(top_source, top_target, fs) if result is None: @@ -123,6 +112,7 @@ async def worker(): return different + T = TypeVar('T') @@ -149,18 +139,18 @@ async def diff_one(source_url: str, target_url: str, fs: AsyncFS) -> Optional[di async def main() -> None: - parser = argparse.ArgumentParser(description='Hail size diff tool. Recursively finds files which differ in size or are entirely missing.') - parser.add_argument('--requester-pays-project', type=str, nargs='?', - help='The Google project to which to charge egress costs.') - parser.add_argument('source', type=str, - help='The source of truth file or directory.') - parser.add_argument('target', type=str, - help='The target file or directory to which to compare.') - parser.add_argument('--max-simultaneous', type=int, - help='The limit on the number of simultaneous diff operations.') - parser.add_argument('-v', '--verbose', action='store_const', - const=True, default=False, - help='show logging information') + parser = argparse.ArgumentParser( + description='Hail size diff tool. Recursively finds files which differ in size or are entirely missing.' + ) + parser.add_argument( + '--requester-pays-project', type=str, nargs='?', help='The Google project to which to charge egress costs.' + ) + parser.add_argument('source', type=str, help='The source of truth file or directory.') + parser.add_argument('target', type=str, help='The target file or directory to which to compare.') + parser.add_argument('--max-simultaneous', type=int, help='The limit on the number of simultaneous diff operations.') + parser.add_argument( + '-v', '--verbose', action='store_const', const=True, default=False, help='show logging information' + ) args = parser.parse_args() if args.verbose: @@ -175,7 +165,7 @@ async def main() -> None: gcs_kwargs=gcs_kwargs, source=args.source, target=args.target, - verbose=args.verbose + verbose=args.verbose, ) except DiffException as exc: print(exc.args[0], file=sys.stderr) @@ -185,5 +175,5 @@ async def main() -> None: if __name__ == '__main__': - uvloop_install() + uvloopx.install() asyncio.run(main()) diff --git a/hail/python/hailtop/aiotools/fs/__init__.py b/hail/python/hailtop/aiotools/fs/__init__.py index a1922f76e49..8401750741e 100644 --- a/hail/python/hailtop/aiotools/fs/__init__.py +++ b/hail/python/hailtop/aiotools/fs/__init__.py @@ -1,8 +1,13 @@ -from .fs import AsyncFS, AsyncFSURL, AsyncFSFactory, MultiPartCreate, FileListEntry, FileStatus from .copier import Copier, CopyReport, SourceCopier, SourceReport, Transfer, TransferReport -from .exceptions import UnexpectedEOFError, FileAndDirectoryError -from .stream import (ReadableStream, EmptyReadableStream, WritableStream, - blocking_readable_stream_to_async, blocking_writable_stream_to_async) +from .exceptions import FileAndDirectoryError, IsABucketError, UnexpectedEOFError +from .fs import AsyncFS, AsyncFSFactory, AsyncFSURL, FileListEntry, FileStatus, MultiPartCreate +from .stream import ( + EmptyReadableStream, + ReadableStream, + WritableStream, + blocking_readable_stream_to_async, + blocking_writable_stream_to_async, +) __all__ = [ 'AsyncFS', @@ -24,4 +29,5 @@ 'FileStatus', 'FileAndDirectoryError', 'UnexpectedEOFError', + 'IsABucketError', ] diff --git a/hail/python/hailtop/aiotools/fs/copier.py b/hail/python/hailtop/aiotools/fs/copier.py index 913f00da4a6..187865acdf6 100644 --- a/hail/python/hailtop/aiotools/fs/copier.py +++ b/hail/python/hailtop/aiotools/fs/copier.py @@ -1,16 +1,22 @@ -from typing import Any, AsyncIterator, Awaitable, Optional, List, Union, Dict, Callable, Tuple -import os -import os.path import asyncio import functools -import humanize +import os +import os.path +from typing import Any, AsyncIterator, Awaitable, Callable, Dict, List, Optional, Tuple, Union +import humanize -from ...utils import (retry_transient_errors, url_basename, url_join, bounded_gather2, time_msecs, - humanize_timedelta_msecs) +from ...utils import ( + bounded_gather2, + humanize_timedelta_msecs, + retry_transient_errors, + time_msecs, + url_basename, + url_join, +) from ..weighted_semaphore import WeightedSemaphore from .exceptions import FileAndDirectoryError, UnexpectedEOFError -from .fs import MultiPartCreate, FileStatus, AsyncFS, FileListEntry +from .fs import AsyncFS, FileListEntry, FileStatus, MultiPartCreate class Transfer: @@ -24,8 +30,7 @@ def __init__(self, src: Union[str, List[str]], dest: str, *, treat_dest_as: str if treat_dest_as == Transfer.DEST_IS_TARGET and isinstance(src, list): raise NotADirectoryError(dest) - if (treat_dest_as == Transfer.INFER_DEST - and dest.endswith('/')): + if treat_dest_as == Transfer.INFER_DEST and dest.endswith('/'): treat_dest_as = Transfer.DEST_DIR self.src = src @@ -34,11 +39,13 @@ def __init__(self, src: Union[str, List[str]], dest: str, *, treat_dest_as: str class SourceReport: - def __init__(self, - source, - *, - files_listener: Optional[Callable[[int], None]] = None, - bytes_listener: Optional[Callable[[int], None]] = None): + def __init__( + self, + source, + *, + files_listener: Optional[Callable[[int], None]] = None, + bytes_listener: Optional[Callable[[int], None]] = None, + ): self._source = source self._files_listener = files_listener self._bytes_listener = bytes_listener @@ -78,29 +85,28 @@ def set_exception(self, exception: Exception): def set_file_error(self, srcfile: str, destfile: str, exception: Exception): if self._first_file_error is None: - self._first_file_error = { - 'srcfile': srcfile, - 'destfile': destfile, - 'exception': exception - } + self._first_file_error = {'srcfile': srcfile, 'destfile': destfile, 'exception': exception} class TransferReport: _source_report: Union[SourceReport, List[SourceReport]] - def __init__(self, - transfer: Transfer, - *, - files_listener: Optional[Callable[[int], None]] = None, - bytes_listener: Optional[Callable[[int], None]] = None): + def __init__( + self, + transfer: Transfer, + *, + files_listener: Optional[Callable[[int], None]] = None, + bytes_listener: Optional[Callable[[int], None]] = None, + ): self._transfer = transfer if isinstance(transfer.src, str): self._source_report = SourceReport( - transfer.src, files_listener=files_listener, bytes_listener=bytes_listener) + transfer.src, files_listener=files_listener, bytes_listener=bytes_listener + ) else: self._source_report = [ - SourceReport(s, files_listener=files_listener, bytes_listener=bytes_listener) - for s in transfer.src] + SourceReport(s, files_listener=files_listener, bytes_listener=bytes_listener) for s in transfer.src + ] self._exception: Optional[Exception] = None def set_exception(self, exception: Exception): @@ -109,21 +115,24 @@ def set_exception(self, exception: Exception): class CopyReport: - def __init__(self, - transfer: Union[Transfer, List[Transfer]], - *, - files_listener: Optional[Callable[[int], None]] = None, - bytes_listener: Optional[Callable[[int], None]] = None): + def __init__( + self, + transfer: Union[Transfer, List[Transfer]], + *, + files_listener: Optional[Callable[[int], None]] = None, + bytes_listener: Optional[Callable[[int], None]] = None, + ): self._start_time = time_msecs() self._end_time: Optional[int] = None self._duration: Optional[int] = None if isinstance(transfer, Transfer): self._transfer_report: Union[TransferReport, List[TransferReport]] = TransferReport( - transfer, files_listener=files_listener, bytes_listener=bytes_listener) + transfer, files_listener=files_listener, bytes_listener=bytes_listener + ) else: self._transfer_report = [ - TransferReport(t, files_listener=files_listener, bytes_listener=bytes_listener) - for t in transfer] + TransferReport(t, files_listener=files_listener, bytes_listener=bytes_listener) for t in transfer + ] self._exception: Optional[Exception] = None def set_exception(self, exception: Exception): @@ -174,12 +183,14 @@ def add_source_reports(transfer_report): class SourceCopier: - '''This class implements copy from a single source. In general, a + """This class implements copy from a single source. In general, a transfer will have multiple sources, and a SourceCopier will be created for each source. - ''' + """ - def __init__(self, router_fs: AsyncFS, xfer_sema: WeightedSemaphore, src: str, dest: str, treat_dest_as: str, dest_type_task): + def __init__( + self, router_fs: AsyncFS, xfer_sema: WeightedSemaphore, src: str, dest: str, treat_dest_as: str, dest_type_task + ): self.router_fs = router_fs self.xfer_sema = xfer_sema self.src = src @@ -218,18 +229,24 @@ async def _copy_file(self, source_report: SourceReport, srcfile: str, size: int, assert written == len(b) source_report.finish_bytes(written) - async def _copy_part(self, - source_report: SourceReport, - part_size: int, - srcfile: str, - part_number: int, - this_part_size: int, - part_creator: MultiPartCreate, - return_exceptions: bool) -> None: + async def _copy_part( + self, + source_report: SourceReport, + part_size: int, + srcfile: str, + part_number: int, + this_part_size: int, + part_creator: MultiPartCreate, + return_exceptions: bool, + ) -> None: try: async with self.xfer_sema.acquire_manager(min(Copier.BUFFER_SIZE, this_part_size)): - async with await self.router_fs.open_from(srcfile, part_number * part_size, length=this_part_size) as srcf: - async with await part_creator.create_part(part_number, part_number * part_size, size_hint=this_part_size) as destf: + async with await self.router_fs.open_from( + srcfile, part_number * part_size, length=this_part_size + ) as srcf: + async with await part_creator.create_part( + part_number, part_number * part_size, size_hint=this_part_size + ) as destf: n = this_part_size while n > 0: b = await srcf.read(min(Copier.BUFFER_SIZE, n)) @@ -246,13 +263,14 @@ async def _copy_part(self, raise async def _copy_file_multi_part_main( - self, - sema: asyncio.Semaphore, - source_report: SourceReport, - srcfile: str, - srcstat: FileStatus, - destfile: str, - return_exceptions: bool): + self, + sema: asyncio.Semaphore, + source_report: SourceReport, + srcfile: str, + srcstat: FileStatus, + destfile: str, + return_exceptions: bool, + ): size = await srcstat.size() part_size = self.router_fs.copy_part_size(destfile) @@ -272,25 +290,30 @@ async def _copy_file_multi_part_main( part_creator = await self.router_fs.multi_part_create(sema, destfile, n_parts) async with part_creator: + async def f(i): this_part_size = rem if i == n_parts - 1 and rem else part_size await retry_transient_errors( self._copy_part, - source_report, part_size, srcfile, i, this_part_size, part_creator, return_exceptions) + source_report, + part_size, + srcfile, + i, + this_part_size, + part_creator, + return_exceptions, + ) - await bounded_gather2(sema, *[ - functools.partial(f, i) - for i in range(n_parts) - ], cancel_on_error=True) + await bounded_gather2(sema, *[functools.partial(f, i) for i in range(n_parts)], cancel_on_error=True) async def _copy_file_multi_part( - self, - sema: asyncio.Semaphore, - source_report: SourceReport, - srcfile: str, - srcstat: FileStatus, - destfile: str, - return_exceptions: bool + self, + sema: asyncio.Semaphore, + source_report: SourceReport, + srcfile: str, + srcstat: FileStatus, + destfile: str, + return_exceptions: bool, ) -> None: success = False try: @@ -310,23 +333,24 @@ async def _full_dest(self): else: dest_type = None - if (self.treat_dest_as == Transfer.DEST_DIR - or (self.treat_dest_as == Transfer.INFER_DEST - and dest_type == AsyncFS.DIR)): + if self.treat_dest_as == Transfer.DEST_DIR or ( + self.treat_dest_as == Transfer.INFER_DEST and dest_type == AsyncFS.DIR + ): # We know dest is a dir, but we're copying to # dest/basename(src), and we don't know its type. return url_join(self.dest, url_basename(self.src.rstrip('/'))), None - if (self.treat_dest_as == Transfer.DEST_IS_TARGET - and self.dest.endswith('/')): + if self.treat_dest_as == Transfer.DEST_IS_TARGET and self.dest.endswith('/'): dest_type = AsyncFS.DIR return self.dest, dest_type - async def copy_as_file(self, - sema: asyncio.Semaphore, # pylint: disable=unused-argument - source_report: SourceReport, - return_exceptions: bool): + async def copy_as_file( + self, + sema: asyncio.Semaphore, # pylint: disable=unused-argument + source_report: SourceReport, + return_exceptions: bool, + ): try: src = self.src if src.endswith('/'): @@ -392,10 +416,17 @@ async def copy_source(srcentry: FileListEntry) -> None: if srcfile.endswith('/'): return - relsrcfile = srcfile[len(src):] + relsrcfile = srcfile[len(src) :] assert not relsrcfile.startswith('/') - await self._copy_file_multi_part(sema, source_report, srcfile, await srcentry.status(), url_join(full_dest, relsrcfile), return_exceptions) + await self._copy_file_multi_part( + sema, + source_report, + srcfile, + await srcentry.status(), + url_join(full_dest, relsrcfile), + return_exceptions, + ) async def create_copies() -> Tuple[List[Callable[[], Awaitable[None]]], int]: nonlocal srcentries @@ -424,17 +455,27 @@ async def copy(self, sema: asyncio.Semaphore, source_report: SourceReport, retur # gather with return_exceptions=True to make copy # deterministic with respect to exceptions results = await asyncio.gather( - self.copy_as_file(sema, source_report, return_exceptions), self.copy_as_dir(sema, source_report, return_exceptions), - return_exceptions=True) + self.copy_as_file(sema, source_report, return_exceptions), + self.copy_as_dir(sema, source_report, return_exceptions), + return_exceptions=True, + ) assert self.pending == 0 for result in results: - if isinstance(result, Exception): + if isinstance(result, BaseException): raise result assert (self.src_is_file is None) == self.src.endswith('/') - assert self.src_is_dir is not None + assert self.src_is_dir is not None, repr(( + results, + self.src_is_file, + self.src_is_dir, + self.src, + self.dest, + self.barrier, + self.pending, + )) if (self.src_is_file is False or self.src.endswith('/')) and not self.src_is_dir: raise FileNotFoundError(self.src) @@ -446,20 +487,22 @@ async def copy(self, sema: asyncio.Semaphore, source_report: SourceReport, retur class Copier: - ''' + """ This class implements copy for a list of transfers. - ''' + """ BUFFER_SIZE = 8 * 1024 * 1024 @staticmethod - async def copy(fs: AsyncFS, - sema: asyncio.Semaphore, - transfer: Union[Transfer, List[Transfer]], - return_exceptions: bool = False, - *, - files_listener: Optional[Callable[[int], None]] = None, - bytes_listener: Optional[Callable[[int], None]] = None) -> CopyReport: + async def copy( + fs: AsyncFS, + sema: asyncio.Semaphore, + transfer: Union[Transfer, List[Transfer]], + return_exceptions: bool = False, + *, + files_listener: Optional[Callable[[int], None]] = None, + bytes_listener: Optional[Callable[[int], None]] = None, + ) -> CopyReport: copier = Copier(fs) copy_report = CopyReport(transfer, files_listener=files_listener, bytes_listener=bytes_listener) await copier._copy(sema, copy_report, transfer, return_exceptions) @@ -474,17 +517,15 @@ def __init__(self, router_fs): self.xfer_sema = WeightedSemaphore(100 * Copier.BUFFER_SIZE) async def _dest_type(self, transfer: Transfer): - '''Return the (real or assumed) type of `dest`. + """Return the (real or assumed) type of `dest`. If the transfer assumes the type of `dest`, return that rather than the real type. A return value of `None` mean `dest` does not exist. - ''' + """ assert transfer.treat_dest_as != Transfer.DEST_IS_TARGET - if (transfer.treat_dest_as == Transfer.DEST_DIR - or isinstance(transfer.src, list) - or transfer.dest.endswith('/')): + if transfer.treat_dest_as == Transfer.DEST_DIR or isinstance(transfer.src, list) or transfer.dest.endswith('/'): return AsyncFS.DIR assert not transfer.dest.endswith('/') @@ -495,11 +536,23 @@ async def _dest_type(self, transfer: Transfer): return dest_type - async def copy_source(self, sema: asyncio.Semaphore, transfer: Transfer, source_report: SourceReport, src: str, dest_type_task, return_exceptions: bool): - src_copier = SourceCopier(self.router_fs, self.xfer_sema, src, transfer.dest, transfer.treat_dest_as, dest_type_task) + async def copy_source( + self, + sema: asyncio.Semaphore, + transfer: Transfer, + source_report: SourceReport, + src: str, + dest_type_task, + return_exceptions: bool, + ): + src_copier = SourceCopier( + self.router_fs, self.xfer_sema, src, transfer.dest, transfer.treat_dest_as, dest_type_task + ) await src_copier.copy(sema, source_report, return_exceptions) - async def _copy_one_transfer(self, sema: asyncio.Semaphore, transfer_report: TransferReport, transfer: Transfer, return_exceptions: bool): + async def _copy_one_transfer( + self, sema: asyncio.Semaphore, transfer_report: TransferReport, transfer: Transfer, return_exceptions: bool + ): try: if transfer.treat_dest_as == Transfer.INFER_DEST: dest_type_task: Optional[asyncio.Task] = asyncio.create_task(self._dest_type(transfer)) @@ -517,10 +570,14 @@ async def _copy_one_transfer(self, sema: asyncio.Semaphore, transfer_report: Tra if transfer.treat_dest_as == Transfer.DEST_IS_TARGET: raise NotADirectoryError(transfer.dest) - await bounded_gather2(sema, *[ - functools.partial(self.copy_source, sema, transfer, r, s, dest_type_task, return_exceptions) - for r, s in zip(src_report, src) - ], cancel_on_error=True) + await bounded_gather2( + sema, + *[ + functools.partial(self.copy_source, sema, transfer, r, s, dest_type_task, return_exceptions) + for r, s in zip(src_report, src) + ], + cancel_on_error=True, + ) # raise potential exception if dest_type_task: @@ -534,11 +591,13 @@ async def _copy_one_transfer(self, sema: asyncio.Semaphore, transfer_report: Tra else: raise e - async def _copy(self, - sema: asyncio.Semaphore, - copy_report: CopyReport, - transfer: Union[Transfer, List[Transfer]], - return_exceptions: bool): + async def _copy( + self, + sema: asyncio.Semaphore, + copy_report: CopyReport, + transfer: Union[Transfer, List[Transfer]], + return_exceptions: bool, + ): transfer_report = copy_report._transfer_report try: if isinstance(transfer, Transfer): @@ -547,10 +606,15 @@ async def _copy(self, return assert isinstance(transfer_report, list) - await bounded_gather2(sema, *[ - functools.partial(self._copy_one_transfer, sema, r, t, return_exceptions) - for r, t in zip(transfer_report, transfer) - ], return_exceptions=return_exceptions, cancel_on_error=True) + await bounded_gather2( + sema, + *[ + functools.partial(self._copy_one_transfer, sema, r, t, return_exceptions) + for r, t in zip(transfer_report, transfer) + ], + return_exceptions=return_exceptions, + cancel_on_error=True, + ) except Exception as e: if return_exceptions: copy_report.set_exception(e) diff --git a/hail/python/hailtop/aiotools/fs/exceptions.py b/hail/python/hailtop/aiotools/fs/exceptions.py index f63907833da..ed4a24b912d 100644 --- a/hail/python/hailtop/aiotools/fs/exceptions.py +++ b/hail/python/hailtop/aiotools/fs/exceptions.py @@ -1,7 +1,10 @@ - class UnexpectedEOFError(Exception): pass class FileAndDirectoryError(Exception): pass + + +class IsABucketError(FileNotFoundError): + pass diff --git a/hail/python/hailtop/aiotools/fs/fs.py b/hail/python/hailtop/aiotools/fs/fs.py index a3472f5ee9c..5b3830023c2 100644 --- a/hail/python/hailtop/aiotools/fs/fs.py +++ b/hail/python/hailtop/aiotools/fs/fs.py @@ -1,20 +1,37 @@ -from typing import (Any, AsyncContextManager, Optional, Type, Set, AsyncIterator, Callable, TypeVar, - Generic, List, Awaitable, Union, Tuple) -from typing_extensions import ParamSpec -from types import TracebackType import abc import asyncio import datetime -from hailtop.utils import retry_transient_errors, OnlineBoundedGather2 -from .stream import EmptyReadableStream, ReadableStream, WritableStream -from .exceptions import FileAndDirectoryError +from types import TracebackType +from typing import ( + Any, + AsyncContextManager, + AsyncIterator, + Awaitable, + Callable, + Generic, + List, + Optional, + Set, + Tuple, + Type, + TypeVar, + Union, +) + +from typing_extensions import ParamSpec, Self + +from hailtop.utils import OnlineBoundedGather2, retry_transient_errors +from .exceptions import FileAndDirectoryError +from .stream import EmptyReadableStream, ReadableStream, WritableStream T = TypeVar("T") P = ParamSpec("P") -async def with_exception(f: Callable[P, Awaitable[T]], *args: P.args, **kwargs: P.kwargs) -> Union[Tuple[T, None], Tuple[None, Exception]]: +async def with_exception( + f: Callable[P, Awaitable[T]], *args: P.args, **kwargs: P.kwargs +) -> Union[Tuple[T, None], Tuple[None, Exception]]: try: return (await f(*args, **kwargs)), None except Exception as e: @@ -22,27 +39,65 @@ async def with_exception(f: Callable[P, Awaitable[T]], *args: P.args, **kwargs: class FileStatus(abc.ABC): + @abc.abstractmethod + def basename(self) -> str: + """The basename of the object. + + Examples + -------- + + The basename of all of these objects is "file": + + - s3://bucket/folder/file + - gs://bucket/folder/file + - https://account.blob.core.windows.net/container/folder/file + - https://account.blob.core.windows.net/container/folder/file?sv=2023-01-01&sr=bv&sig=abc123&sp=rcw + - /folder/file + """ + + @abc.abstractmethod + def url(self) -> str: + """The URL of the object without any query parameters. + + Examples + -------- + + - s3://bucket/folder/file + - gs://bucket/folder/file + - https://account.blob.core.windows.net/container/folder/file + - /folder/file + + Note that the following URL + + https://account.blob.core.windows.net/container/folder/file?sv=2023-01-01&sr=bv&sig=abc123&sp=rcw + + becomes + + https://account.blob.core.windows.net/container/folder/file + + """ + @abc.abstractmethod async def size(self) -> int: pass @abc.abstractmethod def time_created(self) -> datetime.datetime: - '''The time the object was created in seconds since the epcoh, UTC. + """The time the object was created in seconds since the epcoh, UTC. Some filesystems do not support creation time. In that case, an error is raised. - ''' + """ @abc.abstractmethod def time_modified(self) -> datetime.datetime: - '''The time the object was last modified in seconds since the epoch, UTC. + """The time the object was last modified in seconds since the epoch, UTC. The meaning of modification time is cloud-defined. In some clouds, it is the creation time. In some clouds, it is the more recent of the creation time or the time of the most recent metadata modification. - ''' + """ @abc.abstractmethod async def __getitem__(self, key: str) -> Any: @@ -51,17 +106,62 @@ async def __getitem__(self, key: str) -> Any: class FileListEntry(abc.ABC): @abc.abstractmethod - def name(self) -> str: - pass + def basename(self) -> str: + """The basename of the object. + + Examples + -------- + + The basename of all of these objects is "file": + + - s3://bucket/folder/file + - gs://bucket/folder/file + - https://account.blob.core.windows.net/container/folder/file + - https://account.blob.core.windows.net/container/folder/file?sv=2023-01-01&sr=bv&sig=abc123&sp=rcw + - /folder/file + """ @abc.abstractmethod async def url(self) -> str: - pass + """The URL of the object without any query parameters. + + Examples + -------- + + - s3://bucket/folder/file + - gs://bucket/folder/file + - https://account.blob.core.windows.net/container/folder/file + - /folder/file + + Note that the following URL + + https://account.blob.core.windows.net/container/folder/file?sv=2023-01-01&sr=bv&sig=abc123&sp=rcw + + becomes + + https://account.blob.core.windows.net/container/folder/file + + """ async def url_maybe_trailing_slash(self) -> str: return await self.url() async def url_full(self) -> str: + """The URL of the object with any query parameters. + + Examples + -------- + + The only interesting case is for signed URLs in Azure. These are called shared signature tokens or SAS tokens. + For example, the following URL + + https://account.blob.core.windows.net/container/folder/file?sv=2023-01-01&sr=bv&sig=abc123&sp=rcw + + is a signed version of this URL + + https://account.blob.core.windows.net/container/folder/file + + """ return await self.url() @abc.abstractmethod @@ -79,18 +179,19 @@ async def status(self) -> FileStatus: class MultiPartCreate(abc.ABC): @abc.abstractmethod - async def create_part(self, number: int, start: int, size_hint: Optional[int] = None) -> AsyncContextManager[WritableStream]: + async def create_part( + self, number: int, start: int, size_hint: Optional[int] = None + ) -> AsyncContextManager[WritableStream]: pass @abc.abstractmethod - async def __aenter__(self) -> 'MultiPartCreate': + async def __aenter__(self) -> "MultiPartCreate": pass @abc.abstractmethod - async def __aexit__(self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType]) -> None: + async def __aexit__( + self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType] + ) -> None: pass @@ -116,12 +217,30 @@ def scheme(self) -> str: pass @abc.abstractmethod - def with_path(self, path) -> 'AsyncFSURL': + def with_path(self, path) -> "AsyncFSURL": + pass + + @abc.abstractmethod + def with_root_path(self) -> "AsyncFSURL": pass - def with_new_path_component(self, new_path_component) -> 'AsyncFSURL': - prefix = self.path if self.path.endswith('/') else self.path + '/' - suffix = new_path_component[1:] if new_path_component.startswith('/') else new_path_component + def with_new_path_component(self, new_path_component: str) -> "AsyncFSURL": + if new_path_component == '': + raise ValueError('new path component must be non-empty') + return self.with_new_path_components(new_path_component) + + def with_new_path_components(self, *parts: str) -> "AsyncFSURL": + if len(parts) == 0: + return self + + prefix = self.path + if not prefix.endswith("/") and not prefix == '': + prefix += "/" + + suffix = '/'.join(parts) + if suffix[0] == '/': + suffix = suffix[1:] + return self.with_path(prefix + suffix) @abc.abstractmethod @@ -130,21 +249,28 @@ def __str__(self) -> str: class AsyncFS(abc.ABC): - FILE = 'file' - DIR = 'dir' + FILE = "file" + DIR = "dir" - @property + @staticmethod @abc.abstractmethod - def schemes(self) -> Set[str]: + def schemes() -> Set[str]: pass + @staticmethod + def copy_part_size(url: str) -> int: # pylint: disable=unused-argument + """Part size when copying using multi-part uploads. The part size of + the destination filesystem is used.""" + return 128 * 1024 * 1024 + @staticmethod @abc.abstractmethod def valid_url(url: str) -> bool: pass + @staticmethod @abc.abstractmethod - def parse_url(self, url: str) -> AsyncFSURL: + def parse_url(url: str, *, error_if_bucket: bool = False) -> AsyncFSURL: pass @abc.abstractmethod @@ -154,12 +280,12 @@ async def open(self, url: str) -> ReadableStream: async def open_from(self, url: str, start: int, *, length: Optional[int] = None) -> ReadableStream: if length == 0: fs_url = self.parse_url(url) - if fs_url.path.endswith('/'): - file_url = str(fs_url.with_path(fs_url.path.rstrip('/'))) + if fs_url.path.endswith("/"): + file_url = str(fs_url.with_path(fs_url.path.rstrip("/"))) dir_url = str(fs_url) else: file_url = str(fs_url) - dir_url = str(fs_url.with_path(fs_url.path + '/')) + dir_url = str(fs_url.with_path(fs_url.path + "/")) isfile, isdir = await asyncio.gather(self.isfile(file_url), self.isdir(dir_url)) if isfile: if isdir: @@ -179,11 +305,7 @@ async def create(self, url: str, *, retry_writes: bool = True) -> AsyncContextMa pass @abc.abstractmethod - async def multi_part_create( - self, - sema: asyncio.Semaphore, - url: str, - num_parts: int) -> MultiPartCreate: + async def multi_part_create(self, sema: asyncio.Semaphore, url: str, num_parts: int) -> MultiPartCreate: pass @abc.abstractmethod @@ -199,10 +321,9 @@ async def statfile(self, url: str) -> FileStatus: pass @abc.abstractmethod - async def listfiles(self, - url: str, - recursive: bool = False, - exclude_trailing_slash_files: bool = True) -> AsyncIterator[FileListEntry]: + async def listfiles( + self, url: str, recursive: bool = False, exclude_trailing_slash_files: bool = True + ) -> AsyncIterator[FileListEntry]: pass @abc.abstractmethod @@ -210,10 +331,11 @@ async def staturl(self, url: str) -> str: pass async def _staturl_parallel_isfile_isdir(self, url: str) -> str: - assert not url.endswith('/') + assert not url.endswith("/") [(is_file, isfile_exc), (is_dir, isdir_exc)] = await asyncio.gather( - with_exception(self.isfile, url), with_exception(self.isdir, url + '/')) + with_exception(self.isfile, url), with_exception(self.isdir, url + "/") + ) # raise exception deterministically if isfile_exc: raise isfile_exc @@ -248,10 +370,9 @@ async def _remove_doesnt_exist_ok(self, url): except FileNotFoundError: pass - async def rmtree(self, - sema: Optional[asyncio.Semaphore], - url: str, - listener: Optional[Callable[[int], None]] = None) -> None: + async def rmtree( + self, sema: Optional[asyncio.Semaphore], url: str, listener: Optional[Callable[[int], None]] = None + ) -> None: if listener is None: listener = lambda _: None if sema is None: @@ -307,22 +428,16 @@ async def exists(self, url: str) -> bool: async def close(self) -> None: pass - async def __aenter__(self) -> 'AsyncFS': + async def __aenter__(self) -> Self: return self - async def __aexit__(self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType]) -> None: + async def __aexit__( + self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType] + ) -> None: await self.close() - def copy_part_size(self, url: str) -> int: # pylint: disable=unused-argument - '''Part size when copying using multi-part uploads. The part size of - the destination filesystem is used.''' - return 128 * 1024 * 1024 - -T = TypeVar('T', bound=AsyncFS) +T = TypeVar("T", bound=AsyncFS) class AsyncFSFactory(abc.ABC, Generic[T]): diff --git a/hail/python/hailtop/aiotools/fs/stream.py b/hail/python/hailtop/aiotools/fs/stream.py index 55818cd7f07..6b864a24fc6 100644 --- a/hail/python/hailtop/aiotools/fs/stream.py +++ b/hail/python/hailtop/aiotools/fs/stream.py @@ -1,11 +1,13 @@ -from typing import BinaryIO, List, Optional, Tuple, Type -from types import TracebackType import abc import io import os import sys from concurrent.futures import ThreadPoolExecutor +from types import TracebackType +from typing import BinaryIO, List, Optional, Tuple, Type + import janus + from hailtop.utils import blocking_to_async from .exceptions import UnexpectedEOFError @@ -28,10 +30,10 @@ async def read(self, n: int = -1) -> bytes: async def readexactly(self, n: int) -> bytes: raise NotImplementedError - async def seek(self, offset, whence): + async def seek(self, offset, whence) -> int: raise OSError - def seekable(self): + def seekable(self) -> bool: return False def tell(self) -> int: @@ -60,9 +62,11 @@ async def __aenter__(self) -> 'ReadableStream': return self async def __aexit__( - self, exc_type: Optional[Type[BaseException]] = None, - exc_value: Optional[BaseException] = None, - exc_traceback: Optional[TracebackType] = None) -> None: + self, + exc_type: Optional[Type[BaseException]] = None, + exc_value: Optional[BaseException] = None, + exc_traceback: Optional[TracebackType] = None, + ) -> None: await self.wait_closed() @@ -113,9 +117,11 @@ async def __aenter__(self) -> 'WritableStream': return self async def __aexit__( - self, exc_type: Optional[Type[BaseException]] = None, - exc_value: Optional[BaseException] = None, - exc_traceback: Optional[TracebackType] = None) -> None: + self, + exc_type: Optional[Type[BaseException]] = None, + exc_value: Optional[BaseException] = None, + exc_traceback: Optional[TracebackType] = None, + ) -> None: await self.wait_closed() @@ -136,8 +142,8 @@ async def read(self, n: int = -1) -> bytes: return await blocking_to_async(self._thread_pool, self._f.read) return await blocking_to_async(self._thread_pool, self._f.read, n) - async def seek(self, offset, whence): - self._f.seek(offset, whence) + async def seek(self, offset, whence) -> int: + return self._f.seek(offset, whence) def seekable(self): return True @@ -238,7 +244,7 @@ def readinto(self, b) -> int: self._unread = memoryview(self._unread) n = min(len(self._unread) - self._off, len(b) - total) - b[total:total + n] = self._unread[self._off:self._off + n] + b[total : total + n] = self._unread[self._off : self._off + n] self._off += n total += n assert total == len(b) or self._off == len(self._unread) @@ -320,7 +326,7 @@ def get(self) -> bytes: n = new_n assert k <= n - off - buf[off:off + k] = b + buf[off : off + k] = b off += k diff --git a/hail/python/hailtop/aiotools/local_fs.py b/hail/python/hailtop/aiotools/local_fs.py index fa58cbb7041..efe5066c267 100644 --- a/hail/python/hailtop/aiotools/local_fs.py +++ b/hail/python/hailtop/aiotools/local_fs.py @@ -1,26 +1,40 @@ -from typing import (Any, Optional, Type, BinaryIO, cast, Set, AsyncIterator, Callable, Dict, List, - ClassVar, Iterator) -from types import TracebackType +import asyncio +import datetime +import io import os import os.path -import io import stat -import asyncio -import datetime +import urllib.parse from concurrent.futures import ThreadPoolExecutor from contextlib import AbstractContextManager -import urllib.parse - -from ..utils import blocking_to_async, OnlineBoundedGather2 -from .fs import (FileStatus, FileListEntry, MultiPartCreate, AsyncFS, AsyncFSURL, - ReadableStream, WritableStream, blocking_readable_stream_to_async, - blocking_writable_stream_to_async) +from types import TracebackType +from typing import Any, AsyncIterator, BinaryIO, Callable, Dict, Iterator, List, Optional, Set, Type, cast + +from ..utils import OnlineBoundedGather2, blocking_to_async +from .fs import ( + AsyncFS, + AsyncFSURL, + FileListEntry, + FileStatus, + MultiPartCreate, + ReadableStream, + WritableStream, + blocking_readable_stream_to_async, + blocking_writable_stream_to_async, +) class LocalStatFileStatus(FileStatus): - def __init__(self, stat_result: os.stat_result): + def __init__(self, stat_result: os.stat_result, url: str): self._stat_result = stat_result self._items = None + self._url = url + + def basename(self) -> str: + return os.path.basename(self._url) + + def url(self) -> str: + return self._url async def size(self) -> int: return self._stat_result.st_size @@ -29,8 +43,7 @@ def time_created(self) -> datetime.datetime: raise ValueError('LocalFS does not support time created.') def time_modified(self) -> datetime.datetime: - return datetime.datetime.fromtimestamp(self._stat_result.st_mtime, - tz=datetime.timezone.utc) + return datetime.datetime.fromtimestamp(self._stat_result.st_mtime, tz=datetime.timezone.utc) async def __getitem__(self, key: str) -> Any: raise KeyError(key) @@ -46,7 +59,7 @@ def __init__(self, thread_pool: ThreadPoolExecutor, base_url: str, entry: os.Dir self._entry = entry self._status = None - def name(self) -> str: + def basename(self) -> str: return self._entry.name async def url(self) -> str: @@ -66,7 +79,9 @@ async def status(self) -> LocalStatFileStatus: if self._status is None: if await self.is_dir(): raise IsADirectoryError() - self._status = LocalStatFileStatus(await blocking_to_async(self._thread_pool, self._entry.stat)) + self._status = LocalStatFileStatus( + await blocking_to_async(self._thread_pool, self._entry.stat), await self.url() + ) return self._status @@ -76,8 +91,7 @@ def __init__(self, fs: 'LocalAsyncFS', path: str, num_parts: int): self._path = path self._num_parts = num_parts - async def create_part(self, number: int, start: int, - size_hint: Optional[int] = None): # pylint: disable=unused-argument + async def create_part(self, number: int, start: int, size_hint: Optional[int] = None): # pylint: disable=unused-argument assert 0 <= number < self._num_parts f = await blocking_to_async(self._fs._thread_pool, open, self._path, 'r+b') f.seek(start) @@ -86,10 +100,9 @@ async def create_part(self, number: int, start: int, async def __aenter__(self) -> 'LocalMultiPartCreate': return self - async def __aexit__(self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType]) -> None: + async def __aexit__( + self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType] + ) -> None: if exc_val: try: await self._fs.remove(self._path) @@ -101,6 +114,9 @@ class LocalAsyncFSURL(AsyncFSURL): def __init__(self, path: str): self._path = path + def __repr__(self) -> str: + return f'LocalAsyncFSURL({self.path})' + @property def bucket_parts(self) -> List[str]: return [] @@ -120,6 +136,9 @@ def scheme(self) -> str: def with_path(self, path) -> 'LocalAsyncFSURL': return LocalAsyncFSURL(path) + def with_root_path(self) -> 'LocalAsyncFSURL': + return self.with_path('/') + def __str__(self) -> str: return self._path @@ -219,19 +238,22 @@ def __next__(self): class LocalAsyncFS(AsyncFS): - schemes: ClassVar[Set[str]] = {'file'} - def __init__(self, thread_pool: Optional[ThreadPoolExecutor] = None, max_workers: Optional[int] = None): if not thread_pool: thread_pool = ThreadPoolExecutor(max_workers=max_workers) self._thread_pool = thread_pool + @staticmethod + def schemes() -> Set[str]: + return {'file'} + @staticmethod def valid_url(url: str) -> bool: return url.startswith('file://') or '://' not in url - def parse_url(self, url: str) -> LocalAsyncFSURL: - return LocalAsyncFSURL(self._get_path(url)) + @staticmethod + def parse_url(url: str, *, error_if_bucket: bool = False) -> LocalAsyncFSURL: + return LocalAsyncFSURL(LocalAsyncFS._get_path(url)) @staticmethod def _get_path(url): @@ -244,9 +266,10 @@ def _get_path(url): if parsed.netloc: if parsed.netloc != 'localhost': raise ValueError( - f"invalid file URL: {url}, invalid netloc: expected localhost or empty, got {parsed.netloc}") + f"invalid file URL: {url}, invalid netloc: expected localhost or empty, got {parsed.netloc}" + ) prefix += parsed.netloc - return url[len(prefix):] + return url[len(prefix) :] async def open(self, url: str) -> ReadableStream: f = await blocking_to_async(self._thread_pool, open, self._get_path(url), 'rb') @@ -266,10 +289,11 @@ async def create(self, url: str, *, retry_writes: bool = True) -> WritableStream return blocking_writable_stream_to_async(self._thread_pool, cast(BinaryIO, f)) async def multi_part_create( - self, - sema: asyncio.Semaphore, # pylint: disable=unused-argument - url: str, - num_parts: int) -> MultiPartCreate: + self, + sema: asyncio.Semaphore, + url: str, + num_parts: int, # pylint: disable=unused-argument + ) -> MultiPartCreate: # create an empty file # will be opened r+b to write the parts async with await self.create(url): @@ -281,7 +305,7 @@ async def statfile(self, url: str) -> LocalStatFileStatus: stat_result = await blocking_to_async(self._thread_pool, os.stat, path) if stat.S_ISDIR(stat_result.st_mode): raise FileNotFoundError(f'is directory: {url}') - return LocalStatFileStatus(stat_result) + return LocalStatFileStatus(stat_result, path) # entries has no type hint because the return type of os.scandir # appears to be a private type, posix.ScandirIterator. @@ -294,10 +318,9 @@ async def statfile(self, url: str) -> LocalStatFileStatus: # Traceback (most recent call last): # File "", line 1, in # AttributeError: module 'posix' has no attribute 'ScandirIterator' - async def _listfiles_recursive(self, - url: str, - entries: AbstractContextManager[Iterator[os.DirEntry]] - ) -> AsyncIterator[FileListEntry]: + async def _listfiles_recursive( + self, url: str, entries: AbstractContextManager[Iterator[os.DirEntry]] + ) -> AsyncIterator[FileListEntry]: async for file in self._listfiles_flat(url, entries): if await file.is_file(): yield file @@ -308,19 +331,16 @@ async def _listfiles_recursive(self, async for subfile in self._listfiles_recursive(new_url, new_entries): yield subfile - async def _listfiles_flat(self, - url: str, - entries: AbstractContextManager[Iterator[os.DirEntry]] - ) -> AsyncIterator[FileListEntry]: + async def _listfiles_flat( + self, url: str, entries: AbstractContextManager[Iterator[os.DirEntry]] + ) -> AsyncIterator[FileListEntry]: with entries as it: for entry in it: yield LocalFileListEntry(self._thread_pool, url, entry) - async def listfiles(self, - url: str, - recursive: bool = False, - exclude_trailing_slash_files: bool = True - ) -> AsyncIterator[FileListEntry]: + async def listfiles( + self, url: str, recursive: bool = False, exclude_trailing_slash_files: bool = True + ) -> AsyncIterator[FileListEntry]: del exclude_trailing_slash_files # such files do not exist on local file systems path = self._get_path(url) entries = await blocking_to_async(self._thread_pool, os.scandir, path) @@ -359,10 +379,9 @@ async def rmdir(self, url: str) -> None: path = self._get_path(url) return await blocking_to_async(self._thread_pool, os.rmdir, path) - async def rmtree(self, - sema: Optional[asyncio.Semaphore], - url: str, - listener: Optional[Callable[[int], None]] = None) -> None: + async def rmtree( + self, sema: Optional[asyncio.Semaphore], url: str, listener: Optional[Callable[[int], None]] = None + ) -> None: path = self._get_path(url) if listener is None: listener = lambda _: None @@ -375,9 +394,7 @@ async def rm_file(path: str): await self.remove(path) listener(-1) - async def rm_dir(pool: OnlineBoundedGather2, - contents_tasks: List[asyncio.Task], - path: str): + async def rm_dir(pool: OnlineBoundedGather2, contents_tasks: List[asyncio.Task], path: str): assert listener is not None listener(1) if contents_tasks: @@ -389,10 +406,8 @@ def raise_them_all(exceptions: List[BaseException]): raise exceptions[0] finally: raise_them_all(exceptions[1:]) - excs = [exc - for t in contents_tasks - for exc in [t.exception()] - if exc is not None] + + excs = [exc for t in contents_tasks for exc in [t.exception()] if exc is not None] raise_them_all(excs) await self.rmdir(path) listener(-1) @@ -400,17 +415,14 @@ def raise_them_all(exceptions: List[BaseException]): async with OnlineBoundedGather2(sema) as pool: contents_tasks_by_dir: Dict[str, List[asyncio.Task]] = {} for dirpath, dirnames, filenames in os.walk(path, topdown=False): + def rm_dir_or_symlink(path: str): if os.path.islink(path): return pool.call(rm_file, path) return pool.call(rm_dir, pool, contents_tasks_by_dir.get(path, []), path) - contents_tasks = [ - pool.call(rm_file, os.path.join(dirpath, filename)) - for filename in filenames - ] + [ - rm_dir_or_symlink(os.path.join(dirpath, dirname)) - for dirname in dirnames + contents_tasks = [pool.call(rm_file, os.path.join(dirpath, filename)) for filename in filenames] + [ + rm_dir_or_symlink(os.path.join(dirpath, dirname)) for dirname in dirnames ] contents_tasks_by_dir[dirpath] = contents_tasks await rm_dir(pool, contents_tasks_by_dir.get(path, []), path) diff --git a/hail/python/hailtop/aiotools/router_fs.py b/hail/python/hailtop/aiotools/router_fs.py index 4743e5bd202..fecf48cf009 100644 --- a/hail/python/hailtop/aiotools/router_fs.py +++ b/hail/python/hailtop/aiotools/router_fs.py @@ -1,24 +1,40 @@ -from typing import Any, Optional, List, Set, AsyncIterator, Dict, AsyncContextManager, Callable import asyncio +from contextlib import AsyncExitStack +from typing import Any, AsyncContextManager, AsyncIterator, Callable, ClassVar, Dict, List, Optional, Set, Type + +from hailtop.config import ConfigVariable, configuration_of from ..aiocloud import aioaws, aioazure, aiogoogle -from .fs import (AsyncFS, MultiPartCreate, FileStatus, FileListEntry, ReadableStream, - WritableStream, AsyncFSURL) +from ..aiocloud.aioterra import azure as aioterra_azure +from .fs import AsyncFS, AsyncFSURL, FileListEntry, FileStatus, MultiPartCreate, ReadableStream, WritableStream from .local_fs import LocalAsyncFS -from hailtop.config import ConfigVariable, configuration_of - class RouterAsyncFS(AsyncFS): - def __init__(self, - *, - filesystems: Optional[List[AsyncFS]] = None, - local_kwargs: Optional[Dict[str, Any]] = None, - gcs_kwargs: Optional[Dict[str, Any]] = None, - azure_kwargs: Optional[Dict[str, Any]] = None, - s3_kwargs: Optional[Dict[str, Any]] = None, - gcs_bucket_allow_list: Optional[List[str]] = None): - self._filesystems = [] if filesystems is None else filesystems + FS_CLASSES: ClassVar[List[type[AsyncFS]]] = [ + LocalAsyncFS, + aiogoogle.GoogleStorageAsyncFS, + aioterra_azure.TerraAzureAsyncFS, # Must precede Azure since Terra URLs are also valid Azure URLs + aioazure.AzureAsyncFS, + aioaws.S3AsyncFS, + ] + + def __init__( + self, + *, + local_kwargs: Optional[Dict[str, Any]] = None, + gcs_kwargs: Optional[Dict[str, Any]] = None, + azure_kwargs: Optional[Dict[str, Any]] = None, + s3_kwargs: Optional[Dict[str, Any]] = None, + gcs_bucket_allow_list: Optional[List[str]] = None, + ): + self._local_fs: Optional[LocalAsyncFS] = None + self._google_fs: Optional[aiogoogle.GoogleStorageAsyncFS] = None + self._terra_azure_fs: Optional[aioterra_azure.TerraAzureAsyncFS] = None + self._azure_fs: Optional[aioazure.AzureAsyncFS] = None + self._s3_fs: Optional[aioaws.S3AsyncFS] = None + self._exit_stack = AsyncExitStack() + self._local_kwargs = local_kwargs or {} self._gcs_kwargs = gcs_kwargs or {} self._azure_kwargs = azure_kwargs or {} @@ -29,115 +45,122 @@ def __init__(self, else configuration_of(ConfigVariable.GCS_BUCKET_ALLOW_LIST, None, fallback="").split(",") ) - def parse_url(self, url: str) -> AsyncFSURL: - return self._get_fs(url).parse_url(url) + @staticmethod + def schemes() -> Set[str]: + return {scheme for fs_class in RouterAsyncFS.FS_CLASSES for scheme in fs_class.schemes()} + + @staticmethod + def copy_part_size(url: str) -> int: + klass = RouterAsyncFS._fs_class(url) + return klass.copy_part_size(url) + + @staticmethod + def parse_url(url: str, *, error_if_bucket: bool = False) -> AsyncFSURL: + klass = RouterAsyncFS._fs_class(url) + return klass.parse_url(url, error_if_bucket=error_if_bucket) - @property - def schemes(self) -> Set[str]: - return set().union(*(fs.schemes for fs in self._filesystems)) + @staticmethod + def _fs_class(url: str) -> Type[AsyncFS]: + for klass in RouterAsyncFS.FS_CLASSES: + if klass.valid_url(url): + return klass + raise ValueError(f'no file system found for url {url}') @staticmethod def valid_url(url) -> bool: return ( LocalAsyncFS.valid_url(url) or aiogoogle.GoogleStorageAsyncFS.valid_url(url) + or aioterra_azure.TerraAzureAsyncFS.valid_url(url) or aioazure.AzureAsyncFS.valid_url(url) or aioaws.S3AsyncFS.valid_url(url) ) - def _load_fs(self, uri: str): - fs: AsyncFS - - if LocalAsyncFS.valid_url(uri): - fs = LocalAsyncFS(**self._local_kwargs) - elif aiogoogle.GoogleStorageAsyncFS.valid_url(uri): - fs = aiogoogle.GoogleStorageAsyncFS( - **self._gcs_kwargs, - bucket_allow_list = self._gcs_bucket_allow_list.copy() - ) - elif aioazure.AzureAsyncFS.valid_url(uri): - fs = aioazure.AzureAsyncFS(**self._azure_kwargs) - elif aioaws.S3AsyncFS.valid_url(uri): - fs = aioaws.S3AsyncFS(**self._s3_kwargs) - else: - raise ValueError(f'no file system found for url {uri}') - - self._filesystems.append(fs) - return fs - - def _get_fs(self, uri: str) -> AsyncFS: - for fs in self._filesystems: - if fs.valid_url(uri): - return fs - return self._load_fs(uri) + async def _get_fs(self, url: str): + if LocalAsyncFS.valid_url(url): + if self._local_fs is None: + self._local_fs = LocalAsyncFS(**self._local_kwargs) + self._exit_stack.push_async_callback(self._local_fs.close) + return self._local_fs + if aiogoogle.GoogleStorageAsyncFS.valid_url(url): + if self._google_fs is None: + self._google_fs = aiogoogle.GoogleStorageAsyncFS( + **self._gcs_kwargs, bucket_allow_list=self._gcs_bucket_allow_list.copy() + ) + self._exit_stack.push_async_callback(self._google_fs.close) + return self._google_fs + if aioterra_azure.TerraAzureAsyncFS.enabled() and aioterra_azure.TerraAzureAsyncFS.valid_url(url): + if self._terra_azure_fs is None: + self._terra_azure_fs = aioterra_azure.TerraAzureAsyncFS(**self._azure_kwargs) + self._exit_stack.push_async_callback(self._terra_azure_fs.close) + return self._terra_azure_fs + if aioazure.AzureAsyncFS.valid_url(url): + if self._azure_fs is None: + self._azure_fs = aioazure.AzureAsyncFS(**self._azure_kwargs) + self._exit_stack.push_async_callback(self._azure_fs.close) + return self._azure_fs + if aioaws.S3AsyncFS.valid_url(url): + if self._s3_fs is None: + self._s3_fs = aioaws.S3AsyncFS(**self._s3_kwargs) + self._exit_stack.push_async_callback(self._s3_fs.close) + return self._s3_fs + raise ValueError(f'no file system found for url {url}') async def open(self, url: str) -> ReadableStream: - fs = self._get_fs(url) + fs = await self._get_fs(url) return await fs.open(url) async def _open_from(self, url: str, start: int, *, length: Optional[int] = None) -> ReadableStream: - fs = self._get_fs(url) + fs = await self._get_fs(url) return await fs.open_from(url, start, length=length) - async def create(self, url: str, retry_writes: bool = True) -> AsyncContextManager[WritableStream]: - fs = self._get_fs(url) + async def create(self, url: str, *, retry_writes: bool = True) -> AsyncContextManager[WritableStream]: + fs = await self._get_fs(url) return await fs.create(url, retry_writes=retry_writes) - async def multi_part_create( - self, - sema: asyncio.Semaphore, - url: str, - num_parts: int) -> MultiPartCreate: - fs = self._get_fs(url) + async def multi_part_create(self, sema: asyncio.Semaphore, url: str, num_parts: int) -> MultiPartCreate: + fs = await self._get_fs(url) return await fs.multi_part_create(sema, url, num_parts) async def statfile(self, url: str) -> FileStatus: - fs = self._get_fs(url) + fs = await self._get_fs(url) return await fs.statfile(url) - async def listfiles(self, - url: str, - recursive: bool = False, - exclude_trailing_slash_files: bool = True - ) -> AsyncIterator[FileListEntry]: - fs = self._get_fs(url) + async def listfiles( + self, url: str, recursive: bool = False, exclude_trailing_slash_files: bool = True + ) -> AsyncIterator[FileListEntry]: + fs = await self._get_fs(url) return await fs.listfiles(url, recursive, exclude_trailing_slash_files) async def staturl(self, url: str) -> str: - fs = self._get_fs(url) + fs = await self._get_fs(url) return await fs.staturl(url) async def mkdir(self, url: str) -> None: - fs = self._get_fs(url) + fs = await self._get_fs(url) return await fs.mkdir(url) async def makedirs(self, url: str, exist_ok: bool = False) -> None: - fs = self._get_fs(url) + fs = await self._get_fs(url) return await fs.makedirs(url, exist_ok=exist_ok) async def isfile(self, url: str) -> bool: - fs = self._get_fs(url) + fs = await self._get_fs(url) return await fs.isfile(url) async def isdir(self, url: str) -> bool: - fs = self._get_fs(url) + fs = await self._get_fs(url) return await fs.isdir(url) async def remove(self, url: str) -> None: - fs = self._get_fs(url) + fs = await self._get_fs(url) return await fs.remove(url) - async def rmtree(self, - sema: Optional[asyncio.Semaphore], - url: str, - listener: Optional[Callable[[int], None]] = None) -> None: - fs = self._get_fs(url) + async def rmtree( + self, sema: Optional[asyncio.Semaphore], url: str, listener: Optional[Callable[[int], None]] = None + ) -> None: + fs = await self._get_fs(url) return await fs.rmtree(sema, url, listener) async def close(self) -> None: - for fs in self._filesystems: - await fs.close() - - def copy_part_size(self, url: str) -> int: - fs = self._get_fs(url) - return fs.copy_part_size(url) + await self._exit_stack.aclose() diff --git a/hail/python/hailtop/aiotools/tasks.py b/hail/python/hailtop/aiotools/tasks.py index a467793d5b6..48b8f38d092 100644 --- a/hail/python/hailtop/aiotools/tasks.py +++ b/hail/python/hailtop/aiotools/tasks.py @@ -1,7 +1,6 @@ -from typing import Callable, Set import asyncio import logging - +from typing import Callable, Set log = logging.getLogger('aiotools.tasks') @@ -33,6 +32,7 @@ def callback(fut: asyncio.Future): except asyncio.CancelledError: if not self._closed: log.exception('Background task was cancelled before task manager shutdown') + return callback def shutdown(self): diff --git a/hail/python/hailtop/aiotools/utils.py b/hail/python/hailtop/aiotools/utils.py index 0a7ae6d96cf..567acbf778f 100644 --- a/hail/python/hailtop/aiotools/utils.py +++ b/hail/python/hailtop/aiotools/utils.py @@ -1,8 +1,7 @@ -from typing import Deque, TypeVar, AsyncIterator, Iterator +import asyncio import collections from contextlib import contextmanager -import asyncio - +from typing import AsyncIterator, Deque, Iterator, TypeVar _T = TypeVar('_T') diff --git a/hail/python/hailtop/aiotools/validators.py b/hail/python/hailtop/aiotools/validators.py index 7f52e6751dc..ae4e691a6d7 100644 --- a/hail/python/hailtop/aiotools/validators.py +++ b/hail/python/hailtop/aiotools/validators.py @@ -1,17 +1,13 @@ -from hailtop.aiocloud.aiogoogle.client.storage_client import GoogleStorageAsyncFS -from hailtop.aiotools.router_fs import RouterAsyncFS -from hailtop.utils import async_to_blocking from textwrap import dedent from typing import Optional from urllib.parse import urlparse +from hailtop.aiocloud.aiogoogle.client.storage_client import GoogleStorageAsyncFS +from hailtop.aiotools.router_fs import RouterAsyncFS +from hailtop.hail_event_loop import hail_event_loop -def validate_file( - uri: str, - router_async_fs: RouterAsyncFS, - *, - validate_scheme: Optional[bool] = False -) -> None: + +def validate_file(uri: str, router_async_fs: RouterAsyncFS, *, validate_scheme: Optional[bool] = False) -> None: """ Validates a URI's scheme if a file scheme cache was provided, and its cloud location's default storage policy if the URI points to a cloud with an ``AsyncFS`` implementation that supports checking that policy. @@ -21,6 +17,14 @@ def validate_file( :class:`ValueError` If one of the validation steps fails. """ + return hail_event_loop().run_until_complete( + _async_validate_file(uri, router_async_fs, validate_scheme=validate_scheme) + ) + + +async def _async_validate_file( + uri: str, router_async_fs: RouterAsyncFS, *, validate_scheme: Optional[bool] = False +) -> None: if validate_scheme: scheme = urlparse(uri).scheme if not scheme or scheme == "file": @@ -28,14 +32,13 @@ def validate_file( f"Local filepath detected: '{uri}'. The Hail Batch Service does not support the use of local " "filepaths. Please specify a remote URI instead (e.g. 'gs://bucket/folder')." ) - fs = router_async_fs._get_fs(uri) + fs = await router_async_fs._get_fs(uri) if isinstance(fs, GoogleStorageAsyncFS): location = fs.storage_location(uri) if location not in fs.allowed_storage_locations: - if not async_to_blocking(fs.is_hot_storage(location)): + if not await fs.is_hot_storage(location): raise ValueError( - dedent( - f"""\ + dedent(f"""\ GCS Bucket '{location}' is configured to use cold storage by default. Accessing the blob '{uri}' would incur egress charges. Either @@ -45,7 +48,6 @@ def validate_file( * accept the increased cost by adding '{location}' to the 'gcs_bucket_allow_list' configuration variable (https://hail.is/docs/0.2/configuration_reference.html). - """ - ) + """) ) fs.allowed_storage_locations.append(location) diff --git a/hail/python/hailtop/aiotools/weighted_semaphore.py b/hail/python/hailtop/aiotools/weighted_semaphore.py index e910572af11..99ad80d96b9 100644 --- a/hail/python/hailtop/aiotools/weighted_semaphore.py +++ b/hail/python/hailtop/aiotools/weighted_semaphore.py @@ -1,7 +1,8 @@ +import asyncio +from types import TracebackType from typing import Optional, Type + from sortedcontainers import SortedKeyList -from types import TracebackType -import asyncio class _AcquireManager: @@ -13,10 +14,9 @@ async def __aenter__(self) -> '_AcquireManager': await self._ws.acquire(self._n) return self - async def __aexit__(self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType]) -> None: + async def __aexit__( + self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType] + ) -> None: self._ws.release(self._n) diff --git a/hail/python/hailtop/auth/__init__.py b/hail/python/hailtop/auth/__init__.py index 038c652cb27..e081d6e534f 100644 --- a/hail/python/hailtop/auth/__init__.py +++ b/hail/python/hailtop/auth/__init__.py @@ -1,12 +1,18 @@ from . import sql_config -from .tokens import (NotLoggedInError, get_tokens, session_id_encode_to_str, - session_id_decode_from_str) from .auth import ( - get_userinfo, hail_credentials, IdentityProvider, - copy_paste_login, async_copy_paste_login, - async_create_user, async_delete_user, async_get_user, async_logout, - async_get_userinfo) + IdentityProvider, + async_copy_paste_login, + async_create_user, + async_delete_user, + async_get_user, + async_get_userinfo, + async_logout, + copy_paste_login, + get_userinfo, + hail_credentials, +) from .flow import AzureFlow, Flow, GoogleFlow +from .tokens import NotLoggedInError, get_tokens, session_id_decode_from_str, session_id_encode_to_str __all__ = [ 'NotLoggedInError', diff --git a/hail/python/hailtop/auth/auth.py b/hail/python/hailtop/auth/auth.py index 1a1566b578e..1e8ab8603c9 100644 --- a/hail/python/hailtop/auth/auth.py +++ b/hail/python/hailtop/auth/auth.py @@ -1,21 +1,22 @@ -from typing import Any, Optional, Dict, Tuple, List +import json +import os from contextlib import asynccontextmanager from dataclasses import dataclass from enum import Enum -import os -import json +from typing import Any, Dict, List, Optional, Tuple + import aiohttp from hailtop import httpx -from hailtop.aiocloud.common.credentials import CloudCredentials -from hailtop.aiocloud.common import Session -from hailtop.aiocloud.aiogoogle import GoogleCredentials from hailtop.aiocloud.aioazure import AzureCredentials -from hailtop.config import get_deploy_config, DeployConfig, get_user_identity_config_path +from hailtop.aiocloud.aiogoogle import GoogleCredentials +from hailtop.aiocloud.common import Session +from hailtop.aiocloud.common.credentials import CloudCredentials +from hailtop.config import DeployConfig, get_deploy_config, get_user_identity_config_path from hailtop.utils import async_to_blocking, retry_transient_errors -from .tokens import get_tokens, Tokens -from .flow import GoogleFlow, AzureFlow +from .flow import AzureFlow, GoogleFlow +from .tokens import Tokens, get_tokens class IdentityProvider(Enum): @@ -35,7 +36,13 @@ def from_json(config: Dict[str, Any]): class HailCredentials(CloudCredentials): - def __init__(self, tokens: Tokens, cloud_credentials: Optional[CloudCredentials], deploy_config: DeployConfig, authorize_target: bool): + def __init__( + self, + tokens: Tokens, + cloud_credentials: Optional[CloudCredentials], + deploy_config: DeployConfig, + authorize_target: bool, + ): self._tokens = tokens self._cloud_credentials = cloud_credentials self._deploy_config = deploy_config @@ -72,7 +79,10 @@ async def _get_idp_access_token_or_hail_token(self, namespace: str) -> Tuple[str async def _get_hail_token_or_idp_access_token(self, namespace: str) -> Tuple[str, Optional[float]]: if self._cloud_credentials is None: return self._tokens.namespace_token_with_expiration_or_error(namespace) - return self._tokens.namespace_token_with_expiration(namespace) or await self._cloud_credentials.access_token_with_expiration() + return ( + self._tokens.namespace_token_with_expiration(namespace) + or await self._cloud_credentials.access_token_with_expiration() + ) async def close(self): if self._cloud_credentials: @@ -90,11 +100,16 @@ def hail_credentials( tokens_file: Optional[str] = None, cloud_credentials_file: Optional[str] = None, deploy_config: Optional[DeployConfig] = None, - authorize_target: bool = True + authorize_target: bool = True, ) -> HailCredentials: tokens = get_tokens(tokens_file) deploy_config = deploy_config or get_deploy_config() - return HailCredentials(tokens, get_cloud_credentials_scoped_for_hail(credentials_file=cloud_credentials_file), deploy_config, authorize_target=authorize_target) + return HailCredentials( + tokens, + get_cloud_credentials_scoped_for_hail(credentials_file=cloud_credentials_file), + deploy_config, + authorize_target=authorize_target, + ) def get_cloud_credentials_scoped_for_hail(credentials_file: Optional[str] = None) -> Optional[CloudCredentials]: @@ -114,7 +129,9 @@ def get_cloud_credentials_scoped_for_hail(credentials_file: Optional[str] = None assert spec.idp == IdentityProvider.MICROSOFT if spec.oauth2_credentials is not None: - return AzureCredentials.from_credentials_data(spec.oauth2_credentials, scopes=[spec.oauth2_credentials['userOauthScope']]) + return AzureCredentials.from_credentials_data( + spec.oauth2_credentials, scopes=[spec.oauth2_credentials['userOauthScope']] + ) if 'HAIL_AZURE_OAUTH_SCOPE' in os.environ: scopes = [os.environ["HAIL_AZURE_OAUTH_SCOPE"]] @@ -227,6 +244,7 @@ async def hail_session(**session_kwargs): async with Session(credentials=credentials, **session_kwargs) as session: yield session + def get_user(username: str) -> dict: return async_to_blocking(async_get_user(username)) diff --git a/hail/python/hailtop/auth/flow.py b/hail/python/hailtop/auth/flow.py index e74a2f9d8ed..7b2d09381de 100644 --- a/hail/python/hailtop/auth/flow.py +++ b/hail/python/hailtop/auth/flow.py @@ -1,12 +1,10 @@ import abc import base64 -from cryptography import x509 -from cryptography.hazmat.primitives import serialization import json import logging import os import urllib.parse -from typing import Any, Dict, List, Mapping, Optional, TypedDict, ClassVar +from typing import Any, ClassVar, Dict, List, Mapping, Optional, TypedDict import aiohttp.web import google.auth.transport.requests @@ -14,6 +12,8 @@ import google_auth_oauthlib.flow import jwt import msal +from cryptography import x509 +from cryptography.hazmat.primitives import serialization from hailtop import httpx from hailtop.utils import retry_transient_errors @@ -22,11 +22,7 @@ class FlowResult: - def __init__(self, - login_id: str, - unverified_email: str, - organization_id: Optional[str], - token: Mapping[Any, Any]): + def __init__(self, login_id: str, unverified_email: str, organization_id: Optional[str], token: Mapping[Any, Any]): self.login_id = login_id self.unverified_email = unverified_email self.organization_id = organization_id # In Azure, a Tenant ID. In Google, a domain name. @@ -65,19 +61,22 @@ def perform_installed_app_login_flow(oauth2_client: Dict[str, Any]) -> Dict[str, @staticmethod @abc.abstractmethod - async def logout_installed_app(oauth2_credentials: Dict[str, Any]): + async def logout_installed_app(oauth2_credentials: Dict[str, Any]) -> None: """Revokes the OAuth2 credentials on the user's machine.""" raise NotImplementedError @staticmethod @abc.abstractmethod - async def get_identity_uid_from_access_token(session: httpx.ClientSession, access_token: str, *, oauth2_client: dict) -> Optional[str]: + async def get_identity_uid_from_access_token( + session: httpx.ClientSession, access_token: str, *, oauth2_client: dict + ) -> Optional[str]: """ Validate a user-provided access token. If the token is valid, return the identity to which it belongs. If it is not valid, return None. """ raise NotImplementedError + class GoogleFlow(Flow): scopes: ClassVar[List[str]] = [ 'https://www.googleapis.com/auth/userinfo.profile', @@ -113,7 +112,8 @@ def receive_callback(self, request: aiohttp.web.Request, flow_dict: dict) -> Flo flow.redirect_uri = flow_dict['callback_uri'] flow.fetch_token(code=request.query['code']) token = google.oauth2.id_token.verify_oauth2_token( - flow.credentials.id_token, google.auth.transport.requests.Request() # type: ignore + flow.credentials.id_token, # type: ignore + google.auth.transport.requests.Request(), # type: ignore ) email = token['email'] return FlowResult(email, email, token.get('hd'), token) @@ -130,16 +130,18 @@ def perform_installed_app_login_flow(oauth2_client: Dict[str, Any]) -> Dict[str, } @staticmethod - async def logout_installed_app(oauth2_credentials: Dict[str, Any]): + async def logout_installed_app(oauth2_credentials: Dict[str, Any]) -> None: async with httpx.client_session() as session: await session.post( 'https://oauth2.googleapis.com/revoke', params={'token': oauth2_credentials['refresh_token']}, - headers={'content-type': 'application/x-www-form-urlencoded'} + headers={'content-type': 'application/x-www-form-urlencoded'}, ) @staticmethod - async def get_identity_uid_from_access_token(session: httpx.ClientSession, access_token: str, *, oauth2_client: dict) -> Optional[str]: + async def get_identity_uid_from_access_token( + session: httpx.ClientSession, access_token: str, *, oauth2_client: dict + ) -> Optional[str]: oauth2_client_audience = oauth2_client['installed']['client_id'] try: userinfo = await retry_transient_errors( @@ -187,9 +189,9 @@ def organization_id(self) -> str: def initiate_flow(self, redirect_uri: str) -> dict: flow = self._client.initiate_auth_code_flow( scopes=[], # confusingly, scopes=[] is the only way to get the openid, profile, and - # offline_access scopes - # https://github.com/AzureAD/microsoft-authentication-library-for-python/blob/dev/msal/application.py#L568-L580 - redirect_uri=redirect_uri + # offline_access scopes + # https://github.com/AzureAD/microsoft-authentication-library-for-python/blob/dev/msal/application.py#L568-L580 + redirect_uri=redirect_uri, ) return { 'flow': flow, @@ -214,7 +216,7 @@ def receive_callback(self, request: aiohttp.web.Request, flow_dict: dict) -> Flo token['id_token_claims']['oid'], token['id_token_claims']['preferred_username'], token['id_token_claims']['tid'], - token + token, ) @staticmethod @@ -226,7 +228,7 @@ def perform_installed_app_login_flow(oauth2_client: Dict[str, Any]) -> Dict[str, return {**oauth2_client, 'refreshToken': credentials['refresh_token']} @staticmethod - async def logout_installed_app(_: Dict[str, Any]): + async def logout_installed_app(oauth2_credentials: Dict[str, Any]): # AAD does not support revocation of a single refresh token, # only all refresh tokens issued to all applications for a particular # user, which we neither wish nor should have the permissions @@ -235,7 +237,9 @@ async def logout_installed_app(_: Dict[str, Any]): pass @staticmethod - async def get_identity_uid_from_access_token(session: httpx.ClientSession, access_token: str, *, oauth2_client: dict) -> Optional[str]: + async def get_identity_uid_from_access_token( + session: httpx.ClientSession, access_token: str, *, oauth2_client: dict + ) -> Optional[str]: audience = oauth2_client['appIdentifierUri'] try: @@ -253,7 +257,13 @@ async def get_identity_uid_from_access_token(session: httpx.ClientSession, acces jwk = next(key for key in AzureFlow._aad_keys if key['kid'] == kid) der_cert = base64.b64decode(jwk['x5c'][0]) cert = x509.load_der_x509_certificate(der_cert) - pem_key = cert.public_key().public_bytes(encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo).decode() + pem_key = ( + cert.public_key() + .public_bytes( + encoding=serialization.Encoding.PEM, format=serialization.PublicFormat.SubjectPublicKeyInfo + ) + .decode() + ) decoded = jwt.decode(access_token, pem_key, algorithms=['RS256'], audience=audience) return decoded['oid'] diff --git a/hail/python/hailtop/auth/sql_config.py b/hail/python/hailtop/auth/sql_config.py index 32a1473f364..49984abfa3d 100644 --- a/hail/python/hailtop/auth/sql_config.py +++ b/hail/python/hailtop/auth/sql_config.py @@ -1,6 +1,6 @@ -from typing import Dict, NamedTuple, Optional, Any import json import os +from typing import Any, Dict, NamedTuple, Optional class SQLConfig(NamedTuple): @@ -20,14 +20,16 @@ def to_json(self) -> str: return json.dumps(self.to_dict()) def to_dict(self) -> Dict[str, Any]: - d = {'host': self.host, - 'port': self.port, - 'user': self.user, - 'password': self.password, - 'instance': self.instance, - 'connection_name': self.connection_name, - 'ssl-ca': self.ssl_ca, - 'ssl-mode': self.ssl_mode} + d = { + 'host': self.host, + 'port': self.port, + 'user': self.user, + 'password': self.password, + 'instance': self.instance, + 'connection_name': self.connection_name, + 'ssl-ca': self.ssl_ca, + 'ssl-mode': self.ssl_mode, + } if self.db is not None: d['db'] = self.db if self.using_mtls(): @@ -36,14 +38,14 @@ def to_dict(self) -> Dict[str, Any]: return d def to_cnf(self) -> str: - cnf = f'''[client] + cnf = f"""[client] host={self.host} user={self.user} port={self.port} password="{self.password}" ssl-ca={self.ssl_ca} ssl-mode={self.ssl_mode} -''' +""" if self.db is not None: cnf += f'database={self.db}\n' if self.using_mtls(): @@ -83,22 +85,22 @@ def from_json(s: str) -> 'SQLConfig': @staticmethod def from_dict(d: Dict[str, Any]) -> 'SQLConfig': - for k in ('host', 'port', 'user', 'password', - 'instance', 'connection_name', - 'ssl-ca', 'ssl-mode'): + for k in ('host', 'port', 'user', 'password', 'instance', 'connection_name', 'ssl-ca', 'ssl-mode'): assert k in d, f'{k} should be in {d}' assert d[k] is not None, f'{k} should not be None in {d}' - return SQLConfig(host=d['host'], - port=d['port'], - user=d['user'], - password=d['password'], - instance=d['instance'], - connection_name=d['connection_name'], - db=d.get('db'), - ssl_ca=d['ssl-ca'], - ssl_cert=d.get('ssl-cert'), - ssl_key=d.get('ssl-key'), - ssl_mode=d['ssl-mode']) + return SQLConfig( + host=d['host'], + port=d['port'], + user=d['user'], + password=d['password'], + instance=d['instance'], + connection_name=d['connection_name'], + db=d.get('db'), + ssl_ca=d['ssl-ca'], + ssl_cert=d.get('ssl-cert'), + ssl_key=d.get('ssl-key'), + ssl_mode=d['ssl-mode'], + ) @staticmethod def local_insecure_config() -> 'SQLConfig': @@ -117,11 +119,9 @@ def local_insecure_config() -> 'SQLConfig': ) -def create_secret_data_from_config(config: SQLConfig, - server_ca: str, - client_cert: Optional[str], - client_key: Optional[str] - ) -> Dict[str, str]: +def create_secret_data_from_config( + config: SQLConfig, server_ca: str, client_cert: Optional[str], client_key: Optional[str] +) -> Dict[str, str]: secret_data = {} secret_data['sql-config.json'] = config.to_json() secret_data['sql-config.cnf'] = config.to_cnf() diff --git a/hail/python/hailtop/auth/tokens.py b/hail/python/hailtop/auth/tokens.py index 6e44ece2645..b9604ac520f 100644 --- a/hail/python/hailtop/auth/tokens.py +++ b/hail/python/hailtop/auth/tokens.py @@ -1,9 +1,10 @@ -from typing import Optional, Dict, Tuple import base64 import collections.abc -import os import json import logging +import os +from typing import Dict, Optional, Tuple + from hailtop.config import get_deploy_config from hailtop.utils import first_extant_file @@ -21,13 +22,13 @@ def session_id_decode_from_str(session_id_str: str) -> bytes: class NotLoggedInError(Exception): def __init__(self, ns_arg): super().__init__() - self.message = f''' + self.message = f""" You are not authenticated. Please log in with: $ hailctl auth login {ns_arg} to obtain new credentials. -''' +""" def __str__(self): return self.message @@ -37,11 +38,14 @@ class Tokens(collections.abc.MutableMapping): @staticmethod def get_tokens_file() -> str: default_enduser_token_file = os.path.expanduser('~/.hail/tokens.json') - return first_extant_file( - os.environ.get('HAIL_TOKENS_FILE'), - default_enduser_token_file, - '/user-tokens/tokens.json', - ) or default_enduser_token_file + return ( + first_extant_file( + os.environ.get('HAIL_TOKENS_FILE'), + default_enduser_token_file, + '/user-tokens/tokens.json', + ) + or default_enduser_token_file + ) @staticmethod def default_tokens() -> 'Tokens': @@ -100,13 +104,7 @@ def __len__(self): def write(self) -> None: # restrict permissions to user with os.fdopen( - os.open( - self.get_tokens_file(), - os.O_CREAT | os.O_WRONLY | os.O_TRUNC, - 0o600 - ), - 'w', - encoding='utf-8' + os.open(self.get_tokens_file(), os.O_CREAT | os.O_WRONLY | os.O_TRUNC, 0o600), 'w', encoding='utf-8' ) as f: json.dump(self._tokens, f) diff --git a/hail/python/hailtop/batch/Makefile b/hail/python/hailtop/batch/Makefile index 35d5bd1e69c..9f8be057005 100644 --- a/hail/python/hailtop/batch/Makefile +++ b/hail/python/hailtop/batch/Makefile @@ -6,4 +6,5 @@ doctest: -r A \ --doctest-modules \ --ignore=docs/conf.py \ - --doctest-glob='*.rst' + --doctest-glob='*.rst' \ + --doctest-continue-on-failure diff --git a/hail/python/hailtop/batch/__init__.py b/hail/python/hailtop/batch/__init__.py index 3b995a321d9..0f4709c987f 100644 --- a/hail/python/hailtop/batch/__init__.py +++ b/hail/python/hailtop/batch/__init__.py @@ -1,27 +1,28 @@ import warnings +from .backend import Backend, LocalBackend, ServiceBackend from .batch import Batch from .batch_pool_executor import BatchPoolExecutor -from .backend import LocalBackend, ServiceBackend, Backend from .docker import build_python_image from .exceptions import BatchException +from .resource import PythonResult, Resource, ResourceFile, ResourceGroup from .utils import concatenate, plink_merge -from .resource import Resource, ResourceFile, ResourceGroup, PythonResult -__all__ = ['Batch', - 'LocalBackend', - 'ServiceBackend', - 'Backend', - 'BatchException', - 'BatchPoolExecutor', - 'build_python_image', - 'concatenate', - 'plink_merge', - 'PythonResult', - 'Resource', - 'ResourceFile', - 'ResourceGroup', - ] +__all__ = [ + 'Batch', + 'LocalBackend', + 'ServiceBackend', + 'Backend', + 'BatchException', + 'BatchPoolExecutor', + 'build_python_image', + 'concatenate', + 'plink_merge', + 'PythonResult', + 'Resource', + 'ResourceFile', + 'ResourceGroup', +] warnings.filterwarnings('once', append=True) del warnings diff --git a/hail/python/hailtop/batch/backend.py b/hail/python/hailtop/batch/backend.py index d38692b7c1c..5f77d8357cc 100644 --- a/hail/python/hailtop/batch/backend.py +++ b/hail/python/hailtop/batch/backend.py @@ -1,41 +1,43 @@ -from typing import Optional, Dict, Any, TypeVar, Generic, List, Union, ClassVar import abc import asyncio import collections -import orjson +import copy +import functools import os import subprocess as sp -import uuid import time -import functools -import copy -from shlex import quote as shq -import webbrowser +import uuid import warnings +import webbrowser +from shlex import quote as shq +from typing import Any, ClassVar, Dict, Generic, List, Optional, TypeVar, Union + +import orjson from rich.progress import track +import hailtop.batch_client.client as bc from hailtop import pip_version -from hailtop.config import ConfigVariable, configuration_of, get_deploy_config, get_remote_tmpdir -from hailtop.utils.rich_progress_bar import SimpleCopyToolProgressBar -from hailtop.utils import parse_docker_image_reference, async_to_blocking, bounded_gather, url_scheme +from hailtop.aiocloud.aiogoogle import GCSRequesterPaysConfiguration +from hailtop.aiotools.router_fs import RouterAsyncFS +from hailtop.aiotools.validators import validate_file from hailtop.batch.hail_genetics_images import HAIL_GENETICS_IMAGES, hailgenetics_hail_image_for_current_python_version - -from hailtop.batch_client.parse import parse_cpu_in_mcpu -import hailtop.batch_client.client as bc -from hailtop.batch_client.client import BatchClient from hailtop.batch_client.aioclient import BatchClient as AioBatchClient -from hailtop.aiotools.router_fs import RouterAsyncFS -from hailtop.aiocloud.aiogoogle import GCSRequesterPaysConfiguration +from hailtop.batch_client.client import BatchClient +from hailtop.batch_client.parse import parse_cpu_in_mcpu +from hailtop.config import ConfigVariable, configuration_of, get_deploy_config, get_remote_tmpdir +from hailtop.utils import async_to_blocking, bounded_gather, parse_docker_image_reference, url_scheme +from hailtop.utils.gcs_requester_pays import GCSRequesterPaysFSCache +from hailtop.utils.rich_progress_bar import SimpleCopyToolProgressBar -from . import resource, batch # pylint: disable=unused-import -from .job import PythonJob +from . import batch, resource # pylint: disable=unused-import from .exceptions import BatchException from .globals import DEFAULT_SHELL -from hailtop.aiotools.validators import validate_file - +from .job import PythonJob HAIL_GENETICS_HAILTOP_IMAGE = os.environ.get('HAIL_GENETICS_HAILTOP_IMAGE', f'hailgenetics/hailtop:{pip_version()}') -HAIL_GENETICS_HAIL_IMAGE = os.environ.get('HAIL_GENETICS_HAIL_IMAGE') or hailgenetics_hail_image_for_current_python_version() +HAIL_GENETICS_HAIL_IMAGE = ( + os.environ.get('HAIL_GENETICS_HAIL_IMAGE') or hailgenetics_hail_image_for_current_python_version() +) RunningBatchType = TypeVar('RunningBatchType') @@ -51,32 +53,23 @@ class Backend(abc.ABC, Generic[RunningBatchType]): """ Abstract class for backends. """ + _closed = False - def __init__(self): - self._requester_pays_fses: Dict[GCSRequesterPaysConfiguration, RouterAsyncFS] = {} + def __init__(self, requester_pays_fses: GCSRequesterPaysFSCache): + self._requester_pays_fses = requester_pays_fses - def requester_pays_fs(self, requester_pays_config: GCSRequesterPaysConfiguration) -> RouterAsyncFS: - try: - return self._requester_pays_fses[requester_pays_config] - except KeyError: - if requester_pays_config is not None: - self._requester_pays_fses[requester_pays_config] = RouterAsyncFS( - gcs_kwargs={"gcs_requester_pays_configuration": requester_pays_config} - ) - return self._requester_pays_fses[requester_pays_config] - return self._fs + def requester_pays_fs(self, requester_pays_config: Optional[GCSRequesterPaysConfiguration]) -> RouterAsyncFS: + return self._requester_pays_fses[requester_pays_config] def validate_file(self, uri: str, requester_pays_config: Optional[GCSRequesterPaysConfiguration] = None) -> None: - self._validate_file( - uri, self.requester_pays_fs(requester_pays_config) if requester_pays_config is not None else self._fs - ) + self._validate_file(uri, self.requester_pays_fs(requester_pays_config)) @abc.abstractmethod def _validate_file(self, uri: str, fs: RouterAsyncFS) -> None: raise NotImplementedError - def _run(self, batch, dry_run, verbose, delete_scratch_on_exit, **backend_kwargs) -> RunningBatchType: + def _run(self, batch, dry_run, verbose, delete_scratch_on_exit, **backend_kwargs) -> Optional[RunningBatchType]: """ See :meth:`._async_run`. @@ -87,7 +80,9 @@ def _run(self, batch, dry_run, verbose, delete_scratch_on_exit, **backend_kwargs return async_to_blocking(self._async_run(batch, dry_run, verbose, delete_scratch_on_exit, **backend_kwargs)) @abc.abstractmethod - async def _async_run(self, batch, dry_run, verbose, delete_scratch_on_exit, **backend_kwargs) -> RunningBatchType: + async def _async_run( + self, batch, dry_run, verbose, delete_scratch_on_exit, **backend_kwargs + ) -> Optional[RunningBatchType]: """ Execute a batch. @@ -135,7 +130,6 @@ def __exit__(self, exc_type, exc_val, exc_tb): class LocalBackend(Backend[None]): """ Backend that executes batches on a local computer. - Examples -------- @@ -156,11 +150,10 @@ class LocalBackend(Backend[None]): variable `HAIL_BATCH_EXTRA_DOCKER_RUN_FLAGS`. """ - def __init__(self, - tmp_dir: str = '/tmp/', - gsa_key_file: Optional[str] = None, - extra_docker_run_flags: Optional[str] = None): - super().__init__() + def __init__( + self, tmp_dir: str = '/tmp/', gsa_key_file: Optional[str] = None, extra_docker_run_flags: Optional[str] = None + ): + super().__init__(GCSRequesterPaysFSCache(fs_constructor=RouterAsyncFS)) self._tmp_dir = tmp_dir.rstrip('/') flags = '' @@ -176,7 +169,7 @@ def __init__(self, flags += f' -v {gsa_key_file}:/gsa-key/key.json' self._extra_docker_run_flags = flags - self.__fs = RouterAsyncFS() + self.__fs = self._requester_pays_fses[None] @property def _fs(self) -> RouterAsyncFS: @@ -186,12 +179,7 @@ def _validate_file(self, uri: str, fs: RouterAsyncFS) -> None: validate_file(uri, fs) async def _async_run( - self, - batch: 'batch.Batch', - dry_run: bool, - verbose: bool, - delete_scratch_on_exit: bool, - **backend_kwargs + self, batch: 'batch.Batch', dry_run: bool, verbose: bool, delete_scratch_on_exit: bool, **backend_kwargs ) -> None: # pylint: disable=R0915 """ Execute a batch. @@ -218,11 +206,7 @@ async def _async_run( tmpdir = self._get_scratch_dir() def new_code_block(): - return ['set -e' + ('x' if verbose else ''), - '\n', - '# change cd to tmp directory', - f"cd {tmpdir}", - '\n'] + return ['set -e' + ('x' if verbose else ''), '\n', '# change cd to tmp directory', f"cd {tmpdir}", '\n'] def run_code(code) -> Optional[sp.CalledProcessError]: code = '\n'.join(code) @@ -279,7 +263,9 @@ def symlink_input_resource_group(r): symlinks.append(f'ln -sf {shq(src)} {shq(dest)}') return symlinks - def transfer_dicts_for_resource_file(res_file: Union[resource.ResourceFile, resource.PythonResult]) -> List[dict]: + def transfer_dicts_for_resource_file( + res_file: Union[resource.ResourceFile, resource.PythonResult], + ) -> List[dict]: if isinstance(res_file, resource.InputResourceFile): source = res_file._input_path else: @@ -292,7 +278,8 @@ def transfer_dicts_for_resource_file(res_file: Union[resource.ResourceFile, reso input_transfer_dicts = [ transfer_dict for input_resource in batch._input_resources - for transfer_dict in transfer_dicts_for_resource_file(input_resource)] + for transfer_dict in transfer_dicts_for_resource_file(input_resource) + ] if input_transfer_dicts: input_transfers = orjson.dumps(input_transfer_dicts).decode('utf-8') @@ -369,25 +356,30 @@ def cancel_child_jobs(j): else: memory = '' - code.append(f"docker run " - "--entrypoint=''" - f"{self._extra_docker_run_flags} " - f"-v {tmpdir}:{tmpdir} " - f"{memory} " - f"{cpu} " - f"{job._image} " - f"{job_shell} -c {quoted_job_script}") + code.append( + f"docker run " + "--entrypoint=''" + f"{self._extra_docker_run_flags} " + f"-v {tmpdir}:{tmpdir} " + f"{memory} " + f"{cpu} " + f"{job._image} " + f"{job_shell} -c {quoted_job_script}" + ) else: code.append(f"{job_shell} -c {quoted_job_script}") output_transfer_dicts = [ transfer_dict for output_resource in job._external_outputs - for transfer_dict in transfer_dicts_for_resource_file(output_resource)] + for transfer_dict in transfer_dicts_for_resource_file(output_resource) + ] if output_transfer_dicts: output_transfers = orjson.dumps(output_transfer_dicts).decode('utf-8') - code += [f'python3 -m hailtop.aiotools.copy {shq(requester_pays_project_json)} {shq(output_transfers)}'] + code += [ + f'python3 -m hailtop.aiotools.copy {shq(requester_pays_project_json)} {shq(output_transfers)}' + ] code += ['\n'] exc = run_code(code) @@ -420,42 +412,100 @@ async def _async_close(self): class ServiceBackend(Backend[bc.Batch]): - ANY_REGION: ClassVar[List[str]] = ['any_region'] - """Backend that executes batches on Hail's Batch Service on Google Cloud. Examples -------- - >>> service_backend = ServiceBackend(billing_project='my-billing-account', remote_tmpdir='gs://my-bucket/temporary-files/') # doctest: +SKIP - >>> b = Batch(backend=service_backend) # doctest: +SKIP + Create and use a backend that bills to the Hail Batch billing project named "my-billing-account" + and stores temporary intermediate files in "gs://my-bucket/temporary-files". + + >>> import hailtop.batch as hb + >>> service_backend = hb.ServiceBackend( + ... billing_project='my-billing-account', + ... remote_tmpdir='gs://my-bucket/temporary-files/' + ... ) # doctest: +SKIP + >>> b = hb.Batch(backend=service_backend) # doctest: +SKIP + >>> j = b.new_job() # doctest: +SKIP + >>> j.command('echo hello world!') # doctest: +SKIP >>> b.run() # doctest: +SKIP - >>> service_backend.close() # doctest: +SKIP - If the Hail configuration parameters batch/billing_project and - batch/remote_tmpdir were previously set with ``hailctl config set``, then - one may elide the `billing_project` and `remote_tmpdir` parameters. + Same as above, but set the billing project and temporary intermediate folders via a + configuration file:: - >>> service_backend = ServiceBackend() - >>> b = Batch(backend=service_backend) - >>> b.run() # doctest: +SKIP - >>> service_backend.close() + cat >my-batch-script.py >>EOF + import hailtop.batch as hb + b = hb.Batch(backend=ServiceBackend()) + j = b.new_job() + j.command('echo hello world!') + b.run() + EOF + hailctl config set batch/billing_project my-billing-account + hailctl config set batch/remote_tmpdir gs://my-bucket/temporary-files/ + python3 my-batch-script.py + + Same as above, but also specify the use of the :class:`.ServiceBackend` via configuration file:: + + cat >my-batch-script.py >>EOF + import hailtop.batch as hb + b = hb.Batch() + j = b.new_job() + j.command('echo hello world!') + b.run() + EOF + hailctl config set batch/billing_project my-billing-account + hailctl config set batch/remote_tmpdir gs://my-bucket/temporary-files/ + hailctl config set batch/backend service + python3 my-batch-script.py + + Create a backend which stores temporary intermediate files in + "https://my-account.blob.core.windows.net/my-container/tempdir". + + >>> service_backend = hb.ServiceBackend( + ... billing_project='my-billing-account', + ... remote_tmpdir='https://my-account.blob.core.windows.net/my-container/tempdir' + ... ) # doctest: +SKIP + + Require all jobs in all batches in this backend to execute in us-central1:: + + >>> b = hb.Batch(backend=hb.ServiceBackend(regions=['us-central1'])) + + Same as above, but using a configuration file:: + + hailctl config set batch/regions us-central1 + python3 my-batch-script.py + + Same as above, but using the ``HAIL_BATCH_REGIONS`` environment variable:: + + export HAIL_BATCH_REGIONS=us-central1 + python3 my-batch-script.py + + Permit jobs to execute in *either* us-central1 or us-east1:: + + >>> b = hb.Batch(backend=hb.ServiceBackend(regions=['us-central1', 'us-east1'])) + + Same as above, but using a configuration file:: + + hailctl config set batch/regions us-central1,us-east1 + + Allow reading or writing to buckets even though they are "cold" storage: + >>> b = hb.Batch( + ... backend=hb.ServiceBackend( + ... gcs_bucket_allow_list=['cold-bucket', 'cold-bucket2'], + ... ), + ... ) Parameters ---------- billing_project: Name of billing project to use. bucket: - Name of bucket to use. Should not include the ``gs://`` prefix. Cannot be used with - `remote_tmpdir`. Temporary data will be stored in the "/batch" folder of this - bucket. This argument is deprecated. Use `remote_tmpdir` instead. + This argument is deprecated. Use `remote_tmpdir` instead. remote_tmpdir: - Temporary data will be stored in this cloud storage folder. Cannot be used with deprecated - argument `bucket`. Paths should match a GCS URI like gs:/// or an ABS - URI of the form https://.blob.core.windows.net//. + Temporary data will be stored in this cloud storage folder. google_project: - DEPRECATED. Please use gcs_requester_pays_configuration. + This argument is deprecated. Use `gcs_requester_pays_configuration` instead. gcs_requester_pays_configuration : either :class:`str` or :class:`tuple` of :class:`str` and :class:`list` of :class:`str`, optional If a string is provided, configure the Google Cloud Storage file system to bill usage to the project identified by that string. If a tuple is provided, configure the Google Cloud @@ -465,15 +515,19 @@ class ServiceBackend(Backend[bc.Batch]): The authorization token to pass to the batch client. Should only be set for user delegation purposes. regions: - Cloud region(s) to run jobs in. Use py:staticmethod:`.ServiceBackend.supported_regions` to list the - available regions to choose from. Use py:attribute:`.ServiceBackend.ANY_REGION` to signify the default is jobs - can run in any available region. The default is jobs can run in any region unless a default value has - been set with hailctl. An example invocation is `hailctl config set batch/regions "us-central1,us-east1"`. + Cloud regions in which jobs may run. :attr:`.ServiceBackend.ANY_REGION` indicates jobs may + run in any region. If unspecified or ``None``, the ``batch/regions`` Hail configuration + variable is consulted. See examples above. If none of these variables are set, then jobs may + run in any region. :meth:`.ServiceBackend.supported_regions` lists the available regions. gcs_bucket_allow_list: A list of buckets that the :class:`.ServiceBackend` should be permitted to read from or write to, even if their - default policy is to use "cold" storage. Should look like ``["bucket1", "bucket2"]``. + default policy is to use "cold" storage. + """ + ANY_REGION: ClassVar[List[str]] = ['any_region'] + """A special value that indicates a job may run in any region.""" + @staticmethod def supported_regions(): """ @@ -502,19 +556,21 @@ def __init__( gcs_requester_pays_configuration: Optional[GCSRequesterPaysConfiguration] = None, gcs_bucket_allow_list: Optional[List[str]] = None, ): - super().__init__() - if len(args) > 2: raise TypeError(f'ServiceBackend() takes 2 positional arguments but {len(args)} were given') if len(args) >= 1: if billing_project is not None: raise TypeError('ServiceBackend() got multiple values for argument \'billing_project\'') - warnings.warn('Use of deprecated positional argument \'billing_project\' in ServiceBackend(). Specify \'billing_project\' as a keyword argument instead.') + warnings.warn( + 'Use of deprecated positional argument \'billing_project\' in ServiceBackend(). Specify \'billing_project\' as a keyword argument instead.' + ) billing_project = args[0] if len(args) >= 2: if bucket is not None: raise TypeError('ServiceBackend() got multiple values for argument \'bucket\'') - warnings.warn('Use of deprecated positional argument \'bucket\' in ServiceBackend(). Specify \'bucket\' as a keyword argument instead.') + warnings.warn( + 'Use of deprecated positional argument \'bucket\' in ServiceBackend(). Specify \'bucket\' as a keyword argument instead.' + ) bucket = args[1] billing_project = configuration_of(ConfigVariable.BATCH_BILLING_PROJECT, billing_project, None) @@ -522,7 +578,8 @@ def __init__( raise ValueError( 'the billing_project parameter of ServiceBackend must be set ' 'or run `hailctl config set batch/billing_project ' - 'MY_BILLING_PROJECT`') + 'MY_BILLING_PROJECT`' + ) self._billing_project = billing_project self._token = token @@ -542,7 +599,15 @@ def __init__( gcs_kwargs = {'gcs_requester_pays_configuration': google_project} else: gcs_kwargs = {'gcs_requester_pays_configuration': gcs_requester_pays_configuration} - self.__fs = RouterAsyncFS(gcs_kwargs=gcs_kwargs, gcs_bucket_allow_list=gcs_bucket_allow_list) + + super().__init__( + GCSRequesterPaysFSCache( + fs_constructor=RouterAsyncFS, + default_kwargs={"gcs_kwargs": gcs_kwargs, "gcs_bucket_allow_list": gcs_bucket_allow_list}, + ) + ) + + self.__fs = self._requester_pays_fses[None] self.validate_file(self.remote_tmpdir) @@ -587,7 +652,7 @@ async def _async_run( disable_progress_bar: bool = False, callback: Optional[str] = None, token: Optional[str] = None, - **backend_kwargs + **backend_kwargs, ) -> Optional[bc.Batch]: # pylint: disable-msg=too-many-statements """Execute a batch. @@ -635,7 +700,10 @@ async def _async_run( if batch._async_batch is None: batch._async_batch = (await self._batch_client()).create_batch( - attributes=attributes, callback=callback, token=token, cancel_after_n_failures=batch._cancel_after_n_failures + attributes=attributes, + callback=callback, + token=token, + cancel_after_n_failures=batch._cancel_after_n_failures, ) async_batch = batch._async_batch @@ -675,18 +743,14 @@ def symlink_input_resource_group(r): write_external_inputs = [x for r in batch._input_resources for x in copy_external_output(r)] if write_external_inputs: - transfers_bytes = orjson.dumps([ - {"from": src, "to": dest} - for src, dest in write_external_inputs]) + transfers_bytes = orjson.dumps([{"from": src, "to": dest} for src, dest in write_external_inputs]) transfers = transfers_bytes.decode('utf-8') write_cmd = ['python3', '-m', 'hailtop.aiotools.copy', 'null', transfers] if dry_run: commands.append(' '.join(shq(x) for x in write_cmd)) else: j = async_batch.create_job( - image=HAIL_GENETICS_HAILTOP_IMAGE, - command=write_cmd, - attributes={'name': 'write_external_inputs'} + image=HAIL_GENETICS_HAILTOP_IMAGE, command=write_cmd, attributes={'name': 'write_external_inputs'} ) jobs_to_command[j] = ' '.join(shq(x) for x in write_cmd) n_jobs_submitted += 1 @@ -697,18 +761,18 @@ def symlink_input_resource_group(r): for pyjob in pyjobs: if pyjob._image is None: pyjob._image = HAIL_GENETICS_HAIL_IMAGE - await batch._serialize_python_functions_to_input_files( - batch_remote_tmpdir, dry_run=dry_run - ) + await batch._serialize_python_functions_to_input_files(batch_remote_tmpdir, dry_run=dry_run) disable_setup_steps_progress_bar = disable_progress_bar or len(unsubmitted_jobs) < 10_000 - with SimpleCopyToolProgressBar(total=len(unsubmitted_jobs), - description='upload code', - disable=disable_setup_steps_progress_bar) as pbar: + with SimpleCopyToolProgressBar( + total=len(unsubmitted_jobs), description='upload code', disable=disable_setup_steps_progress_bar + ) as pbar: + async def compile_job(job): used_remote_tmpdir = await job._compile(local_tmpdir, batch_remote_tmpdir, dry_run=dry_run) pbar.update(1) return used_remote_tmpdir + used_remote_tmpdir_results = await bounded_gather( *[functools.partial(compile_job, j) for j in unsubmitted_jobs], parallelism=150, @@ -734,17 +798,17 @@ async def compile_job(job): job_command = [cmd.strip() for cmd in job._wrapper_code] prepared_job_command = (f'{{\n{x}\n}}' for x in job_command) - cmd = f''' + cmd = f""" {bash_flags} {make_local_tmpdir} {"; ".join(symlinks)} {" && ".join(prepared_job_command)} -''' +""" user_code = '\n\n'.join(job._user_code) if job._user_code else None if dry_run: - formatted_command = f''' + formatted_command = f""" ================================================================================ # Job {job._job_id} {f": {job.name}" if job.name else ''} @@ -758,7 +822,7 @@ async def compile_job(job): -------------------------------------------------------------------------------- {cmd} ================================================================================ -''' +""" commands.append(formatted_command) continue @@ -785,8 +849,9 @@ async def compile_job(job): image = job._image if job._image else default_image image_ref = parse_docker_image_reference(image) if image_ref.hosted_in('dockerhub') and image_ref.name() not in HAIL_GENETICS_IMAGES: - warnings.warn(f'Using an image {image} from Docker Hub. ' - f'Jobs may fail due to Docker Hub rate limits.') + warnings.warn( + f'Using an image {image} from Docker Hub. ' f'Jobs may fail due to Docker Hub rate limits.' + ) env = {**job._env, 'BATCH_TMPDIR': local_tmpdir} @@ -805,7 +870,7 @@ async def compile_job(job): requester_pays_project=batch.requester_pays_project, user_code=user_code, regions=job._regions, - always_copy_output=job._always_copy_output + always_copy_output=job._always_copy_output, ) n_jobs_submitted += 1 @@ -825,7 +890,7 @@ async def compile_job(job): parents=parents, resources={'cpu': '0.25'}, attributes={'name': 'remove_tmpdir'}, - always_run=True + always_run=True, ) n_jobs_submitted += 1 @@ -843,7 +908,9 @@ async def compile_job(job): jobs_to_command = {j.id: cmd for j, cmd in jobs_to_command.items()} if verbose: - print(f'Submitted batch {batch_id} with {n_jobs_submitted} jobs in {round(time.time() - submit_batch_start, 3)} seconds:') + print( + f'Submitted batch {batch_id} with {n_jobs_submitted} jobs in {round(time.time() - submit_batch_start, 3)} seconds:' + ) for jid, cmd in jobs_to_command.items(): print(f'{jid}: {cmd}') print('') diff --git a/hail/python/hailtop/batch/batch.py b/hail/python/hailtop/batch/batch.py index 79eb47da4fb..c78f5a8bc27 100644 --- a/hail/python/hailtop/batch/batch.py +++ b/hail/python/hailtop/batch/batch.py @@ -1,21 +1,26 @@ import os -import warnings import re -from typing import Callable, Optional, Dict, Union, List, Any, Set +import warnings from io import BytesIO +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Union + import dill -from hailtop.utils import secret_alnum_string, url_scheme, async_to_blocking -from hailtop.aiotools import AsyncFS +import hailtop.batch_client.aioclient as _aiobc +import hailtop.batch_client.client as _bc from hailtop.aiocloud.aioazure.fs import AzureAsyncFS +from hailtop.aiotools import AsyncFS from hailtop.aiotools.router_fs import RouterAsyncFS -import hailtop.batch_client.client as _bc -import hailtop.batch_client.aioclient as _aiobc from hailtop.config import ConfigVariable, configuration_of +from hailtop.utils import async_to_blocking, secret_alnum_string, url_scheme -from . import backend as _backend, job, resource as _resource # pylint: disable=cyclic-import +from . import job +from . import resource as _resource from .exceptions import BatchException +if TYPE_CHECKING: + from hailtop.batch.backend import LocalBackend, ServiceBackend + class Batch: """Object representing the distributed acyclic graph (DAG) of jobs to run. @@ -24,7 +29,8 @@ class Batch: -------- Create a batch object: - >>> p = Batch() + >>> import hailtop.batch as hb + >>> p = hb.Batch() Create a new job that prints "hello": @@ -35,6 +41,10 @@ class Batch: >>> p.run() + Require all jobs in this batch to execute in us-central1: + + >>> b = hb.Batch(backend=hb.ServiceBackend(), default_regions=['us-central1']) + Notes ----- @@ -77,6 +87,9 @@ class Batch: default_storage: Storage setting to use by default if not specified by a job. Only applicable for the :class:`.ServiceBackend`. See :meth:`.Job.storage`. + default_regions: + Cloud regions in which jobs may run. When unspecified or ``None``, use the regions attribute of + :class:`.ServiceBackend`. See :class:`.ServiceBackend` for details. default_timeout: Maximum time in seconds for a job to run before being killed. Only applicable for the :class:`.ServiceBackend`. If `None`, there is no @@ -142,26 +155,33 @@ def from_batch_id(batch_id: int, *args, **kwargs) -> 'Batch': @staticmethod async def _async_from_batch_id(batch_id: int, *args, **kwargs) -> 'Batch': + from hailtop.batch.backend import ServiceBackend # pylint: disable=import-outside-toplevel + b = Batch(*args, **kwargs) - assert isinstance(b._backend, _backend.ServiceBackend) + assert isinstance(b._backend, ServiceBackend) b._async_batch = await (await b._backend._batch_client()).get_batch(batch_id) return b - def __init__(self, - name: Optional[str] = None, - backend: Optional[Union[_backend.LocalBackend, _backend.ServiceBackend]] = None, - attributes: Optional[Dict[str, str]] = None, - requester_pays_project: Optional[str] = None, - default_image: Optional[str] = None, - default_memory: Optional[Union[int, str]] = None, - default_cpu: Optional[Union[float, int, str]] = None, - default_storage: Optional[Union[int, str]] = None, - default_timeout: Optional[Union[float, int]] = None, - default_shell: Optional[str] = None, - default_python_image: Optional[str] = None, - default_spot: Optional[bool] = None, - project: Optional[str] = None, - cancel_after_n_failures: Optional[int] = None): + def __init__( + self, + name: Optional[str] = None, + backend: Optional[Union['LocalBackend', 'ServiceBackend']] = None, + attributes: Optional[Dict[str, str]] = None, + requester_pays_project: Optional[str] = None, + default_image: Optional[str] = None, + default_memory: Optional[Union[int, str]] = None, + default_cpu: Optional[Union[float, int, str]] = None, + default_storage: Optional[Union[int, str]] = None, + default_regions: Optional[List[str]] = None, + default_timeout: Optional[Union[float, int]] = None, + default_shell: Optional[str] = None, + default_python_image: Optional[str] = None, + default_spot: Optional[bool] = None, + project: Optional[str] = None, + cancel_after_n_failures: Optional[int] = None, + ): + from hailtop.batch.backend import LocalBackend, ServiceBackend # pylint: disable=import-outside-toplevel + self._jobs: List[job.Job] = [] self._resource_map: Dict[str, _resource.Resource] = {} self._allocated_files: Set[str] = set() @@ -174,10 +194,10 @@ def __init__(self, else: backend_config = configuration_of(ConfigVariable.BATCH_BACKEND, None, 'local') if backend_config == 'service': - self._backend = _backend.ServiceBackend() + self._backend = ServiceBackend() else: assert backend_config == 'local' - self._backend = _backend.LocalBackend() + self._backend = LocalBackend() self.name = name @@ -193,6 +213,9 @@ def __init__(self, self._default_memory = default_memory self._default_cpu = default_cpu self._default_storage = default_storage + self._default_regions = default_regions + if self._default_regions is None and isinstance(self._backend, ServiceBackend): + self._default_regions = self._backend.regions self._default_timeout = default_timeout self._default_shell = default_shell self._default_python_image = default_python_image @@ -201,7 +224,8 @@ def __init__(self, if project is not None: warnings.warn( 'The project argument to Batch is deprecated, please instead use the google_project argument to ' - 'ServiceBackend. Use of this argument may trigger warnings from aiohttp about unclosed objects.') + 'ServiceBackend. Use of this argument may trigger warnings from aiohttp about unclosed objects.' + ) self._DEPRECATED_project = project self._DEPRECATED_fs: Optional[RouterAsyncFS] = None @@ -242,13 +266,9 @@ async def _serialize_python_to_input_file( return code_input_file - async def _serialize_python_functions_to_input_files( - self, path: str, dry_run: bool = False - ) -> None: + async def _serialize_python_functions_to_input_files(self, path: str, dry_run: bool = False) -> None: for function_id, function in self._python_function_defs.items(): - file = await self._serialize_python_to_input_file( - path, "functions", function_id, function, dry_run - ) + file = await self._serialize_python_to_input_file(path, "functions", function_id, function, dry_run) self._python_function_files[function_id] = file def _unique_job_token(self, n=5): @@ -266,20 +286,18 @@ def _fs(self) -> AsyncFS: return self._DEPRECATED_fs return self._backend._fs - def new_job(self, - name: Optional[str] = None, - attributes: Optional[Dict[str, str]] = None, - shell: Optional[str] = None) -> job.BashJob: + def new_job( + self, name: Optional[str] = None, attributes: Optional[Dict[str, str]] = None, shell: Optional[str] = None + ) -> job.BashJob: """ Alias for :meth:`.Batch.new_bash_job` """ return self.new_bash_job(name, attributes, shell) - def new_bash_job(self, - name: Optional[str] = None, - attributes: Optional[Dict[str, str]] = None, - shell: Optional[str] = None) -> job.BashJob: + def new_bash_job( + self, name: Optional[str] = None, attributes: Optional[Dict[str, str]] = None, shell: Optional[str] = None + ) -> job.BashJob: """ Initialize a :class:`.BashJob` object with default memory, storage, image, and CPU settings (defined in :class:`.Batch`) upon batch creation. @@ -319,20 +337,17 @@ def new_bash_job(self, j.cpu(self._default_cpu) if self._default_storage is not None: j.storage(self._default_storage) + if self._default_regions is not None: + j.regions(self._default_regions) if self._default_timeout is not None: j.timeout(self._default_timeout) if self._default_spot is not None: j.spot(self._default_spot) - if isinstance(self._backend, _backend.ServiceBackend): - j.regions(self._backend.regions) - self._jobs.append(j) return j - def new_python_job(self, - name: Optional[str] = None, - attributes: Optional[Dict[str, str]] = None) -> job.PythonJob: + def new_python_job(self, name: Optional[str] = None, attributes: Optional[Dict[str, str]] = None) -> job.PythonJob: """ Initialize a new :class:`.PythonJob` object with default Python image, memory, storage, and CPU settings (defined in :class:`.Batch`) @@ -393,14 +408,13 @@ def hello(name): j.cpu(self._default_cpu) if self._default_storage is not None: j.storage(self._default_storage) + if self._default_regions is not None: + j.regions(self._default_regions) if self._default_timeout is not None: j.timeout(self._default_timeout) if self._default_spot is not None: j.spot(self._default_spot) - if isinstance(self._backend, _backend.ServiceBackend): - j.regions(self._backend.regions) - self._jobs.append(j) return j @@ -412,6 +426,14 @@ def _new_job_resource_file(self, source, value=None): return jrf def _new_input_resource_file(self, input_path, root=None): + if isinstance(input_path, str): + pass + elif isinstance(input_path, os.PathLike): + # Avoid os.fspath(), which causes some pathlikes to return a path to a downloaded copy instead. + input_path = str(input_path) + else: + raise BatchException(f"path value is neither string nor path-like. Found '{type(input_path)}' instead.") + self._backend.validate_file(input_path, self.requester_pays_project) # Take care not to include an Azure SAS token query string in the local name. @@ -452,7 +474,7 @@ def _new_python_result(self, source, value=None) -> _resource.PythonResult: self._resource_map[jrf._uid] = jrf # pylint: disable=no-member return jrf - def read_input(self, path: str) -> _resource.InputResourceFile: + def read_input(self, path: Union[str, os.PathLike]) -> _resource.InputResourceFile: """ Create a new input resource file object representing a single file. @@ -482,7 +504,7 @@ def read_input(self, path: str) -> _resource.InputResourceFile: irf = self._new_input_resource_file(path) return irf - def read_input_group(self, **kwargs: str) -> _resource.ResourceGroup: + def read_input_group(self, **kwargs: Union[str, os.PathLike]) -> _resource.ResourceGroup: """Create a new resource group representing a mapping of identifier to input resource files. @@ -605,25 +627,34 @@ def write_output(self, resource: _resource.Resource, dest: str): where `identifier` is the identifier of the file in the :class:`.ResourceGroup` map. """ + from hailtop.batch.backend import LocalBackend # pylint: disable=import-outside-toplevel if not isinstance(resource, _resource.Resource): raise BatchException(f"'write_output' only accepts Resource inputs. Found '{type(resource)}'.") - if (isinstance(resource, _resource.JobResourceFile) - and isinstance(resource._source, job.BashJob) - and resource not in resource._source._mentioned): + if ( + isinstance(resource, _resource.JobResourceFile) + and isinstance(resource._source, job.BashJob) + and resource not in resource._source._mentioned + ): name = resource._source._resources_inverse[resource] - raise BatchException(f"undefined resource '{name}'\n" - f"Hint: resources must be defined within the " - f"job methods 'command' or 'declare_resource_group'") - if (isinstance(resource, _resource.PythonResult) - and isinstance(resource._source, job.PythonJob) - and resource not in resource._source._mentioned): + raise BatchException( + f"undefined resource '{name}'\n" + f"Hint: resources must be defined within the " + f"job methods 'command' or 'declare_resource_group'" + ) + if ( + isinstance(resource, _resource.PythonResult) + and isinstance(resource._source, job.PythonJob) + and resource not in resource._source._mentioned + ): name = resource._source._resources_inverse[resource] - raise BatchException(f"undefined resource '{name}'\n" - f"Hint: resources must be bound as a result " - f"using the PythonJob 'call' method") + raise BatchException( + f"undefined resource '{name}'\n" + f"Hint: resources must be bound as a result " + f"using the PythonJob 'call' method" + ) - if isinstance(self._backend, _backend.LocalBackend): + if isinstance(self._backend, LocalBackend): dest_scheme = url_scheme(dest) if dest_scheme == '': dest = os.path.abspath(os.path.expanduser(dest)) @@ -655,11 +686,9 @@ def select_jobs(self, pattern: str) -> List[job.Job]: return [job for job in self._jobs if job.name is not None and re.match(pattern, job.name) is not None] # Do not try to overload this based on dry_run. LocalBackend.run also returns None. - def run(self, - dry_run: bool = False, - verbose: bool = False, - delete_scratch_on_exit: bool = True, - **backend_kwargs: Any) -> Optional[_bc.Batch]: + def run( + self, dry_run: bool = False, verbose: bool = False, delete_scratch_on_exit: bool = True, **backend_kwargs: Any + ) -> Optional[_bc.Batch]: """ Execute a batch. @@ -689,11 +718,7 @@ def run(self, # Do not try to overload this based on dry_run. LocalBackend.run also returns None. async def _async_run( - self, - dry_run: bool = False, - verbose: bool = False, - delete_scratch_on_exit: bool = True, - **backend_kwargs: Any + self, dry_run: bool = False, verbose: bool = False, delete_scratch_on_exit: bool = True, **backend_kwargs: Any ) -> Optional[_bc.Batch]: seen = set() ordered_jobs = [] @@ -727,6 +752,5 @@ def schedule_job(j): self._DEPRECATED_fs = None return run_result - def __str__(self): return self._uid diff --git a/hail/python/hailtop/batch/batch_pool_executor.py b/hail/python/hailtop/batch/batch_pool_executor.py index 752546d49fe..b562f0a98eb 100644 --- a/hail/python/hailtop/batch/batch_pool_executor.py +++ b/hail/python/hailtop/batch/batch_pool_executor.py @@ -1,19 +1,20 @@ -from typing import Optional, Callable, Type, Union, List, Any, Iterable, AsyncGenerator -from types import TracebackType -from io import BytesIO -import warnings import asyncio import concurrent.futures -import dill import functools +import warnings +from io import BytesIO +from types import TracebackType +from typing import Any, AsyncGenerator, Callable, Iterable, List, Optional, Type, Union + +import dill -from hailtop.utils import secret_alnum_string, partition, async_to_blocking import hailtop.batch_client.aioclient as low_level_batch_client -from hailtop.batch_client.parse import parse_cpu_in_mcpu from hailtop.aiotools.router_fs import RouterAsyncFS +from hailtop.batch_client.parse import parse_cpu_in_mcpu +from hailtop.utils import async_to_blocking, partition, secret_alnum_string, the_empty_async_generator +from .backend import HAIL_GENETICS_HAIL_IMAGE, ServiceBackend from .batch import Batch -from .backend import ServiceBackend, HAIL_GENETICS_HAIL_IMAGE def cpu_spec_to_float(spec: Union[int, str]) -> float: @@ -27,6 +28,7 @@ def cpu_spec_to_float(spec: Union[int, str]) -> float: def chunk(fn): def chunkedfn(*args): return [fn(*arglist) for arglist in zip(*args)] + return chunkedfn @@ -102,14 +104,17 @@ class BatchPoolExecutor: DEPRECATED. Please specify gcs_requester_pays_configuration in :class:`.ServiceBackend`. """ - def __init__(self, *, - name: Optional[str] = None, - backend: Optional[ServiceBackend] = None, - image: Optional[str] = None, - cpus_per_job: Optional[Union[int, str]] = None, - wait_on_exit: bool = True, - cleanup_bucket: bool = True, - project: Optional[str] = None): + def __init__( + self, + *, + name: Optional[str] = None, + backend: Optional[ServiceBackend] = None, + image: Optional[str] = None, + cpus_per_job: Optional[Union[int, str]] = None, + wait_on_exit: bool = True, + cleanup_bucket: bool = True, + project: Optional[str] = None, + ): self.name = name or "BatchPoolExecutor-" + secret_alnum_string(4) self.backend = backend or ServiceBackend() if not isinstance(self.backend, ServiceBackend): @@ -148,11 +153,9 @@ def _fs(self) -> RouterAsyncFS: def __enter__(self): return self - def map(self, - fn: Callable, - *iterables: Iterable[Any], - timeout: Optional[Union[int, float]] = None, - chunksize: int = 1): + def map( + self, fn: Callable, *iterables: Iterable[Any], timeout: Optional[Union[int, float]] = None, chunksize: int = 1 + ): """Call `fn` on cloud machines with arguments from `iterables`. This function returns a generator which will produce each result in the @@ -210,8 +213,7 @@ def map(self, amount of meaningful work done per-container. """ - agen = async_to_blocking( - self.async_map(fn, iterables, timeout=timeout, chunksize=chunksize)) + agen = async_to_blocking(self.async_map(fn, iterables, timeout=timeout, chunksize=chunksize)) def generator_from_async_generator(aiter): try: @@ -219,17 +221,19 @@ def generator_from_async_generator(aiter): yield async_to_blocking(aiter.__anext__()) except StopAsyncIteration: return + return generator_from_async_generator(agen.__aiter__()) - async def async_map(self, - fn: Callable, - iterables: Iterable[Iterable[Any]], - timeout: Optional[Union[int, float]] = None, - chunksize: int = 1 - ) -> AsyncGenerator[int, None]: + async def async_map( + self, + fn: Callable, + iterables: Iterable[Iterable[Any]], + timeout: Optional[Union[int, float]] = None, + chunksize: int = 1, + ) -> AsyncGenerator[int, None]: """Aysncio compatible version of :meth:`.map`.""" if not iterables: - return (x for x in range(0)) + return the_empty_async_generator() if chunksize > 1: list_per_argument = [list(x) for x in iterables] @@ -237,13 +241,11 @@ async def async_map(self, assert all(n == len(x) for x in list_per_argument) n_chunks = (n + chunksize - 1) // chunksize iterables_chunks = [list(partition(n_chunks, x)) for x in list_per_argument] - iterables_chunks = [ - chunk for chunk in iterables_chunks if len(chunk) > 0] + iterables_chunks = [chunk for chunk in iterables_chunks if len(chunk) > 0] fn = chunk(fn) iterables = iterables_chunks - submit_tasks = [asyncio.ensure_future(self.async_submit(fn, *arguments)) - for arguments in zip(*iterables)] + submit_tasks = [asyncio.ensure_future(self.async_submit(fn, *arguments)) for arguments in zip(*iterables)] try: bp_futures = [await t for t in submit_tasks] except: @@ -262,18 +264,11 @@ async def async_result_or_cancel_all(future): raise if chunksize > 1: - return (val - for future in bp_futures - for val in await async_result_or_cancel_all(future)) - - return (await async_result_or_cancel_all(future) - for future in bp_futures) - - def submit(self, - fn: Callable, - *args: Any, - **kwargs: Any - ) -> 'BatchPoolFuture': + return (val for future in bp_futures for val in await async_result_or_cancel_all(future)) + + return (await async_result_or_cancel_all(future) for future in bp_futures) + + def submit(self, fn: Callable, *args: Any, **kwargs: Any) -> 'BatchPoolFuture': """Call `fn` on a cloud machine with all remaining arguments and keyword arguments. The function, any objects it references, the arguments, and the keyword @@ -326,14 +321,9 @@ def submit(self, kwargs: Keyword arguments for the function. """ - return async_to_blocking( - self.async_submit(fn, *args, **kwargs)) - - async def async_submit(self, - unapplied: Callable, - *args: Any, - **kwargs: Any - ) -> 'BatchPoolFuture': + return async_to_blocking(self.async_submit(fn, *args, **kwargs)) + + async def async_submit(self, unapplied: Callable, *args: Any, **kwargs: Any) -> 'BatchPoolFuture': """Aysncio compatible version of :meth:`BatchPoolExecutor.submit`.""" if self._shutdown: raise RuntimeError('BatchPoolExecutor has already been shutdown.') @@ -349,11 +339,10 @@ async def async_submit(self, def run_async(*args, **kwargs): return asyncio.run(unapplied_copy(*args, **kwargs)) + unapplied = run_async - batch = Batch(name=self.name + '-' + name, - backend=self.backend, - default_image=self.image) + batch = Batch(name=self.name + '-' + name, backend=self.backend, default_image=self.image) self.batches.append(batch) j = batch.new_job(name) @@ -394,19 +383,19 @@ def run_async(*args, **kwargs): assert sync_batch is not None backend_batch = sync_batch._async_batch try: - return BatchPoolFuture(self, - backend_batch, - low_level_batch_client.Job.submitted_job( - backend_batch, 1), - output_gcs) + return BatchPoolFuture( + self, backend_batch, low_level_batch_client.Job.submitted_job(backend_batch, 1), output_gcs + ) except: await backend_batch.cancel() raise - def __exit__(self, - exc_type: Optional[Type[BaseException]], - exc_value: Optional[BaseException], - traceback: Optional[TracebackType]): + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ): self.shutdown(wait=self.wait_on_exit) def _add_future(self, f): @@ -431,13 +420,14 @@ def shutdown(self, wait: bool = True): method. """ if wait: + async def ignore_exceptions(f): try: await f.async_result() except Exception: pass - async_to_blocking( - asyncio.gather(*[ignore_exceptions(f) for f in self.futures])) + + async_to_blocking(asyncio.gather(*[ignore_exceptions(f) for f in self.futures])) if self.finished_future_count == len(self.futures): self._cleanup() self._shutdown = True @@ -451,11 +441,13 @@ def _cleanup(self): class BatchPoolFuture: - def __init__(self, - executor: BatchPoolExecutor, - batch: low_level_batch_client.Batch, - job: low_level_batch_client.Job, - output_file: str): + def __init__( + self, + executor: BatchPoolExecutor, + batch: low_level_batch_client.Batch, + job: low_level_batch_client.Job, + output_file: str, + ): self.executor = executor self.batch = batch self.job = job @@ -491,8 +483,7 @@ async def async_cancel(self): return True def cancelled(self): - """Returns ``True`` if :meth:`.cancel` was called before a value was produced. - """ + """Returns ``True`` if :meth:`.cancel` was called before a value was produced.""" return self.fetch_coro.cancelled() def running(self): @@ -503,8 +494,7 @@ def running(self): return False def done(self): - """Returns `True` if the function is complete and not cancelled. - """ + """Returns `True` if the function is complete and not cancelled.""" return self.fetch_coro.done() def result(self, timeout: Optional[Union[float, int]] = None): @@ -546,15 +536,14 @@ async def _async_fetch_result(self): status = await self.job.wait() main_container_status = status['status']['container_statuses']['main'] if main_container_status['state'] == 'error': - raise ValueError( - f"submitted job failed:\n{main_container_status['error']}") + raise ValueError(f"submitted job failed:\n{main_container_status['error']}") try: - value, traceback = dill.loads( - await self.executor._fs.read(self.output_file)) + value, traceback = dill.loads(await self.executor._fs.read(self.output_file)) except FileNotFoundError as exc: job_log = await self.job.log() raise ValueError( - f"submitted job did not write output:\n{main_container_status}\n\nLog:\n{job_log}") from exc + f"submitted job did not write output:\n{main_container_status}\n\nLog:\n{job_log}" + ) from exc if traceback is None: return value assert isinstance(value, BaseException) @@ -566,13 +555,11 @@ async def _async_fetch_result(self): self.executor._finish_future() def exception(self, timeout: Optional[Union[float, int]] = None): - """Block until the job is complete and raise any exceptions. - """ + """Block until the job is complete and raise any exceptions.""" if self.cancelled(): raise concurrent.futures.CancelledError() self.result(timeout) def add_done_callback(self, _): - """NOT IMPLEMENTED - """ + """NOT IMPLEMENTED""" raise NotImplementedError() diff --git a/hail/python/hailtop/batch/conftest.py b/hail/python/hailtop/batch/conftest.py index dc5f4bc6291..daa6cd6da60 100644 --- a/hail/python/hailtop/batch/conftest.py +++ b/hail/python/hailtop/batch/conftest.py @@ -1,5 +1,6 @@ import doctest import os + import pytest from hailtop import batch @@ -11,9 +12,11 @@ def patch_doctest_check_output(monkeypatch): base_check_output = doctest.OutputChecker.check_output def patched_check_output(self, want, got, optionflags): - return ((not want) - or (want.strip() == 'None') - or base_check_output(self, want, got, optionflags | doctest.NORMALIZE_WHITESPACE)) + return ( + (not want) + or (want.strip() == 'None') + or base_check_output(self, want, got, optionflags | doctest.NORMALIZE_WHITESPACE) + ) monkeypatch.setattr('doctest.OutputChecker.check_output', patched_check_output) yield @@ -28,8 +31,7 @@ def init(doctest_namespace): doctest_namespace['Batch'] = batch.Batch olddir = os.getcwd() - os.chdir(os.path.join(os.path.dirname(os.path.realpath(__file__)), - "docs")) + os.chdir(os.path.join(os.path.dirname(os.path.realpath(__file__)), "docs")) try: print("finished setting up doctest...") yield diff --git a/hail/python/hailtop/batch/docker.py b/hail/python/hailtop/batch/docker.py index 2333a99cde5..db56a3e70d9 100644 --- a/hail/python/hailtop/batch/docker.py +++ b/hail/python/hailtop/batch/docker.py @@ -1,17 +1,19 @@ +import os import shutil import sys -import os -from typing import Optional, List +from typing import List, Optional from ..utils import secret_alnum_string, sync_check_exec -def build_python_image(fullname: str, - requirements: Optional[List[str]] = None, - python_version: Optional[str] = None, - _tmp_dir: str = '/tmp', - *, - show_docker_output: bool = False) -> str: +def build_python_image( + fullname: str, + requirements: Optional[List[str]] = None, + python_version: Optional[str] = None, + _tmp_dir: str = '/tmp', + *, + show_docker_output: bool = False, +) -> str: """ Build a new Python image with dill and the specified pip packages installed. @@ -56,7 +58,8 @@ def build_python_image(fullname: str, if major_version != 3 or minor_version < 9: raise ValueError( - f'Python versions older than 3.9 (you are using {major_version}.{minor_version}) are not supported') + f'Python versions older than 3.9 (you are using {major_version}.{minor_version}) are not supported' + ) base_image = f'hailgenetics/python-dill:{major_version}.{minor_version}-slim' @@ -73,14 +76,14 @@ def build_python_image(fullname: str, f.write('\n'.join(requirements) + '\n') with open(f'{docker_path}/Dockerfile', 'w', encoding='utf-8') as f: - f.write(f''' + f.write(f""" FROM {base_image} COPY requirements.txt . RUN pip install --upgrade --no-cache-dir -r requirements.txt && \ python3 -m pip check -''') +""") sync_check_exec('docker', 'build', '-t', fullname, docker_path, capture_output=not show_docker_output) print(f'finished building image {fullname}') diff --git a/hail/python/hailtop/batch/docs/change_log.rst b/hail/python/hailtop/batch/docs/change_log.rst index 4e6ff3e665a..053ae6e5daf 100644 --- a/hail/python/hailtop/batch/docs/change_log.rst +++ b/hail/python/hailtop/batch/docs/change_log.rst @@ -15,6 +15,19 @@ versions. In particular, Hail officially supports: Change Log ========== +**Version 0.2.130** + +- (`#14425 `__) A job's 'always run' + state is rendered in the Job and Batch pages. This makes it easier to understand + why a job is queued to run when others have failed or been cancelled. +- (`#14437 `__) The billing page now + reports users' spend on the batch service. + +**Version 0.2.128** + +- (`#14224 `__) `hb.Batch` now accepts a + `default_regions` argument which is the default for all jobs in the Batch. + **Version 0.2.124** - (`#13681 `__) Fix `hailctl batch init` and `hailctl auth login` for diff --git a/hail/python/hailtop/batch/docs/conf.py b/hail/python/hailtop/batch/docs/conf.py index 72fb60ca3de..c6d2d99507e 100644 --- a/hail/python/hailtop/batch/docs/conf.py +++ b/hail/python/hailtop/batch/docs/conf.py @@ -15,8 +15,8 @@ # import os # import sys # sys.path.insert(0, os.path.abspath('.')) -import inspect import datetime +import inspect # -- Project information ----------------------------------------------------- @@ -127,6 +127,7 @@ # -- Extension configuration ------------------------------------------------- + def get_class_that_defined_method(meth): if inspect.ismethod(meth): for cls in inspect.getmro(meth.__self__.__class__): @@ -134,8 +135,7 @@ def get_class_that_defined_method(meth): return cls meth = meth.__func__ # fallback to __qualname__ parsing if inspect.isfunction(meth): - cls = getattr(inspect.getmodule(meth), - meth.__qualname__.split('.', 1)[0].rsplit('.', 1)[0]) + cls = getattr(inspect.getmodule(meth), meth.__qualname__.split('.', 1)[0].rsplit('.', 1)[0]) if isinstance(cls, type): return cls return getattr(meth, '__objclass__', None) # handle special descriptor objects @@ -149,19 +149,37 @@ def has_docstring(obj): def autodoc_skip_member(app, what, name, obj, skip, options): - exclusions = ('__delattr__', '__dict__', '__dir__', '__doc__', '__format__', - '__getattribute__', '__hash__', '__init__', - '__init_subclass__', '__new__', '__reduce__', '__reduce_ex__', - '__repr__', '__setattr__', '__sizeof__', '__str__', - '__subclasshook__', '__weakref__', 'maketrans') + exclusions = ( + '__delattr__', + '__dict__', + '__dir__', + '__doc__', + '__format__', + '__getattribute__', + '__hash__', + '__init__', + '__init_subclass__', + '__new__', + '__reduce__', + '__reduce_ex__', + '__repr__', + '__setattr__', + '__sizeof__', + '__str__', + '__subclasshook__', + '__weakref__', + 'maketrans', + ) excluded_classes = ('str',) cls = get_class_that_defined_method(obj) - exclude = (name in exclusions - or (name.startswith('_') and not has_docstring(obj)) - or (cls and cls.__name__ in excluded_classes)) + exclude = ( + name in exclusions + or (name.startswith('_') and not has_docstring(obj)) + or (cls and cls.__name__ in excluded_classes) + ) return exclude diff --git a/hail/python/hailtop/batch/docs/cookbook/files/batch_clumping.py b/hail/python/hailtop/batch/docs/cookbook/files/batch_clumping.py index eec7c9a0f24..445f31d7d3c 100644 --- a/hail/python/hailtop/batch/docs/cookbook/files/batch_clumping.py +++ b/hail/python/hailtop/batch/docs/cookbook/files/batch_clumping.py @@ -9,19 +9,16 @@ def gwas(batch, vcf, phenotypes): g = batch.new_job(name='run-gwas') g.image('us-docker.pkg.dev//1kg-gwas:latest') g.cpu(cores) - g.declare_resource_group(ofile={ - 'bed': '{root}.bed', - 'bim': '{root}.bim', - 'fam': '{root}.fam', - 'assoc': '{root}.assoc' - }) - g.command(f''' + g.declare_resource_group( + ofile={'bed': '{root}.bed', 'bim': '{root}.bim', 'fam': '{root}.fam', 'assoc': '{root}.assoc'} + ) + g.command(f""" python3 /run_gwas.py \ --vcf {vcf} \ --phenotypes {phenotypes} \ --output-file {g.ofile} \ --cores {cores} -''') +""") return g @@ -32,7 +29,7 @@ def clump(batch, bfile, assoc, chr): c = batch.new_job(name=f'clump-{chr}') c.image('hailgenetics/genetics:0.2.37') c.memory('1Gi') - c.command(f''' + c.command(f""" plink --bfile {bfile} \ --clump {assoc} \ --chr {chr} \ @@ -43,7 +40,7 @@ def clump(batch, bfile, assoc, chr): --memory 1024 mv plink.clumped {c.clumped} -''') +""") return c @@ -54,14 +51,14 @@ def merge(batch, results): merger = batch.new_job(name='merge-results') merger.image('ubuntu:22.04') if results: - merger.command(f''' + merger.command(f""" head -n 1 {results[0]} > {merger.ofile} for result in {" ".join(results)} do tail -n +2 "$result" >> {merger.ofile} done sed -i -e '/^$/d' {merger.ofile} -''') +""") return merger diff --git a/hail/python/hailtop/batch/docs/cookbook/files/run_gwas.py b/hail/python/hailtop/batch/docs/cookbook/files/run_gwas.py index 9349b7aaa7d..dc4081eba14 100644 --- a/hail/python/hailtop/batch/docs/cookbook/files/run_gwas.py +++ b/hail/python/hailtop/batch/docs/cookbook/files/run_gwas.py @@ -1,4 +1,5 @@ import argparse + import hail as hl @@ -12,9 +13,11 @@ def run_gwas(vcf_file, phenotypes_file, output_file): mt = hl.sample_qc(mt) mt = mt.filter_cols((mt.sample_qc.dp_stats.mean >= 4) & (mt.sample_qc.call_rate >= 0.97)) ab = mt.AD[1] / hl.sum(mt.AD) - filter_condition_ab = ((mt.GT.is_hom_ref() & (ab <= 0.1)) - | (mt.GT.is_het() & (ab >= 0.25) & (ab <= 0.75)) - | (mt.GT.is_hom_var() & (ab >= 0.9))) + filter_condition_ab = ( + (mt.GT.is_hom_ref() & (ab <= 0.1)) + | (mt.GT.is_het() & (ab >= 0.25) & (ab <= 0.75)) + | (mt.GT.is_hom_var() & (ab >= 0.9)) + ) mt = mt.filter_entries(filter_condition_ab) mt = hl.variant_qc(mt) mt = mt.filter_rows(mt.variant_qc.AF[1] > 0.01) @@ -26,7 +29,8 @@ def run_gwas(vcf_file, phenotypes_file, output_file): gwas = hl.linear_regression_rows( y=mt.pheno.CaffeineConsumption, x=mt.GT.n_alt_alleles(), - covariates=[1.0, mt.pheno.isFemale, mt.scores[0], mt.scores[1], mt.scores[2]]) + covariates=[1.0, mt.pheno.isFemale, mt.scores[0], mt.scores[1], mt.scores[2]], + ) gwas = gwas.select(SNP=hl.variant_str(gwas.locus, gwas.alleles), P=gwas.p_value) gwas = gwas.key_by(gwas.SNP) diff --git a/hail/python/hailtop/batch/docs/cookbook/files/run_rf_checkpoint.py b/hail/python/hailtop/batch/docs/cookbook/files/run_rf_checkpoint.py index 12cc07fd6f6..2394e1dbe63 100644 --- a/hail/python/hailtop/batch/docs/cookbook/files/run_rf_checkpoint.py +++ b/hail/python/hailtop/batch/docs/cookbook/files/run_rf_checkpoint.py @@ -1,9 +1,11 @@ -import hailtop.batch as hb -import hailtop.fs as hfs -import pandas as pd from typing import Tuple + +import pandas as pd from sklearn.ensemble import RandomForestRegressor +import hailtop.batch as hb +import hailtop.fs as hfs + def random_forest(df_x_path: str, df_y_path: str, window_name: str, cores: int = 1) -> Tuple[str, float, float]: # read in data @@ -19,11 +21,7 @@ def random_forest(df_x_path: str, df_y_path: str, window_name: str, cores: int = # run random forest max_features = 3 / 4 - rf = RandomForestRegressor(n_estimators=100, - n_jobs=cores, - max_features=max_features, - oob_score=True, - verbose=False) + rf = RandomForestRegressor(n_estimators=100, n_jobs=cores, max_features=max_features, oob_score=True, verbose=False) rf.fit(x_train, y_train) diff --git a/hail/python/hailtop/batch/docs/cookbook/files/run_rf_checkpoint_batching.py b/hail/python/hailtop/batch/docs/cookbook/files/run_rf_checkpoint_batching.py index 1da8f70ba5e..e2fcd2c59fe 100644 --- a/hail/python/hailtop/batch/docs/cookbook/files/run_rf_checkpoint_batching.py +++ b/hail/python/hailtop/batch/docs/cookbook/files/run_rf_checkpoint_batching.py @@ -1,9 +1,11 @@ +from typing import Tuple + +import pandas as pd +from sklearn.ensemble import RandomForestRegressor + import hailtop.batch as hb import hailtop.fs as hfs from hailtop.utils import grouped -import pandas as pd -from typing import Tuple -from sklearn.ensemble import RandomForestRegressor def random_forest(df_x_path: str, df_y_path: str, window_name: str, cores: int = 1) -> Tuple[str, float, float]: @@ -20,11 +22,7 @@ def random_forest(df_x_path: str, df_y_path: str, window_name: str, cores: int = # run random forest max_features = 3 / 4 - rf = RandomForestRegressor(n_estimators=100, - n_jobs=cores, - max_features=max_features, - oob_score=True, - verbose=False) + rf = RandomForestRegressor(n_estimators=100, n_jobs=cores, max_features=max_features, oob_score=True, verbose=False) rf.fit(x_train, y_train) diff --git a/hail/python/hailtop/batch/docs/cookbook/files/run_rf_simple.py b/hail/python/hailtop/batch/docs/cookbook/files/run_rf_simple.py index 285c27a2843..eacd660d5e7 100644 --- a/hail/python/hailtop/batch/docs/cookbook/files/run_rf_simple.py +++ b/hail/python/hailtop/batch/docs/cookbook/files/run_rf_simple.py @@ -1,9 +1,11 @@ -import hailtop.batch as hb -import hailtop.fs as hfs -import pandas as pd from typing import Tuple + +import pandas as pd from sklearn.ensemble import RandomForestRegressor +import hailtop.batch as hb +import hailtop.fs as hfs + def random_forest(df_x_path: str, df_y_path: str, window_name: str, cores: int = 1) -> Tuple[str, float, float]: # read in data @@ -19,11 +21,7 @@ def random_forest(df_x_path: str, df_y_path: str, window_name: str, cores: int = # run random forest max_features = 3 / 4 - rf = RandomForestRegressor(n_estimators=100, - n_jobs=cores, - max_features=max_features, - oob_score=True, - verbose=False) + rf = RandomForestRegressor(n_estimators=100, n_jobs=cores, max_features=max_features, oob_score=True, verbose=False) rf.fit(x_train, y_train) diff --git a/hail/python/hailtop/batch/docs/service.rst b/hail/python/hailtop/batch/docs/service.rst index f20beb1ce4c..c304b13b173 100644 --- a/hail/python/hailtop/batch/docs/service.rst +++ b/hail/python/hailtop/batch/docs/service.rst @@ -91,21 +91,12 @@ has the following prefix `us-docker.pkg.dev/`: gcloud artifacts repositories add-iam-policy-binding \ --member= --role=roles/artifactregistry.repoAdmin -If you want to run gcloud commands within your Batch jobs, the service account file is available in -the main container with its path specified in the `$GOOGLE_APPLICATION_CREDENTIALS` environment -variable. You can authenticate using the service account by adding -the following line to your user code and using a Docker image that has gcloud installed. - -.. code-block:: sh - - gcloud -q auth activate-service-account --key-file=$GOOGLE_APPLICATION_CREDENTIALS - Billing ------- The cost for executing a job depends on the underlying machine type, the region in which the VM is running in, -and how much CPU and memory is being requested. Currently, Batch runs most jobs on 16 core, preemptible, n1 +and how much CPU and memory is being requested. Currently, Batch runs most jobs on 16 core, spot, n1 machines with 10 GB of persistent SSD boot disk and 375 GB of local SSD. The costs are as follows: - Compute cost @@ -116,11 +107,11 @@ machines with 10 GB of persistent SSD boot disk and 375 GB of local SSD. The cos based on the current spot prices for a given worker type and the region in which the worker is running in. You can use :meth:`.Job.regions` to specify which regions to run a job in. - = $0.01 per core per hour for **preemptible standard** worker types + = $0.01 per core per hour for **spot standard** worker types - = $0.012453 per core per hour for **preemptible highmem** worker types + = $0.012453 per core per hour for **spot highmem** worker types - = $0.0074578 per core per hour for **preemptible highcpu** worker types + = $0.0074578 per core per hour for **spot highcpu** worker types = $0.04749975 per core per hour for **nonpreemptible standard** worker types @@ -163,22 +154,27 @@ machines with 10 GB of persistent SSD boot disk and 375 GB of local SSD. The cos - IP network cost - = $0.00025 per core per hour + = $0.0003125 per core per hour for **nonpreemptible** worker types + + = $0.00015625 per core per hour for **spot** worker types - Service cost = $0.01 per core per hour +- Logs, Specs, and Firewall Fee + = $0.005 per core per hour -The sum of these costs is **$0.021935** per core/hour for standard workers, **$0.024388** per core/hour -for highmem workers, and **$0.019393** per core/hour for highcpu workers. There is also an additional + +The sum of these costs is **$0.02684125** per core/hour for standard spot workers, **$0.02929425** per core/hour +for highmem spot workers, and **$0.02429905** per core/hour for highcpu spot workers. There is also an additional cost of **$0.00023** per GB per hour of extra storage requested. At any given moment as many as four cores of the cluster may come from a 4 core machine if the worker type is standard. If a job is scheduled on this machine, then the cost per core hour is **$0.02774** plus **$0.00023** per GB per hour storage of extra storage requested. -For jobs that run on non-preemptible machines, the costs are **$0.060462** per core/hour for standard workers, **$0.072114** per core/hour -for highmem workers, and **$0.048365** per core/hour for highcpu workers. +For jobs that run on non-preemptible machines, the costs are **$0.06449725** per core/hour for standard workers, **$0.076149** per core/hour +for highmem workers, and **$0.0524218** per core/hour for highcpu workers. .. note:: @@ -227,22 +223,15 @@ error messages in the terminal window. Submitting a Batch to the Service --------------------------------- +.. warning:: + + To avoid substantial network costs, ensure your jobs and data reside in the same `region`_. + To execute a batch on the Batch service rather than locally, first construct a :class:`.ServiceBackend` object with a billing project and bucket for storing intermediate files. Your service account must have read and write access to the bucket. -.. warning:: - - By default, the Batch Service runs jobs in any region in the US. Make sure you have considered additional `ingress and - egress fees `_ when using regional buckets and container or artifact - registries. Multi-regional buckets also have additional replication fees when writing data. A good rule of thumb is to use - a multi-regional artifact registry for Docker images and regional buckets for data. You can then specify which region(s) - you want your job to run in with :meth:`.Job.regions`. To set the default region(s) for all jobs, you can set the input - regions argument to :class:`.ServiceBackend` or use hailctl to set the default value. An example invocation is - `hailctl config set batch/regions "us-central1,us-east1"`. You can also get the full list of supported regions - with py:staticmethod:`.ServiceBackend.supported_regions`. - Next, pass the :class:`.ServiceBackend` object to the :class:`.Batch` constructor with the parameter name `backend`. @@ -252,7 +241,7 @@ and execute the following batch: .. code-block:: python - >>> import hailtop.batch as hb # doctest: +SKIP + >>> import hailtop.batch as hb >>> backend = hb.ServiceBackend('my-billing-project', remote_tmpdir='gs://my-bucket/batch/tmp/') # doctest: +SKIP >>> b = hb.Batch(backend=backend, name='test') # doctest: +SKIP >>> j = b.new_job(name='hello') # doctest: +SKIP @@ -271,6 +260,72 @@ have previously set them with ``hailctl``: A trial billing project is automatically created for you with the name {USERNAME}-trial +.. _region: + +Regions +------- + +Data and compute both reside in a physical location. In Google Cloud Platform, the location of data +is controlled by the location of the containing bucket. ``gcloud`` can determine the location of a +bucket:: + + gcloud storage buckets describe gs://my-bucket + +If your compute resides in a different location from the data it reads or writes, then you will +accrue substantial `network charges `__. + +To avoid network charges ensure all your data is in one region and specify that region in one of the +following five ways. As a running example, we consider data stored in `us-central1`. The options are +listed from highest to lowest precedence. + +1. :meth:`.Job.regions`: + + .. code-block:: python + + >>> b = hb.Batch(backend=hb.ServiceBackend()) + >>> j = b.new_job() + >>> j.regions(['us-central1']) + +2. The ``default_regions`` parameter of :class:`.Batch`: + + .. code-block:: python + + >>> b = hb.Batch(backend=hb.ServiceBackend(), default_regions=['us-central1']) + + +3. The ``regions`` parameter of :class:`.ServiceBackend`: + + .. code-block:: python + + >>> b = hb.Batch(backend=hb.ServiceBackend(regions=['us-central1'])) + +4. The ``HAIL_BATCH_REGIONS`` environment variable: + + .. code-block:: sh + + export HAIL_BATCH_REGIONS=us-central1 + python3 my-batch-script.py + +5. The ``batch/region`` configuration variable: + + .. code-block:: sh + + hailctl config set batch/regions us-central1 + python3 my-batch-script.py + +.. warning:: + + If none of the five options above are specified, your job may run in *any* region! + +In Google Cloud Platform, the location of a multi-region bucket is considered *different* from any +region within that multi-region. For example, if a VM in the `us-central1` region reads data from a +bucket in the `us` multi-region, this incurs network charges becuse `us` is not considered equal to +`us-central1`. + +Container (aka Docker) images are a form of data. In Google Cloud Platform, we recommend storing +your images in a multi-regional artifact registry, which at time of writing, despite being +"multi-regional", does not incur network charges in the manner described above. + Using the UI ------------ diff --git a/hail/python/hailtop/batch/exceptions.py b/hail/python/hailtop/batch/exceptions.py index 6a49c750a3a..32425471a7b 100644 --- a/hail/python/hailtop/batch/exceptions.py +++ b/hail/python/hailtop/batch/exceptions.py @@ -1,4 +1,3 @@ - class BatchException(Exception): def __init__(self, msg=''): self.msg = msg diff --git a/hail/python/hailtop/batch/globals.py b/hail/python/hailtop/batch/globals.py index 61f9ab7f90e..fad29b20bee 100644 --- a/hail/python/hailtop/batch/globals.py +++ b/hail/python/hailtop/batch/globals.py @@ -1,6 +1,5 @@ import subprocess as sp - __ARG_MAX = None diff --git a/hail/python/hailtop/batch/hail_genetics_images.py b/hail/python/hailtop/batch/hail_genetics_images.py index 12be1abe4b0..c5f8b7882aa 100644 --- a/hail/python/hailtop/batch/hail_genetics_images.py +++ b/hail/python/hailtop/batch/hail_genetics_images.py @@ -2,11 +2,10 @@ from hailtop import pip_version - HAIL_GENETICS = 'hailgenetics/' HAIL_GENETICS_IMAGES = [ - HAIL_GENETICS + name - for name in ('hail', 'hailtop', 'genetics', 'python-dill', 'vep-grch37-85', 'vep-grch38-95')] + HAIL_GENETICS + name for name in ('hail', 'hailtop', 'genetics', 'python-dill', 'vep-grch37-85', 'vep-grch38-95') +] def hailgenetics_hail_image_for_current_python_version(): diff --git a/hail/python/hailtop/batch/job.py b/hail/python/hailtop/batch/job.py index e872d875db4..05fafdc21a1 100644 --- a/hail/python/hailtop/batch/job.py +++ b/hail/python/hailtop/batch/job.py @@ -5,25 +5,26 @@ import textwrap import warnings from shlex import quote as shq -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union, cast, Literal +from typing import Any, Callable, Dict, List, Literal, Optional, Set, Tuple, Union, cast + from typing_extensions import Self import hailtop.batch_client.client as bc +from hailtop.batch.resource import PythonResult, Resource, ResourceFile, ResourceGroup, ResourceType from . import backend, batch # pylint: disable=cyclic-import -from . import resource as _resource # pylint: disable=cyclic-import from .exceptions import BatchException def _add_resource_to_set(resource_set, resource, include_rg=True): - rg: Optional[_resource.ResourceGroup] - if isinstance(resource, _resource.ResourceGroup): + rg: Optional[ResourceGroup] + if isinstance(resource, ResourceGroup): rg = resource if include_rg: resource_set.add(resource) else: resource_set.add(resource) - if isinstance(resource, _resource.ResourceFile) and resource._has_resource_group(): + if isinstance(resource, ResourceFile) and resource._has_resource_group(): rg = resource._get_resource_group() else: rg = None @@ -59,13 +60,15 @@ def _new_uid(cls): cls._counter += 1 return uid - def __init__(self, - batch: 'batch.Batch', - token: str, - *, - name: Optional[str] = None, - attributes: Optional[Dict[str, str]] = None, - shell: Optional[str] = None): + def __init__( + self, + batch: 'batch.Batch', + token: str, + *, + name: Optional[str] = None, + attributes: Optional[Dict[str, str]] = None, + shell: Optional[str] = None, + ): self._batch = batch self._shell = shell self._token = token @@ -89,16 +92,16 @@ def __init__(self, self._user_code: List[str] = [] self._regions: Optional[List[str]] = None - self._resources: Dict[str, _resource.Resource] = {} - self._resources_inverse: Dict[_resource.Resource, str] = {} + self._resources: Dict[str, Resource] = {} + self._resources_inverse: Dict[Resource, str] = {} self._uid = Job._new_uid() self._job_id: Optional[int] = None - self._inputs: Set[_resource.Resource] = set() - self._internal_outputs: Set[Union[_resource.ResourceFile, _resource.PythonResult]] = set() - self._external_outputs: Set[Union[_resource.ResourceFile, _resource.PythonResult]] = set() - self._mentioned: Set[_resource.Resource] = set() # resources used in the command - self._valid: Set[_resource.Resource] = set() # resources declared in the appropriate place + self._inputs: Set[Resource] = set() + self._internal_outputs: Set[Union[ResourceFile, PythonResult]] = set() + self._external_outputs: Set[Union[ResourceFile, PythonResult]] = set() + self._mentioned: Set[Resource] = set() # resources used in the command + self._valid: Set[Resource] = set() # resources declared in the appropriate place self._dependencies: Set[Job] = set() self._submitted: bool = False self._client_job: Optional[bc.Job] = None @@ -114,7 +117,7 @@ def safe_str(s): self._dirname = f'{safe_str(name)}-{self._token}' if name else self._token - def _get_resource(self, item: str) -> '_resource.Resource': + def _get_resource(self, item: str) -> 'Resource': if item not in self._resources: r = self._batch._new_job_resource_file(self, value=item) self._resources[item] = r @@ -122,16 +125,19 @@ def _get_resource(self, item: str) -> '_resource.Resource': return self._resources[item] - def __getitem__(self, item: str) -> '_resource.Resource': + def __iter__(self): + raise TypeError(f'{type(self).__name__!r} object is not iterable') + + def __getitem__(self, item: str) -> 'Resource': return self._get_resource(item) - def __getattr__(self, item: str) -> '_resource.Resource': + def __getattr__(self, item: str) -> 'Resource': return self._get_resource(item) - def _add_internal_outputs(self, resource: '_resource.Resource') -> None: + def _add_internal_outputs(self, resource: 'Resource') -> None: _add_resource_to_set(self._internal_outputs, resource, include_rg=False) - def _add_inputs(self, resource: '_resource.Resource') -> None: + def _add_inputs(self, resource: 'Resource') -> None: _add_resource_to_set(self._inputs, resource, include_rg=False) def depends_on(self, *jobs: 'Job') -> Self: @@ -594,58 +600,69 @@ def handler(match_obj): if groups['BATCH']: raise BatchException(f"found a reference to a Batch object in command '{command}'.") if groups['PYTHON_RESULT'] and not allow_python_results: - raise BatchException(f"found a reference to a PythonResult object. hint: Use one of the methods `as_str`, `as_json` or `as_repr` on a PythonResult. command: '{command}'") + raise BatchException( + f"found a reference to a PythonResult object. hint: Use one of the methods `as_str`, `as_json` or `as_repr` on a PythonResult. command: '{command}'" + ) assert groups['RESOURCE_FILE'] or groups['RESOURCE_GROUP'] or groups['PYTHON_RESULT'] r_uid = match_obj.group() r = self._batch._resource_map.get(r_uid) if r is None: - raise BatchException(f"undefined resource '{r_uid}' in command '{command}'.\n" - f"Hint: resources must be from the same batch as the current job.") + raise BatchException( + f"undefined resource '{r_uid}' in command '{command}'.\n" + f"Hint: resources must be from the same batch as the current job." + ) - if r._source != self: + source = r.source() + if source != self: self._add_inputs(r) - if r._source is not None: - if r not in r._source._valid: - name = r._source._resources_inverse[r] - raise BatchException(f"undefined resource '{name}'\n" - f"Hint: resources must be defined within " - f"the job methods 'command' or 'declare_resource_group'") + if source is not None: + if r not in source._valid: + name = source._resources_inverse[r] + raise BatchException( + f"undefined resource '{name}'\n" + f"Hint: resources must be defined within " + f"the job methods 'command' or 'declare_resource_group'" + ) if self._always_run: - warnings.warn('A job marked as always run has a resource file dependency on another job. If the dependent job fails, ' - f'the always run job with the following command may not succeed:\n{command}') + warnings.warn( + 'A job marked as always run has a resource file dependency on another job. If the dependent job fails, ' + f'the always run job with the following command may not succeed:\n{command}' + ) - self._dependencies.add(r._source) - r._source._add_internal_outputs(r) + self._dependencies.add(source) + source._add_internal_outputs(r) else: _add_resource_to_set(self._valid, r) self._mentioned.add(r) return '${BATCH_TMPDIR}' + shq(r._get_path('')) - regexes = [_resource.ResourceFile._regex_pattern, - _resource.ResourceGroup._regex_pattern, - _resource.PythonResult._regex_pattern, - Job._regex_pattern, - batch.Batch._regex_pattern] + regexes = [ + ResourceFile._regex_pattern, + ResourceGroup._regex_pattern, + PythonResult._regex_pattern, + Job._regex_pattern, + batch.Batch._regex_pattern, + ] - subst_command = re.sub('(' + ')|('.join(regexes) + ')', - handler, - command) + subst_command = re.sub('(' + ')|('.join(regexes) + ')', handler, command) return subst_command def _pretty(self): - s = f"Job '{self._uid}'" \ - f"\tName:\t'{self.name}'" \ - f"\tAttributes:\t'{self.attributes}'" \ - f"\tImage:\t'{self._image}'" \ - f"\tCPU:\t'{self._cpu}'" \ - f"\tMemory:\t'{self._memory}'" \ - f"\tStorage:\t'{self._storage}'" \ + s = ( + f"Job '{self._uid}'" + f"\tName:\t'{self.name}'" + f"\tAttributes:\t'{self.attributes}'" + f"\tImage:\t'{self._image}'" + f"\tCPU:\t'{self._cpu}'" + f"\tMemory:\t'{self._memory}'" + f"\tStorage:\t'{self._storage}'" f"\tCommand:\t'{self._command}'" + ) return s def __str__(self): @@ -682,13 +699,15 @@ class BashJob(Job): or :meth:`.Batch.new_bash_job` instead. """ - def __init__(self, - batch: 'batch.Batch', - token: str, - *, - name: Optional[str] = None, - attributes: Optional[Dict[str, str]] = None, - shell: Optional[str] = None): + def __init__( + self, + batch: 'batch.Batch', + token: str, + *, + name: Optional[str] = None, + attributes: Optional[Dict[str, str]] = None, + shell: Optional[str] = None, + ): super().__init__(batch, token, name=name, attributes=attributes, shell=shell) self._command: List[str] = [] @@ -865,10 +884,10 @@ async def _compile(self, local_tmpdir, remote_tmpdir, *, dry_run=False): code_path = f'{job_path}/code.sh' code = self._batch.read_input(code_path) - wrapper_command = f''' + wrapper_command = f""" chmod u+x {code} source {code} -''' +""" wrapper_command = self._interpolate_command(wrapper_command) self._wrapper_code.append(wrapper_command) @@ -879,7 +898,7 @@ async def _compile(self, local_tmpdir, remote_tmpdir, *, dry_run=False): return True -UnpreparedArg = Union['_resource.ResourceType', List['UnpreparedArg'], Tuple['UnpreparedArg', ...], Dict[str, 'UnpreparedArg'], Any] +UnpreparedArg = Union[ResourceType, List['UnpreparedArg'], Tuple['UnpreparedArg', ...], Dict[str, 'UnpreparedArg'], Any] PreparedArg = Union[ Tuple[Literal['py_path'], str], @@ -888,7 +907,7 @@ async def _compile(self, local_tmpdir, remote_tmpdir, *, dry_run=False): Tuple[Literal['list'], List['PreparedArg']], Tuple[Literal['dict'], Dict[str, 'PreparedArg']], Tuple[Literal['tuple'], Tuple['PreparedArg', ...]], - Tuple[Literal['value'], Any] + Tuple[Literal['value'], Any], ] @@ -929,24 +948,26 @@ def add(x, y): instead. """ - def __init__(self, - batch: 'batch.Batch', - token: str, - *, - name: Optional[str] = None, - attributes: Optional[Dict[str, str]] = None): + def __init__( + self, + batch: 'batch.Batch', + token: str, + *, + name: Optional[str] = None, + attributes: Optional[Dict[str, str]] = None, + ): super().__init__(batch, token, name=name, attributes=attributes, shell=None) - self._resources: Dict[str, _resource.Resource] = {} - self._resources_inverse: Dict[_resource.Resource, str] = {} - self._function_calls: List[Tuple[_resource.PythonResult, int, Tuple[UnpreparedArg, ...], Dict[str, UnpreparedArg]]] = [] + self._resources: Dict[str, Resource] = {} + self._resources_inverse: Dict[Resource, str] = {} + self._function_calls: List[Tuple[PythonResult, int, Tuple[UnpreparedArg, ...], Dict[str, UnpreparedArg]]] = [] self.n_results = 0 - def _get_python_resource(self, item: str) -> '_resource.PythonResult': + def _get_python_resource(self, item: str) -> 'PythonResult': if item not in self._resources: r = self._batch._new_python_result(self, value=item) self._resources[item] = r self._resources_inverse[r] = item - return cast(_resource.PythonResult, self._resources[item]) + return cast(PythonResult, self._resources[item]) def image(self, image: str) -> 'PythonJob': """ @@ -984,7 +1005,7 @@ def image(self, image: str) -> 'PythonJob': self._image = image return self - def call(self, unapplied: Callable, *args: UnpreparedArg, **kwargs: UnpreparedArg) -> '_resource.PythonResult': + def call(self, unapplied: Callable, *args: UnpreparedArg, **kwargs: UnpreparedArg) -> 'PythonResult': """Execute a Python function. Examples @@ -1104,6 +1125,7 @@ def csv_to_json(path): def run_async(*args, **kwargs): return asyncio.run(unapplied_copy(*args, **kwargs)) + unapplied = run_async for arg in args: @@ -1124,22 +1146,23 @@ def run_async(*args, **kwargs): except TypeError as e: raise BatchException(f'Cannot call {unapplied.__name__} with the supplied arguments') from e - def handle_arg(r): - if r._source != self: + def handle_arg(r: Resource) -> None: + source = r.source() + if source != self: self._add_inputs(r) - if r._source is not None: - if r not in r._source._valid: - name = r._source._resources_inverse[r] + if source is not None: + if r not in source._valid: + name = source._resources_inverse[r] raise BatchException(f"undefined resource '{name}'\n") - self._dependencies.add(r._source) - r._source._add_internal_outputs(r) + self._dependencies.add(source) + source._add_internal_outputs(r) else: _add_resource_to_set(self._valid, r) self._mentioned.add(r) - def handle_args(r): - if isinstance(r, _resource.Resource): + def handle_args(r: Union[UnpreparedArg, List[UnpreparedArg], Dict[Any, UnpreparedArg]]) -> None: + if isinstance(r, Resource): handle_arg(r) elif isinstance(r, (list, tuple)): for elt in r: @@ -1163,13 +1186,15 @@ def handle_args(r): async def _compile(self, local_tmpdir, remote_tmpdir, *, dry_run=False): def preserialize(arg: UnpreparedArg) -> PreparedArg: - if isinstance(arg, _resource.PythonResult): + if isinstance(arg, PythonResult): return ('py_path', arg._get_path(local_tmpdir)) - if isinstance(arg, _resource.ResourceFile): + if isinstance(arg, ResourceFile): return ('path', arg._get_path(local_tmpdir)) - if isinstance(arg, _resource.ResourceGroup): - return ('dict_path', {name: resource._get_path(local_tmpdir) - for name, resource in arg._resources.items()}) + if isinstance(arg, ResourceGroup): + return ( + 'dict_path', + {name: resource._get_path(local_tmpdir) for name, resource in arg._resources.items()}, + ) if isinstance(arg, list): return ('list', [preserialize(elt) for elt in arg]) if isinstance(arg, tuple): @@ -1187,16 +1212,21 @@ def preserialize(arg: UnpreparedArg) -> PreparedArg: del kwargs args_file = await self._batch._serialize_python_to_input_file( - os.path.dirname(result._get_path(remote_tmpdir)), "args", i, (preserialized_args, preserialized_kwargs), dry_run + os.path.dirname(result._get_path(remote_tmpdir)), + "args", + i, + (preserialized_args, preserialized_kwargs), + dry_run, ) json_write, str_write, repr_write = [ - '' if not output else f''' + '' + if not output + else f""" with open('{output}', 'w') as out: out.write({formatter}(result) + '\\n') -''' - for output, formatter in - [(result._json, "json.dumps"), (result._str, "str"), (result._repr, "repr")] +""" + for output, formatter in [(result._json, "json.dumps"), (result._str, "str"), (result._repr, "repr")] ] wrapper_code = f'''python3 -c " diff --git a/hail/python/hailtop/batch/resource.py b/hail/python/hailtop/batch/resource.py index 799d80e8daf..52b19a74152 100644 --- a/hail/python/hailtop/batch/resource.py +++ b/hail/python/hailtop/batch/resource.py @@ -1,9 +1,11 @@ import abc -from typing import Optional, Set, cast, Union +from typing import TYPE_CHECKING, Optional, Set, Union, cast -from . import job # pylint: disable=cyclic-import from .exceptions import BatchException +if TYPE_CHECKING: + from hailtop.batch.job import Job, PythonJob + class Resource: """ @@ -11,7 +13,10 @@ class Resource: """ _uid: str - _source: Optional[job.Job] + + @abc.abstractmethod + def source(self) -> Optional['Job']: + pass @abc.abstractmethod def _get_path(self, directory: str) -> str: @@ -27,6 +32,7 @@ class ResourceFile(Resource, str): Class representing a single file resource. There exist two subclasses: :class:`.InputResourceFile` and :class:`.JobResourceFile`. """ + _counter = 0 _uid_prefix = "__RESOURCE_FILE__" _regex_pattern = r"(?P{}\d+)".format(_uid_prefix) # pylint: disable=consider-using-f-string @@ -47,7 +53,6 @@ def __init__(self, value: Optional[str]): super().__init__() assert value is None or isinstance(value, str) self._value = value - self._source: Optional[job.Job] = None self._output_paths: Set[str] = set() self._resource_group: Optional[ResourceGroup] = None @@ -56,8 +61,9 @@ def _get_path(self, directory: str): def _add_output_path(self, path: str) -> None: self._output_paths.add(path) - if self._source is not None: - self._source._external_outputs.add(self) + source = self.source() + if source is not None: + source._external_outputs.add(self) def _add_resource_group(self, rg: 'ResourceGroup') -> None: self._resource_group = rg @@ -103,6 +109,9 @@ def _get_path(self, directory: str) -> str: assert self._value is not None return directory + '/inputs/' + self._value + def source(self) -> None: + return None + class JobResourceFile(ResourceFile): """ @@ -124,13 +133,15 @@ class JobResourceFile(ResourceFile): to be saved. """ - def __init__(self, value, source: job.Job): + def __init__(self, value, source: 'Job'): super().__init__(value) self._has_extension = False - self._source: job.Job = source + self._source = source + + def source(self) -> 'Job': + return self._source def _get_path(self, directory: str) -> str: - assert self._source is not None assert self._value is not None return f'{directory}/{self._source._dirname}/{self._value}' @@ -225,7 +236,7 @@ def _new_uid(cls): cls._counter += 1 return uid - def __init__(self, source: Optional[job.Job], root: str, **values: ResourceFile): + def __init__(self, source: Optional['Job'], root: str, **values: ResourceFile): self._source = source self._resources = {} # dict of name to resource uid self._root = root @@ -236,6 +247,9 @@ def __init__(self, source: Optional[job.Job], root: str, **values: ResourceFile) self._resources[name] = resource_file resource_file._add_resource_group(self) + def source(self): + return self._source + def _get_path(self, directory: str) -> str: subdir = str(self._source._dirname) if self._source else 'inputs' return directory + '/' + subdir + '/' + self._root @@ -246,10 +260,16 @@ def _add_output_path(self, path: str) -> None: def _get_resource(self, item: str) -> ResourceFile: if item not in self._resources: - raise BatchException(f"'{item}' not found in the resource group.\n" - f"Hint: you must declare each attribute when constructing the resource group.") + raise BatchException( + f"'{item}' not found in the resource group.\n" + f"Hint: you must declare each attribute when constructing the resource group." + ) return self._resources[item] + def __iter__(self): + # TODO Iteration over the group's resources could perhaps be useful to implement + raise TypeError(f'{type(self).__name__!r} object is not iterable') + def __getitem__(self, item: str) -> ResourceFile: return self._get_resource(item) @@ -300,6 +320,7 @@ def square(x): to be saved. In most cases, you'll want to convert the :class:`.PythonResult` to a :class:`.JobResourceFile` in a human-readable format. """ + _counter = 0 _uid_prefix = "__PYTHON_RESULT__" _regex_pattern = r"(?P{}\d+)".format(_uid_prefix) # pylint: disable=consider-using-f-string @@ -316,18 +337,17 @@ def __new__(cls, *args, **kwargs): # pylint: disable=W0613 r._uid = uid return r - def __init__(self, value: str, source: job.PythonJob): + def __init__(self, value: str, source: 'PythonJob'): super().__init__() assert value is None or isinstance(value, str) self._value = value - self._source: job.PythonJob = source + self._source = source self._output_paths: Set[str] = set() self._json = None self._str = None self._repr = None def _get_path(self, directory: str) -> str: - assert self._source is not None assert self._value is not None return f'{directory}/{self._source._dirname}/{self._value}' @@ -344,11 +364,11 @@ def _add_output_path(self, path: str) -> None: if self._source is not None: self._source._external_outputs.add(self) - def source(self) -> job.PythonJob: + def source(self) -> 'PythonJob': """ Get the job that created the Python result. """ - return cast(job.PythonJob, self._source) + return self._source def as_json(self) -> JobResourceFile: """ diff --git a/hail/python/hailtop/batch/utils.py b/hail/python/hailtop/batch/utils.py index 2bf4cc9e5c4..7831d947dff 100644 --- a/hail/python/hailtop/batch/utils.py +++ b/hail/python/hailtop/batch/utils.py @@ -1,14 +1,16 @@ import math - from typing import List, Optional -from ..utils.utils import grouped, digits_needed +from ..config.deploy_config import TerraDeployConfig, get_deploy_config +from ..utils.utils import digits_needed, grouped from .batch import Batch from .exceptions import BatchException -from .resource import ResourceGroup, ResourceFile +from .resource import ResourceFile, ResourceGroup -def concatenate(b: Batch, files: List[ResourceFile], image: Optional[str] = None, branching_factor: int = 100) -> ResourceFile: +def concatenate( + b: Batch, files: List[ResourceFile], image: Optional[str] = None, branching_factor: int = 100 +) -> ResourceFile: """ Concatenate files using tree aggregation. @@ -59,8 +61,9 @@ def _concatenate(b, name, xs): return _combine(_concatenate, b, 'concatenate', files, branching_factor=branching_factor) -def plink_merge(b: Batch, bfiles: List[ResourceGroup], - image: Optional[str] = None, branching_factor: int = 100) -> ResourceGroup: +def plink_merge( + b: Batch, bfiles: List[ResourceGroup], image: Optional[str] = None, branching_factor: int = 100 +) -> ResourceGroup: """ Merge binary PLINK files using tree aggregation. @@ -114,3 +117,7 @@ def _combine(combop, b, name, xs, branching_factor=100): level += 1 assert len(xs) == 1 return xs[0] + + +def needs_tokens_mounted() -> bool: + return not isinstance(get_deploy_config(), TerraDeployConfig) diff --git a/hail/python/hailtop/batch_client/__init__.py b/hail/python/hailtop/batch_client/__init__.py index f9219abf05d..ad76233d2be 100644 --- a/hail/python/hailtop/batch_client/__init__.py +++ b/hail/python/hailtop/batch_client/__init__.py @@ -1,4 +1,4 @@ -from . import client, aioclient, parse, types +from . import aioclient, client, parse, types from .aioclient import BatchAlreadyCreatedError, BatchNotCreatedError, JobAlreadySubmittedError, JobNotSubmittedError __all__ = [ diff --git a/hail/python/hailtop/batch_client/aioclient.py b/hail/python/hailtop/batch_client/aioclient.py index ff5fe5f855b..a1e7fe6fc31 100644 --- a/hail/python/hailtop/batch_client/aioclient.py +++ b/hail/python/hailtop/batch_client/aioclient.py @@ -1,25 +1,26 @@ -from typing import Optional, Dict, Any, List, Tuple, Union, AsyncIterator, TypedDict, cast +import asyncio +import functools +import json +import logging import math import random -import logging -import json -import functools -import asyncio +import secrets +from enum import Enum +from typing import Any, AsyncIterator, Dict, List, Optional, Tuple, TypedDict, Union, cast + import aiohttp import orjson -import secrets -from hailtop import is_notebook -from hailtop.config import get_deploy_config, DeployConfig +from hailtop import httpx, is_notebook from hailtop.aiocloud.common import Session from hailtop.aiocloud.common.credentials import CloudCredentials from hailtop.auth import hail_credentials -from hailtop.utils import bounded_gather, sleep_before_try +from hailtop.config import DeployConfig, get_deploy_config +from hailtop.utils import async_to_blocking, bounded_gather, sleep_before_try from hailtop.utils.rich_progress_bar import BatchProgressBar, BatchProgressBarTask -from hailtop import httpx -from .types import GetJobsResponseV1Alpha, JobListEntryV1Alpha, GetJobResponseV1Alpha -from .globals import tasks, complete_states +from .globals import ROOT_JOB_GROUP_ID, complete_states, tasks +from .types import GetJobGroupResponseV1Alpha, GetJobResponseV1Alpha, GetJobsResponseV1Alpha, JobListEntryV1Alpha log = logging.getLogger('batch_client.aioclient') @@ -32,14 +33,6 @@ class JobNotSubmittedError(Exception): pass -class AbsoluteJobId(int): - pass - - -class InUpdateJobId(int): - pass - - class Job: @staticmethod def _get_error(job_status, task): @@ -176,19 +169,23 @@ def _get_duration(container_status): @staticmethod def submitted_job(batch: 'Batch', job_id: int, _status: Optional[GetJobResponseV1Alpha] = None): - return Job(batch, AbsoluteJobId(job_id), _status=_status) + return Job(batch, job_id, submitted=True, _status=_status) @staticmethod def unsubmitted_job(batch: 'Batch', job_id: int): - return Job(batch, InUpdateJobId(job_id)) - - def __init__(self, - batch: 'Batch', - job_id: Union[AbsoluteJobId, InUpdateJobId], - *, - _status: Optional[GetJobResponseV1Alpha] = None): + return Job(batch, job_id, submitted=False) + + def __init__( + self, + batch: 'Batch', + job_id: int, + submitted: bool, + *, + _status: Optional[GetJobResponseV1Alpha] = None, + ): self._batch = batch self._job_id = job_id + self._submitted = submitted self._status = _status def _raise_if_not_submitted(self): @@ -201,11 +198,12 @@ def _raise_if_submitted(self): def _submit(self, in_update_start_job_id: int): self._raise_if_submitted() - self._job_id = AbsoluteJobId(in_update_start_job_id + self._job_id - 1) + self._job_id = in_update_start_job_id + self._job_id - 1 + self._submitted = True @property def is_submitted(self): - return isinstance(self._job_id, AbsoluteJobId) + return self._submitted @property def batch_id(self) -> int: @@ -272,8 +270,8 @@ async def status(self) -> Dict[str, Any]: async def wait(self) -> Dict[str, Any]: return cast( - Dict[str, Any], # https://stackoverflow.com/a/76515675/6823256 - await self._wait_for_states(*complete_states) + Dict[str, Any], # https://stackoverflow.com/a/76515675/6823256 + await self._wait_for_states(*complete_states), ) async def _wait_for_states(self, *states: str) -> GetJobResponseV1Alpha: @@ -287,7 +285,9 @@ async def _wait_for_states(self, *states: str) -> GetJobResponseV1Alpha: async def container_log(self, container_name: str) -> bytes: self._raise_if_not_submitted() - async with await self._client._get(f'/api/v1alpha/batches/{self.batch_id}/jobs/{self.job_id}/log/{container_name}') as resp: + async with await self._client._get( + f'/api/v1alpha/batches/{self.batch_id}/jobs/{self.job_id}/log/{container_name}' + ) as resp: return await resp.read() async def log(self): @@ -301,6 +301,245 @@ async def attempts(self): return await resp.json() +class JobGroupAlreadySubmittedError(Exception): + pass + + +class JobGroupNotSubmittedError(Exception): + pass + + +class JobGroupDebugInfo(TypedDict): + status: GetJobGroupResponseV1Alpha + jobs: List[JobListEntryV1Alpha] + job_groups: List[GetJobGroupResponseV1Alpha] + + +class JobGroup: + @staticmethod + def submitted_job_group( + batch: 'Batch', + job_group_id: int, + *, + _last_known_status: Optional[GetJobGroupResponseV1Alpha] = None, + ) -> 'JobGroup': + return JobGroup(batch, job_group_id, submitted=True, last_known_status=_last_known_status) + + @staticmethod + def unsubmitted_job_group(batch: 'Batch', job_group_id: int) -> 'JobGroup': + return JobGroup(batch, job_group_id, submitted=False) + + def __init__( + self, + batch: 'Batch', + job_group_id: int, + submitted: bool, + *, + last_known_status: Optional[GetJobGroupResponseV1Alpha] = None, + ): + self._batch = batch + self._job_group_id = job_group_id + self._submitted = submitted + self._last_known_status = last_known_status + + def _submit(self, in_update_start_job_group_id: Optional[int]): + self._raise_if_submitted() + if in_update_start_job_group_id is None: + assert self._job_group_id == ROOT_JOB_GROUP_ID + else: + self._job_group_id = in_update_start_job_group_id + self._job_group_id - 1 + self._submitted = True + + def _raise_if_not_submitted(self): + if not self.is_submitted: + raise JobGroupNotSubmittedError + + def _raise_if_submitted(self): + if self.is_submitted: + raise JobGroupAlreadySubmittedError + + async def attributes(self) -> Dict[str, str]: + self._raise_if_not_submitted() + status = await self.last_known_status() + if 'attributes' in status: + return status['attributes'] + return {} + + @property + def is_submitted(self) -> bool: + return self._submitted + + @property + def batch_id(self) -> int: + return self._batch.id + + @property + def job_group_id(self) -> int: + self._raise_if_not_submitted() + return self._job_group_id + + @property + def id(self) -> Tuple[int, int]: + self._raise_if_not_submitted() + return (self.batch_id, self.job_group_id) + + @property + def _client(self) -> 'BatchClient': + return self._batch._client + + async def cancel(self): + self._raise_if_not_submitted() + await self._client._patch(f'/api/v1alpha/batches/{self.batch_id}/job-groups/{self.job_group_id}/cancel') + + async def job_groups(self) -> AsyncIterator['JobGroup']: + self._raise_if_not_submitted() + last_job_group_id = None + while True: + params: Dict[str, Any] = {} + if last_job_group_id is not None: + params['last_job_group_id'] = last_job_group_id + resp = await self._client._get( + f'/api/v1alpha/batches/{self.batch_id}/job-groups/{self.job_group_id}/job-groups', params=params + ) + body = await resp.json() + for job_group in body['job_groups']: + yield JobGroup.submitted_job_group(self._batch, job_group['job_group_id'], _last_known_status=job_group) + last_job_group_id = body.get('last_job_group_id') + if last_job_group_id is None: + break + + async def jobs( + self, + q: Optional[str] = None, + version: Optional[int] = None, + recursive: bool = False, + ) -> AsyncIterator[JobListEntryV1Alpha]: + self._raise_if_not_submitted() + if version is None: + version = 1 + last_job_id = None + while True: + params: Dict[str, Any] = {'recursive': str(recursive)} + if q is not None: + params['q'] = q + if last_job_id is not None: + params['last_job_id'] = last_job_id + resp = await self._client._get( + f'/api/v{version}alpha/batches/{self.batch_id}/job-groups/{self.job_group_id}/jobs', params=params + ) + body = cast(GetJobsResponseV1Alpha, await resp.json()) + for job in body['jobs']: + yield job + last_job_id = body.get('last_job_id') + if last_job_id is None: + break + + async def status(self) -> GetJobGroupResponseV1Alpha: + self._raise_if_not_submitted() + resp = await self._client._get(f'/api/v1alpha/batches/{self.batch_id}/job-groups/{self.job_group_id}') + json_status = await resp.json() + assert isinstance(json_status, dict), json_status + self._last_known_status = cast(GetJobGroupResponseV1Alpha, json_status) + return self._last_known_status + + async def last_known_status(self) -> GetJobGroupResponseV1Alpha: + self._raise_if_not_submitted() + if self._last_known_status is None: + return await self.status() # updates _last_known_status + return self._last_known_status + + def create_job(self, image: str, command: List[str], **kwargs) -> Job: + return self._batch._create_job(self, {'command': command, 'image': image, 'type': 'docker'}, **kwargs) + + def create_jvm_job(self, jar_spec: Dict[str, str], argv: List[str], *, profile: bool = False, **kwargs): + return self._batch._create_job( + self, {'type': 'jvm', 'jar_spec': jar_spec, 'command': argv, 'profile': profile}, **kwargs + ) + + def create_job_group( + self, + *, + attributes: Optional[Dict[str, str]] = None, + callback: Optional[str] = None, + cancel_after_n_failures: Optional[int] = None, + ) -> 'JobGroup': + return self._batch._create_job_group( + self, + attributes=attributes, + callback=callback, + cancel_after_n_failures=cancel_after_n_failures, + ) + + async def _wait( + self, + description: str, + progress: BatchProgressBar, + disable_progress_bar: bool, + ) -> GetJobGroupResponseV1Alpha: + self._raise_if_not_submitted() + deploy_config = get_deploy_config() + url = deploy_config.external_url('batch', f'/batches/{self.batch_id}') + i = 0 + status = await self.status() + + if is_notebook(): + description += f'[link={url}]{self.batch_id}[/link]' + else: + description += url + + with progress.with_task(description, total=status['n_jobs'], disable=disable_progress_bar) as progress_task: + while True: + status = await self.status() + progress_task.update(None, total=status['n_jobs'], completed=status['n_completed']) + if status['complete']: + return status + j = random.randrange(math.floor(1.1**i)) + await asyncio.sleep(0.100 * j) + # max 44.5s + if i < 64: + i = i + 1 + + async def wait( + self, *, disable_progress_bar: bool = False, description: str = '', progress: Optional[BatchProgressBar] = None + ) -> GetJobGroupResponseV1Alpha: + self._raise_if_not_submitted() + if description: + description += ': ' + if progress is not None: + return await self._wait(description, progress, disable_progress_bar) + with BatchProgressBar(disable=disable_progress_bar) as progress2: + return await self._wait(description, progress2, disable_progress_bar) + + async def debug_info( + self, + _jobs_query_string: Optional[str] = None, + _max_job_groups: Optional[int] = None, + _max_jobs: Optional[int] = None, + ) -> JobGroupDebugInfo: + self._raise_if_not_submitted() + jg_status = await self.status() + + job_groups = [] + jobs = [] + + async for jg in self.job_groups(): + if _max_job_groups and _max_job_groups == len(job_groups): + break + job_groups.append({'status': jg._last_known_status}) + + async for j_status in self.jobs(q=_jobs_query_string): + if _max_jobs and len(jobs) == _max_jobs: + break + id = j_status['job_id'] + log, job = await asyncio.gather(self._batch.get_job_log(id), self._batch.get_job(id)) + jobs.append({'log': log, 'status': job._status}) + return {'status': jg_status, 'job_groups': job_groups, 'jobs': jobs} + + def __str__(self): + debug_info = async_to_blocking(self.debug_info()) + return str(orjson.dumps(debug_info).decode('utf-8')) + + class BatchSubmissionInfo: def __init__(self, used_fast_path: Optional[bool] = None): self.used_fast_path = used_fast_path @@ -317,17 +556,36 @@ class BatchAlreadyCreatedError(Exception): class BatchDebugInfo(TypedDict): status: Dict[str, Any] jobs: List[JobListEntryV1Alpha] + job_groups: List[GetJobGroupResponseV1Alpha] + + +class SpecType(Enum): + JOB = 'job' + JOB_GROUP = 'job_group' + + +class SpecBytes: + def __init__(self, spec_bytes: bytes, typ: SpecType): + self.spec_bytes = spec_bytes + self.typ = typ + + @property + def n_bytes(self) -> int: + return len(self.spec_bytes) + class Batch: - def __init__(self, - client: 'BatchClient', - id: Optional[int], - *, - attributes: Optional[Dict[str, str]] = None, - callback: Optional[str] = None, - token: Optional[str] = None, - cancel_after_n_failures: Optional[int] = None, - last_known_status: Optional[Dict[str, Any]] = None): + def __init__( + self, + client: 'BatchClient', + id: Optional[int], + *, + attributes: Optional[Dict[str, str]] = None, + callback: Optional[str] = None, + token: Optional[str] = None, + cancel_after_n_failures: Optional[int] = None, + last_known_status: Optional[Dict[str, Any]] = None, + ): self._client = client self._id = id self.attributes: Dict[str, str] = attributes or {} @@ -341,10 +599,19 @@ def __init__(self, self._submission_info = BatchSubmissionInfo() self._last_known_status = last_known_status - self._job_idx = 0 + self._in_update_job_group_id = 0 + self._job_group_specs: List[Dict[str, Any]] = [] + self._job_groups: List[JobGroup] = [] + + self._in_update_job_id = 0 self._job_specs: List[Dict[str, Any]] = [] self._jobs: List[Job] = [] + if self._id is not None: + self._root_job_group = JobGroup.submitted_job_group(self, ROOT_JOB_GROUP_ID) + else: + self._root_job_group = JobGroup.unsubmitted_job_group(self, ROOT_JOB_GROUP_ID) + def _raise_if_not_created(self): if not self.is_created: raise BatchNotCreatedError @@ -363,34 +630,21 @@ def id(self) -> int: def is_created(self): return self._id is not None + def get_job_group(self, job_group_id: int) -> JobGroup: + self._raise_if_not_created() + return JobGroup.submitted_job_group(self, job_group_id) + async def cancel(self): self._raise_if_not_created() - await self._client._patch(f'/api/v1alpha/batches/{self.id}/cancel') + await self._root_job_group.cancel() - async def jobs(self, - q: Optional[str] = None, - version: Optional[int] = None - ) -> AsyncIterator[JobListEntryV1Alpha]: + def jobs(self, q: Optional[str] = None, version: Optional[int] = None) -> AsyncIterator[JobListEntryV1Alpha]: self._raise_if_not_created() - if version is None: - version = 1 - last_job_id = None - while True: - params = {} - if q is not None: - params['q'] = q - if last_job_id is not None: - params['last_job_id'] = last_job_id - resp = await self._client._get(f'/api/v{version}alpha/batches/{self.id}/jobs', params=params) - body = cast( - GetJobsResponseV1Alpha, - await resp.json() - ) - for job in body['jobs']: - yield job - last_job_id = body.get('last_job_id') - if last_job_id is None: - break + return self._root_job_group.jobs(q, version, recursive=True) + + def job_groups(self) -> AsyncIterator[JobGroup]: + self._raise_if_not_created() + return self._root_job_group.job_groups() async def get_job(self, job_id: int) -> Job: self._raise_if_not_created() @@ -400,27 +654,6 @@ async def get_job_log(self, job_id: int) -> Dict[str, Any]: self._raise_if_not_created() return await self._client.get_job_log(self.id, job_id) - # { - # id: int - # user: str - # billing_project: str - # token: str - # state: str, (open, failure, cancelled, success, running) - # complete: bool - # closed: bool - # n_jobs: int - # n_completed: int - # n_succeeded: int - # n_failed: int - # n_cancelled: int - # time_created: optional(str), (date) - # time_closed: optional(str), (date) - # time_completed: optional(str), (date) - # duration: optional(str) - # attributes: optional(dict(str, str)) - # msec_mcpu: int - # cost: float - # } async def status(self) -> Dict[str, Any]: self._raise_if_not_created() resp = await self._client._get(f'/api/v1alpha/batches/{self.id}') @@ -435,12 +668,9 @@ async def last_known_status(self) -> Dict[str, Any]: return await self.status() # updates _last_known_status return self._last_known_status - async def _wait(self, - description: str, - progress: BatchProgressBar, - disable_progress_bar: bool, - starting_job: int - ) -> Dict[str, Any]: + async def _wait( + self, description: str, progress: BatchProgressBar, disable_progress_bar: bool, starting_job: int + ) -> Dict[str, Any]: self._raise_if_not_created() deploy_config = get_deploy_config() url = deploy_config.external_url('batch', f'/batches/{self.id}') @@ -450,28 +680,30 @@ async def _wait(self, description += f'[link={url}]{self.id}[/link]' else: description += url - with progress.with_task(description, - total=status['n_jobs'] - starting_job + 1, - disable=disable_progress_bar) as progress_task: + with progress.with_task( + description, total=status['n_jobs'] - starting_job + 1, disable=disable_progress_bar + ) as progress_task: while True: status = await self.status() - progress_task.update(None, total=status['n_jobs'] - starting_job + 1, completed=status['n_completed'] - starting_job + 1) + progress_task.update( + None, total=status['n_jobs'] - starting_job + 1, completed=status['n_completed'] - starting_job + 1 + ) if status['complete']: return status - j = random.randrange(math.floor(1.1 ** i)) + j = random.randrange(math.floor(1.1**i)) await asyncio.sleep(0.100 * j) # max 44.5s if i < 64: i = i + 1 - # FIXME Error if this is called while within a job of the same Batch - async def wait(self, - *, - disable_progress_bar: bool = False, - description: str = '', - progress: Optional[BatchProgressBar] = None, - starting_job: int = 1, - ) -> Dict[str, Any]: + async def wait( + self, + *, + disable_progress_bar: bool = False, + description: str = '', + progress: Optional[BatchProgressBar] = None, + starting_job: int = 1, + ) -> Dict[str, Any]: self._raise_if_not_created() if description: description += ': ' @@ -480,21 +712,28 @@ async def wait(self, with BatchProgressBar(disable=disable_progress_bar) as progress2: return await self._wait(description, progress2, disable_progress_bar, starting_job) - async def debug_info(self, - _jobs_query_string: Optional[str] = None, - _max_jobs: Optional[int] = None, - ) -> BatchDebugInfo: + async def debug_info( + self, + _jobs_query_string: Optional[str] = None, + _max_jobs: Optional[int] = None, + _max_job_groups: Optional[int] = None, + ) -> BatchDebugInfo: self._raise_if_not_created() batch_status = await self.status() + job_groups = [] + async for job_group in self._root_job_group.job_groups(): + if _max_job_groups and len(job_groups) == _max_job_groups: + break + job_groups.append({'status': (await job_group.status())}) jobs = [] - async for j_status in self.jobs(q=_jobs_query_string): + async for j_status in self._root_job_group.jobs(q=_jobs_query_string): if _max_jobs and len(jobs) == _max_jobs: break id = j_status['job_id'] log, job = await asyncio.gather(self.get_job_log(id), self.get_job(id)) jobs.append({'log': log, 'status': job._status}) - return {'status': batch_status, 'jobs': jobs} + return {'status': batch_status, 'jobs': jobs, 'job_groups': job_groups} async def delete(self): self._raise_if_not_created() @@ -505,39 +744,55 @@ async def delete(self): raise def create_job(self, image: str, command: List[str], **kwargs) -> Job: - return self._create_job( - {'command': command, 'image': image, 'type': 'docker'}, **kwargs - ) + return self._create_job(self._root_job_group, {'command': command, 'image': image, 'type': 'docker'}, **kwargs) def create_jvm_job(self, jar_spec: Dict[str, str], argv: List[str], *, profile: bool = False, **kwargs): if 'always_copy_output' in kwargs: raise ValueError("the 'always_copy_output' option is not allowed for JVM jobs") - return self._create_job({'type': 'jvm', 'jar_spec': jar_spec, 'command': argv, 'profile': profile}, **kwargs) - - def _create_job(self, - process: dict, - *, - env: Optional[Dict[str, str]] = None, - port: Optional[int] = None, - resources: Optional[dict] = None, - secrets: Optional[dict] = None, - service_account: Optional[str] = None, - attributes: Optional[Dict[str, str]] = None, - parents: Optional[List[Job]] = None, - input_files: Optional[List[Tuple[str, str]]] = None, - output_files: Optional[List[Tuple[str, str]]] = None, - always_run: bool = False, - always_copy_output: bool = False, - timeout: Optional[Union[int, float]] = None, - cloudfuse: Optional[List[Tuple[str, str, bool]]] = None, - requester_pays_project: Optional[str] = None, - mount_tokens: bool = False, - network: Optional[str] = None, - unconfined: bool = False, - user_code: Optional[str] = None, - regions: Optional[List[str]] = None - ) -> Job: - self._job_idx += 1 + return self._create_job( + self._root_job_group, {'type': 'jvm', 'jar_spec': jar_spec, 'command': argv, 'profile': profile}, **kwargs + ) + + def create_job_group( + self, + *, + attributes: Optional[Dict[str, str]] = None, + callback: Optional[str] = None, + cancel_after_n_failures: Optional[int] = None, + ) -> JobGroup: + return self._create_job_group( + self._root_job_group, + attributes=attributes, + callback=callback, + cancel_after_n_failures=cancel_after_n_failures, + ) + + def _create_job( + self, + job_group: JobGroup, + process: dict, + *, + env: Optional[Dict[str, str]] = None, + port: Optional[int] = None, + resources: Optional[dict] = None, + secrets: Optional[dict] = None, + service_account: Optional[str] = None, + attributes: Optional[Dict[str, str]] = None, + parents: Optional[List[Job]] = None, + input_files: Optional[List[Tuple[str, str]]] = None, + output_files: Optional[List[Tuple[str, str]]] = None, + always_run: bool = False, + always_copy_output: bool = False, + timeout: Optional[Union[int, float]] = None, + cloudfuse: Optional[List[Tuple[str, str, bool]]] = None, + requester_pays_project: Optional[str] = None, + mount_tokens: bool = False, + network: Optional[str] = None, + unconfined: bool = False, + user_code: Optional[str] = None, + regions: Optional[List[str]] = None, + ) -> Job: + self._in_update_job_id += 1 if parents is None: parents = [] @@ -548,10 +803,9 @@ def _create_job(self, invalid_job_ids = [] for parent in parents: if not parent.is_submitted: - assert isinstance(parent._job_id, InUpdateJobId) if parent._batch != self: foreign_batches.append(parent) - elif not 0 < parent._job_id < self._job_idx: + elif not 0 < parent._job_id < self._in_update_job_id: invalid_job_ids.append(parent._job_id) else: in_update_parent_ids.append(parent._job_id) @@ -579,12 +833,17 @@ def _create_job(self, job_spec = { 'always_run': always_run, 'always_copy_output': always_copy_output, - 'job_id': self._job_idx, + 'job_id': self._in_update_job_id, 'absolute_parent_ids': absolute_parent_ids, 'in_update_parent_ids': in_update_parent_ids, 'process': process, } + if job_group.is_submitted: + job_spec['absolute_job_group_id'] = job_group._job_group_id + else: + job_spec['in_update_job_group_id'] = job_group._job_group_id + if env: job_spec['env'] = [{'name': k, 'value': v} for (k, v) in env.items()] if port is not None: @@ -605,8 +864,10 @@ def _create_job(self, if output_files: job_spec['output_files'] = [{"from": src, "to": dst} for (src, dst) in output_files] if cloudfuse: - job_spec['cloudfuse'] = [{"bucket": bucket, "mount_path": mount_path, "read_only": read_only} - for (bucket, mount_path, read_only) in cloudfuse] + job_spec['cloudfuse'] = [ + {"bucket": bucket, "mount_path": mount_path, "read_only": read_only} + for (bucket, mount_path, read_only) in cloudfuse + ] if requester_pays_project: job_spec['requester_pays_project'] = requester_pays_project if mount_tokens: @@ -622,13 +883,48 @@ def _create_job(self, self._job_specs.append(job_spec) - j = Job.unsubmitted_job(self, self._job_idx) + j = Job.unsubmitted_job(self, self._in_update_job_id) self._jobs.append(j) return j - async def _create_fast(self, byte_job_specs: List[bytes], n_jobs: int, job_progress_task: BatchProgressBarTask): + def _create_job_group( + self, + parent_job_group: JobGroup, + *, + attributes: Optional[Dict[str, str]] = None, + callback: Optional[str] = None, + cancel_after_n_failures: Optional[int] = None, + ) -> JobGroup: + self._in_update_job_group_id += 1 + spec: Dict[str, Any] = {'job_group_id': self._in_update_job_group_id} + if attributes is not None: + spec['attributes'] = attributes + if callback is not None: + spec['callback'] = callback + if cancel_after_n_failures is not None: + spec['cancel_after_n_failures'] = cancel_after_n_failures + + if parent_job_group.is_submitted: + spec['absolute_parent_id'] = parent_job_group._job_group_id + else: + spec['in_update_parent_id'] = parent_job_group._job_group_id + + self._job_group_specs.append(spec) + + jg = JobGroup.unsubmitted_job_group(self, self._in_update_job_group_id) + self._job_groups.append(jg) + return jg + + async def _create_fast( + self, + byte_specs_bunch: List[SpecBytes], + job_group_progress_task: BatchProgressBarTask, + job_progress_task: BatchProgressBarTask, + ) -> Tuple[int, int]: + byte_job_specs = [spec.spec_bytes for spec in byte_specs_bunch if spec.typ == SpecType.JOB] + byte_job_group_specs = [spec.spec_bytes for spec in byte_specs_bunch if spec.typ == SpecType.JOB_GROUP] + self._raise_if_created() - assert n_jobs == len(self._job_specs) b = bytearray() b.extend(b'{"bunch":') b.append(ord('[')) @@ -637,24 +933,49 @@ async def _create_fast(self, byte_job_specs: List[bytes], n_jobs: int, job_progr b.append(ord(',')) b.extend(spec) b.append(ord(']')) + b.extend(b',"job_groups":') + b.append(ord('[')) + for i, spec in enumerate(byte_job_group_specs): + if i > 0: + b.append(ord(',')) + b.extend(spec) + b.append(ord(']')) b.extend(b',"batch":') - b.extend(json.dumps(self._batch_spec()).encode('utf-8')) + b.extend(orjson.dumps(self._batch_spec())) b.append(ord('}')) resp = await self._client._post( '/api/v1alpha/batches/create-fast', data=aiohttp.BytesPayload(b, content_type='application/json', encoding='utf-8'), ) batch_json = await resp.json() - job_progress_task.update(n_jobs) + job_group_progress_task.update(len(byte_job_group_specs)) + job_progress_task.update(len(byte_job_specs)) + start_job_group_id = int(batch_json['start_job_group_id']) self._id = batch_json['id'] + self._root_job_group._submit(start_job_group_id) self._submission_info = BatchSubmissionInfo(used_fast_path=True) - - async def _update_fast(self, byte_job_specs: List[bytes], n_jobs: int, job_progress_task: BatchProgressBarTask) -> int: + return (start_job_group_id, int(batch_json['start_job_id'])) + + async def _update_fast( + self, + byte_specs_bunch: List[SpecBytes], + job_group_progress_task: BatchProgressBarTask, + job_progress_task: BatchProgressBarTask, + ) -> Tuple[int, int]: self._raise_if_not_created() - assert n_jobs == len(self._job_specs) + byte_job_group_specs = [spec.spec_bytes for spec in byte_specs_bunch if spec.typ == SpecType.JOB_GROUP] + byte_job_specs = [spec.spec_bytes for spec in byte_specs_bunch if spec.typ == SpecType.JOB] + b = bytearray() - b.extend(b'{"bunch":') + b.extend(b'{"job_groups":') + b.append(ord('[')) + for i, spec in enumerate(byte_job_group_specs): + if i > 0: + b.append(ord(',')) + b.extend(spec) + b.append(ord(']')) + b.extend(b',"bunch":') b.append(ord('[')) for i, spec in enumerate(byte_job_specs): if i > 0: @@ -669,54 +990,54 @@ async def _update_fast(self, byte_job_specs: List[bytes], n_jobs: int, job_progr data=aiohttp.BytesPayload(b, content_type='application/json', encoding='utf-8'), ) update_json = await resp.json() - job_progress_task.update(n_jobs) + job_group_progress_task.update(len(byte_job_group_specs)) + job_progress_task.update(len(byte_job_specs)) self._submission_info = BatchSubmissionInfo(used_fast_path=True) - return int(update_json['start_job_id']) - - def _create_bunches(self, - specs: List[dict], - max_bunch_bytesize: int, - max_bunch_size: int, - ) -> Tuple[List[List[bytes]], List[int]]: + return (int(update_json['start_job_group_id']), int(update_json['start_job_id'])) + + def _create_bunches( + self, + job_group_specs: List[dict], + job_specs: List[dict], + max_bunch_bytesize: int, + max_bunch_size: int, + ) -> List[List[SpecBytes]]: assert max_bunch_bytesize > 0 assert max_bunch_size > 0 - byte_specs = [orjson.dumps(spec) for spec in specs] - byte_specs_bunches: List[List[bytes]] = [] - bunch_sizes = [] - bunch: List[bytes] = [] + job_group_byte_specs = [SpecBytes(orjson.dumps(spec), SpecType.JOB_GROUP) for spec in job_group_specs] + job_byte_specs = [SpecBytes(orjson.dumps(spec), SpecType.JOB) for spec in job_specs] + + byte_specs_bunches: List[List[SpecBytes]] = [] + bunch: List[SpecBytes] = [] bunch_n_bytes = 0 - bunch_n_jobs = 0 - for spec in byte_specs: - n_bytes = len(spec) + for spec in [*job_group_byte_specs, *job_byte_specs]: + n_bytes = spec.n_bytes assert n_bytes < max_bunch_bytesize, ( 'every spec must be less than max_bunch_bytesize,' - f' { max_bunch_bytesize }B, but {spec.decode()} is larger') + f' { max_bunch_bytesize }B, but {spec.spec_bytes.decode()} is larger' + ) if bunch_n_bytes + n_bytes < max_bunch_bytesize and len(bunch) < max_bunch_size: bunch.append(spec) bunch_n_bytes += n_bytes - bunch_n_jobs += 1 else: byte_specs_bunches.append(bunch) - bunch_sizes.append(bunch_n_jobs) bunch = [spec] bunch_n_bytes = n_bytes - bunch_n_jobs = 1 if bunch: byte_specs_bunches.append(bunch) - bunch_sizes.append(bunch_n_jobs) - return (byte_specs_bunches, bunch_sizes) + return byte_specs_bunches - async def _submit_jobs(self, update_id: int, byte_job_specs: List[bytes], n_jobs: int, progress_task: BatchProgressBarTask): + async def _submit_spec_bunch(self, url: str, byte_spec_bunch: List[bytes], progress_task: BatchProgressBarTask): self._raise_if_not_created() - assert len(byte_job_specs) > 0, byte_job_specs + assert len(byte_spec_bunch) > 0, byte_spec_bunch b = bytearray() b.append(ord('[')) i = 0 - while i < len(byte_job_specs): - spec = byte_job_specs[i] + while i < len(byte_spec_bunch): + spec = byte_spec_bunch[i] if i > 0: b.append(ord(',')) b.extend(spec) @@ -725,14 +1046,36 @@ async def _submit_jobs(self, update_id: int, byte_job_specs: List[bytes], n_jobs b.append(ord(']')) await self._client._post( - f'/api/v1alpha/batches/{self.id}/updates/{update_id}/jobs/create', + url, data=aiohttp.BytesPayload(b, content_type='application/json', encoding='utf-8'), ) - progress_task.update(n_jobs) + progress_task.update(len(byte_spec_bunch)) + + async def _submit_jobs(self, update_id: int, bunch: List[SpecBytes], progress_task: BatchProgressBarTask): + byte_job_specs = [spec.spec_bytes for spec in bunch if spec.typ == SpecType.JOB] + if len(byte_job_specs) != 0: + await self._submit_spec_bunch( + f'/api/v1alpha/batches/{self.id}/updates/{update_id}/jobs/create', byte_job_specs, progress_task + ) + + async def _submit_job_groups(self, update_id: int, bunch: List[SpecBytes], progress_task: BatchProgressBarTask): + byte_job_group_specs = [spec.spec_bytes for spec in bunch if spec.typ == SpecType.JOB_GROUP] + if len(byte_job_group_specs) != 0: + await self._submit_spec_bunch( + f'/api/v1alpha/batches/{self.id}/updates/{update_id}/job-groups/create', + byte_job_group_specs, + progress_task, + ) def _batch_spec(self): + n_job_groups = len(self._job_group_specs) n_jobs = len(self._job_specs) - batch_spec = {'billing_project': self._client.billing_project, 'n_jobs': n_jobs, 'token': self.token} + batch_spec = { + 'billing_project': self._client.billing_project, + 'n_jobs': n_jobs, + 'n_job_groups': n_job_groups, + 'token': self.token, + } if self.attributes: batch_spec['attributes'] = self.attributes if self._callback: @@ -746,108 +1089,149 @@ async def _open_batch(self) -> Optional[int]: batch_spec = self._batch_spec() batch_json = await (await self._client._post('/api/v1alpha/batches/create', json=batch_spec)).json() self._id = batch_json['id'] + self._root_job_group._submit(None) update_id = batch_json['update_id'] if update_id is None: - assert batch_spec['n_jobs'] == 0 + assert batch_spec['n_jobs'] == 0 and batch_spec['n_job_groups'] == 0 return update_id def _update_spec(self) -> dict: update_token = secrets.token_urlsafe(32) - return {'n_jobs': len(self._jobs), 'token': update_token} + return {'n_jobs': len(self._jobs), 'n_job_groups': len(self._job_groups), 'token': update_token} async def _create_update(self) -> int: self._raise_if_not_created() update_spec = self._update_spec() - update_json = await (await self._client._post(f'/api/v1alpha/batches/{self.id}/updates/create', json=update_spec)).json() + update_json = await ( + await self._client._post(f'/api/v1alpha/batches/{self.id}/updates/create', json=update_spec) + ).json() return int(update_json['update_id']) - async def _commit_update(self, update_id: int) -> int: + async def _commit_update(self, update_id: int) -> Tuple[int, int]: self._raise_if_not_created() - commit_json = await (await self._client._patch(f'/api/v1alpha/batches/{self.id}/updates/{update_id}/commit')).json() - return int(commit_json['start_job_id']) - - async def _submit_job_bunches(self, - update_id: int, - byte_job_specs_bunches: List[List[bytes]], - bunch_sizes: List[int], - progress_task: BatchProgressBarTask): + commit_json = await ( + await self._client._patch(f'/api/v1alpha/batches/{self.id}/updates/{update_id}/commit') + ).json() + return (int(commit_json['start_job_group_id']), int(commit_json['start_job_id'])) + + async def _submit_job_group_bunches( + self, + update_id: int, + byte_specs_bunches: List[List[SpecBytes]], + progress_task: BatchProgressBarTask, + ): + self._raise_if_not_created() + for bunch in byte_specs_bunches: + # if/when we add nested job groups, then a job group must always be submitted after its parents + await self._submit_job_groups(update_id, bunch, progress_task) + + async def _submit_job_bunches( + self, + update_id: int, + byte_specs_bunches: List[List[SpecBytes]], + progress_task: BatchProgressBarTask, + ): self._raise_if_not_created() await bounded_gather( - *[functools.partial(self._submit_jobs, update_id, bunch, size, progress_task) - for bunch, size in zip(byte_job_specs_bunches, bunch_sizes) - ], + *[functools.partial(self._submit_jobs, update_id, bunch, progress_task) for bunch in byte_specs_bunches], parallelism=6, cancel_on_error=True, ) - async def _submit(self, - max_bunch_bytesize: int, - max_bunch_size: int, - disable_progress_bar: bool, - progress: BatchProgressBar) -> Optional[int]: + async def _submit( + self, max_bunch_bytesize: int, max_bunch_size: int, disable_progress_bar: bool, progress: BatchProgressBar + ) -> Tuple[Optional[int], Optional[int]]: + n_job_groups = len(self._job_groups) n_jobs = len(self._jobs) - byte_job_specs_bunches, job_bunch_sizes = self._create_bunches(self._job_specs, max_bunch_bytesize, max_bunch_size) - n_job_bunches = len(byte_job_specs_bunches) - - with progress.with_task('submit job bunches', total=n_jobs, disable=(disable_progress_bar or n_job_bunches < 100)) as job_progress_task: - if not self.is_created: - if n_job_bunches == 0: - await self._open_batch() + byte_specs_bunches = self._create_bunches( + self._job_group_specs, self._job_specs, max_bunch_bytesize, max_bunch_size + ) + n_bunches = len(byte_specs_bunches) + + with progress.with_task( + 'submit job group bunches', total=n_job_groups, disable=(disable_progress_bar or n_bunches < 100) + ) as job_group_progress_task: + with progress.with_task( + 'submit job bunches', total=n_jobs, disable=(disable_progress_bar or n_bunches < 100) + ) as job_progress_task: + if not self.is_created: + if n_bunches == 0: + await self._open_batch() + log.info(f'created batch {self.id}') + return (None, None) + if n_bunches == 1: + start_job_group_id, start_job_id = await self._create_fast( + byte_specs_bunches[0], job_group_progress_task, job_progress_task + ) + else: + update_id = await self._open_batch() + assert update_id is not None + await self._submit_job_group_bunches(update_id, byte_specs_bunches, job_group_progress_task) + await self._submit_job_bunches(update_id, byte_specs_bunches, job_progress_task) + start_job_group_id, start_job_id = await self._commit_update(update_id) + self._submission_info = BatchSubmissionInfo(used_fast_path=False) + assert start_job_id == 1 and start_job_group_id == 1 log.info(f'created batch {self.id}') - return None - if n_job_bunches == 1: - await self._create_fast(byte_job_specs_bunches[0], job_bunch_sizes[0], job_progress_task) - start_job_id = 1 - else: - update_id = await self._open_batch() - assert update_id is not None - await self._submit_job_bunches(update_id, byte_job_specs_bunches, job_bunch_sizes, job_progress_task) - start_job_id = await self._commit_update(update_id) - self._submission_info = BatchSubmissionInfo(used_fast_path=False) - assert start_job_id == 1 - log.info(f'created batch {self.id}') - else: - if n_job_bunches == 0: - log.warning('Tried to submit an update with 0 jobs. Doing nothing.') - return None - if n_job_bunches == 1: - start_job_id = await self._update_fast(byte_job_specs_bunches[0], job_bunch_sizes[0], job_progress_task) else: - update_id = await self._create_update() - await self._submit_job_bunches(update_id, byte_job_specs_bunches, job_bunch_sizes, job_progress_task) - start_job_id = await self._commit_update(update_id) - self._submission_info = BatchSubmissionInfo(used_fast_path=False) - log.info(f'updated batch {self.id}') - return start_job_id + if n_bunches == 0: + log.warning('Tried to submit an update with 0 jobs and 0 job groups. Doing nothing.') + return (None, None) + if n_bunches == 1: + start_job_group_id, start_job_id = await self._update_fast( + byte_specs_bunches[0], + job_group_progress_task, + job_progress_task, + ) + else: + update_id = await self._create_update() + await self._submit_job_group_bunches(update_id, byte_specs_bunches, job_group_progress_task) + await self._submit_job_bunches(update_id, byte_specs_bunches, job_progress_task) + start_job_group_id, start_job_id = await self._commit_update(update_id) + self._submission_info = BatchSubmissionInfo(used_fast_path=False) + log.info(f'updated batch {self.id}') + return (start_job_group_id, start_job_id) MAX_BUNCH_BYTESIZE = 1024 * 1024 MAX_BUNCH_SIZE = 1024 - async def submit(self, - max_bunch_bytesize: int = MAX_BUNCH_BYTESIZE, - max_bunch_size: int = MAX_BUNCH_SIZE, - disable_progress_bar: bool = False, - *, - progress: Optional[BatchProgressBar] = None - ): + async def submit( + self, + max_bunch_bytesize: int = MAX_BUNCH_BYTESIZE, + max_bunch_size: int = MAX_BUNCH_SIZE, + disable_progress_bar: bool = False, + *, + progress: Optional[BatchProgressBar] = None, + ): assert max_bunch_bytesize > 0 assert max_bunch_size > 0 if progress: - start_job_id = await self._submit(max_bunch_bytesize, max_bunch_size, disable_progress_bar, progress) + start_job_group_id, start_job_id = await self._submit( + max_bunch_bytesize, max_bunch_size, disable_progress_bar, progress + ) else: with BatchProgressBar(disable=disable_progress_bar) as progress2: - start_job_id = await self._submit(max_bunch_bytesize, max_bunch_size, disable_progress_bar, progress2) + start_job_group_id, start_job_id = await self._submit( + max_bunch_bytesize, max_bunch_size, disable_progress_bar, progress2 + ) assert self.is_created + for jg in self._job_groups: + assert start_job_group_id is not None + jg._submit(start_job_group_id) + for j in self._jobs: assert start_job_id is not None j._submit(start_job_id) + self._job_group_specs = [] + self._job_groups = [] + self._in_update_job_group_id = 0 + self._job_specs = [] self._jobs = [] - self._job_idx = 0 + self._in_update_job_id = 0 class HailExplicitTokenCredentials(CloudCredentials): @@ -866,14 +1250,16 @@ async def close(self): class BatchClient: @staticmethod - async def create(billing_project: str, - deploy_config: Optional[DeployConfig] = None, - session: Optional[httpx.ClientSession] = None, - headers: Optional[Dict[str, str]] = None, - _token: Optional[str] = None, - token_file: Optional[str] = None, - *, - cloud_credentials_file: Optional[str] = None): + async def create( + billing_project: str, + deploy_config: Optional[DeployConfig] = None, + session: Optional[httpx.ClientSession] = None, + headers: Optional[Dict[str, str]] = None, + _token: Optional[str] = None, + token_file: Optional[str] = None, + *, + cloud_credentials_file: Optional[str] = None, + ): if not deploy_config: deploy_config = get_deploy_config() url = deploy_config.base_url('batch') @@ -888,13 +1274,10 @@ async def create(billing_project: str, billing_project=billing_project, url=url, session=Session(credentials=credentials, http_session=session, timeout=aiohttp.ClientTimeout(total=30)), - headers=headers) + headers=headers, + ) - def __init__(self, - billing_project: str, - url: str, - session: Session, - headers: Dict[str, str]): + def __init__(self, billing_project: str, url: str, session: Session, headers: Dict[str, str]): self.billing_project = billing_project self.url = url self._session: Session = session @@ -915,7 +1298,7 @@ async def _delete(self, path) -> aiohttp.ClientResponse: def reset_billing_project(self, billing_project): self.billing_project = billing_project - async def list_batches(self, q=None, last_batch_id=None, limit=2 ** 64, version=None): + async def list_batches(self, q=None, last_batch_id=None, limit=2**64, version=None): if version is None: version = 1 n = 0 @@ -947,10 +1330,7 @@ async def list_batches(self, q=None, last_batch_id=None, limit=2 ** 64, version= async def get_job(self, batch_id, job_id): b = await self.get_batch(batch_id) j_resp = await self._get(f'/api/v1alpha/batches/{batch_id}/jobs/{job_id}') - j = cast( - GetJobResponseV1Alpha, - await j_resp.json() - ) + j = cast(GetJobResponseV1Alpha, await j_resp.json()) return Job.submitted_job(b, j['job_id'], _status=j) async def get_job_log(self, batch_id, job_id) -> Dict[str, Any]: @@ -967,19 +1347,17 @@ async def get_batch(self, id) -> Batch: assert isinstance(b, dict), b attributes = b.get('attributes') assert attributes is None or isinstance(attributes, dict), attributes - return Batch(self, - id=b['id'], - attributes=attributes, - token=b['token'], - last_known_status=b) + return Batch(self, id=b['id'], attributes=attributes, token=b['token'], last_known_status=b) def create_batch(self, attributes=None, callback=None, token=None, cancel_after_n_failures=None) -> Batch: - return Batch(self, - id=None, - attributes=attributes, - callback=callback, - token=token, - cancel_after_n_failures=cancel_after_n_failures) + return Batch( + self, + id=None, + attributes=attributes, + callback=callback, + token=token, + cancel_after_n_failures=cancel_after_n_failures, + ) async def get_billing_project(self, billing_project): bp_resp = await self._get(f'/api/v1alpha/billing_projects/{billing_project}') diff --git a/hail/python/hailtop/batch_client/client.py b/hail/python/hailtop/batch_client/client.py index f6cf9f00749..4e5f72f197d 100644 --- a/hail/python/hailtop/batch_client/client.py +++ b/hail/python/hailtop/batch_client/client.py @@ -1,9 +1,11 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple -from hailtop.utils import async_to_blocking, ait_to_blocking +from hailtop.batch_client.types import GetJobGroupResponseV1Alpha +from hailtop.utils import ait_to_blocking, async_to_blocking + +from .. import httpx from ..config import DeployConfig from . import aioclient -from .. import httpx class Job: @@ -96,6 +98,118 @@ def attempts(self): return async_to_blocking(self._async_job.attempts()) +class JobGroup: + def __init__(self, async_job_group: aioclient.JobGroup): + self._async_job_group = async_job_group + + def attributes(self): + return async_to_blocking(self._async_job_group.attributes()) + + @property + def batch_id(self) -> int: + return self._async_job_group.batch_id + + @property + def job_group_id(self) -> int: + return self._async_job_group.job_group_id + + @property + def id(self) -> Tuple[int, int]: + return (self.batch_id, self.job_group_id) + + def cancel(self): + return async_to_blocking(self._async_job_group.cancel()) + + def jobs(self, q: Optional[str] = None, version: Optional[int] = None, recursive: bool = False): + return ait_to_blocking(self._async_job_group.jobs(q, version, recursive)) + + def job_groups(self): + return ait_to_blocking(self._async_job_group.job_groups()) + + def status(self) -> GetJobGroupResponseV1Alpha: + return async_to_blocking(self._async_job_group.status()) + + def wait(self, *args, **kwargs) -> GetJobGroupResponseV1Alpha: + return async_to_blocking(self._async_job_group.wait(*args, **kwargs)) + + def last_known_status(self) -> GetJobGroupResponseV1Alpha: + return async_to_blocking(self._async_job_group.last_known_status()) + + def create_job_group(self, *, attributes=None, callback=None, cancel_after_n_failures=None) -> 'JobGroup': + async_job_group = self._async_job_group.create_job_group( + attributes=attributes, callback=callback, cancel_after_n_failures=cancel_after_n_failures + ) + return JobGroup(async_job_group) + + def create_job( + self, + image, + command, + *, + env=None, + port=None, + resources=None, + secrets=None, + service_account=None, + attributes=None, + parents=None, + input_files=None, + output_files=None, + always_run=False, + timeout=None, + cloudfuse=None, + requester_pays_project=None, + mount_tokens=False, + network: Optional[str] = None, + unconfined: bool = False, + user_code: Optional[str] = None, + regions: Optional[List[str]] = None, + always_copy_output: bool = False, + ) -> Job: + if parents: + parents = [parent._async_job for parent in parents] + + async_job = self._async_job_group.create_job( + image, + command, + env=env, + port=port, + resources=resources, + secrets=secrets, + service_account=service_account, + attributes=attributes, + parents=parents, + input_files=input_files, + output_files=output_files, + always_run=always_run, + always_copy_output=always_copy_output, + timeout=timeout, + cloudfuse=cloudfuse, + requester_pays_project=requester_pays_project, + mount_tokens=mount_tokens, + network=network, + unconfined=unconfined, + user_code=user_code, + regions=regions, + ) + + return Job(async_job) + + def create_jvm_job(self, command, *, profile: bool = False, parents=None, **kwargs) -> Job: + if parents: + parents = [parent._async_job for parent in parents] + + async_job = self._async_job_group.create_jvm_job(command, profile=profile, parents=parents, **kwargs) + + return Job(async_job) + + def debug_info(self): + return async_to_blocking(self._async_job_group.debug_info()) + + def __str__(self): + return str(self._async_job_group) + + class Batch: @staticmethod def _open_batch(client: 'BatchClient', token: Optional[str] = None) -> 'Batch': @@ -126,30 +240,16 @@ def token(self): def _submission_info(self): return self._async_batch._submission_info + def get_job_group(self, job_group_id: int) -> JobGroup: + return JobGroup(self._async_batch.get_job_group(job_group_id)) + + def job_groups(self): + for jg in ait_to_blocking(self._async_batch.job_groups()): + yield JobGroup(jg) + def cancel(self): async_to_blocking(self._async_batch.cancel()) - # { - # id: int - # user: str - # billing_project: str - # token: str - # state: str, (open, failure, cancelled, success, running) - # complete: bool - # closed: bool - # n_jobs: int - # n_completed: int - # n_succeeded: int - # n_failed: int - # n_cancelled: int - # time_created: optional(str), (date) - # time_closed: optional(str), (date) - # time_completed: optional(str), (date) - # duration: optional(str) - # attributes: optional(dict(str, str)) - # msec_mcpu: int - # cost: float - # } def status(self): return async_to_blocking(self._async_batch.status()) @@ -175,32 +275,63 @@ def debug_info(self): def delete(self): async_to_blocking(self._async_batch.delete()) - def create_job(self, - image, - command, - *, - env=None, - port=None, resources=None, secrets=None, - service_account=None, attributes=None, parents=None, - input_files=None, output_files=None, always_run=False, - timeout=None, cloudfuse=None, requester_pays_project=None, - mount_tokens=False, network: Optional[str] = None, - unconfined: bool = False, user_code: Optional[str] = None, - regions: Optional[List[str]] = None, - always_copy_output: bool = False) -> Job: + def create_job_group(self, *, attributes=None, callback=None, cancel_after_n_failures=None) -> JobGroup: + async_job_group = self._async_batch.create_job_group( + attributes=attributes, callback=callback, cancel_after_n_failures=cancel_after_n_failures + ) + return JobGroup(async_job_group) + + def create_job( + self, + image, + command, + *, + env=None, + port=None, + resources=None, + secrets=None, + service_account=None, + attributes=None, + parents=None, + input_files=None, + output_files=None, + always_run=False, + timeout=None, + cloudfuse=None, + requester_pays_project=None, + mount_tokens=False, + network: Optional[str] = None, + unconfined: bool = False, + user_code: Optional[str] = None, + regions: Optional[List[str]] = None, + always_copy_output: bool = False, + ) -> Job: if parents: parents = [parent._async_job for parent in parents] async_job = self._async_batch.create_job( - image, command, env=env, - port=port, resources=resources, secrets=secrets, + image, + command, + env=env, + port=port, + resources=resources, + secrets=secrets, service_account=service_account, - attributes=attributes, parents=parents, - input_files=input_files, output_files=output_files, always_run=always_run, - always_copy_output=always_copy_output, timeout=timeout, cloudfuse=cloudfuse, - requester_pays_project=requester_pays_project, mount_tokens=mount_tokens, - network=network, unconfined=unconfined, user_code=user_code, - regions=regions) + attributes=attributes, + parents=parents, + input_files=input_files, + output_files=output_files, + always_run=always_run, + always_copy_output=always_copy_output, + timeout=timeout, + cloudfuse=cloudfuse, + requester_pays_project=requester_pays_project, + mount_tokens=mount_tokens, + network=network, + unconfined=unconfined, + user_code=user_code, + regions=regions, + ) return Job(async_job) @@ -223,15 +354,20 @@ def from_async(async_client: aioclient.BatchClient): bc._async_client = async_client return bc - def __init__(self, - billing_project: str, - deploy_config: Optional[DeployConfig] = None, - session: Optional[httpx.ClientSession] = None, - headers: Optional[Dict[str, str]] = None, - _token: Optional[str] = None, - token_file: Optional[str] = None): - self._async_client = async_to_blocking(aioclient.BatchClient.create( - billing_project, deploy_config, session, headers=headers, _token=_token, token_file=token_file)) + def __init__( + self, + billing_project: str, + deploy_config: Optional[DeployConfig] = None, + session: Optional[httpx.ClientSession] = None, + headers: Optional[Dict[str, str]] = None, + _token: Optional[str] = None, + token_file: Optional[str] = None, + ): + self._async_client = async_to_blocking( + aioclient.BatchClient.create( + billing_project, deploy_config, session, headers=headers, _token=_token, token_file=token_file + ) + ) @property def billing_project(self): @@ -241,7 +377,9 @@ def reset_billing_project(self, billing_project): self._async_client.reset_billing_project(billing_project) def list_batches(self, q=None, last_batch_id=None, limit=2**64, version=None): - for b in ait_to_blocking(self._async_client.list_batches(q=q, last_batch_id=last_batch_id, limit=limit, version=version)): + for b in ait_to_blocking( + self._async_client.list_batches(q=q, last_batch_id=last_batch_id, limit=limit, version=version) + ): yield Batch(b) def get_job(self, batch_id, job_id): @@ -260,16 +398,10 @@ def get_batch(self, id): b = async_to_blocking(self._async_client.get_batch(id)) return Batch(b) - def create_batch(self, - attributes=None, - callback=None, - token=None, - cancel_after_n_failures=None - ) -> 'Batch': - batch = self._async_client.create_batch(attributes=attributes, - callback=callback, - token=token, - cancel_after_n_failures=cancel_after_n_failures) + def create_batch(self, attributes=None, callback=None, token=None, cancel_after_n_failures=None) -> 'Batch': + batch = self._async_client.create_batch( + attributes=attributes, callback=callback, token=token, cancel_after_n_failures=cancel_after_n_failures + ) return Batch(batch) def get_billing_project(self, billing_project): diff --git a/hail/python/hailtop/batch_client/globals.py b/hail/python/hailtop/batch_client/globals.py index 992ad292d15..f515148ae53 100644 --- a/hail/python/hailtop/batch_client/globals.py +++ b/hail/python/hailtop/batch_client/globals.py @@ -1,3 +1,7 @@ +ROOT_JOB_GROUP_ID = 0 + +MAX_JOB_GROUPS_DEPTH = 2 + tasks = ('input', 'main', 'output') complete_states = ('Cancelled', 'Error', 'Failed', 'Success') diff --git a/hail/python/hailtop/batch_client/parse.py b/hail/python/hailtop/batch_client/parse.py index 5c9f4023831..2775fba3217 100644 --- a/hail/python/hailtop/batch_client/parse.py +++ b/hail/python/hailtop/batch_client/parse.py @@ -1,6 +1,6 @@ -from typing import Optional, Mapping, Pattern -import re import math +import re +from typing import Mapping, Optional, Pattern MEMORY_REGEXPAT: str = r'[+]?((?:[0-9]*[.])?[0-9]+)([KMGTP][i]?)?B?' MEMORY_REGEX: Pattern = re.compile(MEMORY_REGEXPAT) @@ -23,11 +23,16 @@ def parse_cpu_in_mcpu(cpu_string: str) -> Optional[int]: conv_factor: Mapping[str, int] = { - 'K': 1000, 'Ki': 1024, - 'M': 1000**2, 'Mi': 1024**2, - 'G': 1000**3, 'Gi': 1024**3, - 'T': 1000**4, 'Ti': 1024**4, - 'P': 1000**5, 'Pi': 1024**5 + 'K': 1000, + 'Ki': 1024, + 'M': 1000**2, + 'Mi': 1024**2, + 'G': 1000**3, + 'Gi': 1024**3, + 'T': 1000**4, + 'Ti': 1024**4, + 'P': 1000**5, + 'Pi': 1024**5, } diff --git a/hail/python/hailtop/batch_client/types.py b/hail/python/hailtop/batch_client/types.py index 2697cb34ee6..41b9ecc114f 100644 --- a/hail/python/hailtop/batch_client/types.py +++ b/hail/python/hailtop/batch_client/types.py @@ -1,4 +1,5 @@ -from typing import TypedDict, Literal, Optional, List, Any, Dict +from typing import Any, Dict, List, Literal, Optional, TypedDict + from typing_extensions import NotRequired @@ -10,6 +11,7 @@ class CostBreakdownEntry(TypedDict): class GetJobResponseV1Alpha(TypedDict): batch_id: int job_id: int + job_group_id: int name: Optional[str] user: str billing_project: str @@ -22,11 +24,14 @@ class GetJobResponseV1Alpha(TypedDict): status: Optional[Dict[str, Any]] spec: Optional[Dict[str, Any]] attributes: NotRequired[Dict[str, str]] + always_run: bool + display_state: Optional[str] class JobListEntryV1Alpha(TypedDict): batch_id: int job_id: int + job_group_id: int name: Optional[str] user: str billing_project: str @@ -36,8 +41,28 @@ class JobListEntryV1Alpha(TypedDict): cost: Optional[float] msec_mcpu: int cost_breakdown: List[CostBreakdownEntry] + always_run: bool + display_state: Optional[str] class GetJobsResponseV1Alpha(TypedDict): jobs: List[JobListEntryV1Alpha] last_job_id: NotRequired[int] + + +class GetJobGroupResponseV1Alpha(TypedDict): + batch_id: int + job_group_id: int + state: Literal['failure', 'cancelled', 'success', 'running'] + complete: bool + n_jobs: int + n_completed: int + n_succeeded: int + n_failed: int + n_cancelled: int + time_created: Optional[str] # date string + time_completed: Optional[str] # date string + duration: Optional[int] + cost: float + cost_breakdown: List[CostBreakdownEntry] + attributes: NotRequired[Dict[str, str]] diff --git a/hail/python/hailtop/cleanup_gcr/__main__.py b/hail/python/hailtop/cleanup_gcr/__main__.py index b9e445f34dd..38587f469b5 100644 --- a/hail/python/hailtop/cleanup_gcr/__main__.py +++ b/hail/python/hailtop/cleanup_gcr/__main__.py @@ -1,9 +1,11 @@ +import asyncio +import logging import sys import time from typing import Awaitable, List, TypeVar -import logging -import asyncio + import aiohttp + from hailtop import aiotools from hailtop.aiocloud import aiogoogle @@ -89,16 +91,14 @@ async def cleanup_image(self, image): await asyncio.gather(*[ self.cleanup_digest(image, digest, tags) for digest, time_uploaded, tags in manifests - if (now - time_uploaded) >= (7 * 24 * 60 * 60) or len(tags) == 0]) + if (now - time_uploaded) >= (7 * 24 * 60 * 60) or len(tags) == 0 + ]) log.info(f'cleaned up image {image}') async def run(self): images = await self._executor.submit(self._client.get('/tags/list')) - await asyncio.gather(*[ - self.cleanup_image(image) - for image in images['child'] - ]) + await asyncio.gather(*[self.cleanup_image(image) for image in images['child']]) async def main(): @@ -108,9 +108,7 @@ async def main(): raise ValueError('usage: cleanup_gcr ') project = sys.argv[1] - async with aiogoogle.GoogleContainerClient( - project=project, - timeout=aiohttp.ClientTimeout(total=5)) as client: + async with aiogoogle.GoogleContainerClient(project=project, timeout=aiohttp.ClientTimeout(total=5)) as client: cleanup_images = CleanupImages(client) try: await cleanup_images.run() diff --git a/hail/python/hailtop/config/__init__.py b/hail/python/hailtop/config/__init__.py index 14a1e46797d..47db5791a50 100644 --- a/hail/python/hailtop/config/__init__.py +++ b/hail/python/hailtop/config/__init__.py @@ -1,6 +1,12 @@ -from .user_config import (get_user_config, get_user_config_path, get_user_identity_config_path, - get_remote_tmpdir, configuration_of, get_hail_config_path) -from .deploy_config import get_deploy_config, DeployConfig +from .deploy_config import DeployConfig, get_deploy_config +from .user_config import ( + configuration_of, + get_hail_config_path, + get_remote_tmpdir, + get_user_config, + get_user_config_path, + get_user_identity_config_path, +) from .variables import ConfigVariable __all__ = [ diff --git a/hail/python/hailtop/config/deploy_config.py b/hail/python/hailtop/config/deploy_config.py index d6b66c7adc3..4996bd0424c 100644 --- a/hail/python/hailtop/config/deploy_config.py +++ b/hail/python/hailtop/config/deploy_config.py @@ -1,41 +1,51 @@ -from typing import Dict -import os import json import logging -from ..utils import first_extant_file +import os +import ssl +from typing import Dict, Optional, TypeVar, Union +from ..tls import external_client_ssl_context, internal_client_ssl_context, internal_server_ssl_context +from ..utils import first_extant_file from .user_config import get_user_config log = logging.getLogger('deploy_config') +T = TypeVar("T") -def env_var_or_default(name: str, defaults: Dict[str, str]) -> str: - return os.environ.get(f'HAIL_{name.upper()}') or defaults[name] +def env_var_or_default(name: str, default: T) -> Union[str, T]: + return os.environ.get(f'HAIL_{name.upper()}', default) -class DeployConfig: - @staticmethod - def from_config(config: Dict[str, str]) -> 'DeployConfig': - return DeployConfig( - env_var_or_default('location', config), - env_var_or_default('default_namespace', config), - env_var_or_default('domain', config) - ) - def get_config(self) -> Dict[str, str]: +class DeployConfig: + @classmethod + def from_config(cls, config: Dict[str, str]) -> 'DeployConfig': + location = env_var_or_default('location', config['location']) + domain = env_var_or_default('domain', config['domain']) + ns = env_var_or_default('default_namespace', config['default_namespace']) + base_path = env_var_or_default('base_path', config.get('base_path')) or None + if base_path is None and ns != 'default': + domain = f'internal.{config["domain"]}' + base_path = f'/{ns}' + + return cls(location, ns, domain, base_path) + + def get_config(self) -> Dict[str, Optional[str]]: return { 'location': self._location, 'default_namespace': self._default_namespace, - 'domain': self._domain + 'domain': self._domain, + 'base_path': self._base_path, } - @staticmethod - def from_config_file(config_file=None) -> 'DeployConfig': + @classmethod + def from_config_file(cls, config_file=None) -> 'DeployConfig': config_file = first_extant_file( config_file, os.environ.get('HAIL_DEPLOY_CONFIG_FILE'), os.path.expanduser('~/.hail/deploy-config.json'), - '/deploy-config/deploy-config.json') + '/deploy-config/deploy-config.json', + ) if config_file is not None: log.info(f'deploy config file found at {config_file}') with open(config_file, 'r', encoding='utf-8') as f: @@ -48,21 +58,22 @@ def from_config_file(config_file=None) -> 'DeployConfig': 'default_namespace': 'default', 'domain': get_user_config().get('global', 'domain', fallback='hail.is'), } - return DeployConfig.from_config(config) + return cls.from_config(config) - def __init__(self, location, default_namespace, domain): + def __init__(self, location: str, default_namespace: str, domain: str, base_path: Optional[str]): assert location in ('external', 'k8s', 'gce') self._location = location self._default_namespace = default_namespace self._domain = domain + self._base_path = base_path def with_default_namespace(self, default_namespace): - return DeployConfig(self._location, default_namespace, self._domain) + return DeployConfig(self._location, default_namespace, self._domain, self._base_path) def with_location(self, location): - return DeployConfig(location, self._default_namespace, self._domain) + return DeployConfig(location, self._default_namespace, self._domain, self._base_path) - def default_namespace(self) -> str: + def default_namespace(self): return self._default_namespace def location(self): @@ -77,19 +88,18 @@ def domain(self, service): if self._location == 'k8s': return f'{service}.{ns}' if self._location == 'gce': - if ns == 'default': + if self._base_path is None: return f'{service}.hail' return 'internal.hail' assert self._location == 'external' - if ns == 'default': + if self._base_path is None: return f'{service}.{self._domain}' - return f'internal.{self._domain}' + return self._domain def base_path(self, service): - ns = self._default_namespace - if ns == 'default': + if self._base_path is None: return '' - return f'/{ns}/{service}' + return f'{self._base_path}/{service}' def base_url(self, service, base_scheme='http'): return f'{self.scheme(base_scheme)}://{self.domain(service)}{self.base_path(service)}' @@ -103,15 +113,15 @@ def auth_session_cookie_name(self): return 'sesh' def external_url(self, service, path, base_scheme='http'): - ns = self._default_namespace - if ns == 'default': + if self._base_path is None: if service == 'www': return f'{base_scheme}s://{self._domain}{path}' return f'{base_scheme}s://{service}.{self._domain}{path}' - return f'{base_scheme}s://internal.{self._domain}/{ns}/{service}{path}' + return f'{base_scheme}s://{self._domain}{self._base_path}/{service}{path}' def prefix_application(self, app, service, **kwargs): from aiohttp import web # pylint: disable=import-outside-toplevel + base_path = self.base_path(service) if not base_path: return app @@ -133,6 +143,36 @@ async def get_metrics(_): log.info(f'serving paths at {base_path}') return root_app + def client_ssl_context(self) -> ssl.SSLContext: + if self._location == 'k8s': + return internal_client_ssl_context() + # no encryption on the internal gateway + return external_client_ssl_context() + + def server_ssl_context(self) -> Optional[ssl.SSLContext]: + if self._location == 'k8s': + return internal_server_ssl_context() + # local mode does not have access to self-signed certs + return None + + +class TerraDeployConfig(DeployConfig): + def domain(self, service): + if self._location == 'k8s': + return { + 'batch-driver': 'localhost:5000', + 'batch': 'localhost:5001', + }[service] + return self._domain + + def client_ssl_context(self) -> ssl.SSLContext: + # Terra app networking doesn't use self-signed certs + return external_client_ssl_context() + + def server_ssl_context(self) -> Optional[ssl.SSLContext]: + # Terra app services are in the same pod and just use http + return None + deploy_config = None @@ -141,5 +181,8 @@ def get_deploy_config() -> DeployConfig: global deploy_config if not deploy_config: - deploy_config = DeployConfig.from_config_file() + if os.environ.get('HAIL_TERRA'): + deploy_config = TerraDeployConfig.from_config_file() + else: + deploy_config = DeployConfig.from_config_file() return deploy_config diff --git a/hail/python/hailtop/config/user_config.py b/hail/python/hailtop/config/user_config.py index 5f1f2388f18..752031eb166 100644 --- a/hail/python/hailtop/config/user_config.py +++ b/hail/python/hailtop/config/user_config.py @@ -1,9 +1,9 @@ -from typing import Optional, Union, TypeVar +import configparser import os import re -import configparser import warnings from pathlib import Path +from typing import Optional, TypeVar, Union from .variables import ConfigVariable @@ -46,12 +46,9 @@ def get_user_config() -> configparser.ConfigParser: T = TypeVar('T') -def unchecked_configuration_of(section: str, - option: str, - explicit_argument: Optional[T], - fallback: T, - *, - deprecated_envvar: Optional[str] = None) -> Union[str, T]: +def unchecked_configuration_of( + section: str, option: str, explicit_argument: Optional[T], fallback: T, *, deprecated_envvar: Optional[str] = None +) -> Union[str, T]: if explicit_argument is not None: return explicit_argument @@ -60,13 +57,17 @@ def unchecked_configuration_of(section: str, deprecated_envval = None if deprecated_envvar is None else os.environ.get(deprecated_envvar) if envval is not None: if deprecated_envval is not None: - raise ValueError(f'Value for configuration variable {section}/{option} is ambiguous ' - f'because both {envvar} and {deprecated_envvar} are set (respectively ' - f'to: {envval} and {deprecated_envval}.') + raise ValueError( + f'Value for configuration variable {section}/{option} is ambiguous ' + f'because both {envvar} and {deprecated_envvar} are set (respectively ' + f'to: {envval} and {deprecated_envval}.' + ) return envval if deprecated_envval is not None: - warnings.warn(f'Use of deprecated envvar {deprecated_envvar} for configuration variable ' - f'{section}/{option}. Please use {envvar} instead.') + warnings.warn( + f'Use of deprecated envvar {deprecated_envvar} for configuration variable ' + f'{section}/{option}. Please use {envvar} instead.' + ) return deprecated_envval from_user_config = get_user_config().get(section, option, fallback=None) @@ -76,11 +77,13 @@ def unchecked_configuration_of(section: str, return fallback -def configuration_of(config_variable: ConfigVariable, - explicit_argument: Optional[T], - fallback: T, - *, - deprecated_envvar: Optional[str] = None) -> Union[str, T]: +def configuration_of( + config_variable: ConfigVariable, + explicit_argument: Optional[T], + fallback: T, + *, + deprecated_envvar: Optional[str] = None, +) -> Union[str, T]: if '/' in config_variable.value: section, option = config_variable.value.split('/') else: @@ -89,22 +92,27 @@ def configuration_of(config_variable: ConfigVariable, return unchecked_configuration_of(section, option, explicit_argument, fallback, deprecated_envvar=deprecated_envvar) -def get_remote_tmpdir(caller_name: str, - *, - bucket: Optional[str] = None, - remote_tmpdir: Optional[str] = None, - user_config: Optional[configparser.ConfigParser] = None, - warnings_stacklevel: int = 2, - ) -> str: +def get_remote_tmpdir( + caller_name: str, + *, + bucket: Optional[str] = None, + remote_tmpdir: Optional[str] = None, + user_config: Optional[configparser.ConfigParser] = None, + warnings_stacklevel: int = 2, +) -> str: if user_config is None: user_config = get_user_config() if bucket is not None: - warnings.warn(f'Use of deprecated argument \'bucket\' in {caller_name}(...). Specify \'remote_tmpdir\' as a keyword argument instead.', - stacklevel=warnings_stacklevel) + warnings.warn( + f'Use of deprecated argument \'bucket\' in {caller_name}(...). Specify \'remote_tmpdir\' as a keyword argument instead.', + stacklevel=warnings_stacklevel, + ) if remote_tmpdir is not None and bucket is not None: - raise ValueError(f'Cannot specify both \'remote_tmpdir\' and \'bucket\' in {caller_name}(...). Specify \'remote_tmpdir\' as a keyword argument instead.') + raise ValueError( + f'Cannot specify both \'remote_tmpdir\' and \'bucket\' in {caller_name}(...). Specify \'remote_tmpdir\' as a keyword argument instead.' + ) if bucket is None and remote_tmpdir is None: remote_tmpdir = configuration_of(ConfigVariable.BATCH_REMOTE_TMPDIR, None, None) @@ -112,25 +120,30 @@ def get_remote_tmpdir(caller_name: str, if remote_tmpdir is None: if bucket is None: bucket = user_config.get('batch', 'bucket', fallback=None) - warnings.warn('Using deprecated configuration setting \'batch/bucket\'. Run `hailctl config set batch/remote_tmpdir` ' - 'to set the default for \'remote_tmpdir\' instead.', - stacklevel=warnings_stacklevel) + warnings.warn( + 'Using deprecated configuration setting \'batch/bucket\'. Run `hailctl config set batch/remote_tmpdir` ' + 'to set the default for \'remote_tmpdir\' instead.', + stacklevel=warnings_stacklevel, + ) if bucket is None: raise ValueError( f'Either the \'remote_tmpdir\' parameter of {caller_name}(...) must be set or you must ' - 'run `hailctl config set batch/remote_tmpdir REMOTE_TMPDIR`.') + 'run `hailctl config set batch/remote_tmpdir REMOTE_TMPDIR`.' + ) if 'gs://' in bucket: raise ValueError( f'The bucket parameter to {caller_name}(...) and the `batch/bucket` hailctl config setting ' 'must both be bucket names, not paths. Use the remote_tmpdir parameter or batch/remote_tmpdir ' - 'hailctl config setting instead to specify a path.') + 'hailctl config setting instead to specify a path.' + ) remote_tmpdir = f'gs://{bucket}/batch' else: - schemes = {'gs', 'hail-az', 'https'} + schemes = {'gs', 'https'} found_scheme = any(remote_tmpdir.startswith(f'{scheme}://') for scheme in schemes) if not found_scheme: raise ValueError( - f'remote_tmpdir must be a storage uri path like gs://bucket/folder. Received: {remote_tmpdir}. Possible schemes include gs for GCP and https for Azure') + f'remote_tmpdir must be a storage uri path like gs://bucket/folder. Received: {remote_tmpdir}. Possible schemes include gs for GCP and https for Azure' + ) if remote_tmpdir[-1] != '/': remote_tmpdir += '/' return remote_tmpdir diff --git a/hail/python/hailtop/config/variables.py b/hail/python/hailtop/config/variables.py index cfd82a3d774..036bae78d26 100644 --- a/hail/python/hailtop/config/variables.py +++ b/hail/python/hailtop/config/variables.py @@ -19,3 +19,4 @@ class ConfigVariable(str, Enum): QUERY_BATCH_WORKER_MEMORY = 'query/batch_worker_memory' QUERY_NAME_PREFIX = 'query/name_prefix' QUERY_DISABLE_PROGRESS_BAR = 'query/disable_progress_bar' + HTTP_TIMEOUT_IN_SECONDS = 'http/timeout_in_seconds' diff --git a/hail/python/hailtop/frozendict.py b/hail/python/hailtop/frozendict.py index eef7cf0c219..0f4c23b9df0 100644 --- a/hail/python/hailtop/frozendict.py +++ b/hail/python/hailtop/frozendict.py @@ -1,5 +1,5 @@ -from typing import TypeVar, Dict, Generic from collections.abc import Mapping +from typing import Dict, Generic, TypeVar T = TypeVar("T") U = TypeVar("U") @@ -25,6 +25,7 @@ class frozendict(Mapping, Generic[T, U]): python does not. """ + def __init__(self, d: Dict[T, U]): self.d = d.copy() diff --git a/hail/python/hailtop/fs/__init__.py b/hail/python/hailtop/fs/__init__.py index a06f18018db..5a62b8f5e79 100644 --- a/hail/python/hailtop/fs/__init__.py +++ b/hail/python/hailtop/fs/__init__.py @@ -1,14 +1,14 @@ from .fs_utils import ( - open, copy, exists, - is_file, is_dir, - stat, + is_file, ls, mkdir, + open, remove, rmtree, + stat, ) __all__ = [ diff --git a/hail/python/hailtop/fs/fs.py b/hail/python/hailtop/fs/fs.py index 26a2b944a37..2b0243c2bac 100644 --- a/hail/python/hailtop/fs/fs.py +++ b/hail/python/hailtop/fs/fs.py @@ -1,12 +1,13 @@ import abc -from typing import IO, List +import io +from typing import List from .stat_result import FileListEntry class FS(abc.ABC): @abc.abstractmethod - def open(self, path: str, mode: str = 'r', buffer_size: int = 8192) -> IO: + def open(self, path: str, mode: str = 'r', buffer_size: int = 8192) -> io.IOBase: raise NotImplementedError @abc.abstractmethod diff --git a/hail/python/hailtop/fs/fs_utils.py b/hail/python/hailtop/fs/fs_utils.py index 1b6c42089ad..955bcdbfd93 100644 --- a/hail/python/hailtop/fs/fs_utils.py +++ b/hail/python/hailtop/fs/fs_utils.py @@ -1,20 +1,22 @@ import io -from typing import List +from typing import List, Optional + +from hailtop.aiocloud.aiogoogle import GCSRequesterPaysConfiguration +from hailtop.utils.gcs_requester_pays import GCSRequesterPaysFSCache from .router_fs import RouterFS from .stat_result import FileListEntry -_router_fs = None - - -def _fs() -> RouterFS: - global _router_fs - if _router_fs is None: - _router_fs = RouterFS() - return _router_fs +_fses = GCSRequesterPaysFSCache(fs_constructor=RouterFS) -def open(path: str, mode: str = 'r', buffer_size: int = 8192) -> io.IOBase: +def open( + path: str, + mode: str = 'r', + buffer_size: int = 8192, + *, + requester_pays_config: Optional[GCSRequesterPaysConfiguration] = None, +) -> io.IOBase: """Open a file from the local filesystem of from blob storage. Supported blob storage providers are GCS, S3 and ABS. @@ -31,6 +33,25 @@ def open(path: str, mode: str = 'r', buffer_size: int = 8192) -> io.IOBase: ... for line in f: ... print(line.strip()) + Access a text file stored in a Requester Pays Bucket in Google Cloud Storage: + + >>> with hfs.open( # doctest: +SKIP + ... 'gs://my-bucket/notes.txt', + ... requester_pays_config='my-project' + ... ) as f: + ... for line in f: + ... print(line.strip()) + + Specify multiple Requester Pays Buckets within a project that are acceptable + to access: + + >>> with hfs.open( # doctest: +SKIP + ... 'gs://my-bucket/notes.txt', + ... requester_pays_config=('my-project', ['my-bucket', 'bucket-2']) + ... ) as f: + ... for line in f: + ... print(line.strip()) + Write two lines directly to a file in Google Cloud Storage: >>> with hfs.open('gs://my-bucket/notes.txt', 'w') as f: # doctest: +SKIP @@ -72,10 +93,10 @@ def open(path: str, mode: str = 'r', buffer_size: int = 8192) -> io.IOBase: ------- Readable or writable file handle. """ - return _fs().open(path, mode, buffer_size) + return _fses[requester_pays_config].open(path, mode, buffer_size) -def copy(src: str, dest: str): +def copy(src: str, dest: str, *, requester_pays_config: Optional[GCSRequesterPaysConfiguration] = None): """Copy a file between filesystems. Filesystems can be local filesystem or the blob storage providers GCS, S3 and ABS. @@ -105,10 +126,10 @@ def copy(src: str, dest: str): dest: :class:`str` Destination file URI. """ - _fs().copy(src, dest) + _fses[requester_pays_config].copy(src, dest) -def exists(path: str) -> bool: +def exists(path: str, *, requester_pays_config: Optional[GCSRequesterPaysConfiguration] = None) -> bool: """Returns ``True`` if `path` exists. Parameters @@ -119,10 +140,10 @@ def exists(path: str) -> bool: ------- :obj:`.bool` """ - return _fs().exists(path) + return _fses[requester_pays_config].exists(path) -def is_file(path: str) -> bool: +def is_file(path: str, *, requester_pays_config: Optional[GCSRequesterPaysConfiguration] = None) -> bool: """Returns ``True`` if `path` both exists and is a file. Parameters @@ -133,10 +154,10 @@ def is_file(path: str) -> bool: ------- :obj:`.bool` """ - return _fs().is_file(path) + return _fses[requester_pays_config].is_file(path) -def is_dir(path: str) -> bool: +def is_dir(path: str, *, requester_pays_config: Optional[GCSRequesterPaysConfiguration] = None) -> bool: """Returns ``True`` if `path` both exists and is a directory. Parameters @@ -147,10 +168,10 @@ def is_dir(path: str) -> bool: ------- :obj:`.bool` """ - return _fs().is_dir(path) + return _fses[requester_pays_config].is_dir(path) -def stat(path: str) -> FileListEntry: +def stat(path: str, *, requester_pays_config: Optional[GCSRequesterPaysConfiguration] = None) -> FileListEntry: """Returns information about the file or directory at a given path. Notes @@ -174,10 +195,10 @@ def stat(path: str) -> FileListEntry: ------- :obj:`dict` """ - return _fs().stat(path) + return _fses[requester_pays_config].stat(path) -def ls(path: str) -> List[FileListEntry]: +def ls(path: str, *, requester_pays_config: Optional[GCSRequesterPaysConfiguration] = None) -> List[FileListEntry]: """Returns information about files at `path`. Notes @@ -205,10 +226,10 @@ def ls(path: str) -> List[FileListEntry]: ------- :obj:`list` [:obj:`dict`] """ - return _fs().ls(path) + return _fses[requester_pays_config].ls(path) -def mkdir(path: str): +def mkdir(path: str, *, requester_pays_config: Optional[GCSRequesterPaysConfiguration] = None): """Ensure files can be created whose dirname is `path`. Warning @@ -218,10 +239,10 @@ def mkdir(path: str): on Google Cloud Storage, this operation does nothing. """ - _fs().mkdir(path) + _fses[requester_pays_config].mkdir(path) -def remove(path: str): +def remove(path: str, *, requester_pays_config: Optional[GCSRequesterPaysConfiguration] = None): """Removes the file at `path`. If the file does not exist, this function does nothing. `path` must be a URI (uniform resource identifier) or a path on the local filesystem. @@ -230,10 +251,10 @@ def remove(path: str): ---------- path : :class:`str` """ - _fs().remove(path) + _fses[requester_pays_config].remove(path) -def rmtree(path: str): +def rmtree(path: str, *, requester_pays_config: Optional[GCSRequesterPaysConfiguration] = None): """Recursively remove all files under the given `path`. On a local filesystem, this removes the directory tree at `path`. On blob storage providers such as GCS, S3 and ABS, this removes all files whose name starts with `path`. As such, @@ -243,4 +264,4 @@ def rmtree(path: str): ---------- path : :class:`str` """ - _fs().rmtree(path) + _fses[requester_pays_config].rmtree(path) diff --git a/hail/python/hailtop/fs/router_fs.py b/hail/python/hailtop/fs/router_fs.py index e4f23442e10..cbc1b02e1a1 100644 --- a/hail/python/hailtop/fs/router_fs.py +++ b/hail/python/hailtop/fs/router_fs.py @@ -1,18 +1,27 @@ -from typing import List, AsyncContextManager, BinaryIO, Optional, Tuple, Dict, Any import asyncio -import io -import os +import fnmatch import functools import glob -import fnmatch - -from hailtop.aiotools.fs import Copier, Transfer, FileListEntry as AIOFileListEntry, ReadableStream, WritableStream -from hailtop.aiotools.local_fs import LocalAsyncFS +import io +import os +from types import TracebackType +from typing import Any, AsyncContextManager, BinaryIO, Dict, List, Optional, Tuple, Type + +from hailtop.aiotools.fs import ( + AsyncFSURL, + Copier, + ReadableStream, + Transfer, + WritableStream, +) +from hailtop.aiotools.fs import ( + FileListEntry as AIOFileListEntry, +) from hailtop.aiotools.router_fs import RouterAsyncFS -from hailtop.utils import bounded_gather2, async_to_blocking +from hailtop.utils import async_to_blocking, bounded_gather2 from .fs import FS -from .stat_result import FileType, FileListEntry +from .stat_result import FileListEntry, FileType class SyncReadableStream(io.RawIOBase, BinaryIO): # type: ignore # https://github.com/python/typeshed/blob/a40d79a4e63c4e750a8d3a8012305da942251eb4/stdlib/http/client.pyi#L81 @@ -50,8 +59,8 @@ def isatty(self): def readable(self): return True - def seek(self, offset: int, whence: int = os.SEEK_SET): - async_to_blocking(self.ars.seek(offset, whence)) + def seek(self, offset: int, whence: int = os.SEEK_SET) -> int: + return async_to_blocking(self.ars.seek(offset, whence)) def seekable(self) -> bool: return self.ars.seekable() @@ -59,7 +68,7 @@ def seekable(self) -> bool: def tell(self) -> int: return self.ars.tell() - def truncate(self): + def truncate(self, size: Optional[int] = None): raise io.UnsupportedOperation def writable(self): @@ -117,7 +126,7 @@ def isatty(self): def readable(self): return False - def readline(self, size=-1): + def readline(self, size: Optional[int] = -1): raise OSError def readlines(self, hint=-1): @@ -132,7 +141,7 @@ def seekable(self): def tell(self): raise io.UnsupportedOperation - def truncate(self): + def truncate(self, size: Optional[int] = None): raise io.UnsupportedOperation def writable(self): @@ -163,28 +172,46 @@ def _stat_result(is_dir: bool, size_bytes_and_time_modified: Optional[Tuple[int, size=size_bytes, typ=FileType.DIRECTORY if is_dir else FileType.FILE, owner=None, - modification_time=time_modified) + modification_time=time_modified, + ) class RouterFS(FS): - def __init__(self, - afs: Optional[RouterAsyncFS] = None, - *, - local_kwargs: Optional[Dict[str, Any]] = None, - gcs_kwargs: Optional[Dict[str, Any]] = None, - azure_kwargs: Optional[Dict[str, Any]] = None, - s3_kwargs: Optional[Dict[str, Any]] = None): + def __init__( + self, + afs: Optional[RouterAsyncFS] = None, + *, + local_kwargs: Optional[Dict[str, Any]] = None, + gcs_kwargs: Optional[Dict[str, Any]] = None, + azure_kwargs: Optional[Dict[str, Any]] = None, + s3_kwargs: Optional[Dict[str, Any]] = None, + ): if afs and (local_kwargs or gcs_kwargs or azure_kwargs or s3_kwargs): raise ValueError( f'If afs is specified, no other arguments may be specified: {afs=}, {local_kwargs=}, {gcs_kwargs=}, {azure_kwargs=}, {s3_kwargs=}' ) self.afs = afs or RouterAsyncFS( - local_kwargs=local_kwargs, - gcs_kwargs=gcs_kwargs, - azure_kwargs=azure_kwargs, - s3_kwargs=s3_kwargs + local_kwargs=local_kwargs, gcs_kwargs=gcs_kwargs, azure_kwargs=azure_kwargs, s3_kwargs=s3_kwargs ) + def __enter__(self): + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ): + self.close() + + def close(self): + async_to_blocking(self.afs.close()) + + @property + def _gcs_kwargs(self) -> Optional[Dict[str, Any]]: + return self.afs._gcs_kwargs + def open(self, path: str, mode: str = 'r', buffer_size: int = 8192) -> io.IOBase: del buffer_size @@ -213,6 +240,7 @@ def copy(self, src: str, dest: str, *, max_simultaneous_transfers=75): async def _copy(): sema = asyncio.Semaphore(max_simultaneous_transfers) await Copier.copy(self.afs, sema, transfer) + return async_to_blocking(_copy()) def exists(self, path: str) -> bool: @@ -220,9 +248,8 @@ async def _exists(): dir_path = path if dir_path[-1] != '/': dir_path = dir_path + '/' - return any(await asyncio.gather( - self.afs.isfile(path), - self.afs.isdir(dir_path))) + return any(await asyncio.gather(self.afs.isfile(path), self.afs.isdir(dir_path))) + return async_to_blocking(_exists()) def is_file(self, path: str) -> bool: @@ -237,8 +264,9 @@ def is_dir(self, path: str) -> bool: return async_to_blocking(self._async_is_dir(path)) def stat(self, path: str) -> FileListEntry: - maybe_sb_and_t, is_dir = async_to_blocking(asyncio.gather( - self._size_bytes_and_time_modified_or_none(path), self._async_is_dir(path))) + maybe_sb_and_t, is_dir = async_to_blocking( + asyncio.gather(self._size_bytes_and_time_modified_or_none(path), self._async_is_dir(path)) + ) if maybe_sb_and_t is None: if not is_dir: raise FileNotFoundError(path) @@ -260,42 +288,46 @@ async def maybe_status() -> Optional[Tuple[int, float]]: return (await file_status.size(), file_status.time_modified().timestamp()) except IsADirectoryError: return None - return _stat_result( - *await asyncio.gather(fle.is_dir(), maybe_status(), fle.url())) - - def ls(self, - path: str, - *, - error_when_file_and_directory: bool = True, - _max_simultaneous_files: int = 50) -> List[FileListEntry]: - return async_to_blocking(self._async_ls( - path, - error_when_file_and_directory=error_when_file_and_directory, - _max_simultaneous_files=_max_simultaneous_files)) - - async def _async_ls(self, - path: str, - *, - error_when_file_and_directory: bool = True, - _max_simultaneous_files: int = 50) -> List[FileListEntry]: + + return _stat_result(*await asyncio.gather(fle.is_dir(), maybe_status(), fle.url())) + + def ls( + self, path: str, *, error_when_file_and_directory: bool = True, _max_simultaneous_files: int = 50 + ) -> List[FileListEntry]: + return async_to_blocking( + self._async_ls( + path, + error_when_file_and_directory=error_when_file_and_directory, + _max_simultaneous_files=_max_simultaneous_files, + ) + ) + + async def _async_ls( + self, path: str, *, error_when_file_and_directory: bool = True, _max_simultaneous_files: int = 50 + ) -> List[FileListEntry]: sema = asyncio.Semaphore(_max_simultaneous_files) async def ls_no_glob(path) -> List[FileListEntry]: try: - return await self._ls_no_glob(path, - error_when_file_and_directory=error_when_file_and_directory, - sema=sema) + return await self._ls_no_glob( + path, error_when_file_and_directory=error_when_file_and_directory, sema=sema + ) except FileNotFoundError: return [] + async def list_within_each_prefix(prefixes: List[AsyncFSURL], parts: List[str]) -> List[List[FileListEntry]]: + pfs = [functools.partial(ls_no_glob, str(prefix.with_new_path_components(*parts))) for prefix in prefixes] + return await bounded_gather2(sema, *pfs, cancel_on_error=True) + url = self.afs.parse_url(path) - if any(glob.escape(bucket_part) != bucket_part - for bucket_part in url.bucket_parts): + if any(glob.escape(bucket_part) != bucket_part for bucket_part in url.bucket_parts): raise ValueError(f'glob pattern only allowed in path (e.g. not in bucket): {path}') blobpath = url.path - components = blobpath.split('/') - assert len(components) > 0 + if blobpath == '': + components = [] + else: + components = blobpath.split('/') glob_components = [] running_prefix = [] @@ -309,78 +341,53 @@ async def ls_no_glob(path) -> List[FileListEntry]: running_prefix = [] suffix_components: List[str] = running_prefix - if len(url.bucket_parts) > 0: - first_prefix = [url.scheme + ':', '', *url.bucket_parts] - else: - assert url.scheme == 'file' - if path.startswith('file://'): - first_prefix = ['file:', '', ''] - else: - first_prefix = [] - cached_stats_for_each_cumulative_prefix: Optional[List[FileListEntry]] = None - cumulative_prefixes = [first_prefix] + cumulative_prefixes: List[AsyncFSURL] = [url.with_root_path()] for intervening_components, single_component_glob_pattern in glob_components: - stats_grouped_by_prefix = await bounded_gather2( - sema, - *[ - functools.partial(ls_no_glob, '/'.join([*cumulative_prefix, *intervening_components])) - for cumulative_prefix in cumulative_prefixes - ], - cancel_on_error=True - ) + stats_grouped_by_prefix = await list_within_each_prefix(cumulative_prefixes, intervening_components) cached_stats_for_each_cumulative_prefix = [ stat for stats_for_one_prefix, cumulative_prefix in zip(stats_grouped_by_prefix, cumulative_prefixes) for stat in stats_for_one_prefix - if fnmatch.fnmatch(stat.path, - '/'.join([*cumulative_prefix, *intervening_components, single_component_glob_pattern])) - ] - cumulative_prefixes = [ - stat.path.split('/') - for stat in cached_stats_for_each_cumulative_prefix + if fnmatch.fnmatch( + stat.path, + str( + cumulative_prefix.with_new_path_components( + *intervening_components, single_component_glob_pattern + ) + ), + ) ] + cumulative_prefixes = [self.afs.parse_url(stat.path) for stat in cached_stats_for_each_cumulative_prefix] if len(suffix_components) == 0 and cached_stats_for_each_cumulative_prefix is not None: found_stats = cached_stats_for_each_cumulative_prefix else: - found_stats_grouped_by_prefix = await bounded_gather2( - sema, - *[ - functools.partial(ls_no_glob, '/'.join([*cumulative_prefix, *suffix_components])) - for cumulative_prefix in cumulative_prefixes - ], - cancel_on_error=True - ) - found_stats = [ - stat - for stats in found_stats_grouped_by_prefix - for stat in stats - ] + found_stats_grouped_by_prefix = await list_within_each_prefix(cumulative_prefixes, suffix_components) + found_stats = [stat for stats in found_stats_grouped_by_prefix for stat in stats] if len(glob_components) == 0 and len(found_stats) == 0: # Unless we are using a glob pattern, a path referring to no files should error raise FileNotFoundError(path) return found_stats - async def _ls_no_glob(self, - path: str, - *, - error_when_file_and_directory: bool = True, - sema: asyncio.Semaphore) -> List[FileListEntry]: + async def _ls_no_glob( + self, path: str, *, error_when_file_and_directory: bool = True, sema: asyncio.Semaphore + ) -> List[FileListEntry]: async def ls_as_dir() -> Optional[List[FileListEntry]]: try: return await bounded_gather2( sema, - *[functools.partial(self._aiofle_to_fle, fle) - async for fle in await self.afs.listfiles(path)], - cancel_on_error=True + *[functools.partial(self._aiofle_to_fle, fle) async for fle in await self.afs.listfiles(path)], + cancel_on_error=True, ) except (FileNotFoundError, NotADirectoryError): return None + maybe_sb_and_t, maybe_contents = await asyncio.gather( - self._size_bytes_and_time_modified_or_none(path), ls_as_dir()) + self._size_bytes_and_time_modified_or_none(path), ls_as_dir() + ) if maybe_sb_and_t is not None: file_stat = _stat_result(False, maybe_sb_and_t, path) @@ -412,10 +419,11 @@ async def armtree(self, path: str) -> None: return await self.afs.rmtree(None, path) def supports_scheme(self, scheme: str) -> bool: - return scheme in self.afs.schemes + return scheme in self.afs.schemes() def canonicalize_path(self, path: str) -> str: - if isinstance(self.afs._get_fs(path), LocalAsyncFS): + url = self.afs.parse_url(path) + if url.scheme == 'file': if path.startswith('file:'): return 'file:' + os.path.realpath(path[5:]) return 'file:' + os.path.realpath(path) diff --git a/hail/python/hailtop/fs/stat_result.py b/hail/python/hailtop/fs/stat_result.py index 4cbb27c6267..e846eb0bf94 100644 --- a/hail/python/hailtop/fs/stat_result.py +++ b/hail/python/hailtop/fs/stat_result.py @@ -1,5 +1,5 @@ from enum import Enum, auto -from typing import Dict, NamedTuple, Optional, Union, Any +from typing import Any, Dict, NamedTuple, Optional, Union from hailtop.utils.filesize import filesize @@ -10,6 +10,23 @@ class FileType(Enum): SYMLINK = auto() +class FileStatus(NamedTuple): + path: str + owner: Union[None, str, int] + size: int + # common point between unix, google, and hadoop filesystems, represented as a unix timestamp + modification_time: Optional[float] + + def to_legacy_dict(self) -> Dict[str, Any]: + return { + 'path': self.path, + 'owner': self.owner, + 'size_bytes': self.size, + 'size': filesize(self.size), + 'modification_time': self.modification_time, + } + + class FileListEntry(NamedTuple): path: str owner: Union[None, str, int] diff --git a/hail/python/hailtop/hail_decorator.py b/hail/python/hailtop/hail_decorator.py index 979d1f41105..5e23bc78410 100644 --- a/hail/python/hailtop/hail_decorator.py +++ b/hail/python/hailtop/hail_decorator.py @@ -1,16 +1,15 @@ -from typing import TypeVar, Callable, cast, Protocol -from typing_extensions import ParamSpec +from typing import Callable, Protocol, TypeVar, cast + from decorator import decorator as _decorator +from typing_extensions import ParamSpec P = ParamSpec('P') T = TypeVar('T') + class Wrapper(Protocol[P, T]): - def __call__(self, fun: Callable[P, T], /, *args: P.args, **kwargs: P.kwargs) -> T: - ... + def __call__(self, fun: Callable[P, T], /, *args: P.args, **kwargs: P.kwargs) -> T: ... -def decorator( - fun: Wrapper[P, T] -) -> Callable[[Callable[P, T]], Callable[P, T]]: +def decorator(fun: Wrapper[P, T]) -> Callable[[Callable[P, T]], Callable[P, T]]: return cast(Callable[[Callable[P, T]], Callable[P, T]], _decorator(fun)) diff --git a/hail/python/hailtop/hail_event_loop.py b/hail/python/hailtop/hail_event_loop.py index 229232604eb..b515e6b55c8 100644 --- a/hail/python/hailtop/hail_event_loop.py +++ b/hail/python/hailtop/hail_event_loop.py @@ -1,16 +1,19 @@ import asyncio + import nest_asyncio -def hail_event_loop(): - '''If a running event loop exists, use nest_asyncio to allow Hail's event loops to nest inside +def hail_event_loop() -> asyncio.AbstractEventLoop: + """If a running event loop exists, use nest_asyncio to allow Hail's event loops to nest inside it. If no event loop exists, ask asyncio to get one for us. - ''' + """ try: - asyncio.get_running_loop() - nest_asyncio.apply() - return asyncio.get_running_loop() + loop = asyncio.get_event_loop() except RuntimeError: - return asyncio.get_event_loop() + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + return loop + nest_asyncio.apply(loop) + return loop diff --git a/hail/python/hailtop/hail_frozenlist.py b/hail/python/hailtop/hail_frozenlist.py index 3d5da29bd92..8d050954af2 100644 --- a/hail/python/hailtop/hail_frozenlist.py +++ b/hail/python/hailtop/hail_frozenlist.py @@ -1,6 +1,6 @@ -from typing import TypeVar, Sequence, List -from frozenlist import FrozenList as _FrozenList +from typing import List, Sequence, TypeVar +from frozenlist import FrozenList as _FrozenList T = TypeVar('T') diff --git a/hail/python/hailtop/hail_logging.py b/hail/python/hailtop/hail_logging.py index 34cd21a62da..9a4283bcd1b 100644 --- a/hail/python/hailtop/hail_logging.py +++ b/hail/python/hailtop/hail_logging.py @@ -14,18 +14,21 @@ def __init__(self, *args, **kwargs): del kwargs -def logger_json_serializer(log_record, - default=None, - cls=None, - indent=None, - ensure_ascii=False) -> str: - assert default is None and cls is OrJsonEncoder and indent is None and ensure_ascii is False, (default, cls, indent, ensure_ascii) +def logger_json_serializer(log_record, default=None, cls=None, indent=None, ensure_ascii=False) -> str: + assert default is None and cls is OrJsonEncoder and indent is None and ensure_ascii is False, ( + default, + cls, + indent, + ensure_ascii, + ) return orjson.dumps(log_record).decode('utf-8') class CustomJsonFormatter(jsonlogger.JsonFormatter): def __init__(self, format_string): - super().__init__(format_string, json_encoder=OrJsonEncoder, json_serializer=logger_json_serializer, json_ensure_ascii=False) + super().__init__( + format_string, json_encoder=OrJsonEncoder, json_serializer=logger_json_serializer, json_ensure_ascii=False + ) def add_fields(self, log_record, record, message_dict): super().add_fields(log_record, record, message_dict) @@ -57,9 +60,10 @@ def log(self, request, response, time): 'request_start_time': start_time_str, 'request_duration': time, 'response_status': response.status, - 'x_real_ip': request.headers.get("X-Real-IP") + 'x_real_ip': request.headers.get("X-Real-IP"), } - self.logger.info(f'{request.scheme} {request.method} {request.path} ' - f'done in {time}s: {response.status}', - extra={**extra, **request.get('batch_telemetry', {})}) + self.logger.info( + f'{request.scheme} {request.method} {request.path} ' f'done in {time}s: {response.status}', + extra={**extra, **request.get('batch_telemetry', {})}, + ) diff --git a/hail/python/hailtop/hailctl/__main__.py b/hail/python/hailtop/hailctl/__main__.py index ebb9ad02e58..98e80043868 100644 --- a/hail/python/hailtop/hailctl/__main__.py +++ b/hail/python/hailtop/hailctl/__main__.py @@ -1,15 +1,15 @@ -import typer import os +import typer + from .auth import cli as auth_cli from .batch import cli as batch_cli from .config import cli as config_cli -from .describe import describe from .dataproc import cli as dataproc_cli +from .describe import describe from .dev import cli as dev_cli from .hdinsight import cli as hdinsight_cli - app = typer.Typer( help='Manage and monitor hail deployments.', no_args_is_help=True, @@ -29,8 +29,9 @@ @app.command() def version(): - '''Print version information and exit.''' + """Print version information and exit.""" import hailtop # pylint: disable=import-outside-toplevel + print(hailtop.version()) @@ -41,8 +42,9 @@ def curl( path: str, ctx: typer.Context, ): - '''Issue authenticated curl requests to Hail infrastructure.''' + """Issue authenticated curl requests to Hail infrastructure.""" from hailtop.utils import async_to_blocking # pylint: disable=import-outside-toplevel + async_to_blocking(_curl(namespace, service, path, ctx)) diff --git a/hail/python/hailtop/hailctl/auth/cli.py b/hail/python/hailtop/hailctl/auth/cli.py index bba44d240f6..af06c2d03a0 100644 --- a/hail/python/hailtop/hailctl/auth/cli.py +++ b/hail/python/hailtop/hailctl/auth/cli.py @@ -1,11 +1,11 @@ import asyncio +import json import sys +from typing import Annotated as Ann +from typing import Optional + import typer from typer import Argument as Arg -import json - -from typing import Optional, Annotated as Ann - app = typer.Typer( name='auth', @@ -17,14 +17,15 @@ @app.command() def login(): - '''Obtain Hail credentials.''' + """Obtain Hail credentials.""" from .login import async_login # pylint: disable=import-outside-toplevel + asyncio.run(async_login()) @app.command() def copy_paste_login(copy_paste_token: str): - '''Obtain Hail credentials with a copy paste token.''' + """Obtain Hail credentials with a copy paste token.""" from hailtop.auth import copy_paste_login # pylint: disable=import-outside-toplevel from hailtop.config import get_deploy_config # pylint: disable=import-outside-toplevel @@ -34,7 +35,7 @@ def copy_paste_login(copy_paste_token: str): @app.command() def logout(): - '''Revoke Hail credentials.''' + """Revoke Hail credentials.""" from hailtop.auth import async_logout # pylint: disable=import-outside-toplevel asyncio.run(async_logout()) @@ -42,9 +43,9 @@ def logout(): @app.command() def list(): - '''List Hail credentials.''' - from hailtop.config import get_deploy_config # pylint: disable=import-outside-toplevel + """List Hail credentials.""" from hailtop.auth import get_tokens # pylint: disable=import-outside-toplevel + from hailtop.config import get_deploy_config # pylint: disable=import-outside-toplevel deploy_config = get_deploy_config() tokens = get_tokens() @@ -58,7 +59,7 @@ def list(): @app.command() def user(): - '''Get Hail user information.''' + """Get Hail user information.""" from hailtop.auth import get_userinfo # pylint: disable=import-outside-toplevel userinfo = get_userinfo() @@ -86,12 +87,16 @@ def create_user( hail_credentials_secret_name: Optional[str] = None, wait: bool = False, ): - ''' + """ Create a new Hail user with username USERNAME and login ID LOGIN_ID. - ''' + """ from .create_user import polling_create_user # pylint: disable=import-outside-toplevel - asyncio.run(polling_create_user(username, login_id, developer, service_account, hail_identity, hail_credentials_secret_name, wait=wait)) + asyncio.run( + polling_create_user( + username, login_id, developer, service_account, hail_identity, hail_credentials_secret_name, wait=wait + ) + ) @app.command() @@ -99,9 +104,9 @@ def delete_user( username: str, wait: bool = False, ): - ''' + """ Delete the Hail user with username USERNAME. - ''' + """ from .delete_user import polling_delete_user # pylint: disable=import-outside-toplevel asyncio.run(polling_delete_user(username, wait)) diff --git a/hail/python/hailtop/hailctl/auth/create_user.py b/hail/python/hailtop/hailctl/auth/create_user.py index 2f622a14cca..2768e68b1a6 100644 --- a/hail/python/hailtop/hailctl/auth/create_user.py +++ b/hail/python/hailtop/hailctl/auth/create_user.py @@ -1,7 +1,7 @@ from typing import Optional -from hailtop.utils import sleep_before_try from hailtop.auth import async_create_user, async_get_user +from hailtop.utils import sleep_before_try class CreateUserException(Exception): @@ -19,7 +19,9 @@ async def polling_create_user( wait: bool = False, ): try: - await async_create_user(username, login_id, developer, service_account, hail_identity, hail_credentials_secret_name) + await async_create_user( + username, login_id, developer, service_account, hail_identity, hail_credentials_secret_name + ) if not wait: return @@ -33,7 +35,7 @@ async def _poll(): return assert user['state'] == 'creating' tries += 1 - await sleep_before_try(tries, base_delay_ms = 5_000) + await sleep_before_try(tries, base_delay_ms=5_000) await _poll() except Exception as e: diff --git a/hail/python/hailtop/hailctl/auth/delete_user.py b/hail/python/hailtop/hailctl/auth/delete_user.py index f82d20045ef..e85afb5b1d1 100644 --- a/hail/python/hailtop/hailctl/auth/delete_user.py +++ b/hail/python/hailtop/hailctl/auth/delete_user.py @@ -1,5 +1,5 @@ -from hailtop.utils import sleep_before_try from hailtop.auth import async_delete_user, async_get_user +from hailtop.utils import sleep_before_try class DeleteUserException(Exception): @@ -25,7 +25,7 @@ async def _poll(): return assert user['state'] == 'deleting' tries += 1 - await sleep_before_try(tries, base_delay_ms = 5_000) + await sleep_before_try(tries, base_delay_ms=5_000) await _poll() except Exception as e: diff --git a/hail/python/hailtop/hailctl/auth/login.py b/hail/python/hailtop/hailctl/auth/login.py index 70dad3faf93..96705d555f5 100644 --- a/hail/python/hailtop/hailctl/auth/login.py +++ b/hail/python/hailtop/hailctl/auth/login.py @@ -1,9 +1,9 @@ -import os import json +import os -from hailtop.config import get_deploy_config, DeployConfig, get_user_identity_config_path, get_hail_config_path -from hailtop.auth import hail_credentials, IdentityProvider, AzureFlow, GoogleFlow -from hailtop.httpx import client_session, ClientSession +from hailtop.auth import AzureFlow, GoogleFlow, IdentityProvider, hail_credentials +from hailtop.config import DeployConfig, get_deploy_config, get_hail_config_path, get_user_identity_config_path +from hailtop.httpx import ClientSession, client_session async def auth_flow(deploy_config: DeployConfig, session: ClientSession): diff --git a/hail/python/hailtop/hailctl/batch/batch_cli_utils.py b/hail/python/hailtop/hailctl/batch/batch_cli_utils.py index d4d286d1420..0bcaa00f9ec 100644 --- a/hail/python/hailtop/hailctl/batch/batch_cli_utils.py +++ b/hail/python/hailtop/hailctl/batch/batch_cli_utils.py @@ -1,11 +1,12 @@ -import json -from enum import Enum -import yaml import csv -from typing import Any, List, Dict, Callable, Annotated as Ann -import tabulate import io +import json +from enum import Enum +from typing import Annotated as Ann +from typing import Any, Callable, Dict, List +import tabulate +import yaml from typer import Option as Opt TableData = List[Dict[str, Any]] diff --git a/hail/python/hailtop/hailctl/batch/billing/cli.py b/hail/python/hailtop/hailctl/batch/billing/cli.py index 7e9b198ccad..532ee4f3677 100644 --- a/hail/python/hailtop/hailctl/batch/billing/cli.py +++ b/hail/python/hailtop/hailctl/batch/billing/cli.py @@ -1,7 +1,6 @@ import typer -from ..batch_cli_utils import make_formatter, StructuredFormat, StructuredFormatOption - +from ..batch_cli_utils import StructuredFormat, StructuredFormatOption, make_formatter app = typer.Typer( name='billing', @@ -13,7 +12,7 @@ @app.command() def get(billing_project: str, output: StructuredFormatOption = StructuredFormat.YAML): - '''Get the billing information for BILLING_PROJECT.''' + """Get the billing information for BILLING_PROJECT.""" from hailtop.batch_client.client import BatchClient # pylint: disable=import-outside-toplevel with BatchClient('') as client: @@ -23,7 +22,7 @@ def get(billing_project: str, output: StructuredFormatOption = StructuredFormat. @app.command() def list(output: StructuredFormatOption = StructuredFormat.YAML): - '''List billing projects.''' + """List billing projects.""" from hailtop.batch_client.client import BatchClient # pylint: disable=import-outside-toplevel with BatchClient('') as client: diff --git a/hail/python/hailtop/hailctl/batch/cli.py b/hail/python/hailtop/hailctl/batch/cli.py index a857680da7f..c10650f02f1 100644 --- a/hail/python/hailtop/hailctl/batch/cli.py +++ b/hail/python/hailtop/hailctl/batch/cli.py @@ -1,27 +1,27 @@ import asyncio -from enum import Enum -import typer -from typer import Option as Opt, Argument as Arg import json +from enum import Enum +from typing import Annotated as Ann +from typing import Any, Dict, List, Optional, cast -from typing import Optional, List, Annotated as Ann, cast, Dict, Any +import typer +from typer import Argument as Arg +from typer import Option as Opt -from . import list_batches -from . import billing -from .initialize import async_basic_initialize +from . import billing, list_batches from . import submit as _submit from .batch_cli_utils import ( - get_batch_if_exists, - get_job_if_exists, - make_formatter, + ExtendedOutputFormat, + ExtendedOutputFormatOption, StructuredFormat, StructuredFormatOption, StructuredFormatPlusText, StructuredFormatPlusTextOption, - ExtendedOutputFormat, - ExtendedOutputFormatOption, + get_batch_if_exists, + get_job_if_exists, + make_formatter, ) - +from .initialize import async_basic_initialize app = typer.Typer( name='batch', @@ -40,13 +40,13 @@ def list( full: bool = False, output: ExtendedOutputFormatOption = ExtendedOutputFormat.GRID, ): - '''List batches.''' + """List batches.""" list_batches.list(query, limit, before, full, output) @app.command() def get(batch_id: int, output: StructuredFormatOption = StructuredFormat.YAML): - '''Get information on the batch with id BATCH_ID.''' + """Get information on the batch with id BATCH_ID.""" from hailtop.batch_client.client import BatchClient # pylint: disable=import-outside-toplevel with BatchClient('') as client: @@ -59,7 +59,7 @@ def get(batch_id: int, output: StructuredFormatOption = StructuredFormat.YAML): @app.command() def cancel(batch_id: int): - '''Cancel the batch with id BATCH_ID.''' + """Cancel the batch with id BATCH_ID.""" from hailtop.batch_client.client import BatchClient # pylint: disable=import-outside-toplevel with BatchClient('') as client: @@ -73,7 +73,7 @@ def cancel(batch_id: int): @app.command() def delete(batch_id: int): - '''Delete the batch with id BATCH_ID.''' + """Delete the batch with id BATCH_ID.""" from hailtop.batch_client.client import BatchClient # pylint: disable=import-outside-toplevel with BatchClient('') as client: @@ -98,7 +98,7 @@ def log( container: Ann[Optional[JobContainer], Opt(help='Container name of the desired job')] = None, output: StructuredFormatOption = StructuredFormat.YAML, ): - '''Get the log for the job with id JOB_ID in the batch with id BATCH_ID.''' + """Get the log for the job with id JOB_ID in the batch with id BATCH_ID.""" from hailtop.batch_client.client import BatchClient # pylint: disable=import-outside-toplevel with BatchClient('') as client: @@ -119,7 +119,7 @@ def wait( quiet: Ann[bool, Opt('--quiet', '-q', help='Do not print a progress bar for the batch.')] = False, output: StructuredFormatPlusTextOption = StructuredFormatPlusText.TEXT, ): - '''Wait for the batch with id BATCH_ID to complete, then print status.''' + """Wait for the batch with id BATCH_ID to complete, then print status.""" from hailtop.batch_client.client import BatchClient # pylint: disable=import-outside-toplevel with BatchClient('') as client: @@ -138,7 +138,7 @@ def wait( @app.command() def job(batch_id: int, job_id: int, output: StructuredFormatOption = StructuredFormat.YAML): - '''Get the status and specification for the job with id JOB_ID in the batch with id BATCH_ID.''' + """Get the status and specification for the job with id JOB_ID in the batch with id BATCH_ID.""" from hailtop.batch_client.client import BatchClient # pylint: disable=import-outside-toplevel with BatchClient('') as client: @@ -146,12 +146,11 @@ def job(batch_id: int, job_id: int, output: StructuredFormatOption = StructuredF if job is not None: assert job._status - print(make_formatter(output)([ - cast( - Dict[str, Any], # https://stackoverflow.com/q/71986632/6823256 - job._status + print( + make_formatter(output)( + [cast(Dict[str, Any], job._status)] # https://stackoverflow.com/q/71986632/6823256 ) - ])) + ) else: print(f"Job with ID {job_id} on batch {batch_id} not found") @@ -160,7 +159,9 @@ def job(batch_id: int, job_id: int, output: StructuredFormatOption = StructuredF def submit( ctx: typer.Context, script: str, - arguments: Ann[Optional[List[str]], Arg(help='You should use -- if you want to pass option-like arguments through.')] = None, + arguments: Ann[ + Optional[List[str]], Arg(help='You should use -- if you want to pass option-like arguments through.') + ] = None, files: Ann[ Optional[List[str]], Opt(help='Files or directories to add to the working directory of the job.') ] = None, @@ -168,19 +169,17 @@ def submit( image_name: Ann[Optional[str], Opt(help='Name of Docker image for the job (default: hailgenetics/hail)')] = None, output: StructuredFormatPlusTextOption = StructuredFormatPlusText.TEXT, ): - '''Submit a batch with a single job that runs SCRIPT with the arguments ARGUMENTS. + """Submit a batch with a single job that runs SCRIPT with the arguments ARGUMENTS. If you wish to pass option-like arguments you should use "--". For example: $ hailctl batch submit --image-name docker.io/image my_script.py -- some-argument --animal dog - ''' + """ asyncio.run(_submit.submit(name, image_name, files or [], output, script, [*(arguments or []), *ctx.args])) @app.command('init', help='Initialize a Hail Batch environment.') -def initialize( - verbose: Ann[bool, Opt('--verbose', '-v', help='Print gcloud commands being executed')] = False -): +def initialize(verbose: Ann[bool, Opt('--verbose', '-v', help='Print gcloud commands being executed')] = False): asyncio.run(async_basic_initialize(verbose=verbose)) diff --git a/hail/python/hailtop/hailctl/batch/initialize.py b/hail/python/hailtop/hailctl/batch/initialize.py index 42261c71370..55af8cac34c 100644 --- a/hail/python/hailtop/hailctl/batch/initialize.py +++ b/hail/python/hailtop/hailctl/batch/initialize.py @@ -1,16 +1,22 @@ from typing import List, Optional, Tuple -import typer -from typer import Abort, Exit +import typer from rich.prompt import Confirm, IntPrompt, Prompt +from typer import Abort, Exit from hailtop.config import ConfigVariable async def setup_existing_remote_tmpdir(service_account: str, verbose: bool) -> Tuple[Optional[str], str, bool]: - from hailtop.aiogoogle import GoogleStorageAsyncFS # pylint: disable=import-outside-toplevel + from hailtop.aiocloud.aiogoogle.client.storage_client import ( # pylint: disable=import-outside-toplevel + GoogleStorageAsyncFS, + ) - from .utils import InsufficientPermissions, get_gcp_bucket_information, grant_service_account_bucket_access_with_role # pylint: disable=import-outside-toplevel + from .utils import ( # pylint: disable=import-outside-toplevel + InsufficientPermissions, + get_gcp_bucket_information, + grant_service_account_bucket_access_with_role, + ) warnings = False @@ -27,42 +33,58 @@ async def setup_existing_remote_tmpdir(service_account: str, verbose: bool) -> T location = bucket_info['location'].lower() if bucket_info['locationType'] != 'region': - typer.secho(f'WARNING: remote temporary directory {remote_tmpdir} is multi-regional. Using this bucket with the Batch Service will incur additional network fees.', - fg=typer.colors.YELLOW) + typer.secho( + f'WARNING: remote temporary directory {remote_tmpdir} is multi-regional. Using this bucket with the Batch Service will incur additional network fees.', + fg=typer.colors.YELLOW, + ) warnings = True storage_class = bucket_info['storageClass'] if storage_class.upper() != 'STANDARD': - typer.secho(f'WARNING: remote temporary directory {remote_tmpdir} does not have storage class "STANDARD". Additional data recovery charges will occur when accessing data.', - fg=typer.colors.YELLOW) + typer.secho( + f'WARNING: remote temporary directory {remote_tmpdir} does not have storage class "STANDARD". Additional data recovery charges will occur when accessing data.', + fg=typer.colors.YELLOW, + ) warnings = True - give_access_to_remote_tmpdir = Confirm.ask(f'Do you want to give service account {service_account} read/write access to bucket {bucket}?') + give_access_to_remote_tmpdir = Confirm.ask( + f'Do you want to give service account {service_account} read/write access to bucket {bucket}?' + ) if give_access_to_remote_tmpdir: try: - await grant_service_account_bucket_access_with_role(bucket=bucket, service_account=service_account, role= 'roles/storage.objectAdmin', verbose=verbose) + await grant_service_account_bucket_access_with_role( + bucket=bucket, service_account=service_account, role='roles/storage.objectAdmin', verbose=verbose + ) except InsufficientPermissions as e: typer.secho(e.message, fg=typer.colors.RED) raise Abort() from e - typer.secho(f'Granted service account {service_account} read and write access to {bucket}.', fg=typer.colors.GREEN) + typer.secho( + f'Granted service account {service_account} read and write access to {bucket}.', fg=typer.colors.GREEN + ) else: - typer.secho(f'WARNING: Please verify service account {service_account} has the role "roles/storage.objectAdmin" or ' - f'both "roles/storage.objectViewer" and "roles/storage.objectCreator" roles for bucket {bucket}.', - fg=typer.colors.YELLOW) + typer.secho( + f'WARNING: Please verify service account {service_account} has the role "roles/storage.objectAdmin" or ' + f'both "roles/storage.objectViewer" and "roles/storage.objectCreator" roles for bucket {bucket}.', + fg=typer.colors.YELLOW, + ) warnings = True return (remote_tmpdir, location, warnings) -async def setup_new_remote_tmpdir(*, - supported_regions: List[str], - username: str, - service_account: str, - verbose: bool) -> Tuple[Optional[str], str, bool]: +async def setup_new_remote_tmpdir( + *, supported_regions: List[str], username: str, service_account: str, verbose: bool +) -> Tuple[Optional[str], str, bool]: from hailtop.utils import secret_alnum_string # pylint: disable=import-outside-toplevel - from .utils import BucketAlreadyExistsError, InsufficientPermissions, create_gcp_bucket, \ - get_gcp_default_project, grant_service_account_bucket_access_with_role, update_gcp_bucket # pylint: disable=import-outside-toplevel + from .utils import ( # pylint: disable=import-outside-toplevel + BucketAlreadyExistsError, + InsufficientPermissions, + create_gcp_bucket, + get_gcp_default_project, + grant_service_account_bucket_access_with_role, + update_gcp_bucket, + ) token = secret_alnum_string(5).lower() maybe_bucket_name = f'hail-batch-{username}-{token}' @@ -81,9 +103,14 @@ async def setup_new_remote_tmpdir(*, bucket_region = Prompt.ask(f'Which region does your data reside in? (Example: {default_compute_region})') if bucket_region not in supported_regions: - typer.secho(f'The region where your data lives ({bucket_region}) is not in one of the supported regions of the Batch Service ({supported_regions}). ' - f'Creating a bucket in {bucket_region} will incur additional network fees when using the Batch Service.', fg=typer.colors.YELLOW) - continue_w_region_error = Confirm.ask(f'Do you wish to continue setting up the new bucket {bucket_name} in region {bucket_region}?') + typer.secho( + f'The region where your data lives ({bucket_region}) is not in one of the supported regions of the Batch Service ({supported_regions}). ' + f'Creating a bucket in {bucket_region} will incur additional network fees when using the Batch Service.', + fg=typer.colors.YELLOW, + ) + continue_w_region_error = Confirm.ask( + f'Do you wish to continue setting up the new bucket {bucket_name} in region {bucket_region}?' + ) if not continue_w_region_error: raise Abort() @@ -91,10 +118,12 @@ async def setup_new_remote_tmpdir(*, warnings = False set_lifecycle = Confirm.ask( - f'Do you want to set a lifecycle policy (automatically delete files after a time period) on the bucket {bucket_name}?') + f'Do you want to set a lifecycle policy (automatically delete files after a time period) on the bucket {bucket_name}?' + ) if set_lifecycle: lifecycle_days = IntPrompt.ask( - f'After how many days should files be automatically deleted from bucket {bucket_name}?', default=30) + f'After how many days should files be automatically deleted from bucket {bucket_name}?', default=30 + ) if lifecycle_days <= 0: typer.secho(f'Invalid value for lifecycle rule in days {lifecycle_days}', fg=typer.colors.RED) raise Abort() @@ -120,10 +149,15 @@ async def setup_new_remote_tmpdir(*, raise Abort() from e except BucketAlreadyExistsError as e: typer.secho(e.message, fg=typer.colors.YELLOW) - continue_w_update = Confirm.ask(f'Do you wish to continue updating the lifecycle rules and permissions on bucket {bucket_name}?') + continue_w_update = Confirm.ask( + f'Do you wish to continue updating the lifecycle rules and permissions on bucket {bucket_name}?' + ) if not continue_w_update: - typer.secho(f'WARNING: The lifecycle rules and permissions on bucket {bucket_name} were not updated. ' - 'You will have to manually configure these yourself.', fg=typer.colors.YELLOW) + typer.secho( + f'WARNING: The lifecycle rules and permissions on bucket {bucket_name} were not updated. ' + 'You will have to manually configure these yourself.', + fg=typer.colors.YELLOW, + ) warnings = True return (remote_tmpdir, bucket_region, warnings) @@ -139,33 +173,42 @@ async def setup_new_remote_tmpdir(*, typer.secho(e.message, fg=typer.colors.RED) raise Abort() from e - typer.secho(f'Updated bucket {bucket_name} in project {project} with lifecycle rule set to {lifecycle_days} days and labels {labels}.', fg=typer.colors.GREEN) + typer.secho( + f'Updated bucket {bucket_name} in project {project} with lifecycle rule set to {lifecycle_days} days and labels {labels}.', + fg=typer.colors.GREEN, + ) try: - await grant_service_account_bucket_access_with_role(bucket_name, service_account, 'roles/storage.objectAdmin', verbose=verbose) + await grant_service_account_bucket_access_with_role( + bucket_name, service_account, 'roles/storage.objectAdmin', verbose=verbose + ) except InsufficientPermissions as e: typer.secho(e.message, fg=typer.colors.RED) raise Abort() from e - typer.secho(f'Granted service account {service_account} read and write access to {bucket_name} in project {project}.', - fg=typer.colors.GREEN) + typer.secho( + f'Granted service account {service_account} read and write access to {bucket_name} in project {project}.', + fg=typer.colors.GREEN, + ) return (remote_tmpdir, bucket_region, warnings) -async def initialize_gcp(username: str, - hail_identity: str, - supported_regions: List[str], - verbose: bool) -> Tuple[Optional[str], str, bool]: +async def initialize_gcp( + username: str, hail_identity: str, supported_regions: List[str], verbose: bool +) -> Tuple[Optional[str], str, bool]: from .utils import check_for_gcloud # pylint: disable=import-outside-toplevel + assert len(supported_regions) > 0 gcloud_installed = await check_for_gcloud() if not gcloud_installed: - typer.secho('Have you installed gcloud? For directions see https://cloud.google.com/sdk/docs/install ' - 'To log into gcloud run:\n' - '> gcloud auth application-default login', - fg=typer.colors.RED) + typer.secho( + 'Have you installed gcloud? For directions see https://cloud.google.com/sdk/docs/install ' + 'To log into gcloud run:\n' + '> gcloud auth application-default login', + fg=typer.colors.RED, + ) raise Abort() create_remote_tmpdir = Confirm.ask('Do you want to create a new bucket for temporary files generated by Hail?') @@ -186,13 +229,20 @@ async def async_basic_initialize(verbose: bool = False): from hailtop.auth import async_get_userinfo # pylint: disable=import-outside-toplevel from hailtop.batch_client.aioclient import BatchClient # pylint: disable=import-outside-toplevel from hailtop.config.deploy_config import get_deploy_config # pylint: disable=import-outside-toplevel - from hailtop.hailctl.config.cli import set as set_config, list as list_config # pylint: disable=import-outside-toplevel + from hailtop.hailctl.config.cli import ( # pylint: disable=import-outside-toplevel + list as list_config, + ) + from hailtop.hailctl.config.cli import ( # pylint: disable=import-outside-toplevel + set as set_config, + ) from .utils import already_logged_into_service, login_to_service # pylint: disable=import-outside-toplevel already_logged_in = await already_logged_into_service() if not already_logged_in: - typer.secho('You are not currently logged in to Hail. Redirecting you to a login screen.', fg=typer.colors.YELLOW) + typer.secho( + 'You are not currently logged in to Hail. Redirecting you to a login screen.', fg=typer.colors.YELLOW + ) await login_to_service() typer.secho('In the future, you can use `hailctl auth login` to login to Hail.', fg=typer.colors.YELLOW) else: @@ -212,10 +262,17 @@ async def async_basic_initialize(verbose: bool = False): supported_regions = await batch_client.supported_regions() if cloud == 'gcp': - remote_tmpdir, tmpdir_region, warnings = await initialize_gcp(username, hail_identity, supported_regions, verbose) + remote_tmpdir, tmpdir_region, warnings = await initialize_gcp( + username, hail_identity, supported_regions, verbose + ) else: - remote_tmpdir = Prompt.ask('Enter a path to an existing remote temporary directory (ex: https://myaccount.blob.core.windows.net/mycontainer/batch/tmp)') - typer.secho(f'WARNING: You will need to grant read/write access to {remote_tmpdir} for account {hail_identity}', fg=typer.colors.YELLOW) + remote_tmpdir = Prompt.ask( + 'Enter a path to an existing remote temporary directory (ex: https://myaccount.blob.core.windows.net/mycontainer/batch/tmp)' + ) + typer.secho( + f'WARNING: You will need to grant read/write access to {remote_tmpdir} for account {hail_identity}', + fg=typer.colors.YELLOW, + ) warnings = False tmpdir_region = Prompt.ask('Which region is your remote temporary directory in? (Example: eastus)') @@ -223,8 +280,10 @@ async def async_basic_initialize(verbose: bool = False): compute_region = Prompt.ask('Which region do you want your jobs to run in?', choices=supported_regions) if tmpdir_region != compute_region and not compute_region.startswith(tmpdir_region): - typer.secho(f'WARNING: remote temporary directory "{remote_tmpdir}" is not located in the selected compute region for Batch jobs "{compute_region}". Found {tmpdir_region}.', - fg=typer.colors.YELLOW) + typer.secho( + f'WARNING: remote temporary directory "{remote_tmpdir}" is not located in the selected compute region for Batch jobs "{compute_region}". Found {tmpdir_region}.', + fg=typer.colors.YELLOW, + ) warnings = True if trial_bp_name: @@ -246,6 +305,8 @@ async def async_basic_initialize(verbose: bool = False): list_config() if warnings: - typer.secho('WARNING: The currently specified configuration will result in additional fees when using Hail Batch.', - fg=typer.colors.YELLOW) + typer.secho( + 'WARNING: The currently specified configuration will result in additional fees when using Hail Batch.', + fg=typer.colors.YELLOW, + ) raise Exit() diff --git a/hail/python/hailtop/hailctl/batch/submit.py b/hail/python/hailtop/hailctl/batch/submit.py index e6e7c355e92..21547f8e0b5 100644 --- a/hail/python/hailtop/hailctl/batch/submit.py +++ b/hail/python/hailtop/hailctl/batch/submit.py @@ -1,53 +1,99 @@ -import orjson import os +import re from shlex import quote as shq +from typing import Tuple + +import orjson + from hailtop import pip_version +FILE_REGEX = re.compile(r'(?P[^:]+)(:(?P.+))?') + async def submit(name, image_name, files, output, script, arguments): import hailtop.batch as hb # pylint: disable=import-outside-toplevel from hailtop.aiotools.copy import copy_from_dict # pylint: disable=import-outside-toplevel - from hailtop.config import get_remote_tmpdir, get_user_config_path, get_deploy_config # pylint: disable=import-outside-toplevel - from hailtop.utils import secret_alnum_string, unpack_comma_delimited_inputs # pylint: disable=import-outside-toplevel + from hailtop.config import ( # pylint: disable=import-outside-toplevel + get_deploy_config, + get_remote_tmpdir, + get_user_config_path, + ) + from hailtop.utils import ( # pylint: disable=import-outside-toplevel + secret_alnum_string, + unpack_comma_delimited_inputs, + ) files = unpack_comma_delimited_inputs(files) - user_config = get_user_config_path() + user_config = str(get_user_config_path()) + quiet = output != 'text' remote_tmpdir = get_remote_tmpdir('hailctl batch submit') + remote_tmpdir = remote_tmpdir.rstrip('/') + tmpdir_path_prefix = secret_alnum_string() def cloud_prefix(path): + path = path.lstrip('/') return f'{remote_tmpdir}/{tmpdir_path_prefix}/{path}' + def file_input_to_src_dest(file: str) -> Tuple[str, str, str]: + match = FILE_REGEX.match(file) + if match is None: + raise ValueError(f'invalid file specification {file}. Must have the form "src" or "src:dest"') + + result = match.groupdict() + + src = result.get('src') + if src is None: + raise ValueError(f'invalid file specification {file}. Must have a "src" defined.') + src = os.path.abspath(os.path.expanduser(src)) + src = src.rstrip('/') + + dest = result.get('dest') + if dest is not None: + dest = os.path.abspath(os.path.expanduser(dest)) + else: + dest = os.getcwd() + + cloud_file = cloud_prefix(src) + + return (src, dest, cloud_file) + backend = hb.ServiceBackend() b = hb.Batch(name=name, backend=backend) j = b.new_bash_job() j.image(image_name or os.environ.get('HAIL_GENETICS_HAIL_IMAGE', f'hailgenetics/hail:{pip_version()}')) - rel_file_paths = [os.path.relpath(file) for file in files] - local_files_to_cloud_files = [{'from': local, 'to': cloud_prefix(local)} for local in rel_file_paths] + local_files_to_cloud_files = [] + + for file in files: + src, dest, cloud_file = file_input_to_src_dest(file) + local_files_to_cloud_files.append({'from': src, 'to': cloud_file}) + in_file = b.read_input(cloud_file) + j.command(f'mkdir -p {os.path.dirname(dest)}; ln -s {in_file} {dest}') + + script_src, _, script_cloud_file = file_input_to_src_dest(script) + user_config_src, _, user_config_cloud_file = file_input_to_src_dest(user_config) + + await copy_from_dict(files=local_files_to_cloud_files) await copy_from_dict( files=[ - {'from': script, 'to': cloud_prefix(script)}, - {'from': str(user_config), 'to': cloud_prefix(user_config)}, - *local_files_to_cloud_files, + {'from': script_src, 'to': script_cloud_file}, + {'from': user_config_src, 'to': user_config_cloud_file}, ] ) - for file in local_files_to_cloud_files: - local_file = file['from'] - cloud_file = file['to'] - in_file = b.read_input(cloud_file) - j.command(f'ln -s {in_file} {local_file}') - script_file = b.read_input(cloud_prefix(script)) - config_file = b.read_input(cloud_prefix(user_config)) - j.command(f'mkdir -p $HOME/.config/hail && ln -s {config_file} $HOME/.config/hail/config.ini') + script_file = b.read_input(script_cloud_file) + config_file = b.read_input(user_config_cloud_file) j.env('HAIL_QUERY_BACKEND', 'batch') command = 'python3' if script.endswith('.py') else 'bash' script_arguments = " ".join(shq(x) for x in arguments) + + j.command(f'mkdir -p $HOME/.config/hail && ln -s {config_file} $HOME/.config/hail/config.ini') + j.command(f'cd {os.getcwd()}') j.command(f'{command} {script_file} {script_arguments}') batch_handle = await b._async_run(wait=False, disable_progress_bar=quiet) assert batch_handle diff --git a/hail/python/hailtop/hailctl/batch/utils.py b/hail/python/hailtop/hailctl/batch/utils.py index d5e0898dd9d..9e95d23a675 100644 --- a/hail/python/hailtop/hailctl/batch/utils.py +++ b/hail/python/hailtop/hailctl/batch/utils.py @@ -25,11 +25,13 @@ async def already_logged_into_service() -> bool: async def login_to_service(): from hailtop.hailctl.auth.login import async_login # pylint: disable=import-outside-toplevel + await async_login() async def check_for_gcloud() -> bool: from hailtop.utils import check_exec_output # pylint: disable=import-outside-toplevel + try: await check_exec_output('gcloud', 'version') return True @@ -39,6 +41,7 @@ async def check_for_gcloud() -> bool: async def get_gcp_default_project(verbose: bool) -> Optional[str]: from hailtop.utils import check_exec_output # pylint: disable=import-outside-toplevel + try: project, _ = await check_exec_output('gcloud', 'config', 'get-value', 'project', echo=verbose) project_str = project.strip().decode('utf-8') @@ -49,44 +52,56 @@ async def get_gcp_default_project(verbose: bool) -> Optional[str]: async def get_gcp_bucket_information(bucket: str, verbose: bool) -> dict: from hailtop.utils import CalledProcessError, check_exec_output # pylint: disable=import-outside-toplevel + try: - info, _ = await check_exec_output('gcloud', 'storage', 'buckets', 'describe', f'gs://{bucket}', '--format="json"', echo=verbose) + info, _ = await check_exec_output( + 'gcloud', 'storage', 'buckets', 'describe', f'gs://{bucket}', '--format="json"', echo=verbose + ) return json.loads(info.decode('utf-8')) except CalledProcessError as e: if 'does not have storage.buckets.get access to the Google Cloud Storage bucket' in e.stderr.decode('utf-8'): - msg = f'ERROR: You do not have sufficient permissions to get information about bucket {bucket} or it does not exist. ' \ - f'If the bucket exists, ask a project administrator to give you the permission "storage.buckets.get" or ' \ - f'assign you the StorageAdmin role in Google Cloud Storage.' + msg = ( + f'ERROR: You do not have sufficient permissions to get information about bucket {bucket} or it does not exist. ' + f'If the bucket exists, ask a project administrator to give you the permission "storage.buckets.get" or ' + f'assign you the StorageAdmin role in Google Cloud Storage.' + ) raise InsufficientPermissions(msg) from e raise -async def create_gcp_bucket(*, - project: str, - bucket: str, - location: str, - verbose: bool): +async def create_gcp_bucket(*, project: str, bucket: str, location: str, verbose: bool): from hailtop.utils import CalledProcessError, check_exec_output # pylint: disable=import-outside-toplevel try: - await check_exec_output('gcloud', '--project', project, 'storage', 'buckets', 'create', f'gs://{bucket}', f'--location={location}', echo=verbose) + await check_exec_output( + 'gcloud', + '--project', + project, + 'storage', + 'buckets', + 'create', + f'gs://{bucket}', + f'--location={location}', + echo=verbose, + ) except CalledProcessError as e: if 'does not have storage.buckets.create access to the Google Cloud project' in e.stderr.decode('utf-8'): - msg = f'ERROR: You do not have the necessary permissions to create buckets in project {project}. Ask a project administrator ' \ - f'to give you the permission "storage.buckets.create" or assign you the StorageAdmin role or ask them to create the bucket {bucket} on your behalf.' + msg = ( + f'ERROR: You do not have the necessary permissions to create buckets in project {project}. Ask a project administrator ' + f'to give you the permission "storage.buckets.create" or assign you the StorageAdmin role or ask them to create the bucket {bucket} on your behalf.' + ) raise InsufficientPermissions(msg) from e - if 'Your previous request to create the named bucket succeeded and you already own it' in e.stderr.decode('utf-8'): + if 'Your previous request to create the named bucket succeeded and you already own it' in e.stderr.decode( + 'utf-8' + ): msg = f'WARNING: Bucket {bucket} was previously created.' raise BucketAlreadyExistsError(msg) from e raise -async def update_gcp_bucket(*, - project: str, - bucket: str, - lifecycle_days: Optional[int], - labels: Optional[Dict[str, str]], - verbose: bool): +async def update_gcp_bucket( + *, project: str, bucket: str, lifecycle_days: Optional[int], labels: Optional[Dict[str, str]], verbose: bool +): from hailtop.utils import CalledProcessError, check_exec_output # pylint: disable=import-outside-toplevel if labels: @@ -96,28 +111,41 @@ async def update_gcp_bucket(*, try: if lifecycle_days: - lifecycle_policy = { - "rule": [ - { - "action": {"type": "Delete"}, - "condition": {"age": lifecycle_days} - } - ] - } + lifecycle_policy = {"rule": [{"action": {"type": "Delete"}, "condition": {"age": lifecycle_days}}]} with tempfile.NamedTemporaryFile(mode='w') as f: f.write(json.dumps(lifecycle_policy)) f.flush() - await check_exec_output('gcloud', '--project', project, 'storage', 'buckets', 'update', - f'--lifecycle-file={f.name}', f'gs://{bucket}', echo=verbose) + await check_exec_output( + 'gcloud', + '--project', + project, + 'storage', + 'buckets', + 'update', + f'--lifecycle-file={f.name}', + f'gs://{bucket}', + echo=verbose, + ) if labels_str: - await check_exec_output('gcloud', '--project', project, 'storage', 'buckets', 'update', - f'--update-labels={labels_str}', f'gs://{bucket}', echo=verbose) + await check_exec_output( + 'gcloud', + '--project', + project, + 'storage', + 'buckets', + 'update', + f'--update-labels={labels_str}', + f'gs://{bucket}', + echo=verbose, + ) except CalledProcessError as e: if 'does not have storage.buckets.get access to the Google Cloud Storage bucket' in e.stderr.decode('utf-8'): - msg = f'ERROR: You do not have the necessary permissions to update bucket {bucket} in project {project}. Ask a project administrator ' \ - f'to assign you the StorageAdmin role in Google Cloud Storage for bucket {bucket} or ask them to update the bucket {bucket} on your behalf.' + msg = ( + f'ERROR: You do not have the necessary permissions to update bucket {bucket} in project {project}. Ask a project administrator ' + f'to assign you the StorageAdmin role in Google Cloud Storage for bucket {bucket} or ask them to update the bucket {bucket} on your behalf.' + ) if lifecycle_days: msg += f'Update the bucket to have a lifecycle policy of {lifecycle_days} days.' if labels_str: @@ -131,12 +159,26 @@ async def grant_service_account_bucket_access_with_role(bucket: str, service_acc try: service_account_member = f'serviceAccount:{service_account}' - await check_exec_output('gcloud', 'storage', 'buckets', 'add-iam-policy-binding', f'gs://{bucket}', '--member', service_account_member, '--role', role, - echo=verbose) + await check_exec_output( + 'gcloud', + 'storage', + 'buckets', + 'add-iam-policy-binding', + f'gs://{bucket}', + '--member', + service_account_member, + '--role', + role, + echo=verbose, + ) except CalledProcessError as e: - if 'does not have storage.buckets.getIamPolicy access to the Google Cloud Storage bucket' in e.stderr.decode('utf-8'): - msg = f'ERROR: You do not have the necessary permissions to set permissions for bucket {bucket}. Ask a project administrator ' \ - f'to assign you the StorageIAMAdmin role in Google Cloud Storage or ask them to update the permissions on your behalf by giving ' \ - f'service account {service_account} the role "{role}" for bucket {bucket}.' + if 'does not have storage.buckets.getIamPolicy access to the Google Cloud Storage bucket' in e.stderr.decode( + 'utf-8' + ): + msg = ( + f'ERROR: You do not have the necessary permissions to set permissions for bucket {bucket}. Ask a project administrator ' + f'to assign you the StorageIAMAdmin role in Google Cloud Storage or ask them to update the permissions on your behalf by giving ' + f'service account {service_account} the role "{role}" for bucket {bucket}.' + ) raise InsufficientPermissions(msg) from e raise diff --git a/hail/python/hailtop/hailctl/config/cli.py b/hail/python/hailtop/hailctl/config/cli.py index b55039f4f31..4c60eec3526 100644 --- a/hail/python/hailtop/hailctl/config/cli.py +++ b/hail/python/hailtop/hailctl/config/cli.py @@ -1,15 +1,15 @@ import os import sys - -from typing import Optional, Tuple, Annotated as Ann -from rich import print +from typing import Annotated as Ann +from typing import Optional, Tuple import typer +from rich import print from typer import Argument as Arg from hailtop.config.variables import ConfigVariable -from .config_variables import config_variables +from .config_variables import config_variables app = typer.Typer( name='config', @@ -26,7 +26,7 @@ def get_section_key_path(parameter: str) -> Tuple[str, str, Tuple[str, ...]]: if len(path) == 2: return path[0], path[1], tuple(path) print( - ''' + """ Parameters must contain at most one slash separating the configuration section from the configuration parameter, for example: "batch/billing_project". @@ -35,9 +35,7 @@ def get_section_key_path(parameter: str) -> Tuple[str, str, Tuple[str, ...]]: A parameter with more than one slash is invalid, for example: "batch/billing/project". -'''.lstrip( - '\n' - ), +""".lstrip('\n'), file=sys.stderr, ) sys.exit(1) @@ -50,8 +48,11 @@ def complete_config_variable(incomplete: str): @app.command() -def set(parameter: Ann[ConfigVariable, Arg(help="Configuration variable to set", autocompletion=complete_config_variable)], value: str): - '''Set a Hail configuration parameter.''' +def set( + parameter: Ann[ConfigVariable, Arg(help="Configuration variable to set", autocompletion=complete_config_variable)], + value: str, +): + """Set a Hail configuration parameter.""" from hailtop.config import get_user_config, get_user_config_path # pylint: disable=import-outside-toplevel if parameter not in config_variables(): @@ -61,7 +62,7 @@ def set(parameter: Ann[ConfigVariable, Arg(help="Configuration variable to set", section, key, _ = get_section_key_path(parameter.value) config_variable_info = config_variables()[parameter] - validation_func, error_msg = config_variable_info.validation + validation_func, error_msg = config_variable_info.validation if not validation_func(value): print(f"Error: bad value {value!r} for parameter {parameter!r} {error_msg}", file=sys.stderr) @@ -107,7 +108,7 @@ def get_config_variable(incomplete: str): @app.command() def unset(parameter: Ann[str, Arg(help="Configuration variable to unset", autocompletion=get_config_variable)]): - '''Unset a Hail configuration parameter (restore to default behavior).''' + """Unset a Hail configuration parameter (restore to default behavior).""" from hailtop.config import get_user_config, get_user_config_path # pylint: disable=import-outside-toplevel config = get_user_config() @@ -123,7 +124,7 @@ def unset(parameter: Ann[str, Arg(help="Configuration variable to unset", autoco @app.command() def get(parameter: Ann[str, Arg(help="Configuration variable to get", autocompletion=get_config_variable)]): - '''Get the value of a Hail configuration parameter.''' + """Get the value of a Hail configuration parameter.""" from hailtop.config import get_user_config # pylint: disable=import-outside-toplevel config = get_user_config() @@ -134,7 +135,7 @@ def get(parameter: Ann[str, Arg(help="Configuration variable to get", autocomple @app.command(name='config-location') def config_location(): - '''Print the location of the config file.''' + """Print the location of the config file.""" from hailtop.config import get_user_config_path # pylint: disable=import-outside-toplevel print(get_user_config_path()) @@ -142,7 +143,7 @@ def config_location(): @app.command() def list(section: Ann[Optional[str], Arg(show_default='all sections')] = None): - '''Lists every config variable in the section.''' + """Lists every config variable in the section.""" from hailtop.config import get_user_config # pylint: disable=import-outside-toplevel config = get_user_config() diff --git a/hail/python/hailtop/hailctl/config/config_variables.py b/hail/python/hailtop/hailctl/config/config_variables.py index 966a008a854..a6e67cab47e 100644 --- a/hail/python/hailtop/hailctl/config/config_variables.py +++ b/hail/python/hailtop/hailctl/config/config_variables.py @@ -1,17 +1,24 @@ -from collections import namedtuple import re +from collections import namedtuple from hailtop.config import ConfigVariable - _config_variables = None ConfigVariableInfo = namedtuple('ConfigVariableInfo', ['help_msg', 'validation']) +def _is_float_str(x: str) -> bool: + try: + float(x) + return True + except ValueError: + return False + + def config_variables(): - from hailtop.batch_client.parse import CPU_REGEXPAT, MEMORY_REGEXPAT # pylint: disable=import-outside-toplevel from hailtop.aiotools.router_fs import RouterAsyncFS # pylint: disable=import-outside-toplevel + from hailtop.batch_client.parse import CPU_REGEXPAT, MEMORY_REGEXPAT # pylint: disable=import-outside-toplevel global _config_variables @@ -29,7 +36,8 @@ def config_variables(): help_msg='Allowed buckets when using requester pays in GCS', validation=( lambda x: re.fullmatch(r'[^:/\s]+(,[^:/\s]+)*', x) is not None, - 'should be comma separated list of bucket names'), + 'should be comma separated list of bucket names', + ), ), ConfigVariable.GCS_BUCKET_ALLOW_LIST: ConfigVariableInfo( help_msg=( @@ -38,26 +46,36 @@ def config_variables(): validation=( # See https://cloud.google.com/storage/docs/buckets#naming for bucket naming requirements. lambda x: re.fullmatch(r'^[-\.\w]+(,[-\.\w]+)*$', x) is not None, - "should match the pattern 'bucket1,bucket2,bucket3'." + "should match the pattern 'bucket1,bucket2,bucket3'.", ), ), ConfigVariable.BATCH_BUCKET: ConfigVariableInfo( help_msg='Deprecated - Name of GCS bucket to use as a temporary scratch directory', - validation=(lambda x: re.fullmatch(r'[^:/\s]+', x) is not None, - 'should be valid Google Bucket identifier, with no gs:// prefix'), + validation=( + lambda x: re.fullmatch(r'[^:/\s]+', x) is not None, + 'should be valid Google Bucket identifier, with no gs:// prefix', + ), ), ConfigVariable.BATCH_REMOTE_TMPDIR: ConfigVariableInfo( help_msg='Cloud storage URI to use as a temporary scratch directory', - validation=(RouterAsyncFS.valid_url, 'should be valid cloud storage URI such as gs://my-bucket/batch-tmp/'), + validation=( + RouterAsyncFS.valid_url, + 'should be valid cloud storage URI such as gs://my-bucket/batch-tmp/', + ), ), ConfigVariable.BATCH_REGIONS: ConfigVariableInfo( help_msg='Comma-separated list of regions to run jobs in', validation=( - lambda x: re.fullmatch(r'[^\s]+(,[^\s]+)*', x) is not None, 'should be comma separated list of regions'), + lambda x: re.fullmatch(r'[^\s]+(,[^\s]+)*', x) is not None, + 'should be comma separated list of regions', + ), ), ConfigVariable.BATCH_BILLING_PROJECT: ConfigVariableInfo( help_msg='Batch billing project', - validation=(lambda x: re.fullmatch(r'[^:/\s]+', x) is not None, 'should be valid Batch billing project name'), + validation=( + lambda x: re.fullmatch(r'[^:/\s]+', x) is not None, + 'should be valid Batch billing project name', + ), ), ConfigVariable.BATCH_BACKEND: ConfigVariableInfo( help_msg='Backend to use. One of local or service.', @@ -65,31 +83,45 @@ def config_variables(): ), ConfigVariable.QUERY_BACKEND: ConfigVariableInfo( help_msg='Backend to use for Hail Query. One of spark, local, batch.', - validation=(lambda x: x in ('local', 'spark', 'batch'), 'should be one of "local", "spark", or "batch"'), + validation=( + lambda x: x in ('local', 'spark', 'batch'), + 'should be one of "local", "spark", or "batch"', + ), ), ConfigVariable.QUERY_JAR_URL: ConfigVariableInfo( help_msg='Cloud storage URI to a Query JAR', - validation=(RouterAsyncFS.valid_url, 'should be valid cloud storage URI such as gs://my-bucket/jars/sha.jar') + validation=( + RouterAsyncFS.valid_url, + 'should be valid cloud storage URI such as gs://my-bucket/jars/sha.jar', + ), ), ConfigVariable.QUERY_BATCH_DRIVER_CORES: ConfigVariableInfo( help_msg='Cores specification for the query driver', - validation=(lambda x: re.fullmatch(CPU_REGEXPAT, x) is not None, - 'should be an integer which is a power of two from 1 to 16 inclusive'), + validation=( + lambda x: re.fullmatch(CPU_REGEXPAT, x) is not None, + 'should be an integer which is a power of two from 1 to 16 inclusive', + ), ), ConfigVariable.QUERY_BATCH_WORKER_CORES: ConfigVariableInfo( help_msg='Cores specification for the query worker', - validation=(lambda x: re.fullmatch(CPU_REGEXPAT, x) is not None, - 'should be an integer which is a power of two from 1 to 16 inclusive'), + validation=( + lambda x: re.fullmatch(CPU_REGEXPAT, x) is not None, + 'should be an integer which is a power of two from 1 to 16 inclusive', + ), ), ConfigVariable.QUERY_BATCH_DRIVER_MEMORY: ConfigVariableInfo( help_msg='Memory specification for the query driver', - validation=(lambda x: re.fullmatch(MEMORY_REGEXPAT, x) is not None or x in ('standard', 'lowmem', 'highmem'), - 'should be a valid string specifying memory "[+]?((?:[0-9]*[.])?[0-9]+)([KMGTP][i]?)?B?" or one of standard, lowmem, highmem'), + validation=( + lambda x: re.fullmatch(MEMORY_REGEXPAT, x) is not None or x in ('standard', 'lowmem', 'highmem'), + 'should be a valid string specifying memory "[+]?((?:[0-9]*[.])?[0-9]+)([KMGTP][i]?)?B?" or one of standard, lowmem, highmem', + ), ), ConfigVariable.QUERY_BATCH_WORKER_MEMORY: ConfigVariableInfo( help_msg='Memory specification for the query worker', - validation=(lambda x: re.fullmatch(MEMORY_REGEXPAT, x) is not None or x in ('standard', 'lowmem', 'highmem'), - 'should be a valid string specifying memory "[+]?((?:[0-9]*[.])?[0-9]+)([KMGTP][i]?)?B?" or one of standard, lowmem, highmem'), + validation=( + lambda x: re.fullmatch(MEMORY_REGEXPAT, x) is not None or x in ('standard', 'lowmem', 'highmem'), + 'should be a valid string specifying memory "[+]?((?:[0-9]*[.])?[0-9]+)([KMGTP][i]?)?B?" or one of standard, lowmem, highmem', + ), ), ConfigVariable.QUERY_NAME_PREFIX: ConfigVariableInfo( help_msg='Name used when displaying query progress in a progress bar', @@ -99,6 +131,10 @@ def config_variables(): help_msg='Disable the progress bar with a value of 1. Enable the progress bar with a value of 0', validation=(lambda x: x in ('0', '1'), 'should be a value of 0 or 1'), ), + ConfigVariable.HTTP_TIMEOUT_IN_SECONDS: ConfigVariableInfo( + help_msg='The default timeout for HTTP requests in seconds.', + validation=(_is_float_str, 'should be a float or an int like 42.42 or 42'), + ), } return _config_variables diff --git a/hail/python/hailtop/hailctl/dataproc/cli.py b/hail/python/hailtop/hailctl/dataproc/cli.py index 3ff55f6a99c..c97b2159e63 100644 --- a/hail/python/hailtop/hailctl/dataproc/cli.py +++ b/hail/python/hailtop/hailctl/dataproc/cli.py @@ -1,18 +1,20 @@ import sys +from typing import Annotated as Ann +from typing import List, Optional import typer -from typer import Option as Opt, Argument as Arg +from typer import Argument as Arg +from typer import Option as Opt -from typing import List, Optional, Annotated as Ann - -from .connect import connect as dataproc_connect, DataprocConnectService -from .submit import submit as dataproc_submit -from .diagnose import diagnose as dataproc_diagnose -from .modify import modify as dataproc_modify -from .start import start as dataproc_start, VepVersion from ..describe import describe from . import gcloud - +from .connect import DataprocConnectService +from .connect import connect as dataproc_connect +from .diagnose import diagnose as dataproc_diagnose +from .modify import modify as dataproc_modify +from .start import VepVersion +from .start import start as dataproc_start +from .submit import submit as dataproc_submit MINIMUM_REQUIRED_GCLOUD_VERSION = (285, 0, 0) @@ -190,9 +192,9 @@ def start( bool, Opt(help='Enable debug features on created cluster (heap dump on out-of-memory error)') ] = False, ): - ''' + """ Start a Dataproc cluster configured for Hail. - ''' + """ assert num_secondary_workers is not None assert num_workers is not None @@ -251,9 +253,9 @@ def stop( asink: Ann[bool, Opt('--async/--sync', help='Do not wait for cluster deletion')] = False, dry_run: DryRunOption = False, ): - ''' + """ Shut down a Dataproc cluster. - ''' + """ print("Stopping cluster '{}'...".format(name)) cmd = ['dataproc', 'clusters', 'delete', '--quiet', name] @@ -273,9 +275,9 @@ def stop( def list( ctx: typer.Context, ): - ''' + """ List active Dataproc clusters. - ''' + """ gcloud.run(['dataproc', 'clusters', 'list', *ctx.args]) @@ -289,10 +291,10 @@ def connect( zone: ZoneOption = None, dry_run: DryRunOption = False, ): - ''' + """ Connect to a running Dataproc cluster with name NAME and start the web service SERVICE. - ''' + """ dataproc_connect(name, service, project, port, zone, dry_run, pass_through_args or []) @@ -317,9 +319,11 @@ def submit( ] = None, dry_run: DryRunOption = False, region: Ann[Optional[str], Opt(help='Compute region for the cluster.')] = None, - arguments: Ann[Optional[List[str]], Arg(help='You should use -- if you want to pass option-like arguments through.')] = None, + arguments: Ann[ + Optional[List[str]], Arg(help='You should use -- if you want to pass option-like arguments through.') + ] = None, ): - '''Submit the Python script at path SCRIPT to a running Dataproc cluster with name NAME. + """Submit the Python script at path SCRIPT to a running Dataproc cluster with name NAME. You may pass arguments to the script being submitted by listing them after the script; however, if you wish to pass option-like arguments you should use "--". For example: @@ -328,8 +332,10 @@ def submit( $ hailctl dataproc submit name --image-name docker.io/image my_script.py -- some-argument --animal dog - ''' - dataproc_submit(name, script, files, pyfiles, properties, gcloud_configuration, dry_run, region, [*(arguments or []), *ctx.args]) + """ + dataproc_submit( + name, script, files, pyfiles, properties, gcloud_configuration, dry_run, region, [*(arguments or []), *ctx.args] + ) @app.command() @@ -343,9 +349,9 @@ def diagnose( workers: Ann[Optional[List[str]], Opt(help='Specific workers to get log files from.')] = None, take: Ann[Optional[int], Opt(help='Only download logs from the first N workers.')] = None, ): - ''' + """ Diagnose problems in a Dataproc cluster with name NAME. - ''' + """ dataproc_diagnose(name, dest, hail_log, overwrite, no_diagnose, compress, workers or [], take) @@ -395,9 +401,9 @@ def modify( ] = False, wheel: Ann[Optional[str], Opt(help='New Hail installation.')] = None, ): - ''' + """ Modify an active dataproc cluster with name NAME. - ''' + """ dataproc_modify( name, num_workers, diff --git a/hail/python/hailtop/hailctl/dataproc/connect.py b/hail/python/hailtop/hailctl/dataproc/connect.py index 0bec6db7097..fe193182027 100755 --- a/hail/python/hailtop/hailctl/dataproc/connect.py +++ b/hail/python/hailtop/hailctl/dataproc/connect.py @@ -1,12 +1,10 @@ -from enum import Enum import os import platform import shutil import subprocess import tempfile - -from typing import Optional, List - +from enum import Enum +from typing import List, Optional from . import gcloud diff --git a/hail/python/hailtop/hailctl/dataproc/diagnose.py b/hail/python/hailtop/hailctl/dataproc/diagnose.py index 1bbba77b6a4..ecb0c4582f5 100644 --- a/hail/python/hailtop/hailctl/dataproc/diagnose.py +++ b/hail/python/hailtop/hailctl/dataproc/diagnose.py @@ -1,8 +1,7 @@ -import re import json - +import re +from subprocess import PIPE, Popen, call from typing import List, Optional -from subprocess import call, Popen, PIPE def diagnose( diff --git a/hail/python/hailtop/hailctl/dataproc/gcloud.py b/hail/python/hailtop/hailctl/dataproc/gcloud.py index 1f314daa42f..311127004d5 100644 --- a/hail/python/hailtop/hailctl/dataproc/gcloud.py +++ b/hail/python/hailtop/hailctl/dataproc/gcloud.py @@ -1,7 +1,7 @@ -from typing import Tuple, List, Optional import json import subprocess import sys +from typing import List, Optional, Tuple def run(command: List[str]): diff --git a/hail/python/hailtop/hailctl/dataproc/modify.py b/hail/python/hailtop/hailctl/dataproc/modify.py index 351cda7d6ca..5921d63bbe5 100644 --- a/hail/python/hailtop/hailctl/dataproc/modify.py +++ b/hail/python/hailtop/hailctl/dataproc/modify.py @@ -1,6 +1,5 @@ import os.path import sys - from typing import List, Optional from . import gcloud @@ -80,41 +79,37 @@ def modify( wheelfile = os.path.basename(wheel) cmds = [] if wheel.startswith("gs://"): - cmds.append( + cmds.append([ + 'compute', + 'ssh', + '{}-m'.format(name), + '--zone={}'.format(zone), + '--', + f'sudo gsutil cp {wheel} /tmp/ && ' + 'sudo /opt/conda/default/bin/pip uninstall -y hail && ' + f'sudo /opt/conda/default/bin/pip install --no-dependencies /tmp/{wheelfile} && ' + f"unzip /tmp/{wheelfile} && " + "requirements_file=$(mktemp) && " + "grep 'Requires-Dist: ' hail*dist-info/METADATA | sed 's/Requires-Dist: //' | sed 's/ (//' | sed 's/)//' | grep -v 'pyspark' >$requirements_file &&" + "/opt/conda/default/bin/pip install -r $requirements_file", + ]) + else: + cmds.extend([ + ['compute', 'scp', '--zone={}'.format(zone), wheel, '{}-m:/tmp/'.format(name)], [ 'compute', 'ssh', - '{}-m'.format(name), - '--zone={}'.format(zone), + f'{name}-m', + f'--zone={zone}', '--', - f'sudo gsutil cp {wheel} /tmp/ && ' 'sudo /opt/conda/default/bin/pip uninstall -y hail && ' f'sudo /opt/conda/default/bin/pip install --no-dependencies /tmp/{wheelfile} && ' f"unzip /tmp/{wheelfile} && " "requirements_file=$(mktemp) && " "grep 'Requires-Dist: ' hail*dist-info/METADATA | sed 's/Requires-Dist: //' | sed 's/ (//' | sed 's/)//' | grep -v 'pyspark' >$requirements_file &&" "/opt/conda/default/bin/pip install -r $requirements_file", - ] - ) - else: - cmds.extend( - [ - ['compute', 'scp', '--zone={}'.format(zone), wheel, '{}-m:/tmp/'.format(name)], - [ - 'compute', - 'ssh', - f'{name}-m', - f'--zone={zone}', - '--', - 'sudo /opt/conda/default/bin/pip uninstall -y hail && ' - f'sudo /opt/conda/default/bin/pip install --no-dependencies /tmp/{wheelfile} && ' - f"unzip /tmp/{wheelfile} && " - "requirements_file=$(mktemp) && " - "grep 'Requires-Dist: ' hail*dist-info/METADATA | sed 's/Requires-Dist: //' | sed 's/ (//' | sed 's/)//' | grep -v 'pyspark' >$requirements_file &&" - "/opt/conda/default/bin/pip install -r $requirements_file", - ], - ] - ) + ], + ]) for cmd in cmds: print('gcloud ' + ' '.join(cmd)) diff --git a/hail/python/hailtop/hailctl/dataproc/resources/init_notebook.py b/hail/python/hailtop/hailctl/dataproc/resources/init_notebook.py index c0a0c1a31ea..f2cbd824466 100644 --- a/hail/python/hailtop/hailctl/dataproc/resources/init_notebook.py +++ b/hail/python/hailtop/hailctl/dataproc/resources/init_notebook.py @@ -1,9 +1,9 @@ #!/opt/conda/default/bin/python3 +import errno import json import os import subprocess as sp import sys -import errno from subprocess import check_output assert sys.version_info > (3, 0), sys.version_info diff --git a/hail/python/hailtop/hailctl/dataproc/resources/vep-GRCh37.sh b/hail/python/hailtop/hailctl/dataproc/resources/vep-GRCh37.sh index c46d66fb947..37a9b7b4e2c 100644 --- a/hail/python/hailtop/hailctl/dataproc/resources/vep-GRCh37.sh +++ b/hail/python/hailtop/hailctl/dataproc/resources/vep-GRCh37.sh @@ -31,7 +31,7 @@ sleep 60 sudo service docker restart # Get VEP cache and LOFTEE data -gcloud storage cp --billing-project $PROJECT gs://hail-us-vep/vep85-loftee-gcloud.json /vep_data/vep85-gcloud.json +gcloud storage cp --billing-project $PROJECT gs://hail-us-central1-vep/vep85-loftee-gcloud.json /vep_data/vep85-gcloud.json ln -s /vep_data/vep85-gcloud.json $VEP_CONFIG_PATH gcloud storage cat --billing-project $PROJECT gs://${VEP_BUCKET}/loftee-beta/${ASSEMBLY}.tar | tar -xf - -C /vep_data & diff --git a/hail/python/hailtop/hailctl/dataproc/resources/vep-GRCh38.sh b/hail/python/hailtop/hailctl/dataproc/resources/vep-GRCh38.sh index c6711157de9..e2ac8208db3 100644 --- a/hail/python/hailtop/hailctl/dataproc/resources/vep-GRCh38.sh +++ b/hail/python/hailtop/hailctl/dataproc/resources/vep-GRCh38.sh @@ -31,11 +31,11 @@ sleep 60 sudo service docker restart # Get VEP cache and LOFTEE data -gcloud storage cp --billing-project $PROJECT gs://hail-us-vep/vep95-GRCh38-loftee-gcloud.json /vep_data/vep95-GRCh38-gcloud.json +gcloud storage cp --billing-project $PROJECT gs://hail-us-central1-vep/vep95-GRCh38-loftee-gcloud.json /vep_data/vep95-GRCh38-gcloud.json ln -s /vep_data/vep95-GRCh38-gcloud.json $VEP_CONFIG_PATH gcloud storage cat --billing-project $PROJECT gs://${VEP_BUCKET}/loftee-beta/${ASSEMBLY}.tar | tar -xf - -C /vep_data/ & -gcloud storage cat --billing-project $PROJECT gs://${VEP_BUCKET}/homo-sapiens/95_${ASSEMBLY}.tar | tar -xf - -C /vep_data/homo_sapiens & +gcloud storage cat --billing-project $PROJECT gs://${VEP_BUCKET}/homo-sapiens/95_${ASSEMBLY}_indexed.tar | tar -xf - -C /vep_data/homo_sapiens & docker pull ${VEP_DOCKER_IMAGE} & wait diff --git a/hail/python/hailtop/hailctl/dataproc/start.py b/hail/python/hailtop/hailctl/dataproc/start.py index 5d10d958186..7a3c167188d 100755 --- a/hail/python/hailtop/hailctl/dataproc/start.py +++ b/hail/python/hailtop/hailctl/dataproc/start.py @@ -1,10 +1,10 @@ import re from enum import Enum +from shlex import quote as shq +from typing import List, Optional import yaml -from typing import Optional, List - from . import gcloud from .cluster_config import ClusterConfig @@ -126,25 +126,11 @@ class VepVersion(str, Enum): 'c2-standard-60': 240, } -REGION_TO_REPLICATE_MAPPING = { - 'us-central1': 'us', - 'us-east1': 'us', - 'us-east4': 'us', - 'us-west1': 'us', - 'us-west2': 'us', - 'us-west3': 'us', - # Europe != EU - 'europe-north1': 'eu', - 'europe-west1': 'eu', - 'europe-west2': 'uk', - 'europe-west3': 'eu', - 'europe-west4': 'eu', - 'australia-southeast1': 'aus-sydney', -} +VEP_SUPPORTED_REGIONS = {'us-central1', 'europe-west1', 'europe-west2', 'australia-southeast1'} -ANNOTATION_DB_BUCKETS = ["hail-datasets-us", "hail-datasets-eu"] +ANNOTATION_DB_BUCKETS = ["hail-datasets-us-central1", "hail-datasets-europe-west1"] -IMAGE_VERSION = '2.1.2-debian11' +IMAGE_VERSION = '2.1.33-debian11' def start( @@ -266,17 +252,16 @@ def start( # add VEP init script if vep: - # VEP is too expensive if you have to pay egress charges. We must choose the right replicate. - replicate = REGION_TO_REPLICATE_MAPPING.get(project_region) - if replicate is None: + if project_region not in VEP_SUPPORTED_REGIONS: + # VEP is too expensive if you have to pay egress charges. raise RuntimeError( f"The --vep argument is not currently provided in your region.\n" f" Please contact the Hail team on https://discuss.hail.is for support.\n" f" Your region: {project_region}\n" - f" Supported regions: {', '.join(REGION_TO_REPLICATE_MAPPING.keys())}" + f" Supported regions: {', '.join(VEP_SUPPORTED_REGIONS)}" ) - print(f"Pulling VEP data from bucket in {replicate}.") - conf.extend_flag('metadata', {"VEP_REPLICATE": replicate}) + print(f"Pulling VEP data from bucket in {project_region}.") + conf.extend_flag('metadata', {"VEP_REPLICATE": project_region}) vep_config_path = "/vep_data/vep-gcloud.json" conf.extend_flag( 'metadata', {"VEP_CONFIG_PATH": vep_config_path, "VEP_CONFIG_URI": f"file://{vep_config_path}"} @@ -317,9 +302,7 @@ def jvm_heap_size_gib(machine_type: str, memory_fraction: float) -> int: conf.extend_flag( 'properties', - { - "spark:spark.driver.memory": f"{jvm_heap_size_gib(master_machine_type, master_memory_fraction)}g" - }, + {"spark:spark.driver.memory": f"{jvm_heap_size_gib(master_machine_type, master_memory_fraction)}g"}, ) conf.flags['master-machine-type'] = master_machine_type conf.flags['master-boot-disk-size'] = '{}GB'.format(master_boot_disk_size) @@ -421,7 +404,13 @@ def jvm_heap_size_gib(machine_type: str, memory_fraction: float) -> int: cmd.extend(pass_through_args) # print underlying gcloud command - print(' '.join(cmd[:5]) + ' \\\n ' + ' \\\n '.join(cmd[5:])) + print( + ''.join([ + ' '.join(shq(x) for x in cmd[:5]), + ' \\\n ', + ' \\\n '.join(shq(x) for x in cmd[5:]), + ]) + ) # spin up cluster if not dry_run: diff --git a/hail/python/hailtop/hailctl/dataproc/submit.py b/hail/python/hailtop/hailctl/dataproc/submit.py index 43469961af3..56ac0006cf7 100644 --- a/hail/python/hailtop/hailctl/dataproc/submit.py +++ b/hail/python/hailtop/hailctl/dataproc/submit.py @@ -1,8 +1,7 @@ import os import tempfile import zipfile - -from typing import Optional, List +from typing import List, Optional from . import gcloud @@ -48,7 +47,9 @@ def _filter_pyfile(fname: str) -> bool: if os.path.isfile(path) and _filter_pyfile(path): zipf.write( os.path.join(root, pyfile), - os.path.relpath(os.path.join(root, pyfile), os.path.join(hail_script_entry, '..')), + os.path.relpath( + os.path.join(root, pyfile), os.path.join(hail_script_entry, '..') + ), ) pyfiles = tfile diff --git a/hail/python/hailtop/hailctl/describe.py b/hail/python/hailtop/hailctl/describe.py index f9cec56a086..c0257095bf9 100644 --- a/hail/python/hailtop/hailctl/describe.py +++ b/hail/python/hailtop/hailctl/describe.py @@ -1,11 +1,12 @@ import asyncio -import orjson -from typing import List, Optional, Union, Annotated as Ann -from os import path -from zlib import decompress, MAX_WBITS -from statistics import median, mean, stdev from collections import OrderedDict +from os import path +from statistics import mean, median, stdev +from typing import Annotated as Ann +from typing import List, Optional, Union +from zlib import MAX_WBITS, decompress +import orjson from typer import Option as Opt SECTION_SEPARATOR = '-' * 40 @@ -78,15 +79,13 @@ def get_partitions_info_str(j): 'Empty partitions': len([p for p in partitions if p == 0]), } if partitions_info['Partitions'] > 1: - partitions_info.update( - { - 'Min(rows/partition)': min(partitions), - 'Max(rows/partition)': max(partitions), - 'Median(rows/partition)': median(partitions), - 'Mean(rows/partition)': int(mean(partitions)), - 'StdDev(rows/partition)': int(stdev(partitions)), - } - ) + partitions_info.update({ + 'Min(rows/partition)': min(partitions), + 'Max(rows/partition)': max(partitions), + 'Median(rows/partition)': median(partitions), + 'Mean(rows/partition)': int(mean(partitions)), + 'StdDev(rows/partition)': int(stdev(partitions)), + }) return "\n{}".format(IDENT).join(['{}: {}'.format(k, v) for k, v in partitions_info.items()]) @@ -98,9 +97,9 @@ def describe( Opt('--requester-pays-project-id', '-u', help='Project to be billed for GCS requests.'), ] = None, ): - ''' + """ Describe the MatrixTable or Table at path FILE. - ''' + """ asyncio.run(async_describe(file, requester_pays_project_id)) diff --git a/hail/python/hailtop/hailctl/dev/ci_client.py b/hail/python/hailtop/hailctl/dev/ci_client.py index c51c6d3a646..4bb2f48112b 100644 --- a/hail/python/hailtop/hailctl/dev/ci_client.py +++ b/hail/python/hailtop/hailctl/dev/ci_client.py @@ -1,11 +1,11 @@ -import aiohttp import sys - from typing import Optional +import aiohttp + from hailtop import httpx -from hailtop.config import get_deploy_config from hailtop.auth import hail_credentials +from hailtop.config import get_deploy_config from hailtop.httpx import client_session @@ -19,9 +19,7 @@ def __init__(self, deploy_config=None): async def __aenter__(self): async with hail_credentials() as credentials: headers = await credentials.auth_headers() - self._session = client_session( - raise_for_status=False, timeout=aiohttp.ClientTimeout(total=60), headers=headers - ) # type: ignore + self._session = client_session(raise_for_status=False, timeout=aiohttp.ClientTimeout(total=60), headers=headers) # type: ignore return self async def __aexit__(self, exc_type, exc, tb): diff --git a/hail/python/hailtop/hailctl/dev/cli.py b/hail/python/hailtop/hailctl/dev/cli.py index 493dfb76afb..613a4626857 100644 --- a/hail/python/hailtop/hailctl/dev/cli.py +++ b/hail/python/hailtop/hailctl/dev/cli.py @@ -1,14 +1,13 @@ import asyncio -import typer import webbrowser +from typing import Annotated as Ann +from typing import List, Optional -from typing import List, Optional, Annotated as Ann +import typer from typer import Option as Opt - from . import config - app = typer.Typer( name='dev', no_args_is_help=True, @@ -42,13 +41,17 @@ def deploy( ] = None, open: Ann[bool, Opt('--open', '-o', help='Open the deploy batch page in a web browser.')] = False, ): - '''Deploy a branch.''' + """Deploy a branch.""" asyncio.run(_deploy(branch, steps, excluded_steps or [], extra_config or [], open)) async def _deploy(branch: str, steps: List[str], excluded_steps: List[str], extra_config: List[str], open: bool): from hailtop.config import get_deploy_config # pylint: disable=import-outside-toplevel - from hailtop.utils import unpack_comma_delimited_inputs, unpack_key_value_inputs # pylint: disable=import-outside-toplevel + from hailtop.utils import ( # pylint: disable=import-outside-toplevel + unpack_comma_delimited_inputs, + unpack_key_value_inputs, + ) + from .ci_client import CIClient # pylint: disable=import-outside-toplevel deploy_config = get_deploy_config() diff --git a/hail/python/hailtop/hailctl/dev/config.py b/hail/python/hailtop/hailctl/dev/config.py index d5e8964b195..fb16d30c02e 100644 --- a/hail/python/hailtop/hailctl/dev/config.py +++ b/hail/python/hailtop/hailctl/dev/config.py @@ -1,6 +1,7 @@ -from enum import Enum import os -import json +from enum import Enum + +import orjson import typer app = typer.Typer( @@ -18,23 +19,19 @@ class DevConfigProperty(str, Enum): @app.command() def set(property: DevConfigProperty, value: str): - '''Set dev config property PROPERTY to value VALUE.''' - from hailtop.config import get_deploy_config # pylint: disable=import-outside-toplevel - - deploy_config = get_deploy_config() - config = deploy_config.get_config() - - p = property - config[p] = value - + """Set dev config property PROPERTY to value VALUE.""" config_file = os.environ.get('HAIL_DEPLOY_CONFIG_FILE', os.path.expanduser('~/.hail/deploy-config.json')) - with open(config_file, 'w', encoding='utf-8') as f: - json.dump(config, f) + with open(config_file, 'r', encoding='utf-8') as old_config_f: + config = orjson.loads(old_config_f.read()) + + config[property] = value + with open(config_file, 'w', encoding='utf-8') as new_config_f: + new_config_f.write(orjson.dumps(config).decode('utf-8')) @app.command() def list(): - '''List the settings in the dev config.''' + """List the settings in the dev config.""" from hailtop.config import get_deploy_config # pylint: disable=import-outside-toplevel deploy_config = get_deploy_config() diff --git a/hail/python/hailtop/hailctl/hdinsight/cli.py b/hail/python/hailtop/hailctl/hdinsight/cli.py index 76fe14ab95c..3ad7ae6f3e9 100644 --- a/hail/python/hailtop/hailctl/hdinsight/cli.py +++ b/hail/python/hailtop/hailctl/hdinsight/cli.py @@ -1,14 +1,15 @@ import subprocess - -from typing import Optional, List, Annotated as Ann +from typing import Annotated as Ann +from typing import List, Optional import typer -from typer import Option as Opt, Argument as Arg +from typer import Argument as Arg +from typer import Option as Opt -from .start import start as hdinsight_start, VepVersion +from .start import VepVersion +from .start import start as hdinsight_start from .submit import submit as hdinsight_submit - app = typer.Typer( name='hdinsight', no_args_is_help=True, @@ -75,9 +76,9 @@ def start( ), ] = None, ): - ''' + """ Start an HDInsight cluster configured for Hail. - ''' + """ from ... import pip_version # pylint: disable=import-outside-toplevel hail_version = pip_version() @@ -114,36 +115,32 @@ def stop( extra_hdinsight_delete_args: Optional[List[str]] = None, extra_storage_delete_args: Optional[List[str]] = None, ): - ''' + """ Stop an HDInsight cluster configured for Hail. - ''' + """ print(f"Stopping cluster '{name}'...") - subprocess.check_call( - [ - 'az', - 'hdinsight', - 'delete', - '--name', - name, - '--resource-group', - resource_group, - *(extra_hdinsight_delete_args or []), - ] - ) - subprocess.check_call( - [ - 'az', - 'storage', - 'container', - 'delete', - '--name', - name, - '--account-name', - storage_account, - *(extra_storage_delete_args or []), - ] - ) + subprocess.check_call([ + 'az', + 'hdinsight', + 'delete', + '--name', + name, + '--resource-group', + resource_group, + *(extra_hdinsight_delete_args or []), + ]) + subprocess.check_call([ + 'az', + 'storage', + 'container', + 'delete', + '--name', + name, + '--account-name', + storage_account, + *(extra_storage_delete_args or []), + ]) @app.command(context_settings={"allow_extra_args": True, "ignore_unknown_options": True}) @@ -153,9 +150,11 @@ def submit( storage_account: Ann[str, Arg(help="Storage account in which the cluster's container exists.")], http_password: Ann[str, Arg(help='Web password for the cluster')], script: Ann[str, Arg(help='Path to script.')], - arguments: Ann[Optional[List[str]], Arg(help='You should use -- if you want to pass option-like arguments through.')] = None, + arguments: Ann[ + Optional[List[str]], Arg(help='You should use -- if you want to pass option-like arguments through.') + ] = None, ): - ''' + """ Submit a job to an HDInsight cluster configured for Hail. If you wish to pass option-like arguments you should use "--". For example: @@ -163,13 +162,13 @@ def submit( $ hailctl hdinsight submit name account password script.py --image-name docker.io/image my_script.py -- some-argument --animal dog - ''' + """ hdinsight_submit(name, storage_account, http_password, script, [*(arguments or []), *ctx.args]) @app.command(context_settings={"allow_extra_args": True, "ignore_unknown_options": True}) def list(ctx: typer.Context): - ''' + """ List HDInsight clusters configured for Hail. - ''' + """ subprocess.check_call(['az', 'hdinsight', 'list', *ctx.args]) diff --git a/hail/python/hailtop/hailctl/hdinsight/start.py b/hail/python/hailtop/hailctl/hdinsight/start.py index dafa1fffb18..f2e2fa9ec83 100644 --- a/hail/python/hailtop/hailctl/hdinsight/start.py +++ b/hail/python/hailtop/hailctl/hdinsight/start.py @@ -1,12 +1,11 @@ -import re +import json import os -from enum import Enum +import re +import subprocess import sys import time -import json -import subprocess +from enum import Enum from shlex import quote as shq - from typing import List, Optional @@ -39,8 +38,9 @@ def start( ): import requests # pylint: disable=import-outside-toplevel import requests.auth # pylint: disable=import-outside-toplevel - from ...utils import secret_alnum_string # pylint: disable=import-outside-toplevel + from ... import pip_version # pylint: disable=import-outside-toplevel + from ...utils import secret_alnum_string # pylint: disable=import-outside-toplevel print(f'Starting the cluster {cluster_name}') @@ -191,20 +191,21 @@ def put_jupyter(command): timeout=60, ) - stop = json.dumps( - {"RequestInfo": {"context": "put services into STOPPED state"}, "Body": {"ServiceInfo": {"state": "INSTALLED"}}} - ) - start = json.dumps( - {"RequestInfo": {"context": "put services into STARTED state"}, "Body": {"ServiceInfo": {"state": "STARTED"}}} - ) + stop = json.dumps({ + "RequestInfo": {"context": "put services into STOPPED state"}, + "Body": {"ServiceInfo": {"state": "INSTALLED"}}, + }) + start = json.dumps({ + "RequestInfo": {"context": "put services into STARTED state"}, + "Body": {"ServiceInfo": {"state": "STARTED"}}, + }) print('Restarting Jupyter ...') put_jupyter(stop) time.sleep(10) put_jupyter(start) - print( - f'''Your cluster is ready. + print(f"""Your cluster is ready. Web username: admin Web password: {http_password} Jupyter URL: https://{cluster_name}.azurehdinsight.net/jupyter/tree @@ -214,5 +215,4 @@ def put_jupyter(command): SSH domain name: {cluster_name}-ssh.azurehdinsight.net Use the "Python3 (ipykernel)" kernel. -''' - ) +""") diff --git a/hail/python/hailtop/hailctl/hdinsight/submit.py b/hail/python/hailtop/hailctl/hdinsight/submit.py index 4b0544ae18a..c95a20bef44 100644 --- a/hail/python/hailtop/hailctl/hdinsight/submit.py +++ b/hail/python/hailtop/hailctl/hdinsight/submit.py @@ -1,6 +1,5 @@ import os import subprocess - from typing import List @@ -13,21 +12,20 @@ def submit( ): import requests # pylint: disable=import-outside-toplevel import requests.auth # pylint: disable=import-outside-toplevel + from ...utils import sync_sleep_before_try # pylint: disable=import-outside-toplevel print("Submitting to cluster '{}'...".format(name)) - subprocess.check_call( - [ - 'az', - 'storage', - 'copy', - '--source', - script, - '--destination', - f'https://{storage_account}.blob.core.windows.net/{name}/{os.path.basename(script)}', - ] - ) + subprocess.check_call([ + 'az', + 'storage', + 'copy', + '--source', + script, + '--destination', + f'https://{storage_account}.blob.core.windows.net/{name}/{os.path.basename(script)}', + ]) resp = requests.post( f'https://{name}.azurehdinsight.net/livy/batches', headers={'Content-Type': 'application/json', 'X-Requested-By': 'admin'}, @@ -62,4 +60,4 @@ def submit( ) break tries += 1 - sync_sleep_before_try(tries, base_delay_ms = 10) + sync_sleep_before_try(tries, base_delay_ms=10) diff --git a/hail/python/hailtop/httpx.py b/hail/python/hailtop/httpx.py index 83dbff787ba..022165d74b2 100644 --- a/hail/python/hailtop/httpx.py +++ b/hail/python/hailtop/httpx.py @@ -1,27 +1,25 @@ -from typing import Any, Tuple, Optional, Type, Union import asyncio from types import TracebackType -import orjson +from typing import Any, Optional, Tuple, Type, Union + import aiohttp import aiohttp.abc import aiohttp.typedefs +import orjson -from .tls import internal_client_ssl_context, external_client_ssl_context +from .config import ConfigVariable, configuration_of from .config.deploy_config import get_deploy_config class ClientResponseError(aiohttp.ClientResponseError): - def __init__(self, - request_info: aiohttp.RequestInfo, - history: Tuple[aiohttp.ClientResponse, ...], - body: str = "", - **kwargs): + def __init__( + self, request_info: aiohttp.RequestInfo, history: Tuple[aiohttp.ClientResponse, ...], body: str = "", **kwargs + ): super().__init__(request_info, history, **kwargs) self.body = body def __str__(self) -> str: - return (f"{self.status}, message={self.message!r}, " - f"url={self.request_info.real_url!r} body={self.body!r}") + return f"{self.status}, message={self.message!r}, " f"url={self.request_info.real_url!r} body={self.body!r}" def __repr__(self) -> str: args = f"{self.request_info!r}, {self.history!r}" @@ -85,46 +83,53 @@ async def __aexit__( class ClientSession: - def __init__(self, - *args, - raise_for_status: bool = True, - timeout: Union[aiohttp.ClientTimeout, float, None] = None, - **kwargs): - location = get_deploy_config().location() - if location == 'external': - tls = external_client_ssl_context() - elif location == 'k8s': - tls = internal_client_ssl_context() - else: - assert location in ('gce', 'azure') - # no encryption on the internal gateway - tls = external_client_ssl_context() - + def __init__( + self, + *args, + raise_for_status: bool = True, + timeout: Union[aiohttp.ClientTimeout, float, int, None] = None, + **kwargs, + ): + tls = get_deploy_config().client_ssl_context() assert 'connector' not in kwargs - if timeout is None: - timeout = aiohttp.ClientTimeout(total=20) + configuration_of_timeout = configuration_of(ConfigVariable.HTTP_TIMEOUT_IN_SECONDS, timeout, 20) + del timeout + if isinstance(configuration_of_timeout, str): + configuration_of_timeout = float(configuration_of_timeout) + if isinstance(configuration_of_timeout, (float, int)): + configuration_of_timeout = aiohttp.ClientTimeout(total=configuration_of_timeout) + assert isinstance(configuration_of_timeout, aiohttp.ClientTimeout) + + self.loop = asyncio.get_running_loop() self.raise_for_status = raise_for_status self.client_session = aiohttp.ClientSession( *args, - timeout=timeout, + timeout=configuration_of_timeout, raise_for_status=False, connector=aiohttp.TCPConnector(ssl=tls), - **kwargs + **kwargs, ) def request( self, method: str, url: aiohttp.typedefs.StrOrURL, **kwargs: Any ) -> aiohttp.client._RequestContextManager: + if self.loop != asyncio.get_running_loop(): + raise ValueError( + f'ClientSession must be created and used in same loop {self.loop} != {asyncio.get_running_loop()}.' + ) raise_for_status = kwargs.pop('raise_for_status', self.raise_for_status) + timeout = kwargs.get('timeout') + if timeout and isinstance(timeout, (float, int)): + kwargs['timeout'] = aiohttp.ClientTimeout(total=timeout) + async def request_and_raise_for_status(): json_data = kwargs.pop('json', None) if json_data is not None: if kwargs.get('data') is not None: - raise ValueError( - 'data and json parameters cannot be used at the same time') + raise ValueError('data and json parameters cannot be used at the same time') kwargs['data'] = aiohttp.BytesPayload( value=orjson.dumps(json_data), # https://github.com/ijl/orjson#serialize @@ -146,14 +151,13 @@ async def request_and_raise_for_status(): status=resp.status, message=resp.reason, headers=resp.headers, - body=body + body=body, ) return resp + return aiohttp.client._RequestContextManager(request_and_raise_for_status()) - def ws_connect( - self, *args, **kwargs - ) -> aiohttp.client._WSRequestContextManager: + def ws_connect(self, *args, **kwargs) -> aiohttp.client._WSRequestContextManager: return self.client_session.ws_connect(*args, **kwargs) def get( @@ -161,15 +165,11 @@ def get( ) -> aiohttp.client._RequestContextManager: return self.request('GET', url, allow_redirects=allow_redirects, **kwargs) - async def get_read_json( - self, *args, **kwargs - ) -> Any: + async def get_read_json(self, *args, **kwargs) -> Any: async with self.get(*args, **kwargs) as resp: return await resp.json() - async def get_read( - self, *args, **kwargs - ) -> bytes: + async def get_read(self, *args, **kwargs) -> bytes: async with self.get(*args, **kwargs) as resp: return await resp.read() @@ -188,15 +188,11 @@ def post( ) -> aiohttp.client._RequestContextManager: return self.request('POST', url, data=data, **kwargs) - async def post_read_json( - self, *args, **kwargs - ) -> Any: + async def post_read_json(self, *args, **kwargs) -> Any: async with self.post(*args, **kwargs) as resp: return await resp.json() - async def post_read( - self, *args, **kwargs - ) -> bytes: + async def post_read(self, *args, **kwargs) -> bytes: async with self.post(*args, **kwargs) as resp: return await resp.read() @@ -210,9 +206,7 @@ def patch( ) -> aiohttp.client._RequestContextManager: return self.request('PATCH', url, data=data, **kwargs) - def delete( - self, url: aiohttp.typedefs.StrOrURL, **kwargs: Any - ) -> aiohttp.client._RequestContextManager: + def delete(self, url: aiohttp.typedefs.StrOrURL, **kwargs: Any) -> aiohttp.client._RequestContextManager: return self.request('DELETE', url, **kwargs) async def close(self) -> None: diff --git a/hail/python/hailtop/humanizex.py b/hail/python/hailtop/humanizex.py index efeab738d21..f2c9e1ac5d0 100644 --- a/hail/python/hailtop/humanizex.py +++ b/hail/python/hailtop/humanizex.py @@ -1,6 +1,5 @@ from typing import Union - _MICROSECOND = 1 _MILLISECOND = 1000 * _MICROSECOND _SECOND = 1000 * _MILLISECOND @@ -34,15 +33,15 @@ def _fmt(s: Union[int, float], word: str) -> str: def naturaldelta(seconds: Union[int, float]) -> str: - return _naturaldelta(seconds, value_unit = _SECOND) + return _naturaldelta(seconds, value_unit=_SECOND) def naturaldelta_msec(milliseconds: Union[int, float]) -> str: - return _naturaldelta(milliseconds, value_unit = _MILLISECOND) + return _naturaldelta(milliseconds, value_unit=_MILLISECOND) def naturaldelta_usec(microseconds: Union[int, float]) -> str: - return _naturaldelta(microseconds, value_unit = _MICROSECOND) + return _naturaldelta(microseconds, value_unit=_MICROSECOND) def _naturaldelta(value: Union[int, float], value_unit: int) -> str: diff --git a/hail/python/hailtop/pinned-requirements.txt b/hail/python/hailtop/pinned-requirements.txt index 70adf3f3f6e..2ace504430a 100644 --- a/hail/python/hailtop/pinned-requirements.txt +++ b/hail/python/hailtop/pinned-requirements.txt @@ -6,17 +6,17 @@ # aiodns==2.0.0 # via -r hail/hail/python/hailtop/requirements.txt -aiohttp==3.9.1 +aiohttp==3.9.3 # via -r hail/hail/python/hailtop/requirements.txt aiosignal==1.3.1 # via aiohttp async-timeout==4.0.3 # via aiohttp -attrs==23.1.0 +attrs==23.2.0 # via aiohttp azure-common==1.1.28 # via azure-mgmt-storage -azure-core==1.29.5 +azure-core==1.30.1 # via # azure-identity # azure-mgmt-core @@ -30,16 +30,16 @@ azure-mgmt-storage==20.1.0 # via -r hail/hail/python/hailtop/requirements.txt azure-storage-blob==12.19.0 # via -r hail/hail/python/hailtop/requirements.txt -boto3==1.33.1 +boto3==1.34.55 # via -r hail/hail/python/hailtop/requirements.txt -botocore==1.33.1 +botocore==1.34.55 # via # -r hail/hail/python/hailtop/requirements.txt # boto3 # s3transfer -cachetools==5.3.2 +cachetools==5.3.3 # via google-auth -certifi==2023.11.17 +certifi==2024.2.2 # via # msrest # requests @@ -53,20 +53,20 @@ click==8.1.7 # via typer commonmark==0.9.1 # via rich -cryptography==41.0.7 +cryptography==42.0.5 # via # azure-identity # azure-storage-blob # msal # pyjwt -dill==0.3.7 +dill==0.3.8 # via -r hail/hail/python/hailtop/requirements.txt -frozenlist==1.4.0 +frozenlist==1.4.1 # via # -r hail/hail/python/hailtop/requirements.txt # aiohttp # aiosignal -google-auth==2.23.4 +google-auth==2.28.1 # via # -r hail/hail/python/hailtop/requirements.txt # google-auth-oauthlib @@ -90,28 +90,28 @@ jmespath==1.0.1 # botocore jproperties==2.1.1 # via -r hail/hail/python/hailtop/requirements.txt -msal==1.25.0 +msal==1.27.0 # via # azure-identity # msal-extensions -msal-extensions==1.0.0 +msal-extensions==1.1.0 # via azure-identity msrest==0.7.1 # via azure-mgmt-storage -multidict==6.0.4 +multidict==6.0.5 # via # aiohttp # yarl -nest-asyncio==1.5.8 +nest-asyncio==1.6.0 # via -r hail/hail/python/hailtop/requirements.txt oauthlib==3.2.2 # via requests-oauthlib orjson==3.9.10 # via -r hail/hail/python/hailtop/requirements.txt +packaging==23.2 + # via msal-extensions portalocker==2.8.2 # via msal-extensions -protobuf==3.20.2 - # via -r hail/hail/python/hailtop/requirements.txt pyasn1==0.5.1 # via # pyasn1-modules @@ -126,7 +126,7 @@ pygments==2.17.2 # via rich pyjwt[crypto]==2.8.0 # via msal -python-dateutil==2.8.2 +python-dateutil==2.9.0.post0 # via botocore python-json-logger==2.0.7 # via -r hail/hail/python/hailtop/requirements.txt @@ -146,7 +146,7 @@ rich==12.6.0 # via -r hail/hail/python/hailtop/requirements.txt rsa==4.9 # via google-auth -s3transfer==0.8.0 +s3transfer==0.10.0 # via boto3 six==1.16.0 # via @@ -160,7 +160,7 @@ tabulate==0.9.0 # via -r hail/hail/python/hailtop/requirements.txt typer==0.9.0 # via -r hail/hail/python/hailtop/requirements.txt -typing-extensions==4.8.0 +typing-extensions==4.10.0 # via # azure-core # azure-storage-blob @@ -172,5 +172,5 @@ urllib3==1.26.18 # requests uvloop==0.19.0 ; sys_platform != "win32" # via -r hail/hail/python/hailtop/requirements.txt -yarl==1.9.3 +yarl==1.9.4 # via aiohttp diff --git a/hail/python/hailtop/requirements.txt b/hail/python/hailtop/requirements.txt index a7d7f2f1899..3f3a06ed976 100644 --- a/hail/python/hailtop/requirements.txt +++ b/hail/python/hailtop/requirements.txt @@ -12,8 +12,8 @@ google-auth-oauthlib>=0.5.2,<1 humanize>=1.0.0,<2 janus>=0.6,<1.1 nest_asyncio>=1.5.8,<2 -orjson>=3.6.4,<4 -protobuf==3.20.2 +# <3.9.11: https://github.com/hail-is/hail/issues/14299 +orjson>=3.6.4,<3.9.11 rich>=12.6.0,<13 typer>=0.9.0,<1 python-json-logger>=2.0.2,<3 diff --git a/hail/python/hailtop/test_utils.py b/hail/python/hailtop/test_utils.py index 06e0aa3b634..c641dc26537 100644 --- a/hail/python/hailtop/test_utils.py +++ b/hail/python/hailtop/test_utils.py @@ -1,17 +1,12 @@ import os -import pytest +import pytest fails_in_azure = pytest.mark.xfail( - os.environ.get('HAIL_CLOUD') == 'azure', - reason="doesn't yet work on azure", - strict=True) + os.environ.get('HAIL_CLOUD') == 'azure', reason="doesn't yet work on azure", strict=True +) -skip_in_azure = pytest.mark.skipif( - os.environ.get('HAIL_CLOUD') == 'azure', - reason="not applicable to azure") +skip_in_azure = pytest.mark.skipif(os.environ.get('HAIL_CLOUD') == 'azure', reason="not applicable to azure") -run_if_azure = pytest.mark.skipif( - os.environ.get('HAIL_CLOUD') != 'azure', - reason="only applicable to azure") +run_if_azure = pytest.mark.skipif(os.environ.get('HAIL_CLOUD') != 'azure', reason="only applicable to azure") diff --git a/hail/python/hailtop/timex.py b/hail/python/hailtop/timex.py index ca205bce316..2482c9780ab 100644 --- a/hail/python/hailtop/timex.py +++ b/hail/python/hailtop/timex.py @@ -1,22 +1,21 @@ -from typing import Dict, Optional - -import re import datetime +import re +from typing import Dict, Optional rfc3339_re = re.compile( # https://www.rfc-editor.org/rfc/rfc3339#section-5.6 - '([0-9][0-9][0-9][0-9])' # YYYY + '([0-9][0-9][0-9][0-9])' # YYYY '-' - '([0-9][0-9])' # MM + '([0-9][0-9])' # MM '-' - '([0-9][0-9])' # DD - '[Tt ]' # see NOTE in link - '([0-9][0-9])' # HH + '([0-9][0-9])' # DD + '[Tt ]' # see NOTE in link + '([0-9][0-9])' # HH ':' - '([0-9][0-9])' # MM + '([0-9][0-9])' # MM ':' - '([0-9][0-9])' # SS - '(.[0-9][0-9]*)?' # optional fractional seconds + '([0-9][0-9])' # SS + '(.[0-9][0-9]*)?' # optional fractional seconds '([Zz]|[+-][0-9][0-9]:[0-9][0-9])' # offset / timezone ) _timezone_cache: Dict[str, datetime.timezone] = {} @@ -70,5 +69,5 @@ def parse_rfc3339(s: str) -> datetime.datetime: minute=int(minute), second=int(second), microsecond=microsecond, - tzinfo=tz + tzinfo=tz, ) diff --git a/hail/python/hailtop/tls.py b/hail/python/hailtop/tls.py index 056b69ad8f1..10817de98bf 100644 --- a/hail/python/hailtop/tls.py +++ b/hail/python/hailtop/tls.py @@ -1,9 +1,9 @@ -from typing import Dict -import logging import json +import logging import os import ssl from ssl import Purpose +from typing import Dict log = logging.getLogger('hailtop.tls') _server_ssl_context = None @@ -42,11 +42,9 @@ def internal_server_ssl_context() -> ssl.SSLContext: if _server_ssl_context is None: ssl_config = _get_ssl_config() _server_ssl_context = ssl.create_default_context( - purpose=Purpose.CLIENT_AUTH, - cafile=ssl_config['incoming_trust']) - _server_ssl_context.load_cert_chain(ssl_config['cert'], - keyfile=ssl_config['key'], - password=None) + purpose=Purpose.CLIENT_AUTH, cafile=ssl_config['incoming_trust'] + ) + _server_ssl_context.load_cert_chain(ssl_config['cert'], keyfile=ssl_config['key'], password=None) _server_ssl_context.verify_mode = ssl.CERT_OPTIONAL # FIXME: mTLS # _server_ssl_context.verify_mode = ssl.CERT_REQURIED @@ -59,15 +57,13 @@ def internal_client_ssl_context() -> ssl.SSLContext: if _client_ssl_context is None: ssl_config = _get_ssl_config() _client_ssl_context = ssl.create_default_context( - purpose=Purpose.SERVER_AUTH, - cafile=ssl_config['outgoing_trust']) + purpose=Purpose.SERVER_AUTH, cafile=ssl_config['outgoing_trust'] + ) # setting cafile in `create_default_context` ignores the system default # certificates. We must explicitly request them again with # load_default_certs. _client_ssl_context.load_default_certs() - _client_ssl_context.load_cert_chain(ssl_config['cert'], - keyfile=ssl_config['key'], - password=None) + _client_ssl_context.load_cert_chain(ssl_config['cert'], keyfile=ssl_config['key'], password=None) _client_ssl_context.verify_mode = ssl.CERT_REQUIRED _client_ssl_context.check_hostname = True return _client_ssl_context diff --git a/hail/python/hailtop/utils/__init__.py b/hail/python/hailtop/utils/__init__.py index 29a01a0cd55..55ad635f930 100644 --- a/hail/python/hailtop/utils/__init__.py +++ b/hail/python/hailtop/utils/__init__.py @@ -1,33 +1,79 @@ -from .time import ( - time_msecs, time_msecs_str, humanize_timedelta_msecs, parse_timestamp_msecs, - time_ns) -from .utils import (unzip, async_to_blocking, blocking_to_async, AsyncWorkerPool, bounded_gather, - grouped, sync_sleep_before_try, sleep_before_try, is_transient_error, - collect_aiter, retry_all_errors, retry_transient_errors, - retry_transient_errors_with_debug_string, retry_long_running, run_if_changed, - run_if_changed_idempotent, LoggingTimer, WaitableSharedPool, - RETRY_FUNCTION_SCRIPT, sync_retry_transient_errors, - retry_response_returning_functions, first_extant_file, secret_alnum_string, - flatten, filter_none, partition, cost_str, external_requests_client_session, - url_basename, url_join, parse_docker_image_reference, url_and_params, - url_scheme, Notice, periodically_call, dump_all_stacktraces, find_spark_home, - TransientError, bounded_gather2, OnlineBoundedGather2, - unpack_comma_delimited_inputs, unpack_key_value_inputs, - retry_all_errors_n_times, Timings, is_limited_retries_error, am_i_interactive, - is_delayed_warning_error, retry_transient_errors_with_delayed_warnings, - periodically_call_with_dynamic_sleep, delay_ms_for_try, ait_to_blocking) +from . import rich_progress_bar, serialization from .process import ( - CalledProcessError, check_shell, check_shell_output, check_exec_output, - sync_check_shell, sync_check_shell_output, sync_check_exec) -from .rates import ( - rate_cpu_hour_to_mcpu_msec, rate_gib_hour_to_mib_msec, rate_gib_month_to_mib_msec, - rate_instance_hour_to_fraction_msec + CalledProcessError, + check_exec_output, + check_shell, + check_shell_output, + sync_check_exec, + sync_check_shell, + sync_check_shell_output, ) from .rate_limiter import RateLimit, RateLimiter -from . import serialization, rich_progress_bar +from .rates import ( + rate_cpu_hour_to_mcpu_msec, + rate_gib_hour_to_mib_msec, + rate_gib_month_to_mib_msec, + rate_instance_hour_to_fraction_msec, +) +from .time import humanize_timedelta_msecs, parse_timestamp_msecs, time_msecs, time_msecs_str, time_ns +from .utils import ( + RETRY_FUNCTION_SCRIPT, + AsyncWorkerPool, + LoggingTimer, + Notice, + OnlineBoundedGather2, + Timings, + TransientError, + WaitableSharedPool, + ait_to_blocking, + am_i_interactive, + async_to_blocking, + blocking_to_async, + bounded_gather, + bounded_gather2, + collect_aiter, + cost_str, + delay_ms_for_try, + dump_all_stacktraces, + external_requests_client_session, + filter_none, + find_spark_home, + first_extant_file, + flatten, + grouped, + is_delayed_warning_error, + is_limited_retries_error, + is_transient_error, + parse_docker_image_reference, + partition, + periodically_call, + periodically_call_with_dynamic_sleep, + retry_all_errors, + retry_all_errors_n_times, + retry_long_running, + retry_response_returning_functions, + retry_transient_errors, + retry_transient_errors_with_debug_string, + retry_transient_errors_with_delayed_warnings, + run_if_changed, + run_if_changed_idempotent, + secret_alnum_string, + sleep_before_try, + sync_retry_transient_errors, + sync_sleep_before_try, + the_empty_async_generator, + unpack_comma_delimited_inputs, + unpack_key_value_inputs, + unzip, + url_and_params, + url_basename, + url_join, + url_scheme, +) __all__ = [ 'time_msecs', + 'the_empty_async_generator', 'time_msecs_str', 'humanize_timedelta_msecs', 'unzip', diff --git a/hail/python/hailtop/utils/gcs_requester_pays.py b/hail/python/hailtop/utils/gcs_requester_pays.py new file mode 100644 index 00000000000..40feaa909c1 --- /dev/null +++ b/hail/python/hailtop/utils/gcs_requester_pays.py @@ -0,0 +1,32 @@ +from typing import Any, Dict, FrozenSet, Generic, Optional, Tuple, Type, TypeVar, Union + +from hailtop.aiocloud.aiogoogle import GCSRequesterPaysConfiguration + +FS = TypeVar("FS") +MaybeGCSRequesterPaysConfiguration = Optional[GCSRequesterPaysConfiguration] +FrozenKey = Optional[Union[str, Tuple[str, FrozenSet[str]]]] + + +class GCSRequesterPaysFSCache(Generic[FS]): + def __init__(self, fs_constructor: Type[FS], default_kwargs: Optional[Dict[str, Any]] = None) -> None: + self._fs_constructor = fs_constructor + self._default_kwargs = default_kwargs if default_kwargs is not None else {} + self._dict: Dict[FrozenKey, FS] = {} + + def __getitem__(self, gcs_requester_pays_configuration: MaybeGCSRequesterPaysConfiguration) -> FS: + frozen_key = self._freeze_key(gcs_requester_pays_configuration) + fs = self._dict.get(frozen_key) + if fs is None: + if gcs_requester_pays_configuration is None: + kwargs = self._default_kwargs + else: + kwargs = {"gcs_kwargs": {"gcs_requester_pays_configuration": gcs_requester_pays_configuration}} + fs = self._fs_constructor(**kwargs) + self._dict[frozen_key] = fs + return fs + + def _freeze_key(self, gcs_requester_pays_configuration: MaybeGCSRequesterPaysConfiguration) -> FrozenKey: + if isinstance(gcs_requester_pays_configuration, tuple): + project, buckets = gcs_requester_pays_configuration + return (project, frozenset(buckets)) + return gcs_requester_pays_configuration diff --git a/hail/python/hailtop/utils/process.py b/hail/python/hailtop/utils/process.py index fab670da136..439413c880e 100644 --- a/hail/python/hailtop/utils/process.py +++ b/hail/python/hailtop/utils/process.py @@ -1,6 +1,6 @@ -from typing import Tuple, List import asyncio import subprocess +from typing import List, Tuple from .utils import async_to_blocking @@ -15,20 +15,15 @@ def __init__(self, argv: List[str], returncode: int, outerr: Tuple[bytes, bytes] self.stderr = outerr[1] def __str__(self) -> str: - return (f'Command {self.argv} returned non-zero exit status {self.returncode}.' - f' Output:\n{self._outerr}') + return f'Command {self.argv} returned non-zero exit status {self.returncode}.' f' Output:\n{self._outerr}' -async def check_exec_output(command: str, - *args: str, - echo: bool = False - ) -> Tuple[bytes, bytes]: +async def check_exec_output(command: str, *args: str, echo: bool = False) -> Tuple[bytes, bytes]: if echo: print([command, *args]) proc = await asyncio.create_subprocess_exec( - command, *args, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE) + command, *args, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE + ) outerr = await proc.communicate() assert proc.returncode is not None if proc.returncode != 0: diff --git a/hail/python/hailtop/utils/rate_limiter.py b/hail/python/hailtop/utils/rate_limiter.py index 56f68d33dbf..36936df27e4 100644 --- a/hail/python/hailtop/utils/rate_limiter.py +++ b/hail/python/hailtop/utils/rate_limiter.py @@ -1,8 +1,8 @@ -from types import TracebackType -from typing import Optional, Type, Deque +import asyncio import collections import time -import asyncio +from types import TracebackType +from typing import Deque, Optional, Type class RateLimit: @@ -35,8 +35,7 @@ async def __aenter__(self) -> 'RateLimiter': return self - async def __aexit__(self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType]) -> None: + async def __aexit__( + self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType] + ) -> None: pass diff --git a/hail/python/hailtop/utils/rich_progress_bar.py b/hail/python/hailtop/utils/rich_progress_bar.py index cb1384a2af9..0bd26be7ad0 100644 --- a/hail/python/hailtop/utils/rich_progress_bar.py +++ b/hail/python/hailtop/utils/rich_progress_bar.py @@ -1,6 +1,17 @@ -from typing import Optional, Callable, Tuple, List +from typing import Callable, List, Optional, Tuple + from rich import filesize -from rich.progress import MofNCompleteColumn, BarColumn, TextColumn, TimeRemainingColumn, TimeElapsedColumn, Progress, ProgressColumn, TaskProgressColumn, Task +from rich.progress import ( + BarColumn, + MofNCompleteColumn, + Progress, + ProgressColumn, + Task, + TaskProgressColumn, + TextColumn, + TimeElapsedColumn, + TimeRemainingColumn, +) from rich.text import Text @@ -54,6 +65,7 @@ def listen(delta: int): progress.update(tid, total=total) else: progress.update(tid, advance=-delta) + return listen @@ -93,9 +105,8 @@ def render(self, task: "Task") -> Text: if speed is None: return Text("?", style="progress.data.speed") - speed = int(speed) - unit, suffix = filesize.pick_unit_and_suffix(speed, *units(task)) - precision = 0 if unit == 1 else 1 + unit, suffix = filesize.pick_unit_and_suffix(int(speed), *units(task)) + precision = 2 if unit == 1 else 1 return Text(f"{speed / unit:,.{precision}f} {suffix}/s", style="progress.data.speed") @@ -114,7 +125,7 @@ def get_default_columns() -> Tuple[ProgressColumn, ...]: BytesOrCountOrN(), RateColumn(), TimeRemainingColumn(), - TimeElapsedColumn() + TimeElapsedColumn(), ) def __enter__(self) -> Progress: @@ -145,7 +156,7 @@ def get_default_columns() -> Tuple[ProgressColumn, ...]: TaskProgressColumn(), MofNCompleteColumn(), TimeRemainingColumn(), - TimeElapsedColumn() + TimeElapsedColumn(), ) def __enter__(self) -> 'BatchProgressBar': @@ -161,7 +172,9 @@ def __exit__(self, exc_type, exc_value, traceback): finally: self._progress.stop() - def with_task(self, description: str, *, total: int = 0, disable: bool = False, transient: bool = False) -> 'BatchProgressBarTask': + def with_task( + self, description: str, *, total: int = 0, disable: bool = False, transient: bool = False + ) -> 'BatchProgressBarTask': tid = self._progress.add_task(description, total=total, visible=not disable) return BatchProgressBarTask(self._progress, tid, transient) diff --git a/hail/python/hailtop/utils/serialization.py b/hail/python/hailtop/utils/serialization.py index 61871eecbae..3110475e707 100644 --- a/hail/python/hailtop/utils/serialization.py +++ b/hail/python/hailtop/utils/serialization.py @@ -1,10 +1,10 @@ -from typing import Dict, Any import traceback +from typing import Any, Dict def exception_to_dict(exc: Exception) -> Dict[str, Any]: return { 'class': type(exc).__name__, 'args': exc.args, - 'traceback': traceback.format_exception(type(exc), exc, exc.__traceback__) + 'traceback': traceback.format_exception(type(exc), exc, exc.__traceback__), } diff --git a/hail/python/hailtop/utils/time.py b/hail/python/hailtop/utils/time.py index 81c97caa48d..80f787c599e 100644 --- a/hail/python/hailtop/utils/time.py +++ b/hail/python/hailtop/utils/time.py @@ -1,9 +1,10 @@ -import time -from typing import Optional, overload, Union import datetime +import time +from typing import Optional, Union, overload + import dateutil.parser -from ..humanizex import naturaldelta_msec +from ..humanizex import naturaldelta_msec def time_msecs() -> int: @@ -15,14 +16,17 @@ def time_ns() -> int: def time_msecs_str(t: Union[int, float]) -> str: - return datetime.datetime.utcfromtimestamp(t / 1000).strftime( - '%Y-%m-%dT%H:%M:%SZ') + return datetime.datetime.utcfromtimestamp(t / 1000).strftime('%Y-%m-%dT%H:%M:%SZ') @overload def humanize_timedelta_msecs(delta_msecs: None) -> None: ... + + @overload def humanize_timedelta_msecs(delta_msecs: Union[int, float]) -> str: ... + + def humanize_timedelta_msecs(delta_msecs: Optional[Union[int, float]]) -> Optional[str]: if delta_msecs is None: return None @@ -35,8 +39,12 @@ def humanize_timedelta_msecs(delta_msecs: Optional[Union[int, float]]) -> Option @overload def parse_timestamp_msecs(ts: None) -> None: ... + + @overload def parse_timestamp_msecs(ts: str) -> int: ... + + def parse_timestamp_msecs(ts: Optional[str]) -> Optional[int]: if ts is None: return ts diff --git a/hail/python/hailtop/utils/utils.py b/hail/python/hailtop/utils/utils.py index ae4805a10c6..9067a24ba57 100644 --- a/hail/python/hailtop/utils/utils.py +++ b/hail/python/hailtop/utils/utils.py @@ -1,33 +1,50 @@ -from typing import (Any, Callable, TypeVar, Awaitable, Mapping, Optional, Type, List, Dict, - Iterable, Tuple, AsyncIterator, Iterator, Union) -from typing import Literal, Sequence -from typing_extensions import ParamSpec -from types import TracebackType +import asyncio import concurrent.futures import contextlib -import subprocess -import traceback -import sys -import os -import re import errno -import random +import itertools import logging -import asyncio -import aiohttp -import urllib.parse -import urllib3.exceptions +import os +import random +import re import secrets import socket -import requests -import botocore.exceptions -import itertools +import subprocess +import sys import time +import traceback +import urllib.parse +from types import TracebackType +from typing import ( + Any, + AsyncGenerator, + AsyncIterator, + Awaitable, + Callable, + Dict, + Iterable, + Iterator, + List, + Literal, + Mapping, + Optional, + Sequence, + Tuple, + Type, + TypeVar, + Union, +) + +import aiohttp +import botocore.exceptions +import requests +import urllib3.exceptions from requests.adapters import HTTPAdapter +from typing_extensions import ParamSpec from urllib3.poolmanager import PoolManager -from .time import time_msecs from ..hail_event_loop import hail_event_loop +from .time import time_msecs try: import aiodocker # pylint: disable=import-error @@ -50,10 +67,13 @@ P = ParamSpec("P") +async def the_empty_async_generator() -> AsyncGenerator[Any, None]: + if False: # pylint: disable=using-constant-test + yield # The appearance of the keyword `yield` forces Python to make this function into a generator + + def unpack_comma_delimited_inputs(inputs: List[str]) -> List[str]: - return [s.strip() - for comma_separated_steps in inputs - for s in comma_separated_steps.split(',') if s.strip()] + return [s.strip() for comma_separated_steps in inputs for s in comma_separated_steps.split(',') if s.strip()] def unpack_key_value_inputs(inputs: List[str]) -> Dict[str, str]: @@ -139,8 +159,9 @@ def partition(k: int, ls: Sequence[T]) -> Iterable[Sequence[T]]: def generator(): start = 0 for part in parts: - yield ls[start:start + part] + yield ls[start : start + part] start += part + return generator() @@ -175,30 +196,26 @@ def ait_to_blocking(ait: AsyncIterator[T]) -> Iterator[T]: break -async def blocking_to_async(thread_pool: concurrent.futures.Executor, - fun: Callable[..., T], - *args, - **kwargs) -> T: - return await asyncio.get_running_loop().run_in_executor( - thread_pool, lambda: fun(*args, **kwargs)) +async def blocking_to_async(thread_pool: concurrent.futures.Executor, fun: Callable[..., T], *args, **kwargs) -> T: + return await asyncio.get_running_loop().run_in_executor(thread_pool, lambda: fun(*args, **kwargs)) -async def bounded_gather(*pfs: Callable[[], Awaitable[T]], - parallelism: int = 10, - return_exceptions: bool = False, - cancel_on_error = False, - ) -> List[T]: +async def bounded_gather( + *pfs: Callable[[], Awaitable[T]], + parallelism: int = 10, + return_exceptions: bool = False, + cancel_on_error=False, +) -> List[T]: return await bounded_gather2( - asyncio.Semaphore(parallelism), - *pfs, - return_exceptions=return_exceptions, - cancel_on_error=cancel_on_error + asyncio.Semaphore(parallelism), *pfs, return_exceptions=return_exceptions, cancel_on_error=cancel_on_error ) class AsyncWorkerPool: def __init__(self, parallelism, queue_size=1000): - self._queue: asyncio.Queue[Tuple[Callable, Tuple[Any, ...], Mapping[str, Any]]] = asyncio.Queue(maxsize=queue_size) + self._queue: asyncio.Queue[Tuple[Callable, Tuple[Any, ...], Mapping[str, Any]]] = asyncio.Queue( + maxsize=queue_size + ) self.workers = {asyncio.ensure_future(self._worker()) for _ in range(parallelism)} async def _worker(self): @@ -261,26 +278,26 @@ async def wait(self): async def __aenter__(self) -> 'WaitableSharedPool': return self - async def __aexit__(self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType]) -> None: + async def __aexit__( + self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType] + ) -> None: await self.wait() class WithoutSemaphore: - def __init__(self, sema): + def __init__(self, sema, *, acquire_on_error: bool = False): self._sema = sema + self._acquire_on_error = acquire_on_error async def __aenter__(self) -> 'WithoutSemaphore': self._sema.release() return self - async def __aexit__(self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType]) -> None: - await self._sema.acquire() + async def __aexit__( + self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType] + ) -> None: + if exc_val is None or self._acquire_on_error: + await self._sema.acquire() class PoolShutdownError(Exception): @@ -288,7 +305,7 @@ class PoolShutdownError(Exception): class OnlineBoundedGather2: - '''`OnlineBoundedGather2` provides the capability to run background + """`OnlineBoundedGather2` provides the capability to run background tasks with bounded parallelism. It is a context manager, and waits for all background tasks to complete on exit. @@ -307,7 +324,7 @@ class OnlineBoundedGather2: a background task or into the context manager exit, is raised by the context manager exit, and any further exceptions are logged and otherwise discarded. - ''' + """ def __init__(self, sema: asyncio.Semaphore): self._counter = 0 @@ -322,11 +339,11 @@ def __init__(self, sema: asyncio.Semaphore): self._exception: Optional[BaseException] = None async def _shutdown(self) -> None: - '''Shut down the pool. + """Shut down the pool. Cancel all pending tasks and wait for them to complete. Subsequent calls to call will raise `PoolShutdownError`. - ''' + """ if self._pending is None: return @@ -344,13 +361,13 @@ async def _shutdown(self) -> None: self._done_event.set() def call(self, f, *args, **kwargs) -> asyncio.Task: - '''Invoke a function as a background task. + """Invoke a function as a background task. Return the task, which can be used to wait on (using `OnlineBoundedGather2.wait()`) or cancel the task (using `asyncio.Task.cancel()`). Note, waiting on a task using `asyncio.wait()` directly can lead to deadlock. - ''' + """ if self._pending is None: raise PoolShutdownError @@ -386,14 +403,14 @@ async def run_and_cleanup(): return t async def wait(self, tasks: List[asyncio.Task]) -> None: - '''Wait for a list of tasks returned to complete. + """Wait for a list of tasks returned to complete. The tasks should be tasks returned from `OnlineBoundedGather2.call()`. They can be a subset of the running tasks, `OnlineBoundedGather2.wait()` can be called multiple times, and additional tasks can be submitted to the pool after waiting. - ''' + """ async with WithoutSemaphore(self._sema): await asyncio.wait(tasks) @@ -401,10 +418,9 @@ async def wait(self, tasks: List[asyncio.Task]) -> None: async def __aenter__(self) -> 'OnlineBoundedGather2': return self - async def __aexit__(self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType]) -> None: + async def __aexit__( + self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType] + ) -> None: if exc_val: if self._exception is None: self._exception = exc_val @@ -427,10 +443,9 @@ async def __aexit__(self, async def bounded_gather2_return_exceptions( - sema: asyncio.Semaphore, - *pfs: Callable[[], Awaitable[T]] + sema: asyncio.Semaphore, *pfs: Callable[[], Awaitable[T]] ) -> List[Union[Tuple[T, None], Tuple[None, Optional[BaseException]]]]: - '''Run the partial functions `pfs` as tasks with parallelism bounded + """Run the partial functions `pfs` as tasks with parallelism bounded by `sema`, which should be `asyncio.Semaphore` whose initial value is the desired level of parallelism. @@ -438,7 +453,8 @@ async def bounded_gather2_return_exceptions( the pair `(value, None)` if the partial function returned value or `(None, exc)` if the partial function raised the exception `exc`. - ''' + """ + async def run_with_sema_return_exceptions(pf: Callable[[], Awaitable[T]]): try: async with sema: @@ -453,11 +469,9 @@ async def run_with_sema_return_exceptions(pf: Callable[[], Awaitable[T]]): async def bounded_gather2_raise_exceptions( - sema: asyncio.Semaphore, - *pfs: Callable[[], Awaitable[T]], - cancel_on_error: bool = False + sema: asyncio.Semaphore, *pfs: Callable[[], Awaitable[T]], cancel_on_error: bool = False ) -> List[T]: - '''Run the partial functions `pfs` as tasks with parallelism bounded + """Run the partial functions `pfs` as tasks with parallelism bounded by `sema`, which should be `asyncio.Semaphore` whose initial value is the level of parallelism. @@ -470,7 +484,8 @@ async def bounded_gather2_raise_exceptions( functions continue to run with bounded parallelism. If cancel_on_error is True, the unfinished tasks are all cancelled. - ''' + """ + async def run_with_sema(pf: Callable[[], Awaitable[T]]): async with sema: return await pf() @@ -500,10 +515,10 @@ async def run_with_sema(pf: Callable[[], Awaitable[T]]): async def bounded_gather2( - sema: asyncio.Semaphore, - *pfs: Callable[[], Awaitable[T]], - return_exceptions: bool = False, - cancel_on_error: bool = False + sema: asyncio.Semaphore, + *pfs: Callable[[], Awaitable[T]], + return_exceptions: bool = False, + cancel_on_error: bool = False, ) -> List[T]: if return_exceptions: if cancel_on_error: @@ -527,6 +542,7 @@ async def bounded_gather2( RETRYABLE_ERRNOS = { # these should match (where an equivalent exists) nettyRetryableErrorNumbers in # is/hail/services/package.scala + errno.EADDRNOTAVAIL, errno.ETIMEDOUT, errno.ECONNREFUSED, errno.EHOSTUNREACH, @@ -551,10 +567,9 @@ def is_limited_retries_error(e: BaseException) -> bool: # provider can manifest as this exception *and* that manifestation is indistinguishable from a # true error. import hailtop.httpx # pylint: disable=import-outside-toplevel,cyclic-import + if aiodocker is not None and isinstance(e, aiodocker.exceptions.DockerError): - return (e.status == 404 - and 'azurecr.io' in e.message - and 'not found: manifest unknown: ' in e.message) + return e.status == 404 and 'azurecr.io' in e.message and 'not found: manifest unknown: ' in e.message if isinstance(e, hailtop.httpx.ClientResponseError): return e.status == 400 and any(msg in e.body for msg in RETRY_ONCE_BAD_REQUEST_ERROR_MESSAGES) if isinstance(e, ConnectionResetError): @@ -615,16 +630,18 @@ def is_transient_error(e: BaseException) -> bool: # https://hail.zulipchat.com/#narrow/stream/223457-Batch-support/topic/ssl.20error import hailtop.aiocloud.aiogoogle.client.compute_client # pylint: disable=import-outside-toplevel,cyclic-import import hailtop.httpx # pylint: disable=import-outside-toplevel,cyclic-import - if (isinstance(e, aiohttp.ClientResponseError) - and e.status in RETRYABLE_HTTP_STATUS_CODES): + + if isinstance(e, aiohttp.ClientResponseError) and e.status in RETRYABLE_HTTP_STATUS_CODES: return True - if (isinstance(e, hailtop.aiocloud.aiogoogle.client.compute_client.GCPOperationError) - and e.error_codes is not None - and 'QUOTA_EXCEEDED' in e.error_codes): + if ( + isinstance(e, hailtop.aiocloud.aiogoogle.client.compute_client.GCPOperationError) + and e.error_codes is not None + and 'QUOTA_EXCEEDED' in e.error_codes + ): return True - if (isinstance(e, hailtop.httpx.ClientResponseError) - and (e.status in RETRYABLE_HTTP_STATUS_CODES - or e.status == 403 and 'rateLimitExceeded' in e.body)): + if isinstance(e, hailtop.httpx.ClientResponseError) and ( + e.status in RETRYABLE_HTTP_STATUS_CODES or e.status == 403 and 'rateLimitExceeded' in e.body + ): return True if isinstance(e, aiohttp.ServerTimeoutError): return True @@ -632,17 +649,14 @@ def is_transient_error(e: BaseException) -> bool: return True if isinstance(e, asyncio.TimeoutError): return True - if (isinstance(e, aiohttp.ClientConnectorError) - and is_transient_error(e.os_error)): + if isinstance(e, aiohttp.ClientConnectorError) and is_transient_error(e.os_error): return True # appears to happen when the connection is lost prematurely, see: # https://github.com/aio-libs/aiohttp/issues/4581 # https://github.com/aio-libs/aiohttp/blob/v3.7.4/aiohttp/client_proto.py#L85 - if (isinstance(e, aiohttp.ClientPayloadError) - and e.args[0] == "Response payload is not completed"): + if isinstance(e, aiohttp.ClientPayloadError) and e.args[0] == "Response payload is not completed": return True - if (isinstance(e, aiohttp.ClientOSError) - and 'sslv3 alert bad record mac' in e.strerror): + if isinstance(e, aiohttp.ClientOSError) and 'sslv3 alert bad record mac' in e.strerror: # aiohttp.client_exceptions.ClientOSError: [Errno 1] [SSL: SSLV3_ALERT_BAD_RECORD_MAC] sslv3 alert bad record mac (_ssl.c:2548) # # This appears to be a symptom of Google rate-limiting as of 2023-10-15 @@ -666,7 +680,10 @@ def is_transient_error(e: BaseException) -> bool: if aiodocker is not None and isinstance(e, aiodocker.exceptions.DockerError): if e.status == 500 and 'Invalid repository name' in e.message: return False - if e.status == 500 and 'Permission "artifactregistry.repositories.downloadArtifacts" denied on resource' in e.message: + if ( + e.status == 500 + and 'Permission "artifactregistry.repositories.downloadArtifacts" denied on resource' in e.message + ): return False if e.status == 500 and 'denied: retrieving permissions failed' in e.message: return False @@ -696,9 +713,7 @@ def is_delayed_warning_error(e: BaseException) -> bool: def delay_ms_for_try( - tries: int, - base_delay_ms: int = DEFAULT_BASE_DELAY_MS, - max_delay_ms: int = DEFAULT_MAX_DELAY_MS + tries: int, base_delay_ms: int = DEFAULT_BASE_DELAY_MS, max_delay_ms: int = DEFAULT_MAX_DELAY_MS ) -> int: # Based on AWS' recommendations: # - https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/ @@ -710,17 +725,13 @@ def delay_ms_for_try( async def sleep_before_try( - tries: int, - base_delay_ms: int = DEFAULT_BASE_DELAY_MS, - max_delay_ms: int = DEFAULT_MAX_DELAY_MS + tries: int, base_delay_ms: int = DEFAULT_BASE_DELAY_MS, max_delay_ms: int = DEFAULT_MAX_DELAY_MS ): await asyncio.sleep(delay_ms_for_try(tries, base_delay_ms, max_delay_ms) / 1000.0) def sync_sleep_before_try( - tries: int, - base_delay_ms: int = DEFAULT_BASE_DELAY_MS, - max_delay_ms: int = DEFAULT_MAX_DELAY_MS + tries: int, base_delay_ms: int = DEFAULT_BASE_DELAY_MS, max_delay_ms: int = DEFAULT_MAX_DELAY_MS ): time.sleep(delay_ms_for_try(tries, base_delay_ms, max_delay_ms) / 1000.0) @@ -740,6 +751,7 @@ async def _wrapper(f: Callable[..., Awaitable[T]], *args, **kwargs) -> T: if msg and tries % error_logging_interval == 0: log.exception(msg, stack_info=True) await sleep_before_try(tries) + return _wrapper @@ -760,6 +772,7 @@ async def _wrapper(f: Callable[P, Awaitable[T]], *args: P.args, **kwargs: P.kwar if tries >= max_errors: raise await sleep_before_try(tries) + return _wrapper @@ -767,11 +780,15 @@ async def retry_transient_errors(f: Callable[..., Awaitable[T]], *args, **kwargs return await retry_transient_errors_with_debug_string('', 0, f, *args, **kwargs) -async def retry_transient_errors_with_delayed_warnings(warning_delay_msecs: int, f: Callable[..., Awaitable[T]], *args, **kwargs) -> T: +async def retry_transient_errors_with_delayed_warnings( + warning_delay_msecs: int, f: Callable[..., Awaitable[T]], *args, **kwargs +) -> T: return await retry_transient_errors_with_debug_string('', warning_delay_msecs, f, *args, **kwargs) -async def retry_transient_errors_with_debug_string(debug_string: str, warning_delay_msecs: int, f: Callable[..., Awaitable[T]], *args, **kwargs) -> T: +async def retry_transient_errors_with_debug_string( + debug_string: str, warning_delay_msecs: int, f: Callable[..., Awaitable[T]], *args, **kwargs +) -> T: start_time = time_msecs() tries = 0 while True: @@ -793,14 +810,19 @@ async def retry_transient_errors_with_debug_string(debug_string: str, warning_de else: log_warnings = (time_msecs() - start_time >= warning_delay_msecs) or not is_delayed_warning_error(e) if log_warnings and tries == 2: - log.warning(f'A transient error occured. We will automatically retry. Do not be alarmed. ' - f'We have thus far seen {tries} transient errors (next delay: ' - f'{delay}s). The most recent error was {type(e)} {e}. {debug_string}') + log.warning( + f'A transient error occured. We will automatically retry. Do not be alarmed. ' + f'We have thus far seen {tries} transient errors (next delay: ' + f'{delay}s). The most recent error was {type(e)} {e}. {debug_string}' + ) elif log_warnings and tries % 10 == 0: st = ''.join(traceback.format_stack()) - log.warning(f'A transient error occured. We will automatically retry. ' - f'We have thus far seen {tries} transient errors (next delay: ' - f'{delay}s). The stack trace for this call is {st}. The most recent error was {type(e)} {e}. {debug_string}', exc_info=True) + log.warning( + f'A transient error occured. We will automatically retry. ' + f'We have thus far seen {tries} transient errors (next delay: ' + f'{delay}s). The stack trace for this call is {st}. The most recent error was {type(e)} {e}. {debug_string}', + exc_info=True, + ) await asyncio.sleep(delay) @@ -815,7 +837,9 @@ def sync_retry_transient_errors(f: Callable[..., T], *args, **kwargs) -> T: tries += 1 if tries % 10 == 0: st = ''.join(traceback.format_stack()) - log.warning(f'Encountered {tries} errors. My stack trace is {st}. Most recent error was {e}', exc_info=True) + log.warning( + f'Encountered {tries} errors. My stack trace is {st}. Most recent error was {e}', exc_info=True + ) if is_transient_error(e): pass else: @@ -825,15 +849,12 @@ def sync_retry_transient_errors(f: Callable[..., T], *args, **kwargs) -> T: def retry_response_returning_functions(fun, *args, **kwargs): tries = 0 - response = sync_retry_transient_errors( - fun, *args, **kwargs) + response = sync_retry_transient_errors(fun, *args, **kwargs) while response.status_code in RETRYABLE_HTTP_STATUS_CODES: tries += 1 if tries % 10 == 0: - log.warning(f'encountered {tries} bad status codes, most recent ' - f'one was {response.status_code}') - response = sync_retry_transient_errors( - fun, *args, **kwargs) + log.warning(f'encountered {tries} bad status codes, most recent ' f'one was {response.status_code}') + response = sync_retry_transient_errors(fun, *args, **kwargs) sync_sleep_before_try(tries) return response @@ -854,13 +875,11 @@ def __init__(self, max_retries, timeout): self.timeout = timeout super().__init__(max_retries=max_retries) - def init_poolmanager(self, connections, maxsize, block=False): + def init_poolmanager(self, connections, maxsize, block=False, **pool_kwargs): + assert len(pool_kwargs) == 0 self.poolmanager = PoolManager( - num_pools=connections, - maxsize=maxsize, - block=block, - retries=self.max_retries, - timeout=self.timeout) + num_pools=connections, maxsize=maxsize, block=block, retries=self.max_retries, timeout=self.timeout + ) async def collect_aiter(aiter: AsyncIterator[T]) -> List[T]: @@ -892,9 +911,7 @@ async def retry_long_running(name: str, f: Callable[P, Awaitable[T]], *args: P.a await asyncio.sleep(t) ran_for_secs = (end_time - start_time) * 1000 - delay_secs = min( - max(0.1, 2 * delay_secs - min(0, (ran_for_secs - t) / 2)), - 30.0) + delay_secs = min(max(0.1, 2 * delay_secs - min(0, (ran_for_secs - t) / 2)), 30.0) async def run_if_changed(changed: asyncio.Event, f: Callable[..., Awaitable[bool]], *args, **kwargs): @@ -926,6 +943,7 @@ async def loop(): while True: await f(*args, **kwargs) await asyncio.sleep(period) + await retry_long_running(f.__name__, loop) @@ -935,6 +953,7 @@ async def loop(): while True: await f(*args, **kwargs) await asyncio.sleep(period()) + await retry_long_running(f.__name__, loop) @@ -995,8 +1014,8 @@ def url_scheme(url: str) -> str: def url_and_params(url: str) -> Tuple[str, Dict[str, str]]: """Strip the query parameters from `url` and parse them into a dictionary. - Assumes that all query parameters are used only once, so have only one - value. + Assumes that all query parameters are used only once, so have only one + value. """ parsed = urllib.parse.urlparse(url) params = {k: v[0] for k, v in urllib.parse.parse_qs(parsed.query).items()} @@ -1067,15 +1086,13 @@ def notify(self): def find_spark_home() -> str: spark_home = os.environ.get('SPARK_HOME') if spark_home is None: - find_spark_home = subprocess.run('find_spark_home.py', - capture_output=True, - check=False) + find_spark_home = subprocess.run('find_spark_home.py', capture_output=True, check=False) if find_spark_home.returncode != 0: - raise ValueError(f'''SPARK_HOME is not set and find_spark_home.py returned non-zero exit code: + raise ValueError(f"""SPARK_HOME is not set and find_spark_home.py returned non-zero exit code: STDOUT: {find_spark_home.stdout!r} STDERR: -{find_spark_home.stderr!r}''') +{find_spark_home.stderr!r}""") spark_home = find_spark_home.stdout.decode().strip() return spark_home diff --git a/hail/python/hailtop/utils/validate/__init__.py b/hail/python/hailtop/utils/validate/__init__.py index 0af209db9ef..fc4c47a6908 100644 --- a/hail/python/hailtop/utils/validate/__init__.py +++ b/hail/python/hailtop/utils/validate/__init__.py @@ -1,5 +1,20 @@ -from .validate import anyof, bool_type, dictof, keyed, listof, int_type, nullable, \ - numeric, oneof, regex, required, str_type, non_empty_str_type, switch, ValidationError +from .validate import ( + ValidationError, + anyof, + bool_type, + dictof, + int_type, + keyed, + listof, + non_empty_str_type, + nullable, + numeric, + oneof, + regex, + required, + str_type, + switch, +) __all__ = [ 'anyof', @@ -16,5 +31,5 @@ 'str_type', 'non_empty_str_type', 'switch', - 'ValidationError' + 'ValidationError', ] diff --git a/hail/python/hailtop/utils/validate/validate.py b/hail/python/hailtop/utils/validate/validate.py index a64e143fe24..4fab467eb78 100644 --- a/hail/python/hailtop/utils/validate/validate.py +++ b/hail/python/hailtop/utils/validate/validate.py @@ -1,6 +1,6 @@ -from typing import Union, Dict, Pattern, Callable, Any, List, Optional -import re import logging +import re +from typing import Any, Callable, Dict, List, Optional, Pattern, Union log = logging.getLogger('foo') @@ -118,8 +118,7 @@ def __init__(self, key: str, checkers: Dict[str, Dict[Key, 'Validator']]): super().__init__(dict) self.key = key self.valid_key = oneof(*checkers.keys()) - self.checkers = {k: keyed({required(key): self.valid_key, **fields}) - for k, fields in checkers.items()} + self.checkers = {k: keyed({required(key): self.valid_key, **fields}) for k, fields in checkers.items()} def __getitem__(self, key): return self.checkers[key] diff --git a/hail/python/hailtop/uvloopx.py b/hail/python/hailtop/uvloopx.py new file mode 100644 index 00000000000..bac9f2542f6 --- /dev/null +++ b/hail/python/hailtop/uvloopx.py @@ -0,0 +1,13 @@ +import sys + +try: + import uvloop + + def install(): + return uvloop.install() +except ImportError as e: + if not sys.platform.startswith('win32'): + raise e + + def install(): + pass diff --git a/hail/python/hailtop/yamlx.py b/hail/python/hailtop/yamlx.py index aa2b4721ad0..09d695d317c 100644 --- a/hail/python/hailtop/yamlx.py +++ b/hail/python/hailtop/yamlx.py @@ -8,12 +8,10 @@ def yaml_dump_multiline_str_as_literal_block(dumper, data): class HailDumper(yaml.SafeDumper): - @property - def yaml_representers(self): - return { - **super().yaml_representers, - str: yaml_dump_multiline_str_as_literal_block - } + pass + + +HailDumper.add_representer(str, yaml_dump_multiline_str_as_literal_block) def dump(data) -> str: diff --git a/hail/python/pinned-requirements.txt b/hail/python/pinned-requirements.txt index d7af149ca74..86c15d88a41 100644 --- a/hail/python/pinned-requirements.txt +++ b/hail/python/pinned-requirements.txt @@ -8,7 +8,7 @@ aiodns==2.0.0 # via # -c hail/hail/python/hailtop/pinned-requirements.txt # -r hail/hail/python/hailtop/requirements.txt -aiohttp==3.9.1 +aiohttp==3.9.3 # via # -c hail/hail/python/hailtop/pinned-requirements.txt # -r hail/hail/python/hailtop/requirements.txt @@ -20,7 +20,7 @@ async-timeout==4.0.3 # via # -c hail/hail/python/hailtop/pinned-requirements.txt # aiohttp -attrs==23.1.0 +attrs==23.2.0 # via # -c hail/hail/python/hailtop/pinned-requirements.txt # aiohttp @@ -30,7 +30,7 @@ azure-common==1.1.28 # via # -c hail/hail/python/hailtop/pinned-requirements.txt # azure-mgmt-storage -azure-core==1.29.5 +azure-core==1.30.1 # via # -c hail/hail/python/hailtop/pinned-requirements.txt # azure-identity @@ -53,23 +53,23 @@ azure-storage-blob==12.19.0 # via # -c hail/hail/python/hailtop/pinned-requirements.txt # -r hail/hail/python/hailtop/requirements.txt -bokeh==3.3.1 +bokeh==3.3.4 # via -r hail/hail/python/requirements.txt -boto3==1.33.1 +boto3==1.34.55 # via # -c hail/hail/python/hailtop/pinned-requirements.txt # -r hail/hail/python/hailtop/requirements.txt -botocore==1.33.1 +botocore==1.34.55 # via # -c hail/hail/python/hailtop/pinned-requirements.txt # -r hail/hail/python/hailtop/requirements.txt # boto3 # s3transfer -cachetools==5.3.2 +cachetools==5.3.3 # via # -c hail/hail/python/hailtop/pinned-requirements.txt # google-auth -certifi==2023.11.17 +certifi==2024.2.2 # via # -c hail/hail/python/hailtop/pinned-requirements.txt # msrest @@ -93,7 +93,7 @@ commonmark==0.9.1 # rich contourpy==1.2.0 # via bokeh -cryptography==41.0.7 +cryptography==42.0.5 # via # -c hail/hail/python/hailtop/pinned-requirements.txt # azure-identity @@ -104,17 +104,17 @@ decorator==4.4.2 # via -r hail/hail/python/requirements.txt deprecated==1.2.14 # via -r hail/hail/python/requirements.txt -dill==0.3.7 +dill==0.3.8 # via # -c hail/hail/python/hailtop/pinned-requirements.txt # -r hail/hail/python/hailtop/requirements.txt -frozenlist==1.4.0 +frozenlist==1.4.1 # via # -c hail/hail/python/hailtop/pinned-requirements.txt # -r hail/hail/python/hailtop/requirements.txt # aiohttp # aiosignal -google-auth==2.23.4 +google-auth==2.28.1 # via # -c hail/hail/python/hailtop/pinned-requirements.txt # -r hail/hail/python/hailtop/requirements.txt @@ -141,7 +141,7 @@ janus==1.0.0 # via # -c hail/hail/python/hailtop/pinned-requirements.txt # -r hail/hail/python/hailtop/requirements.txt -jinja2==3.1.2 +jinja2==3.1.3 # via bokeh jmespath==1.0.1 # via @@ -152,14 +152,14 @@ jproperties==2.1.1 # via # -c hail/hail/python/hailtop/pinned-requirements.txt # -r hail/hail/python/hailtop/requirements.txt -markupsafe==2.1.3 +markupsafe==2.1.5 # via jinja2 -msal==1.25.0 +msal==1.27.0 # via # -c hail/hail/python/hailtop/pinned-requirements.txt # azure-identity # msal-extensions -msal-extensions==1.0.0 +msal-extensions==1.1.0 # via # -c hail/hail/python/hailtop/pinned-requirements.txt # azure-identity @@ -167,16 +167,16 @@ msrest==0.7.1 # via # -c hail/hail/python/hailtop/pinned-requirements.txt # azure-mgmt-storage -multidict==6.0.4 +multidict==6.0.5 # via # -c hail/hail/python/hailtop/pinned-requirements.txt # aiohttp # yarl -nest-asyncio==1.5.8 +nest-asyncio==1.6.0 # via # -c hail/hail/python/hailtop/pinned-requirements.txt # -r hail/hail/python/hailtop/requirements.txt -numpy==1.26.2 +numpy==1.26.4 # via # -r hail/hail/python/requirements.txt # bokeh @@ -193,27 +193,24 @@ orjson==3.9.10 # -r hail/hail/python/hailtop/requirements.txt packaging==23.2 # via + # -c hail/hail/python/hailtop/pinned-requirements.txt # bokeh + # msal-extensions # plotly -pandas==2.1.3 +pandas==2.2.1 # via # -r hail/hail/python/requirements.txt # bokeh parsimonious==0.10.0 # via -r hail/hail/python/requirements.txt -pillow==10.1.0 +pillow==10.2.0 # via bokeh -plotly==5.18.0 +plotly==5.19.0 # via -r hail/hail/python/requirements.txt portalocker==2.8.2 # via # -c hail/hail/python/hailtop/pinned-requirements.txt # msal-extensions -protobuf==3.20.2 - # via - # -c hail/hail/python/hailtop/pinned-requirements.txt - # -r hail/hail/python/hailtop/requirements.txt - # -r hail/hail/python/requirements.txt py4j==0.10.9.5 # via pyspark pyasn1==0.5.1 @@ -241,9 +238,11 @@ pyjwt[crypto]==2.8.0 # via # -c hail/hail/python/hailtop/pinned-requirements.txt # msal -pyspark==3.3.3 - # via -r hail/hail/python/requirements.txt -python-dateutil==2.8.2 +pyspark==3.3.2 + # via + # -c hail/hail/python/dataproc-pre-installed-requirements.txt + # -r hail/hail/python/requirements.txt +python-dateutil==2.9.0.post0 # via # -c hail/hail/python/hailtop/pinned-requirements.txt # botocore @@ -252,14 +251,14 @@ python-json-logger==2.0.7 # via # -c hail/hail/python/hailtop/pinned-requirements.txt # -r hail/hail/python/hailtop/requirements.txt -pytz==2023.3.post1 +pytz==2024.1 # via pandas pyyaml==6.0.1 # via # -c hail/hail/python/hailtop/pinned-requirements.txt # -r hail/hail/python/hailtop/requirements.txt # bokeh -regex==2023.10.3 +regex==2023.12.25 # via parsimonious requests==2.31.0 # via @@ -282,7 +281,7 @@ rsa==4.9 # via # -c hail/hail/python/hailtop/pinned-requirements.txt # google-auth -s3transfer==0.8.0 +s3transfer==0.10.0 # via # -c hail/hail/python/hailtop/pinned-requirements.txt # boto3 @@ -305,20 +304,20 @@ tabulate==0.9.0 # -r hail/hail/python/hailtop/requirements.txt tenacity==8.2.3 # via plotly -tornado==6.3.3 +tornado==6.4 # via bokeh typer==0.9.0 # via # -c hail/hail/python/hailtop/pinned-requirements.txt # -r hail/hail/python/hailtop/requirements.txt -typing-extensions==4.8.0 +typing-extensions==4.10.0 # via # -c hail/hail/python/hailtop/pinned-requirements.txt # azure-core # azure-storage-blob # janus # typer -tzdata==2023.3 +tzdata==2024.1 # via pandas urllib3==1.26.18 # via @@ -333,7 +332,7 @@ wrapt==1.16.0 # via deprecated xyzservices==2023.10.1 # via bokeh -yarl==1.9.3 +yarl==1.9.4 # via # -c hail/hail/python/hailtop/pinned-requirements.txt # aiohttp diff --git a/hail/python/requirements.txt b/hail/python/requirements.txt index eaed2c0d116..34ad198c293 100644 --- a/hail/python/requirements.txt +++ b/hail/python/requirements.txt @@ -1,4 +1,5 @@ -c hailtop/pinned-requirements.txt +-c dataproc-pre-installed-requirements.txt -r hailtop/requirements.txt avro>=1.10,<1.12 @@ -8,8 +9,7 @@ Deprecated>=1.2.10,<1.3 numpy<2 pandas>=2,<3 parsimonious<1 -plotly>=5.5.0,<6 -protobuf==3.20.2 -pyspark>=3.3.0,<3.4 +plotly>=5.18.0,<6 +pyspark>=3.3.2,<3.4 requests>=2.31.0,<3 scipy>1.2,<1.12 diff --git a/hail/python/setup-hailtop.py b/hail/python/setup-hailtop.py index 4223b8d9e65..3a59a9b4f29 100644 --- a/hail/python/setup-hailtop.py +++ b/hail/python/setup-hailtop.py @@ -1,6 +1,6 @@ #!/usr/bin/env python -from setuptools import setup, find_packages +from setuptools import find_packages, setup setup( name='hailtop', @@ -14,20 +14,14 @@ 'Repository': 'https://github.com/hail-is/hail', }, packages=find_packages('.'), - package_dir={ - 'hailtop': 'hailtop'}, - package_data={ - "hailtop": ["py.typed", "hail_version"], - 'hailtop.hailctl': ['hail_version', 'deploy.yaml'] - }, + package_dir={'hailtop': 'hailtop'}, + package_data={"hailtop": ["py.typed", "hail_version"], 'hailtop.hailctl': ['hail_version', 'deploy.yaml']}, classifiers=[ "Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License", ], python_requires=">=3.9", - entry_points={ - 'console_scripts': ['hailctl = hailtop.hailctl.__main__:main'] - }, + entry_points={'console_scripts': ['hailctl = hailtop.hailctl.__main__:main']}, setup_requires=["pytest-runner", "wheel"], include_package_data=True, ) diff --git a/hail/python/setup.py b/hail/python/setup.py index 5c75c86588e..ded40cb8d31 100755 --- a/hail/python/setup.py +++ b/hail/python/setup.py @@ -1,8 +1,8 @@ #!/usr/bin/env python import os -import re -from setuptools import setup, find_packages + +from setuptools import find_packages, setup with open('hail/hail_pip_version') as f: hail_pip_version = f.read().strip() @@ -19,6 +19,7 @@ pkg = stripped + def add_dependencies(fname): with open(fname, 'r') as f: for line in f: @@ -28,7 +29,7 @@ def add_dependencies(fname): if stripped.startswith('-c'): continue if stripped.startswith('-r'): - additional_requirements = stripped[len('-r'):].strip() + additional_requirements = stripped[len('-r') :].strip() add_dependencies(additional_requirements) continue pkg = stripped @@ -39,6 +40,8 @@ def add_dependencies(fname): dependencies.append(f'pyspark>={major}.{minor},<{int(major)+1}') else: dependencies.append(pkg) + + add_dependencies('requirements.txt') setup( @@ -56,26 +59,20 @@ def add_dependencies(fname): 'Change Log': 'https://hail.is/docs/0.2/change_log.html', }, packages=find_packages('.'), - package_dir={ - 'hail': 'hail', - 'hailtop': 'hailtop'}, + package_dir={'hail': 'hail', 'hailtop': 'hailtop'}, package_data={ - 'hail': ['hail_pip_version', - 'hail_version', - 'hail_revision', - 'experimental/datasets.json'], + 'hail': ['hail_pip_version', 'hail_version', 'hail_revision', 'experimental/datasets.json'], 'hail.backend': ['hail-all-spark.jar'], 'hailtop': ['hail_version', 'py.typed'], - 'hailtop.hailctl': ['hail_version', 'deploy.yaml']}, + 'hailtop.hailctl': ['hail_version', 'deploy.yaml'], + }, classifiers=[ "Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License", ], python_requires=">=3.9", install_requires=dependencies, - entry_points={ - 'console_scripts': ['hailctl = hailtop.hailctl.__main__:main'] - }, + entry_points={'console_scripts': ['hailctl = hailtop.hailctl.__main__:main']}, setup_requires=["pytest-runner", "wheel"], tests_require=["pytest"], include_package_data=True, diff --git a/hail/python/test/hail/backend/test_service_backend.py b/hail/python/test/hail/backend/test_service_backend.py index 7050fb724e3..f79793c25fa 100644 --- a/hail/python/test/hail/backend/test_service_backend.py +++ b/hail/python/test/hail/backend/test_service_backend.py @@ -1,10 +1,10 @@ import os import hail as hl - -from ..helpers import skip_unless_service_backend, test_timeout, qobtest from hail.backend.service_backend import ServiceBackend +from ..helpers import qobtest, skip_unless_service_backend, test_timeout + @qobtest @skip_unless_service_backend() diff --git a/hail/python/test/hail/backend/test_spark_backend.py b/hail/python/test/hail/backend/test_spark_backend.py new file mode 100644 index 00000000000..07f6accea82 --- /dev/null +++ b/hail/python/test/hail/backend/test_spark_backend.py @@ -0,0 +1,29 @@ +import os +from test.hail.helpers import skip_unless_spark_backend + +import pytest + +import hail as hl + + +def fatal(typ: hl.HailType, msg: str = "") -> hl.Expression: + return hl.construct_expr(hl.ir.Die(hl.to_expr(msg, hl.tstr)._ir, typ), typ) + + +@skip_unless_spark_backend() +@pytest.mark.parametrize('copy', [True, False]) +def test_copy_spark_log(copy): + hl.stop() + hl.init(copy_spark_log_on_error=copy) + + expr = fatal(hl.tint32) + with pytest.raises(Exception): + hl.eval(expr) + + from hail.utils.java import Env + + hc = Env.hc() + _, filename = os.path.split(hc._log) + log = os.path.join(hc._tmpdir, filename) + + assert Env.fs().exists(log) if copy else not Env.fs().exists(log) diff --git a/hail/python/test/hail/conftest.py b/hail/python/test/hail/conftest.py index 933b522181c..9036152bd35 100644 --- a/hail/python/test/hail/conftest.py +++ b/hail/python/test/hail/conftest.py @@ -1,23 +1,31 @@ -from typing import Dict import asyncio import hashlib -import os import logging +import os +from typing import Dict import pytest -from pytest import StashKey, CollectReport +from pytest import CollectReport, StashKey -from hail import current_backend, init, reset_global_randomness +from hail import current_backend, reset_global_randomness from hail.backend.service_backend import ServiceBackend from hailtop.hail_event_loop import hail_event_loop -from hailtop.utils import secret_alnum_string -from .helpers import hl_init_for_test, hl_stop_for_test +from .helpers import hl_init_for_test, hl_stop_for_test log = logging.getLogger(__name__) -def pytest_collection_modifyitems(config, items): +@pytest.fixture(scope="session") +def event_loop(): + loop = asyncio.get_event_loop() + try: + yield loop + finally: + loop.close() + + +def pytest_collection_modifyitems(items): n_splits = int(os.environ.get('HAIL_RUN_IMAGE_SPLITS', '1')) split_index = int(os.environ.get('HAIL_RUN_IMAGE_SPLIT_INDEX', '-1')) if n_splits <= 1: @@ -34,15 +42,6 @@ def digest(s): item.add_marker(skip_this) -@pytest.fixture(scope="session", autouse=True) -def ensure_event_loop_is_initialized_in_test_thread(): - try: - asyncio.get_running_loop() - except RuntimeError as err: - assert err.args[0] == "no running event loop" - asyncio.set_event_loop(asyncio.new_event_loop()) - - @pytest.fixture(scope="session", autouse=True) def init_hail(): hl_init_for_test() diff --git a/hail/python/test/hail/experimental/test_annotation_db.py b/hail/python/test/hail/experimental/test_annotation_db.py index 182118240eb..5edcfe858b1 100644 --- a/hail/python/test/hail/experimental/test_annotation_db.py +++ b/hail/python/test/hail/experimental/test_annotation_db.py @@ -23,24 +23,32 @@ def db_json(init_hail): 'description': 'now with unique rows!', 'url': 'https://example.com', 'annotation_db': {'key_properties': ['unique']}, - 'versions': [{ - 'url': {"aws": {"eu": fname, "us": fname}, - "gcp": {"eu": fname, "us": fname}}, - 'version': 'v1', - 'reference_genome': 'GRCh37' - }] + 'versions': [ + { + 'url': { + "aws": {"eu": fname, "us": fname}, + "gcp": {"europe-west1": fname, "us-central1": fname}, + }, + 'version': 'v1', + 'reference_genome': 'GRCh37', + } + ], }, 'nonunique_dataset': { 'description': 'non-unique rows :(', 'url': 'https://example.net', 'annotation_db': {'key_properties': []}, - 'versions': [{ - 'url': {"aws": {"eu": fname, "us": fname}, - "gcp": {"eu": fname, "us": fname}}, - 'version': 'v1', - 'reference_genome': 'GRCh37' - }] - } + 'versions': [ + { + 'url': { + "aws": {"eu": fname, "us": fname}, + "gcp": {"europe-west1": fname, "us-central1": fname}, + }, + 'version': 'v1', + 'reference_genome': 'GRCh37', + } + ], + }, } yield db_json @@ -48,7 +56,7 @@ def db_json(init_hail): tempdir_manager.__exit__(None, None, None) def test_uniqueness(self, db_json): - db = hl.experimental.DB(region='us', cloud='gcp', config=db_json) + db = hl.experimental.DB(region='us-central1', cloud='gcp', config=db_json) t = hl.utils.genomic_range_table(10) t = db.annotate_rows_db(t, 'unique_dataset', 'nonunique_dataset') assert t.unique_dataset.dtype == hl.dtype('struct{annotation: str}') diff --git a/hail/python/test/hail/experimental/test_experimental.py b/hail/python/test/hail/experimental/test_experimental.py index 0c62a70e890..1964f939ea5 100644 --- a/hail/python/test/hail/experimental/test_experimental.py +++ b/hail/python/test/hail/experimental/test_experimental.py @@ -1,48 +1,58 @@ -import numpy as np -import hail as hl import unittest + +import numpy as np import pytest -from ..helpers import * + +import hail as hl from hail.utils import new_temp_file +from ..helpers import ( + assert_evals_to, + doctest_resource, + fails_local_backend, + fails_service_backend, + qobtest, + resource, + test_timeout, +) + class Tests(unittest.TestCase): @qobtest def get_ld_score_mt(self): - ht = hl.import_table(doctest_resource('ldsc.annot'), - types={'BP': hl.tint, - 'CM': hl.tfloat, - 'binary': hl.tint, - 'continuous': hl.tfloat}) + ht = hl.import_table( + doctest_resource('ldsc.annot'), + types={'BP': hl.tint, 'CM': hl.tfloat, 'binary': hl.tint, 'continuous': hl.tfloat}, + ) ht = ht.annotate(locus=hl.locus(ht.CHR, ht.BP)) ht = ht.key_by('locus') - mt = hl.import_plink(bed=doctest_resource('ldsc.bed'), - bim=doctest_resource('ldsc.bim'), - fam=doctest_resource('ldsc.fam')) - return mt.annotate_rows(binary=ht[mt.locus].binary, - continuous=ht[mt.locus].continuous) + mt = hl.import_plink( + bed=doctest_resource('ldsc.bed'), bim=doctest_resource('ldsc.bim'), fam=doctest_resource('ldsc.fam') + ) + return mt.annotate_rows(binary=ht[mt.locus].binary, continuous=ht[mt.locus].continuous) @fails_service_backend() @fails_local_backend def test_ld_score_univariate(self): mt = self.get_ld_score_mt() ht_univariate = hl.experimental.ld_score( - entry_expr=mt.GT.n_alt_alleles(), - locus_expr=mt.locus, - radius=1.0, - coord_expr=mt.cm_position) - - univariate = ht_univariate.aggregate(hl.struct( - chr20=hl.agg.filter( - (ht_univariate.locus.contig == '20') & - (ht_univariate.locus.position == 82079), - hl.agg.collect(ht_univariate.univariate))[0], - chr22 =hl.agg.filter( - (ht_univariate.locus.contig == '22') & - (ht_univariate.locus.position == 16894090), - hl.agg.collect(ht_univariate.univariate))[0], - mean=hl.agg.mean(ht_univariate.univariate))) + entry_expr=mt.GT.n_alt_alleles(), locus_expr=mt.locus, radius=1.0, coord_expr=mt.cm_position + ) + + univariate = ht_univariate.aggregate( + hl.struct( + chr20=hl.agg.filter( + (ht_univariate.locus.contig == '20') & (ht_univariate.locus.position == 82079), + hl.agg.collect(ht_univariate.univariate), + )[0], + chr22=hl.agg.filter( + (ht_univariate.locus.contig == '22') & (ht_univariate.locus.position == 16894090), + hl.agg.collect(ht_univariate.univariate), + )[0], + mean=hl.agg.mean(ht_univariate.univariate), + ) + ) self.assertAlmostEqual(univariate.chr20, 1.601, places=3) self.assertAlmostEqual(univariate.chr22, 1.140, places=3) @@ -58,29 +68,35 @@ def test_ld_score_annotated(self): locus_expr=mt.locus, radius=1.0, coord_expr=mt.cm_position, - annotation_exprs=[mt.binary, - mt.continuous]) + annotation_exprs=[mt.binary, mt.continuous], + ) annotated = ht_annotated.aggregate( hl.struct( - chr20=hl.struct(binary=hl.agg.filter( - (ht_annotated.locus.contig == '20') & - (ht_annotated.locus.position == 82079), - hl.agg.collect(ht_annotated.binary))[0], - continuous=hl.agg.filter( - (ht_annotated.locus.contig == '20') & - (ht_annotated.locus.position == 82079), - hl.agg.collect(ht_annotated.continuous))[0]), + chr20=hl.struct( + binary=hl.agg.filter( + (ht_annotated.locus.contig == '20') & (ht_annotated.locus.position == 82079), + hl.agg.collect(ht_annotated.binary), + )[0], + continuous=hl.agg.filter( + (ht_annotated.locus.contig == '20') & (ht_annotated.locus.position == 82079), + hl.agg.collect(ht_annotated.continuous), + )[0], + ), chr22=hl.struct( binary=hl.agg.filter( - (ht_annotated.locus.contig == '22') & - (ht_annotated.locus.position == 16894090), - hl.agg.collect(ht_annotated.binary))[0], + (ht_annotated.locus.contig == '22') & (ht_annotated.locus.position == 16894090), + hl.agg.collect(ht_annotated.binary), + )[0], continuous=hl.agg.filter( - (ht_annotated.locus.contig == '22') & - (ht_annotated.locus.position == 16894090), - hl.agg.collect(ht_annotated.continuous))[0]), - mean_stats=hl.struct(binary=hl.agg.mean(ht_annotated.binary), - continuous=hl.agg.mean(ht_annotated.continuous)))) + (ht_annotated.locus.contig == '22') & (ht_annotated.locus.position == 16894090), + hl.agg.collect(ht_annotated.continuous), + )[0], + ), + mean_stats=hl.struct( + binary=hl.agg.mean(ht_annotated.binary), continuous=hl.agg.mean(ht_annotated.continuous) + ), + ) + ) self.assertAlmostEqual(annotated.chr20.binary, 1.152, places=3) self.assertAlmostEqual(annotated.chr20.continuous, 73.014, places=3) @@ -102,75 +118,74 @@ def test_import_keyby_count_ldsc_lowered_shuffle(self): # if this comment no longer reflects the backend system, that's a really good thing ht_scores = hl.import_table( doctest_resource('ld_score_regression.univariate_ld_scores.tsv'), - key='SNP', types={'L2': hl.tfloat, 'BP': hl.tint}) + key='SNP', + types={'L2': hl.tfloat, 'BP': hl.tint}, + ) ht_20160 = hl.import_table( - doctest_resource('ld_score_regression.20160.sumstats.tsv'), - key='SNP', types={'N': hl.tint, 'Z': hl.tfloat}) + doctest_resource('ld_score_regression.20160.sumstats.tsv'), key='SNP', types={'N': hl.tint, 'Z': hl.tfloat} + ) j1 = ht_scores[ht_20160['SNP']] ht_20160 = ht_20160.annotate( - ld_score=j1['L2'], - locus=hl.locus(j1['CHR'], - j1['BP']), - alleles=hl.array([ht_20160['A2'], ht_20160['A1']])) + ld_score=j1['L2'], locus=hl.locus(j1['CHR'], j1['BP']), alleles=hl.array([ht_20160['A2'], ht_20160['A1']]) + ) - ht_20160 = ht_20160.key_by(ht_20160['locus'], - ht_20160['alleles']) + ht_20160 = ht_20160.key_by(ht_20160['locus'], ht_20160['alleles']) assert ht_20160._force_count() == 151 def get_ht_50_irnt(self): ht_scores = hl.import_table( doctest_resource('ld_score_regression.univariate_ld_scores.tsv'), - key='SNP', types={'L2': hl.tfloat, 'BP': hl.tint}) + key='SNP', + types={'L2': hl.tfloat, 'BP': hl.tint}, + ) ht_50_irnt = hl.import_table( doctest_resource('ld_score_regression.50_irnt.sumstats.tsv'), - key='SNP', types={'N': hl.tint, 'Z': hl.tfloat}) + key='SNP', + types={'N': hl.tint, 'Z': hl.tfloat}, + ) ht_50_irnt = ht_50_irnt.annotate( - chi_squared=ht_50_irnt['Z']**2, + chi_squared=ht_50_irnt['Z'] ** 2, n=ht_50_irnt['N'], ld_score=ht_scores[ht_50_irnt['SNP']]['L2'], - locus=hl.locus(ht_scores[ht_50_irnt['SNP']]['CHR'], - ht_scores[ht_50_irnt['SNP']]['BP']), + locus=hl.locus(ht_scores[ht_50_irnt['SNP']]['CHR'], ht_scores[ht_50_irnt['SNP']]['BP']), alleles=hl.array([ht_50_irnt['A2'], ht_50_irnt['A1']]), - phenotype='50_irnt') + phenotype='50_irnt', + ) - ht_50_irnt = ht_50_irnt.key_by(ht_50_irnt['locus'], - ht_50_irnt['alleles']) + ht_50_irnt = ht_50_irnt.key_by(ht_50_irnt['locus'], ht_50_irnt['alleles']) - ht_50_irnt = ht_50_irnt.select(ht_50_irnt['chi_squared'], - ht_50_irnt['n'], - ht_50_irnt['ld_score'], - ht_50_irnt['phenotype']) + ht_50_irnt = ht_50_irnt.select( + ht_50_irnt['chi_squared'], ht_50_irnt['n'], ht_50_irnt['ld_score'], ht_50_irnt['phenotype'] + ) return ht_50_irnt def get_ht_20160(self): ht_scores = hl.import_table( doctest_resource('ld_score_regression.univariate_ld_scores.tsv'), - key='SNP', types={'L2': hl.tfloat, 'BP': hl.tint}) + key='SNP', + types={'L2': hl.tfloat, 'BP': hl.tint}, + ) ht_20160 = hl.import_table( - doctest_resource('ld_score_regression.20160.sumstats.tsv'), - key='SNP', types={'N': hl.tint, 'Z': hl.tfloat}) + doctest_resource('ld_score_regression.20160.sumstats.tsv'), key='SNP', types={'N': hl.tint, 'Z': hl.tfloat} + ) ht_20160 = ht_20160.annotate( - chi_squared=ht_20160['Z']**2, + chi_squared=ht_20160['Z'] ** 2, n=ht_20160['N'], ld_score=ht_scores[ht_20160['SNP']]['L2'], - locus=hl.locus(ht_scores[ht_20160['SNP']]['CHR'], - ht_scores[ht_20160['SNP']]['BP']), + locus=hl.locus(ht_scores[ht_20160['SNP']]['CHR'], ht_scores[ht_20160['SNP']]['BP']), alleles=hl.array([ht_20160['A2'], ht_20160['A1']]), - phenotype='20160') + phenotype='20160', + ) - ht_20160 = ht_20160.key_by(ht_20160['locus'], - ht_20160['alleles']) + ht_20160 = ht_20160.key_by(ht_20160['locus'], ht_20160['alleles']) - ht_20160 = ht_20160.select(ht_20160['chi_squared'], - ht_20160['n'], - ht_20160['ld_score'], - ht_20160['phenotype']) + ht_20160 = ht_20160.select(ht_20160['chi_squared'], ht_20160['n'], ht_20160['ld_score'], ht_20160['phenotype']) return ht_20160 @pytest.mark.unchecked_allocator @@ -180,10 +195,9 @@ def test_ld_score_regression_1(self): ht_20160 = self.get_ht_20160() ht = ht_50_irnt.union(ht_20160) - mt = ht.to_matrix_table(row_key=['locus', 'alleles'], - col_key=['phenotype'], - row_fields=['ld_score'], - col_fields=[]) + mt = ht.to_matrix_table( + row_key=['locus', 'alleles'], col_key=['phenotype'], row_fields=['ld_score'], col_fields=[] + ) mt_tmp = new_temp_file() mt.write(mt_tmp, overwrite=True) @@ -196,7 +210,8 @@ def test_ld_score_regression_1(self): n_samples_exprs=mt['n'], n_blocks=20, two_step_threshold=5, - n_reference_panel_variants=1173569) + n_reference_panel_variants=1173569, + ) results = { x['phenotype']: { @@ -204,41 +219,22 @@ def test_ld_score_regression_1(self): 'intercept_estimate': x['intercept']['estimate'], 'intercept_standard_error': x['intercept']['standard_error'], 'snp_heritability_estimate': x['snp_heritability']['estimate'], - 'snp_heritability_standard_error': - x['snp_heritability']['standard_error']} - for x in ht_results.collect()} - - self.assertAlmostEqual( - results['50_irnt']['mean_chi_sq'], - 3.4386, places=4) - self.assertAlmostEqual( - results['50_irnt']['intercept_estimate'], - 0.7727, places=4) - self.assertAlmostEqual( - results['50_irnt']['intercept_standard_error'], - 0.2461, places=4) - self.assertAlmostEqual( - results['50_irnt']['snp_heritability_estimate'], - 0.3845, places=4) - self.assertAlmostEqual( - results['50_irnt']['snp_heritability_standard_error'], - 0.1067, places=4) - - self.assertAlmostEqual( - results['20160']['mean_chi_sq'], - 1.5209, places=4) - self.assertAlmostEqual( - results['20160']['intercept_estimate'], - 1.2109, places=4) - self.assertAlmostEqual( - results['20160']['intercept_standard_error'], - 0.2238, places=4) - self.assertAlmostEqual( - results['20160']['snp_heritability_estimate'], - 0.0486, places=4) - self.assertAlmostEqual( - results['20160']['snp_heritability_standard_error'], - 0.0416, places=4) + 'snp_heritability_standard_error': x['snp_heritability']['standard_error'], + } + for x in ht_results.collect() + } + + self.assertAlmostEqual(results['50_irnt']['mean_chi_sq'], 3.4386, places=4) + self.assertAlmostEqual(results['50_irnt']['intercept_estimate'], 0.7727, places=4) + self.assertAlmostEqual(results['50_irnt']['intercept_standard_error'], 0.2461, places=4) + self.assertAlmostEqual(results['50_irnt']['snp_heritability_estimate'], 0.3845, places=4) + self.assertAlmostEqual(results['50_irnt']['snp_heritability_standard_error'], 0.1067, places=4) + + self.assertAlmostEqual(results['20160']['mean_chi_sq'], 1.5209, places=4) + self.assertAlmostEqual(results['20160']['intercept_estimate'], 1.2109, places=4) + self.assertAlmostEqual(results['20160']['intercept_standard_error'], 0.2238, places=4) + self.assertAlmostEqual(results['20160']['snp_heritability_estimate'], 0.0486, places=4) + self.assertAlmostEqual(results['20160']['snp_heritability_standard_error'], 0.0416, places=4) @pytest.mark.unchecked_allocator @test_timeout(6 * 60, local=10 * 60, batch=10 * 60) @@ -250,18 +246,18 @@ def test_ld_score_regression_2(self): chi_squared_50_irnt=ht_50_irnt['chi_squared'], n_50_irnt=ht_50_irnt['n'], chi_squared_20160=ht_20160[ht_50_irnt.key]['chi_squared'], - n_20160=ht_20160[ht_50_irnt.key]['n']) + n_20160=ht_20160[ht_50_irnt.key]['n'], + ) ht_results = hl.experimental.ld_score_regression( weight_expr=ht['ld_score'], ld_score_expr=ht['ld_score'], - chi_sq_exprs=[ht['chi_squared_50_irnt'], - ht['chi_squared_20160']], - n_samples_exprs=[ht['n_50_irnt'], - ht['n_20160']], + chi_sq_exprs=[ht['chi_squared_50_irnt'], ht['chi_squared_20160']], + n_samples_exprs=[ht['n_50_irnt'], ht['n_20160']], n_blocks=20, two_step_threshold=5, - n_reference_panel_variants=1173569) + n_reference_panel_variants=1173569, + ) results = { x['phenotype']: { @@ -269,58 +265,40 @@ def test_ld_score_regression_2(self): 'intercept_estimate': x['intercept']['estimate'], 'intercept_standard_error': x['intercept']['standard_error'], 'snp_heritability_estimate': x['snp_heritability']['estimate'], - 'snp_heritability_standard_error': - x['snp_heritability']['standard_error']} - for x in ht_results.collect()} - - self.assertAlmostEqual( - results[0]['mean_chi_sq'], - 3.4386, places=4) - self.assertAlmostEqual( - results[0]['intercept_estimate'], - 0.7727, places=4) - self.assertAlmostEqual( - results[0]['intercept_standard_error'], - 0.2461, places=4) - self.assertAlmostEqual( - results[0]['snp_heritability_estimate'], - 0.3845, places=4) - self.assertAlmostEqual( - results[0]['snp_heritability_standard_error'], - 0.1067, places=4) - - self.assertAlmostEqual( - results[1]['mean_chi_sq'], - 1.5209, places=4) - self.assertAlmostEqual( - results[1]['intercept_estimate'], - 1.2109, places=4) - self.assertAlmostEqual( - results[1]['intercept_standard_error'], - 0.2238, places=4) - self.assertAlmostEqual( - results[1]['snp_heritability_estimate'], - 0.0486, places=4) - self.assertAlmostEqual( - results[1]['snp_heritability_standard_error'], - 0.0416, places=4) + 'snp_heritability_standard_error': x['snp_heritability']['standard_error'], + } + for x in ht_results.collect() + } + + self.assertAlmostEqual(results[0]['mean_chi_sq'], 3.4386, places=4) + self.assertAlmostEqual(results[0]['intercept_estimate'], 0.7727, places=4) + self.assertAlmostEqual(results[0]['intercept_standard_error'], 0.2461, places=4) + self.assertAlmostEqual(results[0]['snp_heritability_estimate'], 0.3845, places=4) + self.assertAlmostEqual(results[0]['snp_heritability_standard_error'], 0.1067, places=4) + + self.assertAlmostEqual(results[1]['mean_chi_sq'], 1.5209, places=4) + self.assertAlmostEqual(results[1]['intercept_estimate'], 1.2109, places=4) + self.assertAlmostEqual(results[1]['intercept_standard_error'], 0.2238, places=4) + self.assertAlmostEqual(results[1]['snp_heritability_estimate'], 0.0486, places=4) + self.assertAlmostEqual(results[1]['snp_heritability_standard_error'], 0.0416, places=4) @test_timeout(local=6 * 60) def test_sparse(self): expected_split_mt = hl.import_vcf(resource('sparse_split_test_b.vcf')) unsplit_mt = hl.import_vcf(resource('sparse_split_test.vcf'), call_fields=['LGT', 'LPGT']) - mt = (hl.experimental.sparse_split_multi(unsplit_mt) - .drop('a_index', 'was_split').select_entries(*expected_split_mt.entry.keys())) + mt = ( + hl.experimental.sparse_split_multi(unsplit_mt) + .drop('a_index', 'was_split') + .select_entries(*expected_split_mt.entry.keys()) + ) assert mt._same(expected_split_mt) def test_define_function(self): - f1 = hl.experimental.define_function( - lambda a, b: (a + 7) * b, hl.tint32, hl.tint32) + f1 = hl.experimental.define_function(lambda a, b: (a + 7) * b, hl.tint32, hl.tint32) self.assertEqual(hl.eval(f1(1, 3)), 24) - f2 = hl.experimental.define_function( - lambda a, b: (a + 7) * b, hl.tint32, hl.tint32) - self.assertEqual(hl.eval(f1(1, 3)), 24) # idempotent - self.assertEqual(hl.eval(f2(1, 3)), 24) # idempotent + f2 = hl.experimental.define_function(lambda a, b: (a + 7) * b, hl.tint32, hl.tint32) + self.assertEqual(hl.eval(f1(1, 3)), 24) # idempotent + self.assertEqual(hl.eval(f2(1, 3)), 24) # idempotent @fails_local_backend() @test_timeout(batch=8 * 60) @@ -345,66 +323,58 @@ def test_mt_full_outer_join(self): mt2 = mt2.annotate_entries(e1=hl.rand_unif(0, 1)) mtj = hl.experimental.full_outer_join_mt(mt1, mt2) - assert(mtj.aggregate_entries(hl.agg.all(mtj.left_entry == mt1.index_entries(mtj.row_key, mtj.col_key)))) - assert(mtj.aggregate_entries(hl.agg.all(mtj.right_entry == mt2.index_entries(mtj.row_key, mtj.col_key)))) + assert mtj.aggregate_entries(hl.agg.all(mtj.left_entry == mt1.index_entries(mtj.row_key, mtj.col_key))) + assert mtj.aggregate_entries(hl.agg.all(mtj.right_entry == mt2.index_entries(mtj.row_key, mtj.col_key))) - mt2 = mt2.key_cols_by(new_col_key = 5 - (mt2.col_idx // 2)) # duplicate col keys - mt1 = mt1.key_rows_by(new_row_key = 5 - (mt1.row_idx // 2)) # duplicate row keys + mt2 = mt2.key_cols_by(new_col_key=5 - (mt2.col_idx // 2)) # duplicate col keys + mt1 = mt1.key_rows_by(new_row_key=5 - (mt1.row_idx // 2)) # duplicate row keys mtj = hl.experimental.full_outer_join_mt(mt1, mt2) - assert(mtj.count() == (15, 15)) + assert mtj.count() == (15, 15) def test_mt_full_outer_join_self(self): mt = hl.import_vcf(resource('sample.vcf')) jmt = hl.experimental.full_outer_join_mt(mt, mt) - assert jmt.filter_cols(hl.is_defined(jmt.left_col) & hl.is_defined(jmt.right_col)).count_cols() == mt.count_cols() - assert jmt.filter_rows(hl.is_defined(jmt.left_row) & hl.is_defined(jmt.right_row)).count_rows() == mt.count_rows() - assert jmt.filter_entries(hl.is_defined(jmt.left_entry) & hl.is_defined(jmt.right_entry)).entries().count() == mt.entries().count() + assert ( + jmt.filter_cols(hl.is_defined(jmt.left_col) & hl.is_defined(jmt.right_col)).count_cols() == mt.count_cols() + ) + assert ( + jmt.filter_rows(hl.is_defined(jmt.left_row) & hl.is_defined(jmt.right_row)).count_rows() == mt.count_rows() + ) + assert ( + jmt.filter_entries(hl.is_defined(jmt.left_entry) & hl.is_defined(jmt.right_entry)).entries().count() + == mt.entries().count() + ) @fails_service_backend() @fails_local_backend() def test_block_matrices_tofiles(self): - data = [ - np.random.rand(11*12), - np.random.rand(5*17) - ] - arrs = [ - data[0].reshape((11, 12)), - data[1].reshape((5, 17)) - ] + data = [np.random.rand(11 * 12), np.random.rand(5 * 17)] bms = [ hl.linalg.BlockMatrix._create(11, 12, data[0].tolist(), block_size=4), - hl.linalg.BlockMatrix._create(5, 17, data[1].tolist(), block_size=8) + hl.linalg.BlockMatrix._create(5, 17, data[1].tolist(), block_size=8), ] with hl.TemporaryDirectory() as prefix: hl.experimental.block_matrices_tofiles(bms, f'{prefix}/files') for i in range(len(bms)): a = data[i] - a2 = np.frombuffer( - hl.current_backend().fs.open(f'{prefix}/files/{i}', mode='rb').read()) + a2 = np.frombuffer(hl.current_backend().fs.open(f'{prefix}/files/{i}', mode='rb').read()) self.assertTrue(np.array_equal(a, a2)) @fails_service_backend() @fails_local_backend() def test_export_block_matrices(self): - data = [ - np.random.rand(11*12), - np.random.rand(5*17) - ] - arrs = [ - data[0].reshape((11, 12)), - data[1].reshape((5, 17)) - ] + data = [np.random.rand(11 * 12), np.random.rand(5 * 17)] + arrs = [data[0].reshape((11, 12)), data[1].reshape((5, 17))] bms = [ hl.linalg.BlockMatrix._create(11, 12, data[0].tolist(), block_size=4), - hl.linalg.BlockMatrix._create(5, 17, data[1].tolist(), block_size=8) + hl.linalg.BlockMatrix._create(5, 17, data[1].tolist(), block_size=8), ] with hl.TemporaryDirectory() as prefix: hl.experimental.export_block_matrices(bms, f'{prefix}/files') for i in range(len(bms)): a = arrs[i] - a2 = np.loadtxt( - hl.current_backend().fs.open(f'{prefix}/files/{i}.tsv')) + a2 = np.loadtxt(hl.current_backend().fs.open(f'{prefix}/files/{i}.tsv')) self.assertTrue(np.array_equal(a, a2)) with hl.TemporaryDirectory() as prefix2: @@ -412,8 +382,7 @@ def test_export_block_matrices(self): hl.experimental.export_block_matrices(bms, f'{prefix2}/files', custom_filenames=custom_names) for i in range(len(bms)): a = arrs[i] - a2 = np.loadtxt( - hl.current_backend().fs.open(f'{prefix2}/files/{custom_names[i]}')) + a2 = np.loadtxt(hl.current_backend().fs.open(f'{prefix2}/files/{custom_names[i]}')) self.assertTrue(np.array_equal(a, a2)) def test_trivial_loop(self): @@ -421,14 +390,12 @@ def test_trivial_loop(self): def test_loop(self): def triangle_with_ints(n): - return hl.experimental.loop( - lambda f, x, c: hl.if_else(x > 0, f(x - 1, c + x), c), - hl.tint32, n, 0) + return hl.experimental.loop(lambda f, x, c: hl.if_else(x > 0, f(x - 1, c + x), c), hl.tint32, n, 0) def triangle_with_tuple(n): return hl.experimental.loop( - lambda f, xc: hl.if_else(xc[0] > 0, f((xc[0] - 1, xc[1] + xc[0])), xc[1]), - hl.tint32, (n, 0)) + lambda f, xc: hl.if_else(xc[0] > 0, f((xc[0] - 1, xc[1] + xc[0])), xc[1]), hl.tint32, (n, 0) + ) for triangle in [triangle_with_ints, triangle_with_tuple]: assert_evals_to(triangle(20), sum(range(21))) @@ -439,16 +406,11 @@ def fails_typecheck(regex, f): with self.assertRaisesRegex(TypeError, regex): hl.eval(hl.experimental.loop(f, hl.tint32, 1)) - fails_typecheck("outside of tail position", - lambda f, x: x + f(x)) - fails_typecheck("wrong number of arguments", - lambda f, x: f(x, x + 1)) - fails_typecheck("bound value", - lambda f, x: hl.bind(lambda x: x, f(x))) - fails_typecheck("branch condition", - lambda f, x: hl.if_else(f(x) == 0, x, 1)) - fails_typecheck("Type error", - lambda f, x: hl.if_else(x == 0, f("foo"), 1)) + fails_typecheck("outside of tail position", lambda f, x: x + f(x)) + fails_typecheck("wrong number of arguments", lambda f, x: f(x, x + 1)) + fails_typecheck("bound value", lambda f, x: hl.bind(lambda x: x, f(x))) + fails_typecheck("branch condition", lambda f, x: hl.if_else(f(x) == 0, x, 1)) + fails_typecheck("Type error", lambda f, x: hl.if_else(x == 0, f("foo"), 1)) def test_nested_loops(self): def triangle_loop(n, add_f): @@ -456,40 +418,63 @@ def triangle_loop(n, add_f): return hl.experimental.loop(recur, hl.tint32, 0, 0) assert_evals_to(triangle_loop(5, lambda x, c: c + x), 15) - assert_evals_to(triangle_loop(5, lambda x, c: c + triangle_loop(x, lambda x2, c2: c2 + x2)), 15 + 10 + 6 + 3 + 1) + assert_evals_to( + triangle_loop(5, lambda x, c: c + triangle_loop(x, lambda x2, c2: c2 + x2)), 15 + 10 + 6 + 3 + 1 + ) n1 = 5 calls_recur_from_nested_loop = hl.experimental.loop( - lambda f, x1, c1: - hl.if_else(x1 <= n1, - hl.experimental.loop( - lambda f2, x2, c2: - hl.if_else(x2 <= x1, - f2(x2 + 1, c2 + x2), - f(x1 + 1, c1 + c2)), - 'int32', 0, 0), - c1), - 'int32', 0, 0) + lambda f, x1, c1: hl.if_else( + x1 <= n1, + hl.experimental.loop( + lambda f2, x2, c2: hl.if_else(x2 <= x1, f2(x2 + 1, c2 + x2), f(x1 + 1, c1 + c2)), 'int32', 0, 0 + ), + c1, + ), + 'int32', + 0, + 0, + ) assert_evals_to(calls_recur_from_nested_loop, 15 + 10 + 6 + 3 + 1) def test_loop_errors(self): - with pytest.raises(TypeError, match="requested type ndarray does not match inferred type ndarray"): - result = hl.experimental.loop( - lambda f, my_nd: - hl.if_else(my_nd[0, 0] == 1000, my_nd, f(my_nd + 1)), - hl.tndarray(hl.tint32, 2), hl.nd.zeros((20, 10), hl.tfloat64)) + with pytest.raises( + TypeError, match="requested type ndarray does not match inferred type ndarray" + ): + hl.experimental.loop( + lambda f, my_nd: hl.if_else(my_nd[0, 0] == 1000, my_nd, f(my_nd + 1)), + hl.tndarray(hl.tint32, 2), + hl.nd.zeros((20, 10), hl.tfloat64), + ) def test_loop_with_struct_of_strings(self): def loop_func(recur_f, my_struct): - return hl.if_else(hl.len(my_struct.s1) > hl.len(my_struct.s2), - my_struct, - recur_f(hl.struct(s1=my_struct.s1 + my_struct.s2[-1], s2=my_struct.s2[:-1]))) + return hl.if_else( + hl.len(my_struct.s1) > hl.len(my_struct.s2), + my_struct, + recur_f(hl.struct(s1=my_struct.s1 + my_struct.s2[-1], s2=my_struct.s2[:-1])), + ) initial_struct = hl.struct(s1="a", s2="gfedcb") - assert hl.eval(hl.experimental.loop(loop_func, hl.tstruct(s1=hl.tstr, s2=hl.tstr), initial_struct)) == hl.Struct(s1="abcd", s2="gfe") + assert hl.eval( + hl.experimental.loop(loop_func, hl.tstruct(s1=hl.tstr, s2=hl.tstr), initial_struct) + ) == hl.Struct(s1="abcd", s2="gfe") def test_loop_memory(self): - def foo(recur, arr, idx): return hl.if_else(idx > 10, arr, recur(arr.append(hl.str(idx)), idx+1)) - - assert hl.eval(hl.experimental.loop(foo, hl.tarray(hl.tstr), hl.literal(['foo']), 1)) == ['foo', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10'] + def foo(recur, arr, idx): + return hl.if_else(idx > 10, arr, recur(arr.append(hl.str(idx)), idx + 1)) + + assert hl.eval(hl.experimental.loop(foo, hl.tarray(hl.tstr), hl.literal(['foo']), 1)) == [ + 'foo', + '1', + '2', + '3', + '4', + '5', + '6', + '7', + '8', + '9', + '10', + ] diff --git a/hail/python/test/hail/experimental/test_local_whitening.py b/hail/python/test/hail/experimental/test_local_whitening.py index f2e616ff9b8..77e6ba5c371 100644 --- a/hail/python/test/hail/experimental/test_local_whitening.py +++ b/hail/python/test/hail/experimental/test_local_whitening.py @@ -1,14 +1,17 @@ import numpy as np from numpy.random import default_rng -from hail.methods.pca import _make_tsm + import hail as hl -from ..helpers import * +from hail.methods.pca import _make_tsm + +from ..helpers import test_timeout + def naive_whiten(X, w): m, n = np.shape(X) Xw = np.zeros((m, n)) for j in range(n): - Q, _ = np.linalg.qr(X[:, max(j - w, 0): j]) + Q, _ = np.linalg.qr(X[:, max(j - w, 0) : j]) Xw[:, j] = X[:, j] - Q @ (Q.T @ X[:, j]) return Xw.T @@ -26,15 +29,13 @@ def run_local_whitening_test(vec_size, num_rows, chunk_size, window_size, partit whitened_naive = naive_whiten(data.T, window_size) np.testing.assert_allclose(whitened_hail, whitened_naive, rtol=1e-04) + @test_timeout(local=5 * 60, batch=12 * 60) def test_local_whitening(): run_local_whitening_test( - vec_size=100, - num_rows=10000, - chunk_size=32, - window_size=64, - partition_size=32 * 40, - initial_num_partitions=50) + vec_size=100, num_rows=10000, chunk_size=32, window_size=64, partition_size=32 * 40, initial_num_partitions=50 + ) + @test_timeout(local=5 * 60, batch=12 * 60) def test_local_whitening_singleton_final_partition(): @@ -44,4 +45,5 @@ def test_local_whitening_singleton_final_partition(): chunk_size=32, window_size=64, partition_size=32 * 40, - initial_num_partitions=50) + initial_num_partitions=50, + ) diff --git a/hail/python/test/hail/experimental/test_time.py b/hail/python/test/hail/experimental/test_time.py index fabc249ff92..616546ff060 100644 --- a/hail/python/test/hail/experimental/test_time.py +++ b/hail/python/test/hail/experimental/test_time.py @@ -1,16 +1,24 @@ -import pytest -from ..helpers import * - +import hail as hl import hail.experimental.time as htime + def test_strftime(): - assert hl.eval(htime.strftime("%A, %B %e, %Y. %r", 876541523, "America/New_York")) == "Friday, October 10, 1997. 11:45:23 PM" + assert ( + hl.eval(htime.strftime("%A, %B %e, %Y. %r", 876541523, "America/New_York")) + == "Friday, October 10, 1997. 11:45:23 PM" + ) assert hl.eval(htime.strftime("%A, %B %e, %Y. %r", 876541523, "GMT+2")) == "Saturday, October 11, 1997. 05:45:23 AM" - assert hl.eval(htime.strftime("%A, %B %e, %Y. %r", 876541523, "+08:00")) == "Saturday, October 11, 1997. 11:45:23 AM" + assert ( + hl.eval(htime.strftime("%A, %B %e, %Y. %r", 876541523, "+08:00")) == "Saturday, October 11, 1997. 11:45:23 AM" + ) assert hl.eval(htime.strftime("%A, %B %e, %Y. %r", -876541523, "+08:00")) == "Tuesday, March 24, 1942. 04:14:37 AM" + def test_strptime(): - assert hl.eval(htime.strptime("Friday, October 10, 1997. 11:45:23 PM", "%A, %B %e, %Y. %r", "America/New_York")) == 876541523 + assert ( + hl.eval(htime.strptime("Friday, October 10, 1997. 11:45:23 PM", "%A, %B %e, %Y. %r", "America/New_York")) + == 876541523 + ) assert hl.eval(htime.strptime("Friday, October 10, 1997. 11:45:23 PM", "%A, %B %e, %Y. %r", "GMT+2")) == 876519923 assert hl.eval(htime.strptime("Friday, October 10, 1997. 11:45:23 PM", "%A, %B %e, %Y. %r", "+08:00")) == 876498323 assert hl.eval(htime.strptime("Tuesday, March 24, 1942. 04:14:37 AM", "%A, %B %e, %Y. %r", "+08:00")) == -876541523 diff --git a/hail/python/test/hail/expr/test_expr.py b/hail/python/test/hail/expr/test_expr.py index f17549fbe63..4f0d9b9605a 100644 --- a/hail/python/test/hail/expr/test_expr.py +++ b/hail/python/test/hail/expr/test_expr.py @@ -1,15 +1,18 @@ import math -import pytest import random -from scipy.stats import pearsonr +import unittest + import numpy as np +import pytest +from scipy.stats import pearsonr import hail as hl import hail.expr.aggregators as agg -from hail.expr.types import * -from hail.expr.functions import _error_from_cdf, _cdf_combine, _result_from_raw_cdf -import hail.ir as ir -from ..helpers import * +from hail import ir +from hail.expr.functions import _cdf_combine, _error_from_cdf, _result_from_raw_cdf +from hail.expr.types import tarray, tbool, tcall, tfloat, tfloat32, tfloat64, tint, tint32, tint64, tstr, tstruct + +from ..helpers import assert_evals_to, convert_struct_to_dict, qobtest, resource, test_timeout, with_flags def _test_many_equal(test_cases): @@ -28,7 +31,8 @@ def _test_many_equal_typed(test_cases): expecteds = [t[1] for t in test_cases] expected_types = [t[2] for t in test_cases] for expression, actual, expected, actual_type, expected_type in zip( - expressions, actuals, expecteds, actual_type.types, expected_types): + expressions, actuals, expecteds, actual_type.types, expected_types + ): assert expression.dtype == expected_type, (expression.dtype, expected_type) assert actual_type == expected_type, (actual_type, expected_type) assert actual == expected, (actual, expected) @@ -36,7 +40,7 @@ def _test_many_equal_typed(test_cases): class Tests(unittest.TestCase): def collect_unindexed_expression(self): - self.assertEqual(hl.array([4,1,2,3]).collect(), [4,1,2,3]) + self.assertEqual(hl.array([4, 1, 2, 3]).collect(), [4, 1, 2, 3]) def test_key_by_random(self): ht = hl.utils.range_table(10, 4) @@ -45,7 +49,6 @@ def test_key_by_random(self): self.assertEqual(ht._force_count(), 10) def test_seeded_same(self): - def test_random_function(rand_f): ht = hl.utils.range_table(10, 4) sample1 = rand_f() @@ -110,88 +113,135 @@ def test_order_by_head_optimization_with_randomness(self): def test_operators(self): schema = hl.tstruct(a=hl.tint32, b=hl.tint32, c=hl.tint32, d=hl.tint32, e=hl.tstr, f=hl.tarray(hl.tint32)) - rows = [{'a': 4, 'b': 1, 'c': 3, 'd': 5, 'e': "hello", 'f': [1, 2, 3]}, - {'a': 0, 'b': 5, 'c': 13, 'd': -1, 'e': "cat", 'f': []}, - {'a': 4, 'b': 2, 'c': 20, 'd': 3, 'e': "dog", 'f': [5, 6, 7]}] + rows = [ + {'a': 4, 'b': 1, 'c': 3, 'd': 5, 'e': "hello", 'f': [1, 2, 3]}, + {'a': 0, 'b': 5, 'c': 13, 'd': -1, 'e': "cat", 'f': []}, + {'a': 4, 'b': 2, 'c': 20, 'd': 3, 'e': "dog", 'f': [5, 6, 7]}, + ] kt = hl.Table.parallelize(rows, schema) - result = convert_struct_to_dict(kt.annotate( - x1=kt.a + 5, - x2=5 + kt.a, - x3=kt.a + kt.b, - x4=kt.a - 5, - x5=5 - kt.a, - x6=kt.a - kt.b, - x7=kt.a * 5, - x8=5 * kt.a, - x9=kt.a * kt.b, - x10=kt.a / 5, - x11=5 / kt.a, - x12=kt.a / kt.b, - x13=-kt.a, - x14=+kt.a, - x15=kt.a == kt.b, - x16=kt.a == 5, - x17=5 == kt.a, - x18=kt.a != kt.b, - x19=kt.a != 5, - x20=5 != kt.a, - x21=kt.a > kt.b, - x22=kt.a > 5, - x23=5 > kt.a, - x24=kt.a >= kt.b, - x25=kt.a >= 5, - x26=5 >= kt.a, - x27=kt.a < kt.b, - x28=kt.a < 5, - x29=5 < kt.a, - x30=kt.a <= kt.b, - x31=kt.a <= 5, - x32=5 <= kt.a, - x33=(kt.a == 0) & (kt.b == 5), - x34=(kt.a == 0) | (kt.b == 5), - x35=False, - x36=True, - x37=kt.e > "helln", - x38=kt.e < "hellp", - x39=kt.e <= "hello", - x40=kt.e >= "hello", - x41="helln" > kt.e, - x42="hellp" < kt.e, - x43="hello" >= kt.e, - x44="hello" <= kt.e, - x45=kt.f > [1, 2], - x46=kt.f < [1, 3], - x47=kt.f >= [1, 2, 3], - x48=kt.f <= [1, 2, 3], - x49=kt.f < [1.0, 2.0], - x50=kt.f > [1.0, 3.0], - x51=[1.0, 2.0, 3.0] <= kt.f, - x52=[1.0, 2.0, 3.0] >= kt.f, - x53=hl.tuple([True, 1.0]) < (1.0, 0.0), - x54=kt.e * kt.a, - ).take(1)[0]) - - expected = {'a': 4, 'b': 1, 'c': 3, 'd': 5, 'e': "hello", 'f': [1, 2, 3], - 'x1': 9, 'x2': 9, 'x3': 5, - 'x4': -1, 'x5': 1, 'x6': 3, - 'x7': 20, 'x8': 20, 'x9': 4, - 'x10': 4.0 / 5, 'x11': 5.0 / 4, 'x12': 4, 'x13': -4, 'x14': 4, - 'x15': False, 'x16': False, 'x17': False, - 'x18': True, 'x19': True, 'x20': True, - 'x21': True, 'x22': False, 'x23': True, - 'x24': True, 'x25': False, 'x26': True, - 'x27': False, 'x28': True, 'x29': False, - 'x30': False, 'x31': True, 'x32': False, - 'x33': False, 'x34': False, 'x35': False, - 'x36': True, 'x37': True, 'x38': True, - 'x39': True, 'x40': True, 'x41': False, - 'x42': False, 'x43': True, 'x44': True, - 'x45': True, 'x46': True, 'x47': True, - 'x48': True, 'x49': False, 'x50': False, - 'x51': True, 'x52': True, 'x53': False, - 'x54': "hellohellohellohello"} + result = convert_struct_to_dict( + kt.annotate( + x1=kt.a + 5, + x2=5 + kt.a, + x3=kt.a + kt.b, + x4=kt.a - 5, + x5=5 - kt.a, + x6=kt.a - kt.b, + x7=kt.a * 5, + x8=5 * kt.a, + x9=kt.a * kt.b, + x10=kt.a / 5, + x11=5 / kt.a, + x12=kt.a / kt.b, + x13=-kt.a, + x14=+kt.a, + x15=kt.a == kt.b, + x16=kt.a == 5, + x17=5 == kt.a, + x18=kt.a != kt.b, + x19=kt.a != 5, + x20=5 != kt.a, + x21=kt.a > kt.b, + x22=kt.a > 5, + x23=5 > kt.a, + x24=kt.a >= kt.b, + x25=kt.a >= 5, + x26=5 >= kt.a, + x27=kt.a < kt.b, + x28=kt.a < 5, + x29=5 < kt.a, + x30=kt.a <= kt.b, + x31=kt.a <= 5, + x32=5 <= kt.a, + x33=(kt.a == 0) & (kt.b == 5), + x34=(kt.a == 0) | (kt.b == 5), + x35=False, + x36=True, + x37=kt.e > "helln", + x38=kt.e < "hellp", + x39=kt.e <= "hello", + x40=kt.e >= "hello", + x41="helln" > kt.e, + x42="hellp" < kt.e, + x43="hello" >= kt.e, + x44="hello" <= kt.e, + x45=kt.f > [1, 2], + x46=kt.f < [1, 3], + x47=kt.f >= [1, 2, 3], + x48=kt.f <= [1, 2, 3], + x49=kt.f < [1.0, 2.0], + x50=kt.f > [1.0, 3.0], + x51=[1.0, 2.0, 3.0] <= kt.f, + x52=[1.0, 2.0, 3.0] >= kt.f, + x53=hl.tuple([True, 1.0]) < (1.0, 0.0), + x54=kt.e * kt.a, + ).take(1)[0] + ) + + expected = { + 'a': 4, + 'b': 1, + 'c': 3, + 'd': 5, + 'e': "hello", + 'f': [1, 2, 3], + 'x1': 9, + 'x2': 9, + 'x3': 5, + 'x4': -1, + 'x5': 1, + 'x6': 3, + 'x7': 20, + 'x8': 20, + 'x9': 4, + 'x10': 4.0 / 5, + 'x11': 5.0 / 4, + 'x12': 4, + 'x13': -4, + 'x14': 4, + 'x15': False, + 'x16': False, + 'x17': False, + 'x18': True, + 'x19': True, + 'x20': True, + 'x21': True, + 'x22': False, + 'x23': True, + 'x24': True, + 'x25': False, + 'x26': True, + 'x27': False, + 'x28': True, + 'x29': False, + 'x30': False, + 'x31': True, + 'x32': False, + 'x33': False, + 'x34': False, + 'x35': False, + 'x36': True, + 'x37': True, + 'x38': True, + 'x39': True, + 'x40': True, + 'x41': False, + 'x42': False, + 'x43': True, + 'x44': True, + 'x45': True, + 'x46': True, + 'x47': True, + 'x48': True, + 'x49': False, + 'x50': False, + 'x51': True, + 'x52': True, + 'x53': False, + 'x54': "hellohellohellohello", + } for k, v in expected.items(): if isinstance(v, float): @@ -206,32 +256,44 @@ def test_array_slicing(self): ha = hl.array(hl.range(100)) pa = list(range(100)) - result = convert_struct_to_dict(kt.annotate( - x1=kt.a[0], - x2=kt.a[2], - x3=kt.a[:], - x4=kt.a[1:2], - x5=kt.a[-1:4], - x6=kt.a[:2], - x7=kt.a[-20:20:-2], - x8=kt.a[20:-20:2], - x9=kt.a[-20:20:2], - x10=kt.a[20:-20:-2] - ).take(1)[0]) - - expected = {'a': [1, 2, 3, 4, 5], 'x1': 1, 'x2': 3, 'x3': [1, 2, 3, 4, 5], - 'x4': [2], 'x5': [], 'x6': [1, 2], 'x7': [], 'x8': [], 'x9': [1, 3, 5], - 'x10': [5, 3, 1]} + result = convert_struct_to_dict( + kt.annotate( + x1=kt.a[0], + x2=kt.a[2], + x3=kt.a[:], + x4=kt.a[1:2], + x5=kt.a[-1:4], + x6=kt.a[:2], + x7=kt.a[-20:20:-2], + x8=kt.a[20:-20:2], + x9=kt.a[-20:20:2], + x10=kt.a[20:-20:-2], + ).take(1)[0] + ) + + expected = { + 'a': [1, 2, 3, 4, 5], + 'x1': 1, + 'x2': 3, + 'x3': [1, 2, 3, 4, 5], + 'x4': [2], + 'x5': [], + 'x6': [1, 2], + 'x7': [], + 'x8': [], + 'x9': [1, 3, 5], + 'x10': [5, 3, 1], + } self.assertDictEqual(result, expected) - self.assertEqual(pa[60:1:-3], hl.eval(ha[hl.int32(60):hl.int32(1):hl.int32(-3)])) - self.assertEqual(pa[::5], hl.eval(ha[::hl.int32(5)])) + self.assertEqual(pa[60:1:-3], hl.eval(ha[hl.int32(60) : hl.int32(1) : hl.int32(-3)])) + self.assertEqual(pa[::5], hl.eval(ha[:: hl.int32(5)])) self.assertEqual(pa[::-3], hl.eval(ha[::-3])) - self.assertEqual(pa[:-77:-3], hl.eval(ha[:hl.int32(-77):-3])) + self.assertEqual(pa[:-77:-3], hl.eval(ha[: hl.int32(-77) : -3])) self.assertEqual(pa[44::-7], hl.eval(ha[44::-7])) self.assertEqual(pa[2:59:7], hl.eval(ha[2:59:7])) self.assertEqual(pa[4:40:2], hl.eval(ha[4:40:2])) - self.assertEqual(pa[-400:-300:2], hl.eval(ha[hl.int32(-400):-300:2])) + self.assertEqual(pa[-400:-300:2], hl.eval(ha[hl.int32(-400) : -300 : 2])) self.assertEqual(pa[-300:-400:-2], hl.eval(ha[-300:-400:-2])) self.assertEqual(pa[300:400:2], hl.eval(ha[300:400:2])) self.assertEqual(pa[400:300:-2], hl.eval(ha[400:300:-2])) @@ -246,23 +308,35 @@ def test_dict_methods(self): kt = kt.annotate(a={'cat': 3, 'dog': 7}) - result = convert_struct_to_dict(kt.annotate( - x1=kt.a['cat'], - x2=kt.a['dog'], - x3=kt.a.keys().contains('rabbit'), - x4=kt.a.size() == 0, - x5=kt.a.key_set(), - x6=kt.a.keys(), - x7=kt.a.values(), - x8=kt.a.size(), - x9=kt.a.map_values(lambda v: v * 2.0), - x10=kt.a.items(), - ).take(1)[0]) - - expected = {'a': {'cat': 3, 'dog': 7}, 'x': 2.0, 'x1': 3, 'x2': 7, 'x3': False, - 'x4': False, 'x5': {'cat', 'dog'}, 'x6': ['cat', 'dog'], - 'x7': [3, 7], 'x8': 2, 'x9': {'cat': 6.0, 'dog': 14.0}, - 'x10': [('cat', 3), ('dog', 7)]} + result = convert_struct_to_dict( + kt.annotate( + x1=kt.a['cat'], + x2=kt.a['dog'], + x3=kt.a.keys().contains('rabbit'), + x4=kt.a.size() == 0, + x5=kt.a.key_set(), + x6=kt.a.keys(), + x7=kt.a.values(), + x8=kt.a.size(), + x9=kt.a.map_values(lambda v: v * 2.0), + x10=kt.a.items(), + ).take(1)[0] + ) + + expected = { + 'a': {'cat': 3, 'dog': 7}, + 'x': 2.0, + 'x1': 3, + 'x2': 7, + 'x3': False, + 'x4': False, + 'x5': {'cat', 'dog'}, + 'x6': ['cat', 'dog'], + 'x7': [3, 7], + 'x8': 2, + 'x9': {'cat': 6.0, 'dog': 14.0}, + 'x10': [('cat', 3), ('dog', 7)], + } self.assertDictEqual(result, expected) @@ -277,21 +351,19 @@ def test_numeric_conversion(self): kt = hl.Table.parallelize(rows, schema) kt = kt.annotate(d=hl.int64(kt.d)) - kt = kt.annotate(x1=[1.0, kt.a, 1], - x2=[1, 1.0], - x3=[kt.a, kt.c], - x4=[kt.c, kt.d], - x5=[1, kt.c]) - - expected_schema = {'a': hl.tfloat64, - 'b': hl.tfloat64, - 'c': hl.tint32, - 'd': hl.tint64, - 'x1': hl.tarray(hl.tfloat64), - 'x2': hl.tarray(hl.tfloat64), - 'x3': hl.tarray(hl.tfloat64), - 'x4': hl.tarray(hl.tint64), - 'x5': hl.tarray(hl.tint32)} + kt = kt.annotate(x1=[1.0, kt.a, 1], x2=[1, 1.0], x3=[kt.a, kt.c], x4=[kt.c, kt.d], x5=[1, kt.c]) + + expected_schema = { + 'a': hl.tfloat64, + 'b': hl.tfloat64, + 'c': hl.tint32, + 'd': hl.tint64, + 'x1': hl.tarray(hl.tfloat64), + 'x2': hl.tarray(hl.tfloat64), + 'x3': hl.tarray(hl.tfloat64), + 'x4': hl.tarray(hl.tint64), + 'x5': hl.tarray(hl.tint32), + } for f, t in kt.row.dtype.items(): self.assertEqual(expected_schema[f], t) @@ -304,15 +376,23 @@ def test_genetics_constructors(self): kt = hl.Table.parallelize(rows, schema) kt = kt.annotate(d=hl.int64(kt.d)) - kt = kt.annotate(l1=hl.parse_locus("1:51"), - l2=hl.locus("1", 51, reference_genome=rg), - i1=hl.parse_locus_interval("1:51-56", reference_genome=rg), - i2=hl.interval(hl.locus("1", 51, reference_genome=rg), - hl.locus("1", 56, reference_genome=rg))) + kt = kt.annotate( + l1=hl.parse_locus("1:51"), + l2=hl.locus("1", 51, reference_genome=rg), + i1=hl.parse_locus_interval("1:51-56", reference_genome=rg), + i2=hl.interval(hl.locus("1", 51, reference_genome=rg), hl.locus("1", 56, reference_genome=rg)), + ) - expected_schema = {'a': hl.tfloat64, 'b': hl.tfloat64, 'c': hl.tint32, 'd': hl.tint64, - 'l1': hl.tlocus(), 'l2': hl.tlocus(rg), - 'i1': hl.tinterval(hl.tlocus(rg)), 'i2': hl.tinterval(hl.tlocus(rg))} + expected_schema = { + 'a': hl.tfloat64, + 'b': hl.tfloat64, + 'c': hl.tint32, + 'd': hl.tint64, + 'l1': hl.tlocus(), + 'l2': hl.tlocus(rg), + 'i1': hl.tinterval(hl.tlocus(rg)), + 'i2': hl.tinterval(hl.tlocus(rg)), + } self.assertTrue(all([expected_schema[f] == t for f, t in kt.row.dtype.items()])) @@ -335,7 +415,11 @@ def test_rbind_placement(self): def test_translate(self): strs = [None, '', 'TATAN'] - assert hl.eval(hl.literal(strs, 'array').map(lambda x: x.translate({'T': 'A', 'A': 'T'}))) == [None, '', 'ATATN'] + assert hl.eval(hl.literal(strs, 'array').map(lambda x: x.translate({'T': 'A', 'A': 'T'}))) == [ + None, + '', + 'ATATN', + ] with pytest.raises(hl.utils.FatalError, match='mapping keys must be one character'): hl.eval(hl.str('foo').translate({'foo': 'bar'})) @@ -346,8 +430,16 @@ def test_translate(self): def test_reverse_complement(self): strs = ['NNGATTACA', 'NNGATTACA'.lower(), 'foo'] rna_strs = ['NNGATTACA', 'NNGAUUACA'.lower(), 'foo'] - assert hl.eval(hl.literal(strs).map(lambda s: hl.reverse_complement(s))) == ['TGTAATCNN', 'TGTAATCNN'.lower(), 'oof'] - assert hl.eval(hl.literal(rna_strs).map(lambda s: hl.reverse_complement(s, rna=True))) == ['UGUAAUCNN', 'UGUAAUCNN'.lower(), 'oof'] + assert hl.eval(hl.literal(strs).map(lambda s: hl.reverse_complement(s))) == [ + 'TGTAATCNN', + 'TGTAATCNN'.lower(), + 'oof', + ] + assert hl.eval(hl.literal(rna_strs).map(lambda s: hl.reverse_complement(s, rna=True))) == [ + 'UGUAAUCNN', + 'UGUAAUCNN'.lower(), + 'oof', + ] def test_matches(self): self.assertEqual(hl.eval('\\d+'), '\\d+') @@ -397,14 +489,18 @@ def test_if_else(self): @qobtest def test_aggregators(self): table = hl.utils.range_table(10) - r = table.aggregate(hl.struct(x=hl.agg.count(), - y=hl.agg.count_where(table.idx % 2 == 0), - z=hl.agg.filter(table.idx % 2 == 0, hl.agg.count()), - arr_sum=hl.agg.array_sum([1, 2, hl.missing(tint32)]), - bind_agg=hl.agg.count_where(hl.bind(lambda x: x % 2 == 0, table.idx)), - mean=hl.agg.mean(table.idx), - mean2=hl.agg.mean(hl.if_else(table.idx == 9, table.idx, hl.missing(tint32))), - foo=hl.min(3, hl.agg.sum(table.idx)))) + r = table.aggregate( + hl.struct( + x=hl.agg.count(), + y=hl.agg.count_where(table.idx % 2 == 0), + z=hl.agg.filter(table.idx % 2 == 0, hl.agg.count()), + arr_sum=hl.agg.array_sum([1, 2, hl.missing(tint32)]), + bind_agg=hl.agg.count_where(hl.bind(lambda x: x % 2 == 0, table.idx)), + mean=hl.agg.mean(table.idx), + mean2=hl.agg.mean(hl.if_else(table.idx == 9, table.idx, hl.missing(tint32))), + foo=hl.min(3, hl.agg.sum(table.idx)), + ) + ) self.assertEqual(r.x, 10) self.assertEqual(r.y, 5) @@ -418,11 +514,15 @@ def test_aggregators(self): a = hl.literal([1, 2], tarray(tint32)) self.assertEqual(table.aggregate(hl.agg.filter(True, hl.agg.array_sum(a))), [10, 20]) - r = table.aggregate(hl.struct(fraction_odd=hl.agg.fraction(table.idx % 2 == 0), - lessthan6=hl.agg.fraction(table.idx < 6), - gt6=hl.agg.fraction(table.idx > 6), - assert1=hl.agg.fraction(table.idx > 6) < 0.50, - assert2=hl.agg.fraction(table.idx < 6) >= 0.50)) + r = table.aggregate( + hl.struct( + fraction_odd=hl.agg.fraction(table.idx % 2 == 0), + lessthan6=hl.agg.fraction(table.idx < 6), + gt6=hl.agg.fraction(table.idx > 6), + assert1=hl.agg.fraction(table.idx > 6) < 0.50, + assert2=hl.agg.fraction(table.idx < 6) >= 0.50, + ) + ) self.assertEqual(r.fraction_odd, 0.50) self.assertEqual(r.lessthan6, 0.60) self.assertEqual(r.gt6, 0.30) @@ -431,7 +531,7 @@ def test_aggregators(self): def test_agg_nesting(self): t = hl.utils.range_table(10) - aggregated_count = t.aggregate(hl.agg.count(), _localize=False) #10 + aggregated_count = t.aggregate(hl.agg.count(), _localize=False) # 10 filter_count = t.aggregate(hl.agg.filter(aggregated_count == 10, hl.agg.count())) self.assertEqual(filter_count, 10) @@ -473,21 +573,31 @@ def test_aggfold_agg(self): self.assertEqual(ht.aggregate(hl.agg.fold(0, lambda x: x + ht.idx, lambda a, b: a + b)), 4950) ht = ht.annotate(s=hl.struct(x=ht.idx, y=ht.idx + 1)) - sum_and_product = (ht.aggregate( + sum_and_product = ht.aggregate( hl.agg.fold( hl.struct(x=0, y=1.0), lambda accum: hl.struct(x=accum.x + ht.s.x, y=accum.y * ht.s.y), - lambda a, b: hl.struct(x=a.x + b.x, y=a.y * b.y)))) - self.assertEqual(sum_and_product, hl.Struct(x=4950, y=9.332621544394414e+157)) + lambda a, b: hl.struct(x=a.x + b.x, y=a.y * b.y), + ) + ) + self.assertEqual(sum_and_product, hl.Struct(x=4950, y=9.332621544394414e157)) ht = ht.annotate(maybe=hl.if_else(ht.idx % 2 == 0, ht.idx, hl.missing(hl.tint32))) sum_evens_missing = ht.aggregate(hl.agg.fold(0, lambda x: x + ht.maybe, lambda a, b: a + b)) assert sum_evens_missing is None - sum_evens_only = ht.aggregate(hl.agg.fold(0, lambda x: x + hl.coalesce(ht.maybe, 0), lambda a, b: hl.coalesce(a + b, a, b))) + sum_evens_only = ht.aggregate( + hl.agg.fold(0, lambda x: x + hl.coalesce(ht.maybe, 0), lambda a, b: hl.coalesce(a + b, a, b)) + ) self.assertEqual(sum_evens_only, 2450) - #Testing types work out - sum_float64 = ht.aggregate(hl.agg.fold(hl.int32(0), lambda acc: acc + hl.float32(ht.idx), lambda acc1, acc2: hl.float64(acc1) + hl.float64(acc2))) + # Testing types work out + sum_float64 = ht.aggregate( + hl.agg.fold( + hl.int32(0), + lambda acc: acc + hl.float32(ht.idx), + lambda acc1, acc2: hl.float64(acc1) + hl.float64(acc2), + ) + ) self.assertEqual(sum_float64, 4950.0) ht = ht.annotate_globals(foo=7) @@ -501,42 +611,56 @@ def test_aggfold_agg(self): def test_aggfold_scan(self): ht = hl.utils.range_table(15, 5) - ht = ht.annotate(s=hl.scan.fold(0, lambda a: a + ht.idx, lambda a, b: a + b), ) + ht = ht.annotate( + s=hl.scan.fold(0, lambda a: a + ht.idx, lambda a, b: a + b), + ) self.assertEqual(ht.s.collect(), [0, 0, 1, 3, 6, 10, 15, 21, 28, 36, 45, 55, 66, 78, 91]) mt = hl.utils.range_matrix_table(15, 10, 5) mt = mt.annotate_rows(s=hl.scan.fold(0, lambda a: a + mt.row_idx, lambda a, b: a + b)) mt = mt.annotate_rows(x=hl.scan.fold(0, lambda s: s + 1, lambda a, b: a + b)) self.assertEqual(mt.s.collect(), [0, 0, 1, 3, 6, 10, 15, 21, 28, 36, 45, 55, 66, 78, 91]) - self.assertEqual(mt.rows().collect(), - [hl.Struct(row_idx=0, s=0, x=0), hl.Struct(row_idx=1, s=0, x=1), hl.Struct(row_idx=2, s=1, x=2), - hl.Struct(row_idx=3, s=3, x=3), hl.Struct(row_idx=4, s=6, x=4), hl.Struct(row_idx=5, s=10, x=5), - hl.Struct(row_idx=6, s=15, x=6), hl.Struct(row_idx=7, s=21, x=7), hl.Struct(row_idx=8, s=28, x=8), - hl.Struct(row_idx=9, s=36, x=9), hl.Struct(row_idx=10, s=45, x=10), hl.Struct(row_idx=11, s=55, x=11), - hl.Struct(row_idx=12, s=66, x=12), hl.Struct(row_idx=13, s=78, x=13), hl.Struct(row_idx=14, s=91, x=14)]) + self.assertEqual( + mt.rows().collect(), + [ + hl.Struct(row_idx=0, s=0, x=0), + hl.Struct(row_idx=1, s=0, x=1), + hl.Struct(row_idx=2, s=1, x=2), + hl.Struct(row_idx=3, s=3, x=3), + hl.Struct(row_idx=4, s=6, x=4), + hl.Struct(row_idx=5, s=10, x=5), + hl.Struct(row_idx=6, s=15, x=6), + hl.Struct(row_idx=7, s=21, x=7), + hl.Struct(row_idx=8, s=28, x=8), + hl.Struct(row_idx=9, s=36, x=9), + hl.Struct(row_idx=10, s=45, x=10), + hl.Struct(row_idx=11, s=55, x=11), + hl.Struct(row_idx=12, s=66, x=12), + hl.Struct(row_idx=13, s=78, x=13), + hl.Struct(row_idx=14, s=91, x=14), + ], + ) def test_agg_filter(self): t = hl.utils.range_table(10) - tests = [(hl.agg.filter(t.idx > 7, - hl.agg.collect(t.idx + 1).append(0)), - [9, 10, 0]), - (hl.agg.filter(t.idx > 7, - hl.agg.explode(lambda elt: hl.agg.collect(elt + 1).append(0), - [t.idx, t.idx + 1])), - [9, 10, 10, 11, 0]), - (hl.agg.filter(t.idx > 7, - hl.agg.group_by(t.idx % 3, - hl.array(hl.agg.collect_as_set(t.idx + 1)).append(0))), - {0: [10, 0], 2: [9, 0]}), - (hl.agg.filter(t.idx > 7, hl.agg.count()), 2), - (hl.agg.filter(t.idx > 7, - hl.agg.explode(lambda elt: hl.agg.count(), - [t.idx, t.idx + 1])), 4), - (hl.agg.filter(t.idx > 7, - hl.agg.group_by(t.idx % 3, - hl.agg.count())), - {0: 1, 2: 1}), - ] + tests = [ + (hl.agg.filter(t.idx > 7, hl.agg.collect(t.idx + 1).append(0)), [9, 10, 0]), + ( + hl.agg.filter( + t.idx > 7, hl.agg.explode(lambda elt: hl.agg.collect(elt + 1).append(0), [t.idx, t.idx + 1]) + ), + [9, 10, 10, 11, 0], + ), + ( + hl.agg.filter( + t.idx > 7, hl.agg.group_by(t.idx % 3, hl.array(hl.agg.collect_as_set(t.idx + 1)).append(0)) + ), + {0: [10, 0], 2: [9, 0]}, + ), + (hl.agg.filter(t.idx > 7, hl.agg.count()), 2), + (hl.agg.filter(t.idx > 7, hl.agg.explode(lambda elt: hl.agg.count(), [t.idx, t.idx + 1])), 4), + (hl.agg.filter(t.idx > 7, hl.agg.group_by(t.idx % 3, hl.agg.count())), {0: 1, 2: 1}), + ] for aggregation, expected in tests: self.assertEqual(t.aggregate(aggregation), expected) @@ -549,34 +673,39 @@ def test_agg_densify(self): ht = ht.drop('entries', 'cols') assert ht.collect() == [ hl.utils.Struct(row_idx=0, dense=[None, None, None, None, None]), - hl.utils.Struct(row_idx=1, dense=[hl.utils.Struct(x=(0, 0), y='0,0'), - None, - None, - None, - None]), - hl.utils.Struct(row_idx=2, dense=[hl.utils.Struct(x=(0, 0), y='0,0'), - hl.utils.Struct(x=(1, 1), y='1,1'), - None, - None, - None]), - hl.utils.Struct(row_idx=3, dense=[hl.utils.Struct(x=(0, 0), y='0,0'), - hl.utils.Struct(x=(1, 1), y='1,1'), - hl.utils.Struct(x=(2, 2), y='2,2'), - None, - None]), - hl.utils.Struct(row_idx=4, dense=[hl.utils.Struct(x=(0, 0), y='0,0'), - hl.utils.Struct(x=(1, 1), y='1,1'), - hl.utils.Struct(x=(2, 2), y='2,2'), - hl.utils.Struct(x=(3, 3), y='3,3'), - None]), + hl.utils.Struct(row_idx=1, dense=[hl.utils.Struct(x=(0, 0), y='0,0'), None, None, None, None]), + hl.utils.Struct( + row_idx=2, + dense=[hl.utils.Struct(x=(0, 0), y='0,0'), hl.utils.Struct(x=(1, 1), y='1,1'), None, None, None], + ), + hl.utils.Struct( + row_idx=3, + dense=[ + hl.utils.Struct(x=(0, 0), y='0,0'), + hl.utils.Struct(x=(1, 1), y='1,1'), + hl.utils.Struct(x=(2, 2), y='2,2'), + None, + None, + ], + ), + hl.utils.Struct( + row_idx=4, + dense=[ + hl.utils.Struct(x=(0, 0), y='0,0'), + hl.utils.Struct(x=(1, 1), y='1,1'), + hl.utils.Struct(x=(2, 2), y='2,2'), + hl.utils.Struct(x=(3, 3), y='3,3'), + None, + ], + ), ] @qobtest @with_flags(distributed_scan_comb_op='1') def test_densify_table(self): ht = hl.utils.range_table(100, n_partitions=33) - ht = ht.annotate(arr = hl.range(100).map(lambda idx: hl.or_missing(idx == ht.idx, idx))) - ht = ht.annotate(dense = hl.scan._densify(100, ht.arr)) + ht = ht.annotate(arr=hl.range(100).map(lambda idx: hl.or_missing(idx == ht.idx, idx))) + ht = ht.annotate(dense=hl.scan._densify(100, ht.arr)) assert ht.all(ht.dense == hl.range(100).map(lambda idx: hl.or_missing(idx < ht.idx, idx))) def test_agg_array_inside_annotate_rows(self): @@ -588,14 +717,14 @@ def test_agg_array_inside_annotate_rows(self): def test_agg_array_empty(self): ht = hl.utils.range_table(1).annotate(a=[0]).filter(False) - assert ht.aggregate(hl.agg.array_agg(lambda x: hl.agg.sum(x), ht.a)) == None + assert ht.aggregate(hl.agg.array_agg(lambda x: hl.agg.sum(x), ht.a)) is None def test_agg_array_non_trivial_post_op(self): ht = hl.utils.range_table(10) ht = ht.annotate(a=[ht.idx, 2 * ht.idx]) - assert ht.aggregate(hl.agg.array_agg( - lambda x: hl.agg.sum(x) + hl.agg.filter(x % 3 == 0, hl.agg.sum(x)), - ht.a)) == [63, 126] + assert ht.aggregate( + hl.agg.array_agg(lambda x: hl.agg.sum(x) + hl.agg.filter(x % 3 == 0, hl.agg.sum(x)), ht.a) + ) == [63, 126] def test_agg_array_agg_empty_partitions(self): ht = hl.utils.range_table(11, 11) @@ -648,31 +777,28 @@ def test_agg_array_filter(self): def test_agg_array_group_by(self): ht = hl.utils.range_table(10) ht = ht.annotate(a=[ht.idx, ht.idx + 1]) - r = ht.aggregate( - hl.agg.group_by(ht.idx % 2, hl.agg.array_agg(lambda x: hl.agg.sum(x), ht.a))) + r = ht.aggregate(hl.agg.group_by(ht.idx % 2, hl.agg.array_agg(lambda x: hl.agg.sum(x), ht.a))) assert r == {0: [20, 25], 1: [25, 30]} - r2 = ht.aggregate( - hl.agg.array_agg(lambda x: hl.agg.group_by(x % 2, hl.agg.sum(x)), ht.a)) + r2 = ht.aggregate(hl.agg.array_agg(lambda x: hl.agg.group_by(x % 2, hl.agg.sum(x)), ht.a)) assert r2 == [{0: 20, 1: 25}, {0: 30, 1: 25}] - r3 = ht.aggregate( - hl.agg.group_by(ht.idx % 2, hl.agg.array_agg(lambda x: hl.agg.count(), ht.a))) + r3 = ht.aggregate(hl.agg.group_by(ht.idx % 2, hl.agg.array_agg(lambda x: hl.agg.count(), ht.a))) assert r3 == {0: [5, 5], 1: [5, 5]} - r4 = ht.aggregate( - hl.agg.array_agg(lambda x: hl.agg.group_by(x % 2, hl.agg.count()), ht.a)) + r4 = ht.aggregate(hl.agg.array_agg(lambda x: hl.agg.group_by(x % 2, hl.agg.count()), ht.a)) assert r4 == [{0: 5, 1: 5}, {0: 5, 1: 5}] def test_agg_array_nested(self): ht = hl.utils.range_table(10) ht = ht.annotate(a=[[[ht.idx]]]) - assert ht.aggregate(hl.agg.array_agg( - lambda x1: hl.agg.array_agg( - lambda x2: hl.agg.array_agg( - lambda x3: hl.agg.sum(x3), x2), x1), ht.a)) == [[[45]]] + assert ht.aggregate( + hl.agg.array_agg( + lambda x1: hl.agg.array_agg(lambda x2: hl.agg.array_agg(lambda x3: hl.agg.sum(x3), x2), x1), ht.a + ) + ) == [[[45]]] def test_agg_array_take(self): ht = hl.utils.range_table(10) @@ -680,14 +806,16 @@ def test_agg_array_take(self): assert r == [[0, 1], [0, 2]] def test_agg_array_init_op(self): - ht = hl.utils.range_table(1).annotate_globals(n_alleles = ['A', 'T']).annotate(gts = [hl.call(0, 1), hl.call(1, 1)]) + ht = hl.utils.range_table(1).annotate_globals(n_alleles=['A', 'T']).annotate(gts=[hl.call(0, 1), hl.call(1, 1)]) r = ht.aggregate(hl.agg.array_agg(lambda a: hl.agg.call_stats(a, ht.n_alleles), ht.gts)) - assert r == [hl.utils.Struct(AC=[1, 1], AF=[0.5, 0.5], AN=2, homozygote_count=[0, 0]), - hl.utils.Struct(AC=[0, 2], AF=[0.0, 1.0], AN=2, homozygote_count=[0, 1])] + assert r == [ + hl.utils.Struct(AC=[1, 1], AF=[0.5, 0.5], AN=2, homozygote_count=[0, 0]), + hl.utils.Struct(AC=[0, 2], AF=[0.0, 1.0], AN=2, homozygote_count=[0, 1]), + ] def test_agg_collect_all_types_runs(self): ht = hl.utils.range_table(2) - ht = ht.annotate(x = hl.case().when(ht.idx % 1 == 0, True).or_missing()) + ht = ht.annotate(x=hl.case().when(ht.idx % 1 == 0, True).or_missing()) ht.aggregate(( hl.agg.collect(ht.x), hl.agg.collect(hl.int32(ht.x)), @@ -696,7 +824,7 @@ def test_agg_collect_all_types_runs(self): hl.agg.collect(hl.float64(ht.x)), hl.agg.collect(hl.str(ht.x)), hl.agg.collect(hl.call(0, 0, phased=ht.x)), - hl.agg.collect(hl.struct(foo = ht.x)), + hl.agg.collect(hl.struct(foo=ht.x)), hl.agg.collect(hl.tuple([ht.x])), hl.agg.collect([ht.x]), hl.agg.collect({ht.x}), @@ -704,99 +832,140 @@ def test_agg_collect_all_types_runs(self): hl.agg.collect(hl.interval(0, 1, includes_start=ht.x)), )) - def test_agg_explode(self): t = hl.utils.range_table(10) - tests = [(hl.agg.explode(lambda elt: hl.agg.collect(elt + 1).append(0), - hl.if_else(t.idx > 7, [t.idx, t.idx + 1], hl.empty_array(hl.tint32))), - [9, 10, 10, 11, 0]), - (hl.agg.explode(lambda elt: hl.agg.explode(lambda elt2: hl.agg.collect(elt2 + 1).append(0), - [elt, elt + 1]), - hl.if_else(t.idx > 7, [t.idx, t.idx + 1], hl.empty_array(hl.tint32))), - [9, 10, 10, 11, 10, 11, 11, 12, 0]), - (hl.agg.explode(lambda elt: hl.agg.filter(elt > 8, - hl.agg.collect(elt + 1).append(0)), - hl.if_else(t.idx > 7, [t.idx, t.idx + 1], hl.empty_array(hl.tint32))), - [10, 10, 11, 0]), - (hl.agg.explode(lambda elt: hl.agg.group_by(elt % 3, - hl.agg.collect(elt + 1).append(0)), - hl.if_else(t.idx > 7, - [t.idx, t.idx + 1], - hl.empty_array(hl.tint32))), - {0: [10, 10, 0], 1: [11, 0], 2:[9, 0]}), - (hl.agg.explode(lambda elt: hl.agg.count(), - hl.if_else(t.idx > 7, [t.idx, t.idx + 1], hl.empty_array(hl.tint32))), - 4), - (hl.agg.explode(lambda elt: hl.agg.explode(lambda elt2: hl.agg.count(), - [elt, elt + 1]), - hl.if_else(t.idx > 7, [t.idx, t.idx + 1], hl.empty_array(hl.tint32))), - 8), - (hl.agg.explode(lambda elt: hl.agg.filter(elt > 8, - hl.agg.count()), - hl.if_else(t.idx > 7, [t.idx, t.idx + 1], hl.empty_array(hl.tint32))), - 3), - (hl.agg.explode(lambda elt: hl.agg.group_by(elt % 3, - hl.agg.count()), - hl.if_else(t.idx > 7, - [t.idx, t.idx + 1], - hl.empty_array(hl.tint32))), - {0: 2, 1: 1, 2: 1}) - ] + tests = [ + ( + hl.agg.explode( + lambda elt: hl.agg.collect(elt + 1).append(0), + hl.if_else(t.idx > 7, [t.idx, t.idx + 1], hl.empty_array(hl.tint32)), + ), + [9, 10, 10, 11, 0], + ), + ( + hl.agg.explode( + lambda elt: hl.agg.explode(lambda elt2: hl.agg.collect(elt2 + 1).append(0), [elt, elt + 1]), + hl.if_else(t.idx > 7, [t.idx, t.idx + 1], hl.empty_array(hl.tint32)), + ), + [9, 10, 10, 11, 10, 11, 11, 12, 0], + ), + ( + hl.agg.explode( + lambda elt: hl.agg.filter(elt > 8, hl.agg.collect(elt + 1).append(0)), + hl.if_else(t.idx > 7, [t.idx, t.idx + 1], hl.empty_array(hl.tint32)), + ), + [10, 10, 11, 0], + ), + ( + hl.agg.explode( + lambda elt: hl.agg.group_by(elt % 3, hl.agg.collect(elt + 1).append(0)), + hl.if_else(t.idx > 7, [t.idx, t.idx + 1], hl.empty_array(hl.tint32)), + ), + {0: [10, 10, 0], 1: [11, 0], 2: [9, 0]}, + ), + ( + hl.agg.explode( + lambda elt: hl.agg.count(), hl.if_else(t.idx > 7, [t.idx, t.idx + 1], hl.empty_array(hl.tint32)) + ), + 4, + ), + ( + hl.agg.explode( + lambda elt: hl.agg.explode(lambda elt2: hl.agg.count(), [elt, elt + 1]), + hl.if_else(t.idx > 7, [t.idx, t.idx + 1], hl.empty_array(hl.tint32)), + ), + 8, + ), + ( + hl.agg.explode( + lambda elt: hl.agg.filter(elt > 8, hl.agg.count()), + hl.if_else(t.idx > 7, [t.idx, t.idx + 1], hl.empty_array(hl.tint32)), + ), + 3, + ), + ( + hl.agg.explode( + lambda elt: hl.agg.group_by(elt % 3, hl.agg.count()), + hl.if_else(t.idx > 7, [t.idx, t.idx + 1], hl.empty_array(hl.tint32)), + ), + {0: 2, 1: 1, 2: 1}, + ), + ] for aggregation, expected in tests: self.assertEqual(t.aggregate(aggregation), expected) def test_agg_group_by_1(self): t = hl.utils.range_table(10) - tests = [(hl.agg.group_by(t.idx % 2, - hl.array(hl.agg.collect_as_set(t.idx + 1)).append(0)), - {0: [1, 3, 5, 7, 9, 0], 1: [2, 4, 6, 8, 10, 0]}), - (hl.agg.group_by(t.idx % 3, - hl.agg.filter(t.idx > 7, - hl.array(hl.agg.collect_as_set(t.idx + 1)).append(0))), - {0: [10, 0], 1: [0], 2: [9, 0]}), - (hl.agg.group_by(t.idx % 3, - hl.agg.explode(lambda elt: hl.agg.collect(elt + 1).append(0), - hl.if_else(t.idx > 7, - [t.idx, t.idx + 1], - hl.empty_array(hl.tint32)))), - {0: [10, 11, 0], 1: [0], 2:[9, 10, 0]}), - (hl.agg.group_by(t.idx % 2, hl.agg.count()), {0: 5, 1: 5}), - (hl.agg.group_by(t.idx % 3, - hl.agg.filter(t.idx > 7, hl.agg.count())), - {0: 1, 1: 0, 2: 1}), - (hl.agg.group_by(t.idx % 3, - hl.agg.explode(lambda elt: hl.agg.count(), - hl.if_else(t.idx > 7, - [t.idx, t.idx + 1], - hl.empty_array(hl.tint32)))), - {0: 2, 1: 0, 2: 2}), - (hl.agg.group_by(t.idx % 5, - hl.agg.group_by(t.idx % 2, hl.agg.count())), - {i: {0: 1, 1: 1} for i in range(5)}), - ] + tests = [ + ( + hl.agg.group_by(t.idx % 2, hl.array(hl.agg.collect_as_set(t.idx + 1)).append(0)), + {0: [1, 3, 5, 7, 9, 0], 1: [2, 4, 6, 8, 10, 0]}, + ), + ( + hl.agg.group_by( + t.idx % 3, hl.agg.filter(t.idx > 7, hl.array(hl.agg.collect_as_set(t.idx + 1)).append(0)) + ), + {0: [10, 0], 1: [0], 2: [9, 0]}, + ), + ( + hl.agg.group_by( + t.idx % 3, + hl.agg.explode( + lambda elt: hl.agg.collect(elt + 1).append(0), + hl.if_else(t.idx > 7, [t.idx, t.idx + 1], hl.empty_array(hl.tint32)), + ), + ), + {0: [10, 11, 0], 1: [0], 2: [9, 10, 0]}, + ), + (hl.agg.group_by(t.idx % 2, hl.agg.count()), {0: 5, 1: 5}), + (hl.agg.group_by(t.idx % 3, hl.agg.filter(t.idx > 7, hl.agg.count())), {0: 1, 1: 0, 2: 1}), + ( + hl.agg.group_by( + t.idx % 3, + hl.agg.explode( + lambda elt: hl.agg.count(), hl.if_else(t.idx > 7, [t.idx, t.idx + 1], hl.empty_array(hl.tint32)) + ), + ), + {0: 2, 1: 0, 2: 2}, + ), + ( + hl.agg.group_by(t.idx % 5, hl.agg.group_by(t.idx % 2, hl.agg.count())), + {i: {0: 1, 1: 1} for i in range(5)}, + ), + ] results = t.aggregate(hl.tuple([x[0] for x in tests])) for aggregate, (_, expected) in zip(results, tests): assert aggregate == expected def test_agg_group_by_2(self): - t = hl.Table.parallelize([ - {"cohort": None, "pop": "EUR", "GT": hl.Call([0, 0])}, - {"cohort": None, "pop": "ASN", "GT": hl.Call([0, 1])}, - {"cohort": None, "pop": None, "GT": hl.Call([0, 0])}, - {"cohort": "SIGMA", "pop": "AFR", "GT": hl.Call([0, 1])}, - {"cohort": "SIGMA", "pop": "EUR", "GT": hl.Call([1, 1])}, - {"cohort": "IBD", "pop": "EUR", "GT": None}, - {"cohort": "IBD", "pop": "EUR", "GT": hl.Call([0, 0])}, - {"cohort": "IBD", "pop": None, "GT": hl.Call([0, 1])} - ], hl.tstruct(cohort=hl.tstr, pop=hl.tstr, GT=hl.tcall), n_partitions=3) - - r = t.aggregate(hl.struct(count=hl.agg.group_by(t.cohort, hl.agg.group_by(t.pop, hl.agg.count_where(hl.is_defined(t.GT)))), - inbreeding=hl.agg.group_by(t.cohort, hl.agg.inbreeding(t.GT, 0.1)))) - - expected_count = {None: {'EUR': 1, 'ASN': 1, None: 1}, - 'SIGMA': {'AFR': 1, 'EUR': 1}, - 'IBD': {'EUR': 1, None: 1}} + t = hl.Table.parallelize( + [ + {"cohort": None, "pop": "EUR", "GT": hl.Call([0, 0])}, + {"cohort": None, "pop": "ASN", "GT": hl.Call([0, 1])}, + {"cohort": None, "pop": None, "GT": hl.Call([0, 0])}, + {"cohort": "SIGMA", "pop": "AFR", "GT": hl.Call([0, 1])}, + {"cohort": "SIGMA", "pop": "EUR", "GT": hl.Call([1, 1])}, + {"cohort": "IBD", "pop": "EUR", "GT": None}, + {"cohort": "IBD", "pop": "EUR", "GT": hl.Call([0, 0])}, + {"cohort": "IBD", "pop": None, "GT": hl.Call([0, 1])}, + ], + hl.tstruct(cohort=hl.tstr, pop=hl.tstr, GT=hl.tcall), + n_partitions=3, + ) + + r = t.aggregate( + hl.struct( + count=hl.agg.group_by(t.cohort, hl.agg.group_by(t.pop, hl.agg.count_where(hl.is_defined(t.GT)))), + inbreeding=hl.agg.group_by(t.cohort, hl.agg.inbreeding(t.GT, 0.1)), + ) + ) + + expected_count = { + None: {'EUR': 1, 'ASN': 1, None: 1}, + 'SIGMA': {'AFR': 1, 'EUR': 1}, + 'IBD': {'EUR': 1, None: 1}, + } self.assertEqual(r.count, expected_count) @@ -817,13 +986,15 @@ def test_agg_group_by_2(self): def test_agg_group_by_on_call(self): t = hl.utils.range_table(10) - t = t.annotate(call = hl.call(0, 0), x = 1) + t = t.annotate(call=hl.call(0, 0), x=1) res = t.aggregate(hl.agg.group_by(t.call, hl.agg.sum(t.x))) self.assertEqual(res, {hl.Call([0, 0]): 10}) def test_aggregators_with_randomness(self): t = hl.utils.range_table(10) - res = t.aggregate(hl.agg.filter(hl.rand_bool(0.5), hl.struct(collection=hl.agg.collect(t.idx), sum=hl.agg.sum(t.idx)))) + res = t.aggregate( + hl.agg.filter(hl.rand_bool(0.5), hl.struct(collection=hl.agg.collect(t.idx), sum=hl.agg.sum(t.idx))) + ) self.assertEqual(sum(res.collection), res.sum) def test_aggregator_scope(self): @@ -839,30 +1010,35 @@ def test_aggregator_scope(self): with self.assertRaises(hl.expr.ExpressionException): hl.agg.counter(hl.agg.explode(lambda elt: elt, [t.idx, t.idx + 1])) - tests = [(hl.agg.filter(t.idx > 7, - hl.agg.explode(lambda x: hl.agg.collect(hl.int64(x + 1)), - [t.idx, t.idx + 1]).append( - hl.agg.group_by(t.idx % 3, hl.agg.sum(t.idx))[0]) - ), - [9, 10, 10, 11, 9]), - (hl.agg.explode(lambda x: - hl.agg.filter(x > 7, - hl.agg.collect(x) - ).extend(hl.agg.group_by(t.idx % 3, - hl.array(hl.agg.collect_as_set(x)))[0]), - [t.idx, t.idx + 1]), - [8, 8, 9, 9, 10, 0, 1, 3, 4, 6, 7, 9, 10]), - (hl.agg.group_by(t.idx % 3, - hl.agg.filter(t.idx > 7, - hl.agg.collect(t.idx) - ).extend(hl.agg.explode( - lambda x: hl.array(hl.agg.collect_as_set(x)), - [t.idx, t.idx + 34])) - ), - {0: [9, 0, 3, 6, 9, 34, 37, 40, 43], - 1: [1, 4, 7, 35, 38, 41], - 2: [8, 2, 5, 8, 36, 39, 42]}) - ] + tests = [ + ( + hl.agg.filter( + t.idx > 7, + hl.agg.explode(lambda x: hl.agg.collect(hl.int64(x + 1)), [t.idx, t.idx + 1]).append( + hl.agg.group_by(t.idx % 3, hl.agg.sum(t.idx))[0] + ), + ), + [9, 10, 10, 11, 9], + ), + ( + hl.agg.explode( + lambda x: hl.agg.filter(x > 7, hl.agg.collect(x)).extend( + hl.agg.group_by(t.idx % 3, hl.array(hl.agg.collect_as_set(x)))[0] + ), + [t.idx, t.idx + 1], + ), + [8, 8, 9, 9, 10, 0, 1, 3, 4, 6, 7, 9, 10], + ), + ( + hl.agg.group_by( + t.idx % 3, + hl.agg.filter(t.idx > 7, hl.agg.collect(t.idx)).extend( + hl.agg.explode(lambda x: hl.array(hl.agg.collect_as_set(x)), [t.idx, t.idx + 34]) + ), + ), + {0: [9, 0, 3, 6, 9, 34, 37, 40, 43], 1: [1, 4, 7, 35, 38, 41], 2: [8, 2, 5, 8, 36, 39, 42]}, + ), + ] for aggregation, expected in tests: self.assertEqual(t.aggregate(aggregation), expected) @@ -872,7 +1048,7 @@ def test_aggregator_bindings(self): t.aggregate(hl.bind(lambda i: hl.agg.sum(t.idx + i), 1)) with self.assertRaises(hl.expr.ExpressionException): t.annotate(x=hl.bind(lambda i: hl.scan.sum(t.idx + i), 1)) - #filter + # filter with self.assertRaises(hl.expr.ExpressionException): t.aggregate(hl.bind(lambda i: hl.agg.filter(i == 1, hl.agg.sum(t.idx)), 1)) with self.assertRaises(hl.expr.ExpressionException): @@ -881,7 +1057,7 @@ def test_aggregator_bindings(self): t.annotate(x=hl.bind(lambda i: hl.scan.filter(i == 1, hl.scan.sum(t.idx)), 1)) with self.assertRaises(hl.expr.ExpressionException): t.annotate(x=hl.bind(lambda i: hl.scan.filter(t.idx == 1, hl.scan.sum(t.idx) + i), 1)) - #explode + # explode with self.assertRaises(hl.expr.ExpressionException): t.aggregate(hl.bind(lambda i: hl.agg.explode(lambda elt: hl.agg.sum(elt), [t.idx, t.idx + i]), 1)) with self.assertRaises(hl.expr.ExpressionException): @@ -890,7 +1066,7 @@ def test_aggregator_bindings(self): t.annotate(x=hl.bind(lambda i: hl.scan.explode(lambda elt: hl.scan.sum(elt), [t.idx, t.idx + i]), 1)) with self.assertRaises(hl.expr.ExpressionException): t.annotate(x=hl.bind(lambda i: hl.scan.explode(lambda elt: hl.scan.sum(elt) + i, [t.idx, t.idx + 1]), 1)) - #group_by + # group_by with self.assertRaises(hl.expr.ExpressionException): t.aggregate(hl.bind(lambda i: hl.agg.group_by(t.idx % 3 + i, hl.agg.sum(t.idx)), 1)) with self.assertRaises(hl.expr.ExpressionException): @@ -908,14 +1084,16 @@ def test_aggregator_bindings(self): def test_scan(self): table = hl.utils.range_table(10) - t = table.select(scan_count=hl.scan.count(), - scan_count_where=hl.scan.count_where(table.idx % 2 == 0), - scan_count_where2=hl.scan.filter(table.idx % 2 == 0, hl.scan.count()), - arr_sum=hl.scan.array_sum([1, 2, hl.missing(tint32)]), - bind_agg=hl.scan.count_where(hl.bind(lambda x: x % 2 == 0, table.idx)), - mean=hl.scan.mean(table.idx), - foo=hl.min(3, hl.scan.sum(table.idx)), - fraction_odd=hl.scan.fraction(table.idx % 2 == 0)) + t = table.select( + scan_count=hl.scan.count(), + scan_count_where=hl.scan.count_where(table.idx % 2 == 0), + scan_count_where2=hl.scan.filter(table.idx % 2 == 0, hl.scan.count()), + arr_sum=hl.scan.array_sum([1, 2, hl.missing(tint32)]), + bind_agg=hl.scan.count_where(hl.bind(lambda x: x % 2 == 0, table.idx)), + mean=hl.scan.mean(table.idx), + foo=hl.min(3, hl.scan.sum(table.idx)), + fraction_odd=hl.scan.fraction(table.idx % 2 == 0), + ) rows = t.collect() r = hl.Struct(**{n: [i[n] for i in rows] for n in t.row.keys()}) @@ -925,7 +1103,7 @@ def test_scan(self): self.assertEqual(r.arr_sum, [None] + [[i * 1, i * 2, 0] for i in range(1, 10)]) self.assertEqual(r.bind_agg, [(i + 1) // 2 for i in range(10)]) self.assertEqual(r.foo, [min(sum(range(i)), 3) for i in range(10)]) - for (x, y) in zip(r.fraction_odd, [None] + [((i + 1)//2)/i for i in range(1, 10)]): + for x, y in zip(r.fraction_odd, [None] + [((i + 1) // 2) / i for i in range(1, 10)]): self.assertAlmostEqual(x, y) table = hl.utils.range_table(10) @@ -936,17 +1114,21 @@ def test_scan(self): def test_scan_filter(self): t = hl.utils.range_table(5) tests = [ - (hl.scan.filter((t.idx % 2) == 0, - hl.scan.collect(t.idx).append(t.idx)), - [[0], [0, 1], [0, 2], [0, 2, 3], [0, 2, 4]]), - (hl.scan.filter((t.idx % 2) == 0, - hl.scan.explode(lambda elt: hl.scan.collect(elt).append(t.idx), - [t.idx, t.idx + 1])), - [[0], [0, 1, 1], [0, 1, 2], [0, 1, 2, 3, 3], [0, 1, 2, 3, 4]]), - (hl.scan.filter((t.idx % 2) == 0, - hl.scan.group_by(t.idx % 3, - hl.scan.collect(t.idx).append(t.idx))), - [{}, {0: [0, 1]}, {0: [0, 2]}, {0: [0, 3], 2: [2, 3]}, {0: [0, 4], 2: [2, 4]}]) + ( + hl.scan.filter((t.idx % 2) == 0, hl.scan.collect(t.idx).append(t.idx)), + [[0], [0, 1], [0, 2], [0, 2, 3], [0, 2, 4]], + ), + ( + hl.scan.filter( + (t.idx % 2) == 0, + hl.scan.explode(lambda elt: hl.scan.collect(elt).append(t.idx), [t.idx, t.idx + 1]), + ), + [[0], [0, 1, 1], [0, 1, 2], [0, 1, 2, 3, 3], [0, 1, 2, 3, 4]], + ), + ( + hl.scan.filter((t.idx % 2) == 0, hl.scan.group_by(t.idx % 3, hl.scan.collect(t.idx).append(t.idx))), + [{}, {0: [0, 1]}, {0: [0, 2]}, {0: [0, 3], 2: [2, 3]}, {0: [0, 4], 2: [2, 4]}], + ), ] for aggregation, expected in tests: @@ -955,29 +1137,35 @@ def test_scan_filter(self): def test_scan_explode(self): t = hl.utils.range_table(5) tests = [ - (hl.scan.explode(lambda elt: hl.scan.collect(elt).append(t.idx), - [t.idx, t.idx + 1]), - [[0], [0, 1, 1], [0, 1, 1, 2, 2], [0, 1, 1, 2, 2, 3, 3], [0, 1, 1, 2, 2, 3, 3, 4, 4]]), - (hl.scan.explode(lambda elt: - hl.scan.explode(lambda elt2: - hl.scan.collect(elt).append(t.idx), - [elt]), - [t.idx, t.idx + 1]), - [[0], [0, 1, 1], [0, 1, 1, 2, 2], [0, 1, 1, 2, 2, 3, 3], [0, 1, 1, 2, 2, 3, 3, 4, 4]]), - (hl.scan.explode(lambda elt: - hl.scan.filter((elt % 2) == 0, - hl.scan.collect(elt).append(t.idx)), - [t.idx, t.idx + 1]), - [[0], [0, 1], [0, 2, 2], [0, 2, 2, 3], [0, 2, 2, 4, 4]]), - (hl.scan.explode(lambda elt: - hl.scan.group_by(elt % 3, - hl.scan.collect(elt).append(t.idx)), - [t.idx, t.idx + 1]), - [{}, - {0: [0, 1], 1: [1, 1]}, - {0: [0, 2], 1: [1, 1, 2], 2: [2, 2]}, - {0: [0, 3, 3], 1: [1, 1, 3], 2: [2, 2, 3]}, - {0: [0, 3, 3, 4], 1: [1, 1, 4, 4], 2: [2, 2, 4]}]), + ( + hl.scan.explode(lambda elt: hl.scan.collect(elt).append(t.idx), [t.idx, t.idx + 1]), + [[0], [0, 1, 1], [0, 1, 1, 2, 2], [0, 1, 1, 2, 2, 3, 3], [0, 1, 1, 2, 2, 3, 3, 4, 4]], + ), + ( + hl.scan.explode( + lambda elt: hl.scan.explode(lambda elt2: hl.scan.collect(elt).append(t.idx), [elt]), + [t.idx, t.idx + 1], + ), + [[0], [0, 1, 1], [0, 1, 1, 2, 2], [0, 1, 1, 2, 2, 3, 3], [0, 1, 1, 2, 2, 3, 3, 4, 4]], + ), + ( + hl.scan.explode( + lambda elt: hl.scan.filter((elt % 2) == 0, hl.scan.collect(elt).append(t.idx)), [t.idx, t.idx + 1] + ), + [[0], [0, 1], [0, 2, 2], [0, 2, 2, 3], [0, 2, 2, 4, 4]], + ), + ( + hl.scan.explode( + lambda elt: hl.scan.group_by(elt % 3, hl.scan.collect(elt).append(t.idx)), [t.idx, t.idx + 1] + ), + [ + {}, + {0: [0, 1], 1: [1, 1]}, + {0: [0, 2], 1: [1, 1, 2], 2: [2, 2]}, + {0: [0, 3, 3], 1: [1, 1, 3], 2: [2, 2, 3]}, + {0: [0, 3, 3, 4], 1: [1, 1, 4, 4], 2: [2, 2, 4]}, + ], + ), ] for aggregation, expected in tests: @@ -986,29 +1174,32 @@ def test_scan_explode(self): def test_scan_group_by(self): t = hl.utils.range_table(5) tests = [ - (hl.scan.group_by(t.idx % 3, - hl.scan.collect(t.idx).append(t.idx)), - [{}, - {0: [0, 1]}, - {0: [0, 2], 1: [1, 2]}, - {0: [0, 3], 1: [1, 3], 2: [2, 3]}, - {0: [0, 3, 4], 1: [1, 4], 2: [2, 4]}]), - (hl.scan.group_by(t.idx % 3, - hl.scan.filter((t.idx % 2) == 0, - hl.scan.collect(t.idx).append(t.idx))), - [{}, - {0: [0, 1]}, - {0: [0, 2], 1: [2]}, - {0: [0, 3], 1: [3], 2: [2, 3]}, - {0: [0, 4], 1: [4], 2: [2, 4]}]), - (hl.scan.group_by(t.idx % 3, - hl.scan.explode(lambda elt: hl.scan.collect(elt).append(t.idx), - [t.idx, t.idx + 1])), - [{}, - {0: [0, 1, 1]}, - {0: [0, 1, 2], 1: [1, 2, 2]}, - {0: [0, 1, 3], 1: [1, 2, 3], 2: [2, 3, 3]}, - {0: [0, 1, 3, 4, 4], 1: [1, 2, 4], 2: [2, 3, 4]}]) + ( + hl.scan.group_by(t.idx % 3, hl.scan.collect(t.idx).append(t.idx)), + [ + {}, + {0: [0, 1]}, + {0: [0, 2], 1: [1, 2]}, + {0: [0, 3], 1: [1, 3], 2: [2, 3]}, + {0: [0, 3, 4], 1: [1, 4], 2: [2, 4]}, + ], + ), + ( + hl.scan.group_by(t.idx % 3, hl.scan.filter((t.idx % 2) == 0, hl.scan.collect(t.idx).append(t.idx))), + [{}, {0: [0, 1]}, {0: [0, 2], 1: [2]}, {0: [0, 3], 1: [3], 2: [2, 3]}, {0: [0, 4], 1: [4], 2: [2, 4]}], + ), + ( + hl.scan.group_by( + t.idx % 3, hl.scan.explode(lambda elt: hl.scan.collect(elt).append(t.idx), [t.idx, t.idx + 1]) + ), + [ + {}, + {0: [0, 1, 1]}, + {0: [0, 1, 2], 1: [1, 2, 2]}, + {0: [0, 1, 3], 1: [1, 2, 3], 2: [2, 3, 3]}, + {0: [0, 1, 3, 4, 4], 1: [1, 2, 4], 2: [2, 3, 4]}, + ], + ), ] for aggregation, expected in tests: @@ -1023,28 +1214,54 @@ def test_scan_array_agg(self): def test_aggregators_max_min(self): table = hl.utils.range_table(10) # FIXME: add boolean when function registry is removed - for (f, typ) in [(lambda x: hl.int32(x), tint32), (lambda x: hl.int64(x), tint64), - (lambda x: hl.float32(x), tfloat32), (lambda x: hl.float64(x), tfloat64)]: + for f, typ in [ + (lambda x: hl.int32(x), tint32), + (lambda x: hl.int64(x), tint64), + (lambda x: hl.float32(x), tfloat32), + (lambda x: hl.float64(x), tfloat64), + ]: t = table.annotate(x=-1 * f(table.idx) - 5, y=hl.missing(typ)) - r = t.aggregate(hl.struct(max=hl.agg.max(t.x), max_empty=hl.agg.max(t.y), - min=hl.agg.min(t.x), min_empty=hl.agg.min(t.y))) - self.assertTrue(r.max == -5 and r.max_empty is None and - r.min == -14 and r.min_empty is None) + r = t.aggregate( + hl.struct( + max=hl.agg.max(t.x), max_empty=hl.agg.max(t.y), min=hl.agg.min(t.x), min_empty=hl.agg.min(t.y) + ) + ) + self.assertTrue(r.max == -5 and r.max_empty is None and r.min == -14 and r.min_empty is None) def test_aggregators_sum_product(self): table = hl.utils.range_table(5) - for (f, typ) in [(lambda x: hl.int32(x), tint32), (lambda x: hl.int64(x), tint64), - (lambda x: hl.float32(x), tfloat32), (lambda x: hl.float64(x), tfloat64)]: + for f, typ in [ + (lambda x: hl.int32(x), tint32), + (lambda x: hl.int64(x), tint64), + (lambda x: hl.float32(x), tfloat32), + (lambda x: hl.float64(x), tfloat64), + ]: t = table.annotate(x=-1 * f(table.idx) - 1, y=f(table.idx), z=hl.missing(typ)) - r = t.aggregate(hl.struct(sum_x=hl.agg.sum(t.x), sum_y=hl.agg.sum(t.y), sum_empty=hl.agg.sum(t.z), - prod_x=hl.agg.product(t.x), prod_y=hl.agg.product(t.y), prod_empty=hl.agg.product(t.z))) - self.assertTrue(r.sum_x == -15 and r.sum_y == 10 and r.sum_empty == 0 and - r.prod_x == -120 and r.prod_y == 0 and r.prod_empty == 1) + r = t.aggregate( + hl.struct( + sum_x=hl.agg.sum(t.x), + sum_y=hl.agg.sum(t.y), + sum_empty=hl.agg.sum(t.z), + prod_x=hl.agg.product(t.x), + prod_y=hl.agg.product(t.y), + prod_empty=hl.agg.product(t.z), + ) + ) + self.assertTrue( + r.sum_x == -15 + and r.sum_y == 10 + and r.sum_empty == 0 + and r.prod_x == -120 + and r.prod_y == 0 + and r.prod_empty == 1 + ) def test_aggregators_hist(self): table = hl.utils.range_table(11) r = table.aggregate(hl.agg.hist(table.idx - 1, 0, 8, 4)) - self.assertTrue(r.bin_edges == [0, 2, 4, 6, 8] and r.bin_freq == [2, 2, 2, 3] and r.n_smaller == 1 and r.n_larger == 1) + self.assertTrue( + r.bin_edges == [0, 2, 4, 6, 8] and r.bin_freq == [2, 2, 2, 3] and r.n_smaller == 1 and r.n_larger == 1 + ) def test_aggregators_hist_neg0(self): table = hl.utils.range_table(32) @@ -1069,13 +1286,10 @@ def test_aggregator_cse(self): mt = hl.utils.range_matrix_table(10, 10) x = hl.int64(5) - rows = mt.annotate_rows(agg=hl.agg.sum(x+x), scan=hl.scan.sum(x+x), val=x+x).rows() + rows = mt.annotate_rows(agg=hl.agg.sum(x + x), scan=hl.scan.sum(x + x), val=x + x).rows() expected = hl.utils.range_table(10) expected = expected.key_by(row_idx=expected.idx) - expected = expected.select( - agg=hl.int64(100), - scan=hl.int64(expected.row_idx*10), - val=hl.int64(10)) + expected = expected.select(agg=hl.int64(100), scan=hl.int64(expected.row_idx * 10), val=hl.int64(10)) self.assertTrue(rows._same(expected)) # Tested against R code @@ -1091,16 +1305,20 @@ def test_aggregator_cse(self): # f = sumfit$fstatistic # p = pf(f[1],f[2],f[3],lower.tail=F) def test_aggregators_linreg(self): - t = hl.Table.parallelize([ - {"y": None, "x": 1.0}, - {"y": 0.0, "x": None}, - {"y": None, "x": None}, - {"y": 0.22848042, "x": 0.2575928}, - {"y": 0.09159706, "x": -0.3445442}, - {"y": -0.43881935, "x": 1.6590146}, - {"y": -0.99106171, "x": -1.1688806}, - {"y": 2.12823289, "x": 0.5587043} - ], hl.tstruct(y=hl.tfloat64, x=hl.tfloat64), n_partitions=3) + t = hl.Table.parallelize( + [ + {"y": None, "x": 1.0}, + {"y": 0.0, "x": None}, + {"y": None, "x": None}, + {"y": 0.22848042, "x": 0.2575928}, + {"y": 0.09159706, "x": -0.3445442}, + {"y": -0.43881935, "x": 1.6590146}, + {"y": -0.99106171, "x": -1.1688806}, + {"y": 2.12823289, "x": 0.5587043}, + ], + hl.tstruct(y=hl.tfloat64, x=hl.tfloat64), + n_partitions=3, + ) r = t.aggregate(hl.struct(linreg=hl.agg.linreg(t.y, [1, t.x]))).linreg self.assertAlmostEqual(r.beta[0], 0.14069227) self.assertAlmostEqual(r.beta[1], 0.32744807) @@ -1130,8 +1348,7 @@ def test_aggregators_linreg(self): # weighted OLS t = t.add_index() - r = t.aggregate(hl.struct( - linreg=hl.agg.linreg(t.y, [1, t.x], weight=t.idx))).linreg + r = t.aggregate(hl.struct(linreg=hl.agg.linreg(t.y, [1, t.x], weight=t.idx))).linreg self.assertAlmostEqual(r.beta[0], 0.2339059) self.assertAlmostEqual(r.beta[1], 0.4275577) self.assertAlmostEqual(r.standard_error[0], 0.6638324) @@ -1161,16 +1378,27 @@ def test_aggregator_downsample(self): ys = [2, 6, 4, 9, 1, 8, 5, 10, 3, 7] label1 = ["2", "6", "4", "9", "1", "8", "5", "10", "3", "7"] label2 = ["two", "six", "four", "nine", "one", "eight", "five", "ten", "three", "seven"] - table = hl.Table.parallelize([hl.struct(x=x, y=y, label1=label1, label2=label2) - for x, y, label1, label2 in zip(xs, ys, label1, label2)]) - r = table.aggregate(hl.agg.downsample(table.x, table.y, label=hl.array([table.label1, table.label2]), n_divisions=10)) + table = hl.Table.parallelize([ + hl.struct(x=x, y=y, label1=label1, label2=label2) for x, y, label1, label2 in zip(xs, ys, label1, label2) + ]) + r = table.aggregate( + hl.agg.downsample(table.x, table.y, label=hl.array([table.label1, table.label2]), n_divisions=10) + ) xs = [x for (x, y, l) in r] ys = [y for (x, y, l) in r] label = [tuple(l) for (x, y, l) in r] - expected = set([(1.0, 1.0, ('1', 'one')), (2.0, 2.0, ('2', 'two')), (3.0, 3.0, ('3', 'three')), - (4.0, 4.0, ('4', 'four')), (5.0, 5.0, ('5', 'five')), (6.0, 6.0, ('6', 'six')), - (7.0, 7.0, ('7', 'seven')), (8.0, 8.0, ('8', 'eight')), (9.0, 9.0, ('9', 'nine')), - (10.0, 10.0, ('10', 'ten'))]) + expected = set([ + (1.0, 1.0, ('1', 'one')), + (2.0, 2.0, ('2', 'two')), + (3.0, 3.0, ('3', 'three')), + (4.0, 4.0, ('4', 'four')), + (5.0, 5.0, ('5', 'five')), + (6.0, 6.0, ('6', 'six')), + (7.0, 7.0, ('7', 'seven')), + (8.0, 8.0, ('8', 'eight')), + (9.0, 9.0, ('9', 'nine')), + (10.0, 10.0, ('10', 'ten')), + ]) for point in zip(xs, ys, label): self.assertTrue(point in expected) @@ -1182,15 +1410,8 @@ def test_downsample_aggregator_on_empty_table(self): def test_downsample_in_array_agg(self): mt = hl.utils.range_matrix_table(50, 50) - mt = mt.annotate_rows(y = hl.rand_unif(0, 1)) - mt = mt.annotate_cols( - binned=hl.agg.downsample( - mt.row_idx, - mt.y, - label=hl.str(mt.y), - n_divisions=4 - ) - ) + mt = mt.annotate_rows(y=hl.rand_unif(0, 1)) + mt = mt.annotate_cols(binned=hl.agg.downsample(mt.row_idx, mt.y, label=hl.str(mt.y), n_divisions=4)) mt.cols()._force_count() def test_aggregator_info_score(self): @@ -1199,7 +1420,7 @@ def test_aggregator_info_score(self): truth_result_file = resource('infoScoreTest.result') mt = hl.import_gen(gen_file, sample_file=sample_file) - mt = mt.annotate_rows(info_score = hl.agg.info_score(mt.GP)) + mt = mt.annotate_rows(info_score=hl.agg.info_score(mt.GP)) truth = hl.import_table(truth_result_file, impute=True, delimiter=' ', no_header=True, missing='None') truth = truth.drop('f1', 'f2').rename({'f0': 'variant', 'f3': 'score', 'f4': 'n_included'}) @@ -1208,13 +1429,15 @@ def test_aggregator_info_score(self): computed = mt.rows() joined = truth[computed.key] - computed = computed.select(score = computed.info_score.score, - score_truth = joined.score, - n_included = computed.info_score.n_included, - n_included_truth = joined.n_included) + computed = computed.select( + score=computed.info_score.score, + score_truth=joined.score, + n_included=computed.info_score.n_included, + n_included_truth=joined.n_included, + ) violations = computed.filter( - (computed.n_included != computed.n_included_truth) | - (hl.abs(computed.score - computed.score_truth) > 1e-3)) + (computed.n_included != computed.n_included_truth) | (hl.abs(computed.score - computed.score_truth) > 1e-3) + ) if not violations.count() == 0: violations.show() self.fail("disagreement between computed info score and truth") @@ -1227,25 +1450,35 @@ def test_aggregator_info_score_works_with_bgen_import(self): self.assertEqual(result.n_included, 8) def test_aggregator_group_by_sorts_result(self): - t = hl.Table.parallelize([ # the `s` key is stored before the `m` in java.util.HashMap - {"group": "m", "x": 1}, - {"group": "s", "x": 2}, - {"group": "s", "x": 3}, - {"group": "m", "x": 4}, - {"group": "m", "x": 5} - ], hl.tstruct(group=hl.tstr, x=hl.tint32), n_partitions=1) + t = hl.Table.parallelize( + [ # the `s` key is stored before the `m` in java.util.HashMap + {"group": "m", "x": 1}, + {"group": "s", "x": 2}, + {"group": "s", "x": 3}, + {"group": "m", "x": 4}, + {"group": "m", "x": 5}, + ], + hl.tstruct(group=hl.tstr, x=hl.tint32), + n_partitions=1, + ) grouped_expr = t.aggregate(hl.array(hl.agg.group_by(t.group, hl.agg.sum(t.x)))) self.assertEqual(grouped_expr, hl.eval(hl.sorted(grouped_expr))) def test_agg_corr(self): ht = hl.utils.range_table(10) - ht = ht.annotate(tests=hl.range(0, 10).map( - lambda i: hl.struct( - x=hl.if_else(hl.rand_bool(0.1), hl.missing(hl.tfloat64), hl.rand_unif(-10, 10)), - y=hl.if_else(hl.rand_bool(0.1), hl.missing(hl.tfloat64), hl.rand_unif(-10, 10))))) + ht = ht.annotate( + tests=hl.range(0, 10).map( + lambda i: hl.struct( + x=hl.if_else(hl.rand_bool(0.1), hl.missing(hl.tfloat64), hl.rand_unif(-10, 10)), + y=hl.if_else(hl.rand_bool(0.1), hl.missing(hl.tfloat64), hl.rand_unif(-10, 10)), + ) + ) + ) - results = ht.aggregate(hl.agg.array_agg(lambda test: (hl.agg.corr(test.x, test.y), hl.agg.collect((test.x, test.y))), ht.tests)) + results = ht.aggregate( + hl.agg.array_agg(lambda test: (hl.agg.corr(test.x, test.y), hl.agg.collect((test.x, test.y))), ht.tests) + ) for corr, xy in results: filtered = [(x, y) for x, y in xy if x is not None and y is not None] @@ -1261,40 +1494,27 @@ def test_switch(self): x = hl.literal('1') na = hl.missing(tint32) - expr1 = (hl.switch(x) - .when('123', 5) - .when('1', 6) - .when('0', 2) - .or_missing()) + expr1 = hl.switch(x).when('123', 5).when('1', 6).when('0', 2).or_missing() self.assertEqual(hl.eval(expr1), 6) - expr2 = (hl.switch(x) - .when('123', 5) - .when('0', 2) - .or_missing()) + expr2 = hl.switch(x).when('123', 5).when('0', 2).or_missing() self.assertEqual(hl.eval(expr2), None) - expr3 = (hl.switch(x) - .when('123', 5) - .when('0', 2) - .default(100)) + expr3 = hl.switch(x).when('123', 5).when('0', 2).default(100) self.assertEqual(hl.eval(expr3), 100) - expr4 = (hl.switch(na) - .when(5, 0) - .when(6, 1) - .when(0, 2) - .when(hl.missing(tint32), 3) # NA != NA - .default(4)) + expr4 = hl.switch(na).when(5, 0).when(6, 1).when(0, 2).when(hl.missing(tint32), 3).default(4) # NA != NA self.assertEqual(hl.eval(expr4), None) - expr5 = (hl.switch(na) + expr5 = ( + hl.switch(na) .when(5, 0) .when(6, 1) .when(0, 2) .when(hl.missing(tint32), 3) # NA != NA .when_missing(-1) - .default(4)) + .default(4) + ) self.assertEqual(hl.eval(expr5), -1) with pytest.raises(hl.utils.java.HailUserError) as exc: @@ -1304,12 +1524,7 @@ def test_switch(self): def test_case(self): def make_case(x): x = hl.literal(x) - return (hl.case() - .when(x == 6, 'A') - .when(x % 3 == 0, 'B') - .when(x == 5, 'C') - .when(x < 2, 'D') - .or_missing()) + return hl.case().when(x == 6, 'A').when(x % 3 == 0, 'B').when(x == 5, 'C').when(x < 2, 'D').or_missing() self.assertEqual(hl.eval(make_case(6)), 'A') self.assertEqual(hl.eval(make_case(12)), 'B') @@ -1334,41 +1549,31 @@ def assert_typed(expr, result, dtype): self.assertEqual(t, dtype) self.assertEqual(result, r) - assert_typed(s.drop('f3'), - hl.Struct(f1=1, f2=2), - tstruct(f1=tint32, f2=tint32)) + assert_typed(s.drop('f3'), hl.Struct(f1=1, f2=2), tstruct(f1=tint32, f2=tint32)) - assert_typed(s.drop('f1'), - hl.Struct(f2=2, f3=3), - tstruct(f2=tint32, f3=tint32)) + assert_typed(s.drop('f1'), hl.Struct(f2=2, f3=3), tstruct(f2=tint32, f3=tint32)) - assert_typed(s.drop(), - hl.Struct(f1=1, f2=2, f3=3), - tstruct(f1=tint32, f2=tint32, f3=tint32)) + assert_typed(s.drop(), hl.Struct(f1=1, f2=2, f3=3), tstruct(f1=tint32, f2=tint32, f3=tint32)) - assert_typed(s.select('f1', 'f2'), - hl.Struct(f1=1, f2=2), - tstruct(f1=tint32, f2=tint32)) + assert_typed(s.select('f1', 'f2'), hl.Struct(f1=1, f2=2), tstruct(f1=tint32, f2=tint32)) - assert_typed(s.select('f2', 'f1', f4=5, f5=6), - hl.Struct(f2=2, f1=1, f4=5, f5=6), - tstruct(f2=tint32, f1=tint32, f4=tint32, f5=tint32)) + assert_typed( + s.select('f2', 'f1', f4=5, f5=6), + hl.Struct(f2=2, f1=1, f4=5, f5=6), + tstruct(f2=tint32, f1=tint32, f4=tint32, f5=tint32), + ) - assert_typed(s.select(), - hl.Struct(), - tstruct()) + assert_typed(s.select(), hl.Struct(), tstruct()) - assert_typed(s.annotate(f1=5, f2=10, f4=15), - hl.Struct(f1=5, f2=10, f3=3, f4=15), - tstruct(f1=tint32, f2=tint32, f3=tint32, f4=tint32)) + assert_typed( + s.annotate(f1=5, f2=10, f4=15), + hl.Struct(f1=5, f2=10, f3=3, f4=15), + tstruct(f1=tint32, f2=tint32, f3=tint32, f4=tint32), + ) - assert_typed(s.annotate(f1=5), - hl.Struct(f1=5, f2=2, f3=3), - tstruct(f1=tint32, f2=tint32, f3=tint32)) + assert_typed(s.annotate(f1=5), hl.Struct(f1=5, f2=2, f3=3), tstruct(f1=tint32, f2=tint32, f3=tint32)) - assert_typed(s.annotate(), - hl.Struct(f1=1, f2=2, f3=3), - tstruct(f1=tint32, f2=tint32, f3=tint32)) + assert_typed(s.annotate(), hl.Struct(f1=1, f2=2, f3=3), tstruct(f1=tint32, f2=tint32, f3=tint32)) def test_shadowed_struct_fields(self): from typing import Callable @@ -1386,7 +1591,12 @@ def test_shadowed_struct_fields(self): assert isinstance(s._ir, ir.IR) assert '_ir' not in s._warn_on_shadowed_name - s = hl.StructExpression._from_fields({'foo': hl.int(1), 'values': hl.int(2), 'collect': hl.int(3), '_ir': hl.int(4)}) + s = hl.StructExpression._from_fields({ + 'foo': hl.int(1), + 'values': hl.int(2), + 'collect': hl.int(3), + '_ir': hl.int(4), + }) assert 'foo' not in s._warn_on_shadowed_name assert isinstance(s.foo, hl.Expression) assert 'values' in s._warn_on_shadowed_name @@ -1430,40 +1640,37 @@ def test_functions_any_and_all(self): x7 = hl.literal([False, None], dtype='array') x8 = hl.literal([True, False, None], dtype='array') - assert hl.eval( - ( - (x1.any(lambda x: x), x1.all(lambda x: x)), - (x2.any(lambda x: x), x2.all(lambda x: x)), - (x3.any(lambda x: x), x3.all(lambda x: x)), - (x4.any(lambda x: x), x4.all(lambda x: x)), - (x5.any(lambda x: x), x5.all(lambda x: x)), - (x6.any(lambda x: x), x6.all(lambda x: x)), - (x7.any(lambda x: x), x7.all(lambda x: x)), - (x8.any(lambda x: x), x8.all(lambda x: x)), - ) - ) == ( - (False, True), - (True, True), - (False, False), - (None, None), - (True, False), - (True, None), - (None, False), - (True, False) - ) + assert hl.eval(( + (x1.any(lambda x: x), x1.all(lambda x: x)), + (x2.any(lambda x: x), x2.all(lambda x: x)), + (x3.any(lambda x: x), x3.all(lambda x: x)), + (x4.any(lambda x: x), x4.all(lambda x: x)), + (x5.any(lambda x: x), x5.all(lambda x: x)), + (x6.any(lambda x: x), x6.all(lambda x: x)), + (x7.any(lambda x: x), x7.all(lambda x: x)), + (x8.any(lambda x: x), x8.all(lambda x: x)), + )) == ( + (False, True), + (True, True), + (False, False), + (None, None), + (True, False), + (True, None), + (None, False), + (True, False), + ) def test_aggregator_any_and_all(self): df = hl.utils.range_table(10) - df = df.annotate(all_true=True, - all_false=False, - true_or_missing=hl.if_else(df.idx % 2 == 0, True, hl.missing(tbool)), - false_or_missing=hl.if_else(df.idx % 2 == 0, False, hl.missing(tbool)), - all_missing=hl.missing(tbool), - mixed_true_false=hl.if_else(df.idx % 2 == 0, True, False), - mixed_all=hl.switch(df.idx % 3) - .when(0, True) - .when(1, False) - .or_missing()).cache() + df = df.annotate( + all_true=True, + all_false=False, + true_or_missing=hl.if_else(df.idx % 2 == 0, True, hl.missing(tbool)), + false_or_missing=hl.if_else(df.idx % 2 == 0, False, hl.missing(tbool)), + all_missing=hl.missing(tbool), + mixed_true_false=hl.if_else(df.idx % 2 == 0, True, False), + mixed_all=hl.switch(df.idx % 3).when(0, True).when(1, False).or_missing(), + ).cache() self.assertEqual(df.aggregate(hl.agg.any(df.all_true)), True) self.assertEqual(df.aggregate(hl.agg.all(df.all_true)), True) @@ -1485,32 +1692,37 @@ def test_aggregator_any_and_all(self): def test_agg_prev_nonnull(self): t = hl.utils.range_table(17, n_partitions=8) - t = t.annotate( - prev = hl.scan._prev_nonnull( - hl.or_missing((t.idx % 3) != 0, t.row))) + t = t.annotate(prev=hl.scan._prev_nonnull(hl.or_missing((t.idx % 3) != 0, t.row))) self.assertTrue( - t.all(hl._values_similar(t.prev.idx, - hl.case() - .when(t.idx < 2, hl.missing(hl.tint32)) - .when(((t.idx - 1) % 3) == 0, t.idx - 2) - .default(t.idx - 1)))) + t.all( + hl._values_similar( + t.prev.idx, + hl.case() + .when(t.idx < 2, hl.missing(hl.tint32)) + .when(((t.idx - 1) % 3) == 0, t.idx - 2) + .default(t.idx - 1), + ) + ) + ) def test_agg_table_take(self): - ht = hl.utils.range_table(10).annotate(x = 'a') + ht = hl.utils.range_table(10).annotate(x='a') self.assertEqual(ht.aggregate(agg.take(ht.x, 2)), ['a', 'a']) def test_agg_take_by(self): ht = hl.utils.range_table(10, 3) data1 = hl.literal([str(i) for i in range(10)]) data2 = hl.literal([i**2 for i in range(10)]) - ht = ht.annotate(d1 = data1[ht.idx], d2=data2[ht.idx]) - - tb1, tb2, tb3, tb4 = ht.aggregate((hl.agg.take(ht.d1, 5, ordering=-ht.idx), - hl.agg.take(ht.d2, 5, ordering=-ht.idx), - hl.agg.take(ht.idx, 7, ordering=ht.idx // 5), # stable sort - hl.agg.array_agg( - lambda elt: hl.agg.take(hl.str(elt) + "_" + hl.str(ht.idx), 4, - ordering=ht.idx), hl.range(0, 2)))) + ht = ht.annotate(d1=data1[ht.idx], d2=data2[ht.idx]) + + tb1, tb2, tb3, tb4 = ht.aggregate(( + hl.agg.take(ht.d1, 5, ordering=-ht.idx), + hl.agg.take(ht.d2, 5, ordering=-ht.idx), + hl.agg.take(ht.idx, 7, ordering=ht.idx // 5), # stable sort + hl.agg.array_agg( + lambda elt: hl.agg.take(hl.str(elt) + "_" + hl.str(ht.idx), 4, ordering=ht.idx), hl.range(0, 2) + ), + )) assert tb1 == ['9', '8', '7', '6', '5'] assert tb2 == [81, 64, 49, 36, 25] @@ -1522,11 +1734,11 @@ def test_agg_minmax(self): na = hl.missing(hl.tfloat32) size = 200 for aggfunc in (agg.min, agg.max): - array_with_nan = hl.array([0. if i == 1 else nan for i in range(size)]) - array_with_na = hl.array([0. if i == 1 else na for i in range(size)]) + array_with_nan = hl.array([0.0 if i == 1 else nan for i in range(size)]) + array_with_na = hl.array([0.0 if i == 1 else na for i in range(size)]) t = hl.utils.range_table(size) - self.assertEqual(t.aggregate(aggfunc(array_with_nan[t.idx])), 0.) - self.assertEqual(t.aggregate(aggfunc(array_with_na[t.idx])), 0.) + self.assertEqual(t.aggregate(aggfunc(array_with_nan[t.idx])), 0.0) + self.assertEqual(t.aggregate(aggfunc(array_with_na[t.idx])), 0.0) def test_str_ops(self): s = hl.literal('abcABC123') @@ -1549,40 +1761,46 @@ def test_str_ops(self): (s_whitespace.startswith(' \t'), True), (s_whitespace.endswith('\t\n'), True), (s_whitespace.startswith('a'), False), - (s_whitespace.endswith('a'), False)]) + (s_whitespace.endswith('a'), False), + ]) def test_str_parsing(self): int_parsers = (hl.int32, hl.int64, hl.parse_int32, hl.parse_int64) float_parsers = (hl.float, hl.float32, hl.float64, hl.parse_float32, hl.parse_float64) infinity_strings = ('inf', 'Inf', 'iNf', 'InF', 'infinity', 'InfiNitY', 'INFINITY') _test_many_equal([ - *[(hl.bool(x), True) - for x in ('true', 'True', 'TRUE')], - *[(hl.bool(x), False) - for x in ('false', 'False', 'FALSE')], - *[(hl.is_nan(f(sgn + x)), True) - for x in ('nan', 'Nan', 'naN', 'NaN') - for sgn in ('', '+', '-') - for f in float_parsers], - *[(hl.is_infinite(f(sgn + x)), True) - for x in infinity_strings - for sgn in ('', '+', '-') - for f in float_parsers], - *[(f('-' + x) < 0.0, True) - for x in infinity_strings - for f in float_parsers], - *[(hl.tuple([int_parser(hl.literal(x)), float_parser(hl.literal(x))]), - (int(x), float(x))) - for int_parser in int_parsers - for float_parser in float_parsers - for x in ('0', '1', '-5', '12382421')], - *[(hl.tuple([float_parser(hl.literal(x)), flexible_int_parser(hl.literal(x))]), (float(x), None)) - for float_parser in float_parsers - for flexible_int_parser in (hl.parse_int32, hl.parse_int64) - for x in ('-1.5', '0.0', '2.5')], - *[(flexible_numeric_parser(hl.literal(x)), None) - for flexible_numeric_parser in (hl.parse_float32, hl.parse_float64, hl.parse_int32, hl.parse_int64) - for x in ('abc', '1abc', '')] + *[(hl.bool(x), True) for x in ('true', 'True', 'TRUE')], + *[(hl.bool(x), False) for x in ('false', 'False', 'FALSE')], + *[ + (hl.is_nan(f(sgn + x)), True) + for x in ('nan', 'Nan', 'naN', 'NaN') + for sgn in ('', '+', '-') + for f in float_parsers + ], + *[ + (hl.is_infinite(f(sgn + x)), True) + for x in infinity_strings + for sgn in ('', '+', '-') + for f in float_parsers + ], + *[(f('-' + x) < 0.0, True) for x in infinity_strings for f in float_parsers], + *[ + (hl.tuple([int_parser(hl.literal(x)), float_parser(hl.literal(x))]), (int(x), float(x))) + for int_parser in int_parsers + for float_parser in float_parsers + for x in ('0', '1', '-5', '12382421') + ], + *[ + (hl.tuple([float_parser(hl.literal(x)), flexible_int_parser(hl.literal(x))]), (float(x), None)) + for float_parser in float_parsers + for flexible_int_parser in (hl.parse_int32, hl.parse_int64) + for x in ('-1.5', '0.0', '2.5') + ], + *[ + (flexible_numeric_parser(hl.literal(x)), None) + for flexible_numeric_parser in (hl.parse_float32, hl.parse_float64, hl.parse_int32, hl.parse_int64) + for x in ('abc', '1abc', '') + ], ]) def test_str_missingness(self): @@ -1617,68 +1835,57 @@ def test_division(self): (a_int64 / 4, expected, tarray(tfloat64)), (a_float32 / 4, expected, tarray(tfloat32)), (a_float64 / 4, expected, tarray(tfloat64)), - (int32_4s / a_int32, expected_inv, tarray(tfloat64)), (int32_4s / a_int64, expected_inv, tarray(tfloat64)), (int32_4s / a_float32, expected_inv, tarray(tfloat32)), (int32_4s / a_float64, expected_inv, tarray(tfloat64)), - (a_int32 / int32_4s, expected, tarray(tfloat64)), (a_int64 / int32_4s, expected, tarray(tfloat64)), (a_float32 / int32_4s, expected, tarray(tfloat32)), (a_float64 / int32_4s, expected, tarray(tfloat64)), - (a_int32 / int64_4, expected, tarray(tfloat64)), (a_int64 / int64_4, expected, tarray(tfloat64)), (a_float32 / int64_4, expected, tarray(tfloat32)), (a_float64 / int64_4, expected, tarray(tfloat64)), - (int64_4 / a_int32, expected_inv, tarray(tfloat64)), (int64_4 / a_int64, expected_inv, tarray(tfloat64)), (int64_4 / a_float32, expected_inv, tarray(tfloat32)), (int64_4 / a_float64, expected_inv, tarray(tfloat64)), - (a_int32 / int64_4s, expected, tarray(tfloat64)), (a_int64 / int64_4s, expected, tarray(tfloat64)), (a_float32 / int64_4s, expected, tarray(tfloat32)), (a_float64 / int64_4s, expected, tarray(tfloat64)), - (a_int32 / float32_4, expected, tarray(tfloat32)), (a_int64 / float32_4, expected, tarray(tfloat32)), (a_float32 / float32_4, expected, tarray(tfloat32)), (a_float64 / float32_4, expected, tarray(tfloat64)), - (float32_4 / a_int32, expected_inv, tarray(tfloat32)), (float32_4 / a_int64, expected_inv, tarray(tfloat32)), (float32_4 / a_float32, expected_inv, tarray(tfloat32)), (float32_4 / a_float64, expected_inv, tarray(tfloat64)), - (a_int32 / float32_4s, expected, tarray(tfloat32)), (a_int64 / float32_4s, expected, tarray(tfloat32)), (a_float32 / float32_4s, expected, tarray(tfloat32)), (a_float64 / float32_4s, expected, tarray(tfloat64)), - (a_int32 / float64_4, expected, tarray(tfloat64)), (a_int64 / float64_4, expected, tarray(tfloat64)), (a_float32 / float64_4, expected, tarray(tfloat64)), (a_float64 / float64_4, expected, tarray(tfloat64)), - (float64_4 / a_int32, expected_inv, tarray(tfloat64)), (float64_4 / a_int64, expected_inv, tarray(tfloat64)), (float64_4 / a_float32, expected_inv, tarray(tfloat64)), (float64_4 / a_float64, expected_inv, tarray(tfloat64)), - (a_int32 / float64_4s, expected, tarray(tfloat64)), (a_int64 / float64_4s, expected, tarray(tfloat64)), (a_float32 / float64_4s, expected, tarray(tfloat64)), - (a_float64 / float64_4s, expected, tarray(tfloat64))]) + (a_float64 / float64_4s, expected, tarray(tfloat64)), + ]) def test_floor_division(self): a_int32 = hl.array([2, 4, 8, 16, hl.missing(tint32)]) a_int64 = a_int32.map(lambda x: hl.int64(x)) a_float32 = a_int32.map(lambda x: hl.float32(x)) a_float64 = a_int32.map(lambda x: hl.float64(x)) - int32_4s = hl.array([4, 4, 4, 4, hl.missing(tint32)]) int32_3s = hl.array([3, 3, 3, 3, hl.missing(tint32)]) int64_3 = hl.int64(3) int64_3s = int32_3s.map(lambda x: hl.int64(x)) @@ -1695,68 +1902,58 @@ def test_floor_division(self): (a_int64 // 3, expected, tarray(tint64)), (a_float32 // 3, expected, tarray(tfloat32)), (a_float64 // 3, expected, tarray(tfloat64)), - (3 // a_int32, expected_inv, tarray(tint32)), (3 // a_int64, expected_inv, tarray(tint64)), (3 // a_float32, expected_inv, tarray(tfloat32)), (3 // a_float64, expected_inv, tarray(tfloat64)), - (a_int32 // int32_3s, expected, tarray(tint32)), (a_int64 // int32_3s, expected, tarray(tint64)), (a_float32 // int32_3s, expected, tarray(tfloat32)), (a_float64 // int32_3s, expected, tarray(tfloat64)), - (a_int32 // int64_3, expected, tarray(tint64)), (a_int64 // int64_3, expected, tarray(tint64)), (a_float32 // int64_3, expected, tarray(tfloat32)), (a_float64 // int64_3, expected, tarray(tfloat64)), - (int64_3 // a_int32, expected_inv, tarray(tint64)), (int64_3 // a_int64, expected_inv, tarray(tint64)), (int64_3 // a_float32, expected_inv, tarray(tfloat32)), (int64_3 // a_float64, expected_inv, tarray(tfloat64)), - (a_int32 // int64_3s, expected, tarray(tint64)), (a_int64 // int64_3s, expected, tarray(tint64)), (a_float32 // int64_3s, expected, tarray(tfloat32)), (a_float64 // int64_3s, expected, tarray(tfloat64)), - (a_int32 // float32_3, expected, tarray(tfloat32)), (a_int64 // float32_3, expected, tarray(tfloat32)), (a_float32 // float32_3, expected, tarray(tfloat32)), (a_float64 // float32_3, expected, tarray(tfloat64)), - (float32_3 // a_int32, expected_inv, tarray(tfloat32)), (float32_3 // a_int64, expected_inv, tarray(tfloat32)), (float32_3 // a_float32, expected_inv, tarray(tfloat32)), (float32_3 // a_float64, expected_inv, tarray(tfloat64)), - (a_int32 // float32_3s, expected, tarray(tfloat32)), (a_int64 // float32_3s, expected, tarray(tfloat32)), (a_float32 // float32_3s, expected, tarray(tfloat32)), (a_float64 // float32_3s, expected, tarray(tfloat64)), - (a_int32 // float64_3, expected, tarray(tfloat64)), (a_int64 // float64_3, expected, tarray(tfloat64)), (a_float32 // float64_3, expected, tarray(tfloat64)), (a_float64 // float64_3, expected, tarray(tfloat64)), - (float64_3 // a_int32, expected_inv, tarray(tfloat64)), (float64_3 // a_int64, expected_inv, tarray(tfloat64)), (float64_3 // a_float32, expected_inv, tarray(tfloat64)), (float64_3 // a_float64, expected_inv, tarray(tfloat64)), - (a_int32 // float64_3s, expected, tarray(tfloat64)), (a_int64 // float64_3s, expected, tarray(tfloat64)), (a_float32 // float64_3s, expected, tarray(tfloat64)), - (a_float64 // float64_3s, expected, tarray(tfloat64))]) + (a_float64 // float64_3s, expected, tarray(tfloat64)), + ]) def test_addition(self): a_int32 = hl.array([2, 4, 8, 16, hl.missing(tint32)]) a_int64 = a_int32.map(lambda x: hl.int64(x)) a_float32 = a_int32.map(lambda x: hl.float32(x)) a_float64 = a_int32.map(lambda x: hl.float64(x)) - int32_4s = hl.array([4, 4, 4, 4, hl.missing(tint32)]) + hl.array([4, 4, 4, 4, hl.missing(tint32)]) int32_3s = hl.array([3, 3, 3, 3, hl.missing(tint32)]) int64_3 = hl.int64(3) int64_3s = int32_3s.map(lambda x: hl.int64(x)) @@ -1773,68 +1970,58 @@ def test_addition(self): (a_int64 + 3, expected, tarray(tint64)), (a_float32 + 3, expected, tarray(tfloat32)), (a_float64 + 3, expected, tarray(tfloat64)), - (3 + a_int32, expected_inv, tarray(tint32)), (3 + a_int64, expected_inv, tarray(tint64)), (3 + a_float32, expected_inv, tarray(tfloat32)), (3 + a_float64, expected_inv, tarray(tfloat64)), - (a_int32 + int32_3s, expected, tarray(tint32)), (a_int64 + int32_3s, expected, tarray(tint64)), (a_float32 + int32_3s, expected, tarray(tfloat32)), (a_float64 + int32_3s, expected, tarray(tfloat64)), - (a_int32 + int64_3, expected, tarray(tint64)), (a_int64 + int64_3, expected, tarray(tint64)), (a_float32 + int64_3, expected, tarray(tfloat32)), (a_float64 + int64_3, expected, tarray(tfloat64)), - (int64_3 + a_int32, expected_inv, tarray(tint64)), (int64_3 + a_int64, expected_inv, tarray(tint64)), (int64_3 + a_float32, expected_inv, tarray(tfloat32)), (int64_3 + a_float64, expected_inv, tarray(tfloat64)), - (a_int32 + int64_3s, expected, tarray(tint64)), (a_int64 + int64_3s, expected, tarray(tint64)), (a_float32 + int64_3s, expected, tarray(tfloat32)), (a_float64 + int64_3s, expected, tarray(tfloat64)), - (a_int32 + float32_3, expected, tarray(tfloat32)), (a_int64 + float32_3, expected, tarray(tfloat32)), (a_float32 + float32_3, expected, tarray(tfloat32)), (a_float64 + float32_3, expected, tarray(tfloat64)), - (float32_3 + a_int32, expected_inv, tarray(tfloat32)), (float32_3 + a_int64, expected_inv, tarray(tfloat32)), (float32_3 + a_float32, expected_inv, tarray(tfloat32)), (float32_3 + a_float64, expected_inv, tarray(tfloat64)), - (a_int32 + float32_3s, expected, tarray(tfloat32)), (a_int64 + float32_3s, expected, tarray(tfloat32)), (a_float32 + float32_3s, expected, tarray(tfloat32)), (a_float64 + float32_3s, expected, tarray(tfloat64)), - (a_int32 + float64_3, expected, tarray(tfloat64)), (a_int64 + float64_3, expected, tarray(tfloat64)), (a_float32 + float64_3, expected, tarray(tfloat64)), (a_float64 + float64_3, expected, tarray(tfloat64)), - (float64_3 + a_int32, expected_inv, tarray(tfloat64)), (float64_3 + a_int64, expected_inv, tarray(tfloat64)), (float64_3 + a_float32, expected_inv, tarray(tfloat64)), (float64_3 + a_float64, expected_inv, tarray(tfloat64)), - (a_int32 + float64_3s, expected, tarray(tfloat64)), (a_int64 + float64_3s, expected, tarray(tfloat64)), (a_float32 + float64_3s, expected, tarray(tfloat64)), - (a_float64 + float64_3s, expected, tarray(tfloat64))]) + (a_float64 + float64_3s, expected, tarray(tfloat64)), + ]) def test_subtraction(self): a_int32 = hl.array([2, 4, 8, 16, hl.missing(tint32)]) a_int64 = a_int32.map(lambda x: hl.int64(x)) a_float32 = a_int32.map(lambda x: hl.float32(x)) a_float64 = a_int32.map(lambda x: hl.float64(x)) - int32_4s = hl.array([4, 4, 4, 4, hl.missing(tint32)]) + hl.array([4, 4, 4, 4, hl.missing(tint32)]) int32_3s = hl.array([3, 3, 3, 3, hl.missing(tint32)]) int64_3 = hl.int64(3) int64_3s = int32_3s.map(lambda x: hl.int64(x)) @@ -1851,68 +2038,58 @@ def test_subtraction(self): (a_int64 - 3, expected, tarray(tint64)), (a_float32 - 3, expected, tarray(tfloat32)), (a_float64 - 3, expected, tarray(tfloat64)), - (3 - a_int32, expected_inv, tarray(tint32)), (3 - a_int64, expected_inv, tarray(tint64)), (3 - a_float32, expected_inv, tarray(tfloat32)), (3 - a_float64, expected_inv, tarray(tfloat64)), - (a_int32 - int32_3s, expected, tarray(tint32)), (a_int64 - int32_3s, expected, tarray(tint64)), (a_float32 - int32_3s, expected, tarray(tfloat32)), (a_float64 - int32_3s, expected, tarray(tfloat64)), - (a_int32 - int64_3, expected, tarray(tint64)), (a_int64 - int64_3, expected, tarray(tint64)), (a_float32 - int64_3, expected, tarray(tfloat32)), (a_float64 - int64_3, expected, tarray(tfloat64)), - (int64_3 - a_int32, expected_inv, tarray(tint64)), (int64_3 - a_int64, expected_inv, tarray(tint64)), (int64_3 - a_float32, expected_inv, tarray(tfloat32)), (int64_3 - a_float64, expected_inv, tarray(tfloat64)), - (a_int32 - int64_3s, expected, tarray(tint64)), (a_int64 - int64_3s, expected, tarray(tint64)), (a_float32 - int64_3s, expected, tarray(tfloat32)), (a_float64 - int64_3s, expected, tarray(tfloat64)), - (a_int32 - float32_3, expected, tarray(tfloat32)), (a_int64 - float32_3, expected, tarray(tfloat32)), (a_float32 - float32_3, expected, tarray(tfloat32)), (a_float64 - float32_3, expected, tarray(tfloat64)), - (float32_3 - a_int32, expected_inv, tarray(tfloat32)), (float32_3 - a_int64, expected_inv, tarray(tfloat32)), (float32_3 - a_float32, expected_inv, tarray(tfloat32)), (float32_3 - a_float64, expected_inv, tarray(tfloat64)), - (a_int32 - float32_3s, expected, tarray(tfloat32)), (a_int64 - float32_3s, expected, tarray(tfloat32)), (a_float32 - float32_3s, expected, tarray(tfloat32)), (a_float64 - float32_3s, expected, tarray(tfloat64)), - (a_int32 - float64_3, expected, tarray(tfloat64)), (a_int64 - float64_3, expected, tarray(tfloat64)), (a_float32 - float64_3, expected, tarray(tfloat64)), (a_float64 - float64_3, expected, tarray(tfloat64)), - (float64_3 - a_int32, expected_inv, tarray(tfloat64)), (float64_3 - a_int64, expected_inv, tarray(tfloat64)), (float64_3 - a_float32, expected_inv, tarray(tfloat64)), (float64_3 - a_float64, expected_inv, tarray(tfloat64)), - (a_int32 - float64_3s, expected, tarray(tfloat64)), (a_int64 - float64_3s, expected, tarray(tfloat64)), (a_float32 - float64_3s, expected, tarray(tfloat64)), - (a_float64 - float64_3s, expected, tarray(tfloat64))]) + (a_float64 - float64_3s, expected, tarray(tfloat64)), + ]) def test_multiplication(self): a_int32 = hl.array([2, 4, 8, 16, hl.missing(tint32)]) a_int64 = a_int32.map(lambda x: hl.int64(x)) a_float32 = a_int32.map(lambda x: hl.float32(x)) a_float64 = a_int32.map(lambda x: hl.float64(x)) - int32_4s = hl.array([4, 4, 4, 4, hl.missing(tint32)]) + hl.array([4, 4, 4, 4, hl.missing(tint32)]) int32_3s = hl.array([3, 3, 3, 3, hl.missing(tint32)]) int64_3 = hl.int64(3) int64_3s = int32_3s.map(lambda x: hl.int64(x)) @@ -1929,68 +2106,58 @@ def test_multiplication(self): (a_int64 * 3, expected, tarray(tint64)), (a_float32 * 3, expected, tarray(tfloat32)), (a_float64 * 3, expected, tarray(tfloat64)), - (3 * a_int32, expected_inv, tarray(tint32)), (3 * a_int64, expected_inv, tarray(tint64)), (3 * a_float32, expected_inv, tarray(tfloat32)), (3 * a_float64, expected_inv, tarray(tfloat64)), - (a_int32 * int32_3s, expected, tarray(tint32)), (a_int64 * int32_3s, expected, tarray(tint64)), (a_float32 * int32_3s, expected, tarray(tfloat32)), (a_float64 * int32_3s, expected, tarray(tfloat64)), - (a_int32 * int64_3, expected, tarray(tint64)), (a_int64 * int64_3, expected, tarray(tint64)), (a_float32 * int64_3, expected, tarray(tfloat32)), (a_float64 * int64_3, expected, tarray(tfloat64)), - (int64_3 * a_int32, expected_inv, tarray(tint64)), (int64_3 * a_int64, expected_inv, tarray(tint64)), (int64_3 * a_float32, expected_inv, tarray(tfloat32)), (int64_3 * a_float64, expected_inv, tarray(tfloat64)), - (a_int32 * int64_3s, expected, tarray(tint64)), (a_int64 * int64_3s, expected, tarray(tint64)), (a_float32 * int64_3s, expected, tarray(tfloat32)), (a_float64 * int64_3s, expected, tarray(tfloat64)), - (a_int32 * float32_3, expected, tarray(tfloat32)), (a_int64 * float32_3, expected, tarray(tfloat32)), (a_float32 * float32_3, expected, tarray(tfloat32)), (a_float64 * float32_3, expected, tarray(tfloat64)), - (float32_3 * a_int32, expected_inv, tarray(tfloat32)), (float32_3 * a_int64, expected_inv, tarray(tfloat32)), (float32_3 * a_float32, expected_inv, tarray(tfloat32)), (float32_3 * a_float64, expected_inv, tarray(tfloat64)), - (a_int32 * float32_3s, expected, tarray(tfloat32)), (a_int64 * float32_3s, expected, tarray(tfloat32)), (a_float32 * float32_3s, expected, tarray(tfloat32)), (a_float64 * float32_3s, expected, tarray(tfloat64)), - (a_int32 * float64_3, expected, tarray(tfloat64)), (a_int64 * float64_3, expected, tarray(tfloat64)), (a_float32 * float64_3, expected, tarray(tfloat64)), (a_float64 * float64_3, expected, tarray(tfloat64)), - (float64_3 * a_int32, expected_inv, tarray(tfloat64)), (float64_3 * a_int64, expected_inv, tarray(tfloat64)), (float64_3 * a_float32, expected_inv, tarray(tfloat64)), (float64_3 * a_float64, expected_inv, tarray(tfloat64)), - (a_int32 * float64_3s, expected, tarray(tfloat64)), (a_int64 * float64_3s, expected, tarray(tfloat64)), (a_float32 * float64_3s, expected, tarray(tfloat64)), - (a_float64 * float64_3s, expected, tarray(tfloat64))]) + (a_float64 * float64_3s, expected, tarray(tfloat64)), + ]) def test_exponentiation(self): a_int32 = hl.array([2, 4, 8, 16, hl.missing(tint32)]) a_int64 = a_int32.map(lambda x: hl.int64(x)) a_float32 = a_int32.map(lambda x: hl.float32(x)) a_float64 = a_int32.map(lambda x: hl.float64(x)) - int32_4s = hl.array([4, 4, 4, 4, hl.missing(tint32)]) + hl.array([4, 4, 4, 4, hl.missing(tint32)]) int32_3s = hl.array([3, 3, 3, 3, hl.missing(tint32)]) int64_3 = hl.int64(3) int64_3s = int32_3s.map(lambda x: hl.int64(x)) @@ -2003,72 +2170,62 @@ def test_exponentiation(self): expected_inv = [9.0, 81.0, 6561.0, 43046721.0, None] _test_many_equal_typed([ - (a_int32 ** 3, expected, tarray(tfloat64)), - (a_int64 ** 3, expected, tarray(tfloat64)), - (a_float32 ** 3, expected, tarray(tfloat64)), - (a_float64 ** 3, expected, tarray(tfloat64)), - - (3 ** a_int32, expected_inv, tarray(tfloat64)), - (3 ** a_int64, expected_inv, tarray(tfloat64)), - (3 ** a_float32, expected_inv, tarray(tfloat64)), - (3 ** a_float64, expected_inv, tarray(tfloat64)), - - (a_int32 ** int32_3s, expected, tarray(tfloat64)), - (a_int64 ** int32_3s, expected, tarray(tfloat64)), - (a_float32 ** int32_3s, expected, tarray(tfloat64)), - (a_float64 ** int32_3s, expected, tarray(tfloat64)), - - (a_int32 ** int64_3, expected, tarray(tfloat64)), - (a_int64 ** int64_3, expected, tarray(tfloat64)), - (a_float32 ** int64_3, expected, tarray(tfloat64)), - (a_float64 ** int64_3, expected, tarray(tfloat64)), - - (int64_3 ** a_int32, expected_inv, tarray(tfloat64)), - (int64_3 ** a_int64, expected_inv, tarray(tfloat64)), - (int64_3 ** a_float32, expected_inv, tarray(tfloat64)), - (int64_3 ** a_float64, expected_inv, tarray(tfloat64)), - - (a_int32 ** int64_3s, expected, tarray(tfloat64)), - (a_int64 ** int64_3s, expected, tarray(tfloat64)), - (a_float32 ** int64_3s, expected, tarray(tfloat64)), - (a_float64 ** int64_3s, expected, tarray(tfloat64)), - - (a_int32 ** float32_3, expected, tarray(tfloat64)), - (a_int64 ** float32_3, expected, tarray(tfloat64)), - (a_float32 ** float32_3, expected, tarray(tfloat64)), - (a_float64 ** float32_3, expected, tarray(tfloat64)), - - (float32_3 ** a_int32, expected_inv, tarray(tfloat64)), - (float32_3 ** a_int64, expected_inv, tarray(tfloat64)), - (float32_3 ** a_float32, expected_inv, tarray(tfloat64)), - (float32_3 ** a_float64, expected_inv, tarray(tfloat64)), - - (a_int32 ** float32_3s, expected, tarray(tfloat64)), - (a_int64 ** float32_3s, expected, tarray(tfloat64)), - (a_float32 ** float32_3s, expected, tarray(tfloat64)), - (a_float64 ** float32_3s, expected, tarray(tfloat64)), - - (a_int32 ** float64_3, expected, tarray(tfloat64)), - (a_int64 ** float64_3, expected, tarray(tfloat64)), - (a_float32 ** float64_3, expected, tarray(tfloat64)), - (a_float64 ** float64_3, expected, tarray(tfloat64)), - - (float64_3 ** a_int32, expected_inv, tarray(tfloat64)), - (float64_3 ** a_int64, expected_inv, tarray(tfloat64)), - (float64_3 ** a_float32, expected_inv, tarray(tfloat64)), - (float64_3 ** a_float64, expected_inv, tarray(tfloat64)), - - (a_int32 ** float64_3s, expected, tarray(tfloat64)), - (a_int64 ** float64_3s, expected, tarray(tfloat64)), - (a_float32 ** float64_3s, expected, tarray(tfloat64)), - (a_float64 ** float64_3s, expected, tarray(tfloat64))]) + (a_int32**3, expected, tarray(tfloat64)), + (a_int64**3, expected, tarray(tfloat64)), + (a_float32**3, expected, tarray(tfloat64)), + (a_float64**3, expected, tarray(tfloat64)), + (3**a_int32, expected_inv, tarray(tfloat64)), + (3**a_int64, expected_inv, tarray(tfloat64)), + (3**a_float32, expected_inv, tarray(tfloat64)), + (3**a_float64, expected_inv, tarray(tfloat64)), + (a_int32**int32_3s, expected, tarray(tfloat64)), + (a_int64**int32_3s, expected, tarray(tfloat64)), + (a_float32**int32_3s, expected, tarray(tfloat64)), + (a_float64**int32_3s, expected, tarray(tfloat64)), + (a_int32**int64_3, expected, tarray(tfloat64)), + (a_int64**int64_3, expected, tarray(tfloat64)), + (a_float32**int64_3, expected, tarray(tfloat64)), + (a_float64**int64_3, expected, tarray(tfloat64)), + (int64_3**a_int32, expected_inv, tarray(tfloat64)), + (int64_3**a_int64, expected_inv, tarray(tfloat64)), + (int64_3**a_float32, expected_inv, tarray(tfloat64)), + (int64_3**a_float64, expected_inv, tarray(tfloat64)), + (a_int32**int64_3s, expected, tarray(tfloat64)), + (a_int64**int64_3s, expected, tarray(tfloat64)), + (a_float32**int64_3s, expected, tarray(tfloat64)), + (a_float64**int64_3s, expected, tarray(tfloat64)), + (a_int32**float32_3, expected, tarray(tfloat64)), + (a_int64**float32_3, expected, tarray(tfloat64)), + (a_float32**float32_3, expected, tarray(tfloat64)), + (a_float64**float32_3, expected, tarray(tfloat64)), + (float32_3**a_int32, expected_inv, tarray(tfloat64)), + (float32_3**a_int64, expected_inv, tarray(tfloat64)), + (float32_3**a_float32, expected_inv, tarray(tfloat64)), + (float32_3**a_float64, expected_inv, tarray(tfloat64)), + (a_int32**float32_3s, expected, tarray(tfloat64)), + (a_int64**float32_3s, expected, tarray(tfloat64)), + (a_float32**float32_3s, expected, tarray(tfloat64)), + (a_float64**float32_3s, expected, tarray(tfloat64)), + (a_int32**float64_3, expected, tarray(tfloat64)), + (a_int64**float64_3, expected, tarray(tfloat64)), + (a_float32**float64_3, expected, tarray(tfloat64)), + (a_float64**float64_3, expected, tarray(tfloat64)), + (float64_3**a_int32, expected_inv, tarray(tfloat64)), + (float64_3**a_int64, expected_inv, tarray(tfloat64)), + (float64_3**a_float32, expected_inv, tarray(tfloat64)), + (float64_3**a_float64, expected_inv, tarray(tfloat64)), + (a_int32**float64_3s, expected, tarray(tfloat64)), + (a_int64**float64_3s, expected, tarray(tfloat64)), + (a_float32**float64_3s, expected, tarray(tfloat64)), + (a_float64**float64_3s, expected, tarray(tfloat64)), + ]) def test_modulus(self): a_int32 = hl.array([2, 4, 8, 16, hl.missing(tint32)]) a_int64 = a_int32.map(lambda x: hl.int64(x)) a_float32 = a_int32.map(lambda x: hl.float32(x)) a_float64 = a_int32.map(lambda x: hl.float64(x)) - int32_4s = hl.array([4, 4, 4, 4, hl.missing(tint32)]) + hl.array([4, 4, 4, 4, hl.missing(tint32)]) int32_3s = hl.array([3, 3, 3, 3, hl.missing(tint32)]) int64_3 = hl.int64(3) int64_3s = int32_3s.map(lambda x: hl.int64(x)) @@ -2085,61 +2242,51 @@ def test_modulus(self): (a_int64 % 3, expected, tarray(tint64)), (a_float32 % 3, expected, tarray(tfloat32)), (a_float64 % 3, expected, tarray(tfloat64)), - (3 % a_int32, expected_inv, tarray(tint32)), (3 % a_int64, expected_inv, tarray(tint64)), (3 % a_float32, expected_inv, tarray(tfloat32)), (3 % a_float64, expected_inv, tarray(tfloat64)), - (a_int32 % int32_3s, expected, tarray(tint32)), (a_int64 % int32_3s, expected, tarray(tint64)), (a_float32 % int32_3s, expected, tarray(tfloat32)), (a_float64 % int32_3s, expected, tarray(tfloat64)), - (a_int32 % int64_3, expected, tarray(tint64)), (a_int64 % int64_3, expected, tarray(tint64)), (a_float32 % int64_3, expected, tarray(tfloat32)), (a_float64 % int64_3, expected, tarray(tfloat64)), - (int64_3 % a_int32, expected_inv, tarray(tint64)), (int64_3 % a_int64, expected_inv, tarray(tint64)), (int64_3 % a_float32, expected_inv, tarray(tfloat32)), (int64_3 % a_float64, expected_inv, tarray(tfloat64)), - (a_int32 % int64_3s, expected, tarray(tint64)), (a_int64 % int64_3s, expected, tarray(tint64)), (a_float32 % int64_3s, expected, tarray(tfloat32)), (a_float64 % int64_3s, expected, tarray(tfloat64)), - (a_int32 % float32_3, expected, tarray(tfloat32)), (a_int64 % float32_3, expected, tarray(tfloat32)), (a_float32 % float32_3, expected, tarray(tfloat32)), (a_float64 % float32_3, expected, tarray(tfloat64)), - (float32_3 % a_int32, expected_inv, tarray(tfloat32)), (float32_3 % a_int64, expected_inv, tarray(tfloat32)), (float32_3 % a_float32, expected_inv, tarray(tfloat32)), (float32_3 % a_float64, expected_inv, tarray(tfloat64)), - (a_int32 % float32_3s, expected, tarray(tfloat32)), (a_int64 % float32_3s, expected, tarray(tfloat32)), (a_float32 % float32_3s, expected, tarray(tfloat32)), (a_float64 % float32_3s, expected, tarray(tfloat64)), - (a_int32 % float64_3, expected, tarray(tfloat64)), (a_int64 % float64_3, expected, tarray(tfloat64)), (a_float32 % float64_3, expected, tarray(tfloat64)), (a_float64 % float64_3, expected, tarray(tfloat64)), - (float64_3 % a_int32, expected_inv, tarray(tfloat64)), (float64_3 % a_int64, expected_inv, tarray(tfloat64)), (float64_3 % a_float32, expected_inv, tarray(tfloat64)), (float64_3 % a_float64, expected_inv, tarray(tfloat64)), - (a_int32 % float64_3s, expected, tarray(tfloat64)), (a_int64 % float64_3s, expected, tarray(tfloat64)), (a_float32 % float64_3s, expected, tarray(tfloat64)), - (a_float64 % float64_3s, expected, tarray(tfloat64))]) + (a_float64 % float64_3s, expected, tarray(tfloat64)), + ]) def test_comparisons(self): f0 = hl.float(0.0) @@ -2151,16 +2298,14 @@ def test_comparisons(self): (f0 == fnull, None, tbool), (f0 < fnull, None, tbool), (f0 != fnull, None, tbool), - - (fnan == fnan, False, tbool), - (f0 == f0, True, tbool), - (finf == finf, True, tbool), - + (fnan == fnan, False, tbool), # noqa: PLR0124 + (f0 == f0, True, tbool), # noqa: PLR0124 + (finf == finf, True, tbool), # noqa: PLR0124 (f0 < finf, True, tbool), (f0 > finf, False, tbool), - (fnan <= finf, False, tbool), - (fnan >= finf, False, tbool)]) + (fnan >= finf, False, tbool), + ]) def test_bools_can_math(self): b1 = hl.literal(True) @@ -2181,12 +2326,11 @@ def test_bools_can_math(self): (b1 / b1, 1.0), (f1 * b2, 0.0), (b_array + f1, [6.5, 5.5]), - (b_array + f_array, [2.5, 2.5])]) + (b_array + f_array, [2.5, 2.5]), + ]) def test_int_typecheck(self): - _test_many_equal([ - (hl.literal(None, dtype='int32'), None), - (hl.literal(None, dtype='int64'), None)]) + _test_many_equal([(hl.literal(None, dtype='int32'), None), (hl.literal(None, dtype='int64'), None)]) def test_is_transition(self): _test_many_equal([ @@ -2195,7 +2339,8 @@ def test_is_transition(self): (hl.is_transition("AA", "AG"), True), (hl.is_transition("AA", "G"), False), (hl.is_transition("ACA", "AGA"), False), - (hl.is_transition("A", "T"), False)]) + (hl.is_transition("A", "T"), False), + ]) def test_is_transversion(self): _test_many_equal([ @@ -2203,7 +2348,8 @@ def test_is_transversion(self): (hl.is_transversion("A", "G"), False), (hl.is_transversion("AA", "AT"), True), (hl.is_transversion("AA", "T"), False), - (hl.is_transversion("ACCC", "ACCT"), False)]) + (hl.is_transversion("ACCC", "ACCT"), False), + ]) def test_is_snp(self): _test_many_equal([ @@ -2212,17 +2358,14 @@ def test_is_snp(self): (hl.is_snp("C", "G"), True), (hl.is_snp("CC", "CG"), True), (hl.is_snp("AT", "AG"), True), - (hl.is_snp("ATCCC", "AGCCC"), True)]) + (hl.is_snp("ATCCC", "AGCCC"), True), + ]) def test_is_mnp(self): - _test_many_equal([ - (hl.is_mnp("ACTGAC", "ATTGTT"), True), - (hl.is_mnp("CA", "TT"), True)]) + _test_many_equal([(hl.is_mnp("ACTGAC", "ATTGTT"), True), (hl.is_mnp("CA", "TT"), True)]) def test_is_insertion(self): - _test_many_equal([ - (hl.is_insertion("A", "ATGC"), True), - (hl.is_insertion("ATT", "ATGCTT"), True)]) + _test_many_equal([(hl.is_insertion("A", "ATGC"), True), (hl.is_insertion("ATT", "ATGCTT"), True)]) def test_is_deletion(self): self.assertTrue(hl.eval(hl.is_deletion("ATGC", "A"))) @@ -2248,27 +2391,29 @@ def test_is_strand_ambiguous(self): def test_allele_type(self): self.assertEqual( - hl.eval(hl.tuple(( - hl.allele_type('A', 'C'), - hl.allele_type('AC', 'CT'), - hl.allele_type('C', 'CT'), - hl.allele_type('CT', 'C'), - hl.allele_type('CTCA', 'AAC'), - hl.allele_type('CTCA', '*'), - hl.allele_type('C', ''), - hl.allele_type('C', ''), - hl.allele_type('C', 'H'), - hl.allele_type('C', ''), - hl.allele_type('A', 'A'), - hl.allele_type('', 'CCT'), - hl.allele_type('F', 'CCT'), - hl.allele_type('A', '[ASDASD[A'), - hl.allele_type('A', ']ASDASD]A'), - hl.allele_type('A', 'T]ASDASD]'), - hl.allele_type('A', 'T[ASDASD['), - hl.allele_type('A', '.T'), - hl.allele_type('A', 'T.'), - ))), + hl.eval( + hl.tuple(( + hl.allele_type('A', 'C'), + hl.allele_type('AC', 'CT'), + hl.allele_type('C', 'CT'), + hl.allele_type('CT', 'C'), + hl.allele_type('CTCA', 'AAC'), + hl.allele_type('CTCA', '*'), + hl.allele_type('C', ''), + hl.allele_type('C', ''), + hl.allele_type('C', 'H'), + hl.allele_type('C', ''), + hl.allele_type('A', 'A'), + hl.allele_type('', 'CCT'), + hl.allele_type('F', 'CCT'), + hl.allele_type('A', '[ASDASD[A'), + hl.allele_type('A', ']ASDASD]A'), + hl.allele_type('A', 'T]ASDASD]'), + hl.allele_type('A', 'T[ASDASD['), + hl.allele_type('A', '.T'), + hl.allele_type('A', 'T.'), + )) + ), ( 'SNP', 'MNP', @@ -2289,14 +2434,15 @@ def test_allele_type(self): 'Symbolic', 'Symbolic', 'Symbolic', - ) + ), ) def test_hamming(self): _test_many_equal([ (hl.hamming('A', 'T'), 1), (hl.hamming('AAAAA', 'AAAAT'), 1), - (hl.hamming('abcde', 'edcba'), 4)]) + (hl.hamming('abcde', 'edcba'), 4), + ]) def test_gp_dosage(self): self.assertAlmostEqual(hl.eval(hl.gp_dosage([1.0, 0.0, 0.0])), 0.0) @@ -2327,59 +2473,49 @@ def test_call(self): (c2_homref[1], 0, tint32), (c2_homref.phased, False, tbool), (c2_homref.is_hom_ref(), True, tbool), - (c2_het.ploidy, 2, tint32), (c2_het[0], 1, tint32), (c2_het[1], 0, tint32), (c2_het.phased, True, tbool), (c2_het.is_het(), True, tbool), - (c2_homvar.ploidy, 2, tint32), (c2_homvar[0], 1, tint32), (c2_homvar[1], 1, tint32), (c2_homvar.phased, False, tbool), (c2_homvar.is_hom_var(), True, tbool), (c2_homvar.unphased_diploid_gt_index(), 2, tint32), - (c2_hetvar.ploidy, 2, tint32), (c2_hetvar[0], 2, tint32), (c2_hetvar[1], 1, tint32), (c2_hetvar.phased, True, tbool), (c2_hetvar.is_hom_var(), False, tbool), (c2_hetvar.is_het_non_ref(), True, tbool), - (c1.ploidy, 1, tint32), (c1[0], 1, tint32), (c1.phased, False, tbool), (c1.is_hom_var(), True, tbool), - (c0.ploidy, 0, tint32), (c0.phased, False, tbool), (c0.is_hom_var(), False, tbool), - (cNull.ploidy, None, tint32), (cNull[0], None, tint32), (cNull.phased, None, tbool), (cNull.is_hom_var(), None, tbool), - (call_expr_1[0], 1, tint32), (call_expr_1[1], 2, tint32), (call_expr_1.ploidy, 2, tint32), - (call_expr_2[0], 1, tint32), (call_expr_2[1], 2, tint32), (call_expr_2.ploidy, 2, tint32), - (call_expr_3[0], 1, tint32), (call_expr_3[1], 2, tint32), (call_expr_3.ploidy, 2, tint32), - (call_expr_4[0], 1, tint32), (call_expr_4[1], 1, tint32), - (call_expr_4.ploidy, 2, tint32)]) + (call_expr_4.ploidy, 2, tint32), + ]) def test_call_unphase(self): - calls = [ hl.Call([0], phased=True), hl.Call([0], phased=False), @@ -2407,26 +2543,31 @@ def test_call_unphase(self): def test_call_contains_allele(self): c1 = hl.call(1, phased=True) c2 = hl.call(1, phased=False) - c3 = hl.call(3,1, phased=True) - c4 = hl.call(1,3, phased=False) - - for i, b in enumerate(hl.eval(tuple([ - c1.contains_allele(1), - ~c1.contains_allele(0), - ~c1.contains_allele(2), - c2.contains_allele(1), - ~c2.contains_allele(0), - ~c2.contains_allele(2), - c3.contains_allele(1), - c3.contains_allele(3), - ~c3.contains_allele(0), - ~c3.contains_allele(2), - c4.contains_allele(1), - c4.contains_allele(3), - ~c4.contains_allele(0), - ~c4.contains_allele(2), - ]))): + c3 = hl.call(3, 1, phased=True) + c4 = hl.call(1, 3, phased=False) + + for i, b in enumerate( + hl.eval( + tuple([ + c1.contains_allele(1), + ~c1.contains_allele(0), + ~c1.contains_allele(2), + c2.contains_allele(1), + ~c2.contains_allele(0), + ~c2.contains_allele(2), + c3.contains_allele(1), + c3.contains_allele(3), + ~c3.contains_allele(0), + ~c3.contains_allele(2), + c4.contains_allele(1), + c4.contains_allele(3), + ~c4.contains_allele(0), + ~c4.contains_allele(2), + ]) + ) + ): assert b, i + def test_call_unphase_diploid_gt_index(self): calls_and_indices = [ (hl.call(0, 0), 0), @@ -2442,44 +2583,41 @@ def test_call_unphase_diploid_gt_index(self): assert hl.eval(gt_idx) == tuple(i for c, i in calls_and_indices) def test_parse_variant(self): - self.assertEqual(hl.eval(hl.parse_variant('1:1:A:T')), - hl.Struct(locus=hl.Locus('1', 1), alleles=['A', 'T'])) + self.assertEqual(hl.eval(hl.parse_variant('1:1:A:T')), hl.Struct(locus=hl.Locus('1', 1), alleles=['A', 'T'])) def test_locus_to_global_position(self): self.assertEqual(hl.eval(hl.locus('chr22', 1, 'GRCh38').global_position()), 2824183054) def test_locus_from_global_position(self): - self.assertEqual(hl.eval(hl.locus_from_global_position(2824183054, 'GRCh38')), - hl.eval(hl.locus('chr22', 1, 'GRCh38'))) + self.assertEqual( + hl.eval(hl.locus_from_global_position(2824183054, 'GRCh38')), hl.eval(hl.locus('chr22', 1, 'GRCh38')) + ) def test_locus_window(self): locus = hl.Locus('22', 123456, reference_genome='GRCh37') lit = hl.literal(locus) - results = hl.eval(hl.struct( - zeros=lit.window(0, 0), - ones=lit.window(1, 1), - big_windows=lit.window(1_000_000_000, 1_000_000_000) - )) + results = hl.eval( + hl.struct( + zeros=lit.window(0, 0), ones=lit.window(1, 1), big_windows=lit.window(1_000_000_000, 1_000_000_000) + ) + ) pt = hl.tinterval(hl.tlocus('GRCh37')) - assert results.zeros == hl.Interval(hl.Locus('22', 123456), - hl.Locus('22', 123456), - includes_start=True, - includes_end=True, - point_type=pt) - assert results.ones == hl.Interval(hl.Locus('22', 123455), - hl.Locus('22', 123457), - includes_start=True, - includes_end=True, - point_type=pt) - assert results.big_windows == hl.Interval(hl.Locus('22', 1), - hl.Locus('22', hl.get_reference('GRCh37').contig_length('22')), - includes_start=True, - includes_end=True, - point_type=pt) - + assert results.zeros == hl.Interval( + hl.Locus('22', 123456), hl.Locus('22', 123456), includes_start=True, includes_end=True, point_type=pt + ) + assert results.ones == hl.Interval( + hl.Locus('22', 123455), hl.Locus('22', 123457), includes_start=True, includes_end=True, point_type=pt + ) + assert results.big_windows == hl.Interval( + hl.Locus('22', 1), + hl.Locus('22', hl.get_reference('GRCh37').contig_length('22')), + includes_start=True, + includes_end=True, + point_type=pt, + ) def test_dict_conversions(self): self.assertEqual(sorted(hl.eval(hl.array({1: 1, 2: 2}))), [(1, 1), (2, 2)]) @@ -2491,17 +2629,17 @@ def test_dict_conversions(self): self.assertEqual(hl.eval(hl.dict({('1', 2), (hl.missing(tstr), 3)})), {'1': 2, None: 3}) def test_zip(self): - a1 = [1,2,3] + a1 = [1, 2, 3] a2 = ['a', 'b'] a3 = [[1]] self.assertEqual(hl.eval(hl.zip(a1, a2)), [(1, 'a'), (2, 'b')]) self.assertEqual(hl.eval(hl.zip(a1, a2, fill_missing=True)), [(1, 'a'), (2, 'b'), (3, None)]) - self.assertEqual(hl.eval(hl.zip(a3, a2, a1, fill_missing=True)), - [([1], 'a', 1), (None, 'b', 2), (None, None, 3)]) - self.assertEqual(hl.eval(hl.zip(a3, a2, a1)), - [([1], 'a', 1)]) + self.assertEqual( + hl.eval(hl.zip(a3, a2, a1, fill_missing=True)), [([1], 'a', 1), (None, 'b', 2), (None, None, 3)] + ) + self.assertEqual(hl.eval(hl.zip(a3, a2, a1)), [([1], 'a', 1)]) def test_any_form_1(self): self.assertEqual(hl.eval(hl.any()), False) @@ -2556,26 +2694,37 @@ def test_all_form_3(self): self.assertEqual(hl.eval(hl.all(lambda x: x % 2 == 0, [2, 6])), True) def test_array_methods(self): - _test_many_equal([ - (hl.map(lambda x: x % 2 == 0, [0, 1, 4, 6]), [True, False, True, True]), - (hl.len([0, 1, 4, 6]), 4), - (math.isnan(hl.eval(hl.mean(hl.empty_array(hl.tint)))), True), - (hl.mean([0, 1, 4, 6, hl.missing(tint32)]), 2.75), - (hl.median(hl.empty_array(hl.tint)), None), - (1 <= hl.eval(hl.median([0, 1, 4, 6])) <= 4, True) - ] + [test - for f in [lambda x: hl.int32(x), lambda x: hl.int64(x), lambda x: hl.float32(x), lambda x: hl.float64(x)] - for test in [(hl.product([f(x) for x in [1, 4, 6]]), 24), - (hl.sum([f(x) for x in [1, 4, 6]]), 11)] - ] + [ - (hl.group_by(lambda x: x % 2 == 0, [0, 1, 4, 6]), {True: [0, 4, 6], False: [1]}), - (hl.flatmap(lambda x: hl.range(0, x), [1, 2, 3]), [0, 0, 1, 0, 1, 2]), - (hl.flatmap(lambda x: hl.set(hl.range(0, x.length()).map(lambda i: x[i])), {"ABC", "AAa", "BD"}), - {'A', 'a', 'B', 'C', 'D'}) - ]) + _test_many_equal( + [ + (hl.map(lambda x: x % 2 == 0, [0, 1, 4, 6]), [True, False, True, True]), + (hl.len([0, 1, 4, 6]), 4), + (math.isnan(hl.eval(hl.mean(hl.empty_array(hl.tint)))), True), + (hl.mean([0, 1, 4, 6, hl.missing(tint32)]), 2.75), + (hl.median(hl.empty_array(hl.tint)), None), + (1 <= hl.eval(hl.median([0, 1, 4, 6])) <= 4, True), + ] + + [ + test + for f in [ + lambda x: hl.int32(x), + lambda x: hl.int64(x), + lambda x: hl.float32(x), + lambda x: hl.float64(x), + ] + for test in [(hl.product([f(x) for x in [1, 4, 6]]), 24), (hl.sum([f(x) for x in [1, 4, 6]]), 11)] + ] + + [ + (hl.group_by(lambda x: x % 2 == 0, [0, 1, 4, 6]), {True: [0, 4, 6], False: [1]}), + (hl.flatmap(lambda x: hl.range(0, x), [1, 2, 3]), [0, 0, 1, 0, 1, 2]), + ( + hl.flatmap(lambda x: hl.set(hl.range(0, x.length()).map(lambda i: x[i])), {"ABC", "AAa", "BD"}), + {'A', 'a', 'B', 'C', 'D'}, + ), + ] + ) def test_starmap(self): - self.assertEqual(hl.eval(hl.array([(1, 2), (2, 3)]).starmap(lambda x,y: x+y)), [3, 5]) + self.assertEqual(hl.eval(hl.array([(1, 2), (2, 3)]).starmap(lambda x, y: x + y)), [3, 5]) def test_array_corr(self): x1 = [random.uniform(-10, 10) for x in range(10)] @@ -2585,8 +2734,10 @@ def test_array_corr(self): def test_array_corr_missingness(self): x1 = [None, None, 5.0] + [random.uniform(-10, 10) for x in range(15)] x2 = [None, 5.0, None] + [random.uniform(-10, 10) for x in range(15)] - self.assertAlmostEqual(hl.eval(hl.corr(hl.literal(x1, 'array'), hl.literal(x2, 'array'))), - pearsonr(x1[3:], x2[3:])[0]) + self.assertAlmostEqual( + hl.eval(hl.corr(hl.literal(x1, 'array'), hl.literal(x2, 'array'))), + pearsonr(x1[3:], x2[3:])[0], + ) def test_array_grouped(self): x = hl.array([0, 1, 2, 3, 4]) @@ -2612,23 +2763,39 @@ def test_sorted(self): self.assertEqual(hl.eval(hl.sorted([0, 1, 4, 3, 2], lambda x: x % 2, reverse=True)), [1, 3, 0, 4, 2]) self.assertEqual(hl.eval(hl.sorted([0, 1, 4, hl.missing(tint), 3, 2], lambda x: x)), [0, 1, 2, 3, 4, None]) - self.assertEqual(hl.sorted([0, 1, 4, hl.missing(tint), 3, 2], lambda x: x, reverse=True).collect()[0], [4, 3, 2, 1, 0, None]) - self.assertEqual(hl.eval(hl.sorted([0, 1, 4, hl.missing(tint), 3, 2], lambda x: x, reverse=True)), [4, 3, 2, 1, 0, None]) + self.assertEqual( + hl.sorted([0, 1, 4, hl.missing(tint), 3, 2], lambda x: x, reverse=True).collect()[0], [4, 3, 2, 1, 0, None] + ) + self.assertEqual( + hl.eval(hl.sorted([0, 1, 4, hl.missing(tint), 3, 2], lambda x: x, reverse=True)), [4, 3, 2, 1, 0, None] + ) self.assertEqual(hl.eval(hl.sorted({0, 1, 4, 3, 2})), [0, 1, 2, 3, 4]) self.assertEqual(hl.eval(hl.sorted({"foo": 1, "bar": 2})), [("bar", 2), ("foo", 1)]) def test_sort_by(self): - self.assertEqual(hl.eval(hl._sort_by(["c", "aaa", "bb", hl.missing(hl.tstr)], lambda l, r: hl.len(l) < hl.len(r))), ["c", "bb", "aaa", None]) - self.assertEqual(hl.eval(hl._sort_by([hl.Struct(x=i, y="foo", z=5.5) for i in [5, 3, 8, 2, 5]], lambda l, r: l.x < r.x)), - [hl.Struct(x=i, y="foo", z=5.5) for i in [2, 3, 5, 5, 8]]) + self.assertEqual( + hl.eval(hl._sort_by(["c", "aaa", "bb", hl.missing(hl.tstr)], lambda l, r: hl.len(l) < hl.len(r))), + ["c", "bb", "aaa", None], + ) + self.assertEqual( + hl.eval(hl._sort_by([hl.Struct(x=i, y="foo", z=5.5) for i in [5, 3, 8, 2, 5]], lambda l, r: l.x < r.x)), + [hl.Struct(x=i, y="foo", z=5.5) for i in [2, 3, 5, 5, 8]], + ) with self.assertRaises(hl.utils.java.FatalError): - self.assertEqual(hl.eval(hl._sort_by([hl.Struct(x=i, y="foo", z=5.5) for i in [5, 3, 8, 2, 5, hl.missing(hl.tint32)]], lambda l, r: l.x < r.x)), - [hl.Struct(x=i, y="foo", z=5.5) for i in [2, 3, 5, 5, 8, None]]) + self.assertEqual( + hl.eval( + hl._sort_by( + [hl.Struct(x=i, y="foo", z=5.5) for i in [5, 3, 8, 2, 5, hl.missing(hl.tint32)]], + lambda l, r: l.x < r.x, + ) + ), + [hl.Struct(x=i, y="foo", z=5.5) for i in [2, 3, 5, 5, 8, None]], + ) def test_array_first(self): - a = hl.array([1,2,3]) + a = hl.array([1, 2, 3]) assert hl.eval(a.first()) == 1 assert hl.eval(a.filter(lambda x: x > 5).first()) is None @@ -2638,7 +2805,7 @@ def test_array_last(self): assert hl.eval(a.filter(lambda x: x > 5).last()) is None def test_array_index(self): - a = hl.array([1,2,3]) + a = hl.array([1, 2, 3]) assert hl.eval(a.index(2) == 1) assert hl.eval(a.index(4)) is None assert hl.eval(a.index(lambda x: x % 2 == 0) == 1) @@ -2667,7 +2834,18 @@ def test_max(self): (hl.max(0, 1.0, 2), 2.0), (hl.nanmax(0, 1.0, 2), 2.0), (hl.max(0, 1, 2), 2), - (hl.max([0, 10, 2, 3, 4, 5, 6, ]), 10), + ( + hl.max([ + 0, + 10, + 2, + 3, + 4, + 5, + 6, + ]), + 10, + ), (hl.max(0, 10, 2, 3, 4, 5, 6), 10), (hl.max([-5, -4, hl.missing(tint32), -3, -2, hl.missing(tint32)]), -2), (hl.max([float('nan'), -4, float('nan'), -3, -2, hl.missing(tint32)]), float('nan')), @@ -2700,10 +2878,8 @@ def test_max(self): actual = r[i] expected = exprs_and_results[i][1] assert actual == expected or ( - actual is not None - and expected is not None - and (math.isnan(actual) and math.isnan(expected))), \ - f'{i}: {actual}, {expected}' + actual is not None and expected is not None and (math.isnan(actual) and math.isnan(expected)) + ), f'{i}: {actual}, {expected}' def test_min(self): exprs_and_results = [ @@ -2747,10 +2923,8 @@ def test_min(self): actual = r[i] expected = exprs_and_results[i][1] assert actual == expected or ( - actual is not None - and expected is not None - and (math.isnan(actual) and math.isnan(expected))), \ - f'{i}: {actual}, {expected}' + actual is not None and expected is not None and (math.isnan(actual) and math.isnan(expected)) + ), f'{i}: {actual}, {expected}' def test_abs(self): self.assertEqual(hl.eval(hl.abs(-5)), 5) @@ -2784,7 +2958,9 @@ def test_show_row_key_regression(self): def test_show_expression(self): ds = hl.utils.range_matrix_table(3, 3) result = ds.col_idx.show(handler=str) - assert result == '''+---------+ + assert ( + result + == """+---------+ | col_idx | +---------+ | int32 | @@ -2793,23 +2969,24 @@ def test_show_expression(self): | 1 | | 2 | +---------+ -''' +""" + ) @test_timeout(4 * 60) def test_export_genetic_data(self): mt = hl.balding_nichols_model(1, 3, 3) - mt = mt.key_cols_by(s = 's' + hl.str(mt.sample_idx)) + mt = mt.key_cols_by(s='s' + hl.str(mt.sample_idx)) with hl.TemporaryFilename() as f: mt.GT.export(f) - actual = hl.import_matrix_table(f, - row_fields={'locus': hl.tstr, - 'alleles': hl.tstr}, - row_key=['locus', 'alleles'], - entry_type=hl.tstr) + actual = hl.import_matrix_table( + f, row_fields={'locus': hl.tstr, 'alleles': hl.tstr}, row_key=['locus', 'alleles'], entry_type=hl.tstr + ) actual = actual.rename({'col_id': 's'}) - actual = actual.key_rows_by(locus = hl.parse_locus(actual.locus), - alleles = actual.alleles.replace('"', '').replace(r'\[', '').replace(r'\]', '').split(',')) - actual = actual.transmute_entries(GT = hl.parse_call(actual.x)) + actual = actual.key_rows_by( + locus=hl.parse_locus(actual.locus), + alleles=actual.alleles.replace('"', '').replace(r'\[', '').replace(r'\]', '').split(','), + ) + actual = actual.transmute_entries(GT=hl.parse_call(actual.x)) expected = mt.select_cols().select_globals().select_rows() expected.show() actual.show() @@ -2859,8 +3036,7 @@ def test_interval_ops(self): self.assertTrue(hl.eval_typed(interval.overlaps(hl.interval(5, 9))) == (True, hl.tbool)) li = hl.parse_locus_interval('1:100-110') - self.assertEqual(hl.eval(li), hl.utils.Interval(hl.genetics.Locus("1", 100), - hl.genetics.Locus("1", 110))) + self.assertEqual(hl.eval(li), hl.utils.Interval(hl.genetics.Locus("1", 100), hl.genetics.Locus("1", 110))) self.assertTrue(li.dtype.point_type == hl.tlocus()) self.assertTrue(hl.eval(li.contains(hl.locus("1", 100)))) self.assertTrue(hl.eval(li.contains(hl.locus("1", 109)))) @@ -2877,37 +3053,40 @@ def test_interval_ops(self): self.assertFalse(hl.eval(li.overlaps(li5))) def test_locus_interval_constructors(self): - li_contig_start = hl.locus_interval('1', 0, 2, False, False, - invalid_missing=True) - self.assertTrue(hl.eval(li_contig_start) == hl.utils.Interval( - hl.genetics.Locus("1", 1), - hl.genetics.Locus("1", 2), - includes_start=True, - includes_end=False)) - - li_contig_middle1 = hl.locus_interval('1', 100, 100, True, False, - invalid_missing=True) - self.assertTrue(hl.eval(li_contig_middle1) == hl.utils.Interval( - hl.genetics.Locus("1", 99), - hl.genetics.Locus("1", 100), - includes_start=False, - includes_end=False)) - - li_contig_middle2 = hl.locus_interval('1', 100, 100, False, True, - invalid_missing=True) - self.assertTrue(hl.eval(li_contig_middle2) == hl.utils.Interval( - hl.genetics.Locus("1", 100), - hl.genetics.Locus("1", 101), - includes_start=False, - includes_end=False)) - - li_contig_end = hl.locus_interval('1', 249250621, 249250622, True, - False, invalid_missing=True) - self.assertTrue(hl.eval(li_contig_end) == hl.utils.Interval( - hl.genetics.Locus("1", 249250621), - hl.genetics.Locus("1", 249250621), - includes_start=True, - includes_end=True)) + li_contig_start = hl.locus_interval('1', 0, 2, False, False, invalid_missing=True) + self.assertTrue( + hl.eval(li_contig_start) + == hl.utils.Interval( + hl.genetics.Locus("1", 1), hl.genetics.Locus("1", 2), includes_start=True, includes_end=False + ) + ) + + li_contig_middle1 = hl.locus_interval('1', 100, 100, True, False, invalid_missing=True) + self.assertTrue( + hl.eval(li_contig_middle1) + == hl.utils.Interval( + hl.genetics.Locus("1", 99), hl.genetics.Locus("1", 100), includes_start=False, includes_end=False + ) + ) + + li_contig_middle2 = hl.locus_interval('1', 100, 100, False, True, invalid_missing=True) + self.assertTrue( + hl.eval(li_contig_middle2) + == hl.utils.Interval( + hl.genetics.Locus("1", 100), hl.genetics.Locus("1", 101), includes_start=False, includes_end=False + ) + ) + + li_contig_end = hl.locus_interval('1', 249250621, 249250622, True, False, invalid_missing=True) + self.assertTrue( + hl.eval(li_contig_end) + == hl.utils.Interval( + hl.genetics.Locus("1", 249250621), + hl.genetics.Locus("1", 249250621), + includes_start=True, + includes_end=True, + ) + ) li1 = hl.locus_interval('1', 0, 1, False, False, invalid_missing=True) li2 = hl.locus_interval('1', 0, 1, True, False, invalid_missing=True) @@ -2944,121 +3123,170 @@ def test_reference_genome_fns(self): @test_timeout(batch=5 * 60) def test_initop_table(self): - t = (hl.utils.range_table(5, 3) - .annotate(GT=hl.call(0, 1)) - .annotate_globals(alleles=["A", "T"])) + t = hl.utils.range_table(5, 3).annotate(GT=hl.call(0, 1)).annotate_globals(alleles=["A", "T"]) - self.assertTrue(t.aggregate(hl.agg.call_stats(t.GT, t.alleles)) == - hl.Struct(AC=[5, 5], AF=[0.5, 0.5], AN=10, homozygote_count=[0, 0])) # Tests table.aggregate initOp + self.assertTrue( + t.aggregate(hl.agg.call_stats(t.GT, t.alleles)) + == hl.Struct(AC=[5, 5], AF=[0.5, 0.5], AN=10, homozygote_count=[0, 0]) + ) # Tests table.aggregate initOp @test_timeout(batch=5 * 60) def test_initop_matrix_table(self): - mt = (hl.utils.range_matrix_table(10, 5, 5) - .annotate_entries(GT=hl.call(0, 1)) - .annotate_rows(alleles=["A", "T"]) - .annotate_globals(alleles2=["G", "C"])) + mt = ( + hl.utils.range_matrix_table(10, 5, 5) + .annotate_entries(GT=hl.call(0, 1)) + .annotate_rows(alleles=["A", "T"]) + .annotate_globals(alleles2=["G", "C"]) + ) - row_agg = mt.annotate_rows(call_stats=hl.agg.call_stats(mt.GT, mt.alleles)).rows() # Tests MatrixMapRows initOp - col_agg = mt.annotate_cols(call_stats=hl.agg.call_stats(mt.GT, mt.alleles2)).cols() # Tests MatrixMapCols initOp + row_agg = mt.annotate_rows(call_stats=hl.agg.call_stats(mt.GT, mt.alleles)).rows() # Tests MatrixMapRows initOp + col_agg = mt.annotate_cols( + call_stats=hl.agg.call_stats(mt.GT, mt.alleles2) + ).cols() # Tests MatrixMapCols initOp # must test that call_stats isn't null, because equality doesn't test for that - self.assertTrue(row_agg.all( - hl.is_defined(row_agg.call_stats) - & (row_agg.call_stats == hl.struct(AC=[5, 5], AF=[0.5, 0.5], AN=10, homozygote_count=[0, 0])))) - self.assertTrue(col_agg.all( - hl.is_defined(col_agg.call_stats) - & (col_agg.call_stats == hl.struct(AC=[10, 10], AF=[0.5, 0.5], AN=20, homozygote_count=[0, 0])))) + self.assertTrue( + row_agg.all( + hl.is_defined(row_agg.call_stats) + & (row_agg.call_stats == hl.struct(AC=[5, 5], AF=[0.5, 0.5], AN=10, homozygote_count=[0, 0])) + ) + ) + self.assertTrue( + col_agg.all( + hl.is_defined(col_agg.call_stats) + & (col_agg.call_stats == hl.struct(AC=[10, 10], AF=[0.5, 0.5], AN=20, homozygote_count=[0, 0])) + ) + ) @test_timeout(batch=5 * 60) def test_initop_table_aggregate_by_key(self): - t = (hl.utils.range_table(5, 3) - .annotate(GT=hl.call(0, 1)) - .annotate_globals(alleles=["A", "T"])) + t = hl.utils.range_table(5, 3).annotate(GT=hl.call(0, 1)).annotate_globals(alleles=["A", "T"]) t2 = t.annotate(group=t.idx < 3) group_agg = t2.group_by(t2['group']).aggregate(call_stats=hl.agg.call_stats(t2.GT, t2.alleles)) - self.assertTrue(group_agg.all( - hl.if_else(group_agg.group, + self.assertTrue( + group_agg.all( + hl.if_else( + group_agg.group, hl.is_defined(group_agg.call_stats) & (group_agg.call_stats == hl.struct(AC=[3, 3], AF=[0.5, 0.5], AN=6, homozygote_count=[0, 0])), hl.is_defined(group_agg.call_stats) - & (group_agg.call_stats == hl.struct(AC=[2, 2], AF=[0.5, 0.5], AN=4, homozygote_count=[0, 0]))))) + & (group_agg.call_stats == hl.struct(AC=[2, 2], AF=[0.5, 0.5], AN=4, homozygote_count=[0, 0])), + ) + ) + ) @test_timeout(batch=5 * 60) def test_initop_matrix_aggregate_cols_by_key_entries(self): - mt = (hl.utils.range_matrix_table(10, 5, 5) - .annotate_entries(GT=hl.call(0, 1)) - .annotate_rows(alleles=["A", "T"]) - .annotate_globals(alleles2=["G", "C"])) + mt = ( + hl.utils.range_matrix_table(10, 5, 5) + .annotate_entries(GT=hl.call(0, 1)) + .annotate_rows(alleles=["A", "T"]) + .annotate_globals(alleles2=["G", "C"]) + ) mt2 = mt.annotate_cols(group=mt.col_idx < 3) - group_cols_agg = (mt2.group_cols_by(mt2['group']) - .aggregate(call_stats=hl.agg.call_stats(mt2.GT, mt2.alleles2)).entries()) + group_cols_agg = ( + mt2.group_cols_by(mt2['group']).aggregate(call_stats=hl.agg.call_stats(mt2.GT, mt2.alleles2)).entries() + ) - self.assertTrue(group_cols_agg.all( - hl.if_else(group_cols_agg.group, + self.assertTrue( + group_cols_agg.all( + hl.if_else( + group_cols_agg.group, hl.is_defined(group_cols_agg.call_stats) & (group_cols_agg.call_stats == hl.struct(AC=[3, 3], AF=[0.5, 0.5], AN=6, homozygote_count=[0, 0])), hl.is_defined(group_cols_agg.call_stats) - & (group_cols_agg.call_stats == hl.struct(AC=[2, 2], AF=[0.5, 0.5], AN=4, homozygote_count=[0, 0]))))) + & (group_cols_agg.call_stats == hl.struct(AC=[2, 2], AF=[0.5, 0.5], AN=4, homozygote_count=[0, 0])), + ) + ) + ) @test_timeout(batch=5 * 60) def test_initop_matrix_aggregate_cols_by_key_cols(self): - mt = (hl.utils.range_matrix_table(10, 5, 5) - .annotate_entries(GT=hl.call(0, 1)) - .annotate_rows(alleles=["A", "T"]) - .annotate_globals(alleles2=["G", "C"])) + mt = ( + hl.utils.range_matrix_table(10, 5, 5) + .annotate_entries(GT=hl.call(0, 1)) + .annotate_rows(alleles=["A", "T"]) + .annotate_globals(alleles2=["G", "C"]) + ) mt2 = mt.annotate_cols(group=mt.col_idx < 3, GT_col=hl.call(0, 1)) - group_cols_agg = (mt2.group_cols_by(mt2['group']) - .aggregate_cols(call_stats=hl.agg.call_stats(mt2.GT_col, mt2.alleles2)) - .result() - ).entries() + group_cols_agg = ( + mt2.group_cols_by(mt2['group']) + .aggregate_cols(call_stats=hl.agg.call_stats(mt2.GT_col, mt2.alleles2)) + .result() + ).entries() - self.assertTrue(group_cols_agg.all( - hl.if_else(group_cols_agg.group, + self.assertTrue( + group_cols_agg.all( + hl.if_else( + group_cols_agg.group, hl.is_defined(group_cols_agg.call_stats) & (group_cols_agg.call_stats == hl.struct(AC=[3, 3], AF=[0.5, 0.5], AN=6, homozygote_count=[0, 0])), hl.is_defined(group_cols_agg.call_stats) - & (group_cols_agg.call_stats == hl.struct(AC=[2, 2], AF=[0.5, 0.5], AN=4, homozygote_count=[0, 0]))))) + & (group_cols_agg.call_stats == hl.struct(AC=[2, 2], AF=[0.5, 0.5], AN=4, homozygote_count=[0, 0])), + ) + ) + ) @test_timeout(batch=5 * 60) def test_initop_matrix_aggregate_rows_by_key_entries(self): - mt = (hl.utils.range_matrix_table(10, 5, 5) - .annotate_entries(GT=hl.call(0, 1)) - .annotate_rows(alleles=["A", "T"]) - .annotate_globals(alleles2=["G", "C"])) + mt = ( + hl.utils.range_matrix_table(10, 5, 5) + .annotate_entries(GT=hl.call(0, 1)) + .annotate_rows(alleles=["A", "T"]) + .annotate_globals(alleles2=["G", "C"]) + ) mt2 = mt.annotate_rows(group=mt.row_idx < 3) - group_rows_agg = (mt2.group_rows_by(mt2['group']) - .aggregate(call_stats=hl.agg.call_stats(mt2.GT, mt2.alleles2)).entries()) + group_rows_agg = ( + mt2.group_rows_by(mt2['group']).aggregate(call_stats=hl.agg.call_stats(mt2.GT, mt2.alleles2)).entries() + ) - self.assertTrue(group_rows_agg.all( - hl.if_else(group_rows_agg.group, + self.assertTrue( + group_rows_agg.all( + hl.if_else( + group_rows_agg.group, hl.is_defined(group_rows_agg.call_stats) & (group_rows_agg.call_stats == hl.struct(AC=[3, 3], AF=[0.5, 0.5], AN=6, homozygote_count=[0, 0])), hl.is_defined(group_rows_agg.call_stats) - & (group_rows_agg.call_stats == hl.struct(AC=[7, 7], AF=[0.5, 0.5], AN=14, homozygote_count=[0, 0]))))) + & ( + group_rows_agg.call_stats == hl.struct(AC=[7, 7], AF=[0.5, 0.5], AN=14, homozygote_count=[0, 0]) + ), + ) + ) + ) @test_timeout(batch=5 * 60) def test_initop_matrix_aggregate_rows_by_key_rows(self): - mt = (hl.utils.range_matrix_table(10, 5, 5) - .annotate_entries(GT=hl.call(0, 1)) - .annotate_rows(alleles=["A", "T"]) - .annotate_globals(alleles2=["G", "C"])) + mt = ( + hl.utils.range_matrix_table(10, 5, 5) + .annotate_entries(GT=hl.call(0, 1)) + .annotate_rows(alleles=["A", "T"]) + .annotate_globals(alleles2=["G", "C"]) + ) mt2 = mt.annotate_rows(group=mt.row_idx < 3, GT_row=hl.call(0, 1)) - group_rows_agg = (mt2.group_rows_by(mt2['group']) - .aggregate_rows(call_stats=hl.agg.call_stats(mt2.GT_row, mt2.alleles2)) - .result() - ).entries() + group_rows_agg = ( + mt2.group_rows_by(mt2['group']) + .aggregate_rows(call_stats=hl.agg.call_stats(mt2.GT_row, mt2.alleles2)) + .result() + ).entries() - self.assertTrue(group_rows_agg.all( - hl.if_else(group_rows_agg.group, + self.assertTrue( + group_rows_agg.all( + hl.if_else( + group_rows_agg.group, hl.is_defined(group_rows_agg.call_stats) & (group_rows_agg.call_stats == hl.struct(AC=[3, 3], AF=[0.5, 0.5], AN=6, homozygote_count=[0, 0])), hl.is_defined(group_rows_agg.call_stats) - & (group_rows_agg.call_stats == hl.struct(AC=[7, 7], AF=[0.5, 0.5], AN=14, homozygote_count=[0, 0]))))) + & ( + group_rows_agg.call_stats == hl.struct(AC=[7, 7], AF=[0.5, 0.5], AN=14, homozygote_count=[0, 0]) + ), + ) + ) + ) def test_call_stats_init(self): ht = hl.utils.range_table(3) - ht = ht.annotate(GT = hl.unphased_diploid_gt_index_call(ht.idx)) + ht = ht.annotate(GT=hl.unphased_diploid_gt_index_call(ht.idx)) assert ht.aggregate(hl.agg.call_stats(ht.GT, 2).AC) == [3, 3] def test_mendel_error_code(self): @@ -3067,16 +3295,23 @@ def test_mendel_error_code(self): locus_x_nonpar = hl.Locus(locus_x_par.contig, locus_x_par.position - 1) locus_y_nonpar = hl.Locus('Y', hl.get_reference('default').lengths['Y'] - 1) - self.assertTrue(hl.eval(hl.all(lambda x: x, hl.array([ - hl.literal(locus_auto).in_autosome_or_par(), - hl.literal(locus_auto).in_autosome_or_par(), - ~hl.literal(locus_x_par).in_autosome(), - hl.literal(locus_x_par).in_autosome_or_par(), - ~hl.literal(locus_x_nonpar).in_autosome_or_par(), - hl.literal(locus_x_nonpar).in_x_nonpar(), - ~hl.literal(locus_y_nonpar).in_autosome_or_par(), - hl.literal(locus_y_nonpar).in_y_nonpar() - ])))) + self.assertTrue( + hl.eval( + hl.all( + lambda x: x, + hl.array([ + hl.literal(locus_auto).in_autosome_or_par(), + hl.literal(locus_auto).in_autosome_or_par(), + ~hl.literal(locus_x_par).in_autosome(), + hl.literal(locus_x_par).in_autosome_or_par(), + ~hl.literal(locus_x_nonpar).in_autosome_or_par(), + hl.literal(locus_x_nonpar).in_x_nonpar(), + ~hl.literal(locus_y_nonpar).in_autosome_or_par(), + hl.literal(locus_y_nonpar).in_y_nonpar(), + ]), + ) + ) + ) hr = hl.Call([0, 0]) het = hl.Call([0, 1]) @@ -3130,14 +3365,12 @@ def test_mendel_error_code(self): (locus_auto, True, hv, het, het): None, (locus_auto, True, het, hr, het): None, (locus_auto, True, hv, hr, het): None, - (locus_auto, True, hv, hr, het): None, (locus_x_nonpar, True, hv, hr, het): None, (locus_x_nonpar, False, hv, hr, hr): None, (locus_x_nonpar, None, hv, hr, hr): None, (locus_x_nonpar, False, het, hr, hr): None, (locus_y_nonpar, True, het, hr, het): None, (locus_y_nonpar, True, het, hr, hr): None, - (locus_y_nonpar, True, het, hr, het): None, (locus_y_nonpar, True, het, het, het): None, (locus_y_nonpar, True, hr, hr, hr): None, (locus_y_nonpar, None, hr, hr, hr): None, @@ -3146,19 +3379,20 @@ def test_mendel_error_code(self): (locus_y_nonpar, None, hv, hv, hv): None, } - arg_list = hl.literal(list(expected.keys()), - hl.tarray(hl.ttuple(hl.tlocus(), hl.tbool, hl.tcall, hl.tcall, hl.tcall))) + arg_list = hl.literal( + list(expected.keys()), hl.tarray(hl.ttuple(hl.tlocus(), hl.tbool, hl.tcall, hl.tcall, hl.tcall)) + ) values = arg_list.map(lambda args: hl.mendel_error_code(*args)) expr = hl.dict(hl.zip(arg_list, values)) results = hl.eval(expr) for args, result in results.items(): - self.assertEqual(result, expected[args], msg=f'expected {expected[args]}, found {result} at {str(args)}') + self.assertEqual(result, expected[args], msg=f'expected {expected[args]}, found {result} at {args!s}') def test_min_rep(self): def assert_min_reps_to(old, new, pos_change=0): self.assertEqual( hl.eval(hl.min_rep(hl.locus('1', 10), old)), - hl.Struct(locus=hl.Locus('1', 10 + pos_change), alleles=new) + hl.Struct(locus=hl.Locus('1', 10 + pos_change), alleles=new), ) assert_min_reps_to(['TAA', 'TA'], ['TA', 'T']) @@ -3261,13 +3495,15 @@ def test_uniroot_2(self): tol = 1.220703e-4 self.assertAlmostEqual(hl.eval(hl.uniroot(lambda x: x - 1, 0, 3, tolerance=tol)), 1) - self.assertAlmostEqual(hl.eval(hl.uniroot(lambda x: hl.log(x) - 1, 0, 3, tolerance=tol)), 2.718281828459045, delta=tol) + self.assertAlmostEqual( + hl.eval(hl.uniroot(lambda x: hl.log(x) - 1, 0, 3, tolerance=tol)), 2.718281828459045, delta=tol + ) def test_uniroot_3(self): with self.assertRaisesRegex(hl.utils.FatalError, r"value of f\(x\) is missing"): hl.eval(hl.uniroot(lambda x: hl.missing('float'), 0, 1)) with self.assertRaisesRegex(hl.utils.HailUserError, 'opposite signs'): - hl.eval(hl.uniroot(lambda x: x ** 2 - 0.5, -1, 1)) + hl.eval(hl.uniroot(lambda x: x**2 - 0.5, -1, 1)) with self.assertRaisesRegex(hl.utils.HailUserError, 'min must be less than max'): hl.eval(hl.uniroot(lambda x: x, 1, -1)) @@ -3364,6 +3600,48 @@ def test_contingency_table_test(self): self.assertAlmostEqual(res['p_value'] / 2.1565e-7, 1.0, places=4) self.assertAlmostEqual(res['odds_ratio'], 4.91805817) + def test_cochran_mantel_haenszel_test(self): + # https://cran.r-project.org/web/packages/samplesizeCMH/vignettes/samplesizeCMH-introduction.html + a = [118, 154, 422, 670] + b = [62, 25, 88, 192] + c = [4, 13, 106, 3] + d = [141, 93, 90, 20] + + result = hl.eval(hl.cochran_mantel_haenszel_test(a, b, c, d)) + self.assertEqual(360.3311519725744, result['test_statistic']) + self.assertEqual(2.384935629406975e-80, result['p_value']) + + # https://www.biostathandbook.com/cmh.html + a = [708, 136, 106, 109, 801, 159, 151, 950] + b = [50, 24, 32, 22, 102, 27, 51, 173] + c = [169, 73, 17, 16, 180, 18, 28, 218] + d = [13, 14, 4, 26, 25, 13, 15, 33] + + expr = hl.cochran_mantel_haenszel_test(a, b, c, d) + result = hl.eval(expr) + self.assertEqual(6.07023412667767, result['test_statistic']) + self.assertEqual(0.013747873638119005, result['p_value']) + + a = [56, 61, 73, 71] + b = [69, 257, 65, 48] + c = [40, 57, 71, 55] + d = [77, 301, 79, 48] + + expr = hl.cochran_mantel_haenszel_test(a, b, c, d) + result = hl.eval(expr) + self.assertEqual(5.0496881823306765, result['test_statistic']) + self.assertEqual(0.024630370456863417, result['p_value']) + + a = hl.array([2, 4, 1, 1, 2]) + b = hl.array([46, 67, 86, 37, 92]) + c = hl.array([11, 12, 4, 6, 1]) + d = hl.array([41, 60, 76, 32, 93]) + + expr = hl.cochran_mantel_haenszel_test(a, b, c, d) + result = hl.eval(expr) + self.assertEqual(12.74572269532737, result['test_statistic']) + self.assertEqual(0.0003568242404514306, result['p_value']) + def test_hardy_weinberg_test(self): two_sided_res = hl.eval(hl.hardy_weinberg_test(1, 2, 1, one_sided=False)) self.assertAlmostEqual(two_sided_res['p_value'], 0.65714285) @@ -3389,8 +3667,12 @@ def test_hardy_weinberg_agg_1(self): mt = hl.utils.range_matrix_table(n_rows=3, n_cols=5) mt = mt.annotate_rows( - hwe_two_sided = hl.agg.hardy_weinberg_test(hl.literal(row_idx_col_idx_to_call).get((mt.row_idx, mt.col_idx)), one_sided=False), - hwe_one_sided = hl.agg.hardy_weinberg_test(hl.literal(row_idx_col_idx_to_call).get((mt.row_idx, mt.col_idx)), one_sided=True) + hwe_two_sided=hl.agg.hardy_weinberg_test( + hl.literal(row_idx_col_idx_to_call).get((mt.row_idx, mt.col_idx)), one_sided=False + ), + hwe_one_sided=hl.agg.hardy_weinberg_test( + hl.literal(row_idx_col_idx_to_call).get((mt.row_idx, mt.col_idx)), one_sided=True + ), ) rows = mt.rows().collect() all_hwe_one_sided = [r.hwe_one_sided for r in rows] @@ -3429,8 +3711,8 @@ def test_hardy_weinberg_agg_2(self): ht = hl.utils.range_table(6) ht = ht.annotate( - x_two_sided = hl.scan.hardy_weinberg_test(hl.literal(calls)[ht.idx % 5], one_sided=False), - x_one_sided = hl.scan.hardy_weinberg_test(hl.literal(calls)[ht.idx % 5], one_sided=True) + x_two_sided=hl.scan.hardy_weinberg_test(hl.literal(calls)[ht.idx % 5], one_sided=False), + x_one_sided=hl.scan.hardy_weinberg_test(hl.literal(calls)[ht.idx % 5], one_sided=True), ) rows = ht.collect() all_x_one_sided = [r.x_one_sided for r in rows] @@ -3512,16 +3794,24 @@ def test_collection_method_missingness(self): self.assertIsNone(hl.eval(hl.sum(a, filter_missing=False))) def test_literal_with_nested_expr(self): - self.assertEqual(hl.eval(hl.literal(hl.set(['A','B']))), {'A', 'B'}) + self.assertEqual(hl.eval(hl.literal(hl.set(['A', 'B']))), {'A', 'B'}) self.assertEqual(hl.eval(hl.literal({hl.str('A'), hl.str('B')})), {'A', 'B'}) def test_format(self): self.assertEqual(hl.eval(hl.format("%.4f %s %.3e", 0.25, 'hello', 0.114)), '0.2500 hello 1.140e-01') self.assertEqual(hl.eval(hl.format("%.4f %d", hl.missing(hl.tint32), hl.missing(hl.tint32))), 'null null') - self.assertEqual(hl.eval(hl.format("%s", hl.struct(foo=5, bar=True, baz=hl.array([4, 5])))), - '{foo: 5, bar: true, baz: [4,5]}') - self.assertEqual(hl.eval(hl.format("%s %s", hl.locus("1", 356), hl.tuple([9, True, hl.missing(hl.tstr)]))), '1:356 (9, true, null)') - self.assertEqual(hl.eval(hl.format("%b %B %b %b", hl.missing(hl.tint), hl.missing(hl.tstr), True, "hello")), "false FALSE true true") + self.assertEqual( + hl.eval(hl.format("%s", hl.struct(foo=5, bar=True, baz=hl.array([4, 5])))), + '{foo: 5, bar: true, baz: [4,5]}', + ) + self.assertEqual( + hl.eval(hl.format("%s %s", hl.locus("1", 356), hl.tuple([9, True, hl.missing(hl.tstr)]))), + '1:356 (9, true, null)', + ) + self.assertEqual( + hl.eval(hl.format("%b %B %b %b", hl.missing(hl.tint), hl.missing(hl.tstr), True, "hello")), + "false FALSE true true", + ) def test_dict_and_set_type_promotion(self): d = hl.literal({5: 5}, dtype='dict') @@ -3542,9 +3832,9 @@ def test_dict_keyed_by_set(self): assert hl.eval(dict_with_set_key) == {frozenset([1, 2, 3]): 4} def test_dict_keyed_by_dict(self): - dict_with_dict_key = hl.dict({hl.dict({1:2, 3:5}): 4}) + dict_with_dict_key = hl.dict({hl.dict({1: 2, 3: 5}): 4}) # Test that it's evalable, since python dicts aren't hashable. - assert hl.eval(dict_with_dict_key) == {hl.utils.frozendict({1:2, 3:5}): 4} + assert hl.eval(dict_with_dict_key) == {hl.utils.frozendict({1: 2, 3: 5}): 4} def test_frozendict_as_literal(self): fd = hl.utils.frozendict({"a": 4, "b": 8}) @@ -3569,10 +3859,7 @@ def test_approx_equal(self): def test_issue3729(self): t = hl.utils.range_table(10, 3) - fold_expr = hl.if_else(t.idx == 3, - [1, 2, 3], - [4, 5, 6]).fold(lambda accum, i: accum & (i == t.idx), - True) + fold_expr = hl.if_else(t.idx == 3, [1, 2, 3], [4, 5, 6]).fold(lambda accum, i: accum & (i == t.idx), True) t.annotate(foo=hl.if_else(fold_expr, 1, 3))._force_count() def assertValueEqual(self, expr, value, t): @@ -3583,10 +3870,15 @@ def test_array_fold_and_scan(self): self.assertValueEqual(hl.fold(lambda x, y: x + y, 0, [1, 2, 3]), 6, tint32) self.assertValueEqual(hl.array_scan(lambda x, y: x + y, 0, [1, 2, 3]), [0, 1, 3, 6], tarray(tint32)) - self.assertValueEqual(hl.fold(lambda x, y: x + y, 0., [1, 2, 3]), 6., tfloat64) - self.assertValueEqual(hl.fold(lambda x, y: x + y, 0, [1., 2., 3.]), 6., tfloat64) - self.assertValueEqual(hl.array_scan(lambda x, y: x + y, 0., [1, 2, 3]), [0., 1., 3., 6.], tarray(tfloat64)) - self.assertValueEqual(hl.array_scan(lambda x, y: x + y, 0, [1., 2., 3.]), [0., 1., 3., 6.], tarray(tfloat64)) + self.assertValueEqual(hl.fold(lambda x, y: x + y, 0.0, [1, 2, 3]), 6.0, tfloat64) + self.assertValueEqual(hl.fold(lambda x, y: x + y, 0, [1.0, 2.0, 3.0]), 6.0, tfloat64) + self.assertValueEqual(hl.array_scan(lambda x, y: x + y, 0.0, [1, 2, 3]), [0.0, 1.0, 3.0, 6.0], tarray(tfloat64)) + self.assertValueEqual( + hl.array_scan(lambda x, y: x + y, 0, [1.0, 2.0, 3.0]), [0.0, 1.0, 3.0, 6.0], tarray(tfloat64) + ) + + def test_sum(self): + self.assertValueEqual(hl.sum([1, 2, 3, 4]), 10, tint32) def test_cumulative_sum(self): self.assertValueEqual(hl.cumulative_sum([1, 2, 3, 4]), [1, 3, 6, 10], tarray(tint32)) @@ -3598,26 +3890,26 @@ def test_nan_inf_checks(self): nan = math.nan na = hl.missing('float64') - assert hl.eval(hl.is_finite(finite)) == True - assert hl.eval(hl.is_finite(infinite)) == False - assert hl.eval(hl.is_finite(nan)) == False - assert hl.eval(hl.is_finite(na)) == None + assert hl.eval(hl.is_finite(finite)) is True + assert hl.eval(hl.is_finite(infinite)) is False + assert hl.eval(hl.is_finite(nan)) is False + assert hl.eval(hl.is_finite(na)) is None - assert hl.eval(hl.is_infinite(finite)) == False - assert hl.eval(hl.is_infinite(infinite)) == True - assert hl.eval(hl.is_infinite(nan)) == False - assert hl.eval(hl.is_infinite(na)) == None + assert hl.eval(hl.is_infinite(finite)) is False + assert hl.eval(hl.is_infinite(infinite)) is True + assert hl.eval(hl.is_infinite(nan)) is False + assert hl.eval(hl.is_infinite(na)) is None - assert hl.eval(hl.is_nan(finite)) == False - assert hl.eval(hl.is_nan(infinite)) == False - assert hl.eval(hl.is_nan(nan)) == True - assert hl.eval(hl.is_nan(na)) == None + assert hl.eval(hl.is_nan(finite)) is False + assert hl.eval(hl.is_nan(infinite)) is False + assert hl.eval(hl.is_nan(nan)) is True + assert hl.eval(hl.is_nan(na)) is None def test_array_and_if_requiredness(self): mt = hl.import_vcf(resource('sample.vcf'), array_elements_required=True) hl.tuple((mt.AD, mt.PL)).show() hl.array([mt.AD, mt.PL]).show() - hl.array([mt.AD, [1,2]]).show() + hl.array([mt.AD, [1, 2]]).show() def test_string_unicode(self): self.assertTrue(hl.eval(hl.str("李") == "李")) @@ -3709,18 +4001,18 @@ def test_prev_non_null(self): @test_timeout(batch=5 * 60) def test_summarize_runs(self): - mt = hl.utils.range_matrix_table(3,3).annotate_entries( - x1 = 'a', - x2 = 1, - x3 = 1.5, - x4 = True, - x5 = ['1'], - x6 = {'1'}, + mt = hl.utils.range_matrix_table(3, 3).annotate_entries( + x1='a', + x2=1, + x3=1.5, + x4=True, + x5=['1'], + x6={'1'}, x7={'1': 5}, x8=hl.struct(a=5, b='7'), - x9=(1,2,3), + x9=(1, 2, 3), x10=hl.locus('1', 123123), - x11=hl.call(0, 1, phased=True) + x11=hl.call(0, 1, phased=True), ) mt.summarize() @@ -3728,23 +4020,22 @@ def test_summarize_runs(self): mt.x1.summarize() def test_variant_str(self): - assert hl.eval( - hl.variant_str(hl.struct(locus=hl.locus('1', 10000), alleles=['A', 'T', 'CCC']))) == '1:10000:A:T,CCC' + assert ( + hl.eval(hl.variant_str(hl.struct(locus=hl.locus('1', 10000), alleles=['A', 'T', 'CCC']))) + == '1:10000:A:T,CCC' + ) assert hl.eval(hl.variant_str(hl.locus('1', 10000), ['A', 'T', 'CCC'])) == '1:10000:A:T,CCC' with pytest.raises(ValueError): hl.variant_str() def test_collection_getitem(self): collection_types = [(hl.array, list), (hl.set, frozenset)] - for (htyp, pytyp) in collection_types: + for htyp, pytyp in collection_types: x = htyp([hl.struct(a='foo', b=3), hl.struct(a='bar', b=4)]) assert hl.eval(x.a) == pytyp(['foo', 'bar']) - a = hl.array([hl.struct(b=[hl.struct(inner=1), - hl.struct(inner=2)]), - hl.struct(b=[hl.struct(inner=3)])]) - assert hl.eval(a.b) == [[hl.Struct(inner=1), hl.Struct(inner=2)], - [hl.Struct(inner=3)]] + a = hl.array([hl.struct(b=[hl.struct(inner=1), hl.struct(inner=2)]), hl.struct(b=[hl.struct(inner=3)])]) + assert hl.eval(a.b) == [[hl.Struct(inner=1), hl.Struct(inner=2)], [hl.Struct(inner=3)]] assert hl.eval(hl.flatten(a.b).inner) == [1, 2, 3] assert hl.eval(a.b.inner) == [[1, 2], [3]] assert hl.eval(a["b"].inner) == [[1, 2], [3]] @@ -3756,8 +4047,8 @@ def test_struct_collection_getattr(self): for htyp in collection_types: a = htyp([hl.struct(x='foo'), hl.struct(x='bar')]) - assert hasattr(a, 'x') == True - assert hasattr(a, 'y') == False + assert hasattr(a, 'x') is True + assert hasattr(a, 'y') is False with pytest.raises(AttributeError, match="has no field"): getattr(a, 'y') @@ -3777,8 +4068,7 @@ def verify_6930_still_holds(self): mt = rmt33.key_rows_by(rowkey=-mt.row_idx) assert mt.row.collect() == [hl.Struct(row_idx=x) for x in [2, 1, 0]] - mt = rmt33.annotate_entries( - x=(rmt33.row_idx + 1) * (rmt33.col_idx + 1)) + mt = rmt33.annotate_entries(x=(rmt33.row_idx + 1) * (rmt33.col_idx + 1)) mt = mt.key_rows_by(rowkey=-mt.row_idx) mt = mt.choose_cols([2, 1, 0]) assert mt.x.collect() == [9, 6, 3, 6, 4, 2, 3, 2, 1] @@ -3822,7 +4112,7 @@ def test_parse_json(self): hl.call(0, 2, phased=True), hl.locus_interval('1', 10000, 10005), hl.struct(foo='bar'), - hl.tuple([1, 2, 'str']) + hl.tuple([1, 2, 'str']), ] assert hl.eval(hl._compare(hl.tuple(values), hl.tuple(hl.parse_json(hl.json(v), v.dtype) for v in values)) == 0) @@ -3841,12 +4131,11 @@ def test_struct_expression_expr_rename(self): s = hl.struct(f1=1, f2=2, f3=3) assert hl.eval(s.rename({'f1': 'foo'})) == hl.Struct(f2=2, f3=3, foo=1) - assert hl.eval(s.rename({'f3': 'fiddle', 'f1': 'hello'})) == \ - hl.Struct(f2=2, fiddle=3, hello=1) - assert hl.eval(s.rename({'f3': 'fiddle', 'f1': 'hello', 'f2': 'ohai'})) == \ - hl.Struct(fiddle=3, hello=1, ohai=2) - assert hl.eval(s.rename({'f3': 'fiddle', 'f1': 'hello', 'f2': 's p a c e'})) == \ - hl.Struct(fiddle=3, hello=1, **{'s p a c e': 2}) + assert hl.eval(s.rename({'f3': 'fiddle', 'f1': 'hello'})) == hl.Struct(f2=2, fiddle=3, hello=1) + assert hl.eval(s.rename({'f3': 'fiddle', 'f1': 'hello', 'f2': 'ohai'})) == hl.Struct(fiddle=3, hello=1, ohai=2) + assert hl.eval(s.rename({'f3': 'fiddle', 'f1': 'hello', 'f2': 's p a c e'})) == hl.Struct( + fiddle=3, hello=1, **{'s p a c e': 2} + ) try: hl.eval(s.rename({'f1': 'f2'})) @@ -3877,13 +4166,13 @@ def test_enumerate(self): hl.enumerate(a1), hl.enumerate(a1, start=-1000), hl.enumerate(a1, start=10, index_first=False), - hl.enumerate(a_empty, start=5) + hl.enumerate(a_empty, start=5), ) assert hl.eval(exprs) == ( [(0, 'foo'), (1, 'bar'), (2, 'baz')], [(-1000, 'foo'), (-999, 'bar'), (-998, 'baz')], [('foo', 10), ('bar', 11), ('baz', 12)], - [] + [], ) def test_split_line(self): @@ -3894,8 +4183,18 @@ def test_split_line(self): assert hl.eval(hl.str(s1)._split_line(' ', ['NA'], quote=None, regex=False)) == s1.split(' ') assert hl.eval(hl.str(s1)._split_line(r'\s+', ['NA'], quote=None, regex=True)) == s1.split(' ') assert hl.eval(hl.str(s3)._split_line(' ', ['1'], quote='"', regex=False)) == [None, '2'] - assert hl.eval(hl.str(s2)._split_line(' ', ['1', '2'], quote='"', regex=False)) == [None, None, '3 4', 'a b c d'] - assert hl.eval(hl.str(s2)._split_line(r'\s+', ['1', '2'], quote='"', regex=True)) == [None, None, '3 4', 'a b c d'] + assert hl.eval(hl.str(s2)._split_line(' ', ['1', '2'], quote='"', regex=False)) == [ + None, + None, + '3 4', + 'a b c d', + ] + assert hl.eval(hl.str(s2)._split_line(r'\s+', ['1', '2'], quote='"', regex=True)) == [ + None, + None, + '3 4', + 'a b c d', + ] def test_approx_cdf(): @@ -3910,9 +4209,9 @@ def test_approx_cdf(): # assumes cdf was computed from a (possibly shuffled) range table def cdf_max_observed_error(cdf): rank_error = max( - max(abs(cdf['values'][i+1] - cdf.ranks[i+1]), - abs(cdf['values'][i] + 1 - cdf.ranks[i+1])) - for i in range(len(cdf['values']) - 1)) + max(abs(cdf['values'][i + 1] - cdf.ranks[i + 1]), abs(cdf['values'][i] + 1 - cdf.ranks[i + 1])) + for i in range(len(cdf['values']) - 1) + ) return rank_error / cdf.ranks[-1] @@ -3932,7 +4231,7 @@ def test_approx_cdf_accuracy(cdf_test_data): t = cdf_test_data cdf = t.aggregate(hl.agg.approx_cdf(t.idx, 200)) error = cdf_max_observed_error(cdf) - assert(error < 0.015) + assert error < 0.015 def test_approx_cdf_all_missing(): @@ -3959,8 +4258,8 @@ def test_error_from_cdf(): table = hl.utils.range_table(100) table = table.annotate(i=table.idx) cdf = hl.agg.approx_cdf(table.i) - table.aggregate(_error_from_cdf(cdf, .001)) - table.aggregate(_error_from_cdf(cdf, .001, all_quantiles=True)) + table.aggregate(_error_from_cdf(cdf, 0.001)) + table.aggregate(_error_from_cdf(cdf, 0.001, all_quantiles=True)) def test_cdf_combine(cdf_test_data): @@ -3972,69 +4271,60 @@ def test_cdf_combine(cdf_test_data): cdf = _cdf_combine(200, cdf1, cdf2) cdf = hl.eval(_result_from_raw_cdf(cdf)) error = cdf_max_observed_error(cdf) - assert(error < 0.015) + assert error < 0.015 def test_approx_cdf_array_agg(): mt = hl.utils.range_matrix_table(5, 5) - mt = mt.annotate_entries(x = mt.col_idx) - mt = mt.group_cols_by(mt.col_idx).aggregate(cdf = hl.agg.approx_cdf(mt.x)) + mt = mt.annotate_entries(x=mt.col_idx) + mt = mt.group_cols_by(mt.col_idx).aggregate(cdf=hl.agg.approx_cdf(mt.x)) mt._force_count_rows() + @pytest.mark.parametrize("delimiter", ['\t', ',', '@']) @pytest.mark.parametrize("missing", ['NA', 'null']) @pytest.mark.parametrize("header", [True, False]) @test_timeout(local=6 * 60, batch=6 * 60) def test_export_entry(delimiter, missing, header): mt = hl.utils.range_matrix_table(3, 3) - mt = mt.key_cols_by(col_idx = mt.col_idx + 1) - mt = mt.annotate_entries(x = mt.row_idx * mt.col_idx) - mt = mt.annotate_entries(x = hl.or_missing(mt.x != 4, mt.x)) + mt = mt.key_cols_by(col_idx=mt.col_idx + 1) + mt = mt.annotate_entries(x=mt.row_idx * mt.col_idx) + mt = mt.annotate_entries(x=hl.or_missing(mt.x != 4, mt.x)) with hl.TemporaryFilename() as f: - mt.x.export(f, - delimiter=delimiter, - header=header, - missing=missing) + mt.x.export(f, delimiter=delimiter, header=header, missing=missing) if header: - actual = hl.import_matrix_table(f, - row_fields={'row_idx': hl.tint32}, - row_key=['row_idx'], - sep=delimiter, - missing=missing) + actual = hl.import_matrix_table( + f, row_fields={'row_idx': hl.tint32}, row_key=['row_idx'], sep=delimiter, missing=missing + ) else: - actual = hl.import_matrix_table(f, - row_fields={'f0': hl.tint32}, - row_key=['f0'], - sep=delimiter, - no_header=True, - missing=missing) + actual = hl.import_matrix_table( + f, row_fields={'f0': hl.tint32}, row_key=['f0'], sep=delimiter, no_header=True, missing=missing + ) actual = actual.rename({'f0': 'row_idx'}) - actual = actual.key_cols_by(col_idx = hl.int(actual.col_id)) + actual = actual.key_cols_by(col_idx=hl.int(actual.col_id)) actual = actual.drop('col_id') if not header: - actual = actual.key_cols_by(col_idx = actual.col_idx + 1) + actual = actual.key_cols_by(col_idx=actual.col_idx + 1) assert mt._same(actual) - expected_collect = [0, 0, 0, - 1, 2, 3, - 2, None, 6] + expected_collect = [0, 0, 0, 1, 2, 3, 2, None, 6] assert expected_collect == actual.x.collect() def test_stream_randomness(): def assert_contains_node(expr, node): - assert(expr._ir.base_search(lambda x: isinstance(x, node))) + assert expr._ir.base_search(lambda x: isinstance(x, node)) def assert_unique_uids(a): n1 = hl.eval(a.to_array().length()) n2 = len(hl.eval(hl.set(a.map(lambda x: hl.rand_int64()).to_array()))) - assert(n1 == n2) + assert n1 == n2 # test NA a = hl.missing('array') a = a.map(lambda x: x + hl.rand_int32(10)) assert_contains_node(a, ir.NA) - assert(hl.eval(a) == None) + assert hl.eval(a) is None # test If a1 = hl._stream_range(0, 5) @@ -4052,14 +4342,13 @@ def assert_unique_uids(a): a = hl._stream_range(10) a = a.map(lambda x: hl.rand_int64()).to_array() assert_contains_node(a, ir.ToArray) - assert(len(set(hl.eval(a))) == 10) + assert len(set(hl.eval(a))) == 10 # test ToStream - t = hl.rbind(hl.range(10), - lambda a: (a, a.map(lambda x: hl.rand_int64()))) + t = hl.rbind(hl.range(10), lambda a: (a, a.map(lambda x: hl.rand_int64()))) assert_contains_node(t, ir.ToStream) (a, r) = hl.eval(t) - assert(len(set(r)) == len(a)) + assert len(set(r)) == len(a) # test StreamZip a1 = hl._stream_range(10) @@ -4085,78 +4374,76 @@ def assert_unique_uids(a): a = hl._stream_range(10) a = a.fold(lambda acc, x: acc.append(hl.rand_int64()), hl.empty_array(hl.tint64)) assert_contains_node(a, ir.StreamFold) - assert(len(set(hl.eval(a))) == 10) + assert len(set(hl.eval(a))) == 10 # test StreamScan a = hl._stream_range(5) a = a.scan(lambda acc, x: acc.append(hl.rand_int64()), hl.empty_array(hl.tint64)) assert_contains_node(a, ir.StreamScan) - assert(len(set(hl.eval(a.to_array())[-1])) == 5) + assert len(set(hl.eval(a.to_array())[-1])) == 5 # test StreamAgg a = hl._stream_range(10) a = a.aggregate(lambda x: hl.agg.collect(hl.rand_int64())) assert_contains_node(a, ir.StreamAgg) - assert(len(set(hl.eval(a))) == 10) + assert len(set(hl.eval(a))) == 10 a = hl._stream_range(10) a = a.map(lambda x: hl._stream_range(10).aggregate(lambda y: hl.agg.count() + hl.rand_int64())) assert_contains_node(a, ir.StreamAgg) # test AggExplode t = hl.utils.range_table(5) - t = t.annotate(a = hl.range(t.idx)) + t = t.annotate(a=hl.range(t.idx)) a = hl.agg.explode(lambda x: hl.agg.collect_as_set(hl.rand_int64()), t.a) assert_contains_node(a, ir.AggExplode) - assert(len(t.aggregate(a)) == 10) + assert len(t.aggregate(a)) == 10 # test TableCount t = hl.utils.range_table(10) - t = t.annotate(x = hl.rand_int64()) - assert(t.count() == 10) + t = t.annotate(x=hl.rand_int64()) + assert t.count() == 10 # test TableGetGlobals t = hl.utils.range_table(10) - t = t.annotate(x = hl.rand_int64()) + t = t.annotate(x=hl.rand_int64()) g = t.index_globals() assert_contains_node(g, ir.TableGetGlobals) - assert(len(hl.eval(g)) == 0) + assert len(hl.eval(g)) == 0 # test TableCollect t = hl.utils.range_table(10) - t = t.annotate(x = hl.rand_int64()) + t = t.annotate(x=hl.rand_int64()) a = t.collect() - assert(len(set(a)) == 10) + assert len(set(a)) == 10 # test TableAggregate t = hl.utils.range_table(10) a = t.aggregate(hl.agg.collect(hl.rand_int64()).map(lambda x: x + hl.rand_int64())) - assert(len(set(a)) == 10) + assert len(set(a)) == 10 # test MatrixCount mt = hl.utils.range_matrix_table(10, 10) - mt = mt.annotate_entries(x = hl.rand_int64()) - assert(mt.count() == (10, 10)) + mt = mt.annotate_entries(x=hl.rand_int64()) + assert mt.count() == (10, 10) # test MatrixAggregate mt = hl.utils.range_matrix_table(5, 5) a = mt.aggregate_entries(hl.agg.collect(hl.rand_int64()).map(lambda x: x + hl.rand_int64())) - assert(len(set(a)) == 25) + assert len(set(a)) == 25 + + def test_keyed_intersection(): - a1 = hl.literal( - [ - hl.Struct(a=5, b='foo'), - hl.Struct(a=7, b='bar'), - hl.Struct(a=9, b='baz'), - ] - ) - a2 = hl.literal( - [ - hl.Struct(a=5, b='foo'), - hl.Struct(a=6, b='qux'), - hl.Struct(a=8, b='qux'), - hl.Struct(a=9, b='baz'), - ] - ) + a1 = hl.literal([ + hl.Struct(a=5, b='foo'), + hl.Struct(a=7, b='bar'), + hl.Struct(a=9, b='baz'), + ]) + a2 = hl.literal([ + hl.Struct(a=5, b='foo'), + hl.Struct(a=6, b='qux'), + hl.Struct(a=8, b='qux'), + hl.Struct(a=9, b='baz'), + ]) assert hl.eval(hl.keyed_intersection(a1, a2, key=['a'])) == [ hl.Struct(a=5, b='foo'), hl.Struct(a=9, b='baz'), @@ -4164,21 +4451,17 @@ def test_keyed_intersection(): def test_keyed_union(): - a1 = hl.literal( - [ - hl.Struct(a=5, b='foo'), - hl.Struct(a=7, b='bar'), - hl.Struct(a=9, b='baz'), - ] - ) - a2 = hl.literal( - [ - hl.Struct(a=5, b='foo'), - hl.Struct(a=6, b='qux'), - hl.Struct(a=8, b='qux'), - hl.Struct(a=9, b='baz'), - ] - ) + a1 = hl.literal([ + hl.Struct(a=5, b='foo'), + hl.Struct(a=7, b='bar'), + hl.Struct(a=9, b='baz'), + ]) + a2 = hl.literal([ + hl.Struct(a=5, b='foo'), + hl.Struct(a=6, b='qux'), + hl.Struct(a=8, b='qux'), + hl.Struct(a=9, b='baz'), + ]) assert hl.eval(hl.keyed_union(a1, a2, key=['a'])) == [ hl.Struct(a=5, b='foo'), hl.Struct(a=6, b='qux'), @@ -4194,15 +4477,18 @@ def test_to_relational_row_and_col_refs(): mt = mt.annotate_cols(y=1) mt = mt.annotate_entries(z=1) - assert mt.row._to_relational_preserving_rows_and_cols('x')[1].row.dtype == hl.tstruct(row_idx=hl.tint32, x=hl.tint32) + assert mt.row._to_relational_preserving_rows_and_cols('x')[1].row.dtype == hl.tstruct( + row_idx=hl.tint32, x=hl.tint32 + ) assert mt.row_key._to_relational_preserving_rows_and_cols('x')[1].row.dtype == hl.tstruct(row_idx=hl.tint32) - assert mt.col._to_relational_preserving_rows_and_cols('x')[1].row.dtype == hl.tstruct(col_idx=hl.tint32, y=hl.tint32) + assert mt.col._to_relational_preserving_rows_and_cols('x')[1].row.dtype == hl.tstruct( + col_idx=hl.tint32, y=hl.tint32 + ) assert mt.col_key._to_relational_preserving_rows_and_cols('x')[1].row.dtype == hl.tstruct(col_idx=hl.tint32) def test_locus_addition(): - rg = hl.get_reference('GRCh37') len_1 = rg.lengths['1'] loc = hl.locus('1', 5, reference_genome='GRCh37') @@ -4225,42 +4511,60 @@ def test_reservoir_sampling(): ) sample_sizes = [99, 811, 900, 1000, 3333] - (stats, samples) = ht.aggregate((hl.agg.stats(ht.idx), tuple([hl.sorted(hl.agg._reservoir_sample(ht.idx, size)) for size in sample_sizes]))) + (stats, samples) = ht.aggregate(( + hl.agg.stats(ht.idx), + tuple([hl.sorted(hl.agg._reservoir_sample(ht.idx, size)) for size in sample_sizes]), + )) sample_variance = stats['stdev'] ** 2 sample_mean = stats['mean'] - for sample, sample_size in zip(samples, sample_sizes): + for iteration, (sample, sample_size) in enumerate(zip(samples, sample_sizes)): mean = np.mean(sample) expected_stdev = math.sqrt(sample_variance / sample_size) - assert abs(mean - sample_mean) / expected_stdev < 4 , (iteration, sample_size, abs(mean - sample_mean) / expected_stdev) + assert abs(mean - sample_mean) / expected_stdev < 4, ( + iteration, + sample_size, + abs(mean - sample_mean) / expected_stdev, + ) def test_local_agg(): - x = hl.literal([1,2,3,4]) + x = hl.literal([1, 2, 3, 4]) assert hl.eval(x.aggregate(lambda x: hl.agg.sum(x))) == 10 def test_zip_join_producers(): - contexts = hl.literal([1,2,3]) - zj = hl._zip_join_producers(contexts, - lambda i: hl.range(i).map(lambda x: hl.struct(k=x, stream_id=i)), - ['k'], - lambda k, vals: k.annotate(vals=vals)) + contexts = hl.literal([1, 2, 3]) + zj = hl._zip_join_producers( + contexts, + lambda i: hl.range(i).map(lambda x: hl.struct(k=x, stream_id=i)), + ['k'], + lambda k, vals: k.annotate(vals=vals), + ) assert hl.eval(zj) == [ - hl.utils.Struct(k=0, vals=[ - hl.utils.Struct(k=0, stream_id=1), - hl.utils.Struct(k=0, stream_id=2), - hl.utils.Struct(k=0, stream_id=3), - ]), - hl.utils.Struct(k=1, vals=[ - None, - hl.utils.Struct(k=1, stream_id=2), - hl.utils.Struct(k=1, stream_id=3), - ]), - hl.utils.Struct(k=2, vals=[ - None, - None, - hl.utils.Struct(k=2, stream_id=3), - ]) + hl.utils.Struct( + k=0, + vals=[ + hl.utils.Struct(k=0, stream_id=1), + hl.utils.Struct(k=0, stream_id=2), + hl.utils.Struct(k=0, stream_id=3), + ], + ), + hl.utils.Struct( + k=1, + vals=[ + None, + hl.utils.Struct(k=1, stream_id=2), + hl.utils.Struct(k=1, stream_id=3), + ], + ), + hl.utils.Struct( + k=2, + vals=[ + None, + None, + hl.utils.Struct(k=2, stream_id=3), + ], + ), ] diff --git a/hail/python/test/hail/expr/test_freezing.py b/hail/python/test/hail/expr/test_freezing.py index 9a39acaa22a..5418ad908e2 100644 --- a/hail/python/test/hail/expr/test_freezing.py +++ b/hail/python/test/hail/expr/test_freezing.py @@ -1,4 +1,5 @@ from typing import Dict + import hail as hl from hailtop.frozendict import frozendict from hailtop.hail_frozenlist import frozenlist @@ -6,7 +7,7 @@ def test_collect_as_set_list(): t = hl.utils.range_matrix_table(1, 1) - t = t.annotate_entries(l = ['hello']) + t = t.annotate_entries(l=['hello']) result = t.aggregate_entries(hl.agg.collect_as_set(t.l)) assert result == {frozenlist(['hello'])} @@ -14,7 +15,7 @@ def test_collect_as_set_list(): def test_counter_list(): t = hl.utils.range_matrix_table(1, 1) - t = t.annotate_entries(l = ['hello']) + t = t.annotate_entries(l=['hello']) result = t.aggregate_entries(hl.agg.counter(t.l)) assert list(result) == [frozenlist(['hello'])] @@ -24,7 +25,7 @@ def test_counter_list(): def test_collect_as_set_tuple_of_list(): t = hl.utils.range_matrix_table(1, 1) - t = t.annotate_entries(l = (['hello'],)) + t = t.annotate_entries(l=(['hello'],)) result = t.aggregate_entries(hl.agg.collect_as_set(t.l)) assert result == {(frozenlist(['hello']),)} @@ -32,7 +33,7 @@ def test_collect_as_set_tuple_of_list(): def test_counter_tuple_of_list(): t = hl.utils.range_matrix_table(1, 1) - t = t.annotate_entries(l = (['hello'],)) + t = t.annotate_entries(l=(['hello'],)) result = t.aggregate_entries(hl.agg.counter(t.l)) assert list(result) == [(frozenlist(['hello']),)] @@ -42,7 +43,7 @@ def test_counter_tuple_of_list(): def test_collect_as_set_struct_of_list(): t = hl.utils.range_matrix_table(1, 1) - t = t.annotate_entries(l = hl.struct(bad=['hello'], good=3)) + t = t.annotate_entries(l=hl.struct(bad=['hello'], good=3)) result = t.aggregate_entries(hl.agg.collect_as_set(t.l)) assert result == {(hl.Struct(bad=frozenlist(['hello']), good=3))} @@ -50,7 +51,7 @@ def test_collect_as_set_struct_of_list(): def test_counter_struct_of_list(): t = hl.utils.range_matrix_table(1, 1) - t = t.annotate_entries(l = hl.struct(bad=['hello'], good=3)) + t = t.annotate_entries(l=hl.struct(bad=['hello'], good=3)) result = t.aggregate_entries(hl.agg.counter(t.l)) assert list(result) == [hl.Struct(bad=frozenlist(['hello']), good=3)] @@ -60,7 +61,7 @@ def test_counter_struct_of_list(): def test_collect_as_set_dict_value_list(): t = hl.utils.range_matrix_table(1, 1) - t = t.annotate_entries(l = hl.dict([(3, ['hello'])])) + t = t.annotate_entries(l=hl.dict([(3, ['hello'])])) result = t.aggregate_entries(hl.agg.collect_as_set(t.l)) assert result == {frozendict({3: frozenlist(['hello'])})} @@ -68,7 +69,7 @@ def test_collect_as_set_dict_value_list(): def test_counter_dict_value_list(): t = hl.utils.range_matrix_table(1, 1) - t = t.annotate_entries(l = hl.dict([(3, ['hello'])])) + t = t.annotate_entries(l=hl.dict([(3, ['hello'])])) result = t.aggregate_entries(hl.agg.counter(t.l)) assert list(result) == [frozendict({3: frozenlist(['hello'])})] @@ -78,7 +79,7 @@ def test_counter_dict_value_list(): def test_collect_as_set_list_list_list_set_list(): t = hl.utils.range_matrix_table(1, 1) - t = t.annotate_entries(l = [[[hl.set([['hello']])]]]) + t = t.annotate_entries(l=[[[hl.set([['hello']])]]]) result = t.aggregate_entries(hl.agg.collect_as_set(t.l)) assert result == {frozenlist([frozenlist([frozenlist([frozenset([frozenlist(['hello'])])])])])} @@ -86,7 +87,7 @@ def test_collect_as_set_list_list_list_set_list(): def test_counter_list_list_list_set_list(): t = hl.utils.range_matrix_table(1, 1) - t = t.annotate_entries(l = [[[hl.set([['hello']])]]]) + t = t.annotate_entries(l=[[[hl.set([['hello']])]]]) result = t.aggregate_entries(hl.agg.counter(t.l)) assert list(result) == [frozenlist([frozenlist([frozenlist([frozenset([frozenlist(['hello'])])])])])] @@ -96,7 +97,7 @@ def test_counter_list_list_list_set_list(): def test_collect_dict_value_list(): t = hl.utils.range_matrix_table(1, 1) - t = t.annotate_entries(l = hl.dict([(3, ['hello'])])) + t = t.annotate_entries(l=hl.dict([(3, ['hello'])])) result = t.aggregate_entries(hl.agg.collect(t.l)) # NB: We never return dict, only frozendict, so we assert that. However, dict *values* must be @@ -109,7 +110,7 @@ def test_collect_dict_value_list(): def test_collect_dict_key_list(): t = hl.utils.range_matrix_table(1, 1) - t = t.annotate_entries(l = hl.dict([(['hello'], 3)])) + t = t.annotate_entries(l=hl.dict([(['hello'], 3)])) result = t.aggregate_entries(hl.agg.collect(t.l)) assert result == [frozendict({frozenlist(['hello']): 3})] @@ -117,7 +118,7 @@ def test_collect_dict_key_list(): def test_collect_dict_key_and_value_list(): t = hl.utils.range_matrix_table(1, 1) - t = t.annotate_entries(l = hl.dict([(['hello'], ['goodbye'])])) + t = t.annotate_entries(l=hl.dict([(['hello'], ['goodbye'])])) result = t.aggregate_entries(hl.agg.collect(t.l)) # NB: See note in test_collect_dict_value_list. @@ -126,7 +127,7 @@ def test_collect_dict_key_and_value_list(): def test_collect_set_list(): t = hl.utils.range_matrix_table(1, 1) - t = t.annotate_entries(l = hl.set([['hello']])) + t = t.annotate_entries(l=hl.set([['hello']])) result = t.aggregate_entries(hl.agg.collect(t.l)) assert result == [frozenset({frozenlist(['hello'])})] @@ -134,7 +135,7 @@ def test_collect_set_list(): def test_collect_set_dict_list_list(): t = hl.utils.range_matrix_table(1, 1) - t = t.annotate_entries(l = hl.set([hl.dict([(['hello'], ['goodbye'])])])) + t = t.annotate_entries(l=hl.set([hl.dict([(['hello'], ['goodbye'])])])) result = t.aggregate_entries(hl.agg.collect(t.l)) d: Dict[frozenlist[str], frozenlist[str]] = {frozenlist(['hello']): frozenlist(['goodbye'])} @@ -143,7 +144,7 @@ def test_collect_set_dict_list_list(): def test_collect_set_tuple_struct_struct_list(): t = hl.utils.range_matrix_table(1, 1) - t = t.annotate_entries(l = hl.set([(hl.struct(a=hl.struct(inside=['hello'], aside=4.0), b='abc'), 3)])) + t = t.annotate_entries(l=hl.set([(hl.struct(a=hl.struct(inside=['hello'], aside=4.0), b='abc'), 3)])) result = t.aggregate_entries(hl.agg.collect(t.l)) assert result == [frozenset({(hl.Struct(a=hl.Struct(inside=frozenlist(['hello']), aside=4.0), b='abc'), 3)})] diff --git a/hail/python/test/hail/expr/test_functions.py b/hail/python/test/hail/expr/test_functions.py index 5019de46db7..c68e1435b58 100644 --- a/hail/python/test/hail/expr/test_functions.py +++ b/hail/python/test/hail/expr/test_functions.py @@ -1,21 +1,21 @@ -import hail as hl -import scipy.stats as spst import pytest +import scipy.stats as spst + +import hail as hl + from ..helpers import resource def test_deprecated_binom_test(): - assert hl.eval(hl.binom_test(2, 10, 0.5, 'two.sided')) == \ - pytest.approx(spst.binom_test(2, 10, 0.5, 'two-sided')) + assert hl.eval(hl.binom_test(2, 10, 0.5, 'two.sided')) == pytest.approx(spst.binom_test(2, 10, 0.5, 'two-sided')) def test_binom_test(): - arglists = [[2, 10, 0.5, 'two-sided'], - [4, 10, 0.5, 'less'], - [32, 50, 0.4, 'greater']] + arglists = [[2, 10, 0.5, 'two-sided'], [4, 10, 0.5, 'less'], [32, 50, 0.4, 'greater']] for args in arglists: assert hl.eval(hl.binom_test(*args)) == pytest.approx(spst.binom_test(*args)), args + def test_pchisqtail(): def right_tail_from_scipy(x, df, ncp): if ncp: @@ -23,12 +23,7 @@ def right_tail_from_scipy(x, df, ncp): else: return 1 - spst.chi2.cdf(x, df) - arglists = [[3, 1, 2], - [5, 1, None], - [1, 3, 4], - [1, 3, None], - [3, 6, 0], - [3, 6, None]] + arglists = [[3, 1, 2], [5, 1, None], [1, 3, 4], [1, 3, None], [3, 6, 0], [3, 6, None]] for args in arglists: assert hl.eval(hl.pchisqtail(*args)) == pytest.approx(right_tail_from_scipy(*args)), args @@ -50,19 +45,21 @@ def test_pgenchisq(): 'lim': hl.tint32, 'acc': hl.tfloat64, 'expected': hl.tfloat64, - 'expected_n_iterations': hl.tint32 - } + 'expected_n_iterations': hl.tint32, + }, ) ht = ht.add_index('line_number') - ht = ht.annotate(line_number = ht.line_number + 1) - ht = ht.annotate(genchisq_result = hl.pgenchisq( - ht.c, ht.weights, ht.k, ht.lam, 0.0, ht.sigma, max_iterations=ht.lim, min_accuracy=ht.acc - )) + ht = ht.annotate(line_number=ht.line_number + 1) + ht = ht.annotate( + genchisq_result=hl.pgenchisq( + ht.c, ht.weights, ht.k, ht.lam, 0.0, ht.sigma, max_iterations=ht.lim, min_accuracy=ht.acc + ) + ) tests = ht.collect() for test in tests: assert abs(test.genchisq_result.value - test.expected) < 0.0000005, str(test) assert test.genchisq_result.fault == 0, str(test) - assert test.genchisq_result.converged == True, str(test) + assert test.genchisq_result.converged is True, str(test) assert test.genchisq_result.n_iterations == test.expected_n_iterations, str(test) @@ -74,12 +71,7 @@ def test_array(): hl.array(hl.nd.array([1, 2, 3, 3])), )) - expected = ( - [1, 2, 3, 3], - [1, 2, 3], - [(1, 5), (7, 4)], - [1, 2, 3, 3] - ) + expected = ([1, 2, 3, 3], [1, 2, 3], [(1, 5), (7, 4)], [1, 2, 3, 3]) assert actual == expected diff --git a/hail/python/test/hail/expr/test_math.py b/hail/python/test/hail/expr/test_math.py index d13a3766df0..f8f10691812 100644 --- a/hail/python/test/hail/expr/test_math.py +++ b/hail/python/test/hail/expr/test_math.py @@ -1,15 +1,17 @@ -import hail as hl import scipy.special as scsp -import pytest + +import hail as hl + def test_logit(): - assert hl.eval(hl.logit(.5)) == 0.0 + assert hl.eval(hl.logit(0.5)) == 0.0 assert hl.eval(hl.is_infinite(hl.logit(1.0))) assert hl.eval(hl.is_nan(hl.logit(1.01))) - assert hl.eval(hl.logit(.27)) == scsp.logit(.27) + assert hl.eval(hl.logit(0.27)) == scsp.logit(0.27) + def test_expit(): assert hl.eval(hl.expit(0.0)) == 0.5 assert hl.eval(hl.expit(800)) == 1.0 assert hl.eval(hl.expit(-920)) == 0.0 - assert hl.eval(hl.expit(.75)) == scsp.expit(.75) + assert hl.eval(hl.expit(0.75)) == scsp.expit(0.75) diff --git a/hail/python/test/hail/expr/test_ndarrays.py b/hail/python/test/hail/expr/test_ndarrays.py index 90d4d317cb8..45104ddecc9 100644 --- a/hail/python/test/hail/expr/test_ndarrays.py +++ b/hail/python/test/hail/expr/test_ndarrays.py @@ -1,11 +1,15 @@ import math -import numpy as np import re -from ..helpers import * + +import numpy as np import pytest +import hail as hl from hail.utils.java import FatalError, HailUserError +from ..helpers import assert_all_eval_to, assert_evals_to + + def assert_ndarrays(asserter, exprs_and_expecteds): exprs, expecteds = zip(*exprs_and_expecteds) @@ -13,7 +17,7 @@ def assert_ndarrays(asserter, exprs_and_expecteds): evaled_exprs = hl.eval(expr_tuple) evaled_and_expected = zip(evaled_exprs, expecteds) - for (idx, (evaled, expected)) in enumerate(evaled_and_expected): + for idx, (evaled, expected) in enumerate(evaled_and_expected): assert asserter(evaled, expected), f"NDArray comparison {idx} failed, got: {evaled}, expected: {expected}" @@ -26,7 +30,6 @@ def assert_ndarrays_almost_eq(*expr_and_expected): def test_ndarray_ref(): - scalar = 5.0 np_scalar = np.array(scalar) h_scalar = hl.nd.array(scalar) @@ -35,10 +38,7 @@ def test_ndarray_ref(): assert_evals_to(h_scalar[()], 5.0) assert_evals_to(h_np_scalar[()], 5.0) - cube = [[[0, 1], - [2, 3]], - [[4, 5], - [6, 7]]] + cube = [[[0, 1], [2, 3]], [[4, 5], [6, 7]]] h_cube = hl.nd.array(cube) h_np_cube = hl.nd.array(np.array(cube)) missing = hl.nd.array(hl.missing(hl.tarray(hl.tint32))) @@ -52,7 +52,7 @@ def test_ndarray_ref(): (hl.nd.array([[[1, 2]], [[3, 4]]])[1, 0, 0], 3), (missing[1], None), (hl.nd.array([1, 2, 3])[hl.missing(hl.tint32)], None), - (h_cube[0, 0, hl.missing(hl.tint32)], None) + (h_cube[0, 0, hl.missing(hl.tint32)], None), ) @@ -96,7 +96,7 @@ def test_ndarray_slice(): a = [0, 1] an = np.array(a) ah = hl.nd.array(a) - ae_np = np.arange(4*4*5*6*5*4).reshape((4, 4, 5, 6, 5, 4)) + ae_np = np.arange(4 * 4 * 5 * 6 * 5 * 4).reshape((4, 4, 5, 6, 5, 4)) ae = hl.nd.array(ae_np) assert_ndarrays_eq( (rect_prism[:, :, :], np_rect_prism[:, :, :]), @@ -105,10 +105,11 @@ def test_ndarray_slice(): (rect_prism[:, :, 1:4:2], np_rect_prism[:, :, 1:4:2]), (rect_prism[:, 2, 1:4:2], np_rect_prism[:, 2, 1:4:2]), (rect_prism[0, 2, 1:4:2], np_rect_prism[0, 2, 1:4:2]), - (rect_prism[0, :, 1:4:2] + rect_prism[:, :1, 1:4:2], - np_rect_prism[0, :, 1:4:2] + np_rect_prism[:, :1, 1:4:2]), - (rect_prism[0:, :, 1:4:2] + rect_prism[:, :1, 1:4:2], - np_rect_prism[0:, :, 1:4:2] + np_rect_prism[:, :1, 1:4:2]), + (rect_prism[0, :, 1:4:2] + rect_prism[:, :1, 1:4:2], np_rect_prism[0, :, 1:4:2] + np_rect_prism[:, :1, 1:4:2]), + ( + rect_prism[0:, :, 1:4:2] + rect_prism[:, :1, 1:4:2], + np_rect_prism[0:, :, 1:4:2] + np_rect_prism[:, :1, 1:4:2], + ), (rect_prism[0, 0, -3:-1], np_rect_prism[0, 0, -3:-1]), (rect_prism[-1, 0:1, 3:0:-1], np_rect_prism[-1, 0:1, 3:0:-1]), # partial indexing @@ -143,8 +144,6 @@ def test_ndarray_slice(): (flat[-4:-1:2], np_flat[-4:-1:2]), # ellipses inclusion (flat[...], np_flat[...]), - - (mat[::-1, :], np_mat[::-1, :]), (mat[0, 1:4:2] + mat[:, 1:4:2], np_mat[0, 1:4:2] + np_mat[:, 1:4:2]), (mat[-1:4:1, 0], np_mat[-1:4:1, 0]), @@ -179,25 +178,23 @@ def test_ndarray_slice(): (mat[0:1], np_mat[0:1]), # ellipses inclusion (mat[...], np_mat[...]), - (ah[:-3:1], an[:-3:1]), (ah[:-3:-1], an[:-3:-1]), (ah[-3::-1], an[-3::-1]), (ah[-3::1], an[-3::1]), - # ellipses inclusion (ae[..., 3], ae_np[..., 3]), (ae[3, ...], ae_np[3, ...]), (ae[2, 3, 1:2:2, ...], ae_np[2, 3, 1:2:2, ...]), (ae[3, 2, 3, ..., 2], ae_np[3, 2, 3, ..., 2]), (ae[3, 2, 2, ..., 2, 1:2:2], ae_np[3, 2, 2, ..., 2, 1:2:2]), - (ae[3, :, hl.nd.newaxis, ..., :, hl.nd.newaxis, 2], ae_np[3, :, np.newaxis, ..., :, np.newaxis, 2]) + (ae[3, :, hl.nd.newaxis, ..., :, hl.nd.newaxis, 2], ae_np[3, :, np.newaxis, ..., :, np.newaxis, 2]), ) - assert hl.eval(flat[hl.missing(hl.tint32):4:1]) is None - assert hl.eval(flat[4:hl.missing(hl.tint32)]) is None - assert hl.eval(flat[4:10:hl.missing(hl.tint32)]) is None - assert hl.eval(rect_prism[:, :, 0:hl.missing(hl.tint32):1]) is None + assert hl.eval(flat[hl.missing(hl.tint32) : 4 : 1]) is None + assert hl.eval(flat[4 : hl.missing(hl.tint32)]) is None + assert hl.eval(flat[4 : 10 : hl.missing(hl.tint32)]) is None + assert hl.eval(rect_prism[:, :, 0 : hl.missing(hl.tint32) : 1]) is None assert hl.eval(rect_prism[hl.missing(hl.tint32), :, :]) is None with pytest.raises(HailUserError, match="Slice step cannot be zero"): @@ -221,10 +218,7 @@ def test_ndarray_transposed_slice(): np_a = np.array([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]) aT = a.T np_aT = np_a.T - assert_ndarrays_eq( - (a, np_a), - (aT[0:aT.shape[0], 0:5], np_aT[0:np_aT.shape[0], 0:5]) - ) + assert_ndarrays_eq((a, np_a), (aT[0 : aT.shape[0], 0:5], np_aT[0 : np_aT.shape[0], 0:5])) def test_ndarray_eval(): @@ -238,7 +232,7 @@ def test_ndarray_eval(): np_equiv = np.array(data_list, dtype=np.int32) np_equiv_fortran_style = np.asfortranarray(np_equiv) np_equiv_extra_dimension = np_equiv.reshape((3, 1, 3)) - assert(np.array_equal(evaled, np_equiv)) + assert np.array_equal(evaled, np_equiv) assert np.array_equal(hl.eval(hl.nd.array([])), np.array([])) @@ -258,7 +252,8 @@ def test_ndarray_eval(): # Testing from nested hail arrays assert np.array_equal( - hl.eval(hl.nd.array(hl.array([hl.array(x) for x in data_list]))), np.arange(9).reshape((3, 3)) + 1) + hl.eval(hl.nd.array(hl.array([hl.array(x) for x in data_list]))), np.arange(9).reshape((3, 3)) + 1 + ) # Testing missing data assert hl.eval(hl.nd.array(hl.missing(hl.tarray(hl.tint32)))) is None @@ -306,7 +301,7 @@ def test_ndarray_shape(): ((row + nd).shape, (np_row + np_nd).shape), ((row + col).shape, (np_row + np_col).shape), (m.transpose().shape, np_m.transpose().shape), - (missing.shape, None) + (missing.shape, None), ) @@ -348,12 +343,11 @@ def test_ndarray_reshape(): (hypercube.reshape((5, 7, 9, 3)).reshape((7, 9, 3, 5)), np_hypercube.reshape((7, 9, 3, 5))), (hypercube.reshape(hl.tuple([5, 7, 9, 3])), np_hypercube.reshape((5, 7, 9, 3))), (shape_zero.reshape((0, 5)), np_shape_zero.reshape((0, 5))), - (shape_zero.reshape((-1, 5)), np_shape_zero.reshape((-1, 5))) + (shape_zero.reshape((-1, 5)), np_shape_zero.reshape((-1, 5))), ) assert hl.eval(hl.missing(hl.tndarray(hl.tfloat, 2)).reshape((4, 5))) is None - assert hl.eval(hl.nd.array(hl.range(20)).reshape( - hl.missing(hl.ttuple(hl.tint64, hl.tint64)))) is None + assert hl.eval(hl.nd.array(hl.range(20)).reshape(hl.missing(hl.ttuple(hl.tint64, hl.tint64)))) is None with pytest.raises(HailUserError) as exc: hl.eval(hl.literal(np_cube).reshape((-1, -1))) @@ -396,8 +390,8 @@ def test_ndarray_map1(): assert_ndarrays_eq( (b, [[-2, -3, -4], [-5, -6, -7]]), (b2, [[4, 9, 16], [25, 36, 49]]), - (c, [[True, True, True], - [True, True, True]])) + (c, [[True, True, True], [True, True, True]]), + ) assert hl.eval(hl.missing(hl.tndarray(hl.tfloat, 1)).map(lambda x: x * 2)) is None @@ -411,20 +405,13 @@ def test_ndarray_map1(): def test_ndarray_map2(): - a = 2.0 b = 3.0 x = np.array([a, b]) y = np.array([b, a]) row_vec = np.array([[1, 2]]) - cube1 = np.array([[[1, 2], - [3, 4]], - [[5, 6], - [7, 8]]]) - cube2 = np.array([[[9, 10], - [11, 12]], - [[13, 14], - [15, 16]]]) + cube1 = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + cube2 = np.array([[[9, 10], [11, 12]], [[13, 14], [15, 16]]]) empty = np.array([], np.int32).reshape((0, 2, 2)) na = hl.nd.array(a) @@ -444,8 +431,8 @@ def test_ndarray_map2(): (na + na, np.array(a + a)), (nx + ny, x + y), (ncube1 + ncube2, cube1 + cube2), - (nx.map2(y, lambda c, d: c+d), x + y), - (ncube1.map2(cube2, lambda c, d: c+d), cube1 + cube2), + (nx.map2(y, lambda c, d: c + d), x + y), + (ncube1.map2(cube2, lambda c, d: c + d), cube1 + cube2), # Broadcasting (ncube1 + na, cube1 + a), (na + ncube1, a + cube1), @@ -456,10 +443,8 @@ def test_ndarray_map2(): (nrow_vec + ncube1, row_vec + cube1), (ncube1 + nrow_vec, cube1 + row_vec), (nrow_vec + nempty, row_vec + empty), - (ncube1.map2(na, lambda c, d: c+d), cube1 + a), - (nrow_vec.map2(ncube1, lambda c, d: c+d), row_vec + cube1), - - + (ncube1.map2(na, lambda c, d: c + d), cube1 + a), + (nrow_vec.map2(ncube1, lambda c, d: c + d), row_vec + cube1), # Subtraction (na - na, np.array(a - a)), (nx - nx, x - x), @@ -471,7 +456,6 @@ def test_ndarray_map2(): (ny - ncube1, y - cube1), (ncube1 - nrow_vec, cube1 - row_vec), (nrow_vec - ncube1, row_vec - cube1), - # Multiplication (na * na, np.array(a * a)), (nx * nx, x * x), @@ -485,8 +469,6 @@ def test_ndarray_map2(): (ny * ncube1, y * cube1), (ncube1 * nrow_vec, cube1 * row_vec), (nrow_vec * ncube1, row_vec * cube1), - - # Floor div (na // na, np.array(a // a)), (nx // nx, x // x), @@ -499,7 +481,7 @@ def test_ndarray_map2(): (ncube1 // ny, cube1 // y), (ny // ncube1, y // cube1), (ncube1 // nrow_vec, cube1 // row_vec), - (nrow_vec // ncube1, row_vec // cube1) + (nrow_vec // ncube1, row_vec // cube1), ) # Division @@ -515,7 +497,8 @@ def test_ndarray_map2(): (ncube1 / ny, cube1 / y), (ny / ncube1, y / cube1), (ncube1 / nrow_vec, cube1 / row_vec), - (nrow_vec / ncube1, row_vec / cube1)) + (nrow_vec / ncube1, row_vec / cube1), + ) # Missingness tests missing = hl.missing(hl.tndarray(hl.tfloat64, 2)) @@ -525,14 +508,14 @@ def test_ndarray_map2(): assert hl.eval(missing + present) is None assert hl.eval(present + missing) is None + def test_ndarray_sum(): np_m = np.array([[1, 2], [3, 4]]) m = hl.nd.array(np_m) assert_ndarrays_eq( - (m.sum(axis=0), np_m.sum(axis=0)), - (m.sum(axis=1), np_m.sum(axis=1)), - (m.sum(tuple([])), np_m.sum(tuple([])))) + (m.sum(axis=0), np_m.sum(axis=0)), (m.sum(axis=1), np_m.sum(axis=1)), (m.sum(tuple([])), np_m.sum(tuple([]))) + ) assert hl.eval(m.sum()) == 10 assert hl.eval(m.sum((0, 1))) == 10 @@ -552,10 +535,7 @@ def test_ndarray_sum(): def test_ndarray_transpose(): np_v = np.array([1, 2, 3]) np_m = np.array([[1, 2, 3], [4, 5, 6]]) - np_cube = np.array([[[1, 2], - [3, 4]], - [[5, 6], - [7, 8]]]) + np_cube = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) v = hl.nd.array(np_v) m = hl.nd.array(np_m) cube = hl.nd.array(np_cube) @@ -565,7 +545,8 @@ def test_ndarray_transpose(): (v.T, np_v), (m.T, np_m.T), (cube.transpose((0, 2, 1)), np_cube.transpose((0, 2, 1))), - (cube.T, np_cube.T)) + (cube.T, np_cube.T), + ) assert hl.eval(hl.missing(hl.tndarray(hl.tfloat, 1)).T) is None @@ -581,6 +562,7 @@ def test_ndarray_transpose(): cube.transpose((1, 1, 1)) assert "Axes cannot contain duplicates" in str(exc.value) + def test_ndarray_matmul(): np_v = np.array([1, 2]) np_y = np.array([1, 1, 1]) @@ -641,15 +623,12 @@ def test_ndarray_matmul(): (broadcasted_mat @ rect_prism, np_broadcasted_mat @ np_rect_prism), (six_dim_tensor @ five_dim_tensor, np_six_dim_tensor @ np_five_dim_tensor), (zero_by_four @ ones_float64, np_zero_by_four, np_ones_float64), - (zero_by_four.transpose() @ zero_by_four, np_zero_by_four.transpose() @ np_zero_by_four) + (zero_by_four.transpose() @ zero_by_four, np_zero_by_four.transpose() @ np_zero_by_four), ) - assert hl.eval(hl.missing(hl.tndarray(hl.tfloat64, 2)) @ - hl.missing(hl.tndarray(hl.tfloat64, 2))) is None - assert hl.eval(hl.missing(hl.tndarray(hl.tint64, 2)) @ - hl.nd.array(np.arange(10).reshape(5, 2))) is None - assert hl.eval(hl.nd.array(np.arange(10).reshape(5, 2)) @ - hl.missing(hl.tndarray(hl.tint64, 2))) is None + assert hl.eval(hl.missing(hl.tndarray(hl.tfloat64, 2)) @ hl.missing(hl.tndarray(hl.tfloat64, 2))) is None + assert hl.eval(hl.missing(hl.tndarray(hl.tint64, 2)) @ hl.nd.array(np.arange(10).reshape(5, 2))) is None + assert hl.eval(hl.nd.array(np.arange(10).reshape(5, 2)) @ hl.missing(hl.tndarray(hl.tint64, 2))) is None assert np.array_equal(hl.eval(ones_int32 @ ones_float64), np_ones_int32 @ np_ones_float64) @@ -664,12 +643,15 @@ def test_ndarray_matmul(): with pytest.raises(HailUserError) as exc: hl.eval(r @ r) - assert "Matrix dimensions incompatible: (2, 3) can't be multiplied by matrix with dimensions (2, 3)" in str(exc.value), str(exc.value) + assert "Matrix dimensions incompatible: (2, 3) can't be multiplied by matrix with dimensions (2, 3)" in str( + exc.value + ), str(exc.value) with pytest.raises(HailUserError) as exc: hl.eval(hl.nd.array([1, 2]) @ hl.nd.array([1, 2, 3])) assert "Matrix dimensions incompatible" in str(exc.value) + def test_ndarray_matmul_dgemv(): np_mat_3_4 = np.arange(12, dtype=np.float64).reshape((3, 4)) np_mat_4_3 = np.arange(12, dtype=np.float64).reshape((4, 3)) @@ -684,9 +666,10 @@ def test_ndarray_matmul_dgemv(): assert_ndarrays_eq( (mat_3_4 @ vec_4, np_mat_3_4 @ np_vec_4), (mat_4_3 @ vec_3, np_mat_4_3 @ np_vec_3), - (mat_3_4.T @ vec_3, np_mat_3_4.T @ np_vec_3) + (mat_3_4.T @ vec_3, np_mat_3_4.T @ np_vec_3), ) + def test_ndarray_big(): assert hl.eval(hl.nd.array(hl.range(100_000))).size == 100_000 @@ -698,7 +681,7 @@ def test_ndarray_full(): (hl.nd.ones(6), np.ones(6)), (hl.nd.ones((6, 6, 6)), np.ones((6, 6, 6))), (hl.nd.full(7, 9), np.full(7, 9)), - (hl.nd.full((3, 4, 5), 9), np.full((3, 4, 5), 9)) + (hl.nd.full((3, 4, 5), 9), np.full((3, 4, 5), 9)), ) assert hl.eval(hl.nd.zeros((5, 5), dtype=hl.tfloat32)).dtype, np.float32 @@ -710,7 +693,7 @@ def test_ndarray_arange(): assert_ndarrays_eq( (hl.nd.arange(40), np.arange(40)), (hl.nd.arange(5, 50), np.arange(5, 50)), - (hl.nd.arange(2, 47, 13), np.arange(2, 47, 13)) + (hl.nd.arange(2, 47, 13), np.arange(2, 47, 13)), ) with pytest.raises(HailUserError) as exc: @@ -719,13 +702,21 @@ def test_ndarray_arange(): def test_ndarray_mixed(): - assert hl.eval(hl.missing(hl.tndarray(hl.tint64, 2)).map( - lambda x: x * x).reshape((4, 5)).T) is None - assert hl.eval( - (hl.nd.zeros((5, 10)).map(lambda x: x - 2) + - hl.nd.ones((5, 10)).map(lambda x: x + 5)).reshape(hl.missing(hl.ttuple(hl.tint64, hl.tint64))).T.reshape((10, 5))) is None - assert hl.eval(hl.or_missing(False, hl.nd.array(np.arange(10)).reshape( - (5, 2)).map(lambda x: x * 2)).map(lambda y: y * 2)) is None + assert hl.eval(hl.missing(hl.tndarray(hl.tint64, 2)).map(lambda x: x * x).reshape((4, 5)).T) is None + assert ( + hl.eval( + (hl.nd.zeros((5, 10)).map(lambda x: x - 2) + hl.nd.ones((5, 10)).map(lambda x: x + 5)) + .reshape(hl.missing(hl.ttuple(hl.tint64, hl.tint64))) + .T.reshape((10, 5)) + ) + is None + ) + assert ( + hl.eval( + hl.or_missing(False, hl.nd.array(np.arange(10)).reshape((5, 2)).map(lambda x: x * 2)).map(lambda y: y * 2) + ) + is None + ) def test_ndarray_show(): @@ -737,10 +728,8 @@ def test_ndarray_show(): def test_ndarray_diagonal(): assert np.array_equal(hl.eval(hl.nd.diagonal(hl.nd.array([[1, 2], [3, 4]]))), np.array([1, 4])) - assert np.array_equal(hl.eval(hl.nd.diagonal( - hl.nd.array([[1, 2, 3], [4, 5, 6]]))), np.array([1, 5])) - assert np.array_equal(hl.eval(hl.nd.diagonal( - hl.nd.array([[1, 2], [3, 4], [5, 6]]))), np.array([1, 4])) + assert np.array_equal(hl.eval(hl.nd.diagonal(hl.nd.array([[1, 2, 3], [4, 5, 6]]))), np.array([1, 5])) + assert np.array_equal(hl.eval(hl.nd.diagonal(hl.nd.array([[1, 2], [3, 4], [5, 6]]))), np.array([1, 4])) with pytest.raises(AssertionError) as exc: hl.nd.diagonal(hl.nd.array([1, 2])) @@ -758,21 +747,22 @@ def test_ndarray_solve_triangular(): a_sing = hl.nd.array([[0, 1], [0, 1]]) b_sing = hl.nd.array([2, 2]) - assert np.allclose(hl.eval(hl.nd.solve_triangular(a, b)), np.array([1., 1.])) - assert np.allclose(hl.eval(hl.nd.solve_triangular(a, b2)), np.array([[5., 2.], [6., 3.]])) - assert np.allclose(hl.eval(hl.nd.solve_triangular(a_low, b_low, True)), np.array([[1., 3.]])) + assert np.allclose(hl.eval(hl.nd.solve_triangular(a, b)), np.array([1.0, 1.0])) + assert np.allclose(hl.eval(hl.nd.solve_triangular(a, b2)), np.array([[5.0, 2.0], [6.0, 3.0]])) + assert np.allclose(hl.eval(hl.nd.solve_triangular(a_low, b_low, True)), np.array([[1.0, 3.0]])) with pytest.raises(HailUserError) as exc: hl.eval(hl.nd.solve_triangular(a_sing, b_sing)) assert "singular" in str(exc.value), str(exc.value) + def test_ndarray_solve(): a = hl.nd.array([[1, 2], [3, 5]]) b = hl.nd.array([1, 2]) b2 = hl.nd.array([[1, 8], [2, 12]]) - assert np.allclose(hl.eval(hl.nd.solve(a, b)), np.array([-1., 1.])) - assert np.allclose(hl.eval(hl.nd.solve(a, b2)), np.array([[-1., -16.], [1, 12]])) - assert np.allclose(hl.eval(hl.nd.solve(a.T, b2.T)), np.array([[19., 26.], [-6, -8]])) + assert np.allclose(hl.eval(hl.nd.solve(a, b)), np.array([-1.0, 1.0])) + assert np.allclose(hl.eval(hl.nd.solve(a, b2)), np.array([[-1.0, -16.0], [1, 12]])) + assert np.allclose(hl.eval(hl.nd.solve(a.T, b2.T)), np.array([[19.0, 26.0], [-6, -8]])) with pytest.raises(HailUserError) as exc: hl.eval(hl.nd.solve(hl.nd.array([[1, 2], [1, 2]]), hl.nd.array([8, 10]))) @@ -795,8 +785,7 @@ def assert_raw_equivalence(hl_ndarray, np_ndarray): assert np.allclose(ndarray_tau[:rank], np_ndarray_tau[:rank]) def assert_r_equivalence(hl_ndarray, np_ndarray): - assert np.allclose(hl.eval(hl.nd.qr(hl_ndarray, mode="r")), - np.linalg.qr(np_ndarray, mode="r")) + assert np.allclose(hl.eval(hl.nd.qr(hl_ndarray, mode="r")), np.linalg.qr(np_ndarray, mode="r")) def assert_reduced_equivalence(hl_ndarray, np_ndarray): q, r = hl.eval(hl.nd.qr(hl_ndarray, mode="reduced")) @@ -859,9 +848,7 @@ def assert_same_qr(hl_ndarray, np_ndarray): assert_same_qr(nine_square, np_nine_square) - np_wiki_example = np.array([[12, -51, 4], - [6, 167, -68], - [-4, 24, -41]]) + np_wiki_example = np.array([[12, -51, 4], [6, 167, -68], [-4, 24, -41]]) wiki_example = hl.nd.array(np_wiki_example) assert_same_qr(wiki_example, np_wiki_example) @@ -905,7 +892,6 @@ def assert_evals_to_same_svd(nd_expr, np_array, full_matrices=True, compute_uv=T assert h.shape == n.shape k = min(np_array.shape) - rank = np.linalg.matrix_rank(np_array) if compute_uv: hu, hs, hv = evaled @@ -1002,10 +988,8 @@ def test_numpy_interop(): assert np.array_equal(hl.eval(np.array(a) @ hl.nd.array(b)), np.array([[21]])) assert np.array_equal(hl.eval(hl.nd.array(a) @ np.array(b)), np.array([[21]])) - assert np.array_equal(hl.eval(hl.nd.array(b) @ np.array(a)), - np.array([[6, 9], [10, 15]])) - assert np.array_equal(hl.eval(np.array(b) @ hl.nd.array(a)), - np.array([[6, 9], [10, 15]])) + assert np.array_equal(hl.eval(hl.nd.array(b) @ np.array(a)), np.array([[6, 9], [10, 15]])) + assert np.array_equal(hl.eval(np.array(b) @ hl.nd.array(a)), np.array([[6, 9], [10, 15]])) def test_ndarray_emitter_extract(): @@ -1028,7 +1012,7 @@ def test_ndarray(): an1 = np.array((1, 2, 3)) an2 = np.array([1, 2, 3]) - assert(np.array_equal(a1, a2) and np.array_equal(a2, an2)) + assert np.array_equal(a1, a2) and np.array_equal(a2, an2) a1 = hl.eval(hl.nd.array(((1), (2), (3)))) a2 = hl.eval(hl.nd.array(([1], [2], [3]))) @@ -1038,7 +1022,7 @@ def test_ndarray(): an2 = np.array(([1], [2], [3])) an3 = np.array([[1], [2], [3]]) - assert(np.array_equal(a1, an1) and np.array_equal(a2, an2) and np.array_equal(a3, an3)) + assert np.array_equal(a1, an1) and np.array_equal(a2, an2) and np.array_equal(a3, an3) a1 = hl.eval(hl.nd.array(((1, 2), (2, 5), (3, 8)))) a2 = hl.eval(hl.nd.array([[1, 2], [2, 5], [3, 8]])) @@ -1046,7 +1030,7 @@ def test_ndarray(): an1 = np.array(((1, 2), (2, 5), (3, 8))) an2 = np.array([[1, 2], [2, 5], [3, 8]]) - assert(np.array_equal(a1, an1) and np.array_equal(a2, an2)) + assert np.array_equal(a1, an1) and np.array_equal(a2, an2) def test_cast(): @@ -1054,7 +1038,7 @@ def testequal(a, hdtype, ndtype): ah = hl.eval(hl.nd.array(a, dtype=hdtype)) an = np.array(a, dtype=ndtype) - assert(ah.dtype == an.dtype) + assert ah.dtype == an.dtype def test(a): testequal(a, hl.tfloat64, np.float64) @@ -1063,8 +1047,8 @@ def test(a): testequal(a, hl.tint64, np.int64) test([1, 2, 3]) - test([1, 2, 3.]) - test([1., 2., 3.]) + test([1, 2, 3.0]) + test([1.0, 2.0, 3.0]) test([[1, 2], [3, 4]]) @@ -1076,8 +1060,8 @@ def test_inv(): def test_concatenate(): - x = np.array([[1., 2.], [3., 4.]]) - y = np.array([[5.], [6.]]) + x = np.array([[1.0, 2.0], [3.0, 4.0]]) + y = np.array([[5.0], [6.0]]) np_res = np.concatenate([x, y], axis=1) res = hl.eval(hl.nd.concatenate([x, y], axis=1)) @@ -1109,23 +1093,26 @@ def test_concatenate(): def test_concatenate_differing_shapes(): - with pytest.raises(ValueError, match='hl.nd.concatenate: ndarrays must have same number of dimensions, found: 1, 2'): - hl.nd.concatenate([ - hl.nd.array([1]), - hl.nd.array([[1]]) - ]) - - with pytest.raises(ValueError, match=re.escape('hl.nd.concatenate: ndarrays must have same element types, found these element types: (int32, float64)')): - hl.nd.concatenate([ - hl.nd.array([1]), - hl.nd.array([1.0]) - ]) - - with pytest.raises(ValueError, match=re.escape('hl.nd.concatenate: ndarrays must have same element types, found these element types: (int32, float64)')): - hl.nd.concatenate([ - hl.nd.array([1]), - hl.nd.array([[1.0]]) - ]) + with pytest.raises( + ValueError, match='hl.nd.concatenate: ndarrays must have same number of dimensions, found: 1, 2' + ): + hl.nd.concatenate([hl.nd.array([1]), hl.nd.array([[1]])]) + + with pytest.raises( + ValueError, + match=re.escape( + 'hl.nd.concatenate: ndarrays must have same element types, found these element types: (int32, float64)' + ), + ): + hl.nd.concatenate([hl.nd.array([1]), hl.nd.array([1.0])]) + + with pytest.raises( + ValueError, + match=re.escape( + 'hl.nd.concatenate: ndarrays must have same element types, found these element types: (int32, float64)' + ), + ): + hl.nd.concatenate([hl.nd.array([1]), hl.nd.array([[1.0]])]) def make_test_vstack_data(): @@ -1137,7 +1124,6 @@ def make_test_vstack_data(): yield a, empty, b yield empty, a, b - a = np.array([1, 2, 3]) b = np.array([2, 3, 4]) yield a, b @@ -1145,8 +1131,8 @@ def make_test_vstack_data(): @pytest.mark.parametrize("data", make_test_vstack_data()) def test_vstack(data): - assert(np.array_equal(hl.eval(hl.nd.vstack(data)), np.vstack(data))) - assert(np.array_equal(hl.eval(hl.nd.vstack(hl.array(list(data)))), np.vstack(data))) + assert np.array_equal(hl.eval(hl.nd.vstack(data)), np.vstack(data)) + assert np.array_equal(hl.eval(hl.nd.vstack(hl.array(list(data)))), np.vstack(data)) def make_test_vstack_2_data(): @@ -1195,8 +1181,8 @@ def assert_table(a, b): ht2 = ht2.annotate(stacked=hl.nd.hstack([ht2.x, ht2.y])) assert np.array_equal(ht2.collect()[0].stacked, np.hstack([a, b])) - assert(np.array_equal(hl.eval(hl.nd.hstack((a, b))), np.hstack((a, b)))) - assert(np.array_equal(hl.eval(hl.nd.hstack(hl.array([a, b]))), np.hstack((a, b)))) + assert np.array_equal(hl.eval(hl.nd.hstack((a, b))), np.hstack((a, b))) + assert np.array_equal(hl.eval(hl.nd.hstack(hl.array([a, b]))), np.hstack((a, b))) assert_table(a, b) @@ -1232,7 +1218,7 @@ def test_agg_ndarray_sum_ones_2d(): def test_agg_ndarray_sum_with_transposes(): transposes = hl.utils.range_table(4).annotate(x=hl.nd.arange(16).reshape((4, 4))) - transposes = transposes.annotate(x = hl.if_else((transposes.idx % 2) == 0, transposes.x, transposes.x.T)) + transposes = transposes.annotate(x=hl.if_else((transposes.idx % 2) == 0, transposes.x, transposes.x.T)) np_arange_4_by_4 = np.arange(16).reshape((4, 4)) transposes_result = (np_arange_4_by_4 * 2) + (np_arange_4_by_4.T * 2) assert np.array_equal(transposes.aggregate(hl.agg.ndarray_sum(transposes.x)), transposes_result) @@ -1261,7 +1247,7 @@ def test_maximum_minimuim(): (hl.nd.maximum(nx, ny), np.maximum(x, y)), (hl.nd.maximum(ny, z), np.maximum(y, z)), (hl.nd.minimum(nx, ny), np.minimum(x, y)), - (hl.nd.minimum(ny, z), np.minimum(y, z)), + (hl.nd.minimum(ny, z), np.minimum(y, z)), ) np_nan_max = np.maximum(nan_elem, f) @@ -1281,39 +1267,39 @@ def test_maximum_minimuim(): elif np.isnan(a) and np.isnan(b): min_matches += 1 - assert(nan_max.size == max_matches) - assert(nan_min.size == min_matches) + assert nan_max.size == max_matches + assert nan_min.size == min_matches def test_ndarray_broadcasting_with_decorator(): nd = hl.nd.array([[1, 4, 9], [16, 25, 36]]) nd_sqrt = hl.eval(hl.nd.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])) nd = hl.eval(hl.sqrt(nd)) - assert(np.array_equal(nd, nd_sqrt)) + assert np.array_equal(nd, nd_sqrt) nd = hl.nd.array([[10, 100, 1000], [10000, 100000, 1000000]]) nd_log10 = hl.eval(hl.nd.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])) nd = hl.eval(hl.log10(nd)) - assert(np.array_equal(nd, nd_log10)) + assert np.array_equal(nd, nd_log10) nd = hl.nd.array([[1.2, 2.3, 3.3], [4.3, 5.3, 6.3]]) nd_floor = hl.eval(hl.nd.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])) nd = hl.eval(hl.floor(nd)) - assert(np.array_equal(nd, nd_floor)) + assert np.array_equal(nd, nd_floor) def test_ndarray_indices_aggregations(): ht = hl.utils.range_table(1) - ht = ht.annotate_globals(g = hl.nd.ones((2, 2))) - ht = ht.annotate(x = hl.nd.ones((2, 2))) - ht = ht.annotate(a = hl.nd.solve(ht.x, 2 * ht.g)) - ht = ht.annotate(b = hl.nd.solve(2 * ht.g, ht.x)) - ht = ht.annotate(c = hl.nd.solve_triangular(2 * ht.g, hl.nd.eye(2))) - ht = ht.annotate(d = hl.nd.solve_triangular(hl.nd.eye(2), 2 * ht.g)) - ht = ht.annotate(e = hl.nd.svd(ht.x)) - ht = ht.annotate(f = hl.nd.inv(ht.x)) - ht = ht.annotate(h = hl.nd.concatenate((ht.x, ht.g))) - ht = ht.annotate(i = hl.nd.concatenate((ht.g, ht.x))) + ht = ht.annotate_globals(g=hl.nd.ones((2, 2))) + ht = ht.annotate(x=hl.nd.ones((2, 2))) + ht = ht.annotate(a=hl.nd.solve(ht.x, 2 * ht.g)) + ht = ht.annotate(b=hl.nd.solve(2 * ht.g, ht.x)) + ht = ht.annotate(c=hl.nd.solve_triangular(2 * ht.g, hl.nd.eye(2))) + ht = ht.annotate(d=hl.nd.solve_triangular(hl.nd.eye(2), 2 * ht.g)) + ht = ht.annotate(e=hl.nd.svd(ht.x)) + ht = ht.annotate(f=hl.nd.inv(ht.x)) + ht = ht.annotate(h=hl.nd.concatenate((ht.x, ht.g))) + ht = ht.annotate(i=hl.nd.concatenate((ht.g, ht.x))) def test_ndarray_log_broadcasting(): diff --git a/hail/python/test/hail/expr/test_show.py b/hail/python/test/hail/expr/test_show.py index 3dce8a8de21..873ba5b8c99 100644 --- a/hail/python/test/hail/expr/test_show.py +++ b/hail/python/test/hail/expr/test_show.py @@ -1,7 +1,5 @@ import hail as hl -from ..helpers import test_timeout - def test_show_1(): mt = hl.balding_nichols_model(3, 10, 10) @@ -26,6 +24,7 @@ def test_show_4(): mt.bn.fst.show() mt.GT.n_alt_alleles().show() + def test_show_5(): mt = hl.balding_nichols_model(3, 10, 10) (mt.GT.n_alt_alleles() * mt.GT.n_alt_alleles()).show() @@ -46,7 +45,7 @@ def test_show_mt_duplicate_col_key(): shown_cols = 2 mt = hl.utils.range_matrix_table(5, 5) - mt = mt.key_cols_by(c = 0) + mt = mt.key_cols_by(c=0) showobj = mt.show(n_cols=shown_cols, handler=lambda x: x) assert len(showobj.table_show.table.row) == len(mt.row) + shown_cols @@ -56,7 +55,7 @@ def test_show_mt_fewer_cols(): shown_cols = 7 mt = hl.utils.range_matrix_table(5, 5) - mt = mt.key_cols_by(c = 0) + mt = mt.key_cols_by(c=0) showobj = mt.show(n_cols=shown_cols, handler=lambda x: x) assert len(showobj.table_show.table.row) == len(mt.row) + mt.count_cols() diff --git a/hail/python/test/hail/expr/test_types.py b/hail/python/test/hail/expr/test_types.py index 1074873d369..20c0b589b2b 100644 --- a/hail/python/test/hail/expr/test_types.py +++ b/hail/python/test/hail/expr/test_types.py @@ -1,12 +1,32 @@ -from typing import Optional import unittest +from typing import Optional +import hail as hl from hail.expr import coercer_from_dtype -from hail.expr.types import * -from hail.genetics import reference_genome -from ..helpers import * +from hail.expr.types import ( + HailType, + dtype, + tarray, + tbool, + tcall, + tdict, + tfloat32, + tfloat64, + tint32, + tint64, + tinterval, + tlocus, + tndarray, + tset, + tstr, + tstruct, + ttuple, + tunion, +) from hail.utils.java import Env +from ..helpers import fails_local_backend, resource, skip_unless_spark_backend, skip_when_service_backend + class Tests(unittest.TestCase): def types_to_test(self): @@ -28,9 +48,11 @@ def types_to_test(self): tlocus('GRCh38'), tstruct(), tstruct(x=tint32, y=tint64, z=tarray(tset(tstr))), - tstruct(**{'weird field name 1': tint32, - r"""this one ' has "" quotes and `` backticks```""": tint64, - '!@#$%^&({[': tarray(tset(tstr))}), + tstruct(**{ + 'weird field name 1': tint32, + r"""this one ' has "" quotes and `` backticks```""": tint64, + '!@#$%^&({[': tarray(tset(tstr)), + }), tinterval(tlocus()), tset(tinterval(tlocus())), tstruct(a=tint32, b=tint32, c=tarray(tstr)), @@ -42,7 +64,8 @@ def types_to_test(self): tunion(**{'!@#$%^&({[': tstr}), ttuple(tstr, tint32), ttuple(tarray(tint32), tstr, tstr, tint32, tbool), - ttuple()] + ttuple(), + ] def test_parser_roundtrip(self): for t in self.types_to_test(): @@ -104,38 +127,48 @@ def test_get_context(self): tl2 = tlocus('GRCh38') types_and_rgs = [ - ([ - tint32, - tint64, - tfloat32, - tfloat64, - tstr, - tbool, - tcall, - tinterval(tset(tint32)), - tdict(tstr, tarray(tint32)), - tndarray(tstr, 1), - tstruct(), - tstruct(x=tint32, y=tint64, z=tarray(tset(tstr))), - tunion(), - tunion(a=tint32, b=tstr), - ttuple(tstr, tint32), - ttuple()], set()), - ([ - tl1, - tinterval(tl1), - tdict(tstr, tl1), - tndarray(tl1, 2), - tinterval(tl1), - tset(tinterval(tl1)), - tstruct(a=tint32, b=tint32, c=tarray(tl1)), - tunion(a=tint32, b=tl1), - ttuple(tarray(tint32), tl1, tstr, tint32, tbool), - ], {"GRCh37"}), - ([ - tdict(tl1, tl2), - ttuple(tarray(tl2), tl1, tstr, tint32, tbool), - ], {"GRCh37", "GRCh38"}) + ( + [ + tint32, + tint64, + tfloat32, + tfloat64, + tstr, + tbool, + tcall, + tinterval(tset(tint32)), + tdict(tstr, tarray(tint32)), + tndarray(tstr, 1), + tstruct(), + tstruct(x=tint32, y=tint64, z=tarray(tset(tstr))), + tunion(), + tunion(a=tint32, b=tstr), + ttuple(tstr, tint32), + ttuple(), + ], + set(), + ), + ( + [ + tl1, + tinterval(tl1), + tdict(tstr, tl1), + tndarray(tl1, 2), + tinterval(tl1), + tset(tinterval(tl1)), + tstruct(a=tint32, b=tint32, c=tarray(tl1)), + tunion(a=tint32, b=tl1), + ttuple(tarray(tint32), tl1, tstr, tint32, tbool), + ], + {"GRCh37"}, + ), + ( + [ + tdict(tl1, tl2), + ttuple(tarray(tl2), tl1, tstr, tint32, tbool), + ], + {"GRCh37", "GRCh38"}, + ), ] for types, rgs in types_and_rgs: diff --git a/hail/python/test/hail/extract_intervals/conftest.py b/hail/python/test/hail/extract_intervals/conftest.py index a57420c1e84..56ff38e5aa4 100644 --- a/hail/python/test/hail/extract_intervals/conftest.py +++ b/hail/python/test/hail/extract_intervals/conftest.py @@ -1,4 +1,5 @@ import pytest + import hail as hl from ..helpers import resource diff --git a/hail/python/test/hail/extract_intervals/test_full_key.py b/hail/python/test/hail/extract_intervals/test_full_key.py index 4d16b22f841..d68fedc2059 100644 --- a/hail/python/test/hail/extract_intervals/test_full_key.py +++ b/hail/python/test/hail/extract_intervals/test_full_key.py @@ -13,41 +13,43 @@ def test_ht_eq(ht, probe_variant): expr = ht.filter(ht.key == probe_variant) assert expr.n_partitions() == 1 actual = expr.collect() - expected = [hl.Struct( - locus=hl.Locus(contig=20, position=17434581, reference_genome='GRCh37'), - alleles=['A', 'G'], - rsid='rs16999198', - qual=21384.8, - filters=set(), - info=hl.Struct( - NEGATIVE_TRAIN_SITE=False, - HWP=1.0, - AC=[2], - culprit='InbreedingCoeff', - MQ0=0, - ReadPosRankSum=0.534, - AN=200, - InbreedingCoeff=-0.0134, - AF=[0.013], - GQ_STDDEV=134.2, - FS=2.944, - DP=22586, - GQ_MEAN=83.43, - POSITIVE_TRAIN_SITE=True, - VQSLOD=4.77, - ClippingRankSum=0.175, - BaseQRankSum=4.78, - MLEAF=[0.013], - MLEAC=[23], - MQ=59.75, - QD=14.65, - END=None, - DB=True, - HaplotypeScore=None, - MQRankSum=-0.192, - CCC=1740, - NCC=0, - DS=False + expected = [ + hl.Struct( + locus=hl.Locus(contig=20, position=17434581, reference_genome='GRCh37'), + alleles=['A', 'G'], + rsid='rs16999198', + qual=21384.8, + filters=set(), + info=hl.Struct( + NEGATIVE_TRAIN_SITE=False, + HWP=1.0, + AC=[2], + culprit='InbreedingCoeff', + MQ0=0, + ReadPosRankSum=0.534, + AN=200, + InbreedingCoeff=-0.0134, + AF=[0.013], + GQ_STDDEV=134.2, + FS=2.944, + DP=22586, + GQ_MEAN=83.43, + POSITIVE_TRAIN_SITE=True, + VQSLOD=4.77, + ClippingRankSum=0.175, + BaseQRankSum=4.78, + MLEAF=[0.013], + MLEAC=[23], + MQ=59.75, + QD=14.65, + END=None, + DB=True, + HaplotypeScore=None, + MQRankSum=-0.192, + CCC=1740, + NCC=0, + DS=False, + ), ) - )] + ] assert actual == expected diff --git a/hail/python/test/hail/extract_intervals/test_key_prefix.py b/hail/python/test/hail/extract_intervals/test_key_prefix.py index b42d38b5703..fa28f213a70 100644 --- a/hail/python/test/hail/extract_intervals/test_key_prefix.py +++ b/hail/python/test/hail/extract_intervals/test_key_prefix.py @@ -57,43 +57,45 @@ def test_ht_eq(ht, probe_locus): expr = ht.filter(ht.locus == probe_locus) assert expr.n_partitions() == 1 actual = expr.collect() - expected = [hl.Struct( - locus=hl.Locus(contig=20, position=17434581, reference_genome='GRCh37'), - alleles=['A', 'G'], - rsid='rs16999198', - qual=21384.8, - filters=set(), - info=hl.Struct( - NEGATIVE_TRAIN_SITE=False, - HWP=1.0, - AC=[2], - culprit='InbreedingCoeff', - MQ0=0, - ReadPosRankSum=0.534, - AN=200, - InbreedingCoeff=-0.0134, - AF=[0.013], - GQ_STDDEV=134.2, - FS=2.944, - DP=22586, - GQ_MEAN=83.43, - POSITIVE_TRAIN_SITE=True, - VQSLOD=4.77, - ClippingRankSum=0.175, - BaseQRankSum=4.78, - MLEAF=[0.013], - MLEAC=[23], - MQ=59.75, - QD=14.65, - END=None, - DB=True, - HaplotypeScore=None, - MQRankSum=-0.192, - CCC=1740, - NCC=0, - DS=False + expected = [ + hl.Struct( + locus=hl.Locus(contig=20, position=17434581, reference_genome='GRCh37'), + alleles=['A', 'G'], + rsid='rs16999198', + qual=21384.8, + filters=set(), + info=hl.Struct( + NEGATIVE_TRAIN_SITE=False, + HWP=1.0, + AC=[2], + culprit='InbreedingCoeff', + MQ0=0, + ReadPosRankSum=0.534, + AN=200, + InbreedingCoeff=-0.0134, + AF=[0.013], + GQ_STDDEV=134.2, + FS=2.944, + DP=22586, + GQ_MEAN=83.43, + POSITIVE_TRAIN_SITE=True, + VQSLOD=4.77, + ClippingRankSum=0.175, + BaseQRankSum=4.78, + MLEAF=[0.013], + MLEAC=[23], + MQ=59.75, + QD=14.65, + END=None, + DB=True, + HaplotypeScore=None, + MQRankSum=-0.192, + CCC=1740, + NCC=0, + DS=False, + ), ) - )] + ] assert actual == expected diff --git a/hail/python/test/hail/extract_intervals/test_locus_position.py b/hail/python/test/hail/extract_intervals/test_locus_position.py index 43ff881afe0..1e0cf7032e5 100644 --- a/hail/python/test/hail/extract_intervals/test_locus_position.py +++ b/hail/python/test/hail/extract_intervals/test_locus_position.py @@ -49,43 +49,45 @@ def test_ht_eq(ht): expr = ht.filter(ht.locus.position == 17434581) assert expr.n_partitions() == 1 actual = expr.collect() - expected = [hl.Struct( - locus=hl.Locus(contig=20, position=17434581, reference_genome='GRCh37'), - alleles=['A', 'G'], - rsid='rs16999198', - qual=21384.8, - filters=set(), - info=hl.Struct( - NEGATIVE_TRAIN_SITE=False, - HWP=1.0, - AC=[2], - culprit='InbreedingCoeff', - MQ0=0, - ReadPosRankSum=0.534, - AN=200, - InbreedingCoeff=-0.0134, - AF=[0.013], - GQ_STDDEV=134.2, - FS=2.944, - DP=22586, - GQ_MEAN=83.43, - POSITIVE_TRAIN_SITE=True, - VQSLOD=4.77, - ClippingRankSum=0.175, - BaseQRankSum=4.78, - MLEAF=[0.013], - MLEAC=[23], - MQ=59.75, - QD=14.65, - END=None, - DB=True, - HaplotypeScore=None, - MQRankSum=-0.192, - CCC=1740, - NCC=0, - DS=False + expected = [ + hl.Struct( + locus=hl.Locus(contig=20, position=17434581, reference_genome='GRCh37'), + alleles=['A', 'G'], + rsid='rs16999198', + qual=21384.8, + filters=set(), + info=hl.Struct( + NEGATIVE_TRAIN_SITE=False, + HWP=1.0, + AC=[2], + culprit='InbreedingCoeff', + MQ0=0, + ReadPosRankSum=0.534, + AN=200, + InbreedingCoeff=-0.0134, + AF=[0.013], + GQ_STDDEV=134.2, + FS=2.944, + DP=22586, + GQ_MEAN=83.43, + POSITIVE_TRAIN_SITE=True, + VQSLOD=4.77, + ClippingRankSum=0.175, + BaseQRankSum=4.78, + MLEAF=[0.013], + MLEAC=[23], + MQ=59.75, + QD=14.65, + END=None, + DB=True, + HaplotypeScore=None, + MQRankSum=-0.192, + CCC=1740, + NCC=0, + DS=False, + ), ) - )] + ] assert actual == expected diff --git a/hail/python/test/hail/fs/test_worker_driver_fs.py b/hail/python/test/hail/fs/test_worker_driver_fs.py index 43565225f04..568760e7462 100644 --- a/hail/python/test/hail/fs/test_worker_driver_fs.py +++ b/hail/python/test/hail/fs/test_worker_driver_fs.py @@ -2,11 +2,11 @@ import os import hail as hl -from hailtop.utils import secret_alnum_string -from hailtop.test_utils import skip_in_azure, run_if_azure from hailtop.aiocloud.aioazure import AzureAsyncFS +from hailtop.test_utils import run_if_azure, skip_in_azure +from hailtop.utils import secret_alnum_string -from ..helpers import fails_local_backend, hl_stop_for_test, hl_init_for_test, test_timeout, resource +from ..helpers import fails_local_backend, hl_init_for_test, hl_stop_for_test, resource, test_timeout @skip_in_azure @@ -21,7 +21,9 @@ def test_requester_pays_no_settings(): @skip_in_azure def test_requester_pays_write_no_settings(): - random_filename = 'gs://hail-test-requester-pays-fds32/test_requester_pays_on_worker_driver_' + secret_alnum_string(10) + random_filename = 'gs://hail-test-requester-pays-fds32/test_requester_pays_on_worker_driver_' + secret_alnum_string( + 10 + ) try: hl.utils.range_table(4, n_partitions=4).write(random_filename, overwrite=True) except Exception as exc: @@ -36,7 +38,9 @@ def test_requester_pays_write_no_settings(): def test_requester_pays_write_with_project(): hl_stop_for_test() hl_init_for_test(gcs_requester_pays_configuration='hail-vdc') - random_filename = 'gs://hail-test-requester-pays-fds32/test_requester_pays_on_worker_driver_' + secret_alnum_string(10) + random_filename = 'gs://hail-test-requester-pays-fds32/test_requester_pays_on_worker_driver_' + secret_alnum_string( + 10 + ) try: hl.utils.range_table(4, n_partitions=4).write(random_filename, overwrite=True) finally: @@ -48,15 +52,21 @@ def test_requester_pays_write_with_project(): def test_requester_pays_with_project(): hl_stop_for_test() hl_init_for_test(gcs_requester_pays_configuration='hail-vdc') - assert hl.import_table('gs://hail-test-requester-pays-fds32/hello', no_header=True).collect() == [hl.Struct(f0='hello')] + assert hl.import_table('gs://hail-test-requester-pays-fds32/hello', no_header=True).collect() == [ + hl.Struct(f0='hello') + ] hl_stop_for_test() hl_init_for_test(gcs_requester_pays_configuration=('hail-vdc', ['hail-test-requester-pays-fds32'])) - assert hl.import_table('gs://hail-test-requester-pays-fds32/hello', no_header=True).collect() == [hl.Struct(f0='hello')] + assert hl.import_table('gs://hail-test-requester-pays-fds32/hello', no_header=True).collect() == [ + hl.Struct(f0='hello') + ] hl_stop_for_test() hl_init_for_test(gcs_requester_pays_configuration=('hail-vdc', ['hail-test-requester-pays-fds32', 'other-bucket'])) - assert hl.import_table('gs://hail-test-requester-pays-fds32/hello', no_header=True).collect() == [hl.Struct(f0='hello')] + assert hl.import_table('gs://hail-test-requester-pays-fds32/hello', no_header=True).collect() == [ + hl.Struct(f0='hello') + ] hl_stop_for_test() hl_init_for_test(gcs_requester_pays_configuration=('hail-vdc', ['other-bucket'])) @@ -69,7 +79,9 @@ def test_requester_pays_with_project(): hl_stop_for_test() hl_init_for_test(gcs_requester_pays_configuration='hail-vdc') - assert hl.import_table('gs://hail-test-requester-pays-fds32/hello', no_header=True).collect() == [hl.Struct(f0='hello')] + assert hl.import_table('gs://hail-test-requester-pays-fds32/hello', no_header=True).collect() == [ + hl.Struct(f0='hello') + ] @skip_in_azure @@ -99,15 +111,24 @@ def test_requester_pays_with_project_more_than_one_partition(): hl_stop_for_test() hl_init_for_test(gcs_requester_pays_configuration='hail-vdc') - assert hl.import_table('gs://hail-test-requester-pays-fds32/zero-to-nine', no_header=True, min_partitions=8).collect() == expected_file_contents + assert ( + hl.import_table('gs://hail-test-requester-pays-fds32/zero-to-nine', no_header=True, min_partitions=8).collect() + == expected_file_contents + ) hl_stop_for_test() hl_init_for_test(gcs_requester_pays_configuration=('hail-vdc', ['hail-test-requester-pays-fds32'])) - assert hl.import_table('gs://hail-test-requester-pays-fds32/zero-to-nine', no_header=True, min_partitions=8).collect() == expected_file_contents + assert ( + hl.import_table('gs://hail-test-requester-pays-fds32/zero-to-nine', no_header=True, min_partitions=8).collect() + == expected_file_contents + ) hl_stop_for_test() hl_init_for_test(gcs_requester_pays_configuration=('hail-vdc', ['hail-test-requester-pays-fds32', 'other-bucket'])) - assert hl.import_table('gs://hail-test-requester-pays-fds32/zero-to-nine', no_header=True, min_partitions=8).collect() == expected_file_contents + assert ( + hl.import_table('gs://hail-test-requester-pays-fds32/zero-to-nine', no_header=True, min_partitions=8).collect() + == expected_file_contents + ) hl_stop_for_test() hl_init_for_test(gcs_requester_pays_configuration=('hail-vdc', ['other-bucket'])) @@ -120,22 +141,26 @@ def test_requester_pays_with_project_more_than_one_partition(): hl_stop_for_test() hl_init_for_test(gcs_requester_pays_configuration='hail-vdc') - assert hl.import_table('gs://hail-test-requester-pays-fds32/zero-to-nine', no_header=True, min_partitions=8).collect() == expected_file_contents + assert ( + hl.import_table('gs://hail-test-requester-pays-fds32/zero-to-nine', no_header=True, min_partitions=8).collect() + == expected_file_contents + ) @run_if_azure @fails_local_backend def test_can_access_public_blobs(): - public_mt = 'hail-az://azureopendatastorage/gnomad/release/3.1/mt/genomes/gnomad.genomes.v3.1.hgdp_1kg_subset.mt' + public_mt = 'https://azureopendatastorage.blob.core.windows.net/gnomad/release/3.1/mt/genomes/gnomad.genomes.v3.1.hgdp_1kg_subset.mt' assert hl.hadoop_exists(public_mt) with hl.hadoop_open(f'{public_mt}/README.txt') as readme: assert len(readme.read()) > 0 mt = hl.read_matrix_table(public_mt) mt.describe() + @run_if_azure @fails_local_backend -def test_qob_can_use_sas_tokens(): +async def test_qob_can_use_sas_tokens(): vcf = resource('sample.vcf') account = AzureAsyncFS.parse_url(vcf).account diff --git a/hail/python/test/hail/genetics/test_call.py b/hail/python/test/hail/genetics/test_call.py index 6ab0d82caa3..8e5b9988298 100644 --- a/hail/python/test/hail/genetics/test_call.py +++ b/hail/python/test/hail/genetics/test_call.py @@ -1,7 +1,7 @@ import unittest -from hail.genetics import * -from ..helpers import * +import hail as hl +from hail.genetics import Call class Tests(unittest.TestCase): @@ -87,10 +87,10 @@ def test_zeroploid(self): self.assertFalse(c_zeroploid.is_het_non_ref()) self.assertFalse(c_zeroploid.is_het_ref()) - self.assertRaisesRegex(NotImplementedError, - "Calls with greater than 2 alleles are not supported.", - Call, - [1, 1, 1, 1]) + self.assertRaisesRegex( + NotImplementedError, "Calls with greater than 2 alleles are not supported.", Call, [1, 1, 1, 1] + ) + def test_call_rich_comparison(): val = Call([0, 0]) diff --git a/hail/python/test/hail/genetics/test_locus.py b/hail/python/test/hail/genetics/test_locus.py index 94488bd143e..2a5c2073020 100644 --- a/hail/python/test/hail/genetics/test_locus.py +++ b/hail/python/test/hail/genetics/test_locus.py @@ -1,5 +1,6 @@ -from hail.genetics import Locus import hail as hl +from hail.genetics import Locus + def test_constructor(): l = Locus.parse('1:100') diff --git a/hail/python/test/hail/genetics/test_pedigree.py b/hail/python/test/hail/genetics/test_pedigree.py index e42084b0ace..544852673cc 100644 --- a/hail/python/test/hail/genetics/test_pedigree.py +++ b/hail/python/test/hail/genetics/test_pedigree.py @@ -1,12 +1,12 @@ import unittest -from hail.genetics import Trio, Pedigree -from ..helpers import * +from hail.genetics import Pedigree, Trio from hail.utils.java import FatalError +from ..helpers import resource -class Tests(unittest.TestCase): +class Tests(unittest.TestCase): def test_trios(self): t1 = Trio('kid1', pat_id='dad1', is_female=True) t2 = Trio('kid1', pat_id='dad1', is_female=True) @@ -32,7 +32,6 @@ def test_trios(self): self.assertEqual(t5.is_complete(), False) self.assertEqual(t6.is_complete(), True) - def test_pedigree(self): ped = Pedigree.read(resource('sample.fam')) ped.write('/tmp/sample_out.fam') diff --git a/hail/python/test/hail/genetics/test_reference_genome.py b/hail/python/test/hail/genetics/test_reference_genome.py index 30a15b88b9d..88c2940bf9e 100644 --- a/hail/python/test/hail/genetics/test_reference_genome.py +++ b/hail/python/test/hail/genetics/test_reference_genome.py @@ -1,11 +1,13 @@ -import pytest from random import randint +import pytest + import hail as hl -from hail.genetics import * -from ..helpers import * +from hail.genetics import ReferenceGenome from hail.utils import FatalError +from ..helpers import qobtest, resource + @qobtest def test_reference_genome(): @@ -39,15 +41,20 @@ def test_reference_genome(): with hl.TemporaryFilename() as filename: gr2.write(filename) + @qobtest def test_reference_genome_sequence(): gr3 = ReferenceGenome.read(resource("fake_ref_genome.json")) assert gr3.name == "my_reference_genome" assert not gr3.has_sequence() - gr4 = ReferenceGenome.from_fasta_file("test_rg", resource("fake_reference.fasta"), - resource("fake_reference.fasta.fai"), - mt_contigs=["b", "c"], x_contigs=["a"]) + gr4 = ReferenceGenome.from_fasta_file( + "test_rg", + resource("fake_reference.fasta"), + resource("fake_reference.fasta.fai"), + mt_contigs=["b", "c"], + x_contigs=["a"], + ) assert gr4.has_sequence() assert gr4._sequence_files == (resource("fake_reference.fasta"), resource("fake_reference.fasta.fai")) assert gr4.x_contigs == ["a"] @@ -61,8 +68,7 @@ def test_reference_genome_sequence(): gr4.remove_sequence() assert not gr4.has_sequence() - gr4.add_sequence(resource("fake_reference.fasta"), - resource("fake_reference.fasta.fai")) + gr4.add_sequence(resource("fake_reference.fasta"), resource("fake_reference.fasta.fai")) assert gr4.has_sequence() assert gr4._sequence_files == (resource("fake_reference.fasta"), resource("fake_reference.fasta.fai")) @@ -99,13 +105,15 @@ def test_reference_genome_liftover(): {'l37': hl.locus('20', 278691, 'GRCh37'), 'l38': hl.locus('chr20', 298047, 'GRCh38')}, {'l37': hl.locus('20', 37007586, 'GRCh37'), 'l38': hl.locus('chr12', 32563117, 'GRCh38')}, {'l37': hl.locus('20', 62965520, 'GRCh37'), 'l38': hl.locus('chr20', 64334167, 'GRCh38')}, - {'l37': hl.locus('20', 62965521, 'GRCh37'), 'l38': null_locus} + {'l37': hl.locus('20', 62965521, 'GRCh37'), 'l38': null_locus}, ] schema = hl.tstruct(l37=hl.tlocus(grch37), l38=hl.tlocus(grch38)) t = hl.Table.parallelize(rows, schema) - assert t.all(hl.if_else(hl.is_defined(t.l38), - hl.liftover(t.l37, 'GRCh38') == t.l38, - hl.is_missing(hl.liftover(t.l37, 'GRCh38')))) + assert t.all( + hl.if_else( + hl.is_defined(t.l38), hl.liftover(t.l37, 'GRCh38') == t.l38, hl.is_missing(hl.liftover(t.l37, 'GRCh38')) + ) + ) t = t.filter(hl.is_defined(t.l38)) assert t.count() == 6 @@ -117,8 +125,10 @@ def test_reference_genome_liftover(): null_locus_interval = hl.missing(hl.tinterval(hl.tlocus('GRCh38'))) rows = [ {'i37': hl.locus_interval('20', 1, 60000, True, False, 'GRCh37'), 'i38': null_locus_interval}, - {'i37': hl.locus_interval('20', 60001, 82456, True, True, 'GRCh37'), - 'i38': hl.locus_interval('chr20', 79360, 101815, True, True, 'GRCh38')} + { + 'i37': hl.locus_interval('20', 60001, 82456, True, True, 'GRCh37'), + 'i38': hl.locus_interval('chr20', 79360, 101815, True, True, 'GRCh38'), + }, ] schema = hl.tstruct(i37=hl.tinterval(hl.tlocus(grch37)), i38=hl.tinterval(hl.tlocus(grch38))) t = hl.Table.parallelize(rows, schema) @@ -138,13 +148,20 @@ def test_liftover_strand(): expected = hl.Struct(result=hl.Locus('chr20', 79360, 'GRCh38'), is_negative_strand=False) assert actual == expected - actual = hl.eval(hl.liftover(hl.locus_interval('20', 37007582, 37007586, True, True, 'GRCh37'), - 'GRCh38', include_strand=True)) - expected = hl.Struct(result=hl.Interval(hl.Locus('chr12', 32563117, 'GRCh38'), - hl.Locus('chr12', 32563121, 'GRCh38'), - includes_start=True, - includes_end=True), - is_negative_strand=True) + actual = hl.eval( + hl.liftover( + hl.locus_interval('20', 37007582, 37007586, True, True, 'GRCh37'), 'GRCh38', include_strand=True + ) + ) + expected = hl.Struct( + result=hl.Interval( + hl.Locus('chr12', 32563117, 'GRCh38'), + hl.Locus('chr12', 32563121, 'GRCh38'), + includes_start=True, + includes_end=True, + ), + is_negative_strand=True, + ) assert actual == expected with pytest.raises(FatalError): @@ -170,7 +187,6 @@ def assert_rg_loaded_correctly(name): assert rg.mt_contigs == ["MT"] assert rg.par == [hl.Interval(start=hl.Locus("X", 2, name), end=hl.Locus("X", 4, name))] - assert hl.read_table(resource('custom_references.t')).count() == 14 assert_rg_loaded_correctly('test_rg_0') assert_rg_loaded_correctly('test_rg_1') @@ -189,7 +205,7 @@ def assert_rg_loaded_correctly(name): def test_custom_reference_read_write(): hl.ReferenceGenome("dk", ['hello'], {"hello": 123}) ht = hl.utils.range_table(5) - ht = ht.key_by(locus=hl.locus('hello', ht.idx+1, 'dk')) + ht = ht.key_by(locus=hl.locus('hello', ht.idx + 1, 'dk')) with hl.TemporaryDirectory(ensure_exists=False) as foo: ht.write(foo) expected = ht @@ -208,6 +224,7 @@ def test_locus_from_global_position(): assert python == scala + def test_locus_from_global_position_negative_pos(): with pytest.raises(ValueError): hl.get_reference('GRCh37').locus_from_global_position(-1) @@ -215,4 +232,4 @@ def test_locus_from_global_position_negative_pos(): def test_locus_from_global_position_too_long(): with pytest.raises(ValueError): - hl.get_reference('GRCh37').locus_from_global_position(2**64-1) + hl.get_reference('GRCh37').locus_from_global_position(2**64 - 1) diff --git a/hail/python/test/hail/ggplot/test_ggplot.py b/hail/python/test/hail/ggplot/test_ggplot.py index b2349c7455f..41800daebe8 100644 --- a/hail/python/test/hail/ggplot/test_ggplot.py +++ b/hail/python/test/hail/ggplot/test_ggplot.py @@ -1,63 +1,107 @@ -import hail as hl -from hail.ggplot import * -import numpy as np import math +import numpy as np + +import hail as hl +from hail.ggplot import ( + aes, + coord_cartesian, + facet_wrap, + geom_area, + geom_bar, + geom_col, + geom_histogram, + geom_hline, + geom_line, + geom_point, + geom_ribbon, + geom_text, + ggplot, + ggtitle, + scale_color_manual, + vars, + xlab, + ylab, +) + def test_geom_point_line_text_col_area(): ht = hl.utils.range_table(20) ht = ht.annotate(double=ht.idx * 2) ht = ht.annotate(triple=ht.idx * 3) - fig = (ggplot(ht, aes(x=ht.idx)) + - aes(y=ht.double) + - geom_point() + - geom_line(aes(y=ht.triple)) + - geom_text(aes(label=hl.str(ht.idx))) + - geom_col(aes(y=ht.triple + ht.double)) + - geom_area(aes(y=ht.triple - ht.double)) + - coord_cartesian((0, 100), (0, 80)) + - xlab("my_x") + - ylab("my_y") + - ggtitle("Title") - ) + fig = ( + ggplot(ht, aes(x=ht.idx)) + + aes(y=ht.double) + + geom_point() + + geom_line(aes(y=ht.triple)) + + geom_text(aes(label=hl.str(ht.idx))) + + geom_col(aes(y=ht.triple + ht.double)) + + geom_area(aes(y=ht.triple - ht.double)) + + coord_cartesian((0, 100), (0, 80)) + + xlab("my_x") + + ylab("my_y") + + ggtitle("Title") + ) fig.to_plotly() def test_manhattan_plot(): mt = hl.balding_nichols_model(3, 10, 100) ht = mt.rows() - ht = ht.annotate(pval=.02) + ht = ht.annotate(pval=0.02) fig = ggplot(ht, aes(x=ht.locus, y=-hl.log10(ht.pval))) + geom_point() + geom_hline(yintercept=-math.log10(5e-8)) pfig = fig.to_plotly() - expected_ticks = ('1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '20', '21', '22', 'X', 'Y') + expected_ticks = ( + '1', + '2', + '3', + '4', + '5', + '6', + '7', + '8', + '9', + '10', + '11', + '12', + '13', + '14', + '15', + '16', + '17', + '18', + '19', + '20', + '21', + '22', + 'X', + 'Y', + ) assert pfig.layout.xaxis.ticktext == expected_ticks + def test_histogram(): num_rows = 101 num_groups = 5 ht = hl.utils.range_table(num_rows) ht = ht.annotate(mod_3=hl.str(ht.idx % num_groups)) for position in ["stack", "dodge", "identity"]: - fig = (ggplot(ht, aes(x=ht.idx)) + - geom_histogram(aes(fill=ht.mod_3), alpha=0.5, position=position, bins=10) - ) + fig = ggplot(ht, aes(x=ht.idx)) + geom_histogram(aes(fill=ht.mod_3), alpha=0.5, position=position, bins=10) pfig = fig.to_plotly() assert len(pfig.data) == num_groups for idx, bar in enumerate(pfig.data): if position in {"stack", "identity"}: - assert (bar.x == [float(e) for e in range(num_groups, num_rows-1, num_groups*2)]).all() + assert (bar.x == [float(e) for e in range(num_groups, num_rows - 1, num_groups * 2)]).all() else: dist_between_bars_in_one_group = (num_rows - 1) / (num_groups * 2) - single_bar_width = (dist_between_bars_in_one_group / num_groups) + single_bar_width = dist_between_bars_in_one_group / num_groups first_bar_start = single_bar_width / 2 + idx * single_bar_width assert (bar.x == np.arange(first_bar_start, num_rows - 1, dist_between_bars_in_one_group)).all() def test_separate_traces_per_group(): ht = hl.utils.range_table(30) - fig = (ggplot(ht, aes(x=ht.idx)) + - geom_bar(aes(fill=hl.str(ht.idx))) - ) + fig = ggplot(ht, aes(x=ht.idx)) + geom_bar(aes(fill=hl.str(ht.idx))) assert len(fig.to_plotly().data) == 30 @@ -82,7 +126,11 @@ def test_scale_color_manual(): num_rows = 4 colors = set(["red", "blue"]) ht = hl.utils.range_table(num_rows) - fig = ggplot(ht, aes(x=ht.idx, y=ht.idx, color=hl.str(ht.idx % 2))) + geom_point() + scale_color_manual(values=list(colors)) + fig = ( + ggplot(ht, aes(x=ht.idx, y=ht.idx, color=hl.str(ht.idx % 2))) + + geom_point() + + scale_color_manual(values=list(colors)) + ) pfig = fig.to_plotly() assert set([scatter.marker.color for scatter in pfig.data]) == colors @@ -97,7 +145,7 @@ def test_weighted_bar(): result = [8, 9, 5, 6] for idx, y in enumerate(fig.to_plotly().data[0].y): - assert(y == result[idx]) + assert y == result[idx] def test_faceting(): @@ -105,7 +153,7 @@ def test_faceting(): ht = ht.annotate(x=hl.if_else(ht.idx < 4, "less", "more")) pfig = (ggplot(ht) + geom_point(aes(x=ht.idx, y=ht.idx)) + facet_wrap(vars(ht.x))).to_plotly() - assert(len(pfig.layout.annotations) == 2) + assert len(pfig.layout.annotations) == 2 def test_matrix_tables(): @@ -113,12 +161,12 @@ def test_matrix_tables(): mt = mt.annotate_rows(row_doubled=mt.row_idx * 2) mt = mt.annotate_entries(entry_idx=mt.row_idx + mt.col_idx) for field, expected in [ - (mt.row_doubled, [(0, 0), (1, 2), (2, 4)]), - (mt.entry_idx, [(0, 0), (0, 1), (0, 2), (1, 1), (1, 2), (1, 3), (2, 2), (2, 3), (2, 4)]) + (mt.row_doubled, [(0, 0), (1, 2), (2, 4)]), + (mt.entry_idx, [(0, 0), (0, 1), (0, 2), (1, 1), (1, 2), (1, 3), (2, 2), (2, 3), (2, 4)]), ]: data = (ggplot(mt, aes(x=mt.row_idx, y=field)) + geom_point()).to_plotly().data[0] assert len(data.x) == len(expected) assert len(data.y) == len(expected) for idx, (x, y) in enumerate(zip(data.x, data.y)): - assert(x == expected[idx][0]) - assert(y == expected[idx][1]) + assert x == expected[idx][0] + assert y == expected[idx][1] diff --git a/hail/python/test/hail/helpers.py b/hail/python/test/hail/helpers.py index d7a1cc93cf9..edfc717293c 100644 --- a/hail/python/test/hail/helpers.py +++ b/hail/python/test/hail/helpers.py @@ -1,14 +1,14 @@ -from typing import Callable, TypeVar -from typing_extensions import ParamSpec import os -from timeit import default_timer as timer import unittest +from timeit import default_timer as timer +from typing import Callable, TypeVar + import pytest -from hailtop.hail_decorator import decorator +from typing_extensions import ParamSpec -from hail.utils.java import choose_backend import hail as hl - +from hail.utils.java import choose_backend +from hailtop.hail_decorator import decorator GCS_REQUESTER_PAYS_PROJECT = os.environ.get('GCS_REQUESTER_PAYS_PROJECT') HAIL_QUERY_N_CORES = os.environ.get('HAIL_QUERY_N_CORES', '2') @@ -26,9 +26,10 @@ def hl_stop_for_test(): hl.stop() -_test_dir = os.environ.get('HAIL_TEST_RESOURCES_DIR', - os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(hl.__file__))), - 'src/test/resources')) +_test_dir = os.environ.get( + 'HAIL_TEST_RESOURCES_DIR', + os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(hl.__file__))), 'src/test/resources'), +) _doctest_dir = os.environ.get('HAIL_DOCTEST_DATA_DIR', 'hail/docs/data') @@ -68,6 +69,7 @@ def get_dataset(): _dataset = hl.read_matrix_table(resource('split-multi-sample.vcf.mt')).select_globals() return _dataset + def assert_time(f, max_duration): start = timer() x = f() @@ -76,6 +78,7 @@ def assert_time(f, max_duration): print(f'took {end - start:.3f}') return x + def create_all_values(): return hl.struct( f32=hl.float32(3.14), @@ -89,9 +92,7 @@ def create_all_values(): md=hl.missing(hl.tdict(hl.tint32, hl.tstr)), h38=hl.locus('chr22', 33878978, 'GRCh38'), ml=hl.missing(hl.tlocus('GRCh37')), - i=hl.interval( - hl.locus('1', 999), - hl.locus('1', 1001)), + i=hl.interval(hl.locus('1', 999), hl.locus('1', 1001)), c=hl.call(0, 1), mc=hl.missing(hl.tcall), t=hl.tuple([hl.call(1, 2, phased=True), 'foo', hl.missing(hl.tstr)]), @@ -99,24 +100,31 @@ def create_all_values(): nd=hl.nd.arange(0, 10).reshape((2, 5)), ) + def prefix_struct(s, prefix): return hl.struct(**{prefix + k: s[k] for k in s}) + def create_all_values_table(): all_values = create_all_values() - return (hl.utils.range_table(5, n_partitions=3) - .annotate_globals(**prefix_struct(all_values, 'global_')) - .annotate(**all_values) - .cache()) + return ( + hl.utils.range_table(5, n_partitions=3) + .annotate_globals(**prefix_struct(all_values, 'global_')) + .annotate(**all_values) + .cache() + ) + def create_all_values_matrix_table(): all_values = create_all_values() - return (hl.utils.range_matrix_table(3, 2, n_partitions=2) - .annotate_globals(**prefix_struct(all_values, 'global_')) - .annotate_rows(**prefix_struct(all_values, 'row_')) - .annotate_cols(**prefix_struct(all_values, 'col_')) - .annotate_entries(**prefix_struct(all_values, 'entry_')) - .cache()) + return ( + hl.utils.range_matrix_table(3, 2, n_partitions=2) + .annotate_globals(**prefix_struct(all_values, 'global_')) + .annotate_rows(**prefix_struct(all_values, 'row_')) + .annotate_cols(**prefix_struct(all_values, 'col_')) + .annotate_entries(**prefix_struct(all_values, 'entry_')) + .cache() + ) def create_all_values_datasets(): @@ -129,6 +137,7 @@ def create_all_values_datasets(): def skip_unless_spark_backend(reason='requires Spark'): from hail.backend.spark_backend import SparkBackend + @decorator def wrapper(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: if isinstance(hl.utils.java.Env.backend(), SparkBackend): @@ -141,6 +150,7 @@ def wrapper(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: def skip_when_local_backend(reason='skipping for Local Backend'): from hail.backend.local_backend import LocalBackend + @decorator def wrapper(func, *args, **kwargs): if isinstance(hl.utils.java.Env.backend(), LocalBackend): @@ -153,6 +163,7 @@ def wrapper(func, *args, **kwargs): def skip_when_service_backend(reason='skipping for Service Backend'): from hail.backend.service_backend import ServiceBackend + @decorator def wrapper(func, *args, **kwargs): if isinstance(hl.utils.java.Env.backend(), ServiceBackend): @@ -165,6 +176,7 @@ def wrapper(func, *args, **kwargs): def skip_when_service_backend_in_azure(reason='skipping for Service Backend in Azure'): from hail.backend.service_backend import ServiceBackend + @decorator def wrapper(func, *args, **kwargs): if isinstance(hl.utils.java.Env.backend(), ServiceBackend) and os.environ.get('HAIL_CLOUD') == 'azure': @@ -177,6 +189,7 @@ def wrapper(func, *args, **kwargs): def skip_unless_service_backend(reason='only relevant to service backend', clouds=None): from hail.backend.service_backend import ServiceBackend + @decorator def wrapper(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: if not isinstance(hl.utils.java.Env.backend(), ServiceBackend): @@ -190,21 +203,18 @@ def wrapper(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: fails_local_backend = pytest.mark.xfail( - choose_backend() == 'local', - reason="doesn't yet work on local backend", - strict=True) + choose_backend() == 'local', reason="doesn't yet work on local backend", strict=True +) fails_service_backend = pytest.mark.xfail( - choose_backend() == 'batch', - reason="doesn't yet work on service backend", - strict=True) + choose_backend() == 'batch', reason="doesn't yet work on service backend", strict=True +) fails_spark_backend = pytest.mark.xfail( - choose_backend() == 'spark', - reason="doesn't yet work on spark backend", - strict=True) + choose_backend() == 'spark', reason="doesn't yet work on spark backend", strict=True +) qobtest = pytest.mark.qobtest @@ -241,6 +251,7 @@ def wrapper(func, *args, **kwargs): return func(*args, **kwargs) finally: hl._set_flags(**prev_flags) + return wrapper diff --git a/hail/python/test/hail/linalg/test_linalg.py b/hail/python/test/hail/linalg/test_linalg.py index cfecbc9382a..9abf3958710 100644 --- a/hail/python/test/hail/linalg/test_linalg.py +++ b/hail/python/test/hail/linalg/test_linalg.py @@ -1,14 +1,15 @@ -import pytest import math -import numpy as np -import hail as hl - from contextlib import contextmanager +import numpy as np +import pytest + +import hail as hl from hail.expr.expressions import ExpressionException from hail.linalg import BlockMatrix from hail.utils import FatalError, HailUserError -from ..helpers import * + +from ..helpers import fails_local_backend, fails_service_backend, fails_spark_backend, test_timeout def sparsify_numpy(np_mat, block_size, blocks_to_sparsify): @@ -37,9 +38,7 @@ def sparsify_numpy(np_mat, block_size, blocks_to_sparsify): def _np_matrix(a): - return hl.eval(a.to_ndarray()) \ - if isinstance(a, BlockMatrix) \ - else np.array(a) + return hl.eval(a.to_ndarray()) if isinstance(a, BlockMatrix) else np.array(a) def _assert_eq(a, b): @@ -53,14 +52,15 @@ def _assert_close(a, b): def _assert_rectangles_eq(expected, rect_path, export_rects, binary=False): - for (i, r) in enumerate(export_rects): + for i, r in enumerate(export_rects): piece_path = rect_path + '/rect-' + str(i) + '_' + '-'.join(map(str, r)) with hl.current_backend().fs.open(piece_path, mode='rb' if binary else 'r') as file: - expected_rect = expected[r[0]:r[1], r[2]:r[3]] - actual_rect = np.loadtxt(file, ndmin=2) if not binary else np.reshape( - np.frombuffer(file.read()), - (r[1] - r[0], r[3] - r[2]) + expected_rect = expected[r[0] : r[1], r[2] : r[3]] + actual_rect = ( + np.loadtxt(file, ndmin=2) + if not binary + else np.reshape(np.frombuffer(file.read()), (r[1] - r[0], r[3] - r[2])) ) _assert_eq(expected_rect, actual_rect) @@ -86,15 +86,46 @@ def test_from_entry_expr_empty_parts(): @pytest.mark.parametrize( 'mean_impute, center, normalize, mk_expected', - [ (False, False, False, lambda a: a, ) - , (False, False, True, lambda a: a / np.sqrt(5), ) - , (False, True, False, lambda a: a - 1.0, ) - , (False, True, True, lambda a: (a - 1.0) / np.sqrt(2)) - , (True, False, False, lambda a: a, ) - , (True, False, True, lambda a: a / np.sqrt(5), ) - , (True, True, False, lambda a: a - 1.0, ) - , (True, True, True, lambda a: (a - 1.0) / np.sqrt(2)) - ] + [ + ( + False, + False, + False, + lambda a: a, + ), + ( + False, + False, + True, + lambda a: a / np.sqrt(5), + ), + ( + False, + True, + False, + lambda a: a - 1.0, + ), + (False, True, True, lambda a: (a - 1.0) / np.sqrt(2)), + ( + True, + False, + False, + lambda a: a, + ), + ( + True, + False, + True, + lambda a: a / np.sqrt(5), + ), + ( + True, + True, + False, + lambda a: a - 1.0, + ), + (True, True, True, lambda a: (a - 1.0) / np.sqrt(2)), + ], ) def test_from_entry_expr_options(mean_impute, center, normalize, mk_expected): a = np.array([0.0, 1.0, 2.0]) @@ -103,7 +134,7 @@ def test_from_entry_expr_options(mean_impute, center, normalize, mk_expected): mt = mt.rename({'row_idx': 'v', 'col_idx': 's'}) xs = hl.array([0.0, hl.missing(hl.tfloat), 2.0]) if mean_impute else hl.literal(a) - mt = mt.annotate_entries(x = xs[mt.s]) + mt = mt.annotate_entries(x=xs[mt.s]) expected = mk_expected(a) @@ -116,7 +147,7 @@ def test_from_entry_expr_raises_when_values_missing(): mt = hl.utils.range_matrix_table(1, 3) mt = mt.rename({'row_idx': 'v', 'col_idx': 's'}) actual = hl.array([0.0, hl.missing(hl.tfloat), 2.0]) - mt = mt.annotate_entries(x = actual[mt.s]) + mt = mt.annotate_entries(x=actual[mt.s]) with pytest.raises(Exception, match='Cannot construct an ndarray with missing values'): BlockMatrix.from_entry_expr(mt.x) @@ -199,19 +230,13 @@ def test_block_matrix_from_numpy_transpose(): def test_block_matrix_to_file_transpose(): data = np.random.rand(10, 11) with block_matrix_to_tmp_file(data, transpose=True) as f: - _assert_eq(data.T, np - .frombuffer(hl.current_backend().fs.open(f, mode='rb').read()) - .reshape((11, 10)) - ) + _assert_eq(data.T, np.frombuffer(hl.current_backend().fs.open(f, mode='rb').read()).reshape((11, 10))) def test_numpy_read_block_matrix_to_file(): data = np.random.rand(10, 11) with block_matrix_to_tmp_file(data) as f: - _assert_eq(data, np - .frombuffer(hl.current_backend().fs.open(f, mode='rb').read()) - .reshape((10, 11)) - ) + _assert_eq(data, np.frombuffer(hl.current_backend().fs.open(f, mode='rb').read()).reshape((10, 11))) def test_block_matrix_from_numpy_bytes(): @@ -244,19 +269,17 @@ def test_numpy_round_trip_force_blocking(): @fails_service_backend() @fails_local_backend() @pytest.mark.parametrize( - 'n_partitions,block_size', - [ (n_partitions, block_size) - for n_partitions in [1, 2, 3] - for block_size in [1, 2, 5] - ] + 'n_partitions,block_size', [(n_partitions, block_size) for n_partitions in [1, 2, 3] for block_size in [1, 2, 5]] ) def test_to_table(n_partitions, block_size): schema = hl.tstruct(row_idx=hl.tint64, entries=hl.tarray(hl.tfloat64)) - rows = [{'row_idx': 0, 'entries': [0.0, 1.0]}, - {'row_idx': 1, 'entries': [2.0, 3.0]}, - {'row_idx': 2, 'entries': [4.0, 5.0]}, - {'row_idx': 3, 'entries': [6.0, 7.0]}, - {'row_idx': 4, 'entries': [8.0, 9.0]}] + rows = [ + {'row_idx': 0, 'entries': [0.0, 1.0]}, + {'row_idx': 1, 'entries': [2.0, 3.0]}, + {'row_idx': 2, 'entries': [4.0, 5.0]}, + {'row_idx': 3, 'entries': [6.0, 7.0]}, + {'row_idx': 4, 'entries': [8.0, 9.0]}, + ] expected = hl.Table.parallelize(rows, schema, 'row_idx', n_partitions) bm = BlockMatrix._create(5, 2, [float(i) for i in range(10)], block_size) @@ -272,7 +295,10 @@ def test_to_table_maximum_cache_memory_in_bytes_limits(): with pytest.raises(Exception) as exc_info: bm.to_table_row_major(2, maximum_cache_memory_in_bytes=15)._force_count() - assert 'BlockMatrixCachedPartFile must be able to hold at least one row of every block in memory' in exc_info.value.args[0] + assert ( + 'BlockMatrixCachedPartFile must be able to hold at least one row of every block in memory' + in exc_info.value.args[0] + ) bm = BlockMatrix._create(5, 2, [float(i) for i in range(10)], 2) bm.to_table_row_major(2, maximum_cache_memory_in_bytes=16)._force_count() @@ -311,254 +337,240 @@ def block_matrix_bindings(): nrow = np.array([[7.0, 8.0, 9.0]]) nsquare = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]) - yield { - 'nx': np.array([[2.0]]), - 'nc': nc, - 'nr': nr, - 'nm': nm, - 'nrow': nrow, + 'nx': np.array([[2.0]]), + 'nc': nc, + 'nr': nr, + 'nm': nm, + 'nrow': nrow, 'nsquare': nsquare, - - 'e': 2.0, - + 'e': 2.0, # BlockMatrixMap requires very simple IRs on the SparkBackend. If I use # `from_ndarray` here, it generates an `NDArrayRef` expression that it can't handle. # Will be fixed by improving FoldConstants handling of ndarrays or fully lowering BlockMatrix. - 'x': BlockMatrix._create(1, 1, [2.0], block_size=8), - 'c': BlockMatrix.from_ndarray(hl.literal(nc), block_size=8), - 'r': BlockMatrix.from_ndarray(hl.literal(nr), block_size=8), - 'm': BlockMatrix.from_ndarray(hl.literal(nm), block_size=8), - 'row': BlockMatrix.from_ndarray(hl.nd.array(nrow), block_size=8), - 'square': BlockMatrix.from_ndarray(hl.nd.array(nsquare), block_size=8) + 'x': BlockMatrix._create(1, 1, [2.0], block_size=8), + 'c': BlockMatrix.from_ndarray(hl.literal(nc), block_size=8), + 'r': BlockMatrix.from_ndarray(hl.literal(nr), block_size=8), + 'm': BlockMatrix.from_ndarray(hl.literal(nm), block_size=8), + 'row': BlockMatrix.from_ndarray(hl.nd.array(nrow), block_size=8), + 'square': BlockMatrix.from_ndarray(hl.nd.array(nsquare), block_size=8), } + @pytest.mark.parametrize( 'x, y', - [ # addition - ('+m', '0 + m') - , ('x + e', 'nx + e') - , ('c + e', 'nc + e') - , ('r + e', 'nr + e') - , ('m + e', 'nm + e') - , ('x + e', 'e + x') - , ('c + e', 'e + c') - , ('r + e', 'e + r') - , ('m + e', 'e + m') - , ('x + x', '2 * x') - , ('c + c', '2 * c') - , ('r + r', '2 * r') - , ('m + m', '2 * m') - , ('x + c', 'np.array([[3.0], [4.0]])') - , ('x + r', 'np.array([[3.0, 4.0, 5.0]])') - , ('x + m', 'np.array([[3.0, 4.0, 5.0], [6.0, 7.0, 8.0]])') - , ('c + m', 'np.array([[2.0, 3.0, 4.0], [6.0, 7.0, 8.0]])') - , ('r + m', 'np.array([[2.0, 4.0, 6.0], [5.0, 7.0, 9.0]])') - , ('x + c', 'c + x') - , ('x + r', 'r + x') - , ('x + m', 'm + x') - , ('c + m', 'm + c') - , ('r + m', 'm + r') - , ('x + nx', 'x + x') - , ('x + nc', 'x + c') - , ('x + nr', 'x + r') - , ('x + nm', 'x + m') - , ('c + nx', 'c + x') - , ('c + nc', 'c + c') - , ('c + nm', 'c + m') - , ('r + nx', 'r + x') - , ('r + nr', 'r + r') - , ('r + nm', 'r + m') - , ('m + nx', 'm + x') - , ('m + nc', 'm + c') - , ('m + nr', 'm + r') - , ('m + nm', 'm + m') - - # subtraction - , ('-m', '0 - m') - , ('x - e', 'nx - e') - , ('c - e', 'nc - e') - , ('r - e', 'nr - e') - , ('m - e', 'nm - e') - , ('x - e', '-(e - x)') - , ('c - e', '-(e - c)') - , ('r - e', '-(e - r)') - , ('m - e', '-(e - m)') - , ('x - x', 'np.zeros((1, 1))') - , ('c - c', 'np.zeros((2, 1))') - , ('r - r', 'np.zeros((1, 3))') - , ('m - m', 'np.zeros((2, 3))') - , ('x - c', 'np.array([[1.0], [0.0]])') - , ('x - r', 'np.array([[1.0, 0.0, -1.0]])') - , ('x - m', 'np.array([[1.0, 0.0, -1.0], [-2.0, -3.0, -4.0]])') - , ('c - m', 'np.array([[0.0, -1.0, -2.0], [-2.0, -3.0, -4.0]])') - , ('r - m', 'np.array([[0.0, 0.0, 0.0], [-3.0, -3.0, -3.0]])') - , ('x - c', '-(c - x)') - , ('x - r', '-(r - x)') - , ('x - m', '-(m - x)') - , ('c - m', '-(m - c)') - , ('r - m', '-(m - r)') - , ('x - nx', 'x - x') - , ('x - nc', 'x - c') - , ('x - nr', 'x - r') - , ('x - nm', 'x - m') - , ('c - nx', 'c - x') - , ('c - nc', 'c - c') - , ('c - nm', 'c - m') - , ('r - nx', 'r - x') - , ('r - nr', 'r - r') - , ('r - nm', 'r - m') - , ('m - nx', 'm - x') - , ('m - nc', 'm - c') - , ('m - nr', 'm - r') - , ('m - nm', 'm - m') - - # multiplication - , ('x * e', 'nx * e') - , ('c * e', 'nc * e') - , ('r * e', 'nr * e') - , ('m * e', 'nm * e') - , ('x * e', 'e * x') - , ('c * e', 'e * c') - , ('r * e', 'e * r') - , ('m * e', 'e * m') - , ('x * x', 'x ** 2') - , ('c * c', 'c ** 2') - , ('r * r', 'r ** 2') - , ('m * m', 'm ** 2') - , ('x * c', 'np.array([[2.0], [4.0]])') - , ('x * r', 'np.array([[2.0, 4.0, 6.0]])') - , ('x * m', 'np.array([[2.0, 4.0, 6.0], [8.0, 10.0, 12.0]])') - , ('c * m', 'np.array([[1.0, 2.0, 3.0], [8.0, 10.0, 12.0]])') - , ('r * m', 'np.array([[1.0, 4.0, 9.0], [4.0, 10.0, 18.0]])') - , ('x * c', 'c * x') - , ('x * r', 'r * x') - , ('x * m', 'm * x') - , ('c * m', 'm * c') - , ('r * m', 'm * r') - , ('x * nx', 'x * x') - , ('x * nc', 'x * c') - , ('x * nr', 'x * r') - , ('x * nm', 'x * m') - , ('c * nx', 'c * x') - , ('c * nc', 'c * c') - , ('c * nm', 'c * m') - , ('r * nx', 'r * x') - , ('r * nr', 'r * r') - , ('r * nm', 'r * m') - , ('m * nx', 'm * x') - , ('m * nc', 'm * c') - , ('m * nr', 'm * r') - , ('m * nm', 'm * m') - - , ('m.T', 'nm.T') - , ('m.T', 'nm.T') - , ('row.T', 'nrow.T') - , ('m @ m.T', 'nm @ nm.T') - , ('m @ nm.T', 'nm @ nm.T') - , ('row @ row.T', 'nrow @ nrow.T') - , ('row @ nrow.T', 'nrow @ nrow.T') - , ('m.T @ m', 'nm.T @ nm') - , ('m.T @ nm', 'nm.T @ nm') - , ('row.T @ row', 'nrow.T @ nrow') - , ('row.T @ nrow', 'nrow.T @ nrow') - ] + [ # addition + ('+m', '0 + m'), + ('x + e', 'nx + e'), + ('c + e', 'nc + e'), + ('r + e', 'nr + e'), + ('m + e', 'nm + e'), + ('x + e', 'e + x'), + ('c + e', 'e + c'), + ('r + e', 'e + r'), + ('m + e', 'e + m'), + ('x + x', '2 * x'), + ('c + c', '2 * c'), + ('r + r', '2 * r'), + ('m + m', '2 * m'), + ('x + c', 'np.array([[3.0], [4.0]])'), + ('x + r', 'np.array([[3.0, 4.0, 5.0]])'), + ('x + m', 'np.array([[3.0, 4.0, 5.0], [6.0, 7.0, 8.0]])'), + ('c + m', 'np.array([[2.0, 3.0, 4.0], [6.0, 7.0, 8.0]])'), + ('r + m', 'np.array([[2.0, 4.0, 6.0], [5.0, 7.0, 9.0]])'), + ('x + c', 'c + x'), + ('x + r', 'r + x'), + ('x + m', 'm + x'), + ('c + m', 'm + c'), + ('r + m', 'm + r'), + ('x + nx', 'x + x'), + ('x + nc', 'x + c'), + ('x + nr', 'x + r'), + ('x + nm', 'x + m'), + ('c + nx', 'c + x'), + ('c + nc', 'c + c'), + ('c + nm', 'c + m'), + ('r + nx', 'r + x'), + ('r + nr', 'r + r'), + ('r + nm', 'r + m'), + ('m + nx', 'm + x'), + ('m + nc', 'm + c'), + ('m + nr', 'm + r'), + ('m + nm', 'm + m'), + # subtraction + ('-m', '0 - m'), + ('x - e', 'nx - e'), + ('c - e', 'nc - e'), + ('r - e', 'nr - e'), + ('m - e', 'nm - e'), + ('x - e', '-(e - x)'), + ('c - e', '-(e - c)'), + ('r - e', '-(e - r)'), + ('m - e', '-(e - m)'), + ('x - x', 'np.zeros((1, 1))'), + ('c - c', 'np.zeros((2, 1))'), + ('r - r', 'np.zeros((1, 3))'), + ('m - m', 'np.zeros((2, 3))'), + ('x - c', 'np.array([[1.0], [0.0]])'), + ('x - r', 'np.array([[1.0, 0.0, -1.0]])'), + ('x - m', 'np.array([[1.0, 0.0, -1.0], [-2.0, -3.0, -4.0]])'), + ('c - m', 'np.array([[0.0, -1.0, -2.0], [-2.0, -3.0, -4.0]])'), + ('r - m', 'np.array([[0.0, 0.0, 0.0], [-3.0, -3.0, -3.0]])'), + ('x - c', '-(c - x)'), + ('x - r', '-(r - x)'), + ('x - m', '-(m - x)'), + ('c - m', '-(m - c)'), + ('r - m', '-(m - r)'), + ('x - nx', 'x - x'), + ('x - nc', 'x - c'), + ('x - nr', 'x - r'), + ('x - nm', 'x - m'), + ('c - nx', 'c - x'), + ('c - nc', 'c - c'), + ('c - nm', 'c - m'), + ('r - nx', 'r - x'), + ('r - nr', 'r - r'), + ('r - nm', 'r - m'), + ('m - nx', 'm - x'), + ('m - nc', 'm - c'), + ('m - nr', 'm - r'), + ('m - nm', 'm - m'), + # multiplication + ('x * e', 'nx * e'), + ('c * e', 'nc * e'), + ('r * e', 'nr * e'), + ('m * e', 'nm * e'), + ('x * e', 'e * x'), + ('c * e', 'e * c'), + ('r * e', 'e * r'), + ('m * e', 'e * m'), + ('x * x', 'x ** 2'), + ('c * c', 'c ** 2'), + ('r * r', 'r ** 2'), + ('m * m', 'm ** 2'), + ('x * c', 'np.array([[2.0], [4.0]])'), + ('x * r', 'np.array([[2.0, 4.0, 6.0]])'), + ('x * m', 'np.array([[2.0, 4.0, 6.0], [8.0, 10.0, 12.0]])'), + ('c * m', 'np.array([[1.0, 2.0, 3.0], [8.0, 10.0, 12.0]])'), + ('r * m', 'np.array([[1.0, 4.0, 9.0], [4.0, 10.0, 18.0]])'), + ('x * c', 'c * x'), + ('x * r', 'r * x'), + ('x * m', 'm * x'), + ('c * m', 'm * c'), + ('r * m', 'm * r'), + ('x * nx', 'x * x'), + ('x * nc', 'x * c'), + ('x * nr', 'x * r'), + ('x * nm', 'x * m'), + ('c * nx', 'c * x'), + ('c * nc', 'c * c'), + ('c * nm', 'c * m'), + ('r * nx', 'r * x'), + ('r * nr', 'r * r'), + ('r * nm', 'r * m'), + ('m * nx', 'm * x'), + ('m * nc', 'm * c'), + ('m * nr', 'm * r'), + ('m * nm', 'm * m'), + ('m.T', 'nm.T'), + ('m.T', 'nm.T'), + ('row.T', 'nrow.T'), + ('m @ m.T', 'nm @ nm.T'), + ('m @ nm.T', 'nm @ nm.T'), + ('row @ row.T', 'nrow @ nrow.T'), + ('row @ nrow.T', 'nrow @ nrow.T'), + ('m.T @ m', 'nm.T @ nm'), + ('m.T @ nm', 'nm.T @ nm'), + ('row.T @ row', 'nrow.T @ nrow'), + ('row.T @ nrow', 'nrow.T @ nrow'), + ], ) def test_block_matrix_elementwise_arithmetic(block_matrix_bindings, x, y): lhs = eval(x, block_matrix_bindings) - rhs = eval(y, { 'np': np }, block_matrix_bindings) + rhs = eval(y, {'np': np}, block_matrix_bindings) _assert_eq(lhs, rhs) @pytest.mark.parametrize( 'x, y', - [ # division - ('x / e', 'nx / e') - , ('c / e', 'nc / e') - , ('r / e', 'nr / e') - , ('m / e', 'nm / e') - , ('x / e', '1 / (e / x)') - , ('c / e', '1 / (e / c)') - , ('r / e', '1 / (e / r)') - , ('m / e', '1 / (e / m)') - , ('x / x', 'np.ones((1, 1))') - , ('c / c', 'np.ones((2, 1))') - , ('r / r', 'np.ones((1, 3))') - , ('m / m', 'np.ones((2, 3))') - , ('x / c', 'np.array([[2 / 1.0], [2 / 2.0]])') - , ('x / r', 'np.array([[2 / 1.0, 2 / 2.0, 2 / 3.0]])') - , ('x / m', 'np.array([[2 / 1.0, 2 / 2.0, 2 / 3.0], [2 / 4.0, 2 / 5.0, 2 / 6.0]])') - , ('c / m', 'np.array([[1 / 1.0, 1 / 2.0, 1 / 3.0], [2 / 4.0, 2 / 5.0, 2 / 6.0]])') - , ('r / m', 'np.array([[1 / 1.0, 2 / 2.0, 3 / 3.0], [1 / 4.0, 2 / 5.0, 3 / 6.0]])') - , ('x / c', '1 / (c / x)') - , ('x / r', '1 / (r / x)') - , ('x / m', '1 / (m / x)') - , ('c / m', '1 / (m / c)') - , ('r / m', '1 / (m / r)') - , ('x / nx', 'x / x') - , ('x / nc', 'x / c') - , ('x / nr', 'x / r') - , ('x / nm', 'x / m') - , ('c / nx', 'c / x') - , ('c / nc', 'c / c') - , ('c / nm', 'c / m') - , ('r / nx', 'r / x') - , ('r / nr', 'r / r') - , ('r / nm', 'r / m') - , ('m / nx', 'm / x') - , ('m / nc', 'm / c') - , ('m / nr', 'm / r') - , ('m / nm', 'm / m') - - # other ops - , ('m ** 3' , 'nm ** 3') - , ('m.sqrt()' , 'np.sqrt(nm)') - , ('m.ceil()' , 'np.ceil(nm)') - , ('m.floor()' , 'np.floor(nm)') - , ('m.log()' , 'np.log(nm)') - , ('(m - 4).abs()', 'np.abs(nm - 4)') - ] + [ # division + ('x / e', 'nx / e'), + ('c / e', 'nc / e'), + ('r / e', 'nr / e'), + ('m / e', 'nm / e'), + ('x / e', '1 / (e / x)'), + ('c / e', '1 / (e / c)'), + ('r / e', '1 / (e / r)'), + ('m / e', '1 / (e / m)'), + ('x / x', 'np.ones((1, 1))'), + ('c / c', 'np.ones((2, 1))'), + ('r / r', 'np.ones((1, 3))'), + ('m / m', 'np.ones((2, 3))'), + ('x / c', 'np.array([[2 / 1.0], [2 / 2.0]])'), + ('x / r', 'np.array([[2 / 1.0, 2 / 2.0, 2 / 3.0]])'), + ('x / m', 'np.array([[2 / 1.0, 2 / 2.0, 2 / 3.0], [2 / 4.0, 2 / 5.0, 2 / 6.0]])'), + ('c / m', 'np.array([[1 / 1.0, 1 / 2.0, 1 / 3.0], [2 / 4.0, 2 / 5.0, 2 / 6.0]])'), + ('r / m', 'np.array([[1 / 1.0, 2 / 2.0, 3 / 3.0], [1 / 4.0, 2 / 5.0, 3 / 6.0]])'), + ('x / c', '1 / (c / x)'), + ('x / r', '1 / (r / x)'), + ('x / m', '1 / (m / x)'), + ('c / m', '1 / (m / c)'), + ('r / m', '1 / (m / r)'), + ('x / nx', 'x / x'), + ('x / nc', 'x / c'), + ('x / nr', 'x / r'), + ('x / nm', 'x / m'), + ('c / nx', 'c / x'), + ('c / nc', 'c / c'), + ('c / nm', 'c / m'), + ('r / nx', 'r / x'), + ('r / nr', 'r / r'), + ('r / nm', 'r / m'), + ('m / nx', 'm / x'), + ('m / nc', 'm / c'), + ('m / nr', 'm / r'), + ('m / nm', 'm / m'), + # other ops + ('m ** 3', 'nm ** 3'), + ('m.sqrt()', 'np.sqrt(nm)'), + ('m.ceil()', 'np.ceil(nm)'), + ('m.floor()', 'np.floor(nm)'), + ('m.log()', 'np.log(nm)'), + ('(m - 4).abs()', 'np.abs(nm - 4)'), + ], ) def test_block_matrix_elementwise_close_arithmetic(block_matrix_bindings, x, y): lhs = eval(x, block_matrix_bindings) - rhs = eval(y, { 'np': np }, block_matrix_bindings) + rhs = eval(y, {'np': np}, block_matrix_bindings) _assert_close(lhs, rhs) @pytest.mark.parametrize( 'expr, expectation', - [ ('x + np.array([\'one\'], dtype=str)', pytest.raises(TypeError)) - , ('m @ m ', pytest.raises(ValueError)) - , ('m @ nm', pytest.raises(ValueError)) - ] + [ + ('x + np.array([\'one\'], dtype=str)', pytest.raises(TypeError)), + ('m @ m ', pytest.raises(ValueError)), + ('m @ nm', pytest.raises(ValueError)), + ], ) def test_block_matrix_raises(block_matrix_bindings, expr, expectation): with expectation: - eval(expr, { 'np': np }, block_matrix_bindings) + eval(expr, {'np': np}, block_matrix_bindings) @pytest.mark.parametrize( 'x, y', - [ ( 'm.sum(axis=0).T' - , 'np.array([[5.0], [7.0], [9.0]])' - ) - , ( 'm.sum(axis=1).T' - , 'np.array([[6.0, 15.0]])' - ) - , ( 'm.sum(axis=0).T + row' - , 'np.array([[12.0, 13.0, 14.0],[14.0, 15.0, 16.0],[16.0, 17.0, 18.0]])' - ) - , ( 'm.sum(axis=0) + row.T' - , 'np.array([[12.0, 14.0, 16.0],[13.0, 15.0, 17.0],[14.0, 16.0, 18.0]])' - ) - , ( 'square.sum(axis=0).T + square.sum(axis=1)' - , 'np.array([[18.0], [30.0], [42.0]])' - ) - ] + [ + ('m.sum(axis=0).T', 'np.array([[5.0], [7.0], [9.0]])'), + ('m.sum(axis=1).T', 'np.array([[6.0, 15.0]])'), + ('m.sum(axis=0).T + row', 'np.array([[12.0, 13.0, 14.0],[14.0, 15.0, 16.0],[16.0, 17.0, 18.0]])'), + ('m.sum(axis=0) + row.T', 'np.array([[12.0, 14.0, 16.0],[13.0, 15.0, 17.0],[14.0, 16.0, 18.0]])'), + ('square.sum(axis=0).T + square.sum(axis=1)', 'np.array([[18.0], [30.0], [42.0]])'), + ], ) def test_matrix_sums(block_matrix_bindings, x, y): lhs = eval(x, block_matrix_bindings) - rhs = eval(y, { 'np': np }, block_matrix_bindings) + rhs = eval(y, {'np': np}, block_matrix_bindings) _assert_eq(lhs, rhs) @@ -566,15 +578,16 @@ def test_matrix_sums(block_matrix_bindings, x, y): @fails_local_backend() @pytest.mark.parametrize( 'x, y', - [ ('m.tree_matmul(m.T, splits=2)', 'nm @ nm.T') - , ('m.tree_matmul(nm.T, splits=2)', 'nm @ nm.T') - , ('row.tree_matmul(row.T, splits=2)', 'nrow @ nrow.T') - , ('row.tree_matmul(nrow.T, splits=2)', 'nrow @ nrow.T') - , ('m.T.tree_matmul(m, splits=2)', 'nm.T @ nm') - , ('m.T.tree_matmul(nm, splits=2)', 'nm.T @ nm') - , ('row.T.tree_matmul(row, splits=2)', 'nrow.T @ nrow') - , ('row.T.tree_matmul(nrow, splits=2)', 'nrow.T @ nrow') - ] + [ + ('m.tree_matmul(m.T, splits=2)', 'nm @ nm.T'), + ('m.tree_matmul(nm.T, splits=2)', 'nm @ nm.T'), + ('row.tree_matmul(row.T, splits=2)', 'nrow @ nrow.T'), + ('row.tree_matmul(nrow.T, splits=2)', 'nrow @ nrow.T'), + ('m.T.tree_matmul(m, splits=2)', 'nm.T @ nm'), + ('m.T.tree_matmul(nm, splits=2)', 'nm.T @ nm'), + ('row.T.tree_matmul(row, splits=2)', 'nrow.T @ nrow'), + ('row.T.tree_matmul(nrow, splits=2)', 'nrow.T @ nrow'), + ], ) def test_tree_matmul(block_matrix_bindings, x, y): lhs = eval(x, block_matrix_bindings) @@ -586,11 +599,12 @@ def test_tree_matmul(block_matrix_bindings, x, y): @fails_local_backend() @pytest.mark.parametrize( 'nrows,ncols,block_size,split_size', - [ (nrows,ncols,block_size,split_size) - for (nrows, ncols) in [(50, 60), (60, 25)] - for block_size in [7, 10] - for split_size in [2, 9] - ] + [ + (nrows, ncols, block_size, split_size) + for (nrows, ncols) in [(50, 60), (60, 25)] + for block_size in [7, 10] + for split_size in [2, 9] + ], ) def test_tree_matmul_splits(block_size, split_size, nrows, ncols): # Variety of block sizes and splits @@ -650,15 +664,16 @@ def test_slicing_0(indices): @pytest.mark.parametrize( 'indices', - [ (slice(0, 8), slice(0, 10)) - , (slice(0, 8, 2), slice(0, 10, 2)) - , (slice(2, 4), slice(5, 7)) - , (slice(-8, -1), slice(-10, -1)) - , (slice(-8, -1, 2), slice(-10, -1, 2)) - , (slice(None, 4, 1), slice(None, 4, 1)) - , (slice(4, None), slice(4, None)) - , (slice(None, None), slice(None, None)) - ] + [ + (slice(0, 8), slice(0, 10)), + (slice(0, 8, 2), slice(0, 10, 2)), + (slice(2, 4), slice(5, 7)), + (slice(-8, -1), slice(-10, -1)), + (slice(-8, -1, 2), slice(-10, -1, 2)), + (slice(None, 4, 1), slice(None, 4, 1)), + (slice(4, None), slice(4, None)), + (slice(None, None), slice(None, None)), + ], ) def test_slicing_1(indices): nd = np.array(np.arange(0, 80, dtype=float)).reshape(8, 10) @@ -670,15 +685,16 @@ def test_slicing_1(indices): @pytest.mark.parametrize( 'indices, axis', - [ ((0, slice(3, 4)) , 0) - , ((1, slice(3, 4)) , 0) - , ((-8, slice(3, 4)), 0) - , ((-1, slice(3, 4)), 0) - , ((slice(3, 4), 0), 1) - , ((slice(3, 4), 1), 1) - , ((slice(3, 4), -8), 1) - , ((slice(3, 4), -1), 1) - ] + [ + ((0, slice(3, 4)), 0), + ((1, slice(3, 4)), 0), + ((-8, slice(3, 4)), 0), + ((-1, slice(3, 4)), 0), + ((slice(3, 4), 0), 1), + ((slice(3, 4), 1), 1), + ((slice(3, 4), -8), 1), + ((slice(3, 4), -1), 1), + ], ) def test_slicing_2(indices, axis): nd = np.array(np.arange(0, 80, dtype=float)).reshape(8, 10) @@ -690,22 +706,23 @@ def test_slicing_2(indices, axis): @pytest.mark.parametrize( 'expr', - [ 'square[0, ]' - , 'square[9, 0]' - , 'square[-9, 0]' - , 'square[0, 11]' - , 'square[0, -11]' - , 'square[::-1, 0]' - , 'square[0, ::-1]' - , 'square[:0, 0]' - , 'square[0, :0]' - , 'square[0:9, 0]' - , 'square[-9:, 0]' - , 'square[:-9, 0]' - , 'square[0, :11]' - , 'square[0, -11:]' - , 'square[0, :-11] ' - ] + [ + 'square[0, ]', + 'square[9, 0]', + 'square[-9, 0]', + 'square[0, 11]', + 'square[0, -11]', + 'square[::-1, 0]', + 'square[0, ::-1]', + 'square[:0, 0]', + 'square[0, :0]', + 'square[0:9, 0]', + 'square[-9:, 0]', + 'square[:-9, 0]', + 'square[0, :11]', + 'square[0, -11:]', + 'square[0, :-11] ', + ], ) def test_block_matrix_illegal_indexing(block_matrix_bindings, expr): with pytest.raises(ValueError): @@ -713,11 +730,13 @@ def test_block_matrix_illegal_indexing(block_matrix_bindings, expr): def test_diagonal_sparse(): - nd = np.array([[ 1.0, 2.0, 3.0, 4.0], - [ 5.0, 6.0, 7.0, 8.0], - [ 9.0, 10.0, 11.0, 12.0], - [13.0, 14.0, 15.0, 16.0], - [17.0, 18.0, 19.0, 20.0]]) + nd = np.array([ + [1.0, 2.0, 3.0, 4.0], + [5.0, 6.0, 7.0, 8.0], + [9.0, 10.0, 11.0, 12.0], + [13.0, 14.0, 15.0, 16.0], + [17.0, 18.0, 19.0, 20.0], + ]) bm = BlockMatrix.from_numpy(nd, block_size=2) bm = bm.sparsify_row_intervals([0, 0, 0, 0, 0], [2, 2, 2, 2, 2]) @@ -746,38 +765,27 @@ def test_slices_with_sparsify(): def test_sparsify_row_intervals_0(): - nd = np.array([[ 1.0, 2.0, 3.0, 4.0], - [ 5.0, 6.0, 7.0, 8.0], - [ 9.0, 10.0, 11.0, 12.0], - [13.0, 14.0, 15.0, 16.0]]) + nd = np.array([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]]) bm = BlockMatrix.from_numpy(nd, block_size=2) _assert_eq( - bm.sparsify_row_intervals( - starts=[1, 0, 2, 2], - stops= [2, 0, 3, 4]), - np.array([[ 0., 2., 0., 0.], - [ 0., 0., 0., 0.], - [ 0., 0., 11., 0.], - [ 0., 0., 15., 16.]])) + bm.sparsify_row_intervals(starts=[1, 0, 2, 2], stops=[2, 0, 3, 4]), + np.array([[0.0, 2.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 11.0, 0.0], [0.0, 0.0, 15.0, 16.0]]), + ) _assert_eq( - bm.sparsify_row_intervals( - starts=[1, 0, 2, 2], - stops= [2, 0, 3, 4], - blocks_only=True), - np.array([[ 1., 2., 0., 0.], - [ 5., 6., 0., 0.], - [ 0., 0., 11., 12.], - [ 0., 0., 15., 16.]])) + bm.sparsify_row_intervals(starts=[1, 0, 2, 2], stops=[2, 0, 3, 4], blocks_only=True), + np.array([[1.0, 2.0, 0.0, 0.0], [5.0, 6.0, 0.0, 0.0], [0.0, 0.0, 11.0, 12.0], [0.0, 0.0, 15.0, 16.0]]), + ) @pytest.mark.parametrize( 'starts, stops', - [ ([0, 1, 2, 3, 4, 5, 6, 7], [1, 2, 3, 4, 5, 6, 7, 8]) - , ([0, 0, 5, 3, 4, 5, 8, 2], [9, 0, 5, 3, 4, 5, 9, 5]) - , ([0, 5, 10, 8, 7, 6, 5, 4], [0, 5, 10, 9, 8, 7, 6, 5]) - ] + [ + ([0, 1, 2, 3, 4, 5, 6, 7], [1, 2, 3, 4, 5, 6, 7, 8]), + ([0, 0, 5, 3, 4, 5, 8, 2], [9, 0, 5, 3, 4, 5, 9, 5]), + ([0, 5, 10, 8, 7, 6, 5, 4], [0, 5, 10, 9, 8, 7, 6, 5]), + ], ) def test_row_intervals_1(starts, stops): nd2 = np.random.normal(size=(8, 10)) @@ -793,31 +801,21 @@ def test_row_intervals_1(starts, stops): def test_sparsify_band_0(): - nd = np.array([[ 1.0, 2.0, 3.0, 4.0], - [ 5.0, 6.0, 7.0, 8.0], - [ 9.0, 10.0, 11.0, 12.0], - [13.0, 14.0, 15.0, 16.0]]) + nd = np.array([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]]) bm = BlockMatrix.from_numpy(nd, block_size=2) _assert_eq( bm.sparsify_band(lower=-1, upper=2), - np.array([[ 1., 2., 3., 0.], - [ 5., 6., 7., 8.], - [ 0., 10., 11., 12.], - [ 0., 0., 15., 16.]])) + np.array([[1.0, 2.0, 3.0, 0.0], [5.0, 6.0, 7.0, 8.0], [0.0, 10.0, 11.0, 12.0], [0.0, 0.0, 15.0, 16.0]]), + ) _assert_eq( bm.sparsify_band(lower=0, upper=0, blocks_only=True), - np.array([[ 1., 2., 0., 0.], - [ 5., 6., 0., 0.], - [ 0., 0., 11., 12.], - [ 0., 0., 15., 16.]])) + np.array([[1.0, 2.0, 0.0, 0.0], [5.0, 6.0, 0.0, 0.0], [0.0, 0.0, 11.0, 12.0], [0.0, 0.0, 15.0, 16.0]]), + ) -@pytest.mark.parametrize( - 'lower, upper', - [ (0, 0), (1, 1), (2, 2), (-5, 5), (-7, 0), (0, 9), (-100, 100) ] -) +@pytest.mark.parametrize('lower, upper', [(0, 0), (1, 1), (2, 2), (-5, 5), (-7, 0), (0, 9), (-100, 100)]) def test_sparsify_band_1(lower, upper): nd2 = np.arange(0, 80, dtype=float).reshape(8, 10) bm2 = BlockMatrix.from_numpy(nd2, block_size=3) @@ -827,10 +825,7 @@ def test_sparsify_band_1(lower, upper): def test_sparsify_triangle(): - nd = np.array([[ 1.0, 2.0, 3.0, 4.0], - [ 5.0, 6.0, 7.0, 8.0], - [ 9.0, 10.0, 11.0, 12.0], - [13.0, 14.0, 15.0, 16.0]]) + nd = np.array([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]]) bm = BlockMatrix.from_numpy(nd, block_size=2) # FIXME doesn't work in service, if test_is_sparse works, uncomment below @@ -839,39 +834,28 @@ def test_sparsify_triangle(): _assert_eq( bm.sparsify_triangle(), - np.array([[ 1., 2., 3., 4.], - [ 0., 6., 7., 8.], - [ 0., 0., 11., 12.], - [ 0., 0., 0., 16.]])) + np.array([[1.0, 2.0, 3.0, 4.0], [0.0, 6.0, 7.0, 8.0], [0.0, 0.0, 11.0, 12.0], [0.0, 0.0, 0.0, 16.0]]), + ) _assert_eq( bm.sparsify_triangle(lower=True), - np.array([[ 1., 0., 0., 0.], - [ 5., 6., 0., 0.], - [ 9., 10., 11., 0.], - [13., 14., 15., 16.]])) + np.array([[1.0, 0.0, 0.0, 0.0], [5.0, 6.0, 0.0, 0.0], [9.0, 10.0, 11.0, 0.0], [13.0, 14.0, 15.0, 16.0]]), + ) _assert_eq( bm.sparsify_triangle(blocks_only=True), - np.array([[ 1., 2., 3., 4.], - [ 5., 6., 7., 8.], - [ 0., 0., 11., 12.], - [ 0., 0., 15., 16.]])) + np.array([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [0.0, 0.0, 11.0, 12.0], [0.0, 0.0, 15.0, 16.0]]), + ) def test_sparsify_rectangles(): - nd = np.array([[ 1.0, 2.0, 3.0, 4.0], - [ 5.0, 6.0, 7.0, 8.0], - [ 9.0, 10.0, 11.0, 12.0], - [13.0, 14.0, 15.0, 16.0]]) + nd = np.array([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]]) bm = BlockMatrix.from_numpy(nd, block_size=2) _assert_eq( bm.sparsify_rectangles([[0, 1, 0, 1], [0, 3, 0, 2], [1, 2, 0, 4]]), - np.array([[ 1., 2., 3., 4.], - [ 5., 6., 7., 8.], - [ 9., 10., 0., 0.], - [13., 14., 0., 0.]])) + np.array([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 0.0, 0.0], [13.0, 14.0, 0.0, 0.0]]), + ) _assert_eq(bm.sparsify_rectangles([]), np.zeros(shape=(4, 4))) @@ -880,28 +864,26 @@ def test_sparsify_rectangles(): @fails_local_backend() @pytest.mark.parametrize( 'rects,block_size,binary', - [ (rects, block_size, binary) - for binary in [False, True] - for block_size in [3, 4, 10] - for rects in - [ [ [0, 1, 0, 1] - , [4, 5, 7, 8] - ] - , [ [4, 5, 0, 10] - , [0, 8, 4, 5] - ] - , [ [0, 1, 0, 1] - , [1, 2, 1, 2] - , [2, 3, 2, 3] - , [3, 5, 3, 6] - , [3, 6, 3, 7] - , [3, 7, 3, 8] - , [4, 5, 0, 10] - , [0, 8, 4, 5] - , [0, 8, 0, 10] - ] + [ + (rects, block_size, binary) + for binary in [False, True] + for block_size in [3, 4, 10] + for rects in [ + [[0, 1, 0, 1], [4, 5, 7, 8]], + [[4, 5, 0, 10], [0, 8, 4, 5]], + [ + [0, 1, 0, 1], + [1, 2, 1, 2], + [2, 3, 2, 3], + [3, 5, 3, 6], + [3, 6, 3, 7], + [3, 7, 3, 8], + [4, 5, 0, 10], + [0, 8, 4, 5], + [0, 8, 0, 10], + ], ] - ] + ], ) def test_export_rectangles(rects, block_size, binary): nd = np.arange(0, 80, dtype=float).reshape(8, 10) @@ -915,19 +897,13 @@ def test_export_rectangles(rects, block_size, binary): @fails_local_backend() def test_export_rectangles_sparse(): with hl.TemporaryDirectory() as rect_uri: - nd = np.array([[1.0, 2.0, 3.0, 4.0], - [5.0, 6.0, 7.0, 8.0], - [9.0, 10.0, 11.0, 12.0], - [13.0, 14.0, 15.0, 16.0]]) + nd = np.array([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]]) bm = BlockMatrix.from_numpy(nd, block_size=2) sparsify_rects = [[0, 1, 0, 1], [0, 3, 0, 2], [1, 2, 0, 4]] export_rects = [[0, 1, 0, 1], [0, 3, 0, 2], [1, 2, 0, 4], [2, 4, 2, 4]] bm.sparsify_rectangles(sparsify_rects).export_rectangles(rect_uri, export_rects) - expected = np.array([[1.0, 2.0, 3.0, 4.0], - [5.0, 6.0, 7.0, 8.0], - [9.0, 10.0, 0.0, 0.0], - [13.0, 14.0, 0.0, 0.0]]) + expected = np.array([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 0.0, 0.0], [13.0, 14.0, 0.0, 0.0]]) _assert_rectangles_eq(expected, rect_uri, export_rects) @@ -936,17 +912,13 @@ def test_export_rectangles_sparse(): @fails_local_backend() def test_export_rectangles_filtered(): with hl.TemporaryDirectory() as rect_uri: - nd = np.array([[1.0, 2.0, 3.0, 4.0], - [5.0, 6.0, 7.0, 8.0], - [9.0, 10.0, 11.0, 12.0], - [13.0, 14.0, 15.0, 16.0]]) + nd = np.array([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]]) bm = BlockMatrix.from_numpy(nd) bm = bm[1:3, 1:3] export_rects = [[0, 1, 0, 2], [1, 2, 0, 2]] bm.export_rectangles(rect_uri, export_rects) - expected = np.array([[6.0, 7.0], - [10.0, 11.0]]) + expected = np.array([[6.0, 7.0], [10.0, 11.0]]) _assert_rectangles_eq(expected, rect_uri, export_rects) @@ -981,8 +953,12 @@ def test_to_ndarray(): assert np.array_equal(np_mat, hl.eval(mat)) blocks_to_sparsify = [1, 4, 7, 12, 20, 42, 48] - sparsed_numpy = sparsify_numpy(np.arange(25*25).reshape((25, 25)), 4, blocks_to_sparsify) - sparsed = BlockMatrix.from_ndarray(hl.nd.array(sparsed_numpy), block_size=4)._sparsify_blocks(blocks_to_sparsify).to_ndarray() + sparsed_numpy = sparsify_numpy(np.arange(25 * 25).reshape((25, 25)), 4, blocks_to_sparsify) + sparsed = ( + BlockMatrix.from_ndarray(hl.nd.array(sparsed_numpy), block_size=4) + ._sparsify_blocks(blocks_to_sparsify) + .to_ndarray() + ) assert np.array_equal(sparsed_numpy, hl.eval(sparsed)) @@ -992,10 +968,7 @@ def test_block_matrix_entries(block_size): n_rows, n_cols = 5, 3 rows = [{'i': i, 'j': j, 'entry': float(i + j)} for i in range(n_rows) for j in range(n_cols)] schema = hl.tstruct(i=hl.tint32, j=hl.tint32, entry=hl.tfloat64) - table = hl.Table.parallelize( - [hl.struct(i=row['i'], j=row['j'], entry=row['entry']) for row in rows], - schema - ) + table = hl.Table.parallelize([hl.struct(i=row['i'], j=row['j'], entry=row['entry']) for row in rows], schema) table = table.annotate(i=hl.int64(table.i), j=hl.int64(table.j)).key_by('i', 'j') ndarray = np.reshape(list(map(lambda row: row['entry'], rows)), (n_rows, n_cols)) @@ -1009,9 +982,10 @@ def test_block_matrix_entries(block_size): def test_from_entry_expr_filtered(): mt = hl.utils.range_matrix_table(1, 1).filter_entries(False) - bm = hl.linalg.BlockMatrix.from_entry_expr(mt.row_idx + mt.col_idx, mean_impute=True) # should run without error + bm = hl.linalg.BlockMatrix.from_entry_expr(mt.row_idx + mt.col_idx, mean_impute=True) # should run without error assert np.isnan(bm.entries().entry.collect()[0]) + def test_array_windows(): def assert_eq(a, b): assert np.array_equal(a, np.array(b)) @@ -1039,14 +1013,15 @@ def assert_eq(a, b): @pytest.mark.parametrize( 'array,radius', - [ ([1, 0], -1) - , ([0, float('nan')], 1) - , ([float('nan')], 1) - , ([0.0, float('nan')], 1) - , ([None], 1) - , ([], -1) - , (['str'], 1) - ] + [ + ([1, 0], -1), + ([0, float('nan')], 1), + ([float('nan')], 1), + ([0.0, float('nan')], 1), + ([None], 1), + ([], -1), + (['str'], 1), + ], ) def test_array_windows_illegal_arguments(array, radius): with pytest.raises(ValueError): @@ -1092,16 +1067,16 @@ def test_locus_windows_3(): def test_locus_windows_4(): - rows = [{'locus': hl.Locus('1', 1), 'cm': 1.0}, - {'locus': hl.Locus('1', 2), 'cm': 3.0}, - {'locus': hl.Locus('1', 4), 'cm': 4.0}, - {'locus': hl.Locus('2', 1), 'cm': 2.0}, - {'locus': hl.Locus('2', 1), 'cm': 2.0}, - {'locus': hl.Locus('3', 3), 'cm': 5.0}] + rows = [ + {'locus': hl.Locus('1', 1), 'cm': 1.0}, + {'locus': hl.Locus('1', 2), 'cm': 3.0}, + {'locus': hl.Locus('1', 4), 'cm': 4.0}, + {'locus': hl.Locus('2', 1), 'cm': 2.0}, + {'locus': hl.Locus('2', 1), 'cm': 2.0}, + {'locus': hl.Locus('3', 3), 'cm': 5.0}, + ] - ht = hl.Table.parallelize(rows, - hl.tstruct(locus=hl.tlocus('GRCh37'), cm=hl.tfloat64), - key=['locus']) + ht = hl.Table.parallelize(rows, hl.tstruct(locus=hl.tlocus('GRCh37'), cm=hl.tfloat64), key=['locus']) starts, stops = hl.linalg.utils.locus_windows(ht.locus, 1) assert_np_arrays_eq(starts, [0, 0, 2, 3, 3, 5]) @@ -1109,16 +1084,16 @@ def test_locus_windows_4(): def dummy_table_with_loci_and_cms(): - rows = [{'locus': hl.Locus('1', 1), 'cm': 1.0}, - {'locus': hl.Locus('1', 2), 'cm': 3.0}, - {'locus': hl.Locus('1', 4), 'cm': 4.0}, - {'locus': hl.Locus('2', 1), 'cm': 2.0}, - {'locus': hl.Locus('2', 1), 'cm': 2.0}, - {'locus': hl.Locus('3', 3), 'cm': 5.0}] + rows = [ + {'locus': hl.Locus('1', 1), 'cm': 1.0}, + {'locus': hl.Locus('1', 2), 'cm': 3.0}, + {'locus': hl.Locus('1', 4), 'cm': 4.0}, + {'locus': hl.Locus('2', 1), 'cm': 2.0}, + {'locus': hl.Locus('2', 1), 'cm': 2.0}, + {'locus': hl.Locus('3', 3), 'cm': 5.0}, + ] - return hl.Table.parallelize(rows, - hl.tstruct(locus=hl.tlocus('GRCh37'), cm=hl.tfloat64), - key=['locus']) + return hl.Table.parallelize(rows, hl.tstruct(locus=hl.tlocus('GRCh37'), cm=hl.tfloat64), key=['locus']) def test_locus_windows_5(): @@ -1153,7 +1128,7 @@ def test_locus_windows_9(): def test_locus_windows_10(): ht = dummy_table_with_loci_and_cms() - ht = ht.annotate_globals(x = hl.locus('1', 1), y = 1.0) + ht = ht.annotate_globals(x=hl.locus('1', 1), y=1.0) with pytest.raises(ExpressionException, match='row-indexed'): hl.linalg.utils.locus_windows(ht.x, 1.0) @@ -1162,8 +1137,11 @@ def test_locus_windows_10(): def test_locus_windows_11(): - ht = hl.Table.parallelize([{'locus': hl.missing(hl.tlocus()), 'cm': 1.0}], - hl.tstruct(locus=hl.tlocus('GRCh37'), cm=hl.tfloat64), key=['locus']) + ht = hl.Table.parallelize( + [{'locus': hl.missing(hl.tlocus()), 'cm': 1.0}], + hl.tstruct(locus=hl.tlocus('GRCh37'), cm=hl.tfloat64), + key=['locus'], + ) with pytest.raises(HailUserError, match='missing value for \'locus_expr\''): hl.linalg.utils.locus_windows(ht.locus, 1.0) @@ -1175,7 +1153,7 @@ def test_locus_windows_12(): ht = hl.Table.parallelize( [{'locus': hl.Locus('1', 1), 'cm': hl.missing(hl.tfloat64)}], hl.tstruct(locus=hl.tlocus('GRCh37'), cm=hl.tfloat64), - key=['locus'] + key=['locus'], ) with pytest.raises(FatalError, match='missing value for \'coord_expr\''): hl.linalg.utils.locus_windows(ht.locus, 1.0, coord_expr=ht.cm) @@ -1206,8 +1184,7 @@ def assert_same_columns_up_to_sign(a, b): for j in range(a.shape[1]): assert np.allclose(a[:, j], b[:, j]) or np.allclose(-a[:, j], b[:, j]) - x0 = np.array([[-2.0, 0.0, 3.0], - [-1.0, 2.0, 4.0]]) + x0 = np.array([[-2.0, 0.0, 3.0], [-1.0, 2.0, 4.0]]) u0, s0, vt0 = np.linalg.svd(x0, full_matrices=False) x = BlockMatrix.from_numpy(x0) @@ -1284,20 +1261,10 @@ def test_is_sparse(): bm = BlockMatrix.from_numpy(np_square, block_size=2) bm = bm._sparsify_blocks(block_list) assert bm.is_sparse - assert np.array_equal( - bm.to_numpy(), - np.array([[0, 0, 2, 3], - [0, 0, 6, 7], - [8, 9, 0, 0], - [12, 13, 0, 0]])) + assert np.array_equal(bm.to_numpy(), np.array([[0, 0, 2, 3], [0, 0, 6, 7], [8, 9, 0, 0], [12, 13, 0, 0]])) -@pytest.mark.parametrize( - 'block_list,nrows,ncols,block_size', - [ ([1, 2], 4, 4, 2) - , ([4, 8, 10, 12, 13, 14], 15, 15, 4) - ] -) +@pytest.mark.parametrize('block_list,nrows,ncols,block_size', [([1, 2], 4, 4, 2), ([4, 8, 10, 12, 13, 14], 15, 15, 4)]) def test_sparsify_blocks(block_list, nrows, ncols, block_size): np_square = np.arange(nrows * ncols, dtype=np.float64).reshape((nrows, ncols)) bm = BlockMatrix.from_numpy(np_square, block_size=block_size) @@ -1308,11 +1275,12 @@ def test_sparsify_blocks(block_list, nrows, ncols, block_size): @pytest.mark.parametrize( 'block_list,nrows,ncols,block_size', - [ ([1, 2], 4, 4, 2) - , ([4, 8, 10, 12, 13, 14], 15, 15, 4) - , ([2, 5, 8, 10, 11], 10, 15, 4) - , ([2, 5, 8, 10, 11], 15, 11, 4) - ] + [ + ([1, 2], 4, 4, 2), + ([4, 8, 10, 12, 13, 14], 15, 15, 4), + ([2, 5, 8, 10, 11], 10, 15, 4), + ([2, 5, 8, 10, 11], 15, 11, 4), + ], ) def test_sparse_transposition(block_list, nrows, ncols, block_size): np_square = np.arange(nrows * ncols, dtype=np.float64).reshape((nrows, ncols)) @@ -1328,11 +1296,11 @@ def test_row_blockmatrix_sum(): # Summing vertically along a column vector to get a single value b = col.sum(axis=0) - assert b.to_numpy().shape == (1,1) + assert b.to_numpy().shape == (1, 1) # Summing horizontally along a row vector to create a single value d = row.sum(axis=1) - assert d.to_numpy().shape == (1,1) + assert d.to_numpy().shape == (1, 1) # Summing vertically along a row vector to make sure nothing changes e = row.sum(axis=0) diff --git a/hail/python/test/hail/matrixtable/test_file_formats.py b/hail/python/test/hail/matrixtable/test_file_formats.py index a0bda5b694e..42f28785df1 100644 --- a/hail/python/test/hail/matrixtable/test_file_formats.py +++ b/hail/python/test/hail/matrixtable/test_file_formats.py @@ -1,12 +1,17 @@ import asyncio -import pytest -import os from typing import List, Tuple -from pathlib import Path + +import pytest import hail as hl from hail.utils.java import Env, scala_object -from ..helpers import * + +from ..helpers import ( + create_all_values_datasets, + create_all_values_matrix_table, + create_all_values_table, + resource, +) def create_backward_compatibility_files(): @@ -28,8 +33,9 @@ def create_backward_compatibility_files(): i = 0 for codec in supported_codecs: all_values_table.write(os.path.join(table_dir, f'{i}.ht'), overwrite=True, _codec_spec=codec.toString()) - all_values_matrix_table.write(os.path.join(matrix_table_dir, f'{i}.hmt'), overwrite=True, - _codec_spec=codec.toString()) + all_values_matrix_table.write( + os.path.join(matrix_table_dir, f'{i}.hmt'), overwrite=True, _codec_spec=codec.toString() + ) i += 1 @@ -48,28 +54,31 @@ def all_values_table_fixture(init_hail): return create_all_values_table() +resource_dir = resource('backward_compatability/') + + async def collect_paths() -> Tuple[List[str], List[str]]: - resource_dir = resource('backward_compatability/') from hailtop.aiotools.router_fs import RouterAsyncFS + fs = RouterAsyncFS() async def contents_if_present(url: str): try: return await fs.listfiles(url) except FileNotFoundError: + async def empty(): if False: yield + return empty() try: versions = [await x.url() async for x in await fs.listfiles(resource_dir)] - ht_paths = [await x.url() - for version in versions - async for x in await contents_if_present(version + 'table/')] - mt_paths = [await x.url() - for version in versions - async for x in await contents_if_present(version + 'matrix_table/')] + ht_paths = [await x.url() for version in versions async for x in await contents_if_present(version + 'table/')] + mt_paths = [ + await x.url() for version in versions async for x in await contents_if_present(version + 'matrix_table/') + ] return ht_paths, mt_paths finally: await fs.close() diff --git a/hail/python/test/hail/matrixtable/test_grouped_matrix_table.py b/hail/python/test/hail/matrixtable/test_grouped_matrix_table.py index 74b986bf0f8..08930f9f7e6 100644 --- a/hail/python/test/hail/matrixtable/test_grouped_matrix_table.py +++ b/hail/python/test/hail/matrixtable/test_grouped_matrix_table.py @@ -1,28 +1,26 @@ import unittest import hail as hl -from ..helpers import * +from ..helpers import qobtest, test_timeout -class Tests(unittest.TestCase): +class Tests(unittest.TestCase): @staticmethod def get_groupable_matrix(): rt = hl.utils.range_matrix_table(n_rows=100, n_cols=20) rt = rt.annotate_globals(foo="foo") - rt = rt.annotate_rows(group1=rt['row_idx'] % 6, - group2=hl.Struct(a=rt['row_idx'] % 6, - b="foo")) - rt = rt.annotate_cols(group3=rt['col_idx'] % 6, - group4=hl.Struct(a=rt['col_idx'] % 6, - b="foo")) - return rt.annotate_entries(c=rt['row_idx'], - d=rt['col_idx'], - e="foo", - f=rt['group1'], - g=rt['group2']['a'], - h=rt['group3'], - i=rt['group4']['a']) + rt = rt.annotate_rows(group1=rt['row_idx'] % 6, group2=hl.Struct(a=rt['row_idx'] % 6, b="foo")) + rt = rt.annotate_cols(group3=rt['col_idx'] % 6, group4=hl.Struct(a=rt['col_idx'] % 6, b="foo")) + return rt.annotate_entries( + c=rt['row_idx'], + d=rt['col_idx'], + e="foo", + f=rt['group1'], + g=rt['group2']['a'], + h=rt['group3'], + i=rt['group4']['a'], + ) @staticmethod def get_groupable_matrix2(): @@ -52,40 +50,76 @@ def test_errors_caught_correctly(self): self.assertRaises(ExpressionException, mt.group_cols_by, foo=mt['group3']) a = mt.group_rows_by(group5=(mt['group2']['a'] + 1)) - self.assertRaises(NotImplementedError, a.aggregate_cols, bar=hl.agg.sum(mt['col_idx'])) # cannot aggregate cols when grouped by rows - - self.assertRaises(ExpressionException, a.aggregate_entries, group3=hl.agg.sum(mt['c'])) # duplicate column field - self.assertRaises(ExpressionException, a.aggregate_entries, group5=hl.agg.sum(mt['c'])) # duplicate row field - self.assertRaises(ExpressionException, a.aggregate_entries, foo=hl.agg.sum(mt['c'])) # duplicate globals field - - self.assertRaises(ExpressionException, a.aggregate_rows, group3=hl.agg.sum(mt['row_idx'])) # duplicate column field - self.assertRaises(ExpressionException, a.aggregate_rows, group5=hl.agg.sum(mt['row_idx'])) # duplicate row field - self.assertRaises(ExpressionException, a.aggregate_rows, foo=hl.agg.sum(mt['row_idx'])) # duplicate globals field - self.assertRaises(ExpressionException, a.aggregate_rows, bar=mt['row_idx'] + hl.agg.sum(mt['row_idx'])) # expression has to have global indices - self.assertRaises(ExpressionException, a.aggregate_rows, bar=mt['col_idx'] + hl.agg.sum(mt['row_idx'])) # expression has to have global indices - self.assertRaises(ExpressionException, a.aggregate_rows, bar=hl.agg.sum(mt['c'])) # aggregation scope is rows only - entry field - self.assertRaises(ExpressionException, a.aggregate_rows, bar=hl.agg.sum(mt['col_idx'])) # aggregation scope is rows only - column field + self.assertRaises( + NotImplementedError, a.aggregate_cols, bar=hl.agg.sum(mt['col_idx']) + ) # cannot aggregate cols when grouped by rows + + self.assertRaises( + ExpressionException, a.aggregate_entries, group3=hl.agg.sum(mt['c']) + ) # duplicate column field + self.assertRaises(ExpressionException, a.aggregate_entries, group5=hl.agg.sum(mt['c'])) # duplicate row field + self.assertRaises(ExpressionException, a.aggregate_entries, foo=hl.agg.sum(mt['c'])) # duplicate globals field + + self.assertRaises( + ExpressionException, a.aggregate_rows, group3=hl.agg.sum(mt['row_idx']) + ) # duplicate column field + self.assertRaises( + ExpressionException, a.aggregate_rows, group5=hl.agg.sum(mt['row_idx']) + ) # duplicate row field + self.assertRaises( + ExpressionException, a.aggregate_rows, foo=hl.agg.sum(mt['row_idx']) + ) # duplicate globals field + self.assertRaises( + ExpressionException, a.aggregate_rows, bar=mt['row_idx'] + hl.agg.sum(mt['row_idx']) + ) # expression has to have global indices + self.assertRaises( + ExpressionException, a.aggregate_rows, bar=mt['col_idx'] + hl.agg.sum(mt['row_idx']) + ) # expression has to have global indices + self.assertRaises( + ExpressionException, a.aggregate_rows, bar=hl.agg.sum(mt['c']) + ) # aggregation scope is rows only - entry field + self.assertRaises( + ExpressionException, a.aggregate_rows, bar=hl.agg.sum(mt['col_idx']) + ) # aggregation scope is rows only - column field b = mt.group_cols_by(group5=(mt['group4']['a'] + 1)) - self.assertRaises(NotImplementedError, b.aggregate_rows, bar=hl.agg.sum(mt['row_idx'])) # cannot aggregate rows when grouped by cols - - self.assertRaises(ExpressionException, b.aggregate_entries, group1=hl.agg.sum(mt['c'])) # duplicate row field - self.assertRaises(ExpressionException, b.aggregate_entries, group5=hl.agg.sum(mt['c'])) # duplicate column field - self.assertRaises(ExpressionException, b.aggregate_entries, foo=hl.agg.sum(mt['c'])) # duplicate globals field - - self.assertRaises(ExpressionException, b.aggregate_cols, group1=hl.agg.sum(mt['col_idx'])) # duplicate row field - self.assertRaises(ExpressionException, b.aggregate_cols, group5=hl.agg.sum(mt['col_idx'])) # duplicate column field - self.assertRaises(ExpressionException, b.aggregate_cols, foo=hl.agg.sum(mt['col_idx'])) # duplicate globals field - self.assertRaises(ExpressionException, b.aggregate_cols, bar=mt['col_idx'] + hl.agg.sum(mt['col_idx'])) # expression has to have global indices - self.assertRaises(ExpressionException, b.aggregate_cols, bar=mt['row_idx'] + hl.agg.sum(mt['col_idx'])) # expression has to have global indices - self.assertRaises(ExpressionException, b.aggregate_cols, bar=hl.agg.sum(mt['c'])) # aggregation scope is cols only - entry field - self.assertRaises(ExpressionException, b.aggregate_cols, bar=hl.agg.sum(mt['row_idx'])) # aggregation scope is cols only - row field + self.assertRaises( + NotImplementedError, b.aggregate_rows, bar=hl.agg.sum(mt['row_idx']) + ) # cannot aggregate rows when grouped by cols + + self.assertRaises(ExpressionException, b.aggregate_entries, group1=hl.agg.sum(mt['c'])) # duplicate row field + self.assertRaises( + ExpressionException, b.aggregate_entries, group5=hl.agg.sum(mt['c']) + ) # duplicate column field + self.assertRaises(ExpressionException, b.aggregate_entries, foo=hl.agg.sum(mt['c'])) # duplicate globals field + + self.assertRaises( + ExpressionException, b.aggregate_cols, group1=hl.agg.sum(mt['col_idx']) + ) # duplicate row field + self.assertRaises( + ExpressionException, b.aggregate_cols, group5=hl.agg.sum(mt['col_idx']) + ) # duplicate column field + self.assertRaises( + ExpressionException, b.aggregate_cols, foo=hl.agg.sum(mt['col_idx']) + ) # duplicate globals field + self.assertRaises( + ExpressionException, b.aggregate_cols, bar=mt['col_idx'] + hl.agg.sum(mt['col_idx']) + ) # expression has to have global indices + self.assertRaises( + ExpressionException, b.aggregate_cols, bar=mt['row_idx'] + hl.agg.sum(mt['col_idx']) + ) # expression has to have global indices + self.assertRaises( + ExpressionException, b.aggregate_cols, bar=hl.agg.sum(mt['c']) + ) # aggregation scope is cols only - entry field + self.assertRaises( + ExpressionException, b.aggregate_cols, bar=hl.agg.sum(mt['row_idx']) + ) # aggregation scope is cols only - row field c = mt.group_rows_by(group5=(mt['group2']['a'] + 1)).aggregate_rows(x=hl.agg.count()) - self.assertRaises(ExpressionException, c.aggregate_rows, x=hl.agg.count()) # duplicate field + self.assertRaises(ExpressionException, c.aggregate_rows, x=hl.agg.count()) # duplicate field d = mt.group_cols_by(group5=(mt['group4']['a'] + 1)).aggregate_cols(x=hl.agg.count()) - self.assertRaises(ExpressionException, d.aggregate_cols, x=hl.agg.count()) # duplicate field + self.assertRaises(ExpressionException, d.aggregate_cols, x=hl.agg.count()) # duplicate field def test_fields_work_correctly(self): mt = self.get_groupable_matrix() @@ -122,42 +156,54 @@ def test_named_fields_work_correctly(self): def test_joins_work_correctly(self): mt, mt2 = self.get_groupable_matrix2() - col_result = (mt.group_cols_by(group=mt2.cols()[mt.col_idx].col_idx2 < 2) - .aggregate(sum=hl.agg.sum(mt2[mt.row_idx, mt.col_idx].x + mt.glob) + mt.glob - 15) - .drop('r1')) + col_result = ( + mt.group_cols_by(group=mt2.cols()[mt.col_idx].col_idx2 < 2) + .aggregate(sum=hl.agg.sum(mt2[mt.row_idx, mt.col_idx].x + mt.glob) + mt.glob - 15) + .drop('r1') + ) col_expected = ( hl.Table.parallelize( - [{'row_idx': 0, 'group': True, 'sum': 1}, - {'row_idx': 0, 'group': False, 'sum': 5}, - {'row_idx': 1, 'group': True, 'sum': 3}, - {'row_idx': 1, 'group': False, 'sum': 7}, - {'row_idx': 2, 'group': True, 'sum': 5}, - {'row_idx': 2, 'group': False, 'sum': 9}, - {'row_idx': 3, 'group': True, 'sum': 7}, - {'row_idx': 3, 'group': False, 'sum': 11}], - hl.tstruct(row_idx=hl.tint32, group=hl.tbool, sum=hl.tint64) - ).annotate_globals(glob=5).key_by('row_idx', 'group') + [ + {'row_idx': 0, 'group': True, 'sum': 1}, + {'row_idx': 0, 'group': False, 'sum': 5}, + {'row_idx': 1, 'group': True, 'sum': 3}, + {'row_idx': 1, 'group': False, 'sum': 7}, + {'row_idx': 2, 'group': True, 'sum': 5}, + {'row_idx': 2, 'group': False, 'sum': 9}, + {'row_idx': 3, 'group': True, 'sum': 7}, + {'row_idx': 3, 'group': False, 'sum': 11}, + ], + hl.tstruct(row_idx=hl.tint32, group=hl.tbool, sum=hl.tint64), + ) + .annotate_globals(glob=5) + .key_by('row_idx', 'group') ) self.assertTrue(col_result.entries()._same(col_expected)) - row_result = (mt.group_rows_by(group=mt2.rows()[mt.row_idx].row_idx2 < 2) - .aggregate(sum=hl.agg.sum(mt2[mt.row_idx, mt.col_idx].x + mt.glob) + mt.glob - 15) - .drop('c1')) + row_result = ( + mt.group_rows_by(group=mt2.rows()[mt.row_idx].row_idx2 < 2) + .aggregate(sum=hl.agg.sum(mt2[mt.row_idx, mt.col_idx].x + mt.glob) + mt.glob - 15) + .drop('c1') + ) row_expected = ( hl.Table.parallelize( - [{'group': True, 'col_idx': 0, 'sum': 1}, - {'group': True, 'col_idx': 1, 'sum': 3}, - {'group': True, 'col_idx': 2, 'sum': 5}, - {'group': True, 'col_idx': 3, 'sum': 7}, - {'group': False, 'col_idx': 0, 'sum': 5}, - {'group': False, 'col_idx': 1, 'sum': 7}, - {'group': False, 'col_idx': 2, 'sum': 9}, - {'group': False, 'col_idx': 3, 'sum': 11}], - hl.tstruct(group=hl.tbool, col_idx=hl.tint32, sum=hl.tint64) - ).annotate_globals(glob=5).key_by('group', 'col_idx') + [ + {'group': True, 'col_idx': 0, 'sum': 1}, + {'group': True, 'col_idx': 1, 'sum': 3}, + {'group': True, 'col_idx': 2, 'sum': 5}, + {'group': True, 'col_idx': 3, 'sum': 7}, + {'group': False, 'col_idx': 0, 'sum': 5}, + {'group': False, 'col_idx': 1, 'sum': 7}, + {'group': False, 'col_idx': 2, 'sum': 9}, + {'group': False, 'col_idx': 3, 'sum': 11}, + ], + hl.tstruct(group=hl.tbool, col_idx=hl.tint32, sum=hl.tint64), + ) + .annotate_globals(glob=5) + .key_by('group', 'col_idx') ) self.assertTrue(row_result.entries()._same(row_expected)) @@ -165,26 +211,41 @@ def test_joins_work_correctly(self): def test_group_rows_by_aggregate(self): mt, mt2 = self.get_groupable_matrix2() - row_result = (mt.group_rows_by(group=mt2.rows()[mt.row_idx].row_idx2 < 2) - .aggregate_rows(collect=hl.agg.collect(mt.row_idx)) - .aggregate_rows(count=hl.agg.count()) - .aggregate_entries(sum=hl.agg.sum(mt2[mt.row_idx, mt.col_idx].x + mt.glob) + mt.glob - 15 - mt.col_idx) # tests fixed indices - .aggregate_entries(x=5) - .result()) + row_result = ( + mt.group_rows_by(group=mt2.rows()[mt.row_idx].row_idx2 < 2) + .aggregate_rows(collect=hl.agg.collect(mt.row_idx)) + .aggregate_rows(count=hl.agg.count()) + .aggregate_entries( + sum=hl.agg.sum(mt2[mt.row_idx, mt.col_idx].x + mt.glob) + mt.glob - 15 - mt.col_idx + ) # tests fixed indices + .aggregate_entries(x=5) + .result() + ) row_expected = ( hl.Table.parallelize( - [{'group': True, 'col_idx': 0, 'sum': 1, 'collect': [0, 1], 'count': 2, 'c1': 3, 'x': 5}, - {'group': True, 'col_idx': 1, 'sum': 2, 'collect': [0, 1], 'count': 2, 'c1': 3, 'x': 5}, - {'group': True, 'col_idx': 2, 'sum': 3, 'collect': [0, 1], 'count': 2, 'c1': 3, 'x': 5}, - {'group': True, 'col_idx': 3, 'sum': 4, 'collect': [0, 1], 'count': 2, 'c1': 3, 'x': 5}, - {'group': False, 'col_idx': 0, 'sum': 5, 'collect': [2, 3], 'count': 2, 'c1': 3, 'x': 5}, - {'group': False, 'col_idx': 1, 'sum': 6, 'collect': [2, 3], 'count': 2, 'c1': 3, 'x': 5}, - {'group': False, 'col_idx': 2, 'sum': 7, 'collect': [2, 3], 'count': 2, 'c1': 3, 'x': 5}, - {'group': False, 'col_idx': 3, 'sum': 8, 'collect': [2, 3], 'count': 2, 'c1': 3, 'x': 5}], - hl.tstruct(group=hl.tbool, collect=hl.tarray(hl.tint32), count=hl.tint64, - col_idx=hl.tint32, c1=hl.tint32, sum=hl.tint64, x=hl.tint32) - ).annotate_globals(glob=5).key_by('group', 'col_idx') + [ + {'group': True, 'col_idx': 0, 'sum': 1, 'collect': [0, 1], 'count': 2, 'c1': 3, 'x': 5}, + {'group': True, 'col_idx': 1, 'sum': 2, 'collect': [0, 1], 'count': 2, 'c1': 3, 'x': 5}, + {'group': True, 'col_idx': 2, 'sum': 3, 'collect': [0, 1], 'count': 2, 'c1': 3, 'x': 5}, + {'group': True, 'col_idx': 3, 'sum': 4, 'collect': [0, 1], 'count': 2, 'c1': 3, 'x': 5}, + {'group': False, 'col_idx': 0, 'sum': 5, 'collect': [2, 3], 'count': 2, 'c1': 3, 'x': 5}, + {'group': False, 'col_idx': 1, 'sum': 6, 'collect': [2, 3], 'count': 2, 'c1': 3, 'x': 5}, + {'group': False, 'col_idx': 2, 'sum': 7, 'collect': [2, 3], 'count': 2, 'c1': 3, 'x': 5}, + {'group': False, 'col_idx': 3, 'sum': 8, 'collect': [2, 3], 'count': 2, 'c1': 3, 'x': 5}, + ], + hl.tstruct( + group=hl.tbool, + collect=hl.tarray(hl.tint32), + count=hl.tint64, + col_idx=hl.tint32, + c1=hl.tint32, + sum=hl.tint64, + x=hl.tint32, + ), + ) + .annotate_globals(glob=5) + .key_by('group', 'col_idx') ) row_result.entries().show() @@ -194,26 +255,41 @@ def test_group_rows_by_aggregate(self): def test_group_cols_by_aggregate(self): mt, mt2 = self.get_groupable_matrix2() - col_result = (mt.group_cols_by(group=mt2.cols()[mt.col_idx].col_idx2 < 2) - .aggregate_cols(collect=hl.agg.collect(mt.col_idx)) - .aggregate_cols(count=hl.agg.count()) - .aggregate_entries(sum=hl.agg.sum(mt2[mt.row_idx, mt.col_idx].x + mt.glob) + mt.glob - 15 - mt.row_idx) # tests fixed indices - .aggregate_entries(x=5) - .result()) + col_result = ( + mt.group_cols_by(group=mt2.cols()[mt.col_idx].col_idx2 < 2) + .aggregate_cols(collect=hl.agg.collect(mt.col_idx)) + .aggregate_cols(count=hl.agg.count()) + .aggregate_entries( + sum=hl.agg.sum(mt2[mt.row_idx, mt.col_idx].x + mt.glob) + mt.glob - 15 - mt.row_idx + ) # tests fixed indices + .aggregate_entries(x=5) + .result() + ) col_expected = ( hl.Table.parallelize( - [{'group': True, 'row_idx': 0, 'sum': 1, 'collect': [0, 1], 'count': 2, 'r1': 3, 'x': 5}, - {'group': True, 'row_idx': 1, 'sum': 2, 'collect': [0, 1], 'count': 2, 'r1': 3, 'x': 5}, - {'group': True, 'row_idx': 2, 'sum': 3, 'collect': [0, 1], 'count': 2, 'r1': 3, 'x': 5}, - {'group': True, 'row_idx': 3, 'sum': 4, 'collect': [0, 1], 'count': 2, 'r1': 3, 'x': 5}, - {'group': False, 'row_idx': 0, 'sum': 5, 'collect': [2, 3], 'count': 2, 'r1': 3, 'x': 5}, - {'group': False, 'row_idx': 1, 'sum': 6, 'collect': [2, 3], 'count': 2, 'r1': 3, 'x': 5}, - {'group': False, 'row_idx': 2, 'sum': 7, 'collect': [2, 3], 'count': 2, 'r1': 3, 'x': 5}, - {'group': False, 'row_idx': 3, 'sum': 8, 'collect': [2, 3], 'count': 2, 'r1': 3, 'x': 5}], - hl.tstruct(row_idx=hl.tint32, r1=hl.tint32, group=hl.tbool, collect=hl.tarray(hl.tint32), - count=hl.tint64, sum=hl.tint64, x=hl.tint32) - ).annotate_globals(glob=5).key_by('row_idx', 'group') + [ + {'group': True, 'row_idx': 0, 'sum': 1, 'collect': [0, 1], 'count': 2, 'r1': 3, 'x': 5}, + {'group': True, 'row_idx': 1, 'sum': 2, 'collect': [0, 1], 'count': 2, 'r1': 3, 'x': 5}, + {'group': True, 'row_idx': 2, 'sum': 3, 'collect': [0, 1], 'count': 2, 'r1': 3, 'x': 5}, + {'group': True, 'row_idx': 3, 'sum': 4, 'collect': [0, 1], 'count': 2, 'r1': 3, 'x': 5}, + {'group': False, 'row_idx': 0, 'sum': 5, 'collect': [2, 3], 'count': 2, 'r1': 3, 'x': 5}, + {'group': False, 'row_idx': 1, 'sum': 6, 'collect': [2, 3], 'count': 2, 'r1': 3, 'x': 5}, + {'group': False, 'row_idx': 2, 'sum': 7, 'collect': [2, 3], 'count': 2, 'r1': 3, 'x': 5}, + {'group': False, 'row_idx': 3, 'sum': 8, 'collect': [2, 3], 'count': 2, 'r1': 3, 'x': 5}, + ], + hl.tstruct( + row_idx=hl.tint32, + r1=hl.tint32, + group=hl.tbool, + collect=hl.tarray(hl.tint32), + count=hl.tint64, + sum=hl.tint64, + x=hl.tint32, + ), + ) + .annotate_globals(glob=5) + .key_by('row_idx', 'group') ) self.assertTrue(col_result.entries()._same(col_expected)) diff --git a/hail/python/test/hail/matrixtable/test_matrix_table.py b/hail/python/test/hail/matrixtable/test_matrix_table.py index 43833bc539e..ba950fd26a0 100644 --- a/hail/python/test/hail/matrixtable/test_matrix_table.py +++ b/hail/python/test/hail/matrixtable/test_matrix_table.py @@ -1,14 +1,27 @@ import math import operator import random +import unittest + import pytest import hail as hl -import hail.ir as ir import hail.expr.aggregators as agg +from hail import ir from hail.utils.java import Env from hail.utils.misc import new_temp_file -from ..helpers import * + +from ..helpers import ( + convert_struct_to_dict, + create_all_values_matrix_table, + fails_local_backend, + fails_service_backend, + get_dataset, + qobtest, + resource, + schema_eq, + test_timeout, +) class Tests(unittest.TestCase): @@ -34,25 +47,21 @@ def test_annotate(self): self.assertEqual(mt.globals.dtype, hl.tstruct(foo=hl.tint32)) - mt = mt.annotate_rows(x1=agg.count(), - x2=agg.fraction(False), - x3=agg.count_where(True), - x4=mt.info.AC + mt.foo) + mt = mt.annotate_rows(x1=agg.count(), x2=agg.fraction(False), x3=agg.count_where(True), x4=mt.info.AC + mt.foo) mt = mt.annotate_cols(apple=6) - mt = mt.annotate_cols(y1=agg.count(), - y2=agg.fraction(False), - y3=agg.count_where(True), - y4=mt.foo + mt.apple) + mt = mt.annotate_cols(y1=agg.count(), y2=agg.fraction(False), y3=agg.count_where(True), y4=mt.foo + mt.apple) - expected_schema = hl.tstruct(s=hl.tstr, apple=hl.tint32, y1=hl.tint64, y2=hl.tfloat64, y3=hl.tint64, - y4=hl.tint32) + expected_schema = hl.tstruct( + s=hl.tstr, apple=hl.tint32, y1=hl.tint64, y2=hl.tfloat64, y3=hl.tint64, y4=hl.tint32 + ) - self.assertTrue(schema_eq(mt.col.dtype, expected_schema), - "expected: " + str(mt.col.dtype) + "\nactual: " + str(expected_schema)) + self.assertTrue( + schema_eq(mt.col.dtype, expected_schema), + "expected: " + str(mt.col.dtype) + "\nactual: " + str(expected_schema), + ) - mt = mt.select_entries(z1=mt.x1 + mt.foo, - z2=mt.x1 + mt.y1 + mt.foo) + mt = mt.select_entries(z1=mt.x1 + mt.foo, z2=mt.x1 + mt.y1 + mt.foo) self.assertTrue(schema_eq(mt.entry.dtype, hl.tstruct(z1=hl.tint64, z2=hl.tint64))) def test_annotate_globals(self): @@ -64,8 +73,11 @@ def test_annotate_globals(self): (float('inf'), hl.tfloat64, lambda x, y: str(x) == str(y)), (float('-inf'), hl.tfloat64, lambda x, y: str(x) == str(y)), (1.111, hl.tfloat64, operator.eq), - ([hl.Struct(**{'a': None, 'b': 5}), - hl.Struct(**{'a': 'hello', 'b': 10})], hl.tarray(hl.tstruct(a=hl.tstr, b=hl.tint)), operator.eq) + ( + [hl.Struct(**{'a': None, 'b': 5}), hl.Struct(**{'a': 'hello', 'b': 10})], + hl.tarray(hl.tstruct(a=hl.tstr, b=hl.tint)), + operator.eq, + ), ] for x, t, f in data: @@ -161,8 +173,8 @@ def expected(n, m): def test_tail_scan(self): mt = hl.utils.range_matrix_table(30, 40) - mt = mt.annotate_rows(i = hl.scan.count()) - mt = mt.annotate_cols(j = hl.scan.count()) + mt = mt.annotate_rows(i=hl.scan.count()) + mt = mt.annotate_cols(j=hl.scan.count()) mt = mt.tail(10, 11) ht = mt.entries() assert ht.aggregate(agg.collect_as_set(hl.tuple([ht.i, ht.j]))) == set( @@ -192,8 +204,7 @@ def test_aggregate_rows(self): qv = mt.aggregate_rows(agg.count()) self.assertEqual(qv, 346) - mt.aggregate_rows(hl.Struct(x=agg.collect(mt.locus.contig), - y=agg.collect(mt.x1))) + mt.aggregate_rows(hl.Struct(x=agg.collect(mt.locus.contig), y=agg.collect(mt.x1))) def test_aggregate_cols(self): mt = self.get_mt() @@ -208,8 +219,7 @@ def test_aggregate_cols(self): qs = hl.eval(mt.aggregate_cols(agg.count(), _localize=False)) self.assertEqual(qs, 100) - mt.aggregate_cols(hl.Struct(x=agg.collect(mt.s), - y=agg.collect(mt.y1))) + mt.aggregate_cols(hl.Struct(x=agg.collect(mt.s), y=agg.collect(mt.y1))) def test_aggregate_cols_order(self): path = new_temp_file(extension='mt') @@ -231,13 +241,14 @@ def test_aggregate_entries(self): qg = mt.aggregate_entries(agg.count()) self.assertEqual(qg, 34600) - mt.aggregate_entries(hl.Struct(x=agg.filter(False, agg.collect(mt.y1)), - y=agg.filter(hl.rand_bool(0.1), agg.collect(mt.GT)))) + mt.aggregate_entries( + hl.Struct(x=agg.filter(False, agg.collect(mt.y1)), y=agg.filter(hl.rand_bool(0.1), agg.collect(mt.GT))) + ) self.assertIsNotNone(mt.aggregate_entries(hl.agg.take(mt.s, 1)[0])) def test_aggregate_rows_array_agg(self): mt = hl.utils.range_matrix_table(10, 10) - mt = mt.annotate_rows(maf_flag = hl.empty_array('bool')) + mt = mt.annotate_rows(maf_flag=hl.empty_array('bool')) mt.aggregate_rows(hl.agg.array_agg(lambda x: hl.agg.counter(x), mt.maf_flag)) def test_aggregate_rows_bn_counter(self): @@ -246,7 +257,7 @@ def test_aggregate_rows_bn_counter(self): def test_col_agg_no_rows(self): mt = hl.utils.range_matrix_table(3, 3).filter_rows(False) - mt = mt.annotate_cols(x = hl.agg.count()) + mt = mt.annotate_cols(x=hl.agg.count()) assert mt.x.collect() == [0, 0, 0] def test_col_collect(self): @@ -254,19 +265,20 @@ def test_col_collect(self): mt.cols().collect() def test_aggregate_ir(self): - ds = (hl.utils.range_matrix_table(5, 5) - .annotate_globals(g1=5) - .annotate_entries(e1=3)) + ds = hl.utils.range_matrix_table(5, 5).annotate_globals(g1=5).annotate_entries(e1=3) - x = [("col_idx", lambda e: ds.aggregate_cols(e)), - ("row_idx", lambda e: ds.aggregate_rows(e))] + x = [("col_idx", lambda e: ds.aggregate_cols(e)), ("row_idx", lambda e: ds.aggregate_rows(e))] for name, f in x: - r = f(hl.struct(x=agg.sum(ds[name]) + ds.g1, - y=agg.filter(ds[name] % 2 != 0, agg.sum(ds[name] + 2)) + ds.g1, - z=agg.sum(ds.g1 + ds[name]) + ds.g1, - mean=agg.mean(ds[name]))) - self.assertEqual(convert_struct_to_dict(r), {u'x': 15, u'y': 13, u'z': 40, u'mean': 2.0}) + r = f( + hl.struct( + x=agg.sum(ds[name]) + ds.g1, + y=agg.filter(ds[name] % 2 != 0, agg.sum(ds[name] + 2)) + ds.g1, + z=agg.sum(ds.g1 + ds[name]) + ds.g1, + mean=agg.mean(ds[name]), + ) + ) + self.assertEqual(convert_struct_to_dict(r), {'x': 15, 'y': 13, 'z': 40, 'mean': 2.0}) r = f(5) self.assertEqual(r, 5) @@ -277,8 +289,10 @@ def test_aggregate_ir(self): r = f(agg.filter(ds[name] % 2 != 0, agg.sum(ds[name] + 2)) + ds.g1) self.assertEqual(r, 13) - r = ds.aggregate_entries(agg.filter((ds.row_idx % 2 != 0) & (ds.col_idx % 2 != 0), - agg.sum(ds.e1 + ds.g1 + ds.row_idx + ds.col_idx)) + ds.g1) + r = ds.aggregate_entries( + agg.filter((ds.row_idx % 2 != 0) & (ds.col_idx % 2 != 0), agg.sum(ds.e1 + ds.g1 + ds.row_idx + ds.col_idx)) + + ds.g1 + ) self.assertTrue(r, 48) def test_select_entries(self): @@ -288,15 +302,17 @@ def test_select_entries(self): mt = mt.annotate_entries(bc=mt.b * 10 + mt.c) mt_entries = mt.entries() - assert (mt_entries.all(mt_entries.bc == mt_entries.foo)) + assert mt_entries.all(mt_entries.bc == mt_entries.foo) def test_select_cols(self): mt = hl.utils.range_matrix_table(3, 5, n_partitions=4) mt = mt.annotate_entries(e=mt.col_idx * mt.row_idx) mt = mt.annotate_globals(g=1) - mt = mt.annotate_cols(sum=agg.sum(mt.e + mt.col_idx + mt.row_idx + mt.g) + mt.col_idx + mt.g, - count=agg.count_where(mt.e % 2 == 0), - foo=agg.count()) + mt = mt.annotate_cols( + sum=agg.sum(mt.e + mt.col_idx + mt.row_idx + mt.g) + mt.col_idx + mt.g, + count=agg.count_where(mt.e % 2 == 0), + foo=agg.count(), + ) result = convert_struct_to_dict(mt.cols().collect()[-2]) self.assertEqual(result, {'col_idx': 3, 'sum': 28, 'count': 2, 'foo': 3}) @@ -350,36 +366,38 @@ def test_explode_key_errors(self): def test_group_by_field_lifetimes(self): mt = hl.utils.range_matrix_table(3, 3) - mt2 = (mt.group_rows_by(row_idx='100') - .aggregate(x=hl.agg.collect_as_set(mt.row_idx + 5))) + mt2 = mt.group_rows_by(row_idx='100').aggregate(x=hl.agg.collect_as_set(mt.row_idx + 5)) assert mt2.aggregate_entries(hl.agg.all(mt2.x == hl.set({5, 6, 7}))) - mt3 = (mt.group_cols_by(col_idx='100') - .aggregate(x=hl.agg.collect_as_set(mt.col_idx + 5))) + mt3 = mt.group_cols_by(col_idx='100').aggregate(x=hl.agg.collect_as_set(mt.col_idx + 5)) assert mt3.aggregate_entries(hl.agg.all(mt3.x == hl.set({5, 6, 7}))) def test_aggregate_cols_by(self): mt = hl.utils.range_matrix_table(2, 4) - mt = (mt.annotate_cols(group=mt.col_idx < 2) - .annotate_globals(glob=5)) + mt = mt.annotate_cols(group=mt.col_idx < 2).annotate_globals(glob=5) grouped = mt.group_cols_by(mt.group) result = grouped.aggregate(sum=hl.agg.sum(mt.row_idx * 2 + mt.col_idx + mt.glob) + 3) - expected = (hl.Table.parallelize([ - {'row_idx': 0, 'group': True, 'sum': 14}, - {'row_idx': 0, 'group': False, 'sum': 18}, - {'row_idx': 1, 'group': True, 'sum': 18}, - {'row_idx': 1, 'group': False, 'sum': 22} - ], hl.tstruct(row_idx=hl.tint, group=hl.tbool, sum=hl.tint64)) - .annotate_globals(glob=5) - .key_by('row_idx', 'group')) + expected = ( + hl.Table.parallelize( + [ + {'row_idx': 0, 'group': True, 'sum': 14}, + {'row_idx': 0, 'group': False, 'sum': 18}, + {'row_idx': 1, 'group': True, 'sum': 18}, + {'row_idx': 1, 'group': False, 'sum': 22}, + ], + hl.tstruct(row_idx=hl.tint, group=hl.tbool, sum=hl.tint64), + ) + .annotate_globals(glob=5) + .key_by('row_idx', 'group') + ) self.assertTrue(result.entries()._same(expected)) def test_aggregate_cols_by_init_op(self): mt = hl.import_vcf(resource('sample.vcf')) - cs = mt.group_cols_by(mt.s).aggregate(cs = hl.agg.call_stats(mt.GT, mt.alleles)) - cs._force_count_rows() # should run without error + cs = mt.group_cols_by(mt.s).aggregate(cs=hl.agg.call_stats(mt.GT, mt.alleles)) + cs._force_count_rows() # should run without error def test_aggregate_cols_scope_violation(self): mt = get_dataset() @@ -389,19 +407,23 @@ def test_aggregate_cols_scope_violation(self): def test_aggregate_rows_by(self): mt = hl.utils.range_matrix_table(4, 2) - mt = (mt.annotate_rows(group=mt.row_idx < 2) - .annotate_globals(glob=5)) + mt = mt.annotate_rows(group=mt.row_idx < 2).annotate_globals(glob=5) grouped = mt.group_rows_by(mt.group) result = grouped.aggregate(sum=hl.agg.sum(mt.col_idx * 2 + mt.row_idx + mt.glob) + 3) - expected = (hl.Table.parallelize([ - {'col_idx': 0, 'group': True, 'sum': 14}, - {'col_idx': 1, 'group': True, 'sum': 18}, - {'col_idx': 0, 'group': False, 'sum': 18}, - {'col_idx': 1, 'group': False, 'sum': 22} - ], hl.tstruct(group=hl.tbool, col_idx=hl.tint, sum=hl.tint64)) - .annotate_globals(glob=5) - .key_by('group', 'col_idx')) + expected = ( + hl.Table.parallelize( + [ + {'col_idx': 0, 'group': True, 'sum': 14}, + {'col_idx': 1, 'group': True, 'sum': 18}, + {'col_idx': 0, 'group': False, 'sum': 18}, + {'col_idx': 1, 'group': False, 'sum': 22}, + ], + hl.tstruct(group=hl.tbool, col_idx=hl.tint, sum=hl.tint64), + ) + .annotate_globals(glob=5) + .key_by('group', 'col_idx') + ) self.assertTrue(result.entries()._same(expected)) @@ -409,34 +431,35 @@ def test_aggregate_rows_by(self): def test_collect_cols_by_key(self): mt = hl.utils.range_matrix_table(3, 3) col_dict = hl.literal({0: [1], 1: [2, 3], 2: [4, 5, 6]}) - mt = mt.annotate_cols(foo=col_dict.get(mt.col_idx)) \ - .explode_cols('foo') + mt = mt.annotate_cols(foo=col_dict.get(mt.col_idx)).explode_cols('foo') mt = mt.annotate_entries(bar=mt.row_idx * mt.foo) grouped = mt.collect_cols_by_key() - self.assertListEqual(grouped.cols().order_by('col_idx').collect(), - [hl.Struct(col_idx=0, foo=[1]), - hl.Struct(col_idx=1, foo=[2, 3]), - hl.Struct(col_idx=2, foo=[4, 5, 6])]) self.assertListEqual( - grouped.entries().select('bar') - .order_by('row_idx', 'col_idx').collect(), - [hl.Struct(row_idx=0, col_idx=0, bar=[0]), - hl.Struct(row_idx=0, col_idx=1, bar=[0, 0]), - hl.Struct(row_idx=0, col_idx=2, bar=[0, 0, 0]), - hl.Struct(row_idx=1, col_idx=0, bar=[1]), - hl.Struct(row_idx=1, col_idx=1, bar=[2, 3]), - hl.Struct(row_idx=1, col_idx=2, bar=[4, 5, 6]), - hl.Struct(row_idx=2, col_idx=0, bar=[2]), - hl.Struct(row_idx=2, col_idx=1, bar=[4, 6]), - hl.Struct(row_idx=2, col_idx=2, bar=[8, 10, 12])]) + grouped.cols().order_by('col_idx').collect(), + [hl.Struct(col_idx=0, foo=[1]), hl.Struct(col_idx=1, foo=[2, 3]), hl.Struct(col_idx=2, foo=[4, 5, 6])], + ) + self.assertListEqual( + grouped.entries().select('bar').order_by('row_idx', 'col_idx').collect(), + [ + hl.Struct(row_idx=0, col_idx=0, bar=[0]), + hl.Struct(row_idx=0, col_idx=1, bar=[0, 0]), + hl.Struct(row_idx=0, col_idx=2, bar=[0, 0, 0]), + hl.Struct(row_idx=1, col_idx=0, bar=[1]), + hl.Struct(row_idx=1, col_idx=1, bar=[2, 3]), + hl.Struct(row_idx=1, col_idx=2, bar=[4, 5, 6]), + hl.Struct(row_idx=2, col_idx=0, bar=[2]), + hl.Struct(row_idx=2, col_idx=1, bar=[4, 6]), + hl.Struct(row_idx=2, col_idx=2, bar=[8, 10, 12]), + ], + ) def test_collect_cols_by_key_with_rand(self): mt = hl.utils.range_matrix_table(3, 3) - mt = mt.annotate_cols(x = hl.rand_norm()) + mt = mt.annotate_cols(x=hl.rand_norm()) mt = mt.collect_cols_by_key() - mt = mt.annotate_cols(x = hl.rand_norm()) + mt = mt.annotate_cols(x=hl.rand_norm()) mt.cols().collect() def test_weird_names(self): @@ -469,8 +492,8 @@ def test_weird_names(self): def test_semi_anti_join_rows(self): mt = hl.utils.range_matrix_table(10, 3) ht = hl.utils.range_table(3) - mt2 = mt.key_rows_by(k1 = mt.row_idx, k2 = hl.str(mt.row_idx * 2)) - ht2 = ht.key_by(k1 = ht.idx, k2 = hl.str(ht.idx * 2)) + mt2 = mt.key_rows_by(k1=mt.row_idx, k2=hl.str(mt.row_idx * 2)) + ht2 = ht.key_by(k1=ht.idx, k2=hl.str(ht.idx * 2)) assert mt.semi_join_rows(ht).count() == (3, 3) assert mt.anti_join_rows(ht).count() == (7, 3) @@ -492,8 +515,8 @@ def test_semi_anti_join_rows(self): def test_semi_anti_join_cols(self): mt = hl.utils.range_matrix_table(3, 10) ht = hl.utils.range_table(3) - mt2 = mt.key_cols_by(k1 = mt.col_idx, k2 = hl.str(mt.col_idx * 2)) - ht2 = ht.key_by(k1 = ht.idx, k2 = hl.str(ht.idx * 2)) + mt2 = mt.key_cols_by(k1=mt.col_idx, k2=hl.str(mt.col_idx * 2)) + ht2 = ht.key_by(k1=ht.idx, k2=hl.str(ht.idx * 2)) assert mt.semi_join_cols(ht).count() == (3, 3) assert mt.anti_join_cols(ht).count() == (3, 7) @@ -553,8 +576,7 @@ def test_index_keyless(self): def test_table_join(self): ds = self.get_mt() # test different row schemas - self.assertTrue(ds.union_cols(ds.drop(ds.info)) - .count_rows(), 346) + self.assertTrue(ds.union_cols(ds.drop(ds.info)).count_rows(), 346) def test_table_product_join(self): left = hl.utils.range_matrix_table(5, 1) @@ -577,26 +599,32 @@ def test_coalesce_with_no_rows(self): def test_literals_rebuild(self): mt = hl.utils.range_matrix_table(1, 1) - mt = mt.annotate_rows(x=hl.if_else(hl.literal([1,2,3])[mt.row_idx] < hl.rand_unif(10, 11), mt.globals, hl.struct())) + mt = mt.annotate_rows( + x=hl.if_else(hl.literal([1, 2, 3])[mt.row_idx] < hl.rand_unif(10, 11), mt.globals, hl.struct()) + ) mt._force_count_rows() def test_globals_lowering(self): mt = hl.utils.range_matrix_table(1, 1).annotate_globals(x=1) - lit = hl.literal(hl.utils.Struct(x = 0)) + lit = hl.literal(hl.utils.Struct(x=0)) mt.annotate_rows(foo=hl.agg.collect(mt.globals == lit))._force_count_rows() mt.annotate_cols(foo=hl.agg.collect(mt.globals == lit))._force_count_rows() mt.filter_rows(mt.globals == lit)._force_count_rows() mt.filter_cols(mt.globals == lit)._force_count_rows() mt.filter_entries(mt.globals == lit)._force_count_rows() - (mt.group_rows_by(mt.row_idx) - .aggregate_rows(foo=hl.agg.collect(mt.globals == lit)) - .aggregate(bar=hl.agg.collect(mt.globals == lit)) - ._force_count_rows()) - (mt.group_cols_by(mt.col_idx) - .aggregate_cols(foo=hl.agg.collect(mt.globals == lit)) - .aggregate(bar=hl.agg.collect(mt.globals == lit)) - ._force_count_rows()) + ( + mt.group_rows_by(mt.row_idx) + .aggregate_rows(foo=hl.agg.collect(mt.globals == lit)) + .aggregate(bar=hl.agg.collect(mt.globals == lit)) + ._force_count_rows() + ) + ( + mt.group_cols_by(mt.col_idx) + .aggregate_cols(foo=hl.agg.collect(mt.globals == lit)) + .aggregate(bar=hl.agg.collect(mt.globals == lit)) + ._force_count_rows() + ) def test_unions_1(self): dataset = hl.import_vcf(resource('sample2.vcf')) @@ -630,14 +658,14 @@ def test_union_cols_example(self): def test_union_cols_distinct(self): mt = hl.utils.range_matrix_table(10, 10) - mt = mt.key_rows_by(x = mt.row_idx // 2) + mt = mt.key_rows_by(x=mt.row_idx // 2) assert mt.union_cols(mt).count_rows() == 5 def test_union_cols_no_error_on_duplicate_names(self): mt = hl.utils.range_matrix_table(10, 10) - mt = mt.annotate_rows(both = 'hi') - mt2 = mt.annotate_rows(both = 3, right_only = 'abc') - mt = mt.annotate_rows(left_only = '123') + mt = mt.annotate_rows(both='hi') + mt2 = mt.annotate_rows(both=3, right_only='abc') + mt = mt.annotate_rows(left_only='123') mt = mt.union_cols(mt2, drop_right_row_fields=False) assert 'both' in mt.row_value assert 'left_only' in mt.row_value @@ -646,22 +674,25 @@ def test_union_cols_no_error_on_duplicate_names(self): def test_union_cols_outer(self): r, c = 10, 10 - mt = hl.utils.range_matrix_table(2*r, c) + mt = hl.utils.range_matrix_table(2 * r, c) mt = mt.annotate_entries(entry=hl.tuple([mt.row_idx, mt.col_idx])) mt = mt.annotate_rows(left=mt.row_idx) - mt2 = hl.utils.range_matrix_table(2*r, c) + mt2 = hl.utils.range_matrix_table(2 * r, c) mt2 = mt2.key_rows_by(row_idx=mt2.row_idx + r) mt2 = mt2.key_cols_by(col_idx=mt2.col_idx + c) mt2 = mt2.annotate_entries(entry=hl.tuple([mt2.row_idx, mt2.col_idx])) mt2 = mt2.annotate_rows(right=mt2.row_idx) - expected = hl.utils.range_matrix_table(3*r, 2*c) + expected = hl.utils.range_matrix_table(3 * r, 2 * c) missing = hl.missing(hl.ttuple(hl.tint, hl.tint)) - expected = expected.annotate_entries(entry=hl.if_else( - expected.col_idx < c, - hl.if_else(expected.row_idx < 2*r, hl.tuple([expected.row_idx, expected.col_idx]), missing), - hl.if_else(expected.row_idx >= r, hl.tuple([expected.row_idx, expected.col_idx]), missing))) + expected = expected.annotate_entries( + entry=hl.if_else( + expected.col_idx < c, + hl.if_else(expected.row_idx < 2 * r, hl.tuple([expected.row_idx, expected.col_idx]), missing), + hl.if_else(expected.row_idx >= r, hl.tuple([expected.row_idx, expected.col_idx]), missing), + ) + ) expected = expected.annotate_rows( - left=hl.if_else(expected.row_idx < 2*r, expected.row_idx, hl.missing(hl.tint)), + left=hl.if_else(expected.row_idx < 2 * r, expected.row_idx, hl.missing(hl.tint)), right=hl.if_else(expected.row_idx >= r, expected.row_idx, hl.missing(hl.tint)), ) assert mt.union_cols(mt2, row_join_type='outer', drop_right_row_fields=False)._same(expected) @@ -691,11 +722,9 @@ def test_choose_cols(self): random.shuffle(indices) old_order = ds.key_cols_by()['s'].collect() - self.assertEqual(ds.choose_cols(indices).key_cols_by()['s'].collect(), - [old_order[i] for i in indices]) + self.assertEqual(ds.choose_cols(indices).key_cols_by()['s'].collect(), [old_order[i] for i in indices]) - self.assertEqual(ds.choose_cols(list(range(10))).s.collect(), - old_order[:10]) + self.assertEqual(ds.choose_cols(list(range(10))).s.collect(), old_order[:10]) def test_choose_cols_vs_explode(self): ds = self.get_mt() @@ -726,48 +755,46 @@ def test_aggregation_with_no_aggregators(self): def test_computed_key_join_1(self): ds = self.get_mt() kt = hl.Table.parallelize( - [{'key': 0, 'value': True}, - {'key': 1, 'value': False}], + [{'key': 0, 'value': True}, {'key': 1, 'value': False}], hl.tstruct(key=hl.tint32, value=hl.tbool), - key=['key']) + key=['key'], + ) ds = ds.annotate_rows(key=ds.locus.position % 2) ds = ds.annotate_rows(value=kt[ds['key']]['value']) rt = ds.rows() - self.assertTrue( - rt.all(((rt.locus.position % 2) == 0) == rt['value'])) + self.assertTrue(rt.all(((rt.locus.position % 2) == 0) == rt['value'])) def test_computed_key_join_multiple_keys(self): ds = self.get_mt() kt = hl.Table.parallelize( - [{'key1': 0, 'key2': 0, 'value': 0}, - {'key1': 1, 'key2': 0, 'value': 1}, - {'key1': 0, 'key2': 1, 'value': -2}, - {'key1': 1, 'key2': 1, 'value': -1}], + [ + {'key1': 0, 'key2': 0, 'value': 0}, + {'key1': 1, 'key2': 0, 'value': 1}, + {'key1': 0, 'key2': 1, 'value': -2}, + {'key1': 1, 'key2': 1, 'value': -1}, + ], hl.tstruct(key1=hl.tint32, key2=hl.tint32, value=hl.tint32), - key=['key1', 'key2']) + key=['key1', 'key2'], + ) ds = ds.annotate_rows(key1=ds.locus.position % 2, key2=ds.info.DP % 2) ds = ds.annotate_rows(value=kt[ds.key1, ds.key2]['value']) rt = ds.rows() - self.assertTrue( - rt.all((rt.locus.position % 2) - 2 * (rt.info.DP % 2) == rt['value'])) + self.assertTrue(rt.all((rt.locus.position % 2) - 2 * (rt.info.DP % 2) == rt['value'])) def test_computed_key_join_duplicate_row_keys(self): ds = self.get_mt() kt = hl.Table.parallelize( [{'culprit': 'InbreedingCoeff', 'foo': 'bar', 'value': 'IB'}], hl.tstruct(culprit=hl.tstr, foo=hl.tstr, value=hl.tstr), - key=['culprit', 'foo']) - ds = ds.annotate_rows( - dsfoo='bar', - info=ds.info.annotate(culprit=[ds.info.culprit, "foo"])) + key=['culprit', 'foo'], + ) + ds = ds.annotate_rows(dsfoo='bar', info=ds.info.annotate(culprit=[ds.info.culprit, "foo"])) ds = ds.explode_rows(ds.info.culprit) ds = ds.annotate_rows(value=kt[ds.info.culprit, ds.dsfoo]['value']) rt = ds.rows() self.assertTrue( - rt.all(hl.if_else( - rt.info.culprit == "InbreedingCoeff", - rt['value'] == "IB", - hl.is_missing(rt['value'])))) + rt.all(hl.if_else(rt.info.culprit == "InbreedingCoeff", rt['value'] == "IB", hl.is_missing(rt['value']))) + ) def test_interval_join(self): left = hl.utils.range_matrix_table(50, 1, n_partitions=10) @@ -775,23 +802,31 @@ def test_interval_join(self): intervals = intervals.key_by(interval=hl.interval(intervals.idx * 10, intervals.idx * 10 + 5)) left = left.annotate_rows(interval_matches=intervals.index(left.row_key)) rows = left.rows() - self.assertTrue(rows.all(hl.case() - .when(rows.row_idx % 10 < 5, rows.interval_matches.idx == rows.row_idx // 10) - .default(hl.is_missing(rows.interval_matches)))) + self.assertTrue( + rows.all( + hl.case() + .when(rows.row_idx % 10 < 5, rows.interval_matches.idx == rows.row_idx // 10) + .default(hl.is_missing(rows.interval_matches)) + ) + ) - @fails_service_backend() - @fails_local_backend() def test_interval_product_join(self): left = hl.utils.range_matrix_table(50, 1, n_partitions=8) intervals = hl.utils.range_table(25) - intervals = intervals.key_by(interval=hl.interval( - 1 + (intervals.idx // 5) * 10 + (intervals.idx % 5), - (1 + intervals.idx // 5) * 10 - (intervals.idx % 5))) + intervals = intervals.key_by( + interval=hl.interval( + 1 + (intervals.idx // 5) * 10 + (intervals.idx % 5), (1 + intervals.idx // 5) * 10 - (intervals.idx % 5) + ) + ) intervals = intervals.annotate(i=intervals.idx % 5) left = left.annotate_rows(interval_matches=intervals.index(left.row_key, all_matches=True)) rows = left.rows() - self.assertTrue(rows.all(hl.sorted(rows.interval_matches.map(lambda x: x.i)) - == hl.range(0, hl.min(rows.row_idx % 10, 10 - rows.row_idx % 10)))) + self.assertTrue( + rows.all( + hl.sorted(rows.interval_matches.map(lambda x: x.i)) + == hl.range(0, hl.min(rows.row_idx % 10, 10 - rows.row_idx % 10)) + ) + ) def test_entry_join_self(self): mt1 = hl.utils.range_matrix_table(10, 10, n_partitions=4).choose_cols([9, 8, 7, 6, 5, 4, 3, 2, 1, 0]) @@ -839,7 +874,7 @@ def test_entries_table_length_and_fields(self): def test_entries_table_no_keys(self): mt = hl.utils.range_matrix_table(2, 2) - mt = mt.annotate_entries(x = (mt.row_idx, mt.col_idx)) + mt = mt.annotate_entries(x=(mt.row_idx, mt.col_idx)) original_order = [ hl.utils.Struct(row_idx=0, col_idx=0, x=(0, 0)), @@ -875,8 +910,7 @@ def test_filter_cols_agg(self): def test_vcf_regression(self): ds = hl.import_vcf(resource('33alleles.vcf')) - self.assertEqual( - ds.filter_rows(ds.alleles.length() == 2).count_rows(), 0) + self.assertEqual(ds.filter_rows(ds.alleles.length() == 2).count_rows(), 0) def test_field_groups(self): ds = self.get_mt() @@ -889,10 +923,9 @@ def test_field_groups(self): self.assertTrue(df.all((df.col_idx == df.col_struct.col_idx))) df = ds.annotate_entries(entry_struct=ds.entry).entries() - self.assertTrue(df.all( - ((hl.is_missing(df.GT) | - (df.GT == df.entry_struct.GT)) & - (df.AD == df.entry_struct.AD)))) + self.assertTrue( + df.all(((hl.is_missing(df.GT) | (df.GT == df.entry_struct.GT)) & (df.AD == df.entry_struct.AD))) + ) @test_timeout(batch=5 * 60) def test_filter_partitions(self): @@ -902,9 +935,12 @@ def test_filter_partitions(self): self.assertEqual(ds._filter_partitions(range(3)).n_partitions(), 3) self.assertEqual(ds._filter_partitions([4, 5, 7], keep=False).n_partitions(), 5) self.assertTrue( - ds._same(hl.MatrixTable.union_rows( - ds._filter_partitions([0, 3, 7]), - ds._filter_partitions([0, 3, 7], keep=False)))) + ds._same( + hl.MatrixTable.union_rows( + ds._filter_partitions([0, 3, 7]), ds._filter_partitions([0, 3, 7], keep=False) + ) + ) + ) def test_from_rows_table(self): mt = hl.import_vcf(resource('sample.vcf')) @@ -947,17 +983,24 @@ def test_indexed_read(self): f = new_temp_file(extension='mt') mt.write(f) mt1 = hl.read_matrix_table(f) - mt2 = hl.read_matrix_table(f, _intervals=[ - hl.Interval(start=150, end=250, includes_start=True, includes_end=False), - hl.Interval(start=250, end=500, includes_start=True, includes_end=False), - ]) + mt2 = hl.read_matrix_table( + f, + _intervals=[ + hl.Interval(start=150, end=250, includes_start=True, includes_end=False), + hl.Interval(start=250, end=500, includes_start=True, includes_end=False), + ], + ) self.assertEqual(mt2.n_partitions(), 2) self.assertTrue(mt1.filter_rows((mt1.row_idx >= 150) & (mt1.row_idx < 500))._same(mt2)) - mt2 = hl.read_matrix_table(f, _intervals=[ - hl.Interval(start=150, end=250, includes_start=True, includes_end=False), - hl.Interval(start=250, end=500, includes_start=True, includes_end=False), - ], _filter_intervals=True) + mt2 = hl.read_matrix_table( + f, + _intervals=[ + hl.Interval(start=150, end=250, includes_start=True, includes_end=False), + hl.Interval(start=250, end=500, includes_start=True, includes_end=False), + ], + _filter_intervals=True, + ) self.assertEqual(mt2.n_partitions(), 3) self.assertTrue(mt1.filter_rows((mt1.row_idx >= 150) & (mt1.row_idx < 500))._same(mt2)) @@ -965,11 +1008,19 @@ def test_indexed_read_vcf(self): vcf = self.get_mt(10) f = new_temp_file(extension='mt') vcf.write(f) - l1, l2, l3, l4 = hl.Locus('20', 10000000), hl.Locus('20', 11000000), hl.Locus('20', 13000000), hl.Locus('20', 14000000) - mt = hl.read_matrix_table(f, _intervals=[ - hl.Interval(start=l1, end=l2), - hl.Interval(start=l3, end=l4), - ]) + l1, l2, l3, l4 = ( + hl.Locus('20', 10000000), + hl.Locus('20', 11000000), + hl.Locus('20', 13000000), + hl.Locus('20', 14000000), + ) + mt = hl.read_matrix_table( + f, + _intervals=[ + hl.Interval(start=l1, end=l2), + hl.Interval(start=l3, end=l4), + ], + ) self.assertEqual(mt.n_partitions(), 2) p = (vcf.locus >= l1) & (vcf.locus < l2) q = (vcf.locus >= l3) & (vcf.locus < l4) @@ -983,30 +1034,29 @@ def test_interval_filter_partitions(self): hl.Interval(hl.Struct(idx=5), hl.Struct(idx=10)), hl.Interval(hl.Struct(idx=12), hl.Struct(idx=13)), hl.Interval(hl.Struct(idx=15), hl.Struct(idx=17)), - hl.Interval(hl.Struct(idx=19), hl.Struct(idx=20)) + hl.Interval(hl.Struct(idx=19), hl.Struct(idx=20)), ] - assert hl.read_matrix_table(path, _intervals=intervals, _filter_intervals = True).n_partitions() == 1 + assert hl.read_matrix_table(path, _intervals=intervals, _filter_intervals=True).n_partitions() == 1 intervals = [ hl.Interval(hl.Struct(idx=5), hl.Struct(idx=10)), hl.Interval(hl.Struct(idx=12), hl.Struct(idx=13)), hl.Interval(hl.Struct(idx=15), hl.Struct(idx=17)), - hl.Interval(hl.Struct(idx=45), hl.Struct(idx=50)), hl.Interval(hl.Struct(idx=52), hl.Struct(idx=53)), hl.Interval(hl.Struct(idx=55), hl.Struct(idx=57)), - hl.Interval(hl.Struct(idx=75), hl.Struct(idx=80)), hl.Interval(hl.Struct(idx=82), hl.Struct(idx=83)), hl.Interval(hl.Struct(idx=85), hl.Struct(idx=87)), ] - assert hl.read_matrix_table(path, _intervals=intervals, _filter_intervals = True).n_partitions() == 3 + assert hl.read_matrix_table(path, _intervals=intervals, _filter_intervals=True).n_partitions() == 3 @fails_service_backend() @test_timeout(3 * 60, local=6 * 60) def test_codecs_matrix(self): from hail.utils.java import scala_object + supported_codecs = scala_object(Env.hail().io, 'BufferSpec').specs() ds = self.get_mt() temp = new_temp_file(extension='mt') @@ -1019,6 +1069,7 @@ def test_codecs_matrix(self): @test_timeout(local=6 * 60) def test_codecs_table(self): from hail.utils.java import scala_object + supported_codecs = scala_object(Env.hail().io, 'BufferSpec').specs() rt = self.get_mt().rows() temp = new_temp_file(extension='ht') @@ -1121,25 +1172,19 @@ def test_to_table_on_row_fields(self): def test_to_table_on_col_and_col_key(self): mt, _, sample_ids, _, _, _ = self.get_example_mt_for_to_table_on_various_fields() - self.assertEqual(mt.col_key.collect(), - [hl.Struct(s=s) for s in sample_ids]) - self.assertEqual(mt.col.collect(), - [hl.Struct(s=s, col_idx=i) for i, s in enumerate(sample_ids)]) + self.assertEqual(mt.col_key.collect(), [hl.Struct(s=s) for s in sample_ids]) + self.assertEqual(mt.col.collect(), [hl.Struct(s=s, col_idx=i) for i, s in enumerate(sample_ids)]) def test_to_table_on_row_and_row_key(self): mt, _, _, _, rows, sorted_rows = self.get_example_mt_for_to_table_on_various_fields() - self.assertEqual(mt.row_key.collect(), - [hl.Struct(r=r) for r in sorted_rows]) - self.assertEqual(mt.row.collect(), - sorted([hl.Struct(r=r, row_idx=i) for i, r in enumerate(rows)], - key=lambda x: x.r)) + self.assertEqual(mt.row_key.collect(), [hl.Struct(r=r) for r in sorted_rows]) + self.assertEqual( + mt.row.collect(), sorted([hl.Struct(r=r, row_idx=i) for i, r in enumerate(rows)], key=lambda x: x.r) + ) def test_to_table_on_entry(self): mt, _, _, entries, _, sorted_rows = self.get_example_mt_for_to_table_on_various_fields() - self.assertEqual(mt.entry.collect(), - [hl.Struct(e=e) - for _ in sorted_rows - for e in entries]) + self.assertEqual(mt.entry.collect(), [hl.Struct(e=e) for _ in sorted_rows for e in entries]) def test_to_table_on_cols_method(self): mt, _, sample_ids, _, _, _ = self.get_example_mt_for_to_table_on_various_fields() @@ -1168,11 +1213,13 @@ def test_order_by_complex_exprs(self): assert ht.order_by(-ht.idx).idx.collect() == list(range(10))[::-1] def test_order_by_intervals(self): - intervals = {0: hl.Interval(0, 3, includes_start=True, includes_end=False), - 1: hl.Interval(0, 4, includes_start=True, includes_end=True), - 2: hl.Interval(1, 4, includes_start=True, includes_end=False), - 3: hl.Interval(0, 4, includes_start=False, includes_end=False), - 4: hl.Interval(0, 4, includes_start=True, includes_end=False)} + intervals = { + 0: hl.Interval(0, 3, includes_start=True, includes_end=False), + 1: hl.Interval(0, 4, includes_start=True, includes_end=True), + 2: hl.Interval(1, 4, includes_start=True, includes_end=False), + 3: hl.Interval(0, 4, includes_start=False, includes_end=False), + 4: hl.Interval(0, 4, includes_start=True, includes_end=False), + } ht = hl.utils.range_table(5) ht = ht.annotate_globals(ilist=intervals) @@ -1210,11 +1257,14 @@ def test_make_table(self): mt = mt.key_cols_by(col_idx=hl.str(mt.col_idx)) t = hl.Table.parallelize( - [{'row_idx': 0, '0.x': 0, '1.x': 0}, - {'row_idx': 1, '0.x': 0, '1.x': 1}, - {'row_idx': 2, '0.x': 0, '1.x': 2}], + [ + {'row_idx': 0, '0.x': 0, '1.x': 0}, + {'row_idx': 1, '0.x': 0, '1.x': 1}, + {'row_idx': 2, '0.x': 0, '1.x': 2}, + ], hl.tstruct(**{'row_idx': hl.tint32, '0.x': hl.tint32, '1.x': hl.tint32}), - key='row_idx') + key='row_idx', + ) self.assertTrue(mt.make_table()._same(t)) @@ -1224,9 +1274,7 @@ def test_make_table_empty_entry_field(self): mt = mt.key_cols_by(col_idx=hl.str(mt.col_idx)) t = mt.make_table() - self.assertEqual( - t.row.dtype, - hl.tstruct(**{'row_idx': hl.tint32, '0': hl.tint32, '1': hl.tint32})) + self.assertEqual(t.row.dtype, hl.tstruct(**{'row_idx': hl.tint32, '0': hl.tint32, '1': hl.tint32})) def test_make_table_sep(self): mt = hl.utils.range_matrix_table(3, 2) @@ -1241,25 +1289,26 @@ def test_make_table_sep(self): def test_make_table_row_equivalence(self): mt = hl.utils.range_matrix_table(3, 3) - mt = mt.annotate_rows(r1 = hl.rand_norm(), r2 = hl.rand_norm()) - mt = mt.annotate_entries(e1 = hl.rand_norm(), e2 = hl.rand_norm()) + mt = mt.annotate_rows(r1=hl.rand_norm(), r2=hl.rand_norm()) + mt = mt.annotate_entries(e1=hl.rand_norm(), e2=hl.rand_norm()) mt = mt.key_cols_by(col_idx=hl.str(mt.col_idx)) assert mt.make_table().select(*mt.row_value)._same(mt.rows()) def test_make_table_na_error(self): - mt = hl.utils.range_matrix_table(3, 3).key_cols_by(s = hl.missing('str')) - mt = mt.annotate_entries(e1 = 1) + mt = hl.utils.range_matrix_table(3, 3).key_cols_by(s=hl.missing('str')) + mt = mt.annotate_entries(e1=1) with pytest.raises(ValueError): mt.make_table() def test_transmute(self): mt = ( hl.utils.range_matrix_table(1, 1) - .annotate_globals(g1=0, g2=0) - .annotate_cols(c1=0, c2=0) - .annotate_rows(r1=0, r2=0) - .annotate_entries(e1=0, e2=0)) + .annotate_globals(g1=0, g2=0) + .annotate_cols(c1=0, c2=0) + .annotate_rows(r1=0, r2=0) + .annotate_entries(e1=0, e2=0) + ) self.assertEqual(mt.transmute_globals(g3=mt.g2 + 1).globals.dtype, hl.tstruct(g1=hl.tint, g3=hl.tint)) self.assertEqual(mt.transmute_rows(r3=mt.r2 + 1).row_value.dtype, hl.tstruct(r1=hl.tint, r3=hl.tint)) self.assertEqual(mt.transmute_cols(c3=mt.c2 + 1).col_value.dtype, hl.tstruct(c1=hl.tint, c3=hl.tint)) @@ -1267,7 +1316,7 @@ def test_transmute(self): def test_transmute_agg(self): mt = hl.utils.range_matrix_table(1, 1).annotate_entries(x=5) - mt = mt.transmute_rows(y = hl.agg.mean(mt.x)) + mt = mt.transmute_rows(y=hl.agg.mean(mt.x)) def test_agg_explode(self): t = hl.Table.parallelize([ @@ -1275,10 +1324,9 @@ def test_agg_explode(self): hl.struct(a=hl.empty_array(hl.tint32)), hl.struct(a=hl.missing(hl.tarray(hl.tint32))), hl.struct(a=[3]), - hl.struct(a=[hl.missing(hl.tint32)]) + hl.struct(a=[hl.missing(hl.tint32)]), ]) - self.assertCountEqual(t.aggregate(hl.agg.explode(lambda elt: hl.agg.collect(elt), t.a)), - [1, 2, None, 3]) + self.assertCountEqual(t.aggregate(hl.agg.explode(lambda elt: hl.agg.collect(elt), t.a)), [1, 2, None, 3]) def test_agg_call_stats(self): t = hl.Table.parallelize([ @@ -1287,35 +1335,30 @@ def test_agg_call_stats(self): hl.struct(c=hl.call(0, 2, phased=True)), hl.struct(c=hl.call(1)), hl.struct(c=hl.call(0)), - hl.struct(c=hl.call()) + hl.struct(c=hl.call()), ]) actual = t.aggregate(hl.agg.call_stats(t.c, ['A', 'T', 'G'])) - expected = hl.struct(AC=[5, 2, 1], - AF=[5.0 / 8.0, 2.0 / 8.0, 1.0 / 8.0], - AN=8, - homozygote_count=[1, 0, 0]) + expected = hl.struct(AC=[5, 2, 1], AF=[5.0 / 8.0, 2.0 / 8.0, 1.0 / 8.0], AN=8, homozygote_count=[1, 0, 0]) - self.assertTrue(hl.Table.parallelize([actual]), - hl.Table.parallelize([expected])) + self.assertTrue(hl.Table.parallelize([actual]), hl.Table.parallelize([expected])) def test_hardy_weinberg_test(self): mt = hl.import_vcf(resource('HWE_test.vcf')) mt_two_sided = mt.select_rows(**hl.agg.hardy_weinberg_test(mt.GT, one_sided=False)) rt_two_sided = mt_two_sided.rows() - expected_two_sided = hl.Table.parallelize([ - hl.struct( - locus=hl.locus('20', pos), - alleles=alleles, - het_freq_hwe=r, - p_value=p - ) - for (pos, alleles, r, p) in [ - (1, ['A', 'G'], 0.0, 0.5), - (2, ['A', 'G'], 0.25, 0.5), - (3, ['T', 'C'], 0.5357142857142857, 0.21428571428571427), - (4, ['T', 'A'], 0.5714285714285714, 0.6571428571428573), - (5, ['G', 'A'], 0.3333333333333333, 0.5)]], - key=['locus', 'alleles']) + expected_two_sided = hl.Table.parallelize( + [ + hl.struct(locus=hl.locus('20', pos), alleles=alleles, het_freq_hwe=r, p_value=p) + for (pos, alleles, r, p) in [ + (1, ['A', 'G'], 0.0, 0.5), + (2, ['A', 'G'], 0.25, 0.5), + (3, ['T', 'C'], 0.5357142857142857, 0.21428571428571427), + (4, ['T', 'A'], 0.5714285714285714, 0.6571428571428573), + (5, ['G', 'A'], 0.3333333333333333, 0.5), + ] + ], + key=['locus', 'alleles'], + ) self.assertTrue(rt_two_sided.filter(rt_two_sided.locus.position != 6)._same(expected_two_sided)) rt6_two_sided = rt_two_sided.filter(rt_two_sided.locus.position == 6).collect()[0] @@ -1324,20 +1367,19 @@ def test_hardy_weinberg_test(self): mt_one_sided = mt.select_rows(**hl.agg.hardy_weinberg_test(mt.GT, one_sided=True)) rt_one_sided = mt_one_sided.rows() - expected_one_sided = hl.Table.parallelize([ - hl.struct( - locus=hl.locus('20', pos), - alleles=alleles, - het_freq_hwe=r, - p_value=p - ) - for (pos, alleles, r, p) in [ - (1, ['A', 'G'], 0.0, 0.5), - (2, ['A', 'G'], 0.25, 0.5), - (3, ['T', 'C'], 0.5357142857142857, 0.7857142857142857), - (4, ['T', 'A'], 0.5714285714285714, 0.5714285714285715), - (5, ['G', 'A'], 0.3333333333333333, 0.5)]], - key=['locus', 'alleles']) + expected_one_sided = hl.Table.parallelize( + [ + hl.struct(locus=hl.locus('20', pos), alleles=alleles, het_freq_hwe=r, p_value=p) + for (pos, alleles, r, p) in [ + (1, ['A', 'G'], 0.0, 0.5), + (2, ['A', 'G'], 0.25, 0.5), + (3, ['T', 'C'], 0.5357142857142857, 0.7857142857142857), + (4, ['T', 'A'], 0.5714285714285714, 0.5714285714285715), + (5, ['G', 'A'], 0.3333333333333333, 0.5), + ] + ], + key=['locus', 'alleles'], + ) self.assertTrue(rt_one_sided.filter(rt_one_sided.locus.position != 6)._same(expected_one_sided)) rt6_one_sided = rt_one_sided.filter(rt_one_sided.locus.position == 6).collect()[0] @@ -1347,30 +1389,28 @@ def test_hardy_weinberg_test(self): def test_hw_func_and_agg_agree(self): mt = hl.import_vcf(resource('sample.vcf')) mt_two_sided = mt.annotate_rows( - stats=hl.agg.call_stats(mt.GT, mt.alleles), - hw=hl.agg.hardy_weinberg_test(mt.GT, one_sided=False) + stats=hl.agg.call_stats(mt.GT, mt.alleles), hw=hl.agg.hardy_weinberg_test(mt.GT, one_sided=False) ) mt_two_sided = mt_two_sided.annotate_rows( hw2=hl.hardy_weinberg_test( mt_two_sided.stats.homozygote_count[0], mt_two_sided.stats.AC[1] - 2 * mt_two_sided.stats.homozygote_count[1], mt_two_sided.stats.homozygote_count[1], - one_sided=False + one_sided=False, ) ) rt_two_sided = mt_two_sided.rows() self.assertTrue(rt_two_sided.all(rt_two_sided.hw == rt_two_sided.hw2)) mt_one_sided = mt.annotate_rows( - stats=hl.agg.call_stats(mt.GT, mt.alleles), - hw=hl.agg.hardy_weinberg_test(mt.GT, one_sided=True) + stats=hl.agg.call_stats(mt.GT, mt.alleles), hw=hl.agg.hardy_weinberg_test(mt.GT, one_sided=True) ) mt_one_sided = mt_one_sided.annotate_rows( hw2=hl.hardy_weinberg_test( mt_one_sided.stats.homozygote_count[0], mt_one_sided.stats.AC[1] - 2 * mt_one_sided.stats.homozygote_count[1], mt_one_sided.stats.homozygote_count[1], - one_sided=True + one_sided=True, ) ) rt_one_sided = mt_one_sided.rows() @@ -1394,38 +1434,44 @@ def test_write_no_parts(self): def test_nulls_in_distinct_joins_1(self): # MatrixAnnotateRowsTable uses left distinct join mr = hl.utils.range_matrix_table(7, 3, 4) - matrix1 = mr.key_rows_by(new_key=hl.if_else((mr.row_idx == 3) | (mr.row_idx == 5), hl.missing(hl.tint32), mr.row_idx)) - matrix2 = mr.key_rows_by(new_key=hl.if_else((mr.row_idx == 4) | (mr.row_idx == 6), hl.missing(hl.tint32), mr.row_idx)) - joined = matrix1.select_rows(idx1=matrix1.row_idx, - idx2=matrix2.rows()[matrix1.new_key].row_idx) + matrix1 = mr.key_rows_by( + new_key=hl.if_else((mr.row_idx == 3) | (mr.row_idx == 5), hl.missing(hl.tint32), mr.row_idx) + ) + matrix2 = mr.key_rows_by( + new_key=hl.if_else((mr.row_idx == 4) | (mr.row_idx == 6), hl.missing(hl.tint32), mr.row_idx) + ) + joined = matrix1.select_rows(idx1=matrix1.row_idx, idx2=matrix2.rows()[matrix1.new_key].row_idx) def row(new_key, idx1, idx2): return hl.Struct(new_key=new_key, idx1=idx1, idx2=idx2) - expected = [row(0, 0, 0), - row(1, 1, 1), - row(2, 2, 2), - row(4, 4, None), - row(6, 6, None), - row(None, 3, None), - row(None, 5, None)] + expected = [ + row(0, 0, 0), + row(1, 1, 1), + row(2, 2, 2), + row(4, 4, None), + row(6, 6, None), + row(None, 3, None), + row(None, 5, None), + ] self.assertEqual(joined.rows().collect(), expected) def test_nulls_in_distinct_joins_2(self): mr = hl.utils.range_matrix_table(7, 3, 4) - matrix1 = mr.key_rows_by(new_key=hl.if_else((mr.row_idx == 3) | (mr.row_idx == 5), hl.missing(hl.tint32), mr.row_idx)) - matrix2 = mr.key_rows_by(new_key=hl.if_else((mr.row_idx == 4) | (mr.row_idx == 6), hl.missing(hl.tint32), mr.row_idx)) + matrix1 = mr.key_rows_by( + new_key=hl.if_else((mr.row_idx == 3) | (mr.row_idx == 5), hl.missing(hl.tint32), mr.row_idx) + ) + matrix2 = mr.key_rows_by( + new_key=hl.if_else((mr.row_idx == 4) | (mr.row_idx == 6), hl.missing(hl.tint32), mr.row_idx) + ) # union_cols uses inner distinct join - matrix1 = matrix1.annotate_entries(ridx=matrix1.row_idx, - cidx=matrix1.col_idx) - matrix2 = matrix2.annotate_entries(ridx=matrix2.row_idx, - cidx=matrix2.col_idx) + matrix1 = matrix1.annotate_entries(ridx=matrix1.row_idx, cidx=matrix1.col_idx) + matrix2 = matrix2.annotate_entries(ridx=matrix2.row_idx, cidx=matrix2.col_idx) matrix2 = matrix2.key_cols_by(col_idx=matrix2.col_idx + 3) expected = hl.utils.range_matrix_table(3, 6, 1) expected = expected.key_rows_by(new_key=expected.row_idx) - expected = expected.annotate_entries(ridx=expected.row_idx, - cidx=expected.col_idx % 3) + expected = expected.annotate_entries(ridx=expected.row_idx, cidx=expected.col_idx % 3) self.assertTrue(matrix1.union_cols(matrix2)._same(expected)) @@ -1514,100 +1560,120 @@ def test_refs_with_process_joins(self): def test_aggregate_localize_false(self): dim1, dim2 = 10, 10 mt = hl.utils.range_matrix_table(dim1, dim2) - mt = mt.annotate_entries(x = mt.aggregate_rows(hl.agg.max(mt.row_idx), _localize=False) - + mt.aggregate_cols(hl.agg.max(mt.col_idx), _localize=False) - + mt.aggregate_entries(hl.agg.max(mt.row_idx * mt.col_idx), _localize=False) - ) - assert mt.x.take(1)[0] == (dim1 - 1) + (dim2 - 1) + (dim1 -1) * (dim2 - 1) + mt = mt.annotate_entries( + x=mt.aggregate_rows(hl.agg.max(mt.row_idx), _localize=False) + + mt.aggregate_cols(hl.agg.max(mt.col_idx), _localize=False) + + mt.aggregate_entries(hl.agg.max(mt.row_idx * mt.col_idx), _localize=False) + ) + assert mt.x.take(1)[0] == (dim1 - 1) + (dim2 - 1) + (dim1 - 1) * (dim2 - 1) def test_agg_cols_filter(self): t = hl.utils.range_matrix_table(1, 10) - tests = [(agg.filter(t.col_idx > 7, - agg.collect(t.col_idx + 1).append(0)), - [9, 10, 0]), - (agg.filter(t.col_idx > 7, - agg.explode(lambda elt: agg.collect(elt + 1).append(0), - [t.col_idx, t.col_idx + 1])), - [9, 10, 10, 11, 0]), - (agg.filter(t.col_idx > 7, - agg.group_by(t.col_idx % 3, - hl.array(agg.collect_as_set(t.col_idx + 1)).append(0))), - {0: [10, 0], 2: [9, 0]}) - ] + tests = [ + (agg.filter(t.col_idx > 7, agg.collect(t.col_idx + 1).append(0)), [9, 10, 0]), + ( + agg.filter( + t.col_idx > 7, agg.explode(lambda elt: agg.collect(elt + 1).append(0), [t.col_idx, t.col_idx + 1]) + ), + [9, 10, 10, 11, 0], + ), + ( + agg.filter( + t.col_idx > 7, agg.group_by(t.col_idx % 3, hl.array(agg.collect_as_set(t.col_idx + 1)).append(0)) + ), + {0: [10, 0], 2: [9, 0]}, + ), + ] for aggregation, expected in tests: - self.assertEqual(t.select_rows(result = aggregation).result.collect()[0], expected) + self.assertEqual(t.select_rows(result=aggregation).result.collect()[0], expected) def test_agg_cols_explode(self): t = hl.utils.range_matrix_table(1, 10) - tests = [(agg.explode(lambda elt: agg.collect(elt + 1).append(0), - hl.if_else(t.col_idx > 7, [t.col_idx, t.col_idx + 1], hl.empty_array(hl.tint32))), - [9, 10, 10, 11, 0]), - (agg.explode(lambda elt: agg.explode(lambda elt2: agg.collect(elt2 + 1).append(0), - [elt, elt + 1]), - hl.if_else(t.col_idx > 7, [t.col_idx, t.col_idx + 1], hl.empty_array(hl.tint32))), - [9, 10, 10, 11, 10, 11, 11, 12, 0]), - (agg.explode(lambda elt: agg.filter(elt > 8, - agg.collect(elt + 1).append(0)), - hl.if_else(t.col_idx > 7, [t.col_idx, t.col_idx + 1], hl.empty_array(hl.tint32))), - [10, 10, 11, 0]), - (agg.explode(lambda elt: agg.group_by(elt % 3, - agg.collect(elt + 1).append(0)), - hl.if_else(t.col_idx > 7, - [t.col_idx, t.col_idx + 1], - hl.empty_array(hl.tint32))), - {0: [10, 10, 0], 1: [11, 0], 2:[9, 0]}) - ] + tests = [ + ( + agg.explode( + lambda elt: agg.collect(elt + 1).append(0), + hl.if_else(t.col_idx > 7, [t.col_idx, t.col_idx + 1], hl.empty_array(hl.tint32)), + ), + [9, 10, 10, 11, 0], + ), + ( + agg.explode( + lambda elt: agg.explode(lambda elt2: agg.collect(elt2 + 1).append(0), [elt, elt + 1]), + hl.if_else(t.col_idx > 7, [t.col_idx, t.col_idx + 1], hl.empty_array(hl.tint32)), + ), + [9, 10, 10, 11, 10, 11, 11, 12, 0], + ), + ( + agg.explode( + lambda elt: agg.filter(elt > 8, agg.collect(elt + 1).append(0)), + hl.if_else(t.col_idx > 7, [t.col_idx, t.col_idx + 1], hl.empty_array(hl.tint32)), + ), + [10, 10, 11, 0], + ), + ( + agg.explode( + lambda elt: agg.group_by(elt % 3, agg.collect(elt + 1).append(0)), + hl.if_else(t.col_idx > 7, [t.col_idx, t.col_idx + 1], hl.empty_array(hl.tint32)), + ), + {0: [10, 10, 0], 1: [11, 0], 2: [9, 0]}, + ), + ] for aggregation, expected in tests: - self.assertEqual(t.select_rows(result = aggregation).result.collect()[0], expected) + self.assertEqual(t.select_rows(result=aggregation).result.collect()[0], expected) def test_agg_cols_group_by(self): t = hl.utils.range_matrix_table(1, 10) - tests = [(agg.group_by(t.col_idx % 2, - hl.array(agg.collect_as_set(t.col_idx + 1)).append(0)), - {0: [1, 3, 5, 7, 9, 0], 1: [2, 4, 6, 8, 10, 0]}), - (agg.group_by(t.col_idx % 3, - agg.filter(t.col_idx > 7, - hl.array(agg.collect_as_set(t.col_idx + 1)).append(0))), - {0: [10, 0], 1: [0], 2: [9, 0]}), - (agg.group_by(t.col_idx % 3, - agg.explode(lambda elt: agg.collect(elt + 1).append(0), - hl.if_else(t.col_idx > 7, - [t.col_idx, t.col_idx + 1], - hl.empty_array(hl.tint32)))), - {0: [10, 11, 0], 1: [0], 2:[9, 10, 0]}), - ] + tests = [ + ( + agg.group_by(t.col_idx % 2, hl.array(agg.collect_as_set(t.col_idx + 1)).append(0)), + {0: [1, 3, 5, 7, 9, 0], 1: [2, 4, 6, 8, 10, 0]}, + ), + ( + agg.group_by( + t.col_idx % 3, agg.filter(t.col_idx > 7, hl.array(agg.collect_as_set(t.col_idx + 1)).append(0)) + ), + {0: [10, 0], 1: [0], 2: [9, 0]}, + ), + ( + agg.group_by( + t.col_idx % 3, + agg.explode( + lambda elt: agg.collect(elt + 1).append(0), + hl.if_else(t.col_idx > 7, [t.col_idx, t.col_idx + 1], hl.empty_array(hl.tint32)), + ), + ), + {0: [10, 11, 0], 1: [0], 2: [9, 10, 0]}, + ), + ] for aggregation, expected in tests: - self.assertEqual(t.select_rows(result = aggregation).result.collect()[0], expected) + self.assertEqual(t.select_rows(result=aggregation).result.collect()[0], expected) def test_localize_entries_with_both_none_is_rows_table(self): mt = hl.utils.range_matrix_table(10, 10) mt = mt.select_entries(x=mt.row_idx * mt.col_idx) - localized = mt.localize_entries(entries_array_field_name=None, - columns_array_field_name=None) + localized = mt.localize_entries(entries_array_field_name=None, columns_array_field_name=None) rows_table = mt.rows() assert rows_table._same(localized) def test_localize_entries_with_none_cols_adds_no_globals(self): mt = hl.utils.range_matrix_table(10, 10) mt = mt.select_entries(x=mt.row_idx * mt.col_idx) - localized = mt.localize_entries(entries_array_field_name=Env.get_uid(), - columns_array_field_name=None) + localized = mt.localize_entries(entries_array_field_name=Env.get_uid(), columns_array_field_name=None) assert hl.eval(mt.globals) == hl.eval(localized.globals) def test_localize_entries_with_none_entries_changes_no_rows(self): mt = hl.utils.range_matrix_table(10, 10) mt = mt.select_entries(x=mt.row_idx * mt.col_idx) - localized = mt.localize_entries(entries_array_field_name=None, - columns_array_field_name=Env.get_uid()) + localized = mt.localize_entries(entries_array_field_name=None, columns_array_field_name=Env.get_uid()) rows_table = mt.rows() assert rows_table.select_globals()._same(localized.select_globals()) def test_localize_entries_creates_arrays_of_entries_and_array_of_cols(self): mt = hl.utils.range_matrix_table(10, 10) mt = mt.select_entries(x=mt.row_idx * mt.col_idx) - localized = mt.localize_entries(entries_array_field_name='entries', - columns_array_field_name='cols') + localized = mt.localize_entries(entries_array_field_name='entries', columns_array_field_name='cols') t = hl.utils.range_table(10) t = t.select(entries=hl.range(10).map(lambda y: hl.struct(x=t.idx * y))) t = t.select_globals(cols=hl.range(10).map(lambda y: hl.struct(col_idx=y))) @@ -1637,56 +1703,56 @@ def test_entry_filtering(self): mt = mt.filter_entries((mt.col_idx + mt.row_idx) % 2 == 0) assert mt.aggregate_entries(hl.agg.count()) == 50 - assert all(x == 5 for x in mt.annotate_cols(x = hl.agg.count()).x.collect()) - assert all(x == 5 for x in mt.annotate_rows(x = hl.agg.count()).x.collect()) + assert all(x == 5 for x in mt.annotate_cols(x=hl.agg.count()).x.collect()) + assert all(x == 5 for x in mt.annotate_rows(x=hl.agg.count()).x.collect()) mt = mt.unfilter_entries() assert mt.aggregate_entries(hl.agg.count()) == 100 - assert all(x == 10 for x in mt.annotate_cols(x = hl.agg.count()).x.collect()) - assert all(x == 10 for x in mt.annotate_rows(x = hl.agg.count()).x.collect()) + assert all(x == 10 for x in mt.annotate_cols(x=hl.agg.count()).x.collect()) + assert all(x == 10 for x in mt.annotate_rows(x=hl.agg.count()).x.collect()) def test_entry_filter_stats(self): mt = hl.utils.range_matrix_table(40, 20) mt = mt.filter_entries((mt.row_idx % 4 == 0) & (mt.col_idx % 4 == 0), keep=False) mt = mt.compute_entry_filter_stats() - row_expected = hl.dict({True: hl.struct(n_filtered=5, - n_remaining=15, - fraction_filtered=hl.float32(0.25)), - False: hl.struct(n_filtered=0, - n_remaining=20, - fraction_filtered=hl.float32(0.0))}) + row_expected = hl.dict({ + True: hl.struct(n_filtered=5, n_remaining=15, fraction_filtered=hl.float32(0.25)), + False: hl.struct(n_filtered=0, n_remaining=20, fraction_filtered=hl.float32(0.0)), + }) assert mt.aggregate_rows(hl.agg.all(mt.entry_stats_row == row_expected[mt.row_idx % 4 == 0])) - col_expected = hl.dict({True: hl.struct(n_filtered=10, - n_remaining=30, - fraction_filtered=hl.float32(0.25)), - False: hl.struct(n_filtered=0, - n_remaining=40, - fraction_filtered=hl.float32(0.0))}) + col_expected = hl.dict({ + True: hl.struct(n_filtered=10, n_remaining=30, fraction_filtered=hl.float32(0.25)), + False: hl.struct(n_filtered=0, n_remaining=40, fraction_filtered=hl.float32(0.0)), + }) assert mt.aggregate_cols(hl.agg.all(mt.entry_stats_col == col_expected[mt.col_idx % 4 == 0])) def test_annotate_col_agg_lowering(self): mt = hl.utils.range_matrix_table(10, 10, 2) mt = mt.annotate_cols(c1=[mt.col_idx, mt.col_idx * 2]) - mt = mt.annotate_entries(e1=mt.col_idx + mt.row_idx, e2=[mt.col_idx * mt.row_idx, mt.col_idx * mt.row_idx ** 2]) + mt = mt.annotate_entries(e1=mt.col_idx + mt.row_idx, e2=[mt.col_idx * mt.row_idx, mt.col_idx * mt.row_idx**2]) common_ref = mt.c1[1] - mt = mt.annotate_cols(exploded=hl.agg.explode(lambda e: common_ref + hl.agg.sum(e), mt.e2), - array=hl.agg.array_agg(lambda e: common_ref + hl.agg.sum(e), mt.e2), - filt=hl.agg.filter(mt.e1 < 5, hl.agg.sum(mt.e1) + common_ref), - grouped=hl.agg.group_by(mt.e1 % 5, hl.agg.sum(mt.e1) + common_ref)) + mt = mt.annotate_cols( + exploded=hl.agg.explode(lambda e: common_ref + hl.agg.sum(e), mt.e2), + array=hl.agg.array_agg(lambda e: common_ref + hl.agg.sum(e), mt.e2), + filt=hl.agg.filter(mt.e1 < 5, hl.agg.sum(mt.e1) + common_ref), + grouped=hl.agg.group_by(mt.e1 % 5, hl.agg.sum(mt.e1) + common_ref), + ) mt.cols()._force_count() def test_annotate_rows_scan_lowering(self): mt = hl.utils.range_matrix_table(10, 10, 2) mt = mt.annotate_rows(r1=[mt.row_idx, mt.row_idx * 2]) common_ref = mt.r1[1] - mt = mt.annotate_rows(exploded=hl.scan.explode(lambda e: common_ref + hl.scan.sum(e), mt.r1), - array=hl.scan.array_agg(lambda e: common_ref + hl.scan.sum(e), mt.r1), - filt=hl.scan.filter(mt.row_idx < 5, hl.scan.sum(mt.row_idx) + common_ref), - grouped=hl.scan.group_by(mt.row_idx % 5, hl.scan.sum(mt.row_idx) + common_ref), - an_agg = hl.agg.sum(mt.row_idx * mt.col_idx)) + mt = mt.annotate_rows( + exploded=hl.scan.explode(lambda e: common_ref + hl.scan.sum(e), mt.r1), + array=hl.scan.array_agg(lambda e: common_ref + hl.scan.sum(e), mt.r1), + filt=hl.scan.filter(mt.row_idx < 5, hl.scan.sum(mt.row_idx) + common_ref), + grouped=hl.scan.group_by(mt.row_idx % 5, hl.scan.sum(mt.row_idx) + common_ref), + an_agg=hl.agg.sum(mt.row_idx * mt.col_idx), + ) mt.cols()._force_count() def test_show_runs(self): @@ -1698,13 +1764,15 @@ def test_show_header(self): mt = mt.annotate_entries(x=1) mt = mt.key_cols_by(col_idx=mt.col_idx + 10) - expected = ('+---------+-------+\n' - '| row_idx | 10.x |\n' - '+---------+-------+\n' - '| int32 | int32 |\n' - '+---------+-------+\n' - '| 0 | 1 |\n' - '+---------+-------+\n') + expected = ( + '+---------+-------+\n' + '| row_idx | 10.x |\n' + '+---------+-------+\n' + '| int32 | int32 |\n' + '+---------+-------+\n' + '| 0 | 1 |\n' + '+---------+-------+\n' + ) actual = mt.show(handler=str) assert actual == expected @@ -1714,8 +1782,7 @@ def test_partitioned_write(self): def test_parts(parts, expected=mt): parts = [ - hl.Interval(start=hl.Struct(row_idx=s), end=hl.Struct(row_idx=e), - includes_start=_is, includes_end=ie) + hl.Interval(start=hl.Struct(row_idx=s), end=hl.Struct(row_idx=e), includes_start=_is, includes_end=ie) for (s, e, _is, ie) in parts ] @@ -1726,38 +1793,24 @@ def test_parts(parts, expected=mt): self.assertEqual(mt2.n_partitions(), len(parts)) self.assertTrue(mt2._same(expected)) - test_parts([ - (0, 40, True, False) - ]) + test_parts([(0, 40, True, False)]) - test_parts([ - (-34, -31, True, True), - (-30, 9, True, True), - (10, 107, True, True), - (108, 1000, True, True) - ]) + test_parts([(-34, -31, True, True), (-30, 9, True, True), (10, 107, True, True), (108, 1000, True, True)]) - test_parts([ - (0, 5, True, False), - (35, 40, True, True) - ], - mt.filter_rows((mt.row_idx < 5) | (mt.row_idx >= 35))) + test_parts([(0, 5, True, False), (35, 40, True, True)], mt.filter_rows((mt.row_idx < 5) | (mt.row_idx >= 35))) - test_parts([ - (5, 35, True, False) - ], - mt.filter_rows((mt.row_idx >= 5) & (mt.row_idx < 35))) + test_parts([(5, 35, True, False)], mt.filter_rows((mt.row_idx >= 5) & (mt.row_idx < 35))) def test_partitioned_write_coerce(self): mt = hl.import_vcf(resource('sample.vcf')) - parts = [ - hl.Interval(hl.Locus('20', 10277621), hl.Locus('20', 11898992)) - ] + parts = [hl.Interval(hl.Locus('20', 10277621), hl.Locus('20', 11898992))] tmp = new_temp_file(extension='mt') mt.write(tmp, _partitions=parts) mt2 = hl.read_matrix_table(tmp) - assert mt2.aggregate_rows(hl.agg.all(hl.literal(hl.Interval(hl.Locus('20', 10277621), hl.Locus('20', 11898992))).contains(mt2.locus))) + assert mt2.aggregate_rows( + hl.agg.all(hl.literal(hl.Interval(hl.Locus('20', 10277621), hl.Locus('20', 11898992))).contains(mt2.locus)) + ) assert mt2.n_partitions() == len(parts) assert hl.filter_intervals(mt, parts)._same(mt2) @@ -1766,7 +1819,7 @@ def test_overwrite(self): f = new_temp_file(extension='mt') mt.write(f) - with pytest.raises(hl.utils.FatalError, match= "file already exists"): + with pytest.raises(hl.utils.FatalError, match="file already exists"): mt.write(f) mt.write(f, overwrite=True) @@ -1788,7 +1841,7 @@ def test_matrix_native_write_range(self): def test_matrix_multi_write_range(self): mts = [ hl.utils.range_matrix_table(11, 27, n_partitions=10), - hl.utils.range_matrix_table(11, 3, n_partitions=10) + hl.utils.range_matrix_table(11, 3, n_partitions=10), ] f = new_temp_file() hl.experimental.write_matrix_tables(mts, f) @@ -1797,7 +1850,7 @@ def test_matrix_multi_write_range(self): def test_key_cols_by_extract_issue(self): mt = hl.utils.range_matrix_table(1000, 100) - mt = mt.key_cols_by(col_id = hl.str(mt.col_idx)) + mt = mt.key_cols_by(col_id=hl.str(mt.col_idx)) mt = mt.add_col_index() mt.show() @@ -1817,36 +1870,38 @@ def test_invalid_field_ref_error(self): mt = hl.balding_nichols_model(2, 5, 5) mt2 = hl.balding_nichols_model(2, 5, 5) with pytest.raises(hl.expr.ExpressionException, match='Found fields from 2 objects:'): - mt.annotate_entries(x = mt.GT.n_alt_alleles() * mt2.af) + mt.annotate_entries(x=mt.GT.n_alt_alleles() * mt2.af) def test_invalid_field_ref_annotate(self): mt = hl.balding_nichols_model(2, 5, 5) mt2 = hl.balding_nichols_model(2, 5, 5) with pytest.raises(hl.expr.ExpressionException, match='source mismatch'): - mt.annotate_entries(x = mt2.af) + mt.annotate_entries(x=mt2.af) def test_filter_locus_position_collect_returns_data(self): t = hl.utils.range_table(1) t = t.key_by(locus=hl.locus('2', t.idx + 1)) assert t.filter(t.locus.position >= 1).collect() == [ - hl.utils.Struct(idx=0, locus=hl.genetics.Locus(contig='2', position=1, reference_genome='GRCh37'))] + hl.utils.Struct(idx=0, locus=hl.genetics.Locus(contig='2', position=1, reference_genome='GRCh37')) + ] @fails_local_backend() def test_lower_row_agg_init_arg(self): mt = hl.balding_nichols_model(5, 200, 200) mt2 = hl.variant_qc(mt) mt2 = mt2.filter_rows((mt2.variant_qc.AF[0] > 0.05) & (mt2.variant_qc.AF[0] < 0.95)) - mt2 = mt2.sample_rows(.99) + mt2 = mt2.sample_rows(0.99) rows = mt2.rows() mt = mt.semi_join_rows(rows) hl.hwe_normalized_pca(mt.GT) + def test_keys_before_scans(): mt = hl.utils.range_matrix_table(6, 6) - mt = mt.annotate_rows(rev_idx = -mt.row_idx) + mt = mt.annotate_rows(rev_idx=-mt.row_idx) mt = mt.key_rows_by(mt.rev_idx) - mt = mt.annotate_rows(idx_scan = hl.scan.collect(mt.row_idx)) + mt = mt.annotate_rows(idx_scan=hl.scan.collect(mt.row_idx)) mt = mt.key_rows_by(mt.row_idx) assert mt.rows().idx_scan.collect() == [[5, 4, 3, 2, 1], [5, 4, 3, 2], [5, 4, 3], [5, 4], [5], []] @@ -1881,15 +1936,15 @@ def test_filter_against_invalid_contig(): def assert_unique_uids(mt): x = mt.aggregate_rows(hl.struct(r=hl.agg.collect_as_set(hl.rand_int64()), n=hl.agg.count())) - assert(len(x.r) == x.n) + assert len(x.r) == x.n x = mt.aggregate_cols(hl.struct(r=hl.agg.collect_as_set(hl.rand_int64()), n=hl.agg.count())) - assert(len(x.r) == x.n) + assert len(x.r) == x.n x = mt.aggregate_entries(hl.struct(r=hl.agg.collect_as_set(hl.rand_int64()), n=hl.agg.count())) - assert(len(x.r) == x.n) + assert len(x.r) == x.n def assert_contains_node(t, node): - assert(t._mir.base_search(lambda x: isinstance(x, node))) + assert t._mir.base_search(lambda x: isinstance(x, node)) def test_matrix_randomness_read(): @@ -1901,40 +1956,46 @@ def test_matrix_randomness_read(): @test_timeout(batch=8 * 60) def test_matrix_randomness_aggregate_rows_by_key_with_body_randomness(): rmt = hl.utils.range_matrix_table(20, 10, 3) - mt = (rmt.group_rows_by(k=rmt.row_idx % 5) + mt = ( + rmt.group_rows_by(k=rmt.row_idx % 5) .aggregate_rows(r=hl.rand_int64()) .aggregate_entries(e=hl.rand_int64()) - .result()) + .result() + ) assert_contains_node(mt, ir.MatrixAggregateRowsByKey) x = mt.aggregate_rows(hl.struct(r=hl.agg.collect_as_set(mt.r), n=hl.agg.count())) - assert(len(x.r) == x.n) + assert len(x.r) == x.n x = mt.aggregate_entries(hl.struct(r=hl.agg.collect_as_set(mt.e), n=hl.agg.count())) - assert(len(x.r) == x.n) + assert len(x.r) == x.n assert_unique_uids(mt) @test_timeout(batch=8 * 60) def test_matrix_randomness_aggregate_rows_by_key_then_aggregate_entries_with_agg_randomness(): rmt = hl.utils.range_matrix_table(20, 10, 3) - mt = (rmt.group_rows_by(k=rmt.row_idx % 5) - .aggregate_rows(r=hl.agg.collect(hl.rand_int64())) - .aggregate_entries(e=hl.agg.collect(hl.rand_int64())) - .result()) + mt = ( + rmt.group_rows_by(k=rmt.row_idx % 5) + .aggregate_rows(r=hl.agg.collect(hl.rand_int64())) + .aggregate_entries(e=hl.agg.collect(hl.rand_int64())) + .result() + ) assert_contains_node(mt, ir.MatrixAggregateRowsByKey) x = mt.aggregate_rows(hl.agg.explode(lambda r: hl.struct(r=hl.agg.collect_as_set(r), n=hl.agg.count()), mt.r)) - assert(len(x.r) == x.n) + assert len(x.r) == x.n x = mt.aggregate_entries(hl.agg.explode(lambda r: hl.struct(r=hl.agg.collect_as_set(r), n=hl.agg.count()), mt.e)) - assert(len(x.r) == x.n) + assert len(x.r) == x.n assert_unique_uids(mt) @test_timeout(batch=8 * 60) def test_matrix_randomness_aggregate_rows_by_key_without_body_randomness(): rmt = hl.utils.range_matrix_table(20, 10, 3) - mt = (rmt.group_rows_by(k=rmt.row_idx % 5) - .aggregate_rows(row_agg=hl.agg.sum(rmt.row_idx)) - .aggregate_entries(entry_agg=hl.agg.sum(rmt.row_idx + rmt.col_idx)) - .result()) + mt = ( + rmt.group_rows_by(k=rmt.row_idx % 5) + .aggregate_rows(row_agg=hl.agg.sum(rmt.row_idx)) + .aggregate_entries(entry_agg=hl.agg.sum(rmt.row_idx + rmt.col_idx)) + .result() + ) assert_contains_node(mt, ir.MatrixAggregateRowsByKey) assert_unique_uids(mt) @@ -1943,7 +2004,7 @@ def test_matrix_randomness_filter_rows_with_cond_randomness(): rmt = hl.utils.range_matrix_table(10, 10, 3) mt = rmt.filter_rows(hl.rand_int64() % 2 == 0) assert_contains_node(mt, ir.MatrixFilterRows) - mt.entries()._force_count() # test with no consumer randomness + mt.entries()._force_count() # test with no consumer randomness assert_unique_uids(mt) @@ -1966,7 +2027,7 @@ def test_matrix_randomness_map_cols_with_body_randomness(): mt = rmt.annotate_cols(r=hl.rand_int64()) assert_contains_node(mt, ir.MatrixMapCols) x = mt.aggregate_cols(hl.struct(r=hl.agg.collect_as_set(mt.r), n=hl.agg.count())) - assert(len(x.r) == x.n) + assert len(x.r) == x.n assert_unique_uids(mt) @@ -1975,7 +2036,7 @@ def test_matrix_randomness_map_cols_with_agg_randomness(): mt = rmt.annotate_cols(r=hl.agg.collect(hl.rand_int64())) assert_contains_node(mt, ir.MatrixMapCols) x = mt.aggregate_cols(hl.agg.explode(lambda r: hl.struct(r=hl.agg.collect_as_set(r), n=hl.agg.count()), mt.r)) - assert(len(x.r) == x.n) + assert len(x.r) == x.n assert_unique_uids(mt) @@ -1984,21 +2045,21 @@ def test_matrix_randomness_map_cols_with_scan_randomness(): mt = rmt.annotate_cols(r=hl.scan.collect(hl.rand_int64())) assert_contains_node(mt, ir.MatrixMapCols) x = mt.aggregate_cols(hl.struct(r=hl.agg.explode(lambda r: hl.agg.collect_as_set(r), mt.r), n=hl.agg.count())) - assert(len(x.r) == x.n - 1) + assert len(x.r) == x.n - 1 assert_unique_uids(mt) def test_matrix_randomness_map_cols_without_body_randomness(): rmt = hl.utils.range_matrix_table(10, 10, 3) - mt = rmt.annotate_cols(x=2*rmt.col_idx) + mt = rmt.annotate_cols(x=2 * rmt.col_idx) assert_contains_node(mt, ir.MatrixMapCols) assert_unique_uids(mt) def test_matrix_randomness_union_cols(): r, c = 5, 5 - mt = hl.utils.range_matrix_table(2*r, c) - mt2 = hl.utils.range_matrix_table(2*r, c) + mt = hl.utils.range_matrix_table(2 * r, c) + mt2 = hl.utils.range_matrix_table(2 * r, c) mt2 = mt2.key_rows_by(row_idx=mt2.row_idx + r) mt2 = mt2.key_cols_by(col_idx=mt2.col_idx + c) mt = mt.union_cols(mt2) @@ -2011,7 +2072,7 @@ def test_matrix_randomness_map_entries_with_body_randomness(): mt = rmt.annotate_entries(r=hl.rand_int64()) assert_contains_node(mt, ir.MatrixMapEntries) x = mt.aggregate_entries(hl.struct(r=hl.agg.collect_as_set(mt.r), n=hl.agg.count())) - assert(len(x.r) == x.n) + assert len(x.r) == x.n assert_unique_uids(mt) @@ -2026,7 +2087,7 @@ def test_matrix_randomness_filter_entries_with_cond_randomness(): rmt = hl.utils.range_matrix_table(10, 10, 3) mt = rmt.filter_entries(hl.rand_int64() % 2 == 0) assert_contains_node(mt, ir.MatrixFilterEntries) - mt.entries()._force_count() # test with no consumer randomness + mt.entries()._force_count() # test with no consumer randomness assert_unique_uids(mt) @@ -2049,7 +2110,7 @@ def test_matrix_randomness_map_rows(): mt = rmt.annotate_rows(r=hl.rand_int64()) assert_contains_node(mt, ir.MatrixMapRows) x = mt.aggregate_rows(hl.struct(r=hl.agg.collect_as_set(mt.r), n=hl.agg.count())) - assert(len(x.r) == x.n) + assert len(x.r) == x.n assert_unique_uids(mt) @@ -2058,7 +2119,7 @@ def test_matrix_randomness_map_rows_with_agg_randomness(): mt = rmt.annotate_rows(r=hl.agg.collect(hl.rand_int64())) assert_contains_node(mt, ir.MatrixMapRows) x = mt.aggregate_rows(hl.agg.explode(lambda r: hl.struct(r=hl.agg.collect_as_set(r), n=hl.agg.count()), mt.r)) - assert(len(x.r) == x.n) + assert len(x.r) == x.n assert_unique_uids(mt) @@ -2067,13 +2128,13 @@ def test_matrix_randomness_map_rows_with_scan_randomness(): mt = rmt.annotate_rows(r=hl.scan.collect(hl.rand_int64())) assert_contains_node(mt, ir.MatrixMapRows) x = mt.aggregate_rows(hl.struct(r=hl.agg.explode(lambda r: hl.agg.collect_as_set(r), mt.r), n=hl.agg.count())) - assert(len(x.r) == x.n - 1) + assert len(x.r) == x.n - 1 assert_unique_uids(mt) def test_matrix_randomness_map_rows_without_body_randomness(): rmt = hl.utils.range_matrix_table(10, 10, 3) - mt = rmt.annotate_rows(x=2*rmt.row_idx) + mt = rmt.annotate_rows(x=2 * rmt.row_idx) assert_contains_node(mt, ir.MatrixMapRows) assert_unique_uids(mt) @@ -2082,7 +2143,7 @@ def test_matrix_randomness_map_globals_with_body_randomness(): rmt = hl.utils.range_matrix_table(10, 10, 3) mt = rmt.annotate_globals(x=hl.rand_int64()) assert_contains_node(mt, ir.MatrixMapGlobals) - mt.entries()._force_count() # test with no consumer randomness + mt.entries()._force_count() # test with no consumer randomness assert_unique_uids(mt) @@ -2097,7 +2158,7 @@ def test_matrix_randomness_filter_cols_with_cond_randomness(): rmt = hl.utils.range_matrix_table(10, 10, 3) mt = rmt.filter_cols(hl.rand_int64() % 2 == 0) assert_contains_node(mt, ir.MatrixFilterCols) - mt.entries()._force_count() # test with no consumer randomness + mt.entries()._force_count() # test with no consumer randomness assert_unique_uids(mt) @@ -2119,40 +2180,46 @@ def test_matrix_randomness_collect_cols_by_key(): @test_timeout(batch=5 * 60) def test_matrix_randomness_aggregate_cols_by_key_with_body_randomness(): rmt = hl.utils.range_matrix_table(20, 10, 3) - mt = (rmt.group_cols_by(k=rmt.col_idx % 5) - .aggregate_cols(r=hl.rand_int64()) - .aggregate_entries(e=hl.rand_int64()) - .result()) + mt = ( + rmt.group_cols_by(k=rmt.col_idx % 5) + .aggregate_cols(r=hl.rand_int64()) + .aggregate_entries(e=hl.rand_int64()) + .result() + ) assert_contains_node(mt, ir.MatrixAggregateColsByKey) x = mt.aggregate_cols(hl.struct(r=hl.agg.collect_as_set(mt.r), n=hl.agg.count())) - assert(len(x.r) == x.n) + assert len(x.r) == x.n x = mt.aggregate_entries(hl.struct(r=hl.agg.collect_as_set(mt.e), n=hl.agg.count())) - assert(len(x.r) == x.n) + assert len(x.r) == x.n assert_unique_uids(mt) @test_timeout(batch=5 * 60) def test_matrix_randomness_aggregate_cols_by_key_with_agg_randomness(): rmt = hl.utils.range_matrix_table(20, 10, 3) - mt = (rmt.group_cols_by(k=rmt.col_idx % 5) - .aggregate_cols(r=hl.agg.collect(hl.rand_int64())) - .aggregate_entries(e=hl.agg.collect(hl.rand_int64())) - .result()) + mt = ( + rmt.group_cols_by(k=rmt.col_idx % 5) + .aggregate_cols(r=hl.agg.collect(hl.rand_int64())) + .aggregate_entries(e=hl.agg.collect(hl.rand_int64())) + .result() + ) assert_contains_node(mt, ir.MatrixAggregateColsByKey) x = mt.aggregate_cols(hl.agg.explode(lambda r: hl.struct(r=hl.agg.collect_as_set(r), n=hl.agg.count()), mt.r)) - assert(len(x.r) == x.n) + assert len(x.r) == x.n x = mt.aggregate_entries(hl.agg.explode(lambda r: hl.struct(r=hl.agg.collect_as_set(r), n=hl.agg.count()), mt.e)) - assert(len(x.r) == x.n) + assert len(x.r) == x.n assert_unique_uids(mt) @test_timeout(batch=5 * 60) def test_matrix_randomness_aggregate_cols_by_key_without_body_randomness(): rmt = hl.utils.range_matrix_table(20, 10, 3) - mt = (rmt.group_cols_by(k=rmt.col_idx % 5) - .aggregate_cols(row_agg=hl.agg.sum(rmt.col_idx)) - .aggregate_entries(entry_agg=hl.agg.sum(rmt.row_idx + rmt.col_idx)) - .result()) + mt = ( + rmt.group_cols_by(k=rmt.col_idx % 5) + .aggregate_cols(row_agg=hl.agg.sum(rmt.col_idx)) + .aggregate_entries(entry_agg=hl.agg.sum(rmt.row_idx + rmt.col_idx)) + .result() + ) assert_contains_node(mt, ir.MatrixAggregateColsByKey) assert_unique_uids(mt) @@ -2175,8 +2242,8 @@ def test_matrix_randomness_repartition(): def test_matrix_randomness_union_rows(): r, c = 5, 5 - mt = hl.utils.range_matrix_table(2*r, c) - mt2 = hl.utils.range_matrix_table(2*r, c) + mt = hl.utils.range_matrix_table(2 * r, c) + mt2 = hl.utils.range_matrix_table(2 * r, c) mt2 = mt2.key_rows_by(row_idx=mt2.row_idx + r) mt = mt.union_rows(mt2) assert_contains_node(mt, ir.MatrixUnionRows) @@ -2279,9 +2346,16 @@ def test_matrix_randomness_filter_intervals(): def test_upcast_tuples(): - t = hl.utils.range_matrix_table(1,1) + t = hl.utils.range_matrix_table(1, 1) t = t.annotate_cols(foo=[('0', 1)]) t = t.explode_cols(t.foo) t = t.annotate_cols(x=t.foo[1]) t = t.drop('foo') t.cols().collect() + + +def test_sample_entries(): + mt = hl.utils.range_matrix_table(10, 10) + ht = mt.entries() + ht = ht.sample(0.5) + ht._force_count() diff --git a/hail/python/test/hail/matrixtable/test_matrix_table_from_parts.py b/hail/python/test/hail/matrixtable/test_matrix_table_from_parts.py index c300731541a..b45fe9d574b 100644 --- a/hail/python/test/hail/matrixtable/test_matrix_table_from_parts.py +++ b/hail/python/test/hail/matrixtable/test_matrix_table_from_parts.py @@ -1,15 +1,19 @@ +from typing import ClassVar + import pytest + import hail as hl def unless(test: bool, kvs): return {} if test else kvs -class TestData(): - globals={'hello': 'world'} - rows={'foo': ['a', 'b']} - cols={'bar': ['c', 'd']} - entries={'baz': [[1, 2], [3, 4]]} + +class TestData: + globals: ClassVar = {'hello': 'world'} + rows: ClassVar = {'foo': ['a', 'b']} + cols: ClassVar = {'bar': ['c', 'd']} + entries: ClassVar = {'baz': [[1, 2], [3, 4]]} @staticmethod def assert_matches_globals(mt: 'hl.MatrixTable', no_props=False): @@ -21,77 +25,70 @@ def assert_no_globals(mt: 'hl.MatrixTable'): @staticmethod def assert_matches_rows(mt: 'hl.MatrixTable', no_props=False): - assert mt.rows().collect() == \ - [ hl.Struct(row_idx=0, **unless(no_props, {'foo': 'a'})) - , hl.Struct(row_idx=1, **unless(no_props, {'foo': 'b'})) - ] + assert mt.rows().collect() == [ + hl.Struct(row_idx=0, **unless(no_props, {'foo': 'a'})), + hl.Struct(row_idx=1, **unless(no_props, {'foo': 'b'})), + ] @staticmethod def assert_matches_cols(mt: 'hl.MatrixTable', no_props=False): - assert mt.cols().collect() == \ - [ hl.Struct(col_idx=0, **unless(no_props, {'bar': 'c'})) - , hl.Struct(col_idx=1, **unless(no_props, {'bar': 'd'})) - ] + assert mt.cols().collect() == [ + hl.Struct(col_idx=0, **unless(no_props, {'bar': 'c'})), + hl.Struct(col_idx=1, **unless(no_props, {'bar': 'd'})), + ] @staticmethod - def assert_matches_entries( mt: 'hl.MatrixTable', no_props=False): - assert mt.select_rows().select_cols().entries().collect() == \ - [ hl.Struct(row_idx=0, col_idx=0, **unless(no_props, {'baz': 1})) - , hl.Struct(row_idx=0, col_idx=1, **unless(no_props, {'baz': 2})) - , hl.Struct(row_idx=1, col_idx=0, **unless(no_props, {'baz': 3})) - , hl.Struct(row_idx=1, col_idx=1, **unless(no_props, {'baz': 4})) - ] + def assert_matches_entries(mt: 'hl.MatrixTable', no_props=False): + assert mt.select_rows().select_cols().entries().collect() == [ + hl.Struct(row_idx=0, col_idx=0, **unless(no_props, {'baz': 1})), + hl.Struct(row_idx=0, col_idx=1, **unless(no_props, {'baz': 2})), + hl.Struct(row_idx=1, col_idx=0, **unless(no_props, {'baz': 3})), + hl.Struct(row_idx=1, col_idx=1, **unless(no_props, {'baz': 4})), + ] + def test_from_parts(): - mt = hl.MatrixTable.from_parts( globals=TestData.globals - , rows=TestData.rows - , cols=TestData.cols - , entries=TestData.entries - ) + mt = hl.MatrixTable.from_parts( + globals=TestData.globals, rows=TestData.rows, cols=TestData.cols, entries=TestData.entries + ) TestData.assert_matches_globals(mt) TestData.assert_matches_rows(mt) TestData.assert_matches_cols(mt) TestData.assert_matches_entries(mt) + def test_optional_globals(): - mt = hl.MatrixTable.from_parts( rows=TestData.rows - , cols=TestData.cols - , entries=TestData.entries - ) + mt = hl.MatrixTable.from_parts(rows=TestData.rows, cols=TestData.cols, entries=TestData.entries) TestData.assert_no_globals(mt) TestData.assert_matches_rows(mt) TestData.assert_matches_cols(mt) TestData.assert_matches_entries(mt) + def test_optional_rows(): - mt = hl.MatrixTable.from_parts( globals=TestData.globals - , cols=TestData.cols - , entries=TestData.entries - ) + mt = hl.MatrixTable.from_parts(globals=TestData.globals, cols=TestData.cols, entries=TestData.entries) TestData.assert_matches_globals(mt) TestData.assert_matches_rows(mt, no_props=True) TestData.assert_matches_cols(mt) TestData.assert_matches_entries(mt) + def test_optional_cols(): - mt = hl.MatrixTable.from_parts( globals=TestData.globals - , rows=TestData.rows - , entries=TestData.entries - ) + mt = hl.MatrixTable.from_parts(globals=TestData.globals, rows=TestData.rows, entries=TestData.entries) TestData.assert_matches_globals(mt) TestData.assert_matches_rows(mt) TestData.assert_matches_cols(mt, no_props=True) TestData.assert_matches_entries(mt) + def test_optional_globals_and_cols(): - mt = hl.MatrixTable.from_parts( rows=TestData.rows - , entries=TestData.entries - ) + mt = hl.MatrixTable.from_parts(rows=TestData.rows, entries=TestData.entries) TestData.assert_no_globals(mt) TestData.assert_matches_rows(mt) TestData.assert_matches_cols(mt, no_props=True) TestData.assert_matches_entries(mt) + def test_optional_globals_and_rows_and_cols(): mt = hl.MatrixTable.from_parts(entries=TestData.entries) TestData.assert_no_globals(mt) @@ -99,6 +96,7 @@ def test_optional_globals_and_rows_and_cols(): TestData.assert_matches_cols(mt, no_props=True) TestData.assert_matches_entries(mt) + def test_optional_entries(): mt = hl.MatrixTable.from_parts(rows=TestData.rows, cols=TestData.cols) TestData.assert_no_globals(mt) @@ -106,40 +104,40 @@ def test_optional_entries(): TestData.assert_matches_cols(mt) TestData.assert_matches_entries(mt, no_props=True) + def test_rectangular_matrices(): mt = hl.MatrixTable.from_parts(entries={'foo': [[1], [2]]}) - assert mt.select_rows().select_cols().entries().collect() == \ - [ hl.Struct(row_idx=0, col_idx=0, foo=1) - , hl.Struct(row_idx=1, col_idx=0, foo=2) - ] + assert mt.select_rows().select_cols().entries().collect() == [ + hl.Struct(row_idx=0, col_idx=0, foo=1), + hl.Struct(row_idx=1, col_idx=0, foo=2), + ] + def test_raises_when_no_rows_and_entries(): with pytest.raises(AssertionError): hl.MatrixTable.from_parts(cols=TestData.cols) + def test_raises_when_no_cols_and_entries(): with pytest.raises(AssertionError): hl.MatrixTable.from_parts(rows=TestData.rows) + def test_raises_when_mismatched_row_property_dimensions(): with pytest.raises(ValueError): - hl.MatrixTable.from_parts( rows={'foo': [1], 'bar': [1, 2]} - , entries=TestData.entries - ) + hl.MatrixTable.from_parts(rows={'foo': [1], 'bar': [1, 2]}, entries=TestData.entries) + def test_raises_when_mismatched_col_property_dimensions(): with pytest.raises(ValueError): - hl.MatrixTable.from_parts( cols={'foo': [1], 'bar': [1, 2]} - , entries=TestData.entries - ) + hl.MatrixTable.from_parts(cols={'foo': [1], 'bar': [1, 2]}, entries=TestData.entries) + def test_raises_when_mismatched_entry_property_dimensions(): with pytest.raises(ValueError): hl.MatrixTable.from_parts(entries={'foo': [[1]], 'bar': [[1, 2]]}) + def test_raises_when_mismatched_rows_cols_entry_dimensions(): with pytest.raises(ValueError): - hl.MatrixTable.from_parts( rows={'foo': [1]} - , cols={'bar': [1]} - , entries={'baz': [[1, 2]]} - ) + hl.MatrixTable.from_parts(rows={'foo': [1]}, cols={'bar': [1]}, entries={'baz': [[1, 2]]}) diff --git a/hail/python/test/hail/methods/relatedness/test_identity_by_descent.py b/hail/python/test/hail/methods/relatedness/test_identity_by_descent.py index 21a6d69db40..bcc13fd7da4 100644 --- a/hail/python/test/hail/methods/relatedness/test_identity_by_descent.py +++ b/hail/python/test/hail/methods/relatedness/test_identity_by_descent.py @@ -1,16 +1,25 @@ import os -import pytest import subprocess as sp import unittest +import pytest + import hail as hl -import hail.utils as utils -from ...helpers import get_dataset, test_timeout, qobtest +from hail import utils + +from ...helpers import qobtest, test_timeout + + +@pytest.fixture(scope='module') +def ds(): + dataset = hl.balding_nichols_model(1, 100, 100) + dataset = dataset.key_cols_by(s=hl.str(dataset.sample_idx + 1)) + return dataset -def plinkify(ds, min=None, max=None): +def plinkify(dataset, min=None, max=None): vcf = utils.new_temp_file(prefix="plink", extension="vcf") - hl.export_vcf(ds, vcf) + hl.export_vcf(dataset, vcf) local_tmpdir = utils.new_local_temp_dir() plinkpath = f'{local_tmpdir}/plink-ibd' @@ -18,13 +27,11 @@ def plinkify(ds, min=None, max=None): hl.hadoop_copy(vcf, local_vcf) - threshold_string = "{} {}".format("--min {}".format(min) if min else "", - "--max {}".format(max) if max else "") + threshold_string = "{} {}".format("--min {}".format(min) if min else "", "--max {}".format(max) if max else "") - plink_command = "plink --double-id --allow-extra-chr --vcf {} --genome full --out {} {}" \ - .format(utils.uri_path(local_vcf), - utils.uri_path(plinkpath), - threshold_string) + plink_command = "plink --double-id --allow-extra-chr --vcf {} --genome full --out {} {}".format( + utils.uri_path(local_vcf), utils.uri_path(plinkpath), threshold_string + ) result_file = utils.uri_path(plinkpath + ".genome") sp.run(plink_command, check=True, capture_output=True, shell=True) @@ -40,17 +47,14 @@ def plinkify(ds, min=None, max=None): f.readline() for line in f: row = line.strip().split() - results[(row[1], row[3])] = (list(map(float, row[6:10])), - list(map(int, row[14:17]))) + results[(row[1], row[3])] = (list(map(float, row[6:10])), list(map(int, row[14:17]))) return results @qobtest @unittest.skipIf('HAIL_TEST_SKIP_PLINK' in os.environ, 'Skipping tests requiring plink') @test_timeout(local=10 * 60, batch=10 * 60) -def test_ibd_default_arguments(): - ds = get_dataset() - +def test_ibd_default_arguments(ds): plink_results = plinkify(ds) hail_results = hl.identity_by_descent(ds).collect() @@ -65,11 +69,10 @@ def test_ibd_default_arguments(): assert plink_results[key][1][2] == row.ibs2 +@qobtest @unittest.skipIf('HAIL_TEST_SKIP_PLINK' in os.environ, 'Skipping tests requiring plink') @test_timeout(local=10 * 60, batch=10 * 60) -def test_ibd_0_and_1(): - ds = get_dataset() - +def test_ibd_0_and_1(ds): plink_results = plinkify(ds, min=0.0, max=1.0) hail_results = hl.identity_by_descent(ds).collect() @@ -84,15 +87,15 @@ def test_ibd_0_and_1(): assert plink_results[key][1][2] == row.ibs2 +@qobtest @test_timeout(local=10 * 60, batch=10 * 60) -def test_ibd_does_not_error_with_dummy_maf_float64(): - dataset = get_dataset() - dataset = dataset.annotate_rows(dummy_maf=0.01) - hl.identity_by_descent(dataset, dataset['dummy_maf'], min=0.0, max=1.0) +def test_ibd_does_not_error_with_dummy_maf_float64(ds): + ds = ds.annotate_rows(dummy_maf=0.01) + hl.identity_by_descent(ds, ds['dummy_maf'], min=0.0, max=1.0) +@qobtest @test_timeout(local=10 * 60, batch=10 * 60) -def test_ibd_does_not_error_with_dummy_maf_float32(): - dataset = get_dataset() - dataset = dataset.annotate_rows(dummy_maf=0.01) - hl.identity_by_descent(dataset, hl.float32(dataset['dummy_maf']), min=0.0, max=1.0) +def test_ibd_does_not_error_with_dummy_maf_float32(ds): + ds = ds.annotate_rows(dummy_maf=0.01) + hl.identity_by_descent(ds, hl.float32(ds['dummy_maf']), min=0.0, max=1.0) diff --git a/hail/python/test/hail/methods/relatedness/test_pc_relate.py b/hail/python/test/hail/methods/relatedness/test_pc_relate.py index 96e1e2f1a41..002ad17fb36 100644 --- a/hail/python/test/hail/methods/relatedness/test_pc_relate.py +++ b/hail/python/test/hail/methods/relatedness/test_pc_relate.py @@ -1,25 +1,27 @@ import hail as hl -from ...helpers import resource, skip_when_service_backend, test_timeout, skip_when_service_backend_in_azure, qobtest +from ...helpers import qobtest, resource, skip_when_service_backend, skip_when_service_backend_in_azure, test_timeout @test_timeout(local=6 * 60, batch=14 * 60) def test_pc_relate_against_R_truth(): - with hl.TemporaryDirectory(ensure_exists=False) as vcf_f, \ - hl.TemporaryDirectory(ensure_exists=False) as hail_kin_f: + with hl.TemporaryDirectory(ensure_exists=False) as vcf_f, hl.TemporaryDirectory(ensure_exists=False) as hail_kin_f: mt = hl.import_vcf(resource('pc_relate_bn_input.vcf.bgz')).checkpoint(vcf_f) hail_kin = hl.pc_relate(mt.GT, 0.00, k=2).checkpoint(hail_kin_f) with hl.TemporaryDirectory(ensure_exists=False) as r_kin_f: - r_kin = hl.import_table(resource('pc_relate_r_truth.tsv.bgz'), - types={'i': 'struct{s:str}', - 'j': 'struct{s:str}', - 'kin': 'float', - 'ibd0': 'float', - 'ibd1': 'float', - 'ibd2': 'float'}, - key=['i', 'j'] - ).checkpoint(r_kin_f) + r_kin = hl.import_table( + resource('pc_relate_r_truth.tsv.bgz'), + types={ + 'i': 'struct{s:str}', + 'j': 'struct{s:str}', + 'kin': 'float', + 'ibd0': 'float', + 'ibd1': 'float', + 'ibd2': 'float', + }, + key=['i', 'j'], + ).checkpoint(r_kin_f) assert r_kin.select("kin")._same(hail_kin.select("kin"), tolerance=1e-3, absolute=True) assert r_kin.select("ibd0")._same(hail_kin.select("ibd0"), tolerance=1.3e-2, absolute=True) assert r_kin.select("ibd1")._same(hail_kin.select("ibd1"), tolerance=2.6e-2, absolute=True) @@ -28,42 +30,44 @@ def test_pc_relate_against_R_truth(): @qobtest def test_pc_relate_simple_example(): - gs = hl.literal([[0, 0, 0, 0, 1, 1, 1, 1], - [0, 0, 1, 1, 0, 0, 1, 1], - [0, 1, 0, 1, 0, 1, 0, 1], - [0, 0, 1, 1, 0, 0, 1, 1]]) + gs = hl.literal([ + [0, 0, 0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 0, 0, 1, 1], + [0, 1, 0, 1, 0, 1, 0, 1], + [0, 0, 1, 1, 0, 0, 1, 1], + ]) scores = hl.literal([[1, 1], [-1, 0], [1, -1], [-1, 0]]) mt = hl.utils.range_matrix_table(n_rows=8, n_cols=4) mt = mt.annotate_entries(GT=hl.unphased_diploid_gt_index_call(gs[mt.col_idx][mt.row_idx])) mt = mt.annotate_cols(scores=scores[mt.col_idx]) pcr = hl.pc_relate(mt.GT, min_individual_maf=0, scores_expr=mt.scores) - expected = [hl.Struct(i=0, j=1, kin=0.0, ibd0=1.0, ibd1=0.0, ibd2=0.0), - hl.Struct(i=0, j=2, kin=0.0, ibd0=1.0, ibd1=0.0, ibd2=0.0), - hl.Struct(i=0, j=3, kin=0.0, ibd0=1.0, ibd1=0.0, ibd2=0.0), - hl.Struct(i=1, j=2, kin=0.0, ibd0=1.0, ibd1=0.0, ibd2=0.0), - hl.Struct(i=1, j=3, kin=0.0, ibd0=1.0, ibd1=0.0, ibd2=0.0), - hl.Struct(i=2, j=3, kin=0.0, ibd0=1.0, ibd1=0.0, ibd2=0.0)] + expected = [ + hl.Struct(i=0, j=1, kin=0.0, ibd0=1.0, ibd1=0.0, ibd2=0.0), + hl.Struct(i=0, j=2, kin=0.0, ibd0=1.0, ibd1=0.0, ibd2=0.0), + hl.Struct(i=0, j=3, kin=0.0, ibd0=1.0, ibd1=0.0, ibd2=0.0), + hl.Struct(i=1, j=2, kin=0.0, ibd0=1.0, ibd1=0.0, ibd2=0.0), + hl.Struct(i=1, j=3, kin=0.0, ibd0=1.0, ibd1=0.0, ibd2=0.0), + hl.Struct(i=2, j=3, kin=0.0, ibd0=1.0, ibd1=0.0, ibd2=0.0), + ] ht_expected = hl.Table.parallelize(expected) - ht_expected = ht_expected.key_by(i=hl.struct(col_idx=ht_expected.i), - j=hl.struct(col_idx=ht_expected.j)) + ht_expected = ht_expected.key_by(i=hl.struct(col_idx=ht_expected.i), j=hl.struct(col_idx=ht_expected.j)) assert ht_expected._same(pcr, tolerance=1e-12, absolute=True) @test_timeout(6 * 60, batch=14 * 60) @skip_when_service_backend_in_azure(reason='takes >14 minutes in QoB in Azure') def test_pc_relate_paths_1(): - with hl.TemporaryDirectory(ensure_exists=False) as bn_f, \ - hl.TemporaryDirectory(ensure_exists=False) as scores_f, \ - hl.TemporaryDirectory(ensure_exists=False) as kin1_f, \ - hl.TemporaryDirectory(ensure_exists=False) as kins1_f: + with hl.TemporaryDirectory(ensure_exists=False) as bn_f, hl.TemporaryDirectory( + ensure_exists=False + ) as scores_f, hl.TemporaryDirectory(ensure_exists=False) as kin1_f, hl.TemporaryDirectory( + ensure_exists=False + ) as kins1_f: mt = hl.balding_nichols_model(3, 50, 100).checkpoint(bn_f) _, scores3, _ = hl._hwe_normalized_blanczos(mt.GT, k=3, compute_loadings=False, q_iterations=10) scores3 = scores3.checkpoint(scores_f) - kin1 = hl.pc_relate( - mt.GT, 0.10, k=2, statistics='kin', block_size=64 - ).checkpoint(kin1_f) + kin1 = hl.pc_relate(mt.GT, 0.10, k=2, statistics='kin', block_size=64).checkpoint(kin1_f) kin_s1 = hl.pc_relate( mt.GT, 0.10, scores_expr=scores3[mt.col_key].scores[:2], statistics='kin', block_size=64 ).checkpoint(kins1_f) @@ -94,7 +98,9 @@ def test_pc_relate_paths_3(): def test_self_kinship_1(): mt = hl.balding_nichols_model(3, 10, 50).cache() with hl.TemporaryDirectory(ensure_exists=False) as f: - with_self = hl.pc_relate(mt.GT, 0.10, k=2, statistics='kin', block_size=16, include_self_kinship=True).checkpoint(f) + with_self = hl.pc_relate( + mt.GT, 0.10, k=2, statistics='kin', block_size=16, include_self_kinship=True + ).checkpoint(f) assert with_self.count() == 55 with_self_self_kin_only = with_self.filter(with_self.i.sample_idx == with_self.j.sample_idx) assert with_self_self_kin_only.count() == 10, with_self_self_kin_only.collect() @@ -113,9 +119,12 @@ def test_self_kinship_2(): @test_timeout(6 * 60, batch=14 * 60) def test_self_kinship_3(): mt = hl.balding_nichols_model(3, 10, 50).cache() - with hl.TemporaryDirectory(ensure_exists=False) as with_self_f, \ - hl.TemporaryDirectory(ensure_exists=False) as without_self_f: - with_self = hl.pc_relate(mt.GT, 0.10, k=2, statistics='kin20', block_size=16, include_self_kinship=True).checkpoint(with_self_f) + with hl.TemporaryDirectory(ensure_exists=False) as with_self_f, hl.TemporaryDirectory( + ensure_exists=False + ) as without_self_f: + with_self = hl.pc_relate( + mt.GT, 0.10, k=2, statistics='kin20', block_size=16, include_self_kinship=True + ).checkpoint(with_self_f) without_self = hl.pc_relate(mt.GT, 0.10, k=2, statistics='kin20', block_size=16).checkpoint(without_self_f) with_self_no_self_kin = with_self.filter(with_self.i.sample_idx != with_self.j.sample_idx) @@ -127,7 +136,6 @@ def test_self_kinship_3(): def test_pc_relate_issue_5263(): mt = hl.balding_nichols_model(3, 50, 100) expected = hl.pc_relate(mt.GT, 0.10, k=2, statistics='all') - mt = mt.select_entries(GT2=mt.GT, - GT=hl.call(hl.rand_bool(0.5), hl.rand_bool(0.5))) + mt = mt.select_entries(GT2=mt.GT, GT=hl.call(hl.rand_bool(0.5), hl.rand_bool(0.5))) actual = hl.pc_relate(mt.GT2, 0.10, k=2, statistics='all') assert expected._same(actual, tolerance=1e-3) diff --git a/hail/python/test/hail/methods/test_family_methods.py b/hail/python/test/hail/methods/test_family_methods.py index c5cad6e51bd..31d357bdff8 100644 --- a/hail/python/test/hail/methods/test_family_methods.py +++ b/hail/python/test/hail/methods/test_family_methods.py @@ -1,7 +1,8 @@ import unittest import hail as hl -from ..helpers import * + +from ..helpers import qobtest, resource, test_timeout class Tests(unittest.TestCase): @@ -27,27 +28,30 @@ def test_trio_matrix_1(self): moms = ht.filter(hl.is_defined(ht.mat_id)) moms = moms.select(moms.mat_id, is_mom=True).key_by('mat_id') - et = (mt.entries() - .key_by('s') - .join(dads, how='left') - .join(moms, how='left')) - et = et.annotate(is_dad=hl.is_defined(et.is_dad), - is_mom=hl.is_defined(et.is_mom)) + et = mt.entries().key_by('s').join(dads, how='left').join(moms, how='left') + et = et.annotate(is_dad=hl.is_defined(et.is_dad), is_mom=hl.is_defined(et.is_mom)) - et = (et - .group_by(et.locus, et.alleles, fam=et.fam) - .aggregate(data=hl.agg.collect(hl.struct( - role=hl.case().when(et.is_dad, 1).when(et.is_mom, 2).default(0), - g=hl.struct(GT=et.GT, AD=et.AD, DP=et.DP, GQ=et.GQ, PL=et.PL))))) + et = et.group_by(et.locus, et.alleles, fam=et.fam).aggregate( + data=hl.agg.collect( + hl.struct( + role=hl.case().when(et.is_dad, 1).when(et.is_mom, 2).default(0), + g=hl.struct(GT=et.GT, AD=et.AD, DP=et.DP, GQ=et.GQ, PL=et.PL), + ) + ) + ) et = et.filter(hl.len(et.data) == 3) et = et.select('data').explode('data') tt = hl.trio_matrix(mt, ped, complete_trios=True).entries().key_by('locus', 'alleles') - tt = tt.annotate(fam=tt.proband.fam, - data=[hl.struct(role=0, g=tt.proband_entry.select('GT', 'AD', 'DP', 'GQ', 'PL')), - hl.struct(role=1, g=tt.father_entry.select('GT', 'AD', 'DP', 'GQ', 'PL')), - hl.struct(role=2, g=tt.mother_entry.select('GT', 'AD', 'DP', 'GQ', 'PL'))]) + tt = tt.annotate( + fam=tt.proband.fam, + data=[ + hl.struct(role=0, g=tt.proband_entry.select('GT', 'AD', 'DP', 'GQ', 'PL')), + hl.struct(role=1, g=tt.father_entry.select('GT', 'AD', 'DP', 'GQ', 'PL')), + hl.struct(role=2, g=tt.mother_entry.select('GT', 'AD', 'DP', 'GQ', 'PL')), + ], + ) tt = tt.select('fam', 'data').explode('data') tt = tt.filter(hl.is_defined(tt.data.g)).key_by('locus', 'alleles', 'fam') @@ -77,23 +81,32 @@ def test_trio_matrix_2(self): moms = moms.select(moms.mat_id, is_mom=True).key_by('mat_id') # test annotations - e_cols = (mt.cols() - .join(dads, how='left') - .join(moms, how='left')) - e_cols = e_cols.annotate(is_dad=hl.is_defined(e_cols.is_dad), - is_mom=hl.is_defined(e_cols.is_mom)) - e_cols = (e_cols.group_by(fam=e_cols.fam) - .aggregate(data=hl.agg.collect(hl.struct(role=hl.case() - .when(e_cols.is_dad, 1).when(e_cols.is_mom, 2).default(0), - sa=hl.struct(**e_cols.row.select(*mt.col)))))) + e_cols = mt.cols().join(dads, how='left').join(moms, how='left') + e_cols = e_cols.annotate(is_dad=hl.is_defined(e_cols.is_dad), is_mom=hl.is_defined(e_cols.is_mom)) + e_cols = e_cols.group_by(fam=e_cols.fam).aggregate( + data=hl.agg.collect( + hl.struct( + role=hl.case().when(e_cols.is_dad, 1).when(e_cols.is_mom, 2).default(0), + sa=hl.struct(**e_cols.row.select(*mt.col)), + ) + ) + ) e_cols = e_cols.filter(hl.len(e_cols.data) == 3).select('data').explode('data') t_cols = hl.trio_matrix(mt, ped, complete_trios=True).cols() - t_cols = t_cols.annotate(fam=t_cols.proband.fam, - data=[ - hl.struct(role=0, sa=t_cols.proband), - hl.struct(role=1, sa=t_cols.father), - hl.struct(role=2, sa=t_cols.mother)]).key_by('fam').select('data').explode('data') + t_cols = ( + t_cols.annotate( + fam=t_cols.proband.fam, + data=[ + hl.struct(role=0, sa=t_cols.proband), + hl.struct(role=1, sa=t_cols.father), + hl.struct(role=2, sa=t_cols.mother), + ], + ) + .key_by('fam') + .select('data') + .explode('data') + ) t_cols = t_cols.filter(hl.is_defined(t_cols.data.sa)) assert e_cols.key.dtype == t_cols.key.dtype @@ -124,32 +137,18 @@ def test_mendel_errors_1(self): ped = hl.Pedigree.read(resource('mendel.fam')) men, fam, ind, var = hl.mendel_errors(mt['GT'], ped) - assert men.key.dtype == hl.tstruct(locus=mt.locus.dtype, - alleles=hl.tarray(hl.tstr), - s=hl.tstr) - assert men.row.dtype == hl.tstruct(locus=mt.locus.dtype, - alleles=hl.tarray(hl.tstr), - s=hl.tstr, - fam_id=hl.tstr, - mendel_code=hl.tint) - assert fam.key.dtype == hl.tstruct(pat_id=hl.tstr, - mat_id=hl.tstr) - assert fam.row.dtype == hl.tstruct(pat_id=hl.tstr, - mat_id=hl.tstr, - fam_id=hl.tstr, - children=hl.tint, - errors=hl.tint64, - snp_errors=hl.tint64) + assert men.key.dtype == hl.tstruct(locus=mt.locus.dtype, alleles=hl.tarray(hl.tstr), s=hl.tstr) + assert men.row.dtype == hl.tstruct( + locus=mt.locus.dtype, alleles=hl.tarray(hl.tstr), s=hl.tstr, fam_id=hl.tstr, mendel_code=hl.tint + ) + assert fam.key.dtype == hl.tstruct(pat_id=hl.tstr, mat_id=hl.tstr) + assert fam.row.dtype == hl.tstruct( + pat_id=hl.tstr, mat_id=hl.tstr, fam_id=hl.tstr, children=hl.tint, errors=hl.tint64, snp_errors=hl.tint64 + ) assert ind.key.dtype == hl.tstruct(s=hl.tstr) - assert ind.row.dtype == hl.tstruct(s=hl.tstr, - fam_id=hl.tstr, - errors=hl.tint64, - snp_errors=hl.tint64) - assert var.key.dtype == hl.tstruct(locus=mt.locus.dtype, - alleles=hl.tarray(hl.tstr)) - assert var.row.dtype == hl.tstruct(locus=mt.locus.dtype, - alleles=hl.tarray(hl.tstr), - errors=hl.tint64) + assert ind.row.dtype == hl.tstruct(s=hl.tstr, fam_id=hl.tstr, errors=hl.tint64, snp_errors=hl.tint64) + assert var.key.dtype == hl.tstruct(locus=mt.locus.dtype, alleles=hl.tarray(hl.tstr)) + assert var.row.dtype == hl.tstruct(locus=mt.locus.dtype, alleles=hl.tarray(hl.tstr), errors=hl.tint64) @test_timeout(4 * 60) def test_mendel_errors_2(self): @@ -191,10 +190,8 @@ def test_mendel_errors_6(self): actual = set(fam.select('children', 'errors', 'snp_errors').collect()) expected = { - hl.utils.Struct(pat_id='Dad1', mat_id='Mom1', children=2, - errors=41, snp_errors=39), - hl.utils.Struct(pat_id='Dad2', mat_id='Mom2', children=1, - errors=0, snp_errors=0) + hl.utils.Struct(pat_id='Dad1', mat_id='Mom1', children=2, errors=41, snp_errors=39), + hl.utils.Struct(pat_id='Dad2', mat_id='Mom2', children=1, errors=0, snp_errors=0), } assert actual == expected @@ -212,7 +209,7 @@ def test_mendel_errors_7(self): hl.utils.Struct(s='Mom1', errors=22, snp_errors=21), hl.utils.Struct(s='Dad2', errors=0, snp_errors=0), hl.utils.Struct(s='Mom2', errors=0, snp_errors=0), - hl.utils.Struct(s='Son2', errors=0, snp_errors=0) + hl.utils.Struct(s='Son2', errors=0, snp_errors=0), } assert actual == expected @@ -228,7 +225,7 @@ def test_mendel_errors_8(self): (hl.Locus("X", 1), ['C', 'T']), (hl.Locus("X", 3), ['C', 'T']), (hl.Locus("Y", 1), ['C', 'T']), - (hl.Locus("Y", 3), ['C', 'T']) + (hl.Locus("Y", 3), ['C', 'T']), ]) var = var.filter(to_keep.contains((var.locus, var.alleles))) var = var.order_by('locus') @@ -257,31 +254,34 @@ def test_mendel_errors_9(self): def test_tdt(self): pedigree = hl.Pedigree.read(resource('tdt.fam')) - tdt_tab = (hl.transmission_disequilibrium_test( - hl.split_multi_hts(hl.import_vcf(resource('tdt.vcf'), min_partitions=4)), - pedigree)) + tdt_tab = hl.transmission_disequilibrium_test( + hl.split_multi_hts(hl.import_vcf(resource('tdt.vcf'), min_partitions=4)), pedigree + ) truth = hl.import_table( resource('tdt_results.tsv'), - types={'POSITION': hl.tint32, 'T': hl.tint32, 'U': hl.tint32, - 'Chi2': hl.tfloat64, 'Pval': hl.tfloat64}) - truth = (truth - .transmute(locus=hl.locus(truth.CHROM, truth.POSITION), - alleles=[truth.REF, truth.ALT]) - .key_by('locus', 'alleles')) + types={'POSITION': hl.tint32, 'T': hl.tint32, 'U': hl.tint32, 'Chi2': hl.tfloat64, 'Pval': hl.tfloat64}, + ) + truth = truth.transmute(locus=hl.locus(truth.CHROM, truth.POSITION), alleles=[truth.REF, truth.ALT]).key_by( + 'locus', 'alleles' + ) if tdt_tab.count() != truth.count(): self.fail('Result has {} rows but should have {} rows'.format(tdt_tab.count(), truth.count())) - bad = (tdt_tab.filter(hl.is_nan(tdt_tab.p_value), keep=False) - .join(truth.filter(hl.is_nan(truth.Pval), keep=False), how='outer')) + bad = tdt_tab.filter(hl.is_nan(tdt_tab.p_value), keep=False).join( + truth.filter(hl.is_nan(truth.Pval), keep=False), how='outer' + ) bad.describe() - bad = bad.filter(~( - (bad.t == bad.T) & - (bad.u == bad.U) & - (hl.abs(bad.chi_sq - bad.Chi2) < 0.001) & - (hl.abs(bad.p_value - bad.Pval) < 0.001))) + bad = bad.filter( + ~( + (bad.t == bad.T) + & (bad.u == bad.U) + & (hl.abs(bad.chi_sq - bad.Chi2) < 0.001) + & (hl.abs(bad.p_value - bad.Pval) < 0.001) + ) + ) if bad.count() != 0: bad.order_by(hl.asc(bad.u)).show() @@ -298,7 +298,8 @@ def test_de_novo(self): dad_id=r.father.s, mom_id=r.mother.s, p_de_novo=r.p_de_novo, - confidence=r.confidence).key_by('locus', 'alleles', 'kid_id', 'dad_id', 'mom_id') + confidence=r.confidence, + ).key_by('locus', 'alleles', 'kid_id', 'dad_id', 'mom_id') truth = hl.import_table(resource('denovo.out'), impute=True, comment='#') truth = truth.select( @@ -308,8 +309,8 @@ def test_de_novo(self): dad_id=truth['Dad_ID'], mom_id=truth['Mom_ID'], p_de_novo=truth['Prob_dn'], - confidence=truth['Validation_Likelihood'].split('_')[0]).key_by('locus', 'alleles', 'kid_id', 'dad_id', - 'mom_id') + confidence=truth['Validation_Likelihood'].split('_')[0], + ).key_by('locus', 'alleles', 'kid_id', 'dad_id', 'mom_id') j = r.join(truth, how='outer') self.assertTrue(j.all((j.confidence == j.confidence_1) & (hl.abs(j.p_de_novo - j.p_de_novo_1) < 1e-4))) diff --git a/hail/python/test/hail/methods/test_impex.py b/hail/python/test/hail/methods/test_impex.py index bb0c82c7b80..fb71fd96787 100644 --- a/hail/python/test/hail/methods/test_impex.py +++ b/hail/python/test/hail/methods/test_impex.py @@ -1,20 +1,29 @@ import json import os -import pytest +import re import shutil import unittest - from unittest import mock +import pytest from avro.datafile import DataFileReader from avro.io import DatumReader -from hail.context import TemporaryFilename -import pytest import hail as hl -from ..helpers import * from hail import ir -from hail.utils import new_temp_file, new_local_temp_file, FatalError, run_command, uri_path, HailUserError +from hail.context import TemporaryFilename +from hail.utils import FatalError, HailUserError, new_local_temp_file, new_temp_file, run_command, uri_path + +from ..helpers import ( + doctest_resource, + fails_local_backend, + fails_service_backend, + get_dataset, + qobtest, + resource, + test_timeout, + with_flags, +) _FLOAT_INFO_FIELDS = [ 'BaseQRankSum', @@ -83,8 +92,7 @@ def test_not_identical_headers(self): hl.export_vcf(mt.filter_cols((mt.s != "C1048::HG02024") & (mt.s != "HG00255")), t) with self.assertRaisesRegex(FatalError, 'invalid sample IDs'): - (hl.import_vcf([resource('sample.vcf'), t]) - ._force_count_rows()) + (hl.import_vcf([resource('sample.vcf'), t])._force_count_rows()) def test_filter(self): mt = hl.import_vcf(resource('malformed.vcf'), filter='rs685723') @@ -100,15 +108,14 @@ def test_find_replace(self): def test_haploid(self): expected = hl.Table.parallelize( - [hl.struct(locus = hl.locus("X", 16050036), s = "C1046::HG02024", - GT = hl.call(0, 0), AD = [10, 0], GQ = 44), - hl.struct(locus = hl.locus("X", 16050036), s = "C1046::HG02025", - GT = hl.call(1), AD = [0, 6], GQ = 70), - hl.struct(locus = hl.locus("X", 16061250), s = "C1046::HG02024", - GT = hl.call(2, 2), AD = [0, 0, 11], GQ = 33), - hl.struct(locus = hl.locus("X", 16061250), s = "C1046::HG02025", - GT = hl.call(2), AD = [0, 0, 9], GQ = 24)], - key=['locus', 's']) + [ + hl.struct(locus=hl.locus("X", 16050036), s="C1046::HG02024", GT=hl.call(0, 0), AD=[10, 0], GQ=44), + hl.struct(locus=hl.locus("X", 16050036), s="C1046::HG02025", GT=hl.call(1), AD=[0, 6], GQ=70), + hl.struct(locus=hl.locus("X", 16061250), s="C1046::HG02024", GT=hl.call(2, 2), AD=[0, 0, 11], GQ=33), + hl.struct(locus=hl.locus("X", 16061250), s="C1046::HG02025", GT=hl.call(2), AD=[0, 0, 9], GQ=24), + ], + key=['locus', 's'], + ) mt = hl.import_vcf(resource('haploid.vcf')) entries = mt.entries() @@ -118,15 +125,38 @@ def test_haploid(self): def test_call_fields(self): expected = hl.Table.parallelize( - [hl.struct(locus = hl.locus("X", 16050036), s = "C1046::HG02024", - GT = hl.call(0, 0), GTA = hl.missing(hl.tcall), GTZ = hl.call(0, 1)), - hl.struct(locus = hl.locus("X", 16050036), s = "C1046::HG02025", - GT = hl.call(1), GTA = hl.missing(hl.tcall), GTZ = hl.call(0)), - hl.struct(locus = hl.locus("X", 16061250), s = "C1046::HG02024", - GT = hl.call(2, 2), GTA = hl.call(2, 1), GTZ = hl.call(1, 1)), - hl.struct(locus = hl.locus("X", 16061250), s = "C1046::HG02025", - GT = hl.call(2), GTA = hl.missing(hl.tcall), GTZ = hl.call(1))], - key=['locus', 's']) + [ + hl.struct( + locus=hl.locus("X", 16050036), + s="C1046::HG02024", + GT=hl.call(0, 0), + GTA=hl.missing(hl.tcall), + GTZ=hl.call(0, 1), + ), + hl.struct( + locus=hl.locus("X", 16050036), + s="C1046::HG02025", + GT=hl.call(1), + GTA=hl.missing(hl.tcall), + GTZ=hl.call(0), + ), + hl.struct( + locus=hl.locus("X", 16061250), + s="C1046::HG02024", + GT=hl.call(2, 2), + GTA=hl.call(2, 1), + GTZ=hl.call(1, 1), + ), + hl.struct( + locus=hl.locus("X", 16061250), + s="C1046::HG02025", + GT=hl.call(2), + GTA=hl.missing(hl.tcall), + GTZ=hl.call(1), + ), + ], + key=['locus', 's'], + ) mt = hl.import_vcf(resource('generic.vcf'), call_fields=['GT', 'GTA', 'GTZ']) entries = mt.entries() @@ -136,9 +166,10 @@ def test_call_fields(self): def test_import_vcf(self): vcf = hl.split_multi_hts( - hl.import_vcf(resource('sample2.vcf'), - reference_genome=hl.get_reference('GRCh38'), - contig_recoding={"22": "chr22"})) + hl.import_vcf( + resource('sample2.vcf'), reference_genome=hl.get_reference('GRCh38'), contig_recoding={"22": "chr22"} + ) + ) vcf_table = vcf.rows() self.assertTrue(vcf_table.all(vcf_table.locus.contig == "chr22")) @@ -149,8 +180,7 @@ def test_import_vcf_empty(self): assert mt._same(hl.import_vcf(resource('3var.vcf.bgz'))) def test_import_vcf_no_reference_specified(self): - vcf = hl.import_vcf(resource('sample2.vcf'), - reference_genome=None) + vcf = hl.import_vcf(resource('sample2.vcf'), reference_genome=None) self.assertEqual(vcf.locus.dtype, hl.tstruct(contig=hl.tstr, position=hl.tint32)) self.assertEqual(vcf.count_rows(), 735) @@ -161,10 +191,14 @@ def test_import_vcf_bad_reference_allele(self): def test_import_vcf_flags_are_defined(self): # issue 3277 t = hl.import_vcf(resource('sample.vcf')).rows() - self.assertTrue(t.all(hl.is_defined(t.info.NEGATIVE_TRAIN_SITE) & - hl.is_defined(t.info.POSITIVE_TRAIN_SITE) & - hl.is_defined(t.info.DB) & - hl.is_defined(t.info.DS))) + self.assertTrue( + t.all( + hl.is_defined(t.info.NEGATIVE_TRAIN_SITE) + & hl.is_defined(t.info.POSITIVE_TRAIN_SITE) + & hl.is_defined(t.info.DB) + & hl.is_defined(t.info.DS) + ) + ) def test_import_vcf_can_import_float_array_format(self): mt = hl.import_vcf(resource('floating_point_array.vcf')) @@ -178,38 +212,101 @@ def test_import_vcf_can_import_float32_array_format(self): def test_import_vcf_can_import_negative_numbers(self): mt = hl.import_vcf(resource('negative_format_fields.vcf')) - self.assertTrue(mt.aggregate_entries(hl.agg.all(mt.negative_int == -1) & - hl.agg.all(mt.negative_float == -1.5) & - hl.agg.all(mt.negative_int_array == [-1, -2]) & - hl.agg.all(mt.negative_float_array == [-0.5, -1.5]))) + self.assertTrue( + mt.aggregate_entries( + hl.agg.all(mt.negative_int == -1) + & hl.agg.all(mt.negative_float == -1.5) + & hl.agg.all(mt.negative_int_array == [-1, -2]) + & hl.agg.all(mt.negative_float_array == [-0.5, -1.5]) + ) + ) + + def test_import_vcf_has_good_error_message_when_info_fields_have_missing_elements(self): + mt = hl.import_vcf(resource('missingInfoArray.vcf'), reference_genome='GRCh37') + with pytest.raises( + FatalError, + match=".*Missing value in INFO array. Use 'hl.import_vcf[(][.][.][.], array_elements_required=False[)]'[.].*", + ): + mt._force_count_rows() - def test_import_vcf_missing_info_field_elements(self): + def test_import_vcf_array_elements_required_is_false_parses_info_fields_with_missing_elements(self): mt = hl.import_vcf(resource('missingInfoArray.vcf'), reference_genome='GRCh37', array_elements_required=False) - mt = mt.select_rows(FOO=mt.info.FOO, BAR=mt.info.BAR) - expected = hl.Table.parallelize([{'locus': hl.Locus('X', 16050036), 'alleles': ['A', 'C'], - 'FOO': [1, None], 'BAR': [2, None, None]}, - {'locus': hl.Locus('X', 16061250), 'alleles': ['T', 'A', 'C'], - 'FOO': [None, 2, None], 'BAR': [None, 1.0, None]}], - hl.tstruct(locus=hl.tlocus('GRCh37'), alleles=hl.tarray(hl.tstr), - FOO=hl.tarray(hl.tint), BAR=hl.tarray(hl.tfloat64)), - key=['locus', 'alleles']) - self.assertTrue(mt.rows()._same(expected)) + mt = mt.select_rows(**mt.info) + expected = hl.Table.parallelize( + [ + { + 'locus': hl.Locus('X', 16050036), + 'alleles': ['A', 'C'], + 'FOO': [1, None], + 'BAR': [2, None, None], + 'JUST_A_DOT': None, + 'NOT_EVEN_PRESENT': None, + }, + { + 'locus': hl.Locus('X', 16061250), + 'alleles': ['T', 'A', 'C'], + 'FOO': [None, 2, None], + 'BAR': [None, 1.0, None], + 'JUST_A_DOT': None, + 'NOT_EVEN_PRESENT': None, + }, + ], + hl.tstruct( + locus=hl.tlocus('GRCh37'), + alleles=hl.tarray(hl.tstr), + FOO=hl.tarray(hl.tint), + BAR=hl.tarray(hl.tfloat64), + JUST_A_DOT=hl.tarray(hl.tfloat64), + NOT_EVEN_PRESENT=hl.tarray(hl.tfloat64), + ), + key=['locus', 'alleles'], + ) + assert mt.rows()._same(expected) def test_import_vcf_missing_format_field_elements(self): mt = hl.import_vcf(resource('missingFormatArray.vcf'), reference_genome='GRCh37', array_elements_required=False) mt = mt.select_rows().select_entries('AD', 'PL') - expected = hl.Table.parallelize([{'locus': hl.Locus('X', 16050036), 'alleles': ['A', 'C'], 's': 'C1046::HG02024', - 'AD': [None, None], 'PL': [0, None, 180]}, - {'locus': hl.Locus('X', 16050036), 'alleles': ['A', 'C'], 's': 'C1046::HG02025', - 'AD': [None, 6], 'PL': [70, None]}, - {'locus': hl.Locus('X', 16061250), 'alleles': ['T', 'A', 'C'], 's': 'C1046::HG02024', - 'AD': [0, 0, None], 'PL': [396, None, None, 33, None, 0]}, - {'locus': hl.Locus('X', 16061250), 'alleles': ['T', 'A', 'C'], 's': 'C1046::HG02025', - 'AD': [0, 0, 9], 'PL': [None, None, None]}], - hl.tstruct(locus=hl.tlocus('GRCh37'), alleles=hl.tarray(hl.tstr), s=hl.tstr, - AD=hl.tarray(hl.tint), PL=hl.tarray(hl.tint)), - key=['locus', 'alleles', 's']) + expected = hl.Table.parallelize( + [ + { + 'locus': hl.Locus('X', 16050036), + 'alleles': ['A', 'C'], + 's': 'C1046::HG02024', + 'AD': [None, None], + 'PL': [0, None, 180], + }, + { + 'locus': hl.Locus('X', 16050036), + 'alleles': ['A', 'C'], + 's': 'C1046::HG02025', + 'AD': [None, 6], + 'PL': [70, None], + }, + { + 'locus': hl.Locus('X', 16061250), + 'alleles': ['T', 'A', 'C'], + 's': 'C1046::HG02024', + 'AD': [0, 0, None], + 'PL': [396, None, None, 33, None, 0], + }, + { + 'locus': hl.Locus('X', 16061250), + 'alleles': ['T', 'A', 'C'], + 's': 'C1046::HG02025', + 'AD': [0, 0, 9], + 'PL': [None, None, None], + }, + ], + hl.tstruct( + locus=hl.tlocus('GRCh37'), + alleles=hl.tarray(hl.tstr), + s=hl.tstr, + AD=hl.tarray(hl.tint), + PL=hl.tarray(hl.tint), + ), + key=['locus', 'alleles', 's'], + ) self.assertTrue(mt.entries()._same(expected)) @@ -220,8 +317,7 @@ def test_vcf_unsorted_alleles_no_codegen(self): with_flags(no_whole_stage_codegen="1")(_vcf_unsorted_alleles)() def test_import_vcf_skip_invalid_loci(self): - mt = hl.import_vcf(resource('skip_invalid_loci.vcf'), reference_genome='GRCh37', - skip_invalid_loci=True) + mt = hl.import_vcf(resource('skip_invalid_loci.vcf'), reference_genome='GRCh37', skip_invalid_loci=True) self.assertEqual(mt._force_count_rows(), 3) with self.assertRaisesRegex(FatalError, 'Invalid locus'): @@ -234,7 +330,7 @@ def test_import_vcf_set_field_missing(self): def test_import_vcf_dosages_as_doubles_or_floats(self): mt = hl.import_vcf(resource('small-ds.vcf')) self.assertEqual(hl.expr.expressions.typed_expressions.Float64Expression, type(mt.entry.DS)) - mt32 = hl.import_vcf(resource('small-ds.vcf'), entry_float_type=hl.tfloat32) + mt32 = hl.import_vcf(resource('small-ds.vcf'), entry_float_type=hl.tfloat32) self.assertEqual(hl.expr.expressions.typed_expressions.Float32Expression, type(mt32.entry.DS)) mt_result = mt.annotate_entries(DS32=mt32.index_entries(mt.row_key, mt.col_key).DS) compare = mt_result.annotate_entries( @@ -244,19 +340,18 @@ def test_import_vcf_dosages_as_doubles_or_floats(self): def test_import_vcf_invalid_float_type(self): with self.assertRaises(TypeError): - mt = hl.import_vcf(resource('small-ds.vcf'), entry_float_type=hl.tstr) + hl.import_vcf(resource('small-ds.vcf'), entry_float_type=hl.tstr) with self.assertRaises(TypeError): - mt = hl.import_vcf(resource('small-ds.vcf'), entry_float_type=hl.tint) + hl.import_vcf(resource('small-ds.vcf'), entry_float_type=hl.tint) with self.assertRaises(TypeError): - mt = hl.import_vcf(resource('small-ds.vcf'), entry_float_type=hl.tint32) + hl.import_vcf(resource('small-ds.vcf'), entry_float_type=hl.tint32) with self.assertRaises(TypeError): - mt = hl.import_vcf(resource('small-ds.vcf'), entry_float_type=hl.tint64) + hl.import_vcf(resource('small-ds.vcf'), entry_float_type=hl.tint64) def test_export_vcf(self): dataset = hl.import_vcf(resource('sample.vcf.bgz')) vcf_metadata = hl.get_vcf_metadata(resource('sample.vcf.bgz')) - with TemporaryFilename(suffix='.vcf') as sample_vcf, \ - TemporaryFilename(suffix='.vcf') as no_sample_vcf: + with TemporaryFilename(suffix='.vcf') as sample_vcf, TemporaryFilename(suffix='.vcf') as no_sample_vcf: hl.export_vcf(dataset, sample_vcf, metadata=vcf_metadata) dataset_imported = hl.import_vcf(sample_vcf) self.assertTrue(dataset._same(dataset_imported)) @@ -275,11 +370,12 @@ def test_export_vcf_quotes_and_backslash_in_description(self): meta = hl.get_vcf_metadata(resource("sample.vcf")) meta["info"]["AF"]["Description"] = 'foo "bar" \\' with TemporaryFilename(suffix='.vcf') as test_vcf: - hl.export_vcf(ds, test_vcf, metadata=meta) - af_lines = [ - line for line in hl.current_backend().fs.open(test_vcf).read().split('\n') - if line.startswith("##INFO="), line @@ -312,9 +408,7 @@ def test_export_vcf_no_alt_alleles(self): self.assertTrue(mt._same(mt2)) def test_export_sites_only_from_table(self): - mt = hl.import_vcf(resource('sample.vcf.bgz'))\ - .select_entries()\ - .filter_cols(False) + mt = hl.import_vcf(resource('sample.vcf.bgz')).select_entries().filter_cols(False) tmp = new_temp_file(extension="vcf") hl.export_vcf(mt.rows(), tmp) @@ -323,52 +417,74 @@ def test_export_sites_only_from_table(self): def test_export_vcf_invalid_info_types(self): ds = hl.import_vcf(resource("sample.vcf")) ds = ds.annotate_rows( - info=ds.info.annotate(arr_bool=hl.missing(hl.tarray(hl.tbool)), - arr_arr_i32=hl.missing(hl.tarray(hl.tarray(hl.tint32))))) - with pytest.raises(FatalError) as exp, \ - TemporaryFilename(suffix='.vcf') as export_path: + info=ds.info.annotate( + arr_bool=hl.missing(hl.tarray(hl.tbool)), arr_arr_i32=hl.missing(hl.tarray(hl.tarray(hl.tint32))) + ) + ) + with pytest.raises(FatalError) as exp, TemporaryFilename(suffix='.vcf') as export_path: hl.export_vcf(ds, export_path) - msg = '''VCF does not support the type(s) for the following INFO field(s): + msg = """VCF does not support the type(s) for the following INFO field(s): \t'arr_bool': 'array'. \t'arr_arr_i32': 'array>'. -''' +""" assert msg in str(exp.value) def test_export_vcf_invalid_format_types(self): ds = hl.import_vcf(resource("sample.vcf")) - ds = ds.annotate_entries( - boolean=hl.missing(hl.tbool), - arr_arr_i32=hl.missing(hl.tarray(hl.tarray(hl.tint32)))) - with pytest.raises(FatalError) as exp, \ - TemporaryFilename(suffix='.vcf') as export_path: + ds = ds.annotate_entries(boolean=hl.missing(hl.tbool), arr_arr_i32=hl.missing(hl.tarray(hl.tarray(hl.tint32)))) + with pytest.raises(FatalError) as exp, TemporaryFilename(suffix='.vcf') as export_path: hl.export_vcf(ds, export_path) - msg = '''VCF does not support the type(s) for the following FORMAT field(s): + msg = """VCF does not support the type(s) for the following FORMAT field(s): \t'boolean': 'bool'. \t'arr_arr_i32': 'array>'. -''' +""" assert msg in str(exp.value) + def test_export_vcf_haploid(self): + ds = hl.import_vcf(resource("sample.vcf")) + ds = ds.select_entries(GT=hl.call(0)) + with TemporaryFilename(suffix='.vcf') as export_path: + hl.export_vcf(ds, export_path) + def import_gvcfs_sample_vcf(self, path): parts_type = hl.tarray(hl.tinterval(hl.tstruct(locus=hl.tlocus('GRCh37')))) parts = [ - hl.Interval(start=hl.Struct(locus=hl.Locus('20', 1)), - end=hl.Struct(locus=hl.Locus('20', 13509135)), - includes_end=True), - hl.Interval(start=hl.Struct(locus=hl.Locus('20', 13509136)), - end=hl.Struct(locus=hl.Locus('20', 16493533)), - includes_end=True), - hl.Interval(start=hl.Struct(locus=hl.Locus('20', 16493534)), - end=hl.Struct(locus=hl.Locus('20', 20000000)), - includes_end=True) + hl.Interval( + start=hl.Struct(locus=hl.Locus('20', 1)), + end=hl.Struct(locus=hl.Locus('20', 13509135)), + includes_end=True, + ), + hl.Interval( + start=hl.Struct(locus=hl.Locus('20', 13509136)), + end=hl.Struct(locus=hl.Locus('20', 16493533)), + includes_end=True, + ), + hl.Interval( + start=hl.Struct(locus=hl.Locus('20', 16493534)), + end=hl.Struct(locus=hl.Locus('20', 20000000)), + includes_end=True, + ), ] parts_str = json.dumps(parts_type._convert_to_json(parts)) - vir = ir.MatrixVCFReader(path=path, call_fields=['PGT'], entry_float_type=hl.tfloat64, - header_file=None, block_size=None, min_partitions=None, - reference_genome='default', contig_recoding=None, - array_elements_required=True, skip_invalid_loci=False, - force_bgz=False, force_gz=False, filter=None, find_replace=None, - n_partitions=None, _partitions_json=parts_str, - _partitions_type=parts_type) + vir = ir.MatrixVCFReader( + path=path, + call_fields=['PGT'], + entry_float_type=hl.tfloat64, + header_file=None, + block_size=None, + min_partitions=None, + reference_genome='default', + contig_recoding=None, + array_elements_required=True, + skip_invalid_loci=False, + force_bgz=False, + force_gz=False, + filter=None, + find_replace=None, + n_partitions=None, + _partitions_json=parts_str, + _partitions_type=parts_type, + ) vcf1 = hl.import_vcf(path) vcf2 = hl.MatrixTable(ir.MatrixRead(vir)) @@ -409,6 +525,7 @@ def test_import_gvcfs(self): def test_import_gvcfs_long_line(self): import bz2 + fs = hl.current_backend().fs path = resource('gvcfs/long_line.g.vcf.gz') vcf = hl.import_vcf(path, force_bgz=True) @@ -428,25 +545,26 @@ def test_vcf_parser_golden_master__gvcf_GRCh37(self): self._test_vcf_parser_golden_master(resource('gvcfs/HG00096.g.vcf.gz'), 'GRCh38') def _test_vcf_parser_golden_master(self, vcf_path, rg): - vcf = hl.import_vcf( - vcf_path, - reference_genome=rg, - array_elements_required=False, - force_bgz=True) + vcf = hl.import_vcf(vcf_path, reference_genome=rg, array_elements_required=False, force_bgz=True) mt = hl.read_matrix_table(vcf_path + '.mt') self.assertTrue(mt._same(vcf)) def test_combiner_works(self): - from hail.vds.combiner.combine import transform_gvcf, combine_variant_datasets + from hail.vds.combiner.combine import combine_variant_datasets, transform_gvcf + _paths = ['gvcfs/HG00096.g.vcf.gz', 'gvcfs/HG00268.g.vcf.gz'] paths = [resource(p) for p in _paths] vdses = [] for path in paths: mt = hl.import_vcf(path, reference_genome='GRCh38', array_elements_required=False, force_bgz=True) - mt = transform_gvcf(mt.annotate_rows(info=mt.info.annotate( - MQ_DP=hl.missing(hl.tint32), - VarDP=hl.missing(hl.tint32), - QUALapprox=hl.missing(hl.tint32))), reference_entry_fields_to_keep=[]) + mt = transform_gvcf( + mt.annotate_rows( + info=mt.info.annotate( + MQ_DP=hl.missing(hl.tint32), VarDP=hl.missing(hl.tint32), QUALapprox=hl.missing(hl.tint32) + ) + ), + reference_entry_fields_to_keep=[], + ) vdses.append(mt) comb = combine_variant_datasets(vdses) assert comb.reference_data._force_count_rows() == 458646 @@ -454,6 +572,7 @@ def test_combiner_works(self): def test_haploid_combiner_ok(self): from hail.vds.combiner.combine import transform_gvcf + # make a combiner table mt = hl.utils.range_matrix_table(2, 1) mt = mt.annotate_cols(s='S01') @@ -475,6 +594,7 @@ def test_haploid_combiner_ok(self): def test_combiner_parse_allele_specific_annotations(self): from hail.vds.combiner.combine import parse_allele_specific_fields + infos = hl.array([ hl.struct( AS_QUALapprox="|1171|", @@ -482,14 +602,17 @@ def test_combiner_parse_allele_specific_annotations(self): AS_VarDP="0|57|0", AS_RAW_MQ="0.00|15100.00|0.00", AS_RAW_MQRankSum="|0.0,1|NaN", - AS_RAW_ReadPosRankSum="|0.7,1|NaN"), + AS_RAW_ReadPosRankSum="|0.7,1|NaN", + ), hl.struct( AS_QUALapprox="|1171|", AS_SB_TABLE="0,0|30,27|0,0", AS_VarDP="0|57|0", AS_RAW_MQ="0.00|15100.00|0.00", AS_RAW_MQRankSum="|NaN|NaN", - AS_RAW_ReadPosRankSum="|NaN|NaN")]) + AS_RAW_ReadPosRankSum="|NaN|NaN", + ), + ]) output = hl.eval(infos.map(lambda info: parse_allele_specific_fields(info, False))) expected = [ @@ -499,14 +622,17 @@ def test_combiner_parse_allele_specific_annotations(self): AS_VarDP=[0, 57, 0], AS_RAW_MQ=[0.00, 15100.00, 0.00], AS_RAW_MQRankSum=[None, (0.0, 1), None], - AS_RAW_ReadPosRankSum=[None, (0.7, 1), None]), + AS_RAW_ReadPosRankSum=[None, (0.7, 1), None], + ), hl.Struct( AS_QUALapprox=[None, 1171, None], AS_SB_TABLE=[[0, 0], [30, 27], [0, 0]], AS_VarDP=[0, 57, 0], AS_RAW_MQ=[0.00, 15100.00, 0.00], AS_RAW_MQRankSum=[None, None, None], - AS_RAW_ReadPosRankSum=[None, None, None])] + AS_RAW_ReadPosRankSum=[None, None, None], + ), + ] assert output == expected def test_flag_at_eol(self): @@ -515,11 +641,12 @@ def test_flag_at_eol(self): assert mt._force_count_rows() == 1 def test_missing_float_entries(self): - vcf = hl.import_vcf(resource('noglgp.vcf'), array_elements_required=False, - reference_genome='GRCh38') + vcf = hl.import_vcf(resource('noglgp.vcf'), array_elements_required=False, reference_genome='GRCh38') gl_gp = vcf.aggregate_entries(hl.agg.collect(hl.struct(GL=vcf.GL, GP=vcf.GP))) - assert gl_gp == [hl.Struct(GL=[None, None, None], GP=[0.22, 0.5, 0.27]), - hl.Struct(GL=[None, None, None], GP=[None, None, None])] + assert gl_gp == [ + hl.Struct(GL=[None, None, None], GP=[0.22, 0.5, 0.27]), + hl.Struct(GL=[None, None, None], GP=[None, None, None]), + ] def test_same_bgzip(self): mt = hl.import_vcf(resource('sample.vcf'), min_partitions=4) @@ -529,6 +656,7 @@ def test_same_bgzip(self): def test_vcf_parallel_separate_header_export(self): fs = hl.current_backend().fs + def concat_files(outpath, inpaths): with fs.open(outpath, 'wb') as outfile: for path in inpaths: @@ -541,8 +669,7 @@ def concat_files(outpath, inpaths): stat = fs.stat(f) assert stat assert stat.is_dir() - shard_paths = [info.path for info in fs.ls(f) - if os.path.splitext(info.path)[-1] == '.bgz'] + shard_paths = [info.path for info in fs.ls(f) if os.path.splitext(info.path)[-1] == '.bgz'] assert shard_paths shard_paths.sort() nf = new_temp_file(extension='vcf.bgz') @@ -625,11 +752,13 @@ def test_format_genotypes(self): with hl.current_backend().fs.open(f, 'r') as i: for line in i: if line.startswith('20\t13029920'): - expected = "GT:AD:DP:GQ:PL\t1/1:0,6:6:18:234,18,0\t1/1:0,4:4:12:159,12,0\t" \ - "1/1:0,4:4:12:163,12,0\t1/1:0,12:12:36:479,36,0\t" \ - "1/1:0,4:4:12:149,12,0\t1/1:0,6:6:18:232,18,0\t" \ - "1/1:0,6:6:18:242,18,0\t1/1:0,3:3:9:119,9,0\t1/1:0,9:9:27:374,27,0" \ - "\t./.:1,0:1\t1/1:0,3:3:9:133,9,0" + expected = ( + "GT:AD:DP:GQ:PL\t1/1:0,6:6:18:234,18,0\t1/1:0,4:4:12:159,12,0\t" + "1/1:0,4:4:12:163,12,0\t1/1:0,12:12:36:479,36,0\t" + "1/1:0,4:4:12:149,12,0\t1/1:0,6:6:18:232,18,0\t" + "1/1:0,6:6:18:242,18,0\t1/1:0,3:3:9:119,9,0\t1/1:0,9:9:27:374,27,0" + "\t./.:1,0:1\t1/1:0,3:3:9:133,9,0" + ) assert expected in line break else: @@ -653,7 +782,7 @@ def test_metadata_argument(self): metadata = { 'filter': {'LowQual': {'Description': 'Low quality'}}, 'format': {'GT': {'Description': 'Genotype call.', 'Number': 'foo'}}, - 'fakeField': {} + 'fakeField': {}, } hl.export_vcf(mt, f, metadata=metadata) @@ -677,7 +806,6 @@ def test_invalid_info_fields(self): t = new_temp_file(extension='vcf') mt = hl.import_vcf(resource('sample.vcf')) - with mock.patch("hail.methods.impex.warning", autospec=True) as warning: hl.export_vcf(mt, t) assert warning.call_count == 0 @@ -706,17 +834,25 @@ def test_import_fam(self): def test_export_import_plink_same(self): mt = get_dataset() - mt = mt.select_rows(rsid=hl.delimit([mt.locus.contig, hl.str(mt.locus.position), mt.alleles[0], mt.alleles[1]], ':'), - cm_position=15.0) - mt = mt.select_cols(fam_id=hl.missing(hl.tstr), pat_id=hl.missing(hl.tstr), mat_id=hl.missing(hl.tstr), - is_female=hl.missing(hl.tbool), is_case=hl.missing(hl.tbool)) + mt = mt.select_rows( + rsid=hl.delimit([mt.locus.contig, hl.str(mt.locus.position), mt.alleles[0], mt.alleles[1]], ':'), + cm_position=15.0, + ) + mt = mt.select_cols( + fam_id=hl.missing(hl.tstr), + pat_id=hl.missing(hl.tstr), + mat_id=hl.missing(hl.tstr), + is_female=hl.missing(hl.tbool), + is_case=hl.missing(hl.tbool), + ) mt = mt.select_entries('GT') bfile = new_temp_file(prefix='test_import_export_plink') hl.export_plink(mt, bfile, ind_id=mt.s, cm_position=mt.cm_position) - mt_imported = hl.import_plink(bfile + '.bed', bfile + '.bim', bfile + '.fam', - a2_reference=True, reference_genome='GRCh37', n_partitions=8) + mt_imported = hl.import_plink( + bfile + '.bed', bfile + '.bim', bfile + '.fam', a2_reference=True, reference_genome='GRCh37', n_partitions=8 + ) self.assertTrue(mt._same(mt_imported)) self.assertTrue(mt.aggregate_rows(hl.agg.all(mt.cm_position == 15.0))) @@ -740,24 +876,27 @@ def test_import_plink_a1_major(self): hl.export_plink(mt, bfile, ind_id=mt.s) def get_data(a2_reference): - mt_imported = hl.import_plink(bfile + '.bed', bfile + '.bim', - bfile + '.fam', a2_reference=a2_reference) - return (hl.variant_qc(mt_imported) - .rows() - .key_by('rsid')) + mt_imported = hl.import_plink(bfile + '.bed', bfile + '.bim', bfile + '.fam', a2_reference=a2_reference) + return hl.variant_qc(mt_imported).rows().key_by('rsid') a2 = get_data(a2_reference=True) a1 = get_data(a2_reference=False) - j = (a2.annotate(a1_alleles=a1[a2.rsid].alleles, a1_vqc=a1[a2.rsid].variant_qc) - .rename({'variant_qc': 'a2_vqc', 'alleles': 'a2_alleles'})) + j = a2.annotate(a1_alleles=a1[a2.rsid].alleles, a1_vqc=a1[a2.rsid].variant_qc).rename({ + 'variant_qc': 'a2_vqc', + 'alleles': 'a2_alleles', + }) - self.assertTrue(j.all((j.a1_alleles[0] == j.a2_alleles[1]) & - (j.a1_alleles[1] == j.a2_alleles[0]) & - (j.a1_vqc.n_not_called == j.a2_vqc.n_not_called) & - (j.a1_vqc.n_het == j.a2_vqc.n_het) & - (j.a1_vqc.homozygote_count[0] == j.a2_vqc.homozygote_count[1]) & - (j.a1_vqc.homozygote_count[1] == j.a2_vqc.homozygote_count[0]))) + self.assertTrue( + j.all( + (j.a1_alleles[0] == j.a2_alleles[1]) + & (j.a1_alleles[1] == j.a2_alleles[0]) + & (j.a1_vqc.n_not_called == j.a2_vqc.n_not_called) + & (j.a1_vqc.n_het == j.a2_vqc.n_het) + & (j.a1_vqc.homozygote_count[0] == j.a2_vqc.homozygote_count[1]) + & (j.a1_vqc.homozygote_count[1] == j.a2_vqc.homozygote_count[0]) + ) + ) def test_import_plink_same_locus(self): mt = hl.balding_nichols_model(n_populations=2, n_samples=10, n_variants=100) @@ -785,28 +924,30 @@ def test_import_plink_partitions(self): def test_import_plink_contig_recoding_w_reference(self): vcf = hl.split_multi_hts( - hl.import_vcf(resource('sample2.vcf'), - reference_genome=hl.get_reference('GRCh38'), - contig_recoding={"22": "chr22"})) + hl.import_vcf( + resource('sample2.vcf'), reference_genome=hl.get_reference('GRCh38'), contig_recoding={"22": "chr22"} + ) + ) bfile = new_temp_file(prefix='sample_plink') hl.export_plink(vcf, bfile) plink = hl.import_plink( - bfile + '.bed', bfile + '.bim', bfile + '.fam', + bfile + '.bed', + bfile + '.bim', + bfile + '.fam', a2_reference=True, contig_recoding={'chr22': '22'}, - reference_genome='GRCh37').rows() + reference_genome='GRCh37', + ).rows() self.assertTrue(plink.all(plink.locus.contig == "22")) self.assertEqual(vcf.count_rows(), plink.count()) self.assertTrue(plink.locus.dtype, hl.tlocus('GRCh37')) def test_import_plink_no_reference_specified(self): bfile = resource('fastlmmTest') - plink = hl.import_plink(bfile + '.bed', bfile + '.bim', bfile + '.fam', - reference_genome=None) - self.assertEqual(plink.locus.dtype, - hl.tstruct(contig=hl.tstr, position=hl.tint32)) + plink = hl.import_plink(bfile + '.bed', bfile + '.bim', bfile + '.fam', reference_genome=None) + self.assertEqual(plink.locus.dtype, hl.tstruct(contig=hl.tstr, position=hl.tint32)) def test_import_plink_and_ignore_rows(self): bfile = doctest_resource('ldsc') @@ -814,19 +955,24 @@ def test_import_plink_and_ignore_rows(self): self.assertEqual(plink.aggregate_cols(hl.agg.count()), 489) def test_import_plink_skip_invalid_loci(self): - mt = hl.import_plink(resource('skip_invalid_loci.bed'), - resource('skip_invalid_loci.bim'), - resource('skip_invalid_loci.fam'), - reference_genome='GRCh37', - skip_invalid_loci=True, - contig_recoding={'chr1': '1'}) + mt = hl.import_plink( + resource('skip_invalid_loci.bed'), + resource('skip_invalid_loci.bim'), + resource('skip_invalid_loci.fam'), + reference_genome='GRCh37', + skip_invalid_loci=True, + contig_recoding={'chr1': '1'}, + ) self.assertEqual(mt._force_count_rows(), 3) with self.assertRaisesRegex(FatalError, 'Invalid locus'): - (hl.import_plink(resource('skip_invalid_loci.bed'), - resource('skip_invalid_loci.bim'), - resource('skip_invalid_loci.fam')) - ._force_count_rows()) + ( + hl.import_plink( + resource('skip_invalid_loci.bed'), + resource('skip_invalid_loci.bim'), + resource('skip_invalid_loci.fam'), + )._force_count_rows() + ) @unittest.skipIf('HAIL_TEST_SKIP_PLINK' in os.environ, 'Skipping tests requiring plink') def test_export_plink(self): @@ -835,6 +981,7 @@ def test_export_plink(self): # permute columns so not in alphabetical order! import random + indices = list(range(mt.count_cols())) random.shuffle(indices) mt = mt.choose_cols(indices) @@ -854,9 +1001,16 @@ def test_export_plink(self): hl.hadoop_copy(hl_output + '.bim', local_hl_output + '.bim') hl.hadoop_copy(hl_output + '.fam', local_hl_output + '.fam') - run_command(["plink", "--vcf", local_split_vcf_file, - "--make-bed", "--out", plink_output, - "--const-fid", "--keep-allele-order"]) + run_command([ + "plink", + "--vcf", + local_split_vcf_file, + "--make-bed", + "--out", + plink_output, + "--const-fid", + "--keep-allele-order", + ]) data = [] with open(uri_path(plink_output + ".bim")) as file: @@ -868,9 +1022,17 @@ def test_export_plink(self): with open(plink_output + ".bim", 'w') as f: f.writelines(data) - run_command(["plink", "--bfile", plink_output, - "--bmerge", local_hl_output, "--merge-mode", - "6", "--out", merge_output]) + run_command([ + "plink", + "--bfile", + plink_output, + "--bmerge", + local_hl_output, + "--merge-mode", + "6", + "--out", + merge_output, + ]) same = True with open(merge_output + ".diff") as f: @@ -884,62 +1046,72 @@ def test_export_plink(self): def test_export_plink_default_arguments(self): ds = get_dataset() - fam_mapping = {'f0': 'fam_id', 'f1': 'ind_id', 'f2': 'pat_id', 'f3': 'mat_id', - 'f4': 'is_female', 'f5': 'pheno'} - bim_mapping = {'f0': 'contig', 'f1': 'varid', 'f2': 'cm_position', - 'f3': 'position', 'f4': 'a1', 'f5': 'a2'} + fam_mapping = {'f0': 'fam_id', 'f1': 'ind_id', 'f2': 'pat_id', 'f3': 'mat_id', 'f4': 'is_female', 'f5': 'pheno'} + bim_mapping = {'f0': 'contig', 'f1': 'varid', 'f2': 'cm_position', 'f3': 'position', 'f4': 'a1', 'f5': 'a2'} out1 = new_temp_file() hl.export_plink(ds, out1) - fam1 = (hl.import_table(out1 + '.fam', no_header=True, impute=False, missing="") - .rename(fam_mapping)) - bim1 = (hl.import_table(out1 + '.bim', no_header=True, impute=False) - .rename(bim_mapping)) + fam1 = hl.import_table(out1 + '.fam', no_header=True, impute=False, missing="").rename(fam_mapping) + bim1 = hl.import_table(out1 + '.bim', no_header=True, impute=False).rename(bim_mapping) - self.assertTrue(fam1.all((fam1.fam_id == "0") & (fam1.pat_id == "0") & - (fam1.mat_id == "0") & (fam1.is_female == "0") & - (fam1.pheno == "NA"))) - self.assertTrue(bim1.all((bim1.varid == bim1.contig + ":" + bim1.position + ":" + bim1.a2 + ":" + bim1.a1) & - (bim1.cm_position == "0.0"))) + self.assertTrue( + fam1.all( + (fam1.fam_id == "0") + & (fam1.pat_id == "0") + & (fam1.mat_id == "0") + & (fam1.is_female == "0") + & (fam1.pheno == "NA") + ) + ) + self.assertTrue( + bim1.all( + (bim1.varid == bim1.contig + ":" + bim1.position + ":" + bim1.a2 + ":" + bim1.a1) + & (bim1.cm_position == "0.0") + ) + ) def test_export_plink_non_default_arguments(self): ds = get_dataset() - fam_mapping = {'f0': 'fam_id', 'f1': 'ind_id', 'f2': 'pat_id', 'f3': 'mat_id', - 'f4': 'is_female', 'f5': 'pheno'} + fam_mapping = {'f0': 'fam_id', 'f1': 'ind_id', 'f2': 'pat_id', 'f3': 'mat_id', 'f4': 'is_female', 'f5': 'pheno'} out2 = new_temp_file() - hl.export_plink(ds, out2, ind_id=ds.s, fam_id=ds.s, pat_id="nope", - mat_id="nada", is_female=True, pheno=False) - fam2 = (hl.import_table(out2 + '.fam', no_header=True, impute=False, missing="") - .rename(fam_mapping)) + hl.export_plink(ds, out2, ind_id=ds.s, fam_id=ds.s, pat_id="nope", mat_id="nada", is_female=True, pheno=False) + fam2 = hl.import_table(out2 + '.fam', no_header=True, impute=False, missing="").rename(fam_mapping) - self.assertTrue(fam2.all((fam2.fam_id == fam2.ind_id) & (fam2.pat_id == "nope") & - (fam2.mat_id == "nada") & (fam2.is_female == "2") & - (fam2.pheno == "1"))) + self.assertTrue( + fam2.all( + (fam2.fam_id == fam2.ind_id) + & (fam2.pat_id == "nope") + & (fam2.mat_id == "nada") + & (fam2.is_female == "2") + & (fam2.pheno == "1") + ) + ) def test_export_plink_quantitative_phenotype(self): ds = get_dataset() - fam_mapping = {'f0': 'fam_id', 'f1': 'ind_id', 'f2': 'pat_id', 'f3': 'mat_id', - 'f4': 'is_female', 'f5': 'pheno'} - bim_mapping = {'f0': 'contig', 'f1': 'varid', 'f2': 'cm_position', - 'f3': 'position', 'f4': 'a1', 'f5': 'a2'} + fam_mapping = {'f0': 'fam_id', 'f1': 'ind_id', 'f2': 'pat_id', 'f3': 'mat_id', 'f4': 'is_female', 'f5': 'pheno'} out3 = new_temp_file() hl.export_plink(ds, out3, ind_id=ds.s, pheno=hl.float64(hl.len(ds.s))) - fam3 = (hl.import_table(out3 + '.fam', no_header=True, impute=False, missing="") - .rename(fam_mapping)) + fam3 = hl.import_table(out3 + '.fam', no_header=True, impute=False, missing="").rename(fam_mapping) - self.assertTrue(fam3.all((fam3.fam_id == "0") & (fam3.pat_id == "0") & - (fam3.mat_id == "0") & (fam3.is_female == "0") & - (fam3.pheno != "0") & (fam3.pheno != "NA"))) + self.assertTrue( + fam3.all( + (fam3.fam_id == "0") + & (fam3.pat_id == "0") + & (fam3.mat_id == "0") + & (fam3.is_female == "0") + & (fam3.pheno != "0") + & (fam3.pheno != "NA") + ) + ) def test_export_plink_non_default_bim_arguments(self): ds = get_dataset() - bim_mapping = {'f0': 'contig', 'f1': 'varid', 'f2': 'cm_position', - 'f3': 'position', 'f4': 'a1', 'f5': 'a2'} + bim_mapping = {'f0': 'contig', 'f1': 'varid', 'f2': 'cm_position', 'f3': 'position', 'f4': 'a1', 'f5': 'a2'} out4 = new_temp_file() hl.export_plink(ds, out4, varid="hello", cm_position=100) - bim4 = (hl.import_table(out4 + '.bim', no_header=True, impute=False) - .rename(bim_mapping)) + bim4 = hl.import_table(out4 + '.bim', no_header=True, impute=False).rename(bim_mapping) self.assertTrue(bim4.all((bim4.varid == "hello") & (bim4.cm_position == "100.0"))) @@ -963,38 +1135,45 @@ def test_export_plink_white_space_in_varid_raises_error(self): hl.export_plink(ds, new_temp_file(), varid="hello world") def test_contig_recoding_defaults(self): - hl.import_plink(resource('sex_mt_contigs.bed'), - resource('sex_mt_contigs.bim'), - resource('sex_mt_contigs.fam'), - reference_genome='GRCh37') + hl.import_plink( + resource('sex_mt_contigs.bed'), + resource('sex_mt_contigs.bim'), + resource('sex_mt_contigs.fam'), + reference_genome='GRCh37', + ) - hl.import_plink(resource('sex_mt_contigs.bed'), - resource('sex_mt_contigs.bim'), - resource('sex_mt_contigs.fam'), - reference_genome='GRCh38') + hl.import_plink( + resource('sex_mt_contigs.bed'), + resource('sex_mt_contigs.bim'), + resource('sex_mt_contigs.fam'), + reference_genome='GRCh38', + ) - rg_random = hl.ReferenceGenome("random", ['1', '23', '24', '25', '26'], - {'1': 10, '23': 10, '24': 10, '25': 10, '26': 10}) + hl.ReferenceGenome("random", ['1', '23', '24', '25', '26'], {'1': 10, '23': 10, '24': 10, '25': 10, '26': 10}) - hl.import_plink(resource('sex_mt_contigs.bed'), - resource('sex_mt_contigs.bim'), - resource('sex_mt_contigs.fam'), - reference_genome='random') + hl.import_plink( + resource('sex_mt_contigs.bed'), + resource('sex_mt_contigs.bim'), + resource('sex_mt_contigs.fam'), + reference_genome='random', + ) def test_export_plink_struct_locus(self): mt = hl.utils.range_matrix_table(10, 10) - mt = mt.key_rows_by(locus=hl.struct(contig=hl.str(mt.row_idx), position=mt.row_idx), alleles=['A', 'T']).select_rows() + mt = mt.key_rows_by( + locus=hl.struct(contig=hl.str(mt.row_idx), position=mt.row_idx), alleles=['A', 'T'] + ).select_rows() mt = mt.key_cols_by(s=hl.str(mt.col_idx)).select_cols() mt = mt.annotate_entries(GT=hl.call(0, 0)) out = new_temp_file() hl.export_plink(mt, out) - mt2 = hl.import_plink( - bed=out + '.bed', - bim=out + '.bim', - fam=out + '.fam', - reference_genome=None).select_rows().select_cols() + mt2 = ( + hl.import_plink(bed=out + '.bed', bim=out + '.bim', fam=out + '.fam', reference_genome=None) + .select_rows() + .select_cols() + ) assert mt._same(mt2) @@ -1016,17 +1195,14 @@ def test_export_plink_struct_locus(self): # qctool -g random-c-disjoint.gen -s random.sample -og random-c-disjoint.bgen -bgen-bits 8 def generate_random_gen(): mt = hl.utils.range_matrix_table(30, 10) - mt = (mt.annotate_rows(locus = hl.locus('20', mt.row_idx + 1), - alleles = ['A', 'G']) - .key_rows_by('locus', 'alleles')) - mt = (mt.annotate_cols(s = hl.str(mt.col_idx)) - .key_cols_by('s')) + mt = mt.annotate_rows(locus=hl.locus('20', mt.row_idx + 1), alleles=['A', 'G']).key_rows_by('locus', 'alleles') + mt = mt.annotate_cols(s=hl.str(mt.col_idx)).key_cols_by('s') # using totally random values leads rounding differences where # identical GEN values get rounded differently, leading to # differences in the GT call between import_{gen, bgen} - mt = mt.annotate_entries(a = hl.int32(hl.rand_unif(0.0, 255.0))) - mt = mt.annotate_entries(b = hl.int32(hl.rand_unif(0.0, 255.0 - mt.a))) - mt = mt.transmute_entries(GP = hl.array([mt.a, mt.b, 255.0 - mt.a - mt.b]) / 255.0) + mt = mt.annotate_entries(a=hl.int32(hl.rand_unif(0.0, 255.0))) + mt = mt.annotate_entries(b=hl.int32(hl.rand_unif(0.0, 255.0 - mt.a))) + mt = mt.transmute_entries(GP=hl.array([mt.a, mt.b, 255.0 - mt.a - mt.b]) / 255.0) # 20% missing mt = mt.filter_entries(hl.rand_bool(0.8)) hl.export_gen(mt, 'random', precision=4) @@ -1044,37 +1220,36 @@ def test_error_if_no_gp(self): hl.export_gen(mt, tmp_path) def test_import_bgen_dosage_entry(self): - bgen = hl.import_bgen(resource('example.8bits.bgen'), - entry_fields=['dosage']) + bgen = hl.import_bgen(resource('example.8bits.bgen'), entry_fields=['dosage']) self.assertEqual(bgen.entry.dtype, hl.tstruct(dosage=hl.tfloat64)) self.assertEqual(bgen.count_rows(), 199) self.assertEqual(bgen._force_count_rows(), 199) def test_import_bgen_GT_GP_entries(self): - bgen = hl.import_bgen(resource('example.8bits.bgen'), - entry_fields=['GT', 'GP'], - sample_file=resource('example.sample')) + bgen = hl.import_bgen( + resource('example.8bits.bgen'), entry_fields=['GT', 'GP'], sample_file=resource('example.sample') + ) self.assertEqual(bgen.entry.dtype, hl.tstruct(GT=hl.tcall, GP=hl.tarray(hl.tfloat64))) def test_import_bgen_no_entries(self): - bgen = hl.import_bgen(resource('example.8bits.bgen'), - entry_fields=[], - sample_file=resource('example.sample')) + bgen = hl.import_bgen(resource('example.8bits.bgen'), entry_fields=[], sample_file=resource('example.sample')) self.assertEqual(bgen.entry.dtype, hl.tstruct()) def test_import_bgen_no_reference(self): - bgen = hl.import_bgen(resource('example.8bits.bgen'), - entry_fields=['GT', 'GP', 'dosage'], - index_file_map={resource('example.8bits.bgen'): resource('example.8bits.bgen-NO-REFERENCE-GENOME.idx2')}) + bgen = hl.import_bgen( + resource('example.8bits.bgen'), + entry_fields=['GT', 'GP', 'dosage'], + index_file_map={resource('example.8bits.bgen'): resource('example.8bits.bgen-NO-REFERENCE-GENOME.idx2')}, + ) assert bgen.locus.dtype == hl.tstruct(contig=hl.tstr, position=hl.tint32) assert bgen.count_rows() == 199 def test_import_bgen_skip_invalid_loci_does_not_error_with_invalid_loci(self): # Note: the skip_invalid_loci.bgen has 16-bit probabilities, and Hail # will crash if the genotypes are decoded - mt = hl.import_bgen(resource('skip_invalid_loci.bgen'), - entry_fields=[], - sample_file=resource('skip_invalid_loci.sample')) + mt = hl.import_bgen( + resource('skip_invalid_loci.bgen'), entry_fields=[], sample_file=resource('skip_invalid_loci.sample') + ) assert mt.rows().count() == 3 def test_import_bgen_errors_with_invalid_loci(self): @@ -1082,84 +1257,61 @@ def test_import_bgen_errors_with_invalid_loci(self): hl.current_backend().fs.copy(resource('skip_invalid_loci.bgen'), f) with pytest.raises(FatalError, match='Invalid locus'): hl.index_bgen(f) - mt = hl.import_bgen(f, - entry_fields=[], - sample_file=resource('skip_invalid_loci.sample')) + mt = hl.import_bgen(f, entry_fields=[], sample_file=resource('skip_invalid_loci.sample')) mt.rows().count() def test_import_bgen_gavin_example(self): recoding = {'0{}'.format(i): str(i) for i in range(1, 10)} sample_file = resource('example.sample') - genmt = hl.import_gen(resource('example.gen'), sample_file, - contig_recoding=recoding, - reference_genome="GRCh37") + genmt = hl.import_gen(resource('example.gen'), sample_file, contig_recoding=recoding, reference_genome="GRCh37") bgen_file = resource('example.8bits.bgen') bgenmt = hl.import_bgen(bgen_file, ['GT', 'GP'], sample_file) - self.assertTrue( - bgenmt._same(genmt, tolerance=1.0 / 255, absolute=True)) + self.assertTrue(bgenmt._same(genmt, tolerance=1.0 / 255, absolute=True)) def test_import_bgen_random(self): sample_file = resource('random.sample') genmt = hl.import_gen(resource('random.gen'), sample_file) bgenmt = hl.import_bgen(resource('random.bgen'), ['GT', 'GP'], sample_file) - self.assertTrue( - bgenmt._same(genmt, tolerance=1.0 / 255, absolute=True)) + self.assertTrue(bgenmt._same(genmt, tolerance=1.0 / 255, absolute=True)) def test_parallel_import(self): - mt = hl.import_bgen(resource('parallelBgenExport.bgen'), - ['GT', 'GP'], - resource('parallelBgenExport.sample')) + mt = hl.import_bgen(resource('parallelBgenExport.bgen'), ['GT', 'GP'], resource('parallelBgenExport.sample')) self.assertEqual(mt.count(), (16, 10)) def test_import_bgen_dosage_and_gp_dosage_function_agree(self): - recoding = {'0{}'.format(i): str(i) for i in range(1, 10)} - sample_file = resource('example.sample') bgen_file = resource('example.8bits.bgen') bgenmt = hl.import_bgen(bgen_file, ['GP', 'dosage'], sample_file) et = bgenmt.entries() - et = et.transmute(gp_dosage = hl.gp_dosage(et.GP)) - self.assertTrue(et.all( - (hl.is_missing(et.dosage) & hl.is_missing(et.gp_dosage)) | - (hl.abs(et.dosage - et.gp_dosage) < 1e-6))) + et = et.transmute(gp_dosage=hl.gp_dosage(et.GP)) + self.assertTrue( + et.all((hl.is_missing(et.dosage) & hl.is_missing(et.gp_dosage)) | (hl.abs(et.dosage - et.gp_dosage) < 1e-6)) + ) def test_import_bgen_row_fields(self): - default_row_fields = hl.import_bgen(resource('example.8bits.bgen'), - entry_fields=['dosage']) - self.assertEqual(default_row_fields.row.dtype, - hl.tstruct(locus=hl.tlocus('GRCh37'), - alleles=hl.tarray(hl.tstr), - rsid=hl.tstr, - varid=hl.tstr)) - no_row_fields = hl.import_bgen(resource('example.8bits.bgen'), - entry_fields=['dosage'], - _row_fields=[]) - self.assertEqual(no_row_fields.row.dtype, - hl.tstruct(locus=hl.tlocus('GRCh37'), - alleles=hl.tarray(hl.tstr))) - varid_only = hl.import_bgen(resource('example.8bits.bgen'), - entry_fields=['dosage'], - _row_fields=['varid']) - self.assertEqual(varid_only.row.dtype, - hl.tstruct(locus=hl.tlocus('GRCh37'), - alleles=hl.tarray(hl.tstr), - varid=hl.tstr)) - rsid_only = hl.import_bgen(resource('example.8bits.bgen'), - entry_fields=['dosage'], - _row_fields=['rsid']) - self.assertEqual(rsid_only.row.dtype, - hl.tstruct(locus=hl.tlocus('GRCh37'), - alleles=hl.tarray(hl.tstr), - rsid=hl.tstr)) + default_row_fields = hl.import_bgen(resource('example.8bits.bgen'), entry_fields=['dosage']) + self.assertEqual( + default_row_fields.row.dtype, + hl.tstruct(locus=hl.tlocus('GRCh37'), alleles=hl.tarray(hl.tstr), rsid=hl.tstr, varid=hl.tstr), + ) + no_row_fields = hl.import_bgen(resource('example.8bits.bgen'), entry_fields=['dosage'], _row_fields=[]) + self.assertEqual(no_row_fields.row.dtype, hl.tstruct(locus=hl.tlocus('GRCh37'), alleles=hl.tarray(hl.tstr))) + varid_only = hl.import_bgen(resource('example.8bits.bgen'), entry_fields=['dosage'], _row_fields=['varid']) + self.assertEqual( + varid_only.row.dtype, hl.tstruct(locus=hl.tlocus('GRCh37'), alleles=hl.tarray(hl.tstr), varid=hl.tstr) + ) + rsid_only = hl.import_bgen(resource('example.8bits.bgen'), entry_fields=['dosage'], _row_fields=['rsid']) + self.assertEqual( + rsid_only.row.dtype, hl.tstruct(locus=hl.tlocus('GRCh37'), alleles=hl.tarray(hl.tstr), rsid=hl.tstr) + ) self.assertTrue(default_row_fields.drop('varid')._same(rsid_only)) self.assertTrue(default_row_fields.drop('rsid')._same(varid_only)) - self.assertTrue( - default_row_fields.drop('varid', 'rsid')._same(no_row_fields)) + self.assertTrue(default_row_fields.drop('varid', 'rsid')._same(no_row_fields)) def test_import_bgen_variant_filtering_from_literals(self): bgen_file = resource('example.8bits.bgen') @@ -1179,20 +1331,24 @@ def test_import_bgen_variant_filtering_from_literals(self): hl.Struct(locus=hl.Locus('1', 2001), alleles=alleles), hl.Struct(locus=hl.Locus('1', 4000), alleles=alleles), hl.Struct(locus=hl.Locus('1', 10000), alleles=alleles), - hl.Struct(locus=hl.Locus('1', 10000), alleles=alleles), # Duplicated variant + hl.Struct(locus=hl.Locus('1', 10000), alleles=alleles), # Duplicated variant hl.Struct(locus=hl.Locus('1', 100001), alleles=alleles), ] - part_1 = hl.import_bgen(bgen_file, - ['GT'], - n_partitions=1, # forcing seek to be called - variants=desired_variants) + part_1 = hl.import_bgen( + bgen_file, + ['GT'], + n_partitions=1, + variants=desired_variants, # forcing seek to be called + ) self.assertEqual(part_1.rows().key_by('locus', 'alleles').select().collect(), expected_result) - part_199 = hl.import_bgen(bgen_file, - ['GT'], - n_partitions=199, # forcing each variant to be its own partition for testing duplicates work properly - variants=desired_variants) + part_199 = hl.import_bgen( + bgen_file, + ['GT'], + n_partitions=199, # forcing each variant to be its own partition for testing duplicates work properly + variants=desired_variants, + ) self.assertEqual(part_199.rows().key_by('locus', 'alleles').select().collect(), expected_result) everything = hl.import_bgen(bgen_file, ['GT']) @@ -1210,12 +1366,10 @@ def test_import_bgen_locus_filtering_from_Struct_object(self): expected_result = [ hl.Struct(locus=hl.Locus('1', 10000), alleles=['A', 'G']), - hl.Struct(locus=hl.Locus('1', 10000), alleles=['A', 'G']) # Duplicated variant + hl.Struct(locus=hl.Locus('1', 10000), alleles=['A', 'G']), # Duplicated variant ] - data = hl.import_bgen(bgen_file, - ['GT'], - variants=desired_loci) + data = hl.import_bgen(bgen_file, ['GT'], variants=desired_loci) assert data.rows().key_by('locus', 'alleles').select().collect() == expected_result def test_import_bgen_locus_filtering_from_struct_expression(self): @@ -1226,12 +1380,10 @@ def test_import_bgen_locus_filtering_from_struct_expression(self): expected_result = [ hl.Struct(locus=hl.Locus('1', 10000), alleles=['A', 'G']), - hl.Struct(locus=hl.Locus('1', 10000), alleles=['A', 'G']) # Duplicated variant + hl.Struct(locus=hl.Locus('1', 10000), alleles=['A', 'G']), # Duplicated variant ] - data = hl.import_bgen(bgen_file, - ['GT'], - variants=desired_loci) + data = hl.import_bgen(bgen_file, ['GT'], variants=desired_loci) assert data.rows().key_by('locus', 'alleles').select().collect() == expected_result def test_import_bgen_variant_filtering_from_exprs(self): @@ -1242,10 +1394,9 @@ def test_import_bgen_variant_filtering_from_exprs(self): desired_variants = hl.struct(locus=everything.locus, alleles=everything.alleles) - actual = hl.import_bgen(bgen_file, - ['GT'], - n_partitions=10, - variants=desired_variants) # filtering with everything + actual = hl.import_bgen( + bgen_file, ['GT'], n_partitions=10, variants=desired_variants + ) # filtering with everything self.assertTrue(everything._same(actual)) @@ -1255,15 +1406,11 @@ def test_import_bgen_locus_filtering_from_exprs(self): everything = hl.import_bgen(bgen_file, ['GT']) self.assertEqual(everything.count(), (199, 500)) - actual_struct = hl.import_bgen(bgen_file, - ['GT'], - variants=hl.struct(locus=everything.locus)) + actual_struct = hl.import_bgen(bgen_file, ['GT'], variants=hl.struct(locus=everything.locus)) self.assertTrue(everything._same(actual_struct)) - actual_locus = hl.import_bgen(bgen_file, - ['GT'], - variants=everything.locus) + actual_locus = hl.import_bgen(bgen_file, ['GT'], variants=everything.locus) self.assertTrue(everything._same(actual_locus)) @@ -1275,39 +1422,32 @@ def test_import_bgen_variant_filtering_from_table(self): desired_variants = everything.rows() - actual = hl.import_bgen(bgen_file, - ['GT'], - n_partitions=10, - variants=desired_variants) # filtering with everything + actual = hl.import_bgen( + bgen_file, ['GT'], n_partitions=10, variants=desired_variants + ) # filtering with everything self.assertTrue(everything._same(actual)) def test_import_bgen_locus_filtering_from_table(self): bgen_file = resource('example.8bits.bgen') - desired_loci = hl.Table.parallelize([{'locus': hl.Locus('1', 10000)}], - schema=hl.tstruct(locus=hl.tlocus()), - key='locus') + desired_loci = hl.Table.parallelize( + [{'locus': hl.Locus('1', 10000)}], schema=hl.tstruct(locus=hl.tlocus()), key='locus' + ) expected_result = [ hl.Struct(locus=hl.Locus('1', 10000), alleles=['A', 'G']), - hl.Struct(locus=hl.Locus('1', 10000), alleles=['A', 'G']) # Duplicated variant + hl.Struct(locus=hl.Locus('1', 10000), alleles=['A', 'G']), # Duplicated variant ] - result = hl.import_bgen(bgen_file, - ['GT'], - variants=desired_loci) + result = hl.import_bgen(bgen_file, ['GT'], variants=desired_loci) - self.assertEqual(result.rows().key_by('locus', 'alleles').select().collect(), - expected_result) + self.assertEqual(result.rows().key_by('locus', 'alleles').select().collect(), expected_result) def test_import_bgen_empty_variant_filter(self): bgen_file = resource('example.8bits.bgen') - actual = hl.import_bgen(bgen_file, - ['GT'], - n_partitions=10, - variants=[]) + actual = hl.import_bgen(bgen_file, ['GT'], n_partitions=10, variants=[]) self.assertEqual(actual.count_rows(), 0) nothing = hl.import_bgen(bgen_file, ['GT']).filter_rows(False) @@ -1315,22 +1455,16 @@ def test_import_bgen_empty_variant_filter(self): desired_variants = hl.struct(locus=nothing.locus, alleles=nothing.alleles) - actual = hl.import_bgen(bgen_file, - ['GT'], - n_partitions=10, - variants=desired_variants) + actual = hl.import_bgen(bgen_file, ['GT'], n_partitions=10, variants=desired_variants) self.assertEqual(actual.count_rows(), 0) # FIXME testing block_size (in MB) requires large BGEN def test_n_partitions(self): - bgen = hl.import_bgen(resource('example.8bits.bgen'), - entry_fields=['dosage'], - n_partitions=210) - self.assertEqual(bgen.n_partitions(), 199) # only 199 variants in the file + bgen = hl.import_bgen(resource('example.8bits.bgen'), entry_fields=['dosage'], n_partitions=210) + self.assertEqual(bgen.n_partitions(), 199) # only 199 variants in the file def test_drop(self): - bgen = hl.import_bgen(resource('example.8bits.bgen'), - entry_fields=['dosage']) + bgen = hl.import_bgen(resource('example.8bits.bgen'), entry_fields=['dosage']) dr = bgen.filter_rows(False) self.assertEqual(dr._force_count_rows(), 0) @@ -1342,9 +1476,9 @@ def test_drop(self): def test_index_multiple_bgen_files_does_not_fail_and_is_importable(self): original_bgen_files = [resource('random-b.bgen'), resource('random-c.bgen'), resource('random-a.bgen')] - with hl.TemporaryFilename(suffix='.bgen') as f, \ - hl.TemporaryFilename(suffix='.bgen') as g, \ - hl.TemporaryFilename(suffix='.bgen') as h: + with hl.TemporaryFilename(suffix='.bgen') as f, hl.TemporaryFilename(suffix='.bgen') as g, hl.TemporaryFilename( + suffix='.bgen' + ) as h: newly_indexed_bgen_files = [f, g, h] for source, temp in zip(original_bgen_files, newly_indexed_bgen_files): hl.current_backend().fs.copy(source, temp) @@ -1370,14 +1504,10 @@ def test_multiple_files_variant_filtering(self): hl.Struct(locus=hl.Locus('20', 12), alleles=alleles), ] - actual = hl.import_bgen(bgen_file, - ['GT'], - n_partitions=10, - variants=desired_variants) + actual = hl.import_bgen(bgen_file, ['GT'], n_partitions=10, variants=desired_variants) assert actual.count_rows() == 6 - everything = hl.import_bgen(bgen_file, - ['GT']) + everything = hl.import_bgen(bgen_file, ['GT']) assert everything.count() == (30, 10) expected = everything.filter_rows(hl.set(desired_variants).contains(everything.row_key)) @@ -1386,8 +1516,14 @@ def test_multiple_files_variant_filtering(self): def test_multiple_files_disjoint(self): sample_file = resource('random.sample') - bgen_file = [resource('random-b-disjoint.bgen'), resource('random-c-disjoint.bgen'), resource('random-a-disjoint.bgen')] - with pytest.raises(FatalError, match='Each BGEN file must contain a region of the genome disjoint from other files'): + bgen_file = [ + resource('random-b-disjoint.bgen'), + resource('random-c-disjoint.bgen'), + resource('random-a-disjoint.bgen'), + ] + with pytest.raises( + FatalError, match='Each BGEN file must contain a region of the genome disjoint from other files' + ): hl.import_bgen(bgen_file, ['GT', 'GP'], sample_file, n_partitions=3) def test_multiple_references_throws_error(self): @@ -1396,13 +1532,15 @@ def test_multiple_references_throws_error(self): bgen_file2 = resource('random-c.bgen') with pytest.raises(FatalError, match='Found multiple reference genomes were specified in the BGEN index files'): - hl.import_bgen([bgen_file1, bgen_file2], - ['GT'], - sample_file=sample_file, - index_file_map={ - resource('random-b.bgen'): resource('random-b.bgen-NO-REFERENCE-GENOME.idx2'), - resource('random-c.bgen'): resource('random-c.bgen.idx2'), - }) + hl.import_bgen( + [bgen_file1, bgen_file2], + ['GT'], + sample_file=sample_file, + index_file_map={ + resource('random-b.bgen'): resource('random-b.bgen-NO-REFERENCE-GENOME.idx2'), + resource('random-c.bgen'): resource('random-c.bgen.idx2'), + }, + ) def test_old_index_file_throws_error(self): sample_file = resource('random.sample') @@ -1410,14 +1548,17 @@ def test_old_index_file_throws_error(self): with hl.TemporaryFilename() as f: hl.current_backend().fs.copy(bgen_file, f) - with pytest.raises(FatalError, match='have no .idx2 index file'): + + expected_missing_idx2_error_message = re.compile(f'have no .idx2 index file.*{f}.*', re.DOTALL) + + with pytest.raises(FatalError, match=expected_missing_idx2_error_message): hl.import_bgen(f, ['GT', 'GP'], sample_file, n_partitions=3) try: with hl.current_backend().fs.open(f + '.idx', 'wb') as fobj: fobj.write(b'') - with pytest.raises(FatalError, match='have no .idx2 index file'): + with pytest.raises(FatalError, match=expected_missing_idx2_error_message): hl.import_bgen(f, ['GT', 'GP'], sample_file) finally: hl.current_backend().fs.remove(f + '.idx') @@ -1428,12 +1569,8 @@ def test_specify_different_index_file(self): with hl.TemporaryDirectory(suffix='.idx2', ensure_exists=False) as index_file: index_file_map = {bgen_file: index_file} - hl.index_bgen(bgen_file, - index_file_map=index_file_map) - mt = hl.import_bgen(bgen_file, - ['GT', 'GP'], - sample_file, - index_file_map=index_file_map) + hl.index_bgen(bgen_file, index_file_map=index_file_map) + mt = hl.import_bgen(bgen_file, ['GT', 'GP'], sample_file, index_file_map=index_file_map) assert mt.count() == (30, 10) def test_index_bgen_errors_when_index_file_has_wrong_extension(self): @@ -1445,53 +1582,44 @@ def test_index_bgen_errors_when_index_file_has_wrong_extension(self): hl.index_bgen(bgen_file, index_file_map=index_file_map) def test_export_bgen(self): - bgen = hl.import_bgen(resource('example.8bits.bgen'), - entry_fields=['GP'], - sample_file=resource('example.sample')) + bgen = hl.import_bgen( + resource('example.8bits.bgen'), entry_fields=['GP'], sample_file=resource('example.sample') + ) with hl.TemporaryDirectory(ensure_exists=False) as tmpdir: tmp = tmpdir + '/dataset' hl.export_bgen(bgen, tmp) hl.index_bgen(tmp + '.bgen') - bgen2 = hl.import_bgen(tmp + '.bgen', - entry_fields=['GP'], - sample_file=tmp + '.sample') + bgen2 = hl.import_bgen(tmp + '.bgen', entry_fields=['GP'], sample_file=tmp + '.sample') assert bgen._same(bgen2) def test_export_bgen_zstd(self): - bgen = hl.import_bgen(resource('example.8bits.bgen'), - entry_fields=['GP'], - sample_file=resource('example.sample')) + bgen = hl.import_bgen( + resource('example.8bits.bgen'), entry_fields=['GP'], sample_file=resource('example.sample') + ) with hl.TemporaryDirectory(prefix='zstd', ensure_exists=False) as tmpdir: tmp = tmpdir + '/dataset' hl.export_bgen(bgen, tmp, compression_codec='zstd') hl.index_bgen(tmp + '.bgen') - bgen2 = hl.import_bgen(tmp + '.bgen', - entry_fields=['GP'], - sample_file=tmp + '.sample') + bgen2 = hl.import_bgen(tmp + '.bgen', entry_fields=['GP'], sample_file=tmp + '.sample') assert bgen._same(bgen2) def test_export_bgen_parallel(self): - bgen = hl.import_bgen(resource('example.8bits.bgen'), - entry_fields=['GP'], - sample_file=resource('example.sample'), - n_partitions=3) + bgen = hl.import_bgen( + resource('example.8bits.bgen'), entry_fields=['GP'], sample_file=resource('example.sample'), n_partitions=3 + ) with hl.TemporaryDirectory(ensure_exists=False) as tmpdir: tmp = tmpdir + '/dataset' hl.export_bgen(bgen, tmp, parallel='header_per_shard') hl.index_bgen(tmp + '.bgen') - bgen2 = hl.import_bgen(tmp + '.bgen', - entry_fields=['GP'], - sample_file=tmp + '.sample') + bgen2 = hl.import_bgen(tmp + '.bgen', entry_fields=['GP'], sample_file=tmp + '.sample') assert bgen._same(bgen2) fs = hl.current_backend().fs with fs.open(f'{tmp}.bgen/shard-manifest.txt') as lines: manifest_files = [os.path.join(f'{tmp}.bgen/', line.strip()) for line in lines] - bgen3 = hl.import_bgen(manifest_files, - entry_fields=['GP'], - sample_file=tmp + '.sample') + bgen3 = hl.import_bgen(manifest_files, entry_fields=['GP'], sample_file=tmp + '.sample') assert bgen._same(bgen3) def test_export_bgen_from_vcf(self): @@ -1499,19 +1627,22 @@ def test_export_bgen_from_vcf(self): with hl.TemporaryDirectory(ensure_exists=False) as tmpdir: tmp = tmpdir + '/dataset' - hl.export_bgen(mt, tmp, - gp=hl.or_missing( - hl.is_defined(mt.GT), - hl.map(lambda i: hl.if_else(mt.GT.unphased_diploid_gt_index() == i, 1.0, 0.0), - hl.range(0, hl.triangle(hl.len(mt.alleles)))))) - + hl.export_bgen( + mt, + tmp, + gp=hl.or_missing( + hl.is_defined(mt.GT), + hl.map( + lambda i: hl.if_else(mt.GT.unphased_diploid_gt_index() == i, 1.0, 0.0), + hl.range(0, hl.triangle(hl.len(mt.alleles))), + ), + ), + ) hl.index_bgen(tmp + '.bgen') - bgen2 = hl.import_bgen(tmp + '.bgen', - entry_fields=['GT'], - sample_file=tmp + '.sample') + bgen2 = hl.import_bgen(tmp + '.bgen', entry_fields=['GT'], sample_file=tmp + '.sample') mt = mt.select_entries('GT').select_rows().select_cols() - bgen2 = bgen2.unfilter_entries().select_rows() # drop varid, rsid + bgen2 = bgen2.unfilter_entries().select_rows() # drop varid, rsid assert bgen2._same(mt) def test_randomness(self): @@ -1525,18 +1656,19 @@ def test_randomness(self): hl.Struct(locus=hl.Locus('1', 100001), alleles=alleles), ] - bgen1 = hl.import_bgen(resource('example.8bits.bgen'), - entry_fields=['GT'], - sample_file=resource('example.sample'), - n_partitions=3) + bgen1 = hl.import_bgen( + resource('example.8bits.bgen'), entry_fields=['GT'], sample_file=resource('example.sample'), n_partitions=3 + ) bgen1 = bgen1.filter_rows(hl.literal(desired_variants).contains(bgen1.row_key)) c1 = bgen1.filter_entries(hl.rand_bool(0.2, seed=1234)) - bgen2 = hl.import_bgen(resource('example.8bits.bgen'), - entry_fields=['GT'], - sample_file=resource('example.sample'), - n_partitions=5, - variants=desired_variants) + bgen2 = hl.import_bgen( + resource('example.8bits.bgen'), + entry_fields=['GT'], + sample_file=resource('example.sample'), + n_partitions=5, + variants=desired_variants, + ) c2 = bgen2.filter_entries(hl.rand_bool(0.2, seed=1234)) assert c1._same(c2) @@ -1544,84 +1676,104 @@ def test_randomness(self): class GENTests(unittest.TestCase): def test_import_gen(self): - gen = hl.import_gen(resource('example.gen'), - sample_file=resource('example.sample'), - contig_recoding={"01": "1"}, - reference_genome = 'GRCh37').rows() + gen = hl.import_gen( + resource('example.gen'), + sample_file=resource('example.sample'), + contig_recoding={"01": "1"}, + reference_genome='GRCh37', + ).rows() self.assertTrue(gen.all(gen.locus.contig == "1")) self.assertEqual(gen.count(), 199) self.assertEqual(gen.locus.dtype, hl.tlocus('GRCh37')) def test_import_gen_no_chromosome_in_file(self): - gen = hl.import_gen(resource('no_chromosome.gen'), - resource('skip_invalid_loci.sample'), - chromosome="1", - reference_genome=None, - skip_invalid_loci=True) + gen = hl.import_gen( + resource('no_chromosome.gen'), + resource('skip_invalid_loci.sample'), + chromosome="1", + reference_genome=None, + skip_invalid_loci=True, + ) self.assertEqual(gen.aggregate_rows(hl.agg.all(gen.locus.contig == "1")), True) def test_import_gen_no_reference_specified(self): - gen = hl.import_gen(resource('example.gen'), - sample_file=resource('example.sample'), - reference_genome=None) + gen = hl.import_gen(resource('example.gen'), sample_file=resource('example.sample'), reference_genome=None) - self.assertEqual(gen.locus.dtype, - hl.tstruct(contig=hl.tstr, position=hl.tint32)) + self.assertEqual(gen.locus.dtype, hl.tstruct(contig=hl.tstr, position=hl.tint32)) self.assertEqual(gen.count_rows(), 199) def test_import_gen_skip_invalid_loci(self): - mt = hl.import_gen(resource('skip_invalid_loci.gen'), - resource('skip_invalid_loci.sample'), - reference_genome='GRCh37', - skip_invalid_loci=True) + mt = hl.import_gen( + resource('skip_invalid_loci.gen'), + resource('skip_invalid_loci.sample'), + reference_genome='GRCh37', + skip_invalid_loci=True, + ) self.assertEqual(mt._force_count_rows(), 3) with self.assertRaisesRegex(FatalError, 'Invalid locus'): - mt = hl.import_gen(resource('skip_invalid_loci.gen'), - resource('skip_invalid_loci.sample')) + mt = hl.import_gen(resource('skip_invalid_loci.gen'), resource('skip_invalid_loci.sample')) mt._force_count_rows() @test_timeout(local=4 * 60, batch=8 * 60) def test_export_gen(self): - gen = hl.import_gen(resource('example.gen'), - sample_file=resource('example.sample'), - contig_recoding={"01": "1"}, - reference_genome='GRCh37', - min_partitions=3) + gen = hl.import_gen( + resource('example.gen'), + sample_file=resource('example.sample'), + contig_recoding={"01": "1"}, + reference_genome='GRCh37', + min_partitions=3, + ) # permute columns so not in alphabetical order! import random + indices = list(range(gen.count_cols())) random.shuffle(indices) gen = gen.choose_cols(indices) file = new_temp_file() hl.export_gen(gen, file) - gen2 = hl.import_gen(file + '.gen', - sample_file=file + '.sample', - reference_genome='GRCh37', - min_partitions=3) + gen2 = hl.import_gen(file + '.gen', sample_file=file + '.sample', reference_genome='GRCh37', min_partitions=3) - self.assertTrue(gen._same(gen2, tolerance=3E-4, absolute=True)) + self.assertTrue(gen._same(gen2, tolerance=3e-4, absolute=True)) def test_export_gen_exprs(self): - gen = hl.import_gen(resource('example.gen'), - sample_file=resource('example.sample'), - contig_recoding={"01": "1"}, - reference_genome='GRCh37', - min_partitions=3).add_col_index().add_row_index() + gen = ( + hl.import_gen( + resource('example.gen'), + sample_file=resource('example.sample'), + contig_recoding={"01": "1"}, + reference_genome='GRCh37', + min_partitions=3, + ) + .add_col_index() + .add_row_index() + ) out1 = new_temp_file() - hl.export_gen(gen, out1, id1=hl.str(gen.col_idx), id2=hl.str(gen.col_idx), missing=0.5, - varid=hl.str(gen.row_idx), rsid=hl.str(gen.row_idx), gp=[0.0, 1.0, 0.0]) - - in1 = (hl.import_gen(out1 + '.gen', sample_file=out1 + '.sample', min_partitions=3) - .add_col_index() - .add_row_index()) - self.assertTrue(in1.aggregate_entries(hl.agg.fraction((hl.is_missing(in1.GP) | (in1.GP == [0.0, 1.0, 0.0])) == 1.0))) - self.assertTrue(in1.aggregate_rows(hl.agg.fraction((in1.varid == hl.str(in1.row_idx)) & - (in1.rsid == hl.str(in1.row_idx)))) == 1.0) + hl.export_gen( + gen, + out1, + id1=hl.str(gen.col_idx), + id2=hl.str(gen.col_idx), + missing=0.5, + varid=hl.str(gen.row_idx), + rsid=hl.str(gen.row_idx), + gp=[0.0, 1.0, 0.0], + ) + + in1 = ( + hl.import_gen(out1 + '.gen', sample_file=out1 + '.sample', min_partitions=3).add_col_index().add_row_index() + ) + self.assertTrue( + in1.aggregate_entries(hl.agg.fraction((hl.is_missing(in1.GP) | (in1.GP == [0.0, 1.0, 0.0])) == 1.0)) + ) + self.assertTrue( + in1.aggregate_rows(hl.agg.fraction((in1.varid == hl.str(in1.row_idx)) & (in1.rsid == hl.str(in1.row_idx)))) + == 1.0 + ) self.assertTrue(in1.aggregate_cols(hl.agg.fraction((in1.s == hl.str(in1.col_idx))))) @@ -1642,10 +1794,11 @@ def test_import_locus_intervals(self): tmp_file = new_temp_file(prefix="test", extension="interval_list") start = t.interval.start end = t.interval.end - (t - .key_by(interval=hl.locus_interval(start.contig, start.position, end.position, True, True)) - .select() - .export(tmp_file, header=False)) + ( + t.key_by(interval=hl.locus_interval(start.contig, start.position, end.position, True, True)) + .select() + .export(tmp_file, header=False) + ) t2 = hl.import_locus_intervals(tmp_file) @@ -1659,9 +1812,11 @@ def test_import_locus_intervals_no_reference_specified(self): def test_import_locus_intervals_recoding(self): interval_file = resource('annotinterall.grch38.no.chr.interval_list') - t = hl.import_locus_intervals(interval_file, - contig_recoding={str(i): f'chr{i}' for i in [*range(1, 23), 'X', 'Y', 'M']}, - reference_genome='GRCh38') + t = hl.import_locus_intervals( + interval_file, + contig_recoding={str(i): f'chr{i}' for i in [*range(1, 23), 'X', 'Y', 'M']}, + reference_genome='GRCh38', + ) self.assertEqual(t._force_count(), 3) self.assertEqual(t.interval.dtype.point_type, hl.tlocus('GRCh38')) @@ -1695,21 +1850,27 @@ def test_import_bed(self): t = hl.import_bed(bed_file, reference_genome='GRCh37') self.assertEqual(t.interval.dtype.point_type, hl.tlocus('GRCh37')) self.assertEqual(list(t.key.dtype), ['interval']) - self.assertEqual(list(t.row.dtype), ['interval','target']) + self.assertEqual(list(t.row.dtype), ['interval', 'target']) - expected = [hl.interval(hl.locus('20', 1), hl.locus('20', 11), True, False), # 20 0 10 gene0 - hl.interval(hl.locus('20', 2), hl.locus('20', 14000001), True, False), # 20 1 14000000 gene1 - hl.interval(hl.locus('20', 5), hl.locus('20', 6), False, False), # 20 5 5 gene4 - hl.interval(hl.locus('20', 17000001), hl.locus('20', 18000001), True, False), # 20 17000000 18000000 gene2 - hl.interval(hl.locus('20', 63025511), hl.locus('20', 63025520), True, True)] # 20 63025510 63025520 gene3 + expected = [ + hl.interval(hl.locus('20', 1), hl.locus('20', 11), True, False), # 20 0 10 gene0 + hl.interval(hl.locus('20', 2), hl.locus('20', 14000001), True, False), # 20 1 14000000 gene1 + hl.interval(hl.locus('20', 5), hl.locus('20', 6), False, False), # 20 5 5 gene4 + hl.interval( + hl.locus('20', 17000001), hl.locus('20', 18000001), True, False + ), # 20 17000000 18000000 gene2 + hl.interval(hl.locus('20', 63025511), hl.locus('20', 63025520), True, True), + ] # 20 63025510 63025520 gene3 self.assertEqual(t.interval.collect(), hl.eval(expected)) def test_import_bed_recoding(self): bed_file = resource('some-missing-chr-grch38.bed') - bed = hl.import_bed(bed_file, - reference_genome='GRCh38', - contig_recoding={str(i): f'chr{i}' for i in [*range(1, 23), 'X', 'Y', 'M']}) + bed = hl.import_bed( + bed_file, + reference_genome='GRCh38', + contig_recoding={str(i): f'chr{i}' for i in [*range(1, 23), 'X', 'Y', 'M']}, + ) self.assertEqual(bed._force_count(), 5) self.assertEqual(bed.interval.dtype.point_type, hl.tlocus('GRCh38')) @@ -1734,17 +1895,17 @@ def test_import_bed_badly_defined_intervals(self): def test_pass_through_args(self): interval_file = resource('example3.interval_list') - t = hl.import_locus_intervals(interval_file, - reference_genome='GRCh37', - skip_invalid_intervals=True, - filter=r'target_\d\d') + t = hl.import_locus_intervals( + interval_file, reference_genome='GRCh37', skip_invalid_intervals=True, filter=r'target_\d\d' + ) assert t.count() == 9 class ImportMatrixTableTests(unittest.TestCase): def test_import_matrix_table_1(self): - mt = hl.import_matrix_table(doctest_resource('matrix1.tsv'), - row_fields={'Barcode': hl.tstr, 'Tissue': hl.tstr, 'Days': hl.tfloat32}) + mt = hl.import_matrix_table( + doctest_resource('matrix1.tsv'), row_fields={'Barcode': hl.tstr, 'Tissue': hl.tstr, 'Days': hl.tfloat32} + ) self.assertEqual(mt['Barcode']._indices, mt._row_indices) self.assertEqual(mt['Tissue']._indices, mt._row_indices) self.assertEqual(mt['Days']._indices, mt._row_indices) @@ -1755,16 +1916,14 @@ def test_import_matrix_table_1(self): def test_import_matrix_table_2(self): hl.import_matrix_table( - doctest_resource('matrix2.tsv'), - row_fields={'f0': hl.tstr, 'f1': hl.tstr, 'f2': hl.tfloat32}, - row_key=[] + doctest_resource('matrix2.tsv'), row_fields={'f0': hl.tstr, 'f1': hl.tstr, 'f2': hl.tfloat32}, row_key=[] )._force_count_rows() def test_import_matrix_table_3(self): hl.import_matrix_table( doctest_resource('matrix3.tsv'), row_fields={'f0': hl.tstr, 'f1': hl.tstr, 'f2': hl.tfloat32}, - no_header=True + no_header=True, )._force_count_rows() def test_import_matrix_table_4(self): @@ -1772,11 +1931,18 @@ def test_import_matrix_table_4(self): doctest_resource('matrix3.tsv'), row_fields={'f0': hl.tstr, 'f1': hl.tstr, 'f2': hl.tfloat32}, no_header=True, - row_key=[] + row_key=[], )._force_count_rows() def test_import_matrix_table_no_cols(self): - fields = {'Chromosome': hl.tstr, 'Position': hl.tint32, 'Ref': hl.tstr, 'Alt': hl.tstr, 'Rand1': hl.tfloat64, 'Rand2': hl.tfloat64} + fields = { + 'Chromosome': hl.tstr, + 'Position': hl.tint32, + 'Ref': hl.tstr, + 'Alt': hl.tstr, + 'Rand1': hl.tfloat64, + 'Rand2': hl.tfloat64, + } file = resource('sample2_va_nomulti.tsv') mt = hl.import_matrix_table(file, row_fields=fields, row_key=['Chromosome', 'Position']) t = hl.import_table(file, types=fields, key=['Chromosome', 'Position']) @@ -1788,88 +1954,64 @@ def test_import_matrix_table_no_cols(self): def test_import_matrix_comment(self): no_comment = doctest_resource('matrix1.tsv') comment = doctest_resource('matrix1_comment.tsv') - row_fields={'Barcode': hl.tstr, 'Tissue': hl.tstr, 'Days': hl.tfloat32} - mt1 = hl.import_matrix_table(no_comment, - row_fields=row_fields, - row_key=[]) - mt2 = hl.import_matrix_table(comment, - row_fields=row_fields, - row_key=[], - comment=['#', '%']) + row_fields = {'Barcode': hl.tstr, 'Tissue': hl.tstr, 'Days': hl.tfloat32} + mt1 = hl.import_matrix_table(no_comment, row_fields=row_fields, row_key=[]) + mt2 = hl.import_matrix_table(comment, row_fields=row_fields, row_key=[], comment=['#', '%']) assert mt1._same(mt2) def test_headers_not_identical(self): with pytest.raises(ValueError, match='invalid header: lengths of headers differ'): - hl.import_matrix_table([resource("sampleheader1.txt"), resource("sampleheader2.txt")], - row_fields={'f0': hl.tstr}, row_key=['f0']) + hl.import_matrix_table( + [resource("sampleheader1.txt"), resource("sampleheader2.txt")], + row_fields={'f0': hl.tstr}, + row_key=['f0'], + ) def test_headers_same_len_diff_elem(self): with pytest.raises(ValueError, match='invalid header: expected elements to be identical for all input paths'): - hl.import_matrix_table([resource("sampleheader2.txt"), - resource("sampleheaderdiffelem.txt")], row_fields={'f0': hl.tstr}, row_key=['f0']) + hl.import_matrix_table( + [resource("sampleheader2.txt"), resource("sampleheaderdiffelem.txt")], + row_fields={'f0': hl.tstr}, + row_key=['f0'], + ) def test_too_few_entries(self): def boom(): - hl.import_matrix_table(resource("samplesmissing.txt"), - row_fields={'f0': hl.tstr}, - row_key=['f0'] - )._force_count_rows() + hl.import_matrix_table( + resource("samplesmissing.txt"), row_fields={'f0': hl.tstr}, row_key=['f0'] + )._force_count_rows() + with pytest.raises(HailUserError, match='unexpected end of line while reading entries'): boom() def test_wrong_row_field_type(self): with pytest.raises(HailUserError, match="error parsing value into int32 at row field 'f0'"): - hl.import_matrix_table(resource("sampleheader1.txt"), - row_fields={'f0': hl.tint32}, - row_key=['f0'])._force_count_rows() + hl.import_matrix_table( + resource("sampleheader1.txt"), row_fields={'f0': hl.tint32}, row_key=['f0'] + )._force_count_rows() def test_wrong_entry_type(self): with pytest.raises(HailUserError, match="error parsing value into int32 at column id 'col000003'"): - hl.import_matrix_table(resource("samplenonintentries.txt"), - row_fields={'f0': hl.tstr}, - row_key=['f0'])._force_count_rows() - - def test_key_by_after_empty_key_import(self): - fields = {'Chromosome':hl.tstr, - 'Position': hl.tint32, - 'Ref': hl.tstr, - 'Alt': hl.tstr} - mt = hl.import_matrix_table(resource('sample2_va_nomulti.tsv'), - row_fields=fields, - row_key=[], - entry_type=hl.tfloat) - mt = mt.key_rows_by('Chromosome', 'Position') - assert 0.001 < abs(0.50965 - mt.aggregate_entries(hl.agg.mean(mt.x))) + hl.import_matrix_table( + resource("samplenonintentries.txt"), row_fields={'f0': hl.tstr}, row_key=['f0'] + )._force_count_rows() def test_key_by_after_empty_key_import(self): - fields = {'Chromosome':hl.tstr, - 'Position': hl.tint32, - 'Ref': hl.tstr, - 'Alt': hl.tstr} - mt = hl.import_matrix_table(resource('sample2_va_nomulti.tsv'), - row_fields=fields, - row_key=[], - entry_type=hl.tfloat) + fields = {'Chromosome': hl.tstr, 'Position': hl.tint32, 'Ref': hl.tstr, 'Alt': hl.tstr} + mt = hl.import_matrix_table( + resource('sample2_va_nomulti.tsv'), row_fields=fields, row_key=[], entry_type=hl.tfloat + ) mt = mt.key_rows_by('Chromosome', 'Position') mt._force_count_rows() @test_timeout(local=4 * 60) def test_devilish_nine_separated_eight_missing_file(self): - fields = {'chr': hl.tstr, - '': hl.tint32, - 'ref': hl.tstr, - 'alt': hl.tstr} - mt = hl.import_matrix_table(resource('import_matrix_table_devlish.ninesv'), - row_fields=fields, - row_key=['chr', ''], - sep='9', - missing='8') + fields = {'chr': hl.tstr, '': hl.tint32, 'ref': hl.tstr, 'alt': hl.tstr} + mt = hl.import_matrix_table( + resource('import_matrix_table_devlish.ninesv'), row_fields=fields, row_key=['chr', ''], sep='9', missing='8' + ) actual = mt.x.collect() - expected = [ - 1, 2, 3, 4, - 11, 12, 13, 14, - 21, 22, 23, 24, - 31, None, None, 34] + expected = [1, 2, 3, 4, 11, 12, 13, 14, 21, 22, 23, 24, 31, None, None, 34] assert actual == expected assert mt.count_rows() == len(mt.rows().collect()) @@ -1895,18 +2037,15 @@ def test_empty_import_matrix_table(self): def test_import_row_id_multiple_partitions(self): path = new_temp_file(extension='txt') - (hl.utils.range_matrix_table(50, 50) - .annotate_entries(x=1) - .key_rows_by() - .key_cols_by() - .x - .export(path, header=False, delimiter=' ')) - - mt = hl.import_matrix_table(path, - no_header=True, - entry_type=hl.tint32, - delimiter=' ', - min_partitions=10) + ( + hl.utils.range_matrix_table(50, 50) + .annotate_entries(x=1) + .key_rows_by() + .key_cols_by() + .x.export(path, header=False, delimiter=' ') + ) + + mt = hl.import_matrix_table(path, no_header=True, entry_type=hl.tint32, delimiter=' ', min_partitions=10) assert mt.row_id.collect() == list(range(50)) def test_long_parsing(self): @@ -1915,7 +2054,7 @@ def test_long_parsing(self): collected = mt.entries().collect() assert collected == [ hl.utils.Struct(foo=7, row_id=0, col_id='s1', x=1234), - hl.utils.Struct(foo=7, row_id=0, col_id='s2', x=2345) + hl.utils.Struct(foo=7, row_id=0, col_id='s2', x=2345), ] @@ -1925,25 +2064,20 @@ def test_long_parsing(self): @pytest.mark.parametrize("missing", ['.', '9']) def test_import_matrix_table_round_trip(missing, delimiter, header, entry_fun): mt = hl.utils.range_matrix_table(10, 10, n_partitions=2) - mt = mt.annotate_entries(x = entry_fun(mt.row_idx * mt.col_idx)) - mt = mt.annotate_rows(row_str = hl.str(mt.row_idx)) - mt = mt.annotate_rows(row_float = hl.float(mt.row_idx)) + mt = mt.annotate_entries(x=entry_fun(mt.row_idx * mt.col_idx)) + mt = mt.annotate_rows(row_str=hl.str(mt.row_idx)) + mt = mt.annotate_rows(row_float=hl.float(mt.row_idx)) entry_type = mt.x.dtype path = new_temp_file(extension='tsv') - mt.key_rows_by(*mt.row).x.export(path, - missing=missing, - delimiter=delimiter, - header=header) + mt.key_rows_by(*mt.row).x.export(path, missing=missing, delimiter=delimiter, header=header) row_fields = {f: mt.row[f].dtype for f in mt.row} row_key = 'row_idx' if not header: - pseudonym = {'row_idx': 'f0', - 'row_str': 'f1', - 'row_float': 'f2'} + pseudonym = {'row_idx': 'f0', 'row_str': 'f1', 'row_float': 'f2'} row_fields = {pseudonym[k]: v for k, v in row_fields.items()} row_key = pseudonym[row_key] mt = mt.rename(pseudonym) @@ -1957,21 +2091,14 @@ def test_import_matrix_table_round_trip(missing, delimiter, header, entry_fun): entry_type=entry_type, missing=missing, no_header=not header, - sep=delimiter) + sep=delimiter, + ) actual = actual.rename({'col_id': 'col_idx'}) row_key = mt.row_key - col_key = mt.col_key mt = mt.key_rows_by() - mt = mt.annotate_entries( - x = hl.if_else(hl.str(mt.x) == missing, - hl.missing(entry_type), - mt.x)) - mt = mt.annotate_rows(**{ - f: hl.if_else(hl.str(mt[f]) == missing, - hl.missing(mt[f].dtype), - mt[f]) - for f in mt.row}) + mt = mt.annotate_entries(x=hl.if_else(hl.str(mt.x) == missing, hl.missing(entry_type), mt.x)) + mt = mt.annotate_rows(**{f: hl.if_else(hl.str(mt[f]) == missing, hl.missing(mt[f].dtype), mt[f]) for f in mt.row}) mt = mt.key_rows_by(*row_key) assert mt._same(actual) @@ -2019,7 +2146,10 @@ def test_import_table_empty(self): try: rows = hl.import_table(resource('empty.tsv')).collect() except ValueError as err: - assert f'Invalid file: no lines remaining after filters\n Files provided: {resource("empty.tsv")}' in err.args[0] + assert ( + f'Invalid file: no lines remaining after filters\n Files provided: {resource("empty.tsv")}' + in err.args[0] + ) else: assert False, rows @@ -2036,23 +2166,24 @@ def test_type_imputation(self): ht = hl.import_table(resource('variantAnnotations.tsv'), impute=True) assert ht.row.dtype == hl.dtype( - 'struct{Chromosome: int32, Position: int32, Ref: str, Alt: str, Rand1: float64, Rand2: float64, Gene: str}') + 'struct{Chromosome: int32, Position: int32, Ref: str, Alt: str, Rand1: float64, Rand2: float64, Gene: str}' + ) ht = hl.import_table(resource('variantAnnotations.tsv'), impute=True, types={'Chromosome': 'str'}) assert ht.row.dtype == hl.dtype( - 'struct{Chromosome: str, Position: int32, Ref: str, Alt: str, Rand1: float64, Rand2: float64, Gene: str}') + 'struct{Chromosome: str, Position: int32, Ref: str, Alt: str, Rand1: float64, Rand2: float64, Gene: str}' + ) ht = hl.import_table(resource('variantAnnotations.alternateformat.tsv'), impute=True) assert ht.row.dtype == hl.dtype( - 'struct{`Chromosome:Position:Ref:Alt`: str, Rand1: float64, Rand2: float64, Gene: str}') + 'struct{`Chromosome:Position:Ref:Alt`: str, Rand1: float64, Rand2: float64, Gene: str}' + ) ht = hl.import_table(resource('sampleAnnotations.tsv'), impute=True) - assert ht.row.dtype == hl.dtype( - 'struct{Sample: str, Status: str, qPhen: int32}') + assert ht.row.dtype == hl.dtype('struct{Sample: str, Status: str, qPhen: int32}') ht = hl.import_table(resource('integer_imputation.txt'), impute=True, delimiter=r'\s+') - assert ht.row.dtype == hl.dtype( - 'struct{A:int64, B:int32}') + assert ht.row.dtype == hl.dtype('struct{A:int64, B:int32}') def test_import_export_identity(self): fs = hl.current_backend().fs @@ -2069,21 +2200,25 @@ def test_import_export_identity(self): def small_dataset_1(self): data = [ - hl.Struct(Sample='Sample1',field1=5,field2=5), - hl.Struct(Sample='Sample2',field1=3,field2=5), - hl.Struct(Sample='Sample3',field1=2,field2=5), - hl.Struct(Sample='Sample4',field1=1,field2=5), + hl.Struct(Sample='Sample1', field1=5, field2=5), + hl.Struct(Sample='Sample2', field1=3, field2=5), + hl.Struct(Sample='Sample3', field1=2, field2=5), + hl.Struct(Sample='Sample4', field1=1, field2=5), ] return hl.Table.parallelize(data, key='Sample') def test_source_file(self): ht = hl.import_table(resource('variantAnnotations.split.*.tsv'), source_file_field='source') ht = ht.add_index() - assert ht.aggregate(hl.agg.all( - hl.if_else(ht.idx < 239, - ht.source.endswith('variantAnnotations.split.1.tsv'), - ht.source.endswith('variantAnnotations.split.2.tsv')))) - + assert ht.aggregate( + hl.agg.all( + hl.if_else( + ht.idx < 239, + ht.source.endswith('variantAnnotations.split.1.tsv'), + ht.source.endswith('variantAnnotations.split.2.tsv'), + ) + ) + ) def test_read_write_identity(self): ht = self.small_dataset_1() @@ -2115,24 +2250,30 @@ class GrepTests(unittest.TestCase): @fails_local_backend() def test_grep_show_false(self): from hail.backend.service_backend import ServiceBackend + if isinstance(hl.current_backend(), ServiceBackend): prefix = resource('') else: prefix = '' - expected = {prefix + 'sampleAnnotations.tsv': ['HG00120\tCASE\t19599', 'HG00121\tCASE\t4832'], - prefix + 'sample2_rename.tsv': ['HG00120\tB_HG00120', 'HG00121\tB_HG00121'], - prefix + 'sampleAnnotations2.tsv': ['HG00120\t3919.8\t19589', - 'HG00121\t966.4\t4822', - 'HG00120_B\t3919.8\t19589', - 'HG00121_B\t966.4\t4822', - 'HG00120_B_B\t3919.8\t19589', - 'HG00121_B_B\t966.4\t4822']} + expected = { + prefix + 'sampleAnnotations.tsv': ['HG00120\tCASE\t19599', 'HG00121\tCASE\t4832'], + prefix + 'sample2_rename.tsv': ['HG00120\tB_HG00120', 'HG00121\tB_HG00121'], + prefix + 'sampleAnnotations2.tsv': [ + 'HG00120\t3919.8\t19589', + 'HG00121\t966.4\t4822', + 'HG00120_B\t3919.8\t19589', + 'HG00121_B\t966.4\t4822', + 'HG00120_B_B\t3919.8\t19589', + 'HG00121_B_B\t966.4\t4822', + ], + } assert hl.grep('HG0012[0-1]', resource('*.tsv'), show=False) == expected class AvroTests(unittest.TestCase): - @fails_service_backend(reason=''' + @fails_service_backend( + reason=""" E java.io.NotSerializableException: org.apache.avro.Schema$RecordSchema E at java.io.ObjectOutputStream.writeObject0(ObjectOutputStream.java:1184) E at java.io.ObjectOutputStream.writeArray(ObjectOutputStream.java:1378) @@ -2166,7 +2307,8 @@ class AvroTests(unittest.TestCase): E at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149) E at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624) E at java.lang.Thread.run(Thread.java:748) -''') +""" + ) def test_simple_avro(self): avro_file = resource('avro/weather.avro') fs = hl.current_backend().fs diff --git a/hail/python/test/hail/methods/test_king.py b/hail/python/test/hail/methods/test_king.py index 1e86a1ae563..4e5fe9a44aa 100644 --- a/hail/python/test/hail/methods/test_king.py +++ b/hail/python/test/hail/methods/test_king.py @@ -1,28 +1,24 @@ import pytest import hail as hl -from ..helpers import resource, fails_local_backend, fails_service_backend + +from ..helpers import fails_local_backend, fails_service_backend, resource def assert_c_king_same_as_hail_king(c_king_path, hail_king_mt): actual = hail_king_mt.entries() - expected = hl.import_table(c_king_path, - types={'Kinship': hl.tfloat}, - key=['ID1', 'ID2']) - expected = expected.rename({'ID1': 's_1', - 'ID2': 's', - 'Kinship': 'phi'}) + expected = hl.import_table(c_king_path, types={'Kinship': hl.tfloat}, key=['ID1', 'ID2']) + expected = expected.rename({'ID1': 's_1', 'ID2': 's', 'Kinship': 'phi'}) expected = expected.key_by('s_1', 's') expected = expected.annotate(actual=actual[expected.key]) expected = expected.select( - expected=expected.phi, - actual=expected.actual.phi, - diff=expected.phi - expected.actual.phi + expected=expected.phi, actual=expected.actual.phi, diff=expected.phi - expected.actual.phi ) expected = expected.annotate( # KING prints 4 significant digits; but there are several instances # where we calculate 0.XXXX5 whereas KING outputs 0.XXXX - failure=hl.abs(expected.diff) > 0.00006) + failure=hl.abs(expected.diff) > 0.00006 + ) expected = expected.filter(expected.failure) assert expected.count() == 0, expected.collect() @@ -31,23 +27,19 @@ def assert_c_king_same_as_hail_king(c_king_path, hail_king_mt): @fails_local_backend() def test_king_small(): plink_path = resource('balding-nichols-1024-variants-4-samples-3-populations') - mt = hl.import_plink(bed=f'{plink_path}.bed', - bim=f'{plink_path}.bim', - fam=f'{plink_path}.fam') + mt = hl.import_plink(bed=f'{plink_path}.bed', bim=f'{plink_path}.bim', fam=f'{plink_path}.fam') kinship = hl.king(mt.GT) - assert_c_king_same_as_hail_king( - resource('balding-nichols-1024-variants-4-samples-3-populations.kin0'), - kinship) + assert_c_king_same_as_hail_king(resource('balding-nichols-1024-variants-4-samples-3-populations.kin0'), kinship) + @pytest.mark.unchecked_allocator @fails_service_backend() @fails_local_backend() def test_king_large(): plink_path = resource('fastlmmTest') - mt = hl.import_plink(bed=f'{plink_path}.bed', - bim=f'{plink_path}.bim', - fam=f'{plink_path}.fam', - reference_genome=None) + mt = hl.import_plink( + bed=f'{plink_path}.bed', bim=f'{plink_path}.bim', fam=f'{plink_path}.fam', reference_genome=None + ) kinship = hl.king(mt.GT) assert_c_king_same_as_hail_king(resource('fastlmmTest.kin0.bgz'), kinship) @@ -56,8 +48,6 @@ def test_king_large(): @fails_local_backend() def test_king_filtered_entries_no_error(): plink_path = resource('balding-nichols-1024-variants-4-samples-3-populations') - mt = hl.import_plink(bed=f'{plink_path}.bed', - bim=f'{plink_path}.bim', - fam=f'{plink_path}.fam') + mt = hl.import_plink(bed=f'{plink_path}.bed', bim=f'{plink_path}.bim', fam=f'{plink_path}.fam') mt = mt.filter_entries(hl.rand_bool(0.5)) hl.king(mt.GT)._force_count_rows() diff --git a/hail/python/test/hail/methods/test_misc.py b/hail/python/test/hail/methods/test_misc.py index 2e045b866f0..2064eb79616 100644 --- a/hail/python/test/hail/methods/test_misc.py +++ b/hail/python/test/hail/methods/test_misc.py @@ -1,29 +1,29 @@ import unittest import hail as hl -from ..helpers import * + +from ..helpers import get_dataset, resource, test_timeout class Tests(unittest.TestCase): def test_rename_duplicates(self): mt = hl.utils.range_matrix_table(5, 5) - assert hl.rename_duplicates( - mt.key_cols_by(s=hl.str(mt.col_idx)) - ).unique_id.collect() == ['0', '1', '2', '3', '4'] + assert hl.rename_duplicates(mt.key_cols_by(s=hl.str(mt.col_idx))).unique_id.collect() == [ + '0', + '1', + '2', + '3', + '4', + ] - assert hl.rename_duplicates( - mt.key_cols_by(s='0') - ).unique_id.collect() == ['0', '0_1', '0_2', '0_3', '0_4'] + assert hl.rename_duplicates(mt.key_cols_by(s='0')).unique_id.collect() == ['0', '0_1', '0_2', '0_3', '0_4'] assert hl.rename_duplicates( mt.key_cols_by(s=hl.literal(['0', '0_1', '0', '0_2', '0'])[mt.col_idx]) ).unique_id.collect() == ['0', '0_1', '0_2', '0_2_1', '0_3'] - assert hl.rename_duplicates( - mt.key_cols_by(s=hl.str(mt.col_idx)), - 'foo' - )['foo'].dtype == hl.tstr + assert hl.rename_duplicates(mt.key_cols_by(s=hl.str(mt.col_idx)), 'foo')['foo'].dtype == hl.tstr @test_timeout(local=3 * 60) def test_annotate_intervals_bed1(self): @@ -32,9 +32,9 @@ def test_annotate_intervals_bed1(self): interval_list1 = hl.import_locus_intervals(resource('exampleAnnotation1.interval_list')) ann = ds.annotate_rows(in_interval=bed1[ds.locus]).rows() - assert ann.all(hl.any(ann.locus.position <= 14000000, - ann.locus.position >= 17000000, - hl.is_missing(ann.in_interval))) + assert ann.all( + hl.any(ann.locus.position <= 14000000, ann.locus.position >= 17000000, hl.is_missing(ann.in_interval)) + ) intervallist = ds.annotate_rows(in_interval=interval_list1[ds.locus]).rows() bed = ds.annotate_rows(in_interval=bed1[ds.locus]).rows() assert intervallist._same(bed) @@ -51,10 +51,12 @@ def test_annotate_intervals_bed2(self): assert list(interval_list2.row.dtype) == ['interval', 'target'] ann = ds.annotate_rows(target=bed2[ds.locus].target).rows() - expr = (hl.case() - .when(ann.locus.position <= 14000000, ann.target == 'gene1') - .when(ann.locus.position >= 17000000, ann.target == 'gene2') - .default(ann.target == hl.missing(hl.tstr))) + expr = ( + hl.case() + .when(ann.locus.position <= 14000000, ann.target == 'gene1') + .when(ann.locus.position >= 17000000, ann.target == 'gene2') + .default(ann.target == hl.missing(hl.tstr)) + ) assert ann.all(expr) intervallist = ds.annotate_rows(target=interval_list2[ds.locus].target).rows() @@ -66,10 +68,12 @@ def test_annotate_intervals_bed3(self): ds = get_dataset() bed3 = hl.import_bed(resource('example3.bed'), reference_genome='GRCh37') ann = ds.annotate_rows(target=bed3[ds.locus].target).rows() - expr = (hl.case() - .when(ann.locus.position <= 14000000, ann.target == 'gene1') - .when(ann.locus.position >= 17000000, ann.target == 'gene2') - .default(ann.target == hl.missing(hl.tstr))) + expr = ( + hl.case() + .when(ann.locus.position <= 14000000, ann.target == 'gene1') + .when(ann.locus.position >= 17000000, ann.target == 'gene2') + .default(ann.target == hl.missing(hl.tstr)) + ) assert ann.all(expr) def test_maximal_independent_set(self): @@ -88,14 +92,12 @@ def test_maximal_independent_set(self): self.assertRaises(ValueError, lambda: hl.maximal_independent_set(hl.literal(1), hl.literal(2), True)) def test_maximal_independent_set2(self): - edges = [(0, 4), (0, 1), (0, 2), (1, 5), (1, 3), (2, 3), (2, 6), - (3, 7), (4, 5), (4, 6), (5, 7), (6, 7)] + edges = [(0, 4), (0, 1), (0, 2), (1, 5), (1, 3), (2, 3), (2, 6), (3, 7), (4, 5), (4, 6), (5, 7), (6, 7)] edges = [{"i": l, "j": r} for l, r in edges] t = hl.Table.parallelize(edges, hl.tstruct(i=hl.tint64, j=hl.tint64)) mis_t = hl.maximal_independent_set(t.i, t.j) - self.assertTrue(mis_t.row.dtype == hl.tstruct(node=hl.tint64) and - mis_t.globals.dtype == hl.tstruct()) + self.assertTrue(mis_t.row.dtype == hl.tstruct(node=hl.tint64) and mis_t.globals.dtype == hl.tstruct()) mis = set([row.node for row in mis_t.collect()]) maximal_indep_sets = [{0, 6, 5, 3}, {1, 4, 7, 2}] @@ -105,16 +107,15 @@ def test_maximal_independent_set2(self): def test_maximal_independent_set3(self): is_case = {"A", "C", "E", "G", "H"} edges = [("A", "B"), ("C", "D"), ("E", "F"), ("G", "H")] - edges = [{"i": {"id": l, "is_case": l in is_case}, - "j": {"id": r, "is_case": r in is_case}} for l, r in edges] + edges = [{"i": {"id": l, "is_case": l in is_case}, "j": {"id": r, "is_case": r in is_case}} for l, r in edges] - t = hl.Table.parallelize(edges, hl.tstruct(i=hl.tstruct(id=hl.tstr, is_case=hl.tbool), - j=hl.tstruct(id=hl.tstr, is_case=hl.tbool))) + t = hl.Table.parallelize( + edges, hl.tstruct(i=hl.tstruct(id=hl.tstr, is_case=hl.tbool), j=hl.tstruct(id=hl.tstr, is_case=hl.tbool)) + ) - tiebreaker = lambda l, r: (hl.case() - .when(l.is_case & (~r.is_case), -1) - .when(~(l.is_case) & r.is_case, 1) - .default(0)) + tiebreaker = lambda l, r: ( + hl.case().when(l.is_case & (~r.is_case), -1).when(~(l.is_case) & r.is_case, 1).default(0) + ) mis = hl.maximal_independent_set(t.i, t.j, tie_breaker=tiebreaker) @@ -125,22 +126,22 @@ def test_maximal_independent_set3(self): def test_maximal_independent_set_types(self): ht = hl.utils.range_table(10) - ht = ht.annotate(i=hl.struct(a='1', b=hl.rand_norm(0, 1)), - j=hl.struct(a='2', b=hl.rand_norm(0, 1))) - ht = ht.annotate(ii=hl.struct(id=ht.i, rank=hl.rand_norm(0, 1)), - jj=hl.struct(id=ht.j, rank=hl.rand_norm(0, 1))) + ht = ht.annotate(i=hl.struct(a='1', b=hl.rand_norm(0, 1)), j=hl.struct(a='2', b=hl.rand_norm(0, 1))) + ht = ht.annotate(ii=hl.struct(id=ht.i, rank=hl.rand_norm(0, 1)), jj=hl.struct(id=ht.j, rank=hl.rand_norm(0, 1))) hl.maximal_independent_set(ht.ii, ht.jj).count() def test_maximal_independent_set_on_floats(self): - t = hl.utils.range_table(1).annotate(l = hl.struct(s="a", x=3.0), r = hl.struct(s="b", x=2.82)) + t = hl.utils.range_table(1).annotate(l=hl.struct(s="a", x=3.0), r=hl.struct(s="b", x=2.82)) expected = [hl.Struct(node=hl.Struct(s="a", x=3.0))] - actual = hl.maximal_independent_set(t.l, t.r, keep=False, tie_breaker=lambda l,r: l.x - r.x).collect() + actual = hl.maximal_independent_set(t.l, t.r, keep=False, tie_breaker=lambda l, r: l.x - r.x).collect() assert actual == expected def test_maximal_independent_set_string_node_names(self): - ht = hl.Table.parallelize([hl.Struct(i='A', j='B', kin=0.25), - hl.Struct(i='A', j='C', kin=0.25), - hl.Struct(i='D', j='E', kin=0.5)]) + ht = hl.Table.parallelize([ + hl.Struct(i='A', j='B', kin=0.25), + hl.Struct(i='A', j='C', kin=0.25), + hl.Struct(i='D', j='E', kin=0.5), + ]) ret = hl.maximal_independent_set(ht.i, ht.j, False).collect() exp = [hl.Struct(node='A'), hl.Struct(node='D')] assert exp == ret @@ -148,61 +149,74 @@ def test_maximal_independent_set_string_node_names(self): def test_matrix_filter_intervals(self): ds = hl.import_vcf(resource('sample.vcf'), min_partitions=20) - self.assertEqual( - hl.filter_intervals(ds, [hl.parse_locus_interval('20:10639222-10644705')]).count_rows(), 3) + self.assertEqual(hl.filter_intervals(ds, [hl.parse_locus_interval('20:10639222-10644705')]).count_rows(), 3) - intervals = [hl.parse_locus_interval('20:10639222-10644700'), - hl.parse_locus_interval('20:10644700-10644705')] + intervals = [hl.parse_locus_interval('20:10639222-10644700'), hl.parse_locus_interval('20:10644700-10644705')] self.assertEqual(hl.filter_intervals(ds, intervals).count_rows(), 3) - intervals = hl.array([hl.parse_locus_interval('20:10639222-10644700'), - hl.parse_locus_interval('20:10644700-10644705')]) + intervals = hl.array([ + hl.parse_locus_interval('20:10639222-10644700'), + hl.parse_locus_interval('20:10644700-10644705'), + ]) self.assertEqual(hl.filter_intervals(ds, intervals).count_rows(), 3) - intervals = hl.array([hl.eval(hl.parse_locus_interval('20:10639222-10644700')), - hl.parse_locus_interval('20:10644700-10644705')]) + intervals = hl.array([ + hl.eval(hl.parse_locus_interval('20:10639222-10644700')), + hl.parse_locus_interval('20:10644700-10644705'), + ]) self.assertEqual(hl.filter_intervals(ds, intervals).count_rows(), 3) - intervals = [hl.eval(hl.parse_locus_interval('[20:10019093-10026348]')), - hl.eval(hl.parse_locus_interval('[20:17705793-17716416]'))] + intervals = [ + hl.eval(hl.parse_locus_interval('[20:10019093-10026348]')), + hl.eval(hl.parse_locus_interval('[20:17705793-17716416]')), + ] self.assertEqual(hl.filter_intervals(ds, intervals).count_rows(), 4) def test_table_filter_intervals(self): ds = hl.import_vcf(resource('sample.vcf'), min_partitions=20).rows() - self.assertEqual( - hl.filter_intervals(ds, [hl.parse_locus_interval('20:10639222-10644705')]).count(), 3) + self.assertEqual(hl.filter_intervals(ds, [hl.parse_locus_interval('20:10639222-10644705')]).count(), 3) - intervals = [hl.parse_locus_interval('20:10639222-10644700'), - hl.parse_locus_interval('20:10644700-10644705')] + intervals = [hl.parse_locus_interval('20:10639222-10644700'), hl.parse_locus_interval('20:10644700-10644705')] self.assertEqual(hl.filter_intervals(ds, intervals).count(), 3) - intervals = hl.array([hl.parse_locus_interval('20:10639222-10644700'), - hl.parse_locus_interval('20:10644700-10644705')]) + intervals = hl.array([ + hl.parse_locus_interval('20:10639222-10644700'), + hl.parse_locus_interval('20:10644700-10644705'), + ]) self.assertEqual(hl.filter_intervals(ds, intervals).count(), 3) - intervals = hl.array([hl.eval(hl.parse_locus_interval('20:10639222-10644700')), - hl.parse_locus_interval('20:10644700-10644705')]) + intervals = hl.array([ + hl.eval(hl.parse_locus_interval('20:10639222-10644700')), + hl.parse_locus_interval('20:10644700-10644705'), + ]) self.assertEqual(hl.filter_intervals(ds, intervals).count(), 3) - intervals = [hl.eval(hl.parse_locus_interval('[20:10019093-10026348]')), - hl.eval(hl.parse_locus_interval('[20:17705793-17716416]'))] + intervals = [ + hl.eval(hl.parse_locus_interval('[20:10019093-10026348]')), + hl.eval(hl.parse_locus_interval('[20:17705793-17716416]')), + ] self.assertEqual(hl.filter_intervals(ds, intervals).count(), 4) def test_filter_intervals_compound_key(self): ds = hl.import_vcf(resource('sample.vcf'), min_partitions=20) - ds = (ds.annotate_rows(variant=hl.struct(locus=ds.locus, alleles=ds.alleles)) - .key_rows_by('locus', 'alleles')) - - intervals = [hl.Interval(hl.Struct(locus=hl.Locus('20', 10639222), alleles=['A', 'T']), - hl.Struct(locus=hl.Locus('20', 10644700), alleles=['A', 'T']))] + ds = ds.annotate_rows(variant=hl.struct(locus=ds.locus, alleles=ds.alleles)).key_rows_by('locus', 'alleles') + + intervals = [ + hl.Interval( + hl.Struct(locus=hl.Locus('20', 10639222), alleles=['A', 'T']), + hl.Struct(locus=hl.Locus('20', 10644700), alleles=['A', 'T']), + ) + ] self.assertEqual(hl.filter_intervals(ds, intervals).count_rows(), 3) def test_summarize_variants(self): mt = hl.utils.range_matrix_table(3, 3) - variants = hl.literal({0: hl.Struct(locus=hl.Locus('1', 1), alleles=['A', 'T', 'C']), - 1: hl.Struct(locus=hl.Locus('2', 1), alleles=['A', 'AT', '@']), - 2: hl.Struct(locus=hl.Locus('2', 1), alleles=['AC', 'GT'])}) + variants = hl.literal({ + 0: hl.Struct(locus=hl.Locus('1', 1), alleles=['A', 'T', 'C']), + 1: hl.Struct(locus=hl.Locus('2', 1), alleles=['A', 'AT', '@']), + 2: hl.Struct(locus=hl.Locus('2', 1), alleles=['AC', 'GT']), + }) mt = mt.annotate_rows(**variants[mt.row_idx]).key_rows_by('locus', 'alleles') r = hl.summarize_variants(mt, show=False) self.assertEqual(r.n_variants, 3) @@ -217,7 +231,7 @@ def test_verify_biallelic(self): def test_lambda_gc(self): N = 5000000 - ht = hl.utils.range_table(N).annotate(x = hl.scan.count() / N, x2 = (hl.scan.count() / N) ** 1.5) + ht = hl.utils.range_table(N).annotate(x=hl.scan.count() / N, x2=(hl.scan.count() / N) ** 1.5) lgc = hl.lambda_gc(ht.x) lgc2 = hl.lambda_gc(ht.x2) self.assertAlmostEqual(lgc, 1, places=1) # approximate, 1 place is safe @@ -225,7 +239,7 @@ def test_lambda_gc(self): def test_lambda_gc_nans(self): N = 5000000 - ht = hl.utils.range_table(N).annotate(x = hl.scan.count() / N, is_even=hl.scan.count() % 2 == 0) + ht = hl.utils.range_table(N).annotate(x=hl.scan.count() / N, is_even=hl.scan.count() % 2 == 0) lgc_nan = hl.lambda_gc(hl.case().when(ht.is_even, hl.float('nan')).default(ht.x)) self.assertAlmostEqual(lgc_nan, 1, places=1) # approximate, 1 place is safe @@ -234,21 +248,24 @@ def test_segment_intervals(self): [ hl.struct(interval=hl.interval(0, 10)), hl.struct(interval=hl.interval(20, 50)), - hl.struct(interval=hl.interval(52, 52)) + hl.struct(interval=hl.interval(52, 52)), ], schema=hl.tstruct(interval=hl.tinterval(hl.tint32)), - key='interval' + key='interval', ) points1 = [-1, 5, 30, 40, 52, 53] segmented1 = hl.segment_intervals(intervals, points1) - assert segmented1.aggregate(hl.agg.collect(segmented1.interval) == [ - hl.interval(0, 5), - hl.interval(5, 10), - hl.interval(20, 30), - hl.interval(30, 40), - hl.interval(40, 50), - hl.interval(52, 52) - ]) + assert segmented1.aggregate( + hl.agg.collect(segmented1.interval) + == [ + hl.interval(0, 5), + hl.interval(5, 10), + hl.interval(20, 30), + hl.interval(30, 40), + hl.interval(40, 50), + hl.interval(52, 52), + ] + ) diff --git a/hail/python/test/hail/methods/test_pca.py b/hail/python/test/hail/methods/test_pca.py index bb18daf85d2..adf3b331672 100644 --- a/hail/python/test/hail/methods/test_pca.py +++ b/hail/python/test/hail/methods/test_pca.py @@ -1,10 +1,12 @@ import math -import pytest + import numpy as np +import pytest import hail as hl from hail.methods.pca import _make_tsm -from ..helpers import resource, fails_local_backend, skip_when_service_backend, test_timeout + +from ..helpers import fails_local_backend, resource, skip_when_service_backend, test_timeout @fails_local_backend() @@ -26,20 +28,16 @@ def test_hwe_normalized_pca(): @test_timeout(batch=10 * 60) def test_pca_against_numpy(): mt = hl.import_vcf(resource('tiny_m.vcf')) - mt = mt.annotate_rows(AC=hl.agg.sum(mt.GT.n_alt_alleles()), - n_called=hl.agg.count_where(hl.is_defined(mt.GT))) + mt = mt.annotate_rows(AC=hl.agg.sum(mt.GT.n_alt_alleles()), n_called=hl.agg.count_where(hl.is_defined(mt.GT))) n_rows = 3 n_cols = 4 k = 3 mean = mt.AC / mt.n_called eigen, scores, loadings = hl.pca( - hl.coalesce( - (mt.GT.n_alt_alleles() - mean) / hl.sqrt(mean * (2 - mean) * n_rows / 2), - 0 - ), + hl.coalesce((mt.GT.n_alt_alleles() - mean) / hl.sqrt(mean * (2 - mean) * n_rows / 2), 0), k=k, - compute_loadings=True + compute_loadings=True, ) hail_scores = scores.explode('scores').scores.collect() @@ -75,17 +73,13 @@ def concatToNumpy(blocks, horizontal=True): return np.concatenate(blocks, axis=1) mt = hl.import_vcf(resource('tiny_m.vcf')) - mt = mt.annotate_rows(AC=hl.agg.sum(mt.GT.n_alt_alleles()), - n_called=hl.agg.count_where(hl.is_defined(mt.GT))) + mt = mt.annotate_rows(AC=hl.agg.sum(mt.GT.n_alt_alleles()), n_called=hl.agg.count_where(hl.is_defined(mt.GT))) n_rows = 3 n_cols = 4 k = 3 mean = mt.AC / mt.n_called - float_expr = hl.coalesce( - (mt.GT.n_alt_alleles() - mean) / hl.sqrt(mean * (2 - mean) * n_rows / 2), - 0 - ) + float_expr = hl.coalesce((mt.GT.n_alt_alleles() - mean) / hl.sqrt(mean * (2 - mean) * n_rows / 2), 0) eigens, scores_t, loadings_t = hl._blanczos_pca(float_expr, k=k, q_iterations=7, compute_loadings=True) A = np.array(float_expr.collect()).reshape((3, 4)).T @@ -118,7 +112,7 @@ def normalize(a): np_eigenvalues = np.multiply(s, s) def bound(vs, us): # equation 12 from https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4827102/pdf/main.pdf - return 1/k * sum([np.linalg.norm(us.T @ vs[:,i]) for i in range(k)]) + return 1 / k * sum([np.linalg.norm(us.T @ vs[:, i]) for i in range(k)]) np.testing.assert_allclose(eigens, np_eigenvalues, rtol=0.05) assert bound(np_loadings, loadings) > 0.9 @@ -127,8 +121,8 @@ def bound(vs, us): # equation 12 from https://www.ncbi.nlm.nih.gov/pmc/articles def matrix_table_from_numpy(np_mat): rows, cols = np_mat.shape mt = hl.utils.range_matrix_table(rows, cols) - mt = mt.annotate_globals(entries_global = np_mat) - mt = mt.annotate_entries(ent = mt.entries_global[mt.row_idx, mt.col_idx]) + mt = mt.annotate_globals(entries_global=np_mat) + mt = mt.annotate_entries(ent=mt.entries_global[mt.row_idx, mt.col_idx]) return mt @@ -147,7 +141,9 @@ def test_blanczos_T(): A = (U * sigma) @ V.T mt_A_T = matrix_table_from_numpy(A.T) - eigenvalues, scores, loadings = hl._blanczos_pca(mt_A_T.ent, k=k, oversampling_param=k, q_iterations=4, compute_loadings=True, transpose=True) + eigenvalues, scores, loadings = hl._blanczos_pca( + mt_A_T.ent, k=k, oversampling_param=k, q_iterations=4, compute_loadings=True, transpose=True + ) singulars = np.sqrt(eigenvalues) hail_V = (np.array(scores.scores.collect()) / singulars).T hail_U = np.array(loadings.loadings.collect()) @@ -157,12 +153,15 @@ def test_blanczos_T(): np.testing.assert_allclose(singulars, sigma[:k], rtol=1e-01) -@pytest.mark.parametrize("compute_loadings,compute_scores,transpose", [ - (compute_loadings, compute_scores, transpose) - for compute_loadings in [True, False] - for compute_scores in [True, False] - for transpose in [True, False] -]) +@pytest.mark.parametrize( + "compute_loadings,compute_scores,transpose", + [ + (compute_loadings, compute_scores, transpose) + for compute_loadings in [True, False] + for compute_scores in [True, False] + for transpose in [True, False] + ], +) @skip_when_service_backend() def test_blanczos_flags(compute_loadings, compute_scores, transpose): k, m, n = 10, 100, 200 @@ -177,11 +176,18 @@ def test_blanczos_flags(compute_loadings, compute_scores, transpose): # compare absolute values to account for +-1 indeterminacy factor in singular vectors U = np.abs(U[:, :k]) V = np.abs(V[:, :k]) - Usigma = U * sigma[:k] Vsigma = V * sigma[:k] mt = mt_A_T if transpose else mt_A - eigenvalues, scores, loadings = hl._blanczos_pca(mt.ent, k=k, oversampling_param=k, q_iterations=4, compute_loadings=compute_loadings, compute_scores=compute_scores, transpose=transpose) + eigenvalues, scores, loadings = hl._blanczos_pca( + mt.ent, + k=k, + oversampling_param=k, + q_iterations=4, + compute_loadings=compute_loadings, + compute_scores=compute_scores, + transpose=transpose, + ) if compute_loadings: loadings = np.array(loadings.loadings.collect()) np.testing.assert_allclose(np.abs(loadings), U, rtol=1e-02) @@ -207,18 +213,22 @@ def spectra_helper(spec_func, triplet): A = U @ sigma @ V.T mt_A = matrix_table_from_numpy(A) - eigenvalues, scores, loadings = hl._blanczos_pca(mt_A.ent, k=k, oversampling_param=k, compute_loadings=True, q_iterations=4) + eigenvalues, scores, loadings = hl._blanczos_pca( + mt_A.ent, k=k, oversampling_param=k, compute_loadings=True, q_iterations=4 + ) singulars = np.sqrt(eigenvalues) hail_V = (np.array(scores.scores.collect()) / singulars).T hail_U = np.array(loadings.loadings.collect()) approx_A = hail_U @ np.diag(singulars) @ hail_V norm_of_diff = np.linalg.norm(A - approx_A, 2) - np.testing.assert_allclose(norm_of_diff, spec_func(k + 1, k), rtol=1e-02, err_msg=f"Norm test failed on triplet {triplet} ") + np.testing.assert_allclose( + norm_of_diff, spec_func(k + 1, k), rtol=1e-02, err_msg=f"Norm test failed on triplet {triplet} " + ) np.testing.assert_allclose(singulars, np.diag(sigma)[:k], rtol=1e-01, err_msg=f"Failed on triplet {triplet}") def spec1(j, k): - return 1/j + return 1 / j def spec2(j, k): @@ -227,19 +237,19 @@ def spec2(j, k): if j <= k: return 2 * 10**-5 else: - return (10**-5) * (k + 1)/j + return (10**-5) * (k + 1) / j def spec3(j, k): if j <= k: - return 10**(-5*(j-1)/(k-1)) + return 10 ** (-5 * (j - 1) / (k - 1)) else: - return (10**-5)*(k+1)/j + return (10**-5) * (k + 1) / j def spec4(j, k): if j <= k: - return 10**(-5*(j-1)/(k-1)) + return 10 ** (-5 * (j - 1) / (k - 1)) elif j == (k + 1): return 10**-5 else: @@ -248,9 +258,9 @@ def spec4(j, k): def spec5(j, k): if j <= k: - return 10**-5 + (1 - 10**-5)*(k - j)/(k - 1) + return 10**-5 + (1 - 10**-5) * (k - j) / (k - 1) else: - return 10**-5 * math.sqrt((k + 1)/j) + return 10**-5 * math.sqrt((k + 1) / j) @pytest.mark.parametrize("triplet", dim_triplets) @@ -287,7 +297,7 @@ def spectral_moments_helper(spec_func): for triplet in [(20, 1000, 1000)]: k, m, n = triplet min_dim = min(m, n) - sigma = np.diag([spec_func(i+1, k) for i in range(min_dim)]) + sigma = np.diag([spec_func(i + 1, k) for i in range(min_dim)]) seed = 1025 np.random.seed(seed) U = np.linalg.qr(np.random.normal(0, 1, (m, min_dim)))[0] @@ -296,7 +306,7 @@ def spectral_moments_helper(spec_func): mt_A = matrix_table_from_numpy(A) moments, stdevs = hl._spectral_moments(_make_tsm(mt_A.ent, 128), 7) - true_moments = np.array([np.sum(np.power(sigma, 2*i)) for i in range(1, 8)]) + true_moments = np.array([np.sum(np.power(sigma, 2 * i)) for i in range(1, 8)]) np.testing.assert_allclose(moments, true_moments, rtol=2e-01) @@ -334,7 +344,7 @@ def spectra_and_moments_helper(spec_func): for triplet in [(20, 1000, 1000)]: k, m, n = triplet min_dim = min(m, n) - sigma = np.diag([spec_func(i+1, k) for i in range(min_dim)]) + sigma = np.diag([spec_func(i + 1, k) for i in range(min_dim)]) seed = 1025 np.random.seed(seed) U = np.linalg.qr(np.random.normal(0, 1, (m, min_dim)))[0] @@ -342,16 +352,20 @@ def spectra_and_moments_helper(spec_func): A = U @ sigma @ V.T mt_A = matrix_table_from_numpy(A) - eigenvalues, scores, loadings, moments, stdevs = hl._pca_and_moments(_make_tsm(mt_A.ent, 128), k=k, num_moments=7, oversampling_param=k, compute_loadings=True, q_iterations=4) + eigenvalues, scores, loadings, moments, stdevs = hl._pca_and_moments( + _make_tsm(mt_A.ent, 128), k=k, num_moments=7, oversampling_param=k, compute_loadings=True, q_iterations=4 + ) singulars = np.sqrt(eigenvalues) hail_V = (np.array(scores.scores.collect()) / singulars).T hail_U = np.array(loadings.loadings.collect()) approx_A = hail_U @ np.diag(singulars) @ hail_V norm_of_diff = np.linalg.norm(A - approx_A, 2) - np.testing.assert_allclose(norm_of_diff, spec_func(k + 1, k), rtol=1e-02, err_msg=f"Norm test failed on triplet {triplet}") + np.testing.assert_allclose( + norm_of_diff, spec_func(k + 1, k), rtol=1e-02, err_msg=f"Norm test failed on triplet {triplet}" + ) np.testing.assert_allclose(singulars, np.diag(sigma)[:k], rtol=1e-01, err_msg=f"Failed on triplet {triplet}") - true_moments = np.array([np.sum(np.power(sigma, 2*i)) for i in range(1, 8)]) + true_moments = np.array([np.sum(np.power(sigma, 2 * i)) for i in range(1, 8)]) np.testing.assert_allclose(moments, true_moments, rtol=1e-04) diff --git a/hail/python/test/hail/methods/test_qc.py b/hail/python/test/hail/methods/test_qc.py index a7a953b4b75..f4aaed7ae40 100644 --- a/hail/python/test/hail/methods/test_qc.py +++ b/hail/python/test/hail/methods/test_qc.py @@ -1,15 +1,47 @@ +import os import unittest +import pytest + import hail as hl import hail.expr.aggregators as agg -from hail.utils.misc import new_temp_file -from ..helpers import * +from hail.methods.qc import VEPConfigGRCh37Version85, VEPConfigGRCh38Version95 +from ..helpers import ( + get_dataset, + qobtest, + resource, + set_gcs_requester_pays_configuration, + skip_unless_service_backend, + test_timeout, +) GCS_REQUESTER_PAYS_PROJECT = os.environ.get('GCS_REQUESTER_PAYS_PROJECT') class Tests(unittest.TestCase): + @property + def vep_config_grch37_85(self): + return VEPConfigGRCh37Version85( + data_bucket='hail-qob-vep-grch37-us-central1', + data_mount='/vep_data/', + image=os.environ['HAIL_GENETICS_VEP_GRCH37_85_IMAGE'], + regions=['us-central1'], + cloud='gcp', + data_bucket_is_requester_pays=True, + ) + + @property + def vep_config_grch38_95(self): + return VEPConfigGRCh38Version95( + data_bucket='hail-qob-vep-grch38-us-central1', + data_mount='/vep_data/', + image=os.environ['HAIL_GENETICS_VEP_GRCH38_95_IMAGE'], + regions=['us-central1'], + cloud='gcp', + data_bucket_is_requester_pays=True, + ) + @qobtest def test_sample_qc(self): data = [ @@ -116,7 +148,9 @@ def test_variant_qc(self): def test_variant_qc_alleles_field(self): mt = hl.balding_nichols_model(1, 1, 1) mt = mt.key_rows_by().drop('alleles') - with pytest.raises(ValueError, match="Method 'variant_qc' requires a field 'alleles' \\(type 'array'\\).*"): + with pytest.raises( + ValueError, match="Method 'variant_qc' requires a field 'alleles' \\(type 'array'\\).*" + ): hl.variant_qc(mt).variant_qc.collect() mt = hl.balding_nichols_model(1, 1, 1) @@ -130,13 +164,14 @@ def test_concordance(self): self.assertEqual(sum([sum(glob_conc[i]) for i in range(5)]), dataset.count_rows() * dataset.count_cols()) - counts = dataset.aggregate_entries(hl.Struct(n_het=agg.filter(dataset.GT.is_het(), agg.count()), - n_hom_ref=agg.filter(dataset.GT.is_hom_ref(), - agg.count()), - n_hom_var=agg.filter(dataset.GT.is_hom_var(), - agg.count()), - nNoCall=agg.filter(hl.is_missing(dataset.GT), - agg.count()))) + counts = dataset.aggregate_entries( + hl.Struct( + n_het=agg.filter(dataset.GT.is_het(), agg.count()), + n_hom_ref=agg.filter(dataset.GT.is_hom_ref(), agg.count()), + n_hom_var=agg.filter(dataset.GT.is_hom_var(), agg.count()), + nNoCall=agg.filter(hl.is_missing(dataset.GT), agg.count()), + ) + ) self.assertEqual(glob_conc[0][0], 0) self.assertEqual(glob_conc[1][1], counts.nNoCall) @@ -167,7 +202,7 @@ def test_concordance_n_discordant_2(self): hl.Struct(**{'locus': hl.Locus('1', 100), 'alleles': ['A', 'T'], 's': '4', 'GT': hl.Call([1, 1])}), hl.Struct(**{'locus': hl.Locus('1', 101), 'alleles': ['A', 'T'], 's': '1', 'GT': hl.Call([1, 1])}), ] - rows2=[ + rows2 = [ hl.Struct(**{'locus': hl.Locus('1', 100), 'alleles': ['A', 'T'], 's': '1', 'GT': None}), hl.Struct(**{'locus': hl.Locus('1', 100), 'alleles': ['A', 'T'], 's': '2', 'GT': hl.Call([0, 1])}), hl.Struct(**{'locus': hl.Locus('1', 100), 'alleles': ['A', 'T'], 's': '3', 'GT': hl.Call([0, 1])}), @@ -180,56 +215,42 @@ def make_mt(rows): global_conc_2, cols_conc_2, rows_conc_2 = hl.concordance(make_mt(rows1), make_mt(rows2)) assert cols_conc_2.collect() == [ - hl.Struct(s='1', - concordance=[[0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 1, 0, 0, 0], - [0, 0, 0, 0, 0], - [1, 0, 0, 0, 0]], - n_discordant=0), - hl.Struct(s='2', - concordance=[[1, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 1, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0]], - n_discordant=1), - hl.Struct(s='3', - concordance=[[1, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 1, 0]], - n_discordant=1), - hl.Struct(s='4', - concordance=[[1, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 1]], - n_discordant=0), + hl.Struct( + s='1', + concordance=[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 0]], + n_discordant=0, + ), + hl.Struct( + s='2', + concordance=[[1, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]], + n_discordant=1, + ), + hl.Struct( + s='3', + concordance=[[1, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 1, 0]], + n_discordant=1, + ), + hl.Struct( + s='4', + concordance=[[1, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 1]], + n_discordant=0, + ), ] - assert global_conc_2 == [[3, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 1, 0, 1, 0], - [0, 0, 0, 0, 0], - [1, 0, 0, 1, 1]] + assert global_conc_2 == [[3, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 1, 0, 1, 0], [0, 0, 0, 0, 0], [1, 0, 0, 1, 1]] assert rows_conc_2.collect() == [ - hl.Struct(locus=hl.Locus('1', 100), alleles=['A', 'T'], - concordance=[[0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 1, 0, 1, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 1, 1]], - n_discordant=2), - hl.Struct(locus=hl.Locus('1', 101), alleles=['A', 'T'], - concordance=[[3, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0], - [1, 0, 0, 0, 0]], - n_discordant=0), + hl.Struct( + locus=hl.Locus('1', 100), + alleles=['A', 'T'], + concordance=[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 1, 0, 1, 0], [0, 0, 0, 0, 0], [0, 0, 0, 1, 1]], + n_discordant=2, + ), + hl.Struct( + locus=hl.Locus('1', 101), + alleles=['A', 'T'], + concordance=[[3, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 0]], + n_discordant=0, + ), ] def test_concordance_no_values_doesnt_error(self): @@ -240,13 +261,10 @@ def test_concordance_no_values_doesnt_error(self): def test_filter_alleles(self): # poor man's Gen - paths = [resource('sample.vcf'), - resource('multipleChromosomes.vcf'), - resource('sample2.vcf')] + paths = [resource('sample.vcf'), resource('multipleChromosomes.vcf'), resource('sample2.vcf')] for path in paths: ds = hl.import_vcf(path) - self.assertEqual( - hl.filter_alleles(ds, lambda a, i: False).count_rows(), 0) + self.assertEqual(hl.filter_alleles(ds, lambda a, i: False).count_rows(), 0) self.assertEqual(hl.filter_alleles(ds, lambda a, i: True).count_rows(), ds.count_rows()) def test_filter_alleles_hts_1(self): @@ -255,8 +273,9 @@ def test_filter_alleles_hts_1(self): self.assertTrue( hl.filter_alleles_hts(ds, lambda a, i: a == 'T', subset=True) - .drop('old_alleles', 'old_locus', 'new_to_old', 'old_to_new') - ._same(hl.import_vcf(resource('filter_alleles/keep_allele1_subset.vcf')))) + .drop('old_alleles', 'old_locus', 'new_to_old', 'old_to_new') + ._same(hl.import_vcf(resource('filter_alleles/keep_allele1_subset.vcf'))) + ) def test_filter_alleles_hts_2(self): # 1 variant: A:T,G @@ -264,8 +283,8 @@ def test_filter_alleles_hts_2(self): self.assertTrue( hl.filter_alleles_hts(ds, lambda a, i: a == 'G', subset=True) - .drop('old_alleles', 'old_locus', 'new_to_old', 'old_to_new') - ._same(hl.import_vcf(resource('filter_alleles/keep_allele2_subset.vcf'))) + .drop('old_alleles', 'old_locus', 'new_to_old', 'old_to_new') + ._same(hl.import_vcf(resource('filter_alleles/keep_allele2_subset.vcf'))) ) def test_filter_alleles_hts_3(self): @@ -274,8 +293,8 @@ def test_filter_alleles_hts_3(self): self.assertTrue( hl.filter_alleles_hts(ds, lambda a, i: a != 'G', subset=False) - .drop('old_alleles', 'old_locus', 'new_to_old', 'old_to_new') - ._same(hl.import_vcf(resource('filter_alleles/keep_allele1_downcode.vcf'))) + .drop('old_alleles', 'old_locus', 'new_to_old', 'old_to_new') + ._same(hl.import_vcf(resource('filter_alleles/keep_allele1_downcode.vcf'))) ) def test_filter_alleles_hts_4(self): @@ -284,8 +303,8 @@ def test_filter_alleles_hts_4(self): self.assertTrue( hl.filter_alleles_hts(ds, lambda a, i: a == 'G', subset=False) - .drop('old_alleles', 'old_locus', 'new_to_old', 'old_to_new') - ._same(hl.import_vcf(resource('filter_alleles/keep_allele2_downcode.vcf'))) + .drop('old_alleles', 'old_locus', 'new_to_old', 'old_to_new') + ._same(hl.import_vcf(resource('filter_alleles/keep_allele2_downcode.vcf'))) ) def test_sample_and_variant_qc_call_rate(self): @@ -311,151 +330,188 @@ def test_summarize_variants_ti_tv(self): def test_charr(self): mt = hl.import_vcf(resource('sample.vcf')) - es = mt.select_rows().entries() charr = hl.compute_charr(mt, ref_AF=0.9) d = charr.aggregate(hl.dict(hl.agg.collect((charr.s, charr.charr)))) - assert pytest.approx(d['C1046::HG02024'], abs=0.0001) == .00126 - assert pytest.approx(d['C1046::HG02025'], abs=0.0001) == .00124 + assert pytest.approx(d['C1046::HG02024'], abs=0.0001) == 0.00126 + assert pytest.approx(d['C1046::HG02025'], abs=0.0001) == 0.00124 + @qobtest @skip_unless_service_backend(clouds=['gcp']) @set_gcs_requester_pays_configuration(GCS_REQUESTER_PAYS_PROJECT) @test_timeout(batch=5 * 60) def test_vep_grch37_consequence_true(self): - gnomad_vep_result = hl.import_vcf(resource('sample.gnomad.exomes.r2.1.1.sites.chr1.vcf.gz'), reference_genome='GRCh37', force=True) - hail_vep_result = hl.vep(gnomad_vep_result, csq=True) + gnomad_vep_result = hl.import_vcf( + resource('sample.gnomad.exomes.r2.1.1.sites.chr1.vcf.gz'), reference_genome='GRCh37', force=True + ) + hail_vep_result = hl.vep(gnomad_vep_result, self.vep_config_grch37_85, csq=True) - expected = gnomad_vep_result.select_rows( - vep=gnomad_vep_result.info.vep.map(lambda x: x.split('|')[:8]) - ).rows() + expected = gnomad_vep_result.select_rows(vep=gnomad_vep_result.info.vep.map(lambda x: x.split('|')[:8])).rows() - actual = hail_vep_result.select_rows( - vep=hail_vep_result.vep.map(lambda x: x.split('|')[:8]) - ).rows().drop('vep_csq_header') + actual = ( + hail_vep_result.select_rows(vep=hail_vep_result.vep.map(lambda x: x.split('|')[:8])) + .rows() + .drop('vep_csq_header') + ) assert expected._same(actual) vep_csq_header = hl.eval(hail_vep_result.vep_csq_header) assert 'Consequence annotations from Ensembl VEP' in vep_csq_header, vep_csq_header + @qobtest @skip_unless_service_backend(clouds=['gcp']) @set_gcs_requester_pays_configuration(GCS_REQUESTER_PAYS_PROJECT) @test_timeout(batch=5 * 60) def test_vep_grch38_consequence_true(self): - gnomad_vep_result = hl.import_vcf(resource('sample.gnomad.genomes.r3.0.sites.chr1.vcf.gz'), reference_genome='GRCh38', force=True) - hail_vep_result = hl.vep(gnomad_vep_result, csq=True) + gnomad_vep_result = hl.import_vcf( + resource('sample.gnomad.genomes.r3.0.sites.chr1.vcf.gz'), reference_genome='GRCh38', force=True + ) + hail_vep_result = hl.vep(gnomad_vep_result, self.vep_config_grch38_95, csq=True) expected = gnomad_vep_result.select_rows( - vep=gnomad_vep_result.info.vep.map(lambda x: x.split('|')[:8]) + vep=gnomad_vep_result.info.vep.map(lambda x: x.split(r'\|')[:8]) ).rows() - actual = hail_vep_result.select_rows( - vep=hail_vep_result.vep.map(lambda x: x.split('|')[:8]) - ).rows().drop('vep_csq_header') + actual = ( + hail_vep_result.select_rows(vep=hail_vep_result.vep.map(lambda x: x.split(r'\|')[:8])) + .rows() + .drop('vep_csq_header') + ) assert expected._same(actual) vep_csq_header = hl.eval(hail_vep_result.vep_csq_header) assert 'Consequence annotations from Ensembl VEP' in vep_csq_header, vep_csq_header + @qobtest @skip_unless_service_backend(clouds=['gcp']) @set_gcs_requester_pays_configuration(GCS_REQUESTER_PAYS_PROJECT) @test_timeout(batch=5 * 60) def test_vep_grch37_consequence_false(self): - mt = hl.import_vcf(resource('sample.gnomad.exomes.r2.1.1.sites.chr1.vcf.gz'), reference_genome='GRCh37', force=True) - hail_vep_result = hl.vep(mt, csq=False) + mt = hl.import_vcf( + resource('sample.gnomad.exomes.r2.1.1.sites.chr1.vcf.gz'), reference_genome='GRCh37', force=True + ) + hail_vep_result = hl.vep(mt, self.vep_config_grch37_85, csq=False) ht = hail_vep_result.rows() ht = ht.select(variant_class=ht.vep.variant_class) result = ht.head(1).collect()[0] assert result.variant_class == 'SNV', result + @qobtest @skip_unless_service_backend(clouds=['gcp']) @set_gcs_requester_pays_configuration(GCS_REQUESTER_PAYS_PROJECT) @test_timeout(batch=5 * 60) def test_vep_grch38_consequence_false(self): - mt = hl.import_vcf(resource('sample.gnomad.genomes.r3.0.sites.chr1.vcf.gz'), reference_genome='GRCh38', force=True) - hail_vep_result = hl.vep(mt, csq=False) + mt = hl.import_vcf( + resource('sample.gnomad.genomes.r3.0.sites.chr1.vcf.gz'), reference_genome='GRCh38', force=True + ) + hail_vep_result = hl.vep(mt, self.vep_config_grch38_95, csq=False) ht = hail_vep_result.rows() ht = ht.select(variant_class=ht.vep.variant_class) result = ht.head(1).collect()[0] assert result.variant_class == 'SNV', result + @qobtest @skip_unless_service_backend(clouds=['gcp']) @set_gcs_requester_pays_configuration(GCS_REQUESTER_PAYS_PROJECT) @test_timeout(batch=5 * 60) def test_vep_grch37_against_dataproc(self): mt = hl.import_vcf(resource('sample.vcf.gz'), reference_genome='GRCh37', force_bgz=True, n_partitions=4) mt = mt.head(20) - hail_vep_result = hl.vep(mt) + hail_vep_result = hl.vep(mt, self.vep_config_grch37_85) initial_vep_dtype = hail_vep_result.vep.dtype - hail_vep_result = hail_vep_result.annotate_rows(vep=hail_vep_result.vep.annotate( - input=hl.str('\t').join([ - hail_vep_result.locus.contig, - hl.str(hail_vep_result.locus.position), - ".", - hail_vep_result.alleles[0], - hail_vep_result.alleles[1], - ".", - ".", - "GT", - ]) - )) + hail_vep_result = hail_vep_result.annotate_rows( + vep=hail_vep_result.vep.annotate( + input=hl.str('\t').join([ + hail_vep_result.locus.contig, + hl.str(hail_vep_result.locus.position), + ".", + hail_vep_result.alleles[0], + hail_vep_result.alleles[1], + ".", + ".", + "GT", + ]) + ) + ) hail_vep_result = hail_vep_result.rows().select('vep') def parse_lof_info_into_dict(ht): def tuple2(arr): return hl.tuple([arr[0], arr[1]]) - return ht.annotate(vep=ht.vep.annotate( - transcript_consequences=ht.vep.transcript_consequences.map( - lambda csq: csq.annotate( - lof_info=hl.or_missing(csq.lof_info != 'null', - hl.dict(csq.lof_info.split(',').map(lambda kv: tuple2(kv.split(':'))))))))) + return ht.annotate( + vep=ht.vep.annotate( + transcript_consequences=ht.vep.transcript_consequences.map( + lambda csq: csq.annotate( + lof_info=hl.or_missing( + csq.lof_info != 'null', + hl.dict(csq.lof_info.split(',').map(lambda kv: tuple2(kv.split(':')))), + ) + ) + ) + ) + ) hail_vep_result = parse_lof_info_into_dict(hail_vep_result) - dataproc_result = hl.import_table(resource('dataproc_vep_grch37_annotations.tsv.gz'), - key=['locus', 'alleles'], - types={'locus': hl.tlocus('GRCh37'), 'alleles': hl.tarray(hl.tstr), - 'vep': initial_vep_dtype}, force=True) + dataproc_result = hl.import_table( + resource('dataproc_vep_grch37_annotations.tsv.gz'), + key=['locus', 'alleles'], + types={'locus': hl.tlocus('GRCh37'), 'alleles': hl.tarray(hl.tstr), 'vep': initial_vep_dtype}, + force=True, + ) dataproc_result = parse_lof_info_into_dict(dataproc_result) assert hail_vep_result._same(dataproc_result) + @qobtest @skip_unless_service_backend(clouds=['gcp']) @set_gcs_requester_pays_configuration(GCS_REQUESTER_PAYS_PROJECT) @test_timeout(batch=5 * 60) def test_vep_grch38_against_dataproc(self): - dataproc_result = hl.import_table(resource('dataproc_vep_grch38_annotations.tsv.gz'), - key=['locus', 'alleles'], - types={'locus': hl.tlocus('GRCh38'), 'alleles': hl.tarray(hl.tstr), - 'vep': hl.tstr}, force=True) + dataproc_result = hl.import_table( + resource('dataproc_vep_grch38_annotations.tsv.gz'), + key=['locus', 'alleles'], + types={'locus': hl.tlocus('GRCh38'), 'alleles': hl.tarray(hl.tstr), 'vep': hl.tstr}, + force=True, + ) loftee_variants = dataproc_result.select() - hail_vep_result = hl.vep(loftee_variants) - hail_vep_result = hail_vep_result.annotate(vep=hail_vep_result.vep.annotate( - input=hl.str('\t').join([ - hail_vep_result.locus.contig, - hl.str(hail_vep_result.locus.position), - ".", - hail_vep_result.alleles[0], - hail_vep_result.alleles[1], - ".", - ".", - "GT", - ]) - )) + hail_vep_result = hl.vep(loftee_variants, self.vep_config_grch38_95) + hail_vep_result = hail_vep_result.annotate( + vep=hail_vep_result.vep.annotate( + input=hl.str('\t').join([ + hail_vep_result.locus.contig, + hl.str(hail_vep_result.locus.position), + ".", + hail_vep_result.alleles[0], + hail_vep_result.alleles[1], + ".", + ".", + "GT", + ]) + ) + ) hail_vep_result = hail_vep_result.select('vep') def parse_lof_info_into_dict(ht): def tuple2(arr): return hl.tuple([arr[0], arr[1]]) - return ht.annotate(vep=ht.vep.annotate( - transcript_consequences=ht.vep.transcript_consequences.map( - lambda csq: csq.annotate( - lof_info=hl.or_missing(csq.lof_info != 'null', - hl.dict(csq.lof_info.split(',').map(lambda kv: tuple2(kv.split(':'))))))))) + return ht.annotate( + vep=ht.vep.annotate( + transcript_consequences=ht.vep.transcript_consequences.map( + lambda csq: csq.annotate( + lof_info=hl.or_missing( + csq.lof_info != 'null', + hl.dict(csq.lof_info.split(',').map(lambda kv: tuple2(kv.split(':')))), + ) + ) + ) + ) + ) dataproc_result = dataproc_result.annotate(vep=hl.parse_json(dataproc_result.vep, hail_vep_result.vep.dtype)) @@ -463,3 +519,35 @@ def tuple2(arr): dataproc_result = parse_lof_info_into_dict(dataproc_result) assert hail_vep_result._same(dataproc_result) + + @qobtest + @skip_unless_service_backend(clouds=['gcp']) + @set_gcs_requester_pays_configuration(GCS_REQUESTER_PAYS_PROJECT) + @test_timeout(batch=5 * 60) + def test_vep_grch38_with_large_positions(self): + bad_variants = hl.import_table( + resource('vep_grch38_input_req_indexed_cache.tsv'), + key=['locus', 'alleles'], + types={'locus': hl.tlocus('GRCh38'), 'alleles': hl.tarray(hl.tstr)}, + force=True, + delimiter=' ', + ) + loftee_variants = bad_variants.select() + + hail_vep_result = hl.vep(loftee_variants, self.vep_config_grch38_95) + hail_vep_result = hail_vep_result.annotate( + vep=hail_vep_result.vep.annotate( + input=hl.str('\t').join([ + hail_vep_result.locus.contig, + hl.str(hail_vep_result.locus.position), + ".", + hail_vep_result.alleles[0], + hail_vep_result.alleles[1], + ".", + ".", + "GT", + ]) + ) + ) + hail_vep_result = hail_vep_result.select('vep') + hail_vep_result.collect() diff --git a/hail/python/test/hail/methods/test_simulation.py b/hail/python/test/hail/methods/test_simulation.py index c058eeecf9e..9bbc6fc9eb5 100644 --- a/hail/python/test/hail/methods/test_simulation.py +++ b/hail/python/test/hail/methods/test_simulation.py @@ -1,6 +1,6 @@ import hail as hl -from ..helpers import * +from ..helpers import get_dataset def test_mating_simulation(): @@ -8,10 +8,21 @@ def test_mating_simulation(): n_samples = mt.count_cols() - assert hl.simulate_random_mating(mt, n_rounds=1, generation_size_multiplier=2, keep_founders=False).count_cols() == n_samples * 2 - assert hl.simulate_random_mating(mt, n_rounds=4, generation_size_multiplier=2, keep_founders=False).count_cols() == n_samples * 16 - assert hl.simulate_random_mating(mt, n_rounds=2, generation_size_multiplier=1, keep_founders=False).count_cols() == n_samples - assert hl.simulate_random_mating(mt, n_rounds=2, generation_size_multiplier=2, keep_founders=True).count_cols() == n_samples * 9 + assert ( + hl.simulate_random_mating(mt, n_rounds=1, generation_size_multiplier=2, keep_founders=False).count_cols() + == n_samples * 2 + ) + assert ( + hl.simulate_random_mating(mt, n_rounds=4, generation_size_multiplier=2, keep_founders=False).count_cols() + == n_samples * 16 + ) + assert ( + hl.simulate_random_mating(mt, n_rounds=2, generation_size_multiplier=1, keep_founders=False).count_cols() + == n_samples + ) + assert ( + hl.simulate_random_mating(mt, n_rounds=2, generation_size_multiplier=2, keep_founders=True).count_cols() + == n_samples * 9 + ) - - hl.simulate_random_mating(mt, n_rounds=2, generation_size_multiplier=0.5, keep_founders=True)._force_count_rows() \ No newline at end of file + hl.simulate_random_mating(mt, n_rounds=2, generation_size_multiplier=0.5, keep_founders=True)._force_count_rows() diff --git a/hail/python/test/hail/methods/test_skat.py b/hail/python/test/hail/methods/test_skat.py index c39dbc29d5c..b6718559ff8 100644 --- a/hail/python/test/hail/methods/test_skat.py +++ b/hail/python/test/hail/methods/test_skat.py @@ -1,59 +1,37 @@ -import hail as hl import pytest +import hail as hl from hail.utils import FatalError, HailUserError -from ..helpers import resource, test_timeout, qobtest +from ..helpers import qobtest, resource, test_timeout -@pytest.mark.parametrize("skat_model", [('hl._linear_skat', hl._linear_skat), - ('hl._logistic_skat', hl._logistic_skat)]) +@pytest.mark.parametrize("skat_model", [('hl._linear_skat', hl._linear_skat), ('hl._logistic_skat', hl._logistic_skat)]) def test_skat_negative_weights_errors(skat_model): skat_name, skat = skat_model genotypes = [ [2, 1, 1, 1, 0, 1, 1, 2, 1, 1, 2, 1, 0, 0, 1], [1, 0, 1, 1, 1, 2, 0, 2, 1, 1, 0, 1, 1, 0, 0], [0, 2, 0, 0, 2, 1, 1, 2, 2, 1, 1, 1, 0, 1, 1], - [1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 0]] - covariates = [ - [1], - [1], - [0], - [0], - [0], - [0], - [0], - [0], - [0], - [0], - [0], - [0], - [1], - [1], - [0]] + [1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 0], + ] + covariates = [[1], [1], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [1], [1], [0]] phenotypes = [0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0] weights = [-1, 0, 1, -1] mt = hl.utils.range_matrix_table(4, 15) - mt = mt.annotate_entries( - GT = hl.unphased_diploid_gt_index_call( - hl.literal(genotypes)[mt.row_idx][mt.col_idx]) - ) - mt = mt.annotate_cols( - phenotype = hl.literal(phenotypes)[mt.col_idx], - cov1 = hl.literal(covariates)[mt.col_idx][0] - ) - mt = mt.annotate_rows( - weight = hl.literal(weights)[mt.row_idx] - ) - mt = mt.annotate_globals( - group = 0 - ) + mt = mt.annotate_entries(GT=hl.unphased_diploid_gt_index_call(hl.literal(genotypes)[mt.row_idx][mt.col_idx])) + mt = mt.annotate_cols(phenotype=hl.literal(phenotypes)[mt.col_idx], cov1=hl.literal(covariates)[mt.col_idx][0]) + mt = mt.annotate_rows(weight=hl.literal(weights)[mt.row_idx]) + mt = mt.annotate_globals(group=0) ht = skat(mt.group, mt.weight, mt.phenotype, mt.GT.n_alt_alleles(), [1.0, mt.cov1]) try: ht.collect() except Exception as exc: - assert skat_name + ': every weight must be positive, in group 0, the weights were: [-1.0,0.0,1.0,-1.0]' in exc.args[0] + assert ( + skat_name + ': every weight must be positive, in group 0, the weights were: [-1.0,0.0,1.0,-1.0]' + in exc.args[0] + ) else: assert False @@ -64,46 +42,25 @@ def test_logistic_skat_phenotypes_are_binary(): [2, 1, 1, 1, 0, 1, 1, 2, 1, 1, 2, 1, 0, 0, 1], [1, 0, 1, 1, 1, 2, 0, 2, 1, 1, 0, 1, 1, 0, 0], [0, 2, 0, 0, 2, 1, 1, 2, 2, 1, 1, 1, 0, 1, 1], - [1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 0]] - covariates = [ - [1], - [1], - [0], - [0], - [0], - [0], - [0], - [0], - [0], - [0], - [0], - [0], - [1], - [1], - [0]] + [1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 0], + ] + covariates = [[1], [1], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [1], [1], [0]] phenotypes = [0, 0, 0, 3, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0] weights = [1, 1, 1, 1] mt = hl.utils.range_matrix_table(4, 15) - mt = mt.annotate_entries( - GT = hl.unphased_diploid_gt_index_call( - hl.literal(genotypes)[mt.row_idx][mt.col_idx]) - ) - mt = mt.annotate_cols( - phenotype = hl.literal(phenotypes)[mt.col_idx], - cov1 = hl.literal(covariates)[mt.col_idx][0] - ) - mt = mt.annotate_rows( - weight = hl.literal(weights)[mt.row_idx] - ) - mt = mt.annotate_globals( - group = 0 - ) + mt = mt.annotate_entries(GT=hl.unphased_diploid_gt_index_call(hl.literal(genotypes)[mt.row_idx][mt.col_idx])) + mt = mt.annotate_cols(phenotype=hl.literal(phenotypes)[mt.col_idx], cov1=hl.literal(covariates)[mt.col_idx][0]) + mt = mt.annotate_rows(weight=hl.literal(weights)[mt.row_idx]) + mt = mt.annotate_globals(group=0) try: ht = hl._logistic_skat(mt.group, mt.weight, mt.phenotype, mt.GT.n_alt_alleles(), [1.0, mt.cov1]) ht.collect() except Exception as exc: - assert 'hl._logistic_skat: phenotypes must either be True, False, 0, or 1, found: 3.0 of type float64' in exc.args[0] + assert ( + 'hl._logistic_skat: phenotypes must either be True, False, 0, or 1, found: 3.0 of type float64' + in exc.args[0] + ) else: assert False @@ -142,41 +99,17 @@ def test_logistic_skat_no_weights_R_truth(): [2, 1, 1, 1, 0, 1, 1, 2, 1, 1, 2, 1, 0, 0, 1], [1, 0, 1, 1, 1, 2, 0, 2, 1, 1, 0, 1, 1, 0, 0], [0, 2, 0, 0, 2, 1, 1, 2, 2, 1, 1, 1, 0, 1, 1], - [1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 0]] - covariates = [ - [1], - [1], - [0], - [0], - [0], - [0], - [0], - [0], - [0], - [0], - [0], - [0], - [1], - [1], - [0]] + [1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 0], + ] + covariates = [[1], [1], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [1], [1], [0]] phenotypes = [0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0] weights = [1, 1, 1, 1] mt = hl.utils.range_matrix_table(4, 15) - mt = mt.annotate_entries( - GT = hl.unphased_diploid_gt_index_call( - hl.literal(genotypes)[mt.row_idx][mt.col_idx]) - ) - mt = mt.annotate_cols( - phenotype = hl.literal(phenotypes)[mt.col_idx], - cov1 = hl.literal(covariates)[mt.col_idx][0] - ) - mt = mt.annotate_rows( - weight = hl.literal(weights)[mt.row_idx] - ) - mt = mt.annotate_globals( - group = 0 - ) + mt = mt.annotate_entries(GT=hl.unphased_diploid_gt_index_call(hl.literal(genotypes)[mt.row_idx][mt.col_idx])) + mt = mt.annotate_cols(phenotype=hl.literal(phenotypes)[mt.col_idx], cov1=hl.literal(covariates)[mt.col_idx][0]) + mt = mt.annotate_rows(weight=hl.literal(weights)[mt.row_idx]) + mt = mt.annotate_globals(group=0) ht = hl._logistic_skat(mt.group, mt.weight, mt.phenotype, mt.GT.n_alt_alleles(), [1.0, mt.cov1]) results = ht.collect() @@ -222,41 +155,17 @@ def test_logistic_skat_R_truth(): [2, 1, 1, 1, 0, 1, 1, 2, 1, 1, 2, 1, 0, 0, 1], [1, 0, 1, 1, 1, 2, 0, 2, 1, 1, 0, 1, 1, 0, 0], [0, 2, 0, 0, 2, 1, 1, 2, 2, 1, 1, 1, 0, 1, 1], - [1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 0]] - covariates = [ - [1], - [1], - [0], - [0], - [0], - [0], - [0], - [0], - [0], - [0], - [0], - [0], - [1], - [1], - [0]] + [1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 0], + ] + covariates = [[1], [1], [0], [0], [0], [0], [0], [0], [0], [0], [0], [0], [1], [1], [0]] phenotypes = [0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0] weights = [1, 2, 1, 1] mt = hl.utils.range_matrix_table(4, 15) - mt = mt.annotate_entries( - GT = hl.unphased_diploid_gt_index_call( - hl.literal(genotypes)[mt.row_idx][mt.col_idx]) - ) - mt = mt.annotate_cols( - phenotype = hl.literal(phenotypes)[mt.col_idx], - cov1 = hl.literal(covariates)[mt.col_idx][0] - ) - mt = mt.annotate_rows( - weight = hl.literal(weights)[mt.row_idx] - ) - mt = mt.annotate_globals( - group = 0 - ) + mt = mt.annotate_entries(GT=hl.unphased_diploid_gt_index_call(hl.literal(genotypes)[mt.row_idx][mt.col_idx])) + mt = mt.annotate_cols(phenotype=hl.literal(phenotypes)[mt.col_idx], cov1=hl.literal(covariates)[mt.col_idx][0]) + mt = mt.annotate_rows(weight=hl.literal(weights)[mt.row_idx]) + mt = mt.annotate_globals(group=0) ht = hl._logistic_skat(mt.group, mt.weight, mt.phenotype, mt.GT.n_alt_alleles(), [1.0, mt.cov1]) results = ht.collect() @@ -283,10 +192,12 @@ def test_logistic_skat_on_big_matrix(): expected_p_value = 2.697155e-24 expected_Q_value = 10046.37 - mt = hl.import_matrix_table(resource('skat_genotype_matrix_variants_are_rows.csv'), - delimiter=',', - row_fields={'row_idx': hl.tint64}, - row_key=['row_idx']) + mt = hl.import_matrix_table( + resource('skat_genotype_matrix_variants_are_rows.csv'), + delimiter=',', + row_fields={'row_idx': hl.tint64}, + row_key=['row_idx'], + ) mt = mt.key_cols_by(col_id=hl.int64(mt.col_id)) ht = hl.import_table(resource('skat_phenotypes.csv'), no_header=True, types={'f0': hl.tfloat}) @@ -323,36 +234,20 @@ def test_linear_skat_no_weights_R_truth(): expected_p_value = 0.2700286 expected_Q_value = 2.854975 - genotypes = [ - [0, 1, 0, 0, 0], - [1, 0, 0, 0, 0], - [0, 1, 2, 0, 2], - [1, 0, 0, 2, 1]] - covariates = [ - [1, 2], - [3, 4], - [0, 9], - [6, 1], - [1, 1]] + genotypes = [[0, 1, 0, 0, 0], [1, 0, 0, 0, 0], [0, 1, 2, 0, 2], [1, 0, 0, 2, 1]] + covariates = [[1, 2], [3, 4], [0, 9], [6, 1], [1, 1]] phenotypes = [3, 4, 6, 4, 1] weights = [1, 1, 1, 1] mt = hl.utils.range_matrix_table(4, 5) - mt = mt.annotate_entries( - GT = hl.unphased_diploid_gt_index_call( - hl.literal(genotypes)[mt.row_idx][mt.col_idx]) - ) + mt = mt.annotate_entries(GT=hl.unphased_diploid_gt_index_call(hl.literal(genotypes)[mt.row_idx][mt.col_idx])) mt = mt.annotate_cols( - phenotype = hl.literal(phenotypes)[mt.col_idx], - cov1 = hl.literal(covariates)[mt.col_idx][0], - cov2 = hl.literal(covariates)[mt.col_idx][1] - ) - mt = mt.annotate_rows( - weight = hl.literal(weights)[mt.row_idx] - ) - mt = mt.annotate_globals( - group = 0 + phenotype=hl.literal(phenotypes)[mt.col_idx], + cov1=hl.literal(covariates)[mt.col_idx][0], + cov2=hl.literal(covariates)[mt.col_idx][1], ) + mt = mt.annotate_rows(weight=hl.literal(weights)[mt.row_idx]) + mt = mt.annotate_globals(group=0) ht = hl._linear_skat(mt.group, mt.weight, mt.phenotype, mt.GT.n_alt_alleles(), [1.0, mt.cov1, mt.cov2]) results = ht.collect() @@ -383,36 +278,20 @@ def test_linear_skat_R_truth(): expected_p_value = 0.2497489 expected_Q_value = 3.404505 - genotypes = [ - [0, 1, 0, 0, 0], - [1, 0, 0, 0, 0], - [0, 1, 2, 0, 2], - [1, 0, 0, 2, 1]] - covariates = [ - [1, 2], - [3, 4], - [0, 9], - [6, 1], - [1, 1]] + genotypes = [[0, 1, 0, 0, 0], [1, 0, 0, 0, 0], [0, 1, 2, 0, 2], [1, 0, 0, 2, 1]] + covariates = [[1, 2], [3, 4], [0, 9], [6, 1], [1, 1]] phenotypes = [3, 4, 6, 4, 1] weights = [1, 2, 1, 1] mt = hl.utils.range_matrix_table(4, 5) - mt = mt.annotate_entries( - GT = hl.unphased_diploid_gt_index_call( - hl.literal(genotypes)[mt.row_idx][mt.col_idx]) - ) + mt = mt.annotate_entries(GT=hl.unphased_diploid_gt_index_call(hl.literal(genotypes)[mt.row_idx][mt.col_idx])) mt = mt.annotate_cols( - phenotype = hl.literal(phenotypes)[mt.col_idx], - cov1 = hl.literal(covariates)[mt.col_idx][0], - cov2 = hl.literal(covariates)[mt.col_idx][1] - ) - mt = mt.annotate_rows( - weight = hl.literal(weights)[mt.row_idx] - ) - mt = mt.annotate_globals( - group = 0 + phenotype=hl.literal(phenotypes)[mt.col_idx], + cov1=hl.literal(covariates)[mt.col_idx][0], + cov2=hl.literal(covariates)[mt.col_idx][1], ) + mt = mt.annotate_rows(weight=hl.literal(weights)[mt.row_idx]) + mt = mt.annotate_globals(group=0) ht = hl._linear_skat(mt.group, mt.weight, mt.phenotype, mt.GT.n_alt_alleles(), [1.0, mt.cov1, mt.cov2]) results = ht.collect() @@ -459,10 +338,12 @@ def test_linear_skat_on_big_matrix(): expected_p_value = 4.072862e-57 expected_Q_value = 125247 - mt = hl.import_matrix_table(resource('skat_genotype_matrix_variants_are_rows.csv'), - delimiter=',', - row_fields={'row_idx': hl.tint64}, - row_key=['row_idx']) + mt = hl.import_matrix_table( + resource('skat_genotype_matrix_variants_are_rows.csv'), + delimiter=',', + row_fields={'row_idx': hl.tint64}, + row_key=['row_idx'], + ) mt = mt.key_cols_by(col_id=hl.int64(mt.col_id)) ht = hl.import_table(resource('skat_phenotypes.csv'), no_header=True, types={'f0': hl.tfloat}) @@ -484,147 +365,118 @@ def test_linear_skat_on_big_matrix(): def skat_dataset(): ds2 = hl.import_vcf(resource('sample2.vcf')) - covariates = (hl.import_table(resource("skat.cov"), impute=True) - .key_by("Sample")) + covariates = hl.import_table(resource("skat.cov"), impute=True).key_by("Sample") - phenotypes = (hl.import_table(resource("skat.pheno"), - types={"Pheno": hl.tfloat64}, - missing="0") - .key_by("Sample")) + phenotypes = hl.import_table(resource("skat.pheno"), types={"Pheno": hl.tfloat64}, missing="0").key_by("Sample") - intervals = (hl.import_locus_intervals(resource("skat.interval_list"))) + intervals = hl.import_locus_intervals(resource("skat.interval_list")) - weights = (hl.import_table(resource("skat.weights"), - types={"locus": hl.tlocus(), - "weight": hl.tfloat64}) - .key_by("locus")) + weights = hl.import_table(resource("skat.weights"), types={"locus": hl.tlocus(), "weight": hl.tfloat64}).key_by( + "locus" + ) ds = hl.split_multi_hts(ds2) - ds = ds.annotate_rows(gene=intervals[ds.locus], - weight=weights[ds.locus].weight) - ds = ds.annotate_cols(pheno=phenotypes[ds.s].Pheno, - cov=covariates[ds.s]) - ds = ds.annotate_cols(pheno=hl.if_else(ds.pheno == 1.0, - False, - hl.if_else(ds.pheno == 2.0, - True, - hl.missing(hl.tbool)))) + ds = ds.annotate_rows(gene=intervals[ds.locus], weight=weights[ds.locus].weight) + ds = ds.annotate_cols(pheno=phenotypes[ds.s].Pheno, cov=covariates[ds.s]) + ds = ds.annotate_cols( + pheno=hl.if_else(ds.pheno == 1.0, False, hl.if_else(ds.pheno == 2.0, True, hl.missing(hl.tbool))) + ) return ds @test_timeout(3 * 60) def test_skat_1(): ds = skat_dataset() - hl.skat(key_expr=ds.gene, - weight_expr=ds.weight, - y=ds.pheno, - x=ds.GT.n_alt_alleles(), - covariates=[1.0], - logistic=False)._force_count() + hl.skat( + key_expr=ds.gene, weight_expr=ds.weight, y=ds.pheno, x=ds.GT.n_alt_alleles(), covariates=[1.0], logistic=False + )._force_count() @test_timeout(3 * 60) def test_skat_2(): ds = skat_dataset() - hl.skat(key_expr=ds.gene, - weight_expr=ds.weight, - y=ds.pheno, - x=ds.GT.n_alt_alleles(), - covariates=[1.0], - logistic=True)._force_count() + hl.skat( + key_expr=ds.gene, weight_expr=ds.weight, y=ds.pheno, x=ds.GT.n_alt_alleles(), covariates=[1.0], logistic=True + )._force_count() + @test_timeout(3 * 60) def test_skat_3(): ds = skat_dataset() - hl.skat(key_expr=ds.gene, - weight_expr=ds.weight, - y=ds.pheno, - x=ds.GT.n_alt_alleles(), - covariates=[1.0, ds.cov.Cov1, ds.cov.Cov2], - logistic=False)._force_count() + hl.skat( + key_expr=ds.gene, + weight_expr=ds.weight, + y=ds.pheno, + x=ds.GT.n_alt_alleles(), + covariates=[1.0, ds.cov.Cov1, ds.cov.Cov2], + logistic=False, + )._force_count() + @test_timeout(3 * 60) def test_skat_4(): ds = skat_dataset() - hl.skat(key_expr=ds.gene, - weight_expr=ds.weight, - y=ds.pheno, - x=hl.pl_dosage(ds.PL), - covariates=[1.0, ds.cov.Cov1, ds.cov.Cov2], - logistic=True)._force_count() + hl.skat( + key_expr=ds.gene, + weight_expr=ds.weight, + y=ds.pheno, + x=hl.pl_dosage(ds.PL), + covariates=[1.0, ds.cov.Cov1, ds.cov.Cov2], + logistic=True, + )._force_count() + @test_timeout(3 * 60) def test_skat_5(): ds = skat_dataset() - hl.skat(key_expr=ds.gene, - weight_expr=ds.weight, - y=ds.pheno, - x=hl.pl_dosage(ds.PL), - covariates=[1.0, ds.cov.Cov1, ds.cov.Cov2], - logistic=(25, 1e-6))._force_count() + hl.skat( + key_expr=ds.gene, + weight_expr=ds.weight, + y=ds.pheno, + x=hl.pl_dosage(ds.PL), + covariates=[1.0, ds.cov.Cov1, ds.cov.Cov2], + logistic=(25, 1e-6), + )._force_count() @test_timeout(local=4 * 60) def test_linear_skat_produces_same_results_as_old_scala_method(): mt = hl.import_vcf(resource('sample2.vcf')) - covariates_ht = hl.import_table( - resource("skat.cov"), - key='Sample', - types={'Cov1': hl.tint, 'Cov2': hl.tint} - ) + covariates_ht = hl.import_table(resource("skat.cov"), key='Sample', types={'Cov1': hl.tint, 'Cov2': hl.tint}) phenotypes_ht = hl.import_table( - resource("skat.pheno"), - key='Sample', - types={"Pheno": hl.tfloat64}, missing="0", - impute=True - ) - genes = hl.import_locus_intervals( - resource("skat.interval_list") + resource("skat.pheno"), key='Sample', types={"Pheno": hl.tfloat64}, missing="0", impute=True ) + genes = hl.import_locus_intervals(resource("skat.interval_list")) weights = hl.import_table( - resource("skat.weights"), - key='locus', - types={"locus": hl.tlocus(), "weight": hl.tfloat64} + resource("skat.weights"), key='locus', types={"locus": hl.tlocus(), "weight": hl.tfloat64} ) mt = hl.split_multi_hts(mt) pheno = phenotypes_ht[mt.s].Pheno mt = mt.annotate_cols( - cov = covariates_ht[mt.s], - pheno = (hl.case() - .when(pheno == 1.0, False) - .when(pheno == 2.0, True) - .or_missing()) + cov=covariates_ht[mt.s], pheno=(hl.case().when(pheno == 1.0, False).when(pheno == 2.0, True).or_missing()) ) - mt = mt.annotate_rows( - gene = genes[mt.locus].target, - weight = weights[mt.locus].weight + mt = mt.annotate_rows(gene=genes[mt.locus].target, weight=weights[mt.locus].weight) + skat_results = ( + hl._linear_skat( + mt.gene, mt.weight, y=mt.pheno, x=mt.GT.n_alt_alleles(), covariates=[1, mt.cov.Cov1, mt.cov.Cov2] + ) + .rename({'group': 'id'}) + .select_globals() ) - skat_results = hl._linear_skat( - mt.gene, - mt.weight, - y=mt.pheno, - x=mt.GT.n_alt_alleles(), - covariates=[1, mt.cov.Cov1, mt.cov.Cov2] - ).rename({'group': 'id'}).select_globals() old_scala_results = hl.import_table( resource('scala-skat-results.tsv'), types=dict(id=hl.tstr, size=hl.tint64, q_stat=hl.tfloat, p_value=hl.tfloat, fault=hl.tint), - key='id' + key='id', ) - assert skat_results._same(old_scala_results, tolerance=5e-5) # TSV has 5 sigfigs, so we should match within 5e-5 relative - + assert skat_results._same( + old_scala_results, tolerance=5e-5 + ) # TSV has 5 sigfigs, so we should match within 5e-5 relative def test_skat_max_iteration_fails_explodes_in_37_steps(): mt = hl.utils.range_matrix_table(3, 3) mt = mt.annotate_cols(y=hl.literal([1, 0, 1])[mt.col_idx]) - mt = mt.annotate_entries( - x=hl.literal([ - [1, 0, 0], - [10, 0, 0], - [10, 5, 1] - ])[mt.row_idx] - ) + mt = mt.annotate_entries(x=hl.literal([[1, 0, 0], [10, 0, 0], [10, 5, 1]])[mt.row_idx]) try: ht = hl.skat( hl.literal(0), @@ -634,13 +486,19 @@ def test_skat_max_iteration_fails_explodes_in_37_steps(): logistic=(37, 1e-10), # The logistic settings are only used when fitting the null model, so we need to use a # covariate that triggers nonconvergence - covariates=[mt.y] + covariates=[mt.y], ) ht.collect()[0] except FatalError as err: - assert 'Failed to fit logistic regression null model (MLE with covariates only): exploded at Newton iteration 37' in err.args[0] + assert ( + 'Failed to fit logistic regression null model (MLE with covariates only): exploded at Newton iteration 37' + in err.args[0] + ) except HailUserError as err: - assert 'hl._logistic_skat: null model did not converge: {b: null, score: null, fisher: null, mu: null, n_iterations: 37, log_lkhd: -0.6931471805599453, converged: false, exploded: true}' in err.args[0] + assert ( + 'hl._logistic_skat: null model did not converge: {b: null, score: null, fisher: null, mu: null, n_iterations: 37, log_lkhd: -0.6931471805599453, converged: false, exploded: true}' + in err.args[0] + ) else: assert False @@ -648,13 +506,7 @@ def test_skat_max_iteration_fails_explodes_in_37_steps(): def test_skat_max_iterations_fails_to_converge_in_fewer_than_36_steps(): mt = hl.utils.range_matrix_table(3, 3) mt = mt.annotate_cols(y=hl.literal([1, 0, 1])[mt.col_idx]) - mt = mt.annotate_entries( - x=hl.literal([ - [1, 0, 0], - [10, 0, 0], - [10, 5, 1] - ])[mt.row_idx] - ) + mt = mt.annotate_entries(x=hl.literal([[1, 0, 0], [10, 0, 0], [10, 5, 1]])[mt.row_idx]) try: ht = hl.skat( hl.literal(0), @@ -664,12 +516,18 @@ def test_skat_max_iterations_fails_to_converge_in_fewer_than_36_steps(): logistic=(36, 1e-10), # The logistic settings are only used when fitting the null model, so we need to use a # covariate that triggers nonconvergence - covariates=[mt.y] + covariates=[mt.y], ) ht.collect()[0] except FatalError as err: - assert 'Failed to fit logistic regression null model (MLE with covariates only): Newton iteration failed to converge' in err.args[0] + assert ( + 'Failed to fit logistic regression null model (MLE with covariates only): Newton iteration failed to converge' + in err.args[0] + ) except HailUserError as err: - assert 'hl._logistic_skat: null model did not converge: {b: null, score: null, fisher: null, mu: null, n_iterations: 36, log_lkhd: -0.6931471805599457, converged: false, exploded: false}' in err.args[0] + assert ( + 'hl._logistic_skat: null model did not converge: {b: null, score: null, fisher: null, mu: null, n_iterations: 36, log_lkhd: -0.6931471805599457, converged: false, exploded: false}' + in err.args[0] + ) else: assert False diff --git a/hail/python/test/hail/methods/test_statgen.py b/hail/python/test/hail/methods/test_statgen.py index 63a0730a2f9..e3bca97656c 100644 --- a/hail/python/test/hail/methods/test_statgen.py +++ b/hail/python/test/hail/methods/test_statgen.py @@ -1,17 +1,18 @@ -import os import math -import pytest +import os +import unittest + import numpy as np +import pytest import hail as hl import hail.expr.aggregators as agg -import hail.utils as utils +from hail import utils from hail.linalg import BlockMatrix from hail.utils import FatalError, new_temp_file -from hail.utils.java import choose_backend, Env -from ..helpers import resource, fails_service_backend, skip_when_service_backend, test_timeout, qobtest +from hail.utils.java import Env, choose_backend -import unittest +from ..helpers import fails_service_backend, qobtest, resource, skip_when_service_backend, test_timeout class Tests(unittest.TestCase): @@ -28,22 +29,19 @@ def test_impute_sex_same_as_plink(self): hl.export_vcf(ds, vcf_file) - utils.run_command(["plink", "--vcf", vcf_file, "--const-fid", - "--check-sex", "--silent", "--out", out_file]) + utils.run_command(["plink", "--vcf", vcf_file, "--const-fid", "--check-sex", "--silent", "--out", out_file]) - plink_sex = hl.import_table(out_file + '.sexcheck', - delimiter=' +', - types={'SNPSEX': hl.tint32, - 'F': hl.tfloat64}) + plink_sex = hl.import_table( + out_file + '.sexcheck', delimiter=' +', types={'SNPSEX': hl.tint32, 'F': hl.tfloat64} + ) plink_sex = plink_sex.select('IID', 'SNPSEX', 'F') plink_sex = plink_sex.select( s=plink_sex.IID, - is_female=hl.if_else(plink_sex.SNPSEX == 2, - True, - hl.if_else(plink_sex.SNPSEX == 1, - False, - hl.missing(hl.tbool))), - f_stat=plink_sex.F).key_by('s') + is_female=hl.if_else( + plink_sex.SNPSEX == 2, True, hl.if_else(plink_sex.SNPSEX == 1, False, hl.missing(hl.tbool)) + ), + f_stat=plink_sex.F, + ).key_by('s') sex = sex.select('is_female', 'f_stat') @@ -55,17 +53,19 @@ def test_impute_sex_same_as_plink(self): backend_name = choose_backend() # Outside of Spark backend, "linear_regression_rows" just defers to the underscore nd version. - linreg_functions = [hl.linear_regression_rows, hl._linear_regression_rows_nd] if backend_name == "spark" else [hl.linear_regression_rows] + linreg_functions = ( + [hl.linear_regression_rows, hl._linear_regression_rows_nd] + if backend_name == "spark" + else [hl.linear_regression_rows] + ) @qobtest @test_timeout(4 * 60) def test_linreg_basic(self): - phenos = hl.import_table(resource('regressionLinear.pheno'), - types={'Pheno': hl.tfloat64}, - key='Sample') - covs = hl.import_table(resource('regressionLinear.cov'), - types={'Cov1': hl.tfloat64, 'Cov2': hl.tfloat64}, - key='Sample') + phenos = hl.import_table(resource('regressionLinear.pheno'), types={'Pheno': hl.tfloat64}, key='Sample') + covs = hl.import_table( + resource('regressionLinear.cov'), types={'Cov1': hl.tfloat64, 'Cov2': hl.tfloat64}, key='Sample' + ) mt = hl.import_vcf(resource('regressionLinear.vcf')) mt = mt.annotate_cols(pheno=phenos[mt.s].Pheno, cov=covs[mt.s]) @@ -73,19 +73,17 @@ def test_linreg_basic(self): for linreg_function in self.linreg_functions: t1 = linreg_function( - y=mt.pheno, x=mt.GT.n_alt_alleles(), covariates=[1.0, mt.cov.Cov1, mt.cov.Cov2 + 1 - 1]) + y=mt.pheno, x=mt.GT.n_alt_alleles(), covariates=[1.0, mt.cov.Cov1, mt.cov.Cov2 + 1 - 1] + ) t1 = t1.select(p=t1.p_value) - t2 = linreg_function( - y=mt.pheno, x=mt.x, covariates=[1.0, mt.cov.Cov1, mt.cov.Cov2]) + t2 = linreg_function(y=mt.pheno, x=mt.x, covariates=[1.0, mt.cov.Cov1, mt.cov.Cov2]) t2 = t2.select(p=t2.p_value) - t3 = linreg_function( - y=[mt.pheno], x=mt.x, covariates=[1.0, mt.cov.Cov1, mt.cov.Cov2]) + t3 = linreg_function(y=[mt.pheno], x=mt.x, covariates=[1.0, mt.cov.Cov1, mt.cov.Cov2]) t3 = t3.select(p=t3.p_value[0]) - t4 = linreg_function( - y=[mt.pheno, mt.pheno], x=mt.x, covariates=[1.0, mt.cov.Cov1, mt.cov.Cov2]) + t4 = linreg_function(y=[mt.pheno, mt.pheno], x=mt.x, covariates=[1.0, mt.cov.Cov1, mt.cov.Cov2]) t4a = t4.select(p=t4.p_value[0]) t4b = t4.select(p=t4.p_value[1]) @@ -95,26 +93,22 @@ def test_linreg_basic(self): self.assertTrue(t1._same(t4b)) def test_linreg_pass_through(self): - phenos = hl.import_table(resource('regressionLinear.pheno'), - types={'Pheno': hl.tfloat64}, - key='Sample') - covs = hl.import_table(resource('regressionLinear.cov'), - types={'Cov1': hl.tfloat64, 'Cov2': hl.tfloat64}, - key='Sample') + phenos = hl.import_table(resource('regressionLinear.pheno'), types={'Pheno': hl.tfloat64}, key='Sample') - mt = hl.import_vcf(resource('regressionLinear.vcf')).annotate_rows(foo = hl.struct(bar=hl.rand_norm(0, 1))) + mt = hl.import_vcf(resource('regressionLinear.vcf')).annotate_rows(foo=hl.struct(bar=hl.rand_norm(0, 1))) for linreg_function in self.linreg_functions: - # single group - lr_result = linreg_function(phenos[mt.s].Pheno, mt.GT.n_alt_alleles(), [1.0], - pass_through=['filters', mt.foo.bar, mt.qual]) + lr_result = linreg_function( + phenos[mt.s].Pheno, mt.GT.n_alt_alleles(), [1.0], pass_through=['filters', mt.foo.bar, mt.qual] + ) assert mt.aggregate_rows(hl.agg.all(mt.foo.bar == lr_result[mt.row_key].bar)) # chained - lr_result = linreg_function([[phenos[mt.s].Pheno]], mt.GT.n_alt_alleles(), [1.0], - pass_through=['filters', mt.foo.bar, mt.qual]) + lr_result = linreg_function( + [[phenos[mt.s].Pheno]], mt.GT.n_alt_alleles(), [1.0], pass_through=['filters', mt.foo.bar, mt.qual] + ) assert mt.aggregate_rows(hl.agg.all(mt.foo.bar == lr_result[mt.row_key].bar)) @@ -129,52 +123,52 @@ def test_linreg_pass_through(self): assert lr_result.qual.dtype == mt.qual.dtype # should run successfully with key fields - linreg_function([[phenos[mt.s].Pheno]], mt.GT.n_alt_alleles(), [1.0], - pass_through=['locus', 'alleles']) + linreg_function([[phenos[mt.s].Pheno]], mt.GT.n_alt_alleles(), [1.0], pass_through=['locus', 'alleles']) # complex expression with pytest.raises(ValueError): - linreg_function([[phenos[mt.s].Pheno]], mt.GT.n_alt_alleles(), [1.0], - pass_through=[mt.filters.length()]) + linreg_function( + [[phenos[mt.s].Pheno]], mt.GT.n_alt_alleles(), [1.0], pass_through=[mt.filters.length()] + ) @test_timeout(local=3 * 60) def test_linreg_chained(self): - phenos = hl.import_table(resource('regressionLinear.pheno'), - types={'Pheno': hl.tfloat64}, - key='Sample') - covs = hl.import_table(resource('regressionLinear.cov'), - types={'Cov1': hl.tfloat64, 'Cov2': hl.tfloat64}, - key='Sample') + phenos = hl.import_table(resource('regressionLinear.pheno'), types={'Pheno': hl.tfloat64}, key='Sample') + covs = hl.import_table( + resource('regressionLinear.cov'), types={'Cov1': hl.tfloat64, 'Cov2': hl.tfloat64}, key='Sample' + ) mt = hl.import_vcf(resource('regressionLinear.vcf')) mt = mt.annotate_cols(pheno=phenos[mt.s].Pheno, cov=covs[mt.s]) mt = mt.annotate_entries(x=mt.GT.n_alt_alleles()).cache() for linreg_function in self.linreg_functions: - t1 = linreg_function(y=[[mt.pheno], [mt.pheno]], x=mt.x, covariates=[1, mt.cov.Cov1, mt.cov.Cov2]) def all_eq(*args): pred = True for a in args: - if isinstance(a, hl.expr.Expression) \ - and isinstance(a.dtype, hl.tarray) \ - and isinstance(a.dtype.element_type, hl.tarray): - pred = pred & (hl.all(lambda x: x, - hl.map(lambda elt: ((hl.is_nan(elt[0]) & hl.is_nan(elt[1])) | (elt[0] == elt[1])), - hl.zip(a[0], a[1])))) + if ( + isinstance(a, hl.expr.Expression) + and isinstance(a.dtype, hl.tarray) + and isinstance(a.dtype.element_type, hl.tarray) + ): + pred = pred & ( + hl.all( + lambda x: x, + hl.map( + lambda elt: ((hl.is_nan(elt[0]) & hl.is_nan(elt[1])) | (elt[0] == elt[1])), + hl.zip(a[0], a[1]), + ), + ) + ) else: pred = pred & ((hl.is_nan(a[0]) & hl.is_nan(a[1])) | (a[0] == a[1])) return pred - assert t1.aggregate(hl.agg.all( - all_eq(t1.n, - t1.sum_x, - t1.y_transpose_x, - t1.beta, - t1.standard_error, - t1.t_stat, - t1.p_value))) + assert t1.aggregate( + hl.agg.all(all_eq(t1.n, t1.sum_x, t1.y_transpose_x, t1.beta, t1.standard_error, t1.t_stat, t1.p_value)) + ) mt2 = mt.filter_cols(mt.cov.Cov2 >= 0) mt3 = mt.filter_cols(mt.cov.Cov2 <= 0) @@ -183,46 +177,56 @@ def all_eq(*args): t2 = hl.linear_regression_rows(y=mt2.pheno, x=mt2.x, covariates=[1, mt2.cov.Cov1]) t3 = hl.linear_regression_rows(y=mt3.pheno, x=mt3.x, covariates=[1, mt3.cov.Cov1]) - chained = hl.linear_regression_rows(y=[[hl.case().when(mt.cov.Cov2 >= 0, mt.pheno).or_missing()], - [hl.case().when(mt.cov.Cov2 <= 0, mt.pheno).or_missing()]], - x=mt.x, - covariates=[1, mt.cov.Cov1]) + chained = hl.linear_regression_rows( + y=[ + [hl.case().when(mt.cov.Cov2 >= 0, mt.pheno).or_missing()], + [hl.case().when(mt.cov.Cov2 <= 0, mt.pheno).or_missing()], + ], + x=mt.x, + covariates=[1, mt.cov.Cov1], + ) chained = chained.annotate(r0=t2[chained.key], r1=t3[chained.key]) - assert chained.aggregate(hl.agg.all( - all_eq([chained.n[0], chained.r0.n], - [chained.n[1], chained.r1.n], - [chained.sum_x[0], chained.r0.sum_x], - [chained.sum_x[1], chained.r1.sum_x], - [chained.y_transpose_x[0][0], chained.r0.y_transpose_x], - [chained.y_transpose_x[1][0], chained.r1.y_transpose_x], - [chained.beta[0][0], chained.r0.beta], - [chained.beta[1][0], chained.r1.beta], - [chained.standard_error[0][0], chained.r0.standard_error], - [chained.standard_error[1][0], chained.r1.standard_error], - [chained.t_stat[0][0], chained.r0.t_stat], - [chained.t_stat[1][0], chained.r1.t_stat], - [chained.p_value[0][0], chained.r0.p_value], - [chained.p_value[1][0], chained.r1.p_value]))) + assert chained.aggregate( + hl.agg.all( + all_eq( + [chained.n[0], chained.r0.n], + [chained.n[1], chained.r1.n], + [chained.sum_x[0], chained.r0.sum_x], + [chained.sum_x[1], chained.r1.sum_x], + [chained.y_transpose_x[0][0], chained.r0.y_transpose_x], + [chained.y_transpose_x[1][0], chained.r1.y_transpose_x], + [chained.beta[0][0], chained.r0.beta], + [chained.beta[1][0], chained.r1.beta], + [chained.standard_error[0][0], chained.r0.standard_error], + [chained.standard_error[1][0], chained.r1.standard_error], + [chained.t_stat[0][0], chained.r0.t_stat], + [chained.t_stat[1][0], chained.r1.t_stat], + [chained.p_value[0][0], chained.r0.p_value], + [chained.p_value[1][0], chained.r1.p_value], + ) + ) + ) # test differential missingness against each other - phenos = [hl.case().when(mt.cov.Cov2 >= -1, mt.pheno).or_missing(), - hl.case().when(mt.cov.Cov2 <= 1, mt.pheno).or_missing()] + phenos = [ + hl.case().when(mt.cov.Cov2 >= -1, mt.pheno).or_missing(), + hl.case().when(mt.cov.Cov2 <= 1, mt.pheno).or_missing(), + ] t4 = hl.linear_regression_rows(phenos, mt.x, covariates=[1]) t5 = hl.linear_regression_rows([phenos], mt.x, covariates=[1]) - t5 = t5.annotate(**{x: t5[x][0] for x in ['n', 'sum_x', 'y_transpose_x', 'beta', 'standard_error', 't_stat', 'p_value']}) + t5 = t5.annotate(**{ + x: t5[x][0] for x in ['n', 'sum_x', 'y_transpose_x', 'beta', 'standard_error', 't_stat', 'p_value'] + }) assert t4._same(t5) def test_linear_regression_without_intercept(self): for linreg_function in self.linreg_functions: - pheno = hl.import_table(resource('regressionLinear.pheno'), - key='Sample', - missing='0', - types={'Pheno': hl.tfloat}) + pheno = hl.import_table( + resource('regressionLinear.pheno'), key='Sample', missing='0', types={'Pheno': hl.tfloat} + ) mt = hl.import_vcf(resource('regressionLinear.vcf')) - ht = linreg_function(y=pheno[mt.s].Pheno, - x=mt.GT.n_alt_alleles(), - covariates=[]) + ht = linreg_function(y=pheno[mt.s].Pheno, x=mt.GT.n_alt_alleles(), covariates=[]) results = dict(hl.tuple([ht.locus.position, ht.row]).collect()) self.assertAlmostEqual(results[1].beta, 1.5, places=6) self.assertAlmostEqual(results[1].standard_error, 1.161895, places=6) @@ -239,21 +243,19 @@ def test_linear_regression_without_intercept(self): # summary(fit)["coefficients"] @pytest.mark.unchecked_allocator def test_linear_regression_with_cov(self): - - covariates = hl.import_table(resource('regressionLinear.cov'), - key='Sample', - types={'Cov1': hl.tfloat, 'Cov2': hl.tfloat}) - pheno = hl.import_table(resource('regressionLinear.pheno'), - key='Sample', - missing='0', - types={'Pheno': hl.tfloat}) + covariates = hl.import_table( + resource('regressionLinear.cov'), key='Sample', types={'Cov1': hl.tfloat, 'Cov2': hl.tfloat} + ) + pheno = hl.import_table( + resource('regressionLinear.pheno'), key='Sample', missing='0', types={'Pheno': hl.tfloat} + ) mt = hl.import_vcf(resource('regressionLinear.vcf')) for linreg_function in self.linreg_functions: - ht = linreg_function(y=pheno[mt.s].Pheno, - x=mt.GT.n_alt_alleles(), - covariates=[1.0] + list(covariates[mt.s].values())) + ht = linreg_function( + y=pheno[mt.s].Pheno, x=mt.GT.n_alt_alleles(), covariates=[1.0, *list(covariates[mt.s].values())] + ) results = dict(hl.tuple([ht.locus.position, ht.row]).collect()) @@ -282,22 +284,19 @@ def test_linear_regression_with_cov(self): self.assertTrue(np.isnan(results[10].standard_error)) def test_linear_regression_pl(self): - - covariates = hl.import_table(resource('regressionLinear.cov'), - key='Sample', - types={'Cov1': hl.tfloat, 'Cov2': hl.tfloat}) - pheno = hl.import_table(resource('regressionLinear.pheno'), - key='Sample', - missing='0', - types={'Pheno': hl.tfloat}) + covariates = hl.import_table( + resource('regressionLinear.cov'), key='Sample', types={'Cov1': hl.tfloat, 'Cov2': hl.tfloat} + ) + pheno = hl.import_table( + resource('regressionLinear.pheno'), key='Sample', missing='0', types={'Pheno': hl.tfloat} + ) mt = hl.import_vcf(resource('regressionLinear.vcf')) for linreg_function in self.linreg_functions: - - ht = linreg_function(y=pheno[mt.s].Pheno, - x=hl.pl_dosage(mt.PL), - covariates=[1.0] + list(covariates[mt.s].values())) + ht = linreg_function( + y=pheno[mt.s].Pheno, x=hl.pl_dosage(mt.PL), covariates=[1.0, *list(covariates[mt.s].values())] + ) results = dict(hl.tuple([ht.locus.position, ht.row]).collect()) @@ -317,20 +316,18 @@ def test_linear_regression_pl(self): self.assertAlmostEqual(results[3].p_value, 0.2533675, places=6) def test_linear_regression_with_dosage(self): - - covariates = hl.import_table(resource('regressionLinear.cov'), - key='Sample', - types={'Cov1': hl.tfloat, 'Cov2': hl.tfloat}) - pheno = hl.import_table(resource('regressionLinear.pheno'), - key='Sample', - missing='0', - types={'Pheno': hl.tfloat}) + covariates = hl.import_table( + resource('regressionLinear.cov'), key='Sample', types={'Cov1': hl.tfloat, 'Cov2': hl.tfloat} + ) + pheno = hl.import_table( + resource('regressionLinear.pheno'), key='Sample', missing='0', types={'Pheno': hl.tfloat} + ) mt = hl.import_gen(resource('regressionLinear.gen'), sample_file=resource('regressionLinear.sample')) for linreg_function in self.linreg_functions: - ht = linreg_function(y=pheno[mt.s].Pheno, - x=hl.gp_dosage(mt.GP), - covariates=[1.0] + list(covariates[mt.s].values())) + ht = linreg_function( + y=pheno[mt.s].Pheno, x=hl.gp_dosage(mt.GP), covariates=[1.0, *list(covariates[mt.s].values())] + ) results = dict(hl.tuple([ht.locus.position, ht.row]).collect()) @@ -367,16 +364,16 @@ def test_linear_regression_equivalence_between_ds_and_gt(self): self.assertTrue(all(hl.approx_equal(results_t.ds_p_value, results_t.gt_p_value, nan_same=True).collect())) def test_linear_regression_with_import_fam_boolean(self): - covariates = hl.import_table(resource('regressionLinear.cov'), - key='Sample', - types={'Cov1': hl.tfloat, 'Cov2': hl.tfloat}) + covariates = hl.import_table( + resource('regressionLinear.cov'), key='Sample', types={'Cov1': hl.tfloat, 'Cov2': hl.tfloat} + ) fam = hl.import_fam(resource('regressionLinear.fam')) mt = hl.import_vcf(resource('regressionLinear.vcf')) for linreg_function in self.linreg_functions: - ht = linreg_function(y=fam[mt.s].is_case, - x=mt.GT.n_alt_alleles(), - covariates=[1.0] + list(covariates[mt.s].values())) + ht = linreg_function( + y=fam[mt.s].is_case, x=mt.GT.n_alt_alleles(), covariates=[1.0, *list(covariates[mt.s].values())] + ) results = dict(hl.tuple([ht.locus.position, ht.row]).collect()) @@ -397,18 +394,16 @@ def test_linear_regression_with_import_fam_boolean(self): self.assertTrue(np.isnan(results[10].standard_error)) def test_linear_regression_with_import_fam_quant(self): - covariates = hl.import_table(resource('regressionLinear.cov'), - key='Sample', - types={'Cov1': hl.tfloat, 'Cov2': hl.tfloat}) - fam = hl.import_fam(resource('regressionLinear.fam'), - quant_pheno=True, - missing='0') + covariates = hl.import_table( + resource('regressionLinear.cov'), key='Sample', types={'Cov1': hl.tfloat, 'Cov2': hl.tfloat} + ) + fam = hl.import_fam(resource('regressionLinear.fam'), quant_pheno=True, missing='0') mt = hl.import_vcf(resource('regressionLinear.vcf')) for linreg_function in self.linreg_functions: - ht = linreg_function(y=fam[mt.s].quant_pheno, - x=mt.GT.n_alt_alleles(), - covariates=[1.0] + list(covariates[mt.s].values())) + ht = linreg_function( + y=fam[mt.s].quant_pheno, x=mt.GT.n_alt_alleles(), covariates=[1.0, *list(covariates[mt.s].values())] + ) results = dict(hl.tuple([ht.locus.position, ht.row]).collect()) @@ -429,65 +424,71 @@ def test_linear_regression_with_import_fam_quant(self): self.assertTrue(np.isnan(results[10].standard_error)) def test_linear_regression_multi_pheno_same(self): - covariates = hl.import_table(resource('regressionLinear.cov'), - key='Sample', - types={'Cov1': hl.tfloat, 'Cov2': hl.tfloat}) - pheno = hl.import_table(resource('regressionLinear.pheno'), - key='Sample', - missing='0', - types={'Pheno': hl.tfloat}) + covariates = hl.import_table( + resource('regressionLinear.cov'), key='Sample', types={'Cov1': hl.tfloat, 'Cov2': hl.tfloat} + ) + pheno = hl.import_table( + resource('regressionLinear.pheno'), key='Sample', missing='0', types={'Pheno': hl.tfloat} + ) mt = hl.import_vcf(resource('regressionLinear.vcf')) for linreg_function in self.linreg_functions: - single = linreg_function(y=pheno[mt.s].Pheno, - x=mt.GT.n_alt_alleles(), - covariates=list(covariates[mt.s].values())) - multi = linreg_function(y=[pheno[mt.s].Pheno, pheno[mt.s].Pheno], - x=mt.GT.n_alt_alleles(), - covariates=list(covariates[mt.s].values())) + single = linreg_function( + y=pheno[mt.s].Pheno, x=mt.GT.n_alt_alleles(), covariates=list(covariates[mt.s].values()) + ) + multi = linreg_function( + y=[pheno[mt.s].Pheno, pheno[mt.s].Pheno], + x=mt.GT.n_alt_alleles(), + covariates=list(covariates[mt.s].values()), + ) def eq(x1, x2): return (hl.is_nan(x1) & hl.is_nan(x2)) | (hl.abs(x1 - x2) < 1e-4) - combined = single.annotate(multi = multi[single.key]) - self.assertTrue(combined.aggregate(hl.agg.all( - eq(combined.p_value, combined.multi.p_value[0]) & - eq(combined.multi.p_value[0], combined.multi.p_value[1])))) - + combined = single.annotate(multi=multi[single.key]) + self.assertTrue( + combined.aggregate( + hl.agg.all( + eq(combined.p_value, combined.multi.p_value[0]) + & eq(combined.multi.p_value[0], combined.multi.p_value[1]) + ) + ) + ) def test_logistic_regression_rows_max_iter_zero(self): import hail as hl + mt = hl.utils.range_matrix_table(1, 3) mt = mt.annotate_entries(x=hl.literal([1, 1, 10])) try: ht = hl.logistic_regression_rows( - test='wald', - y=hl.literal([0, 0, 1])[mt.col_idx], - x=mt.x[mt.col_idx], - covariates=[1], - max_iterations=0 + test='wald', y=hl.literal([0, 0, 1])[mt.col_idx], x=mt.x[mt.col_idx], covariates=[1], max_iterations=0 ) ht.globals.collect() # null model is a global except Exception as exc: - assert 'Failed to fit logistic regression null model (standard MLE with covariates only): Newton iteration failed to converge' in exc.args[0] + assert ( + 'Failed to fit logistic regression null model (standard MLE with covariates only): Newton iteration failed to converge' + in exc.args[0] + ) else: assert False # Outside the spark backend, "logistic_regression_rows" automatically defers to the _ version. - logreg_functions = [hl.logistic_regression_rows, hl._logistic_regression_rows_nd] if backend_name == "spark" else [hl.logistic_regression_rows] + logreg_functions = ( + [hl.logistic_regression_rows, hl._logistic_regression_rows_nd] + if backend_name == "spark" + else [hl.logistic_regression_rows] + ) def test_logistic_regression_rows_max_iter_explodes(self): for logreg in self.logreg_functions: import hail as hl + mt = hl.utils.range_matrix_table(1, 3) mt = mt.annotate_entries(x=hl.literal([1, 1, 10])) ht = logreg( - test='wald', - y=hl.literal([0, 0, 1])[mt.col_idx], - x=mt.x[mt.col_idx], - covariates=[1], - max_iterations=100 + test='wald', y=hl.literal([0, 0, 1])[mt.col_idx], x=mt.x[mt.col_idx], covariates=[1], max_iterations=100 ) fit = ht.collect()[0].fit assert fit.n_iterations < 100 @@ -496,14 +497,11 @@ def test_logistic_regression_rows_max_iter_explodes(self): def test_firth_logistic_regression_rows_explodes_in_12_steps(self): import hail as hl + mt = hl.utils.range_matrix_table(1, 3) mt = mt.annotate_entries(x=hl.literal([1, 1, 10])) ht = hl.logistic_regression_rows( - test='firth', - y=hl.literal([0, 1, 1, 0])[mt.col_idx], - x=mt.x[mt.col_idx], - covariates=[1], - max_iterations=100 + test='firth', y=hl.literal([0, 1, 1, 0])[mt.col_idx], x=mt.x[mt.col_idx], covariates=[1], max_iterations=100 ) fit = ht.collect()[0].fit assert fit.n_iterations == 12 @@ -512,14 +510,11 @@ def test_firth_logistic_regression_rows_explodes_in_12_steps(self): def test_firth_logistic_regression_rows_does_not_converge_with_105_iterations(self): import hail as hl + mt = hl.utils.range_matrix_table(1, 3) mt = mt.annotate_entries(x=hl.literal([1, 3, 10])) ht = hl.logistic_regression_rows( - test='firth', - y=hl.literal([0, 1, 1])[mt.col_idx], - x=mt.x[mt.col_idx], - covariates=[1], - max_iterations=105 + test='firth', y=hl.literal([0, 1, 1])[mt.col_idx], x=mt.x[mt.col_idx], covariates=[1], max_iterations=105 ) fit = ht.collect()[0].fit assert fit.n_iterations == 105 @@ -528,6 +523,7 @@ def test_firth_logistic_regression_rows_does_not_converge_with_105_iterations(se def test_firth_logistic_regression_rows_does_converge_with_more_iterations(self): import hail as hl + mt = hl.utils.range_matrix_table(1, 3) mt = mt.annotate_entries(x=hl.literal([1, 3, 10])) ht = hl.logistic_regression_rows( @@ -536,7 +532,7 @@ def test_firth_logistic_regression_rows_does_converge_with_more_iterations(self) x=mt.x[mt.col_idx], covariates=[1], max_iterations=106, - tolerance=1e-6 + tolerance=1e-6, ) result = ht.collect()[0] fit = result.fit @@ -555,40 +551,37 @@ def both_nan_or_none(a, b): @test_timeout(3 * 60) def test_weighted_linear_regression(self): - covariates = hl.import_table(resource('regressionLinear.cov'), - key='Sample', - types={'Cov1': hl.tfloat, 'Cov2': hl.tfloat}) - pheno = hl.import_table(resource('regressionLinear.pheno'), - key='Sample', - missing='0', - types={'Pheno': hl.tfloat}) - - weights = hl.import_table(resource('regressionLinear.weights'), - key='Sample', - missing='0', - types={'Sample': hl.tstr, 'Weight1': hl.tfloat, 'Weight2': hl.tfloat}) + covariates = hl.import_table( + resource('regressionLinear.cov'), key='Sample', types={'Cov1': hl.tfloat, 'Cov2': hl.tfloat} + ) + pheno = hl.import_table( + resource('regressionLinear.pheno'), key='Sample', missing='0', types={'Pheno': hl.tfloat} + ) mt = hl.import_vcf(resource('regressionLinear.vcf')) mt = mt.add_col_index() mt = mt.annotate_cols(y=hl.coalesce(pheno[mt.s].Pheno, 1.0)) mt = mt.annotate_entries(x=hl.coalesce(mt.GT.n_alt_alleles(), 1.0)) - my_covs = [1.0] + list(covariates[mt.s].values()) + my_covs = [1.0, *list(covariates[mt.s].values())] - ht_with_weights = hl._linear_regression_rows_nd(y=mt.y, - x=mt.x, - covariates=my_covs, - weights=mt.col_idx) + ht_with_weights = hl._linear_regression_rows_nd(y=mt.y, x=mt.x, covariates=my_covs, weights=mt.col_idx) - ht_pre_weighted_1 = hl._linear_regression_rows_nd(y=mt.y * hl.sqrt(mt.col_idx), - x=mt.x * hl.sqrt(mt.col_idx), - covariates=list(map(lambda e: e * hl.sqrt(mt.col_idx), my_covs))) + ht_pre_weighted_1 = hl._linear_regression_rows_nd( + y=mt.y * hl.sqrt(mt.col_idx), + x=mt.x * hl.sqrt(mt.col_idx), + covariates=list(map(lambda e: e * hl.sqrt(mt.col_idx), my_covs)), + ) - ht_pre_weighted_2 = hl._linear_regression_rows_nd(y=mt.y * hl.sqrt(mt.col_idx + 5), - x=mt.x * hl.sqrt(mt.col_idx + 5), - covariates=list(map(lambda e: e * hl.sqrt(mt.col_idx + 5), my_covs))) + ht_pre_weighted_2 = hl._linear_regression_rows_nd( + y=mt.y * hl.sqrt(mt.col_idx + 5), + x=mt.x * hl.sqrt(mt.col_idx + 5), + covariates=list(map(lambda e: e * hl.sqrt(mt.col_idx + 5), my_covs)), + ) - ht_from_agg = mt.annotate_rows(my_linreg=hl.agg.linreg(mt.y, [1, mt.x] + list(covariates[mt.s].values()), weight=mt.col_idx)).rows() + ht_from_agg = mt.annotate_rows( + my_linreg=hl.agg.linreg(mt.y, [1, mt.x, *list(covariates[mt.s].values())], weight=mt.col_idx) + ).rows() betas_with_weights = ht_with_weights.beta.collect() betas_pre_weighted_1 = ht_pre_weighted_1.beta.collect() @@ -599,10 +592,9 @@ def test_weighted_linear_regression(self): assert self.equal_with_nans(betas_with_weights, betas_pre_weighted_1) assert self.equal_with_nans(betas_with_weights, betas_from_agg) - ht_with_multiple_weights = hl._linear_regression_rows_nd(y=[[mt.y], [hl.abs(mt.y)]], - x=mt.x, - covariates=my_covs, - weights=[mt.col_idx, mt.col_idx + 5]) + ht_with_multiple_weights = hl._linear_regression_rows_nd( + y=[[mt.y], [hl.abs(mt.y)]], x=mt.x, covariates=my_covs, weights=[mt.col_idx, mt.col_idx + 5] + ) # Check that preweighted 1 and preweighted 2 match up with fields 1 and 2 of multiple multi_weight_betas = ht_with_multiple_weights.beta.collect() @@ -617,29 +609,41 @@ def test_weighted_linear_regression(self): @test_timeout(3 * 60) def test_weighted_linear_regression__missing_weights_are_excluded(self): mt = hl.import_vcf(resource('regressionLinear.vcf')) - pheno = hl.import_table(resource('regressionLinear.pheno'), - key='Sample', - missing='0', - types={'Pheno': hl.tfloat}) + pheno = hl.import_table( + resource('regressionLinear.pheno'), key='Sample', missing='0', types={'Pheno': hl.tfloat} + ) mt = mt.annotate_cols(y=hl.coalesce(pheno[mt.s].Pheno, 1.0)) - weights = hl.import_table(resource('regressionLinear.weights'), - key='Sample', - missing='0', - types={'Sample': hl.tstr, 'Weight1': hl.tfloat, 'Weight2': hl.tfloat}) + weights = hl.import_table( + resource('regressionLinear.weights'), + key='Sample', + missing='0', + types={'Sample': hl.tstr, 'Weight1': hl.tfloat, 'Weight2': hl.tfloat}, + ) mt = mt.annotate_entries(x=hl.coalesce(mt.GT.n_alt_alleles(), 1.0)) - ht_with_missing_weights = hl._linear_regression_rows_nd(y=[[mt.y], [hl.abs(mt.y)]], - x=mt.x, - covariates=[1], - weights=[weights[mt.s].Weight1, weights[mt.s].Weight2]) - - mt_with_missing_weights = mt.annotate_cols(Weight1 = weights[mt.s].Weight1, Weight2 = weights[mt.s].Weight2) - mt_with_missing_weight1_filtered = mt_with_missing_weights.filter_cols(hl.is_defined(mt_with_missing_weights.Weight1)) - mt_with_missing_weight2_filtered = mt_with_missing_weights.filter_cols(hl.is_defined(mt_with_missing_weights.Weight2)) + ht_with_missing_weights = hl._linear_regression_rows_nd( + y=[[mt.y], [hl.abs(mt.y)]], x=mt.x, covariates=[1], weights=[weights[mt.s].Weight1, weights[mt.s].Weight2] + ) + + mt_with_missing_weights = mt.annotate_cols(Weight1=weights[mt.s].Weight1, Weight2=weights[mt.s].Weight2) + mt_with_missing_weight1_filtered = mt_with_missing_weights.filter_cols( + hl.is_defined(mt_with_missing_weights.Weight1) + ) + mt_with_missing_weight2_filtered = mt_with_missing_weights.filter_cols( + hl.is_defined(mt_with_missing_weights.Weight2) + ) ht_from_agg_weight_1 = mt_with_missing_weight1_filtered.annotate_rows( - my_linreg=hl.agg.linreg(mt_with_missing_weight1_filtered.y, [1, mt_with_missing_weight1_filtered.x], weight=weights[mt_with_missing_weight1_filtered.s].Weight1) + my_linreg=hl.agg.linreg( + mt_with_missing_weight1_filtered.y, + [1, mt_with_missing_weight1_filtered.x], + weight=weights[mt_with_missing_weight1_filtered.s].Weight1, + ) ).rows() ht_from_agg_weight_2 = mt_with_missing_weight2_filtered.annotate_rows( - my_linreg=hl.agg.linreg(mt_with_missing_weight2_filtered.y, [1, mt_with_missing_weight2_filtered.x], weight=weights[mt_with_missing_weight2_filtered.s].Weight2) + my_linreg=hl.agg.linreg( + mt_with_missing_weight2_filtered.y, + [1, mt_with_missing_weight2_filtered.x], + weight=weights[mt_with_missing_weight2_filtered.s].Weight2, + ) ).rows() multi_weight_missing_results = ht_with_missing_weights.collect() @@ -688,20 +692,18 @@ def test_errors_weighted_linear_regression(self): mt = hl.utils.range_matrix_table(20, 10).annotate_entries(x=2) mt = mt.annotate_cols(**{f"col_{i}": i for i in range(4)}) - self.assertRaises(ValueError, lambda: hl._linear_regression_rows_nd(y=[[mt.col_1]], - x=mt.x, - covariates=[1], - weights=[mt.col_2, mt.col_3])) + self.assertRaises( + ValueError, + lambda: hl._linear_regression_rows_nd(y=[[mt.col_1]], x=mt.x, covariates=[1], weights=[mt.col_2, mt.col_3]), + ) - self.assertRaises(ValueError, lambda: hl._linear_regression_rows_nd(y=[mt.col_1], - x=mt.x, - covariates=[1], - weights=[mt.col_2])) + self.assertRaises( + ValueError, lambda: hl._linear_regression_rows_nd(y=[mt.col_1], x=mt.x, covariates=[1], weights=[mt.col_2]) + ) - self.assertRaises(ValueError, lambda: hl._linear_regression_rows_nd(y=[[mt.col_1]], - x=mt.x, - covariates=[1], - weights=mt.col_2)) + self.assertRaises( + ValueError, lambda: hl._linear_regression_rows_nd(y=[[mt.col_1]], x=mt.x, covariates=[1], weights=mt.col_2) + ) # comparing to R: # x = c(0, 1, 0, 0, 0, 1, 0, 0, 0, 0) @@ -715,20 +717,21 @@ def test_errors_weighted_linear_regression(self): # zstat <- waldtest["x", "z value"] # pval <- waldtest["x", "Pr(>|z|)"] def test_logistic_regression_wald_test(self): - covariates = hl.import_table(resource('regressionLogistic.cov'), - key='Sample', - types={'Cov1': hl.tfloat, 'Cov2': hl.tfloat}) - pheno = hl.import_table(resource('regressionLogisticBoolean.pheno'), - key='Sample', - missing='0', - types={'isCase': hl.tbool}) + covariates = hl.import_table( + resource('regressionLogistic.cov'), key='Sample', types={'Cov1': hl.tfloat, 'Cov2': hl.tfloat} + ) + pheno = hl.import_table( + resource('regressionLogisticBoolean.pheno'), key='Sample', missing='0', types={'isCase': hl.tbool} + ) mt = hl.import_vcf(resource('regressionLogistic.vcf')) for logistic_regression_function in self.logreg_functions: - ht = logistic_regression_function('wald', - y=pheno[mt.s].isCase, - x=mt.GT.n_alt_alleles(), - covariates=[1.0, covariates[mt.s].Cov1, covariates[mt.s].Cov2]) + ht = logistic_regression_function( + 'wald', + y=pheno[mt.s].isCase, + x=mt.GT.n_alt_alleles(), + covariates=[1.0, covariates[mt.s].Cov1, covariates[mt.s].Cov2], + ) results = dict(hl.tuple([ht.locus.position, ht.row]).collect()) @@ -753,40 +756,44 @@ def is_constant(r): self.assertTrue(is_constant(results[10])) def test_logistic_regression_wald_test_apply_multi_pheno(self): - covariates = hl.import_table(resource('regressionLogistic.cov'), - key='Sample', - types={'Cov1': hl.tfloat, 'Cov2': hl.tfloat}) - pheno = hl.import_table(resource('regressionLogisticBoolean.pheno'), - key='Sample', - missing='0', - types={'isCase': hl.tbool}) + covariates = hl.import_table( + resource('regressionLogistic.cov'), key='Sample', types={'Cov1': hl.tfloat, 'Cov2': hl.tfloat} + ) + pheno = hl.import_table( + resource('regressionLogisticBoolean.pheno'), key='Sample', missing='0', types={'isCase': hl.tbool} + ) mt = hl.import_vcf(resource('regressionLogistic.vcf')) for logistic_regression_function in self.logreg_functions: - - ht = logistic_regression_function('wald', - y=[pheno[mt.s].isCase], - x=mt.GT.n_alt_alleles(), - covariates=[1.0, covariates[mt.s].Cov1, covariates[mt.s].Cov2]) + ht = logistic_regression_function( + 'wald', + y=[pheno[mt.s].isCase], + x=mt.GT.n_alt_alleles(), + covariates=[1.0, covariates[mt.s].Cov1, covariates[mt.s].Cov2], + ) results = dict(hl.tuple([ht.locus.position, ht.row]).collect()) - self.assertEqual(len(results[1].logistic_regression),1) + self.assertEqual(len(results[1].logistic_regression), 1) self.assertAlmostEqual(results[1].logistic_regression[0].beta, -0.81226793796, places=6) self.assertAlmostEqual(results[1].logistic_regression[0].standard_error, 2.1085483421, places=6) self.assertAlmostEqual(results[1].logistic_regression[0].z_stat, -0.3852261396, places=6) self.assertAlmostEqual(results[1].logistic_regression[0].p_value, 0.7000698784, places=6) - self.assertEqual(len(results[2].logistic_regression),1) + self.assertEqual(len(results[2].logistic_regression), 1) self.assertAlmostEqual(results[2].logistic_regression[0].beta, -0.43659460858, places=6) self.assertAlmostEqual(results[2].logistic_regression[0].standard_error, 1.0296902941, places=6) self.assertAlmostEqual(results[2].logistic_regression[0].z_stat, -0.4240057531, places=6) self.assertAlmostEqual(results[2].logistic_regression[0].p_value, 0.6715616176, places=6) def is_constant(r): - return (not r.logistic_regression[0].fit.converged) or np.isnan(r.logistic_regression[0].p_value) or abs(r.logistic_regression[0].p_value - 1) < 1e-4 + return ( + (not r.logistic_regression[0].fit.converged) + or np.isnan(r.logistic_regression[0].p_value) + or abs(r.logistic_regression[0].p_value - 1) < 1e-4 + ) - self.assertEqual(len(results[3].logistic_regression),1) + self.assertEqual(len(results[3].logistic_regression), 1) self.assertFalse(results[3].logistic_regression[0].fit.converged) # separable self.assertTrue(is_constant(results[6])) self.assertTrue(is_constant(results[7])) @@ -795,62 +802,68 @@ def is_constant(r): self.assertTrue(is_constant(results[10])) def test_logistic_regression_wald_test_multi_pheno_bgen_dosage(self): - covariates = hl.import_table(resource('regressionLogisticMultiPheno.cov'), - key='Sample', - types={'Cov1': hl.tfloat, 'Cov2': hl.tfloat}).cache() - pheno = hl.import_table(resource('regressionLogisticMultiPheno.pheno'), - key='Sample', - missing='NA', - types={'Pheno1': hl.tint32, 'Pheno2': hl.tint32}).cache() + covariates = hl.import_table( + resource('regressionLogisticMultiPheno.cov'), key='Sample', types={'Cov1': hl.tfloat, 'Cov2': hl.tfloat} + ).cache() + pheno = hl.import_table( + resource('regressionLogisticMultiPheno.pheno'), + key='Sample', + missing='NA', + types={'Pheno1': hl.tint32, 'Pheno2': hl.tint32}, + ).cache() bgen_path = new_temp_file(extension='bgen') Env.fs().copy(resource('example.8bits.bgen'), bgen_path) - hl.index_bgen(bgen_path, - contig_recoding={'01': '1'}, - reference_genome='GRCh37') + hl.index_bgen(bgen_path, contig_recoding={'01': '1'}, reference_genome='GRCh37') - mt = hl.import_bgen(bgen_path, - entry_fields=['dosage']) + mt = hl.import_bgen(bgen_path, entry_fields=['dosage']) for logistic_regression_function in self.logreg_functions: + ht_single_pheno = logistic_regression_function( + 'wald', + y=pheno[mt.s].Pheno1, + x=mt.dosage, + covariates=[1.0, covariates[mt.s].Cov1, covariates[mt.s].Cov2], + ) - ht_single_pheno = logistic_regression_function('wald', - y=pheno[mt.s].Pheno1, - x=mt.dosage, - covariates=[1.0, covariates[mt.s].Cov1, covariates[mt.s].Cov2]) - - ht_multi_pheno = logistic_regression_function('wald', - y=[pheno[mt.s].Pheno1, pheno[mt.s].Pheno2], - x=mt.dosage, - covariates=[1.0, covariates[mt.s].Cov1, covariates[mt.s].Cov2]) + ht_multi_pheno = logistic_regression_function( + 'wald', + y=[pheno[mt.s].Pheno1, pheno[mt.s].Pheno2], + x=mt.dosage, + covariates=[1.0, covariates[mt.s].Cov1, covariates[mt.s].Cov2], + ) single_results = dict(hl.tuple([ht_single_pheno.locus.position, ht_single_pheno.row]).collect()) multi_results = dict(hl.tuple([ht_multi_pheno.locus.position, ht_multi_pheno.row]).collect()) - self.assertEqual(len(multi_results[1001].logistic_regression),2) + self.assertEqual(len(multi_results[1001].logistic_regression), 2) self.assertAlmostEqual(multi_results[1001].logistic_regression[0].beta, single_results[1001].beta, places=6) - self.assertAlmostEqual(multi_results[1001].logistic_regression[0].standard_error,single_results[1001].standard_error, places=6) - self.assertAlmostEqual(multi_results[1001].logistic_regression[0].z_stat, single_results[1001].z_stat, places=6) - self.assertAlmostEqual(multi_results[1001].logistic_regression[0].p_value,single_results[1001].p_value, places=6) - #TODO test handling of missingness - + self.assertAlmostEqual( + multi_results[1001].logistic_regression[0].standard_error, single_results[1001].standard_error, places=6 + ) + self.assertAlmostEqual( + multi_results[1001].logistic_regression[0].z_stat, single_results[1001].z_stat, places=6 + ) + self.assertAlmostEqual( + multi_results[1001].logistic_regression[0].p_value, single_results[1001].p_value, places=6 + ) + # TODO test handling of missingness def test_logistic_regression_wald_test_pl(self): - covariates = hl.import_table(resource('regressionLogistic.cov'), - key='Sample', - types={'Cov1': hl.tfloat, 'Cov2': hl.tfloat}) - pheno = hl.import_table(resource('regressionLogisticBoolean.pheno'), - key='Sample', - missing='0', - types={'isCase': hl.tbool}) + covariates = hl.import_table( + resource('regressionLogistic.cov'), key='Sample', types={'Cov1': hl.tfloat, 'Cov2': hl.tfloat} + ) + pheno = hl.import_table( + resource('regressionLogisticBoolean.pheno'), key='Sample', missing='0', types={'isCase': hl.tbool} + ) mt = hl.import_vcf(resource('regressionLogistic.vcf')) for logistic_regression_function in self.logreg_functions: - ht = logistic_regression_function( test='wald', y=pheno[mt.s].isCase, x=hl.pl_dosage(mt.PL), - covariates=[1.0, covariates[mt.s].Cov1, covariates[mt.s].Cov2]) + covariates=[1.0, covariates[mt.s].Cov1, covariates[mt.s].Cov2], + ) results = dict(hl.tuple([ht.locus.position, ht.row]).collect()) @@ -875,23 +888,21 @@ def is_constant(r): self.assertTrue(is_constant(results[10])) def test_logistic_regression_wald_dosage(self): - covariates = hl.import_table(resource('regressionLogistic.cov'), - key='Sample', - types={'Cov1': hl.tfloat, 'Cov2': hl.tfloat}) - pheno = hl.import_table(resource('regressionLogisticBoolean.pheno'), - key='Sample', - missing='0', - types={'isCase': hl.tbool}) - mt = hl.import_gen(resource('regressionLogistic.gen'), - sample_file=resource('regressionLogistic.sample')) + covariates = hl.import_table( + resource('regressionLogistic.cov'), key='Sample', types={'Cov1': hl.tfloat, 'Cov2': hl.tfloat} + ) + pheno = hl.import_table( + resource('regressionLogisticBoolean.pheno'), key='Sample', missing='0', types={'isCase': hl.tbool} + ) + mt = hl.import_gen(resource('regressionLogistic.gen'), sample_file=resource('regressionLogistic.sample')) for logistic_regression_function in self.logreg_functions: - ht = logistic_regression_function( test='wald', y=pheno[mt.s].isCase, x=hl.gp_dosage(mt.GP), - covariates=[1.0, covariates[mt.s].Cov1, covariates[mt.s].Cov2]) + covariates=[1.0, covariates[mt.s].Cov1, covariates[mt.s].Cov2], + ) results = dict(hl.tuple([ht.locus.position, ht.row]).collect()) @@ -927,13 +938,12 @@ def is_constant(r): # chi2 <- lrtest[["Deviance"]][2] # pval <- lrtest[["Pr(>Chi)"]][2] def test_logistic_regression_lrt(self): - covariates = hl.import_table(resource('regressionLogistic.cov'), - key='Sample', - types={'Cov1': hl.tfloat, 'Cov2': hl.tfloat}) - pheno = hl.import_table(resource('regressionLogisticBoolean.pheno'), - key='Sample', - missing='0', - types={'isCase': hl.tbool}) + covariates = hl.import_table( + resource('regressionLogistic.cov'), key='Sample', types={'Cov1': hl.tfloat, 'Cov2': hl.tfloat} + ) + pheno = hl.import_table( + resource('regressionLogisticBoolean.pheno'), key='Sample', missing='0', types={'isCase': hl.tbool} + ) mt = hl.import_vcf(resource('regressionLogistic.vcf')) for logistic_regression_function in self.logreg_functions: @@ -941,7 +951,8 @@ def test_logistic_regression_lrt(self): test='lrt', y=pheno[mt.s].isCase, x=mt.GT.n_alt_alleles(), - covariates=[1.0, covariates[mt.s].Cov1, covariates[mt.s].Cov2]) + covariates=[1.0, covariates[mt.s].Cov1, covariates[mt.s].Cov2], + ) results = dict(hl.tuple([ht.locus.position, ht.row]).collect()) @@ -974,13 +985,12 @@ def is_constant(r): # chi2 <- scoretest[["Rao"]][2] # pval <- scoretest[["Pr(>Chi)"]][2] def test_logistic_regression_score(self): - covariates = hl.import_table(resource('regressionLogistic.cov'), - key='Sample', - types={'Cov1': hl.tfloat, 'Cov2': hl.tfloat}) - pheno = hl.import_table(resource('regressionLogisticBoolean.pheno'), - key='Sample', - missing='0', - types={'isCase': hl.tbool}) + covariates = hl.import_table( + resource('regressionLogistic.cov'), key='Sample', types={'Cov1': hl.tfloat, 'Cov2': hl.tfloat} + ) + pheno = hl.import_table( + resource('regressionLogisticBoolean.pheno'), key='Sample', missing='0', types={'isCase': hl.tbool} + ) mt = hl.import_vcf(resource('regressionLogistic.vcf')) def is_constant(r): @@ -991,7 +1001,8 @@ def is_constant(r): test='score', y=pheno[mt.s].isCase, x=mt.GT.n_alt_alleles(), - covariates=[1.0, covariates[mt.s].Cov1, covariates[mt.s].Cov2]) + covariates=[1.0, covariates[mt.s].Cov1, covariates[mt.s].Cov2], + ) results = dict(hl.tuple([ht.locus.position, ht.row]).collect()) @@ -1011,21 +1022,22 @@ def is_constant(r): self.assertTrue(is_constant(results[10])) def test_logreg_pass_through(self): - covariates = hl.import_table(resource('regressionLogistic.cov'), - key='Sample', - types={'Cov1': hl.tfloat, 'Cov2': hl.tfloat}) - pheno = hl.import_table(resource('regressionLogisticBoolean.pheno'), - key='Sample', - missing='0', - types={'isCase': hl.tbool}) + covariates = hl.import_table( + resource('regressionLogistic.cov'), key='Sample', types={'Cov1': hl.tfloat, 'Cov2': hl.tfloat} + ) + pheno = hl.import_table( + resource('regressionLogisticBoolean.pheno'), key='Sample', missing='0', types={'isCase': hl.tbool} + ) mt = hl.import_vcf(resource('regressionLogistic.vcf')).annotate_rows(foo=hl.struct(bar=hl.rand_norm(0, 1))) for logreg_function in self.logreg_functions: - ht = logreg_function('wald', - y=pheno[mt.s].isCase, - x=mt.GT.n_alt_alleles(), - covariates=[1.0, covariates[mt.s].Cov1, covariates[mt.s].Cov2], - pass_through=['filters', mt.foo.bar, mt.qual]) + ht = logreg_function( + 'wald', + y=pheno[mt.s].isCase, + x=mt.GT.n_alt_alleles(), + covariates=[1.0, covariates[mt.s].Cov1, covariates[mt.s].Cov2], + pass_through=['filters', mt.foo.bar, mt.qual], + ) assert mt.aggregate_rows(hl.agg.all(mt.foo.bar == ht[mt.row_key].bar)) @@ -1041,19 +1053,19 @@ def test_logreg_pass_through(self): # zstat <- waldtest["x", "z value"] # pval <- waldtest["x", "Pr(>|z|)"] def test_poission_regression_wald_test(self): - covariates = hl.import_table(resource('regressionLogistic.cov'), - key='Sample', - types={'Cov1': hl.tfloat, 'Cov2': hl.tfloat}) - pheno = hl.import_table(resource('regressionPoisson.pheno'), - key='Sample', - missing='-1', - types={'count': hl.tint32}) + covariates = hl.import_table( + resource('regressionLogistic.cov'), key='Sample', types={'Cov1': hl.tfloat, 'Cov2': hl.tfloat} + ) + pheno = hl.import_table( + resource('regressionPoisson.pheno'), key='Sample', missing='-1', types={'count': hl.tint32} + ) mt = hl.import_vcf(resource('regressionLogistic.vcf')) ht = hl.poisson_regression_rows( test='wald', y=pheno[mt.s].count, x=mt.GT.n_alt_alleles(), - covariates=[1.0, covariates[mt.s].Cov1, covariates[mt.s].Cov2]) + covariates=[1.0, covariates[mt.s].Cov1, covariates[mt.s].Cov2], + ) results = dict(hl.tuple([ht.locus.position, ht.row]).collect()) @@ -1078,10 +1090,12 @@ def is_constant(r): def test_poisson_regression_max_iterations(self): import hail as hl + mt = hl.utils.range_matrix_table(1, 3) mt = mt.annotate_entries(x=hl.literal([1, 3, 10, 5])) ht = hl.poisson_regression_rows( - 'wald', y=hl.literal([0, 1, 1, 0])[mt.col_idx], x=mt.x[mt.col_idx], covariates=[1], max_iterations=1) + 'wald', y=hl.literal([0, 1, 1, 0])[mt.col_idx], x=mt.x[mt.col_idx], covariates=[1], max_iterations=1 + ) fit = ht.collect()[0].fit assert fit.n_iterations == 1 assert not fit.converged @@ -1099,19 +1113,19 @@ def test_poisson_regression_max_iterations(self): # chi2 <- lrtest[["Deviance"]][2] # pval <- lrtest[["Pr(>Chi)"]][2] def test_poisson_regression_lrt(self): - covariates = hl.import_table(resource('regressionLogistic.cov'), - key='Sample', - types={'Cov1': hl.tfloat, 'Cov2': hl.tfloat}) - pheno = hl.import_table(resource('regressionPoisson.pheno'), - key='Sample', - missing='-1', - types={'count': hl.tint32}) + covariates = hl.import_table( + resource('regressionLogistic.cov'), key='Sample', types={'Cov1': hl.tfloat, 'Cov2': hl.tfloat} + ) + pheno = hl.import_table( + resource('regressionPoisson.pheno'), key='Sample', missing='-1', types={'count': hl.tint32} + ) mt = hl.import_vcf(resource('regressionLogistic.vcf')) ht = hl.poisson_regression_rows( test='lrt', y=pheno[mt.s].count, x=mt.GT.n_alt_alleles(), - covariates=[1.0, covariates[mt.s].Cov1, covariates[mt.s].Cov2]) + covariates=[1.0, covariates[mt.s].Cov1, covariates[mt.s].Cov2], + ) results = dict(hl.tuple([ht.locus.position, ht.row]).collect()) @@ -1143,19 +1157,19 @@ def is_constant(r): # chi2 <- scoretest[["Rao"]][2] # pval <- scoretest[["Pr(>Chi)"]][2] def test_poisson_regression_score_test(self): - covariates = hl.import_table(resource('regressionLogistic.cov'), - key='Sample', - types={'Cov1': hl.tfloat, 'Cov2': hl.tfloat}) - pheno = hl.import_table(resource('regressionPoisson.pheno'), - key='Sample', - missing='-1', - types={'count': hl.tint32}) + covariates = hl.import_table( + resource('regressionLogistic.cov'), key='Sample', types={'Cov1': hl.tfloat, 'Cov2': hl.tfloat} + ) + pheno = hl.import_table( + resource('regressionPoisson.pheno'), key='Sample', missing='-1', types={'count': hl.tint32} + ) mt = hl.import_vcf(resource('regressionLogistic.vcf')) ht = hl.poisson_regression_rows( test='score', y=pheno[mt.s].count, x=mt.GT.n_alt_alleles(), - covariates=[1.0, covariates[mt.s].Cov1, covariates[mt.s].Cov2]) + covariates=[1.0, covariates[mt.s].Cov1, covariates[mt.s].Cov2], + ) results = dict(hl.tuple([ht.locus.position, ht.row]).collect()) @@ -1178,26 +1192,26 @@ def is_constant(r): self.assertTrue(is_constant(results[10])) def test_poisson_pass_through(self): - covariates = hl.import_table(resource('regressionLogistic.cov'), - key='Sample', - types={'Cov1': hl.tfloat, 'Cov2': hl.tfloat}) - pheno = hl.import_table(resource('regressionPoisson.pheno'), - key='Sample', - missing='-1', - types={'count': hl.tint32}) - mt = hl.import_vcf(resource('regressionLogistic.vcf')).annotate_rows(foo = hl.struct(bar=hl.rand_norm(0, 1))) + covariates = hl.import_table( + resource('regressionLogistic.cov'), key='Sample', types={'Cov1': hl.tfloat, 'Cov2': hl.tfloat} + ) + pheno = hl.import_table( + resource('regressionPoisson.pheno'), key='Sample', missing='-1', types={'count': hl.tint32} + ) + mt = hl.import_vcf(resource('regressionLogistic.vcf')).annotate_rows(foo=hl.struct(bar=hl.rand_norm(0, 1))) ht = hl.poisson_regression_rows( test='wald', y=pheno[mt.s].count, x=mt.GT.n_alt_alleles(), covariates=[1.0, covariates[mt.s].Cov1, covariates[mt.s].Cov2], - pass_through=['filters', mt.foo.bar, mt.qual]) + pass_through=['filters', mt.foo.bar, mt.qual], + ) assert mt.aggregate_rows(hl.agg.all(mt.foo.bar == ht[mt.row_key].bar)) def test_genetic_relatedness_matrix(self): n, m = 100, 200 - mt = hl.balding_nichols_model(3, n, m, fst=[.9, .9, .9], n_partitions=4) + mt = hl.balding_nichols_model(3, n, m, fst=[0.9, 0.9, 0.9], n_partitions=4) g = BlockMatrix.from_entry_expr(mt.GT.n_alt_alleles()).to_numpy().T @@ -1229,7 +1243,7 @@ def _filter_and_standardize_cols(a): def test_realized_relationship_matrix(self): n, m = 100, 200 hl.reset_global_randomness() - mt = hl.balding_nichols_model(3, n, m, fst=[.9, .9, .9], n_partitions=4) + mt = hl.balding_nichols_model(3, n, m, fst=[0.9, 0.9, 0.9], n_partitions=4) g = BlockMatrix.from_entry_expr(mt.GT.n_alt_alleles()).to_numpy().T g_std = self._filter_and_standardize_cols(g) @@ -1244,23 +1258,25 @@ def test_realized_relationship_matrix(self): self.assertRaises(FatalError, lambda: hl.realized_relationship_matrix(one_sample.GT)) def test_row_correlation_vs_hardcode(self): - data = [{'v': '1:1:A:C', 's': '1', 'GT': hl.Call([0, 0])}, - {'v': '1:1:A:C', 's': '2', 'GT': hl.Call([0, 0])}, - {'v': '1:1:A:C', 's': '3', 'GT': hl.Call([0, 1])}, - {'v': '1:1:A:C', 's': '4', 'GT': hl.Call([1, 1])}, - {'v': '1:2:G:T', 's': '1', 'GT': hl.Call([0, 1])}, - {'v': '1:2:G:T', 's': '2', 'GT': hl.Call([1, 1])}, - {'v': '1:2:G:T', 's': '3', 'GT': hl.Call([0, 1])}, - {'v': '1:2:G:T', 's': '4', 'GT': hl.Call([0, 0])}, - {'v': '1:3:C:G', 's': '1', 'GT': hl.Call([0, 1])}, - {'v': '1:3:C:G', 's': '2', 'GT': hl.Call([0, 0])}, - {'v': '1:3:C:G', 's': '3', 'GT': hl.Call([1, 1])}, - {'v': '1:3:C:G', 's': '4', 'GT': hl.missing(hl.tcall)}] + data = [ + {'v': '1:1:A:C', 's': '1', 'GT': hl.Call([0, 0])}, + {'v': '1:1:A:C', 's': '2', 'GT': hl.Call([0, 0])}, + {'v': '1:1:A:C', 's': '3', 'GT': hl.Call([0, 1])}, + {'v': '1:1:A:C', 's': '4', 'GT': hl.Call([1, 1])}, + {'v': '1:2:G:T', 's': '1', 'GT': hl.Call([0, 1])}, + {'v': '1:2:G:T', 's': '2', 'GT': hl.Call([1, 1])}, + {'v': '1:2:G:T', 's': '3', 'GT': hl.Call([0, 1])}, + {'v': '1:2:G:T', 's': '4', 'GT': hl.Call([0, 0])}, + {'v': '1:3:C:G', 's': '1', 'GT': hl.Call([0, 1])}, + {'v': '1:3:C:G', 's': '2', 'GT': hl.Call([0, 0])}, + {'v': '1:3:C:G', 's': '3', 'GT': hl.Call([1, 1])}, + {'v': '1:3:C:G', 's': '4', 'GT': hl.missing(hl.tcall)}, + ] ht = hl.Table.parallelize(data, hl.dtype('struct{v: str, s: str, GT: call}')) mt = ht.to_matrix_table(['v'], ['s']) actual = hl.row_correlation(mt.GT.n_alt_alleles()).to_numpy() - expected = [[1., -0.85280287, 0.42640143], [-0.85280287, 1., -0.5], [0.42640143, -0.5, 1.]] + expected = [[1.0, -0.85280287, 0.42640143], [-0.85280287, 1.0, -0.5], [0.42640143, -0.5, 1.0]] self.assertTrue(np.allclose(actual, expected)) @@ -1281,45 +1297,59 @@ def test_row_correlation_vs_numpy(self): self.assertTrue(np.allclose(l, cor)) def get_ld_matrix_mt(self): - data = [{'v': '1:1:A:C', 'cm': 0.1, 's': 'a', 'GT': hl.Call([0, 0])}, - {'v': '1:1:A:C', 'cm': 0.1, 's': 'b', 'GT': hl.Call([0, 0])}, - {'v': '1:1:A:C', 'cm': 0.1, 's': 'c', 'GT': hl.Call([0, 1])}, - {'v': '1:1:A:C', 'cm': 0.1, 's': 'd', 'GT': hl.Call([1, 1])}, - {'v': '1:2000000:G:T', 'cm': 0.9, 's': 'a', 'GT': hl.Call([0, 1])}, - {'v': '1:2000000:G:T', 'cm': 0.9, 's': 'b', 'GT': hl.Call([1, 1])}, - {'v': '1:2000000:G:T', 'cm': 0.9, 's': 'c', 'GT': hl.Call([0, 1])}, - {'v': '1:2000000:G:T', 'cm': 0.9, 's': 'd', 'GT': hl.Call([0, 0])}, - {'v': '2:1:C:G', 'cm': 0.2, 's': 'a', 'GT': hl.Call([0, 1])}, - {'v': '2:1:C:G', 'cm': 0.2, 's': 'b', 'GT': hl.Call([0, 0])}, - {'v': '2:1:C:G', 'cm': 0.2, 's': 'c', 'GT': hl.Call([1, 1])}, - {'v': '2:1:C:G', 'cm': 0.2, 's': 'd', 'GT': hl.missing(hl.tcall)}] + data = [ + {'v': '1:1:A:C', 'cm': 0.1, 's': 'a', 'GT': hl.Call([0, 0])}, + {'v': '1:1:A:C', 'cm': 0.1, 's': 'b', 'GT': hl.Call([0, 0])}, + {'v': '1:1:A:C', 'cm': 0.1, 's': 'c', 'GT': hl.Call([0, 1])}, + {'v': '1:1:A:C', 'cm': 0.1, 's': 'd', 'GT': hl.Call([1, 1])}, + {'v': '1:2000000:G:T', 'cm': 0.9, 's': 'a', 'GT': hl.Call([0, 1])}, + {'v': '1:2000000:G:T', 'cm': 0.9, 's': 'b', 'GT': hl.Call([1, 1])}, + {'v': '1:2000000:G:T', 'cm': 0.9, 's': 'c', 'GT': hl.Call([0, 1])}, + {'v': '1:2000000:G:T', 'cm': 0.9, 's': 'd', 'GT': hl.Call([0, 0])}, + {'v': '2:1:C:G', 'cm': 0.2, 's': 'a', 'GT': hl.Call([0, 1])}, + {'v': '2:1:C:G', 'cm': 0.2, 's': 'b', 'GT': hl.Call([0, 0])}, + {'v': '2:1:C:G', 'cm': 0.2, 's': 'c', 'GT': hl.Call([1, 1])}, + {'v': '2:1:C:G', 'cm': 0.2, 's': 'd', 'GT': hl.missing(hl.tcall)}, + ] ht = hl.Table.parallelize(data, hl.dtype('struct{v: str, s: str, cm: float64, GT: call}')) ht = ht.transmute(**hl.parse_variant(ht.v)) return ht.to_matrix_table(row_key=['locus', 'alleles'], col_key=['s'], row_fields=['cm']) def test_ld_matrix_1(self): mt = self.get_ld_matrix_mt() - self.assertTrue(np.allclose( - hl.ld_matrix(mt.GT.n_alt_alleles(), mt.locus, radius=1e6).to_numpy(), - [[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]])) + self.assertTrue( + np.allclose( + hl.ld_matrix(mt.GT.n_alt_alleles(), mt.locus, radius=1e6).to_numpy(), + [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], + ) + ) def test_ld_matrix_2(self): mt = self.get_ld_matrix_mt() - self.assertTrue(np.allclose( - hl.ld_matrix(mt.GT.n_alt_alleles(), mt.locus, radius=2e6).to_numpy(), - [[1., -0.85280287, 0.], [-0.85280287, 1., 0.], [0., 0., 1.]])) + self.assertTrue( + np.allclose( + hl.ld_matrix(mt.GT.n_alt_alleles(), mt.locus, radius=2e6).to_numpy(), + [[1.0, -0.85280287, 0.0], [-0.85280287, 1.0, 0.0], [0.0, 0.0, 1.0]], + ) + ) def test_ld_matrix_3(self): mt = self.get_ld_matrix_mt() - self.assertTrue(np.allclose( - hl.ld_matrix(mt.GT.n_alt_alleles(), mt.locus, radius=0.5, coord_expr=mt.cm).to_numpy(), - [[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]])) + self.assertTrue( + np.allclose( + hl.ld_matrix(mt.GT.n_alt_alleles(), mt.locus, radius=0.5, coord_expr=mt.cm).to_numpy(), + [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], + ) + ) def test_ld_matrix_4(self): mt = self.get_ld_matrix_mt() - self.assertTrue(np.allclose( - hl.ld_matrix(mt.GT.n_alt_alleles(), mt.locus, radius=1.0, coord_expr=mt.cm).to_numpy(), - [[1., -0.85280287, 0.], [-0.85280287, 1., 0.], [0., 0., 1.]])) + self.assertTrue( + np.allclose( + hl.ld_matrix(mt.GT.n_alt_alleles(), mt.locus, radius=1.0, coord_expr=mt.cm).to_numpy(), + [[1.0, -0.85280287, 0.0], [-0.85280287, 1.0, 0.0], [0.0, 0.0, 1.0]], + ) + ) @qobtest def test_split_multi_hts(self): @@ -1350,14 +1380,18 @@ def test_split_multi_table(self): @qobtest def test_split_multi_shuffle(self): ht = hl.utils.range_table(1) - ht = ht.annotate(keys=[hl.struct(locus=hl.locus('1', 1180), alleles=['A', 'C', 'T']), - hl.struct(locus=hl.locus('1', 1180), alleles=['A', 'G'])]) + ht = ht.annotate( + keys=[ + hl.struct(locus=hl.locus('1', 1180), alleles=['A', 'C', 'T']), + hl.struct(locus=hl.locus('1', 1180), alleles=['A', 'G']), + ] + ) ht = ht.explode(ht.keys) ht = ht.key_by(**ht.keys).drop('keys') alleles = hl.split_multi(ht, permit_shuffle=True).alleles.collect() assert alleles == [['A', 'C'], ['A', 'G'], ['A', 'T']] - ht = ht.annotate_globals(cols = [hl.struct(s='sample1'), hl.struct(s='sample2')]) + ht = ht.annotate_globals(cols=[hl.struct(s='sample1'), hl.struct(s='sample2')]) ht = ht.annotate(entries=[hl.struct(GT=hl.call(0, 1)), hl.struct(GT=hl.call(1, 1))]) mt = ht._unlocalize_entries('entries', 'cols', ['s']) mt = hl.split_multi_hts(mt, permit_shuffle=True) @@ -1367,7 +1401,7 @@ def test_split_multi_shuffle(self): @qobtest def test_issue_4527(self): mt = hl.utils.range_matrix_table(1, 1) - mt = mt.key_rows_by(locus=hl.locus(hl.str(mt.row_idx+1), mt.row_idx+1), alleles=['A', 'T']) + mt = mt.key_rows_by(locus=hl.locus(hl.str(mt.row_idx + 1), mt.row_idx + 1), alleles=['A', 'T']) mt = hl.split_multi(mt) self.assertEqual(1, mt._force_count_rows()) @@ -1381,14 +1415,16 @@ def test_ld_prune(self): filtered_ds = ds.filter_rows(hl.is_defined(pruned_table[ds.row_key])) filtered_ds = filtered_ds.annotate_rows(stats=agg.stats(filtered_ds.GT.n_alt_alleles())) - filtered_ds = filtered_ds.annotate_rows( - mean=filtered_ds.stats.mean, sd_reciprocal=1 / filtered_ds.stats.stdev) + filtered_ds = filtered_ds.annotate_rows(mean=filtered_ds.stats.mean, sd_reciprocal=1 / filtered_ds.stats.stdev) n_samples = filtered_ds.count_cols() - normalized_mean_imputed_genotype_expr = ( - hl.if_else(hl.is_defined(filtered_ds['GT']), - (filtered_ds['GT'].n_alt_alleles() - filtered_ds['mean']) - * filtered_ds['sd_reciprocal'] * (1 / hl.sqrt(n_samples)), 0)) + normalized_mean_imputed_genotype_expr = hl.if_else( + hl.is_defined(filtered_ds['GT']), + (filtered_ds['GT'].n_alt_alleles() - filtered_ds['mean']) + * filtered_ds['sd_reciprocal'] + * (1 / hl.sqrt(n_samples)), + 0, + ) std_bm = BlockMatrix.from_entry_expr(normalized_mean_imputed_genotype_expr) @@ -1400,10 +1436,10 @@ def test_ld_prune(self): entries = entries.annotate(locus_i=index_table[entries.i].locus, locus_j=index_table[entries.j].locus) bad_pair = ( - (entries.entry >= r2_threshold) & - (entries.locus_i.contig == entries.locus_j.contig) & - (hl.abs(entries.locus_j.position - entries.locus_i.position) <= window_size) & - (entries.i != entries.j) + (entries.entry >= r2_threshold) + & (entries.locus_i.contig == entries.locus_j.contig) + & (hl.abs(entries.locus_j.position - entries.locus_i.position) <= window_size) + & (entries.i != entries.j) ) self.assertEqual(entries.filter(bad_pair).count(), 0) @@ -1462,10 +1498,15 @@ def test_ld_prune_with_duplicate_row_keys(self): def test_balding_nichols_model(self): hl.reset_global_randomness() - ds = hl.balding_nichols_model(2, 20, 25, 3, - pop_dist=[1.0, 2.0], - fst=[.02, .06], - af_dist=hl.rand_beta(a=0.01, b=2.0, lower=0.05, upper=0.95)) + ds = hl.balding_nichols_model( + 2, + 20, + 25, + 3, + pop_dist=[1.0, 2.0], + fst=[0.02, 0.06], + af_dist=hl.rand_beta(a=0.01, b=2.0, lower=0.05, upper=0.95), + ) ds.entries().show(100, width=200) @@ -1478,36 +1519,43 @@ def test_balding_nichols_model(self): self.assertEqual(hl.eval(glob.bn.n_samples), 20) self.assertEqual(hl.eval(glob.bn.n_variants), 25) self.assertEqual(hl.eval(glob.bn.pop_dist), [1, 2]) - self.assertEqual(hl.eval(glob.bn.fst), [.02, .06]) + self.assertEqual(hl.eval(glob.bn.fst), [0.02, 0.06]) def test_balding_nichols_model_same_results(self): for mixture in [True, False]: hl.reset_global_randomness() - ds1 = hl.balding_nichols_model(2, 20, 25, 3, - pop_dist=[1.0, 2.0], - fst=[.02, .06], - af_dist=hl.rand_beta(a=0.01, b=2.0, lower=0.05, upper=0.95), - mixture=mixture) + ds1 = hl.balding_nichols_model( + 2, + 20, + 25, + 3, + pop_dist=[1.0, 2.0], + fst=[0.02, 0.06], + af_dist=hl.rand_beta(a=0.01, b=2.0, lower=0.05, upper=0.95), + mixture=mixture, + ) hl.reset_global_randomness() - ds2 = hl.balding_nichols_model(2, 20, 25, 3, - pop_dist=[1.0, 2.0], - fst=[.02, .06], - af_dist=hl.rand_beta(a=0.01, b=2.0, lower=0.05, upper=0.95), - mixture=mixture) + ds2 = hl.balding_nichols_model( + 2, + 20, + 25, + 3, + pop_dist=[1.0, 2.0], + fst=[0.02, 0.06], + af_dist=hl.rand_beta(a=0.01, b=2.0, lower=0.05, upper=0.95), + mixture=mixture, + ) self.assertTrue(ds1._same(ds2)) def test_balding_nichols_model_af_ranges(self): def test_af_range(rand_func, min, max, seed): hl.reset_global_randomness() bn = hl.balding_nichols_model(3, 400, 400, af_dist=rand_func) - self.assertTrue( - bn.aggregate_rows( - hl.agg.all((bn.ancestral_af > min) & - (bn.ancestral_af < max)))) + self.assertTrue(bn.aggregate_rows(hl.agg.all((bn.ancestral_af > min) & (bn.ancestral_af < max)))) - test_af_range(hl.rand_beta(.01, 2, .2, .8), .2, .8, 0) - test_af_range(hl.rand_beta(3, 3, .4, .6), .4, .6, 1) - test_af_range(hl.rand_unif(.4, .7), .4, .7, 2) + test_af_range(hl.rand_beta(0.01, 2, 0.2, 0.8), 0.2, 0.8, 0) + test_af_range(hl.rand_beta(3, 3, 0.4, 0.6), 0.4, 0.6, 1) + test_af_range(hl.rand_unif(0.4, 0.7), 0.4, 0.7, 2) test_af_range(hl.rand_beta(4, 6), 0, 1, 3) @test_timeout(batch=6 * 60) @@ -1524,22 +1572,25 @@ def test_stat(k, n, m, seed): # test af distribution def variance(expr): return hl.bind(lambda mean: hl.mean(hl.map(lambda elt: (elt - mean) ** 2, expr)), hl.mean(expr)) + delta_mean = 0.2 # consider alternatives to 0.2 delta_var = 0.1 - per_row = hl.bind(lambda mean, var, ancestral: - (ancestral > mean - delta_mean) & - (ancestral < mean + delta_mean) & - (.1 * ancestral * (1 - ancestral) > var - delta_var) & - (.1 * ancestral * (1 - ancestral) < var + delta_var), - hl.mean(bn.af), - variance(bn.af), - bn.ancestral_af) + per_row = hl.bind( + lambda mean, var, ancestral: (ancestral > mean - delta_mean) + & (ancestral < mean + delta_mean) + & (0.1 * ancestral * (1 - ancestral) > var - delta_var) + & (0.1 * ancestral * (1 - ancestral) < var + delta_var), + hl.mean(bn.af), + variance(bn.af), + bn.ancestral_af, + ) self.assertTrue(bn.aggregate_rows(hl.agg.all(per_row))) # test genotype distribution stats_gt_by_pop = hl.agg.group_by(bn.pop, hl.agg.stats(hl.float(bn.GT.n_alt_alleles()))).values() - bn = bn.select_rows(sum_af=hl.sum(bn.af), - sum_mean_gt_by_pop=hl.sum(hl.map(lambda x: x.mean, stats_gt_by_pop))) + bn = bn.select_rows( + sum_af=hl.sum(bn.af), sum_mean_gt_by_pop=hl.sum(hl.map(lambda x: x.mean, stats_gt_by_pop)) + ) sum_af = bn.aggregate_rows(hl.agg.sum(bn.sum_af)) sum_mean_gt = bn.aggregate_rows(hl.agg.sum(bn.sum_mean_gt_by_pop)) self.assertAlmostEqual(sum_mean_gt, 2 * sum_af, delta=0.1 * m * k) @@ -1550,16 +1601,37 @@ def variance(expr): @skip_when_service_backend(reason='flaky, incorrect alleles in output') def test_balding_nichols_model_phased(self): bn_ds = hl.balding_nichols_model(1, 5, 5, phased=True) - assert bn_ds.aggregate_entries(hl.agg.all(bn_ds.GT.phased)) == True + assert bn_ds.aggregate_entries(hl.agg.all(bn_ds.GT.phased)) is True actual = bn_ds.GT.collect() self.assertListEqual( - [ c.alleles for c in actual ], - [ [0, 0], [0, 0], [0, 0], [0, 0], [0, 0] - , [1, 1], [1, 1], [1, 1], [1, 0], [1, 1] - , [1, 1], [0, 1], [1, 0], [1, 0], [0, 1] - , [0, 0], [0, 0], [0, 0], [0, 0], [1, 0] - , [1, 1], [1, 1], [0, 1], [1, 1], [1, 1] - ] + [c.alleles for c in actual], + [ + [0, 0], + [0, 0], + [0, 0], + [0, 0], + [0, 0], + [1, 1], + [1, 1], + [1, 1], + [1, 0], + [1, 1], + [1, 1], + [0, 1], + [1, 0], + [1, 0], + [0, 1], + [0, 0], + [0, 0], + [0, 0], + [0, 0], + [1, 0], + [1, 1], + [1, 1], + [0, 1], + [1, 1], + [1, 1], + ], ) def test_de_novo(self): @@ -1573,7 +1645,8 @@ def test_de_novo(self): dad_id=r.father.s, mom_id=r.mother.s, p_de_novo=r.p_de_novo, - confidence=r.confidence).key_by('locus', 'alleles', 'kid_id', 'dad_id', 'mom_id') + confidence=r.confidence, + ).key_by('locus', 'alleles', 'kid_id', 'dad_id', 'mom_id') truth = hl.import_table(resource('denovo.out'), impute=True, comment='#') truth = truth.select( @@ -1583,7 +1656,8 @@ def test_de_novo(self): dad_id=truth['Dad_ID'], mom_id=truth['Mom_ID'], p_de_novo=truth['Prob_dn'], - confidence=truth['Validation_Likelihood'].split('_')[0]).key_by('locus', 'alleles', 'kid_id', 'dad_id', 'mom_id') + confidence=truth['Validation_Likelihood'].split('_')[0], + ).key_by('locus', 'alleles', 'kid_id', 'dad_id', 'mom_id') j = r.join(truth, how='outer') self.assertTrue(j.all((j.confidence == j.confidence_1) & (hl.abs(j.p_de_novo - j.p_de_novo_1) < 1e-4))) @@ -1605,18 +1679,20 @@ def test_warn_if_no_intercept(self): mt = hl.balding_nichols_model(1, 1, 1).add_row_index().add_col_index() intercept = hl.float64(1.0) - for covariates in [[], - [mt.row_idx], - [mt.col_idx], - [mt.GT.n_alt_alleles()], - [mt.row_idx, mt.col_idx, mt.GT.n_alt_alleles()]]: + for covariates in [ + [], + [mt.row_idx], + [mt.col_idx], + [mt.GT.n_alt_alleles()], + [mt.row_idx, mt.col_idx, mt.GT.n_alt_alleles()], + ]: self.assertTrue(hl.methods.statgen._warn_if_no_intercept('', covariates)) - self.assertFalse(hl.methods.statgen._warn_if_no_intercept('', [intercept] + covariates)) + self.assertFalse(hl.methods.statgen._warn_if_no_intercept('', [intercept, *covariates])) def test_regression_field_dependence(self): mt = hl.utils.range_matrix_table(10, 10) - mt = mt.annotate_cols(c1 = hl.literal([x % 2 == 0 for x in range(10)])[mt.col_idx], c2 = hl.rand_norm(0, 1)) - mt = mt.annotate_entries(e1 = hl.int(hl.rand_norm(0, 1) * 10)) + mt = mt.annotate_cols(c1=hl.literal([x % 2 == 0 for x in range(10)])[mt.col_idx], c2=hl.rand_norm(0, 1)) + mt = mt.annotate_entries(e1=hl.int(hl.rand_norm(0, 1) * 10)) x_expr = hl.case().when(mt.c2 < 0, 0).default(mt.e1) @@ -1633,9 +1709,9 @@ def logistic_epacts_mt(): # Locus("22", 16115882) # MAC 1207 # Locus("22", 16117940) # MAC 7 # Locus("22", 16117953) # MAC 21 - covariates = hl.import_table(resource('regressionLogisticEpacts.cov'), - key='IND_ID', - types={'PC1': hl.tfloat, 'PC2': hl.tfloat}) + covariates = hl.import_table( + resource('regressionLogisticEpacts.cov'), key='IND_ID', types={'PC1': hl.tfloat, 'PC2': hl.tfloat} + ) fam = hl.import_fam(resource('regressionLogisticEpacts.fam')) mt = hl.import_vcf(resource('regressionLogisticEpacts.vcf')) @@ -1646,10 +1722,8 @@ def logistic_epacts_mt(): def test_logistic_regression_epacts_wald(logistic_epacts_mt): mt = logistic_epacts_mt actual = hl.logistic_regression_rows( - test='wald', - y=mt.is_case, - x=mt.GT.n_alt_alleles(), - covariates=[1.0, mt.is_female, mt.PC1, mt.PC2]).collect() + test='wald', y=mt.is_case, x=mt.GT.n_alt_alleles(), covariates=[1.0, mt.is_female, mt.PC1, mt.PC2] + ).collect() assert actual[0].locus == hl.Locus("22", 16060511, 'GRCh37') assert actual[0].beta == pytest.approx(-0.097476, rel=1e-4) @@ -1685,10 +1759,7 @@ def test_logistic_regression_epacts_wald(logistic_epacts_mt): def test_logistic_regression_epacts_lrt(logistic_epacts_mt): mt = logistic_epacts_mt actual = hl.logistic_regression_rows( - test='lrt', - y=mt.is_case, - x=mt.GT.n_alt_alleles(), - covariates=[1.0, mt.is_female, mt.PC1, mt.PC2] + test='lrt', y=mt.is_case, x=mt.GT.n_alt_alleles(), covariates=[1.0, mt.is_female, mt.PC1, mt.PC2] ).collect() assert actual[0].locus == hl.Locus("22", 16060511, 'GRCh37') @@ -1739,10 +1810,7 @@ def test_logistic_regression_epacts_score(logistic_epacts_mt): # mt = logistic_epacts_mt actual = hl.logistic_regression_rows( - test='score', - y=mt.is_case, - x=mt.GT.n_alt_alleles(), - covariates=[1.0, mt.is_female, mt.PC1, mt.PC2] + test='score', y=mt.is_case, x=mt.GT.n_alt_alleles(), covariates=[1.0, mt.is_female, mt.PC1, mt.PC2] ).collect() assert actual[0].locus == hl.Locus("22", 16060511, 'GRCh37') @@ -1769,10 +1837,7 @@ def test_logistic_regression_epacts_score(logistic_epacts_mt): def test_logistic_regression_epacts_firth(logistic_epacts_mt): mt = logistic_epacts_mt actual = hl.logistic_regression_rows( - test='firth', - y=mt.is_case, - x=mt.GT.n_alt_alleles(), - covariates=[1.0, mt.is_female, mt.PC1, mt.PC2] + test='firth', y=mt.is_case, x=mt.GT.n_alt_alleles(), covariates=[1.0, mt.is_female, mt.PC1, mt.PC2] ).collect() assert actual[0].locus == hl.Locus("22", 16060511, 'GRCh37') @@ -1798,13 +1863,8 @@ def test_logistic_regression_epacts_firth(logistic_epacts_mt): ## issue 13788 def test_logistic_regression_y_parameter_sanity(): - mt = hl.utils.range_matrix_table(2,2) - mt = mt.annotate_entries(prod = mt.row_idx * mt.col_idx) + mt = hl.utils.range_matrix_table(2, 2) + mt = mt.annotate_entries(prod=mt.row_idx * mt.col_idx) with pytest.raises(hl.ExpressionException): - hl.logistic_regression_rows( - test='wald', - x=mt.prod, - y=mt.row_idx, - covariates=[1.0] - ).describe() + hl.logistic_regression_rows(test='wald', x=mt.prod, y=mt.row_idx, covariates=[1.0]).describe() diff --git a/hail/python/test/hail/plot/test_plot.py b/hail/python/test/hail/plot/test_plot.py index 8ded69e3fa6..8d79df5d1e3 100644 --- a/hail/python/test/hail/plot/test_plot.py +++ b/hail/python/test/hail/plot/test_plot.py @@ -1,6 +1,7 @@ -import pytest from unittest.mock import patch +import pytest + import hail as hl @@ -31,11 +32,6 @@ def test_cumulative_histogram(): hl.plot.cumulative_histogram(ht.idx) -def test_histogram2d(): - ht = hl.utils.range_matrix_table(100) - hl.plot.histogram2d(ht.idx, ht.col_idx) - - def test_histogram2d(): ht = hl.utils.range_table(100) hl.plot.histogram2d(ht.idx, ht.idx * ht.idx) @@ -55,6 +51,7 @@ def test_qq(): ht = hl.utils.range_table(100) hl.plot.qq(ht.idx / 100) + def test_manhattan(): ht = hl.balding_nichols_model(1, n_variants=100, n_samples=1).rows() ht = ht.add_index('idx') @@ -62,13 +59,14 @@ def test_manhattan(): @pytest.mark.parametrize( - 'name, plot' - , [ ('manhattan', hl.plot.manhattan ) - , ('scatter', lambda x, **kwargs: hl.plot.scatter(x, x, **kwargs) ) - , ('join_plot', lambda x, **kwargs: hl.plot.joint_plot(x, x,** kwargs)) - , ('qq', hl.plot.qq ) - ] - ) + 'name, plot', + [ + ('manhattan', hl.plot.manhattan), + ('scatter', lambda x, **kwargs: hl.plot.scatter(x, x, **kwargs)), + ('join_plot', lambda x, **kwargs: hl.plot.joint_plot(x, x, **kwargs)), + ('qq', hl.plot.qq), + ], +) def test_plots_deprecated_collect_all(name, plot): ht = hl.balding_nichols_model(1, n_variants=100, n_samples=1).rows() ht = ht.add_index('idx') diff --git a/hail/python/test/hail/table/test_grouped_table.py b/hail/python/test/hail/table/test_grouped_table.py index 384995234e1..1305d8d8cf1 100644 --- a/hail/python/test/hail/table/test_grouped_table.py +++ b/hail/python/test/hail/table/test_grouped_table.py @@ -15,10 +15,11 @@ def test_aggregate_by(self): expected = ( hl.Table.parallelize( - [{'group': True, 'sum': 1, 'max': 1}, - {'group': False, 'sum': 5, 'max': 3}], - hl.tstruct(group=hl.tbool, sum=hl.tint64, max=hl.tint32) - ).annotate_globals(glob=5).key_by('group') + [{'group': True, 'sum': 1, 'max': 1}, {'group': False, 'sum': 5, 'max': 3}], + hl.tstruct(group=hl.tbool, sum=hl.tint64, max=hl.tint32), + ) + .annotate_globals(glob=5) + .key_by('group') ) self.assertTrue(result._same(expected)) @@ -33,15 +34,17 @@ def test_aggregate_by_with_joins(self): ht = ht.annotate_globals(glob=5) grouped = ht.group_by(group=ht2[ht.idx].idx2 < 2) - result = grouped.aggregate(sum=hl.agg.sum(ht2[ht.idx].idx2 + ht.glob) + ht.glob - 15, - max=hl.agg.max(ht2[ht.idx].idx2)) + result = grouped.aggregate( + sum=hl.agg.sum(ht2[ht.idx].idx2 + ht.glob) + ht.glob - 15, max=hl.agg.max(ht2[ht.idx].idx2) + ) expected = ( hl.Table.parallelize( - [{'group': True, 'sum': 1, 'max': 1}, - {'group': False, 'sum': 5, 'max': 3}], - hl.tstruct(group=hl.tbool, sum=hl.tint64, max=hl.tint32) - ).annotate_globals(glob=5).key_by('group') + [{'group': True, 'sum': 1, 'max': 1}, {'group': False, 'sum': 5, 'max': 3}], + hl.tstruct(group=hl.tbool, sum=hl.tint64, max=hl.tint32), + ) + .annotate_globals(glob=5) + .key_by('group') ) self.assertTrue(result._same(expected)) diff --git a/hail/python/test/hail/table/test_table.py b/hail/python/test/hail/table/test_table.py index 82664d1a0d5..c3f44479253 100644 --- a/hail/python/test/hail/table/test_table.py +++ b/hail/python/test/hail/table/test_table.py @@ -1,18 +1,28 @@ +import os import unittest -import pandas as pd import numpy as np +import pandas as pd import pyspark.sql import pytest import hail as hl import hail.expr.aggregators as agg +from hail import ExpressionException, ir from hail.utils import new_temp_file from hail.utils.java import Env -import hail.ir as ir -from hail import ExpressionException -from ..helpers import * +from ..helpers import ( + assert_time, + convert_struct_to_dict, + create_all_values_datasets, + create_all_values_table, + lower_only, + qobtest, + resource, + skip_unless_spark_backend, + test_timeout, +) class Tests(unittest.TestCase): @@ -20,56 +30,68 @@ class Tests(unittest.TestCase): def test_annotate(self): schema = hl.tstruct(a=hl.tint32, b=hl.tint32, c=hl.tint32, d=hl.tint32, e=hl.tstr, f=hl.tarray(hl.tint32)) - rows = [{'a': 4, 'b': 1, 'c': 3, 'd': 5, 'e': "hello", 'f': [1, 2, 3]}, - {'a': 0, 'b': 5, 'c': 13, 'd': -1, 'e': "cat", 'f': []}, - {'a': 4, 'b': 2, 'c': 20, 'd': 3, 'e': "dog", 'f': [5, 6, 7]}] + rows = [ + {'a': 4, 'b': 1, 'c': 3, 'd': 5, 'e': "hello", 'f': [1, 2, 3]}, + {'a': 0, 'b': 5, 'c': 13, 'd': -1, 'e': "cat", 'f': []}, + {'a': 4, 'b': 2, 'c': 20, 'd': 3, 'e': "dog", 'f': [5, 6, 7]}, + ] kt = hl.Table.parallelize(rows, schema) self.assertTrue(kt.annotate()._same(kt)) - result1 = convert_struct_to_dict(kt.annotate(foo=kt.a + 1, - foo2=kt.a).take(1)[0]) - - self.assertDictEqual(result1, {'a': 4, - 'b': 1, - 'c': 3, - 'd': 5, - 'e': "hello", - 'f': [1, 2, 3], - 'foo': 5, - 'foo2': 4}) - - result3 = convert_struct_to_dict(kt.annotate( - x1=kt.f.map(lambda x: x * 2), - x2=kt.f.map(lambda x: [x, x + 1]).flatmap(lambda x: x), - x3=hl.min(kt.f), - x4=hl.max(kt.f), - x5=hl.sum(kt.f), - x6=hl.product(kt.f), - x7=kt.f.length(), - x8=kt.f.filter(lambda x: x == 3), - x9=kt.f[1:], - x10=kt.f[:], - x11=kt.f[1:2], - x12=kt.f.map(lambda x: [x, x + 1]), - x13=kt.f.map(lambda x: [[x, x + 1], [x + 2]]).flatmap(lambda x: x), - x14=hl.if_else(kt.a < kt.b, kt.c, hl.missing(hl.tint32)), - x15={1, 2, 3} - ).take(1)[0]) - - self.assertDictEqual(result3, {'a': 4, - 'b': 1, - 'c': 3, - 'd': 5, - 'e': "hello", - 'f': [1, 2, 3], - 'x1': [2, 4, 6], 'x2': [1, 2, 2, 3, 3, 4], - 'x3': 1, 'x4': 3, 'x5': 6, 'x6': 6, 'x7': 3, 'x8': [3], - 'x9': [2, 3], 'x10': [1, 2, 3], 'x11': [2], - 'x12': [[1, 2], [2, 3], [3, 4]], - 'x13': [[1, 2], [3], [2, 3], [4], [3, 4], [5]], - 'x14': None, 'x15': set([1, 2, 3])}) + result1 = convert_struct_to_dict(kt.annotate(foo=kt.a + 1, foo2=kt.a).take(1)[0]) + + self.assertDictEqual( + result1, {'a': 4, 'b': 1, 'c': 3, 'd': 5, 'e': "hello", 'f': [1, 2, 3], 'foo': 5, 'foo2': 4} + ) + + result3 = convert_struct_to_dict( + kt.annotate( + x1=kt.f.map(lambda x: x * 2), + x2=kt.f.map(lambda x: [x, x + 1]).flatmap(lambda x: x), + x3=hl.min(kt.f), + x4=hl.max(kt.f), + x5=hl.sum(kt.f), + x6=hl.product(kt.f), + x7=kt.f.length(), + x8=kt.f.filter(lambda x: x == 3), + x9=kt.f[1:], + x10=kt.f[:], + x11=kt.f[1:2], + x12=kt.f.map(lambda x: [x, x + 1]), + x13=kt.f.map(lambda x: [[x, x + 1], [x + 2]]).flatmap(lambda x: x), + x14=hl.if_else(kt.a < kt.b, kt.c, hl.missing(hl.tint32)), + x15={1, 2, 3}, + ).take(1)[0] + ) + + self.assertDictEqual( + result3, + { + 'a': 4, + 'b': 1, + 'c': 3, + 'd': 5, + 'e': "hello", + 'f': [1, 2, 3], + 'x1': [2, 4, 6], + 'x2': [1, 2, 2, 3, 3, 4], + 'x3': 1, + 'x4': 3, + 'x5': 6, + 'x6': 6, + 'x7': 3, + 'x8': [3], + 'x9': [2, 3], + 'x10': [1, 2, 3], + 'x11': [2], + 'x12': [[1, 2], [2, 3], [3, 4]], + 'x13': [[1, 2], [3], [2, 3], [4], [3, 4], [5]], + 'x14': None, + 'x15': set([1, 2, 3]), + }, + ) kt.annotate( x1=kt.a + 5, x2=5 + kt.a, @@ -106,23 +128,29 @@ def test_annotate(self): x33=(kt.a == 0) & (kt.b == 5), x34=(kt.a == 0) | (kt.b == 5), x35=False, - x36=True + x36=True, ) @qobtest def test_aggregate1(self): schema = hl.tstruct(a=hl.tint32, b=hl.tint32, c=hl.tint32, d=hl.tint32, e=hl.tstr, f=hl.tarray(hl.tint32)) - rows = [{'a': 4, 'b': 1, 'c': 3, 'd': 5, 'e': "hello", 'f': [1, 2, 3]}, - {'a': 0, 'b': 5, 'c': 13, 'd': -1, 'e': "cat", 'f': []}, - {'a': 4, 'b': 2, 'c': 20, 'd': 3, 'e': "dog", 'f': [5, 6, 7]}] + rows = [ + {'a': 4, 'b': 1, 'c': 3, 'd': 5, 'e': "hello", 'f': [1, 2, 3]}, + {'a': 0, 'b': 5, 'c': 13, 'd': -1, 'e': "cat", 'f': []}, + {'a': 4, 'b': 2, 'c': 20, 'd': 3, 'e': "dog", 'f': [5, 6, 7]}, + ] kt = hl.Table.parallelize(rows, schema) - results = kt.aggregate(hl.Struct(q1=agg.sum(kt.b), - q2=agg.count(), - q3=agg.collect(kt.e), - q4=agg.filter((kt.d >= 5) | (kt.a == 0), agg.collect(kt.e)), - q5=agg.explode(lambda elt: agg.mean(elt), kt.f))) + results = kt.aggregate( + hl.Struct( + q1=agg.sum(kt.b), + q2=agg.count(), + q3=agg.collect(kt.e), + q4=agg.filter((kt.d >= 5) | (kt.a == 0), agg.collect(kt.e)), + q5=agg.explode(lambda elt: agg.mean(elt), kt.f), + ) + ) self.assertEqual(results.q1, 8) self.assertEqual(results.q2, 3) @@ -133,14 +161,13 @@ def test_aggregate1(self): def test_aggregate2(self): schema = hl.tstruct(status=hl.tint32, GT=hl.tcall, qPheno=hl.tint32) - rows = [{'status': 0, 'GT': hl.Call([0, 0]), 'qPheno': 3}, - {'status': 0, 'GT': hl.Call([0, 1]), 'qPheno': 13}] + rows = [{'status': 0, 'GT': hl.Call([0, 0]), 'qPheno': 3}, {'status': 0, 'GT': hl.Call([0, 1]), 'qPheno': 13}] kt = hl.Table.parallelize(rows, schema) result = convert_struct_to_dict( kt.group_by(status=kt.status) - .aggregate( + .aggregate( x1=agg.collect(kt.qPheno * 2), x2=agg.explode(lambda elt: agg.collect(elt), [kt.qPheno, kt.qPheno + 1]), x3=agg.min(kt.qPheno), @@ -158,21 +185,32 @@ def test_aggregate2(self): x16=agg.collect(hl.Struct(a=5, b="foo", c=hl.Struct(banana='apple')).c.banana)[0], x17=agg.explode(lambda elt: agg.collect(elt), hl.missing(hl.tarray(hl.tint32))), x18=agg.explode(lambda elt: agg.collect(elt), hl.missing(hl.tset(hl.tint32))), - x19=agg.take(kt.GT, 1, ordering=-kt.qPheno) - ).take(1)[0]) - - expected = {u'status': 0, - u'x13': {u'n_called': 2, u'expected_homs': 1.64, u'f_stat': -1.777777777777777, - u'observed_homs': 1}, - u'x14': {u'AC': [3, 1], u'AF': [0.75, 0.25], u'AN': 4, u'homozygote_count': [1, 0]}, - u'x15': {u'a': 5, u'c': {u'banana': u'apple'}, u'b': u'foo'}, - u'x10': {u'min': 3.0, u'max': 13.0, u'sum': 16.0, u'stdev': 5.0, u'n': 2, u'mean': 8.0}, - u'x8': 1, u'x9': 0.0, u'x16': u'apple', - u'x11': {u'het_freq_hwe': 0.5, u'p_value': 0.5}, - u'x2': [3, 4, 13, 14], u'x3': 3, u'x1': [6, 26], u'x6': 39, u'x7': 2, u'x4': 13, u'x5': 16, - u'x17': [], - u'x18': [], - u'x19': [hl.Call([0, 1])]} + x19=agg.take(kt.GT, 1, ordering=-kt.qPheno), + ) + .take(1)[0] + ) + + expected = { + 'status': 0, + 'x13': {'n_called': 2, 'expected_homs': 1.64, 'f_stat': -1.777777777777777, 'observed_homs': 1}, + 'x14': {'AC': [3, 1], 'AF': [0.75, 0.25], 'AN': 4, 'homozygote_count': [1, 0]}, + 'x15': {'a': 5, 'c': {'banana': 'apple'}, 'b': 'foo'}, + 'x10': {'min': 3.0, 'max': 13.0, 'sum': 16.0, 'stdev': 5.0, 'n': 2, 'mean': 8.0}, + 'x8': 1, + 'x9': 0.0, + 'x16': 'apple', + 'x11': {'het_freq_hwe': 0.5, 'p_value': 0.5}, + 'x2': [3, 4, 13, 14], + 'x3': 3, + 'x1': [6, 26], + 'x6': 39, + 'x7': 2, + 'x4': 13, + 'x5': 16, + 'x17': [], + 'x18': [], + 'x19': [hl.Call([0, 1])], + } self.maxDiff = None @@ -180,10 +218,14 @@ def test_aggregate2(self): def test_aggregate_ir(self): kt = hl.utils.range_table(10).annotate_globals(g1=5) - r = kt.aggregate(hl.struct(x=agg.sum(kt.idx) + kt.g1, - y=agg.filter(kt.idx % 2 != 0, agg.sum(kt.idx + 2)) + kt.g1, - z=agg.sum(kt.g1 + kt.idx) + kt.g1)) - self.assertEqual(convert_struct_to_dict(r), {u'x': 50, u'y': 40, u'z': 100}) + r = kt.aggregate( + hl.struct( + x=agg.sum(kt.idx) + kt.g1, + y=agg.filter(kt.idx % 2 != 0, agg.sum(kt.idx + 2)) + kt.g1, + z=agg.sum(kt.g1 + kt.idx) + kt.g1, + ) + ) + self.assertEqual(convert_struct_to_dict(r), {'x': 50, 'y': 40, 'z': 100}) r = kt.aggregate(5) self.assertEqual(r, 5) @@ -196,7 +238,7 @@ def test_aggregate_ir(self): def test_java_array_string_encoding(self): ht = hl.utils.range_table(10) - ht = ht.annotate(foo = hl.str(ht.idx).split(",")) + ht = ht.annotate(foo=hl.str(ht.idx).split(",")) path = new_temp_file(extension='ht') ht.write(path) hl.read_table(path)._force_count() @@ -204,10 +246,10 @@ def test_java_array_string_encoding(self): def test_to_matrix_table(self): N, M = 50, 50 mt = hl.utils.range_matrix_table(N, M) - mt = mt.key_cols_by(s = 'Col' + hl.str(M - mt.col_idx)) - mt = mt.annotate_cols(c1 = hl.rand_bool(0.5)) - mt = mt.annotate_rows(r1 = hl.rand_bool(0.5)) - mt = mt.annotate_entries(e1 = hl.rand_bool(0.5)) + mt = mt.key_cols_by(s='Col' + hl.str(M - mt.col_idx)) + mt = mt.annotate_cols(c1=hl.rand_bool(0.5)) + mt = mt.annotate_rows(r1=hl.rand_bool(0.5)) + mt = mt.annotate_entries(e1=hl.rand_bool(0.5)) re_mt = mt.entries().to_matrix_table(['row_idx'], ['s'], ['r1'], ['col_idx', 'c1']) new_col_order = re_mt.col_idx.collect() @@ -220,7 +262,9 @@ def test_to_matrix_table_row_major(self): t = t.annotate(foo=t.idx, bar=2 * t.idx, baz=3 * t.idx) mt = t.to_matrix_table_row_major(['bar', 'baz'], 'entry', 'col') round_trip = mt.localize_entries('entries', 'cols') - round_trip = round_trip.transmute(**{col.col: round_trip.entries[i].entry for i, col in enumerate(hl.eval(round_trip.cols))}) + round_trip = round_trip.transmute(**{ + col.col: round_trip.entries[i].entry for i, col in enumerate(hl.eval(round_trip.cols)) + }) round_trip = round_trip.drop(round_trip.cols) self.assertTrue(t._same(round_trip)) @@ -229,7 +273,9 @@ def test_to_matrix_table_row_major(self): t = t.annotate(foo=t.idx, bar=hl.struct(val=2 * t.idx), baz=hl.struct(val=3 * t.idx)) mt = t.to_matrix_table_row_major(['bar', 'baz']) round_trip = mt.localize_entries('entries', 'cols') - round_trip = round_trip.transmute(**{col.col: round_trip.entries[i] for i, col in enumerate(hl.eval(round_trip.cols))}) + round_trip = round_trip.transmute(**{ + col.col: round_trip.entries[i] for i, col in enumerate(hl.eval(round_trip.cols)) + }) round_trip = round_trip.drop(round_trip.cols) self.assertTrue(t._same(round_trip)) @@ -242,43 +288,37 @@ def test_to_matrix_table_row_major(self): def test_group_by_field_lifetimes(self): ht = hl.utils.range_table(3) - ht2 = (ht.group_by(idx='100') - .aggregate(x=hl.agg.collect_as_set(ht.idx + 5))) - assert (ht2.all(ht2.x == hl.set({5, 6, 7}))) + ht2 = ht.group_by(idx='100').aggregate(x=hl.agg.collect_as_set(ht.idx + 5)) + assert ht2.all(ht2.x == hl.set({5, 6, 7})) def test_group_aggregate_by_key(self): ht = hl.utils.range_table(100, n_partitions=10) - r1 = ht.group_by(k = ht.idx % 5)._set_buffer_size(3).aggregate(n = hl.agg.count()) - r2 = ht.group_by(k = ht.idx // 20)._set_buffer_size(3).aggregate(n = hl.agg.count()) + r1 = ht.group_by(k=ht.idx % 5)._set_buffer_size(3).aggregate(n=hl.agg.count()) + r2 = ht.group_by(k=ht.idx // 20)._set_buffer_size(3).aggregate(n=hl.agg.count()) assert r1.all(r1.n == 20) assert r2.all(r2.n == 20) def test_aggregate_by_key_partitioning(self): - ht1 = hl.Table.parallelize([ - {'k': 'foo', 'b': 1}, - {'k': 'bar', 'b': 2}, - {'k': 'bar', 'b': 2}], + ht1 = hl.Table.parallelize( + [{'k': 'foo', 'b': 1}, {'k': 'bar', 'b': 2}, {'k': 'bar', 'b': 2}], hl.tstruct(k=hl.tstr, b=hl.tint32), - key='k') + key='k', + ) self.assertEqual( - set(ht1.group_by('k').aggregate(mean_b = hl.agg.mean(ht1.b)).collect()), - {hl.Struct(k='foo', mean_b=1.0), hl.Struct(k='bar', mean_b=2.0)}) + set(ht1.group_by('k').aggregate(mean_b=hl.agg.mean(ht1.b)).collect()), + {hl.Struct(k='foo', mean_b=1.0), hl.Struct(k='bar', mean_b=2.0)}, + ) @test_timeout(batch=6 * 60) def test_group_aggregate_na(self): ht = hl.utils.range_table(100, 8) ht = ht.key_by(k=hl.or_missing(ht.idx % 10 == 0, ht.idx % 4)) - expected = [ - hl.utils.Struct(k=0, n=5), - hl.utils.Struct(k=2, n=5), - hl.utils.Struct(k=None, n=90) - ] + expected = [hl.utils.Struct(k=0, n=5), hl.utils.Struct(k=2, n=5), hl.utils.Struct(k=None, n=90)] # test map side combine and shuffle aggregation assert ht.group_by(ht.k).aggregate(n=hl.agg.count()).collect() == expected - ht = ht.checkpoint(new_temp_file()) # test sorted aggregation assert ht.group_by(ht.k).aggregate(n=hl.agg.count()).collect() == expected @@ -286,9 +326,11 @@ def test_group_aggregate_na(self): def test_filter(self): schema = hl.tstruct(a=hl.tint32, b=hl.tint32, c=hl.tint32, d=hl.tint32, e=hl.tstr, f=hl.tarray(hl.tint32)) - rows = [{'a': 4, 'b': 1, 'c': 3, 'd': 5, 'e': "hello", 'f': [1, 2, 3]}, - {'a': 0, 'b': 5, 'c': 13, 'd': -1, 'e': "cat", 'f': []}, - {'a': 4, 'b': 2, 'c': 20, 'd': 3, 'e': "dog", 'f': [5, 6, 7]}] + rows = [ + {'a': 4, 'b': 1, 'c': 3, 'd': 5, 'e': "hello", 'f': [1, 2, 3]}, + {'a': 0, 'b': 5, 'c': 13, 'd': -1, 'e': "cat", 'f': []}, + {'a': 4, 'b': 2, 'c': 20, 'd': 3, 'e': "dog", 'f': [5, 6, 7]}, + ] kt = hl.Table.parallelize(rows, schema) @@ -303,12 +345,21 @@ def test_filter_missing(self): self.assertEqual(ht.filter(hl.missing(hl.tbool)).count(), 0) def test_transmute(self): - schema = hl.tstruct(a=hl.tint32, b=hl.tint32, c=hl.tint32, d=hl.tint32, e=hl.tstr, f=hl.tarray(hl.tint32), - g=hl.tstruct(x=hl.tbool, y=hl.tint32)) + schema = hl.tstruct( + a=hl.tint32, + b=hl.tint32, + c=hl.tint32, + d=hl.tint32, + e=hl.tstr, + f=hl.tarray(hl.tint32), + g=hl.tstruct(x=hl.tbool, y=hl.tint32), + ) - rows = [{'a': 4, 'b': 1, 'c': 3, 'd': 5, 'e': "hello", 'f': [1, 2, 3], 'g': {'x': True, 'y': 2}}, - {'a': 0, 'b': 5, 'c': 13, 'd': -1, 'e': "cat", 'f': [], 'g': {'x': True, 'y': 2}}, - {'a': 4, 'b': 2, 'c': 20, 'd': 3, 'e': "dog", 'f': [5, 6, 7], 'g': None}] + rows = [ + {'a': 4, 'b': 1, 'c': 3, 'd': 5, 'e': "hello", 'f': [1, 2, 3], 'g': {'x': True, 'y': 2}}, + {'a': 0, 'b': 5, 'c': 13, 'd': -1, 'e': "cat", 'f': [], 'g': {'x': True, 'y': 2}}, + {'a': 4, 'b': 2, 'c': 20, 'd': 3, 'e': "dog", 'f': [5, 6, 7], 'g': None}, + ] df = hl.Table.parallelize(rows, schema) df = df.transmute(h=df.a + df.b + df.c + df.g.y) @@ -323,17 +374,26 @@ def test_transmute_globals(self): def test_transmute_key(self): ht = hl.utils.range_table(10) - self.assertEqual(ht.transmute(y = ht.idx + 2).row.dtype, hl.dtype('struct{idx: int32, y: int32}')) + self.assertEqual(ht.transmute(y=ht.idx + 2).row.dtype, hl.dtype('struct{idx: int32, y: int32}')) ht = ht.key_by() - self.assertEqual(ht.transmute(y = ht.idx + 2).row.dtype, hl.dtype('struct{y: int32}')) + self.assertEqual(ht.transmute(y=ht.idx + 2).row.dtype, hl.dtype('struct{y: int32}')) def test_select(self): - schema = hl.tstruct(a=hl.tint32, b=hl.tint32, c=hl.tint32, d=hl.tint32, e=hl.tstr, f=hl.tarray(hl.tint32), - g=hl.tstruct(x=hl.tbool, y=hl.tint32)) + schema = hl.tstruct( + a=hl.tint32, + b=hl.tint32, + c=hl.tint32, + d=hl.tint32, + e=hl.tstr, + f=hl.tarray(hl.tint32), + g=hl.tstruct(x=hl.tbool, y=hl.tint32), + ) - rows = [{'a': 4, 'b': 1, 'c': 3, 'd': 5, 'e': "hello", 'f': [1, 2, 3], 'g': {'x': True, 'y': 2}}, - {'a': 0, 'b': 5, 'c': 13, 'd': -1, 'e': "cat", 'f': [], 'g': {'x': True, 'y': 2}}, - {'a': 4, 'b': 2, 'c': 20, 'd': 3, 'e': "dog", 'f': [5, 6, 7], 'g': None}] + rows = [ + {'a': 4, 'b': 1, 'c': 3, 'd': 5, 'e': "hello", 'f': [1, 2, 3], 'g': {'x': True, 'y': 2}}, + {'a': 0, 'b': 5, 'c': 13, 'd': -1, 'e': "cat", 'f': [], 'g': {'x': True, 'y': 2}}, + {'a': 4, 'b': 2, 'c': 20, 'd': 3, 'e': "dog", 'f': [5, 6, 7], 'g': None}, + ] kt = hl.Table.parallelize(rows, schema) @@ -357,9 +417,11 @@ def test_select(self): def test_errors(self): schema = hl.tstruct(status=hl.tint32, gt=hl.tcall, qPheno=hl.tint32) - rows = [{'status': 0, 'gt': hl.Call([0, 0]), 'qPheno': 3}, - {'status': 0, 'gt': hl.Call([0, 1]), 'qPheno': 13}, - {'status': 1, 'gt': hl.Call([0, 1]), 'qPheno': 20}] + rows = [ + {'status': 0, 'gt': hl.Call([0, 0]), 'qPheno': 3}, + {'status': 0, 'gt': hl.Call([0, 1]), 'qPheno': 13}, + {'status': 1, 'gt': hl.Call([0, 1]), 'qPheno': 20}, + ] kt = hl.Table.parallelize(rows, schema) @@ -370,21 +432,21 @@ def f(): def test_scan_filter(self): ht = hl.utils.range_table(10, n_partitions=10) - ht = ht.annotate(x = hl.scan.count()) + ht = ht.annotate(x=hl.scan.count()) ht = ht.filter(ht.idx == 9) assert ht.x.collect() == [9] def test_scan_tail(self): ht = hl.utils.range_table(100, n_partitions=16) - ht = ht.annotate(x = hl.scan.count()) + ht = ht.annotate(x=hl.scan.count()) ht = ht.tail(30) assert ht.x.collect() == list(range(70, 100)) def test_semi_anti_join(self): ht = hl.utils.range_table(10) ht2 = ht.filter(ht.idx < 3) - ht_2k = ht.key_by(k1 = ht.idx, k2 = hl.str(ht.idx * 2)) - ht2_2k = ht2.key_by(k1 = ht2.idx, k2 = hl.str(ht2.idx * 2)) + ht_2k = ht.key_by(k1=ht.idx, k2=hl.str(ht.idx * 2)) + ht2_2k = ht2.key_by(k1=ht2.idx, k2=hl.str(ht2.idx * 2)) assert ht.semi_join(ht2).count() == 3 assert ht.anti_join(ht2).count() == 7 @@ -446,70 +508,85 @@ def test_interval_join(self): intervals = hl.utils.range_table(4) intervals = intervals.key_by(interval=hl.interval(intervals.idx * 10, intervals.idx * 10 + 5)) left = left.annotate(interval_matches=intervals.index(left.key)) - self.assertTrue(left.all(hl.case() - .when(left.idx % 10 < 5, left.interval_matches.idx == left.idx // 10) - .default(hl.is_missing(left.interval_matches)))) + self.assertTrue( + left.all( + hl.case() + .when(left.idx % 10 < 5, left.interval_matches.idx == left.idx // 10) + .default(hl.is_missing(left.interval_matches)) + ) + ) def test_interval_filter_unordered(self): ht = hl.utils.range_table(100) - ht1 = hl.filter_intervals(ht, - [ - hl.utils.Interval(hl.utils.Struct(idx=10), hl.utils.Struct(idx=30)), - hl.utils.Interval(hl.utils.Struct(idx=50), hl.utils.Struct(idx=60)), - ] - ) + ht1 = hl.filter_intervals( + ht, + [ + hl.utils.Interval(hl.utils.Struct(idx=10), hl.utils.Struct(idx=30)), + hl.utils.Interval(hl.utils.Struct(idx=50), hl.utils.Struct(idx=60)), + ], + ) assert ht1.count() == 30 - ht2 = hl.filter_intervals(ht1, - [ - hl.utils.Interval(hl.utils.Struct(idx=25), hl.utils.Struct(idx=35)), - hl.utils.Interval(hl.utils.Struct(idx=70), hl.utils.Struct(idx=80)), - ] - ) + ht2 = hl.filter_intervals( + ht1, + [ + hl.utils.Interval(hl.utils.Struct(idx=25), hl.utils.Struct(idx=35)), + hl.utils.Interval(hl.utils.Struct(idx=70), hl.utils.Struct(idx=80)), + ], + ) assert ht2.count() == 5 - ht3 = hl.filter_intervals(ht, - [ - hl.utils.Interval(hl.utils.Struct(idx=50), hl.utils.Struct(idx=60)), - hl.utils.Interval(hl.utils.Struct(idx=10), hl.utils.Struct(idx=30)), - ] - ) + ht3 = hl.filter_intervals( + ht, + [ + hl.utils.Interval(hl.utils.Struct(idx=50), hl.utils.Struct(idx=60)), + hl.utils.Interval(hl.utils.Struct(idx=10), hl.utils.Struct(idx=30)), + ], + ) assert ht3.count() == 30 - ht4 = hl.filter_intervals(ht3, - [ - hl.utils.Interval(hl.utils.Struct(idx=25), hl.utils.Struct(idx=35)), - ] - ) + ht4 = hl.filter_intervals( + ht3, + [ + hl.utils.Interval(hl.utils.Struct(idx=25), hl.utils.Struct(idx=35)), + ], + ) assert ht4.count() == 5 - @fails_service_backend() - @fails_local_backend() def test_interval_product_join(self): left = hl.utils.range_table(50, n_partitions=8) intervals = hl.utils.range_table(25) - intervals = intervals.key_by(interval=hl.interval( - 1 + (intervals.idx // 5) * 10 + (intervals.idx % 5), - (1 + intervals.idx // 5) * 10 - (intervals.idx % 5))) + intervals = intervals.key_by( + interval=hl.interval( + 1 + (intervals.idx // 5) * 10 + (intervals.idx % 5), (1 + intervals.idx // 5) * 10 - (intervals.idx % 5) + ) + ) intervals = intervals.annotate(i=intervals.idx % 5) left = left.annotate(interval_matches=intervals.index(left.key, all_matches=True)) - self.assertTrue(left.all(hl.sorted(left.interval_matches.map(lambda x: x.i)) - == hl.range(0, hl.min(left.idx % 10, 10 - left.idx % 10)))) + self.assertTrue( + left.all( + hl.sorted(left.interval_matches.map(lambda x: x.i)) + == hl.range(0, hl.min(left.idx % 10, 10 - left.idx % 10)) + ) + ) - @fails_service_backend() - @fails_local_backend() def test_interval_product_join_long_key(self): left = hl.utils.range_table(50, n_partitions=8) intervals = hl.utils.range_table(25) intervals = intervals.key_by( interval=hl.interval( - 1 + (intervals.idx // 5) * 10 + (intervals.idx % 5), - (1 + intervals.idx // 5) * 10 - (intervals.idx % 5)), - k2=1) + 1 + (intervals.idx // 5) * 10 + (intervals.idx % 5), (1 + intervals.idx // 5) * 10 - (intervals.idx % 5) + ), + k2=1, + ) intervals = intervals.checkpoint('/tmp/bar.ht', overwrite=True) intervals = intervals.annotate(i=intervals.idx % 5) intervals = intervals.key_by('interval') left = left.annotate(interval_matches=intervals.index(left.idx, all_matches=True)) - self.assertTrue(left.all(hl.sorted(left.interval_matches.map(lambda x: x.i)) - == hl.range(0, hl.min(left.idx % 10, 10 - left.idx % 10)))) + self.assertTrue( + left.all( + hl.sorted(left.interval_matches.map(lambda x: x.i)) + == hl.range(0, hl.min(left.idx % 10, 10 - left.idx % 10)) + ) + ) def test_join_with_empty(self): kt = hl.utils.range_table(10) @@ -533,38 +610,38 @@ def test_multiple_entry_joins(self): mt = hl.utils.range_matrix_table(4, 4) mt2 = hl.utils.range_matrix_table(4, 4) mt2 = mt2.annotate_entries(x=mt2.row_idx + mt2.col_idx) - mt.select_entries(a=mt2[mt.row_idx, mt.col_idx].x, - b=mt2[mt.row_idx, mt.col_idx].x) + mt.select_entries(a=mt2[mt.row_idx, mt.col_idx].x, b=mt2[mt.row_idx, mt.col_idx].x) @test_timeout(batch=8 * 60) def test_multi_way_zip_join(self): - d1 = [{"id": 0, "name": "a", "data": 0.0}, - {"id": 1, "name": "b", "data": 3.14}, - {"id": 2, "name": "c", "data": 2.78}] - d2 = [{"id": 0, "name": "d", "data": 1.1}, - {"id": 2, "name": "v", "data": 7.89}] - d3 = [{"id": 1, "name": "f", "data": 9.99}, - {"id": 2, "name": "g", "data": -1.0}, - {"id": 3, "name": "z", "data": 0.01}] + d1 = [ + {"id": 0, "name": "a", "data": 0.0}, + {"id": 1, "name": "b", "data": 3.14}, + {"id": 2, "name": "c", "data": 2.78}, + ] + d2 = [{"id": 0, "name": "d", "data": 1.1}, {"id": 2, "name": "v", "data": 7.89}] + d3 = [ + {"id": 1, "name": "f", "data": 9.99}, + {"id": 2, "name": "g", "data": -1.0}, + {"id": 3, "name": "z", "data": 0.01}, + ] s = hl.tstruct(id=hl.tint32, name=hl.tstr, data=hl.tfloat64) ts = [hl.Table.parallelize(r, schema=s, key='id') for r in [d1, d2, d3]] joined = hl.Table.multi_way_zip_join(ts, '__data', '__globals').drop('__globals') - dexpected = [{"id": 0, "__data": [{"name": "a", "data": 0.0}, - {"name": "d", "data": 1.1}, - None]}, - {"id": 1, "__data": [{"name": "b", "data": 3.14}, - None, - {"name": "f", "data": 9.99}]}, - {"id": 2, "__data": [{"name": "c", "data": 2.78}, - {"name": "v", "data": 7.89}, - {"name": "g", "data": -1.0}]}, - {"id": 3, "__data": [None, - None, - {"name": "z", "data": 0.01}]}] + dexpected = [ + {"id": 0, "__data": [{"name": "a", "data": 0.0}, {"name": "d", "data": 1.1}, None]}, + {"id": 1, "__data": [{"name": "b", "data": 3.14}, None, {"name": "f", "data": 9.99}]}, + { + "id": 2, + "__data": [{"name": "c", "data": 2.78}, {"name": "v", "data": 7.89}, {"name": "g", "data": -1.0}], + }, + {"id": 3, "__data": [None, None, {"name": "z", "data": 0.01}]}, + ] expected = hl.Table.parallelize( dexpected, schema=hl.tstruct(id=hl.tint32, __data=hl.tarray(hl.tstruct(name=hl.tstr, data=hl.tfloat64))), - key='id') + key='id', + ) self.assertTrue(expected._same(joined)) expected2 = expected.transmute(data=expected['__data']) @@ -578,10 +655,7 @@ def test_multi_way_zip_join_globals(self): t1 = hl.utils.range_table(1).annotate_globals(x=hl.missing(hl.tint32)) t2 = hl.utils.range_table(1).annotate_globals(x=5) t3 = hl.utils.range_table(1).annotate_globals(x=0) - expected = hl.struct(__globals=hl.array([ - hl.struct(x=hl.missing(hl.tint32)), - hl.struct(x=5), - hl.struct(x=0)])) + expected = hl.struct(__globals=hl.array([hl.struct(x=hl.missing(hl.tint32)), hl.struct(x=5), hl.struct(x=0)])) joined = hl.Table.multi_way_zip_join([t1, t2, t3], '__data', '__globals') self.assertEqual(hl.eval(joined.globals), hl.eval(expected)) @@ -596,25 +670,42 @@ def test_multi_way_zip_join_key_downcast2(self): vcf2 = hl.import_vcf(resource('gvcfs/HG00268.g.vcf.gz'), force_bgz=True, reference_genome='GRCh38') vcf1 = hl.import_vcf(resource('gvcfs/HG00096.g.vcf.gz'), force_bgz=True, reference_genome='GRCh38') vcfs = [vcf1.rows().key_by('locus'), vcf2.rows().key_by('locus')] - exp_count = (vcfs[0].count() + vcfs[1].count() - - vcfs[0].aggregate(hl.agg.count_where(hl.is_defined(vcfs[1][vcfs[0].locus])))) + exp_count = ( + vcfs[0].count() + + vcfs[1].count() + - vcfs[0].aggregate(hl.agg.count_where(hl.is_defined(vcfs[1][vcfs[0].locus]))) + ) ht = hl.Table.multi_way_zip_join(vcfs, 'data', 'new_globals') assert exp_count == ht._force_count() + def test_multi_way_zip_join_highly_unbalanced_partitions__issue_14245(self): + def import_vcf(file: str, partitions: int): + return ( + hl.import_vcf(file, force_bgz=True, reference_genome='GRCh38', min_partitions=partitions) + .rows() + .select() + ) + + hl.Table.multi_way_zip_join( + [ + import_vcf(resource('gvcfs/HG00096.g.vcf.gz'), 100), + import_vcf(resource('gvcfs/HG00268.g.vcf.gz'), 1), + ], + 'data', + 'new_globals', + ).write(new_temp_file(extension='ht')) + def test_index_maintains_count(self): - t1 = hl.Table.parallelize([ - {'a': 'foo', 'b': 1}, - {'a': 'bar', 'b': 2}, - {'a': 'bar', 'b': 2}], + t1 = hl.Table.parallelize( + [{'a': 'foo', 'b': 1}, {'a': 'bar', 'b': 2}, {'a': 'bar', 'b': 2}], hl.tstruct(a=hl.tstr, b=hl.tint32), - key='a') - t2 = hl.Table.parallelize([ - {'t': 'foo', 'x': 3.14}, - {'t': 'bar', 'x': 2.78}, - {'t': 'bar', 'x': -1}, - {'t': 'quam', 'x': 0}], + key='a', + ) + t2 = hl.Table.parallelize( + [{'t': 'foo', 'x': 3.14}, {'t': 'bar', 'x': 2.78}, {'t': 'bar', 'x': -1}, {'t': 'quam', 'x': 0}], hl.tstruct(t=hl.tstr, x=hl.tfloat64), - key='t') + key='t', + ) j = t1.annotate(f=t2[t1.a].x) self.assertEqual(j.count(), t1.count()) @@ -630,7 +721,7 @@ def test_aggregation_with_no_aggregators(self): def test_drop(self): kt = hl.utils.range_table(10) - kt = kt.annotate(sq=kt.idx ** 2, foo='foo', bar='bar').key_by('foo') + kt = kt.annotate(sq=kt.idx**2, foo='foo', bar='bar').key_by('foo') ktd = kt.drop('idx') self.assertEqual(set(ktd.row), {'foo', 'sq', 'bar'}) @@ -665,9 +756,9 @@ def test_weird_names(self): df.group_by(**{'*``81': df.a}).aggregate(c=agg.count()) def test_sample(self): - kt = hl.utils.range_table(10) - kt_small = kt.sample(0.01) - self.assertTrue(kt_small.count() < kt.count()) + kt = hl.utils.range_table(16) + kt_small = kt.sample(0.25, seed=0) + self.assertEqual(4, kt_small.count()) @skip_unless_spark_backend() def test_from_spark_works(self): @@ -692,16 +783,21 @@ def test_from_pandas_works(self): def test_from_pandas_objects(self): import numpy as np - d = {'a': [[1, 2], [3, 4]], 'b': [{'a': 22, 'b': 21}, {'a': 23, 'b': 23}], 'c': - [np.array([np.array([1], dtype=np.int32), np.array([1], dtype=np.int32)]), - np.array([np.array([2], dtype=np.int32), np.array([2], dtype=np.int32)])]} + d = { + 'a': [[1, 2], [3, 4]], + 'b': [{'a': 22, 'b': 21}, {'a': 23, 'b': 23}], + 'c': [ + np.array([np.array([1], dtype=np.int32), np.array([1], dtype=np.int32)]), + np.array([np.array([2], dtype=np.int32), np.array([2], dtype=np.int32)]), + ], + } df = pd.DataFrame(data=d) t = hl.Table.from_pandas(df) - d2 = [hl.struct(a=hl.array([1, 2]), b=hl.literal({'a': 22, 'b': 21}), - c=hl.nd.array([[1], [1]])), - hl.struct(a=hl.array([3, 4]), b=hl.literal({'a': 23, 'b': 23}), - c=hl.nd.array([[2], [2]]))] + d2 = [ + hl.struct(a=hl.array([1, 2]), b=hl.literal({'a': 22, 'b': 21}), c=hl.nd.array([[1], [1]])), + hl.struct(a=hl.array([3, 4]), b=hl.literal({'a': 23, 'b': 23}), c=hl.nd.array([[2], [2]])), + ] t2 = hl.Table.parallelize(d2) self.assertTrue(t._same(t2)) @@ -712,7 +808,7 @@ def test_from_pandas_missing_and_nans(self): "x": pd.Series([None, 1, 2, None, 4], dtype=pd.Int64Dtype()), "y": pd.Series([None, 1, 2, None, 4], dtype=pd.Int32Dtype()), "z": pd.Series([np.nan, 1.0, 3.0, 4.0, np.nan]), - "s": pd.Series([None, "cat", None, "fox", "dog"], dtype=pd.StringDtype()) + "s": pd.Series([None, "cat", None, "fox", "dog"], dtype=pd.StringDtype()), }) ht = hl.Table.from_pandas(df) collected = ht.collect() @@ -734,26 +830,56 @@ def test_from_pandas_mismatched_object_rows(self): def test_table_parallelize_infer_types(self): import numpy as np + a = hl.array([{"b": 1, "c": "d"}, {"b": 1, "c": "d"}]) d = hl.array([[3, 4, 5], [1, 2, 3]]) e = hl.array([{"a": 1, "b": 2}, {"a": 3, "b": 4}]) - f = hl.array([.01, .00000002]) + f = hl.array([0.01, 0.00000002]) g = hl.array([(True, False), (False, True)]) h = hl.array([np.array([1, 2, 3]), np.array([3, 4, 5])]) i = hl.array([hl.Call([0, 0]), hl.Call([0, 1])]) j = hl.array([hl.locus('20', 17434581), hl.locus('19', 15434581)]) k = hl.array([hl.struct(a=1, b="2"), hl.struct(a=3, b="5")]) - data = [{"idx": 0, "a": {"b": 1, "c": "d"}, "d": [3, 4, 5], "e": {"a": 1, "b": 2}, "f": .01, - "g": (True, False), "h": np.array([1, 2, 3]), "i": hl.Call([0, 0]), "j": hl.locus('20', 17434581), - "k": hl.struct(a=1, b="2")}, - {"idx": 1, "a": {"b": 1, "c": "d"}, "d": [1, 2, 3], "e": {"a": 3, "b": 4}, "f": .00000002, - "g": (False, True), "h": np.array([3, 4, 5]), "i": hl.Call([0, 1]), "j": hl.locus('19', 15434581), - "k": hl.struct(a=3, b="5")}] + data = [ + { + "idx": 0, + "a": {"b": 1, "c": "d"}, + "d": [3, 4, 5], + "e": {"a": 1, "b": 2}, + "f": 0.01, + "g": (True, False), + "h": np.array([1, 2, 3]), + "i": hl.Call([0, 0]), + "j": hl.locus('20', 17434581), + "k": hl.struct(a=1, b="2"), + }, + { + "idx": 1, + "a": {"b": 1, "c": "d"}, + "d": [1, 2, 3], + "e": {"a": 3, "b": 4}, + "f": 0.00000002, + "g": (False, True), + "h": np.array([3, 4, 5]), + "i": hl.Call([0, 1]), + "j": hl.locus('19', 15434581), + "k": hl.struct(a=3, b="5"), + }, + ] table = hl.Table.parallelize(data, key='idx') ht = hl.utils.range_table(2) - ht = ht.annotate(a=hl.struct(b=a[ht.idx]['b'], c=a[ht.idx]['c']), d=d[ht.idx], e=e[ht.idx], f=f[ht.idx] - , g=g[ht.idx], h=h[ht.idx], i=i[ht.idx], j=j[ht.idx], k=k[ht.idx]) + ht = ht.annotate( + a=hl.struct(b=a[ht.idx]['b'], c=a[ht.idx]['c']), + d=d[ht.idx], + e=e[ht.idx], + f=f[ht.idx], + g=g[ht.idx], + h=h[ht.idx], + i=i[ht.idx], + j=j[ht.idx], + k=k[ht.idx], + ) self.assertTrue(table._same(ht)) @@ -761,8 +887,10 @@ def test_table_parallelize_partial_infer_types(self): b = hl.array([{"c": {1, 2, 3}, "d": {3, 4, 5}}, {"c": {6, 7, 8}, "d": {9, 10, 11}}]) e = hl.array([[[3], [4], [5]], [[1], [2], [3]]]) f = hl.array([hl.struct(a=1, b=2), hl.struct(a=3, b=4)]) - data = [{"idx": 0, "b": {"c": {1, 2, 3}, "d": {3, 4, 5}}, "e": [[3], [4], [5]], "f": {"a": 1, "b": 2}}, - {"idx": 1, "b": {"c": {6, 7, 8}, "d": {9, 10, 11}}, "e": [[1], [2], [3]], "f": {"a": 3, "b": 4}}] + data = [ + {"idx": 0, "b": {"c": {1, 2, 3}, "d": {3, 4, 5}}, "e": [[3], [4], [5]], "f": {"a": 1, "b": 2}}, + {"idx": 1, "b": {"c": {6, 7, 8}, "d": {9, 10, 11}}, "e": [[1], [2], [3]], "f": {"a": 3, "b": 4}}, + ] partial_type = {"idx": hl.tint32, "f": hl.tstruct(a=hl.tint32, b=hl.tint32)} table = hl.Table.parallelize(data, partial_type=partial_type, key='idx') ht = hl.utils.range_table(2) @@ -771,7 +899,7 @@ def test_table_parallelize_partial_infer_types(self): self.assertTrue(table._same(ht)) def test_table_parallelize_error_both_schema_partial_type_defined(self): - data= [{"a": 1, "b": "a"}, {"a": 2, "b": "c"}] + data = [{"a": 1, "b": "a"}, {"a": 2, "b": "c"}] schema = 'array' partial_type = {"a": hl.tint32} @@ -806,22 +934,25 @@ def test_rename(self): kt.rename({'hello': 'a'}) def test_distinct(self): - t1 = hl.Table.parallelize([ - {'a': 'foo', 'b': 1}, - {'a': 'bar', 'b': 2}, - {'a': 'bar', 'b': 2}, - {'a': 'bar', 'b': 3}, - {'a': 'bar', 'b': 3}, - {'a': 'baz', 'b': 2}, - {'a': 'baz', 'b': 0}, - {'a': 'baz', 'b': 0}, - {'a': 'foo', 'b': 0}, - {'a': '1', 'b': 0}, - {'a': '2', 'b': 0}, - {'a': '3', 'b': 0}], + t1 = hl.Table.parallelize( + [ + {'a': 'foo', 'b': 1}, + {'a': 'bar', 'b': 2}, + {'a': 'bar', 'b': 2}, + {'a': 'bar', 'b': 3}, + {'a': 'bar', 'b': 3}, + {'a': 'baz', 'b': 2}, + {'a': 'baz', 'b': 0}, + {'a': 'baz', 'b': 0}, + {'a': 'foo', 'b': 0}, + {'a': '1', 'b': 0}, + {'a': '2', 'b': 0}, + {'a': '3', 'b': 0}, + ], hl.tstruct(a=hl.tstr, b=hl.tint32), key='a', - n_partitions=4) + n_partitions=4, + ) dist = t1.distinct().collect_by_key() self.assertTrue(dist.all(hl.len(dist.values) == 1)) @@ -829,29 +960,31 @@ def test_distinct(self): @test_timeout(batch=6 * 60) def test_group_by_key(self): - t1 = hl.Table.parallelize([ - {'a': 'foo', 'b': 1}, - {'a': 'bar', 'b': 2}, - {'a': 'bar', 'b': 2}, - {'a': 'bar', 'b': 3}, - {'a': 'bar', 'b': 3}, - {'a': 'baz', 'b': 2}, - {'a': 'baz', 'b': 0}, - {'a': 'baz', 'b': 0}, - {'a': 'foo', 'b': 0}, - {'a': '1', 'b': 0}, - {'a': '2', 'b': 0}, - {'a': '3', 'b': 0}], + t1 = hl.Table.parallelize( + [ + {'a': 'foo', 'b': 1}, + {'a': 'bar', 'b': 2}, + {'a': 'bar', 'b': 2}, + {'a': 'bar', 'b': 3}, + {'a': 'bar', 'b': 3}, + {'a': 'baz', 'b': 2}, + {'a': 'baz', 'b': 0}, + {'a': 'baz', 'b': 0}, + {'a': 'foo', 'b': 0}, + {'a': '1', 'b': 0}, + {'a': '2', 'b': 0}, + {'a': '3', 'b': 0}, + ], hl.tstruct(a=hl.tstr, b=hl.tint32), key='a', - n_partitions=4) + n_partitions=4, + ) g = t1.collect_by_key().explode('values') g = g.transmute(**g.values) self.assertTrue(g._same(t1)) def test_str_annotation_regression(self): - t = hl.Table.parallelize([{'alleles': ['A', 'T']}], - hl.tstruct(alleles=hl.tarray(hl.tstr))) + t = hl.Table.parallelize([{'alleles': ['A', 'T']}], hl.tstruct(alleles=hl.tarray(hl.tstr))) t = t.annotate(ref=t.alleles[0]) t._force_count() @@ -884,10 +1017,10 @@ def test_explode_on_set(self): t = hl.utils.range_table(1) t = t.annotate(a=hl.set(['a', 'b', 'c'])) t = t.explode('a') - self.assertEqual(set(t.collect()), - hl.eval(hl.set([hl.struct(idx=0, a='a'), - hl.struct(idx=0, a='b'), - hl.struct(idx=0, a='c')]))) + self.assertEqual( + set(t.collect()), + hl.eval(hl.set([hl.struct(idx=0, a='a'), hl.struct(idx=0, a='b'), hl.struct(idx=0, a='c')])), + ) def test_explode_nested(self): t = hl.utils.range_table(2) @@ -910,7 +1043,7 @@ def test_export(self): assert f_in.read() == 'idx\tfoo\n0\t3\n' def test_export_delim(self): - t = hl.utils.range_table(1).annotate(foo = 3) + t = hl.utils.range_table(1).annotate(foo=3) tmp_file = new_temp_file() t.export(tmp_file, delimiter=',') @@ -927,8 +1060,7 @@ def test_export_parallel_manifest(self): with fs.open(f'{tmp_file}/shard-manifest.txt') as lines: manifest_files = [os.path.join(tmp_file, line.strip()) for line in lines] - ht2 = hl.import_table(manifest_files, - types={'idx': hl.tint32}) + ht2 = hl.import_table(manifest_files, types={'idx': hl.tint32}) assert ht2.collect() == values tmp_file2 = new_temp_file() @@ -936,8 +1068,7 @@ def test_export_parallel_manifest(self): with fs.open(f'{tmp_file2}/shard-manifest.txt') as lines: manifest_files = [os.path.join(tmp_file2, line.strip()) for line in lines] - ht3 = hl.import_table(manifest_files, - types={'idx': hl.tint32}) + ht3 = hl.import_table(manifest_files, types={'idx': hl.tint32}) assert ht3.collect() == values def test_write_stage_locally(self): @@ -966,10 +1097,14 @@ def test_read_back_same_as_exported(self): def test_indexed_read_1(self): t1 = hl.read_table(resource('range-table-2000-with-10-parts.ht'), _create_row_uids=True) - t2 = hl.read_table(resource('range-table-2000-with-10-parts.ht'), _intervals=[ - hl.Interval(start=150, end=250, includes_start=True, includes_end=False), - hl.Interval(start=250, end=500, includes_start=True, includes_end=False), - ], _create_row_uids=True) + t2 = hl.read_table( + resource('range-table-2000-with-10-parts.ht'), + _intervals=[ + hl.Interval(start=150, end=250, includes_start=True, includes_end=False), + hl.Interval(start=250, end=500, includes_start=True, includes_end=False), + ], + _create_row_uids=True, + ) self.assertEqual(t2.n_partitions(), 2) self.assertEqual(t2.count(), 350) self.assertEqual(t2._force_count(), 350) @@ -977,19 +1112,28 @@ def test_indexed_read_1(self): def test_indexed_read_2(self): t1 = hl.read_table(resource('range-table-2000-with-10-parts.ht'), _create_row_uids=True) - t2 = hl.read_table(resource('range-table-2000-with-10-parts.ht'), _intervals=[ - hl.Interval(start=150, end=250, includes_start=True, includes_end=False), - hl.Interval(start=250, end=500, includes_start=True, includes_end=False), - ], _filter_intervals=True, _create_row_uids=True) + t2 = hl.read_table( + resource('range-table-2000-with-10-parts.ht'), + _intervals=[ + hl.Interval(start=150, end=250, includes_start=True, includes_end=False), + hl.Interval(start=250, end=500, includes_start=True, includes_end=False), + ], + _filter_intervals=True, + _create_row_uids=True, + ) self.assertEqual(t2.n_partitions(), 3) self.assertTrue(t1.filter((t1.idx >= 150) & (t1.idx < 500))._same(t2)) def test_indexed_read_3(self): t1 = hl.read_table(resource('range-table-2000-with-10-parts.ht'), _create_row_uids=True) - t2 = hl.read_table(resource('range-table-2000-with-10-parts.ht'), _intervals=[ - hl.Interval(start=150, end=250, includes_start=False, includes_end=True), - hl.Interval(start=250, end=500, includes_start=False, includes_end=True), - ], _create_row_uids=True) + t2 = hl.read_table( + resource('range-table-2000-with-10-parts.ht'), + _intervals=[ + hl.Interval(start=150, end=250, includes_start=False, includes_end=True), + hl.Interval(start=250, end=500, includes_start=False, includes_end=True), + ], + _create_row_uids=True, + ) self.assertEqual(t2.n_partitions(), 2) self.assertEqual(t2.count(), 350) self.assertEqual(t2._force_count(), 350) @@ -997,15 +1141,20 @@ def test_indexed_read_3(self): def test_indexed_read_4(self): t1 = hl.read_table(resource('range-table-2000-with-10-parts.ht'), _create_row_uids=True) - t2 = hl.read_table(resource('range-table-2000-with-10-parts.ht'), _intervals=[ - hl.Interval(start=150, end=250, includes_start=False, includes_end=True), - hl.Interval(start=250, end=500, includes_start=False, includes_end=True), - ], _filter_intervals=True, _create_row_uids=True) + t2 = hl.read_table( + resource('range-table-2000-with-10-parts.ht'), + _intervals=[ + hl.Interval(start=150, end=250, includes_start=False, includes_end=True), + hl.Interval(start=250, end=500, includes_start=False, includes_end=True), + ], + _filter_intervals=True, + _create_row_uids=True, + ) self.assertEqual(t2.n_partitions(), 3) self.assertTrue(t1.filter((t1.idx > 150) & (t1.idx <= 500))._same(t2)) def test_order_by_parsing(self): - hl.utils.range_table(1).annotate(**{'a b c' : 5}).order_by('a b c')._force_count() + hl.utils.range_table(1).annotate(**{'a b c': 5}).order_by('a b c')._force_count() def test_take_order(self): t = hl.utils.range_table(20, n_partitions=2) @@ -1020,35 +1169,28 @@ def test_filter_partitions(self): self.assertEqual(ht._filter_partitions(range(3)).n_partitions(), 3) self.assertEqual(ht._filter_partitions([4, 5, 7], keep=False).n_partitions(), 5) self.assertTrue( - ht._same(hl.Table.union( - ht._filter_partitions([0, 3, 7]), - ht._filter_partitions([0, 3, 7], keep=False)))) + ht._same(hl.Table.union(ht._filter_partitions([0, 3, 7]), ht._filter_partitions([0, 3, 7], keep=False))) + ) # ht = [0, 1, 2], [3, 4, 5], ..., [21, 22] - self.assertEqual( - ht._filter_partitions([0, 7]).idx.collect(), - [0, 1, 2, 21, 22]) + self.assertEqual(ht._filter_partitions([0, 7]).idx.collect(), [0, 1, 2, 21, 22]) def test_localize_entries(self): - ref_schema = hl.tstruct(row_idx=hl.tint32, - __entries=hl.tarray(hl.tstruct(v=hl.tint32))) - ref_data = [{'row_idx': i, '__entries': [{'v': i+j} for j in range(6)]} - for i in range(8)] + ref_schema = hl.tstruct(row_idx=hl.tint32, __entries=hl.tarray(hl.tstruct(v=hl.tint32))) + ref_data = [{'row_idx': i, '__entries': [{'v': i + j} for j in range(6)]} for i in range(8)] ref_tab = hl.Table.parallelize(ref_data, ref_schema).key_by('row_idx') ref_tab = ref_tab.select_globals(__cols=[hl.struct(col_idx=i) for i in range(6)]) mt = hl.utils.range_matrix_table(8, 6) - mt = mt.annotate_entries(v=mt.row_idx+mt.col_idx) + mt = mt.annotate_entries(v=mt.row_idx + mt.col_idx) t = mt._localize_entries('__entries', '__cols') self.assertTrue(t._same(ref_tab)) def test_localize_self_join(self): - ref_schema = hl.tstruct(row_idx=hl.tint32, - __entries=hl.tarray(hl.tstruct(v=hl.tint32))) - ref_data = [{'row_idx': i, '__entries': [{'v': i+j} for j in range(6)]} - for i in range(8)] + ref_schema = hl.tstruct(row_idx=hl.tint32, __entries=hl.tarray(hl.tstruct(v=hl.tint32))) + ref_data = [{'row_idx': i, '__entries': [{'v': i + j} for j in range(6)]} for i in range(8)] ref_tab = hl.Table.parallelize(ref_data, ref_schema).key_by('row_idx') ref_tab = ref_tab.join(ref_tab, how='outer') mt = hl.utils.range_matrix_table(8, 6) - mt = mt.annotate_entries(v=mt.row_idx+mt.col_idx) + mt = mt.annotate_entries(v=mt.row_idx + mt.col_idx) t = mt._localize_entries('__entries', '__cols').drop('__cols') t = t.join(t, how='outer') self.assertTrue(t._same(ref_tab)) @@ -1108,14 +1250,14 @@ def test_union_unify(self): def test_table_order_by_head_rewrite(self): rt = hl.utils.range_table(10, 2) - rt = rt.annotate(x = 10 - rt.idx) + rt = rt.annotate(x=10 - rt.idx) expected = list(range(10))[::-1] self.assertEqual(rt.order_by('x').idx.take(10), expected) self.assertEqual(rt.order_by('x').idx.collect(), expected) def test_order_by_expr(self): ht = hl.utils.range_table(10, 3) - ht = ht.annotate(xs = hl.range(0, 1).map(lambda x: hl.int(hl.rand_unif(0, 100)))) + ht = ht.annotate(xs=hl.range(0, 1).map(lambda x: hl.int(hl.rand_unif(0, 100)))) asc = ht.order_by(ht.xs[0]) desc = ht.order_by(hl.desc(ht.xs[0])) @@ -1131,13 +1273,9 @@ def test_order_by_expr(self): def test_null_joins(self): tr = hl.utils.range_table(7, 1) - table1 = tr.key_by(new_key=hl.if_else((tr.idx == 3) | (tr.idx == 5), - hl.missing(hl.tint32), tr.idx), - key2=1) + table1 = tr.key_by(new_key=hl.if_else((tr.idx == 3) | (tr.idx == 5), hl.missing(hl.tint32), tr.idx), key2=1) table1 = table1.select(idx1=table1.idx) - table2 = tr.key_by(new_key=hl.if_else((tr.idx == 4) | (tr.idx == 6), - hl.missing(hl.tint32), tr.idx), - key2=1) + table2 = tr.key_by(new_key=hl.if_else((tr.idx == 4) | (tr.idx == 6), hl.missing(hl.tint32), tr.idx), key2=1) table2 = table2.select(idx2=table2.idx) left_join = table1.join(table2, 'left') @@ -1148,21 +1286,41 @@ def test_null_joins(self): def row(new_key, idx1, idx2): return hl.Struct(new_key=new_key, key2=1, idx1=idx1, idx2=idx2) - left_join_expected = [row(0, 0, 0), row(1, 1, 1), row(2, 2, 2), - row(4, 4, None), row(6, 6, None), - row(None, 3, None), row(None, 5, None)] + left_join_expected = [ + row(0, 0, 0), + row(1, 1, 1), + row(2, 2, 2), + row(4, 4, None), + row(6, 6, None), + row(None, 3, None), + row(None, 5, None), + ] - right_join_expected = [row(0, 0, 0), row(1, 1, 1), row(2, 2, 2), - row(3, None, 3), row(5, None, 5), - row(None, None, 4), row(None, None, 6)] + right_join_expected = [ + row(0, 0, 0), + row(1, 1, 1), + row(2, 2, 2), + row(3, None, 3), + row(5, None, 5), + row(None, None, 4), + row(None, None, 6), + ] inner_join_expected = [row(0, 0, 0), row(1, 1, 1), row(2, 2, 2)] - outer_join_expected = [row(0, 0, 0), row(1, 1, 1), row(2, 2, 2), - row(3, None, 3), row(4, 4, None), - row(5, None, 5), row(6, 6, None), - row(None, 3, None), row(None, 5, None), - row(None, None, 4), row(None, None, 6)] + outer_join_expected = [ + row(0, 0, 0), + row(1, 1, 1), + row(2, 2, 2), + row(3, None, 3), + row(4, 4, None), + row(5, None, 5), + row(6, 6, None), + row(None, 3, None), + row(None, 5, None), + row(None, None, 4), + row(None, None, 6), + ] self.assertEqual(left_join.collect(), left_join_expected) self.assertEqual(right_join.collect(), right_join_expected) @@ -1171,13 +1329,13 @@ def row(new_key, idx1, idx2): def test_null_joins_2(self): tr = hl.utils.range_table(7, 1) - table1 = tr.key_by(new_key=hl.if_else((tr.idx == 3) | (tr.idx == 5), - hl.missing(hl.tint32), tr.idx), - key2=tr.idx) + table1 = tr.key_by( + new_key=hl.if_else((tr.idx == 3) | (tr.idx == 5), hl.missing(hl.tint32), tr.idx), key2=tr.idx + ) table1 = table1.select(idx1=table1.idx) - table2 = tr.key_by(new_key=hl.if_else((tr.idx == 4) | (tr.idx == 6), - hl.missing(hl.tint32), tr.idx), - key2=tr.idx) + table2 = tr.key_by( + new_key=hl.if_else((tr.idx == 4) | (tr.idx == 6), hl.missing(hl.tint32), tr.idx), key2=tr.idx + ) table2 = table2.select(idx2=table2.idx) left_join = table1.join(table2, 'left') @@ -1188,22 +1346,44 @@ def test_null_joins_2(self): def row(new_key, key2, idx1, idx2): return hl.Struct(new_key=new_key, key2=key2, idx1=idx1, idx2=idx2) - left_join_expected = [row(0, 0, 0, 0), row(1, 1, 1, 1), row(2, 2, 2, 2), - row(4, 4, 4, None), row(6, 6, 6, None), - row(None, 3, 3, None), row(None, 5, 5, None)] + left_join_expected = [ + row(0, 0, 0, 0), + row(1, 1, 1, 1), + row(2, 2, 2, 2), + row(4, 4, 4, None), + row(6, 6, 6, None), + row(None, 3, 3, None), + row(None, 5, 5, None), + ] - right_join_expected = [row(0, 0, 0, 0), row(1, 1, 1, 1), row(2, 2, 2, 2), - row(3, 3, None, 3), row(5, 5, None, 5), - row(None, 4, None, 4), row(None, 6, None, 6)] + right_join_expected = [ + row(0, 0, 0, 0), + row(1, 1, 1, 1), + row(2, 2, 2, 2), + row(3, 3, None, 3), + row(5, 5, None, 5), + row(None, 4, None, 4), + row(None, 6, None, 6), + ] inner_join_expected = [row(0, 0, 0, 0), row(1, 1, 1, 1), row(2, 2, 2, 2)] def check_outer(actual): - assert actual[:7] == [row(0, 0, 0, 0), row(1, 1, 1, 1), row(2, 2, 2, 2), - row(3, 3, None, 3), row(4, 4, 4, None), - row(5, 5, None, 5), row(6, 6, 6, None)] - assert set(actual[7:]) == {row(None, 3, 3, None), row(None, 4, None, 4), - row(None, 5, 5, None), row(None, 6, None, 6)} + assert actual[:7] == [ + row(0, 0, 0, 0), + row(1, 1, 1, 1), + row(2, 2, 2, 2), + row(3, 3, None, 3), + row(4, 4, 4, None), + row(5, 5, None, 5), + row(6, 6, 6, None), + ] + assert set(actual[7:]) == { + row(None, 3, 3, None), + row(None, 4, None, 4), + row(None, 5, 5, None), + row(None, 6, None, 6), + } self.assertEqual(left_join.collect(), left_join_expected) self.assertEqual(right_join.collect(), right_join_expected) @@ -1225,19 +1405,39 @@ def test_joins_one_null(self): def row(new_key, idx1, idx2): return hl.Struct(new_key=new_key, idx1=idx1, idx2=idx2) - left_join_expected = [row(0, 0, 0), row(1, 1, 1), row(2, 2, 2), row(3, 3, 3), - row(4, 4, None), row(5, 5, 5), row(6, 6, None)] + left_join_expected = [ + row(0, 0, 0), + row(1, 1, 1), + row(2, 2, 2), + row(3, 3, 3), + row(4, 4, None), + row(5, 5, 5), + row(6, 6, None), + ] - right_join_expected = [row(0, 0, 0), row(1, 1, 1), row(2, 2, 2), - row(3, 3, 3), row(5, 5, 5), - row(None, None, 4), row(None, None, 6)] + right_join_expected = [ + row(0, 0, 0), + row(1, 1, 1), + row(2, 2, 2), + row(3, 3, 3), + row(5, 5, 5), + row(None, None, 4), + row(None, None, 6), + ] inner_join_expected = [row(0, 0, 0), row(1, 1, 1), row(2, 2, 2), row(3, 3, 3), row(5, 5, 5)] - outer_join_expected = [row(0, 0, 0), row(1, 1, 1), row(2, 2, 2), - row(3, 3, 3), row(4, 4, None), - row(5, 5, 5), row(6, 6, None), - row(None, None, 4), row(None, None, 6)] + outer_join_expected = [ + row(0, 0, 0), + row(1, 1, 1), + row(2, 2, 2), + row(3, 3, 3), + row(4, 4, None), + row(5, 5, 5), + row(6, 6, None), + row(None, None, 4), + row(None, None, 6), + ] self.assertEqual(left_join.collect(), left_join_expected) self.assertEqual(right_join.collect(), right_join_expected) @@ -1290,21 +1490,21 @@ def test_partitioning_rewrite(self): def test_flatten(self): t1 = hl.utils.range_table(10) - t1 = t1.key_by(x = hl.struct(a=t1.idx, b=0)).flatten() + t1 = t1.key_by(x=hl.struct(a=t1.idx, b=0)).flatten() t2 = hl.utils.range_table(10).key_by() t2 = t2.annotate(**{'x.a': t2.idx, 'x.b': 0}) self.assertTrue(t1._same(t2)) def test_expand_types(self): t1 = hl.utils.range_table(10) - t1 = t1.key_by(x = hl.locus('1', t1.idx+1)).expand_types() + t1 = t1.key_by(x=hl.locus('1', t1.idx + 1)).expand_types() t2 = hl.utils.range_table(10).key_by() - t2 = t2.annotate(x=hl.struct(contig='1', position=t2.idx+1)) + t2 = t2.annotate(x=hl.struct(contig='1', position=t2.idx + 1)) self.assertTrue(t1._same(t2)) def test_expand_types_ordering(self): ht = hl.utils.range_table(10) - ht = ht.key_by(x = 9 - ht.idx) + ht = ht.key_by(x=9 - ht.idx) assert ht.expand_types().x.collect() == list(range(10)) def test_expand_types_on_all_types(self): @@ -1336,8 +1536,8 @@ def test_join_with_filter_intervals(self): def test_key_by_aggregate_rewriting(self): ht = hl.utils.range_table(10) - ht = ht.group_by(x=ht.idx % 5).aggregate(aggr = hl.agg.count()) - assert(ht.count() == 5) + ht = ht.group_by(x=ht.idx % 5).aggregate(aggr=hl.agg.count()) + assert ht.count() == 5 def test_field_method_assignment(self): ht = hl.utils.range_table(10) @@ -1348,14 +1548,16 @@ def test_field_method_assignment(self): def test_refs_with_process_joins(self): ht = hl.utils.range_table(10).annotate(foo=5) - ht.annotate(a_join=ht[ht.key], - a_literal=hl.literal(['a']), - the_row_failure=hl.if_else(True, ht.row, hl.missing(ht.row.dtype)), - the_global_failure=hl.if_else(True, ht.globals, hl.missing(ht.globals.dtype))).count() + ht.annotate( + a_join=ht[ht.key], + a_literal=hl.literal(['a']), + the_row_failure=hl.if_else(True, ht.row, hl.missing(ht.row.dtype)), + the_global_failure=hl.if_else(True, ht.globals, hl.missing(ht.globals.dtype)), + ).count() def test_aggregate_localize_false(self): ht = hl.utils.range_table(10) - ht = ht.annotate(y = ht.idx + ht.aggregate(hl.agg.max(ht.idx), _localize=False)) + ht = ht.annotate(y=ht.idx + ht.aggregate(hl.agg.max(ht.idx), _localize=False)) assert ht.y.collect() == [x + 9 for x in range(10)] def test_collect_localize_false(self): @@ -1378,19 +1580,15 @@ def test_expr_collect(self): fields = [1, 0, 3] t = t.annotate_globals(globe=globe) - t = t.annotate(k = hl.array(keys)[t.idx], - field = hl.array(fields)[t.idx]) + t = t.annotate(k=hl.array(keys)[t.idx], field=hl.array(fields)[t.idx]) t = t.key_by(t.k) - rows = [hl.Struct(k=k, field=field) - for k, field in zip(keys, fields)] + rows = [hl.Struct(k=k, field=field) for k, field in zip(keys, fields)] ordered_rows = sorted(rows, key=lambda x: x.k) assert t.globe.collect() == [globe] - assert t.row.collect() == sorted([hl.Struct(idx=i, **r) - for i, r in enumerate(rows)], - key=lambda x: x.k) + assert t.row.collect() == sorted([hl.Struct(idx=i, **r) for i, r in enumerate(rows)], key=lambda x: x.k) assert t.key.collect() == [hl.Struct(k=r.k) for r in ordered_rows] assert t.k.collect() == [r.k for r in ordered_rows] @@ -1413,33 +1611,31 @@ def test_same_equal(self): def test_same_within_tolerance(self): t = hl.utils.range_table(1) - t1 = t.annotate(x = 1.0) - t2 = t.annotate(x = 1.0 + 1e-7) + t1 = t.annotate(x=1.0) + t2 = t.annotate(x=1.0 + 1e-7) self.assertTrue(t1._same(t2)) def test_same_different_type(self): t1 = hl.utils.range_table(1) - t2 = t1.annotate_globals(x = 7) + t2 = t1.annotate_globals(x=7) self.assertFalse(t1._same(t2)) - t3 = t1.annotate(x = 7) + t3 = t1.annotate(x=7) self.assertFalse(t1._same(t3)) t4 = t1.key_by() self.assertFalse(t1._same(t4)) def test_same_different_global(self): - t1 = (hl.utils.range_table(1) - .annotate_globals(x = 7)) - t2 = t1.annotate_globals(x = 8) + t1 = hl.utils.range_table(1).annotate_globals(x=7) + t2 = t1.annotate_globals(x=8) self.assertFalse(t1._same(t2)) def test_same_different_rows(self): - t1 = (hl.utils.range_table(2) - .annotate(x = 7)) + t1 = hl.utils.range_table(2).annotate(x=7) - t2 = t1.annotate(x = 8) + t2 = t1.annotate(x=8) self.assertFalse(t1._same(t2)) t3 = t1.filter(t1.idx == 0) @@ -1452,7 +1648,7 @@ def test_rvd_key_write(self): ht1 = hl.read_table(tempfile) ht2 = hl.utils.range_table(1).annotate(foo='a') - assert ht2.annotate(x = ht1.key_by('foo')[ht2.foo])._force_count() == 1 + assert ht2.annotate(x=ht1.key_by('foo')[ht2.foo])._force_count() == 1 def test_show_long_field_names(self): hl.utils.range_table(1).annotate(**{'a' * 256: 5}).show() @@ -1460,20 +1656,22 @@ def test_show_long_field_names(self): def test_show__various_types(self): ht = hl.utils.range_table(1) ht = ht.annotate( - x1 = [1], - x2 = [hl.struct(y=[1])], - x3 = {1}, - x4 = {1: 'foo'}, - x5 = {hl.struct(foo=5): 'bar'}, - x6 = hl.tuple(()), - x7 = hl.tuple(('3',)), - x8 = hl.tuple(('3', 3)), - x9 = 4.2, - x10 = hl.dict({'hello': 3, 'bar': 5}), - x11 = (True, False) + x1=[1], + x2=[hl.struct(y=[1])], + x3={1}, + x4={1: 'foo'}, + x5={hl.struct(foo=5): 'bar'}, + x6=hl.tuple(()), + x7=hl.tuple(('3',)), + x8=hl.tuple(('3', 3)), + x9=4.2, + x10=hl.dict({'hello': 3, 'bar': 5}), + x11=(True, False), ) result = ht.show(handler=str) - assert result == '''+-------+--------------+--------------------------------+------------+ + assert ( + result + == """+-------+--------------+--------------------------------+------------+ | idx | x1 | x2 | x3 | +-------+--------------+--------------------------------+------------+ | int32 | array | array}> | set | @@ -1496,22 +1694,24 @@ def test_show__various_types(self): +-------------------+----------+---------------------+-------------------+ | ("3",3) | 4.20e+00 | {"bar":5,"hello":3} | (True,False) | +-------------------+----------+---------------------+-------------------+ -''' +""" + ) def test_import_filter_replace(self): def assert_filter_equals(filter, find_replace, to): - assert hl.import_table(resource('filter_replace.txt'), - filter=filter, - find_replace=find_replace)['HEADER1'].collect() == to + assert ( + hl.import_table(resource('filter_replace.txt'), filter=filter, find_replace=find_replace)[ + 'HEADER1' + ].collect() + == to + ) assert_filter_equals('Foo', None, ['(Baz),(Qux)(']) assert_filter_equals(None, (r',', ''), ['(Foo(Bar))', '(Baz)(Qux)(']) assert_filter_equals(None, (r'\((\w+)\)', '$1'), ['(Foo,Bar)', 'Baz,Qux(']) def test_import_multiple_missing(self): - ht = hl.import_table(resource('global_list.txt'), - missing=['gene1', 'gene2'], - no_header=True) + ht = hl.import_table(resource('global_list.txt'), missing=['gene1', 'gene2'], no_header=True) assert ht.f0.collect() == [None, None, 'gene5', 'gene4', 'gene3'] @@ -1550,25 +1750,26 @@ def test_path_collision_error(self): ht.write(path) assert "both an input and output source" in str(exc.value) + def test_large_number_of_fields(): ht = hl.utils.range_table(100) - ht = ht.annotate(**{ - str(k): k for k in range(1000) - }) + ht = ht.annotate(**{str(k): k for k in range(1000)}) with hl.TemporaryDirectory(ensure_exists=False) as f: assert_time(lambda: ht.count(), 5) assert_time(lambda: ht.write(str(f)), 5) ht = assert_time(lambda: hl.read_table(str(f)), 5) assert_time(lambda: ht.count(), 5) + def test_import_many_fields(): assert_time(lambda: hl.import_table(resource('many_cols.txt')), 5) + def test_segfault(): t = hl.utils.range_table(1) t2 = hl.utils.range_table(3) - t = t.annotate(foo = [0]) - t2 = t2.annotate(foo = [0]) + t = t.annotate(foo=[0]) + t2 = t2.annotate(foo=[0]) joined = t.key_by('foo').join(t2.key_by('foo')) joined = joined.filter(hl.is_missing(joined.idx)) assert joined.collect() == [] @@ -1614,7 +1815,7 @@ def test_maybe_flexindex_table_by_expr_prefix_match(): def test_maybe_flexindex_table_by_expr_direct_interval_match(): t1 = hl.utils.range_table(1) - t1 = t1.key_by(interval=hl.interval(t1.idx, t1.idx+1)) + t1 = t1.key_by(interval=hl.interval(t1.idx, t1.idx + 1)) t2 = hl.utils.range_table(1) match_key = t1._maybe_flexindex_table_by_expr(t2.key) t2.annotate(foo=match_key)._force_count() @@ -1634,7 +1835,7 @@ def test_maybe_flexindex_table_by_expr_direct_interval_match(): def test_maybe_flexindex_table_by_expr_prefix_interval_match(): t1 = hl.utils.range_table(1) - t1 = t1.key_by(interval=hl.interval(t1.idx, t1.idx+1)) + t1 = t1.key_by(interval=hl.interval(t1.idx, t1.idx + 1)) t2 = hl.utils.range_table(1) t2 = t2.key_by(idx=t2.idx, idx2=t2.idx) match_key = t1._maybe_flexindex_table_by_expr(t2.key) @@ -1668,7 +1869,6 @@ def create_width_scale_files(): def write_file(n, n_rows=5): assert n % 4 == 0 n2 = n // 4 - d = {} header = [] for i in range(n2): header.append(f'i{i}') @@ -1680,7 +1880,7 @@ def write_file(n, n_rows=5): for i in range(n_rows): out.write('\n') for j in range(n2): - if (j > 0): + if j > 0: out.write('\t') out.write(str(j)) out.write('\t') @@ -1690,6 +1890,7 @@ def write_file(n, n_rows=5): out.write('\t') out.write(str(i % 2 == 0)) + widths = [1 << k for k in range(8, 14)] for w in widths: write_file(w) @@ -1705,6 +1906,7 @@ def test_join_with_key_prefix(): assert t.aggregate(hl.agg.all(t.foo == 1)) assert t.n_partitions() == 2 + def test_join_distinct_preserves_count(): left_pos = [1, 2, 4, 4, 5, 5, 9, 13, 13, 14, 15] right_pos = [1, 1, 1, 3, 4, 4, 6, 6, 8, 9, 13, 15] @@ -1716,23 +1918,28 @@ def test_join_distinct_preserves_count(): assert keys == left_pos right_table_2 = hl.utils.range_table(1).filter(False) - joined_2 = left_table.annotate(r = right_table_2.index(left_table.i)) - n_defined_2, keys_2 = joined_2.aggregate((hl.agg.count_where(hl.is_defined(joined_2.r)), hl.agg.collect(joined_2.i))) + joined_2 = left_table.annotate(r=right_table_2.index(left_table.i)) + n_defined_2, keys_2 = joined_2.aggregate(( + hl.agg.count_where(hl.is_defined(joined_2.r)), + hl.agg.collect(joined_2.i), + )) assert n_defined_2 == 0 assert keys_2 == left_pos + def test_write_table_containing_ndarray(): t = hl.utils.range_table(5) - t = t.annotate(n = hl.nd.arange(t.idx)) + t = t.annotate(n=hl.nd.arange(t.idx)) f = new_temp_file(extension='ht') t.write(f) t2 = hl.read_table(f) assert t._same(t2) + @test_timeout(batch=6 * 60) def test_group_within_partitions(): t = hl.utils.range_table(10).repartition(2) - t = t.annotate(sq=t.idx ** 2) + t = t.annotate(sq=t.idx**2) grouped1_collected = t._group_within_partitions("grouped_fields", 1).collect() grouped2_collected = t._group_within_partitions("grouped_fields", 2).collect() @@ -1745,15 +1952,21 @@ def test_group_within_partitions(): assert len(grouped3_collected) == 4 assert len(grouped5_collected) == 2 assert grouped5_collected == grouped6_collected - assert grouped3_collected == [hl.Struct(idx=0, grouped_fields=[hl.Struct(idx=0, sq=0.0), hl.Struct(idx=1, sq=1.0), hl.Struct(idx=2, sq=4.0)]), - hl.Struct(idx=3, grouped_fields=[hl.Struct(idx=3, sq=9.0), hl.Struct(idx=4, sq=16.0)]), - hl.Struct(idx=5, grouped_fields=[hl.Struct(idx=5, sq=25.0), hl.Struct(idx=6, sq=36.0), hl.Struct(idx=7, sq=49.0)]), - hl.Struct(idx=8, grouped_fields=[hl.Struct(idx=8, sq=64.0), hl.Struct(idx=9, sq=81.0)])] + assert grouped3_collected == [ + hl.Struct(idx=0, grouped_fields=[hl.Struct(idx=0, sq=0.0), hl.Struct(idx=1, sq=1.0), hl.Struct(idx=2, sq=4.0)]), + hl.Struct(idx=3, grouped_fields=[hl.Struct(idx=3, sq=9.0), hl.Struct(idx=4, sq=16.0)]), + hl.Struct( + idx=5, grouped_fields=[hl.Struct(idx=5, sq=25.0), hl.Struct(idx=6, sq=36.0), hl.Struct(idx=7, sq=49.0)] + ), + hl.Struct(idx=8, grouped_fields=[hl.Struct(idx=8, sq=64.0), hl.Struct(idx=9, sq=81.0)]), + ] # Testing after a filter ht = hl.utils.range_table(100).naive_coalesce(10) filter_then_group = ht.filter(ht.idx % 2 == 0)._group_within_partitions("grouped_fields", 5).collect() - assert filter_then_group[0] == hl.Struct(idx=0, grouped_fields=[hl.Struct(idx=0), hl.Struct(idx=2), hl.Struct(idx=4), hl.Struct(idx=6), hl.Struct(idx=8)]) + assert filter_then_group[0] == hl.Struct( + idx=0, grouped_fields=[hl.Struct(idx=0), hl.Struct(idx=2), hl.Struct(idx=4), hl.Struct(idx=6), hl.Struct(idx=8)] + ) # Test that names other than "grouped_fields" work assert "foo" in t._group_within_partitions("foo", 1).collect()[0] @@ -1764,21 +1977,23 @@ def test_group_within_partitions_after_explode(): t = t.annotate(arr=hl.range(0, 20)) t = t.explode(t.arr) t = t._group_within_partitions("grouped_fields", 10) - assert(t._force_count() == 20) + assert t._force_count() == 20 + def test_group_within_partitions_after_import_vcf(): gt_mt = hl.import_vcf(resource('small-gt.vcf')) ht = gt_mt.rows() ht = ht._group_within_partitions("grouped_fields", 16) - ht.collect() # Just testing import without segault + ht.collect() # Just testing import without segault assert True def test_range_annotate_range(): # tests left join right distinct requiredness ht1 = hl.utils.range_table(10) - ht2 = hl.utils.range_table(5).annotate(x = 1) - ht1.annotate(x = ht2[ht1.idx].x)._force_count() + ht2 = hl.utils.range_table(5).annotate(x=1) + ht1.annotate(x=ht2[ht1.idx].x)._force_count() + @test_timeout(batch=5 * 60) def test_read_write_all_types(): @@ -1801,6 +2016,7 @@ def test_map_partitions_errors(): with pytest.raises(ValueError, match='must preserve key fields'): ht._map_partitions(lambda rows: rows.map(lambda r: r.drop('idx'))) + def test_map_partitions_indexed(): tmp_file = new_temp_file() hl.utils.range_table(100, 8).write(tmp_file) @@ -1808,12 +2024,13 @@ def test_map_partitions_indexed(): ht = ht.key_by()._map_partitions(lambda partition: hl.array([hl.struct(foo=partition.to_array())])._to_stream()) assert [inner.idx for outer in ht.foo.collect() for inner in outer] == list(range(11, 55)) + def test_keys_before_scans(): ht = hl.utils.range_table(6) - ht = ht.annotate(rev_idx = -ht.idx) + ht = ht.annotate(rev_idx=-ht.idx) ht = ht.key_by(ht.rev_idx) - ht = ht.annotate(idx_scan = hl.scan.collect(ht.idx)) + ht = ht.annotate(idx_scan=hl.scan.collect(ht.idx)) ht = ht.key_by(ht.idx) assert ht.idx_scan.collect() == [[5, 4, 3, 2, 1], [5, 4, 3, 2], [5, 4, 3], [5, 4], [5], []] @@ -1827,7 +2044,6 @@ def test_lowered_persist(): assert ht.filter(ht.idx == 55).count() == 1 - @qobtest @lower_only() def test_lowered_shuffle(): @@ -1835,6 +2051,7 @@ def test_lowered_shuffle(): ht = ht.order_by(-ht.idx) assert ht.aggregate(hl.agg.take(ht.idx, 3)) == [99, 98, 97] + def test_read_partitions(): ht = hl.utils.range_table(100, 3) path = new_temp_file() @@ -1865,33 +2082,31 @@ def test_interval_filter_partitions(): hl.Interval(hl.Struct(idx=5), hl.Struct(idx=10)), hl.Interval(hl.Struct(idx=12), hl.Struct(idx=13)), hl.Interval(hl.Struct(idx=15), hl.Struct(idx=17)), - hl.Interval(hl.Struct(idx=19), hl.Struct(idx=20)) + hl.Interval(hl.Struct(idx=19), hl.Struct(idx=20)), ] - assert hl.read_table(path, _intervals=intervals, _filter_intervals = True).n_partitions() == 1 + assert hl.read_table(path, _intervals=intervals, _filter_intervals=True).n_partitions() == 1 intervals = [ hl.Interval(hl.Struct(idx=5), hl.Struct(idx=10)), hl.Interval(hl.Struct(idx=12), hl.Struct(idx=13)), hl.Interval(hl.Struct(idx=15), hl.Struct(idx=17)), - hl.Interval(hl.Struct(idx=45), hl.Struct(idx=50)), hl.Interval(hl.Struct(idx=52), hl.Struct(idx=53)), hl.Interval(hl.Struct(idx=55), hl.Struct(idx=57)), - hl.Interval(hl.Struct(idx=75), hl.Struct(idx=80)), hl.Interval(hl.Struct(idx=82), hl.Struct(idx=83)), hl.Interval(hl.Struct(idx=85), hl.Struct(idx=87)), ] - assert hl.read_table(path, _intervals=intervals, _filter_intervals = True).n_partitions() == 3 - + assert hl.read_table(path, _intervals=intervals, _filter_intervals=True).n_partitions() == 3 def test_grouped_flatmap_streams(): ht = hl.import_vcf(resource('sample.vcf')).rows() ht = ht.annotate(x=hl.str(ht.locus)) # add a map node - ht = ht._map_partitions(lambda part: part.grouped(8).flatmap( - lambda group: group._to_stream().map(lambda x: x.annotate(z=1)))) + ht = ht._map_partitions( + lambda part: part.grouped(8).flatmap(lambda group: group._to_stream().map(lambda x: x.annotate(z=1))) + ) ht._force_count() @@ -1901,25 +2116,30 @@ def test(): if table_name == 'rt': table = hl.utils.range_table(10, n_partitions=num_parts) elif table_name == 'par': - table = hl.Table.parallelize([hl.Struct(x=x) for x in range(10)], schema='struct{x: int32}', - n_partitions=num_parts) + table = hl.Table.parallelize( + [hl.Struct(x=x) for x in range(10)], schema='struct{x: int32}', n_partitions=num_parts + ) elif table_name == 'rtcache': table = hl.utils.range_table(10, n_partitions=num_parts).cache() else: assert table_name == 'chkpt' table = hl.utils.range_table(10, n_partitions=num_parts).checkpoint(new_temp_file(extension='ht')) assert counter(truncator(table, n)) == min(10, n) + return test head_tail_test_data = [ - pytest.param(make_test(table_name, num_parts, counter, truncator, n), - id='__'.join([table_name, str(num_parts), str(n), truncator_name, counter_name])) + pytest.param( + make_test(table_name, num_parts, counter, truncator, n), + id='__'.join([table_name, str(num_parts), str(n), truncator_name, counter_name]), + ) for table_name in ['rt', 'par', 'rtcache', 'chkpt'] for num_parts in [3, 11] for n in (10, 9, 11, 0, 7) for truncator_name, truncator in (('head', hl.Table.head), ('tail', hl.Table.tail)) - for counter_name, counter in (('count', hl.Table.count), ('_force_count', hl.Table._force_count))] + for counter_name, counter in (('count', hl.Table.count), ('_force_count', hl.Table._force_count)) +] @pytest.mark.parametrize("test", head_tail_test_data) @@ -1930,14 +2150,15 @@ def test_table_head_and_tail(test): def test_to_pandas(): ht = hl.utils.range_table(3) strs = ["foo", "bar", "baz"] - ht = ht.annotate(s = hl.array(strs)[ht.idx], nested=hl.struct(foo = ht.idx, bar=hl.range(ht.idx))) + ht = ht.annotate(s=hl.array(strs)[ht.idx], nested=hl.struct(foo=ht.idx, bar=hl.range(ht.idx))) df_from_hail = ht.to_pandas(flatten=False) python_data = { "idx": pd.Series([0, 1, 2], dtype='Int32'), "s": pd.Series(["foo", "bar", "baz"], dtype='string'), - "nested": pd.Series([hl.Struct(foo=0, bar=[]), hl.Struct(foo=1, bar=[0]), - hl.Struct(foo=2, bar=[0, 1])], dtype=object) + "nested": pd.Series( + [hl.Struct(foo=0, bar=[]), hl.Struct(foo=1, bar=[0]), hl.Struct(foo=2, bar=[0, 1])], dtype=object + ), } df_from_python = pd.DataFrame(python_data) @@ -1946,11 +2167,7 @@ def test_to_pandas(): def test_to_pandas_types_type_to_type(): ht = hl.utils.range_table(3) - ht = ht.annotate( - s=hl.array(["foo", "bar", "baz"])[ht.idx], - nested=hl.struct(foo=ht.idx, - bar=hl.range(ht.idx)) - ) + ht = ht.annotate(s=hl.array(["foo", "bar", "baz"])[ht.idx], nested=hl.struct(foo=ht.idx, bar=hl.range(ht.idx))) actual = dict(ht.to_pandas(types={hl.tint32: 'Int64'}).dtypes) assert isinstance(actual['idx'], pd.Int64Dtype) assert isinstance(actual['s'], pd.StringDtype) @@ -1960,11 +2177,7 @@ def test_to_pandas_types_type_to_type(): def test_to_pandas_types_column_to_type(): ht = hl.utils.range_table(3) - ht = ht.annotate( - s=hl.array(["foo", "bar", "baz"])[ht.idx], - nested=hl.struct(foo=ht.idx, - bar=hl.range(ht.idx)) - ) + ht = ht.annotate(s=hl.array(["foo", "bar", "baz"])[ht.idx], nested=hl.struct(foo=ht.idx, bar=hl.range(ht.idx))) actual = dict(ht.to_pandas(types={'nested.foo': 'Int64'}).dtypes) assert isinstance(actual['idx'], pd.Int32Dtype) assert isinstance(actual['s'], pd.StringDtype) @@ -1975,14 +2188,14 @@ def test_to_pandas_types_column_to_type(): def test_to_pandas_flatten(): ht = hl.utils.range_table(3) strs = ["foo", "bar", "baz"] - ht = ht.annotate(s = hl.array(strs)[ht.idx], nested = hl.struct(foo = ht.idx, bar=hl.range(ht.idx))) + ht = ht.annotate(s=hl.array(strs)[ht.idx], nested=hl.struct(foo=ht.idx, bar=hl.range(ht.idx))) df_from_hail = ht.to_pandas(flatten=True) python_data = { "idx": pd.Series([0, 1, 2], dtype='Int32'), "s": pd.Series(["foo", "bar", "baz"], dtype='string'), "nested.foo": pd.Series([0, 1, 2], dtype='Int32'), - "nested.bar": pd.Series([[], [0], [0, 1]], dtype=object) + "nested.bar": pd.Series([[], [0], [0, 1]], dtype=object), } df_from_python = pd.DataFrame(python_data) @@ -1991,12 +2204,14 @@ def test_to_pandas_flatten(): def test_to_pandas_null_ints(): ht = hl.utils.range_table(3) - ht = ht.annotate(missing_int32 = hl.or_missing(ht.idx == 0, ht.idx), - missing_int64 = hl.or_missing(ht.idx == 0, hl.int64(ht.idx)), - missing_float32 = hl.or_missing(ht.idx == 0, hl.float32(ht.idx)), - missing_float64 = hl.or_missing(ht.idx == 0, hl.float64(ht.idx)), - missing_bool = hl.or_missing(ht.idx == 0, True), - missing_str = hl.or_missing(ht.idx == 0, 'foo')) + ht = ht.annotate( + missing_int32=hl.or_missing(ht.idx == 0, ht.idx), + missing_int64=hl.or_missing(ht.idx == 0, hl.int64(ht.idx)), + missing_float32=hl.or_missing(ht.idx == 0, hl.float32(ht.idx)), + missing_float64=hl.or_missing(ht.idx == 0, hl.float64(ht.idx)), + missing_bool=hl.or_missing(ht.idx == 0, True), + missing_str=hl.or_missing(ht.idx == 0, 'foo'), + ) df_from_hail = ht.to_pandas() python_data = { @@ -2015,13 +2230,14 @@ def test_to_pandas_null_ints(): def test_to_pandas_nd_array(): import numpy as np + ht = hl.utils.range_table(3) ht = ht.annotate(nd=hl.nd.arange(3)) df_from_hail = ht.to_pandas() python_data = { "idx": pd.Series([0, 1, 2], dtype='Int32'), - "nd": pd.Series([np.arange(3), np.arange(3), np.arange(3)]) + "nd": pd.Series([np.arange(3), np.arange(3), np.arange(3)]), } df_from_python = pd.DataFrame(python_data) @@ -2042,6 +2258,7 @@ def test_literal_of_numpy_int32(): def test_literal_of_pandas_NA_and_numpy_int64(): import hail as hl + t = hl.utils.range_table(10) x = t.key_by(idx=hl.or_missing(t.idx == 5, hl.int64(t.idx))).to_pandas().idx.tolist() hl.eval(hl.literal(x)) @@ -2049,6 +2266,7 @@ def test_literal_of_pandas_NA_and_numpy_int64(): def test_literal_of_pandas_NA_and_numpy_int32(): import hail as hl + t = hl.utils.range_table(10) x = t.key_by(idx=hl.or_missing(t.idx == 5, t.idx)).to_pandas().idx.tolist() hl.eval(hl.literal(x)) @@ -2057,7 +2275,7 @@ def test_literal_of_pandas_NA_and_numpy_int32(): @test_timeout(batch=5 * 60) def test_write_many(): t = hl.utils.range_table(5) - t = t.annotate(a = t.idx, b = t.idx * t.idx, c = hl.str(t.idx)) + t = t.annotate(a=t.idx, b=t.idx * t.idx, c=hl.str(t.idx)) with hl.TemporaryDirectory(ensure_exists=False) as f: t.write_many(f, fields=('a', 'b', 'c')) @@ -2066,7 +2284,7 @@ def test_write_many(): hl.Struct(idx=1, a=1), hl.Struct(idx=2, a=2), hl.Struct(idx=3, a=3), - hl.Struct(idx=4, a=4) + hl.Struct(idx=4, a=4), ] assert hl.read_table(f + '/b').collect() == [ @@ -2074,7 +2292,7 @@ def test_write_many(): hl.Struct(idx=1, b=1), hl.Struct(idx=2, b=4), hl.Struct(idx=3, b=9), - hl.Struct(idx=4, b=16) + hl.Struct(idx=4, b=16), ] assert hl.read_table(f + '/c').collect() == [ @@ -2082,9 +2300,10 @@ def test_write_many(): hl.Struct(idx=1, c='1'), hl.Struct(idx=2, c='2'), hl.Struct(idx=3, c='3'), - hl.Struct(idx=4, c='4') + hl.Struct(idx=4, c='4'), ] + @pytest.mark.parametrize('branching_factor', [2, 3, 5, 7, 121]) def test_indexed_read_boundaries(branching_factor): with hl._with_flags(index_branching_factor=str(branching_factor)): @@ -2092,10 +2311,13 @@ def test_indexed_read_boundaries(branching_factor): t = t.filter(t.idx % 5 != 0) f = new_temp_file(extension='ht') t.write(f) - t1 = hl.read_table(f, _intervals=[ - hl.Interval(start=140, end=145, includes_start=True, includes_end=True), - hl.Interval(start=151, end=153, includes_start=False, includes_end=False), - ]) + t1 = hl.read_table( + f, + _intervals=[ + hl.Interval(start=140, end=145, includes_start=True, includes_end=True), + hl.Interval(start=151, end=153, includes_start=False, includes_end=False), + ], + ) assert t1.idx.collect() == [141, 142, 143, 144, 152] @@ -2103,10 +2325,11 @@ def test_indexed_read_boundaries(branching_factor): def assert_unique_uids(ht): ht = ht.annotate(r=hl.rand_int64()) x = ht.aggregate(hl.struct(r=hl.agg.collect_as_set(ht.r), n=hl.agg.count())) - assert(len(x.r) == x.n) + assert len(x.r) == x.n + def assert_contains_node(t, node): - assert(t._tir.base_search(lambda x: isinstance(x, node))) + assert t._tir.base_search(lambda x: isinstance(x, node)) def test_table_randomness_range_table(): @@ -2167,7 +2390,7 @@ def test_table_randomness_map_globals_with_body_randomness(): rt = hl.utils.range_table(5) t1 = rt.annotate_globals(x=hl.rand_int64()) assert_contains_node(t1, ir.TableMapGlobals) - t1._force_count() # test with no consumer randomness + t1._force_count() # test with no consumer randomness assert_unique_uids(t1) @@ -2197,7 +2420,7 @@ def test_table_randomness_map_rows_with_body_randomness(): rt = hl.utils.range_table(12, 3) t = rt.annotate(x=hl.rand_int64()) assert_contains_node(t, ir.TableMapRows) - t._force_count() # test with no consumer randomness + t._force_count() # test with no consumer randomness assert_unique_uids(t) @@ -2220,7 +2443,7 @@ def test_table_randomness_map_partitions(): t = rt.annotate(x=hl.rand_int64()) t = t._map_partitions(lambda part: part.map(lambda row: row.annotate(x=row.x / 2))) assert_contains_node(t, ir.TableMapPartitions) - t._force_count() # test with no consumer randomness + t._force_count() # test with no consumer randomness def test_table_randomness_read(): @@ -2243,7 +2466,7 @@ def test_table_randomness_filter_with_cond_randomness(): rt = hl.utils.range_table(20, 3) t = rt.filter(hl.rand_int64() % 2 == 0) assert_contains_node(t, ir.TableFilter) - t._force_count() # test with no consumer randomness + t._force_count() # test with no consumer randomness assert_unique_uids(t) @@ -2258,7 +2481,7 @@ def test_table_randomness_key_by_and_aggregate_with_body_randomness(): rt = hl.utils.range_table(20, 3) t = rt.group_by(k=rt.idx % 5).aggregate(x=hl.agg.sum(rt.idx) + hl.rand_int64()) assert_contains_node(t, ir.TableKeyByAndAggregate) - t._force_count() # test with no consumer randomness + t._force_count() # test with no consumer randomness assert_unique_uids(t) @@ -2266,7 +2489,7 @@ def test_table_randomness_key_by_and_aggregate_with_agg_randomness(): rt = hl.utils.range_table(20, 3) t = rt.group_by(k=rt.idx % 5).aggregate(x=hl.agg.sum(hl.rand_int64())) assert_contains_node(t, ir.TableKeyByAndAggregate) - t._force_count() # test with no consumer randomness + t._force_count() # test with no consumer randomness assert_unique_uids(t) @@ -2293,19 +2516,18 @@ def test_table_randomness_matrix_cols_table(): def test_table_randomness_parallelize_with_body_randomness(): - rt = hl.utils.range_table(20, 3) t = hl.Table.parallelize(hl.array([1, 2, 3]).map(lambda x: hl.struct(x=x, r=hl.rand_int64()))) assert_contains_node(t, ir.TableParallelize) - t._force_count() # test with no consumer randomness + t._force_count() # test with no consumer randomness assert_unique_uids(t) def test_table_randomness_parallelize_without_body_randomness(): - rt = hl.utils.range_table(20, 3) t = hl.Table.parallelize(hl.array([1, 2, 3]).map(lambda x: hl.struct(x=x))) assert_contains_node(t, ir.TableParallelize) assert_unique_uids(t) + def test_table_randomness_head(): t = hl.utils.range_table(20, 3) t = t.head(10) @@ -2429,22 +2651,17 @@ def test_query_table(): [], [], [], - [hl.Struct(idx=30, s='30'), - hl.Struct(idx=40, s='40'), - hl.Struct(idx=50, s='50'), - hl.Struct(idx=60, s='60')], + [hl.Struct(idx=30, s='30'), hl.Struct(idx=40, s='40'), hl.Struct(idx=50, s='50'), hl.Struct(idx=60, s='60')], [], [], - [hl.Struct(idx=30, s='30'), - hl.Struct(idx=40, s='40'), - hl.Struct(idx=50, s='50'), - hl.Struct(idx=60, s='60')], - [hl.Struct(idx=40, s='40'), - hl.Struct(idx=50, s='50'), - hl.Struct(idx=60, s='60'), - hl.Struct(idx=70, s='70'), - hl.Struct(idx=80, s='80'), - ] + [hl.Struct(idx=30, s='30'), hl.Struct(idx=40, s='40'), hl.Struct(idx=50, s='50'), hl.Struct(idx=60, s='60')], + [ + hl.Struct(idx=40, s='40'), + hl.Struct(idx=50, s='50'), + hl.Struct(idx=60, s='60'), + hl.Struct(idx=70, s='70'), + hl.Struct(idx=80, s='80'), + ], ] assert hl.eval(queries) == expected @@ -2469,14 +2686,10 @@ def test_query_table_compound_key(): queries = [ hl.query_table(f, 50), hl.query_table(f, hl.struct(idx=50)), - hl.query_table(f, hl.interval(hl.struct(idx=50, idx2=11), hl.struct(idx=60, idx2=-1))) + hl.query_table(f, hl.interval(hl.struct(idx=50, idx2=11), hl.struct(idx=60, idx2=-1))), ] - expected = [ - [hl.Struct(idx=50, idx2=10, s='50')], - [hl.Struct(idx=50, idx2=10, s='50')], - [] - ] + expected = [[hl.Struct(idx=50, idx2=10, s='50')], [hl.Struct(idx=50, idx2=10, s='50')], []] assert hl.eval(queries) == expected @@ -2494,7 +2707,7 @@ def test_query_table_interval_key(): hl.query_table(f, hl.interval(20, 70)), hl.query_table(f, hl.interval(20, 0)), hl.query_table(f, hl.struct(interval=hl.interval(20, 0))), - hl.query_table(f, hl.interval(hl.interval(15, 10), hl.interval(20, 71))) + hl.query_table(f, hl.interval(hl.interval(15, 10), hl.interval(20, 71))), ] expected = [ diff --git a/hail/python/test/hail/test_call_caching.py b/hail/python/test/hail/test_call_caching.py index 825a96f3931..deada99ea35 100644 --- a/hail/python/test/hail/test_call_caching.py +++ b/hail/python/test/hail/test_call_caching.py @@ -1,8 +1,9 @@ import hail as hl - from hail.utils.misc import new_temp_file + from .helpers import with_flags + def test_execution_cache_creation(): """Asserts creation of execution cache folder""" folder = new_temp_file('hail-execution-cache') @@ -11,9 +12,7 @@ def test_execution_cache_creation(): @with_flags(use_fast_restarts='1', cachedir=folder) def test(): - (hl.utils.range_table(10) - .annotate(another_field=5) - ._force_count()) + (hl.utils.range_table(10).annotate(another_field=5)._force_count()) assert fs.exists(folder) assert len(fs.ls(folder)) > 0 diff --git a/hail/python/test/hail/test_context.py b/hail/python/test/hail/test_context.py index 53fc736eea2..c768a326102 100644 --- a/hail/python/test/hail/test_context.py +++ b/hail/python/test/hail/test_context.py @@ -1,15 +1,10 @@ -from typing import Tuple, Dict, Optional +from test.hail.helpers import hl_init_for_test, hl_stop_for_test, qobtest, skip_unless_spark_backend, with_flags +from typing import Dict, Optional, Tuple import hail as hl -from hail.utils.java import Env from hail.backend.backend import Backend from hail.backend.spark_backend import SparkBackend -from test.hail.helpers import (skip_unless_spark_backend, - hl_init_for_test, - hl_stop_for_test, - qobtest, - with_flags - ) +from hail.utils.java import Env def _scala_map_str_to_tuple_str_str_to_dict(scala) -> Dict[str, Tuple[Optional[str], Optional[str]]]: @@ -57,8 +52,8 @@ def test_tmpdir_runs(): def test_get_flags(): - assert hl._get_flags() == {} - assert list(hl._get_flags('use_new_shuffle')) == ['use_new_shuffle'] + assert hl._get_flags() == {} + assert list(hl._get_flags('use_new_shuffle')) == ['use_new_shuffle'] @skip_unless_spark_backend(reason='requires JVM') @@ -72,23 +67,17 @@ def test_flags_same_in_scala_and_python(): def test_fast_restarts_feature(): def is_featured_off(): - return hl._get_flags('use_fast_restarts', 'cachedir') == { - 'use_fast_restarts': None, - 'cachedir': None - } + return hl._get_flags('use_fast_restarts', 'cachedir') == {'use_fast_restarts': None, 'cachedir': None} @with_flags(use_fast_restarts='1') def uses_fast_restarts(): - return hl._get_flags('use_fast_restarts', 'cachedir') == { - 'use_fast_restarts': '1', - 'cachedir': None - } + return hl._get_flags('use_fast_restarts', 'cachedir') == {'use_fast_restarts': '1', 'cachedir': None} @with_flags(use_fast_restarts='1', cachedir='gs://my-bucket/object-prefix') def uses_cachedir(): return hl._get_flags('use_fast_restarts', 'cachedir') == { 'use_fast_restarts': '1', - 'cachedir': 'gs://my-bucket/object-prefix' + 'cachedir': 'gs://my-bucket/object-prefix', } assert is_featured_off() diff --git a/hail/python/test/hail/test_exceptions_from_workers_have_stack_traces.py b/hail/python/test/hail/test_exceptions_from_workers_have_stack_traces.py new file mode 100644 index 00000000000..5391a63a4d5 --- /dev/null +++ b/hail/python/test/hail/test_exceptions_from_workers_have_stack_traces.py @@ -0,0 +1,25 @@ +import re + +import pytest + +import hail as hl +from hail.utils.java import FatalError + +from .helpers import qobtest + + +@qobtest +def test_exceptions_from_workers_have_stack_traces(): + ht = hl.utils.range_table(10, n_partitions=10) + ht = ht.annotate(x=hl.int(1) // hl.int(hl.rand_norm(0, 0.1))) + pattern = ( + '.*' + + re.escape('java.lang.Math.floorDiv(Math.java:1052)') + + '.*' + + re.escape('(BackendUtils.scala:') + + '[0-9]+' + + re.escape(')\n') + + '.*' + ) + with pytest.raises(FatalError, match=re.compile(pattern, re.DOTALL)): + ht.collect() diff --git a/hail/python/test/hail/test_hail_in_notebook.py b/hail/python/test/hail/test_hail_in_notebook.py index 681f747315a..fec55f216b2 100644 --- a/hail/python/test/hail/test_hail_in_notebook.py +++ b/hail/python/test/hail/test_hail_in_notebook.py @@ -1,12 +1,18 @@ -from hailtop.utils.process import sync_check_exec import os import pathlib + +from hailtop.utils.process import sync_check_exec + from .helpers import skip_when_local_backend -@skip_when_local_backend('In the LocalBackend, writing to a gs:// URL hangs indefinitely https://github.com/hail-is/hail/issues/13904') +@skip_when_local_backend( + 'In the LocalBackend, writing to a gs:// URL hangs indefinitely https://github.com/hail-is/hail/issues/13904' +) def test_hail_in_notebook(): folder = pathlib.Path(__file__).parent.resolve() source_ipynb = os.path.join(folder, 'test_hail_in_notebook.ipynb') output_ipynb = os.path.join(folder, 'test_hail_in_notebook_out.ipynb') - sync_check_exec('jupyter', 'nbconvert', '--to', 'notebook', '--execute', str(source_ipynb), '--output', str(output_ipynb)) + sync_check_exec( + 'jupyter', 'nbconvert', '--to', 'notebook', '--execute', str(source_ipynb), '--output', str(output_ipynb) + ) diff --git a/hail/python/test/hail/test_indices_aggregations.py b/hail/python/test/hail/test_indices_aggregations.py index 38d7119f946..92935bdc8c7 100644 --- a/hail/python/test/hail/test_indices_aggregations.py +++ b/hail/python/test/hail/test_indices_aggregations.py @@ -4,7 +4,7 @@ def test_array_slice_end(): ht = hl.utils.range_matrix_table(1, 1) try: - ht = ht.annotate_rows(c = hl.array([1,2,3])[:ht.col_idx]) + ht = ht.annotate_rows(c=hl.array([1, 2, 3])[: ht.col_idx]) except hl.ExpressionException as exc: assert 'scope violation' in exc.args[0] assert "'col_idx' (indices ['column'])" in exc.args[0] @@ -15,7 +15,7 @@ def test_array_slice_end(): def test_array_slice_start(): ht = hl.utils.range_matrix_table(1, 1) try: - ht = ht.annotate_rows(c = hl.array([1,2,3])[ht.col_idx:]) + ht = ht.annotate_rows(c=hl.array([1, 2, 3])[ht.col_idx :]) except hl.ExpressionException as exc: assert 'scope violation' in exc.args[0] assert "'col_idx' (indices ['column'])" in exc.args[0] @@ -26,7 +26,7 @@ def test_array_slice_start(): def test_array_slice_step(): ht = hl.utils.range_matrix_table(1, 1) try: - ht = ht.annotate_rows(c = hl.array([1,2,3])[::ht.col_idx]) + ht = ht.annotate_rows(c=hl.array([1, 2, 3])[:: ht.col_idx]) except hl.ExpressionException as exc: assert 'scope violation' in exc.args[0] assert "'col_idx' (indices ['column'])" in exc.args[0] @@ -36,10 +36,10 @@ def test_array_slice_step(): def test_matmul(): ht = hl.utils.range_matrix_table(1, 1) - ht = ht.annotate_cols(a = hl.nd.array([0])) - ht = ht.annotate_rows(b = hl.nd.array([0])) + ht = ht.annotate_cols(a=hl.nd.array([0])) + ht = ht.annotate_rows(b=hl.nd.array([0])) try: - ht = ht.annotate_rows(c = ht.b @ ht.a) + ht = ht.annotate_rows(c=ht.b @ ht.a) except hl.ExpressionException as exc: assert 'scope violation' in exc.args[0] assert "'a' (indices ['column'])" in exc.args[0] @@ -49,9 +49,9 @@ def test_matmul(): def test_ndarray_index(): ht = hl.utils.range_matrix_table(1, 1) - ht = ht.annotate_rows(b = hl.nd.array([0])) + ht = ht.annotate_rows(b=hl.nd.array([0])) try: - ht = ht.annotate_rows(c = ht.b[ht.col_idx]) + ht = ht.annotate_rows(c=ht.b[ht.col_idx]) except hl.ExpressionException as exc: assert 'scope violation' in exc.args[0] assert "'col_idx' (indices ['column'])" in exc.args[0] @@ -61,9 +61,9 @@ def test_ndarray_index(): def test_ndarray_index_with_slice_1(): ht = hl.utils.range_matrix_table(1, 1) - ht = ht.annotate_rows(b = hl.nd.array([[0]])) + ht = ht.annotate_rows(b=hl.nd.array([[0]])) try: - ht = ht.annotate_rows(c = ht.b[ht.col_idx, :]) + ht = ht.annotate_rows(c=ht.b[ht.col_idx, :]) except hl.ExpressionException as exc: assert 'scope violation' in exc.args[0] assert "'col_idx' (indices ['column'])" in exc.args[0] @@ -73,9 +73,9 @@ def test_ndarray_index_with_slice_1(): def test_ndarray_index_with_slice_2(): ht = hl.utils.range_matrix_table(1, 1) - ht = ht.annotate_rows(b = hl.nd.array([[0]])) + ht = ht.annotate_rows(b=hl.nd.array([[0]])) try: - ht = ht.annotate_rows(c = ht.b[:, ht.col_idx]) + ht = ht.annotate_rows(c=ht.b[:, ht.col_idx]) except hl.ExpressionException as exc: assert 'scope violation' in exc.args[0] assert "'col_idx' (indices ['column'])" in exc.args[0] @@ -85,9 +85,9 @@ def test_ndarray_index_with_slice_2(): def test_ndarray_index_with_None_1(): ht = hl.utils.range_matrix_table(1, 1) - ht = ht.annotate_rows(b = hl.nd.array([[0]])) + ht = ht.annotate_rows(b=hl.nd.array([[0]])) try: - ht = ht.annotate_rows(c = ht.b[ht.col_idx, None]) + ht = ht.annotate_rows(c=ht.b[ht.col_idx, None]) except hl.ExpressionException as exc: assert 'scope violation' in exc.args[0] assert "'col_idx' (indices ['column'])" in exc.args[0] @@ -97,9 +97,9 @@ def test_ndarray_index_with_None_1(): def test_ndarray_index_with_None_2(): ht = hl.utils.range_matrix_table(1, 1) - ht = ht.annotate_rows(b = hl.nd.array([[0]])) + ht = ht.annotate_rows(b=hl.nd.array([[0]])) try: - ht = ht.annotate_rows(c = ht.b[None, ht.col_idx]) + ht = ht.annotate_rows(c=ht.b[None, ht.col_idx]) except hl.ExpressionException as exc: assert 'scope violation' in exc.args[0] assert "'col_idx' (indices ['column'])" in exc.args[0] @@ -109,9 +109,9 @@ def test_ndarray_index_with_None_2(): def test_ndarray_reshape_1(): ht = hl.utils.range_matrix_table(1, 1) - ht = ht.annotate_rows(b = hl.nd.array([[0]])) + ht = ht.annotate_rows(b=hl.nd.array([[0]])) try: - ht = ht.annotate_rows(c = ht.b.reshape((ht.col_idx, 1))) + ht = ht.annotate_rows(c=ht.b.reshape((ht.col_idx, 1))) except hl.ExpressionException as exc: assert 'scope violation' in exc.args[0] assert "'col_idx' (indices ['column'])" in exc.args[0] @@ -121,9 +121,9 @@ def test_ndarray_reshape_1(): def test_ndarray_reshape_2(): ht = hl.utils.range_matrix_table(1, 1) - ht = ht.annotate_rows(b = hl.nd.array([[0]])) + ht = ht.annotate_rows(b=hl.nd.array([[0]])) try: - ht = ht.annotate_rows(c = ht.b.reshape((1, ht.col_idx))) + ht = ht.annotate_rows(c=ht.b.reshape((1, ht.col_idx))) except hl.ExpressionException as exc: assert 'scope violation' in exc.args[0] assert "'col_idx' (indices ['column'])" in exc.args[0] @@ -133,10 +133,10 @@ def test_ndarray_reshape_2(): def test_ndarray_reshape_tuple(): ht = hl.utils.range_matrix_table(1, 1) - ht = ht.annotate_cols(a = hl.tuple((1, 1))) - ht = ht.annotate_rows(b = hl.nd.array([[0]])) + ht = ht.annotate_cols(a=hl.tuple((1, 1))) + ht = ht.annotate_rows(b=hl.nd.array([[0]])) try: - ht = ht.annotate_rows(c = ht.b.reshape(ht.a)) + ht = ht.annotate_rows(c=ht.b.reshape(ht.a)) except hl.ExpressionException as exc: assert 'scope violation' in exc.args[0] assert "'a' (indices ['column'])" in exc.args[0] diff --git a/hail/python/test/hail/test_ir.py b/hail/python/test/hail/test_ir.py index b2340e99e33..e21ca060a69 100644 --- a/hail/python/test/hail/test_ir.py +++ b/hail/python/test/hail/test_ir.py @@ -1,15 +1,18 @@ import re import unittest +from test.hail.helpers import resource, skip_unless_spark_backend + import numpy as np +import pytest from numpy.testing import assert_array_equal + import hail as hl -import hail.ir as ir -from hail.ir.renderer import CSERenderer +from hail import ir from hail.expr import construct_expr from hail.expr.types import tint32 -from hail.utils.java import Env +from hail.ir.renderer import CSERenderer from hail.utils import new_temp_file -from test.hail.helpers import * +from hail.utils.java import Env class ValueIRTests(unittest.TestCase): @@ -18,7 +21,9 @@ def value_irs_env(self): 'c': hl.tbool, 'a': hl.tarray(hl.tint32), 'st': hl.tstream(hl.tint32), - 'whitenStream': hl.tstream(hl.tstruct(prevWindow=hl.tndarray(hl.tfloat64, 2), newChunk=hl.tndarray(hl.tfloat64, 2))), + 'whitenStream': hl.tstream( + hl.tstruct(prevWindow=hl.tndarray(hl.tfloat64, 2), newChunk=hl.tndarray(hl.tfloat64, 2)) + ), 'mat': hl.tndarray(hl.tfloat64, 2), 'aa': hl.tarray(hl.tarray(hl.tint32)), 'sta': hl.tstream(hl.tarray(hl.tint32)), @@ -29,7 +34,8 @@ def value_irs_env(self): 's': hl.tstruct(x=hl.tint32, y=hl.tint64, z=hl.tfloat64), 't': hl.ttuple(hl.tint32, hl.tint64, hl.tfloat64), 'call': hl.tcall, - 'x': hl.tint32} + 'x': hl.tint32, + } def value_irs(self): env = self.value_irs_env() @@ -40,8 +46,6 @@ def value_irs(self): a = ir.Ref('a', env['a']) st = ir.Ref('st', env['st']) whitenStream = ir.Ref('whitenStream') - mat = ir.Ref('mat') - aa = ir.Ref('aa', env['aa']) sta = ir.Ref('sta', env['sta']) sts = ir.Ref('sts', env['sts']) da = ir.Ref('da', env['da']) @@ -49,14 +53,16 @@ def value_irs(self): v = ir.Ref('v', env['v']) s = ir.Ref('s', env['s']) t = ir.Ref('t', env['t']) - call = ir.Ref('call', env['call']) + ir.Ref('call', env['call']) rngState = ir.RNGStateLiteral() table = ir.TableRange(5, 3) - matrix_read = ir.MatrixRead(ir.MatrixNativeReader( - resource('backward_compatability/1.0.0/matrix_table/0.hmt'), None, False), - False, False) + matrix_read = ir.MatrixRead( + ir.MatrixNativeReader(resource('backward_compatability/1.0.0/matrix_table/0.hmt'), None, False), + False, + False, + ) block_matrix_read = ir.BlockMatrixRead(ir.BlockMatrixNativeReader(resource('blockmatrix_example/0'))) @@ -64,7 +70,14 @@ def aggregate(x): return ir.TableAggregate(table, x) value_irs = [ - i, ir.I64(5), ir.F32(3.14), ir.F64(3.14), s, ir.TrueIR(), ir.FalseIR(), ir.Void(), + i, + ir.I64(5), + ir.F32(3.14), + ir.F64(3.14), + s, + ir.TrueIR(), + ir.FalseIR(), + ir.Void(), ir.Cast(i, hl.tfloat64), ir.NA(hl.tint32), ir.IsNA(i), @@ -78,14 +91,18 @@ def aggregate(x): ir.MakeArray([i, ir.NA(hl.tint32), ir.I32(-3)], hl.tarray(hl.tint32)), ir.ArrayRef(a, i), ir.ArrayLen(a), - ir.ArraySort(ir.ToStream(a), 'l', 'r', ir.ApplyComparisonOp("LT", ir.Ref('l', hl.tint32), ir.Ref('r', hl.tint32))), + ir.ArraySort( + ir.ToStream(a), 'l', 'r', ir.ApplyComparisonOp("LT", ir.Ref('l', hl.tint32), ir.Ref('r', hl.tint32)) + ), ir.ToSet(st), ir.ToDict(da), ir.ToArray(st), ir.CastToArray(ir.NA(hl.tset(hl.tint32))), - ir.MakeNDArray(ir.MakeArray([ir.F64(-1.0), ir.F64(1.0)], hl.tarray(hl.tfloat64)), - ir.MakeTuple([ir.I64(1), ir.I64(2)]), - ir.TrueIR()), + ir.MakeNDArray( + ir.MakeArray([ir.F64(-1.0), ir.F64(1.0)], hl.tarray(hl.tfloat64)), + ir.MakeTuple([ir.I64(1), ir.I64(2)]), + ir.TrueIR(), + ), ir.NDArrayShape(nd), ir.NDArrayReshape(nd, ir.MakeTuple([ir.I64(5)])), ir.NDArrayRef(nd, [ir.I64(1), ir.I64(2)]), @@ -106,7 +123,11 @@ def aggregate(x): aggregate(ir.AggFilter(ir.TrueIR(), ir.I32(0), False)), aggregate(ir.AggExplode(ir.StreamRange(ir.I32(0), ir.I32(2), ir.I32(1)), 'x', ir.I32(0), False)), aggregate(ir.AggGroupBy(ir.TrueIR(), ir.I32(0), False)), - aggregate(ir.AggArrayPerElement(ir.ToArray(ir.StreamRange(ir.I32(0), ir.I32(2), ir.I32(1))), 'x', 'y', ir.I32(0), False)), + aggregate( + ir.AggArrayPerElement( + ir.ToArray(ir.StreamRange(ir.I32(0), ir.I32(2), ir.I32(1))), 'x', 'y', ir.I32(0), False + ) + ), aggregate(ir.ApplyAggOp('Collect', [], [ir.I32(0)])), aggregate(ir.ApplyAggOp('CallStats', [ir.I32(2)], [ir.NA(hl.tcall)])), aggregate(ir.ApplyAggOp('TakeBy', [ir.I32(10)], [ir.F64(-2.11), ir.F64(-2.11)])), @@ -132,14 +153,26 @@ def aggregate(x): ir.TableWrite(table, ir.TableTextWriter(new_temp_file(), None, True, "concatenated", ",")), ir.MatrixAggregate(matrix_read, ir.MakeStruct([('foo', ir.ApplyAggOp('Collect', [], [ir.I32(0)]))])), ir.MatrixWrite(matrix_read, ir.MatrixNativeWriter(new_temp_file(), False, False, "", None, None)), - ir.MatrixWrite(matrix_read, ir.MatrixNativeWriter(new_temp_file(), False, False, "", - '[{"start":{"row_idx":0},"end":{"row_idx": 10},"includeStart":true,"includeEnd":false}]', - hl.dtype('array>'))), - ir.MatrixWrite(matrix_read, ir.MatrixVCFWriter(new_temp_file(), None, ir.ExportType.CONCATENATED, None, False)), + ir.MatrixWrite( + matrix_read, + ir.MatrixNativeWriter( + new_temp_file(), + False, + False, + "", + '[{"start":{"row_idx":0},"end":{"row_idx": 10},"includeStart":true,"includeEnd":false}]', + hl.dtype('array>'), + ), + ), + ir.MatrixWrite( + matrix_read, ir.MatrixVCFWriter(new_temp_file(), None, ir.ExportType.CONCATENATED, None, False) + ), ir.MatrixWrite(matrix_read, ir.MatrixGENWriter(new_temp_file(), 4)), ir.MatrixWrite(matrix_read, ir.MatrixPLINKWriter(new_temp_file())), - ir.MatrixMultiWrite([matrix_read, matrix_read], - ir.MatrixNativeMultiWriter([new_temp_file(), new_temp_file()], False, False, None)), + ir.MatrixMultiWrite( + [matrix_read, matrix_read], + ir.MatrixNativeMultiWriter([new_temp_file(), new_temp_file()], False, False, None), + ), ir.BlockMatrixWrite(block_matrix_read, ir.BlockMatrixNativeWriter('fake.bm', False, False, False)), ir.LiftMeOut(ir.I32(1)), ir.BlockMatrixWrite(block_matrix_read, ir.BlockMatrixPersistWriter('x', 'MEMORY_ONLY')), @@ -164,52 +197,53 @@ class TableIRTests(unittest.TestCase): def table_irs(self): b = ir.TrueIR() table_read = ir.TableRead( - ir.TableNativeReader(resource('backward_compatability/1.1.0/table/0.ht'), None, False), False) - table_read_row_type = hl.dtype('struct{idx: int32, f32: float32, i64: int64, m: float64, astruct: struct{a: int32, b: float64}, mstruct: struct{x: int32, y: str}, aset: set, mset: set, d: dict, float64>, md: dict, h38: locus, ml: locus, i: interval>, c: call, mc: call, t: tuple(call, str, str), mt: tuple(locus, bool)}') + ir.TableNativeReader(resource('backward_compatability/1.1.0/table/0.ht'), None, False), False + ) + table_read_row_type = hl.dtype( + 'struct{idx: int32, f32: float32, i64: int64, m: float64, astruct: struct{a: int32, b: float64}, mstruct: struct{x: int32, y: str}, aset: set, mset: set, d: dict, float64>, md: dict, h38: locus, ml: locus, i: interval>, c: call, mc: call, t: tuple(call, str, str), mt: tuple(locus, bool)}' + ) matrix_read = ir.MatrixRead( ir.MatrixNativeReader(resource('backward_compatability/1.0.0/matrix_table/0.hmt'), None, False), - False, False) + False, + False, + ) block_matrix_read = ir.BlockMatrixRead(ir.BlockMatrixNativeReader(resource('blockmatrix_example/0'))) - aa = hl.literal([[0.00],[0.01],[0.02]])._ir + aa = hl.literal([[0.00], [0.01], [0.02]])._ir table_irs = [ ir.TableKeyBy(table_read, ['m', 'd'], False), ir.TableFilter(table_read, b), table_read, ir.MatrixColsTable(matrix_read), - ir.TableAggregateByKey( - table_read, - ir.MakeStruct([('a', ir.I32(5))])), + ir.TableAggregateByKey(table_read, ir.MakeStruct([('a', ir.I32(5))])), ir.TableKeyByAndAggregate( - table_read, - ir.MakeStruct([('a', ir.I32(5))]), - ir.MakeStruct([('b', ir.I32(5))]), - 1, 2), - ir.TableJoin( - table_read, - ir.TableRange(100, 10), 'inner', 1), + table_read, ir.MakeStruct([('a', ir.I32(5))]), ir.MakeStruct([('b', ir.I32(5))]), 1, 2 + ), + ir.TableJoin(table_read, ir.TableRange(100, 10), 'inner', 1), ir.MatrixEntriesTable(matrix_read), ir.MatrixRowsTable(matrix_read), - ir.TableParallelize(ir.MakeStruct([ - ('rows', ir.Literal(hl.tarray(hl.tstruct(a=hl.tint32)), [{'a':None}, {'a':5}, {'a':-3}])), - ('global', ir.MakeStruct([]))]), None), + ir.TableParallelize( + ir.MakeStruct([ + ('rows', ir.Literal(hl.tarray(hl.tstruct(a=hl.tint32)), [{'a': None}, {'a': 5}, {'a': -3}])), + ('global', ir.MakeStruct([])), + ]), + None, + ), ir.TableMapRows( ir.TableKeyBy(table_read, []), ir.MakeStruct([ ('a', ir.GetField(ir.Ref('row', table_read_row_type), 'f32')), ('b', ir.F64(-2.11)), - ('c', ir.ApplyScanOp('Collect', [], [ir.I32(0)]))])), - ir.TableMapGlobals( - table_read, - ir.MakeStruct([ - ('foo', ir.NA(hl.tarray(hl.tint32)))])), + ('c', ir.ApplyScanOp('Collect', [], [ir.I32(0)])), + ]), + ), + ir.TableMapGlobals(table_read, ir.MakeStruct([('foo', ir.NA(hl.tarray(hl.tint32)))])), ir.TableRange(100, 10), ir.TableRepartition(table_read, 10, ir.RepartitionStrategy.COALESCE), - ir.TableUnion( - [ir.TableRange(100, 10), ir.TableRange(50, 10)]), + ir.TableUnion([ir.TableRange(100, 10), ir.TableRange(50, 10)]), ir.TableExplode(table_read, ['mset']), ir.TableHead(table_read, 10), ir.TableOrderBy(ir.TableKeyBy(table_read, []), [('m', 'A'), ('m', 'D')]), @@ -217,10 +251,25 @@ def table_irs(self): ir.CastMatrixToTable(matrix_read, '__entries', '__cols'), ir.TableRename(table_read, {'idx': 'idx_foo'}, {'global_f32': 'global_foo'}), ir.TableMultiWayZipJoin([table_read, table_read], '__data', '__globals'), - ir.MatrixToTableApply(matrix_read, {'name': 'LinearRegressionRowsSingle', 'yFields': ['col_m'], 'xField': 'entry_m', 'covFields': [], 'rowBlockSize': 10, 'passThrough': []}), + ir.MatrixToTableApply( + matrix_read, + { + 'name': 'LinearRegressionRowsSingle', + 'yFields': ['col_m'], + 'xField': 'entry_m', + 'covFields': [], + 'rowBlockSize': 10, + 'passThrough': [], + }, + ), ir.TableToTableApply(table_read, {'name': 'TableFilterPartitions', 'parts': [0], 'keep': True}), ir.BlockMatrixToTableApply(block_matrix_read, aa, {'name': 'PCRelate', 'maf': 0.01, 'blockSize': 4096}), - ir.TableFilterIntervals(table_read, [hl.utils.Interval(hl.utils.Struct(row_idx=0), hl.utils.Struct(row_idx=10))], hl.tstruct(row_idx=hl.tint32), keep=False), + ir.TableFilterIntervals( + table_read, + [hl.utils.Interval(hl.utils.Struct(row_idx=0), hl.utils.Struct(row_idx=10))], + hl.tstruct(row_idx=hl.tint32), + keep=False, + ), ir.TableMapPartitions(table_read, 'glob', 'rows', ir.Ref('rows', hl.tstream(table_read_row_type)), 0, 1), ir.TableGen( contexts=ir.StreamRange(ir.I32(0), ir.I32(10), ir.I32(1)), @@ -229,10 +278,9 @@ def table_irs(self): gname="globals", body=ir.ToStream(ir.MakeArray([ir.MakeStruct([('a', ir.I32(1))])], type=None)), partitioner=ir.Partitioner( - hl.tstruct(a=hl.tint), - [hl.Interval(hl.Struct(a=1), hl.Struct(a=2), True, True)] - ) - ) + hl.tstruct(a=hl.tint), [hl.Interval(hl.Struct(a=1), hl.Struct(a=2), True, True)] + ), + ), ] return table_irs @@ -248,11 +296,13 @@ def matrix_irs(self): collect = ir.MakeStruct([('x', ir.ApplyAggOp('Collect', [], [ir.I32(0)]))]) matrix_read = ir.MatrixRead( - ir.MatrixNativeReader( - resource('backward_compatability/1.0.0/matrix_table/0.hmt'), None, False), - False, False) + ir.MatrixNativeReader(resource('backward_compatability/1.0.0/matrix_table/0.hmt'), None, False), + False, + False, + ) table_read = ir.TableRead( - ir.TableNativeReader(resource('backward_compatability/1.1.0/table/0.ht'), None, False), False) + ir.TableNativeReader(resource('backward_compatability/1.1.0/table/0.ht'), None, False), False + ) matrix_range = ir.MatrixRead(ir.MatrixRangeReader(1, 1, 10)) matrix_irs = [ @@ -261,17 +311,30 @@ def matrix_irs(self): ir.MatrixDistinctByRow(matrix_range), ir.MatrixRowsHead(matrix_read, 5), ir.MatrixColsHead(matrix_read, 5), - ir.CastTableToMatrix( - ir.CastMatrixToTable(matrix_read, '__entries', '__cols'), - '__entries', - '__cols', - []), + ir.CastTableToMatrix(ir.CastMatrixToTable(matrix_read, '__entries', '__cols'), '__entries', '__cols', []), ir.MatrixAggregateRowsByKey(matrix_read, collect, collect), ir.MatrixAggregateColsByKey(matrix_read, collect, collect), matrix_read, matrix_range, - ir.MatrixRead(ir.MatrixVCFReader(resource('sample.vcf'), ['GT'], hl.tfloat64, None, None, None, None, None, None, - False, True, False, True, None, None)), + ir.MatrixRead( + ir.MatrixVCFReader( + resource('sample.vcf'), + ['GT'], + hl.tfloat64, + None, + None, + None, + None, + None, + None, + False, + True, + False, + True, + None, + None, + ) + ), ir.MatrixRead(ir.MatrixBGENReader(resource('example.8bits.bgen'), None, {}, 10, 1, None)), ir.MatrixFilterRows(matrix_read, ir.FalseIR()), ir.MatrixFilterCols(matrix_read, ir.FalseIR()), @@ -288,8 +351,19 @@ def matrix_irs(self): ir.MatrixAnnotateRowsTable(matrix_read, table_read, '__foo'), ir.MatrixAnnotateColsTable(matrix_read, table_read, '__foo'), ir.MatrixToMatrixApply(matrix_read, {'name': 'MatrixFilterPartitions', 'parts': [0], 'keep': True}), - ir.MatrixRename(matrix_read, {'global_f32': 'global_foo'}, {'col_f32': 'col_foo'}, {'row_aset': 'row_aset2'}, {'entry_f32': 'entry_foo'}), - ir.MatrixFilterIntervals(matrix_read, [hl.utils.Interval(hl.utils.Struct(row_idx=0), hl.utils.Struct(row_idx=10))], hl.tstruct(row_idx=hl.tint32), keep=False), + ir.MatrixRename( + matrix_read, + {'global_f32': 'global_foo'}, + {'col_f32': 'col_foo'}, + {'row_aset': 'row_aset2'}, + {'entry_f32': 'entry_foo'}, + ), + ir.MatrixFilterIntervals( + matrix_read, + [hl.utils.Interval(hl.utils.Struct(row_idx=0), hl.utils.Struct(row_idx=10))], + hl.tstruct(row_idx=hl.tint32), + keep=False, + ), ] return matrix_irs @@ -320,9 +394,13 @@ def blockmatrix_irs(self): vector_ir = ir.MakeArray([ir.F64(3), ir.F64(2)], hl.tarray(hl.tfloat64)) read = ir.BlockMatrixRead(ir.BlockMatrixNativeReader(resource('blockmatrix_example/0'))) - add_two_bms = ir.BlockMatrixMap2(read, read, 'l', 'r', ir.ApplyBinaryPrimOp('+', ir.Ref('l', hl.tfloat64), ir.Ref('r', hl.tfloat64)), "Union") + add_two_bms = ir.BlockMatrixMap2( + read, read, 'l', 'r', ir.ApplyBinaryPrimOp('+', ir.Ref('l', hl.tfloat64), ir.Ref('r', hl.tfloat64)), "Union" + ) negate_bm = ir.BlockMatrixMap(read, 'element', ir.ApplyUnaryPrimOp('-', ir.Ref('element', hl.tfloat64)), False) - sqrt_bm = ir.BlockMatrixMap(read, 'element', hl.sqrt(construct_expr(ir.Ref('element', hl.tfloat64), hl.tfloat64))._ir, False) + sqrt_bm = ir.BlockMatrixMap( + read, 'element', hl.sqrt(construct_expr(ir.Ref('element', hl.tfloat64), hl.tfloat64))._ir, False + ) scalar_to_bm = ir.ValueToBlockMatrix(scalar_ir, [1, 1], 1) col_vector_to_bm = ir.ValueToBlockMatrix(vector_ir, [2, 1], 1) @@ -343,7 +421,10 @@ def blockmatrix_irs(self): densify = ir.BlockMatrixDensify(read) - pow_ir = (construct_expr(ir.Ref('l', hl.tfloat64), hl.tfloat64) ** construct_expr(ir.Ref('r', hl.tfloat64), hl.tfloat64))._ir + pow_ir = ( + construct_expr(ir.Ref('l', hl.tfloat64), hl.tfloat64) + ** construct_expr(ir.Ref('r', hl.tfloat64), hl.tfloat64) + )._ir squared_bm = ir.BlockMatrixMap2(scalar_to_bm, scalar_to_bm, 'l', 'r', pow_ir, "NeedsDense") slice_bm = ir.BlockMatrixSlice(matmul, [slice(0, 2, 1), slice(0, 1, 1)]) @@ -365,7 +446,7 @@ def blockmatrix_irs(self): sparsify3, densify, matmul, - slice_bm + slice_bm, ] @skip_unless_spark_backend() @@ -376,12 +457,11 @@ def test_parses(self): backend.execute(ir.BlockMatrixWrite(bmir, ir.BlockMatrixPersistWriter('x', 'MEMORY_ONLY'))) persist = ir.BlockMatrixRead(ir.BlockMatrixPersistReader('x', bmir)) - for x in (self.blockmatrix_irs() + [persist]): + for x in [*self.blockmatrix_irs(), persist]: backend._parse_blockmatrix_ir(str(x)) class ValueTests(unittest.TestCase): - def values(self): values = [ (hl.tbool, True), @@ -396,7 +476,7 @@ def values(self): (hl.tdict(hl.tstr, hl.tint32), {"a": 0, "b": 1, "c": 4}), (hl.tinterval(hl.tint32), hl.Interval(0, 1, True, False)), (hl.tlocus(hl.default_reference()), hl.Locus("1", 1)), - (hl.tcall, hl.Call([0, 1])) + (hl.tcall, hl.Call([0, 1])), ] return values @@ -407,11 +487,8 @@ def test_value_same_after_parsing(self): row_v = ir.Literal(t, v) range = ir.TableRange(1, 1) map_globals_ir = ir.TableMapGlobals( - range, - ir.InsertFields( - ir.Ref("global", range.typ.global_type), - [("foo", row_v)], - None)) + range, ir.InsertFields(ir.Ref("global", range.typ.global_type), [("foo", row_v)], None) + ) test_exprs.append(hl.Table(map_globals_ir).index_globals()) expecteds.append(hl.Struct(foo=v)) @@ -425,11 +502,7 @@ class CSETests(unittest.TestCase): def test_cse(self): x = ir.I32(5) x = ir.ApplyBinaryPrimOp('+', x, x) - expected = ( - '(Let __cse_1 (I32 5)' - ' (ApplyBinaryPrimOp `+`' - ' (Ref __cse_1)' - ' (Ref __cse_1)))') + expected = '(Let eval __cse_1 (I32 5)' ' (ApplyBinaryPrimOp `+`' ' (Ref __cse_1)' ' (Ref __cse_1)))' assert expected == CSERenderer()(x) def test_cse_debug(self): @@ -444,12 +517,12 @@ def test_cse_complex_lifting(self): prod = ir.ApplyBinaryPrimOp('*', sum, sum) cond = ir.If(ir.ApplyComparisonOp('EQ', prod, x), sum, x) expected = ( - '(Let __cse_1 (I32 5)' - ' (Let __cse_2 (ApplyBinaryPrimOp `+` (Ref __cse_1) (Ref __cse_1))' - ' (If (ApplyComparisonOp EQ (ApplyBinaryPrimOp `*` (Ref __cse_2) (Ref __cse_2)) (Ref __cse_1))' - ' (Let __cse_3 (I32 5)' - ' (ApplyBinaryPrimOp `+` (Ref __cse_3) (Ref __cse_3)))' - ' (I32 5))))' + '(Let eval __cse_1 (I32 5)' + ' (Let eval __cse_2 (ApplyBinaryPrimOp `+` (Ref __cse_1) (Ref __cse_1))' + ' (If (ApplyComparisonOp EQ (ApplyBinaryPrimOp `*` (Ref __cse_2) (Ref __cse_2)) (Ref __cse_1))' + ' (Let eval __cse_3 (I32 5)' + ' (ApplyBinaryPrimOp `+` (Ref __cse_3) (Ref __cse_3)))' + ' (I32 5))))' ) assert expected == CSERenderer()(cond) @@ -459,12 +532,12 @@ def test_stream_cse(self): a2 = ir.ToArray(x) t = ir.MakeTuple([a1, a2]) expected_re = ( - '(Let __cse_1 (I32 0)' - ' (Let __cse_2 (I32 10)' - ' (Let __cse_3 (I32 1)' + '(Let eval __cse_1 (I32 0)' + ' (Let eval __cse_2 (I32 10)' + ' (Let eval __cse_3 (I32 1)' ' (MakeTuple (0 1)' - ' (ToArray (StreamRange [0-9]+ False (Ref __cse_1) (Ref __cse_2) (Ref __cse_3)))' - ' (ToArray (StreamRange [0-9]+ False (Ref __cse_1) (Ref __cse_2) (Ref __cse_3)))))))' + ' (ToArray (StreamRange [0-9]+ False (Ref __cse_1) (Ref __cse_2) (Ref __cse_3)))' + ' (ToArray (StreamRange [0-9]+ False (Ref __cse_1) (Ref __cse_2) (Ref __cse_3)))))))' ) expected_re = expected_re.replace('(', '\\(').replace(')', '\\)') assert re.match(expected_re, CSERenderer()(t)) @@ -476,13 +549,14 @@ def test_cse2(self): prod = ir.ApplyBinaryPrimOp('*', sum, y) div = ir.ApplyBinaryPrimOp('/', prod, sum) expected = ( - '(Let __cse_1 (I32 5)' - ' (Let __cse_2 (ApplyBinaryPrimOp `+` (Ref __cse_1) (Ref __cse_1))' + '(Let eval __cse_1 (I32 5)' + ' (Let eval __cse_2 (ApplyBinaryPrimOp `+` (Ref __cse_1) (Ref __cse_1))' ' (ApplyBinaryPrimOp `/`' - ' (ApplyBinaryPrimOp `*`' - ' (Ref __cse_2)' - ' (I32 4))' - ' (Ref __cse_2))))') + ' (ApplyBinaryPrimOp `*`' + ' (Ref __cse_2)' + ' (I32 4))' + ' (Ref __cse_2))))' + ) assert expected == CSERenderer()(div) def test_cse_ifs(self): @@ -493,11 +567,11 @@ def test_cse_ifs(self): cond = ir.If(ir.TrueIR(), prod, outer_repeated) expected = ( '(If (True)' - ' (Let __cse_1 (I32 1)' - ' (ApplyBinaryPrimOp `*`' - ' (ApplyBinaryPrimOp `+` (Ref __cse_1) (Ref __cse_1))' - ' (I32 5)))' - ' (I32 5))' + ' (Let eval __cse_1 (I32 1)' + ' (ApplyBinaryPrimOp `*`' + ' (ApplyBinaryPrimOp `+` (Ref __cse_1) (Ref __cse_1))' + ' (I32 5)))' + ' (I32 5))' ) assert expected == CSERenderer()(cond) @@ -507,12 +581,13 @@ def test_shadowing(self): inner = ir.Let('row', sum, sum) outer = ir.Let('row', ir.I32(5), inner) expected = ( - '(Let __cse_2 (I32 2)' - ' (Let row (I32 5)' - ' (Let __cse_1 (ApplyBinaryPrimOp `*` (Ref row) (Ref __cse_2))' - ' (Let row (ApplyBinaryPrimOp `+` (Ref __cse_1) (Ref __cse_1))' - ' (Let __cse_3 (ApplyBinaryPrimOp `*` (Ref row) (Ref __cse_2))' - ' (ApplyBinaryPrimOp `+` (Ref __cse_3) (Ref __cse_3)))))))') + '(Let eval __cse_2 (I32 2)' + ' (Let eval row (I32 5)' + ' (Let eval __cse_1 (ApplyBinaryPrimOp `*` (Ref row) (Ref __cse_2))' + ' (Let eval row (ApplyBinaryPrimOp `+` (Ref __cse_1) (Ref __cse_1))' + ' (Let eval __cse_3 (ApplyBinaryPrimOp `*` (Ref row) (Ref __cse_2))' + ' (ApplyBinaryPrimOp `+` (Ref __cse_3) (Ref __cse_3)))))))' + ) assert expected == CSERenderer()(outer) def test_agg_cse(self): @@ -525,14 +600,15 @@ def test_agg_cse(self): table_agg = ir.TableAggregate(table, ir.MakeTuple([outer_sum, filter])) expected = ( '(TableAggregate (TableRange 5 1)' - ' (AggLet __cse_1 False (GetField idx (Ref row))' - ' (AggLet __cse_3 False (ApplyBinaryPrimOp `+` (Ref __cse_1) (Ref __cse_1))' - ' (Let __cse_2 (ApplyAggOp AggOp () ((Ref __cse_3)))' - ' (MakeTuple (0 1)' - ' (ApplyBinaryPrimOp `+` (Ref __cse_2) (Ref __cse_2))' - ' (AggFilter False (True)' - ' (Let __cse_4 (ApplyAggOp AggOp () ((Ref __cse_3)))' - ' (ApplyBinaryPrimOp `+` (Ref __cse_4) (Ref __cse_4)))))))))') + ' (AggLet __cse_1 False (GetField idx (Ref row))' + ' (AggLet __cse_3 False (ApplyBinaryPrimOp `+` (Ref __cse_1) (Ref __cse_1))' + ' (Let eval __cse_2 (ApplyAggOp AggOp () ((Ref __cse_3)))' + ' (MakeTuple (0 1)' + ' (ApplyBinaryPrimOp `+` (Ref __cse_2) (Ref __cse_2))' + ' (AggFilter False (True)' + ' (Let eval __cse_4 (ApplyAggOp AggOp () ((Ref __cse_3)))' + ' (ApplyBinaryPrimOp `+` (Ref __cse_4) (Ref __cse_4)))))))))' + ) assert expected == CSERenderer()(table_agg) def test_init_op(self): @@ -541,14 +617,15 @@ def test_init_op(self): agg = ir.ApplyAggOp('CallStats', [sum], [sum]) top = ir.ApplyBinaryPrimOp('+', sum, agg) expected = ( - '(Let __cse_1 (I32 5)' + '(Let eval __cse_1 (I32 5)' ' (AggLet __cse_3 False (I32 5)' ' (ApplyBinaryPrimOp `+`' - ' (ApplyBinaryPrimOp `+` (Ref __cse_1) (Ref __cse_1))' - ' (ApplyAggOp CallStats' - ' ((Let __cse_2 (I32 5)' - ' (ApplyBinaryPrimOp `+` (Ref __cse_2) (Ref __cse_2))))' - ' ((ApplyBinaryPrimOp `+` (Ref __cse_3) (Ref __cse_3)))))))') + ' (ApplyBinaryPrimOp `+` (Ref __cse_1) (Ref __cse_1))' + ' (ApplyAggOp CallStats' + ' ((Let eval __cse_2 (I32 5)' + ' (ApplyBinaryPrimOp `+` (Ref __cse_2) (Ref __cse_2))))' + ' ((ApplyBinaryPrimOp `+` (Ref __cse_3) (Ref __cse_3)))))))' + ) assert expected == CSERenderer()(top) def test_agg_let(self): @@ -557,7 +634,7 @@ def test_agg_let(self): agglet = ir.AggLet('foo', ir.I32(2), sum, False) expected = ( '(AggLet foo False (I32 2)' - ' (Let __cse_1 (ApplyAggOp AggOp () ((Ref foo)))' + ' (Let eval __cse_1 (ApplyAggOp AggOp () ((Ref foo)))' ' (ApplyBinaryPrimOp `+` (Ref __cse_1) (Ref __cse_1))))' ) assert expected == CSERenderer()(agglet) @@ -565,14 +642,12 @@ def test_agg_let(self): def test_refs(self): table = ir.TableRange(10, 1) ref = ir.Ref('row', table.typ.row_type) - x = ir.TableMapRows(table, - ir.MakeStruct([('foo', ir.GetField(ref, 'idx')), - ('bar', ir.GetField(ref, 'idx'))])) + x = ir.TableMapRows(table, ir.MakeStruct([('foo', ir.GetField(ref, 'idx')), ('bar', ir.GetField(ref, 'idx'))])) expected = ( '(TableMapRows (TableRange 10 1)' - ' (MakeStruct' - ' (foo (GetField idx (Ref row)))' - ' (bar (GetField idx (Ref row)))))' + ' (MakeStruct' + ' (foo (GetField idx (Ref row)))' + ' (bar (GetField idx (Ref row)))))' ) assert expected == CSERenderer()(x) @@ -607,11 +682,12 @@ def _assert_encoding_roundtrip(value): hl.Call([]), hl.Call([1, 1]), hl.Call([17495, 17495]), - ] + ], ) def test_literal_encodings(value): _assert_encoding_roundtrip(value) + @pytest.mark.parametrize( 'value', [ @@ -620,7 +696,7 @@ def test_literal_encodings(value): np.array([1, 2, 3, 4]), np.array([[1, 2], [3, 4], [5, 6]]), np.array([[[[1]], [[2]]], [[[3]], [[4]]], [[[5]], [[6]]]]), - ] + ], ) def test_literal_ndarray_encodings(value): _assert_encoding_roundtrip(value) @@ -630,7 +706,7 @@ def test_literal_ndarray_encodings(value): def test_decoding_multiple_dicts(): dict = {0: 'a', 1: 'b', 2: 'c'} dict2 = {0: 'x', 1: 'y', 2: 'z'} - ht = hl.utils.range_table(1).annotate(indices = hl.array([0, 1, 2])) + ht = hl.utils.range_table(1).annotate(indices=hl.array([0, 1, 2])) ht.select(a=ht.indices.map(lambda i: hl.struct(x=hl.dict(dict).get(i), y=hl.dict(dict2).get(i)))).collect() diff --git a/hail/python/test/hail/test_no_context.py b/hail/python/test/hail/test_no_context.py index 010b1839b8a..ff87c6d82cb 100644 --- a/hail/python/test/hail/test_no_context.py +++ b/hail/python/test/hail/test_no_context.py @@ -1,8 +1,8 @@ import unittest + import hail as hl -from .helpers import * class Tests(unittest.TestCase): def test_get_reference_before_init(self): - hl.get_reference('GRCh37') # Should be no error + hl.get_reference('GRCh37') # Should be no error diff --git a/hail/python/test/hail/test_randomness.py b/hail/python/test/hail/test_randomness.py index 3a53f0213ad..e91cd3e047c 100644 --- a/hail/python/test/hail/test_randomness.py +++ b/hail/python/test/hail/test_randomness.py @@ -7,7 +7,7 @@ def test_table_explode(): hl.reset_global_randomness() ht = hl.utils.range_table(5) - ht = ht.annotate(x = hl.range(hl.rand_int32(5))) + ht = ht.annotate(x=hl.range(hl.rand_int32(5))) ht = ht.explode('x') expected = [ hl.Struct(idx=0, x=0), @@ -24,7 +24,7 @@ def test_table_explode(): hl.Struct(idx=3, x=2), hl.Struct(idx=4, x=0), hl.Struct(idx=4, x=1), - hl.Struct(idx=4, x=2) + hl.Struct(idx=4, x=2), ] actual = ht.collect() assert expected == actual @@ -33,14 +33,14 @@ def test_table_explode(): def test_table_key_by(): hl.reset_global_randomness() ht = hl.utils.range_table(5) - ht = ht.annotate(x = hl.rand_int32(5)) + ht = ht.annotate(x=hl.rand_int32(5)) ht = ht.key_by('x') expected = [ hl.Struct(idx=2, x=2), hl.Struct(idx=1, x=3), hl.Struct(idx=3, x=3), hl.Struct(idx=4, x=3), - hl.Struct(idx=0, x=4) + hl.Struct(idx=0, x=4), ] actual = ht.collect() assert expected == actual @@ -49,8 +49,8 @@ def test_table_key_by(): def test_table_annotate(): hl.reset_global_randomness() ht = hl.utils.range_table(5) - ht = ht.annotate(x = hl.rand_int32(5)) - ht = ht.annotate(y = ht.x * 10) + ht = ht.annotate(x=hl.rand_int32(5)) + ht = ht.annotate(y=ht.x * 10) expected = [ hl.Struct(idx=0, x=4, y=40), hl.Struct(idx=1, x=3, y=30), @@ -65,7 +65,7 @@ def test_table_annotate(): def test_matrix_table_entries(): hl.reset_global_randomness() mt = hl.utils.range_matrix_table(5, 2) - mt = mt.annotate_entries(x = hl.rand_int32(5)) + mt = mt.annotate_entries(x=hl.rand_int32(5)) expected = [ hl.Struct(row_idx=0, col_idx=0, x=0), hl.Struct(row_idx=0, col_idx=1, x=3), @@ -85,7 +85,7 @@ def test_matrix_table_entries(): def test_table_filter(): hl.reset_global_randomness() ht = hl.utils.range_table(5) - ht = ht.annotate(x = hl.rand_int32(5)) + ht = ht.annotate(x=hl.rand_int32(5)) ht = ht.filter(ht.x % 3 == 0) expected = [hl.Struct(idx=1, x=3), hl.Struct(idx=3, x=3), hl.Struct(idx=4, x=3)] actual = ht.collect() @@ -95,7 +95,7 @@ def test_table_filter(): def test_table_key_by_aggregate(): hl.reset_global_randomness() ht = hl.utils.range_table(5) - ht = ht.annotate(x = hl.rand_int32(5)) + ht = ht.annotate(x=hl.rand_int32(5)) ht = ht.group_by(ht.x).aggregate(y=hl.agg.count()) expected = [hl.Struct(x=2, y=1), hl.Struct(x=3, y=3), hl.Struct(x=4, y=1)] actual = ht.collect() diff --git a/hail/python/test/hail/typecheck/test_typecheck.py b/hail/python/test/hail/typecheck/test_typecheck.py index 846f369bffb..003496c749e 100644 --- a/hail/python/test/hail/typecheck/test_typecheck.py +++ b/hail/python/test/hail/typecheck/test_typecheck.py @@ -1,6 +1,20 @@ import unittest -from hail.typecheck.check import * +from hail.typecheck.check import ( + anytype, + dictof, + func_spec, + lazy, + nullable, + numeric, + oneof, + sequenceof, + sized_tupleof, + transformed, + tupleof, + typecheck, + typecheck_method, +) class Tests(unittest.TestCase): @@ -92,8 +106,7 @@ def good_signature_5(x): self.assertRaises(TypeError, lambda: good_signature_5("1", 2, 2)) self.assertRaises(TypeError, lambda: good_signature_5(("1", 5, 10), ("2", 10, 20))) - @typecheck(x=int, y=str, z=sequenceof(sized_tupleof(str, int, int)), - args=int) + @typecheck(x=int, y=str, z=sequenceof(sized_tupleof(str, int, int)), args=int) def good_signature_6(x, y, z, *args): pass @@ -137,19 +150,16 @@ def f(x): pass f('str') - f(u'unicode') + f('unicode') self.assertRaises(TypeError, lambda: f(['abc'])) def test_nested(self): - @typecheck( - x=int, - y=oneof(nullable(str), sequenceof(sequenceof(dictof(oneof(str, int), anytype)))) - ) + @typecheck(x=int, y=oneof(nullable(str), sequenceof(sequenceof(dictof(oneof(str, int), anytype))))) def f(x, y): pass f(5, None) - f(5, u'7') + f(5, '7') f(5, []) f(5, [[]]) f(5, [[{}]]) @@ -224,10 +234,10 @@ def bar(self, other): self.assertRaises(TypeError, lambda: foo.bar(2)) def test_coercion(self): - @typecheck(a=transformed((int, lambda x: 'int'), - (str, lambda x: 'str')), - b=sequenceof(dictof(str, transformed((int, lambda x: 'int'), - (str, lambda x: 'str'))))) + @typecheck( + a=transformed((int, lambda x: 'int'), (str, lambda x: 'str')), + b=sequenceof(dictof(str, transformed((int, lambda x: 'int'), (str, lambda x: 'str')))), + ) def foo(a, b): return a, b @@ -276,7 +286,10 @@ def test_complex_signature(self): def f(a, b='5', c=[10], *d, **e): pass - f(1, 'a', ) + f( + 1, + 'a', + ) f(1, foo={}) f(1, 'a', foo={}) f(1, c=[25, 2]) diff --git a/hail/python/test/hail/utils/test_deduplicate.py b/hail/python/test/hail/utils/test_deduplicate.py index 13d210ccc9d..31ac89d613d 100644 --- a/hail/python/test/hail/utils/test_deduplicate.py +++ b/hail/python/test/hail/utils/test_deduplicate.py @@ -25,7 +25,6 @@ def test_deduplicate_max_attempts(): def test_deduplicate_already_used(): - mappings, new_ids = deduplicate(['0', '0_1', '0'], - already_used={'0_1', '0_2'}) + mappings, new_ids = deduplicate(['0', '0_1', '0'], already_used={'0_1', '0_2'}) assert mappings == [('0_1', '0_1_1'), ('0', '0_3')] assert new_ids == ['0', '0_1_1', '0_3'] diff --git a/hail/python/test/hail/utils/test_genomic_range_table.py b/hail/python/test/hail/utils/test_genomic_range_table.py index 97a527313a6..223bf205a6c 100644 --- a/hail/python/test/hail/utils/test_genomic_range_table.py +++ b/hail/python/test/hail/utils/test_genomic_range_table.py @@ -3,27 +3,23 @@ def test_genomic_range_table_grch38(): actual = hl.utils.genomic_range_table(10, reference_genome='GRCh38').collect() - expected = [hl.Struct(locus=hl.Locus("chr1", pos + 1, reference_genome='GRCh38')) - for pos in range(10)] + expected = [hl.Struct(locus=hl.Locus("chr1", pos + 1, reference_genome='GRCh38')) for pos in range(10)] assert actual == expected def test_genomic_range_table_grch37(): actual = hl.utils.genomic_range_table(10, reference_genome='GRCh37').collect() - expected = [hl.Struct(locus=hl.Locus("1", pos + 1, reference_genome='GRCh37')) - for pos in range(10)] + expected = [hl.Struct(locus=hl.Locus("1", pos + 1, reference_genome='GRCh37')) for pos in range(10)] assert actual == expected def test_genomic_range_table_canfam3(): actual = hl.utils.genomic_range_table(10, reference_genome='CanFam3').collect() - expected = [hl.Struct(locus=hl.Locus("chr1", pos + 1, reference_genome='CanFam3')) - for pos in range(10)] + expected = [hl.Struct(locus=hl.Locus("chr1", pos + 1, reference_genome='CanFam3')) for pos in range(10)] assert actual == expected def test_genomic_range_table_grcm38(): actual = hl.utils.genomic_range_table(10, reference_genome='GRCm38').collect() - expected = [hl.Struct(locus=hl.Locus("1", pos + 1, reference_genome='GRCm38')) - for pos in range(10)] + expected = [hl.Struct(locus=hl.Locus("1", pos + 1, reference_genome='GRCm38')) for pos in range(10)] assert actual == expected diff --git a/hail/python/test/hail/utils/test_hl_hadoop_and_hail_fs.py b/hail/python/test/hail/utils/test_hl_hadoop_and_hail_fs.py index 41aede3e07a..dc896cb693e 100644 --- a/hail/python/test/hail/utils/test_hl_hadoop_and_hail_fs.py +++ b/hail/python/test/hail/utils/test_hl_hadoop_and_hail_fs.py @@ -1,14 +1,15 @@ +import os +import secrets from typing import Generator + import pytest -import secrets -import os import hail as hl -import hailtop.fs as fs from hail.context import _get_local_tmpdir -from hail.utils import hadoop_open, hadoop_copy, hadoop_ls -from hailtop.utils import secret_alnum_string +from hail.utils import hadoop_copy, hadoop_ls, hadoop_open from hail.utils.java import FatalError +from hailtop import fs +from hailtop.utils import secret_alnum_string from ..helpers import qobtest @@ -22,7 +23,7 @@ def touch(fs, filename: str): def tmpdir(request) -> Generator[str, None, None]: if request.param == 'local': tmpdir = _get_local_tmpdir(None) - tmpdir = tmpdir[len('file://'):] + tmpdir = tmpdir[len('file://') :] else: tmpdir = os.environ['HAIL_TEST_STORAGE_URI'] tmpdir = os.path.join(tmpdir, secret_alnum_string(5)) @@ -74,8 +75,7 @@ def test_hadoop_methods_3(tmpdir: str): f.write(d) f.write('\n') - hadoop_copy(f'{tmpdir}/test_out.txt.gz', - f'{tmpdir}/test_out.copy.txt.gz') + hadoop_copy(f'{tmpdir}/test_out.txt.gz', f'{tmpdir}/test_out.copy.txt.gz') with hadoop_open(f'{tmpdir}/test_out.copy.txt.gz') as f: data4 = [line.strip() for line in f] @@ -134,17 +134,16 @@ def test_hadoop_stat(tmpdir: str): f.write('\n') stat1 = hl.hadoop_stat(f'{tmpdir}') - assert stat1['is_dir'] == True + assert stat1['is_dir'] is True - hadoop_copy(f'{tmpdir}/test_hadoop_stat.txt.gz', - f'{tmpdir}/test_hadoop_stat.copy.txt.gz') + hadoop_copy(f'{tmpdir}/test_hadoop_stat.txt.gz', f'{tmpdir}/test_hadoop_stat.copy.txt.gz') stat2 = hl.hadoop_stat(f'{tmpdir}/test_hadoop_stat.copy.txt.gz') # The gzip format permits metadata which makes the compressed file's size unpredictable. In # practice, Hadoop creates a 175 byte file and gzip.GzipFile creates a 202 byte file. The 27 # extra bytes appear to include at least the filename (20 bytes) and a modification timestamp. assert stat2['size_bytes'] == 175 or stat2['size_bytes'] == 202 - assert stat2['is_dir'] == False + assert stat2['is_dir'] is False assert 'path' in stat2 diff --git a/hail/python/test/hail/utils/test_pickle.py b/hail/python/test/hail/utils/test_pickle.py index e51337848c7..f1a2e48ae5e 100644 --- a/hail/python/test/hail/utils/test_pickle.py +++ b/hail/python/test/hail/utils/test_pickle.py @@ -1,5 +1,7 @@ import pickle + import dill + import hail as hl diff --git a/hail/python/test/hail/utils/test_placement_tree.py b/hail/python/test/hail/utils/test_placement_tree.py index 55c70bd68b9..81434770f75 100644 --- a/hail/python/test/hail/utils/test_placement_tree.py +++ b/hail/python/test/hail/utils/test_placement_tree.py @@ -1,13 +1,12 @@ import unittest import hail as hl - from hail.utils.placement_tree import PlacementTree class Tests(unittest.TestCase): def test_realistic(self): - dtype = hl.dtype('''struct{ + dtype = hl.dtype("""struct{ locus: locus, alleles: array, rsid: str, @@ -34,7 +33,7 @@ def test_realistic(self): AF: array, AN: int32, homozygote_count: array, - call_rate: float64}}''') + call_rate: float64}}""") tree = PlacementTree.from_named_type('row', dtype) grid = tree.to_grid() assert len(grid) == 4 @@ -78,4 +77,5 @@ def test_realistic(self): ('AF', 1), ('AN', 1), ('homozygote_count', 1), - ('call_rate', 1)] + ('call_rate', 1), + ] diff --git a/hail/python/test/hail/utils/test_struct_repr_pprint.py b/hail/python/test/hail/utils/test_struct_repr_pprint.py index 338fa33f57b..69576b53320 100644 --- a/hail/python/test/hail/utils/test_struct_repr_pprint.py +++ b/hail/python/test/hail/utils/test_struct_repr_pprint.py @@ -1,6 +1,7 @@ -import hail as hl from pprint import pformat +import hail as hl + def test_repr_empty_struct(): assert repr(hl.Struct()) == 'Struct()' @@ -55,7 +56,9 @@ def test_pformat_struct_in_struct_some_non_identifiers1(): def test_pformat_struct_in_struct_some_non_identifiers2(): - assert pformat(hl.Struct(**{'x': 3, 'y ': 3, 'z': hl.Struct(a=5)})) == "Struct(**{'x': 3, 'y ': 3, 'z': Struct(a=5)})" + assert ( + pformat(hl.Struct(**{'x': 3, 'y ': 3, 'z': hl.Struct(a=5)})) == "Struct(**{'x': 3, 'y ': 3, 'z': Struct(a=5)})" + ) def test_pformat_small_struct_in_big_struct(): @@ -71,20 +74,6 @@ def test_pformat_small_struct_in_big_struct(): assert pformat(x) == expected -def test_pformat_big_struct_in_small_struct(): - x = hl.Struct(a5=hl.Struct(b0='', b1='na', b2='nana', b3='nanana', b5='ndasdfhjwafdhjskfdshjkfhdjksfhdsjk')) - expected = """ -Struct(a5=Struct(b0='', - b1='na', - b2='nana', - b3='nanana', - b5='ndasdfhjwafdhjskfdshjkfhdjksfhdsjk')) -""".strip() - assert pformat(x) == expected - - - - def test_pformat_big_struct_in_small_struct(): x = hl.Struct(a5=hl.Struct(b0='', b1='na', b2='nana', b3='nanana', b5='ndasdfhjwafdhjskfdshjkfhdjksfhdsjk')) expected = """ diff --git a/hail/python/test/hail/utils/test_utils.py b/hail/python/test/hail/utils/test_utils.py index be7f1ebfeba..e51d8a0f6fc 100644 --- a/hail/python/test/hail/utils/test_utils.py +++ b/hail/python/test/hail/utils/test_utils.py @@ -1,13 +1,25 @@ import json +import os import unittest +import pytest + import hail as hl -from hail.utils import * -from hail.utils.misc import escape_str, escape_id +from hail.utils import ( + HailUserError, + Interval, + Struct, + frozendict, + hadoop_copy, + hadoop_open, + range_table, + with_local_temp_file, +) from hail.utils.java import FatalError from hail.utils.linkedlist import LinkedList +from hail.utils.misc import escape_id, escape_str -from ..helpers import * +from ..helpers import fails_local_backend, fails_service_backend, qobtest, resource def normalize_path(path: str) -> str: @@ -21,7 +33,6 @@ def touch(filename): @qobtest class Tests(unittest.TestCase): - def test_hadoop_methods(self): data = ['foo', 'bar', 'baz'] data.extend(map(str, range(100))) @@ -79,7 +90,7 @@ def test_hadoop_mkdir_p(self): self.assertTrue(hl.hadoop_exists(resource('./some/foo/bar.txt'))) with hadoop_open(resource('./some/foo/bar.txt')) as f: - assert(f.read() == test_text) + assert f.read() == test_text hl.current_backend().fs.rmtree(resource('./some')) @@ -129,7 +140,10 @@ def test_hadoop_no_glob_in_bucket(self): except ValueError as err: assert f'glob pattern only allowed in path (e.g. not in bucket): {glob_in_bucket_url}' in err.args[0] except FatalError as err: - assert f"Invalid GCS bucket name 'glob*{bucket}': bucket name must contain only 'a-z0-9_.-' characters." in err.args[0] + assert ( + f"Invalid GCS bucket name 'glob*{bucket}': bucket name must contain only 'a-z0-9_.-' characters." + in err.args[0] + ) else: assert False @@ -137,11 +151,11 @@ def test_hadoop_ls_simple(self): with hl.TemporaryDirectory() as dirname: with hl.current_backend().fs.open(dirname + '/a', 'w') as fobj: fobj.write('hello world') - dirname = normalize_path(dirname) + _dirname = normalize_path(dirname) - results = hl.hadoop_ls(dirname + '/[a]') + results = hl.hadoop_ls(_dirname + '/[a]') assert len(results) == 1 - assert results[0]['path'] == dirname + '/a' + assert results[0]['path'] == _dirname + '/a' def test_hadoop_ls(self): path1 = resource('ls_test/f_50') @@ -170,7 +184,7 @@ def test_hadoop_ls_file_that_does_not_exist(self): except FileNotFoundError: pass except FatalError as err: - assert 'FileNotFoundException: a_file_that_does_not_exist' in err.args[0] + assert 'FileNotFoundException: file:/io/a_file_that_does_not_exist' in err.args[0] else: assert False @@ -182,20 +196,20 @@ def test_hadoop_glob_heterogenous_structure(self): touch(dirname + '/def/dog') touch(dirname + '/ghi/cat') touch(dirname + '/ghi/cat') - dirname = normalize_path(dirname) + _dirname = normalize_path(dirname) - actual = {x['path'] for x in hl.hadoop_ls(dirname + '/*/cat')} + actual = {x['path'] for x in hl.hadoop_ls(_dirname + '/*/cat')} expected = { - dirname + '/abc/cat', - dirname + '/def/cat', - dirname + '/ghi/cat', + _dirname + '/abc/cat', + _dirname + '/def/cat', + _dirname + '/ghi/cat', } assert actual == expected - actual = {x['path'] for x in hl.hadoop_ls(dirname + '/*/dog')} + actual = {x['path'] for x in hl.hadoop_ls(_dirname + '/*/dog')} expected = { - dirname + '/abc/dog', - dirname + '/def/dog', + _dirname + '/abc/dog', + _dirname + '/def/dog', } assert actual == expected @@ -254,8 +268,7 @@ def test_struct_ops(self): self.assertEqual(s.annotate(), s) self.assertEqual(s.annotate(x=5), Struct(a=1, b=2, c=3, x=5)) - self.assertEqual(s.annotate(**{'a': 5, 'x': 10, 'y': 15}), - Struct(a=5, b=2, c=3, x=10, y=15)) + self.assertEqual(s.annotate(**{'a': 5, 'x': 10, 'y': 15}), Struct(a=5, b=2, c=3, x=10, y=15)) def test_expr_exception_results_in_hail_user_error(self): df = range_table(10) @@ -296,7 +309,7 @@ def test_escape_id(self): self.assertEqual(escape_id("123abc"), "`123abc`") def test_frozen_dict(self): - self.assertEqual(frozendict({1:2, 4:7}), frozendict({1:2, 4:7})) + self.assertEqual(frozendict({1: 2, 4: 7}), frozendict({1: 2, 4: 7})) my_frozen_dict = frozendict({"a": "apple", "h": "hail"}) self.assertEqual(my_frozen_dict["a"], "apple") @@ -310,24 +323,18 @@ def test_frozen_dict(self): my_frozen_dict["a"] = "b" def test_json_encoder(self): - self.assertEqual( - json.dumps(frozendict({"foo": "bar"}), cls=hl.utils.JSONEncoder), - '{"foo": "bar"}' - ) + self.assertEqual(json.dumps(frozendict({"foo": "bar"}), cls=hl.utils.JSONEncoder), '{"foo": "bar"}') - self.assertEqual( - json.dumps(Struct(foo="bar"), cls=hl.utils.JSONEncoder), - '{"foo": "bar"}' - ) + self.assertEqual(json.dumps(Struct(foo="bar"), cls=hl.utils.JSONEncoder), '{"foo": "bar"}') self.assertEqual( json.dumps(Interval(start=1, end=10), cls=hl.utils.JSONEncoder), - '{"start": 1, "end": 10, "includes_start": true, "includes_end": false}' + '{"start": 1, "end": 10, "includes_start": true, "includes_end": false}', ) self.assertEqual( json.dumps(hl.Locus(1, 100, "GRCh38"), cls=hl.utils.JSONEncoder), - '{"contig": "1", "position": 100, "reference_genome": "GRCh38"}' + '{"contig": "1", "position": 100, "reference_genome": "GRCh38"}', ) @@ -358,39 +365,41 @@ def glob_tests_directory(init_hail): def test_hadoop_ls_folder_glob(glob_tests_directory): - expected = [glob_tests_directory + '/abc/ghi/123', - glob_tests_directory + '/abc/jkl/123'] + expected = [glob_tests_directory + '/abc/ghi/123', glob_tests_directory + '/abc/jkl/123'] actual = [x['path'] for x in hl.hadoop_ls(glob_tests_directory + '/abc/*/123')] assert set(actual) == set(expected) + def test_hadoop_ls_prefix_folder_glob_qmarks(glob_tests_directory): - expected = [glob_tests_directory + '/abc/ghi/78', - glob_tests_directory + '/abc/jkl/78'] + expected = [glob_tests_directory + '/abc/ghi/78', glob_tests_directory + '/abc/jkl/78'] actual = [x['path'] for x in hl.hadoop_ls(glob_tests_directory + '/abc/*/??')] assert set(actual) == set(expected) def test_hadoop_ls_two_folder_globs(glob_tests_directory): - expected = [glob_tests_directory + '/abc/ghi/123', - glob_tests_directory + '/abc/jkl/123', - glob_tests_directory + '/def/ghi/123', - glob_tests_directory + '/def/jkl/123'] + expected = [ + glob_tests_directory + '/abc/ghi/123', + glob_tests_directory + '/abc/jkl/123', + glob_tests_directory + '/def/ghi/123', + glob_tests_directory + '/def/jkl/123', + ] actual = [x['path'] for x in hl.hadoop_ls(glob_tests_directory + '/*/*/123')] assert set(actual) == set(expected) def test_hadoop_ls_two_folder_globs_and_two_qmarks(glob_tests_directory): - expected = [glob_tests_directory + '/abc/ghi/78', - glob_tests_directory + '/abc/jkl/78', - glob_tests_directory + '/def/ghi/78', - glob_tests_directory + '/def/jkl/78'] + expected = [ + glob_tests_directory + '/abc/ghi/78', + glob_tests_directory + '/abc/jkl/78', + glob_tests_directory + '/def/ghi/78', + glob_tests_directory + '/def/jkl/78', + ] actual = [x['path'] for x in hl.hadoop_ls(glob_tests_directory + '/*/*/??')] assert set(actual) == set(expected) def test_hadoop_ls_one_folder_glob_and_qmarks_in_multiple_components(glob_tests_directory): - expected = [glob_tests_directory + '/abc/ghi/78', - glob_tests_directory + '/def/ghi/78'] + expected = [glob_tests_directory + '/abc/ghi/78', glob_tests_directory + '/def/ghi/78'] actual = [x['path'] for x in hl.hadoop_ls(glob_tests_directory + '/*/?h?/??')] assert set(actual) == set(expected) @@ -408,33 +417,28 @@ def test_hadoop_ls_size_one_groups(glob_tests_directory): def test_hadoop_ls_component_with_only_groups(glob_tests_directory): - expected = [glob_tests_directory + '/abc/ghi/123', - glob_tests_directory + '/abc/ghi/!23', - glob_tests_directory + '/abc/ghi/?23', - glob_tests_directory + '/abc/ghi/456', - glob_tests_directory + '/abc/ghi/78'] + expected = [ + glob_tests_directory + '/abc/ghi/123', + glob_tests_directory + '/abc/ghi/!23', + glob_tests_directory + '/abc/ghi/?23', + glob_tests_directory + '/abc/ghi/456', + glob_tests_directory + '/abc/ghi/78', + ] actual = [x['path'] for x in hl.hadoop_ls(glob_tests_directory + '/abc/[g][h][i]/*')] assert set(actual) == set(expected) def test_hadoop_ls_negated_group(glob_tests_directory): - expected = [glob_tests_directory + '/abc/ghi/!23', - glob_tests_directory + '/abc/ghi/?23'] + expected = [glob_tests_directory + '/abc/ghi/!23', glob_tests_directory + '/abc/ghi/?23'] actual = [x['path'] for x in hl.hadoop_ls(glob_tests_directory + '/abc/ghi/[!1]23')] assert set(actual) == set(expected) def test_struct_rich_comparison(): """Asserts comparisons between structs and struct expressions are symmetric""" - struct = hl.Struct( - locus=hl.Locus(contig=10, position=60515, reference_genome='GRCh37'), - alleles=['C', 'T'] - ) - - expr = hl.struct( - locus=hl.locus(contig='10', pos=60515, reference_genome='GRCh37'), - alleles=['C', 'T'] - ) + struct = hl.Struct(locus=hl.Locus(contig=10, position=60515, reference_genome='GRCh37'), alleles=['C', 'T']) + + expr = hl.struct(locus=hl.locus(contig='10', pos=60515, reference_genome='GRCh37'), alleles=['C', 'T']) assert hl.eval(struct == expr) and hl.eval(expr == struct) assert hl.eval(struct >= expr) and hl.eval(expr >= struct) diff --git a/hail/python/test/hail/vds/test_combiner.py b/hail/python/test/hail/vds/test_combiner.py index 83af1297efc..ca18e5bbd61 100644 --- a/hail/python/test/hail/vds/test_combiner.py +++ b/hail/python/test/hail/vds/test_combiner.py @@ -1,24 +1,64 @@ import os -import pytest import hail as hl - from hail.utils.java import Env from hail.utils.misc import new_temp_file -from hail.vds.combiner import combine_variant_datasets, new_combiner, load_combiner, transform_gvcf -from hail.vds.combiner.combine import defined_entry_fields -from ..helpers import resource, skip_when_service_backend, test_timeout, qobtest - -all_samples = ['HG00308', 'HG00592', 'HG02230', 'NA18534', 'NA20760', - 'NA18530', 'HG03805', 'HG02223', 'HG00637', 'NA12249', - 'HG02224', 'NA21099', 'NA11830', 'HG01378', 'HG00187', - 'HG01356', 'HG02188', 'NA20769', 'HG00190', 'NA18618', - 'NA18507', 'HG03363', 'NA21123', 'HG03088', 'NA21122', - 'HG00373', 'HG01058', 'HG00524', 'NA18969', 'HG03833', - 'HG04158', 'HG03578', 'HG00339', 'HG00313', 'NA20317', - 'HG00553', 'HG01357', 'NA19747', 'NA18609', 'HG01377', - 'NA19456', 'HG00590', 'HG01383', 'HG00320', 'HG04001', - 'NA20796', 'HG00323', 'HG01384', 'NA18613', 'NA20802'] +from hail.vds.combiner import load_combiner, new_combiner + +from ..helpers import qobtest, resource, skip_when_service_backend, test_timeout + +all_samples = [ + 'HG00308', + 'HG00592', + 'HG02230', + 'NA18534', + 'NA20760', + 'NA18530', + 'HG03805', + 'HG02223', + 'HG00637', + 'NA12249', + 'HG02224', + 'NA21099', + 'NA11830', + 'HG01378', + 'HG00187', + 'HG01356', + 'HG02188', + 'NA20769', + 'HG00190', + 'NA18618', + 'NA18507', + 'HG03363', + 'NA21123', + 'HG03088', + 'NA21122', + 'HG00373', + 'HG01058', + 'HG00524', + 'NA18969', + 'HG03833', + 'HG04158', + 'HG03578', + 'HG00339', + 'HG00313', + 'NA20317', + 'HG00553', + 'HG01357', + 'NA19747', + 'NA18609', + 'HG01377', + 'NA19456', + 'HG00590', + 'HG01383', + 'HG00320', + 'HG04001', + 'NA20796', + 'HG00323', + 'HG01384', + 'NA18613', + 'NA20802', +] @qobtest @@ -26,19 +66,27 @@ def test_combiner_works(): _paths = ['gvcfs/HG00096.g.vcf.gz', 'gvcfs/HG00268.g.vcf.gz'] paths = [resource(p) for p in _paths] parts = [ - hl.Interval(start=hl.Locus('chr20', 17821257, reference_genome='GRCh38'), - end=hl.Locus('chr20', 18708366, reference_genome='GRCh38'), - includes_end=True), - hl.Interval(start=hl.Locus('chr20', 18708367, reference_genome='GRCh38'), - end=hl.Locus('chr20', 19776611, reference_genome='GRCh38'), - includes_end=True), - hl.Interval(start=hl.Locus('chr20', 19776612, reference_genome='GRCh38'), - end=hl.Locus('chr20', 21144633, reference_genome='GRCh38'), - includes_end=True) + hl.Interval( + start=hl.Locus('chr20', 17821257, reference_genome='GRCh38'), + end=hl.Locus('chr20', 18708366, reference_genome='GRCh38'), + includes_end=True, + ), + hl.Interval( + start=hl.Locus('chr20', 18708367, reference_genome='GRCh38'), + end=hl.Locus('chr20', 19776611, reference_genome='GRCh38'), + includes_end=True, + ), + hl.Interval( + start=hl.Locus('chr20', 19776612, reference_genome='GRCh38'), + end=hl.Locus('chr20', 21144633, reference_genome='GRCh38'), + includes_end=True, + ), ] with hl.TemporaryDirectory() as tmpdir: out = os.path.join(tmpdir, 'out.vds') - hl.vds.new_combiner(temp_path=tmpdir, output_path=out, gvcf_paths=paths, intervals=parts, reference_genome='GRCh38').run() + hl.vds.new_combiner( + temp_path=tmpdir, output_path=out, gvcf_paths=paths, intervals=parts, reference_genome='GRCh38' + ).run() comb = hl.vds.read_vds(out) # see https://github.com/hail-is/hail/issues/13367 for why these assertions are here @@ -55,42 +103,50 @@ def test_combiner_plan_round_trip_serialization(): paths = [os.path.join(resource('gvcfs'), '1kg_chr22', f'{s}.hg38.g.vcf.gz') for s in sample_names] plan_path = new_temp_file(extension='json') out_file = new_temp_file(extension='vds') - plan = new_combiner(gvcf_paths=paths, - output_path=out_file, - temp_path=Env.hc()._tmpdir, - save_path=plan_path, - reference_genome='GRCh38', - use_exome_default_intervals=True, - branch_factor=2, - batch_size=2) + plan = new_combiner( + gvcf_paths=paths, + output_path=out_file, + temp_path=Env.hc()._tmpdir, + save_path=plan_path, + reference_genome='GRCh38', + use_exome_default_intervals=True, + branch_factor=2, + batch_size=2, + ) plan.save() plan_loaded = load_combiner(plan_path) assert plan == plan_loaded + def test_reload_combiner_plan(): sample_names = all_samples[:5] paths = [os.path.join(resource('gvcfs'), '1kg_chr22', f'{s}.hg38.g.vcf.gz') for s in sample_names] plan_path = new_temp_file(extension='json') out_file = new_temp_file(extension='vds') - plan = new_combiner(gvcf_paths=paths, - output_path=out_file, - temp_path=Env.hc()._tmpdir, - save_path=plan_path, - reference_genome='GRCh38', - use_exome_default_intervals=True, - branch_factor=2, - batch_size=2) + plan = new_combiner( + gvcf_paths=paths, + output_path=out_file, + temp_path=Env.hc()._tmpdir, + save_path=plan_path, + reference_genome='GRCh38', + use_exome_default_intervals=True, + branch_factor=2, + batch_size=2, + ) plan.save() - plan_loaded = new_combiner(gvcf_paths=paths, - output_path=out_file, - temp_path=Env.hc()._tmpdir, - save_path=plan_path, - reference_genome='GRCh38', - use_exome_default_intervals=True, - branch_factor=2, - batch_size=2) + plan_loaded = new_combiner( + gvcf_paths=paths, + output_path=out_file, + temp_path=Env.hc()._tmpdir, + save_path=plan_path, + reference_genome='GRCh38', + use_exome_default_intervals=True, + branch_factor=2, + batch_size=2, + ) assert plan == plan_loaded + def test_move_load_combiner_plan(): fs = hl.current_backend().fs sample_names = all_samples[:5] @@ -98,14 +154,16 @@ def test_move_load_combiner_plan(): plan_path = new_temp_file(extension='json') out_file = new_temp_file(extension='vds') new_plan_path = new_temp_file(extension='json') - plan = new_combiner(gvcf_paths=paths, - output_path=out_file, - temp_path=Env.hc()._tmpdir, - save_path=plan_path, - reference_genome='GRCh38', - use_exome_default_intervals=True, - branch_factor=2, - batch_size=2) + plan = new_combiner( + gvcf_paths=paths, + output_path=out_file, + temp_path=Env.hc()._tmpdir, + save_path=plan_path, + reference_genome='GRCh38', + use_exome_default_intervals=True, + branch_factor=2, + batch_size=2, + ) plan.save() fs.copy(plan_path, new_plan_path) plan_loaded = load_combiner(new_plan_path) @@ -115,7 +173,9 @@ def test_move_load_combiner_plan(): @test_timeout(10 * 60) -@skip_when_service_backend(reason='Combiner makes extensive use of the Backend API which are serviced by starting a Hail Batch job to execute them. This test will be too slow until we change the combiner to use many fewer executes.') +@skip_when_service_backend( + reason='Combiner makes extensive use of the Backend API which are serviced by starting a Hail Batch job to execute them. This test will be too slow until we change the combiner to use many fewer executes.' +) def test_combiner_run(): tmpdir = new_temp_file() samples = all_samples[:5] @@ -128,22 +188,32 @@ def test_combiner_run(): parts = hl.eval([hl.parse_locus_interval('chr22:start-end', reference_genome='GRCh38')]) for input_gvcf, path in zip(input_paths[:2], final_paths_individual[:2]): - combiner = hl.vds.new_combiner(output_path=path, intervals=parts, - temp_path=tmpdir, - gvcf_paths=[input_gvcf], - reference_genome='GRCh38') + combiner = hl.vds.new_combiner( + output_path=path, intervals=parts, temp_path=tmpdir, gvcf_paths=[input_gvcf], reference_genome='GRCh38' + ) combiner.run() - combiner = hl.vds.new_combiner(output_path=final_path_1, intervals=parts, temp_path=tmpdir, - gvcf_paths=input_paths[2:], vds_paths=final_paths_individual[:2], - reference_genome='GRCh38', - branch_factor=2, batch_size=2) + combiner = hl.vds.new_combiner( + output_path=final_path_1, + intervals=parts, + temp_path=tmpdir, + gvcf_paths=input_paths[2:], + vds_paths=final_paths_individual[:2], + reference_genome='GRCh38', + branch_factor=2, + batch_size=2, + ) combiner.run() - combiner2 = hl.vds.new_combiner(output_path=final_path_2, intervals=parts, temp_path=tmpdir, - gvcf_paths=input_paths, - reference_genome='GRCh38', - branch_factor=2, batch_size=2) + combiner2 = hl.vds.new_combiner( + output_path=final_path_2, + intervals=parts, + temp_path=tmpdir, + gvcf_paths=input_paths, + reference_genome='GRCh38', + branch_factor=2, + batch_size=2, + ) combiner2.run() assert hl.vds.read_vds(final_path_1)._same(hl.vds.read_vds(final_path_2)) @@ -153,14 +223,16 @@ def test_combiner_manual_filtration(): sample_names = all_samples[:2] paths = [os.path.join(resource('gvcfs'), '1kg_chr22', f'{s}.hg38.g.vcf.gz') for s in sample_names] out_file = new_temp_file(extension='vds') - plan = new_combiner(gvcf_paths=paths, - output_path=out_file, - temp_path=Env.hc()._tmpdir, - reference_genome='GRCh38', - use_exome_default_intervals=True, - gvcf_reference_entry_fields_to_keep=['GQ'], - gvcf_info_to_keep=['ExcessHet'], - force=True) + plan = new_combiner( + gvcf_paths=paths, + output_path=out_file, + temp_path=Env.hc()._tmpdir, + reference_genome='GRCh38', + use_exome_default_intervals=True, + gvcf_reference_entry_fields_to_keep=['GQ'], + gvcf_info_to_keep=['ExcessHet'], + force=True, + ) assert plan._gvcf_info_to_keep == {'ExcessHet'} @@ -178,19 +250,22 @@ def test_ref_block_max_len_propagates_in_combiner(): for i, gvcf in enumerate(gvcfs): p = os.path.join(tmpdir, f'{i}.vds') vds_paths.append(p) - c = hl.vds.new_combiner(output_path=p, temp_path=tmpdir, - gvcf_paths=[os.path.join(resource('gvcfs'), '1kg_chr22', gvcf)], - reference_genome='GRCh38', - import_interval_size=1000000000) + c = hl.vds.new_combiner( + output_path=p, + temp_path=tmpdir, + gvcf_paths=[os.path.join(resource('gvcfs'), '1kg_chr22', gvcf)], + reference_genome='GRCh38', + import_interval_size=1000000000, + ) c.run() for path in vds_paths: vds = hl.vds.read_vds(path) assert hl.vds.VariantDataset.ref_block_max_length_field in vds.reference_data.globals final_path = os.path.join(tmpdir, 'final.vds') - hl.vds.new_combiner(output_path=final_path, temp_path=tmpdir, - vds_paths=vds_paths, - reference_genome='GRCh38').run() + hl.vds.new_combiner( + output_path=final_path, temp_path=tmpdir, vds_paths=vds_paths, reference_genome='GRCh38' + ).run() vds = hl.vds.read_vds(final_path) assert hl.vds.VariantDataset.ref_block_max_length_field in vds.reference_data.globals @@ -199,13 +274,22 @@ def test_custom_call_fields(): _paths = ['gvcfs/HG00096.g.vcf.gz', 'gvcfs/HG00268.g.vcf.gz'] paths = [resource(p) for p in _paths] parts = [ - hl.Interval(start=hl.Locus('chr20', 17821257, reference_genome='GRCh38'), - end=hl.Locus('chr20', 21144633, reference_genome='GRCh38'), - includes_end=True), + hl.Interval( + start=hl.Locus('chr20', 17821257, reference_genome='GRCh38'), + end=hl.Locus('chr20', 21144633, reference_genome='GRCh38'), + includes_end=True, + ), ] with hl.TemporaryDirectory() as tmpdir: out = os.path.join(tmpdir, 'out.vds') - hl.vds.new_combiner(temp_path=tmpdir, output_path=out, gvcf_paths=paths, intervals=parts, call_fields=[], reference_genome='GRCh38').run() + hl.vds.new_combiner( + temp_path=tmpdir, + output_path=out, + gvcf_paths=paths, + intervals=parts, + call_fields=[], + reference_genome='GRCh38', + ).run() comb = hl.vds.read_vds(out) assert 'LPGT' in comb.variant_data.entry diff --git a/hail/python/test/hail/vds/test_vds.py b/hail/python/test/hail/vds/test_vds.py index 660a52ca29f..77576e0e5e7 100644 --- a/hail/python/test/hail/vds/test_vds.py +++ b/hail/python/test/hail/vds/test_vds.py @@ -1,24 +1,36 @@ import os + import pytest import hail as hl from hail.utils import new_temp_file from hail.vds.combiner.combine import defined_entry_fields -from ..helpers import resource, test_timeout, qobtest + +from ..helpers import qobtest, resource, test_timeout # run this method to regenerate the combined VDS from 5 samples def generate_5_sample_vds(): - paths = [os.path.join(resource('gvcfs'), '1kg_chr22', path) for path in ['HG00187.hg38.g.vcf.gz', - 'HG00190.hg38.g.vcf.gz', - 'HG00308.hg38.g.vcf.gz', - 'HG00313.hg38.g.vcf.gz', - 'HG00320.hg38.g.vcf.gz']] + paths = [ + os.path.join(resource('gvcfs'), '1kg_chr22', path) + for path in [ + 'HG00187.hg38.g.vcf.gz', + 'HG00190.hg38.g.vcf.gz', + 'HG00308.hg38.g.vcf.gz', + 'HG00313.hg38.g.vcf.gz', + 'HG00320.hg38.g.vcf.gz', + ] + ] parts = [ - hl.Interval(start=hl.Struct(locus=hl.Locus('chr22', 1, reference_genome='GRCh38')), - end=hl.Struct(locus=hl.Locus('chr22', hl.get_reference('GRCh38').contig_length('chr22') - 1, - reference_genome='GRCh38')), - includes_end=True) + hl.Interval( + start=hl.Struct(locus=hl.Locus('chr22', 1, reference_genome='GRCh38')), + end=hl.Struct( + locus=hl.Locus( + 'chr22', hl.get_reference('GRCh38').contig_length('chr22') - 1, reference_genome='GRCh38' + ) + ), + includes_end=True, + ) ] vcfs = hl.import_gvcfs(paths, parts, reference_genome='GRCh38', array_elements_required=False) to_keep = defined_entry_fields(vcfs[0].filter_rows(hl.is_defined(vcfs[0].info.END)), 100_000) @@ -33,14 +45,22 @@ def test_validate(): with pytest.raises(ValueError): hl.vds.VariantDataset( - vds.reference_data.annotate_rows(arr=[0, 1]).explode_rows('arr'), - vds.variant_data).validate() + vds.reference_data.annotate_rows(arr=[0, 1]).explode_rows('arr'), vds.variant_data + ).validate() with pytest.raises(ValueError): hl.vds.VariantDataset( vds.reference_data.annotate_entries( - END=hl.or_missing(vds.reference_data.locus.position % 2 == 0, vds.reference_data.END)), - vds.variant_data).validate() + END=hl.or_missing(vds.reference_data.locus.position % 2 == 0, vds.reference_data.END) + ), + vds.variant_data, + ).validate() + + with pytest.raises(ValueError): + hl.vds.VariantDataset( + vds.reference_data.annotate_entries(END=vds.reference_data.END + 1), + vds.variant_data, + ).validate() @qobtest @@ -80,12 +100,12 @@ def test_sampleqc_old_new_equivalence(): 'n_star', 'r_ti_tv', 'r_het_hom_var', - 'r_insertion_deletion' + 'r_insertion_deletion', ] - assert res.aggregate_cols(hl.all( - *(hl.agg.all(res.sample_qc[field] == res.sample_qc_new[field]) for field in fields_to_test) - )) + assert res.aggregate_cols( + hl.all(*(hl.agg.all(res.sample_qc[field] == res.sample_qc_new[field]) for field in fields_to_test)) + ) def test_sampleqc_gq_dp(): @@ -95,21 +115,19 @@ def test_sampleqc_gq_dp(): assert hl.eval(sqc.index_globals()) == hl.Struct(gq_bins=(0, 20, 60), dp_bins=(0, 1, 10, 20, 30)) hg00320 = sqc.filter(sqc.s == 'HG00320').select('bases_over_gq_threshold', 'bases_over_dp_threshold').collect()[0] - assert hg00320 == hl.Struct(s='HG00320', - bases_over_gq_threshold=(334822, 515, 82), - bases_over_dp_threshold=(334822, 10484, 388, 111, 52)) + assert hg00320 == hl.Struct( + s='HG00320', bases_over_gq_threshold=(334822, 515, 82), bases_over_dp_threshold=(334822, 10484, 388, 111, 52) + ) def test_sampleqc_singleton_r_ti_tv(): vds = hl.vds.read_vds(os.path.join(resource('vds'), '1kg_chr22_5_samples.vds')) sqc = hl.vds.sample_qc(vds) - hg00313 = sqc.filter(sqc.s == 'HG00313').select('r_ti_tv_singleton', 'n_singleton_ti', 'n_singleton_tv').collect()[0] - assert hg00313 == hl.Struct(s='HG00313', - r_ti_tv_singleton=4.0, - n_singleton_ti=4, - n_singleton_tv=1) - + hg00313 = ( + sqc.filter(sqc.s == 'HG00313').select('r_ti_tv_singleton', 'n_singleton_ti', 'n_singleton_tv').collect()[0] + ) + assert hg00313 == hl.Struct(s='HG00313', r_ti_tv_singleton=4.0, n_singleton_ti=4, n_singleton_tv=1) def test_filter_samples_and_merge(): @@ -159,9 +177,11 @@ def test_segment_intervals(): contig_len = vds.reference_data.locus.dtype.reference_genome.lengths['chr22'] breakpoints = hl.literal([*range(1, contig_len, 5_000_000), contig_len]) - intervals = hl.range(hl.len(breakpoints) - 1) \ - .map(lambda i: hl.struct( - interval=hl.locus_interval('chr22', breakpoints[i], breakpoints[i + 1], reference_genome='GRCh38'))) + intervals = hl.range(hl.len(breakpoints) - 1).map( + lambda i: hl.struct( + interval=hl.locus_interval('chr22', breakpoints[i], breakpoints[i + 1], reference_genome='GRCh38') + ) + ) intervals_ht = hl.Table.parallelize(intervals, key='interval') path = new_temp_file() @@ -178,7 +198,8 @@ def test_segment_intervals(): before = vds.reference_data sum_per_sample_before = before.select_cols( - ref_block_bases=hl.agg.sum(before.END + 1 - before.locus.position)).cols() + ref_block_bases=hl.agg.sum(before.END + 1 - before.locus.position) + ).cols() sum_per_sample_after = after.select_cols(ref_block_bases=hl.agg.sum(after.END + 1 - after.locus.position)).cols() before_coverage = sum_per_sample_before.collect() @@ -195,38 +216,102 @@ def test_interval_coverage(): intervals = hl.Table.parallelize( list(hl.struct(interval=hl.parse_locus_interval(x, reference_genome='GRCh38')) for x in [interval1, interval2]), - key='interval') + key='interval', + ) checkpoint_path = new_temp_file() r = hl.vds.interval_coverage(vds, intervals, gq_thresholds=(1, 21), dp_thresholds=(0, 1, 6)).checkpoint( - checkpoint_path) - assert r.aggregate_rows(hl.agg.collect((hl.format('%s:%d-%d', r.interval.start.contig, r.interval.start.position, - r.interval.end.position), r.interval_size))) == [(interval1, 10), - (interval2, 9)] + checkpoint_path + ) + assert r.aggregate_rows( + hl.agg.collect(( + hl.format('%s:%d-%d', r.interval.start.contig, r.interval.start.position, r.interval.end.position), + r.interval_size, + )) + ) == [(interval1, 10), (interval2, 9)] observed = r.aggregate_entries(hl.agg.collect(r.entry)) expected = [ - hl.Struct(bases_over_gq_threshold=(10, 0), bases_over_dp_threshold=(10, 10, 5), sum_dp=55, - fraction_over_gq_threshold=(1.0, 0.0), fraction_over_dp_threshold=(1.0, 1.0, 0.5), mean_dp=5.5), - hl.Struct(bases_over_gq_threshold=(10, 0), bases_over_dp_threshold=(10, 10, 0), sum_dp=45, - fraction_over_gq_threshold=(1.0, 0.0), fraction_over_dp_threshold=(1.0, 1.0, 0), mean_dp=4.5), - hl.Struct(bases_over_gq_threshold=(0, 0), bases_over_dp_threshold=(10, 0, 0), sum_dp=0, - fraction_over_gq_threshold=(0.0, 0.0), fraction_over_dp_threshold=(1.0, 0, 0), mean_dp=0), - hl.Struct(bases_over_gq_threshold=(10, 0), bases_over_dp_threshold=(10, 10, 0), sum_dp=30, - fraction_over_gq_threshold=(1.0, 0.0), fraction_over_dp_threshold=(1.0, 1.0, 0.0), mean_dp=3.0), - hl.Struct(bases_over_gq_threshold=(9, 0), bases_over_dp_threshold=(10, 10, 0), sum_dp=10, - fraction_over_gq_threshold=(0.9, 0.0), fraction_over_dp_threshold=(1.0, 1.0, 0.0), mean_dp=1.0), - - hl.Struct(bases_over_gq_threshold=(9, 9), bases_over_dp_threshold=(9, 9, 9), sum_dp=153, - fraction_over_gq_threshold=(1.0, 1.0), fraction_over_dp_threshold=(1.0, 1.0, 1.0), mean_dp=17.0), - hl.Struct(bases_over_gq_threshold=(9, 9), bases_over_dp_threshold=(9, 9, 9), sum_dp=159, - fraction_over_gq_threshold=(1.0, 1.0), fraction_over_dp_threshold=(1.0, 1.0, 1.0), mean_dp=159 / 9), - hl.Struct(bases_over_gq_threshold=(9, 9), bases_over_dp_threshold=(9, 9, 9), sum_dp=98, - fraction_over_gq_threshold=(1.0, 1.0), fraction_over_dp_threshold=(1.0, 1.0, 1.0), mean_dp=98 / 9), - hl.Struct(bases_over_gq_threshold=(9, 9), bases_over_dp_threshold=(9, 9, 9), sum_dp=72, - fraction_over_gq_threshold=(1.0, 1.0), fraction_over_dp_threshold=(1.0, 1.0, 1.0), mean_dp=8), - hl.Struct(bases_over_gq_threshold=(9, 0), bases_over_dp_threshold=(9, 9, 0), sum_dp=20, - fraction_over_gq_threshold=(1.0, 0.0), fraction_over_dp_threshold=(1.0, 1.0, 0.0), mean_dp=2 / 9), + hl.Struct( + bases_over_gq_threshold=(10, 0), + bases_over_dp_threshold=(10, 10, 5), + sum_dp=55, + fraction_over_gq_threshold=(1.0, 0.0), + fraction_over_dp_threshold=(1.0, 1.0, 0.5), + mean_dp=5.5, + ), + hl.Struct( + bases_over_gq_threshold=(10, 0), + bases_over_dp_threshold=(10, 10, 0), + sum_dp=45, + fraction_over_gq_threshold=(1.0, 0.0), + fraction_over_dp_threshold=(1.0, 1.0, 0), + mean_dp=4.5, + ), + hl.Struct( + bases_over_gq_threshold=(0, 0), + bases_over_dp_threshold=(10, 0, 0), + sum_dp=0, + fraction_over_gq_threshold=(0.0, 0.0), + fraction_over_dp_threshold=(1.0, 0, 0), + mean_dp=0, + ), + hl.Struct( + bases_over_gq_threshold=(10, 0), + bases_over_dp_threshold=(10, 10, 0), + sum_dp=30, + fraction_over_gq_threshold=(1.0, 0.0), + fraction_over_dp_threshold=(1.0, 1.0, 0.0), + mean_dp=3.0, + ), + hl.Struct( + bases_over_gq_threshold=(9, 0), + bases_over_dp_threshold=(10, 10, 0), + sum_dp=10, + fraction_over_gq_threshold=(0.9, 0.0), + fraction_over_dp_threshold=(1.0, 1.0, 0.0), + mean_dp=1.0, + ), + hl.Struct( + bases_over_gq_threshold=(9, 9), + bases_over_dp_threshold=(9, 9, 9), + sum_dp=153, + fraction_over_gq_threshold=(1.0, 1.0), + fraction_over_dp_threshold=(1.0, 1.0, 1.0), + mean_dp=17.0, + ), + hl.Struct( + bases_over_gq_threshold=(9, 9), + bases_over_dp_threshold=(9, 9, 9), + sum_dp=159, + fraction_over_gq_threshold=(1.0, 1.0), + fraction_over_dp_threshold=(1.0, 1.0, 1.0), + mean_dp=159 / 9, + ), + hl.Struct( + bases_over_gq_threshold=(9, 9), + bases_over_dp_threshold=(9, 9, 9), + sum_dp=98, + fraction_over_gq_threshold=(1.0, 1.0), + fraction_over_dp_threshold=(1.0, 1.0, 1.0), + mean_dp=98 / 9, + ), + hl.Struct( + bases_over_gq_threshold=(9, 9), + bases_over_dp_threshold=(9, 9, 9), + sum_dp=72, + fraction_over_gq_threshold=(1.0, 1.0), + fraction_over_dp_threshold=(1.0, 1.0, 1.0), + mean_dp=8, + ), + hl.Struct( + bases_over_gq_threshold=(9, 0), + bases_over_dp_threshold=(9, 9, 0), + sum_dp=20, + fraction_over_gq_threshold=(1.0, 0.0), + fraction_over_dp_threshold=(1.0, 1.0, 0.0), + mean_dp=2 / 9, + ), ] for i in range(len(expected)): @@ -248,36 +333,28 @@ def test_impute_sex_chr_ploidy_from_interval_coverage(): y_interval_1 = hl.parse_locus_interval('Y:10-20', reference_genome='GRCh37') y_interval_2 = hl.parse_locus_interval('Y:25-30', reference_genome='GRCh37') - mt = hl.Table.parallelize([hl.Struct(s='sample_xx', interval=norm_interval_1, sum_dp=195), - hl.Struct(s='sample_xx', interval=norm_interval_2, sum_dp=55), - hl.Struct(s='sample_xx', interval=x_interval_1, sum_dp=95), - hl.Struct(s='sample_xx', interval=x_interval_2, sum_dp=85), - hl.Struct(s='sample_xy', interval=norm_interval_1, sum_dp=190), - hl.Struct(s='sample_xy', interval=norm_interval_2, sum_dp=85), - hl.Struct(s='sample_xy', interval=x_interval_1, sum_dp=61), - hl.Struct(s='sample_xy', interval=x_interval_2, sum_dp=49), - hl.Struct(s='sample_xy', interval=y_interval_1, sum_dp=54), - hl.Struct(s='sample_xy', interval=y_interval_2, sum_dp=45)], - schema=hl.dtype( - 'struct{s:str,interval:interval>,sum_dp:int32}')).to_matrix_table( - row_key=['interval'], col_key=['s']) + mt = hl.Table.parallelize( + [ + hl.Struct(s='sample_xx', interval=norm_interval_1, sum_dp=195), + hl.Struct(s='sample_xx', interval=norm_interval_2, sum_dp=55), + hl.Struct(s='sample_xx', interval=x_interval_1, sum_dp=95), + hl.Struct(s='sample_xx', interval=x_interval_2, sum_dp=85), + hl.Struct(s='sample_xy', interval=norm_interval_1, sum_dp=190), + hl.Struct(s='sample_xy', interval=norm_interval_2, sum_dp=85), + hl.Struct(s='sample_xy', interval=x_interval_1, sum_dp=61), + hl.Struct(s='sample_xy', interval=x_interval_2, sum_dp=49), + hl.Struct(s='sample_xy', interval=y_interval_1, sum_dp=54), + hl.Struct(s='sample_xy', interval=y_interval_2, sum_dp=45), + ], + schema=hl.dtype('struct{s:str,interval:interval>,sum_dp:int32}'), + ).to_matrix_table(row_key=['interval'], col_key=['s']) mt = mt.annotate_rows(interval_size=mt.interval.end.position - mt.interval.start.position) r = hl.vds.impute_sex_chr_ploidy_from_interval_coverage(mt, normalization_contig='20') assert r.collect() == [ - hl.Struct(s='sample_xx', - autosomal_mean_dp=10.0, - x_mean_dp=9.0, - x_ploidy=1.8, - y_mean_dp=0.0, - y_ploidy=0.0), - hl.Struct(s='sample_xy', - autosomal_mean_dp=11.0, - x_mean_dp=5.5, - x_ploidy=1.0, - y_mean_dp=6.6, - y_ploidy=1.2) + hl.Struct(s='sample_xx', autosomal_mean_dp=10.0, x_mean_dp=9.0, x_ploidy=1.8, y_mean_dp=0.0, y_ploidy=0.0), + hl.Struct(s='sample_xy', autosomal_mean_dp=11.0, x_mean_dp=5.5, x_ploidy=1.0, y_mean_dp=6.6, y_ploidy=1.2), ] @@ -288,55 +365,127 @@ def test_impute_sex_chr_ploidy_from_interval_coverage(): def get_impute_sex_chromosome_ploidy_ref_mt(): ref_blocks = [ hl.Struct(s='sample_xx', locus=hl.Locus('22', 1000000, 'GRCh37'), END=2000000, GQ=15, DP=5), - hl.Struct(s='sample_xx', locus=hl.Locus('X', X_PAR_END-10, 'GRCh37'), END=X_PAR_END+9, GQ=18, DP=6), - hl.Struct(s='sample_xx', locus=hl.Locus('X', X_PAR_END+10, 'GRCh37'), END=X_PAR_END+29, GQ=15, DP=5), + hl.Struct(s='sample_xx', locus=hl.Locus('X', X_PAR_END - 10, 'GRCh37'), END=X_PAR_END + 9, GQ=18, DP=6), + hl.Struct(s='sample_xx', locus=hl.Locus('X', X_PAR_END + 10, 'GRCh37'), END=X_PAR_END + 29, GQ=15, DP=5), hl.Struct(s='sample_xy', locus=hl.Locus('22', 1000000, 'GRCh37'), END=2000000, GQ=15, DP=5), - hl.Struct(s='sample_xy', locus=hl.Locus('X', X_PAR_END-10, 'GRCh37'), END=X_PAR_END+9, GQ=9, DP=3), - hl.Struct(s='sample_xy', locus=hl.Locus('X', X_PAR_END+10, 'GRCh37'), END=X_PAR_END+29, GQ=6, DP=2), - hl.Struct(s='sample_xy', locus=hl.Locus('Y', Y_PAR_END-10, 'GRCh37'), END=Y_PAR_END+9, GQ=12, DP=4), - hl.Struct(s='sample_xy', locus=hl.Locus('Y', Y_PAR_END+10, 'GRCh37'), END=Y_PAR_END+29, GQ=9, DP=3), + hl.Struct(s='sample_xy', locus=hl.Locus('X', X_PAR_END - 10, 'GRCh37'), END=X_PAR_END + 9, GQ=9, DP=3), + hl.Struct(s='sample_xy', locus=hl.Locus('X', X_PAR_END + 10, 'GRCh37'), END=X_PAR_END + 29, GQ=6, DP=2), + hl.Struct(s='sample_xy', locus=hl.Locus('Y', Y_PAR_END - 10, 'GRCh37'), END=Y_PAR_END + 9, GQ=12, DP=4), + hl.Struct(s='sample_xy', locus=hl.Locus('Y', Y_PAR_END + 10, 'GRCh37'), END=Y_PAR_END + 29, GQ=9, DP=3), ] return hl.Table.parallelize( - ref_blocks, - schema=hl.dtype('struct{s:str,locus:locus,END:int32,GQ:int32,DP:int32}') + ref_blocks, schema=hl.dtype('struct{s:str,locus:locus,END:int32,GQ:int32,DP:int32}') ).to_matrix_table(row_key=['locus'], row_fields=[], col_key=['s']) def get_impute_sex_chromosome_ploidy_var_mt(): var = [ - hl.Struct(locus=hl.Locus('22', 2000021, 'GRCh37'), alleles=hl.array(["A", "C"]), s="sample_xx", LA=hl.array([0, 1]), - LGT=hl.call(0, 1, phased=False), GQ=15, DP=5), - hl.Struct(locus=hl.Locus('X', X_PAR_END-11, 'GRCh37'), alleles=hl.array(["A", "C"]), s="sample_xx", LA=hl.array([0, 1]), - LGT=hl.call(0, 1, phased=False), GQ=18, DP=6), - hl.Struct(locus=hl.Locus('X', X_PAR_END+30, 'GRCh37'), alleles=hl.array(["A", "C"]), s="sample_xx", - LA=hl.array([0, 1]), - LGT=hl.call(0, 1, phased=False), GQ=18, DP=6), - hl.Struct(locus=hl.Locus('X', X_PAR_END + 33, 'GRCh37'), alleles=hl.array(["A", "C", "G"]), s="sample_xx", - LA=hl.array([0, 1, 2]), - LGT=hl.call(0, 2, phased=False), GQ=15, DP=5), - hl.Struct(locus=hl.Locus('22', 2000021, 'GRCh37'), alleles=hl.array(["A", "C"]), s="sample_xy", LA=hl.array([0, 1]), - LGT=hl.call(0, 1, phased=False), GQ=15, DP=5), - hl.Struct(locus=hl.Locus('X', X_PAR_END - 11, 'GRCh37'), alleles=hl.array(["A", "C"]), s="sample_xy", - LA=hl.array([0, 1]), - LGT=hl.call(1, 1, phased=False), GQ=5, DP=2), - hl.Struct(locus=hl.Locus('X', X_PAR_END + 30, 'GRCh37'), alleles=hl.array(["A", "C"]), s="sample_xy", - LA=hl.array([0, 1]), - LGT=hl.call(1, 1, phased=False), GQ=7, DP=4), - hl.Struct(locus=hl.Locus('X', X_PAR_END + 33, 'GRCh37'), alleles=hl.array(["A", "C", "G"]), s="sample_xy", - LA=hl.array([0, 1, 2]), - LGT=hl.call(2, 2, phased=False), GQ=5, DP=3), - hl.Struct(locus=hl.Locus('Y', Y_PAR_END-11, 'GRCh37'), alleles=hl.array(["A", "C"]), s="sample_xy", LA=hl.array([0, 1]), - LGT=hl.call(1, 1, phased=False), GQ=9, DP=2), - hl.Struct(locus=hl.Locus('Y', Y_PAR_END+30, 'GRCh37'), alleles=hl.array(["A", "C"]), s="sample_xy", LA=hl.array([0, 1]), - LGT=hl.call(1, 1, phased=False), GQ=12, DP=4), - hl.Struct(locus=hl.Locus('Y', Y_PAR_END+33, 'GRCh37'), alleles=hl.array(["A", "C"]), s="sample_xy", - LA=hl.array([0, 1]), - LGT=hl.call(1, 1, phased=False), GQ=6, DP=2), + hl.Struct( + locus=hl.Locus('22', 2000021, 'GRCh37'), + alleles=hl.array(["A", "C"]), + s="sample_xx", + LA=hl.array([0, 1]), + LGT=hl.call(0, 1, phased=False), + GQ=15, + DP=5, + ), + hl.Struct( + locus=hl.Locus('X', X_PAR_END - 11, 'GRCh37'), + alleles=hl.array(["A", "C"]), + s="sample_xx", + LA=hl.array([0, 1]), + LGT=hl.call(0, 1, phased=False), + GQ=18, + DP=6, + ), + hl.Struct( + locus=hl.Locus('X', X_PAR_END + 30, 'GRCh37'), + alleles=hl.array(["A", "C"]), + s="sample_xx", + LA=hl.array([0, 1]), + LGT=hl.call(0, 1, phased=False), + GQ=18, + DP=6, + ), + hl.Struct( + locus=hl.Locus('X', X_PAR_END + 33, 'GRCh37'), + alleles=hl.array(["A", "C", "G"]), + s="sample_xx", + LA=hl.array([0, 1, 2]), + LGT=hl.call(0, 2, phased=False), + GQ=15, + DP=5, + ), + hl.Struct( + locus=hl.Locus('22', 2000021, 'GRCh37'), + alleles=hl.array(["A", "C"]), + s="sample_xy", + LA=hl.array([0, 1]), + LGT=hl.call(0, 1, phased=False), + GQ=15, + DP=5, + ), + hl.Struct( + locus=hl.Locus('X', X_PAR_END - 11, 'GRCh37'), + alleles=hl.array(["A", "C"]), + s="sample_xy", + LA=hl.array([0, 1]), + LGT=hl.call(1, 1, phased=False), + GQ=5, + DP=2, + ), + hl.Struct( + locus=hl.Locus('X', X_PAR_END + 30, 'GRCh37'), + alleles=hl.array(["A", "C"]), + s="sample_xy", + LA=hl.array([0, 1]), + LGT=hl.call(1, 1, phased=False), + GQ=7, + DP=4, + ), + hl.Struct( + locus=hl.Locus('X', X_PAR_END + 33, 'GRCh37'), + alleles=hl.array(["A", "C", "G"]), + s="sample_xy", + LA=hl.array([0, 1, 2]), + LGT=hl.call(2, 2, phased=False), + GQ=5, + DP=3, + ), + hl.Struct( + locus=hl.Locus('Y', Y_PAR_END - 11, 'GRCh37'), + alleles=hl.array(["A", "C"]), + s="sample_xy", + LA=hl.array([0, 1]), + LGT=hl.call(1, 1, phased=False), + GQ=9, + DP=2, + ), + hl.Struct( + locus=hl.Locus('Y', Y_PAR_END + 30, 'GRCh37'), + alleles=hl.array(["A", "C"]), + s="sample_xy", + LA=hl.array([0, 1]), + LGT=hl.call(1, 1, phased=False), + GQ=12, + DP=4, + ), + hl.Struct( + locus=hl.Locus('Y', Y_PAR_END + 33, 'GRCh37'), + alleles=hl.array(["A", "C"]), + s="sample_xy", + LA=hl.array([0, 1]), + LGT=hl.call(1, 1, phased=False), + GQ=6, + DP=2, + ), ] return hl.Table.parallelize( var, - schema=hl.dtype('struct{locus:locus,alleles:array,s:str,LA:array,LGT:call,GQ:int32,DP:int32}') + schema=hl.dtype( + 'struct{locus:locus,alleles:array,s:str,LA:array,LGT:call,GQ:int32,DP:int32}' + ), ).to_matrix_table(row_key=['locus', 'alleles'], col_key=['s']) @@ -344,7 +493,9 @@ def test_impute_sex_chromosome_ploidy_1(): ref_mt = get_impute_sex_chromosome_ploidy_ref_mt() var_mt = hl.Table.parallelize( [], - schema=hl.dtype('struct{locus:locus,alleles:array,s:str,LA:array,LGT:call,GQ:int32,DP:int32}') + schema=hl.dtype( + 'struct{locus:locus,alleles:array,s:str,LA:array,LGT:call,GQ:int32,DP:int32}' + ), ).to_matrix_table(row_key=['locus', 'alleles'], col_key=['s']) vds = hl.vds.VariantDataset(ref_mt, var_mt) calling_intervals = [ @@ -356,18 +507,8 @@ def test_impute_sex_chromosome_ploidy_1(): r = hl.vds.impute_sex_chromosome_ploidy(vds, calling_intervals, normalization_contig='22') assert r.collect() == [ - hl.Struct(s='sample_xx', - autosomal_mean_dp=5.0, - x_mean_dp=5.5, - x_ploidy=2.2, - y_mean_dp=0.0, - y_ploidy=0.0), - hl.Struct(s='sample_xy', - autosomal_mean_dp=5.0, - x_mean_dp=2.5, - x_ploidy=1.0, - y_mean_dp=3.5, - y_ploidy=1.4) + hl.Struct(s='sample_xx', autosomal_mean_dp=5.0, x_mean_dp=5.5, x_ploidy=2.2, y_mean_dp=0.0, y_ploidy=0.0), + hl.Struct(s='sample_xy', autosomal_mean_dp=5.0, x_mean_dp=2.5, x_ploidy=1.0, y_mean_dp=3.5, y_ploidy=1.4), ] @@ -387,18 +528,8 @@ def test_impute_sex_chromosome_ploidy_2(): r = hl.vds.impute_sex_chromosome_ploidy(vds, calling_intervals, normalization_contig='22', use_variant_dataset=True) assert r.collect() == [ - hl.Struct(s='sample_xx', - autosomal_mean_dp=5.0, - x_mean_dp=5.0, - x_ploidy=2.0, - y_mean_dp=0.0, - y_ploidy=0.0), - hl.Struct(s='sample_xy', - autosomal_mean_dp=5.0, - x_mean_dp=3.0, - x_ploidy=1.2, - y_mean_dp=2.0, - y_ploidy=0.8) + hl.Struct(s='sample_xx', autosomal_mean_dp=5.0, x_mean_dp=5.0, x_ploidy=2.0, y_mean_dp=0.0, y_ploidy=0.0), + hl.Struct(s='sample_xy', autosomal_mean_dp=5.0, x_mean_dp=3.0, x_ploidy=1.2, y_mean_dp=2.0, y_ploidy=0.8), ] @@ -409,8 +540,7 @@ def test_filter_intervals_segment(): intervals = [hl.parse_locus_interval('chr22:10514784-10517000', reference_genome='GRCh38')] filt = hl.vds.filter_intervals(vds, intervals, split_reference_blocks=True) - assert hl.vds.to_dense_mt(filt)._same( - hl.filter_intervals(hl.vds.to_dense_mt(vds), intervals)) + assert hl.vds.to_dense_mt(filt)._same(hl.filter_intervals(hl.vds.to_dense_mt(vds), intervals)) ref = filt.reference_data var = filt.variant_data @@ -423,13 +553,10 @@ def test_filter_intervals_segment_table(): vds = hl.vds.read_vds(os.path.join(resource('vds'), '1kg_2samples_starts.vds')) intervals = [hl.parse_locus_interval('chr22:10514784-10517000', reference_genome='GRCh38')] - intervals_table = hl.Table.parallelize( - hl.array(intervals).map(lambda x: hl.struct(interval=x)), - key='interval') + intervals_table = hl.Table.parallelize(hl.array(intervals).map(lambda x: hl.struct(interval=x)), key='interval') filt = hl.vds.filter_intervals(vds, intervals_table, split_reference_blocks=True) - assert hl.vds.to_dense_mt(filt)._same( - hl.filter_intervals(hl.vds.to_dense_mt(vds), intervals)) + assert hl.vds.to_dense_mt(filt)._same(hl.filter_intervals(hl.vds.to_dense_mt(vds), intervals)) ref = filt.reference_data var = filt.variant_data @@ -443,8 +570,7 @@ def test_filter_intervals_default(): intervals = [hl.parse_locus_interval('chr22:10514784-10517000', reference_genome='GRCh38')] filt = hl.vds.filter_intervals(vds, intervals) - assert hl.vds.to_dense_mt(filt)._same( - hl.filter_intervals(hl.vds.to_dense_mt(vds), intervals)) + assert hl.vds.to_dense_mt(filt)._same(hl.filter_intervals(hl.vds.to_dense_mt(vds), intervals)) var = filt.variant_data assert var.aggregate_rows(hl.agg.all(intervals[0].contains(var.locus))) @@ -454,13 +580,10 @@ def test_filter_intervals_default_table(): vds = hl.vds.read_vds(os.path.join(resource('vds'), '1kg_2samples_starts.vds')) intervals = [hl.parse_locus_interval('chr22:10514784-10517000', reference_genome='GRCh38')] - intervals_table = hl.Table.parallelize( - hl.array(intervals).map(lambda x: hl.struct(interval=x)), - key='interval') + intervals_table = hl.Table.parallelize(hl.array(intervals).map(lambda x: hl.struct(interval=x)), key='interval') filt = hl.vds.filter_intervals(vds, intervals_table) - assert hl.vds.to_dense_mt(filt)._same( - hl.filter_intervals(hl.vds.to_dense_mt(vds), intervals)) + assert hl.vds.to_dense_mt(filt)._same(hl.filter_intervals(hl.vds.to_dense_mt(vds), intervals)) var = filt.variant_data assert var.aggregate_rows(hl.agg.all(intervals[0].contains(var.locus))) @@ -469,11 +592,11 @@ def test_filter_intervals_default_table(): def test_filter_chromosomes(): vds = hl.vds.read_vds(os.path.join(resource('vds'), '1kg_2samples_starts.vds')) - autosomes = [f'chr{i}' for i in range(1, 23)] sex_chrs = ['chrX', 'chrY'] all_chrs = autosomes + sex_chrs + def assert_contigs(vds, expected): expected_set = set(expected) @@ -501,16 +624,19 @@ def test_to_dense_mt(): dense = hl.vds.to_dense_mt(vds).select_entries('LGT', 'LA', 'GQ', 'DP') - assert dense.rows().select()._same( - vds.variant_data.rows().select()), "rows differ between variant data and dense mt" + assert ( + dense.rows().select()._same(vds.variant_data.rows().select()) + ), "rows differ between variant data and dense mt" assert dense.filter_entries(hl.is_defined(dense.LA))._same( - vds.variant_data.select_entries('LGT', 'LA', 'GQ', 'DP')), "cannot recover variant data" + vds.variant_data.select_entries('LGT', 'LA', 'GQ', 'DP') + ), "cannot recover variant data" as_dict = dense.aggregate_entries( - hl.dict(hl.zip(hl.agg.collect((hl.str(dense.locus), dense.s)), hl.agg.collect(dense.entry)))) + hl.dict(hl.zip(hl.agg.collect((hl.str(dense.locus), dense.s)), hl.agg.collect(dense.entry))) + ) - assert as_dict.get(('chr22:10514784', 'NA12891')) == None + assert as_dict.get(('chr22:10514784', 'NA12891')) is None assert as_dict.get(('chr22:10514784', 'NA12878')) == hl.Struct(LGT=hl.Call([0, 1]), LA=[0, 1], GQ=23, DP=4) assert as_dict.get(('chr22:10516102', 'NA12891')) == hl.Struct(LGT=hl.Call([0, 0]), LA=None, GQ=12, DP=7) @@ -520,7 +646,7 @@ def test_to_dense_mt(): assert as_dict.get(('chr22:10516150', 'NA12878')) == hl.Struct(LGT=hl.Call([0, 1]), LA=[0, 1], GQ=99, DP=10) assert as_dict.get(('chr22:10519088', 'NA12891')) == hl.Struct(LGT=hl.Call([0, 1]), LA=[0, 1], GQ=99, DP=21) - assert as_dict.get(('chr22:10519088', 'NA12878')) == None + assert as_dict.get(('chr22:10519088', 'NA12878')) is None assert as_dict.get(('chr22:10557694', 'NA12891')) == hl.Struct(LGT=hl.Call([0, 1]), LA=[0, 1], GQ=28, DP=19) assert as_dict.get(('chr22:10557694', 'NA12878')) == hl.Struct(LGT=hl.Call([0, 0]), LA=None, GQ=13, DP=16) @@ -539,15 +665,15 @@ def test_merge_reference_blocks(): rd = vds.reference_data vds.reference_data = rd.annotate_entries(GQ=rd.GQ - rd.GQ % 10) vds.reference_data = vds.reference_data.annotate_entries( - LEN=vds.reference_data.END - vds.reference_data.locus.position + 1, N_BLOCKS=1) + LEN=vds.reference_data.END - vds.reference_data.locus.position + 1, N_BLOCKS=1 + ) sampqc_1 = hl.vds.sample_qc(vds, gq_bins=(0, 10, 20, 30), dp_bins=()).select('bases_over_gq_threshold') - merged = hl.vds.merge_reference_blocks(vds, - equivalence_function=lambda b1, b2: b1.GQ == b2.GQ, - merge_functions={'LEN': 'sum', - 'MIN_DP': 'min', - 'N_BLOCKS': 'sum'} - ).checkpoint(new_temp_file(extension='vds')) + merged = hl.vds.merge_reference_blocks( + vds, + equivalence_function=lambda b1, b2: b1.GQ == b2.GQ, + merge_functions={'LEN': 'sum', 'MIN_DP': 'min', 'N_BLOCKS': 'sum'}, + ).checkpoint(new_temp_file(extension='vds')) sampqc_2 = hl.vds.sample_qc(merged, gq_bins=(0, 10, 20, 30), dp_bins=()).select('bases_over_gq_threshold') assert sampqc_1._same(sampqc_2), "gq bins aren't the same" @@ -585,34 +711,35 @@ def test_truncate_reference_blocks(): def test_union_rows1(): vds = hl.vds.read_vds(os.path.join(resource('vds'), '1kg_chr22_5_samples.vds')) - vds1 = hl.vds.filter_intervals(vds, - [hl.parse_locus_interval('chr22:start-10754094', reference_genome='GRCh38')], - split_reference_blocks=True) - vds2 = hl.vds.filter_intervals(vds, - [hl.parse_locus_interval('chr22:10754094-end', reference_genome='GRCh38')], - split_reference_blocks=True) - + vds1 = hl.vds.filter_intervals( + vds, [hl.parse_locus_interval('chr22:start-10754094', reference_genome='GRCh38')], split_reference_blocks=True + ) + vds2 = hl.vds.filter_intervals( + vds, [hl.parse_locus_interval('chr22:10754094-end', reference_genome='GRCh38')], split_reference_blocks=True + ) vds_union = vds1.union_rows(vds2) assert hl.vds.to_dense_mt(vds)._same(hl.vds.to_dense_mt(vds_union)) + @test_timeout(local=3 * 60) def test_union_rows2(): vds = hl.vds.read_vds(os.path.join(resource('vds'), '1kg_chr22_5_samples.vds')) - vds1 = hl.vds.filter_intervals(vds, - [hl.parse_locus_interval('chr22:start-10754094', reference_genome='GRCh38')], - split_reference_blocks=True) - vds2 = hl.vds.filter_intervals(vds, - [hl.parse_locus_interval('chr22:10754094-end', reference_genome='GRCh38')], - split_reference_blocks=True) - + vds1 = hl.vds.filter_intervals( + vds, [hl.parse_locus_interval('chr22:start-10754094', reference_genome='GRCh38')], split_reference_blocks=True + ) + vds2 = hl.vds.filter_intervals( + vds, [hl.parse_locus_interval('chr22:10754094-end', reference_genome='GRCh38')], split_reference_blocks=True + ) vds1_trunc = hl.vds.truncate_reference_blocks(vds1, max_ref_block_base_pairs=50) vds2_trunc = hl.vds.truncate_reference_blocks(vds1, max_ref_block_base_pairs=75) vds_trunc_union = vds1_trunc.union_rows(vds2_trunc) - assert hl.eval(vds_trunc_union.reference_data.index_globals()[hl.vds.VariantDataset.ref_block_max_length_field]) == 75 + assert ( + hl.eval(vds_trunc_union.reference_data.index_globals()[hl.vds.VariantDataset.ref_block_max_length_field]) == 75 + ) assert 'max_ref_block_length' not in vds1_trunc.union_rows(vds2).reference_data.globals @@ -620,8 +747,8 @@ def test_union_rows2(): def test_combiner_max_len(): vds = hl.vds.read_vds(os.path.join(resource('vds'), '1kg_chr22_5_samples.vds')) all_samples = vds.reference_data.s.collect() - samp1 = all_samples[:len(all_samples)//2] - samp2 = all_samples[len(all_samples)//2:] + samp1 = all_samples[: len(all_samples) // 2] + samp2 = all_samples[len(all_samples) // 2 :] vds1 = hl.vds.filter_samples(vds, samp1, remove_dead_alleles=True) vds2 = hl.vds.filter_samples(vds, samp2, remove_dead_alleles=True) @@ -634,7 +761,10 @@ def test_combiner_max_len(): combined1 = combine_references([vds1_trunc.reference_data, vds2_trunc.reference_data]) assert hl.eval(combined1.index_globals()[hl.vds.VariantDataset.ref_block_max_length_field]) == 75 - combined2 = combine_references([vds1_trunc.reference_data, vds2.reference_data.drop(hl.vds.VariantDataset.ref_block_max_length_field)]) + combined2 = combine_references([ + vds1_trunc.reference_data, + vds2.reference_data.drop(hl.vds.VariantDataset.ref_block_max_length_field), + ]) assert hl.vds.VariantDataset.ref_block_max_length_field not in combined2.globals @@ -644,12 +774,14 @@ def test_split_sparse_roundtrip(): vds = hl.vds.read_vds(os.path.join(resource('vds'), '1kg_chr22_5_samples.vds')) smt = hl.vds.to_merged_sparse_mt(vds) smt = hl.experimental.sparse_split_multi(smt) - vds2 = hl.vds.VariantDataset.from_merged_representation(smt, - ref_block_fields=list(vds.reference_data.entry), - is_split=True) + vds2 = hl.vds.VariantDataset.from_merged_representation( + smt, ref_block_fields=list(vds.reference_data.entry), is_split=True + ) vds_split = hl.vds.split_multi(vds) - assert vds2.variant_data.select_entries(*vds_split.variant_data.entry).select_globals()._same(vds_split.variant_data) + assert ( + vds2.variant_data.select_entries(*vds_split.variant_data.entry).select_globals()._same(vds_split.variant_data) + ) assert vds2.reference_data._same(vds_split.reference_data.drop('ref_allele')) @@ -659,7 +791,8 @@ def test_ref_block_max_len_patch(): vds.reference_data = vds.reference_data.drop('ref_block_max_len') max_rb_len = vds.reference_data.aggregate_entries( - hl.agg.max(vds.reference_data.END - vds.reference_data.locus.position + 1)) + hl.agg.max(vds.reference_data.END - vds.reference_data.locus.position + 1) + ) with hl.TemporaryDirectory() as tmpdir: vds_path = os.path.join(tmpdir, 'to_patch.vds') vds.write(vds_path) @@ -667,7 +800,9 @@ def test_ref_block_max_len_patch(): hl.vds.store_ref_block_max_length(vds_path) vds2 = hl.vds.read_vds(vds_path) - assert hl.eval(vds2.reference_data.index_globals()[hl.vds.VariantDataset.ref_block_max_length_field]) == max_rb_len + assert ( + hl.eval(vds2.reference_data.index_globals()[hl.vds.VariantDataset.ref_block_max_length_field]) == max_rb_len + ) def test_filter_intervals_table(): @@ -686,9 +821,13 @@ def test_ref_block_does_not_densify_to_next_contig(): ref = vds.reference_data var = vds.variant_data.filter_entries(False) # max out all chr1 refblocks, and truncate all chr2 refblocks so that nothing in chr2 should be densified - ref = ref.annotate_entries(END=hl.if_else(ref.locus.contig == 'chr1', - hl.parse_locus_interval('chr1', reference_genome=ref.locus.dtype.reference_genome).end.position, - ref.locus.position)) + ref = ref.annotate_entries( + END=hl.if_else( + ref.locus.contig == 'chr1', + hl.parse_locus_interval('chr1', reference_genome=ref.locus.dtype.reference_genome).end.position, + ref.locus.position, + ) + ) vds = hl.vds.VariantDataset(reference_data=ref, variant_data=var) mt = hl.vds.to_dense_mt(vds) mt = mt.filter_rows(mt.locus.contig == 'chr2') diff --git a/hail/python/test/hail/vds/test_vds_functions.py b/hail/python/test/hail/vds/test_vds_functions.py index 831f792e134..dfbe9035afe 100644 --- a/hail/python/test/hail/vds/test_vds_functions.py +++ b/hail/python/test/hail/vds/test_vds_functions.py @@ -2,6 +2,7 @@ import hail as hl + def test_lgt_to_gt(): call_0_0_f = hl.call(0, 0, phased=False) call_0_0_t = hl.call(0, 0, phased=True) @@ -12,30 +13,49 @@ def test_lgt_to_gt(): la = [0, 3, 5] - assert hl.eval(tuple(hl.vds.lgt_to_gt(c, la) for c in [call_0_0_f, call_0_0_t, call_0_1_f, call_2_0_t, call_1])) == \ - tuple([hl.Call([0, 0], phased=False), hl.Call([0, 0], phased=True), hl.Call([0, 3], phased=False), hl.Call([5, 0], phased=True), hl.Call([3], phased=False)]) + assert hl.eval( + tuple(hl.vds.lgt_to_gt(c, la) for c in [call_0_0_f, call_0_0_t, call_0_1_f, call_2_0_t, call_1]) + ) == tuple([ + hl.Call([0, 0], phased=False), + hl.Call([0, 0], phased=True), + hl.Call([0, 3], phased=False), + hl.Call([5, 0], phased=True), + hl.Call([3], phased=False), + ]) - assert hl.eval(hl.vds.lgt_to_gt(call_0_0_f, hl.missing('array'))) == hl.Call([0,0], phased=False) + assert hl.eval(hl.vds.lgt_to_gt(call_0_0_f, hl.missing('array'))) == hl.Call([0, 0], phased=False) def test_lgt_to_gt_invalid(): - c1 = hl.call(1, 1) - c2 = hl.call(1, 1, phased=True) assert hl.eval(hl.vds.lgt_to_gt(c1, [0, 17495])) == hl.Call([17495, 17495]) # the below fails because phasing uses the sum of j and k for its second allele. # we cannot represent this allele index in 28 bits + # c2 = hl.call(1, 1, phased=True) # assert hl.eval(hl.vds.lgt_to_gt(c2, [0, 17495])) == hl.Call([17495, 17495], phased=True) + def test_local_to_global(): local_alleles = [0, 1, 3] lad = [1, 9, 10] lpl = [1001, 1002, 1003, 1004, 0, 1005] assert hl.eval(hl.vds.local_to_global(lad, local_alleles, 4, 0, number='R')) == [1, 9, 0, 10] - assert hl.eval(hl.vds.local_to_global(lpl, local_alleles, 4, 999, number='G')) == [1001, 1002, 1003, 999, 999, 999, 1004, 0, 999, 1005] - assert hl.eval(hl.vds.local_to_global(lad, [0,1,2], 3, 0, number='R')) == lad - assert hl.eval(hl.vds.local_to_global(lpl, [0,1,2], 3, 999, number='G')) == lpl + assert hl.eval(hl.vds.local_to_global(lpl, local_alleles, 4, 999, number='G')) == [ + 1001, + 1002, + 1003, + 999, + 999, + 999, + 1004, + 0, + 999, + 1005, + ] + assert hl.eval(hl.vds.local_to_global(lad, [0, 1, 2], 3, 0, number='R')) == lad + assert hl.eval(hl.vds.local_to_global(lpl, [0, 1, 2], 3, 999, number='G')) == lpl + def test_local_to_global_alleles_non_increasing(): local_alleles = [0, 3, 1] @@ -43,22 +63,46 @@ def test_local_to_global_alleles_non_increasing(): lpl = [1001, 1004, 0, 1002, 1003, 1005] assert hl.eval(hl.vds.local_to_global(lad, local_alleles, 4, 0, number='R')) == [1, 9, 0, 10] - assert hl.eval(hl.vds.local_to_global(lpl, local_alleles, 4, 999, number='G')) == [1001, 1002, 1005, 999, 999, 999, 1004, 1003, 999, 0] + assert hl.eval(hl.vds.local_to_global(lpl, local_alleles, 4, 999, number='G')) == [ + 1001, + 1002, + 1005, + 999, + 999, + 999, + 1004, + 1003, + 999, + 0, + ] assert hl.eval(hl.vds.local_to_global([0, 1, 2, 3, 4, 5], [0, 2, 1], 3, 0, number='G')) == [0, 3, 5, 1, 4, 2] + def test_local_to_global_missing_fill(): local_alleles = [0, 3, 1] lad = [1, 10, 9] assert hl.eval(hl.vds.local_to_global(lad, local_alleles, 4, hl.missing('int32'), number='R')) == [1, 9, None, 10] + def test_local_to_global_out_of_bounds(): local_alleles = [0, 2] lad = [1, 9] lpl = [1001, 0, 1002] - with pytest.raises(hl.utils.HailUserError, match='local_to_global: local allele of 2 out of bounds given n_total_alleles of 2'): + with pytest.raises( + hl.utils.HailUserError, match='local_to_global: local allele of 2 out of bounds given n_total_alleles of 2' + ): assert hl.eval(hl.vds.local_to_global(lad, local_alleles, 2, 0, number='R')) == [1, 0] - with pytest.raises(hl.utils.HailUserError, match='local_to_global: local allele of 2 out of bounds given n_total_alleles of 2'): - assert hl.eval(hl.vds.local_to_global(lpl, local_alleles, 2, 10001, number='G')) == [1001, 10001, 0, 10001, 10001, 1002] + with pytest.raises( + hl.utils.HailUserError, match='local_to_global: local allele of 2 out of bounds given n_total_alleles of 2' + ): + assert hl.eval(hl.vds.local_to_global(lpl, local_alleles, 2, 10001, number='G')) == [ + 1001, + 10001, + 0, + 10001, + 10001, + 1002, + ] diff --git a/hail/python/test/hailtop/batch/conftest.py b/hail/python/test/hailtop/batch/conftest.py new file mode 100644 index 00000000000..4b01195df98 --- /dev/null +++ b/hail/python/test/hailtop/batch/conftest.py @@ -0,0 +1,63 @@ +import asyncio +import os +from typing import AsyncIterator, Tuple + +import pytest + +from hailtop.aiotools.router_fs import RouterAsyncFS +from hailtop.batch import ServiceBackend +from hailtop.config import get_remote_tmpdir +from hailtop.utils import secret_alnum_string + + +@pytest.fixture(scope="session") +async def service_backend() -> AsyncIterator[ServiceBackend]: + sb = ServiceBackend() + try: + yield sb + finally: + await sb.async_close() + + +@pytest.fixture(scope="session") +async def fs() -> AsyncIterator[RouterAsyncFS]: + fs = RouterAsyncFS() + try: + yield fs + finally: + await fs.close() + + +@pytest.fixture(scope="session") +def tmpdir() -> str: + return os.path.join( + get_remote_tmpdir('test_batch_service_backend.py::tmpdir'), + secret_alnum_string(5), # create a unique URL for each split of the tests + ) + + +@pytest.fixture +def output_tmpdir(tmpdir: str) -> str: + return os.path.join(tmpdir, 'output', secret_alnum_string(5)) + + +@pytest.fixture +def output_bucket_path(fs: RouterAsyncFS, output_tmpdir: str) -> Tuple[str, str, str]: + url = fs.parse_url(output_tmpdir) + bucket = '/'.join(url.bucket_parts) + path = url.path + path = '/' + os.path.join(bucket, path) + return bucket, path, output_tmpdir + + +@pytest.fixture(scope="session") +async def upload_test_files( + fs: RouterAsyncFS, tmpdir: str +) -> Tuple[Tuple[str, bytes], Tuple[str, bytes], Tuple[str, bytes]]: + test_files = ( + (os.path.join(tmpdir, 'inputs/hello.txt'), b'hello world'), + (os.path.join(tmpdir, 'inputs/hello spaces.txt'), b'hello'), + (os.path.join(tmpdir, 'inputs/hello (foo) spaces.txt'), b'hello'), + ) + await asyncio.gather(*(fs.write(url, data) for url, data in test_files)) + return test_files diff --git a/hail/python/test/hailtop/batch/test_batch.py b/hail/python/test/hailtop/batch/test_batch.py deleted file mode 100644 index 8c5b04d7e15..00000000000 --- a/hail/python/test/hailtop/batch/test_batch.py +++ /dev/null @@ -1,1500 +0,0 @@ -import asyncio -import inspect -import secrets -import unittest - -import pytest -import os -import subprocess as sp -import tempfile -from shlex import quote as shq -import uuid -import re -import orjson - -import hailtop.fs as hfs -import hailtop.batch_client.client as bc -from hailtop import pip_version -from hailtop.batch import Batch, ServiceBackend, LocalBackend, ResourceGroup -from hailtop.batch.resource import JobResourceFile -from hailtop.batch.exceptions import BatchException -from hailtop.batch.globals import arg_max -from hailtop.utils import grouped, async_to_blocking -from hailtop.config import get_remote_tmpdir, configuration_of -from hailtop.batch.utils import concatenate -from hailtop.aiotools.router_fs import RouterAsyncFS -from hailtop.test_utils import skip_in_azure -from hailtop.httpx import ClientResponseError - -from configparser import ConfigParser -from hailtop.config import get_user_config, user_config -from hailtop.config.variables import ConfigVariable -from _pytest.monkeypatch import MonkeyPatch - - -DOCKER_ROOT_IMAGE = os.environ.get('DOCKER_ROOT_IMAGE', 'ubuntu:22.04') -PYTHON_DILL_IMAGE = 'hailgenetics/python-dill:3.9-slim' -HAIL_GENETICS_HAIL_IMAGE = os.environ.get('HAIL_GENETICS_HAIL_IMAGE', f'hailgenetics/hail:{pip_version()}') -REQUESTER_PAYS_PROJECT = os.environ.get('GCS_REQUESTER_PAYS_PROJECT') - - -class LocalTests(unittest.TestCase): - def batch(self, requester_pays_project=None): - return Batch(backend=LocalBackend(), - requester_pays_project=requester_pays_project) - - def read(self, file): - with open(file, 'r') as f: - result = f.read().rstrip() - return result - - def assert_same_file(self, file1, file2): - assert self.read(file1).rstrip() == self.read(file2).rstrip() - - def test_read_input_and_write_output(self): - with tempfile.NamedTemporaryFile('w') as input_file, \ - tempfile.NamedTemporaryFile('w') as output_file: - input_file.write('abc') - input_file.flush() - - b = self.batch() - input = b.read_input(input_file.name) - b.write_output(input, output_file.name) - b.run() - - self.assert_same_file(input_file.name, output_file.name) - - def test_read_input_group(self): - with tempfile.NamedTemporaryFile('w') as input_file1, \ - tempfile.NamedTemporaryFile('w') as input_file2, \ - tempfile.NamedTemporaryFile('w') as output_file1, \ - tempfile.NamedTemporaryFile('w') as output_file2: - - input_file1.write('abc') - input_file2.write('123') - input_file1.flush() - input_file2.flush() - - b = self.batch() - input = b.read_input_group(in1=input_file1.name, - in2=input_file2.name) - - b.write_output(input.in1, output_file1.name) - b.write_output(input.in2, output_file2.name) - b.run() - - self.assert_same_file(input_file1.name, output_file1.name) - self.assert_same_file(input_file2.name, output_file2.name) - - def test_write_resource_group(self): - with tempfile.NamedTemporaryFile('w') as input_file1, \ - tempfile.NamedTemporaryFile('w') as input_file2, \ - tempfile.TemporaryDirectory() as output_dir: - - b = self.batch() - input = b.read_input_group(in1=input_file1.name, - in2=input_file2.name) - - b.write_output(input, output_dir + '/foo') - b.run() - - self.assert_same_file(input_file1.name, output_dir + '/foo.in1') - self.assert_same_file(input_file2.name, output_dir + '/foo.in2') - - def test_single_job(self): - with tempfile.NamedTemporaryFile('w') as output_file: - msg = 'hello world' - - b = self.batch() - j = b.new_job() - j.command(f'echo "{msg}" > {j.ofile}') - b.write_output(j.ofile, output_file.name) - b.run() - - assert self.read(output_file.name) == msg - - def test_single_job_with_shell(self): - with tempfile.NamedTemporaryFile('w') as output_file: - msg = 'hello world' - - b = self.batch() - j = b.new_job(shell='/bin/bash') - j.command(f'echo "{msg}" > {j.ofile}') - - b.write_output(j.ofile, output_file.name) - b.run() - - assert self.read(output_file.name) == msg - - def test_single_job_with_nonsense_shell(self): - b = self.batch() - j = b.new_job(shell='/bin/ajdsfoijasidojf') - j.image(DOCKER_ROOT_IMAGE) - j.command(f'echo "hello"') - self.assertRaises(Exception, b.run) - - b = self.batch() - j = b.new_job(shell='/bin/nonexistent') - j.command(f'echo "hello"') - self.assertRaises(Exception, b.run) - - def test_single_job_with_intermediate_failure(self): - b = self.batch() - j = b.new_job() - j.command(f'echoddd "hello"') - j2 = b.new_job() - j2.command(f'echo "world"') - - self.assertRaises(Exception, b.run) - - def test_single_job_w_input(self): - with tempfile.NamedTemporaryFile('w') as input_file, \ - tempfile.NamedTemporaryFile('w') as output_file: - msg = 'abc' - input_file.write(msg) - input_file.flush() - - b = self.batch() - input = b.read_input(input_file.name) - j = b.new_job() - j.command(f'cat {input} > {j.ofile}') - b.write_output(j.ofile, output_file.name) - b.run() - - assert self.read(output_file.name) == msg - - def test_single_job_w_input_group(self): - with tempfile.NamedTemporaryFile('w') as input_file1, \ - tempfile.NamedTemporaryFile('w') as input_file2, \ - tempfile.NamedTemporaryFile('w') as output_file: - msg1 = 'abc' - msg2 = '123' - - input_file1.write(msg1) - input_file2.write(msg2) - input_file1.flush() - input_file2.flush() - - b = self.batch() - input = b.read_input_group(in1=input_file1.name, - in2=input_file2.name) - j = b.new_job() - j.command(f'cat {input.in1} {input.in2} > {j.ofile}') - j.command(f'cat {input}.in1 {input}.in2') - b.write_output(j.ofile, output_file.name) - b.run() - - assert self.read(output_file.name) == msg1 + msg2 - - def test_single_job_bad_command(self): - b = self.batch() - j = b.new_job() - j.command("foo") # this should fail! - with self.assertRaises(sp.CalledProcessError): - b.run() - - def test_declare_resource_group(self): - with tempfile.NamedTemporaryFile('w') as output_file: - msg = 'hello world' - b = self.batch() - j = b.new_job() - j.declare_resource_group(ofile={'log': "{root}.txt"}) - assert isinstance(j.ofile, ResourceGroup) - j.command(f'echo "{msg}" > {j.ofile.log}') - b.write_output(j.ofile.log, output_file.name) - b.run() - - assert self.read(output_file.name) == msg - - def test_resource_group_get_all_inputs(self): - b = self.batch() - input = b.read_input_group(fasta="foo", - idx="bar") - j = b.new_job() - j.command(f"cat {input.fasta}") - assert input.fasta in j._inputs - assert input.idx in j._inputs - - def test_resource_group_get_all_mentioned(self): - b = self.batch() - j = b.new_job() - j.declare_resource_group(foo={'bed': '{root}.bed', 'bim': '{root}.bim'}) - assert isinstance(j.foo, ResourceGroup) - j.command(f"cat {j.foo.bed}") - assert j.foo.bed in j._mentioned - assert j.foo.bim not in j._mentioned - - def test_resource_group_get_all_mentioned_dependent_jobs(self): - b = self.batch() - j = b.new_job() - j.declare_resource_group(foo={'bed': '{root}.bed', 'bim': '{root}.bim'}) - j.command(f"cat") - j2 = b.new_job() - j2.command(f"cat {j.foo}") - - def test_resource_group_get_all_outputs(self): - b = self.batch() - j1 = b.new_job() - j1.declare_resource_group(foo={'bed': '{root}.bed', 'bim': '{root}.bim'}) - assert isinstance(j1.foo, ResourceGroup) - j1.command(f"cat {j1.foo.bed}") - j2 = b.new_job() - j2.command(f"cat {j1.foo.bed}") - - for r in [j1.foo.bed, j1.foo.bim]: - assert r in j1._internal_outputs - assert r in j2._inputs - - assert j1.foo.bed in j1._mentioned - assert j1.foo.bim not in j1._mentioned - - assert j1.foo.bed in j2._mentioned - assert j1.foo.bim not in j2._mentioned - - assert j1.foo not in j1._mentioned - - def test_multiple_isolated_jobs(self): - b = self.batch() - - output_files = [] - try: - output_files = [tempfile.NamedTemporaryFile('w') for _ in range(5)] - - for i, ofile in enumerate(output_files): - msg = f'hello world {i}' - j = b.new_job() - j.command(f'echo "{msg}" > {j.ofile}') - b.write_output(j.ofile, ofile.name) - b.run() - - for i, ofile in enumerate(output_files): - msg = f'hello world {i}' - assert self.read(ofile.name) == msg - finally: - [ofile.close() for ofile in output_files] - - def test_multiple_dependent_jobs(self): - with tempfile.NamedTemporaryFile('w') as output_file: - b = self.batch() - j = b.new_job() - j.command(f'echo "0" >> {j.ofile}') - - for i in range(1, 3): - j2 = b.new_job() - j2.command(f'echo "{i}" > {j2.tmp1}') - j2.command(f'cat {j.ofile} {j2.tmp1} > {j2.ofile}') - j = j2 - - b.write_output(j.ofile, output_file.name) - b.run() - - assert self.read(output_file.name) == "0\n1\n2" - - def test_select_jobs(self): - b = self.batch() - for i in range(3): - b.new_job(name=f'foo{i}') - self.assertTrue(len(b.select_jobs('foo')) == 3) - - def test_scatter_gather(self): - with tempfile.NamedTemporaryFile('w') as output_file: - b = self.batch() - - for i in range(3): - j = b.new_job(name=f'foo{i}') - j.command(f'echo "{i}" > {j.ofile}') - - merger = b.new_job() - merger.command('cat {files} > {ofile}'.format(files=' '.join([j.ofile for j in sorted(b.select_jobs('foo'), - key=lambda x: x.name, # type: ignore - reverse=True)]), - ofile=merger.ofile)) - - b.write_output(merger.ofile, output_file.name) - b.run() - - assert self.read(output_file.name) == '2\n1\n0' - - def test_add_extension_job_resource_file(self): - b = self.batch() - j = b.new_job() - j.command(f'echo "hello" > {j.ofile}') - assert isinstance(j.ofile, JobResourceFile) - j.ofile.add_extension('.txt.bgz') - assert j.ofile._value - assert j.ofile._value.endswith('.txt.bgz') - - def test_add_extension_input_resource_file(self): - input_file1 = '/tmp/data/example1.txt.bgz.foo' - b = self.batch() - in1 = b.read_input(input_file1) - assert in1._value - assert in1._value.endswith('.txt.bgz.foo') - - def test_file_name_space(self): - with tempfile.NamedTemporaryFile('w', prefix="some file name with (foo) spaces") as input_file, \ - tempfile.NamedTemporaryFile('w', prefix="another file name with (foo) spaces") as output_file: - - input_file.write('abc') - input_file.flush() - - b = self.batch() - input = b.read_input(input_file.name) - j = b.new_job() - j.command(f'cat {input} > {j.ofile}') - b.write_output(j.ofile, output_file.name) - b.run() - - self.assert_same_file(input_file.name, output_file.name) - - def test_resource_group_mentioned(self): - b = self.batch() - j = b.new_job() - j.declare_resource_group(foo={'bed': '{root}.bed'}) - assert isinstance(j.foo, ResourceGroup) - j.command(f'echo "hello" > {j.foo}') - - t2 = b.new_job() - t2.command(f'echo "hello" >> {j.foo.bed}') - b.run() - - def test_envvar(self): - with tempfile.NamedTemporaryFile('w') as output_file: - b = self.batch() - j = b.new_job() - j.env('SOME_VARIABLE', '123abcdef') - j.command(f'echo $SOME_VARIABLE > {j.ofile}') - b.write_output(j.ofile, output_file.name) - b.run() - assert self.read(output_file.name) == '123abcdef' - - def test_concatenate(self): - b = self.batch() - files = [] - for _ in range(10): - j = b.new_job() - j.command(f'touch {j.ofile}') - files.append(j.ofile) - concatenate(b, files, branching_factor=2) - assert len(b._jobs) == 10 + (5 + 3 + 2 + 1) - b.run() - - def test_python_job(self): - with tempfile.NamedTemporaryFile('w') as output_file: - b = self.batch() - head = b.new_job() - head.command(f'echo "5" > {head.r5}') - head.command(f'echo "3" > {head.r3}') - - def read(path): - with open(path, 'r') as f: - i = f.read() - return int(i) - - def multiply(x, y): - return x * y - - def reformat(x, y): - return {'x': x, 'y': y} - - middle = b.new_python_job() - r3 = middle.call(read, head.r3) - r5 = middle.call(read, head.r5) - r_mult = middle.call(multiply, r3, r5) - - middle2 = b.new_python_job() - r_mult = middle2.call(multiply, r_mult, 2) - r_dict = middle2.call(reformat, r3, r5) - - tail = b.new_job() - tail.command(f'cat {r3.as_str()} {r5.as_repr()} {r_mult.as_str()} {r_dict.as_json()} > {tail.ofile}') - - b.write_output(tail.ofile, output_file.name) - b.run() - assert self.read(output_file.name) == '3\n5\n30\n{\"x\": 3, \"y\": 5}' - - def test_backend_context_manager(self): - with LocalBackend() as backend: - b = Batch(backend=backend) - b.run() - - def test_failed_jobs_dont_stop_non_dependent_jobs(self): - with tempfile.NamedTemporaryFile('w') as output_file: - b = self.batch() - - head = b.new_job() - head.command(f'echo 1 > {head.ofile}') - - head2 = b.new_job() - head2.command('false') - - tail = b.new_job() - tail.command(f'cat {head.ofile} > {tail.ofile}') - b.write_output(tail.ofile, output_file.name) - self.assertRaises(Exception, b.run) - assert self.read(output_file.name) == '1' - - def test_failed_jobs_stop_child_jobs(self): - with tempfile.NamedTemporaryFile('w') as output_file: - b = self.batch() - - head = b.new_job() - head.command(f'echo 1 > {head.ofile}') - head.command('false') - - head2 = b.new_job() - head2.command(f'echo 2 > {head2.ofile}') - - tail = b.new_job() - tail.command(f'cat {head.ofile} > {tail.ofile}') - - b.write_output(head2.ofile, output_file.name) - b.write_output(tail.ofile, output_file.name) - self.assertRaises(Exception, b.run) - assert self.read(output_file.name) == '2' - - def test_failed_jobs_stop_grandchild_jobs(self): - with tempfile.NamedTemporaryFile('w') as output_file: - b = self.batch() - - head = b.new_job() - head.command(f'echo 1 > {head.ofile}') - head.command('false') - - head2 = b.new_job() - head2.command(f'echo 2 > {head2.ofile}') - - tail = b.new_job() - tail.command(f'cat {head.ofile} > {tail.ofile}') - - tail2 = b.new_job() - tail2.depends_on(tail) - tail2.command(f'echo foo > {tail2.ofile}') - - b.write_output(head2.ofile, output_file.name) - b.write_output(tail2.ofile, output_file.name) - self.assertRaises(Exception, b.run) - assert self.read(output_file.name) == '2' - - def test_failed_jobs_dont_stop_always_run_jobs(self): - with tempfile.NamedTemporaryFile('w') as output_file: - b = self.batch() - - head = b.new_job() - head.command(f'echo 1 > {head.ofile}') - head.command('false') - - tail = b.new_job() - tail.command(f'cat {head.ofile} > {tail.ofile}') - tail.always_run() - - b.write_output(tail.ofile, output_file.name) - self.assertRaises(Exception, b.run) - assert self.read(output_file.name) == '1' - - -class ServiceTests(unittest.TestCase): - def setUp(self): - # https://stackoverflow.com/questions/42332030/pytest-monkeypatch-setattr-inside-of-test-class-method - self.monkeypatch = MonkeyPatch() - - self.backend = ServiceBackend() - - remote_tmpdir = get_remote_tmpdir('hailtop_test_batch_service_tests') - if not remote_tmpdir.endswith('/'): - remote_tmpdir += '/' - self.remote_tmpdir = remote_tmpdir + str(uuid.uuid4()) + '/' - - if remote_tmpdir.startswith('gs://'): - match = re.fullmatch('gs://(?P[^/]+).*', remote_tmpdir) - assert match - self.bucket = match.groupdict()['bucket_name'] - else: - assert remote_tmpdir.startswith('hail-az://') - if remote_tmpdir.startswith('hail-az://'): - match = re.fullmatch('hail-az://(?P[^/]+)/(?P[^/]+).*', remote_tmpdir) - assert match - storage_account, container_name = match.groups() - else: - assert remote_tmpdir.startswith('https://') - match = re.fullmatch('https://(?P[^/]+).blob.core.windows.net/(?P[^/]+).*', remote_tmpdir) - assert match - storage_account, container_name = match.groups() - self.bucket = f'{storage_account}/{container_name}' - - self.cloud_input_dir = f'{self.remote_tmpdir}batch-tests/resources' - - token = uuid.uuid4() - self.cloud_output_path = f'/batch-tests/{token}' - self.cloud_output_dir = f'{self.remote_tmpdir}{self.cloud_output_path}' - - self.router_fs = RouterAsyncFS() - - if not self.sync_exists(f'{self.remote_tmpdir}batch-tests/resources/hello.txt'): - self.sync_write(f'{self.remote_tmpdir}batch-tests/resources/hello.txt', b'hello world') - if not self.sync_exists(f'{self.remote_tmpdir}batch-tests/resources/hello spaces.txt'): - self.sync_write(f'{self.remote_tmpdir}batch-tests/resources/hello spaces.txt', b'hello') - if not self.sync_exists(f'{self.remote_tmpdir}batch-tests/resources/hello (foo) spaces.txt'): - self.sync_write(f'{self.remote_tmpdir}batch-tests/resources/hello (foo) spaces.txt', b'hello') - - def tearDown(self): - self.backend.close() - - def sync_exists(self, url): - return async_to_blocking(self.router_fs.exists(url)) - - def sync_write(self, url, data): - return async_to_blocking(self.router_fs.write(url, data)) - - def batch(self, **kwargs): - name_of_test_method = inspect.stack()[1][3] - return Batch(name=name_of_test_method, - backend=self.backend, - default_image=DOCKER_ROOT_IMAGE, - attributes={'foo': 'a', 'bar': 'b'}, - **kwargs) - - def test_single_task_no_io(self): - b = self.batch() - j = b.new_job() - j.command('echo hello') - res = b.run() - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - - def test_single_task_input(self): - b = self.batch() - input = b.read_input(f'{self.cloud_input_dir}/hello.txt') - j = b.new_job() - j.command(f'cat {input}') - res = b.run() - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - - def test_single_task_input_resource_group(self): - b = self.batch() - input = b.read_input_group(foo=f'{self.cloud_input_dir}/hello.txt') - j = b.new_job() - j.storage('10Gi') - j.command(f'cat {input.foo}') - j.command(f'cat {input}.foo') - res = b.run() - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - - def test_single_task_output(self): - b = self.batch() - j = b.new_job(attributes={'a': 'bar', 'b': 'foo'}) - j.command(f'echo hello > {j.ofile}') - res = b.run() - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - - def test_single_task_write_output(self): - b = self.batch() - j = b.new_job() - j.command(f'echo hello > {j.ofile}') - b.write_output(j.ofile, f'{self.cloud_output_dir}/test_single_task_output.txt') - res = b.run() - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - - def test_single_task_resource_group(self): - b = self.batch() - j = b.new_job() - j.declare_resource_group(output={'foo': '{root}.foo'}) - assert isinstance(j.output, ResourceGroup) - j.command(f'echo "hello" > {j.output.foo}') - res = b.run() - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - - def test_single_task_write_resource_group(self): - b = self.batch() - j = b.new_job() - j.declare_resource_group(output={'foo': '{root}.foo'}) - assert isinstance(j.output, ResourceGroup) - j.command(f'echo "hello" > {j.output.foo}') - b.write_output(j.output, f'{self.cloud_output_dir}/test_single_task_write_resource_group') - b.write_output(j.output.foo, f'{self.cloud_output_dir}/test_single_task_write_resource_group_file.txt') - res = b.run() - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - - def test_multiple_dependent_tasks(self): - output_file = f'{self.cloud_output_dir}/test_multiple_dependent_tasks.txt' - b = self.batch() - j = b.new_job() - j.command(f'echo "0" >> {j.ofile}') - - for i in range(1, 3): - j2 = b.new_job() - j2.command(f'echo "{i}" > {j2.tmp1}') - j2.command(f'cat {j.ofile} {j2.tmp1} > {j2.ofile}') - j = j2 - - b.write_output(j.ofile, output_file) - res = b.run() - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - - def test_specify_cpu(self): - b = self.batch() - j = b.new_job() - j.cpu('0.5') - j.command(f'echo "hello" > {j.ofile}') - res = b.run() - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - - def test_specify_memory(self): - b = self.batch() - j = b.new_job() - j.memory('100M') - j.command(f'echo "hello" > {j.ofile}') - res = b.run() - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - - def test_scatter_gather(self): - b = self.batch() - - for i in range(3): - j = b.new_job(name=f'foo{i}') - j.command(f'echo "{i}" > {j.ofile}') - - merger = b.new_job() - merger.command('cat {files} > {ofile}'.format(files=' '.join([j.ofile for j in sorted(b.select_jobs('foo'), - key=lambda x: x.name, # type: ignore - reverse=True)]), - ofile=merger.ofile)) - - res = b.run() - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - - def test_file_name_space(self): - b = self.batch() - input = b.read_input(f'{self.cloud_input_dir}/hello (foo) spaces.txt') - j = b.new_job() - j.command(f'cat {input} > {j.ofile}') - b.write_output(j.ofile, f'{self.cloud_output_dir}/hello (foo) spaces.txt') - res = b.run() - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - - def test_dry_run(self): - b = self.batch() - j = b.new_job() - j.command(f'echo hello > {j.ofile}') - b.write_output(j.ofile, f'{self.cloud_output_dir}/test_single_job_output.txt') - b.run(dry_run=True) - - def test_verbose(self): - b = self.batch() - input = b.read_input(f'{self.cloud_input_dir}/hello.txt') - j = b.new_job() - j.command(f'cat {input}') - b.write_output(input, f'{self.cloud_output_dir}/hello.txt') - res = b.run(verbose=True) - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - - def test_cloudfuse_fails_with_read_write_mount_option(self): - assert self.bucket - path = f'/{self.bucket}{self.cloud_output_path}' - - b = self.batch() - j = b.new_job() - j.command(f'mkdir -p {path}; echo head > {path}/cloudfuse_test_1') - j.cloudfuse(self.bucket, f'/{self.bucket}', read_only=False) - - try: - b.run() - except ClientResponseError as e: - assert 'Only read-only cloudfuse requests are supported' in e.body, e.body - else: - assert False - - def test_cloudfuse_fails_with_io_mount_point(self): - assert self.bucket - path = f'/{self.bucket}{self.cloud_output_path}' - - b = self.batch() - j = b.new_job() - j.command(f'mkdir -p {path}; echo head > {path}/cloudfuse_test_1') - j.cloudfuse(self.bucket, f'/io', read_only=True) - - try: - b.run() - except ClientResponseError as e: - assert 'Cloudfuse requests with mount_path=/io are not supported' in e.body, e.body - else: - assert False - - def test_cloudfuse_read_only(self): - assert self.bucket - path = f'/{self.bucket}{self.cloud_output_path}' - - b = self.batch() - j = b.new_job() - j.command(f'mkdir -p {path}; echo head > {path}/cloudfuse_test_1') - j.cloudfuse(self.bucket, f'/{self.bucket}', read_only=True) - - res = b.run() - res_status = res.status() - assert res_status['state'] == 'failure', str((res_status, res.debug_info())) - - def test_cloudfuse_implicit_dirs(self): - assert self.bucket - path = self.router_fs.parse_url(f'{self.remote_tmpdir}batch-tests/resources/hello.txt').path - b = self.batch() - j = b.new_job() - j.command(f'cat /cloudfuse/{path}') - j.cloudfuse(self.bucket, f'/cloudfuse', read_only=True) - - res = b.run() - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - - def test_cloudfuse_empty_string_bucket_fails(self): - assert self.bucket - b = self.batch() - j = b.new_job() - with self.assertRaises(BatchException): - j.cloudfuse('', '/empty_bucket') - with self.assertRaises(BatchException): - j.cloudfuse(self.bucket, '') - - def test_cloudfuse_submount_in_io_doesnt_rm_bucket(self): - assert self.bucket - b = self.batch() - j = b.new_job() - j.cloudfuse(self.bucket, '/io/cloudfuse') - j.command(f'ls /io/cloudfuse/') - res = b.run() - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - assert self.sync_exists(f'{self.remote_tmpdir}batch-tests/resources/hello.txt') - - @skip_in_azure - def test_fuse_requester_pays(self): - assert REQUESTER_PAYS_PROJECT - b = self.batch(requester_pays_project=REQUESTER_PAYS_PROJECT) - j = b.new_job() - j.cloudfuse('hail-test-requester-pays-fds32', '/fuse-bucket') - j.command('cat /fuse-bucket/hello') - res = b.run() - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - - @skip_in_azure - def test_fuse_non_requester_pays_bucket_when_requester_pays_project_specified(self): - assert REQUESTER_PAYS_PROJECT - assert self.bucket - b = self.batch(requester_pays_project=REQUESTER_PAYS_PROJECT) - j = b.new_job() - j.command(f'ls /fuse-bucket') - j.cloudfuse(self.bucket, f'/fuse-bucket', read_only=True) - - res = b.run() - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - - @skip_in_azure - def test_requester_pays(self): - assert REQUESTER_PAYS_PROJECT - b = self.batch(requester_pays_project=REQUESTER_PAYS_PROJECT) - input = b.read_input('gs://hail-test-requester-pays-fds32/hello') - j = b.new_job() - j.command(f'cat {input}') - res = b.run() - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - - def test_benchmark_lookalike_workflow(self): - b = self.batch() - - setup_jobs = [] - for i in range(10): - j = b.new_job(f'setup_{i}').cpu(0.25) - j.command(f'echo "foo" > {j.ofile}') - setup_jobs.append(j) - - jobs = [] - for i in range(500): - j = b.new_job(f'create_file_{i}').cpu(0.25) - j.command(f'echo {setup_jobs[i % len(setup_jobs)].ofile} > {j.ofile}') - j.command(f'echo "bar" >> {j.ofile}') - jobs.append(j) - - combine = b.new_job(f'combine_output').cpu(0.25) - for _ in grouped(arg_max(), jobs): - combine.command(f'cat {" ".join(shq(j.ofile) for j in jobs)} >> {combine.ofile}') - b.write_output(combine.ofile, f'{self.cloud_output_dir}/pipeline_benchmark_test.txt') - # too slow - # assert b.run().status()['state'] == 'success' - - def test_envvar(self): - b = self.batch() - j = b.new_job() - j.env('SOME_VARIABLE', '123abcdef') - j.command('[ $SOME_VARIABLE = "123abcdef" ]') - res = b.run() - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - - def test_single_job_with_shell(self): - msg = 'hello world' - b = self.batch() - j = b.new_job(shell='/bin/sh') - j.command(f'echo "{msg}"') - res = b.run() - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - - def test_single_job_with_nonsense_shell(self): - b = self.batch() - j = b.new_job(shell='/bin/ajdsfoijasidojf') - j.command(f'echo "hello"') - res = b.run() - res_status = res.status() - assert res_status['state'] == 'failure', str((res_status, res.debug_info())) - - def test_single_job_with_intermediate_failure(self): - b = self.batch() - j = b.new_job() - j.command(f'echoddd "hello"') - j2 = b.new_job() - j2.command(f'echo "world"') - - res = b.run() - res_status = res.status() - assert res_status['state'] == 'failure', str((res_status, res.debug_info())) - - def test_input_directory(self): - b = self.batch() - input1 = b.read_input(self.cloud_input_dir) - input2 = b.read_input(self.cloud_input_dir.rstrip('/') + '/') - j = b.new_job() - j.command(f'ls {input1}/hello.txt') - j.command(f'ls {input2}/hello.txt') - res = b.run() - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - - def test_python_job(self): - b = self.batch(default_python_image=PYTHON_DILL_IMAGE) - head = b.new_job() - head.command(f'echo "5" > {head.r5}') - head.command(f'echo "3" > {head.r3}') - - def read(path): - with open(path, 'r') as f: - i = f.read() - return int(i) - - def multiply(x, y): - return x * y - - def reformat(x, y): - return {'x': x, 'y': y} - - middle = b.new_python_job() - r3 = middle.call(read, head.r3) - r5 = middle.call(read, head.r5) - r_mult = middle.call(multiply, r3, r5) - - middle2 = b.new_python_job() - r_mult = middle2.call(multiply, r_mult, 2) - r_dict = middle2.call(reformat, r3, r5) - - tail = b.new_job() - tail.command(f'cat {r3.as_str()} {r5.as_repr()} {r_mult.as_str()} {r_dict.as_json()}') - - res = b.run() - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - assert res.get_job_log(4)['main'] == "3\n5\n30\n{\"x\": 3, \"y\": 5}\n", str(res.debug_info()) - - def test_python_job_w_resource_group_unpack_individually(self): - b = self.batch(default_python_image=PYTHON_DILL_IMAGE) - head = b.new_job() - head.declare_resource_group(count={'r5': '{root}.r5', - 'r3': '{root}.r3'}) - assert isinstance(head.count, ResourceGroup) - - head.command(f'echo "5" > {head.count.r5}') - head.command(f'echo "3" > {head.count.r3}') - - def read(path): - with open(path, 'r') as f: - r = int(f.read()) - return r - - def multiply(x, y): - return x * y - - def reformat(x, y): - return {'x': x, 'y': y} - - middle = b.new_python_job() - r3 = middle.call(read, head.count.r3) - r5 = middle.call(read, head.count.r5) - r_mult = middle.call(multiply, r3, r5) - - middle2 = b.new_python_job() - r_mult = middle2.call(multiply, r_mult, 2) - r_dict = middle2.call(reformat, r3, r5) - - tail = b.new_job() - tail.command(f'cat {r3.as_str()} {r5.as_repr()} {r_mult.as_str()} {r_dict.as_json()}') - - res = b.run() - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - assert res.get_job_log(4)['main'] == "3\n5\n30\n{\"x\": 3, \"y\": 5}\n", str(res.debug_info()) - - def test_python_job_can_write_to_resource_path(self): - b = self.batch(default_python_image=PYTHON_DILL_IMAGE) - - def write(path): - with open(path, 'w') as f: - f.write('foo') - head = b.new_python_job() - head.call(write, head.ofile) - - tail = b.new_bash_job() - tail.command(f'cat {head.ofile}') - - res = b.run() - assert res - assert tail._job_id - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - assert res.get_job_log(tail._job_id)['main'] == 'foo', str(res.debug_info()) - - def test_python_job_w_resource_group_unpack_jointly(self): - b = self.batch(default_python_image=PYTHON_DILL_IMAGE) - head = b.new_job() - head.declare_resource_group(count={'r5': '{root}.r5', - 'r3': '{root}.r3'}) - assert isinstance(head.count, ResourceGroup) - - head.command(f'echo "5" > {head.count.r5}') - head.command(f'echo "3" > {head.count.r3}') - - def read_rg(root): - with open(root['r3'], 'r') as f: - r3 = int(f.read()) - with open(root['r5'], 'r') as f: - r5 = int(f.read()) - return (r3, r5) - - def multiply(r): - x, y = r - return x * y - - middle = b.new_python_job() - r = middle.call(read_rg, head.count) - r_mult = middle.call(multiply, r) - - tail = b.new_job() - tail.command(f'cat {r_mult.as_str()}') - - res = b.run() - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - job_log_3 = res.get_job_log(3) - assert job_log_3['main'] == "15\n", str((job_log_3, res.debug_info())) - - def test_python_job_w_non_zero_ec(self): - b = self.batch(default_python_image=PYTHON_DILL_IMAGE) - j = b.new_python_job() - - def error(): - raise Exception("this should fail") - - j.call(error) - res = b.run() - res_status = res.status() - assert res_status['state'] == 'failure', str((res_status, res.debug_info())) - - def test_python_job_incorrect_signature(self): - b = self.batch(default_python_image=PYTHON_DILL_IMAGE) - - def foo(pos_arg1, pos_arg2, *, kwarg1, kwarg2=1): - print(pos_arg1, pos_arg2, kwarg1, kwarg2) - - j = b.new_python_job() - - with pytest.raises(BatchException): - j.call(foo) - with pytest.raises(BatchException): - j.call(foo, 1) - with pytest.raises(BatchException): - j.call(foo, 1, 2) - with pytest.raises(BatchException): - j.call(foo, 1, kwarg1=2) - with pytest.raises(BatchException): - j.call(foo, 1, 2, 3) - with pytest.raises(BatchException): - j.call(foo, 1, 2, kwarg1=3, kwarg2=4, kwarg3=5) - - j.call(foo, 1, 2, kwarg1=3) - j.call(foo, 1, 2, kwarg1=3, kwarg2=4) - - # `print` doesn't have a signature but other builtins like `abs` do - j.call(print, 5) - j.call(abs, -1) - with pytest.raises(BatchException): - j.call(abs, -1, 5) - - def test_fail_fast(self): - b = self.batch(cancel_after_n_failures=1) - - j1 = b.new_job() - j1.command('false') - - j2 = b.new_job() - j2.command('sleep 300') - - res = b.run() - job_status = res.get_job(2).status() - assert job_status['state'] == 'Cancelled', str((job_status, res.debug_info())) - - def test_service_backend_remote_tempdir_with_trailing_slash(self): - backend = ServiceBackend(remote_tmpdir=f'{self.remote_tmpdir}/temporary-files/') - b = Batch(backend=backend) - j1 = b.new_job() - j1.command(f'echo hello > {j1.ofile}') - j2 = b.new_job() - j2.command(f'cat {j1.ofile}') - b.run() - - def test_service_backend_remote_tempdir_with_no_trailing_slash(self): - backend = ServiceBackend(remote_tmpdir=f'{self.remote_tmpdir}/temporary-files') - b = Batch(backend=backend) - j1 = b.new_job() - j1.command(f'echo hello > {j1.ofile}') - j2 = b.new_job() - j2.command(f'cat {j1.ofile}') - b.run() - - def test_large_command(self): - backend = ServiceBackend(remote_tmpdir=f'{self.remote_tmpdir}/temporary-files') - b = Batch(backend=backend) - j1 = b.new_job() - long_str = secrets.token_urlsafe(15 * 1024) - j1.command(f'echo "{long_str}"') - b.run() - - def test_big_batch_which_uses_slow_path(self): - backend = ServiceBackend(remote_tmpdir=f'{self.remote_tmpdir}/temporary-files') - b = Batch(backend=backend) - # 8 * 256 * 1024 = 2 MiB > 1 MiB max bunch size - for _ in range(8): - j1 = b.new_job() - long_str = secrets.token_urlsafe(256 * 1024) - j1.command(f'echo "{long_str}" > /dev/null') - batch = b.run() - assert not batch._submission_info.used_fast_path - batch_status = batch.status() - assert batch_status['state'] == 'success', str((batch.debug_info())) - - def test_query_on_batch_in_batch(self): - sb = ServiceBackend(remote_tmpdir=f'{self.remote_tmpdir}/temporary-files') - bb = Batch(backend=sb, default_python_image=HAIL_GENETICS_HAIL_IMAGE) - - tmp_ht_path = self.remote_tmpdir + '/' + secrets.token_urlsafe(32) - - def qob_in_batch(): - import hail as hl - hl.utils.range_table(10).write(tmp_ht_path, overwrite=True) - - j = bb.new_python_job() - j.env('HAIL_QUERY_BACKEND', 'batch') - j.env('HAIL_BATCH_BILLING_PROJECT', configuration_of(ConfigVariable.BATCH_BILLING_PROJECT, None, '')) - j.env('HAIL_BATCH_REMOTE_TMPDIR', self.remote_tmpdir) - j.call(qob_in_batch) - - bb.run() - - def test_basic_async_fun(self): - backend = ServiceBackend(remote_tmpdir=f'{self.remote_tmpdir}/temporary-files') - b = Batch(backend=backend) - - j = b.new_python_job() - j.call(asyncio.sleep, 1) - - batch = b.run() - batch_status = batch.status() - assert batch_status['state'] == 'success', str((batch.debug_info())) - - def test_async_fun_returns_value(self): - backend = ServiceBackend(remote_tmpdir=f'{self.remote_tmpdir}/temporary-files') - b = Batch(backend=backend) - - async def foo(i, j): - await asyncio.sleep(1) - return i * j - - j = b.new_python_job() - result = j.call(foo, 2, 3) - - j = b.new_job() - j.command(f'cat {result.as_str()}') - - batch = b.run() - batch_status = batch.status() - assert batch_status['state'] == 'success', str((batch_status, batch.debug_info())) - job_log_2 = batch.get_job_log(2) - assert job_log_2['main'] == "6\n", str((job_log_2, batch.debug_info())) - - def test_specify_job_region(self): - b = self.batch(cancel_after_n_failures=1) - j = b.new_job('region') - possible_regions = self.backend.supported_regions() - j.regions(possible_regions) - j.command('true') - res = b.run() - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - - def test_always_copy_output(self): - output_path = f'{self.cloud_output_dir}/test_always_copy_output.txt' - - b = self.batch() - j = b.new_job() - j.always_copy_output() - j.command(f'echo "hello" > {j.ofile} && false') - - b.write_output(j.ofile, output_path) - res = b.run() - res_status = res.status() - assert res_status['state'] == 'failure', str((res_status, res.debug_info())) - - b2 = self.batch() - input = b2.read_input(output_path) - file_exists_j = b2.new_job() - file_exists_j.command(f'cat {input}') - - res = b2.run() - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - assert res.get_job_log(1)['main'] == "hello\n", str(res.debug_info()) - - def test_no_copy_output_on_failure(self): - output_path = f'{self.cloud_output_dir}/test_no_copy_output.txt' - - b = self.batch() - j = b.new_job() - j.command(f'echo "hello" > {j.ofile} && false') - - b.write_output(j.ofile, output_path) - res = b.run() - res_status = res.status() - assert res_status['state'] == 'failure', str((res_status, res.debug_info())) - - b2 = self.batch() - input = b2.read_input(output_path) - file_exists_j = b2.new_job() - file_exists_j.command(f'cat {input}') - - res = b2.run() - res_status = res.status() - assert res_status['state'] == 'failure', str((res_status, res.debug_info())) - - def test_update_batch(self): - b = self.batch() - j = b.new_job() - j.command('true') - res = b.run() - - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - - j2 = b.new_job() - j2.command('true') - res = b.run() - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - - def test_update_batch_with_dependencies(self): - b = self.batch() - j1 = b.new_job() - j1.command('true') - j2 = b.new_job() - j2.command('false') - res = b.run() - - res_status = res.status() - assert res_status['state'] == 'failure', str((res_status, res.debug_info())) - - j3 = b.new_job() - j3.command('true') - j3.depends_on(j1) - - j4 = b.new_job() - j4.command('true') - j4.depends_on(j2) - - res = b.run() - res_status = res.status() - assert res_status['state'] == 'failure', str((res_status, res.debug_info())) - - assert res.get_job(3).status()['state'] == 'Success', str((res_status, res.debug_info())) - assert res.get_job(4).status()['state'] == 'Cancelled', str((res_status, res.debug_info())) - - def test_update_batch_with_python_job_dependencies(self): - b = self.batch() - - async def foo(i, j): - await asyncio.sleep(1) - return i * j - - j1 = b.new_python_job() - j1.call(foo, 2, 3) - - batch = b.run() - batch_status = batch.status() - assert batch_status['state'] == 'success', str((batch_status, batch.debug_info())) - - j2 = b.new_python_job() - j2.call(foo, 2, 3) - - batch = b.run() - batch_status = batch.status() - assert batch_status['state'] == 'success', str((batch_status, batch.debug_info())) - - j3 = b.new_python_job() - j3.depends_on(j2) - j3.call(foo, 2, 3) - - batch = b.run() - batch_status = batch.status() - assert batch_status['state'] == 'success', str((batch_status, batch.debug_info())) - - def test_update_batch_from_batch_id(self): - b = self.batch() - j = b.new_job() - j.command('true') - res = b.run() - - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - - b2 = Batch.from_batch_id(res.id, backend=b._backend) - j2 = b2.new_job() - j2.command('true') - res = b2.run() - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - - def test_python_job_with_kwarg(self): - def foo(*, kwarg): - return kwarg - - b = self.batch(default_python_image=PYTHON_DILL_IMAGE) - j = b.new_python_job() - r = j.call(foo, kwarg='hello world') - - output_path = f'{self.cloud_output_dir}/test_python_job_with_kwarg' - b.write_output(r.as_json(), output_path) - res = b.run() - assert isinstance(res, bc.Batch) - - assert res.status()['state'] == 'success', str((res, res.debug_info())) - with hfs.open(output_path) as f: - assert orjson.loads(f.read()) == 'hello world' - - def test_tuple_recursive_resource_extraction_in_python_jobs(self): - b = self.batch(default_python_image=PYTHON_DILL_IMAGE) - - def write(paths): - if not isinstance(paths, tuple): - raise ValueError('paths must be a tuple') - for i, path in enumerate(paths): - with open(path, 'w') as f: - f.write(f'{i}') - - head = b.new_python_job() - head.call(write, (head.ofile1, head.ofile2)) - - tail = b.new_bash_job() - tail.command(f'cat {head.ofile1}') - tail.command(f'cat {head.ofile2}') - - res = b.run() - assert res - assert tail._job_id - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - assert res.get_job_log(tail._job_id)['main'] == '01', str(res.debug_info()) - - def test_list_recursive_resource_extraction_in_python_jobs(self): - b = self.batch(default_python_image=PYTHON_DILL_IMAGE) - - def write(paths): - for i, path in enumerate(paths): - with open(path, 'w') as f: - f.write(f'{i}') - - head = b.new_python_job() - head.call(write, [head.ofile1, head.ofile2]) - - tail = b.new_bash_job() - tail.command(f'cat {head.ofile1}') - tail.command(f'cat {head.ofile2}') - - res = b.run() - assert res - assert tail._job_id - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - assert res.get_job_log(tail._job_id)['main'] == '01', str(res.debug_info()) - - def test_dict_recursive_resource_extraction_in_python_jobs(self): - b = self.batch(default_python_image=PYTHON_DILL_IMAGE) - - def write(kwargs): - for k, v in kwargs.items(): - with open(v, 'w') as f: - f.write(k) - - head = b.new_python_job() - head.call(write, {'a': head.ofile1, 'b': head.ofile2}) - - tail = b.new_bash_job() - tail.command(f'cat {head.ofile1}') - tail.command(f'cat {head.ofile2}') - - res = b.run() - assert res - assert tail._job_id - res_status = res.status() - assert res_status['state'] == 'success', str((res_status, res.debug_info())) - assert res.get_job_log(tail._job_id)['main'] == 'ab', str(res.debug_info()) - - def test_wait_on_empty_batch_update(self): - b = self.batch() - b.run(wait=True) - b.run(wait=True) - - def test_non_spot_job(self): - b = self.batch() - j = b.new_job() - j.spot(False) - j.command('echo hello') - res = b.run() - assert res is not None - assert res.get_job(1).status()['spec']['resources']['preemptible'] == False - - def test_spot_unspecified_job(self): - b = self.batch() - j = b.new_job() - j.command('echo hello') - res = b.run() - assert res is not None - assert res.get_job(1).status()['spec']['resources']['preemptible'] == True - - def test_spot_true_job(self): - b = self.batch() - j = b.new_job() - j.spot(True) - j.command('echo hello') - res = b.run() - assert res is not None - assert res.get_job(1).status()['spec']['resources']['preemptible'] == True - - def test_non_spot_batch(self): - b = self.batch(default_spot=False) - j1 = b.new_job() - j1.command('echo hello') - j2 = b.new_job() - j2.command('echo hello') - j3 = b.new_job() - j3.spot(True) - j3.command('echo hello') - res = b.run() - assert res is not None - assert res.get_job(1).status()['spec']['resources']['preemptible'] == False - assert res.get_job(2).status()['spec']['resources']['preemptible'] == False - assert res.get_job(3).status()['spec']['resources']['preemptible'] == True - - def test_local_file_paths_error(self): - b = self.batch() - j = b.new_job() - for input in ["hi.txt", "~/hello.csv", "./hey.tsv", "/sup.json", "file://yo.yaml"]: - with pytest.raises(ValueError) as e: - b.read_input(input) - assert str(e.value).startswith("Local filepath detected") - - @skip_in_azure - def test_validate_cloud_storage_policy(self): - # buckets do not exist (bucket names can't contain the string "google" per - # https://cloud.google.com/storage/docs/buckets) - fake_bucket1 = "google" - fake_bucket2 = "google1" - no_bucket_error = "bucket does not exist" - # bucket exists, but account does not have permissions on it - no_perms_bucket = "test" - no_perms_error = "does not have storage.buckets.get access" - # bucket exists and account has permissions, but is set to use cold storage by default - cold_bucket = "hail-test-cold-storage" - cold_error = "configured to use cold storage by default" - fake_uri1, fake_uri2, no_perms_uri, cold_uri = [ - f"gs://{bucket}/test" for bucket in [fake_bucket1, fake_bucket2, no_perms_bucket, cold_bucket] - ] - - def _test_raises(exception_type, exception_msg, func): - with pytest.raises(exception_type) as e: - func() - assert exception_msg in str(e.value) - - def _test_raises_no_bucket_error(remote_tmpdir, arg = None): - _test_raises(ClientResponseError, no_bucket_error, lambda: ServiceBackend(remote_tmpdir=remote_tmpdir, gcs_bucket_allow_list=arg)) - - def _test_raises_cold_error(func): - _test_raises(ValueError, cold_error, func) - - # no configuration, nonexistent buckets error - _test_raises_no_bucket_error(fake_uri1) - _test_raises_no_bucket_error(fake_uri2) - - # no configuration, no perms bucket errors - _test_raises(ClientResponseError, no_perms_error, lambda: ServiceBackend(remote_tmpdir=no_perms_uri)) - - # no configuration, cold bucket errors - _test_raises_cold_error(lambda: ServiceBackend(remote_tmpdir=cold_uri)) - b = self.batch() - _test_raises_cold_error(lambda: b.read_input(cold_uri)) - j = b.new_job() - j.command(f"echo hello > {j.ofile}") - _test_raises_cold_error(lambda: b.write_output(j.ofile, cold_uri)) - - # hailctl config, allowlisted nonexistent buckets don't error - base_config = get_user_config() - local_config = ConfigParser() - local_config.read_dict({ - **{ - section: {key: val for key, val in base_config[section].items()} - for section in base_config.sections() - }, - **{"gcs": {"bucket_allow_list": f"{fake_bucket1},{fake_bucket2}"}} - }) - def _get_user_config(): - return local_config - self.monkeypatch.setattr(user_config, "get_user_config", _get_user_config) - ServiceBackend(remote_tmpdir=fake_uri1) - ServiceBackend(remote_tmpdir=fake_uri2) - - # environment variable config, only allowlisted nonexistent buckets don't error - self.monkeypatch.setenv("HAIL_GCS_BUCKET_ALLOW_LIST", fake_bucket2) - _test_raises_no_bucket_error(fake_uri1) - ServiceBackend(remote_tmpdir=fake_uri2) - - # arg to constructor config, only allowlisted nonexistent buckets don't error - arg = [fake_bucket1] - ServiceBackend(remote_tmpdir=fake_uri1, gcs_bucket_allow_list=arg) - _test_raises_no_bucket_error(fake_uri2, arg) diff --git a/hail/python/test/hailtop/batch/test_batch_local_backend.py b/hail/python/test/hailtop/batch/test_batch_local_backend.py new file mode 100644 index 00000000000..dc059237d4c --- /dev/null +++ b/hail/python/test/hailtop/batch/test_batch_local_backend.py @@ -0,0 +1,506 @@ +import os +import subprocess as sp +import tempfile +from typing import AsyncIterator + +import pytest + +from hailtop import pip_version +from hailtop.batch import Batch, LocalBackend, ResourceGroup +from hailtop.batch.resource import JobResourceFile +from hailtop.batch.utils import concatenate + +DOCKER_ROOT_IMAGE = os.environ.get('DOCKER_ROOT_IMAGE', 'ubuntu:22.04') +PYTHON_DILL_IMAGE = 'hailgenetics/python-dill:3.9-slim' +HAIL_GENETICS_HAIL_IMAGE = os.environ.get('HAIL_GENETICS_HAIL_IMAGE', f'hailgenetics/hail:{pip_version()}') +REQUESTER_PAYS_PROJECT = os.environ.get('GCS_REQUESTER_PAYS_PROJECT') + + +@pytest.fixture(scope="session") +async def backend() -> AsyncIterator[LocalBackend]: + lb = LocalBackend() + try: + yield lb + finally: + await lb.async_close() + + +@pytest.fixture +def batch(backend, requester_pays_project=None): + return Batch(backend=backend, requester_pays_project=requester_pays_project) + + +def test_read_input_and_write_output(batch): + with tempfile.NamedTemporaryFile('w') as input_file, tempfile.NamedTemporaryFile('w') as output_file: + input_file.write('abc') + input_file.flush() + + b = batch + input = b.read_input(input_file.name) + b.write_output(input, output_file.name) + b.run() + + assert open(input_file.name).read() == open(output_file.name).read() + + +def test_read_input_group(batch): + with tempfile.NamedTemporaryFile('w') as input_file1, tempfile.NamedTemporaryFile( + 'w' + ) as input_file2, tempfile.NamedTemporaryFile('w') as output_file1, tempfile.NamedTemporaryFile( + 'w' + ) as output_file2: + input_file1.write('abc') + input_file2.write('123') + input_file1.flush() + input_file2.flush() + + b = batch + input = b.read_input_group(in1=input_file1.name, in2=input_file2.name) + + b.write_output(input.in1, output_file1.name) + b.write_output(input.in2, output_file2.name) + b.run() + + assert open(input_file1.name).read() == open(output_file1.name).read() + assert open(input_file2.name).read() == open(output_file2.name).read() + + +def test_write_resource_group(batch): + with tempfile.NamedTemporaryFile('w') as input_file1, tempfile.NamedTemporaryFile( + 'w' + ) as input_file2, tempfile.TemporaryDirectory() as output_dir: + b = batch + input = b.read_input_group(in1=input_file1.name, in2=input_file2.name) + + b.write_output(input, output_dir + '/foo') + b.run() + + assert open(input_file1.name).read() == open(output_dir + '/foo.in1').read() + assert open(input_file2.name).read() == open(output_dir + '/foo.in2').read() + + +def test_single_job(batch): + with tempfile.NamedTemporaryFile('w') as output_file: + msg = 'hello world' + + b = batch + j = b.new_job() + j.command(f'printf "{msg}" > {j.ofile}') + b.write_output(j.ofile, output_file.name) + b.run() + + assert open(output_file.name).read() == msg + + +def test_single_job_with_shell(batch): + with tempfile.NamedTemporaryFile('w') as output_file: + msg = 'hello world' + + b = batch + j = b.new_job(shell='/bin/bash') + j.command(f'printf "{msg}" > {j.ofile}') + + b.write_output(j.ofile, output_file.name) + b.run() + + assert open(output_file.name).read() == msg + + +def test_single_job_with_nonsense_shell(batch): + b = batch + j = b.new_job(shell='/bin/ajdsfoijasidojf') + j.image(DOCKER_ROOT_IMAGE) + j.command('printf "hello"') + with pytest.raises(Exception): + b.run() + + b = batch + j = b.new_job(shell='/bin/nonexistent') + j.command('printf "hello"') + with pytest.raises(Exception): + b.run() + + +def test_single_job_with_intermediate_failure(batch): + b = batch + j = b.new_job() + j.command('echoddd "hello"') + j2 = b.new_job() + j2.command('echo "world"') + + with pytest.raises(Exception): + b.run() + + +def test_single_job_w_input(batch): + with tempfile.NamedTemporaryFile('w') as input_file, tempfile.NamedTemporaryFile('w') as output_file: + msg = 'abc' + input_file.write(msg) + input_file.flush() + + b = batch + input = b.read_input(input_file.name) + j = b.new_job() + j.command(f'cat {input} > {j.ofile}') + b.write_output(j.ofile, output_file.name) + b.run() + + assert open(output_file.name).read() == msg + + +def test_single_job_w_input_group(batch): + with tempfile.NamedTemporaryFile('w') as input_file1, tempfile.NamedTemporaryFile( + 'w' + ) as input_file2, tempfile.NamedTemporaryFile('w') as output_file: + msg1 = 'abc' + msg2 = '123' + + input_file1.write(msg1) + input_file2.write(msg2) + input_file1.flush() + input_file2.flush() + + b = batch + input = b.read_input_group(in1=input_file1.name, in2=input_file2.name) + j = b.new_job() + j.command(f'cat {input.in1} {input.in2} > {j.ofile}') + j.command(f'cat {input}.in1 {input}.in2') + b.write_output(j.ofile, output_file.name) + b.run() + + assert open(output_file.name).read() == msg1 + msg2 + + +def test_single_job_bad_command(batch): + b = batch + j = b.new_job() + j.command("foo") # this should fail! + with pytest.raises(sp.CalledProcessError): + b.run() + + +def test_declare_resource_group(batch): + with tempfile.NamedTemporaryFile('w') as output_file: + msg = 'hello world' + b = batch + j = b.new_job() + j.declare_resource_group(ofile={'log': "{root}.txt"}) + assert isinstance(j.ofile, ResourceGroup) + j.command(f'printf "{msg}" > {j.ofile.log}') + b.write_output(j.ofile.log, output_file.name) + b.run() + + assert open(output_file.name).read() == msg + + +def test_resource_group_get_all_inputs(batch): + b = batch + input = b.read_input_group(fasta="foo", idx="bar") + j = b.new_job() + j.command(f"cat {input.fasta}") + assert input.fasta in j._inputs + assert input.idx in j._inputs + + +def test_resource_group_get_all_mentioned(batch): + b = batch + j = b.new_job() + j.declare_resource_group(foo={'bed': '{root}.bed', 'bim': '{root}.bim'}) + assert isinstance(j.foo, ResourceGroup) + j.command(f"cat {j.foo.bed}") + assert j.foo.bed in j._mentioned + assert j.foo.bim not in j._mentioned + + +def test_resource_group_get_all_mentioned_dependent_jobs(batch): + b = batch + j = b.new_job() + j.declare_resource_group(foo={'bed': '{root}.bed', 'bim': '{root}.bim'}) + j.command("cat") + j2 = b.new_job() + j2.command(f"cat {j.foo}") + + +def test_resource_group_get_all_outputs(batch): + b = batch + j1 = b.new_job() + j1.declare_resource_group(foo={'bed': '{root}.bed', 'bim': '{root}.bim'}) + assert isinstance(j1.foo, ResourceGroup) + j1.command(f"cat {j1.foo.bed}") + j2 = b.new_job() + j2.command(f"cat {j1.foo.bed}") + + for r in [j1.foo.bed, j1.foo.bim]: + assert r in j1._internal_outputs + assert r in j2._inputs + + assert j1.foo.bed in j1._mentioned + assert j1.foo.bim not in j1._mentioned + + assert j1.foo.bed in j2._mentioned + assert j1.foo.bim not in j2._mentioned + + assert j1.foo not in j1._mentioned + + +def test_multiple_isolated_jobs(batch): + b = batch + + output_files = [] + try: + output_files = [tempfile.NamedTemporaryFile('w') for _ in range(5)] + + for i, ofile in enumerate(output_files): + msg = f'hello world {i}' + j = b.new_job() + j.command(f'printf "{msg}" > {j.ofile}') + b.write_output(j.ofile, ofile.name) + b.run() + + for i, ofile in enumerate(output_files): + msg = f'hello world {i}' + assert open(ofile.name).read() == msg + finally: + [ofile.close() for ofile in output_files] + + +def test_multiple_dependent_jobs(batch): + with tempfile.NamedTemporaryFile('w') as output_file: + b = batch + j = b.new_job() + j.command(f'echo "0" >> {j.ofile}') + + for i in range(1, 3): + j2 = b.new_job() + j2.command(f'echo "{i}" > {j2.tmp1}') + j2.command(f'cat {j.ofile} {j2.tmp1} > {j2.ofile}') + j = j2 + + b.write_output(j.ofile, output_file.name) + b.run() + + assert open(output_file.name).read() == "0\n1\n2\n" + + +def test_select_jobs(batch): + b = batch + for i in range(3): + b.new_job(name=f'foo{i}') + assert len(b.select_jobs('foo')) == 3 + + +def test_scatter_gather(batch): + with tempfile.NamedTemporaryFile('w') as output_file: + b = batch + + for i in range(3): + j = b.new_job(name=f'foo{i}') + j.command(f'echo "{i}" > {j.ofile}') + + merger = b.new_job() + merger.command( + 'cat {files} > {ofile}'.format( + files=' '.join( + [j.ofile for j in sorted(b.select_jobs('foo'), key=lambda x: x.name, reverse=True)] # type: ignore + ), + ofile=merger.ofile, + ) + ) + + b.write_output(merger.ofile, output_file.name) + b.run() + + assert open(output_file.name).read() == '2\n1\n0\n' + + +def test_add_extension_job_resource_file(batch): + b = batch + j = b.new_job() + j.command(f'echo "hello" > {j.ofile}') + assert isinstance(j.ofile, JobResourceFile) + j.ofile.add_extension('.txt.bgz') + assert j.ofile._value + assert j.ofile._value.endswith('.txt.bgz') + + +def test_add_extension_input_resource_file(batch): + input_file1 = '/tmp/data/example1.txt.bgz.foo' + b = batch + in1 = b.read_input(input_file1) + assert in1._value + assert in1._value.endswith('.txt.bgz.foo') + + +def test_file_name_space(batch): + with tempfile.NamedTemporaryFile( + 'w', prefix="some file name with (foo) spaces" + ) as input_file, tempfile.NamedTemporaryFile('w', prefix="another file name with (foo) spaces") as output_file: + input_file.write('abc') + input_file.flush() + + b = batch + input = b.read_input(input_file.name) + j = b.new_job() + j.command(f'cat {input} > {j.ofile}') + b.write_output(j.ofile, output_file.name) + b.run() + + assert open(input_file.name).read() == open(output_file.name).read() + + +def test_resource_group_mentioned(batch): + b = batch + j = b.new_job() + j.declare_resource_group(foo={'bed': '{root}.bed'}) + assert isinstance(j.foo, ResourceGroup) + j.command(f'echo "hello" > {j.foo}') + + t2 = b.new_job() + t2.command(f'echo "hello" >> {j.foo.bed}') + b.run() + + +def test_envvar(batch): + with tempfile.NamedTemporaryFile('w') as output_file: + b = batch + j = b.new_job() + j.env('SOME_VARIABLE', '123abcdef') + j.command(f'printf $SOME_VARIABLE > {j.ofile}') + b.write_output(j.ofile, output_file.name) + b.run() + assert open(output_file.name).read() == '123abcdef' + + +def test_concatenate(batch): + b = batch + files = [] + for _ in range(10): + j = b.new_job() + j.command(f'touch {j.ofile}') + files.append(j.ofile) + concatenate(b, files, branching_factor=2) + assert len(b._jobs) == 10 + (5 + 3 + 2 + 1) + b.run() + + +def test_python_job(batch): + with tempfile.NamedTemporaryFile('w') as output_file: + b = batch + head = b.new_job() + head.command(f'echo "5" > {head.r5}') + head.command(f'echo "3" > {head.r3}') + + def read(path): + with open(path, 'r') as f: + i = f.read() + return int(i) + + def multiply(x, y): + return x * y + + def reformat(x, y): + return {'x': x, 'y': y} + + middle = b.new_python_job() + r3 = middle.call(read, head.r3) + r5 = middle.call(read, head.r5) + r_mult = middle.call(multiply, r3, r5) + + middle2 = b.new_python_job() + r_mult = middle2.call(multiply, r_mult, 2) + r_dict = middle2.call(reformat, r3, r5) + + tail = b.new_job() + tail.command(f'cat {r3.as_str()} {r5.as_repr()} {r_mult.as_str()} {r_dict.as_json()} > {tail.ofile}') + + b.write_output(tail.ofile, output_file.name) + b.run() + assert open(output_file.name).read() == '3\n5\n30\n{"x": 3, "y": 5}\n' + + +def test_backend_context_manager(): + with LocalBackend() as backend: + b = Batch(backend=backend) + b.run() + + +def test_failed_jobs_dont_stop_non_dependent_jobs(batch): + with tempfile.NamedTemporaryFile('w') as output_file: + b = batch + + head = b.new_job() + head.command(f'printf 1 > {head.ofile}') + + head2 = b.new_job() + head2.command('false') + + tail = b.new_job() + tail.command(f'cat {head.ofile} > {tail.ofile}') + b.write_output(tail.ofile, output_file.name) + with pytest.raises(Exception): + b.run() + assert open(output_file.name).read() == '1' + + +def test_failed_jobs_stop_child_jobs(batch): + with tempfile.NamedTemporaryFile('w') as output_file: + b = batch + + head = b.new_job() + head.command(f'printf 1 > {head.ofile}') + head.command('false') + + head2 = b.new_job() + head2.command(f'printf 2 > {head2.ofile}') + + tail = b.new_job() + tail.command(f'cat {head.ofile} > {tail.ofile}') + + b.write_output(head2.ofile, output_file.name) + b.write_output(tail.ofile, output_file.name) + with pytest.raises(Exception): + b.run() + assert open(output_file.name).read() == '2' + + +def test_failed_jobs_stop_grandchild_jobs(batch): + with tempfile.NamedTemporaryFile('w') as output_file: + b = batch + + head = b.new_job() + head.command(f'printf 1 > {head.ofile}') + head.command('false') + + head2 = b.new_job() + head2.command(f'printf 2 > {head2.ofile}') + + tail = b.new_job() + tail.command(f'cat {head.ofile} > {tail.ofile}') + + tail2 = b.new_job() + tail2.depends_on(tail) + tail2.command(f'printf foo > {tail2.ofile}') + + b.write_output(head2.ofile, output_file.name) + b.write_output(tail2.ofile, output_file.name) + with pytest.raises(Exception): + b.run() + assert open(output_file.name).read() == '2' + + +def test_failed_jobs_dont_stop_always_run_jobs(batch): + with tempfile.NamedTemporaryFile('w') as output_file: + b = batch + + head = b.new_job() + head.command(f'printf 1 > {head.ofile}') + head.command('false') + + tail = b.new_job() + tail.command(f'cat {head.ofile} > {tail.ofile}') + tail.always_run() + + b.write_output(tail.ofile, output_file.name) + with pytest.raises(Exception): + b.run() + assert open(output_file.name).read() == '1' diff --git a/hail/python/test/hailtop/batch/test_batch_pool_executor.py b/hail/python/test/hailtop/batch/test_batch_pool_executor.py index 97d4f6a0364..29a20958a3c 100644 --- a/hail/python/test/hailtop/batch/test_batch_pool_executor.py +++ b/hail/python/test/hailtop/batch/test_batch_pool_executor.py @@ -4,9 +4,8 @@ import pytest from hailtop.batch import BatchPoolExecutor, ServiceBackend -from hailtop.config import get_user_config -from hailtop.utils import sync_sleep_before_try from hailtop.batch_client.client import BatchClient +from hailtop.utils import sync_sleep_before_try PYTHON_DILL_IMAGE = 'hailgenetics/python-dill:3.9' @@ -62,6 +61,7 @@ def test_simple_submit_result(backend): def test_cancel_future(backend): with BatchPoolExecutor(backend=backend, project='hail-vdc', image=PYTHON_DILL_IMAGE) as bpe: + def sleep_forever(): while True: time.sleep(3600) @@ -74,6 +74,7 @@ def sleep_forever(): def test_cancel_future_after_shutdown_no_wait(backend): bpe = BatchPoolExecutor(backend=backend, project='hail-vdc', image=PYTHON_DILL_IMAGE) + def sleep_forever(): while True: time.sleep(3600) @@ -87,6 +88,7 @@ def sleep_forever(): def test_cancel_future_after_exit_no_wait_on_exit(backend): with BatchPoolExecutor(backend=backend, project='hail-vdc', wait_on_exit=False, image=PYTHON_DILL_IMAGE) as bpe: + def sleep_forever(): while True: time.sleep(3600) @@ -99,6 +101,7 @@ def sleep_forever(): def test_result_with_timeout(backend): with BatchPoolExecutor(backend=backend, project='hail-vdc', image=PYTHON_DILL_IMAGE) as bpe: + def sleep_forever(): while True: time.sleep(3600) @@ -115,30 +118,20 @@ def sleep_forever(): def test_map_chunksize(backend): - row_args = [x - for row in range(5) - for x in [row, row, row, row, row]] - col_args = [x - for _ in range(5) - for x in list(range(5))] + row_args = [x for row in range(5) for x in [row, row, row, row, row]] + col_args = [x for _ in range(5) for x in list(range(5))] with BatchPoolExecutor(backend=backend, project='hail-vdc', image=PYTHON_DILL_IMAGE) as bpe: - multiplication_table = list(bpe.map(lambda x, y: x * y, - row_args, - col_args, - chunksize=5)) - assert multiplication_table == [ - 0, 0, 0, 0, 0, - 0, 1, 2, 3, 4, - 0, 2, 4, 6, 8, - 0, 3, 6, 9, 12, - 0, 4, 8, 12, 16] + multiplication_table = list(bpe.map(lambda x, y: x * y, row_args, col_args, chunksize=5)) + assert multiplication_table == [0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 0, 2, 4, 6, 8, 0, 3, 6, 9, 12, 0, 4, 8, 12, 16] def test_map_timeout(backend): with BatchPoolExecutor(backend=backend, project='hail-vdc', image=PYTHON_DILL_IMAGE) as bpe: + def sleep_forever(): while True: time.sleep(3600) + try: list(bpe.map(lambda _: sleep_forever(), range(5), timeout=2)) except asyncio.TimeoutError: @@ -155,6 +148,7 @@ def test_map_error_without_wait_no_error(backend): def test_exception_in_map(backend): def raise_value_error(): raise ValueError('dead') + with BatchPoolExecutor(backend=backend, project='hail-vdc', image=PYTHON_DILL_IMAGE) as bpe: try: gen = bpe.map(lambda _: raise_value_error(), range(5)) @@ -168,6 +162,7 @@ def raise_value_error(): def test_exception_in_result(backend): def raise_value_error(): raise ValueError('dead') + with BatchPoolExecutor(backend=backend, project='hail-vdc', image=PYTHON_DILL_IMAGE) as bpe: try: future = bpe.submit(raise_value_error) @@ -181,6 +176,7 @@ def raise_value_error(): def test_exception_in_exception(backend): def raise_value_error(): raise ValueError('dead') + with BatchPoolExecutor(backend=backend, project='hail-vdc', image=PYTHON_DILL_IMAGE) as bpe: try: future = bpe.submit(raise_value_error) @@ -194,6 +190,7 @@ def raise_value_error(): def test_no_exception_when_exiting_context(backend): def raise_value_error(): raise ValueError('dead') + with BatchPoolExecutor(backend=backend, project='hail-vdc', image=PYTHON_DILL_IMAGE) as bpe: future = bpe.submit(raise_value_error) try: @@ -205,10 +202,7 @@ def raise_value_error(): def test_bad_image_gives_good_error(backend): - with BatchPoolExecutor( - backend=backend, - project='hail-vdc', - image='hailgenetics/not-a-valid-image:123abc') as bpe: + with BatchPoolExecutor(backend=backend, project='hail-vdc', image='hailgenetics/not-a-valid-image:123abc') as bpe: future = bpe.submit(lambda: 3) try: future.exception() @@ -220,6 +214,7 @@ def test_bad_image_gives_good_error(backend): def test_call_result_after_timeout(): with BatchPoolExecutor(project='hail-vdc', image=PYTHON_DILL_IMAGE) as bpe: + def sleep_forever(): while True: time.sleep(3600) diff --git a/hail/python/test/hailtop/batch/test_batch_service_backend.py b/hail/python/test/hailtop/batch/test_batch_service_backend.py new file mode 100644 index 00000000000..9f85fedeb67 --- /dev/null +++ b/hail/python/test/hailtop/batch/test_batch_service_backend.py @@ -0,0 +1,783 @@ +import os +import secrets +from configparser import ConfigParser +from shlex import quote as shq +from typing import Tuple + +import pytest + +from hailtop.aiotools.router_fs import RouterAsyncFS +from hailtop.batch import Batch, ResourceGroup, ServiceBackend +from hailtop.batch.exceptions import BatchException +from hailtop.batch.globals import arg_max +from hailtop.config import get_user_config, user_config +from hailtop.httpx import ClientResponseError +from hailtop.test_utils import skip_in_azure +from hailtop.utils import grouped + +from .utils import ( + REQUESTER_PAYS_PROJECT, + batch, +) + + +def test_single_task_no_io(service_backend: ServiceBackend): + b = batch(service_backend) + j = b.new_job() + j.command('echo hello') + res = b.run() + assert res + res_status = res.status() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + + +def test_single_task_input( + service_backend: ServiceBackend, upload_test_files: Tuple[Tuple[str, bytes], Tuple[str, bytes], Tuple[str, bytes]] +): + (url1, data1), _, _ = upload_test_files + b = batch(service_backend) + input = b.read_input(url1) + j = b.new_job() + j.command(f'cat {input}') + res = b.run() + assert res + res_status = res.status() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + + +def test_single_task_input_resource_group( + service_backend: ServiceBackend, upload_test_files: Tuple[Tuple[str, bytes], Tuple[str, bytes], Tuple[str, bytes]] +): + (url1, data1), _, _ = upload_test_files + b = batch(service_backend) + input = b.read_input_group(foo=url1) + j = b.new_job() + j.storage('10Gi') + j.command(f'cat {input.foo}') + j.command(f'cat {input}.foo') + res = b.run() + assert res + res_status = res.status() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + + +def test_single_task_output(service_backend: ServiceBackend): + b = batch(service_backend) + j = b.new_job(attributes={'a': 'bar', 'b': 'foo'}) + j.command(f'echo hello > {j.ofile}') + res = b.run() + assert res + res_status = res.status() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + + +def test_single_task_write_output(service_backend: ServiceBackend, output_tmpdir: str): + b = batch(service_backend) + j = b.new_job() + j.command(f'echo hello > {j.ofile}') + b.write_output(j.ofile, os.path.join(output_tmpdir, 'test_single_task_output.txt')) + res = b.run() + assert res + res_status = res.status() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + + +def test_single_task_resource_group(service_backend: ServiceBackend): + b = batch(service_backend) + j = b.new_job() + j.declare_resource_group(output={'foo': '{root}.foo'}) + assert isinstance(j.output, ResourceGroup) + j.command(f'echo "hello" > {j.output.foo}') + res = b.run() + assert res + res_status = res.status() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + + +def test_single_task_write_resource_group(service_backend: ServiceBackend, output_tmpdir: str): + b = batch(service_backend) + j = b.new_job() + j.declare_resource_group(output={'foo': '{root}.foo'}) + assert isinstance(j.output, ResourceGroup) + j.command(f'echo "hello" > {j.output.foo}') + b.write_output(j.output, os.path.join(output_tmpdir, 'test_single_task_write_resource_group')) + b.write_output(j.output.foo, os.path.join(output_tmpdir, 'test_single_task_write_resource_group_file.txt')) + res = b.run() + assert res + res_status = res.status() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + + +def test_multiple_dependent_tasks(service_backend: ServiceBackend, output_tmpdir: str): + output_file = os.path.join(output_tmpdir, 'test_multiple_dependent_tasks.txt') + b = batch(service_backend) + j = b.new_job() + j.command(f'echo "0" >> {j.ofile}') + + for i in range(1, 3): + j2 = b.new_job() + j2.command(f'echo "{i}" > {j2.tmp1}') + j2.command(f'cat {j.ofile} {j2.tmp1} > {j2.ofile}') + j = j2 + + b.write_output(j.ofile, output_file) + res = b.run() + assert res + res_status = res.status() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + + +def test_specify_cpu(service_backend: ServiceBackend): + b = batch(service_backend) + j = b.new_job() + j.cpu('0.5') + j.command(f'echo "hello" > {j.ofile}') + res = b.run() + assert res + res_status = res.status() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + + +def test_specify_memory(service_backend: ServiceBackend): + b = batch(service_backend) + j = b.new_job() + j.memory('100M') + j.command(f'echo "hello" > {j.ofile}') + res = b.run() + assert res + res_status = res.status() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + + +def test_scatter_gather(service_backend: ServiceBackend): + b = batch(service_backend) + + for i in range(3): + j = b.new_job(name=f'foo{i}') + j.command(f'echo "{i}" > {j.ofile}') + + merger = b.new_job() + merger.command( + 'cat {files} > {ofile}'.format( + files=' '.join( + [j.ofile for j in sorted(b.select_jobs('foo'), key=lambda x: x.name, reverse=True)] # type: ignore + ), + ofile=merger.ofile, + ) + ) + + res = b.run() + assert res + res_status = res.status() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + + +def test_file_name_space( + service_backend: ServiceBackend, + upload_test_files: Tuple[Tuple[str, bytes], Tuple[str, bytes], Tuple[str, bytes]], + output_tmpdir: str, +): + _, _, (url3, data3) = upload_test_files + b = batch(service_backend) + input = b.read_input(url3) + j = b.new_job() + j.command(f'cat {input} > {j.ofile}') + b.write_output(j.ofile, os.path.join(output_tmpdir, 'hello (foo) spaces.txt')) + res = b.run() + assert res + res_status = res.status() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + + +def test_dry_run(service_backend: ServiceBackend, output_tmpdir: str): + b = batch(service_backend) + j = b.new_job() + j.command(f'echo hello > {j.ofile}') + b.write_output(j.ofile, os.path.join(output_tmpdir, 'test_single_job_output.txt')) + b.run(dry_run=True) + + +def test_verbose( + service_backend: ServiceBackend, + upload_test_files: Tuple[Tuple[str, bytes], Tuple[str, bytes], Tuple[str, bytes]], + output_tmpdir: str, +): + (url1, data1), _, _ = upload_test_files + b = batch(service_backend) + input = b.read_input(url1) + j = b.new_job() + j.command(f'cat {input}') + b.write_output(input, os.path.join(output_tmpdir, 'hello.txt')) + res = b.run(verbose=True) + assert res + res_status = res.status() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + + +def test_cloudfuse_fails_with_read_write_mount_option( + fs: RouterAsyncFS, service_backend: ServiceBackend, output_bucket_path +): + bucket, path, output_tmpdir = output_bucket_path + + b = batch(service_backend) + j = b.new_job() + j.command(f'mkdir -p {path}; echo head > {path}/cloudfuse_test_1') + j.cloudfuse(bucket, f'/{bucket}', read_only=False) + + try: + b.run() + except ClientResponseError as e: + assert 'Only read-only cloudfuse requests are supported' in e.body, e.body + else: + assert False + + +def test_cloudfuse_fails_with_io_mount_point(fs: RouterAsyncFS, service_backend: ServiceBackend, output_bucket_path): + bucket, path, output_tmpdir = output_bucket_path + + b = batch(service_backend) + j = b.new_job() + j.command(f'mkdir -p {path}; echo head > {path}/cloudfuse_test_1') + j.cloudfuse(bucket, '/io', read_only=True) + + try: + b.run() + except ClientResponseError as e: + assert 'Cloudfuse requests with mount_path=/io are not supported' in e.body, e.body + else: + assert False + + +def test_cloudfuse_read_only(service_backend: ServiceBackend, output_bucket_path): + bucket, path, output_tmpdir = output_bucket_path + + b = batch(service_backend) + j = b.new_job() + j.command(f'mkdir -p {path}; echo head > {path}/cloudfuse_test_1') + j.cloudfuse(bucket, f'/{bucket}', read_only=True) + + res = b.run() + assert res + res_status = res.status() + assert res_status['state'] == 'failure', str((res_status, res.debug_info())) + + +def test_cloudfuse_implicit_dirs(fs: RouterAsyncFS, service_backend: ServiceBackend, upload_test_files): + (url1, data1), _, _ = upload_test_files + parsed_url1 = fs.parse_url(url1) + object_name = parsed_url1.path + bucket_name = '/'.join(parsed_url1.bucket_parts) + + b = batch(service_backend) + j = b.new_job() + j.command('cat ' + os.path.join('/cloudfuse', object_name)) + j.cloudfuse(bucket_name, '/cloudfuse', read_only=True) + + res = b.run() + assert res + res_status = res.status() + assert res.get_job_log(1)['main'] == data1.decode() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + + +def test_cloudfuse_empty_string_bucket_fails(service_backend: ServiceBackend, output_bucket_path): + bucket, path, output_tmpdir = output_bucket_path + + b = batch(service_backend) + j = b.new_job() + with pytest.raises(BatchException): + j.cloudfuse('', '/empty_bucket') + with pytest.raises(BatchException): + j.cloudfuse(bucket, '') + + +async def test_cloudfuse_submount_in_io_doesnt_rm_bucket( + fs: RouterAsyncFS, service_backend: ServiceBackend, output_bucket_path +): + bucket, path, output_tmpdir = output_bucket_path + + should_still_exist_url = os.path.join(output_tmpdir, 'should-still-exist') + await fs.write(should_still_exist_url, b'should-still-exist') + + b = batch(service_backend) + j = b.new_job() + j.cloudfuse(bucket, '/io/cloudfuse') + j.command('ls /io/cloudfuse/') + res = b.run() + assert res + res_status = res.status() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + assert await fs.read(should_still_exist_url) == b'should-still-exist' + + +@skip_in_azure +def test_fuse_requester_pays(service_backend: ServiceBackend): + assert REQUESTER_PAYS_PROJECT + b = batch(service_backend, requester_pays_project=REQUESTER_PAYS_PROJECT) + j = b.new_job() + j.cloudfuse('hail-test-requester-pays-fds32', '/fuse-bucket') + j.command('cat /fuse-bucket/hello') + res = b.run() + assert res + res_status = res.status() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + + +@skip_in_azure +def test_fuse_non_requester_pays_bucket_when_requester_pays_project_specified( + service_backend: ServiceBackend, output_bucket_path +): + bucket, path, output_tmpdir = output_bucket_path + assert REQUESTER_PAYS_PROJECT + + b = batch(service_backend, requester_pays_project=REQUESTER_PAYS_PROJECT) + j = b.new_job() + j.command('ls /fuse-bucket') + j.cloudfuse(bucket, '/fuse-bucket', read_only=True) + + res = b.run() + assert res + res_status = res.status() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + + +@skip_in_azure +def test_requester_pays(service_backend: ServiceBackend): + assert REQUESTER_PAYS_PROJECT + b = batch(service_backend, requester_pays_project=REQUESTER_PAYS_PROJECT) + input = b.read_input('gs://hail-test-requester-pays-fds32/hello') + j = b.new_job() + j.command(f'cat {input}') + res = b.run() + assert res + res_status = res.status() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + + +def test_benchmark_lookalike_workflow(service_backend: ServiceBackend, output_tmpdir): + b = batch(service_backend) + + setup_jobs = [] + for i in range(10): + j = b.new_job(f'setup_{i}').cpu(0.25) + j.command(f'echo "foo" > {j.ofile}') + setup_jobs.append(j) + + jobs = [] + for i in range(500): + j = b.new_job(f'create_file_{i}').cpu(0.25) + j.command(f'echo {setup_jobs[i % len(setup_jobs)].ofile} > {j.ofile}') + j.command(f'echo "bar" >> {j.ofile}') + jobs.append(j) + + combine = b.new_job('combine_output').cpu(0.25) + for _ in grouped(arg_max(), jobs): + combine.command(f'cat {" ".join(shq(j.ofile) for j in jobs)} >> {combine.ofile}') + b.write_output(combine.ofile, os.path.join(output_tmpdir, 'pipeline_benchmark_test.txt')) + # too slow + # assert b.run().status()['state'] == 'success' + + +def test_envvar(service_backend: ServiceBackend): + b = batch(service_backend) + j = b.new_job() + j.env('SOME_VARIABLE', '123abcdef') + j.command('[ $SOME_VARIABLE = "123abcdef" ]') + res = b.run() + assert res + res_status = res.status() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + + +def test_single_job_with_shell(service_backend: ServiceBackend): + msg = 'hello world' + b = batch(service_backend) + j = b.new_job(shell='/bin/sh') + j.command(f'echo "{msg}"') + res = b.run() + assert res + res_status = res.status() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + + +def test_single_job_with_nonsense_shell(service_backend: ServiceBackend): + b = batch(service_backend) + j = b.new_job(shell='/bin/ajdsfoijasidojf') + j.command('echo "hello"') + res = b.run() + assert res + res_status = res.status() + assert res_status['state'] == 'failure', str((res_status, res.debug_info())) + + +def test_single_job_with_intermediate_failure(service_backend: ServiceBackend): + b = batch(service_backend) + j = b.new_job() + j.command('echoddd "hello"') + j2 = b.new_job() + j2.command('echo "world"') + + res = b.run() + assert res + res_status = res.status() + assert res_status['state'] == 'failure', str((res_status, res.debug_info())) + + +def test_input_directory( + service_backend: ServiceBackend, upload_test_files: Tuple[Tuple[str, bytes], Tuple[str, bytes], Tuple[str, bytes]] +): + (url1, data1), _, _ = upload_test_files + b = batch(service_backend) + containing_folder = '/'.join(url1.rstrip('/').split('/')[:-1]) + input1 = b.read_input(containing_folder) + input2 = b.read_input(containing_folder + '/') + j = b.new_job() + j.command(f'ls {input1}/hello.txt') + j.command(f'ls {input2}/hello.txt') + res = b.run() + assert res + res_status = res.status() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + + +def test_fail_fast(service_backend: ServiceBackend): + b = batch(service_backend, cancel_after_n_failures=1) + j1 = b.new_job() + j1.command('false') + + j2 = b.new_job() + j2.command('sleep 300') + + res = b.run() + assert res + job_status = res.get_job(2).status() + assert job_status['state'] == 'Cancelled', str((job_status, res.debug_info())) + + +def test_service_backend_remote_tempdir_with_trailing_slash(service_backend: ServiceBackend): + b = Batch(backend=service_backend) + j1 = b.new_job() + j1.command(f'echo hello > {j1.ofile}') + j2 = b.new_job() + j2.command(f'cat {j1.ofile}') + b.run() + + +def test_service_backend_remote_tempdir_with_no_trailing_slash(service_backend: ServiceBackend): + b = Batch(backend=service_backend) + j1 = b.new_job() + j1.command(f'echo hello > {j1.ofile}') + j2 = b.new_job() + j2.command(f'cat {j1.ofile}') + b.run() + + +def test_large_command(service_backend: ServiceBackend): + b = Batch(backend=service_backend) + j1 = b.new_job() + long_str = secrets.token_urlsafe(15 * 1024) + j1.command(f'echo "{long_str}"') + b.run() + + +def test_big_batch_which_uses_slow_path(service_backend: ServiceBackend): + b = Batch(backend=service_backend) + # 8 * 256 * 1024 = 2 MiB > 1 MiB max bunch size + for _ in range(8): + j1 = b.new_job() + long_str = secrets.token_urlsafe(256 * 1024) + j1.command(f'echo "{long_str}" > /dev/null') + res = b.run() + assert res + assert not res._submission_info.used_fast_path + batch_status = res.status() + assert batch_status['state'] == 'success', str((res.debug_info())) + + +def test_specify_job_region(service_backend: ServiceBackend): + b = batch(service_backend) + j = b.new_job('region') + possible_regions = service_backend.supported_regions() + j.regions(possible_regions) + j.command('true') + res = b.run() + assert res + res_status = res.status() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + + +def test_job_regions_controls_job_execution_region(service_backend: ServiceBackend): + the_region = service_backend.supported_regions()[0] + + b = batch(service_backend) + j = b.new_job() + j.regions([the_region]) + j.command('true') + res = b.run() + + assert res + job_status = res.get_job(1).status() + assert job_status['status']['region'] == the_region, str((job_status, res.debug_info())) + + +def test_job_regions_overrides_batch_regions(service_backend: ServiceBackend): + the_region = service_backend.supported_regions()[0] + + b = batch(service_backend, default_regions=['some-other-region']) + j = b.new_job() + j.regions([the_region]) + j.command('true') + res = b.run() + + assert res + job_status = res.get_job(1).status() + assert job_status['status']['region'] == the_region, str((job_status, res.debug_info())) + + +def test_always_copy_output(service_backend: ServiceBackend, output_tmpdir: str): + output_path = os.path.join(output_tmpdir, 'test_always_copy_output.txt') + + b = batch(service_backend) + j = b.new_job() + j.always_copy_output() + j.command(f'echo "hello" > {j.ofile} && false') + + b.write_output(j.ofile, output_path) + res = b.run() + assert res + res_status = res.status() + assert res_status['state'] == 'failure', str((res_status, res.debug_info())) + + b2 = batch(service_backend) + input = b2.read_input(output_path) + file_exists_j = b2.new_job() + file_exists_j.command(f'cat {input}') + + res = b2.run() + assert res + res_status = res.status() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + assert res.get_job_log(1)['main'] == "hello\n", str(res.debug_info()) + + +def test_no_copy_output_on_failure(service_backend: ServiceBackend, output_tmpdir: str): + output_path = os.path.join(output_tmpdir, 'test_no_copy_output.txt') + + b = batch(service_backend) + j = b.new_job() + j.command(f'echo "hello" > {j.ofile} && false') + + b.write_output(j.ofile, output_path) + res = b.run() + assert res + res_status = res.status() + assert res_status['state'] == 'failure', str((res_status, res.debug_info())) + + b2 = batch(service_backend) + input = b2.read_input(output_path) + file_exists_j = b2.new_job() + file_exists_j.command(f'cat {input}') + + res = b2.run() + assert res + res_status = res.status() + assert res_status['state'] == 'failure', str((res_status, res.debug_info())) + + +def test_update_batch(service_backend: ServiceBackend): + b = batch(service_backend) + j = b.new_job() + j.command('true') + res = b.run() + assert res + + res_status = res.status() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + + j2 = b.new_job() + j2.command('true') + res = b.run() + assert res + res_status = res.status() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + + +def test_update_batch_with_dependencies(service_backend: ServiceBackend): + b = batch(service_backend) + j1 = b.new_job() + j1.command('true') + j2 = b.new_job() + j2.command('false') + res = b.run() + assert res + + res_status = res.status() + assert res_status['state'] == 'failure', str((res_status, res.debug_info())) + + j3 = b.new_job() + j3.command('true') + j3.depends_on(j1) + + j4 = b.new_job() + j4.command('true') + j4.depends_on(j2) + + res = b.run() + assert res + res_status = res.status() + assert res_status['state'] == 'failure', str((res_status, res.debug_info())) + + assert res.get_job(3).status()['state'] == 'Success', str((res_status, res.debug_info())) + assert res.get_job(4).status()['state'] == 'Cancelled', str((res_status, res.debug_info())) + + +def test_update_batch_from_batch_id(service_backend: ServiceBackend): + b = batch(service_backend) + j = b.new_job() + j.command('true') + res = b.run() + assert res + + res_status = res.status() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + + b2 = Batch.from_batch_id(res.id, backend=b._backend) + j2 = b2.new_job() + j2.command('true') + res = b2.run() + assert res + res_status = res.status() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + + +def test_wait_on_empty_batch_update(service_backend: ServiceBackend): + b = batch(service_backend) + b.run(wait=True) + b.run(wait=True) + + +def test_non_spot_job(service_backend: ServiceBackend): + b = batch(service_backend) + j = b.new_job() + j.spot(False) + j.command('echo hello') + res = b.run() + assert res + assert res.get_job(1).status()['spec']['resources']['preemptible'] is False + + +def test_spot_unspecified_job(service_backend: ServiceBackend): + b = batch(service_backend) + j = b.new_job() + j.command('echo hello') + res = b.run() + assert res + assert res.get_job(1).status()['spec']['resources']['preemptible'] is True + + +def test_spot_true_job(service_backend: ServiceBackend): + b = batch(service_backend) + j = b.new_job() + j.spot(True) + j.command('echo hello') + res = b.run() + assert res + assert res.get_job(1).status()['spec']['resources']['preemptible'] is True + + +def test_non_spot_batch(service_backend: ServiceBackend): + b = batch(service_backend, default_spot=False) + j1 = b.new_job() + j1.command('echo hello') + j2 = b.new_job() + j2.command('echo hello') + j3 = b.new_job() + j3.spot(True) + j3.command('echo hello') + res = b.run() + assert res + assert res.get_job(1).status()['spec']['resources']['preemptible'] is False + assert res.get_job(2).status()['spec']['resources']['preemptible'] is False + assert res.get_job(3).status()['spec']['resources']['preemptible'] is True + + +def test_local_file_paths_error(service_backend: ServiceBackend): + b = batch(service_backend) + b.new_job() + for input in ["hi.txt", "~/hello.csv", "./hey.tsv", "/sup.json", "file://yo.yaml"]: + with pytest.raises(ValueError) as e: + b.read_input(input) + assert str(e.value).startswith("Local filepath detected") + + +@skip_in_azure +def test_validate_cloud_storage_policy(service_backend: ServiceBackend, monkeypatch): + # buckets do not exist (bucket names can't contain the string "google" per + # https://cloud.google.com/storage/docs/buckets) + fake_bucket1 = "google" + fake_bucket2 = "google1" + no_bucket_error = "bucket does not exist" + # bucket exists, but account does not have permissions on it + no_perms_bucket = "test" + no_perms_error = "does not have storage.buckets.get access" + # bucket exists and account has permissions, but is set to use cold storage by default + cold_bucket = "hail-test-cold-storage" + cold_error = "configured to use cold storage by default" + fake_uri1, fake_uri2, no_perms_uri, cold_uri = [ + f"gs://{bucket}/test" for bucket in [fake_bucket1, fake_bucket2, no_perms_bucket, cold_bucket] + ] + + def _test_raises(exception_type, exception_msg, func): + with pytest.raises(exception_type) as e: + func() + assert exception_msg in str(e.value) + + def _test_raises_no_bucket_error(remote_tmpdir, arg=None): + _test_raises( + ClientResponseError, + no_bucket_error, + lambda: ServiceBackend(remote_tmpdir=remote_tmpdir, gcs_bucket_allow_list=arg), + ) + + def _test_raises_cold_error(func): + _test_raises(ValueError, cold_error, func) + + # no configuration, nonexistent buckets error + _test_raises_no_bucket_error(fake_uri1) + _test_raises_no_bucket_error(fake_uri2) + + # no configuration, no perms bucket errors + _test_raises(ClientResponseError, no_perms_error, lambda: ServiceBackend(remote_tmpdir=no_perms_uri)) + + # no configuration, cold bucket errors + _test_raises_cold_error(lambda: ServiceBackend(remote_tmpdir=cold_uri)) + b = batch(service_backend) + _test_raises_cold_error(lambda: b.read_input(cold_uri)) + j = b.new_job() + j.command(f"echo hello > {j.ofile}") + _test_raises_cold_error(lambda: b.write_output(j.ofile, cold_uri)) + + # hailctl config, allowlisted nonexistent buckets don't error + base_config = get_user_config() + local_config = ConfigParser() + local_config.read_dict({ + **{section: {key: val for key, val in base_config[section].items()} for section in base_config.sections()}, + **{"gcs": {"bucket_allow_list": f"{fake_bucket1},{fake_bucket2}"}}, + }) + + def _get_user_config(): + return local_config + + monkeypatch.setattr(user_config, "get_user_config", _get_user_config) + ServiceBackend(remote_tmpdir=fake_uri1) + ServiceBackend(remote_tmpdir=fake_uri2) + + # environment variable config, only allowlisted nonexistent buckets don't error + monkeypatch.setenv("HAIL_GCS_BUCKET_ALLOW_LIST", fake_bucket2) + _test_raises_no_bucket_error(fake_uri1) + ServiceBackend(remote_tmpdir=fake_uri2) + + # arg to constructor config, only allowlisted nonexistent buckets don't error + arg = [fake_bucket1] + ServiceBackend(remote_tmpdir=fake_uri1, gcs_bucket_allow_list=arg) + _test_raises_no_bucket_error(fake_uri2, arg) diff --git a/hail/python/test/hailtop/batch/test_python_job_in_service.py b/hail/python/test/hailtop/batch/test_python_job_in_service.py new file mode 100644 index 00000000000..fae274effc9 --- /dev/null +++ b/hail/python/test/hailtop/batch/test_python_job_in_service.py @@ -0,0 +1,368 @@ +import asyncio +import os +import secrets + +import orjson +import pytest + +import hailtop.batch_client.client as bc +from hailtop.aiotools.router_fs import RouterAsyncFS +from hailtop.batch import Batch, ResourceGroup, ServiceBackend +from hailtop.batch.exceptions import BatchException +from hailtop.config import configuration_of +from hailtop.config.variables import ConfigVariable + +from .utils import ( + HAIL_GENETICS_HAIL_IMAGE, + PYTHON_DILL_IMAGE, + batch, +) + + +def test_python_job(service_backend: ServiceBackend): + b = batch(service_backend, default_python_image=PYTHON_DILL_IMAGE) + head = b.new_job() + head.command(f'echo "5" > {head.r5}') + head.command(f'echo "3" > {head.r3}') + + def read(path): + with open(path, 'r') as f: + i = f.read() + return int(i) + + def multiply(x, y): + return x * y + + def reformat(x, y): + return {'x': x, 'y': y} + + middle = b.new_python_job() + r3 = middle.call(read, head.r3) + r5 = middle.call(read, head.r5) + r_mult = middle.call(multiply, r3, r5) + + middle2 = b.new_python_job() + r_mult = middle2.call(multiply, r_mult, 2) + r_dict = middle2.call(reformat, r3, r5) + + tail = b.new_job() + tail.command(f'cat {r3.as_str()} {r5.as_repr()} {r_mult.as_str()} {r_dict.as_json()}') + + res = b.run() + assert res + res_status = res.status() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + assert res.get_job_log(4)['main'] == "3\n5\n30\n{\"x\": 3, \"y\": 5}\n", str(res.debug_info()) + + +def test_python_job_w_resource_group_unpack_individually(service_backend: ServiceBackend): + b = batch(service_backend, default_python_image=PYTHON_DILL_IMAGE) + head = b.new_job() + head.declare_resource_group(count={'r5': '{root}.r5', 'r3': '{root}.r3'}) + assert isinstance(head.count, ResourceGroup) + + head.command(f'echo "5" > {head.count.r5}') + head.command(f'echo "3" > {head.count.r3}') + + def read(path): + with open(path, 'r') as f: + r = int(f.read()) + return r + + def multiply(x, y): + return x * y + + def reformat(x, y): + return {'x': x, 'y': y} + + middle = b.new_python_job() + r3 = middle.call(read, head.count.r3) + r5 = middle.call(read, head.count.r5) + r_mult = middle.call(multiply, r3, r5) + + middle2 = b.new_python_job() + r_mult = middle2.call(multiply, r_mult, 2) + r_dict = middle2.call(reformat, r3, r5) + + tail = b.new_job() + tail.command(f'cat {r3.as_str()} {r5.as_repr()} {r_mult.as_str()} {r_dict.as_json()}') + + res = b.run() + assert res + res_status = res.status() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + assert res.get_job_log(4)['main'] == "3\n5\n30\n{\"x\": 3, \"y\": 5}\n", str(res.debug_info()) + + +def test_python_job_can_write_to_resource_path(service_backend: ServiceBackend): + b = batch(service_backend, default_python_image=PYTHON_DILL_IMAGE) + + def write(path): + with open(path, 'w') as f: + f.write('foo') + + head = b.new_python_job() + head.call(write, head.ofile) + + tail = b.new_bash_job() + tail.command(f'cat {head.ofile}') + + res = b.run() + assert res + assert tail._job_id + res_status = res.status() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + assert res.get_job_log(tail._job_id)['main'] == 'foo', str(res.debug_info()) + + +def test_python_job_w_resource_group_unpack_jointly(service_backend: ServiceBackend): + b = batch(service_backend, default_python_image=PYTHON_DILL_IMAGE) + head = b.new_job() + head.declare_resource_group(count={'r5': '{root}.r5', 'r3': '{root}.r3'}) + assert isinstance(head.count, ResourceGroup) + + head.command(f'echo "5" > {head.count.r5}') + head.command(f'echo "3" > {head.count.r3}') + + def read_rg(root): + with open(root['r3'], 'r') as f: + r3 = int(f.read()) + with open(root['r5'], 'r') as f: + r5 = int(f.read()) + return (r3, r5) + + def multiply(r): + x, y = r + return x * y + + middle = b.new_python_job() + r = middle.call(read_rg, head.count) + r_mult = middle.call(multiply, r) + + tail = b.new_job() + tail.command(f'cat {r_mult.as_str()}') + + res = b.run() + assert res + res_status = res.status() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + job_log_3 = res.get_job_log(3) + assert job_log_3['main'] == "15\n", str((job_log_3, res.debug_info())) + + +def test_python_job_w_non_zero_ec(service_backend: ServiceBackend): + b = batch(service_backend, default_python_image=PYTHON_DILL_IMAGE) + j = b.new_python_job() + + def error(): + raise Exception("this should fail") + + j.call(error) + res = b.run() + assert res + res_status = res.status() + assert res_status['state'] == 'failure', str((res_status, res.debug_info())) + + +def test_python_job_incorrect_signature(service_backend: ServiceBackend): + b = batch(service_backend, default_python_image=PYTHON_DILL_IMAGE) + + def foo(pos_arg1, pos_arg2, *, kwarg1, kwarg2=1): + print(pos_arg1, pos_arg2, kwarg1, kwarg2) + + j = b.new_python_job() + + with pytest.raises(BatchException): + j.call(foo) + with pytest.raises(BatchException): + j.call(foo, 1) + with pytest.raises(BatchException): + j.call(foo, 1, 2) + with pytest.raises(BatchException): + j.call(foo, 1, kwarg1=2) + with pytest.raises(BatchException): + j.call(foo, 1, 2, 3) + with pytest.raises(BatchException): + j.call(foo, 1, 2, kwarg1=3, kwarg2=4, kwarg3=5) + + j.call(foo, 1, 2, kwarg1=3) + j.call(foo, 1, 2, kwarg1=3, kwarg2=4) + + # `print` doesn't have a signature but other builtins like `abs` do + j.call(print, 5) + j.call(abs, -1) + with pytest.raises(BatchException): + j.call(abs, -1, 5) + + +def test_query_on_batch_in_batch(service_backend: ServiceBackend, output_tmpdir: str): + bb = Batch(backend=service_backend, default_python_image=HAIL_GENETICS_HAIL_IMAGE) + + tmp_ht_path = os.path.join(output_tmpdir, secrets.token_urlsafe(32)) + + def qob_in_batch(): + import hail as hl + + hl.utils.range_table(10).write(tmp_ht_path, overwrite=True) + + j = bb.new_python_job() + j.env('HAIL_QUERY_BACKEND', 'batch') + j.env('HAIL_BATCH_BILLING_PROJECT', configuration_of(ConfigVariable.BATCH_BILLING_PROJECT, None, '')) + j.env('HAIL_BATCH_REMOTE_TMPDIR', output_tmpdir) + j.call(qob_in_batch) + + bb.run() + + +def test_basic_async_fun(service_backend: ServiceBackend): + b = Batch(backend=service_backend) + + j = b.new_python_job() + j.call(asyncio.sleep, 1) + + res = b.run() + assert res + batch_status = res.status() + assert batch_status['state'] == 'success', str((res.debug_info())) + + +def test_async_fun_returns_value(service_backend: ServiceBackend): + b = Batch(backend=service_backend) + + async def foo(i, j): + await asyncio.sleep(1) + return i * j + + j = b.new_python_job() + result = j.call(foo, 2, 3) + + j = b.new_job() + j.command(f'cat {result.as_str()}') + + res = b.run() + assert res + batch_status = res.status() + assert batch_status['state'] == 'success', str((batch_status, res.debug_info())) + job_log_2 = res.get_job_log(2) + assert job_log_2['main'] == "6\n", str((job_log_2, res.debug_info())) + + +def test_update_batch_with_python_job_dependencies(service_backend: ServiceBackend): + b = batch(service_backend) + + async def foo(i, j): + await asyncio.sleep(1) + return i * j + + j1 = b.new_python_job() + j1.call(foo, 2, 3) + + res = b.run() + assert res + batch_status = res.status() + assert batch_status['state'] == 'success', str((batch_status, res.debug_info())) + + j2 = b.new_python_job() + j2.call(foo, 2, 3) + + res = b.run() + assert res + batch_status = res.status() + assert batch_status['state'] == 'success', str((batch_status, res.debug_info())) + + j3 = b.new_python_job() + j3.depends_on(j2) + j3.call(foo, 2, 3) + + res = b.run() + assert res + batch_status = res.status() + assert batch_status['state'] == 'success', str((batch_status, res.debug_info())) + + +async def test_python_job_with_kwarg(fs: RouterAsyncFS, service_backend: ServiceBackend, output_tmpdir: str): + def foo(*, kwarg): + return kwarg + + b = batch(service_backend, default_python_image=PYTHON_DILL_IMAGE) + j = b.new_python_job() + r = j.call(foo, kwarg='hello world') + + output_path = os.path.join(output_tmpdir, 'test_python_job_with_kwarg') + b.write_output(r.as_json(), output_path) + res = b.run() + assert isinstance(res, bc.Batch) + + assert res.status()['state'] == 'success', str((res, res.debug_info())) + assert orjson.loads(await fs.read(output_path)) == 'hello world' + + +def test_tuple_recursive_resource_extraction_in_python_jobs(service_backend: ServiceBackend): + b = batch(service_backend, default_python_image=PYTHON_DILL_IMAGE) + + def write(paths): + if not isinstance(paths, tuple): + raise ValueError('paths must be a tuple') + for i, path in enumerate(paths): + with open(path, 'w') as f: + f.write(f'{i}') + + head = b.new_python_job() + head.call(write, (head.ofile1, head.ofile2)) + + tail = b.new_bash_job() + tail.command(f'cat {head.ofile1}') + tail.command(f'cat {head.ofile2}') + + res = b.run() + assert res + assert tail._job_id + res_status = res.status() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + assert res.get_job_log(tail._job_id)['main'] == '01', str(res.debug_info()) + + +def test_list_recursive_resource_extraction_in_python_jobs(service_backend: ServiceBackend): + b = batch(service_backend, default_python_image=PYTHON_DILL_IMAGE) + + def write(paths): + for i, path in enumerate(paths): + with open(path, 'w') as f: + f.write(f'{i}') + + head = b.new_python_job() + head.call(write, [head.ofile1, head.ofile2]) + + tail = b.new_bash_job() + tail.command(f'cat {head.ofile1}') + tail.command(f'cat {head.ofile2}') + + res = b.run() + assert res + assert tail._job_id + res_status = res.status() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + assert res.get_job_log(tail._job_id)['main'] == '01', str(res.debug_info()) + + +def test_dict_recursive_resource_extraction_in_python_jobs(service_backend: ServiceBackend): + b = batch(service_backend, default_python_image=PYTHON_DILL_IMAGE) + + def write(kwargs): + for k, v in kwargs.items(): + with open(v, 'w') as f: + f.write(k) + + head = b.new_python_job() + head.call(write, {'a': head.ofile1, 'b': head.ofile2}) + + tail = b.new_bash_job() + tail.command(f'cat {head.ofile1}') + tail.command(f'cat {head.ofile2}') + + res = b.run() + assert res + assert tail._job_id + res_status = res.status() + assert res_status['state'] == 'success', str((res_status, res.debug_info())) + assert res.get_job_log(tail._job_id)['main'] == 'ab', str(res.debug_info()) diff --git a/hail/python/test/hailtop/batch/utils.py b/hail/python/test/hailtop/batch/utils.py new file mode 100644 index 00000000000..4915962f9c5 --- /dev/null +++ b/hail/python/test/hailtop/batch/utils.py @@ -0,0 +1,21 @@ +import inspect +import os + +from hailtop import pip_version +from hailtop.batch import Batch + +DOCKER_ROOT_IMAGE = os.environ.get('DOCKER_ROOT_IMAGE', 'ubuntu:22.04') +PYTHON_DILL_IMAGE = 'hailgenetics/python-dill:3.9-slim' +HAIL_GENETICS_HAIL_IMAGE = os.environ.get('HAIL_GENETICS_HAIL_IMAGE', f'hailgenetics/hail:{pip_version()}') +REQUESTER_PAYS_PROJECT = os.environ.get('GCS_REQUESTER_PAYS_PROJECT') + + +def batch(backend, **kwargs): + name_of_test_method = inspect.stack()[1][3] + return Batch( + name=name_of_test_method, + backend=backend, + default_image=DOCKER_ROOT_IMAGE, + attributes={'foo': 'a', 'bar': 'b'}, + **kwargs, + ) diff --git a/hail/python/test/hailtop/config/test_deploy_config.py b/hail/python/test/hailtop/config/test_deploy_config.py index 8f9c50db819..7ccf118bb69 100644 --- a/hail/python/test/hailtop/config/test_deploy_config.py +++ b/hail/python/test/hailtop/config/test_deploy_config.py @@ -1,9 +1,11 @@ import unittest + from hailtop.config.deploy_config import DeployConfig + class Test(unittest.TestCase): def test_deploy_external_default(self): - deploy_config = DeployConfig('external', 'default', 'organization.tld') + deploy_config = DeployConfig('external', 'default', 'organization.tld', None) self.assertEqual(deploy_config.location(), 'external') self.assertEqual(deploy_config.default_namespace(), 'default') @@ -17,7 +19,7 @@ def test_deploy_external_default(self): self.assertEqual(deploy_config.external_url('quam', '/moo'), 'https://quam.organization.tld/moo') def test_deploy_external_bar(self): - deploy_config = DeployConfig('external', 'bar', 'organization.tld') + deploy_config = DeployConfig('external', 'bar', 'internal.organization.tld', '/bar') self.assertEqual(deploy_config.location(), 'external') self.assertEqual(deploy_config.default_namespace(), 'bar') @@ -30,7 +32,7 @@ def test_deploy_external_bar(self): self.assertEqual(deploy_config.external_url('foo', '/moo'), 'https://internal.organization.tld/bar/foo/moo') def test_deploy_k8s_default(self): - deploy_config = DeployConfig('k8s', 'default', 'organization.tld') + deploy_config = DeployConfig('k8s', 'default', 'organization.tld', None) self.assertEqual(deploy_config.location(), 'k8s') self.assertEqual(deploy_config.default_namespace(), 'default') @@ -44,7 +46,7 @@ def test_deploy_k8s_default(self): self.assertEqual(deploy_config.external_url('quam', '/moo'), 'https://quam.organization.tld/moo') def test_deploy_k8s_bar(self): - deploy_config = DeployConfig('k8s', 'bar', 'organization.tld') + deploy_config = DeployConfig('k8s', 'bar', 'internal.organization.tld', '/bar') self.assertEqual(deploy_config.location(), 'k8s') self.assertEqual(deploy_config.default_namespace(), 'bar') @@ -57,7 +59,7 @@ def test_deploy_k8s_bar(self): self.assertEqual(deploy_config.external_url('foo', '/moo'), 'https://internal.organization.tld/bar/foo/moo') def test_deploy_batch_job_default(self): - deploy_config = DeployConfig('gce', 'default', 'organization.tld') + deploy_config = DeployConfig('gce', 'default', 'organization.tld', None) self.assertEqual(deploy_config.location(), 'gce') self.assertEqual(deploy_config.default_namespace(), 'default') @@ -71,7 +73,7 @@ def test_deploy_batch_job_default(self): self.assertEqual(deploy_config.external_url('quam', '/moo'), 'https://quam.organization.tld/moo') def test_deploy_batch_job_bar(self): - deploy_config = DeployConfig('gce', 'bar', 'organization.tld') + deploy_config = DeployConfig('gce', 'bar', 'internal.organization.tld', '/bar') self.assertEqual(deploy_config.location(), 'gce') self.assertEqual(deploy_config.default_namespace(), 'bar') diff --git a/hail/python/test/hailtop/conftest.py b/hail/python/test/hailtop/conftest.py index 9b612b07d6d..0b618863e42 100644 --- a/hail/python/test/hailtop/conftest.py +++ b/hail/python/test/hailtop/conftest.py @@ -1,10 +1,20 @@ +import asyncio import hashlib import os import pytest -def pytest_collection_modifyitems(config, items): +@pytest.fixture(scope="session") +def event_loop(): + loop = asyncio.get_event_loop() + try: + yield loop + finally: + loop.close() + + +def pytest_collection_modifyitems(items): n_splits = int(os.environ.get('HAIL_RUN_IMAGE_SPLITS', '1')) split_index = int(os.environ.get('HAIL_RUN_IMAGE_SPLIT_INDEX', '-1')) if n_splits <= 1: diff --git a/hail/python/test/hailtop/hailctl/config/conftest.py b/hail/python/test/hailtop/hailctl/config/conftest.py index 95ca40e0f4d..4f4af976c80 100644 --- a/hail/python/test/hailtop/hailctl/config/conftest.py +++ b/hail/python/test/hailtop/hailctl/config/conftest.py @@ -1,6 +1,6 @@ -import pytest import tempfile +import pytest from typer.testing import CliRunner diff --git a/hail/python/test/hailtop/hailctl/config/test_cli.py b/hail/python/test/hailtop/hailctl/config/test_cli.py index 8cb5af3a8da..3d516fedd26 100644 --- a/hail/python/test/hailtop/hailctl/config/test_cli.py +++ b/hail/python/test/hailtop/hailctl/config/test_cli.py @@ -1,10 +1,10 @@ import os -import pytest +import pytest from typer.testing import CliRunner -from hailtop.config.variables import ConfigVariable from hailtop.config.user_config import get_user_config_path +from hailtop.config.variables import ConfigVariable from hailtop.hailctl.config import cli, config_variables @@ -66,13 +66,13 @@ def test_config_get_unknown_names(runner: CliRunner, config_dir: str): config_path = get_user_config_path(_config_dir=config_dir) os.makedirs(os.path.dirname(config_path)) with open(config_path, 'w', encoding='utf-8') as config: - config.write(f''' + config.write(""" [global] email = johndoe@gmail.com [batch] foo = 5 -''') +""") res = runner.invoke(cli.app, ['get', 'email'], catch_exceptions=False) assert res.exit_code == 0 diff --git a/hail/python/test/hailtop/hailctl/dataproc/conftest.py b/hail/python/test/hailtop/hailctl/dataproc/conftest.py index 519500c7b4c..a18d7aa49fd 100644 --- a/hail/python/test/hailtop/hailctl/dataproc/conftest.py +++ b/hail/python/test/hailtop/hailctl/dataproc/conftest.py @@ -25,7 +25,9 @@ def gcloud_run(): def patch_gcloud(monkeypatch, gcloud_run, gcloud_config): """Automatically replace gcloud functions with mocks.""" monkeypatch.setattr("hailtop.hailctl.dataproc.gcloud.run", gcloud_run) - monkeypatch.setattr("hailtop.hailctl.dataproc.gcloud.get_version", Mock(return_value=MINIMUM_REQUIRED_GCLOUD_VERSION)) + monkeypatch.setattr( + "hailtop.hailctl.dataproc.gcloud.get_version", Mock(return_value=MINIMUM_REQUIRED_GCLOUD_VERSION) + ) def mock_gcloud_get_config(setting): return gcloud_config.get(setting, None) diff --git a/hail/python/test/hailtop/hailctl/dataproc/test_cli.py b/hail/python/test/hailtop/hailctl/dataproc/test_cli.py index 39497ca486e..6ca6c88302a 100644 --- a/hail/python/test/hailtop/hailctl/dataproc/test_cli.py +++ b/hail/python/test/hailtop/hailctl/dataproc/test_cli.py @@ -1,14 +1,16 @@ from unittest.mock import Mock + from typer.testing import CliRunner from hailtop.hailctl.dataproc import cli - runner = CliRunner(mix_stderr=False) def test_required_gcloud_version_met(gcloud_run, monkeypatch): - monkeypatch.setattr("hailtop.hailctl.dataproc.gcloud.get_version", Mock(return_value=cli.MINIMUM_REQUIRED_GCLOUD_VERSION)) + monkeypatch.setattr( + "hailtop.hailctl.dataproc.gcloud.get_version", Mock(return_value=cli.MINIMUM_REQUIRED_GCLOUD_VERSION) + ) runner.invoke(cli.app, ['list']) assert gcloud_run.call_count == 1 diff --git a/hail/python/test/hailtop/hailctl/dataproc/test_connect.py b/hail/python/test/hailtop/hailctl/dataproc/test_connect.py index 36212906ee8..831aa8e1b86 100644 --- a/hail/python/test/hailtop/hailctl/dataproc/test_connect.py +++ b/hail/python/test/hailtop/hailctl/dataproc/test_connect.py @@ -1,11 +1,10 @@ from unittest.mock import Mock -from typer.testing import CliRunner import pytest +from typer.testing import CliRunner from hailtop.hailctl.dataproc import cli - runner = CliRunner(mix_stderr=False) @@ -18,10 +17,7 @@ def subprocess(): def patch_subprocess(monkeypatch, subprocess): """Automatically mock subprocess module.""" monkeypatch.setattr("hailtop.hailctl.dataproc.connect.subprocess", subprocess) - monkeypatch.setattr( - "hailtop.hailctl.dataproc.connect.get_chrome_path", - Mock(return_value="chromium") - ) + monkeypatch.setattr("hailtop.hailctl.dataproc.connect.get_chrome_path", Mock(return_value="chromium")) yield monkeypatch.undo() @@ -48,7 +44,7 @@ def test_connect(gcloud_run, subprocess): gcloud_args = gcloud_run.call_args[0][0] assert gcloud_args[:2] == ["compute", "ssh"] - assert gcloud_args[2][(gcloud_args[2].find("@") + 1):] == "test-cluster-m" + assert gcloud_args[2][(gcloud_args[2].find("@") + 1) :] == "test-cluster-m" assert "--ssh-flag=-D 10000" in gcloud_args assert "--ssh-flag=-N" in gcloud_args @@ -63,14 +59,17 @@ def test_connect(gcloud_run, subprocess): assert any(arg.startswith("--user-data-dir=") for arg in popen_args) -@pytest.mark.parametrize("service,expected_port_and_path", [ - ("spark-ui", "18080/?showIncomplete=true"), - ("ui", "18080/?showIncomplete=true"), - ("spark-history", "18080"), - ("hist", "18080"), - ("notebook", "8123"), - ("nb", "8123"), -]) +@pytest.mark.parametrize( + "service,expected_port_and_path", + [ + ("spark-ui", "18080/?showIncomplete=true"), + ("ui", "18080/?showIncomplete=true"), + ("spark-history", "18080"), + ("hist", "18080"), + ("notebook", "8123"), + ("nb", "8123"), + ], +) def test_service_port_and_path(subprocess, service, expected_port_and_path): runner.invoke(cli.app, ['connect', 'test-cluster', service]) @@ -80,8 +79,7 @@ def test_service_port_and_path(subprocess, service, expected_port_and_path): def test_hailctl_chrome(subprocess, monkeypatch): monkeypatch.setattr( - "hailtop.hailctl.dataproc.connect.get_chrome_path", - Mock(side_effect=Exception("Unable to find chrome")) + "hailtop.hailctl.dataproc.connect.get_chrome_path", Mock(side_effect=Exception("Unable to find chrome")) ) monkeypatch.setenv("HAILCTL_CHROME", "/path/to/chrome.exe") diff --git a/hail/python/test/hailtop/hailctl/dataproc/test_list_clusters.py b/hail/python/test/hailtop/hailctl/dataproc/test_list_clusters.py index 0b1fc333066..66574b53cfd 100644 --- a/hail/python/test/hailtop/hailctl/dataproc/test_list_clusters.py +++ b/hail/python/test/hailtop/hailctl/dataproc/test_list_clusters.py @@ -2,7 +2,6 @@ from hailtop.hailctl.dataproc import cli - runner = CliRunner(mix_stderr=False) diff --git a/hail/python/test/hailtop/hailctl/dataproc/test_modify.py b/hail/python/test/hailtop/hailctl/dataproc/test_modify.py index 7468f5c03e9..e9ad008de52 100644 --- a/hail/python/test/hailtop/hailctl/dataproc/test_modify.py +++ b/hail/python/test/hailtop/hailctl/dataproc/test_modify.py @@ -3,7 +3,6 @@ from hailtop.hailctl.dataproc import cli - runner = CliRunner(mix_stderr=False) @@ -38,22 +37,28 @@ def test_modify_dry_run(gcloud_run): assert gcloud_run.call_count == 0 -@pytest.mark.parametrize("workers_arg", [ - "--num-workers=2", - "--n-workers=2", - "-w2", -]) +@pytest.mark.parametrize( + "workers_arg", + [ + "--num-workers=2", + "--n-workers=2", + "-w2", + ], +) def test_modify_workers(gcloud_run, workers_arg): runner.invoke(cli.app, ['modify', 'test-cluster', workers_arg]) assert "--num-workers=2" in gcloud_run.call_args[0][0] -@pytest.mark.parametrize("workers_arg", [ - "--num-secondary-workers=2", - "--num-preemptible-workers=2", - "--n-pre-workers=2", - "-p2", -]) +@pytest.mark.parametrize( + "workers_arg", + [ + "--num-secondary-workers=2", + "--num-preemptible-workers=2", + "--n-pre-workers=2", + "-p2", + ], +) def test_modify_secondary_workers(gcloud_run, workers_arg): runner.invoke(cli.app, ['modify', 'test-cluster', workers_arg]) assert "--num-secondary-workers=2" in gcloud_run.call_args[0][0] @@ -64,10 +69,13 @@ def test_modify_max_idle(gcloud_run): assert "--max-idle=1h" in gcloud_run.call_args[0][0] -@pytest.mark.parametrize("workers_arg", [ - "--num-workers=2", - "--num-secondary-workers=2", -]) +@pytest.mark.parametrize( + "workers_arg", + [ + "--num-workers=2", + "--num-secondary-workers=2", + ], +) def test_graceful_decommission_timeout(gcloud_run, workers_arg): runner.invoke(cli.app, ['modify', 'test-cluster', workers_arg, '--graceful-decommission-timeout=1h']) assert workers_arg in gcloud_run.call_args[0][0] @@ -87,13 +95,15 @@ def test_modify_wheel_remote_wheel(gcloud_run): assert gcloud_args[:3] == ["compute", "ssh", "test-cluster-m"] remote_command = gcloud_args[gcloud_args.index("--") + 1] - assert remote_command == ("sudo gsutil cp gs://some-bucket/hail.whl /tmp/ && " + - "sudo /opt/conda/default/bin/pip uninstall -y hail && " + - "sudo /opt/conda/default/bin/pip install --no-dependencies /tmp/hail.whl && " + - "unzip /tmp/hail.whl && " + - "requirements_file=$(mktemp) && " + - "grep 'Requires-Dist: ' hail*dist-info/METADATA | sed 's/Requires-Dist: //' | sed 's/ (//' | sed 's/)//' | grep -v 'pyspark' >$requirements_file &&" + - "/opt/conda/default/bin/pip install -r $requirements_file") + assert remote_command == ( + "sudo gsutil cp gs://some-bucket/hail.whl /tmp/ && " + + "sudo /opt/conda/default/bin/pip uninstall -y hail && " + + "sudo /opt/conda/default/bin/pip install --no-dependencies /tmp/hail.whl && " + + "unzip /tmp/hail.whl && " + + "requirements_file=$(mktemp) && " + + "grep 'Requires-Dist: ' hail*dist-info/METADATA | sed 's/Requires-Dist: //' | sed 's/ (//' | sed 's/)//' | grep -v 'pyspark' >$requirements_file &&" + + "/opt/conda/default/bin/pip install -r $requirements_file" + ) def test_modify_wheel_local_wheel(gcloud_run): @@ -108,18 +118,23 @@ def test_modify_wheel_local_wheel(gcloud_run): assert install_gcloud_args[:3] == ["compute", "ssh", "test-cluster-m"] remote_command = install_gcloud_args[install_gcloud_args.index("--") + 1] - assert remote_command == ("sudo /opt/conda/default/bin/pip uninstall -y hail && " + - "sudo /opt/conda/default/bin/pip install --no-dependencies /tmp/local-hail.whl && " + - "unzip /tmp/local-hail.whl && " + - "requirements_file=$(mktemp) && " + - "grep 'Requires-Dist: ' hail*dist-info/METADATA | sed 's/Requires-Dist: //' | sed 's/ (//' | sed 's/)//' | grep -v 'pyspark' >$requirements_file &&" + - "/opt/conda/default/bin/pip install -r $requirements_file") - - -@pytest.mark.parametrize("wheel_arg", [ - "--wheel=gs://some-bucket/hail.whl", - "--wheel=./hail.whl", -]) + assert remote_command == ( + "sudo /opt/conda/default/bin/pip uninstall -y hail && " + + "sudo /opt/conda/default/bin/pip install --no-dependencies /tmp/local-hail.whl && " + + "unzip /tmp/local-hail.whl && " + + "requirements_file=$(mktemp) && " + + "grep 'Requires-Dist: ' hail*dist-info/METADATA | sed 's/Requires-Dist: //' | sed 's/ (//' | sed 's/)//' | grep -v 'pyspark' >$requirements_file &&" + + "/opt/conda/default/bin/pip install -r $requirements_file" + ) + + +@pytest.mark.parametrize( + "wheel_arg", + [ + "--wheel=gs://some-bucket/hail.whl", + "--wheel=./hail.whl", + ], +) def test_modify_wheel_zone(gcloud_run, gcloud_config, wheel_arg): gcloud_config["compute/zone"] = "us-central1-b" @@ -128,10 +143,13 @@ def test_modify_wheel_zone(gcloud_run, gcloud_config, wheel_arg): assert "--zone=us-east1-d" in call_args[0][0] -@pytest.mark.parametrize("wheel_arg", [ - "--wheel=gs://some-bucket/hail.whl", - "--wheel=./hail.whl", -]) +@pytest.mark.parametrize( + "wheel_arg", + [ + "--wheel=gs://some-bucket/hail.whl", + "--wheel=./hail.whl", + ], +) def test_modify_wheel_default_zone(gcloud_run, gcloud_config, wheel_arg): gcloud_config["compute/zone"] = "us-central1-b" @@ -140,10 +158,13 @@ def test_modify_wheel_default_zone(gcloud_run, gcloud_config, wheel_arg): assert "--zone=us-central1-b" in call_args[0][0] -@pytest.mark.parametrize("wheel_arg", [ - "--wheel=gs://some-bucket/hail.whl", - "--wheel=./hail.whl", -]) +@pytest.mark.parametrize( + "wheel_arg", + [ + "--wheel=gs://some-bucket/hail.whl", + "--wheel=./hail.whl", + ], +) def test_modify_wheel_zone_required(gcloud_run, gcloud_config, wheel_arg): gcloud_config["compute/zone"] = None @@ -152,10 +173,13 @@ def test_modify_wheel_zone_required(gcloud_run, gcloud_config, wheel_arg): assert gcloud_run.call_count == 0 -@pytest.mark.parametrize("wheel_arg", [ - "--wheel=gs://some-bucket/hail.whl", - "--wheel=./hail.whl", -]) +@pytest.mark.parametrize( + "wheel_arg", + [ + "--wheel=gs://some-bucket/hail.whl", + "--wheel=./hail.whl", + ], +) def test_modify_wheel_dry_run(gcloud_run, wheel_arg): runner.invoke(cli.app, ['modify', 'test-cluster', wheel_arg, '--dry-run']) assert gcloud_run.call_count == 0 @@ -179,11 +203,11 @@ def test_update_hail_version(gcloud_run, monkeypatch, deploy_metadata): remote_command = gcloud_args[gcloud_args.index("--") + 1] assert remote_command == ( - "sudo gsutil cp gs://hail-common/hailctl/dataproc/test-version/hail-test-version-py3-none-any.whl /tmp/ && " + - "sudo /opt/conda/default/bin/pip uninstall -y hail && " + - "sudo /opt/conda/default/bin/pip install --no-dependencies /tmp/hail-test-version-py3-none-any.whl && " + - "unzip /tmp/hail-test-version-py3-none-any.whl && " + - "requirements_file=$(mktemp) && " + - "grep 'Requires-Dist: ' hail*dist-info/METADATA | sed 's/Requires-Dist: //' | sed 's/ (//' | sed 's/)//' | grep -v 'pyspark' >$requirements_file &&" + - "/opt/conda/default/bin/pip install -r $requirements_file" + "sudo gsutil cp gs://hail-common/hailctl/dataproc/test-version/hail-test-version-py3-none-any.whl /tmp/ && " + + "sudo /opt/conda/default/bin/pip uninstall -y hail && " + + "sudo /opt/conda/default/bin/pip install --no-dependencies /tmp/hail-test-version-py3-none-any.whl && " + + "unzip /tmp/hail-test-version-py3-none-any.whl && " + + "requirements_file=$(mktemp) && " + + "grep 'Requires-Dist: ' hail*dist-info/METADATA | sed 's/Requires-Dist: //' | sed 's/ (//' | sed 's/)//' | grep -v 'pyspark' >$requirements_file &&" + + "/opt/conda/default/bin/pip install -r $requirements_file" ) diff --git a/hail/python/test/hailtop/hailctl/dataproc/test_start.py b/hail/python/test/hailtop/hailctl/dataproc/test_start.py index 2075879ef41..1b3c3e5a762 100644 --- a/hail/python/test/hailtop/hailctl/dataproc/test_start.py +++ b/hail/python/test/hailtop/hailctl/dataproc/test_start.py @@ -3,7 +3,6 @@ from hailtop.hailctl.dataproc import cli - runner = CliRunner(mix_stderr=False) @@ -24,10 +23,13 @@ def test_cluster_project(gcloud_run): assert "--project=foo" in gcloud_run.call_args[0][0] -@pytest.mark.parametrize("location_arg", [ - "--region=europe-north1", - "--zone=us-central1-b", -]) +@pytest.mark.parametrize( + "location_arg", + [ + "--region=europe-north1", + "--zone=us-central1-b", + ], +) def test_cluster_location(gcloud_run, location_arg): runner.invoke(cli.app, ['start', location_arg, 'test-cluster']) assert location_arg in gcloud_run.call_args[0][0] @@ -48,29 +50,28 @@ def test_workers_configuration(gcloud_run): assert "--num-workers=4" in gcloud_run.call_args[0][0] -@pytest.mark.parametrize("workers_arg", [ - "--num-secondary-workers=8", - "--num-preemptible-workers=8" -]) +@pytest.mark.parametrize("workers_arg", ["--num-secondary-workers=8", "--num-preemptible-workers=8"]) def test_secondary_workers_configuration(gcloud_run, workers_arg): runner.invoke(cli.app, ['start', workers_arg, 'test-cluster']) assert "--num-secondary-workers=8" in gcloud_run.call_args[0][0] -@pytest.mark.parametrize("machine_arg", [ - "--master-machine-type=n1-highmem-16", - "--worker-machine-type=n1-standard-32", -]) +@pytest.mark.parametrize( + "machine_arg", + [ + "--master-machine-type=n1-highmem-16", + "--worker-machine-type=n1-standard-32", + ], +) def test_machine_type_configuration(gcloud_run, machine_arg): runner.invoke(cli.app, ['start', machine_arg, 'test-cluster']) assert machine_arg in gcloud_run.call_args[0][0] -@pytest.mark.parametrize("machine_arg", [ - "--master-boot-disk-size=250", - "--worker-boot-disk-size=200", - "--secondary-worker-boot-disk-size=100" -]) +@pytest.mark.parametrize( + "machine_arg", + ["--master-boot-disk-size=250", "--worker-boot-disk-size=200", "--secondary-worker-boot-disk-size=100"], +) def test_boot_disk_size_configuration(gcloud_run, machine_arg): runner.invoke(cli.app, ['start', machine_arg, 'test-cluster']) assert f"{machine_arg}GB" in gcloud_run.call_args[0][0] @@ -87,11 +88,14 @@ def test_vep_defaults_to_larger_worker_boot_disk(gcloud_run): assert "--secondary-worker-boot-disk-size=200GB" in gcloud_run.call_args[0][0] -@pytest.mark.parametrize("requester_pays_arg", [ - "--requester-pays-allow-all", - "--requester-pays-allow-buckets=example-bucket", - "--requester-pays-allow-annotation-db", -]) +@pytest.mark.parametrize( + "requester_pays_arg", + [ + "--requester-pays-allow-all", + "--requester-pays-allow-buckets=example-bucket", + "--requester-pays-allow-annotation-db", + ], +) def test_requester_pays_project_configuration(gcloud_run, gcloud_config, requester_pays_arg): gcloud_config["project"] = "foo-project" @@ -104,11 +108,14 @@ def test_requester_pays_project_configuration(gcloud_run, gcloud_config, request assert "spark:spark.hadoop.fs.gs.requester.pays.project.id=bar-project" in properties -@pytest.mark.parametrize("requester_pays_arg,expected_mode", [ - ("--requester-pays-allow-all", "AUTO"), - ("--requester-pays-allow-buckets=example-bucket", "CUSTOM"), - ("--requester-pays-allow-annotation-db", "CUSTOM"), -]) +@pytest.mark.parametrize( + "requester_pays_arg,expected_mode", + [ + ("--requester-pays-allow-all", "AUTO"), + ("--requester-pays-allow-buckets=example-bucket", "CUSTOM"), + ("--requester-pays-allow-annotation-db", "CUSTOM"), + ], +) def test_requester_pays_mode_configuration(gcloud_run, requester_pays_arg, expected_mode): runner.invoke(cli.app, ['start', 'test-cluster', requester_pays_arg]) properties = next(arg for arg in gcloud_run.call_args[0][0] if arg.startswith("--properties=")) @@ -118,13 +125,16 @@ def test_requester_pays_mode_configuration(gcloud_run, requester_pays_arg, expec def test_requester_pays_buckets_configuration(gcloud_run): runner.invoke(cli.app, ['start', 'test-cluster', '--requester-pays-allow-buckets=foo,bar']) properties = next(arg for arg in gcloud_run.call_args[0][0] if arg.startswith("--properties=")) - assert f"spark:spark.hadoop.fs.gs.requester.pays.buckets=foo,bar" in properties + assert "spark:spark.hadoop.fs.gs.requester.pays.buckets=foo,bar" in properties -@pytest.mark.parametrize("scheduled_deletion_arg", [ - "--max-idle=30m", - "--max-age=1h", -]) +@pytest.mark.parametrize( + "scheduled_deletion_arg", + [ + "--max-idle=30m", + "--max-age=1h", + ], +) def test_scheduled_deletion_configuration(gcloud_run, scheduled_deletion_arg): runner.invoke(cli.app, ['start', scheduled_deletion_arg, 'test-cluster']) assert scheduled_deletion_arg in gcloud_run.call_args[0][0] diff --git a/hail/python/test/hailtop/hailctl/dataproc/test_stop.py b/hail/python/test/hailtop/hailctl/dataproc/test_stop.py index cb26bc79df0..f6fa9ccc721 100644 --- a/hail/python/test/hailtop/hailctl/dataproc/test_stop.py +++ b/hail/python/test/hailtop/hailctl/dataproc/test_stop.py @@ -2,7 +2,6 @@ from hailtop.hailctl.dataproc import cli - runner = CliRunner(mix_stderr=False) diff --git a/hail/python/test/hailtop/hailctl/dataproc/test_submit.py b/hail/python/test/hailtop/hailctl/dataproc/test_submit.py index 1cf45a5382d..6000f769f21 100644 --- a/hail/python/test/hailtop/hailctl/dataproc/test_submit.py +++ b/hail/python/test/hailtop/hailctl/dataproc/test_submit.py @@ -2,7 +2,6 @@ from hailtop.hailctl.dataproc import cli - runner = CliRunner(mix_stderr=False) @@ -32,7 +31,7 @@ def test_dry_run(gcloud_run): def test_script_args(gcloud_run): runner.invoke(cli.app, ['submit', 'test-cluster', 'a-script.py', '--foo', 'bar']) gcloud_args = gcloud_run.call_args[0][0] - job_args = gcloud_args[gcloud_args.index("--") + 1:] + job_args = gcloud_args[gcloud_args.index("--") + 1 :] assert job_args == ["--foo", "bar"] diff --git a/hail/python/test/hailtop/hailctl/dev/conftest.py b/hail/python/test/hailtop/hailctl/dev/conftest.py new file mode 100644 index 00000000000..caec59bfea8 --- /dev/null +++ b/hail/python/test/hailtop/hailctl/dev/conftest.py @@ -0,0 +1,15 @@ +import tempfile + +import pytest +from typer.testing import CliRunner + + +@pytest.fixture() +def deploy_config_file(): + with tempfile.NamedTemporaryFile() as f: + yield f.name + + +@pytest.fixture +def runner(deploy_config_file): + yield CliRunner(mix_stderr=False, env={'HAIL_DEPLOY_CONFIG_FILE': deploy_config_file}) diff --git a/hail/python/test/hailtop/hailctl/dev/test_config.py b/hail/python/test/hailtop/hailctl/dev/test_config.py new file mode 100644 index 00000000000..a743cf0171a --- /dev/null +++ b/hail/python/test/hailtop/hailctl/dev/test_config.py @@ -0,0 +1,44 @@ +import orjson +from typer.testing import CliRunner + +from hailtop.hailctl.dev import config as cli + +default_config = { + 'domain': 'example.com', + 'location': 'external', + 'default_namespace': 'default', +} + + +def set_deploy_config(deploy_config_file: str, config: dict): + with open(deploy_config_file, 'w', encoding='utf-8') as f: + f.write(orjson.dumps(config).decode('utf-8')) + + +def load_deploy_config_dict(deploy_config_file: str) -> dict: + with open(deploy_config_file, 'r', encoding='utf-8') as f: + return orjson.loads(f.read()) + + +def test_dev_config_set(runner: CliRunner, deploy_config_file: str): + set_deploy_config(deploy_config_file, default_config) + + res = runner.invoke(cli.app, ['set', 'default_namespace', 'foo']) + assert res.exit_code == 0, res.stderr + + expected = {'domain': 'example.com', 'location': 'external', 'default_namespace': 'foo'} + assert load_deploy_config_dict(deploy_config_file) == expected + + +def test_dev_config_set_not_affected_by_env_vars(runner: CliRunner, deploy_config_file: str): + set_deploy_config(deploy_config_file, default_config) + + res = runner.invoke( + cli.app, + ['set', 'default_namespace', 'foo'], + env={'HAIL_DOMAIN': 'foo.example.com'}, + ) + assert res.exit_code == 0 + + expected = {'domain': 'example.com', 'location': 'external', 'default_namespace': 'foo'} + assert load_deploy_config_dict(deploy_config_file) == expected diff --git a/hail/python/test/hailtop/inter_cloud/conftest.py b/hail/python/test/hailtop/inter_cloud/conftest.py new file mode 100644 index 00000000000..841650d9b6e --- /dev/null +++ b/hail/python/test/hailtop/inter_cloud/conftest.py @@ -0,0 +1,47 @@ +import asyncio +import functools +import os +import secrets +from typing import AsyncIterator, Dict, Tuple + +import pytest + +from hailtop.aiotools.router_fs import AsyncFS, RouterAsyncFS +from hailtop.utils import bounded_gather2 + + +@pytest.fixture(scope='module') +async def router_filesystem() -> AsyncIterator[Tuple[asyncio.Semaphore, AsyncFS, Dict[str, str]]]: + token = secrets.token_hex(16) + + async with RouterAsyncFS() as fs: + file_base = f'/tmp/{token}/' + await fs.mkdir(file_base) + + gs_bucket = os.environ['HAIL_TEST_GCS_BUCKET'] + gs_base = f'gs://{gs_bucket}/tmp/{token}/' + + s3_bucket = os.environ['HAIL_TEST_S3_BUCKET'] + s3_base = f's3://{s3_bucket}/tmp/{token}/' + + azure_account = os.environ['HAIL_TEST_AZURE_ACCOUNT'] + azure_container = os.environ['HAIL_TEST_AZURE_CONTAINER'] + azure_base = f'https://{azure_account}.blob.core.windows.net/{azure_container}/tmp/{token}/' + + bases = {'file': file_base, 'gs': gs_base, 's3': s3_base, 'azure-https': azure_base} + + sema = asyncio.Semaphore(50) + async with sema: + yield (sema, fs, bases) + await bounded_gather2( + sema, + functools.partial(fs.rmtree, sema, file_base), + functools.partial(fs.rmtree, sema, gs_base), + functools.partial(fs.rmtree, sema, s3_base), + functools.partial(fs.rmtree, sema, azure_base), + ) + + assert not await fs.isdir(file_base) + assert not await fs.isdir(gs_base) + assert not await fs.isdir(s3_base) + assert not await fs.isdir(azure_base) diff --git a/hail/python/test/hailtop/inter_cloud/copy_test_specs.py b/hail/python/test/hailtop/inter_cloud/copy_test_specs.py index 4606cb095df..00a8314933f 100644 --- a/hail/python/test/hailtop/inter_cloud/copy_test_specs.py +++ b/hail/python/test/hailtop/inter_cloud/copy_test_specs.py @@ -1,2547 +1,3309 @@ -COPY_TEST_SPECS = [{'dest_basename': None, - 'dest_trailing_slash': True, - 'dest_type': 'file', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'file', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': None, - 'dest_trailing_slash': False, - 'dest_type': 'file', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'file', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': None, - 'dest_trailing_slash': True, - 'dest_type': 'file', - 'result': {'files': {'/a': 'src/a', '/keep': ''}}, - 'src_trailing_slash': False, - 'src_type': 'file', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': None, - 'dest_trailing_slash': False, - 'dest_type': 'file', - 'result': {'files': {'/a': 'src/a', '/keep': ''}}, - 'src_trailing_slash': False, - 'src_type': 'file', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': None, - 'dest_trailing_slash': True, - 'dest_type': 'file', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'file', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': None, - 'dest_trailing_slash': False, - 'dest_type': 'file', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'file', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': None, - 'dest_trailing_slash': True, - 'dest_type': 'file', - 'result': {'exception': 'IsADirectoryError'}, - 'src_trailing_slash': False, - 'src_type': 'file', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': None, - 'dest_trailing_slash': False, - 'dest_type': 'file', - 'result': {'exception': 'IsADirectoryError'}, - 'src_trailing_slash': False, - 'src_type': 'file', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': None, - 'dest_trailing_slash': True, - 'dest_type': 'file', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'file', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': None, - 'dest_trailing_slash': False, - 'dest_type': 'file', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'file', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': None, - 'dest_trailing_slash': True, - 'dest_type': 'file', - 'result': {'files': {'/a': 'src/a', '/keep': ''}}, - 'src_trailing_slash': False, - 'src_type': 'file', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': None, - 'dest_trailing_slash': False, - 'dest_type': 'file', - 'result': {'files': {'/a': 'src/a', '/keep': ''}}, - 'src_trailing_slash': False, - 'src_type': 'file', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'a', - 'dest_trailing_slash': True, - 'dest_type': 'file', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'file', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'a', - 'dest_trailing_slash': False, - 'dest_type': 'file', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'file', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'a', - 'dest_trailing_slash': True, - 'dest_type': 'file', - 'result': {'exception': 'NotADirectoryError'}, - 'src_trailing_slash': False, - 'src_type': 'file', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'a', - 'dest_trailing_slash': False, - 'dest_type': 'file', - 'result': {'exception': 'NotADirectoryError'}, - 'src_trailing_slash': False, - 'src_type': 'file', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'a', - 'dest_trailing_slash': True, - 'dest_type': 'file', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'file', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'a', - 'dest_trailing_slash': False, - 'dest_type': 'file', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'file', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'a', - 'dest_trailing_slash': True, - 'dest_type': 'file', - 'result': {'exception': 'IsADirectoryError'}, - 'src_trailing_slash': False, - 'src_type': 'file', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'a', - 'dest_trailing_slash': False, - 'dest_type': 'file', - 'result': {'files': {'/a': 'src/a', '/keep': ''}}, - 'src_trailing_slash': False, - 'src_type': 'file', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'a', - 'dest_trailing_slash': True, - 'dest_type': 'file', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'file', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'a', - 'dest_trailing_slash': False, - 'dest_type': 'file', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'file', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'a', - 'dest_trailing_slash': True, - 'dest_type': 'file', - 'result': {'exception': 'NotADirectoryError'}, - 'src_trailing_slash': False, - 'src_type': 'file', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'a', - 'dest_trailing_slash': False, - 'dest_type': 'file', - 'result': {'files': {'/a': 'src/a', '/keep': ''}}, - 'src_trailing_slash': False, - 'src_type': 'file', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'x', - 'dest_trailing_slash': True, - 'dest_type': 'file', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'file', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'x', - 'dest_trailing_slash': False, - 'dest_type': 'file', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'file', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'x', - 'dest_trailing_slash': True, - 'dest_type': 'file', - 'result': {'files': {'/a': 'dest/a', '/keep': '', '/x/a': 'src/a'}}, - 'src_trailing_slash': False, - 'src_type': 'file', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'x', - 'dest_trailing_slash': False, - 'dest_type': 'file', - 'result': {'files': {'/a': 'dest/a', '/keep': '', '/x/a': 'src/a'}}, - 'src_trailing_slash': False, - 'src_type': 'file', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'x', - 'dest_trailing_slash': True, - 'dest_type': 'file', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'file', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'x', - 'dest_trailing_slash': False, - 'dest_type': 'file', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'file', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'x', - 'dest_trailing_slash': True, - 'dest_type': 'file', - 'result': {'exception': 'IsADirectoryError'}, - 'src_trailing_slash': False, - 'src_type': 'file', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'x', - 'dest_trailing_slash': False, - 'dest_type': 'file', - 'result': {'files': {'/a': 'dest/a', '/keep': '', '/x': 'src/a'}}, - 'src_trailing_slash': False, - 'src_type': 'file', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'x', - 'dest_trailing_slash': True, - 'dest_type': 'file', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'file', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'x', - 'dest_trailing_slash': False, - 'dest_type': 'file', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'file', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'x', - 'dest_trailing_slash': True, - 'dest_type': 'file', - 'result': {'files': {'/a': 'dest/a', '/keep': '', '/x/a': 'src/a'}}, - 'src_trailing_slash': False, - 'src_type': 'file', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'x', - 'dest_trailing_slash': False, - 'dest_type': 'file', - 'result': {'files': {'/a': 'dest/a', '/keep': '', '/x': 'src/a'}}, - 'src_trailing_slash': False, - 'src_type': 'file', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': None, - 'dest_trailing_slash': True, - 'dest_type': 'dir', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'file', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': None, - 'dest_trailing_slash': False, - 'dest_type': 'dir', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'file', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': None, - 'dest_trailing_slash': True, - 'dest_type': 'dir', - 'result': {'exception': 'IsADirectoryError'}, - 'src_trailing_slash': False, - 'src_type': 'file', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': None, - 'dest_trailing_slash': False, - 'dest_type': 'dir', - 'result': {'exception': 'IsADirectoryError'}, - 'src_trailing_slash': False, - 'src_type': 'file', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': None, - 'dest_trailing_slash': True, - 'dest_type': 'dir', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'file', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': None, - 'dest_trailing_slash': False, - 'dest_type': 'dir', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'file', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': None, - 'dest_trailing_slash': True, - 'dest_type': 'dir', - 'result': {'exception': 'IsADirectoryError'}, - 'src_trailing_slash': False, - 'src_type': 'file', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': None, - 'dest_trailing_slash': False, - 'dest_type': 'dir', - 'result': {'exception': 'IsADirectoryError'}, - 'src_trailing_slash': False, - 'src_type': 'file', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': None, - 'dest_trailing_slash': True, - 'dest_type': 'dir', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'file', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': None, - 'dest_trailing_slash': False, - 'dest_type': 'dir', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'file', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': None, - 'dest_trailing_slash': True, - 'dest_type': 'dir', - 'result': {'exception': 'IsADirectoryError'}, - 'src_trailing_slash': False, - 'src_type': 'file', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': None, - 'dest_trailing_slash': False, - 'dest_type': 'dir', - 'result': {'exception': 'IsADirectoryError'}, - 'src_trailing_slash': False, - 'src_type': 'file', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'a', - 'dest_trailing_slash': True, - 'dest_type': 'dir', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'file', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'a', - 'dest_trailing_slash': False, - 'dest_type': 'dir', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'file', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'a', - 'dest_trailing_slash': True, - 'dest_type': 'dir', - 'result': {'files': {'/a/a': 'src/a', - '/a/file3': 'dest/a/file3', - '/a/subdir/file2': 'dest/a/subdir/file2', - '/keep': ''}}, - 'src_trailing_slash': False, - 'src_type': 'file', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'a', - 'dest_trailing_slash': False, - 'dest_type': 'dir', - 'result': {'files': {'/a/a': 'src/a', - '/a/file3': 'dest/a/file3', - '/a/subdir/file2': 'dest/a/subdir/file2', - '/keep': ''}}, - 'src_trailing_slash': False, - 'src_type': 'file', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'a', - 'dest_trailing_slash': True, - 'dest_type': 'dir', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'file', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'a', - 'dest_trailing_slash': False, - 'dest_type': 'dir', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'file', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'a', - 'dest_trailing_slash': True, - 'dest_type': 'dir', - 'result': {'exception': 'IsADirectoryError'}, - 'src_trailing_slash': False, - 'src_type': 'file', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'a', - 'dest_trailing_slash': False, - 'dest_type': 'dir', - 'result': {'exception': 'IsADirectoryError'}, - 'src_trailing_slash': False, - 'src_type': 'file', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'a', - 'dest_trailing_slash': True, - 'dest_type': 'dir', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'file', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'a', - 'dest_trailing_slash': False, - 'dest_type': 'dir', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'file', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'a', - 'dest_trailing_slash': True, - 'dest_type': 'dir', - 'result': {'files': {'/a/a': 'src/a', - '/a/file3': 'dest/a/file3', - '/a/subdir/file2': 'dest/a/subdir/file2', - '/keep': ''}}, - 'src_trailing_slash': False, - 'src_type': 'file', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'a', - 'dest_trailing_slash': False, - 'dest_type': 'dir', - 'result': {'files': {'/a/a': 'src/a', - '/a/file3': 'dest/a/file3', - '/a/subdir/file2': 'dest/a/subdir/file2', - '/keep': ''}}, - 'src_trailing_slash': False, - 'src_type': 'file', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'x', - 'dest_trailing_slash': True, - 'dest_type': 'dir', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'file', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'x', - 'dest_trailing_slash': False, - 'dest_type': 'dir', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'file', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'x', - 'dest_trailing_slash': True, - 'dest_type': 'dir', - 'result': {'files': {'/a/file3': 'dest/a/file3', - '/a/subdir/file2': 'dest/a/subdir/file2', - '/keep': '', - '/x/a': 'src/a'}}, - 'src_trailing_slash': False, - 'src_type': 'file', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'x', - 'dest_trailing_slash': False, - 'dest_type': 'dir', - 'result': {'files': {'/a/file3': 'dest/a/file3', - '/a/subdir/file2': 'dest/a/subdir/file2', - '/keep': '', - '/x/a': 'src/a'}}, - 'src_trailing_slash': False, - 'src_type': 'file', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'x', - 'dest_trailing_slash': True, - 'dest_type': 'dir', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'file', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'x', - 'dest_trailing_slash': False, - 'dest_type': 'dir', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'file', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'x', - 'dest_trailing_slash': True, - 'dest_type': 'dir', - 'result': {'exception': 'IsADirectoryError'}, - 'src_trailing_slash': False, - 'src_type': 'file', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'x', - 'dest_trailing_slash': False, - 'dest_type': 'dir', - 'result': {'files': {'/a/file3': 'dest/a/file3', - '/a/subdir/file2': 'dest/a/subdir/file2', - '/keep': '', - '/x': 'src/a'}}, - 'src_trailing_slash': False, - 'src_type': 'file', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'x', - 'dest_trailing_slash': True, - 'dest_type': 'dir', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'file', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'x', - 'dest_trailing_slash': False, - 'dest_type': 'dir', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'file', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'x', - 'dest_trailing_slash': True, - 'dest_type': 'dir', - 'result': {'files': {'/a/file3': 'dest/a/file3', - '/a/subdir/file2': 'dest/a/subdir/file2', - '/keep': '', - '/x/a': 'src/a'}}, - 'src_trailing_slash': False, - 'src_type': 'file', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'x', - 'dest_trailing_slash': False, - 'dest_type': 'dir', - 'result': {'files': {'/a/file3': 'dest/a/file3', - '/a/subdir/file2': 'dest/a/subdir/file2', - '/keep': '', - '/x': 'src/a'}}, - 'src_trailing_slash': False, - 'src_type': 'file', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': None, - 'dest_trailing_slash': True, - 'dest_type': 'noexist', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'file', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': None, - 'dest_trailing_slash': False, - 'dest_type': 'noexist', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'file', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': None, - 'dest_trailing_slash': True, - 'dest_type': 'noexist', - 'result': {'files': {'/a': 'src/a', '/keep': ''}}, - 'src_trailing_slash': False, - 'src_type': 'file', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': None, - 'dest_trailing_slash': False, - 'dest_type': 'noexist', - 'result': {'files': {'/a': 'src/a', '/keep': ''}}, - 'src_trailing_slash': False, - 'src_type': 'file', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': None, - 'dest_trailing_slash': True, - 'dest_type': 'noexist', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'file', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': None, - 'dest_trailing_slash': False, - 'dest_type': 'noexist', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'file', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': None, - 'dest_trailing_slash': True, - 'dest_type': 'noexist', - 'result': {'exception': 'IsADirectoryError'}, - 'src_trailing_slash': False, - 'src_type': 'file', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': None, - 'dest_trailing_slash': False, - 'dest_type': 'noexist', - 'result': {'exception': 'IsADirectoryError'}, - 'src_trailing_slash': False, - 'src_type': 'file', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': None, - 'dest_trailing_slash': True, - 'dest_type': 'noexist', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'file', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': None, - 'dest_trailing_slash': False, - 'dest_type': 'noexist', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'file', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': None, - 'dest_trailing_slash': True, - 'dest_type': 'noexist', - 'result': {'files': {'/a': 'src/a', '/keep': ''}}, - 'src_trailing_slash': False, - 'src_type': 'file', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': None, - 'dest_trailing_slash': False, - 'dest_type': 'noexist', - 'result': {'files': {'/a': 'src/a', '/keep': ''}}, - 'src_trailing_slash': False, - 'src_type': 'file', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'a', - 'dest_trailing_slash': True, - 'dest_type': 'noexist', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'file', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'a', - 'dest_trailing_slash': False, - 'dest_type': 'noexist', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'file', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'a', - 'dest_trailing_slash': True, - 'dest_type': 'noexist', - 'result': {'files': {'/a/a': 'src/a', '/keep': ''}}, - 'src_trailing_slash': False, - 'src_type': 'file', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'a', - 'dest_trailing_slash': False, - 'dest_type': 'noexist', - 'result': {'files': {'/a/a': 'src/a', '/keep': ''}}, - 'src_trailing_slash': False, - 'src_type': 'file', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'a', - 'dest_trailing_slash': True, - 'dest_type': 'noexist', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'file', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'a', - 'dest_trailing_slash': False, - 'dest_type': 'noexist', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'file', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'a', - 'dest_trailing_slash': True, - 'dest_type': 'noexist', - 'result': {'exception': 'IsADirectoryError'}, - 'src_trailing_slash': False, - 'src_type': 'file', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'a', - 'dest_trailing_slash': False, - 'dest_type': 'noexist', - 'result': {'files': {'/a': 'src/a', '/keep': ''}}, - 'src_trailing_slash': False, - 'src_type': 'file', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'a', - 'dest_trailing_slash': True, - 'dest_type': 'noexist', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'file', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'a', - 'dest_trailing_slash': False, - 'dest_type': 'noexist', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'file', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'a', - 'dest_trailing_slash': True, - 'dest_type': 'noexist', - 'result': {'files': {'/a/a': 'src/a', '/keep': ''}}, - 'src_trailing_slash': False, - 'src_type': 'file', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'a', - 'dest_trailing_slash': False, - 'dest_type': 'noexist', - 'result': {'files': {'/a': 'src/a', '/keep': ''}}, - 'src_trailing_slash': False, - 'src_type': 'file', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'x', - 'dest_trailing_slash': True, - 'dest_type': 'noexist', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'file', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'x', - 'dest_trailing_slash': False, - 'dest_type': 'noexist', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'file', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'x', - 'dest_trailing_slash': True, - 'dest_type': 'noexist', - 'result': {'files': {'/keep': '', '/x/a': 'src/a'}}, - 'src_trailing_slash': False, - 'src_type': 'file', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'x', - 'dest_trailing_slash': False, - 'dest_type': 'noexist', - 'result': {'files': {'/keep': '', '/x/a': 'src/a'}}, - 'src_trailing_slash': False, - 'src_type': 'file', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'x', - 'dest_trailing_slash': True, - 'dest_type': 'noexist', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'file', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'x', - 'dest_trailing_slash': False, - 'dest_type': 'noexist', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'file', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'x', - 'dest_trailing_slash': True, - 'dest_type': 'noexist', - 'result': {'exception': 'IsADirectoryError'}, - 'src_trailing_slash': False, - 'src_type': 'file', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'x', - 'dest_trailing_slash': False, - 'dest_type': 'noexist', - 'result': {'files': {'/keep': '', '/x': 'src/a'}}, - 'src_trailing_slash': False, - 'src_type': 'file', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'x', - 'dest_trailing_slash': True, - 'dest_type': 'noexist', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'file', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'x', - 'dest_trailing_slash': False, - 'dest_type': 'noexist', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'file', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'x', - 'dest_trailing_slash': True, - 'dest_type': 'noexist', - 'result': {'files': {'/keep': '', '/x/a': 'src/a'}}, - 'src_trailing_slash': False, - 'src_type': 'file', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'x', - 'dest_trailing_slash': False, - 'dest_type': 'noexist', - 'result': {'files': {'/keep': '', '/x': 'src/a'}}, - 'src_trailing_slash': False, - 'src_type': 'file', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': None, - 'dest_trailing_slash': True, - 'dest_type': 'file', - 'result': {'exception': 'NotADirectoryError'}, - 'src_trailing_slash': True, - 'src_type': 'dir', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': None, - 'dest_trailing_slash': False, - 'dest_type': 'file', - 'result': {'exception': 'NotADirectoryError'}, - 'src_trailing_slash': True, - 'src_type': 'dir', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': None, - 'dest_trailing_slash': True, - 'dest_type': 'file', - 'result': {'exception': 'NotADirectoryError'}, - 'src_trailing_slash': False, - 'src_type': 'dir', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': None, - 'dest_trailing_slash': False, - 'dest_type': 'file', - 'result': {'exception': 'NotADirectoryError'}, - 'src_trailing_slash': False, - 'src_type': 'dir', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': None, - 'dest_trailing_slash': True, - 'dest_type': 'file', - 'result': {'files': {'/a': 'dest/a', - '/file1': 'src/a/file1', - '/keep': '', - '/subdir/file2': 'src/a/subdir/file2'}}, - 'src_trailing_slash': True, - 'src_type': 'dir', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': None, - 'dest_trailing_slash': False, - 'dest_type': 'file', - 'result': {'files': {'/a': 'dest/a', - '/file1': 'src/a/file1', - '/keep': '', - '/subdir/file2': 'src/a/subdir/file2'}}, - 'src_trailing_slash': True, - 'src_type': 'dir', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': None, - 'dest_trailing_slash': True, - 'dest_type': 'file', - 'result': {'files': {'/a': 'dest/a', - '/file1': 'src/a/file1', - '/keep': '', - '/subdir/file2': 'src/a/subdir/file2'}}, - 'src_trailing_slash': False, - 'src_type': 'dir', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': None, - 'dest_trailing_slash': False, - 'dest_type': 'file', - 'result': {'files': {'/a': 'dest/a', - '/file1': 'src/a/file1', - '/keep': '', - '/subdir/file2': 'src/a/subdir/file2'}}, - 'src_trailing_slash': False, - 'src_type': 'dir', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': None, - 'dest_trailing_slash': True, - 'dest_type': 'file', - 'result': {'exception': 'NotADirectoryError'}, - 'src_trailing_slash': True, - 'src_type': 'dir', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': None, - 'dest_trailing_slash': False, - 'dest_type': 'file', - 'result': {'exception': 'NotADirectoryError'}, - 'src_trailing_slash': True, - 'src_type': 'dir', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': None, - 'dest_trailing_slash': True, - 'dest_type': 'file', - 'result': {'exception': 'NotADirectoryError'}, - 'src_trailing_slash': False, - 'src_type': 'dir', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': None, - 'dest_trailing_slash': False, - 'dest_type': 'file', - 'result': {'exception': 'NotADirectoryError'}, - 'src_trailing_slash': False, - 'src_type': 'dir', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'a', - 'dest_trailing_slash': True, - 'dest_type': 'file', - 'result': {'exception': 'NotADirectoryError'}, - 'src_trailing_slash': True, - 'src_type': 'dir', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'a', - 'dest_trailing_slash': False, - 'dest_type': 'file', - 'result': {'exception': 'NotADirectoryError'}, - 'src_trailing_slash': True, - 'src_type': 'dir', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'a', - 'dest_trailing_slash': True, - 'dest_type': 'file', - 'result': {'exception': 'NotADirectoryError'}, - 'src_trailing_slash': False, - 'src_type': 'dir', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'a', - 'dest_trailing_slash': False, - 'dest_type': 'file', - 'result': {'exception': 'NotADirectoryError'}, - 'src_trailing_slash': False, - 'src_type': 'dir', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'a', - 'dest_trailing_slash': True, - 'dest_type': 'file', - 'result': {'exception': 'NotADirectoryError'}, - 'src_trailing_slash': True, - 'src_type': 'dir', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'a', - 'dest_trailing_slash': False, - 'dest_type': 'file', - 'result': {'exception': 'NotADirectoryError'}, - 'src_trailing_slash': True, - 'src_type': 'dir', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'a', - 'dest_trailing_slash': True, - 'dest_type': 'file', - 'result': {'exception': 'NotADirectoryError'}, - 'src_trailing_slash': False, - 'src_type': 'dir', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'a', - 'dest_trailing_slash': False, - 'dest_type': 'file', - 'result': {'exception': 'NotADirectoryError'}, - 'src_trailing_slash': False, - 'src_type': 'dir', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'a', - 'dest_trailing_slash': True, - 'dest_type': 'file', - 'result': {'exception': 'NotADirectoryError'}, - 'src_trailing_slash': True, - 'src_type': 'dir', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'a', - 'dest_trailing_slash': False, - 'dest_type': 'file', - 'result': {'exception': 'NotADirectoryError'}, - 'src_trailing_slash': True, - 'src_type': 'dir', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'a', - 'dest_trailing_slash': True, - 'dest_type': 'file', - 'result': {'exception': 'NotADirectoryError'}, - 'src_trailing_slash': False, - 'src_type': 'dir', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'a', - 'dest_trailing_slash': False, - 'dest_type': 'file', - 'result': {'exception': 'NotADirectoryError'}, - 'src_trailing_slash': False, - 'src_type': 'dir', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'x', - 'dest_trailing_slash': True, - 'dest_type': 'file', - 'result': {'files': {'/a': 'dest/a', - '/keep': '', - '/x/a/file1': 'src/a/file1', - '/x/a/subdir/file2': 'src/a/subdir/file2'}}, - 'src_trailing_slash': True, - 'src_type': 'dir', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'x', - 'dest_trailing_slash': False, - 'dest_type': 'file', - 'result': {'files': {'/a': 'dest/a', - '/keep': '', - '/x/a/file1': 'src/a/file1', - '/x/a/subdir/file2': 'src/a/subdir/file2'}}, - 'src_trailing_slash': True, - 'src_type': 'dir', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'x', - 'dest_trailing_slash': True, - 'dest_type': 'file', - 'result': {'files': {'/a': 'dest/a', - '/keep': '', - '/x/a/file1': 'src/a/file1', - '/x/a/subdir/file2': 'src/a/subdir/file2'}}, - 'src_trailing_slash': False, - 'src_type': 'dir', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'x', - 'dest_trailing_slash': False, - 'dest_type': 'file', - 'result': {'files': {'/a': 'dest/a', - '/keep': '', - '/x/a/file1': 'src/a/file1', - '/x/a/subdir/file2': 'src/a/subdir/file2'}}, - 'src_trailing_slash': False, - 'src_type': 'dir', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'x', - 'dest_trailing_slash': True, - 'dest_type': 'file', - 'result': {'files': {'/a': 'dest/a', - '/keep': '', - '/x/file1': 'src/a/file1', - '/x/subdir/file2': 'src/a/subdir/file2'}}, - 'src_trailing_slash': True, - 'src_type': 'dir', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'x', - 'dest_trailing_slash': False, - 'dest_type': 'file', - 'result': {'files': {'/a': 'dest/a', - '/keep': '', - '/x/file1': 'src/a/file1', - '/x/subdir/file2': 'src/a/subdir/file2'}}, - 'src_trailing_slash': True, - 'src_type': 'dir', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'x', - 'dest_trailing_slash': True, - 'dest_type': 'file', - 'result': {'files': {'/a': 'dest/a', - '/keep': '', - '/x/file1': 'src/a/file1', - '/x/subdir/file2': 'src/a/subdir/file2'}}, - 'src_trailing_slash': False, - 'src_type': 'dir', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'x', - 'dest_trailing_slash': False, - 'dest_type': 'file', - 'result': {'files': {'/a': 'dest/a', - '/keep': '', - '/x/file1': 'src/a/file1', - '/x/subdir/file2': 'src/a/subdir/file2'}}, - 'src_trailing_slash': False, - 'src_type': 'dir', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'x', - 'dest_trailing_slash': True, - 'dest_type': 'file', - 'result': {'files': {'/a': 'dest/a', - '/keep': '', - '/x/a/file1': 'src/a/file1', - '/x/a/subdir/file2': 'src/a/subdir/file2'}}, - 'src_trailing_slash': True, - 'src_type': 'dir', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'x', - 'dest_trailing_slash': False, - 'dest_type': 'file', - 'result': {'files': {'/a': 'dest/a', - '/keep': '', - '/x/file1': 'src/a/file1', - '/x/subdir/file2': 'src/a/subdir/file2'}}, - 'src_trailing_slash': True, - 'src_type': 'dir', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'x', - 'dest_trailing_slash': True, - 'dest_type': 'file', - 'result': {'files': {'/a': 'dest/a', - '/keep': '', - '/x/a/file1': 'src/a/file1', - '/x/a/subdir/file2': 'src/a/subdir/file2'}}, - 'src_trailing_slash': False, - 'src_type': 'dir', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'x', - 'dest_trailing_slash': False, - 'dest_type': 'file', - 'result': {'files': {'/a': 'dest/a', - '/keep': '', - '/x/file1': 'src/a/file1', - '/x/subdir/file2': 'src/a/subdir/file2'}}, - 'src_trailing_slash': False, - 'src_type': 'dir', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': None, - 'dest_trailing_slash': True, - 'dest_type': 'dir', - 'result': {'files': {'/a/file1': 'src/a/file1', - '/a/file3': 'dest/a/file3', - '/a/subdir/file2': 'src/a/subdir/file2', - '/keep': ''}}, - 'src_trailing_slash': True, - 'src_type': 'dir', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': None, - 'dest_trailing_slash': False, - 'dest_type': 'dir', - 'result': {'files': {'/a/file1': 'src/a/file1', - '/a/file3': 'dest/a/file3', - '/a/subdir/file2': 'src/a/subdir/file2', - '/keep': ''}}, - 'src_trailing_slash': True, - 'src_type': 'dir', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': None, - 'dest_trailing_slash': True, - 'dest_type': 'dir', - 'result': {'files': {'/a/file1': 'src/a/file1', - '/a/file3': 'dest/a/file3', - '/a/subdir/file2': 'src/a/subdir/file2', - '/keep': ''}}, - 'src_trailing_slash': False, - 'src_type': 'dir', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': None, - 'dest_trailing_slash': False, - 'dest_type': 'dir', - 'result': {'files': {'/a/file1': 'src/a/file1', - '/a/file3': 'dest/a/file3', - '/a/subdir/file2': 'src/a/subdir/file2', - '/keep': ''}}, - 'src_trailing_slash': False, - 'src_type': 'dir', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': None, - 'dest_trailing_slash': True, - 'dest_type': 'dir', - 'result': {'files': {'/a/file3': 'dest/a/file3', - '/a/subdir/file2': 'dest/a/subdir/file2', - '/file1': 'src/a/file1', - '/keep': '', - '/subdir/file2': 'src/a/subdir/file2'}}, - 'src_trailing_slash': True, - 'src_type': 'dir', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': None, - 'dest_trailing_slash': False, - 'dest_type': 'dir', - 'result': {'files': {'/a/file3': 'dest/a/file3', - '/a/subdir/file2': 'dest/a/subdir/file2', - '/file1': 'src/a/file1', - '/keep': '', - '/subdir/file2': 'src/a/subdir/file2'}}, - 'src_trailing_slash': True, - 'src_type': 'dir', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': None, - 'dest_trailing_slash': True, - 'dest_type': 'dir', - 'result': {'files': {'/a/file3': 'dest/a/file3', - '/a/subdir/file2': 'dest/a/subdir/file2', - '/file1': 'src/a/file1', - '/keep': '', - '/subdir/file2': 'src/a/subdir/file2'}}, - 'src_trailing_slash': False, - 'src_type': 'dir', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': None, - 'dest_trailing_slash': False, - 'dest_type': 'dir', - 'result': {'files': {'/a/file3': 'dest/a/file3', - '/a/subdir/file2': 'dest/a/subdir/file2', - '/file1': 'src/a/file1', - '/keep': '', - '/subdir/file2': 'src/a/subdir/file2'}}, - 'src_trailing_slash': False, - 'src_type': 'dir', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': None, - 'dest_trailing_slash': True, - 'dest_type': 'dir', - 'result': {'files': {'/a/file1': 'src/a/file1', - '/a/file3': 'dest/a/file3', - '/a/subdir/file2': 'src/a/subdir/file2', - '/keep': ''}}, - 'src_trailing_slash': True, - 'src_type': 'dir', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': None, - 'dest_trailing_slash': False, - 'dest_type': 'dir', - 'result': {'files': {'/a/file1': 'src/a/file1', - '/a/file3': 'dest/a/file3', - '/a/subdir/file2': 'src/a/subdir/file2', - '/keep': ''}}, - 'src_trailing_slash': True, - 'src_type': 'dir', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': None, - 'dest_trailing_slash': True, - 'dest_type': 'dir', - 'result': {'files': {'/a/file1': 'src/a/file1', - '/a/file3': 'dest/a/file3', - '/a/subdir/file2': 'src/a/subdir/file2', - '/keep': ''}}, - 'src_trailing_slash': False, - 'src_type': 'dir', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': None, - 'dest_trailing_slash': False, - 'dest_type': 'dir', - 'result': {'files': {'/a/file1': 'src/a/file1', - '/a/file3': 'dest/a/file3', - '/a/subdir/file2': 'src/a/subdir/file2', - '/keep': ''}}, - 'src_trailing_slash': False, - 'src_type': 'dir', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'a', - 'dest_trailing_slash': True, - 'dest_type': 'dir', - 'result': {'files': {'/a/a/file1': 'src/a/file1', - '/a/a/subdir/file2': 'src/a/subdir/file2', - '/a/file3': 'dest/a/file3', - '/a/subdir/file2': 'dest/a/subdir/file2', - '/keep': ''}}, - 'src_trailing_slash': True, - 'src_type': 'dir', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'a', - 'dest_trailing_slash': False, - 'dest_type': 'dir', - 'result': {'files': {'/a/a/file1': 'src/a/file1', - '/a/a/subdir/file2': 'src/a/subdir/file2', - '/a/file3': 'dest/a/file3', - '/a/subdir/file2': 'dest/a/subdir/file2', - '/keep': ''}}, - 'src_trailing_slash': True, - 'src_type': 'dir', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'a', - 'dest_trailing_slash': True, - 'dest_type': 'dir', - 'result': {'files': {'/a/a/file1': 'src/a/file1', - '/a/a/subdir/file2': 'src/a/subdir/file2', - '/a/file3': 'dest/a/file3', - '/a/subdir/file2': 'dest/a/subdir/file2', - '/keep': ''}}, - 'src_trailing_slash': False, - 'src_type': 'dir', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'a', - 'dest_trailing_slash': False, - 'dest_type': 'dir', - 'result': {'files': {'/a/a/file1': 'src/a/file1', - '/a/a/subdir/file2': 'src/a/subdir/file2', - '/a/file3': 'dest/a/file3', - '/a/subdir/file2': 'dest/a/subdir/file2', - '/keep': ''}}, - 'src_trailing_slash': False, - 'src_type': 'dir', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'a', - 'dest_trailing_slash': True, - 'dest_type': 'dir', - 'result': {'files': {'/a/file1': 'src/a/file1', - '/a/file3': 'dest/a/file3', - '/a/subdir/file2': 'src/a/subdir/file2', - '/keep': ''}}, - 'src_trailing_slash': True, - 'src_type': 'dir', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'a', - 'dest_trailing_slash': False, - 'dest_type': 'dir', - 'result': {'files': {'/a/file1': 'src/a/file1', - '/a/file3': 'dest/a/file3', - '/a/subdir/file2': 'src/a/subdir/file2', - '/keep': ''}}, - 'src_trailing_slash': True, - 'src_type': 'dir', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'a', - 'dest_trailing_slash': True, - 'dest_type': 'dir', - 'result': {'files': {'/a/file1': 'src/a/file1', - '/a/file3': 'dest/a/file3', - '/a/subdir/file2': 'src/a/subdir/file2', - '/keep': ''}}, - 'src_trailing_slash': False, - 'src_type': 'dir', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'a', - 'dest_trailing_slash': False, - 'dest_type': 'dir', - 'result': {'files': {'/a/file1': 'src/a/file1', - '/a/file3': 'dest/a/file3', - '/a/subdir/file2': 'src/a/subdir/file2', - '/keep': ''}}, - 'src_trailing_slash': False, - 'src_type': 'dir', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'a', - 'dest_trailing_slash': True, - 'dest_type': 'dir', - 'result': {'files': {'/a/a/file1': 'src/a/file1', - '/a/a/subdir/file2': 'src/a/subdir/file2', - '/a/file3': 'dest/a/file3', - '/a/subdir/file2': 'dest/a/subdir/file2', - '/keep': ''}}, - 'src_trailing_slash': True, - 'src_type': 'dir', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'a', - 'dest_trailing_slash': False, - 'dest_type': 'dir', - 'result': {'files': {'/a/a/file1': 'src/a/file1', - '/a/a/subdir/file2': 'src/a/subdir/file2', - '/a/file3': 'dest/a/file3', - '/a/subdir/file2': 'dest/a/subdir/file2', - '/keep': ''}}, - 'src_trailing_slash': True, - 'src_type': 'dir', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'a', - 'dest_trailing_slash': True, - 'dest_type': 'dir', - 'result': {'files': {'/a/a/file1': 'src/a/file1', - '/a/a/subdir/file2': 'src/a/subdir/file2', - '/a/file3': 'dest/a/file3', - '/a/subdir/file2': 'dest/a/subdir/file2', - '/keep': ''}}, - 'src_trailing_slash': False, - 'src_type': 'dir', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'a', - 'dest_trailing_slash': False, - 'dest_type': 'dir', - 'result': {'files': {'/a/a/file1': 'src/a/file1', - '/a/a/subdir/file2': 'src/a/subdir/file2', - '/a/file3': 'dest/a/file3', - '/a/subdir/file2': 'dest/a/subdir/file2', - '/keep': ''}}, - 'src_trailing_slash': False, - 'src_type': 'dir', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'x', - 'dest_trailing_slash': True, - 'dest_type': 'dir', - 'result': {'files': {'/a/file3': 'dest/a/file3', - '/a/subdir/file2': 'dest/a/subdir/file2', - '/keep': '', - '/x/a/file1': 'src/a/file1', - '/x/a/subdir/file2': 'src/a/subdir/file2'}}, - 'src_trailing_slash': True, - 'src_type': 'dir', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'x', - 'dest_trailing_slash': False, - 'dest_type': 'dir', - 'result': {'files': {'/a/file3': 'dest/a/file3', - '/a/subdir/file2': 'dest/a/subdir/file2', - '/keep': '', - '/x/a/file1': 'src/a/file1', - '/x/a/subdir/file2': 'src/a/subdir/file2'}}, - 'src_trailing_slash': True, - 'src_type': 'dir', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'x', - 'dest_trailing_slash': True, - 'dest_type': 'dir', - 'result': {'files': {'/a/file3': 'dest/a/file3', - '/a/subdir/file2': 'dest/a/subdir/file2', - '/keep': '', - '/x/a/file1': 'src/a/file1', - '/x/a/subdir/file2': 'src/a/subdir/file2'}}, - 'src_trailing_slash': False, - 'src_type': 'dir', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'x', - 'dest_trailing_slash': False, - 'dest_type': 'dir', - 'result': {'files': {'/a/file3': 'dest/a/file3', - '/a/subdir/file2': 'dest/a/subdir/file2', - '/keep': '', - '/x/a/file1': 'src/a/file1', - '/x/a/subdir/file2': 'src/a/subdir/file2'}}, - 'src_trailing_slash': False, - 'src_type': 'dir', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'x', - 'dest_trailing_slash': True, - 'dest_type': 'dir', - 'result': {'files': {'/a/file3': 'dest/a/file3', - '/a/subdir/file2': 'dest/a/subdir/file2', - '/keep': '', - '/x/file1': 'src/a/file1', - '/x/subdir/file2': 'src/a/subdir/file2'}}, - 'src_trailing_slash': True, - 'src_type': 'dir', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'x', - 'dest_trailing_slash': False, - 'dest_type': 'dir', - 'result': {'files': {'/a/file3': 'dest/a/file3', - '/a/subdir/file2': 'dest/a/subdir/file2', - '/keep': '', - '/x/file1': 'src/a/file1', - '/x/subdir/file2': 'src/a/subdir/file2'}}, - 'src_trailing_slash': True, - 'src_type': 'dir', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'x', - 'dest_trailing_slash': True, - 'dest_type': 'dir', - 'result': {'files': {'/a/file3': 'dest/a/file3', - '/a/subdir/file2': 'dest/a/subdir/file2', - '/keep': '', - '/x/file1': 'src/a/file1', - '/x/subdir/file2': 'src/a/subdir/file2'}}, - 'src_trailing_slash': False, - 'src_type': 'dir', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'x', - 'dest_trailing_slash': False, - 'dest_type': 'dir', - 'result': {'files': {'/a/file3': 'dest/a/file3', - '/a/subdir/file2': 'dest/a/subdir/file2', - '/keep': '', - '/x/file1': 'src/a/file1', - '/x/subdir/file2': 'src/a/subdir/file2'}}, - 'src_trailing_slash': False, - 'src_type': 'dir', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'x', - 'dest_trailing_slash': True, - 'dest_type': 'dir', - 'result': {'files': {'/a/file3': 'dest/a/file3', - '/a/subdir/file2': 'dest/a/subdir/file2', - '/keep': '', - '/x/a/file1': 'src/a/file1', - '/x/a/subdir/file2': 'src/a/subdir/file2'}}, - 'src_trailing_slash': True, - 'src_type': 'dir', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'x', - 'dest_trailing_slash': False, - 'dest_type': 'dir', - 'result': {'files': {'/a/file3': 'dest/a/file3', - '/a/subdir/file2': 'dest/a/subdir/file2', - '/keep': '', - '/x/file1': 'src/a/file1', - '/x/subdir/file2': 'src/a/subdir/file2'}}, - 'src_trailing_slash': True, - 'src_type': 'dir', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'x', - 'dest_trailing_slash': True, - 'dest_type': 'dir', - 'result': {'files': {'/a/file3': 'dest/a/file3', - '/a/subdir/file2': 'dest/a/subdir/file2', - '/keep': '', - '/x/a/file1': 'src/a/file1', - '/x/a/subdir/file2': 'src/a/subdir/file2'}}, - 'src_trailing_slash': False, - 'src_type': 'dir', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'x', - 'dest_trailing_slash': False, - 'dest_type': 'dir', - 'result': {'files': {'/a/file3': 'dest/a/file3', - '/a/subdir/file2': 'dest/a/subdir/file2', - '/keep': '', - '/x/file1': 'src/a/file1', - '/x/subdir/file2': 'src/a/subdir/file2'}}, - 'src_trailing_slash': False, - 'src_type': 'dir', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': None, - 'dest_trailing_slash': True, - 'dest_type': 'noexist', - 'result': {'files': {'/a/file1': 'src/a/file1', - '/a/subdir/file2': 'src/a/subdir/file2', - '/keep': ''}}, - 'src_trailing_slash': True, - 'src_type': 'dir', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': None, - 'dest_trailing_slash': False, - 'dest_type': 'noexist', - 'result': {'files': {'/a/file1': 'src/a/file1', - '/a/subdir/file2': 'src/a/subdir/file2', - '/keep': ''}}, - 'src_trailing_slash': True, - 'src_type': 'dir', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': None, - 'dest_trailing_slash': True, - 'dest_type': 'noexist', - 'result': {'files': {'/a/file1': 'src/a/file1', - '/a/subdir/file2': 'src/a/subdir/file2', - '/keep': ''}}, - 'src_trailing_slash': False, - 'src_type': 'dir', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': None, - 'dest_trailing_slash': False, - 'dest_type': 'noexist', - 'result': {'files': {'/a/file1': 'src/a/file1', - '/a/subdir/file2': 'src/a/subdir/file2', - '/keep': ''}}, - 'src_trailing_slash': False, - 'src_type': 'dir', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': None, - 'dest_trailing_slash': True, - 'dest_type': 'noexist', - 'result': {'files': {'/file1': 'src/a/file1', - '/keep': '', - '/subdir/file2': 'src/a/subdir/file2'}}, - 'src_trailing_slash': True, - 'src_type': 'dir', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': None, - 'dest_trailing_slash': False, - 'dest_type': 'noexist', - 'result': {'files': {'/file1': 'src/a/file1', - '/keep': '', - '/subdir/file2': 'src/a/subdir/file2'}}, - 'src_trailing_slash': True, - 'src_type': 'dir', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': None, - 'dest_trailing_slash': True, - 'dest_type': 'noexist', - 'result': {'files': {'/file1': 'src/a/file1', - '/keep': '', - '/subdir/file2': 'src/a/subdir/file2'}}, - 'src_trailing_slash': False, - 'src_type': 'dir', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': None, - 'dest_trailing_slash': False, - 'dest_type': 'noexist', - 'result': {'files': {'/file1': 'src/a/file1', - '/keep': '', - '/subdir/file2': 'src/a/subdir/file2'}}, - 'src_trailing_slash': False, - 'src_type': 'dir', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': None, - 'dest_trailing_slash': True, - 'dest_type': 'noexist', - 'result': {'files': {'/a/file1': 'src/a/file1', - '/a/subdir/file2': 'src/a/subdir/file2', - '/keep': ''}}, - 'src_trailing_slash': True, - 'src_type': 'dir', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': None, - 'dest_trailing_slash': False, - 'dest_type': 'noexist', - 'result': {'files': {'/a/file1': 'src/a/file1', - '/a/subdir/file2': 'src/a/subdir/file2', - '/keep': ''}}, - 'src_trailing_slash': True, - 'src_type': 'dir', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': None, - 'dest_trailing_slash': True, - 'dest_type': 'noexist', - 'result': {'files': {'/a/file1': 'src/a/file1', - '/a/subdir/file2': 'src/a/subdir/file2', - '/keep': ''}}, - 'src_trailing_slash': False, - 'src_type': 'dir', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': None, - 'dest_trailing_slash': False, - 'dest_type': 'noexist', - 'result': {'files': {'/a/file1': 'src/a/file1', - '/a/subdir/file2': 'src/a/subdir/file2', - '/keep': ''}}, - 'src_trailing_slash': False, - 'src_type': 'dir', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'a', - 'dest_trailing_slash': True, - 'dest_type': 'noexist', - 'result': {'files': {'/a/a/file1': 'src/a/file1', - '/a/a/subdir/file2': 'src/a/subdir/file2', - '/keep': ''}}, - 'src_trailing_slash': True, - 'src_type': 'dir', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'a', - 'dest_trailing_slash': False, - 'dest_type': 'noexist', - 'result': {'files': {'/a/a/file1': 'src/a/file1', - '/a/a/subdir/file2': 'src/a/subdir/file2', - '/keep': ''}}, - 'src_trailing_slash': True, - 'src_type': 'dir', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'a', - 'dest_trailing_slash': True, - 'dest_type': 'noexist', - 'result': {'files': {'/a/a/file1': 'src/a/file1', - '/a/a/subdir/file2': 'src/a/subdir/file2', - '/keep': ''}}, - 'src_trailing_slash': False, - 'src_type': 'dir', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'a', - 'dest_trailing_slash': False, - 'dest_type': 'noexist', - 'result': {'files': {'/a/a/file1': 'src/a/file1', - '/a/a/subdir/file2': 'src/a/subdir/file2', - '/keep': ''}}, - 'src_trailing_slash': False, - 'src_type': 'dir', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'a', - 'dest_trailing_slash': True, - 'dest_type': 'noexist', - 'result': {'files': {'/a/file1': 'src/a/file1', - '/a/subdir/file2': 'src/a/subdir/file2', - '/keep': ''}}, - 'src_trailing_slash': True, - 'src_type': 'dir', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'a', - 'dest_trailing_slash': False, - 'dest_type': 'noexist', - 'result': {'files': {'/a/file1': 'src/a/file1', - '/a/subdir/file2': 'src/a/subdir/file2', - '/keep': ''}}, - 'src_trailing_slash': True, - 'src_type': 'dir', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'a', - 'dest_trailing_slash': True, - 'dest_type': 'noexist', - 'result': {'files': {'/a/file1': 'src/a/file1', - '/a/subdir/file2': 'src/a/subdir/file2', - '/keep': ''}}, - 'src_trailing_slash': False, - 'src_type': 'dir', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'a', - 'dest_trailing_slash': False, - 'dest_type': 'noexist', - 'result': {'files': {'/a/file1': 'src/a/file1', - '/a/subdir/file2': 'src/a/subdir/file2', - '/keep': ''}}, - 'src_trailing_slash': False, - 'src_type': 'dir', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'a', - 'dest_trailing_slash': True, - 'dest_type': 'noexist', - 'result': {'files': {'/a/a/file1': 'src/a/file1', - '/a/a/subdir/file2': 'src/a/subdir/file2', - '/keep': ''}}, - 'src_trailing_slash': True, - 'src_type': 'dir', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'a', - 'dest_trailing_slash': False, - 'dest_type': 'noexist', - 'result': {'files': {'/a/file1': 'src/a/file1', - '/a/subdir/file2': 'src/a/subdir/file2', - '/keep': ''}}, - 'src_trailing_slash': True, - 'src_type': 'dir', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'a', - 'dest_trailing_slash': True, - 'dest_type': 'noexist', - 'result': {'files': {'/a/a/file1': 'src/a/file1', - '/a/a/subdir/file2': 'src/a/subdir/file2', - '/keep': ''}}, - 'src_trailing_slash': False, - 'src_type': 'dir', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'a', - 'dest_trailing_slash': False, - 'dest_type': 'noexist', - 'result': {'files': {'/a/file1': 'src/a/file1', - '/a/subdir/file2': 'src/a/subdir/file2', - '/keep': ''}}, - 'src_trailing_slash': False, - 'src_type': 'dir', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'x', - 'dest_trailing_slash': True, - 'dest_type': 'noexist', - 'result': {'files': {'/keep': '', - '/x/a/file1': 'src/a/file1', - '/x/a/subdir/file2': 'src/a/subdir/file2'}}, - 'src_trailing_slash': True, - 'src_type': 'dir', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'x', - 'dest_trailing_slash': False, - 'dest_type': 'noexist', - 'result': {'files': {'/keep': '', - '/x/a/file1': 'src/a/file1', - '/x/a/subdir/file2': 'src/a/subdir/file2'}}, - 'src_trailing_slash': True, - 'src_type': 'dir', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'x', - 'dest_trailing_slash': True, - 'dest_type': 'noexist', - 'result': {'files': {'/keep': '', - '/x/a/file1': 'src/a/file1', - '/x/a/subdir/file2': 'src/a/subdir/file2'}}, - 'src_trailing_slash': False, - 'src_type': 'dir', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'x', - 'dest_trailing_slash': False, - 'dest_type': 'noexist', - 'result': {'files': {'/keep': '', - '/x/a/file1': 'src/a/file1', - '/x/a/subdir/file2': 'src/a/subdir/file2'}}, - 'src_trailing_slash': False, - 'src_type': 'dir', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'x', - 'dest_trailing_slash': True, - 'dest_type': 'noexist', - 'result': {'files': {'/keep': '', - '/x/file1': 'src/a/file1', - '/x/subdir/file2': 'src/a/subdir/file2'}}, - 'src_trailing_slash': True, - 'src_type': 'dir', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'x', - 'dest_trailing_slash': False, - 'dest_type': 'noexist', - 'result': {'files': {'/keep': '', - '/x/file1': 'src/a/file1', - '/x/subdir/file2': 'src/a/subdir/file2'}}, - 'src_trailing_slash': True, - 'src_type': 'dir', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'x', - 'dest_trailing_slash': True, - 'dest_type': 'noexist', - 'result': {'files': {'/keep': '', - '/x/file1': 'src/a/file1', - '/x/subdir/file2': 'src/a/subdir/file2'}}, - 'src_trailing_slash': False, - 'src_type': 'dir', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'x', - 'dest_trailing_slash': False, - 'dest_type': 'noexist', - 'result': {'files': {'/keep': '', - '/x/file1': 'src/a/file1', - '/x/subdir/file2': 'src/a/subdir/file2'}}, - 'src_trailing_slash': False, - 'src_type': 'dir', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'x', - 'dest_trailing_slash': True, - 'dest_type': 'noexist', - 'result': {'files': {'/keep': '', - '/x/a/file1': 'src/a/file1', - '/x/a/subdir/file2': 'src/a/subdir/file2'}}, - 'src_trailing_slash': True, - 'src_type': 'dir', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'x', - 'dest_trailing_slash': False, - 'dest_type': 'noexist', - 'result': {'files': {'/keep': '', - '/x/file1': 'src/a/file1', - '/x/subdir/file2': 'src/a/subdir/file2'}}, - 'src_trailing_slash': True, - 'src_type': 'dir', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'x', - 'dest_trailing_slash': True, - 'dest_type': 'noexist', - 'result': {'files': {'/keep': '', - '/x/a/file1': 'src/a/file1', - '/x/a/subdir/file2': 'src/a/subdir/file2'}}, - 'src_trailing_slash': False, - 'src_type': 'dir', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'x', - 'dest_trailing_slash': False, - 'dest_type': 'noexist', - 'result': {'files': {'/keep': '', - '/x/file1': 'src/a/file1', - '/x/subdir/file2': 'src/a/subdir/file2'}}, - 'src_trailing_slash': False, - 'src_type': 'dir', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': None, - 'dest_trailing_slash': True, - 'dest_type': 'file', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': None, - 'dest_trailing_slash': False, - 'dest_type': 'file', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': None, - 'dest_trailing_slash': True, - 'dest_type': 'file', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': False, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': None, - 'dest_trailing_slash': False, - 'dest_type': 'file', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': False, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': None, - 'dest_trailing_slash': True, - 'dest_type': 'file', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': None, - 'dest_trailing_slash': False, - 'dest_type': 'file', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': None, - 'dest_trailing_slash': True, - 'dest_type': 'file', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': False, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': None, - 'dest_trailing_slash': False, - 'dest_type': 'file', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': False, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': None, - 'dest_trailing_slash': True, - 'dest_type': 'file', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'noexist', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': None, - 'dest_trailing_slash': False, - 'dest_type': 'file', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'noexist', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': None, - 'dest_trailing_slash': True, - 'dest_type': 'file', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': False, - 'src_type': 'noexist', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': None, - 'dest_trailing_slash': False, - 'dest_type': 'file', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': False, - 'src_type': 'noexist', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'a', - 'dest_trailing_slash': True, - 'dest_type': 'file', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'a', - 'dest_trailing_slash': False, - 'dest_type': 'file', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'a', - 'dest_trailing_slash': True, - 'dest_type': 'file', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': False, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'a', - 'dest_trailing_slash': False, - 'dest_type': 'file', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': False, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'a', - 'dest_trailing_slash': True, - 'dest_type': 'file', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'a', - 'dest_trailing_slash': False, - 'dest_type': 'file', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'a', - 'dest_trailing_slash': True, - 'dest_type': 'file', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': False, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'a', - 'dest_trailing_slash': False, - 'dest_type': 'file', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': False, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'a', - 'dest_trailing_slash': True, - 'dest_type': 'file', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'noexist', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'a', - 'dest_trailing_slash': False, - 'dest_type': 'file', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'noexist', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'a', - 'dest_trailing_slash': True, - 'dest_type': 'file', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': False, - 'src_type': 'noexist', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'a', - 'dest_trailing_slash': False, - 'dest_type': 'file', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': False, - 'src_type': 'noexist', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'x', - 'dest_trailing_slash': True, - 'dest_type': 'file', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'x', - 'dest_trailing_slash': False, - 'dest_type': 'file', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'x', - 'dest_trailing_slash': True, - 'dest_type': 'file', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': False, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'x', - 'dest_trailing_slash': False, - 'dest_type': 'file', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': False, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'x', - 'dest_trailing_slash': True, - 'dest_type': 'file', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'x', - 'dest_trailing_slash': False, - 'dest_type': 'file', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'x', - 'dest_trailing_slash': True, - 'dest_type': 'file', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': False, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'x', - 'dest_trailing_slash': False, - 'dest_type': 'file', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': False, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'x', - 'dest_trailing_slash': True, - 'dest_type': 'file', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'noexist', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'x', - 'dest_trailing_slash': False, - 'dest_type': 'file', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'noexist', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'x', - 'dest_trailing_slash': True, - 'dest_type': 'file', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': False, - 'src_type': 'noexist', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'x', - 'dest_trailing_slash': False, - 'dest_type': 'file', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': False, - 'src_type': 'noexist', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': None, - 'dest_trailing_slash': True, - 'dest_type': 'dir', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': None, - 'dest_trailing_slash': False, - 'dest_type': 'dir', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': None, - 'dest_trailing_slash': True, - 'dest_type': 'dir', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': False, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': None, - 'dest_trailing_slash': False, - 'dest_type': 'dir', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': False, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': None, - 'dest_trailing_slash': True, - 'dest_type': 'dir', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': None, - 'dest_trailing_slash': False, - 'dest_type': 'dir', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': None, - 'dest_trailing_slash': True, - 'dest_type': 'dir', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': False, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': None, - 'dest_trailing_slash': False, - 'dest_type': 'dir', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': False, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': None, - 'dest_trailing_slash': True, - 'dest_type': 'dir', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'noexist', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': None, - 'dest_trailing_slash': False, - 'dest_type': 'dir', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'noexist', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': None, - 'dest_trailing_slash': True, - 'dest_type': 'dir', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': False, - 'src_type': 'noexist', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': None, - 'dest_trailing_slash': False, - 'dest_type': 'dir', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': False, - 'src_type': 'noexist', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'a', - 'dest_trailing_slash': True, - 'dest_type': 'dir', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'a', - 'dest_trailing_slash': False, - 'dest_type': 'dir', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'a', - 'dest_trailing_slash': True, - 'dest_type': 'dir', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': False, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'a', - 'dest_trailing_slash': False, - 'dest_type': 'dir', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': False, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'a', - 'dest_trailing_slash': True, - 'dest_type': 'dir', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'a', - 'dest_trailing_slash': False, - 'dest_type': 'dir', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'a', - 'dest_trailing_slash': True, - 'dest_type': 'dir', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': False, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'a', - 'dest_trailing_slash': False, - 'dest_type': 'dir', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': False, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'a', - 'dest_trailing_slash': True, - 'dest_type': 'dir', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'noexist', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'a', - 'dest_trailing_slash': False, - 'dest_type': 'dir', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'noexist', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'a', - 'dest_trailing_slash': True, - 'dest_type': 'dir', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': False, - 'src_type': 'noexist', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'a', - 'dest_trailing_slash': False, - 'dest_type': 'dir', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': False, - 'src_type': 'noexist', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'x', - 'dest_trailing_slash': True, - 'dest_type': 'dir', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'x', - 'dest_trailing_slash': False, - 'dest_type': 'dir', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'x', - 'dest_trailing_slash': True, - 'dest_type': 'dir', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': False, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'x', - 'dest_trailing_slash': False, - 'dest_type': 'dir', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': False, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'x', - 'dest_trailing_slash': True, - 'dest_type': 'dir', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'x', - 'dest_trailing_slash': False, - 'dest_type': 'dir', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'x', - 'dest_trailing_slash': True, - 'dest_type': 'dir', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': False, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'x', - 'dest_trailing_slash': False, - 'dest_type': 'dir', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': False, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'x', - 'dest_trailing_slash': True, - 'dest_type': 'dir', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'noexist', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'x', - 'dest_trailing_slash': False, - 'dest_type': 'dir', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'noexist', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'x', - 'dest_trailing_slash': True, - 'dest_type': 'dir', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': False, - 'src_type': 'noexist', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'x', - 'dest_trailing_slash': False, - 'dest_type': 'dir', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': False, - 'src_type': 'noexist', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': None, - 'dest_trailing_slash': True, - 'dest_type': 'noexist', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': None, - 'dest_trailing_slash': False, - 'dest_type': 'noexist', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': None, - 'dest_trailing_slash': True, - 'dest_type': 'noexist', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': False, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': None, - 'dest_trailing_slash': False, - 'dest_type': 'noexist', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': False, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': None, - 'dest_trailing_slash': True, - 'dest_type': 'noexist', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': None, - 'dest_trailing_slash': False, - 'dest_type': 'noexist', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': None, - 'dest_trailing_slash': True, - 'dest_type': 'noexist', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': False, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': None, - 'dest_trailing_slash': False, - 'dest_type': 'noexist', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': False, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': None, - 'dest_trailing_slash': True, - 'dest_type': 'noexist', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'noexist', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': None, - 'dest_trailing_slash': False, - 'dest_type': 'noexist', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'noexist', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': None, - 'dest_trailing_slash': True, - 'dest_type': 'noexist', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': False, - 'src_type': 'noexist', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': None, - 'dest_trailing_slash': False, - 'dest_type': 'noexist', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': False, - 'src_type': 'noexist', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'a', - 'dest_trailing_slash': True, - 'dest_type': 'noexist', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'a', - 'dest_trailing_slash': False, - 'dest_type': 'noexist', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'a', - 'dest_trailing_slash': True, - 'dest_type': 'noexist', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': False, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'a', - 'dest_trailing_slash': False, - 'dest_type': 'noexist', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': False, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'a', - 'dest_trailing_slash': True, - 'dest_type': 'noexist', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'a', - 'dest_trailing_slash': False, - 'dest_type': 'noexist', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'a', - 'dest_trailing_slash': True, - 'dest_type': 'noexist', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': False, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'a', - 'dest_trailing_slash': False, - 'dest_type': 'noexist', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': False, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'a', - 'dest_trailing_slash': True, - 'dest_type': 'noexist', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'noexist', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'a', - 'dest_trailing_slash': False, - 'dest_type': 'noexist', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'noexist', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'a', - 'dest_trailing_slash': True, - 'dest_type': 'noexist', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': False, - 'src_type': 'noexist', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'a', - 'dest_trailing_slash': False, - 'dest_type': 'noexist', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': False, - 'src_type': 'noexist', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'x', - 'dest_trailing_slash': True, - 'dest_type': 'noexist', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'x', - 'dest_trailing_slash': False, - 'dest_type': 'noexist', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'x', - 'dest_trailing_slash': True, - 'dest_type': 'noexist', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': False, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'x', - 'dest_trailing_slash': False, - 'dest_type': 'noexist', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': False, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_dir'}, - {'dest_basename': 'x', - 'dest_trailing_slash': True, - 'dest_type': 'noexist', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'x', - 'dest_trailing_slash': False, - 'dest_type': 'noexist', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'x', - 'dest_trailing_slash': True, - 'dest_type': 'noexist', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': False, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'x', - 'dest_trailing_slash': False, - 'dest_type': 'noexist', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': False, - 'src_type': 'noexist', - 'treat_dest_as': 'dest_is_target'}, - {'dest_basename': 'x', - 'dest_trailing_slash': True, - 'dest_type': 'noexist', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'noexist', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'x', - 'dest_trailing_slash': False, - 'dest_type': 'noexist', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': True, - 'src_type': 'noexist', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'x', - 'dest_trailing_slash': True, - 'dest_type': 'noexist', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': False, - 'src_type': 'noexist', - 'treat_dest_as': 'infer_dest'}, - {'dest_basename': 'x', - 'dest_trailing_slash': False, - 'dest_type': 'noexist', - 'result': {'exception': 'FileNotFoundError'}, - 'src_trailing_slash': False, - 'src_type': 'noexist', - 'treat_dest_as': 'infer_dest'}] +COPY_TEST_SPECS = [ + { + 'dest_basename': None, + 'dest_trailing_slash': True, + 'dest_type': 'file', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'file', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': False, + 'dest_type': 'file', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'file', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': True, + 'dest_type': 'file', + 'result': {'files': {'/a': 'src/a', '/keep': ''}}, + 'src_trailing_slash': False, + 'src_type': 'file', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': False, + 'dest_type': 'file', + 'result': {'files': {'/a': 'src/a', '/keep': ''}}, + 'src_trailing_slash': False, + 'src_type': 'file', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': True, + 'dest_type': 'file', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'file', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': False, + 'dest_type': 'file', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'file', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': True, + 'dest_type': 'file', + 'result': {'exception': 'IsADirectoryError'}, + 'src_trailing_slash': False, + 'src_type': 'file', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': False, + 'dest_type': 'file', + 'result': {'exception': 'IsADirectoryError'}, + 'src_trailing_slash': False, + 'src_type': 'file', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': True, + 'dest_type': 'file', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'file', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': False, + 'dest_type': 'file', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'file', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': True, + 'dest_type': 'file', + 'result': {'files': {'/a': 'src/a', '/keep': ''}}, + 'src_trailing_slash': False, + 'src_type': 'file', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': False, + 'dest_type': 'file', + 'result': {'files': {'/a': 'src/a', '/keep': ''}}, + 'src_trailing_slash': False, + 'src_type': 'file', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': True, + 'dest_type': 'file', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'file', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': False, + 'dest_type': 'file', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'file', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': True, + 'dest_type': 'file', + 'result': {'exception': 'NotADirectoryError'}, + 'src_trailing_slash': False, + 'src_type': 'file', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': False, + 'dest_type': 'file', + 'result': {'exception': 'NotADirectoryError'}, + 'src_trailing_slash': False, + 'src_type': 'file', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': True, + 'dest_type': 'file', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'file', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': False, + 'dest_type': 'file', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'file', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': True, + 'dest_type': 'file', + 'result': {'exception': 'IsADirectoryError'}, + 'src_trailing_slash': False, + 'src_type': 'file', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': False, + 'dest_type': 'file', + 'result': {'files': {'/a': 'src/a', '/keep': ''}}, + 'src_trailing_slash': False, + 'src_type': 'file', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': True, + 'dest_type': 'file', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'file', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': False, + 'dest_type': 'file', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'file', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': True, + 'dest_type': 'file', + 'result': {'exception': 'NotADirectoryError'}, + 'src_trailing_slash': False, + 'src_type': 'file', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': False, + 'dest_type': 'file', + 'result': {'files': {'/a': 'src/a', '/keep': ''}}, + 'src_trailing_slash': False, + 'src_type': 'file', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': True, + 'dest_type': 'file', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'file', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': False, + 'dest_type': 'file', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'file', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': True, + 'dest_type': 'file', + 'result': {'files': {'/a': 'dest/a', '/keep': '', '/x/a': 'src/a'}}, + 'src_trailing_slash': False, + 'src_type': 'file', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': False, + 'dest_type': 'file', + 'result': {'files': {'/a': 'dest/a', '/keep': '', '/x/a': 'src/a'}}, + 'src_trailing_slash': False, + 'src_type': 'file', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': True, + 'dest_type': 'file', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'file', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': False, + 'dest_type': 'file', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'file', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': True, + 'dest_type': 'file', + 'result': {'exception': 'IsADirectoryError'}, + 'src_trailing_slash': False, + 'src_type': 'file', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': False, + 'dest_type': 'file', + 'result': {'files': {'/a': 'dest/a', '/keep': '', '/x': 'src/a'}}, + 'src_trailing_slash': False, + 'src_type': 'file', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': True, + 'dest_type': 'file', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'file', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': False, + 'dest_type': 'file', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'file', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': True, + 'dest_type': 'file', + 'result': {'files': {'/a': 'dest/a', '/keep': '', '/x/a': 'src/a'}}, + 'src_trailing_slash': False, + 'src_type': 'file', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': False, + 'dest_type': 'file', + 'result': {'files': {'/a': 'dest/a', '/keep': '', '/x': 'src/a'}}, + 'src_trailing_slash': False, + 'src_type': 'file', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': True, + 'dest_type': 'dir', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'file', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': False, + 'dest_type': 'dir', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'file', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': True, + 'dest_type': 'dir', + 'result': {'exception': 'IsADirectoryError'}, + 'src_trailing_slash': False, + 'src_type': 'file', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': False, + 'dest_type': 'dir', + 'result': {'exception': 'IsADirectoryError'}, + 'src_trailing_slash': False, + 'src_type': 'file', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': True, + 'dest_type': 'dir', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'file', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': False, + 'dest_type': 'dir', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'file', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': True, + 'dest_type': 'dir', + 'result': {'exception': 'IsADirectoryError'}, + 'src_trailing_slash': False, + 'src_type': 'file', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': False, + 'dest_type': 'dir', + 'result': {'exception': 'IsADirectoryError'}, + 'src_trailing_slash': False, + 'src_type': 'file', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': True, + 'dest_type': 'dir', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'file', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': False, + 'dest_type': 'dir', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'file', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': True, + 'dest_type': 'dir', + 'result': {'exception': 'IsADirectoryError'}, + 'src_trailing_slash': False, + 'src_type': 'file', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': False, + 'dest_type': 'dir', + 'result': {'exception': 'IsADirectoryError'}, + 'src_trailing_slash': False, + 'src_type': 'file', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': True, + 'dest_type': 'dir', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'file', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': False, + 'dest_type': 'dir', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'file', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': True, + 'dest_type': 'dir', + 'result': { + 'files': { + '/a/a': 'src/a', + '/a/file3': 'dest/a/file3', + '/a/subdir/file2': 'dest/a/subdir/file2', + '/keep': '', + } + }, + 'src_trailing_slash': False, + 'src_type': 'file', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': False, + 'dest_type': 'dir', + 'result': { + 'files': { + '/a/a': 'src/a', + '/a/file3': 'dest/a/file3', + '/a/subdir/file2': 'dest/a/subdir/file2', + '/keep': '', + } + }, + 'src_trailing_slash': False, + 'src_type': 'file', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': True, + 'dest_type': 'dir', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'file', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': False, + 'dest_type': 'dir', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'file', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': True, + 'dest_type': 'dir', + 'result': {'exception': 'IsADirectoryError'}, + 'src_trailing_slash': False, + 'src_type': 'file', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': False, + 'dest_type': 'dir', + 'result': {'exception': 'IsADirectoryError'}, + 'src_trailing_slash': False, + 'src_type': 'file', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': True, + 'dest_type': 'dir', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'file', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': False, + 'dest_type': 'dir', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'file', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': True, + 'dest_type': 'dir', + 'result': { + 'files': { + '/a/a': 'src/a', + '/a/file3': 'dest/a/file3', + '/a/subdir/file2': 'dest/a/subdir/file2', + '/keep': '', + } + }, + 'src_trailing_slash': False, + 'src_type': 'file', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': False, + 'dest_type': 'dir', + 'result': { + 'files': { + '/a/a': 'src/a', + '/a/file3': 'dest/a/file3', + '/a/subdir/file2': 'dest/a/subdir/file2', + '/keep': '', + } + }, + 'src_trailing_slash': False, + 'src_type': 'file', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': True, + 'dest_type': 'dir', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'file', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': False, + 'dest_type': 'dir', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'file', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': True, + 'dest_type': 'dir', + 'result': { + 'files': { + '/a/file3': 'dest/a/file3', + '/a/subdir/file2': 'dest/a/subdir/file2', + '/keep': '', + '/x/a': 'src/a', + } + }, + 'src_trailing_slash': False, + 'src_type': 'file', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': False, + 'dest_type': 'dir', + 'result': { + 'files': { + '/a/file3': 'dest/a/file3', + '/a/subdir/file2': 'dest/a/subdir/file2', + '/keep': '', + '/x/a': 'src/a', + } + }, + 'src_trailing_slash': False, + 'src_type': 'file', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': True, + 'dest_type': 'dir', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'file', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': False, + 'dest_type': 'dir', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'file', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': True, + 'dest_type': 'dir', + 'result': {'exception': 'IsADirectoryError'}, + 'src_trailing_slash': False, + 'src_type': 'file', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': False, + 'dest_type': 'dir', + 'result': { + 'files': {'/a/file3': 'dest/a/file3', '/a/subdir/file2': 'dest/a/subdir/file2', '/keep': '', '/x': 'src/a'} + }, + 'src_trailing_slash': False, + 'src_type': 'file', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': True, + 'dest_type': 'dir', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'file', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': False, + 'dest_type': 'dir', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'file', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': True, + 'dest_type': 'dir', + 'result': { + 'files': { + '/a/file3': 'dest/a/file3', + '/a/subdir/file2': 'dest/a/subdir/file2', + '/keep': '', + '/x/a': 'src/a', + } + }, + 'src_trailing_slash': False, + 'src_type': 'file', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': False, + 'dest_type': 'dir', + 'result': { + 'files': {'/a/file3': 'dest/a/file3', '/a/subdir/file2': 'dest/a/subdir/file2', '/keep': '', '/x': 'src/a'} + }, + 'src_trailing_slash': False, + 'src_type': 'file', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': True, + 'dest_type': 'noexist', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'file', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': False, + 'dest_type': 'noexist', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'file', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': True, + 'dest_type': 'noexist', + 'result': {'files': {'/a': 'src/a', '/keep': ''}}, + 'src_trailing_slash': False, + 'src_type': 'file', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': False, + 'dest_type': 'noexist', + 'result': {'files': {'/a': 'src/a', '/keep': ''}}, + 'src_trailing_slash': False, + 'src_type': 'file', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': True, + 'dest_type': 'noexist', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'file', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': False, + 'dest_type': 'noexist', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'file', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': True, + 'dest_type': 'noexist', + 'result': {'exception': 'IsADirectoryError'}, + 'src_trailing_slash': False, + 'src_type': 'file', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': False, + 'dest_type': 'noexist', + 'result': {'exception': 'IsADirectoryError'}, + 'src_trailing_slash': False, + 'src_type': 'file', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': True, + 'dest_type': 'noexist', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'file', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': False, + 'dest_type': 'noexist', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'file', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': True, + 'dest_type': 'noexist', + 'result': {'files': {'/a': 'src/a', '/keep': ''}}, + 'src_trailing_slash': False, + 'src_type': 'file', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': False, + 'dest_type': 'noexist', + 'result': {'files': {'/a': 'src/a', '/keep': ''}}, + 'src_trailing_slash': False, + 'src_type': 'file', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': True, + 'dest_type': 'noexist', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'file', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': False, + 'dest_type': 'noexist', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'file', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': True, + 'dest_type': 'noexist', + 'result': {'files': {'/a/a': 'src/a', '/keep': ''}}, + 'src_trailing_slash': False, + 'src_type': 'file', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': False, + 'dest_type': 'noexist', + 'result': {'files': {'/a/a': 'src/a', '/keep': ''}}, + 'src_trailing_slash': False, + 'src_type': 'file', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': True, + 'dest_type': 'noexist', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'file', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': False, + 'dest_type': 'noexist', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'file', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': True, + 'dest_type': 'noexist', + 'result': {'exception': 'IsADirectoryError'}, + 'src_trailing_slash': False, + 'src_type': 'file', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': False, + 'dest_type': 'noexist', + 'result': {'files': {'/a': 'src/a', '/keep': ''}}, + 'src_trailing_slash': False, + 'src_type': 'file', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': True, + 'dest_type': 'noexist', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'file', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': False, + 'dest_type': 'noexist', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'file', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': True, + 'dest_type': 'noexist', + 'result': {'files': {'/a/a': 'src/a', '/keep': ''}}, + 'src_trailing_slash': False, + 'src_type': 'file', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': False, + 'dest_type': 'noexist', + 'result': {'files': {'/a': 'src/a', '/keep': ''}}, + 'src_trailing_slash': False, + 'src_type': 'file', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': True, + 'dest_type': 'noexist', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'file', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': False, + 'dest_type': 'noexist', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'file', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': True, + 'dest_type': 'noexist', + 'result': {'files': {'/keep': '', '/x/a': 'src/a'}}, + 'src_trailing_slash': False, + 'src_type': 'file', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': False, + 'dest_type': 'noexist', + 'result': {'files': {'/keep': '', '/x/a': 'src/a'}}, + 'src_trailing_slash': False, + 'src_type': 'file', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': True, + 'dest_type': 'noexist', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'file', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': False, + 'dest_type': 'noexist', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'file', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': True, + 'dest_type': 'noexist', + 'result': {'exception': 'IsADirectoryError'}, + 'src_trailing_slash': False, + 'src_type': 'file', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': False, + 'dest_type': 'noexist', + 'result': {'files': {'/keep': '', '/x': 'src/a'}}, + 'src_trailing_slash': False, + 'src_type': 'file', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': True, + 'dest_type': 'noexist', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'file', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': False, + 'dest_type': 'noexist', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'file', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': True, + 'dest_type': 'noexist', + 'result': {'files': {'/keep': '', '/x/a': 'src/a'}}, + 'src_trailing_slash': False, + 'src_type': 'file', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': False, + 'dest_type': 'noexist', + 'result': {'files': {'/keep': '', '/x': 'src/a'}}, + 'src_trailing_slash': False, + 'src_type': 'file', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': True, + 'dest_type': 'file', + 'result': {'exception': 'NotADirectoryError'}, + 'src_trailing_slash': True, + 'src_type': 'dir', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': False, + 'dest_type': 'file', + 'result': {'exception': 'NotADirectoryError'}, + 'src_trailing_slash': True, + 'src_type': 'dir', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': True, + 'dest_type': 'file', + 'result': {'exception': 'NotADirectoryError'}, + 'src_trailing_slash': False, + 'src_type': 'dir', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': False, + 'dest_type': 'file', + 'result': {'exception': 'NotADirectoryError'}, + 'src_trailing_slash': False, + 'src_type': 'dir', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': True, + 'dest_type': 'file', + 'result': { + 'files': {'/a': 'dest/a', '/file1': 'src/a/file1', '/keep': '', '/subdir/file2': 'src/a/subdir/file2'} + }, + 'src_trailing_slash': True, + 'src_type': 'dir', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': False, + 'dest_type': 'file', + 'result': { + 'files': {'/a': 'dest/a', '/file1': 'src/a/file1', '/keep': '', '/subdir/file2': 'src/a/subdir/file2'} + }, + 'src_trailing_slash': True, + 'src_type': 'dir', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': True, + 'dest_type': 'file', + 'result': { + 'files': {'/a': 'dest/a', '/file1': 'src/a/file1', '/keep': '', '/subdir/file2': 'src/a/subdir/file2'} + }, + 'src_trailing_slash': False, + 'src_type': 'dir', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': False, + 'dest_type': 'file', + 'result': { + 'files': {'/a': 'dest/a', '/file1': 'src/a/file1', '/keep': '', '/subdir/file2': 'src/a/subdir/file2'} + }, + 'src_trailing_slash': False, + 'src_type': 'dir', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': True, + 'dest_type': 'file', + 'result': {'exception': 'NotADirectoryError'}, + 'src_trailing_slash': True, + 'src_type': 'dir', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': False, + 'dest_type': 'file', + 'result': {'exception': 'NotADirectoryError'}, + 'src_trailing_slash': True, + 'src_type': 'dir', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': True, + 'dest_type': 'file', + 'result': {'exception': 'NotADirectoryError'}, + 'src_trailing_slash': False, + 'src_type': 'dir', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': False, + 'dest_type': 'file', + 'result': {'exception': 'NotADirectoryError'}, + 'src_trailing_slash': False, + 'src_type': 'dir', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': True, + 'dest_type': 'file', + 'result': {'exception': 'NotADirectoryError'}, + 'src_trailing_slash': True, + 'src_type': 'dir', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': False, + 'dest_type': 'file', + 'result': {'exception': 'NotADirectoryError'}, + 'src_trailing_slash': True, + 'src_type': 'dir', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': True, + 'dest_type': 'file', + 'result': {'exception': 'NotADirectoryError'}, + 'src_trailing_slash': False, + 'src_type': 'dir', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': False, + 'dest_type': 'file', + 'result': {'exception': 'NotADirectoryError'}, + 'src_trailing_slash': False, + 'src_type': 'dir', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': True, + 'dest_type': 'file', + 'result': {'exception': 'NotADirectoryError'}, + 'src_trailing_slash': True, + 'src_type': 'dir', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': False, + 'dest_type': 'file', + 'result': {'exception': 'NotADirectoryError'}, + 'src_trailing_slash': True, + 'src_type': 'dir', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': True, + 'dest_type': 'file', + 'result': {'exception': 'NotADirectoryError'}, + 'src_trailing_slash': False, + 'src_type': 'dir', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': False, + 'dest_type': 'file', + 'result': {'exception': 'NotADirectoryError'}, + 'src_trailing_slash': False, + 'src_type': 'dir', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': True, + 'dest_type': 'file', + 'result': {'exception': 'NotADirectoryError'}, + 'src_trailing_slash': True, + 'src_type': 'dir', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': False, + 'dest_type': 'file', + 'result': {'exception': 'NotADirectoryError'}, + 'src_trailing_slash': True, + 'src_type': 'dir', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': True, + 'dest_type': 'file', + 'result': {'exception': 'NotADirectoryError'}, + 'src_trailing_slash': False, + 'src_type': 'dir', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': False, + 'dest_type': 'file', + 'result': {'exception': 'NotADirectoryError'}, + 'src_trailing_slash': False, + 'src_type': 'dir', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': True, + 'dest_type': 'file', + 'result': { + 'files': { + '/a': 'dest/a', + '/keep': '', + '/x/a/file1': 'src/a/file1', + '/x/a/subdir/file2': 'src/a/subdir/file2', + } + }, + 'src_trailing_slash': True, + 'src_type': 'dir', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': False, + 'dest_type': 'file', + 'result': { + 'files': { + '/a': 'dest/a', + '/keep': '', + '/x/a/file1': 'src/a/file1', + '/x/a/subdir/file2': 'src/a/subdir/file2', + } + }, + 'src_trailing_slash': True, + 'src_type': 'dir', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': True, + 'dest_type': 'file', + 'result': { + 'files': { + '/a': 'dest/a', + '/keep': '', + '/x/a/file1': 'src/a/file1', + '/x/a/subdir/file2': 'src/a/subdir/file2', + } + }, + 'src_trailing_slash': False, + 'src_type': 'dir', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': False, + 'dest_type': 'file', + 'result': { + 'files': { + '/a': 'dest/a', + '/keep': '', + '/x/a/file1': 'src/a/file1', + '/x/a/subdir/file2': 'src/a/subdir/file2', + } + }, + 'src_trailing_slash': False, + 'src_type': 'dir', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': True, + 'dest_type': 'file', + 'result': { + 'files': {'/a': 'dest/a', '/keep': '', '/x/file1': 'src/a/file1', '/x/subdir/file2': 'src/a/subdir/file2'} + }, + 'src_trailing_slash': True, + 'src_type': 'dir', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': False, + 'dest_type': 'file', + 'result': { + 'files': {'/a': 'dest/a', '/keep': '', '/x/file1': 'src/a/file1', '/x/subdir/file2': 'src/a/subdir/file2'} + }, + 'src_trailing_slash': True, + 'src_type': 'dir', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': True, + 'dest_type': 'file', + 'result': { + 'files': {'/a': 'dest/a', '/keep': '', '/x/file1': 'src/a/file1', '/x/subdir/file2': 'src/a/subdir/file2'} + }, + 'src_trailing_slash': False, + 'src_type': 'dir', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': False, + 'dest_type': 'file', + 'result': { + 'files': {'/a': 'dest/a', '/keep': '', '/x/file1': 'src/a/file1', '/x/subdir/file2': 'src/a/subdir/file2'} + }, + 'src_trailing_slash': False, + 'src_type': 'dir', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': True, + 'dest_type': 'file', + 'result': { + 'files': { + '/a': 'dest/a', + '/keep': '', + '/x/a/file1': 'src/a/file1', + '/x/a/subdir/file2': 'src/a/subdir/file2', + } + }, + 'src_trailing_slash': True, + 'src_type': 'dir', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': False, + 'dest_type': 'file', + 'result': { + 'files': {'/a': 'dest/a', '/keep': '', '/x/file1': 'src/a/file1', '/x/subdir/file2': 'src/a/subdir/file2'} + }, + 'src_trailing_slash': True, + 'src_type': 'dir', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': True, + 'dest_type': 'file', + 'result': { + 'files': { + '/a': 'dest/a', + '/keep': '', + '/x/a/file1': 'src/a/file1', + '/x/a/subdir/file2': 'src/a/subdir/file2', + } + }, + 'src_trailing_slash': False, + 'src_type': 'dir', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': False, + 'dest_type': 'file', + 'result': { + 'files': {'/a': 'dest/a', '/keep': '', '/x/file1': 'src/a/file1', '/x/subdir/file2': 'src/a/subdir/file2'} + }, + 'src_trailing_slash': False, + 'src_type': 'dir', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': True, + 'dest_type': 'dir', + 'result': { + 'files': { + '/a/file1': 'src/a/file1', + '/a/file3': 'dest/a/file3', + '/a/subdir/file2': 'src/a/subdir/file2', + '/keep': '', + } + }, + 'src_trailing_slash': True, + 'src_type': 'dir', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': False, + 'dest_type': 'dir', + 'result': { + 'files': { + '/a/file1': 'src/a/file1', + '/a/file3': 'dest/a/file3', + '/a/subdir/file2': 'src/a/subdir/file2', + '/keep': '', + } + }, + 'src_trailing_slash': True, + 'src_type': 'dir', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': True, + 'dest_type': 'dir', + 'result': { + 'files': { + '/a/file1': 'src/a/file1', + '/a/file3': 'dest/a/file3', + '/a/subdir/file2': 'src/a/subdir/file2', + '/keep': '', + } + }, + 'src_trailing_slash': False, + 'src_type': 'dir', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': False, + 'dest_type': 'dir', + 'result': { + 'files': { + '/a/file1': 'src/a/file1', + '/a/file3': 'dest/a/file3', + '/a/subdir/file2': 'src/a/subdir/file2', + '/keep': '', + } + }, + 'src_trailing_slash': False, + 'src_type': 'dir', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': True, + 'dest_type': 'dir', + 'result': { + 'files': { + '/a/file3': 'dest/a/file3', + '/a/subdir/file2': 'dest/a/subdir/file2', + '/file1': 'src/a/file1', + '/keep': '', + '/subdir/file2': 'src/a/subdir/file2', + } + }, + 'src_trailing_slash': True, + 'src_type': 'dir', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': False, + 'dest_type': 'dir', + 'result': { + 'files': { + '/a/file3': 'dest/a/file3', + '/a/subdir/file2': 'dest/a/subdir/file2', + '/file1': 'src/a/file1', + '/keep': '', + '/subdir/file2': 'src/a/subdir/file2', + } + }, + 'src_trailing_slash': True, + 'src_type': 'dir', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': True, + 'dest_type': 'dir', + 'result': { + 'files': { + '/a/file3': 'dest/a/file3', + '/a/subdir/file2': 'dest/a/subdir/file2', + '/file1': 'src/a/file1', + '/keep': '', + '/subdir/file2': 'src/a/subdir/file2', + } + }, + 'src_trailing_slash': False, + 'src_type': 'dir', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': False, + 'dest_type': 'dir', + 'result': { + 'files': { + '/a/file3': 'dest/a/file3', + '/a/subdir/file2': 'dest/a/subdir/file2', + '/file1': 'src/a/file1', + '/keep': '', + '/subdir/file2': 'src/a/subdir/file2', + } + }, + 'src_trailing_slash': False, + 'src_type': 'dir', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': True, + 'dest_type': 'dir', + 'result': { + 'files': { + '/a/file1': 'src/a/file1', + '/a/file3': 'dest/a/file3', + '/a/subdir/file2': 'src/a/subdir/file2', + '/keep': '', + } + }, + 'src_trailing_slash': True, + 'src_type': 'dir', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': False, + 'dest_type': 'dir', + 'result': { + 'files': { + '/a/file1': 'src/a/file1', + '/a/file3': 'dest/a/file3', + '/a/subdir/file2': 'src/a/subdir/file2', + '/keep': '', + } + }, + 'src_trailing_slash': True, + 'src_type': 'dir', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': True, + 'dest_type': 'dir', + 'result': { + 'files': { + '/a/file1': 'src/a/file1', + '/a/file3': 'dest/a/file3', + '/a/subdir/file2': 'src/a/subdir/file2', + '/keep': '', + } + }, + 'src_trailing_slash': False, + 'src_type': 'dir', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': False, + 'dest_type': 'dir', + 'result': { + 'files': { + '/a/file1': 'src/a/file1', + '/a/file3': 'dest/a/file3', + '/a/subdir/file2': 'src/a/subdir/file2', + '/keep': '', + } + }, + 'src_trailing_slash': False, + 'src_type': 'dir', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': True, + 'dest_type': 'dir', + 'result': { + 'files': { + '/a/a/file1': 'src/a/file1', + '/a/a/subdir/file2': 'src/a/subdir/file2', + '/a/file3': 'dest/a/file3', + '/a/subdir/file2': 'dest/a/subdir/file2', + '/keep': '', + } + }, + 'src_trailing_slash': True, + 'src_type': 'dir', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': False, + 'dest_type': 'dir', + 'result': { + 'files': { + '/a/a/file1': 'src/a/file1', + '/a/a/subdir/file2': 'src/a/subdir/file2', + '/a/file3': 'dest/a/file3', + '/a/subdir/file2': 'dest/a/subdir/file2', + '/keep': '', + } + }, + 'src_trailing_slash': True, + 'src_type': 'dir', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': True, + 'dest_type': 'dir', + 'result': { + 'files': { + '/a/a/file1': 'src/a/file1', + '/a/a/subdir/file2': 'src/a/subdir/file2', + '/a/file3': 'dest/a/file3', + '/a/subdir/file2': 'dest/a/subdir/file2', + '/keep': '', + } + }, + 'src_trailing_slash': False, + 'src_type': 'dir', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': False, + 'dest_type': 'dir', + 'result': { + 'files': { + '/a/a/file1': 'src/a/file1', + '/a/a/subdir/file2': 'src/a/subdir/file2', + '/a/file3': 'dest/a/file3', + '/a/subdir/file2': 'dest/a/subdir/file2', + '/keep': '', + } + }, + 'src_trailing_slash': False, + 'src_type': 'dir', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': True, + 'dest_type': 'dir', + 'result': { + 'files': { + '/a/file1': 'src/a/file1', + '/a/file3': 'dest/a/file3', + '/a/subdir/file2': 'src/a/subdir/file2', + '/keep': '', + } + }, + 'src_trailing_slash': True, + 'src_type': 'dir', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': False, + 'dest_type': 'dir', + 'result': { + 'files': { + '/a/file1': 'src/a/file1', + '/a/file3': 'dest/a/file3', + '/a/subdir/file2': 'src/a/subdir/file2', + '/keep': '', + } + }, + 'src_trailing_slash': True, + 'src_type': 'dir', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': True, + 'dest_type': 'dir', + 'result': { + 'files': { + '/a/file1': 'src/a/file1', + '/a/file3': 'dest/a/file3', + '/a/subdir/file2': 'src/a/subdir/file2', + '/keep': '', + } + }, + 'src_trailing_slash': False, + 'src_type': 'dir', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': False, + 'dest_type': 'dir', + 'result': { + 'files': { + '/a/file1': 'src/a/file1', + '/a/file3': 'dest/a/file3', + '/a/subdir/file2': 'src/a/subdir/file2', + '/keep': '', + } + }, + 'src_trailing_slash': False, + 'src_type': 'dir', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': True, + 'dest_type': 'dir', + 'result': { + 'files': { + '/a/a/file1': 'src/a/file1', + '/a/a/subdir/file2': 'src/a/subdir/file2', + '/a/file3': 'dest/a/file3', + '/a/subdir/file2': 'dest/a/subdir/file2', + '/keep': '', + } + }, + 'src_trailing_slash': True, + 'src_type': 'dir', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': False, + 'dest_type': 'dir', + 'result': { + 'files': { + '/a/a/file1': 'src/a/file1', + '/a/a/subdir/file2': 'src/a/subdir/file2', + '/a/file3': 'dest/a/file3', + '/a/subdir/file2': 'dest/a/subdir/file2', + '/keep': '', + } + }, + 'src_trailing_slash': True, + 'src_type': 'dir', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': True, + 'dest_type': 'dir', + 'result': { + 'files': { + '/a/a/file1': 'src/a/file1', + '/a/a/subdir/file2': 'src/a/subdir/file2', + '/a/file3': 'dest/a/file3', + '/a/subdir/file2': 'dest/a/subdir/file2', + '/keep': '', + } + }, + 'src_trailing_slash': False, + 'src_type': 'dir', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': False, + 'dest_type': 'dir', + 'result': { + 'files': { + '/a/a/file1': 'src/a/file1', + '/a/a/subdir/file2': 'src/a/subdir/file2', + '/a/file3': 'dest/a/file3', + '/a/subdir/file2': 'dest/a/subdir/file2', + '/keep': '', + } + }, + 'src_trailing_slash': False, + 'src_type': 'dir', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': True, + 'dest_type': 'dir', + 'result': { + 'files': { + '/a/file3': 'dest/a/file3', + '/a/subdir/file2': 'dest/a/subdir/file2', + '/keep': '', + '/x/a/file1': 'src/a/file1', + '/x/a/subdir/file2': 'src/a/subdir/file2', + } + }, + 'src_trailing_slash': True, + 'src_type': 'dir', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': False, + 'dest_type': 'dir', + 'result': { + 'files': { + '/a/file3': 'dest/a/file3', + '/a/subdir/file2': 'dest/a/subdir/file2', + '/keep': '', + '/x/a/file1': 'src/a/file1', + '/x/a/subdir/file2': 'src/a/subdir/file2', + } + }, + 'src_trailing_slash': True, + 'src_type': 'dir', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': True, + 'dest_type': 'dir', + 'result': { + 'files': { + '/a/file3': 'dest/a/file3', + '/a/subdir/file2': 'dest/a/subdir/file2', + '/keep': '', + '/x/a/file1': 'src/a/file1', + '/x/a/subdir/file2': 'src/a/subdir/file2', + } + }, + 'src_trailing_slash': False, + 'src_type': 'dir', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': False, + 'dest_type': 'dir', + 'result': { + 'files': { + '/a/file3': 'dest/a/file3', + '/a/subdir/file2': 'dest/a/subdir/file2', + '/keep': '', + '/x/a/file1': 'src/a/file1', + '/x/a/subdir/file2': 'src/a/subdir/file2', + } + }, + 'src_trailing_slash': False, + 'src_type': 'dir', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': True, + 'dest_type': 'dir', + 'result': { + 'files': { + '/a/file3': 'dest/a/file3', + '/a/subdir/file2': 'dest/a/subdir/file2', + '/keep': '', + '/x/file1': 'src/a/file1', + '/x/subdir/file2': 'src/a/subdir/file2', + } + }, + 'src_trailing_slash': True, + 'src_type': 'dir', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': False, + 'dest_type': 'dir', + 'result': { + 'files': { + '/a/file3': 'dest/a/file3', + '/a/subdir/file2': 'dest/a/subdir/file2', + '/keep': '', + '/x/file1': 'src/a/file1', + '/x/subdir/file2': 'src/a/subdir/file2', + } + }, + 'src_trailing_slash': True, + 'src_type': 'dir', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': True, + 'dest_type': 'dir', + 'result': { + 'files': { + '/a/file3': 'dest/a/file3', + '/a/subdir/file2': 'dest/a/subdir/file2', + '/keep': '', + '/x/file1': 'src/a/file1', + '/x/subdir/file2': 'src/a/subdir/file2', + } + }, + 'src_trailing_slash': False, + 'src_type': 'dir', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': False, + 'dest_type': 'dir', + 'result': { + 'files': { + '/a/file3': 'dest/a/file3', + '/a/subdir/file2': 'dest/a/subdir/file2', + '/keep': '', + '/x/file1': 'src/a/file1', + '/x/subdir/file2': 'src/a/subdir/file2', + } + }, + 'src_trailing_slash': False, + 'src_type': 'dir', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': True, + 'dest_type': 'dir', + 'result': { + 'files': { + '/a/file3': 'dest/a/file3', + '/a/subdir/file2': 'dest/a/subdir/file2', + '/keep': '', + '/x/a/file1': 'src/a/file1', + '/x/a/subdir/file2': 'src/a/subdir/file2', + } + }, + 'src_trailing_slash': True, + 'src_type': 'dir', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': False, + 'dest_type': 'dir', + 'result': { + 'files': { + '/a/file3': 'dest/a/file3', + '/a/subdir/file2': 'dest/a/subdir/file2', + '/keep': '', + '/x/file1': 'src/a/file1', + '/x/subdir/file2': 'src/a/subdir/file2', + } + }, + 'src_trailing_slash': True, + 'src_type': 'dir', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': True, + 'dest_type': 'dir', + 'result': { + 'files': { + '/a/file3': 'dest/a/file3', + '/a/subdir/file2': 'dest/a/subdir/file2', + '/keep': '', + '/x/a/file1': 'src/a/file1', + '/x/a/subdir/file2': 'src/a/subdir/file2', + } + }, + 'src_trailing_slash': False, + 'src_type': 'dir', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': False, + 'dest_type': 'dir', + 'result': { + 'files': { + '/a/file3': 'dest/a/file3', + '/a/subdir/file2': 'dest/a/subdir/file2', + '/keep': '', + '/x/file1': 'src/a/file1', + '/x/subdir/file2': 'src/a/subdir/file2', + } + }, + 'src_trailing_slash': False, + 'src_type': 'dir', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': True, + 'dest_type': 'noexist', + 'result': {'files': {'/a/file1': 'src/a/file1', '/a/subdir/file2': 'src/a/subdir/file2', '/keep': ''}}, + 'src_trailing_slash': True, + 'src_type': 'dir', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': False, + 'dest_type': 'noexist', + 'result': {'files': {'/a/file1': 'src/a/file1', '/a/subdir/file2': 'src/a/subdir/file2', '/keep': ''}}, + 'src_trailing_slash': True, + 'src_type': 'dir', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': True, + 'dest_type': 'noexist', + 'result': {'files': {'/a/file1': 'src/a/file1', '/a/subdir/file2': 'src/a/subdir/file2', '/keep': ''}}, + 'src_trailing_slash': False, + 'src_type': 'dir', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': False, + 'dest_type': 'noexist', + 'result': {'files': {'/a/file1': 'src/a/file1', '/a/subdir/file2': 'src/a/subdir/file2', '/keep': ''}}, + 'src_trailing_slash': False, + 'src_type': 'dir', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': True, + 'dest_type': 'noexist', + 'result': {'files': {'/file1': 'src/a/file1', '/keep': '', '/subdir/file2': 'src/a/subdir/file2'}}, + 'src_trailing_slash': True, + 'src_type': 'dir', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': False, + 'dest_type': 'noexist', + 'result': {'files': {'/file1': 'src/a/file1', '/keep': '', '/subdir/file2': 'src/a/subdir/file2'}}, + 'src_trailing_slash': True, + 'src_type': 'dir', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': True, + 'dest_type': 'noexist', + 'result': {'files': {'/file1': 'src/a/file1', '/keep': '', '/subdir/file2': 'src/a/subdir/file2'}}, + 'src_trailing_slash': False, + 'src_type': 'dir', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': False, + 'dest_type': 'noexist', + 'result': {'files': {'/file1': 'src/a/file1', '/keep': '', '/subdir/file2': 'src/a/subdir/file2'}}, + 'src_trailing_slash': False, + 'src_type': 'dir', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': True, + 'dest_type': 'noexist', + 'result': {'files': {'/a/file1': 'src/a/file1', '/a/subdir/file2': 'src/a/subdir/file2', '/keep': ''}}, + 'src_trailing_slash': True, + 'src_type': 'dir', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': False, + 'dest_type': 'noexist', + 'result': {'files': {'/a/file1': 'src/a/file1', '/a/subdir/file2': 'src/a/subdir/file2', '/keep': ''}}, + 'src_trailing_slash': True, + 'src_type': 'dir', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': True, + 'dest_type': 'noexist', + 'result': {'files': {'/a/file1': 'src/a/file1', '/a/subdir/file2': 'src/a/subdir/file2', '/keep': ''}}, + 'src_trailing_slash': False, + 'src_type': 'dir', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': False, + 'dest_type': 'noexist', + 'result': {'files': {'/a/file1': 'src/a/file1', '/a/subdir/file2': 'src/a/subdir/file2', '/keep': ''}}, + 'src_trailing_slash': False, + 'src_type': 'dir', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': True, + 'dest_type': 'noexist', + 'result': {'files': {'/a/a/file1': 'src/a/file1', '/a/a/subdir/file2': 'src/a/subdir/file2', '/keep': ''}}, + 'src_trailing_slash': True, + 'src_type': 'dir', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': False, + 'dest_type': 'noexist', + 'result': {'files': {'/a/a/file1': 'src/a/file1', '/a/a/subdir/file2': 'src/a/subdir/file2', '/keep': ''}}, + 'src_trailing_slash': True, + 'src_type': 'dir', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': True, + 'dest_type': 'noexist', + 'result': {'files': {'/a/a/file1': 'src/a/file1', '/a/a/subdir/file2': 'src/a/subdir/file2', '/keep': ''}}, + 'src_trailing_slash': False, + 'src_type': 'dir', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': False, + 'dest_type': 'noexist', + 'result': {'files': {'/a/a/file1': 'src/a/file1', '/a/a/subdir/file2': 'src/a/subdir/file2', '/keep': ''}}, + 'src_trailing_slash': False, + 'src_type': 'dir', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': True, + 'dest_type': 'noexist', + 'result': {'files': {'/a/file1': 'src/a/file1', '/a/subdir/file2': 'src/a/subdir/file2', '/keep': ''}}, + 'src_trailing_slash': True, + 'src_type': 'dir', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': False, + 'dest_type': 'noexist', + 'result': {'files': {'/a/file1': 'src/a/file1', '/a/subdir/file2': 'src/a/subdir/file2', '/keep': ''}}, + 'src_trailing_slash': True, + 'src_type': 'dir', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': True, + 'dest_type': 'noexist', + 'result': {'files': {'/a/file1': 'src/a/file1', '/a/subdir/file2': 'src/a/subdir/file2', '/keep': ''}}, + 'src_trailing_slash': False, + 'src_type': 'dir', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': False, + 'dest_type': 'noexist', + 'result': {'files': {'/a/file1': 'src/a/file1', '/a/subdir/file2': 'src/a/subdir/file2', '/keep': ''}}, + 'src_trailing_slash': False, + 'src_type': 'dir', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': True, + 'dest_type': 'noexist', + 'result': {'files': {'/a/a/file1': 'src/a/file1', '/a/a/subdir/file2': 'src/a/subdir/file2', '/keep': ''}}, + 'src_trailing_slash': True, + 'src_type': 'dir', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': False, + 'dest_type': 'noexist', + 'result': {'files': {'/a/file1': 'src/a/file1', '/a/subdir/file2': 'src/a/subdir/file2', '/keep': ''}}, + 'src_trailing_slash': True, + 'src_type': 'dir', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': True, + 'dest_type': 'noexist', + 'result': {'files': {'/a/a/file1': 'src/a/file1', '/a/a/subdir/file2': 'src/a/subdir/file2', '/keep': ''}}, + 'src_trailing_slash': False, + 'src_type': 'dir', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': False, + 'dest_type': 'noexist', + 'result': {'files': {'/a/file1': 'src/a/file1', '/a/subdir/file2': 'src/a/subdir/file2', '/keep': ''}}, + 'src_trailing_slash': False, + 'src_type': 'dir', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': True, + 'dest_type': 'noexist', + 'result': {'files': {'/keep': '', '/x/a/file1': 'src/a/file1', '/x/a/subdir/file2': 'src/a/subdir/file2'}}, + 'src_trailing_slash': True, + 'src_type': 'dir', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': False, + 'dest_type': 'noexist', + 'result': {'files': {'/keep': '', '/x/a/file1': 'src/a/file1', '/x/a/subdir/file2': 'src/a/subdir/file2'}}, + 'src_trailing_slash': True, + 'src_type': 'dir', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': True, + 'dest_type': 'noexist', + 'result': {'files': {'/keep': '', '/x/a/file1': 'src/a/file1', '/x/a/subdir/file2': 'src/a/subdir/file2'}}, + 'src_trailing_slash': False, + 'src_type': 'dir', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': False, + 'dest_type': 'noexist', + 'result': {'files': {'/keep': '', '/x/a/file1': 'src/a/file1', '/x/a/subdir/file2': 'src/a/subdir/file2'}}, + 'src_trailing_slash': False, + 'src_type': 'dir', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': True, + 'dest_type': 'noexist', + 'result': {'files': {'/keep': '', '/x/file1': 'src/a/file1', '/x/subdir/file2': 'src/a/subdir/file2'}}, + 'src_trailing_slash': True, + 'src_type': 'dir', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': False, + 'dest_type': 'noexist', + 'result': {'files': {'/keep': '', '/x/file1': 'src/a/file1', '/x/subdir/file2': 'src/a/subdir/file2'}}, + 'src_trailing_slash': True, + 'src_type': 'dir', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': True, + 'dest_type': 'noexist', + 'result': {'files': {'/keep': '', '/x/file1': 'src/a/file1', '/x/subdir/file2': 'src/a/subdir/file2'}}, + 'src_trailing_slash': False, + 'src_type': 'dir', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': False, + 'dest_type': 'noexist', + 'result': {'files': {'/keep': '', '/x/file1': 'src/a/file1', '/x/subdir/file2': 'src/a/subdir/file2'}}, + 'src_trailing_slash': False, + 'src_type': 'dir', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': True, + 'dest_type': 'noexist', + 'result': {'files': {'/keep': '', '/x/a/file1': 'src/a/file1', '/x/a/subdir/file2': 'src/a/subdir/file2'}}, + 'src_trailing_slash': True, + 'src_type': 'dir', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': False, + 'dest_type': 'noexist', + 'result': {'files': {'/keep': '', '/x/file1': 'src/a/file1', '/x/subdir/file2': 'src/a/subdir/file2'}}, + 'src_trailing_slash': True, + 'src_type': 'dir', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': True, + 'dest_type': 'noexist', + 'result': {'files': {'/keep': '', '/x/a/file1': 'src/a/file1', '/x/a/subdir/file2': 'src/a/subdir/file2'}}, + 'src_trailing_slash': False, + 'src_type': 'dir', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': False, + 'dest_type': 'noexist', + 'result': {'files': {'/keep': '', '/x/file1': 'src/a/file1', '/x/subdir/file2': 'src/a/subdir/file2'}}, + 'src_trailing_slash': False, + 'src_type': 'dir', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': True, + 'dest_type': 'file', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': False, + 'dest_type': 'file', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': True, + 'dest_type': 'file', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': False, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': False, + 'dest_type': 'file', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': False, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': True, + 'dest_type': 'file', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': False, + 'dest_type': 'file', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': True, + 'dest_type': 'file', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': False, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': False, + 'dest_type': 'file', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': False, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': True, + 'dest_type': 'file', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'noexist', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': False, + 'dest_type': 'file', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'noexist', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': True, + 'dest_type': 'file', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': False, + 'src_type': 'noexist', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': False, + 'dest_type': 'file', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': False, + 'src_type': 'noexist', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': True, + 'dest_type': 'file', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': False, + 'dest_type': 'file', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': True, + 'dest_type': 'file', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': False, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': False, + 'dest_type': 'file', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': False, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': True, + 'dest_type': 'file', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': False, + 'dest_type': 'file', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': True, + 'dest_type': 'file', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': False, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': False, + 'dest_type': 'file', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': False, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': True, + 'dest_type': 'file', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'noexist', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': False, + 'dest_type': 'file', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'noexist', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': True, + 'dest_type': 'file', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': False, + 'src_type': 'noexist', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': False, + 'dest_type': 'file', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': False, + 'src_type': 'noexist', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': True, + 'dest_type': 'file', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': False, + 'dest_type': 'file', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': True, + 'dest_type': 'file', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': False, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': False, + 'dest_type': 'file', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': False, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': True, + 'dest_type': 'file', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': False, + 'dest_type': 'file', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': True, + 'dest_type': 'file', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': False, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': False, + 'dest_type': 'file', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': False, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': True, + 'dest_type': 'file', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'noexist', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': False, + 'dest_type': 'file', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'noexist', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': True, + 'dest_type': 'file', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': False, + 'src_type': 'noexist', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': False, + 'dest_type': 'file', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': False, + 'src_type': 'noexist', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': True, + 'dest_type': 'dir', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': False, + 'dest_type': 'dir', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': True, + 'dest_type': 'dir', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': False, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': False, + 'dest_type': 'dir', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': False, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': True, + 'dest_type': 'dir', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': False, + 'dest_type': 'dir', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': True, + 'dest_type': 'dir', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': False, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': False, + 'dest_type': 'dir', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': False, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': True, + 'dest_type': 'dir', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'noexist', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': False, + 'dest_type': 'dir', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'noexist', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': True, + 'dest_type': 'dir', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': False, + 'src_type': 'noexist', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': False, + 'dest_type': 'dir', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': False, + 'src_type': 'noexist', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': True, + 'dest_type': 'dir', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': False, + 'dest_type': 'dir', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': True, + 'dest_type': 'dir', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': False, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': False, + 'dest_type': 'dir', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': False, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': True, + 'dest_type': 'dir', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': False, + 'dest_type': 'dir', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': True, + 'dest_type': 'dir', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': False, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': False, + 'dest_type': 'dir', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': False, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': True, + 'dest_type': 'dir', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'noexist', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': False, + 'dest_type': 'dir', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'noexist', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': True, + 'dest_type': 'dir', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': False, + 'src_type': 'noexist', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': False, + 'dest_type': 'dir', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': False, + 'src_type': 'noexist', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': True, + 'dest_type': 'dir', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': False, + 'dest_type': 'dir', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': True, + 'dest_type': 'dir', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': False, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': False, + 'dest_type': 'dir', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': False, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': True, + 'dest_type': 'dir', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': False, + 'dest_type': 'dir', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': True, + 'dest_type': 'dir', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': False, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': False, + 'dest_type': 'dir', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': False, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': True, + 'dest_type': 'dir', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'noexist', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': False, + 'dest_type': 'dir', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'noexist', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': True, + 'dest_type': 'dir', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': False, + 'src_type': 'noexist', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': False, + 'dest_type': 'dir', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': False, + 'src_type': 'noexist', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': True, + 'dest_type': 'noexist', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': False, + 'dest_type': 'noexist', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': True, + 'dest_type': 'noexist', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': False, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': False, + 'dest_type': 'noexist', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': False, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': True, + 'dest_type': 'noexist', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': False, + 'dest_type': 'noexist', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': True, + 'dest_type': 'noexist', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': False, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': False, + 'dest_type': 'noexist', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': False, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': True, + 'dest_type': 'noexist', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'noexist', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': False, + 'dest_type': 'noexist', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'noexist', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': True, + 'dest_type': 'noexist', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': False, + 'src_type': 'noexist', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': None, + 'dest_trailing_slash': False, + 'dest_type': 'noexist', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': False, + 'src_type': 'noexist', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': True, + 'dest_type': 'noexist', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': False, + 'dest_type': 'noexist', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': True, + 'dest_type': 'noexist', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': False, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': False, + 'dest_type': 'noexist', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': False, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': True, + 'dest_type': 'noexist', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': False, + 'dest_type': 'noexist', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': True, + 'dest_type': 'noexist', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': False, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': False, + 'dest_type': 'noexist', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': False, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': True, + 'dest_type': 'noexist', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'noexist', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': False, + 'dest_type': 'noexist', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'noexist', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': True, + 'dest_type': 'noexist', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': False, + 'src_type': 'noexist', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'a', + 'dest_trailing_slash': False, + 'dest_type': 'noexist', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': False, + 'src_type': 'noexist', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': True, + 'dest_type': 'noexist', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': False, + 'dest_type': 'noexist', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': True, + 'dest_type': 'noexist', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': False, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': False, + 'dest_type': 'noexist', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': False, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_dir', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': True, + 'dest_type': 'noexist', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': False, + 'dest_type': 'noexist', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': True, + 'dest_type': 'noexist', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': False, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': False, + 'dest_type': 'noexist', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': False, + 'src_type': 'noexist', + 'treat_dest_as': 'dest_is_target', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': True, + 'dest_type': 'noexist', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'noexist', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': False, + 'dest_type': 'noexist', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': True, + 'src_type': 'noexist', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': True, + 'dest_type': 'noexist', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': False, + 'src_type': 'noexist', + 'treat_dest_as': 'infer_dest', + }, + { + 'dest_basename': 'x', + 'dest_trailing_slash': False, + 'dest_type': 'noexist', + 'result': {'exception': 'FileNotFoundError'}, + 'src_trailing_slash': False, + 'src_type': 'noexist', + 'treat_dest_as': 'infer_dest', + }, +] diff --git a/hail/python/test/hailtop/inter_cloud/generate_copy_test_specs.py b/hail/python/test/hailtop/inter_cloud/generate_copy_test_specs.py index 5638372c77b..f9d11fe5d1f 100644 --- a/hail/python/test/hailtop/inter_cloud/generate_copy_test_specs.py +++ b/hail/python/test/hailtop/inter_cloud/generate_copy_test_specs.py @@ -1,8 +1,9 @@ +import asyncio +import pprint import secrets from concurrent.futures import ThreadPoolExecutor -import pprint -import asyncio -from hailtop.aiotools import LocalAsyncFS, Transfer, Copier + +from hailtop.aiotools import Copier, Transfer from hailtop.aiotools.router_fs import RouterAsyncFS @@ -20,7 +21,7 @@ async def create_test_file(fs, name, base, path): async def create_test_dir(fs, name, base, path): - '''Create a directory of test data. + """Create a directory of test data. The directory test data depends on the name (src or dest) so, when testing overwriting for example, there is a file in src which does @@ -34,7 +35,7 @@ async def create_test_dir(fs, name, base, path): The dest configuration looks like: - {base}/dest/a/subdir/file2 - {base}/dest/a/file3 - ''' + """ assert name in ('src', 'dest') assert path.endswith('/') @@ -70,7 +71,7 @@ def copy_test_configurations(): 'dest_basename': dest_basename, 'treat_dest_as': treat_dest_as, 'src_trailing_slash': src_trailing_slash, - 'dest_trailing_slash': dest_trailing_slash + 'dest_trailing_slash': dest_trailing_slash, } @@ -121,7 +122,7 @@ async def copy_test_specs(): test_specs = [] with ThreadPoolExecutor() as thread_pool: - async with RouterAsyncFS(filesystems=[LocalAsyncFS(thread_pool)]) as fs: + async with RouterAsyncFS(local_kwargs={'thread_pool': thread_pool}) as fs: for config in copy_test_configurations(): token = secrets.token_hex(16) @@ -152,7 +153,7 @@ async def copy_test_specs(): async def main(): test_specs = await copy_test_specs() with open('test/hailtop/aiotools/copy_test_specs.py', 'w') as f: - f.write(f'COPY_TEST_SPECS = ') + f.write('COPY_TEST_SPECS = ') pprint.pprint(test_specs, stream=f) diff --git a/hail/python/test/hailtop/inter_cloud/test_copy.py b/hail/python/test/hailtop/inter_cloud/test_copy.py index de4c4f8acfa..baaa4fbb1d3 100644 --- a/hail/python/test/hailtop/inter_cloud/test_copy.py +++ b/hail/python/test/hailtop/inter_cloud/test_copy.py @@ -1,29 +1,15 @@ -from typing import Tuple, Dict, AsyncIterator, List -import os -import secrets -from concurrent.futures import ThreadPoolExecutor import asyncio -import functools -import pytest -from hailtop.utils import url_scheme, bounded_gather2 -from hailtop.aiotools import LocalAsyncFS, Transfer, FileAndDirectoryError, Copier, AsyncFS, FileListEntry -from hailtop.aiotools.router_fs import RouterAsyncFS -from hailtop.aiocloud.aiogoogle import GoogleStorageAsyncFS -from hailtop.aiocloud.aioaws import S3AsyncFS -from hailtop.aiocloud.aioazure import AzureAsyncFS +import secrets +from typing import AsyncIterator, Dict, List, Tuple +import pytest -from .generate_copy_test_specs import run_test_spec, create_test_file, create_test_dir +from hailtop.aiotools import AsyncFS, Copier, FileAndDirectoryError, FileListEntry, Transfer +from hailtop.utils import url_scheme from .copy_test_specs import COPY_TEST_SPECS - - -@pytest.fixture(scope='module') -def event_loop(): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - yield loop - loop.close() +from .generate_copy_test_specs import create_test_dir, create_test_file, run_test_spec +from .utils import fresh_dir # This fixture is for test_copy_behavior. It runs a series of copy @@ -42,65 +28,27 @@ async def test_spec(request): async def cloud_scheme(request): yield request.param -@pytest.fixture(scope='module') -async def router_filesystem(request) -> AsyncIterator[Tuple[asyncio.Semaphore, AsyncFS, Dict[str, str]]]: - token = secrets.token_hex(16) - - with ThreadPoolExecutor() as thread_pool: - async with RouterAsyncFS( - filesystems=[ - LocalAsyncFS(thread_pool), - GoogleStorageAsyncFS(), - S3AsyncFS(thread_pool), - AzureAsyncFS() - ] - ) as fs: - file_base = f'/tmp/{token}/' - await fs.mkdir(file_base) - - gs_bucket = os.environ['HAIL_TEST_GCS_BUCKET'] - gs_base = f'gs://{gs_bucket}/tmp/{token}/' - - s3_bucket = os.environ['HAIL_TEST_S3_BUCKET'] - s3_base = f's3://{s3_bucket}/tmp/{token}/' - - azure_account = os.environ['HAIL_TEST_AZURE_ACCOUNT'] - azure_container = os.environ['HAIL_TEST_AZURE_CONTAINER'] - azure_base = f'https://{azure_account}.blob.core.windows.net/{azure_container}/tmp/{token}/' - - bases = { - 'file': file_base, - 'gs': gs_base, - 's3': s3_base, - 'azure-https': azure_base - } - - sema = asyncio.Semaphore(50) - async with sema: - yield (sema, fs, bases) - await bounded_gather2(sema, - functools.partial(fs.rmtree, sema, file_base), - functools.partial(fs.rmtree, sema, gs_base), - functools.partial(fs.rmtree, sema, s3_base), - functools.partial(fs.rmtree, sema, azure_base)) - - assert not await fs.isdir(file_base) - assert not await fs.isdir(gs_base) - assert not await fs.isdir(s3_base) - assert not await fs.isdir(azure_base) - - -async def fresh_dir(fs, bases, scheme): - token = secrets.token_hex(16) - dir = f'{bases[scheme]}{token}/' - await fs.mkdir(dir) - return dir - - -@pytest.fixture(params=['file/file', 'file/gs', 'file/s3', 'file/azure-https', - 'gs/file', 'gs/gs', 'gs/s3', 'gs/azure-https', - 's3/file', 's3/gs', 's3/s3', 's3/azure-https', - 'azure-https/file', 'azure-https/gs', 'azure-https/s3', 'azure-https/azure-https']) + +@pytest.fixture( + params=[ + 'file/file', + 'file/gs', + 'file/s3', + 'file/azure-https', + 'gs/file', + 'gs/gs', + 'gs/s3', + 'gs/azure-https', + 's3/file', + 's3/gs', + 's3/s3', + 's3/azure-https', + 'azure-https/file', + 'azure-https/gs', + 'azure-https/s3', + 'azure-https/azure-https', + ] +) async def copy_test_context(request, router_filesystem: Tuple[asyncio.Semaphore, AsyncFS, Dict[str, str]]): sema, fs, bases = router_filesystem @@ -116,7 +64,6 @@ async def copy_test_context(request, router_filesystem: Tuple[asyncio.Semaphore, yield sema, fs, src_base, dest_base -@pytest.mark.asyncio async def test_copy_behavior(copy_test_context, test_spec): sema, fs, src_base, dest_base = copy_test_context @@ -125,9 +72,11 @@ async def test_copy_behavior(copy_test_context, test_spec): expected = test_spec['result'] dest_scheme = url_scheme(dest_base) - if ((dest_scheme == 'gs' or dest_scheme == 's3' or dest_scheme == 'https') - and (result is not None and 'files' in result) - and expected.get('exception') in ('IsADirectoryError', 'NotADirectoryError')): + if ( + (dest_scheme in {'gs', 's3', 'https'}) + and (result is not None and 'files' in result) + and expected.get('exception') in ('IsADirectoryError', 'NotADirectoryError') + ): return assert result == expected, (test_spec, result, expected) @@ -153,7 +102,7 @@ class RaisedWrongExceptionError(Exception): class RaisesOrObjectStore: def __init__(self, dest_base, expected_type): scheme = url_scheme(dest_base) - self._object_store = (scheme == 'gs' or scheme == 's3' or scheme == 'https') + self._object_store = scheme in {'gs', 's3', 'https'} self._expected_type = expected_type def __enter__(self): @@ -171,7 +120,6 @@ def __exit__(self, type, value, traceback): return True -@pytest.mark.asyncio async def test_copy_doesnt_exist(copy_test_context): sema, fs, src_base, dest_base = copy_test_context @@ -179,7 +127,6 @@ async def test_copy_doesnt_exist(copy_test_context): await Copier.copy(fs, sema, Transfer(f'{src_base}a', dest_base)) -@pytest.mark.asyncio async def test_copy_file(copy_test_context): sema, fs, src_base, dest_base = copy_test_context @@ -190,7 +137,6 @@ async def test_copy_file(copy_test_context): await expect_file(fs, f'{dest_base}a', 'src/a') -@pytest.mark.asyncio async def test_copy_large_file(copy_test_context): sema, fs, src_base, dest_base = copy_test_context @@ -206,7 +152,6 @@ async def test_copy_large_file(copy_test_context): assert copy_contents == contents -@pytest.mark.asyncio async def test_copy_rename_file(copy_test_context): sema, fs, src_base, dest_base = copy_test_context @@ -217,7 +162,6 @@ async def test_copy_rename_file(copy_test_context): await expect_file(fs, f'{dest_base}x', 'src/a') -@pytest.mark.asyncio async def test_copy_rename_file_dest_target_file(copy_test_context): sema, fs, src_base, dest_base = copy_test_context @@ -228,7 +172,6 @@ async def test_copy_rename_file_dest_target_file(copy_test_context): await expect_file(fs, f'{dest_base}x', 'src/a') -@pytest.mark.asyncio async def test_copy_file_dest_target_directory_doesnt_exist(copy_test_context): sema, fs, src_base, dest_base = copy_test_context @@ -239,7 +182,6 @@ async def test_copy_file_dest_target_directory_doesnt_exist(copy_test_context): await expect_file(fs, f'{dest_base}x/a', 'src/a') -@pytest.mark.asyncio async def test_overwrite_rename_file(copy_test_context): sema, fs, src_base, dest_base = copy_test_context @@ -251,7 +193,6 @@ async def test_overwrite_rename_file(copy_test_context): await expect_file(fs, f'{dest_base}x', 'src/a') -@pytest.mark.asyncio async def test_copy_rename_dir(copy_test_context): sema, fs, src_base, dest_base = copy_test_context @@ -263,7 +204,6 @@ async def test_copy_rename_dir(copy_test_context): await expect_file(fs, f'{dest_base}x/subdir/file2', 'src/a/subdir/file2') -@pytest.mark.asyncio async def test_copy_rename_dir_dest_is_target(copy_test_context): sema, fs, src_base, dest_base = copy_test_context @@ -275,7 +215,6 @@ async def test_copy_rename_dir_dest_is_target(copy_test_context): await expect_file(fs, f'{dest_base}x/subdir/file2', 'src/a/subdir/file2') -@pytest.mark.asyncio async def test_overwrite_rename_dir(copy_test_context): sema, fs, src_base, dest_base = copy_test_context @@ -289,7 +228,6 @@ async def test_overwrite_rename_dir(copy_test_context): await expect_file(fs, f'{dest_base}x/file3', 'dest/x/file3') -@pytest.mark.asyncio async def test_copy_file_dest_trailing_slash_target_dir(copy_test_context): sema, fs, src_base, dest_base = copy_test_context @@ -300,7 +238,6 @@ async def test_copy_file_dest_trailing_slash_target_dir(copy_test_context): await expect_file(fs, f'{dest_base}a', 'src/a') -@pytest.mark.asyncio async def test_copy_file_dest_target_dir(copy_test_context): sema, fs, src_base, dest_base = copy_test_context @@ -311,7 +248,6 @@ async def test_copy_file_dest_target_dir(copy_test_context): await expect_file(fs, f'{dest_base}a', 'src/a') -@pytest.mark.asyncio async def test_copy_file_dest_target_file(copy_test_context): sema, fs, src_base, dest_base = copy_test_context @@ -322,17 +258,17 @@ async def test_copy_file_dest_target_file(copy_test_context): await expect_file(fs, f'{dest_base}a', 'src/a') -@pytest.mark.asyncio async def test_copy_dest_target_file_is_dir(copy_test_context): sema, fs, src_base, dest_base = copy_test_context await create_test_file(fs, 'src', src_base, 'a') with RaisesOrObjectStore(dest_base, IsADirectoryError): - await Copier.copy(fs, sema, Transfer(f'{src_base}a', dest_base.rstrip('/'), treat_dest_as=Transfer.DEST_IS_TARGET)) + await Copier.copy( + fs, sema, Transfer(f'{src_base}a', dest_base.rstrip('/'), treat_dest_as=Transfer.DEST_IS_TARGET) + ) -@pytest.mark.asyncio async def test_overwrite_file(copy_test_context): sema, fs, src_base, dest_base = copy_test_context @@ -344,7 +280,6 @@ async def test_overwrite_file(copy_test_context): await expect_file(fs, f'{dest_base}a', 'src/a') -@pytest.mark.asyncio async def test_copy_file_src_trailing_slash(copy_test_context): sema, fs, src_base, dest_base = copy_test_context @@ -354,7 +289,6 @@ async def test_copy_file_src_trailing_slash(copy_test_context): await Copier.copy(fs, sema, Transfer(f'{src_base}a/', dest_base)) -@pytest.mark.asyncio async def test_copy_dir(copy_test_context): sema, fs, src_base, dest_base = copy_test_context @@ -366,7 +300,6 @@ async def test_copy_dir(copy_test_context): await expect_file(fs, f'{dest_base}a/subdir/file2', 'src/a/subdir/file2') -@pytest.mark.asyncio async def test_overwrite_dir(copy_test_context): sema, fs, src_base, dest_base = copy_test_context @@ -380,7 +313,6 @@ async def test_overwrite_dir(copy_test_context): await expect_file(fs, f'{dest_base}a/file3', 'dest/a/file3') -@pytest.mark.asyncio async def test_copy_multiple(copy_test_context): sema, fs, src_base, dest_base = copy_test_context @@ -393,7 +325,6 @@ async def test_copy_multiple(copy_test_context): await expect_file(fs, f'{dest_base}b', 'src/b') -@pytest.mark.asyncio async def test_copy_multiple_dest_target_file(copy_test_context): sema, fs, src_base, dest_base = copy_test_context @@ -401,10 +332,13 @@ async def test_copy_multiple_dest_target_file(copy_test_context): await create_test_file(fs, 'src', src_base, 'b') with RaisesOrObjectStore(dest_base, NotADirectoryError): - await Copier.copy(fs, sema, Transfer([f'{src_base}a', f'{src_base}b'], dest_base.rstrip('/'), treat_dest_as=Transfer.DEST_IS_TARGET)) + await Copier.copy( + fs, + sema, + Transfer([f'{src_base}a', f'{src_base}b'], dest_base.rstrip('/'), treat_dest_as=Transfer.DEST_IS_TARGET), + ) -@pytest.mark.asyncio async def test_copy_multiple_dest_file(copy_test_context): sema, fs, src_base, dest_base = copy_test_context @@ -416,18 +350,20 @@ async def test_copy_multiple_dest_file(copy_test_context): await Copier.copy(fs, sema, Transfer([f'{src_base}a', f'{src_base}b'], f'{dest_base}x')) -@pytest.mark.asyncio async def test_file_overwrite_dir(copy_test_context): sema, fs, src_base, dest_base = copy_test_context await create_test_file(fs, 'src', src_base, 'a') with RaisesOrObjectStore(dest_base, IsADirectoryError): - await Copier.copy(fs, sema, Transfer(f'{src_base}a', dest_base.rstrip('/'), treat_dest_as=Transfer.DEST_IS_TARGET)) + await Copier.copy( + fs, sema, Transfer(f'{src_base}a', dest_base.rstrip('/'), treat_dest_as=Transfer.DEST_IS_TARGET) + ) -@pytest.mark.asyncio -async def test_file_and_directory_error(router_filesystem: Tuple[asyncio.Semaphore, AsyncFS, Dict[str, str]], cloud_scheme: str): +async def test_file_and_directory_error( + router_filesystem: Tuple[asyncio.Semaphore, AsyncFS, Dict[str, str]], cloud_scheme: str +): sema, fs, bases = router_filesystem src_base = await fresh_dir(fs, bases, cloud_scheme) @@ -440,13 +376,16 @@ async def test_file_and_directory_error(router_filesystem: Tuple[asyncio.Semapho await Copier.copy(fs, sema, Transfer(f'{src_base}a', dest_base.rstrip('/'))) -@pytest.mark.asyncio async def test_copy_src_parts(copy_test_context): sema, fs, src_base, dest_base = copy_test_context await create_test_dir(fs, 'src', src_base, 'a/') - await Copier.copy(fs, sema, Transfer([f'{src_base}a/file1', f'{src_base}a/subdir'], dest_base.rstrip('/'), treat_dest_as=Transfer.DEST_DIR)) + await Copier.copy( + fs, + sema, + Transfer([f'{src_base}a/file1', f'{src_base}a/subdir'], dest_base.rstrip('/'), treat_dest_as=Transfer.DEST_DIR), + ) await expect_file(fs, f'{dest_base}file1', 'src/a/file1') await expect_file(fs, f'{dest_base}subdir/file2', 'src/a/subdir/file2') @@ -461,8 +400,9 @@ async def collect_files(it: AsyncIterator[FileListEntry]) -> List[str]: return [await x.url() async for x in it] -@pytest.mark.asyncio -async def test_file_and_directory_error_with_slash_empty_file(router_filesystem: Tuple[asyncio.Semaphore, AsyncFS, Dict[str, str]], cloud_scheme: str): +async def test_file_and_directory_error_with_slash_empty_file( + router_filesystem: Tuple[asyncio.Semaphore, AsyncFS, Dict[str, str]], cloud_scheme: str +): sema, fs, bases = router_filesystem src_base = await fresh_dir(fs, bases, cloud_scheme) @@ -478,10 +418,6 @@ async def test_file_and_directory_error_with_slash_empty_file(router_filesystem: for transfer_type in (Transfer.DEST_IS_TARGET, Transfer.DEST_DIR, Transfer.INFER_DEST): dest_base = await fresh_dir(fs, bases, cloud_scheme) - await Copier.copy(fs, sema, Transfer(f'{src_base}', dest_base.rstrip('/'), treat_dest_as=transfer_type)) - - dest_base = await fresh_dir(fs, bases, cloud_scheme) - await Copier.copy(fs, sema, Transfer(f'{src_base}empty/', dest_base.rstrip('/'), treat_dest_as=transfer_type)) await collect_files(await fs.listfiles(f'{dest_base}')) @@ -498,8 +434,10 @@ async def test_file_and_directory_error_with_slash_empty_file(router_filesystem: exp_dest = f'{dest_base}foo' await expect_file(fs, exp_dest, 'foo') -@pytest.mark.asyncio -async def test_file_and_directory_error_with_slash_non_empty_file_for_google_non_recursive(router_filesystem: Tuple[asyncio.Semaphore, AsyncFS, Dict[str, str]]): + +async def test_file_and_directory_error_with_slash_non_empty_file_for_google_non_recursive( + router_filesystem: Tuple[asyncio.Semaphore, AsyncFS, Dict[str, str]], +): _, fs, bases = router_filesystem src_base = await fresh_dir(fs, bases, 'gs') @@ -514,8 +452,9 @@ async def test_file_and_directory_error_with_slash_non_empty_file_for_google_non await collect_files(await fs.listfiles(f'{src_base}not-empty/')) -@pytest.mark.asyncio -async def test_file_and_directory_error_with_slash_non_empty_file(router_filesystem: Tuple[asyncio.Semaphore, AsyncFS, Dict[str, str]], cloud_scheme: str): +async def test_file_and_directory_error_with_slash_non_empty_file( + router_filesystem: Tuple[asyncio.Semaphore, AsyncFS, Dict[str, str]], cloud_scheme: str +): sema, fs, bases = router_filesystem src_base = await fresh_dir(fs, bases, cloud_scheme) @@ -532,7 +471,9 @@ async def test_file_and_directory_error_with_slash_non_empty_file(router_filesys for transfer_type in (Transfer.DEST_IS_TARGET, Transfer.DEST_DIR, Transfer.INFER_DEST): dest_base = await fresh_dir(fs, bases, cloud_scheme) - await Copier.copy(fs, sema, Transfer(f'{src_base}not-empty/bar', dest_base.rstrip('/'), treat_dest_as=transfer_type)) + await Copier.copy( + fs, sema, Transfer(f'{src_base}not-empty/bar', dest_base.rstrip('/'), treat_dest_as=transfer_type) + ) if transfer_type == Transfer.DEST_DIR: exp_dest = f'{dest_base}bar' await expect_file(fs, exp_dest, 'bar') @@ -545,15 +486,18 @@ async def test_file_and_directory_error_with_slash_non_empty_file(router_filesys with pytest.raises(FileAndDirectoryError): dest_base = await fresh_dir(fs, bases, cloud_scheme) - await Copier.copy(fs, sema, Transfer(f'{src_base}not-empty/', dest_base.rstrip('/'), treat_dest_as=transfer_type)) + await Copier.copy( + fs, sema, Transfer(f'{src_base}not-empty/', dest_base.rstrip('/'), treat_dest_as=transfer_type) + ) with pytest.raises(FileAndDirectoryError): dest_base = await fresh_dir(fs, bases, cloud_scheme) await Copier.copy(fs, sema, Transfer(f'{src_base}', dest_base.rstrip('/'), treat_dest_as=transfer_type)) -@pytest.mark.asyncio -async def test_file_and_directory_error_with_slash_non_empty_file_only_for_google_non_recursive(router_filesystem: Tuple[asyncio.Semaphore, AsyncFS, Dict[str, str]]): +async def test_file_and_directory_error_with_slash_non_empty_file_only_for_google_non_recursive( + router_filesystem: Tuple[asyncio.Semaphore, AsyncFS, Dict[str, str]], +): sema, fs, bases = router_filesystem src_base = await fresh_dir(fs, bases, 'gs') @@ -565,15 +509,18 @@ async def test_file_and_directory_error_with_slash_non_empty_file_only_for_googl for transfer_type in (Transfer.DEST_IS_TARGET, Transfer.DEST_DIR, Transfer.INFER_DEST): dest_base = await fresh_dir(fs, bases, 'gs') - await Copier.copy(fs, sema, Transfer(f'{src_base}empty-only/', dest_base.rstrip('/'), treat_dest_as=transfer_type)) + await Copier.copy( + fs, sema, Transfer(f'{src_base}empty-only/', dest_base.rstrip('/'), treat_dest_as=transfer_type) + ) # We ignore empty directories when copying with pytest.raises(FileNotFoundError): await collect_files(await fs.listfiles(f'{dest_base}empty-only/')) -@pytest.mark.asyncio -async def test_file_and_directory_error_with_slash_empty_file_only(router_filesystem: Tuple[asyncio.Semaphore, AsyncFS, Dict[str, str]], cloud_scheme: str): +async def test_file_and_directory_error_with_slash_empty_file_only( + router_filesystem: Tuple[asyncio.Semaphore, AsyncFS, Dict[str, str]], cloud_scheme: str +): sema, fs, bases = router_filesystem src_base = await fresh_dir(fs, bases, cloud_scheme) @@ -585,7 +532,9 @@ async def test_file_and_directory_error_with_slash_empty_file_only(router_filesy for transfer_type in (Transfer.DEST_IS_TARGET, Transfer.DEST_DIR, Transfer.INFER_DEST): dest_base = await fresh_dir(fs, bases, cloud_scheme) - await Copier.copy(fs, sema, Transfer(f'{src_base}empty-only/', dest_base.rstrip('/'), treat_dest_as=transfer_type)) + await Copier.copy( + fs, sema, Transfer(f'{src_base}empty-only/', dest_base.rstrip('/'), treat_dest_as=transfer_type) + ) with pytest.raises(FileNotFoundError): await collect_files(await fs.listfiles(f'{dest_base}empty-only/', recursive=True)) @@ -594,8 +543,9 @@ async def test_file_and_directory_error_with_slash_empty_file_only(router_filesy await Copier.copy(fs, sema, Transfer(f'{src_base}', dest_base.rstrip('/'), treat_dest_as=transfer_type)) -@pytest.mark.asyncio -async def test_file_and_directory_error_with_slash_non_empty_file_only_google_non_recursive(router_filesystem: Tuple[asyncio.Semaphore, AsyncFS, Dict[str, str]]): +async def test_file_and_directory_error_with_slash_non_empty_file_only_google_non_recursive( + router_filesystem: Tuple[asyncio.Semaphore, AsyncFS, Dict[str, str]], +): _, fs, bases = router_filesystem src_base = await fresh_dir(fs, bases, 'gs') @@ -609,8 +559,9 @@ async def test_file_and_directory_error_with_slash_non_empty_file_only_google_no await collect_files(await fs.listfiles(f'{src_base}not-empty-file-w-slash/')) -@pytest.mark.asyncio -async def test_file_and_directory_error_with_slash_non_empty_file_only(router_filesystem: Tuple[asyncio.Semaphore, AsyncFS, Dict[str, str]], cloud_scheme: str): +async def test_file_and_directory_error_with_slash_non_empty_file_only( + router_filesystem: Tuple[asyncio.Semaphore, AsyncFS, Dict[str, str]], cloud_scheme: str +): sema, fs, bases = router_filesystem src_base = await fresh_dir(fs, bases, cloud_scheme) @@ -626,7 +577,11 @@ async def test_file_and_directory_error_with_slash_non_empty_file_only(router_fi for transfer_type in (Transfer.DEST_IS_TARGET, Transfer.DEST_DIR, Transfer.INFER_DEST): with pytest.raises(FileAndDirectoryError): dest_base = await fresh_dir(fs, bases, cloud_scheme) - await Copier.copy(fs, sema, Transfer(f'{src_base}not-empty-file-w-slash/', dest_base.rstrip('/'), treat_dest_as=transfer_type)) + await Copier.copy( + fs, + sema, + Transfer(f'{src_base}not-empty-file-w-slash/', dest_base.rstrip('/'), treat_dest_as=transfer_type), + ) with pytest.raises(FileAndDirectoryError): dest_base = await fresh_dir(fs, bases, cloud_scheme) diff --git a/hail/python/test/hailtop/inter_cloud/test_delete.py b/hail/python/test/hailtop/inter_cloud/test_delete.py new file mode 100644 index 00000000000..52eff72a93d --- /dev/null +++ b/hail/python/test/hailtop/inter_cloud/test_delete.py @@ -0,0 +1,40 @@ +import asyncio +from typing import Dict, Tuple + +import pytest + +from hailtop.aiotools.delete import delete +from hailtop.aiotools.fs import AsyncFS + +from .utils import fresh_dir + + +@pytest.fixture(params=['file', 'gs', 's3', 'azure-https']) +async def test_delete_one_file(request, router_filesystem: Tuple[asyncio.Semaphore, AsyncFS, Dict[str, str]]): + sema, fs, bases = router_filesystem + scheme = request.param + dirname = await fresh_dir(fs, bases, scheme) + + url = f'{dirname}/file' + await fs.write(url, b'hello world') + assert await fs.isfile(url) + await delete(iter([url])) + assert not await fs.isfile(url) + + +@pytest.fixture(params=['file', 'gs', 's3', 'azure-https']) +async def test_delete_folder(request, router_filesystem: Tuple[asyncio.Semaphore, AsyncFS, Dict[str, str]]): + sema, fs, bases = router_filesystem + scheme = request.param + dirname = await fresh_dir(fs, bases, scheme) + + url = f'{dirname}/folder' + await asyncio.gather( + fs.write(f'{url}/1', b'hello world'), + fs.write(f'{url}/2', b'hello world'), + fs.write(f'{url}/3', b'hello world'), + fs.write(f'{url}/4', b'hello world'), + ) + assert await fs.isdir(url) + await delete(iter([url])) + assert not await fs.isdir(url) diff --git a/hail/python/test/hailtop/inter_cloud/test_diff.py b/hail/python/test/hailtop/inter_cloud/test_diff.py index 25676e90375..7d26e086fb5 100644 --- a/hail/python/test/hailtop/inter_cloud/test_diff.py +++ b/hail/python/test/hailtop/inter_cloud/test_diff.py @@ -1,77 +1,35 @@ -from typing import Tuple, AsyncIterator, Dict -import secrets -import os import asyncio +from typing import Dict, Tuple + import pytest -import functools +from hailtop.aiotools.diff import DiffException, diff from hailtop.aiotools.fs import AsyncFS from hailtop.frozendict import frozendict -from hailtop.aiotools.diff import diff, DiffException -from hailtop.utils import bounded_gather2 -from hailtop.aiotools.router_fs import RouterAsyncFS - - -@pytest.fixture(scope='module') -def event_loop(): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - yield loop - loop.close() - - -@pytest.fixture(scope='module') -async def router_filesystem() -> AsyncIterator[Tuple[asyncio.Semaphore, AsyncFS, Dict[str, str]]]: - token = secrets.token_hex(16) - - async with RouterAsyncFS() as fs: - file_base = f'/tmp/{token}/' - await fs.mkdir(file_base) - - gs_bucket = os.environ['HAIL_TEST_GCS_BUCKET'] - gs_base = f'gs://{gs_bucket}/tmp/{token}/' - - s3_bucket = os.environ['HAIL_TEST_S3_BUCKET'] - s3_base = f's3://{s3_bucket}/tmp/{token}/' - - azure_account = os.environ['HAIL_TEST_AZURE_ACCOUNT'] - azure_container = os.environ['HAIL_TEST_AZURE_CONTAINER'] - azure_base = f'https://{azure_account}.blob.core.windows.net/{azure_container}/tmp/{token}/' - - bases = { - 'file': file_base, - 'gs': gs_base, - 's3': s3_base, - 'azure-https': azure_base - } - - sema = asyncio.Semaphore(50) - async with sema: - yield (sema, fs, bases) - await bounded_gather2(sema, - functools.partial(fs.rmtree, sema, file_base), - functools.partial(fs.rmtree, sema, gs_base), - functools.partial(fs.rmtree, sema, s3_base), - functools.partial(fs.rmtree, sema, azure_base) - ) - - assert not await fs.isdir(file_base) - assert not await fs.isdir(gs_base) - assert not await fs.isdir(s3_base) - assert not await fs.isdir(azure_base) - - -async def fresh_dir(fs, bases, scheme): - token = secrets.token_hex(16) - dir = f'{bases[scheme]}{token}/' - await fs.mkdir(dir) - return dir - - -@pytest.fixture(params=['file/file', 'file/gs', 'file/s3', 'file/azure-https', - 'gs/file', 'gs/gs', 'gs/s3', 'gs/azure-https', - 's3/file', 's3/gs', 's3/s3', 's3/azure-https', - 'azure-https/file', 'azure-https/gs', 'azure-https/s3', 'azure-https/azure-https']) + +from .utils import fresh_dir + + +@pytest.fixture( + params=[ + 'file/file', + 'file/gs', + 'file/s3', + 'file/azure-https', + 'gs/file', + 'gs/gs', + 'gs/s3', + 'gs/azure-https', + 's3/file', + 's3/gs', + 's3/s3', + 's3/azure-https', + 'azure-https/file', + 'azure-https/gs', + 'azure-https/s3', + 'azure-https/azure-https', + ] +) async def diff_test_context(request, router_filesystem: Tuple[asyncio.Semaphore, AsyncFS, Dict[str, str]]): sema, fs, bases = router_filesystem @@ -80,8 +38,7 @@ async def diff_test_context(request, router_filesystem: Tuple[asyncio.Semaphore, src_base = await fresh_dir(fs, bases, src_scheme) dest_base = await fresh_dir(fs, bases, dest_scheme) - await asyncio.gather(*[ - fs.mkdir(x) for x in [f'{src_base}a/', f'{src_base}b/', f'{dest_base}a/', f'{dest_base}b/']]) + await asyncio.gather(*[fs.mkdir(x) for x in [f'{src_base}a/', f'{src_base}b/', f'{dest_base}a/', f'{dest_base}b/']]) await asyncio.gather( fs.write(f'{src_base}same', b'123'), @@ -90,14 +47,12 @@ async def diff_test_context(request, router_filesystem: Tuple[asyncio.Semaphore, fs.write(f'{dest_base}diff', b'1'), fs.write(f'{src_base}src-only', b'123'), fs.write(f'{dest_base}dest-only', b'123'), - fs.write(f'{src_base}a/same', b'123'), fs.write(f'{dest_base}a/same', b'123'), fs.write(f'{src_base}a/diff', b'123'), fs.write(f'{dest_base}a/diff', b'1'), fs.write(f'{src_base}a/src-only', b'123'), fs.write(f'{dest_base}a/dest-only', b'123'), - fs.write(f'{src_base}b/same', b'123'), fs.write(f'{dest_base}b/same', b'123'), fs.write(f'{src_base}b/diff', b'123'), @@ -106,11 +61,9 @@ async def diff_test_context(request, router_filesystem: Tuple[asyncio.Semaphore, fs.write(f'{dest_base}b/dest-only', b'123'), ) - yield sema, fs, src_base, dest_base -@pytest.mark.asyncio async def test_diff(diff_test_context): sema, fs, src_base, dest_base = diff_test_context @@ -133,15 +86,11 @@ async def test_diff(diff_test_context): else: assert False, result - expected = [ - {'from': f'{src_base}src-only', 'to': f'{dest_base}', 'from_size': 3, 'to_size': None} - ] + expected = [{'from': f'{src_base}src-only', 'to': f'{dest_base}', 'from_size': 3, 'to_size': None}] actual = await diff(source=f'{src_base}src-only', target=f'{dest_base}') assert actual == expected - expected = [ - {'from': f'{src_base}diff', 'to': f'{dest_base}diff', 'from_size': 3, 'to_size': 1} - ] + expected = [{'from': f'{src_base}diff', 'to': f'{dest_base}diff', 'from_size': 3, 'to_size': 1}] actual = await diff(source=f'{src_base}diff', target=f'{dest_base}diff') assert actual == expected diff --git a/hail/python/test/hailtop/inter_cloud/test_fs.py b/hail/python/test/hailtop/inter_cloud/test_fs.py index 3bb5643776c..689de9b5d28 100644 --- a/hail/python/test/hailtop/inter_cloud/test_fs.py +++ b/hail/python/test/hailtop/inter_cloud/test_fs.py @@ -1,34 +1,47 @@ -from typing import Tuple, AsyncIterator +import asyncio import datetime -import random import functools import os +import random import secrets from concurrent.futures import ThreadPoolExecutor -import asyncio -from hailtop.aiotools.fs.fs import AsyncFSURL +from typing import AsyncIterator, Tuple + import pytest -from hailtop.utils import secret_alnum_string, retry_transient_errors, bounded_gather2 -from hailtop.aiotools import LocalAsyncFS, UnexpectedEOFError, AsyncFS -from hailtop.aiotools.router_fs import RouterAsyncFS + from hailtop.aiocloud.aioaws import S3AsyncFS from hailtop.aiocloud.aioazure import AzureAsyncFS from hailtop.aiocloud.aiogoogle import GoogleStorageAsyncFS - - -@pytest.fixture(params=['file', 'gs', 's3', 'azure-https', 'router/file', 'router/gs', 'router/s3', 'router/azure-https', 'sas/azure-https']) +from hailtop.aiotools import AsyncFS, IsABucketError, LocalAsyncFS, UnexpectedEOFError +from hailtop.aiotools.fs.fs import AsyncFSURL +from hailtop.aiotools.router_fs import RouterAsyncFS +from hailtop.fs.router_fs import RouterFS +from hailtop.utils import bounded_gather2, retry_transient_errors, secret_alnum_string + + +@pytest.fixture( + params=[ + 'file', + 'gs', + 's3', + 'azure-https', + 'router/file', + 'router/gs', + 'router/s3', + 'router/azure-https', + 'sas/azure-https', + ] +) async def filesystem(request) -> AsyncIterator[Tuple[asyncio.Semaphore, AsyncFS, AsyncFSURL]]: token = secret_alnum_string() with ThreadPoolExecutor() as thread_pool: fs: AsyncFS if request.param.startswith('router/'): - fs = RouterAsyncFS(filesystems=[ - LocalAsyncFS(thread_pool), - GoogleStorageAsyncFS(), - S3AsyncFS(thread_pool), - AzureAsyncFS() - ]) + fs = RouterAsyncFS( + local_kwargs={'thread_pool': thread_pool}, + s3_kwargs={'thread_pool': thread_pool}, + ) elif request.param == 'file': fs = LocalAsyncFS(thread_pool) elif request.param.endswith('gs'): @@ -95,9 +108,12 @@ def file_data(request): return [secrets.token_bytes(1_000_000)] -@pytest.mark.asyncio async def test_write_read(filesystem: Tuple[asyncio.Semaphore, AsyncFS, AsyncFSURL], file_data): - _, fs, base, = filesystem + ( + _, + fs, + base, + ) = filesystem file = str(base.with_new_path_component('foo')) @@ -112,7 +128,6 @@ async def test_write_read(filesystem: Tuple[asyncio.Semaphore, AsyncFS, AsyncFSU assert expected == actual -@pytest.mark.asyncio async def test_open_from(filesystem: Tuple[asyncio.Semaphore, AsyncFS, AsyncFSURL]): _, fs, base = filesystem @@ -126,7 +141,6 @@ async def test_open_from(filesystem: Tuple[asyncio.Semaphore, AsyncFS, AsyncFSUR assert r == b'cde' -@pytest.mark.asyncio async def test_open_from_with_length(filesystem: Tuple[asyncio.Semaphore, AsyncFS, AsyncFSURL]): _, fs, base = filesystem @@ -162,7 +176,6 @@ async def test_open_from_with_length(filesystem: Tuple[asyncio.Semaphore, AsyncF assert False -@pytest.mark.asyncio async def test_open_empty(filesystem: Tuple[asyncio.Semaphore, AsyncFS, AsyncFSURL]): _, fs, base = filesystem @@ -176,7 +189,6 @@ async def test_open_empty(filesystem: Tuple[asyncio.Semaphore, AsyncFS, AsyncFSU assert r == b'' -@pytest.mark.asyncio async def test_open_nonexistent_file(filesystem: Tuple[asyncio.Semaphore, AsyncFS, AsyncFSURL]): _, fs, base = filesystem @@ -185,7 +197,6 @@ async def test_open_nonexistent_file(filesystem: Tuple[asyncio.Semaphore, AsyncF await fs.open(file) -@pytest.mark.asyncio async def test_open_from_nonexistent_file(filesystem: Tuple[asyncio.Semaphore, AsyncFS, AsyncFSURL]): _, fs, base = filesystem @@ -194,7 +205,6 @@ async def test_open_from_nonexistent_file(filesystem: Tuple[asyncio.Semaphore, A await fs.open_from(file, 2) -@pytest.mark.asyncio async def test_read_from(filesystem: Tuple[asyncio.Semaphore, AsyncFS, AsyncFSURL]): _, fs, base = filesystem @@ -205,7 +215,6 @@ async def test_read_from(filesystem: Tuple[asyncio.Semaphore, AsyncFS, AsyncFSUR assert r == b'cde' -@pytest.mark.asyncio async def test_read_range(filesystem: Tuple[asyncio.Semaphore, AsyncFS, AsyncFSURL]): _, fs, base = filesystem @@ -224,10 +233,9 @@ async def test_read_range(filesystem: Tuple[asyncio.Semaphore, AsyncFS, AsyncFSU except UnexpectedEOFError: pass else: - assert False + assert False -@pytest.mark.asyncio async def test_read_range_end_exclusive_empty_file(filesystem: Tuple[asyncio.Semaphore, AsyncFS, AsyncFSURL]): _, fs, base = filesystem @@ -237,8 +245,10 @@ async def test_read_range_end_exclusive_empty_file(filesystem: Tuple[asyncio.Sem assert await fs.read_range(file, 0, 0, end_inclusive=False) == b'' -@pytest.mark.asyncio -async def test_read_range_end_inclusive_empty_file_should_error(filesystem: Tuple[asyncio.Semaphore, AsyncFS, AsyncFSURL]): + +async def test_read_range_end_inclusive_empty_file_should_error( + filesystem: Tuple[asyncio.Semaphore, AsyncFS, AsyncFSURL], +): _, fs, base = filesystem file = str(base.with_new_path_component('foo')) @@ -253,7 +263,6 @@ async def test_read_range_end_inclusive_empty_file_should_error(filesystem: Tupl assert False -@pytest.mark.asyncio async def test_read_range_end_exclusive_nonempty_file(filesystem: Tuple[asyncio.Semaphore, AsyncFS, AsyncFSURL]): _, fs, base = filesystem @@ -264,7 +273,6 @@ async def test_read_range_end_exclusive_nonempty_file(filesystem: Tuple[asyncio. assert await fs.read_range(file, 2, 4, end_inclusive=False) == b'cd' -@pytest.mark.asyncio async def test_write_read_range(filesystem: Tuple[asyncio.Semaphore, AsyncFS, AsyncFSURL], file_data): _, fs, base = filesystem @@ -279,13 +287,12 @@ async def test_write_read_range(filesystem: Tuple[asyncio.Semaphore, AsyncFS, As start = min(pt1, pt2) end = max(pt1, pt2) - expected = b''.join(file_data)[start:end+1] + expected = b''.join(file_data)[start : end + 1] actual = await fs.read_range(file, start, end) # end is inclusive assert expected == actual -@pytest.mark.asyncio async def test_isfile(filesystem: Tuple[asyncio.Semaphore, AsyncFS, AsyncFSURL]): _, fs, base = filesystem @@ -299,7 +306,6 @@ async def test_isfile(filesystem: Tuple[asyncio.Semaphore, AsyncFS, AsyncFSURL]) assert await fs.isfile(file) -@pytest.mark.asyncio async def test_isdir(filesystem: Tuple[asyncio.Semaphore, AsyncFS, AsyncFSURL]): _, fs, base = filesystem @@ -323,7 +329,6 @@ async def test_isdir(filesystem: Tuple[asyncio.Semaphore, AsyncFS, AsyncFSURL]): assert await fs.isdir(dir) -@pytest.mark.asyncio async def test_isdir_subdir_only(filesystem: Tuple[asyncio.Semaphore, AsyncFS, AsyncFSURL]): _, fs, base = filesystem @@ -341,7 +346,6 @@ async def test_isdir_subdir_only(filesystem: Tuple[asyncio.Semaphore, AsyncFS, A assert await fs.isdir(subdir) -@pytest.mark.asyncio async def test_remove(filesystem: Tuple[asyncio.Semaphore, AsyncFS, AsyncFSURL]): _, fs, base = filesystem @@ -355,7 +359,6 @@ async def test_remove(filesystem: Tuple[asyncio.Semaphore, AsyncFS, AsyncFSURL]) assert not await fs.isfile(file) -@pytest.mark.asyncio async def test_rmtree(filesystem: Tuple[asyncio.Semaphore, AsyncFS, AsyncFSURL]): sema, fs, base = filesystem @@ -383,10 +386,14 @@ async def test_rmtree(filesystem: Tuple[asyncio.Semaphore, AsyncFS, AsyncFSURL]) await fs.mkdir(str(subdir4_empty)) sema = asyncio.Semaphore(100) - await bounded_gather2(sema, *[ - functools.partial(fs.touch, str(subdir.with_new_path_component(f'a{i:02}'))) - for subdir in [dir, subdir1, subdir2, subdir3, subdir1subdir1, subdir1subdir2, subdir1subdir3] - for i in range(30)]) + await bounded_gather2( + sema, + *[ + functools.partial(fs.touch, str(subdir.with_new_path_component(f'a{i:02}'))) + for subdir in [dir, subdir1, subdir2, subdir3, subdir1subdir1, subdir1subdir2, subdir1subdir3] + for i in range(30) + ], + ) assert await fs.isdir(str(dir)) assert await fs.isdir(str(subdir1)) @@ -426,7 +433,6 @@ async def test_rmtree(filesystem: Tuple[asyncio.Semaphore, AsyncFS, AsyncFSURL]) assert not await fs.isdir(str(dir)) -@pytest.mark.asyncio async def test_rmtree_empty_dir(filesystem: Tuple[asyncio.Semaphore, AsyncFS, AsyncFSURL]): sema, fs, base = filesystem @@ -437,7 +443,6 @@ async def test_rmtree_empty_dir(filesystem: Tuple[asyncio.Semaphore, AsyncFS, As assert not await fs.isdir(dir) -@pytest.mark.asyncio async def test_cloud_rmtree_file_ending_in_slash(filesystem: Tuple[asyncio.Semaphore, AsyncFS, AsyncFSURL]): sema, fs, base = filesystem @@ -454,7 +459,6 @@ async def test_cloud_rmtree_file_ending_in_slash(filesystem: Tuple[asyncio.Semap assert not await fs.exists(fname) -@pytest.mark.asyncio async def test_statfile_nonexistent_file(filesystem: Tuple[asyncio.Semaphore, AsyncFS, AsyncFSURL]): _, fs, base = filesystem @@ -462,7 +466,6 @@ async def test_statfile_nonexistent_file(filesystem: Tuple[asyncio.Semaphore, As await fs.statfile(str(base.with_new_path_component('foo'))) -@pytest.mark.asyncio async def test_statfile_directory(filesystem: Tuple[asyncio.Semaphore, AsyncFS, AsyncFSURL]): _, fs, base = filesystem @@ -474,7 +477,6 @@ async def test_statfile_directory(filesystem: Tuple[asyncio.Semaphore, AsyncFS, await fs.statfile(str(base.with_new_path_component('dir/'))) -@pytest.mark.asyncio async def test_statfile(filesystem: Tuple[asyncio.Semaphore, AsyncFS, AsyncFSURL]): _, fs, base = filesystem @@ -485,7 +487,6 @@ async def test_statfile(filesystem: Tuple[asyncio.Semaphore, AsyncFS, AsyncFSURL assert await status.size() == n -@pytest.mark.asyncio async def test_statfile_creation_and_modified_time(filesystem: Tuple[asyncio.Semaphore, AsyncFS, AsyncFSURL]): _, fs, base = filesystem @@ -495,11 +496,10 @@ async def test_statfile_creation_and_modified_time(filesystem: Tuple[asyncio.Sem status = await fs.statfile(file) if isinstance(fs, RouterAsyncFS): - is_local = isinstance(fs._get_fs(file), LocalAsyncFS) + is_local = isinstance(await fs._get_fs(file), LocalAsyncFS) else: is_local = isinstance(fs, LocalAsyncFS) - if is_local: try: status.time_created() @@ -517,7 +517,6 @@ async def test_statfile_creation_and_modified_time(filesystem: Tuple[asyncio.Sem assert modified_time == create_time -@pytest.mark.asyncio async def test_file_can_contain_url_query_delimiter(filesystem: Tuple[asyncio.Semaphore, AsyncFS, AsyncFSURL]): _, fs, base = filesystem @@ -525,13 +524,19 @@ async def test_file_can_contain_url_query_delimiter(filesystem: Tuple[asyncio.Se await fs.write(file, secrets.token_bytes(10)) assert await fs.exists(file) async for f in await fs.listfiles(str(base)): - if 'bar?baz' in f.name(): + if 'bar?baz' in f.basename(): break else: assert False, 'File bar?baz not found' -@pytest.mark.asyncio +async def test_basename_is_not_path(filesystem: Tuple[asyncio.Semaphore, AsyncFS, AsyncFSURL]): + _, fs, base = filesystem + + await fs.write(str(base.with_new_path_component('abc123')), b'foo') + assert (await fs.statfile(str(base.with_new_path_component('abc123')))).basename() == 'abc123' + + async def test_listfiles(filesystem: Tuple[asyncio.Semaphore, AsyncFS, AsyncFSURL]): _, fs, base = filesystem @@ -574,14 +579,8 @@ async def listfiles(dir, recursive): stat = await entry.status() assert await stat.size() == 0 -@pytest.mark.asyncio -@pytest.mark.parametrize("permutation", [ - None, - [0, 1, 2], - [0, 2, 1], - [1, 2, 0], - [2, 1, 0] -]) + +@pytest.mark.parametrize("permutation", [None, [0, 1, 2], [0, 2, 1], [1, 2, 0], [2, 1, 0]]) async def test_multi_part_create(filesystem: Tuple[asyncio.Semaphore, AsyncFS, AsyncFSURL], permutation): sema, fs, base = filesystem @@ -601,6 +600,7 @@ async def test_multi_part_create(filesystem: Tuple[asyncio.Semaphore, AsyncFS, A path = str(base.with_new_path_component('a')) async with await fs.multi_part_create(sema, path, len(part_data)) as c: + async def create_part(i): async with await c.create_part(i, part_start[i]) as f: await f.write(part_data[i]) @@ -611,9 +611,7 @@ async def create_part(i): await retry_transient_errors(create_part, i) else: # do in parallel - await asyncio.gather(*[ - retry_transient_errors(create_part, i) - for i in range(len(part_data))]) + await asyncio.gather(*[retry_transient_errors(create_part, i) for i in range(len(part_data))]) expected = b''.join(part_data) async with await fs.open(path) as f: @@ -621,7 +619,6 @@ async def create_part(i): assert expected == actual -@pytest.mark.asyncio async def test_rmtree_on_symlink_to_directory(): token = secret_alnum_string() with ThreadPoolExecutor() as thread_pool: @@ -636,3 +633,52 @@ async def test_rmtree_on_symlink_to_directory(): finally: await fs.rmtree(sema, str(base)) assert not await fs.isdir(str(base)) + + +async def test_operations_on_a_bucket_url_is_error(filesystem: Tuple[asyncio.Semaphore, AsyncFS, AsyncFSURL]): + _, fs, base = filesystem + + if base.scheme in ('', 'file'): + return + + bucket_url = str(base.with_path('')) + + with pytest.raises(IsABucketError): + await fs.isdir(bucket_url) + + assert await fs.isfile(bucket_url) is False + + with pytest.raises(IsABucketError): + await fs.statfile(bucket_url) + + with pytest.raises(IsABucketError): + await fs.remove(bucket_url) + + with pytest.raises(IsABucketError): + await fs.create(bucket_url) + + with pytest.raises(IsABucketError): + await fs.open(bucket_url) + + +async def test_hfs_ls_bucket_url_not_an_error(filesystem: Tuple[asyncio.Semaphore, AsyncFS, AsyncFSURL]): + _, fs, base = filesystem + + if base.scheme in ('', 'file'): + return + + await fs.write(str(base.with_new_path_component('abc123')), b'foo') # ensure the bucket is non-empty + + bucket_url = str(base.with_path('')) + with RouterFS() as fs: + fs.ls(bucket_url) + + +async def test_with_new_path_component(filesystem: Tuple[asyncio.Semaphore, AsyncFS, AsyncFSURL]): + _, _, base = filesystem + + assert str(base.with_path('').with_new_path_component('abc')) == str(base.with_path('abc')) + assert str(base.with_path('abc').with_new_path_component('def')) == str(base.with_path('abc/def')) + + actual = base.with_path('abc').with_new_path_component('def').with_new_path_component('ghi') + assert str(actual) == str(base.with_path('abc/def/ghi')) diff --git a/hail/python/test/hailtop/inter_cloud/test_into_copy.py b/hail/python/test/hailtop/inter_cloud/test_into_copy.py index edf32336ca7..0ac74df543d 100644 --- a/hail/python/test/hailtop/inter_cloud/test_into_copy.py +++ b/hail/python/test/hailtop/inter_cloud/test_into_copy.py @@ -1,8 +1,8 @@ -import pytest import os.path -from hailtop.aiotools.copy import copy_from_dict import tempfile +from hailtop.aiotools.copy import copy_from_dict + def write_file(path, data): with open(path, "w") as f: @@ -14,13 +14,14 @@ def read_file(path): return f.read() -@pytest.mark.asyncio async def test_copy_file(): with tempfile.TemporaryDirectory() as test_dir: write_file(f"{test_dir}/file1", "hello world\n") - inputs = [{"from": f"{test_dir}/file1", "to": f"{test_dir}/file2"}, - {"from": f"{test_dir}/file1", "into": f"{test_dir}/dir1"}] + inputs = [ + {"from": f"{test_dir}/file1", "to": f"{test_dir}/file2"}, + {"from": f"{test_dir}/file1", "into": f"{test_dir}/dir1"}, + ] await copy_from_dict(files=inputs) @@ -29,7 +30,6 @@ async def test_copy_file(): assert read_file(file) == "hello world\n" -@pytest.mark.asyncio async def test_copy_dir(): with tempfile.TemporaryDirectory() as test_dir: os.makedirs(f"{test_dir}/subdir1") @@ -41,7 +41,6 @@ async def test_copy_dir(): assert read_file(f"{test_dir}/subdir2/subdir1/file1") == "hello world\n" -@pytest.mark.asyncio async def test_error_function(): with tempfile.TemporaryDirectory() as test_dir: write_file(f"{test_dir}/foo", "hello world\n") diff --git a/hail/python/test/hailtop/inter_cloud/utils.py b/hail/python/test/hailtop/inter_cloud/utils.py new file mode 100644 index 00000000000..b9adba791bc --- /dev/null +++ b/hail/python/test/hailtop/inter_cloud/utils.py @@ -0,0 +1,11 @@ +import secrets +from typing import Dict + +from hailtop.aiotools.fs import AsyncFS + + +async def fresh_dir(fs: AsyncFS, bases: Dict[str, str], scheme: str): + token = secrets.token_hex(16) + dir = f'{bases[scheme]}{token}/' + await fs.mkdir(dir) + return dir diff --git a/hail/python/test/hailtop/test_aiogoogle.py b/hail/python/test/hailtop/test_aiogoogle.py index c493612da5f..b0615c1287f 100644 --- a/hail/python/test/hailtop/test_aiogoogle.py +++ b/hail/python/test/hailtop/test_aiogoogle.py @@ -1,14 +1,15 @@ +import asyncio +import concurrent.futures +import functools import os import secrets from concurrent.futures import ThreadPoolExecutor -import asyncio + import pytest -import concurrent.futures -import functools -from hailtop.utils import secret_alnum_string, bounded_gather2, retry_transient_errors -from hailtop.aiotools import LocalAsyncFS + +from hailtop.aiocloud.aiogoogle import GoogleStorageAsyncFS, GoogleStorageClient from hailtop.aiotools.router_fs import RouterAsyncFS -from hailtop.aiocloud.aiogoogle import GoogleStorageClient, GoogleStorageAsyncFS +from hailtop.utils import bounded_gather2, retry_transient_errors, secret_alnum_string @pytest.fixture(params=['gs', 'router/gs']) @@ -17,14 +18,14 @@ async def gs_filesystem(request): with ThreadPoolExecutor() as thread_pool: if request.param.startswith('router/'): - fs = RouterAsyncFS(filesystems=[LocalAsyncFS(thread_pool), GoogleStorageAsyncFS()]) + fs = RouterAsyncFS(local_kwargs={'thread_pool': thread_pool}) else: assert request.param.endswith('gs') fs = GoogleStorageAsyncFS() async with fs: test_storage_uri = os.environ['HAIL_TEST_STORAGE_URI'] protocol = 'gs://' - assert test_storage_uri[:len(protocol)] == protocol + assert test_storage_uri[: len(protocol)] == protocol base = f'{test_storage_uri}/tmp/{token}/' await fs.mkdir(base) @@ -49,14 +50,15 @@ def test_bucket_path_parsing(): assert bucket == 'foo' and prefix == 'bar/baz' -@pytest.mark.asyncio async def test_get_object_metadata(bucket_and_temporary_file): bucket, file = bucket_and_temporary_file async with GoogleStorageClient() as client: + async def upload(): async with await client.insert_object(bucket, file) as f: await f.write(b'foo') + await retry_transient_errors(upload) metadata = await client.get_object_metadata(bucket, file) assert 'etag' in metadata @@ -65,14 +67,15 @@ async def upload(): assert int(metadata['size']) == 3 -@pytest.mark.asyncio async def test_get_object_headers(bucket_and_temporary_file): bucket, file = bucket_and_temporary_file async with GoogleStorageClient() as client: + async def upload(): async with await client.insert_object(bucket, file) as f: await f.write(b'foo') + await retry_transient_errors(upload) async with await client.get_object(bucket, file) as f: headers = f.headers() # type: ignore @@ -81,16 +84,17 @@ async def upload(): assert await f.read() == b'foo' -@pytest.mark.asyncio async def test_compose(bucket_and_temporary_file): bucket, file = bucket_and_temporary_file part_data = [b'a', b'bb', b'ccc'] async with GoogleStorageClient() as client: + async def upload(i, b): async with await client.insert_object(bucket, f'{file}/{i}') as f: await f.write(b) + for i, b in enumerate(part_data): await retry_transient_errors(upload, i, b) await client.compose(bucket, [f'{file}/{i}' for i in range(len(part_data))], f'{file}/combined') @@ -101,7 +105,6 @@ async def upload(i, b): assert actual == expected -@pytest.mark.asyncio async def test_multi_part_create_many_two_level_merge(gs_filesystem): # This is a white-box test. compose has a maximum of 32 inputs, # so if we're composing more than 32 parts, the @@ -121,13 +124,15 @@ async def test_multi_part_create_many_two_level_merge(gs_filesystem): path = f'{base}a' async with await fs.multi_part_create(sema, path, len(part_data)) as c: + async def create_part(i): async with await c.create_part(i, part_start[i]) as f: await f.write(part_data[i]) # do in parallel - await bounded_gather2(sema, *[ - functools.partial(retry_transient_errors, create_part, i) for i in range(len(part_data))]) + await bounded_gather2( + sema, *[functools.partial(retry_transient_errors, create_part, i) for i in range(len(part_data))] + ) expected = b''.join(part_data) actual = await fs.read(path) @@ -135,7 +140,7 @@ async def create_part(i): except (concurrent.futures._base.CancelledError, asyncio.CancelledError) as err: raise AssertionError('uncaught cancelled error') from err -@pytest.mark.asyncio + async def test_weird_urls(gs_filesystem): _, fs, base = gs_filesystem diff --git a/hail/python/test/hailtop/test_dictfix.py b/hail/python/test/hailtop/test_dictfix.py index c8a4a0aa173..ab5e306e9f1 100644 --- a/hail/python/test/hailtop/test_dictfix.py +++ b/hail/python/test/hailtop/test_dictfix.py @@ -2,9 +2,11 @@ def test_batch_example(): - spec = {'input': dictfix.NoneOr({'logs': None}), - 'main': dictfix.NoneOr({'logs': None}), - 'output': dictfix.NoneOr({'logs': None})} + spec = { + 'input': dictfix.NoneOr({'logs': None}), + 'main': dictfix.NoneOr({'logs': None}), + 'output': dictfix.NoneOr({'logs': None}), + } expected = {'input': None, 'main': None, 'output': None} @@ -24,16 +26,12 @@ def test_batch_example(): dictfix.dictfix(actual, spec) assert actual == expected - expected = {'input': None, - 'main': {'id': 3, 'logs': None}, - 'output': None} + expected = {'input': None, 'main': {'id': 3, 'logs': None}, 'output': None} actual = {'main': {'id': 3}, 'output': None} dictfix.dictfix(actual, spec) assert actual == expected - expected = {'input': None, - 'main': {'id': 3, 'logs': 'abc\n123'}, - 'output': {'id': 4, 'logs': None}} + expected = {'input': None, 'main': {'id': 3, 'logs': 'abc\n123'}, 'output': {'id': 4, 'logs': None}} actual = {'main': {'id': 3, 'logs': 'abc\n123'}, 'output': {'id': 4}} dictfix.dictfix(actual, spec) assert actual == expected diff --git a/hail/python/test/hailtop/test_humanizex.py b/hail/python/test/hailtop/test_humanizex.py index feddb642d5b..b4581939408 100644 --- a/hail/python/test/hailtop/test_humanizex.py +++ b/hail/python/test/hailtop/test_humanizex.py @@ -23,7 +23,6 @@ def test_humanize(): assert naturaldelta(0.001) == '1ms' assert naturaldelta(0.000001) == '1μs' - assert naturaldelta_msec(15 * 24 * 60 * 60 * 1000) == '2 weeks 1 day' assert naturaldelta_msec(200_000) == '3 minutes 20s' assert naturaldelta_msec(120_000) == '2 minutes' diff --git a/hail/python/test/hailtop/test_timex.py b/hail/python/test/hailtop/test_timex.py index 9bb1af6faf6..f40db6de46a 100644 --- a/hail/python/test/hailtop/test_timex.py +++ b/hail/python/test/hailtop/test_timex.py @@ -1,6 +1,7 @@ -from hailtop import timex import datetime +from hailtop import timex + def test_google_cloud_storage_example(): actual = timex.parse_rfc3339('2022-12-27T16:48:06.404Z') @@ -53,7 +54,6 @@ def test_space_instead_of_T(): expected = datetime.datetime(2022, 12, 27, 16, 48, 6, 404000, tzinfo) assert actual == expected - actual = timex.parse_rfc3339('2022-12-27 16:48:06.404-04:35') tzinfo = datetime.timezone(datetime.timedelta(hours=-4, minutes=-35)) @@ -68,14 +68,12 @@ def test_lowercase_T(): expected = datetime.datetime(2022, 12, 27, 16, 48, 6, 0, tzinfo) assert actual == expected - actual = timex.parse_rfc3339('2022-12-27t16:48:06.404Z') tzinfo = datetime.timezone.utc expected = datetime.datetime(2022, 12, 27, 16, 48, 6, 404000, tzinfo) assert actual == expected - actual = timex.parse_rfc3339('2022-12-27t16:48:06.404-04:35') tzinfo = datetime.timezone(datetime.timedelta(hours=-4, minutes=-35)) @@ -90,7 +88,6 @@ def test_lowercase_z(): expected = datetime.datetime(2022, 12, 27, 16, 48, 6, 0, tzinfo) assert actual == expected - actual = timex.parse_rfc3339('2022-12-27T16:48:06.404z') tzinfo = datetime.timezone.utc @@ -98,7 +95,6 @@ def test_lowercase_z(): assert actual == expected - def test_one_fractional_second_digit(): actual = timex.parse_rfc3339('2022-12-27T16:48:06.1Z') diff --git a/hail/python/test/hailtop/test_yamlx.py b/hail/python/test/hailtop/test_yamlx.py index 34b52fb7320..3e4e8886a78 100644 --- a/hail/python/test/hailtop/test_yamlx.py +++ b/hail/python/test/hailtop/test_yamlx.py @@ -3,9 +3,9 @@ def test_multiline_str_is_literal_block(): actual = yamlx.dump({'hello': 'abc', 'multiline': 'abc\ndef'}) - expected = '''hello: abc + expected = """hello: abc multiline: |- abc def -''' +""" assert actual == expected diff --git a/hail/python/test/hailtop/utils/test_filesize.py b/hail/python/test/hailtop/utils/test_filesize.py index 7eb3aeda10a..d49ae4c4541 100644 --- a/hail/python/test/hailtop/utils/test_filesize.py +++ b/hail/python/test/hailtop/utils/test_filesize.py @@ -1,8 +1,10 @@ from pytest import raises + from hailtop.utils.filesize import filesize + def test_filesize(): - for n in [-1, -1023, -1024, -1025, -1024**3]: + for n in [-1, -1023, -1024, -1025, -(1024**3)]: with raises(ValueError): filesize(n) assert filesize(0) == "0B" @@ -13,13 +15,12 @@ def test_filesize(): assert filesize(1025) == "1KiB" prefixes = ["K", "M", "G", "T", "P", "E", "Z", "Y"] for exp in range(2, 9): - assert filesize(1024 ** exp - 1) == f"1023{prefixes[exp - 2]}iB" - assert filesize(1024 ** exp) == f"1{prefixes[exp - 1]}iB" - assert filesize(1024 ** exp + 1) == f"1{prefixes[exp - 1]}iB" - assert filesize(1024 ** exp * 2) == f"2{prefixes[exp - 1]}iB" - assert filesize(1024 ** 9 - 1) == "1023YiB" - assert filesize(1024 ** 9) == "1024YiB" - assert filesize(1024 ** 9 + 1) == "1024YiB" - assert filesize(1024 ** 9 * 2) == "2048YiB" - assert filesize(1024 ** 10) == "1048576YiB" - + assert filesize(1024**exp - 1) == f"1023{prefixes[exp - 2]}iB" + assert filesize(1024**exp) == f"1{prefixes[exp - 1]}iB" + assert filesize(1024**exp + 1) == f"1{prefixes[exp - 1]}iB" + assert filesize(1024**exp * 2) == f"2{prefixes[exp - 1]}iB" + assert filesize(1024**9 - 1) == "1023YiB" + assert filesize(1024**9) == "1024YiB" + assert filesize(1024**9 + 1) == "1024YiB" + assert filesize(1024**9 * 2) == "2048YiB" + assert filesize(1024**10) == "1048576YiB" diff --git a/hail/python/test/hailtop/utils/test_gcs_requester_pays.py b/hail/python/test/hailtop/utils/test_gcs_requester_pays.py new file mode 100644 index 00000000000..e27290bc1d1 --- /dev/null +++ b/hail/python/test/hailtop/utils/test_gcs_requester_pays.py @@ -0,0 +1,21 @@ +import pytest + +from hailtop.aiotools.router_fs import RouterAsyncFS +from hailtop.fs.router_fs import RouterFS +from hailtop.utils.gcs_requester_pays import GCSRequesterPaysFSCache + + +@pytest.mark.parametrize("cls", [RouterFS, RouterAsyncFS]) +def test_get_fs_by_requester_pays_config(cls): + config_1 = "foo" + config_2 = ("foo", ["bar", "baz", "bat"]) + kwargs_1 = {"gcs_requester_pays_configuration": config_1} + fses = GCSRequesterPaysFSCache(cls) + assert fses[None]._gcs_kwargs == {} + assert fses[config_1]._gcs_kwargs == kwargs_1 + set_kwargs = fses[config_2]._gcs_kwargs + assert set_kwargs["gcs_requester_pays_configuration"][0] == config_2[0] + assert len(set_kwargs["gcs_requester_pays_configuration"][1]) == len(config_2[1]) + assert set(set_kwargs["gcs_requester_pays_configuration"][1]) == set(config_2[1]) + default_kwargs_fses = GCSRequesterPaysFSCache(cls, {"gcs_kwargs": kwargs_1}) + assert default_kwargs_fses[None]._gcs_kwargs == kwargs_1 diff --git a/hail/python/test/hailtop/utils/test_utils.py b/hail/python/test/hailtop/utils/test_utils.py index 809a6f219dd..40193eb9227 100644 --- a/hail/python/test/hailtop/utils/test_utils.py +++ b/hail/python/test/hailtop/utils/test_utils.py @@ -1,6 +1,13 @@ -from hailtop.utils import (partition, url_basename, url_join, url_scheme, - url_and_params, parse_docker_image_reference, grouped) -from hailtop.utils.utils import digits_needed, unzip, filter_none, flatten +from hailtop.utils import ( + grouped, + parse_docker_image_reference, + partition, + url_and_params, + url_basename, + url_join, + url_scheme, +) +from hailtop.utils.utils import digits_needed, filter_none, flatten, unzip def test_partition_zero_empty(): @@ -20,8 +27,14 @@ def test_partition_uneven_big(): def test_partition_toofew(): - assert list(partition(6, range(3))) == [range(0, 1), range(1, 2), range(2, 3), - range(3, 3), range(3, 3), range(3, 3)] + assert list(partition(6, range(3))) == [ + range(0, 1), + range(1, 2), + range(2, 3), + range(3, 3), + range(3, 3), + range(3, 3), + ] def test_url_basename(): @@ -42,11 +55,13 @@ def test_url_scheme(): assert url_scheme('https://hail.is/path/to') == 'https' assert url_scheme('/path/to') == '' + def test_url_and_params(): assert url_and_params('https://example.com/') == ('https://example.com/', {}) assert url_and_params('https://example.com/foo?') == ('https://example.com/foo', {}) assert url_and_params('https://example.com/foo?a=b&c=d') == ('https://example.com/foo', {'a': 'b', 'c': 'd'}) + def test_parse_docker_image_reference(): x = parse_docker_image_reference('animage') assert x.domain is None @@ -131,54 +146,63 @@ def test_parse_docker_image_reference(): def test_grouped_size_0_groups_9_elements(): try: - list(grouped(0, [1,2,3,4,5,6,7,8,9])) + list(grouped(0, [1, 2, 3, 4, 5, 6, 7, 8, 9])) except ValueError: pass else: assert False + def test_grouped_size_1_groups_9_elements(): - actual = list(grouped(1, [1,2,3,4,5,6,7,8,9])) + actual = list(grouped(1, [1, 2, 3, 4, 5, 6, 7, 8, 9])) expected = [[1], [2], [3], [4], [5], [6], [7], [8], [9]] assert actual == expected + def test_grouped_size_5_groups_9_elements(): - actual = list(grouped(5, [1,2,3,4,5,6,7,8,9])) + actual = list(grouped(5, [1, 2, 3, 4, 5, 6, 7, 8, 9])) expected = [[1, 2, 3, 4, 5], [6, 7, 8, 9]] assert actual == expected + def test_grouped_size_3_groups_0_elements(): - actual = list(grouped(3,[])) + actual = list(grouped(3, [])) expected = [] assert actual == expected + def test_grouped_size_2_groups_1_elements(): - actual = list(grouped(2,[1])) + actual = list(grouped(2, [1])) expected = [[1]] assert actual == expected + def test_grouped_size_1_groups_0_elements(): - actual = list(grouped(1,[0])) + actual = list(grouped(1, [0])) expected = [[0]] assert actual == expected + def test_grouped_size_1_groups_5_elements(): - actual = list(grouped(1,['abc', 'def', 'ghi', 'jkl', 'mno'])) + actual = list(grouped(1, ['abc', 'def', 'ghi', 'jkl', 'mno'])) expected = [['abc'], ['def'], ['ghi'], ['jkl'], ['mno']] assert actual == expected + def test_grouped_size_2_groups_5_elements(): - actual = list(grouped(2,['abc', 'def', 'ghi', 'jkl', 'mno'])) + actual = list(grouped(2, ['abc', 'def', 'ghi', 'jkl', 'mno'])) expected = [['abc', 'def'], ['ghi', 'jkl'], ['mno']] assert actual == expected + def test_grouped_size_3_groups_6_elements(): - actual = list(grouped(3,['abc', 'def', 'ghi', 'jkl', 'mno', ''])) + actual = list(grouped(3, ['abc', 'def', 'ghi', 'jkl', 'mno', ''])) expected = [['abc', 'def', 'ghi'], ['jkl', 'mno', '']] assert actual == expected + def test_grouped_size_3_groups_7_elements(): - actual = list(grouped(3,['abc', 'def', 'ghi', 'jkl', 'mno', 'pqr', 'stu'])) + actual = list(grouped(3, ['abc', 'def', 'ghi', 'jkl', 'mno', 'pqr', 'stu'])) expected = [['abc', 'def', 'ghi'], ['jkl', 'mno', 'pqr'], ['stu']] assert actual == expected @@ -207,7 +231,12 @@ def test_filter_none(): assert filter_none([None, []]) == [[]] assert filter_none([0, []]) == [0, []] assert filter_none([1, 2, [None]]) == [1, 2, [None]] - assert filter_none([1, 3.5, 2, 4,]) == [1, 3.5, 2, 4] + assert filter_none([ + 1, + 3.5, + 2, + 4, + ]) == [1, 3.5, 2, 4] assert filter_none([1, 2, 3.0, None, 5]) == [1, 2, 3.0, 5] assert filter_none(['a', 'b', 'c', None]) == ['a', 'b', 'c'] assert filter_none([None, [None, [None, [None]]]]) == [[None, [None, [None]]]] @@ -222,5 +251,22 @@ def test_flatten(): assert flatten([['a', 'b', 'c'], ['d', 'e']]) == ['a', 'b', 'c', 'd', 'e'] assert flatten([[['a'], ['b']], [[1, 2, 3], [4, 5]]]) == [['a'], ['b'], [1, 2, 3], [4, 5]] assert flatten([['apples'], ['bannanas'], ['oranges']]) == ['apples', 'bannanas', 'oranges'] - assert flatten([['apple', 'bannana'], ['a', 'b', 'c'], [1, 2, 3, 4]]) == ['apple', 'bannana', 'a', 'b', 'c', 1, 2, 3, 4] - assert flatten([['apples'], [''], ['bannanas'], [''], ['oranges'], ['']]) == ['apples', '', 'bannanas', '', 'oranges', ''] + assert flatten([['apple', 'bannana'], ['a', 'b', 'c'], [1, 2, 3, 4]]) == [ + 'apple', + 'bannana', + 'a', + 'b', + 'c', + 1, + 2, + 3, + 4, + ] + assert flatten([['apples'], [''], ['bannanas'], [''], ['oranges'], ['']]) == [ + 'apples', + '', + 'bannanas', + '', + 'oranges', + '', + ] diff --git a/hail/scripts/release.sh b/hail/scripts/release.sh index cd5a625769f..c7a1c203818 100755 --- a/hail/scripts/release.sh +++ b/hail/scripts/release.sh @@ -1,33 +1,58 @@ -#!/bin/bash +#!/usr/bin/env bash set -ex SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +usage() { +cat << EOF +usage: $(basename "$0") + + All arguments are specified by environment variables. For example: + + HAIL_PIP_VERSION=0.2.123 + HAIL_VERSION=0.2.123-abcdef123 + GIT_VERSION=abcdef123 + REMOTE=origin + WHEEL=/path/to/the.whl + GITHUB_OAUTH_HEADER_FILE=/path/to/github/oauth/header/file + HAIL_GENETICS_HAIL_IMAGE=docker://us-docker.pkg.dev/hail-vdc/hail/hailgenetics/hail:deploy-123abc + HAIL_GENETICS_HAIL_IMAGE_PY_3_10=docker://us-docker.pkg.dev/hail-vdc/hail/hailgenetics/hail:deploy-123abc + HAIL_GENETICS_HAIL_IMAGE_PY_3_11=docker://us-docker.pkg.dev/hail-vdc/hail/hailgenetics/hail:deploy-123abc + HAIL_GENETICS_HAILTOP_IMAGE=docker://us-docker.pkg.dev/hail-vdc/hail/hailgenetics/hailtop:deploy-123abc + HAIL_GENETICS_VEP_GRCH37_85_IMAGE=docker://us-docker.pkg.dev/hail-vdc/hail/hailgenetics/vep-grch37-85:deploy-123abc + HAIL_GENETICS_VEP_GRCH38_95_IMAGE=docker://us-docker.pkg.dev/hail-vdc/hail/hailgenetics/vep-grch38-95:deploy-123abc + AZURE_WHEEL=/path/to/wheel/for/azure + WEBSITE_TAR=/path/to/www.tar.gz + bash $(basename "$0") +EOF +} + retry() { "$@" || (sleep 2 && "$@") || (sleep 5 && "$@"); } -[[ $# -eq 16 ]] || (echo "./release.sh HAIL_PIP_VERSION HAIL_VERSION GIT_VERSION REMOTE WHEEL GITHUB_OAUTH_HEADER_FILE HAIL_GENETICS_HAIL_IMAGE HAIL_GENETICS_HAIL_IMAGE_PY_3_10 HAIL_GENETICS_HAIL_IMAGE_PY_3_11 HAIL_GENETICS_HAILTOP_IMAGE HAIL_GENETICS_VEP_GRCH37_85_IMAGE HAIL_GENETICS_VEP_GRCH38_95_IMAGE WHEEL_FOR_AZURE WEBSITE_TAR GCP_PROJECT GCP_AR_CLEANUP_POLICY" ; exit 1) - -HAIL_PIP_VERSION=$1 -HAIL_VERSION=$2 -GIT_VERSION=$3 -REMOTE=$4 -WHEEL=$5 -GITHUB_OAUTH_HEADER_FILE=$6 -HAIL_GENETICS_HAIL_IMAGE=$7 -HAIL_GENETICS_HAIL_IMAGE_PY_3_10=$8 -HAIL_GENETICS_HAIL_IMAGE_PY_3_11=$9 -HAIL_GENETICS_HAILTOP_IMAGE=${10} -HAIL_GENETICS_VEP_GRCH37_85_IMAGE=${11} -HAIL_GENETICS_VEP_GRCH38_95_IMAGE=${12} -WHEEL_FOR_AZURE=${13} -WEBSITE_TAR=${14} -GCP_PROJECT=${15} -GCP_AR_CLEANUP_POLICY=${16} +arguments="HAIL_PIP_VERSION HAIL_VERSION GIT_VERSION REMOTE WHEEL GITHUB_OAUTH_HEADER_FILE \ + HAIL_GENETICS_HAIL_IMAGE HAIL_GENETICS_HAIL_IMAGE_PY_3_10 \ + HAIL_GENETICS_HAIL_IMAGE_PY_3_11 HAIL_GENETICS_HAILTOP_IMAGE \ + HAIL_GENETICS_VEP_GRCH37_85_IMAGE HAIL_GENETICS_VEP_GRCH38_95_IMAGE AZURE_WHEEL \ + WEBSITE_TAR" + +for varname in $arguments +do + if [ -z "${!varname}" ] # A bash-ism, but we are #!/bin/bash + then + echo + usage + echo + echo "$varname is unset or empty" + exit 1 + else + echo "$varname=${!varname}" + fi +done retry skopeo inspect $HAIL_GENETICS_HAIL_IMAGE || (echo "could not pull $HAIL_GENETICS_HAIL_IMAGE" ; exit 1) retry skopeo inspect $HAIL_GENETICS_HAIL_IMAGE_PY_3_10 || (echo "could not pull $HAIL_GENETICS_HAIL_IMAGE_PY_3_10" ; exit 1) @@ -48,6 +73,21 @@ then exit 1 fi +if [ ! -f $WHEEL_FOR_AZURE ] +then + echo "wheel for azure not found $WHEEL_FOR_AZURE" +fi + +if [ ! -f python/hail/experimental/datasets.json ] +then + echo "datasets.json not found at python/hail/experimental/datasets.json" +fi + +if [ ! -f $WEBSITE_TAR ] +then + echo "website tar not found at $WEBSITE_TAR" +fi + pip_versions_file=$(mktemp) pip install hail== 2>&1 \ | head -n 1 \ @@ -111,14 +151,9 @@ retry skopeo copy $HAIL_GENETICS_VEP_GRCH38_95_IMAGE docker://us-docker.pkg.dev/ twine upload $WHEEL # deploy wheel for Azure HDInsight -wheel_for_azure_url=gs://hail-common/azure-hdinsight-wheels/$(basename $WHEEL_FOR_AZURE) -gcloud storage cp $WHEEL_FOR_AZURE $wheel_for_azure_url -gcloud storage objects update $wheel_for_azure_url --temporary-hold - -# update docs sha -cloud_sha_location=gs://hail-common/builds/0.2/latest-hash/cloudtools-5-spark-2.4.0.txt -printf "$GIT_VERSION" | gcloud storage cp - $cloud_sha_location -gcloud storage objects update -r $cloud_sha_location --add-acl-grant=entity=AllUsers,role=READER +azure_wheel_url=gs://hail-common/azure-hdinsight-wheels/$(basename $AZURE_WHEEL) +gcloud storage cp $AZURE_WHEEL $azure_wheel_url +gcloud storage objects update $azure_wheel_url --temporary-hold # deploy datasets (annotation db) json datasets_json_url=gs://hail-common/annotationdb/$HAIL_VERSION/datasets.json @@ -160,9 +195,3 @@ make_pr_for() { make_pr_for terra-jupyter-hail make_pr_for terra-jupyter-aou - -gcloud artifacts repositories set-cleanup-policies hail \ - --project=$GCP_PROJECT \ - --location=us \ - --policy=$GCP_AR_CLEANUP_POLICY \ - --no-dry-run diff --git a/hail/scripts/test-dataproc.sh b/hail/scripts/test-dataproc.sh index a205393b5a9..e514233a20d 100755 --- a/hail/scripts/test-dataproc.sh +++ b/hail/scripts/test-dataproc.sh @@ -33,8 +33,10 @@ hailctl dataproc \ --max-age 120m \ --vep $1 \ --num-preemptible-workers=4 \ - --requester-pays-allow-buckets hail-us-vep \ - --subnet=default + --requester-pays-allow-buckets hail-us-central1-vep,hail-1kg \ + --subnet=default \ + --bucket=hail-dataproc-staging-bucket-us-central1 \ + --temp-bucket=hail-dataproc-temp-bucket-us-central1 for file in $cluster_test_files do hailctl dataproc \ diff --git a/hail/scripts/test-gcp.sh b/hail/scripts/test-gcp.sh index 82ae452916a..d21036fdb42 100755 --- a/hail/scripts/test-gcp.sh +++ b/hail/scripts/test-gcp.sh @@ -1,17 +1,16 @@ #!/bin/bash -if [[ $# -ne 1 ]] +if [[ $# -ne 0 ]] then cat < /dev/null); then fi function cleanup { - gcloud --project broad-ctsa -q dataproc clusters delete --async $CLUSTER + hailctl dataproc stop $CLUSTER } trap cleanup EXIT SIGINT -gcloud --project broad-ctsa dataproc clusters create $CLUSTER \ +hailctl dataproc start $CLUSTER \ + --project broad-ctsa \ --zone $ZONE \ - --master-machine-type n1-standard-2 \ - --master-boot-disk-size 100 \ - --num-workers 2 \ - --worker-machine-type n1-standard-2 \ - --worker-boot-disk-size 100 \ - --image-version ${DATAPROC_VERSION} \ - --initialization-actions 'gs://hail-dataproc-deps/initialization-actions.sh' + --subnet=default \ + --bucket=gs://hail-dataproc-staging-bucket-us-central1 \ + --temp-bucket=gs://hail-dataproc-temp-bucket-us-central1 # copy up necessary files gcloud --project broad-ctsa compute scp \ diff --git a/hail/scripts/test_requester_pays_parsing.py b/hail/scripts/test_requester_pays_parsing.py index d8c5794d377..9d1ae809aee 100644 --- a/hail/scripts/test_requester_pays_parsing.py +++ b/hail/scripts/test_requester_pays_parsing.py @@ -9,7 +9,9 @@ from hailtop.utils.process import check_exec_output if 'YOU_MAY_OVERWRITE_MY_SPARK_DEFAULTS_CONF_AND_HAILCTL_SETTINGS' not in os.environ: - print('This script will overwrite your spark-defaults.conf and hailctl settings. It is intended to be executed inside a container.') + print( + 'This script will overwrite your spark-defaults.conf and hailctl settings. It is intended to be executed inside a container.' + ) sys.exit(1) @@ -33,7 +35,6 @@ async def unset_hailctl(): ) -@pytest.mark.asyncio async def test_no_configuration(): with open(SPARK_CONF_PATH, 'w'): pass @@ -44,8 +45,6 @@ async def test_no_configuration(): assert actual is None - -@pytest.mark.asyncio async def test_no_project_is_error(): with open(SPARK_CONF_PATH, 'w') as f: f.write('spark.hadoop.fs.gs.requester.pays.mode AUTO\n') @@ -56,7 +55,6 @@ async def test_no_project_is_error(): get_gcs_requester_pays_configuration() -@pytest.mark.asyncio async def test_auto_with_project(): with open(SPARK_CONF_PATH, 'w') as f: f.write('spark.hadoop.fs.gs.requester.pays.project.id my_project\n') @@ -68,8 +66,6 @@ async def test_auto_with_project(): assert actual == 'my_project' - -@pytest.mark.asyncio async def test_custom_no_buckets(): with open(SPARK_CONF_PATH, 'w') as f: f.write('spark.hadoop.fs.gs.requester.pays.project.id my_project\n') @@ -81,7 +77,6 @@ async def test_custom_no_buckets(): get_gcs_requester_pays_configuration() -@pytest.mark.asyncio async def test_custom_with_buckets(): with open(SPARK_CONF_PATH, 'w') as f: f.write('spark.hadoop.fs.gs.requester.pays.project.id my_project\n') @@ -94,8 +89,6 @@ async def test_custom_with_buckets(): assert actual == ('my_project', ['abc', 'def']) - -@pytest.mark.asyncio async def test_disabled(): with open(SPARK_CONF_PATH, 'w') as f: f.write('spark.hadoop.fs.gs.requester.pays.project.id my_project\n') @@ -108,8 +101,6 @@ async def test_disabled(): assert actual is None - -@pytest.mark.asyncio async def test_enabled(): with open(SPARK_CONF_PATH, 'w') as f: f.write('spark.hadoop.fs.gs.requester.pays.project.id my_project\n') @@ -122,8 +113,6 @@ async def test_enabled(): assert actual == 'my_project' - -@pytest.mark.asyncio async def test_hailctl_takes_precedence_1(): await unset_hailctl() @@ -132,25 +121,17 @@ async def test_hailctl_takes_precedence_1(): f.write('spark.hadoop.fs.gs.requester.pays.mode ENABLED\n') f.write('spark.hadoop.fs.gs.requester.pays.buckets abc,def\n') - await check_exec_output( - 'hailctl', - 'config', - 'set', - 'gcs_requester_pays/project', - 'hailctl_project', - echo=True - ) + await check_exec_output('hailctl', 'config', 'set', 'gcs_requester_pays/project', 'hailctl_project', echo=True) actual = get_gcs_requester_pays_configuration() assert actual == 'hailctl_project', str(( configuration_of(ConfigVariable.GCS_REQUESTER_PAYS_PROJECT, None, None), configuration_of(ConfigVariable.GCS_REQUESTER_PAYS_BUCKETS, None, None), get_spark_conf_gcs_requester_pays_configuration(), - open('/Users/dking/.config/hail/config.ini', 'r').readlines() + open('/Users/dking/.config/hail/config.ini', 'r').readlines(), )) -@pytest.mark.asyncio async def test_hailctl_takes_precedence_2(): await unset_hailctl() @@ -159,23 +140,9 @@ async def test_hailctl_takes_precedence_2(): f.write('spark.hadoop.fs.gs.requester.pays.mode ENABLED\n') f.write('spark.hadoop.fs.gs.requester.pays.buckets abc,def\n') - await check_exec_output( - 'hailctl', - 'config', - 'set', - 'gcs_requester_pays/project', - 'hailctl_project2', - echo=True - ) + await check_exec_output('hailctl', 'config', 'set', 'gcs_requester_pays/project', 'hailctl_project2', echo=True) - await check_exec_output( - 'hailctl', - 'config', - 'set', - 'gcs_requester_pays/buckets', - 'bucket1,bucket2', - echo=True - ) + await check_exec_output('hailctl', 'config', 'set', 'gcs_requester_pays/buckets', 'bucket1,bucket2', echo=True) actual = get_gcs_requester_pays_configuration() assert actual == ('hailctl_project2', ['bucket1', 'bucket2']) diff --git a/hail/scripts/update-terra-image.py b/hail/scripts/update-terra-image.py index 9f9e4433ddb..3e7b28bb3d6 100644 --- a/hail/scripts/update-terra-image.py +++ b/hail/scripts/update-terra-image.py @@ -1,7 +1,6 @@ +import datetime import json import sys -import datetime - assert len(sys.argv) == 3, sys.argv hail_pip_version = sys.argv[1] @@ -46,8 +45,4 @@ def update_version_line(line): dockerfile = fobj.read() with open(f'{image_name}/Dockerfile', 'w') as fobj: - fobj.write( - '\n'.join( - [update_version_line(line) for line in dockerfile.split('\n')] - ) - ) + fobj.write('\n'.join([update_version_line(line) for line in dockerfile.split('\n')])) diff --git a/hail/scripts/upload_qob_jar.sh b/hail/scripts/upload_qob_jar.sh index 7aeb20b6c77..1ca4f1cc7e5 100644 --- a/hail/scripts/upload_qob_jar.sh +++ b/hail/scripts/upload_qob_jar.sh @@ -22,5 +22,9 @@ else JAR_LOCATION="${TEST_STORAGE_URI}/${NAMESPACE}/jars/${TOKEN}/${REVISION}.jar" fi -python3 -m hailtop.aiotools.copy -vvv 'null' '[{"from":"'${SHADOW_JAR}'", "to":"'${JAR_LOCATION}'"}]' +python3 -m hailtop.aiotools.copy \ + -vvv \ + 'null' \ + '[{"from":"'${SHADOW_JAR}'", "to":"'${JAR_LOCATION}'"}]' \ + --timeout 600 echo ${JAR_LOCATION} > ${PATH_FILE} diff --git a/hail/settings.gradle b/hail/settings.gradle deleted file mode 100644 index fe9cdb5db97..00000000000 --- a/hail/settings.gradle +++ /dev/null @@ -1,2 +0,0 @@ -rootProject.name='hail' -include 'shadedazure' diff --git a/hail/shadedazure/build.gradle b/hail/shadedazure/build.gradle deleted file mode 100644 index cc5df468676..00000000000 --- a/hail/shadedazure/build.gradle +++ /dev/null @@ -1,57 +0,0 @@ -buildscript { - repositories { - mavenCentral() - } -} - -plugins { - id 'java' - id 'com.github.johnrengelman.shadow' -} - -repositories { - mavenCentral() -} - -dependencies { - implementation group: 'com.azure', name: 'azure-storage-blob', version: '12.22.0' - implementation group: 'com.azure', name: 'azure-core-http-netty', version: '1.13.7' - implementation group: 'com.azure', name: 'azure-identity', version:'1.8.3' -} - -import com.github.jengelman.gradle.plugins.shadow.tasks.ShadowJar -tasks.withType(ShadowJar) { - archiveBaseName = 'shadedazure' - zip64 true - - // generate the jar then lightly prune the output of: - // jar tf shadedazure/build/libs/shadedazure-all.jar \ - // | awk -F'/' '{print "relocate '\''"$1"."$2"'\'', '\''is.hail.shadedazure."$1"."$2"'\''"}' \ - // | sort -u - relocate 'com.azure', 'is.hail.shadedazure.com.azure' - relocate 'com.ctc', 'is.hail.shadedazure.com.ctc' - relocate 'com.fasterxml', 'is.hail.shadedazure.com.fasterxml' - relocate 'com.microsoft', 'is.hail.shadedazure.com.microsoft' - relocate 'com.nimbusds', 'is.hail.shadedazure.com.nimbusds' - relocate 'com.sun', 'is.hail.shadedazure.com.sun' - relocate 'io.netty', 'is.hail.shadedazure.io.netty' - relocate 'is.hail', 'is.hail.shadedazure.is.hail' - relocate 'net.jcip', 'is.hail.shadedazure.net.jcip' - relocate 'net.minidev', 'is.hail.shadedazure.net.minidev' - relocate 'org.apache', 'is.hail.shadedazure.org.apache' - relocate 'org.codehaus', 'is.hail.shadedazure.org.codehaus' - relocate 'org.objectweb', 'is.hail.shadedazure.org.objectweb' - relocate 'org.reactivestreams', 'is.hail.shadedazure.org.reactivestreams' - relocate 'org.slf4j', 'is.hail.shadedazure.org.slf4j' - relocate 'reactor.adapter', 'is.hail.shadedazure.reactor.adapter' - relocate 'reactor.core', 'is.hail.shadedazure.reactor.core' - relocate 'reactor.netty', 'is.hail.shadedazure.reactor.netty' - relocate 'reactor.util', 'is.hail.shadedazure.reactor.util' - - exclude 'META-INF/*.RSA' - exclude 'META-INF/*.SF' - exclude 'META-INF/*.DSA' -} - -// you can make the jar from the parent directory with ./gradlew :shadedazure:shadowJar -shadowJar {} diff --git a/hail/src/debug/scala/is/hail/annotations/Memory.java b/hail/src/debug/java/is/hail/annotations/Memory.java similarity index 100% rename from hail/src/debug/scala/is/hail/annotations/Memory.java rename to hail/src/debug/java/is/hail/annotations/Memory.java diff --git a/hail/src/main/java/is/hail/io/compress/BGzipCodec.java b/hail/src/main/java/is/hail/io/compress/BGzipCodec.java index 9b313e1c826..1423eed6b03 100644 --- a/hail/src/main/java/is/hail/io/compress/BGzipCodec.java +++ b/hail/src/main/java/is/hail/io/compress/BGzipCodec.java @@ -47,12 +47,12 @@ public SplitCompressionInputStream createInputStream(InputStream seekableIn, } @Override - public CompressionOutputStream createOutputStream(OutputStream out) throws IOException { + public CompressionOutputStream createOutputStream(OutputStream out) { return new BGzipOutputStream(out); } @Override - public CompressionOutputStream createOutputStream(OutputStream out, Compressor compressor) throws IOException { + public CompressionOutputStream createOutputStream(OutputStream out, Compressor compressor) { return createOutputStream(out); } diff --git a/hail/src/main/java/is/hail/io/compress/BGzipConstants.java b/hail/src/main/java/is/hail/io/compress/BGzipConstants.java new file mode 100644 index 00000000000..802593ee94b --- /dev/null +++ b/hail/src/main/java/is/hail/io/compress/BGzipConstants.java @@ -0,0 +1,78 @@ +package is.hail.io.compress; + +public interface BGzipConstants { + + /** + * Number of bytes in the gzip block before the deflated data. + */ + int blockHeaderLength = 18; + + /** + * Location in the gzip block of the total block size (actually total block size - 1) + */ + int blockLengthOffset = 16; + + /** + * Number of bytes that follow the deflated data + */ + int blockFooterLength = 8; + + /** + * We require that a compressed block (including header and footer, be <= this) + */ + int maxCompressedBlockSize = 64 * 1024; + + /** + * Gzip overhead is the header, the footer, and the block size (encoded as a short). + */ + int gzipOverhead = blockHeaderLength + blockFooterLength + 2; + + /** + * If Deflater has compression level == NO_COMPRESSION, 10 bytes of overhead (determined experimentally). + */ + int noCompressionOverhead = 10; + + /** + * Push out a gzip block when this many uncompressed bytes have been accumulated. + */ + int defaultUncompressedBlockSize = 64 * 1024 - (gzipOverhead + noCompressionOverhead); + + // gzip magic numbers + + int gzipId1 = 31; + int gzipId2 = 139; + + int gzipModificationTime = 0; + + /** + * set extra fields to true + */ + int gzipFlag = 4; + + /** + * extra flags + */ + int gzipXFL = 0; + + /** + * length of extra subfield + */ + int gzipXLEN = 6; + + /** + * The deflate compression, which is customarily used by gzip + */ + int gzipCMDeflate = 8; + int defaultCompressionLevel = 5; + int gzipOsUnknown = 255; + int bgzfId1 = 66; + int bgzfId2 = 67; + int bgzfLen = 2; + + byte[] emptyGzipBlock = new byte[]{ + 0x1f, (byte) 0x8b, 0x08, 0x04, 0x00, 0x00, 0x00, 0x00, + 0x00, (byte) 0xff, 0x06, 0x00, 0x42, 0x43, 0x02, 0x00, + 0x1b, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00 + }; +} diff --git a/hail/src/main/java/is/hail/io/compress/BGzipOutputStream.java b/hail/src/main/java/is/hail/io/compress/BGzipOutputStream.java new file mode 100644 index 00000000000..757bdb22e8c --- /dev/null +++ b/hail/src/main/java/is/hail/io/compress/BGzipOutputStream.java @@ -0,0 +1,143 @@ +package is.hail.io.compress; + +import org.apache.hadoop.io.compress.CompressionOutputStream; + +import java.io.IOException; +import java.io.OutputStream; +import java.util.zip.CRC32; +import java.util.zip.Deflater; + +public class BGzipOutputStream extends CompressionOutputStream { + + private final byte[] uncompressedBuffer = + new byte[BGzipConstants.defaultUncompressedBlockSize]; + private final byte[] compressedBuffer = + new byte[BGzipConstants.maxCompressedBlockSize - BGzipConstants.blockHeaderLength]; + private final Deflater deflater = + new Deflater(BGzipConstants.defaultCompressionLevel, true); + private final Deflater noCompressionDeflater = + new Deflater(Deflater.NO_COMPRESSION, true); + private final CRC32 crc32 = new CRC32(); + + private boolean finished; + protected int numUncompressedBytes; + + + public BGzipOutputStream(OutputStream out) { + super(out); + finished = false; + } + + @Override + public void write(int b) throws IOException { + assert (numUncompressedBytes < uncompressedBuffer.length); + uncompressedBuffer[numUncompressedBytes] = (byte) b; + numUncompressedBytes += 1; + + if (numUncompressedBytes == uncompressedBuffer.length) { + deflateBlock(); + } + } + + @Override + public void write(byte[] bytes, int offset, int length) throws IOException { + assert (numUncompressedBytes < uncompressedBuffer.length); + + int currentPosition = offset; + int numBytesRemaining = length; + + while (numBytesRemaining > 0) { + int bytesToWrite = + Math.min(uncompressedBuffer.length - numUncompressedBytes, numBytesRemaining); + System.arraycopy(bytes, currentPosition, uncompressedBuffer, numUncompressedBytes, bytesToWrite); + numUncompressedBytes += bytesToWrite; + currentPosition += bytesToWrite; + numBytesRemaining -= bytesToWrite; + assert (numBytesRemaining >= 0); + + if (numUncompressedBytes == uncompressedBuffer.length) + deflateBlock(); + } + } + + final protected void deflateBlock() throws IOException { + assert (numUncompressedBytes != 0); + assert (!finished); + + deflater.reset(); + deflater.setInput(uncompressedBuffer, 0, numUncompressedBytes); + deflater.finish(); + int compressedSize = deflater.deflate(compressedBuffer, 0, compressedBuffer.length); + + // If it didn't all fit in compressedBuffer.length, set compression level to NO_COMPRESSION + // and try again. This should always fit. + if (!deflater.finished()) { + noCompressionDeflater.reset(); + noCompressionDeflater.setInput(uncompressedBuffer, 0, numUncompressedBytes); + noCompressionDeflater.finish(); + compressedSize = noCompressionDeflater.deflate(compressedBuffer, 0, compressedBuffer.length); + assert (noCompressionDeflater.finished()); + } + + // Data compressed small enough, so write it out. + crc32.reset(); + crc32.update(uncompressedBuffer, 0, numUncompressedBytes); + + writeGzipBlock(compressedSize, numUncompressedBytes, crc32.getValue()); + + numUncompressedBytes = 0; // reset variable + } + + public void writeInt8(int i) throws IOException { + out.write(i & 0xff); + } + + public void writeInt16(int i) throws IOException { + out.write(i & 0xff); + out.write((i >> 8) & 0xff); + } + + public void writeInt32(int i) throws IOException { + out.write(i & 0xff); + out.write((i >> 8) & 0xff); + out.write((i >> 16) & 0xff); + out.write((i >> 24) & 0xff); + } + + public int writeGzipBlock(int compressedSize, int bytesToCompress, long crc32val) throws IOException { + int totalBlockSize = compressedSize + BGzipConstants.blockHeaderLength + BGzipConstants.blockFooterLength; + + writeInt8(BGzipConstants.gzipId1); + writeInt8(BGzipConstants.gzipId2); + writeInt8(BGzipConstants.gzipCMDeflate); + writeInt8(BGzipConstants.gzipFlag); + writeInt32(BGzipConstants.gzipModificationTime); + writeInt8(BGzipConstants.gzipXFL); + writeInt8(BGzipConstants.gzipOsUnknown); + writeInt16(BGzipConstants.gzipXLEN); + writeInt8(BGzipConstants.bgzfId1); + writeInt8(BGzipConstants.bgzfId2); + writeInt16(BGzipConstants.bgzfLen); + writeInt16(totalBlockSize - 1); + out.write(compressedBuffer, 0, compressedSize); + writeInt32((int) crc32val); + writeInt32(bytesToCompress); + return totalBlockSize; + } + + @Override + public void resetState() { + throw new UnsupportedOperationException(); + } + + @Override + public void finish() throws IOException { + if (numUncompressedBytes != 0) + deflateBlock(); + + if (!finished) { + out.write(BGzipConstants.emptyGzipBlock); + finished = true; + } + } +} diff --git a/hail/src/main/java/is/hail/io/compress/ComposableBGzipOutputStream.java b/hail/src/main/java/is/hail/io/compress/ComposableBGzipOutputStream.java new file mode 100644 index 00000000000..6d469528cab --- /dev/null +++ b/hail/src/main/java/is/hail/io/compress/ComposableBGzipOutputStream.java @@ -0,0 +1,18 @@ +package is.hail.io.compress; + +import java.io.IOException; +import java.io.OutputStream; + +public final class ComposableBGzipOutputStream extends BGzipOutputStream { + + public ComposableBGzipOutputStream(OutputStream out) { + super(out); + } + + @Override + public void finish() throws IOException { + if (numUncompressedBytes != 0) { + deflateBlock(); + } + } +} diff --git a/hail/src/main/java/is/hail/io/fs/Positioned.java b/hail/src/main/java/is/hail/io/fs/Positioned.java new file mode 100644 index 00000000000..e74f37406e2 --- /dev/null +++ b/hail/src/main/java/is/hail/io/fs/Positioned.java @@ -0,0 +1,5 @@ +package is.hail.io.fs; + +public interface Positioned { + long getPosition(); +} diff --git a/hail/src/main/java/is/hail/io/fs/Seekable.java b/hail/src/main/java/is/hail/io/fs/Seekable.java new file mode 100644 index 00000000000..e1f048d7e2e --- /dev/null +++ b/hail/src/main/java/is/hail/io/fs/Seekable.java @@ -0,0 +1,5 @@ +package is.hail.io.fs; + +public interface Seekable extends Positioned { + void seek(long position); +} diff --git a/hail/src/main/scala/is/hail/HailContext.scala b/hail/src/main/scala/is/hail/HailContext.scala index 4e4063378b8..71ef4eb0be8 100644 --- a/hail/src/main/scala/is/hail/HailContext.scala +++ b/hail/src/main/scala/is/hail/HailContext.scala @@ -2,24 +2,24 @@ package is.hail import is.hail.backend.Backend import is.hail.backend.spark.SparkBackend -import is.hail.expr.ir.BaseIR import is.hail.expr.ir.functions.IRFunctionRegistry import is.hail.io.fs.FS import is.hail.io.vcf._ import is.hail.types.virtual._ import is.hail.utils._ -import org.apache.log4j.{ConsoleAppender, LogManager, PatternLayout, PropertyConfigurator} + +import scala.reflect.ClassTag + +import java.io.InputStream +import java.util.Properties + +import org.apache.log4j.{LogManager, PropertyConfigurator} import org.apache.spark._ import org.apache.spark.executor.InputMetrics import org.apache.spark.rdd.RDD import org.json4s.Extraction import org.json4s.jackson.JsonMethods -import java.io.InputStream -import java.util.Properties -import scala.collection.mutable -import scala.reflect.ClassTag - case class FilePartition(index: Int, file: String) extends Partition object HailContext { @@ -78,11 +78,12 @@ object HailContext { versionString match { // old-style version: 1.MAJOR.MINOR // new-style version: MAJOR.MINOR.SECURITY (started in JRE 9) - // see: https://docs.oracle.com/javase/9/migrate/toc.htm#JSMIG-GUID-3A71ECEF-5FC5-46FE-9BA9-88CBFCE828CB - case javaVersion("1", major, minor) => + /* see: + * https://docs.oracle.com/javase/9/migrate/toc.htm#JSMIG-GUID-3A71ECEF-5FC5-46FE-9BA9-88CBFCE828CB */ + case javaVersion("1", major, _) => if (major.toInt < 8) fatal(s"Hail requires Java 1.8, found $versionString") - case javaVersion(major, minor, security) => + case javaVersion(major, _, _) => if (major.toInt != 11) fatal(s"Hail requires Java 8 or 11, found $versionString") case _ => @@ -90,24 +91,26 @@ object HailContext { } } - def getOrCreate(backend: Backend, - branchingFactor: Int = 50, - optimizerIterations: Int = 3): HailContext = { + def getOrCreate(backend: Backend, branchingFactor: Int = 50, optimizerIterations: Int = 3) + : HailContext = { if (theContext == null) return HailContext(backend, branchingFactor, optimizerIterations) if (theContext.branchingFactor != branchingFactor) - warn(s"Requested branchingFactor $branchingFactor, but already initialized to ${ theContext.branchingFactor }. Ignoring requested setting.") + warn( + s"Requested branchingFactor $branchingFactor, but already initialized to ${theContext.branchingFactor}. Ignoring requested setting." + ) if (theContext.optimizerIterations != optimizerIterations) - warn(s"Requested optimizerIterations $optimizerIterations, but already initialized to ${ theContext.optimizerIterations }. Ignoring requested setting.") + warn( + s"Requested optimizerIterations $optimizerIterations, but already initialized to ${theContext.optimizerIterations}. Ignoring requested setting." + ) theContext } - def apply(backend: Backend, - branchingFactor: Int = 50, - optimizerIterations: Int = 3): HailContext = synchronized { + def apply(backend: Backend, branchingFactor: Int = 50, optimizerIterations: Int = 3) + : HailContext = synchronized { require(theContext == null) checkJavaVersion() @@ -115,13 +118,19 @@ object HailContext { import breeze.linalg._ import breeze.linalg.operators.{BinaryRegistry, OpMulMatrix} - implicitly[BinaryRegistry[DenseMatrix[Double], Vector[Double], OpMulMatrix.type, DenseVector[Double]]].register( - DenseMatrix.implOpMulMatrix_DMD_DVD_eq_DVD) + implicitly[BinaryRegistry[ + DenseMatrix[Double], + Vector[Double], + OpMulMatrix.type, + DenseVector[Double], + ]].register( + DenseMatrix.implOpMulMatrix_DMD_DVD_eq_DVD + ) } theContext = new HailContext(backend, branchingFactor, optimizerIterations) - info(s"Running Hail version ${ theContext.version }") + info(s"Running Hail version ${theContext.version}") theContext } @@ -138,7 +147,8 @@ object HailContext { path: String, partFiles: IndexedSeq[String], read: (Int, InputStream, InputMetrics) => Iterator[T], - optPartitioner: Option[Partitioner] = None): RDD[T] = { + optPartitioner: Option[Partitioner] = None, + ): RDD[T] = { val nPartitions = partFiles.length val fsBc = fs.broadcast @@ -159,10 +169,11 @@ object HailContext { } } -class HailContext private( +class HailContext private ( var backend: Backend, val branchingFactor: Int, - val optimizerIterations: Int) { + val optimizerIterations: Int, +) { def stop(): Unit = HailContext.stop() def sparkBackend(op: String): SparkBackend = backend.asSpark(op) @@ -175,7 +186,7 @@ class HailContext private( fs: FS, regex: String, files: Seq[String], - maxLines: Int + maxLines: Int, ): Map[String, Array[WithContext[String]]] = { val regexp = regex.r SparkBackend.sparkContext("fileAndLineCounts").textFilesLines(fs.globAll(files).map(_.getPath)) @@ -184,9 +195,9 @@ class HailContext private( .groupBy(_.source.file) } - def grepPrint(fs: FS, regex: String, files: Seq[String], maxLines: Int) { + def grepPrint(fs: FS, regex: String, files: Seq[String], maxLines: Int): Unit = { fileAndLineCounts(fs, regex, files, maxLines).foreach { case (file, lines) => - info(s"$file: ${ lines.length } ${ plural(lines.length, "match", "matches") }:") + info(s"$file: ${lines.length} ${plural(lines.length, "match", "matches")}:") lines.map(_.value).foreach { line => val (screen, logged) = line.truncatable().strings log.info("\t" + logged) @@ -195,12 +206,12 @@ class HailContext private( } } - def grepReturn(fs: FS, regex: String, files: Seq[String], maxLines: Int): Array[(String, Array[String])] = + def grepReturn(fs: FS, regex: String, files: Seq[String], maxLines: Int) + : Array[(String, Array[String])] = fileAndLineCounts(fs: FS, regex, files, maxLines).mapValues(_.map(_.value)).toArray - def parseVCFMetadata(fs: FS, file: String): Map[String, Map[String, Map[String, String]]] = { + def parseVCFMetadata(fs: FS, file: String): Map[String, Map[String, Map[String, String]]] = LoadVCF.parseHeaderMetadata(fs, Set.empty, TFloat64, file) - } def pyParseVCFMetadataJSON(fs: FS, file: String): String = { val metadata = LoadVCF.parseHeaderMetadata(fs, Set.empty, TFloat64, file) diff --git a/hail/src/main/scala/is/hail/HailFeatureFlags.scala b/hail/src/main/scala/is/hail/HailFeatureFlags.scala index 225baf4e274..3dcf4829797 100644 --- a/hail/src/main/scala/is/hail/HailFeatureFlags.scala +++ b/hail/src/main/scala/is/hail/HailFeatureFlags.scala @@ -2,10 +2,11 @@ package is.hail import is.hail.backend.ExecutionCache import is.hail.utils._ -import org.json4s.JsonAST.{JArray, JObject, JString} import scala.collection.mutable +import org.json4s.JsonAST.{JArray, JObject, JString} + object HailFeatureFlags { val defaults = Map[String, (String, String)]( // Must match __flags_env_vars_and_defaults in hail/backend/backend.py @@ -75,9 +76,10 @@ class HailFeatureFlags private ( def toJSONEnv: JArray = JArray(flags.filter { case (_, v) => v != null - }.map{ case (name, v) => + }.map { case (name, v) => JObject( "name" -> JString(HailFeatureFlags.defaults(name)._1), - "value" -> JString(v)) + "value" -> JString(v), + ) }.toList) } diff --git a/hail/src/main/scala/is/hail/annotations/Annotation.scala b/hail/src/main/scala/is/hail/annotations/Annotation.scala index 390d111d5f7..9f483f0c767 100644 --- a/hail/src/main/scala/is/hail/annotations/Annotation.scala +++ b/hail/src/main/scala/is/hail/annotations/Annotation.scala @@ -2,6 +2,7 @@ package is.hail.annotations import is.hail.types.virtual._ import is.hail.utils._ + import org.apache.spark.sql.Row object Annotation { @@ -33,14 +34,21 @@ object Annotation { case t: TInterval => val i = a.asInstanceOf[Interval] - i.copy(start = Annotation.copy(t.pointType, i.start), end = Annotation.copy(t.pointType, i.end)) + i.copy( + start = Annotation.copy(t.pointType, i.start), + end = Annotation.copy(t.pointType, i.end), + ) case t: TNDArray => val nd = a.asInstanceOf[NDArray] val rme = nd.getRowMajorElements() - SafeNDArray(nd.shape, Array.tabulate(rme.length)(i => Annotation.copy(t.elementType, rme(i))).toFastSeq) + SafeNDArray( + nd.shape, + Array.tabulate(rme.length)(i => Annotation.copy(t.elementType, rme(i))).toFastSeq, + ) - case TInt32 | TInt64 | TFloat32 | TFloat64 | TBoolean | TString | TCall | _: TLocus | TBinary => a + case TInt32 | TInt64 | TFloat32 | TFloat64 | TBoolean | TString | TCall | _: TLocus | TBinary => + a } } } diff --git a/hail/src/main/scala/is/hail/annotations/BroadcastValue.scala b/hail/src/main/scala/is/hail/annotations/BroadcastValue.scala index 18197d935d0..885e98d90c8 100644 --- a/hail/src/main/scala/is/hail/annotations/BroadcastValue.scala +++ b/hail/src/main/scala/is/hail/annotations/BroadcastValue.scala @@ -1,19 +1,22 @@ package is.hail.annotations -import java.io.{ByteArrayInputStream, ByteArrayOutputStream, InputStream} import is.hail.asm4s.HailClassLoader import is.hail.backend.{BroadcastValue, ExecuteContext} import is.hail.expr.ir.EncodedLiteral +import is.hail.io.{BufferSpec, Decoder, TypedCodecSpec} import is.hail.types.physical.{PArray, PStruct, PType} import is.hail.types.virtual.{TBaseStruct, TStruct} -import is.hail.io.{BufferSpec, Decoder, TypedCodecSpec} -import is.hail.utils.{ArrayOfByteArrayOutputStream, formatSpace, log} +import is.hail.utils.{formatSpace, log, ArrayOfByteArrayOutputStream} import is.hail.utils.prettyPrint.ArrayOfByteArrayInputStream + +import java.io.InputStream + import org.apache.spark.sql.Row case class SerializableRegionValue( - encodedValue: Array[Array[Byte]], t: PType, - makeDecoder: (InputStream, HailClassLoader) => Decoder + encodedValue: Array[Array[Byte]], + t: PType, + makeDecoder: (InputStream, HailClassLoader) => Decoder, ) { def readRegionValue(r: Region, theHailClassLoader: HailClassLoader): Long = { val dec = makeDecoder(new ArrayOfByteArrayInputStream(encodedValue), theHailClassLoader) @@ -69,7 +72,9 @@ trait BroadcastRegionValue { if (broadcasted == null) { val arrays = encodeToByteArrays(theHailClassLoader) val totalSize = arrays.map(_.length).sum - log.info(s"BroadcastRegionValue.broadcast: broadcasting ${ arrays.length } byte arrays of total size $totalSize (${ formatSpace(totalSize) }") + log.info( + s"BroadcastRegionValue.broadcast: broadcasting ${arrays.length} byte arrays of total size $totalSize (${formatSpace(totalSize)}" + ) val srv = SerializableRegionValue(arrays, decodedPType, makeDec) broadcasted = ctx.backend.broadcast(srv) } @@ -83,17 +88,16 @@ trait BroadcastRegionValue { def safeJavaValue: Any override def equals(obj: Any): Boolean = obj match { - case b: BroadcastRegionValue => t == b.t && (ctx eq b.ctx) && t.unsafeOrdering(ctx.stateManager).compare(value, b.value) == 0 + case b: BroadcastRegionValue => + t == b.t && (ctx eq b.ctx) && t.unsafeOrdering(ctx.stateManager).compare(value, b.value) == 0 case _ => false } override def hashCode(): Int = javaValue.hashCode() } -case class BroadcastRow(ctx: ExecuteContext, - value: RegionValue, - t: PStruct -) extends BroadcastRegionValue { +case class BroadcastRow(ctx: ExecuteContext, value: RegionValue, t: PStruct) + extends BroadcastRegionValue { def javaValue: UnsafeRow = UnsafeRow.readBaseStruct(t, value.region, value.offset) @@ -104,20 +108,24 @@ case class BroadcastRow(ctx: ExecuteContext, if (t == newT) return this - BroadcastRow(ctx, - RegionValue(value.region, newT.copyFromAddress(ctx.stateManager, value.region, t, value.offset, deepCopy = false)), - newT) + BroadcastRow( + ctx, + RegionValue( + value.region, + newT.copyFromAddress(ctx.stateManager, value.region, t, value.offset, deepCopy = false), + ), + newT, + ) } - def toEncodedLiteral(theHailClassLoader: HailClassLoader): EncodedLiteral = { + def toEncodedLiteral(theHailClassLoader: HailClassLoader): EncodedLiteral = EncodedLiteral(encoding, encodeToByteArrays(theHailClassLoader)) - } } case class BroadcastIndexedSeq( ctx: ExecuteContext, value: RegionValue, - t: PArray + t: PArray, ) extends BroadcastRegionValue { def safeJavaValue: IndexedSeq[Row] = SafeRow.read(t, value).asInstanceOf[IndexedSeq[Row]] @@ -129,8 +137,13 @@ case class BroadcastIndexedSeq( if (t == newT) return this - BroadcastIndexedSeq(ctx, - RegionValue(value.region, newT.copyFromAddress(ctx.stateManager, value.region, t, value.offset, deepCopy = false)), - newT) + BroadcastIndexedSeq( + ctx, + RegionValue( + value.region, + newT.copyFromAddress(ctx.stateManager, value.region, t, value.offset, deepCopy = false), + ), + newT, + ) } } diff --git a/hail/src/main/scala/is/hail/annotations/ChunkCache.scala b/hail/src/main/scala/is/hail/annotations/ChunkCache.scala index c8b8dec7001..6e389f8d3a6 100644 --- a/hail/src/main/scala/is/hail/annotations/ChunkCache.scala +++ b/hail/src/main/scala/is/hail/annotations/ChunkCache.scala @@ -2,29 +2,26 @@ package is.hail.annotations import is.hail.expr.ir.LongArrayBuilder +import scala.collection.mutable + import java.util.TreeMap import java.util.function.BiConsumer -import scala.collection.mutable - /** - * ChunkCache minimizes calls to free and allocate by holding onto - * chunks when they are no longer in use. When a chunk is needed, the cache - * is searched. If the size requested is less than a certain amount, the size - * is rounded up to the nearest power of 2 and the small chunk cache is checked - * for available chunk. If bigger, the big chunk cache returns the chunk whose size - * is the ceiling match. If the size requested is at least 90 percent of the size of - * the chunk returned, then that chunk is used. If no acceptable chunk is found, a new - * chunk is created. If the chunk created plus the current allocation is greater than - * peak usage, than chunks from the cache are deallocated until this condition is not - * true or the cache is empty. - * When freeChunk is called on RegionPool, the chunks get put in the cache that - * corresponds to their size. freeAll releases all chunks and is called when - * RegionPool is closed. - */ - -private class ChunkCache (allocator: Long => Long, freer: Long => Unit){ +/** ChunkCache minimizes calls to free and allocate by holding onto chunks when they are no longer + * in use. When a chunk is needed, the cache is searched. If the size requested is less than a + * certain amount, the size is rounded up to the nearest power of 2 and the small chunk cache is + * checked for available chunk. If bigger, the big chunk cache returns the chunk whose size is the + * ceiling match. If the size requested is at least 90 percent of the size of the chunk returned, + * then that chunk is used. If no acceptable chunk is found, a new chunk is created. If the chunk + * created plus the current allocation is greater than peak usage, than chunks from the cache are + * deallocated until this condition is not true or the cache is empty. When freeChunk is called on + * RegionPool, the chunks get put in the cache that corresponds to their size. freeAll releases all + * chunks and is called when RegionPool is closed. + */ + +private class ChunkCache(allocator: Long => Long, freer: Long => Unit) { private[this] val highestSmallChunkPowerOf2 = 24 - private[this] val biggestSmallChunk = Math.pow(2,highestSmallChunkPowerOf2) + private[this] val biggestSmallChunk = Math.pow(2, highestSmallChunkPowerOf2) private[this] val bigChunkCache = new TreeMap[Long, LongArrayBuilder]() private[this] val chunksEncountered = mutable.Map[Long, Long]() private[this] val minSpaceRequirements = .9 @@ -32,13 +29,14 @@ private class ChunkCache (allocator: Long => Long, freer: Long => Unit){ private[this] var cacheHits = 0 private[this] var smallChunkCacheSize = 0 private[this] val smallChunkCache = new Array[LongArrayBuilder](highestSmallChunkPowerOf2 + 1) - (0 until highestSmallChunkPowerOf2 + 1).foreach(index => { + + (0 until highestSmallChunkPowerOf2 + 1).foreach { index => smallChunkCache(index) = new LongArrayBuilder() - }) + } - def getChunkSize(chunkPointer: Long): Long = chunksEncountered(chunkPointer) + def getChunkSize(chunkPointer: Long): Long = chunksEncountered(chunkPointer) - def freeChunkFromMemory(pool: RegionPool, chunkPointer: Long):Unit = { + def freeChunkFromMemory(pool: RegionPool, chunkPointer: Long): Unit = { val size = chunksEncountered(chunkPointer) pool.decrementAllocatedBytes(size) freer(chunkPointer) @@ -47,14 +45,15 @@ private class ChunkCache (allocator: Long => Long, freer: Long => Unit){ def freeChunksFromCacheToFit(pool: RegionPool, sizeToFit: Long): Unit = { var smallChunkIndex = highestSmallChunkPowerOf2 - while((sizeToFit + pool.getTotalAllocatedBytes) > pool.getHighestTotalUsage && - smallChunkIndex >= 0 && !chunksEncountered.isEmpty) { + while ( + (sizeToFit + pool.getTotalAllocatedBytes) > pool.getHighestTotalUsage && + smallChunkIndex >= 0 && !chunksEncountered.isEmpty + ) { if (!bigChunkCache.isEmpty) { val toFree = bigChunkCache.lastEntry() freeChunkFromMemory(pool, toFree.getValue.pop()) if (toFree.getValue.size == 0) bigChunkCache.remove(toFree.getKey) - } - else { + } else { if (smallChunkCacheSize == 0) smallChunkIndex = -1 else { val toFree = smallChunkCache(smallChunkIndex) @@ -80,12 +79,13 @@ private class ChunkCache (allocator: Long => Long, freer: Long => Unit){ def freeAll(pool: RegionPool): Unit = { if (!chunksEncountered.isEmpty) { - smallChunkCache.foreach(ab => { + smallChunkCache.foreach { ab => while (ab.size > 0) { freeChunkFromMemory(pool, ab.pop()) smallChunkCacheSize -= 1 - }}) - //BiConsumer needed to work with scala 2.11.12 + } + } + // BiConsumer needed to work with scala 2.11.12 bigChunkCache.forEach(new BiConsumer[Long, LongArrayBuilder]() { def accept(key: Long, value: LongArrayBuilder): Unit = while (value.size > 0) freeChunkFromMemory(pool, value.pop()) @@ -93,14 +93,13 @@ private class ChunkCache (allocator: Long => Long, freer: Long => Unit){ } } - def getUsage(): (Int, Int) = { + def getUsage(): (Int, Int) = (chunksRequested, cacheHits) - } def indexInSmallChunkCache(size: Long): Int = { var closestPower = highestSmallChunkPowerOf2 - while((size >> closestPower) != 1) closestPower = closestPower - 1 - if (size % (1 << closestPower) != 0) closestPower +=1 + while ((size >> closestPower) != 1) closestPower = closestPower - 1 + if (size % (1 << closestPower) != 0) closestPower += 1 closestPower } @@ -109,25 +108,24 @@ private class ChunkCache (allocator: Long => Long, freer: Long => Unit){ assert(size > 0L) if (size <= biggestSmallChunk) { val closestPower = indexInSmallChunkCache(size) - if(smallChunkCache(closestPower).size == 0 ) { + if (smallChunkCache(closestPower).size == 0) { val sizePowerOf2 = (1 << closestPower).toLong (newChunk(pool, sizePowerOf2), sizePowerOf2) - } - else { + } else { cacheHits += 1 (smallChunkCache(closestPower).pop(), size) } - } - else { + } else { val closestSize = bigChunkCache.ceilingEntry(size) - if (closestSize != null && (closestSize.getKey == size - || ((closestSize.getKey * minSpaceRequirements) <= size))) { + if ( + closestSize != null && (closestSize.getKey == size + || ((closestSize.getKey * minSpaceRequirements) <= size)) + ) { cacheHits += 1 val chunkPointer = closestSize.getValue.pop() if (closestSize.getValue.size == 0) bigChunkCache.remove(closestSize.getKey) (chunkPointer, size) - } - else (newChunk(pool, size), size) + } else (newChunk(pool, size), size) } } @@ -136,21 +134,16 @@ private class ChunkCache (allocator: Long => Long, freer: Long => Unit){ if (chunkSize <= biggestSmallChunk) { smallChunkCache(indexInSmallChunkCache(chunkSize)) += chunkPointer smallChunkCacheSize += 1 - } - else { + } else { val sameSizeEntries = bigChunkCache.get(chunkSize) if (sameSizeEntries == null) { val newSize = new LongArrayBuilder() newSize += chunkPointer bigChunkCache.put(chunkSize, newSize) - } - else sameSizeEntries += chunkPointer + } else sameSizeEntries += chunkPointer } } - def freeChunksToCache( ab: LongArrayBuilder): Unit = { + def freeChunksToCache(ab: LongArrayBuilder): Unit = while (ab.size > 0) freeChunkToCache(ab.pop()) - } } - - diff --git a/hail/src/main/scala/is/hail/annotations/ExtendedOrdering.scala b/hail/src/main/scala/is/hail/annotations/ExtendedOrdering.scala index bab800578dc..1fe31337e6d 100644 --- a/hail/src/main/scala/is/hail/annotations/ExtendedOrdering.scala +++ b/hail/src/main/scala/is/hail/annotations/ExtendedOrdering.scala @@ -1,6 +1,7 @@ package is.hail.annotations import is.hail.utils._ + import org.apache.spark.sql.Row object ExtendedOrdering { @@ -14,7 +15,8 @@ object ExtendedOrdering { override def lteqNonnull(x: T, y: T): Boolean = ord.lteq(x.asInstanceOf[S], y.asInstanceOf[S]) - override def equivNonnull(x: T, y: T): Boolean = ord.equiv(x.asInstanceOf[S], y.asInstanceOf[S]) + override def equivNonnull(x: T, y: T): Boolean = + ord.equiv(x.asInstanceOf[S], y.asInstanceOf[S]) } } @@ -91,29 +93,29 @@ object ExtendedOrdering { // ord can be null if the element type is a TVariable val elemOrd = if (ord != null) ord.toOrdering else null - def compareNonnull(x: T, y: T): Int = { + def compareNonnull(x: T, y: T): Int = itOrd.compareNonnull( x.asInstanceOf[Array[T]].sorted(elemOrd).toFastSeq, - y.asInstanceOf[Array[T]].sorted(elemOrd).toFastSeq) - } + y.asInstanceOf[Array[T]].sorted(elemOrd).toFastSeq, + ) - override def ltNonnull(x: T, y: T): Boolean = { + override def ltNonnull(x: T, y: T): Boolean = itOrd.ltNonnull( x.asInstanceOf[Array[T]].sorted(elemOrd).toFastSeq, - y.asInstanceOf[Array[T]].sorted(elemOrd).toFastSeq) - } + y.asInstanceOf[Array[T]].sorted(elemOrd).toFastSeq, + ) - override def lteqNonnull(x: T, y: T): Boolean = { + override def lteqNonnull(x: T, y: T): Boolean = itOrd.lteqNonnull( x.asInstanceOf[Array[T]].sorted(elemOrd).toFastSeq, - y.asInstanceOf[Array[T]].sorted(elemOrd).toFastSeq) - } + y.asInstanceOf[Array[T]].sorted(elemOrd).toFastSeq, + ) - override def equivNonnull(x: T, y: T): Boolean = { + override def equivNonnull(x: T, y: T): Boolean = itOrd.equivNonnull( x.asInstanceOf[Array[T]].sorted(elemOrd).toFastSeq, - y.asInstanceOf[Array[T]].sorted(elemOrd).toFastSeq) - } + y.asInstanceOf[Array[T]].sorted(elemOrd).toFastSeq, + ) } def setOrdering(ord: ExtendedOrdering, _missingEqual: Boolean = true): ExtendedOrdering = @@ -122,29 +124,29 @@ object ExtendedOrdering { val missingEqual = _missingEqual - def compareNonnull(x: T, y: T): Int = { + def compareNonnull(x: T, y: T): Int = saOrd.compareNonnull( x.asInstanceOf[Iterable[T]].toArray, - y.asInstanceOf[Iterable[T]].toArray) - } + y.asInstanceOf[Iterable[T]].toArray, + ) - override def ltNonnull(x: T, y: T): Boolean = { + override def ltNonnull(x: T, y: T): Boolean = saOrd.ltNonnull( x.asInstanceOf[Iterable[T]].toArray, - y.asInstanceOf[Iterable[T]].toArray) - } + y.asInstanceOf[Iterable[T]].toArray, + ) - override def lteqNonnull(x: T, y: T): Boolean = { + override def lteqNonnull(x: T, y: T): Boolean = saOrd.lteqNonnull( x.asInstanceOf[Iterable[T]].toArray, - y.asInstanceOf[Iterable[T]].toArray) - } + y.asInstanceOf[Iterable[T]].toArray, + ) - override def equivNonnull(x: T, y: T): Boolean = { + override def equivNonnull(x: T, y: T): Boolean = saOrd.equivNonnull( x.asInstanceOf[Iterable[T]].toArray, - y.asInstanceOf[Iterable[T]].toArray) - } + y.asInstanceOf[Iterable[T]].toArray, + ) } def mapOrdering(ord: ExtendedOrdering, _missingEqual: Boolean = true): ExtendedOrdering = @@ -156,31 +158,35 @@ object ExtendedOrdering { private def toArrayOfT(x: T): Array[T] = x.asInstanceOf[Map[_, _]].iterator.map { case (k, v) => Row(k, v): T }.toArray - def compareNonnull(x: T, y: T): Int = { + def compareNonnull(x: T, y: T): Int = saOrd.compareNonnull( - toArrayOfT(x), toArrayOfT(y)) - } + toArrayOfT(x), + toArrayOfT(y), + ) - override def ltNonnull(x: T, y: T): Boolean = { + override def ltNonnull(x: T, y: T): Boolean = saOrd.ltNonnull( - toArrayOfT(x), toArrayOfT(y)) - } + toArrayOfT(x), + toArrayOfT(y), + ) - override def lteqNonnull(x: T, y: T): Boolean = { + override def lteqNonnull(x: T, y: T): Boolean = saOrd.lteqNonnull( - toArrayOfT(x), toArrayOfT(y)) - } + toArrayOfT(x), + toArrayOfT(y), + ) - override def equivNonnull(x: T, y: T): Boolean = { + override def equivNonnull(x: T, y: T): Boolean = saOrd.equivNonnull( - toArrayOfT(x), toArrayOfT(y)) - } + toArrayOfT(x), + toArrayOfT(y), + ) } - def rowOrdering(fieldOrd: Array[ExtendedOrdering], _missingEqual: Boolean = true): ExtendedOrdering = + def rowOrdering(fieldOrd: Array[ExtendedOrdering], _missingEqual: Boolean = true) + : ExtendedOrdering = new ExtendedOrdering { outer => - val missingEqual = _missingEqual override def compareNonnull(x: T, y: T): Int = { @@ -281,7 +287,7 @@ abstract class ExtendedOrdering extends Serializable { def ltNonnull(x: T, y: T): Boolean = compareNonnull(x, y) < 0 - def lteqNonnull(x: T, y: T): Boolean = compareNonnull(x, y) <= 0 + def lteqNonnull(x: T, y: T): Boolean = compareNonnull(x, y) <= 0 def equivNonnull(x: T, y: T): Boolean = compareNonnull(x, y) == 0 @@ -401,7 +407,8 @@ abstract class ExtendedOrdering extends Serializable { Integer.compare(xs, ys) } - override def lteqWithOverlap(allowedOverlap: Int)(x: IntervalEndpoint, y: IntervalEndpoint): Boolean = { + override def lteqWithOverlap(allowedOverlap: Int)(x: IntervalEndpoint, y: IntervalEndpoint) + : Boolean = { val xp = x.point val xs = x.sign val yp = y.point diff --git a/hail/src/main/scala/is/hail/annotations/JoinedRegionValue.scala b/hail/src/main/scala/is/hail/annotations/JoinedRegionValue.scala index 0969addc243..f387820c97c 100644 --- a/hail/src/main/scala/is/hail/annotations/JoinedRegionValue.scala +++ b/hail/src/main/scala/is/hail/annotations/JoinedRegionValue.scala @@ -3,5 +3,6 @@ package is.hail.annotations object JoinedRegionValue { def apply(): JoinedRegionValue = new JoinedRegionValue(null, null) - def apply(left: RegionValue, right: RegionValue): JoinedRegionValue = new JoinedRegionValue(left, right) + def apply(left: RegionValue, right: RegionValue): JoinedRegionValue = + new JoinedRegionValue(left, right) } diff --git a/hail/src/main/scala/is/hail/annotations/OrderedRVIterator.scala b/hail/src/main/scala/is/hail/annotations/OrderedRVIterator.scala index 05b57b2fb96..31ce3cffb5e 100644 --- a/hail/src/main/scala/is/hail/annotations/OrderedRVIterator.scala +++ b/hail/src/main/scala/is/hail/annotations/OrderedRVIterator.scala @@ -1,8 +1,8 @@ package is.hail.annotations import is.hail.backend.HailStateManager -import is.hail.types.physical.PInterval import is.hail.rvd.{RVDContext, RVDType} +import is.hail.types.physical.PInterval import is.hail.utils._ import scala.collection.generic.Growable @@ -11,14 +11,14 @@ import scala.collection.mutable object OrderedRVIterator { def multiZipJoin( sm: HailStateManager, - its: IndexedSeq[OrderedRVIterator] + its: IndexedSeq[OrderedRVIterator], ): Iterator[BoxedArrayBuilder[(RegionValue, Int)]] = { require(its.length > 0) val first = its(0) val flipbooks = its.map(_.iterator.toFlipbookIterator) FlipbookIterator.multiZipJoin( flipbooks.toArray, - first.t.joinComp(sm, first.t).compare + first.t.joinComp(sm, first.t).compare, ) } } @@ -33,13 +33,13 @@ case class OrderedRVIterator( def staircase: StagingIterator[FlipbookIterator[RegionValue]] = iterator.toFlipbookIterator.staircased(t.kRowOrdView(sm, ctx.freshRegion())) - def cogroup(other: OrderedRVIterator): - FlipbookIterator[Muple[FlipbookIterator[RegionValue], FlipbookIterator[RegionValue]]] = + def cogroup(other: OrderedRVIterator) + : FlipbookIterator[Muple[FlipbookIterator[RegionValue], FlipbookIterator[RegionValue]]] = this.iterator.toFlipbookIterator.cogroup( other.iterator.toFlipbookIterator, this.t.kRowOrdView(sm, ctx.freshRegion()), other.t.kRowOrdView(sm, ctx.freshRegion()), - this.t.kComp(sm, other.t).compare + this.t.kComp(sm, other.t).compare, ) def leftJoinDistinct(other: OrderedRVIterator): Iterator[JoinedRegionValue] = @@ -47,7 +47,7 @@ case class OrderedRVIterator( other.iterator.toFlipbookIterator, null, null, - this.t.joinComp(sm, other.t).compare + this.t.joinComp(sm, other.t).compare, ) def leftIntervalJoinDistinct(other: OrderedRVIterator): Iterator[JoinedRegionValue] = @@ -55,10 +55,11 @@ case class OrderedRVIterator( other.iterator.toFlipbookIterator, null, null, - this.t.intervalJoinComp(sm, other.t).compare + this.t.intervalJoinComp(sm, other.t).compare, ) - def leftIntervalJoin(other: OrderedRVIterator): Iterator[Muple[RegionValue, Iterable[RegionValue]]] = { + def leftIntervalJoin(other: OrderedRVIterator) + : Iterator[Muple[RegionValue, Iterable[RegionValue]]] = { val left = iterator.toFlipbookIterator val right = other.iterator.toFlipbookIterator val rightEndpointOrdering: Ordering[RegionValue] = RVDType.selectUnsafeOrdering( @@ -67,7 +68,7 @@ case class OrderedRVIterator( other.t.rowType, other.t.kFieldIdx, Array(other.t.kType.types(0).asInstanceOf[PInterval].endPrimaryUnsafeOrdering(sm)), - missingEqual = true + missingEqual = true, ).toRVOrdering.reverse val mixedOrd: (RegionValue, RegionValue) => Int = this.t.intervalJoinComp(sm, other.t).compare @@ -78,7 +79,7 @@ case class OrderedRVIterator( var isValid: Boolean = true - def setValue() { + def setValue(): Unit = { if (left.isValid) { while (buffer.nonEmpty && mixedOrd(left.value, buffer.head) > 0) buffer.dequeue() @@ -94,7 +95,7 @@ case class OrderedRVIterator( } } - def advance() { + def advance(): Unit = { left.advance() setValue() } @@ -107,7 +108,7 @@ case class OrderedRVIterator( def innerJoin( other: OrderedRVIterator, - rightBuffer: Iterable[RegionValue] with Growable[RegionValue] + rightBuffer: Iterable[RegionValue] with Growable[RegionValue], ): Iterator[JoinedRegionValue] = { iterator.toFlipbookIterator.innerJoin( other.iterator.toFlipbookIterator, @@ -116,13 +117,13 @@ case class OrderedRVIterator( null, null, rightBuffer, - this.t.joinComp(sm, other.t).compare + this.t.joinComp(sm, other.t).compare, ) } def leftJoin( other: OrderedRVIterator, - rightBuffer: Iterable[RegionValue] with Growable[RegionValue] + rightBuffer: Iterable[RegionValue] with Growable[RegionValue], ): Iterator[JoinedRegionValue] = { iterator.toFlipbookIterator.leftJoin( other.iterator.toFlipbookIterator, @@ -131,13 +132,13 @@ case class OrderedRVIterator( null, null, rightBuffer, - this.t.joinComp(sm, other.t).compare + this.t.joinComp(sm, other.t).compare, ) } def rightJoin( other: OrderedRVIterator, - rightBuffer: Iterable[RegionValue] with Growable[RegionValue] + rightBuffer: Iterable[RegionValue] with Growable[RegionValue], ): Iterator[JoinedRegionValue] = { iterator.toFlipbookIterator.rightJoin( other.iterator.toFlipbookIterator, @@ -146,13 +147,13 @@ case class OrderedRVIterator( null, null, rightBuffer, - this.t.joinComp(sm, other.t).compare + this.t.joinComp(sm, other.t).compare, ) } def outerJoin( other: OrderedRVIterator, - rightBuffer: Iterable[RegionValue] with Growable[RegionValue] + rightBuffer: Iterable[RegionValue] with Growable[RegionValue], ): Iterator[JoinedRegionValue] = { iterator.toFlipbookIterator.outerJoin( other.iterator.toFlipbookIterator, @@ -161,16 +162,15 @@ case class OrderedRVIterator( null, null, rightBuffer, - this.t.joinComp(sm, other.t).compare + this.t.joinComp(sm, other.t).compare, ) } - def merge(other: OrderedRVIterator): Iterator[RegionValue] = { + def merge(other: OrderedRVIterator): Iterator[RegionValue] = iterator.toFlipbookIterator.merge( other.iterator.toFlipbookIterator, - this.t.kComp(sm, other.t).compare + this.t.kComp(sm, other.t).compare, ) - } def localKeySort( newKey: IndexedSeq[String] @@ -184,7 +184,8 @@ case class OrderedRVIterator( private val bit = iterator.buffered private val q = new mutable.PriorityQueue[RegionValue]()( - t.copy(key = newKey).kInRowOrd(sm).toRVOrdering.reverse) + t.copy(key = newKey).kInRowOrd(sm).toRVOrdering.reverse + ) private val rvb = new RegionValueBuilder(sm, consumerRegion) private val rv = RegionValue() diff --git a/hail/src/main/scala/is/hail/annotations/Region.scala b/hail/src/main/scala/is/hail/annotations/Region.scala index 8c46a3a66af..c04ccca32e6 100644 --- a/hail/src/main/scala/is/hail/annotations/Region.scala +++ b/hail/src/main/scala/is/hail/annotations/Region.scala @@ -1,11 +1,9 @@ package is.hail.annotations import is.hail.asm4s -import is.hail.asm4s.{Code, coerce} -import is.hail.backend.HailTaskContext +import is.hail.asm4s.Code import is.hail.types.physical._ import is.hail.utils._ -import org.apache.spark.TaskContext object Region { type Size = Int @@ -88,71 +86,103 @@ object Region { (loadByte(b) & (1 << (bitOff & 7).toInt)) != 0 } - def setBit(byteOff: Long, bitOff: Long) { + def setBit(byteOff: Long, bitOff: Long): Unit = { val b = byteOff + (bitOff >> 3) - storeByte(b, - (loadByte(b) | (1 << (bitOff & 7).toInt)).toByte) + storeByte(b, (loadByte(b) | (1 << (bitOff & 7).toInt)).toByte) } - def clearBit(byteOff: Long, bitOff: Long) { + def clearBit(byteOff: Long, bitOff: Long): Unit = { val b = byteOff + (bitOff >> 3) - storeByte(b, - (loadByte(b) & ~(1 << (bitOff & 7).toInt)).toByte) + storeByte(b, (loadByte(b) & ~(1 << (bitOff & 7).toInt)).toByte) } - def storeBit(byteOff: Long, bitOff: Long, b: Boolean) { + def storeBit(byteOff: Long, bitOff: Long, b: Boolean): Unit = if (b) setBit(byteOff, bitOff) else clearBit(byteOff, bitOff) - } - - def loadInt(addr: Code[Long]): Code[Int] = Code.invokeScalaObject1[Long, Int](Region.getClass, "loadInt", addr) + def loadInt(addr: Code[Long]): Code[Int] = + Code.invokeScalaObject1[Long, Int](Region.getClass, "loadInt", addr) - def loadLong(addr: Code[Long]): Code[Long] = Code.invokeScalaObject1[Long, Long](Region.getClass, "loadLong", addr) + def loadLong(addr: Code[Long]): Code[Long] = + Code.invokeScalaObject1[Long, Long](Region.getClass, "loadLong", addr) - def loadFloat(addr: Code[Long]): Code[Float] = Code.invokeScalaObject1[Long, Float](Region.getClass, "loadFloat", addr) + def loadFloat(addr: Code[Long]): Code[Float] = + Code.invokeScalaObject1[Long, Float](Region.getClass, "loadFloat", addr) - def loadDouble(addr: Code[Long]): Code[Double] = Code.invokeScalaObject1[Long, Double](Region.getClass, "loadDouble", addr) + def loadDouble(addr: Code[Long]): Code[Double] = + Code.invokeScalaObject1[Long, Double](Region.getClass, "loadDouble", addr) - def loadAddress(addr: Code[Long]): Code[Long] = Code.invokeScalaObject1[Long, Long](Region.getClass, "loadAddress", addr) + def loadAddress(addr: Code[Long]): Code[Long] = + Code.invokeScalaObject1[Long, Long](Region.getClass, "loadAddress", addr) - def loadByte(addr: Code[Long]): Code[Byte] = Code.invokeScalaObject1[Long, Byte](Region.getClass, "loadByte", addr) + def loadByte(addr: Code[Long]): Code[Byte] = + Code.invokeScalaObject1[Long, Byte](Region.getClass, "loadByte", addr) - def loadShort(addr: Code[Long]): Code[Short] = Code.invokeScalaObject1[Long, Short](Region.getClass, "loadShort", addr) + def loadShort(addr: Code[Long]): Code[Short] = + Code.invokeScalaObject1[Long, Short](Region.getClass, "loadShort", addr) - def loadChar(addr: Code[Long]): Code[Char] = Code.invokeScalaObject1[Long, Char](Region.getClass, "loadChar", addr) + def loadChar(addr: Code[Long]): Code[Char] = + Code.invokeScalaObject1[Long, Char](Region.getClass, "loadChar", addr) - def storeInt(addr: Code[Long], v: Code[Int]): Code[Unit] = Code.invokeScalaObject2[Long, Int, Unit](Region.getClass, "storeInt", addr, v) + def storeInt(addr: Code[Long], v: Code[Int]): Code[Unit] = + Code.invokeScalaObject2[Long, Int, Unit](Region.getClass, "storeInt", addr, v) - def storeLong(addr: Code[Long], v: Code[Long]): Code[Unit] = Code.invokeScalaObject2[Long, Long, Unit](Region.getClass, "storeLong", addr, v) + def storeLong(addr: Code[Long], v: Code[Long]): Code[Unit] = + Code.invokeScalaObject2[Long, Long, Unit](Region.getClass, "storeLong", addr, v) - def storeFloat(addr: Code[Long], v: Code[Float]): Code[Unit] = Code.invokeScalaObject2[Long, Float, Unit](Region.getClass, "storeFloat", addr, v) + def storeFloat(addr: Code[Long], v: Code[Float]): Code[Unit] = + Code.invokeScalaObject2[Long, Float, Unit](Region.getClass, "storeFloat", addr, v) - def storeDouble(addr: Code[Long], v: Code[Double]): Code[Unit] = Code.invokeScalaObject2[Long, Double, Unit](Region.getClass, "storeDouble", addr, v) + def storeDouble(addr: Code[Long], v: Code[Double]): Code[Unit] = + Code.invokeScalaObject2[Long, Double, Unit](Region.getClass, "storeDouble", addr, v) - def storeChar(addr: Code[Long], v: Code[Char]): Code[Unit] = Code.invokeScalaObject2[Long, Char, Unit](Region.getClass, "storeChar", addr, v) + def storeChar(addr: Code[Long], v: Code[Char]): Code[Unit] = + Code.invokeScalaObject2[Long, Char, Unit](Region.getClass, "storeChar", addr, v) - def storeAddress(addr: Code[Long], v: Code[Long]): Code[Unit] = Code.invokeScalaObject2[Long, Long, Unit](Region.getClass, "storeAddress", addr, v) + def storeAddress(addr: Code[Long], v: Code[Long]): Code[Unit] = + Code.invokeScalaObject2[Long, Long, Unit](Region.getClass, "storeAddress", addr, v) - def storeByte(addr: Code[Long], v: Code[Byte]): Code[Unit] = Code.invokeScalaObject2[Long, Byte, Unit](Region.getClass, "storeByte", addr, v) + def storeByte(addr: Code[Long], v: Code[Byte]): Code[Unit] = + Code.invokeScalaObject2[Long, Byte, Unit](Region.getClass, "storeByte", addr, v) - def storeShort(addr: Code[Long], v: Code[Short]): Code[Unit] = Code.invokeScalaObject2[Long, Short, Unit](Region.getClass, "storeShort", addr, v) + def storeShort(addr: Code[Long], v: Code[Short]): Code[Unit] = + Code.invokeScalaObject2[Long, Short, Unit](Region.getClass, "storeShort", addr, v) - def loadBoolean(addr: Code[Long]): Code[Boolean] = Code.invokeScalaObject1[Long, Boolean](Region.getClass, "loadBoolean", addr) + def loadBoolean(addr: Code[Long]): Code[Boolean] = + Code.invokeScalaObject1[Long, Boolean](Region.getClass, "loadBoolean", addr) - def storeBoolean(addr: Code[Long], v: Code[Boolean]): Code[Unit] = Code.invokeScalaObject2[Long, Boolean, Unit](Region.getClass, "storeBoolean", addr, v) + def storeBoolean(addr: Code[Long], v: Code[Boolean]): Code[Unit] = + Code.invokeScalaObject2[Long, Boolean, Unit](Region.getClass, "storeBoolean", addr, v) - def loadBytes(addr: Code[Long], n: Code[Int]): Code[Array[Byte]] = Code.invokeScalaObject2[Long, Int, Array[Byte]](Region.getClass, "loadBytes", addr, n) + def loadBytes(addr: Code[Long], n: Code[Int]): Code[Array[Byte]] = + Code.invokeScalaObject2[Long, Int, Array[Byte]](Region.getClass, "loadBytes", addr, n) - def loadBytes(addr: Code[Long], dst: Code[Array[Byte]], dstOff: Code[Long], n: Code[Long]): Code[Unit] = - Code.invokeScalaObject4[Long, Array[Byte], Long, Long, Unit](Region.getClass, "loadBytes", addr, dst, dstOff, n) + def loadBytes(addr: Code[Long], dst: Code[Array[Byte]], dstOff: Code[Long], n: Code[Long]) + : Code[Unit] = + Code.invokeScalaObject4[Long, Array[Byte], Long, Long, Unit]( + Region.getClass, + "loadBytes", + addr, + dst, + dstOff, + n, + ) - def storeBytes(addr: Code[Long], src: Code[Array[Byte]]): Code[Unit] = Code.invokeScalaObject2[Long, Array[Byte], Unit](Region.getClass, "storeBytes", addr, src) + def storeBytes(addr: Code[Long], src: Code[Array[Byte]]): Code[Unit] = + Code.invokeScalaObject2[Long, Array[Byte], Unit](Region.getClass, "storeBytes", addr, src) - def storeBytes(addr: Code[Long], src: Code[Array[Byte]], srcOff: Code[Long], n: Code[Long]): Code[Unit] = - Code.invokeScalaObject4[Long, Array[Byte], Long, Long, Unit](Region.getClass, "storeBytes", addr, src, srcOff, n) + def storeBytes(addr: Code[Long], src: Code[Array[Byte]], srcOff: Code[Long], n: Code[Long]) + : Code[Unit] = + Code.invokeScalaObject4[Long, Array[Byte], Long, Long, Unit]( + Region.getClass, + "storeBytes", + addr, + src, + srcOff, + n, + ) def copyFrom(srcOff: Code[Long], dstOff: Code[Long], n: Code[Long]): Code[Unit] = Code.invokeScalaObject3[Long, Long, Long, Unit](Region.getClass, "copyFrom", srcOff, dstOff, n) @@ -167,13 +197,24 @@ object Region { Code.invokeScalaObject2[Long, Long, Unit](Region.getClass, "clearBit", byteOff, bitOff) def storeBit(byteOff: Code[Long], bitOff: Code[Long], b: Code[Boolean]): Code[Unit] = - Code.invokeScalaObject3[Long, Long, Boolean, Unit](Region.getClass, "storeBit", byteOff, bitOff, b) + Code.invokeScalaObject3[Long, Long, Boolean, Unit]( + Region.getClass, + "storeBit", + byteOff, + bitOff, + b, + ) def setMemory(offset: Code[Long], size: Code[Long], b: Code[Byte]): Code[Unit] = Code.invokeScalaObject3[Long, Long, Byte, Unit](Region.getClass, "setMemory", offset, size, b) def containsNonZeroBits(address: Code[Long], nBits: Code[Long]): Code[Boolean] = - Code.invokeScalaObject2[Long, Long, Boolean](Region.getClass, "containsNonZeroBits", address, nBits) + Code.invokeScalaObject2[Long, Long, Boolean]( + Region.getClass, + "containsNonZeroBits", + address, + nBits, + ) def containsNonZeroBits(address: Long, nBits: Long): Boolean = { assert((address & 0x3) == 0) @@ -188,21 +229,21 @@ object Region { } while (nBits - bitsRead >= 64) { - if (loadLong(address + bitsRead/8) != 0) + if (loadLong(address + bitsRead / 8) != 0) return true bitsRead += 64 } while (nBits - bitsRead >= 32) { - if (loadInt(address + bitsRead/8) != 0) + if (loadInt(address + bitsRead / 8) != 0) return true bitsRead += 32 } while (nBits - bitsRead >= 8) { - if (loadByte(address + bitsRead/8) != 0) + if (loadByte(address + bitsRead / 8) != 0) return true bitsRead += 8 @@ -220,30 +261,47 @@ object Region { val sharedChunkHeaderBytes = 16L def getSharedChunkRefCount(ndAddr: Long): Long = Region.loadLong(ndAddr - sharedChunkHeaderBytes) - def storeSharedChunkRefCount(ndAddr: Long, newCount: Long): Unit = Region.storeLong(ndAddr - sharedChunkHeaderBytes, newCount) + + def storeSharedChunkRefCount(ndAddr: Long, newCount: Long): Unit = + Region.storeLong(ndAddr - sharedChunkHeaderBytes, newCount) + def getSharedChunkByteSize(ndAddr: Long): Long = Region.loadLong(ndAddr - 8L) def getSharedChunkByteSize(ndAddr: Code[Long]): Code[Long] = Region.loadLong(ndAddr - 8L) - def storeSharedChunkByteSize(ndAddr: Long, byteSize: Long): Unit = Region.storeLong(ndAddr - 8L, byteSize) - def stagedCreate(blockSize: Size, pool: Code[RegionPool]): Code[Region] = - Code.invokeScalaObject2[Int, RegionPool, Region](Region.getClass, "apply", asm4s.const(blockSize), pool) + def storeSharedChunkByteSize(ndAddr: Long, byteSize: Long): Unit = + Region.storeLong(ndAddr - 8L, byteSize) - def apply(blockSize: Region.Size = Region.REGULAR, pool: RegionPool): Region = { + def stagedCreate(blockSize: Size, pool: Code[RegionPool]): Code[Region] = + Code.invokeScalaObject2[Int, RegionPool, Region]( + Region.getClass, + "apply", + asm4s.const(blockSize), + pool, + ) + + def apply(blockSize: Region.Size = Region.REGULAR, pool: RegionPool): Region = pool.getRegion(blockSize) - } def pretty(off: Long, n: Int, header: String): String = { val linewidth = 4 s"$header\n" + - Region.loadBytes(off, n) - .map(b => "%02x".format(b)).grouped(8).map(_.mkString(" ")) - .grouped(linewidth).zipWithIndex - .map { case (s, i) => " %016x ".format(off + (8 * 8 * linewidth * i)) + s.mkString(" ") } - .mkString("\n") + Region.loadBytes(off, n) + .map(b => "%02x".format(b)).grouped(8).map(_.mkString(" ")) + .grouped(linewidth).zipWithIndex + .map { case (s, i) => + " %016x ".format(off + (8 * 8 * linewidth * i)) + s.mkString(" ") + } + .mkString("\n") } def pretty(off: Code[Long], n: Int, header: Code[String]): Code[String] = - Code.invokeScalaObject3[Long, Int, String, String](Region.getClass, "pretty", off, asm4s.const(n), header) + Code.invokeScalaObject3[Long, Int, String, String]( + Region.getClass, + "pretty", + off, + asm4s.const(n), + header, + ) def pretty(t: PType, off: Long): String = { val v = new PrettyVisitor() @@ -251,7 +309,7 @@ object Region { v.result() } - def visit(t: PType, off: Long, v: ValueVisitor) { + def visit(t: PType, off: Long, v: ValueVisitor): Unit = { t match { case _: PBoolean => v.visitBoolean(Region.loadBoolean(off)) case _: PInt32 => v.visitInt32(Region.loadInt(off)) @@ -317,7 +375,11 @@ object Region { } } -final class Region protected[annotations](var blockSize: Region.Size, var pool: RegionPool, var memory: RegionMemory = null) extends AutoCloseable { +final class Region protected[annotations] ( + var blockSize: Region.Size, + var pool: RegionPool, + var memory: RegionMemory = null, +) extends AutoCloseable { def getMemory(): RegionMemory = memory def isValid(): Boolean = memory != null @@ -332,12 +394,11 @@ final class Region protected[annotations](var blockSize: Region.Size, var pool: memory.allocate(a, n) } - def invalidate(): Unit = { + def invalidate(): Unit = if (memory != null) { memory.release() memory = null } - } def clear(): Unit = { if (memory.getReferenceCount == 1) { @@ -353,17 +414,14 @@ final class Region protected[annotations](var blockSize: Region.Size, var pool: memory.allocateSharedChunk(nBytes) } - def trackSharedChunk(addr: Long): Unit = { + def trackSharedChunk(addr: Long): Unit = memory.trackSharedChunk(addr) - } - def close(): Unit = { + def close(): Unit = invalidate() - } - def addReferenceTo(r: Region): Unit = { + def addReferenceTo(r: Region): Unit = memory.addReferenceTo(r.memory) - } def move(r: Region): Unit = { r.memory.takeOwnershipOf(memory) @@ -378,19 +436,16 @@ final class Region protected[annotations](var blockSize: Region.Size, var pool: memory = pool.getMemory(blockSize) } - def setNumParents(n: Int): Unit = { + def setNumParents(n: Int): Unit = memory.setNumParents(n) - } - def setParentReference(child: Region, idx: Int): Unit = { + def setParentReference(child: Region, idx: Int): Unit = memory.setReferenceAtIndex(child.memory, idx) - } def getReferenceCount(): Long = memory.getReferenceCount - def getParentReference(idx: Int, blockSize: Region.Size): Region = { + def getParentReference(idx: Int, blockSize: Region.Size): Region = new Region(blockSize, pool, memory.getReferenceAtIndex(idx, blockSize)) - } def setFromParentReference(r: Region, idx: Int, blockSize: Region.Size): Unit = { invalidate() @@ -404,37 +459,48 @@ final class Region protected[annotations](var blockSize: Region.Size, var pool: r } - def unreferenceRegionAtIndex(idx: Int): Unit = { + def unreferenceRegionAtIndex(idx: Int): Unit = memory.releaseReferenceAtIndex(idx) - } def storeJavaObject(obj: AnyRef): Int = memory.storeJavaObject(obj) def lookupJavaObject(idx: Int): AnyRef = memory.lookupJavaObject(idx) - def prettyBits(): String = { + def prettyBits(): String = "FIXME: implement prettyBits on Region" - } - def getPool(): RegionPool = { + def getPool(): RegionPool = pool - } def totalManagedBytes(): Long = memory.totalManagedBytes() } object RegionUtils { def printAddr(off: Long, name: String): String = s"$name: ${"%016x".format(off)}" - def printAddr(off: Code[Long], name: String): Code[String] = Code.invokeScalaObject2[Long, String, String](RegionUtils.getClass, "printAddr", off, name) + + def printAddr(off: Code[Long], name: String): Code[String] = + Code.invokeScalaObject2[Long, String, String](RegionUtils.getClass, "printAddr", off, name) def printBytes(off: Long, n: Int, header: String): String = Region.loadBytes(off, n).zipWithIndex .grouped(16) - .map(bs => bs.map { case (b, _) => "%02x".format(b) }.mkString(" %016x: ".format(off + bs(0)._2), " ", "")) + .map(bs => + bs.map { case (b, _) => "%02x".format(b) }.mkString( + " %016x: ".format(off + bs(0)._2), + " ", + "", + ) + ) .mkString(if (header != null) s"$header\n" else "\n", "\n", "") def printBytes(off: Code[Long], n: Int, header: String): Code[String] = - Code.invokeScalaObject3[Long, Int, String, String](RegionUtils.getClass, "printBytes", off, n, asm4s.const(header)) + Code.invokeScalaObject3[Long, Int, String, String]( + RegionUtils.getClass, + "printBytes", + off, + n, + asm4s.const(header), + ) def logRegionStats(header: String, region: RegionMemory): Unit = { val size = region.blockSize @@ -456,9 +522,15 @@ object RegionUtils { | ndarrays: $ndarrays | block addr: $addr | referenced: $nReferenced - """.stripMargin) + """.stripMargin + ) } def logRegionStats(header: String, region: Code[Region]): Code[Unit] = - Code.invokeScalaObject2[String, RegionMemory, Unit](RegionUtils.getClass, "logRegionStats", header, region.invoke[RegionMemory]("getMemory")) + Code.invokeScalaObject2[String, RegionMemory, Unit]( + RegionUtils.getClass, + "logRegionStats", + header, + region.invoke[RegionMemory]("getMemory"), + ) } diff --git a/hail/src/main/scala/is/hail/annotations/RegionMemory.scala b/hail/src/main/scala/is/hail/annotations/RegionMemory.scala index 9eb44c69a33..f71bef2ceda 100644 --- a/hail/src/main/scala/is/hail/annotations/RegionMemory.scala +++ b/hail/src/main/scala/is/hail/annotations/RegionMemory.scala @@ -1,10 +1,8 @@ package is.hail.annotations -import is.hail.expr.ir.{AnyRefArrayBuilder, LongArrayBuilder, LongMissingArrayBuilder} -import is.hail.types.physical.{PCanonicalNDArray, PNDArray} +import is.hail.expr.ir.{AnyRefArrayBuilder, LongArrayBuilder} import is.hail.utils._ - final class RegionMemory(pool: RegionPool) extends AutoCloseable { private[this] val usedBlocks = new LongArrayBuilder(4) private[this] val bigChunks = new LongArrayBuilder(4) @@ -16,7 +14,8 @@ final class RegionMemory(pool: RegionPool) extends AutoCloseable { private[this] var offsetWithinBlock: Long = _ // var stackTrace: Option[IndexedSeq[StackTraceElement]] = None - // blockThreshold and blockByteSize are mutable because RegionMemory objects are reused with different sizes + /* blockThreshold and blockByteSize are mutable because RegionMemory objects are reused with + * different sizes */ protected[annotations] var blockSize: Region.Size = -1 private[this] var blockThreshold: Long = _ private[this] var blockByteSize: Long = _ @@ -31,18 +30,17 @@ final class RegionMemory(pool: RegionPool) extends AutoCloseable { idx } - def lookupJavaObject(idx: Int): AnyRef = { + def lookupJavaObject(idx: Int): AnyRef = jObjects(idx) - } def dumpMemoryInfo(): String = { s""" |Blocks Used = ${usedBlocks.size}, Chunks used = ${bigChunks.size} |Block Info: - | BlockSize = ${blockSize} ($blockByteSize bytes) + | BlockSize = $blockSize ($blockByteSize bytes) | Current Block Info: - | Current Block Address: ${currentBlock} - | Offset Within Block: ${offsetWithinBlock} + | Current Block Address: $currentBlock + | Offset Within Block: $offsetWithinBlock | Used Blocks Info: | BlockStarts: ${usedBlocks.result().toIndexedSeq} |""".stripMargin @@ -50,7 +48,8 @@ final class RegionMemory(pool: RegionPool) extends AutoCloseable { def allocateNewBlock(): Unit = { val newBlock = pool.getBlock(blockSize) - // don't add currentBlock to usedBlocks until pool.getBlock returns successfully (could throw OOM exception) + /* don't add currentBlock to usedBlocks until pool.getBlock returns successfully (could throw + * OOM exception) */ if (currentBlock != 0) usedBlocks.add(currentBlock) currentBlock = newBlock @@ -60,7 +59,8 @@ final class RegionMemory(pool: RegionPool) extends AutoCloseable { private def allocateBigChunk(size: Long): Long = { val ret = pool.getChunk(size) - val chunkPointer = ret._1 // Match expressions allocate https://github.com/hail-is/hail/pull/13794 + val chunkPointer = + ret._1 // Match expressions allocate https://github.com/hail-is/hail/pull/13794 val chunkSize = ret._2 bigChunks.add(chunkPointer) totalChunkMemory += chunkSize @@ -181,13 +181,12 @@ final class RegionMemory(pool: RegionPool) extends AutoCloseable { } } - private def free(): Unit = { + private def free(): Unit = if (!isFreed) { freeMemory() pool.reclaim(this) } // stackTrace = None - } def getReferenceCount: Long = referenceCount @@ -227,9 +226,8 @@ final class RegionMemory(pool: RegionPool) extends AutoCloseable { offsetWithinBlock = 0L } - def close(): Unit = { + def close(): Unit = free() - } def numChunks: Int = bigChunks.size @@ -289,7 +287,9 @@ final class RegionMemory(pool: RegionPool) extends AutoCloseable { def allocateSharedChunk(size: Long): Long = { if (size < 0L) { - throw new IllegalArgumentException(s"Can't request ndarray of negative memory size, got ${size}") + throw new IllegalArgumentException( + s"Can't request ndarray of negative memory size, got $size" + ) } val extra = Region.sharedChunkHeaderBytes @@ -310,7 +310,6 @@ final class RegionMemory(pool: RegionPool) extends AutoCloseable { Region.storeSharedChunkRefCount(alloc, curRefCount + 1L) } - def listNDArrayRefs(): IndexedSeq[Long] = { + def listNDArrayRefs(): IndexedSeq[Long] = this.ndarrayRefs.result().toIndexedSeq - } } diff --git a/hail/src/main/scala/is/hail/annotations/RegionPool.scala b/hail/src/main/scala/is/hail/annotations/RegionPool.scala index b51417ef041..bd8671c0d54 100644 --- a/hail/src/main/scala/is/hail/annotations/RegionPool.scala +++ b/hail/src/main/scala/is/hail/annotations/RegionPool.scala @@ -3,10 +3,6 @@ package is.hail.annotations import is.hail.expr.ir.LongArrayBuilder import is.hail.utils._ -import java.util.TreeMap -import java.util.function.BiConsumer -import scala.collection.mutable - object RegionPool { def apply(strictMemoryCheck: Boolean = false): RegionPool = { @@ -25,9 +21,13 @@ object RegionPool { } } -final class RegionPool private(strictMemoryCheck: Boolean, threadName: String, threadID: Long) extends AutoCloseable { +final class RegionPool private (strictMemoryCheck: Boolean, threadName: String, threadID: Long) + extends AutoCloseable { log.info(s"RegionPool: initialized for thread $threadID: $threadName") - protected[annotations] val freeBlocks: Array[LongArrayBuilder] = Array.fill[LongArrayBuilder](4)(new LongArrayBuilder(8)) + + protected[annotations] val freeBlocks: Array[LongArrayBuilder] = + Array.fill[LongArrayBuilder](4)(new LongArrayBuilder(8)) + protected[annotations] val regions = new BoxedArrayBuilder[RegionMemory]() private[this] val freeRegions = new BoxedArrayBuilder[RegionMemory]() private[this] val blocks: Array[Long] = Array(0L, 0L, 0L, 0L) @@ -38,25 +38,25 @@ final class RegionPool private(strictMemoryCheck: Boolean, threadName: String, t private[this] val chunkCache = new ChunkCache(Memory.malloc, Memory.free) private[this] val maxSize = RegionPool.maxRegionPoolSize - def addJavaObject(): Unit = { + def addJavaObject(): Unit = numJavaObjects += 1 - } - def removeJavaObjects(n: Int): Unit = { + def removeJavaObjects(n: Int): Unit = numJavaObjects -= n - } def getTotalAllocatedBytes: Long = totalAllocatedBytes def getHighestTotalUsage: Long = highestTotalUsage def getUsage: (Int, Int) = chunkCache.getUsage() - private[annotations] def decrementAllocatedBytes(toSubtract: Long): Unit = totalAllocatedBytes -= toSubtract + private[annotations] def decrementAllocatedBytes(toSubtract: Long): Unit = + totalAllocatedBytes -= toSubtract def closeAndThrow(msg: String): Unit = { close() fatal(msg) } + private[annotations] def incrementAllocatedBytes(toAdd: Long): Unit = { totalAllocatedBytes += toAdd if (totalAllocatedBytes >= allocationEchoThreshold) { @@ -67,17 +67,18 @@ final class RegionPool private(strictMemoryCheck: Boolean, threadName: String, t highestTotalUsage = totalAllocatedBytes if (totalAllocatedBytes > maxSize) { val inBlocks = bytesInBlocks() - closeAndThrow(s"Hail off-heap memory exceeded maximum threshold: limit ${ formatSpace(maxSize) }, allocated ${ formatSpace(totalAllocatedBytes) }\n" - + s"Report: ${readableBytes(totalAllocatedBytes)} allocated (${readableBytes(inBlocks)} blocks / " - + s"${readableBytes(totalAllocatedBytes - inBlocks)} chunks), regions.size = ${regions.size}, " - + s"$numJavaObjects current java objects, thread $threadID: $threadName") + closeAndThrow( + s"Hail off-heap memory exceeded maximum threshold: limit ${formatSpace(maxSize)}, allocated ${formatSpace(totalAllocatedBytes)}\n" + + s"Report: ${readableBytes(totalAllocatedBytes)} allocated (${readableBytes(inBlocks)} blocks / " + + s"${readableBytes(totalAllocatedBytes - inBlocks)} chunks), regions.size = ${regions.size}, " + + s"$numJavaObjects current java objects, thread $threadID: $threadName" + ) } } } - protected[annotations] def reclaim(memory: RegionMemory): Unit = { + protected[annotations] def reclaim(memory: RegionMemory): Unit = freeRegions += memory - } protected[annotations] def getBlock(size: Int): Long = { val pool = freeBlocks(size) @@ -92,16 +93,14 @@ final class RegionPool private(strictMemoryCheck: Boolean, threadName: String, t } } - protected[annotations] def getChunk(size: Long): (Long, Long) = { + protected[annotations] def getChunk(size: Long): (Long, Long) = chunkCache.getChunk(this, size) - } - protected[annotations] def freeChunks(ab: LongArrayBuilder, totalSize: Long): Unit = { + protected[annotations] def freeChunks(ab: LongArrayBuilder, totalSize: Long): Unit = chunkCache.freeChunksToCache(ab) - } - protected[annotations] def freeChunk(chunkPointer: Long): Unit = { + + protected[annotations] def freeChunk(chunkPointer: Long): Unit = chunkCache.freeChunkToCache(chunkPointer) - } protected[annotations] def getMemory(size: Int): RegionMemory = { if (freeRegions.size > 0) { @@ -130,7 +129,9 @@ final class RegionPool private(strictMemoryCheck: Boolean, threadName: String, t def numFreeBlocks(): Int = freeBlocks.map(_.size).sum - def bytesInBlocks(): Long = Region.SIZES.zip(blocks).map { case (size, block) => size * block }.sum[Long] + def bytesInBlocks(): Long = Region.SIZES.zip(blocks).map { case (size, block) => + size * block + }.sum[Long] def logStats(context: String): Unit = { val nFree = this.numFreeRegions() @@ -143,16 +144,19 @@ final class RegionPool private(strictMemoryCheck: Boolean, threadName: String, t s"""Region count for $context | regions: $nRegions active, $nFree free | blocks: $nBlocks - | free: ${ freeBlockCounts.mkString(", ") } - | used: ${ usedBlockCounts.mkString(", ") }""".stripMargin) + | free: ${freeBlockCounts.mkString(", ")} + | used: ${usedBlockCounts.mkString(", ")}""".stripMargin + ) } def report(context: String): Unit = { val inBlocks = bytesInBlocks() - log.info(s"RegionPool: $context: ${readableBytes(totalAllocatedBytes)} allocated (${readableBytes(inBlocks)} blocks / " + - s"${readableBytes(totalAllocatedBytes - inBlocks)} chunks), regions.size = ${regions.size}, " + - s"$numJavaObjects current java objects, thread $threadID: $threadName") + log.info( + s"RegionPool: $context: ${readableBytes(totalAllocatedBytes)} allocated (${readableBytes(inBlocks)} blocks / " + + s"${readableBytes(totalAllocatedBytes - inBlocks)} chunks), regions.size = ${regions.size}, " + + s"$numJavaObjects current java objects, thread $threadID: $threadName" + ) // log.info("-----------STACK_TRACES---------") // val stacks: String = regions.result().toIndexedSeq.flatMap(r => r.stackTrace.map((r.getTotalChunkMemory(), _))).foldLeft("")((a: String, b) => a + "\n" + b.toString()) // log.info(stacks) @@ -160,8 +164,8 @@ final class RegionPool private(strictMemoryCheck: Boolean, threadName: String, t } def scopedRegion[T](f: Region => T): T = using(Region(pool = this))(f) - def scopedSmallRegion[T](f: Region => T): T = using(Region(Region.SMALL, pool=this))(f) - def scopedTinyRegion[T](f: Region => T): T = using(Region(Region.TINY, pool=this))(f) + def scopedSmallRegion[T](f: Region => T): T = using(Region(Region.SMALL, pool = this))(f) + def scopedTinyRegion[T](f: Region => T): T = using(Region(Region.TINY, pool = this))(f) override def finalize(): Unit = close() @@ -194,7 +198,7 @@ final class RegionPool private(strictMemoryCheck: Boolean, threadName: String, t chunkCache.freeAll(pool = this) if (totalAllocatedBytes != 0) { val msg = s"RegionPool: total allocated bytes not 0 after closing! total allocated: " + - s"$totalAllocatedBytes (${ readableBytes(totalAllocatedBytes) })" + s"$totalAllocatedBytes (${readableBytes(totalAllocatedBytes)})" if (strictMemoryCheck) fatal(msg) else diff --git a/hail/src/main/scala/is/hail/annotations/RegionValue.scala b/hail/src/main/scala/is/hail/annotations/RegionValue.scala index c5617b88906..a4457195ba3 100644 --- a/hail/src/main/scala/is/hail/annotations/RegionValue.scala +++ b/hail/src/main/scala/is/hail/annotations/RegionValue.scala @@ -1,11 +1,11 @@ package is.hail.annotations -import java.io._ - import is.hail.asm4s.HailClassLoader -import is.hail.types.physical.PType -import is.hail.utils.{using, RestartableByteArrayInputStream} import is.hail.io._ +import is.hail.types.physical.PType + +import java.io._ + import sun.reflect.generics.reflectiveObjects.NotImplementedException object RegionValue { @@ -19,7 +19,7 @@ object RegionValue { theHailClassLoader: HailClassLoader, makeDec: (InputStream, HailClassLoader) => Decoder, r: Region, - byteses: Iterator[Array[Byte]] + byteses: Iterator[Array[Byte]], ): Iterator[Long] = { val bad = new ByteArrayDecoder(theHailClassLoader, makeDec) byteses.map(bad.regionValueFromBytes(r, _)) @@ -29,18 +29,16 @@ object RegionValue { theHailClassLoader: HailClassLoader, makeDec: (InputStream, HailClassLoader) => Decoder, r: Region, - byteses: Iterator[Array[Byte]] + byteses: Iterator[Array[Byte]], ): Iterator[Long] = { val bad = new ByteArrayDecoder(theHailClassLoader, makeDec) - byteses.map { bytes => - bad.regionValueFromBytes(r, bytes) - } + byteses.map(bytes => bad.regionValueFromBytes(r, bytes)) } def toBytes( theHailClassLoader: HailClassLoader, makeEnc: (OutputStream, HailClassLoader) => Encoder, - rvs: Iterator[Long] + rvs: Iterator[Long], ): Iterator[Array[Byte]] = { val bae = new ByteArrayEncoder(theHailClassLoader, makeEnc) rvs.map(bae.regionValueToBytes) @@ -49,30 +47,26 @@ object RegionValue { final class RegionValue( var region: Region, - var offset: Long + var offset: Long, ) extends UnKryoSerializable { def getOffset: Long = offset - def set(newRegion: Region, newOffset: Long) { + def set(newRegion: Region, newOffset: Long): Unit = { region = newRegion offset = newOffset } - def setRegion(newRegion: Region) { + def setRegion(newRegion: Region): Unit = region = newRegion - } - def setOffset(newOffset: Long) { + def setOffset(newOffset: Long): Unit = offset = newOffset - } def pretty(t: PType): String = Region.pretty(t, offset) - private def writeObject(s: ObjectOutputStream): Unit = { + private def writeObject(s: ObjectOutputStream): Unit = throw new NotImplementedException() - } - private def readObject(s: ObjectInputStream): Unit = { + private def readObject(s: ObjectInputStream): Unit = throw new NotImplementedException() - } } diff --git a/hail/src/main/scala/is/hail/annotations/RegionValueBuilder.scala b/hail/src/main/scala/is/hail/annotations/RegionValueBuilder.scala index ab26171833a..4aa85dbaffc 100644 --- a/hail/src/main/scala/is/hail/annotations/RegionValueBuilder.scala +++ b/hail/src/main/scala/is/hail/annotations/RegionValueBuilder.scala @@ -1,11 +1,9 @@ package is.hail.annotations -import is.hail.backend.{ExecuteContext, HailStateManager} +import is.hail.backend.HailStateManager import is.hail.types.physical._ import is.hail.types.virtual._ import is.hail.utils._ -import is.hail.variant.Locus -import org.apache.spark.sql.Row class RegionValueBuilder(sm: HailStateManager, var region: Region) { def this(sm: HailStateManager) = this(sm, null) @@ -18,7 +16,8 @@ class RegionValueBuilder(sm: HailStateManager, var region: Region) { val offsetstk = new LongArrayStack() val elementsOffsetstk = new LongArrayStack() - def inactive: Boolean = root == null && typestk.isEmpty && offsetstk.isEmpty && elementsOffsetstk.isEmpty && indexstk.isEmpty + def inactive: Boolean = + root == null && typestk.isEmpty && offsetstk.isEmpty && elementsOffsetstk.isEmpty && indexstk.isEmpty def clear(): Unit = { root = null @@ -28,7 +27,7 @@ class RegionValueBuilder(sm: HailStateManager, var region: Region) { indexstk.clear() } - def set(newRegion: Region) { + def set(newRegion: Region): Unit = { assert(inactive) region = newRegion } @@ -61,15 +60,15 @@ class RegionValueBuilder(sm: HailStateManager, var region: Region) { } } - def start(newRoot: PType) { + def start(newRoot: PType): Unit = { assert(inactive) root = newRoot } - def allocateRoot() { + def allocateRoot(): Unit = { assert(typestk.isEmpty) root match { - case t: PArray => + case _: PArray => case _: PBinary => case _ => start = region.allocate(root.alignment, root.byteSize) @@ -83,12 +82,11 @@ class RegionValueBuilder(sm: HailStateManager, var region: Region) { start } - def advance() { + def advance(): Unit = if (indexstk.nonEmpty) indexstk(0) = indexstk(0) + 1 - } - def startBaseStruct(init: Boolean = true, setMissing: Boolean = false) { + def startBaseStruct(init: Boolean = true, setMissing: Boolean = false): Unit = { val t = currentType().asInstanceOf[PBaseStruct] if (typestk.isEmpty) allocateRoot() @@ -102,7 +100,7 @@ class RegionValueBuilder(sm: HailStateManager, var region: Region) { t.initialize(off, setMissing) } - def endBaseStruct() { + def endBaseStruct(): Unit = { val t = typestk.top.asInstanceOf[PBaseStruct] typestk.pop() offsetstk.pop() @@ -112,40 +110,39 @@ class RegionValueBuilder(sm: HailStateManager, var region: Region) { advance() } - def startStruct(init: Boolean = true, setMissing: Boolean = false) { + def startStruct(init: Boolean = true, setMissing: Boolean = false): Unit = { assert(currentType().isInstanceOf[PStruct]) startBaseStruct(init, setMissing) } - def endStruct() { + def endStruct(): Unit = { assert(typestk.top.isInstanceOf[PStruct]) endBaseStruct() } - def startTuple(init: Boolean = true) { + def startTuple(init: Boolean = true): Unit = { assert(currentType().isInstanceOf[PTuple]) startBaseStruct(init) } - def endTuple() { + def endTuple(): Unit = { assert(typestk.top.isInstanceOf[PTuple]) endBaseStruct() } - def startArray(length: Int, init: Boolean = true) { + def startArray(length: Int, init: Boolean = true): Unit = startArrayInternal(length, init, false) - } // using this function, rather than startArray will set all elements of the array to missing by // default, you will need to use setPresent to add a value to this array. - def startMissingArray(length: Int, init: Boolean = true) { + def startMissingArray(length: Int, init: Boolean = true): Unit = { val t = currentType().asInstanceOf[PArray] if (t.elementType.required) - fatal(s"cannot use random array pattern for required type ${ t.elementType }") + fatal(s"cannot use random array pattern for required type ${t.elementType}") startArrayInternal(length, init, true) } - private def startArrayInternal(length: Int, init: Boolean, setMissing: Boolean) { + private def startArrayInternal(length: Int, init: Boolean, setMissing: Boolean): Unit = { val t = currentType() match { case abc: PArrayBackedContainer => abc.arrayRep case arr: PArray => arr @@ -167,7 +164,7 @@ class RegionValueBuilder(sm: HailStateManager, var region: Region) { t.initialize(aoff, length, setMissing) } - def endArray() { + def endArray(): Unit = { val t = typestk.top.asInstanceOf[PArray] val aoff = offsetstk.top val length = t.loadLength(aoff) @@ -176,7 +173,7 @@ class RegionValueBuilder(sm: HailStateManager, var region: Region) { endArrayUnchecked() } - def endArrayUnchecked() { + def endArrayUnchecked(): Unit = { typestk.pop() offsetstk.pop() elementsOffsetstk.pop() @@ -185,32 +182,32 @@ class RegionValueBuilder(sm: HailStateManager, var region: Region) { advance() } - def setArrayIndex(newI: Int) { + def setArrayIndex(newI: Int): Unit = { assert(typestk.top.isInstanceOf[PArray]) indexstk(0) = newI } - def setFieldIndex(newI: Int) { + def setFieldIndex(newI: Int): Unit = { assert(typestk.top.isInstanceOf[PBaseStruct]) indexstk(0) = newI } - def setMissing() { + def setMissing(): Unit = { val i = indexstk.top typestk.top match { case t: PBaseStruct => if (t.fieldRequired(i)) - fatal(s"cannot set missing field for required type ${ t.types(i) }") + fatal(s"cannot set missing field for required type ${t.types(i)}") t.setFieldMissing(offsetstk.top, i) case t: PArray => if (t.elementType.required) - fatal(s"cannot set missing field for required type ${ t.elementType }") + fatal(s"cannot set missing field for required type ${t.elementType}") t.setElementMissing(offsetstk.top, i) } advance() } - def setPresent() { + def setPresent(): Unit = { val i = indexstk.top typestk.top match { case t: PBaseStruct => @@ -220,7 +217,7 @@ class RegionValueBuilder(sm: HailStateManager, var region: Region) { } } - def addBoolean(b: Boolean) { + def addBoolean(b: Boolean): Unit = { assert(currentType().isInstanceOf[PBoolean]) if (typestk.isEmpty) allocateRoot() @@ -229,7 +226,7 @@ class RegionValueBuilder(sm: HailStateManager, var region: Region) { advance() } - def addInt(i: Int) { + def addInt(i: Int): Unit = { assert(currentType().isInstanceOf[PInt32]) addIntInternal(i) } @@ -239,7 +236,7 @@ class RegionValueBuilder(sm: HailStateManager, var region: Region) { addIntInternal(c) } - def addIntInternal(i: Int) { + def addIntInternal(i: Int): Unit = { if (typestk.isEmpty) allocateRoot() val off = currentOffset() @@ -247,7 +244,7 @@ class RegionValueBuilder(sm: HailStateManager, var region: Region) { advance() } - def addLong(l: Long) { + def addLong(l: Long): Unit = { assert(currentType().isInstanceOf[PInt64]) if (typestk.isEmpty) allocateRoot() @@ -256,7 +253,7 @@ class RegionValueBuilder(sm: HailStateManager, var region: Region) { advance() } - def addFloat(f: Float) { + def addFloat(f: Float): Unit = { assert(currentType().isInstanceOf[PFloat32]) if (typestk.isEmpty) allocateRoot() @@ -265,7 +262,7 @@ class RegionValueBuilder(sm: HailStateManager, var region: Region) { advance() } - def addDouble(d: Double) { + def addDouble(d: Double): Unit = { assert(currentType().isInstanceOf[PFloat64]) if (typestk.isEmpty) allocateRoot() @@ -274,9 +271,14 @@ class RegionValueBuilder(sm: HailStateManager, var region: Region) { advance() } - def addString(s: String) { + def addString(s: String): Unit = { assert(currentType().isInstanceOf[PString]) - currentType().asInstanceOf[PString].unstagedStoreJavaObjectAtAddress(sm, currentOffset(), s, region) + currentType().asInstanceOf[PString].unstagedStoreJavaObjectAtAddress( + sm, + currentOffset(), + s, + region, + ) advance() } @@ -286,18 +288,16 @@ class RegionValueBuilder(sm: HailStateManager, var region: Region) { advance() } - def addField(t: PBaseStruct, fromRegion: Region, fromOff: Long, i: Int) { + def addField(t: PBaseStruct, fromRegion: Region, fromOff: Long, i: Int): Unit = addField(t, fromOff, i, region.ne(fromRegion)) - } - def addField(t: PBaseStruct, fromOff: Long, i: Int, deepCopy: Boolean) { + def addField(t: PBaseStruct, fromOff: Long, i: Int, deepCopy: Boolean): Unit = if (t.isFieldDefined(fromOff, i)) addRegionValue(t.types(i), t.loadField(fromOff, i), deepCopy) else setMissing() - } - def skipFields(n: Int) { + def skipFields(n: Int): Unit = { var i = 0 while (i < n) { setMissing() @@ -305,7 +305,7 @@ class RegionValueBuilder(sm: HailStateManager, var region: Region) { } } - def addAllFields(t: PBaseStruct, fromRegion: Region, fromOff: Long) { + def addAllFields(t: PBaseStruct, fromRegion: Region, fromOff: Long): Unit = { var i = 0 while (i < t.size) { addField(t, fromRegion, fromOff, i) @@ -313,11 +313,10 @@ class RegionValueBuilder(sm: HailStateManager, var region: Region) { } } - def addAllFields(t: PBaseStruct, fromRV: RegionValue) { + def addAllFields(t: PBaseStruct, fromRV: RegionValue): Unit = addAllFields(t, fromRV.region, fromRV.offset) - } - def addFields(t: PBaseStruct, fromRegion: Region, fromOff: Long, fieldIdx: Array[Int]) { + def addFields(t: PBaseStruct, fromRegion: Region, fromOff: Long, fieldIdx: Array[Int]): Unit = { var i = 0 while (i < fieldIdx.length) { addField(t, fromRegion, fromOff, fieldIdx(i)) @@ -325,15 +324,14 @@ class RegionValueBuilder(sm: HailStateManager, var region: Region) { } } - def addFields(t: PBaseStruct, fromRV: RegionValue, fieldIdx: Array[Int]) { + def addFields(t: PBaseStruct, fromRV: RegionValue, fieldIdx: Array[Int]): Unit = addFields(t, fromRV.region, fromRV.offset, fieldIdx) - } - def selectRegionValue(fromT: PStruct, fromFieldIdx: Array[Int], fromRV: RegionValue) { + def selectRegionValue(fromT: PStruct, fromFieldIdx: Array[Int], fromRV: RegionValue): Unit = selectRegionValue(fromT, fromFieldIdx, fromRV.region, fromRV.offset) - } - def selectRegionValue(fromT: PStruct, fromFieldIdx: Array[Int], region: Region, offset: Long) { + def selectRegionValue(fromT: PStruct, fromFieldIdx: Array[Int], region: Region, offset: Long) + : Unit = { // too expensive! // val t = fromT.typeAfterSelect(fromFieldIdx) // assert(currentType().setRequired(true) == t.setRequired(true), s"${currentType()} != ${t}") @@ -343,15 +341,13 @@ class RegionValueBuilder(sm: HailStateManager, var region: Region) { endStruct() } - def addRegionValue(t: PType, rv: RegionValue) { + def addRegionValue(t: PType, rv: RegionValue): Unit = addRegionValue(t, rv.region, rv.offset) - } - def addRegionValue(t: PType, fromRegion: Region, fromOff: Long) { + def addRegionValue(t: PType, fromRegion: Region, fromOff: Long): Unit = addRegionValue(t, fromOff, region.ne(fromRegion)) - } - def addRegionValue(t: PType, fromOff: Long, deepCopy: Boolean) { + def addRegionValue(t: PType, fromOff: Long, deepCopy: Boolean): Unit = { val toT = currentType() if (typestk.isEmpty) { @@ -368,7 +364,7 @@ class RegionValueBuilder(sm: HailStateManager, var region: Region) { advance() } - def addAnnotation(t: Type, a: Annotation) { + def addAnnotation(t: Type, a: Annotation): Unit = { assert(typestk.nonEmpty) if (a == null) { setMissing() diff --git a/hail/src/main/scala/is/hail/annotations/UnsafeRow.scala b/hail/src/main/scala/is/hail/annotations/UnsafeRow.scala index c17e59122ab..dcccc9e6ff8 100644 --- a/hail/src/main/scala/is/hail/annotations/UnsafeRow.scala +++ b/hail/src/main/scala/is/hail/annotations/UnsafeRow.scala @@ -1,29 +1,30 @@ package is.hail.annotations -import java.io.{ObjectInputStream, ObjectOutputStream} -import com.esotericsoftware.kryo.{Kryo, KryoSerializable} -import com.esotericsoftware.kryo.io.{Input, Output} -import is.hail.annotations.UnsafeRow.read -import is.hail.types.virtual._ import is.hail.types.physical._ +import is.hail.types.virtual._ import is.hail.utils._ import is.hail.variant.Locus + +import java.io.{ObjectInputStream, ObjectOutputStream} + +import com.esotericsoftware.kryo.{Kryo, KryoSerializable} +import com.esotericsoftware.kryo.io.{Input, Output} import org.apache.spark.sql.Row import sun.reflect.generics.reflectiveObjects.NotImplementedException trait UnKryoSerializable extends KryoSerializable { - def write(kryo: Kryo, output: Output): Unit = { + def write(kryo: Kryo, output: Output): Unit = throw new NotImplementedException() - } - def read(kryo: Kryo, input: Input): Unit = { + def read(kryo: Kryo, input: Input): Unit = throw new NotImplementedException() - } } class UnsafeIndexedSeq( val t: PContainer, - val region: Region, val aoff: Long) extends IndexedSeq[Annotation] with UnKryoSerializable { + val region: Region, + val aoff: Long, +) extends IndexedSeq[Annotation] with UnKryoSerializable { val length: Int = t.loadLength(aoff) @@ -52,17 +53,17 @@ object UnsafeRow { def readString(boff: Long, t: PString): String = new String(readBinary(boff, t.binaryRepresentation)) - def readLocus(offset: Long, t: PLocus): Locus = { + def readLocus(offset: Long, t: PLocus): Locus = Locus( t.contig(offset), - t.position(offset)) - } + t.position(offset), + ) - def readNDArray(offset: Long, region: Region, nd: PNDArray): UnsafeNDArray = { + def readNDArray(offset: Long, region: Region, nd: PNDArray): UnsafeNDArray = new UnsafeNDArray(nd, region, offset) - } - def readAnyRef(t: PType, region: Region, offset: Long): AnyRef = read(t, region, offset).asInstanceOf[AnyRef] + def readAnyRef(t: PType, region: Region, offset: Long): AnyRef = + read(t, region, offset).asInstanceOf[AnyRef] def read(t: PType, region: Region, offset: Long): Any = { t match { @@ -97,15 +98,14 @@ object UnsafeRow { val includesStart = x.includesStart(offset) val includesEnd = x.includesEnd(offset) Interval(start, end, includesStart, includesEnd) - case nd: PNDArray => { + case nd: PNDArray => readNDArray(offset, region, nd) - } } } } -class UnsafeRow(val t: PBaseStruct, - var region: Region, var offset: Long) extends Row with UnKryoSerializable { +class UnsafeRow(val t: PBaseStruct, var region: Region, var offset: Long) + extends Row with UnKryoSerializable { override def toString: String = { if (t.isInstanceOf[PStruct]) { @@ -149,7 +149,7 @@ class UnsafeRow(val t: PBaseStruct, def this() = this(null, null, 0) - def set(newRegion: Region, newOffset: Long) { + def set(newRegion: Region, newOffset: Long): Unit = { region = newRegion offset = newOffset } @@ -158,17 +158,15 @@ class UnsafeRow(val t: PBaseStruct, def length: Int = t.size - private def assertDefined(i: Int) { + private def assertDefined(i: Int): Unit = if (isNullAt(i)) throw new NullPointerException(s"null value at index $i") - } - def get(i: Int): Any = { + def get(i: Int): Any = if (isNullAt(i)) null else UnsafeRow.read(t.types(i), region, t.loadField(offset, i)) - } def copy(): Row = new UnsafeRow(t, region, offset) @@ -210,19 +208,16 @@ class UnsafeRow(val t: PBaseStruct, !t.isFieldDefined(offset, i) } - private def writeObject(s: ObjectOutputStream): Unit = { + private def writeObject(s: ObjectOutputStream): Unit = throw new NotImplementedException() - } - private def readObject(s: ObjectInputStream): Unit = { + private def readObject(s: ObjectInputStream): Unit = throw new NotImplementedException() - } } object SafeRow { - def apply(t: PBaseStruct, off: Long): Row = { + def apply(t: PBaseStruct, off: Long): Row = Annotation.copy(t.virtualType, new UnsafeRow(t, null, off)).asInstanceOf[Row] - } def apply(t: PBaseStruct, rv: RegionValue): Row = SafeRow(t, rv.offset) @@ -237,7 +232,8 @@ object SafeRow { def read(t: PType, off: Long): Annotation = Annotation.copy(t.virtualType, UnsafeRow.read(t, null, off)) - def readAnyRef(t: PType, region: Region, offset: Long): AnyRef = read(t, offset).asInstanceOf[AnyRef] + def readAnyRef(t: PType, region: Region, offset: Long): AnyRef = + read(t, offset).asInstanceOf[AnyRef] def read(t: PType, rv: RegionValue): Annotation = read(t, rv.offset) @@ -273,36 +269,37 @@ object SafeIndexedSeq { class SelectFieldsRow( private[this] var old: Row, - private[this] val fieldMapping: Array[Int] + private[this] val fieldMapping: Array[Int], ) extends Row { def this( old: Row, oldPType: TStruct, - newPType: TStruct + newPType: TStruct, ) = this(old, newPType.fieldNames.map(name => oldPType.fieldIdx(name))) def this( old: Row, oldPType: PStruct, - newPType: PStruct - ) = { - this(old, + newPType: PStruct, + ) = + this( + old, (require( oldPType.fields.length <= old.length && newPType.fields.length <= old.length, - s"${oldPType}, ${newPType} ${old.length} $old") + s"$oldPType, $newPType ${old.length} $old", + ) -> - newPType.fieldNames.map(name => oldPType.fieldIdx(name)))._2 + newPType.fieldNames.map(name => oldPType.fieldIdx(name)))._2, ) - } - require(fieldMapping.forall(x => x < old.length), - s"${fieldMapping.toSeq}, ${old.length} $old") + require(fieldMapping.forall(x => x < old.length), s"${fieldMapping.toSeq}, ${old.length} $old") override def length = fieldMapping.length override def get(i: Int) = old.get(fieldMapping(i)) override def isNullAt(i: Int) = old.isNullAt(fieldMapping(i)) override def copy(): Row = new SelectFieldsRow(old.copy(), fieldMapping) + def set(newRow: Row): SelectFieldsRow = { old = newRow this @@ -343,7 +340,9 @@ class UnsafeNDArray(val pnd: PNDArray, val region: Region, val ndAddr: Long) ext val flat = new Array[Annotation](numElements.toInt) if (numElements > Int.MaxValue) { - throw new IllegalArgumentException(s"Cannot make an UnsafeNDArray with greater than Int.MaxValue entries. Shape was ${shape}") + throw new IllegalArgumentException( + s"Cannot make an UnsafeNDArray with greater than Int.MaxValue entries. Shape was $shape" + ) } while (idxIntoFlat < numElements) { @@ -388,17 +387,17 @@ class UnsafeNDArray(val pnd: PNDArray, val region: Region, val ndAddr: Long) ext true } - override def toString: String = { + override def toString: String = s"UnsafeNDArray of shape (${shape.mkString(", ")}) with elements ${getRowMajorElements()}" - } } -case class SafeNDArray(val shape: IndexedSeq[Long], rowMajorElements: IndexedSeq[Annotation]) extends NDArray { +case class SafeNDArray(val shape: IndexedSeq[Long], rowMajorElements: IndexedSeq[Annotation]) + extends NDArray { assert(shape.foldLeft(1L)(_ * _) == rowMajorElements.size) override def getRowMajorElements: IndexedSeq[Annotation] = rowMajorElements override def lookupElement(indices: IndexedSeq[Long]): Annotation = { - val flatIdx = indices.zip(shape).foldLeft(0L){ case (flatIdx, (index, shape)) => + val flatIdx = indices.zip(shape).foldLeft(0L) { case (flatIdx, (index, shape)) => flatIdx + index * shape } rowMajorElements(flatIdx.toInt) diff --git a/hail/src/main/scala/is/hail/annotations/ValueVisitor.scala b/hail/src/main/scala/is/hail/annotations/ValueVisitor.scala index 6ddd0cc80ab..dafc85a653c 100644 --- a/hail/src/main/scala/is/hail/annotations/ValueVisitor.scala +++ b/hail/src/main/scala/is/hail/annotations/ValueVisitor.scala @@ -1,8 +1,6 @@ package is.hail.annotations import is.hail.types.physical._ -import is.hail.utils.Interval -import is.hail.variant.{Call, Locus} trait ValueVisitor { def visitMissing(t: PType): Unit @@ -47,43 +45,34 @@ final class PrettyVisitor extends ValueVisitor { def result(): String = sb.result() - def visitMissing(t: PType) { + def visitMissing(t: PType): Unit = sb.append("NA") - } - def visitBoolean(b: Boolean) { + def visitBoolean(b: Boolean): Unit = sb.append(b) - } - def visitInt32(i: Int) { + def visitInt32(i: Int): Unit = sb.append(i) - } - def visitInt64(l: Long) { + def visitInt64(l: Long): Unit = sb.append(l) - } - def visitFloat32(f: Float) { + def visitFloat32(f: Float): Unit = sb.append(f) - } - def visitFloat64(d: Double) { + def visitFloat64(d: Double): Unit = sb.append(d) - } - def visitBinary(a: Array[Byte]) { + def visitBinary(a: Array[Byte]): Unit = sb.append("bytes...") - } - def visitString(s: String) { + def visitString(s: String): Unit = sb.append(s) - } - def enterStruct(t: PStruct) { + def enterStruct(t: PStruct): Unit = sb.append("{") - } - def enterField(f: PField) { + def enterField(f: PField): Unit = { if (f.index > 0) sb.append(",") sb.append(" ") @@ -91,25 +80,22 @@ final class PrettyVisitor extends ValueVisitor { sb.append(": ") } - def leaveField() {} + def leaveField(): Unit = {} - def leaveStruct() { + def leaveStruct(): Unit = sb.append(" }") - } - def enterTuple(t: PTuple) { + def enterTuple(t: PTuple): Unit = sb.append('(') - } - def leaveTuple() { + def leaveTuple(): Unit = sb.append(')') - } - def enterArray(t: PContainer, length: Int) { + def enterArray(t: PContainer, length: Int): Unit = { t match { - case t: PSet => + case _: PSet => sb.append("Set") - case t: PDict => + case _: PDict => sb.append("Dict") case _ => } @@ -118,15 +104,14 @@ final class PrettyVisitor extends ValueVisitor { sb.append(";") } - def leaveArray() { + def leaveArray(): Unit = sb.append("]") - } - def enterElement(i: Int) { + def enterElement(i: Int): Unit = { if (i > 0) sb.append(",") sb.append(" ") } - def leaveElement() {} + def leaveElement(): Unit = {} } diff --git a/hail/src/main/scala/is/hail/annotations/WritableRegionValue.scala b/hail/src/main/scala/is/hail/annotations/WritableRegionValue.scala index f0b84e174e6..26d101a2ec8 100644 --- a/hail/src/main/scala/is/hail/annotations/WritableRegionValue.scala +++ b/hail/src/main/scala/is/hail/annotations/WritableRegionValue.scala @@ -1,55 +1,63 @@ package is.hail.annotations -import java.io.{ObjectInputStream, ObjectOutputStream} +import is.hail.backend.HailStateManager +import is.hail.rvd.RVDContext +import is.hail.types.physical.{PStruct, PType} import scala.collection.generic.Growable import scala.collection.mutable.{ArrayBuffer, PriorityQueue} -import is.hail.backend.HailStateManager -import is.hail.types.physical.{PStruct, PType} -import is.hail.rvd.RVDContext +import java.io.{ObjectInputStream, ObjectOutputStream} + import sun.reflect.generics.reflectiveObjects.NotImplementedException object WritableRegionValue { - def apply(sm: HailStateManager, t: PType, initial: RegionValue, region: Region): WritableRegionValue = + def apply(sm: HailStateManager, t: PType, initial: RegionValue, region: Region) + : WritableRegionValue = WritableRegionValue(sm, t, initial.region, initial.offset, region) - def apply(sm: HailStateManager, t: PType, initialOffset: Long, targetRegion: Region): WritableRegionValue = { + def apply(sm: HailStateManager, t: PType, initialOffset: Long, targetRegion: Region) + : WritableRegionValue = { val wrv = WritableRegionValue(sm, t, targetRegion) wrv.set(initialOffset, deepCopy = true) wrv } - def apply(sm: HailStateManager, t: PType, initialRegion: Region, initialOffset: Long, targetRegion: Region): WritableRegionValue = { + def apply( + sm: HailStateManager, + t: PType, + initialRegion: Region, + initialOffset: Long, + targetRegion: Region, + ): WritableRegionValue = { val wrv = WritableRegionValue(sm, t, targetRegion) wrv.set(initialRegion, initialOffset) wrv } - def apply(sm: HailStateManager, t: PType, region: Region): WritableRegionValue = { + def apply(sm: HailStateManager, t: PType, region: Region): WritableRegionValue = new WritableRegionValue(t, region, sm) - } } class WritableRegionValue private ( val t: PType, val region: Region, - sm: HailStateManager + sm: HailStateManager, ) extends UnKryoSerializable { val value = RegionValue(region, 0) private val rvb: RegionValueBuilder = new RegionValueBuilder(sm, region) def offset: Long = value.offset - def setSelect(fromT: PStruct, fromFieldIdx: Array[Int], fromRV: RegionValue) { + def setSelect(fromT: PStruct, fromFieldIdx: Array[Int], fromRV: RegionValue): Unit = setSelect(fromT, fromFieldIdx, fromRV.region, fromRV.offset) - } - def setSelect(fromT: PStruct, fromFieldIdx: Array[Int], fromRegion: Region, fromOffset: Long) { + def setSelect(fromT: PStruct, fromFieldIdx: Array[Int], fromRegion: Region, fromOffset: Long) + : Unit = setSelect(fromT, fromFieldIdx, fromOffset, region.ne(fromRegion)) - } - def setSelect(fromT: PStruct, fromFieldIdx: Array[Int], fromOffset: Long, deepCopy: Boolean) { + def setSelect(fromT: PStruct, fromFieldIdx: Array[Int], fromOffset: Long, deepCopy: Boolean) + : Unit = { (t: @unchecked) match { case t: PStruct => region.clear() @@ -67,11 +75,10 @@ class WritableRegionValue private ( def set(rv: RegionValue): Unit = set(rv.region, rv.offset) - def set(fromRegion: Region, fromOffset: Long) { + def set(fromRegion: Region, fromOffset: Long): Unit = set(fromOffset, region.ne(fromRegion)) - } - def set(fromOffset: Long, deepCopy: Boolean) { + def set(fromOffset: Long, deepCopy: Boolean): Unit = { region.clear() rvb.start(t) rvb.addRegionValue(t, fromOffset, deepCopy) @@ -80,18 +87,19 @@ class WritableRegionValue private ( def pretty: String = value.pretty(t) - private def writeObject(s: ObjectOutputStream): Unit = { + private def writeObject(s: ObjectOutputStream): Unit = throw new NotImplementedException() - } - private def readObject(s: ObjectInputStream): Unit = { + private def readObject(s: ObjectInputStream): Unit = throw new NotImplementedException() - } } -class RegionValuePriorityQueue(sm: HailStateManager, val t: PType, ctx: RVDContext, ord: Ordering[RegionValue]) - extends Iterable[RegionValue] -{ +class RegionValuePriorityQueue( + sm: HailStateManager, + val t: PType, + ctx: RVDContext, + ord: Ordering[RegionValue], +) extends Iterable[RegionValue] { private val queue = new PriorityQueue[RegionValue]()(ord) private val rvb = new RegionValueBuilder(sm) @@ -101,7 +109,7 @@ class RegionValuePriorityQueue(sm: HailStateManager, val t: PType, ctx: RVDConte override def head: RegionValue = queue.head - def enqueue(rv: RegionValue) { + def enqueue(rv: RegionValue): Unit = { val region = ctx.freshRegion() rvb.set(region) rvb.start(t) @@ -114,7 +122,7 @@ class RegionValuePriorityQueue(sm: HailStateManager, val t: PType, ctx: RVDConte this } - def dequeue() { + def dequeue(): Unit = { val popped = queue.dequeue() popped.region.close() } @@ -123,7 +131,7 @@ class RegionValuePriorityQueue(sm: HailStateManager, val t: PType, ctx: RVDConte } class RegionValueArrayBuffer(val t: PType, region: Region, sm: HailStateManager) - extends Iterable[RegionValue] with Growable[RegionValue] { + extends Iterable[RegionValue] with Growable[RegionValue] { val value = RegionValue(region, 0) @@ -132,9 +140,8 @@ class RegionValueArrayBuffer(val t: PType, region: Region, sm: HailStateManager) def length = idx.length - def +=(rv: RegionValue): this.type = { + def +=(rv: RegionValue): this.type = this.append(rv.region, rv.offset) - } def append(fromRegion: Region, fromOffset: Long): this.type = { rvb.start(t) @@ -146,7 +153,8 @@ class RegionValueArrayBuffer(val t: PType, region: Region, sm: HailStateManager) def appendSelect( fromT: PStruct, fromFieldIdx: Array[Int], - fromRV: RegionValue): this.type = { + fromRV: RegionValue, + ): this.type = { (t: @unchecked) match { case t: PStruct => @@ -157,19 +165,21 @@ class RegionValueArrayBuffer(val t: PType, region: Region, sm: HailStateManager) this } - def clear() { + def clear(): Unit = { region.clear() idx.clear() rvb.clear() // remove } private var itIdx = 0 + private val it = new Iterator[RegionValue] { def next(): RegionValue = { value.setOffset(idx(itIdx)) itIdx += 1 value } + def hasNext: Boolean = itIdx < idx.size } diff --git a/hail/src/main/scala/is/hail/annotations/package.scala b/hail/src/main/scala/is/hail/annotations/package.scala index f0afdea37e9..c6dd8268188 100644 --- a/hail/src/main/scala/is/hail/annotations/package.scala +++ b/hail/src/main/scala/is/hail/annotations/package.scala @@ -8,7 +8,7 @@ package annotations { package object annotations { - type Annotation = Any + type Annotation = Any type Deleter = (Annotation) => Annotation diff --git a/hail/src/main/scala/is/hail/asm4s/AsmFunction.scala b/hail/src/main/scala/is/hail/asm4s/AsmFunction.scala index a6a606bcda0..b08cde595c4 100644 --- a/hail/src/main/scala/is/hail/asm4s/AsmFunction.scala +++ b/hail/src/main/scala/is/hail/asm4s/AsmFunction.scala @@ -1,23 +1,64 @@ package is.hail.asm4s -import is.hail.annotations.{Region, RegionValue} +import is.hail.annotations.Region trait AsmFunction0[R] { def apply(): R } -trait AsmFunction1[A,R] { def apply(a: A): R } -trait AsmFunction2[A,B,R] { def apply(a: A, b: B): R } -trait AsmFunction3[A,B,C,R] { def apply(a: A, b: B, c: C): R } -trait AsmFunction4[A,B,C,D,R] { def apply(a: A, b: B, c: C, d: D): R } -trait AsmFunction5[A,B,C,D,E,R] { def apply(a: A, b: B, c: C, d: D, e: E): R } -trait AsmFunction6[A,B,C,D,E,F,R] { def apply(a: A, b: B, c: C, d: D, e: E, f: F): R } -trait AsmFunction7[A,B,C,D,E,F,G,R] { def apply(a: A, b: B, c: C, d: D, e: E, f: F, g: G): R } -trait AsmFunction8[A,B,C,D,E,F,G,H,R] { def apply(a: A, b: B, c: C, d: D, e: E, f: F, g: G, h: H): R } -trait AsmFunction9[A,B,C,D,E,F,G,H,I,R] { def apply(a: A, b: B, c: C, d: D, e: E, f: F, g: G, h: H, i: I): R } -trait AsmFunction10[A,B,C,D,E,F,G,H,I,J,R] { def apply(a: A, b: B, c: C, d: D, e: E, f: F, g: G, h: H, i: I, j: J): R } -trait AsmFunction12[T1,T2,T3,T4,T5,T6,T7,T8,T9,T10,T11,T12,R] { - def apply(t1: T1, t2: T2, t3: T3, t4: T4, t5: T5, t6: T6, t7: T7, t8: T8, t9: T9, t10: T10, t11: T11, t12: T12): R -} -trait AsmFunction13[T1,T2,T3,T4,T5,T6,T7,T8,T9,T10,T11,T12,T13,R] { - def apply(t1: T1, t2: T2, t3: T3, t4: T4, t5: T5, t6: T6, t7: T7, t8: T8, t9: T9, t10: T10, t11: T11, t12: T12, t13: T13): R +trait AsmFunction1[A, R] { def apply(a: A): R } +trait AsmFunction2[A, B, R] { def apply(a: A, b: B): R } +trait AsmFunction3[A, B, C, R] { def apply(a: A, b: B, c: C): R } +trait AsmFunction4[A, B, C, D, R] { def apply(a: A, b: B, c: C, d: D): R } +trait AsmFunction5[A, B, C, D, E, R] { def apply(a: A, b: B, c: C, d: D, e: E): R } +trait AsmFunction6[A, B, C, D, E, F, R] { def apply(a: A, b: B, c: C, d: D, e: E, f: F): R } + +trait AsmFunction7[A, B, C, D, E, F, G, R] { + def apply(a: A, b: B, c: C, d: D, e: E, f: F, g: G): R +} + +trait AsmFunction8[A, B, C, D, E, F, G, H, R] { + def apply(a: A, b: B, c: C, d: D, e: E, f: F, g: G, h: H): R +} + +trait AsmFunction9[A, B, C, D, E, F, G, H, I, R] { + def apply(a: A, b: B, c: C, d: D, e: E, f: F, g: G, h: H, i: I): R +} + +trait AsmFunction10[A, B, C, D, E, F, G, H, I, J, R] { + def apply(a: A, b: B, c: C, d: D, e: E, f: F, g: G, h: H, i: I, j: J): R +} + +trait AsmFunction12[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, R] { + def apply( + t1: T1, + t2: T2, + t3: T3, + t4: T4, + t5: T5, + t6: T6, + t7: T7, + t8: T8, + t9: T9, + t10: T10, + t11: T11, + t12: T12, + ): R +} + +trait AsmFunction13[T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, R] { + def apply( + t1: T1, + t2: T2, + t3: T3, + t4: T4, + t5: T5, + t6: T6, + t7: T7, + t8: T8, + t9: T9, + t10: T10, + t11: T11, + t12: T12, + t13: T13, + ): R } trait AsmFunction1RegionUnit { diff --git a/hail/src/main/scala/is/hail/asm4s/ClassBuilder.scala b/hail/src/main/scala/is/hail/asm4s/ClassBuilder.scala index 7ce44085666..7b793fa365d 100644 --- a/hail/src/main/scala/is/hail/asm4s/ClassBuilder.scala +++ b/hail/src/main/scala/is/hail/asm4s/ClassBuilder.scala @@ -3,22 +3,30 @@ package is.hail.asm4s import is.hail.expr.ir.EmitCodeBuilder import is.hail.lir import is.hail.utils._ + +import scala.collection.mutable +import scala.language.existentials + +import java.io._ +import java.nio.charset.StandardCharsets + +import javassist.bytecode.DuplicateMemberException import org.apache.spark.TaskContext import org.objectweb.asm.ClassReader -import org.objectweb.asm.Opcodes._ -import org.objectweb.asm.tree._ import org.objectweb.asm.util.{Textifier, TraceClassVisitor} -import java.io._ -import java.nio.charset.StandardCharsets -import scala.collection.mutable +object Field { + def apply[T](cb: ClassBuilder[_], name: String)(implicit ti: TypeInfo[T]): Field[T] = + new Field[T](cb.lclass.newField(name, ti)) +} -class Field[T: TypeInfo](classBuilder: ClassBuilder[_], val name: String) { - val ti: TypeInfo[T] = implicitly +class Field[T] private (private val lf: lir.Field) extends AnyVal { - val lf: lir.Field = classBuilder.lclass.newField(name, typeInfo[T]) + def ti: TypeInfo[T] = + lf.ti.asInstanceOf[TypeInfo[T]] - def get(obj: Code[_]): Code[T] = Code(obj, lir.getField(lf)) + def name: String = + lf.name def get(obj: Value[_]): Value[T] = new Value[T] { override def get: Code[T] = Code(obj, lir.getField(lf)) @@ -36,10 +44,18 @@ class Field[T: TypeInfo](classBuilder: ClassBuilder[_], val name: String) { } } -class StaticField[T: TypeInfo](classBuilder: ClassBuilder[_], val name: String) { - val ti: TypeInfo[T] = implicitly +object StaticField { + def apply[T](cb: ClassBuilder[_], name: String)(implicit ti: TypeInfo[T]): StaticField[T] = + new StaticField[T](cb.lclass.newStaticField(name, ti)) +} + +case class StaticField[T] private (lf: lir.StaticField) extends AnyVal { + + def ti: TypeInfo[T] = + lf.ti.asInstanceOf[TypeInfo[T]] - val lf: lir.StaticField = classBuilder.lclass.newStaticField(name, typeInfo[T]) + def name: String = + lf.name def get(): Code[T] = Code(lir.getStaticField(lf)) @@ -59,14 +75,14 @@ class ClassesBytes(classesBytes: Array[(String, Array[Byte])]) extends Serializa synchronized { if (!loaded) { classesBytes.foreach { case (n, bytes) => - try { + try hcl.loadOrDefineClass(n, bytes) - } catch { + catch { case e: Exception => val buffer = new ByteArrayOutputStream() FunctionBuilder.bytesToBytecodeString(bytes, buffer) val classJVMByteCodeAsEscapedStr = buffer.toString(StandardCharsets.UTF_8.name()) - log.error(s"Failed to load bytecode ${e}:\n" + classJVMByteCodeAsEscapedStr) + log.error(s"Failed to load bytecode $e:\n" + classJVMByteCodeAsEscapedStr) throw e } } @@ -77,12 +93,16 @@ class ClassesBytes(classesBytes: Array[(String, Array[Byte])]) extends Serializa } } -class AsmTuple[C](val cb: ClassBuilder[C], val fields: IndexedSeq[Field[_]], val ctor: MethodBuilder[C]) { +class AsmTuple[C]( + val cb: ClassBuilder[C], + val fields: IndexedSeq[Field[_]], + val ctor: MethodBuilder[C], +) { val ti: TypeInfo[_] = cb.ti def newTuple(elems: IndexedSeq[Code[_]]): Code[C] = Code.newInstance(cb, ctor, elems) - def loadElementsAny(t: Value[_]): IndexedSeq[Value[_]] = fields.map(_.get(coerce[C](t) )) + def loadElementsAny(t: Value[_]): IndexedSeq[Value[_]] = fields.map(_.get(coerce[C](t))) def loadElements(t: Value[C]): IndexedSeq[Value[_]] = fields.map(_.get(t)) } @@ -92,15 +112,18 @@ trait WrappedModuleBuilder { def newClass[C](name: String)(implicit cti: TypeInfo[C]): ClassBuilder[C] = modb.newClass[C](name) - def genClass[C](baseName: String)(implicit cti: TypeInfo[C]): ClassBuilder[C] = modb.genClass[C](baseName) + def genClass[C](baseName: String)(implicit cti: TypeInfo[C]): ClassBuilder[C] = + modb.genClass[C](baseName) - def classesBytes(writeIRs: Boolean, print: Option[PrintWriter] = None): ClassesBytes = modb.classesBytes(writeIRs, print) + def classesBytes(writeIRs: Boolean, print: Option[PrintWriter] = None): ClassesBytes = + modb.classesBytes(writeIRs, print) } class ModuleBuilder() { val classes = new mutable.ArrayBuffer[ClassBuilder[_]]() - def newClass[C](name: String, sourceFile: Option[String] = None)(implicit cti: TypeInfo[C]): ClassBuilder[C] = { + def newClass[C](name: String, sourceFile: Option[String] = None)(implicit cti: TypeInfo[C]) + : ClassBuilder[C] = { val c = new ClassBuilder[C](this, name, sourceFile) if (cti != UnitInfo) c.addInterface(cti.iname) @@ -111,34 +134,27 @@ class ModuleBuilder() { private val tuples = mutable.Map[IndexedSeq[TypeInfo[_]], AsmTuple[_]]() def tupleClass(fieldTypes: IndexedSeq[TypeInfo[_]]): AsmTuple[_] = { - tuples.getOrElseUpdate(fieldTypes, { - val kb = genClass[Unit](s"Tuple${fieldTypes.length}") - val fields = fieldTypes.zipWithIndex.map { case (ti, i) => - kb.newField(s"_$i")(ti) - } - val ctor = kb.newMethod("", fieldTypes, UnitInfo) - ctor.emitWithBuilder { cb => - // FIXME, maybe a more elegant way to do this? - val L = new lir.Block() - L.append( - lir.methodStmt(INVOKESPECIAL, - "java/lang/Object", - "", - "()V", - false, - UnitInfo, - FastSeq(lir.load(ctor._this.asInstanceOf[LocalRef[_]].l)))) - cb += new VCode(L, L, null) - fields.zipWithIndex.foreach { case (f, i) => - cb += f.putAny(ctor._this, ctor.getArg(i + 1)(f.ti).get) + tuples.getOrElseUpdate( + fieldTypes, { + val kb = genClass[Unit](s"Tuple${fieldTypes.length}") + val fields = fieldTypes.zipWithIndex.map { case (ti, i) => + kb.newField(s"_$i")(ti) } - Code._empty - } - new AsmTuple(kb, fields, ctor) - }) + val ctor = kb.newMethod("", fieldTypes, UnitInfo) + ctor.emitWithBuilder { cb => + cb += kb.super_.invoke(coerce[Object](cb.this_), Array()) + fields.zipWithIndex.foreach { case (f, i) => + cb += f.putAny(ctor.this_, ctor.getArg(i + 1)(f.ti).get) + } + Code._empty + } + new AsmTuple(kb, fields, ctor) + }, + ) } - def genClass[C](baseName: String)(implicit cti: TypeInfo[C]): ClassBuilder[C] = newClass[C](genName("C", baseName)) + def genClass[C](baseName: String)(implicit cti: TypeInfo[C]): ClassBuilder[C] = + newClass[C](genName("C", baseName)) var classesBytes: ClassesBytes = _ @@ -148,7 +164,8 @@ class ModuleBuilder() { classes .iterator .flatMap(c => c.classBytes(writeIRs, print)) - .toArray) + .toArray + ) } classesBytes @@ -159,13 +176,12 @@ class ModuleBuilder() { private var nStaticFieldsOnThisClass: Int = maxFieldsOrMethodsOnClass private var staticCls: ClassBuilder[_] = null - private def incrStaticClassSize(n: Int = 1): Unit = { + private def incrStaticClassSize(n: Int = 1): Unit = if (nStaticFieldsOnThisClass + n >= maxFieldsOrMethodsOnClass) { nStaticFieldsOnThisClass = n staticFieldWrapperIdx += 1 staticCls = genClass[Unit](s"staticWrapperClass_$staticFieldWrapperIdx") } - } def genStaticField[T: TypeInfo](name: String = null): StaticFieldRef[T] = { incrStaticClassSize() @@ -176,11 +192,10 @@ class ModuleBuilder() { var _objectsField: Settable[Array[AnyRef]] = _ var _objects: BoxedArrayBuilder[AnyRef] = _ - def setObjects(cb: EmitCodeBuilder, objects: Code[Array[AnyRef]]): Unit = { + def setObjects(cb: EmitCodeBuilder, objects: Code[Array[AnyRef]]): Unit = cb.assign(_objectsField, objects) - } - def getObject[T <: AnyRef : TypeInfo](obj: T): Code[T] = { + def getObject[T <: AnyRef: TypeInfo](obj: T): Code[T] = { if (_objectsField == null) { _objectsField = genStaticField[Array[AnyRef]]() _objects = new BoxedArrayBuilder[AnyRef]() @@ -211,95 +226,116 @@ trait WrappedClassBuilder[C] extends WrappedModuleBuilder { def newStaticField[T: TypeInfo](name: String): StaticField[T] = cb.newStaticField[T](name) - def newStaticField[T: TypeInfo](name: String, init: Code[T]): StaticField[T] = cb.newStaticField[T](name, init) + def newStaticField[T: TypeInfo](name: String, init: Code[T]): StaticField[T] = + cb.newStaticField[T](name, init) def genField[T: TypeInfo](baseName: String): Field[T] = cb.genField(baseName) - def genFieldThisRef[T: TypeInfo](name: String = null): ThisFieldRef[T] = cb.genFieldThisRef[T](name) + def getField[T: TypeInfo](name: String): Field[T] = cb.getField(name) - def genLazyFieldThisRef[T: TypeInfo](setup: Code[T], name: String = null): Value[T] = cb.genLazyFieldThisRef(setup, name) + def genFieldThisRef[T: TypeInfo](name: String = null): ThisFieldRef[T] = + cb.genFieldThisRef[T](name) + + def genLazyFieldThisRef[T: TypeInfo](setup: Code[T], name: String = null): Value[T] = + cb.genLazyFieldThisRef(setup, name) - def getOrDefineLazyField[T: TypeInfo](setup: Code[T], id: Any): Value[T] = cb.getOrDefineLazyField(setup, id) + def getOrDefineLazyField[T: TypeInfo](setup: Code[T], id: Any): Value[T] = + cb.getOrDefineLazyField(setup, id) def fieldBuilder: SettableBuilder = cb.fieldBuilder - def newMethod(name: String, parameterTypeInfo: IndexedSeq[TypeInfo[_]], returnTypeInfo: TypeInfo[_]): MethodBuilder[C] = + def newMethod( + name: String, + parameterTypeInfo: IndexedSeq[TypeInfo[_]], + returnTypeInfo: TypeInfo[_], + ): MethodBuilder[C] = cb.newMethod(name, parameterTypeInfo, returnTypeInfo) - def newMethod(name: String, + def newMethod( + name: String, maybeGenericParameterTypeInfo: IndexedSeq[MaybeGenericTypeInfo[_]], - maybeGenericReturnTypeInfo: MaybeGenericTypeInfo[_]): MethodBuilder[C] = + maybeGenericReturnTypeInfo: MaybeGenericTypeInfo[_], + ): MethodBuilder[C] = cb.newMethod(name, maybeGenericParameterTypeInfo, maybeGenericReturnTypeInfo) - def newStaticMethod(name: String, parameterTypeInfo: IndexedSeq[TypeInfo[_]], returnTypeInfo: TypeInfo[_]): MethodBuilder[C] = + def newStaticMethod( + name: String, + parameterTypeInfo: IndexedSeq[TypeInfo[_]], + returnTypeInfo: TypeInfo[_], + ): MethodBuilder[C] = cb.newStaticMethod(name, parameterTypeInfo, returnTypeInfo) - def getOrGenMethod( - baseName: String, key: Any, argsInfo: IndexedSeq[TypeInfo[_]], returnInfo: TypeInfo[_] - )(body: MethodBuilder[C] => Unit): MethodBuilder[C] = - cb.getOrGenMethod(baseName, key, argsInfo, returnInfo)(body) - - def result(writeIRs: Boolean, print: Option[PrintWriter] = None): (HailClassLoader) => C = cb.result(writeIRs, print) - - def _this: Value[C] = cb._this + def result(writeIRs: Boolean, print: Option[PrintWriter] = None): (HailClassLoader) => C = + cb.result(writeIRs, print) - def genMethod(baseName: String, argsInfo: IndexedSeq[TypeInfo[_]], returnInfo: TypeInfo[_]): MethodBuilder[C] = + def genMethod(baseName: String, argsInfo: IndexedSeq[TypeInfo[_]], returnInfo: TypeInfo[_]) + : MethodBuilder[C] = cb.genMethod(baseName, argsInfo, returnInfo) def genMethod[R: TypeInfo](baseName: String): MethodBuilder[C] = cb.genMethod[R](baseName) - def genMethod[A: TypeInfo, R: TypeInfo](baseName: String): MethodBuilder[C] = cb.genMethod[A, R](baseName) + def genMethod[A: TypeInfo, R: TypeInfo](baseName: String): MethodBuilder[C] = + cb.genMethod[A, R](baseName) - def genMethod[A1: TypeInfo, A2: TypeInfo, R: TypeInfo](baseName: String): MethodBuilder[C] = cb.genMethod[A1, A2, R](baseName) + def genMethod[A1: TypeInfo, A2: TypeInfo, R: TypeInfo](baseName: String): MethodBuilder[C] = + cb.genMethod[A1, A2, R](baseName) - def genMethod[A1: TypeInfo, A2: TypeInfo, A3: TypeInfo, R: TypeInfo](baseName: String): MethodBuilder[C] = cb.genMethod[A1, A2, A3, R](baseName) + def genMethod[A1: TypeInfo, A2: TypeInfo, A3: TypeInfo, R: TypeInfo](baseName: String) + : MethodBuilder[C] = cb.genMethod[A1, A2, A3, R](baseName) - def genMethod[A1: TypeInfo, A2: TypeInfo, A3: TypeInfo, A4: TypeInfo, R: TypeInfo](baseName: String): MethodBuilder[C] = cb.genMethod[A1, A2, A3, A4, R](baseName) + def genMethod[A1: TypeInfo, A2: TypeInfo, A3: TypeInfo, A4: TypeInfo, R: TypeInfo]( + baseName: String + ): MethodBuilder[C] = cb.genMethod[A1, A2, A3, A4, R](baseName) - def genMethod[A1: TypeInfo, A2: TypeInfo, A3: TypeInfo, A4: TypeInfo, A5: TypeInfo, R: TypeInfo](baseName: String): MethodBuilder[C] = cb.genMethod[A1, A2, A3, A4, A5, R](baseName) + def genMethod[A1: TypeInfo, A2: TypeInfo, A3: TypeInfo, A4: TypeInfo, A5: TypeInfo, R: TypeInfo]( + baseName: String + ): MethodBuilder[C] = cb.genMethod[A1, A2, A3, A4, A5, R](baseName) - def genStaticMethod(name: String, parameterTypeInfo: IndexedSeq[TypeInfo[_]], returnTypeInfo: TypeInfo[_]): MethodBuilder[C] = + def genStaticMethod( + name: String, + parameterTypeInfo: IndexedSeq[TypeInfo[_]], + returnTypeInfo: TypeInfo[_], + ): MethodBuilder[C] = cb.genStaticMethod(name, parameterTypeInfo, returnTypeInfo) } class ClassBuilder[C]( val modb: ModuleBuilder, val className: String, - val sourceFile: Option[String] + val sourceFile: Option[String], ) extends WrappedModuleBuilder { val ti: ClassInfo[C] = new ClassInfo[C](className) val lclass = new lir.Classx[C](className, "java/lang/Object", sourceFile) - val methods: mutable.ArrayBuffer[MethodBuilder[C]] = new mutable.ArrayBuffer[MethodBuilder[C]](16) - val fields: mutable.ArrayBuffer[FieldNode] = new mutable.ArrayBuffer[FieldNode](16) - - val lazyFieldMemo: mutable.Map[Any, Value[_]] = mutable.Map.empty - - val lInitBuilder = new MethodBuilder[C](this, "", FastSeq(), UnitInfo) - val lInit = lInitBuilder.lmethod - - var initBody: Code[Unit] = { - val L = new lir.Block() - L.append( - lir.methodStmt(INVOKESPECIAL, - "java/lang/Object", - "", - "()V", - false, - UnitInfo, - FastSeq(lir.load(lInit.getParam(0))))) - new VCode(L, L, null) - } + private[this] val methods: mutable.ArrayBuffer[MethodBuilder[C]] = + new mutable.ArrayBuffer[MethodBuilder[C]](16) + + private[this] val fields: mutable.Map[String, Either[StaticField[_], Field[_]]] = + new mutable.HashMap() - private var lClinit: lir.Method = _ + private[this] val lazyFieldMemo: mutable.Map[Any, Value[_]] = + mutable.Map.empty - var clinitBody: Option[Code[Unit]] = None + private[this] val lInitBuilder = new MethodBuilder[C](this, "", FastSeq(), UnitInfo) + private[this] val lInit = lInitBuilder.lmethod - def emitInit(c: Code[Unit]): Unit = { + val super_ : Invokeable[Object, Unit] = + Invokeable(classOf[Object], classOf[Object].getConstructor()) + + private[this] var initBody: Code[Unit] = + super_.invoke(coerce[Object](this_), Array()) + + private[this] var lClinit: lir.Method = _ + + private[this] var clinitBody: Option[Code[Unit]] = None + + def ctor: MethodBuilder[C] = + lInitBuilder + + def emitInit(c: Code[Unit]): Unit = initBody = Code(initBody, c) - } def emitInitI(f: CodeBuilder => Unit): Unit = { val body = CodeBuilder.scopedVoid(lInitBuilder)(f) @@ -316,66 +352,125 @@ class ClassBuilder[C]( } } - def addInterface(name: String): Unit = lclass.addInterface(name) + def addInterface(name: String): Unit = + lclass.addInterface(name) + + def lookupMethod( + name: String, + paramsTyInfo: IndexedSeq[TypeInfo[_]], + retTyInfo: TypeInfo[_], + isStatic: Boolean, + ): Option[MethodBuilder[C]] = + methods.find { m => + m.methodName == name && + m.parameterTypeInfo == paramsTyInfo && + m.returnTypeInfo == retTyInfo && + m.isStatic == isStatic + } + + def newMethod( + name: String, + parameterTypeInfo: IndexedSeq[TypeInfo[_]], + returnTypeInfo: TypeInfo[_], + ): MethodBuilder[C] = { + if (lookupMethod(name, parameterTypeInfo, returnTypeInfo, isStatic = false).isDefined) { + val signature = s"${parameterTypeInfo.mkString("(", ",", ")")} => $returnTypeInfo" + throw new DuplicateMemberException( + s"Method '$name: $signature' already defined in class '$className'." + ) + } - def newMethod(name: String, parameterTypeInfo: IndexedSeq[TypeInfo[_]], returnTypeInfo: TypeInfo[_]): MethodBuilder[C] = { val mb = new MethodBuilder[C](this, name, parameterTypeInfo, returnTypeInfo) - methods.append(mb) + methods += mb mb } - def newMethod(name: String, + def newStaticMethod( + name: String, + parameterTypeInfo: IndexedSeq[TypeInfo[_]], + returnTypeInfo: TypeInfo[_], + ): MethodBuilder[C] = { + if (lookupMethod(name, parameterTypeInfo, returnTypeInfo, isStatic = true).isDefined) { + val signature = s"${parameterTypeInfo.mkString("(", ",", ")")} => $returnTypeInfo" + throw new DuplicateMemberException( + s"Static method '$name: $signature' already defined in class '$className'." + ) + } + + val mb = new MethodBuilder[C](this, name, parameterTypeInfo, returnTypeInfo, isStatic = true) + methods += mb + mb + } + + def newMethod( + name: String, maybeGenericParameterTypeInfo: IndexedSeq[MaybeGenericTypeInfo[_]], - maybeGenericReturnTypeInfo: MaybeGenericTypeInfo[_]): MethodBuilder[C] = { + maybeGenericReturnTypeInfo: MaybeGenericTypeInfo[_], + ): MethodBuilder[C] = { val parameterTypeInfo: IndexedSeq[TypeInfo[_]] = maybeGenericParameterTypeInfo.map(_.base) val returnTypeInfo: TypeInfo[_] = maybeGenericReturnTypeInfo.base val m = newMethod(name, parameterTypeInfo, returnTypeInfo) if (maybeGenericParameterTypeInfo.exists(_.isGeneric) || maybeGenericReturnTypeInfo.isGeneric) { - val generic = newMethod(name, maybeGenericParameterTypeInfo.map(_.generic), maybeGenericReturnTypeInfo.generic) + val generic = newMethod( + name, + maybeGenericParameterTypeInfo.map(_.generic), + maybeGenericReturnTypeInfo.generic, + ) generic.emitWithBuilder { cb => - maybeGenericReturnTypeInfo.castToGeneric(cb, - m.invoke(cb, maybeGenericParameterTypeInfo.zipWithIndex.map { case (ti, i) => - ti.castFromGeneric(cb, generic.getArg(i + 1)(ti.generic)) - }: _*)) -} + maybeGenericReturnTypeInfo.castToGeneric( + cb, + cb.invoke( + m, + cb.mb.cb.this_ +: maybeGenericParameterTypeInfo.zipWithIndex.map { case (ti, i) => + ti.castFromGeneric(cb, generic.getArg(i + 1)(ti.generic)) + }: _* + ), + ) + } } m } - def newStaticMethod(name: String, parameterTypeInfo: IndexedSeq[TypeInfo[_]], returnTypeInfo: TypeInfo[_]): MethodBuilder[C] = { - val mb = new MethodBuilder[C](this, name, parameterTypeInfo, returnTypeInfo, isStatic = true) - methods.append(mb) - mb - } + private def raiseIfFieldExists(name: String): Unit = + fields.get(name).foreach { f => + val (static_, name, ti) = f.fold(f => ("Static ", f.name, f.ti), f => ("", f.name, f.ti)) + throw new DuplicateMemberException( + s"${static_}Field '$name: $ti' already defined in '$className'." + ) + } - def newField[T: TypeInfo](name: String): Field[T] = new Field[T](this, name) + def newField[T](name: String)(implicit ty: TypeInfo[T]): Field[T] = { + raiseIfFieldExists(name) + val field = Field[T](this, name) + fields += name -> Right(field) + field + } - def newStaticField[T: TypeInfo](name: String): StaticField[T] = new StaticField[T](this, name) + def newStaticField[T](name: String)(implicit ty: TypeInfo[T]): StaticField[T] = { + raiseIfFieldExists(name) + val field = StaticField[T](this, name) + fields += name -> Left(field) + field + } def newStaticField[T: TypeInfo](name: String, init: Code[T]): StaticField[T] = { - val f = new StaticField[T](this, name) + val f = newStaticField[T](name) emitClinit(f.put(init)) f } - def genField[T: TypeInfo](baseName: String): Field[T] = newField(genName("f", baseName)) + def genField[T: TypeInfo](baseName: String): Field[T] = + newField(genName("f", baseName)) - private[this] val methodMemo: mutable.Map[Any, MethodBuilder[C]] = mutable.HashMap.empty - - def getOrGenMethod(baseName: String, key: Any, argsInfo: IndexedSeq[TypeInfo[_]], returnInfo: TypeInfo[_]) - (f: MethodBuilder[C] => Unit): MethodBuilder[C] = { - methodMemo.get(key) match { - case Some(mb) => mb - case None => - val mb = newMethod(genName("M", baseName), argsInfo, returnInfo) - f(mb) - methodMemo(key) = mb - mb - } - } + def getField[T](name: String)(implicit ti: TypeInfo[T]): Field[T] = + fields.get(name).fold(Option.empty[Field[T]]) { + case Right(field) if field.ti == ti => Some(field.asInstanceOf[Field[T]]) + case _ => None + }.getOrElse(throw new NoSuchFieldError(s"No field matching '$name: $ti' in '$className'.")) - def classBytes(writeIRs: Boolean, print: Option[PrintWriter] = None): Array[(String, Array[Byte])] = { + def classBytes(writeIRs: Boolean, print: Option[PrintWriter] = None) + : Array[(String, Array[Byte])] = { assert(initBody.start != null) initBody.end.append(lir.returnx()) lInit.setEntry(initBody.start) @@ -393,14 +488,16 @@ class ClassBuilder[C]( lclass.asBytes(writeIRs, print) } - def result(writeIRs: Boolean, print: Option[PrintWriter] = None): (HailClassLoader) => C = { + def result(writeIRs: Boolean, print: Option[PrintWriter] = None): HailClassLoader => C = { val n = className.replace("/", ".") val classesBytes = modb.classesBytes(writeIRs) - assert(TaskContext.get() == null, - "FunctionBuilder emission should happen on master, but happened on worker") + assert( + TaskContext.get() == null, + "FunctionBuilder emission should happen on master, but happened on worker", + ) - new ((HailClassLoader) => C) with java.io.Serializable { + new (HailClassLoader => C) with java.io.Serializable { @transient @volatile private var theClass: Class[_] = null def apply(hcl: HailClassLoader): C = { @@ -418,10 +515,12 @@ class ClassBuilder[C]( } } - def _this: Value[C] = new LocalRef[C](new lir.Parameter(null, 0, ti)) + def this_ : Value[C] = + new LocalRef[C](new lir.Parameter(null, 0, ti)) val fieldBuilder: SettableBuilder = new SettableBuilder { - def newSettable[T](name: String)(implicit tti: TypeInfo[T]): Settable[T] = genFieldThisRef[T](name) + def newSettable[T](name: String)(implicit tti: TypeInfo[T]): Settable[T] = + genFieldThisRef[T](name) } def genFieldThisRef[T: TypeInfo](name: String = null): ThisFieldRef[T] = @@ -430,11 +529,13 @@ class ClassBuilder[C]( def genLazyFieldThisRef[T: TypeInfo](setup: Code[T], name: String = null): Value[T] = new ThisLazyFieldRef[T](this, name, setup) - def getOrDefineLazyField[T: TypeInfo](setup: Code[T], id: Any): Value[T] = { - lazyFieldMemo.getOrElseUpdate(id, genLazyFieldThisRef[T](setup)).asInstanceOf[ThisLazyFieldRef[T]] - } + def getOrDefineLazyField[T: TypeInfo](setup: Code[T], id: Any): Value[T] = + lazyFieldMemo.getOrElseUpdate(id, genLazyFieldThisRef[T](setup)).asInstanceOf[ThisLazyFieldRef[ + T + ]] - def genMethod(baseName: String, argsInfo: IndexedSeq[TypeInfo[_]], returnInfo: TypeInfo[_]): MethodBuilder[C] = + def genMethod(baseName: String, argsInfo: IndexedSeq[TypeInfo[_]], returnInfo: TypeInfo[_]) + : MethodBuilder[C] = newMethod(genName("m", baseName), argsInfo, returnInfo) def genMethod[R: TypeInfo](baseName: String): MethodBuilder[C] = @@ -446,21 +547,35 @@ class ClassBuilder[C]( def genMethod[A1: TypeInfo, A2: TypeInfo, R: TypeInfo](baseName: String): MethodBuilder[C] = genMethod(baseName, FastSeq[TypeInfo[_]](typeInfo[A1], typeInfo[A2]), typeInfo[R]) - def genMethod[A1: TypeInfo, A2: TypeInfo, A3: TypeInfo, R: TypeInfo](baseName: String): MethodBuilder[C] = + def genMethod[A1: TypeInfo, A2: TypeInfo, A3: TypeInfo, R: TypeInfo](baseName: String) + : MethodBuilder[C] = genMethod(baseName, FastSeq[TypeInfo[_]](typeInfo[A1], typeInfo[A2], typeInfo[A3]), typeInfo[R]) - def genMethod[A1: TypeInfo, A2: TypeInfo, A3: TypeInfo, A4: TypeInfo, R: TypeInfo](baseName: String): MethodBuilder[C] = - genMethod(baseName, FastSeq[TypeInfo[_]](typeInfo[A1], typeInfo[A2], typeInfo[A3], typeInfo[A4]), typeInfo[R]) - - def genMethod[A1: TypeInfo, A2: TypeInfo, A3: TypeInfo, A4: TypeInfo, A5: TypeInfo, R: TypeInfo](baseName: String): MethodBuilder[C] = - genMethod(baseName, FastSeq[TypeInfo[_]](typeInfo[A1], typeInfo[A2], typeInfo[A3], typeInfo[A4], typeInfo[A5]), typeInfo[R]) - - def genStaticMethod(baseName: String, argsInfo: IndexedSeq[TypeInfo[_]], returnInfo: TypeInfo[_]): MethodBuilder[C] = + def genMethod[A1: TypeInfo, A2: TypeInfo, A3: TypeInfo, A4: TypeInfo, R: TypeInfo]( + baseName: String + ): MethodBuilder[C] = + genMethod( + baseName, + FastSeq[TypeInfo[_]](typeInfo[A1], typeInfo[A2], typeInfo[A3], typeInfo[A4]), + typeInfo[R], + ) + + def genMethod[A1: TypeInfo, A2: TypeInfo, A3: TypeInfo, A4: TypeInfo, A5: TypeInfo, R: TypeInfo]( + baseName: String + ): MethodBuilder[C] = + genMethod( + baseName, + FastSeq[TypeInfo[_]](typeInfo[A1], typeInfo[A2], typeInfo[A3], typeInfo[A4], typeInfo[A5]), + typeInfo[R], + ) + + def genStaticMethod(baseName: String, argsInfo: IndexedSeq[TypeInfo[_]], returnInfo: TypeInfo[_]) + : MethodBuilder[C] = newStaticMethod(genName("sm", baseName), argsInfo, returnInfo) } object FunctionBuilder { - def bytesToBytecodeString(bytes: Array[Byte], out: OutputStream) { + def bytesToBytecodeString(bytes: Array[Byte], out: OutputStream): Unit = { val tcv = new TraceClassVisitor(null, new Textifier, new PrintWriter(out)) new ClassReader(bytes).accept(tcv, 0) } @@ -468,8 +583,9 @@ object FunctionBuilder { def apply[F]( baseName: String, argInfo: IndexedSeq[MaybeGenericTypeInfo[_]], - returnInfo: MaybeGenericTypeInfo[_] - )(implicit fti: TypeInfo[F]): FunctionBuilder[F] = { + returnInfo: MaybeGenericTypeInfo[_], + )(implicit fti: TypeInfo[F] + ): FunctionBuilder[F] = { val modb: ModuleBuilder = new ModuleBuilder() val cb: ClassBuilder[F] = modb.genClass[F](baseName) val apply = cb.newMethod("apply", argInfo, returnInfo) @@ -482,14 +598,29 @@ object FunctionBuilder { def apply[A1: TypeInfo, R: TypeInfo](baseName: String): FunctionBuilder[AsmFunction1[A1, R]] = apply[AsmFunction1[A1, R]](baseName, Array(GenericTypeInfo[A1]), GenericTypeInfo[R]) - def apply[A1: TypeInfo, A2: TypeInfo, R: TypeInfo](baseName: String): FunctionBuilder[AsmFunction2[A1, A2, R]] = - apply[AsmFunction2[A1, A2, R]](baseName, Array(GenericTypeInfo[A1], GenericTypeInfo[A2]), GenericTypeInfo[R]) - - def apply[A1: TypeInfo, A2: TypeInfo, A3: TypeInfo, R: TypeInfo](baseName: String): FunctionBuilder[AsmFunction3[A1, A2, A3, R]] = - apply[AsmFunction3[A1, A2, A3, R]](baseName, Array(GenericTypeInfo[A1], GenericTypeInfo[A2], GenericTypeInfo[A3]), GenericTypeInfo[R]) - - def apply[A1: TypeInfo, A2: TypeInfo, A3: TypeInfo, A4: TypeInfo, R: TypeInfo](baseName: String): FunctionBuilder[AsmFunction4[A1, A2, A3, A4, R]] = - apply[AsmFunction4[A1, A2, A3, A4, R]](baseName, Array(GenericTypeInfo[A1], GenericTypeInfo[A2], GenericTypeInfo[A3], GenericTypeInfo[A4]), GenericTypeInfo[R]) + def apply[A1: TypeInfo, A2: TypeInfo, R: TypeInfo](baseName: String) + : FunctionBuilder[AsmFunction2[A1, A2, R]] = + apply[AsmFunction2[A1, A2, R]]( + baseName, + Array(GenericTypeInfo[A1], GenericTypeInfo[A2]), + GenericTypeInfo[R], + ) + + def apply[A1: TypeInfo, A2: TypeInfo, A3: TypeInfo, R: TypeInfo](baseName: String) + : FunctionBuilder[AsmFunction3[A1, A2, A3, R]] = + apply[AsmFunction3[A1, A2, A3, R]]( + baseName, + Array(GenericTypeInfo[A1], GenericTypeInfo[A2], GenericTypeInfo[A3]), + GenericTypeInfo[R], + ) + + def apply[A1: TypeInfo, A2: TypeInfo, A3: TypeInfo, A4: TypeInfo, R: TypeInfo](baseName: String) + : FunctionBuilder[AsmFunction4[A1, A2, A3, A4, R]] = + apply[AsmFunction4[A1, A2, A3, A4, R]]( + baseName, + Array(GenericTypeInfo[A1], GenericTypeInfo[A2], GenericTypeInfo[A3], GenericTypeInfo[A4]), + GenericTypeInfo[R], + ) } trait WrappedMethodBuilder[C] extends WrappedClassBuilder[C] { @@ -513,35 +644,45 @@ trait WrappedMethodBuilder[C] extends WrappedClassBuilder[C] { def emit(body: Code[_]): Unit = mb.emit(body) - def emitWithBuilder[T](f: (CodeBuilder) => Code[T]): Unit = mb.emitWithBuilder(f) + def emitWithBuilder[T](f: CodeBuilder => Code[T]): Unit = mb.emitWithBuilder(f) - def invoke[T](cb: EmitCodeBuilder, args: Value[_]*): Value[T] = mb.invoke(cb, args: _*) } class MethodBuilder[C]( - val cb: ClassBuilder[C], _mname: String, + val cb: ClassBuilder[C], + _mname: String, val parameterTypeInfo: IndexedSeq[TypeInfo[_]], val returnTypeInfo: TypeInfo[_], - val isStatic: Boolean = false + val isStatic: Boolean = false, ) extends WrappedClassBuilder[C] { - require(parameterTypeInfo.length + isStatic.toInt <= 255, + require( + parameterTypeInfo.length + isStatic.toInt <= 255, s"""Invalid method, methods may have at most 255 arguments, found ${parameterTypeInfo.length + isStatic.toInt} |Return Type Info: $returnTypeInfo - |Parameter Type Info: ${parameterTypeInfo.mkString}""".stripMargin) + |Parameter Type Info: ${parameterTypeInfo.mkString}""".stripMargin, + ) + // very long method names, repeated hundreds of thousands of times can cause memory issues. // If necessary to find the name of a method precisely, this can be set to around the constant // limit of 65535 characters, but usually, this can be much smaller. - val methodName: String = _mname.substring(0, scala.math.min(_mname.length, 2000 /* 65535 */)) + val methodName: String = _mname.substring(0, scala.math.min(_mname.length, 2000 /* 65535 */ )) if (methodName != "" && !isJavaIdentifier(methodName)) throw new IllegalArgumentException(s"Illegal method name, not Java identifier: $methodName") - val lmethod: lir.Method = cb.lclass.newMethod(methodName, parameterTypeInfo, returnTypeInfo, isStatic) + val lmethod: lir.Method = + cb.lclass.newMethod(methodName, parameterTypeInfo, returnTypeInfo, isStatic) val localBuilder: SettableBuilder = new SettableBuilder { def newSettable[T](name: String)(implicit tti: TypeInfo[T]): Settable[T] = newLocal[T](name) } + def this_ : Value[C] = + if (!isStatic) cb.this_ + else throw new IllegalAccessException( + s"Cannot access 'this' from static context '${cb.className}.$methodName'." + ) + def newLocal[T: TypeInfo](name: String = null): LocalRef[T] = new LocalRef[T](lmethod.newLocal(name, typeInfo[T])) @@ -549,11 +690,13 @@ class MethodBuilder[C]( val ti = implicitly[TypeInfo[T]] if (i == 0 && !isStatic) - assert(ti == cb.ti, s"$ti != ${ cb.ti }") + assert(ti == cb.ti, s"$ti != ${cb.ti}") else { val static = (!isStatic).toInt - assert(ti == parameterTypeInfo(i - static), - s"$ti != ${ parameterTypeInfo(i - static) }\n params: $parameterTypeInfo") + assert( + ti == parameterTypeInfo(i - static), + s"$ti != ${parameterTypeInfo(i - static)}\n params: $parameterTypeInfo", + ) } new LocalRef(lmethod.getParam(i)) } @@ -567,7 +710,8 @@ class MethodBuilder[C]( startup = Code(startup, c) } - def emitWithBuilder[T](f: (CodeBuilder) => Code[T]): Unit = emit(CodeBuilder.scopedCode[T](this)(f)) + def emitWithBuilder[T](f: CodeBuilder => Code[T]): Unit = + emit(CodeBuilder.scopedCode[T](this)(f)) def emit(body: Code[_]): Unit = { assert(!emitted) @@ -575,65 +719,20 @@ class MethodBuilder[C]( val start = startup.start startup.end.append(lir.goto(body.start)) - body.end.append( - if (body.v != null) - lir.returnx(body.v) - else - lir.returnx()) + if (body.isOpenEnded) { + val ret = + if (body.v != null) lir.returnx(body.v) + else lir.returnx() + body.end.append(ret) + } assert(start != null) lmethod.setEntry(start) body.clear() } - - def invokeCode[T](args: Value[_]*): Code[T] = { - val (start, end, argvs) = Code.sequenceValues(args.toFastSeq.map(_.get)) - if (returnTypeInfo eq UnitInfo) { - if (isStatic) { - end.append(lir.methodStmt(INVOKESTATIC, lmethod, argvs)) - } else { - end.append( - lir.methodStmt(INVOKEVIRTUAL, lmethod, - lir.load(new lir.Parameter(null, 0, cb.ti)) +: argvs)) - } - new VCode(start, end, null) - } else { - val value = if (isStatic) { - lir.methodInsn(INVOKESTATIC, lmethod, argvs) - } else { - lir.methodInsn(INVOKEVIRTUAL, lmethod, - lir.load(new lir.Parameter(null, 0, cb.ti)) +: argvs) - } - new VCode(start, end, value) - } - } - - def invoke[T](codeBuilder: CodeBuilderLike, args: Value[_]*): Value[T] = { - val (start, end, argvs) = Code.sequenceValues(args.toFastSeq.map(_.get)) - if (returnTypeInfo eq UnitInfo) { - if (isStatic) { - end.append(lir.methodStmt(INVOKESTATIC, lmethod, argvs)) - } else { - end.append( - lir.methodStmt(INVOKEVIRTUAL, lmethod, - lir.load(new lir.Parameter(null, 0, cb.ti)) +: argvs)) - } - codeBuilder.append(new VCode(start, end, null)) - coerce[T](Code._empty) - } else { - val value = if (isStatic) { - lir.methodInsn(INVOKESTATIC, lmethod, argvs) - } else { - lir.methodInsn(INVOKEVIRTUAL, lmethod, - lir.load(new lir.Parameter(null, 0, cb.ti)) +: argvs) - } - coerce[T](codeBuilder.memoizeAny(new VCode(start, end, value), returnTypeInfo)) - } - } } -class FunctionBuilder[F]( - val apply_method: MethodBuilder[F] -) extends WrappedMethodBuilder[F] { - val mb: MethodBuilder[F] = apply_method +final case class FunctionBuilder[F] private (apply_method: MethodBuilder[F]) + extends WrappedMethodBuilder[F] { + override val mb: MethodBuilder[F] = apply_method } diff --git a/hail/src/main/scala/is/hail/asm4s/Code.scala b/hail/src/main/scala/is/hail/asm4s/Code.scala index 2861858f2b4..8ea1b392ab3 100644 --- a/hail/src/main/scala/is/hail/asm4s/Code.scala +++ b/hail/src/main/scala/is/hail/asm4s/Code.scala @@ -4,12 +4,14 @@ import is.hail.expr.ir.EmitCodeBuilder import is.hail.lir import is.hail.lir.{Block, ControlX, ValueX} import is.hail.utils._ -import org.objectweb.asm.Opcodes._ -import org.objectweb.asm.Type + +import scala.reflect.ClassTag import java.io.PrintStream import java.lang.reflect -import scala.reflect.ClassTag + +import org.objectweb.asm.Opcodes._ +import org.objectweb.asm.Type abstract class Thrower[T] { def apply[U](cerr: Code[T])(implicit uti: TypeInfo[U]): Code[U] @@ -38,7 +40,12 @@ object Code { newC } - def void[T](c1: Code[_], c2: Code[_], c3: Code[_], f: (lir.ValueX, lir.ValueX, lir.ValueX) => lir.StmtX): Code[T] = { + def void[T]( + c1: Code[_], + c2: Code[_], + c3: Code[_], + f: (lir.ValueX, lir.ValueX, lir.ValueX) => lir.StmtX, + ): Code[T] = { c3.end.append(f(c1.v, c2.v, c3.v)) c2.end.append(lir.goto(c3.start)) c1.end.append(lir.goto(c2.start)) @@ -68,7 +75,12 @@ object Code { newC } - def apply[T](c1: Code[_], c2: Code[_], c3: Code[_], f: (lir.ValueX, lir.ValueX, lir.ValueX) => lir.ValueX): Code[T] = { + def apply[T]( + c1: Code[_], + c2: Code[_], + c3: Code[_], + f: (lir.ValueX, lir.ValueX, lir.ValueX) => lir.ValueX, + ): Code[T] = { c1.end.append(lir.goto(c2.start)) c2.end.append(lir.goto(c3.start)) val newC = new VCode(c1.start, c3.end, f(c1.v, c2.v, c3.v)) @@ -111,19 +123,54 @@ object Code { def apply[T](c1: Code[Unit], c2: Code[Unit], c3: Code[Unit], c4: Code[T]): Code[T] = sequence1(FastSeq(c1, c2, c3), c4) - def apply[T](c1: Code[Unit], c2: Code[Unit], c3: Code[Unit], c4: Code[Unit], c5: Code[T]): Code[T] = + def apply[T](c1: Code[Unit], c2: Code[Unit], c3: Code[Unit], c4: Code[Unit], c5: Code[T]) + : Code[T] = sequence1(FastSeq(c1, c2, c3, c4), c5) - def apply[T](c1: Code[Unit], c2: Code[Unit], c3: Code[Unit], c4: Code[Unit], c5: Code[Unit], c6: Code[T]): Code[T] = + def apply[T]( + c1: Code[Unit], + c2: Code[Unit], + c3: Code[Unit], + c4: Code[Unit], + c5: Code[Unit], + c6: Code[T], + ): Code[T] = sequence1(FastSeq(c1, c2, c3, c4, c5), c6) - def apply[T](c1: Code[Unit], c2: Code[Unit], c3: Code[Unit], c4: Code[Unit], c5: Code[Unit], c6: Code[Unit], c7: Code[T]): Code[T] = + def apply[T]( + c1: Code[Unit], + c2: Code[Unit], + c3: Code[Unit], + c4: Code[Unit], + c5: Code[Unit], + c6: Code[Unit], + c7: Code[T], + ): Code[T] = sequence1(FastSeq(c1, c2, c3, c4, c5, c6), c7) - def apply[T](c1: Code[Unit], c2: Code[Unit], c3: Code[Unit], c4: Code[Unit], c5: Code[Unit], c6: Code[Unit], c7: Code[Unit], c8: Code[T]): Code[T] = + def apply[T]( + c1: Code[Unit], + c2: Code[Unit], + c3: Code[Unit], + c4: Code[Unit], + c5: Code[Unit], + c6: Code[Unit], + c7: Code[Unit], + c8: Code[T], + ): Code[T] = sequence1(FastSeq(c1, c2, c3, c4, c5, c6, c7), c8) - def apply[T](c1: Code[Unit], c2: Code[Unit], c3: Code[Unit], c4: Code[Unit], c5: Code[Unit], c6: Code[Unit], c7: Code[Unit], c8: Code[Unit], c9: Code[T]): Code[T] = + def apply[T]( + c1: Code[Unit], + c2: Code[Unit], + c3: Code[Unit], + c4: Code[Unit], + c5: Code[Unit], + c6: Code[Unit], + c7: Code[Unit], + c8: Code[Unit], + c9: Code[T], + ): Code[T] = sequence1(FastSeq(c1, c2, c3, c4, c5, c6, c7, c8), c9) def apply(cs: Seq[Code[Unit]]): Code[Unit] = { @@ -138,31 +185,47 @@ object Code { def foreach[A](it: Seq[A])(f: A => Code[Unit]): Code[Unit] = Code(it.map(f)) - def newInstance[T <: AnyRef](parameterTypes: Array[Class[_]], args: Array[Code[_]])(implicit tct: ClassTag[T]): Code[T] = + def newInstance[T <: AnyRef]( + parameterTypes: Array[Class[_]], + args: Array[Code[_]], + )(implicit tct: ClassTag[T] + ): Code[T] = newInstance(parameterTypes, args, 0) - def newInstance[T <: AnyRef](parameterTypes: Array[Class[_]], args: Array[Code[_]], lineNumber: Int)(implicit tct: ClassTag[T]): Code[T] = { + def newInstance[T <: AnyRef]( + parameterTypes: Array[Class[_]], + args: Array[Code[_]], + lineNumber: Int, + )(implicit tct: ClassTag[T] + ): Code[T] = { val tti = classInfo[T] val tcls = tct.runtimeClass val c = tcls.getDeclaredConstructor(parameterTypes: _*) - assert(c != null, - s"no such method ${ tcls.getName }(${ - parameterTypes.map(_.getName).mkString(", ") - })") + assert( + c != null, + s"no such method ${tcls.getName}(${parameterTypes.map(_.getName).mkString(", ")})", + ) val (start, end, argvs) = Code.sequenceValues(args) val linst = new lir.Local(null, "new_inst", tti) val newInstX = lir.newInstance( - tti, Type.getInternalName(tcls), "", - Type.getConstructorDescriptor(c), tti, argvs, lineNumber) + tti, + Type.getInternalName(tcls), + "", + Type.getConstructorDescriptor(c), + tti, + argvs, + lineNumber, + ) end.append(lir.store(linst, newInstX, lineNumber)) new VCode(start, end, lir.load(linst)) } - def newInstance[C](cb: ClassBuilder[C], ctor: MethodBuilder[C], args: IndexedSeq[Code[_]]): Code[C] = { + def newInstance[C](cb: ClassBuilder[C], ctor: MethodBuilder[C], args: IndexedSeq[Code[_]]) + : Code[C] = { val (start, end, argvs) = sequenceValues(args) val linst = new lir.Local(null, "new_inst", cb.ti) @@ -175,49 +238,222 @@ object Code { def newInstance[T <: AnyRef]()(implicit tct: ClassTag[T], tti: TypeInfo[T]): Code[T] = newInstance[T](Array[Class[_]](), Array[Code[_]]()) - def newInstance[T <: AnyRef, A1](a1: Code[A1])(implicit a1ct: ClassTag[A1], - tct: ClassTag[T], tti: TypeInfo[T]): Code[T] = + def newInstance[T <: AnyRef, A1]( + a1: Code[A1] + )(implicit + a1ct: ClassTag[A1], + tct: ClassTag[T], + tti: TypeInfo[T], + ): Code[T] = newInstance[T](Array[Class[_]](a1ct.runtimeClass), Array[Code[_]](a1)) - def newInstance[T <: AnyRef, A1, A2](a1: Code[A1], a2: Code[A2])(implicit a1ct: ClassTag[A1], a2ct: ClassTag[A2], - tct: ClassTag[T], tti: TypeInfo[T]): Code[T] = + def newInstance[T <: AnyRef, A1, A2]( + a1: Code[A1], + a2: Code[A2], + )(implicit + a1ct: ClassTag[A1], + a2ct: ClassTag[A2], + tct: ClassTag[T], + tti: TypeInfo[T], + ): Code[T] = newInstance[T](Array[Class[_]](a1ct.runtimeClass, a2ct.runtimeClass), Array[Code[_]](a1, a2)) - def newInstance[T <: AnyRef, A1, A2, A3](a1: Code[A1], a2: Code[A2], a3: Code[A3])(implicit a1ct: ClassTag[A1], a2ct: ClassTag[A2], - a3ct: ClassTag[A3], tct: ClassTag[T], tti: TypeInfo[T]): Code[T] = + def newInstance[T <: AnyRef, A1, A2, A3]( + a1: Code[A1], + a2: Code[A2], + a3: Code[A3], + )(implicit + a1ct: ClassTag[A1], + a2ct: ClassTag[A2], + a3ct: ClassTag[A3], + tct: ClassTag[T], + tti: TypeInfo[T], + ): Code[T] = newInstance(a1, a2, a3, 0) - def newInstance[T <: AnyRef, A1, A2, A3](a1: Code[A1], a2: Code[A2], a3: Code[A3], lineNumber: Int)(implicit a1ct: ClassTag[A1], a2ct: ClassTag[A2], - a3ct: ClassTag[A3], tct: ClassTag[T], tti: TypeInfo[T]): Code[T] = - newInstance[T](Array[Class[_]](a1ct.runtimeClass, a2ct.runtimeClass, a3ct.runtimeClass), Array[Code[_]](a1, a2, a3), lineNumber) + def newInstance[T <: AnyRef, A1, A2, A3]( + a1: Code[A1], + a2: Code[A2], + a3: Code[A3], + lineNumber: Int, + )(implicit + a1ct: ClassTag[A1], + a2ct: ClassTag[A2], + a3ct: ClassTag[A3], + tct: ClassTag[T], + tti: TypeInfo[T], + ): Code[T] = + newInstance[T]( + Array[Class[_]](a1ct.runtimeClass, a2ct.runtimeClass, a3ct.runtimeClass), + Array[Code[_]](a1, a2, a3), + lineNumber, + ) - def newInstance[T <: AnyRef, A1, A2, A3, A4](a1: Code[A1], a2: Code[A2], a3: Code[A3], a4: Code[A4] - )(implicit a1ct: ClassTag[A1], a2ct: ClassTag[A2], a3ct: ClassTag[A3], a4ct: ClassTag[A4], tct: ClassTag[T], tti: TypeInfo[T]): Code[T] = - newInstance[T](Array[Class[_]](a1ct.runtimeClass, a2ct.runtimeClass, a3ct.runtimeClass, a4ct.runtimeClass), Array[Code[_]](a1, a2, a3, a4)) + def newInstance[T <: AnyRef, A1, A2, A3, A4]( + a1: Code[A1], + a2: Code[A2], + a3: Code[A3], + a4: Code[A4], + )(implicit + a1ct: ClassTag[A1], + a2ct: ClassTag[A2], + a3ct: ClassTag[A3], + a4ct: ClassTag[A4], + tct: ClassTag[T], + tti: TypeInfo[T], + ): Code[T] = + newInstance[T]( + Array[Class[_]](a1ct.runtimeClass, a2ct.runtimeClass, a3ct.runtimeClass, a4ct.runtimeClass), + Array[Code[_]](a1, a2, a3, a4), + ) - def newInstance[T <: AnyRef, A1, A2, A3, A4, A5](a1: Code[A1], a2: Code[A2], a3: Code[A3], a4: Code[A4], a5: Code[A5] - )(implicit a1ct: ClassTag[A1], a2ct: ClassTag[A2], a3ct: ClassTag[A3], a4ct: ClassTag[A4], a5ct: ClassTag[A5], tct: ClassTag[T], tti: TypeInfo[T]): Code[T] = - newInstance[T](Array[Class[_]](a1ct.runtimeClass, a2ct.runtimeClass, a3ct.runtimeClass, a4ct.runtimeClass, a5ct.runtimeClass), Array[Code[_]](a1, a2, a3, a4, a5)) + def newInstance[T <: AnyRef, A1, A2, A3, A4, A5]( + a1: Code[A1], + a2: Code[A2], + a3: Code[A3], + a4: Code[A4], + a5: Code[A5], + )(implicit + a1ct: ClassTag[A1], + a2ct: ClassTag[A2], + a3ct: ClassTag[A3], + a4ct: ClassTag[A4], + a5ct: ClassTag[A5], + tct: ClassTag[T], + tti: TypeInfo[T], + ): Code[T] = + newInstance[T]( + Array[Class[_]]( + a1ct.runtimeClass, + a2ct.runtimeClass, + a3ct.runtimeClass, + a4ct.runtimeClass, + a5ct.runtimeClass, + ), + Array[Code[_]](a1, a2, a3, a4, a5), + ) - def newInstance7[T <: AnyRef, A1, A2, A3, A4, A5, A6, A7](a1: Code[A1], a2: Code[A2], a3: Code[A3], a4: Code[A4], a5: Code[A5], a6: Code[A6], a7: Code[A7] - )(implicit a1ct: ClassTag[A1], a2ct: ClassTag[A2], a3ct: ClassTag[A3], a4ct: ClassTag[A4], a5ct: ClassTag[A5], a6ct: ClassTag[A6], a7ct: ClassTag[A7], tct: ClassTag[T], tti: TypeInfo[T]): Code[T] = - newInstance[T](Array[Class[_]](a1ct.runtimeClass, a2ct.runtimeClass, a3ct.runtimeClass, a4ct.runtimeClass, a5ct.runtimeClass, a6ct.runtimeClass, a7ct.runtimeClass), Array[Code[_]](a1, a2, a3, a4, a5, a6, a7)) + def newInstance7[T <: AnyRef, A1, A2, A3, A4, A5, A6, A7]( + a1: Code[A1], + a2: Code[A2], + a3: Code[A3], + a4: Code[A4], + a5: Code[A5], + a6: Code[A6], + a7: Code[A7], + )(implicit + a1ct: ClassTag[A1], + a2ct: ClassTag[A2], + a3ct: ClassTag[A3], + a4ct: ClassTag[A4], + a5ct: ClassTag[A5], + a6ct: ClassTag[A6], + a7ct: ClassTag[A7], + tct: ClassTag[T], + tti: TypeInfo[T], + ): Code[T] = + newInstance[T]( + Array[Class[_]]( + a1ct.runtimeClass, + a2ct.runtimeClass, + a3ct.runtimeClass, + a4ct.runtimeClass, + a5ct.runtimeClass, + a6ct.runtimeClass, + a7ct.runtimeClass, + ), + Array[Code[_]](a1, a2, a3, a4, a5, a6, a7), + ) - def newInstance8[T <: AnyRef, A1, A2, A3, A4, A5, A6, A7, A8](a1: Code[A1], a2: Code[A2], a3: Code[A3], a4: Code[A4], a5: Code[A5], a6: Code[A6], a7: Code[A7], a8: Code[A8] - )(implicit a1ct: ClassTag[A1], a2ct: ClassTag[A2], a3ct: ClassTag[A3], a4ct: ClassTag[A4], a5ct: ClassTag[A5], a6ct: ClassTag[A6], a7ct: ClassTag[A7], a8ct: ClassTag[A8], tct: ClassTag[T], tti: TypeInfo[T]): Code[T] = - newInstance[T](Array[Class[_]](a1ct.runtimeClass, a2ct.runtimeClass, a3ct.runtimeClass, a4ct.runtimeClass, a5ct.runtimeClass, a6ct.runtimeClass, a7ct.runtimeClass, a8ct.runtimeClass), Array[Code[_]](a1, a2, a3, a4, a5, a6, a7, a8)) + def newInstance8[T <: AnyRef, A1, A2, A3, A4, A5, A6, A7, A8]( + a1: Code[A1], + a2: Code[A2], + a3: Code[A3], + a4: Code[A4], + a5: Code[A5], + a6: Code[A6], + a7: Code[A7], + a8: Code[A8], + )(implicit + a1ct: ClassTag[A1], + a2ct: ClassTag[A2], + a3ct: ClassTag[A3], + a4ct: ClassTag[A4], + a5ct: ClassTag[A5], + a6ct: ClassTag[A6], + a7ct: ClassTag[A7], + a8ct: ClassTag[A8], + tct: ClassTag[T], + tti: TypeInfo[T], + ): Code[T] = + newInstance[T]( + Array[Class[_]]( + a1ct.runtimeClass, + a2ct.runtimeClass, + a3ct.runtimeClass, + a4ct.runtimeClass, + a5ct.runtimeClass, + a6ct.runtimeClass, + a7ct.runtimeClass, + a8ct.runtimeClass, + ), + Array[Code[_]](a1, a2, a3, a4, a5, a6, a7, a8), + ) - def newInstance11[T <: AnyRef, A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11](a1: Code[A1], a2: Code[A2], a3: Code[A3], a4: Code[A4], - a5: Code[A5], a6: Code[A6], a7: Code[A7], a8: Code[A8], a9: Code[A9], a10: Code[A10], a11: Code[A11] - )(implicit a1ct: ClassTag[A1], a2ct: ClassTag[A2], a3ct: ClassTag[A3], a4ct: ClassTag[A4], a5ct: ClassTag[A5], a6ct: ClassTag[A6], a7ct: ClassTag[A7], - a8ct: ClassTag[A8], a9ct: ClassTag[A9], a10ct: ClassTag[A10], a11ct: ClassTag[A11], tct: ClassTag[T], tti: TypeInfo[T]): Code[T] = - newInstance[T](Array[Class[_]](a1ct.runtimeClass, a2ct.runtimeClass, a3ct.runtimeClass, a4ct.runtimeClass, a5ct.runtimeClass, a6ct.runtimeClass, a7ct.runtimeClass, - a8ct.runtimeClass, a9ct.runtimeClass, a10ct.runtimeClass, a11ct.runtimeClass), Array[Code[_]](a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11)) + def newInstance11[T <: AnyRef, A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11]( + a1: Code[A1], + a2: Code[A2], + a3: Code[A3], + a4: Code[A4], + a5: Code[A5], + a6: Code[A6], + a7: Code[A7], + a8: Code[A8], + a9: Code[A9], + a10: Code[A10], + a11: Code[A11], + )(implicit + a1ct: ClassTag[A1], + a2ct: ClassTag[A2], + a3ct: ClassTag[A3], + a4ct: ClassTag[A4], + a5ct: ClassTag[A5], + a6ct: ClassTag[A6], + a7ct: ClassTag[A7], + a8ct: ClassTag[A8], + a9ct: ClassTag[A9], + a10ct: ClassTag[A10], + a11ct: ClassTag[A11], + tct: ClassTag[T], + tti: TypeInfo[T], + ): Code[T] = + newInstance[T]( + Array[Class[_]]( + a1ct.runtimeClass, + a2ct.runtimeClass, + a3ct.runtimeClass, + a4ct.runtimeClass, + a5ct.runtimeClass, + a6ct.runtimeClass, + a7ct.runtimeClass, + a8ct.runtimeClass, + a9ct.runtimeClass, + a10ct.runtimeClass, + a11ct.runtimeClass, + ), + Array[Code[_]](a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11), + ) def newArray[T](size: Code[Int])(implicit tti: TypeInfo[T]): Code[Array[T]] = Code(size, lir.newArray(tti)) - def invokeScalaObject[S](cls: Class[_], method: String, parameterTypes: Array[Class[_]], args: Array[Code[_]])(implicit sct: ClassTag[S]): Code[S] = { + def invokeScalaObject[S]( + cls: Class[_], + method: String, + parameterTypes: Array[Class[_]], + args: Array[Code[_]], + )(implicit sct: ClassTag[S] + ): Code[S] = { val m = Invokeable.lookupMethod(cls, method, parameterTypes)(sct) val staticObj = FieldRef("MODULE$")(ClassTag(cls), ClassTag(cls), classInfo(ClassTag(cls))) m.invoke(staticObj.getField(), args) @@ -226,116 +462,509 @@ object Code { def invokeScalaObject0[S](cls: Class[_], method: String)(implicit sct: ClassTag[S]): Code[S] = invokeScalaObject[S](cls, method, Array[Class[_]](), Array[Code[_]]()) - def invokeScalaObject1[A1, S](cls: Class[_], method: String, a1: Code[A1])(implicit a1ct: ClassTag[A1], sct: ClassTag[S]): Code[S] = + def invokeScalaObject1[A1, S]( + cls: Class[_], + method: String, + a1: Code[A1], + )(implicit + a1ct: ClassTag[A1], + sct: ClassTag[S], + ): Code[S] = invokeScalaObject[S](cls, method, Array[Class[_]](a1ct.runtimeClass), Array[Code[_]](a1)) - def invokeScalaObject2[A1, A2, S](cls: Class[_], method: String, a1: Code[A1], a2: Code[A2])(implicit a1ct: ClassTag[A1], a2ct: ClassTag[A2], sct: ClassTag[S]): Code[S] = - invokeScalaObject[S](cls, method, Array[Class[_]](a1ct.runtimeClass, a2ct.runtimeClass), Array(a1, a2)) + def invokeScalaObject2[A1, A2, S]( + cls: Class[_], + method: String, + a1: Code[A1], + a2: Code[A2], + )(implicit + a1ct: ClassTag[A1], + a2ct: ClassTag[A2], + sct: ClassTag[S], + ): Code[S] = + invokeScalaObject[S]( + cls, + method, + Array[Class[_]](a1ct.runtimeClass, a2ct.runtimeClass), + Array(a1, a2), + ) - def invokeScalaObject3[A1, A2, A3, S](cls: Class[_], method: String, a1: Code[A1], a2: Code[A2], a3: Code[A3])(implicit a1ct: ClassTag[A1], a2ct: ClassTag[A2], a3ct: ClassTag[A3], sct: ClassTag[S]): Code[S] = - invokeScalaObject[S](cls, method, Array[Class[_]](a1ct.runtimeClass, a2ct.runtimeClass, a3ct.runtimeClass), Array(a1, a2, a3)) + def invokeScalaObject3[A1, A2, A3, S]( + cls: Class[_], + method: String, + a1: Code[A1], + a2: Code[A2], + a3: Code[A3], + )(implicit + a1ct: ClassTag[A1], + a2ct: ClassTag[A2], + a3ct: ClassTag[A3], + sct: ClassTag[S], + ): Code[S] = + invokeScalaObject[S]( + cls, + method, + Array[Class[_]](a1ct.runtimeClass, a2ct.runtimeClass, a3ct.runtimeClass), + Array(a1, a2, a3), + ) def invokeScalaObject4[A1, A2, A3, A4, S]( - cls: Class[_], method: String, a1: Code[A1], a2: Code[A2], a3: Code[A3], a4: Code[A4])( - implicit a1ct: ClassTag[A1], a2ct: ClassTag[A2], a3ct: ClassTag[A3], a4ct: ClassTag[A4], sct: ClassTag[S]): Code[S] = - invokeScalaObject[S](cls, method, Array[Class[_]](a1ct.runtimeClass, a2ct.runtimeClass, a3ct.runtimeClass, a4ct.runtimeClass), Array(a1, a2, a3, a4)) + cls: Class[_], + method: String, + a1: Code[A1], + a2: Code[A2], + a3: Code[A3], + a4: Code[A4], + )(implicit + a1ct: ClassTag[A1], + a2ct: ClassTag[A2], + a3ct: ClassTag[A3], + a4ct: ClassTag[A4], + sct: ClassTag[S], + ): Code[S] = + invokeScalaObject[S]( + cls, + method, + Array[Class[_]](a1ct.runtimeClass, a2ct.runtimeClass, a3ct.runtimeClass, a4ct.runtimeClass), + Array(a1, a2, a3, a4), + ) def invokeScalaObject5[A1, A2, A3, A4, A5, S]( - cls: Class[_], method: String, a1: Code[A1], a2: Code[A2], a3: Code[A3], a4: Code[A4], a5: Code[A5])( - implicit a1ct: ClassTag[A1], a2ct: ClassTag[A2], a3ct: ClassTag[A3], a4ct: ClassTag[A4], a5ct: ClassTag[A5], sct: ClassTag[S] + cls: Class[_], + method: String, + a1: Code[A1], + a2: Code[A2], + a3: Code[A3], + a4: Code[A4], + a5: Code[A5], + )(implicit + a1ct: ClassTag[A1], + a2ct: ClassTag[A2], + a3ct: ClassTag[A3], + a4ct: ClassTag[A4], + a5ct: ClassTag[A5], + sct: ClassTag[S], ): Code[S] = invokeScalaObject[S]( - cls, method, Array[Class[_]]( - a1ct.runtimeClass, a2ct.runtimeClass, a3ct.runtimeClass, a4ct.runtimeClass, a5ct.runtimeClass), Array(a1, a2, a3, a4, a5)) + cls, + method, + Array[Class[_]]( + a1ct.runtimeClass, + a2ct.runtimeClass, + a3ct.runtimeClass, + a4ct.runtimeClass, + a5ct.runtimeClass, + ), + Array(a1, a2, a3, a4, a5), + ) def invokeScalaObject6[A1, A2, A3, A4, A5, A6, S]( - cls: Class[_], method: String, a1: Code[A1], a2: Code[A2], a3: Code[A3], a4: Code[A4], a5: Code[A5], a6: Code[A6])( - implicit a1ct: ClassTag[A1], a2ct: ClassTag[A2], a3ct: ClassTag[A3], a4ct: ClassTag[A4], a5ct: ClassTag[A5], a6ct: ClassTag[A6], sct: ClassTag[S] + cls: Class[_], + method: String, + a1: Code[A1], + a2: Code[A2], + a3: Code[A3], + a4: Code[A4], + a5: Code[A5], + a6: Code[A6], + )(implicit + a1ct: ClassTag[A1], + a2ct: ClassTag[A2], + a3ct: ClassTag[A3], + a4ct: ClassTag[A4], + a5ct: ClassTag[A5], + a6ct: ClassTag[A6], + sct: ClassTag[S], ): Code[S] = invokeScalaObject[S]( - cls, method, Array[Class[_]]( - a1ct.runtimeClass, a2ct.runtimeClass, a3ct.runtimeClass, a4ct.runtimeClass, a5ct.runtimeClass, a6ct.runtimeClass), Array(a1, a2, a3, a4, a5, a6)) + cls, + method, + Array[Class[_]]( + a1ct.runtimeClass, + a2ct.runtimeClass, + a3ct.runtimeClass, + a4ct.runtimeClass, + a5ct.runtimeClass, + a6ct.runtimeClass, + ), + Array(a1, a2, a3, a4, a5, a6), + ) def invokeScalaObject7[A1, A2, A3, A4, A5, A6, A7, S]( - cls: Class[_], method: String, a1: Code[A1], a2: Code[A2], a3: Code[A3], a4: Code[A4], a5: Code[A5], a6: Code[A6], a7: Code[A7])( - implicit a1ct: ClassTag[A1], a2ct: ClassTag[A2], a3ct: ClassTag[A3], a4ct: ClassTag[A4], a5ct: ClassTag[A5], a6ct: ClassTag[A6], a7ct: ClassTag[A7], sct: ClassTag[S] + cls: Class[_], + method: String, + a1: Code[A1], + a2: Code[A2], + a3: Code[A3], + a4: Code[A4], + a5: Code[A5], + a6: Code[A6], + a7: Code[A7], + )(implicit + a1ct: ClassTag[A1], + a2ct: ClassTag[A2], + a3ct: ClassTag[A3], + a4ct: ClassTag[A4], + a5ct: ClassTag[A5], + a6ct: ClassTag[A6], + a7ct: ClassTag[A7], + sct: ClassTag[S], ): Code[S] = invokeScalaObject[S]( - cls, method, Array[Class[_]]( - a1ct.runtimeClass, a2ct.runtimeClass, a3ct.runtimeClass, a4ct.runtimeClass, a5ct.runtimeClass, a6ct.runtimeClass, a7ct.runtimeClass), Array(a1, a2, a3, a4, a5, a6, a7)) + cls, + method, + Array[Class[_]]( + a1ct.runtimeClass, + a2ct.runtimeClass, + a3ct.runtimeClass, + a4ct.runtimeClass, + a5ct.runtimeClass, + a6ct.runtimeClass, + a7ct.runtimeClass, + ), + Array(a1, a2, a3, a4, a5, a6, a7), + ) def invokeScalaObject8[A1, A2, A3, A4, A5, A6, A7, A8, S]( - cls: Class[_], method: String, a1: Code[A1], a2: Code[A2], a3: Code[A3], a4: Code[A4], a5: Code[A5], a6: Code[A6], a7: Code[A7], a8: Code[A8])( - implicit a1ct: ClassTag[A1], a2ct: ClassTag[A2], a3ct: ClassTag[A3], a4ct: ClassTag[A4], a5ct: ClassTag[A5], a6ct: ClassTag[A6], a7ct: ClassTag[A7], a8ct: ClassTag[A8], sct: ClassTag[S] + cls: Class[_], + method: String, + a1: Code[A1], + a2: Code[A2], + a3: Code[A3], + a4: Code[A4], + a5: Code[A5], + a6: Code[A6], + a7: Code[A7], + a8: Code[A8], + )(implicit + a1ct: ClassTag[A1], + a2ct: ClassTag[A2], + a3ct: ClassTag[A3], + a4ct: ClassTag[A4], + a5ct: ClassTag[A5], + a6ct: ClassTag[A6], + a7ct: ClassTag[A7], + a8ct: ClassTag[A8], + sct: ClassTag[S], ): Code[S] = invokeScalaObject[S]( - cls, method, Array[Class[_]]( - a1ct.runtimeClass, a2ct.runtimeClass, a3ct.runtimeClass, a4ct.runtimeClass, a5ct.runtimeClass, a6ct.runtimeClass, a7ct.runtimeClass, a8ct.runtimeClass), Array(a1, a2, a3, a4, a5, a6, a7, a8)) + cls, + method, + Array[Class[_]]( + a1ct.runtimeClass, + a2ct.runtimeClass, + a3ct.runtimeClass, + a4ct.runtimeClass, + a5ct.runtimeClass, + a6ct.runtimeClass, + a7ct.runtimeClass, + a8ct.runtimeClass, + ), + Array(a1, a2, a3, a4, a5, a6, a7, a8), + ) def invokeScalaObject9[A1, A2, A3, A4, A5, A6, A7, A8, A9, S]( - cls: Class[_], method: String, a1: Code[A1], a2: Code[A2], a3: Code[A3], a4: Code[A4], a5: Code[A5], a6: Code[A6], a7: Code[A7], a8: Code[A8], a9: Code[A9])( - implicit a1ct: ClassTag[A1], a2ct: ClassTag[A2], a3ct: ClassTag[A3], a4ct: ClassTag[A4], a5ct: ClassTag[A5], a6ct: ClassTag[A6], a7ct: ClassTag[A7], a8ct: ClassTag[A8], a9ct: ClassTag[A9], sct: ClassTag[S] - ): Code[S] = + cls: Class[_], + method: String, + a1: Code[A1], + a2: Code[A2], + a3: Code[A3], + a4: Code[A4], + a5: Code[A5], + a6: Code[A6], + a7: Code[A7], + a8: Code[A8], + a9: Code[A9], + )(implicit + a1ct: ClassTag[A1], + a2ct: ClassTag[A2], + a3ct: ClassTag[A3], + a4ct: ClassTag[A4], + a5ct: ClassTag[A5], + a6ct: ClassTag[A6], + a7ct: ClassTag[A7], + a8ct: ClassTag[A8], + a9ct: ClassTag[A9], + sct: ClassTag[S], + ): Code[S] = invokeScalaObject[S]( - cls, method, Array[Class[_]]( - a1ct.runtimeClass, a2ct.runtimeClass, a3ct.runtimeClass, a4ct.runtimeClass, a5ct.runtimeClass, a6ct.runtimeClass, a7ct.runtimeClass, a8ct.runtimeClass, a9ct.runtimeClass), Array(a1, a2, a3, a4, a5, a6, a7, a8, a9)) + cls, + method, + Array[Class[_]]( + a1ct.runtimeClass, + a2ct.runtimeClass, + a3ct.runtimeClass, + a4ct.runtimeClass, + a5ct.runtimeClass, + a6ct.runtimeClass, + a7ct.runtimeClass, + a8ct.runtimeClass, + a9ct.runtimeClass, + ), + Array(a1, a2, a3, a4, a5, a6, a7, a8, a9), + ) def invokeScalaObject11[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, S]( - cls: Class[_], method: String, a1: Code[A1], a2: Code[A2], a3: Code[A3], a4: Code[A4], a5: Code[A5], a6: Code[A6], a7: Code[A7], a8: Code[A8], - a9: Code[A9], a10: Code[A10], a11: Code[A11])( - implicit a1ct: ClassTag[A1], a2ct: ClassTag[A2], a3ct: ClassTag[A3], a4ct: ClassTag[A4], a5ct: ClassTag[A5], a6ct: ClassTag[A6], a7ct: ClassTag[A7], - a8ct: ClassTag[A8], a9ct: ClassTag[A9], a10ct: ClassTag[A10], a11ct: ClassTag[A11], sct: ClassTag[S] + cls: Class[_], + method: String, + a1: Code[A1], + a2: Code[A2], + a3: Code[A3], + a4: Code[A4], + a5: Code[A5], + a6: Code[A6], + a7: Code[A7], + a8: Code[A8], + a9: Code[A9], + a10: Code[A10], + a11: Code[A11], + )(implicit + a1ct: ClassTag[A1], + a2ct: ClassTag[A2], + a3ct: ClassTag[A3], + a4ct: ClassTag[A4], + a5ct: ClassTag[A5], + a6ct: ClassTag[A6], + a7ct: ClassTag[A7], + a8ct: ClassTag[A8], + a9ct: ClassTag[A9], + a10ct: ClassTag[A10], + a11ct: ClassTag[A11], + sct: ClassTag[S], ): Code[S] = invokeScalaObject[S]( - cls, method, + cls, + method, Array[Class[_]]( - a1ct.runtimeClass, a2ct.runtimeClass, a3ct.runtimeClass, a4ct.runtimeClass, a5ct.runtimeClass, a6ct.runtimeClass, a7ct.runtimeClass, a8ct.runtimeClass, - a9ct.runtimeClass, a10ct.runtimeClass, a11ct.runtimeClass), - Array(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11) + a1ct.runtimeClass, + a2ct.runtimeClass, + a3ct.runtimeClass, + a4ct.runtimeClass, + a5ct.runtimeClass, + a6ct.runtimeClass, + a7ct.runtimeClass, + a8ct.runtimeClass, + a9ct.runtimeClass, + a10ct.runtimeClass, + a11ct.runtimeClass, + ), + Array(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11), ) def invokeScalaObject13[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, S]( - cls: Class[_], method: String, a1: Code[A1], a2: Code[A2], a3: Code[A3], a4: Code[A4], a5: Code[A5], a6: Code[A6], a7: Code[A7], a8: Code[A8], - a9: Code[A9], a10: Code[A10], a11: Code[A11], a12: Code[A12], a13: Code[A13])( - implicit a1ct: ClassTag[A1], a2ct: ClassTag[A2], a3ct: ClassTag[A3], a4ct: ClassTag[A4], a5ct: ClassTag[A5], a6ct: ClassTag[A6], a7ct: ClassTag[A7], - a8ct: ClassTag[A8], a9ct: ClassTag[A9], a10ct: ClassTag[A10], a11ct: ClassTag[A11], a12ct: ClassTag[A12], a13ct: ClassTag[A13], sct: ClassTag[S]): Code[S] = + cls: Class[_], + method: String, + a1: Code[A1], + a2: Code[A2], + a3: Code[A3], + a4: Code[A4], + a5: Code[A5], + a6: Code[A6], + a7: Code[A7], + a8: Code[A8], + a9: Code[A9], + a10: Code[A10], + a11: Code[A11], + a12: Code[A12], + a13: Code[A13], + )(implicit + a1ct: ClassTag[A1], + a2ct: ClassTag[A2], + a3ct: ClassTag[A3], + a4ct: ClassTag[A4], + a5ct: ClassTag[A5], + a6ct: ClassTag[A6], + a7ct: ClassTag[A7], + a8ct: ClassTag[A8], + a9ct: ClassTag[A9], + a10ct: ClassTag[A10], + a11ct: ClassTag[A11], + a12ct: ClassTag[A12], + a13ct: ClassTag[A13], + sct: ClassTag[S], + ): Code[S] = invokeScalaObject[S]( - cls, method, + cls, + method, Array[Class[_]]( - a1ct.runtimeClass, a2ct.runtimeClass, a3ct.runtimeClass, a4ct.runtimeClass, a5ct.runtimeClass, a6ct.runtimeClass, a7ct.runtimeClass, a8ct.runtimeClass, - a9ct.runtimeClass, a10ct.runtimeClass, a11ct.runtimeClass, a12ct.runtimeClass, a13ct.runtimeClass), - Array(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13) + a1ct.runtimeClass, + a2ct.runtimeClass, + a3ct.runtimeClass, + a4ct.runtimeClass, + a5ct.runtimeClass, + a6ct.runtimeClass, + a7ct.runtimeClass, + a8ct.runtimeClass, + a9ct.runtimeClass, + a10ct.runtimeClass, + a11ct.runtimeClass, + a12ct.runtimeClass, + a13ct.runtimeClass, + ), + Array(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13), ) def invokeScalaObject16[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, S]( - cls: Class[_], method: String, a1: Code[A1], a2: Code[A2], a3: Code[A3], a4: Code[A4], a5: Code[A5], a6: Code[A6], a7: Code[A7], a8: Code[A8], - a9: Code[A9], a10: Code[A10], a11: Code[A11], a12: Code[A12], a13: Code[A13], a14: Code[A14], a15: Code[A15], a16: Code[A16])( - implicit a1ct: ClassTag[A1], a2ct: ClassTag[A2], a3ct: ClassTag[A3], a4ct: ClassTag[A4], a5ct: ClassTag[A5], a6ct: ClassTag[A6], a7ct: ClassTag[A7], - a8ct: ClassTag[A8], a9ct: ClassTag[A9], a10ct: ClassTag[A10], a11ct: ClassTag[A11], a12ct: ClassTag[A12], a13ct: ClassTag[A13], a14ct: ClassTag[A14], - a15ct: ClassTag[A15], a16ct: ClassTag[A16], sct: ClassTag[S]): Code[S] = + cls: Class[_], + method: String, + a1: Code[A1], + a2: Code[A2], + a3: Code[A3], + a4: Code[A4], + a5: Code[A5], + a6: Code[A6], + a7: Code[A7], + a8: Code[A8], + a9: Code[A9], + a10: Code[A10], + a11: Code[A11], + a12: Code[A12], + a13: Code[A13], + a14: Code[A14], + a15: Code[A15], + a16: Code[A16], + )(implicit + a1ct: ClassTag[A1], + a2ct: ClassTag[A2], + a3ct: ClassTag[A3], + a4ct: ClassTag[A4], + a5ct: ClassTag[A5], + a6ct: ClassTag[A6], + a7ct: ClassTag[A7], + a8ct: ClassTag[A8], + a9ct: ClassTag[A9], + a10ct: ClassTag[A10], + a11ct: ClassTag[A11], + a12ct: ClassTag[A12], + a13ct: ClassTag[A13], + a14ct: ClassTag[A14], + a15ct: ClassTag[A15], + a16ct: ClassTag[A16], + sct: ClassTag[S], + ): Code[S] = invokeScalaObject[S]( - cls, method, + cls, + method, Array[Class[_]]( - a1ct.runtimeClass, a2ct.runtimeClass, a3ct.runtimeClass, a4ct.runtimeClass, a5ct.runtimeClass, a6ct.runtimeClass, a7ct.runtimeClass, a8ct.runtimeClass, - a9ct.runtimeClass, a10ct.runtimeClass, a11ct.runtimeClass, a12ct.runtimeClass, a13ct.runtimeClass, a14ct.runtimeClass, a15ct.runtimeClass, a16ct.runtimeClass), - Array(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16)) - - def invokeScalaObject19[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, S]( - cls: Class[_], method: String, a1: Code[A1], a2: Code[A2], a3: Code[A3], a4: Code[A4], a5: Code[A5], a6: Code[A6], a7: Code[A7], a8: Code[A8], - a9: Code[A9], a10: Code[A10], a11: Code[A11], a12: Code[A12], a13: Code[A13], a14: Code[A14], a15: Code[A15], a16: Code[A16], - a17: Code[A17], a18: Code[A18], a19: Code[A19])( - implicit a1ct: ClassTag[A1], a2ct: ClassTag[A2], a3ct: ClassTag[A3], a4ct: ClassTag[A4], a5ct: ClassTag[A5], a6ct: ClassTag[A6], a7ct: ClassTag[A7], - a8ct: ClassTag[A8], a9ct: ClassTag[A9], a10ct: ClassTag[A10], a11ct: ClassTag[A11], a12ct: ClassTag[A12], a13ct: ClassTag[A13], a14ct: ClassTag[A14], - a15ct: ClassTag[A15], a16ct: ClassTag[A16], a17ct: ClassTag[A17], a18ct: ClassTag[A18], a19ct: ClassTag[A19], sct: ClassTag[S]): Code[S] = + a1ct.runtimeClass, + a2ct.runtimeClass, + a3ct.runtimeClass, + a4ct.runtimeClass, + a5ct.runtimeClass, + a6ct.runtimeClass, + a7ct.runtimeClass, + a8ct.runtimeClass, + a9ct.runtimeClass, + a10ct.runtimeClass, + a11ct.runtimeClass, + a12ct.runtimeClass, + a13ct.runtimeClass, + a14ct.runtimeClass, + a15ct.runtimeClass, + a16ct.runtimeClass, + ), + Array(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16), + ) + + def invokeScalaObject19[ + A1, + A2, + A3, + A4, + A5, + A6, + A7, + A8, + A9, + A10, + A11, + A12, + A13, + A14, + A15, + A16, + A17, + A18, + A19, + S, + ]( + cls: Class[_], + method: String, + a1: Code[A1], + a2: Code[A2], + a3: Code[A3], + a4: Code[A4], + a5: Code[A5], + a6: Code[A6], + a7: Code[A7], + a8: Code[A8], + a9: Code[A9], + a10: Code[A10], + a11: Code[A11], + a12: Code[A12], + a13: Code[A13], + a14: Code[A14], + a15: Code[A15], + a16: Code[A16], + a17: Code[A17], + a18: Code[A18], + a19: Code[A19], + )(implicit + a1ct: ClassTag[A1], + a2ct: ClassTag[A2], + a3ct: ClassTag[A3], + a4ct: ClassTag[A4], + a5ct: ClassTag[A5], + a6ct: ClassTag[A6], + a7ct: ClassTag[A7], + a8ct: ClassTag[A8], + a9ct: ClassTag[A9], + a10ct: ClassTag[A10], + a11ct: ClassTag[A11], + a12ct: ClassTag[A12], + a13ct: ClassTag[A13], + a14ct: ClassTag[A14], + a15ct: ClassTag[A15], + a16ct: ClassTag[A16], + a17ct: ClassTag[A17], + a18ct: ClassTag[A18], + a19ct: ClassTag[A19], + sct: ClassTag[S], + ): Code[S] = invokeScalaObject[S]( - cls, method, + cls, + method, Array[Class[_]]( - a1ct.runtimeClass, a2ct.runtimeClass, a3ct.runtimeClass, a4ct.runtimeClass, a5ct.runtimeClass, a6ct.runtimeClass, a7ct.runtimeClass, a8ct.runtimeClass, - a9ct.runtimeClass, a10ct.runtimeClass, a11ct.runtimeClass, a12ct.runtimeClass, a13ct.runtimeClass, a14ct.runtimeClass, a15ct.runtimeClass, a16ct.runtimeClass, - a17ct.runtimeClass, a18ct.runtimeClass, a19ct.runtimeClass), - Array(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18, a19)) + a1ct.runtimeClass, + a2ct.runtimeClass, + a3ct.runtimeClass, + a4ct.runtimeClass, + a5ct.runtimeClass, + a6ct.runtimeClass, + a7ct.runtimeClass, + a8ct.runtimeClass, + a9ct.runtimeClass, + a10ct.runtimeClass, + a11ct.runtimeClass, + a12ct.runtimeClass, + a13ct.runtimeClass, + a14ct.runtimeClass, + a15ct.runtimeClass, + a16ct.runtimeClass, + a17ct.runtimeClass, + a18ct.runtimeClass, + a19ct.runtimeClass, + ), + Array(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15, a16, a17, a18, a19), + ) - def invokeStatic[S](cls: Class[_], method: String, parameterTypes: Array[Class[_]], args: Array[Code[_]])(implicit sct: ClassTag[S]): Code[S] = { + def invokeStatic[S]( + cls: Class[_], + method: String, + parameterTypes: Array[Class[_]], + args: Array[Code[_]], + )(implicit sct: ClassTag[S] + ): Code[S] = { val m = Invokeable.lookupMethod(cls, method, parameterTypes)(sct) assert(m.isStatic) m.invoke(null, args) @@ -344,22 +973,110 @@ object Code { def invokeStatic0[T, S](method: String)(implicit tct: ClassTag[T], sct: ClassTag[S]): Code[S] = invokeStatic[S](tct.runtimeClass, method, Array[Class[_]](), Array[Code[_]]()) - def invokeStatic1[T, A1, S](method: String, a1: Code[A1])(implicit tct: ClassTag[T], sct: ClassTag[S], a1ct: ClassTag[A1]): Code[S] = - invokeStatic[S](tct.runtimeClass, method, Array[Class[_]](a1ct.runtimeClass), Array[Code[_]](a1))(sct) - - def invokeStatic2[T, A1, A2, S](method: String, a1: Code[A1], a2: Code[A2])(implicit tct: ClassTag[T], sct: ClassTag[S], a1ct: ClassTag[A1], a2ct: ClassTag[A2]): Code[S] = - invokeStatic[S](tct.runtimeClass, method, Array[Class[_]](a1ct.runtimeClass, a2ct.runtimeClass), Array[Code[_]](a1, a2))(sct) - - def invokeStatic3[T, A1, A2, A3, S](method: String, a1: Code[A1], a2: Code[A2], a3: Code[A3])(implicit tct: ClassTag[T], sct: ClassTag[S], a1ct: ClassTag[A1], a2ct: ClassTag[A2], a3ct: ClassTag[A3]): Code[S] = - invokeStatic[S](tct.runtimeClass, method, Array[Class[_]](a1ct.runtimeClass, a2ct.runtimeClass, a3ct.runtimeClass), Array[Code[_]](a1, a2, a3))(sct) - - def invokeStatic4[T, A1, A2, A3, A4, S](method: String, a1: Code[A1], a2: Code[A2], a3: Code[A3], a4: Code[A4])(implicit tct: ClassTag[T], sct: ClassTag[S], a1ct: ClassTag[A1], a2ct: ClassTag[A2], a3ct: ClassTag[A3], a4ct: ClassTag[A4]): Code[S] = - invokeStatic[S](tct.runtimeClass, method, Array[Class[_]](a1ct.runtimeClass, a2ct.runtimeClass, a3ct.runtimeClass, a4ct.runtimeClass), Array[Code[_]](a1, a2, a3, a4))(sct) - - def invokeStatic5[T, A1, A2, A3, A4, A5, S](method: String, a1: Code[A1], a2: Code[A2], a3: Code[A3], a4: Code[A4], a5: Code[A5])(implicit tct: ClassTag[T], sct: ClassTag[S], a1ct: ClassTag[A1], a2ct: ClassTag[A2], a3ct: ClassTag[A3], a4ct: ClassTag[A4], a5ct: ClassTag[A5]): Code[S] = - invokeStatic[S](tct.runtimeClass, method, Array[Class[_]](a1ct.runtimeClass, a2ct.runtimeClass, a3ct.runtimeClass, a4ct.runtimeClass, a5ct.runtimeClass), Array[Code[_]](a1, a2, a3, a4, a5))(sct) + def invokeStatic1[T, A1, S]( + method: String, + a1: Code[A1], + )(implicit + tct: ClassTag[T], + sct: ClassTag[S], + a1ct: ClassTag[A1], + ): Code[S] = + invokeStatic[S]( + tct.runtimeClass, + method, + Array[Class[_]](a1ct.runtimeClass), + Array[Code[_]](a1), + )(sct) + + def invokeStatic2[T, A1, A2, S]( + method: String, + a1: Code[A1], + a2: Code[A2], + )(implicit + tct: ClassTag[T], + sct: ClassTag[S], + a1ct: ClassTag[A1], + a2ct: ClassTag[A2], + ): Code[S] = + invokeStatic[S]( + tct.runtimeClass, + method, + Array[Class[_]](a1ct.runtimeClass, a2ct.runtimeClass), + Array[Code[_]](a1, a2), + )(sct) + + def invokeStatic3[T, A1, A2, A3, S]( + method: String, + a1: Code[A1], + a2: Code[A2], + a3: Code[A3], + )(implicit + tct: ClassTag[T], + sct: ClassTag[S], + a1ct: ClassTag[A1], + a2ct: ClassTag[A2], + a3ct: ClassTag[A3], + ): Code[S] = + invokeStatic[S]( + tct.runtimeClass, + method, + Array[Class[_]](a1ct.runtimeClass, a2ct.runtimeClass, a3ct.runtimeClass), + Array[Code[_]](a1, a2, a3), + )(sct) + + def invokeStatic4[T, A1, A2, A3, A4, S]( + method: String, + a1: Code[A1], + a2: Code[A2], + a3: Code[A3], + a4: Code[A4], + )(implicit + tct: ClassTag[T], + sct: ClassTag[S], + a1ct: ClassTag[A1], + a2ct: ClassTag[A2], + a3ct: ClassTag[A3], + a4ct: ClassTag[A4], + ): Code[S] = + invokeStatic[S]( + tct.runtimeClass, + method, + Array[Class[_]](a1ct.runtimeClass, a2ct.runtimeClass, a3ct.runtimeClass, a4ct.runtimeClass), + Array[Code[_]](a1, a2, a3, a4), + )(sct) + + def invokeStatic5[T, A1, A2, A3, A4, A5, S]( + method: String, + a1: Code[A1], + a2: Code[A2], + a3: Code[A3], + a4: Code[A4], + a5: Code[A5], + )(implicit + tct: ClassTag[T], + sct: ClassTag[S], + a1ct: ClassTag[A1], + a2ct: ClassTag[A2], + a3ct: ClassTag[A3], + a4ct: ClassTag[A4], + a5ct: ClassTag[A5], + ): Code[S] = + invokeStatic[S]( + tct.runtimeClass, + method, + Array[Class[_]]( + a1ct.runtimeClass, + a2ct.runtimeClass, + a3ct.runtimeClass, + a4ct.runtimeClass, + a5ct.runtimeClass, + ), + Array[Code[_]](a1, a2, a3, a4, a5), + )(sct) + + def _null[T >: Null](implicit tti: TypeInfo[T]): Value[T] = + Value.fromLIR[T](lir.insn0(ACONST_NULL, tti)) - def _null[T >: Null](implicit tti: TypeInfo[T]): Value[T] = Value.fromLIR[T](lir.insn0(ACONST_NULL, tti)) def _uncheckednull[T](tti: TypeInfo[T]): Value[T] = Value.fromLIR[T](lir.insn0(ACONST_NULL, tti)) def _empty: Value[Unit] = Value.fromLIR[Unit](null: lir.ValueX) @@ -376,17 +1093,17 @@ object Code { } } - private def getEmitLineNum: Int = { + private def getEmitLineNum: Int = // val st = Thread.currentThread().getStackTrace // val i = st.indexWhere(ste => ste.getFileName == "Emit.scala") // if (i == -1) 0 else st(i).getLineNumber 0 - } def _throw[T <: java.lang.Throwable, U](cerr: Code[T])(implicit uti: TypeInfo[U]): Code[U] = _throw[T, U](cerr, getEmitLineNum) - def _throw[T <: java.lang.Throwable, U](cerr: Code[T], lineNumber: Int)(implicit uti: TypeInfo[U]): Code[U] = { + def _throw[T <: java.lang.Throwable, U](cerr: Code[T], lineNumber: Int)(implicit uti: TypeInfo[U]) + : Code[U] = { if (uti eq UnitInfo) { cerr.end.append(lir.throwx(cerr.v, lineNumber)) val newC = new VCode(cerr.start, cerr.end, null) @@ -404,14 +1121,20 @@ object Code { msg, Code.invokeStatic0[scala.Option[String], scala.Option[String]]("empty"), Code._null[Throwable], - lineNumber) + lineNumber, + ) Code._throw[is.hail.utils.HailException, U](cerr, lineNumber) } def _fatalWithID[U](msg: Code[String], errorId: Code[Int])(implicit uti: TypeInfo[U]): Code[U] = - Code._throw[is.hail.utils.HailException, U](Code.newInstance[is.hail.utils.HailException, String, Int]( + Code._throw[is.hail.utils.HailException, U](Code.newInstance[ + is.hail.utils.HailException, + String, + Int, + ]( msg, - errorId)) + errorId, + )) def _return[T](c: Code[T]): Code[Unit] = { c.end.append(if (c.v != null) @@ -423,47 +1146,51 @@ object Code { newC } - def _printlns(cs: Code[String]*): Code[Unit] = { + def _printlns(cs: Code[String]*): Code[Unit] = _println(cs.reduce[Code[String]] { case (l, r) => (l.concat(r)) }) - } - def _println(c: Code[AnyRef]): Code[Unit] = { + def _println(c: Code[AnyRef]): Code[Unit] = Code( Code.invokeScalaObject1[AnyRef, Unit](scala.Console.getClass, "println", c), - Code.invokeScalaObject0[Unit](scala.Console.getClass, "flush") + Code.invokeScalaObject0[Unit](scala.Console.getClass, "flush"), ) - } def checkcast[T](v: Code[_])(implicit tti: TypeInfo[T]): Code[T] = Code(v, lir.checkcast(tti)) - def boxBoolean(cb: Code[Boolean]): Code[java.lang.Boolean] = Code.newInstance[java.lang.Boolean, Boolean](cb) + def boxBoolean(cb: Code[Boolean]): Code[java.lang.Boolean] = + Code.newInstance[java.lang.Boolean, Boolean](cb) def boxInt(ci: Code[Int]): Code[java.lang.Integer] = Code.newInstance[java.lang.Integer, Int](ci) def boxLong(cl: Code[Long]): Code[java.lang.Long] = Code.newInstance[java.lang.Long, Long](cl) - def boxFloat(cf: Code[Float]): Code[java.lang.Float] = Code.newInstance[java.lang.Float, Float](cf) + def boxFloat(cf: Code[Float]): Code[java.lang.Float] = + Code.newInstance[java.lang.Float, Float](cf) - def boxDouble(cd: Code[Double]): Code[java.lang.Double] = Code.newInstance[java.lang.Double, Double](cd) + def boxDouble(cd: Code[Double]): Code[java.lang.Double] = + Code.newInstance[java.lang.Double, Double](cd) - def booleanValue(x: Code[java.lang.Boolean]): Code[Boolean] = toCodeObject(x).invoke[Boolean]("booleanValue") + def booleanValue(x: Code[java.lang.Boolean]): Code[Boolean] = + toCodeObject(x).invoke[Boolean]("booleanValue") def intValue(x: Code[java.lang.Number]): Code[Int] = toCodeObject(x).invoke[Int]("intValue") def longValue(x: Code[java.lang.Number]): Code[Long] = toCodeObject(x).invoke[Long]("longValue") - def floatValue(x: Code[java.lang.Number]): Code[Float] = toCodeObject(x).invoke[Float]("floatValue") + def floatValue(x: Code[java.lang.Number]): Code[Float] = + toCodeObject(x).invoke[Float]("floatValue") - def doubleValue(x: Code[java.lang.Number]): Code[Double] = toCodeObject(x).invoke[Double]("doubleValue") + def doubleValue(x: Code[java.lang.Number]): Code[Double] = + toCodeObject(x).invoke[Double]("doubleValue") - def getStatic[T: ClassTag, S: ClassTag : TypeInfo](field: String): Code[S] = { + def getStatic[T: ClassTag, S: ClassTag: TypeInfo](field: String): Code[S] = { val f = FieldRef[T, S](field) assert(f.isStatic) f.getField(null) } - def putStatic[T: ClassTag, S: ClassTag : TypeInfo](field: String, rhs: Code[S]): Code[Unit] = { + def putStatic[T: ClassTag, S: ClassTag: TypeInfo](field: String, rhs: Code[S]): Code[Unit] = { val f = FieldRef[T, S](field) assert(f.isStatic) f.put(null, rhs) @@ -475,11 +1202,15 @@ object Code { case _ => None } - def currentTimeMillis(): Code[Long] = Code.invokeStatic0[java.lang.System, Long]("currentTimeMillis") + def currentTimeMillis(): Code[Long] = + Code.invokeStatic0[java.lang.System, Long]("currentTimeMillis") - def memoize[T, U](c: Code[T], name: String)(f: (Value[T]) => Code[U])(implicit tti: TypeInfo[T]): Code[U] = { - if (c.start.first == null && - c.v != null) { + def memoize[T, U](c: Code[T], name: String)(f: (Value[T]) => Code[U])(implicit tti: TypeInfo[T]) + : Code[U] = { + if ( + c.start.first == null && + c.v != null + ) { c.v match { case v: lir.LdcX => val t = new Value[T] { @@ -496,14 +1227,27 @@ object Code { Code(lr := c, f(lr)) } - def memoizeAny[T, U](c: Code[_], name: String)(f: (Value[_]) => Code[U])(implicit tti: TypeInfo[T]): Code[U] = + def memoizeAny[T, U]( + c: Code[_], + name: String, + )( + f: (Value[_]) => Code[U] + )(implicit tti: TypeInfo[T] + ): Code[U] = memoize[T, U](coerce[T](c), name)(f)(tti) - def memoize[T1, T2, U](c1: Code[T1], name1: String, - c2: Code[T2], name2: String - )(f: (Value[T1], Value[T2]) => Code[U])(implicit t1ti: TypeInfo[T1], t2ti: TypeInfo[T2]): Code[U] = { + def memoize[T1, T2, U]( + c1: Code[T1], + name1: String, + c2: Code[T2], + name2: String, + )( + f: (Value[T1], Value[T2]) => Code[U] + )(implicit + t1ti: TypeInfo[T1], + t2ti: TypeInfo[T2], + ): Code[U] = memoize(c1, name1)(v1 => memoize(c2, name2)(v2 => f(v1, v2))) - } def toUnit(c: Code[_]): Code[Unit] = { val newC = new VCode(c.start, c.end, null) @@ -519,7 +1263,8 @@ object Code { t.newTuple(elems) } - def loadTuple(modb: ModuleBuilder, elemTypes: IndexedSeq[TypeInfo[_]], v: Value[_]): IndexedSeq[Value[_]] = { + def loadTuple(modb: ModuleBuilder, elemTypes: IndexedSeq[TypeInfo[_]], v: Value[_]) + : IndexedSeq[Value[_]] = { val t = modb.tupleClass(elemTypes) t.loadElementsAny(v) } @@ -541,18 +1286,18 @@ trait Code[+T] { def clear(): Unit - def ti: TypeInfo[_] = { + def ti: TypeInfo[_] = if (v == null) UnitInfo else v.ti - } } class VCode[+T]( var _start: lir.Block, var _end: lir.Block, - var _v: lir.ValueX) extends Code[T] { + var _v: lir.ValueX, +) extends Code[T] { // for debugging // val stack = Thread.currentThread().getStackTrace // var clearStack: Array[StackTraceElement] = _ @@ -572,25 +1317,14 @@ class VCode[+T]( _v } - def check(): Unit = { - /* - if (_start == null) { - println(clearStack.mkString("\n")) - println("-----") - println(stack.mkString("\n")) - } - */ + def check(): Unit = + /* if (_start == null) { println(clearStack.mkString("\n")) println("-----") + * println(stack.mkString("\n")) } */ assert(_start != null) - } def clear(): Unit = { - /* - if (clearStack != null) { - println(clearStack.mkString("\n")) - } - assert(clearStack == null) - clearStack = Thread.currentThread().getStackTrace - */ + /* if (clearStack != null) { println(clearStack.mkString("\n")) } assert(clearStack == null) + * clearStack = Thread.currentThread().getStackTrace */ _start = null _end = null @@ -606,7 +1340,8 @@ object CodeKind extends Enumeration { class CCode( private var _entry: lir.Block, private var _Ltrue: lir.Block, - private var _Lfalse: lir.Block) extends Code[Boolean] { + private var _Lfalse: lir.Block, +) extends Code[Boolean] { private var _kind: CodeKind.Kind = _ @@ -719,6 +1454,7 @@ class CCode( class ConstCodeBoolean(val b: Boolean) extends Code[Boolean] { private[this] lazy val ldc = new lir.LdcX(if (b) 1 else 0, BooleanInfo, 0) + private[this] lazy val vc = { val L = new lir.Block() new VCode(L, L, ldc) @@ -750,7 +1486,8 @@ class CodeBoolean(val lhs: Code[Boolean]) extends AnyVal { if (v.a.asInstanceOf[Int] != 0) Ltrue else - Lfalse)) + Lfalse + )) case _ => assert(lhs.v.ti == BooleanInfo, lhs.v.ti) lhs.end.append(lir.ifx(IFNE, lhs.v, Ltrue, Lfalse)) @@ -947,9 +1684,11 @@ class CodeLong(val lhs: Code[Long]) extends AnyVal { def hexString: Code[String] = Code.invokeStatic1[java.lang.Long, Long, String]("toHexString", lhs) - def numberOfLeadingZeros: Code[Int] = Code.invokeStatic1[java.lang.Long, Long, Int]("numberOfLeadingZeros", lhs) + def numberOfLeadingZeros: Code[Int] = + Code.invokeStatic1[java.lang.Long, Long, Int]("numberOfLeadingZeros", lhs) - def numberOfTrailingZeros: Code[Int] = Code.invokeStatic1[java.lang.Long, Long, Int]("numberOfTrailingZeros", lhs) + def numberOfTrailingZeros: Code[Int] = + Code.invokeStatic1[java.lang.Long, Long, Int]("numberOfTrailingZeros", lhs) def bitCount: Code[Int] = Code.invokeStatic1[java.lang.Long, Long, Int]("bitCount", lhs) } @@ -1050,7 +1789,8 @@ class CodeString(val lhs: Code[String]) extends AnyVal { def concat(other: Code[String]): Code[String] = lhs.invoke[String, String]("concat", other) - def println(): Code[Unit] = Code.getStatic[System, PrintStream]("out").invoke[String, Unit]("println", lhs) + def println(): Code[Unit] = + Code.getStatic[System, PrintStream]("out").invoke[String, Unit]("println", lhs) def length(): Code[Int] = lhs.invoke[Int]("length") @@ -1059,8 +1799,10 @@ class CodeString(val lhs: Code[String]) extends AnyVal { class CodeArray[T](val lhs: Code[Array[T]])(implicit tti: TypeInfo[T]) { assert(lhs.ti.asInstanceOf[ArrayInfo[_]].tti == tti) + def apply(i: Code[Int]): Code[T] = { - val f: (ValueX, ValueX) => ValueX = (v1: ValueX, v2: ValueX) => lir.insn(tti.aloadOp, tti, FastSeq(v1, v2)) + val f: (ValueX, ValueX) => ValueX = + (v1: ValueX, v2: ValueX) => lir.insn(tti.aloadOp, tti, FastSeq(v1, v2)) Code(lhs, i, f) } @@ -1083,9 +1825,8 @@ class UntypedCodeArray(val lhs: Value[_], tti: TypeInfo[_]) { def apply(i: Code[Int]): Code[_] = Code(lhs, i, lir.insn2(tti.aloadOp)) - def index(cb: EmitCodeBuilder, i: Code[Int]): Value[_] = { + def index(cb: EmitCodeBuilder, i: Code[Int]): Value[_] = cb.memoizeAny(apply(i), tti) - } def update(i: Code[Int], x: Code[_]): Code[Unit] = Code.void(lhs.get, i, x, (lhs, i, x) => lir.stmtOp(tti.astoreOp, lhs, i, x)) @@ -1125,27 +1866,15 @@ class CodeLabel(val L: lir.Block) extends Code[Unit] { null } - def check(): Unit = { - /* - if (_start == null) { - println(clearStack.mkString("\n")) - println("-----") - println(stack.mkString("\n")) - } - */ + def check(): Unit = + /* if (_start == null) { println(clearStack.mkString("\n")) println("-----") + * println(stack.mkString("\n")) } */ assert(_start != null) - } - def clear(): Unit = { - /* - if (clearStack != null) { - println(clearStack.mkString("\n")) - } - assert(clearStack == null) - clearStack = Thread.currentThread().getStackTrace - */ + def clear(): Unit = + /* if (clearStack != null) { println(clearStack.mkString("\n")) } assert(clearStack == null) + * clearStack = Thread.currentThread().getStackTrace */ _start = null - } def goto: Code[Unit] = { val M = new lir.Block() @@ -1155,20 +1884,23 @@ class CodeLabel(val L: lir.Block) extends Code[Unit] { } object Invokeable { - def apply[T](cls: Class[T], c: reflect.Constructor[_]): Invokeable[T, Unit] = new Invokeable[T, Unit]( - cls, - "", - isStatic = false, - isInterface = false, - INVOKESPECIAL, - Type.getConstructorDescriptor(c), - implicitly[ClassTag[Unit]].runtimeClass) + def apply[T](cls: Class[T], c: reflect.Constructor[_]): Invokeable[T, Unit] = + new Invokeable[T, Unit]( + cls, + "", + isStatic = false, + isInterface = false, + INVOKESPECIAL, + Type.getConstructorDescriptor(c), + implicitly[ClassTag[Unit]].runtimeClass, + ) def apply[T, S](cls: Class[T], m: reflect.Method)(implicit sct: ClassTag[S]): Invokeable[T, S] = { val isInterface = m.getDeclaringClass.isInterface val isStatic = reflect.Modifier.isStatic(m.getModifiers) assert(!(isInterface && isStatic)) - new Invokeable[T, S](cls, + new Invokeable[T, S]( + cls, m.getName, isStatic, isInterface, @@ -1179,47 +1911,76 @@ object Invokeable { else INVOKEVIRTUAL, Type.getMethodDescriptor(m), - m.getReturnType) + m.getReturnType, + ) } - def lookupMethod[T, S](cls: Class[T], method: String, parameterTypes: Array[Class[_]])(implicit sct: ClassTag[S]): Invokeable[T, S] = { + def lookupMethod[T, S]( + cls: Class[T], + method: String, + parameterTypes: Array[Class[_]], + )(implicit sct: ClassTag[S] + ): Invokeable[T, S] = { val m = cls.getMethod(method, parameterTypes: _*) - assert(m != null, - s"no such method ${ cls.getName }.$method(${ - parameterTypes.map(_.getName).mkString(", ") - })") + assert( + m != null, + s"no such method ${cls.getName}.$method(${parameterTypes.map(_.getName).mkString(", ")})", + ) // generic type parameters return java.lang.Object instead of the correct class - assert(m.getReturnType.isAssignableFrom(sct.runtimeClass), - s"when invoking ${ cls.getName }.$method(): ${ m.getReturnType.getName }: wrong return type ${ sct.runtimeClass.getName }") + assert( + m.getReturnType.isAssignableFrom(sct.runtimeClass), + s"when invoking ${cls.getName}.$method(): ${m.getReturnType.getName}: wrong return type ${sct.runtimeClass.getName}", + ) Invokeable(cls, m) } } -class Invokeable[T, S](tcls: Class[T], +class Invokeable[T, S]( + tcls: Class[T], val name: String, val isStatic: Boolean, val isInterface: Boolean, val invokeOp: Int, val descriptor: String, - val concreteReturnType: Class[_])(implicit sct: ClassTag[S]) { + val concreteReturnType: Class[_], +)(implicit sct: ClassTag[S] +) { def invoke(lhs: Code[T], args: Array[Code[_]]): Code[S] = { val (start, end, argvs) = Code.sequenceValues( if (isStatic) args else - lhs +: args) + lhs +: args + ) val sti = typeInfoFromClassTag(sct) if (sct.runtimeClass == java.lang.Void.TYPE) { end.append( - lir.methodStmt(invokeOp, Type.getInternalName(tcls), name, descriptor, isInterface, sti, argvs)) + lir.methodStmt( + invokeOp, + Type.getInternalName(tcls), + name, + descriptor, + isInterface, + sti, + argvs, + ) + ) new VCode(start, end, null) } else { val t = new lir.Local(null, s"invoke_$name", sti) - var r = lir.methodInsn(invokeOp, Type.getInternalName(tcls), name, descriptor, isInterface, sti, argvs) + var r = lir.methodInsn( + invokeOp, + Type.getInternalName(tcls), + name, + descriptor, + isInterface, + sti, + argvs, + ) if (concreteReturnType != sct.runtimeClass) r = lir.checkcast(sti, r) end.append(lir.store(t, r)) @@ -1229,10 +1990,13 @@ class Invokeable[T, S](tcls: Class[T], } object FieldRef { - def apply[T, S](field: String)(implicit tct: ClassTag[T], sct: ClassTag[S], sti: TypeInfo[S]): FieldRef[T, S] = { + def apply[T, S](field: String)(implicit tct: ClassTag[T], sct: ClassTag[S], sti: TypeInfo[S]) + : FieldRef[T, S] = { val f = tct.runtimeClass.getDeclaredField(field) - assert(f.getType == sct.runtimeClass, - s"when getting field ${ tct.runtimeClass.getName }.$field: ${ f.getType.getName }: wrong type ${ sct.runtimeClass.getName } ") + assert( + f.getType == sct.runtimeClass, + s"when getting field ${tct.runtimeClass.getName}.$field: ${f.getType.getName}: wrong type ${sct.runtimeClass.getName} ", + ) new FieldRef(f) } @@ -1244,7 +2008,7 @@ object Value { } } -trait Value[+T] { self => +trait Value[+T] { def get: Code[T] } @@ -1258,7 +2022,8 @@ trait Settable[T] extends Value[T] { def load(): Code[T] = get } -class ThisLazyFieldRef[T: TypeInfo](cb: ClassBuilder[_], name: String, setup: Code[T]) extends Value[T] { +class ThisLazyFieldRef[T: TypeInfo](cb: ClassBuilder[_], name: String, setup: Code[T]) + extends Value[T] { private[this] val value: Settable[T] = cb.genFieldThisRef[T](name) private[this] val present: Settable[Boolean] = cb.genFieldThisRef[Boolean](s"${name}_present") @@ -1267,7 +2032,7 @@ class ThisLazyFieldRef[T: TypeInfo](cb: ClassBuilder[_], name: String, setup: Co override def get: Code[T] = CodeBuilder.scopedCode(null) { cb => - cb.if_(!present, cb += setm.invoke(cb) ) + cb.if_(!present, cb += cb.invoke(setm, this.cb.this_)) value } } @@ -1275,9 +2040,9 @@ class ThisLazyFieldRef[T: TypeInfo](cb: ClassBuilder[_], name: String, setup: Co class ThisFieldRef[T: TypeInfo](cb: ClassBuilder[_], f: Field[T]) extends Settable[T] { def name: String = f.name - def get: Code[T] = f.get(cb._this) + def get: Code[T] = f.get(cb.this_) - def store(rhs: Code[T]): Code[Unit] = f.put(cb._this, rhs) + def store(rhs: Code[T]): Code[Unit] = f.put(cb.this_, rhs) } class StaticFieldRef[T: TypeInfo](f: StaticField[T]) extends Settable[T] { @@ -1341,66 +2106,246 @@ class FieldRef[T, S](f: reflect.Field)(implicit tct: ClassTag[T], sti: TypeInfo[ Code.void(lhs, rhs, lir.putField(tiname, f.getName, sti)) } -class CodeObject[T <: AnyRef : ClassTag](val lhs: Code[T]) { +class CodeObject[T <: AnyRef: ClassTag](val lhs: Code[T]) { def getField[S](field: String)(implicit sct: ClassTag[S], sti: TypeInfo[S]): Code[S] = FieldRef[T, S](field).getField(lhs) def put[S](field: String, rhs: Code[S])(implicit sct: ClassTag[S], sti: TypeInfo[S]): Code[Unit] = FieldRef[T, S](field).put(lhs, rhs) - def invoke[S](method: String, parameterTypes: Array[Class[_]], args: Array[Code[_]]) - (implicit sct: ClassTag[S]): Code[S] = - Invokeable.lookupMethod[T, S](implicitly[ClassTag[T]].runtimeClass.asInstanceOf[Class[T]], method, parameterTypes).invoke(lhs, args) + def invoke[S]( + method: String, + parameterTypes: Array[Class[_]], + args: Array[Code[_]], + )(implicit sct: ClassTag[S] + ): Code[S] = + Invokeable.lookupMethod[T, S]( + implicitly[ClassTag[T]].runtimeClass.asInstanceOf[Class[T]], + method, + parameterTypes, + ).invoke(lhs, args) def invoke[S](method: String)(implicit sct: ClassTag[S]): Code[S] = invoke[S](method, Array[Class[_]](), Array[Code[_]]()) - def invoke[A1, S](method: String, a1: Code[A1])(implicit a1ct: ClassTag[A1], - sct: ClassTag[S]): Code[S] = + def invoke[A1, S](method: String, a1: Code[A1])(implicit a1ct: ClassTag[A1], sct: ClassTag[S]) + : Code[S] = invoke[S](method, Array[Class[_]](a1ct.runtimeClass), Array[Code[_]](a1)) - def invoke[A1, A2, S](method: String, a1: Code[A1], a2: Code[A2])(implicit a1ct: ClassTag[A1], a2ct: ClassTag[A2], - sct: ClassTag[S]): Code[S] = + def invoke[A1, A2, S]( + method: String, + a1: Code[A1], + a2: Code[A2], + )(implicit + a1ct: ClassTag[A1], + a2ct: ClassTag[A2], + sct: ClassTag[S], + ): Code[S] = invoke[S](method, Array[Class[_]](a1ct.runtimeClass, a2ct.runtimeClass), Array[Code[_]](a1, a2)) - def invoke[A1, A2, A3, S](method: String, a1: Code[A1], a2: Code[A2], a3: Code[A3]) - (implicit a1ct: ClassTag[A1], a2ct: ClassTag[A2], a3ct: ClassTag[A3], sct: ClassTag[S]): Code[S] = - invoke[S](method, Array[Class[_]](a1ct.runtimeClass, a2ct.runtimeClass, a3ct.runtimeClass), Array[Code[_]](a1, a2, a3)) + def invoke[A1, A2, A3, S]( + method: String, + a1: Code[A1], + a2: Code[A2], + a3: Code[A3], + )(implicit + a1ct: ClassTag[A1], + a2ct: ClassTag[A2], + a3ct: ClassTag[A3], + sct: ClassTag[S], + ): Code[S] = + invoke[S]( + method, + Array[Class[_]](a1ct.runtimeClass, a2ct.runtimeClass, a3ct.runtimeClass), + Array[Code[_]](a1, a2, a3), + ) - def invoke[A1, A2, A3, A4, S](method: String, a1: Code[A1], a2: Code[A2], a3: Code[A3], a4: Code[A4]) - (implicit a1ct: ClassTag[A1], a2ct: ClassTag[A2], a3ct: ClassTag[A3], a4ct: ClassTag[A4], sct: ClassTag[S]): Code[S] = - invoke[S](method, Array[Class[_]](a1ct.runtimeClass, a2ct.runtimeClass, a3ct.runtimeClass, a4ct.runtimeClass), Array[Code[_]](a1, a2, a3, a4)) + def invoke[A1, A2, A3, A4, S]( + method: String, + a1: Code[A1], + a2: Code[A2], + a3: Code[A3], + a4: Code[A4], + )(implicit + a1ct: ClassTag[A1], + a2ct: ClassTag[A2], + a3ct: ClassTag[A3], + a4ct: ClassTag[A4], + sct: ClassTag[S], + ): Code[S] = + invoke[S]( + method, + Array[Class[_]](a1ct.runtimeClass, a2ct.runtimeClass, a3ct.runtimeClass, a4ct.runtimeClass), + Array[Code[_]](a1, a2, a3, a4), + ) - def invoke[A1, A2, A3, A4, A5, S](method: String, a1: Code[A1], a2: Code[A2], a3: Code[A3], a4: Code[A4], a5: Code[A5]) - (implicit a1ct: ClassTag[A1], a2ct: ClassTag[A2], a3ct: ClassTag[A3], a4ct: ClassTag[A4], a5ct: ClassTag[A5], sct: ClassTag[S]): Code[S] = - invoke[S](method, Array[Class[_]](a1ct.runtimeClass, a2ct.runtimeClass, a3ct.runtimeClass, a4ct.runtimeClass, a5ct.runtimeClass), Array[Code[_]](a1, a2, a3, a4, a5)) + def invoke[A1, A2, A3, A4, A5, S]( + method: String, + a1: Code[A1], + a2: Code[A2], + a3: Code[A3], + a4: Code[A4], + a5: Code[A5], + )(implicit + a1ct: ClassTag[A1], + a2ct: ClassTag[A2], + a3ct: ClassTag[A3], + a4ct: ClassTag[A4], + a5ct: ClassTag[A5], + sct: ClassTag[S], + ): Code[S] = + invoke[S]( + method, + Array[Class[_]]( + a1ct.runtimeClass, + a2ct.runtimeClass, + a3ct.runtimeClass, + a4ct.runtimeClass, + a5ct.runtimeClass, + ), + Array[Code[_]](a1, a2, a3, a4, a5), + ) - def invoke[A1, A2, A3, A4, A5, A6, S](method: String, a1: Code[A1], a2: Code[A2], a3: Code[A3], a4: Code[A4], a5: Code[A5], a6: Code[A6]) - (implicit a1ct: ClassTag[A1], a2ct: ClassTag[A2], a3ct: ClassTag[A3], a4ct: ClassTag[A4], a5ct: ClassTag[A5], a6ct: ClassTag[A6], sct: ClassTag[S]): Code[S] = - invoke[S](method, Array[Class[_]](a1ct.runtimeClass, a2ct.runtimeClass, a3ct.runtimeClass, a4ct.runtimeClass, a5ct.runtimeClass, a6ct.runtimeClass), Array[Code[_]](a1, a2, a3, a4, a5, a6)) + def invoke[A1, A2, A3, A4, A5, A6, S]( + method: String, + a1: Code[A1], + a2: Code[A2], + a3: Code[A3], + a4: Code[A4], + a5: Code[A5], + a6: Code[A6], + )(implicit + a1ct: ClassTag[A1], + a2ct: ClassTag[A2], + a3ct: ClassTag[A3], + a4ct: ClassTag[A4], + a5ct: ClassTag[A5], + a6ct: ClassTag[A6], + sct: ClassTag[S], + ): Code[S] = + invoke[S]( + method, + Array[Class[_]]( + a1ct.runtimeClass, + a2ct.runtimeClass, + a3ct.runtimeClass, + a4ct.runtimeClass, + a5ct.runtimeClass, + a6ct.runtimeClass, + ), + Array[Code[_]](a1, a2, a3, a4, a5, a6), + ) - def invoke[A1, A2, A3, A4, A5, A6, A7, S](method: String, a1: Code[A1], a2: Code[A2], a3: Code[A3], a4: Code[A4], a5: Code[A5], a6: Code[A6], a7: Code[A7]) - (implicit a1ct: ClassTag[A1], a2ct: ClassTag[A2], a3ct: ClassTag[A3], a4ct: ClassTag[A4], a5ct: ClassTag[A5], a6ct: ClassTag[A6], a7ct: ClassTag[A7], sct: ClassTag[S]): Code[S] = - invoke[S](method, Array[Class[_]](a1ct.runtimeClass, a2ct.runtimeClass, a3ct.runtimeClass, a4ct.runtimeClass, a5ct.runtimeClass, a6ct.runtimeClass, a7ct.runtimeClass), Array[Code[_]](a1, a2, a3, a4, a5, a6, a7)) + def invoke[A1, A2, A3, A4, A5, A6, A7, S]( + method: String, + a1: Code[A1], + a2: Code[A2], + a3: Code[A3], + a4: Code[A4], + a5: Code[A5], + a6: Code[A6], + a7: Code[A7], + )(implicit + a1ct: ClassTag[A1], + a2ct: ClassTag[A2], + a3ct: ClassTag[A3], + a4ct: ClassTag[A4], + a5ct: ClassTag[A5], + a6ct: ClassTag[A6], + a7ct: ClassTag[A7], + sct: ClassTag[S], + ): Code[S] = + invoke[S]( + method, + Array[Class[_]]( + a1ct.runtimeClass, + a2ct.runtimeClass, + a3ct.runtimeClass, + a4ct.runtimeClass, + a5ct.runtimeClass, + a6ct.runtimeClass, + a7ct.runtimeClass, + ), + Array[Code[_]](a1, a2, a3, a4, a5, a6, a7), + ) - def invoke[A1, A2, A3, A4, A5, A6, A7, A8, S](method: String, a1: Code[A1], a2: Code[A2], a3: Code[A3], a4: Code[A4], - a5: Code[A5], a6: Code[A6], a7: Code[A7], a8: Code[A8]) - (implicit a1ct: ClassTag[A1], a2ct: ClassTag[A2], a3ct: ClassTag[A3], a4ct: ClassTag[A4], a5ct: ClassTag[A5], - a6ct: ClassTag[A6], a7ct: ClassTag[A7], a8ct: ClassTag[A8], sct: ClassTag[S]): Code[S] = { - invoke[S](method, Array[Class[_]](a1ct.runtimeClass, a2ct.runtimeClass, a3ct.runtimeClass, a4ct.runtimeClass, a5ct.runtimeClass, - a6ct.runtimeClass, a7ct.runtimeClass, a8ct.runtimeClass), Array[Code[_]](a1, a2, a3, a4, a5, a6, a7, a8)) - } + def invoke[A1, A2, A3, A4, A5, A6, A7, A8, S]( + method: String, + a1: Code[A1], + a2: Code[A2], + a3: Code[A3], + a4: Code[A4], + a5: Code[A5], + a6: Code[A6], + a7: Code[A7], + a8: Code[A8], + )(implicit + a1ct: ClassTag[A1], + a2ct: ClassTag[A2], + a3ct: ClassTag[A3], + a4ct: ClassTag[A4], + a5ct: ClassTag[A5], + a6ct: ClassTag[A6], + a7ct: ClassTag[A7], + a8ct: ClassTag[A8], + sct: ClassTag[S], + ): Code[S] = + invoke[S]( + method, + Array[Class[_]]( + a1ct.runtimeClass, + a2ct.runtimeClass, + a3ct.runtimeClass, + a4ct.runtimeClass, + a5ct.runtimeClass, + a6ct.runtimeClass, + a7ct.runtimeClass, + a8ct.runtimeClass, + ), + Array[Code[_]](a1, a2, a3, a4, a5, a6, a7, a8), + ) - def invoke[A1, A2, A3, A4, A5, A6, A7, A8, A9, S](method: String, a1: Code[A1], a2: Code[A2], a3: Code[A3], a4: Code[A4], - a5: Code[A5], a6: Code[A6], a7: Code[A7], a8: Code[A8], a9: Code[A9]) - (implicit a1ct: ClassTag[A1], a2ct: ClassTag[A2], a3ct: ClassTag[A3], a4ct: ClassTag[A4], a5ct: ClassTag[A5], - a6ct: ClassTag[A6], a7ct: ClassTag[A7], a8ct: ClassTag[A8], a9ct: ClassTag[A9], sct: ClassTag[S]): Code[S] = { - invoke[S](method, Array[Class[_]](a1ct.runtimeClass, a2ct.runtimeClass, a3ct.runtimeClass, a4ct.runtimeClass, a5ct.runtimeClass, - a6ct.runtimeClass, a7ct.runtimeClass, a8ct.runtimeClass, a9ct.runtimeClass), Array[Code[_]](a1, a2, a3, a4, a5, a6, a7, a8, a9)) - } + def invoke[A1, A2, A3, A4, A5, A6, A7, A8, A9, S]( + method: String, + a1: Code[A1], + a2: Code[A2], + a3: Code[A3], + a4: Code[A4], + a5: Code[A5], + a6: Code[A6], + a7: Code[A7], + a8: Code[A8], + a9: Code[A9], + )(implicit + a1ct: ClassTag[A1], + a2ct: ClassTag[A2], + a3ct: ClassTag[A3], + a4ct: ClassTag[A4], + a5ct: ClassTag[A5], + a6ct: ClassTag[A6], + a7ct: ClassTag[A7], + a8ct: ClassTag[A8], + a9ct: ClassTag[A9], + sct: ClassTag[S], + ): Code[S] = + invoke[S]( + method, + Array[Class[_]]( + a1ct.runtimeClass, + a2ct.runtimeClass, + a3ct.runtimeClass, + a4ct.runtimeClass, + a5ct.runtimeClass, + a6ct.runtimeClass, + a7ct.runtimeClass, + a8ct.runtimeClass, + a9ct.runtimeClass, + ), + Array[Code[_]](a1, a2, a3, a4, a5, a6, a7, a8, a9), + ) } -class CodeNullable[T >: Null : TypeInfo](val lhs: Code[T]) { +class CodeNullable[T >: Null: TypeInfo](val lhs: Code[T]) { def isNull: Code[Boolean] = { val Ltrue = new lir.Block() val Lfalse = new lir.Block() diff --git a/hail/src/main/scala/is/hail/asm4s/CodeBuilder.scala b/hail/src/main/scala/is/hail/asm4s/CodeBuilder.scala index 27b2229a2ca..5cba9e8439e 100644 --- a/hail/src/main/scala/is/hail/asm4s/CodeBuilder.scala +++ b/hail/src/main/scala/is/hail/asm4s/CodeBuilder.scala @@ -1,6 +1,9 @@ package is.hail.asm4s -import is.hail.lir +import is.hail.{lir, HAIL_BUILD_CONFIGURATION} +import is.hail.utils.{toRichIterable, Traceback} + +import org.objectweb.asm.Opcodes.{INVOKESTATIC, INVOKEVIRTUAL} abstract class SettableBuilder { def newSettable[T](name: String)(implicit tti: TypeInfo[T]): Settable[T] @@ -31,6 +34,9 @@ object CodeBuilder { trait CodeBuilderLike { def mb: MethodBuilder[_] + def this_ : Value[_] = + mb.this_ + def isOpenEnded: Boolean // def code: Code[Unit] // debugging only @@ -50,16 +56,18 @@ trait CodeBuilderLike { def updateArray[T: TypeInfo](array: Code[Array[T]], index: Code[Int], value: Code[T]): Unit = append(array.update(index, value)) - def memoize[T: TypeInfo](v: Code[T], optionalName: String = "") - (implicit ev: T =!= Unit) - : Value[T] = + def memoize[T: TypeInfo](v: Code[T], optionalName: String = "")(implicit ev: T =!= Unit) + : Value[T] = v match { case b: ConstCodeBoolean => coerce[T](b.b) case _ => newLocal[T]("memoize" + optionalName, v) } def memoizeAny(v: Code[_], ti: TypeInfo[_]): Value[_] = - memoize(v.asInstanceOf[Code[AnyVal]])(ti.asInstanceOf[TypeInfo[AnyVal]], implicitly[AnyVal =!= Unit]) + memoize(v.asInstanceOf[Code[AnyVal]])( + ti.asInstanceOf[TypeInfo[AnyVal]], + implicitly[AnyVal =!= Unit], + ) def assign[T](s: Settable[T], v: Code[T]): Unit = append(s := v) @@ -67,45 +75,49 @@ trait CodeBuilderLike { def assignAny[T](s: Settable[T], v: Code[_]): Unit = append(s := coerce[T](v)) - /* - Note [Evidence Is Unit] - ----------------------- - Here's an example of a common `CodeBuilderLike` foot-gun: - - // previously: - // def ifx(cond: Code[Bool], csq: => Unit): Unit - - cb.ifx(cond, a := expr) - - What's wrong? - -> It doesn't generate the right code despite passing the type-checker! - -> `a := expr` is never emitted because `ifx` evaluates its parameters for effects only; - -> their values are discarded by a conversion to `Unit`. [1] - - How do we fix this? We could write a test and catch this at runtime, but some errors will - slip through. Better to use the type system to prevent these kind of foot-guns. - - The key observation is that the compiler inserts conversions to Unit. We can prevent it - from doing this if we parameterise the type of `csq`: - - // def ifx[A](code: Code[Bool], csq: => A)(implicit ev: A =:= Unit): Unit - - The compiler now infers the type `A` from `csq`; no conversions to `Unit` are made. - We can use an implicit constraint on `A` to fail compilation if `A` is inferred to - anything other than `Unit`, thus catching and preventing this foot-gun! - - It's worth bearing in mind that while this handles simple cases, it won't prevent the - dedicated hacker from working around it. - - [1]: https://github.com/scala/scala/blob/2.13.x/spec/06-expressions.md#value-discarding - */ - - def if_[A](c: => Code[Boolean], emitThen: => A) - (implicit ev: A =:= Unit /* Note [Evidence Is Unit] */): Unit = + /* Note [Evidence Is Unit] + * ----------------------- Here's an example of a common `CodeBuilderLike` foot-gun: + * + * // previously: + * // def ifx(cond: Code[Bool], csq: => Unit): Unit + * + * cb.ifx(cond, a := expr) + * + * What's wrong? + * -> It doesn't generate the right code despite passing the type-checker! + * -> `a := expr` is never emitted because `ifx` evaluates its parameters for effects only; + * -> their values are discarded by a conversion to `Unit`. [1] + * + * How do we fix this? We could write a test and catch this at runtime, but some errors will slip + * through. Better to use the type system to prevent these kind of foot-guns. + * + * The key observation is that the compiler inserts conversions to Unit. We can prevent it from + * doing this if we parameterise the type of `csq`: + * + * // def ifx[A](code: Code[Bool], csq: => A)(implicit ev: A =:= Unit): Unit + * + * The compiler now infers the type `A` from `csq`; no conversions to `Unit` are made. + * We can use an implicit constraint on `A` to fail compilation if `A` is inferred to anything + * other than `Unit`, thus catching and preventing this foot-gun! + * + * It's worth bearing in mind that while this handles simple cases, it won't prevent the dedicated + * hacker from working around it. + * + * [1]: https://github.com/scala/scala/blob/2.13.x/spec/06-expressions.md#value-discarding */ + + def if_[A]( + c: => Code[Boolean], + emitThen: => A, + )(implicit ev: A =:= Unit /* Note [Evidence Is Unit] */ + ): Unit = if_(c, emitThen, ().asInstanceOf[A]) - def if_[A](cond: => Code[Boolean], emitThen: => A, emitElse: => A) - (implicit ev: A =:= Unit /* Note [Evidence Is Unit] */): Unit = { + def if_[A]( + cond: => Code[Boolean], + emitThen: => A, + emitElse: => A, + )(implicit ev: A =:= Unit /* Note [Evidence Is Unit] */ + ): Unit = { val Ltrue = CodeLabel() val Lfalse = CodeLabel() val Lexit = CodeLabel() @@ -119,8 +131,12 @@ trait CodeBuilderLike { define(Lexit) } - def switch[A](discriminant: => Code[Int], emitDefault: => A, cases: IndexedSeq[() => A]) - (implicit ev: A =:= Unit /* Note [Evidence Is Unit] */): Unit = { + def switch[A]( + discriminant: => Code[Int], + emitDefault: => A, + cases: IndexedSeq[() => A], + )(implicit ev: A =:= Unit /* Note [Evidence Is Unit] */ + ): Unit = { val Lexit = CodeLabel() val Lcases = IndexedSeq.fill(cases.length)(CodeLabel()) val Ldefault = CodeLabel() @@ -136,47 +152,65 @@ trait CodeBuilderLike { define(Lexit) } - def loop[A](emitBody: CodeLabel => A) - (implicit ev: A =:= Unit /* Note [Evidence Is Unit] */): Unit = { + def loop[A](emitBody: CodeLabel => A)(implicit ev: A =:= Unit /* Note [Evidence Is Unit] */ ) + : Unit = { val Lstart = CodeLabel() define(Lstart) emitBody(Lstart) } - def while_[A](cond: => Code[Boolean], emitBody: CodeLabel => A) - (implicit ev: A =:= Unit /* Note [Evidence Is Unit] */): Unit = + def while_[A]( + cond: => Code[Boolean], + emitBody: CodeLabel => A, + )(implicit ev: A =:= Unit /* Note [Evidence Is Unit] */ + ): Unit = loop { Lstart => - if_(cond, { - emitBody(Lstart) - goto(Lstart) - }) + if_( + cond, { + emitBody(Lstart) + goto(Lstart) + }, + ) } - def while_[A](c: => Code[Boolean], emitBody: => A) - (implicit ev: A =:= Unit /* Note [Evidence Is Unit] */): Unit = + def while_[A]( + c: => Code[Boolean], + emitBody: => A, + )(implicit ev: A =:= Unit /* Note [Evidence Is Unit] */ + ): Unit = while_(c, (_: CodeLabel) => emitBody) - def for_[A](setup: => A, cond: => Code[Boolean], incr: => A, emitBody: CodeLabel => A) - (implicit ev: A =:= Unit /* Note [Evidence Is Unit] */): Unit = { + def for_[A]( + setup: => A, + cond: => Code[Boolean], + incr: => A, + emitBody: CodeLabel => A, + )(implicit ev: A =:= Unit /* Note [Evidence Is Unit] */ + ): Unit = { setup - while_(cond, { - val Lincr = CodeLabel() - emitBody(Lincr) - define(Lincr) - incr - }) + while_( + cond, { + val Lincr = CodeLabel() + emitBody(Lincr) + define(Lincr) + incr + }, + ) } - def for_[A](setup: => A, cond: => Code[Boolean], incr: => A, body: => A) - (implicit ev: A =:= Unit /* Note [Evidence Is Unit] */): Unit = + def for_[A]( + setup: => A, + cond: => Code[Boolean], + incr: => A, + body: => A, + )(implicit ev: A =:= Unit /* Note [Evidence Is Unit] */ + ): Unit = for_(setup, cond, incr, (_: CodeLabel) => body) - def newLocal[T: TypeInfo](name: String)(implicit ev: T =!= Unit) - : LocalRef[T] = + def newLocal[T: TypeInfo](name: String)(implicit ev: T =!= Unit): LocalRef[T] = mb.newLocal[T](name) - def newLocal[T: TypeInfo](name: String, c: Code[T])(implicit ev: T =!= Unit) - : LocalRef[T] = { + def newLocal[T: TypeInfo](name: String, c: Code[T])(implicit ev: T =!= Unit): LocalRef[T] = { val l = newLocal[T](name) append(l := c) l @@ -191,12 +225,10 @@ trait CodeBuilderLike { ref } - def newField[T: TypeInfo](name: String)(implicit ev: T =!= Unit) - : ThisFieldRef[T] = + def newField[T: TypeInfo](name: String)(implicit ev: T =!= Unit): ThisFieldRef[T] = mb.genFieldThisRef[T](name) - def newField[T: TypeInfo](name: String, c: Code[T])(implicit ev: T =!= Unit) - : ThisFieldRef[T] = { + def newField[T: TypeInfo](name: String, c: Code[T])(implicit ev: T =!= Unit): ThisFieldRef[T] = { val f = newField[T](name) append(f := c) f @@ -214,6 +246,21 @@ trait CodeBuilderLike { def goto(L: CodeLabel): Unit = append(L.goto) + def invoke[T](m: MethodBuilder[_], args: Value[_]*): Value[T] = { + val (start, end, argvs) = Code.sequenceValues(args.toFastSeq.map(_.get)) + val op = if (m.isStatic) INVOKESTATIC else INVOKEVIRTUAL + + if (m.returnTypeInfo eq UnitInfo) { + end.append(lir.methodStmt(op, m.lmethod, argvs)) + append(new VCode(start, end, null)) + coerce[T](Code._empty) + } else { + val value = lir.methodInsn(op, m.lmethod, argvs) + val result = new VCode(start, end, value) + memoize[T](result)(m.returnTypeInfo.asInstanceOf[TypeInfo[T]], implicitly[T =!= Unit]) + } + } + def _fatal(msgs: Code[String]*): Unit = append(Code._fatal[Unit](msgs.reduce(_.concat(_)))) @@ -222,13 +269,20 @@ trait CodeBuilderLike { def _throw[T <: java.lang.Throwable](cerr: Code[T]): Unit = append(Code._throw[T, Unit](cerr)) + + def _assert(cond: => Code[Boolean], message: Code[String]): Unit = + if (HAIL_BUILD_CONFIGURATION.isDebug) { + val traceback = mb.cb.modb.getObject[Throwable](new Traceback().fillInStackTrace()) + val assertion = Code.newInstance[AssertionError, String, Throwable](message, traceback) + if_(cond, {}, _throw(assertion)) + } else { + if_(cond, {}, _throw(Code.newInstance[AssertionError, String](message))) + } } class CodeBuilder(val mb: MethodBuilder[_], var code: Code[Unit]) extends CodeBuilderLike { - def isOpenEnded: Boolean = { - val last = code.end.last - (last == null) || !last.isInstanceOf[lir.ControlX] || last.isInstanceOf[lir.ThrowX] - } + def isOpenEnded: Boolean = + code.isOpenEnded override def append(c: Code[Unit]): Unit = { assert(isOpenEnded) @@ -236,16 +290,16 @@ class CodeBuilder(val mb: MethodBuilder[_], var code: Code[Unit]) extends CodeBu } override def define(L: CodeLabel): Unit = - if (isOpenEnded) append(L) else { + if (isOpenEnded) append(L) + else { val tmp = code code = new VCode(code.start, L.end, null) tmp.clear() L.clear() } - def uncheckedAppend(c: Code[Unit]): Unit = { + def uncheckedAppend(c: Code[Unit]): Unit = code = Code(code, c) - } def result(): Code[Unit] = { val tmp = code diff --git a/hail/src/main/scala/is/hail/asm4s/GenericTypeInfo.scala b/hail/src/main/scala/is/hail/asm4s/GenericTypeInfo.scala index 7e71c314e91..f3283162d82 100644 --- a/hail/src/main/scala/is/hail/asm4s/GenericTypeInfo.scala +++ b/hail/src/main/scala/is/hail/asm4s/GenericTypeInfo.scala @@ -1,8 +1,6 @@ package is.hail.asm4s -import is.hail.expr.ir.EmitCodeBuilder - -sealed abstract class MaybeGenericTypeInfo[T : TypeInfo] { +sealed abstract class MaybeGenericTypeInfo[T: TypeInfo] { def castFromGeneric(cb: CodeBuilderLike, x: Value[_]): Value[T] def castToGeneric(cb: CodeBuilderLike, x: Value[T]): Value[_] @@ -11,7 +9,7 @@ sealed abstract class MaybeGenericTypeInfo[T : TypeInfo] { val isGeneric: Boolean } -final case class GenericTypeInfo[T : TypeInfo]() extends MaybeGenericTypeInfo[T] { +final case class GenericTypeInfo[T: TypeInfo]() extends MaybeGenericTypeInfo[T] { val base = typeInfo[T] def castFromGeneric(cb: CodeBuilderLike, _x: Value[_]): Value[T] = { @@ -61,9 +59,9 @@ final case class GenericTypeInfo[T : TypeInfo]() extends MaybeGenericTypeInfo[T] cb.memoize(Code.newInstance[java.lang.Character, Char](coerce[Char](x))) case _: UnitInfo.type => Code._null[java.lang.Void] - case cti: ClassInfo[_] => + case _: ClassInfo[_] => x - case ati: ArrayInfo[_] => + case _: ArrayInfo[_] => x } @@ -71,7 +69,7 @@ final case class GenericTypeInfo[T : TypeInfo]() extends MaybeGenericTypeInfo[T] val isGeneric = true } -final case class NotGenericTypeInfo[T : TypeInfo]() extends MaybeGenericTypeInfo[T] { +final case class NotGenericTypeInfo[T: TypeInfo]() extends MaybeGenericTypeInfo[T] { def castFromGeneric(cb: CodeBuilderLike, x: Value[_]): Value[T] = coerce[T](x) def castToGeneric(cb: CodeBuilderLike, x: Value[T]): Value[_] = x diff --git a/hail/src/main/scala/is/hail/asm4s/HailClassLoader.scala b/hail/src/main/scala/is/hail/asm4s/HailClassLoader.scala index b8e9ba097f7..ad660d1d02e 100644 --- a/hail/src/main/scala/is/hail/asm4s/HailClassLoader.scala +++ b/hail/src/main/scala/is/hail/asm4s/HailClassLoader.scala @@ -3,9 +3,9 @@ package is.hail.asm4s class HailClassLoader(parent: ClassLoader) extends ClassLoader(parent) { def loadOrDefineClass(name: String, b: Array[Byte]): Class[_] = { getClassLoadingLock(name).synchronized { - try { + try loadClass(name) - } catch { + catch { case _: java.lang.ClassNotFoundException => defineClass(name, b, 0, b.length) } diff --git a/hail/src/main/scala/is/hail/asm4s/package.scala b/hail/src/main/scala/is/hail/asm4s/package.scala index 5f0b6f4ab7a..ad75e00e2d5 100644 --- a/hail/src/main/scala/is/hail/asm4s/package.scala +++ b/hail/src/main/scala/is/hail/asm4s/package.scala @@ -1,18 +1,19 @@ package is.hail -import org.objectweb.asm.Opcodes._ -import org.objectweb.asm.tree._ - import scala.language.implicitConversions import scala.reflect.ClassTag +import org.objectweb.asm.Opcodes._ +import org.objectweb.asm.tree._ + package asm4s { // lifted from https://github.com/milessabin/shapeless @scala.annotation.implicitAmbiguous("${A} must not be equal to ${B}") sealed abstract class =!=[A, B] extends Serializable + object =!= { - implicit def refl[A, B]: A =!= B = new=!=[A, B] {} + implicit def refl[A, B]: A =!= B = new =!=[A, B] {} implicit def ambig1[A]: A =!= A = ??? implicit def ambig2[A]: A =!= A = ??? } @@ -43,7 +44,7 @@ package asm4s { } class ClassInfo[C](className: String) extends TypeInfo[C] { - val desc = s"L${ className.replace(".", "/") };" + val desc = s"L${className.replace(".", "/")};" override val iname = className.replace(".", "/") val loadOp = ALOAD val storeOp = ASTORE @@ -57,7 +58,7 @@ package asm4s { } class ArrayInfo[T](implicit val tti: TypeInfo[T]) extends TypeInfo[Array[T]] { - val desc = s"[${ tti.desc }" + val desc = s"[${tti.desc}" override val iname = desc.replace(".", "/") val loadOp = ALOAD val storeOp = ASTORE @@ -72,10 +73,9 @@ package asm4s { } package object asm4s { - lazy val theHailClassLoaderForSparkWorkers = { + lazy val theHailClassLoaderForSparkWorkers = // FIXME: how do I ensure this is only created in Spark workers? new HailClassLoader(getClass().getClassLoader()) - } def genName(tag: String, baseName: String): String = lir.genName(tag, baseName) @@ -199,7 +199,6 @@ package object asm4s { val astoreOp = FASTORE val returnOp = FRETURN - def newArray() = new IntInsnNode(NEWARRAY, T_FLOAT) override def uninitializedValue: Value[_] = const(0f) @@ -270,7 +269,7 @@ package object asm4s { implicit def toCodeInt(c: Code[Int]): CodeInt = new CodeInt(c) - implicit def byteToCodeInt(c: Code[Byte]): Code[Int] = c.asInstanceOf[Code[Int]] + implicit def byteToCodeInt(c: Code[Byte]): Code[Int] = coerce(c) implicit def byteToCodeInt2(c: Code[Byte]): CodeInt = toCodeInt(byteToCodeInt(c)) @@ -284,12 +283,13 @@ package object asm4s { implicit def toCodeString(c: Code[String]): CodeString = new CodeString(c) - implicit def toCodeArray[T](c: Code[Array[T]])(implicit tti: TypeInfo[T]): CodeArray[T] = new CodeArray(c) + implicit def toCodeArray[T](c: Code[Array[T]])(implicit tti: TypeInfo[T]): CodeArray[T] = + new CodeArray(c) - implicit def toCodeObject[T <: AnyRef : ClassTag](c: Code[T]): CodeObject[T] = + implicit def toCodeObject[T <: AnyRef: ClassTag](c: Code[T]): CodeObject[T] = new CodeObject(c) - implicit def toCodeNullable[T >: Null : TypeInfo](c: Code[T]): CodeNullable[T] = + implicit def toCodeNullable[T >: Null: TypeInfo](c: Code[T]): CodeNullable[T] = new CodeNullable(c) implicit def indexedSeqValueToCode[T](v: IndexedSeq[Value[T]]): IndexedSeq[Code[T]] = v.map(_.get) @@ -308,13 +308,16 @@ package object asm4s { implicit def valueToCodeString(f: Value[String]): CodeString = new CodeString(f.get) - implicit def valueToCodeObject[T <: AnyRef](f: Value[T])(implicit tct: ClassTag[T]): CodeObject[T] = new CodeObject(f.get) + implicit def valueToCodeObject[T <: AnyRef](f: Value[T])(implicit tct: ClassTag[T]) + : CodeObject[T] = new CodeObject(f.get) - implicit def valueToCodeArray[T](c: Value[Array[T]])(implicit tti: TypeInfo[T]): CodeArray[T] = new CodeArray(c) + implicit def valueToCodeArray[T](c: Value[Array[T]])(implicit tti: TypeInfo[T]): CodeArray[T] = + new CodeArray(c) implicit def valueToCodeBoolean(f: Value[Boolean]): CodeBoolean = new CodeBoolean(f.get) - implicit def valueToCodeNullable[T >: Null : TypeInfo](c: Value[T]): CodeNullable[T] = new CodeNullable(c) + implicit def valueToCodeNullable[T >: Null: TypeInfo](c: Value[T]): CodeNullable[T] = + new CodeNullable(c) implicit def toCode[T](f: Settable[T]): Code[T] = f.load() @@ -330,13 +333,16 @@ package object asm4s { implicit def toCodeString(f: Settable[String]): CodeString = new CodeString(f.load()) - implicit def toCodeArray[T](f: Settable[Array[T]])(implicit tti: TypeInfo[T]): CodeArray[T] = new CodeArray(f.load()) + implicit def toCodeArray[T](f: Settable[Array[T]])(implicit tti: TypeInfo[T]): CodeArray[T] = + new CodeArray(f.load()) implicit def toCodeBoolean(f: Settable[Boolean]): CodeBoolean = new CodeBoolean(f.load()) - implicit def toCodeObject[T <: AnyRef : ClassTag](f: Settable[T]): CodeObject[T] = new CodeObject[T](f.load()) + implicit def toCodeObject[T <: AnyRef: ClassTag](f: Settable[T]): CodeObject[T] = + new CodeObject[T](f.load()) - implicit def toCodeNullable[T >: Null : TypeInfo](f: Settable[T]): CodeNullable[T] = new CodeNullable[T](f.load()) + implicit def toCodeNullable[T >: Null: TypeInfo](f: Settable[T]): CodeNullable[T] = + new CodeNullable[T](f.load()) implicit def toLocalRefInt(f: LocalRef[Int]): LocalRefInt = new LocalRefInt(f) diff --git a/hail/src/main/scala/is/hail/backend/Backend.scala b/hail/src/main/scala/is/hail/backend/Backend.scala index 4ebdde16688..43be7467ffe 100644 --- a/hail/src/main/scala/is/hail/backend/Backend.scala +++ b/hail/src/main/scala/is/hail/backend/Backend.scala @@ -1,42 +1,44 @@ package is.hail.backend -import java.io._ -import java.nio.charset.StandardCharsets - -import org.json4s._ -import org.json4s.jackson.{JsonMethods, Serialization} - import is.hail.asm4s._ import is.hail.backend.spark.SparkBackend +import is.hail.expr.ir.{ + BaseIR, CodeCacheKey, CompiledFunction, IRParser, IRParserEnvironment, LoweringAnalyses, + SortField, TableIR, TableReader, +} import is.hail.expr.ir.lowering.{TableStage, TableStageDependency} -import is.hail.expr.ir.{CodeCacheKey, CompiledFunction, LoweringAnalyses, SortField, TableIR, TableReader} import is.hail.io.{BufferSpec, TypedCodecSpec} import is.hail.io.fs._ import is.hail.io.plink.LoadPlink import is.hail.io.vcf.LoadVCF -import is.hail.expr.ir.{IRParser, BaseIR} import is.hail.linalg.BlockMatrix import is.hail.types._ import is.hail.types.encoded.EType -import is.hail.types.virtual.TFloat64 import is.hail.types.physical.PTuple +import is.hail.types.virtual.TFloat64 import is.hail.utils._ import is.hail.variant.ReferenceGenome import scala.collection.mutable import scala.reflect.ClassTag -import is.hail.expr.ir.IRParserEnvironment +import java.io._ +import java.nio.charset.StandardCharsets + +import org.json4s._ +import org.json4s.jackson.{JsonMethods, Serialization} object Backend { private var id: Long = 0L + def nextID(): String = { id += 1 s"hail_query_$id" } private var irID: Int = 0 + def nextIRID(): Int = { irID += 1 irID @@ -66,7 +68,8 @@ abstract class Backend { def broadcast[T: ClassTag](value: T): BroadcastValue[T] - def persist(backendContext: BackendContext, id: String, value: BlockMatrix, storageLevel: String): Unit + def persist(backendContext: BackendContext, id: String, value: BlockMatrix, storageLevel: String) + : Unit def unpersist(backendContext: BackendContext, id: String): Unit @@ -79,7 +82,7 @@ abstract class Backend { fs: FS, collection: Array[Array[Byte]], stageIdentifier: String, - dependency: Option[TableStageDependency] = None + dependency: Option[TableStageDependency] = None, )( f: (Array[Byte], HailTaskContext, HailClassLoader, FS) => Array[Byte] ): Array[Array[Byte]] @@ -89,7 +92,7 @@ abstract class Backend { fs: FS, collection: IndexedSeq[(Array[Byte], Int)], stageIdentifier: String, - dependency: Option[TableStageDependency] = None + dependency: Option[TableStageDependency] = None, )( f: (Array[Byte], HailTaskContext, HailClassLoader, FS) => Array[Byte] ): (Option[Throwable], IndexedSeq[(Array[Byte], Int)]) @@ -97,24 +100,27 @@ abstract class Backend { def stop(): Unit def asSpark(op: String): SparkBackend = - fatal(s"${ getClass.getSimpleName }: $op requires SparkBackend") + fatal(s"${getClass.getSimpleName}: $op requires SparkBackend") def shouldCacheQueryInfo: Boolean = true - def lookupOrCompileCachedFunction[T](k: CodeCacheKey)(f: => CompiledFunction[T]): CompiledFunction[T] + def lookupOrCompileCachedFunction[T](k: CodeCacheKey)(f: => CompiledFunction[T]) + : CompiledFunction[T] var references: Map[String, ReferenceGenome] = Map.empty - def addDefaultReferences(): Unit = { + def addDefaultReferences(): Unit = references = ReferenceGenome.builtinReferences() - } - def addReference(rg: ReferenceGenome) { + def addReference(rg: ReferenceGenome): Unit = { references.get(rg.name) match { case Some(rg2) => if (rg != rg2) { - fatal(s"Cannot add reference genome '${ rg.name }', a different reference with that name already exists. Choose a reference name NOT in the following list:\n " + - s"@1", references.keys.truncatable("\n ")) + fatal( + s"Cannot add reference genome '${rg.name}', a different reference with that name already exists. Choose a reference name NOT in the following list:\n " + + s"@1", + references.keys.truncatable("\n "), + ) } case None => references += (rg.name -> rg) @@ -122,23 +128,23 @@ abstract class Backend { } def hasReference(name: String) = references.contains(name) - def removeReference(name: String): Unit = { + + def removeReference(name: String): Unit = references -= name - } def lowerDistributedSort( ctx: ExecuteContext, stage: TableStage, sortFields: IndexedSeq[SortField], rt: RTable, - nPartitions: Option[Int] + nPartitions: Option[Int], ): TableReader final def lowerDistributedSort( ctx: ExecuteContext, stage: TableStage, sortFields: IndexedSeq[SortField], - rt: RTable + rt: RTable, ): TableReader = lowerDistributedSort(ctx, stage, sortFields, rt, None) @@ -147,30 +153,26 @@ abstract class Backend { inputIR: TableIR, sortFields: IndexedSeq[SortField], rt: RTable, - nPartitions: Option[Int] = None + nPartitions: Option[Int] = None, ): TableReader = { val analyses = LoweringAnalyses.apply(inputIR, ctx) val inputStage = tableToTableStage(ctx, inputIR, analyses) lowerDistributedSort(ctx, inputStage, sortFields, rt, nPartitions) } - def tableToTableStage(ctx: ExecuteContext, - inputIR: TableIR, - analyses: LoweringAnalyses - ): TableStage + def tableToTableStage(ctx: ExecuteContext, inputIR: TableIR, analyses: LoweringAnalyses) + : TableStage - def withExecuteContext[T](methodName: String): (ExecuteContext => T) => T + def withExecuteContext[T](methodName: String)(f: ExecuteContext => T): T - final def valueType(s: String): Array[Byte] = { + final def valueType(s: String): Array[Byte] = withExecuteContext("valueType") { ctx => val v = IRParser.parse_value_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap)) v.typ.toString.getBytes(StandardCharsets.UTF_8) } - } - private[this] def jsonToBytes(f: => JValue): Array[Byte] = { + private[this] def jsonToBytes(f: => JValue): Array[Byte] = JsonMethods.compact(f).getBytes(StandardCharsets.UTF_8) - } final def tableType(s: String): Array[Byte] = jsonToBytes { withExecuteContext("tableType") { ctx => @@ -193,7 +195,7 @@ abstract class Backend { "element_type" -> JString(t.elementType.toString), "shape" -> JArray(t.shape.map(s => JInt(s)).toList), "is_row_vector" -> JBool(t.isRowVector), - "block_size" -> JInt(t.blockSize) + "block_size" -> JInt(t.blockSize), ) } } @@ -208,15 +210,20 @@ abstract class Backend { } } - def fromFASTAFile(name: String, fastaFile: String, indexFile: String, - xContigs: Array[String], yContigs: Array[String], mtContigs: Array[String], - parInput: Array[String]): Array[Byte] = { + def fromFASTAFile( + name: String, + fastaFile: String, + indexFile: String, + xContigs: Array[String], + yContigs: Array[String], + mtContigs: Array[String], + parInput: Array[String], + ): Array[Byte] = withExecuteContext("fromFASTAFile") { ctx => val rg = ReferenceGenome.fromFASTAFile(ctx, name, fastaFile, indexFile, xContigs, yContigs, mtContigs, parInput) rg.toJSONString.getBytes(StandardCharsets.UTF_8) } - } def parseVCFMetadata(path: String): Array[Byte] = jsonToBytes { withExecuteContext("parseVCFMetadata") { ctx => @@ -226,20 +233,36 @@ abstract class Backend { } } - def importFam(path: String, isQuantPheno: Boolean, delimiter: String, missingValue: String): Array[Byte] = { + def importFam(path: String, isQuantPheno: Boolean, delimiter: String, missingValue: String) + : Array[Byte] = withExecuteContext("importFam") { ctx => - LoadPlink.importFamJSON(ctx.fs, path, isQuantPheno, delimiter, missingValue).getBytes(StandardCharsets.UTF_8) + LoadPlink.importFamJSON(ctx.fs, path, isQuantPheno, delimiter, missingValue).getBytes( + StandardCharsets.UTF_8 + ) } - } - def execute(ir: String, timed: Boolean)(consume: (ExecuteContext, Either[Unit, (PTuple, Long)], String) => Unit): Unit = () + def execute( + ir: String, + timed: Boolean, + )( + consume: (ExecuteContext, Either[Unit, (PTuple, Long)], String) => Unit + ): Unit = () - def encodeToOutputStream(ctx: ExecuteContext, t: PTuple, off: Long, bufferSpecString: String, os: OutputStream): Unit = { + def encodeToOutputStream( + ctx: ExecuteContext, + t: PTuple, + off: Long, + bufferSpecString: String, + os: OutputStream, + ): Unit = { val bs = BufferSpec.parseOrDefault(bufferSpecString) assert(t.size == 1) val elementType = t.fields(0).typ val codec = TypedCodecSpec( - EType.fromPythonTypeEncoding(elementType.virtualType), elementType.virtualType, bs) + EType.fromPythonTypeEncoding(elementType.virtualType), + elementType.virtualType, + bs, + ) assert(t.isFieldDefined(off, 0)) codec.encode(ctx, elementType, t.loadField(off, 0), os) } @@ -247,7 +270,9 @@ abstract class Backend { trait BackendWithCodeCache { private[this] val codeCache: Cache[CodeCacheKey, CompiledFunction[_]] = new Cache(50) - def lookupOrCompileCachedFunction[T](k: CodeCacheKey)(f: => CompiledFunction[T]): CompiledFunction[T] = { + + def lookupOrCompileCachedFunction[T](k: CodeCacheKey)(f: => CompiledFunction[T]) + : CompiledFunction[T] = { codeCache.get(k) match { case Some(v) => v.asInstanceOf[CompiledFunction[T]] case None => @@ -259,5 +284,6 @@ trait BackendWithCodeCache { } trait BackendWithNoCodeCache { - def lookupOrCompileCachedFunction[T](k: CodeCacheKey)(f: => CompiledFunction[T]): CompiledFunction[T] = f + def lookupOrCompileCachedFunction[T](k: CodeCacheKey)(f: => CompiledFunction[T]) + : CompiledFunction[T] = f } diff --git a/hail/src/main/scala/is/hail/backend/BackendServer.scala b/hail/src/main/scala/is/hail/backend/BackendServer.scala index 2ce6c7dc263..a36a13b7741 100644 --- a/hail/src/main/scala/is/hail/backend/BackendServer.scala +++ b/hail/src/main/scala/is/hail/backend/BackendServer.scala @@ -1,20 +1,28 @@ package is.hail.backend +import is.hail.utils._ + import java.net.InetSocketAddress import java.nio.charset.StandardCharsets import java.util.concurrent._ -import com.sun.net.httpserver.{HttpContext, HttpExchange, HttpHandler, HttpServer} +import com.sun.net.httpserver.{HttpExchange, HttpHandler, HttpServer} import org.json4s._ -import org.json4s.jackson.{JsonMethods, Serialization} - -import is.hail.utils._ +import org.json4s.jackson.JsonMethods case class IRTypePayload(ir: String) case class LoadReferencesFromDatasetPayload(path: String) -case class FromFASTAFilePayload(name: String, fasta_file: String, index_file: String, - x_contigs: Array[String], y_contigs: Array[String], mt_contigs: Array[String], - par: Array[String]) + +case class FromFASTAFilePayload( + name: String, + fasta_file: String, + index_file: String, + x_contigs: Array[String], + y_contigs: Array[String], + mt_contigs: Array[String], + par: Array[String], +) + case class ParseVCFMetadataPayload(path: String) case class ImportFamPayload(path: String, quant_pheno: Boolean, delimiter: String, missing: String) case class ExecutePayload(ir: String, stream_codec: String, timed: Boolean) @@ -27,6 +35,7 @@ class BackendServer(backend: Backend) { // 0 => let the OS pick an available port private[this] val httpServer = HttpServer.create(new InetSocketAddress(0), 10) private[this] val handler = new BackendHttpHandler(backend) + private[this] val thread = { // This HTTP server *must not* start non-daemon threads because such threads keep the JVM // alive. A living JVM indicates to Spark that the job is incomplete. This does not manifest @@ -46,14 +55,14 @@ class BackendServer(backend: Backend) { // > a default implementation is used, which uses the thread which was created by the start() // > method. // - // Source: https://docs.oracle.com/javase/8/docs/jre/api/net/httpserver/spec/com/sun/net/httpserver/HttpServer.html#setExecutor-java.util.concurrent.Executor- + /* Source: + * https://docs.oracle.com/javase/8/docs/jre/api/net/httpserver/spec/com/sun/net/httpserver/HttpServer.html#setExecutor-java.util.concurrent.Executor- */ // httpServer.createContext("/", handler) httpServer.setExecutor(null) val t = Executors.defaultThreadFactory().newThread(new Runnable() { - def run(): Unit = { + def run(): Unit = httpServer.start() - } }) t.setDaemon(true) t @@ -61,13 +70,11 @@ class BackendServer(backend: Backend) { def port = httpServer.getAddress.getPort - def start(): Unit = { + def start(): Unit = thread.start() - } - def stop(): Unit = { + def stop(): Unit = httpServer.stop(10) - } } class BackendHttpHandler(backend: Backend) extends HttpHandler { @@ -77,31 +84,40 @@ class BackendHttpHandler(backend: Backend) extends HttpHandler { try { val body = using(exchange.getRequestBody)(JsonMethods.parse(_)) if (exchange.getRequestURI.getPath == "/execute") { - val config = body.extract[ExecutePayload] - backend.execute(config.ir, config.timed) { (ctx, res, timings) => - exchange.getResponseHeaders().add("X-Hail-Timings", timings) - res match { - case Left(_) => exchange.sendResponseHeaders(200, -1L) - case Right((t, off)) => - exchange.sendResponseHeaders(200, 0L) // 0 => an arbitrarily long response body - using(exchange.getResponseBody()) { os => - backend.encodeToOutputStream(ctx, t, off, config.stream_codec, os) - } - } + val config = body.extract[ExecutePayload] + backend.execute(config.ir, config.timed) { (ctx, res, timings) => + exchange.getResponseHeaders().add("X-Hail-Timings", timings) + res match { + case Left(_) => exchange.sendResponseHeaders(200, -1L) + case Right((t, off)) => + exchange.sendResponseHeaders(200, 0L) // 0 => an arbitrarily long response body + using(exchange.getResponseBody()) { os => + backend.encodeToOutputStream(ctx, t, off, config.stream_codec, os) + } } - return + } + return } val response: Array[Byte] = exchange.getRequestURI.getPath match { case "/value/type" => backend.valueType(body.extract[IRTypePayload].ir) case "/table/type" => backend.tableType(body.extract[IRTypePayload].ir) case "/matrixtable/type" => backend.matrixTableType(body.extract[IRTypePayload].ir) case "/blockmatrix/type" => backend.blockMatrixType(body.extract[IRTypePayload].ir) - case "/references/load" => backend.loadReferencesFromDataset(body.extract[LoadReferencesFromDatasetPayload].path) + case "/references/load" => + backend.loadReferencesFromDataset(body.extract[LoadReferencesFromDatasetPayload].path) case "/references/from_fasta" => val config = body.extract[FromFASTAFilePayload] - backend.fromFASTAFile(config.name, config.fasta_file, config.index_file, - config.x_contigs, config.y_contigs, config.mt_contigs, config.par) - case "/vcf/metadata/parse" => backend.parseVCFMetadata(body.extract[ParseVCFMetadataPayload].path) + backend.fromFASTAFile( + config.name, + config.fasta_file, + config.index_file, + config.x_contigs, + config.y_contigs, + config.mt_contigs, + config.par, + ) + case "/vcf/metadata/parse" => + backend.parseVCFMetadata(body.extract[ParseVCFMetadataPayload].path) case "/fam/import" => val config = body.extract[ImportFamPayload] backend.importFam(config.path, config.quant_pheno, config.delimiter, config.missing) @@ -115,7 +131,7 @@ class BackendHttpHandler(backend: Backend) extends HttpHandler { val errorJson = JObject( "short" -> JString(shortMessage), "expanded" -> JString(expandedMessage), - "error_id" -> JInt(errorId) + "error_id" -> JInt(errorId), ) val errorBytes = JsonMethods.compact(errorJson).getBytes(StandardCharsets.UTF_8) exchange.sendResponseHeaders(500, errorBytes.length) diff --git a/hail/src/main/scala/is/hail/backend/BackendUtils.scala b/hail/src/main/scala/is/hail/backend/BackendUtils.scala index 8d744bd9567..78bb30fd0f0 100644 --- a/hail/src/main/scala/is/hail/backend/BackendUtils.scala +++ b/hail/src/main/scala/is/hail/backend/BackendUtils.scala @@ -10,17 +10,21 @@ import is.hail.io.fs._ import is.hail.services._ import is.hail.utils._ +import scala.annotation.nowarn import scala.util.Try object BackendUtils { type F = AsmFunction3[Region, Array[Byte], Array[Byte], Array[Byte]] } -class BackendUtils(mods: Array[(String, (HailClassLoader, FS, HailTaskContext, Region) => BackendUtils.F)]) { +class BackendUtils( + mods: Array[(String, (HailClassLoader, FS, HailTaskContext, Region) => BackendUtils.F)] +) { import BackendUtils.F - private[this] val loadedModules: Map[String, (HailClassLoader, FS, HailTaskContext, Region) => F] = mods.toMap + private[this] val loadedModules + : Map[String, (HailClassLoader, FS, HailTaskContext, Region) => F] = mods.toMap def getModule(id: String): (HailClassLoader, FS, HailTaskContext, Region) => F = loadedModules(id) @@ -44,7 +48,7 @@ class BackendUtils(mods: Array[(String, (HailClassLoader, FS, HailTaskContext, R globals: Array[Byte], stageName: String, semhash: Option[SemanticHash.Type], - tsd: Option[TableStageDependency] + tsd: Option[TableStageDependency], ): Array[Array[Byte]] = lookupSemanticHashResults(backendContext, stageName, semhash) match { case None => if (contexts.isEmpty) @@ -76,24 +80,24 @@ class BackendUtils(mods: Array[(String, (HailClassLoader, FS, HailTaskContext, R val fsConfigBC = backend.broadcast(fs.getConfiguration()) backend.parallelizeAndComputeWithIndex(backendContext, fs, contexts, stageName, tsd) { (ctx, htc, theHailClassLoader, fs) => - val fsConfig = fsConfigBC.value - val gs = globalsBC.value - fs.setConfiguration(fsConfig) - htc.getRegionPool().scopedRegion { region => - f(theHailClassLoader, fs, htc, region)(region, ctx, gs) - } + val fsConfig = fsConfigBC.value + val gs = globalsBC.value + fs.setConfiguration(fsConfig) + htc.getRegionPool().scopedRegion { region => + f(theHailClassLoader, fs, htc, region)(region, ctx, gs) + } } } log.info(s"[collectDArray|$stageName]: executed ${contexts.length} tasks " + - s"in ${formatTime(System.nanoTime() - t)}" - ) + s"in ${formatTime(System.nanoTime() - t)}") results case Some(cachedResults) => + @nowarn("cat=unused-pat-vars&msg=pattern var c") val remainingContexts = for { - c@(_, k) <- contexts.zipWithIndex + c @ (_, k) <- contexts.zipWithIndex if !cachedResults.containsOrdered[Int](k, _ < _, _._2) } yield c val results = @@ -130,23 +134,24 @@ class BackendUtils(mods: Array[(String, (HailClassLoader, FS, HailTaskContext, R val globalsBC = backend.broadcast(globals) val fsConfigBC = backend.broadcast(fs.getConfiguration()) val (failureOpt, successes) = - backend.parallelizeAndComputeWithIndexReturnAllErrors(backendContext, fs, remainingContexts, stageName, tsd) { + backend.parallelizeAndComputeWithIndexReturnAllErrors(backendContext, fs, + remainingContexts, stageName, tsd) { (ctx, htc, theHailClassLoader, fs) => - val fsConfig = fsConfigBC.value - val gs = globalsBC.value - fs.setConfiguration(fsConfig) - htc.getRegionPool().scopedRegion { region => - f(theHailClassLoader, fs, htc, region)(region, ctx, gs) - } + val fsConfig = fsConfigBC.value + val gs = globalsBC.value + fs.setConfiguration(fsConfig) + htc.getRegionPool().scopedRegion { region => + f(theHailClassLoader, fs, htc, region)(region, ctx, gs) + } } (failureOpt, successes) } log.info(s"[collectDArray|$stageName]: executed ${remainingContexts.length} tasks " + - s"in ${formatTime(System.nanoTime() - t)}" - ) + s"in ${formatTime(System.nanoTime() - t)}") - val results = merge[(Array[Byte], Int)](cachedResults, successes.sortBy(_._2), _._2 < _._2) + val results = + merge[(Array[Byte], Int)](cachedResults, successes.sortBy(_._2), _._2 < _._2) semhash.foreach(s => backendContext.executionCache.put(s, results)) failureOpt.foreach(throw _) diff --git a/hail/src/main/scala/is/hail/backend/ExecuteContext.scala b/hail/src/main/scala/is/hail/backend/ExecuteContext.scala index ad079f0a4e6..07a411bb309 100644 --- a/hail/src/main/scala/is/hail/backend/ExecuteContext.scala +++ b/hail/src/main/scala/is/hail/backend/ExecuteContext.scala @@ -1,5 +1,6 @@ package is.hail.backend +import is.hail.{HailContext, HailFeatureFlags} import is.hail.annotations.{Region, RegionPool} import is.hail.asm4s.HailClassLoader import is.hail.backend.local.LocalTaskContext @@ -7,11 +8,11 @@ import is.hail.expr.ir.lowering.IrMetadata import is.hail.io.fs.FS import is.hail.utils._ import is.hail.variant.ReferenceGenome -import is.hail.{HailContext, HailFeatureFlags} + +import scala.collection.mutable import java.io._ import java.security.SecureRandom -import scala.collection.mutable trait TempFileManager { def own(path: String): Unit @@ -40,7 +41,10 @@ class NonOwningTempFileManager(owner: TempFileManager) extends TempFileManager { object ExecuteContext { def scoped[T]()(f: ExecuteContext => T): T = { val (result, _) = ExecutionTimer.time("ExecuteContext.scoped") { timer => - HailContext.sparkBackend("ExecuteContext.scoped").withExecuteContext(timer, selfContainedExecution = false)(f) + HailContext.sparkBackend("ExecuteContext.scoped").withExecuteContext( + timer, + selfContainedExecution = false, + )(f) } result } @@ -53,7 +57,6 @@ object ExecuteContext { timer: ExecutionTimer, tempFileManager: TempFileManager, theHailClassLoader: HailClassLoader, - referenceGenomes: Map[String, ReferenceGenome], flags: HailFeatureFlags, backendContext: BackendContext, )( @@ -69,10 +72,9 @@ object ExecuteContext { timer, tempFileManager, theHailClassLoader, - referenceGenomes, flags, backendContext, - IrMetadata(None) + IrMetadata(None), ))(f(_)) } } @@ -109,20 +111,23 @@ class ExecuteContext( val timer: ExecutionTimer, _tempFileManager: TempFileManager, val theHailClassLoader: HailClassLoader, - val referenceGenomes: Map[String, ReferenceGenome], val flags: HailFeatureFlags, val backendContext: BackendContext, - var irMetadata: IrMetadata + var irMetadata: IrMetadata, ) extends Closeable { - val rngNonce: Long = try { - java.lang.Long.decode(getFlag("rng_nonce")) - } catch { - case exc: NumberFormatException => - fatal(s"Could not parse flag rng_nonce as a 64-bit signed integer: ${getFlag("rng_nonce")}", exc) - } + val rngNonce: Long = + try + java.lang.Long.decode(getFlag("rng_nonce")) + catch { + case exc: NumberFormatException => + fatal( + s"Could not parse flag rng_nonce as a 64-bit signed integer: ${getFlag("rng_nonce")}", + exc, + ) + } - val stateManager = HailStateManager(referenceGenomes) + val stateManager = HailStateManager(backend.references) val tempFileManager: TempFileManager = if (_tempFileManager != null) _tempFileManager else new OwningTempFileManager(fs) @@ -131,32 +136,29 @@ class ExecuteContext( private val cleanupFunctions = mutable.ArrayBuffer[() => Unit]() - private[this] val broadcasts = mutable.ArrayBuffer.empty[BroadcastValue[_]] - val memo: mutable.Map[Any, Any] = new mutable.HashMap[Any, Any]() val taskContext: HailTaskContext = new LocalTaskContext(0, 0) - def scopedExecution[T](f: (HailClassLoader, FS, HailTaskContext, Region) => T): T = { + + def scopedExecution[T](f: (HailClassLoader, FS, HailTaskContext, Region) => T): T = using(new LocalTaskContext(0, 0))(f(theHailClassLoader, fs, _, r)) - } def createTmpPath(prefix: String, extension: String = null, local: Boolean = false): String = { - val path = ExecuteContext.createTmpPathNoCleanup(if (local) localTmpdir else tmpdir, prefix, extension) + val path = + ExecuteContext.createTmpPathNoCleanup(if (local) localTmpdir else tmpdir, prefix, extension) tempFileManager.own(path) path } - def ownCloseable(c: Closeable): Unit = { + def ownCloseable(c: Closeable): Unit = cleanupFunctions += c.close - } - def ownCleanup(cleanupFunction: () => Unit): Unit = { + def ownCleanup(cleanupFunction: () => Unit): Unit = cleanupFunctions += cleanupFunction - } def getFlag(name: String): String = flags.get(name) - def getReference(name: String): ReferenceGenome = referenceGenomes(name) + def getReference(name: String): ReferenceGenome = backend.references(name) def shouldWriteIRFiles(): Boolean = getFlag("write_ir_files") != null @@ -170,9 +172,9 @@ class ExecuteContext( var exception: Exception = null for (cleanupFunction <- cleanupFunctions) { - try { + try cleanupFunction() - } catch { + catch { case exc: Exception => if (exception == null) { exception = new RuntimeException("ExecuteContext could not cleanup all resources") diff --git a/hail/src/main/scala/is/hail/backend/ExecutionCache.scala b/hail/src/main/scala/is/hail/backend/ExecutionCache.scala index 5c215201e41..a0e9919ae66 100644 --- a/hail/src/main/scala/is/hail/backend/ExecutionCache.scala +++ b/hail/src/main/scala/is/hail/backend/ExecutionCache.scala @@ -3,14 +3,14 @@ package is.hail.backend import is.hail.HailFeatureFlags import is.hail.expr.ir.analyses.SemanticHash import is.hail.io.fs.FS -import is.hail.utils.{Logging, using} +import is.hail.utils.{using, Logging} -import java.io.{FileNotFoundException, OutputStream} -import java.util.Base64 -import java.util.concurrent.ConcurrentHashMap import scala.io.Source import scala.util.control.NonFatal +import java.io.{FileNotFoundException, OutputStream} +import java.util.Base64 +import java.util.concurrent.ConcurrentHashMap trait ExecutionCache extends Serializable { def lookup(s: SemanticHash.Type): IndexedSeq[(Array[Byte], Int)] @@ -25,10 +25,16 @@ case object ExecutionCache { def fromFlags(flags: HailFeatureFlags, fs: FS, tmpdir: String): ExecutionCache = if (Option(flags.get(Flags.UseFastRestarts)).isEmpty) noCache - else fsCache(fs, Option(flags.get(Flags.Cachedir)).getOrElse(s"$tmpdir/hail/${is.hail.HAIL_PIP_VERSION}")) + else fsCache( + fs, + Option(flags.get(Flags.Cachedir)).getOrElse(s"$tmpdir/hail/${is.hail.HAIL_PIP_VERSION}"), + ) def fsCache(fs: FS, cachedir: String): ExecutionCache = { - assert(fs.validUrl(cachedir), s"""Invalid execution cache location (${fs.getClass.getSimpleName}): "$cachedir".""") + assert( + fs.validUrl(cachedir), + s"""Invalid execution cache location (${fs.getClass.getSimpleName}): "$cachedir".""", + ) FSExecutionCache(fs, cachedir) } @@ -51,9 +57,7 @@ case object ExecutionCache { } } -private case class FSExecutionCache(fs: FS, cacheDir: String) - extends ExecutionCache - with Logging { +private case class FSExecutionCache(fs: FS, cacheDir: String) extends ExecutionCache with Logging { private val base64Encode: Array[Byte] => Array[Byte] = Base64.getUrlEncoder.encode @@ -62,11 +66,11 @@ private case class FSExecutionCache(fs: FS, cacheDir: String) Base64.getUrlDecoder.decode override def lookup(s: SemanticHash.Type): IndexedSeq[(Array[Byte], Int)] = - try { + try using(fs.open(at(s))) { Source.fromInputStream(_).getLines().map(Line.read).toIndexedSeq } - } catch { + catch { case _: FileNotFoundException => IndexedSeq.empty @@ -76,13 +80,14 @@ private case class FSExecutionCache(fs: FS, cacheDir: String) } override def put(s: SemanticHash.Type, r: IndexedSeq[(Array[Byte], Int)]): Unit = - fs.write(at(s)) { ostream => r.foreach(Line.write(_, ostream)) } + fs.write(at(s))(ostream => r.foreach(Line.write(_, ostream))) private def at(s: SemanticHash.Type): String = s"$cacheDir/${base64Encode(s.toString.getBytes).mkString}" private case object Line { private type Type = (Array[Byte], Int) + def write(entry: Type, ostream: OutputStream): Unit = { ostream.write(entry._2.toString.getBytes) ostream.write(','.toInt) diff --git a/hail/src/main/scala/is/hail/backend/HailStateManager.scala b/hail/src/main/scala/is/hail/backend/HailStateManager.scala index 07d067c0007..e9053c64e33 100644 --- a/hail/src/main/scala/is/hail/backend/HailStateManager.scala +++ b/hail/src/main/scala/is/hail/backend/HailStateManager.scala @@ -2,5 +2,5 @@ package is.hail.backend import is.hail.variant.ReferenceGenome -case class HailStateManager(val referenceGenomes: Map[String, ReferenceGenome]) extends Serializable { -} +case class HailStateManager(val referenceGenomes: Map[String, ReferenceGenome]) + extends Serializable {} diff --git a/hail/src/main/scala/is/hail/backend/HailTaskContext.scala b/hail/src/main/scala/is/hail/backend/HailTaskContext.scala index d1186cd3239..b490319f277 100644 --- a/hail/src/main/scala/is/hail/backend/HailTaskContext.scala +++ b/hail/src/main/scala/is/hail/backend/HailTaskContext.scala @@ -8,19 +8,14 @@ import java.io.Closeable class TaskFinalizer { val closeables = new BoxedArrayBuilder[Closeable]() - def clear(): Unit = { + def clear(): Unit = closeables.clear() - } - - def addCloseable(c: Closeable): Unit = { + + def addCloseable(c: Closeable): Unit = closeables += c - } - def closeAll(): Unit = { - (0 until closeables.size).foreach { i => - closeables(i).close() - } - } + def closeAll(): Unit = + (0 until closeables.size).foreach(i => closeables(i).close()) } abstract class HailTaskContext extends AutoCloseable { @@ -37,7 +32,7 @@ abstract class HailTaskContext extends AutoCloseable { def partSuffix(): String = { val rng = new java.security.SecureRandom() val fileUUID = new java.util.UUID(rng.nextLong(), rng.nextLong()) - s"${ stageId() }-${ partitionId() }-${ attemptNumber() }-$fileUUID" + s"${stageId()}-${partitionId()}-${attemptNumber()}-$fileUUID" } val finalizers = new BoxedArrayBuilder[TaskFinalizer]() @@ -49,12 +44,12 @@ abstract class HailTaskContext extends AutoCloseable { } def close(): Unit = { - log.info(s"TaskReport: stage=${ stageId() }, partition=${ partitionId() }, attempt=${ attemptNumber() }, " + - s"peakBytes=${ thePool.getHighestTotalUsage }, peakBytesReadable=${ formatSpace(thePool.getHighestTotalUsage) }, "+ - s"chunks requested=${thePool.getUsage._1}, cache hits=${thePool.getUsage._2}") - (0 until finalizers.size).foreach { i => - finalizers(i).closeAll() - } + log.info( + s"TaskReport: stage=${stageId()}, partition=${partitionId()}, attempt=${attemptNumber()}, " + + s"peakBytes=${thePool.getHighestTotalUsage}, peakBytesReadable=${formatSpace(thePool.getHighestTotalUsage)}, " + + s"chunks requested=${thePool.getUsage._1}, cache hits=${thePool.getUsage._2}" + ) + (0 until finalizers.size).foreach(i => finalizers(i).closeAll()) thePool.close() } } diff --git a/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala b/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala index 96d73a34719..cdb3105b012 100644 --- a/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala +++ b/hail/src/main/scala/is/hail/backend/local/LocalBackend.scala @@ -1,34 +1,34 @@ package is.hail.backend.local -import is.hail.annotations.{Region, SafeRow, UnsafeRow} +import is.hail.{HailContext, HailFeatureFlags} +import is.hail.annotations.{Region, SafeRow} import is.hail.asm4s._ import is.hail.backend._ +import is.hail.expr.Validate +import is.hail.expr.ir.{IRParser, _} import is.hail.expr.ir.analyses.SemanticHash import is.hail.expr.ir.lowering._ -import is.hail.expr.ir.{IRParser, _} -import is.hail.expr.{JSONAnnotationImpex, Validate} -import is.hail.io.fs._ -import is.hail.io.plink.LoadPlink import is.hail.io.{BufferSpec, TypedCodecSpec} +import is.hail.io.fs._ import is.hail.linalg.BlockMatrix import is.hail.types._ import is.hail.types.encoded.EType import is.hail.types.physical.PTuple -import is.hail.types.physical.stypes.{PTypeReferenceSingleCodeType, SingleCodeType} +import is.hail.types.physical.stypes.PTypeReferenceSingleCodeType import is.hail.types.virtual.TVoid import is.hail.utils._ import is.hail.variant.ReferenceGenome -import is.hail.{HailContext, HailFeatureFlags} -import org.apache.hadoop -import org.json4s._ -import org.json4s.jackson.{JsonMethods, Serialization} -import org.sparkproject.guava.util.concurrent.MoreExecutors -import java.io.PrintWriter -import java.nio.charset.StandardCharsets import scala.collection.JavaConverters._ import scala.reflect.ClassTag +import java.io.PrintWriter + +import org.apache.hadoop +import org.json4s._ +import org.json4s.jackson.Serialization +import org.sparkproject.guava.util.concurrent.MoreExecutors + class LocalBroadcastValue[T](val value: T) extends BroadcastValue[T] with Serializable class LocalTaskContext(val partitionId: Int, val stageId: Int) extends HailTaskContext { @@ -45,7 +45,7 @@ object LocalBackend { logFile: String = "hail.log", quiet: Boolean = false, append: Boolean = false, - skipLoggingConfiguration: Boolean = false + skipLoggingConfiguration: Boolean = false, ): LocalBackend = synchronized { require(theLocalBackend == null) @@ -54,7 +54,7 @@ object LocalBackend { theLocalBackend = new LocalBackend( tmpdir, gcsRequesterPaysProject, - gcsRequesterPaysBuckets + gcsRequesterPaysBuckets, ) theLocalBackend.addDefaultReferences() theLocalBackend @@ -75,10 +75,11 @@ object LocalBackend { class LocalBackend( val tmpdir: String, gcsRequesterPaysProject: String, - gcsRequesterPaysBuckets: String + gcsRequesterPaysBuckets: String, ) extends Backend with BackendWithCodeCache { // FIXME don't rely on hadoop val hadoopConf = new hadoop.conf.Configuration() + if (gcsRequesterPaysProject != null) { if (gcsRequesterPaysBuckets == null) { hadoopConf.set("fs.gs.requester.pays.mode", "AUTO") @@ -89,12 +90,14 @@ class LocalBackend( hadoopConf.set("fs.gs.requester.pays.buckets", gcsRequesterPaysBuckets) } } + hadoopConf.set( "hadoop.io.compression.codecs", "org.apache.hadoop.io.compress.DefaultCodec," + "is.hail.io.compress.BGzipCodec," + "is.hail.io.compress.BGzipCodecTbi," - + "org.apache.hadoop.io.compress.GzipCodec") + + "org.apache.hadoop.io.compress.GzipCodec", + ) private[this] val flags = HailFeatureFlags.fromEnv() private[this] val theHailClassLoader = new HailClassLoader(getClass().getClassLoader()) @@ -108,19 +111,38 @@ class LocalBackend( val fs: FS = new HadoopFS(new SerializableHadoopConfiguration(hadoopConf)) def withExecuteContext[T](timer: ExecutionTimer): (ExecuteContext => T) => T = - ExecuteContext.scoped(tmpdir, tmpdir, this, fs, timer, null, theHailClassLoader, this.references, flags, new BackendContext { - override val executionCache: ExecutionCache = - ExecutionCache.fromFlags(flags, fs, tmpdir) - }) - - def withExecuteContext[T](methodName: String): (ExecuteContext => T) => T = { f => - ExecutionTimer.logTime(methodName) { timer => - ExecuteContext.scoped(tmpdir, tmpdir, this, fs, timer, null, theHailClassLoader, this.references, flags, new BackendContext { + ExecuteContext.scoped( + tmpdir, + tmpdir, + this, + fs, + timer, + null, + theHailClassLoader, + flags, + new BackendContext { override val executionCache: ExecutionCache = ExecutionCache.fromFlags(flags, fs, tmpdir) - })(f) + }, + ) + + override def withExecuteContext[T](methodName: String)(f: ExecuteContext => T): T = + ExecutionTimer.logTime(methodName) { timer => + ExecuteContext.scoped( + tmpdir, + tmpdir, + this, + fs, + timer, + null, + theHailClassLoader, + flags, + new BackendContext { + override val executionCache: ExecutionCache = + ExecutionCache.fromFlags(flags, fs, tmpdir) + }, + )(f) } - } def broadcast[T: ClassTag](value: T): BroadcastValue[T] = new LocalBroadcastValue[T](value) @@ -137,15 +159,13 @@ class LocalBackend( fs: FS, collection: Array[Array[Byte]], stageIdentifier: String, - dependency: Option[TableStageDependency] = None + dependency: Option[TableStageDependency] = None, )( f: (Array[Byte], HailTaskContext, HailClassLoader, FS) => Array[Byte] ): Array[Array[Byte]] = { val stageId = nextStageId() collection.zipWithIndex.map { case (c, i) => - using(new LocalTaskContext(i, stageId)) { htc => - f(c, htc, theHailClassLoader, fs) - } + using(new LocalTaskContext(i, stageId))(htc => f(c, htc, theHailClassLoader, fs)) } } @@ -154,17 +174,19 @@ class LocalBackend( fs: FS, collection: IndexedSeq[(Array[Byte], Int)], stageIdentifier: String, - dependency: Option[TableStageDependency] = None - )(f: (Array[Byte], HailTaskContext, HailClassLoader, FS) => Array[Byte]) - : (Option[Throwable], IndexedSeq[(Array[Byte], Int)]) = { + dependency: Option[TableStageDependency] = None, + )( + f: (Array[Byte], HailTaskContext, HailClassLoader, FS) => Array[Byte] + ): (Option[Throwable], IndexedSeq[(Array[Byte], Int)]) = { val stageId = nextStageId() runAllKeepFirstError(MoreExecutors.sameThreadExecutor) { collection.map { case (c, i) => ( - () => using(new LocalTaskContext(i, stageId)) { - f(c, _, theHailClassLoader, fs) - }, - i + () => + using(new LocalTaskContext(i, stageId)) { + f(c, _, theHailClassLoader, fs) + }, + i, ) } } @@ -174,19 +196,27 @@ class LocalBackend( def stop(): Unit = LocalBackend.stop() - private[this] def _jvmLowerAndExecute(ctx: ExecuteContext, ir0: IR, print: Option[PrintWriter] = None): Either[Unit, (PTuple, Long)] = { - val ir = LoweringPipeline.darrayLowerer(true)(DArrayLowering.All).apply(ctx, ir0).asInstanceOf[IR] + private[this] def _jvmLowerAndExecute( + ctx: ExecuteContext, + ir0: IR, + print: Option[PrintWriter] = None, + ): Either[Unit, (PTuple, Long)] = { + val ir = + LoweringPipeline.darrayLowerer(true)(DArrayLowering.All).apply(ctx, ir0).asInstanceOf[IR] if (!Compilable(ir)) - throw new LowererUnsupportedOperation(s"lowered to uncompilable IR: ${ Pretty(ctx, ir) }") + throw new LowererUnsupportedOperation(s"lowered to uncompilable IR: ${Pretty(ctx, ir)}") if (ir.typ == TVoid) { - val (pt, f) = ctx.timer.time("Compile") { - Compile[AsmFunction1RegionUnit](ctx, + val (_, f) = ctx.timer.time("Compile") { + Compile[AsmFunction1RegionUnit]( + ctx, FastSeq(), - FastSeq(classInfo[Region]), UnitInfo, + FastSeq(classInfo[Region]), + UnitInfo, ir, - print = print) + print = print, + ) } ctx.timer.time("Run") { @@ -194,11 +224,14 @@ class LocalBackend( } } else { val (Some(PTypeReferenceSingleCodeType(pt: PTuple)), f) = ctx.timer.time("Compile") { - Compile[AsmFunction1RegionLong](ctx, + Compile[AsmFunction1RegionLong]( + ctx, FastSeq(), - FastSeq(classInfo[Region]), LongInfo, + FastSeq(classInfo[Region]), + LongInfo, MakeTuple.ordered(FastSeq(ir)), - print = print) + print = print, + ) } ctx.timer.time("Run") { @@ -211,14 +244,13 @@ class LocalBackend( TypeCheck(ctx, ir) Validate(ir) val queryID = Backend.nextID() - log.info(s"starting execution of query $queryID of initial size ${ IRSize(ir) }") + log.info(s"starting execution of query $queryID of initial size ${IRSize(ir)}") ctx.irMetadata = ctx.irMetadata.copy(semhash = SemanticHash(ctx)(ir)) val res = _jvmLowerAndExecute(ctx, ir) log.info(s"finished execution of query $queryID") res } - def executeToJavaValue(timer: ExecutionTimer, ir: IR): (Any, ExecutionTimer) = withExecuteContext(timer) { ctx => val result = _execute(ctx, ir) match { @@ -238,7 +270,10 @@ class LocalBackend( val elementType = pt.fields(0).typ assert(pt.isFieldDefined(off, 0)) val codec = TypedCodecSpec( - EType.fromPythonTypeEncoding(elementType.virtualType), elementType.virtualType, bs) + EType.fromPythonTypeEncoding(elementType.virtualType), + elementType.virtualType, + bs, + ) codec.encode(ctx, elementType, pt.loadField(off, 0)) } result @@ -251,11 +286,12 @@ class LocalBackend( val t = ir.typ assert(t.isRealizable) val queryID = Backend.nextID() - log.info(s"starting execution of query $queryID} of initial size ${ IRSize(ir) }") + log.info(s"starting execution of query $queryID} of initial size ${IRSize(ir)}") val retVal = _execute(ctx, ir) val literalIR = retVal match { - case Left(x) => throw new HailException("Can't create literal") - case Right((pt, addr)) => GetFieldByIdx(EncodedLiteral.fromPTypeAndAddress(pt, addr, ctx), 0) + case Left(_) => throw new HailException("Can't create literal") + case Right((pt, addr)) => + GetFieldByIdx(EncodedLiteral.fromPTypeAndAddress(pt, addr, ctx), 0) } log.info(s"finished execution of query $queryID") addJavaIR(literalIR) @@ -263,16 +299,24 @@ class LocalBackend( } } - override def execute(ir: String, timed: Boolean)(consume: (ExecuteContext, Either[Unit, (PTuple, Long)], String) => Unit): Unit = { + override def execute( + ir: String, + timed: Boolean, + )( + consume: (ExecuteContext, Either[Unit, (PTuple, Long)], String) => Unit + ): Unit = { withExecuteContext("LocalBackend.execute") { ctx => val res = ctx.timer.time("execute") { - val irData = IRParser.parse_value_ir(ir, IRParserEnvironment(ctx, irMap = persistedIR.toMap)) + val irData = + IRParser.parse_value_ir(ir, IRParserEnvironment(ctx, irMap = persistedIR.toMap)) val queryID = Backend.nextID() - log.info(s"starting execution of query $queryID of initial size ${ IRSize(irData) }") + log.info(s"starting execution of query $queryID of initial size ${IRSize(irData)}") _execute(ctx, irData) } ctx.timer.finish() - val timings = if (timed) Serialization.write(Map("timings" -> ctx.timer.toMap))(new DefaultFormats {}) else "" + val timings = if (timed) + Serialization.write(Map("timings" -> ctx.timer.toMap))(new DefaultFormats {}) + else "" consume(ctx, res, timings) } } @@ -280,78 +324,90 @@ class LocalBackend( def pyAddReference(jsonConfig: String): Unit = addReference(ReferenceGenome.fromJSON(jsonConfig)) def pyRemoveReference(name: String): Unit = removeReference(name) - def pyAddLiftover(name: String, chainFile: String, destRGName: String): Unit = { + def pyAddLiftover(name: String, chainFile: String, destRGName: String): Unit = ExecutionTimer.logTime("LocalBackend.pyReferenceAddLiftover") { timer => - withExecuteContext(timer) { ctx => - references(name).addLiftover(ctx, chainFile, destRGName) - } + withExecuteContext(timer)(ctx => references(name).addLiftover(ctx, chainFile, destRGName)) } - } - def pyRemoveLiftover(name: String, destRGName: String) = references(name).removeLiftover(destRGName) - def pyFromFASTAFile(name: String, fastaFile: String, indexFile: String, - xContigs: java.util.List[String], yContigs: java.util.List[String], mtContigs: java.util.List[String], - parInput: java.util.List[String]): String = { + def pyRemoveLiftover(name: String, destRGName: String) = + references(name).removeLiftover(destRGName) + + def pyFromFASTAFile( + name: String, + fastaFile: String, + indexFile: String, + xContigs: java.util.List[String], + yContigs: java.util.List[String], + mtContigs: java.util.List[String], + parInput: java.util.List[String], + ): String = { ExecutionTimer.logTime("LocalBackend.pyFromFASTAFile") { timer => withExecuteContext(timer) { ctx => - val rg = ReferenceGenome.fromFASTAFile(ctx, name, fastaFile, indexFile, - xContigs.asScala.toArray, yContigs.asScala.toArray, mtContigs.asScala.toArray, parInput.asScala.toArray) + val rg = ReferenceGenome.fromFASTAFile( + ctx, + name, + fastaFile, + indexFile, + xContigs.asScala.toArray, + yContigs.asScala.toArray, + mtContigs.asScala.toArray, + parInput.asScala.toArray, + ) rg.toJSONString } } } - def pyAddSequence(name: String, fastaFile: String, indexFile: String): Unit = { + def pyAddSequence(name: String, fastaFile: String, indexFile: String): Unit = ExecutionTimer.logTime("LocalBackend.pyAddSequence") { timer => - withExecuteContext(timer) { ctx => - references(name).addSequence(ctx, fastaFile, indexFile) - } + withExecuteContext(timer)(ctx => references(name).addSequence(ctx, fastaFile, indexFile)) } - } + def pyRemoveSequence(name: String) = references(name).removeSequence() - def parse_value_ir(s: String, refMap: java.util.Map[String, String]): IR = { + def parse_value_ir(s: String, refMap: java.util.Map[String, String]): IR = ExecutionTimer.logTime("LocalBackend.parse_value_ir") { timer => withExecuteContext(timer) { ctx => - IRParser.parse_value_ir(s, IRParserEnvironment(ctx, persistedIR.toMap), BindingEnv.eval(refMap.asScala.toMap.mapValues(IRParser.parseType).toSeq: _*)) + IRParser.parse_value_ir( + s, + IRParserEnvironment(ctx, persistedIR.toMap), + BindingEnv.eval(refMap.asScala.toMap.mapValues(IRParser.parseType).toSeq: _*), + ) } } - } - def parse_table_ir(s: String): TableIR = { + def parse_table_ir(s: String): TableIR = ExecutionTimer.logTime("LocalBackend.parse_table_ir") { timer => withExecuteContext(timer) { ctx => IRParser.parse_table_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap)) } } - } - def parse_matrix_ir(s: String): MatrixIR = { + def parse_matrix_ir(s: String): MatrixIR = ExecutionTimer.logTime("LocalBackend.parse_matrix_ir") { timer => withExecuteContext(timer) { ctx => IRParser.parse_matrix_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap)) } } - } - def parse_blockmatrix_ir(s: String): BlockMatrixIR = { + def parse_blockmatrix_ir(s: String): BlockMatrixIR = ExecutionTimer.logTime("LocalBackend.parse_blockmatrix_ir") { timer => withExecuteContext(timer) { ctx => IRParser.parse_blockmatrix_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap)) } } - } override def lowerDistributedSort( ctx: ExecuteContext, stage: TableStage, sortFields: IndexedSeq[SortField], rt: RTable, - nPartitions: Option[Int] + nPartitions: Option[Int], ): TableReader = LowerDistributedSort.distributedSort(ctx, stage, sortFields, rt, nPartitions) - def persist(backendContext: BackendContext, id: String, value: BlockMatrix, storageLevel: String): Unit = ??? + def persist(backendContext: BackendContext, id: String, value: BlockMatrix, storageLevel: String) + : Unit = ??? def unpersist(backendContext: BackendContext, id: String): Unit = ??? @@ -359,10 +415,7 @@ class LocalBackend( def getPersistedBlockMatrixType(backendContext: BackendContext, id: String): BlockMatrixType = ??? - def tableToTableStage(ctx: ExecuteContext, - inputIR: TableIR, - analyses: LoweringAnalyses - ): TableStage = { + def tableToTableStage(ctx: ExecuteContext, inputIR: TableIR, analyses: LoweringAnalyses) + : TableStage = LowerTableIR.applyTable(inputIR, DArrayLowering.All, ctx, analyses) - } } diff --git a/hail/src/main/scala/is/hail/backend/service/Main.scala b/hail/src/main/scala/is/hail/backend/service/Main.scala index 910f1e930ad..698f5ffa23c 100644 --- a/hail/src/main/scala/is/hail/backend/service/Main.scala +++ b/hail/src/main/scala/is/hail/backend/service/Main.scala @@ -1,19 +1,13 @@ package is.hail.backend.service -import is.hail.HailContext -import org.apache.log4j.{LogManager, PropertyConfigurator} - -import java.util.Properties - object Main { val WORKER = "worker" val DRIVER = "driver" - def main(argv: Array[String]): Unit = { + def main(argv: Array[String]): Unit = argv(3) match { case WORKER => Worker.main(argv) case DRIVER => ServiceBackendAPI.main(argv) - case kind => throw new RuntimeException(s"unknown kind: ${kind}") + case kind => throw new RuntimeException(s"unknown kind: $kind") } - } } diff --git a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala index 205361f35fb..5187e0776bf 100644 --- a/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala +++ b/hail/src/main/scala/is/hail/backend/service/ServiceBackend.scala @@ -1,17 +1,18 @@ package is.hail.backend.service +import is.hail.{HailContext, HailFeatureFlags} import is.hail.annotations._ import is.hail.asm4s._ import is.hail.backend._ import is.hail.expr.Validate +import is.hail.expr.ir.{ + Compile, IR, IRParser, LoweringAnalyses, MakeTuple, SortField, TableIR, TableReader, TypeCheck, +} import is.hail.expr.ir.analyses.SemanticHash import is.hail.expr.ir.functions.IRFunctionRegistry import is.hail.expr.ir.lowering._ -import is.hail.expr.ir.{Compile, IR, IRParser, LoweringAnalyses, MakeTuple, SortField, TableIR, TableReader, TypeCheck} -import is.hail.io.fs._ -import is.hail.io.plink.LoadPlink -import is.hail.io.vcf.LoadVCF import is.hail.io.{BufferSpec, TypedCodecSpec} +import is.hail.io.fs._ import is.hail.linalg.BlockMatrix import is.hail.services._ import is.hail.services.batch_client.BatchClient @@ -22,20 +23,19 @@ import is.hail.types.physical.stypes.PTypeReferenceSingleCodeType import is.hail.types.virtual._ import is.hail.utils._ import is.hail.variant.ReferenceGenome -import is.hail.{HailContext, HailFeatureFlags} -import org.apache.log4j.Logger -import org.json4s.JsonAST._ -import org.json4s.jackson.{JsonMethods, Serialization} -import org.json4s.{DefaultFormats, Extraction, Formats} + +import scala.annotation.switch +import scala.collection.JavaConverters._ +import scala.reflect.ClassTag import java.io._ import java.nio.charset.StandardCharsets import java.util.concurrent._ -import scala.annotation.switch -import scala.collection.mutable -import scala.language.higherKinds -import scala.reflect.ClassTag -import scala.collection.JavaConverters._ + +import org.apache.log4j.Logger +import org.json4s.{DefaultFormats, Formats} +import org.json4s.JsonAST._ +import org.json4s.jackson.JsonMethods class ServiceBackendContext( val billingProject: String, @@ -47,8 +47,7 @@ class ServiceBackendContext( val cloudfuseConfig: Array[CloudfuseConfig], val profile: Boolean, val executionCache: ExecutionCache, -) extends BackendContext with Serializable { -} +) extends BackendContext with Serializable {} object ServiceBackend { private val log = Logger.getLogger(getClass.getName()) @@ -60,12 +59,12 @@ object ServiceBackend { batchClient: BatchClient, batchId: Option[Long], scratchDir: String = sys.env.get("HAIL_WORKER_SCRATCH_DIR").getOrElse(""), - rpcConfig: ServiceBackendRPCPayload + rpcConfig: ServiceBackendRPCPayload, ): ServiceBackend = { val flags = HailFeatureFlags.fromMap(rpcConfig.flags) val shouldProfile = flags.get("profile") != null - val fs = FS.cloudSpecificFS(s"${scratchDir}/secrets/gsa-key/key.json", Some(flags)) + val fs = FS.cloudSpecificFS(s"$scratchDir/secrets/gsa-key/key.json", Some(flags)) val backendContext = new ServiceBackendContext( rpcConfig.billing_project, @@ -76,7 +75,7 @@ object ServiceBackend { rpcConfig.regions, rpcConfig.cloudfuse_configs, shouldProfile, - ExecutionCache.fromFlags(flags, fs, rpcConfig.remote_tmpdir) + ExecutionCache.fromFlags(flags, fs, rpcConfig.remote_tmpdir), ) val backend = new ServiceBackend( @@ -89,13 +88,11 @@ object ServiceBackend { rpcConfig.tmp_dir, fs, backendContext, - scratchDir + scratchDir, ) backend.addDefaultReferences() - rpcConfig.custom_references.foreach { s => - backend.addReference(ReferenceGenome.fromJSON(s)) - } + rpcConfig.custom_references.foreach(s => backend.addReference(ReferenceGenome.fromJSON(s))) rpcConfig.liftovers.foreach { case (sourceGenome, liftoversForSource) => liftoversForSource.foreach { case (destGenome, chainFile) => backend.addLiftover(sourceGenome, chainFile, destGenome) @@ -133,9 +130,9 @@ class ServiceBackend( def broadcast[T: ClassTag](_value: T): BroadcastValue[T] = { using(new ObjectOutputStream(new ByteArrayOutputStream())) { os => - try { + try os.writeObject(_value) - } catch { + catch { case e: Exception => fatal(_value.toString, e) } @@ -157,13 +154,12 @@ class ServiceBackend( fs: FS, collection: Array[Array[Byte]], stageIdentifier: String, - dependency: Option[TableStageDependency] = None, - f: (Array[Byte], HailTaskContext, HailClassLoader, FS) => Array[Byte] + f: (Array[Byte], HailTaskContext, HailClassLoader, FS) => Array[Byte], ): (String, String, Int) = { val backendContext = _backendContext.asInstanceOf[ServiceBackendContext] val n = collection.length val token = tokenUrlSafe(32) - val root = s"${ backendContext.remoteTmpDir }parallelizeAndComputeWithIndex/$token" + val root = s"${backendContext.remoteTmpDir}parallelizeAndComputeWithIndex/$token" log.info(s"parallelizeAndComputeWithIndex: $token: nPartitions $n") log.info(s"parallelizeAndComputeWithIndex: $token: writing f and contexts") @@ -171,7 +167,7 @@ class ServiceBackend( val uploadFunction = executor.submit[Unit](() => retryTransientErrors { fs.writePDOS(s"$root/f") { fos => - using(new ObjectOutputStream(fos)) { oos => oos.writeObject(f) } + using(new ObjectOutputStream(fos))(oos => oos.writeObject(f)) } } ) @@ -186,9 +182,7 @@ class ServiceBackend( os.writeInt(len) o += len } - collection.foreach { context => - os.write(context) - } + collection.foreach(context => os.write(context)) } } ) @@ -205,7 +199,8 @@ class ServiceBackend( resources = resources.merge(JObject("memory" -> JString(backendContext.workerMemory))) } if (backendContext.storageRequirement != "0Gi") { - resources = resources.merge(JObject("storage" -> JString(backendContext.storageRequirement))) + resources = + resources.merge(JObject("storage" -> JString(backendContext.storageRequirement))) } JObject( "always_run" -> JBool(false), @@ -214,29 +209,29 @@ class ServiceBackend( "process" -> JObject( "jar_spec" -> JObject( "type" -> JString("jar_url"), - "value" -> JString(jarLocation) + "value" -> JString(jarLocation), ), "command" -> JArray(List( JString(Main.WORKER), JString(root), JString(s"$i"), - JString(s"$n"))), + JString(s"$n"), + )), "type" -> JString("jvm"), "profile" -> JBool(backendContext.profile), ), "attributes" -> JObject( - "name" -> JString(s"${ name }_stage${ stageCount }_${ stageIdentifier }_job$i"), + "name" -> JString(s"${name}_stage${stageCount}_${stageIdentifier}_job$i") ), - "mount_tokens" -> JBool(true), "resources" -> resources, "regions" -> JArray(backendContext.regions.map(JString).toList), "cloudfuse" -> JArray(backendContext.cloudfuseConfig.map { config => JObject( "bucket" -> JString(config.bucket), "mount_path" -> JString(config.mount_path), - "read_only" -> JBool(config.read_only) + "read_only" -> JBool(config.read_only), ) - }.toList) + }.toList), ) } @@ -251,8 +246,10 @@ class ServiceBackend( "billing_project" -> JString(backendContext.billingProject), "n_jobs" -> JInt(n), "token" -> JString(token), - "attributes" -> JObject("name" -> JString(name + "_" + stageCount))), - jobs) + "attributes" -> JObject("name" -> JString(name + "_" + stageCount)), + ), + jobs, + ) (batchId, 1L) } @@ -287,23 +284,25 @@ class ServiceBackend( fs: FS, collection: Array[Array[Byte]], stageIdentifier: String, - dependency: Option[TableStageDependency] = None + dependency: Option[TableStageDependency] = None, )( f: (Array[Byte], HailTaskContext, HailClassLoader, FS) => Array[Byte] ): Array[Array[Byte]] = { - val (token, root, n) = submitAndWaitForBatch(_backendContext, fs, collection, stageIdentifier, dependency, f) + val (token, root, n) = + submitAndWaitForBatch(_backendContext, fs, collection, stageIdentifier, f) log.info(s"parallelizeAndComputeWithIndex: $token: reading results") val startTime = System.nanoTime() - val results = try { - executor.invokeAll[Array[Byte]]( - IndexedSeq.range(0, n).map { i => - (() => readResult(root, i)): Callable[Array[Byte]] - }.asJavaCollection - ).asScala.map(_.get).toArray - } catch { - case exc: ExecutionException if exc.getCause() != null => throw exc.getCause() - } + val results = + try + executor.invokeAll[Array[Byte]]( + IndexedSeq.range(0, n).map { i => + (() => readResult(root, i)): Callable[Array[Byte]] + }.asJavaCollection + ).asScala.map(_.get).toArray + catch { + case exc: ExecutionException if exc.getCause() != null => throw exc.getCause() + } val resultsReadingSeconds = (System.nanoTime() - startTime) / 1000000000.0 val rate = results.length / resultsReadingSeconds val byterate = results.map(_.length).sum / resultsReadingSeconds / 1024 / 1024 @@ -316,13 +315,15 @@ class ServiceBackend( fs: FS, collection: IndexedSeq[(Array[Byte], Int)], stageIdentifier: String, - dependency: Option[TableStageDependency] = None - )(f: (Array[Byte], HailTaskContext, HailClassLoader, FS) => Array[Byte] + dependency: Option[TableStageDependency] = None, + )( + f: (Array[Byte], HailTaskContext, HailClassLoader, FS) => Array[Byte] ): (Option[Throwable], IndexedSeq[(Array[Byte], Int)]) = { - val (token, root, n) = submitAndWaitForBatch(_backendContext, fs, collection.map(_._1).toArray, stageIdentifier, dependency, f) + val (token, root, _) = + submitAndWaitForBatch(_backendContext, fs, collection.map(_._1).toArray, stageIdentifier, f) log.info(s"parallelizeAndComputeWithIndex: $token: reading results") val startTime = System.nanoTime() - val r@(_, results) = runAllKeepFirstError(executor) { + val r @ (_, results) = runAllKeepFirstError(executor) { collection.zipWithIndex.map { case ((_, i), jobIndex) => (() => readResult(root, jobIndex), i) } @@ -343,27 +344,32 @@ class ServiceBackend( val x = LoweringPipeline.darrayLowerer(true)(DArrayLowering.All).apply(ctx, _x) .asInstanceOf[IR] if (x.typ == TVoid) { - val (_, f) = Compile[AsmFunction1RegionUnit](ctx, + val (_, f) = Compile[AsmFunction1RegionUnit]( + ctx, FastSeq(), - FastSeq[TypeInfo[_]](classInfo[Region]), UnitInfo, + FastSeq[TypeInfo[_]](classInfo[Region]), + UnitInfo, x, - optimize = true) + optimize = true, + ) ctx.scopedExecution((hcl, fs, htc, r) => f(hcl, fs, htc, r).apply(r)) Array() } else { - val (Some(PTypeReferenceSingleCodeType(pt: PTuple)), f) = Compile[AsmFunction1RegionLong](ctx, + val (Some(PTypeReferenceSingleCodeType(pt: PTuple)), f) = Compile[AsmFunction1RegionLong]( + ctx, FastSeq(), - FastSeq(classInfo[Region]), LongInfo, + FastSeq(classInfo[Region]), + LongInfo, MakeTuple.ordered(FastSeq(x)), - optimize = true) - val retPType = pt.asInstanceOf[PBaseStruct] + optimize = true, + ) val elementType = pt.fields(0).typ val off = ctx.scopedExecution((hcl, fs, htc, r) => f(hcl, fs, htc, r).apply(r)) val codec = TypedCodecSpec( EType.fromPythonTypeEncoding(elementType.virtualType), elementType.virtualType, - BufferSpec.parseOrDefault(bufferSpecString) + BufferSpec.parseOrDefault(bufferSpecString), ) assert(pt.isFieldDefined(off, 0)) codec.encode(ctx, elementType, pt.loadField(off, 0)) @@ -374,9 +380,9 @@ class ServiceBackend( ctx: ExecuteContext, code: String, token: String, - bufferSpecString: String + bufferSpecString: String, ): Array[Byte] = { - log.info(s"executing: ${token} ${ctx.fs.getConfiguration()}") + log.info(s"executing: $token ${ctx.fs.getConfiguration()}") val ir = IRParser.parse_value_ir(ctx, code) ctx.irMetadata = ctx.irMetadata.copy(semhash = SemanticHash(ctx)(ir)) execute(ctx, ir, bufferSpecString) @@ -387,10 +393,12 @@ class ServiceBackend( inputStage: TableStage, sortFields: IndexedSeq[SortField], rt: RTable, - nPartitions: Option[Int] - ): TableReader = LowerDistributedSort.distributedSort(ctx, inputStage, sortFields, rt, nPartitions) + nPartitions: Option[Int], + ): TableReader = + LowerDistributedSort.distributedSort(ctx, inputStage, sortFields, rt, nPartitions) - def persist(backendContext: BackendContext, id: String, value: BlockMatrix, storageLevel: String): Unit = ??? + def persist(backendContext: BackendContext, id: String, value: BlockMatrix, storageLevel: String) + : Unit = ??? def unpersist(backendContext: BackendContext, id: String): Unit = ??? @@ -398,14 +406,11 @@ class ServiceBackend( def getPersistedBlockMatrixType(backendContext: BackendContext, id: String): BlockMatrixType = ??? - def tableToTableStage(ctx: ExecuteContext, - inputIR: TableIR, - analyses: LoweringAnalyses - ): TableStage = { + def tableToTableStage(ctx: ExecuteContext, inputIR: TableIR, analyses: LoweringAnalyses) + : TableStage = LowerTableIR.applyTable(inputIR, DArrayLowering.All, ctx, analyses) - } - def withExecuteContext[T](methodName: String): (ExecuteContext => T) => T = { f => + override def withExecuteContext[T](methodName: String)(f: ExecuteContext => T): T = ExecutionTimer.logTime(methodName) { timer => ExecuteContext.scoped( tmpdir, @@ -415,24 +420,20 @@ class ServiceBackend( timer, null, theHailClassLoader, - references, flags, - serviceBackendContext + serviceBackendContext, )(f) } - } - def addLiftover(name: String, chainFile: String, destRGName: String): Unit = { + def addLiftover(name: String, chainFile: String, destRGName: String): Unit = withExecuteContext("addLiftover") { ctx => references(name).addLiftover(ctx, chainFile, destRGName) } - } - def addSequence(name: String, fastaFile: String, indexFile: String): Unit = { + def addSequence(name: String, fastaFile: String, indexFile: String): Unit = withExecuteContext("addSequence") { ctx => references(name).addSequence(ctx, fastaFile, indexFile) } - } } class EndOfInputException extends RuntimeException @@ -445,7 +446,7 @@ object ServiceBackendAPI { assert(argv.length == 7, argv.toFastSeq) val scratchDir = argv(0) - val logFile = argv(1) + // val logFile = argv(1) val jarLocation = argv(2) val kind = argv(3) assert(kind == Main.DRIVER) @@ -455,14 +456,16 @@ object ServiceBackendAPI { val fs = FS.cloudSpecificFS(s"$scratchDir/secrets/gsa-key/key.json", None) val deployConfig = DeployConfig.fromConfigFile( - s"$scratchDir/secrets/deploy-config/deploy-config.json") + s"$scratchDir/secrets/deploy-config/deploy-config.json" + ) DeployConfig.set(deployConfig) sys.env.get("HAIL_SSL_CONFIG_DIR").foreach(tls.setSSLConfigFromDir(_)) val batchClient = new BatchClient(s"$scratchDir/secrets/gsa-key/key.json") log.info("BatchClient allocated.") - var batchId = BatchConfig.fromConfigFile(s"$scratchDir/batch-config/batch-config.json").map(_.batchId) + val batchId = + BatchConfig.fromConfigFile(s"$scratchDir/batch-config/batch-config.json").map(_.batchId) log.info("BatchConfig parsed.") implicit val formats: Formats = DefaultFormats @@ -471,8 +474,13 @@ object ServiceBackendAPI { // FIXME: when can the classloader be shared? (optimizer benefits!) val backend = ServiceBackend( - jarLocation, name, new HailClassLoader(getClass().getClassLoader()), batchClient, batchId, scratchDir, - rpcConfig + jarLocation, + name, + new HailClassLoader(getClass().getClassLoader()), + batchClient, + batchId, + scratchDir, + rpcConfig, ) log.info("ServiceBackend allocated.") if (HailContext.isInitialized) { @@ -495,9 +503,8 @@ private class HailSocketAPIOutputStream( private[this] var closed: Boolean = false private[this] val dummy = new Array[Byte](8) - def writeBool(b: Boolean): Unit = { + def writeBool(b: Boolean): Unit = out.write(if (b) 1 else 0) - } def writeInt(v: Int): Unit = { Memory.storeInt(dummy, 0, v) @@ -516,12 +523,11 @@ private class HailSocketAPIOutputStream( def writeString(s: String): Unit = writeBytes(s.getBytes(StandardCharsets.UTF_8)) - def close(): Unit = { + def close(): Unit = if (!closed) { out.close() closed = true } - } } case class CloudfuseConfig(bucket: String, mount_path: String, read_only: Boolean) @@ -622,14 +628,14 @@ class ServiceBackendAPI( fastaPayload.x_contigs, fastaPayload.y_contigs, fastaPayload.mt_contigs, - fastaPayload.par + fastaPayload.par, ) } } private[this] def withIRFunctionsReadFromInput( serializedFunctions: Array[SerializedIRFunction], - ctx: ExecuteContext + ctx: ExecuteContext, )( body: () => Array[Byte] ): Array[Byte] = { @@ -642,13 +648,12 @@ class ServiceBackendAPI( func.value_parameter_names, func.value_parameter_types, func.return_type, - func.rendered_body + func.rendered_body, ) } body() - } finally { + } finally IRFunctionRegistry.clearUserFunctions() - } } def executeOneCommand(action: Int, payload: JValue): Unit = { @@ -672,7 +677,9 @@ class ServiceBackendAPI( output.writeInt(exc.errorId) } } - log.error("A worker failed. The exception was written for Python but we will also throw an exception to fail this driver job.") + log.error( + "A worker failed. The exception was written for Python but we will also throw an exception to fail this driver job." + ) throw exc case t: Throwable => val (shortMessage, expandedMessage, errorId) = handleForPython(t) @@ -685,7 +692,9 @@ class ServiceBackendAPI( output.writeInt(errorId) } } - log.error("An exception occurred in the driver. The exception was written for Python but we will re-throw to fail this driver job.") + log.error( + "An exception occurred in the driver. The exception was written for Python but we will re-throw to fail this driver job." + ) throw t } } diff --git a/hail/src/main/scala/is/hail/backend/service/Worker.scala b/hail/src/main/scala/is/hail/backend/service/Worker.scala index 53619d5a956..ad0b2498954 100644 --- a/hail/src/main/scala/is/hail/backend/service/Worker.scala +++ b/hail/src/main/scala/is/hail/backend/service/Worker.scala @@ -1,24 +1,24 @@ package is.hail.backend.service -import java.util -import java.io._ -import java.nio.charset._ -import java.util.{concurrent => javaConcurrent} - -import is.hail.asm4s._ import is.hail.{HAIL_REVISION, HailContext} +import is.hail.asm4s._ import is.hail.backend.HailTaskContext import is.hail.io.fs._ import is.hail.services._ import is.hail.utils._ -import org.apache.commons.io.IOUtils -import org.apache.log4j.Logger import scala.collection.mutable -import scala.concurrent.duration.{Duration, MILLISECONDS} -import scala.concurrent.{Future, Await, ExecutionContext} +import scala.concurrent.{Await, ExecutionContext, Future} +import scala.concurrent.duration.Duration import scala.util.control.NonFatal +import java.io._ +import java.nio.charset._ +import java.util +import java.util.{concurrent => javaConcurrent} + +import org.apache.log4j.Logger + class ServiceTaskContext(val partitionId: Int) extends HailTaskContext { override def stageId(): Int = 0 @@ -33,9 +33,9 @@ class WorkerTimer() { import WorkerTimer._ var startTimes: mutable.Map[String, Long] = mutable.Map() - def start(label: String): Unit = { + + def start(label: String): Unit = startTimes.put(label, System.nanoTime()) - } def end(label: String): Unit = { val endTime = System.nanoTime() @@ -54,7 +54,7 @@ class WorkerTimer() { // For more context, see: https://github.com/scala/bug/issues/9237#issuecomment-292436652 object ExplicitClassLoaderInputStream { val primClasses: util.HashMap[String, Class[_]] = { - val m = new util.HashMap[String, Class[_]](8, 1.0F) + val m = new util.HashMap[String, Class[_]](8, 1.0f) m.put("boolean", Boolean.getClass) m.put("byte", Byte.getClass) m.put("char", Char.getClass) @@ -67,7 +67,9 @@ object ExplicitClassLoaderInputStream { m } } -class ExplicitClassLoaderInputStream(is: InputStream, cl: ClassLoader) extends ObjectInputStream(is) { + +class ExplicitClassLoaderInputStream(is: InputStream, cl: ClassLoader) + extends ObjectInputStream(is) { override def resolveClass(desc: ObjectStreamClass): Class[_] = { val name = desc.getName @@ -84,8 +86,10 @@ class ExplicitClassLoaderInputStream(is: InputStream, cl: ClassLoader) extends O object Worker { private[this] val log = Logger.getLogger(getClass.getName()) private[this] val myRevision = HAIL_REVISION - private[this] implicit val ec = ExecutionContext.fromExecutorService( - javaConcurrent.Executors.newCachedThreadPool()) + + implicit private[this] val ec = ExecutionContext.fromExecutorService( + javaConcurrent.Executors.newCachedThreadPool() + ) private[this] def writeString(out: DataOutputStream, s: String): Unit = { val bytes = s.getBytes(StandardCharsets.UTF_8) @@ -97,11 +101,11 @@ object Worker { val theHailClassLoader = new HailClassLoader(getClass().getClassLoader()) if (argv.length != 7) { - throw new IllegalArgumentException(s"expected seven arguments, not: ${ argv.length }") + throw new IllegalArgumentException(s"expected seven arguments, not: ${argv.length}") } val scratchDir = argv(0) - val logFile = argv(1) - var jarLocation = argv(2) + // val logFile = argv(1) + // var jarLocation = argv(2) val kind = argv(3) assert(kind == Main.WORKER) val root = argv(4) @@ -110,7 +114,8 @@ object Worker { val timer = new WorkerTimer() val deployConfig = DeployConfig.fromConfigFile( - s"$scratchDir/secrets/deploy-config/deploy-config.json") + s"$scratchDir/secrets/deploy-config/deploy-config.json" + ) DeployConfig.set(deployConfig) sys.env.get("HAIL_SSL_CONFIG_DIR").foreach(tls.setSSLConfigFromDir(_)) @@ -122,18 +127,18 @@ object Worker { timer.start("readInputs") val fs = FS.cloudSpecificFS(s"$scratchDir/secrets/gsa-key/key.json", None) - def open(x: String): SeekableDataInputStream = { + def open(x: String): SeekableDataInputStream = fs.openNoCompression(x) - } - def write(x: String)(writer: PositionedDataOutputStream => Unit): Unit = { + def write(x: String)(writer: PositionedDataOutputStream => Unit): Unit = fs.writePDOS(x)(writer) - } val fFuture = Future { retryTransientErrors { using(new ExplicitClassLoaderInputStream(open(s"$root/f"), theHailClassLoader)) { is => - is.readObject().asInstanceOf[(Array[Byte], HailTaskContext, HailClassLoader, FS) => Array[Byte]] + is.readObject().asInstanceOf[(Array[Byte], HailTaskContext, HailClassLoader, FS) => Array[ + Byte + ]] } } } @@ -159,19 +164,40 @@ object Worker { timer.start("executeFunction") if (HailContext.isInitialized) { - HailContext.get.backend = new ServiceBackend(null, null, new HailClassLoader(getClass().getClassLoader()), null, None, null, null, null, null) + HailContext.get.backend = new ServiceBackend( + null, + null, + new HailClassLoader(getClass().getClassLoader()), + null, + None, + null, + null, + null, + null, + ) } else { HailContext( // FIXME: workers should not have backends, but some things do need hail contexts - new ServiceBackend(null, null, new HailClassLoader(getClass().getClassLoader()), null, None, null, null, null, null)) + new ServiceBackend( + null, + null, + new HailClassLoader(getClass().getClassLoader()), + null, + None, + null, + null, + null, + null, + ) + ) } val result = using(new ServiceTaskContext(i)) { htc => - try { + try retryTransientErrors { Right(f(context, htc, theHailClassLoader, fs)) } - } catch { + catch { case NonFatal(err) => Left(err) } } @@ -186,7 +212,8 @@ object Worker { dos.writeBoolean(true) dos.write(bytes) case Left(throwableWhileExecutingUserCode) => - val (shortMessage, expandedMessage, errorId) = handleForPython(throwableWhileExecutingUserCode) + val (shortMessage, expandedMessage, errorId) = + handleForPython(throwableWhileExecutingUserCode) dos.writeBoolean(false) writeString(dos, shortMessage) writeString(dos, expandedMessage) diff --git a/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala b/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala index fdc1b3b64b2..6d1a9201e9c 100644 --- a/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala +++ b/hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala @@ -1,27 +1,36 @@ package is.hail.backend.spark +import is.hail.{HailContext, HailFeatureFlags} import is.hail.annotations._ import is.hail.asm4s._ import is.hail.backend._ +import is.hail.expr.{JSONAnnotationImpex, SparkAnnotationImpex, Validate} +import is.hail.expr.ir.{IRParser, _} import is.hail.expr.ir.IRParser.parseType import is.hail.expr.ir.analyses.SemanticHash import is.hail.expr.ir.lowering._ -import is.hail.expr.ir.{IRParser, _} -import is.hail.expr.{JSONAnnotationImpex, SparkAnnotationImpex, Validate} -import is.hail.io.fs._ -import is.hail.io.plink.LoadPlink import is.hail.io.{BufferSpec, TypedCodecSpec} +import is.hail.io.fs._ import is.hail.linalg.{BlockMatrix, RowMatrix} import is.hail.rvd.RVD import is.hail.stats.LinearMixedModel import is.hail.types._ import is.hail.types.encoded.EType -import is.hail.types.physical.stypes.PTypeReferenceSingleCodeType import is.hail.types.physical.{PStruct, PTuple} +import is.hail.types.physical.stypes.PTypeReferenceSingleCodeType import is.hail.types.virtual.{TArray, TInterval, TStruct, TVoid} import is.hail.utils._ import is.hail.variant.ReferenceGenome -import is.hail.{HailContext, HailFeatureFlags} + +import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag +import scala.util.{Failure, Success, Try} +import scala.util.control.NonFatal + +import java.io.{Closeable, PrintWriter} + import org.apache.hadoop import org.apache.hadoop.conf.Configuration import org.apache.spark._ @@ -29,17 +38,8 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, SparkSession} import org.json4s +import org.json4s.DefaultFormats import org.json4s.jackson.{JsonMethods, Serialization} -import org.json4s.{DefaultFormats, Formats} - -import com.sun.net.httpserver.{HttpExchange} -import java.io.{Closeable, PrintWriter, OutputStream} -import scala.collection.JavaConverters._ -import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer -import scala.reflect.ClassTag -import scala.util.{Failure, Success, Try} - class SparkBroadcastValue[T](bc: Broadcast[T]) extends BroadcastValue[T] with Serializable { def value: T = bc.value @@ -48,18 +48,17 @@ class SparkBroadcastValue[T](bc: Broadcast[T]) extends BroadcastValue[T] with Se object SparkTaskContext { def get(): SparkTaskContext = taskContext.get - private[this] val taskContext: ThreadLocal[SparkTaskContext] = new ThreadLocal[SparkTaskContext]() { - override def initialValue(): SparkTaskContext = { - val sparkTC = TaskContext.get() - assert(sparkTC != null, "Spark Task Context was null, maybe this ran on the driver?") - sparkTC.addTaskCompletionListener[Unit] { (_: TaskContext) => - SparkTaskContext.finish() - } + private[this] val taskContext: ThreadLocal[SparkTaskContext] = + new ThreadLocal[SparkTaskContext]() { + override def initialValue(): SparkTaskContext = { + val sparkTC = TaskContext.get() + assert(sparkTC != null, "Spark Task Context was null, maybe this ran on the driver?") + sparkTC.addTaskCompletionListener[Unit]((_: TaskContext) => SparkTaskContext.finish()) - // this must be the only place where SparkTaskContext classes are created - new SparkTaskContext(sparkTC) + // this must be the only place where SparkTaskContext classes are created + new SparkTaskContext(sparkTC) + } } - } def finish(): Unit = { taskContext.get().close() @@ -67,8 +66,7 @@ object SparkTaskContext { } } - -class SparkTaskContext private[spark](ctx: TaskContext) extends HailTaskContext { +class SparkTaskContext private[spark] (ctx: TaskContext) extends HailTaskContext { self => override def stageId(): Int = ctx.stageId() override def partitionId(): Int = ctx.partitionId() @@ -84,15 +82,19 @@ object SparkBackend { def majorMinor(version: String): String = version.split("\\.", 3).take(2).mkString(".") if (majorMinor(jarVersion) != majorMinor(sparkVersion)) - fatal(s"This Hail JAR was compiled for Spark $jarVersion, cannot run with Spark $sparkVersion.\n" + - s" The major and minor versions must agree, though the patch version can differ.") + fatal( + s"This Hail JAR was compiled for Spark $jarVersion, cannot run with Spark $sparkVersion.\n" + + s" The major and minor versions must agree, though the patch version can differ." + ) else if (jarVersion != sparkVersion) - warn(s"This Hail JAR was compiled for Spark $jarVersion, running with Spark $sparkVersion.\n" + - s" Compatibility is not guaranteed.") + warn( + s"This Hail JAR was compiled for Spark $jarVersion, running with Spark $sparkVersion.\n" + + s" Compatibility is not guaranteed." + ) } - def createSparkConf(appName: String, master: String, - local: String, blockSize: Long): SparkConf = { + def createSparkConf(appName: String, master: String, local: String, blockSize: Long) + : SparkConf = { require(blockSize >= 0) checkSparkCompatibility(is.hail.HAIL_SPARK_VERSION, org.apache.spark.SPARK_VERSION) @@ -116,12 +118,16 @@ object SparkBackend { "org.apache.hadoop.io.compress.DefaultCodec," + "is.hail.io.compress.BGzipCodec," + "is.hail.io.compress.BGzipCodecTbi," + - "org.apache.hadoop.io.compress.GzipCodec") + "org.apache.hadoop.io.compress.GzipCodec", + ) conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") conf.set("spark.kryo.registrator", "is.hail.kryo.HailKryoRegistrator") - conf.set("spark.hadoop.mapreduce.input.fileinputformat.split.minsize", (blockSize * 1024L * 1024L).toString) + conf.set( + "spark.hadoop.mapreduce.input.fileinputformat.split.minsize", + (blockSize * 1024L * 1024L).toString, + ) // load additional Spark properties from HAIL_SPARK_PROPERTIES val hailSparkProperties = System.getenv("HAIL_SPARK_PROPERTIES") @@ -141,10 +147,13 @@ object SparkBackend { conf } - def configureAndCreateSparkContext(appName: String, master: String, - local: String, blockSize: Long): SparkContext = { + def configureAndCreateSparkContext( + appName: String, + master: String, + local: String, + blockSize: Long, + ): SparkContext = new SparkContext(createSparkConf(appName, master, local, blockSize)) - } def checkSparkConfiguration(sc: SparkContext): Unit = { val conf = sc.getConf @@ -155,31 +164,37 @@ object SparkBackend { val kryoSerializer = "org.apache.spark.serializer.KryoSerializer" if (!serializer.contains(kryoSerializer)) problems += s"Invalid configuration property spark.serializer: required $kryoSerializer. " + - s"Found: ${ serializer.getOrElse("empty parameter") }." + s"Found: ${serializer.getOrElse("empty parameter")}." - if (!conf.getOption("spark.kryo.registrator").exists(_.split(",").contains("is.hail.kryo.HailKryoRegistrator"))) + if ( + !conf.getOption("spark.kryo.registrator").exists( + _.split(",").contains("is.hail.kryo.HailKryoRegistrator") + ) + ) problems += s"Invalid config parameter: spark.kryo.registrator must include is.hail.kryo.HailKryoRegistrator." + - s"Found ${ conf.getOption("spark.kryo.registrator").getOrElse("empty parameter.") }" + s"Found ${conf.getOption("spark.kryo.registrator").getOrElse("empty parameter.")}" if (problems.nonEmpty) fatal( s"""Found problems with SparkContext configuration: - | ${ problems.mkString("\n ") }""".stripMargin) + | ${problems.mkString("\n ")}""".stripMargin + ) } def hailCompressionCodecs: Array[String] = Array( "org.apache.hadoop.io.compress.DefaultCodec", "is.hail.io.compress.BGzipCodec", "is.hail.io.compress.BGzipCodecTbi", - "org.apache.hadoop.io.compress.GzipCodec") + "org.apache.hadoop.io.compress.GzipCodec", + ) - /** - * If a SparkBackend has already been initialized, this function returns it regardless of the + /** If a SparkBackend has already been initialized, this function returns it regardless of the * parameters with which it was initialized. * * Otherwise, it initializes and returns a new HailContext. */ - def getOrCreate(sc: SparkContext = null, + def getOrCreate( + sc: SparkContext = null, appName: String = "Hail", master: String = null, local: String = "local[*]", @@ -191,30 +206,39 @@ object SparkBackend { tmpdir: String = "/tmp", localTmpdir: String = "file:///tmp", gcsRequesterPaysProject: String = null, - gcsRequesterPaysBuckets: String = null + gcsRequesterPaysBuckets: String = null, ): SparkBackend = synchronized { if (theSparkBackend == null) - return SparkBackend(sc, appName, master, local, logFile, quiet, append, skipLoggingConfiguration, + return SparkBackend(sc, appName, master, local, logFile, quiet, append, + skipLoggingConfiguration, minBlockSize, tmpdir, localTmpdir, gcsRequesterPaysProject, gcsRequesterPaysBuckets) // there should be only one SparkContext assert(sc == null || (sc eq theSparkBackend.sc)) val initializedMinBlockSize = - theSparkBackend.sc.getConf.getLong("spark.hadoop.mapreduce.input.fileinputformat.split.minsize", 0L) / 1024L / 1024L + theSparkBackend.sc.getConf.getLong( + "spark.hadoop.mapreduce.input.fileinputformat.split.minsize", + 0L, + ) / 1024L / 1024L if (minBlockSize != initializedMinBlockSize) - warn(s"Requested minBlockSize $minBlockSize, but already initialized to $initializedMinBlockSize. Ignoring requested setting.") + warn( + s"Requested minBlockSize $minBlockSize, but already initialized to $initializedMinBlockSize. Ignoring requested setting." + ) if (master != null) { val initializedMaster = theSparkBackend.sc.master if (master != initializedMaster) - warn(s"Requested master $master, but already initialized to $initializedMaster. Ignoring requested setting.") + warn( + s"Requested master $master, but already initialized to $initializedMaster. Ignoring requested setting." + ) } theSparkBackend } - def apply(sc: SparkContext = null, + def apply( + sc: SparkContext = null, appName: String = "Hail", master: String = null, local: String = "local[*]", @@ -226,7 +250,7 @@ object SparkBackend { tmpdir: String, localTmpdir: String, gcsRequesterPaysProject: String = null, - gcsRequesterPaysBuckets: String = null + gcsRequesterPaysBuckets: String = null, ): SparkBackend = synchronized { require(theSparkBackend == null) @@ -246,7 +270,8 @@ object SparkBackend { sc1.uiWebUrl.foreach(ui => info(s"SparkUI: $ui")) - theSparkBackend = new SparkBackend(tmpdir, localTmpdir, sc1, gcsRequesterPaysProject, gcsRequesterPaysBuckets) + theSparkBackend = + new SparkBackend(tmpdir, localTmpdir, sc1, gcsRequesterPaysProject, gcsRequesterPaysBuckets) theSparkBackend.addDefaultReferences() theSparkBackend } @@ -275,11 +300,13 @@ class SparkBackend( val localTmpdir: String, val sc: SparkContext, gcsRequesterPaysProject: String, - gcsRequesterPaysBuckets: String + gcsRequesterPaysBuckets: String, ) extends Backend with Closeable with BackendWithCodeCache { assert(gcsRequesterPaysProject != null || gcsRequesterPaysBuckets == null) lazy val sparkSession: SparkSession = SparkSession.builder().config(sc.getConf).getOrCreate() - private[this] val theHailClassLoader: HailClassLoader = new HailClassLoader(getClass().getClassLoader()) + + private[this] val theHailClassLoader: HailClassLoader = + new HailClassLoader(getClass().getClassLoader()) override def canExecuteParallelTasksOnDriver: Boolean = false @@ -297,6 +324,7 @@ class SparkBackend( } new HadoopFS(new SerializableHadoopConfiguration(conf)) } + private[this] val longLifeTempFileManager: TempFileManager = new OwningTempFileManager(fs) val bmCache: SparkBlockMatrixCache = SparkBlockMatrixCache() @@ -309,20 +337,23 @@ class SparkBackend( val availableFlags: java.util.ArrayList[String] = flags.available - def persist(backendContext: BackendContext, id: String, value: BlockMatrix, storageLevel: String): Unit = bmCache.persistBlockMatrix(id, value, storageLevel) + def persist(backendContext: BackendContext, id: String, value: BlockMatrix, storageLevel: String) + : Unit = bmCache.persistBlockMatrix(id, value, storageLevel) def unpersist(backendContext: BackendContext, id: String): Unit = unpersist(id) - def getPersistedBlockMatrix(backendContext: BackendContext, id: String): BlockMatrix = bmCache.getPersistedBlockMatrix(id) + def getPersistedBlockMatrix(backendContext: BackendContext, id: String): BlockMatrix = + bmCache.getPersistedBlockMatrix(id) - def getPersistedBlockMatrixType(backendContext: BackendContext, id: String): BlockMatrixType = bmCache.getPersistedBlockMatrixType(id) + def getPersistedBlockMatrixType(backendContext: BackendContext, id: String): BlockMatrixType = + bmCache.getPersistedBlockMatrixType(id) def unpersist(id: String): Unit = bmCache.unpersistBlockMatrix(id) def createExecuteContextForTests( timer: ExecutionTimer, region: Region, - selfContainedExecution: Boolean = true + selfContainedExecution: Boolean = true, ): ExecuteContext = new ExecuteContext( tmpdir, @@ -333,17 +364,16 @@ class SparkBackend( timer, if (selfContainedExecution) null else new NonOwningTempFileManager(longLifeTempFileManager), theHailClassLoader, - this.references, flags, new BackendContext { override val executionCache: ExecutionCache = ExecutionCache.forTesting }, - IrMetadata(None) + IrMetadata(None), ) def withExecuteContext[T](timer: ExecutionTimer, selfContainedExecution: Boolean = true) - : (ExecuteContext => T) => T = + : (ExecuteContext => T) => T = ExecuteContext.scoped( tmpdir, localTmpdir, @@ -352,36 +382,49 @@ class SparkBackend( timer, if (selfContainedExecution) null else new NonOwningTempFileManager(longLifeTempFileManager), theHailClassLoader, - this.references, flags, new BackendContext { override val executionCache: ExecutionCache = ExecutionCache.fromFlags(flags, fs, tmpdir) - } + }, ) - def withExecuteContext[T](methodName: String): (ExecuteContext => T) => T = { f => + override def withExecuteContext[T](methodName: String)(f: ExecuteContext => T): T = ExecutionTimer.logTime(methodName) { timer => - ExecuteContext.scoped(tmpdir, tmpdir, this, fs, timer, null, theHailClassLoader, this.references, flags, new BackendContext { - override val executionCache: ExecutionCache = - ExecutionCache.fromFlags(flags, fs, tmpdir) - })(f) + ExecuteContext.scoped( + tmpdir, + tmpdir, + this, + fs, + timer, + null, + theHailClassLoader, + flags, + new BackendContext { + override val executionCache: ExecutionCache = + ExecutionCache.fromFlags(flags, fs, tmpdir) + }, + )(f) } - } - def broadcast[T : ClassTag](value: T): BroadcastValue[T] = new SparkBroadcastValue[T](sc.broadcast(value)) + def broadcast[T: ClassTag](value: T): BroadcastValue[T] = + new SparkBroadcastValue[T](sc.broadcast(value)) override def parallelizeAndComputeWithIndex( backendContext: BackendContext, fs: FS, collection: Array[Array[Byte]], stageIdentifier: String, - dependency: Option[TableStageDependency] = None + dependency: Option[TableStageDependency] = None, )( f: (Array[Byte], HailTaskContext, HailClassLoader, FS) => Array[Byte] ): Array[Array[Byte]] = { val sparkDeps = dependency.toIndexedSeq - .flatMap(dep => dep.deps.map(rvdDep => new AnonymousDependency(rvdDep.asInstanceOf[RVDDependency].rvd.crdd.rdd))) + .flatMap(dep => + dep.deps.map(rvdDep => + new AnonymousDependency(rvdDep.asInstanceOf[RVDDependency].rvd.crdd.rdd) + ) + ) new SparkBackendComputeRDD(sc, collection, f, sparkDeps).collect() } @@ -391,40 +434,50 @@ class SparkBackend( fs: FS, contexts: IndexedSeq[(Array[Byte], Int)], stageIdentifier: String, - dependency: Option[TableStageDependency] = None + dependency: Option[TableStageDependency] = None, )( f: (Array[Byte], HailTaskContext, HailClassLoader, FS) => Array[Byte] ): (Option[Throwable], IndexedSeq[(Array[Byte], Int)]) = { val sparkDeps = - for {rvdDep <- dependency.toIndexedSeq; dep <- rvdDep.deps} - yield new AnonymousDependency(dep.asInstanceOf[RVDDependency].rvd.crdd.rdd) + for { + rvdDep <- dependency.toIndexedSeq + dep <- rvdDep.deps + } yield new AnonymousDependency(dep.asInstanceOf[RVDDependency].rvd.crdd.rdd) val rdd = new RDD[(Try[Array[Byte]], Int)](sc, sparkDeps) { - /* Spark insists that `Partition.index` is indeed the index that partition - * appears in the result of `RDD.getPartitions`. + /* Spark insists that `Partition.index` is indeed the index that partition appears in the + * result of `RDD.getPartitions`. * - * We accept contexts in the form (data, index) and return results in the - * form (result, index). The index is the index of input context in the - * original array of contexts. This function may receive a subset of those - * contexts when retrying queries. We can't use it as the RDD Partition index, - * therefore; instead store it as a "tag" and use it to transform the RDD result. + * We accept contexts in the form (data, index) and return results in the form (result, + * index). The index is the index of input context in the original array of contexts. This + * function may receive a subset of those contexts when retrying queries. We can't use it as + * the RDD Partition index, therefore; instead store it as a "tag" and use it to transform + * the RDD result. * - * See `BackendUtils.collectDArray` for how the index is generated. - */ - case class TaggedRDDPartition(data: Array[Byte], tag: Int, index: Int) - extends Partition + * See `BackendUtils.collectDArray` for how the index is generated. */ + case class TaggedRDDPartition(data: Array[Byte], tag: Int, index: Int) extends Partition override protected def getPartitions: Array[Partition] = - for {((data, index), rddIndex) <- contexts.zipWithIndex.toArray} - yield TaggedRDDPartition(data, index, rddIndex) + for { + ((data, index), rddIndex) <- contexts.zipWithIndex.toArray + } yield TaggedRDDPartition(data, index, rddIndex) - override def compute(partition: Partition, context: TaskContext): Iterator[(Try[Array[Byte]], Int)] = { + override def compute(partition: Partition, context: TaskContext) + : Iterator[(Try[Array[Byte]], Int)] = { val sp = partition.asInstanceOf[TaggedRDDPartition] val fs = new HadoopFS(null) // FIXME: this is broken: the partitionId of SparkTaskContext will be incorrect - val result = Try(f(sp.data, SparkTaskContext.get(), theHailClassLoaderForSparkWorkers, fs)) + val result = + try + Success(f(sp.data, SparkTaskContext.get(), theHailClassLoaderForSparkWorkers, fs)) + catch { + case NonFatal(exc) => + exc.getStackTrace() // Calling getStackTrace appears to ensure the exception is + // serialized with its stack trace. + Failure(exc) + } Iterator.single((result, sp.tag)) } } @@ -442,11 +495,13 @@ class SparkBackend( def stop(): Unit = SparkBackend.stop() - def startProgressBar() { + def startProgressBar(): Unit = ProgressBarBuilder.build(sc) - } - private[this] def executionResultToAnnotation(ctx: ExecuteContext, result: Either[Unit, (PTuple, Long)]) = result match { + private[this] def executionResultToAnnotation( + ctx: ExecuteContext, + result: Either[Unit, (PTuple, Long)], + ) = result match { case Left(x) => x case Right((pt, off)) => SafeRow(pt, off).get(0) } @@ -458,7 +513,7 @@ class SparkBackend( optimize: Boolean, lowerTable: Boolean, lowerBM: Boolean, - print: Option[PrintWriter] = None + print: Option[PrintWriter] = None, ): Any = { val l = _jvmLowerAndExecute(ctx, ir0, optimize, lowerTable, lowerBM, print) executionResultToAnnotation(ctx, l) @@ -470,7 +525,7 @@ class SparkBackend( optimize: Boolean, lowerTable: Boolean, lowerBM: Boolean, - print: Option[PrintWriter] = None + print: Option[PrintWriter] = None, ): Either[Unit, (PTuple, Long)] = { val typesToLower: DArrayLowering.Type = (lowerTable, lowerBM) match { case (true, true) => DArrayLowering.All @@ -481,28 +536,39 @@ class SparkBackend( val ir = LoweringPipeline.darrayLowerer(optimize)(typesToLower).apply(ctx, ir0).asInstanceOf[IR] if (!Compilable(ir)) - throw new LowererUnsupportedOperation(s"lowered to uncompilable IR: ${ Pretty(ctx, ir) }") + throw new LowererUnsupportedOperation(s"lowered to uncompilable IR: ${Pretty(ctx, ir)}") val res = ir.typ match { case TVoid => val (_, f) = ctx.timer.time("Compile") { - Compile[AsmFunction1RegionUnit](ctx, + Compile[AsmFunction1RegionUnit]( + ctx, FastSeq(), - FastSeq(classInfo[Region]), UnitInfo, + FastSeq(classInfo[Region]), + UnitInfo, ir, - print = print) + print = print, + ) } - ctx.timer.time("Run")(Left(ctx.scopedExecution((hcl, fs, htc, r) => f(hcl, fs, htc, r).apply(r)))) + ctx.timer.time("Run")(Left(ctx.scopedExecution((hcl, fs, htc, r) => + f(hcl, fs, htc, r).apply(r) + ))) case _ => val (Some(PTypeReferenceSingleCodeType(pt: PTuple)), f) = ctx.timer.time("Compile") { - Compile[AsmFunction1RegionLong](ctx, + Compile[AsmFunction1RegionLong]( + ctx, FastSeq(), - FastSeq(classInfo[Region]), LongInfo, + FastSeq(classInfo[Region]), + LongInfo, MakeTuple.ordered(FastSeq(ir)), - print = print) + print = print, + ) } - ctx.timer.time("Run")(Right((pt, ctx.scopedExecution((hcl, fs, htc, r) => f(hcl, fs, htc, r).apply(r))))) + ctx.timer.time("Run")(Right(( + pt, + ctx.scopedExecution((hcl, fs, htc, r) => f(hcl, fs, htc, r).apply(r)), + ))) } res @@ -511,14 +577,16 @@ class SparkBackend( def execute(timer: ExecutionTimer, ir: IR, optimize: Boolean): Any = withExecuteContext(timer) { ctx => val queryID = Backend.nextID() - log.info(s"starting execution of query $queryID of initial size ${ IRSize(ir) }") + log.info(s"starting execution of query $queryID of initial size ${IRSize(ir)}") val l = _execute(ctx, ir, optimize) - val javaObjResult = ctx.timer.time("convertRegionValueToAnnotation")(executionResultToAnnotation(ctx, l)) + val javaObjResult = + ctx.timer.time("convertRegionValueToAnnotation")(executionResultToAnnotation(ctx, l)) log.info(s"finished execution of query $queryID") javaObjResult } - private[this] def _execute(ctx: ExecuteContext, ir: IR, optimize: Boolean): Either[Unit, (PTuple, Long)] = { + private[this] def _execute(ctx: ExecuteContext, ir: IR, optimize: Boolean) + : Either[Unit, (PTuple, Long)] = { TypeCheck(ctx, ir) Validate(ir) ctx.irMetadata = ctx.irMetadata.copy(semhash = SemanticHash(ctx)(ir)) @@ -540,11 +608,12 @@ class SparkBackend( val t = ir.typ assert(t.isRealizable) val queryID = Backend.nextID() - log.info(s"starting execution of query $queryID} of initial size ${ IRSize(ir) }") + log.info(s"starting execution of query $queryID} of initial size ${IRSize(ir)}") val retVal = _execute(ctx, ir, true) val literalIR = retVal match { - case Left(x) => throw new HailException("Can't create literal") - case Right((pt, addr)) => GetFieldByIdx(EncodedLiteral.fromPTypeAndAddress(pt, addr, ctx), 0) + case Left(_) => throw new HailException("Can't create literal") + case Right((pt, addr)) => + GetFieldByIdx(EncodedLiteral.fromPTypeAndAddress(pt, addr, ctx), 0) } log.info(s"finished execution of query $queryID") addJavaIR(literalIR) @@ -552,26 +621,38 @@ class SparkBackend( } } - override def execute(ir: String, timed: Boolean)(consume: (ExecuteContext, Either[Unit, (PTuple, Long)], String) => Unit): Unit = { + override def execute( + ir: String, + timed: Boolean, + )( + consume: (ExecuteContext, Either[Unit, (PTuple, Long)], String) => Unit + ): Unit = { withExecuteContext("SparkBackend.execute") { ctx => val res = ctx.timer.time("execute") { - val irData = IRParser.parse_value_ir(ir, IRParserEnvironment(ctx, irMap = persistedIR.toMap)) + val irData = + IRParser.parse_value_ir(ir, IRParserEnvironment(ctx, irMap = persistedIR.toMap)) val queryID = Backend.nextID() - log.info(s"starting execution of query $queryID of initial size ${ IRSize(irData) }") + log.info(s"starting execution of query $queryID of initial size ${IRSize(irData)}") _execute(ctx, irData, true) } ctx.timer.finish() - val timings = if (timed) Serialization.write(Map("timings" -> ctx.timer.toMap))(new DefaultFormats {}) else "" + val timings = if (timed) + Serialization.write(Map("timings" -> ctx.timer.toMap))(new DefaultFormats {}) + else "" consume(ctx, res, timings) } } - def encodeToBytes(ctx: ExecuteContext, t: PTuple, off: Long, bufferSpecString: String): Array[Byte] = { + def encodeToBytes(ctx: ExecuteContext, t: PTuple, off: Long, bufferSpecString: String) + : Array[Byte] = { val bs = BufferSpec.parseOrDefault(bufferSpecString) assert(t.size == 1) val elementType = t.fields(0).typ val codec = TypedCodecSpec( - EType.fromPythonTypeEncoding(elementType.virtualType), elementType.virtualType, bs) + EType.fromPythonTypeEncoding(elementType.virtualType), + elementType.virtualType, + bs, + ) assert(t.isFieldDefined(off, 0)) codec.encode(ctx, elementType, t.loadField(off, 0)) } @@ -579,9 +660,19 @@ class SparkBackend( def pyFromDF(df: DataFrame, jKey: java.util.List[String]): (Int, String) = { ExecutionTimer.logTime("SparkBackend.pyFromDF") { timer => val key = jKey.asScala.toArray.toFastSeq - val signature = SparkAnnotationImpex.importType(df.schema).setRequired(true).asInstanceOf[PStruct] + val signature = + SparkAnnotationImpex.importType(df.schema).setRequired(true).asInstanceOf[PStruct] withExecuteContext(timer, selfContainedExecution = false) { ctx => - val tir = TableLiteral(TableValue(ctx, signature.virtualType.asInstanceOf[TStruct], key, df.rdd, Some(signature)), ctx.theHailClassLoader) + val tir = TableLiteral( + TableValue( + ctx, + signature.virtualType.asInstanceOf[TStruct], + key, + df.rdd, + Some(signature), + ), + ctx.theHailClassLoader, + ) val id = addJavaIR(tir) (id, JsonMethods.compact(tir.typ.toJSON)) } @@ -603,11 +694,14 @@ class SparkBackend( case json4s.JObject(values) => values.toMap } - val paths = kvs("paths").asInstanceOf[json4s.JArray].arr.toArray.map { case json4s.JString(s) => s } + val paths = kvs("paths").asInstanceOf[json4s.JArray].arr.toArray.map { case json4s.JString(s) => + s + } val intervalPointType = parseType(kvs("intervalPointType").asInstanceOf[json4s.JString].s) - val intervalObjects = JSONAnnotationImpex.importAnnotation(kvs("intervals"), TArray(TInterval(intervalPointType))) - .asInstanceOf[IndexedSeq[Interval]] + val intervalObjects = + JSONAnnotationImpex.importAnnotation(kvs("intervals"), TArray(TInterval(intervalPointType))) + .asInstanceOf[IndexedSeq[Interval]] val opts = NativeReaderOptions(intervalObjects, intervalPointType, filterIntervals = false) val matrixReaders: IndexedSeq[MatrixIR] = paths.map { p => @@ -622,39 +716,57 @@ class SparkBackend( def pyAddReference(jsonConfig: String): Unit = addReference(ReferenceGenome.fromJSON(jsonConfig)) def pyRemoveReference(name: String): Unit = removeReference(name) - def pyAddLiftover(name: String, chainFile: String, destRGName: String): Unit = { + def pyAddLiftover(name: String, chainFile: String, destRGName: String): Unit = ExecutionTimer.logTime("SparkBackend.pyReferenceAddLiftover") { timer => - withExecuteContext(timer) { ctx => - references(name).addLiftover(ctx, chainFile, destRGName) - } + withExecuteContext(timer)(ctx => references(name).addLiftover(ctx, chainFile, destRGName)) } - } - def pyRemoveLiftover(name: String, destRGName: String) = references(name).removeLiftover(destRGName) - def pyFromFASTAFile(name: String, fastaFile: String, indexFile: String, - xContigs: java.util.List[String], yContigs: java.util.List[String], mtContigs: java.util.List[String], - parInput: java.util.List[String]): String = { + def pyRemoveLiftover(name: String, destRGName: String) = + references(name).removeLiftover(destRGName) + + def pyFromFASTAFile( + name: String, + fastaFile: String, + indexFile: String, + xContigs: java.util.List[String], + yContigs: java.util.List[String], + mtContigs: java.util.List[String], + parInput: java.util.List[String], + ): String = { ExecutionTimer.logTime("SparkBackend.pyFromFASTAFile") { timer => withExecuteContext(timer) { ctx => - val rg = ReferenceGenome.fromFASTAFile(ctx, name, fastaFile, indexFile, - xContigs.asScala.toArray, yContigs.asScala.toArray, mtContigs.asScala.toArray, parInput.asScala.toArray) + val rg = ReferenceGenome.fromFASTAFile( + ctx, + name, + fastaFile, + indexFile, + xContigs.asScala.toArray, + yContigs.asScala.toArray, + mtContigs.asScala.toArray, + parInput.asScala.toArray, + ) rg.toJSONString } } } - def pyAddSequence(name: String, fastaFile: String, indexFile: String): Unit = { + def pyAddSequence(name: String, fastaFile: String, indexFile: String): Unit = ExecutionTimer.logTime("SparkBackend.pyAddSequence") { timer => - withExecuteContext(timer) { ctx => - references(name).addSequence(ctx, fastaFile, indexFile) - } + withExecuteContext(timer)(ctx => references(name).addSequence(ctx, fastaFile, indexFile)) } - } + def pyRemoveSequence(name: String) = references(name).removeSequence() def pyExportBlockMatrix( - pathIn: String, pathOut: String, delimiter: String, header: String, addIndex: Boolean, exportType: String, - partitionSize: java.lang.Integer, entries: String): Unit = { + pathIn: String, + pathOut: String, + delimiter: String, + header: String, + addIndex: Boolean, + exportType: String, + partitionSize: java.lang.Integer, + entries: String, + ): Unit = { ExecutionTimer.logTime("SparkBackend.pyExportBlockMatrix") { timer => withExecuteContext(timer) { ctx => val rm = RowMatrix.readBlockMatrix(fs, pathIn, partitionSize) @@ -664,62 +776,75 @@ class SparkBackend( case "lower" => rm.exportLowerTriangle(ctx, pathOut, delimiter, Option(header), addIndex, exportType) case "strict_lower" => - rm.exportStrictLowerTriangle(ctx, pathOut, delimiter, Option(header), addIndex, exportType) + rm.exportStrictLowerTriangle( + ctx, + pathOut, + delimiter, + Option(header), + addIndex, + exportType, + ) case "upper" => rm.exportUpperTriangle(ctx, pathOut, delimiter, Option(header), addIndex, exportType) case "strict_upper" => - rm.exportStrictUpperTriangle(ctx, pathOut, delimiter, Option(header), addIndex, exportType) + rm.exportStrictUpperTriangle( + ctx, + pathOut, + delimiter, + Option(header), + addIndex, + exportType, + ) } } } } - def pyFitLinearMixedModel(lmm: LinearMixedModel, pa_t: RowMatrix, a_t: RowMatrix): TableIR = { + def pyFitLinearMixedModel(lmm: LinearMixedModel, pa_t: RowMatrix, a_t: RowMatrix): TableIR = ExecutionTimer.logTime("SparkBackend.pyAddSequence") { timer => withExecuteContext(timer, selfContainedExecution = false) { ctx => lmm.fit(ctx, pa_t, Option(a_t)) } } - } - def parse_value_ir(s: String, refMap: java.util.Map[String, String]): IR = { + def parse_value_ir(s: String, refMap: java.util.Map[String, String]): IR = ExecutionTimer.logTime("SparkBackend.parse_value_ir") { timer => withExecuteContext(timer) { ctx => - IRParser.parse_value_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap), BindingEnv.eval(refMap.asScala.toMap.mapValues(IRParser.parseType).toSeq: _*)) + IRParser.parse_value_ir( + s, + IRParserEnvironment(ctx, irMap = persistedIR.toMap), + BindingEnv.eval(refMap.asScala.toMap.mapValues(IRParser.parseType).toSeq: _*), + ) } } - } - def parse_table_ir(s: String): TableIR = { + def parse_table_ir(s: String): TableIR = ExecutionTimer.logTime("SparkBackend.parse_table_ir") { timer => withExecuteContext(timer, selfContainedExecution = false) { ctx => IRParser.parse_table_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap)) } } - } - def parse_matrix_ir(s: String): MatrixIR = { + def parse_matrix_ir(s: String): MatrixIR = ExecutionTimer.logTime("SparkBackend.parse_matrix_ir") { timer => withExecuteContext(timer, selfContainedExecution = false) { ctx => IRParser.parse_matrix_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap)) } } - } - def parse_blockmatrix_ir(s: String): BlockMatrixIR = { + def parse_blockmatrix_ir(s: String): BlockMatrixIR = ExecutionTimer.logTime("SparkBackend.parse_blockmatrix_ir") { timer => withExecuteContext(timer, selfContainedExecution = false) { ctx => IRParser.parse_blockmatrix_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap)) } } - } override def lowerDistributedSort( ctx: ExecuteContext, stage: TableStage, sortFields: IndexedSeq[SortField], rt: RTable, - nPartitions: Option[Int] + nPartitions: Option[Int], ): TableReader = { if (getFlag("use_new_shuffle") != null) return LowerDistributedSort.distributedSort(ctx, stage, sortFields, rt) @@ -744,19 +869,19 @@ class SparkBackend( val act = implicitly[ClassTag[Annotation]] val codec = TypedCodecSpec(rvd.rowPType, BufferSpec.wireSpec) - val rdd = rvd.keyedEncodedRDD(ctx, codec, sortFields.map(_.field)).sortBy(_._1, numPartitions = nPartitions.getOrElse(rvd.getNumPartitions))(ord, act) + val rdd = rvd.keyedEncodedRDD(ctx, codec, sortFields.map(_.field)).sortBy( + _._1, + numPartitions = nPartitions.getOrElse(rvd.getNumPartitions), + )(ord, act) val (rowPType: PStruct, orderedCRDD) = codec.decodeRDD(ctx, rowType, rdd.map(_._2)) RVDTableReader(RVD.unkeyed(rowPType, orderedCRDD), globalsLit, rt) } - def close(): Unit = { + def close(): Unit = longLifeTempFileManager.cleanup() - } - def tableToTableStage(ctx: ExecuteContext, - inputIR: TableIR, - analyses: LoweringAnalyses - ): TableStage = { + def tableToTableStage(ctx: ExecuteContext, inputIR: TableIR, analyses: LoweringAnalyses) + : TableStage = { CanLowerEfficiently(ctx, inputIR) match { case Some(failReason) => log.info(s"SparkBackend: could not lower IR to table stage: $failReason") @@ -773,12 +898,11 @@ class SparkBackendComputeRDD( sc: SparkContext, @transient private val collection: Array[Array[Byte]], f: (Array[Byte], HailTaskContext, HailClassLoader, FS) => Array[Byte], - deps: Seq[Dependency[_]] + deps: Seq[Dependency[_]], ) extends RDD[Array[Byte]](sc, deps) { - override def getPartitions: Array[Partition] = { + override def getPartitions: Array[Partition] = Array.tabulate(collection.length)(i => SparkBackendComputeRDDPartition(collection(i), i)) - } override def compute(partition: Partition, context: TaskContext): Iterator[Array[Byte]] = { val sp = partition.asInstanceOf[SparkBackendComputeRDDPartition] diff --git a/hail/src/main/scala/is/hail/backend/spark/SparkBlockMatrixCache.scala b/hail/src/main/scala/is/hail/backend/spark/SparkBlockMatrixCache.scala index c7593172408..8a0830d43e7 100644 --- a/hail/src/main/scala/is/hail/backend/spark/SparkBlockMatrixCache.scala +++ b/hail/src/main/scala/is/hail/backend/spark/SparkBlockMatrixCache.scala @@ -13,8 +13,7 @@ case class SparkBlockMatrixCache() { blockmatrices.update(id, value.persist(storageLevel)) def getPersistedBlockMatrix(id: String): BlockMatrix = - blockmatrices.getOrElse(id, - fatal(s"Persisted BlockMatrix with id ${ id } does not exist.")) + blockmatrices.getOrElse(id, fatal(s"Persisted BlockMatrix with id $id does not exist.")) def getPersistedBlockMatrixType(id: String): BlockMatrixType = BlockMatrixType.fromBlockMatrix(getPersistedBlockMatrix(id)) diff --git a/hail/src/main/scala/is/hail/check/Arbitrary.scala b/hail/src/main/scala/is/hail/check/Arbitrary.scala index a7312e39434..8d7b8ce68db 100644 --- a/hail/src/main/scala/is/hail/check/Arbitrary.scala +++ b/hail/src/main/scala/is/hail/check/Arbitrary.scala @@ -8,40 +8,80 @@ object Arbitrary { new Arbitrary(arbitrary) implicit def arbBoolean: Arbitrary[Boolean] = new Arbitrary( - Gen.oneOf(true, false)) + Gen.oneOf(true, false) + ) - implicit def arbByte: Arbitrary[Byte] = new Arbitrary(Gen.oneOfGen(Gen.oneOf(Byte.MinValue, -1, 0, 1, Byte.MaxValue), - Gen { p => p.rng.getRandomGenerator.nextInt().toByte })) + implicit def arbByte: Arbitrary[Byte] = new Arbitrary(Gen.oneOfGen( + Gen.oneOf(Byte.MinValue, -1, 0, 1, Byte.MaxValue), + Gen(p => p.rng.getRandomGenerator.nextInt().toByte), + )) implicit def arbInt: Arbitrary[Int] = new Arbitrary( - Gen.oneOfGen(Gen.oneOf(Int.MinValue, -1, 0, 1, Int.MaxValue), + Gen.oneOfGen( + Gen.oneOf(Int.MinValue, -1, 0, 1, Int.MaxValue), Gen.choose(-100, 100), - Gen { p => p.rng.getRandomGenerator.nextInt() })) + Gen(p => p.rng.getRandomGenerator.nextInt()), + ) + ) implicit def arbLong: Arbitrary[Long] = new Arbitrary( - Gen.oneOfGen(Gen.oneOf(Long.MinValue, -1L, 0L, 1L, Long.MaxValue), + Gen.oneOfGen( + Gen.oneOf(Long.MinValue, -1L, 0L, 1L, Long.MaxValue), Gen.choose(-100, 100), - Gen { p => p.rng.getRandomGenerator.nextLong() })) + Gen(p => p.rng.getRandomGenerator.nextLong()), + ) + ) implicit def arbFloat: Arbitrary[Float] = new Arbitrary( - Gen.oneOfGen(Gen.oneOf(Float.MinValue, -1.0f, -Float.MinPositiveValue, 0.0f, Float.MinPositiveValue, 1.0f, Float.MaxValue), + Gen.oneOfGen( + Gen.oneOf( + Float.MinValue, + -1.0f, + -Float.MinPositiveValue, + 0.0f, + Float.MinPositiveValue, + 1.0f, + Float.MaxValue, + ), Gen.choose(-100.0f, 100.0f), - Gen { p => p.rng.nextUniform(Float.MinValue, Float.MaxValue, true).toFloat })) + Gen(p => p.rng.nextUniform(Float.MinValue, Float.MaxValue, true).toFloat), + ) + ) implicit def arbDouble: Arbitrary[Double] = new Arbitrary( - Gen.oneOfGen(Gen.oneOf(Double.MinValue, -1.0, -Double.MinPositiveValue, 0.0, Double.MinPositiveValue, 1.0, Double.MaxValue), + Gen.oneOfGen( + Gen.oneOf( + Double.MinValue, + -1.0, + -Double.MinPositiveValue, + 0.0, + Double.MinPositiveValue, + 1.0, + Double.MaxValue, + ), Gen.choose(-100.0, 100.0), - Gen { p => p.rng.nextUniform(Double.MinValue, Double.MaxValue, true) })) + Gen(p => p.rng.nextUniform(Double.MinValue, Double.MaxValue, true)), + ) + ) - implicit def arbString: Arbitrary[String] = new Arbitrary(Gen.frequency((1, Gen.const("")), (10, Gen { (p: Parameters) => - val s = p.rng.getRandomGenerator.nextInt(12) - val b = new StringBuilder() - for (i <- 0 until s) - b += Gen.randomOneOf(p.rng, Gen.printableChars) - b.result() - }))) + implicit def arbString: Arbitrary[String] = new Arbitrary(Gen.frequency( + (1, Gen.const("")), + ( + 10, + Gen { (p: Parameters) => + val s = p.rng.getRandomGenerator.nextInt(12) + val b = new StringBuilder() + for (i <- 0 until s) + b += Gen.randomOneOf(p.rng, Gen.printableChars) + b.result() + }, + ), + )) - implicit def arbBuildableOf[C[_], T](implicit a: Arbitrary[T], cbf: CanBuildFrom[Nothing, T, C[T]]): Arbitrary[C[T]] = + implicit def arbBuildableOf[C[_], T]( + implicit a: Arbitrary[T], + cbf: CanBuildFrom[Nothing, T, C[T]], + ): Arbitrary[C[T]] = Arbitrary(Gen.buildableOf(a.arbitrary)) def arbitrary[T](implicit arb: Arbitrary[T]): Gen[T] = arb.arbitrary diff --git a/hail/src/main/scala/is/hail/check/Gen.scala b/hail/src/main/scala/is/hail/check/Gen.scala index be05d686850..ae157bf802a 100644 --- a/hail/src/main/scala/is/hail/check/Gen.scala +++ b/hail/src/main/scala/is/hail/check/Gen.scala @@ -1,10 +1,7 @@ package is.hail.check -import breeze.linalg.DenseMatrix -import breeze.storage.Zero import is.hail.check.Arbitrary.arbitrary import is.hail.utils.roundWithConstantSum -import org.apache.commons.math3.random._ import scala.collection.generic.CanBuildFrom import scala.collection.mutable @@ -12,6 +9,10 @@ import scala.language.higherKinds import scala.math.Numeric.Implicits._ import scala.reflect.ClassTag +import breeze.linalg.DenseMatrix +import breeze.storage.Zero +import org.apache.commons.math3.random._ + object Parameters { val default = Parameters(new RandomDataGenerator(), 1000, 10) } @@ -29,7 +30,8 @@ object Gen { val nonExtremeDouble: Gen[Double] = oneOfGen( oneOf(1e30, -1.0, -1e-30, 0.0, 1e-30, 1.0, 1e30), choose(-100.0, 100.0), - choose(-1e150, 1e150)) + choose(-1e150, 1e150), + ) def squareOfAreaAtMostSize: Gen[(Int, Int)] = nCubeOfVolumeAtMostSize(2).map(x => (x(0), x(1))) @@ -38,14 +40,27 @@ object Gen { nonEmptyNCubeOfVolumeAtMostSize(2).map(x => (x(0), x(1))) def nCubeOfVolumeAtMostSize(n: Int): Gen[Array[Int]] = - Gen { (p: Parameters) => nCubeOfVolumeAtMost(p.rng, n, p.size) } + Gen((p: Parameters) => nCubeOfVolumeAtMost(p.rng, n, p.size)) def nonEmptyNCubeOfVolumeAtMostSize(n: Int): Gen[Array[Int]] = - Gen { (p: Parameters) => nCubeOfVolumeAtMost(p.rng, n, p.size).map(x => if (x == 0) 1 else x).toArray } + Gen { (p: Parameters) => + nCubeOfVolumeAtMost(p.rng, n, p.size).map(x => if (x == 0) 1 else x).toArray + } - def partition[T](rng: RandomDataGenerator, size: T, parts: Int, f: (RandomDataGenerator, T) => T)(implicit tn: Numeric[T], tct: ClassTag[T]): Array[T] = { + def partition[T]( + rng: RandomDataGenerator, + size: T, + parts: Int, + f: (RandomDataGenerator, T) => T, + )(implicit + tn: Numeric[T], + tct: ClassTag[T], + ): Array[T] = { import tn.mkOrderingOps - assert(size >= tn.zero, s"size must be greater than or equal to 0. Found $size. tn.zero=${ tn.zero }.") + assert( + size >= tn.zero, + s"size must be greater than or equal to 0. Found $size. tn.zero=${tn.zero}.", + ) if (parts == 0) return Array() @@ -70,29 +85,26 @@ object Gen { def partition(rng: RandomDataGenerator, size: Int, parts: Int): Array[Int] = partition(rng, size, parts, (rng: RandomDataGenerator, avail: Int) => rng.nextInt(0, avail)) - /** - * Picks a number of bins, n, from a BetaBinomial(alpha, beta), then takes - * {@code size} balls and places them into n bins according to a - * dirichlet-multinomial distribution with all alpha_i equal to n. - * - **/ - def partitionBetaDirichlet(rng: RandomDataGenerator, size: Int, alpha: Double, beta: Double): Array[Int] = + /** Picks a number of bins, n, from a BetaBinomial(alpha, beta), then takes {@code size} balls and + * places them into n bins according to a dirichlet-multinomial distribution with all alpha_i + * equal to n. + */ + def partitionBetaDirichlet(rng: RandomDataGenerator, size: Int, alpha: Double, beta: Double) + : Array[Int] = partitionDirichlet(rng, size, sampleBetaBinomial(rng, size, alpha, beta)) - /** - * Takes {@code size} balls and places them into {@code parts} bins according - * to a dirichlet-multinomial distribution with alpha_n equal to {@code - * parts} for all n. The outputs of this function tend towards uniformly - * distributed balls, i.e. vectors close to the center of the simplex in - * {@code parts} dimensions. - * - **/ + /** Takes {@code size} balls and places them into {@code parts} bins according to a + * dirichlet-multinomial distribution with alpha_n equal to {@code parts} for all n. The outputs + * of this function tend towards uniformly distributed balls, i.e. vectors close to the center of + * the simplex in {@code parts} dimensions. + */ def partitionDirichlet(rng: RandomDataGenerator, size: Int, parts: Int): Array[Int] = { val simplexVector = sampleDirichlet(rng, Array.fill(parts)(parts.toDouble)) roundWithConstantSum(simplexVector.map((x: Double) => x * size).toArray) } - def nCubeOfVolumeAtMost(rng: RandomDataGenerator, n: Int, size: Int, alpha: Int = 1): Array[Int] = { + def nCubeOfVolumeAtMost(rng: RandomDataGenerator, n: Int, size: Int, alpha: Int = 1) + : Array[Int] = { val sizeOfSum = math.log(size) val simplexVector = sampleDirichlet(rng, Array.fill(n)(alpha.toDouble)) roundWithConstantSum(simplexVector.map((x: Double) => x * sizeOfSum).toArray) @@ -106,32 +118,52 @@ object Gen { } def partition(parts: Int, sum: Int): Gen[Array[Int]] = - Gen { p => partition(p.rng, sum, parts, (rng: RandomDataGenerator, avail: Int) => rng.nextInt(0, avail)) } + Gen { p => + partition(p.rng, sum, parts, (rng: RandomDataGenerator, avail: Int) => rng.nextInt(0, avail)) + } def partition(parts: Int, sum: Long): Gen[Array[Long]] = - Gen { p => partition(p.rng, sum, parts, (rng: RandomDataGenerator, avail: Long) => rng.nextLong(0, avail)) } + Gen { p => + partition( + p.rng, + sum, + parts, + (rng: RandomDataGenerator, avail: Long) => rng.nextLong(0, avail), + ) + } def partition(parts: Int, sum: Double): Gen[Array[Double]] = - Gen { p => partition(p.rng, sum, parts, (rng: RandomDataGenerator, avail: Double) => rng.nextUniform(0, avail)) } + Gen { p => + partition( + p.rng, + sum, + parts, + (rng: RandomDataGenerator, avail: Double) => rng.nextUniform(0, avail), + ) + } def partitionSize(parts: Int): Gen[Array[Int]] = - Gen { p => partitionDirichlet(p.rng, p.size, parts) } + Gen(p => partitionDirichlet(p.rng, p.size, parts)) - def size: Gen[Int] = Gen { p => p.size } + def size: Gen[Int] = Gen(p => p.size) val printableChars = (0 to 127).map(_.toChar).filter(!_.isControl).toArray + val identifierLeadingChars = (0 to 127).map(_.toChar) .filter(c => c == '_' || c.isLetter) + val identifierChars = (0 to 127).map(_.toChar) .filter(c => c == '_' || c.isLetterOrDigit) + val plinkSafeStartOfIdentifierChars = (0 to 127).map(_.toChar) .filter(c => c.isLetter) + val plinkSafeChars = (0 to 127).map(_.toChar) .filter(c => c.isLetterOrDigit) def apply[T](gen: (Parameters) => T): Gen[T] = new Gen[T](gen) - def const[T](x: T): Gen[T] = Gen { (p: Parameters) => x } + def const[T](x: T): Gen[T] = Gen((p: Parameters) => x) def coin(p: Double = 0.5): Gen[Boolean] = { require(0.0 < p) @@ -141,28 +173,24 @@ object Gen { def oneOfSeq[T](xs: Seq[T]): Gen[T] = { assert(xs.nonEmpty) - Gen { (p: Parameters) => - xs(p.rng.getRandomGenerator.nextInt(xs.length)) - } + Gen((p: Parameters) => xs(p.rng.getRandomGenerator.nextInt(xs.length))) } def oneOfGen[T](gs: Gen[T]*): Gen[T] = { assert(gs.nonEmpty) - Gen { (p: Parameters) => - gs(p.rng.getRandomGenerator.nextInt(gs.length))(p) - } + Gen((p: Parameters) => gs(p.rng.getRandomGenerator.nextInt(gs.length))(p)) } def oneOf[T](xs: T*): Gen[T] = oneOfSeq(xs) def choose(min: Int, max: Int): Gen[Int] = { assert(max >= min) - Gen { (p: Parameters) => p.rng.nextInt(min, max) } + Gen((p: Parameters) => p.rng.nextInt(min, max)) } def choose(min: Long, max: Long): Gen[Long] = { assert(max >= min) - Gen { (p: Parameters) => p.rng.nextLong(min, max) } + Gen((p: Parameters) => p.rng.nextLong(min, max)) } def choose(min: Float, max: Float): Gen[Float] = Gen { (p: Parameters) => @@ -184,7 +212,8 @@ object Gen { def nextCoin(p: Double) = choose(0.0, 1.0).map(_ < p) - private def sampleBetaBinomial(rng: RandomDataGenerator, n: Int, alpha: Double, beta: Double): Int = + private def sampleBetaBinomial(rng: RandomDataGenerator, n: Int, alpha: Double, beta: Double) + : Int = rng.nextBinomial(n, rng.nextBeta(alpha, beta)) def nextBetaBinomial(n: Int, alpha: Double, beta: Double): Gen[Int] = Gen { p => @@ -230,47 +259,50 @@ object Gen { def subset[T](s: Set[T]): Gen[Set[T]] = Gen.parameterized { p => Gen.choose(0.0, 1.0).map(cutoff => - s.filter(_ => p.rng.getRandomGenerator.nextDouble <= cutoff)) + s.filter(_ => p.rng.getRandomGenerator.nextDouble <= cutoff) + ) } - def sequence[C[_], T](gs: Traversable[Gen[T]])(implicit cbf: CanBuildFrom[Nothing, T, C[T]]): Gen[C[T]] = + def sequence[C[_], T](gs: Traversable[Gen[T]])(implicit cbf: CanBuildFrom[Nothing, T, C[T]]) + : Gen[C[T]] = Gen { (p: Parameters) => val b = cbf() - gs.foreach { g => b += g(p) } + gs.foreach(g => b += g(p)) b.result() } - def denseMatrix[T : ClassTag : Zero : Arbitrary](): Gen[DenseMatrix[T]] = for { + def denseMatrix[T: ClassTag: Zero: Arbitrary](): Gen[DenseMatrix[T]] = for { (l, w) <- Gen.nonEmptySquareOfAreaAtMostSize m <- denseMatrix(l, w) } yield m - def denseMatrix[T : ClassTag : Zero : Arbitrary](n: Int, m: Int): Gen[DenseMatrix[T]] = + def denseMatrix[T: ClassTag: Zero: Arbitrary](n: Int, m: Int): Gen[DenseMatrix[T]] = denseMatrix[T](n, m, arbitrary[T]) - def denseMatrix[T : ClassTag : Zero](n: Int, m: Int, g: Gen[T]): Gen[DenseMatrix[T]] = Gen { (p: Parameters) => - DenseMatrix.fill[T](n, m)(g.resize(p.size / (n * m))(p)) - } + def denseMatrix[T: ClassTag: Zero](n: Int, m: Int, g: Gen[T]): Gen[DenseMatrix[T]] = + Gen((p: Parameters) => DenseMatrix.fill[T](n, m)(g.resize(p.size / (n * m))(p))) - def twoMultipliableDenseMatrices[T : ClassTag : Zero : Arbitrary](): Gen[(DenseMatrix[T], DenseMatrix[T])] = + def twoMultipliableDenseMatrices[T: ClassTag: Zero: Arbitrary]() + : Gen[(DenseMatrix[T], DenseMatrix[T])] = twoMultipliableDenseMatrices(arbitrary[T]) - def twoMultipliableDenseMatrices[T : ClassTag : Zero](g: Gen[T]): Gen[(DenseMatrix[T], DenseMatrix[T])] = for { + def twoMultipliableDenseMatrices[T: ClassTag: Zero](g: Gen[T]) + : Gen[(DenseMatrix[T], DenseMatrix[T])] = for { Array(rows, inner, columns) <- Gen.nonEmptyNCubeOfVolumeAtMostSize(3) l <- denseMatrix(rows, inner, g) r <- denseMatrix(inner, columns, g) } yield (l, r) - /** - * In general, for any Traversable type T and any Monad M, we may convert an {@code F[M[T]]} to an {@code M[F[T]]} by - * choosing to perform the actions in the order defined by the traversable. With {@code Gen} we must also consider - * the distribution of size. {@code uniformSequence} distributes the size uniformly across all elements of the - * traversable. - * - **/ - def uniformSequence[C[_], T](gs: Traversable[Gen[T]])(implicit cbf: CanBuildFrom[Nothing, T, C[T]]): Gen[C[T]] = { + /** In general, for any Traversable type T and any Monad M, we may convert an {@code F[M[T]]} to + * an {@code M[F[T]]} by choosing to perform the actions in the order defined by the traversable. + * With {@code Gen} we must also consider the distribution of size. {@code uniformSequence} + * distributes the size uniformly across all elements of the traversable. + */ + def uniformSequence[C[_], T]( + gs: Traversable[Gen[T]] + )(implicit cbf: CanBuildFrom[Nothing, T, C[T]] + ): Gen[C[T]] = partitionSize(gs.size).map(resizeMany(gs, _)).flatMap(sequence[C, T]) - } private def resizeMany[T](gs: Traversable[Gen[T]], partition: Array[Int]): Iterable[Gen[T]] = (gs.toIterable, partition).zipped.map((gen, size) => gen.resize(size)) @@ -287,11 +319,15 @@ object Gen { def buildableOf[C[_]] = buildableOfInstance.asInstanceOf[BuildableOf[C]] - implicit def buildableOfFromElements[C[_], T](implicit g: Gen[T], cbf: CanBuildFrom[Nothing, T, C[T]]): Gen[C[T]] = + implicit def buildableOfFromElements[C[_], T]( + implicit g: Gen[T], + cbf: CanBuildFrom[Nothing, T, C[T]], + ): Gen[C[T]] = buildableOf[C](g) sealed trait BuildableOf2[C[_, _]] { - def apply[T, U](g: Gen[(T, U)])(implicit cbf: CanBuildFrom[Nothing, (T, U), C[T, U]]): Gen[C[T, U]] = + def apply[T, U](g: Gen[(T, U)])(implicit cbf: CanBuildFrom[Nothing, (T, U), C[T, U]]) + : Gen[C[T, U]] = unsafeBuildableOf(g) } @@ -301,7 +337,9 @@ object Gen { private val buildableOfAlpha = 3 private val buildableOfBeta = 6 - private def unsafeBuildableOf[C, T](g: Gen[T])(implicit cbf: CanBuildFrom[Nothing, T, C]): Gen[C] = + + private def unsafeBuildableOf[C, T](g: Gen[T])(implicit cbf: CanBuildFrom[Nothing, T, C]) + : Gen[C] = Gen { (p: Parameters) => val b = cbf() if (p.size == 0) @@ -309,7 +347,12 @@ object Gen { else { // scale up a bit by log, so that we can spread out a bit more with // higher sizes - val part = partitionBetaDirichlet(p.rng, p.size, buildableOfAlpha, buildableOfBeta * math.log(p.size + 0.01)) + val part = partitionBetaDirichlet( + p.rng, + p.size, + buildableOfAlpha, + buildableOfBeta * math.log(p.size + 0.01), + ) val s = part.length for (i <- 0 until s) b += g(p.copy(size = part(i))) @@ -326,7 +369,12 @@ object Gen { else { // scale up a bit by log, so that we can spread out a bit more with // higher sizes - val part = partitionBetaDirichlet(p.rng, p.size, buildableOfAlpha, buildableOfBeta * math.log(p.size + 0.01)) + val part = partitionBetaDirichlet( + p.rng, + p.size, + buildableOfAlpha, + buildableOfBeta * math.log(p.size + 0.01), + ) val s = part.length val t = mutable.Set.empty[T] for (i <- 0 until s) @@ -341,29 +389,34 @@ object Gen { def distinctBuildableOf[C[_]] = distinctBuildableOfInstance.asInstanceOf[DistinctBuildableOf[C]] - /** - * This function terminates with probability equal to the probability of {@code g} generating {@code min} distinct - * elements in finite time. + /** This function terminates with probability equal to the probability of {@code g} generating + * {@code min} distinct elements in finite time. */ sealed trait DistinctBuildableOfAtLeast[C[_]] { def apply[T](min: Int, g: Gen[T])(implicit cbf: CanBuildFrom[Nothing, T, C[T]]): Gen[C[T]] = { Gen { (p: Parameters) => val b = cbf() if (p.size < min) { - throw new RuntimeException(s"Size (${ p.size }) is too small for buildable of size at least $min") + throw new RuntimeException( + s"Size (${p.size}) is too small for buildable of size at least $min" + ) } else if (p.size == 0) b.result() else { // scale up a bit by log, so that we can spread out a bit more with // higher sizes - val s = min + sampleBetaBinomial(p.rng, p.size - min, buildableOfAlpha, buildableOfBeta * math.log((p.size - min) + 0.01)) + val s = min + sampleBetaBinomial( + p.rng, + p.size - min, + buildableOfAlpha, + buildableOfBeta * math.log((p.size - min) + 0.01), + ) val part = partitionDirichlet(p.rng, p.size, s) val t = mutable.Set.empty[T] for (i <- 0 until s) { var element = g.resize(part(i))(p) - while (t.contains(element)) { + while (t.contains(element)) element = g.resize(part(i))(p) - } t += element } b ++= t @@ -375,7 +428,8 @@ object Gen { private object distinctBuildableOfAtLeastInstance extends DistinctBuildableOfAtLeast[Nothing] - def distinctBuildableOfAtLeast[C[_]] = distinctBuildableOfAtLeastInstance.asInstanceOf[DistinctBuildableOfAtLeast[C]] + def distinctBuildableOfAtLeast[C[_]] = + distinctBuildableOfAtLeastInstance.asInstanceOf[DistinctBuildableOfAtLeast[C]] sealed trait BuildableOfN[C[_]] { def apply[T](n: Int, g: Gen[T])(implicit cbf: CanBuildFrom[Nothing, T, C[T]]): Gen[C[T]] = @@ -410,7 +464,8 @@ object Gen { private object distinctBuildableOfNInstance extends DistinctBuildableOfN[Nothing] - def distinctBuildableOfN[C[_]] = distinctBuildableOfNInstance.asInstanceOf[DistinctBuildableOfN[C]] + def distinctBuildableOfN[C[_]] = + distinctBuildableOfNInstance.asInstanceOf[DistinctBuildableOfN[C]] def randomOneOf[T](rng: RandomDataGenerator, is: IndexedSeq[T]): T = { assert(is.nonEmpty) @@ -423,7 +478,10 @@ object Gen { def plinkSafeIdentifier: Gen[String] = identifierGen(plinkSafeStartOfIdentifierChars, plinkSafeChars) - private def identifierGen(leadingCharacter: IndexedSeq[Char], trailingCharacters: IndexedSeq[Char]): Gen[String] = Gen { p => + private def identifierGen( + leadingCharacter: IndexedSeq[Char], + trailingCharacters: IndexedSeq[Char], + ): Gen[String] = Gen { p => val s = 1 + p.rng.getRandomGenerator.nextInt(11) val b = new StringBuilder() b += randomOneOf(p.rng, leadingCharacter) @@ -439,18 +497,17 @@ object Gen { None } - def nonnegInt: Gen[Int] = Gen { p => - p.rng.getRandomGenerator.nextInt() & Int.MaxValue - } + def nonnegInt: Gen[Int] = Gen(p => p.rng.getRandomGenerator.nextInt() & Int.MaxValue) def posInt: Gen[Int] = Gen { (p: Parameters) => p.rng.getRandomGenerator.nextInt(Int.MaxValue - 1) + 1 } def interestingPosInt: Gen[Int] = oneOfGen( - oneOf(1, 2, Int.MaxValue - 1, Int.MaxValue), - choose(1, 100), - posInt) + oneOf(1, 2, Int.MaxValue - 1, Int.MaxValue), + choose(1, 100), + posInt, + ) def zip[T1](g1: Gen[T1]): Gen[T1] = g1 @@ -467,7 +524,8 @@ object Gen { z <- g3.resize(s3) } yield (x, y, z) - def zip[T1, T2, T3, T4](g1: Gen[T1], g2: Gen[T2], g3: Gen[T3], g4: Gen[T4]): Gen[(T1, T2, T3, T4)] = for { + def zip[T1, T2, T3, T4](g1: Gen[T1], g2: Gen[T2], g3: Gen[T3], g4: Gen[T4]) + : Gen[(T1, T2, T3, T4)] = for { Array(s1, s2, s3, s4) <- partitionSize(4) x <- g1.resize(s1) y <- g2.resize(s2) @@ -475,9 +533,9 @@ object Gen { w <- g4.resize(s4) } yield (x, y, z, w) - def parameterized[T](f: (Parameters => Gen[T])) = Gen { p => f(p)(p) } + def parameterized[T](f: (Parameters => Gen[T])) = Gen(p => f(p)(p)) - def sized[T](f: (Int) => Gen[T]): Gen[T] = Gen { (p: Parameters) => f(p.size)(p) } + def sized[T](f: (Int) => Gen[T]): Gen[T] = Gen((p: Parameters) => f(p.size)(p)) def applyGen[T, S](gf: Gen[(T) => S], gx: Gen[T]): Gen[S] = Gen { p => val f = gf(p) @@ -492,15 +550,11 @@ class Gen[+T](val gen: (Parameters) => T) extends AnyVal { def sample(): T = apply(Parameters.default) - def map[U](f: (T) => U): Gen[U] = Gen { p => f(apply(p)) } + def map[U](f: (T) => U): Gen[U] = Gen(p => f(apply(p))) - def flatMap[U](f: (T) => Gen[U]): Gen[U] = Gen { p => - f(apply(p))(p) - } + def flatMap[U](f: (T) => Gen[U]): Gen[U] = Gen(p => f(apply(p))(p)) - def resize(newSize: Int): Gen[T] = Gen { (p: Parameters) => - apply(p.copy(size = newSize)) - } + def resize(newSize: Int): Gen[T] = Gen((p: Parameters) => apply(p.copy(size = newSize))) // FIXME should be non-strict def withFilter(f: (T) => Boolean): Gen[T] = Gen { (p: Parameters) => diff --git a/hail/src/main/scala/is/hail/check/Prop.scala b/hail/src/main/scala/is/hail/check/Prop.scala index 7dc7f5cb82d..b5fd2487e25 100644 --- a/hail/src/main/scala/is/hail/check/Prop.scala +++ b/hail/src/main/scala/is/hail/check/Prop.scala @@ -1,14 +1,14 @@ package is.hail.check -import org.apache.commons.math3.random.RandomDataGenerator - import scala.collection.mutable.ArrayBuffer import scala.util.{Failure, Random, Success, Try} +import org.apache.commons.math3.random.RandomDataGenerator + abstract class Prop { def apply(p: Parameters, name: Option[String] = None): Unit - def check() { + def check(): Unit = { val size = System.getProperty("check.size", "1000").toInt val count = System.getProperty("check.count", "10").toInt @@ -21,7 +21,7 @@ abstract class Prop { } class GenProp1[T1](g1: Gen[T1], f: (T1) => Boolean) extends Prop { - override def apply(p: Parameters, name: Option[String]) { + override def apply(p: Parameters, name: Option[String]): Unit = { val prefix = name.map(_ + ": ").getOrElse("") for (i <- 0 until p.count) { val v1 = g1(p) @@ -29,21 +29,21 @@ class GenProp1[T1](g1: Gen[T1], f: (T1) => Boolean) extends Prop { r match { case Success(true) => case Success(false) => - println(s"""! ${ prefix }Falsified after $i passed tests.""") + println(s"""! ${prefix}Falsified after $i passed tests.""") println(s"> ARG_0: $v1") throw new AssertionError(null) case Failure(e) => - println(s"""! ${ prefix }Error after $i passed tests.""") + println(s"""! ${prefix}Error after $i passed tests.""") println(s"> ARG_0: $v1") throw new AssertionError(e) } } - println(s" + ${ prefix }OK, passed ${ p.count } tests.") + println(s" + ${prefix}OK, passed ${p.count} tests.") } } class GenProp2[T1, T2](g1: Gen[T1], g2: Gen[T2], f: (T1, T2) => Boolean) extends Prop { - override def apply(p: Parameters, name: Option[String]) { + override def apply(p: Parameters, name: Option[String]): Unit = { val prefix = name.map(_ + ": ").getOrElse("") for (i <- 0 until p.count) { val v1 = g1(p) @@ -52,21 +52,22 @@ class GenProp2[T1, T2](g1: Gen[T1], g2: Gen[T2], f: (T1, T2) => Boolean) extends r match { case Success(true) => case Success(false) => - println(s"""! ${ prefix }Falsified after $i passed tests.""") + println(s"""! ${prefix}Falsified after $i passed tests.""") println(s"> ARG_0: $v1") throw new AssertionError(null) case Failure(e) => - println(s"""! ${ prefix }Error after $i passed tests.""") + println(s"""! ${prefix}Error after $i passed tests.""") println(s"> ARG_0: $v1") throw new AssertionError(e) } } - println(s" + ${ prefix }OK, passed ${ p.count } tests.") + println(s" + ${prefix}OK, passed ${p.count} tests.") } } -class GenProp3[T1, T2, T3](g1: Gen[T1], g2: Gen[T2], g3: Gen[T3], f: (T1, T2, T3) => Boolean) extends Prop { - override def apply(p: Parameters, name: Option[String]) { +class GenProp3[T1, T2, T3](g1: Gen[T1], g2: Gen[T2], g3: Gen[T3], f: (T1, T2, T3) => Boolean) + extends Prop { + override def apply(p: Parameters, name: Option[String]): Unit = { val prefix = name.map(_ + ": ").getOrElse("") for (i <- 0 until p.count) { val v1 = g1(p) @@ -76,16 +77,16 @@ class GenProp3[T1, T2, T3](g1: Gen[T1], g2: Gen[T2], g3: Gen[T3], f: (T1, T2, T3 r match { case Success(true) => case Success(false) => - println(s"""! ${ prefix }Falsified after $i passed tests.""") + println(s"""! ${prefix}Falsified after $i passed tests.""") println(s"> ARG_0: $v1") throw new AssertionError(null) case Failure(e) => - println(s"""! ${ prefix }Error after $i passed tests.""") + println(s"""! ${prefix}Error after $i passed tests.""") println(s"> ARG_0: $v1") throw new AssertionError(e) } } - println(s" + ${ prefix }OK, passed ${ p.count } tests.") + println(s" + ${prefix}OK, passed ${p.count} tests.") } } @@ -93,17 +94,15 @@ class Properties(val name: String) extends Prop { val properties = ArrayBuffer.empty[(String, Prop)] class PropertySpecifier { - def update(propName: String, prop: Prop) { + def update(propName: String, prop: Prop): Unit = properties += (name + "." + propName) -> prop - } } lazy val property = new PropertySpecifier - override def apply(p: Parameters, prefix: Option[String]) { + override def apply(p: Parameters, prefix: Option[String]): Unit = for ((propName, prop) <- properties) prop.apply(p, prefix.map(_ + "." + propName).orElse(Some(propName))) - } } @@ -119,13 +118,12 @@ object Prop { } def seed: Int = { - println(s"check: seed = ${ _seed }") + println(s"check: seed = ${_seed}") _seed } - def check(prop: Prop) { + def check(prop: Prop): Unit = prop.check() - } def forAll[T1](g1: Gen[Boolean]): Prop = new GenProp1(g1, identity[Boolean]) @@ -145,7 +143,13 @@ object Prop { def forAll[T1, T2](p: (T1, T2) => Boolean)(implicit a1: Arbitrary[T1], a2: Arbitrary[T2]): Prop = new GenProp2(a1.arbitrary, a2.arbitrary, p) - def forAll[T1, T2, T3](p: (T1, T2, T3) => Boolean)(implicit a1: Arbitrary[T1], a2: Arbitrary[T2], a3: Arbitrary[T3]): Prop = + def forAll[T1, T2, T3]( + p: (T1, T2, T3) => Boolean + )(implicit + a1: Arbitrary[T1], + a2: Arbitrary[T2], + a3: Arbitrary[T3], + ): Prop = new GenProp3(a1.arbitrary, a2.arbitrary, a3.arbitrary, p) } diff --git a/hail/src/main/scala/is/hail/compatibility/LegacyBufferSpecs.scala b/hail/src/main/scala/is/hail/compatibility/LegacyBufferSpecs.scala index 4db1e46d250..95d595ac6b0 100644 --- a/hail/src/main/scala/is/hail/compatibility/LegacyBufferSpecs.scala +++ b/hail/src/main/scala/is/hail/compatibility/LegacyBufferSpecs.scala @@ -1,8 +1,8 @@ package is.hail.compatibility import is.hail.asm4s.Code -import is.hail.io.compress.LZ4 import is.hail.io._ +import is.hail.io.compress.LZ4 final case class LZ4BlockBufferSpec(blockSize: Int, child: BlockBufferSpec) extends LZ4BlockBufferSpecCommon { @@ -10,4 +10,3 @@ final case class LZ4BlockBufferSpec(blockSize: Int, child: BlockBufferSpec) def stagedlz4: Code[LZ4] = Code.invokeScalaObject0[LZ4](LZ4.getClass, "hc") def typeName = "LZ4BlockBufferSpec" } - diff --git a/hail/src/main/scala/is/hail/compatibility/LegacyEncodedTypeParser.scala b/hail/src/main/scala/is/hail/compatibility/LegacyEncodedTypeParser.scala index b3526223302..592ad8be637 100644 --- a/hail/src/main/scala/is/hail/compatibility/LegacyEncodedTypeParser.scala +++ b/hail/src/main/scala/is/hail/compatibility/LegacyEncodedTypeParser.scala @@ -1,7 +1,7 @@ package is.hail.compatibility -import is.hail.expr.ir.IRParser._ import is.hail.expr.ir.{IRParser, PunctuationToken, TokenIterator} +import is.hail.expr.ir.IRParser._ import is.hail.types.encoded._ import is.hail.types.virtual._ import is.hail.utils.FastSeq @@ -21,12 +21,18 @@ object LegacyEncodedTypeParser { punctuation(it, "[") val (pointType, ePointType) = legacy_type_expr(it) punctuation(it, "]") - (TInterval(pointType), EBaseStruct(FastSeq( - EField("start", ePointType, 0), - EField("end", ePointType, 1), - EField("includesStart", EBooleanRequired, 2), - EField("includesEnd", EBooleanRequired, 3) - ), req)) + ( + TInterval(pointType), + EBaseStruct( + FastSeq( + EField("start", ePointType, 0), + EField("end", ePointType, 1), + EField("includesStart", EBooleanRequired, 2), + EField("includesEnd", EBooleanRequired, 3), + ), + req, + ), + ) case "Boolean" => (TBoolean, EBoolean(req)) case "Int32" => (TInt32, EInt32(req)) case "Int64" => (TInt64, EInt64(req)) @@ -38,9 +44,16 @@ object LegacyEncodedTypeParser { punctuation(it, "(") val rg = identifier(it) punctuation(it, ")") - (TLocus(rg), EBaseStruct(FastSeq( - EField("contig", EBinaryRequired, 0), - EField("position", EInt32Required, 1)), req)) + ( + TLocus(rg), + EBaseStruct( + FastSeq( + EField("contig", EBinaryRequired, 0), + EField("position", EInt32Required, 1), + ), + req, + ), + ) case "Call" => (TCall, EInt32(req)) case "Array" => punctuation(it, "[") @@ -58,20 +71,42 @@ object LegacyEncodedTypeParser { punctuation(it, ",") val (valueType, valueEType) = legacy_type_expr(it) punctuation(it, "]") - (TDict(keyType, valueType), EArray(EBaseStruct(FastSeq( - EField("key", keyEType, 0), - EField("value", valueEType, 1)), required = true), - req)) + ( + TDict(keyType, valueType), + EArray( + EBaseStruct( + FastSeq( + EField("key", keyEType, 0), + EField("value", valueEType, 1), + ), + required = true, + ), + req, + ), + ) case "Tuple" => punctuation(it, "[") val types = repsepUntil(it, legacy_type_expr, PunctuationToken(","), PunctuationToken("]")) punctuation(it, "]") - (TTuple(types.map(_._1): _*), EBaseStruct(types.zipWithIndex.map { case ((_, t), idx) => EField(idx.toString, t, idx) }, req)) + ( + TTuple(types.map(_._1): _*), + EBaseStruct( + types.zipWithIndex.map { case ((_, t), idx) => EField(idx.toString, t, idx) }, + req, + ), + ) case "Struct" => punctuation(it, "{") - val args = repsepUntil(it, struct_field(legacy_type_expr), PunctuationToken(","), PunctuationToken("}")) + val args = repsepUntil( + it, + struct_field(legacy_type_expr), + PunctuationToken(","), + PunctuationToken("}"), + ) punctuation(it, "}") - val (vFields, eFields) = args.zipWithIndex.map { case ((id, (vt, et)), i) => (Field(id, vt, i), EField(id, et, i)) }.unzip + val (vFields, eFields) = args.zipWithIndex.map { case ((id, (vt, et)), i) => + (Field(id, vt, i), EField(id, et, i)) + }.unzip (TStruct(vFields), EBaseStruct(eFields, req)) } assert(eType.required == req) @@ -96,10 +131,8 @@ object LegacyEncodedTypeParser { } } - - def parseTypeAndEType(str: String): (Type, EType) = { + def parseTypeAndEType(str: String): (Type, EType) = IRParser.parse(str, it => legacy_type_expr(it)) - } def parseLegacyRVDType(str: String): LegacyRVDType = IRParser.parse(str, it => rvd_type_expr(it)) } diff --git a/hail/src/main/scala/is/hail/compatibility/LegacyRVDSpecs.scala b/hail/src/main/scala/is/hail/compatibility/LegacyRVDSpecs.scala index e863e000784..64a07f50c2e 100644 --- a/hail/src/main/scala/is/hail/compatibility/LegacyRVDSpecs.scala +++ b/hail/src/main/scala/is/hail/compatibility/LegacyRVDSpecs.scala @@ -7,68 +7,99 @@ import is.hail.rvd.{AbstractRVDSpec, IndexSpec2, IndexedRVDSpec2, RVDPartitioner import is.hail.types.encoded._ import is.hail.types.virtual._ import is.hail.utils.{FastSeq, Interval} + import org.json4s.JValue -case class IndexSpec private( +case class IndexSpec private ( relPath: String, keyType: String, annotationType: String, - offsetField: Option[String] + offsetField: Option[String], ) { val baseSpec = LEB128BufferSpec( - BlockingBufferSpec(32 * 1024, - LZ4BlockBufferSpec(32 * 1024, - new StreamBlockBufferSpec))) + BlockingBufferSpec(32 * 1024, LZ4BlockBufferSpec(32 * 1024, new StreamBlockBufferSpec)) + ) val (keyVType, keyEType) = LegacyEncodedTypeParser.parseTypeAndEType(keyType) val (annotationVType, annotationEType) = LegacyEncodedTypeParser.parseTypeAndEType(annotationType) val leafEType = EBaseStruct(FastSeq( EField("first_idx", EInt64Required, 0), - EField("keys", EArray(EBaseStruct(FastSeq( - EField("key", keyEType, 0), - EField("offset", EInt64Required, 1), - EField("annotation", annotationEType, 2) - ), required = true), required = true), 1) + EField( + "keys", + EArray( + EBaseStruct( + FastSeq( + EField("key", keyEType, 0), + EField("offset", EInt64Required, 1), + EField("annotation", annotationEType, 2), + ), + required = true, + ), + required = true, + ), + 1, + ), )) + val leafVType = TStruct(FastSeq( Field("first_idx", TInt64, 0), - Field("keys", TArray(TStruct(FastSeq( - Field("key", keyVType, 0), - Field("offset", TInt64, 1), - Field("annotation", annotationVType, 2) - ))), 1))) + Field( + "keys", + TArray(TStruct(FastSeq( + Field("key", keyVType, 0), + Field("offset", TInt64, 1), + Field("annotation", annotationVType, 2), + ))), + 1, + ), + )) val internalNodeEType = EBaseStruct(FastSeq( - EField("children", EArray(EBaseStruct(FastSeq( - EField("index_file_offset", EInt64Required, 0), - EField("first_idx", EInt64Required, 1), - EField("first_key", keyEType, 2), - EField("first_record_offset", EInt64Required, 3), - EField("first_annotation", annotationEType, 4) - ), required = true), required = true), 0) + EField( + "children", + EArray( + EBaseStruct( + FastSeq( + EField("index_file_offset", EInt64Required, 0), + EField("first_idx", EInt64Required, 1), + EField("first_key", keyEType, 2), + EField("first_record_offset", EInt64Required, 3), + EField("first_annotation", annotationEType, 4), + ), + required = true, + ), + required = true, + ), + 0, + ) )) val internalNodeVType = TStruct(FastSeq( - Field("children", TArray(TStruct(FastSeq( - Field("index_file_offset", TInt64, 0), - Field("first_idx", TInt64, 1), - Field("first_key", keyVType, 2), - Field("first_record_offset", TInt64, 3), - Field("first_annotation", annotationVType, 4) - ))), 0) + Field( + "children", + TArray(TStruct(FastSeq( + Field("index_file_offset", TInt64, 0), + Field("first_idx", TInt64, 1), + Field("first_key", keyVType, 2), + Field("first_record_offset", TInt64, 3), + Field("first_annotation", annotationVType, 4), + ))), + 0, + ) )) - val leafCodec: AbstractTypedCodecSpec = TypedCodecSpec(leafEType, leafVType, baseSpec) - val internalNodeCodec: AbstractTypedCodecSpec = TypedCodecSpec(internalNodeEType, internalNodeVType, baseSpec) + + val internalNodeCodec: AbstractTypedCodecSpec = + TypedCodecSpec(internalNodeEType, internalNodeVType, baseSpec) def toIndexSpec2: IndexSpec2 = IndexSpec2( - relPath, leafCodec, internalNodeCodec, keyVType, annotationVType, offsetField + relPath, leafCodec, internalNodeCodec, keyVType, annotationVType, offsetField, ) } -case class PackCodecSpec private(child: BufferSpec) +case class PackCodecSpec private (child: BufferSpec) case class LegacyRVDType(rowType: TStruct, rowEType: EType, key: IndexedSeq[String]) { def keyType: TStruct = rowType.select(key)._1 @@ -91,41 +122,48 @@ trait ShimRVDSpec extends AbstractRVDSpec { lazy val attrs: Map[String, String] = shim.attrs } -case class IndexedRVDSpec private( +case class IndexedRVDSpec private ( rvdType: String, codecSpec: PackCodecSpec, indexSpec: IndexSpec, override val partFiles: Array[String], - jRangeBounds: JValue + jRangeBounds: JValue, ) extends ShimRVDSpec { private val lRvdType = LegacyEncodedTypeParser.parseLegacyRVDType(rvdType) - lazy val shim = IndexedRVDSpec2(lRvdType.key, + lazy val shim = IndexedRVDSpec2( + lRvdType.key, TypedCodecSpec(lRvdType.rowEType.setRequired(true), lRvdType.rowType, codecSpec.child), - indexSpec.toIndexSpec2, partFiles, jRangeBounds, Map.empty[String, String]) + indexSpec.toIndexSpec2, + partFiles, + jRangeBounds, + Map.empty[String, String], + ) } -case class UnpartitionedRVDSpec private( +case class UnpartitionedRVDSpec private ( rowType: String, codecSpec: PackCodecSpec, - partFiles: Array[String] + partFiles: Array[String], ) extends AbstractRVDSpec { private val (rowVType: TStruct, rowEType) = LegacyEncodedTypeParser.parseTypeAndEType(rowType) - def partitioner(sm: HailStateManager): RVDPartitioner = RVDPartitioner.unkeyed(sm, partFiles.length) + def partitioner(sm: HailStateManager): RVDPartitioner = + RVDPartitioner.unkeyed(sm, partFiles.length) def key: IndexedSeq[String] = FastSeq() - def typedCodecSpec: AbstractTypedCodecSpec = TypedCodecSpec(rowEType.setRequired(true), rowVType, codecSpec.child) + def typedCodecSpec: AbstractTypedCodecSpec = + TypedCodecSpec(rowEType.setRequired(true), rowVType, codecSpec.child) val attrs: Map[String, String] = Map.empty } -case class OrderedRVDSpec private( +case class OrderedRVDSpec private ( rvdType: String, codecSpec: PackCodecSpec, partFiles: Array[String], - jRangeBounds: JValue + jRangeBounds: JValue, ) extends AbstractRVDSpec { private val lRvdType = LegacyEncodedTypeParser.parseLegacyRVDType(rvdType) @@ -133,11 +171,19 @@ case class OrderedRVDSpec private( def partitioner(sm: HailStateManager): RVDPartitioner = { val rangeBoundsType = TArray(TInterval(lRvdType.keyType)) - new RVDPartitioner(sm, lRvdType.keyType, - JSONAnnotationImpex.importAnnotation(jRangeBounds, rangeBoundsType, padNulls = false).asInstanceOf[IndexedSeq[Interval]]) + new RVDPartitioner( + sm, + lRvdType.keyType, + JSONAnnotationImpex.importAnnotation( + jRangeBounds, + rangeBoundsType, + padNulls = false, + ).asInstanceOf[IndexedSeq[Interval]], + ) } - override def typedCodecSpec: AbstractTypedCodecSpec = TypedCodecSpec(lRvdType.rowEType.setRequired(true), lRvdType.rowType, codecSpec.child) + override def typedCodecSpec: AbstractTypedCodecSpec = + TypedCodecSpec(lRvdType.rowEType.setRequired(true), lRvdType.rowType, codecSpec.child) val attrs: Map[String, String] = Map.empty } diff --git a/hail/src/main/scala/is/hail/cxx/RegionValueIterator.scala b/hail/src/main/scala/is/hail/cxx/RegionValueIterator.scala index a0ce947b699..9b4cc37f4cd 100644 --- a/hail/src/main/scala/is/hail/cxx/RegionValueIterator.scala +++ b/hail/src/main/scala/is/hail/cxx/RegionValueIterator.scala @@ -7,4 +7,4 @@ class RegionValueIterator(it: Iterator[RegionValue]) extends Iterator[Long] { def next(): Long = it.next().offset def hasNext: Boolean = it.hasNext -} \ No newline at end of file +} diff --git a/hail/src/main/scala/is/hail/experimental/ExperimentalFunctions.scala b/hail/src/main/scala/is/hail/experimental/ExperimentalFunctions.scala index b93be8b0d41..e19680bbdc7 100644 --- a/hail/src/main/scala/is/hail/experimental/ExperimentalFunctions.scala +++ b/hail/src/main/scala/is/hail/experimental/ExperimentalFunctions.scala @@ -1,17 +1,27 @@ package is.hail.experimental import is.hail.expr.ir.functions._ +import is.hail.types.physical.{PCanonicalArray, PFloat64} import is.hail.types.physical.stypes.SType import is.hail.types.physical.stypes.concrete.SIndexablePointer -import is.hail.types.physical.{PCanonicalArray, PFloat64, PType} import is.hail.types.virtual.{TArray, TFloat64, TInt32, Type} object ExperimentalFunctions extends RegistryFunctions { - def registerAll() { + def registerAll(): Unit = { val experimentalPackageClass = Class.forName("is.hail.experimental.package$") - registerScalaFunction("filtering_allele_frequency", Array(TInt32, TInt32, TFloat64), TFloat64, null)(experimentalPackageClass, "calcFilterAlleleFreq") - registerWrappedScalaFunction1("haplotype_freq_em", TArray(TInt32), TArray(TFloat64), (_: Type, pt: SType) => SIndexablePointer(PCanonicalArray(PFloat64(true))))(experimentalPackageClass, "haplotypeFreqEM") + registerScalaFunction( + "filtering_allele_frequency", + Array(TInt32, TInt32, TFloat64), + TFloat64, + null, + )(experimentalPackageClass, "calcFilterAlleleFreq") + registerWrappedScalaFunction1( + "haplotype_freq_em", + TArray(TInt32), + TArray(TFloat64), + (_: Type, pt: SType) => SIndexablePointer(PCanonicalArray(PFloat64(true))), + )(experimentalPackageClass, "haplotypeFreqEM") } -} \ No newline at end of file +} diff --git a/hail/src/main/scala/is/hail/experimental/package.scala b/hail/src/main/scala/is/hail/experimental/package.scala index d9ab5dfd158..623c7e4cb69 100644 --- a/hail/src/main/scala/is/hail/experimental/package.scala +++ b/hail/src/main/scala/is/hail/experimental/package.scala @@ -1,14 +1,15 @@ package is.hail -import breeze.linalg.{DenseVector, max, sum} -import breeze.numerics._ import is.hail.stats._ import is.hail.utils._ +import breeze.linalg.{max, sum, DenseVector} +import breeze.numerics._ + package object experimental { def findMaxAC(af: Double, an: Int, ci: Double = .95): Int = { - if (af == 0) + if (af == 0) 0 else { val quantile_limit = ci // ci for one-sided, 1-(1-ci)/2 for two-sided @@ -17,66 +18,76 @@ package object experimental { } } - def calcFilterAlleleFreq(ac: Int, an: Int, ci: Double = .95, lower: Double = 1e-10, upper: Double = 2, tol: Double = 1e-7, precision: Double = 1e-6): Double = { + def calcFilterAlleleFreq( + ac: Int, + an: Int, + ci: Double = .95, + lower: Double = 1e-10, + upper: Double = 2, + tol: Double = 1e-7, + precision: Double = 1e-6, + ): Double = { if (ac <= 1 || an == 0) // FAF should not be calculated on singletons 0.0 else { - var f = (af: Double) => ac.toDouble - 1 - qpois(ci, an.toDouble * af) + val f = (af: Double) => ac.toDouble - 1 - qpois(ci, an.toDouble * af) val root = uniroot(f, lower, upper, tol) val rounder = 1d / (precision / 100d) var max_af = math.round(root.getOrElse(0.0) * rounder) / rounder - while (findMaxAC(max_af, an, ci) < ac) { + while (findMaxAC(max_af, an, ci) < ac) max_af += precision - } max_af - precision } } - def calcFilterAlleleFreq(ac: Int, an: Int, ci: Double): Double = calcFilterAlleleFreq(ac, an, ci, lower = 1e-10, upper = 2, tol = 1e-7, precision = 1e-6) + def calcFilterAlleleFreq(ac: Int, an: Int, ci: Double): Double = + calcFilterAlleleFreq(ac, an, ci, lower = 1e-10, upper = 2, tol = 1e-7, precision = 1e-6) + def haplotypeFreqEM(gtCounts: IndexedSeq[Int]): IndexedSeq[Double] = { - def haplotypeFreqEM(gtCounts : IndexedSeq[Int]) : IndexedSeq[Double] = { - - assert(gtCounts.size == 9, "haplotypeFreqEM requires genotype counts for the 9 possible genotype combinations.") + assert( + gtCounts.size == 9, + "haplotypeFreqEM requires genotype counts for the 9 possible genotype combinations.", + ) val _gtCounts = new DenseVector(gtCounts.toArray) val nSamples = sum(_gtCounts) - //Needs some non-ref samples to compute - if(_gtCounts(0) >= nSamples){ return FastSeq(_gtCounts(0),0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0)} + // Needs some non-ref samples to compute + if (_gtCounts(0) >= nSamples) { + return FastSeq(_gtCounts(0), 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0) + } - val nHaplotypes = 2.0*nSamples.toDouble + val nHaplotypes = 2.0 * nSamples.toDouble - /** - * Constant quantities for each of the different haplotypes: - * n.AB => 2*n.AABB + n.AaBB + n.AABb - * n.Ab => 2*n.AAbb + n.Aabb + n.AABb - * n.aB => 2*n.aaBB + n.AaBB + n.aaBb - * n.ab => 2*n.aabb + n.aaBb + n.Aabb + /** Constant quantities for each of the different haplotypes: n.AB => 2*n.AABB + n.AaBB + n.AABb + * n.Ab => 2*n.AAbb + n.Aabb + n.AABb n.aB => 2*n.aaBB + n.AaBB + n.aaBb n.ab => 2*n.aabb + + * n.aaBb + n.Aabb */ val const_counts = new DenseVector(Array[Double]( - 2.0*_gtCounts(0) + _gtCounts(1) + _gtCounts(3), //n.AB - 2.0*_gtCounts(6) + _gtCounts(3) + _gtCounts(7), //n.Ab - 2.0*_gtCounts(2) + _gtCounts(1) + _gtCounts(5), //n.aB - 2.0*_gtCounts(8) + _gtCounts(5) + _gtCounts(7) //n.ab + 2.0 * _gtCounts(0) + _gtCounts(1) + _gtCounts(3), // n.AB + 2.0 * _gtCounts(6) + _gtCounts(3) + _gtCounts(7), // n.Ab + 2.0 * _gtCounts(2) + _gtCounts(1) + _gtCounts(5), // n.aB + 2.0 * _gtCounts(8) + _gtCounts(5) + _gtCounts(7), // n.ab )) - //Initial estimate with AaBb contributing equally to each haplotype - var p_next = (const_counts +:+ new DenseVector(Array.fill[Double](4)(_gtCounts(4)/2.0))) /:/ nHaplotypes + // Initial estimate with AaBb contributing equally to each haplotype + var p_next = + (const_counts +:+ new DenseVector(Array.fill[Double](4)(_gtCounts(4) / 2.0))) /:/ nHaplotypes var p_cur = p_next +:+ 1.0 - //EM - while(max(abs(p_next -:- p_cur)) > 1e-7){ + // EM + while (max(abs(p_next -:- p_cur)) > 1e-7) { p_cur = p_next - p_next = (const_counts +:+ - (new DenseVector(Array[Double]( - p_cur(0)*p_cur(3), //n.AB - p_cur(1)*p_cur(2), //n.Ab - p_cur(1)*p_cur(2), //n.aB - p_cur(0)*p_cur(3) //n.ab - )) * (_gtCounts(4) / ((p_cur(0)*p_cur(3))+(p_cur(1)*p_cur(2))))) - ) / nHaplotypes + p_next = + (const_counts +:+ + (new DenseVector(Array[Double]( + p_cur(0) * p_cur(3), // n.AB + p_cur(1) * p_cur(2), // n.Ab + p_cur(1) * p_cur(2), // n.aB + p_cur(0) * p_cur(3), // n.ab + )) * (_gtCounts(4) / ((p_cur(0) * p_cur(3)) + (p_cur(1) * p_cur(2)))))) / nHaplotypes } diff --git a/hail/src/main/scala/is/hail/expr/AnnotationImpex.scala b/hail/src/main/scala/is/hail/expr/AnnotationImpex.scala index 69162d3129a..aed45885ccf 100644 --- a/hail/src/main/scala/is/hail/expr/AnnotationImpex.scala +++ b/hail/src/main/scala/is/hail/expr/AnnotationImpex.scala @@ -1,19 +1,22 @@ package is.hail.expr -import is.hail.annotations.{Annotation, NDArray, SafeNDArray, UnsafeNDArray} -import is.hail.expr.ir.functions.UtilFunctions -import is.hail.types.physical.{PBoolean, PCanonicalArray, PCanonicalBinary, PCanonicalString, PCanonicalStruct, PFloat32, PFloat64, PInt32, PInt64, PType} +import is.hail.annotations.{Annotation, NDArray, SafeNDArray} +import is.hail.types.physical.{ + PBoolean, PCanonicalArray, PCanonicalBinary, PCanonicalString, PCanonicalStruct, PFloat32, + PFloat64, PInt32, PInt64, PType, +} import is.hail.types.virtual._ import is.hail.utils.{Interval, _} import is.hail.variant._ + +import scala.collection.mutable + import org.apache.spark.sql.Row import org.apache.spark.sql.types._ import org.json4s import org.json4s._ import org.json4s.jackson.{JsonMethods, Serialization} -import scala.collection.mutable - object SparkAnnotationImpex { val invalidCharacters: Set[Char] = " ,;{}()\n\t=".toSet @@ -34,9 +37,12 @@ object SparkAnnotationImpex { case DoubleType => PFloat64() case StringType => PCanonicalString() case BinaryType => PCanonicalBinary() - case ArrayType(elementType, containsNull) => PCanonicalArray(importType(elementType).setRequired(!containsNull)) + case ArrayType(elementType, containsNull) => + PCanonicalArray(importType(elementType).setRequired(!containsNull)) case StructType(fields) => - PCanonicalStruct(fields.map { f => (f.name, importType(f.dataType).setRequired(!f.nullable)) }: _*) + PCanonicalStruct(fields.map { f => + (f.name, importType(f.dataType).setRequired(!f.nullable)) + }: _*) } def exportType(t: Type): DataType = (t: @unchecked) match { @@ -51,11 +57,12 @@ object SparkAnnotationImpex { ArrayType(exportType(elementType), containsNull = true) case tbs: TBaseStruct => if (tbs.fields.isEmpty) - BooleanType //placeholder + BooleanType // placeholder else StructType(tbs.fields .map(f => - StructField(escapeColumnName(f.name), f.typ.schema, nullable = true))) + StructField(escapeColumnName(f.name), f.typ.schema, nullable = true) + )) } } @@ -65,11 +72,24 @@ case class JSONExtractIntervalLocus(start: Locus, end: Locus) { case class JSONExtractContig(name: String, length: Int) -case class JSONExtractReferenceGenome(name: String, contigs: Array[JSONExtractContig], xContigs: Set[String], - yContigs: Set[String], mtContigs: Set[String], par: Array[JSONExtractIntervalLocus]) { - - def toReferenceGenome: ReferenceGenome = ReferenceGenome(name, contigs.map(_.name), - contigs.map(c => (c.name, c.length)).toMap, xContigs, yContigs, mtContigs, par.map(_.toLocusTuple)) +case class JSONExtractReferenceGenome( + name: String, + contigs: Array[JSONExtractContig], + xContigs: Set[String], + yContigs: Set[String], + mtContigs: Set[String], + par: Array[JSONExtractIntervalLocus], +) { + + def toReferenceGenome: ReferenceGenome = ReferenceGenome( + name, + contigs.map(_.name), + contigs.map(c => (c.name, c.length)).toMap, + xContigs, + yContigs, + mtContigs, + par.map(_.toLocusTuple), + ) } object JSONAnnotationImpex { @@ -83,7 +103,7 @@ object JSONAnnotationImpex { "inf" -> Double.PositiveInfinity, "Infinity" -> Double.PositiveInfinity, "-inf" -> Double.NegativeInfinity, - "-Infinity" -> Double.NegativeInfinity + "-Infinity" -> Double.NegativeInfinity, ) val floatConv = Map( @@ -92,15 +112,16 @@ object JSONAnnotationImpex { "inf" -> Float.PositiveInfinity, "Infinity" -> Float.PositiveInfinity, "-inf" -> Float.NegativeInfinity, - "-Infinity" -> Float.NegativeInfinity + "-Infinity" -> Float.NegativeInfinity, ) - def exportAnnotation(a: Annotation, t: Type): JValue = try { - _exportAnnotation(a, t) - } catch { - case exc: Exception => - fatal(s"Could not export annotation with type $t: $a", exc) - } + def exportAnnotation(a: Annotation, t: Type): JValue = + try + _exportAnnotation(a, t) + catch { + case exc: Exception => + fatal(s"Could not export annotation with type $t: $a", exc) + } def _exportAnnotation(a: Annotation, t: Type): JValue = if (a == null) @@ -123,9 +144,11 @@ object JSONAnnotationImpex { JArray(arr.map(elem => exportAnnotation(elem, elementType)).toList) case TDict(keyType, valueType) => val m = a.asInstanceOf[Map[_, _]] - JArray(m.map { case (k, v) => JObject( - "key" -> exportAnnotation(k, keyType), - "value" -> exportAnnotation(v, valueType)) + JArray(m.map { case (k, v) => + JObject( + "key" -> exportAnnotation(k, keyType), + "value" -> exportAnnotation(v, valueType), + ) }.toList) case TCall => JString(Call.toString(a.asInstanceOf[Call])) case TLocus(_) => a.asInstanceOf[Locus].toJSON @@ -137,37 +160,56 @@ object JSONAnnotationImpex { }) case TTuple(types) => val row = a.asInstanceOf[Row] - JArray(List.tabulate(row.size) { i => exportAnnotation(row.get(i), types(i).typ) }) + JArray(List.tabulate(row.size)(i => exportAnnotation(row.get(i), types(i).typ))) case TNDArray(elementType, _) => val jnd = a.asInstanceOf[NDArray] JObject( "shape" -> JArray(jnd.shape.map(shapeEntry => JInt(shapeEntry)).toList), - "data" -> JArray(jnd.getRowMajorElements().map(a => exportAnnotation(a, elementType)).toList) + "data" -> JArray(jnd.getRowMajorElements().map(a => + exportAnnotation(a, elementType) + ).toList), ) } } def irImportAnnotation(s: String, t: Type, warnContext: mutable.HashSet[String]): Row = { - try { + try // wraps in a Row to handle returned missingness Row(importAnnotation(JsonMethods.parse(s), t, true, warnContext)) - } catch { + catch { case e: Throwable => fatal(s"Error parsing JSON:\n type: $t\n value: $s", e) } } - def importAnnotation(jv: JValue, t: Type, padNulls: Boolean = true, warnContext: mutable.HashSet[String] = null): Annotation = - importAnnotationInternal(jv, t, "", padNulls, if (warnContext == null) new mutable.HashSet[String] else warnContext) - - private def importAnnotationInternal(jv: JValue, t: Type, parent: String, padNulls: Boolean, warnContext: mutable.HashSet[String]): Annotation = { - def imp(jv: JValue, t: Type, parent: String): Annotation = importAnnotationInternal(jv, t, parent, padNulls, warnContext) - def warnOnce(msg: String, path: String): Unit = { + def importAnnotation( + jv: JValue, + t: Type, + padNulls: Boolean = true, + warnContext: mutable.HashSet[String] = null, + ): Annotation = + importAnnotationInternal( + jv, + t, + "", + padNulls, + if (warnContext == null) new mutable.HashSet[String] else warnContext, + ) + + private def importAnnotationInternal( + jv: JValue, + t: Type, + parent: String, + padNulls: Boolean, + warnContext: mutable.HashSet[String], + ): Annotation = { + def imp(jv: JValue, t: Type, parent: String): Annotation = + importAnnotationInternal(jv, t, parent, padNulls, warnContext) + def warnOnce(msg: String, path: String): Unit = if (!warnContext.contains(path)) { warn(msg) warnContext += path } - } (jv, t) match { case (JNull | JNothing, _) => null @@ -199,30 +241,37 @@ object JSONAnnotationImpex { case (JArray(arr), TDict(keyType, valueType)) => val keyPath = parent + "[key]" val valuePath = parent + "[value]" - arr.map { case JObject(a) => - a match { - case List(k, v) => - (k, v) match { - case (("key", ka), ("value", va)) => - (imp(ka, keyType, keyPath), imp(va, valueType, valuePath)) - } - case _ => - warnOnce(s"Can't convert JSON value $jv to type $t at $parent.", parent) - null + arr.map { + case JObject(a) => + a match { + case List(k, v) => + (k, v) match { + case (("key", ka), ("value", va)) => + (imp(ka, keyType, keyPath), imp(va, valueType, valuePath)) + } + case _ => + warnOnce(s"Can't convert JSON value $jv to type $t at $parent.", parent) + null - } - case _ => - warnOnce(s"Can't convert JSON value $jv to type $t at $parent.", parent) - null + } + case _ => + warnOnce(s"Can't convert JSON value $jv to type $t at $parent.", parent) + null }.toMap case (JObject(jfields), t: TStruct) => if (t.size == 0) Annotation.empty else { - val annotationSize = - if (padNulls) t.size - else jfields.map { case (name, jv2) => t.selfField(name).map(_.index).getOrElse(-1) }.max + 1 + val annotationSize = if (padNulls) { + t.size + } else if (jfields.size == 0) { + 0 + } else { + jfields.map { case (name, _) => + t.selfField(name).map(_.index).getOrElse(-1) + }.max + 1 + } val a = Array.fill[Any](annotationSize)(null) for ((name, jv2) <- jfields) { @@ -255,22 +304,26 @@ object JSONAnnotationImpex { } case (_, TLocus(_)) => jv.extract[Locus] - case (JObject(List(("shape", shapeJson: JArray), ("data", dataJson: JArray))), t@TNDArray(_, _)) => { - val shapeArray = shapeJson.arr.map(imp(_, TInt64, parent)).map(_.asInstanceOf[Long]).toIndexedSeq + case ( + JObject(List(("shape", shapeJson: JArray), ("data", dataJson: JArray))), + t @ TNDArray(_, _), + ) => + val shapeArray = + shapeJson.arr.map(imp(_, TInt64, parent)).map(_.asInstanceOf[Long]).toIndexedSeq val dataArray = dataJson.arr.map(imp(_, t.elementType, parent)).toIndexedSeq new SafeNDArray(shapeArray, dataArray) - } case (_, TInterval(pointType)) => jv match { case JObject(list) => val m = list.toMap (m.get("start"), m.get("end"), m.get("includeStart"), m.get("includeEnd")) match { case (Some(sjv), Some(ejv), Some(isjv), Some(iejv)) => - Interval(imp(sjv, pointType, parent + ".start"), + Interval( + imp(sjv, pointType, parent + ".start"), imp(ejv, pointType, parent + ".end"), imp(isjv, TBoolean, parent + ".includeStart").asInstanceOf[Boolean], - imp(iejv, TBoolean, parent + ".includeEnd").asInstanceOf[Boolean] + imp(iejv, TBoolean, parent + ".includeEnd").asInstanceOf[Boolean], ) case _ => warnOnce(s"Can't convert JSON value $jv to type $t at $parent.", parent) @@ -283,7 +336,9 @@ object JSONAnnotationImpex { case (JString(x), TCall) => Call.parse(x) case (JArray(a), TArray(elementType)) => - a.iterator.map(jv2 => imp(jv2, elementType, parent + "[element]")).toArray[Any]: IndexedSeq[Any] + a.iterator.map(jv2 => imp(jv2, elementType, parent + "[element]")).toArray[Any]: IndexedSeq[ + Any + ] case (JArray(a), TSet(elementType)) => a.iterator.map(jv2 => imp(jv2, elementType, parent + "[element]")).toSet[Any] @@ -312,10 +367,10 @@ object TableAnnotationImpex { case TInterval(TLocus(_)) => val i = a.asInstanceOf[Interval] val bounds = if (i.start.asInstanceOf[Locus].contig == i.end.asInstanceOf[Locus].contig) - s"${ i.start }-${ i.end.asInstanceOf[Locus].position }" + s"${i.start}-${i.end.asInstanceOf[Locus].position}" else - s"${ i.start }-${ i.end }" - s"${ if (i.includesStart) "[" else "(" }$bounds${ if (i.includesEnd) "]" else ")" }" + s"${i.start}-${i.end}" + s"${if (i.includesStart) "[" else "("}$bounds${if (i.includesEnd) "]" else ")"}" case _: TInterval => JsonMethods.compact(t.toJSON(a)) case TCall => Call.toString(a.asInstanceOf[Call]) diff --git a/hail/src/main/scala/is/hail/expr/NatBase.scala b/hail/src/main/scala/is/hail/expr/NatBase.scala index 68d04d451a5..fea44e9caf0 100644 --- a/hail/src/main/scala/is/hail/expr/NatBase.scala +++ b/hail/src/main/scala/is/hail/expr/NatBase.scala @@ -9,14 +9,13 @@ abstract class NatBase { case class Nat(n: Int) extends NatBase { override def toString: String = n.toString - override def clear() {} + override def clear(): Unit = {} - override def unify(concrete: NatBase): Boolean = { + override def unify(concrete: NatBase): Boolean = concrete match { case Nat(cN) => cN == n case _ => false } - } override def subst(): NatBase = this } @@ -24,13 +23,12 @@ case class Nat(n: Int) extends NatBase { case class NatVariable(var nat: NatBase = null) extends NatBase { override def toString: String = "?nat" - override def clear() { nat = null } + override def clear(): Unit = nat = null override def unify(concrete: NatBase): Boolean = { if (nat != null) { nat.unify(concrete) - } - else { + } else { nat = concrete true } diff --git a/hail/src/main/scala/is/hail/expr/Parser.scala b/hail/src/main/scala/is/hail/expr/Parser.scala index cc8ee39bb24..b71ea4961e4 100644 --- a/hail/src/main/scala/is/hail/expr/Parser.scala +++ b/hail/src/main/scala/is/hail/expr/Parser.scala @@ -3,86 +3,85 @@ package is.hail.expr import is.hail.utils._ import is.hail.variant._ -import scala.collection.mutable.ArrayBuffer import scala.util.parsing.combinator.JavaTokenParsers import scala.util.parsing.input.Position class RichParser[T](parser: Parser.Parser[T]) { - def parse(input: String): T = { + def parse(input: String): T = Parser.parseAll(parser, input) match { case Parser.Success(result, _) => result case Parser.NoSuccess(msg, next) => ParserUtils.error(next.pos, msg) } - } - def parseOpt(input: String): Option[T] = { + def parseOpt(input: String): Option[T] = Parser.parseAll(parser, input) match { case Parser.Success(result, _) => Some(result) - case Parser.NoSuccess(msg, next) => None + case Parser.NoSuccess(_, _) => None } - } } object ParserUtils { def error(pos: Position, msg: String): Nothing = { val lineContents = pos.longString.split("\n").head - val prefix = s":${ pos.line }:" + val prefix = s":${pos.line}:" fatal( s"""$msg |$prefix$lineContents - |${ " " * prefix.length }${ - lineContents.take(pos.column - 1).map { c => if (c == '\t') c else ' ' } - }^""".stripMargin) + |${" " * prefix.length}${lineContents.take(pos.column - 1).map { c => + if (c == '\t') c else ' ' + }}^""".stripMargin + ) } def error(pos: Position, msg: String, tr: Truncatable): Nothing = { val lineContents = pos.longString.split("\n").head - val prefix = s":${ pos.line }:" + val prefix = s":${pos.line}:" fatal( s"""$msg |$prefix$lineContents - |${ " " * prefix.length }${ - lineContents.take(pos.column - 1).map { c => if (c == '\t') c else ' ' } - }^""".stripMargin, tr) + |${" " * prefix.length}${lineContents.take(pos.column - 1).map { c => + if (c == '\t') c else ' ' + }}^""".stripMargin, + tr, + ) } } object Parser extends JavaTokenParsers { - def parse[T](parser: Parser[T], code: String): T = { + def parse[T](parser: Parser[T], code: String): T = parseAll(parser, code) match { case Success(result, _) => result case NoSuccess(msg, next) => ParserUtils.error(next.pos, msg) } - } def parseLocusInterval(input: String, rg: ReferenceGenome, invalidMissing: Boolean): Interval = { parseAll[Interval](locusInterval(rg, invalidMissing), input) match { case Success(r, _) => r - case NoSuccess(msg, next) => fatal( - s"""invalid interval expression: '$input': $msg - | Acceptable formats: - | CHR:POS-CHR:POS e.g. 1:12345-1:17299 or [5:151111-8:191293] - | An interval from the starting locus (chromosome, position) - | to the ending locus. By default the bounds are left-inclusive, - | right-exclusive, but may be configured by inclusion of square - | brackets ('[' or ']') for open endpoints, or parenthesis ('(' - | or ')') for closed endpoints. The POS field may be the words - | 'START' or 'END' to denote the start or end of the chromosome. - | CHR:POS-POS e.g. 1:14244-912382 - | The same interval as '[1:14244-1:912382)' - | CHR-CHR e.g. '1-22' or 'X-Y' - | The same intervals as '[1:START-22:END'] or '[X:START-Y:END]' - | CHR e.g. '5' or 'X' - | The same intervals as '[5:START-5:END]' or '[X:START-X:END]' """.stripMargin) + case NoSuccess(msg, _) => fatal( + s"""invalid interval expression: '$input': $msg + | Acceptable formats: + | CHR:POS-CHR:POS e.g. 1:12345-1:17299 or [5:151111-8:191293] + | An interval from the starting locus (chromosome, position) + | to the ending locus. By default the bounds are left-inclusive, + | right-exclusive, but may be configured by inclusion of square + | brackets ('[' or ']') for open endpoints, or parenthesis ('(' + | or ')') for closed endpoints. The POS field may be the words + | 'START' or 'END' to denote the start or end of the chromosome. + | CHR:POS-POS e.g. 1:14244-912382 + | The same interval as '[1:14244-1:912382)' + | CHR-CHR e.g. '1-22' or 'X-Y' + | The same intervals as '[1:START-22:END'] or '[X:START-Y:END]' + | CHR e.g. '5' or 'X' + | The same intervals as '[5:START-5:END]' or '[X:START-X:END]' """.stripMargin + ) } } - def parseCall(input: String): Call = { + def parseCall(input: String): Call = parseAll[Call](call, input) match { case Success(r, _) => r - case NoSuccess(msg, next) => fatal(s"invalid call expression: '$input': $msg") + case NoSuccess(msg, _) => fatal(s"invalid call expression: '$input': $msg") } - } def oneOfLiteral(a: Array[String]): Parser[String] = new Parser[String] { private[this] val root = ParseTrieNode.generate(a) @@ -139,17 +138,21 @@ object Parser extends JavaTokenParsers { val contig = rg.contigParser val valueParser = - locusUnchecked(rg) ~ "-" ~ rg.contigParser ~ ":" ~ pos ^^ { case l1 ~ _ ~ c2 ~ _ ~ p2 => p2 match { - case Some(p) => (l1, Locus(c2, p), true, false) - case None => (l1, Locus(c2, rg.contigLength(c2)), true, true) - } - } | - locusUnchecked(rg) ~ "-" ~ pos ^^ { case l1 ~ _ ~ p2 => p2 match { - case Some(p) => (l1, l1.copy(position = p), true, false) - case None => (l1, l1.copy(position = rg.contigLength(l1.contig)), true, true) + locusUnchecked(rg) ~ "-" ~ rg.contigParser ~ ":" ~ pos ^^ { case l1 ~ _ ~ c2 ~ _ ~ p2 => + p2 match { + case Some(p) => (l1, Locus(c2, p), true, false) + case None => (l1, Locus(c2, rg.contigLength(c2)), true, true) } + } | + locusUnchecked(rg) ~ "-" ~ pos ^^ { case l1 ~ _ ~ p2 => + p2 match { + case Some(p) => (l1, l1.copy(position = p), true, false) + case None => (l1, l1.copy(position = rg.contigLength(l1.contig)), true, true) + } + } | + contig ~ "-" ~ contig ^^ { case c1 ~ _ ~ c2 => + (Locus(c1, 1), Locus(c2, rg.contigLength(c2)), true, true) } | - contig ~ "-" ~ contig ^^ { case c1 ~ _ ~ c2 => (Locus(c1, 1), Locus(c2, rg.contigLength(c2)), true, true) } | contig ^^ { c => (Locus(c, 1), Locus(c, rg.contigLength(c)), true, true) } intervalWithEndpoints(valueParser) ^^ { i => rg.toLocusInterval(i, invalidMissing) } } @@ -158,13 +161,16 @@ object Parser extends JavaTokenParsers { (rg.contigParser ~ ":" ~ pos) ^^ { case c ~ _ ~ p => Locus(c, p.getOrElse(rg.contigLength(c))) } def locus(rg: ReferenceGenome): Parser[Locus] = - (rg.contigParser ~ ":" ~ pos) ^^ { case c ~ _ ~ p => Locus(c, p.getOrElse(rg.contigLength(c)), rg) } + (rg.contigParser ~ ":" ~ pos) ^^ { case c ~ _ ~ p => + Locus(c, p.getOrElse(rg.contigLength(c)), rg) + } - def coerceInt(s: String): Int = try { - s.toInt - } catch { - case e: java.lang.NumberFormatException => Int.MaxValue - } + def coerceInt(s: String): Int = + try + s.toInt + catch { + case _: java.lang.NumberFormatException => Int.MaxValue + } def exp10(i: Int): Int = { var mult = 1 @@ -181,8 +187,12 @@ object Parser extends JavaTokenParsers { "[Ee][Nn][Dd]".r ^^ { _ => None } | "\\d+".r <~ "[Kk]".r ^^ { i => Some(coerceInt(i) * 1000) } | "\\d+".r <~ "[Mm]".r ^^ { i => Some(coerceInt(i) * 1000000) } | - "\\d+".r ~ "." ~ "\\d{1,3}".r ~ "[Kk]".r ^^ { case lft ~ _ ~ rt ~ _ => Some(coerceInt(lft + rt) * exp10(3 - rt.length)) } | - "\\d+".r ~ "." ~ "\\d{1,6}".r ~ "[Mm]".r ^^ { case lft ~ _ ~ rt ~ _ => Some(coerceInt(lft + rt) * exp10(6 - rt.length)) } | + "\\d+".r ~ "." ~ "\\d{1,3}".r ~ "[Kk]".r ^^ { case lft ~ _ ~ rt ~ _ => + Some(coerceInt(lft + rt) * exp10(3 - rt.length)) + } | + "\\d+".r ~ "." ~ "\\d{1,6}".r ~ "[Mm]".r ^^ { case lft ~ _ ~ rt ~ _ => + Some(coerceInt(lft + rt) * exp10(6 - rt.length)) + } | "\\d+".r ^^ { i => Some(coerceInt(i)) } } } diff --git a/hail/src/main/scala/is/hail/expr/Validate.scala b/hail/src/main/scala/is/hail/expr/Validate.scala index 7d4d6de24ab..82df3ebbdd3 100644 --- a/hail/src/main/scala/is/hail/expr/Validate.scala +++ b/hail/src/main/scala/is/hail/expr/Validate.scala @@ -1,6 +1,8 @@ package is.hail.expr -import is.hail.expr.ir.{BaseIR, BlockMatrixRead, BlockMatrixWrite, MatrixRead, MatrixWrite, TableRead, TableWrite} +import is.hail.expr.ir.{ + BaseIR, BlockMatrixRead, BlockMatrixWrite, MatrixRead, MatrixWrite, TableRead, TableWrite, +} import is.hail.utils._ case class ValidateState(writeFilePaths: Set[String]) @@ -16,14 +18,14 @@ object Validate { private def validate(ir: BaseIR, state: ValidateState): Unit = { ir match { case tr: TableRead => tr.tr.pathsUsed.foreach { path => - if (state.writeFilePaths.contains(path)) - fileReadWriteError(path) - } + if (state.writeFilePaths.contains(path)) + fileReadWriteError(path) + } case mr: MatrixRead => mr.reader.pathsUsed.foreach { path => - if (state.writeFilePaths.contains(path)) - fileReadWriteError(path) - } - case bmr: BlockMatrixRead => + if (state.writeFilePaths.contains(path)) + fileReadWriteError(path) + } + case _: BlockMatrixRead => case tw: TableWrite => val newState = state.copy(writeFilePaths = state.writeFilePaths + tw.writer.path) validate(tw.child, newState) diff --git a/hail/src/main/scala/is/hail/expr/ir/AbstractMatrixTableSpec.scala b/hail/src/main/scala/is/hail/expr/ir/AbstractMatrixTableSpec.scala index 4f6379201b0..9b12157d5af 100644 --- a/hail/src/main/scala/is/hail/expr/ir/AbstractMatrixTableSpec.scala +++ b/hail/src/main/scala/is/hail/expr/ir/AbstractMatrixTableSpec.scala @@ -5,35 +5,53 @@ import is.hail.rvd._ import is.hail.types._ import is.hail.utils._ import is.hail.variant.ReferenceGenome + +import scala.collection.mutable + +import java.io.{FileNotFoundException, OutputStreamWriter} + import org.json4s._ import org.json4s.jackson.JsonMethods import org.json4s.jackson.JsonMethods.parse -import java.io.OutputStreamWriter -import scala.collection.mutable -import scala.language.{existentials, implicitConversions} - abstract class ComponentSpec object RelationalSpec { - implicit val formats: Formats = new DefaultFormats() { - override val typeHints = ShortTypeHints(List( - classOf[ComponentSpec], classOf[RVDComponentSpec], classOf[PartitionCountsComponentSpec], classOf[PropertiesSpec], - classOf[RelationalSpec], classOf[MatrixTableSpec], classOf[TableSpec]), typeHintFieldName="name") - } + - new TableTypeSerializer + - new MatrixTypeSerializer + implicit val formats: Formats = + new DefaultFormats() { + override val typeHints = ShortTypeHints( + List( + classOf[ComponentSpec], + classOf[RVDComponentSpec], + classOf[PartitionCountsComponentSpec], + classOf[PropertiesSpec], + classOf[RelationalSpec], + classOf[MatrixTableSpec], + classOf[TableSpec], + ), + typeHintFieldName = "name", + ) + } + + new TableTypeSerializer + + new MatrixTypeSerializer def readMetadata(fs: FS, path: String): JValue = { - if (!fs.isDir(path)) { - if (!fs.exists(path)) { - fatal(s"No file or directory found at ${path}") - } else { - fatal(s"MatrixTable and Table files are directories; path '$path' is not a directory") - } - } val metadataFile = path + "/metadata.json.gz" - val jv = using(fs.open(metadataFile)) { in => parse(in) } + val jv = + try + using(fs.open(metadataFile))(in => parse(in)) + catch { + case _: FileNotFoundException => + if (fs.isFile(path)) { + fatal(s"MatrixTable and Table files are directories; path '$path' is a file.") + } else { + if (fs.isDir(path)) { + fatal(s"MatrixTable is corrupted: $path/metadata.json.gz is missing.") + } else { + fatal(s"No file or directory found at $path.") + } + } + } val fileVersion = jv \ "file_version" match { case JInt(rep) => SemanticVersion(rep.toInt) @@ -41,20 +59,22 @@ object RelationalSpec { fatal( s"""cannot read file: metadata does not contain file version: $metadataFile | Common causes: - | - File is an 0.1 VariantDataset or KeyTable (0.1 and 0.2 native formats are not compatible!)""".stripMargin) + | - File is an 0.1 VariantDataset or KeyTable (0.1 and 0.2 native formats are not compatible!)""".stripMargin + ) } if (!FileFormat.version.supports(fileVersion)) - fatal(s"incompatible file format when reading: $path\n supported file format version: ${ FileFormat.version }, found file format version $fileVersion" + - s"\n The cause of this error is usually an attempt to use an older version of Hail to read files " + - s"generated by a newer version. This is not supported (Hail native files are back-compatible, but not forward-compatible)." + - s"\n To read this file, use a newer version of Hail. Note that the file format version and the Hail Python library version are not the same.") + fatal( + s"incompatible file format when reading: $path\n supported file format version: ${FileFormat.version}, found file format version $fileVersion" + + s"\n The cause of this error is usually an attempt to use an older version of Hail to read files " + + s"generated by a newer version. This is not supported (Hail native files are back-compatible, but not forward-compatible)." + + s"\n To read this file, use a newer version of Hail. Note that the file format version and the Hail Python library version are not the same." + ) jv } def read(fs: FS, path: String): RelationalSpec = { val jv = readMetadata(fs, path) - val references = readReferences(fs, path, jv) (jv \ "name").extract[String] match { case "TableSpec" => TableSpec.fromJValue(fs, path, jv) @@ -81,13 +101,17 @@ abstract class RelationalSpec { def getComponent[T <: ComponentSpec](name: String): T = components(name).asInstanceOf[T] - def getOptionalComponent[T <: ComponentSpec](name: String): Option[T] = components.get(name).map(_.asInstanceOf[T]) + def getOptionalComponent[T <: ComponentSpec](name: String): Option[T] = + components.get(name).map(_.asInstanceOf[T]) def globalsComponent: RVDComponentSpec = getComponent[RVDComponentSpec]("globals") - def partitionCounts: Array[Long] = getComponent[PartitionCountsComponentSpec]("partition_counts").counts.toArray + def partitionCounts: Array[Long] = + getComponent[PartitionCountsComponentSpec]("partition_counts").counts.toArray - def isDistinctlyKeyed: Boolean = getOptionalComponent[PropertiesSpec]("properties").flatMap(_.properties.values.get("distinctlyKeyed").map(_.asInstanceOf[Boolean])).getOrElse(false) + def isDistinctlyKeyed: Boolean = getOptionalComponent[PropertiesSpec]("properties").flatMap( + _.properties.values.get("distinctlyKeyed").map(_.asInstanceOf[Boolean]) + ).getOrElse(false) def indexed: Boolean @@ -100,9 +124,9 @@ case class RVDComponentSpec(rel_path: String) extends ComponentSpec { def absolutePath(path: String): String = path + "/" + rel_path private[this] val specCache = mutable.Map.empty[String, AbstractRVDSpec] - def rvdSpec(fs: FS, path: String): AbstractRVDSpec = { + + def rvdSpec(fs: FS, path: String): AbstractRVDSpec = specCache.getOrElseUpdate(path, AbstractRVDSpec.read(fs, absolutePath(path))) - } def indexed(fs: FS, path: String): Boolean = rvdSpec(fs, path).indexed } @@ -135,11 +159,18 @@ abstract class AbstractMatrixTableSpec extends RelationalSpec { object MatrixTableSpec { def fromJValue(fs: FS, path: String, jv: JValue): MatrixTableSpec = { - implicit val formats: Formats = new DefaultFormats() { - override val typeHints = ShortTypeHints(List( - classOf[ComponentSpec], classOf[RVDComponentSpec], classOf[PartitionCountsComponentSpec]), typeHintFieldName = "name") - } + - new MatrixTypeSerializer + implicit val formats: Formats = + new DefaultFormats() { + override val typeHints = ShortTypeHints( + List( + classOf[ComponentSpec], + classOf[RVDComponentSpec], + classOf[PartitionCountsComponentSpec], + ), + typeHintFieldName = "name", + ) + } + + new MatrixTypeSerializer val params = jv.extract[MatrixTableSpecParameters] val globalsSpec = RelationalSpec.read(fs, path + "/globals").asInstanceOf[AbstractTableSpec] @@ -148,11 +179,17 @@ object MatrixTableSpec { val rowsSpec = RelationalSpec.read(fs, path + "/rows").asInstanceOf[AbstractTableSpec] - // some legacy files written as MatrixTableSpec wrote the wrong type to the entries table metadata + /* some legacy files written as MatrixTableSpec wrote the wrong type to the entries table + * metadata */ var entriesSpec = RelationalSpec.read(fs, path + "/entries").asInstanceOf[TableSpec] - entriesSpec = TableSpec(fs, path + "/entries", + entriesSpec = TableSpec( + fs, + path + "/entries", entriesSpec.params.copy( - table_type = TableType(params.matrix_type.entriesRVType, FastSeq(), params.matrix_type.globalType))) + table_type = + TableType(params.matrix_type.entriesRVType, FastSeq(), params.matrix_type.globalType) + ), + ) new MatrixTableSpec(params, globalsSpec, colsSpec, rowsSpec, entriesSpec) } @@ -163,13 +200,15 @@ case class MatrixTableSpecParameters( hail_version: String, references_rel_path: String, matrix_type: MatrixType, - components: Map[String, ComponentSpec]) { + components: Map[String, ComponentSpec], +) { - def write(fs: FS, path: String) { + def write(fs: FS, path: String): Unit = using(new OutputStreamWriter(fs.create(path + "/metadata.json.gz"))) { out => - out.write(JsonMethods.compact(decomposeWithName(this, "MatrixTableSpec")(RelationalSpec.formats))) + out.write( + JsonMethods.compact(decomposeWithName(this, "MatrixTableSpec")(RelationalSpec.formats)) + ) } - } } @@ -178,7 +217,8 @@ class MatrixTableSpec( val globalsSpec: AbstractTableSpec, val colsSpec: AbstractTableSpec, val rowsSpec: AbstractTableSpec, - val entriesSpec: AbstractTableSpec) extends AbstractMatrixTableSpec { + val entriesSpec: AbstractTableSpec, +) extends AbstractMatrixTableSpec { def references_rel_path: String = params.references_rel_path def file_version: Int = params.file_version @@ -189,9 +229,8 @@ class MatrixTableSpec( def components: Map[String, ComponentSpec] = params.components - def toJValue: JValue = { + def toJValue: JValue = decomposeWithName(params, "MatrixTableSpec")(RelationalSpec.formats) - } } object FileFormat { diff --git a/hail/src/main/scala/is/hail/expr/ir/AbstractTableSpec.scala b/hail/src/main/scala/is/hail/expr/ir/AbstractTableSpec.scala index e2e164782b7..9332c15e587 100644 --- a/hail/src/main/scala/is/hail/expr/ir/AbstractTableSpec.scala +++ b/hail/src/main/scala/is/hail/expr/ir/AbstractTableSpec.scala @@ -1,15 +1,14 @@ package is.hail.expr.ir -import java.io.OutputStreamWriter - -import is.hail.utils._ -import is.hail.types._ import is.hail.io.fs.FS import is.hail.rvd._ -import org.json4s.jackson.JsonMethods -import org.json4s.{DefaultFormats, Extraction, Formats, JValue, ShortTypeHints} +import is.hail.types._ +import is.hail.utils._ + +import java.io.OutputStreamWriter -import scala.language.implicitConversions +import org.json4s.{Formats, JValue} +import org.json4s.jackson.JsonMethods object SortOrder { def deserialize(b: Byte): SortOrder = @@ -86,19 +85,20 @@ case class TableSpecParameters( hail_version: String, references_rel_path: String, table_type: TableType, - components: Map[String, ComponentSpec]) { + components: Map[String, ComponentSpec], +) { - def write(fs: FS, path: String) { + def write(fs: FS, path: String): Unit = using(new OutputStreamWriter(fs.create(path + "/metadata.json.gz"))) { out => out.write(JsonMethods.compact(decomposeWithName(this, "TableSpec")(RelationalSpec.formats))) } - } } class TableSpec( val params: TableSpecParameters, val globalsSpec: AbstractRVDSpec, - val rowsSpec: AbstractRVDSpec) extends AbstractTableSpec { + val rowsSpec: AbstractRVDSpec, +) extends AbstractTableSpec { def file_version: Int = params.file_version def hail_version: String = params.hail_version @@ -109,7 +109,6 @@ class TableSpec( def table_type: TableType = params.table_type - def toJValue: JValue = { + def toJValue: JValue = decomposeWithName(params, "TableSpec")(RelationalSpec.formats) - } } diff --git a/hail/src/main/scala/is/hail/expr/ir/AggOp.scala b/hail/src/main/scala/is/hail/expr/ir/AggOp.scala index 81f500a4288..6223aa939a1 100644 --- a/hail/src/main/scala/is/hail/expr/ir/AggOp.scala +++ b/hail/src/main/scala/is/hail/expr/ir/AggOp.scala @@ -1,8 +1,6 @@ package is.hail.expr.ir import is.hail.expr.ir.agg._ -import is.hail.types.TypeWithRequiredness -import is.hail.types.physical._ import is.hail.types.virtual._ import is.hail.utils.FastSeq @@ -13,9 +11,17 @@ object AggSignature { case AggSignature(Take(), Seq(n), Seq(_)) => AggSignature(Take(), FastSeq(n), FastSeq(requestedType.asInstanceOf[TArray].elementType)) case AggSignature(ReservoirSample(), Seq(n), Seq(_)) => - AggSignature(ReservoirSample(), FastSeq(n), FastSeq(requestedType.asInstanceOf[TArray].elementType)) + AggSignature( + ReservoirSample(), + FastSeq(n), + FastSeq(requestedType.asInstanceOf[TArray].elementType), + ) case AggSignature(TakeBy(reverse), Seq(n), Seq(_, k)) => - AggSignature(TakeBy(reverse), FastSeq(n), FastSeq(requestedType.asInstanceOf[TArray].elementType, k)) + AggSignature( + TakeBy(reverse), + FastSeq(n), + FastSeq(requestedType.asInstanceOf[TArray].elementType, k), + ) case AggSignature(PrevNonnull(), Seq(), Seq(_)) => AggSignature(PrevNonnull(), FastSeq(), FastSeq(requestedType)) case AggSignature(Densify(), Seq(), Seq(_)) => @@ -27,10 +33,32 @@ object AggSignature { case class AggSignature( op: AggOp, var initOpArgs: Seq[Type], - var seqOpArgs: Seq[Type] + var seqOpArgs: Seq[Type], ) { // only to be used with virtual non-nested signatures on ApplyAggOp and ApplyScanOp - lazy val returnType: Type = Extract.getResultType(this) + lazy val returnType: Type = (op, seqOpArgs) match { + case (Sum(), Seq(t)) => t + case (Product(), Seq(t)) => t + case (Min(), Seq(t)) => t + case (Max(), Seq(t)) => t + case (Count(), _) => TInt64 + case (Take(), Seq(t)) => TArray(t) + case (ReservoirSample(), Seq(t)) => TArray(t) + case (CallStats(), _) => CallStatsState.resultPType.virtualType + case (TakeBy(_), Seq(value, _)) => TArray(value) + case (PrevNonnull(), Seq(t)) => t + case (CollectAsSet(), Seq(t)) => TSet(t) + case (Collect(), Seq(t)) => TArray(t) + case (Densify(), Seq(t)) => t + case (ImputeType(), _) => ImputeTypeState.resultEmitType.virtualType + case (LinearRegression(), _) => + LinearRegressionAggregator.resultPType.virtualType + case (ApproxCDF(), _) => QuantilesAggregator.resultPType.virtualType + case (Downsample(), Seq(_, _, _)) => DownsampleAggregator.resultType + case (NDArraySum(), Seq(t)) => t + case (NDArrayMultiplyAdd(), Seq(a: TNDArray, _)) => a + case _ => throw new UnsupportedExtraction(this.toString) + } } sealed trait AggOp {} diff --git a/hail/src/main/scala/is/hail/expr/ir/ArraySorter.scala b/hail/src/main/scala/is/hail/expr/ir/ArraySorter.scala index ef1109fb6b6..5134d44c19f 100644 --- a/hail/src/main/scala/is/hail/expr/ir/ArraySorter.scala +++ b/hail/src/main/scala/is/hail/expr/ir/ArraySorter.scala @@ -2,8 +2,8 @@ package is.hail.expr.ir import is.hail.annotations.Region import is.hail.asm4s._ -import is.hail.types.physical.stypes.interfaces.SIndexableValue import is.hail.types.physical.{PCanonicalArray, PCanonicalDict, PCanonicalSet} +import is.hail.types.physical.stypes.interfaces.SIndexableValue import is.hail.types.virtual.{TArray, TDict, TSet, Type} import is.hail.utils.FastSeq @@ -22,32 +22,53 @@ class ArraySorter(r: EmitRegion, array: StagedArrayBuilder) { private[this] def arrayRef(workingArray: Value[Array[_]]): UntypedCodeArray = new UntypedCodeArray(workingArray, array.ti) - def sort(cb: EmitCodeBuilder, region: Value[Region], comparesLessThan: (EmitCodeBuilder, Value[Region], Value[_], Value[_]) => Value[Boolean]): Unit = { + def sort( + cb: EmitCodeBuilder, + region: Value[Region], + comparesLessThan: (EmitCodeBuilder, Value[Region], Value[_], Value[_]) => Value[Boolean], + ): Unit = { - val sortMB = cb.emb.ecb.genEmitMethod("arraySorter_outer", FastSeq[ParamType](classInfo[Region]), UnitInfo) + val sortMB = + cb.emb.ecb.genEmitMethod("arraySorter_outer", FastSeq[ParamType](classInfo[Region]), UnitInfo) sortMB.voidWithBuilder { cb => - val newEnd = cb.newLocal[Int]("newEnd", 0) val i = cb.newLocal[Int]("i", 0) val size = cb.newLocal[Int]("size", array.size) - cb.while_(i < size, { - cb.if_(!array.isMissing(i), { - cb.if_(newEnd.cne(i), array.update(cb, newEnd, array.apply(i))) - cb.assign(newEnd, newEnd + 1) - }) - cb.assign(i, i + 1) - }) + cb.while_( + i < size, { + cb.if_( + !array.isMissing(i), { + cb.if_(newEnd.cne(i), array.update(cb, newEnd, array.apply(i))) + cb.assign(newEnd, newEnd + 1) + }, + ) + cb.assign(i, i + 1) + }, + ) cb.assign(i, newEnd) - cb.while_(i < size, { - array.setMissing(cb, i, true) - cb.assign(i, i + 1) - }) + cb.while_( + i < size, { + array.setMissing(cb, i, true) + cb.assign(i, i + 1) + }, + ) // sort elements in [0, newEnd] // merging into B - val mergeMB = cb.emb.ecb.genEmitMethod("arraySorter_merge", FastSeq[ParamType](classInfo[Region], IntInfo, IntInfo, IntInfo, workingArrayInfo, workingArrayInfo), UnitInfo) + val mergeMB = cb.emb.ecb.genEmitMethod( + "arraySorter_merge", + FastSeq[ParamType]( + classInfo[Region], + IntInfo, + IntInfo, + IntInfo, + workingArrayInfo, + workingArrayInfo, + ), + UnitInfo, + ) mergeMB.voidWithBuilder { cb => val r = mergeMB.getCodeParam[Region](1) val begin = mergeMB.getCodeParam[Int](2) @@ -62,33 +83,46 @@ class ArraySorter(r: EmitRegion, array: StagedArrayBuilder) { val j = cb.newLocal[Int]("mergemb_j", mid) val k = cb.newLocal[Int]("mergemb_k", i) - cb.while_(k < end, { - - val LtakeFromLeft = CodeLabel() - val LtakeFromRight = CodeLabel() - val Ldone = CodeLabel() - - cb.if_(j < end, { - cb.if_(i >= mid, cb.goto(LtakeFromRight)) - cb.if_(comparesLessThan(cb, r, arrayA.index(cb, j), arrayA.index(cb, i)), cb.goto(LtakeFromRight), cb.goto(LtakeFromLeft)) - }, cb.goto(LtakeFromLeft)) - - cb.define(LtakeFromLeft) - cb += arrayB.update(k, arrayA(i)) - cb.assign(i, i + 1) - cb.goto(Ldone) - - cb.define(LtakeFromRight) - cb += arrayB.update(k, arrayA(j)) - cb.assign(j, j + 1) - cb.goto(Ldone) - - cb.define(Ldone) - cb.assign(k, k + 1) - }) + cb.while_( + k < end, { + + val LtakeFromLeft = CodeLabel() + val LtakeFromRight = CodeLabel() + val Ldone = CodeLabel() + + cb.if_( + j < end, { + cb.if_(i >= mid, cb.goto(LtakeFromRight)) + cb.if_( + comparesLessThan(cb, r, arrayA.index(cb, j), arrayA.index(cb, i)), + cb.goto(LtakeFromRight), + cb.goto(LtakeFromLeft), + ) + }, + cb.goto(LtakeFromLeft), + ) + + cb.define(LtakeFromLeft) + cb += arrayB.update(k, arrayA(i)) + cb.assign(i, i + 1) + cb.goto(Ldone) + + cb.define(LtakeFromRight) + cb += arrayB.update(k, arrayA(j)) + cb.assign(j, j + 1) + cb.goto(Ldone) + + cb.define(Ldone) + cb.assign(k, k + 1) + }, + ) } - val splitMergeMB = cb.emb.ecb.genEmitMethod("arraySorter_splitMerge", FastSeq[ParamType](classInfo[Region], IntInfo, IntInfo, workingArrayInfo, workingArrayInfo), UnitInfo) + val splitMergeMB = cb.emb.ecb.genEmitMethod( + "arraySorter_splitMerge", + FastSeq[ParamType](classInfo[Region], IntInfo, IntInfo, workingArrayInfo, workingArrayInfo), + UnitInfo, + ) splitMergeMB.voidWithBuilder { cb => val r = splitMergeMB.getCodeParam[Region](1) val begin = splitMergeMB.getCodeParam[Int](2) @@ -97,51 +131,67 @@ class ArraySorter(r: EmitRegion, array: StagedArrayBuilder) { val arrayB = splitMergeMB.getCodeParam(4)(workingArrayInfo) val arrayA = splitMergeMB.getCodeParam(5)(workingArrayInfo) - cb.if_(end - begin > 1, { - val mid = cb.newLocal[Int]("splitMerge_mid", (begin + end) / 2) + cb.if_( + end - begin > 1, { + val mid = cb.newLocal[Int]("splitMerge_mid", (begin + end) / 2) - cb.invokeVoid(splitMergeMB, r, begin, mid, arrayA, arrayB) - cb.invokeVoid(splitMergeMB, r, mid, end, arrayA, arrayB) + cb.invokeVoid(splitMergeMB, cb.this_, r, begin, mid, arrayA, arrayB) + cb.invokeVoid(splitMergeMB, cb.this_, r, mid, end, arrayA, arrayB) - // result goes in A - cb.invokeVoid(mergeMB, r, begin, mid, end, arrayB, arrayA) - }) + // result goes in A + cb.invokeVoid(mergeMB, cb.this_, r, begin, mid, end, arrayB, arrayA) + }, + ) } // these arrays should be allocated once and reused - cb.if_(workingArray1.isNull || arrayRef(workingArray1).length() < newEnd, { - cb.assignAny(workingArray1, Code.newArray(newEnd)(array.ti)) - cb.assignAny(workingArray2, Code.newArray(newEnd)(array.ti)) - }) + cb.if_( + workingArray1.isNull || arrayRef(workingArray1).length() < newEnd, { + cb.assignAny(workingArray1, Code.newArray(newEnd)(array.ti)) + cb.assignAny(workingArray2, Code.newArray(newEnd)(array.ti)) + }, + ) cb.assign(i, 0) - cb.while_(i < newEnd, { - cb += arrayRef(workingArray1).update(i, array(i)) - cb += arrayRef(workingArray2).update(i, array(i)) - cb.assign(i, i + 1) - }) + cb.while_( + i < newEnd, { + cb += arrayRef(workingArray1).update(i, array(i)) + cb += arrayRef(workingArray2).update(i, array(i)) + cb.assign(i, i + 1) + }, + ) // elements are sorted in workingArray2 after calling splitMergeMB - cb.invokeVoid(splitMergeMB, sortMB.getCodeParam[Region](1), const(0), newEnd, workingArray1, workingArray2) + cb.invokeVoid( + splitMergeMB, + cb.this_, + sortMB.getCodeParam[Region](1), + const(0), + newEnd, + workingArray1, + workingArray2, + ) cb.assign(i, 0) - cb.while_(i < newEnd, { - array.update(cb, i, arrayRef(workingArray2)(i)) - cb.assign(i, i + 1) - }) + cb.while_( + i < newEnd, { + array.update(cb, i, arrayRef(workingArray2)(i)) + cb.assign(i, i + 1) + }, + ) } - cb.invokeVoid(sortMB, region) - - + cb.invokeVoid(sortMB, cb.this_, region) } def toRegion(cb: EmitCodeBuilder, t: Type): SIndexableValue = { t match { - case pca: TArray => + case _: TArray => val len = cb.newLocal[Int]("arraysorter_to_region_len", array.size) // fixme element requiredness should be set here - val arrayType = PCanonicalArray(array.elt.loadedSType.storageType().setRequired(this.prunedMissing || array.eltRequired)) + val arrayType = PCanonicalArray( + array.elt.loadedSType.storageType().setRequired(this.prunedMissing || array.eltRequired) + ) arrayType.constructFromElements(cb, r.region, len, deepCopy = false) { (cb, idx) => array.loadFromIndex(cb, r.region, idx) @@ -159,52 +209,69 @@ class ArraySorter(r: EmitRegion, array: StagedArrayBuilder) { val i = cb.newLocal[Int]("i", 0) val n = cb.newLocal[Int]("n", 0) val size = cb.newLocal[Int]("size", array.size) - cb.while_(i < size, { - cb.if_(!array.isMissing(i), { - cb.if_(i.cne(n), - array.update(cb, n, array(i))) - cb.assign(n, n + 1) - }) - cb.assign(i, i + 1) - }) + cb.while_( + i < size, { + cb.if_( + !array.isMissing(i), { + cb.if_(i.cne(n), array.update(cb, n, array(i))) + cb.assign(n, n + 1) + }, + ) + cb.assign(i, i + 1) + }, + ) array.setSize(cb, n) } - def distinctFromSorted(cb: EmitCodeBuilder, region: Value[Region], discardNext: (EmitCodeBuilder, Value[Region], EmitCode, EmitCode) => Code[Boolean]): Unit = { + def distinctFromSorted( + cb: EmitCodeBuilder, + region: Value[Region], + discardNext: (EmitCodeBuilder, Value[Region], EmitCode, EmitCode) => Code[Boolean], + ): Unit = { - val distinctMB = cb.emb.genEmitMethod("distinctFromSorted", FastSeq[ParamType](classInfo[Region]), UnitInfo) + val distinctMB = + cb.emb.genEmitMethod("distinctFromSorted", FastSeq[ParamType](classInfo[Region]), UnitInfo) distinctMB.voidWithBuilder { cb => val region = distinctMB.getCodeParam[Region](1) val i = cb.newLocal[Int]("i", 0) val n = cb.newLocal[Int]("n", 0) val size = cb.newLocal[Int]("size", array.size) - cb.while_(i < size, { - cb.assign(i, i + 1) + cb.while_( + i < size, { + cb.assign(i, i + 1) - val LskipLoopBegin = CodeLabel() - val LskipLoopEnd = CodeLabel() - cb.define(LskipLoopBegin) - cb.if_(i >= size, cb.goto(LskipLoopEnd)) - cb.if_(!discardNext(cb, region, - EmitCode.fromI(distinctMB)(cb => array.loadFromIndex(cb, region, n)), - EmitCode.fromI(distinctMB)(cb => array.loadFromIndex(cb, region, i))), - cb.goto(LskipLoopEnd)) - cb.assign(i, i + 1) - cb.goto(LskipLoopBegin) + val LskipLoopBegin = CodeLabel() + val LskipLoopEnd = CodeLabel() + cb.define(LskipLoopBegin) + cb.if_(i >= size, cb.goto(LskipLoopEnd)) + cb.if_( + !discardNext( + cb, + region, + EmitCode.fromI(distinctMB)(cb => array.loadFromIndex(cb, region, n)), + EmitCode.fromI(distinctMB)(cb => array.loadFromIndex(cb, region, i)), + ), + cb.goto(LskipLoopEnd), + ) + cb.assign(i, i + 1) + cb.goto(LskipLoopBegin) - cb.define(LskipLoopEnd) + cb.define(LskipLoopEnd) - cb.assign(n, n + 1) + cb.assign(n, n + 1) - cb.if_(i < size && i.cne(n), { - array.setMissing(cb, n, array.isMissing(i)) - cb.if_(!array.isMissing(n), array.update(cb, n, array(i))) - }) + cb.if_( + i < size && i.cne(n), { + array.setMissing(cb, n, array.isMissing(i)) + cb.if_(!array.isMissing(n), array.update(cb, n, array(i))) + }, + ) - }) + }, + ) array.setSize(cb, n) } - cb.invokeVoid(distinctMB, region) + cb.invokeVoid(distinctMB, cb.this_, region) } } diff --git a/hail/src/main/scala/is/hail/expr/ir/BaseIR.scala b/hail/src/main/scala/is/hail/expr/ir/BaseIR.scala index 51611bbd13e..fb625679ce7 100644 --- a/hail/src/main/scala/is/hail/expr/ir/BaseIR.scala +++ b/hail/src/main/scala/is/hail/expr/ir/BaseIR.scala @@ -3,8 +3,8 @@ package is.hail.expr.ir import is.hail.backend.ExecuteContext import is.hail.types.BaseType import is.hail.types.virtual.Type -import is.hail.utils.StackSafe._ import is.hail.utils._ +import is.hail.utils.StackSafe._ abstract class BaseIR { def typ: BaseType @@ -15,12 +15,12 @@ abstract class BaseIR { protected def copy(newChildren: IndexedSeq[BaseIR]): BaseIR - def deepCopy(): this.type = copy(newChildren = childrenSeq.map(_.deepCopy())).asInstanceOf[this.type] + def deepCopy(): this.type = + copy(newChildren = childrenSeq.map(_.deepCopy())).asInstanceOf[this.type] def noSharing(ctx: ExecuteContext): this.type = if (HasIRSharing(ctx)(this)) this.deepCopy() else this - // For use as a boolean flag by IR passes. Each pass uses a different sentinel value to encode // "true" (and anything else is false). As long as we maintain the global invariant that no // two passes use the same sentinel value, this allows us to reuse this field across passes @@ -62,18 +62,17 @@ abstract class BaseIR { } } - def forEachChildWithEnv(env: BindingEnv[Type])(f: (BaseIR, BindingEnv[Type]) => Unit): Unit = { + def forEachChildWithEnv(env: BindingEnv[Type])(f: (BaseIR, BindingEnv[Type]) => Unit): Unit = childrenSeq.view.zipWithIndex.foreach { case (child, i) => - val childEnv = ChildBindings(this, i, env) + val childEnv = Bindings(this, i, env) f(child, childEnv) } - } def mapChildrenWithEnv(env: BindingEnv[Type])(f: (BaseIR, BindingEnv[Type]) => BaseIR): BaseIR = { val newChildren = childrenSeq.toArray var res = this for (i <- newChildren.indices) { - val childEnv = ChildBindings(res, i, env) + val childEnv = Bindings(res, i, env) val child = newChildren(i) val newChild = f(child, childEnv) if (!(newChild eq child)) { @@ -84,10 +83,13 @@ abstract class BaseIR { res } - def forEachChildWithEnvStackSafe(env: BindingEnv[Type])(f: (BaseIR, Int, BindingEnv[Type]) => StackFrame[Unit]): StackFrame[Unit] = { + def forEachChildWithEnvStackSafe( + env: BindingEnv[Type] + )( + f: (BaseIR, Int, BindingEnv[Type]) => StackFrame[Unit] + ): StackFrame[Unit] = childrenSeq.view.zipWithIndex.foreachRecur { case (child, i) => - val childEnv = ChildBindings(this, i, env) + val childEnv = Bindings(this, i, env) f(child, i, childEnv) } - } } diff --git a/hail/src/main/scala/is/hail/expr/ir/BinaryOp.scala b/hail/src/main/scala/is/hail/expr/ir/BinaryOp.scala index f6e58347ce0..91be6146c88 100644 --- a/hail/src/main/scala/is/hail/expr/ir/BinaryOp.scala +++ b/hail/src/main/scala/is/hail/expr/ir/BinaryOp.scala @@ -1,10 +1,9 @@ package is.hail.expr.ir import is.hail.asm4s._ -import is.hail.types._ -import is.hail.types.physical.stypes.{SCode, SType, SValue} +import is.hail.types.physical.{typeToTypeInfo, PType} +import is.hail.types.physical.stypes.{SType, SValue} import is.hail.types.physical.stypes.interfaces._ -import is.hail.types.physical.{PType, typeToTypeInfo} import is.hail.types.virtual._ import is.hail.utils._ @@ -14,11 +13,19 @@ object BinaryOp { case (FloatingPointDivide(), TInt64, TInt64) => TFloat64 case (FloatingPointDivide(), TFloat32, TFloat32) => TFloat32 case (FloatingPointDivide(), TFloat64, TFloat64) => TFloat64 - case (Add() | Subtract() | Multiply() | RoundToNegInfDivide() | BitAnd() | BitOr() | BitXOr(), TInt32, TInt32) => TInt32 - case (Add() | Subtract() | Multiply() | RoundToNegInfDivide() | BitAnd() | BitOr() | BitXOr(), TInt64, TInt64) => TInt64 + case ( + Add() | Subtract() | Multiply() | RoundToNegInfDivide() | BitAnd() | BitOr() | BitXOr(), + TInt32, + TInt32, + ) => TInt32 + case ( + Add() | Subtract() | Multiply() | RoundToNegInfDivide() | BitAnd() | BitOr() | BitXOr(), + TInt64, + TInt64, + ) => TInt64 case (Add() | Subtract() | Multiply() | RoundToNegInfDivide(), TFloat32, TFloat32) => TFloat32 case (Add() | Subtract() | Multiply() | RoundToNegInfDivide(), TFloat64, TFloat64) => TFloat64 - case (LeftShift() | RightShift() | LogicalRightShift(), t@(TInt32 | TInt64), TInt32) => t + case (LeftShift() | RightShift() | LogicalRightShift(), t @ (TInt32 | TInt64), TInt32) => t } def defaultDivideOp(t: Type): BinaryOp = t match { @@ -80,7 +87,8 @@ object BinaryOp { case Subtract() => ll - rr case Multiply() => ll * rr case FloatingPointDivide() => ll.toD / rr.toD - case RoundToNegInfDivide() => Code.invokeStatic2[Math, Long, Long, Long]("floorDiv", ll, rr) + case RoundToNegInfDivide() => + Code.invokeStatic2[Math, Long, Long, Long]("floorDiv", ll, rr) case BitAnd() => ll & rr case BitOr() => ll | rr case BitXOr() => ll ^ rr @@ -94,7 +102,8 @@ object BinaryOp { case Subtract() => ll - rr case Multiply() => ll * rr case FloatingPointDivide() => ll / rr - case RoundToNegInfDivide() => Code.invokeStatic1[Math, Double, Double]("floor", ll.toD / rr.toD).toF + case RoundToNegInfDivide() => + Code.invokeStatic1[Math, Double, Double]("floor", ll.toD / rr.toD).toF case _ => incompatible(lt, rt, op) } case (TFloat64, TFloat64) => @@ -109,8 +118,6 @@ object BinaryOp { case _ => incompatible(lt, rt, op) } case (TBoolean, TBoolean) => - val ll = coerce[Boolean](l) - val rr = coerce[Boolean](r) op match { case _ => incompatible(lt, rt, op) } @@ -122,7 +129,7 @@ object BinaryOp { case "+" | "Add" => Add() case "-" | "Subtract" => Subtract() case "*" | "Multiply" => Multiply() - case "/" | "FloatingPointDivide" => FloatingPointDivide() + case "/" | "FloatingPointDivide" => FloatingPointDivide() case "//" | "RoundToNegInfDivide" => RoundToNegInfDivide() case "|" | "BitOr" => BitOr() case "&" | "BitAnd" => BitAnd() diff --git a/hail/src/main/scala/is/hail/expr/ir/BinarySearch.scala b/hail/src/main/scala/is/hail/expr/ir/BinarySearch.scala index e43edc13ef6..146502af3ee 100644 --- a/hail/src/main/scala/is/hail/expr/ir/BinarySearch.scala +++ b/hail/src/main/scala/is/hail/expr/ir/BinarySearch.scala @@ -6,90 +6,112 @@ import is.hail.types.physical.stypes._ import is.hail.types.physical.stypes.interfaces._ import is.hail.utils.FastSeq -import scala.language.existentials - object BinarySearch { object Comparator { def fromLtGt( ltNeedle: IEmitCode => Code[Boolean], - gtNeedle: IEmitCode => Code[Boolean] + gtNeedle: IEmitCode => Code[Boolean], ): Comparator = new Comparator { - def apply(cb: EmitCodeBuilder, elt: IEmitCode, ifLtNeedle: => Unit, ifGtNeedle: => Unit, ifNeither: => Unit): Unit = { + def apply( + cb: EmitCodeBuilder, + elt: IEmitCode, + ifLtNeedle: => Unit, + ifGtNeedle: => Unit, + ifNeither: => Unit, + ): Unit = { val eltVal = cb.memoize(elt) - cb.if_(ltNeedle(eltVal.loadI(cb)), + cb.if_( + ltNeedle(eltVal.loadI(cb)), ifLtNeedle, - cb.if_(gtNeedle(eltVal.loadI(cb)), - ifGtNeedle, - ifNeither)) + cb.if_(gtNeedle(eltVal.loadI(cb)), ifGtNeedle, ifNeither), + ) } } def fromCompare(compare: IEmitCode => Value[Int]): Comparator = new Comparator { - def apply(cb: EmitCodeBuilder, elt: IEmitCode, ifLtNeedle: => Unit, ifGtNeedle: => Unit, ifNeither: => Unit): Unit = { + def apply( + cb: EmitCodeBuilder, + elt: IEmitCode, + ifLtNeedle: => Unit, + ifGtNeedle: => Unit, + ifNeither: => Unit, + ): Unit = { val c = cb.memoize(compare(elt)) - cb.if_(c < 0, - ifLtNeedle, - cb.if_(c > 0, - ifGtNeedle, - ifNeither)) + cb.if_(c < 0, ifLtNeedle, cb.if_(c > 0, ifGtNeedle, ifNeither)) } } def fromPred(pred: IEmitCode => Code[Boolean]): Comparator = new Comparator { - def apply(cb: EmitCodeBuilder, elt: IEmitCode, ifLtNeedle: => Unit, ifGtNeedle: => Unit, ifNeither: => Unit): Unit = { + def apply( + cb: EmitCodeBuilder, + elt: IEmitCode, + ifLtNeedle: => Unit, + ifGtNeedle: => Unit, + ifNeither: => Unit, + ): Unit = cb.if_(pred(elt), ifGtNeedle, ifLtNeedle) - } } } - /** Represents a discriminator of values of some type into one of three - * mutually exclusive categories: - * - less than the needle (whatever is being searched for) - * - greater than the needle - * - neither (interpretation depends on the application, e.g. "equals needle", - * "contains needle", "contained in needle") + /** Represents a discriminator of values of some type into one of three mutually exclusive + * categories: + * - less than the needle (whatever is being searched for) + * - greater than the needle + * - neither (interpretation depends on the application, e.g. "equals needle", "contains + * needle", "contained in needle") */ abstract class Comparator { - def apply(cb: EmitCodeBuilder, elt: IEmitCode, ltNeedle: => Unit, gtNeedle: => Unit, neither: => Unit): Unit + def apply( + cb: EmitCodeBuilder, + elt: IEmitCode, + ltNeedle: => Unit, + gtNeedle: => Unit, + neither: => Unit, + ): Unit } - /** Returns true if haystack contains an element x such that !ltNeedle(x) - * and !gtNeedle(x), false otherwise. + /** Returns true if haystack contains an element x such that !ltNeedle(x) and !gtNeedle(x), false + * otherwise. */ def containsOrdered( cb: EmitCodeBuilder, haystack: SIndexableValue, ltNeedle: IEmitCode => Code[Boolean], - gtNeedle: IEmitCode => Code[Boolean] + gtNeedle: IEmitCode => Code[Boolean], ): Value[Boolean] = containsOrdered(cb, haystack, Comparator.fromLtGt(ltNeedle, gtNeedle)) - /** Returns true if haystack contains an element x such that !lt(x, needle) - * and !lt(needle, x), false otherwise. + /** Returns true if haystack contains an element x such that !lt(x, needle) and !lt(needle, x), + * false otherwise. */ def containsOrdered( cb: EmitCodeBuilder, haystack: SIndexableValue, needle: EmitValue, lt: (IEmitCode, IEmitCode) => Code[Boolean], - key: IEmitCode => IEmitCode + key: IEmitCode => IEmitCode, ): Value[Boolean] = - containsOrdered(cb, haystack, x => lt(key(x), needle.loadI(cb)), x => lt(needle.loadI(cb), key(x))) + containsOrdered( + cb, + haystack, + x => lt(key(x), needle.loadI(cb)), + x => lt(needle.loadI(cb), key(x)), + ) def containsOrdered( cb: EmitCodeBuilder, haystack: SIndexableValue, - compare: Comparator + compare: Comparator, ): Value[Boolean] = runSearch[Boolean](cb, haystack, compare, (_, _, _) => true, (_) => false) /** Returns (l, u) such that - * - range [0, l) is < needle - * - range [l, u) is incomparable ("equal") to needle - * - range [u, size) is > needle + * - range [0, l) is < needle + * - range [l, u) is incomparable ("equal") to needle + * - range [u, size) is > needle * - * Assumes comparator separates haystack into < needle, followed by - * incomparable to needle, followed by > needle. + * Assumes comparator separates haystack into < needle, followed by incomparable to needle, + * followed by > needle. */ def equalRange( cb: EmitCodeBuilder, @@ -98,36 +120,41 @@ object BinarySearch { ltNeedle: IEmitCode => Code[Boolean], gtNeedle: IEmitCode => Code[Boolean], start: Value[Int], - end: Value[Int] + end: Value[Int], ): (Value[Int], Value[Int]) = { val l = cb.newLocal[Int]("equalRange_l") val u = cb.newLocal[Int]("equalRange_u") - runSearchBoundedUnit(cb, haystack, compare, start, end, + runSearchBoundedUnit( + cb, + haystack, + compare, + start, + end, (curL, m, curU) => { // [start, curL) is < needle // [start, m] is <= needle // [m, end) is >= needle // [curR, end) is > needle - cb.assign(l, - lowerBound(cb, haystack, ltNeedle, curL, m)) + cb.assign(l, lowerBound(cb, haystack, ltNeedle, curL, m)) // [curL, l) is < needle // [l, m) is >= needle - cb.assign(u, - upperBound(cb, haystack, gtNeedle, cb.memoize(m + 1), curU)) + cb.assign(u, upperBound(cb, haystack, gtNeedle, cb.memoize(m + 1), curU)) // [m+1, u) is <= needle // [u, curU) is > needle - }, m => { + }, + m => { // [start, m) is < needle // [m, end) is > needle cb.assign(l, m) cb.assign(u, m) - }) + }, + ) (l, u) } /** Returns i in ['start', 'end'] such that - * - range [start, i) is < needle - * - range [i, end) is >= needle + * - range [start, i) is < needle + * - range [i, end) is >= needle * * Assumes ltNeedle is down-closed, i.e. all trues precede all falses */ @@ -136,20 +163,20 @@ object BinarySearch { haystack: SIndexableValue, ltNeedle: IEmitCode => Code[Boolean], start: Value[Int], - end: Value[Int] + end: Value[Int], ): Value[Int] = partitionPoint(cb, haystack, x => !ltNeedle(x), start, end) def lowerBound( cb: EmitCodeBuilder, haystack: SIndexableValue, - ltNeedle: IEmitCode => Code[Boolean] + ltNeedle: IEmitCode => Code[Boolean], ): Value[Int] = lowerBound(cb, haystack, ltNeedle, 0, haystack.loadLength()) /** Returns i in ['start', 'end'] such that - * - range [start, i) is <= needle - * - range [i, end) is > needle + * - range [start, i) is <= needle + * - range [i, end) is > needle * * Assumes gtNeedle is up-closed, i.e. all falses precede all trues */ @@ -158,58 +185,63 @@ object BinarySearch { haystack: SIndexableValue, gtNeedle: IEmitCode => Code[Boolean], start: Value[Int], - end: Value[Int] + end: Value[Int], ): Value[Int] = partitionPoint(cb, haystack, gtNeedle, start, end) def upperBound( cb: EmitCodeBuilder, haystack: SIndexableValue, - gtNeedle: IEmitCode => Code[Boolean] + gtNeedle: IEmitCode => Code[Boolean], ): Value[Int] = lowerBound(cb, haystack, gtNeedle, 0, haystack.loadLength()) /** Returns 'start' <= i <= 'end' such that - * - pred is false on range [start, i), and - * - pred is true on range [i, end). + * - pred is false on range [start, i), and + * - pred is true on range [i, end). * - * Assumes pred partitions a, i.e. for all 0 <= i <= j < haystack.size, - * if pred(i) then pred(j), i.e. all falses precede all trues. + * Assumes pred partitions a, i.e. for all 0 <= i <= j < haystack.size, if pred(i) then pred(j), + * i.e. all falses precede all trues. */ def partitionPoint( cb: EmitCodeBuilder, haystack: SIndexableValue, pred: IEmitCode => Code[Boolean], start: Value[Int], - end: Value[Int] + end: Value[Int], ): Value[Int] = { var i: Value[Int] = null - runSearchBoundedUnit(cb, haystack, Comparator.fromPred(pred), start, end, + runSearchBoundedUnit( + cb, + haystack, + Comparator.fromPred(pred), + start, + end, (_, _, _) => {}, // unreachable - _i => i = _i) + _i => i = _i, + ) i } def partitionPoint( cb: EmitCodeBuilder, haystack: SIndexableValue, - pred: IEmitCode => Code[Boolean] + pred: IEmitCode => Code[Boolean], ): Value[Int] = partitionPoint(cb, haystack, pred, const(0), haystack.loadLength()) /** Perform binary search until either - * - an index m is found for which haystack(i) is incomparable with the needle, - * i.e. neither ltNeedle(m) nor gtNeedle(m). - * In this case, call found(l, m, u), where - * - haystack(m) is incomparable to needle - * - range [start, l) is < needle - * - range [r, end) is > needle - * - it is certain that no such m exists. In this case, call notFound(j), where - * - range [start, j) is < needle - * - range [j, end) is > needle + * - an index m is found for which haystack(i) is incomparable with the needle, i.e. neither + * ltNeedle(m) nor gtNeedle(m). In this case, call found(l, m, u), where + * - haystack(m) is incomparable to needle + * - range [start, l) is < needle + * - range [r, end) is > needle + * - it is certain that no such m exists. In this case, call notFound(j), where + * - range [start, j) is < needle + * - range [j, end) is > needle * - * Assumes comparator separates haystack into < needle, followed by - * incomparable to needle, followed by > needle. + * Assumes comparator separates haystack into < needle, followed by incomparable to needle, + * followed by > needle. */ private def runSearchBoundedUnit( cb: EmitCodeBuilder, @@ -218,7 +250,7 @@ object BinarySearch { start: Value[Int], end: Value[Int], found: (Value[Int], Value[Int], Value[Int]) => Unit, - notFound: Value[Int] => Unit + notFound: Value[Int] => Unit, ): Unit = { val left = cb.newLocal[Int]("left", start) val right = cb.newLocal[Int]("right", end) @@ -228,38 +260,32 @@ object BinarySearch { // - left <= right // terminates b/c (right - left) strictly decreases each iteration cb.loop { recur => - cb.if_(left < right, { - val mid = cb.memoize((left + right) >>> 1) // works even when sum overflows - compare(cb, haystack.loadElement(cb, mid), { - // range [start, mid] is < needle - cb.assign(left, mid + 1) - cb.goto(recur) - }, { - // range [mid, end) is > needle - cb.assign(right, mid) - cb.goto(recur) - }, { - // haystack(mid) is incomparable to needle - found(left, mid, right) - }) - }, { + cb.if_( + left < right, { + val mid = cb.memoize((left + right) >>> 1) // works even when sum overflows + compare( + cb, + haystack.loadElement(cb, mid), { + // range [start, mid] is < needle + cb.assign(left, mid + 1) + cb.goto(recur) + }, { + // range [mid, end) is > needle + cb.assign(right, mid) + cb.goto(recur) + }, + // haystack(mid) is incomparable to needle + found(left, mid, right), + ) + }, // now loop invariants hold, with left = right, so // - range [start, left) is < needle // - range [left, end) is > needle - notFound(left) - }) + notFound(left), + ) } } - private def runSearchUnit( - cb: EmitCodeBuilder, - haystack: SIndexableValue, - compare: Comparator, - found: (Value[Int], Value[Int], Value[Int]) => Unit, - notFound: Value[Int] => Unit - ): Unit = - runSearchBoundedUnit(cb, haystack, compare, 0, haystack.loadLength(), found, notFound) - private def runSearchBounded[T: TypeInfo]( cb: EmitCodeBuilder, haystack: SIndexableValue, @@ -267,12 +293,18 @@ object BinarySearch { start: Value[Int], end: Value[Int], found: (Value[Int], Value[Int], Value[Int]) => Code[T], - notFound: Value[Int] => Code[T] + notFound: Value[Int] => Code[T], ): Value[T] = { val ret = cb.newLocal[T]("runSearch_ret") - runSearchBoundedUnit(cb, haystack, compare, start, end, + runSearchBoundedUnit( + cb, + haystack, + compare, + start, + end, (l, m, r) => cb.assign(ret, found(l, m, r)), - i => cb.assign(ret, notFound(i))) + i => cb.assign(ret, notFound(i)), + ) ret } @@ -281,18 +313,25 @@ object BinarySearch { haystack: SIndexableValue, compare: Comparator, found: (Value[Int], Value[Int], Value[Int]) => Code[T], - notFound: Value[Int] => Code[T] + notFound: Value[Int] => Code[T], ): Value[T] = runSearchBounded[T](cb, haystack, compare, 0, haystack.loadLength(), found, notFound) } -class BinarySearch[C](mb: EmitMethodBuilder[C], +class BinarySearch[C]( + mb: EmitMethodBuilder[C], containerType: SContainer, eltType: EmitType, getKey: (EmitCodeBuilder, EmitValue) => EmitValue, - bound: String = "lower") { + bound: String = "lower", +) { val containerElementType: EmitType = containerType.elementEmitType - val findElt = mb.genEmitMethod("findElt", FastSeq[ParamType](containerType.paramType, eltType.paramType), typeInfo[Int]) + + val findElt = mb.genEmitMethod( + "findElt", + FastSeq[ParamType](containerType.paramType, eltType.paramType), + typeInfo[Int], + ) // Returns i in [0, n] such that a(j) < key for j in [0, i), and a(j) >= key // for j in [i, n) @@ -301,30 +340,33 @@ class BinarySearch[C](mb: EmitMethodBuilder[C], val needle = findElt.getEmitParam(cb, 2) val f: ( - EmitCodeBuilder, - SIndexableValue, - IEmitCode => Code[Boolean] + EmitCodeBuilder, + SIndexableValue, + IEmitCode => Code[Boolean], ) => Value[Int] = bound match { case "upper" => BinarySearch.upperBound case "lower" => BinarySearch.lowerBound } - f(cb, haystack, { containerElement => - val elementVal = cb.memoize(containerElement, "binary_search_elt") - val compareVal = getKey(cb, elementVal) - bound match { - case "upper" => - val gt = mb.ecb.getOrderingFunction(compareVal.st, eltType.st, CodeOrdering.Gt()) - gt(cb, compareVal, needle) - case "lower" => - val lt = mb.ecb.getOrderingFunction(compareVal.st, eltType.st, CodeOrdering.Lt()) - lt(cb, compareVal, needle) - } - }) + f( + cb, + haystack, + { containerElement => + val elementVal = cb.memoize(containerElement, "binary_search_elt") + val compareVal = getKey(cb, elementVal) + bound match { + case "upper" => + val gt = mb.ecb.getOrderingFunction(compareVal.st, eltType.st, CodeOrdering.Gt()) + gt(cb, compareVal, needle) + case "lower" => + val lt = mb.ecb.getOrderingFunction(compareVal.st, eltType.st, CodeOrdering.Lt()) + lt(cb, compareVal, needle) + } + }, + ) } // check missingness of v before calling - def search(cb: EmitCodeBuilder, array: SValue, v: EmitCode): Value[Int] = { - cb.memoize(cb.invokeCode[Int](findElt, array, v)) - } + def search(cb: EmitCodeBuilder, array: SValue, v: EmitCode): Value[Int] = + cb.memoize(cb.invokeCode[Int](findElt, cb.this_, array, v)) } diff --git a/hail/src/main/scala/is/hail/expr/ir/Binds.scala b/hail/src/main/scala/is/hail/expr/ir/Binds.scala index 60ad922da29..6a2c8fb752d 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Binds.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Binds.scala @@ -2,249 +2,417 @@ package is.hail.expr.ir import is.hail.types.tcoerce import is.hail.types.virtual._ -import is.hail.utils._ +import is.hail.types.virtual.TIterable.elementType -object Binds { - def apply(x: IR, v: String, i: Int): Boolean = Bindings(x, i).exists(_._1 == v) +object SegregatedBindingEnv { + def apply[A, B](env: BindingEnv[A]): SegregatedBindingEnv[A, B] = + SegregatedBindingEnv(env, env.dropBindings) } -object Bindings { - private val empty: Array[(String, Type)] = Array() - - // A call to Bindings(x, i) may only query the types of children with - // index < i - def apply(x: BaseIR, i: Int): Iterable[(String, Type)] = x match { - case Let(bindings, _) => - val result = Array.ofDim[(String, Type)](i) - for (k <- 0 until i) result(k) = bindings(k)._1 -> bindings(k)._2.typ - result - case TailLoop(name, args, resultType, _) => if (i == args.length) - args.map { case (name, ir) => name -> ir.typ } :+ - name -> TTuple(TTuple(args.map(_._2.typ): _*), resultType) else empty - case StreamMap(a, name, _) => if (i == 1) Array(name -> tcoerce[TStream](a.typ).elementType) else empty - case StreamZip(as, names, _, _, _) => if (i == as.length) names.zip(as.map(a => tcoerce[TStream](a.typ).elementType)) else empty - case StreamZipJoin(as, key, curKey, curVals, _) => - val eltType = tcoerce[TStruct](tcoerce[TStream](as.head.typ).elementType) - if (i == as.length) - Array(curKey -> eltType.typeAfterSelectNames(key), - curVals -> TArray(eltType)) - else - empty - case StreamZipJoinProducers(contexts, ctxName, makeProducer, key, curKey, curVals, _) => - if (i == 1) { - val contextType = TIterable.elementType(contexts.typ) - Array(ctxName -> contextType) - } else if (i == 2) { - val eltType = tcoerce[TStruct](tcoerce[TStream](makeProducer.typ).elementType) - Array(curKey -> eltType.typeAfterSelectNames(key), - curVals -> TArray(eltType)) - } else - empty - case StreamFor(a, name, _) => if (i == 1) Array(name -> tcoerce[TStream](a.typ).elementType) else empty - case StreamFlatMap(a, name, _) => if (i == 1) Array(name -> tcoerce[TStream](a.typ).elementType) else empty - case StreamFilter(a, name, _) => if (i == 1) Array(name -> tcoerce[TStream](a.typ).elementType) else empty - case StreamTakeWhile(a, name, _) => if (i == 1) Array(name -> tcoerce[TStream](a.typ).elementType) else empty - case StreamDropWhile(a, name, _) => if (i == 1) Array(name -> tcoerce[TStream](a.typ).elementType) else empty - case StreamFold(a, zero, accumName, valueName, _) => if (i == 2) Array(accumName -> zero.typ, valueName -> tcoerce[TStream](a.typ).elementType) else empty - case StreamFold2(a, accum, valueName, seq, result) => - if (i <= accum.length) - empty - else if (i < 2 * accum.length + 1) - Array((valueName, tcoerce[TStream](a.typ).elementType)) ++ accum.map { case (name, value) => (name, value.typ) } - else - accum.map { case (name, value) => (name, value.typ) } - case StreamBufferedAggregate(stream, _, _, _, name, _, _) => if (i > 0) Array(name -> tcoerce[TStream](stream.typ).elementType) else empty - case RunAggScan(a, name, _, _, _, _) => if (i == 2 || i == 3) Array(name -> tcoerce[TStream](a.typ).elementType) else empty - case StreamScan(a, zero, accumName, valueName, _) => if (i == 2) Array(accumName -> zero.typ, valueName -> tcoerce[TStream](a.typ).elementType) else empty - case StreamAggScan(a, name, _) => if (i == 1) FastSeq(name -> a.typ.asInstanceOf[TStream].elementType) else empty - case StreamJoinRightDistinct(ll, rr, _, _, l, r, _, _) => if (i == 2) Array(l -> tcoerce[TStream](ll.typ).elementType, r -> tcoerce[TStream](rr.typ).elementType) else empty - case ArraySort(a, left, right, _) => if (i == 1) Array(left -> tcoerce[TStream](a.typ).elementType, right -> tcoerce[TStream](a.typ).elementType) else empty - case ArrayMaximalIndependentSet(a, Some((left, right, _))) => - if (i == 1) { - val typ = tcoerce[TArray](a.typ).elementType.asInstanceOf[TBaseStruct].types.head - val tupleType = TTuple(typ) - Array(left -> tupleType, right -> tupleType) - } else { - empty - } - case AggArrayPerElement(a, _, indexName, _, _, _) => if (i == 1) FastSeq(indexName -> TInt32) else empty - case AggFold(zero, seqOp, combOp, accumName, otherAccumName, _) => { - if (i == 1) FastSeq(accumName -> zero.typ) - else if (i == 2) FastSeq(accumName -> zero.typ, otherAccumName -> zero.typ) - else empty - } - case NDArrayMap(nd, name, _) => if (i == 1) Array(name -> tcoerce[TNDArray](nd.typ).elementType) else empty - case NDArrayMap2(l, r, lName, rName, _, _) => if (i == 2) Array(lName -> tcoerce[TNDArray](l.typ).elementType, rName -> tcoerce[TNDArray](r.typ).elementType) else empty - case CollectDistributedArray(contexts, globals, cname, gname, _, _, _, _) => if (i == 2) Array(cname -> tcoerce[TStream](contexts.typ).elementType, gname -> globals.typ) else empty - case TableAggregate(child, _) => if (i == 1) child.typ.globalEnv.m else empty - case MatrixAggregate(child, _) => if (i == 1) child.typ.globalEnv.m else empty - case TableFilter(child, _) => if (i == 1) child.typ.rowEnv.m else empty - case TableGen(contexts, globals, cname, gname, _, _, _) => - if (i == 2) Array(cname -> TIterable.elementType(contexts.typ), gname -> globals.typ) - else empty - case TableMapGlobals(child, _) => if (i == 1) child.typ.globalEnv.m else empty - case TableMapRows(child, _) => if (i == 1) child.typ.rowEnv.m else empty - case TableAggregateByKey(child, _) => if (i == 1) child.typ.globalEnv.m else empty - case TableKeyByAndAggregate(child, _, _, _, _) => if (i == 1) child.typ.globalEnv.m else if (i == 2) child.typ.rowEnv.m else empty - case TableMapPartitions(child, g, p, _, _, _) => if (i == 1) Array(g -> child.typ.globalType, p -> TStream(child.typ.rowType)) else empty - case MatrixMapRows(child, _) => if (i == 1) child.typ.rowEnv.bind("n_cols", TInt32).m else empty - case MatrixFilterRows(child, _) => if (i == 1) child.typ.rowEnv.m else empty - case MatrixMapCols(child, _, _) => if (i == 1) child.typ.colEnv.bind("n_rows", TInt64).m else empty - case MatrixFilterCols(child, _) => if (i == 1) child.typ.colEnv.m else empty - case MatrixMapEntries(child, _) => if (i == 1) child.typ.entryEnv.m else empty - case MatrixFilterEntries(child, _) => if (i == 1) child.typ.entryEnv.m else empty - case MatrixMapGlobals(child, _) => if (i == 1) child.typ.globalEnv.m else empty - case MatrixAggregateColsByKey(child, _, _) => if (i == 1) child.typ.rowEnv.m else if (i == 2) child.typ.globalEnv.m else empty - case MatrixAggregateRowsByKey(child, _, _) => if (i == 1) child.typ.colEnv.m else if (i == 2) child.typ.globalEnv.m else empty - case BlockMatrixMap(_, eltName, _, _) => if (i == 1) Array(eltName -> TFloat64) else empty - case BlockMatrixMap2(_, _, lName, rName, _, _) => if (i == 2) Array(lName -> TFloat64, rName -> TFloat64) else empty - case _ => empty - } -} +case class SegregatedBindingEnv[A, B]( + childEnvWithoutBindings: BindingEnv[A], + newBindings: BindingEnv[B], +) extends GenericBindingEnv[SegregatedBindingEnv[A, B], B] { + def unified(implicit ev: BindingEnv[B] =:= BindingEnv[A]): BindingEnv[A] = + childEnvWithoutBindings.merge(newBindings) + def mapNewBindings[C](f: (String, B) => C): SegregatedBindingEnv[A, C] = SegregatedBindingEnv( + childEnvWithoutBindings, + newBindings.mapValuesWithKey(f), + ) -object AggBindings { + override def promoteAgg: SegregatedBindingEnv[A, B] = SegregatedBindingEnv( + childEnvWithoutBindings.promoteAgg, + newBindings.promoteAgg, + ) - def apply(x: BaseIR, i: Int, parent: BindingEnv[_]): Option[Iterable[(String, Type)]] = { - def wrapped(bindings: Iterable[(String, Type)]): Option[Iterable[(String, Type)]] = { - if (parent.agg.isEmpty) - throw new RuntimeException(s"aggEnv was None for child $i of $x") - Some(bindings) - } + override def promoteScan: SegregatedBindingEnv[A, B] = SegregatedBindingEnv( + childEnvWithoutBindings.promoteScan, + newBindings.promoteScan, + ) - def base: Option[Iterable[(String, Type)]] = parent.agg.map(_ => FastSeq()) - - x match { - case AggLet(name, value, _, false) => if (i == 1) wrapped(FastSeq(name -> value.typ)) else None - case AggFilter(_, _, false) => if (i == 0) None else base - case AggGroupBy(_, _, false) => if (i == 0) None else base - case AggExplode(a, name, _, false) => if (i == 1) wrapped(FastSeq(name -> a.typ.asInstanceOf[TIterable].elementType)) else None - case AggArrayPerElement(a, elementName, indexName, _, _, false) => if (i == 1) wrapped(FastSeq(elementName -> a.typ.asInstanceOf[TIterable].elementType, indexName -> TInt32)) else if (i == 2) base else None - case StreamAgg(a, name, _) => if (i == 1) Some(FastSeq(name -> a.typ.asInstanceOf[TIterable].elementType)) else base - case TableAggregate(child, _) => if (i == 1) Some(child.typ.rowEnv.m) else None - case MatrixAggregate(child, _) => if (i == 1) Some(child.typ.entryEnv.m) else None - case RelationalLet(_, _, _) => None - case CollectDistributedArray(_, _, _, _, _, _, _, _) if (i == 2) => None - case _: ApplyAggOp => None - case AggFold(_, _, _, _, _, false) => None - case _: IR => base - - case TableAggregateByKey(child, _) => if (i == 1) Some(child.typ.rowEnv.m) else None - case TableKeyByAndAggregate(child, _, _, _, _) => if (i == 1) Some(child.typ.rowEnv.m) else None - case _: TableIR => None - - case MatrixMapRows(child, _) => if (i == 1) Some(child.typ.entryEnv.m) else None - case MatrixMapCols(child, _, _) => if (i == 1) Some(child.typ.entryEnv.m) else None - case MatrixAggregateColsByKey(child, _, _) => if (i == 1) Some(child.typ.entryEnv.m) else if (i == 2) Some(child.typ.colEnv.m) else None - case MatrixAggregateRowsByKey(child, _, _) => if (i == 1) Some(child.typ.entryEnv.m) else if (i == 2) Some(child.typ.rowEnv.m) else None - case _: MatrixIR => None - - case _: BlockMatrixIR => None + override def bindEval(bindings: (String, B)*): SegregatedBindingEnv[A, B] = + copy(newBindings = newBindings.bindEval(bindings: _*)) - } - } -} + override def dropEval: SegregatedBindingEnv[A, B] = SegregatedBindingEnv( + childEnvWithoutBindings.copy(eval = Env.empty), + newBindings.copy(eval = Env.empty), + ) -object ScanBindings { - def apply(x: BaseIR, i: Int, parent: BindingEnv[_]): Option[Iterable[(String, Type)]] = { - def wrapped(bindings: Iterable[(String, Type)]): Option[Iterable[(String, Type)]] = { - if (parent.scan.isEmpty) - throw new RuntimeException(s"scanEnv was None for child $i of $x") - Some(bindings) - } + override def bindAgg(bindings: (String, B)*): SegregatedBindingEnv[A, B] = + copy(newBindings = newBindings.bindAgg(bindings: _*)) - def base: Option[Iterable[(String, Type)]] = parent.scan.map(_ => FastSeq()) - - x match { - case AggLet(name, value, _, true) => if (i == 1) wrapped(FastSeq(name -> value.typ)) else None - case AggFilter(_, _, true) => if (i == 0) None else base - case AggGroupBy(_, _, true) => if (i == 0) None else base - case AggExplode(a, name, _, true) => if (i == 1) wrapped(FastSeq(name -> a.typ.asInstanceOf[TIterable].elementType)) else None - case AggArrayPerElement(a, elementName, indexName, _, _, true) => if (i == 1) wrapped(FastSeq(elementName -> a.typ.asInstanceOf[TIterable].elementType, indexName -> TInt32)) else if (i == 2) base else None - case AggFold(_, _, _, _, _, true) => None - case StreamAggScan(a, name, _) => if (i == 1) Some(FastSeq(name -> a.typ.asInstanceOf[TIterable].elementType)) else base - case TableAggregate(_, _) => None - case MatrixAggregate(_, _) => None - case RelationalLet(_, _, _) => None - case CollectDistributedArray(_, _, _, _, _, _, _, _) if (i == 2) => None - case _: ApplyScanOp => None - case _: IR => base - - case TableMapRows(child, _) => if (i == 1) Some(child.typ.rowEnv.m) else None - case _: TableIR => None - - case MatrixMapRows(child, _) => if (i == 1) Some(child.typ.rowEnv.m) else None - case MatrixMapCols(child, _, _) => if (i == 1) Some(child.typ.colEnv.m) else None - case _: MatrixIR => None - - case _: BlockMatrixIR => None - } - } + override def bindScan(bindings: (String, B)*): SegregatedBindingEnv[A, B] = + copy(newBindings = newBindings.bindScan(bindings: _*)) + + override def createAgg: SegregatedBindingEnv[A, B] = SegregatedBindingEnv( + childEnvWithoutBindings.createAgg, + newBindings.createAgg, + ) + + override def createScan: SegregatedBindingEnv[A, B] = SegregatedBindingEnv( + childEnvWithoutBindings.createScan, + newBindings.createScan, + ) + + override def noAgg: SegregatedBindingEnv[A, B] = SegregatedBindingEnv( + childEnvWithoutBindings.noAgg, + newBindings.noAgg, + ) + + override def noScan: SegregatedBindingEnv[A, B] = SegregatedBindingEnv( + childEnvWithoutBindings.noScan, + newBindings.noScan, + ) + + override def onlyRelational(keepAggCapabilities: Boolean = false): SegregatedBindingEnv[A, B] = + SegregatedBindingEnv( + childEnvWithoutBindings.onlyRelational(keepAggCapabilities), + newBindings.onlyRelational(keepAggCapabilities), + ) + + override def bindRelational(bindings: (String, B)*): SegregatedBindingEnv[A, B] = + copy(newBindings = newBindings.bindRelational(bindings: _*)) } -object RelationalBindings { - private val empty: Array[(String, Type)] = Array() +case class EvalOnlyBindingEnv[T](env: Env[T]) extends GenericBindingEnv[EvalOnlyBindingEnv[T], T] { + override def promoteAgg: EvalOnlyBindingEnv[T] = + EvalOnlyBindingEnv(Env.empty) - def apply(x: BaseIR, i: Int): Iterable[(String, Type)] = { - x match { - case RelationalLet(name, value, _) => if (i == 1) FastSeq(name -> value.typ) else empty - case RelationalLetTable(name, value, _) => if (i == 1) FastSeq(name -> value.typ) else empty - case RelationalLetMatrixTable(name, value, _) => if (i == 1) FastSeq(name -> value.typ) else empty - case RelationalLetBlockMatrix(name, value, _) => if (i == 1) FastSeq(name -> value.typ) else empty - case _ => empty - } - } + override def promoteScan: EvalOnlyBindingEnv[T] = + EvalOnlyBindingEnv(Env.empty) + + override def bindEval(bindings: (String, T)*): EvalOnlyBindingEnv[T] = + EvalOnlyBindingEnv(env.bindIterable(bindings)) + + override def dropEval: EvalOnlyBindingEnv[T] = + EvalOnlyBindingEnv(Env.empty) + + override def bindAgg(bindings: (String, T)*): EvalOnlyBindingEnv[T] = + this + + override def bindScan(bindings: (String, T)*): EvalOnlyBindingEnv[T] = + this + + override def createAgg: EvalOnlyBindingEnv[T] = + this + + override def createScan: EvalOnlyBindingEnv[T] = + this + + override def noAgg: EvalOnlyBindingEnv[T] = + this + + override def noScan: EvalOnlyBindingEnv[T] = + this + + override def onlyRelational(keepAggCapabilities: Boolean = false): EvalOnlyBindingEnv[T] = + EvalOnlyBindingEnv(Env.empty) + + override def bindRelational(bindings: (String, T)*): EvalOnlyBindingEnv[T] = + this } -object NewBindings { - def apply(x: BaseIR, i: Int, parent: BindingEnv[_]): BindingEnv[Type] = { - BindingEnv(Env.fromSeq(Bindings(x, i)), - agg = AggBindings(x, i, parent).map(b => Env.fromSeq(b)), - scan = ScanBindings(x, i, parent).map(b => Env.fromSeq(b)), - relational = Env.fromSeq(RelationalBindings(x, i))) - } +object Binds { + def apply(x: IR, v: String, i: Int): Boolean = + Bindings(x, i, EvalOnlyBindingEnv(Env.empty[Type])).env.contains(v) } -object ChildEnvWithoutBindings { - def apply[T](ir: BaseIR, i: Int, env: BindingEnv[T]): BindingEnv[T] = { +object Bindings { + + /** Returns the environment of the `i`th child of `ir` given the environment of the parent node + * `ir`. + */ + def apply[E <: GenericBindingEnv[E, Type]](ir: BaseIR, i: Int, baseEnv: E): E = + ir match { + case ir: MatrixIR => childEnvMatrix(ir, i, baseEnv) + case ir: TableIR => childEnvTable(ir, i, baseEnv) + case ir: BlockMatrixIR => childEnvBlockMatrix(ir, i, baseEnv) + case ir: IR => childEnvValue(ir, i, baseEnv) + } + + /** Like [[Bindings.apply]], but keeps separate any new bindings introduced by `ir`. Always + * satisfies the identity + * {{{ + * Bindings.segregated(ir, i, baseEnv).unified == Bindings(ir, i, baseEnv) + * }}} + */ + def segregated[A](ir: BaseIR, i: Int, baseEnv: BindingEnv[A]): SegregatedBindingEnv[A, Type] = + apply(ir, i, SegregatedBindingEnv(baseEnv)) + + private def childEnvMatrix[E <: GenericBindingEnv[E, Type]](ir: MatrixIR, i: Int, _baseEnv: E) + : E = { + val baseEnv = _baseEnv.onlyRelational() ir match { - case ArrayMaximalIndependentSet(_, Some(_)) if (i == 1) => env.copy(eval = Env.empty) - case StreamAgg(_, _, _) => if (i == 1) env.createAgg else env - case StreamAggScan(_, _, _) => if (i == 1) env.createScan else env - case ApplyAggOp(init, _, _) => if (i < init.length) env.copy(agg = None) else env.promoteAgg - case ApplyScanOp(init, _, _) => if (i < init.length) env.copy(scan = None) else env.promoteScan - case AggFold(zero, seqOp, combOp, elementName, accumName, isScan) => (isScan, i) match { - case (true, 0) => env.noScan - case (false, 0) => env.noAgg - case (true, 1) => env.promoteScan - case (false, 1) => env.promoteAgg - case (true, 2) => env.copy(eval = Env.empty, scan = None) - case (false, 2) => env.copy(eval = Env.empty, agg = None) - } - case CollectDistributedArray(_, _, _, _, _, _, _, _) => if (i == 2) BindingEnv(relational = env.relational) else env - case MatrixAggregate(_, _) => if (i == 0) env.onlyRelational else BindingEnv(Env.empty, agg = Some(Env.empty), relational = env.relational) - case TableAggregate(_, _) => if (i == 0) env.onlyRelational else BindingEnv(Env.empty, agg = Some(Env.empty), relational = env.relational) - case RelationalLet(_, _, _) => if (i == 0) env.onlyRelational else env.copy(agg = None, scan = None) - case LiftMeOut(_) => BindingEnv(Env.empty[T], env.agg.map(_ => Env.empty), env.scan.map(_ => Env.empty), relational = env.relational) - case _: IR => if (UsesAggEnv(ir, i)) env.promoteAgg else if (UsesScanEnv(ir, i)) env.promoteScan else env - case x => BindingEnv( - agg = AggBindings(x, i, env).map(_ => Env.empty), - scan = ScanBindings(x, i, env).map(_ => Env.empty), - relational = env.relational) + case MatrixMapRows(child, _) if i == 1 => + baseEnv + .createAgg.createScan + .bindEval(child.typ.rowBindings: _*) + .bindEval("n_cols" -> TInt32) + .bindAgg(child.typ.entryBindings: _*) + .bindScan(child.typ.rowBindings: _*) + case MatrixFilterRows(child, _) if i == 1 => + baseEnv.bindEval(child.typ.rowBindings: _*) + case MatrixMapCols(child, _, _) if i == 1 => + baseEnv + .createAgg.createScan + .bindEval(child.typ.colBindings: _*) + .bindEval("n_rows" -> TInt64) + .bindAgg(child.typ.entryBindings: _*) + .bindScan(child.typ.colBindings: _*) + case MatrixFilterCols(child, _) if i == 1 => + baseEnv.bindEval(child.typ.colBindings: _*) + case MatrixMapEntries(child, _) if i == 1 => + baseEnv.bindEval(child.typ.entryBindings: _*) + case MatrixFilterEntries(child, _) if i == 1 => + baseEnv.bindEval(child.typ.entryBindings: _*) + case MatrixMapGlobals(child, _) if i == 1 => + baseEnv.bindEval(child.typ.globalBindings: _*) + case MatrixAggregateColsByKey(child, _, _) => + if (i == 1) + baseEnv + .bindEval(child.typ.rowBindings: _*) + .createAgg.bindAgg(child.typ.entryBindings: _*) + else if (i == 2) + baseEnv + .bindEval(child.typ.globalBindings: _*) + .createAgg.bindAgg(child.typ.colBindings: _*) + else baseEnv + case MatrixAggregateRowsByKey(child, _, _) => + if (i == 1) + baseEnv + .bindEval(child.typ.colBindings: _*) + .createAgg.bindAgg(child.typ.entryBindings: _*) + else if (i == 2) + baseEnv + .bindEval(child.typ.globalBindings: _*) + .createAgg.bindAgg(child.typ.rowBindings: _*) + else baseEnv + case RelationalLetMatrixTable(name, value, _) if i == 1 => + baseEnv.bindRelational(name -> value.typ) + case _ => + baseEnv } } -} -object ChildBindings { - def apply(ir: BaseIR, i: Int, baseEnv: BindingEnv[Type]): BindingEnv[Type] = { - val env = ChildEnvWithoutBindings(ir, i, baseEnv) - val newBindings = NewBindings(ir, i, env) - env.merge(newBindings) + private def childEnvTable[E <: GenericBindingEnv[E, Type]](ir: TableIR, i: Int, _baseEnv: E) + : E = { + val baseEnv = _baseEnv.onlyRelational() + ir match { + case TableFilter(child, _) if i == 1 => + baseEnv.bindEval(child.typ.rowBindings: _*) + case TableGen(contexts, globals, cname, gname, _, _, _) if i == 2 => + baseEnv.bindEval( + cname -> elementType(contexts.typ), + gname -> globals.typ, + ) + case TableMapGlobals(child, _) if i == 1 => + baseEnv.bindEval(child.typ.globalBindings: _*) + case TableMapRows(child, _) if i == 1 => + baseEnv + .bindEval(child.typ.rowBindings: _*) + .createScan.bindScan(child.typ.rowBindings: _*) + case TableAggregateByKey(child, _) if i == 1 => + baseEnv + .bindEval(child.typ.globalBindings: _*) + .createAgg.bindAgg(child.typ.rowBindings: _*) + case TableKeyByAndAggregate(child, _, _, _, _) => + if (i == 1) + baseEnv + .bindEval(child.typ.globalBindings: _*) + .createAgg.bindAgg(child.typ.rowBindings: _*) + else if (i == 2) + baseEnv.bindEval(child.typ.rowBindings: _*) + else baseEnv + case TableMapPartitions(child, g, p, _, _, _) if i == 1 => + baseEnv.bindEval( + g -> child.typ.globalType, + p -> TStream(child.typ.rowType), + ) + case RelationalLetTable(name, value, _) if i == 1 => + baseEnv.bindRelational(name -> value.typ) + case _ => + baseEnv + } } - def transformed[T](ir: BaseIR, i: Int, baseEnv: BindingEnv[T], f: (String, Type) => T): BindingEnv[T] = { - val env = ChildEnvWithoutBindings(ir, i, baseEnv) - val newBindings = NewBindings(ir, i, env).mapValuesWithKey(f) - env.merge(newBindings) + private def childEnvBlockMatrix[E <: GenericBindingEnv[E, Type]]( + ir: BlockMatrixIR, + i: Int, + _baseEnv: E, + ): E = { + val baseEnv = _baseEnv.onlyRelational() + ir match { + case BlockMatrixMap(_, eltName, _, _) if i == 1 => + baseEnv.bindEval(eltName -> TFloat64) + case BlockMatrixMap2(_, _, lName, rName, _, _) if i == 2 => + baseEnv.bindEval(lName -> TFloat64, rName -> TFloat64) + case RelationalLetBlockMatrix(name, value, _) if i == 1 => + baseEnv.bindRelational(name -> value.typ) + case _ => + baseEnv + } } + + private def childEnvValue[E <: GenericBindingEnv[E, Type]](ir: IR, i: Int, baseEnv: E): E = + ir match { + case Block(bindings, _) => + var env = baseEnv + for (k <- 0 until i) bindings(k) match { + case Binding(name, value, scope) => + env = env.bindInScope(name, value.typ, scope) + } + if (i < bindings.length) bindings(i).scope match { + case Scope.EVAL => env + case Scope.AGG => env.promoteAgg + case Scope.SCAN => env.promoteScan + } + else env + case TailLoop(name, args, resultType, _) if i == args.length => + baseEnv + .bindEval(args.map { case (name, ir) => name -> ir.typ }: _*) + .bindEval(name -> TTuple(TTuple(args.map(_._2.typ): _*), resultType)) + case StreamMap(a, name, _) if i == 1 => + baseEnv.bindEval(name -> elementType(a.typ)) + case StreamZip(as, names, _, _, _) if i == as.length => + baseEnv.bindEval(names.zip(as.map(a => elementType(a.typ))): _*) + case StreamZipJoin(as, key, curKey, curVals, _) if i == as.length => + val eltType = tcoerce[TStruct](elementType(as.head.typ)) + baseEnv.bindEval( + curKey -> eltType.typeAfterSelectNames(key), + curVals -> TArray(eltType), + ) + case StreamZipJoinProducers(contexts, ctxName, makeProducer, key, curKey, curVals, _) => + if (i == 1) { + val contextType = elementType(contexts.typ) + baseEnv.bindEval(ctxName -> contextType) + } else if (i == 2) { + val eltType = tcoerce[TStruct](elementType(makeProducer.typ)) + baseEnv.bindEval( + curKey -> eltType.typeAfterSelectNames(key), + curVals -> TArray(eltType), + ) + } else baseEnv + case StreamLeftIntervalJoin(left, right, _, _, lEltName, rEltName, _) if i == 2 => + baseEnv.bindEval( + lEltName -> elementType(left.typ), + rEltName -> TArray(elementType(right.typ)), + ) + case StreamFor(a, name, _) if i == 1 => + baseEnv.bindEval(name -> elementType(a.typ)) + case StreamFlatMap(a, name, _) if i == 1 => + baseEnv.bindEval(name -> elementType(a.typ)) + case StreamFilter(a, name, _) if i == 1 => + baseEnv.bindEval(name -> elementType(a.typ)) + case StreamTakeWhile(a, name, _) if i == 1 => + baseEnv.bindEval(name -> elementType(a.typ)) + case StreamDropWhile(a, name, _) if i == 1 => + baseEnv.bindEval(name -> elementType(a.typ)) + case StreamFold(a, zero, accumName, valueName, _) if i == 2 => + baseEnv.bindEval(accumName -> zero.typ, valueName -> elementType(a.typ)) + case StreamFold2(a, accum, valueName, _, _) => + if (i <= accum.length) + baseEnv + else if (i < 2 * accum.length + 1) + baseEnv + .bindEval(valueName -> elementType(a.typ)) + .bindEval(accum.map { case (name, value) => (name, value.typ) }: _*) + else + baseEnv.bindEval(accum.map { case (name, value) => (name, value.typ) }: _*) + case StreamBufferedAggregate(stream, _, _, _, name, _, _) if i > 0 => + baseEnv.bindEval(name -> elementType(stream.typ)) + case RunAggScan(a, name, _, _, _, _) if i == 2 || i == 3 => + baseEnv.bindEval(name -> elementType(a.typ)) + case StreamScan(a, zero, accumName, valueName, _) if i == 2 => + baseEnv.bindEval( + accumName -> zero.typ, + valueName -> elementType(a.typ), + ) + case StreamAggScan(a, name, _) if i == 1 => + val eltType = elementType(a.typ) + baseEnv + .bindEval(name -> eltType) + .createScan.bindScan(name -> eltType) + case StreamJoinRightDistinct(ll, rr, _, _, l, r, _, _) if i == 2 => + baseEnv.bindEval( + l -> elementType(ll.typ), + r -> elementType(rr.typ), + ) + case ArraySort(a, left, right, _) if i == 1 => + baseEnv.bindEval( + left -> elementType(a.typ), + right -> elementType(a.typ), + ) + case ArrayMaximalIndependentSet(a, Some((left, right, _))) if i == 1 => + val typ = tcoerce[TBaseStruct](elementType(a.typ)).types.head + val tupleType = TTuple(typ) + baseEnv.dropEval.bindEval(left -> tupleType, right -> tupleType) + case AggArrayPerElement(a, elementName, indexName, _, _, isScan) => + if (i == 0) baseEnv.promoteAggOrScan(isScan) + else if (i == 1) + baseEnv + .bindEval(indexName -> TInt32) + .bindAggOrScan( + isScan, + elementName -> elementType(a.typ), + indexName -> TInt32, + ) + else baseEnv + case AggFold(zero, _, _, accumName, otherAccumName, isScan) => + if (i == 0) baseEnv.noAggOrScan(isScan) + else if (i == 1) baseEnv.promoteAggOrScan(isScan).bindEval(accumName -> zero.typ) + else baseEnv.dropEval.noAggOrScan(isScan) + .bindEval(accumName -> zero.typ, otherAccumName -> zero.typ) + case NDArrayMap(nd, name, _) if i == 1 => + baseEnv.bindEval(name -> tcoerce[TNDArray](nd.typ).elementType) + case NDArrayMap2(l, r, lName, rName, _, _) if i == 2 => + baseEnv.bindEval( + lName -> tcoerce[TNDArray](l.typ).elementType, + rName -> tcoerce[TNDArray](r.typ).elementType, + ) + case CollectDistributedArray(contexts, globals, cname, gname, _, _, _, _) if i == 2 => + baseEnv.onlyRelational().bindEval( + cname -> elementType(contexts.typ), + gname -> globals.typ, + ) + case TableAggregate(child, _) => + if (i == 1) + baseEnv.onlyRelational() + .bindEval(child.typ.globalBindings: _*) + .createAgg.bindAgg(child.typ.rowBindings: _*) + else baseEnv.onlyRelational() + case MatrixAggregate(child, _) => + if (i == 1) + baseEnv.onlyRelational() + .bindEval(child.typ.globalBindings: _*) + .createAgg.bindAgg(child.typ.entryBindings: _*) + else baseEnv.onlyRelational() + case ApplyAggOp(init, _, _) => + if (i < init.length) baseEnv.noAgg + else baseEnv.promoteAgg + case ApplyScanOp(init, _, _) => + if (i < init.length) baseEnv.noScan + else baseEnv.promoteScan + case AggFilter(_, _, isScan) if i == 0 => + baseEnv.promoteAggOrScan(isScan) + case AggGroupBy(_, _, isScan) if i == 0 => + baseEnv.promoteAggOrScan(isScan) + case AggExplode(a, name, _, isScan) => + if (i == 0) baseEnv.promoteAggOrScan(isScan) + else baseEnv.bindAggOrScan(isScan, name -> elementType(a.typ)) + case StreamAgg(a, name, _) if i == 1 => + baseEnv.createAgg + .bindAgg(name -> elementType(a.typ)) + case RelationalLet(name, value, _) => + if (i == 1) + baseEnv.noAgg.noScan.bindRelational(name -> value.typ) + else + baseEnv.onlyRelational() + case _: LiftMeOut => + baseEnv.onlyRelational(keepAggCapabilities = true) + case _ => + if (UsesAggEnv(ir, i)) baseEnv.promoteAgg + else if (UsesScanEnv(ir, i)) baseEnv.promoteScan + else baseEnv + } } diff --git a/hail/src/main/scala/is/hail/expr/ir/BlockMatrixIR.scala b/hail/src/main/scala/is/hail/expr/ir/BlockMatrixIR.scala index b56a02edbc2..133c423d19c 100644 --- a/hail/src/main/scala/is/hail/expr/ir/BlockMatrixIR.scala +++ b/hail/src/main/scala/is/hail/expr/ir/BlockMatrixIR.scala @@ -1,43 +1,48 @@ package is.hail.expr.ir -import breeze.linalg.DenseMatrix -import breeze.numerics import is.hail.HailContext import is.hail.annotations.NDArray import is.hail.backend.{BackendContext, ExecuteContext} import is.hail.expr.Nat import is.hail.expr.ir.lowering.{BMSContexts, BlockMatrixStage2, LowererUnsupportedOperation} -import is.hail.io.fs.FS import is.hail.io.{StreamBufferSpec, TypedCodecSpec} +import is.hail.io.fs.FS import is.hail.linalg.{BlockMatrix, BlockMatrixMetadata} +import is.hail.types.{BlockMatrixSparsity, BlockMatrixType} import is.hail.types.encoded.{EBlockMatrixNDArray, EFloat64, ENumpyBinaryNDArray} import is.hail.types.virtual._ -import is.hail.types.{BlockMatrixSparsity, BlockMatrixType} import is.hail.utils._ import is.hail.utils.richUtils.RichDenseMatrixDouble -import org.json4s.{DefaultFormats, Extraction, Formats, JValue, ShortTypeHints} import scala.collection.immutable.NumericRange import scala.collection.mutable.ArrayBuffer +import breeze.linalg.DenseMatrix +import breeze.numerics +import org.json4s.{DefaultFormats, Extraction, Formats, JValue, ShortTypeHints} + object BlockMatrixIR { - def checkFitsIntoArray(nRows: Long, nCols: Long) { + def checkFitsIntoArray(nRows: Long, nCols: Long): Unit = { require(nRows <= Int.MaxValue, s"Number of rows exceeds Int.MaxValue: $nRows") require(nCols <= Int.MaxValue, s"Number of columns exceeds Int.MaxValue: $nCols") - require(nRows * nCols <= Int.MaxValue, s"Number of values exceeds Int.MaxValue: ${ nRows * nCols }") + require( + nRows * nCols <= Int.MaxValue, + s"Number of values exceeds Int.MaxValue: ${nRows * nCols}", + ) } def toBlockMatrix( nRows: Int, nCols: Int, data: Array[Double], - blockSize: Int = BlockMatrix.defaultBlockSize): BlockMatrix = { - + blockSize: Int = BlockMatrix.defaultBlockSize, + ): BlockMatrix = BlockMatrix.fromBreezeMatrix( - new DenseMatrix[Double](nRows, nCols, data, 0, nCols, isTranspose = true), blockSize) - } + new DenseMatrix[Double](nRows, nCols, data, 0, nCols, isTranspose = true), + blockSize, + ) - def matrixShapeToTensorShape(nRows: Long, nCols: Long): (IndexedSeq[Long], Boolean) = { + def matrixShapeToTensorShape(nRows: Long, nCols: Long): (IndexedSeq[Long], Boolean) = { (nRows, nCols) match { case (1, 1) => (FastSeq(), false) case (_, 1) => (FastSeq(nRows), false) @@ -59,7 +64,7 @@ object BlockMatrixIR { } } -abstract sealed class BlockMatrixIR extends BaseIR { +sealed abstract class BlockMatrixIR extends BaseIR { def typ: BlockMatrixType protected[ir] def execute(ctx: ExecuteContext): BlockMatrix = @@ -90,29 +95,34 @@ case class BlockMatrixRead(reader: BlockMatrixReader) extends BlockMatrixIR { object BlockMatrixReader { implicit val formats: Formats = new DefaultFormats() { override val typeHints = ShortTypeHints( - List(classOf[BlockMatrixNativeReader], classOf[BlockMatrixBinaryReader], classOf[BlockMatrixPersistReader]), - typeHintFieldName = "name") + List( + classOf[BlockMatrixNativeReader], + classOf[BlockMatrixBinaryReader], + classOf[BlockMatrixPersistReader], + ), + typeHintFieldName = "name", + ) } - def fromJValue(ctx: ExecuteContext, jv: JValue): BlockMatrixReader = { + def fromJValue(ctx: ExecuteContext, jv: JValue): BlockMatrixReader = (jv \ "name").extract[String] match { case "BlockMatrixNativeReader" => BlockMatrixNativeReader.fromJValue(ctx.fs, jv) case "BlockMatrixPersistReader" => BlockMatrixPersistReader.fromJValue(ctx.backendContext, jv) case _ => jv.extract[BlockMatrixReader] } - } } - abstract class BlockMatrixReader { def pathsUsed: Seq[String] def apply(ctx: ExecuteContext): BlockMatrix + def lower(ctx: ExecuteContext, evalCtx: IRBuilder): BlockMatrixStage2 = - throw new LowererUnsupportedOperation(s"BlockMatrixReader not implemented: ${ this.getClass }") + throw new LowererUnsupportedOperation(s"BlockMatrixReader not implemented: ${this.getClass}") + def fullType: BlockMatrixType - def toJValue: JValue = { + + def toJValue: JValue = Extraction.decompose(this)(BlockMatrixReader.formats) - } } object BlockMatrixNativeReader { @@ -135,13 +145,20 @@ case class BlockMatrixNativeReaderParameters(path: String) class BlockMatrixNativeReader( val params: BlockMatrixNativeReaderParameters, - val metadata: BlockMatrixMetadata) extends BlockMatrixReader { + val metadata: BlockMatrixMetadata, +) extends BlockMatrixReader { def pathsUsed: Seq[String] = Array(params.path) lazy val fullType: BlockMatrixType = { - val (tensorShape, isRowVector) = BlockMatrixIR.matrixShapeToTensorShape(metadata.nRows, metadata.nCols) - - val sparsity = BlockMatrixSparsity.fromLinearBlocks(metadata.nRows, metadata.nCols, metadata.blockSize, metadata.maybeFiltered) + val (tensorShape, isRowVector) = + BlockMatrixIR.matrixShapeToTensorShape(metadata.nRows, metadata.nCols) + + val sparsity = BlockMatrixSparsity.fromLinearBlocks( + metadata.nRows, + metadata.nCols, + metadata.blockSize, + metadata.maybeFiltered, + ) BlockMatrixType(TFloat64, tensorShape, isRowVector, metadata.blockSize, sparsity) } @@ -149,8 +166,7 @@ class BlockMatrixNativeReader( val key = ("BlockMatrixNativeReader.apply", params.path) if (ctx.memo.contains(key)) { ctx.memo(key).asInstanceOf[BlockMatrix] - } - else { + } else { val bm = BlockMatrix.read(ctx.fs, params.path) ctx.memo.update(key, bm) bm @@ -164,13 +180,21 @@ class BlockMatrixNativeReader( val contexts = BMSContexts(fullType, fileNames, evalCtx) val vType = TNDArray(fullType.elementType, Nat(2)) - val spec = TypedCodecSpec(EBlockMatrixNDArray(EFloat64(required = true), required = true), vType, BlockMatrix.bufferSpec) + val spec = TypedCodecSpec( + EBlockMatrixNDArray(EFloat64(required = true), required = true), + vType, + BlockMatrix.bufferSpec, + ) val reader = ETypeValueReader(spec) def blockIR(ctx: IR): IR = { - val path = Apply("concat", FastSeq(), - FastSeq(Str(s"${ params.path }/parts/"), ctx), - TString, ErrorIDs.NO_ERROR) + val path = Apply( + "concat", + FastSeq(), + FastSeq(Str(s"${params.path}/parts/"), ctx), + TString, + ErrorIDs.NO_ERROR, + ) ReadValue(path, reader, vType) } @@ -179,12 +203,12 @@ class BlockMatrixNativeReader( FastSeq(), fullType, contexts, - blockIR) + blockIR, + ) } - override def toJValue: JValue = { + override def toJValue: JValue = decomposeWithName(params, "BlockMatrixNativeReader")(BlockMatrixReader.formats) - } override def hashCode(): Int = params.hashCode() @@ -194,33 +218,52 @@ class BlockMatrixNativeReader( } } -case class BlockMatrixBinaryReader(path: String, shape: IndexedSeq[Long], blockSize: Int) extends BlockMatrixReader { +case class BlockMatrixBinaryReader(path: String, shape: IndexedSeq[Long], blockSize: Int) + extends BlockMatrixReader { def pathsUsed: Seq[String] = Array(path) val IndexedSeq(nRows, nCols) = shape BlockMatrixIR.checkFitsIntoArray(nRows, nCols) - lazy val fullType: BlockMatrixType = { + lazy val fullType: BlockMatrixType = BlockMatrixType.dense(TFloat64, nRows, nCols, blockSize) - } def apply(ctx: ExecuteContext): BlockMatrix = { - val breezeMatrix = RichDenseMatrixDouble.importFromDoubles(ctx.fs, path, nRows.toInt, nCols.toInt, rowMajor = true) + val breezeMatrix = RichDenseMatrixDouble.importFromDoubles( + ctx.fs, + path, + nRows.toInt, + nCols.toInt, + rowMajor = true, + ) BlockMatrix.fromBreezeMatrix(breezeMatrix, blockSize) } override def lower(ctx: ExecuteContext, evalCtx: IRBuilder): BlockMatrixStage2 = { // FIXME numpy should be it's own value reader val readFromNumpyEType = ENumpyBinaryNDArray(nRows, nCols, true) - val readFromNumpySpec = TypedCodecSpec(readFromNumpyEType, TNDArray(TFloat64, Nat(2)), new StreamBufferSpec()) + val readFromNumpySpec = + TypedCodecSpec(readFromNumpyEType, TNDArray(TFloat64, Nat(2)), new StreamBufferSpec()) val reader = ETypeValueReader(readFromNumpySpec) val nd = evalCtx.memoize(ReadValue(Str(path), reader, TNDArray(TFloat64, nDimsBase = Nat(2)))) val typ = fullType val contexts = BMSContexts.tabulate(typ, evalCtx) { (blockRow, blockCol) => - NDArraySlice(nd, MakeTuple.ordered(FastSeq( - MakeTuple.ordered(FastSeq(blockRow.toL * blockSize.toLong, minIR((blockRow + 1).toL * blockSize.toLong, nRows), 1L)), - MakeTuple.ordered(FastSeq(blockCol.toL * blockSize.toLong, minIR((blockCol + 1).toL * blockSize.toLong, nCols), 1L))))) + NDArraySlice( + nd, + MakeTuple.ordered(FastSeq( + MakeTuple.ordered(FastSeq( + blockRow.toL * blockSize.toLong, + minIR((blockRow + 1).toL * blockSize.toLong, nRows), + 1L, + )), + MakeTuple.ordered(FastSeq( + blockCol.toL * blockSize.toLong, + minIR((blockCol + 1).toL * blockSize.toLong, nCols), + 1L, + )), + )), + ) } def blockIR(ctx: IR) = ctx @@ -235,22 +278,25 @@ object BlockMatrixPersistReader { def fromJValue(ctx: BackendContext, jv: JValue): BlockMatrixPersistReader = { implicit val formats: Formats = BlockMatrixReader.formats val params = jv.extract[BlockMatrixNativePersistParameters] - BlockMatrixPersistReader(params.id, HailContext.backend.getPersistedBlockMatrixType(ctx, params.id)) + BlockMatrixPersistReader( + params.id, + HailContext.backend.getPersistedBlockMatrixType(ctx, params.id), + ) } } case class BlockMatrixPersistReader(id: String, typ: BlockMatrixType) extends BlockMatrixReader { def pathsUsed: Seq[String] = FastSeq() lazy val fullType: BlockMatrixType = typ - def apply(ctx: ExecuteContext): BlockMatrix = { + + def apply(ctx: ExecuteContext): BlockMatrix = HailContext.backend.getPersistedBlockMatrix(ctx.backendContext, id) - } } -case class BlockMatrixMap(child: BlockMatrixIR, eltName: String, f: IR, needsDense: Boolean) extends BlockMatrixIR { - override def typecheck(): Unit = { +case class BlockMatrixMap(child: BlockMatrixIR, eltName: String, f: IR, needsDense: Boolean) + extends BlockMatrixIR { + override def typecheck(): Unit = assert(!(needsDense && child.typ.isSparse)) - } override def typ: BlockMatrixType = child.typ @@ -270,7 +316,8 @@ case class BlockMatrixMap(child: BlockMatrixIR, eltName: String, f: IR, needsDen res.asInstanceOf[Double] } - private def binaryOp(scalar: Double, f: (DenseMatrix[Double], Double) => DenseMatrix[Double]): DenseMatrix[Double] => DenseMatrix[Double] = + private def binaryOp(scalar: Double, f: (DenseMatrix[Double], Double) => DenseMatrix[Double]) + : DenseMatrix[Double] => DenseMatrix[Double] = f(_, scalar) override protected[ir] def execute(ctx: ExecuteContext): BlockMatrix = { @@ -291,18 +338,19 @@ case class BlockMatrixMap(child: BlockMatrixIR, eltName: String, f: IR, needsDen case Constant(k) => IndexedSeq(k) } - assert(functionArgs.forall(ir => IsConstant(ir) || ir.isInstanceOf[Ref]), + assert( + functionArgs.forall(ir => IsConstant(ir) || ir.isInstanceOf[Ref]), "Spark backend without lowering does not support general mapping over " + - "BlockMatrix entries. Use predefined functions like `BlockMatrix.abs`.") - + "BlockMatrix entries. Use predefined functions like `BlockMatrix.abs`.", + ) val (name, breezeF): (String, DenseMatrix[Double] => DenseMatrix[Double]) = f match { case ApplyUnaryPrimOp(Negate, _) => ("negate", BlockMatrix.negationOp) case Apply("abs", _, _, _, _) => ("abs", numerics.abs(_)) case Apply("log", _, _, _, _) => ("log", numerics.log(_)) - case Apply("sqrt", _, _, _,_) => ("sqrt", numerics.sqrt(_)) - case Apply("ceil", _, _, _,_) => ("ceil", numerics.ceil(_)) - case Apply("floor", _, _, _,_) => ("floor", numerics.floor(_)) + case Apply("sqrt", _, _, _, _) => ("sqrt", numerics.sqrt(_)) + case Apply("ceil", _, _, _, _) => ("ceil", numerics.ceil(_)) + case Apply("floor", _, _, _, _) => ("floor", numerics.floor(_)) case Apply("pow", _, Seq(Ref(`eltName`, _), r), _, _) if !Mentions(r, eltName) => ("**", binaryOp(evalIR(ctx, r), numerics.pow(_, _))) @@ -320,9 +368,11 @@ case class BlockMatrixMap(child: BlockMatrixIR, eltName: String, f: IR, needsDen ("-", binaryOp(evalIR(ctx, r), (m, s) => m - s)) case ApplyBinaryPrimOp(Subtract(), l, Ref(`eltName`, _)) if !Mentions(l, eltName) => ("-", binaryOp(evalIR(ctx, l), (m, s) => s - m)) - case ApplyBinaryPrimOp(FloatingPointDivide(), Ref(`eltName`, _), r) if !Mentions(r, eltName) => + case ApplyBinaryPrimOp(FloatingPointDivide(), Ref(`eltName`, _), r) + if !Mentions(r, eltName) => ("/", binaryOp(evalIR(ctx, r), (m, s) => m /:/ s)) - case ApplyBinaryPrimOp(FloatingPointDivide(), l, Ref(`eltName`, _)) if !Mentions(l, eltName) => + case ApplyBinaryPrimOp(FloatingPointDivide(), l, Ref(`eltName`, _)) + if !Mentions(l, eltName) => ("/", binaryOp(evalIR(ctx, l), BlockMatrix.reverseScalarDiv)) case Ref(`eltName`, _) => ("identity", identity(_)) @@ -349,26 +399,30 @@ abstract class SparsityStrategy { def exists(leftBlock: Boolean, rightBlock: Boolean): Boolean def mergeSparsity(left: BlockMatrixSparsity, right: BlockMatrixSparsity): BlockMatrixSparsity } + case object UnionBlocks extends SparsityStrategy { def exists(leftBlock: Boolean, rightBlock: Boolean): Boolean = leftBlock || rightBlock - def mergeSparsity(left: BlockMatrixSparsity, right: BlockMatrixSparsity): BlockMatrixSparsity = { + + def mergeSparsity(left: BlockMatrixSparsity, right: BlockMatrixSparsity): BlockMatrixSparsity = if (left.isSparse && right.isSparse) { BlockMatrixSparsity(left.blockSet.union(right.blockSet).toFastSeq) } else BlockMatrixSparsity.dense - } } + case object IntersectionBlocks extends SparsityStrategy { def exists(leftBlock: Boolean, rightBlock: Boolean): Boolean = leftBlock && rightBlock - def mergeSparsity(left: BlockMatrixSparsity, right: BlockMatrixSparsity): BlockMatrixSparsity = { + + def mergeSparsity(left: BlockMatrixSparsity, right: BlockMatrixSparsity): BlockMatrixSparsity = if (right.isSparse) { if (left.isSparse) BlockMatrixSparsity(left.blockSet.intersect(right.blockSet).toFastSeq) else right } else left - } } + case object NeedsDense extends SparsityStrategy { def exists(leftBlock: Boolean, rightBlock: Boolean): Boolean = true + def mergeSparsity(left: BlockMatrixSparsity, right: BlockMatrixSparsity): BlockMatrixSparsity = { assert(!left.isSparse && !right.isSparse) BlockMatrixSparsity.dense @@ -381,7 +435,7 @@ case class BlockMatrixMap2( leftName: String, rightName: String, f: IR, - sparsityStrategy: SparsityStrategy + sparsityStrategy: SparsityStrategy, ) extends BlockMatrixIR { override def typecheck(): Unit = { assert(left.typ.nRows == right.typ.nRows) @@ -389,7 +443,8 @@ case class BlockMatrixMap2( assert(left.typ.blockSize == right.typ.blockSize) } - override lazy val typ: BlockMatrixType = left.typ.copy(sparsity = sparsityStrategy.mergeSparsity(left.typ.sparsity, right.typ.sparsity)) + override lazy val typ: BlockMatrixType = + left.typ.copy(sparsity = sparsityStrategy.mergeSparsity(left.typ.sparsity, right.typ.sparsity)) lazy val childrenSeq: IndexedSeq[BaseIR] = Array(left, right, f) @@ -400,9 +455,11 @@ case class BlockMatrixMap2( BlockMatrixMap2( newChildren(0).asInstanceOf[BlockMatrixIR], newChildren(1).asInstanceOf[BlockMatrixIR], - leftName, rightName, + leftName, + rightName, newChildren(2).asInstanceOf[IR], - sparsityStrategy) + sparsityStrategy, + ) } override protected[ir] def execute(ctx: ExecuteContext): BlockMatrix = { @@ -410,7 +467,7 @@ case class BlockMatrixMap2( left match { case BlockMatrixBroadcast(vectorIR: BlockMatrixIR, IndexedSeq(x), _, _) => - val vector = coerceToVector(ctx , vectorIR) + val vector = coerceToVector(ctx, vectorIR) x match { case 1 => rowVectorOnLeft(ctx, vector, right, f) case 0 => colVectorOnLeft(ctx, vector, right, f) @@ -420,13 +477,24 @@ case class BlockMatrixMap2( } } - private def rowVectorOnLeft(ctx: ExecuteContext, rowVector: Array[Double], right: BlockMatrixIR, f: IR): BlockMatrix = + private def rowVectorOnLeft( + ctx: ExecuteContext, + rowVector: Array[Double], + right: BlockMatrixIR, + f: IR, + ): BlockMatrix = opWithRowVector(right.execute(ctx), rowVector, f, reverse = true) - private def colVectorOnLeft(ctx: ExecuteContext, colVector: Array[Double], right: BlockMatrixIR, f: IR): BlockMatrix = + private def colVectorOnLeft( + ctx: ExecuteContext, + colVector: Array[Double], + right: BlockMatrixIR, + f: IR, + ): BlockMatrix = opWithColVector(right.execute(ctx), colVector, f, reverse = true) - private def matrixOnLeft(ctx: ExecuteContext, matrix: BlockMatrix, right: BlockMatrixIR, f: IR): BlockMatrix = { + private def matrixOnLeft(ctx: ExecuteContext, matrix: BlockMatrix, right: BlockMatrixIR, f: IR) + : BlockMatrix = { right match { case BlockMatrixBroadcast(vectorIR, IndexedSeq(x), _, _) => x match { @@ -447,17 +515,17 @@ case class BlockMatrixMap2( case ValueToBlockMatrix(child, _, _) => Interpret[Any](ctx, child) match { case vector: IndexedSeq[_] => vector.asInstanceOf[IndexedSeq[Double]].toArray - case vector: NDArray => { + case vector: NDArray => val IndexedSeq(numRows, numCols) = vector.shape assert(numRows == 1L || numCols == 1L) vector.getRowMajorElements().asInstanceOf[IndexedSeq[Double]].toArray - } } case _ => ir.execute(ctx).toBreezeMatrix().data } } - private def opWithRowVector(left: BlockMatrix, right: Array[Double], f: IR, reverse: Boolean): BlockMatrix = { + private def opWithRowVector(left: BlockMatrix, right: Array[Double], f: IR, reverse: Boolean) + : BlockMatrix = { (f: @unchecked) match { case ApplyBinaryPrimOp(Add(), _, _) => left.rowVectorAdd(right) case ApplyBinaryPrimOp(Multiply(), _, _) => left.rowVectorMul(right) @@ -468,7 +536,8 @@ case class BlockMatrixMap2( } } - private def opWithColVector(left: BlockMatrix, right: Array[Double], f: IR, reverse: Boolean): BlockMatrix = { + private def opWithColVector(left: BlockMatrix, right: Array[Double], f: IR, reverse: Boolean) + : BlockMatrix = { (f: @unchecked) match { case ApplyBinaryPrimOp(Add(), _, _) => left.colVectorAdd(right) case ApplyBinaryPrimOp(Multiply(), _, _) => left.colVectorMul(right) @@ -500,11 +569,13 @@ case class BlockMatrixDot(left: BlockMatrixIR, right: BlockMatrixIR) extends Blo val sparsity = if (left.typ.isSparse || right.typ.isSparse) BlockMatrixSparsity.constructFromShapeAndFunction( BlockMatrixType.numBlocks(lRows, blockSize), - BlockMatrixType.numBlocks(rCols, blockSize)) { (i: Int, j: Int) => + BlockMatrixType.numBlocks(rCols, blockSize), + ) { (i: Int, j: Int) => Array.tabulate(BlockMatrixType.numBlocks(rCols, blockSize)) { k => left.typ.hasBlock(i -> k) && right.typ.hasBlock(k -> j) }.reduce(_ || _) - } else BlockMatrixSparsity.dense + } + else BlockMatrixSparsity.dense BlockMatrixType(left.typ.elementType, tensorShape, isRowVector, blockSize, sparsity) } @@ -512,7 +583,10 @@ case class BlockMatrixDot(left: BlockMatrixIR, right: BlockMatrixIR) extends Blo def copy(newChildren: IndexedSeq[BaseIR]): BlockMatrixDot = { assert(newChildren.length == 2) - BlockMatrixDot(newChildren(0).asInstanceOf[BlockMatrixIR], newChildren(1).asInstanceOf[BlockMatrixIR]) + BlockMatrixDot( + newChildren(0).asInstanceOf[BlockMatrixIR], + newChildren(1).asInstanceOf[BlockMatrixIR], + ) } val blockCostIsLinear: Boolean = false @@ -523,15 +597,17 @@ case class BlockMatrixDot(left: BlockMatrixIR, right: BlockMatrixIR) extends Blo val fs = ctx.fs if (!left.blockCostIsLinear) { val path = ctx.createTmpPath("blockmatrix-dot-left", "bm") - info(s"BlockMatrix multiply: writing left input with ${ leftBM.nRows } rows and ${ leftBM.nCols } cols " + - s"(${ leftBM.gp.nBlocks } blocks of size ${ leftBM.blockSize }) to temporary file $path") + info(s"BlockMatrix multiply: writing left input with ${leftBM.nRows} rows and ${leftBM.nCols} cols " + + s"(${leftBM.gp.nBlocks} blocks of size ${leftBM.blockSize}) to temporary file $path") leftBM.write(ctx, path) leftBM = BlockMatrixNativeReader(fs, path).apply(ctx) } if (!right.blockCostIsLinear) { val path = ctx.createTmpPath("blockmatrix-dot-right", "bm") - info(s"BlockMatrix multiply: writing right input with ${ rightBM.nRows } rows and ${ rightBM.nCols } cols " + - s"(${ rightBM.gp.nBlocks } blocks of size ${ rightBM.blockSize }) to temporary file $path") + info( + s"BlockMatrix multiply: writing right input with ${rightBM.nRows} rows and ${rightBM.nCols} cols " + + s"(${rightBM.gp.nBlocks} blocks of size ${rightBM.blockSize}) to temporary file $path" + ) rightBM.write(ctx, path) rightBM = BlockMatrixNativeReader(fs, path).apply(ctx) } @@ -543,7 +619,7 @@ case class BlockMatrixBroadcast( child: BlockMatrixIR, inIndexExpr: IndexedSeq[Int], shape: IndexedSeq[Long], - blockSize: Int + blockSize: Int, ) extends BlockMatrixIR { val blockCostIsLinear: Boolean = child.blockCostIsLinear @@ -555,9 +631,9 @@ case class BlockMatrixBroadcast( val (nRows, nCols) = BlockMatrixIR.tensorShapeToMatrixShape(child) val childMatrixShape = IndexedSeq(nRows, nCols) - assert(inIndexExpr.zipWithIndex.forall({ case (out: Int, in: Int) => + assert(inIndexExpr.zipWithIndex.forall { case (out: Int, in: Int) => !child.typ.shape.contains(in) || childMatrixShape(in) == shape(out) - })) + }) } override lazy val typ: BlockMatrixType = { @@ -572,24 +648,32 @@ case class BlockMatrixBroadcast( BlockMatrixSparsity.dense case IndexedSeq(0) => // broadcast col vector assert(Set(1, shape(0)) == Set(child.typ.nRows, child.typ.nCols)) - BlockMatrixSparsity.constructFromShapeAndFunction(nRowBlocks, nColBlocks)((i: Int, j: Int) => child.typ.hasBlock(0 -> j)) + BlockMatrixSparsity.constructFromShapeAndFunction(nRowBlocks, nColBlocks)( + (i: Int, j: Int) => child.typ.hasBlock(0 -> j) + ) case IndexedSeq(1) => // broadcast row vector assert(Set(1, shape(1)) == Set(child.typ.nRows, child.typ.nCols)) - BlockMatrixSparsity.constructFromShapeAndFunction(nRowBlocks, nColBlocks)((i: Int, j: Int) => child.typ.hasBlock(i -> 0)) + BlockMatrixSparsity.constructFromShapeAndFunction(nRowBlocks, nColBlocks)( + (i: Int, j: Int) => child.typ.hasBlock(i -> 0) + ) case IndexedSeq(0, 0) => // diagonal as col vector assert(shape(0) == 1L) assert(shape(1) == java.lang.Math.min(child.typ.nRows, child.typ.nCols)) - BlockMatrixSparsity.constructFromShapeAndFunction(nRowBlocks, nColBlocks)((_, j: Int) => child.typ.hasBlock(j -> j)) + BlockMatrixSparsity.constructFromShapeAndFunction(nRowBlocks, nColBlocks)((_, j: Int) => + child.typ.hasBlock(j -> j) + ) case IndexedSeq(1, 0) => // transpose assert(child.typ.blockSize == blockSize) assert(shape(0) == child.typ.nCols && shape(1) == child.typ.nRows) - BlockMatrixSparsity(child.typ.sparsity.definedBlocks.map(seq => seq.map { case (i, j) => (j, i)})) + BlockMatrixSparsity(child.typ.sparsity.definedBlocks.map(seq => + seq.map { case (i, j) => (j, i) } + )) case IndexedSeq(0, 1) => assert(child.typ.blockSize == blockSize) assert(shape(0) == child.typ.nRows && shape(1) == child.typ.nCols) child.typ.sparsity } - else BlockMatrixSparsity.dense + else BlockMatrixSparsity.dense BlockMatrixType(child.typ.elementType, tensorShape, isRowVector, blockSize, sparsity) } @@ -616,7 +700,7 @@ case class BlockMatrixBroadcast( case IndexedSeq(1) => BlockMatrixIR.checkFitsIntoArray(nRows, nCols) broadcastRowVector(childBm.toBreezeMatrix().data, nRows.toInt, nCols.toInt) - // FIXME: I'm pretty sure this case is broken. + // FIXME: I'm pretty sure this case is broken. case IndexedSeq(0, 0) => BlockMatrixIR.checkFitsIntoArray(nRows, nCols) BlockMatrixIR.toBlockMatrix(nRows.toInt, nCols.toInt, childBm.diagonal(), blockSize) @@ -642,7 +726,7 @@ case class BlockMatrixBroadcast( case class BlockMatrixAgg( child: BlockMatrixIR, - axesToSumOut: IndexedSeq[Int] + axesToSumOut: IndexedSeq[Int], ) extends BlockMatrixIR { val blockCostIsLinear: Boolean = child.blockCostIsLinear @@ -652,7 +736,9 @@ case class BlockMatrixAgg( override lazy val typ: BlockMatrixType = { val matrixShape = BlockMatrixIR.tensorShapeToMatrixShape(child) val matrixShapeArr = Array[Long](matrixShape._1, matrixShape._2) - val shape = IndexedSeq(0, 1).filter(i => !axesToSumOut.contains(i)).map({ i: Int => matrixShapeArr(i) }).toFastSeq + val shape = IndexedSeq(0, 1).filter(i => !axesToSumOut.contains(i)).map({ i: Int => + matrixShapeArr(i) + }).toFastSeq val isRowVector = axesToSumOut == FastSeq(0) val sparsity = if (child.typ.isSparse) { @@ -683,7 +769,8 @@ case class BlockMatrixAgg( val childBm = child.execute(ctx) axesToSumOut match { - case IndexedSeq(0, 1) => BlockMatrixIR.toBlockMatrix(nRows = 1, nCols = 1, Array(childBm.sum()), typ.blockSize) + case IndexedSeq(0, 1) => + BlockMatrixIR.toBlockMatrix(nRows = 1, nCols = 1, Array(childBm.sum()), typ.blockSize) case IndexedSeq(0) => childBm.rowSum() case IndexedSeq(1) => childBm.colSum() } @@ -692,7 +779,7 @@ case class BlockMatrixAgg( case class BlockMatrixFilter( child: BlockMatrixIR, - indices: Array[Array[Long]] + indices: Array[Array[Long]], ) extends BlockMatrixIR { assert(indices.length == 2) @@ -715,9 +802,9 @@ case class BlockMatrixFilter( case (IndexedSeq(numRows, numCols), false) => IndexedSeq(numRows, numCols) } - val matrixShape = indices.zipWithIndex.map({ case (dim, i) => + val matrixShape = indices.zipWithIndex.map { case (dim, i) => if (dim.isEmpty) childMatrixShape(i) else dim.length - }) + } val IndexedSeq(nRows: Long, nCols: Long) = matrixShape.toFastSeq val (tensorShape, isRowVector) = BlockMatrixIR.matrixShapeToTensorShape(nRows, nCols) @@ -748,7 +835,10 @@ case class BlockMatrixFilter( case class BlockMatrixDensify(child: BlockMatrixIR) extends BlockMatrixIR { override lazy val typ: BlockMatrixType = BlockMatrixType.dense( child.typ.elementType, - child.typ.nRows, child.typ.nCols, child.typ.blockSize) + child.typ.nRows, + child.typ.nCols, + child.typ.blockSize, + ) def blockCostIsLinear: Boolean = child.blockCostIsLinear @@ -773,48 +863,56 @@ sealed abstract class BlockMatrixSparsifier { //lower <= j - i <= upper case class BandSparsifier(blocksOnly: Boolean, l: Long, u: Long) extends BlockMatrixSparsifier { val typ: Type = TTuple(TInt64, TInt64) + def definedBlocks(childType: BlockMatrixType): BlockMatrixSparsity = { val lowerBlock = java.lang.Math.floorDiv(l, childType.blockSize).toInt val upperBlock = java.lang.Math.floorDiv(u + childType.blockSize - 1, childType.blockSize).toInt - val blocks = (for { j <- 0 until childType.nColBlocks - i <- ((j - upperBlock) max 0) to - ((j - lowerBlock) min (childType.nRowBlocks - 1)) - if (childType.hasBlock(i -> j)) + val blocks = (for { + j <- 0 until childType.nColBlocks + i <- ((j - upperBlock) max 0) to + ((j - lowerBlock) min (childType.nRowBlocks - 1)) + if (childType.hasBlock(i -> j)) } yield (i -> j)).toArray BlockMatrixSparsity(blocks) } - def sparsify(bm: BlockMatrix): BlockMatrix = { + def sparsify(bm: BlockMatrix): BlockMatrix = bm.filterBand(l, u, blocksOnly) - } + def pretty(): String = s"(BandSparsifier ${Pretty.prettyBooleanLiteral(blocksOnly)} $l $u)" } // interval per row, [start, end) -case class RowIntervalSparsifier(blocksOnly: Boolean, starts: IndexedSeq[Long], stops: IndexedSeq[Long]) extends BlockMatrixSparsifier { +case class RowIntervalSparsifier( + blocksOnly: Boolean, + starts: IndexedSeq[Long], + stops: IndexedSeq[Long], +) extends BlockMatrixSparsifier { val typ: Type = TTuple(TArray(TInt64), TArray(TInt64)) def definedBlocks(childType: BlockMatrixType): BlockMatrixSparsity = { - val blockStarts = starts.grouped(childType.blockSize).map(idxs => childType.getBlockIdx(idxs.min)).toArray - val blockStops = stops.grouped(childType.blockSize).map(idxs => childType.getBlockIdx(idxs.max - 1)).toArray + val blockStarts = + starts.grouped(childType.blockSize).map(idxs => childType.getBlockIdx(idxs.min)).toArray + val blockStops = + stops.grouped(childType.blockSize).map(idxs => childType.getBlockIdx(idxs.max - 1)).toArray - BlockMatrixSparsity.constructFromShapeAndFunction(childType.nRowBlocks, childType.nColBlocks) { (i, j) => - blockStarts(i) <= j && blockStops(i) >= j && childType.hasBlock(i -> j) + BlockMatrixSparsity.constructFromShapeAndFunction(childType.nRowBlocks, childType.nColBlocks) { + (i, j) => blockStarts(i) <= j && blockStops(i) >= j && childType.hasBlock(i -> j) } } - def sparsify(bm: BlockMatrix): BlockMatrix = { + def sparsify(bm: BlockMatrix): BlockMatrix = bm.filterRowIntervals(starts.toArray, stops.toArray, blocksOnly) - } def pretty(): String = - s"(RowIntervalSparsifier ${ Pretty.prettyBooleanLiteral(blocksOnly) } ${ starts.mkString("(", " ", ")") } ${ stops.mkString("(", " ", ")") })" + s"(RowIntervalSparsifier ${Pretty.prettyBooleanLiteral(blocksOnly)} ${starts.mkString("(", " ", ")")} ${stops.mkString("(", " ", ")")})" } //rectangle, starts/ends inclusive -case class RectangleSparsifier(rectangles: IndexedSeq[IndexedSeq[Long]]) extends BlockMatrixSparsifier { +case class RectangleSparsifier(rectangles: IndexedSeq[IndexedSeq[Long]]) + extends BlockMatrixSparsifier { val typ: Type = TArray(TInt64) def definedBlocks(childType: BlockMatrixType): BlockMatrixSparsity = { @@ -825,20 +923,19 @@ case class RectangleSparsifier(rectangles: IndexedSeq[IndexedSeq[Long]]) extends val ce = childType.getBlockIdx(java.lang.Math.min(colEnd - 1, childType.nCols)) + 1 Array.range(rs, re).flatMap { i => Array.range(cs, ce) - .filter { j => childType.hasBlock(i -> j) } - .map { j => i -> j } + .filter(j => childType.hasBlock(i -> j)) + .map(j => i -> j) } }.distinct BlockMatrixSparsity(definedBlocks) } - def sparsify(bm: BlockMatrix): BlockMatrix = { + def sparsify(bm: BlockMatrix): BlockMatrix = bm.filterRectangles(rectangles.flatten.toArray) - } def pretty(): String = - s"(RectangleSparsifier ${ rectangles.flatten.mkString("(", " ", ")") })" + s"(RectangleSparsifier ${rectangles.flatten.mkString("(", " ", ")")})" } case class PerBlockSparsifier(blocks: IndexedSeq[Int]) extends BlockMatrixSparsifier { @@ -846,11 +943,11 @@ case class PerBlockSparsifier(blocks: IndexedSeq[Int]) extends BlockMatrixSparsi val blockSet = blocks.toSet - override def definedBlocks(childType: BlockMatrixType): BlockMatrixSparsity = { - BlockMatrixSparsity.constructFromShapeAndFunction(childType.nRowBlocks, childType.nColBlocks){ case(i: Int, j: Int) => - blockSet.contains(i + j * childType.nRowBlocks) + override def definedBlocks(childType: BlockMatrixType): BlockMatrixSparsity = + BlockMatrixSparsity.constructFromShapeAndFunction(childType.nRowBlocks, childType.nColBlocks) { + case (i: Int, j: Int) => + blockSet.contains(i + j * childType.nRowBlocks) } - } override def sparsify(bm: BlockMatrix): BlockMatrix = bm.filterBlocks(blocks.toArray) @@ -859,9 +956,10 @@ case class PerBlockSparsifier(blocks: IndexedSeq[Int]) extends BlockMatrixSparsi case class BlockMatrixSparsify( child: BlockMatrixIR, - sparsifier: BlockMatrixSparsifier + sparsifier: BlockMatrixSparsifier, ) extends BlockMatrixIR { - override lazy val typ: BlockMatrixType = child.typ.copy(sparsity=sparsifier.definedBlocks(child.typ)) + override lazy val typ: BlockMatrixType = + child.typ.copy(sparsity = sparsifier.definedBlocks(child.typ)) def blockCostIsLinear: Boolean = child.blockCostIsLinear @@ -876,18 +974,23 @@ case class BlockMatrixSparsify( sparsifier.sparsify(child.execute(ctx)) } -case class BlockMatrixSlice(child: BlockMatrixIR, slices: IndexedSeq[IndexedSeq[Long]]) extends BlockMatrixIR { +case class BlockMatrixSlice(child: BlockMatrixIR, slices: IndexedSeq[IndexedSeq[Long]]) + extends BlockMatrixIR { assert(slices.length == 2) assert(slices.forall(_.length == 3)) val blockCostIsLinear: Boolean = child.blockCostIsLinear - lazy val IndexedSeq(rowBlockDependents: Array[Array[Int]], colBlockDependents: Array[Array[Int]]) = slices.map { case IndexedSeq(start, stop, step) => + lazy val IndexedSeq( + rowBlockDependents: Array[Array[Int]], + colBlockDependents: Array[Array[Int]], + ) = slices.map { case IndexedSeq(start, stop, step) => val size = 1 + (stop - start - 1) / step val nBlocks = BlockMatrixType.numBlocks(size, child.typ.blockSize) Array.tabulate(nBlocks) { blockIdx => val blockStart = start + blockIdx * child.typ.blockSize * step - val blockEnd = java.lang.Math.min(start + ((blockIdx + 1) * child.typ.blockSize - 1) * step, stop) + val blockEnd = + java.lang.Math.min(start + ((blockIdx + 1) * child.typ.blockSize - 1) * step, stop) Array.range(child.typ.getBlockIdx(blockStart), child.typ.getBlockIdx(blockEnd) + 1) } } @@ -900,7 +1003,8 @@ case class BlockMatrixSlice(child: BlockMatrixIR, slices: IndexedSeq[IndexedSeq[ val sparsity = child.typ.sparsity.condense(rowBlockDependents -> colBlockDependents) - val (tensorShape, isRowVector) = BlockMatrixIR.matrixShapeToTensorShape(matrixShape(0), matrixShape(1)) + val (tensorShape, isRowVector) = + BlockMatrixIR.matrixShapeToTensorShape(matrixShape(0), matrixShape(1)) BlockMatrixType(child.typ.elementType, tensorShape, isRowVector, child.typ.blockSize, sparsity) } @@ -928,35 +1032,33 @@ case class BlockMatrixSlice(child: BlockMatrixIR, slices: IndexedSeq[IndexedSeq[ } } - private def isFullRange(r: NumericRange[Long], dimLength: Long): Boolean = { + private def isFullRange(r: NumericRange[Long], dimLength: Long): Boolean = r.start == 0 && r.end == dimLength && r.step == 1 - } } case class ValueToBlockMatrix( child: IR, shape: IndexedSeq[Long], - blockSize: Int + blockSize: Int, ) extends BlockMatrixIR { - override def typecheck(): Unit = { - assert(child.typ.isInstanceOf[TArray] || child.typ.isInstanceOf[TNDArray] || child.typ == TFloat64) - } + override def typecheck(): Unit = + assert( + child.typ.isInstanceOf[TArray] || child.typ.isInstanceOf[TNDArray] || child.typ == TFloat64 + ) assert(shape.length == 2) val blockCostIsLinear: Boolean = true - override lazy val typ: BlockMatrixType = { + override lazy val typ: BlockMatrixType = BlockMatrixType.dense(elementType(child.typ), shape(0), shape(1), blockSize) - } - private def elementType(childType: Type): Type = { + private def elementType(childType: Type): Type = childType match { case ndarray: TNDArray => ndarray.elementType case array: TArray => array.elementType case _ => childType } - } lazy val childrenSeq: IndexedSeq[BaseIR] = Array(child) @@ -973,9 +1075,19 @@ case class ValueToBlockMatrix( assert(nRows == 1 && nCols == 1) BlockMatrix.fill(nRows, nCols, scalar, blockSize) case data: IndexedSeq[_] => - BlockMatrixIR.toBlockMatrix(nRows.toInt, nCols.toInt, data.asInstanceOf[IndexedSeq[Double]].toArray, blockSize) + BlockMatrixIR.toBlockMatrix( + nRows.toInt, + nCols.toInt, + data.asInstanceOf[IndexedSeq[Double]].toArray, + blockSize, + ) case ndData: NDArray => - BlockMatrixIR.toBlockMatrix(nRows.toInt, nCols.toInt, ndData.getRowMajorElements().asInstanceOf[IndexedSeq[Double]].toArray, blockSize) + BlockMatrixIR.toBlockMatrix( + nRows.toInt, + nCols.toInt, + ndData.getRowMajorElements().asInstanceOf[IndexedSeq[Double]].toArray, + blockSize, + ) } } } @@ -984,7 +1096,8 @@ case class BlockMatrixRandom( staticUID: Long, gaussian: Boolean, shape: IndexedSeq[Long], - blockSize: Int) extends BlockMatrixIR { + blockSize: Int, +) extends BlockMatrixIR { assert(shape.length == 2) @@ -1000,12 +1113,12 @@ case class BlockMatrixRandom( BlockMatrixRandom(staticUID, gaussian, shape, blockSize) } - override protected[ir] def execute(ctx: ExecuteContext): BlockMatrix = { + override protected[ir] def execute(ctx: ExecuteContext): BlockMatrix = BlockMatrix.random(shape(0), shape(1), blockSize, ctx.rngNonce, staticUID, gaussian) - } } -case class RelationalLetBlockMatrix(name: String, value: IR, body: BlockMatrixIR) extends BlockMatrixIR { +case class RelationalLetBlockMatrix(name: String, value: IR, body: BlockMatrixIR) + extends BlockMatrixIR { override def typ: BlockMatrixType = body.typ def childrenSeq: IndexedSeq[BaseIR] = Array(value, body) diff --git a/hail/src/main/scala/is/hail/expr/ir/BlockMatrixWriter.scala b/hail/src/main/scala/is/hail/expr/ir/BlockMatrixWriter.scala index 6da4b7a9b9e..25c7bbe6059 100644 --- a/hail/src/main/scala/is/hail/expr/ir/BlockMatrixWriter.scala +++ b/hail/src/main/scala/is/hail/expr/ir/BlockMatrixWriter.scala @@ -6,82 +6,124 @@ import is.hail.asm4s._ import is.hail.backend.ExecuteContext import is.hail.expr.Nat import is.hail.expr.ir.lowering.{BlockMatrixStage2, LowererUnsupportedOperation} -import is.hail.io.fs.FS import is.hail.io.{StreamBufferSpec, TypedCodecSpec} +import is.hail.io.fs.FS import is.hail.linalg.{BlockMatrix, BlockMatrixMetadata} +import is.hail.types.{BlockMatrixType, TypeWithRequiredness} import is.hail.types.encoded.{EBlockMatrixNDArray, ENumpyBinaryNDArray, EType} import is.hail.types.virtual._ -import is.hail.types.{BlockMatrixType, TypeWithRequiredness} import is.hail.utils._ import is.hail.utils.richUtils.RichDenseMatrixDouble -import org.json4s.{DefaultFormats, Formats, ShortTypeHints, jackson} import java.io.DataOutputStream +import org.json4s.{jackson, DefaultFormats, Formats, ShortTypeHints} + object BlockMatrixWriter { implicit val formats: Formats = new DefaultFormats() { override val typeHints = ShortTypeHints( - List(classOf[BlockMatrixNativeWriter], classOf[BlockMatrixBinaryWriter], classOf[BlockMatrixRectanglesWriter], - classOf[BlockMatrixBinaryMultiWriter], classOf[BlockMatrixTextMultiWriter], - classOf[BlockMatrixPersistWriter], classOf[BlockMatrixNativeMultiWriter]), typeHintFieldName = "name") + List( + classOf[BlockMatrixNativeWriter], + classOf[BlockMatrixBinaryWriter], + classOf[BlockMatrixRectanglesWriter], + classOf[BlockMatrixBinaryMultiWriter], + classOf[BlockMatrixTextMultiWriter], + classOf[BlockMatrixPersistWriter], + classOf[BlockMatrixNativeMultiWriter], + ), + typeHintFieldName = "name", + ) } } - abstract class BlockMatrixWriter { def pathOpt: Option[String] def apply(ctx: ExecuteContext, bm: BlockMatrix): Any def loweredTyp: Type - def lower(ctx: ExecuteContext, s: BlockMatrixStage2, evalCtx: IRBuilder, eltR: TypeWithRequiredness): IR = - throw new LowererUnsupportedOperation(s"unimplemented writer: \n${ this.getClass }") + + def lower( + ctx: ExecuteContext, + s: BlockMatrixStage2, + evalCtx: IRBuilder, + eltR: TypeWithRequiredness, + ): IR = + throw new LowererUnsupportedOperation(s"unimplemented writer: \n${this.getClass}") } case class BlockMatrixNativeWriter( path: String, overwrite: Boolean, forceRowMajor: Boolean, - stageLocally: Boolean) extends BlockMatrixWriter { + stageLocally: Boolean, +) extends BlockMatrixWriter { def pathOpt: Option[String] = Some(path) - def apply(ctx: ExecuteContext, bm: BlockMatrix): Unit = bm.write(ctx, path, overwrite, forceRowMajor, stageLocally) + def apply(ctx: ExecuteContext, bm: BlockMatrix): Unit = + bm.write(ctx, path, overwrite, forceRowMajor, stageLocally) def loweredTyp: Type = TVoid - override def lower(ctx: ExecuteContext, s: BlockMatrixStage2, evalCtx: IRBuilder, eltR: TypeWithRequiredness): IR = { - val etype = EBlockMatrixNDArray(EType.fromTypeAndAnalysis(s.typ.elementType, eltR), encodeRowMajor = forceRowMajor, required = true) + override def lower( + ctx: ExecuteContext, + s: BlockMatrixStage2, + evalCtx: IRBuilder, + eltR: TypeWithRequiredness, + ): IR = { + val etype = EBlockMatrixNDArray( + EType.fromTypeAndAnalysis(s.typ.elementType, eltR), + encodeRowMajor = forceRowMajor, + required = true, + ) val spec = TypedCodecSpec(etype, TNDArray(s.typ.elementType, Nat(2)), BlockMatrix.bufferSpec) val writer = ETypeValueWriter(spec) val paths = s.collectBlocks(evalCtx, "block_matrix_native_writer") { (_, idx, block) => val suffix = strConcat("parts/part-", idx, UUID4()) val filepath = strConcat(s"$path/", suffix) - WriteValue(block, filepath, writer, - if (stageLocally) Some(strConcat(s"${ctx.localTmpdir}/", suffix)) else None + WriteValue( + block, + filepath, + writer, + if (stageLocally) Some(strConcat(s"${ctx.localTmpdir}/", suffix)) else None, ) } - RelationalWriter.scoped(path, overwrite, None)(WriteMetadata(paths, BlockMatrixNativeMetadataWriter(path, stageLocally, s.typ))) + RelationalWriter.scoped(path, overwrite, None)(WriteMetadata( + paths, + BlockMatrixNativeMetadataWriter(path, stageLocally, s.typ), + )) } } -case class BlockMatrixNativeMetadataWriter(path: String, stageLocally: Boolean, typ: BlockMatrixType) extends MetadataWriter { - - case class BMMetadataHelper(path: String, blockSize: Int, nRows: Long, nCols: Long, partIdxToBlockIdx: Option[IndexedSeq[Int]]) { +case class BlockMatrixNativeMetadataWriter( + path: String, + stageLocally: Boolean, + typ: BlockMatrixType, +) extends MetadataWriter { + + case class BMMetadataHelper( + path: String, + blockSize: Int, + nRows: Long, + nCols: Long, + partIdxToBlockIdx: Option[IndexedSeq[Int]], + ) { def write(fs: FS, rawPartFiles: Array[String]): Unit = { val partFiles = rawPartFiles.map(_.split('/').last) using(new DataOutputStream(fs.create(s"$path/metadata.json"))) { os => implicit val formats = defaultJSONFormats jackson.Serialization.write( BlockMatrixMetadata(blockSize, nRows, nCols, partIdxToBlockIdx, partFiles), - os) + os, + ) } val nBlocks = partIdxToBlockIdx.map(_.length).getOrElse { BlockMatrixType.numBlocks(nRows, blockSize) * BlockMatrixType.numBlocks(nCols, blockSize) } - assert(nBlocks == partFiles.length, s"$nBlocks vs ${ partFiles.mkString(", ") }") + assert(nBlocks == partFiles.length, s"$nBlocks vs ${partFiles.mkString(", ")}") - info(s"wrote matrix with $nRows ${ plural(nRows, "row") } " + - s"and $nCols ${ plural(nCols, "column") } " + - s"as $nBlocks ${ plural(nBlocks, "block") } " + + info(s"wrote matrix with $nRows ${plural(nRows, "row")} " + + s"and $nCols ${plural(nCols, "column")} " + + s"as $nBlocks ${plural(nBlocks, "block")} " + s"of size $blockSize to $path") } } @@ -91,29 +133,36 @@ case class BlockMatrixNativeMetadataWriter(path: String, stageLocally: Boolean, def writeMetadata( writeAnnotations: => IEmitCode, cb: EmitCodeBuilder, - region: Value[Region]): Unit = { - val metaHelper = BMMetadataHelper(path, typ.blockSize, typ.nRows, typ.nCols, typ.linearizedDefinedBlocks) + region: Value[Region], + ): Unit = { + val metaHelper = + BMMetadataHelper(path, typ.blockSize, typ.nRows, typ.nCols, typ.linearizedDefinedBlocks) - val pc = writeAnnotations.get(cb, "write annotations can't be missing!").asIndexable + val pc = writeAnnotations.getOrFatal(cb, "write annotations can't be missing!").asIndexable val partFiles = cb.newLocal[Array[String]]("partFiles") val n = cb.newLocal[Int]("n", pc.loadLength()) val i = cb.newLocal[Int]("i", 0) cb.assign(partFiles, Code.newArray[String](n)) - cb.while_(i < n, { - val s = pc.loadElement(cb, i).get(cb, "file name can't be missing!").asString - cb += partFiles.update(i, s.loadString(cb)) - cb.assign(i, i + 1) - }) - cb += cb.emb.getObject(metaHelper).invoke[FS, Array[String], Unit]("write", cb.emb.getFS, partFiles) + cb.while_( + i < n, { + val s = pc.loadElement(cb, i).getOrFatal(cb, "file name can't be missing!").asString + cb += partFiles.update(i, s.loadString(cb)) + cb.assign(i, i + 1) + }, + ) + cb += cb.emb.getObject(metaHelper).invoke[FS, Array[String], Unit]( + "write", + cb.emb.getFS, + partFiles, + ) } def loweredTyp: Type = TVoid } - - case class BlockMatrixBinaryWriter(path: String) extends BlockMatrixWriter { def pathOpt: Option[String] = Some(path) + def apply(ctx: ExecuteContext, bm: BlockMatrix): String = { RichDenseMatrixDouble.exportToDoubles(ctx.fs, path, bm.toBreezeMatrix(), forceRowMajor = true) path @@ -121,7 +170,12 @@ case class BlockMatrixBinaryWriter(path: String) extends BlockMatrixWriter { def loweredTyp: Type = TString - override def lower(ctx: ExecuteContext, s: BlockMatrixStage2, evalCtx: IRBuilder, eltR: TypeWithRequiredness): IR = { + override def lower( + ctx: ExecuteContext, + s: BlockMatrixStage2, + evalCtx: IRBuilder, + eltR: TypeWithRequiredness, + ): IR = { val nd = s.collectLocal(evalCtx, "block_matrix_binary_writer") // FIXME remove numpy encoder @@ -134,8 +188,10 @@ case class BlockMatrixBinaryWriter(path: String) extends BlockMatrixWriter { case class BlockMatrixPersistWriter(id: String, storageLevel: String) extends BlockMatrixWriter { def pathOpt: Option[String] = None + def apply(ctx: ExecuteContext, bm: BlockMatrix): Unit = HailContext.backend.persist(ctx.backendContext, id, bm, storageLevel) + def loweredTyp: Type = TVoid } @@ -143,13 +199,13 @@ case class BlockMatrixRectanglesWriter( path: String, rectangles: Array[Array[Long]], delimiter: String, - binary: Boolean) extends BlockMatrixWriter { + binary: Boolean, +) extends BlockMatrixWriter { def pathOpt: Option[String] = Some(path) - def apply(ctx: ExecuteContext, bm: BlockMatrix): Unit = { + def apply(ctx: ExecuteContext, bm: BlockMatrix): Unit = bm.exportRectangles(ctx, path, rectangles, delimiter, binary) - } def loweredTyp: Type = TVoid } @@ -160,7 +216,8 @@ abstract class BlockMatrixMultiWriter { case class BlockMatrixBinaryMultiWriter( prefix: String, - overwrite: Boolean) extends BlockMatrixMultiWriter { + overwrite: Boolean, +) extends BlockMatrixMultiWriter { def apply(ctx: ExecuteContext, bms: IndexedSeq[BlockMatrix]): Unit = BlockMatrix.binaryWriteBlockMatrices(ctx.fs, bms, prefix, overwrite) @@ -175,10 +232,21 @@ case class BlockMatrixTextMultiWriter( header: Option[String], addIndex: Boolean, compression: Option[String], - customFilenames: Option[Array[String]]) extends BlockMatrixMultiWriter { + customFilenames: Option[Array[String]], +) extends BlockMatrixMultiWriter { def apply(ctx: ExecuteContext, bms: IndexedSeq[BlockMatrix]): Unit = - BlockMatrix.exportBlockMatrices(ctx.fs, bms, prefix, overwrite, delimiter, header, addIndex, compression, customFilenames) + BlockMatrix.exportBlockMatrices( + ctx.fs, + bms, + prefix, + overwrite, + delimiter, + header, + addIndex, + compression, + customFilenames, + ) def loweredTyp: Type = TVoid } @@ -186,11 +254,11 @@ case class BlockMatrixTextMultiWriter( case class BlockMatrixNativeMultiWriter( prefix: String, overwrite: Boolean, - forceRowMajor: Boolean) extends BlockMatrixMultiWriter { + forceRowMajor: Boolean, +) extends BlockMatrixMultiWriter { - def apply(ctx: ExecuteContext, bms: IndexedSeq[BlockMatrix]): Unit = { + def apply(ctx: ExecuteContext, bms: IndexedSeq[BlockMatrix]): Unit = BlockMatrix.writeBlockMatrices(ctx, bms, prefix, overwrite, forceRowMajor) - } def loweredTyp: Type = TVoid } diff --git a/hail/src/main/scala/is/hail/expr/ir/Casts.scala b/hail/src/main/scala/is/hail/expr/ir/Casts.scala index 79cda4ab471..9bd368e645c 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Casts.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Casts.scala @@ -1,31 +1,53 @@ package is.hail.expr.ir import is.hail.asm4s._ -import is.hail.types._ -import is.hail.types.physical.stypes.{SCode, SValue} +import is.hail.types.physical.stypes.SValue import is.hail.types.physical.stypes.interfaces._ import is.hail.types.virtual._ -import scala.language.existentials - object Casts { private val casts: Map[(Type, Type), (EmitCodeBuilder, SValue) => SValue] = Map( (TInt32, TInt32) -> ((cb: EmitCodeBuilder, x: SValue) => x), - (TInt32, TInt64) -> ((cb: EmitCodeBuilder, x: SValue) => primitive(cb.memoize(x.asInt.value.toL))), - (TInt32, TFloat32) -> ((cb: EmitCodeBuilder, x: SValue) => primitive(cb.memoize(x.asInt.value.toF))), - (TInt32, TFloat64) -> ((cb: EmitCodeBuilder, x: SValue) => primitive(cb.memoize(x.asInt.value.toD))), - (TInt64, TInt32) -> ((cb: EmitCodeBuilder, x: SValue) => primitive(cb.memoize(x.asLong.value.toI))), + (TInt32, TInt64) -> ((cb: EmitCodeBuilder, x: SValue) => + primitive(cb.memoize(x.asInt.value.toL)) + ), + (TInt32, TFloat32) -> ((cb: EmitCodeBuilder, x: SValue) => + primitive(cb.memoize(x.asInt.value.toF)) + ), + (TInt32, TFloat64) -> ((cb: EmitCodeBuilder, x: SValue) => + primitive(cb.memoize(x.asInt.value.toD)) + ), + (TInt64, TInt32) -> ((cb: EmitCodeBuilder, x: SValue) => + primitive(cb.memoize(x.asLong.value.toI)) + ), (TInt64, TInt64) -> ((cb: EmitCodeBuilder, x: SValue) => x), - (TInt64, TFloat32) -> ((cb: EmitCodeBuilder, x: SValue) => primitive(cb.memoize(x.asLong.value.toF))), - (TInt64, TFloat64) -> ((cb: EmitCodeBuilder, x: SValue) => primitive(cb.memoize(x.asLong.value.toD))), - (TFloat32, TInt32) -> ((cb: EmitCodeBuilder, x: SValue) => primitive(cb.memoize(x.asFloat.value.toI))), - (TFloat32, TInt64) -> ((cb: EmitCodeBuilder, x: SValue) => primitive(cb.memoize(x.asFloat.value.toL))), + (TInt64, TFloat32) -> ((cb: EmitCodeBuilder, x: SValue) => + primitive(cb.memoize(x.asLong.value.toF)) + ), + (TInt64, TFloat64) -> ((cb: EmitCodeBuilder, x: SValue) => + primitive(cb.memoize(x.asLong.value.toD)) + ), + (TFloat32, TInt32) -> ((cb: EmitCodeBuilder, x: SValue) => + primitive(cb.memoize(x.asFloat.value.toI)) + ), + (TFloat32, TInt64) -> ((cb: EmitCodeBuilder, x: SValue) => + primitive(cb.memoize(x.asFloat.value.toL)) + ), (TFloat32, TFloat32) -> ((cb: EmitCodeBuilder, x: SValue) => x), - (TFloat32, TFloat64) -> ((cb: EmitCodeBuilder, x: SValue) => primitive(cb.memoize(x.asFloat.value.toD))), - (TFloat64, TInt32) -> ((cb: EmitCodeBuilder, x: SValue) => primitive(cb.memoize(x.asDouble.value.toI))), - (TFloat64, TInt64) -> ((cb: EmitCodeBuilder, x: SValue) => primitive(cb.memoize(x.asDouble.value.toL))), - (TFloat64, TFloat32) -> ((cb: EmitCodeBuilder, x: SValue) => primitive(cb.memoize(x.asDouble.value.toF))), - (TFloat64, TFloat64) -> ((cb: EmitCodeBuilder, x: SValue) => x)) + (TFloat32, TFloat64) -> ((cb: EmitCodeBuilder, x: SValue) => + primitive(cb.memoize(x.asFloat.value.toD)) + ), + (TFloat64, TInt32) -> ((cb: EmitCodeBuilder, x: SValue) => + primitive(cb.memoize(x.asDouble.value.toI)) + ), + (TFloat64, TInt64) -> ((cb: EmitCodeBuilder, x: SValue) => + primitive(cb.memoize(x.asDouble.value.toL)) + ), + (TFloat64, TFloat32) -> ((cb: EmitCodeBuilder, x: SValue) => + primitive(cb.memoize(x.asDouble.value.toF)) + ), + (TFloat64, TFloat64) -> ((cb: EmitCodeBuilder, x: SValue) => x), + ) def get(from: Type, to: Type): (EmitCodeBuilder, SValue) => SValue = casts(from -> to) diff --git a/hail/src/main/scala/is/hail/expr/ir/Children.scala b/hail/src/main/scala/is/hail/expr/ir/Children.scala index 2d486940b09..e69e46d46b1 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Children.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Children.scala @@ -6,58 +6,56 @@ object Children { private val none: IndexedSeq[BaseIR] = Array.empty[BaseIR] def apply(x: IR): IndexedSeq[BaseIR] = x match { - case I32(x) => none - case I64(x) => none - case F32(x) => none - case F64(x) => none - case Str(x) => none + case I32(_) => none + case I64(_) => none + case F32(_) => none + case F64(_) => none + case Str(_) => none case UUID4(_) => none case True() => none case False() => none case Literal(_, _) => none case EncodedLiteral(_, _) => none case Void() => none - case Cast(v, typ) => + case Cast(v, _) => Array(v) - case CastRename(v, typ) => + case CastRename(v, _) => Array(v) - case NA(typ) => none + case NA(_) => none case IsNA(value) => Array(value) case Coalesce(values) => values.toFastSeq case Consume(value) => FastSeq(value) case If(cond, cnsq, altr) => Array(cond, cnsq, altr) - case s@Switch(x, default, cases) => + case s @ Switch(x, default, cases) => val children = Array.ofDim[BaseIR](s.size) children(0) = x children(1) = default for (i <- cases.indices) children(2 + i) = cases(i) children - case Let(bindings, body) => + case Block(bindings, body) => val children = Array.ofDim[BaseIR](x.size) - for (i <- bindings.indices) children(i) = bindings(i)._2 + for (i <- bindings.indices) children(i) = bindings(i).value children(bindings.size) = body children - case RelationalLet(name, value, body) => - Array(value, body) - case AggLet(name, value, body, _) => + case RelationalLet(_, value, body) => Array(value, body) case TailLoop(_, args, _, body) => args.map(_._2).toFastSeq :+ body case Recur(_, args, _) => args.toFastSeq - case Ref(name, typ) => + case Ref(_, _) => none case RelationalRef(_, _) => none - case ApplyBinaryPrimOp(op, l, r) => + case ApplyBinaryPrimOp(_, l, r) => Array(l, r) - case ApplyUnaryPrimOp(op, x) => + case ApplyUnaryPrimOp(_, x) => Array(x) - case ApplyComparisonOp(op, l, r) => + case ApplyComparisonOp(_, l, r) => Array(l, r) - case MakeArray(args, typ) => + case MakeArray(args, _) => args.toFastSeq case MakeStream(args, _, _) => args.toFastSeq @@ -121,9 +119,9 @@ object Children { Array(a, size) case StreamGroupByKey(a, _, _) => Array(a) - case StreamMap(a, name, body) => + case StreamMap(a, _, body) => Array(a, body) - case StreamZip(as, names, body, _, _) => + case StreamZip(as, _, body, _, _) => as :+ body case StreamZipJoin(as, _, _, _, joinF) => as :+ joinF @@ -131,27 +129,27 @@ object Children { Array(contexts, makeProducer, joinF) case StreamMultiMerge(as, _) => as - case StreamFilter(a, name, cond) => + case StreamFilter(a, _, cond) => Array(a, cond) - case StreamTakeWhile(a, name, cond) => + case StreamTakeWhile(a, _, cond) => Array(a, cond) - case StreamDropWhile(a, name, cond) => + case StreamDropWhile(a, _, cond) => Array(a, cond) - case StreamFlatMap(a, name, body) => + case StreamFlatMap(a, _, body) => Array(a, body) - case StreamFold(a, zero, accumName, valueName, body) => + case StreamFold(a, zero, _, _, body) => Array(a, zero, body) - case StreamFold2(a, accum, valueName, seq, result) => + case StreamFold2(a, accum, _, seq, result) => Array(a) ++ accum.map(_._2) ++ seq ++ Array(result) - case StreamScan(a, zero, accumName, valueName, body) => + case StreamScan(a, zero, _, _, body) => Array(a, zero, body) - case StreamJoinRightDistinct(left, right, lKey, rKey, l, r, join, joinType) => + case StreamJoinRightDistinct(left, right, _, _, _, _, join, _) => Array(left, right, join) - case StreamFor(a, valueName, body) => + case StreamFor(a, _, body) => Array(a, body) - case StreamAgg(a, name, query) => + case StreamAgg(a, _, query) => Array(a, query) - case StreamAggScan(a, name, query) => + case StreamAggScan(a, _, query) => Array(a, query) case StreamBufferedAggregate(streamChild, initAggs, newKey, seqOps, _, _, _) => Array(streamChild, initAggs, newKey, seqOps) @@ -183,7 +181,7 @@ object Children { Array(nd) case NDArrayEigh(nd, _, _) => Array(nd) - case NDArrayInv(nd, errorID) => + case NDArrayInv(nd, _) => Array(nd) case NDArrayWrite(nd, path) => Array(nd, path) @@ -193,10 +191,11 @@ object Children { Array(array, aggBody) case AggGroupBy(key, aggIR, _) => Array(key, aggIR) - case AggArrayPerElement(a, _, _, aggBody, knownLength, _) => Array(a, aggBody) ++ knownLength.toArray[IR] + case AggArrayPerElement(a, _, _, aggBody, knownLength, _) => + Array(a, aggBody) ++ knownLength.toArray[IR] case MakeStruct(fields) => fields.map(_._2).toFastSeq - case SelectFields(old, fields) => + case SelectFields(old, _) => Array(old) case InsertFields(old, fields, _) => (old +: fields.map(_._2)).toFastSeq @@ -209,23 +208,21 @@ object Children { case InitFromSerializedValue(_, value, _) => Array(value) case SerializeAggs(_, _, _, _) => none case DeserializeAggs(_, _, _, _) => none - case Begin(xs) => - xs - case ApplyAggOp(initOpArgs, seqOpArgs, aggSig) => + case ApplyAggOp(initOpArgs, seqOpArgs, _) => initOpArgs ++ seqOpArgs - case ApplyScanOp(initOpArgs, seqOpArgs, aggSig) => + case ApplyScanOp(initOpArgs, seqOpArgs, _) => initOpArgs ++ seqOpArgs - case AggFold(zero, seqOp, combOp, elementName, accumName, _) => + case AggFold(zero, seqOp, combOp, _, _, _) => Array(zero, seqOp, combOp) - case GetField(o, name) => + case GetField(o, _) => Array(o) case MakeTuple(fields) => fields.map(_._2).toFastSeq - case GetTupleElement(o, idx) => + case GetTupleElement(o, _) => Array(o) - case In(i, typ) => + case In(_, _) => none - case Die(message, typ, errorId) => + case Die(message, _, _) => Array(message) case Trap(child) => Array(child) case ConsoleLog(message, result) => @@ -257,7 +254,8 @@ object Children { case BlockMatrixCollect(child) => Array(child) case BlockMatrixWrite(child, _) => Array(child) case BlockMatrixMultiWrite(blockMatrices, _) => blockMatrices - case CollectDistributedArray(ctxs, globals, _, _, body, dynamicID, _, _) => Array(ctxs, globals, body, dynamicID) + case CollectDistributedArray(ctxs, globals, _, _, body, dynamicID, _, _) => + Array(ctxs, globals, body, dynamicID) case ReadPartition(path, _, _) => Array(path) case WritePartition(stream, ctx, _) => Array(stream, ctx) case WriteMetadata(writeAnnotations, _) => Array(writeAnnotations) diff --git a/hail/src/main/scala/is/hail/expr/ir/ComparisonOp.scala b/hail/src/main/scala/is/hail/expr/ir/ComparisonOp.scala index f9d83ebb980..03428d3da06 100644 --- a/hail/src/main/scala/is/hail/expr/ir/ComparisonOp.scala +++ b/hail/src/main/scala/is/hail/expr/ir/ComparisonOp.scala @@ -1,12 +1,10 @@ package is.hail.expr.ir -import is.hail.asm4s.Code -import is.hail.expr.ir.orderings.{CodeOrdering, StructOrdering} +import is.hail.expr.ir.orderings.CodeOrdering import is.hail.expr.ir.orderings.CodeOrdering.F -import is.hail.types.physical.PType import is.hail.types.physical.stypes.SType import is.hail.types.physical.stypes.interfaces.SBaseStruct -import is.hail.types.virtual.{TStruct, Type} +import is.hail.types.virtual.Type object ComparisonOp { @@ -71,7 +69,9 @@ sealed trait ComparisonOp[ReturnType] { def t2: Type val op: CodeOrdering.Op val strict: Boolean = true - def codeOrdering(ecb: EmitClassBuilder[_], t1p: SType, t2p: SType): CodeOrdering.F[op.ReturnType] = { + + def codeOrdering(ecb: EmitClassBuilder[_], t1p: SType, t2p: SType) + : CodeOrdering.F[op.ReturnType] = { ComparisonOp.checkCompatible(t1p.virtualType, t2p.virtualType) ecb.getOrderingFunction(t1p, t2p, op).asInstanceOf[CodeOrdering.F[op.ReturnType]] } @@ -85,49 +85,66 @@ case class GT(t1: Type, t2: Type) extends ComparisonOp[Boolean] { val op: CodeOrdering.Op = CodeOrdering.Gt() override def copy(t1: Type = t1, t2: Type = t2): GT = GT(t1, t2) } + object GT { def apply(typ: Type): GT = GT(typ, typ) } + case class GTEQ(t1: Type, t2: Type) extends ComparisonOp[Boolean] { val op: CodeOrdering.Op = CodeOrdering.Gteq() override def copy(t1: Type = t1, t2: Type = t2): GTEQ = GTEQ(t1, t2) } + object GTEQ { def apply(typ: Type): GTEQ = GTEQ(typ, typ) } + case class LTEQ(t1: Type, t2: Type) extends ComparisonOp[Boolean] { val op: CodeOrdering.Op = CodeOrdering.Lteq() override def copy(t1: Type = t1, t2: Type = t2): LTEQ = LTEQ(t1, t2) } + object LTEQ { def apply(typ: Type): LTEQ = LTEQ(typ, typ) } + case class LT(t1: Type, t2: Type) extends ComparisonOp[Boolean] { val op: CodeOrdering.Op = CodeOrdering.Lt() override def copy(t1: Type = t1, t2: Type = t2): LT = LT(t1, t2) } + object LT { def apply(typ: Type): LT = LT(typ, typ) } + case class EQ(t1: Type, t2: Type) extends ComparisonOp[Boolean] { val op: CodeOrdering.Op = CodeOrdering.Equiv() override def copy(t1: Type = t1, t2: Type = t2): EQ = EQ(t1, t2) } + object EQ { def apply(typ: Type): EQ = EQ(typ, typ) } + case class NEQ(t1: Type, t2: Type) extends ComparisonOp[Boolean] { val op: CodeOrdering.Op = CodeOrdering.Neq() override def copy(t1: Type = t1, t2: Type = t2): NEQ = NEQ(t1, t2) } + object NEQ { def apply(typ: Type): NEQ = NEQ(typ, typ) } + case class EQWithNA(t1: Type, t2: Type) extends ComparisonOp[Boolean] { val op: CodeOrdering.Op = CodeOrdering.Equiv() override val strict: Boolean = false override def copy(t1: Type = t1, t2: Type = t2): EQWithNA = EQWithNA(t1, t2) } + object EQWithNA { def apply(typ: Type): EQWithNA = EQWithNA(typ, typ) } + case class NEQWithNA(t1: Type, t2: Type) extends ComparisonOp[Boolean] { val op: CodeOrdering.Op = CodeOrdering.Neq() override val strict: Boolean = false override def copy(t1: Type = t1, t2: Type = t2): NEQWithNA = NEQWithNA(t1, t2) } + object NEQWithNA { def apply(typ: Type): NEQWithNA = NEQWithNA(typ, typ) } + case class Compare(t1: Type, t2: Type) extends ComparisonOp[Int] { override val strict: Boolean = false val op: CodeOrdering.Op = CodeOrdering.Compare() override def copy(t1: Type = t1, t2: Type = t2): Compare = Compare(t1, t2) } + object Compare { def apply(typ: Type): Compare = Compare(typ, typ) } trait StructComparisonOp[T] extends ComparisonOp[T] { @@ -135,40 +152,62 @@ trait StructComparisonOp[T] extends ComparisonOp[T] { override def codeOrdering(ecb: EmitClassBuilder[_], t1: SType, t2: SType): F[op.ReturnType] = { ComparisonOp.checkCompatible(t1.virtualType, t2.virtualType) - ecb.getStructOrderingFunction(t1.asInstanceOf[SBaseStruct], t2.asInstanceOf[SBaseStruct], sortFields, op).asInstanceOf[CodeOrdering.F[op.ReturnType]] + ecb.getStructOrderingFunction( + t1.asInstanceOf[SBaseStruct], + t2.asInstanceOf[SBaseStruct], + sortFields, + op, + ).asInstanceOf[CodeOrdering.F[op.ReturnType]] } } -case class StructCompare(t1: Type, t2: Type, sortFields: Array[SortField]) extends StructComparisonOp[Int] { +case class StructCompare(t1: Type, t2: Type, sortFields: Array[SortField]) + extends StructComparisonOp[Int] { val op: CodeOrdering.Op = CodeOrdering.StructCompare() override val strict: Boolean = false override def copy(t1: Type = t1, t2: Type = t2): StructCompare = StructCompare(t1, t2, sortFields) } -case class StructLT(t1: Type, t2: Type, sortFields: Array[SortField]) extends StructComparisonOp[Boolean] { +case class StructLT(t1: Type, t2: Type, sortFields: Array[SortField]) + extends StructComparisonOp[Boolean] { val op: CodeOrdering.Op = CodeOrdering.StructLt() override def copy(t1: Type = t1, t2: Type = t2): StructLT = StructLT(t1, t2, sortFields) } -object StructLT { def apply(typ: Type, sortFields: IndexedSeq[SortField]): StructLT = StructLT(typ, typ, sortFields.toArray) } +object StructLT { + def apply(typ: Type, sortFields: IndexedSeq[SortField]): StructLT = + StructLT(typ, typ, sortFields.toArray) +} -case class StructLTEQ(t1: Type, t2: Type, sortFields: Array[SortField]) extends StructComparisonOp[Boolean] { +case class StructLTEQ(t1: Type, t2: Type, sortFields: Array[SortField]) + extends StructComparisonOp[Boolean] { val op: CodeOrdering.Op = CodeOrdering.StructLteq() override def copy(t1: Type = t1, t2: Type = t2): StructLTEQ = StructLTEQ(t1, t2, sortFields) } -object StructLTEQ { def apply(typ: Type, sortFields: IndexedSeq[SortField]): StructLTEQ = StructLTEQ(typ, typ, sortFields.toArray) } +object StructLTEQ { + def apply(typ: Type, sortFields: IndexedSeq[SortField]): StructLTEQ = + StructLTEQ(typ, typ, sortFields.toArray) +} -case class StructGT(t1: Type, t2: Type, sortFields: Array[SortField]) extends StructComparisonOp[Boolean] { +case class StructGT(t1: Type, t2: Type, sortFields: Array[SortField]) + extends StructComparisonOp[Boolean] { val op: CodeOrdering.Op = CodeOrdering.StructGt() override def copy(t1: Type = t1, t2: Type = t2): StructGT = StructGT(t1, t2, sortFields) } -object StructGT { def apply(typ: Type, sortFields: IndexedSeq[SortField]): StructGT = StructGT(typ, typ, sortFields.toArray) } +object StructGT { + def apply(typ: Type, sortFields: IndexedSeq[SortField]): StructGT = + StructGT(typ, typ, sortFields.toArray) +} -case class StructGTEQ(t1: Type, t2: Type, sortFields: Array[SortField]) extends StructComparisonOp[Boolean] { +case class StructGTEQ(t1: Type, t2: Type, sortFields: Array[SortField]) + extends StructComparisonOp[Boolean] { val op: CodeOrdering.Op = CodeOrdering.StructGteq() override def copy(t1: Type = t1, t2: Type = t2): StructGTEQ = StructGTEQ(t1, t2, sortFields) } -object StructGTEQ { def apply(typ: Type, sortFields: IndexedSeq[SortField]): StructGTEQ = StructGTEQ(typ, typ, sortFields.toArray) } +object StructGTEQ { + def apply(typ: Type, sortFields: IndexedSeq[SortField]): StructGTEQ = + StructGTEQ(typ, typ, sortFields.toArray) +} diff --git a/hail/src/main/scala/is/hail/expr/ir/Compilable.scala b/hail/src/main/scala/is/hail/expr/ir/Compilable.scala index 1e54da05e98..457525c01de 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Compilable.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Compilable.scala @@ -62,6 +62,7 @@ object Emittable { case _: AggExplode => true case _ => false } + def apply(ir: IR): Boolean = ir match { case x if isNonEmittableAgg(x) => false case _: ApplyIR => false diff --git a/hail/src/main/scala/is/hail/expr/ir/Compile.scala b/hail/src/main/scala/is/hail/expr/ir/Compile.scala index 56e127af9b4..304dcc0bc0c 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Compile.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Compile.scala @@ -8,66 +8,83 @@ import is.hail.expr.ir.lowering.LoweringPipeline import is.hail.expr.ir.streams.EmitStream import is.hail.io.fs.FS import is.hail.rvd.RVDContext -import is.hail.types.physical.stypes.interfaces.{NoBoxLongIterator, SStream} -import is.hail.types.physical.stypes.{PTypeReferenceSingleCodeType, SingleCodeType, StreamSingleCodeType} import is.hail.types.physical.{PStruct, PType} +import is.hail.types.physical.stypes.{ + PTypeReferenceSingleCodeType, SingleCodeType, StreamSingleCodeType, +} +import is.hail.types.physical.stypes.interfaces.{NoBoxLongIterator, SStream} import is.hail.types.virtual.Type import is.hail.utils._ import java.io.PrintWriter -case class CodeCacheKey(aggSigs: IndexedSeq[AggStateSig], args: Seq[(String, EmitParamType)], body: IR) - -case class CompiledFunction[T](typ: Option[SingleCodeType], f: (HailClassLoader, FS, HailTaskContext, Region) => T) { - def tuple: (Option[SingleCodeType], (HailClassLoader, FS, HailTaskContext, Region) => T) = (typ, f) +case class CodeCacheKey( + aggSigs: IndexedSeq[AggStateSig], + args: Seq[(String, EmitParamType)], + body: IR, +) + +case class CompiledFunction[T]( + typ: Option[SingleCodeType], + f: (HailClassLoader, FS, HailTaskContext, Region) => T, +) { + def tuple: (Option[SingleCodeType], (HailClassLoader, FS, HailTaskContext, Region) => T) = + (typ, f) } object Compile { def apply[F: TypeInfo]( ctx: ExecuteContext, params: IndexedSeq[(String, EmitParamType)], - expectedCodeParamTypes: IndexedSeq[TypeInfo[_]], expectedCodeReturnType: TypeInfo[_], + expectedCodeParamTypes: IndexedSeq[TypeInfo[_]], + expectedCodeReturnType: TypeInfo[_], body: IR, optimize: Boolean = true, - print: Option[PrintWriter] = None + print: Option[PrintWriter] = None, ): (Option[SingleCodeType], (HailClassLoader, FS, HailTaskContext, Region) => F) = { - val normalizedBody = new NormalizeNames(_.toString)(ctx, body, - Env(params.map { case (n, _) => n -> n }: _*) - ) - val k = CodeCacheKey(FastSeq[AggStateSig](), params.map { case (n, pt) => (n, pt) }, normalizedBody) + val normalizedBody = + new NormalizeNames(_.toString)(ctx, body, Env(params.map { case (n, _) => n -> n }: _*)) + val k = + CodeCacheKey(FastSeq[AggStateSig](), params.map { case (n, pt) => (n, pt) }, normalizedBody) (ctx.backend.lookupOrCompileCachedFunction[F](k) { var ir = body - ir = Subst(ir, BindingEnv(params - .zipWithIndex - .foldLeft(Env.empty[IR]) { case (e, ((n, t), i)) => e.bind(n, In(i, t)) })) + ir = Subst( + ir, + BindingEnv(params + .zipWithIndex + .foldLeft(Env.empty[IR]) { case (e, ((n, t), i)) => e.bind(n, In(i, t)) }), + ) ir = LoweringPipeline.compileLowerer(optimize).apply(ctx, ir).asInstanceOf[IR].noSharing(ctx) TypeCheck(ctx, ir, BindingEnv.empty) val returnParam = CodeParamType(SingleCodeType.typeInfoFromType(ir.typ)) - val fb = EmitFunctionBuilder[F](ctx, "Compiled", + val fb = EmitFunctionBuilder[F]( + ctx, + "Compiled", CodeParamType(typeInfo[Region]) +: params.map { case (_, pt) => pt - }, returnParam, Some("Emit.scala")) - - /* - { - def visit(x: IR): Unit = { - println(f"${ System.identityHashCode(x) }%08x ${ x.getClass.getSimpleName } ${ x.pType }") - Children(x).foreach { - case c: IR => visit(c) - } - } - - visit(ir) - } - */ - - assert(fb.mb.parameterTypeInfo == expectedCodeParamTypes, s"expected $expectedCodeParamTypes, got ${ fb.mb.parameterTypeInfo }") - assert(fb.mb.returnTypeInfo == expectedCodeReturnType, s"expected $expectedCodeReturnType, got ${ fb.mb.returnTypeInfo }") + }, + returnParam, + Some("Emit.scala"), + ) + + /* { def visit(x: IR): Unit = { println(f"${ System.identityHashCode(x) }%08x ${ + * x.getClass.getSimpleName } ${ x.pType }") Children(x).foreach { case c: IR => visit(c) } } + * + * visit(ir) } */ + + assert( + fb.mb.parameterTypeInfo == expectedCodeParamTypes, + s"expected $expectedCodeParamTypes, got ${fb.mb.parameterTypeInfo}", + ) + assert( + fb.mb.returnTypeInfo == expectedCodeReturnType, + s"expected $expectedCodeReturnType, got ${fb.mb.returnTypeInfo}", + ) val emitContext = EmitContext.analyze(ctx, ir) val rt = Emit(emitContext, ir, fb, expectedCodeReturnType, params.length) @@ -81,46 +98,60 @@ object CompileWithAggregators { ctx: ExecuteContext, aggSigs: Array[AggStateSig], params: IndexedSeq[(String, EmitParamType)], - expectedCodeParamTypes: IndexedSeq[TypeInfo[_]], expectedCodeReturnType: TypeInfo[_], + expectedCodeParamTypes: IndexedSeq[TypeInfo[_]], + expectedCodeReturnType: TypeInfo[_], body: IR, - optimize: Boolean = true - ): (Option[SingleCodeType], (HailClassLoader, FS, HailTaskContext, Region) => (F with FunctionWithAggRegion)) = { - val normalizedBody = new NormalizeNames(_.toString)(ctx, body, - Env(params.map { case (n, _) => n -> n }: _*) - ) + optimize: Boolean = true, + ): ( + Option[SingleCodeType], + (HailClassLoader, FS, HailTaskContext, Region) => (F with FunctionWithAggRegion), + ) = { + val normalizedBody = + new NormalizeNames(_.toString)(ctx, body, Env(params.map { case (n, _) => n -> n }: _*)) val k = CodeCacheKey(aggSigs, params.map { case (n, pt) => (n, pt) }, normalizedBody) (ctx.backend.lookupOrCompileCachedFunction[F with FunctionWithAggRegion](k) { var ir = body - ir = Subst(ir, BindingEnv(params - .zipWithIndex - .foldLeft(Env.empty[IR]) { case (e, ((n, t), i)) => e.bind(n, In(i, t)) })) + ir = Subst( + ir, + BindingEnv(params + .zipWithIndex + .foldLeft(Env.empty[IR]) { case (e, ((n, t), i)) => e.bind(n, In(i, t)) }), + ) ir = LoweringPipeline.compileLowerer(optimize).apply(ctx, ir).asInstanceOf[IR].noSharing(ctx) - TypeCheck(ctx, ir, BindingEnv(Env.fromSeq[Type](params.map { case (name, t) => name -> t.virtualType }))) + TypeCheck( + ctx, + ir, + BindingEnv(Env.fromSeq[Type](params.map { case (name, t) => name -> t.virtualType })), + ) - val fb = EmitFunctionBuilder[F](ctx, "CompiledWithAggs", + val fb = EmitFunctionBuilder[F]( + ctx, + "CompiledWithAggs", CodeParamType(typeInfo[Region]) +: params.map { case (_, pt) => pt }, - SingleCodeType.typeInfoFromType(ir.typ), Some("Emit.scala")) - - /* - { - def visit(x: IR): Unit = { - println(f"${ System.identityHashCode(x) }%08x ${ x.getClass.getSimpleName } ${ x.pType }") - Children(x).foreach { - case c: IR => visit(c) - } - } + SingleCodeType.typeInfoFromType(ir.typ), + Some("Emit.scala"), + ) - visit(ir) - } - */ + /* { def visit(x: IR): Unit = { println(f"${ System.identityHashCode(x) }%08x ${ + * x.getClass.getSimpleName } ${ x.pType }") Children(x).foreach { case c: IR => visit(c) } } + * + * visit(ir) } */ val emitContext = EmitContext.analyze(ctx, ir) val rt = Emit(emitContext, ir, fb, expectedCodeReturnType, params.length, Some(aggSigs)) val f = fb.resultWithIndex() - CompiledFunction(rt, f.asInstanceOf[(HailClassLoader, FS, HailTaskContext, Region) => (F with FunctionWithAggRegion)]) + CompiledFunction( + rt, + f.asInstanceOf[( + HailClassLoader, + FS, + HailTaskContext, + Region, + ) => (F with FunctionWithAggRegion)], + ) }).tuple } } @@ -143,7 +174,7 @@ object CompileIterator { def setRegions(outerRegion: Region, eltRegion: Region): Unit } - private abstract class LongIteratorWrapper extends Iterator[java.lang.Long] { + abstract private class LongIteratorWrapper extends Iterator[java.lang.Long] { def step(): Boolean protected val stepFunction: StepFunctionBase @@ -166,18 +197,31 @@ object CompileIterator { } } - private def compileStepper[F >: Null <: StepFunctionBase : TypeInfo]( + private def compileStepper[F >: Null <: StepFunctionBase: TypeInfo]( ctx: ExecuteContext, body: IR, argTypeInfo: Array[ParamType], - printWriter: Option[PrintWriter] + printWriter: Option[PrintWriter], ): (PType, (HailClassLoader, FS, HailTaskContext, Region) => F) = { - val fb = EmitFunctionBuilder.apply[F](ctx, s"stream_${body.getClass.getSimpleName}", argTypeInfo.toFastSeq, CodeParamType(BooleanInfo), Some("Emit.scala")) + val fb = EmitFunctionBuilder.apply[F]( + ctx, + s"stream_${body.getClass.getSimpleName}", + argTypeInfo.toFastSeq, + CodeParamType(BooleanInfo), + Some("Emit.scala"), + ) val outerRegionField = fb.genFieldThisRef[Region]("outerRegion") val eltRegionField = fb.genFieldThisRef[Region]("eltRegion") - val setF = fb.newEmitMethod("setRegions", FastSeq(CodeParamType(typeInfo[Region]), CodeParamType(typeInfo[Region])), CodeParamType(typeInfo[Unit])) - setF.emit(Code(outerRegionField := setF.getCodeParam[Region](1), eltRegionField := setF.getCodeParam[Region](2))) + val setF = fb.newEmitMethod( + "setRegions", + FastSeq(CodeParamType(typeInfo[Region]), CodeParamType(typeInfo[Region])), + CodeParamType(typeInfo[Unit]), + ) + setF.emit(Code( + outerRegionField := setF.getCodeParam[Region](1), + eltRegionField := setF.getCodeParam[Region](2), + )) val stepF = fb.apply_method val stepFECB = stepF.ecb @@ -194,8 +238,15 @@ object CompileIterator { val emitContext = EmitContext.analyze(ctx, ir) val emitter = new Emit(emitContext, stepFECB) - val env = EmitEnv(Env.empty, argTypeInfo.indices.filter(i => argTypeInfo(i).isInstanceOf[EmitParamType]).map(i => stepF.getEmitParam(cb, i + 1))) - val optStream = EmitCode.fromI(stepF)(cb => EmitStream.produce(emitter, ir, cb, cb.emb, outerRegion, env, None)) + val env = EmitEnv( + Env.empty, + argTypeInfo.indices.filter(i => argTypeInfo(i).isInstanceOf[EmitParamType]).map(i => + stepF.getEmitParam(cb, i + 1) + ), + ) + val optStream = EmitCode.fromI(stepF)(cb => + EmitStream.produce(emitter, ir, cb, cb.emb, outerRegion, env, None) + ) returnType = optStream.st.asInstanceOf[SStream].elementEmitType.storageType.setRequired(true) elementAddress = stepF.genFieldThisRef[Long]("elementAddr") @@ -210,19 +261,23 @@ object CompileIterator { val ret = cb.newLocal[Boolean]("stepf_ret") val Lreturn = CodeLabel() - cb.if_(!didSetup, { - optStream.toI(cb).get(cb) // handle missing, but bound stream producer above + cb.if_( + !didSetup, { + optStream.toI(cb).getOrAssert(cb) // handle missing, but bound stream producer above - cb.assign(producer.elementRegion, eltRegionField) - producer.initialize(cb, outerRegion) - cb.assign(didSetup, true) - cb.assign(eosField, false) - }) + cb.assign(producer.elementRegion, eltRegionField) + producer.initialize(cb, outerRegion) + cb.assign(didSetup, true) + cb.assign(eosField, false) + }, + ) - cb.if_(eosField, { - cb.assign(ret, false) - cb.goto(Lreturn) - }) + cb.if_( + eosField, { + cb.assign(ret, false) + cb.goto(Lreturn) + }, + ) cb.goto(producer.LproduceElement) @@ -234,7 +289,7 @@ object CompileIterator { } stepF.implementLabel(producer.LproduceElementDone) { cb => - val pc = producer.element.toI(cb).get(cb) + val pc = producer.element.toI(cb).getOrAssert(cb) cb.assign(elementAddress, returnType.store(cb, producer.elementRegion, pc, false)) cb.assign(ret, true) cb.goto(Lreturn) @@ -244,7 +299,6 @@ object CompileIterator { ret } - val getMB = fb.newEmitMethod("loadAddress", FastSeq(), LongInfo) getMB.emit(elementAddress.load()) @@ -253,55 +307,72 @@ object CompileIterator { def forTableMapPartitions( ctx: ExecuteContext, - typ0: PStruct, streamElementType: PType, - ir: IR - ): (PType, (HailClassLoader, FS, HailTaskContext, RVDContext, Long, NoBoxLongIterator) => Iterator[java.lang.Long]) = { + typ0: PStruct, + streamElementType: PType, + ir: IR, + ): ( + PType, + (HailClassLoader, FS, HailTaskContext, RVDContext, Long, NoBoxLongIterator) => Iterator[java.lang.Long], + ) = { assert(typ0.required) assert(streamElementType.required) val (eltPType, makeStepper) = compileStepper[TMPStepFunction]( - ctx, ir, + ctx, + ir, Array[ParamType]( CodeParamType(typeInfo[Object]), SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(typ0)), - SingleCodeEmitParamType(true, StreamSingleCodeType(true, streamElementType, true)) + SingleCodeEmitParamType(true, StreamSingleCodeType(true, streamElementType, true)), ), - None + None, + ) + ( + eltPType, + (theHailClassLoader, fs, htc, consumerCtx, v0, part) => { + val stepper = makeStepper(theHailClassLoader, fs, htc, consumerCtx.partitionRegion) + stepper.setRegions(consumerCtx.partitionRegion, consumerCtx.region) + new LongIteratorWrapper { + val stepFunction: TMPStepFunction = stepper + + def step(): Boolean = stepper.apply(null, v0, part) + } + }, ) - (eltPType, (theHailClassLoader, fs, htc, consumerCtx, v0, part) => { - val stepper = makeStepper(theHailClassLoader, fs, htc, consumerCtx.partitionRegion) - stepper.setRegions(consumerCtx.partitionRegion, consumerCtx.region) - new LongIteratorWrapper { - val stepFunction: TMPStepFunction = stepper - - def step(): Boolean = stepper.apply(null, v0, part) - } - }) } def forTableStageToRVD( ctx: ExecuteContext, - ctxType: PStruct, bcValsType: PType, - ir: IR - ): (PType, (HailClassLoader, FS, HailTaskContext, RVDContext, Long, Long) => Iterator[java.lang.Long]) = { + ctxType: PStruct, + bcValsType: PType, + ir: IR, + ): ( + PType, + (HailClassLoader, FS, HailTaskContext, RVDContext, Long, Long) => Iterator[java.lang.Long], + ) = { assert(ctxType.required) assert(bcValsType.required) val (eltPType, makeStepper) = compileStepper[TableStageToRVDStepFunction]( - ctx, ir, + ctx, + ir, Array[ParamType]( CodeParamType(typeInfo[Object]), SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(ctxType)), - SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(bcValsType))), - None + SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(bcValsType)), + ), + None, + ) + ( + eltPType, + (theHailClassLoader, fs, htc, consumerCtx, v0, v1) => { + val stepper = makeStepper(theHailClassLoader, fs, htc, consumerCtx.partitionRegion) + stepper.setRegions(consumerCtx.partitionRegion, consumerCtx.region) + new LongIteratorWrapper { + val stepFunction: TableStageToRVDStepFunction = stepper + + def step(): Boolean = stepper.apply(null, v0, v1) + } + }, ) - (eltPType, (theHailClassLoader, fs, htc, consumerCtx, v0, v1) => { - val stepper = makeStepper(theHailClassLoader, fs, htc, consumerCtx.partitionRegion) - stepper.setRegions(consumerCtx.partitionRegion, consumerCtx.region) - new LongIteratorWrapper { - val stepFunction: TableStageToRVDStepFunction = stepper - - def step(): Boolean = stepper.apply(null, v0, v1) - } - }) } } diff --git a/hail/src/main/scala/is/hail/expr/ir/CompileAndEvaluate.scala b/hail/src/main/scala/is/hail/expr/ir/CompileAndEvaluate.scala index 2ae6d527864..56b253402b9 100644 --- a/hail/src/main/scala/is/hail/expr/ir/CompileAndEvaluate.scala +++ b/hail/src/main/scala/is/hail/expr/ir/CompileAndEvaluate.scala @@ -8,13 +8,11 @@ import is.hail.types.physical.PTuple import is.hail.types.physical.stypes.PTypeReferenceSingleCodeType import is.hail.types.virtual._ import is.hail.utils.FastSeq + import org.apache.spark.sql.Row object CompileAndEvaluate { - def apply[T](ctx: ExecuteContext, - ir0: IR, - optimize: Boolean = true - ): T = { + def apply[T](ctx: ExecuteContext, ir0: IR, optimize: Boolean = true): T = { ctx.timer.time("CompileAndEvaluate") { _apply(ctx, ir0, optimize) match { case Left(()) => ().asInstanceOf[T] @@ -23,10 +21,7 @@ object CompileAndEvaluate { } } - def evalToIR(ctx: ExecuteContext, - ir0: IR, - optimize: Boolean = true - ): IR = { + def evalToIR(ctx: ExecuteContext, ir0: IR, optimize: Boolean = true): IR = { if (IsConstant(ir0)) return ir0 @@ -45,16 +40,20 @@ object CompileAndEvaluate { def _apply( ctx: ExecuteContext, ir0: IR, - optimize: Boolean = true + optimize: Boolean = true, ): Either[Unit, (PTuple, Long)] = { val ir = LoweringPipeline.relationalLowerer(optimize).apply(ctx, ir0).asInstanceOf[IR] if (ir.typ == TVoid) { - val (_, f) = ctx.timer.time("Compile")(Compile[AsmFunction1RegionUnit](ctx, + val (_, f) = ctx.timer.time("Compile")(Compile[AsmFunction1RegionUnit]( + ctx, FastSeq(), - FastSeq(classInfo[Region]), UnitInfo, + FastSeq(classInfo[Region]), + UnitInfo, ir, - print = None, optimize = optimize)) + print = None, + optimize = optimize, + )) ctx.scopedExecution { (hcl, fs, htc, r) => val fRunnable = ctx.timer.time("InitializeCompiledFunction")(f(hcl, fs, htc, r)) @@ -63,14 +62,23 @@ object CompileAndEvaluate { return Left(()) } - val (Some(PTypeReferenceSingleCodeType(resType: PTuple)), f) = ctx.timer.time("Compile")(Compile[AsmFunction1RegionLong](ctx, - FastSeq(), - FastSeq(classInfo[Region]), LongInfo, - MakeTuple.ordered(FastSeq(ir)), - print = None, optimize = optimize)) - + val (Some(PTypeReferenceSingleCodeType(resType: PTuple)), f) = + ctx.timer.time("Compile")(Compile[AsmFunction1RegionLong]( + ctx, + FastSeq(), + FastSeq(classInfo[Region]), + LongInfo, + MakeTuple.ordered(FastSeq(ir)), + print = None, + optimize = optimize, + )) - val fRunnable = ctx.timer.time("InitializeCompiledFunction")(f(ctx.theHailClassLoader, ctx.fs, ctx.taskContext, ctx.r)) + val fRunnable = ctx.timer.time("InitializeCompiledFunction")(f( + ctx.theHailClassLoader, + ctx.fs, + ctx.taskContext, + ctx.r, + )) val resultAddress = ctx.timer.time("RunCompiledFunction")(fRunnable(ctx.r)) Right((resType, resultAddress)) diff --git a/hail/src/main/scala/is/hail/expr/ir/ComputeUsesAndDefs.scala b/hail/src/main/scala/is/hail/expr/ir/ComputeUsesAndDefs.scala index 85d3abd8f4e..5ee4ed4d146 100644 --- a/hail/src/main/scala/is/hail/expr/ir/ComputeUsesAndDefs.scala +++ b/hail/src/main/scala/is/hail/expr/ir/ComputeUsesAndDefs.scala @@ -2,10 +2,11 @@ package is.hail.expr.ir import scala.collection.mutable -case class UsesAndDefs(uses: Memo[mutable.Set[RefEquality[BaseRef]]], - defs: Memo[BaseIR], - free: mutable.Set[RefEquality[BaseRef]] - ) +case class UsesAndDefs( + uses: Memo[mutable.Set[RefEquality[BaseRef]]], + defs: Memo[BaseIR], + free: mutable.Set[RefEquality[BaseRef]], +) object ComputeUsesAndDefs { def apply(ir0: BaseIR, errorIfFreeVariables: Boolean = true): UsesAndDefs = { @@ -31,7 +32,7 @@ object ComputeUsesAndDefs { } case None => if (errorIfFreeVariables) - throw new RuntimeException(s"found variable with no definition: ${ r.name }") + throw new RuntimeException(s"found variable with no definition: ${r.name}") else free += RefEquality(r) } @@ -41,12 +42,10 @@ object ComputeUsesAndDefs { ir.children .zipWithIndex .foreach { case (child, i) => - val e = ChildEnvWithoutBindings(ir, i, env) - val b = NewBindings(ir, i, env).mapValues[BaseIR](_ => ir) - if (!b.allEmpty && !uses.contains(ir)) + val bindings = Bindings.segregated(ir, i, env).mapNewBindings((_, _) => ir) + if (!bindings.newBindings.allEmpty && !uses.contains(ir)) uses.bind(ir, mutable.Set.empty[RefEquality[BaseRef]]) - val childEnv = e.merge(b) - compute(child, childEnv) + compute(child, bindings.unified) } } diff --git a/hail/src/main/scala/is/hail/expr/ir/Copy.scala b/hail/src/main/scala/is/hail/expr/ir/Copy.scala index 7a5ac9d4a74..79522e2d94a 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Copy.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Copy.scala @@ -21,7 +21,7 @@ object Copy { assert(newChildren.length == 1) CastRename(newChildren(0).asInstanceOf[IR], typ) case NA(t) => NA(t) - case IsNA(value) => + case IsNA(_) => assert(newChildren.length == 1) IsNA(newChildren(0).asInstanceOf[IR]) case Coalesce(_) => @@ -30,23 +30,33 @@ object Copy { Consume(newChildren(0).asInstanceOf[IR]) case If(_, _, _) => assert(newChildren.length == 3) - If(newChildren(0).asInstanceOf[IR], newChildren(1).asInstanceOf[IR], newChildren(2).asInstanceOf[IR]) + If( + newChildren(0).asInstanceOf[IR], + newChildren(1).asInstanceOf[IR], + newChildren(2).asInstanceOf[IR], + ) case s: Switch => assert(s.size == newChildren.size) - Switch(newChildren(0).asInstanceOf[IR], newChildren(1).asInstanceOf[IR], newChildren.drop(2).asInstanceOf[IndexedSeq[IR]]) - case Let(bindings, _) => + Switch( + newChildren(0).asInstanceOf[IR], + newChildren(1).asInstanceOf[IR], + newChildren.drop(2).asInstanceOf[IndexedSeq[IR]], + ) + case Block(bindings, _) => assert(newChildren.length == x.size) val newBindings = (bindings, newChildren.init) .zipped - .map { case ((name, _), ir: IR) => name -> ir } - Let(newBindings, newChildren.last.asInstanceOf[IR]) - case AggLet(name, _, _, isScan) => - assert(newChildren.length == 2) - AggLet(name, newChildren(0).asInstanceOf[IR], newChildren(1).asInstanceOf[IR], isScan) - case TailLoop(name, params, resultType, _) => + .map { case (binding, ir: IR) => binding.copy(value = ir) } + Block(newBindings, newChildren.last.asInstanceOf[IR]) + case TailLoop(name, params, resultType, _) => assert(newChildren.length == params.length + 1) - TailLoop(name, params.map(_._1).zip(newChildren.init.map(_.asInstanceOf[IR])), resultType, newChildren.last.asInstanceOf[IR]) + TailLoop( + name, + params.map(_._1).zip(newChildren.init.map(_.asInstanceOf[IR])), + resultType, + newChildren.last.asInstanceOf[IR], + ) case Recur(name, args, t) => assert(newChildren.length == args.length) Recur(name, newChildren.map(_.asInstanceOf[IR]), t) @@ -73,62 +83,117 @@ object Copy { case ArrayRef(_, _, errorID) => assert(newChildren.length == 2) ArrayRef(newChildren(0).asInstanceOf[IR], newChildren(1).asInstanceOf[IR], errorID) - case ArraySlice(_,_, stop, _, errorID) => + case ArraySlice(_, _, stop, _, errorID) => if (stop.isEmpty) { assert(newChildren.length == 3) - ArraySlice(newChildren(0).asInstanceOf[IR], newChildren(1).asInstanceOf[IR], None, - newChildren(2).asInstanceOf[IR], errorID) + ArraySlice( + newChildren(0).asInstanceOf[IR], + newChildren(1).asInstanceOf[IR], + None, + newChildren(2).asInstanceOf[IR], + errorID, + ) + } else { + assert(newChildren.length == 4) + ArraySlice( + newChildren(0).asInstanceOf[IR], + newChildren(1).asInstanceOf[IR], + Some(newChildren(2).asInstanceOf[IR]), + newChildren(3).asInstanceOf[IR], + errorID, + ) } - else { - assert(newChildren.length == 4) - ArraySlice(newChildren(0).asInstanceOf[IR], newChildren(1).asInstanceOf[IR], Some(newChildren(2).asInstanceOf[IR]), - newChildren(3).asInstanceOf[IR], errorID) - } case ArrayLen(_) => assert(newChildren.length == 1) ArrayLen(newChildren(0).asInstanceOf[IR]) case StreamIota(_, _, requiresMemoryManagementPerElement) => assert(newChildren.length == 2) - StreamIota(newChildren(0).asInstanceOf[IR], newChildren(1).asInstanceOf[IR], requiresMemoryManagementPerElement) + StreamIota( + newChildren(0).asInstanceOf[IR], + newChildren(1).asInstanceOf[IR], + requiresMemoryManagementPerElement, + ) case StreamRange(_, _, _, requiresMemoryManagementPerElement, errorID) => assert(newChildren.length == 3) - StreamRange(newChildren(0).asInstanceOf[IR], newChildren(1).asInstanceOf[IR], newChildren(2).asInstanceOf[IR], - requiresMemoryManagementPerElement, errorID) + StreamRange( + newChildren(0).asInstanceOf[IR], + newChildren(1).asInstanceOf[IR], + newChildren(2).asInstanceOf[IR], + requiresMemoryManagementPerElement, + errorID, + ) case SeqSample(_, _, _, requiresMemoryManagementPerElement) => assert(newChildren.length == 3) - SeqSample(newChildren(0).asInstanceOf[IR], newChildren(1).asInstanceOf[IR], newChildren(2).asInstanceOf[IR], requiresMemoryManagementPerElement) + SeqSample( + newChildren(0).asInstanceOf[IR], + newChildren(1).asInstanceOf[IR], + newChildren(2).asInstanceOf[IR], + requiresMemoryManagementPerElement, + ) case StreamDistribute(_, _, _, op, spec) => - StreamDistribute(newChildren(0).asInstanceOf[IR], newChildren(1).asInstanceOf[IR], newChildren(2).asInstanceOf[IR], op, spec) - case StreamWhiten(stream, newChunk, prevWindow, vecSize, windowSize, chunkSize, blockSize, normalizeAfterWhiten) => - StreamWhiten(newChildren(0).asInstanceOf[IR], newChunk, prevWindow, vecSize, windowSize, chunkSize, blockSize, normalizeAfterWhiten) + StreamDistribute( + newChildren(0).asInstanceOf[IR], + newChildren(1).asInstanceOf[IR], + newChildren(2).asInstanceOf[IR], + op, + spec, + ) + case StreamWhiten(_, newChunk, prevWindow, vecSize, windowSize, chunkSize, blockSize, + normalizeAfterWhiten) => + StreamWhiten( + newChildren(0).asInstanceOf[IR], + newChunk, + prevWindow, + vecSize, + windowSize, + chunkSize, + blockSize, + normalizeAfterWhiten, + ) case ArrayZeros(_) => assert(newChildren.length == 1) ArrayZeros(newChildren(0).asInstanceOf[IR]) case MakeNDArray(_, _, _, errorId) => assert(newChildren.length == 3) - MakeNDArray(newChildren(0).asInstanceOf[IR], newChildren(1).asInstanceOf[IR], newChildren(2).asInstanceOf[IR], errorId) + MakeNDArray( + newChildren(0).asInstanceOf[IR], + newChildren(1).asInstanceOf[IR], + newChildren(2).asInstanceOf[IR], + errorId, + ) case NDArrayShape(_) => assert(newChildren.length == 1) NDArrayShape(newChildren(0).asInstanceOf[IR]) case NDArrayReshape(_, _, errorID) => - assert(newChildren.length == 2) + assert(newChildren.length == 2) NDArrayReshape(newChildren(0).asInstanceOf[IR], newChildren(1).asInstanceOf[IR], errorID) case NDArrayConcat(_, axis) => - assert(newChildren.length == 1) + assert(newChildren.length == 1) NDArrayConcat(newChildren(0).asInstanceOf[IR], axis) case NDArrayRef(_, _, errorId) => - NDArrayRef(newChildren(0).asInstanceOf[IR], newChildren.tail.map(_.asInstanceOf[IR]), errorId) + NDArrayRef( + newChildren(0).asInstanceOf[IR], + newChildren.tail.map(_.asInstanceOf[IR]), + errorId, + ) case NDArraySlice(_, _) => - assert(newChildren.length == 2) + assert(newChildren.length == 2) NDArraySlice(newChildren(0).asInstanceOf[IR], newChildren(1).asInstanceOf[IR]) case NDArrayFilter(_, _) => NDArrayFilter(newChildren(0).asInstanceOf[IR], newChildren.tail.map(_.asInstanceOf[IR])) case NDArrayMap(_, name, _) => - assert(newChildren.length == 2) + assert(newChildren.length == 2) NDArrayMap(newChildren(0).asInstanceOf[IR], name, newChildren(1).asInstanceOf[IR]) case NDArrayMap2(_, _, lName, rName, _, errorID) => - assert(newChildren.length == 3) - NDArrayMap2(newChildren(0).asInstanceOf[IR], newChildren(1).asInstanceOf[IR], lName, rName, newChildren(2).asInstanceOf[IR], errorID) + assert(newChildren.length == 3) + NDArrayMap2( + newChildren(0).asInstanceOf[IR], + newChildren(1).asInstanceOf[IR], + lName, + rName, + newChildren(2).asInstanceOf[IR], + errorID, + ) case NDArrayReindex(_, indexExpr) => assert(newChildren.length == 1) NDArrayReindex(newChildren(0).asInstanceOf[IR], indexExpr) @@ -157,7 +222,10 @@ object Copy { assert(newChildren.length == 2) ArraySort(newChildren(0).asInstanceOf[IR], l, r, newChildren(1).asInstanceOf[IR]) case ArrayMaximalIndependentSet(_, tb) => - ArrayMaximalIndependentSet(newChildren(0).asInstanceOf[IR], tb.map { case (l, r, _) => (l, r, newChildren(1).asInstanceOf[IR]) } ) + ArrayMaximalIndependentSet( + newChildren(0).asInstanceOf[IR], + tb.map { case (l, r, _) => (l, r, newChildren(1).asInstanceOf[IR]) }, + ) case ToSet(_) => assert(newChildren.length == 1) ToSet(newChildren(0).asInstanceOf[IR]) @@ -175,7 +243,11 @@ object Copy { ToStream(newChildren(0).asInstanceOf[IR], requiresMemoryManagementPerElement) case LowerBoundOnOrderedCollection(_, _, asKey) => assert(newChildren.length == 2) - LowerBoundOnOrderedCollection(newChildren(0).asInstanceOf[IR], newChildren(1).asInstanceOf[IR], asKey) + LowerBoundOnOrderedCollection( + newChildren(0).asInstanceOf[IR], + newChildren(1).asInstanceOf[IR], + asKey, + ) case GroupByKey(_) => assert(newChildren.length == 1) GroupByKey(newChildren(0).asInstanceOf[IR]) @@ -202,15 +274,33 @@ object Copy { StreamMap(newChildren(0).asInstanceOf[IR], name, newChildren(1).asInstanceOf[IR]) case StreamZip(_, names, _, behavior, errorID) => assert(newChildren.length == names.length + 1) - StreamZip(newChildren.init.asInstanceOf[IndexedSeq[IR]], names, newChildren(names.length).asInstanceOf[IR], - behavior, errorID) + StreamZip( + newChildren.init.asInstanceOf[IndexedSeq[IR]], + names, + newChildren(names.length).asInstanceOf[IR], + behavior, + errorID, + ) case StreamZipJoin(as, key, curKey, curVals, _) => assert(newChildren.length == as.length + 1) - StreamZipJoin(newChildren.init.asInstanceOf[IndexedSeq[IR]], key, curKey, curVals, newChildren(as.length).asInstanceOf[IR]) + StreamZipJoin( + newChildren.init.asInstanceOf[IndexedSeq[IR]], + key, + curKey, + curVals, + newChildren(as.length).asInstanceOf[IR], + ) case StreamZipJoinProducers(_, ctxName, _, key, curKey, curVals, _) => assert(newChildren.length == 3) - StreamZipJoinProducers(newChildren(0).asInstanceOf[IR], ctxName, newChildren(1).asInstanceOf[IR], - key, curKey, curVals, newChildren(2).asInstanceOf[IR]) + StreamZipJoinProducers( + newChildren(0).asInstanceOf[IR], + ctxName, + newChildren(1).asInstanceOf[IR], + key, + curKey, + curVals, + newChildren(2).asInstanceOf[IR], + ) case StreamMultiMerge(as, key) => assert(newChildren.length == as.length) StreamMultiMerge(newChildren.asInstanceOf[IndexedSeq[IR]], key) @@ -228,22 +318,58 @@ object Copy { StreamFlatMap(newChildren(0).asInstanceOf[IR], name, newChildren(1).asInstanceOf[IR]) case StreamFold(_, _, accumName, valueName, _) => assert(newChildren.length == 3) - StreamFold(newChildren(0).asInstanceOf[IR], newChildren(1).asInstanceOf[IR], accumName, valueName, newChildren(2).asInstanceOf[IR]) + StreamFold( + newChildren(0).asInstanceOf[IR], + newChildren(1).asInstanceOf[IR], + accumName, + valueName, + newChildren(2).asInstanceOf[IR], + ) case StreamFold2(_, accum, valueName, seq, _) => val ncIR = newChildren.map(_.asInstanceOf[IR]) assert(newChildren.length == 2 + accum.length + seq.length) - StreamFold2(ncIR(0), + StreamFold2( + ncIR(0), accum.indices.map(i => (accum(i)._1, ncIR(i + 1))), valueName, - seq.indices.map(i => ncIR(i + 1 + accum.length)), ncIR.last) + seq.indices.map(i => ncIR(i + 1 + accum.length)), + ncIR.last, + ) case StreamScan(_, _, accumName, valueName, _) => assert(newChildren.length == 3) - StreamScan(newChildren(0).asInstanceOf[IR], newChildren(1).asInstanceOf[IR], accumName, valueName, newChildren(2).asInstanceOf[IR]) + StreamScan( + newChildren(0).asInstanceOf[IR], + newChildren(1).asInstanceOf[IR], + accumName, + valueName, + newChildren(2).asInstanceOf[IR], + ) + case StreamLeftIntervalJoin(_, _, lKeyNames, rIntrvlName, lname, rname, _) => + assert(newChildren.length == 3) + StreamLeftIntervalJoin( + newChildren(0).asInstanceOf[IR], + newChildren(1).asInstanceOf[IR], + lKeyNames, + rIntrvlName, + lname, + rname, + newChildren(2).asInstanceOf[IR], + ) case StreamJoinRightDistinct(_, _, lKey, rKey, l, r, _, joinType) => assert(newChildren.length == 3) - StreamJoinRightDistinct(newChildren(0).asInstanceOf[IR], newChildren(1).asInstanceOf[IR], lKey, rKey, l, r, newChildren(2).asInstanceOf[IR], joinType) + StreamJoinRightDistinct( + newChildren(0).asInstanceOf[IR], + newChildren(1).asInstanceOf[IR], + lKey, + rKey, + l, + r, + newChildren(2).asInstanceOf[IR], + joinType, + ) case _: StreamLocalLDPrune => - val IndexedSeq(child: IR, r2Threshold: IR, windowSize: IR, maxQueueSize: IR, nSamples: IR) = newChildren + val IndexedSeq(child: IR, r2Threshold: IR, windowSize: IR, maxQueueSize: IR, nSamples: IR) = + newChildren StreamLocalLDPrune(child, r2Threshold, windowSize, maxQueueSize, nSamples) case StreamFor(_, valueName, _) => assert(newChildren.length == 2) @@ -257,11 +383,24 @@ object Copy { case RunAgg(_, _, signatures) => RunAgg(newChildren(0).asInstanceOf[IR], newChildren(1).asInstanceOf[IR], signatures) case RunAggScan(_, name, _, _, _, signatures) => - RunAggScan(newChildren(0).asInstanceOf[IR], name, newChildren(1).asInstanceOf[IR], - newChildren(2).asInstanceOf[IR], newChildren(3).asInstanceOf[IR], signatures) + RunAggScan( + newChildren(0).asInstanceOf[IR], + name, + newChildren(1).asInstanceOf[IR], + newChildren(2).asInstanceOf[IR], + newChildren(3).asInstanceOf[IR], + signatures, + ) case StreamBufferedAggregate(_, _, _, _, name, aggSignatures, bufferSize) => - StreamBufferedAggregate(newChildren(0).asInstanceOf[IR], newChildren(1).asInstanceOf[IR], - newChildren(2).asInstanceOf[IR], newChildren(3).asInstanceOf[IR], name, aggSignatures, bufferSize) + StreamBufferedAggregate( + newChildren(0).asInstanceOf[IR], + newChildren(1).asInstanceOf[IR], + newChildren(2).asInstanceOf[IR], + newChildren(3).asInstanceOf[IR], + name, + aggSignatures, + bufferSize, + ) case AggFilter(_, _, isScan) => assert(newChildren.length == 2) AggFilter(newChildren(0).asInstanceOf[IR], newChildren(1).asInstanceOf[IR], isScan) @@ -278,7 +417,14 @@ object Copy { assert(newChildren.length == 2) None } - AggArrayPerElement(newChildren(0).asInstanceOf[IR], elementName, indexName, newChildren(1).asInstanceOf[IR], newKnownLength, isScan) + AggArrayPerElement( + newChildren(0).asInstanceOf[IR], + elementName, + indexName, + newChildren(1).asInstanceOf[IR], + newKnownLength, + isScan, + ) case MakeStruct(fields) => assert(fields.length == newChildren.length) MakeStruct(fields.zip(newChildren).map { case ((n, _), a) => (n, a.asInstanceOf[IR]) }) @@ -287,7 +433,11 @@ object Copy { SelectFields(newChildren(0).asInstanceOf[IR], fields) case InsertFields(_, fields, fieldOrder) => assert(newChildren.length == fields.length + 1) - InsertFields(newChildren.head.asInstanceOf[IR], fields.zip(newChildren.tail).map { case ((n, _), a) => (n, a.asInstanceOf[IR]) }, fieldOrder) + InsertFields( + newChildren.head.asInstanceOf[IR], + fields.zip(newChildren.tail).map { case ((n, _), a) => (n, a.asInstanceOf[IR]) }, + fieldOrder, + ) case GetField(_, name) => assert(newChildren.length == 1) GetField(newChildren(0).asInstanceOf[IR], name) @@ -307,29 +457,40 @@ object Copy { case InitFromSerializedValue(i, _, aggSig) => assert(newChildren.length == 1) InitFromSerializedValue(i, newChildren(0).asInstanceOf[IR], aggSig) - case SerializeAggs(startIdx, serIdx, spec, aggSigs) => SerializeAggs(startIdx, serIdx, spec, aggSigs) - case DeserializeAggs(startIdx, serIdx, spec, aggSigs) => DeserializeAggs(startIdx, serIdx, spec, aggSigs) - case Begin(_) => - Begin(newChildren.map(_.asInstanceOf[IR])) - case x@ApplyAggOp(initOpArgs, seqOpArgs, aggSig) => + case SerializeAggs(startIdx, serIdx, spec, aggSigs) => + SerializeAggs(startIdx, serIdx, spec, aggSigs) + case DeserializeAggs(startIdx, serIdx, spec, aggSigs) => + DeserializeAggs(startIdx, serIdx, spec, aggSigs) + case x @ ApplyAggOp(_, _, aggSig) => val args = newChildren.map(_.asInstanceOf[IR]) assert(args.length == x.nInitArgs + x.nSeqOpArgs) ApplyAggOp( args.take(x.nInitArgs), args.drop(x.nInitArgs), - aggSig) - case x@ApplyScanOp(initOpArgs, _, aggSig) => + aggSig, + ) + case x @ ApplyScanOp(_, _, aggSig) => val args = newChildren.map(_.asInstanceOf[IR]) assert(args.length == x.nInitArgs + x.nSeqOpArgs) ApplyScanOp( args.take(x.nInitArgs), args.drop(x.nInitArgs), - aggSig) + aggSig, + ) case AggFold(_, _, _, accumName, otherAccumName, isScan) => - AggFold(newChildren(0).asInstanceOf[IR], newChildren(1).asInstanceOf[IR], newChildren(2).asInstanceOf[IR], accumName, otherAccumName, isScan) + AggFold( + newChildren(0).asInstanceOf[IR], + newChildren(1).asInstanceOf[IR], + newChildren(2).asInstanceOf[IR], + accumName, + otherAccumName, + isScan, + ) case MakeTuple(fields) => assert(fields.length == newChildren.length) - MakeTuple(fields.zip(newChildren).map { case ((i, _), newValue) => (i, newValue.asInstanceOf[IR]) }) + MakeTuple(fields.zip(newChildren).map { case ((i, _), newValue) => + (i, newValue.asInstanceOf[IR]) + }) case GetTupleElement(_, idx) => assert(newChildren.length == 1) GetTupleElement(newChildren(0).asInstanceOf[IR], idx) @@ -337,22 +498,28 @@ object Copy { case Die(_, typ, errorId) => assert(newChildren.length == 1) Die(newChildren(0).asInstanceOf[IR], typ, errorId) - case Trap(child) => + case Trap(_) => assert(newChildren.length == 1) Trap(newChildren(0).asInstanceOf[IR]) - case ConsoleLog(message, result) => + case ConsoleLog(_, _) => assert(newChildren.length == 2) ConsoleLog(newChildren(0).asInstanceOf[IR], newChildren(1).asInstanceOf[IR]) - case x@ApplyIR(fn, typeArgs, args, rt, errorID) => + case x @ ApplyIR(fn, typeArgs, _, rt, errorID) => val r = ApplyIR(fn, typeArgs, newChildren.map(_.asInstanceOf[IR]), rt, errorID) r.conversion = x.conversion r.inline = x.inline r - case Apply(fn, typeArgs, args, t, errorID) => + case Apply(fn, typeArgs, _, t, errorID) => Apply(fn, typeArgs, newChildren.map(_.asInstanceOf[IR]), t, errorID) - case ApplySeeded(fn, args, rngState, staticUID, t) => - ApplySeeded(fn, newChildren.init.map(_.asInstanceOf[IR]), newChildren.last.asInstanceOf[IR], staticUID, t) - case ApplySpecial(fn, typeArgs, args, t, errorID) => + case ApplySeeded(fn, _, _, staticUID, t) => + ApplySeeded( + fn, + newChildren.init.map(_.asInstanceOf[IR]), + newChildren.last.asInstanceOf[IR], + staticUID, + t, + ) + case ApplySpecial(fn, typeArgs, _, t, errorID) => ApplySpecial(fn, typeArgs, newChildren.map(_.asInstanceOf[IR]), t, errorID) // from MatrixIR case MatrixWrite(_, writer) => @@ -403,17 +570,26 @@ object Copy { BlockMatrixMultiWrite(newChildren.map(_.asInstanceOf[BlockMatrixIR]), writer) case CollectDistributedArray(_, _, cname, gname, _, _, id, tsd) => assert(newChildren.length == 4) - CollectDistributedArray(newChildren(0).asInstanceOf[IR], newChildren(1).asInstanceOf[IR], cname, gname, newChildren(2).asInstanceOf[IR], newChildren(3).asInstanceOf[IR], id, tsd) - case ReadPartition(context, rowType, reader) => + CollectDistributedArray( + newChildren(0).asInstanceOf[IR], + newChildren(1).asInstanceOf[IR], + cname, + gname, + newChildren(2).asInstanceOf[IR], + newChildren(3).asInstanceOf[IR], + id, + tsd, + ) + case ReadPartition(_, rowType, reader) => assert(newChildren.length == 1) ReadPartition(newChildren(0).asInstanceOf[IR], rowType, reader) - case WritePartition(stream, ctx, writer) => + case WritePartition(_, _, writer) => assert(newChildren.length == 2) WritePartition(newChildren(0).asInstanceOf[IR], newChildren(1).asInstanceOf[IR], writer) - case WriteMetadata(ctx, writer) => + case WriteMetadata(_, writer) => assert(newChildren.length == 1) WriteMetadata(newChildren(0).asInstanceOf[IR], writer) - case ReadValue(path, writer, requestedType) => + case ReadValue(_, writer, requestedType) => assert(newChildren.length == 1) ReadValue(newChildren(0).asInstanceOf[IR], writer, requestedType) case WriteValue(_, _, writer, _) => diff --git a/hail/src/main/scala/is/hail/expr/ir/DeprecatedIRBuilder.scala b/hail/src/main/scala/is/hail/expr/ir/DeprecatedIRBuilder.scala index 837ea5c51b6..96e91f36bda 100644 --- a/hail/src/main/scala/is/hail/expr/ir/DeprecatedIRBuilder.scala +++ b/hail/src/main/scala/is/hail/expr/ir/DeprecatedIRBuilder.scala @@ -2,7 +2,7 @@ package is.hail.expr.ir import is.hail.types._ import is.hail.types.virtual._ -import is.hail.utils.FastSeq +import is.hail.utils.{toRichIterable, FastSeq} import scala.language.{dynamics, implicitConversions} @@ -28,7 +28,7 @@ object DeprecatedIRBuilder { implicit def symbolToSymbolProxy(s: Symbol): SymbolProxy = new SymbolProxy(s) implicit def arrayToProxy(seq: IndexedSeq[IRProxy]): IRProxy = (env: E) => { - val irs = seq.map(_ (env)) + val irs = seq.map(_(env)) val elType = irs.head.typ MakeArray(irs, TArray(elType)) } @@ -40,7 +40,6 @@ object DeprecatedIRBuilder { def irArrayLen(a: IRProxy): IRProxy = (env: E) => ArrayLen(a(env)) - def irIf(cond: IRProxy)(cnsq: IRProxy)(altr: IRProxy): IRProxy = (env: E) => If(cond(env), cnsq(env), altr(env)) @@ -61,12 +60,13 @@ object DeprecatedIRBuilder { } def makeTuple(values: IRProxy*): IRProxy = (env: E) => - MakeTuple.ordered(values.toArray.map(_ (env))) + MakeTuple.ordered(values.toArray.map(_(env))) def applyAggOp( - op: AggOp, - initOpArgs: IndexedSeq[IRProxy] = FastSeq(), - seqOpArgs: IndexedSeq[IRProxy] = FastSeq()): IRProxy = (env: E) => { + op: AggOp, + initOpArgs: IndexedSeq[IRProxy] = FastSeq(), + seqOpArgs: IndexedSeq[IRProxy] = FastSeq(), + ): IRProxy = (env: E) => { val i = initOpArgs.map(x => x(env)) val s = seqOpArgs.map(x => x(env)) @@ -121,7 +121,10 @@ object DeprecatedIRBuilder { keyBy(FastSeq()) .collect() .apply('rows) - .map(Symbol(uid) ~> makeTuple(Symbol(uid).selectFields(keyFields: _*), Symbol(uid).selectFields(valueFields: _*))) + .map(Symbol(uid) ~> makeTuple( + Symbol(uid).selectFields(keyFields: _*), + Symbol(uid).selectFields(valueFields: _*), + )) .toDict } @@ -134,7 +137,7 @@ object DeprecatedIRBuilder { ArrayRef(ir(env), idx(env)) def invoke(name: String, rt: Type, args: IRProxy*): IRProxy = { env: E => - val irArgs = Array(ir(env)) ++ args.map(_ (env)) + val irArgs = Array(ir(env)) ++ args.map(_(env)) is.hail.expr.ir.invoke(name, rt, irArgs: _*) } @@ -147,9 +150,11 @@ object DeprecatedIRBuilder { def *(other: IRProxy): IRProxy = (env: E) => ApplyBinaryPrimOp(Multiply(), ir(env), other(env)) - def /(other: IRProxy): IRProxy = (env: E) => ApplyBinaryPrimOp(FloatingPointDivide(), ir(env), other(env)) + def /(other: IRProxy): IRProxy = + (env: E) => ApplyBinaryPrimOp(FloatingPointDivide(), ir(env), other(env)) - def floorDiv(other: IRProxy): IRProxy = (env: E) => ApplyBinaryPrimOp(RoundToNegInfDivide(), ir(env), other(env)) + def floorDiv(other: IRProxy): IRProxy = + (env: E) => ApplyBinaryPrimOp(RoundToNegInfDivide(), ir(env), other(env)) def &&(other: IRProxy): IRProxy = invoke("land", TBoolean, ir, other) @@ -217,8 +222,11 @@ object DeprecatedIRBuilder { def insertFields(fields: (Symbol, IRProxy)*): IRProxy = insertFieldsList(fields) - def insertFieldsList(fields: Seq[(Symbol, IRProxy)], ordering: Option[IndexedSeq[String]] = None): IRProxy = (env: E) => - InsertFields(ir(env), fields.map { case (s, fir) => (s.name, fir(env))}, ordering) + def insertFieldsList( + fields: Seq[(Symbol, IRProxy)], + ordering: Option[IndexedSeq[String]] = None, + ): IRProxy = (env: E) => + InsertFields(ir(env), fields.map { case (s, fir) => (s.name, fir(env)) }, ordering) def selectFields(fields: String*): IRProxy = (env: E) => SelectFields(ir(env), fields.toArray[String]) @@ -231,17 +239,21 @@ object DeprecatedIRBuilder { def dropFields(fields: Symbol*): IRProxy = dropFieldList(fields.map(_.name).toArray[String]) - def insertStruct(other: IRProxy, ordering: Option[IndexedSeq[String]] = None): IRProxy = (env: E) => { - val right = other(env) - val sym = genUID() - Let(FastSeq(sym -> right), - InsertFields( - ir(env), - right.typ.asInstanceOf[TStruct].fieldNames.map(f => f -> GetField(Ref(sym, right.typ), f)), - ordering + def insertStruct(other: IRProxy, ordering: Option[IndexedSeq[String]] = None): IRProxy = + (env: E) => { + val right = other(env) + val sym = genUID() + Let( + FastSeq(sym -> right), + InsertFields( + ir(env), + right.typ.asInstanceOf[TStruct].fieldNames.map(f => + f -> GetField(Ref(sym, right.typ), f) + ), + ordering, + ), ) - ) - } + } def len: IRProxy = (env: E) => ArrayLen(ir(env)) @@ -256,7 +268,11 @@ object DeprecatedIRBuilder { def filter(pred: LambdaProxy): IRProxy = (env: E) => { val array = ir(env) val eltType = array.typ.asInstanceOf[TArray].elementType - ToArray(StreamFilter(ToStream(array), pred.s.name, pred.body(env.bind(pred.s.name -> eltType)))) + ToArray(StreamFilter( + ToStream(array), + pred.s.name, + pred.body(env.bind(pred.s.name -> eltType)), + )) } def map(f: LambdaProxy): IRProxy = (env: E) => { @@ -267,13 +283,22 @@ object DeprecatedIRBuilder { def aggExplode(f: LambdaProxy): IRProxy = (env: E) => { val array = ir(env) - AggExplode(ToStream(array), f.s.name, f.body(env.bind(f.s.name, array.typ.asInstanceOf[TArray].elementType)), isScan = false) + AggExplode( + ToStream(array), + f.s.name, + f.body(env.bind(f.s.name, array.typ.asInstanceOf[TArray].elementType)), + isScan = false, + ) } def flatMap(f: LambdaProxy): IRProxy = (env: E) => { val array = ir(env) val eltType = array.typ.asInstanceOf[TArray].elementType - ToArray(StreamFlatMap(ToStream(array), f.s.name, ToStream(f.body(env.bind(f.s.name -> eltType))))) + ToArray(StreamFlatMap( + ToStream(array), + f.s.name, + ToStream(f.body(env.bind(f.s.name -> eltType))), + )) } def streamAgg(f: LambdaProxy): IRProxy = (env: E) => { @@ -289,10 +314,23 @@ object DeprecatedIRBuilder { } def arraySlice(start: IRProxy, stop: Option[IRProxy], step: IRProxy): IRProxy = { - (env: E) => ArraySlice(this.ir(env), start.ir(env), stop.map(inner => inner.ir(env)), step.ir(env), ErrorIDs.NO_ERROR) + (env: E) => + ArraySlice( + this.ir(env), + start.ir(env), + stop.map(inner => inner.ir(env)), + step.ir(env), + ErrorIDs.NO_ERROR, + ) } - def aggElements(elementsSym: Symbol, indexSym: Symbol, knownLength: Option[IRProxy])(aggBody: IRProxy): IRProxy = (env: E) => { + def aggElements( + elementsSym: Symbol, + indexSym: Symbol, + knownLength: Option[IRProxy], + )( + aggBody: IRProxy + ): IRProxy = (env: E) => { val array = ir(env) val eltType = array.typ.asInstanceOf[TArray].elementType AggArrayPerElement( @@ -300,11 +338,13 @@ object DeprecatedIRBuilder { elementsSym.name, indexSym.name, aggBody.apply(env.bind(elementsSym.name -> eltType, indexSym.name -> TInt32)), - knownLength.map(_ (env)), - isScan = false) + knownLength.map(_(env)), + isScan = false, + ) } - def sort(ascending: IRProxy, onKey: Boolean = false): IRProxy = (env: E) => ArraySort(ToStream(ir(env)), ascending(env), onKey) + def sort(ascending: IRProxy, onKey: Boolean = false): IRProxy = + (env: E) => ArraySort(ToStream(ir(env)), ascending(env), onKey) def groupByKey: IRProxy = (env: E) => GroupByKey(ToStream(ir(env))) @@ -312,7 +352,8 @@ object DeprecatedIRBuilder { def toDict: IRProxy = (env: E) => ToDict(ToStream(ir(env))) - def parallelize(nPartitions: Option[Int] = None): TableIR = TableParallelize(ir(Env.empty), nPartitions) + def parallelize(nPartitions: Option[Int] = None): TableIR = + TableParallelize(ir(Env.empty), nPartitions) def arrayStructToDict(keyFields: IndexedSeq[String]): IRProxy = { val element = Symbol(genUID()) @@ -320,7 +361,8 @@ object DeprecatedIRBuilder { .map(element ~> makeTuple( element.selectFields(keyFields: _*), - element.dropFieldList(keyFields))) + element.dropFieldList(keyFields), + )) .toDict } @@ -335,65 +377,65 @@ object DeprecatedIRBuilder { def ~>(body: IRProxy): LambdaProxy = new LambdaProxy(s, body) } - case class BindingProxy(s: Symbol, value: IRProxy) - - object LetProxy { - def bind(bindings: Seq[BindingProxy], body: IRProxy, env: E, scope: Int): IR = - bindings match { - case BindingProxy(sym, binding) +: rest => - val name = sym.name - val value = binding(env) - scope match { - case Scope.EVAL => Let(FastSeq(name -> value), bind(rest, body, env.bind(name -> value.typ), scope)) - case Scope.AGG => AggLet(name, value, bind(rest, body, env.bind(name -> value.typ), scope), isScan = false) - case Scope.SCAN => AggLet(name, value, bind(rest, body, env.bind(name -> value.typ), scope), isScan = true) - } - case Seq() => - body(env) + case class BindingProxy(s: Symbol, value: IRProxy, scope: Int) + + private object LetProxy { + def bind(bindings: IndexedSeq[BindingProxy], body: IRProxy, env: E): IR = { + var newEnv = env + val resolvedBindings = bindings.map { case BindingProxy(sym, value, scope) => + val resolvedValue = value(newEnv) + newEnv = newEnv.bind(sym.name -> resolvedValue.typ) + Binding(sym.name, resolvedValue, scope) } + Block(resolvedBindings, body(newEnv)) + } } object let extends Dynamic { def applyDynamicNamed(method: String)(args: (String, IRProxy)*): LetProxy = { assert(method == "apply") - new LetProxy(args.map { case (s, b) => BindingProxy(Symbol(s), b) }) + letDyn(args: _*) } } - class LetProxy(val bindings: Seq[BindingProxy]) extends AnyVal with Dynamic { + object letDyn { + def apply(args: (String, IRProxy)*): LetProxy = + new LetProxy(args.map { case (s, b) => BindingProxy(Symbol(s), b, Scope.EVAL) }.toFastSeq) + } + + class LetProxy(val bindings: IndexedSeq[BindingProxy]) extends AnyVal { def apply(body: IRProxy): IRProxy = in(body) - def in(body: IRProxy): IRProxy = { (env: E) => - LetProxy.bind(bindings, body, env, scope = Scope.EVAL) - } + def in(body: IRProxy): IRProxy = { (env: E) => LetProxy.bind(bindings, body, env) } } object aggLet extends Dynamic { def applyDynamicNamed(method: String)(args: (String, IRProxy)*): AggLetProxy = { assert(method == "apply") - new AggLetProxy(args.map { case (s, b) => BindingProxy(Symbol(s), b) }) + new AggLetProxy(args.map { case (s, b) => BindingProxy(Symbol(s), b, Scope.AGG) }.toFastSeq) } } - class AggLetProxy(val bindings: Seq[BindingProxy]) extends AnyVal with Dynamic { + class AggLetProxy(val bindings: IndexedSeq[BindingProxy]) extends AnyVal { def apply(body: IRProxy): IRProxy = in(body) - def in(body: IRProxy): IRProxy = { (env: E) => - LetProxy.bind(bindings, body, env, scope = Scope.AGG) - } + def in(body: IRProxy): IRProxy = { (env: E) => LetProxy.bind(bindings, body, env) } } object MapIRProxy { - def apply(f: (IRProxy) => IRProxy)(x: IRProxy): IRProxy = (e: E) => { + def apply(f: (IRProxy) => IRProxy)(x: IRProxy): IRProxy = (e: E) => MapIR(x => f(x)(e))(x(e)) - } } - def subst(x: IRProxy, env: BindingEnv[IRProxy]): IRProxy = (e: E) => { - Subst(x(e), BindingEnv(env.eval.mapValues(_ (e)), - agg = env.agg.map(_.mapValues(_ (e))), - scan = env.scan.map(_.mapValues(_ (e))))) - } + def subst(x: IRProxy, env: BindingEnv[IRProxy]): IRProxy = (e: E) => + Subst( + x(e), + BindingEnv( + env.eval.mapValues(_(e)), + agg = env.agg.map(_.mapValues(_(e))), + scan = env.scan.map(_.mapValues(_(e))), + ), + ) def lift(f: (IR) => IRProxy)(x: IRProxy): IRProxy = (e: E) => f(x(e))(e) } diff --git a/hail/src/main/scala/is/hail/expr/ir/DistinctlyKeyed.scala b/hail/src/main/scala/is/hail/expr/ir/DistinctlyKeyed.scala index 64ab867e6d9..0275f9ff55f 100644 --- a/hail/src/main/scala/is/hail/expr/ir/DistinctlyKeyed.scala +++ b/hail/src/main/scala/is/hail/expr/ir/DistinctlyKeyed.scala @@ -8,36 +8,38 @@ object DistinctlyKeyed { case t: TableRead => memo.bindIf(t.isDistinctlyKeyed, t, ()) - case t@TableKeyBy(child, keys, _) => + case t @ TableKeyBy(child, keys, _) => memo.bindIf(child.typ.key.forall(keys.contains) && memo.contains(child), t, ()) - case t@(_: TableFilter - | _: TableIntervalJoin - | _: TableLeftJoinRightDistinct - | _: TableMapRows - | _: TableMapGlobals) => + case t @ (_: TableFilter + | _: TableIntervalJoin + | _: TableLeftJoinRightDistinct + | _: TableMapRows + | _: TableMapGlobals) => memo.bindIf(memo.contains(t.children.head), t, ()) - case t@RelationalLetTable(_, _, body) => + case t @ RelationalLetTable(_, _, body) => memo.bindIf(memo.contains(body), t, ()) - case t@(_: TableHead - | _: TableTail - | _: TableRepartition - | _: TableJoin - | _: TableMultiWayZipJoin - | _: TableRename - | _: TableFilterIntervals) => + case t @ (_: TableHead + | _: TableTail + | _: TableRepartition + | _: TableJoin + | _: TableMultiWayZipJoin + | _: TableRename + | _: TableFilterIntervals) => memo.bindIf(t.children.forall(memo.contains), t, ()) - case t@(_: TableRange - | _: TableDistinct - | _: TableKeyByAndAggregate - | _: TableAggregateByKey) => + case t @ (_: TableRange + | _: TableDistinct + | _: TableKeyByAndAggregate + | _: TableAggregateByKey) => memo.bind(t, ()) case _: MatrixIR => - throw new IllegalArgumentException("MatrixIR should be lowered when it reaches distinct analysis") + throw new IllegalArgumentException( + "MatrixIR should be lowered when it reaches distinct analysis" + ) case _ => memo @@ -47,9 +49,8 @@ object DistinctlyKeyed { } case class DistinctKeyedAnalysis(distinctMemo: Memo[Unit]) { - def contains(tableIR: BaseIR): Boolean = { + def contains(tableIR: BaseIR): Boolean = distinctMemo.contains(tableIR) - } override def toString: String = distinctMemo.toString -} \ No newline at end of file +} diff --git a/hail/src/main/scala/is/hail/expr/ir/Emit.scala b/hail/src/main/scala/is/hail/expr/ir/Emit.scala index 7465aea42e0..272cea273ef 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Emit.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Emit.scala @@ -4,26 +4,30 @@ import is.hail.annotations._ import is.hail.asm4s._ import is.hail.backend.{BackendContext, ExecuteContext, HailTaskContext} import is.hail.expr.ir.agg.{AggStateSig, ArrayAggStateSig, GroupedStateSig} -import is.hail.expr.ir.analyses.{ComputeMethodSplits, ControlFlowPreventsSplit, ParentPointers, SemanticHash} +import is.hail.expr.ir.analyses.{ + ComputeMethodSplits, ControlFlowPreventsSplit, ParentPointers, SemanticHash, +} import is.hail.expr.ir.lowering.TableStageDependency import is.hail.expr.ir.ndarrays.EmitNDArray import is.hail.expr.ir.streams.{EmitStream, StreamProducer, StreamUtils} -import is.hail.io.fs.FS import is.hail.io.{BufferSpec, InputBuffer, OutputBuffer, TypedCodecSpec} +import is.hail.io.fs.FS import is.hail.linalg.{BLAS, LAPACK, LinalgCodeUtils} +import is.hail.types.{tcoerce, TypeWithRequiredness, VirtualTypeWithReq} import is.hail.types.physical._ import is.hail.types.physical.stypes._ import is.hail.types.physical.stypes.concrete._ import is.hail.types.physical.stypes.interfaces._ import is.hail.types.physical.stypes.primitives._ import is.hail.types.virtual._ -import is.hail.types.{TypeWithRequiredness, VirtualTypeWithReq, tcoerce} import is.hail.utils._ import is.hail.variant.ReferenceGenome -import java.io._ +import scala.annotation.{nowarn, tailrec} import scala.collection.mutable -import scala.language.{existentials, postfixOps} +import scala.language.existentials + +import java.io._ // class for holding all information computed ahead-of-time that we need in the emitter object EmitContext { @@ -33,7 +37,14 @@ object EmitContext { val requiredness = Requiredness(ir, usesAndDefs, ctx, pTypeEnv) val inLoopCriticalPath = ControlFlowPreventsSplit(ir, ParentPointers(ir), usesAndDefs) val methodSplits = ComputeMethodSplits(ctx, ir, inLoopCriticalPath) - new EmitContext(ctx, requiredness, usesAndDefs, methodSplits, inLoopCriticalPath, Memo.empty[Unit]) + new EmitContext( + ctx, + requiredness, + usesAndDefs, + methodSplits, + inLoopCriticalPath, + Memo.empty[Unit], + ) } } } @@ -44,33 +55,52 @@ class EmitContext( val usesAndDefs: UsesAndDefs, val methodSplits: Memo[Unit], val inLoopCriticalPath: Memo[Unit], - val tryingToSplit: Memo[Unit] + val tryingToSplit: Memo[Unit], ) case class EmitEnv(bindings: Env[EmitValue], inputValues: IndexedSeq[EmitValue]) { def bind(name: String, v: EmitValue): EmitEnv = copy(bindings = bindings.bind(name, v)) - def bind(newBindings: (String, EmitValue)*): EmitEnv = copy(bindings = bindings.bindIterable(newBindings)) + def bind(newBindings: (String, EmitValue)*): EmitEnv = + copy(bindings = bindings.bindIterable(newBindings)) - def asParams(freeVariables: Env[Unit]): (IndexedSeq[ParamType], IndexedSeq[Value[_]], (EmitCodeBuilder, Int) => EmitEnv) = { + def asParams(freeVariables: Env[Unit]) + : (IndexedSeq[ParamType], IndexedSeq[Value[_]], (EmitCodeBuilder, Int) => EmitEnv) = { val m = bindings.m.filterKeys(freeVariables.contains) val bindingNames = m.keys.toArray - val paramTypes = bindingNames.map(name => m(name).emitType.paramType) ++ inputValues.map(_.emitType.paramType) - val params = bindingNames.flatMap(name => m(name).valueTuple()) ++ inputValues.flatMap(_.valueTuple()) + val paramTypes = + bindingNames.map(name => m(name).emitType.paramType) ++ inputValues.map(_.emitType.paramType) + val params = + bindingNames.flatMap(name => m(name).valueTuple()) ++ inputValues.flatMap(_.valueTuple()) val recreateFromMB = { (cb: EmitCodeBuilder, startIdx: Int) => val emb = cb.emb EmitEnv( - Env.fromSeq(bindingNames.zipWithIndex.map { case (name, bindingIdx) => (name, cb.memoizeField(emb.getEmitParam(cb, startIdx + bindingIdx), name))}), - inputValues.indices.map(inputIdx => cb.memoizeField(emb.getEmitParam(cb, startIdx + bindingNames.length + inputIdx), s"arg_$inputIdx")) + Env.fromSeq(bindingNames.zipWithIndex.map { case (name, bindingIdx) => + (name, cb.memoizeField(emb.getEmitParam(cb, startIdx + bindingIdx), name)) + }), + inputValues.indices.map(inputIdx => + cb.memoizeField( + emb.getEmitParam(cb, startIdx + bindingNames.length + inputIdx), + s"arg_$inputIdx", + ) + ), ) } (paramTypes, params, recreateFromMB) } + } object Emit { - def apply[C](ctx: EmitContext, ir: IR, fb: EmitFunctionBuilder[C], rti: TypeInfo[_], nParams: Int, aggs: Option[Array[AggStateSig]] = None): Option[SingleCodeType] = { + def apply[C]( + ctx: EmitContext, + ir: IR, + fb: EmitFunctionBuilder[C], + rti: TypeInfo[_], + nParams: Int, + aggs: Option[Array[AggStateSig]] = None, + ): Option[SingleCodeType] = { TypeCheck(ctx.executeContext, ir) val mb = fb.apply_method @@ -82,21 +112,27 @@ object Emit { val region = mb.getCodeParam[Region](1) val returnTypeOption: Option[SingleCodeType] = if (ir.typ == TVoid) { fb.apply_method.voidWithBuilder { cb => - val env = EmitEnv(Env.empty, (0 until nParams).map(i => mb.storeEmitParamAsField(cb, i + 2))) // this, region, ... + val env = EmitEnv( + Env.empty, + (0 until nParams).map(i => mb.storeEmitParamAsField(cb, i + 2)), + ) // this, region, ... emitter.emitVoid(cb, ir, region, env, container, None) } None } else { var sct: SingleCodeType = null fb.emitWithBuilder { cb => - - val env = EmitEnv(Env.empty, (0 until nParams).map(i => mb.storeEmitParamAsField(cb, i + 2))) // this, region, ... - val sc = emitter.emitI(ir, cb, region, env, container, None).handle(cb, { - cb._throw(Code.newInstance[RuntimeException, String]("cannot return empty")) - }) + val env = EmitEnv( + Env.empty, + (0 until nParams).map(i => mb.storeEmitParamAsField(cb, i + 2)), + ) // this, region, ... + val sc = emitter.emitI(ir, cb, region, env, container, None).handle( + cb, + cb._throw(Code.newInstance[RuntimeException, String]("cannot return empty")), + ) val scp = SingleCodeSCode.fromSCode(cb, sc, region) - assert(scp.typ.ti == rti, s"type info mismatch: expect $rti, got ${ scp.typ.ti }") + assert(scp.typ.ti == rti, s"type info mismatch: expect $rti, got ${scp.typ.ti}") sct = scp.typ scp.code } @@ -108,14 +144,22 @@ object Emit { object AggContainer { // FIXME remove this when EmitStream also has a codebuilder - def fromVars(aggs: Array[AggStateSig], mb: EmitMethodBuilder[_], region: Settable[Region], off: Settable[Long]): (AggContainer, EmitCodeBuilder => Unit, EmitCodeBuilder => Unit) = { + def fromVars( + aggs: Array[AggStateSig], + mb: EmitMethodBuilder[_], + region: Settable[Region], + off: Settable[Long], + ): (AggContainer, EmitCodeBuilder => Unit, EmitCodeBuilder => Unit) = { val (setup, aggState) = EmitCodeBuilder.scoped(mb) { cb => val states = agg.StateTuple(aggs.map(a => agg.AggStateSig.getState(a, cb.emb.ecb))) val aggState = new agg.TupleAggregatorState(mb.ecb, states, region, off) cb += (region := Region.stagedCreate(Region.REGULAR, cb.emb.ecb.pool())) cb += region.load().setNumParents(aggs.length) - cb += (off := region.load().allocate(aggState.storageType.alignment, aggState.storageType.byteSize)) + cb += (off := region.load().allocate( + aggState.storageType.alignment, + aggState.storageType.byteSize, + )) states.createStates(cb) aggState } @@ -129,27 +173,48 @@ object AggContainer { (AggContainer(aggs, aggState, () => ()), (cb: EmitCodeBuilder) => cb += setup, cleanup) } - def fromMethodBuilder[C](aggs: Array[AggStateSig], mb: EmitMethodBuilder[C], varPrefix: String): (AggContainer, EmitCodeBuilder => Unit, EmitCodeBuilder => Unit) = - fromVars(aggs, mb, mb.genFieldThisRef[Region](s"${ varPrefix }_top_region"), mb.genFieldThisRef[Long](s"${ varPrefix }_off")) - - def fromBuilder[C](cb: EmitCodeBuilder, aggs: Array[AggStateSig], varPrefix: String): AggContainer = { - val off = cb.newField[Long](s"${ varPrefix }_off") - val region = cb.newField[Region](s"${ varPrefix }_top_region", Region.stagedCreate(Region.REGULAR, cb.emb.ecb.pool())) + def fromMethodBuilder[C](aggs: Array[AggStateSig], mb: EmitMethodBuilder[C], varPrefix: String) + : (AggContainer, EmitCodeBuilder => Unit, EmitCodeBuilder => Unit) = + fromVars( + aggs, + mb, + mb.genFieldThisRef[Region](s"${varPrefix}_top_region"), + mb.genFieldThisRef[Long](s"${varPrefix}_off"), + ) + + def fromBuilder[C](cb: EmitCodeBuilder, aggs: Array[AggStateSig], varPrefix: String) + : AggContainer = { + val off = cb.newField[Long](s"${varPrefix}_off") + val region = cb.newField[Region]( + s"${varPrefix}_top_region", + Region.stagedCreate(Region.REGULAR, cb.emb.ecb.pool()), + ) val states = agg.StateTuple(aggs.map(a => agg.AggStateSig.getState(a, cb.emb.ecb))) val aggState = new agg.TupleAggregatorState(cb.emb.ecb, states, region, off) cb += region.load().setNumParents(aggs.length) - cb.assign(off, region.load().allocate(aggState.storageType.alignment, aggState.storageType.byteSize)) + cb.assign( + off, + region.load().allocate(aggState.storageType.alignment, aggState.storageType.byteSize), + ) states.createStates(cb) - AggContainer(aggs, aggState, { () => - aggState.store(cb) - cb += region.load().invalidate() - cb.assign(region, Code._null[Region]) - }) + AggContainer( + aggs, + aggState, + { () => + aggState.store(cb) + cb += region.load().invalidate() + cb.assign(region, Code._null[Region]) + }, + ) } } -case class AggContainer(aggs: Array[AggStateSig], container: agg.TupleAggregatorState, cleanup: () => Unit) { +case class AggContainer( + aggs: Array[AggStateSig], + container: agg.TupleAggregatorState, + cleanup: () => Unit, +) { def nested(i: Int, init: Boolean): Option[AggContainer] = { aggs(i).n.map { nested => @@ -178,14 +243,15 @@ object EmitValue { def apply(missing: Option[Value[Boolean]], v: SValue): EmitValue = new EmitValue( missing.filterNot(m => Code.constBoolValue(m).contains(false)), - v) + v, + ) def present(v: SValue): EmitValue = EmitValue(None, v) def missing(t: SType): EmitValue = EmitValue(Some(const(true)), t.defaultValue) } -class EmitValue protected(missing: Option[Value[Boolean]], val v: SValue) { +class EmitValue protected (missing: Option[Value[Boolean]], val v: SValue) { def m: Value[Boolean] = missing.getOrElse(const(false)) def required: Boolean = missing.isEmpty @@ -208,9 +274,7 @@ class EmitValue protected(missing: Option[Value[Boolean]], val v: SValue) { def loadI(cb: EmitCodeBuilder): IEmitCode = load.toI(cb) def get(cb: EmitCodeBuilder): SValue = { - missing.foreach { m => - cb.if_(m, cb._fatal(s"Can't convert missing ${ v.st } to SValue")) - } + missing.foreach(m => cb.if_(m, cb._fatal(s"Can't convert missing ${v.st} to SValue"))) v } @@ -223,13 +287,11 @@ class EmitValue protected(missing: Option[Value[Boolean]], val v: SValue) { } } -/** - * Notes on IEmitCode; - * 1. It is the responsibility of the producers of IEmitCode to emit the relevant - * jumps for the Lmissing and Lpresent labels (cb.goto or similar) - * 2. It is the responsibility of consumers to define these labels and to - * prevent the SCode from being used on any code path taken as a result of - * jumping to Lmissing. +/** Notes on IEmitCode; + * 1. It is the responsibility of the producers of IEmitCode to emit the relevant jumps for the + * Lmissing and Lpresent labels (cb.goto or similar) 2. It is the responsibility of consumers + * to define these labels and to prevent the SCode from being used on any code path taken as a + * result of jumping to Lmissing. */ object IEmitCode { def apply[A](cb: EmitCodeBuilder, m: Code[Boolean], value: => A): IEmitCodeGen[A] = { @@ -245,7 +307,8 @@ object IEmitCode { } } - def apply[A](Lmissing: CodeLabel, Lpresent: CodeLabel, value: A, required: Boolean): IEmitCodeGen[A] = + def apply[A](Lmissing: CodeLabel, Lpresent: CodeLabel, value: A, required: Boolean) + : IEmitCodeGen[A] = IEmitCodeGen(Lmissing, Lpresent, value, required) def present[A](cb: EmitCodeBuilder, value: => A): IEmitCodeGen[A] = { @@ -261,12 +324,19 @@ object IEmitCode { IEmitCodeGen(Lmissing, CodeLabel(), defaultValue, false) } - def multiMapEmitCodes(cb: EmitCodeBuilder, seq: IndexedSeq[EmitCode])(f: IndexedSeq[SValue] => SValue): IEmitCode = + def multiMapEmitCodes( + cb: EmitCodeBuilder, + seq: IndexedSeq[EmitCode], + )( + f: IndexedSeq[SValue] => SValue + ): IEmitCode = multiMap(cb, seq.map(ec => cb => ec.toI(cb)))(f) - def multiMap(cb: EmitCodeBuilder, - seq: IndexedSeq[EmitCodeBuilder => IEmitCode] - )(f: IndexedSeq[SValue] => SValue + def multiMap( + cb: EmitCodeBuilder, + seq: IndexedSeq[EmitCodeBuilder => IEmitCode], + )( + f: IndexedSeq[SValue] => SValue ): IEmitCode = { val Lmissing = CodeLabel() val Lpresent = CodeLabel() @@ -288,14 +358,21 @@ object IEmitCode { IEmitCodeGen(Lmissing, Lpresent, pc, required) } - def multiFlatMap(cb: EmitCodeBuilder, - seq: IndexedSeq[EmitCodeBuilder => IEmitCode] - )(f: IndexedSeq[SValue] => IEmitCode + def multiFlatMap( + cb: EmitCodeBuilder, + seq: IndexedSeq[EmitCodeBuilder => IEmitCode], + )( + f: IndexedSeq[SValue] => IEmitCode ): IEmitCode = multiFlatMap[EmitCodeBuilder => IEmitCode, SValue, SValue](seq, x => x(cb), cb)(f) - def multiFlatMap[A, B, C](seq: IndexedSeq[A], toIec: A => IEmitCodeGen[B], cb: EmitCodeBuilder) - (f: IndexedSeq[B] => IEmitCodeGen[C]): IEmitCodeGen[C] = { + def multiFlatMap[A, B, C]( + seq: IndexedSeq[A], + toIec: A => IEmitCodeGen[B], + cb: EmitCodeBuilder, + )( + f: IndexedSeq[B] => IEmitCodeGen[C] + ): IEmitCodeGen[C] = { val Lmissing = CodeLabel() var required: Boolean = true @@ -303,7 +380,6 @@ object IEmitCode { val iec = toIec(elem) required = required && iec.required - cb.define(iec.Lmissing) cb.goto(Lmissing) cb.define(iec.Lpresent) @@ -321,7 +397,7 @@ object IEmitCode { object IEmitCodeGen { - implicit class IEmitCode(val iec: IEmitCodeGen[SValue]) extends AnyVal { + implicit class IEmitCode(private val iec: IEmitCodeGen[SValue]) extends AnyVal { def pc: SValue = iec.value def st: SType = pc.st @@ -336,12 +412,11 @@ object IEmitCodeGen { } case class IEmitCodeGen[+A](Lmissing: CodeLabel, Lpresent: CodeLabel, value: A, required: Boolean) { - lazy val emitType: EmitType = { + lazy val emitType: EmitType = value match { case pc: SValue => EmitType(pc.st, required) case _ => throw new UnsupportedOperationException(s"emitType on $value") } - } def setOptional: IEmitCodeGen[A] = copy(required = false) @@ -376,9 +451,16 @@ case class IEmitCodeGen[+A](Lmissing: CodeLabel, Lpresent: CodeLabel, value: A, value } - def get(cb: EmitCodeBuilder, errorMsg: Code[String]=s"expected non-missing", errorID: Code[Int] = const(ErrorIDs.NO_ERROR)): A = + def getOrFatal( + cb: EmitCodeBuilder, + errorMsg: Code[String], + errorID: Code[Int] = const(ErrorIDs.NO_ERROR), + ): A = handle(cb, cb._fatalWithError(errorID, errorMsg)) + def getOrAssert(cb: EmitCodeBuilder, debugMsg: Code[String] = const("expected non-missing")): A = + handle(cb, cb._assert(false, debugMsg)) + def consume(cb: EmitCodeBuilder, ifMissing: => Unit, ifPresent: (A) => Unit): Unit = { val Lafter = CodeLabel() cb.define(Lmissing) @@ -405,7 +487,8 @@ case class IEmitCodeGen[+A](Lmissing: CodeLabel, Lpresent: CodeLabel, value: A, ret } - def consumeI(cb: EmitCodeBuilder, ifMissing: => IEmitCode, ifPresent: A => IEmitCode): IEmitCode = { + def consumeI(cb: EmitCodeBuilder, ifMissing: => IEmitCode, ifPresent: A => IEmitCode) + : IEmitCode = { val Lmissing2 = CodeLabel() val Lpresent2 = CodeLabel() cb.define(Lmissing) @@ -419,7 +502,11 @@ case class IEmitCodeGen[+A](Lmissing: CodeLabel, Lpresent: CodeLabel, value: A, IEmitCode(Lmissing2, Lpresent2, ret, required = missingI.required && presentI.required) } - def consumeCode[B: TypeInfo](cb: EmitCodeBuilder, ifMissing: => Value[B], ifPresent: (A) => Value[B]): Value[B] = { + def consumeCode[B: TypeInfo]( + cb: EmitCodeBuilder, + ifMissing: => Value[B], + ifPresent: (A) => Value[B], + ): Value[B] = { val ret = cb.emb.newLocal[B]("iec_consumeCode") consume(cb, cb.assign(ret, ifMissing), a => cb.assign(ret, ifPresent(a))) ret @@ -431,10 +518,14 @@ object EmitCode { Code.constBoolValue(m) match { case Some(false) => val Lpresent = CodeLabel() - new EmitCode(new CodeLabel(Code(setup, Lpresent.goto).start), IEmitCode(CodeLabel(), Lpresent, pv, required = true)) + new EmitCode( + new CodeLabel(Code(setup, Lpresent.goto).start), + IEmitCode(CodeLabel(), Lpresent, pv, required = true), + ) case _ => val mCC = Code(setup, m).toCCode - val iec = IEmitCode(new CodeLabel(mCC.Ltrue), new CodeLabel(mCC.Lfalse), pv, required = false) + val iec = + IEmitCode(new CodeLabel(mCC.Ltrue), new CodeLabel(mCC.Lfalse), pv, required = false) val result = new EmitCode(new CodeLabel(mCC.entry), iec) result } @@ -484,9 +575,13 @@ class EmitCode(private val start: CodeLabel, private val iec: IEmitCode) { iec } - def castTo(mb: EmitMethodBuilder[_], region: Value[Region], destType: SType, deepCopy: Boolean = false): EmitCode = { + def castTo( + mb: EmitMethodBuilder[_], + region: Value[Region], + destType: SType, + deepCopy: Boolean = false, + ): EmitCode = EmitCode.fromI(mb)(cb => toI(cb).map(cb)(_.castTo(cb, region, destType))) - } def missingIf(mb: EmitMethodBuilder[_], cond: Code[Boolean]): EmitCode = EmitCode.fromI(mb) { cb => @@ -518,29 +613,29 @@ object EmitSettable { class EmitSettable( missing: Option[Settable[Boolean]], // required if None - vs: SSettable + vs: SSettable, ) extends EmitValue(missing, vs) { - def settableTuple(): IndexedSeq[Settable[_]] = { + def settableTuple(): IndexedSeq[Settable[_]] = missing match { case Some(m) => vs.settableTuple() :+ m case None => vs.settableTuple() } - } - def store(cb: EmitCodeBuilder, ec: EmitCode): Unit = { + def store(cb: EmitCodeBuilder, ec: EmitCode): Unit = store(cb, ec.toI(cb)) - } def store(cb: EmitCodeBuilder, iec: IEmitCode): Unit = if (required) - cb.assign(vs, iec.get(cb, s"Required EmitSettable cannot be missing ${ vs.st }")) + cb.assign(vs, iec.getOrFatal(cb, s"Required EmitSettable cannot be missing ${vs.st}")) else - iec.consume(cb, { - cb.assign(missing.get, true) - }, { value => - cb.assign(missing.get, false) - cb.assign(vs, value) - }) + iec.consume( + cb, + cb.assign(missing.get, true), + { value => + cb.assign(missing.get, false) + cb.assign(vs, value) + }, + ) } class RichIndexedSeqEmitSettable(is: IndexedSeq[EmitSettable]) { @@ -548,7 +643,13 @@ class RichIndexedSeqEmitSettable(is: IndexedSeq[EmitSettable]) { } object LoopRef { - def apply(cb: EmitCodeBuilder, L: CodeLabel, args: IndexedSeq[(String, EmitType)], pool: Value[RegionPool], resultType: EmitType): LoopRef = { + def apply( + cb: EmitCodeBuilder, + L: CodeLabel, + args: IndexedSeq[(String, EmitType)], + pool: Value[RegionPool], + resultType: EmitType, + ): LoopRef = { val (loopArgs, tmpLoopArgs) = args.zipWithIndex.map { case ((name, et), i) => (cb.emb.newEmitField(s"$name$i", et), cb.emb.newEmitField(s"tmp$name$i", et)) }.unzip @@ -570,7 +671,8 @@ class LoopRef( val tmpLoopArgs: IndexedSeq[EmitSettable], val r1: Settable[Region], val r2: Settable[Region], - val resultType: EmitType) + val resultType: EmitType, +) abstract class EstimableEmitter[C] { def emit(mb: EmitMethodBuilder[C]): Code[Unit] @@ -578,14 +680,20 @@ abstract class EstimableEmitter[C] { def estimatedSize: Int } -class Emit[C]( - val ctx: EmitContext, - val cb: EmitClassBuilder[C]) { - emitSelf => +class Emit[C](val ctx: EmitContext, val cb: EmitClassBuilder[C]) { - val methods: mutable.Map[(String, Seq[Type], Seq[SType], SType), EmitMethodBuilder[C]] = mutable.Map() + val methods: mutable.Map[(String, Seq[Type], Seq[SType], SType), EmitMethodBuilder[C]] = + mutable.Map() - def emitVoidInSeparateMethod(context: String, cb: EmitCodeBuilder, ir: IR, region: Value[Region], env: EmitEnv, container: Option[AggContainer], loopEnv: Option[Env[LoopRef]]): Unit = { + def emitVoidInSeparateMethod( + context: String, + cb: EmitCodeBuilder, + ir: IR, + region: Value[Region], + env: EmitEnv, + container: Option[AggContainer], + loopEnv: Option[Env[LoopRef]], + ): Unit = { assert(!ctx.inLoopCriticalPath.contains(ir)) val mb = cb.emb.genEmitMethod(context, FastSeq[ParamType](), UnitInfo) val r = cb.newField[Region]("emitVoidSeparate_region", region) @@ -593,10 +701,18 @@ class Emit[C]( ctx.tryingToSplit.bind(ir, ()) emitVoid(cb, ir, r, env, container, loopEnv) } - cb.invokeVoid(mb) + cb.invokeVoid(mb, cb.this_) } - def emitSplitMethod(context: String, cb: EmitCodeBuilder, ir: IR, region: Value[Region], env: EmitEnv, container: Option[AggContainer], loopEnv: Option[Env[LoopRef]]): (EmitSettable, EmitMethodBuilder[_]) = { + def emitSplitMethod( + context: String, + cb: EmitCodeBuilder, + ir: IR, + region: Value[Region], + env: EmitEnv, + container: Option[AggContainer], + loopEnv: Option[Env[LoopRef]], + ): (EmitSettable, EmitMethodBuilder[_]) = { val mb = cb.emb.genEmitMethod(context, FastSeq[ParamType](), UnitInfo) val r = cb.newField[Region]("emitInSeparate_region", region) @@ -611,7 +727,15 @@ class Emit[C]( (ev, mb) } - def emitInSeparateMethod(context: String, cb: EmitCodeBuilder, ir: IR, region: Value[Region], env: EmitEnv, container: Option[AggContainer], loopEnv: Option[Env[LoopRef]]): IEmitCode = { + def emitInSeparateMethod( + context: String, + cb: EmitCodeBuilder, + ir: IR, + region: Value[Region], + env: EmitEnv, + container: Option[AggContainer], + loopEnv: Option[Env[LoopRef]], + ): IEmitCode = { if (ir.typ == TVoid) { emitVoidInSeparateMethod(context, cb, ir, region, env, container, loopEnv) return IEmitCode.present(cb, SVoidValue) @@ -619,29 +743,72 @@ class Emit[C]( assert(!ctx.inLoopCriticalPath.contains(ir)) val (ev, mb) = emitSplitMethod(context, cb, ir, region, env, container, loopEnv) - cb.invokeVoid(mb) + cb.invokeVoid(mb, cb.this_) ev.toI(cb) } - private[ir] def emitVoid(cb: EmitCodeBuilder, ir: IR, region: Value[Region], env: EmitEnv, container: Option[AggContainer], loopEnv: Option[Env[LoopRef]]): Unit = { + private[ir] def emitVoid( + cb: EmitCodeBuilder, + ir: IR, + region: Value[Region], + env: EmitEnv, + container: Option[AggContainer], + loopEnv: Option[Env[LoopRef]], + ): Unit = { if (ctx.methodSplits.contains(ir) && !ctx.tryingToSplit.contains(ir)) { - emitVoidInSeparateMethod(s"split_${ir.getClass.getSimpleName}", cb, ir, region, env, container, loopEnv) + emitVoidInSeparateMethod( + s"split_${ir.getClass.getSimpleName}", + cb, + ir, + region, + env, + container, + loopEnv, + ) return } val mb: EmitMethodBuilder[C] = cb.emb.asInstanceOf[EmitMethodBuilder[C]] - - def emit(ir: IR, mb: EmitMethodBuilder[C] = mb, region: Value[Region] = region, env: EmitEnv = env, container: Option[AggContainer] = container, loopEnv: Option[Env[LoopRef]] = loopEnv): EmitCode = + @nowarn("cat=unused-locals&msg=local default argument") + def emit( + ir: IR, + mb: EmitMethodBuilder[C] = mb, + region: Value[Region] = region, + env: EmitEnv = env, + container: Option[AggContainer] = container, + loopEnv: Option[Env[LoopRef]] = loopEnv, + ): EmitCode = this.emit(ir, mb, region, env, container, loopEnv) - def emitStream(ir: IR, outerRegion: Value[Region], mb: EmitMethodBuilder[C] = mb, env: EmitEnv = env): EmitCode = - EmitCode.fromI(mb)(cb => EmitStream.produce(this, ir, cb, cb.emb, outerRegion, env, container)) - - def emitVoid(ir: IR, cb: EmitCodeBuilder = cb, region: Value[Region] = region, env: EmitEnv = env, container: Option[AggContainer] = container, loopEnv: Option[Env[LoopRef]] = loopEnv): Unit = + def emitStream( + ir: IR, + outerRegion: Value[Region], + mb: EmitMethodBuilder[C] = mb, + env: EmitEnv = env, + ): EmitCode = + EmitCode.fromI(mb)(cb => + EmitStream.produce(this, ir, cb, cb.emb, outerRegion, env, container) + ) + + def emitVoid( + ir: IR, + cb: EmitCodeBuilder = cb, + region: Value[Region] = region, + env: EmitEnv = env, + container: Option[AggContainer] = container, + loopEnv: Option[Env[LoopRef]] = loopEnv, + ): Unit = this.emitVoid(cb, ir, region, env, container, loopEnv) - def emitI(ir: IR, region: Value[Region] = region, env: EmitEnv = env, container: Option[AggContainer] = container, loopEnv: Option[Env[LoopRef]] = loopEnv): IEmitCode = + def emitI( + ir: IR, + cb: EmitCodeBuilder = cb, + region: Value[Region] = region, + env: EmitEnv = env, + container: Option[AggContainer] = container, + loopEnv: Option[Env[LoopRef]] = loopEnv, + ): IEmitCode = this.emitI(ir, cb, region, env, container, loopEnv) (ir: @unchecked) match { @@ -651,41 +818,18 @@ class Emit[C]( case Void() => Code._empty - case x@Begin(xs) => - if (!ctx.inLoopCriticalPath.contains(x) && xs.forall(x => !ctx.inLoopCriticalPath.contains(x))) { - xs.grouped(16).zipWithIndex.foreach { case (group, idx) => - val mb = cb.emb.genEmitMethod(s"begin_group_$idx", FastSeq[ParamType](classInfo[Region]), UnitInfo) - mb.voidWithBuilder { cb => - group.foreach(x => emitVoid(x, cb, mb.getCodeParam[Region](1), env, container, loopEnv)) - } - cb.invokeVoid(mb, region) - } - } else - xs.foreach(x => emitVoid(x)) - case If(cond, cnsq, altr) => assert(cnsq.typ == TVoid && altr.typ == TVoid) emitI(cond).consume(cb, {}, m => cb.if_(m.asBoolean.value, emitVoid(cnsq), emitVoid(altr))) - case Let(bindings, body) => - def go(env: EmitEnv): IndexedSeq[(String, IR)] => Unit = { - case (name, value) +: rest => - val xVal = - if (value.typ.isInstanceOf[TStream]) emitStream(value, region, env = env) - else emit(value, env = env) - - cb.withScopedMaybeStreamValue(xVal, s"let_$name") { ev => - go(env.bind(name, ev))(rest) - } - case Seq() => - emitVoid(body, env = env) - } - - go(env)(bindings) + case let: Block => + val newEnv = emitBlock(let, cb, env, region, container, loopEnv) + emitVoid(let.body, env = newEnv) case StreamFor(a, valueName, body) => - emitStream(a, region).toI(cb).consume(cb, + emitStream(a, region).toI(cb).consume( + cb, {}, { case stream: SStreamValue => val producer = stream.getProducer(mb) @@ -694,38 +838,39 @@ class Emit[C]( emitVoid(body, region = producer.elementRegion, env = env.bind(valueName -> ev)) } } - }) + }, + ) - case x@InitOp(i, args, sig) => + case InitOp(i, args, sig) => val AggContainer(aggs, sc, _) = container.get assert(aggs(i) == sig.state) val rvAgg = agg.Extract.getAgg(sig) val argVars = args - .map { a => emit(a, container = container.flatMap(_.nested(i, init = true))) } + .map(a => emit(a, container = container.flatMap(_.nested(i, init = true)))) .toArray sc.newState(cb, i) rvAgg.initOp(cb, sc.states(i), argVars) - case x@SeqOp(i, args, sig) => + case SeqOp(i, args, sig) => val AggContainer(aggs, sc, _) = container.get assert(sig.state == aggs(i)) val rvAgg = agg.Extract.getAgg(sig) val argVars = args - .map { a => emit(a, container = container.flatMap(_.nested(i, init = false))) } + .map(a => emit(a, container = container.flatMap(_.nested(i, init = false)))) .toArray rvAgg.seqOp(cb, sc.states(i), argVars) - case x@CombOp(i1, i2, sig) => + case CombOp(i1, i2, sig) => val AggContainer(aggs, sc, _) = container.get assert(sig.state == aggs(i1) && sig.state == aggs(i2)) val rvAgg = agg.Extract.getAgg(sig) - rvAgg.combOp(ctx.executeContext, cb, sc.states(i1), sc.states(i2)) + rvAgg.combOp(ctx.executeContext, cb, region, sc.states(i1), sc.states(i2)) - case x@SerializeAggs(start, sIdx, spec, sigs) => + case SerializeAggs(start, sIdx, spec, sigs) => val AggContainer(_, sc, _) = container.get val ob = mb.genFieldThisRef[OutputBuffer]() val baos = mb.genFieldThisRef[ByteArrayOutputStream]() @@ -734,9 +879,7 @@ class Emit[C]( cb.assign(ob, spec.buildCodeOutputBuffer(baos)) Array.range(start, start + sigs.length) - .foreach { idx => - sc.states(idx).serialize(spec)(cb, ob) - } + .foreach(idx => sc.states(idx).serialize(spec)(cb, ob)) cb += ob.invoke[Unit]("flush") cb += ob.invoke[Unit]("close") @@ -753,18 +896,21 @@ class Emit[C]( Array.range(start, start + ns).foreach(i => sc.newState(cb, i)) - cb.assign(ib, spec.buildCodeInputBuffer( - Code.newInstance[ByteArrayInputStream, Array[Byte]]( - mb.getSerializedAgg(sIdx)))) + cb.assign( + ib, + spec.buildCodeInputBuffer( + Code.newInstance[ByteArrayInputStream, Array[Byte]]( + mb.getSerializedAgg(sIdx) + ) + ), + ) cb += mb.freeSerializedAgg(sIdx) - (0 until ns).foreach { j => - deserializers(j)(cb, ib) - } + (0 until ns).foreach(j => deserializers(j)(cb, ib)) cb.assign(ib, Code._null[InputBuffer]) - case Die(m, typ, errorId) => + case Die(m, _, errorId) => val cm = emitI(m) val msg = cm.consumeCode(cb, "", _.asString.loadString(cb)) cb._throw(Code.newInstance[HailException, String, Int](msg, errorId)) @@ -776,18 +922,22 @@ class Emit[C]( val AggContainer(_, sc, _) = container.get val rvAgg = agg.Extract.getAgg(aggSig) val tempState = AggStateSig.getState(aggSig.state, mb.ecb) - val aggStateOffset = mb.genFieldThisRef[Long](s"combOpValue_${ i }_state"); + val aggStateOffset = mb.genFieldThisRef[Long](s"combOpValue_${i}_state"); val v = emitI(value) - v.consume(cb, + v.consume( + cb, cb._fatal("cannot combOp a missing value"), { case serializedValue: SBinaryValue => - cb.assign(aggStateOffset, region.allocate(tempState.storageType.alignment, tempState.storageType.byteSize)) + cb.assign( + aggStateOffset, + region.allocate(tempState.storageType.alignment, tempState.storageType.byteSize), + ) tempState.createState(cb) tempState.newState(cb) tempState.deserializeFromBytes(cb, serializedValue) - rvAgg.combOp(ctx.executeContext, cb, sc.states(i), tempState) - } + rvAgg.combOp(ctx.executeContext, cb, region, sc.states(i), tempState) + }, ) case InitFromSerializedValue(i, value, sig) => @@ -795,63 +945,105 @@ class Emit[C]( assert(aggs(i) == sig) val v = emitI(value) - v.consume(cb, + v.consume( + cb, cb._fatal("cannot initialize aggs from a missing value"), { case serializedValue: SBinaryValue => sc.states(i).createState(cb) sc.newState(cb, i) sc.states(i).deserializeFromBytes(cb, serializedValue) - } + }, ) } } - private[ir] def emitI(ir: IR, cb: EmitCodeBuilder, env: EmitEnv, container: Option[AggContainer]): IEmitCode = { + private[ir] def emitI(ir: IR, cb: EmitCodeBuilder, env: EmitEnv, container: Option[AggContainer]) + : IEmitCode = { val region = cb.emb.getCodeParam[Region](1) emitI(ir, cb, region, env, container, None) } - private[ir] def emitI(ir: IR, cb: EmitCodeBuilder, region: Value[Region], env: EmitEnv, - container: Option[AggContainer], loopEnv: Option[Env[LoopRef]] + private[ir] def emitI( + ir: IR, + cb: EmitCodeBuilder, + region: Value[Region], + env: EmitEnv, + container: Option[AggContainer], + loopEnv: Option[Env[LoopRef]], ): IEmitCode = { if (ctx.methodSplits.contains(ir) && !ctx.tryingToSplit.contains(ir)) { - return emitInSeparateMethod(s"split_${ir.getClass.getSimpleName}", cb, ir, region, env, container, loopEnv) + return emitInSeparateMethod( + s"split_${ir.getClass.getSimpleName}", + cb, + ir, + region, + env, + container, + loopEnv, + ) } val mb: EmitMethodBuilder[C] = cb.emb.asInstanceOf[EmitMethodBuilder[C]] - def emitI(ir: IR, region: Value[Region] = region, env: EmitEnv = env, container: Option[AggContainer] = container, loopEnv: Option[Env[LoopRef]] = loopEnv): IEmitCode = + def emitI( + ir: IR, + region: Value[Region] = region, + env: EmitEnv = env, + container: Option[AggContainer] = container, + loopEnv: Option[Env[LoopRef]] = loopEnv, + ): IEmitCode = this.emitI(ir, cb, region, env, container, loopEnv) - def emitInNewBuilder(cb: EmitCodeBuilder, ir: IR, region: Value[Region] = region, env: EmitEnv = env, container: Option[AggContainer] = container, loopEnv: Option[Env[LoopRef]] = loopEnv): IEmitCode = + def emitInNewBuilder( + cb: EmitCodeBuilder, + ir: IR, + region: Value[Region] = region, + env: EmitEnv = env, + container: Option[AggContainer] = container, + loopEnv: Option[Env[LoopRef]] = loopEnv, + ): IEmitCode = this.emitI(ir, cb, region, env, container, loopEnv) - def emitStream(ir: IR, cb: EmitCodeBuilder, outerRegion: Value[Region]): IEmitCode = + def emitStream(ir: IR, cb: EmitCodeBuilder, outerRegion: Value[Region], env: EmitEnv = env) + : IEmitCode = EmitStream.produce(this, ir, cb, cb.emb, outerRegion, env, container) - def emitVoid(ir: IR, env: EmitEnv = env, container: Option[AggContainer] = container, loopEnv: Option[Env[LoopRef]] = loopEnv): Unit = + def emitVoid( + ir: IR, + cb: EmitCodeBuilder = cb, + region: Value[Region] = region, + env: EmitEnv = env, + container: Option[AggContainer] = container, + loopEnv: Option[Env[LoopRef]] = loopEnv, + ): Unit = this.emitVoid(cb, ir: IR, region, env, container, loopEnv) - def emitFallback(ir: IR, env: EmitEnv = env, container: Option[AggContainer] = container, loopEnv: Option[Env[LoopRef]] = loopEnv): IEmitCode = + def emitFallback( + ir: IR, + env: EmitEnv = env, + container: Option[AggContainer] = container, + loopEnv: Option[Env[LoopRef]] = loopEnv, + ): IEmitCode = this.emit(ir, mb, region, env, container, loopEnv, fallingBackFromEmitI = true).toI(cb) - def emitDeforestedNDArrayI(ir: IR): IEmitCode = EmitNDArray(this, ir, cb, region, env, container, loopEnv) + def emitDeforestedNDArrayI(ir: IR): IEmitCode = + EmitNDArray(this, ir, cb, region, env, container, loopEnv) - def emitNDArrayColumnMajorStrides(ir: IR): IEmitCode = { + def emitNDArrayColumnMajorStrides(ir: IR): IEmitCode = emitI(ir).map(cb) { case pNDValue: SNDArrayValue => LinalgCodeUtils.checkColMajorAndCopyIfNeeded(pNDValue, cb, region) } - } - // Returns an IEmitCode along with a Boolean that is true if the returned value is column major. If false it's row + /* Returns an IEmitCode along with a Boolean that is true if the returned value is column major. + * If false it's row */ // major instead. - def emitNDArrayStandardStriding(ir: IR): IEmitCodeGen[(SNDArrayValue, Value[Boolean])] = { + def emitNDArrayStandardStriding(ir: IR): IEmitCodeGen[(SNDArrayValue, Value[Boolean])] = emitI(ir).map(cb) { case pNDValue: SNDArrayValue => LinalgCodeUtils.checkStandardStriding(pNDValue, cb, region) } - } - def typeWithReqx(node: IR): VirtualTypeWithReq = VirtualTypeWithReq(node.typ, ctx.req.lookup(node).asInstanceOf[TypeWithRequiredness]) + def typeWithReqx(node: IR): VirtualTypeWithReq = + VirtualTypeWithReq(node.typ, ctx.req.lookup(node).asInstanceOf[TypeWithRequiredness]) def typeWithReq: VirtualTypeWithReq = typeWithReqx(ir) @@ -863,7 +1055,7 @@ class Emit[C]( def presentPC(pc: SValue): IEmitCode = IEmitCode.present(cb, pc) val result: IEmitCode = (ir: @unchecked) match { - case In(i, expectedPType) => + case In(i, _) => val ev = env.inputValues(i) ev.toI(cb) case I32(x) => @@ -874,16 +1066,24 @@ class Emit[C]( presentPC(primitive(const(x))) case F64(x) => presentPC(primitive(const(x))) - case s@Str(x) => + case Str(x) => presentPC(mb.addLiteral(cb, x, typeWithReq)) - case x@UUID4(_) => + case UUID4(_) => val pt = PCanonicalString() - presentPC(pt.loadCheapSCode(cb, pt. - allocateAndStoreString(cb, region, Code.invokeScalaObject0[String]( - Class.forName("is.hail.expr.ir.package$"), "uuid4")))) - case x@Literal(t, v) => + presentPC(pt.loadCheapSCode( + cb, + pt.allocateAndStoreString( + cb, + region, + Code.invokeScalaObject0[String]( + Class.forName("is.hail.expr.ir.package$"), + "uuid4", + ), + ), + )) + case Literal(_, v) => presentPC(mb.addLiteral(cb, v, typeWithReq)) - case x@EncodedLiteral(codec, value) => + case x @ EncodedLiteral(_, _) => presentPC(mb.addEncodedLiteral(cb, x)) case True() => presentPC(primitive(const(true))) @@ -900,15 +1100,18 @@ class Emit[C]( iec.map(cb)(pc => cast(cb, pc)) case CastRename(v, _typ) => emitI(v) - .map(cb)(pc => pc.st.castRename(_typ).fromValues(pc.valueTuple)) + .map(cb)(_.castRename(_typ)) case NA(typ) => IEmitCode.missing(cb, SUnreachable.fromVirtualType(typ).defaultValue) case IsNA(v) => val m = emitI(v).consumeCode(cb, true, _ => false) presentPC(primitive(m)) - case Coalesce(values) => + case let: Block => + val newEnv = emitBlock(let, cb, env, region, container, loopEnv) + emitI(let.body, env = newEnv) + case Coalesce(values) => val emittedValues = values.map(v => EmitCode.fromI(cb.emb)(cb => emitInNewBuilder(cb, v))) val unifiedType = SType.chooseCompatibleType(typeWithReq, emittedValues.map(_.st): _*) val coalescedValue = mb.newPLocal("coalesce_value", unifiedType) @@ -917,12 +1120,14 @@ class Emit[C]( val Lmissing = CodeLabel() emittedValues.foreach { value => - value.toI(cb).consume(cb, + value.toI(cb).consume( + cb, {}, // fall through to next check { sc => cb.assign(coalescedValue, sc.castTo(cb, region, unifiedType)) cb.goto(Ldefined) - }) + }, + ) } cb.goto(Lmissing) @@ -933,8 +1138,6 @@ class Emit[C]( assert(cnsq.typ == altr.typ) emitI(cond).flatMap(cb) { case condValue: SBooleanValue => - - val codeCnsq = EmitCode.fromI(cb.emb)(cb => emitInNewBuilder(cb, cnsq)) val codeAltr = EmitCode.fromI(cb.emb)(cb => emitInNewBuilder(cb, altr)) val outType = SType.chooseCompatibleType(typeWithReq, codeCnsq.st, codeAltr.st) @@ -942,21 +1145,19 @@ class Emit[C]( val Lmissing = CodeLabel() val Ldefined = CodeLabel() val out = mb.newPLocal(outType) - cb.if_(condValue.value, { - codeCnsq.toI(cb).consume(cb, - { - cb.goto(Lmissing) - }, { sc => - cb.assign(out, sc.castTo(cb, region, outType)) - }) - }, { - codeAltr.toI(cb).consume(cb, - { - cb.goto(Lmissing) - }, { sc => - cb.assign(out, sc.castTo(cb, region, outType)) - }) - }) + cb.if_( + condValue.value, + codeCnsq.toI(cb).consume( + cb, + cb.goto(Lmissing), + sc => cb.assign(out, sc.castTo(cb, region, outType)), + ), + codeAltr.toI(cb).consume( + cb, + cb.goto(Lmissing), + sc => cb.assign(out, sc.castTo(cb, region, outType)), + ), + ) cb.goto(Ldefined) IEmitCode(Lmissing, Ldefined, out, codeCnsq.required && codeAltr.required) @@ -971,54 +1172,67 @@ class Emit[C]( val Ldefined = CodeLabel() val Lundefined = CodeLabel() - val sType = SType.chooseCompatibleType(typeWithReq, emitCases.map(_.st): _ *) + val sType = SType.chooseCompatibleType(typeWithReq, emitCases.map(_.st): _*) val res = cb.newSLocal(sType, genName("l", "switch")) def mkCase(cb: EmitCodeBuilder, case_ : EmitCode): Unit = - case_.toI(cb).consume(cb, { cb.goto(Lundefined) }, { svalue => - cb.assign(res, svalue.castTo(cb, region, sType)) - cb.goto(Ldefined) - }) + case_.toI(cb).consume( + cb, + cb.goto(Lundefined), + { svalue => + cb.assign(res, svalue.castTo(cb, region, sType)) + cb.goto(Ldefined) + }, + ) - cb.switch(x.value, mkCase(cb, emitCases.last), emitCases.init.map { case_ => - () => mkCase(cb, case_) - }) + cb.switch( + x.value, + mkCase(cb, emitCases.last), + emitCases.init.map(case_ => () => mkCase(cb, case_)), + ) IEmitCode(Lundefined, Ldefined, res, emitCases.forall(_.required)) } - case x@MakeStruct(fields) => - presentPC(SStackStruct.constructFromArgs(cb, region, x.typ.asInstanceOf[TBaseStruct], + case x @ MakeStruct(fields) => + presentPC(SStackStruct.constructFromArgs( + cb, + region, + x.typ.asInstanceOf[TBaseStruct], fields.map { case (_, x) => EmitCode.fromI(cb.emb)(cb => emitInNewBuilder(cb, x)) }: _* )) - case x@MakeTuple(fields) => - presentPC(SStackStruct.constructFromArgs(cb, region, x.typ.asInstanceOf[TBaseStruct], + case x @ MakeTuple(fields) => + presentPC(SStackStruct.constructFromArgs( + cb, + region, + x.typ.asInstanceOf[TBaseStruct], fields.map { case (_, x) => EmitCode.fromI(cb.emb)(cb => emitInNewBuilder(cb, x)) }: _* )) case SelectFields(oldStruct, fields) => - emitI(oldStruct).map(cb) { _.asBaseStruct.subset(fields: _*) } + emitI(oldStruct).map(cb)(_.asBaseStruct.subset(fields: _*)) - case x@InsertFields(old, fields, _) => + case x @ InsertFields(old, fields, _) => if (fields.isEmpty) emitI(old) else { emitI(old).map(cb) { case old: SBaseStructValue => val newFields = fields.map { case (name, x) => - (name, cb.memoize(EmitCode.fromI(cb.emb)(cb => emitInNewBuilder(cb, x)), "InsertFields")) + ( + name, + cb.memoize(EmitCode.fromI(cb.emb)(cb => emitInNewBuilder(cb, x)), "InsertFields"), + ) } old.insert(cb, region, x.typ, newFields: _*) } } case ApplyBinaryPrimOp(op, l, r) => - emitI(l).flatMap(cb) { pcL => - emitI(r).map(cb)(pcR => BinaryOp.emit(cb, op, pcL, pcR)) - } + emitI(l).flatMap(cb)(pcL => emitI(r).map(cb)(pcR => BinaryOp.emit(cb, op, pcL, pcR))) case ApplyUnaryPrimOp(op, x) => emitI(x).map(cb)(pc => UnaryOp.emit(cb, op, pc)) @@ -1038,14 +1252,13 @@ class Emit[C]( presentPC(primitive(ir.typ, f(cb, lc, rc))) } - case x@MakeArray(args, _) => - + case MakeArray(args, _) => val emittedArgs = args.map(a => EmitCode.fromI(mb)(cb => emitInNewBuilder(cb, a))) val pType = typeWithReq.canonicalPType.asInstanceOf[PCanonicalArray] - val (pushElement, finish) = pType.constructFromFunctions(cb, region, args.size, deepCopy = false) - for (arg <- emittedArgs) { + val (pushElement, finish) = + pType.constructFromFunctions(cb, region, args.size, deepCopy = false) + for (arg <- emittedArgs) pushElement(cb, arg.toI(cb)) - } presentPC(finish(cb)) case ArrayZeros(length) => @@ -1056,32 +1269,38 @@ class Emit[C]( } case ArrayRef(a, i, errorID) => - def boundsCheck(cb: EmitCodeBuilder, index: Value[Int], len: Value[Int]): Unit = { - val bcMb = mb.getOrGenEmitMethod("arrayref_bounds_check", "arrayref_bounds_check", - IndexedSeq[ParamType](IntInfo, IntInfo, IntInfo), UnitInfo)({ mb => - mb.voidWithBuilder { cb => - val index = mb.getCodeParam[Int](1) - val len = mb.getCodeParam[Int](2) - val errorID = mb.getCodeParam[Int](3) - cb.if_(index < 0 || index >= len, { - cb._fatalWithError(errorID, const("array index out of bounds: index=") + val boundsCheck: EmitMethodBuilder[_] = + mb.ecb.getOrGenEmitMethod( + "arrayref_bounds_check", + "arrayref_bounds_check", + FastSeq(IntInfo, IntInfo, IntInfo), + UnitInfo, + ) { mb => + mb.voidWithBuilder { cb => + val index = mb.getCodeParam[Int](1) + val len = mb.getCodeParam[Int](2) + val errorID = mb.getCodeParam[Int](3) + cb.if_( + index < 0 || index >= len, + cb._fatalWithError( + errorID, + const("array index out of bounds: index=") .concat(index.toS) .concat(", length=") - .concat(len.toS)) - }) - - } - }) - cb.invokeVoid(bcMb, index, len, const(errorID)) - } + .concat(len.toS), + ), + ) + } + } emitI(a).flatMap(cb) { case av: SIndexableValue => emitI(i).flatMap(cb) { case ic: SInt32Value => val iv = ic.value - boundsCheck(cb, iv, av.loadLength()) + cb.invokeVoid(boundsCheck, cb.this_, iv, av.loadLength(), const(errorID)) av.loadElement(cb, iv) } } + case ArraySlice(a, start, stop, step, errorID) => emitI(a).flatMap(cb) { case arrayValue: SIndexableValue => emitI(start).flatMap(cb) { startCode => @@ -1089,39 +1308,70 @@ class Emit[C]( val arrayLength = arrayValue.loadLength() val realStep = cb.newLocal[Int]("array_slice_requestedStep", stepCode.asInt.value) - cb.if_(realStep ceq const(0), cb._fatalWithError(const(errorID), const("step cannot be 0 for array slice"))) + cb.if_( + realStep ceq const(0), + cb._fatalWithError(const(errorID), const("step cannot be 0 for array slice")), + ) val noneStop = cb.newLocal[Int]("array_slice_noneStop") - cb.if_(realStep < 0, cb.assign(noneStop, const(-1) * arrayLength - const(1)), cb.assign(noneStop, arrayLength)) + cb.if_( + realStep < 0, + cb.assign(noneStop, const(-1) * arrayLength - const(1)), + cb.assign(noneStop, arrayLength), + ) val maxBound = cb.newLocal[Int]("array_slice_maxBound") val minBound = cb.newLocal[Int]("array_slice_minBound") - cb.if_(realStep > 0, cb.assign(maxBound, arrayLength), cb.assign(maxBound, arrayLength - 1)) + cb.if_( + realStep > 0, + cb.assign(maxBound, arrayLength), + cb.assign(maxBound, arrayLength - 1), + ) cb.if_(realStep > 0, cb.assign(minBound, 0), cb.assign(minBound, -1)) - val stopI = stop.map(emitI(_)).getOrElse(IEmitCode.present(cb, new SInt32Value(noneStop))) + val stopI = + stop.map(emitI(_)).getOrElse(IEmitCode.present(cb, new SInt32Value(noneStop))) stopI.map(cb) { stopCode => - val requestedStart = cb.newLocal[Int]("array_slice_requestedStart", startCode.asInt.value) + val requestedStart = + cb.newLocal[Int]("array_slice_requestedStart", startCode.asInt.value) val realStart = cb.newLocal[Int]("array_slice_realStart") - cb.if_(requestedStart >= arrayLength, + cb.if_( + requestedStart >= arrayLength, cb.assign(realStart, maxBound), - cb.if_(requestedStart >= 0, cb.assign(realStart, requestedStart), - cb.if_(arrayLength + requestedStart >= 0, + cb.if_( + requestedStart >= 0, + cb.assign(realStart, requestedStart), + cb.if_( + arrayLength + requestedStart >= 0, cb.assign(realStart, arrayLength + requestedStart), - cb.assign(realStart, minBound)))) + cb.assign(realStart, minBound), + ), + ), + ) - val requestedStop = cb.newLocal[Int]("array_slice_requestedStop", stopCode.asInt.value) + val requestedStop = + cb.newLocal[Int]("array_slice_requestedStop", stopCode.asInt.value) val realStop = cb.newLocal[Int]("array_slice_realStop") - cb.if_(requestedStop > arrayLength, + cb.if_( + requestedStop > arrayLength, cb.assign(realStop, maxBound), - cb.if_(requestedStop >= 0, + cb.if_( + requestedStop >= 0, cb.assign(realStop, requestedStop), - cb.if_(arrayLength + requestedStop > 0, + cb.if_( + arrayLength + requestedStop > 0, cb.assign(realStop, arrayLength + requestedStop), - cb.assign(realStop, minBound)))) + cb.assign(realStop, minBound), + ), + ), + ) - val resultLen = cb.newLocal[Int]("array_slice_resultLength", (realStop - realStart) / realStep) - cb.if_(((realStop - realStart) % realStep cne 0), cb.assign(resultLen, resultLen + 1)) + val resultLen = + cb.newLocal[Int]("array_slice_resultLength", (realStop - realStart) / realStep) + cb.if_( + ((realStop - realStart) % realStep cne 0), + cb.assign(resultLen, resultLen + 1), + ) cb.if_(resultLen < 0, cb.assign(resultLen, 0)) val resultArray = typeWithReq.canonicalPType.asInstanceOf[PCanonicalArray] @@ -1134,38 +1384,39 @@ class Emit[C]( } case ArrayLen(a) => - emitI(a).map(cb) { ac => - primitive(ac.asIndexable.loadLength()) - } + emitI(a).map(cb)(ac => primitive(ac.asIndexable.loadLength())) case GetField(o, name) => - emitI(o).flatMap(cb) { oc => - oc.asBaseStruct.loadField(cb, name) - } + emitI(o).flatMap(cb)(oc => oc.asBaseStruct.loadField(cb, name)) case GetTupleElement(o, i) => emitI(o).flatMap(cb) { oc => oc.asBaseStruct.loadField(cb, o.typ.asInstanceOf[TTuple].fieldIndex(i)) } - case x@LowerBoundOnOrderedCollection(orderedCollection, elem, onKey) => + case LowerBoundOnOrderedCollection(orderedCollection, elem, onKey) => emitI(orderedCollection).map(cb) { a => - val e = EmitCode.fromI(cb.emb)(cb => this.emitI(elem, cb, region, env, container, loopEnv)) - val bs = new BinarySearch[C](mb, a.st.asInstanceOf[SContainer], e.emitType, { (cb, elt) => - - if (onKey) { - cb.memoize(elt.toI(cb).flatMap(cb) { - case x: SBaseStructValue => - x.loadField(cb, 0) - case x: SIntervalValue => - x.loadStart(cb) - }) - } else elt - }) + val e = + EmitCode.fromI(cb.emb)(cb => this.emitI(elem, cb, region, env, container, loopEnv)) + val bs = new BinarySearch[C]( + mb, + a.st.asInstanceOf[SContainer], + e.emitType, + { (cb, elt) => + if (onKey) { + cb.memoize(elt.toI(cb).flatMap(cb) { + case x: SBaseStructValue => + x.loadField(cb, 0) + case x: SIntervalValue => + x.loadStart(cb) + }) + } else elt + }, + ) primitive(bs.search(cb, a, e)) } - case x@ArraySort(a, left, right, lessThan) => + case x @ ArraySort(a, left, right, lessThan) => emitStream(a, cb, region).map(cb) { case stream: SStreamValue => val producer = stream.getProducer(mb) @@ -1174,30 +1425,73 @@ class Emit[C]( val vab = new StagedArrayBuilder(cb, sct, producer.element.required, 0) StreamUtils.writeToArrayBuilder(cb, producer, vab, region) val sorter = new ArraySorter(EmitRegion(mb, region), vab) - sorter.sort(cb, region, makeDependentSortingFunction(cb, sct, lessThan, env, emitSelf, Array(left, right))) + sorter.sort( + cb, + region, + makeDependentSortingFunction(cb, sct, lessThan, env, this, Array(left, right)), + ) sorter.toRegion(cb, x.typ) } case ArrayMaximalIndependentSet(edges, tieBreaker) => emitI(edges).map(cb) { edgesCode => - val jEdges: Value[UnsafeIndexedSeq] = cb.memoize(Code.checkcast[UnsafeIndexedSeq]((is.hail.expr.ir.functions.ArrayFunctions.svalueToJavaValue(cb, region, edgesCode)))) + val jEdges: Value[UnsafeIndexedSeq] = cb.memoize(Code.checkcast[UnsafeIndexedSeq]( + (is.hail.expr.ir.functions.ArrayFunctions.svalueToJavaValue(cb, region, edgesCode)) + )) val ms = tieBreaker match { case None => - Code.invokeScalaObject1[UnsafeIndexedSeq, IndexedSeq[Any]](Graph.getClass, "maximalIndependentSet", jEdges) + Code.invokeScalaObject1[UnsafeIndexedSeq, IndexedSeq[Any]]( + Graph.getClass, + "maximalIndependentSet", + jEdges, + ) case Some((leftName, rightName, tieBreaker)) => - val nodeType = tcoerce[TArray](edges.typ).elementType.asInstanceOf[TBaseStruct].types.head + val nodeType = + tcoerce[TArray](edges.typ).elementType.asInstanceOf[TBaseStruct].types.head val wrappedNodeType = PCanonicalTuple(true, PType.canonical(nodeType)) - val (Some(PTypeReferenceSingleCodeType(t)), f) = Compile[AsmFunction3RegionLongLongLong](ctx.executeContext, - IndexedSeq((leftName, SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(wrappedNodeType))), - (rightName, SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(wrappedNodeType)))), - FastSeq(classInfo[Region], LongInfo, LongInfo), LongInfo, - MakeTuple.ordered(FastSeq(tieBreaker))) + val (Some(PTypeReferenceSingleCodeType(t)), f) = + Compile[AsmFunction3RegionLongLongLong]( + ctx.executeContext, + IndexedSeq( + ( + leftName, + SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(wrappedNodeType)), + ), + ( + rightName, + SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(wrappedNodeType)), + ), + ), + FastSeq(classInfo[Region], LongInfo, LongInfo), + LongInfo, + MakeTuple.ordered(FastSeq(tieBreaker)), + ) assert(t.virtualType == TTuple(TFloat64)) val resultType = t.asInstanceOf[PTuple] - Code.invokeScalaObject9[Map[String, ReferenceGenome], UnsafeIndexedSeq, HailClassLoader, FS, HailTaskContext, Region, PTuple, PTuple, (HailClassLoader, FS, HailTaskContext, Region) => AsmFunction3RegionLongLongLong, IndexedSeq[Any]]( - Graph.getClass, "maximalIndependentSet", cb.emb.ecb.emodb.referenceGenomeMap, - jEdges, mb.getHailClassLoader, mb.getFS, mb.getTaskContext, region, - mb.getPType[PTuple](wrappedNodeType), mb.getPType[PTuple](resultType), mb.getObject(f)) + Code.invokeScalaObject9[ + Map[String, ReferenceGenome], + UnsafeIndexedSeq, + HailClassLoader, + FS, + HailTaskContext, + Region, + PTuple, + PTuple, + (HailClassLoader, FS, HailTaskContext, Region) => AsmFunction3RegionLongLongLong, + IndexedSeq[Any], + ]( + Graph.getClass, + "maximalIndependentSet", + cb.emb.ecb.emodb.referenceGenomeMap, + jEdges, + mb.getHailClassLoader, + mb.getFS, + mb.getTaskContext, + region, + mb.getPType[PTuple](wrappedNodeType), + mb.getPType[PTuple](resultType), + mb.getObject(f), + ) } val (rt, maxSet: Code[_]) = typeWithReq.t match { @@ -1205,9 +1499,15 @@ class Emit[C]( val rawSet = cb.memoize(ms) val maxSet = cb.memoize(Code.newArray[String](rawSet.invoke[Int]("length"))) val i = cb.newLocal[Int]("mis_str_iseq_to_arr_i") - cb.for_(cb.assign(i, 0), i < maxSet.length(), cb.assign(i, i + 1), { - cb += maxSet.update(i, Code.checkcast[String](rawSet.invoke[Int, java.lang.Object]("apply", i))) - }) + cb.for_( + cb.assign(i, 0), + i < maxSet.length(), + cb.assign(i, i + 1), + cb += maxSet.update( + i, + Code.checkcast[String](rawSet.invoke[Int, java.lang.Object]("apply", i)), + ), + ) SJavaArrayString(typeWithReq.r.required) -> maxSet.get case _ => @@ -1216,7 +1516,7 @@ class Emit[C]( is.hail.expr.ir.functions.ArrayFunctions.unwrapReturn(cb, region, rt, maxSet) } - case x@ToSet(a) => + case x @ ToSet(a) => emitStream(a, cb, region).map(cb) { case stream: SStreamValue => val producer = stream.getProducer(mb) @@ -1226,23 +1526,23 @@ class Emit[C]( StreamUtils.writeToArrayBuilder(cb, producer, vab, region) val sorter = new ArraySorter(EmitRegion(mb, region), vab) - def lessThan(cb: EmitCodeBuilder, region: Value[Region], l: Value[_], r: Value[_]): Value[Boolean] = { + def lessThan(cb: EmitCodeBuilder, region: Value[Region], l: Value[_], r: Value[_]) + : Value[Boolean] = cb.emb.ecb.getOrdering(sct.loadedSType, sct.loadedSType) .ltNonnull(cb, sct.loadToSValue(cb, l), sct.loadToSValue(cb, r)) - } sorter.sort(cb, region, lessThan) - def skipNext(cb: EmitCodeBuilder, region: Value[Region], l: EmitCode, r: EmitCode): Value[Boolean] = { + def skipNext(cb: EmitCodeBuilder, region: Value[Region], l: EmitCode, r: EmitCode) + : Value[Boolean] = cb.emb.ecb.getOrdering(l.st, r.st) .equiv(cb, cb.memoize(l), cb.memoize(r), missingEqual = true) - } sorter.distinctFromSorted(cb, region, skipNext) sorter.toRegion(cb, x.typ) } - case x@ToDict(a) => + case x @ ToDict(a) => emitStream(a, cb, region).map(cb) { case stream: SStreamValue => val producer = stream.getProducer(mb) @@ -1252,7 +1552,8 @@ class Emit[C]( StreamUtils.writeToArrayBuilder(cb, producer, vab, region) val sorter = new ArraySorter(EmitRegion(mb, region), vab) - def lessThan(cb: EmitCodeBuilder, region: Value[Region], l: Value[_], r: Value[_]): Value[Boolean] = { + def lessThan(cb: EmitCodeBuilder, region: Value[Region], l: Value[_], r: Value[_]) + : Value[Boolean] = { val lk = cb.memoize(sct.loadToSValue(cb, l).asBaseStruct.loadField(cb, 0)) val rk = cb.memoize(sct.loadToSValue(cb, r).asBaseStruct.loadField(cb, 0)) @@ -1264,17 +1565,16 @@ class Emit[C]( sorter.sort(cb, region, lessThan) sorter.pruneMissing(cb) - def skipNext(cb: EmitCodeBuilder, region: Value[Region], l: EmitCode, r: EmitCode): Code[Boolean] = { + def skipNext(cb: EmitCodeBuilder, region: Value[Region], l: EmitCode, r: EmitCode) + : Code[Boolean] = { val lk = cb.memoize( - l.toI(cb).flatMap(cb) { x => - x.asBaseStruct.loadField(cb, 0) - }) + l.toI(cb).flatMap(cb)(x => x.asBaseStruct.loadField(cb, 0)) + ) val rk = cb.memoize( - r.toI(cb).flatMap(cb) { x => - x.asBaseStruct.loadField(cb, 0) - }) + r.toI(cb).flatMap(cb)(x => x.asBaseStruct.loadField(cb, 0)) + ) cb.emb.ecb.getOrdering(lk.st, rk.st) .equiv(cb, lk, rk, missingEqual = true) @@ -1286,14 +1586,14 @@ class Emit[C]( case GroupByKey(collection) => emitStream(collection, cb, region).map(cb) { case stream: SStreamValue => - val producer = stream.getProducer(mb) val sct = SingleCodeType.fromSType(producer.element.st) val sortedElts = new StagedArrayBuilder(cb, sct, producer.element.required, 16) StreamUtils.writeToArrayBuilder(cb, producer, sortedElts, region) val sorter = new ArraySorter(EmitRegion(mb, region), sortedElts) - def lt(cb: EmitCodeBuilder, region: Value[Region], l: Value[_], r: Value[_]): Value[Boolean] = { + def lt(cb: EmitCodeBuilder, region: Value[Region], l: Value[_], r: Value[_]) + : Value[Boolean] = { val lk = cb.memoize(sct.loadToSValue(cb, l).asBaseStruct.loadField(cb, 0)) val rk = cb.memoize(sct.loadToSValue(cb, r).asBaseStruct.loadField(cb, 0)) @@ -1315,42 +1615,51 @@ class Emit[C]( cb.assign(eltIdx, 0) cb.assign(groupSize, 0) - def sameKeyAtIndices(cb: EmitCodeBuilder, region: Value[Region], idx1: Value[Int], idx2: Value[Int]): Code[Boolean] = { + def sameKeyAtIndices( + cb: EmitCodeBuilder, + region: Value[Region], + idx1: Value[Int], + idx2: Value[Int], + ): Code[Boolean] = { val lk = cb.memoize( sortedElts.loadFromIndex(cb, region, idx1).flatMap(cb) { x => x.asBaseStruct.loadField(cb, 0) - }) + } + ) val rk = cb.memoize( sortedElts.loadFromIndex(cb, region, idx2).flatMap(cb) { x => x.asBaseStruct.loadField(cb, 0) - }) + } + ) cb.emb.ecb.getOrdering(lk.st, rk.st) .equiv(cb, lk, rk, missingEqual = true) } - cb.while_(eltIdx < sortedElts.size, { - val bottomOfLoop = CodeLabel() - val newGroup = CodeLabel() - - cb.assign(groupSize, groupSize + 1) - cb.if_(eltIdx.ceq(sortedElts.size - 1), { - cb.goto(newGroup) - }, { - cb.if_(sameKeyAtIndices(cb, region, eltIdx, cb.memoize(eltIdx + 1)), { - cb.goto(bottomOfLoop) - }, { - cb.goto(newGroup) - }) - }) - cb.define(newGroup) - groupSizes.add(cb, groupSize) - cb.assign(groupSize, 0) - - cb.define(bottomOfLoop) - cb.assign(eltIdx, eltIdx + 1) - }) + cb.while_( + eltIdx < sortedElts.size, { + val bottomOfLoop = CodeLabel() + val newGroup = CodeLabel() + + cb.assign(groupSize, groupSize + 1) + cb.if_( + eltIdx.ceq(sortedElts.size - 1), + cb.goto(newGroup), + cb.if_( + sameKeyAtIndices(cb, region, eltIdx, cb.memoize(eltIdx + 1)), + cb.goto(bottomOfLoop), + cb.goto(newGroup), + ), + ) + cb.define(newGroup) + groupSizes.add(cb, groupSize) + cb.assign(groupSize, 0) + + cb.define(bottomOfLoop) + cb.assign(eltIdx, eltIdx + 1) + }, + ) cb.assign(outerSize, groupSizes.size) val loadedElementType = sct.loadedSType.asInstanceOf[SBaseStruct] @@ -1358,31 +1667,39 @@ class Emit[C]( val kt = loadedElementType.fieldEmitTypes(0).storageType val groupType = PCanonicalStruct(true, ("key", kt), ("value", innerType)) val dictType = PCanonicalDict(kt, innerType, false) - val (addGroup, finishOuter) = dictType.arrayRep.constructFromFunctions(cb, region, outerSize, deepCopy = false) + val (addGroup, finishOuter) = + dictType.arrayRep.constructFromFunctions(cb, region, outerSize, deepCopy = false) cb.assign(eltIdx, 0) cb.assign(grpIdx, 0) - cb.while_(grpIdx < outerSize, { - cb.assign(groupSize, coerce[Int](groupSizes(grpIdx))) - cb.assign(withinGrpIdx, 0) - val firstStruct = sortedElts.loadFromIndex(cb, region, eltIdx).get(cb).asBaseStruct - val key = EmitCode.fromI(mb) { cb => firstStruct.loadField(cb, 0) } - val group = EmitCode.fromI(mb) { cb => - val (addElt, finishInner) = innerType - .constructFromFunctions(cb, region, groupSize, deepCopy = false) - cb.while_(withinGrpIdx < groupSize, { - val struct = sortedElts.loadFromIndex(cb, region, eltIdx).get(cb).asBaseStruct - addElt(cb, struct.loadField(cb, 1)) - cb.assign(eltIdx, eltIdx + 1) - cb.assign(withinGrpIdx, withinGrpIdx + 1) - }) - IEmitCode.present(cb, finishInner(cb)) - } - val elt = groupType.constructFromFields(cb, region, FastSeq(key, group), deepCopy = false) - addGroup(cb, IEmitCode.present(cb, elt)) - cb.assign(grpIdx, grpIdx + 1) - }) + cb.while_( + grpIdx < outerSize, { + cb.assign(groupSize, coerce[Int](groupSizes(grpIdx))) + cb.assign(withinGrpIdx, 0) + val firstStruct = + sortedElts.loadFromIndex(cb, region, eltIdx).getOrAssert(cb).asBaseStruct + val key = EmitCode.fromI(mb)(cb => firstStruct.loadField(cb, 0)) + val group = EmitCode.fromI(mb) { cb => + val (addElt, finishInner) = innerType + .constructFromFunctions(cb, region, groupSize, deepCopy = false) + cb.while_( + withinGrpIdx < groupSize, { + val struct = + sortedElts.loadFromIndex(cb, region, eltIdx).getOrAssert(cb).asBaseStruct + addElt(cb, struct.loadField(cb, 1)) + cb.assign(eltIdx, eltIdx + 1) + cb.assign(withinGrpIdx, withinGrpIdx + 1) + }, + ) + IEmitCode.present(cb, finishInner(cb)) + } + val elt = + groupType.constructFromFields(cb, region, FastSeq(key, group), deepCopy = false) + addGroup(cb, IEmitCode.present(cb, elt)) + cb.assign(grpIdx, grpIdx + 1) + }, + ) dictType.construct(finishOuter(cb)) } @@ -1391,15 +1708,15 @@ class Emit[C]( IEmitCode.present(cb, SRNGStateStaticSizeValue(cb)) case RNGSplit(state, dynBitstring) => - val stateValue = emitI(state).get(cb) - val tupleOrLong = emitI(dynBitstring).get(cb) + val stateValue = emitI(state).getOrAssert(cb) + val tupleOrLong = emitI(dynBitstring).getOrAssert(cb) val longs = if (tupleOrLong.isInstanceOf[SInt64Value]) { Array(tupleOrLong.asInt64.value) } else { val tuple = tupleOrLong.asBaseStruct Array.tabulate(tuple.st.size) { i => tuple.loadField(cb, i) - .get(cb, "RNGSplit tuple components are required") + .getOrFatal(cb, "RNGSplit tuple components are required") .asInt64 .value } @@ -1408,7 +1725,7 @@ class Emit[C]( longs.foreach(l => result = result.splitDyn(cb, l)) presentPC(result) - case x@StreamLen(a) => + case StreamLen(a) => emitStream(a, cb, region).map(cb) { case stream: SStreamValue => val producer = stream.getProducer(mb) producer.length match { @@ -1435,12 +1752,13 @@ class Emit[C]( emitI(path).flatMap(cb) { pathValue => emitI(pivots).flatMap(cb) { case pivotsVal: SIndexableValue => emitStream(child, cb, region).map(cb) { case childStream: SStreamValue => - EmitStreamDistribute.emit(cb, region, pivotsVal, childStream, pathValue, comparisonOp, spec) + EmitStreamDistribute.emit(cb, region, pivotsVal, childStream, pathValue, comparisonOp, + spec) } } } - case x@MakeNDArray(dataIR, shapeIR, rowMajorIR, errorId) => + case x @ MakeNDArray(dataIR, shapeIR, rowMajorIR, errorId) => val nDims = x.typ.nDims emitI(rowMajorIR).flatMap(cb) { isRowMajorCode => @@ -1448,73 +1766,110 @@ class Emit[C]( dataIR.typ match { case _: TArray => emitI(dataIR).map(cb) { case dataValue: SIndexableValue => - val xP = PCanonicalNDArray(PType.canonical(dataValue.st.elementType.storageType().setRequired(true)), nDims) + val xP = PCanonicalNDArray( + PType.canonical(dataValue.st.elementType.storageType().setRequired(true)), + nDims, + ) - cb.if_(dataValue.hasMissingValues(cb), { + cb.if_( + dataValue.hasMissingValues(cb), cb._throw(Code.newInstance[HailException, String, Int]( - "Cannot construct an ndarray with missing values.", errorId - )) - }) + "Cannot construct an ndarray with missing values.", + errorId, + )), + ) (0 until nDims).foreach { index => - cb.if_(shapeTupleValue.isFieldMissing(cb, index), - cb._fatalWithError(errorId, s"shape missing at index $index")) + cb.if_( + shapeTupleValue.isFieldMissing(cb, index), + cb._fatalWithError(errorId, s"shape missing at index $index"), + ) } - val stridesSettables = (0 until nDims).map(i => cb.newLocal[Long](s"make_ndarray_stride_$i")) + val stridesSettables = + (0 until nDims).map(i => cb.newLocal[Long](s"make_ndarray_stride_$i")) val shapeValues = (0 until nDims).map { i => - val shape = SingleCodeSCode.fromSCode(cb, shapeTupleValue.loadField(cb, i).get(cb), region) - cb.newLocal[Long](s"make_ndarray_shape_${ i }", coerce[Long](shape.code)) + val shape = SingleCodeSCode.fromSCode( + cb, + shapeTupleValue.loadField(cb, i).getOrAssert(cb), + region, + ) + cb.newLocal[Long](s"make_ndarray_shape_$i", coerce[Long](shape.code)) } - cb.if_(isRowMajorCode.asBoolean.value, { - val strides = xP.makeRowMajorStrides(shapeValues, cb) + cb.if_( + isRowMajorCode.asBoolean.value, { + val strides = xP.makeRowMajorStrides(shapeValues, cb) - stridesSettables.zip(strides).foreach { case (settable, stride) => - cb.assign(settable, stride) - } - }, { - val strides = xP.makeColumnMajorStrides(shapeValues, cb) - stridesSettables.zip(strides).foreach { case (settable, stride) => - cb.assign(settable, stride) - } - }) + stridesSettables.zip(strides).foreach { case (settable, stride) => + cb.assign(settable, stride) + } + }, { + val strides = xP.makeColumnMajorStrides(shapeValues, cb) + stridesSettables.zip(strides).foreach { case (settable, stride) => + cb.assign(settable, stride) + } + }, + ) - xP.constructByCopyingArray(shapeValues, stridesSettables, dataValue.asIndexable, cb, region) + xP.constructByCopyingArray( + shapeValues, + stridesSettables, + dataValue.asIndexable, + cb, + region, + ) } case _: TStream => EmitStream.produce(this, dataIR, cb, cb.emb, region, env, container) .map(cb) { case stream: SStreamValue => - val xP = PCanonicalNDArray(PType.canonical(stream.st.elementType.storageType().setRequired(true)), nDims) + val xP = PCanonicalNDArray( + PType.canonical(stream.st.elementType.storageType().setRequired(true)), + nDims, + ) (0 until nDims).foreach { index => - cb.if_(shapeTupleValue.isFieldMissing(cb, index), - cb.append(Code._fatal[Unit](s"shape missing at index $index"))) + cb.if_( + shapeTupleValue.isFieldMissing(cb, index), + cb.append(Code._fatal[Unit](s"shape missing at index $index")), + ) } - val stridesSettables = (0 until nDims).map(i => cb.newLocal[Long](s"make_ndarray_stride_$i")) + val stridesSettables = + (0 until nDims).map(i => cb.newLocal[Long](s"make_ndarray_stride_$i")) val shapeValues = (0 until nDims).map { i => - cb.newLocal[Long](s"make_ndarray_shape_${i}", shapeTupleValue.loadField(cb, i).get(cb).asLong.value) + cb.newLocal[Long]( + s"make_ndarray_shape_$i", + shapeTupleValue.loadField(cb, i).getOrAssert(cb).asLong.value, + ) } - cb.if_(isRowMajorCode.asBoolean.value, { - val strides = xP.makeRowMajorStrides(shapeValues, cb) - - - stridesSettables.zip(strides).foreach { case (settable, stride) => - cb.assign(settable, stride) - } - }, { - val strides = xP.makeColumnMajorStrides(shapeValues, cb) - stridesSettables.zip(strides).foreach { case (settable, stride) => - cb.assign(settable, stride) - } - }) - - val (firstElementAddress, finisher) = xP.constructDataFunction(shapeValues, stridesSettables, cb, region) - StreamUtils.storeNDArrayElementsAtAddress(cb, stream.getProducer(mb), region, firstElementAddress, errorId) + cb.if_( + isRowMajorCode.asBoolean.value, { + val strides = xP.makeRowMajorStrides(shapeValues, cb) + + stridesSettables.zip(strides).foreach { case (settable, stride) => + cb.assign(settable, stride) + } + }, { + val strides = xP.makeColumnMajorStrides(shapeValues, cb) + stridesSettables.zip(strides).foreach { case (settable, stride) => + cb.assign(settable, stride) + } + }, + ) + + val (firstElementAddress, finisher) = + xP.constructDataFunction(shapeValues, stridesSettables, cb, region) + StreamUtils.storeNDArrayElementsAtAddress( + cb, + stream.getProducer(mb), + region, + firstElementAddress, + errorId, + ) finisher(cb) } } @@ -1524,7 +1879,7 @@ class Emit[C]( case NDArrayShape(ndIR) => emitI(ndIR).map(cb) { case pc: SNDArrayValue => pc.shapeStruct(cb) } - case x@NDArrayReindex(child, indexMap) => + case NDArrayReindex(child, indexMap) => val childEC = emitI(child) childEC.map(cb) { case sndVal: SNDArrayPointerValue => val childPType = sndVal.st.pType @@ -1555,7 +1910,7 @@ class Emit[C]( ndt.flatMap(cb) { case ndValue: SNDArrayValue => val indexEmitCodes = idxs.map(idx => EmitCode.fromI(cb.emb)(emitInNewBuilder(_, idx))) IEmitCode.multiMapEmitCodes(cb, indexEmitCodes) { idxPCodes: IndexedSeq[SValue] => - val idxValues = idxPCodes.zipWithIndex.map { case (pc, idx) => + val idxValues = idxPCodes.zipWithIndex.map { case (pc, _) => pc.asInt64.value } @@ -1566,174 +1921,254 @@ class Emit[C]( } case NDArrayMatMul(lChild, rChild, errorID) => - emitNDArrayStandardStriding(lChild).flatMap(cb) { case (leftPVal: SNDArrayValue, leftIsColumnMajor: Value[Boolean]) => - emitNDArrayStandardStriding(rChild).map(cb) { case (rightPVal: SNDArrayValue, rightIsColumnMajor: Value[Boolean]) => - val lSType = leftPVal.st - val rSType = rightPVal.st - - val lShape = leftPVal.shapes - val rShape = rightPVal.shapes - - val unifiedShape = NDArrayEmitter.matmulShape(cb, lShape, rShape, errorID) - - val leftBroadcastMask = if (lSType.nDims > 2) NDArrayEmitter.broadcastMask(lShape) else IndexedSeq[Value[Long]]() - val rightBroadcastMask = if (rSType.nDims > 2) NDArrayEmitter.broadcastMask(rShape) else IndexedSeq[Value[Long]]() + emitNDArrayStandardStriding(lChild).flatMap(cb) { + case (leftPVal: SNDArrayValue, leftIsColumnMajor: Value[Boolean]) => + emitNDArrayStandardStriding(rChild).map(cb) { + case (rightPVal: SNDArrayValue, rightIsColumnMajor: Value[Boolean]) => + val lSType = leftPVal.st + val rSType = rightPVal.st + + val lShape = leftPVal.shapes + val rShape = rightPVal.shapes + + val unifiedShape = NDArrayEmitter.matmulShape(cb, lShape, rShape, errorID) + + val leftBroadcastMask = if (lSType.nDims > 2) NDArrayEmitter.broadcastMask(lShape) + else IndexedSeq[Value[Long]]() + val rightBroadcastMask = if (rSType.nDims > 2) NDArrayEmitter.broadcastMask(rShape) + else IndexedSeq[Value[Long]]() + + val outputPType = PCanonicalNDArray( + lSType.elementType.storageType().setRequired(true), + TNDArray.matMulNDims(lSType.nDims, rSType.nDims), + ) - val outputPType = PCanonicalNDArray(lSType.elementType.storageType().setRequired(true), - TNDArray.matMulNDims(lSType.nDims, rSType.nDims)) + if ( + (lSType.elementType.virtualType == TFloat64 || lSType.elementType.virtualType == TFloat32) && lSType.nDims == 2 && rSType.nDims == 2 + ) { + val leftDataAddress = leftPVal.firstDataAddress + val rightDataAddress = rightPVal.firstDataAddress - if ((lSType.elementType.virtualType == TFloat64 || lSType.elementType.virtualType == TFloat32) && lSType.nDims == 2 && rSType.nDims == 2) { - val leftDataAddress = leftPVal.firstDataAddress - val rightDataAddress = rightPVal.firstDataAddress + val M = lShape(lSType.nDims - 2) + val N = rShape(rSType.nDims - 1) + val K = lShape(lSType.nDims - 1) - val M = lShape(lSType.nDims - 2) - val N = rShape(rSType.nDims - 1) - val K = lShape(lSType.nDims - 1) + val LDA = leftIsColumnMajor.mux(M, K) + val LDB = rightIsColumnMajor.mux(K, N) + val LDC = M - val LDA = leftIsColumnMajor.mux(M, K) - val LDB = rightIsColumnMajor.mux(K, N) - val LDC = M + val TRANSA: Code[String] = leftIsColumnMajor.mux("N", "T") + val TRANSB: Code[String] = rightIsColumnMajor.mux("N", "T") - val TRANSA: Code[String] = leftIsColumnMajor.mux("N", "T") - val TRANSB: Code[String] = rightIsColumnMajor.mux("N", "T") + val (answerFirstElementAddr, answerFinisher) = outputPType.constructDataFunction( + IndexedSeq(M, N), + outputPType.makeColumnMajorStrides(IndexedSeq(M, N), cb), + cb, + region, + ) - val (answerFirstElementAddr, answerFinisher) = outputPType.constructDataFunction( - IndexedSeq(M, N), - outputPType.makeColumnMajorStrides(IndexedSeq(M, N), cb), - cb, - region) - - cb.if_((M.get cne 0L) && (N.get cne 0L) && (K.get cne 0L), { - cb.append(lSType.elementType.virtualType match { - case TFloat32 => - Code.invokeScalaObject13[String, String, Int, Int, Int, Float, Long, Int, Long, Int, Float, Long, Int, Unit](BLAS.getClass, method = "sgemm", - TRANSA, - TRANSB, - M.toI, - N.toI, - K.toI, - 1.0f, - leftDataAddress, - LDA.toI, - rightDataAddress, - LDB.toI, - 0.0f, - answerFirstElementAddr, - LDC.toI - ) - case TFloat64 => - Code.invokeScalaObject13[String, String, Int, Int, Int, Double, Long, Int, Long, Int, Double, Long, Int, Unit](BLAS.getClass, method = "dgemm", - TRANSA, - TRANSB, - M.toI, - N.toI, - K.toI, - 1.0, - leftDataAddress, - LDA.toI, - rightDataAddress, - LDB.toI, - 0.0, + cb.if_( + (M.get cne 0L) && (N.get cne 0L) && (K.get cne 0L), { + cb.append(lSType.elementType.virtualType match { + case TFloat32 => + Code.invokeScalaObject13[ + String, + String, + Int, + Int, + Int, + Float, + Long, + Int, + Long, + Int, + Float, + Long, + Int, + Unit, + ]( + BLAS.getClass, + method = "sgemm", + TRANSA, + TRANSB, + M.toI, + N.toI, + K.toI, + 1.0f, + leftDataAddress, + LDA.toI, + rightDataAddress, + LDB.toI, + 0.0f, + answerFirstElementAddr, + LDC.toI, + ) + case TFloat64 => + Code.invokeScalaObject13[ + String, + String, + Int, + Int, + Int, + Double, + Long, + Int, + Long, + Int, + Double, + Long, + Int, + Unit, + ]( + BLAS.getClass, + method = "dgemm", + TRANSA, + TRANSB, + M.toI, + N.toI, + K.toI, + 1.0, + leftDataAddress, + LDA.toI, + rightDataAddress, + LDB.toI, + 0.0, + answerFirstElementAddr, + LDC.toI, + ) + }) + }, + // Fill with zeroes + cb.append(Region.setMemory( answerFirstElementAddr, - LDC.toI - ) - }) - }, - { // Fill with zeroes - cb.append(Region.setMemory(answerFirstElementAddr, (M * N) * outputPType.elementType.byteSize, 0.toByte)) - } - ) - - answerFinisher(cb) - } else if (lSType.elementType.virtualType == TFloat64 && lSType.nDims == 2 && rSType.nDims == 1) { - val leftDataAddress = leftPVal.firstDataAddress - val rightDataAddress = rightPVal.firstDataAddress - - val numRows = lShape(lSType.nDims - 2) - val numCols = lShape(lSType.nDims - 1) - val M = cb.newLocal[Long]("dgemv_m", leftIsColumnMajor.mux(numRows, numCols)) - val N = cb.newLocal[Long]("dgemv_n", leftIsColumnMajor.mux(numCols, numRows)) - val outputSize = cb.newLocal[Long]("output_size", numRows) - - val alpha = 1.0 - val beta = 0.0 - - val LDA = M - val TRANS: Code[String] = leftIsColumnMajor.mux("N", "T") + (M * N) * outputPType.elementType.byteSize, + 0.toByte, + )), + ) - val (answerFirstElementAddr, answerFinisher) = outputPType.constructDataFunction( - IndexedSeq(outputSize), - outputPType.makeColumnMajorStrides(IndexedSeq(outputSize), cb), - cb, - region) - - cb.append(Code.invokeScalaObject11[String, Int, Int, Double, Long, Int, Long, Int, Double, Long, Int, Unit](BLAS.getClass, method="dgemv", - TRANS, - M.toI, - N.toI, - alpha, - leftDataAddress, - LDA.toI, - rightDataAddress, - 1, - beta, - answerFirstElementAddr, - 1 - )) + answerFinisher(cb) + } else if ( + lSType.elementType.virtualType == TFloat64 && lSType.nDims == 2 && rSType.nDims == 1 + ) { + val leftDataAddress = leftPVal.firstDataAddress + val rightDataAddress = rightPVal.firstDataAddress + + val numRows = lShape(lSType.nDims - 2) + val numCols = lShape(lSType.nDims - 1) + val M = cb.newLocal[Long]("dgemv_m", leftIsColumnMajor.mux(numRows, numCols)) + val N = cb.newLocal[Long]("dgemv_n", leftIsColumnMajor.mux(numCols, numRows)) + val outputSize = cb.newLocal[Long]("output_size", numRows) + + val alpha = 1.0 + val beta = 0.0 + + val LDA = M + val TRANS: Code[String] = leftIsColumnMajor.mux("N", "T") + + val (answerFirstElementAddr, answerFinisher) = outputPType.constructDataFunction( + IndexedSeq(outputSize), + outputPType.makeColumnMajorStrides(IndexedSeq(outputSize), cb), + cb, + region, + ) + cb.append(Code.invokeScalaObject11[ + String, + Int, + Int, + Double, + Long, + Int, + Long, + Int, + Double, + Long, + Int, + Unit, + ]( + BLAS.getClass, + method = "dgemv", + TRANS, + M.toI, + N.toI, + alpha, + leftDataAddress, + LDA.toI, + rightDataAddress, + 1, + beta, + answerFirstElementAddr, + 1, + )) + + answerFinisher(cb) + } else { + val numericElementType = tcoerce[PNumeric]( + PType.canonical(lSType.elementType.storageType().setRequired(true)) + ) + val eVti = typeToTypeInfo(numericElementType) + + val emitter = new NDArrayEmitter(unifiedShape, leftPVal.st.elementType) { + override def outputElement( + cb: EmitCodeBuilder, + idxVars: IndexedSeq[Value[Long]], + ): SValue = { + val element = cb.newFieldAny("matmul_element", eVti) + val k = cb.newField[Long]("ndarray_matmul_k") + + val (lIndices: IndexedSeq[Value[Long]], rIndices: IndexedSeq[Value[Long]]) = + (lSType.nDims, rSType.nDims, idxVars) match { + case (1, 1, Seq()) => (IndexedSeq(k), IndexedSeq(k)) + case (1, _, stack :+ m) => + val rStackVars = + NDArrayEmitter.zeroBroadcastedDims(stack, rightBroadcastMask) + (IndexedSeq(k), rStackVars :+ k :+ m) + case (_, 1, stack :+ n) => + val lStackVars = + NDArrayEmitter.zeroBroadcastedDims(stack, leftBroadcastMask) + (lStackVars :+ n :+ k, FastSeq(k)) + case (_, _, stack :+ n :+ m) => + val lStackVars = + NDArrayEmitter.zeroBroadcastedDims(stack, leftBroadcastMask) + val rStackVars = + NDArrayEmitter.zeroBroadcastedDims(stack, rightBroadcastMask) + (lStackVars :+ n :+ k, rStackVars :+ k :+ m) + } - answerFinisher(cb) - } else { - val numericElementType = tcoerce[PNumeric](PType.canonical(lSType.elementType.storageType().setRequired(true))) - val eVti = typeToTypeInfo(numericElementType) - - val emitter = new NDArrayEmitter(unifiedShape, leftPVal.st.elementType) { - override def outputElement(cb: EmitCodeBuilder, idxVars: IndexedSeq[Value[Long]]): SValue = { - val element = cb.newFieldAny("matmul_element", eVti) - val k = cb.newField[Long]("ndarray_matmul_k") - - val (lIndices: IndexedSeq[Value[Long]], rIndices: IndexedSeq[Value[Long]]) = (lSType.nDims, rSType.nDims, idxVars) match { - case (1, 1, Seq()) => (IndexedSeq(k), IndexedSeq(k)) - case (1, _, stack :+ m) => - val rStackVars = NDArrayEmitter.zeroBroadcastedDims(stack, rightBroadcastMask) - (IndexedSeq(k), rStackVars :+ k :+ m) - case (_, 1, stack :+ n) => - val lStackVars = NDArrayEmitter.zeroBroadcastedDims(stack, leftBroadcastMask) - (lStackVars :+ n :+ k, FastSeq(k)) - case (_, _, stack :+ n :+ m) => - val lStackVars = NDArrayEmitter.zeroBroadcastedDims(stack, leftBroadcastMask) - val rStackVars = NDArrayEmitter.zeroBroadcastedDims(stack, rightBroadcastMask) - (lStackVars :+ n :+ k, rStackVars :+ k :+ m) - } + def multiply(l: SValue, r: SValue): Code[_] = { + (l.st, r.st) match { + case (SInt32, SInt32) => + l.asInt.value * r.asInt.value + case (SInt64, SInt64) => + l.asLong.value * r.asLong.value + case (SFloat32, SFloat32) => + l.asFloat.value * r.asFloat.value + case (SFloat64, SFloat64) => + l.asDouble.value * r.asDouble.value + } + } - def multiply(l: SValue, r: SValue): Code[_] = { - (l.st, r.st) match { - case (SInt32, SInt32) => - l.asInt.value * r.asInt.value - case (SInt64, SInt64) => - l.asLong.value * r.asLong.value - case (SFloat32, SFloat32) => - l.asFloat.value * r.asFloat.value - case (SFloat64, SFloat64) => - l.asDouble.value * r.asDouble.value + val kLen = lShape(lSType.nDims - 1) + cb.assignAny(element, numericElementType.zero) + cb.for_( + cb.assign(k, 0L), + k < kLen, + cb.assign(k, k + 1L), { + val lElem = leftPVal.loadElement(lIndices, cb) + val rElem = rightPVal.loadElement(rIndices, cb) + cb.assignAny( + element, + numericElementType.add(multiply(lElem, rElem), element), + ) + }, + ) + + primitive(outputPType.elementType.virtualType, element) } } - - - val kLen = lShape(lSType.nDims - 1) - cb.assignAny(element, numericElementType.zero) - cb.for_(cb.assign(k, 0L), k < kLen, cb.assign(k, k + 1L), { - val lElem = leftPVal.loadElement(lIndices, cb) - val rElem = rightPVal.loadElement(rIndices, cb) - cb.assignAny(element, numericElementType.add(multiply(lElem, rElem), element)) - }) - - primitive(outputPType.elementType.virtualType, element) + emitter.emit(cb, outputPType, region) } - } - emitter.emit(cb, outputPType, region) } - } } case NDArrayInv(nd, errorID) => @@ -1744,7 +2179,6 @@ class Emit[C]( val shapeArray = pndVal.shapes val stridesArray = ndPT.makeColumnMajorStrides(shapeArray, cb) - assert(shapeArray.length == 2) val M = shapeArray(0) @@ -1762,8 +2196,12 @@ class Emit[C]( val INFOdgetri = mb.newLocal[Int]() def INFOerror(cb: EmitCodeBuilder, fun: String, info: LocalRef[Int]): Unit = - cb.if_(info cne 0, - cb._fatalWithError(errorID, const(s"LAPACK error $fun. Error code = ").concat(info.toS)) + cb.if_( + info cne 0, + cb._fatalWithError( + errorID, + const(s"LAPACK error $fun. Error code = ").concat(info.toS), + ), ) cb.if_(N cne M, cb._fatalWithError(errorID, "Can only invert square matrix")) @@ -1772,43 +2210,58 @@ class Emit[C]( cb.assign(IPIVaddr, IPIVptype.allocate(region, N.toI)) IPIVptype.stagedInitialize(cb, IPIVaddr, N.toI) - val (aAadrFirstElement, finish) = ndPT.constructDataFunction(shapeArray, stridesArray, cb, region) - cb.append(Region.copyFrom(dataFirstAddress, - aAadrFirstElement, An.toL * 8L)) - - cb.assign(INFOdgetrf, Code.invokeScalaObject5[Int, Int, Long, Int, Long, Int](LAPACK.getClass, "dgetrf", - M.toI, - N.toI, - aAadrFirstElement, - LDA.toI, - IPIVptype.firstElementOffset(IPIVaddr, N.toI) - )) + val (aAadrFirstElement, finish) = + ndPT.constructDataFunction(shapeArray, stridesArray, cb, region) + cb.append(Region.copyFrom(dataFirstAddress, aAadrFirstElement, An.toL * 8L)) + + cb.assign( + INFOdgetrf, + Code.invokeScalaObject5[Int, Int, Long, Int, Long, Int]( + LAPACK.getClass, + "dgetrf", + M.toI, + N.toI, + aAadrFirstElement, + LDA.toI, + IPIVptype.firstElementOffset(IPIVaddr, N.toI), + ), + ) INFOerror(cb, "dgetrf", INFOdgetrf) cb.assign(WORKaddr, Code.invokeStatic1[Memory, Long, Long]("malloc", An.toL * 8L)) - cb.assign(INFOdgetri, Code.invokeScalaObject6[Int, Long, Int, Long, Long, Int, Int](LAPACK.getClass, "dgetri", - N.toI, - aAadrFirstElement, - LDA.toI, - IPIVptype.firstElementOffset(IPIVaddr, N.toI), - WORKaddr, - N.toI - )) + cb.assign( + INFOdgetri, + Code.invokeScalaObject6[Int, Long, Int, Long, Long, Int, Int]( + LAPACK.getClass, + "dgetri", + N.toI, + aAadrFirstElement, + LDA.toI, + IPIVptype.firstElementOffset(IPIVaddr, N.toI), + WORKaddr, + N.toI, + ), + ) INFOerror(cb, "dgetri", INFOdgetri) finish(cb) } - case x@NDArraySVD(nd, full_matrices, computeUV, errorID) => + case NDArraySVD(nd, full_matrices, computeUV, errorID) => emitNDArrayColumnMajorStrides(nd).flatMap(cb) { case ndPVal: SNDArrayValue => - val infoDGESDDResult = cb.newLocal[Int]("infoDGESDD") def infoDGESDDErrorTest(cb: EmitCodeBuilder, extraErrorMsg: String): Unit = - cb.if_(infoDGESDDResult cne 0, - cb._fatalWithError(errorID, const(s"LAPACK error DGESDD. $extraErrorMsg Error code = ").concat(infoDGESDDResult.toS)) + cb.if_( + infoDGESDDResult cne 0, + cb._fatalWithError( + errorID, + const(s"LAPACK error DGESDD. $extraErrorMsg Error code = ").concat( + infoDGESDDResult.toS + ), + ), ) val LWORKAddress = mb.newLocal[Long]("svd_lwork_address") @@ -1836,43 +2289,86 @@ class Emit[C]( val vtPType = outputPType.fields(2).typ.asInstanceOf[PCanonicalNDArray] val uShapeSeq = FastSeq[Value[Long]](M, UCOL) - val (uData, uFinisher) = uPType.constructDataFunction(uShapeSeq, uPType.makeColumnMajorStrides(uShapeSeq, cb), cb, region) + val (uData, uFinisher) = uPType.constructDataFunction( + uShapeSeq, + uPType.makeColumnMajorStrides(uShapeSeq, cb), + cb, + region, + ) val vtShapeSeq = FastSeq[Value[Long]](LDVT, N) - val (vtData, vtFinisher) = vtPType.constructDataFunction(vtShapeSeq, vtPType.makeColumnMajorStrides(vtShapeSeq, cb), cb, region) + val (vtData, vtFinisher) = vtPType.constructDataFunction( + vtShapeSeq, + vtPType.makeColumnMajorStrides(vtShapeSeq, cb), + cb, + region, + ) (if (full_matrices) "A" else "S", sPType, uData, uFinisher, vtData, vtFinisher) - } - else { + } else { val outputPType = retPTypeUncast.asInstanceOf[PCanonicalNDArray] - def noOp(cb: EmitCodeBuilder): SNDArrayValue = { + def noOp(cb: EmitCodeBuilder): SNDArrayValue = throw new IllegalStateException("Can't happen") - } - ("N", outputPType.asInstanceOf[PCanonicalNDArray], const(0L), noOp(_), const(0L), noOp(_)) + ( + "N", + outputPType.asInstanceOf[PCanonicalNDArray], + const(0L), + noOp(_), + const(0L), + noOp(_), + ) } - val (sDataAddress, sFinisher) = sPType.constructDataFunction(IndexedSeq(K), sPType.makeColumnMajorStrides(IndexedSeq(K), cb), cb, region) - - cb.assign(infoDGESDDResult, Code.invokeScalaObject13[String, Int, Int, Long, Int, Long, Long, Int, Long, Int, Long, Int, Long, Int](LAPACK.getClass, "dgesdd", - jobz, - M.toI, - N.toI, - A, - LDA.toI, - sDataAddress, - uData, - LDU.toI, - vtData, - LDVT.toI, - LWORKAddress, - -1, - IWORK - )) + val (sDataAddress, sFinisher) = sPType.constructDataFunction( + IndexedSeq(K), + sPType.makeColumnMajorStrides(IndexedSeq(K), cb), + cb, + region, + ) + + cb.assign( + infoDGESDDResult, + Code.invokeScalaObject13[ + String, + Int, + Int, + Long, + Int, + Long, + Long, + Int, + Long, + Int, + Long, + Int, + Long, + Int, + ]( + LAPACK.getClass, + "dgesdd", + jobz, + M.toI, + N.toI, + A, + LDA.toI, + sDataAddress, + uData, + LDU.toI, + vtData, + LDVT.toI, + LWORKAddress, + -1, + IWORK, + ), + ) infoDGESDDErrorTest(cb, "Failed size query.") - cb.assign(IWORK, Code.invokeStatic1[Memory, Long, Long]("malloc", K.toL * 8L * 4L)) // 8K 4 byte integers. + cb.assign( + IWORK, + Code.invokeStatic1[Memory, Long, Long]("malloc", K.toL * 8L * 4L), + ) // 8K 4 byte integers. cb.assign(A, Code.invokeStatic1[Memory, Long, Long]("malloc", M * N * 8L)) // Copy data into A because dgesdd destroys the input array: cb.append(Region.copyFrom(firstElementDataAddress, A, (M * N) * 8L)) @@ -1883,21 +2379,41 @@ class Emit[C]( cb.assign(WORK, Code.invokeStatic1[Memory, Long, Long]("malloc", LWORK.toL * 8L)) - cb.assign(infoDGESDDResult, Code.invokeScalaObject13[String, Int, Int, Long, Int, Long, Long, Int, Long, Int, Long, Int, Long, Int](LAPACK.getClass, "dgesdd", - jobz, - M.toI, - N.toI, - A, - LDA.toI, - sDataAddress, - uData, - LDU.toI, - vtData, - LDVT.toI, - WORK, - LWORK, - IWORK - )) + cb.assign( + infoDGESDDResult, + Code.invokeScalaObject13[ + String, + Int, + Int, + Long, + Int, + Long, + Long, + Int, + Long, + Int, + Long, + Int, + Long, + Int, + ]( + LAPACK.getClass, + "dgesdd", + jobz, + M.toI, + N.toI, + A, + LDA.toI, + sDataAddress, + uData, + LDU.toI, + vtData, + LDVT.toI, + WORK, + LWORK, + IWORK, + ), + ) cb.append(Code.invokeStatic1[Memory, Long, Unit]("free", IWORK.load())) cb.append(Code.invokeStatic1[Memory, Long, Unit]("free", A.load())) @@ -1913,7 +2429,16 @@ class Emit[C]( val vt = vtFinisher(cb) val outputPType = NDArraySVD.pTypes(true, false).asInstanceOf[PCanonicalTuple] - outputPType.constructFromFields(cb, region, FastSeq(EmitCode.present(cb.emb, u), EmitCode.present(cb.emb, s), EmitCode.present(cb.emb, vt)), deepCopy = false) + outputPType.constructFromFields( + cb, + region, + FastSeq( + EmitCode.present(cb.emb, u), + EmitCode.present(cb.emb, s), + EmitCode.present(cb.emb, vt), + ), + deepCopy = false, + ) } else { s } @@ -1921,7 +2446,7 @@ class Emit[C]( } - case NDArrayEigh(nd, eigvalsOnly, errorID) => + case NDArrayEigh(nd, eigvalsOnly, _) => emitNDArrayColumnMajorStrides(nd).map(cb) { case mat: SNDArrayValue => val n = mat.shapes(0) val jobz = if (eigvalsOnly) "N" else "V" @@ -1942,18 +2467,24 @@ class Emit[C]( } else { val resultType = NDArrayEigh.pTypes(false, false).asInstanceOf[PCanonicalTuple] val Z = matType.constructUninitialized(FastSeq(n, n), cb, region) - val iSuppZ = vecType.constructUninitialized(FastSeq(SizeValueDyn(cb.memoize(n * 2))), cb, region) + val iSuppZ = + vecType.constructUninitialized(FastSeq(SizeValueDyn(cb.memoize(n * 2))), cb, region) SNDArray.syevr(cb, "U", mat, W, Some((Z, iSuppZ)), work, iWork) - resultType.constructFromFields(cb, region, FastSeq(EmitCode.present(cb.emb, W), EmitCode.present(cb.emb, Z)), false) + resultType.constructFromFields( + cb, + region, + FastSeq(EmitCode.present(cb.emb, W), EmitCode.present(cb.emb, Z)), + false, + ) } } - case x@NDArrayQR(nd, mode, errorID) => - // See here to understand different modes: https://docs.scipy.org/doc/numpy/reference/generated/numpy.linalg.qr.html + case NDArrayQR(nd, mode, errorID) => + /* See here to understand different modes: + * https://docs.scipy.org/doc/numpy/reference/generated/numpy.linalg.qr.html */ emitNDArrayColumnMajorStrides(nd).map(cb) { case pndValue: SNDArrayValue => - val resultPType = NDArrayQR.pType(mode, false) // This does a lot of byte level copying currently, so only trust @@ -1970,10 +2501,12 @@ class Emit[C]( def get: Code[Long] = (M < N).mux(M, N) } val LDA = new Value[Long] { - override def get: Code[Long] = (M > 1L).mux(M, 1L) // Possible stride tricks could change this in the future. + override def get: Code[Long] = + (M > 1L).mux(M, 1L) // Possible stride tricks could change this in the future. } - def LWORK = (Region.loadDouble(LWORKAddress).toI > 0).mux(Region.loadDouble(LWORKAddress).toI, 1) + def LWORK = + (Region.loadDouble(LWORKAddress).toI > 0).mux(Region.loadDouble(LWORKAddress).toI, 1) val ndPT = pType val dataFirstElementAddress = pndValue.firstDataAddress @@ -1981,10 +2514,12 @@ class Emit[C]( val hPType = ndPT val hShapeArray = FastSeq[Value[Long]](N, M) val hStridesArray = hPType.makeRowMajorStrides(hShapeArray, cb) - val (hFirstElement, hFinisher) = hPType.constructDataFunction(hShapeArray, hStridesArray, cb, region) + val (hFirstElement, hFinisher) = + hPType.constructDataFunction(hShapeArray, hStridesArray, cb, region) val tauNDPType = PCanonicalNDArray(PFloat64Required, 1, true) - val (tauFirstElementAddress, tauFinisher) = tauNDPType.constructDataFunction(IndexedSeq(K), IndexedSeq(const(8L)), cb, region) + val (tauFirstElementAddress, tauFinisher) = + tauNDPType.constructDataFunction(IndexedSeq(K), IndexedSeq(const(8L)), cb, region) val workAddress = cb.newLocal[Long]("ndarray_qr_workAddress") val aNumElements = cb.newLocal[Long]("ndarray_qr_aNumElements") @@ -1992,8 +2527,14 @@ class Emit[C]( val infoDGEQRFResult = cb.newLocal[Int]("ndaray_qr_infoDGEQRFResult") def infoDGEQRFErrorTest(cb: EmitCodeBuilder, extraErrorMsg: String): Unit = - cb.if_(infoDGEQRFResult cne 0, - cb._fatalWithError(errorID, const(s"LAPACK error DGEQRF. $extraErrorMsg Error code = ").concat(infoDGEQRFResult.toS)) + cb.if_( + infoDGEQRFResult cne 0, + cb._fatalWithError( + errorID, + const(s"LAPACK error DGEQRF. $extraErrorMsg Error code = ").concat( + infoDGEQRFResult.toS + ), + ), ) // Computing H and Tau @@ -2002,28 +2543,38 @@ class Emit[C]( cb.assign(LWORKAddress, region.allocate(8L, 8L)) - cb.assign(infoDGEQRFResult, Code.invokeScalaObject7[Int, Int, Long, Int, Long, Long, Int, Int](LAPACK.getClass, "dgeqrf", - M.toI, - N.toI, - hFirstElement, - LDA.toI, - tauFirstElementAddress, - LWORKAddress, - -1 - )) + cb.assign( + infoDGEQRFResult, + Code.invokeScalaObject7[Int, Int, Long, Int, Long, Long, Int, Int]( + LAPACK.getClass, + "dgeqrf", + M.toI, + N.toI, + hFirstElement, + LDA.toI, + tauFirstElementAddress, + LWORKAddress, + -1, + ), + ) infoDGEQRFErrorTest(cb, "Failed size query.") cb.assign(workAddress, Code.invokeStatic1[Memory, Long, Long]("malloc", LWORK.toL * 8L)) - cb.assign(infoDGEQRFResult, Code.invokeScalaObject7[Int, Int, Long, Int, Long, Long, Int, Int](LAPACK.getClass, "dgeqrf", - M.toI, - N.toI, - hFirstElement, - LDA.toI, - tauFirstElementAddress, - workAddress, - LWORK - )) + cb.assign( + infoDGEQRFResult, + Code.invokeScalaObject7[Int, Int, Long, Int, Long, Long, Int, Int]( + LAPACK.getClass, + "dgeqrf", + M.toI, + N.toI, + hFirstElement, + LDA.toI, + tauFirstElementAddress, + workAddress, + LWORK, + ), + ) cb.append(Code.invokeStatic1[Memory, Long, Unit]("free", workAddress.load())) infoDGEQRFErrorTest(cb, "Failed to compute H and Tau.") @@ -2033,10 +2584,15 @@ class Emit[C]( val resultType = resultPType.asInstanceOf[PCanonicalBaseStruct] val tau = tauFinisher(cb) - resultType.constructFromFields(cb, region, FastSeq( - EmitCode.present(cb.emb, h), - EmitCode.present(cb.emb, tau) - ), deepCopy = false) + resultType.constructFromFields( + cb, + region, + FastSeq( + EmitCode.present(cb.emb, h), + EmitCode.present(cb.emb, tau), + ), + deepCopy = false, + ) } else { val (rPType, rRows, rCols) = if (mode == "r") { @@ -2053,35 +2609,45 @@ class Emit[C]( val rStridesArray = rPType.makeColumnMajorStrides(rShapeArray, cb) - val (rDataAddress, rFinisher) = rPType.constructDataFunction(rShapeArray, rStridesArray, cb, region) + val (rDataAddress, rFinisher) = + rPType.constructDataFunction(rShapeArray, rStridesArray, cb, region) - // This block assumes that `rDataAddress` and `aAddressDGEQRF` point to column major arrays. + /* This block assumes that `rDataAddress` and `aAddressDGEQRF` point to column major + * arrays. */ // TODO: Abstract this into ndarray ptype/SCode interface methods. val currRow = cb.newLocal[Long]("ndarray_qr_currRow") val currCol = cb.newLocal[Long]("ndarray_qr_currCol") val curWriteAddress = cb.newLocal[Long]("ndarray_qr_curr_write_addr", rDataAddress) - // I think this just copies out the upper triangle into new ndarray in column major order - cb.for_(cb.assign(currCol, 0L), currCol < rCols, cb.assign(currCol, currCol + 1L), { - cb.for_(cb.assign(currRow, 0L), currRow < rRows, cb.assign(currRow, currRow + 1L), { - cb.append(Region.storeDouble( - curWriteAddress, - (currCol >= currRow).mux( - h.loadElement(IndexedSeq(currCol, currRow), cb).asDouble.value, - 0.0 - ) - )) - cb.assign(curWriteAddress, curWriteAddress + rPType.elementType.byteSize) - }) - }) + /* I think this just copies out the upper triangle into new ndarray in column major + * order */ + cb.for_( + cb.assign(currCol, 0L), + currCol < rCols, + cb.assign(currCol, currCol + 1L), { + cb.for_( + cb.assign(currRow, 0L), + currRow < rRows, + cb.assign(currRow, currRow + 1L), { + cb.append(Region.storeDouble( + curWriteAddress, + (currCol >= currRow).mux( + h.loadElement(IndexedSeq(currCol, currRow), cb).asDouble.value, + 0.0, + ), + )) + cb.assign(curWriteAddress, curWriteAddress + rPType.elementType.byteSize) + }, + ) + }, + ) val computeR = rFinisher(cb) if (mode == "r") { computeR - } - else { + } else { val crPType = resultPType.asInstanceOf[PCanonicalTuple] val qPType = crPType.types(0).asInstanceOf[PCanonicalNDArray] @@ -2091,8 +2657,14 @@ class Emit[C]( val infoDORGQRResult = cb.newLocal[Int]("ndarray_qr_DORGQR_info") def infoDORQRErrorTest(cb: EmitCodeBuilder, extraErrorMsg: String): Unit = - cb.if_(infoDORGQRResult cne 0, - cb._fatalWithError(errorID, const(s"LAPACK error DORGQR. $extraErrorMsg Error code = ").concat(infoDORGQRResult.toS)) + cb.if_( + infoDORGQRResult cne 0, + cb._fatalWithError( + errorID, + const(s"LAPACK error DORGQR. $extraErrorMsg Error code = ").concat( + infoDORGQRResult.toS + ), + ), ) val qCondition = cb.newLocal[Boolean]("ndarray_qr_qCondition") @@ -2101,54 +2673,82 @@ class Emit[C]( val qNumElements = cb.newLocal[Long]("ndarray_qr_qNumElements") - val rNDArray = computeR cb.assign(qCondition, const(mode == "complete") && (M > N)) cb.assign(numColsToUse, qCondition.mux(M, K)) cb.assign(qNumElements, M * numColsToUse) - cb.if_(qCondition, { - cb.assign(aAddressDORGQRFirstElement, region.allocate(8L, qNumElements * ndPT.elementType.byteSize)) - cb.append(Region.copyFrom(hFirstElement, - aAddressDORGQRFirstElement, aNumElements * 8L)) - }, { + cb.if_( + qCondition, { + cb.assign( + aAddressDORGQRFirstElement, + region.allocate(8L, qNumElements * ndPT.elementType.byteSize), + ) + cb.append(Region.copyFrom( + hFirstElement, + aAddressDORGQRFirstElement, + aNumElements * 8L, + )) + }, // We are intentionally clobbering h, since we aren't going to return it to anyone. - cb.assign(aAddressDORGQRFirstElement, hFirstElement) - }) + cb.assign(aAddressDORGQRFirstElement, hFirstElement), + ) - cb.assign(infoDORGQRResult, Code.invokeScalaObject8[Int, Int, Int, Long, Int, Long, Long, Int, Int](LAPACK.getClass, "dorgqr", - M.toI, - numColsToUse.toI, - K.toI, - aAddressDORGQRFirstElement, - LDA.toI, - tauFirstElementAddress, - LWORKAddress, - -1 - )) + cb.assign( + infoDORGQRResult, + Code.invokeScalaObject8[Int, Int, Int, Long, Int, Long, Long, Int, Int]( + LAPACK.getClass, + "dorgqr", + M.toI, + numColsToUse.toI, + K.toI, + aAddressDORGQRFirstElement, + LDA.toI, + tauFirstElementAddress, + LWORKAddress, + -1, + ), + ) infoDORQRErrorTest(cb, "Failed size query.") - cb.append(workAddress := Code.invokeStatic1[Memory, Long, Long]("malloc", LWORK.toL * 8L)) - cb.assign(infoDORGQRResult, Code.invokeScalaObject8[Int, Int, Int, Long, Int, Long, Long, Int, Int](LAPACK.getClass, "dorgqr", - M.toI, - numColsToUse.toI, - K.toI, - aAddressDORGQRFirstElement, - LDA.toI, - tauFirstElementAddress, - workAddress, - LWORK + cb.append(workAddress := Code.invokeStatic1[Memory, Long, Long]( + "malloc", + LWORK.toL * 8L, )) + cb.assign( + infoDORGQRResult, + Code.invokeScalaObject8[Int, Int, Int, Long, Int, Long, Long, Int, Int]( + LAPACK.getClass, + "dorgqr", + M.toI, + numColsToUse.toI, + K.toI, + aAddressDORGQRFirstElement, + LDA.toI, + tauFirstElementAddress, + workAddress, + LWORK, + ), + ) cb.append(Code.invokeStatic1[Memory, Long, Unit]("free", workAddress.load())) infoDORQRErrorTest(cb, "Failed to compute Q.") - val (qFirstElementAddress, qFinisher) = qPType.constructDataFunction(qShapeArray, qStridesArray, cb, region) - cb.append(Region.copyFrom(aAddressDORGQRFirstElement, - qFirstElementAddress, (M * numColsToUse) * 8L)) - - crPType.constructFromFields(cb, region, FastSeq( - EmitCode.present(cb.emb, qFinisher(cb)), - EmitCode.present(cb.emb, rNDArray) - ), deepCopy = false) + val (qFirstElementAddress, qFinisher) = + qPType.constructDataFunction(qShapeArray, qStridesArray, cb, region) + cb.append(Region.copyFrom( + aAddressDORGQRFirstElement, + qFirstElementAddress, + (M * numColsToUse) * 8L, + )) + + crPType.constructFromFields( + cb, + region, + FastSeq( + EmitCode.present(cb.emb, qFinisher(cb)), + EmitCode.present(cb.emb, rNDArray), + ), + deepCopy = false, + ) } } result @@ -2174,19 +2774,26 @@ class Emit[C]( } case ResultOp(idx, sig) => - val AggContainer(aggs, sc, _) = container.get + val AggContainer(_, sc, _) = container.get val rvAgg = agg.Extract.getAgg(sig) rvAgg.result(cb, sc.states(idx), region) - case x@ApplySeeded(fn, args, rngState, staticUID, rt) => + case x @ ApplySeeded(_, args, rngState, staticUID, rt) => val codeArgs = args.map(a => EmitCode.fromI(cb.emb)(emitInNewBuilder(_, a))) val codeArgsMem = codeArgs.map(_.memoize(cb, "ApplySeeded_arg")) - val state = emitI(rngState).get(cb) + val state = emitI(rngState).getOrAssert(cb) val impl = x.implementation assert(impl.unify(Array.empty[Type], x.argTypes, rt)) val newState = EmitCode.present(mb, state.asRNGState.splitStatic(cb, staticUID)) - impl.applyI(EmitRegion(cb.emb, region), cb, impl.computeReturnEmitType(x.typ, newState.emitType +: codeArgs.map(_.emitType)).st, Seq[Type](), const(0), newState +: codeArgsMem.map(_.load): _*) + impl.applyI( + EmitRegion(cb.emb, region), + cb, + impl.computeReturnEmitType(x.typ, newState.emitType +: codeArgs.map(_.emitType)).st, + Seq[Type](), + const(0), + newState +: codeArgsMem.map(_.load): _* + ) case AggStateValue(i, _) => val AggContainer(_, sc, _) = container.get @@ -2194,14 +2801,19 @@ class Emit[C]( case ToArray(a) => EmitStream.produce(this, a, cb, cb.emb, region, env, container) - .map(cb) { case stream: SStreamValue => StreamUtils.toArray(cb, stream.getProducer(mb), region) } + .map(cb) { case stream: SStreamValue => + StreamUtils.toArray(cb, stream.getProducer(mb), region) + } - case x@StreamFold(a, zero, accumName, valueName, body) => + case x @ StreamFold(a, zero, accumName, valueName, body) => EmitStream.produce(this, a, cb, cb.emb, region, env, container) .flatMap(cb) { case stream: SStreamValue => val producer = stream.getProducer(mb) - val stateEmitType = VirtualTypeWithReq(zero.typ, ctx.req.lookupState(x).head.asInstanceOf[TypeWithRequiredness]).canonicalEmitType + val stateEmitType = VirtualTypeWithReq( + zero.typ, + ctx.req.lookupState(x).head.asInstanceOf[TypeWithRequiredness], + ).canonicalEmitType val xAcc = mb.newEmitField(accumName, stateEmitType) val xElt = mb.newEmitField(valueName, producer.element.emitType) @@ -2209,44 +2821,73 @@ class Emit[C]( var tmpRegion: Settable[Region] = null if (producer.requiresMemoryManagementPerElement) { - cb.assign(producer.elementRegion, Region.stagedCreate(Region.REGULAR, region.getPool())) + cb.assign( + producer.elementRegion, + Region.stagedCreate(Region.REGULAR, region.getPool()), + ) tmpRegion = mb.genFieldThisRef[Region]("streamfold_tmpregion") cb.assign(tmpRegion, Region.stagedCreate(Region.REGULAR, region.getPool())) - cb.assign(xAcc, emitI(zero, tmpRegion) - .map(cb)(pc => pc.castTo(cb, tmpRegion, stateEmitType.st))) + cb.assign( + xAcc, + emitI(zero, tmpRegion) + .map(cb)(pc => pc.castTo(cb, tmpRegion, stateEmitType.st)), + ) } else { cb.assign(producer.elementRegion, region) - cb.assign(xAcc, emitI(zero, producer.elementRegion) - .map(cb)(pc => pc.castTo(cb, producer.elementRegion, stateEmitType.st))) + cb.assign( + xAcc, + emitI(zero, producer.elementRegion) + .map(cb)(pc => pc.castTo(cb, producer.elementRegion, stateEmitType.st)), + ) } producer.unmanagedConsume(cb, region) { cb => cb.assign(xElt, producer.element) if (producer.requiresMemoryManagementPerElement) { - cb.assign(xAcc, emitI(body, producer.elementRegion, env.bind(accumName -> xAcc, valueName -> xElt)) - .map(cb)(pc => pc.castTo(cb, tmpRegion, stateEmitType.st, deepCopy = true))) + cb.assign( + xAcc, + emitI( + body, + producer.elementRegion, + env.bind(accumName -> xAcc, valueName -> xElt), + ) + .map(cb)(pc => pc.castTo(cb, tmpRegion, stateEmitType.st, deepCopy = true)), + ) cb += producer.elementRegion.clearRegion() - val swapRegion = cb.newLocal[Region]("streamfold_swap_region", producer.elementRegion) + val swapRegion = + cb.newLocal[Region]("streamfold_swap_region", producer.elementRegion) cb.assign(producer.elementRegion, tmpRegion.load()) cb.assign(tmpRegion, swapRegion.load()) } else { - cb.assign(xAcc, emitI(body, producer.elementRegion, env.bind(accumName -> xAcc, valueName -> xElt)) - .map(cb)(pc => pc.castTo(cb, producer.elementRegion, stateEmitType.st, deepCopy = false))) + cb.assign( + xAcc, + emitI( + body, + producer.elementRegion, + env.bind(accumName -> xAcc, valueName -> xElt), + ) + .map(cb)(pc => + pc.castTo(cb, producer.elementRegion, stateEmitType.st, deepCopy = false) + ), + ) } } if (producer.requiresMemoryManagementPerElement) { - cb.assign(xAcc, xAcc.toI(cb).map(cb)(pc => pc.castTo(cb, region, pc.st, deepCopy = true))) + cb.assign( + xAcc, + xAcc.toI(cb).map(cb)(pc => pc.castTo(cb, region, pc.st, deepCopy = true)), + ) cb += producer.elementRegion.invalidate() cb += tmpRegion.invalidate() } xAcc.toI(cb) } - case x@StreamFold2(a, acc, valueName, seq, res) => + case x @ StreamFold2(a, acc, valueName, seq, res) => emitStream(a, cb, region) .flatMap(cb) { case stream: SStreamValue => val producer = stream.getProducer(mb) @@ -2254,8 +2895,9 @@ class Emit[C]( var tmpRegion: Settable[Region] = null val accTypes = ctx.req.lookupState(x).zip(acc.map(_._2.typ)) - .map { case (btwr, t) => VirtualTypeWithReq(t, btwr.asInstanceOf[TypeWithRequiredness]) - .canonicalEmitType + .map { case (btwr, t) => + VirtualTypeWithReq(t, btwr.asInstanceOf[TypeWithRequiredness]) + .canonicalEmitType } val xElt = mb.newEmitField(valueName, producer.element.emitType) @@ -2266,7 +2908,10 @@ class Emit[C]( val seqEnv = resEnv.bind(valueName, xElt) if (producer.requiresMemoryManagementPerElement) { - cb.assign(producer.elementRegion, Region.stagedCreate(Region.REGULAR, region.getPool())) + cb.assign( + producer.elementRegion, + Region.stagedCreate(Region.REGULAR, region.getPool()), + ) tmpRegion = mb.genFieldThisRef[Region]("streamfold_tmpregion") cb.assign(tmpRegion, Region.stagedCreate(Region.REGULAR, region.getPool())) @@ -2285,26 +2930,36 @@ class Emit[C]( cb.assign(xElt, producer.element) if (producer.requiresMemoryManagementPerElement) { (accVars, seq).zipped.foreach { (accVar, ir) => - cb.assign(accVar, + cb.assign( + accVar, emitI(ir, producer.elementRegion, env = seqEnv) - .map(cb)(pc => pc.castTo(cb, tmpRegion, accVar.st, deepCopy = true))) + .map(cb)(pc => pc.castTo(cb, tmpRegion, accVar.st, deepCopy = true)), + ) } cb += producer.elementRegion.clearRegion() - val swapRegion = cb.newLocal[Region]("streamfold2_swap_region", producer.elementRegion) + val swapRegion = + cb.newLocal[Region]("streamfold2_swap_region", producer.elementRegion) cb.assign(producer.elementRegion, tmpRegion.load()) cb.assign(tmpRegion, swapRegion.load()) } else { (accVars, seq).zipped.foreach { (accVar, ir) => - cb.assign(accVar, + cb.assign( + accVar, emitI(ir, producer.elementRegion, env = seqEnv) - .map(cb)(pc => pc.castTo(cb, producer.elementRegion, accVar.st, deepCopy = false))) + .map(cb)(pc => + pc.castTo(cb, producer.elementRegion, accVar.st, deepCopy = false) + ), + ) } } } if (producer.requiresMemoryManagementPerElement) { accVars.foreach { xAcc => - cb.assign(xAcc, xAcc.toI(cb).map(cb)(pc => pc.castTo(cb, region, pc.st, deepCopy = true))) + cb.assign( + xAcc, + xAcc.toI(cb).map(cb)(pc => pc.castTo(cb, region, pc.st, deepCopy = true)), + ) } cb += producer.elementRegion.invalidate() cb += tmpRegion.invalidate() @@ -2312,37 +2967,68 @@ class Emit[C]( emitI(res, env = resEnv) } - case t@Trap(child) => + case t @ Trap(child) => val (ev, mb) = emitSplitMethod("trap", cb, child, region, env, container, loopEnv) - val maybeException = cb.newLocal[(String, java.lang.Integer)]("trap_msg", cb.emb.ecb.runMethodWithHailExceptionHandler(mb.mb.methodName)) + val maybeException = cb.newLocal[(String, java.lang.Integer)]( + "trap_msg", + cb.emb.ecb.runMethodWithHailExceptionHandler(mb.mb.methodName), + ) val sst = SStringPointer(PCanonicalString(false)) val tt = t.typ.asInstanceOf[TTuple] val errTupleType = tt.types(0).asInstanceOf[TTuple] - val errTuple = SStackStruct(errTupleType, FastSeq(EmitType(sst, true), EmitType(SInt32, true))) + val errTuple = + SStackStruct(errTupleType, FastSeq(EmitType(sst, true), EmitType(SInt32, true))) val tv = cb.emb.newEmitField("trap_errTuple", EmitType(errTuple, false)) val maybeMissingEV = cb.emb.newEmitField("trap_value", ev.emitType.copy(required = false)) - cb.if_(maybeException.isNull, { - cb.assign(tv, EmitCode.missing(cb.emb, errTuple)) - cb.assign(maybeMissingEV, ev) - }, { - val str = EmitCode.fromI(cb.emb)(cb => IEmitCode.present(cb, sst.constructFromString(cb, region, maybeException.invoke[String]("_1")))) - val errorId = EmitCode.fromI(cb.emb)(cb => - IEmitCode.present(cb, primitive(cb.memoize(maybeException.invoke[java.lang.Integer]("_2").invoke[Int]("intValue"))))) - cb.assign(tv, IEmitCode.present(cb, SStackStruct.constructFromArgs(cb, region, errTupleType, str, errorId))) - cb.assign(maybeMissingEV, EmitCode.missing(cb.emb, ev.st)) - }) - IEmitCode.present(cb, { - SStackStruct.constructFromArgs(cb, region, t.typ.asInstanceOf[TBaseStruct], tv, maybeMissingEV) - }) + cb.if_( + maybeException.isNull, { + cb.assign(tv, EmitCode.missing(cb.emb, errTuple)) + cb.assign(maybeMissingEV, ev) + }, { + val str = EmitCode.fromI(cb.emb)(cb => + IEmitCode.present( + cb, + sst.constructFromString(cb, region, maybeException.invoke[String]("_1")), + ) + ) + val errorId = EmitCode.fromI(cb.emb)(cb => + IEmitCode.present( + cb, + primitive( + cb.memoize(maybeException.invoke[java.lang.Integer]("_2").invoke[Int]("intValue")) + ), + ) + ) + cb.assign( + tv, + IEmitCode.present( + cb, + SStackStruct.constructFromArgs(cb, region, errTupleType, str, errorId), + ), + ) + cb.assign(maybeMissingEV, EmitCode.missing(cb.emb, ev.st)) + }, + ) + IEmitCode.present( + cb, + SStackStruct.constructFromArgs( + cb, + region, + t.typ.asInstanceOf[TBaseStruct], + tv, + maybeMissingEV, + ), + ) case Die(m, typ, errorId) => val cm = emitI(m) val msg = cb.newLocal[String]("die_msg") - cm.consume(cb, + cm.consume( + cb, cb.assign(msg, ""), - { sc => cb.assign(msg, sc.asString.loadString(cb)) } + sc => cb.assign(msg, sc.asString.loadString(cb)), ) cb._throw[HailException](Code.newInstance[HailException, String, Int](msg, errorId)) IEmitCode(CodeLabel(), CodeLabel(), SUnreachable.fromVirtualType(typ).defaultValue, true) @@ -2354,7 +3040,7 @@ class Emit[C]( emitI(result) case CastToArray(a) => - emitI(a).map(cb) { ind => ind.asIndexable.castToArray(cb) } + emitI(a).map(cb)(ind => ind.asIndexable.castToArray(cb)) case ReadValue(path, reader, requestedType) => emitI(path).map(cb) { pv => @@ -2367,23 +3053,29 @@ class Emit[C]( case WriteValue(value, path, writer, stagingFile) => emitI(path).flatMap(cb) { case pv: SStringValue => emitI(value).map(cb) { v => - val s = stagingFile.map(emitI(_).get(cb).asString) + val s = stagingFile.map(emitI(_).getOrAssert(cb).asString) val os = cb.memoize(mb.createUnbuffered(s.getOrElse(pv).loadString(cb))) writer.writeValue(cb, v, os) cb += os.invoke[Unit]("close") s.foreach { stage => - cb += mb.getFS.invoke[String, String, Boolean, Unit]("copy", stage.loadString(cb), pv.loadString(cb), const(true)) + cb += mb.getFS.invoke[String, String, Boolean, Unit]( + "copy", + stage.loadString(cb), + pv.loadString(cb), + const(true), + ) } pv } } - case x@TailLoop(name, args, _, body) => + case x @ TailLoop(name, args, _, body) => val loopStartLabel = CodeLabel() val accTypes = ctx.req.lookupState(x).zip(args.map(_._2.typ)) - .map { case (btwr, t) => VirtualTypeWithReq(t, btwr.asInstanceOf[TypeWithRequiredness]) - .canonicalEmitType + .map { case (btwr, t) => + VirtualTypeWithReq(t, btwr.asInstanceOf[TypeWithRequiredness]) + .canonicalEmitType } val inits = args.zip(accTypes) @@ -2391,8 +3083,15 @@ class Emit[C]( val stagedPool = cb.newLocal[RegionPool]("tail_loop_pool_ref") cb.assign(stagedPool, region.getPool()) - val resultEmitType = ctx.req.lookup(body).asInstanceOf[TypeWithRequiredness].canonicalEmitType(body.typ) - val loopRef = LoopRef(cb, loopStartLabel, inits.map { case ((name, _), pt) => (name, pt) }, stagedPool, resultEmitType) + val resultEmitType = + ctx.req.lookup(body).asInstanceOf[TypeWithRequiredness].canonicalEmitType(body.typ) + val loopRef = LoopRef( + cb, + loopStartLabel, + inits.map { case ((name, _), pt) => (name, pt) }, + stagedPool, + resultEmitType, + ) val argEnv = env .bind((args.map(_._1), loopRef.loopArgs).zipped.toArray: _*) @@ -2406,13 +3105,21 @@ class Emit[C]( cb.define(loopStartLabel) - val result = emitI(body, region=loopRef.r1, env = argEnv, loopEnv = Some(newLoopEnv.bind(name, loopRef))).map(cb) { pc => + val result = emitI( + body, + region = loopRef.r1, + env = argEnv, + loopEnv = Some(newLoopEnv.bind(name, loopRef)), + ).map(cb) { pc => val answerInRightRegion = pc.copyToRegion(cb, region, pc.st) cb.append(loopRef.r1.clearRegion()) cb.append(loopRef.r2.clearRegion()) answerInRightRegion } - assert(result.emitType == resultEmitType, s"loop type mismatch: emitted=${ result.emitType }, expected=${ resultEmitType }") + assert( + result.emitType == resultEmitType, + s"loop type mismatch: emitted=${result.emitType}, expected=$resultEmitType", + ) result case Recur(name, args, _) => @@ -2420,7 +3127,14 @@ class Emit[C]( // Need to emit into region 1, copy to region 2, then clear region 1, then swap them. (loopRef.tmpLoopArgs, loopRef.loopTypes, args).zipped.map { case (tmpLoopArg, et, arg) => - tmpLoopArg.store(cb, emitI(arg, loopEnv = None, region = loopRef.r1).map(cb)(_.copyToRegion(cb, loopRef.r2, et.st))) + tmpLoopArg.store( + cb, + emitI(arg, loopEnv = None, region = loopRef.r1).map(cb)(_.copyToRegion( + cb, + loopRef.r2, + et.st, + )), + ) } cb.append(loopRef.r1.clearRegion()) @@ -2434,7 +3148,8 @@ class Emit[C]( cb.assign(loopRef.loopArgs, loopRef.tmpLoopArgs.load()) cb.goto(loopRef.L) - // Dead code. The dead label is necessary because you can't append anything else to a code builder + /* Dead code. The dead label is necessary because you can't append anything else to a code + * builder */ // after a goto. val deadLabel = CodeLabel() cb.define(deadLabel) @@ -2442,13 +3157,18 @@ class Emit[C]( val rt = loopRef.resultType IEmitCode(CodeLabel(), CodeLabel(), rt.st.defaultValue, rt.required) - case x@CollectDistributedArray(contexts, globals, cname, gname, body, dynamicID, staticID, tsd) => + case CollectDistributedArray(contexts, globals, cname, gname, body, dynamicID, staticID, + tsd) => val parentCB = mb.ecb emitStream(contexts, cb, region).map(cb) { case ctxStream: SStreamValue => - - def wrapInTuple(cb: EmitCodeBuilder, region: Value[Region], et: EmitCode): SBaseStructPointerValue = { - PCanonicalTuple(true, et.emitType.storageType).constructFromFields(cb, region, FastSeq(et), deepCopy = false) - } + def wrapInTuple(cb: EmitCodeBuilder, region: Value[Region], et: EmitCode) + : SBaseStructPointerValue = + PCanonicalTuple(true, et.emitType.storageType).constructFromFields( + cb, + region, + FastSeq(et), + deepCopy = false, + ) val bufferSpec: BufferSpec = BufferSpec.blockedUncompressed @@ -2456,49 +3176,76 @@ class Emit[C]( val ctxType = ctxStream.st.elementEmitType val contextPTuple: PTuple = PCanonicalTuple(required = true, ctxType.storageType) - val globalPTuple: PTuple = PCanonicalTuple(required = true, emitGlobals.emitType.storageType) + val globalPTuple: PTuple = + PCanonicalTuple(required = true, emitGlobals.emitType.storageType) val contextSpec: TypedCodecSpec = TypedCodecSpec(contextPTuple, bufferSpec) val globalSpec: TypedCodecSpec = TypedCodecSpec(globalPTuple, bufferSpec) // emit body in new FB - val bodyFB = EmitFunctionBuilder[Region, Array[Byte], Array[Byte], Array[Byte]](ctx.executeContext, s"collect_distributed_array_$staticID") + val bodyFB = EmitFunctionBuilder[Region, Array[Byte], Array[Byte], Array[Byte]]( + ctx.executeContext, + s"collect_distributed_array_$staticID", + ) var bodySpec: TypedCodecSpec = null bodyFB.emitWithBuilder { cb => val region = bodyFB.getCodeParam[Region](1) - val ctxIB = cb.newLocal[InputBuffer]("cda_ctx_ib", contextSpec.buildCodeInputBuffer( - Code.newInstance[ByteArrayInputStream, Array[Byte]](bodyFB.getCodeParam[Array[Byte]](2)))) - val gIB = cb.newLocal[InputBuffer]("cda_g_ib", globalSpec.buildCodeInputBuffer( - Code.newInstance[ByteArrayInputStream, Array[Byte]](bodyFB.getCodeParam[Array[Byte]](3)))) - - val decodedContext = contextSpec.encodedType.buildDecoder(contextSpec.encodedVirtualType, bodyFB.ecb) - .apply(cb, region, ctxIB) - .asBaseStruct - .loadField(cb, 0) - .memoizeField(cb, "decoded_context") - - val decodedGlobal = globalSpec.encodedType.buildDecoder(globalSpec.encodedVirtualType, bodyFB.ecb) - .apply(cb, region, gIB) - .asBaseStruct - .loadField(cb, 0) - .memoizeField(cb, "decoded_global") + val ctxIB = cb.newLocal[InputBuffer]( + "cda_ctx_ib", + contextSpec.buildCodeInputBuffer( + Code.newInstance[ByteArrayInputStream, Array[Byte]]( + bodyFB.getCodeParam[Array[Byte]](2) + ) + ), + ) + val gIB = cb.newLocal[InputBuffer]( + "cda_g_ib", + globalSpec.buildCodeInputBuffer( + Code.newInstance[ByteArrayInputStream, Array[Byte]]( + bodyFB.getCodeParam[Array[Byte]](3) + ) + ), + ) - val env = EmitEnv(Env[EmitValue]( - (cname, decodedContext), - (gname, decodedGlobal)), FastSeq()) + val decodedContext = + contextSpec.encodedType.buildDecoder(contextSpec.encodedVirtualType, bodyFB.ecb) + .apply(cb, region, ctxIB) + .asBaseStruct + .loadField(cb, 0) + .memoizeField(cb, "decoded_context") + + val decodedGlobal = + globalSpec.encodedType.buildDecoder(globalSpec.encodedVirtualType, bodyFB.ecb) + .apply(cb, region, gIB) + .asBaseStruct + .loadField(cb, 0) + .memoizeField(cb, "decoded_global") + + val env = EmitEnv( + Env[EmitValue]( + (cname, decodedContext), + (gname, decodedGlobal), + ), + FastSeq(), + ) if (ctx.executeContext.getFlag("print_ir_on_worker") != null) cb.consoleInfo(Pretty(ctx.executeContext, body, elideLiterals = true)) if (ctx.executeContext.getFlag("print_inputs_on_worker") != null) cb.consoleInfo(cb.strValue(decodedContext)) - val bodyResult = wrapInTuple(cb, + val bodyResult = wrapInTuple( + cb, region, - EmitCode.fromI(cb.emb)(cb => new Emit(ctx, bodyFB.ecb).emitI(body, cb, env, None))) + EmitCode.fromI(cb.emb)(cb => new Emit(ctx, bodyFB.ecb).emitI(body, cb, env, None)), + ) bodySpec = TypedCodecSpec(bodyResult.st.storageType().setRequired(true), bufferSpec) - val bOS = cb.newLocal[ByteArrayOutputStream]("cda_baos", Code.newInstance[ByteArrayOutputStream]()) + val bOS = cb.newLocal[ByteArrayOutputStream]( + "cda_baos", + Code.newInstance[ByteArrayOutputStream](), + ) val bOB = cb.newLocal[OutputBuffer]("cda_ob", bodySpec.buildCodeOutputBuffer(bOS)) bodySpec.encodedType.buildEncoder(bodyResult.st, cb.emb.ecb) .apply(cb, bodyResult, bOB) @@ -2517,11 +3264,17 @@ class Emit[C]( val buf = mb.genFieldThisRef[OutputBuffer]() val ctxab = mb.genFieldThisRef[ByteArrayArrayBuilder]() - def addContexts(cb: EmitCodeBuilder, ctxStream: StreamProducer): Unit = { - ctxStream.memoryManagedConsume(region, cb, setup = { cb => - cb += ctxab.invoke[Int, Unit]("ensureCapacity", ctxStream.length.map(_.apply(cb)).getOrElse(16)) - }) { cb => + ctxStream.memoryManagedConsume( + region, + cb, + setup = { cb => + cb += ctxab.invoke[Int, Unit]( + "ensureCapacity", + ctxStream.length.map(_.apply(cb)).getOrElse(16), + ) + }, + ) { cb => cb += baos.invoke[Unit]("reset") val ctxTuple = wrapInTuple(cb, region, ctxStream.element) contextSpec.encodedType.buildEncoder(ctxTuple.st, parentCB) @@ -2539,7 +3292,10 @@ class Emit[C]( } cb.assign(baos, Code.newInstance[ByteArrayOutputStream]()) - cb.assign(buf, contextSpec.buildCodeOutputBuffer(baos)) // TODO: take a closer look at whether we need two codec buffers? + cb.assign( + buf, + contextSpec.buildCodeOutputBuffer(baos), + ) // TODO: take a closer look at whether we need two codec buffers? cb.assign(ctxab, Code.newInstance[ByteArrayArrayBuilder, Int](16)) addContexts(cb, ctxStream.getProducer(mb)) cb += baos.invoke[Unit]("reset") @@ -2549,31 +3305,33 @@ class Emit[C]( val stageName = cb.newLocal[String]("stagename") cb.assign(stageName, staticID) - val semhash = cb.newLocal[Option[SemanticHash.Type]]("semhash", + val semhash = cb.newLocal[Option[SemanticHash.Type]]( + "semhash", Code.invokeScalaObject[Option[SemanticHash.Type]]( Option.getClass, "empty", Array(), - Array() - ) + Array(), + ), ) - emitI(dynamicID).consume(cb, + emitI(dynamicID).consume( + cb, ctx.executeContext.irMetadata.nextHash.foreach { hash => - cb.assign(semhash, + cb.assign( + semhash, Code.invokeScalaObject[Option[SemanticHash.Type]]( SemanticHash.CodeGenSupport.getClass, "lift", Array(classOf[SemanticHash.Type]), - Array(hash) - ) + Array(hash), + ), ) }, { dynamicID => val dynV = dynamicID.asString.loadString(cb) cb.assign(stageName, stageName.concat("|").concat(dynV)) ctx.executeContext.irMetadata.nextHash.foreach { staticHash => - val dynamicHash = dynV.invoke[Array[Byte]]("getBytes") @@ -2582,39 +3340,59 @@ class Emit[C]( SemanticHash.getClass, "extend", Array(classOf[SemanticHash.Type], classOf[Array[Byte]]), - Array(staticHash, dynamicHash) + Array(staticHash, dynamicHash), ) - cb.assign(semhash, + cb.assign( + semhash, Code.invokeScalaObject[Option[SemanticHash.Type]]( SemanticHash.CodeGenSupport.getClass, "lift", Array(classOf[SemanticHash.Type]), - Array(combined) - ) + Array(combined), + ), ) } - } + }, ) val encRes = cb.newLocal[Array[Array[Byte]]]("encRes") - cb.assign(encRes, backend.invoke[BackendContext, HailClassLoader, FS, String, Array[Array[Byte]], Array[Byte], String, Option[SemanticHash.Type], Option[TableStageDependency], Array[Array[Byte]]]( - "collectDArray", - mb.getObject(ctx.executeContext.backendContext), - mb.getHailClassLoader, - mb.getFS, - functionID, - ctxab.invoke[Array[Array[Byte]]]("result"), - baos.invoke[Array[Byte]]("toByteArray"), - stageName, - semhash, - mb.getObject(tsd)) + cb.assign( + encRes, + backend.invoke[ + BackendContext, + HailClassLoader, + FS, + String, + Array[Array[Byte]], + Array[Byte], + String, + Option[SemanticHash.Type], + Option[TableStageDependency], + Array[Array[Byte]], + ]( + "collectDArray", + mb.getObject(ctx.executeContext.backendContext), + mb.getHailClassLoader, + mb.getFS, + functionID, + ctxab.invoke[Array[Array[Byte]]]("result"), + baos.invoke[Array[Byte]]("toByteArray"), + stageName, + semhash, + mb.getObject(tsd), + ), ) val len = cb.memoize(encRes.length()) - val pt = PCanonicalArray(bodySpec.encodedType.decodedSType(bodySpec.encodedVirtualType).asInstanceOf[SBaseStruct].fieldEmitTypes(0).storageType) + val pt = PCanonicalArray(bodySpec.encodedType.decodedSType( + bodySpec.encodedVirtualType + ).asInstanceOf[SBaseStruct].fieldEmitTypes(0).storageType) val resultArray = pt.constructFromElements(cb, region, len, deepCopy = false) { (cb, i) => - val ib = cb.memoize(bodySpec.buildCodeInputBuffer(Code.newInstance[ByteArrayInputStream, Array[Byte]](encRes(i)))) + val ib = cb.memoize(bodySpec.buildCodeInputBuffer(Code.newInstance[ + ByteArrayInputStream, + Array[Byte], + ](encRes(i)))) val eltTupled = bodySpec.encodedType.buildDecoder(bodySpec.encodedVirtualType, parentCB) .apply(cb, region, ib) .asBaseStruct @@ -2632,7 +3410,9 @@ class Emit[C]( ctx.req.lookupOpt(ir) match { case Some(r) => if (result.required != r.required) { - throw new RuntimeException(s"requiredness mismatch: EC=${ result.required } / Analysis=${ r.required }\n${ result.st }\n${ Pretty(ctx.executeContext, ir) }") + throw new RuntimeException( + s"requiredness mismatch: EC=${result.required} / Analysis=${r.required}\n${result.st}\n${Pretty(ctx.executeContext, ir)}" + ) } case _ => @@ -2640,50 +3420,20 @@ class Emit[C]( } if (result.st.virtualType != ir.typ) - throw new RuntimeException(s"type mismatch:\n EC=${ result.st.virtualType }\n IR=${ ir.typ }\n node: ${ Pretty(ctx.executeContext, ir).take(50) }") + throw new RuntimeException( + s"type mismatch:\n EC=${result.st.virtualType}\n IR=${ir.typ}\n node: ${Pretty(ctx.executeContext, ir).take(50)}" + ) result } - /** - * Invariants of the Returned Triplet - * ---------------------------------- - * - * The elements of the triplet are called (precompute, missingness, value) - * - * 1. evaluate each returned Code[_] at most once - * 2. evaluate precompute *on all static code-paths* leading to missingness or value - * 3. guard the the evaluation of value by missingness - * - * Triplets returning values cannot have side-effects. For void triplets, precompute - * contains the side effect, missingness is false, and value is {@code Code._empty}. - * - * JVM gotcha: - * a variable must be initialized on all static code-paths prior to its use (ergo defaultValue) - * - * Argument Convention - * ------------------- - * - * {@code In(i)} occupies two argument slots, one for the value and one for a - * missing bit. The value for {@code  In(0)} is passed as argument - * {@code  nSpecialArguments + 1}. The missingness bit is the subsequent - * argument. In general, the value for {@code  In(i)} appears at - * {@code  nSpecialArguments + 1 + 2 * i}. - * - * There must always be at least one special argument: a {@code  Region} in - * which the IR can allocate memory. - * - * When compiling an aggregation expression, {@code AggIn} refers to the first - * argument {@code In(0)} whose type must be of type - * {@code tAggIn.elementType}. {@code tAggIn.symTab} is not used by Emit. - * - **/ - private[ir] def emit(ir: IR, mb: EmitMethodBuilder[C], env: EmitEnv, container: Option[AggContainer]): EmitCode = { - val region = mb.getCodeParam[Region](1) - emit(ir, mb, region, env, container, None) - } - - private[ir] def emitWithRegion(ir: IR, mb: EmitMethodBuilder[C], region: Value[Region], env: EmitEnv, container: Option[AggContainer]): EmitCode = + private[ir] def emit( + ir: IR, + mb: EmitMethodBuilder[C], + region: Value[Region], + env: EmitEnv, + container: Option[AggContainer], + ): EmitCode = emit(ir, mb, region, env, container, None) private def emit( @@ -2693,28 +3443,45 @@ class Emit[C]( env: EmitEnv, container: Option[AggContainer], loopEnv: Option[Env[LoopRef]], - fallingBackFromEmitI: Boolean = false + fallingBackFromEmitI: Boolean = false, ): EmitCode = { if (ctx.methodSplits.contains(ir) && !ctx.tryingToSplit.contains(ir)) { - return EmitCode.fromI(mb)(cb => emitInSeparateMethod(s"split_${ir.getClass.getSimpleName}", cb, ir, region, env, container, loopEnv)) + return EmitCode.fromI(mb)(cb => + emitInSeparateMethod( + s"split_${ir.getClass.getSimpleName}", + cb, + ir, + region, + env, + container, + loopEnv, + ) + ) } - - def emit(ir: IR, region: Value[Region] = region, env: EmitEnv = env, container: Option[AggContainer] = container, loopEnv: Option[Env[LoopRef]] = loopEnv): EmitCode = + def emit( + ir: IR, + region: Value[Region] = region, + env: EmitEnv = env, + container: Option[AggContainer] = container, + loopEnv: Option[Env[LoopRef]] = loopEnv, + ): EmitCode = this.emit(ir, mb, region, env, container, loopEnv) - def emitI(ir: IR, cb: EmitCodeBuilder, env: EmitEnv = env, container: Option[AggContainer] = container, loopEnv: Option[Env[LoopRef]] = loopEnv): IEmitCode = + def emitI( + ir: IR, + cb: EmitCodeBuilder, + env: EmitEnv = env, + container: Option[AggContainer] = container, + loopEnv: Option[Env[LoopRef]] = loopEnv, + ): IEmitCode = this.emitI(ir, cb, region, env, container, loopEnv) - def emitVoid(ir: IR, env: EmitEnv = env, container: Option[AggContainer] = container, loopEnv: Option[Env[LoopRef]] = loopEnv): Code[Unit] = { - EmitCodeBuilder.scopedVoid(mb) { cb => - this.emitVoid(cb, ir, region, env, container, loopEnv) - } - } - def emitStream(ir: IR, outerRegion: Value[Region], env: EmitEnv = env): EmitCode = - EmitCode.fromI(mb)(cb => EmitStream.produce(this, ir, cb, cb.emb, outerRegion, env, container)) + EmitCode.fromI(mb)(cb => + EmitStream.produce(this, ir, cb, cb.emb, outerRegion, env, container) + ) // ideally, emit would not be called with void values, but initOp args can be void // working towards removing this @@ -2726,31 +3493,15 @@ class Emit[C]( val result: EmitCode = (ir: @unchecked) match { - case Let(bindings, body) => - EmitCode.fromI(mb) { cb => - def go(env: EmitEnv): IndexedSeq[(String, IR)] => IEmitCode = { - case (name, value) +: rest => - val xVal = - if (value.typ.isInstanceOf[TStream]) emitStream(value, region, env = env) - else emit(value, env = env) - - cb.withScopedMaybeStreamValue(xVal, s"let_$name") { ev => - go(env.bind(name, ev))(rest) - } - case Seq() => - emitI(body, cb, env = env) - } - - go(env)(bindings) - } - case Ref(name, t) => val ev = env.bindings.lookup(name) if (ev.st.virtualType != t) - throw new RuntimeException(s"emit value type did not match specified type:\n name: $name\n ev: ${ ev.st.virtualType }\n ir: ${ ir.typ }") + throw new RuntimeException( + s"emit value type did not match specified type:\n name: $name\n ev: ${ev.st.virtualType}\n ir: ${ir.typ}" + ) ev.load - case ir@Apply(fn, typeArgs, args, rt, errorID) => + case ir @ Apply(fn, typeArgs, args, rt, errorID) => val impl = ir.implementation val unified = impl.unify(typeArgs, args.map(_.typ), rt) assert(unified) @@ -2772,11 +3523,16 @@ class Emit[C]( EmitCode.fromI(mb) { cb => val emitArgs = args.map(a => EmitCode.fromI(cb.emb)(emitI(a, _))).toFastSeq IEmitCode.multiMapEmitCodes(cb, emitArgs) { codeArgs => - cb.invokeSCode(meth, FastSeq[Param](CodeParam(region), CodeParam(errorID)) ++ codeArgs.map(pc => pc: Param): _*) + cb.invokeSCode( + meth, + FastSeq[Param](cb.this_, CodeParam(region), CodeParam(errorID)) ++ codeArgs.map(pc => + pc: Param + ): _* + ) } } - case x@ApplySpecial(_, typeArgs, args, rt, errorID) => + case x @ ApplySpecial(_, typeArgs, args, rt, errorID) => val codeArgs = args.map(a => emit(a)) val impl = x.implementation val unified = impl.unify(typeArgs, args.map(_.typ), rt) @@ -2797,15 +3553,15 @@ class Emit[C]( if (fallingBackFromEmitI) { fatal(s"ir is not defined in emit or emitI $x") } - EmitCode.fromI(mb) { cb => - emitI(ir, cb) - } + EmitCode.fromI(mb)(cb => emitI(ir, cb)) } ctx.req.lookupOpt(ir) match { case Some(r) => if (result.required != r.required) { - throw new RuntimeException(s"requiredness mismatch: EC=${ result.required } / Analysis=${ r.required }\n${ result.emitType }\n${ Pretty(ctx.executeContext, ir) }") + throw new RuntimeException( + s"requiredness mismatch: EC=${result.required} / Analysis=${r.required}\n${result.emitType}\n${Pretty(ctx.executeContext, ir)}" + ) } case _ => @@ -2813,87 +3569,238 @@ class Emit[C]( } if (result.st.virtualType != ir.typ) - throw new RuntimeException(s"type mismatch: EC=${ result.st.virtualType } / IR=${ ir.typ }\n${ ir.getClass.getSimpleName }") + throw new RuntimeException( + s"type mismatch: EC=${result.st.virtualType} / IR=${ir.typ}\n${ir.getClass.getSimpleName}" + ) result } private def makeDependentSortingFunction( cb: EmitCodeBuilder, - elemSCT: SingleCodeType, ir: IR, env: EmitEnv, emitter: Emit[_], leftRightComparatorNames: Array[String] + elemSCT: SingleCodeType, + ir: IR, + env: EmitEnv, + emitter: Emit[_], + leftRightComparatorNames: Array[String], ): (EmitCodeBuilder, Value[Region], Value[_], Value[_]) => Value[Boolean] = { val fb = cb.emb.ecb var newEnv = env - val sort = fb.genEmitMethod("dependent_sorting_func", + val sort = fb.genEmitMethod( + "dependent_sorting_func", FastSeq(typeInfo[Region], CodeParamType(elemSCT.ti), CodeParamType(elemSCT.ti)), - BooleanInfo) + BooleanInfo, + ) sort.emitWithBuilder[Boolean] { cb => - val region = sort.getCodeParam[Region](1) - val leftEC = cb.memoize(EmitCode.present(sort, elemSCT.loadToSValue(cb, sort.getCodeParam(2)(elemSCT.ti))), "sort_leftEC") - val rightEC = cb.memoize(EmitCode.present(sort, elemSCT.loadToSValue(cb, sort.getCodeParam(3)(elemSCT.ti))), "sort_rightEC") + val leftEC = cb.memoize( + EmitCode.present(sort, elemSCT.loadToSValue(cb, sort.getCodeParam(2)(elemSCT.ti))), + "sort_leftEC", + ) + val rightEC = cb.memoize( + EmitCode.present(sort, elemSCT.loadToSValue(cb, sort.getCodeParam(3)(elemSCT.ti))), + "sort_rightEC", + ) if (leftRightComparatorNames.nonEmpty) { assert(leftRightComparatorNames.length == 2) newEnv = newEnv.bind( - (leftRightComparatorNames(0), leftEC), - (leftRightComparatorNames(1), rightEC)) + (leftRightComparatorNames(0), leftEC), + (leftRightComparatorNames(1), rightEC), + ) } val iec = emitter.emitI(ir, cb, newEnv, None) - iec.get(cb, "Result of sorting function cannot be missing").asBoolean.value + iec.getOrFatal(cb, "Result of sorting function cannot be missing").asBoolean.value + } + (cb: EmitCodeBuilder, region: Value[Region], l: Value[_], r: Value[_]) => + cb.memoize(cb.invokeCode[Boolean](sort, cb.this_, region, l, r)) + } + + /** Emit the bindings (but not the body) of `let`. If possible, split bindings into chunks, and + * emit each chunk in a separate method. + */ + // TODO: splitting logic should get lifted into ComputeMethodSplits + def emitBlock( + let: Block, + cb: EmitCodeBuilder, + env: EmitEnv, + r: Value[Region], + container: Option[AggContainer], + loopEnv: Option[Env[LoopRef]], + ): EmitEnv = { + def emitI(ir: IR, cb: EmitCodeBuilder, env: EmitEnv, r: Value[Region]): IEmitCode = + if (ir.typ.isInstanceOf[TStream]) + EmitStream.produce(this, ir, cb, cb.emb, r, env, container) + else this.emitI(ir, cb, r, env, container, loopEnv) + + def emitVoid(ir: IR, cb: EmitCodeBuilder, env: EmitEnv, r: Value[Region]): Unit = + this.emitVoid(cb, ir, r, env, container, loopEnv) + + val uses: mutable.Set[String] = + ctx.usesAndDefs.uses.get(let) match { + case Some(refs) => refs.map(_.t.name) + case None => mutable.Set.empty + } + + /* Emit a sequence of bindings into a code builder. Each is added to the environment of all + * following bindings. Any bindings which is unused and has no side effects is skipped (this is + * mostly an optimization, but it is important not to emit unused streams). */ + def emitChunk(cb: EmitCodeBuilder, bindings: Seq[Binding], env: EmitEnv, r: Value[Region]) + : EmitEnv = + bindings.foldLeft(env) { case (newEnv, Binding(name, ir, Scope.EVAL)) => + if (ir.typ == TVoid) { + emitVoid(ir, cb, newEnv, r) + newEnv + } else if (IsPure(ir) && !uses.contains(name)) { + newEnv + } else { + val value = emitI(ir, cb, newEnv, r) + val memo = cb.memoizeMaybeStreamValue(value, s"let_$name") + newEnv.bind(name, memo) + } + } + + /* Bindings before chunkStart have been emitted. Bindings in the range chunkStart <= i < pos are + * a pending chunk, which have not yet been emitted. chunkSize is the number of non-skipped + * bindings in the pending chunk. groupIdx is how many chunks have already been emitted. */ + @tailrec def go( + env: EmitEnv, + chunkStart: Int, + pos: Int, + chunkSize: Int, + groupIdx: Int, + ): EmitEnv = { + + def emitChunkInSeparateMethod(): EmitEnv = { + val mb = cb.emb.genEmitMethod( + s"begin_group_$groupIdx", + FastSeq[ParamType](classInfo[Region]), + UnitInfo, + ) + var newEnv = env + mb.voidWithBuilder { cb => + newEnv = + emitChunk(cb, let.bindings.slice(chunkStart, pos), env, mb.getCodeParam[Region](1)) + } + cb.invokeVoid(mb, cb.this_, r) + newEnv + } + + def cantEmitInSeparateMethod(ir: IR): Boolean = + ir.typ.isInstanceOf[TStream] || ctx.inLoopCriticalPath.contains(ir) + + // end of bindings, emit any pending chunk and return the final environment + if (pos == let.bindings.length) { + if (chunkSize > 0) + return emitChunkInSeparateMethod() + else + return env + } + + val Binding(curName, curIR, Scope.EVAL) = let.bindings(pos) + + // skip over unused streams + if (curIR.typ.isInstanceOf[TStream] && !uses.contains(curName)) { + go(env, chunkStart, pos + 1, chunkSize, groupIdx) + } else if (chunkSize == 16 || (chunkSize > 0 && cantEmitInSeparateMethod(curIR))) { + /* emit the current chunk if it's either max size, or broken by a stream or other control + * flow */ + val newEnv = emitChunkInSeparateMethod() + go(newEnv, pos, pos, 0, groupIdx + 1) + } else if (curIR.typ.isInstanceOf[TStream]) { + // emit a stream, assuming we've already emitted any prior chunk + assert(chunkSize == 0) // no pending bindings + val value = emitI(curIR, cb, env, r) + val memo = cb.memoizeMaybeStreamValue(value, s"let_$curName") + val newEnv = env.bind(curName, memo) + go(newEnv, pos + 1, pos + 1, 0, groupIdx) + } else { + // add cur binding to pending chunk + go(env, chunkStart, pos + 1, chunkSize + 1, groupIdx) + } + } + + // don't split into separate methods if the bindings list is small + if (let.bindings.size > 4) { + go(env, 0, 0, 0, 0) + } else { + emitChunk(cb, let.bindings, env, r) } - (cb: EmitCodeBuilder, region: Value[Region], l: Value[_], r: Value[_]) => cb.memoize(cb.invokeCode[Boolean](sort, region, l, r)) } } object NDArrayEmitter { - def zeroBroadcastedDims2(mb: EmitMethodBuilder[_], loopVars: IndexedSeq[Value[Long]], nDims: Int, shapeArray: IndexedSeq[Value[Long]]): IndexedSeq[Value[Long]] = { + def zeroBroadcastedDims2( + mb: EmitMethodBuilder[_], + loopVars: IndexedSeq[Value[Long]], + nDims: Int, + shapeArray: IndexedSeq[Value[Long]], + ): IndexedSeq[Value[Long]] = { val broadcasted = 0L val notBroadcasted = 1L - Array.tabulate(nDims)(dim => new Value[Long] { - def get: Code[Long] = (shapeArray(dim) > 1L).mux(notBroadcasted, broadcasted) * loopVars(dim) - }) + Array.tabulate(nDims)(dim => + new Value[Long] { + def get: Code[Long] = + (shapeArray(dim) > 1L).mux(notBroadcasted, broadcasted) * loopVars(dim) + } + ) } def broadcastMask(shapeArray: IndexedSeq[Code[Long]]): IndexedSeq[Value[Long]] = { val broadcasted = 0L val notBroadcasted = 1L - shapeArray.map(shapeElement => new Value[Long] { - def get: Code[Long] = (shapeElement > 1L).mux(notBroadcasted, broadcasted) - }) + shapeArray.map(shapeElement => + new Value[Long] { + def get: Code[Long] = (shapeElement > 1L).mux(notBroadcasted, broadcasted) + } + ) } - def zeroBroadcastedDims(indices: IndexedSeq[Code[Long]], broadcastMask: IndexedSeq[Code[Long]]): IndexedSeq[Value[Long]] = { - indices.zip(broadcastMask).map { case (index, flag) => new Value[Long] { - def get: Code[Long] = index * flag - } + def zeroBroadcastedDims(indices: IndexedSeq[Code[Long]], broadcastMask: IndexedSeq[Code[Long]]) + : IndexedSeq[Value[Long]] = + indices.zip(broadcastMask).map { case (index, flag) => + new Value[Long] { + def get: Code[Long] = index * flag + } } - } - def unifyShapes2(cb: EmitCodeBuilder, leftShape: IndexedSeq[Value[Long]], rightShape: IndexedSeq[Value[Long]], errorID: Int): IndexedSeq[Value[Long]] = { + def unifyShapes2( + cb: EmitCodeBuilder, + leftShape: IndexedSeq[Value[Long]], + rightShape: IndexedSeq[Value[Long]], + errorID: Int, + ): IndexedSeq[Value[Long]] = { val shape = leftShape.zip(rightShape).zipWithIndex.map { case ((left, right), i) => val notSameAndNotBroadcastable = !((left ceq right) || (left ceq 1L) || (right ceq 1L)) cb.newField[Long]( s"unify_shapes2_shape$i", notSameAndNotBroadcastable.mux( - Code._fatalWithID[Long](rightShape.foldLeft[Code[String]]( - leftShape.foldLeft[Code[String]]( - const("Incompatible NDArray shapes: [ ") + Code._fatalWithID[Long]( + rightShape.foldLeft[Code[String]]( + leftShape.foldLeft[Code[String]]( + const("Incompatible NDArray shapes: [ ") + )((accum, v) => accum.concat(v.toS).concat(" ")) + .concat("] vs [ ") )((accum, v) => accum.concat(v.toS).concat(" ")) - .concat("] vs [ ") - )((accum, v) => accum.concat(v.toS).concat(" ")) - .concat("]"), errorID), - (right ceq 1L).mux(left, right))) + .concat("]"), + errorID, + ), + (right ceq 1L).mux(left, right), + ), + ) } shape } - def matmulShape(cb: EmitCodeBuilder, leftShape: IndexedSeq[Value[Long]], rightShape: IndexedSeq[Value[Long]], errorID: Int): IndexedSeq[Value[Long]] = { - val mb = cb.emb - + def matmulShape( + cb: EmitCodeBuilder, + leftShape: IndexedSeq[Value[Long]], + rightShape: IndexedSeq[Value[Long]], + errorID: Int, + ): IndexedSeq[Value[Long]] = { assert(leftShape.nonEmpty) assert(rightShape.nonEmpty) @@ -2917,21 +3824,27 @@ object NDArrayEmitter { shape = leftShape.slice(0, leftShape.length - 1) } else { rK = rightShape(rightShape.length - 2) - val unifiedShape = unifyShapes2(cb, + val unifiedShape = unifyShapes2( + cb, leftShape.slice(0, leftShape.length - 2), - rightShape.slice(0, rightShape.length - 2), errorID) + rightShape.slice(0, rightShape.length - 2), + errorID, + ) shape = unifiedShape :+ leftShape(leftShape.length - 2) :+ rightShape.last } } - val leftShapeString = const("(").concat(leftShape.map(_.toS).reduce((a, b) => a.concat(", ").concat(b))).concat(")") - val rightShapeString = const("(").concat(rightShape.map(_.toS).reduce((a, b) => a.concat(", ").concat(b))).concat(")") - + val leftShapeString = + const("(").concat(leftShape.map(_.toS).reduce((a, b) => a.concat(", ").concat(b))).concat(")") + val rightShapeString = const("(").concat(rightShape.map(_.toS).reduce((a, b) => + a.concat(", ").concat(b) + )).concat(")") - cb.if_(lK.cne(rK), { - cb._fatalWithError(errorID,"Matrix dimensions incompatible: ", leftShapeString, - " can't be multiplied by matrix with dimensions ", rightShapeString) - }) + cb.if_( + lK.cne(rK), + cb._fatalWithError(errorID, "Matrix dimensions incompatible: ", leftShapeString, + " can't be multiplied by matrix with dimensions ", rightShapeString), + ) shape } @@ -2951,15 +3864,24 @@ abstract class NDArrayEmitter(val outputShape: IndexedSeq[Value[Long]], val elem outputShape, targetType.makeColumnMajorStrides(shapeArray, cb), cb, - region) - - SNDArray.forEachIndexColMajor(cb, shapeArray, "ndarrayemitter_emitloops") { case (cb, idxVars) => - IEmitCode.present(cb, outputElement(cb, idxVars)).consume(cb, { - cb._fatal("NDArray elements cannot be missing") - }, { elementPc => - targetType.elementType.storeAtAddress(cb, firstElementAddress + (idx.toL * targetType.elementType.byteSize), region, elementPc, true) - }) - cb.assign(idx, idx + 1) + region, + ) + + SNDArray.forEachIndexColMajor(cb, shapeArray, "ndarrayemitter_emitloops") { + case (cb, idxVars) => + IEmitCode.present(cb, outputElement(cb, idxVars)).consume( + cb, + cb._fatal("NDArray elements cannot be missing"), + elementPc => + targetType.elementType.storeAtAddress( + cb, + firstElementAddress + (idx.toL * targetType.elementType.byteSize), + region, + elementPc, + true, + ), + ) + cb.assign(idx, idx + 1) } finish(cb) diff --git a/hail/src/main/scala/is/hail/expr/ir/EmitClassBuilder.scala b/hail/src/main/scala/is/hail/expr/ir/EmitClassBuilder.scala index 115df824b3e..78391c6e975 100644 --- a/hail/src/main/scala/is/hail/expr/ir/EmitClassBuilder.scala +++ b/hail/src/main/scala/is/hail/expr/ir/EmitClassBuilder.scala @@ -5,28 +5,36 @@ import is.hail.asm4s._ import is.hail.backend.{BackendUtils, ExecuteContext, HailTaskContext} import is.hail.expr.ir.functions.IRRandomness import is.hail.expr.ir.orderings.{CodeOrdering, StructOrdering} -import is.hail.io.fs.FS import is.hail.io.{BufferSpec, InputBuffer, TypedCodecSpec} +import is.hail.io.fs.FS import is.hail.types.VirtualTypeWithReq +import is.hail.types.physical.{PCanonicalTuple, PType} import is.hail.types.physical.stypes._ import is.hail.types.physical.stypes.interfaces.SBaseStruct -import is.hail.types.physical.{PCanonicalTuple, PType} import is.hail.types.virtual.Type import is.hail.utils._ import is.hail.utils.prettyPrint.ArrayOfByteArrayInputStream import is.hail.variant.ReferenceGenome -import org.apache.spark.TaskContext -import java.io._ -import java.lang.reflect.InvocationTargetException import scala.collection.mutable import scala.language.existentials +import java.io._ +import java.lang.reflect.InvocationTargetException + +import org.apache.spark.TaskContext + class EmitModuleBuilder(val ctx: ExecuteContext, val modb: ModuleBuilder) { - def newEmitClass[C](name: String, sourceFile: Option[String] = None)(implicit cti: TypeInfo[C]): EmitClassBuilder[C] = - new EmitClassBuilder(this, modb.newClass(name, sourceFile)) - def genEmitClass[C](baseName: String, sourceFile: Option[String] = None)(implicit cti: TypeInfo[C]): EmitClassBuilder[C] = + def newEmitClass[C](name: String, sourceFile: Option[String] = None)(implicit cti: TypeInfo[C]) + : EmitClassBuilder[C] = + new EmitClassBuilder[C](this, modb.newClass[C](name, sourceFile)) + + def genEmitClass[C]( + baseName: String, + sourceFile: Option[String] = None, + )(implicit cti: TypeInfo[C] + ): EmitClassBuilder[C] = newEmitClass[C](genName("C", baseName), sourceFile) private[this] val _staticHailClassLoader: StaticField[HailClassLoader] = { @@ -34,7 +42,8 @@ class EmitModuleBuilder(val ctx: ExecuteContext, val modb: ModuleBuilder) { cls.newStaticField[HailClassLoader]("hailClassLoader", Code._null[HailClassLoader]) } - def setHailClassLoader(cb: EmitCodeBuilder, fs: Code[HailClassLoader]): Unit = cb += _staticHailClassLoader.put(fs) + def setHailClassLoader(cb: EmitCodeBuilder, fs: Code[HailClassLoader]): Unit = + cb += _staticHailClassLoader.put(fs) def getHailClassLoader: Value[HailClassLoader] = new StaticFieldRef(_staticHailClassLoader) @@ -52,58 +61,73 @@ class EmitModuleBuilder(val ctx: ExecuteContext, val modb: ModuleBuilder) { def hasReferences: Boolean = rgContainers.nonEmpty def getReferenceGenome(rg: String): Value[ReferenceGenome] = { - val rgField = rgContainers.getOrElseUpdate(rg, { - val cls = genEmitClass[Unit](s"RGContainer_${rg}") - cls.newStaticField("reference_genome", Code._null[ReferenceGenome]) - }) + val rgField = rgContainers.getOrElseUpdate( + rg, { + val cls = genEmitClass[Unit](s"RGContainer_$rg") + cls.newStaticField("reference_genome", Code._null[ReferenceGenome]) + }, + ) new StaticFieldRef(rgField) } - def referenceGenomes(): IndexedSeq[ReferenceGenome] = rgContainers.keys.map(ctx.getReference(_)).toIndexedSeq.sortBy(_.name) - def referenceGenomeFields(): IndexedSeq[StaticField[ReferenceGenome]] = rgContainers.toFastSeq.sortBy(_._1).map(_._2) + def referenceGenomes(): IndexedSeq[ReferenceGenome] = + rgContainers.keys.map(ctx.getReference(_)).toIndexedSeq.sortBy(_.name) + + def referenceGenomeFields(): IndexedSeq[StaticField[ReferenceGenome]] = + rgContainers.toFastSeq.sortBy(_._1).map(_._2) var _rgMapField: StaticFieldRef[Map[String, ReferenceGenome]] = null def referenceGenomeMap: Value[Map[String, ReferenceGenome]] = { if (_rgMapField == null) { val cls = genEmitClass[Unit](s"RGMapContainer") - _rgMapField = new StaticFieldRef(cls.newStaticField("reference_genome_map", Code._null[Map[String, ReferenceGenome]])) + _rgMapField = new StaticFieldRef(cls.newStaticField( + "reference_genome_map", + Code._null[Map[String, ReferenceGenome]], + )) } _rgMapField } - def setObjects(cb: EmitCodeBuilder, objects: Code[Array[AnyRef]]): Unit = modb.setObjects(cb, objects) + def setObjects(cb: EmitCodeBuilder, objects: Code[Array[AnyRef]]): Unit = + modb.setObjects(cb, objects) - def getObject[T <: AnyRef : TypeInfo](obj: T): Code[T] = modb.getObject(obj) + def getObject[T <: AnyRef: TypeInfo](obj: T): Code[T] = modb.getObject(obj) private[this] var currLitIndex: Int = 0 private[this] val literalsBuilder = mutable.Map.empty[(VirtualTypeWithReq, Any), (PType, Int)] private[this] val encodedLiteralsBuilder = mutable.Map.empty[EncodedLiteral, (PType, Int)] def registerLiteral(t: VirtualTypeWithReq, value: Any): (PType, Int) = { - literalsBuilder.getOrElseUpdate((t, value), { - val curr = currLitIndex - val pt = t.canonicalPType - literalsBuilder.put((t, value), (pt, curr)) - currLitIndex += 1 - (pt, curr) - }) + literalsBuilder.getOrElseUpdate( + (t, value), { + val curr = currLitIndex + val pt = t.canonicalPType + literalsBuilder.put((t, value), (pt, curr)) + currLitIndex += 1 + (pt, curr) + }, + ) } def registerEncodedLiteral(el: EncodedLiteral): (PType, Int) = { - encodedLiteralsBuilder.getOrElseUpdate(el, { - val curr = currLitIndex - val pt = el.codec.decodedPType() - encodedLiteralsBuilder.put(el, (pt, curr)) - currLitIndex += 1 - (pt, curr) - }) + encodedLiteralsBuilder.getOrElseUpdate( + el, { + val curr = currLitIndex + val pt = el.codec.decodedPType() + encodedLiteralsBuilder.put(el, (pt, curr)) + currLitIndex += 1 + (pt, curr) + }, + ) } - def literalsResult(): (Array[(VirtualTypeWithReq, Any, PType, Int)], Array[(EncodedLiteral, PType, Int)]) = { - (literalsBuilder.toArray.map { case ((vt, a), (pt, i)) => (vt, a, pt, i) }, - encodedLiteralsBuilder.toArray.map { case (el, (pt, i)) => (el, pt, i) }) - } + def literalsResult() + : (Array[(VirtualTypeWithReq, Any, PType, Int)], Array[(EncodedLiteral, PType, Int)]) = + ( + literalsBuilder.toArray.map { case ((vt, a), (pt, i)) => (vt, a, pt, i) }, + encodedLiteralsBuilder.toArray.map { case (el, (pt, i)) => (el, pt, i) }, + ) def hasLiterals: Boolean = currLitIndex > 0 } @@ -115,9 +139,11 @@ trait WrappedEmitModuleBuilder { def ctx: ExecuteContext = emodb.ctx - def newEmitClass[C](name: String)(implicit cti: TypeInfo[C]): EmitClassBuilder[C] = emodb.newEmitClass[C](name) + def newEmitClass[C](name: String)(implicit cti: TypeInfo[C]): EmitClassBuilder[C] = + emodb.newEmitClass[C](name) - def genEmitClass[C](baseName: String)(implicit cti: TypeInfo[C]): EmitClassBuilder[C] = emodb.genEmitClass[C](baseName) + def genEmitClass[C](baseName: String)(implicit cti: TypeInfo[C]): EmitClassBuilder[C] = + emodb.genEmitClass[C](baseName) def getReferenceGenome(rg: String): Value[ReferenceGenome] = emodb.getReferenceGenome(rg) } @@ -135,17 +161,24 @@ trait WrappedEmitClassBuilder[C] extends WrappedEmitModuleBuilder { def newStaticField[T: TypeInfo](name: String): StaticField[T] = ecb.newStaticField[T](name) - def newStaticField[T: TypeInfo](name: String, init: Code[T]): StaticField[T] = ecb.newStaticField[T](name, init) + def newStaticField[T: TypeInfo](name: String, init: Code[T]): StaticField[T] = + ecb.newStaticField[T](name, init) def genField[T: TypeInfo](baseName: String): Field[T] = ecb.genField(baseName) - def genFieldThisRef[T: TypeInfo](name: String = null): ThisFieldRef[T] = ecb.genFieldThisRef[T](name) + def getField[T: TypeInfo](name: String): Field[T] = ecb.getField(name) + + def genFieldThisRef[T: TypeInfo](name: String = null): ThisFieldRef[T] = + ecb.genFieldThisRef[T](name) - def genLazyFieldThisRef[T: TypeInfo](setup: Code[T], name: String = null): Value[T] = ecb.genLazyFieldThisRef(setup, name) + def genLazyFieldThisRef[T: TypeInfo](setup: Code[T], name: String = null): Value[T] = + ecb.genLazyFieldThisRef(setup, name) - def getOrDefineLazyField[T: TypeInfo](setup: Code[T], id: Any): Value[T] = ecb.getOrDefineLazyField(setup, id) + def getOrDefineLazyField[T: TypeInfo](setup: Code[T], id: Any): Value[T] = + ecb.getOrDefineLazyField(setup, id) - def newPSettable(sb: SettableBuilder, pt: SType, name: String = null): SSettable = ecb.newPSettable(sb, pt, name) + def newPSettable(sb: SettableBuilder, pt: SType, name: String = null): SSettable = + ecb.newPSettable(sb, pt, name) def newPField(pt: SType): SSettable = ecb.newPField(pt) @@ -155,9 +188,11 @@ trait WrappedEmitClassBuilder[C] extends WrappedEmitModuleBuilder { def newEmitField(pt: SType, required: Boolean): EmitSettable = ecb.newEmitField(pt, required) - def newEmitField(name: String, et: EmitType): EmitSettable = ecb.newEmitField(name, et.st, et.required) + def newEmitField(name: String, et: EmitType): EmitSettable = + ecb.newEmitField(name, et.st, et.required) - def newEmitField(name: String, pt: SType, required: Boolean): EmitSettable = ecb.newEmitField(name, pt, required) + def newEmitField(name: String, pt: SType, required: Boolean): EmitSettable = + ecb.newEmitField(name, pt, required) def fieldBuilder: SettableBuilder = cb.fieldBuilder @@ -170,7 +205,7 @@ trait WrappedEmitClassBuilder[C] extends WrappedEmitModuleBuilder { def getTaskContext: Value[HailTaskContext] = ecb.getTaskContext - def getObject[T <: AnyRef : TypeInfo](obj: T): Code[T] = ecb.getObject(obj) + def getObject[T <: AnyRef: TypeInfo](obj: T): Code[T] = ecb.getObject(obj) def getSerializedAgg(i: Int): Code[Array[Byte]] = ecb.getSerializedAgg(i) @@ -180,47 +215,72 @@ trait WrappedEmitClassBuilder[C] extends WrappedEmitModuleBuilder { def backend(): Code[BackendUtils] = ecb.backend() - def addModule(name: String, mod: (HailClassLoader, FS, HailTaskContext, Region) => AsmFunction3[Region, Array[Byte], Array[Byte], Array[Byte]]): Unit = + def addModule( + name: String, + mod: (HailClassLoader, FS, HailTaskContext, Region) => AsmFunction3[ + Region, + Array[Byte], + Array[Byte], + Array[Byte], + ], + ): Unit = ecb.addModule(name, mod) def partitionRegion: Settable[Region] = ecb.partitionRegion - def addLiteral(cb: EmitCodeBuilder, v: Any, t: VirtualTypeWithReq): SValue = ecb.addLiteral(cb, v, t) + def addLiteral(cb: EmitCodeBuilder, v: Any, t: VirtualTypeWithReq): SValue = + ecb.addLiteral(cb, v, t) - def addEncodedLiteral(cb: EmitCodeBuilder, encodedLiteral: EncodedLiteral) = ecb.addEncodedLiteral(cb, encodedLiteral) + def addEncodedLiteral(cb: EmitCodeBuilder, encodedLiteral: EncodedLiteral) = + ecb.addEncodedLiteral(cb, encodedLiteral) - def getPType[T <: PType : TypeInfo](t: T): Code[T] = ecb.getPType(t) + def getPType[T <: PType: TypeInfo](t: T): Code[T] = ecb.getPType(t) - def getType[T <: Type : TypeInfo](t: T): Code[T] = ecb.getType(t) + def getType[T <: Type: TypeInfo](t: T): Code[T] = ecb.getType(t) - def newEmitMethod(name: String, argsInfo: IndexedSeq[ParamType], returnInfo: ParamType): EmitMethodBuilder[C] = + def newEmitMethod(name: String, argsInfo: IndexedSeq[ParamType], returnInfo: ParamType) + : EmitMethodBuilder[C] = ecb.newEmitMethod(name, argsInfo, returnInfo) - def newEmitMethod(name: String, argsInfo: IndexedSeq[MaybeGenericTypeInfo[_]], returnInfo: MaybeGenericTypeInfo[_]): EmitMethodBuilder[C] = + def newEmitMethod( + name: String, + argsInfo: IndexedSeq[MaybeGenericTypeInfo[_]], + returnInfo: MaybeGenericTypeInfo[_], + ): EmitMethodBuilder[C] = ecb.newEmitMethod(name, argsInfo, returnInfo) - def newStaticEmitMethod(name: String, argsInfo: IndexedSeq[ParamType], returnInfo: ParamType): EmitMethodBuilder[C] = + def newStaticEmitMethod(name: String, argsInfo: IndexedSeq[ParamType], returnInfo: ParamType) + : EmitMethodBuilder[C] = ecb.newStaticEmitMethod(name, argsInfo, returnInfo) - def genEmitMethod(baseName: String, argsInfo: IndexedSeq[ParamType], returnInfo: ParamType): EmitMethodBuilder[C] = + def genEmitMethod(baseName: String, argsInfo: IndexedSeq[ParamType], returnInfo: ParamType) + : EmitMethodBuilder[C] = ecb.genEmitMethod(baseName, argsInfo, returnInfo) - def genStaticEmitMethod(baseName: String, argsInfo: IndexedSeq[ParamType], returnInfo: ParamType): EmitMethodBuilder[C] = + def genStaticEmitMethod(baseName: String, argsInfo: IndexedSeq[ParamType], returnInfo: ParamType) + : EmitMethodBuilder[C] = ecb.genStaticEmitMethod(baseName, argsInfo, returnInfo) - def addAggStates(aggSigs: Array[agg.AggStateSig]): agg.TupleAggregatorState = ecb.addAggStates(aggSigs) + def addAggStates(aggSigs: Array[agg.AggStateSig]): agg.TupleAggregatorState = + ecb.addAggStates(aggSigs) def newRNG(seed: Long): Value[IRRandomness] = ecb.newRNG(seed) def getThreefryRNG(): Value[ThreefryRandomEngine] = ecb.getThreefryRNG() def resultWithIndex(print: Option[PrintWriter] = None) - : (HailClassLoader, FS, HailTaskContext, Region) => C = + : (HailClassLoader, FS, HailTaskContext, Region) => C = ecb.resultWithIndex(print) def getOrGenEmitMethod( - baseName: String, key: Any, argsInfo: IndexedSeq[ParamType], returnInfo: ParamType - )(body: EmitMethodBuilder[C] => Unit): EmitMethodBuilder[C] = ecb.getOrGenEmitMethod(baseName, key, argsInfo, returnInfo)(body) + baseName: String, + key: Any, + argsInfo: IndexedSeq[ParamType], + returnInfo: ParamType, + )( + body: EmitMethodBuilder[C] => Unit + ): EmitMethodBuilder[C] = + ecb.getOrGenEmitMethod(baseName, key, argsInfo, returnInfo)(body) def genEmitMethod[R: TypeInfo](baseName: String): EmitMethodBuilder[C] = ecb.genEmitMethod[R](baseName) @@ -228,16 +288,29 @@ trait WrappedEmitClassBuilder[C] extends WrappedEmitModuleBuilder { def genEmitMethod[A: TypeInfo, R: TypeInfo](baseName: String): EmitMethodBuilder[C] = ecb.genEmitMethod[A, R](baseName) - def genEmitMethod[A1: TypeInfo, A2: TypeInfo, R: TypeInfo](baseName: String): EmitMethodBuilder[C] = + def genEmitMethod[A1: TypeInfo, A2: TypeInfo, R: TypeInfo](baseName: String) + : EmitMethodBuilder[C] = ecb.genEmitMethod[A1, A2, R](baseName) - def genEmitMethod[A1: TypeInfo, A2: TypeInfo, A3: TypeInfo, R: TypeInfo](baseName: String): EmitMethodBuilder[C] = + def genEmitMethod[A1: TypeInfo, A2: TypeInfo, A3: TypeInfo, R: TypeInfo](baseName: String) + : EmitMethodBuilder[C] = ecb.genEmitMethod[A1, A2, A3, R](baseName) - def geEmitMethod[A1: TypeInfo, A2: TypeInfo, A3: TypeInfo, A4: TypeInfo, R: TypeInfo](baseName: String): EmitMethodBuilder[C] = + def geEmitMethod[A1: TypeInfo, A2: TypeInfo, A3: TypeInfo, A4: TypeInfo, R: TypeInfo]( + baseName: String + ): EmitMethodBuilder[C] = ecb.genEmitMethod[A1, A2, A3, A4, R](baseName) - def genEmitMethod[A1: TypeInfo, A2: TypeInfo, A3: TypeInfo, A4: TypeInfo, A5: TypeInfo, R: TypeInfo](baseName: String): EmitMethodBuilder[C] = + def genEmitMethod[ + A1: TypeInfo, + A2: TypeInfo, + A3: TypeInfo, + A4: TypeInfo, + A5: TypeInfo, + R: TypeInfo, + ]( + baseName: String + ): EmitMethodBuilder[C] = ecb.genEmitMethod[A1, A2, A3, A4, A5, R](baseName) def openUnbuffered(path: Code[String], checkCodec: Code[Boolean]): Code[InputStream] = @@ -245,20 +318,20 @@ trait WrappedEmitClassBuilder[C] extends WrappedEmitModuleBuilder { def open(path: Code[String], checkCodec: Code[Boolean]): Code[InputStream] = Code.newInstance[java.io.BufferedInputStream, InputStream]( - getFS.invoke[String, Boolean, InputStream]("open", path, checkCodec)) + getFS.invoke[String, Boolean, InputStream]("open", path, checkCodec) + ) def createUnbuffered(path: Code[String]): Code[OutputStream] = getFS.invoke[String, OutputStream]("create", path) def create(path: Code[String]): Code[OutputStream] = Code.newInstance[java.io.BufferedOutputStream, OutputStream]( - getFS.invoke[String, OutputStream]("create", path)) + getFS.invoke[String, OutputStream]("create", path) + ) } -class EmitClassBuilder[C]( - val emodb: EmitModuleBuilder, - val cb: ClassBuilder[C] -) extends WrappedEmitModuleBuilder { self => +final class EmitClassBuilder[C](val emodb: EmitModuleBuilder, val cb: ClassBuilder[C]) + extends WrappedEmitModuleBuilder { self => // wrapped ClassBuilder methods def className: String = cb.className @@ -266,56 +339,68 @@ class EmitClassBuilder[C]( def newStaticField[T: TypeInfo](name: String): StaticField[T] = cb.newStaticField[T](name) - def newStaticField[T: TypeInfo](name: String, init: Code[T]): StaticField[T] = cb.newStaticField[T](name, init) + def newStaticField[T: TypeInfo](name: String, init: Code[T]): StaticField[T] = + cb.newStaticField[T](name, init) def genField[T: TypeInfo](baseName: String): Field[T] = cb.genField(baseName) - def genFieldThisRef[T: TypeInfo](name: String = null): ThisFieldRef[T] = cb.genFieldThisRef[T](name) + def getField[T: TypeInfo](name: String): Field[T] = cb.getField(name) - def genLazyFieldThisRef[T: TypeInfo](setup: Code[T], name: String = null): Value[T] = cb.genLazyFieldThisRef(setup, name) + def genFieldThisRef[T: TypeInfo](name: String = null): ThisFieldRef[T] = + cb.genFieldThisRef[T](name) - def getOrDefineLazyField[T: TypeInfo](setup: Code[T], id: Any): Value[T] = cb.getOrDefineLazyField(setup, id) + def genLazyFieldThisRef[T: TypeInfo](setup: Code[T], name: String = null): Value[T] = + cb.genLazyFieldThisRef(setup, name) + + def getOrDefineLazyField[T: TypeInfo](setup: Code[T], id: Any): Value[T] = + cb.getOrDefineLazyField(setup, id) def fieldBuilder: SettableBuilder = cb.fieldBuilder - def result(writeIRs: Boolean, print: Option[PrintWriter] = None): (HailClassLoader) => C = cb.result(writeIRs, print) + def result(writeIRs: Boolean, print: Option[PrintWriter] = None): (HailClassLoader) => C = + cb.result(writeIRs, print) // EmitClassBuilder methods - def newPSettable(sb: SettableBuilder, st: SType, name: String = null): SSettable = SSettable(sb, st, name) + def newPSettable(sb: SettableBuilder, st: SType, name: String = null): SSettable = + SSettable(sb, st, name) def newPField(st: SType): SSettable = newPSettable(fieldBuilder, st) def newPField(name: String, st: SType): SSettable = newPSettable(fieldBuilder, st, name) def newEmitField(st: SType, required: Boolean): EmitSettable = - new EmitSettable(if (required) None else Some(genFieldThisRef[Boolean]("emitfield_missing")), newPField(st)) + new EmitSettable( + if (required) None else Some(genFieldThisRef[Boolean]("emitfield_missing")), + newPField(st), + ) - def newEmitField(name: String, emitType: EmitType): EmitSettable = newEmitField(name, emitType.st, emitType.required) + def newEmitField(name: String, emitType: EmitType): EmitSettable = + newEmitField(name, emitType.st, emitType.required) def newEmitField(name: String, st: SType, required: Boolean): EmitSettable = - new EmitSettable(if (required) None else Some(genFieldThisRef[Boolean](name + "_missing")), newPField(name, st)) - - private[this] val typMap: mutable.Map[Type, Value[_ <: Type]] = - mutable.Map() - - private[this] val pTypeMap: mutable.Map[PType, Value[_ <: PType]] = mutable.Map() + new EmitSettable( + if (required) None else Some(genFieldThisRef[Boolean](name + "_missing")), + newPField(name, st), + ) private[this] type CompareMapKey = (SType, SType) + private[this] val memoizedComparisons: mutable.Map[CompareMapKey, CodeOrdering] = mutable.Map[CompareMapKey, CodeOrdering]() - - def numTypes: Int = typMap.size - private[this] val decodedLiteralsField = genFieldThisRef[Array[Long]]("decoded_lits") def literalsArray(): Value[Array[Long]] = decodedLiteralsField - def setLiteralsArray(cb: EmitCodeBuilder, arr: Code[Array[Long]]): Unit = cb.assign(decodedLiteralsField, arr) + def setLiteralsArray(cb: EmitCodeBuilder, arr: Code[Array[Long]]): Unit = + cb.assign(decodedLiteralsField, arr) lazy val partitionRegion: Settable[Region] = genFieldThisRef[Region]("partitionRegion") - private[this] lazy val _taskContext: Settable[HailTaskContext] = genFieldThisRef[HailTaskContext]("taskContext") + + private[this] lazy val _taskContext: Settable[HailTaskContext] = + genFieldThisRef[HailTaskContext]("taskContext") + private[this] lazy val poolField: Settable[RegionPool] = genFieldThisRef[RegionPool]() def addLiteral(cb: EmitCodeBuilder, v: Any, t: VirtualTypeWithReq): SValue = { @@ -336,30 +421,47 @@ class EmitClassBuilder[C]( val spec = TypedCodecSpec(litType, BufferSpec.wireSpec) cb.addInterface(typeInfo[FunctionWithLiterals].iname) - val mb2 = newEmitMethod("addAndDecodeLiterals", FastSeq[ParamType](typeInfo[Array[AnyRef]]), typeInfo[Unit]) + val mb2 = newEmitMethod( + "addAndDecodeLiterals", + FastSeq[ParamType](typeInfo[Array[AnyRef]]), + typeInfo[Unit], + ) mb2.voidWithBuilder { cb => val allEncodedFields = mb2.getCodeParam[Array[AnyRef]](1) - cb.assign(decodedLiteralsField, Code.newArray[Long](literals.length + preEncodedLiterals.length)) - - val ib = cb.newLocal[InputBuffer]("ib", - spec.buildCodeInputBuffer(Code.newInstance[ByteArrayInputStream, Array[Byte]](Code.checkcast[Array[Byte]](allEncodedFields(0))))) + cb.assign( + decodedLiteralsField, + Code.newArray[Long](literals.length + preEncodedLiterals.length), + ) + + val ib = cb.newLocal[InputBuffer]( + "ib", + spec.buildCodeInputBuffer(Code.newInstance[ByteArrayInputStream, Array[Byte]]( + Code.checkcast[Array[Byte]](allEncodedFields(0)) + )), + ) val lits = spec.encodedType.buildDecoder(spec.encodedVirtualType, this) .apply(cb, partitionRegion, ib) .asBaseStruct - literals.zipWithIndex.foreach { case ((t, _, pt, arrIdx), i) => + literals.zipWithIndex.foreach { case ((_, _, pt, arrIdx), i) => lits.loadField(cb, i) - .consume(cb, + .consume( + cb, cb._fatal("expect non-missing literals!"), - { pc => - cb += (decodedLiteralsField(arrIdx) = pt.store(cb, partitionRegion, pc, false)) - }) + pc => cb += (decodedLiteralsField(arrIdx) = pt.store(cb, partitionRegion, pc, false)), + ) } // Handle the pre-encoded literals, which only need to be decoded. preEncodedLiterals.zipWithIndex.foreach { case ((encLit, pt, arrIdx), index) => val spec = encLit.codec - cb.assign(ib, spec.buildCodeInputBuffer(Code.newInstance[ArrayOfByteArrayInputStream, Array[Array[Byte]]](Code.checkcast[Array[Array[Byte]]](allEncodedFields(index + 1))))) + cb.assign( + ib, + spec.buildCodeInputBuffer(Code.newInstance[ + ArrayOfByteArrayInputStream, + Array[Array[Byte]], + ](Code.checkcast[Array[Array[Byte]]](allEncodedFields(index + 1)))), + ) val decodedValue = encLit.codec.encodedType.buildDecoder(encLit.typ, this) .apply(cb, partitionRegion, ib) cb += (decodedLiteralsField(arrIdx) = pt.store(cb, partitionRegion, decodedValue, false)) @@ -381,7 +483,11 @@ class EmitClassBuilder[C]( Array[AnyRef](baos.toByteArray) ++ preEncodedLiterals.map(_._1.value.ba) } - private[this] var _mods: BoxedArrayBuilder[(String, (HailClassLoader, FS, HailTaskContext, Region) => AsmFunction3[Region, Array[Byte], Array[Byte], Array[Byte]])] = new BoxedArrayBuilder() + private[this] var _mods: BoxedArrayBuilder[( + String, + (HailClassLoader, FS, HailTaskContext, Region) => AsmFunction3[Region, Array[Byte], Array[Byte], Array[Byte]], + )] = new BoxedArrayBuilder() + private[this] var _backendField: Settable[BackendUtils] = _ private[this] var _aggSigs: Array[agg.AggStateSig] = _ @@ -403,19 +509,32 @@ class EmitClassBuilder[C]( _aggSerialized = genFieldThisRef[Array[Array[Byte]]]("agg_serialized") val newF = newEmitMethod("newAggState", FastSeq[ParamType](typeInfo[Region]), typeInfo[Unit]) - val setF = newEmitMethod("setAggState", FastSeq[ParamType](typeInfo[Region], typeInfo[Long]), typeInfo[Unit]) + val setF = newEmitMethod( + "setAggState", + FastSeq[ParamType](typeInfo[Region], typeInfo[Long]), + typeInfo[Unit], + ) val getF = newEmitMethod("getAggOffset", FastSeq[ParamType](), typeInfo[Long]) val storeF = newEmitMethod("storeAggsToRegion", FastSeq[ParamType](), typeInfo[Unit]) - val setNSer = newEmitMethod("setNumSerialized", FastSeq[ParamType](typeInfo[Int]), typeInfo[Unit]) - val setSer = newEmitMethod("setSerializedAgg", FastSeq[ParamType](typeInfo[Int], typeInfo[Array[Byte]]), typeInfo[Unit]) - val getSer = newEmitMethod("getSerializedAgg", FastSeq[ParamType](typeInfo[Int]), typeInfo[Array[Byte]]) + val setNSer = + newEmitMethod("setNumSerialized", FastSeq[ParamType](typeInfo[Int]), typeInfo[Unit]) + val setSer = newEmitMethod( + "setSerializedAgg", + FastSeq[ParamType](typeInfo[Int], typeInfo[Array[Byte]]), + typeInfo[Unit], + ) + val getSer = + newEmitMethod("getSerializedAgg", FastSeq[ParamType](typeInfo[Int]), typeInfo[Array[Byte]]) val (nfcode, states) = EmitCodeBuilder.scoped(newF) { cb => val states = agg.StateTuple(aggSigs.map(a => agg.AggStateSig.getState(a, cb.emb.ecb)).toArray) _aggState = new agg.TupleAggregatorState(this, states, _aggRegion, _aggOff) cb += (_aggRegion := newF.getCodeParam[Region](1)) cb += _aggState.topRegion.setNumParents(aggSigs.length) - cb += (_aggOff := _aggRegion.load().allocate(states.storageType.alignment, states.storageType.byteSize)) + cb += (_aggOff := _aggRegion.load().allocate( + states.storageType.alignment, + states.storageType.byteSize, + )) states.createStates(cb) _aggState.newState(cb) @@ -434,17 +553,18 @@ class EmitClassBuilder[C]( } getF.emitWithBuilder { cb => - storeF.invokeCode[Unit](cb) + cb.invokeVoid(storeF, cb.this_) _aggOff } - storeF.voidWithBuilder { cb => - _aggState.store(cb) - } + storeF.voidWithBuilder(cb => _aggState.store(cb)) setNSer.emit(_aggSerialized := Code.newArray[Array[Byte]](setNSer.getCodeParam[Int](1))) - setSer.emit(_aggSerialized.load().update(setSer.getCodeParam[Int](1), setSer.getCodeParam[Array[Byte]](2))) + setSer.emit(_aggSerialized.load().update( + setSer.getCodeParam[Int](1), + setSer.getCodeParam[Array[Byte]](2), + )) getSer.emit(_aggSerialized.load()(getSer.getCodeParam[Int](1))) @@ -468,30 +588,39 @@ class EmitClassBuilder[C]( _aggSerialized.load().update(i, Code._null[Array[Byte]]) } - def runMethodWithHailExceptionHandler(mname: String): Code[(String, java.lang.Integer)] = { - Code.invokeScalaObject2[AnyRef, String, (String, java.lang.Integer)](CodeExceptionHandler.getClass, + def runMethodWithHailExceptionHandler(mname: String): Code[(String, java.lang.Integer)] = + Code.invokeScalaObject2[AnyRef, String, (String, java.lang.Integer)]( + CodeExceptionHandler.getClass, "handleUserException", - cb._this.get.asInstanceOf[Code[AnyRef]], mname) - } + cb.this_.get.asInstanceOf[Code[AnyRef]], + mname, + ) def backend(): Code[BackendUtils] = { if (_backendField == null) { cb.addInterface(typeInfo[FunctionWithBackend].iname) val backendField = genFieldThisRef[BackendUtils]() - val mb = newEmitMethod("setBackend", FastSeq[ParamType](typeInfo[BackendUtils]), typeInfo[Unit]) + val mb = + newEmitMethod("setBackend", FastSeq[ParamType](typeInfo[BackendUtils]), typeInfo[Unit]) mb.emit(backendField := mb.getCodeParam[BackendUtils](1)) _backendField = backendField } _backendField } - def pool(): Value[RegionPool] = { + def pool(): Value[RegionPool] = poolField - } - def addModule(name: String, mod: (HailClassLoader, FS, HailTaskContext, Region) => AsmFunction3[Region, Array[Byte], Array[Byte], Array[Byte]]): Unit = { + def addModule( + name: String, + mod: (HailClassLoader, FS, HailTaskContext, Region) => AsmFunction3[ + Region, + Array[Byte], + Array[Byte], + Array[Byte], + ], + ): Unit = _mods += name -> mod - } def getHailClassLoader: Code[HailClassLoader] = emodb.getHailClassLoader @@ -499,32 +628,30 @@ class EmitClassBuilder[C]( def getTaskContext: Value[HailTaskContext] = _taskContext - def setObjects(cb: EmitCodeBuilder, objects: Code[Array[AnyRef]]): Unit = modb.setObjects(cb, objects) + def setObjects(cb: EmitCodeBuilder, objects: Code[Array[AnyRef]]): Unit = + modb.setObjects(cb, objects) - def getObject[T <: AnyRef : TypeInfo](obj: T): Code[T] = modb.getObject(obj) + def getObject[T <: AnyRef: TypeInfo](obj: T): Code[T] = modb.getObject(obj) def makeAddObjects(): Array[AnyRef] = { if (emodb.modb._objects == null) null else { cb.addInterface(typeInfo[FunctionWithObjects].iname) - val mb = newEmitMethod("setObjects", FastSeq[ParamType](typeInfo[Array[AnyRef]]), typeInfo[Unit]) + val mb = + newEmitMethod("setObjects", FastSeq[ParamType](typeInfo[Array[AnyRef]]), typeInfo[Unit]) mb.voidWithBuilder(cb => emodb.setObjects(cb, mb.getCodeParam[Array[AnyRef]](1))) emodb.modb._objects.result() } } - def getPType[T <: PType : TypeInfo](t: T): Code[T] = emodb.getObject(t) + def getPType[T <: PType: TypeInfo](t: T): Code[T] = emodb.getObject(t) - def getType[T <: Type : TypeInfo](t: T): Code[T] = emodb.getObject(t) + def getType[T <: Type: TypeInfo](t: T): Code[T] = emodb.getObject(t) - def getOrdering(t1: SType, - t2: SType, - sortOrder: SortOrder = Ascending - ): CodeOrdering = { - val baseOrd = memoizedComparisons.getOrElseUpdate((t1, t2), { - CodeOrdering.makeOrdering(t1, t2, this) - }) + def getOrdering(t1: SType, t2: SType, sortOrder: SortOrder = Ascending): CodeOrdering = { + val baseOrd = + memoizedComparisons.getOrElseUpdate((t1, t2), CodeOrdering.makeOrdering(t1, t2, this)) sortOrder match { case Ascending => baseOrd case Descending => baseOrd.reverse @@ -535,12 +662,11 @@ class EmitClassBuilder[C]( t1: SType, t2: SType, sortOrder: SortOrder, - op: CodeOrdering.Op + op: CodeOrdering.Op, ): CodeOrdering.F[op.ReturnType] = { val ord = getOrdering(t1, t2, sortOrder) { (cb: EmitCodeBuilder, v1: EmitValue, v2: EmitValue) => - val r: Code[_] = op match { case CodeOrdering.Compare(missingEqual) => ord.compare(cb, v1, v2, missingEqual) case CodeOrdering.Equiv(missingEqual) => ord.equiv(cb, v1, v2, missingEqual) @@ -550,7 +676,10 @@ class EmitClassBuilder[C]( case CodeOrdering.Gteq(missingEqual) => ord.gteq(cb, v1, v2, missingEqual) case CodeOrdering.Neq(missingEqual) => !ord.equiv(cb, v1, v2, missingEqual) } - cb.memoize[op.ReturnType](coerce[op.ReturnType](r))(op.rtti, implicitly[op.ReturnType =!= Unit]) + cb.memoize[op.ReturnType](coerce[op.ReturnType](r))( + op.rtti, + implicitly[op.ReturnType =!= Unit], + ) } } @@ -558,9 +687,9 @@ class EmitClassBuilder[C]( t1: SBaseStruct, t2: SBaseStruct, sortFields: Array[SortField], - op: CodeOrdering.Op + op: CodeOrdering.Op, ): CodeOrdering.F[op.ReturnType] = { - { (cb: EmitCodeBuilder, v1: EmitValue, v2: EmitValue) => + (cb: EmitCodeBuilder, v1: EmitValue, v2: EmitValue) => val ord = StructOrdering.make(t1, t2, cb.emb.ecb, sortFields.map(_.sortOrder)) val r: Code[_] = op match { @@ -570,25 +699,29 @@ class EmitClassBuilder[C]( case CodeOrdering.StructGteq(missingEqual) => ord.gteq(cb, v1, v2, missingEqual) case CodeOrdering.StructCompare(missingEqual) => ord.compare(cb, v1, v2, missingEqual) } - cb.memoize[op.ReturnType](coerce[op.ReturnType](r))(op.rtti, implicitly[op.ReturnType =!= Unit]) - } + cb.memoize[op.ReturnType](coerce[op.ReturnType](r))( + op.rtti, + implicitly[op.ReturnType =!= Unit], + ) } // derived functions def getOrderingFunction(t: SType, op: CodeOrdering.Op): CodeOrdering.F[op.ReturnType] = getOrderingFunction(t, t, sortOrder = Ascending, op) - def getOrderingFunction(t1: SType, t2: SType, op: CodeOrdering.Op): CodeOrdering.F[op.ReturnType] = + def getOrderingFunction(t1: SType, t2: SType, op: CodeOrdering.Op) + : CodeOrdering.F[op.ReturnType] = getOrderingFunction(t1, t2, sortOrder = Ascending, op) def getOrderingFunction( t: SType, op: CodeOrdering.Op, - sortOrder: SortOrder + sortOrder: SortOrder, ): CodeOrdering.F[op.ReturnType] = getOrderingFunction(t, t, sortOrder, op) - private def getCodeArgsInfo(argsInfo: IndexedSeq[ParamType], returnInfo: ParamType): (IndexedSeq[TypeInfo[_]], TypeInfo[_], AsmTuple[_]) = { + private def getCodeArgsInfo(argsInfo: IndexedSeq[ParamType], returnInfo: ParamType) + : (IndexedSeq[TypeInfo[_]], TypeInfo[_], AsmTuple[_]) = { val codeArgsInfo = argsInfo.flatMap { case CodeParamType(ti) => FastSeq(ti) case t: EmitParamType => t.valueTupleTypes @@ -613,33 +746,60 @@ class EmitClassBuilder[C]( (codeArgsInfo, codeReturnInfo, asmTuple) } - def newEmitMethod(name: String, argsInfo: IndexedSeq[ParamType], returnInfo: ParamType): EmitMethodBuilder[C] = { - val (codeArgsInfo, codeReturnInfo, asmTuple) = getCodeArgsInfo(argsInfo, returnInfo) + def ctor: EmitMethodBuilder[C] = + new EmitMethodBuilder[C](FastSeq(), CodeParamType(UnitInfo), this, cb.ctor, null) - new EmitMethodBuilder[C](argsInfo, returnInfo, this, cb.newMethod(name, codeArgsInfo, codeReturnInfo), asmTuple) - } + def emitInitI(f: EmitCodeBuilder => Unit): Unit = + ctor.cb.emitInit(EmitCodeBuilder.scopedVoid(ctor)(f)) + + def newEmitMethod(name: String, argsInfo: IndexedSeq[ParamType], returnInfo: ParamType) + : EmitMethodBuilder[C] = { + val (codeArgsInfo, codeReturnInfo, asmTuple) = getCodeArgsInfo(argsInfo, returnInfo) - def newEmitMethod(name: String, argsInfo: IndexedSeq[MaybeGenericTypeInfo[_]], returnInfo: MaybeGenericTypeInfo[_]): EmitMethodBuilder[C] = { new EmitMethodBuilder[C]( - argsInfo.map(ai => CodeParamType(ai.base)), CodeParamType(returnInfo.base), - this, cb.newMethod(name, argsInfo, returnInfo), asmTuple = null) + argsInfo, + returnInfo, + this, + cb.newMethod(name, codeArgsInfo, codeReturnInfo), + asmTuple, + ) } - def newStaticEmitMethod(name: String, argsInfo: IndexedSeq[ParamType], returnInfo: ParamType): EmitMethodBuilder[C] = { + def newEmitMethod( + name: String, + argsInfo: IndexedSeq[MaybeGenericTypeInfo[_]], + returnInfo: MaybeGenericTypeInfo[_], + ): EmitMethodBuilder[C] = + new EmitMethodBuilder[C]( + argsInfo.map(ai => CodeParamType(ai.base)), + CodeParamType(returnInfo.base), + this, + cb.newMethod(name, argsInfo, returnInfo), + asmTuple = null, + ) + + def newStaticEmitMethod(name: String, argsInfo: IndexedSeq[ParamType], returnInfo: ParamType) + : EmitMethodBuilder[C] = { val (codeArgsInfo, codeReturnInfo, asmTuple) = getCodeArgsInfo(argsInfo, returnInfo) - new EmitMethodBuilder[C](argsInfo, returnInfo, this, + new EmitMethodBuilder[C]( + argsInfo, + returnInfo, + this, cb.newStaticMethod(name, codeArgsInfo, codeReturnInfo), - asmTuple) + asmTuple, + ) } - val rngs: BoxedArrayBuilder[(Settable[IRRandomness], Code[IRRandomness])] = new BoxedArrayBuilder() + val rngs: BoxedArrayBuilder[(Settable[IRRandomness], Code[IRRandomness])] = + new BoxedArrayBuilder() var threefryRNG: Option[(Settable[ThreefryRandomEngine], Code[ThreefryRandomEngine])] = None def makeAddPartitionRegion(): Unit = { cb.addInterface(typeInfo[FunctionWithPartitionRegion].iname) - val mb = newEmitMethod("addPartitionRegion", FastSeq[ParamType](typeInfo[Region]), typeInfo[Unit]) + val mb = + newEmitMethod("addPartitionRegion", FastSeq[ParamType](typeInfo[Region]), typeInfo[Unit]) mb.emit(partitionRegion := mb.getCodeParam[Region](1)) val mb2 = newEmitMethod("setPool", FastSeq[ParamType](typeInfo[RegionPool]), typeInfo[Unit]) mb2.emit(poolField := mb2.getCodeParam[RegionPool](1)) @@ -647,42 +807,59 @@ class EmitClassBuilder[C]( def makeAddHailClassLoader(): Unit = { cb.addInterface(typeInfo[FunctionWithHailClassLoader].iname) - val mb = newEmitMethod("addHailClassLoader", FastSeq[ParamType](typeInfo[HailClassLoader]), typeInfo[Unit]) - mb.voidWithBuilder { cb => - emodb.setHailClassLoader(cb, mb.getCodeParam[HailClassLoader](1)) - } + val mb = newEmitMethod( + "addHailClassLoader", + FastSeq[ParamType](typeInfo[HailClassLoader]), + typeInfo[Unit], + ) + mb.voidWithBuilder(cb => emodb.setHailClassLoader(cb, mb.getCodeParam[HailClassLoader](1))) } def makeAddFS(): Unit = { cb.addInterface(typeInfo[FunctionWithFS].iname) val mb = newEmitMethod("addFS", FastSeq[ParamType](typeInfo[FS]), typeInfo[Unit]) - mb.voidWithBuilder { cb => - emodb.setFS(cb, mb.getCodeParam[FS](1)) - } + mb.voidWithBuilder(cb => emodb.setFS(cb, mb.getCodeParam[FS](1))) } def makeAddTaskContext(): Unit = { cb.addInterface(typeInfo[FunctionWithTaskContext].iname) - val mb = newEmitMethod("addTaskContext", FastSeq[ParamType](typeInfo[HailTaskContext]), typeInfo[Unit]) - mb.voidWithBuilder { cb => - cb.assign(_taskContext, mb.getCodeParam[HailTaskContext](1)) - } + val mb = + newEmitMethod("addTaskContext", FastSeq[ParamType](typeInfo[HailTaskContext]), typeInfo[Unit]) + mb.voidWithBuilder(cb => cb.assign(_taskContext, mb.getCodeParam[HailTaskContext](1))) } def makeAddReferenceGenomes(): Unit = { cb.addInterface(typeInfo[FunctionWithReferences].iname) - val mb = newEmitMethod("addReferenceGenomes", FastSeq[ParamType](typeInfo[Array[ReferenceGenome]]), typeInfo[Unit]) + val mb = newEmitMethod( + "addReferenceGenomes", + FastSeq[ParamType](typeInfo[Array[ReferenceGenome]]), + typeInfo[Unit], + ) mb.voidWithBuilder { cb => val rgFields = emodb.referenceGenomeFields() val rgs = mb.getCodeParam[Array[ReferenceGenome]](1) - cb.if_(rgs.length().cne(const(rgFields.length)), cb._fatal("Invalid number of references, expected ", rgFields.length.toString, " got ", rgs.length().toS)) + cb.if_( + rgs.length().cne(const(rgFields.length)), + cb._fatal( + "Invalid number of references, expected ", + rgFields.length.toString, + " got ", + rgs.length().toS, + ), + ) for ((fld, i) <- rgFields.zipWithIndex) { cb += rgs(i).invoke[String, FS, Unit]("heal", ctx.localTmpdir, getFS) cb += fld.put(rgs(i)) } Option(emodb._rgMapField).foreach { fld => - cb.assign(fld, Code.invokeStatic1[ReferenceGenome, Array[ReferenceGenome], Map[String, ReferenceGenome]]("getMapFromArray", rgs)) + cb.assign( + fld, + Code.invokeStatic1[ReferenceGenome, Array[ReferenceGenome], Map[String, ReferenceGenome]]( + "getMapFromArray", + rgs, + ), + ) } } } @@ -691,21 +868,24 @@ class EmitClassBuilder[C]( cb.addInterface(typeInfo[FunctionWithSeededRandomness].iname) val initialized = genFieldThisRef[Boolean]("initialized") - val mb = newEmitMethod("setPartitionIndex", IndexedSeq[ParamType](typeInfo[Int]), typeInfo[Unit]) + val mb = + newEmitMethod("setPartitionIndex", IndexedSeq[ParamType](typeInfo[Int]), typeInfo[Unit]) val rngFields = rngs.result() mb.voidWithBuilder { cb => - cb.if_(!initialized, { - rngFields.foreach { case (field, init) => - cb.assign(field, init) - } + cb.if_( + !initialized, { + rngFields.foreach { case (field, init) => + cb.assign(field, init) + } - threefryRNG.foreach { case (field, init) => - cb.assign(field, init) - } + threefryRNG.foreach { case (field, init) => + cb.assign(field, init) + } - cb.assign(initialized, true) - }) + cb.assign(initialized, true) + }, + ) rngFields.foreach { case (field, _) => cb += field.invoke[Int, Unit]("reset", mb.getCodeParam[Int](1)) @@ -725,14 +905,16 @@ class EmitClassBuilder[C]( case None => val rngField = genFieldThisRef[ThreefryRandomEngine]() val rngInit = Code.invokeScalaObject0[ThreefryRandomEngine]( - ThreefryRandomEngine.getClass, "apply") + ThreefryRandomEngine.getClass, + "apply", + ) threefryRNG = Some(rngField -> rngInit) rngField } } def resultWithIndex(print: Option[PrintWriter] = None) - : (HailClassLoader, FS, HailTaskContext, Region) => C = { + : (HailClassLoader, FS, HailTaskContext, Region) => C = { makeRNGs() makeAddPartitionRegion() makeAddHailClassLoader() @@ -757,14 +939,15 @@ class EmitClassBuilder[C]( else null - val nSerializedAggs = _nSerialized val useBackend = _backendField != null val backend = if (useBackend) new BackendUtils(_mods.result()) else null - assert(TaskContext.get() == null, - "FunctionBuilder emission should happen on master, but happened on worker") + assert( + TaskContext.get() == null, + "FunctionBuilder emission should happen on master, but happened on worker", + ) val n = cb.className.replace("/", ".") val classesBytes = modb.classesBytes(ctx.shouldWriteIRFiles(), print) @@ -808,17 +991,36 @@ class EmitClassBuilder[C]( private[this] val methodMemo: mutable.Map[Any, EmitMethodBuilder[C]] = mutable.Map() def getOrGenEmitMethod( - baseName: String, key: Any, argsInfo: IndexedSeq[ParamType], returnInfo: ParamType - )(body: EmitMethodBuilder[C] => Unit): EmitMethodBuilder[C] = { - methodMemo.getOrElse(key, { - val mb = genEmitMethod(baseName, argsInfo, returnInfo) - methodMemo(key) = mb - body(mb) - mb - }) + baseName: String, + key: Any, + argsInfo: IndexedSeq[ParamType], + returnInfo: ParamType, + )( + body: EmitMethodBuilder[C] => Unit + ): EmitMethodBuilder[C] = + methodMemo.getOrElse( + key, { + val mb = genEmitMethod(baseName, argsInfo, returnInfo) + methodMemo(key) = mb + body(mb) + mb + }, + ) + + def defineEmitMethod( + name: String, + paramTys: IndexedSeq[ParamType], + retTy: ParamType, + )( + body: EmitMethodBuilder[C] => Unit + ): EmitMethodBuilder[C] = { + val mb = newEmitMethod(name, paramTys, retTy) + body(mb) + mb } - def genEmitMethod(baseName: String, argsInfo: IndexedSeq[ParamType], returnInfo: ParamType): EmitMethodBuilder[C] = + def genEmitMethod(baseName: String, argsInfo: IndexedSeq[ParamType], returnInfo: ParamType) + : EmitMethodBuilder[C] = newEmitMethod(genName("m", baseName), argsInfo, returnInfo) def genEmitMethod[R: TypeInfo](baseName: String): EmitMethodBuilder[C] = @@ -830,16 +1032,41 @@ class EmitClassBuilder[C]( def genEmitMethod[A: TypeInfo, B: TypeInfo, R: TypeInfo](baseName: String): EmitMethodBuilder[C] = genEmitMethod(baseName, FastSeq[ParamType](typeInfo[A], typeInfo[B]), typeInfo[R]) - def genEmitMethod[A1: TypeInfo, A2: TypeInfo, A3: TypeInfo, R: TypeInfo](baseName: String): EmitMethodBuilder[C] = - genEmitMethod(baseName, FastSeq[ParamType](typeInfo[A1], typeInfo[A2], typeInfo[A3]), typeInfo[R]) - - def genEmitMethod[A1: TypeInfo, A2: TypeInfo, A3: TypeInfo, A4: TypeInfo, R: TypeInfo](baseName: String): EmitMethodBuilder[C] = - genEmitMethod(baseName, FastSeq[ParamType](typeInfo[A1], typeInfo[A2], typeInfo[A3], typeInfo[A4]), typeInfo[R]) - - def genEmitMethod[A1: TypeInfo, A2: TypeInfo, A3: TypeInfo, A4: TypeInfo, A5: TypeInfo, R: TypeInfo](baseName: String): EmitMethodBuilder[C] = - genEmitMethod(baseName, FastSeq[ParamType](typeInfo[A1], typeInfo[A2], typeInfo[A3], typeInfo[A4], typeInfo[A5]), typeInfo[R]) - - def genStaticEmitMethod(baseName: String, argsInfo: IndexedSeq[ParamType], returnInfo: ParamType): EmitMethodBuilder[C] = + def genEmitMethod[A1: TypeInfo, A2: TypeInfo, A3: TypeInfo, R: TypeInfo](baseName: String) + : EmitMethodBuilder[C] = + genEmitMethod( + baseName, + FastSeq[ParamType](typeInfo[A1], typeInfo[A2], typeInfo[A3]), + typeInfo[R], + ) + + def genEmitMethod[A1: TypeInfo, A2: TypeInfo, A3: TypeInfo, A4: TypeInfo, R: TypeInfo]( + baseName: String + ): EmitMethodBuilder[C] = + genEmitMethod( + baseName, + FastSeq[ParamType](typeInfo[A1], typeInfo[A2], typeInfo[A3], typeInfo[A4]), + typeInfo[R], + ) + + def genEmitMethod[ + A1: TypeInfo, + A2: TypeInfo, + A3: TypeInfo, + A4: TypeInfo, + A5: TypeInfo, + R: TypeInfo, + ]( + baseName: String + ): EmitMethodBuilder[C] = + genEmitMethod( + baseName, + FastSeq[ParamType](typeInfo[A1], typeInfo[A2], typeInfo[A3], typeInfo[A4], typeInfo[A5]), + typeInfo[R], + ) + + def genStaticEmitMethod(baseName: String, argsInfo: IndexedSeq[ParamType], returnInfo: ParamType) + : EmitMethodBuilder[C] = newStaticEmitMethod(genName("sm", baseName), argsInfo, returnInfo) def getUnsafeReader(path: Code[String], checkCodec: Code[Boolean]): Code[InputStream] = @@ -850,47 +1077,150 @@ class EmitClassBuilder[C]( } object EmitFunctionBuilder { - def apply[F]( - ctx: ExecuteContext, baseName: String, paramTypes: IndexedSeq[ParamType], returnType: ParamType, sourceFile: Option[String] = None - )(implicit fti: TypeInfo[F]): EmitFunctionBuilder[F] = { + def apply[F: TypeInfo]( + ctx: ExecuteContext, + baseName: String, + paramTypes: IndexedSeq[ParamType], + returnType: ParamType, + sourceFile: Option[String] = None, + ): EmitFunctionBuilder[F] = { val modb = new EmitModuleBuilder(ctx, new ModuleBuilder()) val cb = modb.genEmitClass[F](baseName, sourceFile) val apply = cb.newEmitMethod("apply", paramTypes, returnType) new EmitFunctionBuilder(apply) } - def apply[F]( - ctx: ExecuteContext, baseName: String, argInfo: IndexedSeq[MaybeGenericTypeInfo[_]], returnInfo: MaybeGenericTypeInfo[_] - )(implicit fti: TypeInfo[F]): EmitFunctionBuilder[F] = { + def apply[F: TypeInfo]( + ctx: ExecuteContext, + baseName: String, + argInfo: IndexedSeq[MaybeGenericTypeInfo[_]], + returnInfo: MaybeGenericTypeInfo[_], + ): EmitFunctionBuilder[F] = { val modb = new EmitModuleBuilder(ctx, new ModuleBuilder()) val cb = modb.genEmitClass[F](baseName) - val apply = cb.newEmitMethod("apply", argInfo, returnInfo) + val apply = cb.newEmitMethod("apply", argInfo, returnInfo) new EmitFunctionBuilder(apply) } - def apply[R: TypeInfo](ctx: ExecuteContext, baseName: String): EmitFunctionBuilder[AsmFunction0[R]] = - EmitFunctionBuilder[AsmFunction0[R]](ctx, baseName, FastSeq[MaybeGenericTypeInfo[_]](), GenericTypeInfo[R]) - - def apply[A: TypeInfo, R: TypeInfo](ctx: ExecuteContext, baseName: String): EmitFunctionBuilder[AsmFunction1[A, R]] = - EmitFunctionBuilder[AsmFunction1[A, R]](ctx, baseName, Array(GenericTypeInfo[A]), GenericTypeInfo[R]) - - def apply[A: TypeInfo, B: TypeInfo, R: TypeInfo](ctx: ExecuteContext, baseName: String): EmitFunctionBuilder[AsmFunction2[A, B, R]] = - EmitFunctionBuilder[AsmFunction2[A, B, R]](ctx, baseName, Array(GenericTypeInfo[A], GenericTypeInfo[B]), GenericTypeInfo[R]) - - def apply[A: TypeInfo, B: TypeInfo, C: TypeInfo, R: TypeInfo](ctx: ExecuteContext, baseName: String): EmitFunctionBuilder[AsmFunction3[A, B, C, R]] = - EmitFunctionBuilder[AsmFunction3[A, B, C, R]](ctx, baseName, Array(GenericTypeInfo[A], GenericTypeInfo[B], GenericTypeInfo[C]), GenericTypeInfo[R]) - - def apply[A: TypeInfo, B: TypeInfo, C: TypeInfo, D: TypeInfo, R: TypeInfo](ctx: ExecuteContext, baseName: String): EmitFunctionBuilder[AsmFunction4[A, B, C, D, R]] = - EmitFunctionBuilder[AsmFunction4[A, B, C, D, R]](ctx, baseName, Array(GenericTypeInfo[A], GenericTypeInfo[B], GenericTypeInfo[C], GenericTypeInfo[D]), GenericTypeInfo[R]) - - def apply[A: TypeInfo, B: TypeInfo, C: TypeInfo, D: TypeInfo, E: TypeInfo, R: TypeInfo](ctx: ExecuteContext, baseName: String): EmitFunctionBuilder[AsmFunction5[A, B, C, D, E, R]] = - EmitFunctionBuilder[AsmFunction5[A, B, C, D, E, R]](ctx, baseName, Array(GenericTypeInfo[A], GenericTypeInfo[B], GenericTypeInfo[C], GenericTypeInfo[D], GenericTypeInfo[E]), GenericTypeInfo[R]) - - def apply[A: TypeInfo, B: TypeInfo, C: TypeInfo, D: TypeInfo, E: TypeInfo, F: TypeInfo, R: TypeInfo](ctx: ExecuteContext, baseName: String): EmitFunctionBuilder[AsmFunction6[A, B, C, D, E, F, R]] = - EmitFunctionBuilder[AsmFunction6[A, B, C, D, E, F, R]](ctx, baseName, Array(GenericTypeInfo[A], GenericTypeInfo[B], GenericTypeInfo[C], GenericTypeInfo[D], GenericTypeInfo[E], GenericTypeInfo[F]), GenericTypeInfo[R]) - - def apply[A: TypeInfo, B: TypeInfo, C: TypeInfo, D: TypeInfo, E: TypeInfo, F: TypeInfo, G: TypeInfo, R: TypeInfo](ctx: ExecuteContext, baseName: String): EmitFunctionBuilder[AsmFunction7[A, B, C, D, E, F, G, R]] = - EmitFunctionBuilder[AsmFunction7[A, B, C, D, E, F, G, R]](ctx, baseName, Array(GenericTypeInfo[A], GenericTypeInfo[B], GenericTypeInfo[C], GenericTypeInfo[D], GenericTypeInfo[E], GenericTypeInfo[F], GenericTypeInfo[G]), GenericTypeInfo[R]) + def apply[R: TypeInfo](ctx: ExecuteContext, baseName: String) + : EmitFunctionBuilder[AsmFunction0[R]] = + EmitFunctionBuilder[AsmFunction0[R]]( + ctx, + baseName, + FastSeq[MaybeGenericTypeInfo[_]](), + GenericTypeInfo[R], + ) + + def apply[A: TypeInfo, R: TypeInfo](ctx: ExecuteContext, baseName: String) + : EmitFunctionBuilder[AsmFunction1[A, R]] = + EmitFunctionBuilder[AsmFunction1[A, R]]( + ctx, + baseName, + Array(GenericTypeInfo[A]), + GenericTypeInfo[R], + ) + + def apply[A: TypeInfo, B: TypeInfo, R: TypeInfo](ctx: ExecuteContext, baseName: String) + : EmitFunctionBuilder[AsmFunction2[A, B, R]] = + EmitFunctionBuilder[AsmFunction2[A, B, R]]( + ctx, + baseName, + Array(GenericTypeInfo[A], GenericTypeInfo[B]), + GenericTypeInfo[R], + ) + + def apply[A: TypeInfo, B: TypeInfo, C: TypeInfo, R: TypeInfo]( + ctx: ExecuteContext, + baseName: String, + ): EmitFunctionBuilder[AsmFunction3[A, B, C, R]] = + EmitFunctionBuilder[AsmFunction3[A, B, C, R]]( + ctx, + baseName, + Array(GenericTypeInfo[A], GenericTypeInfo[B], GenericTypeInfo[C]), + GenericTypeInfo[R], + ) + + def apply[A: TypeInfo, B: TypeInfo, C: TypeInfo, D: TypeInfo, R: TypeInfo]( + ctx: ExecuteContext, + baseName: String, + ): EmitFunctionBuilder[AsmFunction4[A, B, C, D, R]] = + EmitFunctionBuilder[AsmFunction4[A, B, C, D, R]]( + ctx, + baseName, + Array(GenericTypeInfo[A], GenericTypeInfo[B], GenericTypeInfo[C], GenericTypeInfo[D]), + GenericTypeInfo[R], + ) + + def apply[A: TypeInfo, B: TypeInfo, C: TypeInfo, D: TypeInfo, E: TypeInfo, R: TypeInfo]( + ctx: ExecuteContext, + baseName: String, + ): EmitFunctionBuilder[AsmFunction5[A, B, C, D, E, R]] = + EmitFunctionBuilder[AsmFunction5[A, B, C, D, E, R]]( + ctx, + baseName, + Array( + GenericTypeInfo[A], + GenericTypeInfo[B], + GenericTypeInfo[C], + GenericTypeInfo[D], + GenericTypeInfo[E], + ), + GenericTypeInfo[R], + ) + + def apply[ + A: TypeInfo, + B: TypeInfo, + C: TypeInfo, + D: TypeInfo, + E: TypeInfo, + F: TypeInfo, + R: TypeInfo, + ]( + ctx: ExecuteContext, + baseName: String, + ): EmitFunctionBuilder[AsmFunction6[A, B, C, D, E, F, R]] = + EmitFunctionBuilder[AsmFunction6[A, B, C, D, E, F, R]]( + ctx, + baseName, + Array( + GenericTypeInfo[A], + GenericTypeInfo[B], + GenericTypeInfo[C], + GenericTypeInfo[D], + GenericTypeInfo[E], + GenericTypeInfo[F], + ), + GenericTypeInfo[R], + ) + + def apply[ + A: TypeInfo, + B: TypeInfo, + C: TypeInfo, + D: TypeInfo, + E: TypeInfo, + F: TypeInfo, + G: TypeInfo, + R: TypeInfo, + ]( + ctx: ExecuteContext, + baseName: String, + ): EmitFunctionBuilder[AsmFunction7[A, B, C, D, E, F, G, R]] = + EmitFunctionBuilder[AsmFunction7[A, B, C, D, E, F, G, R]]( + ctx, + baseName, + Array( + GenericTypeInfo[A], + GenericTypeInfo[B], + GenericTypeInfo[C], + GenericTypeInfo[D], + GenericTypeInfo[E], + GenericTypeInfo[F], + GenericTypeInfo[G], + ), + GenericTypeInfo[R], + ) } trait FunctionWithObjects { @@ -901,7 +1231,8 @@ trait FunctionWithAggRegion { // Calls storeAggsToRegion, and returns the aggregator state offset in the top agg region def getAggOffset(): Long - // stores agg regions into the top agg region, so that all agg resources are referenced solely by that region + /* stores agg regions into the top agg region, so that all agg resources are referenced solely by + * that region */ def storeAggsToRegion(): Unit // Sets the function's agg container to the agg state at $offset, loads agg regions onto class @@ -950,10 +1281,9 @@ trait FunctionWithBackend { } object CodeExceptionHandler { - /** - * This method assumes that the method referred to by `methodName` - * is a -argument class method (only takes the class itself as an arg) - * which returns void. + + /** This method assumes that the method referred to by `methodName` is a -argument class method + * (only takes the class itself as an arg) which returns void. */ def handleUserException(obj: AnyRef, methodName: String): (String, java.lang.Integer) = { try { @@ -974,12 +1304,14 @@ class EmitMethodBuilder[C]( val emitReturnType: ParamType, val ecb: EmitClassBuilder[C], val mb: MethodBuilder[C], - private[ir] val asmTuple: AsmTuple[_] + private[ir] val asmTuple: AsmTuple[_], ) extends WrappedEmitClassBuilder[C] { private[this] val nCodeArgs = emitParamTypes.map(_.nCodes).sum + if (nCodeArgs > 255) - throw new RuntimeException(s"invalid method ${ mb.methodName }: ${ nCodeArgs } code arguments:" + - s"\n ${ emitParamTypes.map(p => s"${ p.nCodes } - $p").mkString("\n ") }") + throw new RuntimeException(s"invalid method ${mb.methodName}: $nCodeArgs code arguments:" + + s"\n ${emitParamTypes.map(p => s"${p.nCodes} - $p").mkString("\n ")}") + // wrapped MethodBuilder methods def newLocal[T: TypeInfo](name: String = null): LocalRef[T] = mb.newLocal[T](name) @@ -1022,7 +1354,7 @@ class EmitMethodBuilder[C]( def storeEmitParamAsField(cb: EmitCodeBuilder, emitIndex: Int): EmitValue = { val param = getEmitParam(cb, emitIndex) - val fd = newEmitField(s"${mb.methodName}_param_${emitIndex}", param.emitType) + val fd = newEmitField(s"${mb.methodName}_param_$emitIndex", param.emitType) cb.assign(fd, param) fd } @@ -1032,7 +1364,9 @@ class EmitMethodBuilder[C]( val static = (!mb.isStatic).toInt val et = emitParamTypes(emitIndex - static) match { case t: EmitParamType => t - case _ => throw new RuntimeException(s"isStatic=${ mb.isStatic }, emitIndex=$emitIndex, params=$emitParamTypes") + case _ => throw new RuntimeException( + s"isStatic=${mb.isStatic}, emitIndex=$emitIndex, params=$emitParamTypes" + ) } val codeIndex = emitParamCodeIndex(emitIndex - static) @@ -1054,33 +1388,32 @@ class EmitMethodBuilder[C]( } } - def invokeCode[T](cb: CodeBuilderLike, args: Param*): Value[T] = { - assert(emitReturnType.isInstanceOf[CodeParamType]) - assert(args.forall(_.isInstanceOf[CodeParam])) - mb.invoke(cb, args.flatMap { - case CodeParam(c) => FastSeq(c) - // If you hit this assertion, it means that an EmitParam was passed to - // invokeCode. Code with EmitParams must be invoked using the EmitCodeBuilder - // interface to ensure that setup is run and missingness is evaluated for the - // EmitCode - case EmitParam(ec) => fatal("EmitParam passed to invokeCode") - }: _*) - } def newPLocal(st: SType): SSettable = newPSettable(localBuilder, st) def newPLocal(name: String, st: SType): SSettable = newPSettable(localBuilder, st, name) def newEmitLocal(emitType: EmitType): EmitSettable = newEmitLocal(emitType.st, emitType.required) + def newEmitLocal(st: SType, required: Boolean): EmitSettable = - new EmitSettable(if (required) None else Some(newLocal[Boolean]("anon_emitlocal_m")), newPLocal("anon_emitlocal_v", st)) + new EmitSettable( + if (required) None else Some(newLocal[Boolean]("anon_emitlocal_m")), + newPLocal("anon_emitlocal_v", st), + ) + + def newEmitLocal(name: String, emitType: EmitType): EmitSettable = + newEmitLocal(name, emitType.st, emitType.required) - def newEmitLocal(name: String, emitType: EmitType): EmitSettable = newEmitLocal(name, emitType.st, emitType.required) def newEmitLocal(name: String, st: SType, required: Boolean): EmitSettable = - new EmitSettable(if (required) None else Some(newLocal[Boolean](name + "_missing")), newPLocal(name, st)) + new EmitSettable( + if (required) None else Some(newLocal[Boolean](name + "_missing")), + newPLocal(name, st), + ) - def emitWithBuilder[T](f: (EmitCodeBuilder) => Code[T]): Unit = emit(EmitCodeBuilder.scopedCode[T](this)(f)) + def emitWithBuilder[T](f: (EmitCodeBuilder) => Code[T]): Unit = + emit(EmitCodeBuilder.scopedCode[T](this)(f)) - def voidWithBuilder(f: (EmitCodeBuilder) => Unit): Unit = emit(EmitCodeBuilder.scopedVoid(this)(f)) + def voidWithBuilder(f: (EmitCodeBuilder) => Unit): Unit = + emit(EmitCodeBuilder.scopedVoid(this)(f)) def emitSCode(f: (EmitCodeBuilder) => SValue): Unit = { emit(EmitCodeBuilder.scopedCode(this) { cb => @@ -1097,15 +1430,11 @@ class EmitMethodBuilder[C]( cb.define(label) f(cb) // assert(!cb.isOpenEnded) - /* - FIXME: The above assertion should hold, but currently does not. This is - likely due to client code with patterns like the following, which incorrectly - leaves the code builder open-ended: - - cb.ifx(b, - cb.goto(L1), - cb.goto(L2)) - */ + /* FIXME: The above assertion should hold, but currently does not. This is likely due to + * client code with patterns like the following, which incorrectly leaves the code builder + * open-ended: + * + * cb.ifx(b, cb.goto(L1), cb.goto(L2)) */ } } @@ -1115,9 +1444,8 @@ class EmitMethodBuilder[C]( label } - override def toString: String = { - s"[${ mb.methodName }]${ super.toString }" - } + override def toString: String = + s"[${mb.methodName}]${super.toString}" } trait WrappedEmitMethodBuilder[C] extends WrappedEmitClassBuilder[C] { @@ -1148,9 +1476,11 @@ trait WrappedEmitMethodBuilder[C] extends WrappedEmitClassBuilder[C] { def newEmitLocal(st: SType, required: Boolean): EmitSettable = emb.newEmitLocal(st, required) - def newEmitLocal(name: String, pt: SType, required: Boolean): EmitSettable = emb.newEmitLocal(name, pt, required) + def newEmitLocal(name: String, pt: SType, required: Boolean): EmitSettable = + emb.newEmitLocal(name, pt, required) } -class EmitFunctionBuilder[F](val apply_method: EmitMethodBuilder[F]) extends WrappedEmitMethodBuilder[F] { - def emb: EmitMethodBuilder[F] = apply_method +final case class EmitFunctionBuilder[F] private (apply_method: EmitMethodBuilder[F]) + extends WrappedEmitMethodBuilder[F] { + override val emb: EmitMethodBuilder[F] = apply_method } diff --git a/hail/src/main/scala/is/hail/expr/ir/EmitCodeBuilder.scala b/hail/src/main/scala/is/hail/expr/ir/EmitCodeBuilder.scala index f02eef8a8a4..8f8d7f5326f 100644 --- a/hail/src/main/scala/is/hail/expr/ir/EmitCodeBuilder.scala +++ b/hail/src/main/scala/is/hail/expr/ir/EmitCodeBuilder.scala @@ -1,16 +1,16 @@ package is.hail.expr.ir -import is.hail.asm4s.{coerce => _, _} +import is.hail.asm4s._ import is.hail.expr.ir.functions.StringFunctions -import is.hail.types.physical.stypes.interfaces.{SStream, SStreamValue} import is.hail.types.physical.stypes.{SSettable, SType, SValue} +import is.hail.types.physical.stypes.interfaces.{SStream, SStreamValue} import is.hail.utils._ - object EmitCodeBuilder { def apply(mb: EmitMethodBuilder[_]): EmitCodeBuilder = new EmitCodeBuilder(mb, Code._empty) - def apply(mb: EmitMethodBuilder[_], code: Code[Unit]): EmitCodeBuilder = new EmitCodeBuilder(mb, code) + def apply(mb: EmitMethodBuilder[_], code: Code[Unit]): EmitCodeBuilder = + new EmitCodeBuilder(mb, code) def scoped[T](mb: EmitMethodBuilder[_])(f: (EmitCodeBuilder) => T): (Code[Unit], T) = { val cb = EmitCodeBuilder(mb) @@ -40,12 +40,12 @@ class EmitCodeBuilder(val emb: EmitMethodBuilder[_], var code: Code[Unit]) exten def mb: MethodBuilder[_] = emb.mb - override def append(c: Code[Unit]): Unit = { + override def append(c: Code[Unit]): Unit = code = Code(code, c) - } override def define(L: CodeLabel): Unit = - if (isOpenEnded) append(L) else { + if (isOpenEnded) append(L) + else { val tmp = code code = new VCode(code.start, L.end, null) tmp.clear() @@ -83,20 +83,24 @@ class EmitCodeBuilder(val emb: EmitMethodBuilder[_], var code: Code[Unit]) exten define(Ltrue) val tval = emitThen val value = newSLocal(tval.st, "ifx_value") - tval.consume(this, { - goto(Lmissing) - }, { tval => - assign(value, tval) - goto(Lpresent) - }) + tval.consume( + this, + goto(Lmissing), + { tval => + assign(value, tval) + goto(Lpresent) + }, + ) define(Lfalse) val fval = emitElse - fval.consume(this, { - goto(Lmissing) - }, { fval => - assign(value, fval) - goto(Lpresent) - }) + fval.consume( + this, + goto(Lmissing), + { fval => + assign(value, fval) + goto(Lpresent) + }, + ) IEmitCode(Lmissing, Lpresent, value, tval.required && fval.required) } @@ -107,17 +111,14 @@ class EmitCodeBuilder(val emb: EmitMethodBuilder[_], var code: Code[Unit]) exten s.store(this, v) } - def assign(s: EmitSettable, v: EmitCode): Unit = { + def assign(s: EmitSettable, v: EmitCode): Unit = s.store(this, v) - } - def assign(s: EmitSettable, v: IEmitCode): Unit = { + def assign(s: EmitSettable, v: IEmitCode): Unit = s.store(this, v) - } - def assign(is: IndexedSeq[EmitSettable], ix: IndexedSeq[EmitCode]): Unit = { - (is, ix).zipped.foreach { (s, c) => s.store(this, c) } - } + def assign(is: IndexedSeq[EmitSettable], ix: IndexedSeq[EmitCode]): Unit = + (is, ix).zipped.foreach((s, c) => s.store(this, c)) def memoizeField(pc: SValue, name: String): SValue = { val f = emb.newPField(name, pc.st) @@ -125,13 +126,11 @@ class EmitCodeBuilder(val emb: EmitMethodBuilder[_], var code: Code[Unit]) exten f } - def memoizeField[T: TypeInfo](v: Code[T], name: String): Value[T] = { + def memoizeField[T: TypeInfo](v: Code[T], name: String): Value[T] = newField[T](name, v) - } - def memoizeField[T: TypeInfo](v: Code[T]): Value[T] = { + def memoizeField[T: TypeInfo](v: Code[T]): Value[T] = memoizeField[T](v, "memoize") - } def memoize(v: EmitCode): EmitValue = memoize(v, "memoize") @@ -161,24 +160,27 @@ class EmitCodeBuilder(val emb: EmitMethodBuilder[_], var code: Code[Unit]) exten } def withScopedMaybeStreamValue[T](ec: EmitCode, name: String)(f: EmitValue => T): T = { - if (ec.st.isRealizable) { - f(memoizeField(ec, name)) - } else { - assert(ec.st.isInstanceOf[SStream]) - val ev = if (ec.required) - EmitValue(None, ec.toI(this).get(this, "")) + val ev = memoizeMaybeStreamValue(ec.toI(this), name) + val res = f(ev) + ec.pv match { + case ss: SStreamValue => + ss.defineUnusedLabels(emb) + case _ => + } + res + } + + def memoizeMaybeStreamValue(iec: IEmitCode, name: String): EmitValue = + if (iec.st.isRealizable) memoizeField(iec, name) + else { + assert(iec.st.isInstanceOf[SStream]) + if (iec.required) EmitValue(None, iec.getOrFatal(this, s"'$name' cannot be missing.")) else { val m = emb.genFieldThisRef[Boolean](name + "_missing") - ec.toI(this).consume(this, assign(m, true), _ => assign(m, false)) - EmitValue(Some(m), ec.pv) - } - val res = f(ev) - ec.pv match { - case ss: SStreamValue => ss.defineUnusedLabels(emb) + iec.consume(this, assign(m, true), _ => assign(m, false)) + EmitValue(Some(m), iec.value) } - res } - } def memoizeField(v: IEmitCode, name: String): EmitValue = { require(v.st.isRealizable) @@ -188,27 +190,32 @@ class EmitCodeBuilder(val emb: EmitMethodBuilder[_], var code: Code[Unit]) exten } private def _invoke[T](callee: EmitMethodBuilder[_], _args: Param*): Value[T] = { - val expectedArgs = callee.emitParamTypes + + // Instance methods must supply `this` in first position. + val expectedArgs = + if (callee.mb.isStatic) callee.emitParamTypes + else CodeParamType(callee.ecb.cb.ti) +: callee.emitParamTypes + val args = _args.toArray + if (expectedArgs.size != args.length) - throw new RuntimeException(s"invoke ${ callee.mb.methodName }: wrong number of parameters: " + - s"expected ${ expectedArgs.size }, found ${ args.length }") - val codeArgs = args.indices.flatMap { i => - val arg = args(i) - val pt = expectedArgs(i) + throw new RuntimeException(s"invoke ${callee.mb.methodName}: wrong number of parameters: " + + s"expected ${expectedArgs.size}, found ${args.length}") + + val codeArgs = args.zip(expectedArgs).zipWithIndex.flatMap { case ((arg, pt), i) => (arg, pt) match { case (CodeParam(c), cpt: CodeParamType) => if (c.ti != cpt.ti) - throw new RuntimeException(s"invoke ${ callee.mb.methodName }: arg $i: type mismatch:" + - s"\n got ${ c.ti }" + - s"\n expected ${ cpt.ti }" + - s"\n all param types: ${expectedArgs}-") + throw new RuntimeException(s"invoke ${callee.mb.methodName}: arg $i: type mismatch:" + + s"\n got ${c.ti}" + + s"\n expected ${cpt.ti}" + + s"\n all param types: $expectedArgs-") FastSeq(c) case (SCodeParam(pc), pcpt: SCodeParamType) => if (pc.st != pcpt.st) - throw new RuntimeException(s"invoke ${ callee.mb.methodName }: arg $i: type mismatch:" + - s"\n got ${ pc.st }" + - s"\n expected ${ pcpt.st }") + throw new RuntimeException(s"invoke ${callee.mb.methodName}: arg $i: type mismatch:" + + s"\n got ${pc.st}" + + s"\n expected ${pcpt.st}") pc.valueTuple case (EmitParam(ec), SCodeEmitParamType(et)) => if (!ec.emitType.equalModuloRequired(et)) { @@ -220,23 +227,24 @@ class EmitCodeBuilder(val emb: EmitMethodBuilder[_], var code: Code[Unit]) exten val castEc = (ec.required, et.required) match { case (true, false) => ec.setOptional case (false, true) => - EmitCode.fromI(emb) { cb => IEmitCode.present(cb, ec.toI(cb).get(cb)) } + EmitCode.fromI(emb)(cb => IEmitCode.present(cb, ec.toI(cb).getOrAssert(cb))) case _ => ec } val castEv = memoize(castEc, "_invoke") castEv.valueTuple() case (arg, expected) => - throw new RuntimeException(s"invoke ${ callee.mb.methodName }: arg $i: type mismatch:" + - s"\n got ${ arg }" + - s"\n expected ${ expected }") + throw new RuntimeException(s"invoke ${callee.mb.methodName}: arg $i: type mismatch:" + + s"\n got $arg" + + s"\n expected $expected") } } - callee.mb.invoke(this, codeArgs: _*) + + super.invoke[T](callee.mb, codeArgs: _*) } def invokeVoid(callee: EmitMethodBuilder[_], args: Param*): Unit = { assert(callee.emitReturnType == CodeParamType(UnitInfo)) - append(_invoke[Unit](callee, args: _*)) + _invoke[Unit](callee, args: _*) } def invokeCode[T](callee: EmitMethodBuilder[_], args: Param*): Value[T] = { @@ -244,7 +252,8 @@ class EmitCodeBuilder(val emb: EmitMethodBuilder[_], var code: Code[Unit]) exten case CodeParamType(UnitInfo) => throw new AssertionError("CodeBuilder.invokeCode had unit return type, use invokeVoid") case _: CodeParamType => - case x => throw new AssertionError(s"CodeBuilder.invokeCode expects CodeParamType return, got $x") + case x => + throw new AssertionError(s"CodeBuilder.invokeCode expects CodeParamType return, got $x") } _invoke[T](callee, args: _*) } @@ -260,9 +269,8 @@ class EmitCodeBuilder(val emb: EmitMethodBuilder[_], var code: Code[Unit]) exten } // for debugging - def strValue(sc: SValue): Code[String] = { + def strValue(sc: SValue): Code[String] = StringFunctions.svalueToJavaValue(this, emb.partitionRegion, sc).invoke[String]("toString") - } def strValue(ec: EmitCode): Code[String] = { val s = newLocal[String]("s") @@ -273,15 +281,24 @@ class EmitCodeBuilder(val emb: EmitMethodBuilder[_], var code: Code[Unit]) exten // for debugging def println(cString: Code[String]*) = this += Code._printlns(cString: _*) - def logInfo(cs: Code[String]*): Unit = { - this += Code.invokeScalaObject1[String, Unit](LogHelper.getClass, "logInfo", cs.reduce[Code[String]] { case (l, r) => (l.concat(r)) }) - } - - def warning(cs: Code[String]*): Unit = { - this += Code.invokeScalaObject1[String, Unit](LogHelper.getClass, "warning", cs.reduce[Code[String]] { case (l, r) => (l.concat(r)) }) - } - - def consoleInfo(cs: Code[String]*): Unit = { - this += Code.invokeScalaObject1[String, Unit](LogHelper.getClass, "consoleInfo", cs.reduce[Code[String]] { case (l, r) => (l.concat(r)) }) - } + def logInfo(cs: Code[String]*): Unit = + this += Code.invokeScalaObject1[String, Unit]( + LogHelper.getClass, + "logInfo", + cs.reduce[Code[String]] { case (l, r) => (l.concat(r)) }, + ) + + def warning(cs: Code[String]*): Unit = + this += Code.invokeScalaObject1[String, Unit]( + LogHelper.getClass, + "warning", + cs.reduce[Code[String]] { case (l, r) => (l.concat(r)) }, + ) + + def consoleInfo(cs: Code[String]*): Unit = + this += Code.invokeScalaObject1[String, Unit]( + LogHelper.getClass, + "consoleInfo", + cs.reduce[Code[String]] { case (l, r) => (l.concat(r)) }, + ) } diff --git a/hail/src/main/scala/is/hail/expr/ir/EmitStreamDistribute.scala b/hail/src/main/scala/is/hail/expr/ir/EmitStreamDistribute.scala index 57883f2c8fa..943c2877a3a 100644 --- a/hail/src/main/scala/is/hail/expr/ir/EmitStreamDistribute.scala +++ b/hail/src/main/scala/is/hail/expr/ir/EmitStreamDistribute.scala @@ -1,24 +1,39 @@ package is.hail.expr.ir import is.hail.annotations.Region -import is.hail.asm4s.{Code, Value, const, _} +import is.hail.asm4s.{const, Code, Value, _} import is.hail.expr.ir.functions.MathFunctions -import is.hail.expr.ir.orderings.StructOrdering import is.hail.io.{AbstractTypedCodecSpec, OutputBuffer} import is.hail.types.physical._ +import is.hail.types.physical.stypes.{EmitType, SValue} import is.hail.types.physical.stypes.concrete._ -import is.hail.types.physical.stypes.interfaces.{SBaseStruct, SIndexableValue, SStreamValue, primitive} +import is.hail.types.physical.stypes.interfaces.{ + primitive, SBaseStruct, SIndexableValue, SStreamValue, +} import is.hail.types.physical.stypes.primitives.{SBooleanValue, SInt32, SInt32Value} -import is.hail.types.physical.stypes.{EmitType, SValue} import is.hail.types.virtual.TBaseStruct import is.hail.utils._ object EmitStreamDistribute { - def emit(cb: EmitCodeBuilder, region: Value[Region], requestedSplittersAndEndsVal: SIndexableValue, childStream: SStreamValue, pathVal: SValue, comparisonOp: ComparisonOp[_], spec: AbstractTypedCodecSpec): SIndexableValue = { + def emit( + cb: EmitCodeBuilder, + region: Value[Region], + requestedSplittersAndEndsVal: SIndexableValue, + childStream: SStreamValue, + pathVal: SValue, + comparisonOp: ComparisonOp[_], + spec: AbstractTypedCodecSpec, + ): SIndexableValue = { val mb = cb.emb val pivotsPType = requestedSplittersAndEndsVal.st.storageType().asInstanceOf[PCanonicalArray] - val requestedSplittersVal = requestedSplittersAndEndsVal.sliceArray(cb, region, pivotsPType, 1, requestedSplittersAndEndsVal.loadLength() - 1) + val requestedSplittersVal = requestedSplittersAndEndsVal.sliceArray( + cb, + region, + pivotsPType, + 1, + requestedSplittersAndEndsVal.loadLength() - 1, + ) val keyType = requestedSplittersVal.st.elementType.asInstanceOf[SBaseStruct] val keyPType = pivotsPType.elementType @@ -27,67 +42,162 @@ object EmitStreamDistribute { def compare(cb: EmitCodeBuilder, lelt: EmitValue, relt: EmitValue): Code[Int] = { val lhs = lelt.map(cb)(_.asBaseStruct.subset(keyFieldNames: _*)) val rhs = relt.map(cb)(_.asBaseStruct.subset(keyFieldNames: _*)) - val codeOrdering = comparisonOp.codeOrdering(cb.emb.ecb, lhs.st.asInstanceOf[SBaseStruct], rhs.st.asInstanceOf[SBaseStruct]) + val codeOrdering = comparisonOp.codeOrdering( + cb.emb.ecb, + lhs.st.asInstanceOf[SBaseStruct], + rhs.st.asInstanceOf[SBaseStruct], + ) codeOrdering(cb, lhs, rhs).asInstanceOf[Value[Int]] } - def equal(cb: EmitCodeBuilder, lelt: EmitValue, relt: EmitValue): Code[Boolean] = compare(cb, lelt, relt) ceq 0 - - def lessThan(cb: EmitCodeBuilder, lelt: EmitValue, relt: EmitValue): Code[Boolean] = compare(cb, lelt, relt) < 0 - - val filledInTreeSize = Code.invokeScalaObject1[Int, Int](MathFunctions.getClass, "roundToNextPowerOf2", requestedSplittersVal.loadLength() + 1) - val treeHeight = cb.memoize[Int](Code.invokeScalaObject1[Int, Int](MathFunctions.getClass, "log2", filledInTreeSize)) + def equal(cb: EmitCodeBuilder, lelt: EmitValue, relt: EmitValue): Code[Boolean] = + compare(cb, lelt, relt) ceq 0 + + def lessThan(cb: EmitCodeBuilder, lelt: EmitValue, relt: EmitValue): Code[Boolean] = + compare(cb, lelt, relt) < 0 + + val filledInTreeSize = Code.invokeScalaObject1[Int, Int]( + MathFunctions.getClass, + "roundToNextPowerOf2", + requestedSplittersVal.loadLength() + 1, + ) + val treeHeight = cb.memoize[Int](Code.invokeScalaObject1[Int, Int]( + MathFunctions.getClass, + "log2", + filledInTreeSize, + )) val paddedSplittersSize = cb.memoize[Int](const(1) << treeHeight) val uniqueSplittersIdx = cb.newLocal[Int]("unique_splitters_idx", 0) - def cleanupSplitters()= { - // Copy each unique splitter into array. If it is seen twice, set a boolean in a parallel array for that + def cleanupSplitters() = { + /* Copy each unique splitter into array. If it is seen twice, set a boolean in a parallel + * array for that */ // splitter, so we know what identity buckets to make later val paddedSplittersPType = PCanonicalArray(keyPType) val splittersWasDuplicatedPType = PCanonicalArray(PBooleanRequired) - val paddedSplittersAddr = cb.memoize[Long](paddedSplittersPType.allocate(region, paddedSplittersSize)) + val paddedSplittersAddr = + cb.memoize[Long](paddedSplittersPType.allocate(region, paddedSplittersSize)) paddedSplittersPType.stagedInitialize(cb, paddedSplittersAddr, paddedSplittersSize) val splittersWasDuplicatedLength = paddedSplittersSize - val splittersWasDuplicatedAddr = cb.memoize[Long](splittersWasDuplicatedPType.allocate(region, splittersWasDuplicatedLength)) - splittersWasDuplicatedPType.stagedInitialize(cb, splittersWasDuplicatedAddr, splittersWasDuplicatedLength) + val splittersWasDuplicatedAddr = + cb.memoize[Long](splittersWasDuplicatedPType.allocate(region, splittersWasDuplicatedLength)) + splittersWasDuplicatedPType.stagedInitialize( + cb, + splittersWasDuplicatedAddr, + splittersWasDuplicatedLength, + ) val splitters: SIndexableValue = paddedSplittersPType.loadCheapSCode(cb, paddedSplittersAddr) val requestedSplittersIdx = cb.newLocal[Int]("stream_distribute_splitters_index") val lastKeySeen = cb.emb.newEmitLocal("stream_distribute_last_seen", keyType, false) - cb.for_(cb.assign(requestedSplittersIdx, 0), requestedSplittersIdx < requestedSplittersVal.loadLength(), cb.assign(requestedSplittersIdx, requestedSplittersIdx + 1), { - val currentSplitter = requestedSplittersVal.loadElement(cb, requestedSplittersIdx).memoize(cb, "stream_distribute_current_splitter") - cb.if_(requestedSplittersIdx ceq 0, { - paddedSplittersPType.elementType.storeAtAddress(cb, paddedSplittersPType.loadElement(paddedSplittersAddr, paddedSplittersSize, 0), region, currentSplitter.get(cb), false) - splittersWasDuplicatedPType.elementType.storeAtAddress(cb, splittersWasDuplicatedPType.loadElement(splittersWasDuplicatedAddr, splittersWasDuplicatedLength, uniqueSplittersIdx), region, new SBooleanValue(false), false) - cb.assign(uniqueSplittersIdx, uniqueSplittersIdx + 1) - }, { - cb.if_(!equal(cb, lastKeySeen, currentSplitter), { - // write to pos in splitters - paddedSplittersPType.elementType.storeAtAddress(cb, paddedSplittersPType.loadElement(paddedSplittersAddr, paddedSplittersSize, uniqueSplittersIdx), region, currentSplitter.get(cb), false) - splittersWasDuplicatedPType.elementType.storeAtAddress(cb, splittersWasDuplicatedPType.loadElement(splittersWasDuplicatedAddr, splittersWasDuplicatedLength, uniqueSplittersIdx), region, new SBooleanValue(false), false) - cb.assign(uniqueSplittersIdx, uniqueSplittersIdx + 1) - }, { - splittersWasDuplicatedPType.elementType.storeAtAddress(cb, splittersWasDuplicatedPType.loadElement(splittersWasDuplicatedAddr, splittersWasDuplicatedLength, uniqueSplittersIdx - 1), region, new SBooleanValue(true), false) - }) - }) - cb.assign(lastKeySeen, currentSplitter) - }) + cb.for_( + cb.assign(requestedSplittersIdx, 0), + requestedSplittersIdx < requestedSplittersVal.loadLength(), + cb.assign(requestedSplittersIdx, requestedSplittersIdx + 1), { + val currentSplitter = requestedSplittersVal.loadElement( + cb, + requestedSplittersIdx, + ).memoize(cb, "stream_distribute_current_splitter") + cb.if_( + requestedSplittersIdx ceq 0, { + paddedSplittersPType.elementType.storeAtAddress( + cb, + paddedSplittersPType.loadElement(paddedSplittersAddr, paddedSplittersSize, 0), + region, + currentSplitter.get(cb), + false, + ) + splittersWasDuplicatedPType.elementType.storeAtAddress( + cb, + splittersWasDuplicatedPType.loadElement( + splittersWasDuplicatedAddr, + splittersWasDuplicatedLength, + uniqueSplittersIdx, + ), + region, + new SBooleanValue(false), + false, + ) + cb.assign(uniqueSplittersIdx, uniqueSplittersIdx + 1) + }, { + cb.if_( + !equal(cb, lastKeySeen, currentSplitter), { + // write to pos in splitters + paddedSplittersPType.elementType.storeAtAddress( + cb, + paddedSplittersPType.loadElement( + paddedSplittersAddr, + paddedSplittersSize, + uniqueSplittersIdx, + ), + region, + currentSplitter.get(cb), + false, + ) + splittersWasDuplicatedPType.elementType.storeAtAddress( + cb, + splittersWasDuplicatedPType.loadElement( + splittersWasDuplicatedAddr, + splittersWasDuplicatedLength, + uniqueSplittersIdx, + ), + region, + new SBooleanValue(false), + false, + ) + cb.assign(uniqueSplittersIdx, uniqueSplittersIdx + 1) + }, + splittersWasDuplicatedPType.elementType.storeAtAddress( + cb, + splittersWasDuplicatedPType.loadElement( + splittersWasDuplicatedAddr, + splittersWasDuplicatedLength, + uniqueSplittersIdx - 1, + ), + region, + new SBooleanValue(true), + false, + ), + ) + }, + ) + cb.assign(lastKeySeen, currentSplitter) + }, + ) val numUniqueSplitters = cb.memoize[Int](uniqueSplittersIdx) // Pad out the rest of the splitters array so tree later is balanced. - cb.for_({}, uniqueSplittersIdx < paddedSplittersSize, cb.assign(uniqueSplittersIdx, uniqueSplittersIdx + 1), { - cb.if_(lastKeySeen.get(cb).asInstanceOf[SBaseStructPointerSettable].a ceq const(0L), cb._fatal("paddedSplitterSize was ", paddedSplittersSize.toS)) - val loaded = paddedSplittersPType.loadElement(paddedSplittersAddr, paddedSplittersSize, uniqueSplittersIdx) - paddedSplittersPType.elementType.storeAtAddress(cb, loaded, region, lastKeySeen.get(cb), false) - }) + cb.for_( + {}, + uniqueSplittersIdx < paddedSplittersSize, + cb.assign(uniqueSplittersIdx, uniqueSplittersIdx + 1), { + cb.if_( + lastKeySeen.get(cb).asInstanceOf[SBaseStructPointerSettable].a ceq const(0L), + cb._fatal("paddedSplitterSize was ", paddedSplittersSize.toS), + ) + val loaded = paddedSplittersPType.loadElement( + paddedSplittersAddr, + paddedSplittersSize, + uniqueSplittersIdx, + ) + paddedSplittersPType.elementType.storeAtAddress( + cb, + loaded, + region, + lastKeySeen.get(cb), + false, + ) + }, + ) - val splitterWasDuplicated = splittersWasDuplicatedPType.loadCheapSCode(cb, splittersWasDuplicatedAddr) + val splitterWasDuplicated = + splittersWasDuplicatedPType.loadCheapSCode(cb, splittersWasDuplicatedAddr) (splitters, numUniqueSplitters, splitterWasDuplicated) } @@ -95,48 +205,99 @@ object EmitStreamDistribute { val treeAddr = cb.memoize[Long](treePType.allocate(region, paddedSplittersSize)) treePType.stagedInitialize(cb, treeAddr, paddedSplittersSize) - /* - Walk through the array one level of the tree at a time, filling in the tree as you go to get a breadth - first traversal of the tree. - */ + /* Walk through the array one level of the tree at a time, filling in the tree as you go to + * get a breadth first traversal of the tree. */ val currentHeight = cb.newLocal[Int]("stream_dist_current_height") val treeFillingIndex = cb.newLocal[Int]("stream_dist_tree_filling_idx", 1) - cb.for_(cb.assign(currentHeight, treeHeight - 1), currentHeight >= 0, cb.assign(currentHeight, currentHeight - 1), { - val startingPoint = cb.memoize[Int]((const(1) << currentHeight) - 1) - val inner = cb.newLocal[Int]("stream_dist_tree_inner") - cb.for_(cb.assign(inner, 0), inner < (const(1) << (treeHeight - 1 - currentHeight)), cb.assign(inner, inner + 1), { - val elementLoaded = paddedSplitters.loadElement(cb, startingPoint + inner * (const(1) << (currentHeight + 1))).get(cb) - keyPType.storeAtAddress(cb, treePType.loadElement(treeAddr, treeFillingIndex), region, - elementLoaded, false) - cb.assign(treeFillingIndex, treeFillingIndex + 1) - }) - }) + cb.for_( + cb.assign(currentHeight, treeHeight - 1), + currentHeight >= 0, + cb.assign(currentHeight, currentHeight - 1), { + val startingPoint = cb.memoize[Int]((const(1) << currentHeight) - 1) + val inner = cb.newLocal[Int]("stream_dist_tree_inner") + cb.for_( + cb.assign(inner, 0), + inner < (const(1) << (treeHeight - 1 - currentHeight)), + cb.assign(inner, inner + 1), { + val elementLoaded = paddedSplitters.loadElement( + cb, + startingPoint + inner * (const(1) << (currentHeight + 1)), + ).getOrAssert(cb) + keyPType.storeAtAddress( + cb, + treePType.loadElement(treeAddr, treeFillingIndex), + region, + elementLoaded, + false, + ) + cb.assign(treeFillingIndex, treeFillingIndex + 1) + }, + ) + }, + ) // 0th element is garbage, tree elements start at idx 1. treePType.loadCheapSCode(cb, treeAddr) } - def createFileMapping(numFilesToWrite: Value[Int], splitterWasDuplicated: SIndexableValue, numberOfBuckets: Value[Int], shouldUseIdentityBuckets: Value[Boolean]): SIndexablePointerValue = { - // The element classifying algorithm acts as though there are identity buckets for every splitter. We only use identity buckets for elements that repeat - // in splitters list. Since We don't want many empty files, we need to make an array mapping output buckets to files. + def createFileMapping( + numFilesToWrite: Value[Int], + splitterWasDuplicated: SIndexableValue, + numberOfBuckets: Value[Int], + shouldUseIdentityBuckets: Value[Boolean], + ): SIndexablePointerValue = { + /* The element classifying algorithm acts as though there are identity buckets for every + * splitter. We only use identity buckets for elements that repeat */ + /* in splitters list. Since We don't want many empty files, we need to make an array mapping + * output buckets to files. */ val fileMappingType = PCanonicalArray(PInt32Required) val fileMappingAddr = cb.memoize(fileMappingType.allocate(region, numberOfBuckets)) fileMappingType.stagedInitialize(cb, fileMappingAddr, numberOfBuckets) val bucketIdx = cb.newLocal[Int]("stream_dist_bucket_idx") val currentFileToMapTo = cb.newLocal[Int]("stream_dist_mapping_cur_storage", 0) - def destFileSCode(cb: EmitCodeBuilder) = new SInt32Value(cb.memoize((currentFileToMapTo >= numFilesToWrite).mux(numFilesToWrite - 1, currentFileToMapTo))) + def destFileSCode(cb: EmitCodeBuilder) = + new SInt32Value(cb.memoize((currentFileToMapTo >= numFilesToWrite).mux( + numFilesToWrite - 1, + currentFileToMapTo, + ))) val indexIncrement = cb.newLocal[Int]("stream_dist_create_file_mapping_increment") cb.if_(shouldUseIdentityBuckets, cb.assign(indexIncrement, 2), cb.assign(indexIncrement, 1)) - cb.for_(cb.assign(bucketIdx, 0), bucketIdx < numberOfBuckets, cb.assign(bucketIdx, bucketIdx + indexIncrement), { - fileMappingType.elementType.storeAtAddress(cb, fileMappingType.loadElement(fileMappingAddr, numberOfBuckets, bucketIdx), region, destFileSCode(cb), false) - cb.if_(shouldUseIdentityBuckets, { - cb.assign(currentFileToMapTo, currentFileToMapTo + splitterWasDuplicated.loadElement(cb, bucketIdx / 2).get(cb).asBoolean.value.toI) - fileMappingType.elementType.storeAtAddress(cb, fileMappingType.loadElement(fileMappingAddr, numberOfBuckets, bucketIdx + 1), region, destFileSCode(cb), false) - }) - cb.assign(currentFileToMapTo, currentFileToMapTo + 1) - }) + cb.for_( + cb.assign(bucketIdx, 0), + bucketIdx < numberOfBuckets, + cb.assign(bucketIdx, bucketIdx + indexIncrement), { + fileMappingType.elementType.storeAtAddress( + cb, + fileMappingType.loadElement(fileMappingAddr, numberOfBuckets, bucketIdx), + region, + destFileSCode(cb), + false, + ) + cb.if_( + shouldUseIdentityBuckets, { + cb.assign( + currentFileToMapTo, + currentFileToMapTo + splitterWasDuplicated.loadElement( + cb, + bucketIdx / 2, + ).getOrAssert( + cb + ).asBoolean.value.toI, + ) + fileMappingType.elementType.storeAtAddress( + cb, + fileMappingType.loadElement(fileMappingAddr, numberOfBuckets, bucketIdx + 1), + region, + destFileSCode(cb), + false, + ) + }, + ) + cb.assign(currentFileToMapTo, currentFileToMapTo + 1) + }, + ) fileMappingType.loadCheapSCode(cb, fileMappingAddr) } @@ -144,40 +305,63 @@ object EmitStreamDistribute { val (paddedSplitters, numUniqueSplitters, splitterWasDuplicated) = cleanupSplitters() val tree = buildTree(paddedSplitters, PCanonicalArray(keyPType)) - val shouldUseIdentityBuckets = cb.memoize[Boolean](numUniqueSplitters < requestedSplittersVal.loadLength()) + val shouldUseIdentityBuckets = + cb.memoize[Boolean](numUniqueSplitters < requestedSplittersVal.loadLength()) val numberOfBuckets = cb.newLocal[Int]("stream_dist_number_of_buckets") - cb.if_(shouldUseIdentityBuckets, + cb.if_( + shouldUseIdentityBuckets, cb.assign(numberOfBuckets, const(1) << (treeHeight + 1)), - cb.assign(numberOfBuckets, const(1) << treeHeight)) + cb.assign(numberOfBuckets, const(1) << treeHeight), + ) - // Without identity buckets you'd have numUniqueSplitters + 1 buckets, but we have to add an extra for each identity bucket. + /* Without identity buckets you'd have numUniqueSplitters + 1 buckets, but we have to add an + * extra for each identity bucket. */ // FIXME: We should have less files if we aren't writing endpoint buckets. val numFilesToWrite = cb.newLocal[Int]("stream_dist_num_files_to_write", 1) - cb.for_(cb.assign(uniqueSplittersIdx, 0), uniqueSplittersIdx < numUniqueSplitters, cb.assign(uniqueSplittersIdx, uniqueSplittersIdx + 1), { - cb.assign(numFilesToWrite, numFilesToWrite + 1 + splitterWasDuplicated.loadElement(cb, uniqueSplittersIdx).get(cb).asBoolean.value.toI) - }) - - val fileMapping = createFileMapping(numFilesToWrite, splitterWasDuplicated, numberOfBuckets, shouldUseIdentityBuckets) - - val outputBuffers = cb.memoize[Array[OutputBuffer]](Code.newArray[OutputBuffer](numFilesToWrite)) + cb.for_( + cb.assign(uniqueSplittersIdx, 0), + uniqueSplittersIdx < numUniqueSplitters, + cb.assign(uniqueSplittersIdx, uniqueSplittersIdx + 1), + cb.assign( + numFilesToWrite, + numFilesToWrite + 1 + splitterWasDuplicated.loadElement(cb, uniqueSplittersIdx).getOrAssert( + cb + ).asBoolean.value.toI, + ), + ) + + val fileMapping = createFileMapping( + numFilesToWrite, + splitterWasDuplicated, + numberOfBuckets, + shouldUseIdentityBuckets, + ) + + val outputBuffers = + cb.memoize[Array[OutputBuffer]](Code.newArray[OutputBuffer](numFilesToWrite)) val numElementsPerFile = cb.memoize[Array[Int]](Code.newArray[Int](numFilesToWrite)) val numBytesPerFile = cb.memoize[Array[Long]](Code.newArray[Long](numFilesToWrite)) val fileArrayIdx = cb.newLocal[Int]("stream_dist_file_array_idx") - def makeFileName(cb: EmitCodeBuilder, fileIdx: Value[Int]): Value[String] = { + def makeFileName(cb: EmitCodeBuilder, fileIdx: Value[Int]): Value[String] = cb.memoize(pathVal.asString.loadString(cb) concat const("/sorted_part_") concat fileIdx.toS) - } - cb.for_(cb.assign(fileArrayIdx, 0), fileArrayIdx < numFilesToWrite, cb.assign(fileArrayIdx, fileArrayIdx + 1), { - val fileName = makeFileName(cb, fileArrayIdx) - val ob = cb.memoize(spec.buildCodeOutputBuffer(mb.createUnbuffered(fileName))) - cb += outputBuffers.update(fileArrayIdx, ob) - cb += numElementsPerFile.update(fileArrayIdx, 0) - cb += numBytesPerFile.update(fileArrayIdx, 0) - }) - // The element classifying algorithm acts as though there are identity buckets for every splitter. We only use identity buckets for elements that repeat - // in splitters list. Since We don't want many empty files, we need to make an array mapping output buckets to files. + cb.for_( + cb.assign(fileArrayIdx, 0), + fileArrayIdx < numFilesToWrite, + cb.assign(fileArrayIdx, fileArrayIdx + 1), { + val fileName = makeFileName(cb, fileArrayIdx) + val ob = cb.memoize(spec.buildCodeOutputBuffer(mb.createUnbuffered(fileName))) + cb += outputBuffers.update(fileArrayIdx, ob) + cb += numElementsPerFile.update(fileArrayIdx, 0) + cb += numBytesPerFile.update(fileArrayIdx, 0) + }, + ) + /* The element classifying algorithm acts as though there are identity buckets for every + * splitter. We only use identity buckets for elements that repeat */ + /* in splitters list. Since We don't want many empty files, we need to make an array mapping + * output buckets to files. */ val encoder = spec.encodedType.buildEncoder(childStream.st.elementType, cb.emb.ecb) val producer = childStream.getProducer(mb) @@ -187,15 +371,33 @@ object EmitStreamDistribute { cb.assign(current, producer.element) val r = cb.newLocal[Int]("stream_dist_r") - cb.for_(cb.assign(r, 0), r < treeHeight, cb.assign(r, r + 1), { - val treeAtB = tree.loadElement(cb, b).memoize(cb, "stream_dist_tree_b") - cb.assign(b, const(2) * b + lessThan(cb, treeAtB, current).toI) - }) - cb.if_(shouldUseIdentityBuckets, { - cb.assign(b, const(2) * b + 1 - lessThan(cb, current, paddedSplitters.loadElement(cb, b - numberOfBuckets / 2).memoize(cb, "stream_dist_splitter_compare")).toI) - }) + cb.for_( + cb.assign(r, 0), + r < treeHeight, + cb.assign(r, r + 1), { + val treeAtB = tree.loadElement(cb, b).memoize(cb, "stream_dist_tree_b") + cb.assign(b, const(2) * b + lessThan(cb, treeAtB, current).toI) + }, + ) + cb.if_( + shouldUseIdentityBuckets, + cb.assign( + b, + const(2) * b + 1 - lessThan( + cb, + current, + paddedSplitters.loadElement(cb, b - numberOfBuckets / 2).memoize( + cb, + "stream_dist_splitter_compare", + ), + ).toI, + ), + ) - val fileToUse = cb.memoize[Int](fileMapping.loadElement(cb, b - numberOfBuckets).get(cb).asInt.value) + val fileToUse = + cb.memoize[Int]( + fileMapping.loadElement(cb, b - numberOfBuckets).getOrAssert(cb).asInt.value + ) val ob = cb.memoize[OutputBuffer](outputBuffers(fileToUse)) @@ -203,106 +405,204 @@ object EmitStreamDistribute { val curSV = current.get(cb) encoder(cb, curSV, ob) cb += numElementsPerFile.update(fileToUse, numElementsPerFile(fileToUse) + 1) - cb += numBytesPerFile.update(fileToUse, numBytesPerFile(fileToUse) + curSV.sizeToStoreInBytes(cb).value) + cb += numBytesPerFile.update( + fileToUse, + numBytesPerFile(fileToUse) + curSV.sizeToStoreInBytes(cb).value, + ) } - cb.for_(cb.assign(fileArrayIdx, 0), fileArrayIdx < numFilesToWrite, cb.assign(fileArrayIdx, fileArrayIdx + 1), { - val ob = cb.memoize[OutputBuffer](outputBuffers(fileArrayIdx)) - cb += ob.writeByte(0.asInstanceOf[Byte]) - cb += ob.invoke[Unit]("close") - }) + cb.for_( + cb.assign(fileArrayIdx, 0), + fileArrayIdx < numFilesToWrite, + cb.assign(fileArrayIdx, fileArrayIdx + 1), { + val ob = cb.memoize[OutputBuffer](outputBuffers(fileArrayIdx)) + cb += ob.writeByte(0.asInstanceOf[Byte]) + cb += ob.invoke[Unit]("close") + }, + ) val intervalType = PCanonicalInterval(keyPType.setRequired(false), true) - val returnType = PCanonicalArray(PCanonicalStruct(("interval", intervalType), ("fileName", PCanonicalStringRequired), ("numElements", PInt32Required), ("numBytes", PInt64Required)), true) + val returnType = PCanonicalArray( + PCanonicalStruct( + ("interval", intervalType), + ("fileName", PCanonicalStringRequired), + ("numElements", PInt32Required), + ("numBytes", PInt64Required), + ), + true, + ) val min = requestedSplittersAndEndsVal.loadElement(cb, 0).memoize(cb, "stream_dist_min") val firstSplitter = paddedSplitters.loadElement(cb, 0).memoize(cb, "stream_dist_first_splitter") - val max = requestedSplittersAndEndsVal.loadElement(cb, requestedSplittersAndEndsVal.loadLength() - 1).memoize(cb, "stream_dist_min") - val lastSplitter = paddedSplitters.loadElement(cb, paddedSplitters.loadLength() - 1).memoize(cb, "stream_dist_last_splitter") - - val skipMinInterval = cb.memoize(equal(cb, min, firstSplitter) && splitterWasDuplicated.loadElement(cb, 0).get(cb).asBoolean.value) + val max = requestedSplittersAndEndsVal.loadElement( + cb, + requestedSplittersAndEndsVal.loadLength() - 1, + ).memoize(cb, "stream_dist_min") + val lastSplitter = paddedSplitters.loadElement(cb, paddedSplitters.loadLength() - 1).memoize( + cb, + "stream_dist_last_splitter", + ) + + val skipMinInterval = cb.memoize(equal( + cb, + min, + firstSplitter, + ) && splitterWasDuplicated.loadElement(cb, 0).getOrAssert(cb).asBoolean.value) val skipMaxInterval = cb.memoize(equal(cb, max, lastSplitter)) - val (pushElement, finisher) = returnType.constructFromFunctions(cb, region, cb.memoize(numFilesToWrite - skipMinInterval.toI - skipMaxInterval.toI), false) - - val stackStructType = new SStackStruct(returnType.virtualType.elementType.asInstanceOf[TBaseStruct], IndexedSeq( - EmitType(intervalType.sType, true), - EmitType(SJavaString, true), - EmitType(SInt32, true) - )) + val (pushElement, finisher) = returnType.constructFromFunctions( + cb, + region, + cb.memoize(numFilesToWrite - skipMinInterval.toI - skipMaxInterval.toI), + false, + ) + + val stackStructType = new SStackStruct( + returnType.virtualType.elementType.asInstanceOf[TBaseStruct], + IndexedSeq( + EmitType(intervalType.sType, true), + EmitType(SJavaString, true), + EmitType(SInt32, true), + ), + ) // Add first, but only if min != first key. - cb.if_(!skipMinInterval, { - val firstInterval = intervalType.constructFromCodes(cb, region, - min, - firstSplitter, - true, - cb.memoize(!splitterWasDuplicated.loadElement(cb, 0).get(cb).asBoolean.value) - ) - - pushElement(cb, IEmitCode.present(cb, new SStackStructValue(stackStructType, IndexedSeq( - EmitValue.present(firstInterval), - EmitValue.present(SJavaString.construct(cb, makeFileName(cb, 0))), - EmitValue.present(primitive(cb.memoize(numElementsPerFile(0)))), - EmitValue.present(primitive(cb.memoize(numBytesPerFile(0)))) - )))) - }) - - cb.for_({cb.assign(uniqueSplittersIdx, 0); cb.assign(fileArrayIdx, 1) }, uniqueSplittersIdx < numUniqueSplitters, cb.assign(uniqueSplittersIdx, uniqueSplittersIdx + 1), { - cb.if_(uniqueSplittersIdx cne 0, { - val intervalFromLastToThis = intervalType.constructFromCodes(cb, region, - EmitCode.fromI(cb.emb)(cb => paddedSplitters.loadElement(cb, uniqueSplittersIdx - 1)), - EmitCode.fromI(cb.emb)(cb => paddedSplitters.loadElement(cb, uniqueSplittersIdx)), - false, - cb.memoize(!splitterWasDuplicated.loadElement(cb, uniqueSplittersIdx).get(cb).asBoolean.value) - ) - - pushElement(cb, IEmitCode.present(cb, new SStackStructValue(stackStructType, IndexedSeq( - EmitValue.present(intervalFromLastToThis), - EmitValue.present(SJavaString.construct(cb, makeFileName(cb, fileArrayIdx))), - EmitValue.present(primitive(cb.memoize(numElementsPerFile(fileArrayIdx)))), - EmitValue.present(primitive(cb.memoize(numBytesPerFile(fileArrayIdx)))) - )))) - - cb.assign(fileArrayIdx, fileArrayIdx + 1) - }) - - // Now, maybe have to make an identity bucket. - cb.if_(splitterWasDuplicated.loadElement(cb, uniqueSplittersIdx).get(cb).asBoolean.value, { - val identityInterval = intervalType.constructFromCodes(cb, region, - EmitCode.fromI(cb.emb)(cb => paddedSplitters.loadElement(cb, uniqueSplittersIdx)), - EmitCode.fromI(cb.emb)(cb => paddedSplitters.loadElement(cb, uniqueSplittersIdx)), + cb.if_( + !skipMinInterval, { + val firstInterval = intervalType.constructFromCodes( + cb, + region, + min, + firstSplitter, true, - true + cb.memoize(!splitterWasDuplicated.loadElement(cb, 0).getOrAssert(cb).asBoolean.value), ) - pushElement(cb, IEmitCode.present(cb, new SStackStructValue(stackStructType, IndexedSeq( - EmitValue.present(identityInterval), - EmitValue.present(SJavaString.construct(cb, makeFileName(cb, fileArrayIdx))), - EmitValue.present(primitive(cb.memoize(numElementsPerFile(fileArrayIdx)))), - EmitValue.present(primitive(cb.memoize(numBytesPerFile(fileArrayIdx)))) - )))) + pushElement( + cb, + IEmitCode.present( + cb, + new SStackStructValue( + stackStructType, + IndexedSeq( + EmitValue.present(firstInterval), + EmitValue.present(SJavaString.construct(cb, makeFileName(cb, 0))), + EmitValue.present(primitive(cb.memoize(numElementsPerFile(0)))), + EmitValue.present(primitive(cb.memoize(numBytesPerFile(0)))), + ), + ), + ), + ) + }, + ) + + cb.for_( + { cb.assign(uniqueSplittersIdx, 0); cb.assign(fileArrayIdx, 1) }, + uniqueSplittersIdx < numUniqueSplitters, + cb.assign(uniqueSplittersIdx, uniqueSplittersIdx + 1), { + cb.if_( + uniqueSplittersIdx cne 0, { + val intervalFromLastToThis = intervalType.constructFromCodes( + cb, + region, + EmitCode.fromI(cb.emb)(cb => paddedSplitters.loadElement(cb, uniqueSplittersIdx - 1)), + EmitCode.fromI(cb.emb)(cb => paddedSplitters.loadElement(cb, uniqueSplittersIdx)), + false, + cb.memoize( + !splitterWasDuplicated.loadElement(cb, uniqueSplittersIdx).getOrAssert( + cb + ).asBoolean.value + ), + ) + + pushElement( + cb, + IEmitCode.present( + cb, + new SStackStructValue( + stackStructType, + IndexedSeq( + EmitValue.present(intervalFromLastToThis), + EmitValue.present(SJavaString.construct(cb, makeFileName(cb, fileArrayIdx))), + EmitValue.present(primitive(cb.memoize(numElementsPerFile(fileArrayIdx)))), + EmitValue.present(primitive(cb.memoize(numBytesPerFile(fileArrayIdx)))), + ), + ), + ), + ) + + cb.assign(fileArrayIdx, fileArrayIdx + 1) + }, + ) - cb.assign(fileArrayIdx, fileArrayIdx + 1) - }) - }) + // Now, maybe have to make an identity bucket. + cb.if_( + splitterWasDuplicated.loadElement(cb, uniqueSplittersIdx).getOrAssert(cb).asBoolean.value, { + val identityInterval = intervalType.constructFromCodes( + cb, + region, + EmitCode.fromI(cb.emb)(cb => paddedSplitters.loadElement(cb, uniqueSplittersIdx)), + EmitCode.fromI(cb.emb)(cb => paddedSplitters.loadElement(cb, uniqueSplittersIdx)), + true, + true, + ) + + pushElement( + cb, + IEmitCode.present( + cb, + new SStackStructValue( + stackStructType, + IndexedSeq( + EmitValue.present(identityInterval), + EmitValue.present(SJavaString.construct(cb, makeFileName(cb, fileArrayIdx))), + EmitValue.present(primitive(cb.memoize(numElementsPerFile(fileArrayIdx)))), + EmitValue.present(primitive(cb.memoize(numBytesPerFile(fileArrayIdx)))), + ), + ), + ), + ) + + cb.assign(fileArrayIdx, fileArrayIdx + 1) + }, + ) + }, + ) // Add last, but only if max != last key - cb.if_(!skipMaxInterval, { - val lastInterval = intervalType.constructFromCodes(cb, region, - EmitCode.fromI(cb.emb)(cb => paddedSplitters.loadElement(cb, uniqueSplittersIdx - 1)), - EmitCode.fromI(cb.emb)(cb => requestedSplittersAndEndsVal.loadElement(cb, requestedSplittersAndEndsVal.loadLength() - 1)), - false, - true - ) + cb.if_( + !skipMaxInterval, { + val lastInterval = intervalType.constructFromCodes( + cb, + region, + EmitCode.fromI(cb.emb)(cb => paddedSplitters.loadElement(cb, uniqueSplittersIdx - 1)), + EmitCode.fromI(cb.emb)(cb => + requestedSplittersAndEndsVal.loadElement( + cb, + requestedSplittersAndEndsVal.loadLength() - 1, + ) + ), + false, + true, + ) - pushElement(cb, IEmitCode.present(cb, new SStackStructValue(stackStructType, IndexedSeq( - EmitValue.present(lastInterval), - EmitValue.present(SJavaString.construct(cb, makeFileName(cb, fileArrayIdx))), - EmitValue.present(primitive(cb.memoize(numElementsPerFile(fileArrayIdx)))), - EmitValue.present(primitive(cb.memoize(numBytesPerFile(fileArrayIdx)))) - )))) - }) + pushElement( + cb, + IEmitCode.present( + cb, + new SStackStructValue( + stackStructType, + IndexedSeq( + EmitValue.present(lastInterval), + EmitValue.present(SJavaString.construct(cb, makeFileName(cb, fileArrayIdx))), + EmitValue.present(primitive(cb.memoize(numElementsPerFile(fileArrayIdx)))), + EmitValue.present(primitive(cb.memoize(numBytesPerFile(fileArrayIdx)))), + ), + ), + ), + ) + }, + ) finisher(cb) } diff --git a/hail/src/main/scala/is/hail/expr/ir/Env.scala b/hail/src/main/scala/is/hail/expr/ir/Env.scala index d1f4c9297d2..e6efe4075b5 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Env.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Env.scala @@ -10,24 +10,75 @@ object Env { def fromSeq[V](bindings: Iterable[(String, V)]): Env[V] = empty[V].bindIterable(bindings) } +trait GenericBindingEnv[Self, V] { + def promoteAgg: Self + + def promoteScan: Self + + def promoteAggOrScan(isScan: Boolean): Self = + if (isScan) promoteScan else promoteAgg + + def bindEval(bindings: (String, V)*): Self + + def dropEval: Self + + def bindAgg(bindings: (String, V)*): Self + + def bindScan(bindings: (String, V)*): Self + + def bindAggOrScan(isScan: Boolean, bindings: (String, V)*): Self = + if (isScan) bindScan(bindings: _*) else bindAgg(bindings: _*) + + def bindInScope(name: String, v: V, scope: Int): Self = scope match { + case Scope.EVAL => bindEval(name -> v) + case Scope.AGG => bindAgg(name -> v) + case Scope.SCAN => bindScan(name -> v) + } + + def createAgg: Self + + def createScan: Self + + def createAggOrScan(isScan: Boolean): Self = + if (isScan) createScan else createAgg + + def noAgg: Self + + def noScan: Self + + def noAggOrScan(isScan: Boolean): Self = if (isScan) noScan else noAgg + + def onlyRelational(keepAggCapabilities: Boolean = false): Self + + def bindRelational(bindings: (String, V)*): Self +} object BindingEnv { def empty[T]: BindingEnv[T] = BindingEnv(Env.empty[T], None, None) - def eval[T](bindings: (String, T)*): BindingEnv[T] = BindingEnv(Env.fromSeq[T](bindings), None, None) + + def eval[T](bindings: (String, T)*): BindingEnv[T] = + BindingEnv(Env.fromSeq[T](bindings), None, None) } case class BindingEnv[V]( eval: Env[V] = Env.empty[V], agg: Option[Env[V]] = None, scan: Option[Env[V]] = None, - relational: Env[V] = Env.empty[V] -) { - def allEmpty: Boolean = eval.isEmpty && agg.forall(_.isEmpty) && scan.forall(_.isEmpty) && relational.isEmpty + relational: Env[V] = Env.empty[V], +) extends GenericBindingEnv[BindingEnv[V], V] { + def allEmpty: Boolean = + eval.isEmpty && agg.forall(_.isEmpty) && scan.forall(_.isEmpty) && relational.isEmpty def promoteAgg: BindingEnv[V] = copy(eval = agg.get, agg = None) def promoteScan: BindingEnv[V] = copy(eval = scan.get, scan = None) + def promoteScope(scope: Int): BindingEnv[V] = scope match { + case Scope.EVAL => this + case Scope.AGG => promoteAgg + case Scope.SCAN => promoteScan + } + def noAgg: BindingEnv[V] = copy(agg = None) def noScan: BindingEnv[V] = copy(scan = None) @@ -38,7 +89,12 @@ case class BindingEnv[V]( def createScan: BindingEnv[V] = copy(scan = Some(eval), agg = agg.map(_ => Env.empty)) - def onlyRelational: BindingEnv[V] = BindingEnv(relational = relational) + def onlyRelational(keepAggCapabilities: Boolean = false): BindingEnv[V] = + BindingEnv( + agg = if (keepAggCapabilities) agg.map(_ => Env.empty) else None, + scan = if (keepAggCapabilities) scan.map(_ => Env.empty) else None, + relational = relational, + ) def bindEval(name: String, v: V): BindingEnv[V] = copy(eval = eval.bind(name, v)) @@ -47,7 +103,9 @@ case class BindingEnv[V]( copy(eval = eval.bindIterable(bindings)) def deleteEval(name: String): BindingEnv[V] = copy(eval = eval.delete(name)) - def deleteEval(names: IndexedSeq[String]) : BindingEnv[V] = copy(eval = eval.delete(names)) + def deleteEval(names: IndexedSeq[String]): BindingEnv[V] = copy(eval = eval.delete(names)) + + def dropEval: BindingEnv[V] = copy(eval = Env.empty) def bindAgg(name: String, v: V): BindingEnv[V] = copy(agg = Some(agg.get.bind(name, v))) @@ -63,64 +121,90 @@ case class BindingEnv[V]( def bindScan(bindings: (String, V)*): BindingEnv[V] = copy(scan = Some(scan.get.bindIterable(bindings))) - def bindRelational(name: String, v: V): BindingEnv[V] = copy(relational = relational.bind(name, v)) + def bindRelational(name: String, v: V): BindingEnv[V] = + copy(relational = relational.bind(name, v)) + + def bindRelational(bindings: (String, V)*): BindingEnv[V] = + copy(relational = relational.bind(bindings: _*)) def scanOrEmpty: Env[V] = scan.getOrElse(Env.empty) - def pretty(valuePrinter: V => String = _.toString): String = { + def pretty(valuePrinter: V => String = _.toString): String = s"""BindingEnv: - | Eval:${ eval.m.map { case (k, v) => s"\n $k -> ${ valuePrinter(v) }" }.mkString("") } - | Agg: ${ agg.map(_.m.map { case (k, v) => s"\n $k -> ${ valuePrinter(v) }" }.mkString("")).getOrElse("None") } - | Scan: ${ scan.map(_.m.map { case (k, v) => s"\n $k -> ${ valuePrinter(v) }" }.mkString("")).getOrElse("None") } - | Relational: ${ relational.m.map { case (k, v) => s"\n $k -> ${ valuePrinter(v) }" }.mkString("") }""".stripMargin - } + | Eval:${eval.m.map { case (k, v) => s"\n $k -> ${valuePrinter(v)}" }.mkString("")} + | Agg: ${agg.map(_.m.map { case (k, v) => s"\n $k -> ${valuePrinter(v)}" }.mkString("")).getOrElse("None")} + | Scan: ${scan.map(_.m.map { case (k, v) => s"\n $k -> ${valuePrinter(v)}" }.mkString("")).getOrElse("None")} + | Relational: ${relational.m.map { case (k, v) => + s"\n $k -> ${valuePrinter(v)}" + }.mkString("")}""".stripMargin def merge(newBindings: BindingEnv[V]): BindingEnv[V] = { if (agg.isDefined != newBindings.agg.isDefined || scan.isDefined != newBindings.scan.isDefined) throw new RuntimeException(s"found inconsistent agg or scan environments:" + - s"\n left: ${ agg.isDefined }, ${ scan.isDefined }" + - s"\n right: ${ newBindings.agg.isDefined }, ${ newBindings.scan.isDefined }") + s"\n left: ${agg.isDefined}, ${scan.isDefined}" + + s"\n right: ${newBindings.agg.isDefined}, ${newBindings.scan.isDefined}") if (allEmpty) newBindings else if (newBindings.allEmpty) this else { - copy(eval = eval.bindIterable(newBindings.eval.m), + copy( + eval = eval.bindIterable(newBindings.eval.m), agg = agg.map(a => a.bindIterable(newBindings.agg.get.m)), scan = scan.map(a => a.bindIterable(newBindings.scan.get.m)), - relational = relational.bindIterable(newBindings.relational.m)) + relational = relational.bindIterable(newBindings.relational.m), + ) } } def subtract(newBindings: BindingEnv[_]): BindingEnv[V] = { if (agg.isDefined != newBindings.agg.isDefined || scan.isDefined != newBindings.scan.isDefined) throw new RuntimeException(s"found inconsistent agg or scan environments:" + - s"\n left: ${ agg.isDefined }, ${ scan.isDefined }" + - s"\n right: ${ newBindings.agg.isDefined }, ${ newBindings.scan.isDefined }") + s"\n left: ${agg.isDefined}, ${scan.isDefined}" + + s"\n right: ${newBindings.agg.isDefined}, ${newBindings.scan.isDefined}") if (allEmpty || newBindings.allEmpty) this else { - copy(eval = eval.delete(newBindings.eval.m.keys), + copy( + eval = eval.delete(newBindings.eval.m.keys), agg = agg.map(a => a.delete(newBindings.agg.get.m.keys)), scan = scan.map(a => a.delete(newBindings.scan.get.m.keys)), - relational = relational.delete(newBindings.relational.m.keys)) + relational = relational.delete(newBindings.relational.m.keys), + ) } } - def mapValues[T](f: V => T): BindingEnv[T] = { - copy[T](eval = eval.mapValues(f), agg = agg.map(_.mapValues(f)), scan = scan.map(_.mapValues(f)), relational = relational.mapValues(f)) - } - - def mapValuesWithKey[T](f: (Env.K, V) => T): BindingEnv[T] = { - copy[T](eval = eval.mapValuesWithKey(f), agg = agg.map(_.mapValuesWithKey(f)), scan = scan.map(_.mapValuesWithKey(f)), relational = relational.mapValuesWithKey(f)) - } - - def dropBindings[T]: BindingEnv[T] = copy(eval = Env.empty, agg = agg.map(_ => Env.empty), scan = scan.map(_ => Env.empty), relational = Env.empty) + def mapValues[T](f: V => T): BindingEnv[T] = + copy[T]( + eval = eval.mapValues(f), + agg = agg.map(_.mapValues(f)), + scan = scan.map(_.mapValues(f)), + relational = relational.mapValues(f), + ) + + def mapValuesWithKey[T](f: (Env.K, V) => T): BindingEnv[T] = + copy[T]( + eval = eval.mapValuesWithKey(f), + agg = agg.map(_.mapValuesWithKey(f)), + scan = scan.map(_.mapValuesWithKey(f)), + relational = relational.mapValuesWithKey(f), + ) + + def dropBindings[T]: BindingEnv[T] = copy( + eval = Env.empty, + agg = agg.map(_ => Env.empty), + scan = scan.map(_ => Env.empty), + relational = Env.empty, + ) } -class Env[V] private(val m: Map[Env.K, V]) { - def this() { +final class Env[V] private (val m: Map[Env.K, V]) { + def this() = this(Map()) + + override def equals(other: Any): Boolean = other match { + case env: Env[V] => this.m == env.m + case _ => false } def contains(k: Env.K): Boolean = m.contains(k) @@ -130,7 +214,7 @@ class Env[V] private(val m: Map[Env.K, V]) { def apply(name: String): V = m(name) def lookup(name: String): V = - m.get(name).getOrElse(throw new RuntimeException(s"Cannot find $name in $m")) + m.getOrElse(name, throw new RuntimeException(s"Cannot find $name in $m")) def lookupOption(name: String): Option[V] = m.get(name) @@ -146,7 +230,8 @@ class Env[V] private(val m: Map[Env.K, V]) { def bind(bindings: (String, V)*): Env[V] = bindIterable(bindings) - def bindIterable(bindings: Iterable[(String, V)]): Env[V] = if (bindings.isEmpty) this else new Env(m ++ bindings) + def bindIterable(bindings: Iterable[(String, V)]): Env[V] = + if (bindings.isEmpty) this else new Env(m ++ bindings) def freshName(prefix: String): String = { var i = 0 @@ -182,7 +267,8 @@ class Env[V] private(val m: Map[Env.K, V]) { def mapValues[U](f: (V) => U): Env[U] = new Env(m.mapValues(f)) - def mapValuesWithKey[U](f: (Env.K, V) => U): Env[U] = new Env(m.map { case (k, v) => (k, f(k, v)) }) + def mapValuesWithKey[U](f: (Env.K, V) => U): Env[U] = + new Env(m.map { case (k, v) => (k, f(k, v)) }) override def toString: String = m.map { case (k, v) => s"$k -> $v" }.mkString("(", ",", ")") } diff --git a/hail/src/main/scala/is/hail/expr/ir/Exists.scala b/hail/src/main/scala/is/hail/expr/ir/Exists.scala index c53348734eb..1e4d7fb3f79 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Exists.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Exists.scala @@ -1,10 +1,5 @@ package is.hail.expr.ir -import is.hail.expr._ -import is.hail.utils.{BoxedArrayBuilder, _} - -import scala.collection.mutable._ - // // Search an IR tree for the first node satisfying some condition // @@ -22,12 +17,11 @@ object Exists { } } - def apply(node: BaseIR, visitor: BaseIR => Boolean): Boolean = { + def apply(node: BaseIR, visitor: BaseIR => Boolean): Boolean = if (visitor(node)) true else node.children.exists(Exists(_, visitor)) - } } object Forall { @@ -62,33 +56,39 @@ object IsAggResult { object ContainsAgg { def apply(root: IR): Boolean = IsAggResult(root) || (root match { - case l: AggLet => !l.isScan + case Block(bindings, body) => + bindings.exists { + case Binding(_, value, Scope.EVAL) => ContainsAgg(value) + case Binding(_, _, Scope.AGG) => true + case Binding(_, _, Scope.SCAN) => false + } || ContainsAgg(body) case _: TableAggregate => false case _: MatrixAggregate => false case _: StreamAgg => false case _ => root.children.exists { - case child: IR => ContainsAgg(child) - case _ => false - } + case child: IR => ContainsAgg(child) + case _ => false + } }) } object ContainsAggIntermediate { - def apply(root: IR): Boolean = (root match { - case _: ResultOp => true - case _: SeqOp => true - case _: InitOp => true - case _: CombOp => true - case _: DeserializeAggs => true - case _: SerializeAggs => true - case _: AggStateValue => true - case _: CombOpValue => true - case _: InitFromSerializedValue => true - case _ => false - }) || root.children.exists { - case child: IR => ContainsAggIntermediate(child) - case _ => false - } + def apply(root: IR): Boolean = + (root match { + case _: ResultOp => true + case _: SeqOp => true + case _: InitOp => true + case _: CombOp => true + case _: DeserializeAggs => true + case _: SerializeAggs => true + case _: AggStateValue => true + case _: CombOpValue => true + case _: InitFromSerializedValue => true + case _ => false + }) || root.children.exists { + case child: IR => ContainsAggIntermediate(child) + case _ => false + } } object AggIsCommutative { @@ -104,21 +104,26 @@ object ContainsNonCommutativeAgg { case _: TableAggregate => false case _: MatrixAggregate => false case _ => root.children.exists { - case child: IR => ContainsNonCommutativeAgg(child) - case _ => false - } + case child: IR => ContainsNonCommutativeAgg(child) + case _ => false + } } } object ContainsScan { def apply(root: IR): Boolean = IsScanResult(root) || (root match { - case l: AggLet => l.isScan + case Block(bindings, body) => + bindings.exists { + case Binding(_, value, Scope.EVAL) => ContainsScan(value) + case Binding(_, _, Scope.AGG) => false + case Binding(_, _, Scope.SCAN) => true + } || ContainsScan(body) case _: TableAggregate => false case _: MatrixAggregate => false case _: StreamAggScan => false case _ => root.children.exists { - case child: IR => ContainsScan(child) - case _ => false - } + case child: IR => ContainsScan(child) + case _ => false + } }) } diff --git a/hail/src/main/scala/is/hail/expr/ir/ExtractIntervalFilters.scala b/hail/src/main/scala/is/hail/expr/ir/ExtractIntervalFilters.scala index 90db484f1dd..c2d90654383 100644 --- a/hail/src/main/scala/is/hail/expr/ir/ExtractIntervalFilters.scala +++ b/hail/src/main/scala/is/hail/expr/ir/ExtractIntervalFilters.scala @@ -6,11 +6,12 @@ import is.hail.rvd.PartitionBoundOrdering import is.hail.types.virtual._ import is.hail.utils.{Interval, IntervalEndpoint, _} import is.hail.variant.{Locus, ReferenceGenome} -import org.apache.spark.sql.Row import scala.Option.option2Iterable import scala.collection.mutable +import org.apache.spark.sql.Row + trait Lattice { type Value <: AnyRef def top: Value @@ -35,32 +36,37 @@ object ExtractIntervalFilters { log.info( s"generated TableFilterIntervals node with ${intervals.length} intervals:\n " + s"Intervals: ${intervals.mkString(", ")}\n " + - s"Predicate: ${Pretty(ctx, pred)}\n " + s"Post: ${Pretty(ctx, newCond)}") + s"Predicate: ${Pretty(ctx, pred)}\n " + s"Post: ${Pretty(ctx, newCond)}" + ) TableFilter(TableFilterIntervals(child, intervals, keep = true), newCond) } case MatrixFilterRows(child, pred) => extractPartitionFilters( ctx, pred, Ref("va", child.typ.rowType), - child.typ.rowKey).map { case (newCond, intervals) => + child.typ.rowKey, + ).map { case (newCond, intervals) => log.info( s"generated MatrixFilterIntervals node with ${intervals.length} intervals:\n " + s"Intervals: ${intervals.mkString(", ")}\n " + - s"Predicate: ${Pretty(ctx, pred)}\n " + s"Post: ${Pretty(ctx, newCond)}") + s"Predicate: ${Pretty(ctx, pred)}\n " + s"Post: ${Pretty(ctx, newCond)}" + ) MatrixFilterRows(MatrixFilterIntervals(child, intervals, keep = true), newCond) } case _ => None } ).getOrElse(ir) - } + }, ) } - def extractPartitionFilters(ctx: ExecuteContext, cond: IR, ref: Ref, key: IndexedSeq[String]): Option[(IR, IndexedSeq[Interval])] = { + def extractPartitionFilters(ctx: ExecuteContext, cond: IR, ref: Ref, key: IndexedSeq[String]) + : Option[(IR, IndexedSeq[Interval])] = { if (key.isEmpty) None else { - val extract = new ExtractIntervalFilters(ctx, ref.typ.asInstanceOf[TStruct].typeAfterSelectNames(key)) + val extract = + new ExtractIntervalFilters(ctx, ref.typ.asInstanceOf[TStruct].typeAfterSelectNames(key)) val trueSet = extract.analyze(cond, ref.name) if (trueSet == extract.KeySetLattice.top) None @@ -72,7 +78,8 @@ object ExtractIntervalFilters { } } - def liftPosIntervalsToLocus(pos: IndexedSeq[Interval], rg: ReferenceGenome, ctx: ExecuteContext): IndexedSeq[Interval] = { + def liftPosIntervalsToLocus(pos: IndexedSeq[Interval], rg: ReferenceGenome, ctx: ExecuteContext) + : IndexedSeq[Interval] = { val ord = PartitionBoundOrdering(ctx, TTuple(TInt32)) val nonNull = rg.contigs.indices.flatMap { cont => pos.flatMap { i => @@ -82,7 +89,8 @@ object ExtractIntervalFilters { Row(Locus(rg.contigs(cont), interval.start.asInstanceOf[Row].getAs[Int](0))), Row(Locus(rg.contigs(cont), interval.right.point.asInstanceOf[Row].getAs[Int](0))), interval.includesStart, - interval.includesEnd) + interval.includesEnd, + ) } } } @@ -93,7 +101,6 @@ object ExtractIntervalFilters { } } - // A set of key values, represented by an ordered sequence of disjoint intervals // Supports lattice ops, plus complements. class KeySetLattice(ctx: ExecuteContext, keyType: TStruct) extends Lattice { @@ -101,9 +108,8 @@ class KeySetLattice(ctx: ExecuteContext, keyType: TStruct) extends Lattice { type KeySet = Value object KeySet { - def apply(intervals: Interval*): IndexedSeq[Interval] = { + def apply(intervals: Interval*): IndexedSeq[Interval] = apply(intervals.toFastSeq) - } def apply(intervals: IndexedSeq[Interval]): IndexedSeq[Interval] = { assert(intervals.isEmpty || KeySet.intervalIsReduced(intervals.last)) @@ -122,54 +128,51 @@ class KeySetLattice(ctx: ExecuteContext, keyType: TStruct) extends Lattice { (init :+ reducedLast).toFastSeq } - private def intervalIsReduced(interval: Interval): Boolean = { + private def intervalIsReduced(interval: Interval): Boolean = interval.right != IntervalEndpoint(Row(null), 1) - } } def keyOrd: ExtendedOrdering = PartitionBoundOrdering(ctx, keyType) val iord: IntervalEndpointOrdering = keyOrd.intervalEndpointOrdering // l is contained in r - def specializes(l: Value, r: Value): Boolean = { + def specializes(l: Value, r: Value): Boolean = if (l == bottom) true else if (r == top) true else l.forall { i => - r.containsOrdered[Interval, Interval](i, (h: Interval, n: Interval) => iord.lt(h.right, n.right), (n: Interval, h: Interval) => iord.lt(n.left, h.left)) + r.containsOrdered[Interval, Interval]( + i, + (h: Interval, n: Interval) => iord.lt(h.right, n.right), + (n: Interval, h: Interval) => iord.lt(n.left, h.left), + ) } - } override def top: Value = KeySet(Interval(Row(), Row(), true, true)) override def bottom: Value = KeySet.empty - override def join(l: Value, r: Value): Value = { + override def join(l: Value, r: Value): Value = if (l == bottom) r else if (r == bottom) l else KeySet(Interval.union(l ++ r, iord)) - } - def combineMulti(vs: Value*): Value = { + def combineMulti(vs: Value*): Value = KeySet(Interval.union(vs.flatten.toFastSeq, iord)) - } - override def meet(l: Value, r: Value): Value = { + override def meet(l: Value, r: Value): Value = if (l == top) r else if (r == top) l else KeySet(Interval.intersection(l, r, iord)) - } def complement(v: Value): Value = { if (v.isEmpty) return top val builder = mutable.ArrayBuilder.make[Interval]() - var i = 0 if (v.head.left != IntervalEndpoint(Row(), -1)) { builder += Interval(IntervalEndpoint(Row(), -1), v.head.left) } - for (i <- 0 until v.length - 1) { + for (i <- 0 until v.length - 1) builder += Interval(v(i).right, v(i + 1).left) - } if (v.last.right != IntervalEndpoint(Row(), 1)) { builder += Interval(v.last.right, IntervalEndpoint(Row(), 1)) } @@ -227,7 +230,6 @@ class ExtractIntervalFilters(ctx: ExecuteContext, keyType: TStruct) { case _ => ConcreteKeyField(idx) } - def unapply(v: Value): Option[Int] = v match { case v: KeyField => Some(v.idx) case _ => None @@ -302,12 +304,14 @@ class ExtractIntervalFilters(ctx: ExecuteContext, keyType: TStruct) { assert(lKeySet eq rKeySet) ConstantBool(l || r, lKeySet) case _ => ConcreteBool( - KeySetLattice.join(l.trueBound, r.trueBound), - KeySetLattice.meet(l.falseBound, r.falseBound), - KeySetLattice.combineMulti( - KeySetLattice.meet(l.naBound, r.falseBound), - KeySetLattice.meet(l.naBound, r.naBound), - KeySetLattice.meet(l.falseBound, r.naBound))) + KeySetLattice.join(l.trueBound, r.trueBound), + KeySetLattice.meet(l.falseBound, r.falseBound), + KeySetLattice.combineMulti( + KeySetLattice.meet(l.naBound, r.falseBound), + KeySetLattice.meet(l.naBound, r.naBound), + KeySetLattice.meet(l.falseBound, r.naBound), + ), + ) } def and(l: BoolValue, r: BoolValue): BoolValue = (l, r) match { @@ -315,12 +319,14 @@ class ExtractIntervalFilters(ctx: ExecuteContext, keyType: TStruct) { assert(lKeySet eq rKeySet) ConstantBool(l && r, lKeySet) case _ => ConcreteBool( - KeySetLattice.meet(l.trueBound, r.trueBound), - KeySetLattice.join(l.falseBound, r.falseBound), - KeySetLattice.combineMulti( - KeySetLattice.meet(l.naBound, r.trueBound), - KeySetLattice.meet(l.naBound, r.naBound), - KeySetLattice.meet(l.trueBound, r.naBound))) + KeySetLattice.meet(l.trueBound, r.trueBound), + KeySetLattice.join(l.falseBound, r.falseBound), + KeySetLattice.combineMulti( + KeySetLattice.meet(l.naBound, r.trueBound), + KeySetLattice.meet(l.naBound, r.naBound), + KeySetLattice.meet(l.trueBound, r.naBound), + ), + ) } def not(x: BoolValue): BoolValue = x match { @@ -328,98 +334,140 @@ class ExtractIntervalFilters(ctx: ExecuteContext, keyType: TStruct) { case _ => ConcreteBool(x.falseBound, x.trueBound, x.naBound) } - def coalesce(l: BoolValue, r: BoolValue): BoolValue = { + def coalesce(l: BoolValue, r: BoolValue): BoolValue = ConcreteBool( KeySetLattice.join(l.trueBound, KeySetLattice.meet(l.naBound, r.trueBound)), KeySetLattice.join(l.falseBound, KeySetLattice.meet(l.naBound, r.falseBound)), - KeySetLattice.meet(l.naBound, r.naBound)) - } + KeySetLattice.meet(l.naBound, r.naBound), + ) def fromComparison(v: Any, op: ComparisonOp[_], wrapped: Boolean = true): BoolValue = { (op: @unchecked) match { case _: EQ => BoolValue( // value == key - KeySet(Interval(endpoint(v, -1, wrapped), endpoint(v, 1, wrapped))), - KeySet(Interval(negInf, endpoint(v, -1, wrapped)), Interval(endpoint(v, 1, wrapped), endpoint(null, -1))), - KeySet(Interval(endpoint(null, -1), posInf))) + KeySet(Interval(endpoint(v, -1, wrapped), endpoint(v, 1, wrapped))), + KeySet( + Interval(negInf, endpoint(v, -1, wrapped)), + Interval(endpoint(v, 1, wrapped), endpoint(null, -1)), + ), + KeySet(Interval(endpoint(null, -1), posInf)), + ) case _: NEQ => BoolValue( // value != key - KeySet(Interval(negInf, endpoint(v, -1, wrapped)), Interval(endpoint(v, 1, wrapped), endpoint(null, -1))), - KeySet(Interval(endpoint(v, -1, wrapped), endpoint(v, 1, wrapped))), - KeySet(Interval(endpoint(null, -1), posInf))) + KeySet( + Interval(negInf, endpoint(v, -1, wrapped)), + Interval(endpoint(v, 1, wrapped), endpoint(null, -1)), + ), + KeySet(Interval(endpoint(v, -1, wrapped), endpoint(v, 1, wrapped))), + KeySet(Interval(endpoint(null, -1), posInf)), + ) case _: GT => BoolValue( // value > key - KeySet(Interval(negInf, endpoint(v, -1, wrapped))), - KeySet(Interval(endpoint(v, -1, wrapped), endpoint(null, -1))), - KeySet(Interval(endpoint(null, -1), posInf))) + KeySet(Interval(negInf, endpoint(v, -1, wrapped))), + KeySet(Interval(endpoint(v, -1, wrapped), endpoint(null, -1))), + KeySet(Interval(endpoint(null, -1), posInf)), + ) case _: GTEQ => BoolValue( // value >= key - KeySet(Interval(negInf, endpoint(v, 1, wrapped))), - KeySet(Interval(endpoint(v, 1, wrapped), endpoint(null, -1))), - KeySet(Interval(endpoint(null, -1), posInf))) + KeySet(Interval(negInf, endpoint(v, 1, wrapped))), + KeySet(Interval(endpoint(v, 1, wrapped), endpoint(null, -1))), + KeySet(Interval(endpoint(null, -1), posInf)), + ) case _: LT => BoolValue( // value < key - KeySet(Interval(endpoint(v, 1, wrapped), endpoint(null, -1))), - KeySet(Interval(negInf, endpoint(v, 1, wrapped))), - KeySet(Interval(endpoint(null, -1), posInf))) + KeySet(Interval(endpoint(v, 1, wrapped), endpoint(null, -1))), + KeySet(Interval(negInf, endpoint(v, 1, wrapped))), + KeySet(Interval(endpoint(null, -1), posInf)), + ) case _: LTEQ => BoolValue( // value <= key - KeySet(Interval(endpoint(v, -1, wrapped), endpoint(null, -1))), - KeySet(Interval(negInf, endpoint(v, -1, wrapped))), - KeySet(Interval(endpoint(null, -1), posInf))) + KeySet(Interval(endpoint(v, -1, wrapped), endpoint(null, -1))), + KeySet(Interval(negInf, endpoint(v, -1, wrapped))), + KeySet(Interval(endpoint(null, -1), posInf)), + ) case _: EQWithNA => // value == key if (v == null) BoolValue( KeySet(Interval(endpoint(v, -1, wrapped), posInf)), KeySet(Interval(negInf, endpoint(v, -1, wrapped))), - KeySetLattice.bottom) + KeySetLattice.bottom, + ) else BoolValue( KeySet(Interval(endpoint(v, -1, wrapped), endpoint(v, 1, wrapped))), - KeySet(Interval(negInf, endpoint(v, -1, wrapped)), Interval(endpoint(v, 1, wrapped), posInf)), - KeySetLattice.bottom) + KeySet( + Interval(negInf, endpoint(v, -1, wrapped)), + Interval(endpoint(v, 1, wrapped), posInf), + ), + KeySetLattice.bottom, + ) case _: NEQWithNA => // value != key if (v == null) BoolValue( KeySet(Interval(negInf, endpoint(v, -1, wrapped))), KeySet(Interval(endpoint(v, -1, wrapped), posInf)), - KeySetLattice.bottom) + KeySetLattice.bottom, + ) else BoolValue( - KeySet(Interval(negInf, endpoint(v, -1, wrapped)), Interval(endpoint(v, 1, wrapped), posInf)), + KeySet( + Interval(negInf, endpoint(v, -1, wrapped)), + Interval(endpoint(v, 1, wrapped), posInf), + ), KeySet(Interval(endpoint(v, -1, wrapped), endpoint(v, 1, wrapped))), - KeySetLattice.bottom) + KeySetLattice.bottom, + ) } } def fromComparisonKeyPrefix(v: Row, op: ComparisonOp[_]): BoolValue = { (op: @unchecked) match { case _: EQ => BoolValue( // value == key - KeySet(Interval(endpoint(v, -1, false), endpoint(v, 1, false))), - KeySet(Interval(negInf, endpoint(v, -1, false)), Interval(endpoint(v, 1, false), posInf)), - KeySetLattice.bottom) + KeySet(Interval(endpoint(v, -1, false), endpoint(v, 1, false))), + KeySet( + Interval(negInf, endpoint(v, -1, false)), + Interval(endpoint(v, 1, false), posInf), + ), + KeySetLattice.bottom, + ) case _: NEQ => BoolValue( // value != key - KeySet(Interval(negInf, endpoint(v, -1, false)), Interval(endpoint(v, 1, false), posInf)), - KeySet(Interval(endpoint(v, -1, false), endpoint(v, 1, false))), - KeySetLattice.bottom) + KeySet( + Interval(negInf, endpoint(v, -1, false)), + Interval(endpoint(v, 1, false), posInf), + ), + KeySet(Interval(endpoint(v, -1, false), endpoint(v, 1, false))), + KeySetLattice.bottom, + ) case _: GT => BoolValue( // value > key - KeySet(Interval(negInf, endpoint(v, -1, false))), - KeySet(Interval(endpoint(v, -1, false), posInf)), - KeySetLattice.bottom) + KeySet(Interval(negInf, endpoint(v, -1, false))), + KeySet(Interval(endpoint(v, -1, false), posInf)), + KeySetLattice.bottom, + ) case _: GTEQ => BoolValue( // value >= key - KeySet(Interval(negInf, endpoint(v, 1, false))), - KeySet(Interval(endpoint(v, 1, false), posInf)), - KeySetLattice.bottom) + KeySet(Interval(negInf, endpoint(v, 1, false))), + KeySet(Interval(endpoint(v, 1, false), posInf)), + KeySetLattice.bottom, + ) case _: LT => BoolValue( // value < key - KeySet(Interval(endpoint(v, 1, false), posInf)), - KeySet(Interval(negInf, endpoint(v, 1, false))), - KeySetLattice.bottom) + KeySet(Interval(endpoint(v, 1, false), posInf)), + KeySet(Interval(negInf, endpoint(v, 1, false))), + KeySetLattice.bottom, + ) case _: LTEQ => BoolValue( // value <= key - KeySet(Interval(endpoint(v, -1, false), posInf)), - KeySet(Interval(negInf, endpoint(v, 1, false))), - KeySetLattice.bottom) + KeySet(Interval(endpoint(v, -1, false), posInf)), + KeySet(Interval(negInf, endpoint(v, 1, false))), + KeySetLattice.bottom, + ) case _: EQWithNA => BoolValue( // value == key - KeySet(Interval(endpoint(v, -1, false), endpoint(v, 1, false))), - KeySet(Interval(negInf, endpoint(v, -1, false)), Interval(endpoint(v, 1, false), posInf)), - KeySetLattice.bottom) + KeySet(Interval(endpoint(v, -1, false), endpoint(v, 1, false))), + KeySet( + Interval(negInf, endpoint(v, -1, false)), + Interval(endpoint(v, 1, false), posInf), + ), + KeySetLattice.bottom, + ) case _: NEQWithNA => BoolValue( // value != key - KeySet(Interval(negInf, endpoint(v, -1, false)), Interval(endpoint(v, 1, false), posInf)), - KeySet(Interval(endpoint(v, -1, false), endpoint(v, 1, false))), - KeySetLattice.bottom) + KeySet( + Interval(negInf, endpoint(v, -1, false)), + Interval(endpoint(v, 1, false), posInf), + ), + KeySet(Interval(endpoint(v, -1, false), endpoint(v, 1, false))), + KeySetLattice.bottom, + ) } } } @@ -431,27 +479,33 @@ class ExtractIntervalFilters(ctx: ExecuteContext, keyType: TStruct) { def restrict(keySet: KeySet): BoolValue - def isNA: BoolValue = ConcreteBool(naBound, KeySetLattice.join(trueBound, falseBound), KeySetLattice.bottom) + def isNA: BoolValue = + ConcreteBool(naBound, KeySetLattice.join(trueBound, falseBound), KeySetLattice.bottom) } private case class ConcreteBool( trueBound: KeySet, falseBound: KeySet, - naBound: KeySet + naBound: KeySet, ) extends BoolValue { - override def restrict(keySet: KeySet): BoolValue = { - ConcreteBool(KeySetLattice.meet(trueBound, keySet), KeySetLattice.meet(falseBound, keySet), KeySetLattice.meet(naBound, keySet)) - } + override def restrict(keySet: KeySet): BoolValue = + ConcreteBool( + KeySetLattice.meet(trueBound, keySet), + KeySetLattice.meet(falseBound, keySet), + KeySetLattice.meet(naBound, keySet), + ) } - private case class ConstantStruct(value: Row, t: TStruct) extends StructValue with ConstantValue { + private case class ConstantStruct(value: Row, t: TStruct) + extends StructValue with ConstantValue { def apply(field: String): Value = this(t.field(field)) def values: Iterable[Value] = t.fields.map(apply) private def apply(field: Field): ConstantValue = ConstantValue(value(field.index), field.typ) def isKeyPrefix: Boolean = false } - private case class ConstantBool(value: Boolean, keySet: KeySet) extends BoolValue with ConstantValue { + private case class ConstantBool(value: Boolean, keySet: KeySet) + extends BoolValue with ConstantValue { override def trueBound: KeySet = if (value) keySet else KeySetLattice.bottom @@ -531,9 +585,10 @@ class ExtractIntervalFilters(ctx: ExecuteContext, keyType: TStruct) { f => f -> join(l(f), r(f)) }.toMap) case (l: BoolValue, r: BoolValue) => BoolValue( - KeySetLattice.join(l.trueBound, r.trueBound), - KeySetLattice.join(l.falseBound, r.falseBound), - KeySetLattice.join(l.naBound, r.naBound)) + KeySetLattice.join(l.trueBound, r.trueBound), + KeySetLattice.join(l.falseBound, r.falseBound), + KeySetLattice.join(l.naBound, r.naBound), + ) case _ => Top } @@ -553,9 +608,10 @@ class ExtractIntervalFilters(ctx: ExecuteContext, keyType: TStruct) { f => f -> meet(l(f), r(f)) }.toMap) case (l: BoolValue, r: BoolValue) => BoolValue( - KeySetLattice.meet(l.trueBound, r.trueBound), - KeySetLattice.meet(l.falseBound, r.falseBound), - KeySetLattice.meet(l.naBound, r.naBound)) + KeySetLattice.meet(l.trueBound, r.trueBound), + KeySetLattice.meet(l.falseBound, r.falseBound), + KeySetLattice.meet(l.naBound, r.naBound), + ) case _ => Top } @@ -563,33 +619,47 @@ class ExtractIntervalFilters(ctx: ExecuteContext, keyType: TStruct) { if (opIsSupported(op)) (l, r) match { case (ConstantValue(l), r) => compareWithConstant(l, r, op, keySet) case (l, ConstantValue(r)) => - compareWithConstant(r, l, ComparisonOp.swap(op.asInstanceOf[ComparisonOp[Boolean]]), keySet) + compareWithConstant( + r, + l, + ComparisonOp.swap(op.asInstanceOf[ComparisonOp[Boolean]]), + keySet, + ) case _ => BoolValue.top(keySet) - } else { + } + else { BoolValue.top(keySet) } } - private def compareWithConstant(l: Any, r: Value, op: ComparisonOp[_], keySet: KeySet): BoolValue = { - if (op.strict && l == null) return BoolValue.allNA(keySet) - r match { + private def compareWithConstant(l: Any, r: Value, op: ComparisonOp[_], keySet: KeySet) + : BoolValue = + if (op.strict && l == null) BoolValue.allNA(keySet) + else r match { case r: KeyField if r.idx == 0 => // simple key comparison BoolValue.fromComparison(l, op).restrict(keySet) case Contig(rgStr) => - // locus contig comparison - assert(op.isInstanceOf[EQ]) + // locus contig equality comparison val b = getIntervalFromContig(l.asInstanceOf[String], ctx.getReference(rgStr)) match { case Some(i) => - BoolValue( + val b = BoolValue( KeySet(i), KeySet(Interval(negInf, i.left), Interval(i.right, endpoint(null, -1))), - KeySet(Interval(endpoint(null, -1), posInf))) + KeySet(Interval(endpoint(null, -1), posInf)), + ) + + op match { + case _: EQ => b + case _: NEQ => BoolValue.not(b) + case _ => BoolValue.top(keySet) + } case None => BoolValue( KeySetLattice.bottom, KeySet(Interval(negInf, endpoint(null, -1))), - KeySet(Interval(endpoint(null, -1), posInf))) + KeySet(Interval(endpoint(null, -1), posInf)), + ) } b.restrict(keySet) case Position(rgStr) => @@ -599,13 +669,13 @@ class ExtractIntervalFilters(ctx: ExecuteContext, keyType: TStruct) { val b = BoolValue( KeySet(liftPosIntervalsToLocus(posBoolValue.trueBound, rg, ctx)), KeySet(liftPosIntervalsToLocus(posBoolValue.falseBound, rg, ctx)), - KeySet(liftPosIntervalsToLocus(posBoolValue.naBound, rg, ctx))) + KeySet(liftPosIntervalsToLocus(posBoolValue.naBound, rg, ctx)), + ) b.restrict(keySet) case s: StructValue if s.isKeyPrefix => BoolValue.fromComparisonKeyPrefix(l.asInstanceOf[Row], op).restrict(keySet) case _ => BoolValue.top(keySet) } - } private def opIsSupported(op: ComparisonOp[_]): Boolean = op match { case _: EQ | _: NEQ | _: LTEQ | _: LT | _: GTEQ | _: GT | _: EQWithNA | _: NEQWithNA => true @@ -613,7 +683,9 @@ class ExtractIntervalFilters(ctx: ExecuteContext, keyType: TStruct) { } } - import AbstractLattice.{BoolValue, ConstantValue, Contig, KeyField, Position, StructValue, Value => AbstractValue} + import AbstractLattice.{ + BoolValue, ConstantValue, Contig, KeyField, Position, StructValue, Value => AbstractValue, + } case class AbstractEnv(keySet: KeySet, env: Env[AbstractValue]) { def apply(name: String): AbstractValue = @@ -636,21 +708,30 @@ class ExtractIntervalFilters(ctx: ExecuteContext, keyType: TStruct) { case x: Traversable[_] => intervalsFromCollection(x, ordering, wrapped) } - private def intervalsFromCollection(lit: Traversable[Any], ordering: Ordering[Any], wrapped: Boolean): KeySet = + private def intervalsFromCollection( + lit: Traversable[Any], + ordering: Ordering[Any], + wrapped: Boolean, + ): KeySet = KeySet.reduce( lit.toArray.distinct.filter(x => wrapped || x != null).sorted(ordering) .map(elt => Interval(endpoint(elt, -1, wrapped), endpoint(elt, 1, wrapped))) - .toFastSeq) + .toFastSeq + ) private def intervalsFromLiteralContigs(contigs: Any, rg: ReferenceGenome): KeySet = { KeySet((contigs: @unchecked) match { - case x: Map[_, _] => x.keys.asInstanceOf[Iterable[String]].toFastSeq - .sortBy(rg.contigsIndex.get(_))(TInt32.ordering(null).toOrdering.asInstanceOf[Ordering[Integer]]) - .flatMap(getIntervalFromContig(_, rg)) - case x: Traversable[_] => x.asInstanceOf[Traversable[String]].toArray.toFastSeq - .sortBy(rg.contigsIndex.get(_))(TInt32.ordering(null).toOrdering.asInstanceOf[Ordering[Integer]]) - .flatMap(getIntervalFromContig(_, rg)) - }) + case x: Map[_, _] => x.keys.asInstanceOf[Iterable[String]].toFastSeq + .sortBy(rg.contigsIndex.get(_))( + TInt32.ordering(null).toOrdering.asInstanceOf[Ordering[Integer]] + ) + .flatMap(getIntervalFromContig(_, rg)) + case x: Traversable[_] => x.asInstanceOf[Traversable[String]].toArray.toFastSeq + .sortBy(rg.contigsIndex.get(_))( + TInt32.ordering(null).toOrdering.asInstanceOf[Ordering[Integer]] + ) + .flatMap(getIntervalFromContig(_, rg)) + }) } private def getIntervalFromContig(c: String, rg: ReferenceGenome): Option[Interval] = { @@ -661,7 +742,8 @@ class ExtractIntervalFilters(ctx: ExecuteContext, keyType: TStruct) { Some(Interval(Row(null), Row(), true, true)) } else { warn( - s"Filtered with contig '$c', but '$c' is not a valid contig in reference genome ${rg.name}") + s"Filtered with contig '$c', but '$c' is not a valid contig in reference genome ${rg.name}" + ) None } } @@ -672,32 +754,22 @@ class ExtractIntervalFilters(ctx: ExecuteContext, keyType: TStruct) { private def literalSizeOkay(lit: Any): Boolean = lit.asInstanceOf[Iterable[_]].size <= MAX_LITERAL_SIZE - private def wrapInRow(intervals: IndexedSeq[Interval]): IndexedSeq[Interval] = intervals - .map { interval => - Interval( - IntervalEndpoint(Row(interval.left.point), interval.left.sign), - IntervalEndpoint(Row(interval.right.point), interval.right.sign)) - } - - private def intervalFromComparison(v: Any, op: ComparisonOp[_]): Interval = { - (op: @unchecked) match { - case _: EQ => Interval(endpoint(v, -1), endpoint(v, 1)) - case GT(_, _) => Interval(negInf, endpoint(v, -1)) // value > key - case GTEQ(_, _) => Interval(negInf, endpoint(v, 1)) // value >= key - case LT(_, _) => Interval(endpoint(v, 1), posInf) // value < key - case LTEQ(_, _) => Interval(endpoint(v, -1), posInf) // value <= key - } - } - private def posInf: IntervalEndpoint = IntervalEndpoint(Row(), 1) private def negInf: IntervalEndpoint = IntervalEndpoint(Row(), -1) - def analyze(x: IR, rowName: String, rw: Option[Rewrites] = None, constraint: KeySet = KeySetLattice.top): KeySet = { + def analyze( + x: IR, + rowName: String, + rw: Option[Rewrites] = None, + constraint: KeySet = KeySetLattice.top, + ): KeySet = { val env = Env.empty[AbstractValue].bind( rowName, StructValue( - Map(keyType.fieldNames.zipWithIndex.map(t => t._1 -> KeyField(t._2)): _*))) + Map(keyType.fieldNames.zipWithIndex.map(t => t._1 -> KeyField(t._2)): _*) + ), + ) val bool = _analyze(x, AbstractEnv(constraint, env), rw).asInstanceOf[BoolValue] bool.trueBound } @@ -711,64 +783,79 @@ class ExtractIntervalFilters(ctx: ExecuteContext, keyType: TStruct) { } } - private def computeKeyOrConst(x: IR, children: IndexedSeq[AbstractValue]): AbstractValue = x match { - case False() => ConstantValue(false, TBoolean) - case True() => ConstantValue(true, TBoolean) - case I32(v) => ConstantValue(v, x.typ) - case I64(v) => ConstantValue(v, x.typ) - case F32(v) => ConstantValue(v, x.typ) - case F64(v) => ConstantValue(v, x.typ) - case Str(v) => ConstantValue(v, x.typ) - case NA(_) => ConstantValue(null, x.typ) - case Literal(_, value) => ConstantValue(value, x.typ) - case EncodedLiteral(codec, arrays) => - ConstantValue(ctx.r.getPool().scopedRegion { r => - val (pt, addr) = codec.decodeArrays(ctx, codec.encodedVirtualType, arrays.ba, ctx.r) - SafeRow.read(pt, addr) - }, x.typ) - case ApplySpecial("lor", _, _, _, _) => children match { - case Seq(ConstantValue(l: Boolean), ConstantValue(r: Boolean)) => ConstantValue(l || r, TBoolean) - case _ => AbstractLattice.top - } - case ApplySpecial("land", _, _, _, _) => children match { - case Seq(ConstantValue(l: Boolean), ConstantValue(r: Boolean)) => ConstantValue(l && r, TBoolean) - case _ => AbstractLattice.top - } - case Apply("contig", _, Seq(k), _, _) => children match { - case Seq(KeyField(0)) => Contig(k.typ.asInstanceOf[TLocus].rg) - case _ => AbstractLattice.top - } - case Apply("position", _, Seq(k), _, _) => children match { - case Seq(KeyField(0)) => Position(k.typ.asInstanceOf[TLocus].rg) + private def computeKeyOrConst(x: IR, children: IndexedSeq[AbstractValue]): AbstractValue = + x match { + case False() => ConstantValue(false, TBoolean) + case True() => ConstantValue(true, TBoolean) + case I32(v) => ConstantValue(v, x.typ) + case I64(v) => ConstantValue(v, x.typ) + case F32(v) => ConstantValue(v, x.typ) + case F64(v) => ConstantValue(v, x.typ) + case Str(v) => ConstantValue(v, x.typ) + case NA(_) => ConstantValue(null, x.typ) + case Literal(_, value) => ConstantValue(value, x.typ) + case EncodedLiteral(codec, arrays) => + ConstantValue( + ctx.r.getPool().scopedRegion { r => + val (pt, addr) = codec.decodeArrays(ctx, codec.encodedVirtualType, arrays.ba, ctx.r) + SafeRow.read(pt, addr) + }, + x.typ, + ) + case ApplySpecial("lor", _, _, _, _) => children match { + case Seq(ConstantValue(l: Boolean), ConstantValue(r: Boolean)) => + ConstantValue(l || r, TBoolean) + case _ => AbstractLattice.top + } + case ApplySpecial("land", _, _, _, _) => children match { + case Seq(ConstantValue(l: Boolean), ConstantValue(r: Boolean)) => + ConstantValue(l && r, TBoolean) + case _ => AbstractLattice.top + } + case Apply("contig", _, Seq(k), _, _) => children match { + case Seq(KeyField(0)) => Contig(k.typ.asInstanceOf[TLocus].rg) + case _ => AbstractLattice.top + } + case Apply("position", _, Seq(k), _, _) => children match { + case Seq(KeyField(0)) => Position(k.typ.asInstanceOf[TLocus].rg) + case _ => AbstractLattice.top + } case _ => AbstractLattice.top } - case _ => AbstractLattice.top - } - private def computeBoolean(x: IR, children: IndexedSeq[AbstractValue], keySet: KeySet): BoolValue = (x, children) match { + private def computeBoolean(x: IR, children: IndexedSeq[AbstractValue], keySet: KeySet) + : BoolValue = (x, children) match { case (False(), _) => BoolValue.allFalse(keySet) case (True(), _) => BoolValue.allTrue(keySet) case (IsNA(_), Seq(KeyField(0))) => BoolValue( - KeySet(Interval(endpoint(null, -1), posInf)), - KeySet(Interval(negInf, endpoint(null, -1))), - KeySetLattice.bottom) - .restrict(keySet) + KeySet(Interval(endpoint(null, -1), posInf)), + KeySet(Interval(negInf, endpoint(null, -1))), + KeySetLattice.bottom, + ) + .restrict(keySet) case (IsNA(_), Seq(b: BoolValue)) => b.isNA.restrict(keySet) // collection contains - case (ApplyIR("contains", _, _, _, _), Seq(ConstantValue(collectionVal), queryVal)) if literalSizeOkay(collectionVal) => + case (ApplyIR("contains", _, _, _, _), Seq(ConstantValue(collectionVal), queryVal)) + if literalSizeOkay(collectionVal) => if (collectionVal == null) { BoolValue.allNA(keySet) } else queryVal match { case Contig(rgStr) => val rg = ctx.stateManager.referenceGenomes(rgStr) val intervals = intervalsFromLiteralContigs(collectionVal, rg) - BoolValue(intervals, KeySetLattice.complement(intervals), KeySetLattice.bottom).restrict(keySet) + BoolValue(intervals, KeySetLattice.complement(intervals), KeySetLattice.bottom).restrict( + keySet + ) case KeyField(0) => val intervals = intervalsFromLiteral(collectionVal, firstKeyOrd.toOrdering, true) - BoolValue(intervals, KeySetLattice.complement(intervals), KeySetLattice.bottom).restrict(keySet) + BoolValue(intervals, KeySetLattice.complement(intervals), KeySetLattice.bottom).restrict( + keySet + ) case struct: StructValue if struct.isKeyPrefix => val intervals = intervalsFromLiteral(collectionVal, keyOrd.toOrdering, false) - BoolValue(intervals, KeySetLattice.complement(intervals), KeySetLattice.bottom).restrict(keySet) + BoolValue(intervals, KeySetLattice.complement(intervals), KeySetLattice.bottom).restrict( + keySet + ) case _ => BoolValue.top(keySet) } // interval contains @@ -776,22 +863,24 @@ class ExtractIntervalFilters(ctx: ExecuteContext, keyType: TStruct) { (intervalVal: @unchecked) match { case null => BoolValue.allNA(keySet) case i: Interval => queryVal match { - case KeyField(0) => - val l = IntervalEndpoint(Row(i.left.point), i.left.sign) - val r = IntervalEndpoint(Row(i.right.point), i.right.sign) - BoolValue( - KeySet(Interval(l, r)), - KeySet(Interval(negInf, l), Interval(r, endpoint(null, -1))), - KeySet(Interval(endpoint(null, -1), posInf))) - .restrict(keySet) - case struct: StructValue if struct.isKeyPrefix => - BoolValue( - KeySet(i), - KeySet(Interval(negInf, i.left), Interval(i.right, posInf)), - KeySet.empty) - .restrict(keySet) - case _ => BoolValue.top(keySet) - } + case KeyField(0) => + val l = IntervalEndpoint(Row(i.left.point), i.left.sign) + val r = IntervalEndpoint(Row(i.right.point), i.right.sign) + BoolValue( + KeySet(Interval(l, r)), + KeySet(Interval(negInf, l), Interval(r, endpoint(null, -1))), + KeySet(Interval(endpoint(null, -1), posInf)), + ) + .restrict(keySet) + case struct: StructValue if struct.isKeyPrefix => + BoolValue( + KeySet(i), + KeySet(Interval(negInf, i.left), Interval(i.right, posInf)), + KeySet.empty, + ) + .restrict(keySet) + case _ => BoolValue.top(keySet) + } } case (ApplyComparisonOp(op, _, _), Seq(l, r)) => AbstractLattice.compare(l, r, op, keySet) @@ -806,7 +895,8 @@ class ExtractIntervalFilters(ctx: ExecuteContext, keyType: TStruct) { case class Rewrites( replaceWithTrue: mutable.Set[RefEquality[IR]], - replaceWithFalse: mutable.Set[RefEquality[IR]]) + replaceWithFalse: mutable.Set[RefEquality[IR]], + ) private def _analyze(x: IR, env: AbstractEnv, rewrites: Option[Rewrites]): AbstractValue = { def recur(x: IR, env: AbstractEnv = env): AbstractValue = @@ -815,15 +905,21 @@ class ExtractIntervalFilters(ctx: ExecuteContext, keyType: TStruct) { var res: AbstractLattice.Value = if (env.keySet == KeySetLattice.bottom) AbstractLattice.bottom else x match { - case Let(bindings, body) => - recur(body, bindings.foldLeft(env) { case (env, (name, value)) => - env.bind(name -> recur(value, env)) - }) + case Block(bindings, body) => + val newEnv = bindings.foldLeft[Option[AbstractEnv]](Some(env)) { + case (Some(env), Binding(name, value, Scope.EVAL)) => + Some(env.bind(name -> recur(value, env))) + case _ => None + } + newEnv match { + case Some(env) => recur(body, env) + case None => null + } case Ref(name, _) => env(name) case GetField(o, name) => recur(o).asInstanceOf[StructValue](name) case MakeStruct(fields) => StructValue(fields.view.map { case (name, field) => - name -> recur(field) - }) + name -> recur(field) + }) case SelectFields(old, fields) => val oldVal = recur(old) StructValue(fields.view.map(name => name -> oldVal.asInstanceOf[StructValue](name))) @@ -831,11 +927,15 @@ class ExtractIntervalFilters(ctx: ExecuteContext, keyType: TStruct) { val c = recur(cond).asInstanceOf[BoolValue] val res = AbstractLattice.join( recur(cnsq, env.restrict(c.trueBound)), - recur(altr, env.restrict(c.falseBound))) + recur(altr, env.restrict(c.falseBound)), + ) // If cond is missing, then the result of If is missing. But only // our boolean abstraction tracks missingness. if (x.typ == TBoolean) - AbstractLattice.join(res, BoolValue(KeySetLattice.bottom, KeySetLattice.bottom, c.naBound)) + AbstractLattice.join( + res, + BoolValue(KeySetLattice.bottom, KeySetLattice.bottom, c.naBound), + ) else res case Switch(y_, default_, cases_) => @@ -845,7 +945,10 @@ class ExtractIntervalFilters(ctx: ExecuteContext, keyType: TStruct) { case _ => val res = cases_.foldLeft(recur(default_))((e, c) => AbstractLattice.join(e, recur(c))) if (x.typ == TBoolean) - AbstractLattice.join(res, BoolValue(KeySetLattice.bottom, KeySetLattice.bottom, KeySetLattice.top)) + AbstractLattice.join( + res, + BoolValue(KeySetLattice.bottom, KeySetLattice.bottom, KeySetLattice.top), + ) else res } @@ -853,11 +956,17 @@ class ExtractIntervalFilters(ctx: ExecuteContext, keyType: TStruct) { case StreamFold(a, zero, accumName, valueName, body) => recur(a) match { case ConstantValue(array) => array.asInstanceOf[Iterable[Any]] .foldLeft(recur(zero)) { (accum, value) => - recur(body, env.bind(accumName -> accum, valueName -> ConstantValue(value, TIterable.elementType(a.typ)))) + recur( + body, + env.bind( + accumName -> accum, + valueName -> ConstantValue(value, TIterable.elementType(a.typ)), + ), + ) } case _ => AbstractLattice.top } - case x@Coalesce(values) => + case x @ Coalesce(values) => val aVals = values.map(recur(_)) if (x.typ == TBoolean) { aVals.asInstanceOf[Seq[BoolValue]].reduce(BoolValue.coalesce) @@ -890,18 +999,37 @@ class ExtractIntervalFilters(ctx: ExecuteContext, keyType: TStruct) { res match { case res: BoolValue if x.typ == TBoolean => - assert(KeySetLattice.specializes(res.trueBound, env.keySet), s"\n trueBound = ${res.trueBound}\n env = ${env.keySet}\n ir = ${Pretty.sexprStyle(x, allowUnboundRefs = true)}") - assert(KeySetLattice.specializes(res.falseBound, env.keySet), s"\n falseBound = ${res.falseBound}\n env = ${env.keySet}") - assert(KeySetLattice.specializes(res.naBound, env.keySet), s"\n naBound = ${res.naBound}\n env = ${env.keySet}") + assert( + KeySetLattice.specializes(res.trueBound, env.keySet), + s"\n trueBound = ${res.trueBound}\n env = ${env.keySet}\n ir = ${Pretty.sexprStyle(x, allowUnboundRefs = true)}", + ) + assert( + KeySetLattice.specializes(res.falseBound, env.keySet), + s"\n falseBound = ${res.falseBound}\n env = ${env.keySet}", + ) + assert( + KeySetLattice.specializes(res.naBound, env.keySet), + s"\n naBound = ${res.naBound}\n env = ${env.keySet}", + ) case _ => } rewrites.foreach { rw => if (x.typ == TBoolean) { val bool = res.asInstanceOf[BoolValue] - if (KeySetLattice.meet(KeySetLattice.join(bool.falseBound, bool.naBound), env.keySet) == KeySetLattice.bottom) + if ( + KeySetLattice.meet( + KeySetLattice.join(bool.falseBound, bool.naBound), + env.keySet, + ) == KeySetLattice.bottom + ) rw.replaceWithTrue += RefEquality(x) - else if (KeySetLattice.meet(KeySetLattice.join(bool.trueBound, bool.naBound), env.keySet) == KeySetLattice.bottom) { + else if ( + KeySetLattice.meet( + KeySetLattice.join(bool.trueBound, bool.naBound), + env.keySet, + ) == KeySetLattice.bottom + ) { rw.replaceWithFalse += RefEquality(x) } } diff --git a/hail/src/main/scala/is/hail/expr/ir/FoldConstants.scala b/hail/src/main/scala/is/hail/expr/ir/FoldConstants.scala index 37dc84cfdbd..2ec474e40ae 100644 --- a/hail/src/main/scala/is/hail/expr/ir/FoldConstants.scala +++ b/hail/src/main/scala/is/hail/expr/ir/FoldConstants.scala @@ -1,53 +1,54 @@ package is.hail.expr.ir import is.hail.backend.ExecuteContext -import is.hail.types.virtual.TStream +import is.hail.types.virtual.{TStream, TVoid} import is.hail.utils.HailException object FoldConstants { def apply(ctx: ExecuteContext, ir: BaseIR): BaseIR = - ExecuteContext.scopedNewRegion(ctx) { ctx => - foldConstants(ctx, ir) - } + ExecuteContext.scopedNewRegion(ctx)(ctx => foldConstants(ctx, ir)) + private def foldConstants(ctx: ExecuteContext, ir: BaseIR): BaseIR = - RewriteBottomUp(ir, { - case _: Ref | - _: In | - _: RelationalRef | - _: RelationalLet | - _: ApplySeeded | - _: UUID4 | - _: ApplyAggOp | - _: ApplyScanOp | - _: AggLet | - _: Begin | - _: MakeNDArray | - _: NDArrayShape | - _: NDArrayReshape | - _: NDArrayConcat | - _: NDArraySlice | - _: NDArrayFilter | - _: NDArrayMap | - _: NDArrayMap2 | - _: NDArrayReindex | - _: NDArrayAgg | - _: NDArrayWrite | - _: NDArrayMatMul | - _: Trap | - _: Die | - _: RNGStateLiteral => None - case ir: IR if ir.typ.isInstanceOf[TStream] => None - case ir: IR if !IsConstant(ir) && - Interpretable(ir) && - ir.children.forall { - case c: IR => IsConstant(c) - case _ => false - } => - try { - Some(Literal.coerce(ir.typ, Interpret.alreadyLowered(ctx, ir))) - } catch { - case _: HailException => None - } - case _ => None - }) + RewriteBottomUp( + ir, + { + case _: Ref | + _: In | + _: RelationalRef | + _: RelationalLet | + _: ApplySeeded | + _: UUID4 | + _: ApplyAggOp | + _: ApplyScanOp | + _: MakeNDArray | + _: NDArrayShape | + _: NDArrayReshape | + _: NDArrayConcat | + _: NDArraySlice | + _: NDArrayFilter | + _: NDArrayMap | + _: NDArrayMap2 | + _: NDArrayReindex | + _: NDArrayAgg | + _: NDArrayWrite | + _: NDArrayMatMul | + _: Trap | + _: Die | + _: RNGStateLiteral => None + case ir: IR if ir.typ.isInstanceOf[TStream] || ir.typ == TVoid => None + case ir: IR + if !IsConstant(ir) && + Interpretable(ir) && + ir.children.forall { + case c: IR => IsConstant(c) + case _ => false + } => + try + Some(Literal.coerce(ir.typ, Interpret.alreadyLowered(ctx, ir))) + catch { + case _: HailException => None + } + case _ => None + }, + ) } diff --git a/hail/src/main/scala/is/hail/expr/ir/ForwardLets.scala b/hail/src/main/scala/is/hail/expr/ir/ForwardLets.scala index d78c0630ea3..58425ebcee3 100644 --- a/hail/src/main/scala/is/hail/expr/ir/ForwardLets.scala +++ b/hail/src/main/scala/is/hail/expr/ir/ForwardLets.scala @@ -1,6 +1,7 @@ package is.hail.expr.ir import is.hail.backend.ExecuteContext +import is.hail.types.virtual.TVoid import is.hail.utils.BoxedArrayBuilder import scala.collection.Set @@ -13,52 +14,55 @@ object ForwardLets { def rewrite(ir: BaseIR, env: BindingEnv[IR]): BaseIR = { - def shouldForward(value: IR, refs: Set[RefEquality[BaseRef]], base: IR): Boolean = { - value.isInstanceOf[Ref] || - value.isInstanceOf[In] || - (IsConstant(value) && !value.isInstanceOf[Str]) || - refs.isEmpty || - (refs.size == 1 && - nestingDepth.lookup(refs.head) == nestingDepth.lookup(base) && - !ContainsScan(value) && - !ContainsAgg(value)) && + def shouldForward(value: IR, refs: Set[RefEquality[BaseRef]], base: Block, scope: Int) + : Boolean = { + IsPure(value) && ( + value.isInstanceOf[Ref] || + value.isInstanceOf[In] || + (IsConstant(value) && !value.isInstanceOf[Str]) || + refs.isEmpty || + (refs.size == 1 && + nestingDepth.lookupRef(refs.head) == nestingDepth.lookupBinding(base, scope) && + !ContainsScan(value) && + !ContainsAgg(value)) && !ContainsAggIntermediate(value) + ) } ir match { - case l@Let(bindings, body) => - val keep = new BoxedArrayBuilder[(String, IR)] - val refs = uses(ir) - val newEnv = bindings.foldLeft(env) { case (env, (name, value)) => - val rewriteValue = rewrite(value, env).asInstanceOf[IR] - if (shouldForward(rewriteValue, refs.filter(_.t.name == name), l)) - env.bindEval(name -> rewriteValue) - else {keep += (name -> rewriteValue); env} + case l: Block => + val keep = new BoxedArrayBuilder[Binding] + val refs = uses(l) + val newEnv = l.bindings.foldLeft(env) { + case (env, Binding(name, value, scope)) => + val rewriteValue = rewrite(value, env.promoteScope(scope)).asInstanceOf[IR] + if ( + rewriteValue.typ != TVoid + && shouldForward(rewriteValue, refs.filter(_.t.name == name), l, scope) + ) { + env.bindInScope(name, rewriteValue, scope) + } else { + keep += Binding(name, rewriteValue, scope) + env + } } - val newBody = rewrite(body, newEnv).asInstanceOf[IR] + val newBody = rewrite(l.body, newEnv).asInstanceOf[IR] if (keep.isEmpty) newBody - else Let(keep.result(), newBody) + else Block(keep.result(), newBody) - case l@AggLet(name, value, body, isScan) => - val refs = uses.lookup(ir) - val rewriteValue = rewrite(value, if (isScan) env.promoteScan else env.promoteAgg).asInstanceOf[IR] - if (shouldForward(rewriteValue, refs, l)) - if (isScan) - rewrite(body, env.copy(scan = Some(env.scan.get.bind(name -> rewriteValue)))) - else - rewrite(body, env.copy(agg = Some(env.agg.get.bind(name -> rewriteValue)))) - else - AggLet(name, rewriteValue, rewrite(body, env).asInstanceOf[IR], isScan) - case x@Ref(name, _) => + case x @ Ref(name, _) => env.eval .lookupOption(name) - .map { forwarded => if (uses.lookup(defs.lookup(x)).count(_.t.name == name) > 1) forwarded.deepCopy() else forwarded } + .map { forwarded => + if (uses.lookup(defs.lookup(x)).count(_.t.name == name) > 1) forwarded.deepCopy() + else forwarded + } .getOrElse(x) case _ => - ir.mapChildrenWithIndex { (ir1, i) => - rewrite(ir1, ChildEnvWithoutBindings(ir, i, env)) - } + ir.mapChildrenWithIndex((ir1, i) => + rewrite(ir1, Bindings.segregated(ir, i, env).childEnvWithoutBindings) + ) } } diff --git a/hail/src/main/scala/is/hail/expr/ir/ForwardRelationalLets.scala b/hail/src/main/scala/is/hail/expr/ir/ForwardRelationalLets.scala index 19617827186..4e356844772 100644 --- a/hail/src/main/scala/is/hail/expr/ir/ForwardRelationalLets.scala +++ b/hail/src/main/scala/is/hail/expr/ir/ForwardRelationalLets.scala @@ -19,9 +19,9 @@ object ForwardRelationalLets { usages(name) = (0, 0) case RelationalLetBlockMatrix(name, _, _) => usages(name) = (0, 0) - case x@RelationalRef(name, _) => + case x @ RelationalRef(name, _) => val (n, nd) = usages(name) - usages(name) = (n + 1, math.max(nd, nestingDepth.lookup(x))) + usages(name) = (n + 1, math.max(nd, nestingDepth.lookupRef(x))) case _ => } ir1.children.foreach(visit) @@ -47,18 +47,30 @@ object ForwardRelationalLets { if (shouldForward(usages(name))) { m(name) = recur(value).asInstanceOf[IR] recur(body) - } else RelationalLetTable(name, recur(value).asInstanceOf[IR], recur(body).asInstanceOf[TableIR]) + } else RelationalLetTable( + name, + recur(value).asInstanceOf[IR], + recur(body).asInstanceOf[TableIR], + ) case RelationalLetMatrixTable(name, value, body) => if (shouldForward(usages(name))) { m(name) = recur(value).asInstanceOf[IR] recur(body) - } else RelationalLetMatrixTable(name, recur(value).asInstanceOf[IR], recur(body).asInstanceOf[MatrixIR]) + } else RelationalLetMatrixTable( + name, + recur(value).asInstanceOf[IR], + recur(body).asInstanceOf[MatrixIR], + ) case RelationalLetBlockMatrix(name, value, body) => if (shouldForward(usages(name))) { m(name) = recur(value).asInstanceOf[IR] recur(body) - } else RelationalLetBlockMatrix(name, recur(value).asInstanceOf[IR], recur(body).asInstanceOf[BlockMatrixIR]) - case x@RelationalRef(name, _) => + } else RelationalLetBlockMatrix( + name, + recur(value).asInstanceOf[IR], + recur(body).asInstanceOf[BlockMatrixIR], + ) + case x @ RelationalRef(name, _) => m.getOrElse(name, x) case _ => ir1.mapChildren(recur) } diff --git a/hail/src/main/scala/is/hail/expr/ir/FreeVariables.scala b/hail/src/main/scala/is/hail/expr/ir/FreeVariables.scala index d2e4f7c8a0d..cde037e4c41 100644 --- a/hail/src/main/scala/is/hail/expr/ir/FreeVariables.scala +++ b/hail/src/main/scala/is/hail/expr/ir/FreeVariables.scala @@ -17,7 +17,7 @@ object FreeVariables { val aE = compute(a, baseEnv) val qE = compute(query, baseEnv.copy(agg = Some(Env.empty))) aE.merge(qE.copy(eval = qE.eval.bindIterable(qE.agg.get.m - name), agg = baseEnv.agg)) - case ApplyAggOp(init, seq, sig) => + case ApplyAggOp(init, seq, _) => val initEnv = baseEnv.copy(agg = None) val initFreeVars = init.iterator.map(x => compute(x, initEnv)).fold(initEnv)(_.merge(_)) .copy(agg = Some(Env.empty[Unit])) @@ -26,7 +26,7 @@ object FreeVariables { val e = compute(x, seqEnv) e.copy(eval = Env.empty[Unit], agg = Some(e.eval)) }.fold(initFreeVars)(_.merge(_)) - case ApplyScanOp(init, seq, sig) => + case ApplyScanOp(init, seq, _) => val initEnv = baseEnv.copy(scan = None) val initFreeVars = init.iterator.map(x => compute(x, initEnv)).fold(initEnv)(_.merge(_)) .copy(scan = Some(Env.empty[Unit])) @@ -38,26 +38,33 @@ object FreeVariables { case AggFold(zero, seqOp, combOp, accumName, otherAccumName, isScan) => val zeroEnv = if (isScan) baseEnv.copy(scan = None) else baseEnv.copy(agg = None) val zeroFreeVarsCompute = compute(zero, zeroEnv) - val zeroFreeVars = if (isScan) zeroFreeVarsCompute.copy(scan = Some(Env.empty[Unit])) else zeroFreeVarsCompute.copy(agg = Some(Env.empty[Unit])) + val zeroFreeVars = if (isScan) zeroFreeVarsCompute.copy(scan = Some(Env.empty[Unit])) + else zeroFreeVarsCompute.copy(agg = Some(Env.empty[Unit])) val seqOpEnv = if (isScan) baseEnv.promoteScan else baseEnv.promoteAgg val seqOpFreeVarsCompute = compute(seqOp, seqOpEnv) val seqOpFreeVars = if (isScan) { - seqOpFreeVarsCompute.copy(eval = Env.empty[Unit], scan = Some(seqOpFreeVarsCompute.eval)) + seqOpFreeVarsCompute.copy( + eval = Env.empty[Unit], + scan = Some(seqOpFreeVarsCompute.eval), + ) } else { seqOpFreeVarsCompute.copy(eval = Env.empty[Unit], agg = Some(seqOpFreeVarsCompute.eval)) } val combEval = Env.fromSeq(IndexedSeq((accumName, {}), (otherAccumName, {}))) - val combOpFreeVarsCompute = compute(combOp, baseEnv.copy(eval=combEval)) - val combOpFreeVars = combOpFreeVarsCompute.copy(eval = Env.empty[Unit], scan = Some(combOpFreeVarsCompute.eval)) + val combOpFreeVarsCompute = compute(combOp, baseEnv.copy(eval = combEval)) + val combOpFreeVars = combOpFreeVarsCompute.copy( + eval = Env.empty[Unit], + scan = Some(combOpFreeVarsCompute.eval), + ) zeroFreeVars.merge(seqOpFreeVars).merge(combOpFreeVars) case _ => ir1.children .zipWithIndex .map { case (child: IR, i) => - val childEnv = ChildEnvWithoutBindings(ir1, i, baseEnv) - val sub = compute(child, childEnv) - .subtract(NewBindings(ir1, i, childEnv)) + val bindings = Bindings.segregated(ir1, i, baseEnv) + val childEnv = bindings.childEnvWithoutBindings + val sub = compute(child, childEnv).subtract(bindings.newBindings) if (UsesAggEnv(ir1, i)) sub.copy(eval = Env.empty[Unit], agg = Some(sub.eval), scan = baseEnv.scan) else if (UsesScanEnv(ir1, i)) @@ -71,8 +78,13 @@ object FreeVariables { } } - compute(ir, BindingEnv(Env.empty, - if (supportsAgg) Some(Env.empty[Unit]) else None, - if (supportsScan) Some(Env.empty[Unit]) else None)) + compute( + ir, + BindingEnv( + Env.empty, + if (supportsAgg) Some(Env.empty[Unit]) else None, + if (supportsScan) Some(Env.empty[Unit]) else None, + ), + ) } } diff --git a/hail/src/main/scala/is/hail/expr/ir/GenericLines.scala b/hail/src/main/scala/is/hail/expr/ir/GenericLines.scala index 6fe9b601778..8adb4fe75eb 100644 --- a/hail/src/main/scala/is/hail/expr/ir/GenericLines.scala +++ b/hail/src/main/scala/is/hail/expr/ir/GenericLines.scala @@ -2,20 +2,20 @@ package is.hail.expr.ir import is.hail.backend.spark.SparkBackend import is.hail.io.compress.BGzipInputStream -import is.hail.io.fs.{FS, FileListEntry, Positioned, PositionedInputStream, BGZipCompressionCodec} -import is.hail.io.tabix.{TabixReader, TabixLineIterator} +import is.hail.io.fs.{BGZipCompressionCodec, FS, FileStatus, Positioned, PositionedInputStream} +import is.hail.io.tabix.{TabixLineIterator, TabixReader} import is.hail.types.virtual.{TBoolean, TInt32, TInt64, TString, TStruct, Type} import is.hail.utils._ import is.hail.variant.Locus +import scala.annotation.meta.param + import org.apache.commons.io.input.{CountingInputStream, ProxyInputStream} import org.apache.hadoop.io.compress.SplittableCompressionCodec import org.apache.spark.{Partition, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row -import scala.annotation.meta.param - trait CloseableIterator[T] extends Iterator[T] with AutoCloseable object CloseableIterator { @@ -27,7 +27,8 @@ object CloseableIterator { } object GenericLines { - def read(fs: FS, contexts: IndexedSeq[Any], gzAsBGZ: Boolean, filePerPartition: Boolean): GenericLines = { + def read(fs: FS, contexts: IndexedSeq[Any], gzAsBGZ: Boolean, filePerPartition: Boolean) + : GenericLines = { val body: (FS, Any) => CloseableIterator[GenericLine] = { (fs: FS, context: Any) => val contextRow = context.asInstanceOf[Row] @@ -50,7 +51,8 @@ object GenericLines { } else if (codec == BGZipCompressionCodec) { assert(split || filePerPartition) splitCompressed = true - val bgzIS = new BGzipInputStream(rawIS, start, end, SplittableCompressionCodec.READ_MODE.BYBLOCK) + val bgzIS = + new BGzipInputStream(rawIS, start, end, SplittableCompressionCodec.READ_MODE.BYBLOCK) new ProxyInputStream(bgzIS) with Positioned { def getPosition: Long = bgzIS.getVirtualOffset } @@ -65,14 +67,14 @@ object GenericLines { private var eof = false private var closed = false - private var buf = new Array[Byte](64 * 1024) + private val buf = new Array[Byte](64 * 1024) private var bufOffset = 0L private var bufMark = 0 private var bufPos = 0 private var realEnd = if (splitCompressed) - -1L // end really means first block >= end + -1L // end really means first block >= end else end @@ -141,10 +143,12 @@ object GenericLines { eol = true } else { // look for end of line in buf - while (bufPos < bufMark && { - val c = buf(bufPos) - c != '\n' && c != '\r' - }) + while ( + bufPos < bufMark && { + val c = buf(bufPos) + c != '\n' && c != '\r' + } + ) bufPos += 1 if (bufPos < bufMark) { @@ -173,10 +177,12 @@ object GenericLines { val copySize = linePos.toLong + n // Maximum array size compatible with common JDK implementations - // https://github.com/openjdk/jdk14u/blob/84917a040a81af2863fddc6eace3dda3e31bf4b5/src/java.base/share/classes/jdk/internal/util/ArraysSupport.java#L577 + /* https://github.com/openjdk/jdk14u/blob/84917a040a81af2863fddc6eace3dda3e31bf4b5/src/java.base/share/classes/jdk/internal/util/ArraysSupport.java#L577 */ val maxArraySize = Int.MaxValue - 8 if (copySize > maxArraySize) - fatal(s"GenericLines: line size reached: cannot read a line with more than 2^31-1 bytes") + fatal( + s"GenericLines: line size reached: cannot read a line with more than 2^31-1 bytes" + ) val newSize = Math.min(copySize * 2, maxArraySize).toInt if (newSize > (1 << 20)) { log.info(s"GenericLines: growing line buffer to $newSize") @@ -227,12 +233,11 @@ object GenericLines { line } - def close(): Unit = { + def close(): Unit = if (!closed) { is.close() closed = true } - } } } @@ -242,26 +247,27 @@ object GenericLines { "file" -> TString, "start" -> TInt64, "end" -> TInt64, - "split" -> TBoolean) + "split" -> TBoolean, + ) new GenericLines( contextType, contexts, - body) + body, + ) } - def read( fs: FS, - fileListEntries0: IndexedSeq[FileListEntry], + fileStatuses0: IndexedSeq[_ <: FileStatus], nPartitions: Option[Int], blockSizeInMB: Option[Int], minPartitions: Option[Int], gzAsBGZ: Boolean, allowSerialRead: Boolean, - filePerPartition: Boolean = false + filePerPartition: Boolean = false, ): GenericLines = { - val fileListEntries = fileListEntries0.zipWithIndex.filter(_._1.getLen > 0) - val totalSize = fileListEntries.map(_._1.getLen).sum + val fileStatuses = fileStatuses0.zipWithIndex.filter(_._1.getLen > 0) + val totalSize = fileStatuses.map(_._1.getLen).sum var totalPartitions = nPartitions match { case Some(nPartitions) => nPartitions @@ -276,7 +282,7 @@ object GenericLines { case None => } - val contexts = fileListEntries.flatMap { case (fileListEntry, fileNum) => + val contexts = fileStatuses.flatMap { case (fileListEntry, fileNum) => val size = fileListEntry.getLen val codec = fs.getCodecFromPath(fileListEntry.getPath, gzAsBGZ) @@ -298,7 +304,7 @@ object GenericLines { } } else { if (!allowSerialRead && !filePerPartition) - fatal(s"Cowardly refusing to read file serially: ${ fileListEntry.getPath }.") + fatal(s"Cowardly refusing to read file serially: ${fileListEntry.getPath}.") Iterator.single { Row(0, fileNum, fileListEntry.getPath, 0L, size, false) @@ -309,7 +315,12 @@ object GenericLines { GenericLines.read(fs, contexts, gzAsBGZ, filePerPartition) } - def readTabix(fs: FS, path: String, contigMapping: Map[String, String], partitions: IndexedSeq[Interval]): GenericLines = { + def readTabix( + fs: FS, + path: String, + contigMapping: Map[String, String], + partitions: IndexedSeq[Interval], + ): GenericLines = { val reverseContigMapping: Map[String, String] = contigMapping.toArray .groupBy(_._2) @@ -317,8 +328,8 @@ object GenericLines { if (mappings.length > 1) fatal(s"contig_recoding may not map multiple contigs to the same target contig, " + s"due to ambiguity when querying the tabix index." + - s"\n Duplicate mappings: ${ mappings.map(_._1).mkString(",") } all map to ${ target }") - (target, mappings.head._1) + s"\n Duplicate mappings: ${mappings.map(_._1).mkString(",")} all map to $target") + (target, mappings.head._1) }.toMap val contexts = partitions.zipWithIndex.map { case (interval, i) => val start = interval.start.asInstanceOf[Row].getAs[Locus](0) @@ -328,7 +339,6 @@ object GenericLines { } val body: (FS, Any) => CloseableIterator[GenericLine] = { (fs: FS, context: Any) => val contextRow = context.asInstanceOf[Row] - val index = contextRow.getAs[Int](0) val file = contextRow.getAs[String](1) val chrom = contextRow.getAs[String](2) val start = contextRow.getAs[Int](3) @@ -368,7 +378,7 @@ object GenericLines { val pos = s.substring(t1 + 1, t2).toInt if (chr != chrom) { - throw new RuntimeException(s"bad chromosome! ${chrom}, $s") + throw new RuntimeException(s"bad chromosome! $chrom, $s") } start <= pos && pos <= end } @@ -384,17 +394,15 @@ object GenericLines { "file" -> TString, "chrom" -> TString, "start" -> TInt32, - "end" -> TInt32) + "end" -> TInt32, + ) new GenericLines(contextType, contexts, body) } - def collect(fs: FS, lines: GenericLines): IndexedSeq[String] = { + def collect(fs: FS, lines: GenericLines): IndexedSeq[String] = lines.contexts.flatMap { context => - using(lines.body(fs, context)) { it => - it.map(_.toString).toArray - } + using(lines.body(fs, context))(it => it.map(_.toString).toArray) } - } } class GenericLine( @@ -403,7 +411,8 @@ class GenericLine( // possibly virtual private var _offset: Long, var data: Array[Byte], - private var _lineLength: Int) { + private var _lineLength: Int, +) { def this(file: String, fileNum: Int) = this(file, fileNum, 0, null, 0) private var _str: String = null @@ -438,8 +447,8 @@ class GenericLine( class GenericLinesRDDPartition(val index: Int, val context: Any) extends Partition class GenericLinesRDD( - @(transient@param) contexts: IndexedSeq[Any], - body: (Any) => CloseableIterator[GenericLine] + @(transient @param) contexts: IndexedSeq[Any], + body: (Any) => CloseableIterator[GenericLine], ) extends RDD[GenericLine](SparkBackend.sparkContext("GenericLinesRDD"), Seq()) { protected def getPartitions: Array[Partition] = @@ -449,9 +458,7 @@ class GenericLinesRDD( def compute(split: Partition, context: TaskContext): Iterator[GenericLine] = { val it = body(split.asInstanceOf[GenericLinesRDDPartition].context) - TaskContext.get.addTaskCompletionListener[Unit] { _ => - it.close() - } + TaskContext.get.addTaskCompletionListener[Unit](_ => it.close()) it } } @@ -459,7 +466,8 @@ class GenericLinesRDD( class GenericLines( val contextType: Type, val contexts: IndexedSeq[Any], - val body: (FS, Any) => CloseableIterator[GenericLine]) { + val body: (FS, Any) => CloseableIterator[GenericLine], +) { def nPartitions: Int = contexts.length diff --git a/hail/src/main/scala/is/hail/expr/ir/GenericTableValue.scala b/hail/src/main/scala/is/hail/expr/ir/GenericTableValue.scala index cfc04a596fc..11b890c8a7b 100644 --- a/hail/src/main/scala/is/hail/expr/ir/GenericTableValue.scala +++ b/hail/src/main/scala/is/hail/expr/ir/GenericTableValue.scala @@ -1,35 +1,32 @@ package is.hail.expr.ir -import is.hail.annotations.{BroadcastRow, Region} +import is.hail.annotations.Region import is.hail.asm4s._ import is.hail.backend.ExecuteContext -import is.hail.backend.spark.SparkBackend import is.hail.expr.ir.functions.UtilFunctions import is.hail.expr.ir.lowering.{TableStage, TableStageDependency} import is.hail.expr.ir.streams.StreamProducer import is.hail.io.fs.FS import is.hail.rvd._ -import is.hail.sparkextras.ContextRDD +import is.hail.types.{RStruct, TableType, TypeWithRequiredness} +import is.hail.types.physical.PStruct import is.hail.types.physical.stypes.EmitType import is.hail.types.physical.stypes.concrete.{SStackStruct, SStackStructValue} -import is.hail.types.physical.stypes.interfaces.{SBaseStructValue, SStream, SStreamValue, primitive} -import is.hail.types.physical.stypes.primitives.{SInt64, SInt64Value} -import is.hail.types.physical.{PStruct, PType} +import is.hail.types.physical.stypes.interfaces.{primitive, SBaseStructValue, SStreamValue} +import is.hail.types.physical.stypes.primitives.SInt64 import is.hail.types.virtual.{TArray, TInt32, TInt64, TStruct, TTuple, Type} -import is.hail.types.{RStruct, TableType, TypeWithRequiredness} import is.hail.utils._ -import org.apache.spark.rdd.RDD + import org.apache.spark.sql.Row -import org.apache.spark.{Partition, TaskContext} -import org.json4s.JsonAST.{JObject, JString} import org.json4s.{Extraction, JValue} +import org.json4s.JsonAST.{JObject, JString} class PartitionIteratorLongReader( rowType: TStruct, override val uidFieldName: String, override val contextType: TStruct, bodyPType: TStruct => PStruct, - body: TStruct => (Region, HailClassLoader, FS, Any) => Iterator[Long] + body: TStruct => (Region, HailClassLoader, FS, Any) => Iterator[Long], ) extends PartitionReader { assert(contextType.hasField("partitionIndex")) assert(contextType.fieldType("partitionIndex") == TInt32) @@ -51,7 +48,8 @@ class PartitionIteratorLongReader( cb: EmitCodeBuilder, mb: EmitMethodBuilder[_], context: EmitCode, - requestedType: TStruct): IEmitCode = { + requestedType: TStruct, + ): IEmitCode = { val insertUID: Boolean = requestedType.hasField(uidFieldName) && !rowType.hasField(uidFieldName) @@ -63,11 +61,13 @@ class PartitionIteratorLongReader( val eltPType = bodyPType(concreteType) val uidSType: SStackStruct = SStackStruct( TTuple(TInt64, TInt64), - Array(EmitType(SInt64, true), EmitType(SInt64, true))) + Array(EmitType(SInt64, true), EmitType(SInt64, true)), + ) context.toI(cb).map(cb) { case _ctxStruct: SBaseStructValue => val ctxStruct = cb.memoizeField(_ctxStruct, "ctxStruct") - val partIdx = cb.memoizeField(_ctxStruct.loadField(cb, "partitionIndex").get(cb).asInt.value.toL) + val partIdx = + cb.memoizeField(_ctxStruct.loadField(cb, "partitionIndex").getOrAssert(cb).asInt.value.toL) val rowIdx = mb.genFieldThisRef[Long]("pnr_rowidx") val region = mb.genFieldThisRef[Region]("pilr_region") val it = mb.genFieldThisRef[Iterator[java.lang.Long]]("pilr_it") @@ -79,32 +79,53 @@ class PartitionIteratorLongReader( override def initialize(cb: EmitCodeBuilder, partitionRegion: Value[Region]): Unit = { val ctxJavaValue = UtilFunctions.svalueToJavaValue(cb, partitionRegion, ctxStruct) - cb.assign(it, cb.emb.getObject(body(requestedType)) - .invoke[java.lang.Object, java.lang.Object, java.lang.Object, java.lang.Object, Iterator[java.lang.Long]]( - "apply", region, cb.emb.getHailClassLoader, cb.emb.getFS, ctxJavaValue)) + cb.assign( + it, + cb.emb.getObject(body(requestedType)) + .invoke[ + java.lang.Object, + java.lang.Object, + java.lang.Object, + java.lang.Object, + Iterator[java.lang.Long], + ]( + "apply", + region, + cb.emb.getHailClassLoader, + cb.emb.getFS, + ctxJavaValue, + ), + ) } override val elementRegion: Settable[Region] = region override val requiresMemoryManagementPerElement: Boolean = true override val LproduceElement: CodeLabel = mb.defineAndImplementLabel { cb => - cb.if_(!it.get.hasNext, - cb.goto(LendOfStream)) + cb.if_(!it.get.hasNext, cb.goto(LendOfStream)) cb.assign(rv, Code.longValue(it.get.next())) cb.assign(rowIdx, rowIdx + 1L) - cb.goto(LproduceElementDone) } - override val element: EmitCode = EmitCode.fromI(mb)(cb => IEmitCode.present(cb, - if (insertUID) { - val uid = EmitValue.present( - new SStackStructValue(uidSType, Array( - EmitValue.present(primitive(partIdx)), - EmitValue.present(primitive(rowIdx))))) - eltPType.loadCheapSCode(cb, rv)._insert(requestedType, uidFieldName -> uid) - } else { - eltPType.loadCheapSCode(cb, rv) - })) + override val element: EmitCode = EmitCode.fromI(mb)(cb => + IEmitCode.present( + cb, + if (insertUID) { + val uid = EmitValue.present( + new SStackStructValue( + uidSType, + Array( + EmitValue.present(primitive(partIdx)), + EmitValue.present(primitive(rowIdx)), + ), + ) + ) + eltPType.loadCheapSCode(cb, rv)._insert(requestedType, uidFieldName -> uid) + } else { + eltPType.loadCheapSCode(cb, rv) + }, + ) + ) override def close(cb: EmitCodeBuilder): Unit = {} } @@ -113,21 +134,23 @@ class PartitionIteratorLongReader( } } - def toJValue: JValue = { + def toJValue: JValue = JObject( "category" -> JString("PartitionIteratorLongReader"), "fullRowType" -> Extraction.decompose(fullRowType)(PartitionReader.formats), "uidFieldName" -> JString(uidFieldName), - "contextType" -> Extraction.decompose(contextType)(PartitionReader.formats)) - } + "contextType" -> Extraction.decompose(contextType)(PartitionReader.formats), + ) } abstract class LoweredTableReaderCoercer { - def coerce(ctx: ExecuteContext, + def coerce( + ctx: ExecuteContext, globals: IR, contextType: Type, contexts: IndexedSeq[Any], - body: IR => IR): TableStage + body: IR => IR, + ): TableStage } class GenericTableValue( @@ -138,14 +161,17 @@ class GenericTableValue( val contextType: TStruct, var contexts: IndexedSeq[Any], val bodyPType: TStruct => PStruct, - val body: TStruct => (Region, HailClassLoader, FS, Any) => Iterator[Long]) { + val body: TStruct => (Region, HailClassLoader, FS, Any) => Iterator[Long], +) { assert(fullTableType.rowType.hasField(uidFieldName), s"uid=$uidFieldName, t=$fullTableType") assert(contextType.hasField("partitionIndex")) assert(contextType.fieldType("partitionIndex") == TInt32) private var ltrCoercer: LoweredTableReaderCoercer = _ - private def getLTVCoercer(ctx: ExecuteContext, context: String, cacheKey: Any): LoweredTableReaderCoercer = { + + private def getLTVCoercer(ctx: ExecuteContext, context: String, cacheKey: Any) + : LoweredTableReaderCoercer = { if (ltrCoercer == null) { ltrCoercer = LoweredTableReader.makeCoercer( ctx, @@ -158,19 +184,27 @@ class GenericTableValue( bodyPType, body, context, - cacheKey) + cacheKey, + ) } ltrCoercer } - def toTableStage(ctx: ExecuteContext, requestedType: TableType, context: String, cacheKey: Any): TableStage = { + def toTableStage(ctx: ExecuteContext, requestedType: TableType, context: String, cacheKey: Any) + : TableStage = { val globalsIR = Literal(requestedType.globalType, globals(requestedType.globalType)) - val requestedBody: (IR) => (IR) = (ctx: IR) => ReadPartition(ctx, - requestedType.rowType, - new PartitionIteratorLongReader( - fullTableType.rowType, uidFieldName, contextType, - (requestedType: Type) => bodyPType(requestedType.asInstanceOf[TStruct]), - (requestedType: Type) => body(requestedType.asInstanceOf[TStruct]))) + val requestedBody: (IR) => (IR) = (ctx: IR) => + ReadPartition( + ctx, + requestedType.rowType, + new PartitionIteratorLongReader( + fullTableType.rowType, + uidFieldName, + contextType, + (requestedType: Type) => bodyPType(requestedType.asInstanceOf[TStruct]), + (requestedType: Type) => body(requestedType.asInstanceOf[TStruct]), + ), + ) var p: RVDPartitioner = null partitioner match { case Some(partitioner) => diff --git a/hail/src/main/scala/is/hail/expr/ir/IR.scala b/hail/src/main/scala/is/hail/expr/ir/IR.scala index 1fcc1fb7ccf..e18a2056301 100644 --- a/hail/src/main/scala/is/hail/expr/ir/IR.scala +++ b/hail/src/main/scala/is/hail/expr/ir/IR.scala @@ -8,33 +8,33 @@ import is.hail.expr.ir.agg.{AggStateSig, PhysicalAggSig} import is.hail.expr.ir.functions._ import is.hail.expr.ir.lowering.TableStageDependency import is.hail.expr.ir.streams.StreamProducer +import is.hail.io.{AbstractTypedCodecSpec, BufferSpec, TypedCodecSpec} import is.hail.io.avro.{AvroPartitionReader, AvroSchemaSerializer} import is.hail.io.bgen.BgenPartitionReader import is.hail.io.vcf.{GVCFPartitionReader, VCFHeaderInfo} -import is.hail.io.{AbstractTypedCodecSpec, BufferSpec, TypedCodecSpec} import is.hail.rvd.RVDSpecMaker +import is.hail.types.{tcoerce, RIterable, RStruct, TypeWithRequiredness} import is.hail.types.encoded._ import is.hail.types.physical._ import is.hail.types.physical.stypes._ import is.hail.types.physical.stypes.concrete.SJavaString import is.hail.types.physical.stypes.interfaces._ import is.hail.types.virtual._ -import is.hail.types.{RIterable, RStruct, TypeWithRequiredness, tcoerce} import is.hail.utils._ -import org.json4s.JsonAST.{JNothing, JString} -import org.json4s.{DefaultFormats, Extraction, Formats, JValue, ShortTypeHints} import java.io.OutputStream -import scala.language.existentials + +import org.json4s.{DefaultFormats, Extraction, Formats, JValue, ShortTypeHints} +import org.json4s.JsonAST.{JNothing, JString} sealed trait IR extends BaseIR { private var _typ: Type = null def typ: Type = { if (_typ == null) - try { + try _typ = InferType(this) - } catch { + catch { case e: Throwable => throw new RuntimeException(s"typ: inference failure:", e) } _typ @@ -43,12 +43,13 @@ sealed trait IR extends BaseIR { protected lazy val childrenSeq: IndexedSeq[BaseIR] = Children(this) - protected override def copy(newChildren: IndexedSeq[BaseIR]): IR = + override protected def copy(newChildren: IndexedSeq[BaseIR]): IR = Copy(this, newChildren) override def mapChildren(f: BaseIR => BaseIR): IR = super.mapChildren(f).asInstanceOf[IR] - override def mapChildrenWithIndex(f: (BaseIR, Int) => BaseIR): IR = super.mapChildrenWithIndex(f).asInstanceOf[IR] + override def mapChildrenWithIndex(f: (BaseIR, Int) => BaseIR): IR = + super.mapChildrenWithIndex(f).asInstanceOf[IR] override def deepCopy(): this.type = { @@ -59,9 +60,9 @@ sealed trait IR extends BaseIR { } lazy val size: Int = 1 + children.map { - case x: IR => x.size - case _ => 0 - }.sum + case x: IR => x.size + case _ => 0 + }.sum private[this] def _unwrap: IR => IR = { case node: ApplyIR => MapIR(_unwrap)(node.explicitNode) @@ -103,9 +104,8 @@ final case class Literal(_typ: Type, value: Annotation) extends IR { } object EncodedLiteral { - def apply(codec: AbstractTypedCodecSpec, value: Array[Array[Byte]]): EncodedLiteral = { + def apply(codec: AbstractTypedCodecSpec, value: Array[Array[Byte]]): EncodedLiteral = EncodedLiteral(codec, new WrappedByteArrays(value)) - } def fromPTypeAndAddress(pt: PType, addr: Long, ctx: ExecuteContext): IR = { pt match { @@ -124,22 +124,21 @@ object EncodedLiteral { } } -final case class EncodedLiteral(codec: AbstractTypedCodecSpec, value: WrappedByteArrays) extends IR { +final case class EncodedLiteral(codec: AbstractTypedCodecSpec, value: WrappedByteArrays) + extends IR { require(!CanEmit(codec.encodedVirtualType)) require(value != null) } class WrappedByteArrays(val ba: Array[Array[Byte]]) { - override def hashCode(): Int = { - ba.foldLeft(31) { (h, b) => 37 * h + java.util.Arrays.hashCode(b) } - } + override def hashCode(): Int = + ba.foldLeft(31)((h, b) => 37 * h + java.util.Arrays.hashCode(b)) override def equals(obj: Any): Boolean = { this.eq(obj.asInstanceOf[AnyRef]) || { if (!obj.isInstanceOf[WrappedByteArrays]) { false - } - else { + } else { val other = obj.asInstanceOf[WrappedByteArrays] ba.length == other.ba.length && (ba, other.ba).zipped.forall(java.util.Arrays.equals) } @@ -151,9 +150,11 @@ final case class I32(x: Int) extends IR with TrivialIR final case class I64(x: Long) extends IR with TrivialIR final case class F32(x: Float) extends IR with TrivialIR final case class F64(x: Double) extends IR with TrivialIR + final case class Str(x: String) extends IR with TrivialIR { override def toString(): String = s"""Str("${StringEscapeUtils.escapeString(x)}")""" } + final case class True() extends IR with TrivialIR final case class False() extends IR with TrivialIR final case class Void() extends IR with TrivialIR @@ -188,24 +189,53 @@ final case class Switch(x: IR, default: IR, cases: IndexedSeq[IR]) extends IR { 2 + cases.length } -final case class AggLet(name: String, value: IR, body: IR, isScan: Boolean) extends IR -final case class Let(bindings: IndexedSeq[(String, IR)], body: IR) extends IR { - override lazy val size: Int = - bindings.length + 1 +object AggLet { + def apply(name: String, value: IR, body: IR, isScan: Boolean): IR = { + val scope = if (isScan) Scope.SCAN else Scope.AGG + Block(FastSeq(Binding(name, value, scope)), body) + } } object Let { - case class Extract(p: ((String, IR)) => Boolean) { - def unapply(bindings: IndexedSeq[(String, IR)]): - Option[(IndexedSeq[(String, IR)], IndexedSeq[(String, IR)])] = { - val idx = bindings.indexWhere(p) - if (idx == -1) None else Some(bindings.splitAt(idx)) + def apply(bindings: IndexedSeq[(String, IR)], body: IR): Block = + Block( + bindings.map { case (name, value) => Binding(name, value) }, + body, + ) + + def void(bindings: IndexedSeq[(String, IR)]): IR = { + if (bindings.isEmpty) { + Void() + } else { + assert(bindings.last._2.typ == TVoid) + Let(bindings.init, bindings.last._2) } } - object Nested extends Extract(_._2.isInstanceOf[Let]) - object Insert extends Extract(_._2.isInstanceOf[InsertFields]) +} + +case class Binding(name: String, value: IR, scope: Int = Scope.EVAL) + +final case class Block(bindings: IndexedSeq[Binding], body: IR) extends IR { + override lazy val size: Int = + bindings.length + 1 +} + +object Block { + object Insert { + def unapply(bindings: IndexedSeq[Binding]) + : Option[(IndexedSeq[Binding], Binding, IndexedSeq[Binding])] = { + val idx = bindings.indexWhere(_.value.isInstanceOf[InsertFields]) + if (idx == -1) None else Some((bindings.take(idx), bindings(idx), bindings.drop(idx + 1))) + } + } + object Nested { + def unapply(bindings: IndexedSeq[Binding]): Option[(Int, IndexedSeq[Binding])] = { + val idx = bindings.indexWhere(_.value.isInstanceOf[Block]) + if (idx == -1) None else Some((idx, bindings)) + } + } } sealed abstract class BaseRef extends IR with TrivialIR { @@ -220,12 +250,17 @@ final case class Ref(name: String, var _typ: Type) extends BaseRef { } } - // Recur can't exist outside of loop // Loops can be nested, but we can't call outer loops in terms of inner loops so there can only be one loop "active" in a given context -final case class TailLoop(name: String, params: IndexedSeq[(String, IR)], resultType: Type, body: IR) extends IR { +final case class TailLoop( + name: String, + params: IndexedSeq[(String, IR)], + resultType: Type, + body: IR, +) extends IR { lazy val paramIdx: Map[String, Int] = params.map(_._1).zipWithIndex.toMap } + final case class Recur(name: String, args: IndexedSeq[IR], var _typ: Type) extends BaseRef final case class RelationalLet(name: String, value: IR, body: IR) extends IR @@ -244,57 +279,91 @@ object MakeArray { def unify(ctx: ExecuteContext, args: IndexedSeq[IR], requestedType: TArray = null): MakeArray = { assert(requestedType != null || args.nonEmpty) - if(args.nonEmpty) + if (args.nonEmpty) if (args.forall(_.typ == args.head.typ)) return MakeArray(args, TArray(args.head.typ)) - MakeArray(args.map { arg => - val upcast = PruneDeadFields.upcast(ctx, arg, requestedType.elementType) - assert(upcast.typ == requestedType.elementType) - upcast - }, requestedType) + MakeArray( + args.map { arg => + val upcast = PruneDeadFields.upcast(ctx, arg, requestedType.elementType) + assert(upcast.typ == requestedType.elementType) + upcast + }, + requestedType, + ) } } final case class MakeArray(args: IndexedSeq[IR], _typ: TArray) extends IR object MakeStream { - def unify(ctx: ExecuteContext, args: IndexedSeq[IR], requiresMemoryManagementPerElement: Boolean = false, requestedType: TStream = null): MakeStream = { + def unify( + ctx: ExecuteContext, + args: IndexedSeq[IR], + requiresMemoryManagementPerElement: Boolean = false, + requestedType: TStream = null, + ): MakeStream = { assert(requestedType != null || args.nonEmpty) if (args.nonEmpty) if (args.forall(_.typ == args.head.typ)) return MakeStream(args, TStream(args.head.typ), requiresMemoryManagementPerElement) - MakeStream(args.map { arg => - val upcast = PruneDeadFields.upcast(ctx, arg, requestedType.elementType) - assert(upcast.typ == requestedType.elementType) - upcast - }, requestedType, requiresMemoryManagementPerElement) + MakeStream( + args.map { arg => + val upcast = PruneDeadFields.upcast(ctx, arg, requestedType.elementType) + assert(upcast.typ == requestedType.elementType) + upcast + }, + requestedType, + requiresMemoryManagementPerElement, + ) } } -final case class MakeStream(args: IndexedSeq[IR], _typ: TStream, requiresMemoryManagementPerElement: Boolean = false) extends IR +final case class MakeStream( + args: IndexedSeq[IR], + _typ: TStream, + requiresMemoryManagementPerElement: Boolean = false, +) extends IR object ArrayRef { def apply(a: IR, i: IR): ArrayRef = ArrayRef(a, i, ErrorIDs.NO_ERROR) } final case class ArrayRef(a: IR, i: IR, errorID: Int) extends IR -final case class ArraySlice(a: IR, start: IR, stop: Option[IR], step:IR = I32(1), errorID: Int = ErrorIDs.NO_ERROR) extends IR + +final case class ArraySlice( + a: IR, + start: IR, + stop: Option[IR], + step: IR = I32(1), + errorID: Int = ErrorIDs.NO_ERROR, +) extends IR + final case class ArrayLen(a: IR) extends IR final case class ArrayZeros(length: IR) extends IR -final case class ArrayMaximalIndependentSet(edges: IR, tieBreaker: Option[(String, String, IR)]) extends IR -/** - * [[StreamIota]] is an infinite stream producer, whose element is an integer starting at `start`, updated by - * `step` at each iteration. The name comes from APL: +final case class ArrayMaximalIndependentSet(edges: IR, tieBreaker: Option[(String, String, IR)]) + extends IR + +/** [[StreamIota]] is an infinite stream producer, whose element is an integer starting at `start`, + * updated by `step` at each iteration. The name comes from APL: * [[https://stackoverflow.com/questions/9244879/what-does-iota-of-stdiota-stand-for]] */ -final case class StreamIota(start: IR, step: IR, requiresMemoryManagementPerElement: Boolean = false) extends IR - -final case class StreamRange(start: IR, stop: IR, step: IR, requiresMemoryManagementPerElement: Boolean = false, - errorID: Int = ErrorIDs.NO_ERROR) extends IR +final case class StreamIota( + start: IR, + step: IR, + requiresMemoryManagementPerElement: Boolean = false, +) extends IR + +final case class StreamRange( + start: IR, + stop: IR, + step: IR, + requiresMemoryManagementPerElement: Boolean = false, + errorID: Int = ErrorIDs.NO_ERROR, +) extends IR object ArraySort { def apply(a: IR, ascending: IR = True(), onKey: Boolean = false): ArraySort = { @@ -304,15 +373,27 @@ object ArraySort { val compare = if (onKey) { val elementType = atyp.elementType.asInstanceOf[TBaseStruct] elementType match { - case t: TStruct => + case _: TStruct => val elt = tcoerce[TStruct](atyp.elementType) - ApplyComparisonOp(Compare(elt.types(0)), GetField(Ref(l, elt), elt.fieldNames(0)), GetField(Ref(r, atyp.elementType), elt.fieldNames(0))) - case t: TTuple => + ApplyComparisonOp( + Compare(elt.types(0)), + GetField(Ref(l, elt), elt.fieldNames(0)), + GetField(Ref(r, atyp.elementType), elt.fieldNames(0)), + ) + case _: TTuple => val elt = tcoerce[TTuple](atyp.elementType) - ApplyComparisonOp(Compare(elt.types(0)), GetTupleElement(Ref(l, elt), elt.fields(0).index), GetTupleElement(Ref(r, atyp.elementType), elt.fields(0).index)) + ApplyComparisonOp( + Compare(elt.types(0)), + GetTupleElement(Ref(l, elt), elt.fields(0).index), + GetTupleElement(Ref(r, atyp.elementType), elt.fields(0).index), + ) } } else { - ApplyComparisonOp(Compare(atyp.elementType), Ref(l, atyp.elementType), Ref(r, atyp.elementType)) + ApplyComparisonOp( + Compare(atyp.elementType), + Ref(l, atyp.elementType), + Ref(r, atyp.elementType), + ) } ArraySort(a, l, r, If(ascending, compare < 0, compare > 0)) @@ -325,10 +406,19 @@ final case class ToDict(a: IR) extends IR final case class ToArray(a: IR) extends IR final case class CastToArray(a: IR) extends IR final case class ToStream(a: IR, requiresMemoryManagementPerElement: Boolean = false) extends IR -final case class StreamBufferedAggregate(streamChild: IR, initAggs: IR, newKey: IR, seqOps: IR, name: String, - aggSignatures: IndexedSeq[PhysicalAggSig], bufferSize: Int) extends IR -final case class LowerBoundOnOrderedCollection(orderedCollection: IR, elem: IR, onKey: Boolean) extends IR +final case class StreamBufferedAggregate( + streamChild: IR, + initAggs: IR, + newKey: IR, + seqOps: IR, + name: String, + aggSignatures: IndexedSeq[PhysicalAggSig], + bufferSize: Int, +) extends IR + +final case class LowerBoundOnOrderedCollection(orderedCollection: IR, elem: IR, onKey: Boolean) + extends IR final case class GroupByKey(collection: IR) extends IR @@ -353,11 +443,22 @@ final case class StreamTake(a: IR, num: IR) extends IR final case class StreamDrop(a: IR, num: IR) extends IR // Generate, in ascending order, a uniform random sample, without replacement, of numToSample integers in the range [0, totalRange) -final case class SeqSample(totalRange: IR, numToSample: IR, rngState: IR, requiresMemoryManagementPerElement: Boolean) extends IR +final case class SeqSample( + totalRange: IR, + numToSample: IR, + rngState: IR, + requiresMemoryManagementPerElement: Boolean, +) extends IR // Take the child stream and sort each element into buckets based on the provided pivots. The first and last elements of // pivots are the endpoints of the first and last interval respectively, should not be contained in the dataset. -final case class StreamDistribute(child: IR, pivots: IR, path: IR, comparisonOp: ComparisonOp[_],spec: AbstractTypedCodecSpec) extends IR +final case class StreamDistribute( + child: IR, + pivots: IR, + path: IR, + comparisonOp: ComparisonOp[_], + spec: AbstractTypedCodecSpec, +) extends IR // "Whiten" a stream of vectors by regressing out from each vector all components // in the direction of vectors in the preceding window. For efficiency, takes @@ -365,7 +466,16 @@ final case class StreamDistribute(child: IR, pivots: IR, path: IR, comparisonOp: // Takes a stream of structs, with two designated fields: `prevWindow` is the // previous window (e.g. from the previous partition), if there is one, and // `newChunk` is the new chunk to whiten. -final case class StreamWhiten(stream: IR, newChunk: String, prevWindow: String, vecSize: Int, windowSize: Int, chunkSize: Int, blockSize: Int, normalizeAfterWhiten: Boolean) extends IR +final case class StreamWhiten( + stream: IR, + newChunk: String, + prevWindow: String, + vecSize: Int, + windowSize: Int, + chunkSize: Int, + blockSize: Int, + normalizeAfterWhiten: Boolean, +) extends IR object ArrayZipBehavior extends Enumeration { type ArrayZipBehavior = Value @@ -376,44 +486,67 @@ object ArrayZipBehavior extends Enumeration { } final case class StreamZip( - as: IndexedSeq[IR], names: IndexedSeq[String], body: IR, behavior: ArrayZipBehavior, - errorID: Int = ErrorIDs.NO_ERROR + as: IndexedSeq[IR], + names: IndexedSeq[String], + body: IR, + behavior: ArrayZipBehavior, + errorID: Int = ErrorIDs.NO_ERROR, ) extends TypedIR[TStream] -final case class StreamMultiMerge(as: IndexedSeq[IR], key: IndexedSeq[String]) extends TypedIR[TStream] - +final case class StreamMultiMerge(as: IndexedSeq[IR], key: IndexedSeq[String]) + extends TypedIR[TStream] final case class StreamZipJoinProducers( - contexts: IR, ctxName: String, makeProducer: IR, key: IndexedSeq[String], - curKey: String, curVals: String, joinF: IR + contexts: IR, + ctxName: String, + makeProducer: IR, + key: IndexedSeq[String], + curKey: String, + curVals: String, + joinF: IR, ) extends TypedIR[TStream] -/** - * The StreamZipJoin node assumes that input streams have distinct keys. If input streams - * do not have distinct keys, the key that is included in the result is undefined, but - * is likely the last. - */ +/** The StreamZipJoin node assumes that input streams have distinct keys. If input streams do not + * have distinct keys, the key that is included in the result is undefined, but is likely the last. + */ final case class StreamZipJoin( - as: IndexedSeq[IR], key: IndexedSeq[String], curKey: String, curVals: String, joinF: IR + as: IndexedSeq[IR], + key: IndexedSeq[String], + curKey: String, + curVals: String, + joinF: IR, ) extends TypedIR[TStream] final case class StreamFilter(a: IR, name: String, cond: IR) extends TypedIR[TStream] final case class StreamFlatMap(a: IR, name: String, body: IR) extends TypedIR[TStream] -final case class StreamFold(a: IR, zero: IR, accumName: String, valueName: String, body: IR) extends IR +final case class StreamFold(a: IR, zero: IR, accumName: String, valueName: String, body: IR) + extends IR object StreamFold2 { - def apply(a: StreamFold): StreamFold2 = { - StreamFold2(a.a, FastSeq((a.accumName, a.zero)), a.valueName, FastSeq(a.body), Ref(a.accumName, a.zero.typ)) - } + def apply(a: StreamFold): StreamFold2 = + StreamFold2( + a.a, + FastSeq((a.accumName, a.zero)), + a.valueName, + FastSeq(a.body), + Ref(a.accumName, a.zero.typ), + ) } -final case class StreamFold2(a: IR, accum: IndexedSeq[(String, IR)], valueName: String, seq: IndexedSeq[IR], result: IR) extends IR { +final case class StreamFold2( + a: IR, + accum: IndexedSeq[(String, IR)], + valueName: String, + seq: IndexedSeq[IR], + result: IR, +) extends IR { assert(accum.length == seq.length) val nameIdx: Map[String, Int] = accum.map(_._1).zipWithIndex.toMap } -final case class StreamScan(a: IR, zero: IR, accumName: String, valueName: String, body: IR) extends IR +final case class StreamScan(a: IR, zero: IR, accumName: String, valueName: String, body: IR) + extends IR final case class StreamFor(a: IR, valueName: String, body: IR) extends IR @@ -422,21 +555,24 @@ final case class StreamAggScan(a: IR, name: String, query: IR) extends IR object StreamJoin { def apply( - left: IR, right: IR, - lKey: IndexedSeq[String], rKey: IndexedSeq[String], - l: String, r: String, + left: IR, + right: IR, + lKey: IndexedSeq[String], + rKey: IndexedSeq[String], + l: String, + r: String, joinF: IR, joinType: String, requiresMemoryManagement: Boolean, - rightKeyIsDistinct: Boolean = false + rightKeyIsDistinct: Boolean = false, ): IR = { val lType = tcoerce[TStream](left.typ) val rType = tcoerce[TStream](right.typ) val lEltType = tcoerce[TStruct](lType.elementType) val rEltType = tcoerce[TStruct](rType.elementType) - assert(lEltType.typeAfterSelectNames(lKey) isIsomorphicTo rEltType.typeAfterSelectNames(rKey)) + assert(lEltType.typeAfterSelectNames(lKey) isJoinableWith rEltType.typeAfterSelectNames(rKey)) - if(!rightKeyIsDistinct) { + if (!rightKeyIsDistinct) { val rightGroupedStream = StreamGroupByKey(right, rKey, missingEqual = false) val groupField = genUID() @@ -445,7 +581,7 @@ object StreamJoin { val rightGrouped = mapIR(rightGroupedStream) { group => bindIR(ToArray(group)) { array => bindIR(ArrayRef(array, 0)) { head => - MakeStruct(rKey.map { key => key -> GetField(head, key) } :+ groupField -> array) + MakeStruct(rKey.map(key => key -> GetField(head, key)) :+ groupField -> array) } } } @@ -453,21 +589,40 @@ object StreamJoin { val rElt = Ref(genUID(), tcoerce[TStream](rightGrouped.typ).elementType) val lElt = Ref(genUID(), lEltType) val makeTupleFromJoin = MakeStruct(FastSeq("left" -> lElt, "rightGroup" -> rElt)) - val joined = StreamJoinRightDistinct(left, rightGrouped, lKey, rKey, lElt.name, rElt.name, makeTupleFromJoin, joinType) + val joined = StreamJoinRightDistinct( + left, + rightGrouped, + lKey, + rKey, + lElt.name, + rElt.name, + makeTupleFromJoin, + joinType, + ) // joined is a stream of {leftElement, rightGroup} bindIR(MakeArray(NA(rEltType))) { missingSingleton => flatMapIR(joined) { x => - Let(FastSeq(l -> GetField(x, "left")), bindIR(GetField(GetField(x, "rightGroup"), groupField)) { rightElts => - joinType match { - case "left" | "outer" => StreamMap(ToStream(If(IsNA(rightElts), missingSingleton, rightElts), requiresMemoryManagement), r, joinF) - case "right" | "inner" => StreamMap(ToStream(rightElts, requiresMemoryManagement), r, joinF) - } - }) + Let( + FastSeq(l -> GetField(x, "left")), + bindIR(GetField(GetField(x, "rightGroup"), groupField)) { rightElts => + joinType match { + case "left" | "outer" => StreamMap( + ToStream( + If(IsNA(rightElts), missingSingleton, rightElts), + requiresMemoryManagement, + ), + r, + joinF, + ) + case "right" | "inner" => + StreamMap(ToStream(rightElts, requiresMemoryManagement), r, joinF) + } + }, + ) } } - } - else { + } else { val rElt = Ref(r, rEltType) val lElt = Ref(l, lEltType) StreamJoinRightDistinct(left, right, lKey, rKey, lElt.name, rElt.name, joinF, joinType) @@ -475,7 +630,34 @@ object StreamJoin { } } -final case class StreamJoinRightDistinct(left: IR, right: IR, lKey: IndexedSeq[String], rKey: IndexedSeq[String], l: String, r: String, joinF: IR, joinType: String) extends IR { +final case class StreamLeftIntervalJoin( + // input streams + left: IR, + right: IR, + + // names for joiner + lKeyFieldName: String, + rIntervalFieldName: String, + + // how to combine records + lname: String, + rname: String, + body: IR, +) extends IR { + override protected lazy val childrenSeq: IndexedSeq[BaseIR] = + FastSeq(left, right, body) +} + +final case class StreamJoinRightDistinct( + left: IR, + right: IR, + lKey: IndexedSeq[String], + rKey: IndexedSeq[String], + l: String, + r: String, + joinF: IR, + joinType: String, +) extends IR { def isIntervalJoin: Boolean = { if (rKey.size != 1) return false val lKeyTyp = tcoerce[TStruct](tcoerce[TStream](left.typ).elementType).fieldType(lKey(0)) @@ -485,7 +667,13 @@ final case class StreamJoinRightDistinct(left: IR, right: IR, lKey: IndexedSeq[S } } -final case class StreamLocalLDPrune(child: IR, r2Threshold: IR, windowSize: IR, maxQueueSize: IR, nSamples: IR) extends IR +final case class StreamLocalLDPrune( + child: IR, + r2Threshold: IR, + windowSize: IR, + maxQueueSize: IR, + nSamples: IR, +) extends IR sealed trait NDArrayIR extends TypedIR[TNDArray] { def elementTyp: Type = typ.elementType @@ -494,12 +682,15 @@ sealed trait NDArrayIR extends TypedIR[TNDArray] { object MakeNDArray { def fill(elt: IR, shape: IndexedSeq[IR], rowMajor: IR): MakeNDArray = { val flatSize: IR = if (shape.nonEmpty) - shape.reduce { (l, r) => l * r } + shape.reduce((l, r) => l * r) else 0L MakeNDArray( ToArray(mapIR(rangeIR(flatSize.toI))(_ => elt)), - MakeTuple.ordered(shape), rowMajor, ErrorIDs.NO_ERROR) + MakeTuple.ordered(shape), + rowMajor, + ErrorIDs.NO_ERROR, + ) } } @@ -516,7 +707,9 @@ final case class NDArraySlice(nd: IR, slices: IR) extends NDArrayIR final case class NDArrayFilter(nd: IR, keep: IndexedSeq[IR]) extends NDArrayIR final case class NDArrayMap(nd: IR, valueName: String, body: IR) extends NDArrayIR -final case class NDArrayMap2(l: IR, r: IR, lName: String, rName: String, body: IR, errorID: Int) extends NDArrayIR + +final case class NDArrayMap2(l: IR, r: IR, lName: String, rName: String, body: IR, errorID: Int) + extends NDArrayIR final case class NDArrayReindex(nd: IR, indexExpr: IndexedSeq[Int]) extends NDArrayIR final case class NDArrayAgg(nd: IR, axes: IndexedSeq[Int]) extends IR @@ -528,9 +721,21 @@ object NDArrayQR { def pType(mode: String, req: Boolean): PType = { mode match { case "r" => PCanonicalNDArray(PFloat64Required, 2, req) - case "raw" => PCanonicalTuple(req, PCanonicalNDArray(PFloat64Required, 2, true), PCanonicalNDArray(PFloat64Required, 1, true)) - case "reduced" => PCanonicalTuple(req, PCanonicalNDArray(PFloat64Required, 2, true), PCanonicalNDArray(PFloat64Required, 2, true)) - case "complete" => PCanonicalTuple(req, PCanonicalNDArray(PFloat64Required, 2, true), PCanonicalNDArray(PFloat64Required, 2, true)) + case "raw" => PCanonicalTuple( + req, + PCanonicalNDArray(PFloat64Required, 2, true), + PCanonicalNDArray(PFloat64Required, 1, true), + ) + case "reduced" => PCanonicalTuple( + req, + PCanonicalNDArray(PFloat64Required, 2, true), + PCanonicalNDArray(PFloat64Required, 2, true), + ) + case "complete" => PCanonicalTuple( + req, + PCanonicalNDArray(PFloat64Required, 2, true), + PCanonicalNDArray(PFloat64Required, 2, true), + ) } } } @@ -538,9 +743,13 @@ object NDArrayQR { object NDArraySVD { def pTypes(computeUV: Boolean, req: Boolean): PType = { if (computeUV) { - PCanonicalTuple(req, PCanonicalNDArray(PFloat64Required, 2, true), PCanonicalNDArray(PFloat64Required, 1, true), PCanonicalNDArray(PFloat64Required, 2, true)) - } - else { + PCanonicalTuple( + req, + PCanonicalNDArray(PFloat64Required, 2, true), + PCanonicalNDArray(PFloat64Required, 1, true), + PCanonicalNDArray(PFloat64Required, 2, true), + ) + } else { PCanonicalNDArray(PFloat64Required, 1, req) } } @@ -552,17 +761,22 @@ object NDArrayInv { final case class NDArrayQR(nd: IR, mode: String, errorID: Int) extends IR -final case class NDArraySVD(nd: IR, fullMatrices: Boolean, computeUV: Boolean, errorID: Int) extends IR +final case class NDArraySVD(nd: IR, fullMatrices: Boolean, computeUV: Boolean, errorID: Int) + extends IR object NDArrayEigh { - def pTypes(eigvalsOnly: Boolean, req: Boolean): PType = { + def pTypes(eigvalsOnly: Boolean, req: Boolean): PType = if (eigvalsOnly) { PCanonicalNDArray(PFloat64Required, 1, req) } else { - PCanonicalTuple(req, PCanonicalNDArray(PFloat64Required, 1, true), PCanonicalNDArray(PFloat64Required, 2, true)) + PCanonicalTuple( + req, + PCanonicalNDArray(PFloat64Required, 1, true), + PCanonicalNDArray(PFloat64Required, 2, true), + ) } - } } + final case class NDArrayEigh(nd: IR, eigvalsOnly: Boolean, errorID: Int) extends IR final case class NDArrayInv(nd: IR, errorID: Int) extends IR @@ -573,14 +787,29 @@ final case class AggExplode(array: IR, name: String, aggBody: IR, isScan: Boolea final case class AggGroupBy(key: IR, aggIR: IR, isScan: Boolean) extends IR -final case class AggArrayPerElement(a: IR, elementName: String, indexName: String, aggBody: IR, knownLength: Option[IR], isScan: Boolean) extends IR +final case class AggArrayPerElement( + a: IR, + elementName: String, + indexName: String, + aggBody: IR, + knownLength: Option[IR], + isScan: Boolean, +) extends IR object ApplyAggOp { def apply(op: AggOp, initOpArgs: IR*)(seqOpArgs: IR*): ApplyAggOp = - ApplyAggOp(initOpArgs.toIndexedSeq, seqOpArgs.toIndexedSeq, AggSignature(op, initOpArgs.map(_.typ), seqOpArgs.map(_.typ))) + ApplyAggOp( + initOpArgs.toIndexedSeq, + seqOpArgs.toIndexedSeq, + AggSignature(op, initOpArgs.map(_.typ), seqOpArgs.map(_.typ)), + ) } -final case class ApplyAggOp(initOpArgs: IndexedSeq[IR], seqOpArgs: IndexedSeq[IR], aggSig: AggSignature) extends IR { +final case class ApplyAggOp( + initOpArgs: IndexedSeq[IR], + seqOpArgs: IndexedSeq[IR], + aggSig: AggSignature, +) extends IR { def nSeqOpArgs = seqOpArgs.length @@ -603,11 +832,12 @@ object AggFold { minAndMaxHelper(element, keyType, StructGT(keyType, sortFields)) } - def all(element: IR): IR = { + def all(element: IR): IR = aggFoldIR(True(), element) { case (accum, element) => ApplySpecial("land", Seq.empty[Type], Seq(accum, element), TBoolean, ErrorIDs.NO_ERROR) - } { case (accum1, accum2) => ApplySpecial("land", Seq.empty[Type], Seq(accum1, accum2), TBoolean, ErrorIDs.NO_ERROR) } - } + } { case (accum1, accum2) => + ApplySpecial("land", Seq.empty[Type], Seq(accum1, accum2), TBoolean, ErrorIDs.NO_ERROR) + } private def minAndMaxHelper(element: IR, keyType: TStruct, comp: ComparisonOp[Boolean]): IR = { val keyFields = keyType.fields.map(_.name) @@ -618,29 +848,54 @@ object AggFold { val aggFoldMinAccumRef1 = Ref(aggFoldMinAccumName1, keyType) val aggFoldMinAccumRef2 = Ref(aggFoldMinAccumName2, keyType) val minSeq = bindIR(SelectFields(element, keyFields)) { keyOfCurElementRef => - If(IsNA(aggFoldMinAccumRef1), + If( + IsNA(aggFoldMinAccumRef1), keyOfCurElementRef, - If(ApplyComparisonOp(comp, aggFoldMinAccumRef1, keyOfCurElementRef), aggFoldMinAccumRef1, keyOfCurElementRef) + If( + ApplyComparisonOp(comp, aggFoldMinAccumRef1, keyOfCurElementRef), + aggFoldMinAccumRef1, + keyOfCurElementRef, + ), ) } val minComb = - If(IsNA(aggFoldMinAccumRef1), + If( + IsNA(aggFoldMinAccumRef1), aggFoldMinAccumRef2, - If (ApplyComparisonOp(comp, aggFoldMinAccumRef1, aggFoldMinAccumRef2), aggFoldMinAccumRef1, aggFoldMinAccumRef2) + If( + ApplyComparisonOp(comp, aggFoldMinAccumRef1, aggFoldMinAccumRef2), + aggFoldMinAccumRef1, + aggFoldMinAccumRef2, + ), ) AggFold(minAndMaxZero, minSeq, minComb, aggFoldMinAccumName1, aggFoldMinAccumName2, false) } } -final case class AggFold(zero: IR, seqOp: IR, combOp: IR, accumName: String, otherAccumName: String, isScan: Boolean) extends IR +final case class AggFold( + zero: IR, + seqOp: IR, + combOp: IR, + accumName: String, + otherAccumName: String, + isScan: Boolean, +) extends IR object ApplyScanOp { def apply(op: AggOp, initOpArgs: IR*)(seqOpArgs: IR*): ApplyScanOp = - ApplyScanOp(initOpArgs.toIndexedSeq, seqOpArgs.toIndexedSeq, AggSignature(op, initOpArgs.map(_.typ), seqOpArgs.map(_.typ))) + ApplyScanOp( + initOpArgs.toIndexedSeq, + seqOpArgs.toIndexedSeq, + AggSignature(op, initOpArgs.map(_.typ), seqOpArgs.map(_.typ)), + ) } -final case class ApplyScanOp(initOpArgs: IndexedSeq[IR], seqOpArgs: IndexedSeq[IR], aggSig: AggSignature) extends IR { +final case class ApplyScanOp( + initOpArgs: IndexedSeq[IR], + seqOpArgs: IndexedSeq[IR], + aggSig: AggSignature, +) extends IR { def nSeqOpArgs = seqOpArgs.length @@ -652,24 +907,52 @@ final case class ApplyScanOp(initOpArgs: IndexedSeq[IR], seqOpArgs: IndexedSeq[I final case class InitOp(i: Int, args: IndexedSeq[IR], aggSig: PhysicalAggSig) extends IR final case class SeqOp(i: Int, args: IndexedSeq[IR], aggSig: PhysicalAggSig) extends IR final case class CombOp(i1: Int, i2: Int, aggSig: PhysicalAggSig) extends IR + object ResultOp { - def makeTuple(aggs: IndexedSeq[PhysicalAggSig]) = { + def makeTuple(aggs: IndexedSeq[PhysicalAggSig]) = MakeTuple.ordered(aggs.zipWithIndex.map { case (aggSig, index) => ResultOp(index, aggSig) }) - } } + final case class ResultOp(idx: Int, aggSig: PhysicalAggSig) extends IR -private final case class CombOpValue(i: Int, value: IR, aggSig: PhysicalAggSig) extends IR +final private case class CombOpValue(i: Int, value: IR, aggSig: PhysicalAggSig) extends IR final case class AggStateValue(i: Int, aggSig: AggStateSig) extends IR final case class InitFromSerializedValue(i: Int, value: IR, aggSig: AggStateSig) extends IR -final case class SerializeAggs(startIdx: Int, serializedIdx: Int, spec: BufferSpec, aggSigs: IndexedSeq[AggStateSig]) extends IR -final case class DeserializeAggs(startIdx: Int, serializedIdx: Int, spec: BufferSpec, aggSigs: IndexedSeq[AggStateSig]) extends IR +final case class SerializeAggs( + startIdx: Int, + serializedIdx: Int, + spec: BufferSpec, + aggSigs: IndexedSeq[AggStateSig], +) extends IR + +final case class DeserializeAggs( + startIdx: Int, + serializedIdx: Int, + spec: BufferSpec, + aggSigs: IndexedSeq[AggStateSig], +) extends IR final case class RunAgg(body: IR, result: IR, signature: IndexedSeq[AggStateSig]) extends IR -final case class RunAggScan(array: IR, name: String, init: IR, seqs: IR, result: IR, signature: IndexedSeq[AggStateSig]) extends IR + +final case class RunAggScan( + array: IR, + name: String, + init: IR, + seqs: IR, + result: IR, + signature: IndexedSeq[AggStateSig], +) extends IR + +object Begin { + def apply(xs: IndexedSeq[IR]): IR = + if (xs.isEmpty) + Void() + else + Let(xs.init.map(x => ("__void", x)), xs.last) +} final case class Begin(xs: IndexedSeq[IR]) extends IR final case class MakeStruct(fields: IndexedSeq[(String, IR)]) extends IR @@ -678,36 +961,48 @@ final case class SelectFields(old: IR, fields: IndexedSeq[String]) extends IR object InsertFields { def apply(old: IR, fields: Seq[(String, IR)]): InsertFields = InsertFields(old, fields, None) } -final case class InsertFields(old: IR, fields: Seq[(String, IR)], fieldOrder: Option[IndexedSeq[String]]) extends TypedIR[TStruct] + +final case class InsertFields( + old: IR, + fields: Seq[(String, IR)], + fieldOrder: Option[IndexedSeq[String]], +) extends TypedIR[TStruct] object GetFieldByIdx { - def apply(s: IR, field: Int): IR = { + def apply(s: IR, field: Int): IR = (s.typ: @unchecked) match { case t: TStruct => GetField(s, t.fieldNames(field)) case _: TTuple => GetTupleElement(s, field) } - } } final case class GetField(o: IR, name: String) extends IR object MakeTuple { - def ordered(types: IndexedSeq[IR]): MakeTuple = MakeTuple(types.zipWithIndex.map { case (ir, i) => (i, ir) }) + def ordered(types: IndexedSeq[IR]): MakeTuple = MakeTuple(types.zipWithIndex.map { case (ir, i) => + (i, ir) + }) } final case class MakeTuple(fields: IndexedSeq[(Int, IR)]) extends IR final case class GetTupleElement(o: IR, idx: Int) extends IR object In { - def apply(i: Int, typ: Type): In = In(i, SingleCodeEmitParamType(false, typ match { - case TInt32 => Int32SingleCodeType - case TInt64 => Int64SingleCodeType - case TFloat32 => Float32SingleCodeType - case TFloat64 => Float64SingleCodeType - case TBoolean => BooleanSingleCodeType - case ts: TStream => throw new UnsupportedOperationException - case t => PTypeReferenceSingleCodeType(PType.canonical(t)) - })) + def apply(i: Int, typ: Type): In = In( + i, + SingleCodeEmitParamType( + false, + typ match { + case TInt32 => Int32SingleCodeType + case TInt64 => Int64SingleCodeType + case TFloat32 => Float32SingleCodeType + case TFloat64 => Float64SingleCodeType + case TBoolean => BooleanSingleCodeType + case _: TStream => throw new UnsupportedOperationException + case t => PTypeReferenceSingleCodeType(PType.canonical(t)) + }, + ), + ) } // Function Input @@ -719,17 +1014,22 @@ object Die { def apply(message: String, typ: Type, errorId: Int): Die = Die(Str(message), typ, errorId) } -/** - * the Trap node runs the `child` node with an exception handler. If the child - * throws a HailException (user exception), then we return the tuple ((msg, errorId), NA). - * If the child throws any other exception, we raise that exception. If the - * child does not throw, then we return the tuple (NA, child value). +/** the Trap node runs the `child` node with an exception handler. If the child throws a + * HailException (user exception), then we return the tuple ((msg, errorId), NA). If the child + * throws any other exception, we raise that exception. If the child does not throw, then we return + * the tuple (NA, child value). */ final case class Trap(child: IR) extends IR final case class Die(message: IR, _typ: Type, errorId: Int) extends IR final case class ConsoleLog(message: IR, result: IR) extends IR -final case class ApplyIR(function: String, typeArgs: Seq[Type], args: Seq[IR], returnType: Type, errorID: Int) extends IR { +final case class ApplyIR( + function: String, + typeArgs: Seq[Type], + args: Seq[IR], + returnType: Type, + errorID: Int, +) extends IR { var conversion: (Seq[Type], Seq[IR], Int) => IR = _ var inline: Boolean = _ @@ -750,18 +1050,38 @@ sealed abstract class AbstractApplyNode[F <: JVMFunction] extends IR { def returnType: Type def typeArgs: Seq[Type] def argTypes: Seq[Type] = args.map(_.typ) - lazy val implementation: F = IRFunctionRegistry.lookupFunctionOrFail(function, returnType, typeArgs, argTypes) - .asInstanceOf[F] -} -final case class Apply(function: String, typeArgs: Seq[Type], args: Seq[IR], returnType: Type, errorID: Int) extends AbstractApplyNode[UnseededMissingnessObliviousJVMFunction] + lazy val implementation: F = + IRFunctionRegistry.lookupFunctionOrFail(function, returnType, typeArgs, argTypes) + .asInstanceOf[F] +} -final case class ApplySeeded(function: String, _args: Seq[IR], rngState: IR, staticUID: Long, returnType: Type) extends AbstractApplyNode[UnseededMissingnessObliviousJVMFunction] { +final case class Apply( + function: String, + typeArgs: Seq[Type], + args: Seq[IR], + returnType: Type, + errorID: Int, +) extends AbstractApplyNode[UnseededMissingnessObliviousJVMFunction] + +final case class ApplySeeded( + function: String, + _args: Seq[IR], + rngState: IR, + staticUID: Long, + returnType: Type, +) extends AbstractApplyNode[UnseededMissingnessObliviousJVMFunction] { val args = rngState +: _args val typeArgs: Seq[Type] = Seq.empty[Type] } -final case class ApplySpecial(function: String, typeArgs: Seq[Type], args: Seq[IR], returnType: Type, errorID: Int) extends AbstractApplyNode[UnseededMissingnessAwareJVMFunction] +final case class ApplySpecial( + function: String, + typeArgs: Seq[Type], + args: Seq[IR], + returnType: Type, + errorID: Int, +) extends AbstractApplyNode[UnseededMissingnessAwareJVMFunction] final case class LiftMeOut(child: IR) extends IR final case class TableCount(child: TableIR) extends IR @@ -771,57 +1091,84 @@ final case class MatrixAggregate(child: MatrixIR, query: IR) extends IR final case class TableWrite(child: TableIR, writer: TableWriter) extends IR -final case class TableMultiWrite(_children: IndexedSeq[TableIR], writer: WrappedMatrixNativeMultiWriter) extends IR +final case class TableMultiWrite( + _children: IndexedSeq[TableIR], + writer: WrappedMatrixNativeMultiWriter, +) extends IR final case class TableGetGlobals(child: TableIR) extends IR final case class TableCollect(child: TableIR) extends IR final case class MatrixWrite(child: MatrixIR, writer: MatrixWriter) extends IR -final case class MatrixMultiWrite(_children: IndexedSeq[MatrixIR], writer: MatrixNativeMultiWriter) extends IR +final case class MatrixMultiWrite(_children: IndexedSeq[MatrixIR], writer: MatrixNativeMultiWriter) + extends IR final case class TableToValueApply(child: TableIR, function: TableToValueFunction) extends IR final case class MatrixToValueApply(child: MatrixIR, function: MatrixToValueFunction) extends IR -final case class BlockMatrixToValueApply(child: BlockMatrixIR, function: BlockMatrixToValueFunction) extends IR + +final case class BlockMatrixToValueApply(child: BlockMatrixIR, function: BlockMatrixToValueFunction) + extends IR final case class BlockMatrixCollect(child: BlockMatrixIR) extends NDArrayIR final case class BlockMatrixWrite(child: BlockMatrixIR, writer: BlockMatrixWriter) extends IR -final case class BlockMatrixMultiWrite(blockMatrices: IndexedSeq[BlockMatrixIR], writer: BlockMatrixMultiWriter) extends IR - -final case class CollectDistributedArray(contexts: IR, globals: IR, cname: String, gname: String, body: IR, dynamicID: IR, staticID: String, tsd: Option[TableStageDependency] = None) extends IR +final case class BlockMatrixMultiWrite( + blockMatrices: IndexedSeq[BlockMatrixIR], + writer: BlockMatrixMultiWriter, +) extends IR + +final case class CollectDistributedArray( + contexts: IR, + globals: IR, + cname: String, + gname: String, + body: IR, + dynamicID: IR, + staticID: String, + tsd: Option[TableStageDependency] = None, +) extends IR object PartitionReader { - implicit val formats: Formats = new DefaultFormats() { - override val typeHints = ShortTypeHints(List( - classOf[PartitionRVDReader], - classOf[PartitionNativeReader], - classOf[PartitionNativeReaderIndexed], - classOf[PartitionNativeIntervalReader], - classOf[PartitionZippedNativeReader], - classOf[PartitionZippedIndexedNativeReader], - classOf[BgenPartitionReader], - classOf[GVCFPartitionReader], - classOf[TextInputFilterAndReplace], - classOf[VCFHeaderInfo], - classOf[AbstractTypedCodecSpec], - classOf[TypedCodecSpec], - classOf[AvroPartitionReader]), - typeHintFieldName = "name") + BufferSpec.shortTypeHints - } + - new TStructSerializer + - new TypeSerializer + - new PTypeSerializer + - new ETypeSerializer + - new AvroSchemaSerializer + implicit val formats: Formats = + new DefaultFormats() { + override val typeHints = ShortTypeHints( + List( + classOf[PartitionRVDReader], + classOf[PartitionNativeReader], + classOf[PartitionNativeReaderIndexed], + classOf[PartitionNativeIntervalReader], + classOf[PartitionZippedNativeReader], + classOf[PartitionZippedIndexedNativeReader], + classOf[BgenPartitionReader], + classOf[GVCFPartitionReader], + classOf[TextInputFilterAndReplace], + classOf[VCFHeaderInfo], + classOf[AbstractTypedCodecSpec], + classOf[TypedCodecSpec], + classOf[AvroPartitionReader], + ), + typeHintFieldName = "name", + ) + BufferSpec.shortTypeHints + } + + new TStructSerializer + + new TypeSerializer + + new PTypeSerializer + + new ETypeSerializer + + new AvroSchemaSerializer def extract(ctx: ExecuteContext, jv: JValue): PartitionReader = { (jv \ "name").extract[String] match { case "PartitionNativeIntervalReader" => val path = (jv \ "path").extract[String] val spec = TableNativeReader.read(ctx.fs, path, None).spec - PartitionNativeIntervalReader(ctx.stateManager, path, spec, (jv \ "uidFieldName").extract[String]) + PartitionNativeIntervalReader( + ctx.stateManager, + path, + spec, + (jv \ "uidFieldName").extract[String], + ) case "GVCFPartitionReader" => val header = VCFHeaderInfo.fromJSON((jv \ "header")) val callFields = (jv \ "callFields").extract[Set[String]] @@ -836,7 +1183,8 @@ object PartitionReader { val filterAndReplace = (jv \ "filterAndReplace").extract[TextInputFilterAndReplace] val entriesFieldName = (jv \ "entriesFieldName").extract[String] val uidFieldName = (jv \ "uidFieldName").extract[String] - GVCFPartitionReader(header, callFields, entryFloatType, arrayElementsRequired, rg, contigRecoding, + GVCFPartitionReader(header, callFields, entryFloatType, arrayElementsRequired, rg, + contigRecoding, skipInvalidLoci, filterAndReplace, entriesFieldName, uidFieldName) case _ => jv.extract[PartitionReader] } @@ -844,43 +1192,50 @@ object PartitionReader { } object PartitionWriter { - implicit val formats: Formats = new DefaultFormats() { - override val typeHints = ShortTypeHints(List( - classOf[PartitionNativeWriter], - classOf[TableTextPartitionWriter], - classOf[VCFPartitionWriter], - classOf[GenSampleWriter], - classOf[GenVariantWriter], - classOf[AbstractTypedCodecSpec], - classOf[TypedCodecSpec]), typeHintFieldName = "name" - ) + BufferSpec.shortTypeHints - } + - new TStructSerializer + - new TypeSerializer + - new PTypeSerializer + - new PStructSerializer + - new ETypeSerializer + implicit val formats: Formats = + new DefaultFormats() { + override val typeHints = ShortTypeHints( + List( + classOf[PartitionNativeWriter], + classOf[TableTextPartitionWriter], + classOf[VCFPartitionWriter], + classOf[GenSampleWriter], + classOf[GenVariantWriter], + classOf[AbstractTypedCodecSpec], + classOf[TypedCodecSpec], + ), + typeHintFieldName = "name", + ) + BufferSpec.shortTypeHints + } + + new TStructSerializer + + new TypeSerializer + + new PTypeSerializer + + new PStructSerializer + + new ETypeSerializer } object MetadataWriter { - implicit val formats: Formats = new DefaultFormats() { - override val typeHints = ShortTypeHints(List( - classOf[RVDSpecWriter], - classOf[TableSpecWriter], - classOf[RelationalWriter], - classOf[TableTextFinalizer], - classOf[VCFExportFinalizer], - classOf[SimpleMetadataWriter], - classOf[RVDSpecMaker], - classOf[AbstractTypedCodecSpec], - classOf[TypedCodecSpec]), - typeHintFieldName = "name" - ) + BufferSpec.shortTypeHints - } + - new TStructSerializer + - new TypeSerializer + - new PTypeSerializer + - new ETypeSerializer + implicit val formats: Formats = + new DefaultFormats() { + override val typeHints = ShortTypeHints( + List( + classOf[RVDSpecWriter], + classOf[TableSpecWriter], + classOf[RelationalWriter], + classOf[TableTextFinalizer], + classOf[VCFExportFinalizer], + classOf[SimpleMetadataWriter], + classOf[RVDSpecMaker], + classOf[AbstractTypedCodecSpec], + classOf[TypedCodecSpec], + ), + typeHintFieldName = "name", + ) + BufferSpec.shortTypeHints + } + + new TStructSerializer + + new TypeSerializer + + new PTypeSerializer + + new ETypeSerializer } abstract class PartitionReader { @@ -899,7 +1254,7 @@ abstract class PartitionReader { cb: EmitCodeBuilder, mb: EmitMethodBuilder[_], context: EmitCode, - requestedType: TStruct + requestedType: TStruct, ): IEmitCode def toJValue: JValue @@ -911,11 +1266,17 @@ abstract class PartitionWriter { cb: EmitCodeBuilder, stream: StreamProducer, context: EmitCode, - region: Value[Region]): IEmitCode + region: Value[Region], + ): IEmitCode def ctxType: Type def returnType: Type - def unionTypeRequiredness(r: TypeWithRequiredness, ctxType: TypeWithRequiredness, streamType: RIterable): Unit + + def unionTypeRequiredness( + r: TypeWithRequiredness, + ctxType: TypeWithRequiredness, + streamType: RIterable, + ): Unit def toJValue: JValue = Extraction.decompose(this)(PartitionWriter.formats) } @@ -923,17 +1284,33 @@ abstract class PartitionWriter { abstract class SimplePartitionWriter extends PartitionWriter { def ctxType: Type = TString def returnType: Type = TString - def unionTypeRequiredness(r: TypeWithRequiredness, ctxType: TypeWithRequiredness, streamType: RIterable): Unit = { + + def unionTypeRequiredness( + r: TypeWithRequiredness, + ctxType: TypeWithRequiredness, + streamType: RIterable, + ): Unit = { r.union(ctxType.required) r.union(streamType.required) } - def consumeElement(cb: EmitCodeBuilder, element: EmitCode, os: Value[OutputStream], region: Value[Region]): Unit + def consumeElement( + cb: EmitCodeBuilder, + element: EmitCode, + os: Value[OutputStream], + region: Value[Region], + ): Unit + def preConsume(cb: EmitCodeBuilder, os: Value[OutputStream]): Unit = () def postConsume(cb: EmitCodeBuilder, os: Value[OutputStream]): Unit = () - final def consumeStream(ctx: ExecuteContext, cb: EmitCodeBuilder, stream: StreamProducer, - context: EmitCode, region: Value[Region]): IEmitCode = { + final def consumeStream( + ctx: ExecuteContext, + cb: EmitCodeBuilder, + stream: StreamProducer, + context: EmitCode, + region: Value[Region], + ): IEmitCode = { context.toI(cb).map(cb) { case ctx: SStringValue => val filename = ctx.loadString(cb) val os = cb.memoize(cb.emb.create(filename)) @@ -954,17 +1331,20 @@ abstract class SimplePartitionWriter extends PartitionWriter { abstract class MetadataWriter { def annotationType: Type + def writeMetadata( writeAnnotations: => IEmitCode, cb: EmitCodeBuilder, - region: Value[Region]): Unit + region: Value[Region], + ): Unit def toJValue: JValue = Extraction.decompose(this)(MetadataWriter.formats) } final case class SimpleMetadataWriter(val annotationType: Type) extends MetadataWriter { - def writeMetadata(writeAnnotations: => IEmitCode, cb: EmitCodeBuilder, region: Value[Region]): Unit = - writeAnnotations.consume(cb, {}, {_ => ()}) + def writeMetadata(writeAnnotations: => IEmitCode, cb: EmitCodeBuilder, region: Value[Region]) + : Unit = + writeAnnotations.consume(cb, {}, _ => ()) } final case class ReadPartition(context: IR, rowType: TStruct, reader: PartitionReader) extends IR @@ -972,7 +1352,13 @@ final case class WritePartition(value: IR, writeCtx: IR, writer: PartitionWriter final case class WriteMetadata(writeAnnotations: IR, writer: MetadataWriter) extends IR final case class ReadValue(path: IR, reader: ValueReader, requestedType: Type) extends IR -final case class WriteValue(value: IR, path: IR, writer: ValueWriter, stagingFile: Option[IR] = None) extends IR + +final case class WriteValue( + value: IR, + path: IR, + writer: ValueWriter, + stagingFile: Option[IR] = None, +) extends IR class PrimitiveIR(val self: IR) extends AnyVal { def +(other: IR): IR = { @@ -982,6 +1368,7 @@ class PrimitiveIR(val self: IR) extends AnyVal { else ApplyBinaryPrimOp(Add(), self, other) } + def -(other: IR): IR = ApplyBinaryPrimOp(Subtract(), self, other) def *(other: IR): IR = ApplyBinaryPrimOp(Multiply(), self, other) def /(other: IR): IR = ApplyBinaryPrimOp(FloatingPointDivide(), self, other) diff --git a/hail/src/main/scala/is/hail/expr/ir/InTailPosition.scala b/hail/src/main/scala/is/hail/expr/ir/InTailPosition.scala index 6689d2d3941..e5165c5476d 100644 --- a/hail/src/main/scala/is/hail/expr/ir/InTailPosition.scala +++ b/hail/src/main/scala/is/hail/expr/ir/InTailPosition.scala @@ -2,7 +2,7 @@ package is.hail.expr.ir object InTailPosition { def apply(x: IR, i: Int): Boolean = x match { - case Let(bindings, _) => i == bindings.length + case Block(bindings, _) => i == bindings.length case If(_, _, _) => i != 0 case _: Switch => i != 0 case TailLoop(_, params, _, _) => i == params.length diff --git a/hail/src/main/scala/is/hail/expr/ir/InferType.scala b/hail/src/main/scala/is/hail/expr/ir/InferType.scala index dd48deb97e2..8b49e14c258 100644 --- a/hail/src/main/scala/is/hail/expr/ir/InferType.scala +++ b/hail/src/main/scala/is/hail/expr/ir/InferType.scala @@ -33,7 +33,7 @@ object InferType { case MakeNDArray(data, shape, _, _) => TNDArray(tcoerce[TIterable](data.typ).elementType, Nat(shape.typ.asInstanceOf[TTuple].size)) case StreamBufferedAggregate(_, _, newKey, _, _, aggSignatures, _) => - val tupleFieldTypes = TTuple(aggSignatures.map(_ => TBinary):_*) + val tupleFieldTypes = TTuple(aggSignatures.map(_ => TBinary): _*) TStream(newKey.typ.asInstanceOf[TStruct].insertFields(IndexedSeq(("agg", tupleFieldTypes)))) case _: ArrayLen => TInt32 case _: StreamIota => TStream(TInt32) @@ -47,24 +47,21 @@ object InferType { case _: CombOp => TVoid case ResultOp(_, aggSig) => aggSig.resultType - case AggStateValue(i, sig) => TBinary + case AggStateValue(_, _) => TBinary case _: CombOpValue => TVoid case _: InitFromSerializedValue => TVoid case _: SerializeAggs => TVoid case _: DeserializeAggs => TVoid - case _: Begin => TVoid case Die(_, t, _) => t case Trap(child) => TTuple(TTuple(TString, TInt32), child.typ) - case ConsoleLog(message, result) => result.typ + case ConsoleLog(_, result) => result.typ case If(cond, cnsq, altr) => assert(cond.typ == TBoolean) assert(cnsq.typ == altr.typ) cnsq.typ case Switch(_, default, _) => default.typ - case Let(_, body) => - body.typ - case AggLet(name, value, body, _) => + case Block(_, body) => body.typ case TailLoop(_, _, resultType, _) => resultType @@ -107,7 +104,7 @@ object InferType { case ToDict(a) => val elt = tcoerce[TBaseStruct](tcoerce[TStream](a.typ).elementType) TDict(elt.types(0), elt.types(1)) - case ta@ToArray(a) => + case ToArray(a) => val elt = tcoerce[TStream](a.typ).elementType TArray(elt) case CastToArray(a) => @@ -120,7 +117,7 @@ object InferType { TRNGState case RNGSplit(_, _) => TRNGState - case StreamLen(a) => TInt32 + case StreamLen(_) => TInt32 case GroupByKey(collection) => val elt = tcoerce[TBaseStruct](tcoerce[TStream](collection.typ).elementType) TDict(elt.types(0), TArray(elt.types(1))) @@ -132,9 +129,9 @@ object InferType { TStream(a.typ) case StreamGroupByKey(a, _, _) => TStream(a.typ) - case StreamMap(a, name, body) => + case StreamMap(_, _, body) => TStream(body.typ) - case StreamZip(as, _, body, _, _) => + case StreamZip(_, _, body, _, _) => TStream(body.typ) case StreamZipJoin(_, _, _, _, joinF) => TStream(joinF.typ) @@ -142,24 +139,29 @@ object InferType { TStream(joinF.typ) case StreamMultiMerge(as, _) => TStream(tcoerce[TStream](as.head.typ).elementType) - case StreamFilter(a, name, cond) => + case StreamFilter(a, _, _) => a.typ - case StreamTakeWhile(a, name, cond) => + case StreamTakeWhile(a, _, _) => a.typ - case StreamDropWhile(a, name, cond) => + case StreamDropWhile(a, _, _) => a.typ - case StreamFlatMap(a, name, body) => + case StreamFlatMap(_, _, body) => TStream(tcoerce[TStream](body.typ).elementType) - case StreamFold(a, zero, accumName, valueName, body) => + case StreamFold(_, zero, _, _, body) => assert(body.typ == zero.typ) zero.typ case StreamFold2(_, _, _, _, result) => result.typ - case StreamDistribute(child, pivots, pathPrefix, _, _) => + case StreamDistribute(_, pivots, _, _, _) => val keyType = pivots.typ.asInstanceOf[TContainer].elementType - TArray(TStruct(("interval", TInterval(keyType)), ("fileName", TString), ("numElements", TInt32), ("numBytes", TInt64))) + TArray(TStruct( + ("interval", TInterval(keyType)), + ("fileName", TString), + ("numElements", TInt32), + ("numBytes", TInt64), + )) case StreamWhiten(stream, _, _, _, _, _, _, _) => stream.typ - case StreamScan(a, zero, accumName, valueName, body) => + case StreamScan(_, zero, _, _, body) => assert(body.typ == zero.typ) TStream(zero.typ) case StreamAgg(_, _, query) => @@ -172,12 +174,15 @@ object InferType { "locus" -> childType.fieldType("locus"), "alleles" -> childType.fieldType("alleles"), "mean" -> TFloat64, - "centered_length_rec" -> TFloat64)) - case RunAgg(body, result, _) => + "centered_length_rec" -> TFloat64, + )) + case RunAgg(_, result, _) => result.typ case RunAggScan(_, _, _, _, result, _) => TStream(result.typ) - case StreamJoinRightDistinct(left, right, lKey, rKey, l, r, join, joinType) => + case s: StreamLeftIntervalJoin => + TStream(s.body.typ) + case StreamJoinRightDistinct(_, _, _, _, _, _, join, _) => TStream(join.typ) case NDArrayShape(nd) => val ndType = nd.typ.asInstanceOf[TNDArray] @@ -201,7 +206,7 @@ object InferType { case NDArraySlice(nd, slices) => val childTyp = tcoerce[TNDArray](nd.typ) val slicesTyp = tcoerce[TTuple](slices.typ) - val tuplesOnly = slicesTyp.types.collect { case x: TTuple => x} + val tuplesOnly = slicesTyp.types.collect { case x: TTuple => x } val remainingDims = Nat(tuplesOnly.length) TNDArray(childTyp.elementType, remainingDims) case NDArrayFilter(nd, _) => @@ -210,7 +215,7 @@ object InferType { val lTyp = tcoerce[TNDArray](l.typ) val rTyp = tcoerce[TNDArray](r.typ) TNDArray(lTyp.elementType, Nat(TNDArray.matMulNDims(lTyp.nDims, rTyp.nDims))) - case NDArrayQR(nd, mode, _) => + case NDArrayQR(_, mode, _) => if (Array("complete", "reduced").contains(mode)) { TTuple(TNDArray(TFloat64, Nat(2)), TNDArray(TFloat64, Nat(2))) } else if (mode == "raw") { @@ -220,13 +225,13 @@ object InferType { } else { throw new NotImplementedError(s"Cannot infer type for mode $mode") } - case NDArraySVD(nd, _, compute_uv, _) => + case NDArraySVD(_, _, compute_uv, _) => if (compute_uv) { TTuple(TNDArray(TFloat64, Nat(2)), TNDArray(TFloat64, Nat(1)), TNDArray(TFloat64, Nat(2))) } else { TNDArray(TFloat64, Nat(1)) } - case NDArrayEigh(nd, eigvalsOnly, _) => + case NDArrayEigh(_, eigvalsOnly, _) => if (eigvalsOnly) { TNDArray(TFloat64, Nat(1)) } else { @@ -237,11 +242,11 @@ object InferType { case NDArrayWrite(_, _) => TVoid case AggFilter(_, aggIR, _) => aggIR.typ - case AggExplode(array, name, aggBody, _) => + case AggExplode(_, _, aggBody, _) => aggBody.typ case AggGroupBy(key, aggIR, _) => TDict(key.typ, aggIR.typ) - case AggArrayPerElement(a, _, _, aggBody, _, _) => TArray(aggBody.typ) + case AggArrayPerElement(_, _, _, aggBody, _, _) => TArray(aggBody.typ) case ApplyAggOp(_, _, aggSig) => aggSig.returnType case ApplyScanOp(_, _, aggSig) => @@ -259,7 +264,7 @@ object InferType { val tbs = tcoerce[TStruct](old.typ) val s = tbs.insertFields(fields.map(f => (f._1, f._2.typ))) fieldOrder.map { fds => - assert(fds.length == s.size, s"${fds} != ${s.types.toIndexedSeq}") + assert(fds.length == s.size, s"$fds != ${s.types.toIndexedSeq}") TStruct(fds.map(f => f -> s.fieldType(f)): _*) }.getOrElse(s) case GetField(o, name) => @@ -275,9 +280,9 @@ object InferType { fd case TableCount(_) => TInt64 case MatrixCount(_) => TTuple(TInt64, TInt32) - case TableAggregate(child, query) => + case TableAggregate(_, query) => query.typ - case MatrixAggregate(child, query) => + case MatrixAggregate(_, query) => query.typ case _: TableWrite => TVoid case _: TableMultiWrite => TVoid @@ -287,13 +292,14 @@ object InferType { case BlockMatrixWrite(_, writer) => writer.loweredTyp case _: BlockMatrixMultiWrite => TVoid case TableGetGlobals(child) => child.typ.globalType - case TableCollect(child) => TStruct("rows" -> TArray(child.typ.rowType), "global" -> child.typ.globalType) + case TableCollect(child) => + TStruct("rows" -> TArray(child.typ.rowType), "global" -> child.typ.globalType) case TableToValueApply(child, function) => function.typ(child.typ) case MatrixToValueApply(child, function) => function.typ(child.typ) case BlockMatrixToValueApply(child, function) => function.typ(child.typ) case CollectDistributedArray(_, _, _, _, body, _, _, _) => TArray(body.typ) case ReadPartition(_, rowType, _) => TStream(rowType) - case WritePartition(value, writeCtx, writer) => writer.returnType + case WritePartition(_, _, writer) => writer.returnType case _: WriteMetadata => TVoid case ReadValue(_, _, typ) => typ case _: WriteValue => TString diff --git a/hail/src/main/scala/is/hail/expr/ir/Interpret.scala b/hail/src/main/scala/is/hail/expr/ir/Interpret.scala index c1d5e80b215..1b022d3a853 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Interpret.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Interpret.scala @@ -2,21 +2,22 @@ package is.hail.expr.ir import is.hail.annotations._ import is.hail.asm4s._ -import is.hail.backend.spark.SparkTaskContext import is.hail.backend.{ExecuteContext, HailTaskContext} +import is.hail.backend.spark.SparkTaskContext import is.hail.expr.ir.lowering.LoweringPipeline import is.hail.io.BufferSpec import is.hail.linalg.BlockMatrix import is.hail.rvd.RVDContext -import is.hail.types.physical.stypes.{PTypeReferenceSingleCodeType, SingleCodeType} import is.hail.types.physical.{PTuple, PType} +import is.hail.types.physical.stypes.{PTypeReferenceSingleCodeType, SingleCodeType} import is.hail.types.tcoerce import is.hail.types.virtual._ import is.hail.utils._ -import org.apache.spark.sql.Row import scala.collection.mutable +import org.apache.spark.sql.Row + object Interpret { type Agg = (IndexedSeq[Row], TStruct) @@ -24,7 +25,10 @@ object Interpret { apply(tir, ctx, optimize = true) def apply(tir: TableIR, ctx: ExecuteContext, optimize: Boolean): TableValue = { - val lowered = LoweringPipeline.legacyRelationalLowerer(optimize)(ctx, tir).asInstanceOf[TableIR].noSharing(ctx) + val lowered = + LoweringPipeline.legacyRelationalLowerer(optimize)(ctx, tir).asInstanceOf[TableIR].noSharing( + ctx + ) lowered.analyzeAndExecute(ctx).asTableValue(ctx) } @@ -34,32 +38,40 @@ object Interpret { } def apply(bmir: BlockMatrixIR, ctx: ExecuteContext, optimize: Boolean): BlockMatrix = { - val lowered = LoweringPipeline.legacyRelationalLowerer(optimize)(ctx, bmir).asInstanceOf[BlockMatrixIR] + val lowered = + LoweringPipeline.legacyRelationalLowerer(optimize)(ctx, bmir).asInstanceOf[BlockMatrixIR] lowered.execute(ctx) } - def apply[T](ctx: ExecuteContext, ir: IR): T = apply(ctx, ir, Env.empty[(Any, Type)], FastSeq[(Any, Type)]()).asInstanceOf[T] + def apply[T](ctx: ExecuteContext, ir: IR): T = + apply(ctx, ir, Env.empty[(Any, Type)], FastSeq[(Any, Type)]()).asInstanceOf[T] def apply[T]( ctx: ExecuteContext, ir0: IR, env: Env[(Any, Type)], args: IndexedSeq[(Any, Type)], - optimize: Boolean = true + optimize: Boolean = true, ): T = { - val bindings = env.m.view.map { case (k, (value, t)) => k -> Literal.coerce(t, value) }.toFastSeq - val lowered = LoweringPipeline.relationalLowerer(optimize).apply(ctx, Let(bindings, ir0)).asInstanceOf[IR] + val bindings = env.m.view.map { case (k, (value, t)) => + k -> Literal.coerce(t, value) + }.toFastSeq + val lowered = + LoweringPipeline.relationalLowerer(optimize).apply(ctx, Let(bindings, ir0)).asInstanceOf[IR] val result = run(ctx, lowered, Env.empty[Any], args, Memo.empty).asInstanceOf[T] result } - def alreadyLowered(ctx: ExecuteContext, ir: IR): Any = run(ctx, ir, Env.empty, FastSeq(), Memo.empty) + def alreadyLowered(ctx: ExecuteContext, ir: IR): Any = + run(ctx, ir, Env.empty, FastSeq(), Memo.empty) - private def run(ctx: ExecuteContext, + private def run( + ctx: ExecuteContext, ir: IR, env: Env[Any], args: IndexedSeq[(Any, Type)], - functionMemo: Memo[(SingleCodeType, AsmFunction2RegionLongLong)]): Any = { + functionMemo: Memo[(SingleCodeType, AsmFunction2RegionLongLong)], + ): Any = { def interpret(ir: IR, env: Env[Any] = env, args: IndexedSeq[(Any, Type)] = args): Any = run(ctx, ir, env, args, functionMemo) @@ -73,7 +85,7 @@ object Interpret { case True() => true case False() => false case Literal(_, value) => value - case x@EncodedLiteral(codec, value) => + case x @ EncodedLiteral(codec, value) => ctx.r.getPool().scopedRegion { r => val (pt, addr) = codec.decodeArrays(ctx, x.typ, value.ba, ctx.r) SafeRow.read(pt, addr) @@ -127,8 +139,8 @@ object Interpret { case null => null } - case Let(bindings, body) => - val newEnv = bindings.foldLeft(env) { case (env, (name, value)) => + case Block(bindings, body) => + val newEnv = bindings.foldLeft(env) { case (env, Binding(name, value, Scope.EVAL)) => env.bind(name -> interpret(value, env, args)) } interpret(body, newEnv, args) @@ -249,7 +261,7 @@ object Interpret { case MakeArray(elements, _) => elements.map(interpret(_, env, args)).toFastSeq case MakeStream(elements, _, _) => elements.map(interpret(_, env, args)).toFastSeq - case x@ArrayRef(a, i, errorId) => + case ArrayRef(a, i, errorId) => val aValue = interpret(a, env, args) val iValue = interpret(i, env, args) if (aValue == null || iValue == null) @@ -259,7 +271,7 @@ object Interpret { val i = iValue.asInstanceOf[Int] if (i < 0 || i >= a.length) { - fatal(s"array index out of bounds: index=$i, length=${ a.length }", errorId = errorId) + fatal(s"array index out of bounds: index=$i, length=${a.length}", errorId = errorId) } else a.apply(i) } @@ -268,8 +280,10 @@ object Interpret { val startValue = interpret(start, env, args) val stopValue = stop.map(ir => interpret(ir, env, args)) val stepValue = interpret(step, env, args) - if (startValue == null || stepValue == null || aValue == null || - stopValue.getOrElse(aValue.asInstanceOf[IndexedSeq[Any]].size) == null) + if ( + startValue == null || stepValue == null || aValue == null || + stopValue.getOrElse(aValue.asInstanceOf[IndexedSeq[Any]].size) == null + ) null else { val a = aValue.asInstanceOf[IndexedSeq[Any]] @@ -278,20 +292,20 @@ object Interpret { if (requestedStep == 0) fatal("step cannot be 0 for array slice", errorID) val noneStop = if (requestedStep < 0) -a.size - 1 - else a.size - val maxBound = if(requestedStep > 0) a.size - else a.size - 1 - val minBound = if(requestedStep > 0) 0 - else - 1 + else a.size + val maxBound = if (requestedStep > 0) a.size + else a.size - 1 + val minBound = if (requestedStep > 0) 0 + else -1 val requestedStop = stopValue.getOrElse(noneStop).asInstanceOf[Int] val realStart = if (requestedStart >= a.size) maxBound - else if (requestedStart >= 0) requestedStart - else if (requestedStart + a.size >= 0) requestedStart + a.size - else minBound + else if (requestedStart >= 0) requestedStart + else if (requestedStart + a.size >= 0) requestedStart + a.size + else minBound val realStop = if (requestedStop >= a.size) maxBound - else if (requestedStop >= 0) requestedStop - else if (requestedStop + a.size > 0) requestedStop + a.size - else minBound + else if (requestedStop >= 0) requestedStop + else if (requestedStop + a.size > 0) requestedStop + a.size + else minBound (realStart until realStop by requestedStep).map(idx => a(idx)) } case ArrayLen(a) => @@ -306,7 +320,8 @@ object Interpret { null else aValue.asInstanceOf[IndexedSeq[Any]].length - case StreamIota(start, step, requiresMemoryManagementPerElement) => throw new UnsupportedOperationException + case StreamIota(_, _, _) => + throw new UnsupportedOperationException case StreamRange(start, stop, step, _, errorID) => val startValue = interpret(start, env, args) val stopValue = interpret(stop, env, args) @@ -316,7 +331,9 @@ object Interpret { if (startValue == null || stopValue == null || stepValue == null) null else - startValue.asInstanceOf[Int] until stopValue.asInstanceOf[Int] by stepValue.asInstanceOf[Int] + startValue.asInstanceOf[Int] until stopValue.asInstanceOf[Int] by stepValue.asInstanceOf[ + Int + ] case ArraySort(a, l, r, lessThan) => val aValue = interpret(a, env, args) if (aValue == null) @@ -344,7 +361,9 @@ object Interpret { if (aValue == null) null else - aValue.asInstanceOf[IndexedSeq[Row]].filter(_ != null).map { case Row(k, v) => (k, v) }.toMap + aValue.asInstanceOf[IndexedSeq[Row]].filter(_ != null).map { case Row(k, v) => + (k, v) + }.toMap case _: CastToArray | _: ToArray | _: ToStream => val c = ir.children.head.asInstanceOf[IR] val cValue = interpret(c, env, args) @@ -355,7 +374,8 @@ object Interpret { cValue match { case s: Set[_] => s.asInstanceOf[Set[Any]].toFastSeq.sorted(ordering) - case d: Map[_, _] => d.iterator.map { case (k, v) => Row(k, v) }.toFastSeq.sorted(ordering) + case d: Map[_, _] => + d.iterator.map { case (k, v) => Row(k, v) }.toFastSeq.sorted(ordering) case a => a } } @@ -375,16 +395,23 @@ object Interpret { d.count { case (k, _) => elem.typ.ordering(ctx.stateManager).lt(k, eValue) } case a: IndexedSeq[_] => if (onKey) { - val (eltF, eltT) = orderedCollection.typ.asInstanceOf[TContainer].elementType match { - case t: TBaseStruct => ( { (x: Any) => - val r = x.asInstanceOf[Row] - if (r == null) null else r.get(0) - }, t.types(0)) - case i: TInterval => ( { (x: Any) => - val i = x.asInstanceOf[Interval] - if (i == null) null else i.start - }, i.pointType) - } + val (eltF, eltT) = + orderedCollection.typ.asInstanceOf[TContainer].elementType match { + case t: TBaseStruct => ( + { (x: Any) => + val r = x.asInstanceOf[Row] + if (r == null) null else r.get(0) + }, + t.types(0), + ) + case i: TInterval => ( + { (x: Any) => + val i = x.asInstanceOf[Interval] + if (i == null) null else i.start + }, + i.pointType, + ) + } val ordering = eltT.ordering(ctx.stateManager) val lb = a.count(elem => ordering.lt(eltF(elem), eValue)) lb @@ -475,7 +502,7 @@ object Interpret { case ArrayZipBehavior.AssertSameLength | ArrayZipBehavior.AssumeSameLength => val lengths = aValues.map(_.length).toSet if (lengths.size != 1) - fatal(s"zip: length mismatch: ${ lengths.mkString(", ") }", errorID) + fatal(s"zip: length mismatch: ${lengths.mkString(", ")}", errorID) lengths.head case ArrayZipBehavior.TakeMinLength => aValues.map(_.length).min @@ -483,7 +510,8 @@ object Interpret { aValues.map(_.length).max } (0 until len).map { i => - val e = env.bindIterable(names.zip(aValues.map(a => if (i >= a.length) null else a.apply(i)))) + val e = + env.bindIterable(names.zip(aValues.map(a => if (i >= a.length) null else a.apply(i)))) interpret(body, e, args) } } @@ -505,7 +533,7 @@ object Interpret { c < 0 || (c == 0 && li < ri) } - def advance(i: Int) { + def advance(i: Int): Unit = { heads(i) += 1 var winner = if (heads(i) < streams(i).length) i else k var j = (i + k) / 2 @@ -520,7 +548,7 @@ object Interpret { tournament(j) = winner } - for (i <- 0 until k) { advance(i) } + for (i <- 0 until k) advance(i) val builder = new BoxedArrayBuilder[Row]() while (tournament(0) != k) { @@ -546,7 +574,7 @@ object Interpret { def get(i: Int): Row = streams(i)(heads(i)) - def advance(i: Int) { + def advance(i: Int): Unit = { heads(i) += 1 var winner = if (heads(i) < streams(i).length) i else k var j = (i + k) / 2 @@ -561,7 +589,7 @@ object Interpret { tournament(j) = winner } - for (i <- 0 until k) { advance(i) } + for (i <- 0 until k) advance(i) val builder = new mutable.ArrayBuffer[Any]() while (tournament(0) != k) { @@ -576,7 +604,11 @@ object Interpret { advance(j) j = tournament(0) } - builder += interpret(joinF, env.bind(curKeyName -> curKey, curValsName -> elt.toFastSeq), args) + builder += interpret( + joinF, + env.bind(curKeyName -> curKey, curValsName -> elt.toFastSeq), + args, + ) } builder.toFastSeq } @@ -630,7 +662,8 @@ object Interpret { else { var zeroValue = interpret(zero, env, args) aValue.asInstanceOf[IndexedSeq[Any]].foreach { element => - zeroValue = interpret(body, env.bind(accumName -> zeroValue, valueName -> element), args) + zeroValue = + interpret(body, env.bind(accumName -> zeroValue, valueName -> element), args) } zeroValue } @@ -643,9 +676,7 @@ object Interpret { var e = env.bindIterable(accVals) aValue.asInstanceOf[IndexedSeq[Any]].foreach { elt => e = e.bind(valueName, elt) - accVals.indices.foreach { i => - e = e.bind(accum(i)._1, interpret(seq(i), e, args)) - } + accVals.indices.foreach(i => e = e.bind(accum(i)._1, interpret(seq(i), e, args))) } interpret(res, e.delete(valueName), args) } @@ -667,9 +698,11 @@ object Interpret { if (lValue == null || rValue == null) null else { - val (lKeyTyp, lGetKey) = tcoerce[TStruct](tcoerce[TStream](left.typ).elementType).select(lKey) - val (rKeyTyp, rGetKey) = tcoerce[TStruct](tcoerce[TStream](right.typ).elementType).select(rKey) - assert(lKeyTyp isIsomorphicTo rKeyTyp) + val (lKeyTyp, lGetKey) = + tcoerce[TStruct](tcoerce[TStream](left.typ).elementType).select(lKey) + val (rKeyTyp, rGetKey) = + tcoerce[TStruct](tcoerce[TStream](right.typ).elementType).select(rKey) + assert(lKeyTyp isJoinableWith rKeyTyp) val keyOrd = TBaseStruct.getJoinOrdering(ctx.stateManager, lKeyTyp.types) def compF(lelt: Any, relt: Any): Int = @@ -709,12 +742,16 @@ object Interpret { val outerResult = builder.result() val elts: Iterator[(Option[Int], Option[Int])] = joinType match { - case "inner" => outerResult.iterator.filter { case (l, r) => l.isDefined && r.isDefined } + case "inner" => outerResult.iterator.filter { case (l, r) => + l.isDefined && r.isDefined + } case "outer" => outerResult.iterator - case "left" => outerResult.iterator.filter { case (l, r) => l.isDefined } - case "right" => outerResult.iterator.filter { case (l, r) => r.isDefined } + case "left" => outerResult.iterator.filter { case (l, _) => l.isDefined } + case "right" => outerResult.iterator.filter { case (_, r) => r.isDefined } + } + elts.map { case (lIdx, rIdx) => + joinF(lIdx.map(lValue.apply).orNull, rIdx.map(rValue.apply).orNull) } - elts.map { case (lIdx, rIdx) => joinF(lIdx.map(lValue.apply).orNull, rIdx.map(rValue.apply).orNull) } .toFastSeq } @@ -726,10 +763,8 @@ object Interpret { } } () - case Begin(xs) => - xs.foreach(x => interpret(x)) case MakeStruct(fields) => - Row.fromSeq(fields.map { case (name, fieldIR) => interpret(fieldIR, env, args) }) + Row.fromSeq(fields.map { case (_, fieldIR) => interpret(fieldIR, env, args) }) case SelectFields(old, fields) => val oldt = tcoerce[TStruct](old.typ) val oldRow = interpret(old, env, args).asInstanceOf[Row] @@ -737,14 +772,17 @@ object Interpret { null else Row.fromSeq(fields.map(id => oldRow.get(oldt.fieldIdx(id)))) - case x@InsertFields(old, fields, fieldOrder) => + case InsertFields(old, fields, fieldOrder) => var struct = interpret(old, env, args) if (struct != null) fieldOrder match { case Some(fds) => val newValues = fields.toMap.mapValues(interpret(_, env, args)) - val oldIndices = old.typ.asInstanceOf[TStruct].fields.map(f => f.name -> f.index).toMap - Row.fromSeq(fds.map(name => newValues.getOrElse(name, struct.asInstanceOf[Row].get(oldIndices(name))))) + val oldIndices = + old.typ.asInstanceOf[TStruct].fields.map(f => f.name -> f.index).toMap + Row.fromSeq(fds.map(name => + newValues.getOrElse(name, struct.asInstanceOf[Row].get(oldIndices(name))) + )) case None => var t = old.typ.asInstanceOf[TStruct] fields.foreach { case (name, body) => @@ -777,20 +815,20 @@ object Interpret { case In(i, _) => val (a, _) = args(i) a - case Die(message, typ, errorId) => + case Die(message, _, errorId) => val message_ = interpret(message).asInstanceOf[String] - fatal(if (message_ != null) message_ else "", errorId) + fatal(if (message_ != null) message_ else "", errorId) case Trap(child) => - try { + try Row(null, interpret(child)) - } catch { + catch { case e: HailException => Row(Row(e.msg, e.errorId), null) } case ConsoleLog(message, result) => val message_ = interpret(message).asInstanceOf[String] info(message_) interpret(result) - case ir@ApplyIR(function, _, _, functionArgs, _) => + case ir @ ApplyIR(_, _, _, _, _) => interpret(ir.explicitNode, env, args) case ApplySpecial("lor", _, Seq(left_, right_), _, _) => val left = interpret(left_) @@ -817,25 +855,34 @@ object Interpret { else true } case ir: AbstractApplyNode[_] => - val argTuple = PType.canonical(TTuple(ir.args.map(_.typ): _*)).setRequired(true).asInstanceOf[PTuple] + val argTuple = + PType.canonical(TTuple(ir.args.map(_.typ): _*)).setRequired(true).asInstanceOf[PTuple] ctx.r.pool.scopedRegion { region => - val (rt, f) = functionMemo.getOrElseUpdate(ir, { - val wrappedArgs: IndexedSeq[BaseIR] = ir.args.zipWithIndex.map { case (x, i) => - GetTupleElement(Ref("in", argTuple.virtualType), i) - }.toFastSeq - val newChildren = ir match { - case ir: ApplySeeded => wrappedArgs :+ NA(TRNGState) - case _ => wrappedArgs - } - val wrappedIR = Copy(ir, newChildren) + val (rt, f) = functionMemo.getOrElseUpdate( + ir, { + val wrappedArgs: IndexedSeq[BaseIR] = ir.args.zipWithIndex.map { case (_, i) => + GetTupleElement(Ref("in", argTuple.virtualType), i) + }.toFastSeq + val newChildren = ir match { + case _: ApplySeeded => wrappedArgs :+ NA(TRNGState) + case _ => wrappedArgs + } + val wrappedIR = Copy(ir, newChildren) - val (rt, makeFunction) = Compile[AsmFunction2RegionLongLong](ctx, - FastSeq(("in", SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(argTuple)))), - FastSeq(classInfo[Region], LongInfo), LongInfo, - MakeTuple.ordered(FastSeq(wrappedIR)), - optimize = false) - (rt.get, makeFunction(ctx.theHailClassLoader, ctx.fs, ctx.taskContext, region)) - }) + val (rt, makeFunction) = Compile[AsmFunction2RegionLongLong]( + ctx, + FastSeq(( + "in", + SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(argTuple)), + )), + FastSeq(classInfo[Region], LongInfo), + LongInfo, + MakeTuple.ordered(FastSeq(wrappedIR)), + optimize = false, + ) + (rt.get, makeFunction(ctx.theHailClassLoader, ctx.fs, ctx.taskContext, region)) + }, + ) val rvb = new RegionValueBuilder(ctx.stateManager) rvb.set(region) rvb.start(argTuple) @@ -849,10 +896,13 @@ object Interpret { try { val resultOffset = f(region, offset) - SafeRow(rt.asInstanceOf[PTypeReferenceSingleCodeType].pt.asInstanceOf[PTuple], resultOffset).get(0) + SafeRow( + rt.asInstanceOf[PTypeReferenceSingleCodeType].pt.asInstanceOf[PTuple], + resultOffset, + ).get(0) } catch { case e: Exception => - fatal(s"error while calling '${ ir.implementation.name }': ${ e.getMessage }", e) + fatal(s"error while calling '${ir.implementation.name}': ${e.getMessage}", e) } } case TableCount(child) => @@ -883,45 +933,69 @@ object Interpret { val breezeMat = bm.transpose().toBreezeMatrix() val shape = IndexedSeq(bm.nRows, bm.nCols) SafeNDArray(shape, breezeMat.toArray) - case x@TableAggregate(child, query) => + case x @ TableAggregate(child, query) => val value = child.analyzeAndExecute(ctx).asTableValue(ctx) val fsBc = ctx.fsBc val globalsBc = value.globals.broadcast(ctx.theHailClassLoader) val globalsOffset = value.globals.value.offset - val res = genUID() - - val extracted = agg.Extract(query, res, Requiredness(x, ctx)) + val extracted = agg.Extract(query, Requiredness(x, ctx)) val wrapped = if (extracted.aggs.isEmpty) { - val (Some(PTypeReferenceSingleCodeType(rt: PTuple)), f) = Compile[AsmFunction2RegionLongLong](ctx, - FastSeq(("global", SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(value.globals.t)))), - FastSeq(classInfo[Region], LongInfo), LongInfo, - MakeTuple.ordered(FastSeq(extracted.postAggIR))) + val (Some(PTypeReferenceSingleCodeType(rt: PTuple)), f) = + Compile[AsmFunction2RegionLongLong]( + ctx, + FastSeq(( + "global", + SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(value.globals.t)), + )), + FastSeq(classInfo[Region], LongInfo), + LongInfo, + MakeTuple.ordered(FastSeq(extracted.postAggIR)), + ) // TODO Is this right? where does wrapped run? - ctx.scopedExecution((hcl, fs, htc, r) => SafeRow(rt, f(hcl, fs, htc, r).apply(r, globalsOffset))) + ctx.scopedExecution((hcl, fs, htc, r) => + SafeRow(rt, f(hcl, fs, htc, r).apply(r, globalsOffset)) + ) } else { val spec = BufferSpec.blockedUncompressed - val (_, initOp) = CompileWithAggregators[AsmFunction2RegionLongUnit](ctx, + val (_, initOp) = CompileWithAggregators[AsmFunction2RegionLongUnit]( + ctx, extracted.states, - FastSeq(("global", SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(value.globals.t)))), - FastSeq(classInfo[Region], LongInfo), UnitInfo, - extracted.init) + FastSeq(( + "global", + SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(value.globals.t)), + )), + FastSeq(classInfo[Region], LongInfo), + UnitInfo, + extracted.init, + ) - val (_, partitionOpSeq) = CompileWithAggregators[AsmFunction3RegionLongLongUnit](ctx, + val (_, partitionOpSeq) = CompileWithAggregators[AsmFunction3RegionLongLongUnit]( + ctx, extracted.states, - FastSeq(("global", SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(value.globals.t))), - ("row", SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(value.rvd.rowPType)))), - FastSeq(classInfo[Region], LongInfo, LongInfo), UnitInfo, - extracted.seqPerElt) + FastSeq( + ( + "global", + SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(value.globals.t)), + ), + ( + "row", + SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(value.rvd.rowPType)), + ), + ), + FastSeq(classInfo[Region], LongInfo, LongInfo), + UnitInfo, + extracted.seqPerElt, + ) val useTreeAggregate = extracted.shouldTreeAggregate val isCommutative = extracted.isCommutative - log.info(s"Aggregate: useTreeAggregate=${ useTreeAggregate }") - log.info(s"Aggregate: commutative=${ isCommutative }") + log.info(s"Aggregate: useTreeAggregate=$useTreeAggregate") + log.info(s"Aggregate: commutative=$isCommutative") // A mutable reference to a byte array. If someone higher up the // call stack holds a WrappedByteArray, we can set the reference @@ -929,19 +1003,18 @@ object Interpret { class WrappedByteArray(_bytes: Array[Byte]) { private var ref: Array[Byte] = _bytes def bytes: Array[Byte] = ref - def clear() { ref = null } + def clear(): Unit = ref = null } // creates a region, giving ownership to the caller val read: (HailClassLoader, HailTaskContext) => (WrappedByteArray => RegionValue) = { val deserialize = extracted.deserialize(ctx, spec) (hcl: HailClassLoader, htc: HailTaskContext) => { - (a: WrappedByteArray) => { + (a: WrappedByteArray) => val r = Region(Region.SMALL, htc.getRegionPool()) val res = deserialize(hcl, htc, r, a.bytes) a.clear() RegionValue(r, res) - } } } @@ -961,11 +1034,13 @@ object Interpret { // returns ownership of a new region holding the partition aggregation // result - def itF(theHailClassLoader: HailClassLoader, i: Int, ctx: RVDContext, it: Iterator[Long]): RegionValue = { + def itF(theHailClassLoader: HailClassLoader, i: Int, ctx: RVDContext, it: Iterator[Long]) + : RegionValue = { val partRegion = ctx.partitionRegion val globalsOffset = globalsBc.value.readRegionValue(partRegion, theHailClassLoader) val init = initOp(theHailClassLoader, fsBc.value, SparkTaskContext.get(), partRegion) - val seqOps = partitionOpSeq(theHailClassLoader, fsBc.value, SparkTaskContext.get(), partRegion) + val seqOps = + partitionOpSeq(theHailClassLoader, fsBc.value, SparkTaskContext.get(), partRegion) val aggRegion = ctx.freshRegion(Region.SMALL) init.newAggState(aggRegion) @@ -996,9 +1071,16 @@ object Interpret { CompileWithAggregators[AsmFunction2RegionLongLong]( ctx, extracted.states, - FastSeq(("global", SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(value.globals.t)))), - FastSeq(classInfo[Region], LongInfo), LongInfo, - Let(FastSeq(res -> extracted.results), MakeTuple.ordered(FastSeq(extracted.postAggIR))) + FastSeq(( + "global", + SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(value.globals.t)), + )), + FastSeq(classInfo[Region], LongInfo), + LongInfo, + Let( + FastSeq(extracted.resultRef.name -> extracted.results), + MakeTuple.ordered(FastSeq(extracted.postAggIR)), + ), ) assert(rTyp.types(0).virtualType == query.typ) @@ -1015,16 +1097,20 @@ object Interpret { wrapped.get(0) case LiftMeOut(child) => - val (Some(PTypeReferenceSingleCodeType(rt)), makeFunction) = Compile[AsmFunction1RegionLong](ctx, - FastSeq(), - FastSeq(classInfo[Region]), LongInfo, - MakeTuple.ordered(FastSeq(child)), - optimize = false) + val (Some(PTypeReferenceSingleCodeType(rt)), makeFunction) = + Compile[AsmFunction1RegionLong]( + ctx, + FastSeq(), + FastSeq(classInfo[Region]), + LongInfo, + MakeTuple.ordered(FastSeq(child)), + optimize = false, + ) ctx.scopedExecution { (hcl, fs, htc, r) => SafeRow.read(rt, makeFunction(hcl, fs, htc, r)(r)).asInstanceOf[Row](0) } case UUID4(_) => - uuid4() + uuid4() } } } diff --git a/hail/src/main/scala/is/hail/expr/ir/Interpretable.scala b/hail/src/main/scala/is/hail/expr/ir/Interpretable.scala index 5459daced86..af5903f04e7 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Interpretable.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Interpretable.scala @@ -1,54 +1,58 @@ package is.hail.expr.ir -import is.hail.types.virtual.{TNDArray, TStream} +import is.hail.types.virtual.TNDArray object Interpretable { def apply(ir: IR): Boolean = { !ir.typ.isInstanceOf[TNDArray] && (ir match { - case - _: EncodedLiteral | - _: RunAgg | - _: InitOp | - _: SeqOp | - _: CombOp | - _: ResultOp | - _: CombOpValue | - _: InitFromSerializedValue | - _: AggStateValue | - _: RunAgg | - _: RunAggScan | - _: SerializeAggs | - _: DeserializeAggs | - _: ArrayZeros | - _: MakeNDArray | - _: NDArrayShape | - _: NDArrayReshape | - _: NDArrayConcat | - _: NDArrayRef | - _: NDArraySlice | - _: NDArrayFilter | - _: NDArrayMap | - _: NDArrayMap2 | - _: NDArrayReindex | - _: NDArrayAgg | - _: NDArrayMatMul | - _: TailLoop | - _: Recur | - _: ReadPartition | - _: WritePartition | - _: WriteMetadata | - _: ReadValue | - _: WriteValue | - _: NDArrayWrite | - _: StreamZipJoinProducers | - _: ArrayMaximalIndependentSet | - _: RNGStateLiteral => false + case _: EncodedLiteral | + _: RunAgg | + _: InitOp | + _: SeqOp | + _: CombOp | + _: ResultOp | + _: CombOpValue | + _: InitFromSerializedValue | + _: AggStateValue | + _: RunAgg | + _: RunAggScan | + _: SerializeAggs | + _: DeserializeAggs | + _: ArrayZeros | + _: MakeNDArray | + _: NDArrayShape | + _: NDArrayReshape | + _: NDArrayConcat | + _: NDArrayRef | + _: NDArraySlice | + _: NDArrayFilter | + _: NDArrayMap | + _: NDArrayMap2 | + _: NDArrayReindex | + _: NDArrayAgg | + _: NDArrayMatMul | + _: TailLoop | + _: Recur | + _: ReadPartition | + _: WritePartition | + _: WriteMetadata | + _: ReadValue | + _: WriteValue | + _: NDArrayWrite | + _: StreamZipJoinProducers | + _: ArrayMaximalIndependentSet | + _: RNGStateLiteral => false + case Block(bindings, _) => + bindings.forall(_.scope == Scope.EVAL) case x: ApplyIR => - !Exists(x.body, { - case n: IR => !Interpretable(n) - case _ => false - }) + !Exists( + x.body, + { + case n: IR => !Interpretable(n) + case _ => false + }, + ) case _ => true }) } diff --git a/hail/src/main/scala/is/hail/expr/ir/IsConstant.scala b/hail/src/main/scala/is/hail/expr/ir/IsConstant.scala index 49d5b970f8e..4bbe2023817 100644 --- a/hail/src/main/scala/is/hail/expr/ir/IsConstant.scala +++ b/hail/src/main/scala/is/hail/expr/ir/IsConstant.scala @@ -3,12 +3,11 @@ package is.hail.expr.ir import is.hail.types.virtual._ object CanEmit { - def apply(t: Type): Boolean = { + def apply(t: Type): Boolean = t match { case TInt32 | TInt64 | TFloat32 | TFloat64 | TBoolean | TString => true case _ => false } - } } object Constant { @@ -17,10 +16,12 @@ object Constant { } object IsConstant { - def apply(ir: IR): Boolean = { + def apply(ir: IR): Boolean = ir match { - case I32(_) | I64(_) | F32(_) | F64(_) | True() | False() | NA(_) | Str(_) | Literal(_, _) | EncodedLiteral(_, _) => true + case I32(_) | I64(_) | F32(_) | F64(_) | True() | False() | NA(_) | Str(_) | Literal( + _, + _, + ) | EncodedLiteral(_, _) => true case _ => false } - } } diff --git a/hail/src/main/scala/is/hail/expr/ir/IsPure.scala b/hail/src/main/scala/is/hail/expr/ir/IsPure.scala new file mode 100644 index 00000000000..b4295d2b9ef --- /dev/null +++ b/hail/src/main/scala/is/hail/expr/ir/IsPure.scala @@ -0,0 +1,11 @@ +package is.hail.expr.ir + +import is.hail.types.virtual.TVoid + +object IsPure { + def apply(x: IR): Boolean = x match { + case _ if x.typ == TVoid => false + case _: WritePartition | _: WriteValue => false + case _ => true + } +} diff --git a/hail/src/main/scala/is/hail/expr/ir/LiftRelationalValues.scala b/hail/src/main/scala/is/hail/expr/ir/LiftRelationalValues.scala index 76f1f01649b..e5ae730acd7 100644 --- a/hail/src/main/scala/is/hail/expr/ir/LiftRelationalValues.scala +++ b/hail/src/main/scala/is/hail/expr/ir/LiftRelationalValues.scala @@ -9,31 +9,56 @@ object LiftRelationalValues { def apply(ir0: BaseIR): BaseIR = { - def rewrite(ir: BaseIR, ab: BoxedArrayBuilder[(String, IR)], memo: mutable.Map[IR, String]): BaseIR = ir match { + def rewrite(ir: BaseIR, ab: BoxedArrayBuilder[(String, IR)], memo: mutable.Map[IR, String]) + : BaseIR = ir match { case RelationalLet(name, value, body) => val value2 = rewrite(value, ab, memo).asInstanceOf[IR] val ab2 = new BoxedArrayBuilder[(String, IR)] val memo2 = mutable.Map.empty[IR, String] val body2 = rewrite(body, ab2, memo2).asInstanceOf[IR] - RelationalLet(name, value2, ab2.result().foldRight[IR](body2) { case ((name, value), acc) => RelationalLet(name, value, acc) }) + RelationalLet( + name, + value2, + ab2.result().foldRight[IR](body2) { case ((name, value), acc) => + RelationalLet(name, value, acc) + }, + ) case RelationalLetTable(name, value, body) => val value2 = rewrite(value, ab, memo).asInstanceOf[IR] val ab2 = new BoxedArrayBuilder[(String, IR)] val memo2 = mutable.Map.empty[IR, String] val body2 = rewrite(body, ab2, memo2).asInstanceOf[TableIR] - RelationalLetTable(name, value2, ab2.result().foldRight[TableIR](body2) { case ((name, value), acc) => RelationalLetTable(name, value, acc) }) + RelationalLetTable( + name, + value2, + ab2.result().foldRight[TableIR](body2) { case ((name, value), acc) => + RelationalLetTable(name, value, acc) + }, + ) case RelationalLetMatrixTable(name, value, body) => val value2 = rewrite(value, ab, memo).asInstanceOf[IR] val ab2 = new BoxedArrayBuilder[(String, IR)] val memo2 = mutable.Map.empty[IR, String] val body2 = rewrite(body, ab2, memo2).asInstanceOf[MatrixIR] - RelationalLetMatrixTable(name, value2, ab2.result().foldRight[MatrixIR](body2) { case ((name, value), acc) => RelationalLetMatrixTable(name, value, acc) }) + RelationalLetMatrixTable( + name, + value2, + ab2.result().foldRight[MatrixIR](body2) { case ((name, value), acc) => + RelationalLetMatrixTable(name, value, acc) + }, + ) case RelationalLetBlockMatrix(name, value, body) => val value2 = rewrite(value, ab, memo).asInstanceOf[IR] val ab2 = new BoxedArrayBuilder[(String, IR)] val memo2 = mutable.Map.empty[IR, String] val body2 = rewrite(body, ab2, memo2).asInstanceOf[BlockMatrixIR] - RelationalLetBlockMatrix(name, value2, ab2.result().foldRight[BlockMatrixIR](body2) { case ((name, value), acc) => RelationalLetBlockMatrix(name, value, acc) }) + RelationalLetBlockMatrix( + name, + value2, + ab2.result().foldRight[BlockMatrixIR](body2) { case ((name, value), acc) => + RelationalLetBlockMatrix(name, value, acc) + }, + ) case LiftMeOut(child) => val name = memo.get(child) match { case Some(name) => name @@ -46,12 +71,12 @@ object LiftRelationalValues { } RelationalRef(name, child.typ) case (_: TableAggregate - | _: TableCount - | _: TableToValueApply - | _: BlockMatrixToValueApply - | _: TableCollect - | _: BlockMatrixCollect - | _: TableGetGlobals) if ir.typ != TVoid => + | _: TableCount + | _: TableToValueApply + | _: BlockMatrixToValueApply + | _: TableCollect + | _: BlockMatrixCollect + | _: TableGetGlobals) if ir.typ != TVoid => val ref = RelationalRef(genUID(), ir.asInstanceOf[IR].typ) val newChild = ir.mapChildren(rewrite(_, ab, memo)) ab += ((ref.name, newChild.asInstanceOf[IR])) @@ -63,10 +88,18 @@ object LiftRelationalValues { val ab = new BoxedArrayBuilder[(String, IR)] val memo = mutable.Map.empty[IR, String] rewrite(ir0, ab, memo) match { - case rw: IR => ab.result().foldRight[IR](rw) { case ((name, value), acc) => RelationalLet(name, value, acc) } - case rw: TableIR => ab.result().foldRight[TableIR](rw) { case ((name, value), acc) => RelationalLetTable(name, value, acc) } - case rw: MatrixIR => ab.result().foldRight[MatrixIR](rw) { case ((name, value), acc) => RelationalLetMatrixTable(name, value, acc) } - case rw: BlockMatrixIR => ab.result().foldRight[BlockMatrixIR](rw) { case ((name, value), acc) => RelationalLetBlockMatrix(name, value, acc) } + case rw: IR => ab.result().foldRight[IR](rw) { case ((name, value), acc) => + RelationalLet(name, value, acc) + } + case rw: TableIR => ab.result().foldRight[TableIR](rw) { case ((name, value), acc) => + RelationalLetTable(name, value, acc) + } + case rw: MatrixIR => ab.result().foldRight[MatrixIR](rw) { case ((name, value), acc) => + RelationalLetMatrixTable(name, value, acc) + } + case rw: BlockMatrixIR => ab.result().foldRight[BlockMatrixIR](rw) { + case ((name, value), acc) => RelationalLetBlockMatrix(name, value, acc) + } } } diff --git a/hail/src/main/scala/is/hail/expr/ir/LowerMatrixIR.scala b/hail/src/main/scala/is/hail/expr/ir/LowerMatrixIR.scala index 4d3a2ac151a..188800a4ae5 100644 --- a/hail/src/main/scala/is/hail/expr/ir/LowerMatrixIR.scala +++ b/hail/src/main/scala/is/hail/expr/ir/LowerMatrixIR.scala @@ -15,39 +15,46 @@ object LowerMatrixIR { def apply(ctx: ExecuteContext, ir: IR): IR = { val ab = new BoxedArrayBuilder[(String, IR)] val l1 = lower(ctx, ir, ab) - ab.result().foldRight[IR](l1) { case ((ident, value), body) => RelationalLet(ident, value, body) } + ab.result().foldRight[IR](l1) { case ((ident, value), body) => + RelationalLet(ident, value, body) + } } def apply(ctx: ExecuteContext, tir: TableIR): TableIR = { val ab = new BoxedArrayBuilder[(String, IR)] val l1 = lower(ctx, tir, ab) - ab.result().foldRight[TableIR](l1) { case ((ident, value), body) => RelationalLetTable(ident, value, body) } + ab.result().foldRight[TableIR](l1) { case ((ident, value), body) => + RelationalLetTable(ident, value, body) + } } def apply(ctx: ExecuteContext, mir: MatrixIR): TableIR = { val ab = new BoxedArrayBuilder[(String, IR)] val l1 = lower(ctx, mir, ab) - ab.result().foldRight[TableIR](l1) { case ((ident, value), body) => RelationalLetTable(ident, value, body) } + ab.result().foldRight[TableIR](l1) { case ((ident, value), body) => + RelationalLetTable(ident, value, body) + } } def apply(ctx: ExecuteContext, bmir: BlockMatrixIR): BlockMatrixIR = { val ab = new BoxedArrayBuilder[(String, IR)] val l1 = lower(ctx, bmir, ab) - ab.result().foldRight[BlockMatrixIR](l1) { case ((ident, value), body) => RelationalLetBlockMatrix(ident, value, body) } + ab.result().foldRight[BlockMatrixIR](l1) { case ((ident, value), body) => + RelationalLetBlockMatrix(ident, value, body) + } } - private[this] def lowerChildren( ctx: ExecuteContext, ir: BaseIR, - ab: BoxedArrayBuilder[(String, IR)] + ab: BoxedArrayBuilder[(String, IR)], ): BaseIR = { ir.mapChildren { case tir: TableIR => lower(ctx, tir, ab) case mir: MatrixIR => throw new RuntimeException(s"expect specialized lowering rule for " + - s"${ ir.getClass.getName }\n Found MatrixIR child $mir") + s"${ir.getClass.getName}\n Found MatrixIR child $mir") case bmir: BlockMatrixIR => lower(ctx, bmir, ab) case vir: IR => lower(ctx, vir, ab) } @@ -59,7 +66,8 @@ object LowerMatrixIR { def globals(tir: TableIR): IR = SelectFields( Ref("global", tir.typ.globalType), - tir.typ.globalType.fieldNames.diff(FastSeq(colsFieldName))) + tir.typ.globalType.fieldNames.diff(FastSeq(colsFieldName)), + ) def nCols(tir: TableIR): IR = ArrayLen(colVals(tir)) @@ -69,8 +77,10 @@ object LowerMatrixIR { import is.hail.expr.ir.DeprecatedIRBuilder._ def matrixSubstEnv(child: MatrixIR): BindingEnv[IRProxy] = { - val e = Env[IRProxy]("global" -> 'global.selectFields(child.typ.globalType.fieldNames: _*), - "va" -> 'row.selectFields(child.typ.rowType.fieldNames: _*)) + val e = Env[IRProxy]( + "global" -> 'global.selectFields(child.typ.globalType.fieldNames: _*), + "va" -> 'row.selectFields(child.typ.rowType.fieldNames: _*), + ) BindingEnv(e, agg = Some(e), scan = Some(e)) } @@ -80,121 +90,169 @@ object LowerMatrixIR { } def matrixSubstEnvIR(child: MatrixIR, lowered: TableIR): BindingEnv[IR] = { - val e = Env[IR]("global" -> SelectFields(Ref("global", lowered.typ.globalType), child.typ.globalType.fieldNames), - "va" -> SelectFields(Ref("row", lowered.typ.rowType), child.typ.rowType.fieldNames)) + val e = Env[IR]( + "global" -> SelectFields( + Ref("global", lowered.typ.globalType), + child.typ.globalType.fieldNames, + ), + "va" -> SelectFields(Ref("row", lowered.typ.rowType), child.typ.rowType.fieldNames), + ) BindingEnv(e, agg = Some(e), scan = Some(e)) } - private[this] def lower( ctx: ExecuteContext, mir: MatrixIR, - ab: BoxedArrayBuilder[(String, IR)] + liftedRelationalLets: BoxedArrayBuilder[(String, IR)], ): TableIR = { val lowered = mir match { case RelationalLetMatrixTable(name, value, body) => - RelationalLetTable(name, lower(ctx, value, ab), lower(ctx, body, ab)) + RelationalLetTable( + name, + lower(ctx, value, liftedRelationalLets), + lower(ctx, body, liftedRelationalLets), + ) - case CastTableToMatrix(child, entries, cols, colKey) => - val lc = lower(ctx, child, ab) + case CastTableToMatrix(child, entries, cols, _) => + val lc = lower(ctx, child, liftedRelationalLets) val row = Ref("row", lc.typ.rowType) val glob = Ref("global", lc.typ.globalType) TableMapRows( lc, bindIR(GetField(row, entries)) { entries => - If(IsNA(entries), + If( + IsNA(entries), Die("missing entry array unsupported in 'to_matrix_table_row_major'", row.typ), - bindIRs(ArrayLen(entries), ArrayLen(GetField(glob, cols))) { case Seq(entriesLen, colsLen) => - If(entriesLen cne colsLen, - Die( - strConcat( - Str("length mismatch between entry array and column array in 'to_matrix_table_row_major': "), - invoke("str", TString, entriesLen), - Str(" entries, "), - invoke("str", TString, colsLen), - Str(" cols, at "), - invoke("str", TString, SelectFields(row, child.typ.key)) - ), row.typ, -1), - row - ) - } + bindIRs(ArrayLen(entries), ArrayLen(GetField(glob, cols))) { + case Seq(entriesLen, colsLen) => + If( + entriesLen cne colsLen, + Die( + strConcat( + Str( + "length mismatch between entry array and column array in 'to_matrix_table_row_major': " + ), + invoke("str", TString, entriesLen), + Str(" entries, "), + invoke("str", TString, colsLen), + Str(" cols, at "), + invoke("str", TString, SelectFields(row, child.typ.key)), + ), + row.typ, + -1, + ), + row, + ) + }, ) - } + }, ).rename(Map(entries -> entriesFieldName), Map(cols -> colsFieldName)) case MatrixToMatrixApply(child, function) => - val loweredChild = lower(ctx, child, ab) + val loweredChild = lower(ctx, child, liftedRelationalLets) TableToTableApply(loweredChild, function.lower()) case MatrixRename(child, globalMap, colMap, rowMap, entryMap) => - var t = lower(ctx, child, ab).rename(rowMap, globalMap) + var t = lower(ctx, child, liftedRelationalLets).rename(rowMap, globalMap) if (colMap.nonEmpty) { val newColsType = TArray(child.typ.colType.rename(colMap)) - t = t.mapGlobals('global.castRename(t.typ.globalType.insertFields(FastSeq((colsFieldName, newColsType))))) + t = t.mapGlobals('global.castRename(t.typ.globalType.insertFields(FastSeq(( + colsFieldName, + newColsType, + ))))) } if (entryMap.nonEmpty) { val newEntriesType = TArray(child.typ.entryType.rename(entryMap)) - t = t.mapRows('row.castRename(t.typ.rowType.insertFields(FastSeq((entriesFieldName, newEntriesType))))) + t = t.mapRows('row.castRename(t.typ.rowType.insertFields(FastSeq(( + entriesFieldName, + newEntriesType, + ))))) } t case MatrixKeyRowsBy(child, keys, isSorted) => - lower(ctx, child, ab).keyBy(keys, isSorted) + lower(ctx, child, liftedRelationalLets).keyBy(keys, isSorted) case MatrixFilterRows(child, pred) => - lower(ctx, child, ab) - .filter(subst(lower(ctx, pred, ab), matrixSubstEnv(child))) + lower(ctx, child, liftedRelationalLets) + .filter(subst(lower(ctx, pred, liftedRelationalLets), matrixSubstEnv(child))) case MatrixFilterCols(child, pred) => - lower(ctx, child, ab) + lower(ctx, child, liftedRelationalLets) .mapGlobals('global.insertFields('newColIdx -> - irRange(0, 'global (colsField).len) + irRange(0, 'global(colsField).len) .filter('i ~> - (let(sa = 'global (colsField)('i)) - in subst(lower(ctx, pred, ab), matrixGlobalSubstEnv(child)))))) - .mapRows('row.insertFields(entriesField -> 'global ('newColIdx).map('i ~> 'row (entriesField)('i)))) + (let(sa = 'global(colsField)('i)) + in subst(lower(ctx, pred, liftedRelationalLets), matrixGlobalSubstEnv(child)))))) + .mapRows('row.insertFields( + entriesField -> 'global('newColIdx).map('i ~> 'row(entriesField)('i)) + )) .mapGlobals('global .insertFields(colsField -> - 'global ('newColIdx).map('i ~> 'global (colsField)('i))) + 'global('newColIdx).map('i ~> 'global(colsField)('i))) .dropFields('newColIdx)) case MatrixAnnotateRowsTable(child, table, root, product) => val kt = table.typ.keyType if (kt.size == 1 && kt.types(0) == TInterval(child.typ.rowKeyStruct.types(0))) - TableIntervalJoin(lower(ctx, child, ab), lower(ctx, table, ab), root, product) + TableIntervalJoin( + lower(ctx, child, liftedRelationalLets), + lower(ctx, table, liftedRelationalLets), + root, + product, + ) else - TableLeftJoinRightDistinct(lower(ctx, child, ab), lower(ctx, table, ab), root) + TableLeftJoinRightDistinct( + lower(ctx, child, liftedRelationalLets), + lower(ctx, table, liftedRelationalLets), + root, + ) case MatrixChooseCols(child, oldIndices) => - lower(ctx, child, ab) + lower(ctx, child, liftedRelationalLets) .mapGlobals('global.insertFields('newColIdx -> oldIndices.map(I32))) - .mapRows('row.insertFields(entriesField -> 'global ('newColIdx).map('i ~> 'row (entriesField)('i)))) + .mapRows('row.insertFields( + entriesField -> 'global('newColIdx).map('i ~> 'row(entriesField)('i)) + )) .mapGlobals('global - .insertFields(colsField -> 'global ('newColIdx).map('i ~> 'global (colsField)('i))) + .insertFields(colsField -> 'global('newColIdx).map('i ~> 'global(colsField)('i))) .dropFields('newColIdx)) case MatrixAnnotateColsTable(child, table, root) => val col = Symbol(genUID()) - val colKey = makeStruct(table.typ.key.zip(child.typ.colKey).map { case (tk, mck) => Symbol(tk) -> col(Symbol(mck)) }: _*) - lower(ctx, child, ab) - .mapGlobals(let(__dictfield = lower(ctx, table, ab) - .keyBy(FastSeq()) - .collect() - .apply('rows) - .arrayStructToDict(table.typ.key)) { + val colKey = makeStruct(table.typ.key.zip(child.typ.colKey).map { case (tk, mck) => + Symbol(tk) -> col(Symbol(mck)) + }: _*) + lower(ctx, child, liftedRelationalLets) + .mapGlobals(let(__dictfield = + lower(ctx, table, liftedRelationalLets) + .keyBy(FastSeq()) + .collect() + .apply('rows) + .arrayStructToDict(table.typ.key) + ) { 'global.insertFields(colsField -> - 'global (colsField).map(col ~> col.insertFields(Symbol(root) -> '__dictfield.invoke("get", table.typ.valueType, colKey)))) + 'global(colsField).map(col ~> col.insertFields(Symbol(root) -> '__dictfield.invoke( + "get", + table.typ.valueType, + colKey, + )))) }) case MatrixMapGlobals(child, newGlobals) => - lower(ctx, child, ab) + lower(ctx, child, liftedRelationalLets) .mapGlobals( - subst(lower(ctx, newGlobals, ab), BindingEnv(Env[IRProxy]( - "global" -> 'global.selectFields(child.typ.globalType.fieldNames: _*)))) - .insertFields(colsField -> 'global (colsField))) + subst( + lower(ctx, newGlobals, liftedRelationalLets), + BindingEnv(Env[IRProxy]( + "global" -> 'global.selectFields(child.typ.globalType.fieldNames: _*) + )), + ) + .insertFields(colsField -> 'global(colsField)) + ) case MatrixMapRows(child, newRow) => def liftScans(ir: IR): IRProxy = { @@ -204,7 +262,7 @@ object LowerMatrixIR { builder += ((s, a)) Ref(s, a.typ) - case a@AggFold(zero, seqOp, combOp, accumName, otherAccumName, true) => + case a @ AggFold(_, _, _, _, _, true) => val s = genUID() builder += ((s, a)) Ref(s, a.typ) @@ -232,7 +290,7 @@ object LowerMatrixIR { val liftedBody = lift(body, ab) val aggs = ab.result() - val aggIR = AggGroupBy(a, MakeStruct(aggs), true) + val aggIR = AggGroupBy(a, MakeStruct(aggs), true) val uid = Ref(genUID(), aggIR.typ) builder += (uid.name -> aggIR) val valueUID = genUID() @@ -242,7 +300,7 @@ object LowerMatrixIR { Let( (valueUID -> GetField(eltUID, "value")) +: aggs.map { case (name, _) => name -> GetField(Ref(valueUID, valueType), name) }, - MakeTuple.ordered(FastSeq(GetField(eltUID, "key"), liftedBody)) + MakeTuple.ordered(FastSeq(GetField(eltUID, "key"), liftedBody)), ) }) @@ -251,7 +309,8 @@ object LowerMatrixIR { val liftedBody = lift(body, ab) val aggs = ab.result() - val aggIR = AggArrayPerElement(a, elementName, indexName, MakeStruct(aggs), knownLength, true) + val aggIR = + AggArrayPerElement(a, elementName, indexName, MakeStruct(aggs), knownLength, true) val uid = Ref(genUID(), aggIR.typ) builder += (uid.name -> aggIR) @@ -259,14 +318,33 @@ object LowerMatrixIR { Let(aggs.map { case (name, _) => name -> GetField(eltUID, name) }, liftedBody) }) - case AggLet(name, value, body, true) => - val ab = new BoxedArrayBuilder[(String, IR)] - val liftedBody = lift(body, ab) - val aggs = ab.result() - val structResult = MakeStruct(aggs) - val uid = genUID() - builder += (uid -> AggLet(name, value, structResult, true)) - Let(aggs.map { case (name, _) => name -> GetField(Ref(uid, structResult.typ), name) }, liftedBody) + case Block(bindings, body) => + val newBindings = new BoxedArrayBuilder[Binding] + def go(i: Int, builder: BoxedArrayBuilder[(String, IR)]): IR = { + if (i == bindings.length) { + lift(body, builder) + } else bindings(i) match { + case Binding(name, value, Scope.SCAN) => + val ab = new BoxedArrayBuilder[(String, IR)] + val liftedBody = go(i + 1, ab) + val aggs = ab.result() + val structResult = MakeStruct(aggs) + val uid = genUID() + builder += (uid -> Block( + FastSeq(Binding(name, value, Scope.EVAL)), + structResult, + )) + newBindings ++= aggs.map { case (name, _) => + Binding(name, GetField(Ref(uid, structResult.typ), name), Scope.EVAL) + } + liftedBody + case Binding(name, value, scope) => + newBindings += Binding(name, lift(value, builder), scope) + go(i + 1, builder) + } + } + val newBody = go(0, builder) + Block(newBindings.result(), newBody) case _ => MapIR(lift(_, builder))(ir) @@ -275,7 +353,7 @@ object LowerMatrixIR { val ab = new BoxedArrayBuilder[(String, IR)] val b0 = lift(ir, ab) - val scans = ab.result() + val scans = ab.result().toFastSeq val scanStruct = MakeStruct(scans) val scanResultRef = Ref(genUID(), scanStruct.typ) @@ -284,27 +362,33 @@ object LowerMatrixIR { irRange(0, 'row(entriesField).len) .filter('i ~> !'row(entriesField)('i).isNA) .streamAgg('i ~> - (aggLet(sa = 'global(colsField)('i), - g = 'row(entriesField)('i)) + (aggLet(sa = 'global(colsField)('i), g = 'row(entriesField)('i)) in b0)) } else irToProxy(b0) - let.applyDynamicNamed("apply")((scanResultRef.name, scanStruct))( - scans.foldLeft[IRProxy](b1) { case (acc, (name, _)) => let.applyDynamicNamed("apply")((name, GetField(scanResultRef, name)))(acc) }) + letDyn( + ((scanResultRef.name, irToProxy(scanStruct)) + +: scans.map { case (name, _) => + name -> irToProxy(GetField(scanResultRef, name)) + }): _* + )(b1) } - - val lc = lower(ctx, child, ab) + val lc = lower(ctx, child, liftedRelationalLets) lc.mapRows(let(n_cols = 'global(colsField).len) { - liftScans(Subst(lower(ctx, newRow, ab), matrixSubstEnvIR(child, lc))) + liftScans(Subst(lower(ctx, newRow, liftedRelationalLets), matrixSubstEnvIR(child, lc))) .insertFields(entriesField -> 'row(entriesField)) }) case MatrixMapCols(child, newCol, _) => - val loweredChild = lower(ctx, child, ab) + val loweredChild = lower(ctx, child, liftedRelationalLets) - def lift(ir: IR, scanBindings: BoxedArrayBuilder[(String, IR)], aggBindings: BoxedArrayBuilder[(String, IR)]): IR = ir match { + def lift( + ir: IR, + scanBindings: BoxedArrayBuilder[(String, IR)], + aggBindings: BoxedArrayBuilder[(String, IR)], + ): IR = ir match { case a: ApplyScanOp => val s = genUID() scanBindings += ((s, a)) @@ -315,7 +399,7 @@ object LowerMatrixIR { aggBindings += ((s, a)) Ref(s, a.typ) - case a@AggFold(zero, seqOp, combOp, accumName, otherAccumName, isScan) => + case a @ AggFold(_, _, _, _, _, isScan) => val s = genUID() if (isScan) { scanBindings += ((s, a)) @@ -364,12 +448,15 @@ object LowerMatrixIR { val valueType = elementType.types(1) ToDict(mapIR(ToStream(uid)) { eltUID => MakeTuple.ordered( - FastSeq(GetField(eltUID, "key"), + FastSeq( + GetField(eltUID, "key"), Let( (valueUID -> GetField(eltUID, "value")) +: - aggs.map { case (name, _) => name -> GetField(Ref(valueUID, valueType), name) }, - liftedBody - ) + aggs.map { case (name, _) => + name -> GetField(Ref(valueUID, valueType), name) + }, + liftedBody, + ), ) ) }) @@ -381,25 +468,51 @@ object LowerMatrixIR { else (lift(body, scanBindings, ab), aggBindings) val aggs = ab.result() - val aggIR = AggArrayPerElement(a, elementName, indexName, MakeStruct(aggs), knownLength, isScan) + val aggIR = + AggArrayPerElement(a, elementName, indexName, MakeStruct(aggs), knownLength, isScan) val uid = Ref(genUID(), aggIR.typ) builder += (uid.name -> aggIR) ToArray(mapIR(ToStream(uid)) { eltUID => Let(aggs.map { case (name, _) => name -> GetField(eltUID, name) }, liftedBody) }) - case AggLet(name, value, body, isScan) => - val ab = new BoxedArrayBuilder[(String, IR)] - val (liftedBody, builder) = - if (isScan) (lift(body, ab, aggBindings), scanBindings) - else (lift(body, scanBindings, ab), aggBindings) - - val aggs = ab.result() - val structResult = MakeStruct(aggs) - - val uid = Ref(genUID(), structResult.typ) - builder += (uid.name -> AggLet(name, value, structResult, isScan)) - Let(aggs.map { case (name, _) => name -> GetField(uid, name) }, liftedBody) + case Block(bindings, body) => + var newBindings = Seq[Binding]() + def go( + i: Int, + scanBindings: BoxedArrayBuilder[(String, IR)], + aggBindings: BoxedArrayBuilder[(String, IR)], + ): IR = { + if (i == bindings.length) { + lift(body, scanBindings, aggBindings) + } else bindings(i) match { + case Binding(name, value, Scope.EVAL) => + val lifted = lift(value, scanBindings, aggBindings) + val liftedBody = go(i + 1, scanBindings, aggBindings) + newBindings = Binding(name, lifted, Scope.EVAL) +: newBindings + liftedBody + case Binding(name, value, scope) => + val ab = new BoxedArrayBuilder[(String, IR)] + val liftedBody = if (scope == Scope.SCAN) + go(i + 1, ab, aggBindings) + else + go(i + 1, scanBindings, ab) + + val builder = if (scope == Scope.SCAN) scanBindings else aggBindings + + val aggs = ab.result() + val structResult = MakeStruct(aggs) + + val uid = genUID() + builder += (uid -> Block(FastSeq(Binding(name, value, scope)), structResult)) + newBindings = aggs.map { case (name, _) => + Binding(name, GetField(Ref(uid, structResult.typ), name), Scope.EVAL) + } ++ newBindings + liftedBody + } + } + val newBody = go(0, scanBindings, aggBindings) + Block(newBindings.toFastSeq, newBody) case x: StreamAgg => x case x: StreamAggScan => x @@ -411,7 +524,11 @@ object LowerMatrixIR { val scanBuilder = new BoxedArrayBuilder[(String, IR)] val aggBuilder = new BoxedArrayBuilder[(String, IR)] - val b0 = lift(Subst(lower(ctx, newCol, ab), matrixSubstEnvIR(child, loweredChild)), scanBuilder, aggBuilder) + val b0 = lift( + Subst(lower(ctx, newCol, liftedRelationalLets), matrixSubstEnvIR(child, loweredChild)), + scanBuilder, + aggBuilder, + ) val aggs = aggBuilder.result() val scans = scanBuilder.result() @@ -420,7 +537,10 @@ object LowerMatrixIR { val noOp: (IRProxy => IRProxy, IRProxy => IRProxy) = (identity[IRProxy], identity[IRProxy]) - val (aggOutsideTransformer: (IRProxy => IRProxy), aggInsideTransformer: (IRProxy => IRProxy)) = if (aggs.isEmpty) + val ( + aggOutsideTransformer: (IRProxy => IRProxy), + aggInsideTransformer: (IRProxy => IRProxy), + ) = if (aggs.isEmpty) noOp else { val aggStruct = MakeStruct(aggs) @@ -429,36 +549,61 @@ object LowerMatrixIR { aggLet(va = 'row.selectFields(child.typ.rowType.fieldNames: _*)) { makeStruct( ('count, applyAggOp(Count(), FastSeq(), FastSeq())), - ('array_aggs, irRange(0, 'global(colsField).len) - .aggElements('__element_idx, '__result_idx, Some('global(colsField).len))( - let(sa = 'global(colsField)('__result_idx)) { - aggLet(sa = 'global(colsField)('__element_idx), - g = 'row(entriesField)('__element_idx)) { - aggFilter(!'g.isNA, aggStruct) + ( + 'array_aggs, + irRange(0, 'global(colsField).len) + .aggElements('__element_idx, '__result_idx, Some('global(colsField).len))( + let(sa = 'global(colsField)('__result_idx)) { + aggLet( + sa = 'global(colsField)('__element_idx), + g = 'row(entriesField)('__element_idx), + ) { + aggFilter(!'g.isNA, aggStruct) + } } - }))) - }) + ), + ), + ) + } + ) val ident = genUID() - ab += ((ident, aggResult)) + liftedRelationalLets += ((ident, aggResult)) val aggResultRef = Ref(genUID(), aggResult.typ) - val aggResultElementRef = Ref(genUID(), aggResult.typ.asInstanceOf[TStruct] - .fieldType("array_aggs") - .asInstanceOf[TArray].elementType) + val aggResultElementRef = Ref( + genUID(), + aggResult.typ.asInstanceOf[TStruct] + .fieldType("array_aggs") + .asInstanceOf[TArray].elementType, + ) - val bindResult: IRProxy => IRProxy = let.applyDynamicNamed("apply")((aggResultRef.name, irToProxy(RelationalRef(ident, aggResult.typ)))).apply(_) + val bindResult: IRProxy => IRProxy = let.applyDynamicNamed("apply")(( + aggResultRef.name, + irToProxy(RelationalRef(ident, aggResult.typ)), + )).apply(_) val bodyResult: IRProxy => IRProxy = (x: IRProxy) => - let.applyDynamicNamed("apply")((aggResultRef.name, irToProxy(RelationalRef(ident, aggResult.typ)))) - .apply(let(n_rows = Symbol(aggResultRef.name)('count), array_aggs = Symbol(aggResultRef.name)('array_aggs)) { + let.applyDynamicNamed("apply")(( + aggResultRef.name, + irToProxy(RelationalRef(ident, aggResult.typ)), + )) + .apply(let( + n_rows = Symbol(aggResultRef.name)('count), + array_aggs = Symbol(aggResultRef.name)('array_aggs), + ) { let.applyDynamicNamed("apply")((aggResultElementRef.name, 'array_aggs(idx))) { - aggs.foldLeft[IRProxy](x) { case (acc, (name, _)) => let.applyDynamicNamed("apply")((name, GetField(aggResultElementRef, name)))(acc) } + aggs.foldLeft[IRProxy](x) { case (acc, (name, _)) => + let.applyDynamicNamed("apply")((name, GetField(aggResultElementRef, name)))(acc) + } } }) (bindResult, bodyResult) } - val (scanOutsideTransformer: (IRProxy => IRProxy), scanInsideTransformer: (IRProxy => IRProxy)) = if (scans.isEmpty) + val ( + scanOutsideTransformer: (IRProxy => IRProxy), + scanInsideTransformer: (IRProxy => IRProxy), + ) = if (scans.isEmpty) noOp else { val scanStruct = MakeStruct(scans) @@ -466,34 +611,42 @@ object LowerMatrixIR { val scanResultArray = ToArray(StreamAggScan( ToStream(GetField(Ref("global", loweredChild.typ.globalType), colsFieldName)), "sa", - scanStruct)) + scanStruct, + )) val scanResultRef = Ref(genUID(), scanResultArray.typ) - val scanResultElementRef = Ref(genUID(), scanResultArray.typ.asInstanceOf[TArray].elementType) + val scanResultElementRef = + Ref(genUID(), scanResultArray.typ.asInstanceOf[TArray].elementType) - val bindResult: IRProxy => IRProxy = let.applyDynamicNamed("apply")((scanResultRef.name, scanResultArray)).apply(_) + val bindResult: IRProxy => IRProxy = + let.applyDynamicNamed("apply")((scanResultRef.name, scanResultArray)).apply(_) val bodyResult: IRProxy => IRProxy = (x: IRProxy) => - let.applyDynamicNamed("apply")((scanResultElementRef.name, ArrayRef(scanResultRef, idx)))( + let.applyDynamicNamed("apply")(( + scanResultElementRef.name, + ArrayRef(scanResultRef, idx), + ))( scans.foldLeft[IRProxy](x) { case (acc, (name, _)) => let.applyDynamicNamed("apply")((name, GetField(scanResultElementRef, name)))(acc) - }) + } + ) (bindResult, bodyResult) } loweredChild.mapGlobals('global.insertFields(colsField -> - aggOutsideTransformer(scanOutsideTransformer(irRange(0, 'global(colsField).len).map(idxSym ~> let(__cols_array = 'global(colsField), sa = '__cols_array(idxSym)) { - aggInsideTransformer(scanInsideTransformer(b0)) - }))) - )) + aggOutsideTransformer(scanOutsideTransformer(irRange(0, 'global(colsField).len).map( + idxSym ~> let(__cols_array = 'global(colsField), sa = '__cols_array(idxSym)) { + aggInsideTransformer(scanInsideTransformer(b0)) + } + ))))) case MatrixFilterEntries(child, pred) => - val lc = lower(ctx, child, ab) - lc.mapRows('row.insertFields(entriesField -> - irRange(0, 'global (colsField).len).map { + val lc = lower(ctx, child, liftedRelationalLets) + lc.mapRows('row.insertFields(entriesField -> + irRange(0, 'global(colsField).len).map { 'i ~> - let(g = 'row (entriesField)('i)) { - irIf(let(sa = 'global (colsField)('i)) - in !subst(lower(ctx, pred, ab), matrixSubstEnv(child))) { + let(g = 'row(entriesField)('i)) { + irIf(let(sa = 'global(colsField)('i)) + in !subst(lower(ctx, pred, liftedRelationalLets), matrixSubstEnv(child))) { NA(child.typ.entryType) } { 'g @@ -504,30 +657,34 @@ object LowerMatrixIR { case MatrixUnionCols(left, right, joinType) => val rightEntries = genUID() val rightCols = genUID() - val ll = lower(ctx, left, ab).distinct() + val ll = lower(ctx, left, liftedRelationalLets).distinct() def handleMissingEntriesArray(entries: Symbol, cols: Symbol): IRProxy = if (joinType == "inner") 'row(entries) else irIf('row(entries).isNA) { irRange(0, 'global(cols).len) - .map('a ~> irToProxy(MakeStruct(right.typ.entryType.fieldNames.map(f => (f, NA(right.typ.entryType.fieldType(f))))))) + .map('a ~> irToProxy(MakeStruct(right.typ.entryType.fieldNames.map(f => + (f, NA(right.typ.entryType.fieldType(f))) + )))) } { 'row(entries) } - val rr = lower(ctx, right, ab).distinct() + val rr = lower(ctx, right, liftedRelationalLets).distinct() TableJoin( ll, rr.mapRows('row.castRename(rr.typ.rowType.rename(Map(entriesFieldName -> rightEntries)))) .mapGlobals('global .insertFields(Symbol(rightCols) -> 'global(colsField)) .selectFields(rightCols)), - joinType) + joinType, + ) .mapRows('row .insertFields(entriesField -> makeArray( handleMissingEntriesArray(entriesField, colsField), - handleMissingEntriesArray(Symbol(rightEntries), Symbol(rightCols))) + handleMissingEntriesArray(Symbol(rightEntries), Symbol(rightCols)), + ) .flatMap('a ~> 'a)) .dropFields(Symbol(rightEntries))) .mapGlobals('global @@ -536,133 +693,158 @@ object LowerMatrixIR { .dropFields(Symbol(rightCols))) case MatrixMapEntries(child, newEntries) => - val loweredChild = lower(ctx, child, ab) + val loweredChild = lower(ctx, child, liftedRelationalLets) val rt = loweredChild.typ.rowType val gt = loweredChild.typ.globalType TableMapRows( loweredChild, InsertFields( Ref("row", rt), - FastSeq((entriesFieldName, ToArray(StreamZip( - FastSeq( - ToStream(GetField(Ref("row", rt), entriesFieldName)), - ToStream(GetField(Ref("global", gt), colsFieldName))), - FastSeq("g", "sa"), - Subst(lower(ctx, newEntries, ab), BindingEnv(Env( - "global" -> SelectFields(Ref("global", gt), child.typ.globalType.fieldNames), - "va" -> SelectFields(Ref("row", rt), child.typ.rowType.fieldNames)))), - ArrayZipBehavior.AssumeSameLength - ))))) + FastSeq(( + entriesFieldName, + ToArray(StreamZip( + FastSeq( + ToStream(GetField(Ref("row", rt), entriesFieldName)), + ToStream(GetField(Ref("global", gt), colsFieldName)), + ), + FastSeq("g", "sa"), + Subst( + lower(ctx, newEntries, liftedRelationalLets), + BindingEnv(Env( + "global" -> SelectFields(Ref("global", gt), child.typ.globalType.fieldNames), + "va" -> SelectFields(Ref("row", rt), child.typ.rowType.fieldNames), + )), + ), + ArrayZipBehavior.AssumeSameLength, + )), + )), + ), ) - case MatrixRepartition(child, n, shuffle) => TableRepartition(lower(ctx, child, ab), n, shuffle) + case MatrixRepartition(child, n, shuffle) => + TableRepartition(lower(ctx, child, liftedRelationalLets), n, shuffle) - case MatrixFilterIntervals(child, intervals, keep) => TableFilterIntervals(lower(ctx, child, ab), intervals, keep) + case MatrixFilterIntervals(child, intervals, keep) => + TableFilterIntervals(lower(ctx, child, liftedRelationalLets), intervals, keep) case MatrixUnionRows(children) => // FIXME: this should check that all children have the same column keys. - val first = lower(ctx, children.head, ab) + val first = lower(ctx, children.head, liftedRelationalLets) TableUnion(FastSeq(first) ++ - children.tail.map(lower(ctx, _, ab) + children.tail.map(lower(ctx, _, liftedRelationalLets) .mapRows('row.selectFields(first.typ.rowType.fieldNames: _*)))) - case MatrixDistinctByRow(child) => TableDistinct(lower(ctx, child, ab)) + case MatrixDistinctByRow(child) => TableDistinct(lower(ctx, child, liftedRelationalLets)) - case MatrixRowsHead(child, n) => TableHead(lower(ctx, child, ab), n) - case MatrixRowsTail(child, n) => TableTail(lower(ctx, child, ab), n) + case MatrixRowsHead(child, n) => TableHead(lower(ctx, child, liftedRelationalLets), n) + case MatrixRowsTail(child, n) => TableTail(lower(ctx, child, liftedRelationalLets), n) - case MatrixColsHead(child, n) => lower(ctx, child, ab) - .mapGlobals('global.insertFields(colsField -> 'global (colsField).arraySlice(0, Some(n), 1))) - .mapRows('row.insertFields(entriesField -> 'row (entriesField).arraySlice(0, Some(n), 1))) + case MatrixColsHead(child, n) => lower(ctx, child, liftedRelationalLets) + .mapGlobals('global.insertFields(colsField -> 'global(colsField).arraySlice( + 0, + Some(n), + 1, + ))) + .mapRows('row.insertFields(entriesField -> 'row(entriesField).arraySlice(0, Some(n), 1))) - case MatrixColsTail(child, n) => lower(ctx, child, ab) - .mapGlobals('global.insertFields(colsField -> 'global (colsField).arraySlice(-n, None, 1))) - .mapRows('row.insertFields(entriesField -> 'row (entriesField).arraySlice(-n, None, 1))) + case MatrixColsTail(child, n) => lower(ctx, child, liftedRelationalLets) + .mapGlobals('global.insertFields(colsField -> 'global(colsField).arraySlice(-n, None, 1))) + .mapRows('row.insertFields(entriesField -> 'row(entriesField).arraySlice(-n, None, 1))) case MatrixExplodeCols(child, path) => - val loweredChild = lower(ctx, child, ab) + val loweredChild = lower(ctx, child, liftedRelationalLets) val lengths = Symbol(genUID()) val colIdx = Symbol(genUID()) val nestedIdx = Symbol(genUID()) val colElementUID1 = Symbol(genUID()) - - val nestedRefs = path.init.scanLeft('global (colsField)(colIdx): IRProxy)((irp, name) => irp(Symbol(name))) + val nestedRefs = + path.init.scanLeft('global(colsField)(colIdx): IRProxy)((irp, name) => irp(Symbol(name))) val postExplodeSelector = path.zip(nestedRefs).zipWithIndex.foldRight[IRProxy](nestedIdx) { case (((field, ref), i), arg) => ref.insertFields(Symbol(field) -> (if (i == nestedRefs.length - 1) - ref(Symbol(field)).toArray(arg) - else - arg)) + ref(Symbol(field)).toArray(arg) + else + arg)) } - val arrayIR = path.foldLeft[IRProxy](colElementUID1) { case (irp, fieldName) => irp(Symbol(fieldName)) } + val arrayIR = path.foldLeft[IRProxy](colElementUID1) { case (irp, fieldName) => + irp(Symbol(fieldName)) + } loweredChild - .mapGlobals('global.insertFields(lengths -> 'global (colsField).map({ + .mapGlobals('global.insertFields(lengths -> 'global(colsField).map({ colElementUID1 ~> arrayIR.len.orElse(0) }))) .mapGlobals('global.insertFields(colsField -> - irRange(0, 'global (colsField).len, 1) + irRange(0, 'global(colsField).len, 1) .flatMap({ colIdx ~> - irRange(0, 'global (lengths)(colIdx), 1) + irRange(0, 'global(lengths)(colIdx), 1) .map({ nestedIdx ~> postExplodeSelector }) }))) .mapRows('row.insertFields(entriesField -> - irRange(0, 'row (entriesField).len, 1) + irRange(0, 'row(entriesField).len, 1) .flatMap(colIdx ~> - irRange(0, 'global (lengths)(colIdx), 1).map(Symbol(genUID()) ~> 'row (entriesField)(colIdx))))) + irRange(0, 'global(lengths)(colIdx), 1).map( + Symbol(genUID()) ~> 'row(entriesField)(colIdx) + )))) .mapGlobals('global.dropFields(lengths)) case MatrixAggregateRowsByKey(child, entryExpr, rowExpr) => - val substEnv = matrixSubstEnv(child) - val eeSub = subst(lower(ctx, entryExpr, ab), substEnv) - val reSub = subst(lower(ctx, rowExpr, ab), substEnv) - lower(ctx, child, ab) + val eeSub = subst(lower(ctx, entryExpr, liftedRelationalLets), substEnv) + val reSub = subst(lower(ctx, rowExpr, liftedRelationalLets), substEnv) + lower(ctx, child, liftedRelationalLets) .aggregateByKey( - reSub.insertFields(entriesField -> irRange(0, 'global (colsField).len) - .aggElements('__element_idx, '__result_idx, Some('global (colsField).len))( - let(sa = 'global (colsField)('__result_idx)) { - aggLet(sa = 'global (colsField)('__element_idx), - g = 'row (entriesField)('__element_idx)) { + reSub.insertFields(entriesField -> irRange(0, 'global(colsField).len) + .aggElements('__element_idx, '__result_idx, Some('global(colsField).len))( + let(sa = 'global(colsField)('__result_idx)) { + aggLet( + sa = 'global(colsField)('__element_idx), + g = 'row(entriesField)('__element_idx), + ) { aggFilter(!'g.isNA, eeSub) } - }))) + } + )) + ) case MatrixCollectColsByKey(child) => - lower(ctx, child, ab) + lower(ctx, child, liftedRelationalLets) .mapGlobals('global.insertFields('newColIdx -> - irRange(0, 'global (colsField).len).map { + irRange(0, 'global(colsField).len).map { 'i ~> - makeTuple('global (colsField)('i).selectFields(child.typ.colKey: _*), - 'i) + makeTuple('global(colsField)('i).selectFields(child.typ.colKey: _*), 'i) }.groupByKey.toArray)) .mapRows('row.insertFields(entriesField -> - 'global ('newColIdx).map { + 'global('newColIdx).map { 'kv ~> makeStruct(child.typ.entryType.fieldNames.map { s => - (Symbol(s), 'kv ('value).map { - 'i ~> 'row (entriesField)('i)(Symbol(s)) - }) + ( + Symbol(s), + 'kv('value).map { + 'i ~> 'row(entriesField)('i)(Symbol(s)) + }, + ) }: _*) })) .mapGlobals('global .insertFields(colsField -> - 'global ('newColIdx).map { + 'global('newColIdx).map { 'kv ~> - 'kv ('key).insertFields( + 'kv('key).insertFields( child.typ.colValueStruct.fieldNames.map { s => - (Symbol(s), 'kv ('value).map('i ~> 'global (colsField)('i)(Symbol(s)))) - }: _*) + (Symbol(s), 'kv('value).map('i ~> 'global(colsField)('i)(Symbol(s)))) + }: _* + ) }) - .dropFields('newColIdx) - ) + .dropFields('newColIdx)) - case MatrixExplodeRows(child, path) => TableExplode(lower(ctx, child, ab), path) + case MatrixExplodeRows(child, path) => + TableExplode(lower(ctx, child, liftedRelationalLets), path) case mr: MatrixRead => mr.lower() @@ -676,69 +858,78 @@ object LowerMatrixIR { val keyMap = Symbol(genUID()) val aggElementIdx = Symbol(genUID()) - val e1 = Env[IRProxy]("global" -> 'global.selectFields(child.typ.globalType.fieldNames: _*), - "va" -> 'row.selectFields(child.typ.rowType.fieldNames: _*)) + val e1 = Env[IRProxy]( + "global" -> 'global.selectFields(child.typ.globalType.fieldNames: _*), + "va" -> 'row.selectFields(child.typ.rowType.fieldNames: _*), + ) val e2 = Env[IRProxy]("global" -> 'global.selectFields(child.typ.globalType.fieldNames: _*)) - val ceSub = subst(lower(ctx, colExpr, ab), BindingEnv(e2, agg = Some(e1))) - val eeSub = subst(lower(ctx, entryExpr, ab), BindingEnv(e1, agg = Some(e1))) + val ceSub = subst(lower(ctx, colExpr, liftedRelationalLets), BindingEnv(e2, agg = Some(e1))) + val eeSub = + subst(lower(ctx, entryExpr, liftedRelationalLets), BindingEnv(e1, agg = Some(e1))) - lower(ctx, child, ab) + lower(ctx, child, liftedRelationalLets) .mapGlobals('global.insertFields(keyMap -> - let(__cols_field = 'global (colsField)) { + let(__cols_field = 'global(colsField)) { irRange(0, '__cols_field.len) - .map(originalColIdx ~> let(__cols_field_element = '__cols_field (originalColIdx)) { - makeStruct('key -> '__cols_field_element.selectFields(colKey: _*), 'value -> originalColIdx) + .map(originalColIdx ~> let(__cols_field_element = '__cols_field(originalColIdx)) { + makeStruct( + 'key -> '__cols_field_element.selectFields(colKey: _*), + 'value -> originalColIdx, + ) }) .groupByKey .toArray })) .mapRows('row.insertFields(entriesField -> - let(__entries = 'row (entriesField), __key_map = 'global (keyMap)) { + let(__entries = 'row(entriesField), __key_map = 'global(keyMap)) { irRange(0, '__key_map.len) - .map(newColIdx1 ~> '__key_map (newColIdx1) + .map(newColIdx1 ~> '__key_map(newColIdx1) .apply('value) .streamAgg(aggElementIdx ~> - aggLet(g = '__entries (aggElementIdx), sa = 'global (colsField)(aggElementIdx)) { + aggLet(g = '__entries(aggElementIdx), sa = 'global(colsField)(aggElementIdx)) { aggFilter(!'g.isNA, eeSub) })) })) .mapGlobals( 'global.insertFields(colsField -> - let(__key_map = 'global (keyMap)) { + let(__key_map = 'global(keyMap)) { irRange(0, '__key_map.len) .map(newColIdx2 ~> concatStructs( - '__key_map (newColIdx2)('key), - '__key_map (newColIdx2)('value) - .streamAgg(colsAggIdx ~> aggLet(sa = 'global (colsField)(colsAggIdx)) { + '__key_map(newColIdx2)('key), + '__key_map(newColIdx2)('value) + .streamAgg(colsAggIdx ~> aggLet(sa = 'global(colsField)(colsAggIdx)) { ceSub - }) + }), )) - } - ).dropFields(keyMap)) + }).dropFields(keyMap) + ) case MatrixLiteral(_, tl) => tl } if (!mir.typ.isCompatibleWith(lowered.typ)) - throw new RuntimeException(s"Lowering changed type:\n BEFORE: ${ Pretty(ctx, mir) }\n ${ mir.typ }\n ${ mir.typ.canonicalTableType}\n AFTER: ${ Pretty(ctx, lowered) }\n ${ lowered.typ }") + throw new RuntimeException( + s"Lowering changed type:\n BEFORE: ${Pretty(ctx, mir)}\n ${mir.typ}\n ${mir.typ.canonicalTableType}\n AFTER: ${Pretty(ctx, lowered)}\n ${lowered.typ}" + ) lowered } - private[this] def lower( ctx: ExecuteContext, tir: TableIR, - ab: BoxedArrayBuilder[(String, IR)] + ab: BoxedArrayBuilder[(String, IR)], ): TableIR = { val lowered = tir match { case CastMatrixToTable(child, entries, cols) => lower(ctx, child, ab) .mapRows('row.selectFields(child.typ.rowType.fieldNames ++ Array(entriesFieldName): _*)) - .mapGlobals('global.selectFields(child.typ.globalType.fieldNames ++ Array(colsFieldName): _*)) + .mapGlobals('global.selectFields( + child.typ.globalType.fieldNames ++ Array(colsFieldName): _* + )) .rename(Map(entriesFieldName -> entries), Map(colsFieldName -> cols)) - case x@MatrixEntriesTable(child) => + case x @ MatrixEntriesTable(child) => val lc = lower(ctx, child, ab) if (child.typ.rowKey.nonEmpty && child.typ.colKey.nonEmpty) { @@ -750,43 +941,59 @@ object LowerMatrixIR { val values = Symbol(genUID()) lc .mapGlobals('global.insertFields(oldColIdx -> - irRange(0, 'global (colsField).len) - .map(lambdaIdx1 ~> makeStruct('key -> 'global (colsField)(lambdaIdx1).selectFields(child.typ.colKey: _*), 'value -> lambdaIdx1)) + irRange(0, 'global(colsField).len) + .map(lambdaIdx1 ~> makeStruct( + 'key -> 'global(colsField)(lambdaIdx1).selectFields(child.typ.colKey: _*), + 'value -> lambdaIdx1, + )) .sort(ascending = true, onKey = true) .map(lambdaIdx1 ~> lambdaIdx1('value)))) - .aggregateByKey(makeStruct(values -> applyAggOp(Collect(), seqOpArgs = FastSeq('row.selectFields(lc.typ.valueType.fieldNames: _*))))) + .aggregateByKey(makeStruct(values -> applyAggOp( + Collect(), + seqOpArgs = FastSeq('row.selectFields(lc.typ.valueType.fieldNames: _*)), + ))) .mapRows('row.dropFields(values).insertFields(toExplode -> - 'global (oldColIdx) - .flatMap(lambdaIdx1 ~> 'row (values) + 'global(oldColIdx) + .flatMap(lambdaIdx1 ~> 'row(values) .filter(lambdaIdx2 ~> !lambdaIdx2(entriesField)(lambdaIdx1).isNA) - .map(lambdaIdx3 ~> let(__col = 'global (colsField)(lambdaIdx1), __entry = lambdaIdx3(entriesField)(lambdaIdx1)) { + .map(lambdaIdx3 ~> let( + __col = 'global(colsField)(lambdaIdx1), + __entry = lambdaIdx3(entriesField)(lambdaIdx1), + ) { makeStruct( - child.typ.rowValueStruct.fieldNames.map(Symbol(_)).map(f => f -> lambdaIdx3(f)) ++ - child.typ.colType.fieldNames.map(Symbol(_)).map(f => f -> '__col (f)) ++ - child.typ.entryType.fieldNames.map(Symbol(_)).map(f => f -> '__entry (f)): _* + child.typ.rowValueStruct.fieldNames.map(Symbol(_)).map(f => + f -> lambdaIdx3(f) + ) ++ + child.typ.colType.fieldNames.map(Symbol(_)).map(f => f -> '__col(f)) ++ + child.typ.entryType.fieldNames.map(Symbol(_)).map(f => f -> '__entry(f)): _* ) })))) - .explode(toExplode) .mapRows(makeStruct(x.typ.rowType.fieldNames.map { f => val fd = Symbol(f) - (fd, if (child.typ.rowKey.contains(f)) 'row (fd) else 'row (toExplode) (fd)) }: _*)) + (fd, if (child.typ.rowKey.contains(f)) 'row(fd) else 'row(toExplode)(fd)) + }: _*)) .mapGlobals('global.dropFields(colsField, oldColIdx)) .keyBy(child.typ.rowKey ++ child.typ.colKey, isSorted = true) } else { val colIdx = Symbol(genUID()) val lambdaIdx = Symbol(genUID()) val result = lc - .mapRows('row.insertFields(colIdx -> irRange(0, 'global (colsField).len) - .filter(lambdaIdx ~> !'row (entriesField)(lambdaIdx).isNA))) + .mapRows('row.insertFields(colIdx -> irRange(0, 'global(colsField).len) + .filter(lambdaIdx ~> !'row(entriesField)(lambdaIdx).isNA))) .explode(colIdx) - .mapRows(let(__col_struct = 'global (colsField)('row (colIdx)), - __entry_struct = 'row (entriesField)('row (colIdx))) { - val newFields = child.typ.colType.fieldNames.map(Symbol(_)).map(f => f -> '__col_struct (f)) ++ - child.typ.entryType.fieldNames.map(Symbol(_)).map(f => f -> '__entry_struct (f)) - - 'row.dropFields(entriesField, colIdx).insertFieldsList(newFields, - ordering = Some(x.typ.rowType.fieldNames.toFastSeq)) + .mapRows(let( + __col_struct = 'global(colsField)('row(colIdx)), + __entry_struct = 'row(entriesField)('row(colIdx)), + ) { + val newFields = + child.typ.colType.fieldNames.map(Symbol(_)).map(f => f -> '__col_struct(f)) ++ + child.typ.entryType.fieldNames.map(Symbol(_)).map(f => f -> '__entry_struct(f)) + + 'row.dropFields(entriesField, colIdx).insertFieldsList( + newFields, + ordering = Some(x.typ.rowType.fieldNames.toFastSeq), + ) }) .mapGlobals('global.dropFields(colsField)) if (child.typ.colKey.isEmpty) @@ -799,9 +1006,16 @@ object LowerMatrixIR { case MatrixToTableApply(child, function) => val loweredChild = lower(ctx, child, ab) - TableToTableApply(loweredChild, + TableToTableApply( + loweredChild, function.lower() - .getOrElse(WrappedMatrixToTableFunction(function, colsFieldName, entriesFieldName, child.typ.colKey))) + .getOrElse(WrappedMatrixToTableFunction( + function, + colsFieldName, + entriesFieldName, + child.typ.colKey, + )), + ) case MatrixRowsTable(child) => lower(ctx, child, ab) @@ -812,16 +1026,18 @@ object LowerMatrixIR { val colKey = child.typ.colKey let(__cols_and_globals = lower(ctx, child, ab).getGlobals) { val sortedCols = if (colKey.isEmpty) - '__cols_and_globals (colsField) + '__cols_and_globals(colsField) else - '__cols_and_globals (colsField).map { '__cols_element ~> - makeStruct( - // key struct - '_1 -> '__cols_element.selectFields(colKey: _*), - '_2 -> '__cols_element) + '__cols_and_globals(colsField).map { + '__cols_element ~> + makeStruct( + // key struct + '_1 -> '__cols_element.selectFields(colKey: _*), + '_2 -> '__cols_element, + ) }.sort(true, onKey = true) .map { - 'elt ~> 'elt ('_2) + 'elt ~> 'elt('_2) } makeStruct('rows -> sortedCols, 'global -> '__cols_and_globals.dropFields(colsField)) }.parallelize(None).keyBy(child.typ.colKey) @@ -836,7 +1052,7 @@ object LowerMatrixIR { private[this] def lower( ctx: ExecuteContext, bmir: BlockMatrixIR, - ab: BoxedArrayBuilder[(String, IR)] + ab: BoxedArrayBuilder[(String, IR)], ): BlockMatrixIR = { val lowered = bmir match { case noMatrixChildren => lowerChildren(ctx, noMatrixChildren, ab).asInstanceOf[BlockMatrixIR] @@ -848,46 +1064,69 @@ object LowerMatrixIR { private[this] def lower( ctx: ExecuteContext, ir: IR, - ab: BoxedArrayBuilder[(String, IR)] + ab: BoxedArrayBuilder[(String, IR)], ): IR = { val lowered = ir match { - case MatrixToValueApply(child, function) => TableToValueApply(lower(ctx, child, ab), function.lower() - .getOrElse(WrappedMatrixToValueFunction(function, colsFieldName, entriesFieldName, child.typ.colKey))) + case MatrixToValueApply(child, function) => TableToValueApply( + lower(ctx, child, ab), + function.lower() + .getOrElse(WrappedMatrixToValueFunction( + function, + colsFieldName, + entriesFieldName, + child.typ.colKey, + )), + ) case MatrixWrite(child, writer) => - TableWrite(lower(ctx, child, ab), WrappedMatrixWriter(writer, colsFieldName, entriesFieldName, child.typ.colKey)) + TableWrite( + lower(ctx, child, ab), + WrappedMatrixWriter(writer, colsFieldName, entriesFieldName, child.typ.colKey), + ) case MatrixMultiWrite(children, writer) => - TableMultiWrite(children.map(lower(ctx, _, ab)), WrappedMatrixNativeMultiWriter(writer, children.head.typ.colKey)) + TableMultiWrite( + children.map(lower(ctx, _, ab)), + WrappedMatrixNativeMultiWriter(writer, children.head.typ.colKey), + ) case MatrixCount(child) => lower(ctx, child, ab) .aggregate(makeTuple(applyAggOp(Count(), FastSeq(), FastSeq()), 'global(colsField).len)) case MatrixAggregate(child, query) => val lc = lower(ctx, child, ab) - val idx = Symbol(genUID()) - TableAggregate(lc, + TableAggregate( + lc, aggExplodeIR( filterIR( zip2( ToStream(GetField(Ref("row", lc.typ.rowType), entriesFieldName)), ToStream(GetField(Ref("global", lc.typ.globalType), colsFieldName)), - ArrayZipBehavior.AssertSameLength + ArrayZipBehavior.AssertSameLength, ) { case (e, c) => MakeTuple.ordered(FastSeq(e, c)) - }) { filterTuple => - ApplyUnaryPrimOp(Bang, IsNA(GetTupleElement(filterTuple, 0))) - }) { explodedTuple => - AggLet("g", GetTupleElement(explodedTuple, 0), - AggLet("sa", GetTupleElement(explodedTuple, 1), Subst(query, matrixSubstEnvIR(child, lc)), - isScan = false), - isScan = false) - }) + } + )(filterTuple => ApplyUnaryPrimOp(Bang, IsNA(GetTupleElement(filterTuple, 0)))) + ) { explodedTuple => + AggLet( + "g", + GetTupleElement(explodedTuple, 0), + AggLet( + "sa", + GetTupleElement(explodedTuple, 1), + Subst(query, matrixSubstEnvIR(child, lc)), + isScan = false, + ), + isScan = false, + ) + }, + ) case _ => lowerChildren(ctx, ir, ab).asInstanceOf[IR] } assertTypeUnchanged(ir, lowered) lowered } - private[this] def assertTypeUnchanged(original: BaseIR, lowered: BaseIR) { + private[this] def assertTypeUnchanged(original: BaseIR, lowered: BaseIR): Unit = if (lowered.typ != original.typ) - fatal(s"lowering changed type:\n before: ${ original.typ }\n after: ${ lowered.typ }\n ${ original.getClass.getName } => ${ lowered.getClass.getName }") - } + fatal( + s"lowering changed type:\n before: ${original.typ}\n after: ${lowered.typ}\n ${original.getClass.getName} => ${lowered.getClass.getName}" + ) } diff --git a/hail/src/main/scala/is/hail/expr/ir/LowerOrInterpretNonCompilable.scala b/hail/src/main/scala/is/hail/expr/ir/LowerOrInterpretNonCompilable.scala index f36863d5f2c..02e7db7a9fa 100644 --- a/hail/src/main/scala/is/hail/expr/ir/LowerOrInterpretNonCompilable.scala +++ b/hail/src/main/scala/is/hail/expr/ir/LowerOrInterpretNonCompilable.scala @@ -16,20 +16,20 @@ object LowerOrInterpretNonCompilable { val result = CanLowerEfficiently(ctx, value) match { case Some(failReason) => log.info(s"LowerOrInterpretNonCompilable: cannot efficiently lower query: $failReason") - log.info(s"interpreting non-compilable result: ${ value.getClass.getSimpleName }") + log.info(s"interpreting non-compilable result: ${value.getClass.getSimpleName}") val v = Interpret.alreadyLowered(ctx, value) if (value.typ == TVoid) { Begin(FastSeq()) } else Literal.coerce(value.typ, v) case None => log.info(s"LowerOrInterpretNonCompilable: whole stage code generation is a go!") - log.info(s"lowering result: ${ value.getClass.getSimpleName }") + log.info(s"lowering result: ${value.getClass.getSimpleName}") val fullyLowered = LowerToDistributedArrayPass(DArrayLowering.All).transform(ctx, value) .asInstanceOf[IR] - log.info(s"compiling and evaluating result: ${ value.getClass.getSimpleName }") + log.info(s"compiling and evaluating result: ${value.getClass.getSimpleName}") CompileAndEvaluate.evalToIR(ctx, fullyLowered, true) } - log.info(s"took ${ formatTime(System.nanoTime() - preTime) }") + log.info(s"took ${formatTime(System.nanoTime() - preTime)}") assert(result.typ == value.typ) result } @@ -55,7 +55,8 @@ object LowerOrInterpretNonCompilable { res case None => throw new RuntimeException(name) } - case x: IR if InterpretableButNotCompilable(x) => evaluate(rewriteChildren(x, m).asInstanceOf[IR]) + case x: IR if InterpretableButNotCompilable(x) => + evaluate(rewriteChildren(x, m).asInstanceOf[IR]) case _ => rewriteChildren(x, m) } } diff --git a/hail/src/main/scala/is/hail/expr/ir/LoweringAnalyses.scala b/hail/src/main/scala/is/hail/expr/ir/LoweringAnalyses.scala index 523d4e421a0..2b0e5afde82 100644 --- a/hail/src/main/scala/is/hail/expr/ir/LoweringAnalyses.scala +++ b/hail/src/main/scala/is/hail/expr/ir/LoweringAnalyses.scala @@ -3,10 +3,14 @@ package is.hail.expr.ir import is.hail.backend.ExecuteContext object LoweringAnalyses { - def apply(ir: BaseIR, ctx:ExecuteContext): LoweringAnalyses = { + def apply(ir: BaseIR, ctx: ExecuteContext): LoweringAnalyses = { val requirednessAnalysis = Requiredness(ir, ctx) val distinctKeyedAnalysis = DistinctlyKeyed.apply(ir) LoweringAnalyses(requirednessAnalysis, distinctKeyedAnalysis) } } -case class LoweringAnalyses(requirednessAnalysis: RequirednessAnalysis, distinctKeyedAnalysis: DistinctKeyedAnalysis) + +case class LoweringAnalyses( + requirednessAnalysis: RequirednessAnalysis, + distinctKeyedAnalysis: DistinctKeyedAnalysis, +) diff --git a/hail/src/main/scala/is/hail/expr/ir/MapIR.scala b/hail/src/main/scala/is/hail/expr/ir/MapIR.scala index 4d9e3d9a956..d04033a9b95 100644 --- a/hail/src/main/scala/is/hail/expr/ir/MapIR.scala +++ b/hail/src/main/scala/is/hail/expr/ir/MapIR.scala @@ -5,9 +5,9 @@ object MapIR { case ta: TableAggregate => ta case ma: MatrixAggregate => ma case _ => ir.mapChildren { - case c: IR => f(c) - case c => c - } + case c: IR => f(c) + case c => c + } } def mapBaseIR(ir: BaseIR, f: BaseIR => BaseIR): BaseIR = f(ir.mapChildren(mapBaseIR(_, f))) @@ -18,4 +18,4 @@ object VisitIR { f(ir) ir.children.foreach(apply(_)(f)) } -} \ No newline at end of file +} diff --git a/hail/src/main/scala/is/hail/expr/ir/MatrixIR.scala b/hail/src/main/scala/is/hail/expr/ir/MatrixIR.scala index 4b45e5bd4df..b2af1d0239b 100644 --- a/hail/src/main/scala/is/hail/expr/ir/MatrixIR.scala +++ b/hail/src/main/scala/is/hail/expr/ir/MatrixIR.scala @@ -1,4 +1,3 @@ - package is.hail.expr.ir import is.hail.HailContext @@ -15,24 +14,37 @@ import is.hail.types._ import is.hail.types.virtual._ import is.hail.utils._ import is.hail.variant._ + import org.apache.spark.sql.Row import org.json4s._ import org.json4s.jackson.JsonMethods object MatrixIR { - def read(fs: FS, path: String, dropCols: Boolean = false, dropRows: Boolean = false, requestedType: Option[MatrixType] = None): MatrixIR = { + def read( + fs: FS, + path: String, + dropCols: Boolean = false, + dropRows: Boolean = false, + requestedType: Option[MatrixType] = None, + ): MatrixIR = { val reader = MatrixNativeReader(fs, path) MatrixRead(requestedType.getOrElse(reader.fullMatrixType), dropCols, dropRows, reader) } - def range(nRows: Int, nCols: Int, nPartitions: Option[Int], dropCols: Boolean = false, dropRows: Boolean = false): MatrixIR = { + def range( + nRows: Int, + nCols: Int, + nPartitions: Option[Int], + dropCols: Boolean = false, + dropRows: Boolean = false, + ): MatrixIR = { val reader = MatrixRangeReader(nRows, nCols, nPartitions) val requestedType = reader.fullMatrixTypeWithoutUIDs MatrixRead(requestedType, dropCols = dropCols, dropRows = dropRows, reader = reader) } } -abstract sealed class MatrixIR extends BaseIR { +sealed abstract class MatrixIR extends BaseIR { def typ: MatrixType def partitionCounts: Option[IndexedSeq[Long]] = None @@ -43,12 +55,11 @@ abstract sealed class MatrixIR extends BaseIR { override def copy(newChildren: IndexedSeq[BaseIR]): MatrixIR - def unpersist(): MatrixIR = { + def unpersist(): MatrixIR = this match { case MatrixLiteral(typ, tl) => MatrixLiteral(typ, tl.unpersist().asInstanceOf[TableLiteral]) case x => x } - } def pyUnpersist(): MatrixIR = unpersist() @@ -56,14 +67,30 @@ abstract sealed class MatrixIR extends BaseIR { } object MatrixLiteral { - def apply(ctx: ExecuteContext, typ: MatrixType, rvd: RVD, globals: Row, colValues: IndexedSeq[Row]): MatrixLiteral = { + def apply( + ctx: ExecuteContext, + typ: MatrixType, + rvd: RVD, + globals: Row, + colValues: IndexedSeq[Row], + ): MatrixLiteral = { val tt = typ.canonicalTableType - MatrixLiteral(typ, + MatrixLiteral( + typ, TableLiteral( - TableValue(ctx, tt, - BroadcastRow(ctx, Row.fromSeq(globals.toSeq :+ colValues), typ.canonicalTableType.globalType), - rvd), - ctx.theHailClassLoader)) + TableValue( + ctx, + tt, + BroadcastRow( + ctx, + Row.fromSeq(globals.toSeq :+ colValues), + typ.canonicalTableType.globalType, + ), + rvd, + ), + ctx.theHailClassLoader, + ), + ) } } @@ -132,9 +159,8 @@ trait MatrixReader { def renderShort(): String - def defaultRender(): String = { + def defaultRender(): String = StringEscapeUtils.escapeString(JsonMethods.compact(toJValue)) - } final def matrixToTableType(mt: MatrixType, includeColsArray: Boolean = true): TableType = { TableType( @@ -148,7 +174,8 @@ trait MatrixReader { globalType = if (includeColsArray) mt.globalType.appendKey(LowerMatrixIR.colsFieldName, TArray(mt.colType)) else - mt.globalType) + mt.globalType, + ) } final def rowUIDFieldName: String = MatrixReader.rowUIDFieldName @@ -161,7 +188,9 @@ abstract class MatrixHybridReader extends TableReaderWithExtraUID with MatrixRea override def fullTypeWithoutUIDs: TableType = matrixToTableType( fullMatrixTypeWithoutUIDs.copy( - colType = fullMatrixTypeWithoutUIDs.colType.appendKey(colUIDFieldName, colUIDType))) + colType = fullMatrixTypeWithoutUIDs.colType.appendKey(colUIDFieldName, colUIDType) + ) + ) override def defaultRender(): String = super.defaultRender() @@ -173,12 +202,22 @@ abstract class MatrixHybridReader extends TableReaderWithExtraUID with MatrixRea tr, InsertFields( Ref("row", tr.typ.rowType), - FastSeq(LowerMatrixIR.entriesFieldName -> MakeArray(FastSeq(), TArray(requestedType.entryType))))) + FastSeq(LowerMatrixIR.entriesFieldName -> MakeArray( + FastSeq(), + TArray(requestedType.entryType), + )), + ), + ) tr = TableMapGlobals( tr, InsertFields( Ref("global", tr.typ.globalType), - FastSeq(LowerMatrixIR.colsFieldName -> MakeArray(FastSeq(), TArray(requestedType.colType))))) + FastSeq(LowerMatrixIR.colsFieldName -> MakeArray( + FastSeq(), + TArray(requestedType.colType), + )), + ), + ) } tr } @@ -192,7 +231,7 @@ object MatrixNativeReader { val spec = (RelationalSpec.read(fs, params.path): @unchecked) match { case mts: AbstractMatrixTableSpec => mts - case _: AbstractTableSpec => fatal(s"file is a Table, not a MatrixTable: '${ params.path }'") + case _: AbstractTableSpec => fatal(s"file is a Table, not a MatrixTable: '${params.path}'") } val intervals = params.options.map(_.intervals) @@ -221,22 +260,25 @@ object MatrixNativeReader { case class MatrixNativeReaderParameters( path: String, - options: Option[NativeReaderOptions]) + options: Option[NativeReaderOptions], +) class MatrixNativeReader( val params: MatrixNativeReaderParameters, - spec: AbstractMatrixTableSpec + spec: AbstractMatrixTableSpec, ) extends MatrixReader { def pathsUsed: Seq[String] = FastSeq(params.path) - override def renderShort(): String = s"(MatrixNativeReader ${ params.path } ${ params.options.map(_.renderShort()).getOrElse("") })" + override def renderShort(): String = + s"(MatrixNativeReader ${params.path} ${params.options.map(_.renderShort()).getOrElse("")})" lazy val columnCount: Option[Int] = Some(spec.colsSpec .partitionCounts .sum .toInt) - def partitionCounts: Option[IndexedSeq[Long]] = if (params.options.isEmpty) Some(spec.partitionCounts) else None + def partitionCounts: Option[IndexedSeq[Long]] = + if (params.options.isEmpty) Some(spec.partitionCounts) else None def fullMatrixTypeWithoutUIDs: MatrixType = spec.matrix_type @@ -250,18 +292,29 @@ class MatrixNativeReader( if (dropCols) { val tt = TableType(requestedType.rowType, requestedType.rowKey, requestedType.globalType) - val trdr: TableReader = new TableNativeReader(TableNativeReaderParameters(rowsPath, params.options), spec.rowsSpec) + val trdr: TableReader = + new TableNativeReader(TableNativeReaderParameters(rowsPath, params.options), spec.rowsSpec) var tr: TableIR = TableRead(tt, dropRows, trdr) tr = TableMapGlobals( tr, InsertFields( Ref("global", tr.typ.globalType), - FastSeq(LowerMatrixIR.colsFieldName -> MakeArray(FastSeq(), TArray(requestedType.colType))))) + FastSeq(LowerMatrixIR.colsFieldName -> MakeArray( + FastSeq(), + TArray(requestedType.colType), + )), + ), + ) TableMapRows( tr, InsertFields( Ref("row", tr.typ.rowType), - FastSeq(LowerMatrixIR.entriesFieldName -> MakeArray(FastSeq(), TArray(requestedType.entryType))))) + FastSeq(LowerMatrixIR.entriesFieldName -> MakeArray( + FastSeq(), + TArray(requestedType.entryType), + )), + ), + ) } else { val tt = matrixToTableType(requestedType, includeColsArray = false) val trdr = TableNativeZippedReader( @@ -269,33 +322,46 @@ class MatrixNativeReader( entriesPath, params.options, spec.rowsSpec, - spec.entriesSpec) + spec.entriesSpec, + ) val tr: TableIR = TableRead(tt, dropRows, trdr) val colsRVDSpec = spec.colsSpec.rowsSpec - val partFiles = colsRVDSpec.absolutePartPaths(spec.colsSpec.rowsComponent.absolutePath(colsPath)) + val partFiles = + colsRVDSpec.absolutePartPaths(spec.colsSpec.rowsComponent.absolutePath(colsPath)) val cols = if (partFiles.length == 1) { ReadPartition( MakeStruct(Array("partitionIndex" -> I64(0), "partitionPath" -> Str(partFiles.head))), requestedType.colType, - PartitionNativeReader(colsRVDSpec.typedCodecSpec, colUIDFieldName)) + PartitionNativeReader(colsRVDSpec.typedCodecSpec, colUIDFieldName), + ) } else { val contextType = TStruct("partitionIndex" -> TInt64, "partitionPath" -> TString) val partNames = MakeArray( partFiles.zipWithIndex.map { case (path, idx) => MakeStruct(Array("partitionIndex" -> I64(idx), "partitionPath" -> Str(path))) - }, TArray(contextType)) + }, + TArray(contextType), + ) val elt = Ref(genUID(), contextType) StreamFlatMap( partNames, elt.name, - ReadPartition(elt, requestedType.colType, PartitionNativeReader(colsRVDSpec.typedCodecSpec, colUIDFieldName))) + ReadPartition( + elt, + requestedType.colType, + PartitionNativeReader(colsRVDSpec.typedCodecSpec, colUIDFieldName), + ), + ) } - TableMapGlobals(tr, InsertFields( - Ref("global", tr.typ.globalType), - FastSeq(LowerMatrixIR.colsFieldName -> ToArray(cols)) - )) + TableMapGlobals( + tr, + InsertFields( + Ref("global", tr.typ.globalType), + FastSeq(LowerMatrixIR.colsFieldName -> ToArray(cols)), + ), + ) } } @@ -326,7 +392,8 @@ object MatrixRangeReader { } def apply(params: MatrixRangeReaderParameters): MatrixRangeReader = { - val nPartitionsAdj = math.min(params.nRows, params.nPartitions.getOrElse(HailContext.backend.defaultParallelism)) + val nPartitionsAdj = + math.min(params.nRows, params.nPartitions.getOrElse(HailContext.backend.defaultParallelism)) new MatrixRangeReader(params, nPartitionsAdj) } } @@ -335,7 +402,7 @@ case class MatrixRangeReaderParameters(nRows: Int, nCols: Int, nPartitions: Opti case class MatrixRangeReader( val params: MatrixRangeReaderParameters, - nPartitionsAdj: Int + nPartitionsAdj: Int, ) extends MatrixReader { def pathsUsed: Seq[String] = FastSeq() @@ -348,33 +415,41 @@ case class MatrixRangeReader( colType = TStruct("col_idx" -> TInt32), rowKey = Array("row_idx"), rowType = TStruct("row_idx" -> TInt32), - entryType = TStruct.empty) + entryType = TStruct.empty, + ) override def renderShort(): String = s"(MatrixRangeReader $params $nPartitionsAdj)" val columnCount: Option[Int] = Some(params.nCols) - lazy val partitionCounts: Option[IndexedSeq[Long]] = Some(partition(params.nRows, nPartitionsAdj).map(_.toLong)) + lazy val partitionCounts: Option[IndexedSeq[Long]] = + Some(partition(params.nRows, nPartitionsAdj).map(_.toLong)) override def lower(requestedType: MatrixType, dropCols: Boolean, dropRows: Boolean): TableIR = { val nRowsAdj = if (dropRows) 0 else params.nRows val nColsAdj = if (dropCols) 0 else params.nCols - var ht = TableRange(nRowsAdj, params.nPartitions.getOrElse(HailContext.backend.defaultParallelism)) - .rename(Map("idx" -> "row_idx")) + var ht = + TableRange(nRowsAdj, params.nPartitions.getOrElse(HailContext.backend.defaultParallelism)) + .rename(Map("idx" -> "row_idx")) if (requestedType.colType.hasField(colUIDFieldName)) ht = ht.mapGlobals(makeStruct(LowerMatrixIR.colsField -> - irRange(0, nColsAdj).map('i ~> makeStruct('col_idx -> 'i, Symbol(colUIDFieldName) -> 'i.toL)))) + irRange(0, nColsAdj).map('i ~> makeStruct( + 'col_idx -> 'i, + Symbol(colUIDFieldName) -> 'i.toL, + )))) else ht = ht.mapGlobals(makeStruct(LowerMatrixIR.colsField -> irRange(0, nColsAdj).map('i ~> makeStruct('col_idx -> 'i)))) if (requestedType.rowType.hasField(rowUIDFieldName)) ht = ht.mapRows('row.insertFields( LowerMatrixIR.entriesField -> irRange(0, nColsAdj).map('i ~> makeStruct()), - Symbol(rowUIDFieldName) -> 'row('row_idx).toL)) + Symbol(rowUIDFieldName) -> 'row('row_idx).toL, + )) else ht = ht.mapRows('row.insertFields( LowerMatrixIR.entriesField -> - irRange(0, nColsAdj).map('i ~> makeStruct()))) + irRange(0, nColsAdj).map('i ~> makeStruct()) + )) ht } @@ -397,7 +472,7 @@ object MatrixRead { typ: MatrixType, dropCols: Boolean, dropRows: Boolean, - reader: MatrixReader + reader: MatrixReader, ): MatrixRead = { assert(!reader.fullMatrixTypeWithoutUIDs.rowType.hasField(MatrixReader.rowUIDFieldName) && !reader.fullMatrixTypeWithoutUIDs.colType.hasField(MatrixReader.colUIDFieldName)) @@ -408,19 +483,21 @@ object MatrixRead { typ: MatrixType, dropCols: Boolean, dropRows: Boolean, - reader: MatrixReader - ): MatrixRead = { + reader: MatrixReader, + ): MatrixRead = new MatrixRead(typ, dropCols, dropRows, reader) - } } case class MatrixRead( typ: MatrixType, dropCols: Boolean, dropRows: Boolean, - reader: MatrixReader) extends MatrixIR { - assert(PruneDeadFields.isSupertype(typ, reader.fullMatrixType), - s"\n original: ${ reader.fullMatrixType }\n requested: $typ") + reader: MatrixReader, +) extends MatrixIR { + assert( + PruneDeadFields.isSupertype(typ, reader.fullMatrixType), + s"\n original: ${reader.fullMatrixType}\n requested: $typ", + ) lazy val childrenSeq: IndexedSeq[BaseIR] = Array.empty[BaseIR] @@ -435,21 +512,19 @@ case class MatrixRead( s"dropCols = $dropCols, " + s"dropRows = $dropRows)" - override def partitionCounts: Option[IndexedSeq[Long]] = { + override def partitionCounts: Option[IndexedSeq[Long]] = if (dropRows) Some(Array.empty[Long]) else reader.partitionCounts - } lazy val rowCountUpperBound: Option[Long] = partitionCounts.map(_.sum) - override def columnCount: Option[Int] = { + override def columnCount: Option[Int] = if (dropCols) Some(0) else reader.columnCount - } final def lower(): TableIR = reader.lower(typ, dropCols, dropRows) } @@ -512,7 +587,8 @@ case class MatrixCollectColsByKey(child: MatrixIR) extends MatrixIR { } lazy val typ: MatrixType = { - val newColValueType = TStruct(child.typ.colValueStruct.fields.map(f => f.copy(typ = TArray(f.typ)))) + val newColValueType = + TStruct(child.typ.colValueStruct.fields.map(f => f.copy(typ = TArray(f.typ)))) val newColType = child.typ.colKeyStruct ++ newColValueType val newEntryType = TStruct(child.typ.entryType.fields.map(f => f.copy(typ = TArray(f.typ)))) @@ -525,9 +601,8 @@ case class MatrixCollectColsByKey(child: MatrixIR) extends MatrixIR { } case class MatrixAggregateRowsByKey(child: MatrixIR, entryExpr: IR, rowExpr: IR) extends MatrixIR { - override def typecheck(): Unit = { + override def typecheck(): Unit = assert(child.typ.rowKey.nonEmpty) - } lazy val childrenSeq: IndexedSeq[BaseIR] = Array(child, entryExpr, rowExpr) @@ -538,7 +613,7 @@ case class MatrixAggregateRowsByKey(child: MatrixIR, entryExpr: IR, rowExpr: IR) lazy val typ: MatrixType = child.typ.copy( rowType = child.typ.rowKeyStruct ++ tcoerce[TStruct](rowExpr.typ), - entryType = tcoerce[TStruct](entryExpr.typ) + entryType = tcoerce[TStruct](entryExpr.typ), ) override def columnCount: Option[Int] = child.columnCount @@ -547,9 +622,8 @@ case class MatrixAggregateRowsByKey(child: MatrixIR, entryExpr: IR, rowExpr: IR) } case class MatrixAggregateColsByKey(child: MatrixIR, entryExpr: IR, colExpr: IR) extends MatrixIR { - override def typecheck(): Unit = { + override def typecheck(): Unit = assert(child.typ.colKey.nonEmpty) - } lazy val childrenSeq: IndexedSeq[BaseIR] = Array(child, entryExpr, colExpr) @@ -560,7 +634,8 @@ case class MatrixAggregateColsByKey(child: MatrixIR, entryExpr: IR, colExpr: IR) lazy val typ = child.typ.copy( entryType = tcoerce[TStruct](entryExpr.typ), - colType = child.typ.colKeyStruct ++ tcoerce[TStruct](colExpr.typ)) + colType = child.typ.colKeyStruct ++ tcoerce[TStruct](colExpr.typ), + ) override def partitionCounts: Option[IndexedSeq[Long]] = child.partitionCounts @@ -571,26 +646,40 @@ case class MatrixUnionCols(left: MatrixIR, right: MatrixIR, joinType: String) ex require(joinType == "inner" || joinType == "outer") override def typecheck(): Unit = { - assert(left.typ.rowKeyStruct == right.typ.rowKeyStruct, s"${left.typ.rowKeyStruct} != ${right.typ.rowKeyStruct}") + assert( + left.typ.rowKeyStruct == right.typ.rowKeyStruct, + s"${left.typ.rowKeyStruct} != ${right.typ.rowKeyStruct}", + ) assert(left.typ.colType == right.typ.colType, s"${left.typ.colType} != ${right.typ.colType}") - assert(left.typ.entryType == right.typ.entryType, s"${left.typ.entryType} != ${right.typ.entryType}") + assert( + left.typ.entryType == right.typ.entryType, + s"${left.typ.entryType} != ${right.typ.entryType}", + ) } lazy val childrenSeq: IndexedSeq[BaseIR] = Array(left, right) def copy(newChildren: IndexedSeq[BaseIR]): MatrixUnionCols = { assert(newChildren.length == 2) - MatrixUnionCols(newChildren(0).asInstanceOf[MatrixIR], newChildren(1).asInstanceOf[MatrixIR], joinType) + MatrixUnionCols( + newChildren(0).asInstanceOf[MatrixIR], + newChildren(1).asInstanceOf[MatrixIR], + joinType, + ) } private def newRowType = { val leftKeyType = left.typ.rowKeyStruct val leftValueType = left.typ.rowValueStruct val rightValueType = right.typ.rowValueStruct - if (leftValueType.fieldNames.toSet - .intersect(rightValueType.fieldNames.toSet) - .nonEmpty) - throw new RuntimeException(s"invalid MatrixUnionCols: \n left value: $leftValueType\n right value: $rightValueType") + if ( + leftValueType.fieldNames.toSet + .intersect(rightValueType.fieldNames.toSet) + .nonEmpty + ) + throw new RuntimeException( + s"invalid MatrixUnionCols: \n left value: $leftValueType\n right value: $rightValueType" + ) leftKeyType ++ leftValueType ++ rightValueType } @@ -601,17 +690,21 @@ case class MatrixUnionCols(left: MatrixIR, right: MatrixIR, joinType: String) ex left.typ.copy( rowType = newRowType, colType = TStruct(left.typ.colType.fields.map(f => f.copy(typ = f.typ))), - entryType = TStruct(left.typ.entryType.fields.map(f => f.copy(typ = f.typ)))) + entryType = TStruct(left.typ.entryType.fields.map(f => f.copy(typ = f.typ))), + ) override def columnCount: Option[Int] = - left.columnCount.flatMap(leftCount => right.columnCount.map(rightCount => leftCount + rightCount)) - - lazy val rowCountUpperBound: Option[Long] = (left.rowCountUpperBound, right.rowCountUpperBound) match { - case (Some(l), Some(r)) => if (joinType == "inner") Some(l.min(r)) else Some(l + r) - case (Some(l), None) => if (joinType == "inner") Some(l) else None - case (None, Some(r)) => if (joinType == "inner") Some(r) else None - case (None, None) => None - } + left.columnCount.flatMap(leftCount => + right.columnCount.map(rightCount => leftCount + rightCount) + ) + + lazy val rowCountUpperBound: Option[Long] = + (left.rowCountUpperBound, right.rowCountUpperBound) match { + case (Some(l), Some(r)) => if (joinType == "inner") Some(l.min(r)) else Some(l + r) + case (Some(l), None) => if (joinType == "inner") Some(l) else None + case (None, Some(r)) => if (joinType == "inner") Some(r) else None + case (None, None) => None + } } case class MatrixMapEntries(child: MatrixIR, newEntries: IR) extends MatrixIR { @@ -632,7 +725,8 @@ case class MatrixMapEntries(child: MatrixIR, newEntries: IR) extends MatrixIR { lazy val rowCountUpperBound: Option[Long] = child.rowCountUpperBound } -case class MatrixKeyRowsBy(child: MatrixIR, keys: IndexedSeq[String], isSorted: Boolean = false) extends MatrixIR { +case class MatrixKeyRowsBy(child: MatrixIR, keys: IndexedSeq[String], isSorted: Boolean = false) + extends MatrixIR { override def typecheck(): Unit = { val fields = child.typ.rowType.fieldNames.toSet assert(keys.forall(fields.contains), s"${keys.filter(k => !fields.contains(k)).mkString(", ")}") @@ -661,9 +755,8 @@ case class MatrixMapRows(child: MatrixIR, newRow: IR) extends MatrixIR { MatrixMapRows(newChildren(0).asInstanceOf[MatrixIR], newChildren(1).asInstanceOf[IR]) } - lazy val typ: MatrixType = { + lazy val typ: MatrixType = child.typ.copy(rowType = newRow.typ.asInstanceOf[TStruct]) - } override def partitionCounts: Option[IndexedSeq[Long]] = child.partitionCounts @@ -672,7 +765,8 @@ case class MatrixMapRows(child: MatrixIR, newRow: IR) extends MatrixIR { lazy val rowCountUpperBound: Option[Long] = child.rowCountUpperBound } -case class MatrixMapCols(child: MatrixIR, newCol: IR, newKey: Option[IndexedSeq[String]]) extends MatrixIR { +case class MatrixMapCols(child: MatrixIR, newCol: IR, newKey: Option[IndexedSeq[String]]) + extends MatrixIR { lazy val childrenSeq: IndexedSeq[BaseIR] = Array(child, newCol) def copy(newChildren: IndexedSeq[BaseIR]): MatrixMapCols = { @@ -731,11 +825,10 @@ case class MatrixFilterEntries(child: MatrixIR, pred: IR) extends MatrixIR { case class MatrixAnnotateColsTable( child: MatrixIR, table: TableIR, - root: String + root: String, ) extends MatrixIR { - override def typecheck(): Unit = { + override def typecheck(): Unit = assert(child.typ.colType.selfField(root).isEmpty) - } lazy val childrenSeq: IndexedSeq[BaseIR] = FastSeq(child, table) @@ -744,14 +837,15 @@ case class MatrixAnnotateColsTable( override def partitionCounts: Option[IndexedSeq[Long]] = child.partitionCounts lazy val typ: MatrixType = child.typ.copy( - colType = child.typ.colType.structInsert(table.typ.valueType, FastSeq(root))) + colType = child.typ.colType.structInsert(table.typ.valueType, FastSeq(root)) + ) - def copy(newChildren: IndexedSeq[BaseIR]): MatrixAnnotateColsTable = { + def copy(newChildren: IndexedSeq[BaseIR]): MatrixAnnotateColsTable = MatrixAnnotateColsTable( newChildren(0).asInstanceOf[MatrixIR], newChildren(1).asInstanceOf[TableIR], - root) - } + root, + ) lazy val rowCountUpperBound: Option[Long] = child.rowCountUpperBound } @@ -760,14 +854,16 @@ case class MatrixAnnotateRowsTable( child: MatrixIR, table: TableIR, root: String, - product: Boolean + product: Boolean, ) extends MatrixIR { - override def typecheck(): Unit = { + override def typecheck(): Unit = assert( (!product && table.typ.keyType.isPrefixOf(child.typ.rowKeyStruct)) || - (table.typ.keyType.size == 1 && table.typ.keyType.types(0) == TInterval(child.typ.rowKeyStruct.types(0))), - s"\n L: ${child.typ}\n R: ${table.typ}") - } + (table.typ.keyType.size == 1 && table.typ.keyType.types(0) == TInterval( + child.typ.rowKeyStruct.types(0) + )), + s"\n L: ${child.typ}\n R: ${table.typ}", + ) lazy val childrenSeq: IndexedSeq[BaseIR] = FastSeq(child, table) @@ -835,25 +931,26 @@ case class MatrixRepartition(child: MatrixIR, n: Int, strategy: Int) extends Mat case class MatrixUnionRows(childrenSeq: IndexedSeq[MatrixIR]) extends MatrixIR { require(childrenSeq.length > 1) - override def typecheck(): Unit = { - assert(childrenSeq.tail.forall(c => compatible(c.typ, childrenSeq.head.typ)), childrenSeq.map(_.typ)) - } + override def typecheck(): Unit = + assert( + childrenSeq.tail.forall(c => compatible(c.typ, childrenSeq.head.typ)), + childrenSeq.map(_.typ), + ) def typ: MatrixType = childrenSeq.head.typ - def compatible(t1: MatrixType, t2: MatrixType): Boolean = { + def compatible(t1: MatrixType, t2: MatrixType): Boolean = t1.colKeyStruct == t2.colKeyStruct && t1.rowType == t2.rowType && t1.rowKey == t2.rowKey && t1.entryType == t2.entryType - } def copy(newChildren: IndexedSeq[BaseIR]): MatrixUnionRows = MatrixUnionRows(newChildren.asInstanceOf[IndexedSeq[MatrixIR]]) override def columnCount: Option[Int] = childrenSeq.map(_.columnCount).reduce { (c1, c2) => - require(c1.forall { i1 => c2.forall(i1 == _) }) + require(c1.forall(i1 => c2.forall(i1 == _))) c1.orElse(c2) } @@ -992,24 +1089,23 @@ case class MatrixExplodeCols(child: MatrixIR, path: IndexedSeq[String]) extends } } -/** Create a MatrixTable from a Table, where the column values are stored in a - * global field 'colsFieldName', and the entry values are stored in a row - * field 'entriesFieldName'. +/** Create a MatrixTable from a Table, where the column values are stored in a global field + * 'colsFieldName', and the entry values are stored in a row field 'entriesFieldName'. */ case class CastTableToMatrix( child: TableIR, entriesFieldName: String, colsFieldName: String, - colKey: IndexedSeq[String] + colKey: IndexedSeq[String], ) extends MatrixIR { - override def typecheck(): Unit = { + override def typecheck(): Unit = child.typ.rowType.fieldType(entriesFieldName) match { case TArray(TStruct(_)) => case t => fatal(s"expected entry field to be an array of structs, found $t") } - } - lazy val typ: MatrixType = MatrixType.fromTableType(child.typ, colsFieldName, entriesFieldName, colKey) + lazy val typ: MatrixType = + MatrixType.fromTableType(child.typ, colsFieldName, entriesFieldName, colKey) lazy val childrenSeq: IndexedSeq[BaseIR] = Array(child) @@ -1019,7 +1115,8 @@ case class CastTableToMatrix( newChildren(0).asInstanceOf[TableIR], entriesFieldName, colsFieldName, - colKey) + colKey, + ) } override def partitionCounts: Option[IndexedSeq[Long]] = child.partitionCounts @@ -1040,7 +1137,8 @@ case class MatrixToMatrixApply(child: MatrixIR, function: MatrixToMatrixFunction override def partitionCounts: Option[IndexedSeq[Long]] = if (function.preservesPartitionCounts) child.partitionCounts else None - lazy val rowCountUpperBound: Option[Long] = if (function.preservesPartitionCounts) child.rowCountUpperBound else None + lazy val rowCountUpperBound: Option[Long] = + if (function.preservesPartitionCounts) child.rowCountUpperBound else None } case class MatrixRename( @@ -1048,7 +1146,7 @@ case class MatrixRename( globalMap: Map[String, String], colMap: Map[String, String], rowMap: Map[String, String], - entryMap: Map[String, String] + entryMap: Map[String, String], ) extends MatrixIR { override def typecheck(): Unit = { assert(globalMap.keys.forall(child.typ.globalType.hasField)) @@ -1063,7 +1161,8 @@ case class MatrixRename( colType = child.typ.colType.rename(colMap), rowKey = child.typ.rowKey.map(k => rowMap.getOrElse(k, k)), rowType = child.typ.rowType.rename(rowMap), - entryType = child.typ.entryType.rename(entryMap)) + entryType = child.typ.entryType.rename(entryMap), + ) lazy val childrenSeq: IndexedSeq[BaseIR] = FastSeq(child) @@ -1079,7 +1178,8 @@ case class MatrixRename( } } -case class MatrixFilterIntervals(child: MatrixIR, intervals: IndexedSeq[Interval], keep: Boolean) extends MatrixIR { +case class MatrixFilterIntervals(child: MatrixIR, intervals: IndexedSeq[Interval], keep: Boolean) + extends MatrixIR { lazy val childrenSeq: IndexedSeq[BaseIR] = Array(child) def copy(newChildren: IndexedSeq[BaseIR]): MatrixIR = { diff --git a/hail/src/main/scala/is/hail/expr/ir/MatrixValue.scala b/hail/src/main/scala/is/hail/expr/ir/MatrixValue.scala index 994206776d0..5a9e52aa930 100644 --- a/hail/src/main/scala/is/hail/expr/ir/MatrixValue.scala +++ b/hail/src/main/scala/is/hail/expr/ir/MatrixValue.scala @@ -6,17 +6,19 @@ import is.hail.backend.{ExecuteContext, HailStateManager} import is.hail.io.{BufferSpec, FileWriteMetadata} import is.hail.linalg.RowMatrix import is.hail.rvd.{AbstractRVDSpec, RVD} +import is.hail.types.{MatrixType, TableType} import is.hail.types.physical.{PArray, PCanonicalStruct, PStruct, PType} import is.hail.types.virtual._ -import is.hail.types.{MatrixType, TableType} import is.hail.utils._ import is.hail.variant._ + import org.apache.spark.SparkContext import org.apache.spark.sql.Row case class MatrixValue( typ: MatrixType, - tv: TableValue) { + tv: TableValue, +) { val colFieldType = tv.globals.t.fieldType(LowerMatrixIR.colsFieldName).asInstanceOf[PArray] assert(colFieldType.required) assert(colFieldType.elementType.required) @@ -27,8 +29,11 @@ case class MatrixValue( val rvb = new RegionValueBuilder(HailStateManager(Map.empty), prevGlobals.value.region) rvb.start(newT) rvb.startStruct() - rvb.addFields(prevGlobals.t, prevGlobals.value, - prevGlobals.t.fields.filter(_.name != LowerMatrixIR.colsFieldName).map(_.index).toArray) + rvb.addFields( + prevGlobals.t, + prevGlobals.value, + prevGlobals.t.fields.filter(_.name != LowerMatrixIR.colsFieldName).map(_.index).toArray, + ) rvb.endStruct() BroadcastRow(tv.ctx, RegionValue(prevGlobals.value.region, rvb.end()), newT) } @@ -37,9 +42,14 @@ case class MatrixValue( val prevGlobals = tv.globals val field = prevGlobals.t.field(LowerMatrixIR.colsFieldName) val t = field.typ.asInstanceOf[PArray] - BroadcastIndexedSeq(tv.ctx, - RegionValue(prevGlobals.value.region, prevGlobals.t.loadField(prevGlobals.value.offset, field.index)), - t) + BroadcastIndexedSeq( + tv.ctx, + RegionValue( + prevGlobals.value.region, + prevGlobals.t.loadField(prevGlobals.value.offset, field.index), + ), + t, + ) } val rvd: RVD = tv.rvd @@ -52,9 +62,13 @@ case class MatrixValue( lazy val entryType: TStruct = entryArrayType.elementType.asInstanceOf[TStruct] lazy val entriesRVType: TStruct = TStruct( - MatrixType.entriesIdentifier -> TArray(entryType)) + MatrixType.entriesIdentifier -> TArray(entryType) + ) - require(rvd.typ.key.startsWith(typ.rowKey), s"\nmat row key: ${ typ.rowKey }\nrvd key: ${ rvd.typ.key }") + require( + rvd.typ.key.startsWith(typ.rowKey), + s"\nmat row key: ${typ.rowKey}\nrvd key: ${rvd.typ.key}", + ) def sparkContext: SparkContext = rvd.sparkContext @@ -69,16 +83,25 @@ case class MatrixValue( colValues.javaValue.map(querier(_).asInstanceOf[String]) } - def requireUniqueSamples(method: String) { + def requireUniqueSamples(method: String): Unit = { val dups = stringSampleIds.counter().filter(_._2 > 1).toArray if (dups.nonEmpty) - fatal(s"Method '$method' does not support duplicate column keys. Duplicates:" + - s"\n @1", dups.sortBy(-_._2).map { case (id, count) => s"""($count) "$id"""" }.truncatable("\n ")) + fatal( + s"Method '$method' does not support duplicate column keys. Duplicates:" + + s"\n @1", + dups.sortBy(-_._2).map { case (id, count) => s"""($count) "$id"""" }.truncatable("\n "), + ) } private def writeCols(ctx: ExecuteContext, path: String, bufferSpec: BufferSpec): Long = { val fs = ctx.fs - val fileData = AbstractRVDSpec.writeSingle(ctx, path + "/rows", colValues.t.elementType.asInstanceOf[PStruct], bufferSpec, colValues.javaValue) + val fileData = AbstractRVDSpec.writeSingle( + ctx, + path + "/rows", + colValues.t.elementType.asInstanceOf[PStruct], + bufferSpec, + colValues.javaValue, + ) val partitionCounts = fileData.map(_.rowsWritten) val colsSpec = TableSpecParameters( @@ -86,9 +109,12 @@ case class MatrixValue( is.hail.HAIL_PRETTY_VERSION, "../references", typ.colsTableType.copy(key = FastSeq[String]()), - Map("globals" -> RVDComponentSpec("../globals/rows"), + Map( + "globals" -> RVDComponentSpec("../globals/rows"), "rows" -> RVDComponentSpec("rows"), - "partition_counts" -> PartitionCountsComponentSpec(partitionCounts))) + "partition_counts" -> PartitionCountsComponentSpec(partitionCounts), + ), + ) colsSpec.write(fs, path) using(fs.create(path + "/_SUCCESS"))(out => ()) @@ -98,19 +124,34 @@ case class MatrixValue( private def writeGlobals(ctx: ExecuteContext, path: String, bufferSpec: BufferSpec): Long = { val fs = ctx.fs - val fileData = AbstractRVDSpec.writeSingle(ctx, path + "/rows", globals.t, bufferSpec, Array(globals.javaValue)) + val fileData = AbstractRVDSpec.writeSingle( + ctx, + path + "/rows", + globals.t, + bufferSpec, + Array(globals.javaValue), + ) val partitionCounts = fileData.map(_.rowsWritten) - AbstractRVDSpec.writeSingle(ctx, path + "/globals", PCanonicalStruct.empty(required = true), bufferSpec, Array[Annotation](Row())) + AbstractRVDSpec.writeSingle( + ctx, + path + "/globals", + PCanonicalStruct.empty(required = true), + bufferSpec, + Array[Annotation](Row()), + ) val globalsSpec = TableSpecParameters( FileFormat.version.rep, is.hail.HAIL_PRETTY_VERSION, "../references", TableType(typ.globalType, FastSeq(), TStruct.empty), - Map("globals" -> RVDComponentSpec("globals"), + Map( + "globals" -> RVDComponentSpec("globals"), "rows" -> RVDComponentSpec("rows"), - "partition_counts" -> PartitionCountsComponentSpec(partitionCounts))) + "partition_counts" -> PartitionCountsComponentSpec(partitionCounts), + ), + ) globalsSpec.write(fs, path) using(fs.create(path + "/_SUCCESS"))(out => ()) @@ -122,7 +163,7 @@ case class MatrixValue( path: String, bufferSpec: BufferSpec, fileData: Array[FileWriteMetadata], - consoleInfo: Boolean + consoleInfo: Boolean, ): Unit = { val fs = ctx.fs val globalsPath = path + "/globals" @@ -136,9 +177,12 @@ case class MatrixValue( is.hail.HAIL_PRETTY_VERSION, "../references", typ.rowsTableType, - Map("globals" -> RVDComponentSpec("../globals/rows"), + Map( + "globals" -> RVDComponentSpec("../globals/rows"), "rows" -> RVDComponentSpec("rows"), - "partition_counts" -> PartitionCountsComponentSpec(partitionCounts))) + "partition_counts" -> PartitionCountsComponentSpec(partitionCounts), + ), + ) rowsSpec.write(fs, path + "/rows") using(fs.create(path + "/rows/_SUCCESS"))(out => ()) @@ -148,9 +192,12 @@ case class MatrixValue( is.hail.HAIL_PRETTY_VERSION, "../references", TableType(entriesRVType, FastSeq(), typ.globalType), - Map("globals" -> RVDComponentSpec("../globals/rows"), + Map( + "globals" -> RVDComponentSpec("../globals/rows"), "rows" -> RVDComponentSpec("rows"), - "partition_counts" -> PartitionCountsComponentSpec(partitionCounts))) + "partition_counts" -> PartitionCountsComponentSpec(partitionCounts), + ), + ) entriesSpec.write(fs, path + "/entries") using(fs.create(path + "/entries/_SUCCESS"))(out => ()) @@ -161,7 +208,11 @@ case class MatrixValue( val refPath = path + "/references" fs.mkDir(refPath) Array(typ.colType, typ.rowType, entryType, typ.globalType).foreach { t => - ReferenceGenome.exportReferences(fs, refPath, ReferenceGenome.getReferences(t).map(ctx.getReference(_))) + ReferenceGenome.exportReferences( + fs, + refPath, + ReferenceGenome.getReferences(t).map(ctx.getReference(_)), + ) } val spec = MatrixTableSpecParameters( @@ -169,11 +220,14 @@ case class MatrixValue( is.hail.HAIL_PRETTY_VERSION, "references", typ, - Map("globals" -> RVDComponentSpec("globals/rows"), + Map( + "globals" -> RVDComponentSpec("globals/rows"), "cols" -> RVDComponentSpec("cols/rows"), "rows" -> RVDComponentSpec("rows/rows"), "entries" -> RVDComponentSpec("entries/rows"), - "partition_counts" -> PartitionCountsComponentSpec(partitionCounts))) + "partition_counts" -> PartitionCountsComponentSpec(partitionCounts), + ), + ) spec.write(fs, path) writeNativeFileReadMe(fs, path) @@ -186,22 +240,25 @@ case class MatrixValue( val partitionBytesWritten = fileData.map(_.bytesWritten) val totalRowsEntriesBytes = partitionBytesWritten.sum val totalBytesWritten: Long = totalRowsEntriesBytes + colBytesWritten + globalBytesWritten - val (smallestStr, largestStr) = if (fileData.isEmpty) ("N/A", "N/A") else { + val (smallestStr, largestStr) = if (fileData.isEmpty) ("N/A", "N/A") + else { val smallestPartition = fileData.minBy(_.bytesWritten) val largestPartition = fileData.maxBy(_.bytesWritten) - val smallestStr = s"${ smallestPartition.rowsWritten } rows (${ formatSpace(smallestPartition.bytesWritten) })" - val largestStr = s"${ largestPartition.rowsWritten } rows (${ formatSpace(largestPartition.bytesWritten) })" + val smallestStr = + s"${smallestPartition.rowsWritten} rows (${formatSpace(smallestPartition.bytesWritten)})" + val largestStr = + s"${largestPartition.rowsWritten} rows (${formatSpace(largestPartition.bytesWritten)})" (smallestStr, largestStr) } - printer(s"wrote matrix table with $nRows ${ plural(nRows, "row") } " + - s"and $nCols ${ plural(nCols, "column") } " + - s"in ${ partitionCounts.length } ${ plural(partitionCounts.length, "partition") } " + + printer(s"wrote matrix table with $nRows ${plural(nRows, "row")} " + + s"and $nCols ${plural(nCols, "column")} " + + s"in ${partitionCounts.length} ${plural(partitionCounts.length, "partition")} " + s"to $path" + - s"\n Total size: ${ formatSpace(totalBytesWritten) }" + - s"\n * Rows/entries: ${ formatSpace(totalRowsEntriesBytes) }" + - s"\n * Columns: ${ formatSpace(colBytesWritten) }" + - s"\n * Globals: ${ formatSpace(globalBytesWritten) }" + + s"\n Total size: ${formatSpace(totalBytesWritten)}" + + s"\n * Rows/entries: ${formatSpace(totalRowsEntriesBytes)}" + + s"\n * Columns: ${formatSpace(colBytesWritten)}" + + s"\n * Globals: ${formatSpace(globalBytesWritten)}" + s"\n * Smallest partition: $smallestStr" + s"\n * Largest partition: $largestStr") } @@ -267,11 +324,14 @@ object MatrixValue { paths: IndexedSeq[String], overwrite: Boolean, stageLocally: Boolean, - bufferSpec: BufferSpec + bufferSpec: BufferSpec, ): Unit = { val first = mvs.head require(mvs.forall(_.typ == first.typ)) - require(mvs.length == paths.length, s"found ${ mvs.length } matrix tables but ${ paths.length } paths") + require( + mvs.length == paths.length, + s"found ${mvs.length} matrix tables but ${paths.length} paths", + ) val fs = ctx.fs paths.foreach { path => @@ -293,24 +353,29 @@ object MatrixValue { typ: MatrixType, globals: Row, colValues: IndexedSeq[Row], - rvd: RVD): MatrixValue = { + rvd: RVD, + ): MatrixValue = { val globalsType = typ.globalType.appendKey(LowerMatrixIR.colsFieldName, TArray(typ.colType)) val globalsPType = PType.canonical(globalsType).asInstanceOf[PStruct] val rvb = new RegionValueBuilder(ctx.stateManager, ctx.r) rvb.start(globalsPType) rvb.startStruct() - typ.globalType.fields.foreach { f => - rvb.addAnnotation(f.typ, globals.get(f.index)) - } + typ.globalType.fields.foreach(f => rvb.addAnnotation(f.typ, globals.get(f.index))) rvb.addAnnotation(TArray(typ.colType), colValues) - MatrixValue(typ, - TableValue(ctx, TableType( - rowType = rvd.rowType, - key = typ.rowKey, - globalType = globalsType), + MatrixValue( + typ, + TableValue( + ctx, + TableType( + rowType = rvd.rowType, + key = typ.rowKey, + globalType = globalsType, + ), BroadcastRow(ctx, RegionValue(ctx.r, rvb.end()), globalsPType), - rvd)) + rvd, + ), + ) } } diff --git a/hail/src/main/scala/is/hail/expr/ir/MatrixWriter.scala b/hail/src/main/scala/is/hail/expr/ir/MatrixWriter.scala index aa4d2fdd945..fc197439f42 100644 --- a/hail/src/main/scala/is/hail/expr/ir/MatrixWriter.scala +++ b/hail/src/main/scala/is/hail/expr/ir/MatrixWriter.scala @@ -3,9 +3,9 @@ package is.hail.expr.ir import is.hail.annotations.Region import is.hail.asm4s._ import is.hail.backend.ExecuteContext +import is.hail.expr.{JSONAnnotationImpex, Nat} import is.hail.expr.ir.lowering.TableStage import is.hail.expr.ir.streams.StreamProducer -import is.hail.expr.{JSONAnnotationImpex, Nat} import is.hail.io._ import is.hail.io.bgen.BgenSettings import is.hail.io.fs.FS @@ -18,36 +18,50 @@ import is.hail.rvd.{IndexSpec, RVDPartitioner, RVDSpecMaker} import is.hail.types._ import is.hail.types.encoded.{EBaseStruct, EBlockMatrixNDArray, EType} import is.hail.types.physical._ -import is.hail.types.physical.stypes.concrete.{SJavaArrayString, SJavaArrayStringValue, SJavaString, SStackStruct} +import is.hail.types.physical.stypes.{EmitType, SValue} +import is.hail.types.physical.stypes.concrete.{ + SJavaArrayString, SJavaArrayStringValue, SJavaString, SStackStruct, +} import is.hail.types.physical.stypes.interfaces._ import is.hail.types.physical.stypes.primitives._ -import is.hail.types.physical.stypes.{EmitType, SValue} import is.hail.types.virtual._ import is.hail.utils._ import is.hail.utils.richUtils.ByteTrackingOutputStream import is.hail.variant.{Call, ReferenceGenome} -import org.apache.spark.sql.Row -import org.json4s.jackson.JsonMethods -import org.json4s.{DefaultFormats, Formats, ShortTypeHints} + +import scala.language.existentials import java.io.{InputStream, OutputStream} import java.nio.file.{FileSystems, Path} import java.util.UUID -import scala.language.existentials + +import org.apache.spark.sql.Row +import org.json4s.{DefaultFormats, Formats, ShortTypeHints} +import org.json4s.jackson.JsonMethods object MatrixWriter { implicit val formats: Formats = new DefaultFormats() { override val typeHints = ShortTypeHints( - List(classOf[MatrixNativeWriter], classOf[MatrixVCFWriter], classOf[MatrixGENWriter], - classOf[MatrixBGENWriter], classOf[MatrixPLINKWriter], classOf[WrappedMatrixWriter], - classOf[MatrixBlockMatrixWriter]), typeHintFieldName = "name") + List( + classOf[MatrixNativeWriter], + classOf[MatrixVCFWriter], + classOf[MatrixGENWriter], + classOf[MatrixBGENWriter], + classOf[MatrixPLINKWriter], + classOf[WrappedMatrixWriter], + classOf[MatrixBlockMatrixWriter], + ), + typeHintFieldName = "name", + ) } } -case class WrappedMatrixWriter(writer: MatrixWriter, +case class WrappedMatrixWriter( + writer: MatrixWriter, colsFieldName: String, entriesFieldName: String, - colKey: IndexedSeq[String]) extends TableWriter { + colKey: IndexedSeq[String], +) extends TableWriter { def path: String = writer.path override def lower(ctx: ExecuteContext, ts: TableStage, r: RTable): IR = @@ -60,13 +74,27 @@ abstract class MatrixWriter { def apply(ctx: ExecuteContext, mv: MatrixValue): Unit = { val tv = mv.toTableValue val ts = TableExecuteIntermediate(tv).asTableStage(ctx) - CompileAndEvaluate(ctx, lower(LowerMatrixIR.colsFieldName, MatrixType.entriesIdentifier, - mv.typ.colKey, ctx, ts, BaseTypeWithRequiredness(tv.typ).asInstanceOf[RTable] - )) + CompileAndEvaluate( + ctx, + lower( + LowerMatrixIR.colsFieldName, + MatrixType.entriesIdentifier, + mv.typ.colKey, + ctx, + ts, + BaseTypeWithRequiredness(tv.typ).asInstanceOf[RTable], + ), + ) } - def lower(colsFieldName: String, entriesFieldName: String, colKey: IndexedSeq[String], - ctx: ExecuteContext, ts: TableStage, r: RTable): IR + def lower( + colsFieldName: String, + entriesFieldName: String, + colKey: IndexedSeq[String], + ctx: ExecuteContext, + ts: TableStage, + r: RTable, + ): IR } sealed trait MatrixWriterComponents { @@ -90,7 +118,7 @@ object MatrixNativeWriter { stageLocally: Boolean = false, codecSpecJSONStr: String = null, partitions: String = null, - partitionsTypeStr: String = null + partitionsTypeStr: String = null, ): MatrixWriterComponents = { val bufferSpec: BufferSpec = BufferSpec.parseOrDefault(codecSpecJSONStr) val tm = MatrixType.fromTableType(tablestage.tableType, colsFieldName, entriesFieldName, colKey) @@ -102,22 +130,39 @@ object MatrixNativeWriter { val jv = JsonMethods.parse(partitions) val rangeBounds = JSONAnnotationImpex.importAnnotation(jv, partitionsType) .asInstanceOf[IndexedSeq[Interval]] - tablestage.repartitionNoShuffle(ctx, new RVDPartitioner(ctx.stateManager, tm.rowKey.toArray, tm.rowKeyStruct, rangeBounds)) + tablestage.repartitionNoShuffle( + ctx, + new RVDPartitioner(ctx.stateManager, tm.rowKey.toArray, tm.rowKeyStruct, rangeBounds), + ) } else tablestage - val rowSpec = TypedCodecSpec(EType.fromTypeAndAnalysis(tm.rowType, rm.rowType), tm.rowType, bufferSpec) - val entrySpec = TypedCodecSpec(EType.fromTypeAndAnalysis(tm.entriesRVType, rm.entriesRVType), tm.entriesRVType, bufferSpec) - val colSpec = TypedCodecSpec(EType.fromTypeAndAnalysis(tm.colType, rm.colType), tm.colType, bufferSpec) - val globalSpec = TypedCodecSpec(EType.fromTypeAndAnalysis(tm.globalType, rm.globalType), tm.globalType, bufferSpec) - val emptySpec = TypedCodecSpec(EBaseStruct(FastSeq(), required = true), TStruct.empty, bufferSpec) + val rowSpec = + TypedCodecSpec(EType.fromTypeAndAnalysis(tm.rowType, rm.rowType), tm.rowType, bufferSpec) + val entrySpec = TypedCodecSpec( + EType.fromTypeAndAnalysis(tm.entriesRVType, rm.entriesRVType), + tm.entriesRVType, + bufferSpec, + ) + val colSpec = + TypedCodecSpec(EType.fromTypeAndAnalysis(tm.colType, rm.colType), tm.colType, bufferSpec) + val globalSpec = TypedCodecSpec( + EType.fromTypeAndAnalysis(tm.globalType, rm.globalType), + tm.globalType, + bufferSpec, + ) + val emptySpec = + TypedCodecSpec(EBaseStruct(FastSeq(), required = true), TStruct.empty, bufferSpec) // write out partitioner key, which may be stricter than table key val partitioner = lowered.partitioner val pKey: PStruct = tcoerce[PStruct](rowSpec.decodedPType(partitioner.kType)) - val emptyWriter = PartitionNativeWriter(emptySpec, IndexedSeq(), s"$path/globals/globals/parts/", None, None) - val globalWriter = PartitionNativeWriter(globalSpec, IndexedSeq(), s"$path/globals/rows/parts/", None, None) - val colWriter = PartitionNativeWriter(colSpec, IndexedSeq(), s"$path/cols/rows/parts/", None, None) + val emptyWriter = + PartitionNativeWriter(emptySpec, IndexedSeq(), s"$path/globals/globals/parts/", None, None) + val globalWriter = + PartitionNativeWriter(globalSpec, IndexedSeq(), s"$path/globals/rows/parts/", None, None) + val colWriter = + PartitionNativeWriter(colSpec, IndexedSeq(), s"$path/cols/rows/parts/", None, None) val rowWriter = SplitPartitionNativeWriter( rowSpec, s"$path/rows/rows/parts/", @@ -125,13 +170,43 @@ object MatrixNativeWriter { s"$path/entries/rows/parts/", pKey.virtualType.fieldNames, Some(s"$path/index/" -> pKey), - if (stageLocally) Some(FileSystems.getDefault.getPath(ctx.localTmpdir, s"hail_stage_tmp_${UUID.randomUUID}")) else None + if (stageLocally) + Some(FileSystems.getDefault.getPath(ctx.localTmpdir, s"hail_stage_tmp_${UUID.randomUUID}")) + else None, ) - val globalTableWriter = TableSpecWriter(s"$path/globals", TableType(tm.globalType, FastSeq(), TStruct.empty), "rows", "globals", "../references", log = false) - val colTableWriter = TableSpecWriter(s"$path/cols", tm.colsTableType.copy(key = FastSeq[String]()), "rows", "../globals/rows", "../references", log = false) - val rowTableWriter = TableSpecWriter(s"$path/rows", tm.rowsTableType, "rows", "../globals/rows", "../references", log = false) - val entriesTableWriter = TableSpecWriter(s"$path/entries", TableType(tm.entriesRVType, FastSeq(), tm.globalType), "rows", "../globals/rows", "../references", log = false) + val globalTableWriter = TableSpecWriter( + s"$path/globals", + TableType(tm.globalType, FastSeq(), TStruct.empty), + "rows", + "globals", + "../references", + log = false, + ) + val colTableWriter = TableSpecWriter( + s"$path/cols", + tm.colsTableType.copy(key = FastSeq[String]()), + "rows", + "../globals/rows", + "../references", + log = false, + ) + val rowTableWriter = TableSpecWriter( + s"$path/rows", + tm.rowsTableType, + "rows", + "../globals/rows", + "../references", + log = false, + ) + val entriesTableWriter = TableSpecWriter( + s"$path/entries", + TableType(tm.entriesRVType, FastSeq(), tm.globalType), + "rows", + "../globals/rows", + "../references", + log = false, + ) new MatrixWriterComponents { @@ -140,18 +215,23 @@ object MatrixNativeWriter { val d = digitsNeeded(lowered.numPartitions) val partFiles = Array.tabulate(lowered.numPartitions)(i => s"${partFile(d, i)}-") - zip2(oldCtx, ToStream(Literal(TArray(TString), partFiles.toFastSeq)), ArrayZipBehavior.AssertSameLength) { (ctxElt, pf) => - MakeStruct(FastSeq("oldCtx" -> ctxElt, "writeCtx" -> pf)) - } + zip2( + oldCtx, + ToStream(Literal(TArray(TString), partFiles.toFastSeq)), + ArrayZipBehavior.AssertSameLength, + )((ctxElt, pf) => MakeStruct(FastSeq("oldCtx" -> ctxElt, "writeCtx" -> pf))) }(GetField(_, "oldCtx")) override val setup: IR = Begin(FastSeq( - WriteMetadata(Void(), RelationalSetup(path, overwrite = overwrite, Some(tablestage.tableType))), + WriteMetadata( + Void(), + RelationalSetup(path, overwrite = overwrite, Some(tablestage.tableType)), + ), WriteMetadata(Void(), RelationalSetup(s"$path/globals", overwrite = false, None)), WriteMetadata(Void(), RelationalSetup(s"$path/cols", overwrite = false, None)), WriteMetadata(Void(), RelationalSetup(s"$path/rows", overwrite = false, None)), - WriteMetadata(Void(), RelationalSetup(s"$path/entries", overwrite = false, None)) + WriteMetadata(Void(), RelationalSetup(s"$path/entries", overwrite = false, None)), )) override def writePartitionType: Type = @@ -162,52 +242,132 @@ object MatrixNativeWriter { override def finalizeWrite(parts: IR, globals: IR): IR = { // parts is array of partition results - val writeEmpty = WritePartition(MakeStream(FastSeq(makestruct()), TStream(TStruct.empty)), Str(partFile(1, 0)), emptyWriter) - val writeCols = WritePartition(ToStream(GetField(globals, colsFieldName)), Str(partFile(1, 0)), colWriter) - val writeGlobals = WritePartition(MakeStream(FastSeq(SelectFields(globals, tm.globalType.fieldNames)), TStream(tm.globalType)), - Str(partFile(1, 0)), globalWriter) - + val writeEmpty = WritePartition( + MakeStream(FastSeq(makestruct()), TStream(TStruct.empty)), + Str(partFile(1, 0)), + emptyWriter, + ) + val writeCols = + WritePartition(ToStream(GetField(globals, colsFieldName)), Str(partFile(1, 0)), colWriter) + val writeGlobals = WritePartition( + MakeStream( + FastSeq(SelectFields(globals, tm.globalType.fieldNames)), + TStream(tm.globalType), + ), + Str(partFile(1, 0)), + globalWriter, + ) - val matrixWriter = MatrixSpecWriter(path, tm, "rows/rows", "globals/rows", "cols/rows", "entries/rows", "references", log = true) + val matrixWriter = MatrixSpecWriter(path, tm, "rows/rows", "globals/rows", "cols/rows", + "entries/rows", "references", log = true) val rowsIndexSpec = IndexSpec.defaultAnnotation("../../index", tcoerce[PStruct](pKey)) - val entriesIndexSpec = IndexSpec.defaultAnnotation("../../index", tcoerce[PStruct](pKey), withOffsetField = true) + val entriesIndexSpec = + IndexSpec.defaultAnnotation("../../index", tcoerce[PStruct](pKey), withOffsetField = true) bindIR(writeCols) { colInfo => bindIR(parts) { partInfo => Begin(FastSeq( - WriteMetadata(MakeArray(GetField(writeEmpty, "filePath")), - RVDSpecWriter(s"$path/globals/globals", RVDSpecMaker(emptySpec, RVDPartitioner.unkeyed(ctx.stateManager, 1)))), - WriteMetadata(MakeArray(GetField(writeGlobals, "filePath")), - RVDSpecWriter(s"$path/globals/rows", RVDSpecMaker(globalSpec, RVDPartitioner.unkeyed(ctx.stateManager, 1)))), - WriteMetadata(MakeArray(MakeStruct(FastSeq("partitionCounts" -> I64(1), "distinctlyKeyed" -> True(), "firstKey" -> MakeStruct(FastSeq()), "lastKey" -> MakeStruct(FastSeq())))), globalTableWriter), - WriteMetadata(MakeArray(GetField(colInfo, "filePath")), - RVDSpecWriter(s"$path/cols/rows", RVDSpecMaker(colSpec, RVDPartitioner.unkeyed(ctx.stateManager, 1)))), - WriteMetadata(MakeArray(SelectFields(colInfo, IndexedSeq("partitionCounts", "distinctlyKeyed", "firstKey", "lastKey"))), colTableWriter), - bindIR(ToArray(mapIR(ToStream(partInfo)) { fc => GetField(fc, "filePath") })) { files => - Begin(FastSeq( - WriteMetadata(files, RVDSpecWriter(s"$path/rows/rows", RVDSpecMaker(rowSpec, lowered.partitioner, rowsIndexSpec))), - WriteMetadata(files, RVDSpecWriter(s"$path/entries/rows", RVDSpecMaker(entrySpec, RVDPartitioner.unkeyed(ctx.stateManager, lowered.numPartitions), entriesIndexSpec))))) + WriteMetadata( + MakeArray(GetField(writeEmpty, "filePath")), + RVDSpecWriter( + s"$path/globals/globals", + RVDSpecMaker(emptySpec, RVDPartitioner.unkeyed(ctx.stateManager, 1)), + ), + ), + WriteMetadata( + MakeArray(GetField(writeGlobals, "filePath")), + RVDSpecWriter( + s"$path/globals/rows", + RVDSpecMaker(globalSpec, RVDPartitioner.unkeyed(ctx.stateManager, 1)), + ), + ), + WriteMetadata( + MakeArray(MakeStruct(FastSeq( + "partitionCounts" -> I64(1), + "distinctlyKeyed" -> True(), + "firstKey" -> MakeStruct(FastSeq()), + "lastKey" -> MakeStruct(FastSeq()), + ))), + globalTableWriter, + ), + WriteMetadata( + MakeArray(GetField(colInfo, "filePath")), + RVDSpecWriter( + s"$path/cols/rows", + RVDSpecMaker(colSpec, RVDPartitioner.unkeyed(ctx.stateManager, 1)), + ), + ), + WriteMetadata( + MakeArray(SelectFields( + colInfo, + IndexedSeq("partitionCounts", "distinctlyKeyed", "firstKey", "lastKey"), + )), + colTableWriter, + ), + bindIR(ToArray(mapIR(ToStream(partInfo))(fc => GetField(fc, "filePath")))) { + files => + Begin(FastSeq( + WriteMetadata( + files, + RVDSpecWriter( + s"$path/rows/rows", + RVDSpecMaker(rowSpec, lowered.partitioner, rowsIndexSpec), + ), + ), + WriteMetadata( + files, + RVDSpecWriter( + s"$path/entries/rows", + RVDSpecMaker( + entrySpec, + RVDPartitioner.unkeyed(ctx.stateManager, lowered.numPartitions), + entriesIndexSpec, + ), + ), + ), + )) }, - bindIR(ToArray(mapIR(ToStream(partInfo)) { fc => SelectFields(fc, FastSeq("partitionCounts", "distinctlyKeyed", "firstKey", "lastKey")) })) { countsAndKeyInfo => + bindIR(ToArray(mapIR(ToStream(partInfo)) { fc => + SelectFields( + fc, + FastSeq("partitionCounts", "distinctlyKeyed", "firstKey", "lastKey"), + ) + })) { countsAndKeyInfo => Begin(FastSeq( WriteMetadata(countsAndKeyInfo, rowTableWriter), WriteMetadata( ToArray(mapIR(ToStream(countsAndKeyInfo)) { countAndKeyInfo => - InsertFields(SelectFields(countAndKeyInfo, IndexedSeq("partitionCounts", "distinctlyKeyed")), IndexedSeq("firstKey" -> MakeStruct(FastSeq()), "lastKey" -> MakeStruct(FastSeq()))) + InsertFields( + SelectFields( + countAndKeyInfo, + IndexedSeq("partitionCounts", "distinctlyKeyed"), + ), + IndexedSeq( + "firstKey" -> MakeStruct(FastSeq()), + "lastKey" -> MakeStruct(FastSeq()), + ), + ) }), - entriesTableWriter), + entriesTableWriter, + ), WriteMetadata( makestruct( "cols" -> GetField(colInfo, "partitionCounts"), - "rows" -> ToArray(mapIR(ToStream(countsAndKeyInfo)) { countAndKey => GetField(countAndKey, "partitionCounts") })), - matrixWriter))) + "rows" -> ToArray(mapIR(ToStream(countsAndKeyInfo)) { countAndKey => + GetField(countAndKey, "partitionCounts") + }), + ), + matrixWriter, + ), + )) }, WriteMetadata(MakeStruct(FastSeq()), RelationalCommit(path)), WriteMetadata(MakeStruct(FastSeq()), RelationalCommit(s"$path/globals")), WriteMetadata(MakeStruct(FastSeq()), RelationalCommit(s"$path/cols")), WriteMetadata(MakeStruct(FastSeq()), RelationalCommit(s"$path/rows")), - WriteMetadata(MakeStruct(FastSeq()), RelationalCommit(s"$path/entries")))) + WriteMetadata(MakeStruct(FastSeq()), RelationalCommit(s"$path/entries")), + )) } } } @@ -221,11 +381,17 @@ case class MatrixNativeWriter( stageLocally: Boolean = false, codecSpecJSONStr: String = null, partitions: String = null, - partitionsTypeStr: String = null + partitionsTypeStr: String = null, ) extends MatrixWriter { - override def lower(colsFieldName: String, entriesFieldName: String, colKey: IndexedSeq[String], - ctx: ExecuteContext, tablestage: TableStage, r: RTable): IR = { + override def lower( + colsFieldName: String, + entriesFieldName: String, + colKey: IndexedSeq[String], + ctx: ExecuteContext, + tablestage: TableStage, + r: RTable, + ): IR = { val components = MatrixNativeWriter.generateComponentFunctions( colsFieldName, entriesFieldName, colKey, ctx, tablestage, r, @@ -233,20 +399,22 @@ case class MatrixNativeWriter( Begin(FastSeq( components.setup, - components.stage.mapCollectWithContextsAndGlobals("matrix_native_writer")(components.writePartition)(components.finalizeWrite) + components.stage.mapCollectWithContextsAndGlobals("matrix_native_writer")( + components.writePartition + )(components.finalizeWrite), )) } } -case class SplitPartitionNativeWriter(spec1: AbstractTypedCodecSpec, - partPrefix1: String, - spec2: AbstractTypedCodecSpec, - partPrefix2: String, - keyFieldNames: IndexedSeq[String], - index: Option[(String, PStruct)], - stageFolder: Option[Path] - ) - extends PartitionWriter { +case class SplitPartitionNativeWriter( + spec1: AbstractTypedCodecSpec, + partPrefix1: String, + spec2: AbstractTypedCodecSpec, + partPrefix2: String, + keyFieldNames: IndexedSeq[String], + index: Option[(String, PStruct)], + stageFolder: Option[Path], +) extends PartitionWriter { val filenameType = PCanonicalString(required = true) def pContextType = PCanonicalString() @@ -254,8 +422,20 @@ case class SplitPartitionNativeWriter(spec1: AbstractTypedCodecSpec, val keyType = spec1.encodedVirtualType.asInstanceOf[TStruct].select(keyFieldNames)._1 def ctxType: Type = TString - def returnType: Type = TStruct("filePath" -> TString, "partitionCounts" -> TInt64, "distinctlyKeyed" -> TBoolean, "firstKey" -> keyType, "lastKey" -> keyType) - def unionTypeRequiredness(r: TypeWithRequiredness, ctxType: TypeWithRequiredness, streamType: RIterable): Unit = { + + def returnType: Type = TStruct( + "filePath" -> TString, + "partitionCounts" -> TInt64, + "distinctlyKeyed" -> TBoolean, + "firstKey" -> keyType, + "lastKey" -> keyType, + ) + + def unionTypeRequiredness( + r: TypeWithRequiredness, + ctxType: TypeWithRequiredness, + streamType: RIterable, + ): Unit = { val rs = r.asInstanceOf[RStruct] val rKeyType = streamType.elementType.asInstanceOf[RStruct].select(keyFieldNames.toArray) rs.field("firstKey").union(false) @@ -266,22 +446,31 @@ case class SplitPartitionNativeWriter(spec1: AbstractTypedCodecSpec, r.union(streamType.required) } - def consumeStream(ctx: ExecuteContext, - cb: EmitCodeBuilder, - stream: StreamProducer, - context: EmitCode, - region: Value[Region] - ): IEmitCode = { + def consumeStream( + ctx: ExecuteContext, + cb: EmitCodeBuilder, + stream: StreamProducer, + context: EmitCode, + region: Value[Region], + ): IEmitCode = { val iAnnotationType = PCanonicalStruct(required = true, "entries_offset" -> PInt64()) val mb = cb.emb val writeIndexInfo = index.map { case (name, ktype) => val bfactor = Option(mb.ctx.getFlag("index_branching_factor")).map(_.toInt).getOrElse(4096) - (name, ktype, StagedIndexWriter.withDefaults(ktype, mb.ecb, annotationType = iAnnotationType, branchingFactor = bfactor)) + ( + name, + ktype, + StagedIndexWriter.withDefaults( + ktype, + mb.ecb, + annotationType = iAnnotationType, + branchingFactor = bfactor, + ), + ) } context.toI(cb).map(cb) { pctx => - val ctxValue = pctx.asString.loadString(cb) val (filenames, stages, buffers) = FastSeq(partPrefix1, partPrefix2) @@ -298,9 +487,12 @@ case class SplitPartitionNativeWriter(spec1: AbstractTypedCodecSpec, } val ostream = mb.newLocal[ByteTrackingOutputStream](s"write_os$i") - cb.assign(ostream, Code.newInstance[ByteTrackingOutputStream, OutputStream]( - mb.createUnbuffered(stagingFile.getOrElse(filename).get) - )) + cb.assign( + ostream, + Code.newInstance[ByteTrackingOutputStream, OutputStream]( + mb.createUnbuffered(stagingFile.getOrElse(filename).get) + ), + ) val buffer = mb.newLocal[OutputBuffer](s"write_ob$i") cb.assign(buffer, spec1.buildCodeOutputBuffer(Code.checkcast[OutputStream](ostream))) @@ -319,12 +511,15 @@ case class SplitPartitionNativeWriter(spec1: AbstractTypedCodecSpec, cb.assign(pCount, 0L) val distinctlyKeyed = mb.newLocal[Boolean]("distinctlyKeyed") - cb.assign(distinctlyKeyed, !keyFieldNames.isEmpty) // True until proven otherwise, if there's a key to care about all. + cb.assign( + distinctlyKeyed, + !keyFieldNames.isEmpty, + ) // True until proven otherwise, if there's a key to care about all. val keyEmitType = EmitType(spec1.decodedPType(keyType).sType, false) - val firstSeenSettable = mb.newEmitLocal("pnw_firstSeen", keyEmitType) - val lastSeenSettable = mb.newEmitLocal("pnw_lastSeen", keyEmitType) + val firstSeenSettable = mb.newEmitLocal("pnw_firstSeen", keyEmitType) + val lastSeenSettable = mb.newEmitLocal("pnw_lastSeen", keyEmitType) val lastSeenRegion = cb.newLocal[Region]("last_seen_region") // Start off missing, we will use this to determine if we haven't processed any rows yet. @@ -334,48 +529,78 @@ case class SplitPartitionNativeWriter(spec1: AbstractTypedCodecSpec, val specs = FastSeq(spec1, spec2) stream.memoryManagedConsume(region, cb) { cb => - val row = stream.element.toI(cb).get(cb, "row can't be missing").asBaseStruct + val row = stream.element.toI(cb).getOrFatal(cb, "row can't be missing").asBaseStruct writeIndexInfo.foreach { case (_, keyType, writer) => - writer.add(cb, { - IEmitCode.present(cb, keyType.asInstanceOf[PCanonicalBaseStruct] - .constructFromFields(cb, stream.elementRegion, keyType.fields.map { f => - EmitCode.fromI(cb.emb)(cb => row.loadField(cb, f.name)) - }, - deepCopy = false + writer.add( + cb, { + IEmitCode.present( + cb, + keyType.asInstanceOf[PCanonicalBaseStruct] + .constructFromFields( + cb, + stream.elementRegion, + keyType.fields.map { f => + EmitCode.fromI(cb.emb)(cb => row.loadField(cb, f.name)) + }, + deepCopy = false, + ), ) - ) - }, + }, buffers(0).invoke[Long]("indexOffset"), { - IEmitCode.present(cb, - iAnnotationType.constructFromFields(cb, stream.elementRegion, - FastSeq(EmitCode.present(cb.emb, primitive(cb.memoize(buffers(1).invoke[Long]("indexOffset"))))), - deepCopy = false - ) + IEmitCode.present( + cb, + iAnnotationType.constructFromFields( + cb, + stream.elementRegion, + FastSeq(EmitCode.present( + cb.emb, + primitive(cb.memoize(buffers(1).invoke[Long]("indexOffset"))), + )), + deepCopy = false, + ), ) - } + }, ) } - val key = SStackStruct.constructFromArgs(cb, stream.elementRegion, keyType, keyType.fields.map { f => - EmitCode.fromI(cb.emb)(cb => row.loadField(cb, f.name)) - }: _*) + val key = SStackStruct.constructFromArgs( + cb, + stream.elementRegion, + keyType, + keyType.fields.map(f => EmitCode.fromI(cb.emb)(cb => row.loadField(cb, f.name))): _* + ) if (!keyFieldNames.isEmpty) { - cb.if_(distinctlyKeyed, { - lastSeenSettable.loadI(cb).consume(cb, { - // If there's no last seen, we are in the first row. - cb.assign(firstSeenSettable, EmitValue.present(key.copyToRegion(cb, region, firstSeenSettable.st))) - }, { lastSeen => - val comparator = EQ(lastSeenSettable.emitType.virtualType).codeOrdering(cb.emb.ecb, lastSeenSettable.st, key.st) - val equalToLast = comparator(cb, lastSeenSettable, EmitValue.present(key)) - cb.if_(equalToLast.asInstanceOf[Value[Boolean]], { - cb.assign(distinctlyKeyed, false) - }) - }) - }) + cb.if_( + distinctlyKeyed, { + lastSeenSettable.loadI(cb).consume( + cb, + // If there's no last seen, we are in the first row. + cb.assign( + firstSeenSettable, + EmitValue.present(key.copyToRegion(cb, region, firstSeenSettable.st)), + ), + { lastSeen => + val comparator = EQ(lastSeenSettable.emitType.virtualType).codeOrdering( + cb.emb.ecb, + lastSeenSettable.st, + key.st, + ) + val equalToLast = comparator(cb, lastSeenSettable, EmitValue.present(key)) + cb.if_( + equalToLast.asInstanceOf[Value[Boolean]], + cb.assign(distinctlyKeyed, false), + ) + }, + ) + }, + ) cb += lastSeenRegion.clearRegion() - cb.assign(lastSeenSettable, IEmitCode.present(cb, key.copyToRegion(cb, lastSeenRegion, lastSeenSettable.st))) + cb.assign( + lastSeenSettable, + IEmitCode.present(cb, key.copyToRegion(cb, lastSeenRegion, lastSeenSettable.st)), + ) } buffers.zip(specs).foreach { case (buff, spec) => @@ -395,27 +620,48 @@ case class SplitPartitionNativeWriter(spec1: AbstractTypedCodecSpec, } stages.flatMap(_.toIterable).zip(filenames).foreach { case (source, destination) => - cb += mb.getFS.invoke[String, String, Boolean, Unit]("copy", source, destination, const(true)) + cb += mb.getFS.invoke[String, String, Boolean, Unit]( + "copy", + source, + destination, + const(true), + ) } - lastSeenSettable.loadI(cb).consume(cb, { /* do nothing */ }, { lastSeen => - cb.assign(lastSeenSettable, IEmitCode.present(cb, lastSeen.copyToRegion(cb, region, lastSeenSettable.st))) - }) + lastSeenSettable.loadI(cb).consume( + cb, + { /* do nothing */ }, + lastSeen => + cb.assign( + lastSeenSettable, + IEmitCode.present(cb, lastSeen.copyToRegion(cb, region, lastSeenSettable.st)), + ), + ) cb += lastSeenRegion.invalidate() - SStackStruct.constructFromArgs(cb, region, returnType.asInstanceOf[TBaseStruct], + SStackStruct.constructFromArgs( + cb, + region, + returnType.asInstanceOf[TBaseStruct], EmitCode.present(mb, pctx), EmitCode.present(mb, new SInt64Value(pCount)), EmitCode.present(mb, new SBooleanValue(distinctlyKeyed)), firstSeenSettable, - lastSeenSettable + lastSeenSettable, ) } } } class MatrixSpecHelper( - path: String, rowRelPath: String, globalRelPath: String, colRelPath: String, entryRelPath: String, refRelPath: String, typ: MatrixType, log: Boolean + path: String, + rowRelPath: String, + globalRelPath: String, + colRelPath: String, + entryRelPath: String, + refRelPath: String, + typ: MatrixType, + log: Boolean, ) extends Serializable { def write(fs: FS, nCols: Long, partCounts: Array[Long]): Unit = { val spec = MatrixTableSpecParameters( @@ -423,44 +669,65 @@ class MatrixSpecHelper( is.hail.HAIL_PRETTY_VERSION, "references", typ, - Map("globals" -> RVDComponentSpec(globalRelPath), + Map( + "globals" -> RVDComponentSpec(globalRelPath), "cols" -> RVDComponentSpec(colRelPath), "rows" -> RVDComponentSpec(rowRelPath), "entries" -> RVDComponentSpec(entryRelPath), - "partition_counts" -> PartitionCountsComponentSpec(partCounts))) + "partition_counts" -> PartitionCountsComponentSpec(partCounts), + ), + ) spec.write(fs, path) val nRows = partCounts.sum - info(s"wrote matrix table with $nRows ${ plural(nRows, "row") } " + - s"and ${ nCols } ${ plural(nCols, "column") } " + - s"in ${ partCounts.length } ${ plural(partCounts.length, "partition") } " + + info(s"wrote matrix table with $nRows ${plural(nRows, "row")} " + + s"and $nCols ${plural(nCols, "column")} " + + s"in ${partCounts.length} ${plural(partCounts.length, "partition")} " + s"to $path") } } -case class MatrixSpecWriter(path: String, typ: MatrixType, rowRelPath: String, globalRelPath: String, colRelPath: String, entryRelPath: String, refRelPath: String, log: Boolean) extends MetadataWriter { +case class MatrixSpecWriter( + path: String, + typ: MatrixType, + rowRelPath: String, + globalRelPath: String, + colRelPath: String, + entryRelPath: String, + refRelPath: String, + log: Boolean, +) extends MetadataWriter { def annotationType: Type = TStruct("cols" -> TInt64, "rows" -> TArray(TInt64)) def writeMetadata( writeAnnotations: => IEmitCode, cb: EmitCodeBuilder, - region: Value[Region]): Unit = { + region: Value[Region], + ): Unit = { cb += cb.emb.getFS.invoke[String, Unit]("mkDir", path) - val c = writeAnnotations.get(cb, "write annotations can't be missing!").asBaseStruct + val c = writeAnnotations.getOrFatal(cb, "write annotations can't be missing!").asBaseStruct val partCounts = cb.newLocal[Array[Long]]("partCounts") - val a = c.loadField(cb, "rows").get(cb).asIndexable + val a = c.loadField(cb, "rows").getOrAssert(cb).asIndexable val n = cb.newLocal[Int]("n", a.loadLength()) val i = cb.newLocal[Int]("i", 0) cb.assign(partCounts, Code.newArray[Long](n)) - cb.while_(i < n, { - val count = a.loadElement(cb, i).get(cb, "part count can't be missing!") - cb += partCounts.update(i, count.asInt64.value) - cb.assign(i, i + 1) - }) - cb += cb.emb.getObject(new MatrixSpecHelper(path, rowRelPath, globalRelPath, colRelPath, entryRelPath, refRelPath, typ, log)) - .invoke[FS, Long, Array[Long], Unit]("write", cb.emb.getFS, c.loadField(cb, "cols").get(cb).asInt64.value, partCounts) + cb.while_( + i < n, { + val count = a.loadElement(cb, i).getOrFatal(cb, "part count can't be missing!") + cb += partCounts.update(i, count.asInt64.value) + cb.assign(i, i + 1) + }, + ) + cb += cb.emb.getObject(new MatrixSpecHelper(path, rowRelPath, globalRelPath, colRelPath, + entryRelPath, refRelPath, typ, log)) + .invoke[FS, Long, Array[Long], Unit]( + "write", + cb.emb.getFS, + c.loadField(cb, "cols").getOrAssert(cb).asInt64.value, + partCounts, + ) } } @@ -469,10 +736,16 @@ case class MatrixVCFWriter( append: Option[String] = None, exportType: String = ExportType.CONCATENATED, metadata: Option[VCFMetadata] = None, - tabix: Boolean = false + tabix: Boolean = false, ) extends MatrixWriter { - override def lower(colsFieldName: String, entriesFieldName: String, colKey: IndexedSeq[String], - ctx: ExecuteContext, ts: TableStage, r: RTable): IR = { + override def lower( + colsFieldName: String, + entriesFieldName: String, + colKey: IndexedSeq[String], + ctx: ExecuteContext, + ts: TableStage, + r: RTable, + ): IR = { require(exportType != ExportType.PARALLEL_COMPOSABLE) val tm = MatrixType.fromTableType(ts.tableType, colsFieldName, entriesFieldName, colKey) @@ -484,7 +757,9 @@ case class MatrixVCFWriter( case tinfo: TStruct => ExportVCF.checkInfoSignature(tinfo) case t => - warn(s"export_vcf found row field 'info' of type $t, but expected type 'tstruct'. Emitting no INFO fields.") + warn( + s"export_vcf found row field 'info' of type $t, but expected type 'tstruct'. Emitting no INFO fields." + ) } } else { warn(s"export_vcf found no row field 'info'. Emitting no INFO fields.") @@ -504,26 +779,39 @@ case class MatrixVCFWriter( val writeHeader = exportType == ExportType.PARALLEL_HEADER_IN_SHARD val partAppend = appendStr.filter(_ => writeHeader) val partMetadata = metadata.filter(_ => writeHeader) - val lineWriter = VCFPartitionWriter(tm, entriesFieldName, writeHeader = exportType == ExportType.PARALLEL_HEADER_IN_SHARD, - partAppend, partMetadata, tabix && exportType != ExportType.CONCATENATED) + val lineWriter = VCFPartitionWriter( + tm, + entriesFieldName, + writeHeader = exportType == ExportType.PARALLEL_HEADER_IN_SHARD, + partAppend, + partMetadata, + tabix && exportType != ExportType.CONCATENATED, + ) ts.mapContexts { oldCtx => val d = digitsNeeded(ts.numPartitions) - val partFiles = Literal(TArray(TString), Array.tabulate(ts.numPartitions)(i => s"$folder/${ partFile(d, i) }-").toFastSeq) + val partFiles = Literal( + TArray(TString), + Array.tabulate(ts.numPartitions)(i => s"$folder/${partFile(d, i)}-").toFastSeq, + ) zip2(oldCtx, ToStream(partFiles), ArrayZipBehavior.AssertSameLength) { (ctxElt, pf) => MakeStruct(FastSeq( "oldCtx" -> ctxElt, - "partFile" -> pf)) + "partFile" -> pf, + )) } - }(GetField(_, "oldCtx")).mapCollectWithContextsAndGlobals("matrix_vcf_writer") { (rows, ctxRef) => - val partFile = GetField(ctxRef, "partFile") + UUID4() + Str(ext) - val ctx = MakeStruct(FastSeq( - "cols" -> GetField(ts.globals, colsFieldName), - "partFile" -> partFile)) - WritePartition(rows, ctx, lineWriter) - }{ (parts, globals) => - val ctx = MakeStruct(FastSeq("cols" -> GetField(globals, colsFieldName), "partFiles" -> parts)) + }(GetField(_, "oldCtx")).mapCollectWithContextsAndGlobals("matrix_vcf_writer") { + (rows, ctxRef) => + val partFile = GetField(ctxRef, "partFile") + UUID4() + Str(ext) + val ctx = MakeStruct(FastSeq( + "cols" -> GetField(ts.globals, colsFieldName), + "partFile" -> partFile, + )) + WritePartition(rows, ctx, lineWriter) + } { (parts, globals) => + val ctx = + MakeStruct(FastSeq("cols" -> GetField(globals, colsFieldName), "partFiles" -> parts)) val commit = VCFExportFinalizer(tm, path, appendStr, metadata, exportType, tabix) Begin(FastSeq(WriteMetadata(ctx, commit))) } @@ -544,50 +832,87 @@ case class MatrixVCFWriter( } } -case class VCFPartitionWriter(typ: MatrixType, entriesFieldName: String, writeHeader: Boolean, - append: Option[String], metadata: Option[VCFMetadata], tabix: Boolean) extends PartitionWriter { +case class VCFPartitionWriter( + typ: MatrixType, + entriesFieldName: String, + writeHeader: Boolean, + append: Option[String], + metadata: Option[VCFMetadata], + tabix: Boolean, +) extends PartitionWriter { val ctxType: Type = TStruct("cols" -> TArray(typ.colType), "partFile" -> TString) val formatFieldOrder: Array[Int] = typ.entryType.fieldIdx.get("GT") match { case Some(i) => (i +: typ.entryType.fields.filter(fd => fd.name != "GT").map(_.index)).toArray case None => typ.entryType.fields.indices.toArray } + val formatFieldStr = formatFieldOrder.map(i => typ.entryType.fields(i).name).mkString(":") val locusIdx = typ.rowType.fieldIdx("locus") val allelesIdx = typ.rowType.fieldIdx("alleles") val (idExists, idIdx) = ExportVCF.lookupVAField(typ.rowType, "rsid", "ID", Some(TString)) val (qualExists, qualIdx) = ExportVCF.lookupVAField(typ.rowType, "qual", "QUAL", Some(TFloat64)) - val (filtersExists, filtersIdx) = ExportVCF.lookupVAField(typ.rowType, "filters", "FILTERS", Some(TSet(TString))) + + val (filtersExists, filtersIdx) = + ExportVCF.lookupVAField(typ.rowType, "filters", "FILTERS", Some(TSet(TString))) + val (infoExists, infoIdx) = ExportVCF.lookupVAField(typ.rowType, "info", "INFO", None) def returnType: Type = TString - def unionTypeRequiredness(r: TypeWithRequiredness, ctxType: TypeWithRequiredness, streamType: RIterable): Unit = { + + def unionTypeRequiredness( + r: TypeWithRequiredness, + ctxType: TypeWithRequiredness, + streamType: RIterable, + ): Unit = { r.union(ctxType.required) r.union(streamType.required) } - final def consumeStream(ctx: ExecuteContext, cb: EmitCodeBuilder, stream: StreamProducer, - context: EmitCode, region: Value[Region]): IEmitCode = { + final def consumeStream( + ctx: ExecuteContext, + cb: EmitCodeBuilder, + stream: StreamProducer, + context: EmitCode, + region: Value[Region], + ): IEmitCode = { val mb = cb.emb context.toI(cb).map(cb) { case ctx: SBaseStructValue => val formatFieldUTF8 = cb.memoize(const(formatFieldStr).invoke[Array[Byte]]("getBytes")) - val filename = ctx.loadField(cb, "partFile").get(cb, "partFile can't be missing").asString.loadString(cb) + val filename = + ctx.loadField(cb, "partFile").getOrFatal( + cb, + "partFile can't be missing", + ).asString.loadString(cb) val os = cb.memoize(cb.emb.create(filename)) if (writeHeader) { - val sampleIds = ctx.loadField(cb, "cols").get(cb).asIndexable + val sampleIds = ctx.loadField(cb, "cols").getOrAssert(cb).asIndexable val stringSampleIds = cb.memoize(Code.newArray[String](sampleIds.loadLength())) sampleIds.forEachDefined(cb) { case (cb, i, colv: SBaseStructValue) => - val s = colv.subset(typ.colKey: _*).loadField(cb, 0).get(cb).asString + val s = colv.subset(typ.colKey: _*).loadField(cb, 0).getOrAssert(cb).asString cb += (stringSampleIds(i) = s.loadString(cb)) } - val headerStr = Code.invokeScalaObject6[TStruct, TStruct, ReferenceGenome, Option[String], Option[VCFMetadata], Array[String], String]( - ExportVCF.getClass, "makeHeader", - mb.getType[TStruct](typ.rowType), mb.getType[TStruct](typ.entryType), - mb.getReferenceGenome(typ.referenceGenomeName), mb.getObject(append), - mb.getObject(metadata), stringSampleIds) + val headerStr = Code.invokeScalaObject6[ + TStruct, + TStruct, + ReferenceGenome, + Option[String], + Option[VCFMetadata], + Array[String], + String, + ]( + ExportVCF.getClass, + "makeHeader", + mb.getType[TStruct](typ.rowType), + mb.getType[TStruct](typ.entryType), + mb.getReferenceGenome(typ.referenceGenomeName), + mb.getObject(append), + mb.getObject(metadata), + stringSampleIds, + ) cb += os.invoke[Array[Byte], Unit]("write", headerStr.invoke[Array[Byte]]("getBytes")) cb += os.invoke[Int, Unit]("write", '\n') } @@ -610,7 +935,7 @@ case class VCFPartitionWriter(typ: MatrixType, entriesFieldName: String, writeHe formatFieldUTF8, missingUnphasedDiploidGTUTF8Value, missingFormatUTF8Value, - passUTF8Value + passUTF8Value, ) } @@ -618,7 +943,12 @@ case class VCFPartitionWriter(typ: MatrixType, entriesFieldName: String, writeHe cb += os.invoke[Unit]("close") if (tabix) { - cb += Code.invokeScalaObject2[FS, String, Unit](TabixVCF.getClass, "apply", cb.emb.getFS, filename) + cb += Code.invokeScalaObject2[FS, String, Unit]( + TabixVCF.getClass, + "apply", + cb.emb.getFS, + filename, + ) } SJavaString.construct(cb, filename) @@ -633,80 +963,118 @@ case class VCFPartitionWriter(typ: MatrixType, entriesFieldName: String, writeHe formatFieldUTF8: Value[Array[Byte]], missingUnphasedDiploidGTUTF8Value: Value[Array[Byte]], missingFormatUTF8Value: Value[Array[Byte]], - passUTF8Value: Value[Array[Byte]] + passUTF8Value: Value[Array[Byte]], ): Unit = { - def _writeC(cb: EmitCodeBuilder, code: Code[Int]) = { cb += os.invoke[Int, Unit]("write", code) } - def _writeB(cb: EmitCodeBuilder, code: Code[Array[Byte]]) = { cb += os.invoke[Array[Byte], Unit]("write", code) } - def _writeS(cb: EmitCodeBuilder, code: Code[String]) = { _writeB(cb, code.invoke[Array[Byte]]("getBytes")) } + def _writeC(cb: EmitCodeBuilder, code: Code[Int]) = cb += os.invoke[Int, Unit]("write", code) + def _writeB(cb: EmitCodeBuilder, code: Code[Array[Byte]]) = + cb += os.invoke[Array[Byte], Unit]("write", code) + def _writeS(cb: EmitCodeBuilder, code: Code[String]) = + _writeB(cb, code.invoke[Array[Byte]]("getBytes")) def writeValue(cb: EmitCodeBuilder, value: SValue) = value match { case v: SInt32Value => _writeS(cb, v.value.toS) case v: SInt64Value => - cb.if_(v.value > Int.MaxValue || v.value < Int.MinValue, cb._fatal( - "Cannot convert Long to Int if value is greater than Int.MaxValue (2^31 - 1) ", - "or less than Int.MinValue (-2^31). Found ", v.value.toS)) + cb.if_( + v.value > Int.MaxValue || v.value < Int.MinValue, + cb._fatal( + "Cannot convert Long to Int if value is greater than Int.MaxValue (2^31 - 1) ", + "or less than Int.MinValue (-2^31). Found ", + v.value.toS, + ), + ) _writeS(cb, v.value.toS) case v: SFloat32Value => - cb.if_(Code.invokeStatic1[java.lang.Float, Float, Boolean]("isNaN", v.value), + cb.if_( + Code.invokeStatic1[java.lang.Float, Float, Boolean]("isNaN", v.value), _writeC(cb, '.'), - _writeS(cb, Code.invokeScalaObject2[String, Float, String](ExportVCF.getClass, "fmtFloat", "%.6g", v.value))) + _writeS( + cb, + Code.invokeScalaObject2[String, Float, String]( + ExportVCF.getClass, + "fmtFloat", + "%.6g", + v.value, + ), + ), + ) case v: SFloat64Value => - cb.if_(Code.invokeStatic1[java.lang.Double, Double, Boolean]("isNaN", v.value), + cb.if_( + Code.invokeStatic1[java.lang.Double, Double, Boolean]("isNaN", v.value), _writeC(cb, '.'), - _writeS(cb, Code.invokeScalaObject2[String, Double, String](ExportVCF.getClass, "fmtDouble", "%.6g", v.value))) + _writeS( + cb, + Code.invokeScalaObject2[String, Double, String]( + ExportVCF.getClass, + "fmtDouble", + "%.6g", + v.value, + ), + ), + ) case v: SStringValue => _writeB(cb, v.toBytes(cb).loadBytes(cb)) case v: SCallValue => val ploidy = v.ploidy(cb) - val phased = v.isPhased(cb) cb.if_(ploidy.ceq(0), cb._fatal("VCF spec does not support 0-ploid calls.")) - cb.if_(ploidy.ceq(1) , cb._fatal("VCF spec does not support phased haploid calls.")) + cb.if_( + ploidy.ceq(1) && v.isPhased(cb), + cb._fatal("VCF spec does not support phased haploid calls."), + ) val c = v.canonicalCall(cb) _writeB(cb, Code.invokeScalaObject1[Int, Array[Byte]](Call.getClass, "toUTF8", c)) } def writeIterable(cb: EmitCodeBuilder, it: SIndexableValue, delim: Int) = - it.forEachDefinedOrMissing(cb)({ (cb, i) => - cb.if_(i.cne(0), _writeC(cb, delim)) - _writeC(cb, '.') - }, { (cb, i, value) => - cb.if_(i.cne(0), _writeC(cb, delim)) - writeValue(cb, value) - }) + it.forEachDefinedOrMissing(cb)( + { (cb, i) => + cb.if_(i.cne(0), _writeC(cb, delim)) + _writeC(cb, '.') + }, + { (cb, i, value) => + cb.if_(i.cne(0), _writeC(cb, delim)) + writeValue(cb, value) + }, + ) def writeGenotype(cb: EmitCodeBuilder, gt: SBaseStructValue) = { val end = cb.newLocal[Int]("lastDefined", -1) val Lend = CodeLabel() formatFieldOrder.zipWithIndex.reverse.foreach { case (idx, pos) => - cb.if_(!gt.isFieldMissing(cb, idx), { - cb.assign(end, pos) - cb.goto(Lend) - }) + cb.if_( + !gt.isFieldMissing(cb, idx), { + cb.assign(end, pos) + cb.goto(Lend) + }, + ) } cb.define(Lend) val Lout = CodeLabel() - cb.if_(end < 0, { - _writeB(cb, missingFormatUTF8Value) - cb.goto(Lout) - }) + cb.if_( + end < 0, { + _writeB(cb, missingFormatUTF8Value) + cb.goto(Lout) + }, + ) formatFieldOrder.zipWithIndex.foreach { case (idx, pos) => if (pos != 0) _writeC(cb, ':') - gt.loadField(cb, idx).consume(cb, { + gt.loadField(cb, idx).consume( + cb, if (gt.st.fieldTypes(idx).virtualType == TCall) _writeB(cb, missingUnphasedDiploidGTUTF8Value) else - _writeC(cb, '.') - }, { - case value: SIndexableValue => - writeIterable(cb, value, ',') - case value => - writeValue(cb, value) - }) + _writeC(cb, '.'), + { + case value: SIndexableValue => + writeIterable(cb, value, ',') + case value => + writeValue(cb, value) + }, + ) cb.if_(end.ceq(pos), cb.goto(Lout)) } @@ -718,8 +1086,8 @@ case class VCFPartitionWriter(typ: MatrixType, entriesFieldName: String, writeHe def writeB(code: Code[Array[Byte]]) = _writeB(cb, code) def writeS(code: Code[String]) = _writeS(cb, code) - val elt = element.toI(cb).get(cb).asBaseStruct - val locus = elt.loadField(cb, locusIdx).get(cb).asLocus + val elt = element.toI(cb).getOrAssert(cb).asBaseStruct + val locus = elt.loadField(cb, locusIdx).getOrAssert(cb).asLocus // CHROM writeB(locus.contig(cb).toBytes(cb).loadBytes(cb)) // POS @@ -729,46 +1097,69 @@ case class VCFPartitionWriter(typ: MatrixType, entriesFieldName: String, writeHe // ID writeC('\t') if (idExists) - elt.loadField(cb, idIdx).consume(cb, writeC('.'), { case id: SStringValue => - writeB(id.toBytes(cb).loadBytes(cb)) - }) + elt.loadField(cb, idIdx).consume( + cb, + writeC('.'), + { case id: SStringValue => + writeB(id.toBytes(cb).loadBytes(cb)) + }, + ) else writeC('.') // REF writeC('\t') - val alleles = elt.loadField(cb, allelesIdx).get(cb).asIndexable - writeB(alleles.loadElement(cb, 0).get(cb).asString.toBytes(cb).loadBytes(cb)) + val alleles = elt.loadField(cb, allelesIdx).getOrAssert(cb).asIndexable + writeB(alleles.loadElement(cb, 0).getOrAssert(cb).asString.toBytes(cb).loadBytes(cb)) // ALT writeC('\t') - cb.if_(alleles.loadLength() > 1, - { + cb.if_( + alleles.loadLength() > 1, { val i = cb.newLocal[Int]("i") - cb.for_(cb.assign(i, 1), i < alleles.loadLength(), cb.assign(i, i + 1), { - cb.if_(i.cne(1), writeC(',')) - writeB(alleles.loadElement(cb, i).get(cb).asString.toBytes(cb).loadBytes(cb)) - }) + cb.for_( + cb.assign(i, 1), + i < alleles.loadLength(), + cb.assign(i, i + 1), { + cb.if_(i.cne(1), writeC(',')) + writeB(alleles.loadElement(cb, i).getOrAssert(cb).asString.toBytes(cb).loadBytes(cb)) + }, + ) }, - writeC('.')) + writeC('.'), + ) // QUAL writeC('\t') if (qualExists) - elt.loadField(cb, qualIdx).consume(cb, writeC('.'), { qual => - writeS(Code.invokeScalaObject2[String, Double, String](ExportVCF.getClass, "fmtDouble", "%.2f", qual.asDouble.value)) - }) + elt.loadField(cb, qualIdx).consume( + cb, + writeC('.'), + qual => + writeS(Code.invokeScalaObject2[String, Double, String]( + ExportVCF.getClass, + "fmtDouble", + "%.2f", + qual.asDouble.value, + )), + ) else writeC('.') // FILTER writeC('\t') if (filtersExists) - elt.loadField(cb, filtersIdx).consume(cb, writeC('.'), { case filters: SIndexableValue => - cb.if_(filters.loadLength().ceq(0), writeB(passUTF8Value), { - writeIterable(cb, filters, ';') - }) - }) + elt.loadField(cb, filtersIdx).consume( + cb, + writeC('.'), + { case filters: SIndexableValue => + cb.if_( + filters.loadLength().ceq(0), + writeB(passUTF8Value), + writeIterable(cb, filters, ';'), + ) + }, + ) else writeC('.') @@ -777,35 +1168,48 @@ case class VCFPartitionWriter(typ: MatrixType, entriesFieldName: String, writeHe if (infoExists) { val wroteInfo = cb.newLocal[Boolean]("wroteInfo", false) - elt.loadField(cb, infoIdx).consume(cb, { /* do nothing */ }, { case info: SBaseStructValue => - var idx = 0 - while (idx < info.st.size) { - val field = info.st.virtualType.fields(idx) - info.loadField(cb, idx).consume(cb, { /* do nothing */ }, { - case infoArray: SIndexableValue if infoArray.st.elementType.virtualType != TBoolean => - cb.if_(infoArray.loadLength() > 0, { - cb.if_(wroteInfo, writeC(';')) - writeS(field.name) - writeC('=') - writeIterable(cb, infoArray, ',') - cb.assign(wroteInfo, true) - }) - case infoFlag: SBooleanValue => - cb.if_(infoFlag.value, { - cb.if_(wroteInfo, writeC(';')) - writeS(field.name) - cb.assign(wroteInfo, true) - }) - case info => - cb.if_(wroteInfo, writeC(';')) - writeS(field.name) - writeC('=') - writeValue(cb, info) - cb.assign(wroteInfo, true) - }) - idx += 1 - } - }) + elt.loadField(cb, infoIdx).consume( + cb, + { /* do nothing */ }, + { case info: SBaseStructValue => + var idx = 0 + while (idx < info.st.size) { + val field = info.st.virtualType.fields(idx) + info.loadField(cb, idx).consume( + cb, + { /* do nothing */ }, + { + case infoArray: SIndexableValue + if infoArray.st.elementType.virtualType != TBoolean => + cb.if_( + infoArray.loadLength() > 0, { + cb.if_(wroteInfo, writeC(';')) + writeS(field.name) + writeC('=') + writeIterable(cb, infoArray, ',') + cb.assign(wroteInfo, true) + }, + ) + case infoFlag: SBooleanValue => + cb.if_( + infoFlag.value, { + cb.if_(wroteInfo, writeC(';')) + writeS(field.name) + cb.assign(wroteInfo, true) + }, + ) + case info => + cb.if_(wroteInfo, writeC(';')) + writeS(field.name) + writeC('=') + writeValue(cb, info) + cb.assign(wroteInfo, true) + }, + ) + idx += 1 + } + }, + ) cb.if_(!wroteInfo, writeC('.')) } else { @@ -813,62 +1217,103 @@ case class VCFPartitionWriter(typ: MatrixType, entriesFieldName: String, writeHe } // FORMAT - val genotypes = elt.loadField(cb, entriesFieldName).get(cb).asIndexable - cb.if_(genotypes.loadLength() > 0, { - writeC('\t') - writeB(formatFieldUTF8) - genotypes.forEachDefinedOrMissing(cb)({ (cb, _) => - _writeC(cb, '\t') - _writeB(cb, missingFormatUTF8Value) - }, { case (cb, _, gt: SBaseStructValue) => - _writeC(cb, '\t') - writeGenotype(cb, gt) - }) - }) + val genotypes = elt.loadField(cb, entriesFieldName).getOrAssert(cb).asIndexable + cb.if_( + genotypes.loadLength() > 0, { + writeC('\t') + writeB(formatFieldUTF8) + genotypes.forEachDefinedOrMissing(cb)( + { (cb, _) => + _writeC(cb, '\t') + _writeB(cb, missingFormatUTF8Value) + }, + { case (cb, _, gt: SBaseStructValue) => + _writeC(cb, '\t') + writeGenotype(cb, gt) + }, + ) + }, + ) writeC('\n') } } -case class VCFExportFinalizer(typ: MatrixType, outputPath: String, append: Option[String], - metadata: Option[VCFMetadata], exportType: String, tabix: Boolean) extends MetadataWriter { +case class VCFExportFinalizer( + typ: MatrixType, + outputPath: String, + append: Option[String], + metadata: Option[VCFMetadata], + exportType: String, + tabix: Boolean, +) extends MetadataWriter { def annotationType: Type = TStruct("cols" -> TArray(typ.colType), "partFiles" -> TArray(TString)) + private def header(cb: EmitCodeBuilder, annotations: SBaseStructValue): Code[String] = { val mb = cb.emb - val sampleIds = annotations.loadField(cb, "cols").get(cb).asIndexable + val sampleIds = annotations.loadField(cb, "cols").getOrAssert(cb).asIndexable val stringSampleIds = cb.memoize(Code.newArray[String](sampleIds.loadLength())) sampleIds.forEachDefined(cb) { case (cb, i, colv: SBaseStructValue) => - val s = colv.subset(typ.colKey: _*).loadField(cb, 0).get(cb).asString + val s = colv.subset(typ.colKey: _*).loadField(cb, 0).getOrAssert(cb).asString cb += (stringSampleIds(i) = s.loadString(cb)) } - Code.invokeScalaObject6[TStruct, TStruct, ReferenceGenome, Option[String], Option[VCFMetadata], Array[String], String]( - ExportVCF.getClass, "makeHeader", - mb.getType[TStruct](typ.rowType), mb.getType[TStruct](typ.entryType), - mb.getReferenceGenome(typ.referenceGenomeName), mb.getObject(append), - mb.getObject(metadata), stringSampleIds) + Code.invokeScalaObject6[ + TStruct, + TStruct, + ReferenceGenome, + Option[String], + Option[VCFMetadata], + Array[String], + String, + ]( + ExportVCF.getClass, + "makeHeader", + mb.getType[TStruct](typ.rowType), + mb.getType[TStruct](typ.entryType), + mb.getReferenceGenome(typ.referenceGenomeName), + mb.getObject(append), + mb.getObject(metadata), + stringSampleIds, + ) } - def writeMetadata(writeAnnotations: => IEmitCode, cb: EmitCodeBuilder, region: Value[Region]): Unit = { + def writeMetadata(writeAnnotations: => IEmitCode, cb: EmitCodeBuilder, region: Value[Region]) + : Unit = { val ctx: ExecuteContext = cb.emb.ctx val ext = ctx.fs.getCodecExtension(outputPath) - val annotations = writeAnnotations.get(cb).asBaseStruct + val annotations = writeAnnotations.getOrAssert(cb).asBaseStruct - val partPaths = annotations.loadField(cb, "partFiles").get(cb).asIndexable - val partFiles = partPaths.castTo(cb, region, SJavaArrayString(true), false).asInstanceOf[SJavaArrayStringValue].array - cb.if_(partPaths.hasMissingValues(cb), cb._fatal("matrixwriter part paths contains missing values")) + val partPaths = annotations.loadField(cb, "partFiles").getOrAssert(cb).asIndexable + val partFiles = partPaths.castTo(cb, region, SJavaArrayString(true), false).asInstanceOf[ + SJavaArrayStringValue + ].array + cb.if_( + partPaths.hasMissingValues(cb), + cb._fatal("matrixwriter part paths contains missing values"), + ) val allFiles = if (tabix && exportType != ExportType.CONCATENATED) { val len = partPaths.loadLength() val files = cb.memoize(Code.newArray[String](len * 2)) val i = cb.newLocal[Int]("i", 0) - cb.while_(i < len, { - val path = cb.memoize(partFiles(i)) - cb += files.update(i, path) - // FIXME(chrisvittal): this will put the string ".tbi" in generated code, we should just access the htsjdk value - cb += files.update(cb.memoize(i + len), Code.invokeStatic2[htsjdk.tribble.util.ParsingUtils, String, String, String]("appendToPath", path, htsjdk.samtools.util.FileExtensions.TABIX_INDEX)) - cb.assign(i, i+1) - }) + cb.while_( + i < len, { + val path = cb.memoize(partFiles(i)) + cb += files.update(i, path) + /* FIXME(chrisvittal): this will put the string ".tbi" in generated code, we should just + * access the htsjdk value */ + cb += files.update( + cb.memoize(i + len), + Code.invokeStatic2[htsjdk.tribble.util.ParsingUtils, String, String, String]( + "appendToPath", + path, + htsjdk.samtools.util.FileExtensions.TABIX_INDEX, + ), + ) + cb.assign(i, i + 1) + }, + ) files } else { partFiles @@ -886,26 +1331,63 @@ case class VCFExportFinalizer(typ: MatrixType, outputPath: String, append: Optio val jFiles = cb.memoize(Code.newArray[String](partFiles.length + 1)) cb += (jFiles(0) = const(headerFilePath)) cb += Code.invokeStatic5[System, Any, Int, Any, Int, Int, Unit]( - "arraycopy", partFiles /*src*/, 0 /*srcPos*/, jFiles /*dest*/, 1 /*destPos*/, partFiles.length /*len*/) + "arraycopy", + partFiles /*src*/, + 0 /*srcPos*/, + jFiles /*dest*/, + 1 /*destPos*/, + partFiles.length, /*len*/ + ) - cb += cb.emb.getFS.invoke[Array[String], String, Unit]("concatenateFiles", jFiles, const(outputPath)) + cb += cb.emb.getFS.invoke[Array[String], String, Unit]( + "concatenateFiles", + jFiles, + const(outputPath), + ) val i = cb.newLocal[Int]("i") - cb.for_(cb.assign(i, 0), i < jFiles.length, cb.assign(i, i + 1), { - cb += cb.emb.getFS.invoke[String, Boolean, Unit]("delete", jFiles(i), const(false)) - }) + cb.for_( + cb.assign(i, 0), + i < jFiles.length, + cb.assign(i, i + 1), + cb += cb.emb.getFS.invoke[String, Boolean, Unit]("delete", jFiles(i), const(false)), + ) if (tabix) { - cb += Code.invokeScalaObject2[FS, String, Unit](TabixVCF.getClass, "apply", cb.emb.getFS, const(outputPath)) + cb += Code.invokeScalaObject2[FS, String, Unit]( + TabixVCF.getClass, + "apply", + cb.emb.getFS, + const(outputPath), + ) } case ExportType.PARALLEL_HEADER_IN_SHARD => - cb += Code.invokeScalaObject3[FS, String, Array[String], Unit](TableTextFinalizer.getClass, "cleanup", cb.emb.getFS, outputPath, allFiles) - cb += Code.invokeScalaObject4[FS, String, Array[String], String, Unit](TableTextFinalizer.getClass, "writeManifest", cb.emb.getFS, outputPath, partFiles, Code._null[String]) + cb += Code.invokeScalaObject3[FS, String, Array[String], Unit]( + TableTextFinalizer.getClass, + "cleanup", + cb.emb.getFS, + outputPath, + allFiles, + ) + cb += Code.invokeScalaObject4[FS, String, Array[String], String, Unit]( + TableTextFinalizer.getClass, + "writeManifest", + cb.emb.getFS, + outputPath, + partFiles, + Code._null[String], + ) cb += cb.emb.getFS.invoke[String, Unit]("touch", const(outputPath).concat("/_SUCCESS")) case ExportType.PARALLEL_SEPARATE_HEADER => - cb += Code.invokeScalaObject3[FS, String, Array[String], Unit](TableTextFinalizer.getClass, "cleanup", cb.emb.getFS, outputPath, allFiles) + cb += Code.invokeScalaObject3[FS, String, Array[String], Unit]( + TableTextFinalizer.getClass, + "cleanup", + cb.emb.getFS, + outputPath, + allFiles, + ) val headerFilePath = s"$outputPath/header$ext" val headerStr = header(cb, annotations) @@ -913,7 +1395,14 @@ case class VCFExportFinalizer(typ: MatrixType, outputPath: String, append: Optio cb += os.invoke[Array[Byte], Unit]("write", headerStr.invoke[Array[Byte]]("getBytes")) cb += os.invoke[Int, Unit]("write", '\n') cb += os.invoke[Unit]("close") - cb += Code.invokeScalaObject4[FS, String, Array[String], String, Unit](TableTextFinalizer.getClass, "writeManifest", cb.emb.getFS, outputPath, partFiles, headerFilePath) + cb += Code.invokeScalaObject4[FS, String, Array[String], String, Unit]( + TableTextFinalizer.getClass, + "writeManifest", + cb.emb.getFS, + outputPath, + partFiles, + headerFilePath, + ) cb += cb.emb.getFS.invoke[String, Unit]("touch", const(outputPath).concat("/_SUCCESS")) } @@ -922,11 +1411,17 @@ case class VCFExportFinalizer(typ: MatrixType, outputPath: String, append: Optio case class MatrixGENWriter( path: String, - precision: Int = 4 + precision: Int = 4, ) extends MatrixWriter { - override def lower(colsFieldName: String, entriesFieldName: String, colKey: IndexedSeq[String], - ctx: ExecuteContext, ts: TableStage, r: RTable): IR = { + override def lower( + colsFieldName: String, + entriesFieldName: String, + colKey: IndexedSeq[String], + ctx: ExecuteContext, + ts: TableStage, + r: RTable, + ): IR = { val tm = MatrixType.fromTableType(ts.tableType, colsFieldName, entriesFieldName, colKey) val sampleWriter = new GenSampleWriter @@ -936,17 +1431,22 @@ case class MatrixGENWriter( ts.mapContexts { oldCtx => val d = digitsNeeded(ts.numPartitions) - val partFiles = Literal(TArray(TString), Array.tabulate(ts.numPartitions)(i => s"$folder/${ partFile(d, i) }-").toFastSeq) + val partFiles = Literal( + TArray(TString), + Array.tabulate(ts.numPartitions)(i => s"$folder/${partFile(d, i)}-").toFastSeq, + ) zip2(oldCtx, ToStream(partFiles), ArrayZipBehavior.AssertSameLength) { (ctxElt, pf) => MakeStruct(FastSeq( "oldCtx" -> ctxElt, - "partFile" -> pf)) + "partFile" -> pf, + )) } - }(GetField(_, "oldCtx")).mapCollectWithContextsAndGlobals("matrix_gen_writer") { (rows, ctxRef) => - val ctx = GetField(ctxRef, "partFile") + UUID4() - WritePartition(rows, ctx, lineWriter) - }{ (parts, globals) => + }(GetField(_, "oldCtx")).mapCollectWithContextsAndGlobals("matrix_gen_writer") { + (rows, ctxRef) => + val ctx = GetField(ctxRef, "partFile") + UUID4() + WritePartition(rows, ctx, lineWriter) + } { (parts, globals) => val cols = ToStream(GetField(globals, colsFieldName)) val sampleFileName = Str(s"$path.sample") val writeSamples = WritePartition(cols, sampleFileName, sampleWriter) @@ -958,89 +1458,168 @@ case class MatrixGENWriter( } } -final case class GenVariantWriter(typ: MatrixType, entriesFieldName: String, precision: Int) extends SimplePartitionWriter { - def consumeElement(cb: EmitCodeBuilder, element: EmitCode, os: Value[OutputStream], region: Value[Region]): Unit = { - def _writeC(cb: EmitCodeBuilder, code: Code[Int]) = { cb += os.invoke[Int, Unit]("write", code) } - def _writeB(cb: EmitCodeBuilder, code: Code[Array[Byte]]) = { cb += os.invoke[Array[Byte], Unit]("write", code) } - def _writeS(cb: EmitCodeBuilder, code: Code[String]) = { _writeB(cb, code.invoke[Array[Byte]]("getBytes")) } +final case class GenVariantWriter(typ: MatrixType, entriesFieldName: String, precision: Int) + extends SimplePartitionWriter { + def consumeElement( + cb: EmitCodeBuilder, + element: EmitCode, + os: Value[OutputStream], + region: Value[Region], + ): Unit = { + def _writeC(cb: EmitCodeBuilder, code: Code[Int]) = cb += os.invoke[Int, Unit]("write", code) + def _writeB(cb: EmitCodeBuilder, code: Code[Array[Byte]]) = + cb += os.invoke[Array[Byte], Unit]("write", code) + def _writeS(cb: EmitCodeBuilder, code: Code[String]) = + _writeB(cb, code.invoke[Array[Byte]]("getBytes")) def writeC(code: Code[Int]) = _writeC(cb, code) def writeS(code: Code[String]) = _writeS(cb, code) - require(typ.entryType.hasField("GP") && typ.entryType.fieldType("GP") == TArray(TFloat64)) - element.toI(cb).consume(cb, cb._fatal("stream element cannot be missing!"), { case sv: SBaseStructValue => - val locus = sv.loadField(cb, "locus").get(cb).asLocus - val contig = locus.contig(cb).loadString(cb) - val alleles = sv.loadField(cb, "alleles").get(cb).asIndexable - val rsid = sv.loadField(cb, "rsid").get(cb).asString.loadString(cb) - val varid = sv.loadField(cb, "varid").get(cb).asString.loadString(cb) - val a0 = alleles.loadElement(cb, 0).get(cb).asString.loadString(cb) - val a1 = alleles.loadElement(cb, 1).get(cb).asString.loadString(cb) - - cb += Code.invokeScalaObject6[String, Int, String, String, String, String, Unit](ExportGen.getClass, "checkVariant", contig, locus.position(cb), a0, a1, varid, rsid) - - writeS(contig) - writeC(' ') - writeS(varid) - writeC(' ') - writeS(rsid) - writeC(' ') - writeS(locus.position(cb).toS) - writeC(' ') - writeS(a0) - writeC(' ') - writeS(a1) - - sv.loadField(cb, entriesFieldName).get(cb).asIndexable.forEachDefinedOrMissing(cb)({ (cb, i) => - _writeS(cb, " 0 0 0") - }, { (cb, i, va) => - va.asBaseStruct.loadField(cb, "GP").consume(cb, _writeS(cb, " 0 0 0"), { case gp: SIndexableValue => - cb.if_(gp.loadLength().cne(3), - cb._fatal("Invalid 'gp' at variant '", locus.contig(cb).loadString(cb), ":", locus.position(cb).toS, ":", a0, ":", a1, "' and sample index ", i.toS, ". The array must have length equal to 3.")) - gp.forEachDefinedOrMissing(cb)((cb, _) => cb._fatal("GP cannot be missing"), { (cb, _, gp) => - _writeC(cb, ' ') - _writeS(cb, Code.invokeScalaObject2[Double, Int, String](utilsPackageClass, "formatDouble", gp.asDouble.value, precision)) - }) - }) - }) - writeC('\n') - }) + element.toI(cb).consume( + cb, + cb._fatal("stream element cannot be missing!"), + { case sv: SBaseStructValue => + val locus = sv.loadField(cb, "locus").getOrAssert(cb).asLocus + val contig = locus.contig(cb).loadString(cb) + val alleles = sv.loadField(cb, "alleles").getOrAssert(cb).asIndexable + val rsid = sv.loadField(cb, "rsid").getOrAssert(cb).asString.loadString(cb) + val varid = sv.loadField(cb, "varid").getOrAssert(cb).asString.loadString(cb) + val a0 = alleles.loadElement(cb, 0).getOrAssert(cb).asString.loadString(cb) + val a1 = alleles.loadElement(cb, 1).getOrAssert(cb).asString.loadString(cb) + + cb += Code.invokeScalaObject6[String, Int, String, String, String, String, Unit]( + ExportGen.getClass, + "checkVariant", + contig, + locus.position(cb), + a0, + a1, + varid, + rsid, + ) + + writeS(contig) + writeC(' ') + writeS(varid) + writeC(' ') + writeS(rsid) + writeC(' ') + writeS(locus.position(cb).toS) + writeC(' ') + writeS(a0) + writeC(' ') + writeS(a1) + + sv.loadField(cb, entriesFieldName).getOrAssert(cb).asIndexable.forEachDefinedOrMissing(cb)( + (cb, i) => _writeS(cb, " 0 0 0"), + { (cb, i, va) => + va.asBaseStruct.loadField(cb, "GP").consume( + cb, + _writeS(cb, " 0 0 0"), + { case gp: SIndexableValue => + cb.if_( + gp.loadLength().cne(3), + cb._fatal( + "Invalid 'gp' at variant '", + locus.contig(cb).loadString(cb), + ":", + locus.position(cb).toS, + ":", + a0, + ":", + a1, + "' and sample index ", + i.toS, + ". The array must have length equal to 3.", + ), + ) + gp.forEachDefinedOrMissing(cb)( + (cb, _) => cb._fatal("GP cannot be missing"), + { (cb, _, gp) => + _writeC(cb, ' ') + _writeS( + cb, + Code.invokeScalaObject2[Double, Int, String]( + utilsPackageClass, + "formatDouble", + gp.asDouble.value, + precision, + ), + ) + }, + ) + }, + ) + }, + ) + writeC('\n') + }, + ) } } final class GenSampleWriter extends SimplePartitionWriter { - def consumeElement(cb: EmitCodeBuilder, element: EmitCode, os: Value[OutputStream], region: Value[Region]): Unit = { - element.toI(cb).consume(cb, cb._fatal("stream element cannot be missing!"), { case sv: SBaseStructValue => - val id1 = sv.loadField(cb, 0).get(cb).asString.loadString(cb) - val id2 = sv.loadField(cb, 1).get(cb).asString.loadString(cb) - val missing = sv.loadField(cb, 2).get(cb).asDouble.value - - cb += Code.invokeScalaObject3[String, String, Double, Unit](ExportGen.getClass, "checkSample", id1, id2, missing) - - cb += os.invoke[Array[Byte], Unit]("write", id1.invoke[Array[Byte]]("getBytes")) - cb += os.invoke[Int, Unit]("write", ' ') - cb += os.invoke[Array[Byte], Unit]("write", id2.invoke[Array[Byte]]("getBytes")) - cb += os.invoke[Int, Unit]("write", ' ') - cb += os.invoke[Array[Byte], Unit]("write", missing.toS.invoke[Array[Byte]]("getBytes")) - cb += os.invoke[Int, Unit]("write", '\n') - }) + def consumeElement( + cb: EmitCodeBuilder, + element: EmitCode, + os: Value[OutputStream], + region: Value[Region], + ): Unit = { + element.toI(cb).consume( + cb, + cb._fatal("stream element cannot be missing!"), + { case sv: SBaseStructValue => + val id1 = sv.loadField(cb, 0).getOrAssert(cb).asString.loadString(cb) + val id2 = sv.loadField(cb, 1).getOrAssert(cb).asString.loadString(cb) + val missing = sv.loadField(cb, 2).getOrAssert(cb).asDouble.value + + cb += Code.invokeScalaObject3[String, String, Double, Unit]( + ExportGen.getClass, + "checkSample", + id1, + id2, + missing, + ) + + cb += os.invoke[Array[Byte], Unit]("write", id1.invoke[Array[Byte]]("getBytes")) + cb += os.invoke[Int, Unit]("write", ' ') + cb += os.invoke[Array[Byte], Unit]("write", id2.invoke[Array[Byte]]("getBytes")) + cb += os.invoke[Int, Unit]("write", ' ') + cb += os.invoke[Array[Byte], Unit]("write", missing.toS.invoke[Array[Byte]]("getBytes")) + cb += os.invoke[Int, Unit]("write", '\n') + }, + ) } override def preConsume(cb: EmitCodeBuilder, os: Value[OutputStream]): Unit = - cb += os.invoke[Array[Byte], Unit]("write", const("ID_1 ID_2 ID_3\n0 0 0\n").invoke[Array[Byte]]("getBytes")) + cb += os.invoke[Array[Byte], Unit]( + "write", + const("ID_1 ID_2 ID_3\n0 0 0\n").invoke[Array[Byte]]("getBytes"), + ) } case class MatrixBGENWriter( path: String, exportType: String, - compressionCodec: String + compressionCodec: String, ) extends MatrixWriter { - override def lower(colsFieldName: String, entriesFieldName: String, colKey: IndexedSeq[String], - ctx: ExecuteContext, ts: TableStage, r: RTable): IR = { + override def lower( + colsFieldName: String, + entriesFieldName: String, + colKey: IndexedSeq[String], + ctx: ExecuteContext, + ts: TableStage, + r: RTable, + ): IR = { - val tm = MatrixType.fromTableType(TableType(ts.rowType, ts.key, ts.globalType), colsFieldName, entriesFieldName, colKey) + val tm = MatrixType.fromTableType( + TableType(ts.rowType, ts.key, ts.globalType), + colsFieldName, + entriesFieldName, + colKey, + ) val folder = if (exportType == ExportType.CONCATENATED) ctx.createTmpPath("export-bgen-concatenated") else @@ -1056,58 +1635,97 @@ case class MatrixBGENWriter( ts.mapContexts { oldCtx => val d = digitsNeeded(ts.numPartitions) - val partFiles = ToStream(Literal(TArray(TString), Array.tabulate(ts.numPartitions)(i => s"$folder/${ partFile(d, i) }-").toFastSeq)) - val numVariants = if (writeHeader) ToStream(ts.countPerPartition()) else ToStream(MakeArray(Array.tabulate(ts.numPartitions)(_ => NA(TInt64)): _*)) + val partFiles = ToStream(Literal( + TArray(TString), + Array.tabulate(ts.numPartitions)(i => s"$folder/${partFile(d, i)}-").toFastSeq, + )) + val numVariants = if (writeHeader) ToStream(ts.countPerPartition()) + else ToStream(MakeArray(Array.tabulate(ts.numPartitions)(_ => NA(TInt64)): _*)) val ctxElt = Ref(genUID(), tcoerce[TStream](oldCtx.typ).elementType) val pf = Ref(genUID(), tcoerce[TStream](partFiles.typ).elementType) val nv = Ref(genUID(), tcoerce[TStream](numVariants.typ).elementType) - StreamZip(FastSeq(oldCtx, partFiles, numVariants), FastSeq(ctxElt.name, pf.name, nv.name), + StreamZip( + FastSeq(oldCtx, partFiles, numVariants), + FastSeq(ctxElt.name, pf.name, nv.name), MakeStruct(FastSeq("oldCtx" -> ctxElt, "numVariants" -> nv, "partFile" -> pf)), - ArrayZipBehavior.AssertSameLength) - }(GetField(_, "oldCtx")).mapCollectWithContextsAndGlobals("matrix_vcf_writer") { (rows, ctxRef) => - val partFile = GetField(ctxRef, "partFile") + UUID4() - val ctx = MakeStruct(FastSeq( - "cols" -> GetField(ts.globals, colsFieldName), - "numVariants" -> GetField(ctxRef, "numVariants"), - "partFile" -> partFile)) - WritePartition(rows, ctx, partWriter) - }{ (results, globals) => - val ctx = MakeStruct(FastSeq("cols" -> GetField(globals, colsFieldName), "results" -> results)) + ArrayZipBehavior.AssertSameLength, + ) + }(GetField(_, "oldCtx")).mapCollectWithContextsAndGlobals("matrix_vcf_writer") { + (rows, ctxRef) => + val partFile = GetField(ctxRef, "partFile") + UUID4() + val ctx = MakeStruct(FastSeq( + "cols" -> GetField(ts.globals, colsFieldName), + "numVariants" -> GetField(ctxRef, "numVariants"), + "partFile" -> partFile, + )) + WritePartition(rows, ctx, partWriter) + } { (results, globals) => + val ctx = + MakeStruct(FastSeq("cols" -> GetField(globals, colsFieldName), "results" -> results)) val commit = BGENExportFinalizer(tm, path, exportType, compressionInt) Begin(FastSeq(WriteMetadata(ctx, commit))) } } } -case class BGENPartitionWriter(typ: MatrixType, entriesFieldName: String, writeHeader: Boolean, compression: Int) extends PartitionWriter { +case class BGENPartitionWriter( + typ: MatrixType, + entriesFieldName: String, + writeHeader: Boolean, + compression: Int, +) extends PartitionWriter { require(typ.entryType.hasField("GP") && typ.entryType.fieldType("GP") == TArray(TFloat64)) - val ctxType: Type = TStruct("cols" -> TArray(typ.colType), "numVariants" -> TInt64, "partFile" -> TString) - override def returnType: TStruct = TStruct("partFile" -> TString, "numVariants" -> TInt64, "dropped" -> TInt64) - def unionTypeRequiredness(r: TypeWithRequiredness, ctxType: TypeWithRequiredness, streamType: RIterable): Unit = { + + val ctxType: Type = + TStruct("cols" -> TArray(typ.colType), "numVariants" -> TInt64, "partFile" -> TString) + + override def returnType: TStruct = + TStruct("partFile" -> TString, "numVariants" -> TInt64, "dropped" -> TInt64) + + def unionTypeRequiredness( + r: TypeWithRequiredness, + ctxType: TypeWithRequiredness, + streamType: RIterable, + ): Unit = { r.union(ctxType.required) r.union(streamType.required) } - final def consumeStream(ctx: ExecuteContext, cb: EmitCodeBuilder, stream: StreamProducer, - context: EmitCode, region: Value[Region]): IEmitCode = { + final def consumeStream( + ctx: ExecuteContext, + cb: EmitCodeBuilder, + stream: StreamProducer, + context: EmitCode, + region: Value[Region], + ): IEmitCode = { context.toI(cb).map(cb) { case ctx: SBaseStructValue => - val filename = ctx.loadField(cb, "partFile").get(cb, "partFile can't be missing").asString.loadString(cb) + val filename = + ctx.loadField(cb, "partFile").getOrFatal( + cb, + "partFile can't be missing", + ).asString.loadString(cb) val os = cb.memoize(cb.emb.create(filename)) - val colValues = ctx.loadField(cb, "cols").get(cb).asIndexable + val colValues = ctx.loadField(cb, "cols").getOrAssert(cb).asIndexable val nSamples = colValues.loadLength() if (writeHeader) { val sampleIds = cb.memoize(Code.newArray[String](colValues.loadLength())) colValues.forEachDefined(cb) { case (cb, i, colv: SBaseStructValue) => - val s = colv.subset(typ.colKey: _*).loadField(cb, 0).get(cb).asString + val s = colv.subset(typ.colKey: _*).loadField(cb, 0).getOrAssert(cb).asString cb += (sampleIds(i) = s.loadString(cb)) } - val numVariants = ctx.loadField(cb, "numVariants").get(cb).asInt64.value - val header = Code.invokeScalaObject3[Array[String], Long, Int, Array[Byte]](BgenWriter.getClass, "headerBlock", sampleIds, numVariants, compression) + val numVariants = ctx.loadField(cb, "numVariants").getOrAssert(cb).asInt64.value + val header = Code.invokeScalaObject3[Array[String], Long, Int, Array[Byte]]( + BgenWriter.getClass, + "headerBlock", + sampleIds, + numVariants, + compression, + ) cb += os.invoke[Array[Byte], Unit]("write", header) } @@ -1115,47 +1733,116 @@ case class BGENPartitionWriter(typ: MatrixType, entriesFieldName: String, writeH val buf = cb.memoize(Code.newInstance[ByteArrayBuilder, Int](16)) val uncompBuf = cb.memoize(Code.newInstance[ByteArrayBuilder, Int](16)) - val slowCount = if (writeHeader || stream.length.isDefined) None else Some(cb.newLocal[Long]("num_variants", 0)) - val fastCount = if (writeHeader) Some(ctx.loadField(cb, "numVariants").get(cb).asInt64.value) else stream.length.map(len => cb.memoize(len(cb).toL)) + val slowCount = if (writeHeader || stream.length.isDefined) None + else Some(cb.newLocal[Long]("num_variants", 0)) + val fastCount = if (writeHeader) + Some(ctx.loadField(cb, "numVariants").getOrAssert(cb).asInt64.value) + else stream.length.map(len => cb.memoize(len(cb).toL)) stream.memoryManagedConsume(region, cb) { cb => slowCount.foreach(nv => cb.assign(nv, nv + 1L)) - consumeElement(cb, stream.element, buf, uncompBuf, os, stream.elementRegion, dropped, nSamples) + consumeElement( + cb, + stream.element, + buf, + uncompBuf, + os, + stream.elementRegion, + dropped, + nSamples, + ) } cb += os.invoke[Unit]("flush") cb += os.invoke[Unit]("close") val numVariants = fastCount.getOrElse(slowCount.get) - SStackStruct.constructFromArgs(cb, region, returnType, + SStackStruct.constructFromArgs( + cb, + region, + returnType, EmitCode.present(cb.emb, SJavaString.construct(cb, filename)), EmitCode.present(cb.emb, new SInt64Value(numVariants)), - EmitCode.present(cb.emb, new SInt64Value(dropped))) + EmitCode.present(cb.emb, new SInt64Value(dropped)), + ) } } - private def consumeElement(cb: EmitCodeBuilder, element: EmitCode, buf: Value[ByteArrayBuilder], uncompBuf: Value[ByteArrayBuilder], - os: Value[OutputStream], region: Value[Region], dropped: Settable[Long], nSamples: Value[Int]): Unit = { + private def consumeElement( + cb: EmitCodeBuilder, + element: EmitCode, + buf: Value[ByteArrayBuilder], + uncompBuf: Value[ByteArrayBuilder], + os: Value[OutputStream], + region: Value[Region], + dropped: Settable[Long], + nSamples: Value[Int], + ): Unit = { - def stringToBytesWithShortLength(cb: EmitCodeBuilder, bb: Value[ByteArrayBuilder], str: Value[String]) = - cb += Code.toUnit(Code.invokeScalaObject2[ByteArrayBuilder, String, Int](BgenWriter.getClass, "stringToBytesWithShortLength", bb, str)) - def stringToBytesWithIntLength(cb: EmitCodeBuilder, bb: Value[ByteArrayBuilder], str: Value[String]) = - cb += Code.toUnit(Code.invokeScalaObject2[ByteArrayBuilder, String, Int](BgenWriter.getClass, "stringToBytesWithIntLength", bb, str)) - def intToBytesLE(cb: EmitCodeBuilder, bb: Value[ByteArrayBuilder], i: Value[Int]) = cb += Code.invokeScalaObject2[ByteArrayBuilder, Int, Unit](BgenWriter.getClass, "intToBytesLE", bb, i) - def shortToBytesLE(cb: EmitCodeBuilder, bb: Value[ByteArrayBuilder], i: Value[Int]) = cb += Code.invokeScalaObject2[ByteArrayBuilder, Int, Unit](BgenWriter.getClass, "shortToBytesLE", bb, i) - def updateIntToBytesLE(cb: EmitCodeBuilder, bb: Value[ByteArrayBuilder], i: Value[Int], pos: Value[Int]) = - cb += Code.invokeScalaObject3[ByteArrayBuilder, Int, Int, Unit](BgenWriter.getClass, "updateIntToBytesLE", bb, i, pos) + def stringToBytesWithShortLength( + cb: EmitCodeBuilder, + bb: Value[ByteArrayBuilder], + str: Value[String], + ) = + cb += Code.toUnit(Code.invokeScalaObject2[ByteArrayBuilder, String, Int]( + BgenWriter.getClass, + "stringToBytesWithShortLength", + bb, + str, + )) + def stringToBytesWithIntLength( + cb: EmitCodeBuilder, + bb: Value[ByteArrayBuilder], + str: Value[String], + ) = + cb += Code.toUnit(Code.invokeScalaObject2[ByteArrayBuilder, String, Int]( + BgenWriter.getClass, + "stringToBytesWithIntLength", + bb, + str, + )) + def intToBytesLE(cb: EmitCodeBuilder, bb: Value[ByteArrayBuilder], i: Value[Int]) = + cb += Code.invokeScalaObject2[ByteArrayBuilder, Int, Unit]( + BgenWriter.getClass, + "intToBytesLE", + bb, + i, + ) + def shortToBytesLE(cb: EmitCodeBuilder, bb: Value[ByteArrayBuilder], i: Value[Int]) = + cb += Code.invokeScalaObject2[ByteArrayBuilder, Int, Unit]( + BgenWriter.getClass, + "shortToBytesLE", + bb, + i, + ) + def updateIntToBytesLE( + cb: EmitCodeBuilder, + bb: Value[ByteArrayBuilder], + i: Value[Int], + pos: Value[Int], + ) = + cb += Code.invokeScalaObject3[ByteArrayBuilder, Int, Int, Unit]( + BgenWriter.getClass, + "updateIntToBytesLE", + bb, + i, + pos, + ) - def add(cb: EmitCodeBuilder, bb: Value[ByteArrayBuilder], i: Value[Int]) = cb += bb.invoke[Byte, Unit]("add", i.toB) + def add(cb: EmitCodeBuilder, bb: Value[ByteArrayBuilder], i: Value[Int]) = + cb += bb.invoke[Byte, Unit]("add", i.toB) - val elt = element.toI(cb).get(cb).asBaseStruct - val locus = elt.loadField(cb, "locus").get(cb).asLocus + val elt = element.toI(cb).getOrAssert(cb).asBaseStruct + val locus = elt.loadField(cb, "locus").getOrAssert(cb).asLocus val chr = locus.contig(cb).loadString(cb) val pos = locus.position(cb) - val varid = elt.loadField(cb, "varid").get(cb).asString.loadString(cb) - val rsid = elt.loadField(cb, "rsid").get(cb).asString.loadString(cb) - val alleles = elt.loadField(cb, "alleles").get(cb).asIndexable + val varid = elt.loadField(cb, "varid").getOrAssert(cb).asString.loadString(cb) + val rsid = elt.loadField(cb, "rsid").getOrAssert(cb).asString.loadString(cb) + val alleles = elt.loadField(cb, "alleles").getOrAssert(cb).asIndexable - cb.if_(alleles.loadLength() >= 0xffff, cb._fatal("Maximum number of alleles per variant is 65536. Found ", alleles.loadLength().toS)) + cb.if_( + alleles.loadLength() >= 0xffff, + cb._fatal("Maximum number of alleles per variant is 65536. Found ", alleles.loadLength().toS), + ) cb += buf.invoke[Unit]("clear") cb += uncompBuf.invoke[Unit]("clear") @@ -1186,38 +1873,80 @@ case class BGENPartitionWriter(typ: MatrixType, entriesFieldName: String, writeH val samplePloidyStart = cb.memoize(uncompBuf.invoke[Int]("size")) val i = cb.newLocal[Int]("i") - cb.for_(cb.assign(i, 0), i < nSamples, cb.assign(i, i + 1), { - add(cb, uncompBuf, 0x82) // placeholder for sample ploidy - default is missing - }) + cb.for_( + cb.assign(i, 0), + i < nSamples, + cb.assign(i, i + 1), + add(cb, uncompBuf, 0x82), // placeholder for sample ploidy - default is missing + ) add(cb, uncompBuf, BgenWriter.phased) add(cb, uncompBuf, 8) - def emitNullGP(cb: EmitCodeBuilder): Unit = cb.for_(cb.assign(i, 0), i < nGenotypes - 1, cb.assign(i, i + 1), add(cb, uncompBuf, 0)) - - val entries = elt.loadField(cb, entriesFieldName).get(cb).asIndexable - entries.forEachDefinedOrMissing(cb)({ (cb, j) => - emitNullGP(cb) - }, { case (cb, j, entry: SBaseStructValue) => - entry.loadField(cb, "GP").consume(cb, emitNullGP(cb), { gp => - val gpSum = cb.newLocal[Double]("gpSum", 0d) - gp.asIndexable.forEachDefined(cb) { (cb, idx, x) => - val gpv = x.asDouble.value - cb.if_(gpv < 0d, - cb._fatal("found GP value less than 0: ", gpv.toS, ", at sample ", j.toS, " of variant", chr, ":", pos.toS)) - cb.assign(gpSum, gpSum + gpv) - cb += (gpResized(idx) = gpv * BgenWriter.totalProb.toDouble) - } - cb.if_(gpSum >= 0.999 && gpSum <= 1.001, { - cb += uncompBuf.invoke[Int, Byte, Unit]("update", samplePloidyStart + j, BgenWriter.ploidy) - cb += Code.invokeScalaObject6[Array[Double], Array[Double], Array[Int], Array[Int], ByteArrayBuilder, Long, Unit](BgenWriter.getClass, "roundWithConstantSum", - gpResized, fractional, index, indexInverse, uncompBuf, BgenWriter.totalProb.toLong) - }, { - cb.assign(dropped, dropped + 1l) - emitNullGP(cb) - }) - }) - }) + def emitNullGP(cb: EmitCodeBuilder): Unit = + cb.for_(cb.assign(i, 0), i < nGenotypes - 1, cb.assign(i, i + 1), add(cb, uncompBuf, 0)) + + val entries = elt.loadField(cb, entriesFieldName).getOrAssert(cb).asIndexable + entries.forEachDefinedOrMissing(cb)( + (cb, j) => emitNullGP(cb), + { case (cb, j, entry: SBaseStructValue) => + entry.loadField(cb, "GP").consume( + cb, + emitNullGP(cb), + { gp => + val gpSum = cb.newLocal[Double]("gpSum", 0d) + gp.asIndexable.forEachDefined(cb) { (cb, idx, x) => + val gpv = x.asDouble.value + cb.if_( + gpv < 0d, + cb._fatal( + "found GP value less than 0: ", + gpv.toS, + ", at sample ", + j.toS, + " of variant", + chr, + ":", + pos.toS, + ), + ) + cb.assign(gpSum, gpSum + gpv) + cb += (gpResized(idx) = gpv * BgenWriter.totalProb.toDouble) + } + cb.if_( + gpSum >= 0.999 && gpSum <= 1.001, { + cb += uncompBuf.invoke[Int, Byte, Unit]( + "update", + samplePloidyStart + j, + BgenWriter.ploidy, + ) + cb += Code.invokeScalaObject6[ + Array[Double], + Array[Double], + Array[Int], + Array[Int], + ByteArrayBuilder, + Long, + Unit, + ]( + BgenWriter.getClass, + "roundWithConstantSum", + gpResized, + fractional, + index, + indexInverse, + uncompBuf, + BgenWriter.totalProb.toLong, + ) + }, { + cb.assign(dropped, dropped + 1L) + emitNullGP(cb) + }, + ) + }, + ) + }, + ) // end emitGPData val uncompLen = cb.memoize(uncompBuf.invoke[Int]("size")) @@ -1226,7 +1955,12 @@ case class BGENPartitionWriter(typ: MatrixType, entriesFieldName: String, writeH case 1 => "compressZlib" case 2 => "compressZstd" } - val compLen = cb.memoize(Code.invokeScalaObject2[ByteArrayBuilder, Array[Byte], Int](CompressionUtils.getClass, compMethod, buf, uncompBuf.invoke[Array[Byte]]("result"))) + val compLen = cb.memoize(Code.invokeScalaObject2[ByteArrayBuilder, Array[Byte], Int]( + CompressionUtils.getClass, + compMethod, + buf, + uncompBuf.invoke[Array[Byte]]("result"), + )) updateIntToBytesLE(cb, buf, cb.memoize(compLen + 4), gtDataBlockStart) updateIntToBytesLE(cb, buf, uncompLen, cb.memoize(gtDataBlockStart + 4)) @@ -1235,74 +1969,140 @@ case class BGENPartitionWriter(typ: MatrixType, entriesFieldName: String, writeH } } -case class BGENExportFinalizer(typ: MatrixType, path: String, exportType: String, compression: Int) extends MetadataWriter { - def annotationType: Type = TStruct("cols" -> TArray(typ.colType), "results" -> TArray(TStruct("partFile" -> TString, "numVariants" -> TInt64, "dropped" -> TInt64))) - - def writeMetadata(writeAnnotations: => IEmitCode, cb: EmitCodeBuilder, region: Value[Region]): Unit = { - val annotations = writeAnnotations.get(cb).asBaseStruct - val colValues = annotations.loadField(cb, "cols").get(cb).asIndexable +case class BGENExportFinalizer(typ: MatrixType, path: String, exportType: String, compression: Int) + extends MetadataWriter { + def annotationType: Type = TStruct( + "cols" -> TArray(typ.colType), + "results" -> TArray(TStruct( + "partFile" -> TString, + "numVariants" -> TInt64, + "dropped" -> TInt64, + )), + ) + + def writeMetadata(writeAnnotations: => IEmitCode, cb: EmitCodeBuilder, region: Value[Region]) + : Unit = { + val annotations = writeAnnotations.getOrAssert(cb).asBaseStruct + val colValues = annotations.loadField(cb, "cols").getOrAssert(cb).asIndexable val sampleIds = cb.memoize(Code.newArray[String](colValues.loadLength())) colValues.forEachDefined(cb) { case (cb, i, colv: SBaseStructValue) => - val s = colv.subset(typ.colKey: _*).loadField(cb, 0).get(cb).asString + val s = colv.subset(typ.colKey: _*).loadField(cb, 0).getOrAssert(cb).asString cb += (sampleIds(i) = s.loadString(cb)) } - val results = annotations.loadField(cb, "results").get(cb).asIndexable + val results = annotations.loadField(cb, "results").getOrAssert(cb).asIndexable val dropped = cb.newLocal[Long]("dropped", 0L) results.forEachDefined(cb) { (cb, i, res) => - res.asBaseStruct.loadField(cb, "dropped").consume(cb, {/* do nothing */}, { d => - cb.assign(dropped, dropped + d.asInt64.value) - }) + res.asBaseStruct.loadField(cb, "dropped").consume( + cb, + { /* do nothing */ }, + d => cb.assign(dropped, dropped + d.asInt64.value), + ) } - cb.if_(dropped.cne(0L), cb.warning("Set ", dropped.toS, " genotypes to missing: total GP probability did not lie in [0.999, 1.001].")) + cb.if_( + dropped.cne(0L), + cb.warning( + "Set ", + dropped.toS, + " genotypes to missing: total GP probability did not lie in [0.999, 1.001].", + ), + ) val numVariants = cb.newLocal[Long]("num_variants", 0L) if (exportType != ExportType.PARALLEL_HEADER_IN_SHARD) { results.forEachDefined(cb) { (cb, i, res) => - res.asBaseStruct.loadField(cb, "numVariants").consume(cb, {/* do nothing */}, { nv => - cb.assign(numVariants, numVariants + nv.asInt64.value) - }) + res.asBaseStruct.loadField(cb, "numVariants").consume( + cb, + { /* do nothing */ }, + nv => cb.assign(numVariants, numVariants + nv.asInt64.value), + ) } } - if (exportType == ExportType.PARALLEL_SEPARATE_HEADER || exportType == ExportType.PARALLEL_HEADER_IN_SHARD) { + if ( + exportType == ExportType.PARALLEL_SEPARATE_HEADER || exportType == ExportType.PARALLEL_HEADER_IN_SHARD + ) { val files = cb.memoize(Code.newArray[String](results.loadLength())) results.forEachDefined(cb) { (cb, i, res) => - cb += files.update(i, res.asBaseStruct.loadField(cb, "partFile").get(cb).asString.loadString(cb)) + cb += files.update( + i, + res.asBaseStruct.loadField(cb, "partFile").getOrAssert(cb).asString.loadString(cb), + ) } val headerStr = if (exportType == ExportType.PARALLEL_SEPARATE_HEADER) { val headerStr = cb.memoize(const(path + ".bgen").concat("/header")) val os = cb.memoize(cb.emb.create(headerStr)) - val header = Code.invokeScalaObject3[Array[String], Long, Int, Array[Byte]](BgenWriter.getClass, "headerBlock", sampleIds, numVariants, compression) + val header = Code.invokeScalaObject3[Array[String], Long, Int, Array[Byte]]( + BgenWriter.getClass, + "headerBlock", + sampleIds, + numVariants, + compression, + ) cb += os.invoke[Array[Byte], Unit]("write", header) cb += os.invoke[Unit]("close") headerStr } else Code._null[String] - cb += Code.invokeScalaObject3[FS, String, Array[String], Unit](TableTextFinalizer.getClass, "cleanup", cb.emb.getFS, path + ".bgen", files) - cb += Code.invokeScalaObject4[FS, String, Array[String], String, Unit](TableTextFinalizer.getClass, "writeManifest", cb.emb.getFS, path + ".bgen", files, headerStr) + cb += Code.invokeScalaObject3[FS, String, Array[String], Unit]( + TableTextFinalizer.getClass, + "cleanup", + cb.emb.getFS, + path + ".bgen", + files, + ) + cb += Code.invokeScalaObject4[FS, String, Array[String], String, Unit]( + TableTextFinalizer.getClass, + "writeManifest", + cb.emb.getFS, + path + ".bgen", + files, + headerStr, + ) } - if (exportType == ExportType.CONCATENATED) { val os = cb.memoize(cb.emb.create(const(path + ".bgen"))) - val header = Code.invokeScalaObject3[Array[String], Long, Int, Array[Byte]](BgenWriter.getClass, "headerBlock", sampleIds, numVariants, compression) + val header = Code.invokeScalaObject3[Array[String], Long, Int, Array[Byte]]( + BgenWriter.getClass, + "headerBlock", + sampleIds, + numVariants, + compression, + ) cb += os.invoke[Array[Byte], Unit]("write", header) - annotations.loadField(cb, "results").get(cb).asIndexable.forEachDefined(cb) { (cb, i, res) => - res.asBaseStruct.loadField(cb, "partFile").consume(cb, {/* do nothing */}, { case pf: SStringValue => - val f = cb.memoize(cb.emb.open(pf.loadString(cb), false)) - cb += Code.invokeStatic3[org.apache.hadoop.io.IOUtils, InputStream, OutputStream, Int, Unit]("copyBytes", f, os, 4096) - cb += f.invoke[Unit]("close") - }) + annotations.loadField(cb, "results").getOrAssert(cb).asIndexable.forEachDefined(cb) { + (cb, i, res) => + res.asBaseStruct.loadField(cb, "partFile").consume( + cb, + { /* do nothing */ }, + { case pf: SStringValue => + val f = cb.memoize(cb.emb.open(pf.loadString(cb), false)) + cb += Code.invokeStatic3[ + org.apache.hadoop.io.IOUtils, + InputStream, + OutputStream, + Int, + Unit, + ]("copyBytes", f, os, 4096) + cb += f.invoke[Unit]("close") + }, + ) } cb += os.invoke[Unit]("flush") cb += os.invoke[Unit]("close") } - cb += Code.invokeScalaObject3[FS, String, Array[String], Unit](BgenWriter.getClass, "writeSampleFile", cb.emb.getFS, path, sampleIds) + cb += Code.invokeScalaObject3[FS, String, Array[String], Unit]( + BgenWriter.getClass, + "writeSampleFile", + cb.emb.getFS, + path, + sampleIds, + ) } } @@ -1310,8 +2110,14 @@ case class MatrixPLINKWriter( path: String ) extends MatrixWriter { - override def lower(colsFieldName: String, entriesFieldName: String, colKey: IndexedSeq[String], - ctx: ExecuteContext, ts: TableStage, r: RTable): IR = { + override def lower( + colsFieldName: String, + entriesFieldName: String, + colKey: IndexedSeq[String], + ctx: ExecuteContext, + ts: TableStage, + r: RTable, + ): IR = { val tm = MatrixType.fromTableType(ts.tableType, colsFieldName, entriesFieldName, colKey) val tmpBedDir = ctx.createTmpPath("export-plink", "bed") val tmpBimDir = ctx.createTmpPath("export-plink", "bim") @@ -1319,28 +2125,37 @@ case class MatrixPLINKWriter( val lineWriter = PLINKPartitionWriter(tm, entriesFieldName) ts.mapContexts { oldCtx => val d = digitsNeeded(ts.numPartitions) - val files = Literal(TArray(TTuple(TString, TString)), - Array.tabulate(ts.numPartitions)(i => Row(s"$tmpBedDir/${ partFile(d, i) }-", s"$tmpBimDir/${ partFile(d, i) }-")).toFastSeq) + val files = Literal( + TArray(TTuple(TString, TString)), + Array.tabulate(ts.numPartitions)(i => + Row(s"$tmpBedDir/${partFile(d, i)}-", s"$tmpBimDir/${partFile(d, i)}-") + ).toFastSeq, + ) zip2(oldCtx, ToStream(files), ArrayZipBehavior.AssertSameLength) { (ctxElt, pf) => MakeStruct(FastSeq( "oldCtx" -> ctxElt, - "file" -> pf)) + "file" -> pf, + )) } - }(GetField(_, "oldCtx")).mapCollectWithContextsAndGlobals("matrix_plink_writer") { (rows, ctxRef) => - val id = UUID4() - val bedFile = GetTupleElement(GetField(ctxRef, "file"), 0) + id - val bimFile = GetTupleElement(GetField(ctxRef, "file"), 1) + id - val ctx = MakeStruct(FastSeq("bedFile" -> bedFile, "bimFile" -> bimFile)) - WritePartition(rows, ctx, lineWriter) - }{ (parts, globals) => + }(GetField(_, "oldCtx")).mapCollectWithContextsAndGlobals("matrix_plink_writer") { + (rows, ctxRef) => + val id = UUID4() + val bedFile = GetTupleElement(GetField(ctxRef, "file"), 0) + id + val bimFile = GetTupleElement(GetField(ctxRef, "file"), 1) + id + val ctx = MakeStruct(FastSeq("bedFile" -> bedFile, "bimFile" -> bimFile)) + WritePartition(rows, ctx, lineWriter) + } { (parts, globals) => val commit = PLINKExportFinalizer(tm, path, tmpBedDir + "/header") val famWriter = TableTextPartitionWriter(tm.colsTableType.rowType, "\t", writeHeader = false) val famPath = Str(path + ".fam") val cols = ToStream(GetField(globals, colsFieldName)) val writeFam = WritePartition(cols, famPath, famWriter) bindIR(writeFam) { fpath => - Begin(FastSeq(WriteMetadata(parts, commit), WriteMetadata(fpath, SimpleMetadataWriter(fpath.typ)))) + Begin(FastSeq( + WriteMetadata(parts, commit), + WriteMetadata(fpath, SimpleMetadataWriter(fpath.typ)), + )) } } } @@ -1355,16 +2170,25 @@ case class PLINKPartitionWriter(typ: MatrixType, entriesFieldName: String) exten val varidIdx = typ.rowType.fieldIdx("varid") val cmPosIdx = typ.rowType.fieldIdx("cm_position") - def unionTypeRequiredness(r: TypeWithRequiredness, ctxType: TypeWithRequiredness, streamType: RIterable): Unit = { + def unionTypeRequiredness( + r: TypeWithRequiredness, + ctxType: TypeWithRequiredness, + streamType: RIterable, + ): Unit = { r.union(ctxType.required) r.union(streamType.required) } - final def consumeStream(ctx: ExecuteContext, cb: EmitCodeBuilder, stream: StreamProducer, - context: EmitCode, region: Value[Region]): IEmitCode = { + final def consumeStream( + ctx: ExecuteContext, + cb: EmitCodeBuilder, + stream: StreamProducer, + context: EmitCode, + region: Value[Region], + ): IEmitCode = { context.toI(cb).map(cb) { case context: SBaseStructValue => - val bedFile = context.loadField(cb, "bedFile").get(cb).asString.loadString(cb) - val bimFile = context.loadField(cb, "bimFile").get(cb).asString.loadString(cb) + val bedFile = context.loadField(cb, "bedFile").getOrAssert(cb).asString.loadString(cb) + val bimFile = context.loadField(cb, "bimFile").getOrAssert(cb).asString.loadString(cb) val bedOs = cb.memoize(cb.emb.create(bedFile)) val bimOs = cb.memoize(cb.emb.create(bimFile)) @@ -1384,28 +2208,49 @@ case class PLINKPartitionWriter(typ: MatrixType, entriesFieldName: String) exten } } - private def consumeElement(cb: EmitCodeBuilder, element: EmitCode, bimOs: Value[OutputStream], bp: Value[BitPacker], region: Value[Region]): Unit = { - def _writeC(cb: EmitCodeBuilder, code: Code[Int]) = { cb += bimOs.invoke[Int, Unit]("write", code) } - def _writeB(cb: EmitCodeBuilder, code: Code[Array[Byte]]) = { cb += bimOs.invoke[Array[Byte], Unit]("write", code) } - def _writeS(cb: EmitCodeBuilder, code: Code[String]) = { _writeB(cb, code.invoke[Array[Byte]]("getBytes")) } + private def consumeElement( + cb: EmitCodeBuilder, + element: EmitCode, + bimOs: Value[OutputStream], + bp: Value[BitPacker], + region: Value[Region], + ): Unit = { + def _writeC(cb: EmitCodeBuilder, code: Code[Int]) = cb += bimOs.invoke[Int, Unit]("write", code) + def _writeB(cb: EmitCodeBuilder, code: Code[Array[Byte]]) = + cb += bimOs.invoke[Array[Byte], Unit]("write", code) + def _writeS(cb: EmitCodeBuilder, code: Code[String]) = + _writeB(cb, code.invoke[Array[Byte]]("getBytes")) def writeC(code: Code[Int]) = _writeC(cb, code) def writeS(code: Code[String]) = _writeS(cb, code) - val elt = element.toI(cb).get(cb).asBaseStruct + val elt = element.toI(cb).getOrAssert(cb).asBaseStruct - val (contig, position) = elt.loadField(cb, locusIdx).get(cb) match { + val (contig, position) = elt.loadField(cb, locusIdx).getOrAssert(cb) match { case locus: SLocusValue => locus.contig(cb).loadString(cb) -> locus.position(cb) case locus: SBaseStructValue => - locus.loadField(cb, 0).get(cb).asString.loadString(cb) -> locus.loadField(cb, 1).get(cb).asInt.value + locus.loadField(cb, 0).getOrAssert(cb).asString.loadString(cb) -> locus.loadField( + cb, + 1, + ).getOrAssert( + cb + ).asInt.value } - val cmPosition = elt.loadField(cb, cmPosIdx).get(cb).asDouble - val varid = elt.loadField(cb, varidIdx).get(cb).asString.loadString(cb) - val alleles = elt.loadField(cb, allelesIdx).get(cb).asIndexable - val a0 = alleles.loadElement(cb, 0).get(cb).asString.loadString(cb) - val a1 = alleles.loadElement(cb, 1).get(cb).asString.loadString(cb) - - cb += Code.invokeScalaObject5[String, String, Int, String, String, Unit](ExportPlink.getClass, "checkVariant", contig, varid, position, a0, a1) + val cmPosition = elt.loadField(cb, cmPosIdx).getOrAssert(cb).asDouble + val varid = elt.loadField(cb, varidIdx).getOrAssert(cb).asString.loadString(cb) + val alleles = elt.loadField(cb, allelesIdx).getOrAssert(cb).asIndexable + val a0 = alleles.loadElement(cb, 0).getOrAssert(cb).asString.loadString(cb) + val a1 = alleles.loadElement(cb, 1).getOrAssert(cb).asString.loadString(cb) + + cb += Code.invokeScalaObject5[String, String, Int, String, String, Unit]( + ExportPlink.getClass, + "checkVariant", + contig, + varid, + position, + a0, + a1, + ) writeS(contig) writeC('\t') writeS(varid) @@ -1419,23 +2264,36 @@ case class PLINKPartitionWriter(typ: MatrixType, entriesFieldName: String) exten writeS(a0) writeC('\n') - elt.loadField(cb, entriesFieldName).get(cb).asIndexable.forEachDefinedOrMissing(cb)({ (cb, i) => - cb += bp.invoke[Int, Unit]("add", 1) - }, { (cb, i, va) => - va.asBaseStruct.loadField(cb, "GT").consume(cb, { - cb += bp.invoke[Int, Unit]("add", 1) - }, { case call: SCallValue => - val gtIx = cb.memoize(Code.invokeScalaObject1[Call, Int](Call.getClass, "unphasedDiploidGtIndex", call.canonicalCall(cb))) - val gt = (gtIx ceq 0).mux(3, (gtIx ceq 1).mux(2, 0)) - cb += bp.invoke[Int, Unit]("add", gt) - }) - }) + elt.loadField(cb, entriesFieldName).getOrAssert(cb).asIndexable.forEachDefinedOrMissing(cb)( + (cb, i) => cb += bp.invoke[Int, Unit]("add", 1), + { (cb, i, va) => + va.asBaseStruct.loadField(cb, "GT").consume( + cb, + cb += bp.invoke[Int, Unit]("add", 1), + { case call: SCallValue => + val gtIx = cb.memoize(Code.invokeScalaObject1[Call, Int]( + Call.getClass, + "unphasedDiploidGtIndex", + call.canonicalCall(cb), + )) + val gt = (gtIx ceq 0).mux(3, (gtIx ceq 1).mux(2, 0)) + cb += bp.invoke[Int, Unit]("add", gt) + }, + ) + }, + ) cb += bp.invoke[Unit]("flush") } } object PLINKExportFinalizer { - def finalize(fs: FS, path: String, headerPath: String, bedFiles: Array[String], bimFiles: Array[String]): Unit = { + def finalize( + fs: FS, + path: String, + headerPath: String, + bedFiles: Array[String], + bimFiles: Array[String], + ): Unit = { using(fs.create(headerPath))(out => out.write(ExportPlink.bedHeader)) bedFiles(0) = headerPath fs.concatenateFiles(bedFiles, path + ".bed") @@ -1443,20 +2301,30 @@ object PLINKExportFinalizer { } } -case class PLINKExportFinalizer(typ: MatrixType, path: String, headerPath: String) extends MetadataWriter { +case class PLINKExportFinalizer(typ: MatrixType, path: String, headerPath: String) + extends MetadataWriter { def annotationType: Type = TArray(TStruct("bedFile" -> TString, "bimFile" -> TString)) - def writeMetadata(writeAnnotations: => IEmitCode, cb: EmitCodeBuilder, region: Value[Region]): Unit = { - val paths = writeAnnotations.get(cb).asIndexable + def writeMetadata(writeAnnotations: => IEmitCode, cb: EmitCodeBuilder, region: Value[Region]) + : Unit = { + val paths = writeAnnotations.getOrAssert(cb).asIndexable val bedFiles = cb.memoize(Code.newArray[String](paths.loadLength() + 1)) // room for header val bimFiles = cb.memoize(Code.newArray[String](paths.loadLength())) paths.forEachDefined(cb) { case (cb, i, elt: SBaseStructValue) => - val bed = elt.loadField(cb, "bedFile").get(cb).asString.loadString(cb) - val bim = elt.loadField(cb, "bimFile").get(cb).asString.loadString(cb) + val bed = elt.loadField(cb, "bedFile").getOrAssert(cb).asString.loadString(cb) + val bim = elt.loadField(cb, "bimFile").getOrAssert(cb).asString.loadString(cb) cb += (bedFiles(cb.memoize(i + 1)) = bed) cb += (bimFiles(i) = bim) } - cb += Code.invokeScalaObject5[FS, String, String, Array[String], Array[String], Unit](PLINKExportFinalizer.getClass, "finalize", cb.emb.getFS, path, headerPath, bedFiles, bimFiles) + cb += Code.invokeScalaObject5[FS, String, String, Array[String], Array[String], Unit]( + PLINKExportFinalizer.getClass, + "finalize", + cb.emb.getFS, + path, + headerPath, + bedFiles, + bimFiles, + ) } } @@ -1464,11 +2332,17 @@ case class MatrixBlockMatrixWriter( path: String, overwrite: Boolean, entryField: String, - blockSize: Int + blockSize: Int, ) extends MatrixWriter { - override def lower(colsFieldName: String, entriesFieldName: String, colKey: IndexedSeq[String], - ctx: ExecuteContext, ts: TableStage, r: RTable): IR = { + override def lower( + colsFieldName: String, + entriesFieldName: String, + colKey: IndexedSeq[String], + ctx: ExecuteContext, + ts: TableStage, + r: RTable, + ): IR = { val tm = MatrixType.fromTableType(ts.tableType, colsFieldName, entriesFieldName, colKey) val rm = r.asMatrixType(colsFieldName, entriesFieldName) @@ -1478,8 +2352,11 @@ case class MatrixBlockMatrixWriter( val numBlockCols: Int = (numCols - 1) / blockSize + 1 val lastBlockNumCols = numCols % blockSize - val rowCountIR = ts.mapCollect("matrix_block_matrix_writer_partition_counts")(paritionIR => StreamLen(paritionIR)) - val inputRowCountPerPartition: IndexedSeq[Int] = CompileAndEvaluate(ctx, rowCountIR).asInstanceOf[IndexedSeq[Int]] + val rowCountIR = ts.mapCollect("matrix_block_matrix_writer_partition_counts")(paritionIR => + StreamLen(paritionIR) + ) + val inputRowCountPerPartition: IndexedSeq[Int] = + CompileAndEvaluate(ctx, rowCountIR).asInstanceOf[IndexedSeq[Int]] val inputPartStartsPlusLast = inputRowCountPerPartition.scanLeft(0L)(_ + _) val inputPartStarts = inputPartStartsPlusLast.dropRight(1) val inputPartStops = inputPartStartsPlusLast.tail @@ -1488,103 +2365,173 @@ case class MatrixBlockMatrixWriter( val numBlockRows: Int = (numRows.toInt - 1) / blockSize + 1 // Zip contexts with partition starts and ends - val zippedWithStarts = ts.mapContexts{oldContextsStream => zipIR(IndexedSeq(oldContextsStream, ToStream(Literal(TArray(TInt64), inputPartStarts)), ToStream(Literal(TArray(TInt64), inputPartStops))), ArrayZipBehavior.AssertSameLength){ case IndexedSeq(oldCtx, partStart, partStop) => - MakeStruct(FastSeq("mwOld" -> oldCtx, "mwStartIdx" -> Cast(partStart, TInt32), "mwStopIdx" -> Cast(partStop, TInt32))) - }}(newCtx => GetField(newCtx, "mwOld")) + val zippedWithStarts = ts.mapContexts { oldContextsStream => + zipIR( + IndexedSeq( + oldContextsStream, + ToStream(Literal(TArray(TInt64), inputPartStarts)), + ToStream(Literal(TArray(TInt64), inputPartStops)), + ), + ArrayZipBehavior.AssertSameLength, + ) { case IndexedSeq(oldCtx, partStart, partStop) => + MakeStruct(FastSeq( + "mwOld" -> oldCtx, + "mwStartIdx" -> Cast(partStart, TInt32), + "mwStopIdx" -> Cast(partStop, TInt32), + )) + } + }(newCtx => GetField(newCtx, "mwOld")) // Now label each row with its idx. val perRowIdxId = genUID() val partsZippedWithIdx = zippedWithStarts.mapPartitionWithContext { (part, ctx) => - zip2(part, rangeIR(GetField(ctx, "mwStartIdx"), GetField(ctx, "mwStopIdx")), ArrayZipBehavior.AssertSameLength) { (partRow, idx) => - insertIR(partRow, (perRowIdxId, idx)) - } + zip2( + part, + rangeIR(GetField(ctx, "mwStartIdx"), GetField(ctx, "mwStopIdx")), + ArrayZipBehavior.AssertSameLength, + )((partRow, idx) => insertIR(partRow, (perRowIdxId, idx))) } - // Two steps, make a partitioner that works currently based on row_idx splits, then resplit accordingly. - val inputRowIntervals = inputPartStarts.zip(inputPartStops).map{ case (intervalStart, intervalEnd) => - Interval(Row(intervalStart.toInt), Row(intervalEnd.toInt), true, false) - } + /* Two steps, make a partitioner that works currently based on row_idx splits, then resplit + * accordingly. */ + val inputRowIntervals = + inputPartStarts.zip(inputPartStops).map { case (intervalStart, intervalEnd) => + Interval(Row(intervalStart.toInt), Row(intervalEnd.toInt), true, false) + } - val rowIdxPartitioner = new RVDPartitioner(ctx.stateManager, TStruct((perRowIdxId, TInt32)), inputRowIntervals) + val rowIdxPartitioner = + new RVDPartitioner(ctx.stateManager, TStruct((perRowIdxId, TInt32)), inputRowIntervals) val keyedByRowIdx = partsZippedWithIdx.changePartitionerNoRepartition(rowIdxPartitioner) // Now create a partitioner that makes appropriately sized blocks val desiredRowStarts = (0 until numBlockRows).map(_ * blockSize) val desiredRowStops = desiredRowStarts.drop(1) :+ numRows.toInt - val desiredRowIntervals = desiredRowStarts.zip(desiredRowStops).map{ - case (intervalStart, intervalEnd) => Interval(Row(intervalStart), Row(intervalEnd), true, false) + val desiredRowIntervals = desiredRowStarts.zip(desiredRowStops).map { + case (intervalStart, intervalEnd) => + Interval(Row(intervalStart), Row(intervalEnd), true, false) } - val blockSizeGroupsPartitioner = RVDPartitioner.generate(ctx.stateManager, TStruct((perRowIdxId, TInt32)), desiredRowIntervals) - val rowsInBlockSizeGroups: TableStage = keyedByRowIdx.repartitionNoShuffle(ctx, blockSizeGroupsPartitioner) + val blockSizeGroupsPartitioner = + RVDPartitioner.generate(ctx.stateManager, TStruct((perRowIdxId, TInt32)), desiredRowIntervals) + val rowsInBlockSizeGroups: TableStage = + keyedByRowIdx.repartitionNoShuffle(ctx, blockSizeGroupsPartitioner) def createBlockMakingContexts(tablePartsStreamIR: IR): IR = { - flatten(zip2(tablePartsStreamIR, rangeIR(numBlockRows), ArrayZipBehavior.AssertSameLength) { case (tableSinglePartCtx, blockRowIdx) => - mapIR(rangeIR(I32(numBlockCols))){ blockColIdx => - MakeStruct(FastSeq("oldTableCtx" -> tableSinglePartCtx, "blockStart" -> (blockColIdx * I32(blockSize)), - "blockSize" -> If(blockColIdx ceq I32(numBlockCols - 1), I32(lastBlockNumCols), I32(blockSize)), - "blockColIdx" -> blockColIdx, - "blockRowIdx" -> blockRowIdx)) - } + flatten(zip2(tablePartsStreamIR, rangeIR(numBlockRows), ArrayZipBehavior.AssertSameLength) { + case (tableSinglePartCtx, blockRowIdx) => + mapIR(rangeIR(I32(numBlockCols))) { blockColIdx => + MakeStruct(FastSeq( + "oldTableCtx" -> tableSinglePartCtx, + "blockStart" -> (blockColIdx * I32(blockSize)), + "blockSize" -> If( + blockColIdx ceq I32(numBlockCols - 1), + I32(lastBlockNumCols), + I32(blockSize), + ), + "blockColIdx" -> blockColIdx, + "blockRowIdx" -> blockRowIdx, + )) + } }) } - val tableOfNDArrays = rowsInBlockSizeGroups.mapContexts(createBlockMakingContexts)(ir => GetField(ir, "oldTableCtx")).mapPartitionWithContext{ (partIr, ctxRef) => - bindIR(GetField(ctxRef, "blockStart")){ blockStartRef => + val tableOfNDArrays = rowsInBlockSizeGroups.mapContexts(createBlockMakingContexts)(ir => + GetField(ir, "oldTableCtx") + ).mapPartitionWithContext { (partIr, ctxRef) => + bindIR(GetField(ctxRef, "blockStart")) { blockStartRef => val numColsOfBlock = GetField(ctxRef, "blockSize") val arrayOfSlicesAndIndices = ToArray(mapIR(partIr) { singleRow => - val mappedSlice = ToArray(mapIR(ToStream(sliceArrayIR(GetField(singleRow, entriesFieldName), blockStartRef, blockStartRef + numColsOfBlock)))(entriesStructRef => + val mappedSlice = ToArray(mapIR(ToStream(sliceArrayIR( + GetField(singleRow, entriesFieldName), + blockStartRef, + blockStartRef + numColsOfBlock, + )))(entriesStructRef => GetField(entriesStructRef, entryField) )) MakeStruct(FastSeq( perRowIdxId -> GetField(singleRow, perRowIdxId), - "rowOfData" -> mappedSlice + "rowOfData" -> mappedSlice, )) }) - bindIR(arrayOfSlicesAndIndices){ arrayOfSlicesAndIndicesRef => + bindIR(arrayOfSlicesAndIndices) { arrayOfSlicesAndIndicesRef => val idxOfResult = GetField(ArrayRef(arrayOfSlicesAndIndicesRef, I32(0)), perRowIdxId) - val ndarrayData = ToArray(flatMapIR(ToStream(arrayOfSlicesAndIndicesRef)){idxAndSlice => + val ndarrayData = ToArray(flatMapIR(ToStream(arrayOfSlicesAndIndicesRef)) { idxAndSlice => ToStream(GetField(idxAndSlice, "rowOfData")) }) val numRowsOfBlock = ArrayLen(arrayOfSlicesAndIndicesRef) val shape = maketuple(Cast(numRowsOfBlock, TInt64), Cast(numColsOfBlock, TInt64)) val ndarray = MakeNDArray(ndarrayData, shape, True(), ErrorIDs.NO_ERROR) - MakeStream(FastSeq(MakeStruct(FastSeq( - perRowIdxId -> idxOfResult, - "blockRowIdx" -> GetField(ctxRef, "blockRowIdx"), - "blockColIdx" -> GetField(ctxRef, "blockColIdx"), - "ndBlock" -> ndarray))), - TStream(TStruct(perRowIdxId -> TInt32, "blockRowIdx" -> TInt32, "blockColIdx" -> TInt32, "ndBlock" -> ndarray.typ))) + MakeStream( + FastSeq(MakeStruct(FastSeq( + perRowIdxId -> idxOfResult, + "blockRowIdx" -> GetField(ctxRef, "blockRowIdx"), + "blockColIdx" -> GetField(ctxRef, "blockColIdx"), + "ndBlock" -> ndarray, + ))), + TStream(TStruct( + perRowIdxId -> TInt32, + "blockRowIdx" -> TInt32, + "blockColIdx" -> TInt32, + "ndBlock" -> ndarray.typ, + )), + ) } } } val elementType = tm.entryType.fieldType(entryField) - val etype = EBlockMatrixNDArray(EType.fromTypeAndAnalysis(elementType, rm.entryType.field(entryField)), encodeRowMajor = true, required = true) - val spec = TypedCodecSpec(etype, TNDArray(tm.entryType.fieldType(entryField), Nat(2)), BlockMatrix.bufferSpec) + val etype = EBlockMatrixNDArray( + EType.fromTypeAndAnalysis(elementType, rm.entryType.field(entryField)), + encodeRowMajor = true, + required = true, + ) + val spec = TypedCodecSpec( + etype, + TNDArray(tm.entryType.fieldType(entryField), Nat(2)), + BlockMatrix.bufferSpec, + ) val writer = ETypeValueWriter(spec) - val pathsWithColMajorIndices = tableOfNDArrays.mapCollect("matrix_block_matrix_writer") { partition => - ToArray(mapIR(partition) { singleNDArrayTuple => - bindIR(GetField(singleNDArrayTuple, "blockRowIdx") + (GetField(singleNDArrayTuple, "blockColIdx") * numBlockRows)) { colMajorIndex => - val blockPath = - Str(s"$path/parts/part-") + - invoke("str", TString, colMajorIndex) + Str("-") + UUID4() - maketuple(colMajorIndex, WriteValue(GetField(singleNDArrayTuple, "ndBlock"), blockPath, writer)) - } - }) - } + val pathsWithColMajorIndices = + tableOfNDArrays.mapCollect("matrix_block_matrix_writer") { partition => + ToArray(mapIR(partition) { singleNDArrayTuple => + bindIR(GetField(singleNDArrayTuple, "blockRowIdx") + (GetField( + singleNDArrayTuple, + "blockColIdx", + ) * numBlockRows)) { colMajorIndex => + val blockPath = + Str(s"$path/parts/part-") + + invoke("str", TString, colMajorIndex) + Str("-") + UUID4() + maketuple( + colMajorIndex, + WriteValue(GetField(singleNDArrayTuple, "ndBlock"), blockPath, writer), + ) + } + }) + } val flatPathsAndIndices = flatMapIR(ToStream(pathsWithColMajorIndices))(ToStream(_)) - val sortedColMajorPairs = sortIR(flatPathsAndIndices){case (l, r) => ApplyComparisonOp(LT(TInt32), GetTupleElement(l, 0), GetTupleElement(r, 0))} + val sortedColMajorPairs = sortIR(flatPathsAndIndices) { case (l, r) => + ApplyComparisonOp(LT(TInt32), GetTupleElement(l, 0), GetTupleElement(r, 0)) + } val flatPaths = ToArray(mapIR(ToStream(sortedColMajorPairs))(GetTupleElement(_, 1))) - val bmt = BlockMatrixType(elementType, IndexedSeq(numRows, numCols), numRows==1, blockSize, BlockMatrixSparsity.dense) - RelationalWriter.scoped(path, overwrite, None)(WriteMetadata(flatPaths, BlockMatrixNativeMetadataWriter(path, false, bmt))) + val bmt = BlockMatrixType( + elementType, + IndexedSeq(numRows, numCols), + numRows == 1, + blockSize, + BlockMatrixSparsity.dense, + ) + RelationalWriter.scoped(path, overwrite, None)(WriteMetadata( + flatPaths, + BlockMatrixNativeMetadataWriter(path, false, bmt), + )) } } object MatrixNativeMultiWriter { implicit val formats: Formats = new DefaultFormats() { - override val typeHints = ShortTypeHints(List(classOf[MatrixNativeMultiWriter]), typeHintFieldName = "name") + override val typeHints = + ShortTypeHints(List(classOf[MatrixNativeMultiWriter]), typeHintFieldName = "name") } } @@ -1592,18 +2539,22 @@ case class MatrixNativeMultiWriter( paths: IndexedSeq[String], overwrite: Boolean = false, stageLocally: Boolean = false, - codecSpecJSONStr: String = null + codecSpecJSONStr: String = null, ) { val bufferSpec: BufferSpec = BufferSpec.parseOrDefault(codecSpecJSONStr) def apply(ctx: ExecuteContext, mvs: IndexedSeq[MatrixValue]): Unit = MatrixValue.writeMultiple(ctx, mvs, paths, overwrite, stageLocally, bufferSpec) - def lower(ctx: ExecuteContext, tables: IndexedSeq[(String, String, IndexedSeq[String], TableStage, RTable)]): IR = { - val components = paths.zip(tables).map { case (path, (colsFieldName, entriesFieldName, colKey, ts, rt)) => - MatrixNativeWriter.generateComponentFunctions(colsFieldName, entriesFieldName, colKey, - ctx, ts, rt, path, overwrite, stageLocally, codecSpecJSONStr) - } + def lower( + ctx: ExecuteContext, + tables: IndexedSeq[(String, String, IndexedSeq[String], TableStage, RTable)], + ): IR = { + val components = + paths.zip(tables).map { case (path, (colsFieldName, entriesFieldName, colKey, ts, rt)) => + MatrixNativeWriter.generateComponentFunctions(colsFieldName, entriesFieldName, colKey, + ctx, ts, rt, path, overwrite, stageLocally, codecSpecJSONStr) + } require(tables.map(_._4.tableType.keyType).distinct.length == 1) val unionType = TTuple(components.map(c => TIterable.elementType(c.stage.contexts.typ)): _*) @@ -1620,11 +2571,11 @@ case class MatrixNativeMultiWriter( ToArray(mapIR(c.stage.contexts) { ctx => MakeStruct(FastSeq( "matrixId" -> I32(matrixId), - "options" -> MakeTuple(emptyUnionIRs.updated(matrixId, matrixId -> ctx)) + "options" -> MakeTuple(emptyUnionIRs.updated(matrixId, matrixId -> ctx)), )) }) }, - TArray(TArray(contextUnionType)) + TArray(TArray(contextUnionType)), ) ) @@ -1632,31 +2583,40 @@ case class MatrixNativeMultiWriter( Begin(FastSeq( Begin(components.map(_.setup)), - Let(components.flatMap(_.stage.letBindings), - bindIR(cdaIR(concatenatedContexts, allBroadcasts, "matrix_multi_writer") { case (ctx, globals) => - bindIR(GetField(ctx, "options")) { options => - Switch(GetField(ctx, "matrixId"), - default = Die("MatrixId exceeds matrix count", components.head.writePartitionType), - cases = components.zipWithIndex.map { case (component, i) => - val binds = component.stage.broadcastVals.map { case (name, _) => - name -> GetField(globals, name) - } - - Let(binds, bindIR(GetTupleElement(options, i)) { ctxRef => - component.writePartition(component.stage.partition(ctxRef), ctxRef) - }) - } - ) - } + Let( + components.flatMap(_.stage.letBindings), + bindIR(cdaIR(concatenatedContexts, allBroadcasts, "matrix_multi_writer") { + case (ctx, globals) => + bindIR(GetField(ctx, "options")) { options => + Switch( + GetField(ctx, "matrixId"), + default = Die("MatrixId exceeds matrix count", components.head.writePartitionType), + cases = components.zipWithIndex.map { case (component, i) => + val binds = component.stage.broadcastVals.map { case (name, _) => + name -> GetField(globals, name) + } + + Let( + binds, + bindIR(GetTupleElement(options, i)) { ctxRef => + component.writePartition(component.stage.partition(ctxRef), ctxRef) + }, + ) + }, + ) + } }) { cdaResult => val partitionCountScan = components.map(_.stage.numPartitions).scanLeft(0)(_ + _) Begin(components.zipWithIndex.map { case (c, i) => - c.finalizeWrite(ArraySlice(cdaResult, partitionCountScan(i), Some(partitionCountScan(i + 1))), c.stage.globals) + c.finalizeWrite( + ArraySlice(cdaResult, partitionCountScan(i), Some(partitionCountScan(i + 1))), + c.stage.globals, + ) }) - } - ) + }, + ), )) } } diff --git a/hail/src/main/scala/is/hail/expr/ir/Mentions.scala b/hail/src/main/scala/is/hail/expr/ir/Mentions.scala index 049bf3838e5..81555984a09 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Mentions.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Mentions.scala @@ -10,4 +10,4 @@ object Mentions { val fv = FreeVariables(x, true, true) fv.agg.get.lookupOption(name).isDefined || fv.scan.get.lookupOption(name).isDefined } -} \ No newline at end of file +} diff --git a/hail/src/main/scala/is/hail/expr/ir/NativeReaderOptions.scala b/hail/src/main/scala/is/hail/expr/ir/NativeReaderOptions.scala index 61d359a7b14..d0c31e07a69 100644 --- a/hail/src/main/scala/is/hail/expr/ir/NativeReaderOptions.scala +++ b/hail/src/main/scala/is/hail/expr/ir/NativeReaderOptions.scala @@ -1,32 +1,34 @@ package is.hail.expr.ir -import is.hail.types.virtual._ import is.hail.expr.JSONAnnotationImpex +import is.hail.types.virtual._ import is.hail.utils._ + import org.json4s.{CustomSerializer, DefaultFormats, Formats, JObject, JValue} import org.json4s.JsonDSL._ -class NativeReaderOptionsSerializer() extends CustomSerializer[NativeReaderOptions]( - format => - ({ case jObj: JObject => - implicit val fmt = format - val filterIntervals = (jObj \ "filterIntervals").extract[Boolean] - val intervalPointType = IRParser.parseType((jObj \ "intervalPointType").extract[String]) - val intervals = { - val jv = jObj \ "intervals" - val ty = TArray(TInterval(intervalPointType)) - JSONAnnotationImpex.importAnnotation(jv, ty).asInstanceOf[IndexedSeq[Interval]] - } - NativeReaderOptions(intervals, intervalPointType, filterIntervals) - }, { case opts: NativeReaderOptions => - implicit val fmt = format - val ty = TArray(TInterval(opts.intervalPointType)) - (("name" -> opts.getClass.getSimpleName) ~ - ("intervals" -> JSONAnnotationImpex.exportAnnotation(opts.intervals, ty)) ~ - ("intervalPointType" -> opts.intervalPointType.parsableString()) ~ - ("filterIntervals" -> opts.filterIntervals)) - }) -) +class NativeReaderOptionsSerializer() extends CustomSerializer[NativeReaderOptions](format => + ( + { case jObj: JObject => + implicit val fmt = format + val filterIntervals = (jObj \ "filterIntervals").extract[Boolean] + val intervalPointType = IRParser.parseType((jObj \ "intervalPointType").extract[String]) + val intervals = { + val jv = jObj \ "intervals" + val ty = TArray(TInterval(intervalPointType)) + JSONAnnotationImpex.importAnnotation(jv, ty).asInstanceOf[IndexedSeq[Interval]] + } + NativeReaderOptions(intervals, intervalPointType, filterIntervals) + }, + { case opts: NativeReaderOptions => + val ty = TArray(TInterval(opts.intervalPointType)) + ("name" -> opts.getClass.getSimpleName) ~ + ("intervals" -> JSONAnnotationImpex.exportAnnotation(opts.intervals, ty)) ~ + ("intervalPointType" -> opts.intervalPointType.parsableString()) ~ + ("filterIntervals" -> opts.filterIntervals) + }, + ) + ) object NativeReaderOptions { def fromJValue(jv: JValue): NativeReaderOptions = { @@ -46,15 +48,18 @@ object NativeReaderOptions { case class NativeReaderOptions( intervals: IndexedSeq[Interval], intervalPointType: Type, - filterIntervals: Boolean = false) { + filterIntervals: Boolean = false, +) { def toJson: JValue = { val ty = TArray(TInterval(intervalPointType)) JObject( "name" -> "NativeReaderOptions", "intervals" -> JSONAnnotationImpex.exportAnnotation(intervals, ty), "intervalPointType" -> intervalPointType.parsableString(), - "filterIntervals" -> filterIntervals) + "filterIntervals" -> filterIntervals, + ) } - def renderShort(): String = s"(IntervalRead: ${intervals.length} intervals, filter=${filterIntervals})" + def renderShort(): String = + s"(IntervalRead: ${intervals.length} intervals, filter=$filterIntervals)" } diff --git a/hail/src/main/scala/is/hail/expr/ir/NestingDepth.scala b/hail/src/main/scala/is/hail/expr/ir/NestingDepth.scala index b3c71f57345..59a7b298b61 100644 --- a/hail/src/main/scala/is/hail/expr/ir/NestingDepth.scala +++ b/hail/src/main/scala/is/hail/expr/ir/NestingDepth.scala @@ -16,19 +16,30 @@ case class ScopedDepth(eval: Int, agg: Int, scan: Int) { def promoteScanOrAgg(isScan: Boolean): ScopedDepth = if (isScan) promoteScan else promoteAgg } +final class NestingDepth(private val memo: Memo[ScopedDepth]) { + def lookupRef(x: RefEquality[BaseRef]): Int = memo.lookup(x).eval + def lookupRef(x: BaseRef): Int = memo.lookup(x).eval + + def lookupBinding(x: Block, scope: Int): Int = scope match { + case Scope.EVAL => memo(x).eval + case Scope.AGG => memo(x).agg + case Scope.SCAN => memo(x).scan + } +} + object NestingDepth { - def apply(ir0: BaseIR): Memo[Int] = { + def apply(ir0: BaseIR): NestingDepth = { - val memo = Memo.empty[Int] + val memo = Memo.empty[ScopedDepth] def computeChildren(ir: BaseIR): Unit = { ir.children .zipWithIndex .foreach { - case (child: IR, i) => computeIR(child, ScopedDepth(0, 0, 0)) - case (tir: TableIR, i) => computeTable(tir) - case (mir: MatrixIR, i) => computeMatrix(mir) - case (bmir: BlockMatrixIR, i) => computeBlockMatrix(bmir) + case (child: IR, _) => computeIR(child, ScopedDepth(0, 0, 0)) + case (tir: TableIR, _) => computeTable(tir) + case (mir: MatrixIR, _) => computeMatrix(mir) + case (bmir: BlockMatrixIR, _) => computeBlockMatrix(bmir) } } @@ -40,21 +51,18 @@ object NestingDepth { def computeIR(ir: IR, depth: ScopedDepth): Unit = { ir match { - case x@AggLet(_, _, _, false) => - memo.bind(x, depth.agg) - case x@AggLet(_, _, _, true) => - memo.bind(x, depth.scan) + case _: Block | _: BaseRef => + memo.bind(ir, depth) case _ => - memo.bind(ir, depth.eval) } ir match { - case StreamMap(a, name, body) => + case StreamMap(a, _, body) => computeIR(a, depth) computeIR(body, depth.incrementEval) - case StreamAgg(a, name, body) => + case StreamAgg(a, _, body) => computeIR(a, depth) computeIR(body, ScopedDepth(depth.eval, depth.eval + 1, depth.scan)) - case StreamAggScan(a, name, body) => + case StreamAggScan(a, _, body) => computeIR(a, depth) computeIR(body, ScopedDepth(depth.eval, depth.agg, depth.eval + 1)) case StreamZip(as, _, body, _, _) => @@ -67,22 +75,22 @@ object NestingDepth { computeIR(contexts, depth) computeIR(makeProducer, depth.incrementEval) computeIR(joinF, depth.incrementEval) - case StreamFor(a, valueName, body) => + case StreamFor(a, _, body) => computeIR(a, depth) computeIR(body, depth.incrementEval) - case StreamFlatMap(a, name, body) => + case StreamFlatMap(a, _, body) => computeIR(a, depth) computeIR(body, depth.incrementEval) - case StreamFilter(a, name, cond) => + case StreamFilter(a, _, cond) => computeIR(a, depth) computeIR(cond, depth.incrementEval) - case StreamTakeWhile(a, name, cond) => + case StreamTakeWhile(a, _, cond) => computeIR(a, depth) computeIR(cond, depth.incrementEval) - case StreamDropWhile(a, name, cond) => + case StreamDropWhile(a, _, cond) => computeIR(a, depth) computeIR(cond, depth.incrementEval) - case StreamFold(a, zero, accumName, valueName, body) => + case StreamFold(a, zero, _, _, body) => computeIR(a, depth) computeIR(zero, depth) computeIR(body, depth.incrementEval) @@ -91,14 +99,18 @@ object NestingDepth { accum.foreach { case (_, value) => computeIR(value, depth) } seq.foreach(computeIR(_, depth.incrementEval)) computeIR(result, depth) - case StreamScan(a, zero, accumName, valueName, body) => + case StreamScan(a, zero, _, _, body) => computeIR(a, depth) computeIR(zero, depth) computeIR(body, depth.incrementEval) - case StreamJoinRightDistinct(left, right, lKey, rKey, l, r, joinF, joinType) => + case StreamJoinRightDistinct(left, right, _, _, _, _, joinF, _) => computeIR(left, depth) computeIR(right, depth) computeIR(joinF, depth.incrementEval) + case StreamLeftIntervalJoin(left, right, _, _, _, _, body) => + computeIR(left, depth) + computeIR(right, depth) + computeIR(body, depth.incrementEval) case TailLoop(_, params, _, body) => params.foreach { case (_, p) => computeIR(p, depth) } computeIR(body, depth.incrementEval) @@ -127,11 +139,11 @@ object NestingDepth { .zipWithIndex .foreach { case (child: IR, i) => if (UsesAggEnv(ir, i)) - computeIR(child, depth.promoteAgg) - else if (UsesScanEnv(ir, i)) - computeIR(child, depth.promoteScan) - else - computeIR(child, depth) + computeIR(child, depth.promoteAgg) + else if (UsesScanEnv(ir, i)) + computeIR(child, depth.promoteScan) + else + computeIR(child, depth) case (child: TableIR, _) => computeTable(child) case (child: MatrixIR, _) => computeMatrix(child) case (child: BlockMatrixIR, _) => computeBlockMatrix(child) @@ -146,6 +158,6 @@ object NestingDepth { case bmir: BlockMatrixIR => computeBlockMatrix(bmir) } - memo + new NestingDepth(memo) } } diff --git a/hail/src/main/scala/is/hail/expr/ir/NormalizeNames.scala b/hail/src/main/scala/is/hail/expr/ir/NormalizeNames.scala index 67db3bb4158..6e5e780da4a 100644 --- a/hail/src/main/scala/is/hail/expr/ir/NormalizeNames.scala +++ b/hail/src/main/scala/is/hail/expr/ir/NormalizeNames.scala @@ -3,6 +3,8 @@ package is.hail.expr.ir import is.hail.backend.ExecuteContext import is.hail.utils.StackSafe._ +import scala.annotation.nowarn + class NormalizeNames(normFunction: Int => String, allowFreeVariables: Boolean = false) { var count: Int = 0 @@ -18,10 +20,12 @@ class NormalizeNames(normFunction: Int => String, allowFreeVariables: Boolean = normalizeIR(ir.noSharing(ctx), env).run().asInstanceOf[IR] def apply(ctx: ExecuteContext, ir: BaseIR): BaseIR = - normalizeIR(ir.noSharing(ctx), BindingEnv(agg=Some(Env.empty), scan=Some(Env.empty))).run() + normalizeIR(ir.noSharing(ctx), BindingEnv(agg = Some(Env.empty), scan = Some(Env.empty))).run() - private def normalizeIR(ir: BaseIR, env: BindingEnv[String], context: Array[String] = Array()): StackFrame[BaseIR] = { + private def normalizeIR(ir: BaseIR, env: BindingEnv[String], context: Array[String] = Array()) + : StackFrame[BaseIR] = { + @nowarn("cat=unused-locals&msg=default argument") def normalizeBaseIR(next: BaseIR, env: BindingEnv[String] = env): StackFrame[BaseIR] = call(normalizeIR(next, env, context :+ ir.getClass().getName())) @@ -29,29 +33,32 @@ class NormalizeNames(normFunction: Int => String, allowFreeVariables: Boolean = call(normalizeIR(next, env, context :+ ir.getClass().getName()).asInstanceOf[StackFrame[IR]]) ir match { - case Let(bindings, body) => - val newBindings: Array[(String, IR)] = - Array.ofDim(bindings.length) + case Block(bindings, body) => + val newBindings: Array[Binding] = Array.ofDim(bindings.length) for { (env, _) <- bindings.foldLeft(done((env, 0))) { - case (get, (name, value)) => + case (get, Binding(name, value, scope)) => for { (env, idx) <- get - newValue <- normalize(value, env) - newName = gen() - _ = newBindings(idx) = newName -> newValue - } yield (env.bindEval(name, newName), idx + 1) + newValue <- normalize(value, env.promoteScope(scope)) + } yield { + val newName = gen() + newBindings(idx) = Binding(newName, newValue, scope) + (env.bindInScope(name, newName, scope), idx + 1) + } } newBody <- normalize(body, env) - } yield Let(newBindings, newBody) + } yield Block(newBindings, newBody) case Ref(name, typ) => val newName = env.eval.lookupOption(name) match { case Some(n) => n case None => if (!allowFreeVariables) - throw new RuntimeException(s"found free variable in normalize: $name, ${context.reverse.mkString(", ")}; ${env.pretty(x => x)}") + throw new RuntimeException( + s"found free variable in normalize: $name, ${context.reverse.mkString(", ")}; ${env.pretty(x => x)}" + ) else name } @@ -61,30 +68,25 @@ class NormalizeNames(normFunction: Int => String, allowFreeVariables: Boolean = case Some(n) => n case None => if (!allowFreeVariables) - throw new RuntimeException(s"found free loop variable in normalize: $name, ${context.reverse.mkString(", ")}; ${env.pretty(x => x)}") + throw new RuntimeException( + s"found free loop variable in normalize: $name, ${context.reverse.mkString(", ")}; ${env.pretty(x => x)}" + ) else name } for { newArgs <- args.mapRecur(v => normalize(v)) } yield Recur(newName, newArgs, typ) - case AggLet(name, value, body, isScan) => - val newName = gen() - val (valueEnv, bodyEnv) = if (isScan) - env.promoteScan -> env.bindScan(name, newName) - else - env.promoteAgg -> env.bindAgg(name, newName) - for { - newValue <- normalize(value, valueEnv) - newBody <- normalize(body, bodyEnv) - } yield AggLet(newName, newValue, newBody, isScan) case TailLoop(name, args, resultType, body) => val newFName = gen() val newNames = Array.tabulate(args.length)(i => gen()) val (names, values) = args.unzip for { newValues <- values.mapRecur(v => normalize(v)) - newBody <- normalize(body, env.copy(eval = env.eval.bind(names.zip(newNames) :+ name -> newFName: _*))) + newBody <- normalize( + body, + env.copy(eval = env.eval.bind(names.zip(newNames) :+ name -> newFName: _*)), + ) } yield TailLoop(newFName, newNames.zip(newValues), resultType, newBody) case ArraySort(a, left, right, lessThan) => val newLeft = gen() @@ -120,7 +122,16 @@ class NormalizeNames(normFunction: Int => String, allowFreeVariables: Boolean = newCtxs <- normalize(contexts) newMakeProducer <- normalize(makeProducer, env.bindEval(ctxName -> newCtxName)) newJoinF <- normalize(joinF, env.bindEval(curKey -> newCurKey, curVals -> newCurVals)) - } yield StreamZipJoinProducers(newCtxs, newCtxName, newMakeProducer, key, newCurKey, newCurVals, newJoinF) + } yield StreamZipJoinProducers(newCtxs, newCtxName, newMakeProducer, key, newCurKey, + newCurVals, newJoinF) + case StreamLeftIntervalJoin(left, right, lKeyNames, rIntrvlName, lEltName, rEltName, body) => + val newLName = gen() + val newRName = gen() + for { + newL <- normalize(left) + newR <- normalize(right) + newB <- normalize(body, env.bindEval(lEltName -> newLName, rEltName -> newRName)) + } yield StreamLeftIntervalJoin(newL, newR, lKeyNames, rIntrvlName, newLName, newRName, newB) case StreamFilter(a, name, body) => val newName = gen() for { @@ -131,13 +142,13 @@ class NormalizeNames(normFunction: Int => String, allowFreeVariables: Boolean = val newName = gen() for { newA <- normalize(a) - newBody <- normalize(body, env.bindEval(name, newName)) + newBody <- normalize(body, env.bindEval(name, newName)) } yield StreamTakeWhile(newA, newName, newBody) case StreamDropWhile(a, name, body) => val newName = gen() for { newA <- normalize(a) - newBody <- normalize(body, env.bindEval(name, newName)) + newBody <- normalize(body, env.bindEval(name, newName)) } yield StreamDropWhile(newA, newName, newBody) case StreamFlatMap(a, name, body) => val newName = gen() @@ -151,7 +162,8 @@ class NormalizeNames(normFunction: Int => String, allowFreeVariables: Boolean = for { newA <- normalize(a) newZero <- normalize(zero) - newBody <- normalize(body, env.bindEval(accumName -> newAccumName, valueName -> newValueName)) + newBody <- + normalize(body, env.bindEval(accumName -> newAccumName, valueName -> newValueName)) } yield StreamFold(newA, newZero, newAccumName, newValueName, newBody) case StreamFold2(a, accum, valueName, seq, res) => val newValueName = gen() @@ -175,7 +187,8 @@ class NormalizeNames(normFunction: Int => String, allowFreeVariables: Boolean = for { newA <- normalize(a) newZero <- normalize(zero) - newBody <- normalize(body, env.bindEval(accumName -> newAccumName, valueName -> newValueName)) + newBody <- + normalize(body, env.bindEval(accumName -> newAccumName, valueName -> newValueName)) } yield StreamScan(newA, newZero, newAccumName, newValueName, newBody) case StreamFor(a, valueName, body) => val newValueName = gen() @@ -190,8 +203,7 @@ class NormalizeNames(normFunction: Int => String, allowFreeVariables: Boolean = for { newA <- normalize(a) newBody <- normalize(body, env.copy(agg = Some(env.eval.bind(name, newName)))) - } yield - StreamAgg(newA, newName, newBody) + } yield StreamAgg(newA, newName, newBody) case RunAggScan(a, name, init, seq, result, sig) => val newName = gen() val e2 = env.bindEval(name, newName) @@ -218,7 +230,8 @@ class NormalizeNames(normFunction: Int => String, allowFreeVariables: Boolean = newLeft <- normalize(left) newRight <- normalize(right) newJoinF <- normalize(joinF, newEnv) - } yield StreamJoinRightDistinct(newLeft, newRight, lKey, rKey, newL, newR, newJoinF, joinType) + } yield StreamJoinRightDistinct(newLeft, newRight, lKey, rKey, newL, newR, newJoinF, + joinType) case NDArrayMap(nd, name, body) => val newName = gen() for { @@ -244,7 +257,8 @@ class NormalizeNames(normFunction: Int => String, allowFreeVariables: Boolean = newA <- normalize(a, aEnv) newAggBody <- normalize(aggBody, bodyEnv.bindEval(indexName, newIndexName)) newKnownLength <- knownLength.mapRecur(normalize(_, env)) - } yield AggArrayPerElement(newA, newElementName, newIndexName, newAggBody, newKnownLength, isScan) + } yield AggArrayPerElement(newA, newElementName, newIndexName, newAggBody, newKnownLength, + isScan) case CollectDistributedArray(ctxs, globals, cname, gname, body, dynamicID, staticID, tsd) => val newC = gen() val newG = gen() @@ -253,7 +267,8 @@ class NormalizeNames(normFunction: Int => String, allowFreeVariables: Boolean = newGlobals <- normalize(globals) newBody <- normalize(body, BindingEnv.eval(cname -> newC, gname -> newG)) newDynamicID <- normalize(dynamicID) - } yield CollectDistributedArray(newCtxs, newGlobals, newC, newG, newBody, newDynamicID, staticID, tsd) + } yield CollectDistributedArray(newCtxs, newGlobals, newC, newG, newBody, newDynamicID, + staticID, tsd) case RelationalLet(name, value, body) => val newName = gen() for { @@ -262,14 +277,19 @@ class NormalizeNames(normFunction: Int => String, allowFreeVariables: Boolean = } yield RelationalLet(newName, newValue, newBody) case RelationalRef(name, typ) => val newName = env.relational.lookupOption(name).getOrElse( - if (!allowFreeVariables) throw new RuntimeException(s"found free variable in normalize: $name, ${context.reverse.mkString(", ")}; ${env.pretty(x => x)}") + if (!allowFreeVariables) throw new RuntimeException( + s"found free variable in normalize: $name, ${context.reverse.mkString(", ")}; ${env.pretty(x => x)}" + ) else name ) done(RelationalRef(newName, typ)) case x => x.mapChildrenWithIndexStackSafe { (child, i) => - normalizeBaseIR(child, ChildBindings.transformed(x, i, env, { case (name, _) => name })) + normalizeBaseIR( + child, + Bindings.segregated(x, i, env).mapNewBindings((name, _) => name).unified, + ) } } } diff --git a/hail/src/main/scala/is/hail/expr/ir/NumericPrimitives.scala b/hail/src/main/scala/is/hail/expr/ir/NumericPrimitives.scala index b44a6733a6a..20faa03ffcc 100644 --- a/hail/src/main/scala/is/hail/expr/ir/NumericPrimitives.scala +++ b/hail/src/main/scala/is/hail/expr/ir/NumericPrimitives.scala @@ -1,7 +1,7 @@ package is.hail.expr.ir -import is.hail.asm4s.{Settable, coerce} -import is.hail.types.virtual.{Type, TInt32, TInt64, TFloat32, TFloat64} +import is.hail.asm4s.{coerce, Settable} +import is.hail.types.virtual.{TFloat32, TFloat64, TInt32, TInt64, Type} object NumericPrimitives { diff --git a/hail/src/main/scala/is/hail/expr/ir/Optimize.scala b/hail/src/main/scala/is/hail/expr/ir/Optimize.scala index 49412f0d91b..ecb33cd669f 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Optimize.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Optimize.scala @@ -2,7 +2,6 @@ package is.hail.expr.ir import is.hail.HailContext import is.hail.backend.ExecuteContext -import is.hail.utils._ object Optimize { def apply[T <: BaseIR](ir0: T, context: String, ctx: ExecuteContext): T = { @@ -11,9 +10,8 @@ object Optimize { var iter = 0 val maxIter = HailContext.get.optimizerIterations - def runOpt(f: BaseIR => BaseIR, iter: Int, optContext: String): Unit = { + def runOpt(f: BaseIR => BaseIR, iter: Int, optContext: String): Unit = ir = ctx.timer.time(optContext)(f(ir).asInstanceOf[T]) - } ctx.timer.time("Optimize") { val normalizeNames = new NormalizeNames(_ => genUID(), allowFreeVariables = true) @@ -33,10 +31,10 @@ object Optimize { if (ir.typ != ir0.typ) throw new RuntimeException(s"optimization changed type!" + - s"\n before: ${ ir0.typ.parsableString() }" + - s"\n after: ${ ir.typ.parsableString() }" + - s"\n Before IR:\n ----------\n${ Pretty(ctx, ir0) }" + - s"\n After IR:\n ---------\n${ Pretty(ctx, ir) }") + s"\n before: ${ir0.typ.parsableString()}" + + s"\n after: ${ir.typ.parsableString()}" + + s"\n Before IR:\n ----------\n${Pretty(ctx, ir0)}" + + s"\n After IR:\n ---------\n${Pretty(ctx, ir)}") ir } diff --git a/hail/src/main/scala/is/hail/expr/ir/Param.scala b/hail/src/main/scala/is/hail/expr/ir/Param.scala index 3a574ef7e0a..f9935540e52 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Param.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Param.scala @@ -5,9 +5,7 @@ import is.hail.types.physical.stypes.{EmitType, SType, SValue, SingleCodeType} import is.hail.types.virtual.Type import is.hail.utils.FastSeq -import scala.language.existentials - -sealed trait ParamType { +sealed trait ParamType { def nCodes: Int } diff --git a/hail/src/main/scala/is/hail/expr/ir/Parser.scala b/hail/src/main/scala/is/hail/expr/ir/Parser.scala index 4edd4f9c5f8..2042db3857e 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Parser.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Parser.scala @@ -2,28 +2,30 @@ package is.hail.expr.ir import is.hail.HailContext import is.hail.backend.ExecuteContext +import is.hail.expr.{JSONAnnotationImpex, Nat, ParserUtils} import is.hail.expr.ir.agg._ import is.hail.expr.ir.functions.RelationalFunctions -import is.hail.expr.{JSONAnnotationImpex, Nat, ParserUtils} import is.hail.io.{BufferSpec, TypedCodecSpec} import is.hail.rvd.{RVDPartitioner, RVDType} +import is.hail.types.{tcoerce, MatrixType, TableType, VirtualTypeWithReq} import is.hail.types.encoded.EType import is.hail.types.physical._ import is.hail.types.virtual._ -import is.hail.types.{MatrixType, TableType, VirtualTypeWithReq, tcoerce} +import is.hail.utils._ import is.hail.utils.StackSafe._ import is.hail.utils.StringEscapeUtils._ -import is.hail.utils._ -import org.apache.spark.sql.Row -import org.json4s.jackson.{JsonMethods, Serialization} -import org.json4s.{Formats, JObject} -import java.util.Base64 import scala.collection.mutable import scala.reflect.ClassTag import scala.util.parsing.combinator.JavaTokenParsers import scala.util.parsing.input.Positional +import java.util.Base64 + +import org.apache.spark.sql.Row +import org.json4s.{Formats, JObject} +import org.json4s.jackson.{JsonMethods, Serialization} + abstract class Token extends Positional { def value: Any @@ -114,31 +116,31 @@ object IRLexer extends JavaTokenParsers { def int64_literal: Parser[Long] = wholeNumber.map(_.toLong) def float64_literal: Parser[Double] = - "-inf" ^^ { _ => Double.NegativeInfinity } | // inf, neginf, and nan are parsed as identifiers + "-inf" ^^ { _ => Double.NegativeInfinity } | // inf, neginf, and nan are parsed as identifiers """[+-]?\d+(\.\d+)?[eE][+-]?\d+""".r ^^ { _.toDouble } | """[+-]?\d*\.\d+""".r ^^ { _.toDouble } - def parse(code: String): Array[Token] = { + def parse(code: String): Array[Token] = parseAll(lexer, code) match { case Success(result, _) => result case NoSuccess(msg, next) => ParserUtils.error(next.pos, msg) } - } } case class IRParserEnvironment( ctx: ExecuteContext, - irMap: Map[Int, BaseIR] = Map.empty) + irMap: Map[Int, BaseIR] = Map.empty, +) object IRParser { def error(t: Token, msg: String): Nothing = ParserUtils.error(t.pos, msg) def deserialize[T](str: String)(implicit formats: Formats, mf: Manifest[T]): T = { - try { + try Serialization.read[T](str) - } catch { + catch { case e: org.json4s.MappingException => - throw new RuntimeException(s"Couldn't deserialize ${str}", e) + throw new RuntimeException(s"Couldn't deserialize $str", e) } } @@ -149,37 +151,35 @@ object IRParser { it.next() } - def punctuation(it: TokenIterator, symbol: String): String = { + def punctuation(it: TokenIterator, symbol: String): String = consumeToken(it) match { case x: PunctuationToken if x.value == symbol => x.value - case x: Token => error(x, s"Expected punctuation '$symbol' but found ${ x.getName } '${ x.value }'.") + case x: Token => + error(x, s"Expected punctuation '$symbol' but found ${x.getName} '${x.value}'.") } - } - def identifier(it: TokenIterator): String = { + def identifier(it: TokenIterator): String = consumeToken(it) match { case x: IdentifierToken => x.value - case x: Token => error(x, s"Expected identifier but found ${ x.getName } '${ x.value }'.") + case x: Token => error(x, s"Expected identifier but found ${x.getName} '${x.value}'.") } - } - def identifier(it: TokenIterator, expectedId: String): String = { + def identifier(it: TokenIterator, expectedId: String): String = consumeToken(it) match { case x: IdentifierToken if x.value == expectedId => x.value - case x: Token => error(x, s"Expected identifier '$expectedId' but found ${ x.getName } '${ x.value }'.") + case x: Token => + error(x, s"Expected identifier '$expectedId' but found ${x.getName} '${x.value}'.") } - } def identifiers(it: TokenIterator): Array[String] = base_seq_parser(identifier)(it) - def boolean_literal(it: TokenIterator): Boolean = { + def boolean_literal(it: TokenIterator): Boolean = consumeToken(it) match { case IdentifierToken("True") => true case IdentifierToken("False") => false - case x: Token => error(x, s"Expected boolean but found ${ x.getName } '${ x.value }'.") + case x: Token => error(x, s"Expected boolean but found ${x.getName} '${x.value}'.") } - } def int32_literal(it: TokenIterator): Int = { consumeToken(it) match { @@ -187,17 +187,16 @@ object IRParser { if (x.value >= Int.MinValue && x.value <= Int.MaxValue) x.value.toInt else - error(x, s"Found integer '${ x.value }' that is outside the numeric range for int32.") - case x: Token => error(x, s"Expected integer but found ${ x.getName } '${ x.value }'.") + error(x, s"Found integer '${x.value}' that is outside the numeric range for int32.") + case x: Token => error(x, s"Expected integer but found ${x.getName} '${x.value}'.") } } - def int64_literal(it: TokenIterator): Long = { + def int64_literal(it: TokenIterator): Long = consumeToken(it) match { case x: IntegerToken => x.value - case x: Token => error(x, s"Expected integer but found ${ x.getName } '${ x.value }'.") + case x: Token => error(x, s"Expected integer but found ${x.getName} '${x.value}'.") } - } def float32_literal(it: TokenIterator): Float = { consumeToken(it) match { @@ -205,15 +204,15 @@ object IRParser { if (x.value >= Float.MinValue && x.value <= Float.MaxValue) x.value.toFloat else - error(x, s"Found float '${ x.value }' that is outside the numeric range for float32.") + error(x, s"Found float '${x.value}' that is outside the numeric range for float32.") case x: IntegerToken => x.value.toFloat case x: IdentifierToken => x.value match { - case "nan" => Float.NaN - case "inf" => Float.PositiveInfinity - case "neginf" => Float.NegativeInfinity - case _ => error(x, s"Expected float but found ${ x.getName } '${ x.value }'.") - } - case x: Token => error(x, s"Expected float but found ${ x.getName } '${ x.value }'.") + case "nan" => Float.NaN + case "inf" => Float.PositiveInfinity + case "neginf" => Float.NegativeInfinity + case _ => error(x, s"Expected float but found ${x.getName} '${x.value}'.") + } + case x: Token => error(x, s"Expected float but found ${x.getName} '${x.value}'.") } } @@ -222,42 +221,54 @@ object IRParser { case x: FloatToken => x.value case x: IntegerToken => x.value.toDouble case x: IdentifierToken => x.value match { - case "nan" => Double.NaN - case "inf" => Double.PositiveInfinity - case "neginf" => Double.NegativeInfinity - case _ => error(x, s"Expected float but found ${ x.getName } '${ x.value }'.") - } - case x: Token => error(x, s"Expected float but found ${ x.getName } '${ x.value }'.") + case "nan" => Double.NaN + case "inf" => Double.PositiveInfinity + case "neginf" => Double.NegativeInfinity + case _ => error(x, s"Expected float but found ${x.getName} '${x.value}'.") + } + case x: Token => error(x, s"Expected float but found ${x.getName} '${x.value}'.") } } - def string_literal(it: TokenIterator): String = { + def string_literal(it: TokenIterator): String = consumeToken(it) match { case x: StringToken => x.value - case x: Token => error(x, s"Expected string but found ${ x.getName } '${ x.value }'.") + case x: Token => error(x, s"Expected string but found ${x.getName} '${x.value}'.") } - } def partitioner_literal(env: IRParserEnvironment)(it: TokenIterator): RVDPartitioner = { identifier(it, "Partitioner") val keyType = type_expr(it).asInstanceOf[TStruct] val vJSON = JsonMethods.parse(string_literal(it)) val rangeBounds = JSONAnnotationImpex.importAnnotation(vJSON, TArray(TInterval(keyType))) - new RVDPartitioner(env.ctx.stateManager, keyType, rangeBounds.asInstanceOf[mutable.IndexedSeq[Interval]]) + new RVDPartitioner( + env.ctx.stateManager, + keyType, + rangeBounds.asInstanceOf[mutable.IndexedSeq[Interval]], + ) } - def literals[T](literalIdentifier: TokenIterator => T)(it: TokenIterator)(implicit tct: ClassTag[T]): Array[T] = + def literals[T]( + literalIdentifier: TokenIterator => T + )( + it: TokenIterator + )(implicit tct: ClassTag[T] + ): Array[T] = base_seq_parser(literalIdentifier)(it) - def between[A](open: TokenIterator => Any, close: TokenIterator => Any, f: TokenIterator => A) - (it: TokenIterator): A = { + def between[A]( + open: TokenIterator => Any, + close: TokenIterator => Any, + f: TokenIterator => A, + )( + it: TokenIterator + ): A = { open(it) val a = f(it) close(it) a } - def string_literals: TokenIterator => Array[String] = literals(string_literal) def int32_literals: TokenIterator => Array[Int] = literals(int32_literal) def int64_literals: TokenIterator => Array[Long] = literals(int64_literal) @@ -272,10 +283,13 @@ object IRParser { } } - def repsepUntil[T](it: TokenIterator, + def repsepUntil[T]( + it: TokenIterator, f: (TokenIterator) => T, sep: Token, - end: Token)(implicit tct: ClassTag[T]): Array[T] = { + end: Token, + )(implicit tct: ClassTag[T] + ): Array[T] = { val xs = new mutable.ArrayBuffer[T]() while (it.hasNext && it.head != end) { xs += f(it) @@ -285,18 +299,20 @@ object IRParser { xs.toArray } - def repUntil[T](it: TokenIterator, + def repUntil[T]( + it: TokenIterator, f: (TokenIterator) => StackFrame[T], - end: Token)(implicit tct: ClassTag[T]): StackFrame[Array[T]] = { + end: Token, + )(implicit tct: ClassTag[T] + ): StackFrame[Array[T]] = { val xs = new mutable.ArrayBuffer[T]() var cont: T => StackFrame[Array[T]] = null - def loop(): StackFrame[Array[T]] = { + def loop(): StackFrame[Array[T]] = if (it.hasNext && it.head != end) { f(it).flatMap(cont) } else { done(xs.toArray) } - } cont = { t => xs += t loop() @@ -304,17 +320,19 @@ object IRParser { loop() } - def repUntilNonStackSafe[T](it: TokenIterator, + def repUntilNonStackSafe[T]( + it: TokenIterator, f: (TokenIterator) => T, - end: Token)(implicit tct: ClassTag[T]): Array[T] = { + end: Token, + )(implicit tct: ClassTag[T] + ): Array[T] = { val xs = new mutable.ArrayBuffer[T]() - while (it.hasNext && it.head != end) { + while (it.hasNext && it.head != end) xs += f(it) - } xs.toArray } - def base_seq_parser[T : ClassTag](f: TokenIterator => T)(it: TokenIterator): Array[T] = { + def base_seq_parser[T: ClassTag](f: TokenIterator => T)(it: TokenIterator): Array[T] = { punctuation(it, "(") val r = repUntilNonStackSafe(it, f, PunctuationToken(")")) punctuation(it, ")") @@ -336,7 +354,6 @@ object IRParser { i -> t } - def tuple_subset_field(it: TokenIterator): (Int, Type) = { val i = int32_literal(it) punctuation(it, ":") @@ -348,19 +365,16 @@ object IRParser { val name = identifier(it) punctuation(it, ":") val typ = f(it) - while (it.hasNext && it.head == PunctuationToken("@")) { + while (it.hasNext && it.head == PunctuationToken("@")) decorator(it) - } (name, typ) } - def ptype_field(it: TokenIterator): (String, PType) = { + def ptype_field(it: TokenIterator): (String, PType) = struct_field(ptype_expr)(it) - } - def type_field(it: TokenIterator): (String, Type) = { + def type_field(it: TokenIterator): (String, Type) = struct_field(type_expr)(it) - } def vtwr_expr(it: TokenIterator): VirtualTypeWithReq = { val pt = ptype_expr(it) @@ -420,9 +434,10 @@ object IRParser { PCanonicalDict(keyType, valueType, req) case "PCTuple" => punctuation(it, "[") - val fields = repsepUntil(it, ptuple_subset_field, PunctuationToken(","), PunctuationToken("]")) + val fields = + repsepUntil(it, ptuple_subset_field, PunctuationToken(","), PunctuationToken("]")) punctuation(it, "]") - PCanonicalTuple(fields.map { case (idx, t) => PTupleField(idx, t)}, req) + PCanonicalTuple(fields.map { case (idx, t) => PTupleField(idx, t) }, req) case "PCStruct" => punctuation(it, "{") val args = repsepUntil(it, ptype_field, PunctuationToken(","), PunctuationToken("}")) @@ -510,9 +525,10 @@ object IRParser { TTuple(types: _*) case "TupleSubset" => punctuation(it, "[") - val fields = repsepUntil(it, tuple_subset_field, PunctuationToken(","), PunctuationToken("]")) + val fields = + repsepUntil(it, tuple_subset_field, PunctuationToken(","), PunctuationToken("]")) punctuation(it, "]") - TTuple(fields.map { case (idx, t) => TupleField(idx, t)}) + TTuple(fields.map { case (idx, t) => TupleField(idx, t) }) case "Struct" => punctuation(it, "{") val args = repsepUntil(it, type_field, PunctuationToken(","), PunctuationToken("}")) @@ -633,7 +649,14 @@ object IRParser { val entryType = tcoerce[TStruct](type_expr(it)) punctuation(it, "}") - MatrixType(tcoerce[TStruct](globalType), colKey, colType, rowPartitionKey ++ rowRestKey, rowType, entryType) + MatrixType( + tcoerce[TStruct](globalType), + colKey, + colType, + rowPartitionKey ++ rowRestKey, + rowType, + entryType, + ) } def agg_op(it: TokenIterator): AggOp = @@ -735,7 +758,8 @@ object IRParser { (typ, v) } - def named_value_irs(env: IRParserEnvironment)(it: TokenIterator): StackFrame[Array[(String, IR)]] = + def named_value_irs(env: IRParserEnvironment)(it: TokenIterator) + : StackFrame[Array[(String, IR)]] = repUntil(it, named_value_ir(env), PunctuationToken(")")) def named_value_ir(env: IRParserEnvironment)(it: TokenIterator): StackFrame[(String, IR)] = { @@ -766,14 +790,17 @@ object IRParser { } yield ir } - def apply_like(env: IRParserEnvironment, cons: (String, Seq[Type], Seq[IR], Type, Int) => IR)(it: TokenIterator): StackFrame[IR] = { + def apply_like( + env: IRParserEnvironment, + cons: (String, Seq[Type], Seq[IR], Type, Int) => IR, + )( + it: TokenIterator + ): StackFrame[IR] = { val errorID = int32_literal(it) val function = identifier(it) val typeArgs = type_exprs(it) val rt = type_expr(it) - ir_value_children(env)(it).map { args => - cons(function, typeArgs, args, rt, errorID) - } + ir_value_children(env)(it).map(args => cons(function, typeArgs, args, rt, errorID)) } def ir_value_expr_1(env: IRParserEnvironment)(it: TokenIterator): StackFrame[IR] = { @@ -795,7 +822,7 @@ object IRParser { val codec = TypedCodecSpec( EType.fromPythonTypeEncoding(typ), typ, - BufferSpec.unblockedUncompressed + BufferSpec.unblockedUncompressed, ) done(EncodedLiteral(codec, Array(encodedValue))) case "Void" => done(Void()) @@ -824,8 +851,9 @@ object IRParser { default <- ir_value_expr(env)(it) cases <- ir_value_children(env)(it) } yield Switch(x, default, cases) - case "Let" => - val names = repUntilNonStackSafe(it, identifier, PunctuationToken("(")) + case "Let" | "Block" => + val names = + repUntilNonStackSafe(it, it => (identifier(it), identifier(it)), PunctuationToken("(")) val values = new Array[IR](names.length) for { _ <- names.indices.foldLeft(done(())) { case (update, i) => @@ -835,7 +863,17 @@ object IRParser { } yield values.update(i, value) } body <- ir_value_expr(env)(it) - } yield Let(names.zip(values).toFastSeq, body) + } yield { + val bindings = (names, values).zipped.map { case ((bindType, name), value) => + val scope = bindType match { + case "eval" => Scope.EVAL + case "agg" => Scope.AGG + case "scan" => Scope.SCAN + } + Binding(name, value, scope) + } + Block(bindings, body) + } case "AggLet" => val name = identifier(it) val isScan = boolean_literal(it) @@ -854,9 +892,7 @@ object IRParser { } yield TailLoop(name, params, resultType, body) case "Recur" => val name = identifier(it) - ir_value_children(env)(it).map { args => - Recur(name, args, null) - } + ir_value_children(env)(it).map(args => Recur(name, args, null)) case "Ref" => val id = identifier(it) done(Ref(id, null)) @@ -887,9 +923,7 @@ object IRParser { } yield ApplyComparisonOp(ComparisonOp.fromString(opName), l, r) case "MakeArray" => val typ = opt(it, type_expr).map(_.asInstanceOf[TArray]).orNull - ir_value_children(env)(it).map { args => - MakeArray(args, typ) - } + ir_value_children(env)(it).map(args => MakeArray(args, typ)) case "MakeStream" => val typ = opt(it, type_expr).map(_.asInstanceOf[TStream]).orNull val requiresMemoryManagementPerElement = boolean_literal(it) @@ -972,9 +1006,7 @@ object IRParser { } yield NDArrayReshape(nd, shape, errorID) case "NDArrayConcat" => val axis = int32_literal(it) - ir_value_expr(env)(it).map { nds => - NDArrayConcat(nds, axis) - } + ir_value_expr(env)(it).map(nds => NDArrayConcat(nds, axis)) case "NDArrayMap" => val name = identifier(it) for { @@ -992,14 +1024,10 @@ object IRParser { } yield NDArrayMap2(l, r, lName, rName, body, errorID) case "NDArrayReindex" => val indexExpr = int32_literals(it) - ir_value_expr(env)(it).map { nd => - NDArrayReindex(nd, indexExpr) - } + ir_value_expr(env)(it).map(nd => NDArrayReindex(nd, indexExpr)) case "NDArrayAgg" => val axes = int32_literals(it) - ir_value_expr(env)(it).map { nd => - NDArrayAgg(nd, axes) - } + ir_value_expr(env)(it).map(nd => NDArrayAgg(nd, axes)) case "NDArrayRef" => val errorID = int32_literal(it) for { @@ -1030,34 +1058,26 @@ object IRParser { case "NDArrayQR" => val errorID = int32_literal(it) val mode = string_literal(it) - ir_value_expr(env)(it).map { nd => - NDArrayQR(nd, mode, errorID) - } + ir_value_expr(env)(it).map(nd => NDArrayQR(nd, mode, errorID)) case "NDArraySVD" => val errorID = int32_literal(it) val fullMatrices = boolean_literal(it) val computeUV = boolean_literal(it) - ir_value_expr(env)(it).map { nd => - NDArraySVD(nd, fullMatrices, computeUV, errorID) - } + ir_value_expr(env)(it).map(nd => NDArraySVD(nd, fullMatrices, computeUV, errorID)) case "NDArrayEigh" => val errorID = int32_literal(it) val eigvalsOnly = boolean_literal(it) - ir_value_expr(env)(it).map { nd => - NDArrayEigh(nd, eigvalsOnly, errorID) - } + ir_value_expr(env)(it).map(nd => NDArrayEigh(nd, eigvalsOnly, errorID)) case "NDArrayInv" => val errorID = int32_literal(it) - ir_value_expr(env)(it).map{ nd => NDArrayInv(nd, errorID) } + ir_value_expr(env)(it).map(nd => NDArrayInv(nd, errorID)) case "ToSet" => ir_value_expr(env)(it).map(ToSet) case "ToDict" => ir_value_expr(env)(it).map(ToDict) case "ToArray" => ir_value_expr(env)(it).map(ToArray) case "CastToArray" => ir_value_expr(env)(it).map(CastToArray) case "ToStream" => val requiresMemoryManagementPerElement = boolean_literal(it) - ir_value_expr(env)(it).map { a => - ToStream(a, requiresMemoryManagementPerElement) - } + ir_value_expr(env)(it).map(a => ToStream(a, requiresMemoryManagementPerElement)) case "LowerBoundOnOrderedCollection" => val onKey = boolean_literal(it) for { @@ -1102,9 +1122,8 @@ object IRParser { for { ctxs <- ir_value_expr(env)(it) makeProducer <- ir_value_expr(env)(it) - body <- { + body <- ir_value_expr(env)(it) - } } yield StreamZipJoinProducers(ctxs, ctxName, makeProducer, key, curKey, curVals, body) case "StreamZipJoin" => val nStreams = int32_literal(it) @@ -1113,9 +1132,8 @@ object IRParser { val curVals = identifier(it) for { streams <- (0 until nStreams).mapRecur(_ => ir_value_expr(env)(it)) - body <- { + body <- ir_value_expr(env)(it) - } } yield StreamZipJoin(streams, key, curKey, curVals, body) case "StreamMultiMerge" => val key = identifiers(it) @@ -1132,13 +1150,13 @@ object IRParser { val name = identifier(it) for { a <- ir_value_expr(env)(it) - body <- ir_value_expr(env)(it) + body <- ir_value_expr(env)(it) } yield StreamTakeWhile(a, name, body) case "StreamDropWhile" => val name = identifier(it) for { a <- ir_value_expr(env)(it) - body <- ir_value_expr(env)(it) + body <- ir_value_expr(env)(it) } yield StreamDropWhile(a, name, body) case "StreamFlatMap" => val name = identifier(it) @@ -1182,7 +1200,8 @@ object IRParser { val normalizeAfterWhitening = boolean_literal(it) for { stream <- ir_value_expr(env)(it) - } yield StreamWhiten(stream, newChunk, prevWindow, vecSize, windowSize, chunkSize, blockSize, normalizeAfterWhitening) + } yield StreamWhiten(stream, newChunk, prevWindow, vecSize, windowSize, chunkSize, + blockSize, normalizeAfterWhitening) case "StreamJoinRightDistinct" => val lKey = identifiers(it) val rKey = identifiers(it) @@ -1194,6 +1213,18 @@ object IRParser { right <- ir_value_expr(env)(it) join <- ir_value_expr(env)(it) } yield StreamJoinRightDistinct(left, right, lKey, rKey, l, r, join, joinType) + case "StreamLeftIntervalJoin" => + val lKeyFieldName = identifier(it) + val rIntervalName = identifier(it) + val lname = identifier(it) + val rname = identifier(it) + for { + left <- ir_value_expr(env)(it) + right <- ir_value_expr(env)(it) + body <- ir_value_expr(env)(it) + } yield StreamLeftIntervalJoin(left, right, lKeyFieldName, rIntervalName, lname, rname, + body) + case "StreamFor" => val name = identifier(it) for { @@ -1282,15 +1313,11 @@ object IRParser { case "InitOp" => val i = int32_literal(it) val aggSig = p_agg_sig(env)(it) - ir_value_exprs(env)(it).map { args => - InitOp(i, args, aggSig) - } + ir_value_exprs(env)(it).map(args => InitOp(i, args, aggSig)) case "SeqOp" => val i = int32_literal(it) val aggSig = p_agg_sig(env)(it) - ir_value_exprs(env)(it).map { args => - SeqOp(i, args, aggSig) - } + ir_value_exprs(env)(it).map(args => SeqOp(i, args, aggSig)) case "CombOp" => val i1 = int32_literal(it) val i2 = int32_literal(it) @@ -1307,15 +1334,11 @@ object IRParser { case "InitFromSerializedValue" => val i = int32_literal(it) val sig = agg_state_signature(env)(it) - ir_value_expr(env)(it).map { value => - InitFromSerializedValue(i, value, sig) - } + ir_value_expr(env)(it).map(value => InitFromSerializedValue(i, value, sig)) case "CombOpValue" => val i = int32_literal(it) val sig = p_agg_sig(env)(it) - ir_value_expr(env)(it).map { value => - CombOpValue(i, value, sig) - } + ir_value_expr(env)(it).map(value => CombOpValue(i, value, sig)) case "SerializeAggs" => val i = int32_literal(it) val i2 = int32_literal(it) @@ -1332,9 +1355,7 @@ object IRParser { case "MakeStruct" => named_value_irs(env)(it).map(MakeStruct(_)) case "SelectFields" => val fields = identifiers(it) - ir_value_expr(env)(it).map { old => - SelectFields(old, fields) - } + ir_value_expr(env)(it).map(old => SelectFields(old, fields)) case "InsertFields" => for { old <- ir_value_expr(env)(it) @@ -1343,29 +1364,19 @@ object IRParser { } yield InsertFields(old, fields, fieldOrder.map(_.toFastSeq)) case "GetField" => val name = identifier(it) - ir_value_expr(env)(it).map { s => - GetField(s, name) - } + ir_value_expr(env)(it).map(s => GetField(s, name)) case "MakeTuple" => val indices = int32_literals(it) - ir_value_children(env)(it).map { args => - MakeTuple(indices.zip(args)) - } + ir_value_children(env)(it).map(args => MakeTuple(indices.zip(args))) case "GetTupleElement" => val idx = int32_literal(it) - ir_value_expr(env)(it).map { tuple => - GetTupleElement(tuple, idx) - } + ir_value_expr(env)(it).map(tuple => GetTupleElement(tuple, idx)) case "Die" => val typ = type_expr(it) val errorID = int32_literal(it) - ir_value_expr(env)(it).map { msg => - Die(msg, typ, errorID) - } + ir_value_expr(env)(it).map(msg => Die(msg, typ, errorID)) case "Trap" => - ir_value_expr(env)(it).map { child => - Trap(child) - } + ir_value_expr(env)(it).map(child => Trap(child)) case "ConsoleLog" => for { msg <- ir_value_expr(env)(it) @@ -1411,16 +1422,17 @@ object IRParser { case "BlockMatrixToValueApply" => val config = string_literal(it) blockmatrix_ir(env)(it).map { child => - BlockMatrixToValueApply(child, RelationalFunctions.lookupBlockMatrixToValue(env.ctx, config)) + BlockMatrixToValueApply( + child, + RelationalFunctions.lookupBlockMatrixToValue(env.ctx, config), + ) } case "BlockMatrixCollect" => blockmatrix_ir(env)(it).map(BlockMatrixCollect) case "TableWrite" => implicit val formats = TableWriter.formats val writerStr = string_literal(it) - table_ir(env)(it).map { child => - TableWrite(child, deserialize[TableWriter](writerStr)) - } + table_ir(env)(it).map(child => TableWrite(child, deserialize[TableWriter](writerStr))) case "TableMultiWrite" => implicit val formats = WrappedMatrixNativeMultiWriter.formats val writerStr = string_literal(it) @@ -1436,23 +1448,17 @@ object IRParser { val writerStr = string_literal(it) implicit val formats: Formats = MatrixWriter.formats val writer = deserialize[MatrixWriter](writerStr) - matrix_ir(env)(it).map { child => - MatrixWrite(child, writer) - } + matrix_ir(env)(it).map(child => MatrixWrite(child, writer)) case "MatrixMultiWrite" => val writerStr = string_literal(it) implicit val formats = MatrixNativeMultiWriter.formats val writer = deserialize[MatrixNativeMultiWriter](writerStr) - matrix_ir_children(env)(it).map { children => - MatrixMultiWrite(children, writer) - } + matrix_ir_children(env)(it).map(children => MatrixMultiWrite(children, writer)) case "BlockMatrixWrite" => val writerStr = string_literal(it) implicit val formats: Formats = BlockMatrixWriter.formats val writer = deserialize[BlockMatrixWriter](writerStr) - blockmatrix_ir(env)(it).map { child => - BlockMatrixWrite(child, writer) - } + blockmatrix_ir(env)(it).map(child => BlockMatrixWrite(child, writer)) case "BlockMatrixMultiWrite" => val writerStr = string_literal(it) implicit val formats: Formats = BlockMatrixWriter.formats @@ -1483,11 +1489,15 @@ object IRParser { } val reader = PartitionReader.extract(env.ctx, JsonMethods.parse(string_literal(it))) ir_value_expr(env)(it).map { context => - ReadPartition(context, requestedTypeRaw match { - case Left("None") => reader.fullRowType - case Left("DropRowUIDs") => reader.fullRowType.deleteKey(reader.uidFieldName) - case Right(t) => t.asInstanceOf[TStruct] - }, reader) + ReadPartition( + context, + requestedTypeRaw match { + case Left("None") => reader.fullRowType + case Left("DropRowUIDs") => reader.fullRowType.deleteKey(reader.uidFieldName) + case Right(t) => t.asInstanceOf[TStruct] + }, + reader, + ) } case "WritePartition" => import PartitionWriter.formats @@ -1499,16 +1509,12 @@ object IRParser { case "WriteMetadata" => import MetadataWriter.formats val writer = JsonMethods.parse(string_literal(it)).extract[MetadataWriter] - ir_value_expr(env)(it).map { ctx => - WriteMetadata(ctx, writer) - } + ir_value_expr(env)(it).map(ctx => WriteMetadata(ctx, writer)) case "ReadValue" => import ValueReader.formats val reader = JsonMethods.parse(string_literal(it)).extract[ValueReader] val typ = type_expr(it) - ir_value_expr(env)(it).map { path => - ReadValue(path, reader, typ) - } + ir_value_expr(env)(it).map(path => ReadValue(path, reader, typ)) case "WriteValue" => import ValueWriter.formats val writer = JsonMethods.parse(string_literal(it)).extract[ValueWriter] @@ -1521,9 +1527,7 @@ object IRParser { val rowType = tcoerce[TStruct](type_expr(it)) import PartitionReader.formats val reader = JsonMethods.parse(string_literal(it)).extract[PartitionReader] - ir_value_expr(env)(it).map { context => - ReadPartition(context, rowType, reader) - } + ir_value_expr(env)(it).map(context => ReadPartition(context, rowType, reader)) } } @@ -1551,9 +1555,7 @@ object IRParser { case "TableKeyBy" => val keys = identifiers(it) val isSorted = boolean_literal(it) - table_ir(env)(it).map { child => - TableKeyBy(child, keys, isSorted) - } + table_ir(env)(it).map(child => TableKeyBy(child, keys, isSorted)) case "TableDistinct" => table_ir(env)(it).map(TableDistinct) case "TableFilter" => for { @@ -1570,10 +1572,12 @@ object IRParser { } val dropRows = boolean_literal(it) val readerStr = string_literal(it) - val reader = TableReader.fromJValue(env.ctx.fs, JsonMethods.parse(readerStr).asInstanceOf[JObject]) + val reader = + TableReader.fromJValue(env.ctx.fs, JsonMethods.parse(readerStr).asInstanceOf[JObject]) val requestedType = requestedTypeRaw match { case Left("None") => reader.fullType - case Left("DropRowUIDs") => reader.asInstanceOf[TableReaderWithExtraUID].fullTypeWithoutUIDs + case Left("DropRowUIDs") => + reader.asInstanceOf[TableReaderWithExtraUID].fullTypeWithoutUIDs case Right(t) => t } done(TableRead(requestedType, dropRows, reader)) @@ -1596,19 +1600,13 @@ object IRParser { case "TableRepartition" => val n = int32_literal(it) val strategy = int32_literal(it) - table_ir(env)(it).map { child => - TableRepartition(child, n, strategy) - } + table_ir(env)(it).map(child => TableRepartition(child, n, strategy)) case "TableHead" => val n = int64_literal(it) - table_ir(env)(it).map { child => - TableHead(child, n) - } + table_ir(env)(it).map(child => TableHead(child, n)) case "TableTail" => val n = int64_literal(it) - table_ir(env)(it).map { child => - TableTail(child, n) - } + table_ir(env)(it).map(child => TableTail(child, n)) case "TableJoin" => val joinType = identifier(it) val joinKey = int32_literal(it) @@ -1637,9 +1635,7 @@ object IRParser { } case "TableParallelize" => val nPartitions = opt(it, int32_literal) - ir_value_expr(env)(it).map { rowsAndGlobal => - TableParallelize(rowsAndGlobal, nPartitions) - } + ir_value_expr(env)(it).map(rowsAndGlobal => TableParallelize(rowsAndGlobal, nPartitions)) case "TableMapRows" => for { child <- table_ir(env)(it) @@ -1657,20 +1653,14 @@ object IRParser { case "TableUnion" => table_ir_children(env)(it).map(TableUnion(_)) case "TableOrderBy" => val sortFields = sort_fields(it) - table_ir(env)(it).map { child => - TableOrderBy(child, sortFields) - } + table_ir(env)(it).map(child => TableOrderBy(child, sortFields)) case "TableExplode" => val path = string_literals(it) - table_ir(env)(it).map { child => - TableExplode(child, path) - } + table_ir(env)(it).map(child => TableExplode(child, path)) case "CastMatrixToTable" => val entriesField = string_literal(it) val colsField = string_literal(it) - matrix_ir(env)(it).map { child => - CastMatrixToTable(child, entriesField, colsField) - } + matrix_ir(env)(it).map(child => CastMatrixToTable(child, entriesField, colsField)) case "MatrixToTableApply" => val config = string_literal(it) matrix_ir(env)(it).map { child => @@ -1686,7 +1676,11 @@ object IRParser { for { bm <- blockmatrix_ir(env)(it) aux <- ir_value_expr(env)(it) - } yield BlockMatrixToTableApply(bm, aux, RelationalFunctions.lookupBlockMatrixToTable(env.ctx, config)) + } yield BlockMatrixToTableApply( + bm, + aux, + RelationalFunctions.lookupBlockMatrixToTable(env.ctx, config), + ) case "BlockMatrixToTable" => blockmatrix_ir(env)(it).map(BlockMatrixToTable) case "TableRename" => val rowK = string_literals(it) @@ -1700,7 +1694,8 @@ object IRParser { case "TableGen" => val cname = identifier(it) val gname = identifier(it) - val partitioner = between(punctuation(_, "("), punctuation(_, ")"), partitioner_literal(env))(it) + val partitioner = + between(punctuation(_, "("), punctuation(_, ")"), partitioner_literal(env))(it) val errorId = int32_literal(it) for { contexts <- ir_value_expr(env)(it) @@ -1713,11 +1708,15 @@ object IRParser { val intervals = string_literal(it) val keep = boolean_literal(it) table_ir(env)(it).map { child => - TableFilterIntervals(child, - JSONAnnotationImpex.importAnnotation(JsonMethods.parse(intervals), + TableFilterIntervals( + child, + JSONAnnotationImpex.importAnnotation( + JsonMethods.parse(intervals), TArray(TInterval(keyType)), - padNulls = false).asInstanceOf[IndexedSeq[Interval]], - keep) + padNulls = false, + ).asInstanceOf[IndexedSeq[Interval]], + keep, + ) } case "TableMapPartitions" => val globalsName = identifier(it) @@ -1727,7 +1726,8 @@ object IRParser { for { child <- table_ir(env)(it) body <- ir_value_expr(env)(it) - } yield TableMapPartitions(child, globalsName, partitionStreamName, body, requestedKey, allowedOverlap) + } yield TableMapPartitions(child, globalsName, partitionStreamName, body, requestedKey, + allowedOverlap) case "RelationalLetTable" => val name = identifier(it) for { @@ -1777,9 +1777,7 @@ object IRParser { case "MatrixKeyRowsBy" => val key = identifiers(it) val isSorted = boolean_literal(it) - matrix_ir(env)(it).map { child => - MatrixKeyRowsBy(child, key, isSorted) - } + matrix_ir(env)(it).map(child => MatrixKeyRowsBy(child, key, isSorted)) case "MatrixMapRows" => for { child <- matrix_ir(env)(it) @@ -1815,7 +1813,8 @@ object IRParser { } yield MatrixAggregateRowsByKey(child, entryExpr, rowExpr) case "MatrixRead" => val requestedTypeRaw = it.head match { - case x: IdentifierToken if x.value == "None" || x.value == "DropColUIDs" || x.value == "DropRowUIDs" || x.value == "DropRowColUIDs" => + case x: IdentifierToken + if x.value == "None" || x.value == "DropColUIDs" || x.value == "DropRowUIDs" || x.value == "DropRowColUIDs" => consumeToken(it) Left(x.value) case _ => @@ -1829,12 +1828,15 @@ object IRParser { val requestedType = requestedTypeRaw match { case Left("None") => fullType case Left("DropRowUIDs") => fullType.copy( - rowType = fullType.rowType.deleteKey(reader.rowUIDFieldName)) + rowType = fullType.rowType.deleteKey(reader.rowUIDFieldName) + ) case Left("DropColUIDs") => fullType.copy( - colType = fullType.colType.deleteKey(reader.colUIDFieldName)) + colType = fullType.colType.deleteKey(reader.colUIDFieldName) + ) case Left("DropRowColUIDs") => fullType.copy( - rowType = fullType.rowType.deleteKey(reader.rowUIDFieldName), - colType = fullType.colType.deleteKey(reader.colUIDFieldName)) + rowType = fullType.rowType.deleteKey(reader.rowUIDFieldName), + colType = fullType.colType.deleteKey(reader.colUIDFieldName), + ) case Right(t) => t } done(MatrixRead(requestedType, dropCols, dropRows, reader)) @@ -1853,56 +1855,38 @@ object IRParser { } yield MatrixAnnotateColsTable(child, table, root) case "MatrixExplodeRows" => val path = identifiers(it) - matrix_ir(env)(it).map { child => - MatrixExplodeRows(child, path) - } + matrix_ir(env)(it).map(child => MatrixExplodeRows(child, path)) case "MatrixExplodeCols" => val path = identifiers(it) - matrix_ir(env)(it).map { child => - MatrixExplodeCols(child, path) - } + matrix_ir(env)(it).map(child => MatrixExplodeCols(child, path)) case "MatrixChooseCols" => val oldIndices = int32_literals(it) - matrix_ir(env)(it).map { child => - MatrixChooseCols(child, oldIndices) - } + matrix_ir(env)(it).map(child => MatrixChooseCols(child, oldIndices)) case "MatrixCollectColsByKey" => matrix_ir(env)(it).map(MatrixCollectColsByKey) case "MatrixRepartition" => val n = int32_literal(it) val strategy = int32_literal(it) - matrix_ir(env)(it).map { child => - MatrixRepartition(child, n, strategy) - } + matrix_ir(env)(it).map(child => MatrixRepartition(child, n, strategy)) case "MatrixUnionRows" => matrix_ir_children(env)(it).map(MatrixUnionRows(_)) case "MatrixDistinctByRow" => matrix_ir(env)(it).map(MatrixDistinctByRow) case "MatrixRowsHead" => val n = int64_literal(it) - matrix_ir(env)(it).map { child => - MatrixRowsHead(child, n) - } + matrix_ir(env)(it).map(child => MatrixRowsHead(child, n)) case "MatrixColsHead" => val n = int32_literal(it) - matrix_ir(env)(it).map { child => - MatrixColsHead(child, n) - } + matrix_ir(env)(it).map(child => MatrixColsHead(child, n)) case "MatrixRowsTail" => val n = int64_literal(it) - matrix_ir(env)(it).map { child => - MatrixRowsTail(child, n) - } + matrix_ir(env)(it).map(child => MatrixRowsTail(child, n)) case "MatrixColsTail" => val n = int32_literal(it) - matrix_ir(env)(it).map { child => - MatrixColsTail(child, n) - } + matrix_ir(env)(it).map(child => MatrixColsTail(child, n)) case "CastTableToMatrix" => val entriesField = identifier(it) val colsField = identifier(it) val colKey = identifiers(it) - table_ir(env)(it).map { child => - CastTableToMatrix(child, entriesField, colsField, colKey) - } + table_ir(env)(it).map(child => CastTableToMatrix(child, entriesField, colsField, colKey)) case "MatrixToMatrixApply" => val config = string_literal(it) matrix_ir(env)(it).map { child => @@ -1918,18 +1902,28 @@ object IRParser { val entryK = string_literals(it) val entryV = string_literals(it) matrix_ir(env)(it).map { child => - MatrixRename(child, globalK.zip(globalV).toMap, colK.zip(colV).toMap, rowK.zip(rowV).toMap, entryK.zip(entryV).toMap) + MatrixRename( + child, + globalK.zip(globalV).toMap, + colK.zip(colV).toMap, + rowK.zip(rowV).toMap, + entryK.zip(entryV).toMap, + ) } case "MatrixFilterIntervals" => val keyType = type_expr(it) val intervals = string_literal(it) val keep = boolean_literal(it) matrix_ir(env)(it).map { child => - MatrixFilterIntervals(child, - JSONAnnotationImpex.importAnnotation(JsonMethods.parse(intervals), + MatrixFilterIntervals( + child, + JSONAnnotationImpex.importAnnotation( + JsonMethods.parse(intervals), TArray(TInterval(keyType)), - padNulls = false).asInstanceOf[IndexedSeq[Interval]], - keep) + padNulls = false, + ).asInstanceOf[IndexedSeq[Interval]], + keep, + ) } case "RelationalLetMatrixTable" => val name = identifier(it) @@ -1940,7 +1934,8 @@ object IRParser { } } - def blockmatrix_sparsifier(env: IRParserEnvironment)(it: TokenIterator): StackFrame[BlockMatrixSparsifier] = { + def blockmatrix_sparsifier(env: IRParserEnvironment)(it: TokenIterator) + : StackFrame[BlockMatrixSparsifier] = { punctuation(it, "(") identifier(it) match { case "PyRowIntervalSparsifier" => @@ -2039,14 +2034,10 @@ object IRParser { } case "BlockMatrixAgg" => val outIndexExpr = int32_literals(it) - blockmatrix_ir(env)(it).map { child => - BlockMatrixAgg(child, outIndexExpr) - } + blockmatrix_ir(env)(it).map(child => BlockMatrixAgg(child, outIndexExpr)) case "BlockMatrixFilter" => val indices = literals(literals(int64_literal))(it) - blockmatrix_ir(env)(it).map { child => - BlockMatrixFilter(child, indices) - } + blockmatrix_ir(env)(it).map(child => BlockMatrixFilter(child, indices)) case "BlockMatrixDensify" => blockmatrix_ir(env)(it).map(BlockMatrixDensify) case "BlockMatrixSparsify" => @@ -2062,9 +2053,7 @@ object IRParser { case "ValueToBlockMatrix" => val shape = int64_literals(it) val blockSize = int32_literal(it) - ir_value_expr(env)(it).map { child => - ValueToBlockMatrix(child, shape, blockSize) - } + ir_value_expr(env)(it).map(child => ValueToBlockMatrix(child, shape, blockSize)) case "BlockMatrixRandom" => val staticUID = int64_literal(it) val gaussian = boolean_literal(it) @@ -2104,17 +2093,44 @@ object IRParser { x case MakeArray(args, typ) => MakeArray.unify(ctx, args, typ) - case x@InitOp(_, _, BasicPhysicalAggSig(_, FoldStateSig(t, accumName, otherAccumName, combIR))) => - run(combIR, BindingEnv.empty.bindEval(accumName -> t.virtualType, otherAccumName -> t.virtualType)) + case x @ InitOp( + _, + _, + BasicPhysicalAggSig(_, FoldStateSig(t, accumName, otherAccumName, combIR)), + ) => + run( + combIR, + BindingEnv.empty.bindEval(accumName -> t.virtualType, otherAccumName -> t.virtualType), + ) x - case x@SeqOp(_, _, BasicPhysicalAggSig(_, FoldStateSig(t, accumName, otherAccumName, combIR))) => - run(combIR, BindingEnv.empty.bindEval(accumName -> t.virtualType, otherAccumName -> t.virtualType)) + case x @ SeqOp( + _, + _, + BasicPhysicalAggSig(_, FoldStateSig(t, accumName, otherAccumName, combIR)), + ) => + run( + combIR, + BindingEnv.empty.bindEval(accumName -> t.virtualType, otherAccumName -> t.virtualType), + ) x - case x@CombOp(_, _, BasicPhysicalAggSig(_, FoldStateSig(t, accumName, otherAccumName, combIR))) => - run(combIR, BindingEnv.empty.bindEval(accumName -> t.virtualType, otherAccumName -> t.virtualType)) + case x @ CombOp( + _, + _, + BasicPhysicalAggSig(_, FoldStateSig(t, accumName, otherAccumName, combIR)), + ) => + run( + combIR, + BindingEnv.empty.bindEval(accumName -> t.virtualType, otherAccumName -> t.virtualType), + ) x - case x@ResultOp(_, BasicPhysicalAggSig(_, FoldStateSig(t, accumName, otherAccumName, combIR))) => - run(combIR, BindingEnv.empty.bindEval(accumName -> t.virtualType, otherAccumName -> t.virtualType)) + case x @ ResultOp( + _, + BasicPhysicalAggSig(_, FoldStateSig(t, accumName, otherAccumName, combIR)), + ) => + run( + combIR, + BindingEnv.empty.bindEval(accumName -> t.virtualType, otherAccumName -> t.virtualType), + ) x case Apply(name, typeArgs, args, rt, errorID) => invoke(name, rt, typeArgs, errorID, args: _*) @@ -2131,18 +2147,22 @@ object IRParser { f(it) } - def parse_value_ir(s: String, env: IRParserEnvironment, typeEnv: BindingEnv[Type] = BindingEnv.empty): IR = { + def parse_value_ir( + s: String, + env: IRParserEnvironment, + typeEnv: BindingEnv[Type] = BindingEnv.empty, + ): IR = { var ir = parse(s, ir_value_expr(env)(_).run()) ir = annotateTypes(env.ctx, ir, typeEnv).asInstanceOf[IR] TypeCheck(env.ctx, ir, typeEnv) ir } - def parse_value_ir(ctx: ExecuteContext, s: String): IR = { + def parse_value_ir(ctx: ExecuteContext, s: String): IR = parse_value_ir(s, IRParserEnvironment(ctx)) - } - def parse_table_ir(ctx: ExecuteContext, s: String): TableIR = parse_table_ir(s, IRParserEnvironment(ctx)) + def parse_table_ir(ctx: ExecuteContext, s: String): TableIR = + parse_table_ir(s, IRParserEnvironment(ctx)) def parse_table_ir(s: String, env: IRParserEnvironment): TableIR = { var ir = parse(s, table_ir(env)(_).run()) @@ -2158,7 +2178,8 @@ object IRParser { ir } - def parse_matrix_ir(ctx: ExecuteContext, s: String): MatrixIR = parse_matrix_ir(s, IRParserEnvironment(ctx)) + def parse_matrix_ir(ctx: ExecuteContext, s: String): MatrixIR = + parse_matrix_ir(s, IRParserEnvironment(ctx)) def parse_blockmatrix_ir(s: String, env: IRParserEnvironment): BlockMatrixIR = { var ir = parse(s, blockmatrix_ir(env)(_).run()) @@ -2167,7 +2188,8 @@ object IRParser { ir } - def parse_blockmatrix_ir(ctx: ExecuteContext, s: String): BlockMatrixIR = parse_blockmatrix_ir(s, IRParserEnvironment(ctx)) + def parse_blockmatrix_ir(ctx: ExecuteContext, s: String): BlockMatrixIR = + parse_blockmatrix_ir(s, IRParserEnvironment(ctx)) def parseType(code: String): Type = parse(code, type_expr) diff --git a/hail/src/main/scala/is/hail/expr/ir/Pretty.scala b/hail/src/main/scala/is/hail/expr/ir/Pretty.scala index 1b36a307e93..b5b2f7524c7 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Pretty.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Pretty.scala @@ -7,28 +7,55 @@ import is.hail.expr.ir.agg._ import is.hail.expr.ir.functions.RelationalFunctions import is.hail.types.TableType import is.hail.types.virtual.{TArray, TInterval, TStream, Type} +import is.hail.utils.{space => _, _} import is.hail.utils.prettyPrint._ import is.hail.utils.richUtils.RichIterable -import is.hail.utils.{space => _, _} -import org.json4s.DefaultFormats -import org.json4s.jackson.{JsonMethods, Serialization} import scala.collection.mutable +import org.json4s.DefaultFormats +import org.json4s.jackson.{JsonMethods, Serialization} + object Pretty { - def apply(ctx: ExecuteContext, ir: BaseIR, width: Int = 100, ribbonWidth: Int = 50, elideLiterals: Boolean = true, maxLen: Int = -1, allowUnboundRefs: Boolean = false): String = { + def apply( + ctx: ExecuteContext, + ir: BaseIR, + width: Int = 100, + ribbonWidth: Int = 50, + elideLiterals: Boolean = true, + maxLen: Int = -1, + allowUnboundRefs: Boolean = false, + preserveNames: Boolean = false, + ): String = { val useSSA = ctx != null && ctx.getFlag("use_ssa_logs") != null - val pretty = new Pretty(width, ribbonWidth, elideLiterals, maxLen, allowUnboundRefs, useSSA) + val pretty = + new Pretty(width, ribbonWidth, elideLiterals, maxLen, allowUnboundRefs, useSSA, preserveNames) pretty(ir) } - def sexprStyle(ir: BaseIR, width: Int = 100, ribbonWidth: Int = 50, elideLiterals: Boolean = true, maxLen: Int = -1, allowUnboundRefs: Boolean = false): String = { - val pretty = new Pretty(width, ribbonWidth, elideLiterals, maxLen, allowUnboundRefs, useSSA = false) + def sexprStyle( + ir: BaseIR, + width: Int = 100, + ribbonWidth: Int = 50, + elideLiterals: Boolean = true, + maxLen: Int = -1, + allowUnboundRefs: Boolean = false, + ): String = { + val pretty = + new Pretty(width, ribbonWidth, elideLiterals, maxLen, allowUnboundRefs, useSSA = false) pretty(ir) } - def ssaStyle(ir: BaseIR, width: Int = 100, ribbonWidth: Int = 50, elideLiterals: Boolean = true, maxLen: Int = -1, allowUnboundRefs: Boolean = false): String = { - val pretty = new Pretty(width, ribbonWidth, elideLiterals, maxLen, allowUnboundRefs, useSSA = true) + def ssaStyle( + ir: BaseIR, + width: Int = 100, + ribbonWidth: Int = 50, + elideLiterals: Boolean = true, + maxLen: Int = -1, + allowUnboundRefs: Boolean = false, + ): String = { + val pretty = + new Pretty(width, ribbonWidth, elideLiterals, maxLen, allowUnboundRefs, useSSA = true) pretty(ir) } @@ -42,7 +69,15 @@ object Pretty { } } -class Pretty(width: Int, ribbonWidth: Int, elideLiterals: Boolean, maxLen: Int, allowUnboundRefs: Boolean, useSSA: Boolean) { +class Pretty( + width: Int, + ribbonWidth: Int, + elideLiterals: Boolean, + maxLen: Int, + allowUnboundRefs: Boolean, + useSSA: Boolean, + preserveNames: Boolean = false, +) { def short(ir: BaseIR): String = { val s = apply(ir) if (s.length < maxLen) s else s.substring(0, maxLen) + "..." @@ -94,7 +129,8 @@ class Pretty(width: Int, ribbonWidth: Int, elideLiterals: Boolean, maxLen: Int, concat(docs.intersperse[Doc]( "(", softline, - if (truncate) s"... ${ x.length - MAX_VALUES_TO_LOG } more values... )" else ")")) + if (truncate) s"... ${x.length - MAX_VALUES_TO_LOG} more values... )" else ")", + )) } def prettyInts(x: IndexedSeq[Int], elideLiterals: Boolean): Doc = { @@ -104,7 +140,8 @@ class Pretty(width: Int, ribbonWidth: Int, elideLiterals: Boolean, maxLen: Int, concat(docs.intersperse[Doc]( "(", softline, - if (truncate) s"... ${ x.length - MAX_VALUES_TO_LOG } more values... )" else ")")) + if (truncate) s"... ${x.length - MAX_VALUES_TO_LOG} more values... )" else ")", + )) } def prettyIdentifiers(x: IndexedSeq[String]): Doc = @@ -116,12 +153,21 @@ class Pretty(width: Int, ribbonWidth: Int, elideLiterals: Boolean, maxLen: Int, def prettyAggStateSignature(state: AggStateSig): Doc = { state match { case FoldStateSig(resultEmitType, accumName, otherAccumName, combOpIR) => - fillList(IndexedSeq(text(Pretty.prettyClass(state)), text(resultEmitType.typeWithRequiredness.canonicalPType.toString), - text(accumName), text(otherAccumName), text(apply(combOpIR)))) + fillList(IndexedSeq( + text(Pretty.prettyClass(state)), + text(resultEmitType.typeWithRequiredness.canonicalPType.toString), + text(accumName), + text(otherAccumName), + text(apply(combOpIR)), + )) case _ => fillList(state.n match { - case None => text(Pretty.prettyClass(state)) +: state.t.view.map(typ => text(typ.canonicalPType.toString)) - case Some(nested) => text(Pretty.prettyClass(state)) +: state.t.view.map(typ => text(typ.canonicalPType.toString)) :+ prettyAggStateSignatures(nested) + case None => text(Pretty.prettyClass(state)) +: state.t.view.map(typ => + text(typ.canonicalPType.toString) + ) + case Some(nested) => text(Pretty.prettyClass(state)) +: state.t.view.map(typ => + text(typ.canonicalPType.toString) + ) :+ prettyAggStateSignatures(nested) }) } } @@ -145,54 +191,75 @@ class Pretty(width: Int, ribbonWidth: Int, elideLiterals: Boolean, maxLen: Int, def single(d: Doc): Iterable[Doc] = RichIterable.single(d) def header(ir: BaseIR, elideBindings: Boolean = false): Iterable[Doc] = ir match { - case ApplyAggOp(initOpArgs, seqOpArgs, aggSig) => single(Pretty.prettyClass(aggSig.op)) - case ApplyScanOp(initOpArgs, seqOpArgs, aggSig) => single(Pretty.prettyClass(aggSig.op)) - case InitOp(i, args, aggSig) => FastSeq(i.toString, prettyPhysicalAggSig(aggSig)) - case SeqOp(i, args, aggSig) => FastSeq(i.toString, prettyPhysicalAggSig(aggSig)) + case ApplyAggOp(_, _, aggSig) => single(Pretty.prettyClass(aggSig.op)) + case ApplyScanOp(_, _, aggSig) => single(Pretty.prettyClass(aggSig.op)) + case InitOp(i, _, aggSig) => FastSeq(i.toString, prettyPhysicalAggSig(aggSig)) + case SeqOp(i, _, aggSig) => FastSeq(i.toString, prettyPhysicalAggSig(aggSig)) case CombOp(i1, i2, aggSig) => FastSeq(i1.toString, i2.toString, prettyPhysicalAggSig(aggSig)) case ResultOp(i, aggSig) => FastSeq(i.toString, prettyPhysicalAggSig(aggSig)) case AggStateValue(i, sig) => FastSeq(i.toString, prettyAggStateSignature(sig)) - case InitFromSerializedValue(i, value, aggSig) => + case InitFromSerializedValue(i, _, aggSig) => FastSeq(i.toString, prettyAggStateSignature(aggSig)) - case CombOpValue(i, value, sig) => FastSeq(i.toString, prettyPhysicalAggSig(sig)) + case CombOpValue(i, _, sig) => FastSeq(i.toString, prettyPhysicalAggSig(sig)) case SerializeAggs(i, i2, spec, aggSigs) => - FastSeq(i.toString, i2.toString, prettyStringLiteral(spec.toString), prettyAggStateSignatures(aggSigs)) + FastSeq( + i.toString, + i2.toString, + prettyStringLiteral(spec.toString), + prettyAggStateSignatures(aggSigs), + ) case DeserializeAggs(i, i2, spec, aggSigs) => - FastSeq(i.toString, i2.toString, prettyStringLiteral(spec.toString), prettyAggStateSignatures(aggSigs)) - case RunAgg(body, result, signature) => single(prettyAggStateSignatures(signature)) - case RunAggScan(a, name, init, seq, res, signature) => + FastSeq( + i.toString, + i2.toString, + prettyStringLiteral(spec.toString), + prettyAggStateSignatures(aggSigs), + ) + case RunAgg(_, _, signature) => single(prettyAggStateSignatures(signature)) + case RunAggScan(_, name, _, _, _, signature) => FastSeq(prettyIdentifier(name), prettyAggStateSignatures(signature)) case I32(x) => single(x.toString) case I64(x) => single(x.toString) case F32(x) => single(x.toString) case F64(x) => single(x.toString) - case Str(x) => single(prettyStringLiteral(if (elideLiterals && x.length > 13) x.take(10) + "..." else x)) + case Str(x) => + single(prettyStringLiteral(if (elideLiterals && x.length > 13) x.take(10) + "..." else x)) case UUID4(id) => single(prettyIdentifier(id)) case Cast(_, typ) => single(typ.parsableString()) case CastRename(_, typ) => single(typ.parsableString()) case NA(typ) => single(typ.parsableString()) case Literal(typ, value) => - FastSeq(typ.parsableString(), + FastSeq( + typ.parsableString(), if (!elideLiterals) prettyStringLiteral(JsonMethods.compact(JSONAnnotationImpex.exportAnnotation(value, typ))) else - "") + "", + ) case EncodedLiteral(codec, _) => single(codec.encodedVirtualType.parsableString()) - case Let(bindings, _) if !elideBindings => bindings.map(b => text(prettyIdentifier(b._1))) - case AggLet(name, _, _, isScan) => if (elideBindings) - single(Pretty.prettyBooleanLiteral(isScan)) - else - FastSeq(prettyIdentifier(name), Pretty.prettyBooleanLiteral(isScan)) + case Block(bindings, _) if !elideBindings => + bindings.flatMap { b => + val bindType = b.scope match { + case Scope.EVAL => "eval" + case Scope.AGG => "agg" + case Scope.SCAN => "scan" + } + FastSeq(text(bindType), text(prettyIdentifier(b.name))) + } case TailLoop(name, args, returnType, _) if !elideBindings => - FastSeq(prettyIdentifier(name), prettyIdentifiers(args.map(_._1).toFastSeq), returnType.parsableString()) - case Recur(name, _, t) if !elideBindings => + FastSeq( + prettyIdentifier(name), + prettyIdentifiers(args.map(_._1).toFastSeq), + returnType.parsableString(), + ) + case Recur(name, _, _) if !elideBindings => FastSeq(prettyIdentifier(name)) // case Ref(name, t) if t != null => FastSeq(prettyIdentifier(name), t.parsableString()) // For debug purposes case Ref(name, _) => single(prettyIdentifier(name)) case RelationalRef(name, t) => if (elideBindings) - single(t.parsableString()) - else - FastSeq(prettyIdentifier(name), t.parsableString()) + single(t.parsableString()) + else + FastSeq(prettyIdentifier(name), t.parsableString()) case RelationalLet(name, _, _) if !elideBindings => single(prettyIdentifier(name)) case ApplyBinaryPrimOp(op, _, _) => single(Pretty.prettyClass(op)) case ApplyUnaryPrimOp(op, _) => single(Pretty.prettyClass(op)) @@ -203,65 +270,127 @@ class Pretty(width: Int, ribbonWidth: Int, elideLiterals: Boolean, maxLen: Int, case MakeArray(_, typ) => single(typ.parsableString()) case MakeStream(_, typ, requiresMemoryManagementPerElement) => FastSeq(typ.parsableString(), Pretty.prettyBooleanLiteral(requiresMemoryManagementPerElement)) - case StreamIota(_, _, requiresMemoryManagementPerElement) => FastSeq(Pretty.prettyBooleanLiteral(requiresMemoryManagementPerElement)) - case StreamRange(_, _, _, requiresMemoryManagementPerElement, errorID) => FastSeq(errorID.toString, Pretty.prettyBooleanLiteral(requiresMemoryManagementPerElement)) - case ToStream(_, requiresMemoryManagementPerElement) => single(Pretty.prettyBooleanLiteral(requiresMemoryManagementPerElement)) + case StreamIota(_, _, requiresMemoryManagementPerElement) => + FastSeq(Pretty.prettyBooleanLiteral(requiresMemoryManagementPerElement)) + case StreamRange(_, _, _, requiresMemoryManagementPerElement, errorID) => + FastSeq(errorID.toString, Pretty.prettyBooleanLiteral(requiresMemoryManagementPerElement)) + case ToStream(_, requiresMemoryManagementPerElement) => + single(Pretty.prettyBooleanLiteral(requiresMemoryManagementPerElement)) case StreamMap(_, name, _) if !elideBindings => single(prettyIdentifier(name)) case StreamZip(_, names, _, behavior, errorID) => if (elideBindings) - FastSeq(errorID.toString, behavior match { - case ArrayZipBehavior.AssertSameLength => "AssertSameLength" - case ArrayZipBehavior.TakeMinLength => "TakeMinLength" - case ArrayZipBehavior.ExtendNA => "ExtendNA" - case ArrayZipBehavior.AssumeSameLength => "AssumeSameLength" - }) - else - FastSeq(errorID.toString, behavior match { - case ArrayZipBehavior.AssertSameLength => "AssertSameLength" - case ArrayZipBehavior.TakeMinLength => "TakeMinLength" - case ArrayZipBehavior.ExtendNA => "ExtendNA" - case ArrayZipBehavior.AssumeSameLength => "AssumeSameLength" - }, prettyIdentifiers(names)) + FastSeq( + errorID.toString, + behavior match { + case ArrayZipBehavior.AssertSameLength => "AssertSameLength" + case ArrayZipBehavior.TakeMinLength => "TakeMinLength" + case ArrayZipBehavior.ExtendNA => "ExtendNA" + case ArrayZipBehavior.AssumeSameLength => "AssumeSameLength" + }, + ) + else + FastSeq( + errorID.toString, + behavior match { + case ArrayZipBehavior.AssertSameLength => "AssertSameLength" + case ArrayZipBehavior.TakeMinLength => "TakeMinLength" + case ArrayZipBehavior.ExtendNA => "ExtendNA" + case ArrayZipBehavior.AssumeSameLength => "AssumeSameLength" + }, + prettyIdentifiers(names), + ) case StreamZipJoin(streams, key, curKey, curVals, _) if !elideBindings => - FastSeq(streams.length.toString, prettyIdentifiers(key), prettyIdentifier(curKey), prettyIdentifier(curVals)) + FastSeq( + streams.length.toString, + prettyIdentifiers(key), + prettyIdentifier(curKey), + prettyIdentifier(curVals), + ) case StreamZipJoinProducers(_, ctxName, _, key, curKey, curVals, _) if !elideBindings => - FastSeq(prettyIdentifiers(key), prettyIdentifier(ctxName), prettyIdentifier(curKey), prettyIdentifier(curVals)) + FastSeq( + prettyIdentifiers(key), + prettyIdentifier(ctxName), + prettyIdentifier(curKey), + prettyIdentifier(curVals), + ) case StreamMultiMerge(_, key) => single(prettyIdentifiers(key)) case StreamFilter(_, name, _) if !elideBindings => single(prettyIdentifier(name)) case StreamTakeWhile(_, name, _) if !elideBindings => single(prettyIdentifier(name)) case StreamDropWhile(_, name, _) if !elideBindings => single(prettyIdentifier(name)) case StreamFlatMap(_, name, _) if !elideBindings => single(prettyIdentifier(name)) - case StreamFold(_, _, accumName, valueName, _) if !elideBindings => FastSeq(prettyIdentifier(accumName), prettyIdentifier(valueName)) - case StreamFold2(_, acc, valueName, _, _) if !elideBindings => FastSeq(prettyIdentifiers(acc.map(_._1)), prettyIdentifier(valueName)) - case StreamScan(_, _, accumName, valueName, _) if !elideBindings => FastSeq(prettyIdentifier(accumName), prettyIdentifier(valueName)) - case StreamWhiten(_, newChunk, prevWindow, vecSize, windowSize, chunkSize, blockSize, normalizeAfterWhiten) => - FastSeq(prettyIdentifier(newChunk), prettyIdentifier(prevWindow), vecSize.toString, windowSize.toString, chunkSize.toString, blockSize.toString, Pretty.prettyBooleanLiteral(normalizeAfterWhiten)) + case StreamFold(_, _, accumName, valueName, _) if !elideBindings => + FastSeq(prettyIdentifier(accumName), prettyIdentifier(valueName)) + case StreamFold2(_, acc, valueName, _, _) if !elideBindings => + FastSeq(prettyIdentifiers(acc.map(_._1)), prettyIdentifier(valueName)) + case StreamScan(_, _, accumName, valueName, _) if !elideBindings => + FastSeq(prettyIdentifier(accumName), prettyIdentifier(valueName)) + case StreamWhiten(_, newChunk, prevWindow, vecSize, windowSize, chunkSize, blockSize, + normalizeAfterWhiten) => + FastSeq( + prettyIdentifier(newChunk), + prettyIdentifier(prevWindow), + vecSize.toString, + windowSize.toString, + chunkSize.toString, + blockSize.toString, + Pretty.prettyBooleanLiteral(normalizeAfterWhiten), + ) case StreamJoinRightDistinct(_, _, lKey, rKey, l, r, _, joinType) => if (elideBindings) - FastSeq(prettyIdentifiers(lKey), prettyIdentifiers(rKey), joinType) - else - FastSeq(prettyIdentifiers(lKey), prettyIdentifiers(rKey), prettyIdentifier(l), prettyIdentifier(r), joinType) + FastSeq(prettyIdentifiers(lKey), prettyIdentifiers(rKey), joinType) + else + FastSeq( + prettyIdentifiers(lKey), + prettyIdentifiers(rKey), + prettyIdentifier(l), + prettyIdentifier(r), + joinType, + ) + case StreamLeftIntervalJoin(_, _, lKeyFieldName, rIntrvlName, lEltName, rEltName, _) => + val builder = new BoxedArrayBuilder[Doc](if (elideBindings) 2 else 4) + builder += prettyIdentifier(lKeyFieldName) + builder += prettyIdentifier(rIntrvlName) + + if (!elideBindings) { + builder += prettyIdentifier(lEltName) + builder += prettyIdentifier(rEltName) + } + + builder.underlying() case StreamFor(_, valueName, _) if !elideBindings => single(prettyIdentifier(valueName)) - case StreamAgg(a, name, query) if !elideBindings => single(prettyIdentifier(name)) - case StreamAggScan(a, name, query) if !elideBindings => single(prettyIdentifier(name)) - case StreamGroupByKey(a, key, missingEqual) => FastSeq(prettyIdentifiers(key), prettyBooleanLiteral(missingEqual)) + case StreamAgg(_, name, _) if !elideBindings => single(prettyIdentifier(name)) + case StreamAggScan(_, name, _) if !elideBindings => single(prettyIdentifier(name)) + case StreamGroupByKey(_, key, missingEqual) => + FastSeq(prettyIdentifiers(key), prettyBooleanLiteral(missingEqual)) case AggFold(_, _, _, accumName, otherAccumName, isScan) => if (elideBindings) - single(Pretty.prettyBooleanLiteral(isScan)) - else - FastSeq(prettyIdentifier(accumName), prettyIdentifier(otherAccumName), Pretty.prettyBooleanLiteral(isScan)) + single(Pretty.prettyBooleanLiteral(isScan)) + else + FastSeq( + prettyIdentifier(accumName), + prettyIdentifier(otherAccumName), + Pretty.prettyBooleanLiteral(isScan), + ) case AggExplode(_, name, _, isScan) => if (elideBindings) - single(Pretty.prettyBooleanLiteral(isScan)) - else - FastSeq(prettyIdentifier(name), Pretty.prettyBooleanLiteral(isScan)) + single(Pretty.prettyBooleanLiteral(isScan)) + else + FastSeq(prettyIdentifier(name), Pretty.prettyBooleanLiteral(isScan)) case AggFilter(_, _, isScan) => single(Pretty.prettyBooleanLiteral(isScan)) case AggGroupBy(_, _, isScan) => single(Pretty.prettyBooleanLiteral(isScan)) case AggArrayPerElement(_, elementName, indexName, _, knownLength, isScan) => if (elideBindings) - FastSeq(Pretty.prettyBooleanLiteral(isScan), Pretty.prettyBooleanLiteral(knownLength.isDefined)) - else - FastSeq(prettyIdentifier(elementName), prettyIdentifier(indexName), Pretty.prettyBooleanLiteral(isScan), Pretty.prettyBooleanLiteral(knownLength.isDefined)) + FastSeq( + Pretty.prettyBooleanLiteral(isScan), + Pretty.prettyBooleanLiteral(knownLength.isDefined), + ) + else + FastSeq( + prettyIdentifier(elementName), + prettyIdentifier(indexName), + Pretty.prettyBooleanLiteral(isScan), + Pretty.prettyBooleanLiteral(knownLength.isDefined), + ) case NDArrayMap(_, name, _) if !elideBindings => single(prettyIdentifier(name)) case NDArrayMap2(_, _, lName, rName, _, errorID) => if (elideBindings) - single(s"$errorID") - else - FastSeq(s"$errorID", prettyIdentifier(lName), prettyIdentifier(rName)) + single(s"$errorID") + else + FastSeq(s"$errorID", prettyIdentifier(lName), prettyIdentifier(rName)) case NDArrayReindex(_, indexExpr) => single(prettyInts(indexExpr, elideLiterals)) case NDArrayConcat(_, axis) => single(axis.toString) case NDArrayAgg(_, axes) => single(prettyInts(axes, elideLiterals)) @@ -269,47 +398,73 @@ class Pretty(width: Int, ribbonWidth: Int, elideLiterals: Boolean, maxLen: Int, case NDArrayReshape(_, _, errorID) => single(s"$errorID") case NDArrayMatMul(_, _, errorID) => single(s"$errorID") case NDArrayQR(_, mode, errorID) => FastSeq(errorID.toString, mode) - case NDArraySVD(_, fullMatrices, computeUV, errorID) => FastSeq(errorID.toString, fullMatrices.toString, computeUV.toString) + case NDArraySVD(_, fullMatrices, computeUV, errorID) => + FastSeq(errorID.toString, fullMatrices.toString, computeUV.toString) case NDArrayEigh(_, eigvalsOnly, errorID) => FastSeq(errorID.toString, eigvalsOnly.toString) case NDArrayInv(_, errorID) => single(s"$errorID") - case ArraySort(_, l, r, _) if !elideBindings => FastSeq(prettyIdentifier(l), prettyIdentifier(r)) - case ArrayRef(_,_, errorID) => single(s"$errorID") - case ApplyIR(function, typeArgs, _, _, errorID) => FastSeq(s"$errorID", prettyIdentifier(function), prettyTypes(typeArgs), ir.typ.parsableString()) - case Apply(function, typeArgs, _, t, errorID) => FastSeq(s"$errorID", prettyIdentifier(function), prettyTypes(typeArgs), t.parsableString()) - case ApplySeeded(function, _, rngState, staticUID, t) => FastSeq(prettyIdentifier(function), staticUID.toString, t.parsableString()) - case ApplySpecial(function, typeArgs, _, t, errorID) => FastSeq(s"$errorID", prettyIdentifier(function), prettyTypes(typeArgs), t.parsableString()) - case SelectFields(_, fields) => single(fillList(fields.view.map(f => text(prettyIdentifier(f))))) + case ArraySort(_, l, r, _) if !elideBindings => + FastSeq(prettyIdentifier(l), prettyIdentifier(r)) + case ArrayRef(_, _, errorID) => single(s"$errorID") + case ApplyIR(function, typeArgs, _, _, errorID) => FastSeq( + s"$errorID", + prettyIdentifier(function), + prettyTypes(typeArgs), + ir.typ.parsableString(), + ) + case Apply(function, typeArgs, _, t, errorID) => + FastSeq(s"$errorID", prettyIdentifier(function), prettyTypes(typeArgs), t.parsableString()) + case ApplySeeded(function, _, _, staticUID, t) => + FastSeq(prettyIdentifier(function), staticUID.toString, t.parsableString()) + case ApplySpecial(function, typeArgs, _, t, errorID) => + FastSeq(s"$errorID", prettyIdentifier(function), prettyTypes(typeArgs), t.parsableString()) + case SelectFields(_, fields) => + single(fillList(fields.view.map(f => text(prettyIdentifier(f))))) case LowerBoundOnOrderedCollection(_, _, onKey) => single(Pretty.prettyBooleanLiteral(onKey)) case In(i, typ) => FastSeq(typ.toString, i.toString) - case Die(message, typ, errorID) => FastSeq(typ.parsableString(), errorID.toString) + case Die(_, typ, errorID) => FastSeq(typ.parsableString(), errorID.toString) case CollectDistributedArray(_, _, cname, gname, _, _, staticID, _) if !elideBindings => FastSeq(staticID, prettyIdentifier(cname), prettyIdentifier(gname)) case MatrixRead(typ, dropCols, dropRows, reader) => - FastSeq(if (typ == reader.fullMatrixType) "None" else typ.parsableString(), + FastSeq( + if (typ == reader.fullMatrixType) "None" else typ.parsableString(), Pretty.prettyBooleanLiteral(dropCols), Pretty.prettyBooleanLiteral(dropRows), - if (elideLiterals) reader.renderShort() else '"' + StringEscapeUtils.escapeString(JsonMethods.compact(reader.toJValue)) + '"') + if (elideLiterals) reader.renderShort() + else '"' + StringEscapeUtils.escapeString(JsonMethods.compact(reader.toJValue)) + '"', + ) case MatrixWrite(_, writer) => - single('"' + StringEscapeUtils.escapeString(Serialization.write(writer)(MatrixWriter.formats)) + '"') + single('"' + StringEscapeUtils.escapeString( + Serialization.write(writer)(MatrixWriter.formats) + ) + '"') case MatrixMultiWrite(_, writer) => - single('"' + StringEscapeUtils.escapeString(Serialization.write(writer)(MatrixNativeMultiWriter.formats)) + '"') + single('"' + StringEscapeUtils.escapeString( + Serialization.write(writer)(MatrixNativeMultiWriter.formats) + ) + '"') case BlockMatrixRead(reader) => single('"' + StringEscapeUtils.escapeString(JsonMethods.compact(reader.toJValue)) + '"') case BlockMatrixWrite(_, writer) => - single('"' + StringEscapeUtils.escapeString(Serialization.write(writer)(BlockMatrixWriter.formats)) + '"') + single('"' + StringEscapeUtils.escapeString( + Serialization.write(writer)(BlockMatrixWriter.formats) + ) + '"') case BlockMatrixMultiWrite(_, writer) => - single('"' + StringEscapeUtils.escapeString(Serialization.write(writer)(BlockMatrixWriter.formats)) + '"') + single('"' + StringEscapeUtils.escapeString( + Serialization.write(writer)(BlockMatrixWriter.formats) + ) + '"') case BlockMatrixBroadcast(_, inIndexExpr, shape, blockSize) => - FastSeq(prettyInts(inIndexExpr, elideLiterals), + FastSeq( + prettyInts(inIndexExpr, elideLiterals), prettyLongs(shape, elideLiterals), - blockSize.toString) + blockSize.toString, + ) case BlockMatrixAgg(_, outIndexExpr) => single(prettyInts(outIndexExpr, elideLiterals)) case BlockMatrixSlice(_, slices) => single(fillList(slices.view.map(slice => prettyLongs(slice, elideLiterals)))) case ValueToBlockMatrix(_, shape, blockSize) => FastSeq(prettyLongs(shape, elideLiterals), blockSize.toString) case BlockMatrixFilter(_, indicesToKeepPerDim) => - single(fillList(indicesToKeepPerDim.toSeq.view.map(indices => prettyLongs(indices, elideLiterals)))) + single(fillList(indicesToKeepPerDim.toSeq.view.map(indices => + prettyLongs(indices, elideLiterals) + ))) case BlockMatrixSparsify(_, sparsifier) => single(sparsifier.pretty()) case BlockMatrixRandom(staticUID, gaussian, shape, blockSize) => @@ -317,15 +472,20 @@ class Pretty(width: Int, ribbonWidth: Int, elideLiterals: Boolean, maxLen: Int, staticUID.toString, Pretty.prettyBooleanLiteral(gaussian), prettyLongs(shape, elideLiterals), - blockSize.toString) + blockSize.toString, + ) case BlockMatrixMap(_, name, _, needsDense) => if (elideBindings) - single(Pretty.prettyBooleanLiteral(needsDense)) - else - FastSeq(prettyIdentifier(name), Pretty.prettyBooleanLiteral(needsDense)) + single(Pretty.prettyBooleanLiteral(needsDense)) + else + FastSeq(prettyIdentifier(name), Pretty.prettyBooleanLiteral(needsDense)) case BlockMatrixMap2(_, _, lName, rName, _, sparsityStrategy) => if (elideBindings) - single(Pretty.prettyClass(sparsityStrategy)) - else - FastSeq(prettyIdentifier(lName), prettyIdentifier(rName), Pretty.prettyClass(sparsityStrategy)) + single(Pretty.prettyClass(sparsityStrategy)) + else + FastSeq( + prettyIdentifier(lName), + prettyIdentifier(rName), + Pretty.prettyClass(sparsityStrategy), + ) case MatrixRowsHead(_, n) => single(n.toString) case MatrixColsHead(_, n) => single(n.toString) case MatrixRowsTail(_, n) => single(n.toString) @@ -338,17 +498,24 @@ class Pretty(width: Int, ribbonWidth: Int, elideLiterals: Boolean, maxLen: Int, case MatrixRepartition(_, n, strategy) => single(s"$n $strategy") case MatrixChooseCols(_, oldIndices) => single(prettyInts(oldIndices, elideLiterals)) case MatrixMapCols(_, _, newKey) => single(prettyStringsOpt(newKey)) - case MatrixUnionCols(l, r, joinType) => single(joinType) + case MatrixUnionCols(_, _, joinType) => single(joinType) case MatrixKeyRowsBy(_, keys, isSorted) => FastSeq(prettyIdentifiers(keys), Pretty.prettyBooleanLiteral(isSorted)) case TableRead(typ, dropRows, tr) => - FastSeq(if (typ == tr.fullType) "None" else typ.parsableString(), + FastSeq( + if (typ == tr.fullType) "None" else typ.parsableString(), Pretty.prettyBooleanLiteral(dropRows), - if (elideLiterals) tr.renderShort() else '"' + StringEscapeUtils.escapeString(JsonMethods.compact(tr.toJValue)) + '"') + if (elideLiterals) tr.renderShort() + else '"' + StringEscapeUtils.escapeString(JsonMethods.compact(tr.toJValue)) + '"', + ) case TableWrite(_, writer) => - single('"' + StringEscapeUtils.escapeString(Serialization.write(writer)(TableWriter.formats)) + '"') + single( + '"' + StringEscapeUtils.escapeString(Serialization.write(writer)(TableWriter.formats)) + '"' + ) case TableMultiWrite(_, writer) => - single('"' + StringEscapeUtils.escapeString(Serialization.write(writer)(WrappedMatrixNativeMultiWriter.formats)) + '"') + single('"' + StringEscapeUtils.escapeString( + Serialization.write(writer)(WrappedMatrixNativeMultiWriter.formats) + ) + '"') case TableKeyBy(_, keys, isSorted) => FastSeq(prettyIdentifiers(keys), Pretty.prettyBooleanLiteral(isSorted)) case TableRange(n, nPartitions) => FastSeq(n.toString, nPartitions.toString) @@ -364,13 +531,22 @@ class Pretty(width: Int, ribbonWidth: Int, elideLiterals: Boolean, maxLen: Int, case TableKeyByAndAggregate(_, _, _, nPartitions, bufferSize) => FastSeq(prettyIntOpt(nPartitions), bufferSize.toString) case TableExplode(_, path) => single(prettyStrings(path)) - case TableMapPartitions(_, g, p, _, requestedKey, allowedOverlap) => FastSeq(prettyIdentifier(g), prettyIdentifier(p), requestedKey.toString, allowedOverlap.toString) + case TableMapPartitions(_, g, p, _, requestedKey, allowedOverlap) => FastSeq( + prettyIdentifier(g), + prettyIdentifier(p), + requestedKey.toString, + allowedOverlap.toString, + ) case TableParallelize(_, nPartitions) => single(prettyIntOpt(nPartitions)) case TableOrderBy(_, sortFields) => single(prettySortFields(sortFields)) case CastMatrixToTable(_, entriesFieldName, colsFieldName) => FastSeq(prettyStringLiteral(entriesFieldName), prettyStringLiteral(colsFieldName)) case CastTableToMatrix(_, entriesFieldName, colsFieldName, colKey) => - FastSeq(prettyIdentifier(entriesFieldName), prettyIdentifier(colsFieldName), prettyIdentifiers(colKey)) + FastSeq( + prettyIdentifier(entriesFieldName), + prettyIdentifier(colsFieldName), + prettyIdentifiers(colKey), + ) case MatrixToMatrixApply(_, function) => single(prettyStringLiteral(Serialization.write(function)(RelationalFunctions.formats))) case MatrixToTableApply(_, function) => @@ -389,50 +565,63 @@ class Pretty(width: Int, ribbonWidth: Int, elideLiterals: Boolean, maxLen: Int, implicit val jsonFormats = DefaultFormats FastSeq( prettyIdentifier(cname), - prettyIdentifier(gname), - { - val boundsJson = Serialization.write(partitioner.rangeBounds.map(_.toJSON(partitioner.kType.toJSON))) - list("Partitioner " + partitioner.kType.parsableString() + prettyStringLiteral(boundsJson)) + prettyIdentifier(gname), { + val boundsJson = + Serialization.write(partitioner.rangeBounds.map(_.toJSON(partitioner.kType.toJSON))) + list( + "Partitioner " + partitioner.kType.parsableString() + prettyStringLiteral(boundsJson) + ) }, - text(errorId.toString) + text(errorId.toString), ) case TableRename(_, rowMap, globalMap) => val rowKV = rowMap.toArray val globalKV = globalMap.toArray - FastSeq(prettyStrings(rowKV.map(_._1)), prettyStrings(rowKV.map(_._2)), - prettyStrings(globalKV.map(_._1)), prettyStrings(globalKV.map(_._2))) + FastSeq( + prettyStrings(rowKV.map(_._1)), + prettyStrings(rowKV.map(_._2)), + prettyStrings(globalKV.map(_._1)), + prettyStrings(globalKV.map(_._2)), + ) case MatrixRename(_, globalMap, colMap, rowMap, entryMap) => val globalKV = globalMap.toArray val colKV = colMap.toArray val rowKV = rowMap.toArray val entryKV = entryMap.toArray - FastSeq(prettyStrings(globalKV.map(_._1)), prettyStrings(globalKV.map(_._2)), - prettyStrings(colKV.map(_._1)), prettyStrings(colKV.map(_._2)), - prettyStrings(rowKV.map(_._1)), prettyStrings(rowKV.map(_._2)), - prettyStrings(entryKV.map(_._1)), prettyStrings(entryKV.map(_._2))) + FastSeq( + prettyStrings(globalKV.map(_._1)), + prettyStrings(globalKV.map(_._2)), + prettyStrings(colKV.map(_._1)), + prettyStrings(colKV.map(_._2)), + prettyStrings(rowKV.map(_._1)), + prettyStrings(rowKV.map(_._2)), + prettyStrings(entryKV.map(_._1)), + prettyStrings(entryKV.map(_._2)), + ) case TableFilterIntervals(child, intervals, keep) => FastSeq( child.typ.keyType.parsableString(), prettyStringLiteral(Serialization.write( JSONAnnotationImpex.exportAnnotation(intervals, TArray(TInterval(child.typ.keyType))) )(RelationalSpec.formats)), - Pretty.prettyBooleanLiteral(keep)) + Pretty.prettyBooleanLiteral(keep), + ) case MatrixFilterIntervals(child, intervals, keep) => FastSeq( child.typ.rowType.parsableString(), prettyStringLiteral(Serialization.write( JSONAnnotationImpex.exportAnnotation(intervals, TArray(TInterval(child.typ.rowKeyStruct))) )(RelationalSpec.formats)), - Pretty.prettyBooleanLiteral(keep)) + Pretty.prettyBooleanLiteral(keep), + ) case RelationalLetTable(name, _, _) => single(prettyIdentifier(name)) case RelationalLetMatrixTable(name, _, _) => single(prettyIdentifier(name)) case RelationalLetBlockMatrix(name, _, _) => single(prettyIdentifier(name)) case ReadPartition(_, rowType, reader) => - FastSeq(rowType.parsableString(), - prettyStringLiteral(JsonMethods.compact(reader.toJValue))) - case WritePartition(value, writeCtx, writer) => + FastSeq(rowType.parsableString(), prettyStringLiteral(JsonMethods.compact(reader.toJValue))) + case WritePartition(_, _, writer) => single(prettyStringLiteral(JsonMethods.compact(writer.toJValue))) - case WriteMetadata(writeAnnotations, writer) => + case WriteMetadata(_, writer) => single(prettyStringLiteral(JsonMethods.compact(writer.toJValue), elide = elideLiterals)) case ReadValue(_, reader, reqType) => FastSeq(prettyStringLiteral(JsonMethods.compact(reader.toJValue)), reqType.parsableString()) @@ -459,12 +648,12 @@ class Pretty(width: Int, ribbonWidth: Int, elideLiterals: Boolean, maxLen: Int, fields.view.map { case (n, a) => list(n, pretty(a)) } - case ApplyAggOp(initOpArgs, seqOpArgs, aggSig) => + case ApplyAggOp(initOpArgs, seqOpArgs, _) => FastSeq(prettySeq(initOpArgs), prettySeq(seqOpArgs)) - case ApplyScanOp(initOpArgs, seqOpArgs, aggSig) => + case ApplyScanOp(initOpArgs, seqOpArgs, _) => FastSeq(prettySeq(initOpArgs), prettySeq(seqOpArgs)) - case InitOp(i, args, aggSig) => single(prettySeq(args)) - case SeqOp(i, args, aggSig) => single(prettySeq(args)) + case InitOp(_, args, _) => single(prettySeq(args)) + case SeqOp(_, args, _) => single(prettySeq(args)) case InsertFields(old, fields, fieldOrder) => val fieldDocs = fields.view.map { case (n, a) => list(prettyIdentifier(n), pretty(a)) @@ -473,13 +662,9 @@ class Pretty(width: Int, ribbonWidth: Int, elideLiterals: Boolean, maxLen: Int, case _ => ir.children.map(pretty).toFastSeq } - /* - val pt = ir match{ - case ir: IR => if (ir._pType != null) single(ir.pType.toString) - case _ => Iterable.empty - } - list(fillSep(text(prettyClass(ir)) +: pt ++ header(ir, elideLiterals)) +: body) - */ + /* val pt = ir match{ case ir: IR => if (ir._pType != null) single(ir.pType.toString) case _ + * => Iterable.empty } list(fillSep(text(prettyClass(ir)) +: pt ++ header(ir, elideLiterals)) + * +: body) */ list(fillSep(text(Pretty.prettyClass(ir)) +: header(ir)) +: body) } @@ -494,102 +679,121 @@ class Pretty(width: Int, ribbonWidth: Int, elideLiterals: Boolean, maxLen: Int, if (i > 0) Some(FastSeq()) else None case _: Switch => if (i > 0) Some(FastSeq()) else None - case TailLoop(name, args, _, body) => if (i == args.length) - Some(args.map { case (name, ir) => name -> "loopvar" } :+ - name -> "loop") else None - case StreamMap(a, name, _) => + case TailLoop(name, args, _, _) => if (i == args.length) + Some(args.map { case (name, _) => name -> "loopvar" } :+ + name -> "loop") + else None + case StreamMap(_, name, _) => if (i == 1) Some(Array(name -> "elt")) else None case StreamZip(as, names, _, _, _) => if (i == as.length) Some(names.map(_ -> "elt")) else None - case StreamZipJoin(as, key, curKey, curVals, _) => + case StreamZipJoin(as, _, curKey, curVals, _) => if (i == as.length) Some(Array(curKey -> "key", curVals -> "elts")) else None - case StreamFor(a, name, _) => + case StreamFor(_, name, _) => if (i == 1) Some(Array(name -> "elt")) else None - case StreamFlatMap(a, name, _) => + case StreamFlatMap(_, name, _) => if (i == 1) Some(Array(name -> "elt")) else None - case StreamFilter(a, name, _) => + case StreamFilter(_, name, _) => if (i == 1) Some(Array(name -> "elt")) else None - case StreamTakeWhile(a, name, _) => + case StreamTakeWhile(_, name, _) => if (i == 1) Some(Array(name -> "elt")) else None - case StreamDropWhile(a, name, _) => + case StreamDropWhile(_, name, _) => if (i == 1) Some(Array(name -> "elt")) else None - case StreamFold(a, zero, accumName, valueName, _) => + case StreamFold(_, _, accumName, valueName, _) => if (i == 2) Some(Array(accumName -> "accum", valueName -> "elt")) else None - case StreamFold2(a, accum, valueName, seq, result) => + case StreamFold2(_, accum, valueName, _, _) => if (i <= accum.length) None else if (i < 2 * accum.length + 1) - Some(Array(valueName -> "elt") ++ accum.map { case (name, value) => name -> "accum" }) + Some(Array(valueName -> "elt") ++ accum.map { case (name, _) => name -> "accum" }) else - Some(accum.map { case (name, value) => name -> "accum" }) - case RunAggScan(a, name, _, _, _, _) => + Some(accum.map { case (name, _) => name -> "accum" }) + case RunAggScan(_, name, _, _, _, _) => if (i == 2 || i == 3) Some(Array(name -> "elt")) else None - case StreamScan(a, zero, accumName, valueName, _) => + case StreamScan(_, _, accumName, valueName, _) => if (i == 2) Some(Array(accumName -> "accum", valueName -> "elt")) else None - case StreamAggScan(a, name, _) => + case StreamAggScan(_, name, _) => if (i == 1) Some(FastSeq(name -> "elt")) else None - case StreamJoinRightDistinct(ll, rr, _, _, l, r, _, _) => + case StreamJoinRightDistinct(_, _, _, _, l, r, _, _) => if (i == 2) Some(Array(l -> "l_elt", r -> "r_elt")) else None - case ArraySort(a, left, right, _) => + case StreamLeftIntervalJoin(_, _, _, _, l, r, _) => + if (i == 2) Some(Array(l -> "l_elt", r -> "r_elts")) else None + case ArraySort(_, left, right, _) => if (i == 1) Some(Array(left -> "l", right -> "r")) else None case AggArrayPerElement(_, elementName, indexName, _, _, _) => if (i == 1) Some(Array(elementName -> "elt", indexName -> "idx")) else None - case AggFold(zero, seqOp, combOp, accumName, otherAccumName, _) => { + case AggFold(_, _, _, accumName, otherAccumName, _) => if (i == 1) Some(Array(accumName -> "accum")) else if (i == 2) Some(Array(accumName -> "l", otherAccumName -> "r")) else None - } - case NDArrayMap(nd, name, _) => + case NDArrayMap(_, name, _) => if (i == 1) Some(Array(name -> "elt")) else None - case NDArrayMap2(l, r, lName, rName, _, _) => if (i == 2) - Some(Array(lName -> "l_elt", rName -> "r_elt")) - else - None - case CollectDistributedArray(contexts, globals, cname, gname, _, _, _, _) => + case NDArrayMap2(_, _, lName, rName, _, _) => if (i == 2) + Some(Array(lName -> "l_elt", rName -> "r_elt")) + else + None + case CollectDistributedArray(_, _, cname, gname, _, _, _, _) => if (i == 2) Some(Array(cname -> "ctx", gname -> "g")) else None - case TableAggregate(child, _) => + case TableAggregate(_, _) => if (i == 1) Some(Array("global" -> "g", "row" -> "row")) else None - case MatrixAggregate(child, _) => - if (i == 1) Some(Array("global" -> "g", "sa" -> "col", "va" -> "row", "g" -> "entry")) else None - case TableFilter(child, _) => + case MatrixAggregate(_, _) => + if (i == 1) Some(Array("global" -> "g", "sa" -> "col", "va" -> "row", "g" -> "entry")) + else None + case TableFilter(_, _) => if (i == 1) Some(Array("global" -> "g", "row" -> "row")) else None - case TableMapGlobals(child, _) => + case TableMapGlobals(_, _) => if (i == 1) Some(Array("global" -> "g")) else None - case TableMapRows(child, _) => + case TableMapRows(_, _) => if (i == 1) Some(Array("global" -> "g", "row" -> "row")) else None - case TableAggregateByKey(child, _) => + case TableAggregateByKey(_, _) => if (i == 1) Some(Array("global" -> "g", "row" -> "row")) else None - case TableKeyByAndAggregate(child, _, _, _, _) => + case TableKeyByAndAggregate(_, _, _, _, _) => if (i == 1 || i == 2) Some(Array("global" -> "g", "row" -> "row")) else None - case TableMapPartitions(child, g, p, _, _, _) => + case TableMapPartitions(_, g, p, _, _, _) => if (i == 1) Some(Array(g -> "g", p -> "part")) else None - case MatrixMapRows(child, _) => - if (i == 1) Some(Array("global" -> "g", "va" -> "row", "sa" -> "col", "g" -> "entry", "n_cols" -> "n_cols")) else None - case MatrixFilterRows(child, _) => + case MatrixMapRows(_, _) => + if (i == 1) Some(Array( + "global" -> "g", + "va" -> "row", + "sa" -> "col", + "g" -> "entry", + "n_cols" -> "n_cols", + )) + else None + case MatrixFilterRows(_, _) => if (i == 1) Some(Array("global" -> "g", "va" -> "row")) else None - case MatrixMapCols(child, _, _) => - if (i == 1) Some(Array("global" -> "g", "va" -> "row", "sa" -> "col", "g" -> "entry", "n_rows" -> "n_rows")) else None - case MatrixFilterCols(child, _) => + case MatrixMapCols(_, _, _) => + if (i == 1) Some(Array( + "global" -> "g", + "va" -> "row", + "sa" -> "col", + "g" -> "entry", + "n_rows" -> "n_rows", + )) + else None + case MatrixFilterCols(_, _) => if (i == 1) Some(Array("global" -> "g", "sa" -> "col")) else None - case MatrixMapEntries(child, _) => - if (i == 1) Some(Array("global" -> "g", "sa" -> "col", "va" -> "row", "g" -> "entry")) else None - case MatrixFilterEntries(child, _) => - if (i == 1) Some(Array("global" -> "g", "sa" -> "col", "va" -> "row", "g" -> "entry")) else None - case MatrixMapGlobals(child, _) => + case MatrixMapEntries(_, _) => + if (i == 1) Some(Array("global" -> "g", "sa" -> "col", "va" -> "row", "g" -> "entry")) + else None + case MatrixFilterEntries(_, _) => + if (i == 1) Some(Array("global" -> "g", "sa" -> "col", "va" -> "row", "g" -> "entry")) + else None + case MatrixMapGlobals(_, _) => if (i == 1) Some(Array("global" -> "g")) else None - case MatrixAggregateColsByKey(child, _, _) => + case MatrixAggregateColsByKey(_, _, _) => if (i == 1) Some(Array("global" -> "g", "va" -> "row", "sa" -> "col", "g" -> "entry")) else if (i == 2) Some(Array("global" -> "g", "sa" -> "col")) else None - case MatrixAggregateRowsByKey(child, _, _) => + case MatrixAggregateRowsByKey(_, _, _) => if (i == 1) Some(Array("global" -> "g", "va" -> "row", "sa" -> "col", "g" -> "entry")) else if (i == 2) @@ -600,8 +804,6 @@ class Pretty(width: Int, ribbonWidth: Int, elideLiterals: Boolean, maxLen: Int, if (i == 1) Some(Array(eltName -> "elt")) else None case BlockMatrixMap2(_, _, lName, rName, _, _) => if (i == 2) Some(Array(lName -> "l", rName -> "r")) else None - case AggLet(name, _, _, _) => - if (i == 1) Some(Array(name -> "")) else None case AggExplode(_, name, _, _) => if (i == 1) Some(Array(name -> "elt")) else None case StreamAgg(_, name, _) => @@ -623,31 +825,47 @@ class Pretty(width: Int, ribbonWidth: Int, elideLiterals: Boolean, maxLen: Int, case _ => "" } - def uniqueify(base: String): String = { - if (base.isEmpty) { - identCounter += 1 - identCounter.toString - } else if (idents.contains(base)) { - idents(base) += 1 - if (base.last.isDigit) - s"${base}_${idents(base)}" - else - s"${base}${idents(base)}" + def uniqueify(base: String, origName: Option[String]): String = { + if (preserveNames && origName.nonEmpty) { + origName.get } else { - idents(base) = 1 - base + if (base.isEmpty) { + identCounter += 1 + identCounter.toString + } else if (idents.contains(base)) { + idents(base) += 1 + if (base.last.isDigit) + s"${base}_${idents(base)}" + else + s"$base${idents(base)}" + } else { + idents(base) = 1 + base + } } } - def prettyWithIdent(ir: BaseIR, bindings: Env[String], prefix: String): (Doc, String) = { + def prettyWithIdent( + ir: BaseIR, + bindings: Env[String], + prefix: String, + origName: Option[String], + scope: Int = Scope.EVAL, + ): (Doc, String) = { val (pre, body) = pretty(ir, bindings) - val ident = prefix + uniqueify(getIdentBase(ir)) - val doc = vsep(pre, hsep(text(ident), "=", body)) + val ident = prefix + uniqueify(getIdentBase(ir), origName) + val assignmentSymbol = scope match { + case Scope.EVAL => "=" + case Scope.AGG => "=(agg)" + case Scope.SCAN => "=(scan)" + } + val doc = vsep(pre, hsep(text(ident), assignmentSymbol, body)) (doc, ident) } - def prettyBlock(ir: BaseIR, newBindings: IndexedSeq[(String, String)], bindings: Env[String]): Doc = { - val args = newBindings.map { case (name, base) => name -> s"%${uniqueify(base)}" } + def prettyBlock(ir: BaseIR, newBindings: IndexedSeq[(String, String)], bindings: Env[String]) + : Doc = { + val args = newBindings.map { case (name, base) => name -> s"%${uniqueify(base, Some(name))}" } val blockBindings = bindings.bindIterable(args) val openBlock = if (args.isEmpty) text("{") @@ -655,10 +873,15 @@ class Pretty(width: Int, ribbonWidth: Int, elideLiterals: Boolean, maxLen: Int, concat("{", softline, args.map(_._2).mkString("(", ", ", ") =>")) ir match { case Ref(name, _) => - val body = blockBindings.lookupOption(name).getOrElse(uniqueify("%undefined_ref")) + val body = + blockBindings.lookupOption(name).getOrElse(uniqueify("%undefined_ref", Some(name))) concat(openBlock, group(nest(2, concat(line, body, line)), "}")) case RelationalRef(name, _) => - val body = blockBindings.lookupOption(name).getOrElse(uniqueify("%undefined_relational_ref")) + val body = + blockBindings.lookupOption(name).getOrElse(uniqueify( + "%undefined_relational_ref", + Some(name), + )) concat(openBlock, group(nest(2, concat(line, body, line)), "}")) case _ => val (pre, body) = pretty(ir, blockBindings) @@ -667,28 +890,29 @@ class Pretty(width: Int, ribbonWidth: Int, elideLiterals: Boolean, maxLen: Int, } def pretty(ir: BaseIR, bindings: Env[String]): (Doc, Doc) = ir match { - case Let(binds, body) => + case Block(binds, body) => val (valueDoc, newBindings) = - binds.foldLeft((empty, bindings)) { case ((valueDoc, bindings), (name, value)) => - val (doc, ident) = prettyWithIdent(value, bindings, "%") - (concat(valueDoc, doc), bindings.bind(name, ident)) + binds.foldLeft((empty, bindings)) { case ((valueDoc, bindings), binding) => + val (doc, ident) = + prettyWithIdent(binding.value, bindings, "%", Some(binding.name), binding.scope) + (concat(valueDoc, doc), bindings.bind(binding.name, ident)) } val (bodyPre, bodyHead) = pretty(body, newBindings) (concat(valueDoc, bodyPre), bodyHead) case RelationalLet(name, value, body) => - val (valueDoc, valueIdent) = prettyWithIdent(value, bindings, "%") + val (valueDoc, valueIdent) = prettyWithIdent(value, bindings, "%", Some(name)) val (bodyPre, bodyHead) = pretty(body, bindings.bind(name, valueIdent)) (concat(valueDoc, bodyPre), bodyHead) case RelationalLetTable(name, value, body) => - val (valueDoc, valueIdent) = prettyWithIdent(value, bindings, "%") + val (valueDoc, valueIdent) = prettyWithIdent(value, bindings, "%", Some(name)) val (bodyPre, bodyHead) = pretty(body, bindings.bind(name, valueIdent)) (concat(valueDoc, bodyPre), bodyHead) case RelationalLetMatrixTable(name, value, body) => - val (valueDoc, valueIdent) = prettyWithIdent(value, bindings, "%") + val (valueDoc, valueIdent) = prettyWithIdent(value, bindings, "%", Some(name)) val (bodyPre, bodyHead) = pretty(body, bindings.bind(name, valueIdent)) (concat(valueDoc, bodyPre), bodyHead) case RelationalLetBlockMatrix(name, value, body) => - val (valueDoc, valueIdent) = prettyWithIdent(value, bindings, "%") + val (valueDoc, valueIdent) = prettyWithIdent(value, bindings, "%", Some(name)) val (bodyPre, bodyHead) = pretty(body, bindings.bind(name, valueIdent)) (concat(valueDoc, bodyPre), bodyHead) case _ => @@ -699,11 +923,14 @@ class Pretty(width: Int, ribbonWidth: Int, elideLiterals: Boolean, maxLen: Int, } yield { child match { case Ref(name, _) => - bindings.lookupOption(name).getOrElse(uniqueify("%undefined_ref")) + bindings.lookupOption(name).getOrElse(uniqueify("%undefined_ref", Some(name))) case RelationalRef(name, _) => - bindings.lookupOption(name).getOrElse(uniqueify("%undefined_relational_ref")) + bindings.lookupOption(name).getOrElse(uniqueify( + "%undefined_relational_ref", + Some(name), + )) case _ => - val (body, ident) = prettyWithIdent(child, bindings, "!") + val (body, ident) = prettyWithIdent(child, bindings, "!", None) strictChildBodies += body ident } @@ -715,7 +942,8 @@ class Pretty(width: Int, ribbonWidth: Int, elideLiterals: Boolean, maxLen: Int, } yield prettyBlock(child, blockArgs(ir, i).get, bindings)).toFastSeq val attsIterable = header(ir, elideBindings = true) - val attributes = if (attsIterable.isEmpty) Iterable.empty else + val attributes = if (attsIterable.isEmpty) Iterable.empty + else RichIterable.single(concat(attsIterable.intersperse[Doc]("[", ", ", "]"))) def standardArgs = if (strictChildIdents.isEmpty) @@ -730,8 +958,8 @@ class Pretty(width: Int, ribbonWidth: Int, elideLiterals: Boolean, maxLen: Int, }.mkString("(", ", ", ")") hsep(text(Pretty.prettyClass(ir) + args) +: (attributes ++ nestedBlocks)) case InsertFields(_, fields, _) => - val newFields = (fields.map(_._1), strictChildIdents.tail).zipped.map { (field, value) => - s"$field: $value" + val newFields = (fields.map(_._1), strictChildIdents.tail).zipped.map { + (field, value) => s"$field: $value" }.mkString("(", ", ", ")") val args = s" ${strictChildIdents.head} $newFields" hsep(text(Pretty.prettyClass(ir) + args) +: (attributes ++ nestedBlocks)) @@ -741,7 +969,8 @@ class Pretty(width: Int, ribbonWidth: Int, elideLiterals: Boolean, maxLen: Int, text("then"), nestedBlocks(0), text("else"), - nestedBlocks(1)) + nestedBlocks(1), + ) case _ => hsep(text(Pretty.prettyClass(ir) + standardArgs) +: (attributes ++ nestedBlocks)) } diff --git a/hail/src/main/scala/is/hail/expr/ir/PrimitiveTypeToIRIntermediateClassTag.scala b/hail/src/main/scala/is/hail/expr/ir/PrimitiveTypeToIRIntermediateClassTag.scala index e070f7e837f..a26fd16dc08 100644 --- a/hail/src/main/scala/is/hail/expr/ir/PrimitiveTypeToIRIntermediateClassTag.scala +++ b/hail/src/main/scala/is/hail/expr/ir/PrimitiveTypeToIRIntermediateClassTag.scala @@ -1,9 +1,8 @@ package is.hail.expr.ir -import is.hail.types._ import is.hail.types.virtual._ -import scala.reflect.{ClassTag, classTag} +import scala.reflect.{classTag, ClassTag} object PrimitiveTypeToIRIntermediateClassTag { def apply(t: Type): ClassTag[_] = t match { diff --git a/hail/src/main/scala/is/hail/expr/ir/PruneDeadFields.scala b/hail/src/main/scala/is/hail/expr/ir/PruneDeadFields.scala index 67263fb8e78..97090f33e4b 100644 --- a/hail/src/main/scala/is/hail/expr/ir/PruneDeadFields.scala +++ b/hail/src/main/scala/is/hail/expr/ir/PruneDeadFields.scala @@ -5,25 +5,33 @@ import is.hail.backend.ExecuteContext import is.hail.expr.Nat import is.hail.types._ import is.hail.types.virtual._ +import is.hail.types.virtual.TIterable.elementType import is.hail.utils._ import scala.collection.mutable - object PruneDeadFields { - case class ComputeMutableState(requestedType: Memo[BaseType], relationalRefs: mutable.HashMap[String, BoxedArrayBuilder[Type]]) { - def rebuildState: RebuildMutableState = RebuildMutableState(requestedType, mutable.HashMap.empty) + case class ComputeMutableState( + requestedType: Memo[BaseType], + relationalRefs: mutable.HashMap[String, BoxedArrayBuilder[Type]], + ) { + def rebuildState: RebuildMutableState = + RebuildMutableState(requestedType, mutable.HashMap.empty) } - case class RebuildMutableState(requestedType: Memo[BaseType], relationalRefs: mutable.HashMap[String, Type]) + case class RebuildMutableState( + requestedType: Memo[BaseType], + relationalRefs: mutable.HashMap[String, Type], + ) def subsetType(t: Type, path: Array[String], index: Int = 0): Type = { if (index == path.length) PruneDeadFields.minimal(t) else t match { - case ts: TStruct => TStruct(path(index) -> subsetType(ts.field(path(index)).typ, path, index + 1)) + case ts: TStruct => + TStruct(path(index) -> subsetType(ts.field(path(index)).typ, path, index + 1)) case ta: TArray => TArray(subsetType(ta.elementType, path, index)) case ts: TStream => TStream(subsetType(ts.elementType, path, index)) } @@ -34,23 +42,24 @@ object PruneDeadFields { (superType, subType) match { case (tt1: TableType, tt2: TableType) => isSupertype(tt1.globalType, tt2.globalType) && - isSupertype(tt1.rowType, tt2.rowType) && - tt2.key.startsWith(tt1.key) + isSupertype(tt1.rowType, tt2.rowType) && + tt2.key.startsWith(tt1.key) case (mt1: MatrixType, mt2: MatrixType) => isSupertype(mt1.globalType, mt2.globalType) && - isSupertype(mt1.rowType, mt2.rowType) && - isSupertype(mt1.colType, mt2.colType) && - isSupertype(mt1.entryType, mt2.entryType) && - mt2.rowKey.startsWith(mt1.rowKey) && - mt2.colKey.startsWith(mt1.colKey) + isSupertype(mt1.rowType, mt2.rowType) && + isSupertype(mt1.colType, mt2.colType) && + isSupertype(mt1.entryType, mt2.entryType) && + mt2.rowKey.startsWith(mt1.rowKey) && + mt2.colKey.startsWith(mt1.colKey) case (TArray(et1), TArray(et2)) => isSupertype(et1, et2) - case (TNDArray(et1, ndims1), TNDArray(et2, ndims2)) => (ndims1 == ndims2) && isSupertype(et1, et2) + case (TNDArray(et1, ndims1), TNDArray(et2, ndims2)) => + (ndims1 == ndims2) && isSupertype(et1, et2) case (TStream(et1), TStream(et2)) => isSupertype(et1, et2) case (TSet(et1), TSet(et2)) => isSupertype(et1, et2) case (TDict(kt1, vt1), TDict(kt2, vt2)) => isSupertype(kt1, kt2) && isSupertype(vt1, vt2) case (s1: TStruct, s2: TStruct) => var idx = -1 - s1.fields.forall { f => + s1.fields.forall { f => val s2field = s2.field(f.name) if (s2field.index > idx) { idx = s2field.index @@ -94,22 +103,27 @@ object PruneDeadFields { rebuild(ctx, bmir, ms.rebuildState) case vir: IR => memoizeValueIR(ctx, vir, vir.typ, ms) - rebuildIR(ctx, vir, BindingEnv(Env.empty, Some(Env.empty), Some(Env.empty)), ms.rebuildState) + rebuildIR( + ctx, + vir, + BindingEnv(Env.empty, Some(Env.empty), Some(Env.empty)), + ms.rebuildState, + ) } } catch { - case e: Throwable => fatal(s"error trying to rebuild IR:\n${ Pretty(ctx, ir, elideLiterals = true) }", e) + case e: Throwable => + fatal(s"error trying to rebuild IR:\n${Pretty(ctx, ir, elideLiterals = true)}", e) } } def selectKey(t: TStruct, k: IndexedSeq[String]): TStruct = t.filterSet(k.toSet)._1 - def minimal(tt: TableType): TableType = { + def minimal(tt: TableType): TableType = TableType( rowType = TStruct.empty, key = FastSeq(), - globalType = TStruct.empty + globalType = TStruct.empty, ) - } def minimal(mt: MatrixType): MatrixType = { MatrixType( @@ -118,13 +132,13 @@ object PruneDeadFields { rowType = TStruct.empty, colType = TStruct.empty, globalType = TStruct.empty, - entryType = TStruct.empty + entryType = TStruct.empty, ) } def minimal[T <: Type](base: T): T = { val result = base match { - case ts: TStruct => TStruct.empty + case _: TStruct => TStruct.empty case ta: TArray => TArray(minimal(ta.elementType)) case ta: TStream => TStream(minimal(ta.elementType)) case t => t @@ -132,19 +146,20 @@ object PruneDeadFields { result.asInstanceOf[T] } - def minimalBT[T <: BaseType](base: T): T = { + def minimalBT[T <: BaseType](base: T): T = (base match { case tt: TableType => minimal(tt) case mt: MatrixType => minimal(mt) case t: Type => minimal(t) }).asInstanceOf[T] - } - def unifyKey(children: Seq[IndexedSeq[String]]): IndexedSeq[String] = { - children.foldLeft(FastSeq[String]()) { case (comb, k) => if (k.length > comb.length) k else comb } - } + def unifyKey(children: Seq[IndexedSeq[String]]): IndexedSeq[String] = + children.foldLeft(FastSeq[String]()) { case (comb, k) => + if (k.length > comb.length) k else comb + } - def unifyBaseType(base: BaseType, children: BaseType*): BaseType = unifyBaseTypeSeq(base, children) + def unifyBaseType(base: BaseType, children: BaseType*): BaseType = + unifyBaseTypeSeq(base, children) def unifyBaseTypeSeq(base: BaseType, _children: Seq[BaseType]): BaseType = { try { @@ -157,7 +172,7 @@ object PruneDeadFields { tt.copy( key = unifyKey(ttChildren.map(_.key)), rowType = unify(tt.rowType, ttChildren.map(_.rowType): _*), - globalType = unify(tt.globalType, ttChildren.map(_.globalType): _*) + globalType = unify(tt.globalType, ttChildren.map(_.globalType): _*), ) case mt: MatrixType => val mtChildren = children.map(_.asInstanceOf[MatrixType]) @@ -167,7 +182,7 @@ object PruneDeadFields { globalType = unifySeq(mt.globalType, mtChildren.map(_.globalType)), rowType = unifySeq(mt.rowType, mtChildren.map(_.rowType)), entryType = unifySeq(mt.entryType, mtChildren.map(_.entryType)), - colType = unifySeq(mt.colType, mtChildren.map(_.colType)) + colType = unifySeq(mt.colType, mtChildren.map(_.colType)), ) case t: Type => if (children.isEmpty) @@ -202,7 +217,8 @@ object PruneDeadFields { val ab = fieldArrays(oldIdx) if (ab.nonEmpty) { val oldField = ts.fields(oldIdx) - subFields(newIdx) = Field(oldField.name, unifySeq(oldField.typ, ab.result()), newIdx) + subFields(newIdx) = + Field(oldField.name, unifySeq(oldField.typ, ab.result()), newIdx) newIdx += 1 } oldIdx += 1 @@ -240,7 +256,8 @@ object PruneDeadFields { val ab = fieldArrays(oldIdx) if (ab.nonEmpty) { val oldField = tt._types(oldIdx) - subFields(newIdx) = TupleField(oldField.index, unifySeq(oldField.typ, ab.result())) + subFields(newIdx) = + TupleField(oldField.index, unifySeq(oldField.typ, ab.result())) newIdx += 1 } oldIdx += 1 @@ -254,22 +271,30 @@ object PruneDeadFields { if (!children.forall(_.asInstanceOf[Type] == t)) { val badChildren = children.filter(c => c.asInstanceOf[Type] != t) .map(c => "\n child: " + c.asInstanceOf[Type].parsableString()) - throw new RuntimeException(s"invalid unification:\n base: ${ t.parsableString() }${ badChildren.mkString("\n") }") + throw new RuntimeException( + s"invalid unification:\n base: ${t.parsableString()}${badChildren.mkString("\n")}" + ) } base } } } catch { case e: RuntimeException => - throw new RuntimeException(s"failed to unify children while unifying:\n base: ${ base }\n${ _children.mkString("\n") }", e) + throw new RuntimeException( + s"failed to unify children while unifying:\n base: $base\n${_children.mkString("\n")}", + e, + ) } } - def unify[T <: BaseType](base: T, children: T*): T = unifyBaseTypeSeq(base, children).asInstanceOf[T] + def unify[T <: BaseType](base: T, children: T*): T = + unifyBaseTypeSeq(base, children).asInstanceOf[T] - def unifySeq[T <: BaseType](base: T, children: Seq[T]): T = unifyBaseTypeSeq(base, children).asInstanceOf[T] + def unifySeq[T <: BaseType](base: T, children: Seq[T]): T = + unifyBaseTypeSeq(base, children).asInstanceOf[T] - def unifyEnvs(envs: BindingEnv[BoxedArrayBuilder[Type]]*): BindingEnv[BoxedArrayBuilder[Type]] = unifyEnvsSeq(envs) + def unifyEnvs(envs: BindingEnv[BoxedArrayBuilder[Type]]*): BindingEnv[BoxedArrayBuilder[Type]] = + unifyEnvsSeq(envs) def concatEnvs(envs: Seq[Env[BoxedArrayBuilder[Type]]]): Env[BoxedArrayBuilder[Type]] = { val lc = envs.lengthCompare(1) @@ -292,7 +317,8 @@ object PruneDeadFields { } } - def unifyEnvsSeq(envs: Seq[BindingEnv[BoxedArrayBuilder[Type]]]): BindingEnv[BoxedArrayBuilder[Type]] = { + def unifyEnvsSeq(envs: Seq[BindingEnv[BoxedArrayBuilder[Type]]]) + : BindingEnv[BoxedArrayBuilder[Type]] = { val lc = envs.lengthCompare(1) if (lc < 0) BindingEnv.empty[BoxedArrayBuilder[Type]] @@ -301,7 +327,8 @@ object PruneDeadFields { else { val evalEnv = concatEnvs(envs.map(_.eval)) val aggEnv = if (envs.exists(_.agg.isDefined)) Some(concatEnvs(envs.flatMap(_.agg))) else None - val scanEnv = if (envs.exists(_.scan.isDefined)) Some(concatEnvs(envs.flatMap(_.scan))) else None + val scanEnv = + if (envs.exists(_.scan.isDefined)) Some(concatEnvs(envs.flatMap(_.scan))) else None BindingEnv(evalEnv, aggEnv, scanEnv) } } @@ -329,36 +356,64 @@ object PruneDeadFields { ctx: ExecuteContext, tir: TableIR, requestedType: TableType, - memo: ComputeMutableState - ) { + memo: ComputeMutableState, + ): Unit = { memo.requestedType.bind(tir, requestedType) tir match { case TableRead(_, _, _) => case TableLiteral(_, _, _, _) => case TableParallelize(rowsAndGlobal, _) => - memoizeValueIR(ctx, rowsAndGlobal, TStruct("rows" -> TArray(requestedType.rowType), "global" -> requestedType.globalType), memo) + memoizeValueIR( + ctx, + rowsAndGlobal, + TStruct("rows" -> TArray(requestedType.rowType), "global" -> requestedType.globalType), + memo, + ) case TableRange(_, _) => case TableRepartition(child, _, _) => memoizeTableIR(ctx, child, requestedType, memo) - case TableHead(child, _) => memoizeTableIR(ctx, child, TableType( - key = child.typ.key, - rowType = unify(child.typ.rowType, selectKey(child.typ.rowType, child.typ.key), requestedType.rowType), - globalType = requestedType.globalType), memo) - case TableTail(child, _) => memoizeTableIR(ctx, child, TableType( - key = child.typ.key, - rowType = unify(child.typ.rowType, selectKey(child.typ.rowType, child.typ.key), requestedType.rowType), - globalType = requestedType.globalType), memo) + case TableHead(child, _) => memoizeTableIR( + ctx, + child, + TableType( + key = child.typ.key, + rowType = unify( + child.typ.rowType, + selectKey(child.typ.rowType, child.typ.key), + requestedType.rowType, + ), + globalType = requestedType.globalType, + ), + memo, + ) + case TableTail(child, _) => memoizeTableIR( + ctx, + child, + TableType( + key = child.typ.key, + rowType = unify( + child.typ.rowType, + selectKey(child.typ.rowType, child.typ.key), + requestedType.rowType, + ), + globalType = requestedType.globalType, + ), + memo, + ) case TableGen(contexts, globals, cname, gname, body, _, _) => val bodyEnv = memoizeValueIR(ctx, body, TStream(requestedType.rowType), memo) // Contexts are only used in the body so we only need to keep the fields used therein - val contextsElemType = unifySeq(TIterable.elementType(contexts.typ), uses(cname, bodyEnv.eval)) + val contextsElemType = + unifySeq(TIterable.elementType(contexts.typ), uses(cname, bodyEnv.eval)) // Globals are exported and used in body, so keep the union of the used fields - val globalsType = unifySeq(globals.typ, uses(gname, bodyEnv.eval) :+ requestedType.globalType) + val globalsType = + unifySeq(globals.typ, uses(gname, bodyEnv.eval) :+ requestedType.globalType) memoizeValueIR(ctx, contexts, TStream(contextsElemType), memo) memoizeValueIR(ctx, globals, globalsType, memo) case TableJoin(left, right, _, joinKey) => - val lk = unifyKey(FastSeq(requestedType.key.take(left.typ.key.length), left.typ.key.take(joinKey))) + val lk = + unifyKey(FastSeq(requestedType.key.take(left.typ.key.length), left.typ.key.take(joinKey))) val lkSet = lk.toSet val leftDep = TableType( key = lk, @@ -366,12 +421,16 @@ object PruneDeadFields { if (lkSet.contains(f)) Some(f -> left.typ.rowType.field(f).typ) else - requestedType.rowType.selfField(f).map(reqF => f -> reqF.typ)): _*), + requestedType.rowType.selfField(f).map(reqF => f -> reqF.typ) + ): _*), globalType = TStruct(left.typ.globalType.fieldNames.flatMap(f => - requestedType.globalType.selfField(f).map(reqF => f -> reqF.typ)): _*)) + requestedType.globalType.selfField(f).map(reqF => f -> reqF.typ) + ): _*), + ) memoizeTableIR(ctx, left, leftDep, memo) - val rk = right.typ.key.take(joinKey + math.max(0, requestedType.key.length - left.typ.key.length)) + val rk = + right.typ.key.take(joinKey + math.max(0, requestedType.key.length - left.typ.key.length)) val rightKeyFields = rk.toSet val rightDep = TableType( key = rk, @@ -379,9 +438,12 @@ object PruneDeadFields { if (rightKeyFields.contains(f)) Some(f -> right.typ.rowType.field(f).typ) else - requestedType.rowType.selfField(f).map(reqF => f -> reqF.typ)): _*), + requestedType.rowType.selfField(f).map(reqF => f -> reqF.typ) + ): _*), globalType = TStruct(right.typ.globalType.fieldNames.flatMap(f => - requestedType.globalType.selfField(f).map(reqF => f -> reqF.typ)): _*)) + requestedType.globalType.selfField(f).map(reqF => f -> reqF.typ) + ): _*), + ) memoizeTableIR(ctx, right, rightDep, memo) case TableLeftJoinRightDistinct(left, right, root) => val fieldDep = requestedType.rowType.selfField(root).map(_.typ.asInstanceOf[TStruct]) @@ -392,16 +454,22 @@ object PruneDeadFields { rowType = unify( right.typ.rowType, FastSeq[TStruct](right.typ.rowType.filterSet(right.typ.key.toSet, true)._1) ++ - FastSeq(struct): _*), - globalType = minimal(right.typ.globalType)) + FastSeq(struct): _* + ), + globalType = minimal(right.typ.globalType), + ) memoizeTableIR(ctx, right, rightDep, memo) val lk = unifyKey(FastSeq(left.typ.key.take(right.typ.key.length), requestedType.key)) val leftDep = TableType( key = lk, - rowType = unify(left.typ.rowType, requestedType.rowType.filterSet(Set(root), include = false)._1, - selectKey(left.typ.rowType, lk)), - globalType = requestedType.globalType) + rowType = unify( + left.typ.rowType, + requestedType.rowType.filterSet(Set(root), include = false)._1, + selectKey(left.typ.rowType, lk), + ), + globalType = requestedType.globalType, + ) memoizeTableIR(ctx, left, leftDep, memo) case None => // don't memoize right if we are going to elide it during rebuild @@ -421,16 +489,22 @@ object PruneDeadFields { rowType = unify( right.typ.rowType, FastSeq[TStruct](right.typ.rowType.filterSet(right.typ.key.toSet, true)._1) ++ - FastSeq(struct): _*), - globalType = minimal(right.typ.globalType)) + FastSeq(struct): _* + ), + globalType = minimal(right.typ.globalType), + ) memoizeTableIR(ctx, right, rightDep, memo) val lk = unifyKey(FastSeq(left.typ.key.take(right.typ.key.length), requestedType.key)) val leftDep = TableType( key = lk, - rowType = unify(left.typ.rowType, requestedType.rowType.filterSet(Set(root), include = false)._1, - selectKey(left.typ.rowType, lk)), - globalType = requestedType.globalType) + rowType = unify( + left.typ.rowType, + requestedType.rowType.filterSet(Set(root), include = false)._1, + selectKey(left.typ.rowType, lk), + ), + globalType = requestedType.globalType, + ) memoizeTableIR(ctx, left, leftDep, memo) case None => // don't memoize right if we are going to elide it during rebuild @@ -449,23 +523,29 @@ object PruneDeadFields { rowType = TStruct(child1.typ.rowType.fieldNames.flatMap(f => child1.typ.keyType.selfField(f).orElse(rType.selfField(f)).map(reqF => f -> reqF.typ) ): _*), - globalType = gType) + globalType = gType, + ) children.foreach(memoizeTableIR(ctx, _, dep, memo)) case TableExplode(child, path) => def getExplodedField(typ: TableType): Type = typ.rowType.queryTyped(path.toList)._1 val preExplosionFieldType = getExplodedField(child.typ) - val prunedPreExlosionFieldType = try { - val t = getExplodedField(requestedType) - preExplosionFieldType match { - case ta: TArray => TArray(t) - case ts: TSet => ts.copy(elementType = t) + val prunedPreExlosionFieldType = + try { + val t = getExplodedField(requestedType) + preExplosionFieldType match { + case _: TArray => TArray(t) + case ts: TSet => ts.copy(elementType = t) + } + } catch { + case _: AnnotationPathException => minimal(preExplosionFieldType) } - } catch { - case e: AnnotationPathException => minimal(preExplosionFieldType) - } - val dep = requestedType.copy(rowType = unify(child.typ.rowType, - requestedType.rowType.insert(prunedPreExlosionFieldType, path)._1.asInstanceOf[TStruct])) + val dep = requestedType.copy(rowType = + unify( + child.typ.rowType, + requestedType.rowType.insert(prunedPreExlosionFieldType, path)._1.asInstanceOf[TStruct], + ) + ) memoizeTableIR(ctx, child, dep, memo) case TableFilter(child, pred) => val irDep = memoizeAndGetDep(ctx, pred, pred.typ, child.typ, memo) @@ -476,48 +556,88 @@ object PruneDeadFields { val childReqKey = if (isSorted) child.typ.key else if (isPrefix) - if (reqKey.length <= child.typ.key.length) reqKey else child.typ.key + if (reqKey.length <= child.typ.key.length) reqKey else child.typ.key else FastSeq() - memoizeTableIR(ctx, child, TableType( - key = childReqKey, - rowType = unify(child.typ.rowType, selectKey(child.typ.rowType, childReqKey), requestedType.rowType), - globalType = requestedType.globalType), memo) + memoizeTableIR( + ctx, + child, + TableType( + key = childReqKey, + rowType = unify( + child.typ.rowType, + selectKey(child.typ.rowType, childReqKey), + requestedType.rowType, + ), + globalType = requestedType.globalType, + ), + memo, + ) case TableOrderBy(child, sortFields) => - val k = if (sortFields.forall(_.sortOrder == Ascending) && child.typ.key.startsWith(sortFields.map(_.field))) - child.typ.key - else - FastSeq() - memoizeTableIR(ctx, child, TableType( - key = k, - rowType = unify(child.typ.rowType, - selectKey(child.typ.rowType, sortFields.map(_.field) ++ k), - requestedType.rowType), - globalType = requestedType.globalType), memo) + val k = + if ( + sortFields.forall(_.sortOrder == Ascending) && child.typ.key.startsWith( + sortFields.map(_.field) + ) + ) + child.typ.key + else + FastSeq() + memoizeTableIR( + ctx, + child, + TableType( + key = k, + rowType = unify( + child.typ.rowType, + selectKey(child.typ.rowType, sortFields.map(_.field) ++ k), + requestedType.rowType, + ), + globalType = requestedType.globalType, + ), + memo, + ) case TableDistinct(child) => - val dep = TableType(key = child.typ.key, - rowType = unify(child.typ.rowType, requestedType.rowType, selectKey(child.typ.rowType, child.typ.key)), - globalType = requestedType.globalType) + val dep = TableType( + key = child.typ.key, + rowType = unify( + child.typ.rowType, + requestedType.rowType, + selectKey(child.typ.rowType, child.typ.key), + ), + globalType = requestedType.globalType, + ) memoizeTableIR(ctx, child, dep, memo) case TableMapPartitions(child, gName, pName, body, requestedKey, _) => - val requestedKeyStruct = child.typ.keyType.truncate(math.max(requestedType.key.length, requestedKey)) - val reqRowsType = unify(body.typ, TStream(requestedType.rowType), TStream(requestedKeyStruct)) + val requestedKeyStruct = + child.typ.keyType.truncate(math.max(requestedType.key.length, requestedKey)) + val reqRowsType = + unify(body.typ, TStream(requestedType.rowType), TStream(requestedKeyStruct)) val bodyDep = memoizeValueIR(ctx, body, reqRowsType, memo) val depGlobalType = unifySeq( child.typ.globalType, - uses(gName, bodyDep.eval) :+ requestedType.globalType + uses(gName, bodyDep.eval) :+ requestedType.globalType, ) val depRowType = unifySeq( child.typ.rowType, - uses(pName, bodyDep.eval).map(TIterable.elementType) :+ requestedKeyStruct) + uses(pName, bodyDep.eval).map(TIterable.elementType) :+ requestedKeyStruct, + ) val dep = TableType( key = requestedKeyStruct.fieldNames, rowType = depRowType.asInstanceOf[TStruct], - globalType = depGlobalType.asInstanceOf[TStruct]) + globalType = depGlobalType.asInstanceOf[TStruct], + ) memoizeTableIR(ctx, child, dep, memo) case TableMapRows(child, newRow) => val (reqKey, reqRowType) = if (ContainsScan(newRow)) - (child.typ.key, unify(newRow.typ, requestedType.rowType, selectKey(newRow.typ.asInstanceOf[TStruct], child.typ.key))) + ( + child.typ.key, + unify( + newRow.typ, + requestedType.rowType, + selectKey(newRow.typ.asInstanceOf[TStruct], child.typ.key), + ), + ) else (requestedType.key, requestedType.rowType) val rowDep = memoizeAndGetDep(ctx, newRow, reqRowType, child.typ, memo) @@ -525,40 +645,69 @@ object PruneDeadFields { val dep = TableType( key = reqKey, rowType = unify(child.typ.rowType, selectKey(child.typ.rowType, reqKey), rowDep.rowType), - globalType = unify(child.typ.globalType, requestedType.globalType, rowDep.globalType) + globalType = unify(child.typ.globalType, requestedType.globalType, rowDep.globalType), ) memoizeTableIR(ctx, child, dep, memo) case TableMapGlobals(child, newGlobals) => val globalDep = memoizeAndGetDep(ctx, newGlobals, requestedType.globalType, child.typ, memo) - memoizeTableIR(ctx, child, unify(child.typ, requestedType.copy(globalType = globalDep.globalType), globalDep), memo) + memoizeTableIR( + ctx, + child, + unify(child.typ, requestedType.copy(globalType = globalDep.globalType), globalDep), + memo, + ) case TableAggregateByKey(child, expr) => - val exprRequestedType = requestedType.rowType.filter(f => expr.typ.asInstanceOf[TStruct].hasField(f.name))._1 + val exprRequestedType = + requestedType.rowType.filter(f => expr.typ.asInstanceOf[TStruct].hasField(f.name))._1 val aggDep = memoizeAndGetDep(ctx, expr, exprRequestedType, child.typ, memo) - memoizeTableIR(ctx, child, TableType(key = child.typ.key, - rowType = unify(child.typ.rowType, aggDep.rowType, selectKey(child.typ.rowType, child.typ.key)), - globalType = unify(child.typ.globalType, aggDep.globalType, requestedType.globalType)), memo) + memoizeTableIR( + ctx, + child, + TableType( + key = child.typ.key, + rowType = + unify(child.typ.rowType, aggDep.rowType, selectKey(child.typ.rowType, child.typ.key)), + globalType = unify(child.typ.globalType, aggDep.globalType, requestedType.globalType), + ), + memo, + ) case TableKeyByAndAggregate(child, expr, newKey, _, _) => val keyDep = memoizeAndGetDep(ctx, newKey, newKey.typ, child.typ, memo) val exprDep = memoizeAndGetDep(ctx, expr, requestedType.valueType, child.typ, memo) - memoizeTableIR(ctx, child, + memoizeTableIR( + ctx, + child, TableType( key = FastSeq(), // note: this can deoptimize if prune runs before Simplify rowType = unify(child.typ.rowType, keyDep.rowType, exprDep.rowType), - globalType = unify(child.typ.globalType, keyDep.globalType, exprDep.globalType, requestedType.globalType)), - memo) + globalType = unify( + child.typ.globalType, + keyDep.globalType, + exprDep.globalType, + requestedType.globalType, + ), + ), + memo, + ) case MatrixColsTable(child) => val mtDep = minimal(child.typ).copy( globalType = requestedType.globalType, entryType = TStruct.empty, colType = requestedType.rowType, - colKey = requestedType.key) + colKey = requestedType.key, + ) memoizeMatrixIR(ctx, child, mtDep, memo) case MatrixRowsTable(child) => val minChild = minimal(child.typ) val mtDep = minChild.copy( globalType = requestedType.globalType, - rowType = unify(child.typ.rowType, selectKey(child.typ.rowType, requestedType.key), requestedType.rowType), - rowKey = requestedType.key) + rowType = unify( + child.typ.rowType, + selectKey(child.typ.rowType, requestedType.key), + requestedType.rowType, + ), + rowKey = requestedType.key, + ) memoizeMatrixIR(ctx, child, mtDep, memo) case MatrixEntriesTable(child) => val mtDep = MatrixType( @@ -566,12 +715,21 @@ object PruneDeadFields { colKey = requestedType.key.drop(child.typ.rowKey.length), globalType = requestedType.globalType, colType = TStruct( - child.typ.colType.fields.flatMap(f => requestedType.rowType.selfField(f.name).map(f2 => f.name -> f2.typ)): _*), + child.typ.colType.fields.flatMap(f => + requestedType.rowType.selfField(f.name).map(f2 => f.name -> f2.typ) + ): _* + ), rowType = TStruct( - child.typ.rowType.fields.flatMap(f => requestedType.rowType.selfField(f.name).map(f2 => f.name -> f2.typ)): _*), + child.typ.rowType.fields.flatMap(f => + requestedType.rowType.selfField(f.name).map(f2 => f.name -> f2.typ) + ): _* + ), entryType = TStruct( - child.typ.entryType.fields.flatMap(f => requestedType.rowType.selfField(f.name).map(f2 => f.name -> f2.typ)): _*) - ) + child.typ.entryType.fields.flatMap(f => + requestedType.rowType.selfField(f.name).map(f2 => f.name -> f2.typ) + ): _* + ), + ) memoizeMatrixIR(ctx, child, mtDep, memo) case TableUnion(children) => memoizeTableIR(ctx, children(0), requestedType, memo) @@ -586,17 +744,22 @@ object PruneDeadFields { else requestedType.globalType, colType = if (requestedType.globalType.hasField(colsFieldName)) - requestedType.globalType.field(colsFieldName).typ.asInstanceOf[TArray].elementType.asInstanceOf[TStruct] + requestedType.globalType.field(colsFieldName).typ.asInstanceOf[ + TArray + ].elementType.asInstanceOf[TStruct] else TStruct.empty, entryType = if (requestedType.rowType.hasField(entriesFieldName)) - requestedType.rowType.field(entriesFieldName).typ.asInstanceOf[TArray].elementType.asInstanceOf[TStruct] + requestedType.rowType.field(entriesFieldName).typ.asInstanceOf[ + TArray + ].elementType.asInstanceOf[TStruct] else TStruct.empty, rowType = if (requestedType.rowType.hasField(entriesFieldName)) requestedType.rowType.deleteKey(entriesFieldName) else - requestedType.rowType) + requestedType.rowType, + ) memoizeMatrixIR(ctx, child, childDep, memo) case TableRename(child, rowMap, globalMap) => val rowMapRev = rowMap.map { case (k, v) => (v, k) } @@ -604,14 +767,24 @@ object PruneDeadFields { val childDep = TableType( rowType = requestedType.rowType.rename(rowMapRev), globalType = requestedType.globalType.rename(globalMapRev), - key = requestedType.key.map(k => rowMapRev.getOrElse(k, k))) + key = requestedType.key.map(k => rowMapRev.getOrElse(k, k)), + ) memoizeTableIR(ctx, child, childDep, memo) case TableFilterIntervals(child, _, _) => - memoizeTableIR(ctx, child, requestedType.copy(key = child.typ.key, - rowType = PruneDeadFields.unify(child.typ.rowType, - requestedType.rowType, - PruneDeadFields.selectKey(child.typ.rowType, child.typ.key))), memo) - case TableToTableApply(child, f) => memoizeTableIR(ctx, child, child.typ, memo) + memoizeTableIR( + ctx, + child, + requestedType.copy( + key = child.typ.key, + rowType = PruneDeadFields.unify( + child.typ.rowType, + requestedType.rowType, + PruneDeadFields.selectKey(child.typ.rowType, child.typ.key), + ), + ), + memo, + ) + case TableToTableApply(child, _) => memoizeTableIR(ctx, child, child.typ, memo) case MatrixToTableApply(child, _) => memoizeMatrixIR(ctx, child, child.typ, memo) case BlockMatrixToTableApply(bm, aux, _) => memoizeBlockMatrixIR(ctx, bm, bm.typ, memo) @@ -628,8 +801,8 @@ object PruneDeadFields { ctx: ExecuteContext, mir: MatrixIR, requestedType: MatrixType, - memo: ComputeMutableState - ) { + memo: ComputeMutableState, + ): Unit = { memo.requestedType.bind(mir, requestedType) mir match { case MatrixFilterCols(child, pred) => @@ -641,15 +814,24 @@ object PruneDeadFields { case MatrixFilterEntries(child, pred) => val irDep = memoizeAndGetDep(ctx, pred, pred.typ, child.typ, memo) memoizeMatrixIR(ctx, child, unify(child.typ, requestedType, irDep), memo) - case MatrixUnionCols(left, right, joinType) => + case MatrixUnionCols(left, right, _) => val leftRequestedType = requestedType.copy( rowKey = left.typ.rowKey, - rowType = unify(left.typ.rowType, requestedType.rowType, selectKey(left.typ.rowType, left.typ.rowKey)) + rowType = unify( + left.typ.rowType, + requestedType.rowType, + selectKey(left.typ.rowType, left.typ.rowKey), + ), ) val rightRequestedType = requestedType.copy( globalType = TStruct.empty, rowKey = right.typ.rowKey, - rowType = unify(right.typ.rowType, requestedType.rowType, selectKey(right.typ.rowType, right.typ.rowKey))) + rowType = unify( + right.typ.rowType, + requestedType.rowType, + selectKey(right.typ.rowType, right.typ.rowKey), + ), + ) memoizeMatrixIR(ctx, left, leftRequestedType, memo) memoizeMatrixIR(ctx, right, rightRequestedType, memo) case MatrixMapEntries(child, newEntries) => @@ -665,30 +847,53 @@ object PruneDeadFields { if (reqKey.length <= child.typ.rowKey.length) reqKey else child.typ.rowKey else FastSeq() - memoizeMatrixIR(ctx, child, requestedType.copy( - rowKey = childReqKey, - rowType = unify(child.typ.rowType, requestedType.rowType, selectKey(child.typ.rowType, childReqKey))), - memo) + memoizeMatrixIR( + ctx, + child, + requestedType.copy( + rowKey = childReqKey, + rowType = unify( + child.typ.rowType, + requestedType.rowType, + selectKey(child.typ.rowType, childReqKey), + ), + ), + memo, + ) case MatrixMapRows(child, newRow) => val (reqKey, reqRowType) = if (ContainsScan(newRow)) - (child.typ.rowKey, unify(newRow.typ, requestedType.rowType, selectKey(newRow.typ.asInstanceOf[TStruct], child.typ.rowKey))) + ( + child.typ.rowKey, + unify( + newRow.typ, + requestedType.rowType, + selectKey(newRow.typ.asInstanceOf[TStruct], child.typ.rowKey), + ), + ) else (requestedType.rowKey, requestedType.rowType) val irDep = memoizeAndGetDep(ctx, newRow, reqRowType, child.typ, memo) - val depMod = requestedType.copy(rowType = selectKey(child.typ.rowType, reqKey), rowKey = reqKey) + val depMod = + requestedType.copy(rowType = selectKey(child.typ.rowType, reqKey), rowKey = reqKey) memoizeMatrixIR(ctx, child, unify(child.typ, depMod, irDep), memo) case MatrixMapCols(child, newCol, newKey) => val irDep = memoizeAndGetDep(ctx, newCol, requestedType.colType, child.typ, memo) - val reqKey = newKey match { + val reqKey = newKey match { case Some(_) => FastSeq() case None => requestedType.colKey } - val depMod = requestedType.copy(colType = selectKey(child.typ.colType, reqKey), colKey = reqKey) + val depMod = + requestedType.copy(colType = selectKey(child.typ.colType, reqKey), colKey = reqKey) memoizeMatrixIR(ctx, child, unify(child.typ, depMod, irDep), memo) case MatrixMapGlobals(child, newGlobals) => val irDep = memoizeAndGetDep(ctx, newGlobals, requestedType.globalType, child.typ, memo) - memoizeMatrixIR(ctx, child, unify(child.typ, requestedType.copy(globalType = irDep.globalType), irDep), memo) + memoizeMatrixIR( + ctx, + child, + unify(child.typ, requestedType.copy(globalType = irDep.globalType), irDep), + memo, + ) case MatrixRead(_, _, _, _) => case MatrixLiteral(_, _) => case MatrixChooseCols(child, _) => @@ -703,11 +908,16 @@ object PruneDeadFields { Some(f.name -> f.typ) else { requestedColType.selfField(f.name) - .map(requestedField => f.name -> requestedField.typ.asInstanceOf[TArray].elementType) + .map(requestedField => + f.name -> requestedField.typ.asInstanceOf[TArray].elementType + ) } }: _*), rowType = requestedType.rowType, - entryType = TStruct(requestedType.entryType.fields.map(f => f.copy(typ = f.typ.asInstanceOf[TArray].elementType)))) + entryType = TStruct(requestedType.entryType.fields.map(f => + f.copy(typ = f.typ.asInstanceOf[TArray].elementType) + )), + ) memoizeMatrixIR(ctx, child, explodedDep, memo) case MatrixAggregateRowsByKey(child, entryExpr, rowExpr) => val irDepEntry = memoizeAndGetDep(ctx, entryExpr, requestedType.entryType, child.typ, memo) @@ -716,9 +926,21 @@ object PruneDeadFields { rowKey = child.typ.rowKey, colKey = requestedType.colKey, entryType = irDepEntry.entryType, - rowType = unify(child.typ.rowType, selectKey(child.typ.rowType, child.typ.rowKey), irDepRow.rowType, irDepEntry.rowType), - colType = unify(child.typ.colType, requestedType.colType, irDepEntry.colType, irDepRow.colType), - globalType = unify(child.typ.globalType, requestedType.globalType, irDepEntry.globalType, irDepRow.globalType)) + rowType = unify( + child.typ.rowType, + selectKey(child.typ.rowType, child.typ.rowKey), + irDepRow.rowType, + irDepEntry.rowType, + ), + colType = + unify(child.typ.colType, requestedType.colType, irDepEntry.colType, irDepRow.colType), + globalType = unify( + child.typ.globalType, + requestedType.globalType, + irDepEntry.globalType, + irDepRow.globalType, + ), + ) memoizeMatrixIR(ctx, child, childDep, memo) case MatrixAggregateColsByKey(child, entryExpr, colExpr) => val irDepEntry = memoizeAndGetDep(ctx, entryExpr, requestedType.entryType, child.typ, memo) @@ -726,10 +948,22 @@ object PruneDeadFields { val childDep: MatrixType = MatrixType( rowKey = requestedType.rowKey, colKey = child.typ.colKey, - colType = unify(child.typ.colType, irDepCol.colType, irDepEntry.colType, selectKey(child.typ.colType, child.typ.colKey)), - globalType = unify(child.typ.globalType, requestedType.globalType, irDepEntry.globalType, irDepCol.globalType), - rowType = unify(child.typ.rowType, irDepEntry.rowType, irDepCol.rowType, requestedType.rowType), - entryType = irDepEntry.entryType) + colType = unify( + child.typ.colType, + irDepCol.colType, + irDepEntry.colType, + selectKey(child.typ.colType, child.typ.colKey), + ), + globalType = unify( + child.typ.globalType, + requestedType.globalType, + irDepEntry.globalType, + irDepCol.globalType, + ), + rowType = + unify(child.typ.rowType, irDepEntry.rowType, irDepCol.rowType, requestedType.rowType), + entryType = irDepEntry.entryType, + ) memoizeMatrixIR(ctx, child, childDep, memo) case MatrixAnnotateRowsTable(child, table, root, product) => val fieldDep = requestedType.rowType.selfField(root).map { field => @@ -744,16 +978,20 @@ object PruneDeadFields { val tableDep = TableType( key = tk, rowType = unify(table.typ.rowType, struct, selectKey(table.typ.rowType, tk)), - globalType = minimal(table.typ.globalType)) + globalType = minimal(table.typ.globalType), + ) memoizeTableIR(ctx, table, tableDep, memo) val mk = unifyKey(FastSeq(child.typ.rowKey.take(tk.length), requestedType.rowKey)) val matDep = requestedType.copy( rowKey = mk, rowType = - unify(child.typ.rowType, + unify( + child.typ.rowType, selectKey(child.typ.rowType, mk), - requestedType.rowType.filterSet(Set(root), include = false)._1)) + requestedType.rowType.filterSet(Set(root), include = false)._1, + ), + ) memoizeMatrixIR(ctx, child, matDep, memo) case None => // don't depend on key IR dependencies if we are going to elide the node anyway @@ -767,14 +1005,20 @@ object PruneDeadFields { val tableDep = TableType( key = tk, rowType = unify(table.typ.rowType, struct, selectKey(table.typ.rowType, tk)), - globalType = minimal(table.typ.globalType)) + globalType = minimal(table.typ.globalType), + ) memoizeTableIR(ctx, table, tableDep, memo) - val mk = unifyKey(FastSeq(child.typ.colKey.take(table.typ.key.length), requestedType.colKey)) + val mk = + unifyKey(FastSeq(child.typ.colKey.take(table.typ.key.length), requestedType.colKey)) val matDep = requestedType.copy( colKey = mk, - colType = unify(child.typ.colType, requestedType.colType.filterSet(Set(uid), include = false)._1, - selectKey(child.typ.colType, mk))) + colType = unify( + child.typ.colType, + requestedType.colType.filterSet(Set(uid), include = false)._1, + selectKey(child.typ.colType, mk), + ), + ) memoizeMatrixIR(ctx, child, matDep, memo) case None => // don't depend on key IR dependencies if we are going to elide the node anyway @@ -784,73 +1028,116 @@ object PruneDeadFields { def getExplodedField(typ: MatrixType): Type = typ.rowType.queryTyped(path.toList)._1 val preExplosionFieldType = getExplodedField(child.typ) - val prunedPreExlosionFieldType = try { - val t = getExplodedField(requestedType) - preExplosionFieldType match { - case ta: TArray => TArray(t) - case ts: TSet => ts.copy(elementType = t) + val prunedPreExlosionFieldType = + try { + val t = getExplodedField(requestedType) + preExplosionFieldType match { + case _: TArray => TArray(t) + case ts: TSet => ts.copy(elementType = t) + } + } catch { + case _: AnnotationPathException => minimal(preExplosionFieldType) } - } catch { - case e: AnnotationPathException => minimal(preExplosionFieldType) - } - val dep = requestedType.copy(rowType = unify(child.typ.rowType, - requestedType.rowType.insert(prunedPreExlosionFieldType, path)._1.asInstanceOf[TStruct])) + val dep = requestedType.copy(rowType = + unify( + child.typ.rowType, + requestedType.rowType.insert(prunedPreExlosionFieldType, path)._1.asInstanceOf[TStruct], + ) + ) memoizeMatrixIR(ctx, child, dep, memo) case MatrixExplodeCols(child, path) => def getExplodedField(typ: MatrixType): Type = typ.colType.queryTyped(path.toList)._1 val preExplosionFieldType = getExplodedField(child.typ) - val prunedPreExplosionFieldType = try { - val t = getExplodedField(requestedType) - preExplosionFieldType match { - case ta: TArray => TArray(t) - case ts: TSet => ts.copy(elementType = t) + val prunedPreExplosionFieldType = + try { + val t = getExplodedField(requestedType) + preExplosionFieldType match { + case _: TArray => TArray(t) + case ts: TSet => ts.copy(elementType = t) + } + } catch { + case _: AnnotationPathException => minimal(preExplosionFieldType) } - } catch { - case e: AnnotationPathException => minimal(preExplosionFieldType) - } - val dep = requestedType.copy(colType = unify(child.typ.colType, - requestedType.colType.insert(prunedPreExplosionFieldType, path)._1.asInstanceOf[TStruct])) + val dep = requestedType.copy(colType = + unify( + child.typ.colType, + requestedType.colType.insert(prunedPreExplosionFieldType, path)._1.asInstanceOf[TStruct], + ) + ) memoizeMatrixIR(ctx, child, dep, memo) case MatrixRepartition(child, _, _) => memoizeMatrixIR(ctx, child, requestedType, memo) case MatrixUnionRows(children) => memoizeMatrixIR(ctx, children.head, requestedType, memo) - children.tail.foreach(memoizeMatrixIR(ctx, _, requestedType.copy(colType = requestedType.colKeyStruct), memo)) + children.tail.foreach(memoizeMatrixIR( + ctx, + _, + requestedType.copy(colType = requestedType.colKeyStruct), + memo, + )) case MatrixDistinctByRow(child) => val dep = requestedType.copy( rowKey = child.typ.rowKey, - rowType = unify(child.typ.rowType, requestedType.rowType, selectKey(child.typ.rowType, child.typ.rowKey)) + rowType = unify( + child.typ.rowType, + requestedType.rowType, + selectKey(child.typ.rowType, child.typ.rowKey), + ), ) memoizeMatrixIR(ctx, child, dep, memo) - case MatrixRowsHead(child, n) => + case MatrixRowsHead(child, _) => val dep = requestedType.copy( rowKey = child.typ.rowKey, - rowType = unify(child.typ.rowType, requestedType.rowType, selectKey(child.typ.rowType, child.typ.rowKey)) + rowType = unify( + child.typ.rowType, + requestedType.rowType, + selectKey(child.typ.rowType, child.typ.rowKey), + ), ) memoizeMatrixIR(ctx, child, dep, memo) - case MatrixColsHead(child, n) => memoizeMatrixIR(ctx, child, requestedType, memo) - case MatrixRowsTail(child, n) => + case MatrixColsHead(child, _) => memoizeMatrixIR(ctx, child, requestedType, memo) + case MatrixRowsTail(child, _) => val dep = requestedType.copy( rowKey = child.typ.rowKey, - rowType = unify(child.typ.rowType, requestedType.rowType, selectKey(child.typ.rowType, child.typ.rowKey)) + rowType = unify( + child.typ.rowType, + requestedType.rowType, + selectKey(child.typ.rowType, child.typ.rowKey), + ), ) memoizeMatrixIR(ctx, child, dep, memo) - case MatrixColsTail(child, n) => memoizeMatrixIR(ctx, child, requestedType, memo) + case MatrixColsTail(child, _) => memoizeMatrixIR(ctx, child, requestedType, memo) case CastTableToMatrix(child, entriesFieldName, colsFieldName, _) => - val m = Map(MatrixType.entriesIdentifier -> entriesFieldName) val childDep = child.typ.copy( key = requestedType.rowKey, - globalType = unify(child.typ.globalType, requestedType.globalType, TStruct((colsFieldName, TArray(requestedType.colType)))), - rowType = unify(child.typ.rowType, requestedType.rowType, TStruct((entriesFieldName, TArray(requestedType.entryType)))) + globalType = unify( + child.typ.globalType, + requestedType.globalType, + TStruct((colsFieldName, TArray(requestedType.colType))), + ), + rowType = unify( + child.typ.rowType, + requestedType.rowType, + TStruct((entriesFieldName, TArray(requestedType.entryType))), + ), ) memoizeTableIR(ctx, child, childDep, memo) case MatrixFilterIntervals(child, _, _) => - memoizeMatrixIR(ctx, child, requestedType.copy(rowKey = child.typ.rowKey, - rowType = unify(child.typ.rowType, - requestedType.rowType, - selectKey(child.typ.rowType, child.typ.rowKey))), memo) - case MatrixToMatrixApply(child, f) => memoizeMatrixIR(ctx, child, child.typ, memo) + memoizeMatrixIR( + ctx, + child, + requestedType.copy( + rowKey = child.typ.rowKey, + rowType = unify( + child.typ.rowType, + requestedType.rowType, + selectKey(child.typ.rowType, child.typ.rowKey), + ), + ), + memo, + ) + case MatrixToMatrixApply(child, _) => memoizeMatrixIR(ctx, child, child.typ, memo) case MatrixRename(child, globalMap, colMap, rowMap, entryMap) => val globalMapRev = globalMap.map { case (k, v) => (v, k) } val colMapRev = colMap.map { case (k, v) => (v, k) } @@ -862,7 +1149,8 @@ object PruneDeadFields { rowKey = requestedType.rowKey.map(k => rowMapRev.getOrElse(k, k)), colKey = requestedType.colKey.map(k => colMapRev.getOrElse(k, k)), rowType = requestedType.rowType.rename(rowMapRev), - entryType = requestedType.entryType.rename(entryMapRev)) + entryType = requestedType.entryType.rename(entryMapRev), + ) memoizeMatrixIR(ctx, child, childDep, memo) case RelationalLetMatrixTable(name, value, body) => memoizeMatrixIR(ctx, body, requestedType, memo) @@ -875,7 +1163,7 @@ object PruneDeadFields { ctx: ExecuteContext, bmir: BlockMatrixIR, requestedType: BlockMatrixType, - memo: ComputeMutableState + memo: ComputeMutableState, ): Unit = { memo.requestedType.bind(bmir, requestedType) bmir match { @@ -898,7 +1186,7 @@ object PruneDeadFields { ir: IR, requestedType: Type, base: TableType, - memo: ComputeMutableState + memo: ComputeMutableState, ): TableType = { val depEnv = memoizeValueIR(ctx, ir, requestedType, memo) val depEnvUnified = concatEnvs(FastSeq(depEnv.eval) ++ FastSeq(depEnv.agg, depEnv.scan).flatten) @@ -907,20 +1195,19 @@ object PruneDeadFields { depEnvUnified.m.keys.foreach { k => if (!expectedBindingSet.contains(k)) throw new RuntimeException(s"found unexpected free variable in pruning: $k\n" + - s" ${ depEnv.pretty(_.result().mkString(",")) }\n" + - s" ${ Pretty(ctx, ir) }") + s" ${depEnv.pretty(_.result().mkString(","))}\n" + + s" ${Pretty(ctx, ir)}") } val min = minimal(base) - val rowType = unifySeq(base.rowType, - Array(min.rowType) ++ uses("row", depEnvUnified) - ) - val globalType = unifySeq(base.globalType, - Array(min.globalType) ++ uses("global", depEnvUnified) - ) - TableType(key = FastSeq(), + val rowType = unifySeq(base.rowType, Array(min.rowType) ++ uses("row", depEnvUnified)) + val globalType = + unifySeq(base.globalType, Array(min.globalType) ++ uses("global", depEnvUnified)) + TableType( + key = FastSeq(), rowType = rowType.asInstanceOf[TStruct], - globalType = globalType.asInstanceOf[TStruct]) + globalType = globalType.asInstanceOf[TStruct], + ) } def memoizeAndGetDep( @@ -928,7 +1215,7 @@ object PruneDeadFields { ir: IR, requestedType: Type, base: MatrixType, - memo: ComputeMutableState + memo: ComputeMutableState, ): MatrixType = { val depEnv = memoizeValueIR(ctx, ir, requestedType, memo) val depEnvUnified = concatEnvs(FastSeq(depEnv.eval) ++ FastSeq(depEnv.agg, depEnv.scan).flatten) @@ -936,25 +1223,26 @@ object PruneDeadFields { val expectedBindingSet = Set("va", "sa", "g", "global", "n_rows", "n_cols") depEnvUnified.m.keys.foreach { k => if (!expectedBindingSet.contains(k)) - throw new RuntimeException(s"found unexpected free variable in pruning: $k\n ${ Pretty(ctx, ir) }") + throw new RuntimeException( + s"found unexpected free variable in pruning: $k\n ${Pretty(ctx, ir)}" + ) } val min = minimal(base) - val globalType = unifySeq(base.globalType, - Array(min.globalType) ++ uses("global", depEnvUnified)) - .asInstanceOf[TStruct] - val rowType = unifySeq(base.rowType, - Array(min.rowType) ++ uses("va", depEnvUnified)) + val globalType = + unifySeq(base.globalType, Array(min.globalType) ++ uses("global", depEnvUnified)) + .asInstanceOf[TStruct] + val rowType = unifySeq(base.rowType, Array(min.rowType) ++ uses("va", depEnvUnified)) .asInstanceOf[TStruct] - val colType = unifySeq(base.colType, - Array(min.colType) ++ uses("sa", depEnvUnified)) + val colType = unifySeq(base.colType, Array(min.colType) ++ uses("sa", depEnvUnified)) .asInstanceOf[TStruct] - val entryType = unifySeq(base.entryType, - Array(min.entryType) ++ uses("g", depEnvUnified)) + val entryType = unifySeq(base.entryType, Array(min.entryType) ++ uses("g", depEnvUnified)) .asInstanceOf[TStruct] if (rowType.hasField(MatrixType.entriesIdentifier)) - throw new RuntimeException(s"prune: found dependence on entry array in row binding:\n${ Pretty(ctx, ir) }") + throw new RuntimeException( + s"prune: found dependence on entry array in row binding:\n${Pretty(ctx, ir)}" + ) MatrixType( rowKey = FastSeq(), @@ -962,25 +1250,24 @@ object PruneDeadFields { globalType = globalType, colType = colType, rowType = rowType, - entryType = entryType) + entryType = entryType, + ) } - /** - * This function does *not* necessarily bind each child node in `memo`. - * Known dead code is not memoized. For instance: + /** This function does *not* necessarily bind each child node in `memo`. Known dead code is not + * memoized. For instance: * - * ir = MakeStruct(Seq("a" -> (child1), "b" -> (child2))) - * requestedType = TStruct("a" -> (reqType of a)) + * ir = MakeStruct(Seq("a" -> (child1), "b" -> (child2))) requestedType = TStruct("a" -> (reqType + * of a)) * - * In the above, `child2` will not be memoized because `ir` does not require - * any of the "b" dependencies in order to create its own requested type, - * which only contains "a". + * In the above, `child2` will not be memoized because `ir` does not require any of the "b" + * dependencies in order to create its own requested type, which only contains "a". */ def memoizeValueIR( ctx: ExecuteContext, ir: IR, requestedType: Type, - memo: ComputeMutableState + memo: ComputeMutableState, ): BindingEnv[BoxedArrayBuilder[Type]] = { memo.requestedType.bind(ir, requestedType) ir match { @@ -995,9 +1282,11 @@ object PruneDeadFields { }) case (TTuple(req), TTuple(cast), TTuple(base)) => assert(base.length == cast.length) - val castFields = cast.map { f => f.index -> f.typ }.toMap - val baseFields = base.map { f => f.index -> f.typ }.toMap - TTuple(req.map { f => TupleField(f.index, recur(f.typ, castFields(f.index), baseFields(f.index)))}) + val castFields = cast.map(f => f.index -> f.typ).toMap + val baseFields = base.map(f => f.index -> f.typ).toMap + TTuple(req.map { f => + TupleField(f.index, recur(f.typ, castFields(f.index), baseFields(f.index))) + }) case (TArray(req), TArray(cast), TArray(base)) => TArray(recur(req, cast, base)) case (TSet(req), TSet(cast), TSet(base)) => @@ -1015,45 +1304,41 @@ object PruneDeadFields { unifyEnvs( memoizeValueIR(ctx, cond, cond.typ, memo), memoizeValueIR(ctx, cnsq, requestedType, memo), - memoizeValueIR(ctx, alt, requestedType, memo) + memoizeValueIR(ctx, alt, requestedType, memo), ) case Switch(x, default, cases) => unifyEnvs( memoizeValueIR(ctx, x, x.typ, memo), memoizeValueIR(ctx, default, requestedType, memo), - unifyEnvsSeq(cases.map { case_ => - memoizeValueIR(ctx, case_, requestedType, memo) - }) + unifyEnvsSeq(cases.map(case_ => memoizeValueIR(ctx, case_, requestedType, memo))), ) case Coalesce(values) => unifyEnvsSeq(values.map(memoizeValueIR(ctx, _, requestedType, memo))) case Consume(value) => memoizeValueIR(ctx, value, value.typ, memo) - case Let(bindings, body) => + case Block(bindings, body) => val bodyEnv = memoizeValueIR(ctx, body, requestedType, memo) - bindings.foldRight(bodyEnv) { case ((name, value), bodyEnv) => - val valueType = unifySeq(value.typ, uses(name, bodyEnv.eval)) - unifyEnvs( - bodyEnv.deleteEval(name), - memoizeValueIR(ctx, value, valueType, memo) - ) - } - case AggLet(name, value, body, isScan) => - val bodyEnv = memoizeValueIR(ctx, body, requestedType, memo) - if (isScan) { - val valueType = unifySeq(value.typ, uses(name, bodyEnv.scanOrEmpty)) - val valueEnv = memoizeValueIR(ctx, value, valueType, memo) - unifyEnvs( - bodyEnv.copy(scan = bodyEnv.scan.map(_.delete(name))), - valueEnv.copy(eval = Env.empty, scan = Some(valueEnv.eval)) - ) - } else { - val valueType = unifySeq(value.typ, uses(name, bodyEnv.aggOrEmpty)) - val valueEnv = memoizeValueIR(ctx, value, valueType, memo) - unifyEnvs( - bodyEnv.copy(agg = bodyEnv.agg.map(_.delete(name))), - valueEnv.copy(eval = Env.empty, agg = Some(valueEnv.eval)) - ) + bindings.foldRight(bodyEnv) { + case (Binding(name, value, Scope.EVAL), bodyEnv) => + val valueType = unifySeq(value.typ, uses(name, bodyEnv.eval)) + unifyEnvs( + bodyEnv.deleteEval(name), + memoizeValueIR(ctx, value, valueType, memo), + ) + case (Binding(name, value, Scope.SCAN), bodyEnv) => + val valueType = unifySeq(value.typ, uses(name, bodyEnv.scanOrEmpty)) + val valueEnv = memoizeValueIR(ctx, value, valueType, memo) + unifyEnvs( + bodyEnv.copy(scan = bodyEnv.scan.map(_.delete(name))), + valueEnv.copy(eval = Env.empty, scan = Some(valueEnv.eval)), + ) + case (Binding(name, value, Scope.AGG), bodyEnv) => + val valueType = unifySeq(value.typ, uses(name, bodyEnv.aggOrEmpty)) + val valueEnv = memoizeValueIR(ctx, value, valueType, memo) + unifyEnvs( + bodyEnv.copy(agg = bodyEnv.agg.map(_.delete(name))), + valueEnv.copy(eval = Env.empty, agg = Some(valueEnv.eval)), + ) } - case Ref(name, t) => + case Ref(name, _) => val ab = new BoxedArrayBuilder[Type]() ab += requestedType BindingEnv.empty.bindEval(name -> ab) @@ -1075,85 +1360,105 @@ object PruneDeadFields { unifyEnvs( memoizeValueIR(ctx, a, TArray(requestedType), memo), memoizeValueIR(ctx, i, i.typ, memo), - memoizeValueIR(ctx, s, s.typ, memo) + memoizeValueIR(ctx, s, s.typ, memo), ) case ArrayLen(a) => memoizeValueIR(ctx, a, minimal(a.typ), memo) case StreamTake(a, len) => unifyEnvs( memoizeValueIR(ctx, a, requestedType, memo), - memoizeValueIR(ctx, len, len.typ, memo)) + memoizeValueIR(ctx, len, len.typ, memo), + ) case StreamDrop(a, len) => unifyEnvs( memoizeValueIR(ctx, a, requestedType, memo), - memoizeValueIR(ctx, len, len.typ, memo)) + memoizeValueIR(ctx, len, len.typ, memo), + ) case StreamWhiten(a, newChunk, prevWindow, _, _, _, _, _) => val matType = TNDArray(TFloat64, Nat(2)) val unifiedStructType = unify( a.typ.asInstanceOf[TStream].elementType, requestedType.asInstanceOf[TStream].elementType, - TStruct((newChunk, matType), (prevWindow, matType))) + TStruct((newChunk, matType), (prevWindow, matType)), + ) unifyEnvs( - memoizeValueIR(ctx, a, TStream(unifiedStructType), memo)) - case StreamMap(a, name, body) => - val bodyEnv = memoizeValueIR(ctx, body, - TIterable.elementType(requestedType), - memo + memoizeValueIR(ctx, a, TStream(unifiedStructType), memo) ) + case StreamMap(a, name, body) => + val bodyEnv = memoizeValueIR(ctx, body, TIterable.elementType(requestedType), memo) val valueType = unifySeq(TIterable.elementType(a.typ), uses(name, bodyEnv.eval)) unifyEnvs( bodyEnv.deleteEval(name), - memoizeValueIR(ctx, a, TStream(valueType), memo) + memoizeValueIR(ctx, a, TStream(valueType), memo), ) case StreamGrouped(a, size) => unifyEnvs( memoizeValueIR(ctx, a, TIterable.elementType(requestedType), memo), - memoizeValueIR(ctx, size, size.typ, memo)) + memoizeValueIR(ctx, size, size.typ, memo), + ) case StreamGroupByKey(a, key, _) => - val reqStructT = tcoerce[TStruct](tcoerce[TStream](tcoerce[TStream](requestedType).elementType).elementType) + val reqStructT = tcoerce[TStruct]( + tcoerce[TStream](tcoerce[TStream](requestedType).elementType).elementType + ) val origStructT = tcoerce[TStruct](tcoerce[TStream](a.typ).elementType) - memoizeValueIR(ctx, a, TStream(unify(origStructT, reqStructT, selectKey(origStructT, key))), memo) + memoizeValueIR( + ctx, + a, + TStream(unify(origStructT, reqStructT, selectKey(origStructT, key))), + memo, + ) case StreamZip(as, names, body, behavior, _) => - val bodyEnv = memoizeValueIR(ctx, body, - TIterable.elementType(requestedType), - memo) + val bodyEnv = memoizeValueIR(ctx, body, TIterable.elementType(requestedType), memo) val valueTypes = (names, as).zipped.map { (name, a) => - bodyEnv.eval.lookupOption(name).map(ab => unifySeq(tcoerce[TStream](a.typ).elementType, ab.result())) + bodyEnv.eval.lookupOption(name).map(ab => + unifySeq(tcoerce[TStream](a.typ).elementType, ab.result()) + ) } if (behavior == ArrayZipBehavior.AssumeSameLength && valueTypes.forall(_.isEmpty)) { - unifyEnvs(memoizeValueIR(ctx, as.head, TStream(minimal(tcoerce[TStream](as.head.typ).elementType)), memo) +: - Array(bodyEnv.deleteEval(names)): _*) + unifyEnvs(memoizeValueIR( + ctx, + as.head, + TStream(minimal(tcoerce[TStream](as.head.typ).elementType)), + memo, + ) +: + Array(bodyEnv.deleteEval(names)): _*) } else { unifyEnvs( (as, valueTypes).zipped.map { (a, vtOption) => val at = tcoerce[TStream](a.typ) if (behavior == ArrayZipBehavior.AssumeSameLength) { - vtOption.map { vt => - memoizeValueIR(ctx, a, TStream(vt), memo) - }.getOrElse(BindingEnv.empty) + vtOption.map(vt => memoizeValueIR(ctx, a, TStream(vt), memo)).getOrElse( + BindingEnv.empty + ) } else memoizeValueIR(ctx, a, TStream(vtOption.getOrElse(minimal(at.elementType))), memo) - } ++ Array(bodyEnv.deleteEval(names)): _*) + } ++ Array(bodyEnv.deleteEval(names)): _* + ) } - case StreamZipJoin(as, key, curKey, curVals, joinF) => + case StreamZipJoin(as, key, _, curVals, joinF) => val eltType = tcoerce[TStruct](tcoerce[TStream](as.head.typ).elementType) val requestedEltType = tcoerce[TStream](requestedType).elementType val bodyEnv = memoizeValueIR(ctx, joinF, requestedEltType, memo) val childRequestedEltType = unifySeq( eltType, - uses(curVals, bodyEnv.eval).map(TIterable.elementType) :+ selectKey(eltType, key) + uses(curVals, bodyEnv.eval).map(TIterable.elementType) :+ selectKey(eltType, key), ) unifyEnvsSeq(as.map(memoizeValueIR(ctx, _, TStream(childRequestedEltType), memo))) - case StreamZipJoinProducers(contexts, ctxName, makeProducer, key, curKey, curVals, joinF) => + case StreamZipJoinProducers(contexts, ctxName, makeProducer, key, _, curVals, joinF) => val baseEltType = tcoerce[TStruct](TIterable.elementType(makeProducer.typ)) val requestedEltType = tcoerce[TStream](requestedType).elementType val bodyEnv = memoizeValueIR(ctx, joinF, requestedEltType, memo) val producerRequestedEltType = unifySeq( baseEltType, - uses(curVals, bodyEnv.eval).map(TIterable.elementType) :+ selectKey(baseEltType, key) + uses(curVals, bodyEnv.eval).map(TIterable.elementType) :+ selectKey(baseEltType, key), ) val producerEnv = memoizeValueIR(ctx, makeProducer, TStream(producerRequestedEltType), memo) - val ctxEnv = memoizeValueIR(ctx, contexts, TArray(unifySeq(TIterable.elementType(contexts.typ), uses(ctxName, producerEnv.eval))), memo) + val ctxEnv = memoizeValueIR( + ctx, + contexts, + TArray(unifySeq(TIterable.elementType(contexts.typ), uses(ctxName, producerEnv.eval))), + memo, + ) unifyEnvsSeq(Array(bodyEnv, producerEnv, ctxEnv)) case StreamMultiMerge(as, key) => val eltType = tcoerce[TStruct](tcoerce[TStream](as.head.typ).elementType) @@ -1164,38 +1469,38 @@ object PruneDeadFields { val bodyEnv = memoizeValueIR(ctx, cond, cond.typ, memo) val valueType = unifySeq( TIterable.elementType(a.typ), - FastSeq(TIterable.elementType(requestedType)) ++ uses(name, bodyEnv.eval) + FastSeq(TIterable.elementType(requestedType)) ++ uses(name, bodyEnv.eval), ) unifyEnvs( bodyEnv.deleteEval(name), - memoizeValueIR(ctx, a, TStream(valueType), memo) + memoizeValueIR(ctx, a, TStream(valueType), memo), ) case StreamTakeWhile(a, name, cond) => val bodyEnv = memoizeValueIR(ctx, cond, cond.typ, memo) val valueType = unifySeq( TIterable.elementType(a.typ), - FastSeq(TIterable.elementType(requestedType)) ++ uses(name, bodyEnv.eval) + FastSeq(TIterable.elementType(requestedType)) ++ uses(name, bodyEnv.eval), ) unifyEnvs( bodyEnv.deleteEval(name), - memoizeValueIR(ctx, a, TStream(valueType), memo) + memoizeValueIR(ctx, a, TStream(valueType), memo), ) case StreamDropWhile(a, name, cond) => val bodyEnv = memoizeValueIR(ctx, cond, cond.typ, memo) val valueType = unifySeq( TIterable.elementType(a.typ), - FastSeq(TIterable.elementType(requestedType)) ++ uses(name, bodyEnv.eval) + FastSeq(TIterable.elementType(requestedType)) ++ uses(name, bodyEnv.eval), ) unifyEnvs( bodyEnv.deleteEval(name), - memoizeValueIR(ctx, a, TStream(valueType), memo) + memoizeValueIR(ctx, a, TStream(valueType), memo), ) case StreamFlatMap(a, name, body) => val bodyEnv = memoizeValueIR(ctx, body, requestedType, memo) val valueType = unifySeq(TIterable.elementType(a.typ), uses(name, bodyEnv.eval)) unifyEnvs( bodyEnv.deleteEval(name), - memoizeValueIR(ctx, a, TStream(valueType), memo) + memoizeValueIR(ctx, a, TStream(valueType), memo), ) case StreamFold(a, zero, accumName, valueName, body) => val zeroEnv = memoizeValueIR(ctx, zero, zero.typ, memo) @@ -1205,15 +1510,15 @@ object PruneDeadFields { unifyEnvs( zeroEnv, bodyEnv.deleteEval(valueName).deleteEval(accumName), - memoizeValueIR(ctx, a, TStream(valueType), memo) + memoizeValueIR(ctx, a, TStream(valueType), memo), ) case StreamFold2(a, accum, valueName, seq, res) => - val zeroEnvs = accum.map { case (name, zval) => memoizeValueIR(ctx, zval, zval.typ, memo) } - val seqEnvs = seq.map { seq => memoizeValueIR(ctx, seq, seq.typ, memo) } + val zeroEnvs = accum.map { case (_, zval) => memoizeValueIR(ctx, zval, zval.typ, memo) } + val seqEnvs = seq.map(seq => memoizeValueIR(ctx, seq, seq.typ, memo)) val resEnv = memoizeValueIR(ctx, res, requestedType, memo) val valueType = unifySeq( TIterable.elementType(a.typ), - uses(valueName, resEnv.eval) ++ seqEnvs.flatMap(e => uses(valueName, e.eval)) + uses(valueName, resEnv.eval) ++ seqEnvs.flatMap(e => uses(valueName, e.eval)), ) val accumNames = accum.map(_._1) @@ -1231,10 +1536,10 @@ object PruneDeadFields { unifyEnvs( zeroEnv, bodyEnv.deleteEval(valueName).deleteEval(accumName), - memoizeValueIR(ctx, a, TStream(valueType), memo) + memoizeValueIR(ctx, a, TStream(valueType), memo), ) - - case StreamJoinRightDistinct(left, right, lKey, rKey, l, r, join, joinType) => + + case StreamJoinRightDistinct(left, right, lKey, rKey, l, r, join, _) => val lElemType = TIterable.elementType(left.typ).asInstanceOf[TStruct] val rElemType = TIterable.elementType(right.typ).asInstanceOf[TStruct] @@ -1245,21 +1550,47 @@ object PruneDeadFields { unifyEnvs( joinEnv.deleteEval(l).deleteEval(r), memoizeValueIR(ctx, left, TStream(lRequested), memo), - memoizeValueIR(ctx, right, TStream(rRequested), memo) + memoizeValueIR(ctx, right, TStream(rRequested), memo), + ) + + case StreamLeftIntervalJoin(left, right, keyFieldName, intervalFieldName, lname, rname, + body) => + val joinEnv = memoizeValueIR(ctx, body, elementType(requestedType), memo) + + val lEltType = elementType(left.typ).asInstanceOf[TStruct] + val lRequestedType = unifySeq( + lEltType, + uses(lname, joinEnv.eval) :+ selectKey(lEltType, FastSeq(keyFieldName)), + ) + + val rEltType = elementType(right.typ).asInstanceOf[TStruct] + val rRequestedType = unifySeq( + TArray(rEltType), + uses(rname, joinEnv.eval) :+ TArray(selectKey(rEltType, FastSeq(intervalFieldName))), + ) + + unifyEnvs( + joinEnv.deleteEval(lname).deleteEval(rname), + memoizeValueIR(ctx, left, TStream(lRequestedType), memo), + memoizeValueIR(ctx, right, TStream(elementType(rRequestedType)), memo), ) + case ArraySort(a, left, right, lessThan) => val compEnv = memoizeValueIR(ctx, lessThan, lessThan.typ, memo) val requestedElementType = unifySeq( TIterable.elementType(a.typ), - Array(TIterable.elementType(requestedType)) ++ uses(left, compEnv.eval) ++ uses(right, compEnv.eval) + Array(TIterable.elementType(requestedType)) ++ uses(left, compEnv.eval) ++ uses( + right, + compEnv.eval, + ), ) - + unifyEnvs( compEnv.deleteEval(left).deleteEval(right), - memoizeValueIR(ctx, a, TStream(requestedElementType), memo) + memoizeValueIR(ctx, a, TStream(requestedElementType), memo), ) - + case ArrayMaximalIndependentSet(a, tiebreaker) => tiebreaker.foreach { case (_, _, tb) => memoizeValueIR(ctx, tb, tb.typ, memo) } memoizeValueIR(ctx, a, a.typ, memo) @@ -1268,34 +1599,38 @@ object PruneDeadFields { val bodyEnv = memoizeValueIR(ctx, body, body.typ, memo) val valueType = unifySeq( TIterable.elementType(a.typ), - uses(valueName, bodyEnv.eval)) + uses(valueName, bodyEnv.eval), + ) unifyEnvs( bodyEnv.deleteEval(valueName), - memoizeValueIR(ctx, a, TStream(valueType), memo) + memoizeValueIR(ctx, a, TStream(valueType), memo), ) - case MakeNDArray(data, shape, rowMajor, errorId) => + case MakeNDArray(data, shape, rowMajor, _) => val elementType = requestedType.asInstanceOf[TNDArray].elementType - val dataType = if (data.typ.isInstanceOf[TArray]) TArray(elementType) else TStream(elementType) + val dataType = + if (data.typ.isInstanceOf[TArray]) TArray(elementType) else TStream(elementType) unifyEnvs( memoizeValueIR(ctx, data, dataType, memo), memoizeValueIR(ctx, shape, shape.typ, memo), - memoizeValueIR(ctx, rowMajor, rowMajor.typ, memo) + memoizeValueIR(ctx, rowMajor, rowMajor.typ, memo), ) case NDArrayMap(nd, valueName, body) => val ndType = nd.typ.asInstanceOf[TNDArray] - val bodyEnv = memoizeValueIR(ctx, body, requestedType.asInstanceOf[TNDArray].elementType, memo) + val bodyEnv = + memoizeValueIR(ctx, body, requestedType.asInstanceOf[TNDArray].elementType, memo) val valueType = unifySeq( ndType.elementType, - uses(valueName, bodyEnv.eval) + uses(valueName, bodyEnv.eval), ) unifyEnvs( bodyEnv.deleteEval(valueName), - memoizeValueIR(ctx, nd, ndType.copy(elementType = valueType), memo) + memoizeValueIR(ctx, nd, ndType.copy(elementType = valueType), memo), ) case NDArrayMap2(left, right, leftName, rightName, body, _) => val leftType = left.typ.asInstanceOf[TNDArray] val rightType = right.typ.asInstanceOf[TNDArray] - val bodyEnv = memoizeValueIR(ctx, body, requestedType.asInstanceOf[TNDArray].elementType, memo) + val bodyEnv = + memoizeValueIR(ctx, body, requestedType.asInstanceOf[TNDArray].elementType, memo) val leftValueType = unifySeq(leftType.elementType, uses(leftName, bodyEnv.eval)) val rightValueType = unifySeq(rightType.elementType, uses(rightName, bodyEnv.eval)) @@ -1303,25 +1638,23 @@ object PruneDeadFields { unifyEnvs( bodyEnv.deleteEval(leftName).deleteEval(rightName), memoizeValueIR(ctx, left, leftType.copy(elementType = leftValueType), memo), - memoizeValueIR(ctx, right, rightType.copy(elementType = rightValueType), memo) + memoizeValueIR(ctx, right, rightType.copy(elementType = rightValueType), memo), ) case AggExplode(a, name, body, isScan) => - val bodyEnv = memoizeValueIR(ctx, body, - requestedType, - memo) + val bodyEnv = memoizeValueIR(ctx, body, requestedType, memo) if (isScan) { val valueType = unifySeq(TIterable.elementType(a.typ), uses(name, bodyEnv.scanOrEmpty)) val aEnv = memoizeValueIR(ctx, a, TStream(valueType), memo) unifyEnvs( BindingEnv(scan = bodyEnv.scan.map(_.delete(name))), - BindingEnv(scan = Some(aEnv.eval)) + BindingEnv(scan = Some(aEnv.eval)), ) } else { val valueType = unifySeq(TIterable.elementType(a.typ), uses(name, bodyEnv.aggOrEmpty)) val aEnv = memoizeValueIR(ctx, a, TStream(valueType), memo) unifyEnvs( BindingEnv(agg = bodyEnv.agg.map(_.delete(name))), - BindingEnv(agg = Some(aEnv.eval)) + BindingEnv(agg = Some(aEnv.eval)), ) } case AggFilter(cond, aggIR, isScan) => @@ -1331,7 +1664,7 @@ object PruneDeadFields { BindingEnv(scan = Some(condEnv.eval)) else BindingEnv(agg = Some(condEnv.eval)), - memoizeValueIR(ctx, aggIR, requestedType, memo) + memoizeValueIR(ctx, aggIR, requestedType, memo), ) case AggGroupBy(key, aggIR, isScan) => val keyEnv = memoizeValueIR(ctx, key, requestedType.asInstanceOf[TDict].keyType, memo) @@ -1340,39 +1673,52 @@ object PruneDeadFields { BindingEnv(scan = Some(keyEnv.eval)) else BindingEnv(agg = Some(keyEnv.eval)), - memoizeValueIR(ctx, aggIR, requestedType.asInstanceOf[TDict].valueType, memo) + memoizeValueIR(ctx, aggIR, requestedType.asInstanceOf[TDict].valueType, memo), ) case AggArrayPerElement(a, elementName, indexName, aggBody, knownLength, isScan) => - val aType = a.typ.asInstanceOf[TArray] - val bodyEnv = memoizeValueIR(ctx, aggBody, - TIterable.elementType(requestedType), - memo) + val bodyEnv = memoizeValueIR(ctx, aggBody, TIterable.elementType(requestedType), memo) if (isScan) { - val valueType = unifySeq(TIterable.elementType(a.typ), uses(elementName, bodyEnv.scanOrEmpty)) + val valueType = + unifySeq(TIterable.elementType(a.typ), uses(elementName, bodyEnv.scanOrEmpty)) val aEnv = memoizeValueIR(ctx, a, TArray(valueType), memo) unifyEnvsSeq(FastSeq( - bodyEnv.copy(eval = bodyEnv.eval.delete(indexName), scan = bodyEnv.scan.map(_.delete(elementName))), - BindingEnv(scan = Some(aEnv.eval)) + bodyEnv.copy( + eval = bodyEnv.eval.delete(indexName), + scan = bodyEnv.scan.map(_.delete(elementName)), + ), + BindingEnv(scan = Some(aEnv.eval)), ) ++ knownLength.map(x => memoizeValueIR(ctx, x, x.typ, memo))) } else { - val valueType = unifySeq(TIterable.elementType(a.typ), uses(elementName, bodyEnv.aggOrEmpty)) + val valueType = + unifySeq(TIterable.elementType(a.typ), uses(elementName, bodyEnv.aggOrEmpty)) val aEnv = memoizeValueIR(ctx, a, TArray(valueType), memo) unifyEnvsSeq(FastSeq( - bodyEnv.copy(eval = bodyEnv.eval.delete(indexName), agg = bodyEnv.agg.map(_.delete(elementName))), - BindingEnv(agg = Some(aEnv.eval)) + bodyEnv.copy( + eval = bodyEnv.eval.delete(indexName), + agg = bodyEnv.agg.map(_.delete(elementName)), + ), + BindingEnv(agg = Some(aEnv.eval)), ) ++ knownLength.map(x => memoizeValueIR(ctx, x, x.typ, memo))) } case ApplyAggOp(initOpArgs, seqOpArgs, sig) => val prunedSig = AggSignature.prune(sig, requestedType) - val initEnv = unifyEnvsSeq((initOpArgs, prunedSig.initOpArgs).zipped.map { (arg, req) => memoizeValueIR(ctx, arg, req, memo) }) - val seqOpEnv = unifyEnvsSeq((seqOpArgs, prunedSig.seqOpArgs).zipped.map { (arg, req) => memoizeValueIR(ctx, arg, req, memo) }) + val initEnv = unifyEnvsSeq((initOpArgs, prunedSig.initOpArgs).zipped.map { (arg, req) => + memoizeValueIR(ctx, arg, req, memo) + }) + val seqOpEnv = unifyEnvsSeq((seqOpArgs, prunedSig.seqOpArgs).zipped.map { (arg, req) => + memoizeValueIR(ctx, arg, req, memo) + }) BindingEnv(eval = initEnv.eval, agg = Some(seqOpEnv.eval)) case ApplyScanOp(initOpArgs, seqOpArgs, sig) => val prunedSig = AggSignature.prune(sig, requestedType) - val initEnv = unifyEnvsSeq((initOpArgs, prunedSig.initOpArgs).zipped.map { (arg, req) => memoizeValueIR(ctx, arg, req, memo) }) - val seqOpEnv = unifyEnvsSeq((seqOpArgs, prunedSig.seqOpArgs).zipped.map { (arg, req) => memoizeValueIR(ctx, arg, req, memo) }) + val initEnv = unifyEnvsSeq((initOpArgs, prunedSig.initOpArgs).zipped.map { (arg, req) => + memoizeValueIR(ctx, arg, req, memo) + }) + val seqOpEnv = unifyEnvsSeq((seqOpArgs, prunedSig.seqOpArgs).zipped.map { (arg, req) => + memoizeValueIR(ctx, arg, req, memo) + }) BindingEnv(eval = initEnv.eval, scan = Some(seqOpEnv.eval)) - case AggFold(zero, seqOp, combOp, accumName, otherAccumName, isScan) => + case AggFold(zero, seqOp, combOp, accumName, _, isScan) => val initEnv = memoizeValueIR(ctx, zero, zero.typ, memo) val seqEnv = memoizeValueIR(ctx, seqOp, seqOp.typ, memo) memoizeValueIR(ctx, combOp, combOp.typ, memo) @@ -1383,27 +1729,32 @@ object PruneDeadFields { BindingEnv(eval = initEnv.eval, agg = Some(seqEnv.eval.delete(accumName))) case StreamAgg(a, name, query) => val queryEnv = memoizeValueIR(ctx, query, requestedType, memo) - val requestedElemType = unifySeq(TIterable.elementType(a.typ), uses(name, queryEnv.aggOrEmpty)) + val requestedElemType = + unifySeq(TIterable.elementType(a.typ), uses(name, queryEnv.aggOrEmpty)) val aEnv = memoizeValueIR(ctx, a, TStream(requestedElemType), memo) unifyEnvs( BindingEnv(eval = concatEnvs(Array(queryEnv.eval, queryEnv.aggOrEmpty.delete(name)))), - aEnv) + aEnv, + ) case StreamAggScan(a, name, query) => val queryEnv = memoizeValueIR(ctx, query, TIterable.elementType(requestedType), memo) val requestedElemType = unifySeq( TIterable.elementType(a.typ), - uses(name, queryEnv.scanOrEmpty) ++ uses(name, queryEnv.eval) + uses(name, queryEnv.scanOrEmpty) ++ uses(name, queryEnv.eval), ) val aEnv = memoizeValueIR(ctx, a, TStream(requestedElemType), memo) unifyEnvs( - BindingEnv(eval = concatEnvs(Array(queryEnv.eval.delete(name), queryEnv.scanOrEmpty.delete(name)))), - aEnv) + BindingEnv(eval = + concatEnvs(Array(queryEnv.eval.delete(name), queryEnv.scanOrEmpty.delete(name))) + ), + aEnv, + ) case RunAgg(body, result, _) => unifyEnvs( memoizeValueIR(ctx, body, body.typ, memo), - memoizeValueIR(ctx, result, requestedType, memo) + memoizeValueIR(ctx, result, requestedType, memo), ) - case RunAggScan(array, name, init, seqs, result, signature) => + case RunAggScan(array, name, init, seqs, result, _) => val resultEnv = memoizeValueIR(ctx, result, TIterable.elementType(requestedType), memo) val seqEnv = memoizeValueIR(ctx, seqs, seqs.typ, memo) val elemEnv = unifyEnvs(resultEnv, seqEnv) @@ -1411,7 +1762,7 @@ object PruneDeadFields { unifyEnvs( elemEnv, memoizeValueIR(ctx, array, TStream(requestedElemType), memo), - memoizeValueIR(ctx, init, init.typ, memo) + memoizeValueIR(ctx, init, init.typ, memo), ) case MakeStruct(fields) => val sType = requestedType.asInstanceOf[TStruct] @@ -1423,7 +1774,6 @@ object PruneDeadFields { val sType = requestedType.asInstanceOf[TStruct] val insFieldNames = fields.map(_._1).toSet val rightDep = sType.filter(f => insFieldNames.contains(f.name))._1 - val rightDepFields = rightDep.fieldNames.toSet val leftDep = TStruct( old.typ.asInstanceOf[TStruct] .fields @@ -1432,7 +1782,8 @@ object PruneDeadFields { Some(f.name -> minimal(f.typ)) else sType.selfField(f.name).map(f.name -> _.typ) - }: _*) + }: _* + ) unifyEnvsSeq( FastSeq(memoizeValueIR(ctx, old, leftDep, memo)) ++ // ignore unreachable fields, these are eliminated on the upwards pass @@ -1440,7 +1791,7 @@ object PruneDeadFields { rightDep.selfField(fname).map(f => memoizeValueIR(ctx, fir, f.typ, memo)) } ) - case SelectFields(old, fields) => + case SelectFields(old, _) => val sType = requestedType.asInstanceOf[TStruct] val oldReqType = TStruct(old.typ.asInstanceOf[TStruct] .fieldNames @@ -1455,17 +1806,16 @@ object PruneDeadFields { fields.flatMap { case (i, value) => // ignore unreachable fields, these are eliminated on the upwards pass tType.fieldIndex.get(i) - .map { idx => - memoizeValueIR(ctx, value, tType.types(idx), memo) - }}) + .map(idx => memoizeValueIR(ctx, value, tType.types(idx), memo)) + } + ) case GetTupleElement(o, idx) => - val childTupleType = o.typ.asInstanceOf[TTuple] val tupleDep = TTuple(FastSeq(TupleField(idx, requestedType))) memoizeValueIR(ctx, o, tupleDep, memo) case ConsoleLog(message, result) => unifyEnvs( memoizeValueIR(ctx, message, TString, memo), - memoizeValueIR(ctx, result, result.typ, memo) + memoizeValueIR(ctx, result, result.typ, memo), ) case MatrixCount(child) => memoizeMatrixIR(ctx, child, minimal(child.typ), memo) @@ -1474,30 +1824,48 @@ object PruneDeadFields { memoizeTableIR(ctx, child, minimal(child.typ), memo) BindingEnv.empty case TableGetGlobals(child) => - memoizeTableIR(ctx, child, minimal(child.typ).copy(globalType = requestedType.asInstanceOf[TStruct]), memo) + memoizeTableIR( + ctx, + child, + minimal(child.typ).copy(globalType = requestedType.asInstanceOf[TStruct]), + memo, + ) BindingEnv.empty case TableCollect(child) => val rStruct = requestedType.asInstanceOf[TStruct] - memoizeTableIR(ctx, child, TableType( - key = child.typ.key, - rowType = unify(child.typ.rowType, - rStruct.selfField("rows").map(_.typ.asInstanceOf[TArray].elementType.asInstanceOf[TStruct]).getOrElse(TStruct.empty)), - globalType = rStruct.selfField("global").map(_.typ.asInstanceOf[TStruct]).getOrElse(TStruct.empty)), - memo) + memoizeTableIR( + ctx, + child, + TableType( + key = child.typ.key, + rowType = unify( + child.typ.rowType, + rStruct.selfField("rows").map( + _.typ.asInstanceOf[TArray].elementType.asInstanceOf[TStruct] + ).getOrElse(TStruct.empty), + ), + globalType = + rStruct.selfField("global").map(_.typ.asInstanceOf[TStruct]).getOrElse(TStruct.empty), + ), + memo, + ) BindingEnv.empty case TableToValueApply(child, _) => memoizeTableIR(ctx, child, child.typ, memo) BindingEnv.empty - case MatrixToValueApply(child, _) => memoizeMatrixIR(ctx, child, child.typ, memo) + case MatrixToValueApply(child, _) => + memoizeMatrixIR(ctx, child, child.typ, memo) BindingEnv.empty - case BlockMatrixToValueApply(child, _) => memoizeBlockMatrixIR(ctx, child, child.typ, memo) + case BlockMatrixToValueApply(child, _) => + memoizeBlockMatrixIR(ctx, child, child.typ, memo) BindingEnv.empty case TableAggregate(child, query) => val queryDep = memoizeAndGetDep(ctx, query, query.typ, child.typ, memo) val dep = TableType( key = child.typ.key, - rowType = unify(child.typ.rowType, queryDep.rowType, selectKey(child.typ.rowType, child.typ.key)), - globalType = queryDep.globalType + rowType = + unify(child.typ.rowType, queryDep.rowType, selectKey(child.typ.rowType, child.typ.key)), + globalType = queryDep.globalType, ) memoizeTableIR(ctx, child, dep, memo) BindingEnv.empty @@ -1506,25 +1874,29 @@ object PruneDeadFields { val dep = MatrixType( rowKey = child.typ.rowKey, colKey = FastSeq(), - rowType = unify(child.typ.rowType, queryDep.rowType, selectKey(child.typ.rowType, child.typ.rowKey)), + rowType = unify( + child.typ.rowType, + queryDep.rowType, + selectKey(child.typ.rowType, child.typ.rowKey), + ), entryType = queryDep.entryType, colType = queryDep.colType, - globalType = queryDep.globalType + globalType = queryDep.globalType, ) memoizeMatrixIR(ctx, child, dep, memo) BindingEnv.empty - case TailLoop(name, params, _, body) => + case TailLoop(_, params, _, body) => val bodyEnv = memoizeValueIR(ctx, body, body.typ, memo) - val paramTypes = params.map{ case (paramName, paramIR) => + val paramTypes = params.map { case (paramName, paramIR) => unifySeq(paramIR.typ, uses(paramName, bodyEnv.eval)) } unifyEnvsSeq( IndexedSeq(bodyEnv.deleteEval(params.map(_._1))) ++ - (params, paramTypes).zipped.map{ case ((paramName, paramIR), paramType) => - memoizeValueIR(ctx, paramIR, paramType, memo) - } + (params, paramTypes).zipped.map { case ((_, paramIR), paramType) => + memoizeValueIR(ctx, paramIR, paramType, memo) + } ) - case CollectDistributedArray(contexts, globals, cname, gname, body, dynamicID, _, tsd) => + case CollectDistributedArray(contexts, globals, cname, gname, body, dynamicID, _, _) => val rArray = requestedType.asInstanceOf[TArray] val bodyEnv = memoizeValueIR(ctx, body, rArray.elementType, memo) assert(bodyEnv.scan.isEmpty) @@ -1546,7 +1918,7 @@ object PruneDeadFields { case tir: TableIR => memoizeTableIR(ctx, tir, tir.typ, memo) None - case bmir: BlockMatrixIR => //NOTE Currently no BlockMatrixIRs would have dead fields + case _: BlockMatrixIR => // NOTE Currently no BlockMatrixIRs would have dead fields None case ir: IR => Some(memoizeValueIR(ctx, ir, ir.typ, memo)) @@ -1558,15 +1930,19 @@ object PruneDeadFields { def rebuild( ctx: ExecuteContext, tir: TableIR, - memo: RebuildMutableState + memo: RebuildMutableState, ): TableIR = { val requestedType = memo.requestedType.lookup(tir).asInstanceOf[TableType] tir match { case TableParallelize(rowsAndGlobal, nPartitions) => TableParallelize( - upcast(ctx, rebuildIR(ctx, rowsAndGlobal, BindingEnv.empty, memo), - memo.requestedType.lookup(rowsAndGlobal).asInstanceOf[TStruct]), - nPartitions) + upcast( + ctx, + rebuildIR(ctx, rowsAndGlobal, BindingEnv.empty, memo), + memo.requestedType.lookup(rowsAndGlobal).asInstanceOf[TStruct], + ), + nPartitions, + ) case TableGen(contexts, globals, cname, gname, body, partitioner, errorId) => val newContexts = rebuildIR(ctx, contexts, BindingEnv.empty, memo) @@ -1579,7 +1955,7 @@ object PruneDeadFields { gname = gname, body = rebuildIR(ctx, body, BindingEnv(bodyEnv), memo), partitioner.coarsen(requestedType.key.length), - errorId + errorId, ) case TableRead(typ, dropRows, tr) => @@ -1587,7 +1963,8 @@ object PruneDeadFields { val requestedTypeWithKey = TableType( key = typ.key, rowType = unify(typ.rowType, selectKey(typ.rowType, typ.key), requestedType.rowType), - globalType = requestedType.globalType) + globalType = requestedType.globalType, + ) TableRead(requestedTypeWithKey, dropRows, tr) case TableFilter(child, pred) => val child2 = rebuild(ctx, child, memo) @@ -1595,9 +1972,15 @@ object PruneDeadFields { TableFilter(child2, pred2) case TableMapPartitions(child, gName, pName, body, requestedKey, allowedOverlap) => val child2 = rebuild(ctx, child, memo) - val body2 = rebuildIR(ctx, body, BindingEnv(Env( - gName -> child2.typ.globalType, - pName -> TStream(child2.typ.rowType))), memo) + val body2 = rebuildIR( + ctx, + body, + BindingEnv(Env( + gName -> child2.typ.globalType, + pName -> TStream(child2.typ.rowType), + )), + memo, + ) val body2ElementType = TIterable.elementType(body2.typ).asInstanceOf[TStruct] val child2Keyed = if (child2.typ.key.exists(k => !body2ElementType.hasField(k))) TableKeyBy(child2, child2.typ.key.takeWhile(body2ElementType.hasField)) @@ -1605,10 +1988,22 @@ object PruneDeadFields { child2 val childKeyLen = child2Keyed.typ.key.length require(requestedKey <= childKeyLen) - TableMapPartitions(child2Keyed, gName, pName, body2, requestedKey, math.min(allowedOverlap, childKeyLen)) + TableMapPartitions( + child2Keyed, + gName, + pName, + body2, + requestedKey, + math.min(allowedOverlap, childKeyLen), + ) case TableMapRows(child, newRow) => val child2 = rebuild(ctx, child, memo) - val newRow2 = rebuildIR(ctx, newRow, BindingEnv(child2.typ.rowEnv, scan = Some(child2.typ.rowEnv)), memo) + val newRow2 = rebuildIR( + ctx, + newRow, + BindingEnv(child2.typ.rowEnv, scan = Some(child2.typ.rowEnv)), + memo, + ) val newRowType = newRow2.typ.asInstanceOf[TStruct] val child2Keyed = if (child2.typ.key.exists(k => !newRowType.hasField(k))) { val upcastKey = child2.typ.key.takeWhile(newRowType.hasField) @@ -1625,15 +2020,30 @@ object PruneDeadFields { val keys2 = requestedType.key // fully upcast before shuffle if (!isSorted && keys2.nonEmpty) - child2 = upcastTable(ctx, child2, memo.requestedType.lookup(child).asInstanceOf[TableType], upcastGlobals = false) + child2 = upcastTable( + ctx, + child2, + memo.requestedType.lookup(child).asInstanceOf[TableType], + upcastGlobals = false, + ) TableKeyBy(child2, keys2, isSorted) case TableOrderBy(child, sortFields) => - val child2 = if (sortFields.forall(_.sortOrder == Ascending) && child.typ.key.startsWith(sortFields.map(_.field))) - rebuild(ctx, child, memo) - else { - // fully upcast before shuffle - upcastTable(ctx, rebuild(ctx, child, memo), memo.requestedType.lookup(child).asInstanceOf[TableType], upcastGlobals = false) - } + val child2 = + if ( + sortFields.forall(_.sortOrder == Ascending) && child.typ.key.startsWith( + sortFields.map(_.field) + ) + ) + rebuild(ctx, child, memo) + else { + // fully upcast before shuffle + upcastTable( + ctx, + rebuild(ctx, child, memo), + memo.requestedType.lookup(child).asInstanceOf[TableType], + upcastGlobals = false, + ) + } TableOrderBy(child2, sortFields) case TableLeftJoinRightDistinct(left, right, root) => if (requestedType.rowType.hasField(root)) @@ -1646,15 +2056,30 @@ object PruneDeadFields { else rebuild(ctx, left, memo) case TableMultiWayZipJoin(children, fieldName, globalName) => - val rebuilt = children.map { c => rebuild(ctx, c, memo) } - val upcasted = rebuilt.map { t => upcastTable(ctx, t, memo.requestedType.lookup(children(0)).asInstanceOf[TableType]) } + val rebuilt = children.map(c => rebuild(ctx, c, memo)) + val upcasted = rebuilt.map { t => + upcastTable(ctx, t, memo.requestedType.lookup(children(0)).asInstanceOf[TableType]) + } TableMultiWayZipJoin(upcasted, fieldName, globalName) case TableAggregateByKey(child, expr) => val child2 = rebuild(ctx, child, memo) - TableAggregateByKey(child2, rebuildIR(ctx, expr, BindingEnv(child2.typ.globalEnv, agg = Some(child2.typ.rowEnv)), memo)) + TableAggregateByKey( + child2, + rebuildIR( + ctx, + expr, + BindingEnv(child2.typ.globalEnv, agg = Some(child2.typ.rowEnv)), + memo, + ), + ) case TableKeyByAndAggregate(child, expr, newKey, nPartitions, bufferSize) => val child2 = rebuild(ctx, child, memo) - val expr2 = rebuildIR(ctx, expr, BindingEnv(child2.typ.globalEnv, agg = Some(child2.typ.rowEnv)), memo) + val expr2 = rebuildIR( + ctx, + expr, + BindingEnv(child2.typ.globalEnv, agg = Some(child2.typ.rowEnv)), + memo, + ) val newKey2 = rebuildIR(ctx, newKey, BindingEnv(child2.typ.rowEnv), memo) TableKeyByAndAggregate(child2, expr2, newKey2, nPartitions, bufferSize) case TableRename(child, rowMap, globalMap) => @@ -1662,7 +2087,8 @@ object PruneDeadFields { TableRename( child2, rowMap.filterKeys(child2.typ.rowType.hasField), - globalMap.filterKeys(child2.typ.globalType.hasField)) + globalMap.filterKeys(child2.typ.globalType.hasField), + ) case TableUnion(children) => val requestedType = memo.requestedType.lookup(tir).asInstanceOf[TableType] val rebuilt = children.map { c => @@ -1678,22 +2104,22 @@ object PruneDeadFields { val aux2 = rebuildIR(ctx, aux, BindingEnv.empty, memo) BlockMatrixToTableApply(bmir2, aux2, function) case _ => tir.mapChildren { - // IR should be a match error - all nodes with child value IRs should have a rule - case childT: TableIR => rebuild(ctx, childT, memo) - case childM: MatrixIR => rebuild(ctx, childM, memo) - case childBm: BlockMatrixIR => rebuild(ctx, childBm, memo) - }.asInstanceOf[TableIR] + // IR should be a match error - all nodes with child value IRs should have a rule + case childT: TableIR => rebuild(ctx, childT, memo) + case childM: MatrixIR => rebuild(ctx, childM, memo) + case childBm: BlockMatrixIR => rebuild(ctx, childBm, memo) + }.asInstanceOf[TableIR] } } def rebuild( ctx: ExecuteContext, mir: MatrixIR, - memo: RebuildMutableState + memo: RebuildMutableState, ): MatrixIR = { val requestedType = memo.requestedType.lookup(mir).asInstanceOf[MatrixType] mir match { - case x@MatrixRead(typ, dropCols, dropRows, reader) => + case MatrixRead(typ, dropCols, dropRows, reader) => // FIXME: remove this when all readers know how to read without keys val requestedTypeWithKeys = MatrixType( rowKey = typ.rowKey, @@ -1701,7 +2127,7 @@ object PruneDeadFields { rowType = unify(typ.rowType, selectKey(typ.rowType, typ.rowKey), requestedType.rowType), entryType = requestedType.entryType, colType = unify(typ.colType, selectKey(typ.colType, typ.colKey), requestedType.colType), - globalType = requestedType.globalType + globalType = requestedType.globalType, ) MatrixRead(requestedTypeWithKeys, dropCols, dropRows, reader) case MatrixFilterCols(child, pred) => @@ -1718,8 +2144,16 @@ object PruneDeadFields { MatrixMapEntries(child2, rebuildIR(ctx, newEntries, BindingEnv(child2.typ.entryEnv), memo)) case MatrixMapRows(child, newRow) => val child2 = rebuild(ctx, child, memo) - val newRow2 = rebuildIR(ctx, newRow, - BindingEnv(child2.typ.rowEnv, agg = Some(child2.typ.entryEnv), scan = Some(child2.typ.rowEnv)), memo) + val newRow2 = rebuildIR( + ctx, + newRow, + BindingEnv( + child2.typ.rowEnv, + agg = Some(child2.typ.entryEnv), + scan = Some(child2.typ.rowEnv), + ), + memo, + ) val newRowType = newRow2.typ.asInstanceOf[TStruct] val child2Keyed = if (child2.typ.rowKey.exists(k => !newRowType.hasField(k))) MatrixKeyRowsBy(child2, child2.typ.rowKey.takeWhile(newRowType.hasField)) @@ -1728,15 +2162,23 @@ object PruneDeadFields { MatrixMapRows(child2Keyed, newRow2) case MatrixMapCols(child, newCol, newKey) => val child2 = rebuild(ctx, child, memo) - val newCol2 = rebuildIR(ctx, newCol, - BindingEnv(child2.typ.colEnv, agg = Some(child2.typ.entryEnv), scan = Some(child2.typ.colEnv)), memo) + val newCol2 = rebuildIR( + ctx, + newCol, + BindingEnv( + child2.typ.colEnv, + agg = Some(child2.typ.entryEnv), + scan = Some(child2.typ.colEnv), + ), + memo, + ) val newColType = newCol2.typ.asInstanceOf[TStruct] val newKey2 = newKey match { case Some(nk) => Some(nk.takeWhile(newColType.hasField)) case None => if (child2.typ.colKey.exists(k => !newColType.hasField(k))) - Some(child2.typ.colKey.takeWhile(newColType.hasField)) - else - None + Some(child2.typ.colKey.takeWhile(newColType.hasField)) + else + None } MatrixMapCols(child2, newCol2, newKey2) case MatrixMapGlobals(child, newGlobals) => @@ -1748,20 +2190,49 @@ object PruneDeadFields { MatrixKeyRowsBy(child2, keys2, isSorted) case MatrixAggregateRowsByKey(child, entryExpr, rowExpr) => val child2 = rebuild(ctx, child, memo) - MatrixAggregateRowsByKey(child2, - rebuildIR(ctx, entryExpr, BindingEnv(child2.typ.colEnv, agg = Some(child2.typ.entryEnv)), memo), - rebuildIR(ctx, rowExpr, BindingEnv(child2.typ.globalEnv, agg = Some(child2.typ.rowEnv)), memo)) + MatrixAggregateRowsByKey( + child2, + rebuildIR( + ctx, + entryExpr, + BindingEnv(child2.typ.colEnv, agg = Some(child2.typ.entryEnv)), + memo, + ), + rebuildIR( + ctx, + rowExpr, + BindingEnv(child2.typ.globalEnv, agg = Some(child2.typ.rowEnv)), + memo, + ), + ) case MatrixAggregateColsByKey(child, entryExpr, colExpr) => val child2 = rebuild(ctx, child, memo) - MatrixAggregateColsByKey(child2, - rebuildIR(ctx, entryExpr, BindingEnv(child2.typ.rowEnv, agg = Some(child2.typ.entryEnv)), memo), - rebuildIR(ctx, colExpr, BindingEnv(child2.typ.globalEnv, agg = Some(child2.typ.colEnv)), memo)) + MatrixAggregateColsByKey( + child2, + rebuildIR( + ctx, + entryExpr, + BindingEnv(child2.typ.rowEnv, agg = Some(child2.typ.entryEnv)), + memo, + ), + rebuildIR( + ctx, + colExpr, + BindingEnv(child2.typ.globalEnv, agg = Some(child2.typ.colEnv)), + memo, + ), + ) case MatrixUnionRows(children) => val requestedType = memo.requestedType.lookup(mir).asInstanceOf[MatrixType] - val firstChild = upcast(ctx, rebuild(ctx, children.head, memo), requestedType, upcastGlobals = false) + val firstChild = + upcast(ctx, rebuild(ctx, children.head, memo), requestedType, upcastGlobals = false) val remainingChildren = children.tail.map { child => - upcast(ctx, rebuild(ctx, child, memo), requestedType.copy(colType = requestedType.colKeyStruct), - upcastGlobals = false) + upcast( + ctx, + rebuild(ctx, child, memo), + requestedType.copy(colType = requestedType.colKeyStruct), + upcastGlobals = false, + ) } MatrixUnionRows(firstChild +: remainingChildren) case MatrixUnionCols(left, right, joinType) => @@ -1769,17 +2240,19 @@ object PruneDeadFields { val left2 = rebuild(ctx, left, memo) val right2 = rebuild(ctx, right, memo) - if (left2.typ.colType == right2.typ.colType && left2.typ.entryType == right2.typ.entryType) { + if ( + left2.typ.colType == right2.typ.colType && left2.typ.entryType == right2.typ.entryType + ) { MatrixUnionCols( left2, right2, - joinType + joinType, ) } else { MatrixUnionCols( - upcast(ctx, left2, requestedType, upcastRows=false, upcastGlobals = false), - upcast(ctx, right2, requestedType, upcastRows=false, upcastGlobals = false), - joinType + upcast(ctx, left2, requestedType, upcastRows = false, upcastGlobals = false), + upcast(ctx, right2, requestedType, upcastRows = false, upcastGlobals = false), + joinType, ) } case MatrixAnnotateRowsTable(child, table, root, product) => @@ -1807,25 +2280,31 @@ object PruneDeadFields { globalMap.filterKeys(child2.typ.globalType.hasField), colMap.filterKeys(child2.typ.colType.hasField), rowMap.filterKeys(child2.typ.rowType.hasField), - entryMap.filterKeys(child2.typ.entryType.hasField)) + entryMap.filterKeys(child2.typ.entryType.hasField), + ) case RelationalLetMatrixTable(name, value, body) => val value2 = rebuildIR(ctx, value, BindingEnv.empty, memo) memo.relationalRefs += name -> value2.typ RelationalLetMatrixTable(name, value2, rebuild(ctx, body, memo)) case CastTableToMatrix(child, entriesFieldName, colsFieldName, _) => - CastTableToMatrix(rebuild(ctx, child, memo), entriesFieldName, colsFieldName, requestedType.colKey) + CastTableToMatrix( + rebuild(ctx, child, memo), + entriesFieldName, + colsFieldName, + requestedType.colKey, + ) case _ => mir.mapChildren { - // IR should be a match error - all nodes with child value IRs should have a rule - case childT: TableIR => rebuild(ctx, childT, memo) - case childM: MatrixIR => rebuild(ctx, childM, memo) - }.asInstanceOf[MatrixIR] + // IR should be a match error - all nodes with child value IRs should have a rule + case childT: TableIR => rebuild(ctx, childT, memo) + case childM: MatrixIR => rebuild(ctx, childM, memo) + }.asInstanceOf[MatrixIR] } } def rebuild( ctx: ExecuteContext, bmir: BlockMatrixIR, - memo: RebuildMutableState + memo: RebuildMutableState, ): BlockMatrixIR = bmir match { case RelationalLetBlockMatrix(name, value, body) => val value2 = rebuildIR(ctx, value, BindingEnv.empty, memo) @@ -1844,7 +2323,7 @@ object PruneDeadFields { ctx: ExecuteContext, ir: IR, env: BindingEnv[Type], - memo: RebuildMutableState + memo: RebuildMutableState, ): IR = { val requestedType = memo.requestedType.lookup(ir).asInstanceOf[Type] ir match { @@ -1861,9 +2340,11 @@ object PruneDeadFields { }) case (TTuple(reb), TTuple(cast), TTuple(base)) => assert(base.length == cast.length) - val castFields = cast.map { f => f.index -> f.typ }.toMap - val baseFields = base.map { f => f.index -> f.typ }.toMap - TTuple(reb.map { f => TupleField(f.index, recur(f.typ, castFields(f.index), baseFields(f.index)))}) + val castFields = cast.map(f => f.index -> f.typ).toMap + val baseFields = base.map(f => f.index -> f.typ).toMap + TTuple(reb.map { f => + TupleField(f.index, recur(f.typ, castFields(f.index), baseFields(f.index))) + }) case (TArray(reb), TArray(cast), TArray(base)) => TArray(recur(reb, cast, base)) case (TSet(reb), TSet(cast), TSet(base)) => @@ -1885,10 +2366,7 @@ object PruneDeadFields { if (cnsq2.typ == alt2.typ) If(cond2, cnsq2, alt2) else - If(cond2, - upcast(ctx, cnsq2, requestedType), - upcast(ctx, alt2, requestedType) - ) + If(cond2, upcast(ctx, cnsq2, requestedType), upcast(ctx, alt2, requestedType)) case Coalesce(values) => val values2 = values.map(rebuildIR(ctx, _, env, memo)) require(values2.nonEmpty) @@ -1899,23 +2377,15 @@ object PruneDeadFields { case Consume(value) => val value2 = rebuildIR(ctx, value, env, memo) Consume(value2) - case Let(bindings, body) => - val newBindings = new Array[(String, IR)](bindings.length) - val (_, newEnv) = bindings.foldLeft((0, env)) { case ((idx, env), (name, value)) => - val newValue = rebuildIR(ctx, value, env, memo) - newBindings(idx) = (name -> newValue) - (idx + 1, env.bindEval(name -> newValue.typ)) + case Block(bindings, body) => + val newBindings = new Array[Binding](bindings.length) + val (_, newEnv) = bindings.foldLeft((0, env)) { + case ((idx, env), Binding(name, value, scope)) => + val newValue = rebuildIR(ctx, value, env.promoteScope(scope), memo) + newBindings(idx) = Binding(name, newValue, scope) + (idx + 1, env.bindInScope(name, newValue.typ, scope)) } - - Let(newBindings, rebuildIR(ctx, body, newEnv, memo)) - case AggLet(name, value, body, isScan) => - val value2 = rebuildIR(ctx, value, if (isScan) env.promoteScan else env.promoteAgg, memo) - AggLet( - name, - value2, - rebuildIR(ctx, body, if (isScan) env.bindScan(name, value2.typ) else env.bindAgg(name, value2.typ), memo), - isScan - ) + Block(newBindings, rebuildIR(ctx, body, newEnv, memo)) case Ref(name, t) => Ref(name, env.eval.lookupOption(name).getOrElse(t)) case RelationalLet(name, value, body) => @@ -1930,35 +2400,61 @@ object PruneDeadFields { case MakeStream(args, _, requiresMemoryManagementPerElement) => val dep = requestedType.asInstanceOf[TStream] val args2 = args.map(a => rebuildIR(ctx, a, env, memo)) - MakeStream.unify(ctx, args2, requiresMemoryManagementPerElement, requestedType = TStream(dep.elementType)) + MakeStream.unify( + ctx, + args2, + requiresMemoryManagementPerElement, + requestedType = TStream(dep.elementType), + ) case StreamMap(a, name, body) => val a2 = rebuildIR(ctx, a, env, memo) - StreamMap(a2, name, rebuildIR(ctx, body, env.bindEval(name, TIterable.elementType(a2.typ)), memo)) + StreamMap( + a2, + name, + rebuildIR(ctx, body, env.bindEval(name, TIterable.elementType(a2.typ)), memo), + ) case StreamZip(as, names, body, b, errorID) => val (newAs, newNames) = (as, names) .zipped - .flatMap { case (a, name) => if (memo.requestedType.contains(a)) Some((rebuildIR(ctx, a, env, memo), name)) else None } + .flatMap { case (a, name) => + if (memo.requestedType.contains(a)) Some((rebuildIR(ctx, a, env, memo), name)) else None + } .unzip - StreamZip(newAs, newNames, rebuildIR(ctx, body, - env.bindEval(newNames.zip(newAs.map(a => TIterable.elementType(a.typ))): _*), memo), b, errorID) + StreamZip( + newAs, + newNames, + rebuildIR( + ctx, + body, + env.bindEval(newNames.zip(newAs.map(a => TIterable.elementType(a.typ))): _*), + memo, + ), + b, + errorID, + ) case StreamZipJoin(as, key, curKey, curVals, joinF) => val newAs = as.map(a => rebuildIR(ctx, a, env, memo)) val newEltType = TIterable.elementType(as.head.typ).asInstanceOf[TStruct] - val newJoinF = rebuildIR(ctx, + val newJoinF = rebuildIR( + ctx, joinF, env.bindEval(curKey -> selectKey(newEltType, key), curVals -> TArray(newEltType)), - memo) + memo, + ) StreamZipJoin(newAs, key, curKey, curVals, newJoinF) case StreamZipJoinProducers(contexts, ctxName, makeProducer, key, curKey, curVals, joinF) => val newContexts = rebuildIR(ctx, contexts, env, memo) val newCtxType = TIterable.elementType(newContexts.typ) val newMakeProducer = rebuildIR(ctx, makeProducer, env.bindEval(ctxName, newCtxType), memo) val newEltType = TIterable.elementType(newMakeProducer.typ).asInstanceOf[TStruct] - val newJoinF = rebuildIR(ctx, + val newJoinF = rebuildIR( + ctx, joinF, env.bindEval(curKey -> selectKey(newEltType, key), curVals -> TArray(newEltType)), - memo) - StreamZipJoinProducers(newContexts, ctxName, newMakeProducer,key, curKey, curVals, newJoinF) + memo, + ) + StreamZipJoinProducers(newContexts, ctxName, newMakeProducer, key, curKey, curVals, + newJoinF) case StreamMultiMerge(as, key) => val eltType = tcoerce[TStruct](tcoerce[TStream](as.head.typ).elementType) val requestedEltType = tcoerce[TStream](requestedType).elementType @@ -1973,16 +2469,32 @@ object PruneDeadFields { StreamMultiMerge(newAs2, key) case StreamFilter(a, name, cond) => val a2 = rebuildIR(ctx, a, env, memo) - StreamFilter(a2, name, rebuildIR(ctx, cond, env.bindEval(name, TIterable.elementType(a2.typ)), memo)) + StreamFilter( + a2, + name, + rebuildIR(ctx, cond, env.bindEval(name, TIterable.elementType(a2.typ)), memo), + ) case StreamTakeWhile(a, name, cond) => val a2 = rebuildIR(ctx, a, env, memo) - StreamTakeWhile(a2, name, rebuildIR(ctx, cond, env.bindEval(name, TIterable.elementType(a2.typ)), memo)) + StreamTakeWhile( + a2, + name, + rebuildIR(ctx, cond, env.bindEval(name, TIterable.elementType(a2.typ)), memo), + ) case StreamDropWhile(a, name, cond) => val a2 = rebuildIR(ctx, a, env, memo) - StreamDropWhile(a2, name, rebuildIR(ctx, cond, env.bindEval(name, TIterable.elementType(a2.typ)), memo)) + StreamDropWhile( + a2, + name, + rebuildIR(ctx, cond, env.bindEval(name, TIterable.elementType(a2.typ)), memo), + ) case StreamFlatMap(a, name, body) => val a2 = rebuildIR(ctx, a, env, memo) - StreamFlatMap(a2, name, rebuildIR(ctx, body, env.bindEval(name, TIterable.elementType(a2.typ)), memo)) + StreamFlatMap( + a2, + name, + rebuildIR(ctx, body, env.bindEval(name, TIterable.elementType(a2.typ)), memo), + ) case StreamFold(a, zero, accumName, valueName, body) => val a2 = rebuildIR(ctx, a, env, memo) val z2 = rebuildIR(ctx, zero, env, memo) @@ -1991,7 +2503,12 @@ object PruneDeadFields { z2, accumName, valueName, - rebuildIR(ctx, body, env.bindEval(accumName -> z2.typ, valueName -> TIterable.elementType(a2.typ)), memo) + rebuildIR( + ctx, + body, + env.bindEval(accumName -> z2.typ, valueName -> TIterable.elementType(a2.typ)), + memo, + ), ) case StreamFold2(a: IR, accum, valueName, seqs, result) => val a2 = rebuildIR(ctx, a, env, memo) @@ -2004,7 +2521,8 @@ object PruneDeadFields { newAccum, valueName, seqs.map(rebuildIR(ctx, _, newEnv, memo)), - rebuildIR(ctx, result, newEnv, memo)) + rebuildIR(ctx, result, newEnv, memo), + ) case StreamScan(a, zero, accumName, valueName, body) => val a2 = rebuildIR(ctx, a, env, memo) val z2 = rebuildIR(ctx, zero, env, memo) @@ -2013,7 +2531,12 @@ object PruneDeadFields { z2, accumName, valueName, - rebuildIR(ctx, body, env.bindEval(accumName -> z2.typ, valueName -> TIterable.elementType(a2.typ)), memo) + rebuildIR( + ctx, + body, + env.bindEval(accumName -> z2.typ, valueName -> TIterable.elementType(a2.typ)), + memo, + ), ) case StreamJoinRightDistinct(left, right, lKey, rKey, l, r, join, joinType) => val left2 = rebuildIR(ctx, left, env, memo) @@ -2022,12 +2545,29 @@ object PruneDeadFields { val ltyp = left2.typ.asInstanceOf[TStream] val rtyp = right2.typ.asInstanceOf[TStream] StreamJoinRightDistinct( - left2, right2, lKey, rKey, l, r, + left2, + right2, + lKey, + rKey, + l, + r, rebuildIR(ctx, join, env.bindEval(l -> ltyp.elementType, r -> rtyp.elementType), memo), - joinType) + joinType, + ) + + case StreamLeftIntervalJoin(left, right, lKFieldName, rIntrvlName, lName, rName, body) => + val newL = rebuildIR(ctx, left, env, memo) + val newR = rebuildIR(ctx, right, env, memo) + val newEnv = env.bindEval( + lName -> TIterable.elementType(newL.typ), + rName -> TArray(TIterable.elementType(newR.typ)), + ) + val newB = rebuildIR(ctx, body, newEnv, memo) + StreamLeftIntervalJoin(newL, newR, lKFieldName, rIntrvlName, lName, rName, newB) case StreamFor(a, valueName, body) => val a2 = rebuildIR(ctx, a, env, memo) - val body2 = rebuildIR(ctx, body, env.bindEval(valueName -> TIterable.elementType(a2.typ)), memo) + val body2 = + rebuildIR(ctx, body, env.bindEval(valueName -> TIterable.elementType(a2.typ)), memo) StreamFor(a2, valueName, body2) case ArraySort(a, left, right, lessThan) => val a2 = rebuildIR(ctx, a, env, memo) @@ -2041,13 +2581,28 @@ object PruneDeadFields { MakeNDArray(data2, shape2, rowMajor2, errorId) case NDArrayMap(nd, valueName, body) => val nd2 = rebuildIR(ctx, nd, env, memo) - NDArrayMap(nd2, valueName, rebuildIR(ctx, body, env.bindEval(valueName, nd2.typ.asInstanceOf[TNDArray].elementType), memo)) + NDArrayMap( + nd2, + valueName, + rebuildIR( + ctx, + body, + env.bindEval(valueName, nd2.typ.asInstanceOf[TNDArray].elementType), + memo, + ), + ) case NDArrayMap2(left, right, leftName, rightName, body, errorID) => val left2 = rebuildIR(ctx, left, env, memo) val right2 = rebuildIR(ctx, right, env, memo) - val body2 = rebuildIR(ctx, body, - env.bindEval(leftName, left2.typ.asInstanceOf[TNDArray].elementType).bindEval(rightName, right2.typ.asInstanceOf[TNDArray].elementType), - memo) + val body2 = rebuildIR( + ctx, + body, + env.bindEval(leftName, left2.typ.asInstanceOf[TNDArray].elementType).bindEval( + rightName, + right2.typ.asInstanceOf[TNDArray].elementType, + ), + memo, + ) NDArrayMap2(left2, right2, leftName, rightName, body2, errorID) case MakeStruct(fields) => val depStruct = requestedType.asInstanceOf[TStruct] @@ -2077,18 +2632,24 @@ object PruneDeadFields { val rebuiltChild = rebuildIR(ctx, old, env, memo) val preservedChildFields = rebuiltChild.typ.asInstanceOf[TStruct].fieldNames.toSet - val insertOverwritesUnrequestedButPreservedField = fields.exists{ case (fieldName, _) => + val insertOverwritesUnrequestedButPreservedField = fields.exists { case (fieldName, _) => preservedChildFields.contains(fieldName) && !depFields.contains(fieldName) } val wrappedChild = if (insertOverwritesUnrequestedButPreservedField) { val selectedChildFields = preservedChildFields.filter(s => depFields.contains(s)) - SelectFields(rebuiltChild, rebuiltChild.typ.asInstanceOf[TStruct].fieldNames.filter(selectedChildFields.contains(_))) + SelectFields( + rebuiltChild, + rebuiltChild.typ.asInstanceOf[TStruct].fieldNames.filter( + selectedChildFields.contains(_) + ), + ) } else { rebuiltChild } - InsertFields(wrappedChild, + InsertFields( + wrappedChild, fields.flatMap { case (f, fir) => if (depFields.contains(f)) Some(f -> rebuildIR(ctx, fir, env, memo)) @@ -2096,22 +2657,41 @@ object PruneDeadFields { log.info(s"Prune: InsertFields: eliminating field '$f'") None } - }, fieldOrder.map(fds => fds.filter(f => depFields.contains(f) || wrappedChild.typ.asInstanceOf[TStruct].hasField(f)))) + }, + fieldOrder.map(fds => + fds.filter(f => + depFields.contains(f) || wrappedChild.typ.asInstanceOf[TStruct].hasField(f) + ) + ), + ) case SelectFields(old, fields) => val depStruct = requestedType.asInstanceOf[TStruct] val old2 = rebuildIR(ctx, old, env, memo) - SelectFields(old2, fields.filter(f => old2.typ.asInstanceOf[TStruct].hasField(f) && depStruct.hasField(f))) + SelectFields( + old2, + fields.filter(f => old2.typ.asInstanceOf[TStruct].hasField(f) && depStruct.hasField(f)), + ) case ConsoleLog(message, result) => val message2 = rebuildIR(ctx, message, env, memo) val result2 = rebuildIR(ctx, result, env, memo) ConsoleLog(message2, result2) case TableAggregate(child, query) => val child2 = rebuild(ctx, child, memo) - val query2 = rebuildIR(ctx, query, BindingEnv(child2.typ.globalEnv, agg = Some(child2.typ.rowEnv)), memo) + val query2 = rebuildIR( + ctx, + query, + BindingEnv(child2.typ.globalEnv, agg = Some(child2.typ.rowEnv)), + memo, + ) TableAggregate(child2, query2) case MatrixAggregate(child, query) => val child2 = rebuild(ctx, child, memo) - val query2 = rebuildIR(ctx, query, BindingEnv(child2.typ.globalEnv, agg = Some(child2.typ.entryEnv)), memo) + val query2 = rebuildIR( + ctx, + query, + BindingEnv(child2.typ.globalEnv, agg = Some(child2.typ.entryEnv)), + memo, + ) MatrixAggregate(child2, query2) case TableCollect(child) => val rStruct = requestedType.asInstanceOf[TStruct] @@ -2122,14 +2702,25 @@ object PruneDeadFields { MakeStruct(FastSeq()) else { val rRowType = TIterable.elementType(rStruct.fieldType("rows")).asInstanceOf[TStruct] - val rGlobType = rStruct.selfField("global").map(_.typ.asInstanceOf[TStruct]).getOrElse(TStruct()) - TableCollect(upcastTable(ctx, rebuild(ctx, child, memo), TableType(rowType = rRowType, FastSeq(), rGlobType), - upcastRow = true, upcastGlobals = false)) + val rGlobType = + rStruct.selfField("global").map(_.typ.asInstanceOf[TStruct]).getOrElse(TStruct()) + TableCollect(upcastTable( + ctx, + rebuild(ctx, child, memo), + TableType(rowType = rRowType, FastSeq(), rGlobType), + upcastRow = true, + upcastGlobals = false, + )) } case AggExplode(array, name, aggBody, isScan) => val a2 = rebuildIR(ctx, array, if (isScan) env.promoteScan else env.promoteAgg, memo) val a2t = TIterable.elementType(a2.typ) - val body2 = rebuildIR(ctx, aggBody, if (isScan) env.bindScan(name, a2t) else env.bindAgg(name, a2t), memo) + val body2 = rebuildIR( + ctx, + aggBody, + if (isScan) env.bindScan(name, a2t) else env.bindAgg(name, a2t), + memo, + ) AggExplode(a2, name, body2, isScan) case AggFilter(cond, aggIR, isScan) => val cond2 = rebuildIR(ctx, cond, if (isScan) env.promoteScan else env.promoteAgg, memo) @@ -2144,8 +2735,20 @@ object PruneDeadFields { val a2 = rebuildIR(ctx, a, aEnv, memo) val a2t = TIterable.elementType(a2.typ) val env_ = env.bindEval(indexName -> TInt32) - val aggBody2 = rebuildIR(ctx, aggBody, if (isScan) env_.bindScan(elementName, a2t) else env_.bindAgg(elementName, a2t), memo) - AggArrayPerElement(a2, elementName, indexName, aggBody2, knownLength.map(rebuildIR(ctx, _, aEnv, memo)), isScan) + val aggBody2 = rebuildIR( + ctx, + aggBody, + if (isScan) env_.bindScan(elementName, a2t) else env_.bindAgg(elementName, a2t), + memo, + ) + AggArrayPerElement( + a2, + elementName, + indexName, + aggBody2, + knownLength.map(rebuildIR(ctx, _, aEnv, memo)), + isScan, + ) case StreamAgg(a, name, query) => val a2 = rebuildIR(ctx, a, env, memo) val newEnv = env.copy(agg = Some(env.eval.bind(name -> TIterable.elementType(a2.typ)))) @@ -2153,7 +2756,12 @@ object PruneDeadFields { StreamAgg(a2, name, query2) case StreamAggScan(a, name, query) => val a2 = rebuildIR(ctx, a, env, memo) - val query2 = rebuildIR(ctx, query, env.copy(scan = Some(env.eval.bind(name -> TIterable.elementType(a2.typ)))), memo) + val query2 = rebuildIR( + ctx, + query, + env.copy(scan = Some(env.eval.bind(name -> TIterable.elementType(a2.typ)))), + memo, + ) StreamAggScan(a2, name, query2) case RunAgg(body, result, signatures) => val body2 = rebuildIR(ctx, body, env, memo) @@ -2169,34 +2777,58 @@ object PruneDeadFields { case ApplyAggOp(initOpArgs, seqOpArgs, aggSig) => val initOpArgs2 = initOpArgs.map(rebuildIR(ctx, _, env, memo)) val seqOpArgs2 = seqOpArgs.map(rebuildIR(ctx, _, env.promoteAgg, memo)) - ApplyAggOp(initOpArgs2, seqOpArgs2, + ApplyAggOp( + initOpArgs2, + seqOpArgs2, aggSig.copy( initOpArgs = initOpArgs2.map(_.typ), - seqOpArgs = seqOpArgs2.map(_.typ))) + seqOpArgs = seqOpArgs2.map(_.typ), + ), + ) case ApplyScanOp(initOpArgs, seqOpArgs, aggSig) => val initOpArgs2 = initOpArgs.map(rebuildIR(ctx, _, env, memo)) val seqOpArgs2 = seqOpArgs.map(rebuildIR(ctx, _, env.promoteScan, memo)) - ApplyScanOp(initOpArgs2, seqOpArgs2, + ApplyScanOp( + initOpArgs2, + seqOpArgs2, aggSig.copy( initOpArgs = initOpArgs2.map(_.typ), - seqOpArgs = seqOpArgs2.map(_.typ))) + seqOpArgs = seqOpArgs2.map(_.typ), + ), + ) case AggFold(zero, seqOp, combOp, accumName, otherAccumName, isScan) => val zero2 = rebuildIR(ctx, zero, env, memo) val seqOp2 = rebuildIR(ctx, seqOp, if (isScan) env.promoteScan else env.promoteAgg, memo) val combOp2 = rebuildIR(ctx, combOp, env, memo) AggFold(zero2, seqOp2, combOp2, accumName, otherAccumName, isScan) - case CollectDistributedArray(contexts, globals, cname, gname, body, dynamicID, staticID, tsd) => - val contexts2 = upcast(ctx, rebuildIR(ctx, contexts, env, memo), memo.requestedType.lookup(contexts).asInstanceOf[Type]) - val globals2 = upcast(ctx, rebuildIR(ctx, globals, env, memo), memo.requestedType.lookup(globals).asInstanceOf[Type]) - val body2 = rebuildIR(ctx, body, BindingEnv(Env(cname -> TIterable.elementType(contexts2.typ), gname -> globals2.typ)), memo) + case CollectDistributedArray(contexts, globals, cname, gname, body, dynamicID, staticID, + tsd) => + val contexts2 = upcast( + ctx, + rebuildIR(ctx, contexts, env, memo), + memo.requestedType.lookup(contexts).asInstanceOf[Type], + ) + val globals2 = upcast( + ctx, + rebuildIR(ctx, globals, env, memo), + memo.requestedType.lookup(globals).asInstanceOf[Type], + ) + val body2 = rebuildIR( + ctx, + body, + BindingEnv(Env(cname -> TIterable.elementType(contexts2.typ), gname -> globals2.typ)), + memo, + ) val dynamicID2 = rebuildIR(ctx, dynamicID, env, memo) CollectDistributedArray(contexts2, globals2, cname, gname, body2, dynamicID2, staticID, tsd) case _ => ir.mapChildren { - case valueIR: IR => rebuildIR(ctx, valueIR, env, memo) // FIXME: assert IR does not bind or change env + case valueIR: IR => + rebuildIR(ctx, valueIR, env, memo) // FIXME: assert IR does not bind or change env case mir: MatrixIR => rebuild(ctx, mir, memo) case tir: TableIR => rebuild(ctx, tir, memo) - case bmir: BlockMatrixIR => bmir //NOTE Currently no BlockMatrixIRs would have dead fields + case bmir: BlockMatrixIR => + bmir // NOTE Currently no BlockMatrixIRs would have dead fields } } } @@ -2206,12 +2838,16 @@ object PruneDeadFields { ir else { val result = ir.typ match { - case _: TStruct => - bindIR(ir) { ref => - val ms = MakeStruct(rType.asInstanceOf[TStruct].fields.map { f => - f.name -> upcast(ctx, GetField(ref, f.name), f.typ) - }) - If(IsNA(ref), NA(ms.typ), ms) + case tstruct: TStruct => + if (rType.asInstanceOf[TStruct].fields.forall(f => tstruct.field(f.name).typ == f.typ)) { + SelectFields(ir, rType.asInstanceOf[TStruct].fields.map(f => f.name)) + } else { + bindIR(ir) { ref => + val ms = MakeStruct(rType.asInstanceOf[TStruct].fields.map { f => + f.name -> upcast(ctx, GetField(ref, f.name), f.typ) + }) + If(IsNA(ref), NA(ms.typ), ms) + } } case ts: TStream => val ra = rType.asInstanceOf[TStream] @@ -2226,7 +2862,11 @@ object PruneDeadFields { case _: TTuple => bindIR(ir) { ref => val mt = MakeTuple(rType.asInstanceOf[TTuple]._types.map { tupleField => - tupleField.index -> upcast(ctx, GetTupleElement(ref, tupleField.index), tupleField.typ) + tupleField.index -> upcast( + ctx, + GetTupleElement(ref, tupleField.index), + tupleField.typ, + ) }) If(IsNA(ref), NA(mt.typ), mt) } @@ -2239,7 +2879,7 @@ object PruneDeadFields { case _ => ir } - assert(result.typ == rType, s"${ Pretty(ctx, result) }, ${ result.typ }, $rType") + assert(result.typ == rType, s"${Pretty(ctx, result)}, ${result.typ}, $rType") result } } @@ -2251,7 +2891,7 @@ object PruneDeadFields { upcastRows: Boolean = true, upcastCols: Boolean = true, upcastGlobals: Boolean = true, - upcastEntries: Boolean = true + upcastEntries: Boolean = true, ): MatrixIR = { if (ir.typ == rType || !(upcastRows || upcastCols || upcastGlobals || upcastEntries)) ir @@ -2270,8 +2910,11 @@ object PruneDeadFields { mt = MatrixMapRows(mt, upcast(ctx, Ref("va", mt.typ.rowType), rType.rowType)) if (upcastCols && (mt.typ.colType != rType.colType || mt.typ.colKey != rType.colKey)) { - mt = MatrixMapCols(mt, upcast(ctx, Ref("sa", mt.typ.colType), rType.colType), - if (rType.colKey == mt.typ.colKey) None else Some(rType.colKey)) + mt = MatrixMapCols( + mt, + upcast(ctx, Ref("sa", mt.typ.colType), rType.colType), + if (rType.colKey == mt.typ.colKey) None else Some(rType.colKey), + ) } if (upcastGlobals && mt.typ.globalType != rType.globalType) @@ -2286,7 +2929,7 @@ object PruneDeadFields { ir: TableIR, rType: TableType, upcastRow: Boolean = true, - upcastGlobals: Boolean = true + upcastGlobals: Boolean = true, ): TableIR = { if (ir.typ == rType) ir @@ -2300,8 +2943,8 @@ object PruneDeadFields { table = TableMapRows(table, upcast(ctx, Ref("row", table.typ.rowType), rType.rowType)) } if (upcastGlobals && ir.typ.globalType != rType.globalType) { - table = TableMapGlobals(table, - upcast(ctx, Ref("global", table.typ.globalType), rType.globalType)) + table = + TableMapGlobals(table, upcast(ctx, Ref("global", table.typ.globalType), rType.globalType)) } table } diff --git a/hail/src/main/scala/is/hail/expr/ir/Random.scala b/hail/src/main/scala/is/hail/expr/ir/Random.scala index 6a39c62340c..b3d835a9872 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Random.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Random.scala @@ -2,8 +2,9 @@ package is.hail.expr.ir import is.hail.asm4s._ import is.hail.utils.FastSeq -import net.sourceforge.jdistlib.rng.RandomEngine + import net.sourceforge.jdistlib.{Beta, Gamma, HyperGeometric, Poisson} +import net.sourceforge.jdistlib.rng.RandomEngine import org.apache.commons.math3.random.RandomGenerator object Threefry { @@ -11,22 +12,28 @@ object Threefry { val finalBlockNoPadTweak = -2L val finalBlockPaddedTweak = -3L - val keyConst = 0x1BD11BDAA9FC1A22L + val keyConst = 0x1bd11bdaa9fc1a22L val rotConsts = Array( Array(14, 16), Array(52, 57), Array(23, 40), - Array( 5, 37), + Array(5, 37), Array(25, 33), Array(46, 12), Array(58, 22), - Array(32, 32)) + Array(32, 32), + ) val defaultNumRounds = 20 val defaultKey: IndexedSeq[Long] = - expandKey(FastSeq(0x215d6dfdb7dfdf6bL, 0x045cfa043329c49fL, 0x9ec75a93692444ddL, 0x1284681663220f1cL)) + expandKey(FastSeq( + 0x215d6dfdb7dfdf6bL, + 0x045cfa043329c49fL, + 0x9ec75a93692444ddL, + 0x1284681663220f1cL, + )) def expandKey(k: IndexedSeq[Long]): IndexedSeq[Long] = { assert(k.length == 4) @@ -34,9 +41,8 @@ object Threefry { k :+ k4 } - def rotL(i: Value[Long], n: Value[Int]): Code[Long] = { + def rotL(i: Value[Long], n: Value[Int]): Code[Long] = (i << n) | (i >>> -n) - } def mix(cb: CodeBuilderLike, x0: Settable[Long], x1: Settable[Long], n: Int): Unit = { cb.assign(x0, x0 + x1) @@ -44,7 +50,8 @@ object Threefry { cb.assign(x1, x0 ^ x1) } - def injectKey(key: IndexedSeq[Long], tweak: IndexedSeq[Long], block: Array[Long], s: Int): Unit = { + def injectKey(key: IndexedSeq[Long], tweak: IndexedSeq[Long], block: Array[Long], s: Int) + : Unit = { assert(tweak.length == 3) assert(key.length == 5) assert(block.length == 4) @@ -54,11 +61,12 @@ object Threefry { block(3) += key((s + 3) % 5) + s.toLong } - def injectKey(cb: CodeBuilderLike, + def injectKey( + cb: CodeBuilderLike, key: IndexedSeq[Long], tweak: IndexedSeq[Value[Long]], block: IndexedSeq[Settable[Long]], - s: Int + s: Int, ): Unit = { cb.assign(block(0), block(0) + key(s % 5)) cb.assign(block(1), block(1) + const(key((s + 1) % 5)) + tweak(s % 3)) @@ -72,7 +80,16 @@ object Threefry { x(3) = tmp } - def encryptUnrolled(k0: Long, k1: Long, k2: Long, k3: Long, k4: Long, t0: Long, t1: Long, x: Array[Long]): Unit = { + def encryptUnrolled( + k0: Long, + k1: Long, + k2: Long, + k3: Long, + k4: Long, + t0: Long, + t1: Long, + x: Array[Long], + ): Unit = { import java.lang.Long.rotateLeft var x0 = x(0); var x1 = x(1); var x2 = x(2); var x3 = x(3) val t2 = t0 ^ t1 @@ -183,18 +200,20 @@ object Threefry { injectKey(k, t, x, rounds / 4) } - def encrypt(cb: CodeBuilderLike, + def encrypt( + cb: CodeBuilderLike, k: IndexedSeq[Long], t: IndexedSeq[Value[Long]], - x: IndexedSeq[Settable[Long]] + x: IndexedSeq[Settable[Long]], ): Unit = encrypt(cb, k, t, x, defaultNumRounds) - def encrypt(cb: CodeBuilderLike, + def encrypt( + cb: CodeBuilderLike, k: IndexedSeq[Long], _t: IndexedSeq[Value[Long]], _x: IndexedSeq[Settable[Long]], - rounds: Int + rounds: Int, ): Unit = { assert(k.length == 5) assert(_t.length == 2) @@ -207,7 +226,7 @@ object Threefry { injectKey(cb, k, t, x, d / 4) for (j <- 0 until 2) - mix(cb, x(2*j), x(2*j+1), rotConsts(d % 8)(j)) + mix(cb, x(2 * j), x(2 * j + 1), rotConsts(d % 8)(j)) permute(x) } @@ -216,8 +235,17 @@ object Threefry { injectKey(cb, k, t, x, rounds / 4) } - def debugPrint(cb: EmitCodeBuilder, x: IndexedSeq[Settable[Long]], info: String) { - cb.println(s"[$info]=\n\t", x(0).toString, " ", x(1).toString, " ", x(2).toString, " ", x(3).toString) + def debugPrint(cb: EmitCodeBuilder, x: IndexedSeq[Settable[Long]], info: String): Unit = { + cb.println( + s"[$info]=\n\t", + x(0).toString, + " ", + x(1).toString, + " ", + x(2).toString, + " ", + x(3).toString, + ) } def pmac(sum: Array[Long], message: IndexedSeq[Long]): Array[Long] = { @@ -261,9 +289,8 @@ object Threefry { sum(3) ^= x(3) i += 4 } - for (j <- 0 until 4) { + for (j <- 0 until 4) sum(j) ^= message(i + j) - } val finalTweak = if (padded) Threefry.finalBlockPaddedTweak else Threefry.finalBlockNoPadTweak (sum, finalTweak) @@ -332,8 +359,18 @@ object ThreefryRandomEngine { def apply(): ThreefryRandomEngine = { val key = Threefry.defaultKey new ThreefryRandomEngine( - key(0), key(1), key(2), key(3), key(4), - 0, 0, 0, 0, 0, 0) + key(0), + key(1), + key(2), + key(3), + key(4), + 0, + 0, + 0, + 0, + 0, + 0, + ) } def apply(nonce: Long, staticID: Long, message: IndexedSeq[Long]): ThreefryRandomEngine = { @@ -347,9 +384,18 @@ object ThreefryRandomEngine { val rand = new java.util.Random() val key = Threefry.expandKey(Array.fill(4)(rand.nextLong())) new ThreefryRandomEngine( - key(0), key(1), key(2), key(3), key(4), - rand.nextLong(), rand.nextLong(), rand.nextLong(), rand.nextLong(), - 0, 0) + key(0), + key(1), + key(2), + key(3), + key(4), + rand.nextLong(), + rand.nextLong(), + rand.nextLong(), + rand.nextLong(), + 0, + 0, + ) } } @@ -364,7 +410,7 @@ class ThreefryRandomEngine( var state2: Long, var state3: Long, var counter: Long, - var tweak: Long + var tweak: Long, ) extends RandomEngine with RandomGenerator { val buffer: Array[Long] = Array.ofDim[Long](4) var usedInts: Int = 8 @@ -440,14 +486,15 @@ class ThreefryRandomEngine( // Uses approach from https://github.com/apple/swift/pull/39143 override def nextInt(n: Int): Int = { val nL = n.toLong - val mult = nL * (nextInt().toLong & 0xFFFFFFFFL) + val mult = nL * (nextInt().toLong & 0xffffffffL) val result = (mult >>> 32).toInt - val fraction = mult & 0xFFFFFFFFL + val fraction = mult & 0xffffffffL // optional early return, benchmark to decide if it helps if (fraction < ((1L << 32) - nL)) return result - val multHigh = (((nL * (nextInt().toLong & 0xFFFFFFFFL)) >>> 32) + (nL * (nextInt().toLong & 0xFFFFFFFFL))) >>> 32 + val multHigh = + (((nL * (nextInt().toLong & 0xffffffffL)) >>> 32) + (nL * (nextInt().toLong & 0xffffffffL))) >>> 32 val sum = fraction + multHigh val carry = (sum >>> 32).toInt result + carry @@ -540,4 +587,4 @@ class ThreefryRandomEngine( val result = (exponent << 23) | significand java.lang.Float.intBitsToFloat(result) } -} \ No newline at end of file +} diff --git a/hail/src/main/scala/is/hail/expr/ir/RefEquality.scala b/hail/src/main/scala/is/hail/expr/ir/RefEquality.scala index 4a0e0acf5fe..0551e0b2191 100644 --- a/hail/src/main/scala/is/hail/expr/ir/RefEquality.scala +++ b/hail/src/main/scala/is/hail/expr/ir/RefEquality.scala @@ -23,12 +23,12 @@ object Memo { def empty[T]: Memo[T] = new Memo[T](new mutable.HashMap[RefEquality[BaseIR], T]) } -class Memo[T] private(val m: mutable.HashMap[RefEquality[BaseIR], T]) { +class Memo[T] private (val m: mutable.HashMap[RefEquality[BaseIR], T]) { def bind(ir: BaseIR, t: T): Memo[T] = bind(RefEquality(ir), t) def bind(ir: RefEquality[BaseIR], t: T): Memo[T] = { if (m.contains(ir)) - throw new RuntimeException(s"IR already in memo: ${ ir.t }") + throw new RuntimeException(s"IR already in memo: ${ir.t}") m += ir -> t this } @@ -58,10 +58,9 @@ class Memo[T] private(val m: mutable.HashMap[RefEquality[BaseIR], T]) { def delete(ir: BaseIR): Unit = delete(RefEquality(ir)) def delete(ir: RefEquality[BaseIR]): Unit = m -= ir - override def toString: String = s"Memo(${m})" + override def toString: String = s"Memo($m)" } - object HasIRSharing { def apply(ctx: ExecuteContext)(ir: BaseIR): Boolean = { val mark = ctx.irMetadata.nextFlag diff --git a/hail/src/main/scala/is/hail/expr/ir/Requiredness.scala b/hail/src/main/scala/is/hail/expr/ir/Requiredness.scala index 16ce778504a..a5674fe8b98 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Requiredness.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Requiredness.scala @@ -12,7 +12,8 @@ import is.hail.utils._ import scala.collection.mutable object Requiredness { - def apply(node: BaseIR, usesAndDefs: UsesAndDefs, ctx: ExecuteContext, env: Env[PType]): RequirednessAnalysis = { + def apply(node: BaseIR, usesAndDefs: UsesAndDefs, ctx: ExecuteContext, env: Env[PType]) + : RequirednessAnalysis = { val pass = new Requiredness(usesAndDefs, ctx) pass.initialize(node, env) pass.run() @@ -23,7 +24,10 @@ object Requiredness { apply(node, ComputeUsesAndDefs(node), ctx, Env.empty) } -case class RequirednessAnalysis(r: Memo[BaseTypeWithRequiredness], states: Memo[IndexedSeq[TypeWithRequiredness]]) { +case class RequirednessAnalysis( + r: Memo[BaseTypeWithRequiredness], + states: Memo[IndexedSeq[TypeWithRequiredness]], +) { def lookup(node: BaseIR): BaseTypeWithRequiredness = r.lookup(node) def lookupState(node: BaseIR): IndexedSeq[BaseTypeWithRequiredness] = states.lookup(node) def lookupOpt(node: BaseIR): Option[BaseTypeWithRequiredness] = r.get(node) @@ -47,7 +51,8 @@ class Requiredness(val usesAndDefs: UsesAndDefs, ctx: ExecuteContext) { def lookup(node: TableIR): RTable = tcoerce[RTable](cache(node)) def lookup(node: BlockMatrixIR): RBlockMatrix = tcoerce[RBlockMatrix](cache(node)) - def supportedType(node: BaseIR): Boolean = node.isInstanceOf[TableIR] || node.isInstanceOf[IR] || node.isInstanceOf[BlockMatrixIR] + def supportedType(node: BaseIR): Boolean = + node.isInstanceOf[TableIR] || node.isInstanceOf[IR] || node.isInstanceOf[BlockMatrixIR] private def initializeState(node: BaseIR): Unit = if (!cache.contains(node)) { assert(supportedType(node)) @@ -64,7 +69,7 @@ class Requiredness(val usesAndDefs: UsesAndDefs, ctx: ExecuteContext) { case _ => } node.children.foreach { - case c: MatrixIR => fatal("Requiredness analysis only works on lowered MatrixTables. ") + case _: MatrixIR => fatal("Requiredness analysis only works on lowered MatrixTables. ") case c if supportedType(node) => initializeState(c) if (node.typ != TVoid) @@ -72,20 +77,20 @@ class Requiredness(val usesAndDefs: UsesAndDefs, ctx: ExecuteContext) { } if (node.typ != TVoid) { cache.bind(node, BaseTypeWithRequiredness(node.typ)) - if (usesAndDefs.free.isEmpty || !re.t.isInstanceOf[BaseRef] || !usesAndDefs.free.contains(re.asInstanceOf[RefEquality[BaseRef]])) + if ( + usesAndDefs.free.isEmpty || !re.t.isInstanceOf[BaseRef] || !usesAndDefs.free.contains( + re.asInstanceOf[RefEquality[BaseRef]] + ) + ) q += re } } def initialize(node: BaseIR, env: Env[PType]): Unit = { initializeState(node) - usesAndDefs.uses.m.keys.foreach { n => - if (supportedType(n.t)) addBindingRelations(n.t) - } + usesAndDefs.uses.m.keys.foreach(n => if (supportedType(n.t)) addBindingRelations(n.t)) - usesAndDefs.free.foreach { re => - lookup(re.t).fromPType(env.lookup(re.t.name)) - } + usesAndDefs.free.foreach(re => lookup(re.t).fromPType(env.lookup(re.t.name))) } def run(): Unit = { @@ -100,7 +105,12 @@ class Requiredness(val usesAndDefs: UsesAndDefs, ctx: ExecuteContext) { def addBindingRelations(node: BaseIR): Unit = { val refMap: Map[String, IndexedSeq[RefEquality[BaseRef]]] = usesAndDefs.uses(node).toFastSeq.groupBy(_.t.name) - def addElementBinding(name: String, d: IR, makeOptional: Boolean = false, makeRequired: Boolean = false): Unit = { + def addElementBinding( + name: String, + d: IR, + makeOptional: Boolean = false, + makeRequired: Boolean = false, + ): Unit = { assert(!(makeOptional && makeRequired)) if (refMap.contains(name)) { val uses = refMap(name) @@ -114,12 +124,13 @@ class Requiredness(val usesAndDefs: UsesAndDefs, ctx: ExecuteContext) { req.union(true) req } else eltReq - uses.foreach { u => defs.bind(u, Array(req)) } + uses.foreach(u => defs.bind(u, Array(req))) dependents.getOrElseUpdate(d, mutable.Set[RefEquality[BaseIR]]()) ++= uses } } - def addBlockMatrixElementBinding(name: String, d: BlockMatrixIR, makeOptional: Boolean = false): Unit = { + def addBlockMatrixElementBinding(name: String, d: BlockMatrixIR, makeOptional: Boolean = false) + : Unit = { if (refMap.contains(name)) { val uses = refMap(name) val eltReq = tcoerce[RBlockMatrix](lookup(d)).elementType @@ -128,35 +139,39 @@ class Requiredness(val usesAndDefs: UsesAndDefs, ctx: ExecuteContext) { optional.union(false) optional } else eltReq - uses.foreach { u => defs.bind(u, Array(req)) } + uses.foreach(u => defs.bind(u, Array(req))) dependents.getOrElseUpdate(d, mutable.Set[RefEquality[BaseIR]]()) ++= uses } } - def addBindings(name: String, ds: Array[IR]): Unit = { + def addBindings(name: String, ds: Array[IR]): Unit = if (refMap.contains(name)) { val uses = refMap(name) - uses.foreach { u => defs.bind(u, ds.map(lookup).toArray[BaseTypeWithRequiredness]) } - ds.foreach { d => dependents.getOrElseUpdate(d, mutable.Set[RefEquality[BaseIR]]()) ++= uses } + uses.foreach(u => defs.bind(u, ds.map(lookup).toArray[BaseTypeWithRequiredness])) + ds.foreach { d => + dependents.getOrElseUpdate(d, mutable.Set[RefEquality[BaseIR]]()) ++= uses + } } - } def addBinding(name: String, ds: IR): Unit = addBindings(name, Array(ds)) def addTableBinding(table: TableIR): Unit = { if (refMap.contains("row")) - refMap("row").foreach { u => defs.bind(u, Array[BaseTypeWithRequiredness](lookup(table).rowType)) } + refMap("row").foreach { u => + defs.bind(u, Array[BaseTypeWithRequiredness](lookup(table).rowType)) + } if (refMap.contains("global")) - refMap("global").foreach { u => defs.bind(u, Array[BaseTypeWithRequiredness](lookup(table).globalType)) } + refMap("global").foreach { u => + defs.bind(u, Array[BaseTypeWithRequiredness](lookup(table).globalType)) + } val refs = refMap.getOrElse("row", FastSeq()) ++ refMap.getOrElse("global", FastSeq()) dependents.getOrElseUpdate(table, mutable.Set[RefEquality[BaseIR]]()) ++= refs } node match { - case AggLet(name, value, body, isScan) => addBinding(name, value) - case Let(bindings, _) => bindings.foreach(Function.tupled(addBinding)) - case RelationalLet(name, value, body) => addBinding(name, value) - case RelationalLetTable(name, value, body) => addBinding(name, value) + case Block(bindings, _) => bindings.foreach(b => addBinding(b.name, b.value)) + case RelationalLet(name, value, _) => addBinding(name, value) + case RelationalLetTable(name, value, _) => addBinding(name, value) case TailLoop(loopName, params, _, body) => addBinding(loopName, body) val argDefs = Array.fill(params.length)(new BoxedArrayBuilder[IR]()) @@ -167,40 +182,45 @@ class Requiredness(val usesAndDefs: UsesAndDefs, ctx: ExecuteContext) { var i = 0 while (i < params.length) { val (name, init) = params(i) - s(i) = lookup(refMap.get(name).flatMap(refs => refs.headOption.map(_.t.asInstanceOf[IR])).getOrElse(init)) + s(i) = lookup(refMap.get(name).flatMap(refs => + refs.headOption.map(_.t.asInstanceOf[IR]) + ).getOrElse(init)) addBindings(name, argDefs(i).result() :+ init) i += 1 } states.bind(node, s) - case x@ApplyIR(_, _, args, _, _) => + case x @ ApplyIR(_, _, args, _, _) => x.refIdx.foreach { case (n, i) => addBinding(n, args(i)) } - case ArraySort(a, l, r, c) => + case ArraySort(a, l, r, _) => addElementBinding(l, a, makeRequired = true) addElementBinding(r, a, makeRequired = true) case ArrayMaximalIndependentSet(a, tiebreaker) => tiebreaker.foreach { case (left, right, _) => - val eltReq = tcoerce[TypeWithRequiredness](tcoerce[RIterable](lookup(a)).elementType.children.head) + val eltReq = + tcoerce[TypeWithRequiredness](tcoerce[RIterable](lookup(a)).elementType.children.head) val req = RTuple.fromNamesAndTypes(FastSeq("0" -> eltReq)) req.union(true) - refMap(left).foreach { u => defs.bind(u, Array(req)) } - refMap(right).foreach { u => defs.bind(u, Array(req)) } + refMap(left).foreach(u => defs.bind(u, Array(req))) + refMap(right).foreach(u => defs.bind(u, Array(req))) } - case StreamMap(a, name, body) => + case StreamMap(a, name, _) => addElementBinding(name, a) - case x@StreamZip(as, names, body, behavior, _) => + case StreamZip(as, names, _, behavior, _) => var i = 0 while (i < names.length) { - addElementBinding(names(i), as(i), - makeOptional = behavior == ArrayZipBehavior.ExtendNA) + addElementBinding(names(i), as(i), makeOptional = behavior == ArrayZipBehavior.ExtendNA) i += 1 } case StreamZipJoin(as, key, curKey, curVals, _) => val aEltTypes = as.map(a => tcoerce[RStruct](tcoerce[RIterable](lookup(a)).elementType)) if (refMap.contains(curKey)) { val uses = refMap(curKey) - val keyTypes = aEltTypes.map(t => RStruct.fromNamesAndTypes(key.map(k => k -> t.fieldType(k)))) - uses.foreach { u => defs.bind(u, keyTypes) } - as.foreach { a => dependents.getOrElseUpdate(a, mutable.Set[RefEquality[BaseIR]]()) ++= uses } + val keyTypes = + aEltTypes.map(t => RStruct.fromNamesAndTypes(key.map(k => k -> t.fieldType(k)))) + uses.foreach(u => defs.bind(u, keyTypes)) + as.foreach { a => + dependents.getOrElseUpdate(a, mutable.Set[RefEquality[BaseIR]]()) ++= uses + } } if (refMap.contains(curVals)) { val uses = refMap(curVals) @@ -209,114 +229,142 @@ class Requiredness(val usesAndDefs: UsesAndDefs, ctx: ExecuteContext) { optional.union(false) RIterable(optional) } - uses.foreach { u => defs.bind(u, valTypes) } - as.foreach { a => dependents.getOrElseUpdate(a, mutable.Set[RefEquality[BaseIR]]()) ++= uses } + uses.foreach(u => defs.bind(u, valTypes)) + as.foreach { a => + dependents.getOrElseUpdate(a, mutable.Set[RefEquality[BaseIR]]()) ++= uses + } } case StreamZipJoinProducers(contexts, ctxName, makeProducer, key, curKey, curVals, _) => val ctxType = tcoerce[RIterable](lookup(contexts)).elementType if (refMap.contains(ctxName)) { val uses = refMap(ctxName) - uses.foreach { u => defs.bind(u, Array(ctxType)) } + uses.foreach(u => defs.bind(u, Array(ctxType))) dependents.getOrElseUpdate(contexts, mutable.Set[RefEquality[BaseIR]]()) ++= uses } - val producerElementType = tcoerce[RStruct](tcoerce[RIterable](lookup(makeProducer)).elementType) + val producerElementType = + tcoerce[RStruct](tcoerce[RIterable](lookup(makeProducer)).elementType) if (refMap.contains(curKey)) { val uses = refMap(curKey) - val keyType = RStruct.fromNamesAndTypes(key.map(k => k -> producerElementType.fieldType(k))) - uses.foreach { u => defs.bind(u, Array(keyType)) } + val keyType = + RStruct.fromNamesAndTypes(key.map(k => k -> producerElementType.fieldType(k))) + uses.foreach(u => defs.bind(u, Array(keyType))) dependents.getOrElseUpdate(makeProducer, mutable.Set[RefEquality[BaseIR]]()) ++= uses } if (refMap.contains(curVals)) { val uses = refMap(curVals) val optional = producerElementType.copy(producerElementType.children) optional.union(false) - uses.foreach { u => defs.bind(u, Array(RIterable(optional))) } + uses.foreach(u => defs.bind(u, Array(RIterable(optional)))) dependents.getOrElseUpdate(makeProducer, mutable.Set[RefEquality[BaseIR]]()) ++= uses } - case StreamFilter(a, name, cond) => addElementBinding(name, a) - case StreamTakeWhile(a, name, cond) => addElementBinding(name, a) - case StreamDropWhile(a, name, cond) => addElementBinding(name, a) - case StreamFlatMap(a, name, body) => addElementBinding(name, a) + case StreamFilter(a, name, _) => addElementBinding(name, a) + case StreamTakeWhile(a, name, _) => addElementBinding(name, a) + case StreamDropWhile(a, name, _) => addElementBinding(name, a) + case StreamFlatMap(a, name, _) => addElementBinding(name, a) case StreamFor(a, name, _) => addElementBinding(name, a) case StreamFold(a, zero, accumName, valueName, body) => addElementBinding(valueName, a) addBindings(accumName, Array[IR](zero, body)) - states.bind(node, Array[TypeWithRequiredness](lookup( - refMap.get(accumName) - .flatMap(refs => refs.headOption.map(_.t.asInstanceOf[IR])) - .getOrElse(zero)))) + states.bind( + node, + Array[TypeWithRequiredness](lookup( + refMap.get(accumName) + .flatMap(refs => refs.headOption.map(_.t.asInstanceOf[IR])) + .getOrElse(zero) + )), + ) case StreamScan(a, zero, accumName, valueName, body) => addElementBinding(valueName, a) addBindings(accumName, Array[IR](zero, body)) - states.bind(node, Array[TypeWithRequiredness](lookup( - refMap.get(accumName) - .flatMap(refs => refs.headOption.map(_.t.asInstanceOf[IR])) - .getOrElse(zero)))) - case StreamFold2(a, accums, valueName, seq, result) => + states.bind( + node, + Array[TypeWithRequiredness](lookup( + refMap.get(accumName) + .flatMap(refs => refs.headOption.map(_.t.asInstanceOf[IR])) + .getOrElse(zero) + )), + ) + case StreamFold2(a, accums, valueName, seq, _) => addElementBinding(valueName, a) val s = Array.fill[TypeWithRequiredness](accums.length)(null) var i = 0 while (i < accums.length) { val (n, z) = accums(i) addBindings(n, Array[IR](z, seq(i))) - s(i) = lookup(refMap.get(n).flatMap(refs => refs.headOption.map(_.t.asInstanceOf[IR])).getOrElse(z)) + s(i) = lookup(refMap.get(n).flatMap(refs => + refs.headOption.map(_.t.asInstanceOf[IR]) + ).getOrElse(z)) i += 1 } states.bind(node, s) - case StreamJoinRightDistinct(left, right, lKey, rKey, l, r, joinf, joinType) => + case StreamJoinRightDistinct(left, right, _, _, l, r, _, joinType) => addElementBinding(l, left, makeOptional = (joinType == "outer" || joinType == "right")) addElementBinding(r, right, makeOptional = (joinType == "outer" || joinType == "left")) - case StreamAgg(a, name, query) => + case StreamLeftIntervalJoin(left, right, _, _, lname, rname, _) => + addElementBinding(lname, left) + val uses = refMap(rname) + val rtypes = Array(lookup(right)) + uses.foreach(u => defs.bind(u, rtypes)) + dependents.getOrElseUpdate(right, mutable.Set[RefEquality[BaseIR]]()) ++= uses + case StreamAgg(a, name, _) => addElementBinding(name, a) - case StreamAggScan(a, name, query) => + case StreamAggScan(a, name, _) => addElementBinding(name, a) case StreamBufferedAggregate(stream, _, _, _, name, _, _) => addElementBinding(name, stream) - case RunAggScan(a, name, init, seqs, result, signature) => + case RunAggScan(a, name, _, _, _, _) => addElementBinding(name, a) case AggFold(zero, seqOp, combOp, accumName, otherAccumName, _) => addBindings(accumName, Array(zero, seqOp, combOp)) addBindings(otherAccumName, Array(zero, seqOp, combOp)) - case AggExplode(a, name, aggBody, isScan) => + case AggExplode(a, name, _, _) => addElementBinding(name, a) - case AggArrayPerElement(a, elt, idx, body, knownLength, isScan) => + case AggArrayPerElement(a, elt, idx, _, _, _) => addElementBinding(elt, a) - //idx is always required Int32 + // idx is always required Int32 if (refMap.contains(idx)) - refMap(idx).foreach { use => defs.bind(use, Array[BaseTypeWithRequiredness](RPrimitive())) } - case NDArrayMap(nd, name, body) => + refMap(idx).foreach { use => + defs.bind(use, Array[BaseTypeWithRequiredness](RPrimitive())) + } + case NDArrayMap(nd, name, _) => addElementBinding(name, nd) - case NDArrayMap2(left, right, l, r, body, _) => + case NDArrayMap2(left, right, l, r, _, _) => addElementBinding(l, left) addElementBinding(r, right) - case CollectDistributedArray(ctxs, globs, c, g, body, _, _, _) => + case CollectDistributedArray(ctxs, globs, c, g, _, _, _, _) => addElementBinding(c, ctxs) addBinding(g, globs) - case BlockMatrixMap(child, eltName, f, _) => addBlockMatrixElementBinding(eltName, child) - case BlockMatrixMap2(leftChild, rightChild, leftName, rightName, _, _) => { + case BlockMatrixMap(child, eltName, _, _) => addBlockMatrixElementBinding(eltName, child) + case BlockMatrixMap2(leftChild, rightChild, leftName, rightName, _, _) => addBlockMatrixElementBinding(leftName, leftChild) addBlockMatrixElementBinding(rightName, rightChild) - } - case TableAggregate(c, q) => + case TableAggregate(c, _) => addTableBinding(c) - case TableFilter(child, pred) => + case TableFilter(child, _) => addTableBinding(child) - case TableMapRows(child, newRow) => + case TableMapRows(child, _) => addTableBinding(child) - case TableMapGlobals(child, newGlobals) => + case TableMapGlobals(child, _) => addTableBinding(child) - case TableKeyByAndAggregate(child, expr, newKey, nPartitions, bufferSize) => + case TableKeyByAndAggregate(child, _, _, _, _) => addTableBinding(child) - case TableAggregateByKey(child, expr) => + case TableAggregateByKey(child, _) => addTableBinding(child) - case TableMapPartitions(child, globalName, partitionStreamName, body, _, _) => + case TableMapPartitions(child, globalName, partitionStreamName, _, _, _) => if (refMap.contains(globalName)) - refMap(globalName).foreach { u => defs.bind(u, Array[BaseTypeWithRequiredness](lookup(child).globalType)) } + refMap(globalName).foreach { u => + defs.bind(u, Array[BaseTypeWithRequiredness](lookup(child).globalType)) + } if (refMap.contains(partitionStreamName)) - refMap(partitionStreamName).foreach { u => defs.bind(u, Array[BaseTypeWithRequiredness](RIterable(lookup(child).rowType))) } - val refs = refMap.getOrElse(globalName, FastSeq()) ++ refMap.getOrElse(partitionStreamName, FastSeq()) + refMap(partitionStreamName).foreach { u => + defs.bind(u, Array[BaseTypeWithRequiredness](RIterable(lookup(child).rowType))) + } + val refs = refMap.getOrElse(globalName, FastSeq()) ++ refMap.getOrElse( + partitionStreamName, + FastSeq(), + ) dependents.getOrElseUpdate(child, mutable.Set[RefEquality[BaseIR]]()) ++= refs case TableGen(contexts, globals, cname, gname, _, _, _) => addElementBinding(cname, contexts) @@ -336,11 +384,11 @@ class Requiredness(val usesAndDefs: UsesAndDefs, ctx: ExecuteContext) { def analyzeTable(node: TableIR): Boolean = { val requiredness = lookup(node) node match { - //statically known - case TableLiteral(typ, rvd, enc, encodedGlobals) => + // statically known + case TableLiteral(typ, rvd, enc, _) => requiredness.rowType.fromPType(rvd.rowPType) requiredness.globalType.fromPType(enc.encodedType.decodedPType(typ.globalType)) - case TableRead(typ, dropRows, tr) => + case TableRead(typ, _, tr) => val rowReq = tr.rowRequiredness(ctx, typ) val globalReq = tr.globalRequiredness(ctx, typ) requiredness.rowType.unionFields(rowReq.r.asInstanceOf[RStruct]) @@ -352,17 +400,18 @@ class Requiredness(val usesAndDefs: UsesAndDefs, ctx: ExecuteContext) { case TableFilter(child, _) => requiredness.unionFrom(lookup(child)) case TableHead(child, _) => requiredness.unionFrom(lookup(child)) case TableTail(child, _) => requiredness.unionFrom(lookup(child)) - case TableRepartition(child, n, strategy) => requiredness.unionFrom(lookup(child)) + case TableRepartition(child, _, _) => requiredness.unionFrom(lookup(child)) case TableDistinct(child) => requiredness.unionFrom(lookup(child)) - case TableOrderBy(child, sortFields) => requiredness.unionFrom(lookup(child)) - case TableRename(child, rMap, gMap) => requiredness.unionFrom(lookup(child)) - case TableFilterIntervals(child, intervals, keep) => requiredness.unionFrom(lookup(child)) - case RelationalLetTable(name, value, body) => requiredness.unionFrom(lookup(body)) + case TableOrderBy(child, _) => requiredness.unionFrom(lookup(child)) + case TableRename(child, _, _) => requiredness.unionFrom(lookup(child)) + case TableFilterIntervals(child, _, _) => requiredness.unionFrom(lookup(child)) + case RelationalLetTable(_, _, body) => requiredness.unionFrom(lookup(body)) case TableGen(_, globals, _, _, body, _, _) => requiredness.unionGlobals(lookupAs[RStruct](globals)) requiredness.unionRows(lookupAs[RIterable](body).elementType.asInstanceOf[RStruct]) case TableParallelize(rowsAndGlobal, _) => - val Seq(rowsReq: RIterable, globalReq: RStruct) = lookupAs[RBaseStruct](rowsAndGlobal).children + val Seq(rowsReq: RIterable, globalReq: RStruct) = + lookupAs[RBaseStruct](rowsAndGlobal).children requiredness.unionRows(tcoerce[RStruct](rowsReq.elementType)) requiredness.unionGlobals(globalReq) case TableMapRows(child, newRow) => @@ -389,7 +438,7 @@ class Requiredness(val usesAndDefs: UsesAndDefs, ctx: ExecuteContext) { case TableUnion(children) => requiredness.unionFrom(lookup(children.head)) children.tail.foreach(c => requiredness.unionRows(lookup(c))) - case TableKeyByAndAggregate(child, expr, newKey, nPartitions, bufferSize) => + case TableKeyByAndAggregate(child, expr, newKey, _, _) => requiredness.unionKeys(lookupAs[RStruct](newKey)) requiredness.unionValues(lookupAs[RStruct](expr)) requiredness.unionGlobals(lookup(child)) @@ -422,7 +471,8 @@ class Requiredness(val usesAndDefs: UsesAndDefs, ctx: ExecuteContext) { requiredness.key.take(joinKey).zipWithIndex.foreach { case (k, i) => requiredness.field(k).unionWithIntersection(FastSeq( leftReq.field(leftReq.key(i)), - rightReq.field(rightReq.key(i)))) + rightReq.field(rightReq.key(i)), + )) } requiredness.unionGlobals(leftReq.globalType) @@ -438,12 +488,14 @@ class Requiredness(val usesAndDefs: UsesAndDefs, ctx: ExecuteContext) { requiredness.field(root).asInstanceOf[RIterable] .elementType.asInstanceOf[RStruct] else requiredness.field(root).asInstanceOf[RStruct] - rReq.valueFields.foreach { n => joinField.field(n).unionFrom(rReq.field(n)) } + rReq.valueFields.foreach(n => joinField.field(n).unionFrom(rReq.field(n))) requiredness.field(root).union(false) requiredness.unionGlobals(lReq) case TableMultiWayZipJoin(children, valueName, globalName) => - val valueStruct = tcoerce[RStruct](tcoerce[RIterable](requiredness.field(valueName)).elementType) - val globalStruct = tcoerce[RStruct](tcoerce[RIterable](requiredness.field(globalName)).elementType) + val valueStruct = + tcoerce[RStruct](tcoerce[RIterable](requiredness.field(valueName)).elementType) + val globalStruct = + tcoerce[RStruct](tcoerce[RIterable](requiredness.field(globalName)).elementType) children.foreach { c => val cReq = lookup(c) requiredness.unionKeys(cReq) @@ -459,12 +511,14 @@ class Requiredness(val usesAndDefs: UsesAndDefs, ctx: ExecuteContext) { rReq.valueFields.foreach(n => joined.field(n).unionFrom(rReq.field(n))) joined.union(false) requiredness.unionGlobals(lReq.globalType) - case TableMapPartitions(child, globalName, partitionStreamName, body, _, _) => + case TableMapPartitions(child, _, _, body, _, _) => requiredness.unionRows(lookupAs[RIterable](body).elementType.asInstanceOf[RStruct]) requiredness.unionGlobals(lookup(child)) - case TableToTableApply(child, function) => requiredness.maximize() //FIXME: needs implementation - case BlockMatrixToTableApply(child, _, function) => requiredness.maximize() //FIXME: needs implementation - case BlockMatrixToTable(child) => //all required + case TableToTableApply(_, _) => + requiredness.maximize() // FIXME: needs implementation + case BlockMatrixToTableApply(_, _, _) => + requiredness.maximize() // FIXME: needs implementation + case BlockMatrixToTable(_) => // all required } requiredness.probeChangedAndReset() } @@ -474,27 +528,27 @@ class Requiredness(val usesAndDefs: UsesAndDefs, ctx: ExecuteContext) { node match { // union all case _: Cast | - _: CastRename | - _: ToSet | - _: CastToArray | - _: ToArray | - _: ToStream | - _: NDArrayReindex | - _: NDArrayAgg => + _: CastRename | + _: ToSet | + _: CastToArray | + _: ToArray | + _: ToStream | + _: NDArrayReindex | + _: NDArrayAgg => node.children.foreach { case c: IR => requiredness.unionFrom(lookup(c)) } // union top-level case _: ApplyBinaryPrimOp | - _: ApplyUnaryPrimOp | - _: Consume | - _: ArrayLen | - _: StreamLen | - _: ArrayZeros | - _: StreamRange | - _: StreamIota | - _: SeqSample | - _: StreamDistribute | - _: WriteValue => + _: ApplyUnaryPrimOp | + _: Consume | + _: ArrayLen | + _: StreamLen | + _: ArrayZeros | + _: StreamRange | + _: StreamIota | + _: SeqSample | + _: StreamDistribute | + _: WriteValue => requiredness.union(node.children.forall { case c: IR => lookup(c).required }) case x: ApplyComparisonOp if x.op.strict => requiredness.union(node.children.forall { case c: IR => lookup(c).required }) @@ -511,23 +565,30 @@ class Requiredness(val usesAndDefs: UsesAndDefs, ctx: ExecuteContext) { childField.union(false) childField.unionFrom(lookup(child)) - case x@ConsoleLog(message, result) => + case ConsoleLog(_, result) => requiredness.unionFrom(lookup(result)) case x if x.typ == TVoid => - case ApplyComparisonOp(EQWithNA(_, _), _, _) | ApplyComparisonOp(NEQWithNA(_, _), _, _) | ApplyComparisonOp(Compare(_, _), _, _) => - case ApplyComparisonOp(op, l, r) => + case ApplyComparisonOp(EQWithNA(_, _), _, _) | ApplyComparisonOp( + NEQWithNA(_, _), + _, + _, + ) | ApplyComparisonOp(Compare(_, _), _, _) => + case ApplyComparisonOp(op, _, _) => fatal(s"non-strict comparison op $op must have explicit case") - case TableCount(t) => - case TableToValueApply(t, ForceCountTable()) => + case TableCount(_) => + case TableToValueApply(_, ForceCountTable()) => case _: NA => requiredness.union(false) - case Literal(t, a) => requiredness.unionLiteral(a) - case EncodedLiteral(codec, value) => requiredness.fromPType(codec.decodedPType().setRequired(true)) + case Literal(_, a) => requiredness.unionLiteral(a) + case EncodedLiteral(codec, _) => + requiredness.fromPType(codec.decodedPType().setRequired(true)) case Coalesce(values) => val reqs = values.map(lookup) requiredness.union(reqs.exists(_.required)) - reqs.foreach(r => requiredness.children.zip(r.children).foreach { case (r1, r2) => r1.unionFrom(r2) }) + reqs.foreach(r => + requiredness.children.zip(r.children).foreach { case (r1, r2) => r1.unionFrom(r2) } + ) case If(cond, cnsq, altr) => requiredness.union(lookup(cond).required) requiredness.unionFrom(lookup(cnsq)) @@ -536,15 +597,13 @@ class Requiredness(val usesAndDefs: UsesAndDefs, ctx: ExecuteContext) { requiredness.union(lookup(x).required) requiredness.unionFrom(lookup(default)) requiredness.unionFrom(cases.map(lookup)) - case AggLet(name, value, body, isScan) => - requiredness.unionFrom(lookup(body)) - case Let(_, body) => + case Block(_, body) => requiredness.unionFrom(lookup(body)) - case RelationalLet(name, value, body) => + case RelationalLet(_, _, body) => requiredness.unionFrom(lookup(body)) - case TailLoop(name, params, _, body) => + case TailLoop(_, _, _, body) => requiredness.unionFrom(lookup(body)) - case x: BaseRef => + case _: BaseRef => requiredness.unionFrom(defs(node).map(tcoerce[TypeWithRequiredness])) case MakeArray(args, _) => tcoerce[RIterable](requiredness).elementType.unionFrom(args.map(lookup)) @@ -554,12 +613,14 @@ class Requiredness(val usesAndDefs: UsesAndDefs, ctx: ExecuteContext) { val aReq = lookupAs[RIterable](a) requiredness.unionFrom(aReq.elementType) requiredness.union(lookup(i).required && aReq.required) - case ArraySlice(a, start, stop, step, errorID) => + case ArraySlice(a, start, stop, step, _) => val aReq = lookupAs[RIterable](a) requiredness.asInstanceOf[RIterable].elementType.unionFrom(aReq.elementType) val stopReq = if (!stop.isEmpty) lookup(stop.get).required else true - requiredness.union(aReq.required && stopReq && lookup(start).required && lookup(step).required) - case ArraySort(a, l, r, c) => + requiredness.union( + aReq.required && stopReq && lookup(start).required && lookup(step).required + ) + case ArraySort(a, _, _, _) => requiredness.unionFrom(lookup(a)) case ArrayMaximalIndependentSet(a, _) => val aReq = lookupAs[RIterable](a) @@ -572,7 +633,7 @@ class Requiredness(val usesAndDefs: UsesAndDefs, ctx: ExecuteContext) { tcoerce[RDict](requiredness).keyType.unionFrom(keyType) tcoerce[RDict](requiredness).valueType.unionFrom(valueType) requiredness.union(aReq.required) - case LowerBoundOnOrderedCollection(collection, elem, _) => + case LowerBoundOnOrderedCollection(collection, _, _) => requiredness.union(lookup(collection).required) case GroupByKey(c) => val cReq = lookupAs[RIterable](c) @@ -585,13 +646,13 @@ class Requiredness(val usesAndDefs: UsesAndDefs, ctx: ExecuteContext) { tcoerce[RIterable](tcoerce[RIterable](requiredness).elementType).elementType .unionFrom(aReq.elementType) requiredness.union(aReq.required && lookup(size).required) - case StreamGroupByKey(a, key, _) => + case StreamGroupByKey(a, _, _) => val aReq = lookupAs[RIterable](a) val elt = tcoerce[RIterable](tcoerce[RIterable](requiredness).elementType).elementType elt.union(true) elt.children.zip(aReq.elementType.children).foreach { case (r1, r2) => r1.unionFrom(r2) } requiredness.union(aReq.required) - case StreamMap(a, name, body) => + case StreamMap(a, _, body) => requiredness.union(lookup(a).required) tcoerce[RIterable](requiredness).elementType.unionFrom(lookup(body)) case StreamTake(a, n) => @@ -600,16 +661,16 @@ class Requiredness(val usesAndDefs: UsesAndDefs, ctx: ExecuteContext) { case StreamDrop(a, n) => requiredness.union(lookup(n).required) requiredness.unionFrom(lookup(a)) - case StreamWhiten(stream, _, _ ,_, _, _, _, _) => + case StreamWhiten(stream, _, _, _, _, _, _, _) => requiredness.unionFrom(lookup(stream)) - case StreamZip(as, names, body, behavior, _) => + case StreamZip(as, _, body, _, _) => requiredness.union(as.forall(lookup(_).required)) tcoerce[RIterable](requiredness).elementType.unionFrom(lookup(body)) - case StreamZipJoin(as, _, curKey, curVals, joinF) => + case StreamZipJoin(as, _, _, _, joinF) => requiredness.union(as.forall(lookup(_).required)) val eltType = tcoerce[RIterable](requiredness).elementType eltType.unionFrom(lookup(joinF)) - case StreamZipJoinProducers(contexts, ctxName, makeProducer, _, curKey, curVals, joinF) => + case StreamZipJoinProducers(contexts, _, _, _, _, _, joinF) => requiredness.union(lookup(contexts).required) val eltType = tcoerce[RIterable](requiredness).elementType eltType.unionFrom(lookup(joinF)) @@ -619,65 +680,78 @@ class Requiredness(val usesAndDefs: UsesAndDefs, ctx: ExecuteContext) { as.foreach { a => elt.unionFields(tcoerce[RStruct](tcoerce[RIterable](lookup(a)).elementType)) } - case StreamFilter(a, name, cond) => + case StreamFilter(a, _, _) => requiredness.unionFrom(lookup(a)) - case StreamTakeWhile(a, name, cond) => + case StreamTakeWhile(a, _, _) => requiredness.unionFrom(lookup(a)) - case StreamDropWhile(a, name, cond) => + case StreamDropWhile(a, _, _) => requiredness.unionFrom(lookup(a)) - case StreamFlatMap(a, name, body) => + case StreamFlatMap(a, _, body) => requiredness.union(lookup(a).required) - tcoerce[RIterable](requiredness).elementType.unionFrom(lookupAs[RIterable](body).elementType) - case StreamFold(a, zero, accumName, valueName, body) => + tcoerce[RIterable](requiredness).elementType.unionFrom( + lookupAs[RIterable](body).elementType + ) + case StreamFold(a, zero, _, _, body) => requiredness.union(lookup(a).required) requiredness.unionFrom(lookup(body)) requiredness.unionFrom(lookup(zero)) // if a is length 0 - case StreamScan(a, zero, accumName, valueName, body) => + case StreamScan(a, zero, _, _, body) => requiredness.union(lookup(a).required) tcoerce[RIterable](requiredness).elementType.unionFrom(lookup(body)) tcoerce[RIterable](requiredness).elementType.unionFrom(lookup(zero)) - case StreamFold2(a, accums, valueName, seq, result) => + case StreamFold2(a, _, _, _, result) => requiredness.union(lookup(a).required) requiredness.unionFrom(lookup(result)) - case StreamJoinRightDistinct(left, right, _, _, _, _, joinf, joinType) => + case StreamLeftIntervalJoin(left, right, _, _, _, _, body) => + requiredness.union(lookup(left).required && lookup(right).required) + tcoerce[RIterable](requiredness).elementType.unionFrom(lookup(body)) + case StreamJoinRightDistinct(left, right, _, _, _, _, joinf, _) => requiredness.union(lookup(left).required && lookup(right).required) tcoerce[RIterable](requiredness).elementType.unionFrom(lookup(joinf)) - case StreamLocalLDPrune(a, r2Threshold, windowSize, maxQueueSize, nSamples) => + case StreamLocalLDPrune(a, _, _, _, _) => // FIXME what else needs to go here? requiredness.union(lookup(a).required) - case StreamAgg(a, name, query) => + case StreamAgg(a, _, query) => requiredness.union(lookup(a).required) requiredness.unionFrom(lookup(query)) - case StreamAggScan(a, name, query) => + case StreamAggScan(a, _, query) => requiredness.union(lookup(a).required) tcoerce[RIterable](requiredness).elementType.unionFrom(lookup(query)) - case AggFilter(cond, aggIR, isScan) => + case AggFilter(_, aggIR, _) => requiredness.unionFrom(lookup(aggIR)) - case AggExplode(array, name, aggBody, isScan) => + case AggExplode(_, _, aggBody, _) => requiredness.unionFrom(lookup(aggBody)) - case AggGroupBy(key, aggIR, isScan) => + case AggGroupBy(key, aggIR, _) => val rdict = tcoerce[RDict](requiredness) rdict.keyType.unionFrom(lookup(key)) rdict.valueType.unionFrom(lookup(aggIR)) - case AggArrayPerElement(a, _, _, body, knownLength, isScan) => + case AggArrayPerElement(a, _, _, body, _, _) => val rit = tcoerce[RIterable](requiredness) rit.union(lookup(a).required) rit.elementType.unionFrom(lookup(body)) - case ApplyAggOp(initOpArgs, seqOpArgs, aggSig) => //FIXME round-tripping through ptype - val emitResult = agg.PhysicalAggSig(aggSig.op, agg.AggStateSig(aggSig.op, - initOpArgs.map(i => i -> lookup(i)), - seqOpArgs.map(s => s -> lookup(s)))).emitResultType + case ApplyAggOp(_, seqOpArgs, aggSig) => // FIXME round-tripping through ptype + val emitResult = agg.PhysicalAggSig( + aggSig.op, + agg.AggStateSig( + aggSig.op, + seqOpArgs.map(s => s -> lookup(s)), + ), + ).emitResultType requiredness.fromEmitType(emitResult) - case ApplyScanOp(initOpArgs, seqOpArgs, aggSig) => - val emitResult = agg.PhysicalAggSig(aggSig.op, agg.AggStateSig(aggSig.op, - initOpArgs.map(i => i -> lookup(i)), - seqOpArgs.map(s => s -> lookup(s)))).emitResultType + case ApplyScanOp(_, seqOpArgs, aggSig) => + val emitResult = agg.PhysicalAggSig( + aggSig.op, + agg.AggStateSig( + aggSig.op, + seqOpArgs.map(s => s -> lookup(s)), + ), + ).emitResultType requiredness.fromEmitType(emitResult) - case AggFold(zero, seqOp, combOp, elementName, accumName, _) => + case AggFold(zero, seqOp, combOp, _, _, _) => requiredness.unionFrom(lookup(zero)) requiredness.unionFrom(lookup(seqOp)) requiredness.unionFrom(lookup(combOp)) - case MakeNDArray(data, shape, rowMajor, _) => + case MakeNDArray(data, shape, _, _) => requiredness.unionFrom(lookup(data)) requiredness.union(lookup(shape).required) case NDArrayShape(nd) => @@ -687,7 +761,7 @@ class Requiredness(val usesAndDefs: UsesAndDefs, ctx: ExecuteContext) { val ndReq = lookup(nd) requiredness.unionFrom(ndReq) requiredness.union(sReq.required && sReq.children.forall(_.required)) - case NDArrayConcat(nds, axis) => + case NDArrayConcat(nds, _) => val ndsReq = lookupAs[RIterable](nds) requiredness.unionFrom(ndsReq.elementType) requiredness.union(ndsReq.required) @@ -703,9 +777,9 @@ class Requiredness(val usesAndDefs: UsesAndDefs, ctx: ExecuteContext) { } requiredness.unionFrom(lookup(nd)) requiredness.union(slicesReq.required && allSlicesRequired) - case NDArrayFilter(nd, keep) => + case NDArrayFilter(nd, _) => requiredness.unionFrom(lookup(nd)) - case NDArrayMap(nd, name, body) => + case NDArrayMap(nd, _, body) => requiredness.union(lookup(nd).required) tcoerce[RNDArray](requiredness).unionElement(lookup(body)) case NDArrayMap2(l, r, _, _, body, _) => @@ -714,9 +788,12 @@ class Requiredness(val usesAndDefs: UsesAndDefs, ctx: ExecuteContext) { case NDArrayMatMul(l, r, _) => requiredness.unionFrom(lookup(l)) requiredness.union(lookup(r).required) - case NDArrayQR(child, mode, _) => requiredness.fromPType(NDArrayQR.pType(mode, lookup(child).required)) - case NDArraySVD(child, _, computeUV, _) => requiredness.fromPType(NDArraySVD.pTypes(computeUV, lookup(child).required)) - case NDArrayEigh(child, eigvalsOnly, _) => requiredness.fromPType(NDArrayEigh.pTypes(eigvalsOnly, lookup(child).required)) + case NDArrayQR(child, mode, _) => + requiredness.fromPType(NDArrayQR.pType(mode, lookup(child).required)) + case NDArraySVD(child, _, computeUV, _) => + requiredness.fromPType(NDArraySVD.pTypes(computeUV, lookup(child).required)) + case NDArrayEigh(child, eigvalsOnly, _) => + requiredness.fromPType(NDArrayEigh.pTypes(eigvalsOnly, lookup(child).required)) case NDArrayInv(child, _) => requiredness.unionFrom(lookup(child)) case MakeStruct(fields) => fields.foreach { case (n, f) => @@ -729,9 +806,7 @@ class Requiredness(val usesAndDefs: UsesAndDefs, ctx: ExecuteContext) { case SelectFields(old, fields) => val oldReq = lookupAs[RStruct](old) requiredness.union(oldReq.required) - fields.foreach { n => - tcoerce[RStruct](requiredness).field(n).unionFrom(oldReq.field(n)) - } + fields.foreach(n => tcoerce[RStruct](requiredness).field(n).unionFrom(oldReq.field(n))) case InsertFields(old, fields, _) => lookup(old) match { case oldReq: RStruct => @@ -741,8 +816,8 @@ class Requiredness(val usesAndDefs: UsesAndDefs, ctx: ExecuteContext) { f.typ.unionFrom(fieldMap.getOrElse(f.name, oldReq.field(f.name))) } case _ => fields.foreach { case (n, f) => - tcoerce[RStruct](requiredness).field(n).unionFrom(lookup(f)) - } + tcoerce[RStruct](requiredness).field(n).unionFrom(lookup(f)) + } } case GetField(o, name) => val oldReq = lookupAs[RStruct](o) @@ -753,13 +828,16 @@ class Requiredness(val usesAndDefs: UsesAndDefs, ctx: ExecuteContext) { requiredness.union(oldReq.required) requiredness.unionFrom(oldReq.field(idx)) case x: ApplyIR => requiredness.unionFrom(lookup(x.body)) - case x: AbstractApplyNode[_] => //FIXME: round-tripping via PTypes. + case x: AbstractApplyNode[_] => // FIXME: round-tripping via PTypes. val argP = x.args.map { a => val pt = lookup(a).canonicalPType(a.typ) EmitType(pt.sType, pt.required) } - requiredness.unionFrom(x.implementation.computeReturnEmitType(x.returnType, argP).typeWithRequiredness.r) - case CollectDistributedArray(ctxs, globs, _, _, body, _, _, _) => + requiredness.unionFrom(x.implementation.computeReturnEmitType( + x.returnType, + argP, + ).typeWithRequiredness.r) + case CollectDistributedArray(ctxs, _, _, _, body, _, _, _) => requiredness.union(lookup(ctxs).required) tcoerce[RIterable](requiredness).elementType.unionFrom(lookup(body)) case ReadPartition(context, rowType, reader) => @@ -773,37 +851,38 @@ class Requiredness(val usesAndDefs: UsesAndDefs, ctx: ExecuteContext) { requiredness.union(lookup(path).required) reader.unionRequiredness(rt, requiredness) case In(_, t) => t match { - case SCodeEmitParamType(et) => requiredness.unionFrom(et.typeWithRequiredness.r) - case SingleCodeEmitParamType(required, StreamSingleCodeType(_, eltType, eltRequired)) => - requiredness.asInstanceOf[RIterable].elementType.fromPType(eltType.setRequired(eltRequired)) - requiredness.union(required) - case SingleCodeEmitParamType(required, PTypeReferenceSingleCodeType(pt)) => - requiredness.fromPType(pt.setRequired(required)) - case SingleCodeEmitParamType(required, _) => - requiredness.union(required) - } + case SCodeEmitParamType(et) => requiredness.unionFrom(et.typeWithRequiredness.r) + case SingleCodeEmitParamType(required, StreamSingleCodeType(_, eltType, eltRequired)) => + requiredness.asInstanceOf[RIterable].elementType.fromPType( + eltType.setRequired(eltRequired) + ) + requiredness.union(required) + case SingleCodeEmitParamType(required, PTypeReferenceSingleCodeType(pt)) => + requiredness.fromPType(pt.setRequired(required)) + case SingleCodeEmitParamType(required, _) => + requiredness.union(required) + } case LiftMeOut(f) => requiredness.unionFrom(lookup(f)) case ResultOp(_, sig) => val r = requiredness r.fromEmitType(sig.emitResultType) case RunAgg(_, result, _) => requiredness.unionFrom(lookup(result)) - case StreamBufferedAggregate(streamChild, initAggs, newKey, seqOps, _, _, _) => + case StreamBufferedAggregate(streamChild, _, newKey, _, _, _, _) => requiredness.union(lookup(streamChild).required) val rstruct = requiredness.asInstanceOf[RIterable].elementType.asInstanceOf[RStruct] lookup(newKey).asInstanceOf[RStruct] .fields - .foreach { f => - rstruct.field(f.name).unionFrom(f.typ) - } - case RunAggScan(array, name, init, seqs, result, signature) => + .foreach(f => rstruct.field(f.name).unionFrom(f.typ)) + case RunAggScan(array, _, _, _, result, _) => requiredness.union(lookup(array).required) tcoerce[RIterable](requiredness).elementType.unionFrom(lookup(result)) - case TableAggregate(c, q) => requiredness.unionFrom(lookup(q)) + case TableAggregate(_, q) => requiredness.unionFrom(lookup(q)) case TableGetGlobals(c) => requiredness.unionFrom(lookup(c).globalType) case TableCollect(c) => val cReq = lookup(c) - val row = requiredness.asInstanceOf[RStruct].fieldType("rows").asInstanceOf[RIterable].elementType + val row = + requiredness.asInstanceOf[RStruct].fieldType("rows").asInstanceOf[RIterable].elementType val global = requiredness.asInstanceOf[RStruct].fieldType("global") row.unionFrom(cReq.rowType) global.unionFrom(cReq.globalType) @@ -812,11 +891,13 @@ class Requiredness(val usesAndDefs: UsesAndDefs, ctx: ExecuteContext) { case TableGetGlobals(c) => requiredness.unionFrom(lookup(c).globalType) case TableCollect(c) => - tcoerce[RIterable](tcoerce[RStruct](requiredness).field("rows")).elementType.unionFrom(lookup(c).rowType) + tcoerce[RIterable](tcoerce[RStruct](requiredness).field("rows")).elementType.unionFrom( + lookup(c).rowType + ) tcoerce[RStruct](requiredness).field("global").unionFrom(lookup(c).globalType) - case BlockMatrixToValueApply(child, GetElement(_)) => // BlockMatrix elements are all required - case BlockMatrixCollect(child) => // BlockMatrix elements are all required - case BlockMatrixWrite(child, writer) => // write result is required + case BlockMatrixToValueApply(_, GetElement(_)) => // BlockMatrix elements are all required + case BlockMatrixCollect(_) => // BlockMatrix elements are all required + case BlockMatrixWrite(_, _) => // write result is required } requiredness.probeChangedAndReset() } @@ -828,7 +909,6 @@ class Requiredness(val usesAndDefs: UsesAndDefs, ctx: ExecuteContext) { requiredness.probeChangedAndReset() } - final class Queue(val markFlag: Int) { private[this] val q = mutable.Queue[RefEquality[BaseIR]]() diff --git a/hail/src/main/scala/is/hail/expr/ir/Scope.scala b/hail/src/main/scala/is/hail/expr/ir/Scope.scala index 2e9b995e55d..65cfb65ccb4 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Scope.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Scope.scala @@ -2,7 +2,8 @@ package is.hail.expr.ir object UsesAggEnv { def apply(ir0: BaseIR, i: Int): Boolean = ir0 match { - case AggLet(_, _, _, false) => i == 0 + case Block(bindings, _) if i < bindings.length => + bindings(i).scope == Scope.AGG case AggGroupBy(_, _, false) => i == 0 case AggFilter(_, _, false) => i == 0 case AggExplode(_, _, _, false) => i == 0 @@ -15,7 +16,8 @@ object UsesAggEnv { object UsesScanEnv { def apply(ir0: BaseIR, i: Int): Boolean = ir0 match { - case AggLet(_, _, _, true) => i == 0 + case Block(bindings, _) if i < bindings.length => + bindings(i).scope == Scope.SCAN case AggGroupBy(_, _, true) => i == 0 case AggFilter(_, _, true) => i == 0 case AggExplode(_, _, _, true) => i == 0 @@ -26,7 +28,6 @@ object UsesScanEnv { } } - object Scope { val EVAL: Int = 0 val AGG: Int = 1 @@ -50,20 +51,20 @@ object Scope { compute(value, EVAL) compute(body, scope) case _ => ir.children.zipWithIndex.foreach { - case (child: IR, i) => - val usesAgg = UsesAggEnv(ir, i) - val usesScan = UsesScanEnv(ir, i) - if (usesAgg) { - assert(!usesScan) - assert(scope == EVAL) - compute(child, AGG) - } else if (usesScan) { - assert(scope == EVAL) - compute(child, SCAN) - } else - compute(child, scope) - case (child, _) => compute(child, EVAL) - } + case (child: IR, i) => + val usesAgg = UsesAggEnv(ir, i) + val usesScan = UsesScanEnv(ir, i) + if (usesAgg) { + assert(!usesScan) + assert(scope == EVAL) + compute(child, AGG) + } else if (usesScan) { + assert(scope == EVAL) + compute(child, SCAN) + } else + compute(child, scope) + case (child, _) => compute(child, EVAL) + } } } diff --git a/hail/src/main/scala/is/hail/expr/ir/Simplify.scala b/hail/src/main/scala/is/hail/expr/ir/Simplify.scala index 68faa4e0829..b4fc473645a 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Simplify.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Simplify.scala @@ -7,10 +7,11 @@ import is.hail.types.tcoerce import is.hail.types.virtual._ import is.hail.utils._ +import scala.collection.mutable + object Simplify { - /** Transform 'ir' using simplification rules until none apply. - */ + /** Transform 'ir' using simplification rules until none apply. */ def apply(ctx: ExecuteContext, ir: BaseIR): BaseIR = ir match { case ir: IR => simplifyValue(ctx)(ir) case tir: TableIR => simplifyTable(ctx)(tir) @@ -21,8 +22,10 @@ object Simplify { private[this] def visitNode[T <: BaseIR]( visitChildren: BaseIR => BaseIR, transform: T => Option[T], - post: => (T => T) - )(t: T): T = { + post: => (T => T), + )( + t: T + ): T = { val t1 = t.mapChildren(visitChildren).asInstanceOf[T] transform(t1).map(post).getOrElse(t1) } @@ -31,30 +34,29 @@ object Simplify { visitNode( Simplify(ctx, _), rewriteValueNode(ctx), - simplifyValue(ctx) + simplifyValue(ctx), ) private[this] def simplifyTable(ctx: ExecuteContext)(tir: TableIR): TableIR = visitNode( Simplify(ctx, _), rewriteTableNode(ctx), - simplifyTable(ctx) + simplifyTable(ctx), )(tir) private[this] def simplifyMatrix(ctx: ExecuteContext)(mir: MatrixIR): MatrixIR = visitNode( Simplify(ctx, _), rewriteMatrixNode(), - simplifyMatrix(ctx) + simplifyMatrix(ctx), )(mir) - private[this] def simplifyBlockMatrix(ctx: ExecuteContext)(bmir: BlockMatrixIR): BlockMatrixIR = { + private[this] def simplifyBlockMatrix(ctx: ExecuteContext)(bmir: BlockMatrixIR): BlockMatrixIR = visitNode( Simplify(ctx, _), rewriteBlockMatrixNode, - simplifyBlockMatrix(ctx) + simplifyBlockMatrix(ctx), )(bmir) - } private[this] def rewriteValueNode(ctx: ExecuteContext)(ir: IR): Option[IR] = valueRules(ctx).lift(ir).orElse(numericRules(ir)) @@ -65,56 +67,37 @@ object Simplify { private[this] def rewriteMatrixNode()(mir: MatrixIR): Option[MatrixIR] = matrixRules().lift(mir) - private[this] def rewriteBlockMatrixNode: BlockMatrixIR => Option[BlockMatrixIR] = blockMatrixRules.lift - - /** Returns true if 'x' propagates missingness, meaning if any child of 'x' - * evaluates to missing, then 'x' will evaluate to missing. - */ - private[this] def isStrict(x: IR): Boolean = { - x match { - case _: Apply | - _: ApplySeeded | - _: ApplyUnaryPrimOp | - _: ApplyBinaryPrimOp | - _: ArrayRef | - _: ArrayLen | - _: GetField | - _: GetTupleElement => true - case ApplyComparisonOp(op, _, _) => op.strict - case _ => false - } - } + private[this] def rewriteBlockMatrixNode: BlockMatrixIR => Option[BlockMatrixIR] = + blockMatrixRules.lift - /** - * Returns true if any strict child of 'x' is NA. - * A child is strict if 'x' evaluates to missing whenever the child does. + /** Returns true if any strict child of 'x' is NA. A child is strict if 'x' evaluates to missing + * whenever the child does. */ private[this] def hasMissingStrictChild(x: IR): Boolean = { x match { case _: Apply | - _: ApplySeeded | - _: ApplyUnaryPrimOp | - _: ApplyBinaryPrimOp | - _: ArrayRef | - _: ArrayLen | - _: GetField | - _: GetTupleElement => x.children.exists(_.isInstanceOf[NA]) + _: ApplySeeded | + _: ApplyUnaryPrimOp | + _: ApplyBinaryPrimOp | + _: ArrayRef | + _: ArrayLen | + _: GetField | + _: GetTupleElement => x.children.exists(_.isInstanceOf[NA]) case ApplyComparisonOp(op, _, _) if op.strict => x.children.exists(_.isInstanceOf[NA]) case _ => false } } - /** Returns true if 'x' will never evaluate to missing. - */ + /** Returns true if 'x' will never evaluate to missing. */ private[this] def isDefinitelyDefined(x: IR): Boolean = { x match { case _: MakeArray | - _: MakeStruct | - _: MakeTuple | - _: IsNA | - ApplyComparisonOp(EQWithNA(_, _), _, _) | - ApplyComparisonOp(NEQWithNA(_, _), _, _) | - _: I32 | _: I64 | _: F32 | _: F64 | True() | False() => true + _: MakeStruct | + _: MakeTuple | + _: IsNA | + ApplyComparisonOp(EQWithNA(_, _), _, _) | + ApplyComparisonOp(NEQWithNA(_, _), _, _) | + _: I32 | _: I64 | _: F32 | _: F64 | True() | False() => true case _ => false } } @@ -175,10 +158,10 @@ object Simplify { def hoistUnaryOp(ir: IR): Option[IR] = ir match { - case ApplyUnaryPrimOp(f@(Negate | BitNot | Bang), x) => x match { - case ApplyUnaryPrimOp(g, y) if g == f => Some(y) - case _ => None - } + case ApplyUnaryPrimOp(f @ (Negate | BitNot | Bang), x) => x match { + case ApplyUnaryPrimOp(g, y) if g == f => Some(y) + case _ => None + } case _ => None } @@ -227,36 +210,39 @@ object Simplify { case x: IR if hasMissingStrictChild(x) => NA(x.typ) - case x@If(NA(_), _, _) => NA(x.typ) + case x @ If(NA(_), _, _) => NA(x.typ) case Coalesce(values) if isDefinitelyDefined(values.head) => values.head - case Coalesce(values) if values.zipWithIndex.exists { case (ir, i) => isDefinitelyDefined(ir) && i != values.size - 1 } => + case Coalesce(values) if values.zipWithIndex.exists { case (ir, i) => + isDefinitelyDefined(ir) && i != values.size - 1 + } => val idx = values.indexWhere(isDefinitelyDefined) Coalesce(values.take(idx + 1)) case Coalesce(values) if values.size == 1 => values.head - case x@StreamMap(NA(_), _, _) => NA(x.typ) + case x @ StreamMap(NA(_), _, _) => NA(x.typ) case StreamZip(as, names, body, _, _) if as.length == 1 => StreamMap(as.head, names.head, body) case StreamMap(StreamZip(as, names, zipBody, b, errorID), name, mapBody) => StreamZip(as, names, Let(FastSeq(name -> zipBody), mapBody), b, errorID) - case StreamMap(StreamFlatMap(child, flatMapName, flatMapBody), mapName, mapBody) => StreamFlatMap(child, flatMapName, StreamMap(flatMapBody, mapName, mapBody)) + case StreamMap(StreamFlatMap(child, flatMapName, flatMapBody), mapName, mapBody) => + StreamFlatMap(child, flatMapName, StreamMap(flatMapBody, mapName, mapBody)) - case x@StreamFlatMap(NA(_), _, _) => NA(x.typ) + case x @ StreamFlatMap(NA(_), _, _) => NA(x.typ) - case x@StreamFilter(NA(_), _, _) => NA(x.typ) + case x @ StreamFilter(NA(_), _, _) => NA(x.typ) - case x@StreamFold(NA(_), _, _, _, _) => NA(x.typ) + case x @ StreamFold(NA(_), _, _, _, _) => NA(x.typ) case IsNA(NA(_)) => True() case IsNA(x) if isDefinitelyDefined(x) => False() - case x@If(True(), cnsq, _) => cnsq + case If(True(), cnsq, _) => cnsq - case x@If(False(), _, altr) => altr + case If(False(), _, altr) => altr case If(c, cnsq, altr) if cnsq == altr && cnsq.typ != TVoid => if (isDefinitelyDefined(c)) @@ -264,6 +250,8 @@ object Simplify { else If(IsNA(c), NA(cnsq.typ), cnsq) + case If(IsNA(a), NA(_), b) if a == b => b + case If(ApplyUnaryPrimOp(Bang, c), cnsq, altr) => If(c, altr, cnsq) case If(c1, If(c2, cnsq2, _), altr1) if c1 == c2 => If(c1, cnsq2, altr1) @@ -276,58 +264,67 @@ object Simplify { default case Cast(x, t) if x.typ == t => x - case Cast(Cast(x, _), t) if x.typ == t =>x + case Cast(Cast(x, _), t) if x.typ == t => x case CastRename(x, t) if x.typ == t => x case CastRename(CastRename(x, _), t) => CastRename(x, t) - case ApplyIR("indexArray", _, Seq(a, i@I32(v)), _, errorID) if v >= 0 => + case ApplyIR("indexArray", _, Seq(a, i @ I32(v)), _, errorID) if v >= 0 => ArrayRef(a, i, errorID) - case ApplyIR("contains", _, Seq(CastToArray(x), element), _, _) if x.typ.isInstanceOf[TSet] => invoke("contains", TBoolean, x, element) + case ApplyIR("contains", _, Seq(CastToArray(x), element), _, _) if x.typ.isInstanceOf[TSet] => + invoke("contains", TBoolean, x, element) case ApplyIR("contains", _, Seq(Literal(t, v), element), _, _) if t.isInstanceOf[TArray] => - invoke("contains", TBoolean, Literal(TSet(t.asInstanceOf[TArray].elementType), v.asInstanceOf[IndexedSeq[_]].toSet), element) + invoke( + "contains", + TBoolean, + Literal(TSet(t.asInstanceOf[TArray].elementType), v.asInstanceOf[IndexedSeq[_]].toSet), + element, + ) - case ApplyIR("contains", _, Seq(ToSet(x), element), _, _) if x.typ.isInstanceOf[TArray] => invoke("contains", TBoolean, x, element) + case ApplyIR("contains", _, Seq(ToSet(x), element), _, _) if x.typ.isInstanceOf[TArray] => + invoke("contains", TBoolean, x, element) case x: ApplyIR if x.inline || x.body.size < 10 => x.explicitNode case ArrayLen(MakeArray(args, _)) => I32(args.length) case StreamLen(MakeStream(args, _, _)) => I32(args.length) - case StreamLen(Let(bindings, body)) => Let(bindings, StreamLen(body)) + case StreamLen(Block(bindings, body)) => Block(bindings, StreamLen(body)) case StreamLen(StreamMap(s, _, _)) => StreamLen(s) case StreamLen(StreamFlatMap(a, name, body)) => streamSumIR(StreamMap(a, name, StreamLen(body))) - case StreamLen(StreamGrouped(a, groupSize)) => bindIR(groupSize)(groupSizeRef => (StreamLen(a) + groupSizeRef - 1) floorDiv groupSizeRef) + case StreamLen(StreamGrouped(a, groupSize)) => + bindIR(groupSize)(groupSizeRef => (StreamLen(a) + groupSizeRef - 1) floorDiv groupSizeRef) case ArrayLen(ToArray(s)) if s.typ.isInstanceOf[TStream] => StreamLen(s) - case ArrayLen(StreamFlatMap(a, _, MakeArray(args, _))) => ApplyBinaryPrimOp(Multiply(), I32(args.length), ArrayLen(a)) + case ArrayLen(StreamFlatMap(a, _, MakeArray(args, _))) => + ApplyBinaryPrimOp(Multiply(), I32(args.length), ArrayLen(a)) case ArrayLen(ArraySort(a, _, _, _)) => ArrayLen(ToArray(a)) case ArrayLen(ToArray(MakeStream(args, _, _))) => I32(args.length) - case ArraySlice(ToArray(s),I32(0), Some(x@I32(i)), I32(1), _) if i >= 0 => + case ArraySlice(ToArray(s), I32(0), Some(x @ I32(i)), I32(1), _) if i >= 0 => ToArray(StreamTake(s, x)) - case ArraySlice(z@ToArray(s), x@I32(i), Some(I32(j)), I32(1), _) if i > 0 && j > 0 => { + case ArraySlice(z @ ToArray(s), x @ I32(i), Some(I32(j)), I32(1), _) if i > 0 && j > 0 => if (j > i) { - ToArray(StreamTake(StreamDrop(s, x), I32(j-i))) + ToArray(StreamTake(StreamDrop(s, x), I32(j - i))) } else new MakeArray(FastSeq(), z.typ.asInstanceOf[TArray]) - } - case ArraySlice(ToArray(s), x@I32(i), None, I32(1), _) if i >= 0 => + case ArraySlice(ToArray(s), x @ I32(i), None, I32(1), _) if i >= 0 => ToArray(StreamDrop(s, x)) case ArrayRef(MakeArray(args, _), I32(i), _) if i >= 0 && i < args.length => args(i) case StreamFilter(a, _, True()) => a - case StreamFor(_, _, Begin(Seq())) => Begin(FastSeq()) + case StreamFor(_, _, Void()) => Void() // FIXME: Unqualify when StreamFold supports folding over stream of streams - case StreamFold(StreamMap(a, n1, b), zero, accumName, valueName, body) if a.typ.asInstanceOf[TStream].elementType.isRealizable => + case StreamFold(StreamMap(a, n1, b), zero, accumName, valueName, body) + if a.typ.asInstanceOf[TStream].elementType.isRealizable => StreamFold(a, zero, accumName, n1, Let(FastSeq(valueName -> b), body)) case StreamFlatMap(StreamMap(a, n1, b1), n2, b2) => @@ -338,10 +335,18 @@ object Simplify { case StreamMap(StreamMap(a, n1, b1), n2, b2) => StreamMap(a, n1, Let(FastSeq(n2 -> b1), b2)) - case StreamFilter(ArraySort(a, left, right, lessThan), name, cond) => ArraySort(StreamFilter(a, name, cond), left, right, lessThan) - - case StreamFilter(ToStream(ArraySort(a, left, right, lessThan), requiresMemoryManagementPerElement), name, cond) => - ToStream(ArraySort(StreamFilter(a, name, cond), left, right, lessThan), requiresMemoryManagementPerElement) + case StreamFilter(ArraySort(a, left, right, lessThan), name, cond) => + ArraySort(StreamFilter(a, name, cond), left, right, lessThan) + + case StreamFilter( + ToStream(ArraySort(a, left, right, lessThan), requiresMemoryManagementPerElement), + name, + cond, + ) => + ToStream( + ArraySort(StreamFilter(a, name, cond), left, right, lessThan), + requiresMemoryManagementPerElement, + ) case CastToArray(x) if x.typ.isInstanceOf[TArray] => x case ToArray(ToStream(a, _)) if a.typ.isInstanceOf[TArray] => a @@ -350,20 +355,20 @@ object Simplify { case ToStream(ToArray(s), false) if s.typ.isInstanceOf[TStream] => s - case ToStream(Let(bindings, ToArray(x)), false) if x.typ.isInstanceOf[TStream] => - Let(bindings, x) + case ToStream(Block(bindings, ToArray(x)), false) if x.typ.isInstanceOf[TStream] => + Block(bindings, x) - case MakeNDArray(ToArray(someStream), shape, rowMajor, errorId) => MakeNDArray(someStream, shape, rowMajor, errorId) - case MakeNDArray(ToStream(someArray, _), shape, rowMajor, errorId) => MakeNDArray(someArray, shape, rowMajor, errorId) - case NDArrayShape(MakeNDArray(data, shape, _, _)) => { + case MakeNDArray(ToArray(someStream), shape, rowMajor, errorId) => + MakeNDArray(someStream, shape, rowMajor, errorId) + case MakeNDArray(ToStream(someArray, _), shape, rowMajor, errorId) => + MakeNDArray(someArray, shape, rowMajor, errorId) + case NDArrayShape(MakeNDArray(data, shape, _, _)) => If(IsNA(data), NA(shape.typ), shape) - } case NDArrayShape(NDArrayMap(nd, _, _)) => NDArrayShape(nd) case NDArrayMap(NDArrayMap(child, innerName, innerBody), outerName, outerBody) => NDArrayMap(child, innerName, Let(FastSeq(outerName -> innerBody), outerBody)) - case GetField(MakeStruct(fields), name) => val (_, x) = fields.find { case (n, _) => n == name }.get x @@ -374,9 +379,9 @@ object Simplify { case None => GetField(old, name) } - case GetField(SelectFields(old, fields), name) => GetField(old, name) + case GetField(SelectFields(old, _), name) => GetField(old, name) - case outer@InsertFields(InsertFields(base, fields1, fieldOrder1), fields2, fieldOrder2) => + case outer @ InsertFields(InsertFields(base, fields1, fieldOrder1), fields2, fieldOrder2) => val fields2Set = fields2.map(_._1).toSet val newFields = fields1.filter { case (name, _) => !fields2Set.contains(name) } ++ fields2 (fieldOrder1, fieldOrder2) match { @@ -387,7 +392,8 @@ object Simplify { case (_, Some(_)) => InsertFields(base, newFields, fieldOrder2) case _ => - // In this case, it's important to make a field order that reflects the original insertion order + /* In this case, it's important to make a field order that reflects the original insertion + * order */ val resultFieldOrder = outer.typ.fieldNames InsertFields(base, newFields, Some(resultFieldOrder)) } @@ -399,69 +405,81 @@ object Simplify { case Some(fo) => MakeStruct(fo.map(f => f -> fields2Map.getOrElse(f, fields1Map(f)))) case None => - val finalFields = fields1.map { case (name, fieldIR) => name -> fields2Map.getOrElse(name, fieldIR) } ++ + val finalFields = fields1.map { case (name, fieldIR) => + name -> fields2Map.getOrElse(name, fieldIR) + } ++ fields2.filter { case (name, _) => !fields1Map.contains(name) } MakeStruct(finalFields) } case InsertFields(struct, Seq(), None) => struct - case InsertFields(SelectFields(old, _), Seq(), Some(insertFieldOrder)) => SelectFields(old, insertFieldOrder) + case InsertFields(SelectFields(old, _), Seq(), Some(insertFieldOrder)) => + SelectFields(old, insertFieldOrder) - case Let(Seq(), body) => + case Block(Seq(), body) => body - case Let(xs, Let(ys, body)) => - Let(xs ++ ys, body) + case Block(xs, Block(ys, body)) => + Block(xs ++ ys, body) - // assumes `NormalizeNames` has been run before this. - case Let(Let.Nested(before, after), body) => - def numBindings(b: (String, IR)): Int = - b._2 match { - case let: Let => 1 + let.bindings.length + // assumes `NormalizeNames` has been run before this. + case Block(Block.Nested(i, bindings), body) => + def numBindings(b: Binding): Int = + b.value match { + case let: Block => 1 + let.bindings.length case _ => 1 } val newBindings = - new BoxedArrayBuilder[(String, IR)]( - after.foldLeft(before.length) { (sum, binding) => - sum + numBindings(binding) - } + new BoxedArrayBuilder[Binding]( + bindings.foldLeft(0)((sum, binding) => sum + numBindings(binding)) ) - newBindings ++= before - - after.foreach { - case (name: String, ir: Let) => - newBindings ++= ir.bindings - newBindings += name -> ir.body - case (name, value) => - newBindings += name -> value + newBindings ++= bindings.view.take(i) + + bindings.view.drop(i).foreach { + case Binding(name, ir: Block, scope) => + newBindings ++= (if (scope == Scope.EVAL) ir.bindings + else ir.bindings.map { + case Binding(name, value, Scope.EVAL) => Binding(name, value, scope) + case _ => fatal("Simplify: found nested Agg bindings") + }) + newBindings += Binding(name, ir.body, scope) + case binding => newBindings += binding } - Let(newBindings.underlying(), body) - - case Let(Let.Insert(before, (name, x@InsertFields(old, newFields, _)) +: after), body) if x.typ.size < 500 && { - val r = Ref(name, x.typ) - val nfSet = newFields.map(_._1).toSet - - def allRefsCanBePassedThrough(ir1: IR): Boolean = ir1 match { - case GetField(`r`, _) => true - case InsertFields(`r`, inserted, _) => inserted.forall { case (_, toInsert) => allRefsCanBePassedThrough(toInsert) } - case SelectFields(`r`, fds) => fds.forall(f => !nfSet.contains(f)) - case `r` => false // if the binding is referenced in any other context, don't rewrite - case _: TableAggregate => true - case _: MatrixAggregate => true - case _ => ir1.children - .zipWithIndex - .forall { - case (child: IR, idx) => Binds(ir1, name, idx) || allRefsCanBePassedThrough(child) - case _ => true + Block(newBindings.underlying(), body) + + case Block( + Block.Insert( + before, + Binding(name, x @ InsertFields(old, newFields, _), Scope.EVAL), + after, + ), + body, + ) + if x.typ.size < 500 && { + val r = Ref(name, x.typ) + val nfSet = newFields.map(_._1).toSet + + def allRefsCanBePassedThrough(ir1: IR): Boolean = ir1 match { + case GetField(`r`, _) => true + case InsertFields(`r`, inserted, _) => + inserted.forall { case (_, toInsert) => allRefsCanBePassedThrough(toInsert) } + case SelectFields(`r`, fds) => fds.forall(f => !nfSet.contains(f)) + case `r` => false // if the binding is referenced in any other context, don't rewrite + case _: TableAggregate => true + case _: MatrixAggregate => true + case _ => ir1.children + .zipWithIndex + .forall { + case (child: IR, idx) => Binds(ir1, name, idx) || allRefsCanBePassedThrough(child) + case _ => true + } } - } - allRefsCanBePassedThrough(Let(after.toFastSeq, body)) - } => - val r = Ref(name, x.typ) + allRefsCanBePassedThrough(Block(after.toFastSeq, body)) + } => val fieldNames = newFields.map(_._1).toArray val newFieldMap = newFields.toMap val newFieldRefs = newFieldMap.map { case (k, ir) => @@ -473,18 +491,27 @@ object Simplify { def rewrite(ir1: IR): IR = ir1 match { case GetField(Ref(`name`, _), fd) => newFieldRefs.get(fd) match { - case Some(r) => r.deepCopy() - case None => GetField(Ref(name, old.typ), fd) - } - case ins@InsertFields(Ref(`name`, _), fields, _) => + case Some(r) => r.deepCopy() + case None => GetField(Ref(name, old.typ), fd) + } + case ins @ InsertFields(Ref(`name`, _), fields, _) => val newFieldSet = fields.map(_._1).toSet - InsertFields(Ref(name, old.typ), + InsertFields( + Ref(name, old.typ), copiedNewFieldRefs().filter { case (name, _) => !newFieldSet.contains(name) } ++ fields.map { case (name, ir) => (name, rewrite(ir)) }, - Some(ins.typ.fieldNames.toFastSeq)) + Some(ins.typ.fieldNames.toFastSeq), + ) case SelectFields(Ref(`name`, _), fds) => - SelectFields(InsertFields(Ref(name, old.typ), copiedNewFieldRefs(), Some(x.typ.fieldNames.toFastSeq)), fds) + SelectFields( + InsertFields( + Ref(name, old.typ), + copiedNewFieldRefs(), + Some(x.typ.fieldNames.toFastSeq), + ), + fds, + ) case ta: TableAggregate => ta case ma: MatrixAggregate => ma case _ => ir1.mapChildrenWithIndex { @@ -493,9 +520,13 @@ object Simplify { } } - Let( - before.toFastSeq ++ fieldNames.map(f => newFieldRefs(f).name -> newFieldMap(f)) ++ FastSeq(name -> old), - rewrite(Let(after.toFastSeq, body)) + Block( + before.toFastSeq ++ fieldNames.map(f => + Binding(newFieldRefs(f).name, newFieldMap(f)) + ) ++ FastSeq( + Binding(name, old) + ), + rewrite(Block(after.toFastSeq, body)), ) case SelectFields(old, fields) if tcoerce[TStruct](old.typ).fieldNames sameElements fields => @@ -508,27 +539,37 @@ object Simplify { val makeStructFields = fields.toMap MakeStruct(fieldNames.map(f => f -> makeStructFields(f))) - case x@SelectFields(InsertFields(struct, insertFields, _), selectFields) => + case x @ SelectFields(InsertFields(struct, insertFields, _), selectFields) => val selectSet = selectFields.toSet val insertFields2 = insertFields.filter { case (fName, _) => selectSet.contains(fName) } val structSet = struct.typ.asInstanceOf[TStruct].fieldNames.toSet val selectFields2 = selectFields.filter(structSet.contains) - val x2 = InsertFields(SelectFields(struct, selectFields2), insertFields2, Some(selectFields.toFastSeq)) + val x2 = InsertFields( + SelectFields(struct, selectFields2), + insertFields2, + Some(selectFields.toFastSeq), + ) assert(x2.typ == x.typ) x2 - case x@InsertFields(SelectFields(struct, selectFields), insertFields, _) if - insertFields.exists { case (name, f) => f == GetField(struct, name) } => + case x @ InsertFields(SelectFields(struct, selectFields), insertFields, _) if + insertFields.exists { case (name, f) => f == GetField(struct, name) } => val fields = x.typ.fieldNames val insertNames = insertFields.map(_._1).toSet val (oldFields, newFields) = - insertFields.partition { case (name, f) => f == GetField(struct, name) } - val preservedFields = selectFields.filter(f => !insertNames.contains(f)) ++ oldFields.map(_._1) + insertFields.partition { case (name, f) => f == GetField(struct, name) } + val preservedFields = + selectFields.filter(f => !insertNames.contains(f)) ++ oldFields.map(_._1) InsertFields(SelectFields(struct, preservedFields), newFields, Some(fields.toFastSeq)) + case MakeStructOfGetField(o, newNames) => + val select = SelectFields(o, newNames.map(_._1)) + CastRename(select, select.typ.asInstanceOf[TStruct].rename(newNames.toMap)) + case GetTupleElement(MakeTuple(xs), idx) => xs.find(_._1 == idx).get._2 - case TableCount(MatrixColsTable(child)) if child.columnCount.isDefined => I64(child.columnCount.get) + case TableCount(MatrixColsTable(child)) if child.columnCount.isDefined => + I64(child.columnCount.get) case TableCount(child) if child.partitionCounts.isDefined => I64(child.partitionCounts.get.sum) case TableCount(CastMatrixToTable(child, _, _)) => TableCount(MatrixRowsTable(child)) @@ -542,15 +583,21 @@ object Simplify { case TableCount(TableLeftJoinRightDistinct(child, _, _)) => TableCount(child) case TableCount(TableIntervalJoin(child, _, _, _)) => TableCount(child) case TableCount(TableRange(n, _)) => I64(n) - case TableCount(TableParallelize(rowsAndGlobal, _)) => Cast(ArrayLen(GetField(rowsAndGlobal, "rows")), TInt64) + case TableCount(TableParallelize(rowsAndGlobal, _)) => + Cast(ArrayLen(GetField(rowsAndGlobal, "rows")), TInt64) case TableCount(TableRename(child, _, _)) => TableCount(child) case TableCount(TableAggregateByKey(child, _)) => TableCount(TableDistinct(child)) case TableCount(TableExplode(child, path)) => - TableAggregate(child, + TableAggregate( + child, ApplyAggOp( FastSeq(), - FastSeq(ArrayLen(CastToArray(path.foldLeft[IR](Ref("row", child.typ.rowType)) { case (comb, s) => GetField(comb, s)})).toL), - AggSignature(Sum(), FastSeq(), FastSeq(TInt64)))) + FastSeq(ArrayLen(CastToArray(path.foldLeft[IR](Ref("row", child.typ.rowType)) { + case (comb, s) => GetField(comb, s) + })).toL), + AggSignature(Sum(), FastSeq(), FastSeq(TInt64)), + ), + ) case MatrixCount(child) if child.partitionCounts.isDefined || child.columnCount.isDefined => val rowCount = child.partitionCounts match { @@ -563,14 +610,15 @@ object Simplify { } MakeTuple.ordered(FastSeq(rowCount, colCount)) case MatrixCount(MatrixMapRows(child, _)) => MatrixCount(child) - case MatrixCount(MatrixMapCols(child,_, _)) => MatrixCount(child) - case MatrixCount(MatrixMapEntries(child,_)) => MatrixCount(child) - case MatrixCount(MatrixFilterEntries(child,_)) => MatrixCount(child) + case MatrixCount(MatrixMapCols(child, _, _)) => MatrixCount(child) + case MatrixCount(MatrixMapEntries(child, _)) => MatrixCount(child) + case MatrixCount(MatrixFilterEntries(child, _)) => MatrixCount(child) case MatrixCount(MatrixAnnotateColsTable(child, _, _)) => MatrixCount(child) case MatrixCount(MatrixAnnotateRowsTable(child, _, _, _)) => MatrixCount(child) case MatrixCount(MatrixRepartition(child, _, _)) => MatrixCount(child) case MatrixCount(MatrixRename(child, _, _, _, _)) => MatrixCount(child) - case TableCount(TableRead(_, false, r: MatrixBGENReader)) if r.params.includedVariants.isEmpty => + case TableCount(TableRead(_, false, r: MatrixBGENReader)) + if r.params.includedVariants.isEmpty => I64(r.nVariants) // TableGetGlobals should simplify very aggressively @@ -587,8 +635,11 @@ object Simplify { ) } - case TableGetGlobals(x@TableMultiWayZipJoin(children, _, globalName)) => - MakeStruct(FastSeq(globalName -> MakeArray(children.map(TableGetGlobals), TArray(children.head.typ.globalType)))) + case TableGetGlobals(TableMultiWayZipJoin(children, _, globalName)) => + MakeStruct(FastSeq(globalName -> MakeArray( + children.map(TableGetGlobals), + TArray(children.head.typ.globalType), + ))) case TableGetGlobals(TableLeftJoinRightDistinct(child, _, _)) => TableGetGlobals(child) case TableGetGlobals(TableMapRows(child, _)) => TableGetGlobals(child) case TableGetGlobals(TableMapGlobals(child, newGlobals)) => @@ -612,8 +663,9 @@ object Simplify { } case TableCollect(TableParallelize(x, _)) => x - case x@TableCollect(TableOrderBy(child, sortFields)) if sortFields.forall(_.sortOrder == Ascending) - && !child.typ.key.startsWith(sortFields.map(_.field)) => + case x @ TableCollect(TableOrderBy(child, sortFields)) + if sortFields.forall(_.sortOrder == Ascending) + && !child.typ.key.startsWith(sortFields.map(_.field)) => val uid = genUID() val uid2 = genUID() val left = genUID() @@ -623,28 +675,41 @@ object Simplify { val kvElement = MakeStruct(FastSeq( ("key", SelectFields(Ref(uid2, child.typ.rowType), sortFields.map(_.field))), - ("value", Ref(uid2, child.typ.rowType)))) + ("value", Ref(uid2, child.typ.rowType)), + )) val sorted = ArraySort( StreamMap( ToStream(GetField(Ref(uid, x.typ), "rows")), uid2, - kvElement + kvElement, ), left, right, - ApplyComparisonOp(LT(sortType), + ApplyComparisonOp( + LT(sortType), GetField(Ref(left, kvElement.typ), "key"), - GetField(Ref(right, kvElement.typ), "key"))) - Let(FastSeq(uid -> TableCollect(TableKeyBy(child, FastSeq()))), + GetField(Ref(right, kvElement.typ), "key"), + ), + ) + Let( + FastSeq(uid -> TableCollect(TableKeyBy(child, FastSeq()))), MakeStruct(FastSeq( - ("rows", ToArray(StreamMap(ToStream(sorted), - uid3, - GetField(Ref(uid3, sorted.typ.asInstanceOf[TArray].elementType), "value")))), - ("global", GetField(Ref(uid, x.typ), "global"))))) + ( + "rows", + ToArray(StreamMap( + ToStream(sorted), + uid3, + GetField(Ref(uid3, sorted.typ.asInstanceOf[TArray].elementType), "value"), + )), + ), + ("global", GetField(Ref(uid, x.typ), "global")), + )), + ) case ArrayLen(GetField(TableCollect(child), "rows")) => Cast(TableCount(child), TInt32) case GetField(TableCollect(child), "global") => TableGetGlobals(child) - case TableAggregate(child, query) if child.typ.key.nonEmpty && !ContainsNonCommutativeAgg(query) => + case TableAggregate(child, query) + if child.typ.key.nonEmpty && !ContainsNonCommutativeAgg(query) => TableAggregate(TableKeyBy(child, FastSeq(), false), query) case TableAggregate(TableOrderBy(child, _), query) if !ContainsNonCommutativeAgg(query) => if (child.typ.key.isEmpty) @@ -653,13 +718,22 @@ object Simplify { TableAggregate(TableKeyBy(child, FastSeq(), false), query) case TableAggregate(TableMapRows(child, newRow), query) if !ContainsScan(newRow) => val uid = genUID() - TableAggregate(child, - AggLet(uid, newRow, Subst(query, BindingEnv(agg = Some(Env("row" -> Ref(uid, newRow.typ))))), isScan = false)) + TableAggregate( + child, + AggLet( + uid, + newRow, + Subst(query, BindingEnv(agg = Some(Env("row" -> Ref(uid, newRow.typ))))), + isScan = false, + ), + ) - // NOTE: The below rule should be reintroduced when it is possible to put an ArrayAgg inside a TableAggregate + /* NOTE: The below rule should be reintroduced when it is possible to put an ArrayAgg inside a + * TableAggregate */ // case TableAggregate(TableParallelize(rowsAndGlobal, _), query) => // rowsAndGlobal match { - // // match because we currently don't optimize MakeStruct through Let, and this is a common pattern + /* // match because we currently don't optimize MakeStruct through Let, and this is a common + * pattern */ // case MakeStruct(Seq((_, rows), (_, global))) => // Let("global", global, ArrayAgg(rows, "row", query)) // case other => @@ -684,38 +758,47 @@ object Simplify { ApplyComparisonOp(ComparisonOp.negate(op.asInstanceOf[ComparisonOp[Boolean]]), l, r) case StreamAgg(_, _, query) if { - def canBeLifted(x: IR): Boolean = x match { - case _: TableAggregate => true - case _: MatrixAggregate => true - case AggLet(_, _, _, false) => false - case x if IsAggResult(x) => false - case other => other.children.forall { - case child: IR => canBeLifted(child) - case _: BaseIR => true - } - } - canBeLifted(query) - } => query + def canBeLifted(x: IR): Boolean = x match { + case _: TableAggregate => true + case _: MatrixAggregate => true + case Block(bindings, _) if bindings.exists { + case Binding(_, _, Scope.AGG) => true + case _ => false + } => false + case x if IsAggResult(x) => false + case other => other.children.forall { + case child: IR => canBeLifted(child) + case _: BaseIR => true + } + } + canBeLifted(query) + } => query case StreamAggScan(_, _, query) if { - def canBeLifted(x: IR): Boolean = x match { - case _: TableAggregate => true - case _: MatrixAggregate => true - case AggLet(_, _, _, true) => false - case x if IsScanResult(x) => false - case other => other.children.forall { - case child: IR => canBeLifted(child) - case _: BaseIR => true - } + def canBeLifted(x: IR): Boolean = x match { + case _: TableAggregate => true + case _: MatrixAggregate => true + case Block(bindings, _) if bindings.exists { + case Binding(_, _, Scope.SCAN) => true + case _ => false + } => false + case x if IsScanResult(x) => false + case other => other.children.forall { + case child: IR => canBeLifted(child) + case _: BaseIR => true + } + } + canBeLifted(query) + } => query + + case BlockMatrixToValueApply( + ValueToBlockMatrix(child, IndexedSeq(_, ncols), _), + functions.GetElement(Seq(i, j)), + ) => child.typ match { + case TArray(_) => ArrayRef(child, I32((i * ncols + j).toInt)) + case TNDArray(_, _) => NDArrayRef(child, IndexedSeq(i, j), ErrorIDs.NO_ERROR) + case TFloat64 => child } - canBeLifted(query) - } => query - - case BlockMatrixToValueApply(ValueToBlockMatrix(child, IndexedSeq(nrows, ncols), _), functions.GetElement(Seq(i, j))) => child.typ match { - case TArray(_) => ArrayRef(child, I32((i * ncols + j).toInt)) - case TNDArray(_, _) => NDArrayRef(child, IndexedSeq(i, j), ErrorIDs.NO_ERROR) - case TFloat64 => child - } case LiftMeOut(child) if IsConstant(child) => child } @@ -724,7 +807,7 @@ object Simplify { case TableRename(child, m1, m2) if m1.isTrivial && m2.isTrivial => child // TODO: Write more rules like this to bubble 'TableRename' nodes towards the root. - case t@TableRename(TableKeyBy(child, keys, isSorted), rowMap, globalMap) => + case t @ TableRename(TableKeyBy(child, keys, isSorted), rowMap, globalMap) => TableKeyBy(TableRename(child, rowMap, globalMap), keys.map(t.rowF), isSorted) case TableFilter(t, True()) => t @@ -733,8 +816,10 @@ object Simplify { TableRead(typ, dropRows = true, tr) case TableFilter(TableFilter(t, p1), p2) => - TableFilter(t, - ApplySpecial("land", Array.empty[Type], Array(p1, p2), TBoolean, ErrorIDs.NO_ERROR)) + TableFilter( + t, + ApplySpecial("land", Array.empty[Type], Array(p1, p2), TBoolean, ErrorIDs.NO_ERROR), + ) case TableFilter(TableKeyBy(child, key, isSorted), p) => TableKeyBy(TableFilter(child, p), key, isSorted) @@ -750,26 +835,33 @@ object Simplify { case TableFilter(TableParallelize(rowsAndGlobal, nPartitions), pred) => val newRowsAndGlobal = rowsAndGlobal match { case MakeStruct(Seq(("rows", rows), ("global", globalVal))) => - Let(FastSeq("global" -> globalVal), + Let( + FastSeq("global" -> globalVal), MakeStruct(FastSeq( ("rows", ToArray(StreamFilter(ToStream(rows), "row", pred))), - ("global", Ref("global", globalVal.typ))))) + ("global", Ref("global", globalVal.typ)), + )), + ) case _ => val uid = genUID() Let( FastSeq( uid -> rowsAndGlobal, - "global" -> GetField(Ref(uid, rowsAndGlobal.typ), "global") + "global" -> GetField(Ref(uid, rowsAndGlobal.typ), "global"), ), MakeStruct(FastSeq( - "rows" -> ToArray(StreamFilter(ToStream(GetField(Ref(uid, rowsAndGlobal.typ), "rows")), "row", pred)), - "global" -> Ref("global", rowsAndGlobal.typ.asInstanceOf[TStruct].fieldType("global")) - )) + "rows" -> ToArray(StreamFilter( + ToStream(GetField(Ref(uid, rowsAndGlobal.typ), "rows")), + "row", + pred, + )), + "global" -> Ref("global", rowsAndGlobal.typ.asInstanceOf[TStruct].fieldType("global")), + )), ) } TableParallelize(newRowsAndGlobal, nPartitions) - case TableKeyBy(TableOrderBy(child, sortFields), keys, false) => + case TableKeyBy(TableOrderBy(child, _), keys, false) => TableKeyBy(child, keys, false) case TableKeyBy(TableKeyBy(child, _, _), keys, false) => @@ -783,10 +875,10 @@ object Simplify { case TableMapRows(child, Ref("row", _)) => child case TableMapRows(child, MakeStruct(fields)) - if fields.length == child.typ.rowType.size - && fields.zip(child.typ.rowType.fields).forall { case ((_, ir), field) => - ir == GetField(Ref("row", field.typ), field.name) - } => + if fields.length == child.typ.rowType.size + && fields.zip(child.typ.rowType.fields).forall { case ((_, ir), field) => + ir == GetField(Ref("row", field.typ), field.name) + } => val renamedPairs = for { (oldName, (newName, _)) <- child.typ.rowType.fieldNames zip fields if oldName != newName @@ -822,9 +914,11 @@ object Simplify { val mrt = MatrixRowsTable(child) TableFilter( mrt, - Subst(pred, BindingEnv(Env("va" -> Ref("row", mrt.typ.rowType))))) + Subst(pred, BindingEnv(Env("va" -> Ref("row", mrt.typ.rowType)))), + ) - case MatrixRowsTable(MatrixMapGlobals(child, newGlobals)) => TableMapGlobals(MatrixRowsTable(child), newGlobals) + case MatrixRowsTable(MatrixMapGlobals(child, newGlobals)) => + TableMapGlobals(MatrixRowsTable(child), newGlobals) case MatrixRowsTable(MatrixMapCols(child, _, _)) => MatrixRowsTable(child) case MatrixRowsTable(MatrixMapEntries(child, _)) => MatrixRowsTable(child) case MatrixRowsTable(MatrixFilterEntries(child, _)) => MatrixRowsTable(child) @@ -832,18 +926,21 @@ object Simplify { case MatrixRowsTable(MatrixAggregateColsByKey(child, _, _)) => MatrixRowsTable(child) case MatrixRowsTable(MatrixChooseCols(child, _)) => MatrixRowsTable(child) case MatrixRowsTable(MatrixCollectColsByKey(child)) => MatrixRowsTable(child) - case MatrixRowsTable(MatrixKeyRowsBy(child, keys, isSorted)) => TableKeyBy(MatrixRowsTable(child), keys, isSorted) + case MatrixRowsTable(MatrixKeyRowsBy(child, keys, isSorted)) => + TableKeyBy(MatrixRowsTable(child), keys, isSorted) - case MatrixColsTable(x@MatrixMapCols(child, newRow, newKey)) - if newKey.isEmpty - && !ContainsAgg(newRow) - && !ContainsScan(newRow) => + case MatrixColsTable(MatrixMapCols(child, newRow, newKey)) + if newKey.isEmpty + && !ContainsAgg(newRow) + && !ContainsScan(newRow) => val mct = MatrixColsTable(child) TableMapRows( mct, - Subst(newRow, BindingEnv(Env("sa" -> Ref("row", mct.typ.rowType))))) + Subst(newRow, BindingEnv(Env("sa" -> Ref("row", mct.typ.rowType)))), + ) - case MatrixColsTable(MatrixMapGlobals(child, newGlobals)) => TableMapGlobals(MatrixColsTable(child), newGlobals) + case MatrixColsTable(MatrixMapGlobals(child, newGlobals)) => + TableMapGlobals(MatrixColsTable(child), newGlobals) case MatrixColsTable(MatrixMapRows(child, _)) => MatrixColsTable(child) case MatrixColsTable(MatrixMapEntries(child, _)) => MatrixColsTable(child) case MatrixColsTable(MatrixFilterEntries(child, _)) => MatrixColsTable(child) @@ -854,12 +951,11 @@ object Simplify { case TableRepartition(TableRange(nRows, _), nParts, _) => TableRange(nRows, nParts) case TableMapGlobals(TableMapGlobals(child, ng1), ng2) => - TableMapGlobals(child, bindIR(ng1) { uid => - Subst(ng2, BindingEnv(Env("global" -> uid))) - }) + TableMapGlobals(child, bindIR(ng1)(uid => Subst(ng2, BindingEnv(Env("global" -> uid))))) case TableHead(MatrixColsTable(child), n) if child.typ.colKey.isEmpty => - if (n > Int.MaxValue) MatrixColsTable(child) else MatrixColsTable(MatrixColsHead(child, n.toInt)) + if (n > Int.MaxValue) MatrixColsTable(child) + else MatrixColsTable(MatrixColsHead(child, n.toInt)) case TableHead(TableMapRows(child, newRow), n) => TableMapRows(TableHead(child, n), newRow) @@ -867,7 +963,7 @@ object Simplify { case TableHead(TableRepartition(child, nPar, shuffle), n) => TableRepartition(TableHead(child, n), nPar, shuffle) - case TableHead(tr@TableRange(nRows, nPar), n) => + case TableHead(tr @ TableRange(nRows, nPar), n) => if (n < nRows) TableRange(n.toInt, (nPar.toFloat * n / nRows).toInt.max(1)) else @@ -877,43 +973,65 @@ object Simplify { TableMapGlobals(TableHead(child, n), newGlobals) case TableHead(TableOrderBy(child, sortFields), n) - if !TableOrderBy.isAlreadyOrdered(sortFields, child.typ.key) // FIXME: https://github.com/hail-is/hail/issues/6234 - && sortFields.forall(_.sortOrder == Ascending) - && n < 256 => + if !TableOrderBy.isAlreadyOrdered( + sortFields, + child.typ.key, + ) // FIXME: https://github.com/hail-is/hail/issues/6234 + && sortFields.forall(_.sortOrder == Ascending) + && n < 256 => // n < 256 is arbitrary for memory concerns val row = Ref("row", child.typ.rowType) val keyStruct = MakeStruct(sortFields.map(f => f.field -> GetField(row, f.field))) - val aggSig = AggSignature(TakeBy(), FastSeq(TInt32), FastSeq(row.typ, keyStruct.typ)) + val aggSig = AggSignature(TakeBy(), FastSeq(TInt32), FastSeq(row.typ, keyStruct.typ)) val te = TableExplode( - TableKeyByAndAggregate(child, + TableKeyByAndAggregate( + child, MakeStruct(FastSeq( "row" -> ApplyAggOp( FastSeq(I32(n.toInt)), Array(row, keyStruct), - aggSig))), + aggSig, + ) + )), MakeStruct(FastSeq()), // aggregate to one row - Some(1), 10), - FastSeq("row")) + Some(1), + 10, + ), + FastSeq("row"), + ) TableMapRows(te, GetField(Ref("row", te.typ.rowType), "row")) case TableDistinct(TableDistinct(child)) => TableDistinct(child) case TableDistinct(TableAggregateByKey(child, expr)) => TableAggregateByKey(child, expr) case TableDistinct(TableMapRows(child, newRow)) => TableMapRows(TableDistinct(child), newRow) - case TableDistinct(TableLeftJoinRightDistinct(child, right, root)) => TableLeftJoinRightDistinct(TableDistinct(child), right, root) - case TableDistinct(TableRepartition(child, n, strategy)) => TableRepartition(TableDistinct(child), n, strategy) + case TableDistinct(TableLeftJoinRightDistinct(child, right, root)) => + TableLeftJoinRightDistinct(TableDistinct(child), right, root) + case TableDistinct(TableRepartition(child, n, strategy)) => + TableRepartition(TableDistinct(child), n, strategy) - case TableKeyByAndAggregate(child, MakeStruct(Seq()), k@MakeStruct(keyFields), _, _) => - TableDistinct(TableKeyBy(TableMapRows(TableKeyBy(child, FastSeq()), k), k.typ.asInstanceOf[TStruct].fieldNames)) + case TableKeyByAndAggregate(child, MakeStruct(Seq()), k @ MakeStruct(_), _, _) => + TableDistinct(TableKeyBy( + TableMapRows(TableKeyBy(child, FastSeq()), k), + k.typ.asInstanceOf[TStruct].fieldNames, + )) case TableKeyByAndAggregate(child, expr, newKey, _, _) - if (newKey == MakeStruct(child.typ.key.map(k => k -> GetField(Ref("row", child.typ.rowType), k))) || - newKey == SelectFields(Ref("row", child.typ.rowType), child.typ.key)) - && child.typ.key.nonEmpty => + if (newKey == MakeStruct(child.typ.key.map(k => + k -> GetField(Ref("row", child.typ.rowType), k) + )) || + newKey == SelectFields(Ref("row", child.typ.rowType), child.typ.key)) + && child.typ.key.nonEmpty => TableAggregateByKey(child, expr) - case TableAggregateByKey(x@TableKeyBy(child, keys, false), expr) if !x.definitelyDoesNotShuffle => - TableKeyByAndAggregate(child, expr, MakeStruct(keys.map(k => k -> GetField(Ref("row", child.typ.rowType), k))), bufferSize = ctx.getFlag("grouped_aggregate_buffer_size").toInt) + case TableAggregateByKey(x @ TableKeyBy(child, keys, false), expr) + if !x.definitelyDoesNotShuffle => + TableKeyByAndAggregate( + child, + expr, + MakeStruct(keys.map(k => k -> GetField(Ref("row", child.typ.rowType), k))), + bufferSize = ctx.getFlag("grouped_aggregate_buffer_size").toInt, + ) case TableParallelize(TableCollect(child), _) => child @@ -926,7 +1044,8 @@ object Simplify { // push down filter intervals nodes case TableFilterIntervals(TableFilter(child, pred), intervals, keep) => TableFilter(TableFilterIntervals(child, intervals, keep), pred) - case TableFilterIntervals(TableMapRows(child, newRow), intervals, keep) if !ContainsScan(newRow) => + case TableFilterIntervals(TableMapRows(child, newRow), intervals, keep) + if !ContainsScan(newRow) => TableMapRows(TableFilterIntervals(child, intervals, keep), newRow) case TableFilterIntervals(TableMapGlobals(child, newRow), intervals, keep) => TableMapGlobals(TableFilterIntervals(child, intervals, keep), newRow) @@ -935,70 +1054,110 @@ object Simplify { case TableFilterIntervals(TableRepartition(child, n, strategy), intervals, keep) => TableRepartition(TableFilterIntervals(child, intervals, keep), n, strategy) case TableFilterIntervals(TableLeftJoinRightDistinct(child, right, root), intervals, true) => - TableLeftJoinRightDistinct(TableFilterIntervals(child, intervals, true), TableFilterIntervals(right, intervals, true), root) + TableLeftJoinRightDistinct( + TableFilterIntervals(child, intervals, true), + TableFilterIntervals(right, intervals, true), + root, + ) case TableFilterIntervals(TableIntervalJoin(child, right, root, product), intervals, keep) => TableIntervalJoin(TableFilterIntervals(child, intervals, keep), right, root, product) case TableFilterIntervals(TableJoin(left, right, jt, jk), intervals, true) => - TableJoin(TableFilterIntervals(left, intervals, true), TableFilterIntervals(right, intervals, true), jt, jk) + TableJoin( + TableFilterIntervals(left, intervals, true), + TableFilterIntervals(right, intervals, true), + jt, + jk, + ) case TableFilterIntervals(TableExplode(child, path), intervals, keep) => TableExplode(TableFilterIntervals(child, intervals, keep), path) case TableFilterIntervals(TableAggregateByKey(child, expr), intervals, keep) => TableAggregateByKey(TableFilterIntervals(child, intervals, keep), expr) - case TableFilterIntervals(TableFilterIntervals(child, _i1, keep1), _i2, keep2) if keep1 == keep2 => + case TableFilterIntervals(TableFilterIntervals(child, _i1, keep1), _i2, keep2) + if keep1 == keep2 => val ord = PartitionBoundOrdering(ctx, child.typ.keyType).intervalEndpointOrdering val i1 = Interval.union(_i1.toArray[Interval], ord) val i2 = Interval.union(_i2.toArray[Interval], ord) val intervals = if (keep1) - // keep means intersect intervals + // keep means intersect intervals Interval.intersection(i1, i2, ord) else - // remove means union intervals + // remove means union intervals Interval.union(i1 ++ i2, ord) TableFilterIntervals(child, intervals.toFastSeq, keep1) - // FIXME: Can try to serialize intervals shorter than the key - // case TableFilterIntervals(k@TableKeyBy(child, keys, isSorted), intervals, keep) if !child.typ.key.startsWith(keys) => - // val ord = k.typ.keyType.ordering.intervalEndpointOrdering - // val maybeFlip: IR => IR = if (keep) identity else !_ - // val pred = maybeFlip(invoke("sortedNonOverlappingIntervalsContain", - // TBoolean, - // Literal(TArray(TInterval(k.typ.keyType)), Interval.union(intervals.toArray, ord).toFastIndexedSeq), - // MakeStruct(k.typ.keyType.fieldNames.map { keyField => - // (keyField, GetField(Ref("row", child.typ.rowType), keyField)) - // }))) - // TableKeyBy(TableFilter(child, pred), keys, isSorted) + // FIXME: Can try to serialize intervals shorter than the key + /* case TableFilterIntervals(k@TableKeyBy(child, keys, isSorted), intervals, keep) if + * !child.typ.key.startsWith(keys) => */ + // val ord = k.typ.keyType.ordering.intervalEndpointOrdering + // val maybeFlip: IR => IR = if (keep) identity else !_ + // val pred = maybeFlip(invoke("sortedNonOverlappingIntervalsContain", + // TBoolean, + /* Literal(TArray(TInterval(k.typ.keyType)), Interval.union(intervals.toArray, + * ord).toFastIndexedSeq), */ + // MakeStruct(k.typ.keyType.fieldNames.map { keyField => + // (keyField, GetField(Ref("row", child.typ.rowType), keyField)) + // }))) + // TableKeyBy(TableFilter(child, pred), keys, isSorted) case TableFilterIntervals(TableRead(t, false, tr: TableNativeReader), intervals, true) - if tr.spec.indexed - && tr.params.options.forall(_.filterIntervals) - && SemanticVersion(tr.spec.file_version) >= SemanticVersion(1, 3, 0) => + if tr.spec.indexed + && tr.params.options.forall(_.filterIntervals) + && SemanticVersion(tr.spec.file_version) >= SemanticVersion(1, 3, 0) => val newOpts = tr.params.options match { case None => val pt = t.keyType - NativeReaderOptions(Interval.union(intervals, PartitionBoundOrdering(ctx, pt).intervalEndpointOrdering), pt, true) + NativeReaderOptions( + Interval.union(intervals, PartitionBoundOrdering(ctx, pt).intervalEndpointOrdering), + pt, + true, + ) case Some(NativeReaderOptions(preIntervals, intervalPointType, _)) => val iord = PartitionBoundOrdering(ctx, intervalPointType).intervalEndpointOrdering NativeReaderOptions( - Interval.intersection(Interval.union(preIntervals, iord), Interval.union(intervals, iord), iord), - intervalPointType, true) + Interval.intersection( + Interval.union(preIntervals, iord), + Interval.union(intervals, iord), + iord, + ), + intervalPointType, + true, + ) } - TableRead(t, false, new TableNativeReader(TableNativeReaderParameters(tr.params.path, Some(newOpts)), tr.spec)) + TableRead( + t, + false, + new TableNativeReader(TableNativeReaderParameters(tr.params.path, Some(newOpts)), tr.spec), + ) case TableFilterIntervals(TableRead(t, false, tr: TableNativeZippedReader), intervals, true) - if tr.specLeft.indexed - && tr.options.forall(_.filterIntervals) - && SemanticVersion(tr.specLeft.file_version) >= SemanticVersion(1, 3, 0) => + if tr.specLeft.indexed + && tr.options.forall(_.filterIntervals) + && SemanticVersion(tr.specLeft.file_version) >= SemanticVersion(1, 3, 0) => val newOpts = tr.options match { case None => val pt = t.keyType - NativeReaderOptions(Interval.union(intervals, PartitionBoundOrdering(ctx, pt).intervalEndpointOrdering), pt, true) + NativeReaderOptions( + Interval.union(intervals, PartitionBoundOrdering(ctx, pt).intervalEndpointOrdering), + pt, + true, + ) case Some(NativeReaderOptions(preIntervals, intervalPointType, _)) => val iord = PartitionBoundOrdering(ctx, intervalPointType).intervalEndpointOrdering NativeReaderOptions( - Interval.intersection(Interval.union(preIntervals, iord), Interval.union(intervals, iord), iord), - intervalPointType, true) + Interval.intersection( + Interval.union(preIntervals, iord), + Interval.union(intervals, iord), + iord, + ), + intervalPointType, + true, + ) } - TableRead(t, false, TableNativeZippedReader(tr.pathLeft, tr.pathRight, Some(newOpts), tr.specLeft, tr.specRight)) + TableRead( + t, + false, + TableNativeZippedReader(tr.pathLeft, tr.pathRight, Some(newOpts), tr.specLeft, tr.specRight), + ) } private[this] def matrixRules(): PartialFunction[MatrixIR, MatrixIR] = { @@ -1012,14 +1171,15 @@ object Simplify { case MatrixMapCols(child, Ref("sa", _), None) => child - case x@MatrixMapEntries(child, Ref("g", _)) => + case x @ MatrixMapEntries(child, Ref("g", _)) => assert(child.typ == x.typ) child case MatrixMapEntries(MatrixMapEntries(child, newEntries1), newEntries2) => - MatrixMapEntries(child, bindIR(newEntries1) { uid => - Subst(newEntries2, BindingEnv(Env("g" -> uid))) - }) + MatrixMapEntries( + child, + bindIR(newEntries1)(uid => Subst(newEntries2, BindingEnv(Env("g" -> uid)))), + ) case MatrixMapGlobals(child, Ref("global", _)) => child @@ -1042,92 +1202,201 @@ object Simplify { case MatrixFilterCols(m, True()) => m - case MatrixFilterRows(MatrixFilterRows(child, pred1), pred2) => MatrixFilterRows(child, ApplySpecial("land", FastSeq(), FastSeq(pred1, pred2), TBoolean, ErrorIDs.NO_ERROR)) + case MatrixFilterRows(MatrixFilterRows(child, pred1), pred2) => MatrixFilterRows( + child, + ApplySpecial("land", FastSeq(), FastSeq(pred1, pred2), TBoolean, ErrorIDs.NO_ERROR), + ) - case MatrixFilterCols(MatrixFilterCols(child, pred1), pred2) => MatrixFilterCols(child, ApplySpecial("land", FastSeq(), FastSeq(pred1, pred2), TBoolean, ErrorIDs.NO_ERROR)) + case MatrixFilterCols(MatrixFilterCols(child, pred1), pred2) => MatrixFilterCols( + child, + ApplySpecial("land", FastSeq(), FastSeq(pred1, pred2), TBoolean, ErrorIDs.NO_ERROR), + ) - case MatrixFilterEntries(MatrixFilterEntries(child, pred1), pred2) => MatrixFilterEntries(child, ApplySpecial("land", FastSeq(), FastSeq(pred1, pred2), TBoolean, ErrorIDs.NO_ERROR)) + case MatrixFilterEntries(MatrixFilterEntries(child, pred1), pred2) => MatrixFilterEntries( + child, + ApplySpecial("land", FastSeq(), FastSeq(pred1, pred2), TBoolean, ErrorIDs.NO_ERROR), + ) case MatrixMapGlobals(MatrixMapGlobals(child, ng1), ng2) => - MatrixMapGlobals(child, bindIR(ng1) { uid => Subst(ng2, BindingEnv(Env("global" -> uid))) }) - - // Note: the following MMR and MMC fusing rules are much weaker than they could be. If they contain aggregations - // but those aggregations that mention "row" / "sa" but do not depend on the updated value, we should locally - // prune and fuse anyway. - case MatrixMapRows(MatrixMapRows(child, newRow1), newRow2) if !Mentions.inAggOrScan(newRow2, "va") - && !Exists.inIR(newRow2, { - case a: ApplyAggOp => a.initOpArgs.exists(Mentions(_, "va")) // Lowering produces invalid IR - case _ => false - }) => - MatrixMapRows(child, bindIR(newRow1) { uid => - Subst(newRow2, BindingEnv[IR]( - Env("va" -> uid), - agg = Some(Env.empty[IR]), - scan = Some(Env.empty[IR]) - )) - }) + MatrixMapGlobals(child, bindIR(ng1)(uid => Subst(ng2, BindingEnv(Env("global" -> uid))))) + + /* Note: the following MMR and MMC fusing rules are much weaker than they could be. If they + * contain aggregations but those aggregations that mention "row" / "sa" but do not depend on + * the updated value, we should locally prune and fuse anyway. */ + case MatrixMapRows(MatrixMapRows(child, newRow1), newRow2) + if !Mentions.inAggOrScan(newRow2, "va") + && !Exists.inIR( + newRow2, + { + case a: ApplyAggOp => + a.initOpArgs.exists(Mentions(_, "va")) // Lowering produces invalid IR + case _ => false + }, + ) => + MatrixMapRows( + child, + bindIR(newRow1) { uid => + Subst( + newRow2, + BindingEnv[IR]( + Env("va" -> uid), + agg = Some(Env.empty[IR]), + scan = Some(Env.empty[IR]), + ), + ) + }, + ) - case MatrixMapCols(MatrixMapCols(child, newCol1, nk1), newCol2, nk2) if !Mentions.inAggOrScan(newCol2, "sa") => - MatrixMapCols(child, + case MatrixMapCols(MatrixMapCols(child, newCol1, nk1), newCol2, nk2) + if !Mentions.inAggOrScan(newCol2, "sa") => + MatrixMapCols( + child, bindIR(newCol1) { uid => - Subst(newCol2, BindingEnv[IR]( - Env("sa" -> uid), - agg = Some(Env.empty[IR]), - scan = Some(Env.empty[IR])) + Subst( + newCol2, + BindingEnv[IR]( + Env("sa" -> uid), + agg = Some(Env.empty[IR]), + scan = Some(Env.empty[IR]), + ), ) }, - nk2.orElse(nk1) + nk2.orElse(nk1), ) // bubble up MatrixColsHead node - case MatrixColsHead(MatrixMapCols(child, newCol, newKey), n) => MatrixMapCols(MatrixColsHead(child, n), newCol, newKey) - case MatrixColsHead(MatrixMapEntries(child, newEntries), n) => MatrixMapEntries(MatrixColsHead(child, n), newEntries) - case MatrixColsHead(MatrixFilterEntries(child, newEntries), n) => MatrixFilterEntries(MatrixColsHead(child, n), newEntries) - case MatrixColsHead(MatrixKeyRowsBy(child, keys, isSorted), n) => MatrixKeyRowsBy(MatrixColsHead(child, n), keys, isSorted) - case MatrixColsHead(MatrixAggregateRowsByKey(child, rowExpr, entryExpr), n) => MatrixAggregateRowsByKey(MatrixColsHead(child, n), rowExpr, entryExpr) - case MatrixColsHead(MatrixChooseCols(child, oldIndices), n) => MatrixChooseCols(child, oldIndices.take(n)) + case MatrixColsHead(MatrixMapCols(child, newCol, newKey), n) => + MatrixMapCols(MatrixColsHead(child, n), newCol, newKey) + case MatrixColsHead(MatrixMapEntries(child, newEntries), n) => + MatrixMapEntries(MatrixColsHead(child, n), newEntries) + case MatrixColsHead(MatrixFilterEntries(child, newEntries), n) => + MatrixFilterEntries(MatrixColsHead(child, n), newEntries) + case MatrixColsHead(MatrixKeyRowsBy(child, keys, isSorted), n) => + MatrixKeyRowsBy(MatrixColsHead(child, n), keys, isSorted) + case MatrixColsHead(MatrixAggregateRowsByKey(child, rowExpr, entryExpr), n) => + MatrixAggregateRowsByKey(MatrixColsHead(child, n), rowExpr, entryExpr) + case MatrixColsHead(MatrixChooseCols(child, oldIndices), n) => + MatrixChooseCols(child, oldIndices.take(n)) case MatrixColsHead(MatrixColsHead(child, n1), n2) => MatrixColsHead(child, math.min(n1, n2)) - case MatrixColsHead(MatrixFilterRows(child, pred), n) => MatrixFilterRows(MatrixColsHead(child, n), pred) + case MatrixColsHead(MatrixFilterRows(child, pred), n) => + MatrixFilterRows(MatrixColsHead(child, n), pred) case MatrixColsHead(MatrixRead(t, dr, dc, r: MatrixRangeReader), n) => - MatrixRead(t, dr, dc, MatrixRangeReader(r.params.nRows, math.min(r.params.nCols, n), r.params.nPartitions)) + MatrixRead( + t, + dr, + dc, + MatrixRangeReader(r.params.nRows, math.min(r.params.nCols, n), r.params.nPartitions), + ) case MatrixColsHead(MatrixMapRows(child, newRow), n) if !Mentions.inAggOrScan(newRow, "sa") => MatrixMapRows(MatrixColsHead(child, n), newRow) - case MatrixColsHead(MatrixMapGlobals(child, newGlobals), n) => MatrixMapGlobals(MatrixColsHead(child, n), newGlobals) - case MatrixColsHead(MatrixAnnotateColsTable(child, table, root), n) => MatrixAnnotateColsTable(MatrixColsHead(child, n), table, root) - case MatrixColsHead(MatrixAnnotateRowsTable(child, table, root, product), n) => MatrixAnnotateRowsTable(MatrixColsHead(child, n), table, root, product) - case MatrixColsHead(MatrixRepartition(child, nPar, strategy), n) => MatrixRepartition(MatrixColsHead(child, n), nPar, strategy) - case MatrixColsHead(MatrixExplodeRows(child, path), n) => MatrixExplodeRows(MatrixColsHead(child, n), path) + case MatrixColsHead(MatrixMapGlobals(child, newGlobals), n) => + MatrixMapGlobals(MatrixColsHead(child, n), newGlobals) + case MatrixColsHead(MatrixAnnotateColsTable(child, table, root), n) => + MatrixAnnotateColsTable(MatrixColsHead(child, n), table, root) + case MatrixColsHead(MatrixAnnotateRowsTable(child, table, root, product), n) => + MatrixAnnotateRowsTable(MatrixColsHead(child, n), table, root, product) + case MatrixColsHead(MatrixRepartition(child, nPar, strategy), n) => + MatrixRepartition(MatrixColsHead(child, n), nPar, strategy) + case MatrixColsHead(MatrixExplodeRows(child, path), n) => + MatrixExplodeRows(MatrixColsHead(child, n), path) case MatrixColsHead(MatrixUnionRows(children), n) => - // could prevent a dimension mismatch error, but we view errors as undefined behavior, so this seems OK. + /* could prevent a dimension mismatch error, but we view errors as undefined behavior, so this + * seems OK. */ MatrixUnionRows(children.map(MatrixColsHead(_, n))) - case MatrixColsHead(MatrixDistinctByRow(child), n) => MatrixDistinctByRow(MatrixColsHead(child, n)) - case MatrixColsHead(MatrixRename(child, glob, col, row, entry), n) => MatrixRename(MatrixColsHead(child, n), glob, col, row, entry) + case MatrixColsHead(MatrixDistinctByRow(child), n) => + MatrixDistinctByRow(MatrixColsHead(child, n)) + case MatrixColsHead(MatrixRename(child, glob, col, row, entry), n) => + MatrixRename(MatrixColsHead(child, n), glob, col, row, entry) } private[this] def blockMatrixRules: PartialFunction[BlockMatrixIR, BlockMatrixIR] = { case BlockMatrixBroadcast(child, IndexedSeq(0, 1), _, _) => child - case BlockMatrixSlice(BlockMatrixMap(child, n, f, reqDense), slices) => BlockMatrixMap(BlockMatrixSlice(child, slices), n, f, reqDense) + case BlockMatrixSlice(BlockMatrixMap(child, n, f, reqDense), slices) => + BlockMatrixMap(BlockMatrixSlice(child, slices), n, f, reqDense) case BlockMatrixSlice(BlockMatrixMap2(l, r, ln, rn, f, sparsityStrategy), slices) => - BlockMatrixMap2(BlockMatrixSlice(l, slices), BlockMatrixSlice(r, slices), ln, rn, f, sparsityStrategy) - case BlockMatrixMap2(BlockMatrixBroadcast(scalarBM, IndexedSeq(), _, _), right, leftName, rightName, f, sparsityStrategy) => + BlockMatrixMap2( + BlockMatrixSlice(l, slices), + BlockMatrixSlice(r, slices), + ln, + rn, + f, + sparsityStrategy, + ) + case BlockMatrixMap2( + BlockMatrixBroadcast(scalarBM, IndexedSeq(), _, _), + right, + leftName, + rightName, + f, + sparsityStrategy, + ) => val getElement = BlockMatrixToValueApply(scalarBM, functions.GetElement(IndexedSeq(0, 0))) - val needsDense = sparsityStrategy == NeedsDense || sparsityStrategy.exists(leftBlock = true, rightBlock = false) + val needsDense = sparsityStrategy == NeedsDense || sparsityStrategy.exists( + leftBlock = true, + rightBlock = false, + ) val maybeDense = if (needsDense) BlockMatrixDensify(right) else right - BlockMatrixMap(maybeDense, rightName, Subst(f, BindingEnv.eval(leftName -> getElement)), needsDense) - case BlockMatrixMap2(left, BlockMatrixBroadcast(scalarBM, IndexedSeq(), _, _), leftName, rightName, f, sparsityStrategy) => + BlockMatrixMap( + maybeDense, + rightName, + Subst(f, BindingEnv.eval(leftName -> getElement)), + needsDense, + ) + case BlockMatrixMap2( + left, + BlockMatrixBroadcast(scalarBM, IndexedSeq(), _, _), + leftName, + rightName, + f, + sparsityStrategy, + ) => val getElement = BlockMatrixToValueApply(scalarBM, functions.GetElement(IndexedSeq(0, 0))) - val needsDense = sparsityStrategy == NeedsDense || sparsityStrategy.exists(leftBlock = false, rightBlock = true) + val needsDense = sparsityStrategy == NeedsDense || sparsityStrategy.exists( + leftBlock = false, + rightBlock = true, + ) val maybeDense = if (needsDense) BlockMatrixDensify(left) else left - BlockMatrixMap(maybeDense, leftName, Subst(f, BindingEnv.eval(rightName -> getElement)), needsDense) + BlockMatrixMap( + maybeDense, + leftName, + Subst(f, BindingEnv.eval(rightName -> getElement)), + needsDense, + ) case BlockMatrixMap(matrix, name, Ref(x, _), _) if name == x => matrix - case BlockMatrixMap(matrix, name, ir, _) if IsConstant(ir) || (ir.isInstanceOf[Ref] && ir.asInstanceOf[Ref].name != name) => + case BlockMatrixMap(matrix, name, ir, _) + if IsConstant(ir) || (ir.isInstanceOf[Ref] && ir.asInstanceOf[Ref].name != name) => val typ = matrix.typ BlockMatrixBroadcast( ValueToBlockMatrix(ir, FastSeq(1, 1), typ.blockSize), FastSeq(), typ.shape, - typ.blockSize + typ.blockSize, ) } + + // Match on expressions of the form + // MakeStruct(IndexedSeq(a -> GetField(o, x) [, b -> GetField(o, y), ...])) + // where + // - all fields are extracted from the same object, `o` + // - all references to the fields in o are unique + private object MakeStructOfGetField { + def unapply(ir: IR): Option[(IR, IndexedSeq[(String, String)])] = + ir match { + case MakeStruct(fields) if fields.nonEmpty => + val names = mutable.HashSet.empty[String] + val rewrites = new BoxedArrayBuilder[(String, String)](fields.length) + + fields.view.map { + case (a, GetField(o, b)) if names.add(b) => + rewrites += (b -> a) + Some(o) + case _ => None + } + .reduce((a, b) => if (a == b) a else None) + .map(_ -> rewrites.underlying().toFastSeq) + case _ => + None + } + } } diff --git a/hail/src/main/scala/is/hail/expr/ir/SpecializedArrayBuilders.scala b/hail/src/main/scala/is/hail/expr/ir/SpecializedArrayBuilders.scala index 39a28150f5d..55b3f7f7d8a 100644 --- a/hail/src/main/scala/is/hail/expr/ir/SpecializedArrayBuilders.scala +++ b/hail/src/main/scala/is/hail/expr/ir/SpecializedArrayBuilders.scala @@ -3,23 +3,40 @@ package is.hail.expr.ir import is.hail.annotations.Region import is.hail.asm4s._ import is.hail.types.physical.stypes.SingleCodeType -import is.hail.types.physical.{PType, typeToTypeInfo} -import is.hail.types.virtual.Type -import is.hail.utils.BoxedArrayBuilder import scala.reflect.ClassTag -class StagedArrayBuilder(cb: EmitCodeBuilder, val elt: SingleCodeType, val eltRequired: Boolean, len: Int) { +class StagedArrayBuilder( + cb: EmitCodeBuilder, + val elt: SingleCodeType, + val eltRequired: Boolean, + len: Int, +) { def mb = cb.emb val ti: TypeInfo[_] = elt.ti val ref: Value[Any] = coerce[Any](ti match { - case BooleanInfo => mb.genLazyFieldThisRef[BooleanMissingArrayBuilder](Code.newInstance[BooleanMissingArrayBuilder, Int](len), "zab") - case IntInfo => mb.genLazyFieldThisRef[IntMissingArrayBuilder](Code.newInstance[IntMissingArrayBuilder, Int](len), "iab") - case LongInfo => mb.genLazyFieldThisRef[LongMissingArrayBuilder](Code.newInstance[LongMissingArrayBuilder, Int](len), "jab") - case FloatInfo => mb.genLazyFieldThisRef[FloatMissingArrayBuilder](Code.newInstance[FloatMissingArrayBuilder, Int](len), "fab") - case DoubleInfo => mb.genLazyFieldThisRef[DoubleMissingArrayBuilder](Code.newInstance[DoubleMissingArrayBuilder, Int](len), "dab") + case BooleanInfo => mb.genLazyFieldThisRef[BooleanMissingArrayBuilder]( + Code.newInstance[BooleanMissingArrayBuilder, Int](len), + "zab", + ) + case IntInfo => mb.genLazyFieldThisRef[IntMissingArrayBuilder]( + Code.newInstance[IntMissingArrayBuilder, Int](len), + "iab", + ) + case LongInfo => mb.genLazyFieldThisRef[LongMissingArrayBuilder]( + Code.newInstance[LongMissingArrayBuilder, Int](len), + "jab", + ) + case FloatInfo => mb.genLazyFieldThisRef[FloatMissingArrayBuilder]( + Code.newInstance[FloatMissingArrayBuilder, Int](len), + "fab", + ) + case DoubleInfo => mb.genLazyFieldThisRef[DoubleMissingArrayBuilder]( + Code.newInstance[DoubleMissingArrayBuilder, Int](len), + "dab", + ) case ti => throw new RuntimeException(s"unsupported typeinfo found: $ti") }) @@ -30,11 +47,14 @@ class StagedArrayBuilder(cb: EmitCodeBuilder, val elt: SingleCodeType, val eltRe ensureCapacity(cb, len) def add(cb: EmitCodeBuilder, x: Code[_]): Unit = cb.append(ti match { - case BooleanInfo => coerce[BooleanMissingArrayBuilder](ref).invoke[Boolean, Unit]("add", coerce[Boolean](x)) + case BooleanInfo => + coerce[BooleanMissingArrayBuilder](ref).invoke[Boolean, Unit]("add", coerce[Boolean](x)) case IntInfo => coerce[IntMissingArrayBuilder](ref).invoke[Int, Unit]("add", coerce[Int](x)) case LongInfo => coerce[LongMissingArrayBuilder](ref).invoke[Long, Unit]("add", coerce[Long](x)) - case FloatInfo => coerce[FloatMissingArrayBuilder](ref).invoke[Float, Unit]("add", coerce[Float](x)) - case DoubleInfo => coerce[DoubleMissingArrayBuilder](ref).invoke[Double, Unit]("add", coerce[Double](x)) + case FloatInfo => + coerce[FloatMissingArrayBuilder](ref).invoke[Float, Unit]("add", coerce[Float](x)) + case DoubleInfo => + coerce[DoubleMissingArrayBuilder](ref).invoke[Double, Unit]("add", coerce[Double](x)) }) def apply(i: Code[Int]): Code[_] = ti match { @@ -46,11 +66,22 @@ class StagedArrayBuilder(cb: EmitCodeBuilder, val elt: SingleCodeType, val eltRe } def update(cb: EmitCodeBuilder, i: Code[Int], x: Code[_]): Unit = cb.append(ti match { - case BooleanInfo => coerce[BooleanMissingArrayBuilder](ref).invoke[Int, Boolean, Unit]("update", i, coerce[Boolean](x)) - case IntInfo => coerce[IntMissingArrayBuilder](ref).invoke[Int, Int, Unit]("update", i, coerce[Int](x)) - case LongInfo => coerce[LongMissingArrayBuilder](ref).invoke[Int, Long, Unit]("update", i, coerce[Long](x)) - case FloatInfo => coerce[FloatMissingArrayBuilder](ref).invoke[Int, Float, Unit]("update", i, coerce[Float](x)) - case DoubleInfo => coerce[DoubleMissingArrayBuilder](ref).invoke[Int, Double, Unit]("update", i, coerce[Double](x)) + case BooleanInfo => coerce[BooleanMissingArrayBuilder](ref).invoke[Int, Boolean, Unit]( + "update", + i, + coerce[Boolean](x), + ) + case IntInfo => + coerce[IntMissingArrayBuilder](ref).invoke[Int, Int, Unit]("update", i, coerce[Int](x)) + case LongInfo => + coerce[LongMissingArrayBuilder](ref).invoke[Int, Long, Unit]("update", i, coerce[Long](x)) + case FloatInfo => + coerce[FloatMissingArrayBuilder](ref).invoke[Int, Float, Unit]("update", i, coerce[Float](x)) + case DoubleInfo => coerce[DoubleMissingArrayBuilder](ref).invoke[Int, Double, Unit]( + "update", + i, + coerce[Double](x), + ) }) def addMissing(cb: EmitCodeBuilder): Unit = @@ -73,9 +104,8 @@ class StagedArrayBuilder(cb: EmitCodeBuilder, val elt: SingleCodeType, val eltRe def clear(cb: EmitCodeBuilder): Unit = cb += coerce[MissingArrayBuilder](ref).invoke[Unit]("clear") - def loadFromIndex(cb: EmitCodeBuilder, r: Value[Region], i: Value[Int]): IEmitCode = { + def loadFromIndex(cb: EmitCodeBuilder, r: Value[Region], i: Value[Int]): IEmitCode = IEmitCode(cb, isMissing(i), elt.loadToSValue(cb, cb.memoizeAny(apply(i), ti))) - } } sealed abstract class MissingArrayBuilder(initialCapacity: Int) { @@ -97,21 +127,22 @@ sealed abstract class MissingArrayBuilder(initialCapacity: Int) { missing(i) = m } - def addMissing() { + def addMissing(): Unit = { ensureCapacity(size_ + 1) missing(size_) = true size_ += 1 } - def setSize(n: Int) { + def setSize(n: Int): Unit = { require(n >= 0 && n <= size) size_ = n } - def clear() { size_ = 0 } + def clear(): Unit = size_ = 0 } -final class IntMissingArrayBuilder(initialCapacity: Int) extends MissingArrayBuilder(initialCapacity) { +final class IntMissingArrayBuilder(initialCapacity: Int) + extends MissingArrayBuilder(initialCapacity) { private var b: Array[Int] = new Array[Int](initialCapacity) def apply(i: Int): Int = { @@ -170,7 +201,8 @@ final class IntMissingArrayBuilder(initialCapacity: Int) extends MissingArrayBui } } -final class LongMissingArrayBuilder(initialCapacity: Int) extends MissingArrayBuilder(initialCapacity) { +final class LongMissingArrayBuilder(initialCapacity: Int) + extends MissingArrayBuilder(initialCapacity) { private var b: Array[Long] = new Array[Long](initialCapacity) def apply(i: Int): Long = { @@ -229,7 +261,8 @@ final class LongMissingArrayBuilder(initialCapacity: Int) extends MissingArrayBu } } -final class FloatMissingArrayBuilder(initialCapacity: Int) extends MissingArrayBuilder(initialCapacity) { +final class FloatMissingArrayBuilder(initialCapacity: Int) + extends MissingArrayBuilder(initialCapacity) { private var b: Array[Float] = new Array[Float](initialCapacity) def apply(i: Int): Float = { @@ -288,7 +321,8 @@ final class FloatMissingArrayBuilder(initialCapacity: Int) extends MissingArrayB } } -final class DoubleMissingArrayBuilder(initialCapacity: Int) extends MissingArrayBuilder(initialCapacity) { +final class DoubleMissingArrayBuilder(initialCapacity: Int) + extends MissingArrayBuilder(initialCapacity) { private var b: Array[Double] = new Array[Double](initialCapacity) def apply(i: Int): Double = { @@ -347,7 +381,8 @@ final class DoubleMissingArrayBuilder(initialCapacity: Int) extends MissingArray } } -final class BooleanMissingArrayBuilder(initialCapacity: Int) extends MissingArrayBuilder(initialCapacity) { +final class BooleanMissingArrayBuilder(initialCapacity: Int) + extends MissingArrayBuilder(initialCapacity) { private var b: Array[Boolean] = new Array[Boolean](initialCapacity) def apply(i: Int): Boolean = { @@ -413,7 +448,7 @@ final class ByteArrayArrayBuilder(initialCapacity: Int) { def size: Int = size_ - def setSize(n: Int) { + def setSize(n: Int): Unit = { require(n >= 0 && n <= size) size_ = n } @@ -443,20 +478,19 @@ final class ByteArrayArrayBuilder(initialCapacity: Int) { b(i) = x } - def clear() { size_ = 0 } + def clear(): Unit = size_ = 0 def result(): Array[Array[Byte]] = b.slice(0, size_) } - -final class LongArrayBuilder(initialCapacity: Int= 16) { +final class LongArrayBuilder(initialCapacity: Int = 16) { var size_ : Int = 0 var b: Array[Long] = new Array[Long](initialCapacity) def size: Int = size_ - def setSize(n: Int) { + def setSize(n: Int): Unit = { require(n >= 0 && n <= size) size_ = n } @@ -502,7 +536,7 @@ final class LongArrayBuilder(initialCapacity: Int= 16) { b(i) = x } - def clear() { size_ = 0 } + def clear(): Unit = size_ = 0 def result(): Array[Long] = b.slice(0, size_) @@ -511,6 +545,7 @@ final class LongArrayBuilder(initialCapacity: Int= 16) { if (b.length > initialCapacity) b = new Array[Long](initialCapacity) } + def appendFrom(ab2: LongArrayBuilder): Unit = { ensureCapacity(size_ + ab2.size_) System.arraycopy(ab2.b, 0, b, size_, ab2.size_) @@ -530,7 +565,7 @@ final class IntArrayBuilder(initialCapacity: Int = 16) { def size: Int = size_ - def setSize(n: Int) { + def setSize(n: Int): Unit = { require(n >= 0 && n <= size) size_ = n } @@ -581,7 +616,7 @@ final class IntArrayBuilder(initialCapacity: Int = 16) { b(i) = x } - def clear() { size_ = 0 } + def clear(): Unit = size_ = 0 def result(): Array[Int] = b.slice(0, size_) @@ -590,6 +625,7 @@ final class IntArrayBuilder(initialCapacity: Int = 16) { if (b.length > initialCapacity) b = new Array[Int](initialCapacity) } + def appendFrom(ab2: IntArrayBuilder): Unit = { ensureCapacity(size_ + ab2.size_) System.arraycopy(ab2.b, 0, b, size_, ab2.size_) @@ -609,7 +645,7 @@ final class DoubleArrayBuilder(initialCapacity: Int = 16) { def size: Int = size_ - def setSize(n: Int) { + def setSize(n: Int): Unit = { require(n >= 0 && n <= size) size_ = n } @@ -660,7 +696,7 @@ final class DoubleArrayBuilder(initialCapacity: Int = 16) { b(i) = x } - def clear() { size_ = 0 } + def clear(): Unit = size_ = 0 def result(): Array[Double] = b.slice(0, size_) @@ -669,6 +705,7 @@ final class DoubleArrayBuilder(initialCapacity: Int = 16) { if (b.length > initialCapacity) b = new Array[Double](initialCapacity) } + def appendFrom(ab2: DoubleArrayBuilder): Unit = { ensureCapacity(size_ + ab2.size_) System.arraycopy(ab2.b, 0, b, size_, ab2.size_) @@ -688,14 +725,13 @@ final class ByteArrayBuilder(initialCapacity: Int = 16) { def size: Int = size_ - def setSize(n: Int) { + def setSize(n: Int): Unit = { require(n >= 0 && n <= size) size_ = n } - def setSizeUnchecked(n: Int) { + def setSizeUnchecked(n: Int): Unit = size_ = n - } def apply(i: Int): Byte = { require(i >= 0 && i < size) @@ -738,7 +774,7 @@ final class ByteArrayBuilder(initialCapacity: Int = 16) { b(i) = x } - def clear() { size_ = 0 } + def clear(): Unit = size_ = 0 def result(): Array[Byte] = b.slice(0, size_) @@ -747,6 +783,7 @@ final class ByteArrayBuilder(initialCapacity: Int = 16) { if (b.length > initialCapacity) b = new Array[Byte](initialCapacity) } + def appendFrom(ab2: ByteArrayBuilder): Unit = { ensureCapacity(size_ + ab2.size_) System.arraycopy(ab2.b, 0, b, size_, ab2.size_) @@ -766,7 +803,7 @@ final class BooleanArrayBuilder(initialCapacity: Int = 16) { def size: Int = size_ - def setSize(n: Int) { + def setSize(n: Int): Unit = { require(n >= 0 && n <= size) size_ = n } @@ -798,7 +835,7 @@ final class BooleanArrayBuilder(initialCapacity: Int = 16) { b(i) = x } - def clear() { size_ = 0 } + def clear(): Unit = size_ = 0 def result(): Array[Boolean] = b.slice(0, size_) @@ -807,6 +844,7 @@ final class BooleanArrayBuilder(initialCapacity: Int = 16) { if (b.length > initialCapacity) b = new Array[Boolean](initialCapacity) } + def appendFrom(ab2: BooleanArrayBuilder): Unit = { ensureCapacity(size_ + ab2.size_) System.arraycopy(ab2.b, 0, b, size_, ab2.size_) @@ -825,7 +863,7 @@ final class StringArrayBuilder(initialCapacity: Int = 16) { def size: Int = size_ - def setSize(n: Int) { + def setSize(n: Int): Unit = { require(n >= 0 && n <= size) size_ = n } @@ -857,7 +895,7 @@ final class StringArrayBuilder(initialCapacity: Int = 16) { b(i) = x } - def clear() { size_ = 0 } + def clear(): Unit = size_ = 0 def result(): Array[String] = { val a = new Array[String](size_) @@ -870,6 +908,7 @@ final class StringArrayBuilder(initialCapacity: Int = 16) { if (b.length > initialCapacity) b = new Array[String](initialCapacity) } + def appendFrom(ab2: StringArrayBuilder): Unit = { ensureCapacity(size_ + ab2.size_) System.arraycopy(ab2.b, 0, b, size_, ab2.size_) @@ -889,7 +928,7 @@ final class AnyRefArrayBuilder[T <: AnyRef](initialCapacity: Int = 16)(implicit def size: Int = size_ - def setSize(n: Int) { + def setSize(n: Int): Unit = { require(n >= 0 && n <= size) size_ = n } @@ -921,7 +960,7 @@ final class AnyRefArrayBuilder[T <: AnyRef](initialCapacity: Int = 16)(implicit b(i) = x } - def clear() { size_ = 0 } + def clear(): Unit = size_ = 0 def result(): Array[T] = b.slice(0, size_) @@ -930,6 +969,7 @@ final class AnyRefArrayBuilder[T <: AnyRef](initialCapacity: Int = 16)(implicit if (b.length > initialCapacity) b = new Array[T](initialCapacity) } + def appendFrom(ab2: AnyRefArrayBuilder[T]): Unit = { ensureCapacity(size_ + ab2.size_) System.arraycopy(ab2.b, 0, b, size_, ab2.size_) @@ -954,4 +994,4 @@ final class AnyRefArrayBuilder[T <: AnyRef](initialCapacity: Int = 16)(implicit ensureCapacity(size) size_ = size } -} \ No newline at end of file +} diff --git a/hail/src/main/scala/is/hail/expr/ir/StringTableReader.scala b/hail/src/main/scala/is/hail/expr/ir/StringTableReader.scala index 96f147a5e23..b15cd829731 100644 --- a/hail/src/main/scala/is/hail/expr/ir/StringTableReader.scala +++ b/hail/src/main/scala/is/hail/expr/ir/StringTableReader.scala @@ -1,4 +1,5 @@ package is.hail.expr.ir + import is.hail.annotations.Region import is.hail.asm4s._ import is.hail.backend.ExecuteContext @@ -7,14 +8,15 @@ import is.hail.expr.ir.lowering.{LowererUnsupportedOperation, TableStage, TableS import is.hail.expr.ir.streams.StreamProducer import is.hail.io.fs.{FS, FileListEntry} import is.hail.rvd.RVDPartitioner +import is.hail.types.{BaseTypeWithRequiredness, RStruct, TableType, VirtualTypeWithReq} import is.hail.types.physical._ import is.hail.types.physical.stypes.EmitType import is.hail.types.physical.stypes.concrete.{SJavaString, SStackStruct, SStackStructValue} import is.hail.types.physical.stypes.interfaces.{SBaseStructValue, SStreamValue} import is.hail.types.physical.stypes.primitives.{SInt64, SInt64Value} import is.hail.types.virtual._ -import is.hail.types.{BaseTypeWithRequiredness, RStruct, TableType, VirtualTypeWithReq} -import is.hail.utils.{FastSeq, fatal, checkGzipOfGlobbedFiles} +import is.hail.utils.{checkGzipOfGlobbedFiles, FastSeq} + import org.json4s.{Extraction, Formats, JValue} case class StringTableReaderParameters( @@ -22,7 +24,8 @@ case class StringTableReaderParameters( minPartitions: Option[Int], forceBGZ: Boolean, forceGZ: Boolean, - filePerPartition: Boolean) + filePerPartition: Boolean, +) object StringTableReader { def apply(fs: FS, params: StringTableReaderParameters): StringTableReader = { @@ -30,6 +33,7 @@ object StringTableReader { checkGzipOfGlobbedFiles(params.files, fileListEntries, params.forceGZ, params.forceBGZ) new StringTableReader(params, fileListEntries) } + def fromJValue(fs: FS, jv: JValue): StringTableReader = { implicit val formats: Formats = TableReader.formats val params = jv.extract[StringTableReaderParameters] @@ -37,10 +41,12 @@ object StringTableReader { } } -case class StringTablePartitionReader(lines: GenericLines, uidFieldName: String) extends PartitionReader{ +case class StringTablePartitionReader(lines: GenericLines, uidFieldName: String) + extends PartitionReader { override def contextType: Type = lines.contextType - override def fullRowType: TStruct = TStruct("file"-> TString, "text"-> TString, uidFieldName -> TTuple(TInt64, TInt64)) + override def fullRowType: TStruct = + TStruct("file" -> TString, "text" -> TString, uidFieldName -> TTuple(TInt64, TInt64)) override def rowRequiredness(requestedType: TStruct): RStruct = { val req = BaseTypeWithRequiredness(requestedType).asInstanceOf[RStruct] @@ -54,12 +60,13 @@ case class StringTablePartitionReader(lines: GenericLines, uidFieldName: String) cb: EmitCodeBuilder, mb: EmitMethodBuilder[_], context: EmitCode, - requestedType: TStruct + requestedType: TStruct, ): IEmitCode = { val uidSType: SStackStruct = SStackStruct( TTuple(TInt64, TInt64), - Array(EmitType(SInt64, true), EmitType(SInt64, true))) + Array(EmitType(SInt64, true), EmitType(SInt64, true)), + ) context.toI(cb).map(cb) { case partitionContext: SBaseStructValue => val iter = mb.genFieldThisRef[CloseableIterator[GenericLine]]("string_table_reader_iter") @@ -74,15 +81,27 @@ case class StringTablePartitionReader(lines: GenericLines, uidFieldName: String) override val length: Option[EmitCodeBuilder => Code[Int]] = None override def initialize(cb: EmitCodeBuilder, partitionRegion: Value[Region]): Unit = { - val contextAsJavaValue = coerce[Any](StringFunctions.svalueToJavaValue(cb, partitionRegion, partitionContext)) + val contextAsJavaValue = + coerce[Any](StringFunctions.svalueToJavaValue(cb, partitionRegion, partitionContext)) - cb.assign(fileName, partitionContext.loadField(cb, "file").get(cb).asString.loadString(cb)) - cb.assign(partIdx, partitionContext.loadField(cb, "partitionIndex").get(cb).asInt.value.toL) + cb.assign( + fileName, + partitionContext.loadField(cb, "file").getOrAssert(cb).asString.loadString(cb), + ) + cb.assign( + partIdx, + partitionContext.loadField(cb, "partitionIndex").getOrAssert(cb).asInt.value.toL, + ) cb.assign(rowIdx, -1L) - cb.assign(iter, + cb.assign( + iter, cb.emb.getObject[(FS, Any) => CloseableIterator[GenericLine]](lines.body) - .invoke[Any, Any, CloseableIterator[GenericLine]]("apply", cb.emb.getFS, contextAsJavaValue) + .invoke[Any, Any, CloseableIterator[GenericLine]]( + "apply", + cb.emb.getFS, + contextAsJavaValue, + ), ) } @@ -93,32 +112,43 @@ case class StringTablePartitionReader(lines: GenericLines, uidFieldName: String) override val LproduceElement: CodeLabel = mb.defineAndImplementLabel { cb => val hasNext = iter.invoke[Boolean]("hasNext") - cb.if_(hasNext, { - val gLine = iter.invoke[GenericLine]("next") - cb.assign(line, gLine.invoke[String]("toString")) - cb.assign(rowIdx, rowIdx + 1L) - cb.goto(LproduceElementDone) - }, { - cb.goto(LendOfStream) - }) + cb.if_( + hasNext, { + val gLine = iter.invoke[GenericLine]("next") + cb.assign(line, gLine.invoke[String]("toString")) + cb.assign(rowIdx, rowIdx + 1L) + cb.goto(LproduceElementDone) + }, + cb.goto(LendOfStream), + ) } override val element: EmitCode = EmitCode.fromI(cb.emb) { cb => val uid = EmitValue.present( - new SStackStructValue(uidSType, Array( - EmitValue.present(new SInt64Value(partIdx)), - EmitValue.present(new SInt64Value(rowIdx))))) + new SStackStructValue( + uidSType, + Array( + EmitValue.present(new SInt64Value(partIdx)), + EmitValue.present(new SInt64Value(rowIdx)), + ), + ) + ) val requestedFields = IndexedSeq[Option[EmitCode]]( - requestedType.selfField("file").map(_ => EmitCode.present(cb.emb, SJavaString.construct(cb, fileName))), - requestedType.selfField("text").map(_ => EmitCode.present(cb.emb, SJavaString.construct(cb, line))), - requestedType.selfField(uidFieldName).map(_ => uid) + requestedType.selfField("file").map(_ => + EmitCode.present(cb.emb, SJavaString.construct(cb, fileName)) + ), + requestedType.selfField("text").map(_ => + EmitCode.present(cb.emb, SJavaString.construct(cb, line)) + ), + requestedType.selfField(uidFieldName).map(_ => uid), ).flatten.toIndexedSeq - IEmitCode.present(cb, SStackStruct.constructFromArgs(cb, elementRegion, requestedType, - requestedFields: _*)) + IEmitCode.present( + cb, + SStackStruct.constructFromArgs(cb, elementRegion, requestedType, requestedFields: _*), + ) } - override def close(cb: EmitCodeBuilder): Unit = { + override def close(cb: EmitCodeBuilder): Unit = cb += iter.invoke[Unit]("close") - } }) } } @@ -128,15 +158,16 @@ case class StringTablePartitionReader(lines: GenericLines, uidFieldName: String) case class StringTableReader( val params: StringTableReaderParameters, - fileListEntries: IndexedSeq[FileListEntry] + fileListEntries: IndexedSeq[FileListEntry], ) extends TableReaderWithExtraUID { override def uidType = TTuple(TInt64, TInt64) override def fullTypeWithoutUIDs: TableType = TableType( - TStruct("file"-> TString, "text" -> TString), + TStruct("file" -> TString, "text" -> TString), FastSeq.empty, - TStruct()) + TStruct(), + ) override def renderShort(): String = defaultRender() @@ -144,32 +175,50 @@ case class StringTableReader( override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = { val fs = ctx.fs - val lines = GenericLines.read(fs, fileListEntries, None, None, params.minPartitions, params.forceBGZ, params.forceGZ, - params.filePerPartition) - TableStage(globals = MakeStruct(FastSeq()), + val lines = GenericLines.read( + fs, + fileListEntries, + None, + None, + params.minPartitions, + params.forceBGZ, + params.forceGZ, + params.filePerPartition, + ) + TableStage( + globals = MakeStruct(FastSeq()), partitioner = RVDPartitioner.unkeyed(ctx.stateManager, lines.nPartitions), dependency = TableStageDependency.none, contexts = ToStream(Literal.coerce(TArray(lines.contextType), lines.contexts)), - body = { partitionContext: Ref => ReadPartition(partitionContext, requestedType.rowType, StringTablePartitionReader(lines, uidFieldName)) - } + body = { partitionContext: Ref => + ReadPartition( + partitionContext, + requestedType.rowType, + StringTablePartitionReader(lines, uidFieldName), + ) + }, ) } override def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR = - throw new LowererUnsupportedOperation(s"${ getClass.getSimpleName }.lowerGlobals not implemented") + throw new LowererUnsupportedOperation(s"${getClass.getSimpleName}.lowerGlobals not implemented") override def partitionCounts: Option[IndexedSeq[Long]] = None - override def concreteRowRequiredness(ctx: ExecuteContext, requestedType: TableType): VirtualTypeWithReq = + override def concreteRowRequiredness(ctx: ExecuteContext, requestedType: TableType) + : VirtualTypeWithReq = VirtualTypeWithReq(PCanonicalStruct( - IndexedSeq(PField("file", PCanonicalString(true), 0), - PField("text", PCanonicalString(true), 1)), - true + IndexedSeq( + PField("file", PCanonicalString(true), 0), + PField("text", PCanonicalString(true), 1), + ), + true, ).subsetTo(requestedType.rowType)) override def uidRequiredness: VirtualTypeWithReq = VirtualTypeWithReq(PCanonicalTuple(true, PInt64Required, PInt64Required)) - override def globalRequiredness(ctx: ExecuteContext, requestedType: TableType): VirtualTypeWithReq = + override def globalRequiredness(ctx: ExecuteContext, requestedType: TableType) + : VirtualTypeWithReq = VirtualTypeWithReq(PCanonicalStruct.empty(required = true)) } diff --git a/hail/src/main/scala/is/hail/expr/ir/Subst.scala b/hail/src/main/scala/is/hail/expr/ir/Subst.scala index 178972b9e84..371064b0901 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Subst.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Subst.scala @@ -1,27 +1,23 @@ package is.hail.expr.ir -import is.hail.utils._ - object Subst { def apply(e: IR): IR = apply(e, BindingEnv.empty[IR]) - def apply(e: IR, env: BindingEnv[IR]): IR = { + def apply(e: IR, env: BindingEnv[IR]): IR = subst(e, env) - } private def subst(e: IR, env: BindingEnv[IR]): IR = { if (env.allEmpty) return e e match { - case x@Ref(name, _) => + case x @ Ref(name, _) => env.eval.lookupOption(name).getOrElse(x) case _ => e.mapChildrenWithIndex { case (child: IR, i) => - val childEnv = ChildEnvWithoutBindings(e, i, env) - val newBindings = NewBindings(e, i, childEnv) - subst(child, childEnv.subtract(newBindings)) + val bindings = Bindings.segregated(e, i, env) + subst(child, bindings.childEnvWithoutBindings.subtract(bindings.newBindings)) case (child, _) => child } } diff --git a/hail/src/main/scala/is/hail/expr/ir/Sym.scala b/hail/src/main/scala/is/hail/expr/ir/Sym.scala index 974bfbdda2d..863093d9d84 100644 --- a/hail/src/main/scala/is/hail/expr/ir/Sym.scala +++ b/hail/src/main/scala/is/hail/expr/ir/Sym.scala @@ -15,12 +15,11 @@ object Sym { abstract class Sym case class Identifier(name: String) extends Sym { - override def toString: String = { + override def toString: String = if (name.matches("""\p{javaJavaIdentifierStart}\p{javaJavaIdentifierPart}*""")) name else - s"`${ StringEscapeUtils.escapeString(name, backticked = true) }`" - } + s"`${StringEscapeUtils.escapeString(name, backticked = true)}`" } // lang is one of "py" or "j" diff --git a/hail/src/main/scala/is/hail/expr/ir/TableIR.scala b/hail/src/main/scala/is/hail/expr/ir/TableIR.scala index b284f29344b..4c238dfd49e 100644 --- a/hail/src/main/scala/is/hail/expr/ir/TableIR.scala +++ b/hail/src/main/scala/is/hail/expr/ir/TableIR.scala @@ -3,10 +3,12 @@ package is.hail.expr.ir import is.hail.HailContext import is.hail.annotations._ import is.hail.asm4s._ -import is.hail.backend.spark.{SparkBackend, SparkTaskContext} import is.hail.backend.{ExecuteContext, HailStateManager, HailTaskContext, TaskFinalizer} +import is.hail.backend.spark.{SparkBackend, SparkTaskContext} import is.hail.expr.ir -import is.hail.expr.ir.functions.{BlockMatrixToTableFunction, IntervalFunctions, MatrixToTableFunction, TableToTableFunction} +import is.hail.expr.ir.functions.{ + BlockMatrixToTableFunction, IntervalFunctions, MatrixToTableFunction, TableToTableFunction, +} import is.hail.expr.ir.lowering._ import is.hail.expr.ir.streams.{StreamProducer, StreamUtils} import is.hail.io._ @@ -25,20 +27,22 @@ import is.hail.types.physical.stypes.primitives.{SInt64, SInt64Value} import is.hail.types.virtual._ import is.hail.utils._ import is.hail.utils.prettyPrint.ArrayOfByteArrayInputStream + +import scala.reflect.ClassTag + +import java.io.{Closeable, DataInputStream, DataOutputStream, InputStream} + import org.apache.spark.TaskContext import org.apache.spark.sql.Row +import org.json4s.{DefaultFormats, Extraction, Formats, JValue, ShortTypeHints} import org.json4s.JsonAST.JString import org.json4s.jackson.JsonMethods -import org.json4s.{DefaultFormats, Extraction, Formats, JValue, ShortTypeHints} - -import java.io.{Closeable, DataInputStream, DataOutputStream, InputStream} -import scala.reflect.ClassTag - object TableIR { - def read(fs: FS, path: String, dropRows: Boolean = false, requestedType: Option[TableType] = None): TableRead = { + def read(fs: FS, path: String, dropRows: Boolean = false, requestedType: Option[TableType] = None) + : TableRead = { val successFile = path + "/_SUCCESS" - if (!fs.exists(path + "/_SUCCESS")) + if (!fs.isFile(path + "/_SUCCESS")) fatal(s"write failed: file not found: $successFile") val tr = TableNativeReader.read(fs, path, None) @@ -46,7 +50,7 @@ object TableIR { } } -abstract sealed class TableIR extends BaseIR { +sealed abstract class TableIR extends BaseIR { def typ: TableType def partitionCounts: Option[IndexedSeq[Long]] = None @@ -64,12 +68,12 @@ abstract sealed class TableIR extends BaseIR { override def copy(newChildren: IndexedSeq[BaseIR]): TableIR - def unpersist(): TableIR = { + def unpersist(): TableIR = this match { - case TableLiteral(typ, rvd, enc, encodedGlobals) => TableLiteral(typ, rvd.unpersist(), enc, encodedGlobals) + case TableLiteral(typ, rvd, enc, encodedGlobals) => + TableLiteral(typ, rvd.unpersist(), enc, encodedGlobals) case x => x } - } def pyUnpersist(): TableIR = unpersist() @@ -77,12 +81,21 @@ abstract sealed class TableIR extends BaseIR { } object TableLiteral { - def apply(value: TableValue, theHailClassLoader: HailClassLoader): TableLiteral = { - TableLiteral(value.typ, value.rvd, value.globals.encoding, value.globals.encodeToByteArrays(theHailClassLoader)) - } + def apply(value: TableValue, theHailClassLoader: HailClassLoader): TableLiteral = + TableLiteral( + value.typ, + value.rvd, + value.globals.encoding, + value.globals.encodeToByteArrays(theHailClassLoader), + ) } -case class TableLiteral(typ: TableType, rvd: RVD, enc: AbstractTypedCodecSpec, encodedGlobals: Array[Array[Byte]]) extends TableIR { +case class TableLiteral( + typ: TableType, + rvd: RVD, + enc: AbstractTypedCodecSpec, + encodedGlobals: Array[Array[Byte]], +) extends TableIR { val childrenSeq: IndexedSeq[BaseIR] = Array.empty[BaseIR] lazy val rowCountUpperBound: Option[Long] = None @@ -92,12 +105,18 @@ case class TableLiteral(typ: TableType, rvd: RVD, enc: AbstractTypedCodecSpec, e TableLiteral(typ, rvd, enc, encodedGlobals) } - protected[ir] override def execute(ctx: ExecuteContext, r: LoweringAnalyses): TableExecuteIntermediate = { + override protected[ir] def execute(ctx: ExecuteContext, r: LoweringAnalyses) + : TableExecuteIntermediate = { val (globalPType: PStruct, dec) = enc.buildDecoder(ctx, typ.globalType) val bais = new ArrayOfByteArrayInputStream(encodedGlobals) val globalOffset = dec.apply(bais, ctx.theHailClassLoader).readRegionValue(ctx.r) - new TableValueIntermediate(TableValue(ctx, typ, BroadcastRow(ctx, RegionValue(ctx.r, globalOffset), globalPType), rvd)) + new TableValueIntermediate(TableValue( + ctx, + typ, + BroadcastRow(ctx, RegionValue(ctx.r, globalOffset), globalPType), + rvd, + )) } } @@ -134,7 +153,7 @@ object LoweredTableReader { bodyPType: (TStruct) => PStruct, keys: (TStruct) => (Region, HailClassLoader, FS, Any) => Iterator[Long], context: String, - cacheKey: Any + cacheKey: Any, ): LoweredTableReaderCoercer = { assert(key.nonEmpty) assert(contexts.nonEmpty) @@ -155,183 +174,282 @@ object LoweredTableReader { case Some(r) => r case None => info(s"scanning $context for sortedness...") - val prevkey = AggSignature(PrevNonnull(), - FastSeq(), - FastSeq(keyType)) + val prevkey = AggSignature(PrevNonnull(), FastSeq(), FastSeq(keyType)) - val count = AggSignature(Count(), - FastSeq(), - FastSeq()) + val count = AggSignature(Count(), FastSeq(), FastSeq()) val xType = TStruct( "key" -> keyType, "token" -> TFloat64, - "prevkey" -> keyType) - - val samplekey = AggSignature(TakeBy(), - FastSeq(TInt32), - FastSeq(keyType, TFloat64)) - - val sum = AggSignature(Sum(), - FastSeq(), - FastSeq(TInt64)) - - val minkey = AggSignature(TakeBy(), - FastSeq(TInt32), - FastSeq(keyType, keyType)) - - val maxkey = AggSignature(TakeBy(Descending), - FastSeq(TInt32), - FastSeq(keyType, keyType)) - - val scanBody = (ctx: IR) => StreamAgg( - StreamAggScan( - ReadPartition(ctx, keyType, new PartitionIteratorLongReader( - keyType, - uidFieldName, - contextType, - (requestedType: Type) => bodyPType(requestedType.asInstanceOf[TStruct]), - (requestedType: Type) => keys(requestedType.asInstanceOf[TStruct]))), - "key", - MakeStruct(FastSeq( - "key" -> Ref("key", keyType), - "token" -> invokeSeeded("rand_unif", 1, TFloat64, RNGStateLiteral(), F64(0.0), F64(1.0)), - "prevkey" -> ApplyScanOp(FastSeq(), FastSeq(Ref("key", keyType)), prevkey)))), - "x", - Let(FastSeq("n" -> ApplyAggOp(FastSeq(), FastSeq(), count)), - AggLet("key", GetField(Ref("x", xType), "key"), + "prevkey" -> keyType, + ) + + val samplekey = AggSignature(TakeBy(), FastSeq(TInt32), FastSeq(keyType, TFloat64)) + + val sum = AggSignature(Sum(), FastSeq(), FastSeq(TInt64)) + + val minkey = AggSignature(TakeBy(), FastSeq(TInt32), FastSeq(keyType, keyType)) + + val maxkey = AggSignature(TakeBy(Descending), FastSeq(TInt32), FastSeq(keyType, keyType)) + + val scanBody = (ctx: IR) => + StreamAgg( + StreamAggScan( + ReadPartition( + ctx, + keyType, + new PartitionIteratorLongReader( + keyType, + uidFieldName, + contextType, + (requestedType: Type) => bodyPType(requestedType.asInstanceOf[TStruct]), + (requestedType: Type) => keys(requestedType.asInstanceOf[TStruct]), + ), + ), + "key", MakeStruct(FastSeq( - "n" -> Ref("n", TInt64), - "minkey" -> - ApplyAggOp( - FastSeq(I32(1)), - FastSeq(Ref("key", keyType), Ref("key", keyType)), - minkey), - "maxkey" -> - ApplyAggOp( - FastSeq(I32(1)), - FastSeq(Ref("key", keyType), Ref("key", keyType)), - maxkey), - "ksorted" -> - ApplyComparisonOp(EQ(TInt64), + "key" -> Ref("key", keyType), + "token" -> invokeSeeded( + "rand_unif", + 1, + TFloat64, + RNGStateLiteral(), + F64(0.0), + F64(1.0), + ), + "prevkey" -> ApplyScanOp(FastSeq(), FastSeq(Ref("key", keyType)), prevkey), + )), + ), + "x", + Let( + FastSeq("n" -> ApplyAggOp(FastSeq(), FastSeq(), count)), + AggLet( + "key", + GetField(Ref("x", xType), "key"), + MakeStruct(FastSeq( + "n" -> Ref("n", TInt64), + "minkey" -> ApplyAggOp( - FastSeq(), - FastSeq( - invoke("toInt64", TInt64, - invoke("lor", TBoolean, - IsNA(GetField(Ref("x", xType), "prevkey")), - ApplyComparisonOp(LTEQ(keyType), - GetField(Ref("x", xType), "prevkey"), - GetField(Ref("x", xType), "key"))))), - sum), - Ref("n", TInt64)), - "pksorted" -> - ApplyComparisonOp(EQ(TInt64), + FastSeq(I32(1)), + FastSeq(Ref("key", keyType), Ref("key", keyType)), + minkey, + ), + "maxkey" -> ApplyAggOp( - FastSeq(), - FastSeq( - invoke("toInt64", TInt64, - invoke("lor", TBoolean, - IsNA(selectPK(GetField(Ref("x", xType), "prevkey"))), - ApplyComparisonOp(LTEQ(pkType), - selectPK(GetField(Ref("x", xType), "prevkey")), - selectPK(GetField(Ref("x", xType), "key")))))), - sum), - Ref("n", TInt64)), - "sample" -> ApplyAggOp( - FastSeq(I32(samplesPerPartition)), - FastSeq(GetField(Ref("x", xType), "key"), GetField(Ref("x", xType), "token")), - samplekey))), - isScan = false))) + FastSeq(I32(1)), + FastSeq(Ref("key", keyType), Ref("key", keyType)), + maxkey, + ), + "ksorted" -> + ApplyComparisonOp( + EQ(TInt64), + ApplyAggOp( + FastSeq(), + FastSeq( + invoke( + "toInt64", + TInt64, + invoke( + "lor", + TBoolean, + IsNA(GetField(Ref("x", xType), "prevkey")), + ApplyComparisonOp( + LTEQ(keyType), + GetField(Ref("x", xType), "prevkey"), + GetField(Ref("x", xType), "key"), + ), + ), + ) + ), + sum, + ), + Ref("n", TInt64), + ), + "pksorted" -> + ApplyComparisonOp( + EQ(TInt64), + ApplyAggOp( + FastSeq(), + FastSeq( + invoke( + "toInt64", + TInt64, + invoke( + "lor", + TBoolean, + IsNA(selectPK(GetField(Ref("x", xType), "prevkey"))), + ApplyComparisonOp( + LTEQ(pkType), + selectPK(GetField(Ref("x", xType), "prevkey")), + selectPK(GetField(Ref("x", xType), "key")), + ), + ), + ) + ), + sum, + ), + Ref("n", TInt64), + ), + "sample" -> ApplyAggOp( + FastSeq(I32(samplesPerPartition)), + FastSeq(GetField(Ref("x", xType), "key"), GetField(Ref("x", xType), "token")), + samplekey, + ), + )), + isScan = false, + ), + ), + ) val scanResult = CollectDistributedArray( ToStream(Literal(TArray(contextType), contexts)), MakeStruct(FastSeq()), "context", "globals", - scanBody(Ref("context", contextType)), NA(TString), "table_coerce_sortedness") + scanBody(Ref("context", contextType)), + NA(TString), + "table_coerce_sortedness", + ) val sortedPartDataIR = sortIR(bindIR(scanResult) { scanResult => mapIR( filterIR( mapIR( - rangeIR(I32(0), ArrayLen(scanResult))) { i => + rangeIR(I32(0), ArrayLen(scanResult)) + ) { i => InsertFields( ArrayRef(scanResult, i), - FastSeq("i" -> i)) - }) { row => ArrayLen(GetField(row, "minkey")) > 0 } + FastSeq("i" -> i), + ) + } + )(row => ArrayLen(GetField(row, "minkey")) > 0) ) { row => - InsertFields(row, FastSeq( - ("minkey", ArrayRef(GetField(row, "minkey"), I32(0))), - ("maxkey", ArrayRef(GetField(row, "maxkey"), I32(0))))) + InsertFields( + row, + FastSeq( + ("minkey", ArrayRef(GetField(row, "minkey"), I32(0))), + ("maxkey", ArrayRef(GetField(row, "maxkey"), I32(0))), + ), + ) } }) { (l, r) => - ApplyComparisonOp(LT(TStruct("minkey" -> keyType, "maxkey" -> keyType)), + ApplyComparisonOp( + LT(TStruct("minkey" -> keyType, "maxkey" -> keyType)), SelectFields(l, FastSeq("minkey", "maxkey")), - SelectFields(r, FastSeq("minkey", "maxkey"))) + SelectFields(r, FastSeq("minkey", "maxkey")), + ) } val partDataElt = tcoerce[TArray](sortedPartDataIR.typ).elementType val summary = - Let(FastSeq("sortedPartData" -> sortedPartDataIR), + Let( + FastSeq("sortedPartData" -> sortedPartDataIR), MakeStruct(FastSeq( "ksorted" -> - invoke("land", TBoolean, - StreamFold(ToStream(Ref("sortedPartData", sortedPartDataIR.typ)), + invoke( + "land", + TBoolean, + StreamFold( + ToStream(Ref("sortedPartData", sortedPartDataIR.typ)), True(), "acc", "partDataWithIndex", - invoke("land", TBoolean, + invoke( + "land", + TBoolean, Ref("acc", TBoolean), - GetField(Ref("partDataWithIndex", partDataElt), "ksorted"))), + GetField(Ref("partDataWithIndex", partDataElt), "ksorted"), + ), + ), StreamFold( StreamRange( I32(0), ArrayLen(Ref("sortedPartData", sortedPartDataIR.typ)) - I32(1), - I32(1)), + I32(1), + ), True(), - "acc", "i", - invoke("land", TBoolean, + "acc", + "i", + invoke( + "land", + TBoolean, Ref("acc", TBoolean), - ApplyComparisonOp(LTEQ(keyType), + ApplyComparisonOp( + LTEQ(keyType), GetField( ArrayRef(Ref("sortedPartData", sortedPartDataIR.typ), Ref("i", TInt32)), - "maxkey"), + "maxkey", + ), GetField( - ArrayRef(Ref("sortedPartData", sortedPartDataIR.typ), Ref("i", TInt32) + I32(1)), - "minkey"))))), + ArrayRef( + Ref("sortedPartData", sortedPartDataIR.typ), + Ref("i", TInt32) + I32(1), + ), + "minkey", + ), + ), + ), + ), + ), "pksorted" -> - invoke("land", TBoolean, - StreamFold(ToStream(Ref("sortedPartData", sortedPartDataIR.typ)), + invoke( + "land", + TBoolean, + StreamFold( + ToStream(Ref("sortedPartData", sortedPartDataIR.typ)), True(), "acc", "partDataWithIndex", - invoke("land", TBoolean, + invoke( + "land", + TBoolean, Ref("acc", TBoolean), - GetField(Ref("partDataWithIndex", partDataElt), "pksorted"))), + GetField(Ref("partDataWithIndex", partDataElt), "pksorted"), + ), + ), StreamFold( StreamRange( I32(0), ArrayLen(Ref("sortedPartData", sortedPartDataIR.typ)) - I32(1), - I32(1)), + I32(1), + ), True(), - "acc", "i", - invoke("land", TBoolean, + "acc", + "i", + invoke( + "land", + TBoolean, Ref("acc", TBoolean), - ApplyComparisonOp(LTEQ(pkType), + ApplyComparisonOp( + LTEQ(pkType), selectPK(GetField( ArrayRef(Ref("sortedPartData", sortedPartDataIR.typ), Ref("i", TInt32)), - "maxkey")), + "maxkey", + )), selectPK(GetField( - ArrayRef(Ref("sortedPartData", sortedPartDataIR.typ), Ref("i", TInt32) + I32(1)), - "minkey")))))), - "sortedPartData" -> Ref("sortedPartData", sortedPartDataIR.typ)))) - - val (Some(PTypeReferenceSingleCodeType(resultPType: PStruct)), f) = Compile[AsmFunction1RegionLong](ctx, - FastSeq(), - FastSeq[TypeInfo[_]](classInfo[Region]), LongInfo, - summary, - optimize = true) + ArrayRef( + Ref("sortedPartData", sortedPartDataIR.typ), + Ref("i", TInt32) + I32(1), + ), + "minkey", + )), + ), + ), + ), + ), + "sortedPartData" -> Ref("sortedPartData", sortedPartDataIR.typ), + )), + ) + + val (Some(PTypeReferenceSingleCodeType(resultPType: PStruct)), f) = + Compile[AsmFunction1RegionLong]( + ctx, + FastSeq(), + FastSeq[TypeInfo[_]](classInfo[Region]), + LongInfo, + summary, + optimize = true, + ) val s = ctx.scopedExecution { (hcl, fs, htc, r) => val a = f(hcl, fs, htc, r)(r) @@ -343,29 +461,45 @@ object LoweredTableReader { val sortedPartData = s.getAs[IndexedSeq[Row]](2) val coercer = if (ksorted) { - info(s"Coerced sorted ${ context } - no additional import work to do") + info(s"Coerced sorted $context - no additional import work to do") new LoweredTableReaderCoercer { - def coerce(ctx: ExecuteContext, + def coerce( + ctx: ExecuteContext, globals: IR, contextType: Type, contexts: IndexedSeq[Any], - body: IR => IR): TableStage = { + body: IR => IR, + ): TableStage = { val partOrigIndex = sortedPartData.map(_.getInt(6)) - val partitioner = new RVDPartitioner(ctx.stateManager, keyType, + val partitioner = new RVDPartitioner( + ctx.stateManager, + keyType, sortedPartData.map { partData => - Interval(partData.get(1), partData.get(2), includesStart = true, includesEnd = true) + Interval( + partData.get(1), + partData.get(2), + includesStart = true, + includesEnd = true, + ) }, - key.length) + key.length, + ) - TableStage(globals, partitioner, TableStageDependency.none, + TableStage( + globals, + partitioner, + TableStageDependency.none, ToStream(Literal(TArray(contextType), partOrigIndex.map(i => contexts(i)))), - body) + body, + ) } } } else if (pksorted) { - info(s"Coerced prefix-sorted $context, requiring additional sorting within data partitions on each query.") + info( + s"Coerced prefix-sorted $context, requiring additional sorting within data partitions on each query." + ) new LoweredTableReaderCoercer { private[this] def selectPK(r: Row): Row = { @@ -378,58 +512,87 @@ object LoweredTableReader { Row.fromSeq(a) } - def coerce(ctx: ExecuteContext, + def coerce( + ctx: ExecuteContext, globals: IR, contextType: Type, contexts: IndexedSeq[Any], - body: IR => IR): TableStage = { + body: IR => IR, + ): TableStage = { val partOrigIndex = sortedPartData.map(_.getInt(6)) - val partitioner = new RVDPartitioner(ctx.stateManager, pkType, + val partitioner = new RVDPartitioner( + ctx.stateManager, + pkType, sortedPartData.map { partData => - Interval(selectPK(partData.getAs[Row](1)), selectPK(partData.getAs[Row](2)), includesStart = true, includesEnd = true) - }, pkType.size) + Interval( + selectPK(partData.getAs[Row](1)), + selectPK(partData.getAs[Row](2)), + includesStart = true, + includesEnd = true, + ) + }, + pkType.size, + ) - val pkPartitioned = TableStage(globals, partitioner, TableStageDependency.none, + val pkPartitioned = TableStage( + globals, + partitioner, + TableStageDependency.none, ToStream(Literal(TArray(contextType), partOrigIndex.map(i => contexts(i)))), - body) + body, + ) pkPartitioned .extendKeyPreservesPartitioning(ctx, key) .mapPartition(None) { part => - flatMapIR(StreamGroupByKey(part, pkType.fieldNames, missingEqual = true)) { inner => - ToStream(sortIR(inner) { case (l, r) => ApplyComparisonOp(LT(l.typ), l, r) }) + flatMapIR(StreamGroupByKey(part, pkType.fieldNames, missingEqual = true)) { + inner => + ToStream(sortIR(inner) { case (l, r) => ApplyComparisonOp(LT(l.typ), l, r) }) } } } } } else { - info(s"$context is out of order..." + - s"\n Write the dataset to disk before running multiple queries to avoid multiple costly data shuffles.") + info( + s"$context is out of order..." + + s"\n Write the dataset to disk before running multiple queries to avoid multiple costly data shuffles." + ) new LoweredTableReaderCoercer { - def coerce(ctx: ExecuteContext, + def coerce( + ctx: ExecuteContext, globals: IR, contextType: Type, contexts: IndexedSeq[Any], - body: IR => IR): TableStage = { + body: IR => IR, + ): TableStage = { val partOrigIndex = sortedPartData.map(_.getInt(6)) val partitioner = RVDPartitioner.unkeyed(ctx.stateManager, sortedPartData.length) - val tableStage = TableStage(globals, partitioner, TableStageDependency.none, + val tableStage = TableStage( + globals, + partitioner, + TableStageDependency.none, ToStream(Literal(TArray(contextType), partOrigIndex.map(i => contexts(i)))), - body) + body, + ) - val rowRType = VirtualTypeWithReq(bodyPType(tableStage.rowType)).r.asInstanceOf[RStruct] + val rowRType = + VirtualTypeWithReq(bodyPType(tableStage.rowType)).r.asInstanceOf[RStruct] val globReq = Requiredness(globals, ctx) val globRType = globReq.lookup(globals).asInstanceOf[RStruct] - ctx.backend.lowerDistributedSort(ctx, + ctx.backend.lowerDistributedSort( + ctx, tableStage, keyType.fieldNames.map(f => SortField(f, Ascending)), - RTable(rowRType, globRType, FastSeq()) - ).lower(ctx, TableType(tableStage.rowType, keyType.fieldNames, globals.typ.asInstanceOf[TStruct])) + RTable(rowRType, globRType, FastSeq()), + ).lower( + ctx, + TableType(tableStage.rowType, keyType.fieldNames, globals.typ.asInstanceOf[TStruct]), + ) } } } @@ -440,7 +603,6 @@ object LoweredTableReader { } } - trait TableReaderWithExtraUID extends TableReader { def fullTypeWithoutUIDs: TableType @@ -451,17 +613,20 @@ trait TableReaderWithExtraUID extends TableReader { require(!fullTypeWithoutUIDs.rowType.hasField(uidFieldName)) fullTypeWithoutUIDs.copy( rowType = fullTypeWithoutUIDs.rowType.insertFields( - Array((uidFieldName, uidType)))) + Array((uidFieldName, uidType)) + ) + ) } def uidType: Type - - protected def concreteRowRequiredness(ctx: ExecuteContext, requestedType: TableType): VirtualTypeWithReq + protected def concreteRowRequiredness(ctx: ExecuteContext, requestedType: TableType) + : VirtualTypeWithReq protected def uidRequiredness: VirtualTypeWithReq - override def rowRequiredness(ctx: ExecuteContext, requestedType: TableType): VirtualTypeWithReq = { + override def rowRequiredness(ctx: ExecuteContext, requestedType: TableType) + : VirtualTypeWithReq = { val requestedUID = requestedType.rowType.hasField(uidFieldName) val concreteRowType = if (requestedUID) requestedType.rowType.deleteKey(uidFieldName) @@ -472,12 +637,14 @@ trait TableReaderWithExtraUID extends TableReader { val concreteRFields = concreteRowReq.r.asInstanceOf[RStruct].fields VirtualTypeWithReq( requestedType.rowType, - RStruct(concreteRFields :+ RField(uidFieldName, uidRequiredness.r, concreteRFields.length))) + RStruct(concreteRFields :+ RField(uidFieldName, uidRequiredness.r, concreteRFields.length)), + ) } else { concreteRowReq } } } + abstract class TableReader { def pathsUsed: Seq[String] @@ -490,15 +657,16 @@ abstract class TableReader { RVDPartitioner.empty(ctx, requestedType.keyType), TableStageDependency.none, MakeStream(FastSeq(), TStream(TStruct.empty)), - (_: Ref) => MakeStream(FastSeq(), TStream(requestedType.rowType))) + (_: Ref) => MakeStream(FastSeq(), TStream(requestedType.rowType)), + ) } else { lower(ctx, requestedType) } } - def toExecuteIntermediate(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableExecuteIntermediate = { + def toExecuteIntermediate(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean) + : TableExecuteIntermediate = TableExecuteIntermediate(lower(ctx, requestedType, dropRows)) - } def partitionCounts: Option[IndexedSeq[Long]] @@ -510,15 +678,13 @@ abstract class TableReader { def globalRequiredness(ctx: ExecuteContext, requestedType: TableType): VirtualTypeWithReq - def toJValue: JValue = { + def toJValue: JValue = Extraction.decompose(this)(TableReader.formats) - } def renderShort(): String - def defaultRender(): String = { + def defaultRender(): String = StringEscapeUtils.escapeString(JsonMethods.compact(toJValue)) - } def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR @@ -532,7 +698,8 @@ object TableNativeReader { def apply(fs: FS, params: TableNativeReaderParameters): TableNativeReader = { val spec = (RelationalSpec.read(fs, params.path): @unchecked) match { case ts: AbstractTableSpec => ts - case _: AbstractMatrixTableSpec => fatal(s"file is a MatrixTable, not a Table: '${ params.path }'") + case _: AbstractMatrixTableSpec => + fatal(s"file is a MatrixTable, not a Table: '${params.path}'") } val filterIntervals = params.options.map(_.filterIntervals).getOrElse(false) @@ -541,7 +708,8 @@ object TableNativeReader { fatal( """`intervals` specified on an unindexed table. |This table was written using an older version of hail - |rewrite the table in order to create an index to proceed""".stripMargin) + |rewrite the table in order to create an index to proceed""".stripMargin + ) new TableNativeReader(params, spec) } @@ -553,11 +721,11 @@ object TableNativeReader { } } - case class PartitionRVDReader(rvd: RVD, uidFieldName: String) extends PartitionReader { override def contextType: Type = TInt32 - override def fullRowType: TStruct = rvd.rowType.insertFields(Array(uidFieldName -> TTuple(TInt64, TInt64))) + override def fullRowType: TStruct = + rvd.rowType.insertFields(Array(uidFieldName -> TTuple(TInt64, TInt64))) override def rowRequiredness(requestedType: TStruct): RStruct = { val tr = TypeWithRequiredness(requestedType).asInstanceOf[RStruct] @@ -570,22 +738,34 @@ case class PartitionRVDReader(rvd: RVD, uidFieldName: String) extends PartitionR cb: EmitCodeBuilder, mb: EmitMethodBuilder[_], context: EmitCode, - requestedType: TStruct): IEmitCode = { - - val (Some(PTypeReferenceSingleCodeType(upcastPType: PBaseStruct)), upcast) = Compile[AsmFunction2RegionLongLong](ctx, - FastSeq(("elt", SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(rvd.rowPType)))), - FastSeq(classInfo[Region], LongInfo), - LongInfo, - PruneDeadFields.upcast(ctx, Ref("elt", rvd.rowType), requestedType)) + requestedType: TStruct, + ): IEmitCode = { - val upcastCode = mb.getObject[Function4[HailClassLoader, FS, HailTaskContext, Region, AsmFunction2RegionLongLong]](upcast) + val (Some(PTypeReferenceSingleCodeType(upcastPType: PBaseStruct)), upcast) = + Compile[AsmFunction2RegionLongLong]( + ctx, + FastSeq(("elt", SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(rvd.rowPType)))), + FastSeq(classInfo[Region], LongInfo), + LongInfo, + PruneDeadFields.upcast(ctx, Ref("elt", rvd.rowType), requestedType), + ) + + val upcastCode = mb.getObject[Function4[ + HailClassLoader, + FS, + HailTaskContext, + Region, + AsmFunction2RegionLongLong, + ]](upcast) val rowPType = rvd.rowPType.subsetTo(requestedType) val createUID = requestedType.hasField(uidFieldName) - assert(upcastPType == rowPType, - s"ptype mismatch:\n upcast: $upcastPType\n computed: ${ rowPType }\n inputType: ${rvd.rowPType}\n requested: ${requestedType}") + assert( + upcastPType == rowPType, + s"ptype mismatch:\n upcast: $upcastPType\n computed: $rowPType\n inputType: ${rvd.rowPType}\n requested: $requestedType", + ) context.toI(cb).map(cb) { _partIdx => val partIdx = cb.memoizeField(_partIdx, "partIdx") @@ -596,7 +776,8 @@ case class PartitionRVDReader(rvd: RVD, uidFieldName: String) extends PartitionR val region = mb.genFieldThisRef[Region]("rvdreader_region") val upcastF = mb.genFieldThisRef[AsmFunction2RegionLongLong]("rvdreader_upcast") - val broadcastRVD = mb.getObject[BroadcastRVD](new BroadcastRVD(ctx.backend.asSpark("RVDReader"), rvd)) + val broadcastRVD = + mb.getObject[BroadcastRVD](new BroadcastRVD(ctx.backend.asSpark("RVDReader"), rvd)) val producer = new StreamProducer { override def method: EmitMethodBuilder[_] = mb @@ -604,10 +785,31 @@ case class PartitionRVDReader(rvd: RVD, uidFieldName: String) extends PartitionR override def initialize(cb: EmitCodeBuilder, partitionRegion: Value[Region]): Unit = { cb.assign(curIdx, 0L) - cb.assign(iterator, broadcastRVD.invoke[Int, Region, Region, Iterator[Long]]( - "computePartition", partIdx.asInt.value, region, partitionRegion)) - cb.assign(upcastF, Code.checkcast[AsmFunction2RegionLongLong](upcastCode.invoke[AnyRef, AnyRef, AnyRef, AnyRef, AnyRef]( - "apply", cb.emb.ecb.emodb.getHailClassLoader, cb.emb.ecb.emodb.getFS, cb.emb.ecb.getTaskContext, partitionRegion))) + cb.assign( + iterator, + broadcastRVD.invoke[Int, Region, Region, Iterator[Long]]( + "computePartition", + partIdx.asInt.value, + region, + partitionRegion, + ), + ) + cb.assign( + upcastF, + Code.checkcast[AsmFunction2RegionLongLong](upcastCode.invoke[ + AnyRef, + AnyRef, + AnyRef, + AnyRef, + AnyRef, + ]( + "apply", + cb.emb.ecb.emodb.getHailClassLoader, + cb.emb.ecb.emodb.getFS, + cb.emb.ecb.getTaskContext, + partitionRegion, + )), + ) } override val elementRegion: Settable[Region] = region @@ -615,15 +817,30 @@ case class PartitionRVDReader(rvd: RVD, uidFieldName: String) extends PartitionR override val LproduceElement: CodeLabel = mb.defineAndImplementLabel { cb => cb.if_(!iterator.invoke[Boolean]("hasNext"), cb.goto(LendOfStream)) cb.assign(curIdx, curIdx + 1) - cb.assign(next, upcastF.invoke[Region, Long, Long]("apply", region, Code.longValue(iterator.invoke[java.lang.Long]("next")))) + cb.assign( + next, + upcastF.invoke[Region, Long, Long]( + "apply", + region, + Code.longValue(iterator.invoke[java.lang.Long]("next")), + ), + ) cb.goto(LproduceElementDone) } override val element: EmitCode = EmitCode.fromI(mb) { cb => if (createUID) { - val uid = SStackStruct.constructFromArgs(cb, region, TTuple(TInt64, TInt64), - EmitCode.present(mb, partIdx), EmitCode.present(mb, primitive(cb.memoize(curIdx - 1)))) - IEmitCode.present(cb, upcastPType.loadCheapSCode(cb, next) - ._insert(requestedType, uidFieldName -> EmitValue.present(uid))) + val uid = SStackStruct.constructFromArgs( + cb, + region, + TTuple(TInt64, TInt64), + EmitCode.present(mb, partIdx), + EmitCode.present(mb, primitive(cb.memoize(curIdx - 1))), + ) + IEmitCode.present( + cb, + upcastPType.loadCheapSCode(cb, next) + ._insert(requestedType, uidFieldName -> EmitValue.present(uid)), + ) } else { IEmitCode.present(cb, upcastPType.loadCheapSCode(cb, next)) } @@ -636,7 +853,8 @@ case class PartitionRVDReader(rvd: RVD, uidFieldName: String) extends PartitionR } } - def toJValue: JValue = JString("") // cannot be parsed, but need a printout for Pretty + def toJValue: JValue = + JString("") // cannot be parsed, but need a printout for Pretty } trait AbstractNativeReader extends PartitionReader { @@ -662,7 +880,7 @@ trait AbstractNativeReader extends PartitionReader { } case class PartitionNativeReader(spec: AbstractTypedCodecSpec, uidFieldName: String) - extends AbstractNativeReader { + extends AbstractNativeReader { def contextType: Type = TStruct("partitionIndex" -> TInt64, "partitionPath" -> TString) @@ -671,9 +889,12 @@ case class PartitionNativeReader(spec: AbstractTypedCodecSpec, uidFieldName: Str cb: EmitCodeBuilder, mb: EmitMethodBuilder[_], context: EmitCode, - requestedType: TStruct): IEmitCode = { + requestedType: TStruct, + ): IEmitCode = { - val insertUID: Boolean = requestedType.hasField(uidFieldName) && !spec.encodedVirtualType.asInstanceOf[TStruct].hasField(uidFieldName) + val insertUID: Boolean = requestedType.hasField( + uidFieldName + ) && !spec.encodedVirtualType.asInstanceOf[TStruct].hasField(uidFieldName) val concreteType: TStruct = if (insertUID) requestedType.deleteKey(uidFieldName) else @@ -682,17 +903,25 @@ case class PartitionNativeReader(spec: AbstractTypedCodecSpec, uidFieldName: Str val concreteSType = spec.encodedType.decodedSType(concreteType).asInstanceOf[SBaseStruct] val uidSType: SStackStruct = SStackStruct( TTuple(TInt64, TInt64), - Array(EmitType(SInt64, true), EmitType(SInt64, true))) + Array(EmitType(SInt64, true), EmitType(SInt64, true)), + ) val elementSType = if (insertUID) - SInsertFieldsStruct(requestedType, concreteSType, - Array(uidFieldName -> EmitType(uidSType, true))) + SInsertFieldsStruct( + requestedType, + concreteSType, + Array(uidFieldName -> EmitType(uidSType, true)), + ) else concreteSType context.toI(cb).map(cb) { case ctxStruct: SBaseStructValue => - val partIdx = cb.memoizeField(ctxStruct.loadField(cb, "partitionIndex").get(cb), "partIdx") + val partIdx = + cb.memoizeField(ctxStruct.loadField(cb, "partitionIndex").getOrAssert(cb), "partIdx") val rowIdx = mb.genFieldThisRef[Long]("pnr_rowidx") - val pathString = cb.memoizeField(ctxStruct.loadField(cb, "partitionPath").get(cb).asString.loadString(cb)) + val pathString = + cb.memoizeField( + ctxStruct.loadField(cb, "partitionPath").getOrAssert(cb).asString.loadString(cb) + ) val xRowBuf = mb.genFieldThisRef[InputBuffer]("pnr_xrowbuf") val next = mb.newPSettable(mb.fieldBuilder, elementSType, "pnr_next") val region = mb.genFieldThisRef[Region]("pnr_region") @@ -702,7 +931,10 @@ case class PartitionNativeReader(spec: AbstractTypedCodecSpec, uidFieldName: Str override val length: Option[EmitCodeBuilder => Code[Int]] = None override def initialize(cb: EmitCodeBuilder, partitionRegion: Value[Region]): Unit = { - cb.assign(xRowBuf, spec.buildCodeInputBuffer(mb.openUnbuffered(pathString, checkCodec = true))) + cb.assign( + xRowBuf, + spec.buildCodeInputBuffer(mb.openUnbuffered(pathString, checkCodec = true)), + ) cb.assign(rowIdx, -1L) } @@ -711,13 +943,22 @@ case class PartitionNativeReader(spec: AbstractTypedCodecSpec, uidFieldName: Str override val LproduceElement: CodeLabel = mb.defineAndImplementLabel { cb => cb.if_(!xRowBuf.readByte().toZ, cb.goto(LendOfStream)) - val base = spec.encodedType.buildDecoder(concreteType, cb.emb.ecb).apply(cb, region, xRowBuf).asBaseStruct + val base = spec.encodedType.buildDecoder(concreteType, cb.emb.ecb).apply( + cb, + region, + xRowBuf, + ).asBaseStruct if (insertUID) { cb.assign(rowIdx, rowIdx + 1) val uid = EmitValue.present( - new SStackStructValue(uidSType, Array( - EmitValue.present(partIdx), - EmitValue.present(new SInt64Value(rowIdx))))) + new SStackStructValue( + uidSType, + Array( + EmitValue.present(partIdx), + EmitValue.present(new SInt64Value(rowIdx)), + ), + ) + ) cb.assign(next, base._insert(requestedType, uidFieldName -> uid)) } else cb.assign(next, base) @@ -736,7 +977,12 @@ case class PartitionNativeReader(spec: AbstractTypedCodecSpec, uidFieldName: Str def toJValue: JValue = Extraction.decompose(this)(PartitionReader.formats) } -case class PartitionNativeIntervalReader(sm: HailStateManager, tablePath: String, tableSpec: AbstractTableSpec, uidFieldName: String) extends AbstractNativeReader { +case class PartitionNativeIntervalReader( + sm: HailStateManager, + tablePath: String, + tableSpec: AbstractTableSpec, + uidFieldName: String, +) extends AbstractNativeReader { require(tableSpec.indexed) lazy val rowsSpec = tableSpec.rowsSpec @@ -753,20 +999,26 @@ case class PartitionNativeIntervalReader(sm: HailStateManager, tablePath: String cb: EmitCodeBuilder, mb: EmitMethodBuilder[_], context: EmitCode, - requestedType: TStruct): IEmitCode = { + requestedType: TStruct, + ): IEmitCode = { val insertUID: Boolean = requestedType.hasField(uidFieldName) val concreteType: TStruct = if (insertUID) requestedType.deleteKey(uidFieldName) else requestedType - val concreteSType: SBaseStruct = spec.encodedType.decodedSType(concreteType).asInstanceOf[SBaseStruct] + val concreteSType: SBaseStruct = + spec.encodedType.decodedSType(concreteType).asInstanceOf[SBaseStruct] val uidSType: SStackStruct = SStackStruct( TTuple(TInt64, TInt64), - Array(EmitType(SInt64, true), EmitType(SInt64, true))) + Array(EmitType(SInt64, true), EmitType(SInt64, true)), + ) val eltSType: SBaseStruct = if (insertUID) - SInsertFieldsStruct(requestedType, concreteSType, - Array(uidFieldName -> EmitType(uidSType, true))) + SInsertFieldsStruct( + requestedType, + concreteSType, + Array(uidFieldName -> EmitType(uidSType, true)), + ) else concreteSType @@ -776,14 +1028,30 @@ case class PartitionNativeIntervalReader(sm: HailStateManager, tablePath: String val ctx = cb.memoizeField(_ctx, "ctx").asInterval val partitionerLit = partitioner.partitionBoundsIRRepresentation - val partitionerRuntime = cb.emb.addLiteral(cb, partitionerLit.value, VirtualTypeWithReq.fullyOptional(partitionerLit.typ)) + val partitionerRuntime = cb.emb.addLiteral( + cb, + partitionerLit.value, + VirtualTypeWithReq.fullyOptional(partitionerLit.typ), + ) .asIndexable val pathsType = VirtualTypeWithReq.fullyRequired(TArray(TString)) val rowsPath = tableSpec.rowsComponent.absolutePath(tablePath) - val partitionPathsRuntime = cb.memoizeField(mb.addLiteral(cb, rowsSpec.absolutePartPaths(rowsPath).toFastSeq, pathsType), "partitionPathsRuntime") + val partitionPathsRuntime = cb.memoizeField( + mb.addLiteral(cb, rowsSpec.absolutePartPaths(rowsPath).toFastSeq, pathsType), + "partitionPathsRuntime", + ) .asIndexable - val indexPathsRuntime = cb.memoizeField(mb.addLiteral(cb, rowsSpec.partFiles.map(partPath => s"${ rowsPath }/${ indexSpec.relPath }/${ partPath }.idx").toFastSeq, pathsType), "indexPathsRuntime") + val indexPathsRuntime = cb.memoizeField( + mb.addLiteral( + cb, + rowsSpec.partFiles.map(partPath => + s"$rowsPath/${indexSpec.relPath}/$partPath.idx" + ).toFastSeq, + pathsType, + ), + "indexPathsRuntime", + ) .asIndexable val currIdxInPartition = mb.genFieldThisRef[Long]("n_to_read") @@ -810,24 +1078,37 @@ case class PartitionNativeIntervalReader(sm: HailStateManager, tablePath: String override def initialize(cb: EmitCodeBuilder, outerRegion: Value[Region]): Unit = { - val startBound = ctx.loadStart(cb).get(cb) + val startBound = ctx.loadStart(cb).getOrAssert(cb) val includesStart = ctx.includesStart - val endBound = ctx.loadEnd(cb).get(cb) + val endBound = ctx.loadEnd(cb).getOrAssert(cb) val includesEnd = ctx.includesEnd - val (startPart, endPart) = IntervalFunctions.partitionerFindIntervalRange(cb, + val (startPart, endPart) = IntervalFunctions.partitionerFindIntervalRange( + cb, partitionerRuntime, - SStackInterval.construct(EmitValue.present(startBound), EmitValue.present(endBound), includesStart, includesEnd), - -1) - - cb.if_(endPart < startPart, cb._fatal("invalid start/end config - startPartIdx=", - startPartitionIndex.toS, ", endPartIdx=", lastIncludedPartitionIdx.toS)) + SStackInterval.construct( + EmitValue.present(startBound), + EmitValue.present(endBound), + includesStart, + includesEnd, + ), + -1, + ) + + cb.if_( + endPart < startPart, + cb._fatal( + "invalid start/end config - startPartIdx=", + startPartitionIndex.toS, + ", endPartIdx=", + lastIncludedPartitionIdx.toS, + ), + ) cb.assign(startPartitionIndex, startPart) cb.assign(lastIncludedPartitionIdx, endPart - 1) cb.assign(currPartitionIdx, startPartitionIndex) - cb.assign(streamFirst, true) cb.assign(currIdxInPartition, 0L) cb.assign(stopIdxInPartition, 0L) @@ -840,113 +1121,163 @@ case class PartitionNativeIntervalReader(sm: HailStateManager, tablePath: String override val LproduceElement: CodeLabel = mb.defineAndImplementLabel { cb => val Lstart = CodeLabel() cb.define(Lstart) - cb.if_(currIdxInPartition >= stopIdxInPartition, { - cb.if_(currPartitionIdx >= partitioner.numPartitions || currPartitionIdx > lastIncludedPartitionIdx, - cb.goto(LendOfStream)) - - val requiresIndexInit = cb.newLocal[Boolean]("requiresIndexInit") - - cb.if_(streamFirst, { - // if first, reuse open index from previous time the stream was run if possible - // this is a common case if looking up nearby keys - cb.assign(requiresIndexInit, !(indexInitialized && (indexCachedIndex ceq currPartitionIdx))) - }, { - // if not first, then the index must be open to the previous partition and needs to be reinitialized - cb.assign(streamFirst, false) - cb.assign(requiresIndexInit, true) - }) - - cb.if_(requiresIndexInit, { - cb.if_(indexInitialized, { - cb += finalizer.invoke[Unit]("clear") - index.close(cb) - cb += ib.close() - }, { - cb.assign(indexInitialized, true) - }) - cb.assign(indexCachedIndex, currPartitionIdx) - val partPath = partitionPathsRuntime.loadElement(cb, currPartitionIdx).get(cb).asString.loadString(cb) - val idxPath = indexPathsRuntime.loadElement(cb, currPartitionIdx).get(cb).asString.loadString(cb) - index.initialize(cb, idxPath) - cb.assign(ib, spec.buildCodeInputBuffer( - Code.newInstance[ByteTrackingInputStream, InputStream]( - cb.emb.openUnbuffered(partPath, false)))) - index.addToFinalizer(cb, finalizer) - cb += finalizer.invoke[Closeable, Unit]("addCloseable", ib) - }) - - cb.if_(currPartitionIdx ceq lastIncludedPartitionIdx, { - cb.if_(currPartitionIdx ceq startPartitionIndex, { - // query the full interval - val indexResult = index.queryInterval(cb, ctx) - val startIdx = indexResult.loadField(cb, 0) - .get(cb) - .asInt64 - .value - cb.assign(currIdxInPartition, startIdx) - val endIdx = indexResult.loadField(cb, 1) - .get(cb) - .asInt64 - .value - cb.assign(stopIdxInPartition, endIdx) - cb.if_(endIdx > startIdx, { - val firstOffset = indexResult.loadField(cb, 2) - .get(cb) - .asBaseStruct - .loadField(cb, "offset") - .get(cb) - .asInt64 - .value - - cb += ib.seek(firstOffset) - }) - }, { - // read from start of partition to the end interval - - val indexResult = index.queryBound(cb, ctx.loadEnd(cb).get(cb).asBaseStruct, ctx.includesEnd) - val startIdx = indexResult.loadField(cb, 0).get(cb).asInt64.value - cb.assign(currIdxInPartition, 0L) - cb.assign(stopIdxInPartition, startIdx) - // no need to seek, starting at beginning of partition - }) - }, { - cb.if_(currPartitionIdx ceq startPartitionIndex, - { - // read from left endpoint until end of partition - val indexResult = index.queryBound(cb, ctx.loadStart(cb).get(cb).asBaseStruct, cb.memoize(!ctx.includesStart)) - val startIdx = indexResult.loadField(cb, 0).get(cb).asInt64.value - - cb.assign(currIdxInPartition, startIdx) - cb.assign(stopIdxInPartition, index.nKeys(cb)) - cb.if_(currIdxInPartition < stopIdxInPartition, { - val firstOffset = indexResult.loadField(cb, 1).get(cb).asBaseStruct - .loadField(cb, "offset").get(cb).asInt64.value - - cb += ib.seek(firstOffset) - }) + cb.if_( + currIdxInPartition >= stopIdxInPartition, { + cb.if_( + currPartitionIdx >= partitioner.numPartitions || currPartitionIdx > lastIncludedPartitionIdx, + cb.goto(LendOfStream), + ) + + val requiresIndexInit = cb.newLocal[Boolean]("requiresIndexInit") + + cb.if_( + streamFirst, + // if first, reuse open index from previous time the stream was run if possible + // this is a common case if looking up nearby keys + cb.assign( + requiresIndexInit, + !(indexInitialized && (indexCachedIndex ceq currPartitionIdx)), + ), { + /* if not first, then the index must be open to the previous partition and needs + * to be reinitialized */ + cb.assign(streamFirst, false) + cb.assign(requiresIndexInit, true) + }, + ) + + cb.if_( + requiresIndexInit, { + cb.if_( + indexInitialized, { + cb += finalizer.invoke[Unit]("clear") + index.close(cb) + cb += ib.close() + }, + cb.assign(indexInitialized, true), + ) + cb.assign(indexCachedIndex, currPartitionIdx) + val partPath = + partitionPathsRuntime.loadElement(cb, currPartitionIdx).getOrAssert( + cb + ).asString.loadString(cb) + val idxPath = indexPathsRuntime.loadElement(cb, currPartitionIdx).getOrAssert( + cb + ).asString.loadString(cb) + index.initialize(cb, idxPath) + cb.assign( + ib, + spec.buildCodeInputBuffer( + Code.newInstance[ByteTrackingInputStream, InputStream]( + cb.emb.openUnbuffered(partPath, false) + ) + ), + ) + index.addToFinalizer(cb, finalizer) + cb += finalizer.invoke[Closeable, Unit]("addCloseable", ib) + }, + ) + + cb.if_( + currPartitionIdx ceq lastIncludedPartitionIdx, { + cb.if_( + currPartitionIdx ceq startPartitionIndex, { + // query the full interval + val indexResult = index.queryInterval(cb, ctx) + val startIdx = indexResult.loadField(cb, 0) + .getOrAssert(cb) + .asInt64 + .value + cb.assign(currIdxInPartition, startIdx) + val endIdx = indexResult.loadField(cb, 1) + .getOrAssert(cb) + .asInt64 + .value + cb.assign(stopIdxInPartition, endIdx) + cb.if_( + endIdx > startIdx, { + val firstOffset = indexResult.loadField(cb, 2) + .getOrAssert(cb) + .asBaseStruct + .loadField(cb, "offset") + .getOrAssert(cb) + .asInt64 + .value + + cb += ib.seek(firstOffset) + }, + ) + }, { + // read from start of partition to the end interval + + val indexResult = + index.queryBound( + cb, + ctx.loadEnd(cb).getOrAssert(cb).asBaseStruct, + ctx.includesEnd, + ) + val startIdx = indexResult.loadField(cb, 0).getOrAssert(cb).asInt64.value + cb.assign(currIdxInPartition, 0L) + cb.assign(stopIdxInPartition, startIdx) + // no need to seek, starting at beginning of partition + }, + ) }, { - // in the middle of a partition run, so read everything - cb.assign(currIdxInPartition, 0L) - cb.assign(stopIdxInPartition, index.nKeys(cb)) - }) - }) - - cb.assign(currPartitionIdx, currPartitionIdx + 1) - cb.goto(Lstart) - }) + cb.if_( + currPartitionIdx ceq startPartitionIndex, { + // read from left endpoint until end of partition + val indexResult = index.queryBound( + cb, + ctx.loadStart(cb).getOrAssert(cb).asBaseStruct, + cb.memoize(!ctx.includesStart), + ) + val startIdx = indexResult.loadField(cb, 0).getOrAssert(cb).asInt64.value + + cb.assign(currIdxInPartition, startIdx) + cb.assign(stopIdxInPartition, index.nKeys(cb)) + cb.if_( + currIdxInPartition < stopIdxInPartition, { + val firstOffset = + indexResult.loadField(cb, 1).getOrAssert(cb).asBaseStruct + .loadField(cb, "offset").getOrAssert(cb).asInt64.value + + cb += ib.seek(firstOffset) + }, + ) + }, { + // in the middle of a partition run, so read everything + cb.assign(currIdxInPartition, 0L) + cb.assign(stopIdxInPartition, index.nKeys(cb)) + }, + ) + }, + ) + + cb.assign(currPartitionIdx, currPartitionIdx + 1) + cb.goto(Lstart) + }, + ) cb.if_(ib.readByte() cne 1, cb._fatal(s"bad buffer state!")) cb.assign(currIdxInPartition, currIdxInPartition + 1L) - val decRow = spec.encodedType.buildDecoder(requestedType, cb.emb.ecb)(cb, region, ib).asBaseStruct - cb.assign(decodedRow, if (insertUID) - decRow.insert(cb, - elementRegion, - eltSType.virtualType.asInstanceOf[TStruct], - uidFieldName -> EmitValue.present(uidSType.fromEmitCodes(cb, - FastSeq( - EmitCode.present(mb, primitive(currPartitionIdx)), - EmitCode.present(mb, primitive(currIdxInPartition)))))) - else decRow) + val decRow = + spec.encodedType.buildDecoder(requestedType, cb.emb.ecb)(cb, region, ib).asBaseStruct + cb.assign( + decodedRow, + if (insertUID) + decRow.insert( + cb, + elementRegion, + eltSType.virtualType.asInstanceOf[TStruct], + uidFieldName -> EmitValue.present(uidSType.fromEmitCodes( + cb, + FastSeq( + EmitCode.present(mb, primitive(currPartitionIdx)), + EmitCode.present(mb, primitive(currIdxInPartition)), + ), + )), + ) + else decRow, + ) cb.goto(LproduceElementDone) } override val element: EmitCode = EmitCode.fromI(mb) { cb => @@ -968,39 +1299,49 @@ case class PartitionNativeReaderIndexed( spec: AbstractTypedCodecSpec, indexSpec: AbstractIndexSpec, key: IndexedSeq[String], - uidFieldName: String + uidFieldName: String, ) extends AbstractNativeReader { def contextType: Type = TStruct( "partitionIndex" -> TInt64, "partitionPath" -> TString, "indexPath" -> TString, - "interval" -> RVDPartitioner.intervalIRRepresentation(spec.encodedVirtualType.asInstanceOf[TStruct].select(key)._1)) + "interval" -> RVDPartitioner.intervalIRRepresentation( + spec.encodedVirtualType.asInstanceOf[TStruct].select(key)._1 + ), + ) def emitStream( ctx: ExecuteContext, cb: EmitCodeBuilder, mb: EmitMethodBuilder[_], context: EmitCode, - requestedType: TStruct): IEmitCode = { + requestedType: TStruct, + ): IEmitCode = { val insertUID: Boolean = requestedType.hasField(uidFieldName) val concreteType: TStruct = if (insertUID) requestedType.deleteKey(uidFieldName) else requestedType - val concreteSType: SBaseStructPointer = spec.encodedType.decodedSType(concreteType).asInstanceOf[SBaseStructPointer] + val concreteSType: SBaseStructPointer = + spec.encodedType.decodedSType(concreteType).asInstanceOf[SBaseStructPointer] val uidSType: SStackStruct = SStackStruct( TTuple(TInt64, TInt64), - Array(EmitType(SInt64, true), EmitType(SInt64, true))) + Array(EmitType(SInt64, true), EmitType(SInt64, true)), + ) val eltSType: SBaseStruct = if (insertUID) - SInsertFieldsStruct(requestedType, concreteSType, - Array(uidFieldName -> EmitType(uidSType, true))) + SInsertFieldsStruct( + requestedType, + concreteSType, + Array(uidFieldName -> EmitType(uidSType, true)), + ) else concreteSType val index = new StagedIndexReader(cb.emb, indexSpec.leafCodec, indexSpec.internalNodeCodec) context.toI(cb).map(cb) { case ctxStruct: SBaseStructValue => - val partIdx = cb.memoizeField(ctxStruct.loadField(cb, "partitionIndex").get(cb), "partIdx") + val partIdx = + cb.memoizeField(ctxStruct.loadField(cb, "partitionIndex").getOrAssert(cb), "partIdx") val curIdx = mb.genFieldThisRef[Long]("cur_index") val endIdx = mb.genFieldThisRef[Long]("end_index") val ib = mb.genFieldThisRef[InputBuffer]("buffer") @@ -1016,46 +1357,53 @@ case class PartitionNativeReaderIndexed( override def initialize(cb: EmitCodeBuilder, outerRegion: Value[Region]): Unit = { val indexPath = ctxStruct .loadField(cb, "indexPath") - .get(cb) + .getOrAssert(cb) .asString .loadString(cb) val partitionPath = ctxStruct .loadField(cb, "partitionPath") - .get(cb) + .getOrAssert(cb) .asString .loadString(cb) val interval = ctxStruct .loadField(cb, "interval") - .get(cb) + .getOrAssert(cb) .asInterval index.initialize(cb, indexPath) val indexResult = index.queryInterval(cb, interval) val startIndex = indexResult.loadField(cb, 0) - .get(cb) + .getOrAssert(cb) .asInt64 .value val endIndex = indexResult.loadField(cb, 1) - .get(cb) + .getOrAssert(cb) .asInt64 .value cb.assign(curIdx, startIndex) cb.assign(endIdx, endIndex) - cb.assign(ib, spec.buildCodeInputBuffer( - Code.newInstance[ByteTrackingInputStream, InputStream]( - cb.emb.openUnbuffered(partitionPath, false)))) - cb.if_(endIndex > startIndex, { - val firstOffset = indexResult.loadField(cb, 2) - .get(cb) - .asBaseStruct - .loadField(cb, "offset") - .get(cb) - .asInt64 - .value - - cb += ib.seek(firstOffset) - }) + cb.assign( + ib, + spec.buildCodeInputBuffer( + Code.newInstance[ByteTrackingInputStream, InputStream]( + cb.emb.openUnbuffered(partitionPath, false) + ) + ), + ) + cb.if_( + endIndex > startIndex, { + val firstOffset = indexResult.loadField(cb, 2) + .getOrAssert(cb) + .asBaseStruct + .loadField(cb, "offset") + .getOrAssert(cb) + .asInt64 + .value + + cb += ib.seek(firstOffset) + }, + ) index.close(cb) } override val elementRegion: Settable[Region] = region @@ -1064,15 +1412,25 @@ case class PartitionNativeReaderIndexed( cb.if_(curIdx >= endIdx, cb.goto(LendOfStream)) val next = ib.readByte() cb.if_(next cne 1, cb._fatal(s"bad buffer state!")) - val base = spec.encodedType.buildDecoder(concreteType, cb.emb.ecb)(cb, region, ib).asBaseStruct + val base = + spec.encodedType.buildDecoder(concreteType, cb.emb.ecb)(cb, region, ib).asBaseStruct if (insertUID) - cb.assign(decodedRow, new SInsertFieldsStructValue( - eltSType.asInstanceOf[SInsertFieldsStruct], - base, - Array(EmitValue.present( - new SStackStructValue(uidSType, Array( - EmitValue.present(partIdx), - EmitValue.present(primitive(curIdx)))))))) + cb.assign( + decodedRow, + new SInsertFieldsStructValue( + eltSType.asInstanceOf[SInsertFieldsStruct], + base, + Array(EmitValue.present( + new SStackStructValue( + uidSType, + Array( + EmitValue.present(partIdx), + EmitValue.present(primitive(curIdx)), + ), + ) + )), + ), + ) else cb.assign(decodedRow, base) cb.assign(curIdx, curIdx + 1L) @@ -1093,13 +1451,14 @@ case class PartitionNativeReaderIndexed( // Result uses the uid field name and values from the right input, and ignores // uids from the left. case class PartitionZippedNativeReader(left: PartitionReader, right: PartitionReader) - extends PartitionReader { + extends PartitionReader { def uidFieldName = right.uidFieldName def contextType: Type = TStruct( "leftContext" -> left.contextType, - "rightContext" -> right.contextType) + "rightContext" -> right.contextType, + ) def splitRequestedType(requestedType: TStruct): (TStruct, TStruct) = { val leftStruct = left.fullRowType.deleteKey(left.uidFieldName) @@ -1116,7 +1475,9 @@ case class PartitionZippedNativeReader(left: PartitionReader, right: PartitionRe val lRequired = left.rowRequiredness(lRequested) val rRequired = right.rowRequiredness(rRequested) - RStruct.fromNamesAndTypes(requestedType.fieldNames.map(f => (f, lRequired.fieldType.getOrElse(f, rRequired.fieldType(f))))) + RStruct.fromNamesAndTypes(requestedType.fieldNames.map(f => + (f, lRequired.fieldType.getOrElse(f, rRequired.fieldType(f))) + )) } lazy val fullRowType: TStruct = { @@ -1127,11 +1488,12 @@ case class PartitionZippedNativeReader(left: PartitionReader, right: PartitionRe def toJValue: JValue = Extraction.decompose(this)(PartitionReader.formats) - override def emitStream(ctx: ExecuteContext, + override def emitStream( + ctx: ExecuteContext, cb: EmitCodeBuilder, mb: EmitMethodBuilder[_], context: EmitCode, - requestedType: TStruct + requestedType: TStruct, ): IEmitCode = { val (lRequested, rRequested) = splitRequestedType(requestedType) @@ -1140,7 +1502,6 @@ case class PartitionZippedNativeReader(left: PartitionReader, right: PartitionRe val ctx2 = EmitCode.fromI(cb.emb)(zippedContext.loadField(_, "rightContext")) left.emitStream(ctx, cb, mb, ctx1, lRequested).flatMap(cb) { sstream1 => right.emitStream(ctx, cb, mb, ctx2, rRequested).map(cb) { sstream2 => - val stream1 = sstream1.asStream.getProducer(cb.emb) val stream2 = sstream2.asStream.getProducer(cb.emb) @@ -1196,9 +1557,13 @@ case class PartitionZippedNativeReader(left: PartitionReader, right: PartitionRe } } -case class PartitionZippedIndexedNativeReader(specLeft: AbstractTypedCodecSpec, specRight: AbstractTypedCodecSpec, - indexSpecLeft: AbstractIndexSpec, indexSpecRight: AbstractIndexSpec, - key: IndexedSeq[String], uidFieldName: String +case class PartitionZippedIndexedNativeReader( + specLeft: AbstractTypedCodecSpec, + specRight: AbstractTypedCodecSpec, + indexSpecLeft: AbstractIndexSpec, + indexSpecRight: AbstractIndexSpec, + key: IndexedSeq[String], + uidFieldName: String, ) extends PartitionReader { def contextType: Type = { @@ -1207,7 +1572,9 @@ case class PartitionZippedIndexedNativeReader(specLeft: AbstractTypedCodecSpec, "leftPartitionPath" -> TString, "rightPartitionPath" -> TString, "indexPath" -> TString, - "interval" -> RVDPartitioner.intervalIRRepresentation(specLeft.encodedVirtualType.asInstanceOf[TStruct].select(key)._1) + "interval" -> RVDPartitioner.intervalIRRepresentation( + specLeft.encodedVirtualType.asInstanceOf[TStruct].select(key)._1 + ), ) } @@ -1229,17 +1596,27 @@ case class PartitionZippedIndexedNativeReader(specLeft: AbstractTypedCodecSpec, def rowRequiredness(requestedType: TStruct): RStruct = { val (leftStruct, rightStruct) = splitRequestedTypes(requestedType) val rt = TypeWithRequiredness(requestedType).asInstanceOf[RStruct] - val pt = specLeft.decodedPType(leftStruct).asInstanceOf[PStruct].insertFields(specRight.decodedPType(rightStruct).asInstanceOf[PStruct].fields.map(f => (f.name, f.typ))) + val pt = specLeft.decodedPType(leftStruct).asInstanceOf[PStruct].insertFields( + specRight.decodedPType(rightStruct).asInstanceOf[PStruct].fields.map(f => (f.name, f.typ)) + ) rt.fromPType(pt) rt } val uidSType: SStackStruct = SStackStruct( TTuple(TInt64, TInt64, TInt64, TInt64), - Array(EmitType(SInt64, true), EmitType(SInt64, true), EmitType(SInt64, true), EmitType(SInt64, true))) + Array( + EmitType(SInt64, true), + EmitType(SInt64, true), + EmitType(SInt64, true), + EmitType(SInt64, true), + ), + ) def fullRowType: TStruct = - (specLeft.encodedVirtualType.asInstanceOf[TStruct] ++ specRight.encodedVirtualType.asInstanceOf[TStruct]) + (specLeft.encodedVirtualType.asInstanceOf[TStruct] ++ specRight.encodedVirtualType.asInstanceOf[ + TStruct + ]) .insertFields(Array(uidFieldName -> TTuple(TInt64, TInt64, TInt64, TInt64))) def emitStream( @@ -1247,7 +1624,8 @@ case class PartitionZippedIndexedNativeReader(specLeft: AbstractTypedCodecSpec, cb: EmitCodeBuilder, mb: EmitMethodBuilder[_], context: EmitCode, - requestedType: TStruct): IEmitCode = { + requestedType: TStruct, + ): IEmitCode = { val (leftRType, rightRType) = splitRequestedTypes(requestedType) @@ -1256,7 +1634,8 @@ case class PartitionZippedIndexedNativeReader(specLeft: AbstractTypedCodecSpec, val leftOffsetFieldIndex = indexSpecLeft.offsetFieldIndex val rightOffsetFieldIndex = indexSpecRight.offsetFieldIndex - val index = new StagedIndexReader(cb.emb, indexSpecLeft.leafCodec, indexSpecLeft.internalNodeCodec) + val index = + new StagedIndexReader(cb.emb, indexSpecLeft.leafCodec, indexSpecLeft.internalNodeCodec) context.toI(cb).map(cb) { case _ctxStruct: SBaseStructValue => val ctxStruct = cb.memoizeField(_ctxStruct, "ctxStruct").asBaseStruct @@ -1282,76 +1661,97 @@ case class PartitionZippedIndexedNativeReader(specLeft: AbstractTypedCodecSpec, override def initialize(cb: EmitCodeBuilder, outerRegion: Value[Region]): Unit = { val indexPath = ctxStruct .loadField(cb, "indexPath") - .get(cb) + .getOrAssert(cb) .asString .loadString(cb) val interval = ctxStruct .loadField(cb, "interval") - .get(cb) + .getOrAssert(cb) .asInterval index.initialize(cb, indexPath) val indexResult = index.queryInterval(cb, interval) val startIndex = indexResult.loadField(cb, 0) - .get(cb) + .getOrAssert(cb) .asInt64 .value val endIndex = indexResult.loadField(cb, 1) - .get(cb) + .getOrAssert(cb) .asInt64 .value cb.assign(curIdx, startIndex) cb.assign(endIdx, endIndex) - cb.assign(partIdx, ctxStruct.loadField(cb, "partitionIndex").get(cb).asInt64.value) - cb.assign(leftBuffer, specLeft.buildCodeInputBuffer( - Code.newInstance[ByteTrackingInputStream, InputStream]( - mb.openUnbuffered(ctxStruct.loadField(cb, "leftPartitionPath") - .get(cb) - .asString - .loadString(cb), true)))) - cb.assign(rightBuffer, specRight.buildCodeInputBuffer( - Code.newInstance[ByteTrackingInputStream, InputStream]( - mb.openUnbuffered(ctxStruct.loadField(cb, "rightPartitionPath") - .get(cb) - .asString - .loadString(cb), true)))) - - cb.if_(endIndex > startIndex, { - val leafNode = indexResult.loadField(cb, 2) - .get(cb) - .asBaseStruct - - val leftSeekAddr = leftOffsetFieldIndex match { - case Some(offsetIdx) => - leafNode - .loadField(cb, "annotation") - .get(cb) - .asBaseStruct - .loadField(cb, offsetIdx) - .get(cb) - case None => - leafNode - .loadField(cb, "offset") - .get(cb) - } - cb += leftBuffer.seek(leftSeekAddr.asInt64.value) - - val rightSeekAddr = rightOffsetFieldIndex match { - case Some(offsetIdx) => - leafNode - .loadField(cb, "annotation") - .get(cb) - .asBaseStruct - .loadField(cb, offsetIdx) - .get(cb) - case None => - leafNode - .loadField(cb, "offset") - .get(cb) - } - cb += rightBuffer.seek(rightSeekAddr.asInt64.value) - }) + cb.assign( + partIdx, + ctxStruct.loadField(cb, "partitionIndex").getOrAssert(cb).asInt64.value, + ) + cb.assign( + leftBuffer, + specLeft.buildCodeInputBuffer( + Code.newInstance[ByteTrackingInputStream, InputStream]( + mb.openUnbuffered( + ctxStruct.loadField(cb, "leftPartitionPath") + .getOrAssert(cb) + .asString + .loadString(cb), + true, + ) + ) + ), + ) + cb.assign( + rightBuffer, + specRight.buildCodeInputBuffer( + Code.newInstance[ByteTrackingInputStream, InputStream]( + mb.openUnbuffered( + ctxStruct.loadField(cb, "rightPartitionPath") + .getOrAssert(cb) + .asString + .loadString(cb), + true, + ) + ) + ), + ) + + cb.if_( + endIndex > startIndex, { + val leafNode = indexResult.loadField(cb, 2) + .getOrAssert(cb) + .asBaseStruct + + val leftSeekAddr = leftOffsetFieldIndex match { + case Some(offsetIdx) => + leafNode + .loadField(cb, "annotation") + .getOrAssert(cb) + .asBaseStruct + .loadField(cb, offsetIdx) + .getOrAssert(cb) + case None => + leafNode + .loadField(cb, "offset") + .getOrAssert(cb) + } + cb += leftBuffer.seek(leftSeekAddr.asInt64.value) + + val rightSeekAddr = rightOffsetFieldIndex match { + case Some(offsetIdx) => + leafNode + .loadField(cb, "annotation") + .getOrAssert(cb) + .asBaseStruct + .loadField(cb, offsetIdx) + .getOrAssert(cb) + case None => + leafNode + .loadField(cb, "offset") + .getOrAssert(cb) + } + cb += rightBuffer.seek(rightSeekAddr.asInt64.value) + }, + ) index.close(cb) } @@ -1371,13 +1771,23 @@ case class PartitionZippedIndexedNativeReader(specLeft: AbstractTypedCodecSpec, } override val element: EmitCode = EmitCode.fromI(mb) { cb => if (insertUID) { - val uid = SStackStruct.constructFromArgs(cb, region, TTuple(TInt64, TInt64), + val uid = SStackStruct.constructFromArgs( + cb, + region, + TTuple(TInt64, TInt64), EmitCode.present(mb, primitive(partIdx)), - EmitCode.present(mb, primitive(cb.memoize(curIdx.get - 1L)))) + EmitCode.present(mb, primitive(cb.memoize(curIdx.get - 1L))), + ) val merged = SBaseStruct.merge(cb, leftValue.asBaseStruct, rightValue.asBaseStruct) - IEmitCode.present(cb, merged._insert(requestedType, uidFieldName -> EmitValue.present(uid))) + IEmitCode.present( + cb, + merged._insert(requestedType, uidFieldName -> EmitValue.present(uid)), + ) } else { - IEmitCode.present(cb, SBaseStruct.merge(cb, leftValue.asBaseStruct, rightValue.asBaseStruct)) + IEmitCode.present( + cb, + SBaseStruct.merge(cb, leftValue.asBaseStruct, rightValue.asBaseStruct), + ) } } @@ -1395,17 +1805,19 @@ case class PartitionZippedIndexedNativeReader(specLeft: AbstractTypedCodecSpec, case class TableNativeReaderParameters( path: String, - options: Option[NativeReaderOptions]) + options: Option[NativeReaderOptions], +) class TableNativeReader( val params: TableNativeReaderParameters, - val spec: AbstractTableSpec + val spec: AbstractTableSpec, ) extends TableReaderWithExtraUID { def pathsUsed: Seq[String] = Array(params.path) val filterIntervals: Boolean = params.options.map(_.filterIntervals).getOrElse(false) - def partitionCounts: Option[IndexedSeq[Long]] = if (params.options.isDefined) None else Some(spec.partitionCounts) + def partitionCounts: Option[IndexedSeq[Long]] = + if (params.options.isDefined) None else Some(spec.partitionCounts) override def isDistinctlyKeyed: Boolean = spec.isDistinctlyKeyed @@ -1413,14 +1825,16 @@ class TableNativeReader( def fullTypeWithoutUIDs = spec.table_type - override def concreteRowRequiredness(ctx: ExecuteContext, requestedType: TableType): VirtualTypeWithReq = + override def concreteRowRequiredness(ctx: ExecuteContext, requestedType: TableType) + : VirtualTypeWithReq = VirtualTypeWithReq(tcoerce[PStruct](spec.rowsComponent.rvdSpec(ctx.fs, params.path) .typedCodecSpec.encodedType.decodedPType(requestedType.rowType))) protected def uidRequiredness: VirtualTypeWithReq = VirtualTypeWithReq(PCanonicalTuple(true, PInt64Required, PInt64Required)) - override def globalRequiredness(ctx: ExecuteContext, requestedType: TableType): VirtualTypeWithReq = + override def globalRequiredness(ctx: ExecuteContext, requestedType: TableType) + : VirtualTypeWithReq = VirtualTypeWithReq(tcoerce[PStruct](spec.globalsComponent.rvdSpec(ctx.fs, params.path) .typedCodecSpec.encodedType.decodedPType(requestedType.globalType))) @@ -1429,7 +1843,8 @@ class TableNativeReader( decomposeWithName(params, "TableNativeReader") } - override def renderShort(): String = s"(TableNativeReader ${ params.path } ${ params.options.map(_.renderShort()).getOrElse("") })" + override def renderShort(): String = + s"(TableNativeReader ${params.path} ${params.options.map(_.renderShort()).getOrElse("")})" override def hashCode(): Int = params.hashCode() @@ -1438,7 +1853,7 @@ class TableNativeReader( case _ => false } - override def toString(): String = s"TableNativeReader(${ params })" + override def toString(): String = s"TableNativeReader($params)" override def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR = { val globalsSpec = spec.globalsSpec @@ -1446,10 +1861,15 @@ class TableNativeReader( assert(!requestedGlobalsType.hasField(uidFieldName)) ArrayRef( ToArray(ReadPartition( - MakeStruct(Array("partitionIndex" -> I64(0), "partitionPath" -> Str(globalsSpec.absolutePartPaths(globalsPath).head))), + MakeStruct(Array( + "partitionIndex" -> I64(0), + "partitionPath" -> Str(globalsSpec.absolutePartPaths(globalsPath).head), + )), requestedGlobalsType, - PartitionNativeReader(globalsSpec.typedCodecSpec, uidFieldName))), - 0) + PartitionNativeReader(globalsSpec.typedCodecSpec, uidFieldName), + )), + 0, + ) } override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = { @@ -1457,9 +1877,18 @@ class TableNativeReader( val rowsSpec = spec.rowsSpec val specPart = rowsSpec.partitioner(ctx.stateManager) val partitioner = if (filterIntervals) - params.options.map(opts => RVDPartitioner.union(ctx.stateManager, specPart.kType, opts.intervals, specPart.kType.size - 1)) + params.options.map(opts => + RVDPartitioner.union( + ctx.stateManager, + specPart.kType, + opts.intervals, + specPart.kType.size - 1, + ) + ) else - params.options.map(opts => new RVDPartitioner(ctx.stateManager, specPart.kType, opts.intervals)) + params.options.map(opts => + new RVDPartitioner(ctx.stateManager, specPart.kType, opts.intervals) + ) // If the data on disk already has a uidFieldName field, we should read it // as is. Do this by passing a dummy uidFieldName to the rows component, @@ -1469,7 +1898,14 @@ class TableNativeReader( else uidFieldName - spec.rowsSpec.readTableStage(ctx, spec.rowsComponent.absolutePath(params.path), requestedType, requestedUIDFieldName, partitioner, filterIntervals).apply(globals) + spec.rowsSpec.readTableStage( + ctx, + spec.rowsComponent.absolutePath(params.path), + requestedType, + requestedUIDFieldName, + partitioner, + filterIntervals, + ).apply(globals) } } @@ -1478,27 +1914,34 @@ case class TableNativeZippedReader( pathRight: String, options: Option[NativeReaderOptions], specLeft: AbstractTableSpec, - specRight: AbstractTableSpec + specRight: AbstractTableSpec, ) extends TableReaderWithExtraUID { def pathsUsed: Seq[String] = FastSeq(pathLeft, pathRight) - override def renderShort(): String = s"(TableNativeZippedReader $pathLeft $pathRight ${ options.map(_.renderShort()).getOrElse("") })" + override def renderShort(): String = + s"(TableNativeZippedReader $pathLeft $pathRight ${options.map(_.renderShort()).getOrElse("")})" private lazy val filterIntervals = options.exists(_.filterIntervals) private def intervals = options.map(_.intervals) - require((specLeft.table_type.rowType.fieldNames ++ specRight.table_type.rowType.fieldNames).areDistinct()) + require( + (specLeft.table_type.rowType.fieldNames ++ specRight.table_type.rowType.fieldNames).areDistinct() + ) + require(specRight.table_type.key.isEmpty) require(specLeft.partitionCounts sameElements specRight.partitionCounts) require(specLeft.version == specRight.version) - def partitionCounts: Option[IndexedSeq[Long]] = if (intervals.isEmpty) Some(specLeft.partitionCounts) else None + def partitionCounts: Option[IndexedSeq[Long]] = + if (intervals.isEmpty) Some(specLeft.partitionCounts) else None override def uidType = TTuple(TInt64, TInt64) override def fullTypeWithoutUIDs: TableType = specLeft.table_type.copy( - rowType = specLeft.table_type.rowType ++ specRight.table_type.rowType) + rowType = specLeft.table_type.rowType ++ specRight.table_type.rowType + ) + private val leftFieldSet = specLeft.table_type.rowType.fieldNames.toSet private val rightFieldSet = specRight.table_type.rowType.fieldNames.toSet @@ -1516,24 +1959,40 @@ case class TableNativeZippedReader( tcoerce[PStruct](specRight.rowsComponent.rvdSpec(ctx.fs, pathRight) .typedCodecSpec.encodedType.decodedPType(rightRType)) - override def concreteRowRequiredness(ctx: ExecuteContext, requestedType: TableType): VirtualTypeWithReq = - VirtualTypeWithReq(fieldInserter(ctx, leftPType(ctx, leftRType(requestedType.rowType)), - rightPType(ctx, rightRType(requestedType.rowType)))._1) + override def concreteRowRequiredness(ctx: ExecuteContext, requestedType: TableType) + : VirtualTypeWithReq = + VirtualTypeWithReq(fieldInserter( + ctx, + leftPType(ctx, leftRType(requestedType.rowType)), + rightPType(ctx, rightRType(requestedType.rowType)), + )._1) override def uidRequiredness: VirtualTypeWithReq = VirtualTypeWithReq(PCanonicalTuple(true, PInt64Required, PInt64Required)) - override def globalRequiredness(ctx: ExecuteContext, requestedType: TableType): VirtualTypeWithReq = + override def globalRequiredness(ctx: ExecuteContext, requestedType: TableType) + : VirtualTypeWithReq = VirtualTypeWithReq(specLeft.globalsComponent.rvdSpec(ctx.fs, pathLeft) .typedCodecSpec.encodedType.decodedPType(requestedType.globalType)) - def fieldInserter(ctx: ExecuteContext, pLeft: PStruct, pRight: PStruct): (PStruct, (HailClassLoader, FS, HailTaskContext, Region) => AsmFunction3RegionLongLongLong) = { - val (Some(PTypeReferenceSingleCodeType(t: PStruct)), mk) = ir.Compile[AsmFunction3RegionLongLongLong](ctx, - FastSeq("left" -> SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(pLeft)), "right" -> SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(pRight))), - FastSeq(typeInfo[Region], LongInfo, LongInfo), LongInfo, - InsertFields(Ref("left", pLeft.virtualType), - pRight.fieldNames.map(f => - f -> GetField(Ref("right", pRight.virtualType), f)))) + def fieldInserter(ctx: ExecuteContext, pLeft: PStruct, pRight: PStruct) + : (PStruct, (HailClassLoader, FS, HailTaskContext, Region) => AsmFunction3RegionLongLongLong) = { + val (Some(PTypeReferenceSingleCodeType(t: PStruct)), mk) = + ir.Compile[AsmFunction3RegionLongLongLong]( + ctx, + FastSeq( + "left" -> SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(pLeft)), + "right" -> SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(pRight)), + ), + FastSeq(typeInfo[Region], LongInfo, LongInfo), + LongInfo, + InsertFields( + Ref("left", pLeft.virtualType), + pRight.fieldNames.map(f => + f -> GetField(Ref("right", pRight.virtualType), f) + ), + ), + ) (t, mk) } @@ -1542,10 +2001,15 @@ case class TableNativeZippedReader( val globalsPath = specLeft.globalsComponent.absolutePath(pathLeft) ArrayRef( ToArray(ReadPartition( - MakeStruct(Array("partitionIndex" -> I64(0), "partitionPath" -> Str(globalsSpec.absolutePartPaths(globalsPath).head))), + MakeStruct(Array( + "partitionIndex" -> I64(0), + "partitionPath" -> Str(globalsSpec.absolutePartPaths(globalsPath).head), + )), requestedGlobalsType, - PartitionNativeReader(globalsSpec.typedCodecSpec, uidFieldName))), - 0) + PartitionNativeReader(globalsSpec.typedCodecSpec, uidFieldName), + )), + 0, + ) } override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = { @@ -1553,29 +2017,51 @@ case class TableNativeZippedReader( val rowsSpec = specLeft.rowsSpec val specPart = rowsSpec.partitioner(ctx.stateManager) val partitioner = if (filterIntervals) - options.map(opts => RVDPartitioner.union(ctx.stateManager, specPart.kType, opts.intervals, specPart.kType.size - 1)) + options.map(opts => + RVDPartitioner.union( + ctx.stateManager, + specPart.kType, + opts.intervals, + specPart.kType.size - 1, + ) + ) else options.map(opts => new RVDPartitioner(ctx.stateManager, specPart.kType, opts.intervals)) - AbstractRVDSpec.readZippedLowered(ctx, - specLeft.rowsSpec, specRight.rowsSpec, - pathLeft + "/rows", pathRight + "/rows", - partitioner, filterIntervals, - requestedType.rowType, requestedType.key, uidFieldName + AbstractRVDSpec.readZippedLowered( + ctx, + specLeft.rowsSpec, + specRight.rowsSpec, + pathLeft + "/rows", + pathRight + "/rows", + partitioner, + filterIntervals, + requestedType.rowType, + requestedType.key, + uidFieldName, ).apply(globals) } } object TableFromBlockMatrixNativeReader { - def apply(fs: FS, params: TableFromBlockMatrixNativeReaderParameters): TableFromBlockMatrixNativeReader = { + def apply(fs: FS, params: TableFromBlockMatrixNativeReaderParameters) + : TableFromBlockMatrixNativeReader = { val metadata: BlockMatrixMetadata = BlockMatrix.readMetadata(fs, params.path) TableFromBlockMatrixNativeReader(params, metadata) } - def apply(fs: FS, path: String, nPartitions: Option[Int] = None, maximumCacheMemoryInBytes: Option[Int] = None): TableFromBlockMatrixNativeReader = - TableFromBlockMatrixNativeReader(fs, TableFromBlockMatrixNativeReaderParameters(path, nPartitions, maximumCacheMemoryInBytes)) + def apply( + fs: FS, + path: String, + nPartitions: Option[Int] = None, + maximumCacheMemoryInBytes: Option[Int] = None, + ): TableFromBlockMatrixNativeReader = + TableFromBlockMatrixNativeReader( + fs, + TableFromBlockMatrixNativeReaderParameters(path, nPartitions, maximumCacheMemoryInBytes), + ) def fromJValue(fs: FS, jv: JValue): TableFromBlockMatrixNativeReader = { implicit val formats: Formats = TableReader.formats @@ -1584,11 +2070,15 @@ object TableFromBlockMatrixNativeReader { } } -case class TableFromBlockMatrixNativeReaderParameters(path: String, nPartitions: Option[Int], maximumCacheMemoryInBytes: Option[Int]) +case class TableFromBlockMatrixNativeReaderParameters( + path: String, + nPartitions: Option[Int], + maximumCacheMemoryInBytes: Option[Int], +) case class TableFromBlockMatrixNativeReader( params: TableFromBlockMatrixNativeReaderParameters, - metadata: BlockMatrixMetadata + metadata: BlockMatrixMetadata, ) extends TableReaderWithExtraUID { def pathsUsed: Seq[String] = FastSeq(params.path) @@ -1601,49 +2091,62 @@ case class TableFromBlockMatrixNativeReader( start until end } - override def partitionCounts: Option[IndexedSeq[Long]] = { + override def partitionCounts: Option[IndexedSeq[Long]] = Some(partitionRanges.map(r => r.end - r.start)) - } override def uidType = TInt64 override def fullTypeWithoutUIDs: TableType = TableType( TStruct("row_idx" -> TInt64, "entries" -> TArray(TFloat64)), Array("row_idx"), - TStruct.empty) + TStruct.empty, + ) - override def concreteRowRequiredness(ctx: ExecuteContext, requestedType: TableType): VirtualTypeWithReq = + override def concreteRowRequiredness(ctx: ExecuteContext, requestedType: TableType) + : VirtualTypeWithReq = VirtualTypeWithReq(PType.canonical(requestedType.rowType).setRequired(true)) override def uidRequiredness: VirtualTypeWithReq = VirtualTypeWithReq(PInt64Required) - override def globalRequiredness(ctx: ExecuteContext, requestedType: TableType): VirtualTypeWithReq = + override def globalRequiredness(ctx: ExecuteContext, requestedType: TableType) + : VirtualTypeWithReq = VirtualTypeWithReq(PCanonicalStruct.empty(required = true)) - override def toExecuteIntermediate(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableExecuteIntermediate = { + override def toExecuteIntermediate( + ctx: ExecuteContext, + requestedType: TableType, + dropRows: Boolean, + ): TableExecuteIntermediate = { assert(!dropRows) val rowsRDD = new BlockMatrixReadRowBlockedRDD( - ctx.fsBc, params.path, partitionRanges, requestedType.rowType, metadata, - maybeMaximumCacheMemoryInBytes = params.maximumCacheMemoryInBytes) + ctx.fsBc, + params.path, + partitionRanges, + requestedType.rowType, + metadata, + maybeMaximumCacheMemoryInBytes = params.maximumCacheMemoryInBytes, + ) - val partitionBounds = partitionRanges.map { r => Interval(Row(r.start), Row(r.end), true, false) } + val partitionBounds = partitionRanges.map { r => + Interval(Row(r.start), Row(r.end), true, false) + } val partitioner = new RVDPartitioner(ctx.stateManager, fullType.keyType, partitionBounds) val rowTyp = PType.canonical(requestedType.rowType, required = true).asInstanceOf[PStruct] - val rvd = RVD(RVDType(rowTyp, fullType.key.filter(rowTyp.hasField)), partitioner, ContextRDD(rowsRDD)) + val rvd = + RVD(RVDType(rowTyp, fullType.key.filter(rowTyp.hasField)), partitioner, ContextRDD(rowsRDD)) TableExecuteIntermediate(TableValue(ctx, requestedType, BroadcastRow.empty(ctx), rvd)) } override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = - throw new LowererUnsupportedOperation(s"${ getClass.getSimpleName }.lower not implemented") + throw new LowererUnsupportedOperation(s"${getClass.getSimpleName}.lower not implemented") override def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR = - throw new LowererUnsupportedOperation(s"${ getClass.getSimpleName }.lowerGlobals not implemented") + throw new LowererUnsupportedOperation(s"${getClass.getSimpleName}.lowerGlobals not implemented") - override def toJValue: JValue = { + override def toJValue: JValue = decomposeWithName(params, "TableFromBlockMatrixNativeReader")(TableReader.formats) - } def renderShort(): String = defaultRender() } @@ -1652,23 +2155,25 @@ object TableRead { def native(fs: FS, path: String, uidField: Boolean = false): TableRead = { val tr = TableNativeReader(fs, TableNativeReaderParameters(path, None)) val requestedType = if (uidField) - tr.fullType + tr.fullType else tr.fullType.copy( - rowType = tr.fullType.rowType.deleteKey(TableReader.uidFieldName)) + rowType = tr.fullType.rowType.deleteKey(TableReader.uidFieldName) + ) TableRead(requestedType, false, tr) } } case class TableRead(typ: TableType, dropRows: Boolean, tr: TableReader) extends TableIR { - try { + try assert(PruneDeadFields.isSupertype(typ, tr.fullType)) - } catch { + catch { case e: Throwable => fatal(s"bad type:\n full type: ${tr.fullType}\n requested: $typ\n reader: $tr", e) } - override def partitionCounts: Option[IndexedSeq[Long]] = if (dropRows) Some(FastSeq(0L)) else tr.partitionCounts + override def partitionCounts: Option[IndexedSeq[Long]] = + if (dropRows) Some(FastSeq(0L)) else tr.partitionCounts def isDistinctlyKeyed: Boolean = tr.isDistinctlyKeyed @@ -1681,7 +2186,8 @@ case class TableRead(typ: TableType, dropRows: Boolean, tr: TableReader) extends TableRead(typ, dropRows, tr) } - protected[ir] override def execute(ctx: ExecuteContext, r: LoweringAnalyses): TableExecuteIntermediate = + override protected[ir] def execute(ctx: ExecuteContext, r: LoweringAnalyses) + : TableExecuteIntermediate = tr.toExecuteIntermediate(ctx, typ, dropRows) } @@ -1703,17 +2209,21 @@ case class TableParallelize(rowsAndGlobal: IR, nPartitions: Option[Int] = None) lazy val typ: TableType = { def rowsType = rowsAndGlobal.typ.asInstanceOf[TStruct].fieldType("rows").asInstanceOf[TArray] - def globalsType = rowsAndGlobal.typ.asInstanceOf[TStruct].fieldType("global").asInstanceOf[TStruct] + def globalsType = + rowsAndGlobal.typ.asInstanceOf[TStruct].fieldType("global").asInstanceOf[TStruct] TableType( rowsType.elementType.asInstanceOf[TStruct], FastSeq(), - globalsType) + globalsType, + ) } - protected[ir] override def execute(ctx: ExecuteContext, r: LoweringAnalyses): TableExecuteIntermediate = { - val (ptype: PStruct, res) = CompileAndEvaluate._apply(ctx, rowsAndGlobal, optimize = false) match { - case Right((t, off)) => (t.fields(0).typ, t.loadField(off, 0)) - } + override protected[ir] def execute(ctx: ExecuteContext, r: LoweringAnalyses) + : TableExecuteIntermediate = { + val (ptype: PStruct, res) = + CompileAndEvaluate._apply(ctx, rowsAndGlobal, optimize = false) match { + case Right((t, off)) => (t.fields(0).typ, t.loadField(off, 0)) + } val globalsT = ptype.types(1).setRequired(true).asInstanceOf[PStruct] if (ptype.isFieldMissing(res, 1)) @@ -1747,8 +2257,11 @@ case class TableParallelize(rowsAndGlobal: IR, nPartitions: Option[Int] = None) } val (resultRowType: PStruct, makeDec) = spec.buildDecoder(ctx, typ.rowType) - assert(resultRowType.virtualType == typ.rowType, s"typ mismatch:" + - s"\n res=${ resultRowType.virtualType }\n typ=${ typ.rowType }") + assert( + resultRowType.virtualType == typ.rowType, + s"typ mismatch:" + + s"\n res=${resultRowType.virtualType}\n typ=${typ.rowType}", + ) log.info(s"parallelized $nRows rows in $nSplits partitions") @@ -1758,29 +2271,26 @@ case class TableParallelize(rowsAndGlobal: IR, nPartitions: Option[Int] = None) val bais = new ByteArrayDecoder(theHailClassLoaderForSparkWorkers, makeDec) bais.set(arr) Iterator.range(0, nRowPartition) - .map { _ => - bais.readValue(ctx.region) - } + .map(_ => bais.readValue(ctx.region)) } } new TableValueIntermediate(TableValue(ctx, typ, globals, RVD.unkeyed(resultRowType, rvd))) } } -/** - * Change the table to have key 'keys'. +/** Change the table to have key 'keys'. * - * Let n be the longest common prefix of 'keys' and the old key, i.e. the - * number of key fields that are not being changed. - * - If 'isSorted', then 'child' must already be sorted by 'keys', and n must - * not be zero. Thus, if 'isSorted', TableKeyBy will not shuffle or scan. - * The new partitioner will be the old one with partition bounds truncated - * to length n. - * - If n = 'keys.length', i.e. we are simply shortening the key, do nothing - * but change the table type to the new key. 'isSorted' is ignored. - * - Otherwise, if 'isSorted' is false and n < 'keys.length', then shuffle. + * Let n be the longest common prefix of 'keys' and the old key, i.e. the number of key fields that + * are not being changed. + * - If 'isSorted', then 'child' must already be sorted by 'keys', and n must not be zero. Thus, + * if 'isSorted', TableKeyBy will not shuffle or scan. The new partitioner will be the old one + * with partition bounds truncated to length n. + * - If n = 'keys.length', i.e. we are simply shortening the key, do nothing but change the table + * type to the new key. 'isSorted' is ignored. + * - Otherwise, if 'isSorted' is false and n < 'keys.length', then shuffle. */ -case class TableKeyBy(child: TableIR, keys: IndexedSeq[String], isSorted: Boolean = false) extends TableIR { +case class TableKeyBy(child: TableIR, keys: IndexedSeq[String], isSorted: Boolean = false) + extends TableIR { override def typecheck(): Unit = { val fields = child.typ.rowType.fieldNames.toSet assert(keys.forall(fields.contains), s"${keys.filter(k => !fields.contains(k)).mkString(", ")}") @@ -1799,33 +2309,39 @@ case class TableKeyBy(child: TableIR, keys: IndexedSeq[String], isSorted: Boolea TableKeyBy(newChildren(0).asInstanceOf[TableIR], keys, isSorted) } - protected[ir] override def execute(ctx: ExecuteContext, r: LoweringAnalyses): TableExecuteIntermediate = { + override protected[ir] def execute(ctx: ExecuteContext, r: LoweringAnalyses) + : TableExecuteIntermediate = { val tv = child.execute(ctx, r).asTableValue(ctx) TableValueIntermediate(tv.copy(typ = typ, rvd = tv.rvd.enforceKey(ctx, keys, isSorted))) } } -/** - * Generate a table from the elementwise application of a body IR to a stream of `contexts`. - * - * @param contexts IR of type TStream[Any] whose elements are downwardly exposed to `body` as `cname`. - * @param globals IR of type TStruct, downwardly exposed to `body` as `gname`. - * @param cname Name of free variable in `body` referencing elements of `contexts`. - * @param gname Name of free variable in `body` referencing `globals`. - * @param body IR of type TStream[TStruct] that generates the rows of the table for each - * element in `contexts`, optionally referencing free variables Ref(cname) and - * Ref(gname). - * @param partitioner - * @param errorId Identifier tracing location in Python source that created this node - */ -case class TableGen(contexts: IR, - globals: IR, - cname: String, - gname: String, - body: IR, - partitioner: RVDPartitioner, - errorId: Int = ErrorIDs.NO_ERROR - ) extends TableIR { +/** Generate a table from the elementwise application of a body IR to a stream of `contexts`. + * + * @param contexts + * IR of type TStream[Any] whose elements are downwardly exposed to `body` as `cname`. + * @param globals + * IR of type TStruct, downwardly exposed to `body` as `gname`. + * @param cname + * Name of free variable in `body` referencing elements of `contexts`. + * @param gname + * Name of free variable in `body` referencing `globals`. + * @param body + * IR of type TStream[TStruct] that generates the rows of the table for each element in + * `contexts`, optionally referencing free variables Ref(cname) and Ref(gname). + * @param partitioner + * @param errorId + * Identifier tracing location in Python source that created this node + */ +case class TableGen( + contexts: IR, + globals: IR, + cname: String, + gname: String, + body: IR, + partitioner: RVDPartitioner, + errorId: Int = ErrorIDs.NO_ERROR, +) extends TableIR { override def typecheck(): Unit = { TypeCheck.coerce[TStream]("contexts", contexts.typ) @@ -1838,15 +2354,15 @@ case class TableGen(contexts: IR, s"""'partitioner': key type contains fields absent from row type | Key type: ${partitioner.kType} | Row type: $rowType""".stripMargin - ) + ) } private def globalType = TypeCheck.coerce[TStruct]("globals", globals.typ) private def rowType = { - val bodyType = TypeCheck.coerce[TStream]( "body", body.typ) - TypeCheck.coerce[TStruct]( "body.elementType", bodyType.elementType) + val bodyType = TypeCheck.coerce[TStream]("body", body.typ) + TypeCheck.coerce[TStruct]("body.elementType", bodyType.elementType) } override lazy val typ: TableType = @@ -1863,8 +2379,14 @@ case class TableGen(contexts: IR, override def childrenSeq: IndexedSeq[BaseIR] = FastSeq(contexts, globals, body) - override protected[ir] def execute(ctx: ExecuteContext, r: LoweringAnalyses): TableExecuteIntermediate = - new TableStageIntermediate(LowerTableIR.applyTable(this, DArrayLowering.All, ctx, LoweringAnalyses(this, ctx))) + override protected[ir] def execute(ctx: ExecuteContext, r: LoweringAnalyses) + : TableExecuteIntermediate = + new TableStageIntermediate(LowerTableIR.applyTable( + this, + DArrayLowering.All, + ctx, + LoweringAnalyses(this, ctx), + )) } case class TableRange(n: Int, nPartitions: Int) extends TableIR { @@ -1887,22 +2409,30 @@ case class TableRange(n: Int, nPartitions: Int) extends TableIR { val typ: TableType = TableType( TStruct("idx" -> TInt32), Array("idx"), - TStruct.empty) + TStruct.empty, + ) - protected[ir] override def execute(ctx: ExecuteContext, r: LoweringAnalyses): TableExecuteIntermediate = { + override protected[ir] def execute(ctx: ExecuteContext, r: LoweringAnalyses) + : TableExecuteIntermediate = { val localRowType = PCanonicalStruct(true, "idx" -> PInt32Required) val localPartCounts = partCounts val partStarts = partCounts.scanLeft(0)(_ + _) - new TableValueIntermediate(TableValue(ctx, typ, + new TableValueIntermediate(TableValue( + ctx, + typ, BroadcastRow.empty(ctx), new RVD( RVDType(localRowType, Array("idx")), - new RVDPartitioner(ctx.stateManager, Array("idx"), typ.rowType, + new RVDPartitioner( + ctx.stateManager, + Array("idx"), + typ.rowType, Array.tabulate(nPartitionsAdj) { i => val start = partStarts(i) val end = partStarts(i + 1) Interval(Row(start), Row(end), includesStart = true, includesEnd = false) - }), + }, + ), ContextRDD.parallelize(Range(0, nPartitionsAdj), nPartitionsAdj) .cmapPartitionsWithIndex { case (i, ctx, _) => val region = ctx.region @@ -1915,7 +2445,9 @@ case class TableRange(n: Int, nPartitions: Int) extends TableIR { Region.storeInt(localRowType.fieldOffset(off, 0), j) off } - }))) + }, + ), + )) } } @@ -1931,7 +2463,8 @@ case class TableFilter(child: TableIR, pred: IR) extends TableIR { TableFilter(newChildren(0).asInstanceOf[TableIR], newChildren(1).asInstanceOf[IR]) } - protected[ir] override def execute(ctx: ExecuteContext, r: LoweringAnalyses): TableExecuteIntermediate = { + override protected[ir] def execute(ctx: ExecuteContext, r: LoweringAnalyses) + : TableExecuteIntermediate = { val tv = child.execute(ctx, r).asTableValue(ctx) if (pred == True()) @@ -1941,13 +2474,20 @@ case class TableFilter(child: TableIR, pred: IR) extends TableIR { val (Some(BooleanSingleCodeType), f) = ir.Compile[AsmFunction3RegionLongLongBoolean]( ctx, - FastSeq(("row", SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(tv.rvd.rowPType))), - ("global", SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(tv.globals.t)))), - FastSeq(classInfo[Region], LongInfo, LongInfo), BooleanInfo, - Coalesce(FastSeq(pred, False()))) + FastSeq( + ("row", SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(tv.rvd.rowPType))), + ("global", SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(tv.globals.t))), + ), + FastSeq(classInfo[Region], LongInfo, LongInfo), + BooleanInfo, + Coalesce(FastSeq(pred, False())), + ) new TableValueIntermediate( - tv.filterWithPartitionOp(ctx.theHailClassLoader, ctx.fsBc, f)((rowF, ctx, ptr, globalPtr) => rowF(ctx.region, ptr, globalPtr))) + tv.filterWithPartitionOp(ctx.theHailClassLoader, ctx.fsBc, f)((rowF, ctx, ptr, globalPtr) => + rowF(ctx.region, ptr, globalPtr) + ) + ) } } @@ -1976,7 +2516,8 @@ trait TableSubset extends TableIR { case None => Some(n) } - protected[ir] override def execute(ctx: ExecuteContext, r: LoweringAnalyses): TableExecuteIntermediate = { + override protected[ir] def execute(ctx: ExecuteContext, r: LoweringAnalyses) + : TableExecuteIntermediate = { val prev = child.execute(ctx, r).asTableValue(ctx) new TableValueIntermediate(prev.copy(rvd = subsetKind match { case TableSubset.HEAD => prev.rvd.head(n, child.partitionCounts) @@ -2023,7 +2564,8 @@ case class TableRepartition(child: TableIR, n: Int, strategy: Int) extends Table TableRepartition(newChild, n, strategy) } - protected[ir] override def execute(ctx: ExecuteContext, r: LoweringAnalyses): TableExecuteIntermediate = { + override protected[ir] def execute(ctx: ExecuteContext, r: LoweringAnalyses) + : TableExecuteIntermediate = { val prev = child.execute(ctx, r).asTableValue(ctx) val rvd = strategy match { case RepartitionStrategy.SHUFFLE => prev.rvd.coalesce(ctx, n, shuffle = true) @@ -2040,26 +2582,22 @@ object TableJoin { TableJoin(left, right, joinType, left.typ.key.length) } -/** - * Suppose 'left' has key [l_1, ..., l_n] and 'right' has key [r_1, ..., r_m]. - * Then [l_1, ..., l_j] and [r_1, ..., r_j] must have the same type, where - * j = 'joinKey'. TableJoin computes the join of 'left' and 'right' along this - * common prefix of their keys, returning a table with key - * [l_1, ..., l_j, l_{j+1}, ..., l_n, r_{j+1}, ..., r_m]. +/** Suppose 'left' has key [l_1, ..., l_n] and 'right' has key [r_1, ..., r_m]. Then [l_1, ..., l_j] + * and [r_1, ..., r_j] must have the same type, where j = 'joinKey'. TableJoin computes the join of + * 'left' and 'right' along this common prefix of their keys, returning a table with key [l_1, ..., + * l_j, l_{j+1}, ..., l_n, r_{j+1}, ..., r_m]. * - * WARNING: If 'left' has any duplicate (full) key [k_1, ..., k_n], and j < m, - * and 'right' has multiple rows with the corresponding join key - * [k_1, ..., k_j] but distinct full keys, then the resulting table will have - * out-of-order keys. To avoid this, ensure one of the following: - * * j == m - * * 'left' has distinct keys - * * 'right' has distinct join keys (length j prefix), or at least no + * WARNING: If 'left' has any duplicate (full) key [k_1, ..., k_n], and j < m, and 'right' has + * multiple rows with the corresponding join key [k_1, ..., k_j] but distinct full keys, then the + * resulting table will have out-of-order keys. To avoid this, ensure one of the following: * j == + * m * 'left' has distinct keys * 'right' has distinct join keys (length j prefix), or at least no * distinct keys with the same join key. */ case class TableJoin(left: TableIR, right: TableIR, joinType: String, joinKey: Int) - extends TableIR { + extends TableIR { require(joinKey >= 0) + require(joinType == "inner" || joinType == "left" || joinType == "right" || @@ -2068,11 +2606,12 @@ case class TableJoin(left: TableIR, right: TableIR, joinType: String, joinKey: I override def typecheck(): Unit = { assert(left.typ.key.length >= joinKey) assert(right.typ.key.length >= joinKey) - assert(left.typ.keyType.truncate(joinKey) isIsomorphicTo right.typ.keyType.truncate(joinKey)) + assert(left.typ.keyType.truncate(joinKey) isJoinableWith right.typ.keyType.truncate(joinKey)) assert( left.typ.globalType.fieldNames.toSet .intersect(right.typ.globalType.fieldNames.toSet) - .isEmpty) + .isEmpty + ) } val childrenSeq: IndexedSeq[BaseIR] = Array(left, right) @@ -2088,10 +2627,14 @@ case class TableJoin(left: TableIR, right: TableIR, joinType: String, joinKey: I val leftKeyType = TableType.keyType(leftRowType, leftKey) val leftValueType = TableType.valueType(leftRowType, leftKey) val rightValueType = TableType.valueType(rightRowType, rightKey) - if (leftValueType.fieldNames.toSet - .intersect(rightValueType.fieldNames.toSet) - .nonEmpty) - throw new RuntimeException(s"invalid join: \n left value: $leftValueType\n right value: $rightValueType") + if ( + leftValueType.fieldNames.toSet + .intersect(rightValueType.fieldNames.toSet) + .nonEmpty + ) + throw new RuntimeException( + s"invalid join: \n left value: $leftValueType\n right value: $rightValueType" + ) val newRowType = leftKeyType ++ leftValueType ++ rightValueType val newGlobalType = left.typ.globalType ++ right.typ.globalType @@ -2106,10 +2649,12 @@ case class TableJoin(left: TableIR, right: TableIR, joinType: String, joinKey: I newChildren(0).asInstanceOf[TableIR], newChildren(1).asInstanceOf[TableIR], joinType, - joinKey) + joinKey, + ) } - protected[ir] override def execute(ctx: ExecuteContext, r: LoweringAnalyses): TableExecuteIntermediate = { + override protected[ir] def execute(ctx: ExecuteContext, r: LoweringAnalyses) + : TableExecuteIntermediate = { val leftTV = left.execute(ctx, r).asTableStage(ctx) val rightTV = right.execute(ctx, r).asTableStage(ctx) TableExecuteIntermediate(LowerTableIRHelpers.lowerTableJoin(ctx, r, this, leftTV, rightTV)) @@ -2120,7 +2665,7 @@ case class TableIntervalJoin( left: TableIR, right: TableIR, root: String, - product: Boolean + product: Boolean, ) extends TableIR { lazy val childrenSeq: IndexedSeq[BaseIR] = Array(left, right) @@ -2132,11 +2677,17 @@ case class TableIntervalJoin( } override def copy(newChildren: IndexedSeq[BaseIR]): TableIR = - TableIntervalJoin(newChildren(0).asInstanceOf[TableIR], newChildren(1).asInstanceOf[TableIR], root, product) + TableIntervalJoin( + newChildren(0).asInstanceOf[TableIR], + newChildren(1).asInstanceOf[TableIR], + root, + product, + ) override def partitionCounts: Option[IndexedSeq[Long]] = left.partitionCounts - protected[ir] override def execute(ctx: ExecuteContext, r: LoweringAnalyses): TableExecuteIntermediate = { + override protected[ir] def execute(ctx: ExecuteContext, r: LoweringAnalyses) + : TableExecuteIntermediate = { val leftValue = left.execute(ctx, r).asTableValue(ctx) val rightValue = right.execute(ctx, r).asTableValue(ctx) @@ -2151,50 +2702,62 @@ case class TableIntervalJoin( if (product) { val joiner = (rightPType: PStruct) => { val leftRowType = leftRVDType.rowType - val newRowType = leftRowType.appendKey(localRoot, PCanonicalArray(rightPType.selectFields(rightValueFields))) - (RVDType(newRowType, localKey), (_: RVDContext, it: Iterator[Muple[RegionValue, Iterable[RegionValue]]]) => { - val rvb = new RegionValueBuilder(sm) - val rv2 = RegionValue() - it.map { case Muple(rv, is) => - rvb.set(rv.region) - rvb.start(newRowType) - rvb.startStruct() - rvb.addAllFields(leftRowType, rv) - rvb.startArray(is.size) - is.foreach(i => rvb.selectRegionValue(rightPType, rightRVDType.valueFieldIdx, i)) - rvb.endArray() - rvb.endStruct() - rv2.set(rv.region, rvb.end()) - - rv2 - } - }) + val newRowType = leftRowType.appendKey( + localRoot, + PCanonicalArray(rightPType.selectFields(rightValueFields)), + ) + ( + RVDType(newRowType, localKey), + (_: RVDContext, it: Iterator[Muple[RegionValue, Iterable[RegionValue]]]) => { + val rvb = new RegionValueBuilder(sm) + val rv2 = RegionValue() + it.map { case Muple(rv, is) => + rvb.set(rv.region) + rvb.start(newRowType) + rvb.startStruct() + rvb.addAllFields(leftRowType, rv) + rvb.startArray(is.size) + is.foreach(i => rvb.selectRegionValue(rightPType, rightRVDType.valueFieldIdx, i)) + rvb.endArray() + rvb.endStruct() + rv2.set(rv.region, rvb.end()) + + rv2 + } + }, + ) } leftValue.rvd.orderedLeftIntervalJoin(ctx, rightValue.rvd, joiner) } else { val joiner = (rightPType: PStruct) => { val leftRowType = leftRVDType.rowType - val newRowType = leftRowType.appendKey(localRoot, rightPType.selectFields(rightValueFields).setRequired(false)) - - (RVDType(newRowType, localKey), (_: RVDContext, it: Iterator[JoinedRegionValue]) => { - val rvb = new RegionValueBuilder(sm) - val rv2 = RegionValue() - it.map { case Muple(rv, i) => - rvb.set(rv.region) - rvb.start(newRowType) - rvb.startStruct() - rvb.addAllFields(leftRowType, rv) - if (i == null) - rvb.setMissing() - else - rvb.selectRegionValue(rightPType, rightRVDType.valueFieldIdx, i) - rvb.endStruct() - rv2.set(rv.region, rvb.end()) - - rv2 - } - }) + val newRowType = leftRowType.appendKey( + localRoot, + rightPType.selectFields(rightValueFields).setRequired(false), + ) + + ( + RVDType(newRowType, localKey), + (_: RVDContext, it: Iterator[JoinedRegionValue]) => { + val rvb = new RegionValueBuilder(sm) + val rv2 = RegionValue() + it.map { case Muple(rv, i) => + rvb.set(rv.region) + rvb.start(newRowType) + rvb.startStruct() + rvb.addAllFields(leftRowType, rv) + if (i == null) + rvb.setMissing() + else + rvb.selectRegionValue(rightPType, rightRVDType.valueFieldIdx, i) + rvb.endStruct() + rv2.set(rv.region, rvb.end()) + + rv2 + } + }, + ) } leftValue.rvd.orderedLeftIntervalJoinDistinct(ctx, rightValue.rvd, joiner) @@ -2204,12 +2767,14 @@ case class TableIntervalJoin( } } -/** - * The TableMultiWayZipJoin node assumes that input tables have distinct keys. If inputs - * do not have distinct keys, the key that is included in the result is undefined, but - * is likely the last. +/** The TableMultiWayZipJoin node assumes that input tables have distinct keys. If inputs do not + * have distinct keys, the key that is included in the result is undefined, but is likely the last. */ -case class TableMultiWayZipJoin(childrenSeq: IndexedSeq[TableIR], fieldName: String, globalName: String) extends TableIR { +case class TableMultiWayZipJoin( + childrenSeq: IndexedSeq[TableIR], + fieldName: String, + globalName: String, +) extends TableIR { require(childrenSeq.nonEmpty, "there must be at least one table as an argument") override def typecheck(): Unit = { @@ -2219,7 +2784,8 @@ case class TableMultiWayZipJoin(childrenSeq: IndexedSeq[TableIR], fieldName: Str assert(rest.forall(e => e.typ.key == first.typ.key), "all keys must be the same") assert( rest.forall(e => e.typ.globalType == first.typ.globalType), - "all globals must have the same type") + "all globals must have the same type", + ) } private def first = childrenSeq.head @@ -2232,13 +2798,14 @@ case class TableMultiWayZipJoin(childrenSeq: IndexedSeq[TableIR], fieldName: Str lazy val typ: TableType = first.typ.copy( rowType = newRowType, - globalType = newGlobalType + globalType = newGlobalType, ) def copy(newChildren: IndexedSeq[BaseIR]): TableMultiWayZipJoin = TableMultiWayZipJoin(newChildren.asInstanceOf[IndexedSeq[TableIR]], fieldName, globalName) - protected[ir] override def execute(ctx: ExecuteContext, r: LoweringAnalyses): TableExecuteIntermediate = { + override protected[ir] def execute(ctx: ExecuteContext, r: LoweringAnalyses) + : TableExecuteIntermediate = { val sm = ctx.stateManager val childValues = childrenSeq.map(_.execute(ctx, r).asTableValue(ctx)) @@ -2246,8 +2813,10 @@ case class TableMultiWayZipJoin(childrenSeq: IndexedSeq[TableIR], fieldName: Str assert(childRVDs.forall(_.typ.key.startsWith(typ.key))) val repartitionedRVDs = - if (childRVDs(0).partitioner.satisfiesAllowedOverlap(typ.key.length - 1) && - childRVDs.forall(rvd => rvd.partitioner == childRVDs(0).partitioner)) + if ( + childRVDs(0).partitioner.satisfiesAllowedOverlap(typ.key.length - 1) && + childRVDs.forall(rvd => rvd.partitioner == childRVDs(0).partitioner) + ) childRVDs.map(_.truncateKey(typ.key.length)) else { info("TableMultiWayZipJoin: repartitioning children") @@ -2264,9 +2833,16 @@ case class TableMultiWayZipJoin(childrenSeq: IndexedSeq[TableIR], fieldName: Str val localRVDType = rvdType val keyFields = rvdType.kType.fields.map(f => (f.name, f.typ)) val valueFields = rvdType.valueType.fields.map(f => (f.name, f.typ)) - val localNewRowType = PCanonicalStruct(required = true, - keyFields ++ Array((fieldName, PCanonicalArray( - PCanonicalStruct(required = false, valueFields: _*), required = true))): _*) + val localNewRowType = PCanonicalStruct( + required = true, + keyFields ++ Array(( + fieldName, + PCanonicalArray( + PCanonicalStruct(required = false, valueFields: _*), + required = true, + ), + )): _* + ) val localDataLength = childrenSeq.length val rvMerger = { (ctx: RVDContext, it: Iterator[BoxedArrayBuilder[(RegionValue, Int)]]) => val rvb = new RegionValueBuilder(sm) @@ -2300,32 +2876,33 @@ case class TableMultiWayZipJoin(childrenSeq: IndexedSeq[TableIR], fieldName: Str val rvd = RVD( typ = RVDType(localNewRowType, typ.key), partitioner = newPartitioner, - crdd = ContextRDD.czipNPartitions(repartitionedRVDs.map(_.crdd.toCRDDRegionValue)) { (ctx, its) => - val orvIters = its.map(it => OrderedRVIterator(localRVDType, it, ctx, sm)) - rvMerger(ctx, OrderedRVIterator.multiZipJoin(sm, orvIters)) - }.toCRDDPtr) + crdd = ContextRDD.czipNPartitions(repartitionedRVDs.map(_.crdd.toCRDDRegionValue)) { + (ctx, its) => + val orvIters = its.map(it => OrderedRVIterator(localRVDType, it, ctx, sm)) + rvMerger(ctx, OrderedRVIterator.multiZipJoin(sm, orvIters)) + }.toCRDDPtr, + ) - val newGlobals = BroadcastRow(ctx, - Row(childValues.map(_.globals.javaValue)), - newGlobalType) + val newGlobals = BroadcastRow(ctx, Row(childValues.map(_.globals.javaValue)), newGlobalType) new TableValueIntermediate(TableValue(ctx, typ, newGlobals, rvd)) } } case class TableLeftJoinRightDistinct(left: TableIR, right: TableIR, root: String) extends TableIR { - override def typecheck(): Unit = { + override def typecheck(): Unit = assert( right.typ.keyType isPrefixOf left.typ.keyType, - s"\n L: ${left.typ}\n R: ${right.typ}") - } + s"\n L: ${left.typ}\n R: ${right.typ}", + ) lazy val rowCountUpperBound: Option[Long] = left.rowCountUpperBound lazy val childrenSeq: IndexedSeq[BaseIR] = Array(left, right) lazy val typ: TableType = left.typ.copy( - rowType = left.typ.rowType.structInsert(right.typ.valueType, FastSeq(root))) + rowType = left.typ.rowType.structInsert(right.typ.valueType, FastSeq(root)) + ) override def partitionCounts: Option[IndexedSeq[Long]] = left.partitionCounts @@ -2334,7 +2911,8 @@ case class TableLeftJoinRightDistinct(left: TableIR, right: TableIR, root: Strin TableLeftJoinRightDistinct(newLeft, newRight, root) } - protected[ir] override def execute(ctx: ExecuteContext, r: LoweringAnalyses): TableExecuteIntermediate = { + override protected[ir] def execute(ctx: ExecuteContext, r: LoweringAnalyses) + : TableExecuteIntermediate = { val leftValue = left.execute(ctx, r).asTableValue(ctx) val rightValue = right.execute(ctx, r).asTableValue(ctx) @@ -2343,22 +2921,25 @@ case class TableLeftJoinRightDistinct(left: TableIR, right: TableIR, root: Strin leftValue.copy( typ = typ, rvd = leftValue.rvd - .orderedLeftJoinDistinctAndInsert(rightValue.rvd.truncateKey(joinKey), root))) + .orderedLeftJoinDistinctAndInsert(rightValue.rvd.truncateKey(joinKey), root), + ) + ) } } object TableMapPartitions { - def apply(child: TableIR, - globalName: String, - partitionStreamName: String, - body: IR): TableMapPartitions = TableMapPartitions(child, globalName, partitionStreamName, body, 0, child.typ.key.length) + def apply(child: TableIR, globalName: String, partitionStreamName: String, body: IR) + : TableMapPartitions = + TableMapPartitions(child, globalName, partitionStreamName, body, 0, child.typ.key.length) } -case class TableMapPartitions(child: TableIR, + +case class TableMapPartitions( + child: TableIR, globalName: String, partitionStreamName: String, body: IR, requestedKey: Int, - allowedOverlap: Int + allowedOverlap: Int, ) extends TableIR { override def typecheck(): Unit = { assert(body.typ.isInstanceOf[TStream], s"${body.typ}") @@ -2366,13 +2947,20 @@ case class TableMapPartitions(child: TableIR, assert(allowedOverlap <= child.typ.key.size) assert(requestedKey >= 0) assert(requestedKey <= child.typ.key.size) - assert(StreamUtils.isIterationLinear(body, partitionStreamName), "must iterate over the partition exactly once") + assert( + StreamUtils.isIterationLinear(body, partitionStreamName), + "must iterate over the partition exactly once", + ) val newRowType = body.typ.asInstanceOf[TStream].elementType.asInstanceOf[TStruct] - child.typ.key.foreach { k => if (!newRowType.hasField(k)) throw new RuntimeException(s"prev key: ${child.typ.key}, new row: ${newRowType}") } + child.typ.key.foreach { k => + if (!newRowType.hasField(k)) + throw new RuntimeException(s"prev key: ${child.typ.key}, new row: $newRowType") + } } lazy val typ: TableType = child.typ.copy( - rowType = body.typ.asInstanceOf[TStream].elementType.asInstanceOf[TStruct]) + rowType = body.typ.asInstanceOf[TStream].elementType.asInstanceOf[TStruct] + ) lazy val childrenSeq: IndexedSeq[BaseIR] = Array(child, body) @@ -2380,21 +2968,43 @@ case class TableMapPartitions(child: TableIR, override def copy(newChildren: IndexedSeq[BaseIR]): TableMapPartitions = { assert(newChildren.length == 2) - TableMapPartitions(newChildren(0).asInstanceOf[TableIR], - globalName, partitionStreamName, newChildren(1).asInstanceOf[IR], requestedKey, allowedOverlap) + TableMapPartitions( + newChildren(0).asInstanceOf[TableIR], + globalName, + partitionStreamName, + newChildren(1).asInstanceOf[IR], + requestedKey, + allowedOverlap, + ) } - protected[ir] override def execute(ctx: ExecuteContext, r: LoweringAnalyses): TableExecuteIntermediate = { + override protected[ir] def execute(ctx: ExecuteContext, r: LoweringAnalyses) + : TableExecuteIntermediate = { val tv = child.execute(ctx, r).asTableValue(ctx) val rowPType = tv.rvd.rowPType val globalPType = tv.globals.t val (newRowPType: PStruct, makeIterator) = CompileIterator.forTableMapPartitions( ctx, - globalPType, rowPType, - Subst(body, BindingEnv(Env( - globalName -> In(0, SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(globalPType))), - partitionStreamName -> In(1, SingleCodeEmitParamType(true, StreamSingleCodeType(requiresMemoryManagementPerElement = true, rowPType, true))))))) + globalPType, + rowPType, + Subst( + body, + BindingEnv(Env( + globalName -> In( + 0, + SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(globalPType)), + ), + partitionStreamName -> In( + 1, + SingleCodeEmitParamType( + true, + StreamSingleCodeType(requiresMemoryManagementPerElement = true, rowPType, true), + ), + ), + )), + ), + ) val globalsBc = tv.globals.broadcast(ctx.theHailClassLoader) @@ -2403,23 +3013,28 @@ case class TableMapPartitions(child: TableIR, val boxedPartition = new NoBoxLongIterator { var eos: Boolean = false var iter: Iterator[Long] = _ - override def init(partitionRegion: Region, elementRegion: Region): Unit = { + override def init(partitionRegion: Region, elementRegion: Region): Unit = iter = partition(new RVDContext(partitionRegion, elementRegion)) - } - override def next(): Long = { + override def next(): Long = if (!iter.hasNext) { eos = true 0L } else iter.next() - } override def close(): Unit = () } - makeIterator(theHailClassLoaderForSparkWorkers, fsBc.value, SparkTaskContext.get(), consumerCtx, - globalsBc.value.readRegionValue(consumerCtx.partitionRegion, theHailClassLoaderForSparkWorkers), - boxedPartition + makeIterator( + theHailClassLoaderForSparkWorkers, + fsBc.value, + SparkTaskContext.get(), + consumerCtx, + globalsBc.value.readRegionValue( + consumerCtx.partitionRegion, + theHailClassLoaderForSparkWorkers, + ), + boxedPartition, ).map(l => l.longValue()) } @@ -2429,7 +3044,8 @@ case class TableMapPartitions(child: TableIR, tv.copy( typ = typ, rvd = rvd - .mapPartitionsWithContextAndIndex(RVDType(newRowPType, typ.key))(itF)) + .mapPartitionsWithContextAndIndex(RVDType(newRowPType, typ.key))(itF), + ) ) } } @@ -2454,21 +3070,27 @@ case class TableMapRows(child: TableIR, newRow: IR) extends TableIR { override def partitionCounts: Option[IndexedSeq[Long]] = child.partitionCounts - protected[ir] override def execute(ctx: ExecuteContext, r: LoweringAnalyses): TableExecuteIntermediate = { + override protected[ir] def execute(ctx: ExecuteContext, r: LoweringAnalyses) + : TableExecuteIntermediate = { val tv = child.execute(ctx, r).asTableValue(ctx) val fsBc = ctx.fsBc - val scanRef = genUID() - val extracted = agg.Extract.apply(newRow, scanRef, Requiredness(this, ctx), isScan = true) + val extracted = agg.Extract.apply(newRow, Requiredness(this, ctx), isScan = true) if (extracted.aggs.isEmpty) { - val (Some(PTypeReferenceSingleCodeType(rTyp)), f) = ir.Compile[AsmFunction3RegionLongLongLong]( - ctx, - FastSeq(("global", SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(tv.globals.t))), - ("row", SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(tv.rvd.rowPType)))), - FastSeq(classInfo[Region], LongInfo, LongInfo), LongInfo, - Coalesce(FastSeq( - extracted.postAggIR, - Die("Internal error: TableMapRows: row expression missing", extracted.postAggIR.typ)))) + val (Some(PTypeReferenceSingleCodeType(rTyp)), f) = + ir.Compile[AsmFunction3RegionLongLongLong]( + ctx, + FastSeq( + ("global", SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(tv.globals.t))), + ("row", SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(tv.rvd.rowPType))), + ), + FastSeq(classInfo[Region], LongInfo, LongInfo), + LongInfo, + Coalesce(FastSeq( + extracted.postAggIR, + Die("Internal error: TableMapRows: row expression missing", extracted.postAggIR.typ), + )), + ) val rowIterationNeedsGlobals = Mentions(extracted.postAggIR, "global") val globalsBc = @@ -2485,16 +3107,17 @@ case class TableMapRows(child: TableIR, newRow: IR) extends TableIR { else 0 - val newRow = f(theHailClassLoaderForSparkWorkers, fsBc.value, SparkTaskContext.get(), globalRegion) - it.map { ptr => - newRow(ctx.r, globals, ptr) - } + val newRow = + f(theHailClassLoaderForSparkWorkers, fsBc.value, SparkTaskContext.get(), globalRegion) + it.map(ptr => newRow(ctx.r, globals, ptr)) } return new TableValueIntermediate( tv.copy( typ = typ, - rvd = tv.rvd.mapPartitionsWithIndex(RVDType(rTyp.asInstanceOf[PStruct], typ.key))(itF))) + rvd = tv.rvd.mapPartitionsWithIndex(RVDType(rTyp.asInstanceOf[PStruct], typ.key))(itF), + ) + ) } val scanInitNeedsGlobals = Mentions(extracted.init, "global") @@ -2515,34 +3138,54 @@ case class TableMapRows(child: TableIR, newRow: IR) extends TableIR { // 3. load in partition aggregations, comb op as necessary, serialize. // 4. load in partStarts, calculate newRow based on those results. - val (_, initF) = ir.CompileWithAggregators[AsmFunction2RegionLongUnit](ctx, + val (_, initF) = ir.CompileWithAggregators[AsmFunction2RegionLongUnit]( + ctx, extracted.states, - FastSeq(("global", SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(tv.globals.t)))), - FastSeq(classInfo[Region], LongInfo), UnitInfo, - Begin(FastSeq(extracted.init))) + FastSeq(( + "global", + SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(tv.globals.t)), + )), + FastSeq(classInfo[Region], LongInfo), + UnitInfo, + Begin(FastSeq(extracted.init)), + ) val serializeF = extracted.serialize(ctx, spec) - val (_, eltSeqF) = ir.CompileWithAggregators[AsmFunction3RegionLongLongUnit](ctx, + val (_, eltSeqF) = ir.CompileWithAggregators[AsmFunction3RegionLongLongUnit]( + ctx, extracted.states, - FastSeq(("global", SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(tv.globals.t))), - ("row", SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(tv.rvd.rowPType)))), - FastSeq(classInfo[Region], LongInfo, LongInfo), UnitInfo, - extracted.eltOp(ctx)) + FastSeq( + ("global", SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(tv.globals.t))), + ("row", SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(tv.rvd.rowPType))), + ), + FastSeq(classInfo[Region], LongInfo, LongInfo), + UnitInfo, + extracted.seqPerElt, + ) val read = extracted.deserialize(ctx, spec) val write = extracted.serialize(ctx, spec) val combOpFNeedsPool = extracted.combOpFSerializedFromRegionPool(ctx, spec) - val (Some(PTypeReferenceSingleCodeType(rTyp)), f) = ir.CompileWithAggregators[AsmFunction3RegionLongLongLong](ctx, - extracted.states, - FastSeq(("global", SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(tv.globals.t))), - ("row", SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(tv.rvd.rowPType)))), - FastSeq(classInfo[Region], LongInfo, LongInfo), LongInfo, - Let(FastSeq(scanRef -> extracted.results), - Coalesce(FastSeq( - extracted.postAggIR, - Die("Internal error: TableMapRows: row expression missing", extracted.postAggIR.typ))))) + val (Some(PTypeReferenceSingleCodeType(rTyp)), f) = + ir.CompileWithAggregators[AsmFunction3RegionLongLongLong]( + ctx, + extracted.states, + FastSeq( + ("global", SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(tv.globals.t))), + ("row", SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(tv.rvd.rowPType))), + ), + FastSeq(classInfo[Region], LongInfo, LongInfo), + LongInfo, + Let( + FastSeq(extracted.resultRef.name -> extracted.results), + Coalesce(FastSeq( + extracted.postAggIR, + Die("Internal error: TableMapRows: row expression missing", extracted.postAggIR.typ), + )), + ), + ) assert(rTyp.virtualType == newRow.typ) // 1. init op on all aggs and write out to initPath @@ -2562,13 +3205,18 @@ case class TableMapRows(child: TableIR, newRow: IR) extends TableIR { val files = tv.rvd.mapPartitionsWithIndex { (i, ctx, it) => val path = tmpBase + "/" + partFile(d, i, TaskContext.get) val globalRegion = ctx.freshRegion() - val globals = if (scanSeqNeedsGlobals) globalsBc.value.readRegionValue(globalRegion, theHailClassLoaderForSparkWorkers) else 0 + val globals = if (scanSeqNeedsGlobals) + globalsBc.value.readRegionValue(globalRegion, theHailClassLoaderForSparkWorkers) + else 0 ctx.r.pool.scopedSmallRegion { aggRegion => val tc = SparkTaskContext.get() val seq = eltSeqF(theHailClassLoaderForSparkWorkers, fsBc.value, tc, globalRegion) - seq.setAggState(aggRegion, read(theHailClassLoaderForSparkWorkers, tc, aggRegion, initAgg)) + seq.setAggState( + aggRegion, + read(theHailClassLoaderForSparkWorkers, tc, aggRegion, initAgg), + ) it.foreach { ptr => seq(ctx.region, globals, ptr) ctx.region.clear() @@ -2589,30 +3237,36 @@ case class TableMapRows(child: TableIR, newRow: IR) extends TableIR { log.info(s"Running distributed combine stage with $nToMerge tasks") fileStack += filesToMerge - filesToMerge = ContextRDD.weaken(SparkBackend.sparkContext("TableMapRows.execute").parallelize(0 until nToMerge, nToMerge)) - .cmapPartitions { (ctx, it) => - val i = it.next() - assert(it.isEmpty) - val path = tmpBase + "/" + partFile(d, i, TaskContext.get) - val file1 = filesToMerge(i * 2) - val file2 = filesToMerge(i * 2 + 1) - - def readToBytes(is: DataInputStream): Array[Byte] = { - val len = is.readInt() - val b = new Array[Byte](len) - is.readFully(b) - b - } + filesToMerge = + ContextRDD.weaken(SparkBackend.sparkContext("TableMapRows.execute").parallelize( + 0 until nToMerge, + nToMerge, + )) + .cmapPartitions { (ctx, it) => + val i = it.next() + assert(it.isEmpty) + val path = tmpBase + "/" + partFile(d, i, TaskContext.get) + val file1 = filesToMerge(i * 2) + val file2 = filesToMerge(i * 2 + 1) + + def readToBytes(is: DataInputStream): Array[Byte] = { + val len = is.readInt() + val b = new Array[Byte](len) + is.readFully(b) + b + } - val b1 = using(new DataInputStream(fsBc.value.open(file1)))(readToBytes) - val b2 = using(new DataInputStream(fsBc.value.open(file2)))(readToBytes) - using(new DataOutputStream(fsBc.value.create(path))) { os => - val bytes = combOpFNeedsPool(() => (ctx.r.pool, theHailClassLoaderForSparkWorkers, SparkTaskContext.get()))(b1, b2) - os.writeInt(bytes.length) - os.write(bytes) - } - Iterator.single(path) - }.collect() + val b1 = using(new DataInputStream(fsBc.value.open(file1)))(readToBytes) + val b2 = using(new DataInputStream(fsBc.value.open(file2)))(readToBytes) + using(new DataOutputStream(fsBc.value.create(path))) { os => + val bytes = combOpFNeedsPool(() => + (ctx.r.pool, theHailClassLoaderForSparkWorkers, SparkTaskContext.get()) + )(b1, b2) + os.writeInt(bytes.length) + os.write(bytes) + } + Iterator.single(path) + }.collect() } fileStack += filesToMerge @@ -2646,7 +3300,9 @@ case class TableMapRows(child: TableIR, newRow: IR) extends TableIR { b } - b = combOpFNeedsPool(() => (ctx.r.pool, theHailClassLoaderForSparkWorkers, SparkTaskContext.get()))(b, using(new DataInputStream(fsBc.value.open(path)))(readToBytes)) + b = combOpFNeedsPool(() => + (ctx.r.pool, theHailClassLoaderForSparkWorkers, SparkTaskContext.get()) + )(b, using(new DataInputStream(fsBc.value.open(path)))(readToBytes)) } b } @@ -2672,30 +3328,41 @@ case class TableMapRows(child: TableIR, newRow: IR) extends TableIR { return new TableValueIntermediate( tv.copy( typ = typ, - rvd = tv.rvd.mapPartitionsWithIndex(RVDType(rTyp.asInstanceOf[PStruct], typ.key))(itF))) + rvd = tv.rvd.mapPartitionsWithIndex(RVDType(rTyp.asInstanceOf[PStruct], typ.key))(itF), + ) + ) } // 2. load in init op on each partition, seq op over partition, write out. - val scanPartitionAggs = SpillingCollectIterator(ctx.localTmpdir, ctx.fs, tv.rvd.mapPartitionsWithIndex { (i, ctx, it) => - val globalRegion = ctx.partitionRegion - val globals = if (scanSeqNeedsGlobals) globalsBc.value.readRegionValue(globalRegion, theHailClassLoaderForSparkWorkers) else 0 + val scanPartitionAggs = SpillingCollectIterator( + ctx.localTmpdir, + ctx.fs, + tv.rvd.mapPartitionsWithIndex { (i, ctx, it) => + val globalRegion = ctx.partitionRegion + val globals = if (scanSeqNeedsGlobals) + globalsBc.value.readRegionValue(globalRegion, theHailClassLoaderForSparkWorkers) + else 0 - SparkTaskContext.get().getRegionPool().scopedSmallRegion { aggRegion => - val hcl = theHailClassLoaderForSparkWorkers - val tc = SparkTaskContext.get() - val seq = eltSeqF(hcl, fsBc.value, tc, globalRegion) + SparkTaskContext.get().getRegionPool().scopedSmallRegion { aggRegion => + val hcl = theHailClassLoaderForSparkWorkers + val tc = SparkTaskContext.get() + val seq = eltSeqF(hcl, fsBc.value, tc, globalRegion) - seq.setAggState(aggRegion, read(hcl, tc, aggRegion, initAgg)) - it.foreach { ptr => - seq(ctx.region, globals, ptr) - ctx.region.clear() + seq.setAggState(aggRegion, read(hcl, tc, aggRegion, initAgg)) + it.foreach { ptr => + seq(ctx.region, globals, ptr) + ctx.region.clear() + } + Iterator.single(write(hcl, tc, aggRegion, seq.getAggOffset())) } - Iterator.single(write(hcl, tc, aggRegion, seq.getAggOffset())) - } - }, ctx.getFlag("max_leader_scans").toInt) + }, + ctx.getFlag("max_leader_scans").toInt, + ) // 3. load in partition aggregations, comb op as necessary, write back out. - val partAggs = scanPartitionAggs.scanLeft(initAgg)(combOpFNeedsPool(() => (ctx.r.pool, ctx.theHailClassLoader, ctx.taskContext))) + val partAggs = scanPartitionAggs.scanLeft(initAgg)(combOpFNeedsPool(() => + (ctx.r.pool, ctx.theHailClassLoader, ctx.taskContext) + )) val scanAggCount = tv.rvd.getNumPartitions val partitionIndices = new Array[Long](scanAggCount) val scanAggsPerPartitionFile = ctx.createTmpPath("table-map-rows-scan-aggs-part") @@ -2710,7 +3377,6 @@ case class TableMapRows(child: TableIR, newRow: IR) extends TableIR { } } - // 4. load in partStarts, calculate newRow based on those results. val itF = { (i: Int, ctx: RVDContext, filePosition: Long, it: Iterator[Long]) => val globalRegion = ctx.partitionRegion @@ -2755,7 +3421,12 @@ case class TableMapRows(child: TableIR, newRow: IR) extends TableIR { new TableValueIntermediate( tv.copy( typ = typ, - rvd = tv.rvd.mapPartitionsWithIndexAndValue(RVDType(rTyp.asInstanceOf[PStruct], typ.key), partitionIndices)(itF))) + rvd = tv.rvd.mapPartitionsWithIndexAndValue( + RVDType(rTyp.asInstanceOf[PStruct], typ.key), + partitionIndices, + )(itF), + ) + ) } } @@ -2774,29 +3445,38 @@ case class TableMapGlobals(child: TableIR, newGlobals: IR) extends TableIR { override def partitionCounts: Option[IndexedSeq[Long]] = child.partitionCounts - protected[ir] override def execute(ctx: ExecuteContext, r: LoweringAnalyses): TableExecuteIntermediate = { + override protected[ir] def execute(ctx: ExecuteContext, r: LoweringAnalyses) + : TableExecuteIntermediate = { val tv = child.execute(ctx, r).asTableValue(ctx) - val (Some(PTypeReferenceSingleCodeType(resultPType: PStruct)), f) = Compile[AsmFunction2RegionLongLong](ctx, - FastSeq(("global", SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(tv.globals.t)))), - FastSeq(classInfo[Region], LongInfo), LongInfo, - Coalesce(FastSeq( - newGlobals, - Die("Internal error: TableMapGlobals: globals missing", newGlobals.typ)))) + val (Some(PTypeReferenceSingleCodeType(resultPType: PStruct)), f) = + Compile[AsmFunction2RegionLongLong]( + ctx, + FastSeq(( + "global", + SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(tv.globals.t)), + )), + FastSeq(classInfo[Region], LongInfo), + LongInfo, + Coalesce(FastSeq( + newGlobals, + Die("Internal error: TableMapGlobals: globals missing", newGlobals.typ), + )), + ) - val resultOff = f(ctx.theHailClassLoader, ctx.fs, ctx.taskContext, ctx.r)(ctx.r, tv.globals.value.offset) + val resultOff = + f(ctx.theHailClassLoader, ctx.fs, ctx.taskContext, ctx.r)(ctx.r, tv.globals.value.offset) new TableValueIntermediate( - tv.copy(typ = typ, - globals = BroadcastRow(ctx, RegionValue(ctx.r, resultOff), resultPType))) + tv.copy(typ = typ, globals = BroadcastRow(ctx, RegionValue(ctx.r, resultOff), resultPType)) + ) } } case class TableExplode(child: TableIR, path: IndexedSeq[String]) extends TableIR { assert(path.nonEmpty) - override def typecheck(): Unit = { + override def typecheck(): Unit = assert(!child.typ.key.contains(path.head)) - } lazy val rowCountUpperBound: Option[Long] = None @@ -2805,17 +3485,22 @@ case class TableExplode(child: TableIR, path: IndexedSeq[String]) extends TableI private def childRowType = child.typ.rowType private[this] lazy val idx = Ref(genUID(), TInt32) + private[this] lazy val newRow: InsertFields = { val refs = path.init.scanLeft(Ref("row", childRowType))((struct, name) => - Ref(genUID(), tcoerce[TStruct](struct.typ).field(name).typ)) + Ref(genUID(), tcoerce[TStruct](struct.typ).field(name).typ) + ) path.zip(refs).zipWithIndex.foldRight[IR](idx) { case (((field, ref), i), arg) => - InsertFields(ref, FastSeq(field -> - (if (i == refs.length - 1) - ArrayRef(CastToArray(GetField(ref, field)), arg) - else - Let(FastSeq(refs(i + 1).name -> GetField(ref, field)), arg)))) + InsertFields( + ref, + FastSeq(field -> + (if (i == refs.length - 1) + ArrayRef(CastToArray(GetField(ref, field)), arg) + else + Let(FastSeq(refs(i + 1).name -> GetField(ref, field)), arg))), + ) }.asInstanceOf[InsertFields] } @@ -2826,41 +3511,57 @@ case class TableExplode(child: TableIR, path: IndexedSeq[String]) extends TableI TableExplode(newChildren(0).asInstanceOf[TableIR], path) } - protected[ir] override def execute(ctx: ExecuteContext, r: LoweringAnalyses): TableExecuteIntermediate = { + override protected[ir] def execute(ctx: ExecuteContext, r: LoweringAnalyses) + : TableExecuteIntermediate = { val prev = child.execute(ctx, r).asTableValue(ctx) val length: IR = Coalesce(FastSeq( ArrayLen(CastToArray( - path.foldLeft[IR](Ref("row", childRowType)) { (struct, field) => - GetField(struct, field) - })), - 0)) - - val (len, l) = Compile[AsmFunction2RegionLongInt](ctx, - FastSeq(("row", SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(prev.rvd.rowPType)))), - FastSeq(classInfo[Region], LongInfo), IntInfo, - length) - val (Some(PTypeReferenceSingleCodeType(newRowType: PStruct)), f) = Compile[AsmFunction3RegionLongIntLong]( + path.foldLeft[IR](Ref("row", childRowType))((struct, field) => GetField(struct, field)) + )), + 0, + )) + + val (_, l) = Compile[AsmFunction2RegionLongInt]( ctx, - FastSeq(("row", SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(prev.rvd.rowPType))), - (idx.name, SingleCodeEmitParamType(true, Int32SingleCodeType))), - FastSeq(classInfo[Region], LongInfo, IntInfo), LongInfo, - newRow) + FastSeq(( + "row", + SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(prev.rvd.rowPType)), + )), + FastSeq(classInfo[Region], LongInfo), + IntInfo, + length, + ) + val (Some(PTypeReferenceSingleCodeType(newRowType: PStruct)), f) = + Compile[AsmFunction3RegionLongIntLong]( + ctx, + FastSeq( + ("row", SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(prev.rvd.rowPType))), + (idx.name, SingleCodeEmitParamType(true, Int32SingleCodeType)), + ), + FastSeq(classInfo[Region], LongInfo, IntInfo), + LongInfo, + newRow, + ) assert(newRowType.virtualType == typ.rowType) val rvdType: RVDType = RVDType( newRowType, - prev.rvd.typ.key.takeWhile(_ != path.head) + prev.rvd.typ.key.takeWhile(_ != path.head), ) val fsBc = ctx.fsBc TableValueIntermediate( - TableValue(ctx, typ, + TableValue( + ctx, + typ, prev.globals, prev.rvd.boundary.mapPartitionsWithIndex(rvdType) { (i, ctx, it) => val globalRegion = ctx.partitionRegion - val lenF = l(theHailClassLoaderForSparkWorkers, fsBc.value, SparkTaskContext.get(), globalRegion) - val rowF = f(theHailClassLoaderForSparkWorkers, fsBc.value, SparkTaskContext.get(), globalRegion) + val lenF = + l(theHailClassLoaderForSparkWorkers, fsBc.value, SparkTaskContext.get(), globalRegion) + val rowF = + f(theHailClassLoaderForSparkWorkers, fsBc.value, SparkTaskContext.get(), globalRegion) it.flatMap { ptr => val len = lenF(ctx.region, ptr) new Iterator[Long] { @@ -2875,7 +3576,8 @@ case class TableExplode(child: TableIR, path: IndexedSeq[String]) extends TableI } } } - }) + }, + ) ) } } @@ -2896,17 +3598,19 @@ case class TableUnion(childrenSeq: IndexedSeq[TableIR]) extends TableIR { None } - def copy(newChildren: IndexedSeq[BaseIR]): TableUnion = { + def copy(newChildren: IndexedSeq[BaseIR]): TableUnion = TableUnion(newChildren.map(_.asInstanceOf[TableIR])) - } def typ: TableType = childrenSeq(0).typ - protected[ir] override def execute(ctx: ExecuteContext, r: LoweringAnalyses): TableExecuteIntermediate = { + override protected[ir] def execute(ctx: ExecuteContext, r: LoweringAnalyses) + : TableExecuteIntermediate = { val tvs = childrenSeq.map(_.execute(ctx, r).asTableValue(ctx)) TableValueIntermediate( tvs(0).copy( - rvd = RVD.union(RVD.unify(ctx, tvs.map(_.rvd)), tvs(0).typ.key.length, ctx))) + rvd = RVD.union(RVD.unify(ctx, tvs.map(_.rvd)), tvs(0).typ.key.length, ctx) + ) + ) } } @@ -2963,9 +3667,12 @@ case class TableDistinct(child: TableIR) extends TableIR { def typ: TableType = child.typ - protected[ir] override def execute(ctx: ExecuteContext, r: LoweringAnalyses): TableExecuteIntermediate = { + override protected[ir] def execute(ctx: ExecuteContext, r: LoweringAnalyses) + : TableExecuteIntermediate = { val prev = child.execute(ctx, r).asTableValue(ctx) - new TableValueIntermediate(prev.copy(rvd = prev.rvd.truncateKey(prev.typ.key).distinctByKey(ctx))) + new TableValueIntermediate(prev.copy(rvd = + prev.rvd.truncateKey(prev.typ.key).distinctByKey(ctx) + )) } } @@ -2974,7 +3681,7 @@ case class TableKeyByAndAggregate( expr: IR, newKey: IR, nPartitions: Option[Int] = None, - bufferSize: Int + bufferSize: Int, ) extends TableIR { assert(bufferSize > 0) @@ -2993,52 +3700,78 @@ case class TableKeyByAndAggregate( } private lazy val keyType = newKey.typ.asInstanceOf[TStruct] + lazy val typ: TableType = TableType( rowType = keyType ++ tcoerce[TStruct](expr.typ), globalType = child.typ.globalType, - key = keyType.fieldNames + key = keyType.fieldNames, ) - protected[ir] override def execute(ctx: ExecuteContext, r: LoweringAnalyses): TableExecuteIntermediate = { + override protected[ir] def execute(ctx: ExecuteContext, r: LoweringAnalyses) + : TableExecuteIntermediate = { val prev = child.execute(ctx, r).asTableValue(ctx) val fsBc = ctx.fsBc val sm = ctx.stateManager val localKeyType = keyType - val (Some(PTypeReferenceSingleCodeType(localKeyPType: PStruct)), makeKeyF) = ir.Compile[AsmFunction3RegionLongLongLong](ctx, - FastSeq(("row", SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(prev.rvd.rowPType))), - ("global", SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(prev.globals.t)))), - FastSeq(classInfo[Region], LongInfo, LongInfo), LongInfo, - Coalesce(FastSeq( - newKey, - Die("Internal error: TableKeyByAndAggregate: newKey missing", newKey.typ)))) + val (Some(PTypeReferenceSingleCodeType(localKeyPType: PStruct)), makeKeyF) = + ir.Compile[AsmFunction3RegionLongLongLong]( + ctx, + FastSeq( + ("row", SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(prev.rvd.rowPType))), + ("global", SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(prev.globals.t))), + ), + FastSeq(classInfo[Region], LongInfo, LongInfo), + LongInfo, + Coalesce(FastSeq( + newKey, + Die("Internal error: TableKeyByAndAggregate: newKey missing", newKey.typ), + )), + ) val globalsBc = prev.globals.broadcast(ctx.theHailClassLoader) val spec = BufferSpec.blockedUncompressed - val res = genUID() - val extracted = agg.Extract(expr, res, Requiredness(this, ctx)) + val extracted = agg.Extract(expr, Requiredness(this, ctx)) - val (_, makeInit) = ir.CompileWithAggregators[AsmFunction2RegionLongUnit](ctx, + val (_, makeInit) = ir.CompileWithAggregators[AsmFunction2RegionLongUnit]( + ctx, extracted.states, - FastSeq(("global", SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(prev.globals.t)))), - FastSeq(classInfo[Region], LongInfo), UnitInfo, - extracted.init) + FastSeq(( + "global", + SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(prev.globals.t)), + )), + FastSeq(classInfo[Region], LongInfo), + UnitInfo, + extracted.init, + ) - val (_, makeSeq) = ir.CompileWithAggregators[AsmFunction3RegionLongLongUnit](ctx, + val (_, makeSeq) = ir.CompileWithAggregators[AsmFunction3RegionLongLongUnit]( + ctx, extracted.states, - FastSeq(("global", SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(prev.globals.t))), - ("row", SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(prev.rvd.rowPType)))), - FastSeq(classInfo[Region], LongInfo, LongInfo), UnitInfo, - extracted.seqPerElt) + FastSeq( + ("global", SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(prev.globals.t))), + ("row", SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(prev.rvd.rowPType))), + ), + FastSeq(classInfo[Region], LongInfo, LongInfo), + UnitInfo, + extracted.seqPerElt, + ) - val (Some(PTypeReferenceSingleCodeType(rTyp: PStruct)), makeAnnotate) = ir.CompileWithAggregators[AsmFunction2RegionLongLong](ctx, - extracted.states, - FastSeq(("global", SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(prev.globals.t)))), - FastSeq(classInfo[Region], LongInfo), LongInfo, - Let(FastSeq(res -> extracted.results), extracted.postAggIR)) - assert(rTyp.virtualType == typ.valueType, s"$rTyp, ${ typ.valueType }") + val (Some(PTypeReferenceSingleCodeType(rTyp: PStruct)), makeAnnotate) = + ir.CompileWithAggregators[AsmFunction2RegionLongLong]( + ctx, + extracted.states, + FastSeq(( + "global", + SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(prev.globals.t)), + )), + FastSeq(classInfo[Region], LongInfo), + LongInfo, + Let(FastSeq(extracted.resultRef.name -> extracted.results), extracted.postAggIR), + ) + assert(rTyp.virtualType == typ.valueType, s"$rTyp, ${typ.valueType}") val serialize = extracted.serialize(ctx, spec) val deserialize = extracted.deserialize(ctx, spec) @@ -3054,8 +3787,12 @@ case class TableKeyByAndAggregate( serialize(hcl, tc, aggRegion, initF.getAggOffset()) } - val newRowType = PCanonicalStruct(required = true, - localKeyPType.fields.map(f => (f.name, PType.canonical(f.typ))) ++ rTyp.fields.map(f => (f.name, f.typ)): _*) + val newRowType = PCanonicalStruct( + required = true, + localKeyPType.fields.map(f => (f.name, PType.canonical(f.typ))) ++ rTyp.fields.map(f => + (f.name, f.typ) + ): _* + ) val localBufferSize = bufferSize val rdd = prev.rvd @@ -3101,48 +3838,47 @@ case class TableKeyByAndAggregate( localBufferSize) }.aggregateByKey(initAggs, nPartitions.getOrElse(prev.rvd.getNumPartitions))(combOp, combOp) - val crdd = ContextRDD.weaken(rdd).cmapPartitionsWithIndex( - { (i, ctx, it) => - val region = ctx.region - - val rvb = new RegionValueBuilder(sm) - val partRegion = ctx.partitionRegion - val hcl = theHailClassLoaderForSparkWorkers - val tc = SparkTaskContext.get() - val globals = globalsBc.value.readRegionValue(partRegion, hcl) - val annotate = makeAnnotate(hcl, fsBc.value, tc, partRegion) + val crdd = ContextRDD.weaken(rdd).cmapPartitionsWithIndex({ (i, ctx, it) => + val region = ctx.region - it.map { case (key, aggs) => - rvb.set(region) - rvb.start(newRowType) - rvb.startStruct() - var i = 0 - while (i < localKeyType.size) { - rvb.addAnnotation(localKeyType.types(i), key.get(i)) - i += 1 - } + val rvb = new RegionValueBuilder(sm) + val partRegion = ctx.partitionRegion + val hcl = theHailClassLoaderForSparkWorkers + val tc = SparkTaskContext.get() + val globals = globalsBc.value.readRegionValue(partRegion, hcl) + val annotate = makeAnnotate(hcl, fsBc.value, tc, partRegion) - val aggOff = deserialize(hcl, tc, region, aggs) - annotate.setAggState(region, aggOff) - rvb.addAllFields(rTyp, region, annotate(region, globals)) - rvb.endStruct() - rvb.end() + it.map { case (key, aggs) => + rvb.set(region) + rvb.start(newRowType) + rvb.startStruct() + var i = 0 + while (i < localKeyType.size) { + rvb.addAnnotation(localKeyType.types(i), key.get(i)) + i += 1 } - }) + + val aggOff = deserialize(hcl, tc, region, aggs) + annotate.setAggState(region, aggOff) + rvb.addAllFields(rTyp, region, annotate(region, globals)) + rvb.endStruct() + rvb.end() + } + }) new TableValueIntermediate( prev.copy( typ = typ, - rvd = RVD.coerce(ctx, RVDType(newRowType, keyType.fieldNames), crdd)) + rvd = RVD.coerce(ctx, RVDType(newRowType, keyType.fieldNames), crdd), + ) ) } } // follows key_by non-empty key case class TableAggregateByKey(child: TableIR, expr: IR) extends TableIR { - override def typecheck(): Unit = { + override def typecheck(): Unit = assert(child.typ.key.nonEmpty) - } lazy val rowCountUpperBound: Option[Long] = child.rowCountUpperBound @@ -3154,44 +3890,64 @@ case class TableAggregateByKey(child: TableIR, expr: IR) extends TableIR { TableAggregateByKey(newChild, newExpr) } - lazy val typ: TableType = child.typ.copy(rowType = child.typ.keyType ++ tcoerce[TStruct](expr.typ)) + lazy val typ: TableType = + child.typ.copy(rowType = child.typ.keyType ++ tcoerce[TStruct](expr.typ)) - protected[ir] override def execute(ctx: ExecuteContext, r: LoweringAnalyses): TableExecuteIntermediate = { + override protected[ir] def execute(ctx: ExecuteContext, r: LoweringAnalyses) + : TableExecuteIntermediate = { val prev = child.execute(ctx, r).asTableValue(ctx) val prevRVD = prev.rvd.truncateKey(child.typ.key) val fsBc = ctx.fsBc val sm = ctx.stateManager - val res = genUID() - val extracted = agg.Extract(expr, res, Requiredness(this, ctx)) + val extracted = agg.Extract(expr, Requiredness(this, ctx)) - val (_, makeInit) = ir.CompileWithAggregators[AsmFunction2RegionLongUnit](ctx, + val (_, makeInit) = ir.CompileWithAggregators[AsmFunction2RegionLongUnit]( + ctx, extracted.states, - FastSeq(("global", SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(prev.globals.t)))), - FastSeq(classInfo[Region], LongInfo), UnitInfo, - extracted.init) + FastSeq(( + "global", + SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(prev.globals.t)), + )), + FastSeq(classInfo[Region], LongInfo), + UnitInfo, + extracted.init, + ) - val (_, makeSeq) = ir.CompileWithAggregators[AsmFunction3RegionLongLongUnit](ctx, + val (_, makeSeq) = ir.CompileWithAggregators[AsmFunction3RegionLongLongUnit]( + ctx, extracted.states, - FastSeq(("global", SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(prev.globals.t))), - ("row", SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(prevRVD.rowPType)))), - FastSeq(classInfo[Region], LongInfo, LongInfo), UnitInfo, - extracted.seqPerElt) + FastSeq( + ("global", SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(prev.globals.t))), + ("row", SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(prevRVD.rowPType))), + ), + FastSeq(classInfo[Region], LongInfo, LongInfo), + UnitInfo, + extracted.seqPerElt, + ) - val valueIR = Let(FastSeq(res -> extracted.results), extracted.postAggIR) + val valueIR = Let(FastSeq(extracted.resultRef.name -> extracted.results), extracted.postAggIR) val keyType = prevRVD.typ.kType val key = Ref(genUID(), keyType.virtualType) val value = Ref(genUID(), valueIR.typ) - val (Some(PTypeReferenceSingleCodeType(rowType: PStruct)), makeRow) = ir.CompileWithAggregators[AsmFunction3RegionLongLongLong](ctx, - extracted.states, - FastSeq(("global", SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(prev.globals.t))), - (key.name, SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(keyType)))), - FastSeq(classInfo[Region], LongInfo, LongInfo), LongInfo, - Let(FastSeq(value.name -> valueIR), - InsertFields(key, typ.valueType.fieldNames.map(n => n -> GetField(value, n))))) - - assert(rowType.virtualType == typ.rowType, s"$rowType, ${ typ.rowType }") + val (Some(PTypeReferenceSingleCodeType(rowType: PStruct)), makeRow) = + ir.CompileWithAggregators[AsmFunction3RegionLongLongLong]( + ctx, + extracted.states, + FastSeq( + ("global", SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(prev.globals.t))), + (key.name, SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(keyType))), + ), + FastSeq(classInfo[Region], LongInfo, LongInfo), + LongInfo, + Let( + FastSeq(value.name -> valueIR), + InsertFields(key, typ.valueType.fieldNames.map(n => n -> GetField(value, n))), + ), + ) + + assert(rowType.virtualType == typ.rowType, s"$rowType, ${typ.rowType}") val localChildRowType = prevRVD.rowPType val keyIndices = prevRVD.typ.kFieldIdx @@ -3205,11 +3961,19 @@ case class TableAggregateByKey(child: TableIR, expr: IR) extends TableIR { .boundary .mapPartitionsWithIndex(newRVDType) { (i, ctx, it) => val partRegion = ctx.partitionRegion - val globalsOff = globalsBc.value.readRegionValue(partRegion, theHailClassLoaderForSparkWorkers) - - val initialize = makeInit(theHailClassLoaderForSparkWorkers, fsBc.value, SparkTaskContext.get(), partRegion) - val sequence = makeSeq(theHailClassLoaderForSparkWorkers, fsBc.value, SparkTaskContext.get(), partRegion) - val newRowF = makeRow(theHailClassLoaderForSparkWorkers, fsBc.value, SparkTaskContext.get(), partRegion) + val globalsOff = + globalsBc.value.readRegionValue(partRegion, theHailClassLoaderForSparkWorkers) + + val initialize = makeInit( + theHailClassLoaderForSparkWorkers, + fsBc.value, + SparkTaskContext.get(), + partRegion, + ) + val sequence = + makeSeq(theHailClassLoaderForSparkWorkers, fsBc.value, SparkTaskContext.get(), partRegion) + val newRowF = + makeRow(theHailClassLoaderForSparkWorkers, fsBc.value, SparkTaskContext.get(), partRegion) val aggRegion = ctx.freshRegion() @@ -3218,7 +3982,6 @@ case class TableAggregateByKey(child: TableIR, expr: IR) extends TableIR { var current: Long = 0 val rowKey: WritableRegionValue = WritableRegionValue(sm, keyType, ctx.freshRegion()) val consumerRegion: Region = ctx.region - val newRV = RegionValue(consumerRegion) def hasNext: Boolean = { if (isEnd || (current == 0 && !it.hasNext)) { @@ -3242,9 +4005,7 @@ case class TableAggregateByKey(child: TableIR, expr: IR) extends TableIR { sequence.setAggState(aggRegion, initialize.getAggOffset()) do { - sequence(ctx.r, - globalsOff, - current) + sequence(ctx.r, globalsOff, current) current = 0 } while (hasNext && keyOrd.equiv(rowKey.value.offset, current)) newRowF.setAggState(aggRegion, sequence.getAggOffset()) @@ -3261,17 +4022,17 @@ case class TableAggregateByKey(child: TableIR, expr: IR) extends TableIR { } object TableOrderBy { - def isAlreadyOrdered(sortFields: IndexedSeq[SortField], prevKey: IndexedSeq[String]): Boolean = { + def isAlreadyOrdered(sortFields: IndexedSeq[SortField], prevKey: IndexedSeq[String]): Boolean = sortFields.length <= prevKey.length && sortFields.zip(prevKey).forall { case (sf, k) => sf.sortOrder == Ascending && sf.field == k } - } } case class TableOrderBy(child: TableIR, sortFields: IndexedSeq[SortField]) extends TableIR { - lazy val definitelyDoesNotShuffle: Boolean = TableOrderBy.isAlreadyOrdered(sortFields, child.typ.key) + lazy val definitelyDoesNotShuffle: Boolean = + TableOrderBy.isAlreadyOrdered(sortFields, child.typ.key) // TableOrderBy expects an unkeyed child, so that we can better optimize by // pushing these two steps around as needed @@ -3286,7 +4047,8 @@ case class TableOrderBy(child: TableIR, sortFields: IndexedSeq[SortField]) exten lazy val typ: TableType = child.typ.copy(key = FastSeq()) - protected[ir] override def execute(ctx: ExecuteContext, r: LoweringAnalyses): TableExecuteIntermediate = { + override protected[ir] def execute(ctx: ExecuteContext, r: LoweringAnalyses) + : TableExecuteIntermediate = { val prev = child.execute(ctx, r).asTableValue(ctx) val physicalKey = prev.rvd.typ.key @@ -3308,18 +4070,22 @@ case class TableOrderBy(child: TableIR, sortFields: IndexedSeq[SortField]) exten val codec = TypedCodecSpec(prev.rvd.rowPType, BufferSpec.wireSpec) val rdd = prev.rvd.keyedEncodedRDD(ctx, codec, sortFields.map(_.field)).sortBy(_._1)(ord, act) val (rowPType: PStruct, orderedCRDD) = codec.decodeRDD(ctx, rowType, rdd.map(_._2)) - new TableValueIntermediate(TableValue(ctx, typ, prev.globals, RVD.unkeyed(rowPType, orderedCRDD))) + new TableValueIntermediate(TableValue( + ctx, + typ, + prev.globals, + RVD.unkeyed(rowPType, orderedCRDD), + )) } } -/** Create a Table from a MatrixTable, storing the column values in a global - * field 'colsFieldName', and storing the entry values in a row field - * 'entriesFieldName'. +/** Create a Table from a MatrixTable, storing the column values in a global field 'colsFieldName', + * and storing the entry values in a row field 'entriesFieldName'. */ case class CastMatrixToTable( child: MatrixIR, entriesFieldName: String, - colsFieldName: String + colsFieldName: String, ) extends TableIR { lazy val rowCountUpperBound: Option[Long] = child.rowCountUpperBound @@ -3336,7 +4102,8 @@ case class CastMatrixToTable( override def partitionCounts: Option[IndexedSeq[Long]] = child.partitionCounts } -case class TableRename(child: TableIR, rowMap: Map[String, String], globalMap: Map[String, String]) extends TableIR { +case class TableRename(child: TableIR, rowMap: Map[String, String], globalMap: Map[String, String]) + extends TableIR { override def typecheck(): Unit = { assert(rowMap.keys.forall(child.typ.rowType.hasField)) assert(globalMap.keys.forall(child.typ.globalType.hasField)) @@ -3349,7 +4116,7 @@ case class TableRename(child: TableIR, rowMap: Map[String, String], globalMap: M lazy val typ: TableType = child.typ.copy( rowType = child.typ.rowType.rename(rowMap), globalType = child.typ.globalType.rename(globalMap), - key = child.typ.key.map(k => rowMap.getOrElse(k, k)) + key = child.typ.key.map(k => rowMap.getOrElse(k, k)), ) override def partitionCounts: Option[IndexedSeq[Long]] = child.partitionCounts @@ -3361,12 +4128,15 @@ case class TableRename(child: TableIR, rowMap: Map[String, String], globalMap: M TableRename(newChild, rowMap, globalMap) } - protected[ir] override def execute(ctx: ExecuteContext, r: LoweringAnalyses): TableExecuteIntermediate = + override protected[ir] def execute(ctx: ExecuteContext, r: LoweringAnalyses) + : TableExecuteIntermediate = TableValueIntermediate( - child.execute(ctx, r).asTableValue(ctx).rename(globalMap, rowMap)) + child.execute(ctx, r).asTableValue(ctx).rename(globalMap, rowMap) + ) } -case class TableFilterIntervals(child: TableIR, intervals: IndexedSeq[Interval], keep: Boolean) extends TableIR { +case class TableFilterIntervals(child: TableIR, intervals: IndexedSeq[Interval], keep: Boolean) + extends TableIR { lazy val childrenSeq: IndexedSeq[BaseIR] = Array(child) lazy val rowCountUpperBound: Option[Long] = child.rowCountUpperBound @@ -3378,22 +4148,26 @@ case class TableFilterIntervals(child: TableIR, intervals: IndexedSeq[Interval], override def typ: TableType = child.typ - protected[ir] override def execute(ctx: ExecuteContext, r: LoweringAnalyses): TableExecuteIntermediate = { + override protected[ir] def execute(ctx: ExecuteContext, r: LoweringAnalyses) + : TableExecuteIntermediate = { val tv = child.execute(ctx, r).asTableValue(ctx) val partitioner = RVDPartitioner.union( ctx.stateManager, tv.typ.keyType, intervals, - tv.typ.keyType.size - 1) + tv.typ.keyType.size - 1, + ) new TableValueIntermediate( - TableValue(ctx, tv.typ, tv.globals, tv.rvd.filterIntervals(partitioner, keep))) + TableValue(ctx, tv.typ, tv.globals, tv.rvd.filterIntervals(partitioner, keep)) + ) } } case class MatrixToTableApply(child: MatrixIR, function: MatrixToTableFunction) extends TableIR { lazy val childrenSeq: IndexedSeq[BaseIR] = Array(child) - lazy val rowCountUpperBound: Option[Long] = if (function.preservesPartitionCounts) child.rowCountUpperBound else None + lazy val rowCountUpperBound: Option[Long] = + if (function.preservesPartitionCounts) child.rowCountUpperBound else None def copy(newChildren: IndexedSeq[BaseIR]): TableIR = { val IndexedSeq(newChild: MatrixIR) = newChildren @@ -3419,17 +4193,18 @@ case class TableToTableApply(child: TableIR, function: TableToTableFunction) ext override def partitionCounts: Option[IndexedSeq[Long]] = if (function.preservesPartitionCounts) child.partitionCounts else None - lazy val rowCountUpperBound: Option[Long] = if (function.preservesPartitionCounts) child.rowCountUpperBound else None + lazy val rowCountUpperBound: Option[Long] = + if (function.preservesPartitionCounts) child.rowCountUpperBound else None - protected[ir] override def execute(ctx: ExecuteContext, r: LoweringAnalyses): TableExecuteIntermediate = { + override protected[ir] def execute(ctx: ExecuteContext, r: LoweringAnalyses) + : TableExecuteIntermediate = new TableValueIntermediate(function.execute(ctx, child.execute(ctx, r).asTableValue(ctx))) - } } case class BlockMatrixToTableApply( bm: BlockMatrixIR, aux: IR, - function: BlockMatrixToTableFunction + function: BlockMatrixToTableFunction, ) extends TableIR { override lazy val childrenSeq: IndexedSeq[BaseIR] = Array(bm, aux) @@ -3440,11 +4215,13 @@ case class BlockMatrixToTableApply( BlockMatrixToTableApply( newChildren(0).asInstanceOf[BlockMatrixIR], newChildren(1).asInstanceOf[IR], - function) + function, + ) override lazy val typ: TableType = function.typ(bm.typ, aux.typ) - protected[ir] override def execute(ctx: ExecuteContext, r: LoweringAnalyses): TableExecuteIntermediate = { + override protected[ir] def execute(ctx: ExecuteContext, r: LoweringAnalyses) + : TableExecuteIntermediate = { val b = bm.execute(ctx) val a = CompileAndEvaluate[Any](ctx, aux, optimize = false) new TableValueIntermediate(function.execute(ctx, b, a)) @@ -3466,9 +4243,9 @@ case class BlockMatrixToTable(child: BlockMatrixIR) extends TableIR { TableType(rvType, Array[String](), TStruct.empty) } - protected[ir] override def execute(ctx: ExecuteContext, r: LoweringAnalyses): TableExecuteIntermediate = { + override protected[ir] def execute(ctx: ExecuteContext, r: LoweringAnalyses) + : TableExecuteIntermediate = TableValueIntermediate(child.execute(ctx).entriesTable(ctx)) - } } case class RelationalLetTable(name: String, value: IR, body: TableIR) extends TableIR { diff --git a/hail/src/main/scala/is/hail/expr/ir/TableValue.scala b/hail/src/main/scala/is/hail/expr/ir/TableValue.scala index 1fa2512dc17..472abcb10fc 100644 --- a/hail/src/main/scala/is/hail/expr/ir/TableValue.scala +++ b/hail/src/main/scala/is/hail/expr/ir/TableValue.scala @@ -2,22 +2,23 @@ package is.hail.expr.ir import is.hail.HailContext import is.hail.annotations._ -import is.hail.asm4s.{HailClassLoader, theHailClassLoaderForSparkWorkers} -import is.hail.backend.spark.SparkTaskContext +import is.hail.asm4s.{theHailClassLoaderForSparkWorkers, HailClassLoader} import is.hail.backend.{BroadcastValue, ExecuteContext, HailTaskContext} +import is.hail.backend.spark.SparkTaskContext import is.hail.expr.TableAnnotationImpex import is.hail.expr.ir.lowering.{RVDToTableStage, TableStage, TableStageToRVD} import is.hail.io.exportTypes import is.hail.io.fs.FS import is.hail.rvd.{RVD, RVDContext, RVDPartitioner, RVDType} import is.hail.sparkextras.ContextRDD +import is.hail.types.{MatrixType, TableType} import is.hail.types.physical.{PArray, PCanonicalArray, PCanonicalStruct, PStruct} import is.hail.types.virtual.{Field, TArray, TStruct} -import is.hail.types.{MatrixType, TableType} import is.hail.utils._ + import org.apache.spark.rdd.RDD -import org.apache.spark.sql.types.StructType import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.types.StructType import org.apache.spark.storage.StorageLevel object TableExecuteIntermediate { @@ -35,9 +36,8 @@ sealed trait TableExecuteIntermediate { } case class TableValueIntermediate(tv: TableValue) extends TableExecuteIntermediate { - def asTableStage(ctx: ExecuteContext): TableStage = { + def asTableStage(ctx: ExecuteContext): TableStage = RVDToTableStage(tv.rvd, tv.globals.toEncodedLiteral(ctx.theHailClassLoader)) - } def asTableValue(ctx: ExecuteContext): TableValue = tv @@ -55,39 +55,55 @@ case class TableStageIntermediate(ts: TableStage) extends TableExecuteIntermedia def partitioner: RVDPartitioner = ts.partitioner } - object TableValue { - def apply(ctx: ExecuteContext, rowType: PStruct, key: IndexedSeq[String], rdd: ContextRDD[Long]): TableValue = { + def apply(ctx: ExecuteContext, rowType: PStruct, key: IndexedSeq[String], rdd: ContextRDD[Long]) + : TableValue = { assert(rowType.required) val tt = TableType(rowType.virtualType, key, TStruct.empty) - TableValue(ctx, - tt, - BroadcastRow.empty(ctx), - RVD.coerce(ctx, RVDType(rowType, key), rdd)) + TableValue(ctx, tt, BroadcastRow.empty(ctx), RVD.coerce(ctx, RVDType(rowType, key), rdd)) } - def apply(ctx: ExecuteContext, rowType: TStruct, key: IndexedSeq[String], rdd: RDD[Row], rowPType: Option[PStruct] = None): TableValue = { - val canonicalRowType = rowPType.getOrElse(PCanonicalStruct.canonical(rowType).setRequired(true).asInstanceOf[PStruct]) + def apply( + ctx: ExecuteContext, + rowType: TStruct, + key: IndexedSeq[String], + rdd: RDD[Row], + rowPType: Option[PStruct] = None, + ): TableValue = { + val canonicalRowType = rowPType.getOrElse( + PCanonicalStruct.canonical(rowType).setRequired(true).asInstanceOf[PStruct] + ) assert(canonicalRowType.required) val tt = TableType(rowType, key, TStruct.empty) - TableValue(ctx, + TableValue( + ctx, tt, BroadcastRow.empty(ctx), - RVD.coerce(ctx, + RVD.coerce( + ctx, RVDType(canonicalRowType, key), - ContextRDD.weaken(rdd).toRegionValues(canonicalRowType))) + ContextRDD.weaken(rdd).toRegionValues(canonicalRowType), + ), + ) } } case class TableValue(ctx: ExecuteContext, typ: TableType, globals: BroadcastRow, rvd: RVD) { if (typ.rowType != rvd.rowType) - throw new RuntimeException(s"row mismatch:\n typ: ${ typ.rowType.parsableString() }\n rvd: ${ rvd.rowType.parsableString() }") + throw new RuntimeException( + s"row mismatch:\n typ: ${typ.rowType.parsableString()}\n rvd: ${rvd.rowType.parsableString()}" + ) + if (!rvd.typ.key.startsWith(typ.key)) - throw new RuntimeException(s"key mismatch:\n typ: ${ typ.key }\n rvd: ${ rvd.typ.key }") + throw new RuntimeException(s"key mismatch:\n typ: ${typ.key}\n rvd: ${rvd.typ.key}") + if (typ.globalType != globals.t.virtualType) - throw new RuntimeException(s"globals mismatch:\n typ: ${ typ.globalType.parsableString() }\n val: ${ globals.t.virtualType.parsableString() }") + throw new RuntimeException( + s"globals mismatch:\n typ: ${typ.globalType.parsableString()}\n val: ${globals.t.virtualType.parsableString()}" + ) + if (!globals.t.required) - throw new RuntimeException(s"globals not required; ${ globals.t }") + throw new RuntimeException(s"globals not required; ${globals.t}") def rdd: RDD[Row] = rvd.toRows @@ -95,23 +111,50 @@ case class TableValue(ctx: ExecuteContext, typ: TableType, globals: BroadcastRow def persist(ctx: ExecuteContext, level: StorageLevel) = TableValue(ctx, typ, globals, rvd.persist(ctx, level)) - def filterWithPartitionOp[P](theHailClassLoader: HailClassLoader, fs: BroadcastValue[FS], partitionOp: (HailClassLoader, FS, HailTaskContext, Region) => P)(pred: (P, RVDContext, Long, Long) => Boolean): TableValue = { + def filterWithPartitionOp[P]( + theHailClassLoader: HailClassLoader, + fs: BroadcastValue[FS], + partitionOp: (HailClassLoader, FS, HailTaskContext, Region) => P, + )( + pred: (P, RVDContext, Long, Long) => Boolean + ): TableValue = { val localGlobals = globals.broadcast(theHailClassLoader) - copy(rvd = rvd.filterWithContext[(P, Long)]( - { (partitionIdx, ctx) => - val globalRegion = ctx.partitionRegion - ( - partitionOp(theHailClassLoaderForSparkWorkers, fs.value, SparkTaskContext.get(), globalRegion), - localGlobals.value.readRegionValue(globalRegion, theHailClassLoaderForSparkWorkers) - ) - }, { case ((p, glob), ctx, ptr) => pred(p, ctx, ptr, glob) })) + copy(rvd = + rvd.filterWithContext[(P, Long)]( + { (partitionIdx, ctx) => + val globalRegion = ctx.partitionRegion + ( + partitionOp( + theHailClassLoaderForSparkWorkers, + fs.value, + SparkTaskContext.get(), + globalRegion, + ), + localGlobals.value.readRegionValue(globalRegion, theHailClassLoaderForSparkWorkers), + ) + }, + { case ((p, glob), ctx, ptr) => pred(p, ctx, ptr, glob) }, + ) + ) } - def filter(theHailClassLoader: HailClassLoader, fs: BroadcastValue[FS], p: (RVDContext, Long, Long) => Boolean): TableValue = { - filterWithPartitionOp(theHailClassLoader, fs, (_, _, _, _) => ())((_, ctx, ptr, glob) => p(ctx, ptr, glob)) - } - - def export(ctx: ExecuteContext, path: String, typesFile: String = null, header: Boolean = true, exportType: String = ExportType.CONCATENATED, delimiter: String = "\t") { + def filter( + theHailClassLoader: HailClassLoader, + fs: BroadcastValue[FS], + p: (RVDContext, Long, Long) => Boolean, + ): TableValue = + filterWithPartitionOp(theHailClassLoader, fs, (_, _, _, _) => ())((_, ctx, ptr, glob) => + p(ctx, ptr, glob) + ) + + def export( + ctx: ExecuteContext, + path: String, + typesFile: String = null, + header: Boolean = true, + exportType: String = ExportType.CONCATENATED, + delimiter: String = "\t", + ): Unit = { val fs = ctx.fs fs.delete(path, recursive = true) @@ -137,30 +180,41 @@ case class TableValue(ctx: ExecuteContext, typ: TableType, globals: BroadcastRow sb.result() } - }.writeTable(ctx, path, Some(fields.map(_.name).mkString(localDelim)).filter(_ => header), exportType = exportType) + }.writeTable( + ctx, + path, + Some(fields.map(_.name).mkString(localDelim)).filter(_ => header), + exportType = exportType, + ) } - def toDF(): DataFrame = { + def toDF(): DataFrame = HailContext.sparkBackend("toDF").sparkSession.createDataFrame( rvd.toRows, - typ.rowType.schema.asInstanceOf[StructType]) - } + typ.rowType.schema.asInstanceOf[StructType], + ) def rename(globalMap: Map[String, String], rowMap: Map[String, String]): TableValue = { - TableValue(ctx, + TableValue( + ctx, typ.copy( rowType = typ.rowType.rename(rowMap), globalType = typ.globalType.rename(globalMap), - key = typ.key.map(k => rowMap.getOrElse(k, k))), - globals.copy(t = globals.t.rename(globalMap)), rvd = rvd.cast(rvd.rowPType.rename(rowMap))) + key = typ.key.map(k => rowMap.getOrElse(k, k)), + ), + globals.copy(t = globals.t.rename(globalMap)), + rvd = rvd.cast(rvd.rowPType.rename(rowMap)), + ) } - def toMatrixValue(colKey: IndexedSeq[String], + def toMatrixValue( + colKey: IndexedSeq[String], colsFieldName: String = LowerMatrixIR.colsFieldName, - entriesFieldName: String = LowerMatrixIR.entriesFieldName): MatrixValue = { + entriesFieldName: String = LowerMatrixIR.entriesFieldName, + ): MatrixValue = { val (colType, colsFieldIdx) = typ.globalType.field(colsFieldName) match { - case Field(_, TArray(t@TStruct(_)), idx) => (t, idx) + case Field(_, TArray(t @ TStruct(_)), idx) => (t, idx) case Field(_, t, _) => fatal(s"expected cols field to be an array of structs, found $t") } @@ -170,7 +224,10 @@ case class TableValue(ctx: ExecuteContext, typ: TableType, globals: BroadcastRow colType, typ.key, typ.rowType.deleteKey(entriesFieldName), - typ.rowType.field(MatrixType.entriesIdentifier).typ.asInstanceOf[TArray].elementType.asInstanceOf[TStruct]) + typ.rowType.field(MatrixType.entriesIdentifier).typ.asInstanceOf[ + TArray + ].elementType.asInstanceOf[TStruct], + ) val globalsT = globals.t val colsT = globalsT.field(colsFieldName).typ.asInstanceOf[PArray] @@ -181,12 +238,18 @@ case class TableValue(ctx: ExecuteContext, typ: TableType, globals: BroadcastRow else globals.cast( globalsT.insertFields(FastSeq( - colsFieldName -> PCanonicalArray(colsT.elementType.setRequired(true), true)))) + colsFieldName -> PCanonicalArray(colsT.elementType.setRequired(true), true) + )) + ) val newTV = TableValue(ctx, typ, globals2, rvd) - MatrixValue(mType, newTV.rename( - Map(colsFieldName -> LowerMatrixIR.colsFieldName), - Map(entriesFieldName -> LowerMatrixIR.entriesFieldName))) + MatrixValue( + mType, + newTV.rename( + Map(colsFieldName -> LowerMatrixIR.colsFieldName), + Map(entriesFieldName -> LowerMatrixIR.entriesFieldName), + ), + ) } } diff --git a/hail/src/main/scala/is/hail/expr/ir/TableWriter.scala b/hail/src/main/scala/is/hail/expr/ir/TableWriter.scala index c0426ea5fc3..4109948648d 100644 --- a/hail/src/main/scala/is/hail/expr/ir/TableWriter.scala +++ b/hail/src/main/scala/is/hail/expr/ir/TableWriter.scala @@ -8,33 +8,38 @@ import is.hail.expr.TableAnnotationImpex import is.hail.expr.ir.functions.StringFunctions import is.hail.expr.ir.lowering.{LowererUnsupportedOperation, TableStage} import is.hail.expr.ir.streams.StreamProducer +import is.hail.io.{AbstractTypedCodecSpec, BufferSpec, OutputBuffer, TypedCodecSpec} import is.hail.io.fs.FS import is.hail.io.index.StagedIndexWriter -import is.hail.io.{AbstractTypedCodecSpec, BufferSpec, OutputBuffer, TypedCodecSpec} import is.hail.rvd.{AbstractRVDSpec, IndexSpec, RVDPartitioner, RVDSpecMaker} import is.hail.types._ import is.hail.types.encoded.EType import is.hail.types.physical._ -import is.hail.types.physical.stypes.concrete.{SJavaArrayString, SJavaArrayStringValue, SStackStruct} +import is.hail.types.physical.stypes.{EmitType, SCode, SValue} +import is.hail.types.physical.stypes.concrete.{ + SJavaArrayString, SJavaArrayStringValue, SStackStruct, +} import is.hail.types.physical.stypes.interfaces._ import is.hail.types.physical.stypes.primitives.{SBooleanValue, SInt64, SInt64Value} -import is.hail.types.physical.stypes.{EmitType, SCode, SValue} import is.hail.types.virtual._ import is.hail.utils._ import is.hail.utils.richUtils.ByteTrackingOutputStream import is.hail.variant.ReferenceGenome -import org.json4s.{DefaultFormats, Formats, JBool, JObject, ShortTypeHints} + +import scala.language.existentials import java.io.{BufferedOutputStream, OutputStream} import java.nio.file.{FileSystems, Path} import java.util.UUID -import scala.language.existentials +import org.json4s.{DefaultFormats, Formats, JBool, JObject, ShortTypeHints} object TableWriter { - implicit val formats: Formats = new DefaultFormats() { + implicit val formats: Formats = new DefaultFormats() { override val typeHints = ShortTypeHints( - List(classOf[TableNativeFanoutWriter], classOf[TableNativeWriter], classOf[TableTextWriter]), typeHintFieldName = "name") + List(classOf[TableNativeFanoutWriter], classOf[TableNativeWriter], classOf[TableTextWriter]), + typeHintFieldName = "name", + ) } } @@ -47,50 +52,96 @@ abstract class TableWriter { } def lower(ctx: ExecuteContext, ts: TableStage, r: RTable): IR = - throw new LowererUnsupportedOperation(s"${ this.getClass } does not have defined lowering!") + throw new LowererUnsupportedOperation(s"${this.getClass} does not have defined lowering!") def canLowerEfficiently: Boolean = true } object TableNativeWriter { - def lower(ctx: ExecuteContext, ts: TableStage, path: String, overwrite: Boolean, stageLocally: Boolean, - rowSpec: TypedCodecSpec, globalSpec: TypedCodecSpec): IR = { + def lower( + ctx: ExecuteContext, + ts: TableStage, + path: String, + overwrite: Boolean, + stageLocally: Boolean, + rowSpec: TypedCodecSpec, + globalSpec: TypedCodecSpec, + ): IR = { // write out partitioner key, which may be stricter than table key val partitioner = ts.partitioner val pKey: PStruct = tcoerce[PStruct](rowSpec.decodedPType(partitioner.kType)) - val rowWriter = PartitionNativeWriter(rowSpec, pKey.fieldNames, s"$path/rows/parts/", Some(s"$path/index/" -> pKey), - if (stageLocally) Some(FileSystems.getDefault.getPath(ctx.localTmpdir, s"hail_staging_tmp_${UUID.randomUUID()}", "rows", "parts")) else None + val rowWriter = PartitionNativeWriter( + rowSpec, + pKey.fieldNames, + s"$path/rows/parts/", + Some(s"$path/index/" -> pKey), + if (stageLocally) Some(FileSystems.getDefault.getPath( + ctx.localTmpdir, + s"hail_staging_tmp_${UUID.randomUUID()}", + "rows", + "parts", + )) + else None, ) - val globalWriter = PartitionNativeWriter(globalSpec, IndexedSeq(), s"$path/globals/parts/", None, None) + val globalWriter = + PartitionNativeWriter(globalSpec, IndexedSeq(), s"$path/globals/parts/", None, None) RelationalWriter.scoped(path, overwrite, Some(ts.tableType))( ts.mapContexts { oldCtx => val d = digitsNeeded(ts.numPartitions) - val partFiles = Literal(TArray(TString), Array.tabulate(ts.numPartitions)(i => s"${ partFile(d, i) }-").toFastSeq) + val partFiles = Literal( + TArray(TString), + Array.tabulate(ts.numPartitions)(i => s"${partFile(d, i)}-").toFastSeq, + ) zip2(oldCtx, ToStream(partFiles), ArrayZipBehavior.AssertSameLength) { (ctxElt, pf) => MakeStruct(FastSeq( "oldCtx" -> ctxElt, - "writeCtx" -> pf)) + "writeCtx" -> pf, + )) } - }(GetField(_, "oldCtx")).mapCollectWithContextsAndGlobals( "table_native_writer") { (rows, ctxRef) => - val file = GetField(ctxRef, "writeCtx") - WritePartition(rows, file + UUID4(), rowWriter) + }(GetField(_, "oldCtx")).mapCollectWithContextsAndGlobals("table_native_writer") { + (rows, ctxRef) => + val file = GetField(ctxRef, "writeCtx") + WritePartition(rows, file + UUID4(), rowWriter) } { (parts, globals) => - val writeGlobals = WritePartition(MakeStream(FastSeq(globals), TStream(globals.typ)), - Str(partFile(1, 0)), globalWriter) + val writeGlobals = WritePartition( + MakeStream(FastSeq(globals), TStream(globals.typ)), + Str(partFile(1, 0)), + globalWriter, + ) bindIR(parts) { fileCountAndDistinct => Begin(FastSeq( - WriteMetadata(MakeArray(GetField(writeGlobals, "filePath")), - RVDSpecWriter(s"$path/globals", RVDSpecMaker(globalSpec, RVDPartitioner.unkeyed(ctx.stateManager, 1)))), - WriteMetadata(ToArray(mapIR(ToStream(fileCountAndDistinct)) { fc => GetField(fc, "filePath") }), - RVDSpecWriter(s"$path/rows", RVDSpecMaker(rowSpec, partitioner, IndexSpec.emptyAnnotation("../index", tcoerce[PStruct](pKey))))), - WriteMetadata(ToArray(mapIR(ToStream(fileCountAndDistinct)) { fc => - SelectFields(fc, FastSeq("partitionCounts", "distinctlyKeyed", "firstKey", "lastKey")) - }), - TableSpecWriter(path, ts.tableType, "rows", "globals", "references", log = true)))) + WriteMetadata( + MakeArray(GetField(writeGlobals, "filePath")), + RVDSpecWriter( + s"$path/globals", + RVDSpecMaker(globalSpec, RVDPartitioner.unkeyed(ctx.stateManager, 1)), + ), + ), + WriteMetadata( + ToArray(mapIR(ToStream(fileCountAndDistinct))(fc => GetField(fc, "filePath"))), + RVDSpecWriter( + s"$path/rows", + RVDSpecMaker( + rowSpec, + partitioner, + IndexSpec.emptyAnnotation("../index", tcoerce[PStruct](pKey)), + ), + ), + ), + WriteMetadata( + ToArray(mapIR(ToStream(fileCountAndDistinct)) { fc => + SelectFields( + fc, + FastSeq("partitionCounts", "distinctlyKeyed", "firstKey", "lastKey"), + ) + }), + TableSpecWriter(path, ts.tableType, "rows", "globals", "references", log = true), + ), + )) } } ) @@ -101,13 +152,18 @@ case class TableNativeWriter( path: String, overwrite: Boolean = true, stageLocally: Boolean = false, - codecSpecJSONStr: String = null + codecSpecJSONStr: String = null, ) extends TableWriter { override def lower(ctx: ExecuteContext, ts: TableStage, r: RTable): IR = { val bufferSpec: BufferSpec = BufferSpec.parseOrDefault(codecSpecJSONStr) - val rowSpec = TypedCodecSpec(EType.fromTypeAndAnalysis(ts.rowType, r.rowType), ts.rowType, bufferSpec) - val globalSpec = TypedCodecSpec(EType.fromTypeAndAnalysis(ts.globalType, r.globalType), ts.globalType, bufferSpec) + val rowSpec = + TypedCodecSpec(EType.fromTypeAndAnalysis(ts.rowType, r.rowType), ts.rowType, bufferSpec) + val globalSpec = TypedCodecSpec( + EType.fromTypeAndAnalysis(ts.globalType, r.globalType), + ts.globalType, + bufferSpec, + ) TableNativeWriter.lower(ctx, ts, path, overwrite, stageLocally, rowSpec, globalSpec) } @@ -115,34 +171,40 @@ case class TableNativeWriter( object PartitionNativeWriter { val ctxType = TString + def fullReturnType(keyType: TStruct): TStruct = TStruct( "filePath" -> TString, "partitionCounts" -> TInt64, "distinctlyKeyed" -> TBoolean, "firstKey" -> keyType, "lastKey" -> keyType, - "partitionByteSize" -> TInt64 + "partitionByteSize" -> TInt64, ) def returnType(keyType: TStruct, trackTotalBytes: Boolean): TStruct = { val t = PartitionNativeWriter.fullReturnType(keyType) - if (trackTotalBytes) t else t.filterSet(Set("partitionByteSize"), include=false)._1 + if (trackTotalBytes) t else t.filterSet(Set("partitionByteSize"), include = false)._1 } } -case class PartitionNativeWriter(spec: AbstractTypedCodecSpec, - keyFields: IndexedSeq[String], - partPrefix: String, - index: Option[(String, PStruct)] = None, - stagingFolder: Option[Path] = None, - trackTotalBytes: Boolean = false - ) extends PartitionWriter { +case class PartitionNativeWriter( + spec: AbstractTypedCodecSpec, + keyFields: IndexedSeq[String], + partPrefix: String, + index: Option[(String, PStruct)] = None, + stagingFolder: Option[Path] = None, + trackTotalBytes: Boolean = false, +) extends PartitionWriter { val keyType = spec.encodedVirtualType.asInstanceOf[TStruct].select(keyFields)._1 def ctxType = PartitionNativeWriter.ctxType val returnType = PartitionNativeWriter.returnType(keyType, trackTotalBytes) - def unionTypeRequiredness(r: TypeWithRequiredness, ctxType: TypeWithRequiredness, streamType: RIterable): Unit = { + def unionTypeRequiredness( + r: TypeWithRequiredness, + ctxType: TypeWithRequiredness, + streamType: RIterable, + ): Unit = { val rs = r.asInstanceOf[RStruct] val rKeyType = streamType.elementType.asInstanceOf[RStruct].select(keyFields.toArray) rs.field("firstKey").union(false) @@ -156,22 +218,29 @@ case class PartitionNativeWriter(spec: AbstractTypedCodecSpec, class StreamConsumer( ctx: SValue, private val cb: EmitCodeBuilder, - private val region: Value[Region] + private val region: Value[Region], ) { private val mb = cb.emb + private val writeIndexInfo = index.map { case (name, ktype) => - val branchingFactor = Option(mb.ctx.getFlag("index_branching_factor")).map(_.toInt).getOrElse(4096) - (name, ktype, StagedIndexWriter.withDefaults(ktype, mb.ecb, branchingFactor = branchingFactor)) + val branchingFactor = + Option(mb.ctx.getFlag("index_branching_factor")).map(_.toInt).getOrElse(4096) + ( + name, + ktype, + StagedIndexWriter.withDefaults(ktype, mb.ecb, branchingFactor = branchingFactor), + ) } private val filename = mb.newLocal[String]("filename") - private val stagingInfo = stagingFolder.map { folder => - (folder, mb.newLocal[String]("stage")) - } + private val stagingInfo = stagingFolder.map(folder => (folder, mb.newLocal[String]("stage"))) private val ob = mb.newLocal[OutputBuffer]("write_ob") private val n = mb.newLocal[Long]("partition_count") - private val byteCount = if (trackTotalBytes) Some(mb.newPLocal("partition_byte_count", SInt64)) else None + + private val byteCount = + if (trackTotalBytes) Some(mb.newPLocal("partition_byte_count", SInt64)) else None + private val distinctlyKeyed = mb.newLocal[Boolean]("distinctlyKeyed") private val keyEmitType = EmitType(spec.decodedPType(keyType).sType, false) private val firstSeenSettable = mb.newEmitLocal("pnw_firstSeen", keyEmitType) @@ -179,7 +248,10 @@ case class PartitionNativeWriter(spec: AbstractTypedCodecSpec, private val lastSeenRegion = mb.newLocal[Region]("last_key_region") def setup(): Unit = { - cb.assign(distinctlyKeyed, !keyFields.isEmpty) // True until proven otherwise, if there's a key to care about at all. + cb.assign( + distinctlyKeyed, + !keyFields.isEmpty, + ) // True until proven otherwise, if there's a key to care about at all. // Start off missing, we will use this to determine if we haven't processed any rows yet. cb.assign(firstSeenSettable, EmitCode.missing(cb.emb, keyEmitType.st)) cb.assign(lastSeenSettable, EmitCode.missing(cb.emb, keyEmitType.st)) @@ -194,57 +266,81 @@ case class PartitionNativeWriter(spec: AbstractTypedCodecSpec, writer.init(cb, indexFile, cb.memoize(mb.getObject[Map[String, Any]](Map.empty))) } - val stagingFile = stagingInfo.map { case (folder, fileRef) => cb.assign(fileRef, const(s"$folder/").concat(ctxValue)) fileRef } val os = mb.newLocal[ByteTrackingOutputStream]("write_os") - cb.assign(os, Code.newInstance[ByteTrackingOutputStream, OutputStream]( - mb.create(stagingFile.getOrElse(filename).get) - )) + cb.assign( + os, + Code.newInstance[ByteTrackingOutputStream, OutputStream]( + mb.create(stagingFile.getOrElse(filename).get) + ), + ) cb.assign(ob, spec.buildCodeOutputBuffer(Code.checkcast[OutputStream](os))) cb.assign(n, 0L) } - def consumeElement(cb: EmitCodeBuilder, codeRow: SValue, elementRegion: Settable[Region]): Unit = { + def consumeElement(cb: EmitCodeBuilder, codeRow: SValue, elementRegion: Settable[Region]) + : Unit = { val row = codeRow.asBaseStruct writeIndexInfo.foreach { case (_, indexKeyType, writer) => - writer.add(cb, { - IEmitCode.present(cb, indexKeyType.asInstanceOf[PCanonicalBaseStruct] - .constructFromFields(cb, elementRegion, indexKeyType.fields.map { f => - EmitCode.fromI(cb.emb)(cb => row.loadField(cb, f.name)) - }, - deepCopy = true + writer.add( + cb, { + IEmitCode.present( + cb, + indexKeyType.asInstanceOf[PCanonicalBaseStruct] + .constructFromFields( + cb, + elementRegion, + indexKeyType.fields.map { f => + EmitCode.fromI(cb.emb)(cb => row.loadField(cb, f.name)) + }, + deepCopy = true, + ), ) - ) - }, + }, ob.invoke[Long]("indexOffset"), - IEmitCode.present(cb, PCanonicalStruct().loadCheapSCode(cb, 0L)) + IEmitCode.present(cb, PCanonicalStruct().loadCheapSCode(cb, 0L)), ) } - val key = SStackStruct.constructFromArgs(cb, elementRegion, keyType, keyType.fields.map { f => - EmitCode.fromI(cb.emb)(cb => row.loadField(cb, f.name)) - }:_*) + val key = SStackStruct.constructFromArgs( + cb, + elementRegion, + keyType, + keyType.fields.map(f => EmitCode.fromI(cb.emb)(cb => row.loadField(cb, f.name))): _* + ) if (!keyFields.isEmpty) { - cb.if_(distinctlyKeyed, { - lastSeenSettable.loadI(cb).consume(cb, { - // If there's no last seen, we are in the first row. - cb.assign(firstSeenSettable, EmitValue.present(key.copyToRegion(cb, region, firstSeenSettable.st))) - }, { lastSeen => - val comparator = EQ(lastSeenSettable.emitType.virtualType).codeOrdering(cb.emb.ecb, lastSeenSettable.st, key.st) - val equalToLast = comparator(cb, lastSeenSettable, EmitValue.present(key)) - cb.if_(equalToLast.asInstanceOf[Value[Boolean]], { - cb.assign(distinctlyKeyed, false) - }) - }) - }) + cb.if_( + distinctlyKeyed, { + lastSeenSettable.loadI(cb).consume( + cb, + // If there's no last seen, we are in the first row. + cb.assign( + firstSeenSettable, + EmitValue.present(key.copyToRegion(cb, region, firstSeenSettable.st)), + ), + { lastSeen => + val comparator = EQ(lastSeenSettable.emitType.virtualType).codeOrdering( + cb.emb.ecb, + lastSeenSettable.st, + key.st, + ) + val equalToLast = comparator(cb, lastSeenSettable, EmitValue.present(key)) + cb.if_(equalToLast.asInstanceOf[Value[Boolean]], cb.assign(distinctlyKeyed, false)) + }, + ) + }, + ) cb += lastSeenRegion.clearRegion() - cb.assign(lastSeenSettable, IEmitCode.present(cb, key.copyToRegion(cb, lastSeenRegion, lastSeenSettable.st))) + cb.assign( + lastSeenSettable, + IEmitCode.present(cb, key.copyToRegion(cb, lastSeenRegion, lastSeenSettable.st)), + ) } cb += ob.writeByte(1.asInstanceOf[Byte]) @@ -253,9 +349,7 @@ case class PartitionNativeWriter(spec: AbstractTypedCodecSpec, .apply(cb, row, ob) cb.assign(n, n + 1L) - byteCount.foreach { bc => - cb.assign(bc, SCode.add(cb, bc, row.sizeToStoreInBytes(cb), true)) - } + byteCount.foreach(bc => cb.assign(bc, SCode.add(cb, bc, row.sizeToStoreInBytes(cb), true))) } def result(): SValue = { @@ -268,16 +362,22 @@ case class PartitionNativeWriter(spec: AbstractTypedCodecSpec, cb += mb.getFS.invoke[String, String, Boolean, Unit]("copy", source, filename, const(true)) } - lastSeenSettable.loadI(cb).consume(cb, { /* do nothing */ }, { lastSeen => - cb.assign(lastSeenSettable, IEmitCode.present(cb, lastSeen.copyToRegion(cb, region, lastSeenSettable.st))) - }) + lastSeenSettable.loadI(cb).consume( + cb, + { /* do nothing */ }, + lastSeen => + cb.assign( + lastSeenSettable, + IEmitCode.present(cb, lastSeen.copyToRegion(cb, region, lastSeenSettable.st)), + ), + ) cb += lastSeenRegion.invalidate() val values = Seq[EmitCode]( EmitCode.present(mb, ctx), EmitCode.present(mb, new SInt64Value(n)), EmitCode.present(mb, new SBooleanValue(distinctlyKeyed)), firstSeenSettable, - lastSeenSettable + lastSeenSettable, ) ++ byteCount.map(EmitCode.present(mb, _)) SStackStruct.constructFromArgs(cb, region, returnType.asInstanceOf[TBaseStruct], values: _*) @@ -289,13 +389,13 @@ case class PartitionNativeWriter(spec: AbstractTypedCodecSpec, cb: EmitCodeBuilder, stream: StreamProducer, context: EmitCode, - region: Value[Region] + region: Value[Region], ): IEmitCode = { - val ctx = context.toI(cb).get(cb) + val ctx = context.toI(cb).getOrAssert(cb) val consumer = new StreamConsumer(ctx, cb, region) consumer.setup() stream.memoryManagedConsume(region, cb) { cb => - val element = stream.element.toI(cb).get(cb, "row can't be missing") + val element = stream.element.toI(cb).getOrFatal(cb, "row can't be missing") consumer.consumeElement(cb, element, stream.elementRegion) } IEmitCode.present(cb, consumer.result()) @@ -308,63 +408,89 @@ case class RVDSpecWriter(path: String, spec: RVDSpecMaker) extends MetadataWrite def writeMetadata( writeAnnotations: => IEmitCode, cb: EmitCodeBuilder, - region: Value[Region]): Unit = { + region: Value[Region], + ): Unit = { cb += cb.emb.getFS.invoke[String, Unit]("mkDir", path) - val a = writeAnnotations.get(cb, "write annotations can't be missing!").asIndexable + val a = writeAnnotations.getOrFatal(cb, "write annotations can't be missing!").asIndexable val partFiles = cb.newLocal[Array[String]]("partFiles") val n = cb.newLocal[Int]("n", a.loadLength()) val i = cb.newLocal[Int]("i", 0) cb.assign(partFiles, Code.newArray[String](n)) - cb.while_(i < n, { - val s = a.loadElement(cb, i).get(cb, "file name can't be missing!").asString - cb += partFiles.update(i, s.loadString(cb)) - cb.assign(i, i + 1) - }) + cb.while_( + i < n, { + val s = a.loadElement(cb, i).getOrFatal(cb, "file name can't be missing!").asString + cb += partFiles.update(i, s.loadString(cb)) + cb.assign(i, i + 1) + }, + ) cb += cb.emb.getObject(spec) .invoke[Array[String], AbstractRVDSpec]("apply", partFiles) .invoke[FS, String, Unit]("write", cb.emb.getFS, path) } } -class TableSpecHelper(path: String, rowRelPath: String, globalRelPath: String, refRelPath: String, typ: TableType, log: Boolean) extends Serializable { +class TableSpecHelper( + path: String, + rowRelPath: String, + globalRelPath: String, + refRelPath: String, + typ: TableType, + log: Boolean, +) extends Serializable { def write(fs: FS, partCounts: Array[Long], distinctlyKeyed: Boolean): Unit = { val spec = TableSpecParameters( FileFormat.version.rep, is.hail.HAIL_PRETTY_VERSION, refRelPath, typ, - Map("globals" -> RVDComponentSpec(globalRelPath), + Map( + "globals" -> RVDComponentSpec(globalRelPath), "rows" -> RVDComponentSpec(rowRelPath), "partition_counts" -> PartitionCountsComponentSpec(partCounts), "properties" -> PropertiesSpec(JObject( "distinctlyKeyed" -> JBool(distinctlyKeyed) - )) - )) + )), + ), + ) spec.write(fs, path) val nRows = partCounts.sum - if (log) info(s"wrote table with $nRows ${ plural(nRows, "row") } " + - s"in ${ partCounts.length } ${ plural(partCounts.length, "partition") } " + + if (log) info(s"wrote table with $nRows ${plural(nRows, "row")} " + + s"in ${partCounts.length} ${plural(partCounts.length, "partition")} " + s"to $path") } } -case class TableSpecWriter(path: String, typ: TableType, rowRelPath: String, globalRelPath: String, refRelPath: String, log: Boolean) extends MetadataWriter { - def annotationType: Type = TArray(TStruct("partitionCounts" -> TInt64, "distinctlyKeyed" -> TBoolean, "firstKey" -> typ.keyType, "lastKey" -> typ.keyType)) +case class TableSpecWriter( + path: String, + typ: TableType, + rowRelPath: String, + globalRelPath: String, + refRelPath: String, + log: Boolean, +) extends MetadataWriter { + def annotationType: Type = TArray(TStruct( + "partitionCounts" -> TInt64, + "distinctlyKeyed" -> TBoolean, + "firstKey" -> typ.keyType, + "lastKey" -> typ.keyType, + )) def writeMetadata( writeAnnotations: => IEmitCode, cb: EmitCodeBuilder, - region: Value[Region]): Unit = { + region: Value[Region], + ): Unit = { cb += cb.emb.getFS.invoke[String, Unit]("mkDir", path) val hasKey = !this.typ.keyType.fields.isEmpty - val a = writeAnnotations.get(cb, "write annotations can't be missing!").asIndexable + val a = writeAnnotations.getOrFatal(cb, "write annotations can't be missing!").asIndexable val partCounts = cb.newLocal[Array[Long]]("partCounts") - val idxOfFirstKeyField = annotationType.asInstanceOf[TArray].elementType.asInstanceOf[TStruct].fieldIdx("firstKey") + val idxOfFirstKeyField = + annotationType.asInstanceOf[TArray].elementType.asInstanceOf[TStruct].fieldIdx("firstKey") val keySType = a.st.elementType.asInstanceOf[SBaseStruct].fieldTypes(idxOfFirstKeyField) val lastSeenSettable = cb.emb.newEmitLocal(EmitType(keySType, false)) @@ -374,39 +500,82 @@ case class TableSpecWriter(path: String, typ: TableType, rowRelPath: String, glo val n = cb.newLocal[Int]("n", a.loadLength()) val i = cb.newLocal[Int]("i", 0) cb.assign(partCounts, Code.newArray[Long](n)) - cb.while_(i < n, { - val curElement = a.loadElement(cb, i).get(cb, "writeMetadata annotation can't be missing").asBaseStruct - val count = curElement.asBaseStruct.loadField(cb, "partitionCounts").get(cb, "part count can't be missing!").asLong.value - - if (hasKey) { - // Only nonempty partitions affect first, last, and distinctlyKeyed. - cb.if_(count cne 0L, { - val curFirst = curElement.loadField(cb, "firstKey").get(cb, const("firstKey of curElement can't be missing, part size was ") concat count.toS) - - val comparator = NEQ(lastSeenSettable.emitType.virtualType).codeOrdering(cb.emb.ecb, lastSeenSettable.st, curFirst.st) - val notEqualToLast = comparator(cb, lastSeenSettable, EmitValue.present(curFirst)).asInstanceOf[Value[Boolean]] - - val partWasDistinctlyKeyed = curElement.loadField(cb, "distinctlyKeyed").get(cb).asBoolean.value - cb.assign(distinctlyKeyed, distinctlyKeyed && partWasDistinctlyKeyed && notEqualToLast) - cb.assign(lastSeenSettable, curElement.loadField(cb, "lastKey")) - }) - } + cb.while_( + i < n, { + val curElement = + a.loadElement(cb, i).getOrFatal( + cb, + "writeMetadata annotation can't be missing", + ).asBaseStruct + val count = curElement.asBaseStruct.loadField(cb, "partitionCounts").getOrFatal( + cb, + "part count can't be missing!", + ).asLong.value + + if (hasKey) { + // Only nonempty partitions affect first, last, and distinctlyKeyed. + cb.if_( + count cne 0L, { + val curFirst = curElement.loadField(cb, "firstKey").getOrFatal( + cb, + const("firstKey of curElement can't be missing, part size was ") concat count.toS, + ) + + val comparator = NEQ(lastSeenSettable.emitType.virtualType).codeOrdering( + cb.emb.ecb, + lastSeenSettable.st, + curFirst.st, + ) + val notEqualToLast = comparator( + cb, + lastSeenSettable, + EmitValue.present(curFirst), + ).asInstanceOf[Value[Boolean]] + + val partWasDistinctlyKeyed = + curElement.loadField(cb, "distinctlyKeyed").getOrAssert(cb).asBoolean.value + cb.assign( + distinctlyKeyed, + distinctlyKeyed && partWasDistinctlyKeyed && notEqualToLast, + ) + cb.assign(lastSeenSettable, curElement.loadField(cb, "lastKey")) + }, + ) + } - cb += partCounts.update(i, count) - cb.assign(i, i + 1) - }) - cb += cb.emb.getObject(new TableSpecHelper(path, rowRelPath, globalRelPath, refRelPath, typ, log)) + cb += partCounts.update(i, count) + cb.assign(i, i + 1) + }, + ) + cb += cb.emb.getObject(new TableSpecHelper(path, rowRelPath, globalRelPath, refRelPath, typ, + log)) .invoke[FS, Array[Long], Boolean, Unit]("write", cb.emb.getFS, partCounts, distinctlyKeyed) } } object RelationalWriter { - def scoped(path: String, overwrite: Boolean, refs: Option[TableType])(write: IR): IR = WriteMetadata( - write, RelationalWriter(path, overwrite, refs.map(typ => "references" -> (ReferenceGenome.getReferences(typ.rowType) ++ ReferenceGenome.getReferences(typ.globalType))))) + def scoped(path: String, overwrite: Boolean, refs: Option[TableType])(write: IR): IR = + WriteMetadata( + write, + RelationalWriter( + path, + overwrite, + refs.map(typ => + "references" -> (ReferenceGenome.getReferences( + typ.rowType + ) ++ ReferenceGenome.getReferences(typ.globalType)) + ), + ), + ) } -case class RelationalSetup(path: String, overwrite: Boolean, refs: Option[TableType]) extends MetadataWriter { - lazy val maybeRefs = refs.map(typ => "references" -> (ReferenceGenome.getReferences(typ.rowType) ++ ReferenceGenome.getReferences(typ.globalType))) +case class RelationalSetup(path: String, overwrite: Boolean, refs: Option[TableType]) + extends MetadataWriter { + lazy val maybeRefs = refs.map(typ => + "references" -> (ReferenceGenome.getReferences(typ.rowType) ++ ReferenceGenome.getReferences( + typ.globalType + )) + ) def annotationType: Type = TVoid @@ -414,14 +583,23 @@ case class RelationalSetup(path: String, overwrite: Boolean, refs: Option[TableT if (overwrite) cb += cb.emb.getFS.invoke[String, Boolean, Unit]("delete", path, true) else - cb.if_(cb.emb.getFS.invoke[String, Boolean]("exists", path), cb._fatal(s"file already exists: $path")) + cb.if_( + cb.emb.getFS.invoke[String, Boolean]("exists", path), + cb._fatal(s"RelationalSetup.writeMetadata: file already exists: $path"), + ) cb += cb.emb.getFS.invoke[String, Unit]("mkDir", path) maybeRefs.foreach { case (refRelPath, refs) => val referencesFQPath = s"$path/$refRelPath" cb += cb.emb.getFS.invoke[String, Unit]("mkDir", referencesFQPath) refs.foreach { rg => - cb += Code.invokeScalaObject3[FS, String, ReferenceGenome, Unit](ReferenceGenome.getClass, "writeReference", cb.emb.getFS, referencesFQPath, cb.emb.getReferenceGenome(rg)) + cb += Code.invokeScalaObject3[FS, String, ReferenceGenome, Unit]( + ReferenceGenome.getClass, + "writeReference", + cb.emb.getFS, + referencesFQPath, + cb.emb.getReferenceGenome(rg), + ) } } } @@ -433,36 +611,65 @@ case class RelationalCommit(path: String) extends MetadataWriter { def writeMetadata( writeAnnotations: => IEmitCode, cb: EmitCodeBuilder, - region: Value[Region]): Unit = { - cb += Code.invokeScalaObject2[FS, String, Unit](Class.forName("is.hail.utils.package$"), "writeNativeFileReadMe", cb.emb.getFS, path) + region: Value[Region], + ): Unit = { + cb += Code.invokeScalaObject2[FS, String, Unit]( + Class.forName("is.hail.utils.package$"), + "writeNativeFileReadMe", + cb.emb.getFS, + path, + ) cb += cb.emb.create(s"$path/_SUCCESS").invoke[Unit]("close") } } -case class RelationalWriter(path: String, overwrite: Boolean, maybeRefs: Option[(String, Set[String])]) extends MetadataWriter { +case class RelationalWriter( + path: String, + overwrite: Boolean, + maybeRefs: Option[(String, Set[String])], +) extends MetadataWriter { def annotationType: Type = TVoid def writeMetadata( writeAnnotations: => IEmitCode, cb: EmitCodeBuilder, - region: Value[Region]): Unit = { + region: Value[Region], + ): Unit = { if (overwrite) cb += cb.emb.getFS.invoke[String, Boolean, Unit]("delete", path, true) else - cb.if_(cb.emb.getFS.invoke[String, Boolean]("exists", path), cb._fatal(s"file already exists: $path")) + cb.if_( + cb.emb.getFS.invoke[String, Boolean]("exists", path), + cb._fatal(s"RelationalWriter.writeMetadata: file already exists: $path"), + ) cb += cb.emb.getFS.invoke[String, Unit]("mkDir", path) maybeRefs.foreach { case (refRelPath, refs) => val referencesFQPath = s"$path/$refRelPath" cb += cb.emb.getFS.invoke[String, Unit]("mkDir", referencesFQPath) refs.foreach { rg => - cb += Code.invokeScalaObject3[FS, String, ReferenceGenome, Unit](ReferenceGenome.getClass, "writeReference", cb.emb.getFS, referencesFQPath, cb.emb.getReferenceGenome(rg)) + cb += Code.invokeScalaObject3[FS, String, ReferenceGenome, Unit]( + ReferenceGenome.getClass, + "writeReference", + cb.emb.getFS, + referencesFQPath, + cb.emb.getReferenceGenome(rg), + ) } } - writeAnnotations.consume(cb, {}, { pc => assert(pc == SVoidValue) }) // PVoidCode.code is Code._empty - - cb += Code.invokeScalaObject2[FS, String, Unit](Class.forName("is.hail.utils.package$"), "writeNativeFileReadMe", cb.emb.getFS, path) + writeAnnotations.consume( + cb, + {}, + pc => assert(pc == SVoidValue), + ) // PVoidCode.code is Code._empty + + cb += Code.invokeScalaObject2[FS, String, Unit]( + Class.forName("is.hail.utils.package$"), + "writeNativeFileReadMe", + cb.emb.getFS, + path, + ) cb += cb.emb.create(s"$path/_SUCCESS").invoke[Unit]("close") } } @@ -472,7 +679,7 @@ case class TableTextWriter( typesFile: String = null, header: Boolean = true, exportType: String = ExportType.CONCATENATED, - delimiter: String + delimiter: String, ) extends TableWriter { override def canLowerEfficiently: Boolean = exportType != ExportType.PARALLEL_COMPOSABLE @@ -489,20 +696,29 @@ case class TableTextWriter( ctx.createTmpPath("write-table-concatenated") else path - val lineWriter = TableTextPartitionWriter(ts.rowType, delimiter, writeHeader = exportType == ExportType.PARALLEL_HEADER_IN_SHARD) + val lineWriter = TableTextPartitionWriter( + ts.rowType, + delimiter, + writeHeader = exportType == ExportType.PARALLEL_HEADER_IN_SHARD, + ) ts.mapContexts { oldCtx => val d = digitsNeeded(ts.numPartitions) - val partFiles = Literal(TArray(TString), Array.tabulate(ts.numPartitions)(i => s"$folder/${ partFile(d, i) }-").toFastSeq) + val partFiles = Literal( + TArray(TString), + Array.tabulate(ts.numPartitions)(i => s"$folder/${partFile(d, i)}-").toFastSeq, + ) zip2(oldCtx, ToStream(partFiles), ArrayZipBehavior.AssertSameLength) { (ctxElt, pf) => MakeStruct(FastSeq( "oldCtx" -> ctxElt, - "partFile" -> pf)) + "partFile" -> pf, + )) } - }(GetField(_, "oldCtx")).mapCollectWithContextsAndGlobals("table_text_writer") { (rows, ctxRef) => - val file = GetField(ctxRef, "partFile") + UUID4() + Str(ext) - WritePartition(rows, file, lineWriter) + }(GetField(_, "oldCtx")).mapCollectWithContextsAndGlobals("table_text_writer") { + (rows, ctxRef) => + val file = GetField(ctxRef, "partFile") + UUID4() + Str(ext) + WritePartition(rows, file, lineWriter) } { (parts, _) => val commit = TableTextFinalizer(path, ts.rowType, delimiter, header, exportType) Begin(FastSeq(WriteMetadata(parts, commit))) @@ -510,7 +726,8 @@ case class TableTextWriter( } } -case class TableTextPartitionWriter(rowType: TStruct, delimiter: String, writeHeader: Boolean) extends SimplePartitionWriter { +case class TableTextPartitionWriter(rowType: TStruct, delimiter: String, writeHeader: Boolean) + extends SimplePartitionWriter { lazy val headerStr = rowType.fields.map(_.name).mkString(delimiter) override def preConsume(cb: EmitCodeBuilder, os: Value[OutputStream]): Unit = if (writeHeader) { @@ -518,27 +735,48 @@ case class TableTextPartitionWriter(rowType: TStruct, delimiter: String, writeHe cb += os.invoke[Int, Unit]("write", '\n') } - def consumeElement(cb: EmitCodeBuilder, element: EmitCode, os: Value[OutputStream], region: Value[Region]): Unit = { + def consumeElement( + cb: EmitCodeBuilder, + element: EmitCode, + os: Value[OutputStream], + region: Value[Region], + ): Unit = { require(element.st.virtualType == rowType) val delimBytes: Value[Array[Byte]] = cb.memoize(cb.emb.getObject(delimiter.getBytes)) - element.toI(cb).consume(cb, { cb._fatal("stream element can not be missing!") }, { case sv: SBaseStructValue => - // I hope we're buffering our writes correctly! - (0 until sv.st.size).foreachBetween { i => - val f = sv.loadField(cb, i) - val annotation = f.consumeCode[AnyRef](cb, Code._null[AnyRef], - { sv => StringFunctions.svalueToJavaValue(cb, region, sv) }) - val str = Code.invokeScalaObject2[Any, Type, String](TableAnnotationImpex.getClass, "exportAnnotation", - annotation, cb.emb.getType(f.st.virtualType)) - cb += os.invoke[Array[Byte], Unit]("write", str.invoke[Array[Byte]]("getBytes")) - }(cb += os.invoke[Array[Byte], Unit]("write", delimBytes)) - cb += os.invoke[Int, Unit]("write", '\n') - }) + element.toI(cb).consume( + cb, + cb._fatal("stream element can not be missing!"), + { case sv: SBaseStructValue => + // I hope we're buffering our writes correctly! + (0 until sv.st.size).foreachBetween { i => + val f = sv.loadField(cb, i) + val annotation = f.consumeCode[AnyRef]( + cb, + Code._null[AnyRef], + sv => StringFunctions.svalueToJavaValue(cb, region, sv), + ) + val str = Code.invokeScalaObject2[Any, Type, String]( + TableAnnotationImpex.getClass, + "exportAnnotation", + annotation, + cb.emb.getType(f.st.virtualType), + ) + cb += os.invoke[Array[Byte], Unit]("write", str.invoke[Array[Byte]]("getBytes")) + }(cb += os.invoke[Array[Byte], Unit]("write", delimBytes)) + cb += os.invoke[Int, Unit]("write", '\n') + }, + ) } } object TableTextFinalizer { - def writeManifest(fs: FS, outputPath: String, files: Array[String], optionalAdditionalFirstPath: String): Unit = { + def writeManifest( + fs: FS, + outputPath: String, + files: Array[String], + optionalAdditionalFirstPath: String, + ): Unit = { def basename(f: String): String = (new java.io.File(f)).getName @@ -564,59 +802,113 @@ object TableTextFinalizer { } } -case class TableTextFinalizer(outputPath: String, rowType: TStruct, delimiter: String, - header: Boolean = true, exportType: String = ExportType.CONCATENATED) extends MetadataWriter { +case class TableTextFinalizer( + outputPath: String, + rowType: TStruct, + delimiter: String, + header: Boolean = true, + exportType: String = ExportType.CONCATENATED, +) extends MetadataWriter { def annotationType: Type = TArray(TString) - def writeMetadata(writeAnnotations: => IEmitCode, cb: EmitCodeBuilder, region: Value[Region]): Unit = { + + def writeMetadata(writeAnnotations: => IEmitCode, cb: EmitCodeBuilder, region: Value[Region]) + : Unit = { val ctx: ExecuteContext = cb.emb.ctx val ext = ctx.fs.getCodecExtension(outputPath) - val partPaths = writeAnnotations.get(cb, "write annotations cannot be missing!") - val files = partPaths.castTo(cb, region, SJavaArrayString(true), false).asInstanceOf[SJavaArrayStringValue].array + val partPaths = writeAnnotations.getOrFatal(cb, "write annotations cannot be missing!") + val files = partPaths.castTo(cb, region, SJavaArrayString(true), false).asInstanceOf[ + SJavaArrayStringValue + ].array exportType match { case ExportType.CONCATENATED => val jFiles = if (header) { val headerFilePath = ctx.createTmpPath("header", ext) val headerStr = rowType.fields.map(_.name).mkString(delimiter) val os = cb.memoize(cb.emb.create(const(headerFilePath))) - cb += os.invoke[Array[Byte], Unit]("write", const(headerStr).invoke[Array[Byte]]("getBytes")) + cb += os.invoke[Array[Byte], Unit]( + "write", + const(headerStr).invoke[Array[Byte]]("getBytes"), + ) cb += os.invoke[Int, Unit]("write", '\n') cb += os.invoke[Unit]("close") val allFiles = cb.memoize(Code.newArray[String](files.length + 1)) cb += (allFiles(0) = const(headerFilePath)) cb += Code.invokeStatic5[System, Any, Int, Any, Int, Int, Unit]( - "arraycopy", files /*src*/, 0 /*srcPos*/, allFiles /*dest*/, 1 /*destPos*/, files.length /*len*/) + "arraycopy", + files /*src*/, + 0 /*srcPos*/, + allFiles /*dest*/, + 1 /*destPos*/, + files.length, /*len*/ + ) allFiles } else { files } - cb += cb.emb.getFS.invoke[Array[String], String, Unit]("concatenateFiles", jFiles, const(outputPath)) + cb += cb.emb.getFS.invoke[Array[String], String, Unit]( + "concatenateFiles", + jFiles, + const(outputPath), + ) val i = cb.newLocal[Int]("i") - cb.for_(cb.assign(i, 0), i < jFiles.length, cb.assign(i, i + 1), { - cb += cb.emb.getFS.invoke[String, Boolean, Unit]("delete", jFiles(i), const(false)) - }) + cb.for_( + cb.assign(i, 0), + i < jFiles.length, + cb.assign(i, i + 1), + cb += cb.emb.getFS.invoke[String, Boolean, Unit]("delete", jFiles(i), const(false)), + ) case ExportType.PARALLEL_HEADER_IN_SHARD => - cb += Code.invokeScalaObject3[FS, String, Array[String], Unit](TableTextFinalizer.getClass, "cleanup", cb.emb.getFS, outputPath, files) - cb += Code.invokeScalaObject4[FS, String, Array[String], String, Unit](TableTextFinalizer.getClass, "writeManifest", cb.emb.getFS, outputPath, files, Code._null[String]) + cb += Code.invokeScalaObject3[FS, String, Array[String], Unit]( + TableTextFinalizer.getClass, + "cleanup", + cb.emb.getFS, + outputPath, + files, + ) + cb += Code.invokeScalaObject4[FS, String, Array[String], String, Unit]( + TableTextFinalizer.getClass, + "writeManifest", + cb.emb.getFS, + outputPath, + files, + Code._null[String], + ) cb += cb.emb.getFS.invoke[String, Unit]("touch", const(outputPath).concat("/_SUCCESS")) case ExportType.PARALLEL_SEPARATE_HEADER => - cb += Code.invokeScalaObject3[FS, String, Array[String], Unit](TableTextFinalizer.getClass, "cleanup", cb.emb.getFS, outputPath, files) + cb += Code.invokeScalaObject3[FS, String, Array[String], Unit]( + TableTextFinalizer.getClass, + "cleanup", + cb.emb.getFS, + outputPath, + files, + ) val headerPath = if (header) { val headerFilePath = const(s"$outputPath/header$ext") val headerStr = rowType.fields.map(_.name).mkString(delimiter) val os = cb.memoize(cb.emb.create(headerFilePath)) - cb += os.invoke[Array[Byte], Unit]("write", const(headerStr).invoke[Array[Byte]]("getBytes")) + cb += os.invoke[Array[Byte], Unit]( + "write", + const(headerStr).invoke[Array[Byte]]("getBytes"), + ) cb += os.invoke[Int, Unit]("write", '\n') cb += os.invoke[Unit]("close") headerFilePath } else Code._null[String] - cb += Code.invokeScalaObject4[FS, String, Array[String], String, Unit](TableTextFinalizer.getClass, "writeManifest", cb.emb.getFS, outputPath, files, headerPath) + cb += Code.invokeScalaObject4[FS, String, Array[String], String, Unit]( + TableTextFinalizer.getClass, + "writeManifest", + cb.emb.getFS, + outputPath, + files, + headerPath, + ) cb += cb.emb.getFS.invoke[String, Unit]("touch", const(outputPath).concat("/_SUCCESS")) } } @@ -629,7 +921,7 @@ class FanoutWriterTarget( val keyPType: PStruct, val tableType: TableType, val rowWriter: PartitionNativeWriter, - val globalWriter: PartitionNativeWriter + val globalWriter: PartitionNativeWriter, ) case class TableNativeFanoutWriter( @@ -637,13 +929,17 @@ case class TableNativeFanoutWriter( val fields: IndexedSeq[String], overwrite: Boolean = true, stageLocally: Boolean = false, - codecSpecJSONStr: String = null + codecSpecJSONStr: String = null, ) extends TableWriter { override def lower(ctx: ExecuteContext, ts: TableStage, r: RTable): IR = { val partitioner = ts.partitioner val bufferSpec = BufferSpec.parseOrDefault(codecSpecJSONStr) - val globalSpec = TypedCodecSpec(EType.fromTypeAndAnalysis(ts.globalType, r.globalType), ts.globalType, bufferSpec) + val globalSpec = TypedCodecSpec( + EType.fromTypeAndAnalysis(ts.globalType, r.globalType), + ts.globalType, + bufferSpec, + ) val targets = { val rowType = ts.rowType val rowRType = r.rowType @@ -655,7 +951,11 @@ case class TableNativeFanoutWriter( val fieldAndKey = (field +: keyFields) val targetRowType = rowType.typeAfterSelectNames(fieldAndKey) val targetRowRType = rowRType.select(fieldAndKey) - val rowSpec = TypedCodecSpec(EType.fromTypeAndAnalysis(targetRowType, targetRowRType), targetRowType, bufferSpec) + val rowSpec = TypedCodecSpec( + EType.fromTypeAndAnalysis(targetRowType, targetRowRType), + targetRowType, + bufferSpec, + ) val keyPType = tcoerce[PStruct](rowSpec.decodedPType(keyType)) val tableType = TableType(targetRowType, keyFields, ts.globalType) val rowWriter = PartitionNativeWriter( @@ -663,22 +963,33 @@ case class TableNativeFanoutWriter( keyFields, s"$targetPath/rows/parts/", Some(s"$targetPath/index/" -> keyPType), - if (stageLocally) Some(FileSystems.getDefault.getPath(ctx.localTmpdir, s"hail_staging_tmp_${UUID.randomUUID()}", "rows", "parts")) else None + if (stageLocally) Some(FileSystems.getDefault.getPath( + ctx.localTmpdir, + s"hail_staging_tmp_${UUID.randomUUID()}", + "rows", + "parts", + )) + else None, ) - val globalWriter = PartitionNativeWriter(globalSpec, IndexedSeq(), s"$targetPath/globals/parts/", None, None) - new FanoutWriterTarget(field, targetPath, rowSpec, keyPType, tableType, rowWriter, globalWriter) + val globalWriter = + PartitionNativeWriter(globalSpec, IndexedSeq(), s"$targetPath/globals/parts/", None, None) + new FanoutWriterTarget(field, targetPath, rowSpec, keyPType, tableType, rowWriter, + globalWriter) }.toFastSeq } val writeTables = ts.mapContexts { oldCtx => val d = digitsNeeded(ts.numPartitions) - val partFiles = Literal(TArray(TString), Array.tabulate(ts.numPartitions)(i => s"${ partFile(d, i) }-").toFastSeq) + val partFiles = Literal( + TArray(TString), + Array.tabulate(ts.numPartitions)(i => s"${partFile(d, i)}-").toFastSeq, + ) zip2(oldCtx, ToStream(partFiles), ArrayZipBehavior.AssertSameLength) { (ctxElt, pf) => MakeStruct(FastSeq( "oldCtx" -> ctxElt, - "writeCtx" -> pf) - ) + "writeCtx" -> pf, + )) } }( GetField(_, "oldCtx") @@ -695,33 +1006,45 @@ case class TableNativeFanoutWriter( WritePartition( MakeStream(FastSeq(globals), TStream(globals.typ)), Str(partFile(1, 0)), - target.globalWriter + target.globalWriter, ), - "filePath" + "filePath", ) ), - RVDSpecWriter(s"${target.path}/globals", RVDSpecMaker(globalSpec, RVDPartitioner.unkeyed(ctx.stateManager, 1))) + RVDSpecWriter( + s"${target.path}/globals", + RVDSpecMaker(globalSpec, RVDPartitioner.unkeyed(ctx.stateManager, 1)), + ), ), WriteMetadata( - ToArray(mapIR(ToStream(fileCountAndDistinct)) { fc => GetField(GetTupleElement(fc, index), "filePath") }), + ToArray(mapIR(ToStream(fileCountAndDistinct)) { fc => + GetField(GetTupleElement(fc, index), "filePath") + }), RVDSpecWriter( s"${target.path}/rows", RVDSpecMaker( target.rowSpec, partitioner, - IndexSpec.emptyAnnotation("../index", tcoerce[PStruct](target.keyPType)) - ) - ) + IndexSpec.emptyAnnotation("../index", tcoerce[PStruct](target.keyPType)), + ), + ), ), WriteMetadata( ToArray(mapIR(ToStream(fileCountAndDistinct)) { fc => SelectFields( GetTupleElement(fc, index), - FastSeq("partitionCounts", "distinctlyKeyed", "firstKey", "lastKey") + FastSeq("partitionCounts", "distinctlyKeyed", "firstKey", "lastKey"), ) }), - TableSpecWriter(target.path, target.tableType, "rows", "globals", "references", log = true) - ) + TableSpecWriter( + target.path, + target.tableType, + "rows", + "globals", + "references", + log = true, + ), + ), )) }.toFastSeq) } @@ -729,7 +1052,9 @@ case class TableNativeFanoutWriter( targets.foldLeft(writeTables) { (rest: IR, target: FanoutWriterTarget) => RelationalWriter.scoped( - target.path, overwrite, Some(target.tableType) + target.path, + overwrite, + Some(target.tableType), )( rest ) @@ -745,22 +1070,21 @@ class PartitionNativeFanoutWriter( cb: EmitCodeBuilder, stream: StreamProducer, context: EmitCode, - region: Value[Region] + region: Value[Region], ): IEmitCode = { - val ctx = context.toI(cb).get(cb) - val consumers = targets.map { target => - new target.rowWriter.StreamConsumer(ctx, cb, region) - } + val ctx = context.toI(cb).getOrAssert(cb) + val consumers = targets.map(target => new target.rowWriter.StreamConsumer(ctx, cb, region)) consumers.foreach(_.setup()) stream.memoryManagedConsume(region, cb) { cb => - val row = stream.element.toI(cb).get(cb, "row can't be missing") + val row = stream.element.toI(cb).getOrFatal(cb, "row can't be missing") (consumers zip targets).foreach { case (consumer, target) => consumer.consumeElement( cb, - row.asBaseStruct.subset((target.keyPType.fieldNames :+ target.field):_*), - stream.elementRegion) + row.asBaseStruct.subset((target.keyPType.fieldNames :+ target.field): _*), + stream.elementRegion, + ) } } IEmitCode.present( @@ -770,16 +1094,20 @@ class PartitionNativeFanoutWriter( region, returnType, consumers.map(consumer => EmitCode.present(cb.emb, consumer.result())): _* - ) + ), ) } def ctxType = TString def returnType: TTuple = - TTuple(targets.map(target => target.rowWriter.returnType):_*) + TTuple(targets.map(target => target.rowWriter.returnType): _*) - def unionTypeRequiredness(returnType: TypeWithRequiredness, ctxType: TypeWithRequiredness, streamType: RIterable): Unit = { + def unionTypeRequiredness( + returnType: TypeWithRequiredness, + ctxType: TypeWithRequiredness, + streamType: RIterable, + ): Unit = { val targetReturnTypes = returnType.asInstanceOf[RTuple].fields.map(_.typ) ((targetReturnTypes) zip targets).foreach { case (returnType, target) => @@ -796,13 +1124,18 @@ object WrappedMatrixNativeMultiWriter { case class WrappedMatrixNativeMultiWriter( writer: MatrixNativeMultiWriter, - colKey: IndexedSeq[String] + colKey: IndexedSeq[String], ) { def lower(ctx: ExecuteContext, ts: IndexedSeq[(TableStage, RTable)]): IR = - writer.lower(ctx, ts.map { case (ts, rt) => - (LowerMatrixIR.colsFieldName, LowerMatrixIR.entriesFieldName, colKey, ts, rt) - }) + writer.lower( + ctx, + ts.map { case (ts, rt) => + (LowerMatrixIR.colsFieldName, LowerMatrixIR.entriesFieldName, colKey, ts, rt) + }, + ) def apply(ctx: ExecuteContext, mvs: IndexedSeq[TableValue]): Unit = writer.apply( - ctx, mvs.map(_.toMatrixValue(colKey))) + ctx, + mvs.map(_.toMatrixValue(colKey)), + ) } diff --git a/hail/src/main/scala/is/hail/expr/ir/TypeCheck.scala b/hail/src/main/scala/is/hail/expr/ir/TypeCheck.scala index 6afb42d8fb2..3eccc1bc404 100644 --- a/hail/src/main/scala/is/hail/expr/ir/TypeCheck.scala +++ b/hail/src/main/scala/is/hail/expr/ir/TypeCheck.scala @@ -4,44 +4,43 @@ import is.hail.backend.ExecuteContext import is.hail.expr.Nat import is.hail.types.tcoerce import is.hail.types.virtual._ -import is.hail.utils.StackSafe._ import is.hail.utils._ +import is.hail.utils.StackSafe._ import scala.reflect.ClassTag object TypeCheck { - def apply(ctx: ExecuteContext, ir: BaseIR): Unit = { - try { + def apply(ctx: ExecuteContext, ir: BaseIR): Unit = + try check(ctx, ir, BindingEnv.empty).run() - } catch { - case e: Throwable => fatal(s"Error while typechecking IR:\n${ Pretty(ctx, ir) }", e) + catch { + case e: Throwable => + fatal(s"Error while typechecking IR:\n${Pretty(ctx, ir, preserveNames = true)}", e) } - } - def apply(ctx: ExecuteContext, ir: IR, env: BindingEnv[Type]): Unit = { - try { + def apply(ctx: ExecuteContext, ir: IR, env: BindingEnv[Type]): Unit = + try check(ctx, ir, env).run() - } catch { - case e: Throwable => fatal(s"Error while typechecking IR:\n${ Pretty(ctx, ir) }", e) + catch { + case e: Throwable => fatal(s"Error while typechecking IR:\n${Pretty(ctx, ir)}", e) } - } def check(ctx: ExecuteContext, ir: BaseIR, env: BindingEnv[Type]): StackFrame[Unit] = { for { _ <- ir.forEachChildWithEnvStackSafe(env) { (child, i, childEnv) => for { _ <- call(check(ctx, child, childEnv)) - } yield { + } yield if (child.typ == TVoid) { checkVoidTypedChild(ctx, ir, i, env) } else () - } } } yield checkSingleNode(ctx, ir, env) } - private def checkVoidTypedChild(ctx: ExecuteContext, ir: BaseIR, i: Int, env: BindingEnv[Type]): Unit = ir match { - case l: Let if i == l.bindings.length => + private def checkVoidTypedChild(ctx: ExecuteContext, ir: BaseIR, i: Int, env: BindingEnv[Type]) + : Unit = ir match { + case l: Block if i == l.bindings.length || l.body.typ == TVoid => case _: StreamFor if i == 1 => case _: RunAggScan if (i == 1 || i == 2) => case _: StreamBufferedAggregate if (i == 1 || i == 3) => @@ -50,41 +49,42 @@ object TypeCheck { case _: InitOp => // let initop checking below catch bad void arguments case _: If if i != 0 => case _: RelationalLet if i == 1 => - case _: Begin => case _: WriteMetadata => case _ => - throw new RuntimeException(s"unexpected void-typed IR at child $i of ${ ir.getClass.getSimpleName }" + - s"\n IR: ${ Pretty(ctx, ir) }") + throw new RuntimeException( + s"unexpected void-typed IR at child $i of ${ir.getClass.getSimpleName}" + + s"\n IR: ${Pretty(ctx, ir)}" + ) } private def checkSingleNode(ctx: ExecuteContext, ir: BaseIR, env: BindingEnv[Type]): Unit = { ir match { - case I32(x) => - case I64(x) => - case F32(x) => - case F64(x) => + case I32(_) => + case I64(_) => + case F32(_) => + case F64(_) => case True() => case False() => - case Str(x) => + case Str(_) => case UUID4(_) => case Literal(_, _) => case EncodedLiteral(_, _) => case Void() => case Cast(v, typ) => if (!Casts.valid(v.typ, typ)) - throw new RuntimeException(s"invalid cast:\n " + - s"child type: ${ v.typ.parsableString() }\n " + - s"cast type: ${ typ.parsableString() }") + throw new RuntimeException(s"invalid cast:\n " + + s"child type: ${v.typ.parsableString()}\n " + + s"cast type: ${typ.parsableString()}") case CastRename(v, typ) => - if (!v.typ.canCastTo(typ)) + if (!v.typ.isIsomorphicTo(typ)) throw new RuntimeException(s"invalid cast:\n " + - s"child type: ${ v.typ.parsableString() }\n " + - s"cast type: ${ typ.parsableString() }") + s"child type: ${v.typ.parsableString()}\n " + + s"cast type: ${typ.parsableString()}") case NA(t) => assert(t != null) - case IsNA(v) => + case IsNA(_) => case Coalesce(values) => assert(values.tail.forall(_.typ == values.head.typ)) - case x@If(cond, cnsq, altr) => + case x @ If(cond, cnsq, altr) => assert(cond.typ == TBoolean) assert(x.typ == cnsq.typ && x.typ == altr.typ) x.typ match { @@ -94,28 +94,30 @@ object TypeCheck { case Switch(x, default, cases) => assert(x.typ == TInt32) assert(cases.forall(_.typ == default.typ)) - case x@Let(_, body) => + case x @ Block(_, body) => assert(x.typ == body.typ) - case x@AggLet(_, _, body, _) => - assert(x.typ == body.typ) - case x@Ref(name, _) => + case x @ Ref(name, _) => env.eval.lookupOption(name) match { case Some(expected) => - assert(x.typ == expected, - s"type mismatch:\n name: $name\n actual: ${x.typ.parsableString()}\n expect: ${expected.parsableString()}") + assert( + x.typ == expected, + s"type mismatch:\n name: $name\n actual: ${x.typ.parsableString()}\n expect: ${expected.parsableString()}", + ) case None => - throw new NoSuchElementException(s"Ref with name ${name} could not be resolved in env ${env}") + throw new NoSuchElementException( + s"Ref with name $name could not be resolved in env $env" + ) } case RelationalRef(name, t) => env.relational.lookupOption(name) match { case Some(t2) => if (t != t2) - throw new RuntimeException(s"RelationalRef type mismatch:\n node=${t}\n env=${t2}") + throw new RuntimeException(s"RelationalRef type mismatch:\n node=$t\n env=$t2") case None => throw new RuntimeException(s"RelationalRef not found in env: $name") } - case x@TailLoop(name, _, rt, body) => + case x @ TailLoop(name, _, rt, body) => assert(x.typ == rt) assert(body.typ == rt) def recurInTail(node: IR, tailPosition: Boolean): Boolean = node match { @@ -124,20 +126,20 @@ object TypeCheck { case _ => node.children.zipWithIndex .forall { - case (c: IR, i) => recurInTail(c, tailPosition && InTailPosition(node, i)) + case (c: IR, i) => recurInTail(c, tailPosition && InTailPosition(node, i)) case _ => true - } + } } assert(recurInTail(body, tailPosition = true)) - case x@Recur(name, args, typ) => + case Recur(name, args, typ) => val TTuple(IndexedSeq(TupleField(_, argTypes), TupleField(_, rt))) = env.eval.lookup(name) - assert(argTypes.asInstanceOf[TTuple].types.zip(args).forall { case (t, ir) => t == ir.typ } ) + assert(argTypes.asInstanceOf[TTuple].types.zip(args).forall { case (t, ir) => t == ir.typ }) assert(typ == rt) - case x@ApplyBinaryPrimOp(op, l, r) => + case x @ ApplyBinaryPrimOp(op, l, r) => assert(x.typ == BinaryOp.getReturnType(op, l.typ, r.typ)) - case x@ApplyUnaryPrimOp(op, v) => + case x @ ApplyUnaryPrimOp(op, v) => assert(x.typ == UnaryOp.getReturnType(op, v.typ)) - case x@ApplyComparisonOp(op, l, r) => + case x @ ApplyComparisonOp(op, l, r) => assert(op.t1 == l.typ) assert(op.t2 == r.typ) ComparisonOp.checkCompatible(op.t1, op.t2) @@ -145,23 +147,28 @@ object TypeCheck { case _: Compare => assert(x.typ == TInt32) case _ => assert(x.typ == TBoolean) } - case x@MakeArray(args, typ) => + case MakeArray(args, typ) => assert(typ != null) args.map(_.typ).zipWithIndex.foreach { case (x, i) => - assert(x == typ.elementType && x.isRealizable, - s"at position $i type mismatch: ${ typ.parsableString() } ${ x.parsableString() }") + assert( + x == typ.elementType && x.isRealizable, + s"at position $i type mismatch: ${typ.parsableString()} ${x.parsableString()}", + ) } - case x@MakeStream(args, typ, _) => + case MakeStream(args, typ, _) => assert(typ != null) assert(typ.elementType.isRealizable) - args.map(_.typ).zipWithIndex.foreach { case (x, i) => assert(x == typ.elementType, - s"at position $i type mismatch: ${ typ.elementType.parsableString() } ${ x.parsableString() }") + args.map(_.typ).zipWithIndex.foreach { case (x, i) => + assert( + x == typ.elementType, + s"at position $i type mismatch: ${typ.elementType.parsableString()} ${x.parsableString()}", + ) } - case x@ArrayRef(a, i, _) => + case x @ ArrayRef(a, i, _) => assert(i.typ == TInt32) assert(x.typ == tcoerce[TArray](a.typ).elementType) - case x@ArraySlice(a, start, stop, step, _) => + case x @ ArraySlice(a, start, stop, step, _) => assert(start.typ == TInt32) stop.foreach(ir => assert(ir.typ == TInt32)) assert(step.typ == TInt32) @@ -178,7 +185,7 @@ object TypeCheck { case StreamIota(start, step, _) => assert(start.typ == TInt32) assert(step.typ == TInt32) - case x@StreamRange(a, b, c, _, _) => + case StreamRange(a, b, c, _, _) => assert(a.typ == TInt32) assert(b.typ == TInt32) assert(c.typ == TInt32) @@ -191,7 +198,8 @@ object TypeCheck { assert(child.typ.isInstanceOf[TStream]) assert(pivots.typ.isInstanceOf[TArray]) assert(pivots.typ.asInstanceOf[TArray].elementType.isInstanceOf[TStruct]) - case StreamWhiten(stream, newChunk, prevWindow, vecSize, windowSize, chunkSize, blockSize, normalizeAfterWhiten) => + case StreamWhiten(stream, newChunk, prevWindow, _, windowSize, chunkSize, _, + _) => assert(stream.typ.isInstanceOf[TStream]) val eltTyp = stream.typ.asInstanceOf[TStream].elementType assert(eltTyp.isInstanceOf[TStruct]) @@ -200,25 +208,25 @@ object TypeCheck { assert(structTyp.field(newChunk).typ == matTyp) assert(structTyp.field(prevWindow).typ == matTyp) assert(windowSize % chunkSize == 0) - case x@ArrayZeros(length) => + case ArrayZeros(length) => assert(length.typ == TInt32) - case x@MakeNDArray(data, shape, rowMajor, _) => + case MakeNDArray(data, shape, rowMajor, _) => assert(data.typ.isInstanceOf[TArray] || data.typ.isInstanceOf[TStream]) assert(shape.typ.asInstanceOf[TTuple].types.forall(t => t == TInt64)) assert(rowMajor.typ == TBoolean) - case x@NDArrayShape(nd) => + case NDArrayShape(nd) => assert(nd.typ.isInstanceOf[TNDArray]) - case x@NDArrayReshape(nd, shape, _) => + case NDArrayReshape(nd, shape, _) => assert(nd.typ.isInstanceOf[TNDArray]) assert(shape.typ.asInstanceOf[TTuple].types.forall(t => t == TInt64)) - case x@NDArrayConcat(nds, axis) => + case x @ NDArrayConcat(nds, axis) => assert(tcoerce[TArray](nds.typ).elementType.isInstanceOf[TNDArray]) assert(axis < x.typ.nDims) - case x@NDArrayRef(nd, idxs, _) => + case NDArrayRef(nd, idxs, _) => assert(nd.typ.isInstanceOf[TNDArray]) assert(nd.typ.asInstanceOf[TNDArray].nDims == idxs.length) assert(idxs.forall(_.typ == TInt64)) - case x@NDArraySlice(nd, slices) => + case NDArraySlice(nd, slices) => assert(nd.typ.isInstanceOf[TNDArray]) val childTyp = nd.typ.asInstanceOf[TNDArray] val slicesTuple = slices.typ.asInstanceOf[TTuple] @@ -230,30 +238,30 @@ object TypeCheck { val ndtyp = tcoerce[TNDArray](nd.typ) assert(ndtyp.nDims == filters.length) assert(filters.forall(f => tcoerce[TArray](f.typ).elementType == TInt64)) - case x@NDArrayMap(_, _, body) => + case x @ NDArrayMap(_, _, body) => assert(x.elementTyp == body.typ) - case x@NDArrayMap2(l, r, _, _, body, _) => + case x @ NDArrayMap2(l, r, _, _, body, _) => val lTyp = tcoerce[TNDArray](l.typ) val rTyp = tcoerce[TNDArray](r.typ) assert(lTyp.nDims == rTyp.nDims) assert(x.elementTyp == body.typ) - case x@NDArrayReindex(nd, indexExpr) => + case NDArrayReindex(nd, indexExpr) => assert(nd.typ.isInstanceOf[TNDArray]) val nInputDims = tcoerce[TNDArray](nd.typ).nDims val nOutputDims = indexExpr.length assert(nInputDims <= nOutputDims) assert(indexExpr.forall(i => i < nOutputDims)) assert((0 until nOutputDims).forall(i => indexExpr.contains(i))) - case x@NDArrayAgg(nd, axes) => + case NDArrayAgg(nd, axes) => assert(nd.typ.isInstanceOf[TNDArray]) val nInputDims = tcoerce[TNDArray](nd.typ).nDims assert(axes.length <= nInputDims) assert(axes.forall(i => i < nInputDims)) assert(axes.distinct.length == axes.length) - case x@NDArrayWrite(nd, path) => + case NDArrayWrite(nd, path) => assert(nd.typ.isInstanceOf[TNDArray]) assert(path.typ == TString) - case x@NDArrayMatMul(l, r, _) => + case NDArrayMatMul(l, r, _) => assert(l.typ.isInstanceOf[TNDArray]) assert(r.typ.isInstanceOf[TNDArray]) val lType = l.typ.asInstanceOf[TNDArray] @@ -262,48 +270,49 @@ object TypeCheck { assert(lType.nDims > 0) assert(rType.nDims > 0) assert(lType.nDims == 1 || rType.nDims == 1 || lType.nDims == rType.nDims) - case x@NDArrayQR(nd, mode, _) => + case NDArrayQR(nd, _, _) => val ndType = nd.typ.asInstanceOf[TNDArray] assert(ndType.elementType == TFloat64) assert(ndType.nDims == 2) - case x@NDArraySVD(nd, _, _, _) => + case NDArraySVD(nd, _, _, _) => val ndType = nd.typ.asInstanceOf[TNDArray] assert(ndType.elementType == TFloat64) assert(ndType.nDims == 2) - case x@NDArrayEigh(nd, _, _) => + case NDArrayEigh(nd, _, _) => val ndType = nd.typ.asInstanceOf[TNDArray] assert(ndType.elementType == TFloat64) assert(ndType.nDims == 2) - case x@NDArrayInv(nd, _) => + case NDArrayInv(nd, _) => val ndType = nd.typ.asInstanceOf[TNDArray] assert(ndType.elementType == TFloat64) assert(ndType.nDims == 2) - case x@ArraySort(a, l, r, lessThan) => + case ArraySort(a, _, _, lessThan) => assert(a.typ.isInstanceOf[TStream]) assert(lessThan.typ == TBoolean) - case x@ToSet(a) => + case ToSet(a) => assert(a.typ.isInstanceOf[TStream], a.typ) - case x@ToDict(a) => + case ToDict(a) => assert(a.typ.isInstanceOf[TStream]) assert(tcoerce[TBaseStruct](tcoerce[TStream](a.typ).elementType).size == 2) - case x@ToArray(a) => + case ToArray(a) => assert(a.typ.isInstanceOf[TStream]) - case x@CastToArray(a) => + case CastToArray(a) => assert(a.typ.isInstanceOf[TContainer]) - case x@ToStream(a, _) => + case ToStream(a, _) => assert(a.typ.isInstanceOf[TContainer]) - case x@LowerBoundOnOrderedCollection(orderedCollection, elem, onKey) => + case LowerBoundOnOrderedCollection(orderedCollection, elem, onKey) => val elt = tcoerce[TIterable](orderedCollection.typ).elementType assert(elem.typ == (if (onKey) elt match { - case t: TBaseStruct => t.types(0) - case t: TInterval => t.pointType - } else elt)) - case x@GroupByKey(collection) => + case t: TBaseStruct => t.types(0) + case t: TInterval => t.pointType + } + else elt)) + case x @ GroupByKey(collection) => val telt = tcoerce[TBaseStruct](tcoerce[TStream](collection.typ).elementType) val td = tcoerce[TDict](x.typ) assert(td.keyType == telt.types(0)) assert(td.valueType == TArray(telt.types(1))) - case x@RNGStateLiteral() => + case x @ RNGStateLiteral() => assert(x.typ == TRNGState) case RNGSplit(state, dynBitstring) => assert(state.typ == TRNGState) @@ -314,79 +323,80 @@ object TypeCheck { assert(isValid(dynBitstring.typ)) case StreamLen(a) => assert(a.typ.isInstanceOf[TStream]) - case x@StreamTake(a, num) => + case x @ StreamTake(a, num) => assert(a.typ.isInstanceOf[TStream]) assert(x.typ == a.typ) assert(num.typ == TInt32) - case x@StreamDrop(a, num) => + case x @ StreamDrop(a, num) => assert(a.typ.isInstanceOf[TStream]) assert(x.typ == a.typ) assert(num.typ == TInt32) - case x@StreamGrouped(a, size) => + case x @ StreamGrouped(a, size) => val ts = tcoerce[TStream](x.typ) assert(a.typ.isInstanceOf[TStream]) assert(ts.elementType == a.typ) assert(size.typ == TInt32) - case x@StreamGroupByKey(a, key, _) => + case x @ StreamGroupByKey(a, key, _) => val ts = tcoerce[TStream](x.typ) assert(ts.elementType == a.typ) val structType = tcoerce[TStruct](tcoerce[TStream](a.typ).elementType) assert(key.forall(structType.hasField)) - case x@StreamMap(a, name, body) => + case x @ StreamMap(a, _, body) => assert(a.typ.isInstanceOf[TStream]) assert(x.elementTyp == body.typ) - case x@StreamZip(as, names, body, _, _) => + case x @ StreamZip(as, names, body, _, _) => assert(as.length == names.length) assert(x.typ.elementType == body.typ) assert(as.forall(_.typ.isInstanceOf[TStream])) - case x@StreamZipJoin(as, key, curKey, curVals, joinF) => + case x @ StreamZipJoin(as, key, _, _, joinF) => val streamType = tcoerce[TStream](as.head.typ) assert(as.forall(_.typ == streamType)) val eltType = tcoerce[TStruct](streamType.elementType) assert(key.forall(eltType.hasField)) assert(x.typ.elementType == joinF.typ) - case x@StreamZipJoinProducers(contexts, ctxName, makeProducer, key, curKey, curVals, joinF) => + case x @ StreamZipJoinProducers(contexts, _, makeProducer, key, _, _, + joinF) => assert(contexts.typ.isInstanceOf[TArray]) val streamType = tcoerce[TStream](makeProducer.typ) val eltType = tcoerce[TStruct](streamType.elementType) assert(key.forall(eltType.hasField)) assert(x.typ.elementType == joinF.typ) - case x@StreamMultiMerge(as, key) => + case x @ StreamMultiMerge(as, key) => val streamType = tcoerce[TStream](as.head.typ) assert(as.forall(_.typ == streamType)) val eltType = tcoerce[TStruct](streamType.elementType) assert(x.typ.elementType == eltType) assert(key.forall(eltType.hasField)) - case x@StreamFilter(a, name, cond) => + case x @ StreamFilter(a, _, cond) => assert(a.typ.asInstanceOf[TStream].elementType.isRealizable) assert(cond.typ == TBoolean, cond.typ) assert(x.typ == a.typ) - case x@StreamTakeWhile(a, name, cond) => + case x @ StreamTakeWhile(a, _, cond) => assert(a.typ.asInstanceOf[TStream].elementType.isRealizable) assert(cond.typ == TBoolean) assert(x.typ == a.typ) - case x@StreamDropWhile(a, name, cond) => + case x @ StreamDropWhile(a, _, cond) => assert(a.typ.asInstanceOf[TStream].elementType.isRealizable) assert(cond.typ == TBoolean) assert(x.typ == a.typ) - case x@StreamFlatMap(a, name, body) => + case StreamFlatMap(a, _, body) => assert(a.typ.isInstanceOf[TStream]) assert(body.typ.isInstanceOf[TStream]) - case x@StreamFold(a, zero, accumName, valueName, body) => + case x @ StreamFold(a, zero, _, _, body) => assert(a.typ.isInstanceOf[TStream]) assert(a.typ.asInstanceOf[TStream].elementType.isRealizable, Pretty(ctx, x)) assert(body.typ == zero.typ) assert(x.typ == zero.typ) - case x@StreamFold2(a, accum, valueName, seq, res) => + case x @ StreamFold2(a, accum, _, seq, res) => assert(a.typ.isInstanceOf[TStream]) assert(x.typ == res.typ) assert(accum.zip(seq).forall { case ((_, z), s) => s.typ == z.typ }) - case x@StreamScan(a, zero, accumName, valueName, body) => + case x @ StreamScan(a, zero, _, _, body) => assert(a.typ.isInstanceOf[TStream]) assert(body.typ == zero.typ) assert(tcoerce[TStream](x.typ).elementType == zero.typ) assert(zero.typ.isRealizable) - case x@StreamJoinRightDistinct(left, right, lKey, rKey, l, r, join, joinType) => + case x @ StreamJoinRightDistinct(left, right, lKey, rKey, _, _, join, joinType) => val lEltTyp = tcoerce[TStruct](tcoerce[TStream](left.typ).elementType) val rEltTyp = tcoerce[TStruct](tcoerce[TStream](right.typ).elementType) assert(tcoerce[TStream](x.typ).elementType == join.typ) @@ -402,21 +412,37 @@ object TypeCheck { lEltTyp.fieldType(lk) == rEltTyp.fieldType(rk) }) } - case x@StreamFor(a, valueName, body) => + case StreamLeftIntervalJoin(left, right, lKeyFieldName, rIntrvlName, _, _, body) => + assert(left.typ.isInstanceOf[TStream]) + assert(right.typ.isInstanceOf[TStream]) + + val lEltTy = + TIterable.elementType(left.typ).asInstanceOf[TStruct] + + val rPointTy = + TIterable.elementType(right.typ) + .asInstanceOf[TStruct] + .fieldType(rIntrvlName) + .asInstanceOf[TInterval] + .pointType + + assert(lEltTy.fieldType(lKeyFieldName) == rPointTy) + assert(body.typ.isInstanceOf[TStruct]) + case StreamFor(a, _, body) => assert(a.typ.isInstanceOf[TStream]) assert(body.typ == TVoid) - case x@StreamAgg(a, name, query) => + case StreamAgg(a, _, _) => assert(a.typ.isInstanceOf[TStream]) - case x@StreamAggScan(a, name, query) => + case x @ StreamAggScan(a, _, query) => assert(a.typ.isInstanceOf[TStream]) assert(x.typ.asInstanceOf[TStream].elementType == query.typ) - case x@StreamBufferedAggregate(streamChild, initAggs, newKey, seqOps, _, _, _) => + case x @ StreamBufferedAggregate(streamChild, initAggs, newKey, seqOps, _, _, _) => assert(streamChild.typ.isInstanceOf[TStream]) assert(initAggs.typ == TVoid) assert(seqOps.typ == TVoid) assert(newKey.typ.isInstanceOf[TStruct]) assert(x.typ.isInstanceOf[TStream]) - case x@StreamLocalLDPrune(streamChild, r2Threshold, windowSize, maxQueueSize, nSamples) => + case x @ StreamLocalLDPrune(streamChild, r2Threshold, windowSize, maxQueueSize, nSamples) => assert(streamChild.typ.isInstanceOf[TStream]) assert(r2Threshold.typ == TFloat64) assert(windowSize.typ == TInt32) @@ -433,61 +459,60 @@ object TypeCheck { assert(gtType.isInstanceOf[TArray]) assert(gtType.asInstanceOf[TArray].elementType == TCall) assert(x.typ.isInstanceOf[TStream]) - case x@RunAgg(body, result, _) => + case x @ RunAgg(body, result, _) => assert(x.typ == result.typ) assert(body.typ == TVoid) - case x@RunAggScan(array, _, init, seqs, result, _) => + case x @ RunAggScan(array, _, init, seqs, result, _) => assert(array.typ.isInstanceOf[TStream]) assert(init.typ == TVoid) assert(seqs.typ == TVoid) assert(x.typ.asInstanceOf[TStream].elementType == result.typ) - case x@AggFilter(cond, aggIR, _) => + case x @ AggFilter(cond, aggIR, _) => assert(cond.typ == TBoolean) assert(x.typ == aggIR.typ) - case x@AggExplode(array, name, aggBody, _) => + case x @ AggExplode(array, _, aggBody, _) => assert(array.typ.isInstanceOf[TStream]) assert(x.typ == aggBody.typ) - case x@AggGroupBy(key, aggIR, _) => + case x @ AggGroupBy(key, aggIR, _) => assert(x.typ == TDict(key.typ, aggIR.typ)) - case x@AggArrayPerElement(a, _, _, aggBody, knownLength, _) => + case x @ AggArrayPerElement(_, _, _, aggBody, knownLength, _) => assert(x.typ == TArray(aggBody.typ)) assert(knownLength.forall(_.typ == TInt32)) - case x@InitOp(_, args, aggSig) => - assert(args.map(_.typ) == aggSig.initOpTypes, s"${args.map(_.typ)} != ${aggSig.initOpTypes}") - case x@SeqOp(_, args, aggSig) => + case InitOp(_, args, aggSig) => + assert( + args.map(_.typ) == aggSig.initOpTypes, + s"${args.map(_.typ)} != ${aggSig.initOpTypes}", + ) + case SeqOp(_, args, aggSig) => assert(args.map(_.typ) == aggSig.seqOpTypes) case _: CombOp => case _: ResultOp => - case AggStateValue(i, sig) => - case CombOpValue(i, value, sig) => assert(value.typ == TBinary) - case InitFromSerializedValue(i, value, sig) => assert(value.typ == TBinary) + case AggStateValue(_, _) => + case CombOpValue(_, value, _) => assert(value.typ == TBinary) + case InitFromSerializedValue(_, value, _) => assert(value.typ == TBinary) case _: SerializeAggs => case _: DeserializeAggs => - case x@Begin(xs) => - xs.foreach { x => - assert(x.typ == TVoid) - } - case x@ApplyAggOp(initOpArgs, seqOpArgs, aggSig) => + case x @ ApplyAggOp(initOpArgs, seqOpArgs, aggSig) => assert(x.typ == aggSig.returnType) assert(initOpArgs.map(_.typ).zip(aggSig.initOpArgs).forall { case (l, r) => l == r }) assert(seqOpArgs.map(_.typ).zip(aggSig.seqOpArgs).forall { case (l, r) => l == r }) - case x@ApplyScanOp(initOpArgs, seqOpArgs, aggSig) => + case x @ ApplyScanOp(initOpArgs, seqOpArgs, aggSig) => assert(x.typ == aggSig.returnType) assert(initOpArgs.map(_.typ).zip(aggSig.initOpArgs).forall { case (l, r) => l == r }) assert(seqOpArgs.map(_.typ).zip(aggSig.seqOpArgs).forall { case (l, r) => l == r }) - case x@AggFold(zero, seqOp, combOp, elementName, accumName, _) => + case AggFold(zero, seqOp, combOp, _, _, _) => assert(zero.typ == seqOp.typ) assert(zero.typ == combOp.typ) - case x@MakeStruct(fields) => + case x @ MakeStruct(fields) => assert(x.typ == TStruct(fields.map { case (name, a) => (name, a.typ) }: _*)) - case x@SelectFields(old, fields) => + case SelectFields(old, fields) => assert { val oldfields = tcoerce[TStruct](old.typ).fieldNames.toSet - fields.forall { id => oldfields.contains(id) } + fields.forall(id => oldfields.contains(id)) } - case x@InsertFields(old, fields, fieldOrder) => + case x @ InsertFields(old, fields, fieldOrder) => fieldOrder.foreach { fds => val newFieldSet = fields.map(_._1).toSet val oldFieldNames = old.typ.asInstanceOf[TStruct].fieldNames @@ -496,30 +521,30 @@ object TypeCheck { assert(fds.areDistinct()) assert(fds.toSet.forall(f => newFieldSet.contains(f) || oldFieldNameSet.contains(f))) } - case x@GetField(o, name) => + case x @ GetField(o, name) => val t = tcoerce[TStruct](o.typ) assert(t.index(name).nonEmpty, s"$name not in $t") assert(x.typ == t.field(name).typ) - case x@MakeTuple(fields) => + case x @ MakeTuple(fields) => val indices = fields.map(_._1) assert(indices.areDistinct()) assert(indices.isSorted) - assert(x.typ == TTuple(fields.map { case (idx, f) => TupleField(idx, f.typ)}.toFastSeq)) - case x@GetTupleElement(o, idx) => + assert(x.typ == TTuple(fields.map { case (idx, f) => TupleField(idx, f.typ) }.toFastSeq)) + case x @ GetTupleElement(o, idx) => val t = tcoerce[TTuple](o.typ) val fd = t.fields(t.fieldIndex(idx)) assert(x.typ == fd.typ) - case In(i, typ) => + case In(_, typ) => assert(typ != null) typ.virtualType match { case stream: TStream => assert(stream.elementType.isRealizable) case _ => } - case Die(msg, typ, _) => + case Die(msg, _, _) => assert(msg.typ == TString) - case Trap(child) => + case Trap(_) => case ConsoleLog(msg, _) => assert(msg.typ == TString) - case x@ApplyIR(fn, _, typeArgs, args, _) => + case ApplyIR(_, _, _, _, _) => case x: AbstractApplyNode[_] => assert(x.implementation.unify(x.typeArgs, x.args.map(_.typ), x.returnType)) case MatrixWrite(_, _) => @@ -527,12 +552,13 @@ object TypeCheck { val t = children.head.typ assert( !t.rowType.hasField(MatrixReader.rowUIDFieldName) && - !t.colType.hasField(MatrixReader.colUIDFieldName), t - ) + !t.colType.hasField(MatrixReader.colUIDFieldName), + t, + ) assert(children.forall(_.typ == t)) - case x@TableAggregate(child, query) => + case x @ TableAggregate(_, query) => assert(x.typ == query.typ) - case x@MatrixAggregate(child, query) => + case x @ MatrixAggregate(_, query) => assert(x.typ == query.typ) case RelationalLet(_, _, _) => case TableWrite(_, _) => @@ -550,28 +576,28 @@ object TypeCheck { case BlockMatrixCollect(_) => case BlockMatrixWrite(_, writer) => writer.loweredTyp case BlockMatrixMultiWrite(_, _) => - case CollectDistributedArray(ctxs, globals, cname, gname, body, dynamicID, _, _) => + case CollectDistributedArray(ctxs, _, _, _, _, dynamicID, _, _) => assert(ctxs.typ.isInstanceOf[TStream]) assert(dynamicID.typ == TString) - case x@ReadPartition(context, rowType, reader) => + case x @ ReadPartition(context, rowType, reader) => assert(rowType.isRealizable) assert(context.typ == reader.contextType) assert(x.typ == TStream(rowType)) assert(PruneDeadFields.isSupertype(rowType, reader.fullRowType)) - case x@WritePartition(value, writeCtx, writer) => + case x @ WritePartition(value, writeCtx, writer) => assert(value.typ.isInstanceOf[TStream]) assert(writeCtx.typ == writer.ctxType) assert(x.typ == writer.returnType) case WriteMetadata(writeAnnotations, writer) => assert(writeAnnotations.typ == writer.annotationType) - case x@ReadValue(path, reader, requestedType) => + case ReadValue(path, reader, requestedType) => assert(path.typ == TString) reader match { case reader: ETypeValueReader => assert(reader.spec.encodedType.decodedPType(requestedType).virtualType == requestedType) case _ => // do nothing, we can't in general typecheck an arbitrary value reader } - case WriteValue(_, path, writer, stagingFile) => + case WriteValue(_, path, _, stagingFile) => assert(path.typ == TString) assert(stagingFile.forall(_.typ == TString)) case LiftMeOut(_) => diff --git a/hail/src/main/scala/is/hail/expr/ir/UnaryOp.scala b/hail/src/main/scala/is/hail/expr/ir/UnaryOp.scala index ac477ed859d..f66332721c9 100644 --- a/hail/src/main/scala/is/hail/expr/ir/UnaryOp.scala +++ b/hail/src/main/scala/is/hail/expr/ir/UnaryOp.scala @@ -1,20 +1,18 @@ package is.hail.expr.ir import is.hail.asm4s._ -import is.hail.expr._ -import is.hail.types._ -import is.hail.types.physical.stypes.{SCode, SType, SValue} +import is.hail.types.physical.{typeToTypeInfo, PType} +import is.hail.types.physical.stypes.{SType, SValue} import is.hail.types.physical.stypes.interfaces._ -import is.hail.types.physical.{PType, typeToTypeInfo} import is.hail.types.virtual._ import is.hail.utils._ object UnaryOp { private val returnType: ((UnaryOp, Type)) => Option[Type] = lift { - case (Negate, t@(TInt32 | TInt64 | TFloat32 | TFloat64)) => t + case (Negate, t @ (TInt32 | TInt64 | TFloat32 | TFloat64)) => t case (Bang, TBoolean) => TBoolean - case (BitNot, t@(TInt32 | TInt64)) => t + case (BitNot, t @ (TInt32 | TInt64)) => t case (BitCount, TInt32 | TInt64) => TInt32 } diff --git a/hail/src/main/scala/is/hail/expr/ir/ValueReader.scala b/hail/src/main/scala/is/hail/expr/ir/ValueReader.scala index ba5c142d08c..0305502b818 100644 --- a/hail/src/main/scala/is/hail/expr/ir/ValueReader.scala +++ b/hail/src/main/scala/is/hail/expr/ir/ValueReader.scala @@ -6,28 +6,29 @@ import is.hail.io.{AbstractTypedCodecSpec, BufferSpec, TypedCodecSpec} import is.hail.types.TypeWithRequiredness import is.hail.types.encoded._ import is.hail.types.physical._ -import is.hail.types.physical.stypes.{SCode, SType, SValue} -import is.hail.types.physical.stypes.concrete.SStackStruct +import is.hail.types.physical.stypes.SValue import is.hail.types.virtual._ -import is.hail.utils._ -import org.json4s.{DefaultFormats, Extraction, Formats, JValue, ShortTypeHints} +import java.io.InputStream -import java.io.{ByteArrayInputStream, ByteArrayOutputStream, InputStream, OutputStream} +import org.json4s.{DefaultFormats, Extraction, Formats, JValue, ShortTypeHints} object ValueReader { - implicit val formats: Formats = new DefaultFormats() { - override val typeHints = ShortTypeHints(List( - classOf[ETypeValueReader], - classOf[AbstractTypedCodecSpec], - classOf[TypedCodecSpec]), - typeHintFieldName = "name" - ) + BufferSpec.shortTypeHints - } + - new TStructSerializer + - new TypeSerializer + - new PTypeSerializer + - new ETypeSerializer + implicit val formats: Formats = + new DefaultFormats() { + override val typeHints = ShortTypeHints( + List( + classOf[ETypeValueReader], + classOf[AbstractTypedCodecSpec], + classOf[TypedCodecSpec], + ), + typeHintFieldName = "name", + ) + BufferSpec.shortTypeHints + } + + new TStructSerializer + + new TypeSerializer + + new PTypeSerializer + + new ETypeSerializer } abstract class ValueReader { @@ -38,12 +39,12 @@ abstract class ValueReader { def toJValue: JValue = Extraction.decompose(this)(ValueReader.formats) } - final case class ETypeValueReader(spec: AbstractTypedCodecSpec) extends ValueReader { def unionRequiredness(requestedType: Type, requiredness: TypeWithRequiredness): Unit = requiredness.fromPType(spec.encodedType.decodedPType(requestedType)) - def readValue(cb: EmitCodeBuilder, t: Type, region: Value[Region], is: Value[InputStream]): SValue = { + def readValue(cb: EmitCodeBuilder, t: Type, region: Value[Region], is: Value[InputStream]) + : SValue = { val decoder = spec.encodedType.buildDecoder(t, cb.emb.ecb) val ib = cb.memoize(spec.buildCodeInputBuffer(is)) val ret = decoder.apply(cb, region, ib) diff --git a/hail/src/main/scala/is/hail/expr/ir/ValueWriter.scala b/hail/src/main/scala/is/hail/expr/ir/ValueWriter.scala index 5db0e45a37c..1205006008c 100644 --- a/hail/src/main/scala/is/hail/expr/ir/ValueWriter.scala +++ b/hail/src/main/scala/is/hail/expr/ir/ValueWriter.scala @@ -4,28 +4,29 @@ import is.hail.asm4s._ import is.hail.io.{AbstractTypedCodecSpec, BufferSpec, TypedCodecSpec} import is.hail.types.encoded._ import is.hail.types.physical._ -import is.hail.types.physical.stypes.{SCode, SType, SValue} -import is.hail.types.physical.stypes.concrete.SStackStruct +import is.hail.types.physical.stypes.SValue import is.hail.types.virtual._ -import is.hail.utils._ -import org.json4s.{DefaultFormats, Extraction, Formats, JValue, ShortTypeHints} +import java.io.OutputStream -import java.io.{ByteArrayInputStream, ByteArrayOutputStream, InputStream, OutputStream} +import org.json4s.{DefaultFormats, Extraction, Formats, JValue, ShortTypeHints} object ValueWriter { - implicit val formats: Formats = new DefaultFormats() { - override val typeHints = ShortTypeHints(List( - classOf[ETypeValueWriter], - classOf[AbstractTypedCodecSpec], - classOf[TypedCodecSpec]), - typeHintFieldName = "name" - ) + BufferSpec.shortTypeHints - } + - new TStructSerializer + - new TypeSerializer + - new PTypeSerializer + - new ETypeSerializer + implicit val formats: Formats = + new DefaultFormats() { + override val typeHints = ShortTypeHints( + List( + classOf[ETypeValueWriter], + classOf[AbstractTypedCodecSpec], + classOf[TypedCodecSpec], + ), + typeHintFieldName = "name", + ) + BufferSpec.shortTypeHints + } + + new TStructSerializer + + new TypeSerializer + + new PTypeSerializer + + new ETypeSerializer } abstract class ValueWriter { diff --git a/hail/src/main/scala/is/hail/expr/ir/agg/AggregatorState.scala b/hail/src/main/scala/is/hail/expr/ir/agg/AggregatorState.scala index 223545a5d74..a17942e1244 100644 --- a/hail/src/main/scala/is/hail/expr/ir/agg/AggregatorState.scala +++ b/hail/src/main/scala/is/hail/expr/ir/agg/AggregatorState.scala @@ -25,9 +25,17 @@ trait AggregatorState { // null to safeguard against users of off def newState(cb: EmitCodeBuilder): Unit = newState(cb, null) - def load(cb: EmitCodeBuilder, regionLoader: (EmitCodeBuilder, Value[Region]) => Unit, src: Value[Long]): Unit - - def store(cb: EmitCodeBuilder, regionStorer: (EmitCodeBuilder, Value[Region]) => Unit, dest: Value[Long]): Unit + def load( + cb: EmitCodeBuilder, + regionLoader: (EmitCodeBuilder, Value[Region]) => Unit, + src: Value[Long], + ): Unit + + def store( + cb: EmitCodeBuilder, + regionStorer: (EmitCodeBuilder, Value[Region]) => Unit, + dest: Value[Long], + ): Unit def copyFrom(cb: EmitCodeBuilder, src: Value[Long]): Unit @@ -36,7 +44,10 @@ trait AggregatorState { def deserialize(codec: BufferSpec): (EmitCodeBuilder, Value[InputBuffer]) => Unit def deserializeFromBytes(cb: EmitCodeBuilder, bytes: SBinaryValue): Unit = { - val lazyBuffer = kb.getOrDefineLazyField[MemoryBufferWrapper](Code.newInstance[MemoryBufferWrapper](), ("AggregatorStateBufferWrapper")) + val lazyBuffer = kb.getOrDefineLazyField[MemoryBufferWrapper]( + Code.newInstance[MemoryBufferWrapper](), + ("AggregatorStateBufferWrapper"), + ) cb += lazyBuffer.invoke[Array[Byte], Unit]("set", bytes.loadBytes(cb)) val ib = cb.memoize(lazyBuffer.invoke[InputBuffer]("buffer")) deserialize(BufferSpec.blockedUncompressed)(cb, ib) @@ -44,7 +55,10 @@ trait AggregatorState { } def serializeToRegion(cb: EmitCodeBuilder, t: PBinary, r: Value[Region]): SValue = { - val lazyBuffer = kb.getOrDefineLazyField[MemoryWriterWrapper](Code.newInstance[MemoryWriterWrapper](), ("AggregatorStateWriterWrapper")) + val lazyBuffer = kb.getOrDefineLazyField[MemoryWriterWrapper]( + Code.newInstance[MemoryWriterWrapper](), + ("AggregatorStateWriterWrapper"), + ) val addr = kb.genFieldThisRef[Long]("addr") cb += lazyBuffer.invoke[Unit]("clear") val ob = cb.memoize(lazyBuffer.invoke[OutputBuffer]("buffer")) @@ -61,20 +75,29 @@ trait RegionBackedAggState extends AggregatorState { protected val r: Settable[Region] = kb.genFieldThisRef[Region]() val region: Value[Region] = r - def newState(cb: EmitCodeBuilder, off: Value[Long]): Unit = cb += region.getNewRegion(const(regionSize)) + def newState(cb: EmitCodeBuilder, off: Value[Long]): Unit = + cb += region.getNewRegion(const(regionSize)) - def createState(cb: EmitCodeBuilder): Unit = { + def createState(cb: EmitCodeBuilder): Unit = cb.if_(region.isNull, cb.assign(r, Region.stagedCreate(regionSize, kb.pool()))) - } - def load(cb: EmitCodeBuilder, regionLoader: (EmitCodeBuilder, Value[Region]) => Unit, src: Value[Long]): Unit = regionLoader(cb, r) - - def store(cb: EmitCodeBuilder, regionStorer: (EmitCodeBuilder, Value[Region]) => Unit, dest: Value[Long]): Unit = - cb.if_(region.isValid, - { + def load( + cb: EmitCodeBuilder, + regionLoader: (EmitCodeBuilder, Value[Region]) => Unit, + src: Value[Long], + ): Unit = regionLoader(cb, r) + + def store( + cb: EmitCodeBuilder, + regionStorer: (EmitCodeBuilder, Value[Region]) => Unit, + dest: Value[Long], + ): Unit = + cb.if_( + region.isValid, { regionStorer(cb, region) cb += region.invalidate() - }) + }, + ) } trait PointerBasedRVAState extends RegionBackedAggState { @@ -83,18 +106,26 @@ trait PointerBasedRVAState extends RegionBackedAggState { override val regionSize: Int = Region.TINIER - override def load(cb: EmitCodeBuilder, regionLoader: (EmitCodeBuilder, Value[Region]) => Unit, src: Value[Long]): Unit = { + override def load( + cb: EmitCodeBuilder, + regionLoader: (EmitCodeBuilder, Value[Region]) => Unit, + src: Value[Long], + ): Unit = { super.load(cb, regionLoader, src) cb.assign(off, Region.loadAddress(src)) } - override def store(cb: EmitCodeBuilder, regionStorer: (EmitCodeBuilder, Value[Region]) => Unit, dest: Value[Long]): Unit = { - cb.if_(region.isValid, - { + override def store( + cb: EmitCodeBuilder, + regionStorer: (EmitCodeBuilder, Value[Region]) => Unit, + dest: Value[Long], + ): Unit = + cb.if_( + region.isValid, { cb += Region.storeAddress(dest, off) super.store(cb, regionStorer, dest) - }) - } + }, + ) def copyFrom(cb: EmitCodeBuilder, src: Value[Long]): Unit = copyFromAddress(cb, cb.memoize(Region.loadAddress(src))) @@ -102,7 +133,8 @@ trait PointerBasedRVAState extends RegionBackedAggState { def copyFromAddress(cb: EmitCodeBuilder, src: Value[Long]): Unit } -class TypedRegionBackedAggState(val typ: VirtualTypeWithReq, val kb: EmitClassBuilder[_]) extends AbstractTypedRegionBackedAggState(typ.canonicalPType) +class TypedRegionBackedAggState(val typ: VirtualTypeWithReq, val kb: EmitClassBuilder[_]) + extends AbstractTypedRegionBackedAggState(typ.canonicalPType) abstract class AbstractTypedRegionBackedAggState(val ptype: PType) extends RegionBackedAggState { @@ -120,21 +152,29 @@ abstract class AbstractTypedRegionBackedAggState(val ptype: PType) extends Regio super.newState(cb, off) } - override def load(cb: EmitCodeBuilder, regionLoader: (EmitCodeBuilder, Value[Region]) => Unit, src: Value[Long]): Unit = { - super.load(cb, { (cb: EmitCodeBuilder, r: Value[Region]) => cb += r.invalidate() }, src) + override def load( + cb: EmitCodeBuilder, + regionLoader: (EmitCodeBuilder, Value[Region]) => Unit, + src: Value[Long], + ): Unit = { + super.load(cb, (cb: EmitCodeBuilder, r: Value[Region]) => cb += r.invalidate(), src) cb.assign(off, src) } - override def store(cb: EmitCodeBuilder, regionStorer: (EmitCodeBuilder, Value[Region]) => Unit, dest: Value[Long]): Unit = { - cb.if_(region.isValid, - cb.if_(dest.cne(off), - cb += Region.copyFrom(off, dest, const(storageType.byteSize)))) + override def store( + cb: EmitCodeBuilder, + regionStorer: (EmitCodeBuilder, Value[Region]) => Unit, + dest: Value[Long], + ): Unit = { + cb.if_( + region.isValid, + cb.if_(dest.cne(off), cb += Region.copyFrom(off, dest, const(storageType.byteSize))), + ) super.store(cb, regionStorer, dest) } - def storeMissing(cb: EmitCodeBuilder): Unit = { + def storeMissing(cb: EmitCodeBuilder): Unit = storageType.setFieldMissing(cb, off, 0) - } def storeNonmissing(cb: EmitCodeBuilder, sc: SValue): Unit = { cb += region.getNewRegion(const(regionSize)) @@ -142,13 +182,22 @@ abstract class AbstractTypedRegionBackedAggState(val ptype: PType) extends Regio ptype.storeAtAddress(cb, storageType.fieldOffset(off, 0), region, sc, deepCopy = true) } - def get(cb: EmitCodeBuilder): IEmitCode = { - IEmitCode(cb, storageType.isFieldMissing(cb, off, 0), ptype.loadCheapSCode(cb, storageType.loadField(off, 0))) - } + def get(cb: EmitCodeBuilder): IEmitCode = + IEmitCode( + cb, + storageType.isFieldMissing(cb, off, 0), + ptype.loadCheapSCode(cb, storageType.loadField(off, 0)), + ) def copyFrom(cb: EmitCodeBuilder, src: Value[Long]): Unit = { newState(cb, off) - storageType.storeAtAddress(cb, off, region, storageType.loadCheapSCode(cb, src), deepCopy = true) + storageType.storeAtAddress( + cb, + off, + region, + storageType.loadCheapSCode(cb, src), + deepCopy = true, + ) } def serialize(codec: BufferSpec): (EmitCodeBuilder, Value[OutputBuffer]) => Unit = { @@ -161,25 +210,27 @@ abstract class AbstractTypedRegionBackedAggState(val ptype: PType) extends Regio val codecSpec = TypedCodecSpec(storageType, codec) val dec = codecSpec.encodedType.buildDecoder(storageType.virtualType, kb) - ((cb: EmitCodeBuilder, ib: Value[InputBuffer]) => - storageType.storeAtAddress(cb, off, region, dec(cb, region, ib), deepCopy = false)) + (cb: EmitCodeBuilder, ib: Value[InputBuffer]) => + storageType.storeAtAddress(cb, off, region, dec(cb, region, ib), deepCopy = false) } } -class PrimitiveRVAState(val vtypes: Array[VirtualTypeWithReq], val kb: EmitClassBuilder[_]) extends AggregatorState { +class PrimitiveRVAState(val vtypes: Array[VirtualTypeWithReq], val kb: EmitClassBuilder[_]) + extends AggregatorState { private[this] val emitTypes = vtypes.map(_.canonicalEmitType) assert(emitTypes.forall(_.st.isPrimitive)) val nFields: Int = emitTypes.length - val fields: Array[EmitSettable] = Array.tabulate(nFields) { i => kb.newEmitField(s"primitiveRVA_${ i }_v", emitTypes(i)) } + + val fields: Array[EmitSettable] = Array.tabulate(nFields) { i => + kb.newEmitField(s"primitiveRVA_${i}_v", emitTypes(i)) + } + val storageType = PCanonicalTuple(true, emitTypes.map(_.typeWithRequiredness.canonicalPType): _*) val sStorageType = storageType.sType - def foreachField(f: (Int, EmitSettable) => Unit): Unit = { - (0 until nFields).foreach { i => - f(i, fields(i)) - } - } + def foreachField(f: (Int, EmitSettable) => Unit): Unit = + (0 until nFields).foreach(i => f(i, fields(i))) def newState(cb: EmitCodeBuilder, off: Value[Long]): Unit = {} @@ -187,22 +238,28 @@ class PrimitiveRVAState(val vtypes: Array[VirtualTypeWithReq], val kb: EmitClass private[this] def loadVarsFromRegion(cb: EmitCodeBuilder, srcc: Code[Long]): Unit = { val pv = storageType.loadCheapSCode(cb, srcc) - foreachField { (i, es) => - cb.assign(es, pv.loadField(cb, i)) - } + foreachField((i, es) => cb.assign(es, pv.loadField(cb, i))) } - def load(cb: EmitCodeBuilder, regionLoader: (EmitCodeBuilder, Value[Region]) => Unit, src: Value[Long]): Unit = { + def load( + cb: EmitCodeBuilder, + regionLoader: (EmitCodeBuilder, Value[Region]) => Unit, + src: Value[Long], + ): Unit = loadVarsFromRegion(cb, src) - } - def store(cb: EmitCodeBuilder, regionStorer: (EmitCodeBuilder, Value[Region]) => Unit, dest: Value[Long]): Unit = { - storageType.storeAtAddress(cb, + def store( + cb: EmitCodeBuilder, + regionStorer: (EmitCodeBuilder, Value[Region]) => Unit, + dest: Value[Long], + ): Unit = + storageType.storeAtAddress( + cb, dest, null, SStackStruct.constructFromArgs(cb, null, storageType.virtualType, fields.map(_.load): _*), - false) - } + false, + ) def copyFrom(cb: EmitCodeBuilder, src: Value[Long]): Unit = loadVarsFromRegion(cb, src) @@ -212,12 +269,14 @@ class PrimitiveRVAState(val vtypes: Array[VirtualTypeWithReq], val kb: EmitClass if (es.emitType.required) { ob.writePrimitive(cb, es.get(cb)) } else { - es.toI(cb).consume(cb, + es.toI(cb).consume( + cb, cb += ob.writeBoolean(true), { sc => cb += ob.writeBoolean(false) ob.writePrimitive(cb, sc) - }) + }, + ) } } } @@ -228,9 +287,11 @@ class PrimitiveRVAState(val vtypes: Array[VirtualTypeWithReq], val kb: EmitClass if (es.emitType.required) { cb.assign(es, EmitCode.present(cb.emb, ib.readPrimitive(cb, es.st.virtualType))) } else { - cb.if_(ib.readBoolean(), + cb.if_( + ib.readBoolean(), cb.assign(es, EmitCode.missing(cb.emb, es.st)), - cb.assign(es, EmitCode.present(cb.emb, ib.readPrimitive(cb, es.st.virtualType)))) + cb.assign(es, EmitCode.present(cb.emb, ib.readPrimitive(cb, es.st.virtualType))), + ) } } } @@ -238,61 +299,68 @@ class PrimitiveRVAState(val vtypes: Array[VirtualTypeWithReq], val kb: EmitClass case class StateTuple(states: Array[AggregatorState]) { val nStates: Int = states.length - val storageType: PTuple = PCanonicalTuple(true, states.map { s => s.storageType }: _*) + val storageType: PTuple = PCanonicalTuple(true, states.map(s => s.storageType): _*) def apply(i: Int): AggregatorState = { if (i >= states.length) - throw new RuntimeException(s"tried to access state $i, but there are only ${ states.length } states") + throw new RuntimeException( + s"tried to access state $i, but there are only ${states.length} states" + ) states(i) } - def toCode(f: (Int, AggregatorState) => Unit): Unit = { - (0 until nStates).foreach { i => - f(i, states(i)) - } - } + def toCode(f: (Int, AggregatorState) => Unit): Unit = + (0 until nStates).foreach(i => f(i, states(i))) def toCodeWithArgs( - cb: EmitCodeBuilder, f: (EmitCodeBuilder, Int, AggregatorState) => Unit - ): Unit = { - (0 until nStates).foreach { i => - f(cb, i, states(i)) - } - } + cb: EmitCodeBuilder, + f: (EmitCodeBuilder, Int, AggregatorState) => Unit, + ): Unit = + (0 until nStates).foreach(i => f(cb, i, states(i))) def createStates(cb: EmitCodeBuilder): Unit = toCode((i, s) => s.createState(cb)) } -class TupleAggregatorState(val kb: EmitClassBuilder[_], val states: StateTuple, val topRegion: Value[Region], val off: Value[Long], val rOff: Value[Int] = const(0)) { +class TupleAggregatorState( + val kb: EmitClassBuilder[_], + val states: StateTuple, + val topRegion: Value[Region], + val off: Value[Long], + val rOff: Value[Int] = const(0), +) { val storageType: PTuple = states.storageType - private def getRegion(i: Int): (EmitCodeBuilder, Value[Region]) => Unit = { (cb: EmitCodeBuilder, r: Value[Region]) => - cb += r.setFromParentReference(topRegion, rOff + const(i), states(i).regionSize) + private def getRegion(i: Int): (EmitCodeBuilder, Value[Region]) => Unit = { + (cb: EmitCodeBuilder, r: Value[Region]) => + cb += r.setFromParentReference(topRegion, rOff + const(i), states(i).regionSize) } - private def setRegion(i: Int): (EmitCodeBuilder, Value[Region]) => Unit = { (cb: EmitCodeBuilder, r: Value[Region]) => - cb += topRegion.setParentReference(r, rOff + const(i)) + private def setRegion(i: Int): (EmitCodeBuilder, Value[Region]) => Unit = { + (cb: EmitCodeBuilder, r: Value[Region]) => + cb += topRegion.setParentReference(r, rOff + const(i)) } - private def getStateOffset(cb: EmitCodeBuilder, i: Int): Value[Long] = cb.memoize(storageType.loadField(off, i)) + private def getStateOffset(cb: EmitCodeBuilder, i: Int): Value[Long] = + cb.memoize(storageType.loadField(off, i)) def toCode(f: (Int, AggregatorState) => Unit): Unit = (0 until states.nStates).foreach(i => f(i, states(i))) def newState(cb: EmitCodeBuilder, i: Int): Unit = states(i).newState(cb, getStateOffset(cb, i)) - def newState(cb: EmitCodeBuilder): Unit = states.toCode((i, s) => s.newState(cb, getStateOffset(cb, i))) + def newState(cb: EmitCodeBuilder): Unit = + states.toCode((i, s) => s.newState(cb, getStateOffset(cb, i))) def load(cb: EmitCodeBuilder): Unit = states.toCode((i, s) => s.load(cb, getRegion(i), getStateOffset(cb, i))) - def store(cb: EmitCodeBuilder): Unit = { + def store(cb: EmitCodeBuilder): Unit = states.toCode((i, s) => s.store(cb, setRegion(i), getStateOffset(cb, i))) - } - def copyFrom(cb: EmitCodeBuilder, statesOffset: Value[Long]): Unit = { - states.toCodeWithArgs(cb, - { case (cb, i, s) => s.copyFrom(cb, cb.memoize(storageType.loadField(statesOffset, i))) }) - } + def copyFrom(cb: EmitCodeBuilder, statesOffset: Value[Long]): Unit = + states.toCodeWithArgs( + cb, + { case (cb, i, s) => s.copyFrom(cb, cb.memoize(storageType.loadField(statesOffset, i))) }, + ) } diff --git a/hail/src/main/scala/is/hail/expr/ir/agg/AppendOnlyBTree.scala b/hail/src/main/scala/is/hail/expr/ir/agg/AppendOnlyBTree.scala index 9e7dff252f8..a1411542107 100644 --- a/hail/src/main/scala/is/hail/expr/ir/agg/AppendOnlyBTree.scala +++ b/hail/src/main/scala/is/hail/expr/ir/agg/AppendOnlyBTree.scala @@ -30,19 +30,29 @@ trait BTreeKey { compKeys(cb, loadCompKey(cb, off), loadCompKey(cb, other)) } - def compWithKey(cb: EmitCodeBuilder, off: Value[Long], k: EmitValue): Value[Int] = { + def compWithKey(cb: EmitCodeBuilder, off: Value[Long], k: EmitValue): Value[Int] = compKeys(cb, loadCompKey(cb, off), k) - } } -class AppendOnlyBTree(kb: EmitClassBuilder[_], val key: BTreeKey, region: Value[Region], root: Settable[Long], maxElements: Int = 2) { +class AppendOnlyBTree( + kb: EmitClassBuilder[_], + val key: BTreeKey, + region: Value[Region], + root: Settable[Long], + maxElements: Int = 2, +) { private val splitIdx = maxElements / 2 private val eltType: PTuple = PCanonicalTuple(false, key.storageType, PInt64(true)) - private val elementsType: PTuple = PCanonicalTuple(required = true, Array.fill[PType](maxElements)(eltType): _*) - private val storageType: PStruct = PCanonicalStruct(required = true, + + private val elementsType: PTuple = + PCanonicalTuple(required = true, Array.fill[PType](maxElements)(eltType): _*) + + private val storageType: PStruct = PCanonicalStruct( + required = true, "parent" -> PInt64(), "child0" -> PInt64(), - "elements" -> elementsType) + "elements" -> elementsType, + ) private def createNode(cb: EmitCodeBuilder, nodeBucket: Settable[Long]): Unit = { cb.assign(nodeBucket, region.allocate(storageType.alignment, storageType.byteSize)) @@ -64,18 +74,17 @@ class AppendOnlyBTree(kb: EmitClassBuilder[_], val key: BTreeKey, region: Value[ private def hasKey(cb: EmitCodeBuilder, node: Code[Long], i: Int): Value[Boolean] = elementsType.isFieldDefined(cb, elements(node), i) - private def setKeyPresent(cb: EmitCodeBuilder, node: Code[Long], i: Int): Unit = { + private def setKeyPresent(cb: EmitCodeBuilder, node: Code[Long], i: Int): Unit = elementsType.setFieldPresent(cb, elements(node), i) - } - private def setKeyMissing(cb: EmitCodeBuilder, node: Code[Long], i: Int): Unit = { + private def setKeyMissing(cb: EmitCodeBuilder, node: Code[Long], i: Int): Unit = elementsType.setFieldMissing(cb, elements(node), i) - } private def isFull(cb: EmitCodeBuilder, node: Code[Long]): Value[Boolean] = hasKey(cb, node, maxElements - 1) - private def keyOffset(node: Code[Long], i: Int): Code[Long] = eltType.fieldOffset(elementsType.loadField(elements(node), i), 0) + private def keyOffset(node: Code[Long], i: Int): Code[Long] = + eltType.fieldOffset(elementsType.loadField(elements(node), i), 0) private def loadKey(cb: EmitCodeBuilder, node: Value[Long], i: Int): Value[Long] = cb.memoize(eltType.loadField(elementsType.loadField(elements(node), i), 0)) @@ -89,7 +98,13 @@ class AppendOnlyBTree(kb: EmitClassBuilder[_], val key: BTreeKey, region: Value[ private def loadChild(cb: EmitCodeBuilder, node: Code[Long], i: Int): Value[Long] = cb.memoize(Region.loadAddress(childOffset(node, i))) - private def setChild(cb: EmitCodeBuilder, parentC: Code[Long], i: Int, childC: Code[Long], context: String): Unit = { + private def setChild( + cb: EmitCodeBuilder, + parentC: Code[Long], + i: Int, + childC: Code[Long], + context: String, + ): Unit = { val parent = cb.newLocal[Long]("aobt_set_child_parent", parentC) val child = cb.newLocal[Long]("aobt_set_child_child", childC) @@ -100,11 +115,22 @@ class AppendOnlyBTree(kb: EmitClassBuilder[_], val key: BTreeKey, region: Value[ cb += Region.storeAddress(storageType.fieldOffset(child, 0), parent) } - private def insert(cb: EmitCodeBuilder, nodec: Value[Long], insertIdxc: Value[Int], kc: EmitCode, childC: Value[Long]): Value[Long] = { + private def insert( + cb: EmitCodeBuilder, + nodec: Value[Long], + insertIdxc: Value[Int], + kc: EmitCode, + childC: Value[Long], + ): Value[Long] = { val kt = key.compType.sType - val castKCode = EmitCode.fromI(cb.emb)(cb => kc.toI(cb).map(cb)(k => kt.coerceOrCopy(cb, region, k, false))) - val insertAt = kb.getOrGenEmitMethod("btree_insert", (this, "insert", kt), - FastSeq[ParamType](typeInfo[Long], typeInfo[Int], castKCode.emitParamType, typeInfo[Long]), typeInfo[Long]) { insertAt => + val castKCode = + EmitCode.fromI(cb.emb)(cb => kc.toI(cb).map(cb)(k => kt.coerceOrCopy(cb, region, k, false))) + val insertAt = kb.getOrGenEmitMethod( + "btree_insert", + (this, "insert", kt), + FastSeq[ParamType](typeInfo[Long], typeInfo[Int], castKCode.emitParamType, typeInfo[Long]), + typeInfo[Long], + ) { insertAt => val node: Value[Long] = insertAt.getCodeParam[Long](1) val insertIdx: Value[Int] = insertAt.getCodeParam[Int](2) val k: EmitValue = insertAt.getEmitParam(cb, 3) @@ -117,21 +143,25 @@ class AppendOnlyBTree(kb: EmitClassBuilder[_], val key: BTreeKey, region: Value[ def makeUninitialized(cb: EmitCodeBuilder, idx: Int): Value[Long] = { setKeyPresent(cb, node, idx) key.initializeEmpty(cb, keyOffset(node, idx)) - cb.if_(!isLeaf(cb, node), { - setChild(cb, node, idx, child, "makeUninitialized setChild") - }) + cb.if_(!isLeaf(cb, node), setChild(cb, node, idx, child, "makeUninitialized setChild")) loadKey(cb, node, idx) } - def copyFrom(cb: EmitCodeBuilder, destNodeC: Code[Long], destIdx: Int, srcNodeC: Code[Long], srcIdx: Int): Unit = { + def copyFrom( + cb: EmitCodeBuilder, + destNodeC: Code[Long], + destIdx: Int, + srcNodeC: Code[Long], + srcIdx: Int, + ): Unit = { val destNode = cb.newLocal("aobt_copy_from_destnode", destNodeC) val srcNode = cb.newLocal("aobt_copy_from_srcnode", srcNodeC) setKeyPresent(cb, destNode, destIdx) key.copy(cb, keyOffset(srcNode, srcIdx), keyOffset(destNode, destIdx)) - cb.if_(!isLeaf(cb, srcNode), - { - setChild(cb, destNode, destIdx, loadChild(cb, srcNode, srcIdx), "insert copyFrom") - }) + cb.if_( + !isLeaf(cb, srcNode), + setChild(cb, destNode, destIdx, loadChild(cb, srcNode, srcIdx), "insert copyFrom"), + ) } def copyToNew(cb: EmitCodeBuilder, startIdx: Int): Unit = @@ -147,26 +177,23 @@ class AppendOnlyBTree(kb: EmitClassBuilder[_], val key: BTreeKey, region: Value[ (0 until maxElements).foreach { i => val b = cb.newLocal[Boolean]("btree_insertkey_b", !hasKey(cb, parent(cb), i)) cb.if_(!b, cb.assign(b, key.compWithKey(cb, loadKey(cb, parent(cb), i), ev) >= 0)) - cb.if_(b, { - cb.assign(upperBound, i) - cb.goto(Lfound) - }) + cb.if_( + b, { + cb.assign(upperBound, i) + cb.goto(Lfound) + }, + ) } cb.define(Lfound) - cb.if_(!isLeaf(cb, node), { - setChild(cb, newNode, -1, c, "insertKey !isLeaf") - }) - cb.invokeCode(insertAt, parent(cb), upperBound, ev, newNode) + cb.if_(!isLeaf(cb, node), setChild(cb, newNode, -1, c, "insertKey !isLeaf")) + cb.invokeCode(insertAt, cb.this_, parent(cb), upperBound, ev, newNode) } - def promote(cb: EmitCodeBuilder, idx: Int): Unit = { val nikey = cb.newLocal("aobt_insert_nikey", loadKey(cb, node, idx)) - cb.if_(!isLeaf(cb, node), { - setChild(cb, newNode, -1, loadChild(cb, node, idx), "promote") - }) + cb.if_(!isLeaf(cb, node), setChild(cb, newNode, -1, loadChild(cb, node, idx), "promote")) val upperBound = cb.newLocal("promote_upper_bound", maxElements) val Lfound = CodeLabel() @@ -174,37 +201,65 @@ class AppendOnlyBTree(kb: EmitClassBuilder[_], val key: BTreeKey, region: Value[ (0 until maxElements).foreach { i => val b = cb.newLocal[Boolean]("btree_insert_promote_b", !hasKey(cb, parent(cb), i)) cb.if_(!b, cb.assign(b, key.compSame(cb, loadKey(cb, parent(cb), i), nikey) >= 0)) - cb.if_(b, { - cb.assign(upperBound, i) - cb.goto(Lfound) - }) + cb.if_( + b, { + cb.assign(upperBound, i) + cb.goto(Lfound) + }, + ) } cb.define(Lfound) - key.copy(cb, loadKey(cb, node, idx), cb.invokeCode(insertAt, parent(cb), upperBound, key.loadCompKey(cb, nikey), newNode)) + key.copy( + cb, + loadKey(cb, node, idx), + cb.invokeCode( + insertAt, + cb.this_, + parent(cb), + upperBound, + key.loadCompKey(cb, nikey), + newNode, + ), + ) setKeyMissing(cb, node, idx) } def splitAndInsert(cb: EmitCodeBuilder): Code[Long] = { - cb.if_(isRoot(cb, node), { - createNode(cb, root) - setChild(cb, root, -1, node, "splitAndInsert") - }) + cb.if_( + isRoot(cb, node), { + createNode(cb, root) + setChild(cb, root, -1, node, "splitAndInsert") + }, + ) createNode(cb, newNode) val out = cb.newLocal[Long]("split_and_insert_out") - cb.if_(insertIdx > splitIdx, { - copyToNew(cb, splitIdx + 1) - promote(cb, splitIdx) - cb.assign(out, cb.invokeCode(insertAt, newNode, cb.memoize(insertIdx - splitIdx - 1), k, child)) - }, { - copyToNew(cb, splitIdx) - cb.if_(insertIdx.ceq(splitIdx), { - cb.assign(out, insertKey(cb, k, child)) + cb.if_( + insertIdx > splitIdx, { + copyToNew(cb, splitIdx + 1) + promote(cb, splitIdx) + cb.assign( + out, + cb.invokeCode( + insertAt, + cb.this_, + newNode, + cb.memoize(insertIdx - splitIdx - 1), + k, + child, + ), + ) }, { - promote(cb, splitIdx - 1) - cb.assign(out, cb.invokeCode(insertAt, node, insertIdx, k, child)) - }) - }) + copyToNew(cb, splitIdx) + cb.if_( + insertIdx.ceq(splitIdx), + cb.assign(out, insertKey(cb, k, child)), { + promote(cb, splitIdx - 1) + cb.assign(out, cb.invokeCode(insertAt, cb.this_, node, insertIdx, k, child)) + }, + ) + }, + ) out } @@ -212,13 +267,13 @@ class AppendOnlyBTree(kb: EmitClassBuilder[_], val key: BTreeKey, region: Value[ val ret = cb.newLocal[Long]("shift_and_insert") val Lout = CodeLabel() (1 until maxElements).reverse.foreach { destIdx => - cb.if_(hasKey(cb, node, destIdx - 1), { - copyFrom(cb, node, destIdx, node, destIdx - 1) - }) - cb.if_(insertIdx.ceq(destIdx), { - cb.assign(ret, makeUninitialized(cb, destIdx)) - cb.goto(Lout) - }) + cb.if_(hasKey(cb, node, destIdx - 1), copyFrom(cb, node, destIdx, node, destIdx - 1)) + cb.if_( + insertIdx.ceq(destIdx), { + cb.assign(ret, makeUninitialized(cb, destIdx)) + cb.goto(Lout) + }, + ) } cb.assign(ret, makeUninitialized(cb, 0)) cb.define(Lout) @@ -227,54 +282,71 @@ class AppendOnlyBTree(kb: EmitClassBuilder[_], val key: BTreeKey, region: Value[ insertAt.emitWithBuilder { cb => val ret = cb.newLocal[Long]("btree_insert_result") - cb.if_(isFull(cb, node), + cb.if_( + isFull(cb, node), cb.assign(ret, splitAndInsert(cb)), - cb.assign(ret, shiftAndInsert(cb))) + cb.assign(ret, shiftAndInsert(cb)), + ) ret } } - cb.invokeCode[Long](insertAt, nodec, insertIdxc, castKCode, childC) + cb.invokeCode[Long](insertAt, cb.this_, nodec, insertIdxc, castKCode, childC) } private def getF(cb: EmitCodeBuilder, root: Value[Long], kc: EmitCode): Value[Long] = { - val get = kb.genEmitMethod("btree_get", FastSeq[ParamType](typeInfo[Long], kc.emitParamType), typeInfo[Long]) - get.emitWithBuilder { cb => - val node = get.getCodeParam[Long](1) - val k = get.getEmitParam(cb, 2) - - val cmp = cb.newLocal("btree_get_cmp", -1) - val keyV = cb.newLocal("btree_get_keyV", 0L) - - def insertOrGetAt(i: Int) = { - cb.if_(isLeaf(cb, node), { - cb.assign(keyV, insert(cb, node, const(i), k, const(0L))) - cb.assign(cmp, 0) - }, { - cb.assign(node, loadChild(cb, node, i - 1)) - }) - } - - cb.while_(cmp.cne(0), { (Lcont: CodeLabel) => - (0 until maxElements).foreach { i => - cb.if_(hasKey(cb, node, i), { - cb.assign(keyV, loadKey(cb, node, i)) - cb.assign(cmp, key.compWithKey(cb, keyV, k)) - cb.if_(cmp.ceq(0), cb.goto(Lcont)) - cb.if_(cmp > 0, { - insertOrGetAt(i) - cb.goto(Lcont) - }) - }, { - insertOrGetAt(i) - cb.goto(Lcont) - }) + val get = kb.getOrGenEmitMethod( + "btree_get", + ("btree_get", key), + FastSeq[ParamType](typeInfo[Long], kc.emitParamType), + typeInfo[Long], + ) { get => + get.emitWithBuilder { cb => + val node = get.getCodeParam[Long](1) + val k = get.getEmitParam(cb, 2) + + val cmp = cb.newLocal("btree_get_cmp", -1) + val keyV = cb.newLocal("btree_get_keyV", 0L) + + def insertOrGetAt(i: Int) = { + cb.if_( + isLeaf(cb, node), { + cb.assign(keyV, insert(cb, node, const(i), k, const(0L))) + cb.assign(cmp, 0) + }, + cb.assign(node, loadChild(cb, node, i - 1)), + ) } - insertOrGetAt(maxElements) - }) - keyV.get + + cb.while_( + cmp.cne(0), + { (Lcont: CodeLabel) => + (0 until maxElements).foreach { i => + cb.if_( + hasKey(cb, node, i), { + cb.assign(keyV, loadKey(cb, node, i)) + cb.assign(cmp, key.compWithKey(cb, keyV, k)) + cb.if_(cmp.ceq(0), cb.goto(Lcont)) + cb.if_( + cmp > 0, { + insertOrGetAt(i) + cb.goto(Lcont) + }, + ) + }, { + insertOrGetAt(i) + cb.goto(Lcont) + }, + ) + } + insertOrGetAt(maxElements) + }, + ) + keyV.get + } } - cb.invokeCode(get, root, kc) + + cb.invokeCode(get, cb.this_, root, kc) } def init(cb: EmitCodeBuilder): Unit = createNode(cb, root) @@ -292,122 +364,152 @@ class AppendOnlyBTree(kb: EmitClassBuilder[_], val key: BTreeKey, region: Value[ cb += idxStack.update(stackI, -1) } - def stackUpdateIdx(newIdx: Code[Int]) = { + def stackUpdateIdx(newIdx: Code[Int]) = cb += idxStack.update(stackI, newIdx) - } - def stackPop() = { + def stackPop() = cb.assign(stackI, stackI - 1) - } stackPush(root) - cb.while_(stackI >= 0, { (Lstart: CodeLabel) => - val node = cb.newLocal("btree_foreach_node", nodeStack(stackI)) - val idx = cb.newLocal("btree_foreach_idx", idxStack(stackI)) - val Lend = CodeLabel() - val Lminus1 = CodeLabel() - val labels = Array.fill[CodeLabel](maxElements)(CodeLabel()) - // this should probably be a switch, don't know how to make it one though - // furthermore, we should be able to do the lookups at runtime - // FIXME, clean this up once we have fixed arrays - cb.if_(idx.ceq(-1), cb.goto(Lminus1)) - (0 until maxElements).zip(labels).foreach { case (i, l) => - cb.if_(idx.ceq(i), cb.goto(l)) - } - cb.goto(Lend) - - cb.define(Lminus1) - cb.if_(!isLeaf(cb, node), { - stackUpdateIdx(0) - stackPush(loadChild(cb, node, -1)) - cb.goto(Lstart) - }) - (0 until maxElements).foreach { i => - cb.define(labels(i)) - cb.if_(hasKey(cb, node, i), { - visitor(cb, loadKey(cb, node, i)) - cb.if_(!isLeaf(cb, node), { - stackUpdateIdx(i + 1) - stackPush(loadChild(cb, node, i)) + cb.while_( + stackI >= 0, + { (Lstart: CodeLabel) => + val node = cb.newLocal("btree_foreach_node", nodeStack(stackI)) + val idx = cb.newLocal("btree_foreach_idx", idxStack(stackI)) + val Lend = CodeLabel() + val Lminus1 = CodeLabel() + val labels = Array.fill[CodeLabel](maxElements)(CodeLabel()) + // this should probably be a switch, don't know how to make it one though + // furthermore, we should be able to do the lookups at runtime + // FIXME, clean this up once we have fixed arrays + cb.if_(idx.ceq(-1), cb.goto(Lminus1)) + (0 until maxElements).zip(labels).foreach { case (i, l) => + cb.if_(idx.ceq(i), cb.goto(l)) + } + cb.goto(Lend) + + cb.define(Lminus1) + cb.if_( + !isLeaf(cb, node), { + stackUpdateIdx(0) + stackPush(loadChild(cb, node, -1)) cb.goto(Lstart) - }) - }, { - cb.goto(Lend) - }) - } + }, + ) + (0 until maxElements).foreach { i => + cb.define(labels(i)) + cb.if_( + hasKey(cb, node, i), { + visitor(cb, loadKey(cb, node, i)) + cb.if_( + !isLeaf(cb, node), { + stackUpdateIdx(i + 1) + stackPush(loadChild(cb, node, i)) + cb.goto(Lstart) + }, + ) + }, + cb.goto(Lend), + ) + } - cb.define(Lend) - stackPop() - }) + cb.define(Lend) + stackPop() + }, + ) } val deepCopy: (EmitCodeBuilder, Value[Long]) => Unit = { - val f = kb.genEmitMethod("btree_deepCopy", FastSeq[ParamType](typeInfo[Long], typeInfo[Long]), typeInfo[Unit]) - f.voidWithBuilder { cb => - val destNode = f.getCodeParam[Long](1) - val srcNode = f.getCodeParam[Long](2) - - val er = EmitRegion(cb.emb, region) - val newNode = cb.newLocal[Long]("new_node") - - def copyChild(i: Int): Unit = { - createNode(cb, newNode) - cb.invokeVoid(cb.emb, newNode, loadChild(cb, srcNode, i)) - } + val f = kb.getOrGenEmitMethod( + "btree_deepCopy", + ("btree_deepCopy", key), + FastSeq[ParamType](typeInfo[Long], typeInfo[Long]), + UnitInfo, + ) { f => + f.voidWithBuilder { cb => + val destNode = f.getCodeParam[Long](1) + val srcNode = f.getCodeParam[Long](2) + + val er = EmitRegion(cb.emb, region) + val newNode = cb.newLocal[Long]("new_node") + + def copyChild(i: Int): Unit = { + createNode(cb, newNode) + cb.invokeVoid(cb.emb, cb.this_, newNode, loadChild(cb, srcNode, i)) + } - cb.if_(!isLeaf(cb, srcNode), { - copyChild(-1) - setChild(cb, destNode, -1, newNode, "deepcopy1") - }) + cb.if_( + !isLeaf(cb, srcNode), { + copyChild(-1) + setChild(cb, destNode, -1, newNode, "deepcopy1") + }, + ) - (0 until maxElements).foreach { i => - cb.if_(hasKey(cb, srcNode, i), { - key.deepCopy(cb, er, destNode, srcNode) - cb.if_(!isLeaf(cb, srcNode), { - copyChild(i) - setChild(cb, destNode, i, newNode, "deepcopy2") - }) - }) + (0 until maxElements).foreach { i => + cb.if_( + hasKey(cb, srcNode, i), { + key.deepCopy(cb, er, destNode, srcNode) + cb.if_( + !isLeaf(cb, srcNode), { + copyChild(i) + setChild(cb, destNode, i, newNode, "deepcopy2") + }, + ) + }, + ) + } } } - { (cb: EmitCodeBuilder, srcRoot: Value[Long]) => cb.invokeVoid(f, root, srcRoot) } + { (cb: EmitCodeBuilder, srcRoot: Value[Long]) => cb.invokeVoid(f, cb.this_, root, srcRoot) } } - def bulkStore(cb: EmitCodeBuilder, obCode: Value[OutputBuffer] - )(keyStore: (EmitCodeBuilder, Value[OutputBuffer], Code[Long]) => Unit): Unit = { - val f = kb.genEmitMethod("btree_bulkStore", FastSeq[ParamType](typeInfo[Long], typeInfo[OutputBuffer]), - typeInfo[Unit]) + def bulkStore( + cb: EmitCodeBuilder, + obCode: Value[OutputBuffer], + )( + keyStore: (EmitCodeBuilder, Value[OutputBuffer], Code[Long]) => Unit + ): Unit = { + val f = kb.genEmitMethod( + "btree_bulkStore", + FastSeq(typeInfo[Long], typeInfo[OutputBuffer]), + typeInfo[Unit], + ) val node = f.getCodeParam[Long](1) val ob = f.getCodeParam[OutputBuffer](2) f.voidWithBuilder { cb => cb += ob.writeBoolean(!isLeaf(cb, node)) - cb.if_(!isLeaf(cb, node), { - cb.invokeVoid(f, loadChild(cb, node, -1), ob) - }) + cb.if_(!isLeaf(cb, node), cb.invokeVoid(f, cb.this_, loadChild(cb, node, -1), ob)) val Lexit = CodeLabel() (0 until maxElements).foreach { i => - cb.if_(hasKey(cb, node, i), { - cb += ob.writeBoolean(true) - keyStore(cb, ob, loadKey(cb, node, i)) - cb.if_(!isLeaf(cb, node), { - cb.invokeVoid(f, loadChild(cb, node, i), ob) - }) - }, { - cb += ob.writeBoolean(false) - cb.goto(Lexit) - }) + cb.if_( + hasKey(cb, node, i), { + cb += ob.writeBoolean(true) + keyStore(cb, ob, loadKey(cb, node, i)) + cb.if_(!isLeaf(cb, node), cb.invokeVoid(f, cb.this_, loadChild(cb, node, i), ob)) + }, { + cb += ob.writeBoolean(false) + cb.goto(Lexit) + }, + ) } cb.define(Lexit) } - cb.invokeVoid(f, root, obCode) + cb.invokeVoid(f, cb.this_, root, obCode) } - def bulkLoad(cb: EmitCodeBuilder, ibCode: Value[InputBuffer] - )(keyLoad: (EmitCodeBuilder, Value[InputBuffer], Code[Long]) => Unit): Unit = { - val f = kb.genEmitMethod("btree_bulkLoad", FastSeq[ParamType](typeInfo[Long], typeInfo[InputBuffer]), - typeInfo[Unit]) + def bulkLoad( + cb: EmitCodeBuilder, + ibCode: Value[InputBuffer], + )( + keyLoad: (EmitCodeBuilder, Value[InputBuffer], Code[Long]) => Unit + ): Unit = { + val f = kb.genEmitMethod( + "btree_bulkLoad", + FastSeq[ParamType](typeInfo[Long], typeInfo[InputBuffer]), + typeInfo[Unit], + ) val node = f.getCodeParam[Long](1) val ib = f.getCodeParam[InputBuffer](2) val newNode = f.newLocal[Long]() @@ -415,27 +517,32 @@ class AppendOnlyBTree(kb: EmitClassBuilder[_], val key: BTreeKey, region: Value[ f.voidWithBuilder { cb => cb.assign(isInternalNode, ib.readBoolean()) - cb.if_(isInternalNode, { - createNode(cb, newNode) - setChild(cb, node, -1, newNode, "bulkLoad1") - cb.invokeVoid(f, newNode, ib) - }) + cb.if_( + isInternalNode, { + createNode(cb, newNode) + setChild(cb, node, -1, newNode, "bulkLoad1") + cb.invokeVoid(f, cb.this_, newNode, ib) + }, + ) val Lexit = CodeLabel() (0 until maxElements).foreach { i => - cb.if_(ib.readBoolean(), { - setKeyPresent(cb, node, i) - keyLoad(cb, ib, keyOffset(node, i)) - cb.if_(isInternalNode, { - createNode(cb, newNode) - setChild(cb, node, i, newNode, "bulkLoad2") - cb.invokeVoid(f, newNode, ib) - }) - }, { - cb.goto(Lexit) - }) + cb.if_( + ib.readBoolean(), { + setKeyPresent(cb, node, i) + keyLoad(cb, ib, keyOffset(node, i)) + cb.if_( + isInternalNode, { + createNode(cb, newNode) + setChild(cb, node, i, newNode, "bulkLoad2") + cb.invokeVoid(f, cb.this_, newNode, ib) + }, + ) + }, + cb.goto(Lexit), + ) } cb.define(Lexit) } - cb.invokeVoid(f, root, ibCode) + cb.invokeVoid(f, cb.this_, root, ibCode) } } diff --git a/hail/src/main/scala/is/hail/expr/ir/agg/ApproxCDFAggregator.scala b/hail/src/main/scala/is/hail/expr/ir/agg/ApproxCDFAggregator.scala index aeb318fa7db..9d68f82e47b 100644 --- a/hail/src/main/scala/is/hail/expr/ir/agg/ApproxCDFAggregator.scala +++ b/hail/src/main/scala/is/hail/expr/ir/agg/ApproxCDFAggregator.scala @@ -3,11 +3,11 @@ package is.hail.expr.ir.agg import is.hail.annotations.Region import is.hail.asm4s._ import is.hail.backend.ExecuteContext -import is.hail.expr.ir.{EmitClassBuilder, EmitCode, EmitCodeBuilder, EmitContext, IEmitCode} +import is.hail.expr.ir.{EmitClassBuilder, EmitCode, EmitCodeBuilder, IEmitCode} import is.hail.io.{BufferSpec, InputBuffer, OutputBuffer} -import is.hail.types.physical.stypes.concrete.{SBaseStructPointer, SBaseStructPointerValue} import is.hail.types.physical._ import is.hail.types.physical.stypes.EmitType +import is.hail.types.physical.stypes.concrete.{SBaseStructPointer, SBaseStructPointerValue} import is.hail.types.virtual.{TFloat64, TInt32, Type} import is.hail.utils._ @@ -17,7 +17,13 @@ class ApproxCDFState(val kb: EmitClassBuilder[_]) extends AggregatorState { private val r: Settable[Region] = kb.genFieldThisRef[Region]() val region: Value[Region] = r - val storageType: PStruct = PCanonicalStruct(true, ("id", PInt32Required), ("initialized", PBooleanRequired), ("k", PInt32Required)) + val storageType: PStruct = PCanonicalStruct( + true, + ("id", PInt32Required), + ("initialized", PBooleanRequired), + ("k", PInt32Required), + ) + private val aggr = kb.genFieldThisRef[ApproxCDFStateManager]("aggr") private val initialized = kb.genFieldThisRef[Boolean]("initialized") @@ -30,49 +36,66 @@ class ApproxCDFState(val kb: EmitClassBuilder[_]) extends AggregatorState { private val kOffset: Code[Long] => Code[Long] = storageType.loadField(_, "k") def init(cb: EmitCodeBuilder, k: Code[Int]): Unit = { - cb.assign(this.k, k) - cb.assign(aggr, Code.invokeScalaObject1[Int, ApproxCDFStateManager](ApproxCDFStateManager.getClass, "apply", this.k)) - cb.assign(id, region.storeJavaObject(aggr)) - cb.assign(this.initialized, true) + cb.assign(this.k, k) + cb.assign( + aggr, + Code.invokeScalaObject1[Int, ApproxCDFStateManager]( + ApproxCDFStateManager.getClass, + "apply", + this.k, + ), + ) + cb.assign(id, region.storeJavaObject(aggr)) + cb.assign(this.initialized, true) } - def seq(cb: EmitCodeBuilder, x: Code[Double]): Unit = { + def seq(cb: EmitCodeBuilder, x: Code[Double]): Unit = cb += aggr.invoke[Double, Unit]("seqOp", x) - } - def comb(cb: EmitCodeBuilder, other: ApproxCDFState): Unit = { + def comb(cb: EmitCodeBuilder, other: ApproxCDFState): Unit = cb += aggr.invoke[ApproxCDFStateManager, Unit]("combOp", other.aggr) - } - def result(cb: EmitCodeBuilder, region: Value[Region]): SBaseStructPointerValue = { - QuantilesAggregator.resultPType.loadCheapSCode(cb, aggr.invoke[Region, Long]("rvResult", region)) - } + def result(cb: EmitCodeBuilder, region: Value[Region]): SBaseStructPointerValue = + QuantilesAggregator.resultPType.loadCheapSCode( + cb, + aggr.invoke[Region, Long]("rvResult", region), + ) def newState(cb: EmitCodeBuilder, off: Value[Long]): Unit = cb += region.getNewRegion(regionSize) def createState(cb: EmitCodeBuilder): Unit = cb.if_(region.isNull, cb.assign(r, Region.stagedCreate(regionSize, kb.pool()))) - override def load(cb: EmitCodeBuilder, regionLoader: (EmitCodeBuilder, Value[Region]) => Unit, src: Value[Long]): Unit = { + override def load( + cb: EmitCodeBuilder, + regionLoader: (EmitCodeBuilder, Value[Region]) => Unit, + src: Value[Long], + ): Unit = { regionLoader(cb, r) cb.assign(id, Region.loadInt(idOffset(src))) cb.assign(initialized, Region.loadBoolean(initializedOffset(src))) - cb.if_(initialized, - { + cb.if_( + initialized, { cb.assign(aggr, Code.checkcast[ApproxCDFStateManager](region.lookupJavaObject(id))) cb.assign(k, Region.loadInt(kOffset(src))) - }) + }, + ) } - override def store(cb: EmitCodeBuilder, regionStorer: (EmitCodeBuilder, Value[Region]) => Unit, dest: Value[Long]): Unit = { - cb.if_(region.isValid, - { + override def store( + cb: EmitCodeBuilder, + regionStorer: (EmitCodeBuilder, Value[Region]) => Unit, + dest: Value[Long], + ): Unit = { + cb.if_( + region.isValid, { regionStorer(cb, region) cb += region.invalidate() cb += Region.storeInt(idOffset(dest), id) cb += Region.storeInt(kOffset(dest), k) cb += Region.storeBoolean(initializedOffset(dest), initialized) - }) + }, + ) } override def serialize(codec: BufferSpec): (EmitCodeBuilder, Value[OutputBuffer]) => Unit = { @@ -86,18 +109,33 @@ class ApproxCDFState(val kb: EmitClassBuilder[_]) extends AggregatorState { (cb, ib: Value[InputBuffer]) => cb.assign(initialized, ib.readBoolean()) cb.assign(k, ib.readInt()) - cb.if_(initialized, { - cb.assign(aggr, Code.invokeScalaObject2[Int, InputBuffer, ApproxCDFStateManager]( - ApproxCDFStateManager.getClass, "deserializeFrom", k, ib) - ) - - cb.assign(id, region.storeJavaObject(aggr)) - }) + cb.if_( + initialized, { + cb.assign( + aggr, + Code.invokeScalaObject2[Int, InputBuffer, ApproxCDFStateManager]( + ApproxCDFStateManager.getClass, + "deserializeFrom", + k, + ib, + ), + ) + + cb.assign(id, region.storeJavaObject(aggr)) + }, + ) } override def copyFrom(cb: EmitCodeBuilder, src: Value[Long]): Unit = { cb.assign(k, Region.loadInt(kOffset(src))) - cb.assign(aggr, Code.invokeScalaObject1[Int, ApproxCDFStateManager](ApproxCDFStateManager.getClass, "apply", this.k)) + cb.assign( + aggr, + Code.invokeScalaObject1[Int, ApproxCDFStateManager]( + ApproxCDFStateManager.getClass, + "apply", + this.k, + ), + ) cb.assign(id, region.storeJavaObject(aggr)) cb.assign(this.initialized, true) } @@ -113,25 +151,32 @@ class ApproxCDFAggregator extends StagedAggregator { protected def _initOp(cb: EmitCodeBuilder, state: State, init: Array[EmitCode]): Unit = { val Array(k) = init k.toI(cb) - .consume(cb, + .consume( + cb, cb += Code._fatal[Unit]("approx_cdf: 'k' may not be missing"), - pv => state.init(cb, pv.asInt.value)) + pv => state.init(cb, pv.asInt.value), + ) } protected def _seqOp(cb: EmitCodeBuilder, state: State, seq: Array[EmitCode]): Unit = { val Array(x) = seq x.toI(cb) - .consume(cb, + .consume( + cb, {}, - pv => state.seq(cb, pv.asDouble.value) + pv => state.seq(cb, pv.asDouble.value), ) } - protected def _combOp(ctx: ExecuteContext, cb: EmitCodeBuilder, state: ApproxCDFState, other: ApproxCDFState): Unit = { + protected def _combOp( + ctx: ExecuteContext, + cb: EmitCodeBuilder, + region: Value[Region], + state: ApproxCDFState, + other: ApproxCDFState, + ): Unit = state.comb(cb, other) - } - protected def _result(cb: EmitCodeBuilder, state: State, region: Value[Region]): IEmitCode = { + protected def _result(cb: EmitCodeBuilder, state: State, region: Value[Region]): IEmitCode = IEmitCode.present(cb, state.result(cb, region)) - } } diff --git a/hail/src/main/scala/is/hail/expr/ir/agg/ApproxCDFStateManager.scala b/hail/src/main/scala/is/hail/expr/ir/agg/ApproxCDFStateManager.scala index 2b80f86036d..152b1bb99d1 100644 --- a/hail/src/main/scala/is/hail/expr/ir/agg/ApproxCDFStateManager.scala +++ b/hail/src/main/scala/is/hail/expr/ir/agg/ApproxCDFStateManager.scala @@ -11,9 +11,14 @@ object ApproxCDFHelper { def sort(a: Array[Double], begin: Int, end: Int): Unit = java.util.Arrays.sort(a, begin, end) def merge( - left: Array[Double], lStart: Int, lEnd: Int, - right: Array[Double], rStart: Int, rEnd: Int, - out: Array[Double], outStart: Int + left: Array[Double], + lStart: Int, + lEnd: Int, + right: Array[Double], + rStart: Int, + rEnd: Int, + out: Array[Double], + outStart: Int, ): Unit = { assert((left ne out) || (outStart <= lStart - (rEnd - rStart)) || (outStart >= lEnd)) assert((right ne out) || (outStart <= rStart - (lEnd - lStart)) || (outStart >= rEnd)) @@ -73,9 +78,12 @@ object ApproxCDFHelper { } def compactBuffer( - buf: Array[Double], inStart: Int, inEnd: Int, - out: Array[Double], outStart: Int, - skipFirst: Boolean + buf: Array[Double], + inStart: Int, + inEnd: Int, + out: Array[Double], + outStart: Int, + skipFirst: Boolean, ): Unit = { assert((buf ne out) || (outStart <= inStart) || (outStart >= inEnd)) var i = inStart @@ -91,9 +99,12 @@ object ApproxCDFHelper { } def compactBufferBackwards( - buf: Array[Double], inStart: Int, inEnd: Int, - out: Array[Double], outEnd: Int, - skipFirst: Boolean + buf: Array[Double], + inStart: Int, + inEnd: Int, + out: Array[Double], + outEnd: Int, + skipFirst: Boolean, ): Unit = { assert((buf ne out) || (outEnd <= inStart) || (outEnd >= inEnd)) var i = inEnd - 1 @@ -110,12 +121,14 @@ object ApproxCDFHelper { } object ApproxCDFCombiner { - def apply(numLevels: Int, capacity: Int, rand: java.util.Random): ApproxCDFCombiner = new ApproxCDFCombiner( - { val a = Array.ofDim[Int](numLevels + 1); java.util.Arrays.fill(a, capacity); a }, - Array.ofDim[Double](capacity), - Array.ofDim[Int](numLevels), - 1, - rand) + def apply(numLevels: Int, capacity: Int, rand: java.util.Random): ApproxCDFCombiner = + new ApproxCDFCombiner( + { val a = Array.ofDim[Int](numLevels + 1); java.util.Arrays.fill(a, capacity); a }, + Array.ofDim[Double](capacity), + Array.ofDim[Int](numLevels), + 1, + rand, + ) def apply(numLevels: Int, capacity: Int): ApproxCDFCombiner = apply(numLevels, capacity, new java.util.Random()) @@ -147,21 +160,19 @@ object ApproxCDFCombiner { /* Keep a collection of values, grouped into levels. * * Invariants: - * - `items` stores all levels contiguously. Each level above 0 is - * always sorted in non-decreasing order. - * - `levels` tracks the boundaries of the levels stored in `items`. It is - * always non-decreasing, and `levels(numLevels)` always equals `items.length`. - * The values in level i occupy indices from `levels(i)` (inclusive) to - * `levels(i+1)` (exclusive). - * - `numLevels` is the number of levels currently held. The top level is - * never empty, so this is also the greatest nonempty level. - */ + * - `items` stores all levels contiguously. Each level above 0 is always sorted in non-decreasing + * order. + * - `levels` tracks the boundaries of the levels stored in `items`. It is always non-decreasing, + * and `levels(numLevels)` always equals `items.length`. + * The values in level i occupy indices from `levels(i)` (inclusive) to `levels(i+1)` (exclusive). + * - `numLevels` is the number of levels currently held. The top level is never empty, so this is + * also the greatest nonempty level. */ class ApproxCDFCombiner( val levels: Array[Int], val items: Array[Double], val compactionCounts: Array[Int], var numLevels: Int, - val rand: java.util.Random + val rand: java.util.Random, ) extends Serializable { def serializeTo(ob: OutputBuffer): Unit = { @@ -215,7 +226,7 @@ class ApproxCDFCombiner( def safeLevelSize(level: Int): Int = if (level >= maxNumLevels) 0 else levels(level + 1) - levels(level) - def push(t: Double) { + def push(t: Double): Unit = { val bot = levels(0) val newBot = bot - 1 @@ -244,7 +255,7 @@ class ApproxCDFCombiner( new ApproxCDFCombiner(newLevels, newItems, newCompactionCounts, numLevels, rand) } - def clear() { + def clear(): Unit = { numLevels = 1 var i = 0 while (i < levels.length) { @@ -253,13 +264,12 @@ class ApproxCDFCombiner( } } - /* Compact level `level`, merging the compacted results into level `level+1`, - * keeping the 'keep' smallest and 'keep' largest values at 'level'. If - * 'shiftLowerLevels' is true, shift lower levels up to keep items contiguous. + /* Compact level `level`, merging the compacted results into level `level+1`, keeping the 'keep' + * smallest and 'keep' largest values at 'level'. If 'shiftLowerLevels' is true, shift lower + * levels up to keep items contiguous. * - * Returns the new end of 'level'. If 'shiftLowerLevels', this is always - * equal to 'levels(level + 1)`. - */ + * Returns the new end of 'level'. If 'shiftLowerLevels', this is always equal to 'levels(level + + * 1)`. */ def compactLevel(level: Int, shiftLowerLevels: Boolean = true): Int = { val keep = if (level == 0) 1 else 0 @@ -354,9 +364,15 @@ class ApproxCDFCombiner( if (selfPop > 0 && otherPop > 0) ApproxCDFHelper.merge( - items, levels(lvl), levels(lvl + 1), - other.items, other.levels(lvl), other.levels(lvl + 1), - mergedItems, mergedLevels(lvl)) + items, + levels(lvl), + levels(lvl + 1), + other.items, + other.levels(lvl), + other.levels(lvl + 1), + mergedItems, + mergedLevels(lvl), + ) else if (selfPop > 0) System.arraycopy(items, levels(lvl), mergedItems, mergedLevels(lvl), selfPop) else if (otherPop > 0) @@ -382,10 +398,11 @@ class ApproxCDFCombiner( mergedItems, mergedCompactionCounts, math.max(numLevels, other.numLevels), - rand) + rand, + ) } - def generalCompact(minCapacity: Int, levelCapacity: (Int, Int) => Int) { + def generalCompact(minCapacity: Int, levelCapacity: (Int, Int) => Int): Unit = { var currentItemCount = levels(numLevels) - levels(0) // decreases with each compaction var targetItemCount = { // increases if we add levels var lvl = 0 @@ -430,7 +447,7 @@ class ApproxCDFCombiner( } } - def copyFrom(other: ApproxCDFCombiner) { + def copyFrom(other: ApproxCDFCombiner): Unit = { assert(capacity >= other.size) assert(maxNumLevels >= other.numLevels) @@ -495,7 +512,8 @@ object ApproxCDFStateManager { val initLevelsCapacity: Int = QuantilesAggregator.findInitialLevelsCapacity(k, m) val combiner: ApproxCDFCombiner = ApproxCDFCombiner( initLevelsCapacity, - QuantilesAggregator.computeTotalCapacity(initLevelsCapacity, k, m)) + QuantilesAggregator.computeTotalCapacity(initLevelsCapacity, k, m), + ) new ApproxCDFStateManager(k, combiner) } @@ -505,29 +523,32 @@ object ApproxCDFStateManager { a } - def fromData(k: Int, levels: Array[Int], items: Array[Double], compactionCounts: Array[Int]): ApproxCDFStateManager = { + def fromData(k: Int, levels: Array[Int], items: Array[Double], compactionCounts: Array[Int]) + : ApproxCDFStateManager = { val combiner: ApproxCDFCombiner = new ApproxCDFCombiner( - levels, items, compactionCounts, levels.length - 1, new java.util.Random) + levels, + items, + compactionCounts, + levels.length - 1, + new java.util.Random, + ) new ApproxCDFStateManager(k, combiner) } } /* Compute an approximation to the sorted sequence of values seen. * - * Let `n` be the number of non-missing values seen, and let `m` and `M` be - * respectively the minimum and maximum values seen. The result of the - * aggregator is an array "values" of samples, in increasing order, and an array - * "ranks" of integers less than `n`, in increasing order, such that: + * Let `n` be the number of non-missing values seen, and let `m` and `M` be respectively the minimum + * and maximum values seen. The result of the aggregator is an array "values" of samples, in + * increasing order, and an array "ranks" of integers less than `n`, in increasing order, such that: * - ranks.length = values.length + 1 * - ranks(0) = 0 * - ranks(values.length) = n * - values(0) = m - * - values(values.length - 1) = M - * These represent a summary of the sorted list of values seen by the - * aggregator. For example, values=[0,2,5,6,9] and ranks=[0,3,4,5,8,10] - * represents the approximation [0,0,0,2,5,6,6,6,9,9], with the value - * `values(i)` occupying indices `ranks(i)` to `ranks(i+1)` (again half-open). - */ + * - values(values.length - 1) = M These represent a summary of the sorted list of values seen by + * the aggregator. For example, values=[0,2,5,6,9] and ranks=[0,3,4,5,8,10] represents the + * approximation [0,0,0,2,5,6,6,6,9,9], with the value `values(i)` occupying indices `ranks(i)` to + * `ranks(i+1)` (again half-open). */ class ApproxCDFStateManager(val k: Int, var combiner: ApproxCDFCombiner) { val m: Int = ApproxCDFStateManager.defaultM private val growthRate: Int = 4 @@ -536,41 +557,35 @@ class ApproxCDFStateManager(val k: Int, var combiner: ApproxCDFCombiner) { /* The sketch maintains a sample of items seen, organized into levels. * - * Samples in level i represent 2^i items from the original stream. Whenever - * `items` fills up, we make room by "compacting" a full level. Compacting - * means sorting (if the level wasn't already sorted), throwing away every - * other sample (taking the evens or the odds with equal probability), and - * adding the remaining samples to the level above (where now each kept sample + * Samples in level i represent 2^i items from the original stream. Whenever `items` fills up, we + * make room by "compacting" a full level. Compacting means sorting (if the level wasn't already + * sorted), throwing away every other sample (taking the evens or the odds with equal + * probability), and adding the remaining samples to the level above (where now each kept sample * represents twice as many items). * - * Let `levelCapacity(i)`=k*(2/3)^(numLevels-i). A compaction operation at - * level i is correct if the level contains at least `levelCapacity(i)` - * samples at the time of compaction. As long as this holds, the analysis from - * the paper [KLL] applies. This leaves room for several compaction - * strategies, of which we implement two, with the `eager` flag choosing - * between them. + * Let `levelCapacity(i)`=k*(2/3)^(numLevels-i). A compaction operation at level i is correct if + * the level contains at least `levelCapacity(i)` samples at the time of compaction. As long as + * this holds, the analysis from the paper [KLL] applies. This leaves room for several compaction + * strategies, of which we implement two, with the `eager` flag choosing between them. * - * To keep things simple, we require that any level contains a minimum of m - * samples at the time of compaction, where `m` is a class parameter, m>=2, - * controlling the minimum size of a compaction. Because of this minimum size, - * we must (very slowly) grow the `items` buffer over time. + * To keep things simple, we require that any level contains a minimum of m samples at the time of + * compaction, where `m` is a class parameter, m>=2, controlling the minimum size of a compaction. + * Because of this minimum size, we must (very slowly) grow the `items` buffer over time. * - * To maintain the correct total weight, we only compact even numbers of - * samples. If a level contains an odd number of samples when compacting, - * we leave one sample at the lower level. + * To maintain the correct total weight, we only compact even numbers of samples. If a level + * contains an odd number of samples when compacting, we leave one sample at the lower level. * * Invariants: * - `n` is the number of items seen. - * - `levelsCapacity` is the number of levels `items` and `levels` have room - * for before we need to reallocate. - * - `numLevels` is the number of levels currently held. The top level is - * never empty, so this is also the greatest nonempty level. - * - `items.length` is always at least the sum of all level capacities up to - * `numLevels`. Thus if `items` is full, at least one level must be full. + * - `levelsCapacity` is the number of levels `items` and `levels` have room for before we need to + * reallocate. + * - `numLevels` is the number of levels currently held. The top level is never empty, so this is + * also the greatest nonempty level. + * - `items.length` is always at least the sum of all level capacities up to `numLevels`. Thus if + * `items` is full, at least one level must be full. * * [KLL] "Optimal Quantile Approximation in Streams", Karnin, Lang, and Liberty - * https://github.com/DataSketches/sketches-core/tree/master/src/main/java/com/yahoo/sketches/kll - */ + * https://github.com/DataSketches/sketches-core/tree/master/src/main/java/com/yahoo/sketches/kll */ def levels: Array[Int] = combiner.levels @@ -597,7 +612,7 @@ class ApproxCDFStateManager(val k: Int, var combiner: ApproxCDFCombiner) { combiner.push(x) } - def combOp(other: ApproxCDFStateManager) { + def combOp(other: ApproxCDFStateManager): Unit = { assert(m == other.m) if (other.numLevels == 1) { var i = other.levels(0) @@ -651,15 +666,13 @@ class ApproxCDFStateManager(val k: Int, var combiner: ApproxCDFCombiner) { rvb.end() } - def clear() { + def clear(): Unit = combiner.clear() - } private def findFullLevel(): Int = { var level: Int = 0 - while (levels(level + 1) - levels(level) < levelCapacity(level)) { + while (levels(level + 1) - levels(level) < levelCapacity(level)) level += 1 - } level } @@ -668,10 +681,8 @@ class ApproxCDFStateManager(val k: Int, var combiner: ApproxCDFCombiner) { if (depth < capacities.length) capacities(depth) else m } - /* Compact the first over-capacity level. If that is the top level, grow the - * sketch. - */ - private def compact() { + /* Compact the first over-capacity level. If that is the top level, grow the sketch. */ + private def compact(): Unit = { assert(combiner.isFull) val level = findFullLevel() if (level == numLevels - 1) growSketch() @@ -679,11 +690,10 @@ class ApproxCDFStateManager(val k: Int, var combiner: ApproxCDFCombiner) { combiner.compactLevel(level) } - /* If we are following the eager compacting strategy, level 0 must be full - * when starting a compaction. This strategy sacrifices some accuracy, but - * avoids having to shift up items below the compacted level. - */ - private def compactEager() { + /* If we are following the eager compacting strategy, level 0 must be full when starting a + * compaction. This strategy sacrifices some accuracy, but avoids having to shift up items below + * the compacted level. */ + private def compactEager(): Unit = { assert(combiner.levelSize(0) >= levelCapacity(0)) var level = 0 @@ -703,14 +713,14 @@ class ApproxCDFStateManager(val k: Int, var combiner: ApproxCDFCombiner) { } while (levels(level) < desiredFreeCapacity && !grew) } - private def growSketch() { + private def growSketch(): Unit = if (combiner.numLevels == combiner.maxNumLevels) combiner = combiner.grow( combiner.maxNumLevels + growthRate, - combiner.capacity + m * growthRate) - } + combiner.capacity + m * growthRate, + ) - private def merge(other: ApproxCDFStateManager) { + private def merge(other: ApproxCDFStateManager): Unit = { val finalN = n + other.n val ub = QuantilesAggregator.ubOnNumLevels(finalN) @@ -727,17 +737,18 @@ class ApproxCDFStateManager(val k: Int, var combiner: ApproxCDFCombiner) { private def computeTotalCapacity(numLevels: Int): Int = QuantilesAggregator.computeTotalCapacity(numLevels, k, m) - def serializeTo(ob: OutputBuffer): Unit = { + def serializeTo(ob: OutputBuffer): Unit = combiner.serializeTo(ob) - } } object QuantilesAggregator { val resultPType: PCanonicalStruct = - PCanonicalStruct(required = false, + PCanonicalStruct( + required = false, "levels" -> PCanonicalArray(PInt32(true), required = true), "items" -> PCanonicalArray(PFloat64(true), required = true), - "_compaction_counts" -> PCanonicalArray(PInt32(true), required = true)) + "_compaction_counts" -> PCanonicalArray(PInt32(true), required = true), + ) def floorOfLog2OfFraction(numer: Long, denom: Long): Int = { var count = 0 diff --git a/hail/src/main/scala/is/hail/expr/ir/agg/ArrayElementLengthCheckAggregator.scala b/hail/src/main/scala/is/hail/expr/ir/agg/ArrayElementLengthCheckAggregator.scala index 67e01ea4163..a155caa383b 100644 --- a/hail/src/main/scala/is/hail/expr/ir/agg/ArrayElementLengthCheckAggregator.scala +++ b/hail/src/main/scala/is/hail/expr/ir/agg/ArrayElementLengthCheckAggregator.scala @@ -7,14 +7,15 @@ import is.hail.expr.ir._ import is.hail.io.{BufferSpec, InputBuffer, OutputBuffer} import is.hail.types.physical._ import is.hail.types.physical.stypes.{EmitType, SValue} -import is.hail.types.physical.stypes.concrete.{SBaseStructPointer, SIndexablePointer} +import is.hail.types.physical.stypes.concrete.SIndexablePointer import is.hail.types.virtual.{TInt32, TVoid, Type} import is.hail.utils._ // initOp args: initOps for nestedAgg, length if knownLength = true // seqOp args: array, other non-elt args for nestedAgg -class ArrayElementState(val kb: EmitClassBuilder[_], val nested: StateTuple) extends PointerBasedRVAState { +class ArrayElementState(val kb: EmitClassBuilder[_], val nested: StateTuple) + extends PointerBasedRVAState { val arrayType: PArray = PCanonicalArray(nested.storageType) private val nStates: Int = nested.nStates override val regionSize: Int = Region.SMALL @@ -33,23 +34,36 @@ class ArrayElementState(val kb: EmitClassBuilder[_], val nested: StateTuple) ext def get: Code[Long] = arrayType.loadElement(typ.loadField(off, 1), eltIdx) } - val initContainer: TupleAggregatorState = new TupleAggregatorState(kb, nested, region, new Value[Long] { - def get: Code[Long] = typ.loadField(off, 0) - }) - val container: TupleAggregatorState = new TupleAggregatorState(kb, nested, region, statesOffset(idx), regionOffset(idx)) + val initContainer: TupleAggregatorState = new TupleAggregatorState( + kb, + nested, + region, + new Value[Long] { + def get: Code[Long] = typ.loadField(off, 0) + }, + ) + + val container: TupleAggregatorState = + new TupleAggregatorState(kb, nested, region, statesOffset(idx), regionOffset(idx)) override def createState(cb: EmitCodeBuilder): Unit = { super.createState(cb) nested.createStates(cb) } - override def load(cb: EmitCodeBuilder, regionLoader: (EmitCodeBuilder, Value[Region]) => Unit, src: Value[Long]): Unit = { + override def load( + cb: EmitCodeBuilder, + regionLoader: (EmitCodeBuilder, Value[Region]) => Unit, + src: Value[Long], + ): Unit = { super.load(cb, regionLoader, src) - cb.if_(off.cne(0L), - { - cb.assign(lenRef, typ.isFieldMissing(cb, off, 1).mux(-1, - arrayType.loadLength(typ.loadField(off, 1)))) - }) + cb.if_( + off.cne(0L), + cb.assign( + lenRef, + typ.isFieldMissing(cb, off, 1).mux(-1, arrayType.loadLength(typ.loadField(off, 1))), + ), + ) } def initArray(cb: EmitCodeBuilder): Unit = { @@ -63,28 +77,29 @@ class ArrayElementState(val kb: EmitClassBuilder[_], val nested: StateTuple) ext def seq(cb: EmitCodeBuilder, init: => Unit, initPerElt: => Unit, seqOp: => Unit): Unit = { init cb.assign(idx, 0) - cb.while_(idx < lenRef, { - initPerElt - seqOp - store(cb) - cb.assign(idx, idx + 1) - }) + cb.while_( + idx < lenRef, { + initPerElt + seqOp + store(cb) + cb.assign(idx, idx + 1) + }, + ) } - def seq(cb: EmitCodeBuilder, seqOp: => Unit): Unit = - seq(cb, { - initArray(cb) - }, container.newState(cb), seqOp) + seq(cb, initArray(cb), container.newState(cb), seqOp) def initLength(cb: EmitCodeBuilder, len: Code[Int]): Unit = { cb.assign(lenRef, len) seq(cb, container.copyFrom(cb, initContainer.off)) } - def checkLength(cb: EmitCodeBuilder, len: Code[Int]): Unit = { - cb.if_(lenRef.cne(len), cb += Code._fatal[Unit]("mismatched lengths in ArrayElementsAggregator")) - } + def checkLength(cb: EmitCodeBuilder, len: Code[Int]): Unit = + cb.if_( + lenRef.cne(len), + cb += Code._fatal[Unit]("mismatched lengths in ArrayElementsAggregator"), + ) def init(cb: EmitCodeBuilder, initOp: (EmitCodeBuilder) => Unit, initLen: Boolean): Unit = { cb += region.setNumParents(nStates) @@ -107,61 +122,65 @@ class ArrayElementState(val kb: EmitClassBuilder[_], val nested: StateTuple) ext val serializers = nested.states.map(_.serialize(codec)); { (cb: EmitCodeBuilder, ob: Value[OutputBuffer]) => loadInit(cb) - nested.toCodeWithArgs(cb, - { (cb, i, _) => - serializers(i)(cb, ob) - }) + nested.toCodeWithArgs(cb, (cb, i, _) => serializers(i)(cb, ob)) cb += ob.writeInt(lenRef) cb.assign(idx, 0) - cb.while_(idx < lenRef, { - load(cb) - nested.toCodeWithArgs(cb, - { case (cb, i, _) => - serializers(i)(cb, ob) - }) - cb.assign(idx, idx + 1) - }) + cb.while_( + idx < lenRef, { + load(cb) + nested.toCodeWithArgs( + cb, + { case (cb, i, _) => + serializers(i)(cb, ob) + }, + ) + cb.assign(idx, idx + 1) + }, + ) } } def deserialize(codec: BufferSpec): (EmitCodeBuilder, Value[InputBuffer]) => Unit = { val deserializers = nested.states.map(_.deserialize(codec)); { (cb: EmitCodeBuilder, ib: Value[InputBuffer]) => - init(cb, cb => nested.toCodeWithArgs(cb, - { (cb, i, _) => - deserializers(i)(cb, ib) - }), - initLen = false) + init( + cb, + cb => nested.toCodeWithArgs(cb, (cb, i, _) => deserializers(i)(cb, ib)), + initLen = false, + ) cb.assign(lenRef, ib.readInt()) - cb.if_(lenRef < 0, { - typ.setFieldMissing(cb, off, 1) - }, { - seq(cb, { - nested.toCodeWithArgs(cb, - { (cb, i, _) => - deserializers(i)(cb, ib) - }) - }) - }) + cb.if_( + lenRef < 0, + typ.setFieldMissing(cb, off, 1), + seq(cb, nested.toCodeWithArgs(cb, (cb, i, _) => deserializers(i)(cb, ib))), + ) } } def copyFromAddress(cb: EmitCodeBuilder, src: Value[Long]): Unit = { init(cb, cb => initContainer.copyFrom(cb, cb.memoize(typ.loadField(src, 0))), initLen = false) - cb.if_(typ.isFieldMissing(cb, src, 1), { - typ.setFieldMissing(cb, off, 1) - cb.assign(lenRef, -1) - }, { - cb.assign(lenRef, arrayType.loadLength(typ.loadField(src, 1))) - seq(cb, container.copyFrom(cb, cb.memoize(arrayType.loadElement(typ.loadField(src, 1), idx)))) - }) + cb.if_( + typ.isFieldMissing(cb, src, 1), { + typ.setFieldMissing(cb, off, 1) + cb.assign(lenRef, -1) + }, { + cb.assign(lenRef, arrayType.loadLength(typ.loadField(src, 1))) + seq( + cb, + container.copyFrom(cb, cb.memoize(arrayType.loadElement(typ.loadField(src, 1), idx))), + ) + }, + ) } } -class ArrayElementLengthCheckAggregator(nestedAggs: Array[StagedAggregator], knownLength: Boolean) extends StagedAggregator { +class ArrayElementLengthCheckAggregator(nestedAggs: Array[StagedAggregator], knownLength: Boolean) + extends StagedAggregator { type State = ArrayElementState - val resultEltType: PCanonicalTuple = PCanonicalTuple(true, nestedAggs.map(_.resultEmitType.storageType): _*) + val resultEltType: PCanonicalTuple = + PCanonicalTuple(true, nestedAggs.map(_.resultEmitType.storageType): _*) + val resultPType: PCanonicalArray = PCanonicalArray(resultEltType) override def resultEmitType = EmitType(SIndexablePointer(resultPType), knownLength) @@ -173,8 +192,11 @@ class ArrayElementLengthCheckAggregator(nestedAggs: Array[StagedAggregator], kno if (knownLength) { val Array(len, inits) = init state.init(cb, cb => cb += inits.asVoid, initLen = false) - len.toI(cb).consume(cb, cb._fatal("Array length can't be missing"), - len => state.initLength(cb, len.asInt32.value)) + len.toI(cb).consume( + cb, + cb._fatal("Array length can't be missing"), + len => state.initLength(cb, len.asInt32.value), + ) } else { val Array(inits) = init state.init(cb, cb => cb += inits.asVoid, initLen = true) @@ -186,45 +208,55 @@ class ArrayElementLengthCheckAggregator(nestedAggs: Array[StagedAggregator], kno protected def _seqOp(cb: EmitCodeBuilder, state: State, seq: Array[EmitCode]): Unit = { assert(seq.length == 1) val len = seq.head - len.toI(cb).consume(cb, { - /* do nothing */ - }, { len => - if (!knownLength) { - val v = cb.newLocal("aelca_seqop_len", len.asInt.value) - cb.if_(state.lenRef < 0, state.initLength(cb, v), state.checkLength(cb, v)) - } else { - state.checkLength(cb, len.asInt.value) - } - }) - } - - protected def _combOp(ctx: ExecuteContext, cb: EmitCodeBuilder, state: ArrayElementState, other: ArrayElementState): Unit = { - state.seq(cb, { - cb.if_(other.lenRef < 0, { - cb.if_(state.lenRef >= 0, { - other.initLength(cb, state.lenRef) - }) - }, { + len.toI(cb).consume( + cb, { + /* do nothing */ + }, + { len => if (!knownLength) { - cb.if_(state.lenRef < 0, { - state.initLength(cb, other.lenRef) - }, { - state.checkLength(cb, other.lenRef) - }) + val v = cb.newLocal("aelca_seqop_len", len.asInt.value) + cb.if_(state.lenRef < 0, state.initLength(cb, v), state.checkLength(cb, v)) } else { - state.checkLength(cb, other.lenRef) + state.checkLength(cb, len.asInt.value) } - }) - }, { - cb.assign(other.idx, state.idx) - other.load(cb) - state.load(cb) - }, { - state.nested.toCode((i, s) => nestedAggs(i).combOp(ctx, cb, s, other.nested(i))) - }) + }, + ) + } + + protected def _combOp( + ctx: ExecuteContext, + cb: EmitCodeBuilder, + region: Value[Region], + state: ArrayElementState, + other: ArrayElementState, + ): Unit = { + state.seq( + cb, { + cb.if_( + other.lenRef < 0, + cb.if_(state.lenRef >= 0, other.initLength(cb, state.lenRef)), { + if (!knownLength) { + cb.if_( + state.lenRef < 0, + state.initLength(cb, other.lenRef), + state.checkLength(cb, other.lenRef), + ) + } else { + state.checkLength(cb, other.lenRef) + } + }, + ) + }, { + cb.assign(other.idx, state.idx) + other.load(cb) + state.load(cb) + }, + state.nested.toCode((i, s) => nestedAggs(i).combOp(ctx, cb, region, s, other.nested(i))), + ) } - protected override def _result(cb: EmitCodeBuilder, state: State, region: Value[Region]): IEmitCode = { + override protected def _result(cb: EmitCodeBuilder, state: State, region: Value[Region]) + : IEmitCode = { val len = state.lenRef def resultBody(cb: EmitCodeBuilder): SValue = { @@ -232,29 +264,38 @@ class ArrayElementLengthCheckAggregator(nestedAggs: Array[StagedAggregator], kno resultPType.stagedInitialize(cb, resultAddr, len, setMissing = false) val i = cb.newLocal[Int]("arrayagg_result_i", 0) - cb.while_(i < len, { - val addrAtI = cb.newLocal[Long]("arrayagg_result_addr_at_i", resultPType.elementOffset(resultAddr, len, i)) - resultEltType.stagedInitialize(cb, addrAtI, setMissing = false) - cb.assign(state.idx, i) - state.load(cb) - state.nested.toCode { case (nestedIdx, nestedState) => - val nestedAddr = cb.newLocal[Long](s"arrayagg_result_nested_addr_$nestedIdx", resultEltType.fieldOffset(addrAtI, nestedIdx)) - val nestedRes = nestedAggs(nestedIdx).result(cb, nestedState, region) - nestedRes.consume(cb, - { resultEltType.setFieldMissing(cb, addrAtI, nestedIdx)}, - { sv => resultEltType.types(nestedIdx).storeAtAddress(cb, nestedAddr, region, sv, true)}) - } - state.store(cb) - cb.assign(i, i + 1) - }) + cb.while_( + i < len, { + val addrAtI = cb.newLocal[Long]( + "arrayagg_result_addr_at_i", + resultPType.elementOffset(resultAddr, len, i), + ) + resultEltType.stagedInitialize(cb, addrAtI, setMissing = false) + cb.assign(state.idx, i) + state.load(cb) + state.nested.toCode { case (nestedIdx, nestedState) => + val nestedAddr = cb.newLocal[Long]( + s"arrayagg_result_nested_addr_$nestedIdx", + resultEltType.fieldOffset(addrAtI, nestedIdx), + ) + val nestedRes = nestedAggs(nestedIdx).result(cb, nestedState, region) + nestedRes.consume( + cb, + resultEltType.setFieldMissing(cb, addrAtI, nestedIdx), + sv => resultEltType.types(nestedIdx).storeAtAddress(cb, nestedAddr, region, sv, true), + ) + } + state.store(cb) + cb.assign(i, i + 1) + }, + ) // don't need to deep copy because that's done in nested aggregators resultPType.loadCheapSCode(cb, resultAddr) } if (knownLength) { IEmitCode.present(cb, resultBody(cb)) - } - else { + } else { IEmitCode(cb, len < 0, resultBody(cb)) } } @@ -266,29 +307,48 @@ class ArrayElementwiseOpAggregator(nestedAggs: Array[StagedAggregator]) extends val initOpTypes: Seq[Type] = Array[Type]() val seqOpTypes: Seq[Type] = Array[Type](TInt32, TVoid) - val resultPType = PCanonicalArray(PCanonicalTuple(false, nestedAggs.map(_.resultEmitType.storageType): _*)) + val resultPType = + PCanonicalArray(PCanonicalTuple(false, nestedAggs.map(_.resultEmitType.storageType): _*)) + override def resultEmitType = EmitType(SIndexablePointer(resultPType), false) protected def _initOp(cb: EmitCodeBuilder, state: State, init: Array[EmitCode]): Unit = - throw new UnsupportedOperationException("State must be initialized by ArrayElementLengthCheckAggregator.") + throw new UnsupportedOperationException( + "State must be initialized by ArrayElementLengthCheckAggregator." + ) protected def _seqOp(cb: EmitCodeBuilder, state: State, seq: Array[EmitCode]): Unit = { val Array(eltIdx, seqOps) = seq - eltIdx.toI(cb).consume(cb, {}, { idx => - cb.assign(state.idx, idx.asInt32.value) - cb.if_(state.idx > state.lenRef || state.idx < 0, { - cb._fatal("element idx out of bounds") - }, { - state.load(cb) - cb += seqOps.asVoid - state.store(cb) - }) - }) + eltIdx.toI(cb).consume( + cb, + {}, + { idx => + cb.assign(state.idx, idx.asInt32.value) + cb.if_( + state.idx > state.lenRef || state.idx < 0, + cb._fatal("element idx out of bounds"), { + state.load(cb) + cb += seqOps.asVoid + state.store(cb) + }, + ) + }, + ) } - protected def _combOp(ctx: ExecuteContext, cb: EmitCodeBuilder, state: ArrayElementState, other: ArrayElementState): Unit = - throw new UnsupportedOperationException("State must be combined by ArrayElementLengthCheckAggregator.") + protected def _combOp( + ctx: ExecuteContext, + cb: EmitCodeBuilder, + region: Value[Region], + state: ArrayElementState, + other: ArrayElementState, + ): Unit = + throw new UnsupportedOperationException( + "State must be combined by ArrayElementLengthCheckAggregator." + ) protected def _result(cb: EmitCodeBuilder, state: State, region: Value[Region]): IEmitCode = - throw new UnsupportedOperationException("Result must be defined by ArrayElementLengthCheckAggregator.") + throw new UnsupportedOperationException( + "Result must be defined by ArrayElementLengthCheckAggregator." + ) } diff --git a/hail/src/main/scala/is/hail/expr/ir/agg/CallStatsAggregator.scala b/hail/src/main/scala/is/hail/expr/ir/agg/CallStatsAggregator.scala index d268994d9cf..04bf5b3a63e 100644 --- a/hail/src/main/scala/is/hail/expr/ir/agg/CallStatsAggregator.scala +++ b/hail/src/main/scala/is/hail/expr/ir/agg/CallStatsAggregator.scala @@ -3,27 +3,28 @@ package is.hail.expr.ir.agg import is.hail.annotations.Region import is.hail.asm4s._ import is.hail.backend.ExecuteContext -import is.hail.expr.ir.{EmitClassBuilder, EmitCode, EmitCodeBuilder, EmitContext, IEmitCode} +import is.hail.expr.ir.{EmitClassBuilder, EmitCode, EmitCodeBuilder, IEmitCode} import is.hail.io.{BufferSpec, InputBuffer, OutputBuffer, TypedCodecSpec} import is.hail.types.physical._ import is.hail.types.physical.stypes.EmitType import is.hail.types.physical.stypes.concrete.SBaseStructPointer -import is.hail.types.virtual.{TCall, TInt32, Type} import is.hail.types.physical.stypes.interfaces._ +import is.hail.types.virtual.{TCall, TInt32, Type} import is.hail.utils._ -import scala.language.existentials - - object CallStatsState { val callStatsInternalArrayType = PCanonicalArray(PInt32Required, required = true) - val stateType: PCanonicalTuple = PCanonicalTuple(true, callStatsInternalArrayType, callStatsInternalArrayType) - val resultPType = PCanonicalStruct(required = false, + val stateType: PCanonicalTuple = + PCanonicalTuple(true, callStatsInternalArrayType, callStatsInternalArrayType) + + val resultPType = PCanonicalStruct( + required = false, "AC" -> PCanonicalArray(PInt32(true), required = true), "AF" -> PCanonicalArray(PFloat64(true), required = false), "AN" -> PInt32(true), - "homozygote_count" -> PCanonicalArray(PInt32(true), required = true)) + "homozygote_count" -> PCanonicalArray(PInt32(true), required = true), + ) } @@ -39,24 +40,30 @@ class CallStatsState(val kb: EmitClassBuilder[_]) extends PointerBasedRVAState { val nAlleles: Settable[Int] = kb.genFieldThisRef[Int]() private val addr = kb.genFieldThisRef[Long]() - def loadNAlleles(cb: EmitCodeBuilder): Unit = { + def loadNAlleles(cb: EmitCodeBuilder): Unit = cb.assign(nAlleles, CallStatsState.callStatsInternalArrayType.loadLength(alleleCounts)) - } // unused but extremely useful for debugging if something goes wrong def dump(cb: CodeBuilderLike, tag: String): Unit = { val i = cb.newLocal[Int]("i") cb += Code._println(s"at tag $tag") - cb.for_(cb.assign(i, 0), i < nAlleles, cb.assign(i, i + 1), { + cb.for_( + cb.assign(i, 0), + i < nAlleles, + cb.assign(i, i + 1), cb += Code._println( const("at i=") + i.toS + ", AC=" + alleleCountAtIndex(i, nAlleles).toS + ", HOM=" + homCountAtIndex(i, nAlleles).toS - ) - }) + ), + ) } - override def load(cb: EmitCodeBuilder, regionLoader: (EmitCodeBuilder, Value[Region]) => Unit, src: Value[Long]): Unit = { + override def load( + cb: EmitCodeBuilder, + regionLoader: (EmitCodeBuilder, Value[Region]) => Unit, + src: Value[Long], + ): Unit = { super.load(cb, regionLoader, src) loadNAlleles(cb) } @@ -64,16 +71,28 @@ class CallStatsState(val kb: EmitClassBuilder[_]) extends PointerBasedRVAState { def alleleCountAtIndex(idx: Code[Int], length: Code[Int]): Code[Int] = Region.loadInt(CallStatsState.callStatsInternalArrayType.loadElement(alleleCounts, length, idx)) - def updateAlleleCountAtIndex(cb: EmitCodeBuilder, idx: Code[Int], length: Code[Int], updater: Code[Int] => Code[Int]): Unit = { - cb.assign(addr, CallStatsState.callStatsInternalArrayType.loadElement(alleleCounts, length, idx)) + def updateAlleleCountAtIndex( + cb: EmitCodeBuilder, + idx: Code[Int], + length: Code[Int], + updater: Code[Int] => Code[Int], + ): Unit = { + cb.assign( + addr, + CallStatsState.callStatsInternalArrayType.loadElement(alleleCounts, length, idx), + ) cb += Region.storeInt(addr, updater(Region.loadInt(addr))) } def homCountAtIndex(idx: Code[Int], length: Code[Int]): Code[Int] = Region.loadInt(CallStatsState.callStatsInternalArrayType.loadElement(homCounts, length, idx)) - - def updateHomCountAtIndex(cb: EmitCodeBuilder, idx: Code[Int], length: Code[Int], updater: Code[Int] => Code[Int]): Unit = { + def updateHomCountAtIndex( + cb: EmitCodeBuilder, + idx: Code[Int], + length: Code[Int], + updater: Code[Int] => Code[Int], + ): Unit = { cb.assign(addr, CallStatsState.callStatsInternalArrayType.loadElement(homCounts, length, idx)) cb += Region.storeInt(addr, updater(Region.loadInt(addr))) } @@ -85,18 +104,25 @@ class CallStatsState(val kb: EmitClassBuilder[_]) extends PointerBasedRVAState { } def deserialize(codec: BufferSpec): (EmitCodeBuilder, Value[InputBuffer]) => Unit = { - { (cb: EmitCodeBuilder, ib: Value[InputBuffer]) => + (cb: EmitCodeBuilder, ib: Value[InputBuffer]) => val codecSpec = TypedCodecSpec(CallStatsState.stateType, codec) val decValue = codecSpec.encodedType.buildDecoder(CallStatsState.stateType.virtualType, kb) .apply(cb, region, ib) cb.assign(off, CallStatsState.stateType.store(cb, region, decValue, deepCopy = false)) loadNAlleles(cb) - } } def copyFromAddress(cb: EmitCodeBuilder, src: Value[Long]): Unit = { - cb.assign(off, CallStatsState.stateType.store(cb, region, CallStatsState.stateType.loadCheapSCode(cb, src), deepCopy = true)) + cb.assign( + off, + CallStatsState.stateType.store( + cb, + region, + CallStatsState.stateType.loadCheapSCode(cb, src), + deepCopy = true, + ), + ) loadNAlleles(cb) } } @@ -118,12 +144,19 @@ class CallStatsAggregator extends StagedAggregator { val i = state.kb.genFieldThisRef[Int]() nAlleles.toI(cb) - .consume(cb, + .consume( + cb, cb += Code._fatal[Unit]("hl.agg.call_stats: n_alleles may not be missing"), { sc => cb.assign(n, sc.asInt.value) cb.assign(state.nAlleles, n) - cb.assign(state.off, state.region.allocate(CallStatsState.stateType.alignment, CallStatsState.stateType.byteSize)) + cb.assign( + state.off, + state.region.allocate( + CallStatsState.stateType.alignment, + CallStatsState.stateType.byteSize, + ), + ) cb.assign(addr, CallStatsState.callStatsInternalArrayType.allocate(state.region, n)) CallStatsState.callStatsInternalArrayType.stagedInitialize(cb, addr, n) cb += Region.storeAddress(state.alleleCountsOffset, addr) @@ -131,56 +164,78 @@ class CallStatsAggregator extends StagedAggregator { CallStatsState.callStatsInternalArrayType.stagedInitialize(cb, addr, n) cb += Region.storeAddress(state.homCountsOffset, addr) cb.assign(i, 0) - cb.while_(i < n, - { + cb.while_( + i < n, { state.updateAlleleCountAtIndex(cb, i, n, _ => 0) state.updateHomCountAtIndex(cb, i, n, _ => 0) cb.assign(i, i + 1) - }) - }) + }, + ) + }, + ) } protected def _seqOp(cb: EmitCodeBuilder, state: State, seq: Array[EmitCode]): Unit = { val Array(call) = seq - call.toI(cb).consume(cb, { - /* do nothing if missing */ - }, { case call: SCallValue => - val hom = cb.newLocal[Boolean]("hom", true) - val lastAllele = cb.newLocal[Int]("lastAllele", -1) - val i = cb.newLocal[Int]("i", 0) - call.forEachAllele(cb) { allele: Value[Int] => - cb.if_(allele > state.nAlleles, - cb._fatal(const("hl.agg.call_stats: found allele outside of expected range [0, ") - .concat(state.nAlleles.toS).concat("]: ").concat(allele.toS))) - state.updateAlleleCountAtIndex(cb, allele, state.nAlleles, _ + 1) - cb.if_(i > 0, cb.assign(hom, hom && allele.ceq(lastAllele))) - cb.assign(lastAllele, allele) - cb.assign(i, i + 1) - } + call.toI(cb).consume( + cb, { + /* do nothing if missing */ + }, + { case call: SCallValue => + val hom = cb.newLocal[Boolean]("hom", true) + val lastAllele = cb.newLocal[Int]("lastAllele", -1) + val i = cb.newLocal[Int]("i", 0) + call.forEachAllele(cb) { allele: Value[Int] => + cb.if_( + allele > state.nAlleles, + cb._fatal(const("hl.agg.call_stats: found allele outside of expected range [0, ") + .concat(state.nAlleles.toS).concat("]: ").concat(allele.toS)), + ) + state.updateAlleleCountAtIndex(cb, allele, state.nAlleles, _ + 1) + cb.if_(i > 0, cb.assign(hom, hom && allele.ceq(lastAllele))) + cb.assign(lastAllele, allele) + cb.assign(i, i + 1) + } - cb.if_((i > 1) && hom, { - state.updateHomCountAtIndex(cb, lastAllele, state.nAlleles, _ + 1) - }) - }) + cb.if_((i > 1) && hom, state.updateHomCountAtIndex(cb, lastAllele, state.nAlleles, _ + 1)) + }, + ) } - protected def _combOp(ctx: ExecuteContext, cb: EmitCodeBuilder, state: CallStatsState, other: CallStatsState): Unit = { + protected def _combOp( + ctx: ExecuteContext, + cb: EmitCodeBuilder, + region: Value[Region], + state: CallStatsState, + other: CallStatsState, + ): Unit = { val i = state.kb.genFieldThisRef[Int]() - cb.if_(other.nAlleles.cne(state.nAlleles), - cb += Code._fatal[Unit]("hl.agg.call_stats: length mismatch"), - { + cb.if_( + other.nAlleles.cne(state.nAlleles), + cb += Code._fatal[Unit]("hl.agg.call_stats: length mismatch"), { cb.assign(i, 0) - cb.while_(i < state.nAlleles, - { - state.updateAlleleCountAtIndex(cb, i, state.nAlleles, _ + other.alleleCountAtIndex(i, state.nAlleles)) - state.updateHomCountAtIndex(cb, i, state.nAlleles, _ + other.homCountAtIndex(i, state.nAlleles)) + cb.while_( + i < state.nAlleles, { + state.updateAlleleCountAtIndex( + cb, + i, + state.nAlleles, + _ + other.alleleCountAtIndex(i, state.nAlleles), + ) + state.updateHomCountAtIndex( + cb, + i, + state.nAlleles, + _ + other.homCountAtIndex(i, state.nAlleles), + ) cb.assign(i, i + 1) - }) - }) + }, + ) + }, + ) } - protected def _result(cb: EmitCodeBuilder, state: State, region: Value[Region]): IEmitCode = { val rt = CallStatsState.resultPType val addr = cb.memoize(rt.allocate(region), "call_stats_aggregator_result_addr") @@ -191,36 +246,51 @@ class CallStatsAggregator extends StagedAggregator { // this is a little weird - computing AC has the side effect of updating AN val ac = acType.constructFromElements(cb, region, state.nAlleles, deepCopy = true) { (cb, i) => - val acAtIndex = cb.newLocal[Int]("callstats_result_acAtIndex", state.alleleCountAtIndex(i, state.nAlleles)) + val acAtIndex = + cb.newLocal[Int]("callstats_result_acAtIndex", state.alleleCountAtIndex(i, state.nAlleles)) cb.assign(alleleNumber, alleleNumber + acAtIndex) IEmitCode.present(cb, primitive(acAtIndex)) } acType.storeAtAddress(cb, rt.fieldOffset(addr, "AC"), region, ac, deepCopy = false) - cb.if_(alleleNumber.ceq(0), - rt.setFieldMissing(cb, addr, "AF"), - { + cb.if_( + alleleNumber.ceq(0), + rt.setFieldMissing(cb, addr, "AF"), { val afType = resultStorageType.fieldType("AF").asInstanceOf[PCanonicalArray] - val af = afType.constructFromElements(cb, region, state.nAlleles, deepCopy = true) { (cb, i) => - val acAtIndex = cb.newLocal[Int]("callstats_result_acAtIndex", state.alleleCountAtIndex(i, state.nAlleles)) - IEmitCode.present(cb, primitive(cb.memoize(acAtIndex.toD / alleleNumber.toD))) - } + val af = + afType.constructFromElements(cb, region, state.nAlleles, deepCopy = true) { (cb, i) => + val acAtIndex = cb.newLocal[Int]( + "callstats_result_acAtIndex", + state.alleleCountAtIndex(i, state.nAlleles), + ) + IEmitCode.present(cb, primitive(cb.memoize(acAtIndex.toD / alleleNumber.toD))) + } afType.storeAtAddress(cb, rt.fieldOffset(addr, "AF"), region, af, deepCopy = false) - }) + }, + ) val anType = resultStorageType.fieldType("AN") val an = primitive(alleleNumber) anType.storeAtAddress(cb, rt.fieldOffset(addr, "AN"), region, an, deepCopy = false) - val homCountType = resultStorageType.fieldType("homozygote_count").asInstanceOf[PCanonicalArray] - val homCount = homCountType.constructFromElements(cb, region, state.nAlleles, deepCopy = true) { (cb, i) => - val homCountAtIndex = cb.newLocal[Int]("callstats_result_homCountAtIndex", state.homCountAtIndex(i, state.nAlleles)) - IEmitCode.present(cb, primitive(homCountAtIndex)) - } + val homCount = + homCountType.constructFromElements(cb, region, state.nAlleles, deepCopy = true) { (cb, i) => + val homCountAtIndex = cb.newLocal[Int]( + "callstats_result_homCountAtIndex", + state.homCountAtIndex(i, state.nAlleles), + ) + IEmitCode.present(cb, primitive(homCountAtIndex)) + } - homCountType.storeAtAddress(cb, rt.fieldOffset(addr, "homozygote_count"), region, homCount, deepCopy = false) + homCountType.storeAtAddress( + cb, + rt.fieldOffset(addr, "homozygote_count"), + region, + homCount, + deepCopy = false, + ) IEmitCode.present(cb, rt.loadCheapSCode(cb, addr)) } } diff --git a/hail/src/main/scala/is/hail/expr/ir/agg/CollectAggregator.scala b/hail/src/main/scala/is/hail/expr/ir/agg/CollectAggregator.scala index 35690a62ecd..907fa809639 100644 --- a/hail/src/main/scala/is/hail/expr/ir/agg/CollectAggregator.scala +++ b/hail/src/main/scala/is/hail/expr/ir/agg/CollectAggregator.scala @@ -12,7 +12,8 @@ import is.hail.types.physical.stypes.concrete.SIndexablePointer import is.hail.types.virtual.Type import is.hail.utils._ -class CollectAggState(val elemVType: VirtualTypeWithReq, val kb: EmitClassBuilder[_]) extends AggregatorState { +class CollectAggState(val elemVType: VirtualTypeWithReq, val kb: EmitClassBuilder[_]) + extends AggregatorState { private val elemType = elemVType.canonicalPType val r = kb.genFieldThisRef[Region]() @@ -24,25 +25,36 @@ class CollectAggState(val elemVType: VirtualTypeWithReq, val kb: EmitClassBuilde override def regionSize: Region.Size = Region.REGULAR def createState(cb: EmitCodeBuilder): Unit = - cb.if_(region.isNull, { - cb.assign(r, Region.stagedCreate(regionSize, kb.pool())) - cb += region.invalidate() - }) + cb.if_( + region.isNull, { + cb.assign(r, Region.stagedCreate(regionSize, kb.pool())) + cb += region.invalidate() + }, + ) def newState(cb: EmitCodeBuilder, off: Value[Long]): Unit = cb += region.getNewRegion(regionSize) - override def load(cb: EmitCodeBuilder, regionLoader: (EmitCodeBuilder, Value[Region]) => Unit, src: Value[Long]): Unit = { + override def load( + cb: EmitCodeBuilder, + regionLoader: (EmitCodeBuilder, Value[Region]) => Unit, + src: Value[Long], + ): Unit = { regionLoader(cb, region) bll.load(cb, src) } - override def store(cb: EmitCodeBuilder, regionStorer: (EmitCodeBuilder, Value[Region]) => Unit, dest: Value[Long]): Unit = { - cb.if_(region.isValid, - { + override def store( + cb: EmitCodeBuilder, + regionStorer: (EmitCodeBuilder, Value[Region]) => Unit, + dest: Value[Long], + ): Unit = { + cb.if_( + region.isValid, { regionStorer(cb, region) bll.store(cb, dest) cb += region.invalidate() - }) + }, + ) } def copyFrom(cb: EmitCodeBuilder, src: Value[Long]): Unit = { @@ -52,14 +64,13 @@ class CollectAggState(val elemVType: VirtualTypeWithReq, val kb: EmitClassBuilde } def serialize(codec: BufferSpec): (EmitCodeBuilder, Value[OutputBuffer]) => Unit = { - { (cb, ib) => bll.serialize(cb, region, ib) } + (cb, ib) => bll.serialize(cb, region, ib) } def deserialize(codec: BufferSpec): (EmitCodeBuilder, Value[InputBuffer]) => Unit = { - { (cb, ib) => + (cb, ib) => bll.init(cb, region) bll.deserialize(cb, region, ib) - } } } @@ -76,15 +87,22 @@ class CollectAggregator(val elemType: VirtualTypeWithReq) extends StagedAggregat state.bll.init(cb, state.region) } - protected def _seqOp(cb: EmitCodeBuilder, state: State, seq: Array[EmitCode]): Unit = { + protected def _seqOp(cb: EmitCodeBuilder, state: State, seq: Array[EmitCode]): Unit = state.bll.push(cb, state.region, seq(0)) - } - protected def _combOp(ctx: ExecuteContext, cb: EmitCodeBuilder, state: CollectAggState, other: CollectAggState): Unit = + protected def _combOp( + ctx: ExecuteContext, + cb: EmitCodeBuilder, + region: Value[Region], + state: CollectAggState, + other: CollectAggState, + ): Unit = state.bll.append(cb, state.region, other.bll) - protected def _result(cb: EmitCodeBuilder, state: State, region: Value[Region]): IEmitCode = { + protected def _result(cb: EmitCodeBuilder, state: State, region: Value[Region]): IEmitCode = // deepCopy is handled by the blocked linked list - IEmitCode.present(cb, state.bll.resultArray(cb, region, resultEmitType.storageType.asInstanceOf[PCanonicalArray])) - } + IEmitCode.present( + cb, + state.bll.resultArray(cb, region, resultEmitType.storageType.asInstanceOf[PCanonicalArray]), + ) } diff --git a/hail/src/main/scala/is/hail/expr/ir/agg/CollectAsSetAggregator.scala b/hail/src/main/scala/is/hail/expr/ir/agg/CollectAsSetAggregator.scala index 8659ea7a367..dacfe59890f 100644 --- a/hail/src/main/scala/is/hail/expr/ir/agg/CollectAsSetAggregator.scala +++ b/hail/src/main/scala/is/hail/expr/ir/agg/CollectAsSetAggregator.scala @@ -3,14 +3,16 @@ package is.hail.expr.ir.agg import is.hail.annotations.Region import is.hail.asm4s._ import is.hail.backend.ExecuteContext +import is.hail.expr.ir.{ + EmitClassBuilder, EmitCode, EmitCodeBuilder, EmitRegion, EmitValue, IEmitCode, +} import is.hail.expr.ir.orderings.CodeOrdering -import is.hail.expr.ir.{EmitClassBuilder, EmitCode, EmitCodeBuilder, EmitRegion, EmitValue, IEmitCode} import is.hail.io._ import is.hail.types.VirtualTypeWithReq import is.hail.types.encoded.EType import is.hail.types.physical._ -import is.hail.types.physical.stypes.concrete.SIndexablePointer import is.hail.types.physical.stypes.{EmitType, SValue} +import is.hail.types.physical.stypes.concrete.SIndexablePointer import is.hail.types.virtual.Type import is.hail.utils._ @@ -21,9 +23,8 @@ class TypedKey(typ: PType, kb: EmitClassBuilder[_], region: Value[Region]) exten def isKeyMissing(cb: EmitCodeBuilder, src: Code[Long]): Value[Boolean] = storageType.isFieldMissing(cb, src, 0) - def loadKey(cb: EmitCodeBuilder, src: Code[Long]): SValue = { + def loadKey(cb: EmitCodeBuilder, src: Code[Long]): SValue = typ.loadCheapSCode(cb, storageType.loadField(src, 0)) - } override def isEmpty(cb: EmitCodeBuilder, off: Code[Long]): Value[Boolean] = storageType.isFieldMissing(cb, off, 1) @@ -36,32 +37,38 @@ class TypedKey(typ: PType, kb: EmitClassBuilder[_], region: Value[Region]) exten storageType.setFieldPresent(cb, dest, 1) k.toI(cb) - .consume(cb, - { - storageType.setFieldMissing(cb, dest, 0) - }, + .consume( + cb, + storageType.setFieldMissing(cb, dest, 0), { sc => storageType.setFieldPresent(cb, dest, 0) typ.storeAtAddress(cb, storageType.fieldOffset(dest, 0), region, sc, deepCopy = true) - }) + }, + ) } override def copy(cb: EmitCodeBuilder, src: Code[Long], dest: Code[Long]): Unit = cb += Region.copyFrom(src, dest, storageType.byteSize) - override def deepCopy(cb: EmitCodeBuilder, er: EmitRegion, dest: Code[Long], src: Code[Long]): Unit = { - storageType.storeAtAddress(cb, dest, region, storageType.loadCheapSCode(cb, src), deepCopy = true) - } - - override def compKeys(cb: EmitCodeBuilder, k1: EmitValue, k2: EmitValue): Value[Int] = { + override def deepCopy(cb: EmitCodeBuilder, er: EmitRegion, dest: Code[Long], src: Code[Long]) + : Unit = + storageType.storeAtAddress( + cb, + dest, + region, + storageType.loadCheapSCode(cb, src), + deepCopy = true, + ) + + override def compKeys(cb: EmitCodeBuilder, k1: EmitValue, k2: EmitValue): Value[Int] = kb.getOrderingFunction(k1.st, k2.st, CodeOrdering.Compare())(cb, k1, k2) - } override def loadCompKey(cb: EmitCodeBuilder, off: Value[Long]): EmitValue = EmitValue(Some(isKeyMissing(cb, off)), loadKey(cb, off)) } -class AppendOnlySetState(val kb: EmitClassBuilder[_], vt: VirtualTypeWithReq) extends PointerBasedRVAState { +class AppendOnlySetState(val kb: EmitClassBuilder[_], vt: VirtualTypeWithReq) + extends PointerBasedRVAState { private val t = vt.canonicalPType val root: Settable[Long] = kb.genFieldThisRef[Long]() val size: Settable[Int] = kb.genFieldThisRef[Int]() @@ -72,18 +79,28 @@ class AppendOnlySetState(val kb: EmitClassBuilder[_], vt: VirtualTypeWithReq) ex val typ: PStruct = PCanonicalStruct( required = true, "size" -> PInt32(true), - "tree" -> PInt64(true)) - - override def load(cb: EmitCodeBuilder, regionLoader: (EmitCodeBuilder, Value[Region]) => Unit, src: Value[Long]): Unit = { + "tree" -> PInt64(true), + ) + + override def load( + cb: EmitCodeBuilder, + regionLoader: (EmitCodeBuilder, Value[Region]) => Unit, + src: Value[Long], + ): Unit = { super.load(cb, regionLoader, src) - cb.if_(off.cne(0L), - { + cb.if_( + off.cne(0L), { cb.assign(size, Region.loadInt(typ.loadField(off, 0))) cb.assign(root, Region.loadAddress(typ.loadField(off, 1))) - }) + }, + ) } - override def store(cb: EmitCodeBuilder, regionStorer: (EmitCodeBuilder, Value[Region]) => Unit, dest: Value[Long]): Unit = { + override def store( + cb: EmitCodeBuilder, + regionStorer: (EmitCodeBuilder, Value[Region]) => Unit, + dest: Value[Long], + ): Unit = { cb += Region.storeInt(typ.fieldOffset(off, 0), size) cb += Region.storeAddress(typ.fieldOffset(off, 1), root) super.store(cb, regionStorer, dest) @@ -100,17 +117,24 @@ class AppendOnlySetState(val kb: EmitClassBuilder[_], vt: VirtualTypeWithReq) ex def insert(cb: EmitCodeBuilder, v: EmitCode): Unit = { val _v = cb.memoize(v, "collect_as_set_insert_value") cb.assign(_elt, tree.getOrElseInitialize(cb, _v)) - cb.if_(key.isEmpty(cb, _elt), { - cb.assign(size, size + 1) - key.store(cb, _elt, _v) - }) + cb.if_( + key.isEmpty(cb, _elt), { + cb.assign(size, size + 1) + key.store(cb, _elt, _v) + }, + ) } // loads container; does not update. def foreach(cb: EmitCodeBuilder)(f: (EmitCodeBuilder, EmitCode) => Unit): Unit = tree.foreach(cb) { (cb, eoffCode) => val eoff = cb.newLocal("casa_foreach_eoff", eoffCode) - f(cb, EmitCode.fromI(cb.emb)(cb => IEmitCode(cb, key.isKeyMissing(cb, eoff), key.loadKey(cb, eoff)))) + f( + cb, + EmitCode.fromI(cb.emb)(cb => + IEmitCode(cb, key.isKeyMissing(cb, eoff), key.loadKey(cb, eoff)) + ), + ) } def copyFromAddress(cb: EmitCodeBuilder, src: Value[Long]): Unit = { @@ -121,23 +145,22 @@ class AppendOnlySetState(val kb: EmitClassBuilder[_], vt: VirtualTypeWithReq) ex } def serialize(codec: BufferSpec): (EmitCodeBuilder, Value[OutputBuffer]) => Unit = { - { (cb: EmitCodeBuilder, ob: Value[OutputBuffer]) => + (cb: EmitCodeBuilder, ob: Value[OutputBuffer]) => tree.bulkStore(cb, ob) { (cb, ob, srcCode) => val src = cb.newLocal("aoss_ser_src", srcCode) cb += ob.writeBoolean(key.isKeyMissing(cb, src)) - cb.if_(!key.isKeyMissing(cb, src), { - val k = key.loadKey(cb, src) - et.buildEncoder(k.st, kb) + cb.if_( + !key.isKeyMissing(cb, src), { + val k = key.loadKey(cb, src) + et.buildEncoder(k.st, kb) .apply(cb, k, ob) - }) + }, + ) } - } } def deserialize(codec: BufferSpec): (EmitCodeBuilder, Value[InputBuffer]) => Unit = { val kDec = et.buildDecoder(t.virtualType, kb) - val km = kb.genFieldThisRef[Boolean]("km") - val kv = kb.genFieldThisRef("kv")(typeToTypeInfo(t)) { (cb: EmitCodeBuilder, ib: Value[InputBuffer]) => init(cb) @@ -171,15 +194,19 @@ class CollectAsSetAggregator(elem: VirtualTypeWithReq) extends StagedAggregator state.insert(cb, elt) } - protected def _combOp(ctx: ExecuteContext, cb: EmitCodeBuilder, state: AppendOnlySetState, other: AppendOnlySetState): Unit = { - other.foreach(cb) { (cb, k) => state.insert(cb, k) } - } + protected def _combOp( + ctx: ExecuteContext, + cb: EmitCodeBuilder, + region: Value[Region], + state: AppendOnlySetState, + other: AppendOnlySetState, + ): Unit = + other.foreach(cb)((cb, k) => state.insert(cb, k)) protected def _result(cb: EmitCodeBuilder, state: State, region: Value[Region]): IEmitCode = { - val (pushElement, finish) = arrayRep.constructFromFunctions(cb, region, state.size, deepCopy = true) - state.foreach(cb) { (cb, elt) => - pushElement(cb, elt.toI(cb)) - } + val (pushElement, finish) = + arrayRep.constructFromFunctions(cb, region, state.size, deepCopy = true) + state.foreach(cb)((cb, elt) => pushElement(cb, elt.toI(cb))) assert(arrayRep.required) // deepCopy is handled by `storeElement` above IEmitCode.present(cb, setPType.construct(finish(cb))) diff --git a/hail/src/main/scala/is/hail/expr/ir/agg/CountAggregator.scala b/hail/src/main/scala/is/hail/expr/ir/agg/CountAggregator.scala index 63e23df336a..c0b447de0cb 100644 --- a/hail/src/main/scala/is/hail/expr/ir/agg/CountAggregator.scala +++ b/hail/src/main/scala/is/hail/expr/ir/agg/CountAggregator.scala @@ -1,11 +1,9 @@ package is.hail.expr.ir.agg -import freemarker.template.utility.Execute import is.hail.annotations.Region import is.hail.asm4s._ import is.hail.backend.ExecuteContext -import is.hail.expr.ir.{EmitCode, EmitCodeBuilder, EmitContext, IEmitCode} -import is.hail.types.physical._ +import is.hail.expr.ir.{EmitCode, EmitCodeBuilder, IEmitCode} import is.hail.types.physical.stypes.EmitType import is.hail.types.physical.stypes.interfaces.primitive import is.hail.types.physical.stypes.primitives.SInt64 @@ -32,11 +30,20 @@ object CountAggregator extends StagedAggregator { cb.assign(ev, EmitCode.present(cb.emb, primitive(cb.memoize(ev.pv.asInt64.value + 1L)))) } - protected def _combOp(ctx: ExecuteContext, cb: EmitCodeBuilder, state: PrimitiveRVAState, other: PrimitiveRVAState): Unit = { + protected def _combOp( + ctx: ExecuteContext, + cb: EmitCodeBuilder, + region: Value[Region], + state: PrimitiveRVAState, + other: PrimitiveRVAState, + ): Unit = { assert(state.vtypes.head.r.required) val v1 = state.fields(0) val v2 = other.fields(0) - cb.assign(v1, EmitCode.present(cb.emb, primitive(cb.memoize(v1.pv.asInt64.value + v2.pv.asInt64.value)))) + cb.assign( + v1, + EmitCode.present(cb.emb, primitive(cb.memoize(v1.pv.asInt64.value + v2.pv.asInt64.value))), + ) } protected def _result(cb: EmitCodeBuilder, state: State, region: Value[Region]): IEmitCode = { diff --git a/hail/src/main/scala/is/hail/expr/ir/agg/DensifyAggregator.scala b/hail/src/main/scala/is/hail/expr/ir/agg/DensifyAggregator.scala index 2f4c8e77b7b..2730bfc61ca 100644 --- a/hail/src/main/scala/is/hail/expr/ir/agg/DensifyAggregator.scala +++ b/hail/src/main/scala/is/hail/expr/ir/agg/DensifyAggregator.scala @@ -3,7 +3,7 @@ package is.hail.expr.ir.agg import is.hail.annotations.Region import is.hail.asm4s.{Code, _} import is.hail.backend.ExecuteContext -import is.hail.expr.ir.{EmitClassBuilder, EmitCode, EmitCodeBuilder, EmitContext, IEmitCode} +import is.hail.expr.ir.{EmitClassBuilder, EmitCode, EmitCodeBuilder, IEmitCode} import is.hail.io.{BufferSpec, InputBuffer, OutputBuffer, TypedCodecSpec} import is.hail.types.VirtualTypeWithReq import is.hail.types.physical._ @@ -17,11 +17,11 @@ object DensifyAggregator { val END_SERIALIZATION: Int = 0xf81ea4 } -class DensifyState(val arrayVType: VirtualTypeWithReq, val kb: EmitClassBuilder[_]) extends AggregatorState { - val eltType: PType = { +class DensifyState(val arrayVType: VirtualTypeWithReq, val kb: EmitClassBuilder[_]) + extends AggregatorState { + val eltType: PType = // FIXME: VirtualTypeWithReq needs better ergonomics arrayVType.canonicalPType.asInstanceOf[PCanonicalArray].elementType.setRequired(false) - } private val r: ThisFieldRef[Region] = kb.genFieldThisRef[Region]() val region: Value[Region] = r @@ -34,22 +34,27 @@ class DensifyState(val arrayVType: VirtualTypeWithReq, val kb: EmitClassBuilder[ private val length = kb.genFieldThisRef[Int]("densify_len") private val arrayAddr = kb.genFieldThisRef[Long]("densify_addr") - def newState(cb: EmitCodeBuilder, off: Value[Long]): Unit = { + def newState(cb: EmitCodeBuilder, off: Value[Long]): Unit = cb += region.getNewRegion(regionSize) - } def createState(cb: EmitCodeBuilder): Unit = - cb.if_(region.isNull, { - cb.assign(r, Region.stagedCreate(regionSize, kb.pool())) - }) + cb.if_(region.isNull, cb.assign(r, Region.stagedCreate(regionSize, kb.pool()))) - override def load(cb: EmitCodeBuilder, regionLoader: (EmitCodeBuilder, Value[Region]) => Unit, src: Value[Long]): Unit = { + override def load( + cb: EmitCodeBuilder, + regionLoader: (EmitCodeBuilder, Value[Region]) => Unit, + src: Value[Long], + ): Unit = { regionLoader(cb, r) cb.assign(arrayAddr, Region.loadAddress(src)) cb.assign(length, arrayStorageType.loadLength(arrayAddr)) } - override def store(cb: EmitCodeBuilder, regionStorer: (EmitCodeBuilder, Value[Region]) => Unit, dest: Value[Long]): Unit = { + override def store( + cb: EmitCodeBuilder, + regionStorer: (EmitCodeBuilder, Value[Region]) => Unit, + dest: Value[Long], + ): Unit = { regionStorer(cb, region) cb += region.invalidate() cb += Region.storeAddress(dest, arrayAddr) @@ -76,8 +81,10 @@ class DensifyState(val arrayVType: VirtualTypeWithReq, val kb: EmitClassBuilder[ cb.assign(arrayAddr, arrayStorageType.store(cb, region, decValue, deepCopy = false)) cb.assign(length, arrayStorageType.loadLength(arrayAddr)) - cb.if_(ib.readInt().cne(const(DensifyAggregator.END_SERIALIZATION)), - cb += Code._fatal[Unit](s"densify serialization failed")) + cb.if_( + ib.readInt().cne(const(DensifyAggregator.END_SERIALIZATION)), + cb += Code._fatal[Unit](s"densify serialization failed"), + ) } } @@ -88,27 +95,44 @@ class DensifyState(val arrayVType: VirtualTypeWithReq, val kb: EmitClassBuilder[ } private def gc(cb: EmitCodeBuilder): Unit = { - cb.if_(region.totalManagedBytes() > maxRegionSize, { - val newRegion = cb.newLocal[Region]("densify_gc", Region.stagedCreate(regionSize, kb.pool())) - cb.assign(arrayAddr, arrayStorageType.store(cb, newRegion, arrayStorageType.loadCheapSCode(cb, arrayAddr), deepCopy = true)) - cb += region.invalidate() - cb.assign(r, newRegion) - - }) + cb.if_( + region.totalManagedBytes() > maxRegionSize, { + val newRegion = + cb.newLocal[Region]("densify_gc", Region.stagedCreate(regionSize, kb.pool())) + cb.assign( + arrayAddr, + arrayStorageType.store( + cb, + newRegion, + arrayStorageType.loadCheapSCode(cb, arrayAddr), + deepCopy = true, + ), + ) + cb += region.invalidate() + cb.assign(r, newRegion) + + }, + ) } def seqOp(cb: EmitCodeBuilder, a: EmitCode): Unit = { a.toI(cb) - .consume(cb, - { + .consume( + cb, { /* do nothing if missing */ }, - { arr => + arr => arr.asIndexable.forEachDefined(cb) { case (cb, idx, element) => arrayStorageType.setElementPresent(cb, arrayAddr, idx) - eltType.storeAtAddress(cb, arrayStorageType.elementOffset(arrayAddr, length, idx), region, element, deepCopy = true) - } - }) + eltType.storeAtAddress( + cb, + arrayStorageType.elementOffset(arrayAddr, length, idx), + region, + element, + deepCopy = true, + ) + }, + ) gc(cb) } @@ -117,34 +141,44 @@ class DensifyState(val arrayVType: VirtualTypeWithReq, val kb: EmitClassBuilder[ val arr = arrayStorageType.loadCheapSCode(cb, other.arrayAddr) arr.asInstanceOf[SIndexableValue].forEachDefined(cb) { case (cb, idx, element) => arrayStorageType.setElementPresent(cb, arrayAddr, idx) - eltType.storeAtAddress(cb, arrayStorageType.elementOffset(arrayAddr, length, idx), region, element, deepCopy = true) + eltType.storeAtAddress( + cb, + arrayStorageType.elementOffset(arrayAddr, length, idx), + region, + element, + deepCopy = true, + ) } gc(cb) } - def result(cb: EmitCodeBuilder, region: Value[Region]): SIndexablePointerValue = { + def result(cb: EmitCodeBuilder, region: Value[Region]): SIndexablePointerValue = arrayStorageType.loadCheapSCode(cb, arrayAddr) - } def copyFrom(cb: EmitCodeBuilder, src: Value[Long]): Unit = { - cb.assign(arrayAddr, - arrayStorageType.store(cb, + cb.assign( + arrayAddr, + arrayStorageType.store( + cb, region, arrayStorageType.loadCheapSCode(cb, arrayStorageType.loadFromNested(src)), - deepCopy = true)) + deepCopy = true, + ), + ) cb.assign(length, arrayStorageType.loadLength(arrayAddr)) } } - class DensifyAggregator(val arrayVType: VirtualTypeWithReq) extends StagedAggregator { type State = DensifyState private val pt = { // FIXME: VirtualTypeWithReq needs better ergonomics - val eltType = arrayVType.canonicalPType.asInstanceOf[PCanonicalArray].elementType.setRequired(false) + val eltType = + arrayVType.canonicalPType.asInstanceOf[PCanonicalArray].elementType.setRequired(false) PCanonicalArray(eltType) } + val resultEmitType: EmitType = EmitType(SIndexablePointer(pt), true) val initOpTypes: Seq[Type] = Array(TInt32) val seqOpTypes: Seq[Type] = Array(resultEmitType.virtualType) @@ -153,9 +187,10 @@ class DensifyAggregator(val arrayVType: VirtualTypeWithReq) extends StagedAggreg assert(init.length == 1) val Array(sizeTriplet) = init sizeTriplet.toI(cb) - .consume(cb, + .consume( + cb, cb += Code._fatal[Unit](s"argument 'n' for 'hl.agg.densify' may not be missing"), - sc => state.init(cb, sc.asInt.value) + sc => state.init(cb, sc.asInt.value), ) } @@ -164,7 +199,13 @@ class DensifyAggregator(val arrayVType: VirtualTypeWithReq) extends StagedAggreg state.seqOp(cb, elt) } - protected def _combOp(ctx: ExecuteContext, cb: EmitCodeBuilder, state: DensifyState, other: DensifyState): Unit = state.combine(cb, other) + protected def _combOp( + ctx: ExecuteContext, + cb: EmitCodeBuilder, + region: Value[Region], + state: DensifyState, + other: DensifyState, + ): Unit = state.combine(cb, other) protected def _result(cb: EmitCodeBuilder, state: State, region: Value[Region]): IEmitCode = { val resultInWrongRegion = state.result(cb, region) diff --git a/hail/src/main/scala/is/hail/expr/ir/agg/DownsampleAggregator.scala b/hail/src/main/scala/is/hail/expr/ir/agg/DownsampleAggregator.scala index b180447fc5e..4ed3d98e669 100644 --- a/hail/src/main/scala/is/hail/expr/ir/agg/DownsampleAggregator.scala +++ b/hail/src/main/scala/is/hail/expr/ir/agg/DownsampleAggregator.scala @@ -3,24 +3,28 @@ package is.hail.expr.ir.agg import is.hail.annotations.Region import is.hail.asm4s._ import is.hail.backend.ExecuteContext +import is.hail.expr.ir.{ + EmitClassBuilder, EmitCode, EmitCodeBuilder, EmitRegion, EmitValue, IEmitCode, ParamType, +} import is.hail.expr.ir.orderings.CodeOrdering -import is.hail.expr.ir.{EmitClassBuilder, EmitCode, EmitCodeBuilder, EmitRegion, EmitValue, IEmitCode, ParamType} import is.hail.io.{BufferSpec, InputBuffer, OutputBuffer} import is.hail.types.VirtualTypeWithReq import is.hail.types.encoded.EType import is.hail.types.physical._ +import is.hail.types.physical.stypes.{EmitType, SingleCodeSCode} import is.hail.types.physical.stypes.concrete.{SIndexablePointer, SIndexablePointerValue} import is.hail.types.physical.stypes.interfaces.SBaseStructValue -import is.hail.types.physical.stypes.{EmitType, SingleCodeSCode} import is.hail.types.virtual._ import is.hail.utils._ - -class DownsampleBTreeKey(binType: PBaseStruct, pointType: PBaseStruct, kb: EmitClassBuilder[_], region: Code[Region]) extends BTreeKey { - override val storageType: PCanonicalStruct = PCanonicalStruct(required = true, +class DownsampleBTreeKey(binType: PBaseStruct, pointType: PBaseStruct, kb: EmitClassBuilder[_]) + extends BTreeKey { + override val storageType: PCanonicalStruct = PCanonicalStruct( + required = true, "bin" -> binType, "point" -> pointType, - "empty" -> PBooleanRequired) + "empty" -> PBooleanRequired, + ) override val compType: PType = binType private val kcomp = kb.getOrderingFunction(binType.sType, CodeOrdering.Compare()) @@ -28,15 +32,26 @@ class DownsampleBTreeKey(binType: PBaseStruct, pointType: PBaseStruct, kb: EmitC override def isEmpty(cb: EmitCodeBuilder, off: Code[Long]): Value[Boolean] = PBooleanRequired.loadCheapSCode(cb, storageType.loadField(off, "empty")).value - override def initializeEmpty(cb: EmitCodeBuilder, off: Code[Long]): Unit = cb += Region.storeBoolean(storageType.fieldOffset(off, "empty"), true) + override def initializeEmpty(cb: EmitCodeBuilder, off: Code[Long]): Unit = + cb += Region.storeBoolean(storageType.fieldOffset(off, "empty"), true) - override def copy(cb: EmitCodeBuilder, src: Code[Long], dest: Code[Long]): Unit = cb += Region.copyFrom(src, dest, storageType.byteSize) + override def copy(cb: EmitCodeBuilder, src: Code[Long], dest: Code[Long]): Unit = + cb += Region.copyFrom(src, dest, storageType.byteSize) - override def deepCopy(cb: EmitCodeBuilder, er: EmitRegion, srcc: Code[Long], dest: Code[Long]): Unit = { + override def deepCopy(cb: EmitCodeBuilder, er: EmitRegion, srcc: Code[Long], dest: Code[Long]) + : Unit = { val src = cb.newLocal[Long]("dsa_deep_copy_src", srcc) - cb.if_(Region.loadBoolean(storageType.loadField(src, "empty")), - cb += Code._fatal[Unit]("key empty!")) - storageType.storeAtAddress(cb, dest, er.region, storageType.loadCheapSCode(cb, src), deepCopy = true) + cb.if_( + Region.loadBoolean(storageType.loadField(src, "empty")), + cb += Code._fatal[Unit]("key empty!"), + ) + storageType.storeAtAddress( + cb, + dest, + er.region, + storageType.loadCheapSCode(cb, src), + deepCopy = true, + ) } override def compKeys(cb: EmitCodeBuilder, k1: EmitValue, k2: EmitValue): Value[Int] = @@ -46,12 +61,15 @@ class DownsampleBTreeKey(binType: PBaseStruct, pointType: PBaseStruct, kb: EmitC EmitValue.present(binType.loadCheapSCode(cb, storageType.loadField(off, "bin"))) } - object DownsampleState { val serializationEndMarker: Int = 883625255 } -class DownsampleState(val kb: EmitClassBuilder[_], labelType: VirtualTypeWithReq, maxBufferSize: Int = 256) extends AggregatorState { +class DownsampleState( + val kb: EmitClassBuilder[_], + labelType: VirtualTypeWithReq, + maxBufferSize: Int = 256, +) extends AggregatorState { private val labelPType = labelType.canonicalPType val r: Settable[Region] = kb.genFieldThisRef[Region]("region") val region: Value[Region] = r @@ -64,7 +82,13 @@ class DownsampleState(val kb: EmitClassBuilder[_], labelType: VirtualTypeWithReq cb.if_(region.isNull, cb.assign(r, Region.stagedCreate(regionSize, kb.pool()))) val binType = PCanonicalStruct(required = true, "x" -> PInt32Required, "y" -> PInt32Required) - val pointType = PCanonicalStruct(required = true, "x" -> PFloat64Required, "y" -> PFloat64Required, "label" -> labelPType) + + val pointType = PCanonicalStruct( + required = true, + "x" -> PFloat64Required, + "y" -> PFloat64Required, + "label" -> labelPType, + ) private val binET = EType.defaultFromPType(binType) private val pointET = EType.defaultFromPType(pointType) @@ -72,7 +96,7 @@ class DownsampleState(val kb: EmitClassBuilder[_], labelType: VirtualTypeWithReq private val root: Settable[Long] = kb.genFieldThisRef[Long]("root") private val oldRoot: Settable[Long] = kb.genFieldThisRef[Long]("old_root") - val key = new DownsampleBTreeKey(binType, pointType, kb, region) + val key = new DownsampleBTreeKey(binType, pointType, kb) val tree = new AppendOnlyBTree(kb, key, region, root) val buffer = new StagedArrayBuilder(pointType, kb, region, initialCapacity = maxBufferSize) val oldRootBTree = new AppendOnlyBTree(kb, key, region, oldRoot) @@ -89,7 +113,8 @@ class DownsampleState(val kb: EmitClassBuilder[_], labelType: VirtualTypeWithReq private val bufferTop: Settable[Double] = kb.genFieldThisRef[Double]("buffer_top") private val treeSize: Settable[Int] = kb.genFieldThisRef[Int]("treeSize") - val storageType = PCanonicalStruct(required = true, + val storageType = PCanonicalStruct( + required = true, "nDivisions" -> PInt32Required, "treeSize" -> PInt32Required, "left" -> PFloat64Required, @@ -103,7 +128,7 @@ class DownsampleState(val kb: EmitClassBuilder[_], labelType: VirtualTypeWithReq "buffer" -> buffer.stateType, "tree" -> PInt64Required, "binStaging" -> binType, // used as scratch space - "pointStaging" -> pointType // used as scratch space + "pointStaging" -> pointType, // used as scratch space ) override val regionSize: Int = Region.SMALL @@ -117,7 +142,12 @@ class DownsampleState(val kb: EmitClassBuilder[_], labelType: VirtualTypeWithReq allocateSpace(cb) cb.assign(this.nDivisions, mb.getCodeParam[Int](1).load()) - cb.if_(this.nDivisions < 4, cb += Code._fatal[Unit](const("downsample: require n_divisions >= 4, found ").concat(this.nDivisions.toS))) + cb.if_( + this.nDivisions < 4, + cb += Code._fatal[Unit]( + const("downsample: require n_divisions >= 4, found ").concat(this.nDivisions.toS) + ), + ) cb.assign(left, 0d) cb.assign(right, 0d) @@ -127,13 +157,16 @@ class DownsampleState(val kb: EmitClassBuilder[_], labelType: VirtualTypeWithReq tree.init(cb) buffer.initialize(cb) } - cb.invokeVoid(mb, nDivisions) + cb.invokeVoid(mb, cb.this_, nDivisions) } - override def load(cb: EmitCodeBuilder, regionLoader: (EmitCodeBuilder, Value[Region]) => Unit, src: Value[Long]): Unit = { + override def load( + cb: EmitCodeBuilder, + regionLoader: (EmitCodeBuilder, Value[Region]) => Unit, + src: Value[Long], + ): Unit = { val mb = kb.genEmitMethod("downsample_load", FastSeq[ParamType](), UnitInfo) mb.voidWithBuilder { cb => - cb.assign(nDivisions, Region.loadInt(storageType.loadField(off, "nDivisions"))) cb.assign(treeSize, Region.loadInt(storageType.loadField(off, "treeSize"))) cb.assign(left, Region.loadDouble(storageType.loadField(off, "left"))) @@ -149,10 +182,14 @@ class DownsampleState(val kb: EmitClassBuilder[_], labelType: VirtualTypeWithReq } cb.assign(off, src) regionLoader(cb, r) - cb.invokeVoid(mb) + cb.invokeVoid(mb, cb.this_) } - override def store(cb: EmitCodeBuilder, regionStorer: (EmitCodeBuilder, Value[Region]) => Unit, dest: Value[Long]): Unit = { + override def store( + cb: EmitCodeBuilder, + regionStorer: (EmitCodeBuilder, Value[Region]) => Unit, + dest: Value[Long], + ): Unit = { val mb = kb.genEmitMethod("downsample_store", FastSeq[ParamType](), UnitInfo) mb.voidWithBuilder { cb => cb += Region.storeInt(storageType.fieldOffset(off, "nDivisions"), nDivisions) @@ -170,12 +207,13 @@ class DownsampleState(val kb: EmitClassBuilder[_], labelType: VirtualTypeWithReq } cb.assign(off, dest) - cb.invokeVoid(mb) - cb.if_(region.isValid, - { + cb.invokeVoid(mb, cb.this_) + cb.if_( + region.isValid, { regionStorer(cb, region) cb += region.invalidate() - }) + }, + ) } def copyFrom(cb: EmitCodeBuilder, _src: Value[Long]): Unit = { @@ -194,12 +232,16 @@ class DownsampleState(val kb: EmitClassBuilder[_], labelType: VirtualTypeWithReq tree.deepCopy(cb, cb.memoize(Region.loadAddress(storageType.loadField(src, "tree")))) buffer.copyFrom(cb, storageType.loadField(src, "buffer")) } - cb.invokeVoid(mb, _src) + cb.invokeVoid(mb, cb.this_, _src) } def serialize(codec: BufferSpec): (EmitCodeBuilder, Value[OutputBuffer]) => Unit = { - { (cb: EmitCodeBuilder, ob: Value[OutputBuffer]) => - val mb = kb.genEmitMethod("downsample_serialize", FastSeq[ParamType](typeInfo[OutputBuffer]), UnitInfo) + (cb: EmitCodeBuilder, ob: Value[OutputBuffer]) => + val mb = kb.genEmitMethod( + "downsample_serialize", + FastSeq[ParamType](typeInfo[OutputBuffer]), + UnitInfo, + ) mb.emitWithBuilder { cb => val ob = mb.getCodeParam[OutputBuffer](1) dumpBuffer(cb) @@ -223,8 +265,7 @@ class DownsampleState(val kb: EmitClassBuilder[_], labelType: VirtualTypeWithReq Code._empty } - cb.invokeVoid(mb, ob) - } + cb.invokeVoid(mb, cb.this_, ob) } def deserialize(codec: BufferSpec): (EmitCodeBuilder, Value[InputBuffer]) => Unit = { @@ -232,7 +273,11 @@ class DownsampleState(val kb: EmitClassBuilder[_], labelType: VirtualTypeWithReq val pointDec = pointET.buildInplaceDecoderMethod(pointType, kb) { (cb: EmitCodeBuilder, ib: Value[InputBuffer]) => - val mb = kb.genEmitMethod("downsample_deserialize", FastSeq[ParamType](typeInfo[InputBuffer]), UnitInfo) + val mb = kb.genEmitMethod( + "downsample_deserialize", + FastSeq[ParamType](typeInfo[InputBuffer]), + UnitInfo, + ) mb.emitWithBuilder { cb => val ib = cb.emb.getCodeParam[InputBuffer](1) val serializationEndTag = cb.newLocal[Int]("de_end_tag") @@ -251,38 +296,64 @@ class DownsampleState(val kb: EmitClassBuilder[_], labelType: VirtualTypeWithReq tree.init(cb) tree.bulkLoad(cb, ib) { (cb, ib, destCode) => val dest = cb.newLocal("dss_deser_dest", destCode) - binDec.invokeCode(cb, region, cb.memoize(key.storageType.fieldOffset(dest, "bin")), ib) - pointDec.invokeCode(cb, region, cb.memoize(key.storageType.fieldOffset(dest, "point")), ib) + cb.invokeVoid( + binDec, + cb.this_, + region, + cb.memoize(key.storageType.fieldOffset(dest, "bin")), + ib, + ) + cb.invokeVoid( + pointDec, + cb.this_, + region, + cb.memoize(key.storageType.fieldOffset(dest, "point")), + ib, + ) cb += Region.storeBoolean(key.storageType.fieldOffset(dest, "empty"), false) } buffer.initialize(cb) cb.assign(serializationEndTag, ib.readInt()) - cb.if_(serializationEndTag.cne(DownsampleState.serializationEndMarker), { - cb._fatal("downsample aggregator failed to serialize!") - }) + cb.if_( + serializationEndTag.cne(DownsampleState.serializationEndMarker), + cb._fatal("downsample aggregator failed to serialize!"), + ) Code._empty } - cb.invokeVoid(mb, ib) + cb.invokeVoid(mb, cb.this_, ib) } } val xBinCoordinate: (EmitCodeBuilder, Value[Double]) => Value[Int] = { - val mb = kb.genEmitMethod("downsample_x_bin_coordinate", FastSeq[ParamType](DoubleInfo), IntInfo) + val mb = + kb.genEmitMethod("downsample_x_bin_coordinate", FastSeq[ParamType](DoubleInfo), IntInfo) val x = mb.getCodeParam[Double](1) mb.emit(right.ceq(left).mux(0, (((x - left) / (right - left)) * nDivisions.toD).toI)) - mb.invokeCode(_, _) + (cb, x) => cb.invokeCode(mb, cb.this_, x) } val yBinCoordinate: (EmitCodeBuilder, Value[Double]) => Value[Int] = { - val mb = kb.genEmitMethod("downsample_y_bin_coordinate", FastSeq[ParamType](DoubleInfo), IntInfo) + val mb = + kb.genEmitMethod("downsample_y_bin_coordinate", FastSeq[ParamType](DoubleInfo), IntInfo) val y = mb.getCodeParam[Double](1) mb.emit(top.ceq(bottom).mux(0, (((y - bottom) / (top - bottom)) * nDivisions.toD).toI)) - mb.invokeCode(_, _) + (cb, y) => cb.invokeCode(mb, cb.this_, y) } - def insertIntoTree(cb: EmitCodeBuilder, binX: Value[Int], binY: Value[Int], point: Value[Long], deepCopy: Boolean): Unit = { - val name = s"downsample_insert_into_tree_${ deepCopy.toString }" - val mb = kb.getOrGenEmitMethod(name, (this, name, deepCopy), FastSeq[ParamType](IntInfo, IntInfo, LongInfo), UnitInfo) { mb => + def insertIntoTree( + cb: EmitCodeBuilder, + binX: Value[Int], + binY: Value[Int], + point: Value[Long], + deepCopy: Boolean, + ): Unit = { + val name = s"downsample_insert_into_tree_${deepCopy.toString}" + val mb = kb.getOrGenEmitMethod( + name, + (this, name, deepCopy), + FastSeq[ParamType](IntInfo, IntInfo, LongInfo), + UnitInfo, + ) { mb => val binX = mb.getCodeParam[Int](1) val binY = mb.getCodeParam[Int](2) val point = mb.getCodeParam[Long](3) @@ -295,21 +366,37 @@ class DownsampleState(val kb: EmitClassBuilder[_], labelType: VirtualTypeWithReq cb.assign(binStaging, storageType.loadField(off, "binStaging")) cb += Region.storeInt(binType.fieldOffset(binStaging, "x"), binX) cb += Region.storeInt(binType.fieldOffset(binStaging, "y"), binY) - cb.assign(insertOffset, - tree.getOrElseInitialize(cb, EmitCode.present(cb.emb, storageType.fieldType("binStaging").loadCheapSCode(cb, binStaging)))) - cb.if_(key.isEmpty(cb, insertOffset), { - cb.assign(binOffset, key.storageType.loadField(insertOffset, "bin")) - cb += Region.storeInt(binType.loadField(binOffset, "x"), binX) - cb += Region.storeInt(binType.loadField(binOffset, "y"), binY) - cb.assign(insertedPointOffset, key.storageType.loadField(insertOffset, "point")) - pointType.storeAtAddress(cb, insertedPointOffset, region, pointType.loadCheapSCode(cb, point), deepCopy = deepCopy) - cb += Region.storeBoolean(key.storageType.loadField(insertOffset, "empty"), false) - cb.assign(treeSize, treeSize + 1) - }) + cb.assign( + insertOffset, + tree.getOrElseInitialize( + cb, + EmitCode.present( + cb.emb, + storageType.fieldType("binStaging").loadCheapSCode(cb, binStaging), + ), + ), + ) + cb.if_( + key.isEmpty(cb, insertOffset), { + cb.assign(binOffset, key.storageType.loadField(insertOffset, "bin")) + cb += Region.storeInt(binType.loadField(binOffset, "x"), binX) + cb += Region.storeInt(binType.loadField(binOffset, "y"), binY) + cb.assign(insertedPointOffset, key.storageType.loadField(insertOffset, "point")) + pointType.storeAtAddress( + cb, + insertedPointOffset, + region, + pointType.loadCheapSCode(cb, point), + deepCopy = deepCopy, + ) + cb += Region.storeBoolean(key.storageType.loadField(insertOffset, "empty"), false) + cb.assign(treeSize, treeSize + 1) + }, + ) } } - cb.invokeVoid(mb, binX, binY, point) + cb.invokeVoid(mb, cb.this_, binX, binY, point) } def copyFromTree(cb: EmitCodeBuilder, other: AppendOnlyBTree): Unit = { @@ -317,19 +404,29 @@ class DownsampleState(val kb: EmitClassBuilder[_], labelType: VirtualTypeWithReq mb.voidWithBuilder { cb => other.foreach(cb) { (cb, v) => - val mb = kb.genEmitMethod("downsample_copy_from_tree_foreach", FastSeq[ParamType](LongInfo), UnitInfo) + val mb = kb.genEmitMethod( + "downsample_copy_from_tree_foreach", + FastSeq[ParamType](LongInfo), + UnitInfo, + ) val value = mb.getCodeParam[Long](1) mb.voidWithBuilder { cb => val point = cb.memoize(key.storageType.loadField(value, "point")) val pointX = cb.memoize(Region.loadDouble(pointType.loadField(point, "x"))) val pointY = cb.memoize(Region.loadDouble(pointType.loadField(point, "y"))) - insertIntoTree(cb, xBinCoordinate(cb, pointX), yBinCoordinate(cb, pointY), point, deepCopy = true) + insertIntoTree( + cb, + xBinCoordinate(cb, pointX), + yBinCoordinate(cb, pointY), + point, + deepCopy = true, + ) } - cb.invokeVoid(mb, v) + cb.invokeVoid(mb, cb.this_, v) } } - cb.invokeVoid(mb) + cb.invokeVoid(mb, cb.this_) } def min(a: Code[Double], b: Code[Double]): Code[Double] = @@ -341,7 +438,8 @@ class DownsampleState(val kb: EmitClassBuilder[_], labelType: VirtualTypeWithReq def max(a: Code[Double], b: Code[Double]): Code[Double] = Code.invokeStatic2[java.lang.Double, Double, Double, Double]("max", a, b) - def isFinite(a: Code[Double]): Code[Boolean] = Code.invokeStatic1[java.lang.Double, Double, Boolean]("isFinite", a) + def isFinite(a: Code[Double]): Code[Boolean] = + Code.invokeStatic1[java.lang.Double, Double, Boolean]("isFinite", a) def dumpBuffer(cb: EmitCodeBuilder): Unit = { val name = "downsample_dump_buffer" @@ -360,28 +458,50 @@ class DownsampleState(val kb: EmitClassBuilder[_], labelType: VirtualTypeWithReq tree.init(cb) copyFromTree(cb, oldRootBTree) cb.assign(i, 0) - cb.while_(i < buffer.size, - { - buffer.loadElement(cb, i).toI(cb).consume(cb, {}, { case point: SBaseStructValue => - val x = point.loadField(cb, "x").get(cb).asFloat64.value - val y = point.loadField(cb, "y").get(cb).asFloat64.value - val pointc = coerce[Long](SingleCodeSCode.fromSCode(cb, point, region).code) - insertIntoTree(cb, xBinCoordinate(cb, x), yBinCoordinate(cb, y), pointc, deepCopy = true) - }) + cb.while_( + i < buffer.size, { + buffer.loadElement(cb, i).toI(cb).consume( + cb, + {}, + { case point: SBaseStructValue => + val x = point.loadField(cb, "x").getOrAssert(cb).asFloat64.value + val y = point.loadField(cb, "y").getOrAssert(cb).asFloat64.value + val pointc = coerce[Long](SingleCodeSCode.fromSCode(cb, point, region).code) + insertIntoTree( + cb, + xBinCoordinate(cb, x), + yBinCoordinate(cb, y), + pointc, + deepCopy = true, + ) + }, + ) cb.assign(i, i + 1) - }) + }, + ) buffer.initialize(cb) cb += oldRegion.invalidate() allocateSpace(cb) } } - cb.invokeVoid(mb) + cb.invokeVoid(mb, cb.this_) } - def insertPointIntoBuffer(cb: EmitCodeBuilder, x: Value[Double], y: Value[Double], point: Value[Long], deepCopy: Boolean): Unit = { + def insertPointIntoBuffer( + cb: EmitCodeBuilder, + x: Value[Double], + y: Value[Double], + point: Value[Long], + deepCopy: Boolean, + ): Unit = { val name = "downsample_insert_into_buffer" - val mb = kb.getOrGenEmitMethod(name, (this, name, deepCopy), FastSeq[ParamType](DoubleInfo, DoubleInfo, LongInfo), UnitInfo) { mb => + val mb = kb.getOrGenEmitMethod( + name, + (this, name, deepCopy), + FastSeq[ParamType](DoubleInfo, DoubleInfo, LongInfo), + UnitInfo, + ) { mb => val x = mb.getCodeParam[Double](1) val y = mb.getCodeParam[Double](2) val point = mb.getCodeParam[Long](3) @@ -396,30 +516,44 @@ class DownsampleState(val kb: EmitClassBuilder[_], labelType: VirtualTypeWithReq } } - cb.invokeVoid(mb, x, y, point) + cb.invokeVoid(mb, cb.this_, x, y, point) } def checkBounds(cb: EmitCodeBuilder, xBin: Value[Int], yBin: Value[Int]): Value[Boolean] = { val name = "downsample_check_bounds" - val mb = kb.getOrGenEmitMethod(name, (this, name), FastSeq[ParamType](IntInfo, IntInfo), BooleanInfo) { mb => - val xBin = mb.getCodeParam[Int](1) - val yBin = mb.getCodeParam[Int](2) - val factor = mb.newLocal[Int]("factor") - mb.emit(Code( - factor := nDivisions >> 2, - treeSize.ceq(0) - || (xBin < -factor) - || (xBin > nDivisions + factor) - || (yBin < -factor) - || (yBin > nDivisions + factor))) - } + val mb = + kb.getOrGenEmitMethod(name, (this, name), FastSeq[ParamType](IntInfo, IntInfo), BooleanInfo) { + mb => + val xBin = mb.getCodeParam[Int](1) + val yBin = mb.getCodeParam[Int](2) + val factor = mb.newLocal[Int]("factor") + mb.emit(Code( + factor := nDivisions >> 2, + treeSize.ceq(0) + || (xBin < -factor) + || (xBin > nDivisions + factor) + || (yBin < -factor) + || (yBin > nDivisions + factor), + )) + } - mb.invokeCode(cb, xBin, yBin) + cb.invokeCode(mb, cb.this_, xBin, yBin) } - def binAndInsert(cb: EmitCodeBuilder, x: Value[Double], y: Value[Double], point: Value[Long], deepCopy: Boolean): Unit = { + def binAndInsert( + cb: EmitCodeBuilder, + x: Value[Double], + y: Value[Double], + point: Value[Long], + deepCopy: Boolean, + ): Unit = { val name = "downsample_bin_and_insert" - val mb = kb.getOrGenEmitMethod(name, (this, name, deepCopy), FastSeq[ParamType](DoubleInfo, DoubleInfo, LongInfo), UnitInfo) { mb => + val mb = kb.getOrGenEmitMethod( + name, + (this, name, deepCopy), + FastSeq[ParamType](DoubleInfo, DoubleInfo, LongInfo), + UnitInfo, + ) { mb => val x = mb.getCodeParam[Double](1) val y = mb.getCodeParam[Double](2) val point = mb.getCodeParam[Long](3) @@ -430,19 +564,24 @@ class DownsampleState(val kb: EmitClassBuilder[_], labelType: VirtualTypeWithReq mb.voidWithBuilder { cb => cb.assign(binX, xBinCoordinate(cb, x)) cb.assign(binY, yBinCoordinate(cb, y)) - cb.if_(checkBounds(cb, binX, binY), + cb.if_( + checkBounds(cb, binX, binY), insertPointIntoBuffer(cb, x, y, point, deepCopy = deepCopy), - insertIntoTree(cb, binX, binY, point, deepCopy = deepCopy) + insertIntoTree(cb, binX, binY, point, deepCopy = deepCopy), ) } } - cb.invokeVoid(mb, x, y, point) + cb.invokeVoid(mb, cb.this_, x, y, point) } def insert(cb: EmitCodeBuilder, x: EmitCode, y: EmitCode, l: EmitCode): Unit = { val name = "downsample_insert" - val mb = kb.getOrGenEmitMethod(name, (this, name), FastSeq[ParamType](x.st.paramType, y.st.paramType, l.emitParamType), UnitInfo) { mb => - + val mb = kb.getOrGenEmitMethod( + name, + (this, name), + FastSeq[ParamType](x.st.paramType, y.st.paramType, l.emitParamType), + UnitInfo, + ) { mb => val pointStaging = mb.newLocal[Long]("pointStaging") mb.voidWithBuilder { cb => val x = mb.getSCodeParam(1) @@ -455,52 +594,73 @@ class DownsampleState(val kb: EmitClassBuilder[_], labelType: VirtualTypeWithReq cb.if_((!(isFinite(xx) && isFinite(yy))), cb += Code._return[Unit](Code._empty)) cb.assign(pointStaging, storageType.loadField(off, "pointStaging")) - pointType.fieldType("x").storeAtAddress(cb, pointType.fieldOffset(pointStaging, "x"), region, x, deepCopy = true) - pointType.fieldType("y").storeAtAddress(cb, pointType.fieldOffset(pointStaging, "y"), region, y, deepCopy = true) + pointType.fieldType("x").storeAtAddress( + cb, + pointType.fieldOffset(pointStaging, "x"), + region, + x, + deepCopy = true, + ) + pointType.fieldType("y").storeAtAddress( + cb, + pointType.fieldOffset(pointStaging, "y"), + region, + y, + deepCopy = true, + ) l.toI(cb) - .consume(cb, + .consume( + cb, pointType.setFieldMissing(cb, pointStaging, "label"), { sc => pointType.setFieldPresent(cb, pointStaging, "label") - pointType.fieldType("label").storeAtAddress(cb, pointType.fieldOffset(pointStaging, "label"), region, sc, deepCopy = true) - } + pointType.fieldType("label").storeAtAddress( + cb, + pointType.fieldOffset(pointStaging, "label"), + region, + sc, + deepCopy = true, + ) + }, ) binAndInsert(cb, xx, yy, pointStaging, deepCopy = false) } } x.toI(cb) - .consume(cb, - { + .consume( + cb, { /* do nothing if x is missing */ }, { xcode => y.toI(cb) - .consume(cb, - { + .consume( + cb, { /* do nothing if y is missing */ }, - ycode => cb.invokeVoid(mb, xcode, ycode, l) + ycode => cb.invokeVoid(mb, cb.this_, xcode, ycode, l), ) - }) + }, + ) } def deepCopyAndInsertPoint(cb: EmitCodeBuilder, point: Value[Long]): Unit = { val name = "downsample_deep_copy_insert_point" - val mb = kb.getOrGenEmitMethod(name, (this, name), IndexedSeq[ParamType](LongInfo), UnitInfo) { mb => - val point = mb.getCodeParam[Long](1) + val mb = + kb.getOrGenEmitMethod(name, (this, name), IndexedSeq[ParamType](LongInfo), UnitInfo) { mb => + val point = mb.getCodeParam[Long](1) - val x = mb.newLocal[Double]("x") - val y = mb.newLocal[Double]("y") + val x = mb.newLocal[Double]("x") + val y = mb.newLocal[Double]("y") - mb.voidWithBuilder { cb => - cb.assign(x, Region.loadDouble(pointType.loadField(point, "x"))) - cb.assign(y, Region.loadDouble(pointType.loadField(point, "y"))) - binAndInsert(cb, x, y, point, deepCopy = true) + mb.voidWithBuilder { cb => + cb.assign(x, Region.loadDouble(pointType.loadField(point, "x"))) + cb.assign(y, Region.loadDouble(pointType.loadField(point, "y"))) + binAndInsert(cb, x, y, point, deepCopy = true) + } } - } - cb.invokeVoid(mb, point) + cb.invokeVoid(mb, cb.this_, point) } def merge(cb: EmitCodeBuilder, other: DownsampleState): Unit = { @@ -509,30 +669,35 @@ class DownsampleState(val kb: EmitClassBuilder[_], labelType: VirtualTypeWithReq val i = mb.newLocal[Int]("i") mb.emitWithBuilder { cb => cb.assign(i, 0) - cb.while_(i < other.buffer.size, { - val point = SingleCodeSCode.fromSCode(cb, other.buffer.loadElement(cb, i).pv, region) - deepCopyAndInsertPoint(cb, coerce[Long](point.code)) - cb.assign(i, i + 1) - }) + cb.while_( + i < other.buffer.size, { + val point = SingleCodeSCode.fromSCode(cb, other.buffer.loadElement(cb, i).pv, region) + deepCopyAndInsertPoint(cb, coerce[Long](point.code)) + cb.assign(i, i + 1) + }, + ) other.tree.foreach(cb) { (cb, value) => deepCopyAndInsertPoint(cb, cb.memoize(key.storageType.loadField(value, "point"))) } Code._empty } - cb.invokeVoid(mb) + cb.invokeVoid(mb, cb.this_) } - def resultArray(cb: EmitCodeBuilder, region: Value[Region], resType: PCanonicalArray): SIndexablePointerValue = { + def resultArray(cb: EmitCodeBuilder, region: Value[Region], resType: PCanonicalArray) + : SIndexablePointerValue = { // dump all elements into tree for simplicity dumpBuffer(cb) - val (pushElement, finish) = resType.constructFromFunctions(cb, region, treeSize, deepCopy = true) - cb.if_(treeSize > 0, { + val (pushElement, finish) = + resType.constructFromFunctions(cb, region, treeSize, deepCopy = true) + cb.if_( + treeSize > 0, tree.foreach(cb) { (cb, tv) => val pointCode = pointType.loadCheapSCode(cb, key.storageType.loadField(tv, "point")) pushElement(cb, IEmitCode.present(cb, pointCode)) - } - }) + }, + ) finish(cb) } } @@ -544,7 +709,13 @@ object DownsampleAggregator { class DownsampleAggregator(arrayType: VirtualTypeWithReq) extends StagedAggregator { type State = DownsampleState - val resultPType: PCanonicalArray = PCanonicalArray(PCanonicalTuple(required = true, PFloat64(true), PFloat64(true), arrayType.canonicalPType)) + val resultPType: PCanonicalArray = PCanonicalArray(PCanonicalTuple( + required = true, + PFloat64(true), + PFloat64(true), + arrayType.canonicalPType, + )) + val resultEmitType = EmitType(SIndexablePointer(resultPType), true) val initOpTypes: Seq[Type] = Array(TInt32) @@ -553,9 +724,10 @@ class DownsampleAggregator(arrayType: VirtualTypeWithReq) extends StagedAggregat protected def _initOp(cb: EmitCodeBuilder, state: State, init: Array[EmitCode]): Unit = { val Array(nDivisions) = init nDivisions.toI(cb) - .consume(cb, + .consume( + cb, cb += Code._fatal[Unit]("downsample: n_divisions may not be missing"), - sc => state.init(cb, sc.asInt.value) + sc => state.init(cb, sc.asInt.value), ) } @@ -565,10 +737,15 @@ class DownsampleAggregator(arrayType: VirtualTypeWithReq) extends StagedAggregat state.insert(cb, x, y, label) } - protected def _combOp(ctx: ExecuteContext, cb: EmitCodeBuilder, state: DownsampleState, other: DownsampleState): Unit = state.merge(cb, other) + protected def _combOp( + ctx: ExecuteContext, + cb: EmitCodeBuilder, + region: Value[Region], + state: DownsampleState, + other: DownsampleState, + ): Unit = state.merge(cb, other) - protected def _result(cb: EmitCodeBuilder, state: State, region: Value[Region]): IEmitCode = { + protected def _result(cb: EmitCodeBuilder, state: State, region: Value[Region]): IEmitCode = // deepCopy is handled by state.resultArray IEmitCode.present(cb, state.resultArray(cb, region, resultPType)) - } } diff --git a/hail/src/main/scala/is/hail/expr/ir/agg/Extract.scala b/hail/src/main/scala/is/hail/expr/ir/agg/Extract.scala index 24f15e5880e..7b8c16e968a 100644 --- a/hail/src/main/scala/is/hail/expr/ir/agg/Extract.scala +++ b/hail/src/main/scala/is/hail/expr/ir/agg/Extract.scala @@ -2,33 +2,37 @@ package is.hail.expr.ir.agg import is.hail.annotations.{Region, RegionPool, RegionValue} import is.hail.asm4s.{HailClassLoader, _} -import is.hail.backend.spark.SparkTaskContext import is.hail.backend.{ExecuteContext, HailTaskContext} +import is.hail.backend.spark.SparkTaskContext import is.hail.expr.ir import is.hail.expr.ir._ import is.hail.io.BufferSpec +import is.hail.types.{TypeWithRequiredness, VirtualTypeWithReq} import is.hail.types.physical.stypes.EmitType import is.hail.types.virtual._ -import is.hail.types.{TypeWithRequiredness, VirtualTypeWithReq} import is.hail.utils._ -import org.apache.spark.TaskContext import scala.collection.mutable -import scala.language.{existentials, postfixOps} + +import org.apache.spark.TaskContext class UnsupportedExtraction(msg: String) extends Exception(msg) object AggStateSig { - def apply(op: AggOp, initOpArgs: Seq[IR], seqOpArgs: Seq[IR], r: RequirednessAnalysis): AggStateSig = { - val inits = initOpArgs.map { i => i -> (if (i.typ == TVoid) null else r(i)) } - val seqs = seqOpArgs.map { s => s -> (if (s.typ == TVoid) null else r(s)) } - apply(op, inits, seqs) + def apply(op: AggOp, seqOpArgs: Seq[IR], r: RequirednessAnalysis): AggStateSig = { + val seqs = seqOpArgs.map(s => s -> (if (s.typ == TVoid) null else r(s))) + apply(op, seqs) } - def apply(op: AggOp, initOpArgs: Seq[(IR, TypeWithRequiredness)], seqOpArgs: Seq[(IR, TypeWithRequiredness)]): AggStateSig = { + + // FIXME: factor out requiredness inference part + def apply( + op: AggOp, + seqOpArgs: Seq[(IR, TypeWithRequiredness)], + ): AggStateSig = { val seqVTypes = seqOpArgs.map { case (a, r) => VirtualTypeWithReq(a.typ, r) } op match { case Sum() | Product() => TypedStateSig(seqVTypes.head.setRequired(true)) - case Min() | Max() => TypedStateSig(seqVTypes.head.setRequired(false)) + case Min() | Max() => TypedStateSig(seqVTypes.head.setRequired(false)) case Count() => TypedStateSig(VirtualTypeWithReq.fullyOptional(TInt64).setRequired(true)) case Take() => TakeStateSig(seqVTypes.head) case ReservoirSample() => ReservoirSampleStateSig(seqVTypes.head) @@ -46,15 +50,17 @@ object AggStateSig { val Seq(_, _, labelType) = seqVTypes DownsampleStateSig(labelType) case ImputeType() => ImputeTypeStateSig() - case NDArraySum() => NDArraySumStateSig(seqVTypes.head.setRequired(false)) // set required to false to handle empty aggs + case NDArraySum() => + NDArraySumStateSig( + seqVTypes.head.setRequired(false) + ) // set required to false to handle empty aggs case NDArrayMultiplyAdd() => NDArrayMultiplyAddStateSig(seqVTypes.head.setRequired(false)) case _ => throw new UnsupportedExtraction(op.toString) } } + def grouped(k: IR, aggs: Seq[AggStateSig], r: RequirednessAnalysis): GroupedStateSig = GroupedStateSig(VirtualTypeWithReq(k.typ, r(k)), aggs) - def arrayelements(aggs: Seq[AggStateSig]): ArrayAggStateSig = - ArrayAggStateSig(aggs) def getState(sig: AggStateSig, cb: EmitClassBuilder[_]): AggregatorState = sig match { case TypedStateSig(vt) if vt.t.isPrimitive => new PrimitiveRVAState(Array(vt), cb) @@ -69,48 +75,74 @@ object AggStateSig { case CallStatsStateSig() => new CallStatsState(cb) case ApproxCDFStateSig() => new ApproxCDFState(cb) case ImputeTypeStateSig() => new ImputeTypeState(cb) - case ArrayAggStateSig(nested) => new ArrayElementState(cb, StateTuple(nested.map(sig => AggStateSig.getState(sig, cb)).toArray)) - case GroupedStateSig(kt, nested) => new DictState(cb, kt, StateTuple(nested.map(sig => AggStateSig.getState(sig, cb)).toArray)) + case ArrayAggStateSig(nested) => new ArrayElementState( + cb, + StateTuple(nested.map(sig => AggStateSig.getState(sig, cb)).toArray), + ) + case GroupedStateSig(kt, nested) => + new DictState(cb, kt, StateTuple(nested.map(sig => AggStateSig.getState(sig, cb)).toArray)) case NDArraySumStateSig(nda) => new TypedRegionBackedAggState(nda, cb) case NDArrayMultiplyAddStateSig(nda) => new TypedRegionBackedAggState(nda, cb) - case FoldStateSig(resultEmitType, accumName, otherAccumName, combOpIR) => { + case FoldStateSig(resultEmitType, _, _, _) => val vWithReq = resultEmitType.typeWithRequiredness new TypedRegionBackedAggState(vWithReq, cb) - } case LinearRegressionStateSig() => new LinearRegressionAggregatorState(cb) } } sealed abstract class AggStateSig(val t: Seq[VirtualTypeWithReq], val n: Option[Seq[AggStateSig]]) case class TypedStateSig(pt: VirtualTypeWithReq) extends AggStateSig(Array(pt), None) -case class DownsampleStateSig(labelType: VirtualTypeWithReq) extends AggStateSig(Array(labelType), None) + +case class DownsampleStateSig(labelType: VirtualTypeWithReq) + extends AggStateSig(Array(labelType), None) + case class TakeStateSig(pt: VirtualTypeWithReq) extends AggStateSig(Array(pt), None) -case class TakeByStateSig(vt: VirtualTypeWithReq, kt: VirtualTypeWithReq, so: SortOrder) extends AggStateSig(Array(vt, kt), None) + +case class TakeByStateSig(vt: VirtualTypeWithReq, kt: VirtualTypeWithReq, so: SortOrder) + extends AggStateSig(Array(vt, kt), None) + case class ReservoirSampleStateSig(pt: VirtualTypeWithReq) extends AggStateSig(Array(pt), None) case class DensifyStateSig(vt: VirtualTypeWithReq) extends AggStateSig(Array(vt), None) case class CollectStateSig(pt: VirtualTypeWithReq) extends AggStateSig(Array(pt), None) case class CollectAsSetStateSig(pt: VirtualTypeWithReq) extends AggStateSig(Array(pt), None) case class CallStatsStateSig() extends AggStateSig(Array[VirtualTypeWithReq](), None) case class ImputeTypeStateSig() extends AggStateSig(Array[VirtualTypeWithReq](), None) -case class ArrayAggStateSig(nested: Seq[AggStateSig]) extends AggStateSig(Array[VirtualTypeWithReq](), Some(nested)) -case class GroupedStateSig(kt: VirtualTypeWithReq, nested: Seq[AggStateSig]) extends AggStateSig(Array(kt), Some(nested)) + +case class ArrayAggStateSig(nested: Seq[AggStateSig]) + extends AggStateSig(Array[VirtualTypeWithReq](), Some(nested)) + +case class GroupedStateSig(kt: VirtualTypeWithReq, nested: Seq[AggStateSig]) + extends AggStateSig(Array(kt), Some(nested)) + case class ApproxCDFStateSig() extends AggStateSig(Array[VirtualTypeWithReq](), None) case class LinearRegressionStateSig() extends AggStateSig(Array[VirtualTypeWithReq](), None) -case class NDArraySumStateSig(nda: VirtualTypeWithReq) extends AggStateSig(Array[VirtualTypeWithReq](nda), None) { + +case class NDArraySumStateSig(nda: VirtualTypeWithReq) + extends AggStateSig(Array[VirtualTypeWithReq](nda), None) { require(!nda.r.required) } -case class NDArrayMultiplyAddStateSig(nda: VirtualTypeWithReq) extends AggStateSig(Array[VirtualTypeWithReq](nda), None) { + +case class NDArrayMultiplyAddStateSig(nda: VirtualTypeWithReq) + extends AggStateSig(Array[VirtualTypeWithReq](nda), None) { require(!nda.r.required) } -case class FoldStateSig(resultEmitType: EmitType, accumName: String, otherAccumName: String, combOpIR: IR) extends AggStateSig(Array[VirtualTypeWithReq](resultEmitType.typeWithRequiredness), None) +case class FoldStateSig( + resultEmitType: EmitType, + accumName: String, + otherAccumName: String, + combOpIR: IR, +) extends AggStateSig(Array[VirtualTypeWithReq](resultEmitType.typeWithRequiredness), None) object PhysicalAggSig { def apply(op: AggOp, state: AggStateSig): PhysicalAggSig = BasicPhysicalAggSig(op, state) - def unapply(v: PhysicalAggSig): Option[(AggOp, AggStateSig)] = if (v.nestedOps.isEmpty) Some(v.op -> v.state) else None + + def unapply(v: PhysicalAggSig): Option[(AggOp, AggStateSig)] = + if (v.nestedOps.isEmpty) Some(v.op -> v.state) else None } +// A pair of an agg state and an op. If the state is compound, also encodes ops for nested states. class PhysicalAggSig(val op: AggOp, val state: AggStateSig, val nestedOps: Array[AggOp]) { val allOps: Array[AggOp] = nestedOps :+ op def initOpTypes: IndexedSeq[Type] = Extract.getAgg(this).initOpTypes.toFastSeq @@ -119,45 +151,74 @@ class PhysicalAggSig(val op: AggOp, val state: AggStateSig, val nestedOps: Array def resultType: Type = emitResultType.virtualType } -case class BasicPhysicalAggSig(override val op: AggOp, override val state: AggStateSig) extends PhysicalAggSig(op, state, Array()) - -case class GroupedAggSig(kt: VirtualTypeWithReq, nested: Seq[PhysicalAggSig]) extends - PhysicalAggSig(Group(), GroupedStateSig(kt, nested.map(_.state)), nested.flatMap(sig => sig.allOps).toArray) -case class AggElementsAggSig(nested: Seq[PhysicalAggSig]) extends - PhysicalAggSig(AggElements(), ArrayAggStateSig(nested.map(_.state)), nested.flatMap(sig => sig.allOps).toArray) - -case class ArrayLenAggSig(knownLength: Boolean, nested: Seq[PhysicalAggSig]) extends - PhysicalAggSig(AggElementsLengthCheck(), ArrayAggStateSig(nested.map(_.state)), nested.flatMap(sig => sig.allOps).toArray) - -class Aggs(original: IR, rewriteMap: Memo[IR], bindingNodesReferenced: Memo[Unit], val init: IR, val seqPerElt: IR, val aggs: Array[PhysicalAggSig]) { +case class BasicPhysicalAggSig(override val op: AggOp, override val state: AggStateSig) + extends PhysicalAggSig(op, state, Array()) + +case class GroupedAggSig(kt: VirtualTypeWithReq, nested: Seq[PhysicalAggSig]) + extends PhysicalAggSig( + Group(), + GroupedStateSig(kt, nested.map(_.state)), + nested.flatMap(sig => sig.allOps).toArray, + ) + +case class AggElementsAggSig(nested: Seq[PhysicalAggSig]) extends PhysicalAggSig( + AggElements(), + ArrayAggStateSig(nested.map(_.state)), + nested.flatMap(sig => sig.allOps).toArray, + ) + +case class ArrayLenAggSig(knownLength: Boolean, nested: Seq[PhysicalAggSig]) extends PhysicalAggSig( + AggElementsLengthCheck(), + ArrayAggStateSig(nested.map(_.state)), + nested.flatMap(sig => sig.allOps).toArray, + ) + +// The result of Extract +class Aggs( + // The pre-extract IR + original: IR, + /* map each descendant of the original ir to its extract result (all contained aggs replaced with + * their results) */ + rewriteMap: Memo[IR], + // nodes which bind variables which are referenced in init op arguments + bindingNodesReferenced: Memo[Unit], + // The extracted void-typed initialization ir + val init: IR, + // The extracted void-typed update ir + val seqPerElt: IR, + // Must be bound to raw aggregators results in postAggIR + val resultRef: Ref, + // All (top-level) aggregators used + val aggs: Array[PhysicalAggSig], +) { val states: Array[AggStateSig] = aggs.map(_.state) val nAggs: Int = aggs.length - - lazy val postAggIR: IR = { + lazy val postAggIR: IR = rewriteMap.lookup(original) - } def rewriteFromInitBindingRoot(f: IR => IR): IR = { - val irNumberMemo = Memo.empty[Int] - var i = 0 - // depth first search -- either DFS or BFS should work here given IR assumptions - VisitIR(original) { x => - irNumberMemo.bind(x, i) - i += 1 - } - if (bindingNodesReferenced.m.isEmpty) { f(rewriteMap.lookup(original)) - // find deepest binding node referenced } else { + val irNumberMemo = Memo.empty[Int] + var i = 0 + // find deepest binding node referenced + // depth first search -- either DFS or BFS should work here given IR assumptions + VisitIR(original) { x => + irNumberMemo.bind(x, i) + i += 1 + } val rewriteRoot = bindingNodesReferenced.m.keys.maxBy(irNumberMemo.lookup) // only support let nodes here -- other binders like stream operators are undefined behavior - RewriteTopDown.rewriteTopDown(original, { - case ir if RefEquality(ir) == rewriteRoot => - val Let(bindings, body) = ir - Let(bindings, f(rewriteMap.lookup(body))) - }).asInstanceOf[IR] + RewriteTopDown.rewriteTopDown( + original, + { + case ir if RefEquality(ir) == rewriteRoot => + val Block(bindings, body) = ir + Block(bindings, f(rewriteMap.lookup(body))) + }, + ).asInstanceOf[IR] } } @@ -184,14 +245,16 @@ class Aggs(original: IR, rewriteMap: Memo[IR], bindingNodesReferenced: Memo[Unit aggs.exists(containsBigAggregator) } - def eltOp(ctx: ExecuteContext): IR = seqPerElt - - def deserialize(ctx: ExecuteContext, spec: BufferSpec): ((HailClassLoader, HailTaskContext, Region, Array[Byte]) => Long) = { - val (_, f) = ir.CompileWithAggregators[AsmFunction1RegionUnit](ctx, + def deserialize(ctx: ExecuteContext, spec: BufferSpec) + : ((HailClassLoader, HailTaskContext, Region, Array[Byte]) => Long) = { + val (_, f) = ir.CompileWithAggregators[AsmFunction1RegionUnit]( + ctx, states, FastSeq(), - FastSeq(classInfo[Region]), UnitInfo, - ir.DeserializeAggs(0, 0, spec, states)) + FastSeq(classInfo[Region]), + UnitInfo, + ir.DeserializeAggs(0, 0, spec, states), + ) val fsBc = ctx.fsBc; { (hcl: HailClassLoader, htc: HailTaskContext, aggRegion: Region, bytes: Array[Byte]) => @@ -203,12 +266,16 @@ class Aggs(original: IR, rewriteMap: Memo[IR], bindingNodesReferenced: Memo[Unit } } - def serialize(ctx: ExecuteContext, spec: BufferSpec): (HailClassLoader, HailTaskContext, Region, Long) => Array[Byte] = { - val (_, f) = ir.CompileWithAggregators[AsmFunction1RegionUnit](ctx, + def serialize(ctx: ExecuteContext, spec: BufferSpec) + : (HailClassLoader, HailTaskContext, Region, Long) => Array[Byte] = { + val (_, f) = ir.CompileWithAggregators[AsmFunction1RegionUnit]( + ctx, states, FastSeq(), - FastSeq(classInfo[Region]), UnitInfo, - ir.SerializeAggs(0, 0, spec, states)) + FastSeq(classInfo[Region]), + UnitInfo, + ir.SerializeAggs(0, 0, spec, states), + ) val fsBc = ctx.fsBc; { (hcl: HailClassLoader, htc: HailTaskContext, aggRegion: Region, off: Long) => @@ -220,52 +287,64 @@ class Aggs(original: IR, rewriteMap: Memo[IR], bindingNodesReferenced: Memo[Unit } } - def combOpFSerializedWorkersOnly(ctx: ExecuteContext, spec: BufferSpec): (Array[Byte], Array[Byte]) => Array[Byte] = { - combOpFSerializedFromRegionPool(ctx, spec)(() => { + def combOpFSerializedWorkersOnly(ctx: ExecuteContext, spec: BufferSpec) + : (Array[Byte], Array[Byte]) => Array[Byte] = { + combOpFSerializedFromRegionPool(ctx, spec) { () => val htc = SparkTaskContext.get() val hcl = theHailClassLoaderForSparkWorkers if (htc == null) { - throw new UnsupportedOperationException(s"Can't get htc. On worker = ${TaskContext.get != null}") + throw new UnsupportedOperationException( + s"Can't get htc. On worker = ${TaskContext.get != null}" + ) } (htc.getRegionPool(), hcl, htc) - }) + } } - def combOpFSerializedFromRegionPool(ctx: ExecuteContext, spec: BufferSpec): (() => (RegionPool, HailClassLoader, HailTaskContext)) => ((Array[Byte], Array[Byte]) => Array[Byte]) = { - val (_, f) = ir.CompileWithAggregators[AsmFunction1RegionUnit](ctx, + def combOpFSerializedFromRegionPool(ctx: ExecuteContext, spec: BufferSpec) + : (() => (RegionPool, HailClassLoader, HailTaskContext)) => ( + (Array[Byte], Array[Byte]) => Array[Byte], + ) = { + val (_, f) = ir.CompileWithAggregators[AsmFunction1RegionUnit]( + ctx, states ++ states, FastSeq(), - FastSeq(classInfo[Region]), UnitInfo, + FastSeq(classInfo[Region]), + UnitInfo, Begin(FastSeq( ir.DeserializeAggs(0, 0, spec, states), ir.DeserializeAggs(nAggs, 1, spec, states), Begin(aggs.zipWithIndex.map { case (sig, i) => CombOp(i, i + nAggs, sig) }), - SerializeAggs(0, 0, spec, states) - ))) + SerializeAggs(0, 0, spec, states), + )), + ) val fsBc = ctx.fsBc - poolGetter: (() => (RegionPool, HailClassLoader, HailTaskContext)) => { (bytes1: Array[Byte], bytes2: Array[Byte]) => - val (pool, hcl, htc) = poolGetter() - pool.scopedSmallRegion { r => - val f2 = f(hcl, fsBc.value, htc, r) - f2.newAggState(r) - f2.setSerializedAgg(0, bytes1) - f2.setSerializedAgg(1, bytes2) - f2(r) - f2.storeAggsToRegion() - f2.getSerializedAgg(0) - } + poolGetter: (() => (RegionPool, HailClassLoader, HailTaskContext)) => { + (bytes1: Array[Byte], bytes2: Array[Byte]) => + val (pool, hcl, htc) = poolGetter() + pool.scopedSmallRegion { r => + val f2 = f(hcl, fsBc.value, htc, r) + f2.newAggState(r) + f2.setSerializedAgg(0, bytes1) + f2.setSerializedAgg(1, bytes2) + f2(r) + f2.storeAggsToRegion() + f2.getSerializedAgg(0) + } } } // Takes ownership of both input regions, and returns ownership of region in // resulting RegionValue. - def combOpF(ctx: ExecuteContext, spec: BufferSpec): (HailClassLoader, HailTaskContext, RegionValue, RegionValue) => RegionValue = { + def combOpF(ctx: ExecuteContext, spec: BufferSpec) + : (HailClassLoader, HailTaskContext, RegionValue, RegionValue) => RegionValue = { val fb = ir.EmitFunctionBuilder[AsmFunction4RegionLongRegionLongLong]( ctx, "combOpF3", FastSeq[ParamType](classInfo[Region], LongInfo, classInfo[Region], LongInfo), - LongInfo) + LongInfo, + ) val leftAggRegion = fb.genFieldThisRef[Region]("agg_combine_left_top_region") val leftAggOff = fb.genFieldThisRef[Long]("agg_combine_left_off") @@ -281,7 +360,8 @@ class Aggs(original: IR, rewriteMap: Memo[IR], bindingNodesReferenced: Memo[Unit val leftStates = agg.StateTuple(states.map(s => AggStateSig.getState(s, fb.ecb))) val leftAggState = new agg.TupleAggregatorState(fb.ecb, leftStates, leftAggRegion, leftAggOff) val rightStates = agg.StateTuple(states.map(s => AggStateSig.getState(s, fb.ecb))) - val rightAggState = new agg.TupleAggregatorState(fb.ecb, rightStates, rightAggRegion, rightAggOff) + val rightAggState = + new agg.TupleAggregatorState(fb.ecb, rightStates, rightAggRegion, rightAggOff) leftStates.createStates(cb) leftAggState.load(cb) @@ -291,7 +371,7 @@ class Aggs(original: IR, rewriteMap: Memo[IR], bindingNodesReferenced: Memo[Unit for (i <- 0 until nAggs) { val rvAgg = agg.Extract.getAgg(aggs(i)) - rvAgg.combOp(ctx, cb, leftAggState.states(i), rightAggState.states(i)) + rvAgg.combOp(ctx, cb, leftAggRegion, leftAggState.states(i), rightAggState.states(i)) } leftAggState.store(cb) @@ -310,60 +390,43 @@ class Aggs(original: IR, rewriteMap: Memo[IR], bindingNodesReferenced: Memo[Unit } } - def results: IR = { + def results: IR = ResultOp.makeTuple(aggs) - } } object Extract { - def partitionDependentLets(lets: Array[AggLet], name: String): (Array[AggLet], Array[AggLet]) = { + // All lets whose value depends on `name` (either directly or transitively through previous lets) + // are returned in the first array, the rest are in the second. + /* TODO: this is only being used to do ad hoc code motion. Remove when we have a real code motion + * pass. */ + private def partitionDependentLets(lets: Array[(String, IR)], name: String) + : (Array[(String, IR)], Array[(String, IR)]) = { val depBindings = mutable.HashSet.empty[String] depBindings += name - val dep = new BoxedArrayBuilder[AggLet] - val indep = new BoxedArrayBuilder[AggLet] - - lets.foreach { l => - val fv = FreeVariables(l.value, supportsAgg = false, supportsScan = false) - if (fv.eval.m.keysIterator.exists(k => depBindings.contains(k))) { - dep += l - depBindings += l.name - } else - indep += l + val dep = new BoxedArrayBuilder[(String, IR)] + val indep = new BoxedArrayBuilder[(String, IR)] + + lets.foreach { case x @ (name, value) => + if (value.typ == TVoid || value.isInstanceOf[ResultOp] || value.isInstanceOf[AggStateValue]) { + /* if the value is side effecting, or implicitly reads the aggregator state, we can't lift + * it */ + dep += x + } else { + val fv = FreeVariables(value, supportsAgg = false, supportsScan = false) + if (fv.eval.m.keysIterator.exists(k => depBindings.contains(k))) { + dep += x + depBindings += name + } else { + indep += x + } + } } (dep.result(), indep.result()) } - def addLets(ir: IR, lets: Array[AggLet]): IR = { - assert(lets.areDistinct()) - Let(lets.map(al => al.name -> al.value), ir) - } - - def getResultType(aggSig: AggSignature): Type = aggSig match { - case AggSignature(Sum(), _, Seq(t)) => t - case AggSignature(Product(), _, Seq(t)) => t - case AggSignature(Min(), _, Seq(t)) => t - case AggSignature(Max(), _, Seq(t)) => t - case AggSignature(Count(), _, _) => TInt64 - case AggSignature(Take(), _, Seq(t)) => TArray(t) - case AggSignature(ReservoirSample(), _, Seq(t)) => TArray(t) - case AggSignature(CallStats(), _, _) => CallStatsState.resultPType.virtualType - case AggSignature(TakeBy(_), _, Seq(value, key)) => TArray(value) - case AggSignature(PrevNonnull(), _, Seq(t)) => t - case AggSignature(CollectAsSet(), _, Seq(t)) => TSet(t) - case AggSignature(Collect(), _, Seq(t)) => TArray(t) - case AggSignature(Densify(), _, Seq(t)) => t - case AggSignature(ImputeType(), _, _) => ImputeTypeState.resultEmitType.virtualType - case AggSignature(LinearRegression(), _, _) => - LinearRegressionAggregator.resultPType.virtualType - case AggSignature(ApproxCDF(), _, _) => QuantilesAggregator.resultPType.virtualType - case AggSignature(Downsample(), _, Seq(_, _, label)) => DownsampleAggregator.resultType - case AggSignature(NDArraySum(), _, Seq(t)) => t - case AggSignature(NDArrayMultiplyAdd(), _, Seq(a : TNDArray, _)) => a - case _ => throw new UnsupportedExtraction(aggSig.toString) - } - + // FIXME: move this to StagedAggregator? def getAgg(sig: PhysicalAggSig): StagedAggregator = sig match { case PhysicalAggSig(Sum(), TypedStateSig(t)) => new SumAggregator(t.t) case PhysicalAggSig(Product(), TypedStateSig(t)) => new ProductAggregator(t.t) @@ -373,16 +436,19 @@ object Extract { case PhysicalAggSig(Count(), TypedStateSig(_)) => CountAggregator case PhysicalAggSig(Take(), TakeStateSig(t)) => new TakeAggregator(t) case PhysicalAggSig(TakeBy(_), TakeByStateSig(v, k, _)) => new TakeByAggregator(v, k) - case PhysicalAggSig(ReservoirSample(), ReservoirSampleStateSig(t)) => new ReservoirSampleAggregator(t) + case PhysicalAggSig(ReservoirSample(), ReservoirSampleStateSig(t)) => + new ReservoirSampleAggregator(t) case PhysicalAggSig(Densify(), DensifyStateSig(v)) => new DensifyAggregator(v) case PhysicalAggSig(CallStats(), CallStatsStateSig()) => new CallStatsAggregator() case PhysicalAggSig(Collect(), CollectStateSig(t)) => new CollectAggregator(t) case PhysicalAggSig(CollectAsSet(), CollectAsSetStateSig(t)) => new CollectAsSetAggregator(t) - case PhysicalAggSig(LinearRegression(), LinearRegressionStateSig()) => new LinearRegressionAggregator() + case PhysicalAggSig(LinearRegression(), LinearRegressionStateSig()) => + new LinearRegressionAggregator() case PhysicalAggSig(ApproxCDF(), ApproxCDFStateSig()) => new ApproxCDFAggregator - case PhysicalAggSig(Downsample(), DownsampleStateSig(labelType)) => new DownsampleAggregator(labelType) + case PhysicalAggSig(Downsample(), DownsampleStateSig(labelType)) => + new DownsampleAggregator(labelType) case PhysicalAggSig(ImputeType(), ImputeTypeStateSig()) => new ImputeTypeAggregator() - case ArrayLenAggSig(knownLength, nested) => //FIXME nested things shouldn't need to know state + case ArrayLenAggSig(knownLength, nested) => // FIXME nested things shouldn't need to know state new ArrayElementLengthCheckAggregator(nested.map(getAgg).toArray, knownLength) case AggElementsAggSig(nested) => new ArrayElementwiseOpAggregator(nested.map(getAgg).toArray) @@ -396,103 +462,170 @@ object Extract { new FoldAggregator(res, accumName, otherAccumName, combOpIR) } - def apply(ir: IR, resultName: String, r: RequirednessAnalysis, isScan: Boolean = false): Aggs = { + def apply(ir: IR, r: RequirednessAnalysis, isScan: Boolean = false): Aggs = { val ab = new BoxedArrayBuilder[(InitOp, PhysicalAggSig)]() - val seq = new BoxedArrayBuilder[IR]() - val let = new BoxedArrayBuilder[AggLet]() - val ref = Ref(resultName, null) + val seq = new BoxedArrayBuilder[(String, IR)]() + val ref = Ref(genUID(), null) val memo = mutable.Map.empty[IR, Int] val bindingNodesReferenced = Memo.empty[Unit] val rewriteMap = Memo.empty[IR] - extract(ir, Env.empty, bindingNodesReferenced, rewriteMap, ab, seq, let, memo, ref, r, isScan) + extract( + ir, + BindingEnv.empty, + bindingNodesReferenced, + rewriteMap, + ab, + seq, + memo, + ref, + r, + isScan, + ) val (initOps, pAggSigs) = ab.result().unzip val rt = TTuple(initOps.map(_.aggSig.resultType): _*) ref._typ = rt - new Aggs(ir, rewriteMap, bindingNodesReferenced, Begin(initOps), addLets(Begin(seq.result()), let.result()), pAggSigs) + new Aggs( + ir, + rewriteMap, + bindingNodesReferenced, + Begin(initOps), + Let.void(seq.result()), + ref, + pAggSigs, + ) } - private def extract(ir: IR, env: Env[RefEquality[IR]], bindingNodesReferenced: Memo[Unit], rewriteMap: Memo[IR], ab: BoxedArrayBuilder[(InitOp, PhysicalAggSig)], seqBuilder: BoxedArrayBuilder[IR], letBuilder: BoxedArrayBuilder[AggLet], memo: mutable.Map[IR, Int], result: IR, r: RequirednessAnalysis, isScan: Boolean): IR = { - // the env argument here tracks variable bindings that are accessible to init op arguments - + private def extract( + ir: IR, + // bindings in scope for init op arguments + env: BindingEnv[RefEquality[IR]], + // nodes which bind variables which are referenced in init op arguments + bindingNodesReferenced: Memo[Unit], + /* map each desendant of the original ir to its extract result (all contained aggs replaced with + * their results) */ + rewriteMap: Memo[IR], + // set of contained aggs, and the init op for each + ab: BoxedArrayBuilder[(InitOp, PhysicalAggSig)], + // set of updates for contained aggs + seqBuilder: BoxedArrayBuilder[(String, IR)], + /* Map each contained ApplyAggOp, ApplyScanOp, or AggFold, to the index of the corresponding agg + * state, used to perform CSE on agg ops */ + memo: mutable.Map[IR, Int], + // a reference to the tuple of results of contained aggs + result: IR, + r: RequirednessAnalysis, + isScan: Boolean, + ): IR = { def newMemo: mutable.Map[IR, Int] = mutable.Map.empty[IR, Int] + // For each free variable in each init op arg, add the binding site to bindingNodesReferenced def bindInitArgRefs(initArgs: IndexedSeq[IR]): Unit = { initArgs.foreach { arg => val fv = FreeVariables(arg, false, false).eval fv.m.keys - .flatMap { k => env.lookupOption(k) } + .flatMap(k => env.eval.lookupOption(k)) .foreach(bindingNodesReferenced.bind(_, ())) } } val newNode = ir match { - case x@AggLet(name, value, body, _) => - letBuilder += x - this.extract(body, env, bindingNodesReferenced, rewriteMap, ab, seqBuilder, letBuilder, memo, result, r, isScan) + case x @ Block(bindings, body) => + var newEnv = env + val newBindings = Array.newBuilder[Binding] + newBindings.sizeHint(bindings) + for (binding <- bindings) binding match { + case Binding(name, value, Scope.EVAL) => + val newValue = this.extract(value, newEnv, bindingNodesReferenced, rewriteMap, ab, + seqBuilder, memo, result, r, isScan) + newBindings += Binding(name, newValue) + newEnv = env.bindEval(name, RefEquality(x)) + case Binding(name, value, _) => + seqBuilder += name -> value + } + val newBody = this.extract(body, newEnv, bindingNodesReferenced, rewriteMap, ab, seqBuilder, + memo, result, r, isScan) + Block(newBindings.result(), newBody) case x: ApplyAggOp if !isScan => - val idx = memo.getOrElseUpdate(x, { - val i = ab.length - val op = x.aggSig.op - bindInitArgRefs(x.initOpArgs) - val state = PhysicalAggSig(op, AggStateSig(op, x.initOpArgs, x.seqOpArgs, r)) - ab += InitOp(i, x.initOpArgs, state) -> state - seqBuilder += SeqOp(i, x.seqOpArgs, state) - i - }) + val idx = memo.getOrElseUpdate( + x, { + val i = ab.length + val op = x.aggSig.op + bindInitArgRefs(x.initOpArgs) + val state = PhysicalAggSig(op, AggStateSig(op, x.seqOpArgs, r)) + ab += InitOp(i, x.initOpArgs, state) -> state + seqBuilder += "__void" -> SeqOp(i, x.seqOpArgs, state) + i + }, + ) GetTupleElement(result, idx) case x: ApplyScanOp if isScan => - val idx = memo.getOrElseUpdate(x, { - val i = ab.length - val op = x.aggSig.op - bindInitArgRefs(x.initOpArgs) - val state = PhysicalAggSig(op, AggStateSig(op, x.initOpArgs, x.seqOpArgs, r)) - ab += InitOp(i, x.initOpArgs, state) -> state - seqBuilder += SeqOp(i, x.seqOpArgs, state) - i - }) + val idx = memo.getOrElseUpdate( + x, { + val i = ab.length + val op = x.aggSig.op + bindInitArgRefs(x.initOpArgs) + val state = PhysicalAggSig(op, AggStateSig(op, x.seqOpArgs, r)) + ab += InitOp(i, x.initOpArgs, state) -> state + seqBuilder += "__void" -> SeqOp(i, x.seqOpArgs, state) + i + }, + ) GetTupleElement(result, idx) - case x@AggFold(zero, seqOp, combOp, accumName, otherAccumName, isScan) => - val idx = memo.getOrElseUpdate(x, { - val i = ab.length - val initOpArgs = IndexedSeq(zero) - bindInitArgRefs(initOpArgs) - val seqOpArgs = IndexedSeq(seqOp) - val op = Fold() - val resultEmitType = r(x).canonicalEmitType(x.typ) - val foldStateSig = FoldStateSig(resultEmitType, accumName, otherAccumName, combOp) - val signature = PhysicalAggSig(op, foldStateSig) - ab += InitOp(i, initOpArgs, signature) -> signature - // So seqOp has to be able to reference accumName. - val seqWithLet = Let(FastSeq(accumName -> ResultOp(i, signature)), SeqOp(i, seqOpArgs, signature)) - seqBuilder += seqWithLet - i - }) + case x @ AggFold(zero, seqOp, combOp, accumName, otherAccumName, _) => + val idx = memo.getOrElseUpdate( + x, { + val i = ab.length + val initOpArgs = IndexedSeq(zero) + bindInitArgRefs(initOpArgs) + val seqOpArgs = IndexedSeq(seqOp) + val op = Fold() + val resultEmitType = r(x).canonicalEmitType(x.typ) + val foldStateSig = FoldStateSig(resultEmitType, accumName, otherAccumName, combOp) + val signature = PhysicalAggSig(op, foldStateSig) + ab += InitOp(i, initOpArgs, signature) -> signature + // So seqOp has to be able to reference accumName. + seqBuilder += accumName -> ResultOp(i, signature) + seqBuilder += "__void" -> SeqOp(i, seqOpArgs, signature) + i + }, + ) GetTupleElement(result, idx) case AggFilter(cond, aggIR, _) => - val newSeq = new BoxedArrayBuilder[IR]() - val newLet = new BoxedArrayBuilder[AggLet]() - val transformed = this.extract(aggIR, env, bindingNodesReferenced, rewriteMap, ab, newSeq, newLet, newMemo, result, r, isScan) + val newSeq = new BoxedArrayBuilder[(String, IR)]() + val transformed = this.extract(aggIR, env, bindingNodesReferenced, rewriteMap, ab, newSeq, + newMemo, result, r, isScan) - seqBuilder += If(cond, addLets(Begin(newSeq.result()), newLet.result()), Begin(FastSeq[IR]())) + seqBuilder += "__void" -> If(cond, Let.void(newSeq.result()), Void()) transformed case AggExplode(array, name, aggBody, _) => - val newSeq = new BoxedArrayBuilder[IR]() - val newLet = new BoxedArrayBuilder[AggLet]() - val transformed = this.extract(aggBody, env, bindingNodesReferenced, rewriteMap, ab, newSeq, newLet, newMemo, result, r, isScan) + val newSeq = new BoxedArrayBuilder[(String, IR)]() + val transformed = this.extract(aggBody, env, bindingNodesReferenced, rewriteMap, ab, newSeq, + newMemo, result, r, isScan) - val (dependent, independent) = partitionDependentLets(newLet.result(), name) - letBuilder ++= independent - seqBuilder += StreamFor(array, name, addLets(Begin(newSeq.result()), dependent)) + val (dependent, independent) = partitionDependentLets(newSeq.result(), name) + seqBuilder ++= independent + seqBuilder += "__void" -> StreamFor(array, name, Let.void(dependent)) transformed case AggGroupBy(key, aggIR, _) => val newAggs = new BoxedArrayBuilder[(InitOp, PhysicalAggSig)]() - val newSeq = new BoxedArrayBuilder[IR]() + val newSeq = new BoxedArrayBuilder[(String, IR)]() val newRef = Ref(genUID(), null) - val transformed = this.extract(aggIR, env, bindingNodesReferenced, rewriteMap, newAggs, newSeq, letBuilder, newMemo, GetField(newRef, "value"), r, isScan) + val transformed = this.extract( + aggIR, + env, + bindingNodesReferenced, + rewriteMap, + newAggs, + newSeq, + newMemo, + GetField(newRef, "value"), + r, + isScan, + ) val i = ab.length val (initOps, pAggSigs) = newAggs.result().unzip @@ -503,19 +636,26 @@ object Extract { val groupState = AggStateSig.grouped(key, pAggSigs.map(_.state), r) val groupSig = GroupedAggSig(groupState.kt, pAggSigs.toFastSeq) ab += InitOp(i, FastSeq(Begin(initOps)), groupSig) -> groupSig - seqBuilder += SeqOp(i, FastSeq(key, Begin(newSeq.result().toFastSeq)), groupSig) - - ToDict(StreamMap(ToStream(GetTupleElement(result, i)), newRef.name, MakeTuple.ordered(FastSeq(GetField(newRef, "key"), transformed)))) + seqBuilder += "__void" -> SeqOp( + i, + FastSeq(key, Let.void(newSeq.result())), + groupSig, + ) + + ToDict(StreamMap( + ToStream(GetTupleElement(result, i)), + newRef.name, + MakeTuple.ordered(FastSeq(GetField(newRef, "key"), transformed)), + )) case AggArrayPerElement(a, elementName, indexName, aggBody, knownLength, _) => val newAggs = new BoxedArrayBuilder[(InitOp, PhysicalAggSig)]() - val newSeq = new BoxedArrayBuilder[IR]() - val newLet = new BoxedArrayBuilder[AggLet]() + val newSeq = new BoxedArrayBuilder[(String, IR)]() val newRef = Ref(genUID(), null) - val transformed = this.extract(aggBody, env, bindingNodesReferenced, rewriteMap, newAggs, newSeq, newLet, newMemo, newRef, r, isScan) + val transformed = this.extract(aggBody, env, bindingNodesReferenced, rewriteMap, newAggs, + newSeq, newMemo, newRef, r, isScan) - val (dependent, independent) = partitionDependentLets(newLet.result(), elementName) - letBuilder ++= independent + val (dependent, independent) = partitionDependentLets(newSeq.result(), elementName) val i = ab.length val (initOps, pAggSigs) = newAggs.result().unzip @@ -528,20 +668,28 @@ object Extract { val aRef = Ref(genUID(), a.typ) - ab += InitOp(i, knownLength.map(FastSeq(_)).getOrElse(FastSeq[IR]()) :+ Begin(initOps), checkSig) -> checkSig + ab += InitOp( + i, + knownLength.map(FastSeq(_)).getOrElse(FastSeq[IR]()) :+ Begin(initOps), + checkSig, + ) -> checkSig + + seqBuilder ++= independent + seqBuilder += aRef.name -> a + seqBuilder += "__void" -> SeqOp(i, FastSeq(ArrayLen(aRef)), checkSig) seqBuilder += - Let( - FastSeq(aRef.name -> a), - Begin(FastSeq( - SeqOp(i, FastSeq(ArrayLen(aRef)), checkSig), - StreamFor( - StreamRange(I32(0), ArrayLen(aRef), I32(1)), - indexName, - Let( - FastSeq(elementName -> ArrayRef(aRef, Ref(indexName, TInt32))), - addLets(SeqOp(i, - FastSeq(Ref(indexName, TInt32), Begin(newSeq.result().toFastSeq)), - eltSig), dependent)))))) + "__void" -> StreamFor( + StreamRange(I32(0), ArrayLen(aRef), I32(1)), + indexName, + SeqOp( + i, + FastSeq( + Ref(indexName, TInt32), + Let.void((elementName, ArrayRef(aRef, Ref(indexName, TInt32))) +: dependent), + ), + eltSig, + ), + ) val rUID = Ref(genUID(), rt) Let( @@ -551,7 +699,10 @@ object Extract { indexName, Let( FastSeq(newRef.name -> ArrayRef(rUID, Ref(indexName, TInt32))), - transformed)))) + transformed, + ), + )), + ) case x: StreamAgg => assert(!ContainsScan(x)) @@ -560,14 +711,12 @@ object Extract { assert(!ContainsAgg(x)) x case x => - ir.mapChildrenWithIndex { case (child: IR, i) => - val nb = Bindings(x, i) - val newEnv = if (nb.nonEmpty) { - val re = RefEquality(x) - env.bindIterable(nb.map { case (name, _) => (name, re)}) - } else env - - this.extract(child, newEnv, bindingNodesReferenced, rewriteMap, ab, seqBuilder, letBuilder, memo, result, r, isScan) + x.mapChildrenWithIndex { case (child: IR, i) => + val newEnv = + Bindings.segregated(x, i, env).mapNewBindings((_, _) => RefEquality(x)).unified + + this.extract(child, newEnv, bindingNodesReferenced, rewriteMap, ab, seqBuilder, memo, + result, r, isScan) } } diff --git a/hail/src/main/scala/is/hail/expr/ir/agg/FoldAggregator.scala b/hail/src/main/scala/is/hail/expr/ir/agg/FoldAggregator.scala index c14c2cb3b5d..db22c026634 100644 --- a/hail/src/main/scala/is/hail/expr/ir/agg/FoldAggregator.scala +++ b/hail/src/main/scala/is/hail/expr/ir/agg/FoldAggregator.scala @@ -1,15 +1,23 @@ package is.hail.expr.ir.agg + import is.hail.annotations.Region -import is.hail.asm4s.{AsmFunction1RegionLong, AsmFunction2, LongInfo, Value} +import is.hail.asm4s.Value import is.hail.backend.ExecuteContext -import is.hail.expr.ir.{Compile, Emit, EmitClassBuilder, EmitCode, EmitCodeBuilder, EmitContext, EmitEnv, EmitMethodBuilder, Env, IEmitCode, IR, SCodeEmitParamType} -import is.hail.types.physical.PType -import is.hail.types.physical.stypes.{EmitType, SType} +import is.hail.expr.ir.{ + Emit, EmitClassBuilder, EmitCode, EmitCodeBuilder, EmitContext, EmitEnv, EmitMethodBuilder, Env, + IEmitCode, IR, +} +import is.hail.types.physical.stypes.EmitType import is.hail.types.virtual.Type // (IR => T), seq op (IR T => T), and comb op (IR (T,T) => T) -class FoldAggregator(val resultEmitType: EmitType, accumName: String, otherAccumName: String, combOpIR: IR) extends StagedAggregator { +class FoldAggregator( + val resultEmitType: EmitType, + accumName: String, + otherAccumName: String, + combOpIR: IR, +) extends StagedAggregator { override type State = TypedRegionBackedAggState override val initOpTypes: Seq[Type] = IndexedSeq(resultEmitType.virtualType) @@ -26,20 +34,28 @@ class FoldAggregator(val resultEmitType: EmitType, accumName: String, otherAccum elt.toI(cb).consume(cb, state.storeMissing(cb), sv => state.storeNonmissing(cb, sv)) } - override protected def _combOp(ctx: ExecuteContext, cb: EmitCodeBuilder, state: TypedRegionBackedAggState, other: TypedRegionBackedAggState): Unit = { - + override protected def _combOp( + ctx: ExecuteContext, + cb: EmitCodeBuilder, + region: Value[Region], + state: TypedRegionBackedAggState, + other: TypedRegionBackedAggState, + ): Unit = { val stateEV = state.get(cb).memoizeField(cb, "fold_agg_comb_op_state") val otherEV = other.get(cb).memoizeField(cb, "fold_agg_comb_op_other") val env = EmitEnv(Env.apply((accumName, stateEV), (otherAccumName, otherEV)), IndexedSeq()) - val pEnv = Env.apply((accumName, resultEmitType.storageType), (otherAccumName, resultEmitType.storageType)) + val pEnv = Env.apply( + (accumName, resultEmitType.storageType), + (otherAccumName, resultEmitType.storageType), + ) val emitCtx = EmitContext.analyze(ctx, combOpIR, pEnv) val emit = new Emit[Any](emitCtx, cb.emb.ecb.asInstanceOf[EmitClassBuilder[Any]]) - val ec = emit.emit(combOpIR, cb.emb.asInstanceOf[EmitMethodBuilder[Any]], env, None) + val ec = emit.emit(combOpIR, cb.emb.asInstanceOf[EmitMethodBuilder[Any]], region, env, None) ec.toI(cb).consume(cb, state.storeMissing(cb), sv => state.storeNonmissing(cb, sv)) } - override protected def _result(cb: EmitCodeBuilder, state: State, region: Value[Region]): IEmitCode = { - state.get(cb).map(cb){ sv => sv.copyToRegion(cb, region, sv.st)} - } + override protected def _result(cb: EmitCodeBuilder, state: State, region: Value[Region]) + : IEmitCode = + state.get(cb).map(cb)(sv => sv.copyToRegion(cb, region, sv.st)) } diff --git a/hail/src/main/scala/is/hail/expr/ir/agg/GroupedAggregator.scala b/hail/src/main/scala/is/hail/expr/ir/agg/GroupedAggregator.scala index fe3f23504cd..b5cdc4c41be 100644 --- a/hail/src/main/scala/is/hail/expr/ir/agg/GroupedAggregator.scala +++ b/hail/src/main/scala/is/hail/expr/ir/agg/GroupedAggregator.scala @@ -3,29 +3,41 @@ package is.hail.expr.ir.agg import is.hail.annotations.Region import is.hail.asm4s._ import is.hail.backend.ExecuteContext +import is.hail.expr.ir.{ + EmitClassBuilder, EmitCode, EmitCodeBuilder, EmitRegion, EmitValue, IEmitCode, ParamType, +} import is.hail.expr.ir.orderings.CodeOrdering -import is.hail.expr.ir.{EmitClassBuilder, EmitCode, EmitCodeBuilder, EmitRegion, EmitValue, IEmitCode, ParamType} import is.hail.io._ import is.hail.types.VirtualTypeWithReq import is.hail.types.encoded.EType import is.hail.types.physical._ -import is.hail.types.physical.stypes.concrete.SIndexablePointer import is.hail.types.physical.stypes.{EmitType, SValue} +import is.hail.types.physical.stypes.concrete.SIndexablePointer import is.hail.types.virtual.{TVoid, Type} import is.hail.utils._ -class GroupedBTreeKey(kt: PType, kb: EmitClassBuilder[_], region: Value[Region], val offset: Value[Long], states: StateTuple) extends BTreeKey { - override val storageType: PStruct = PCanonicalStruct(required = true, +class GroupedBTreeKey( + kt: PType, + kb: EmitClassBuilder[_], + region: Value[Region], + val offset: Value[Long], + states: StateTuple, +) extends BTreeKey { + override val storageType: PStruct = PCanonicalStruct( + required = true, "kt" -> kt, "regionIdx" -> PInt32(true), - "container" -> states.storageType) + "container" -> states.storageType, + ) + override val compType: PType = kt override def compWithKey(cb: EmitCodeBuilder, off: Value[Long], k: EmitValue): Value[Int] = { - val mb = kb.getOrGenEmitMethod("compWithKey", + val mb = kb.getOrGenEmitMethod( + "compWithKey", ("compWithKey_grouped_btree", kt, k.emitType), FastSeq[ParamType](typeInfo[Long], k.emitParamType), - typeInfo[Int] + typeInfo[Int], ) { mb => val comp = kb.getOrderingFunction(compType.sType, k.st, CodeOrdering.Compare()) mb.emitWithBuilder { cb => @@ -35,12 +47,13 @@ class GroupedBTreeKey(kt: PType, kb: EmitClassBuilder[_], region: Value[Region], comp(cb, ev1, ev2) } } - cb.invokeCode(mb, off, k) + cb.invokeCode(mb, cb.this_, off, k) } val regionIdx: Value[Int] = new Value[Int] { def get: Code[Int] = Region.loadInt(storageType.fieldOffset(offset, 1)) } + val container = new TupleAggregatorState(kb, states, region, containerOffset(offset), regionIdx) def isKeyMissing(cb: EmitCodeBuilder, off: Code[Long]): Value[Boolean] = @@ -52,15 +65,15 @@ class GroupedBTreeKey(kt: PType, kb: EmitClassBuilder[_], region: Value[Region], def initValue(cb: EmitCodeBuilder, destc: Code[Long], k: EmitCode, rIdx: Code[Int]): Unit = { val dest = cb.newLocal("ga_init_value_dest", destc) k.toI(cb) - .consume(cb, - { - storageType.setFieldMissing(cb, dest, 0) - }, + .consume( + cb, + storageType.setFieldMissing(cb, dest, 0), { sc => storageType.setFieldPresent(cb, dest, 0) storageType.fieldType("kt") .storeAtAddress(cb, storageType.fieldOffset(dest, 0), region, sc, deepCopy = true) - }) + }, + ) storeRegionIdx(cb, dest, rIdx) container.newState(cb) } @@ -69,7 +82,8 @@ class GroupedBTreeKey(kt: PType, kb: EmitClassBuilder[_], region: Value[Region], def storeStates(cb: EmitCodeBuilder): Unit = container.store(cb) - def copyStatesFrom(cb: EmitCodeBuilder, srcOff: Value[Long]): Unit = container.copyFrom(cb, srcOff) + def copyStatesFrom(cb: EmitCodeBuilder, srcOff: Value[Long]): Unit = + container.copyFrom(cb, srcOff) def storeRegionIdx(cb: EmitCodeBuilder, off: Code[Long], idx: Code[Int]): Unit = cb += Region.storeInt(storageType.fieldOffset(off, 1), idx) @@ -85,27 +99,46 @@ class GroupedBTreeKey(kt: PType, kb: EmitClassBuilder[_], region: Value[Region], cb += Region.storeInt(storageType.fieldOffset(off, 1), -1) override def copy(cb: EmitCodeBuilder, src: Code[Long], dest: Code[Long]): Unit = - storageType.storeAtAddress(cb, dest, region, storageType.loadCheapSCode(cb, src), deepCopy = false) - - override def deepCopy(cb: EmitCodeBuilder, er: EmitRegion, dest: Code[Long], srcCode: Code[Long]): Unit = { + storageType.storeAtAddress( + cb, + dest, + region, + storageType.loadCheapSCode(cb, src), + deepCopy = false, + ) + + override def deepCopy(cb: EmitCodeBuilder, er: EmitRegion, dest: Code[Long], srcCode: Code[Long]) + : Unit = { val src = cb.newLocal("ga_deep_copy_src", srcCode) - storageType.storeAtAddress(cb, dest, region, storageType.loadCheapSCode(cb, src), deepCopy = true) + storageType.storeAtAddress( + cb, + dest, + region, + storageType.loadCheapSCode(cb, src), + deepCopy = true, + ) container.copyFrom(cb, containerOffset(src)) container.store(cb) } - override def compKeys(cb: EmitCodeBuilder, k1: EmitValue, k2: EmitValue): Value[Int] = { + override def compKeys(cb: EmitCodeBuilder, k1: EmitValue, k2: EmitValue): Value[Int] = kb.getOrderingFunction(k1.st, k2.st, CodeOrdering.Compare())(cb, k1, k2) - } override def loadCompKey(cb: EmitCodeBuilder, off: Value[Long]): EmitValue = cb.memoize(IEmitCode(cb, isKeyMissing(cb, off), loadKey(cb, off))) } -class DictState(val kb: EmitClassBuilder[_], val keyVType: VirtualTypeWithReq, val nested: StateTuple) extends PointerBasedRVAState { +class DictState( + val kb: EmitClassBuilder[_], + val keyVType: VirtualTypeWithReq, + val nested: StateTuple, +) extends PointerBasedRVAState { private val keyType = keyVType.canonicalPType val nStates: Int = nested.nStates - val valueType: PStruct = PCanonicalStruct("regionIdx" -> PInt32(true), "states" -> nested.storageType) + + val valueType: PStruct = + PCanonicalStruct("regionIdx" -> PInt32(true), "states" -> nested.storageType) + val root: Settable[Long] = kb.genFieldThisRef[Long]("grouped_agg_root") val size: Settable[Int] = kb.genFieldThisRef[Int]("grouped_agg_size") val keyEType = EType.defaultFromPType(keyType) @@ -114,13 +147,17 @@ class DictState(val kb: EmitClassBuilder[_], val keyVType: VirtualTypeWithReq, v required = true, "inits" -> nested.storageType, "size" -> PInt32(true), - "tree" -> PInt64(true)) + "tree" -> PInt64(true), + ) private val _elt = kb.genFieldThisRef[Long]() + private val initStatesOffset: Value[Long] = new Value[Long] { def get: Code[Long] = typ.loadField(off, 0) } - val initContainer: TupleAggregatorState = new TupleAggregatorState(kb, nested, region, initStatesOffset) + + val initContainer: TupleAggregatorState = + new TupleAggregatorState(kb, nested, region, initStatesOffset) val keyed = new GroupedBTreeKey(keyType, kb, region, _elt, nested) val tree = new AppendOnlyBTree(kb, keyed, region, root, maxElements = 6) @@ -135,16 +172,17 @@ class DictState(val kb: EmitClassBuilder[_], val keyVType: VirtualTypeWithReq, v cb.assign(_elt, node) keyed.loadStates(cb) } - + def loadContainer(cb: EmitCodeBuilder, kec: EmitCode): Unit = { val kev = cb.memoize(kec, "ga_load_cont_k") cb.assign(_elt, tree.getOrElseInitialize(cb, kev)) - cb.if_(keyed.isEmpty(cb, _elt), { - initElement(cb, _elt, kev) - keyed.copyStatesFrom(cb, initStatesOffset) - }, { - keyed.loadStates(cb) - }) + cb.if_( + keyed.isEmpty(cb, _elt), { + initElement(cb, _elt, kev) + keyed.copyStatesFrom(cb, initStatesOffset) + }, + keyed.loadStates(cb), + ) } def withContainer(cb: EmitCodeBuilder, k: EmitCode, seqOps: EmitCodeBuilder => Unit): Unit = { @@ -158,16 +196,25 @@ class DictState(val kb: EmitClassBuilder[_], val keyVType: VirtualTypeWithReq, v nested.createStates(cb) } - override def load(cb: EmitCodeBuilder, regionLoader: (EmitCodeBuilder, Value[Region]) => Unit, src: Value[Long]): Unit = { + override def load( + cb: EmitCodeBuilder, + regionLoader: (EmitCodeBuilder, Value[Region]) => Unit, + src: Value[Long], + ): Unit = { super.load(cb, regionLoader, src) - cb.if_(off.cne(0L), - { + cb.if_( + off.cne(0L), { cb.assign(size, Region.loadInt(typ.loadField(off, 1))) cb.assign(root, Region.loadAddress(typ.loadField(off, 2))) - }) + }, + ) } - override def store(cb: EmitCodeBuilder, regionStorer: (EmitCodeBuilder, Value[Region]) => Unit, dest: Value[Long]): Unit = { + override def store( + cb: EmitCodeBuilder, + regionStorer: (EmitCodeBuilder, Value[Region]) => Unit, + dest: Value[Long], + ): Unit = { cb += Region.storeInt(typ.fieldOffset(off, 1), size) cb += Region.storeAddress(typ.fieldOffset(off, 2), root) super.store(cb, regionStorer, dest) @@ -183,20 +230,24 @@ class DictState(val kb: EmitClassBuilder[_], val keyVType: VirtualTypeWithReq, v tree.init(cb) } - def combine(cb: EmitCodeBuilder, other: DictState, comb: EmitCodeBuilder => Unit): Unit = { - other.foreach(cb) { (cb, k) => withContainer(cb, k, comb) } - } + def combine(cb: EmitCodeBuilder, other: DictState, comb: EmitCodeBuilder => Unit): Unit = + other.foreach(cb)((cb, k) => withContainer(cb, k, comb)) // loads container; does not update. def foreach(cb: EmitCodeBuilder)(f: (EmitCodeBuilder, EmitCode) => Unit): Unit = tree.foreach(cb) { (cb, kvOff) => cb.assign(_elt, kvOff) keyed.loadStates(cb) - f(cb, EmitCode.fromI(cb.emb)(cb => IEmitCode(cb, keyed.isKeyMissing(cb, _elt), keyed.loadKey(cb, _elt)))) + f( + cb, + EmitCode.fromI(cb.emb)(cb => + IEmitCode(cb, keyed.isKeyMissing(cb, _elt), keyed.loadKey(cb, _elt)) + ), + ) } def copyFromAddress(cb: EmitCodeBuilder, src: Value[Long]): Unit = { - init(cb, { cb => initContainer.copyFrom(cb, cb.memoize(typ.loadField(src, 0))) }) + init(cb, cb => initContainer.copyFrom(cb, cb.memoize(typ.loadField(src, 0)))) cb.assign(size, Region.loadInt(typ.loadField(src, 1))) tree.deepCopy(cb, cb.memoize(Region.loadAddress(typ.loadField(src, 2)))) } @@ -206,24 +257,20 @@ class DictState(val kb: EmitClassBuilder[_], val keyVType: VirtualTypeWithReq, v { (cb: EmitCodeBuilder, ob: Value[OutputBuffer]) => initContainer.load(cb) - nested.toCodeWithArgs(cb, - { (cb, i, _) => - serializers(i)(cb, ob) - }) + nested.toCodeWithArgs(cb, (cb, i, _) => serializers(i)(cb, ob)) tree.bulkStore(cb, ob) { (cb: EmitCodeBuilder, ob: Value[OutputBuffer], kvOff: Code[Long]) => cb.assign(_elt, kvOff) val km = keyed.isKeyMissing(cb, _elt) cb += (ob.writeBoolean(km)) - cb.if_(!km, { - val k = keyed.loadKey(cb, _elt) - keyEType.buildEncoder(k.st, kb) - .apply(cb, k, ob) - }) + cb.if_( + !km, { + val k = keyed.loadKey(cb, _elt) + keyEType.buildEncoder(k.st, kb) + .apply(cb, k, ob) + }, + ) keyed.loadStates(cb) - nested.toCodeWithArgs(cb, - { (cb, i, _) => - serializers(i)(cb, ob) - }) + nested.toCodeWithArgs(cb, (cb, i, _) => serializers(i)(cb, ob)) } } } @@ -232,33 +279,34 @@ class DictState(val kb: EmitClassBuilder[_], val keyVType: VirtualTypeWithReq, v val deserializers = nested.states.map(_.deserialize(codec)) { (cb: EmitCodeBuilder, ib: Value[InputBuffer]) => - init(cb, { cb => - nested.toCodeWithArgs(cb, - { (cb, i, _) => - deserializers(i)(cb, ib) - }) - }) + init(cb, cb => nested.toCodeWithArgs(cb, (cb, i, _) => deserializers(i)(cb, ib))) tree.bulkLoad(cb, ib) { (cb, ib, koff) => cb.assign(_elt, koff) val kc = EmitCode.fromI(cb.emb)(cb => - IEmitCode(cb, ib.readBoolean(), keyEType.buildDecoder(keyType.virtualType, kb).apply(cb, region, ib))) + IEmitCode( + cb, + ib.readBoolean(), + keyEType.buildDecoder(keyType.virtualType, kb).apply(cb, region, ib), + ) + ) initElement(cb, _elt, kc) - nested.toCodeWithArgs(cb, - { (cb, i, _) => - deserializers(i)(cb, ib) - }) + nested.toCodeWithArgs(cb, (cb, i, _) => deserializers(i)(cb, ib)) keyed.storeStates(cb) } } } } -class GroupedAggregator(ktV: VirtualTypeWithReq, nestedAggs: Array[StagedAggregator]) extends StagedAggregator { +class GroupedAggregator(ktV: VirtualTypeWithReq, nestedAggs: Array[StagedAggregator]) + extends StagedAggregator { type State = DictState private val kt = ktV.canonicalPType - val resultEltType: PTuple = PCanonicalTuple(true, nestedAggs.map(_.resultEmitType.storageType): _*) + + val resultEltType: PTuple = + PCanonicalTuple(true, nestedAggs.map(_.resultEmitType.storageType): _*) + val resultPType: PCanonicalDict = PCanonicalDict(kt, resultEltType) override val resultEmitType = EmitType(SIndexablePointer(resultPType), true) private[this] val arrayRep = resultPType.arrayRep @@ -268,7 +316,7 @@ class GroupedAggregator(ktV: VirtualTypeWithReq, nestedAggs: Array[StagedAggrega protected def _initOp(cb: EmitCodeBuilder, state: State, init: Array[EmitCode]): Unit = { val Array(inits) = init - state.init(cb, { cb => cb += inits.asVoid() }) + state.init(cb, cb => cb += inits.asVoid()) } protected def _seqOp(cb: EmitCodeBuilder, state: State, seq: Array[EmitCode]): Unit = { @@ -276,36 +324,57 @@ class GroupedAggregator(ktV: VirtualTypeWithReq, nestedAggs: Array[StagedAggrega state.withContainer(cb, key, (cb) => cb += seqs.asVoid()) } - protected def _combOp(ctx: ExecuteContext, cb: EmitCodeBuilder, state: DictState, other: DictState): Unit = { - state.combine(cb, other, { cb => - state.nested.toCode((i, s) => nestedAggs(i).combOp(ctx, cb, s, other.nested(i))) - }) - - } - - protected override def _result(cb: EmitCodeBuilder, state: State, region: Value[Region]): IEmitCode = { + protected def _combOp( + ctx: ExecuteContext, + cb: EmitCodeBuilder, + region: Value[Region], + state: DictState, + other: DictState, + ): Unit = + state.combine( + cb, + other, + cb => state.nested.toCode((i, s) => nestedAggs(i).combOp(ctx, cb, region, s, other.nested(i))), + ) + + override protected def _result(cb: EmitCodeBuilder, state: State, region: Value[Region]) + : IEmitCode = { val len = state.size val resultAddr = cb.newLocal[Long]("groupedagg_result_addr", resultPType.allocate(region, len)) arrayRep.stagedInitialize(cb, resultAddr, len, setMissing = false) val i = cb.newLocal[Int]("groupedagg_result_i", 0) state.foreach(cb) { (cb, k) => - val addrAtI = cb.newLocal[Long]("groupedagg_result_addr_at_i", arrayRep.elementOffset(resultAddr, len, i)) + val addrAtI = + cb.newLocal[Long]("groupedagg_result_addr_at_i", arrayRep.elementOffset(resultAddr, len, i)) dictElt.stagedInitialize(cb, addrAtI, setMissing = false) - k.toI(cb).consume(cb, + k.toI(cb).consume( + cb, dictElt.setFieldMissing(cb, addrAtI, "key"), - { sc => - dictElt.fieldType("key").storeAtAddress(cb, dictElt.fieldOffset(addrAtI, "key"), region, sc, deepCopy = true) - }) - - val valueAddr = cb.newLocal[Long]("groupedagg_value_addr", dictElt.fieldOffset(addrAtI, "value")) + sc => + dictElt.fieldType("key").storeAtAddress( + cb, + dictElt.fieldOffset(addrAtI, "key"), + region, + sc, + deepCopy = true, + ), + ) + + val valueAddr = + cb.newLocal[Long]("groupedagg_value_addr", dictElt.fieldOffset(addrAtI, "value")) resultEltType.stagedInitialize(cb, valueAddr, setMissing = false) state.nested.toCode { case (nestedIdx, nestedState) => - val nestedAddr = cb.newLocal[Long](s"groupedagg_result_nested_addr_$nestedIdx", resultEltType.fieldOffset(valueAddr, nestedIdx)) + val nestedAddr = cb.newLocal[Long]( + s"groupedagg_result_nested_addr_$nestedIdx", + resultEltType.fieldOffset(valueAddr, nestedIdx), + ) val nestedRes = nestedAggs(nestedIdx).result(cb, nestedState, region) - nestedRes.consume(cb, - { resultEltType.setFieldMissing(cb, valueAddr, nestedIdx)}, - { sv => resultEltType.types(nestedIdx).storeAtAddress(cb, nestedAddr, region, sv, true)}) // TODO: Should this be deep copied? + nestedRes.consume( + cb, + resultEltType.setFieldMissing(cb, valueAddr, nestedIdx), + sv => resultEltType.types(nestedIdx).storeAtAddress(cb, nestedAddr, region, sv, true), + ) // TODO: Should this be deep copied? } cb.assign(i, i + 1) diff --git a/hail/src/main/scala/is/hail/expr/ir/agg/ImputeTypeAggregator.scala b/hail/src/main/scala/is/hail/expr/ir/agg/ImputeTypeAggregator.scala index 1530626af53..7558195808c 100644 --- a/hail/src/main/scala/is/hail/expr/ir/agg/ImputeTypeAggregator.scala +++ b/hail/src/main/scala/is/hail/expr/ir/agg/ImputeTypeAggregator.scala @@ -4,16 +4,14 @@ import is.hail.annotations.Region import is.hail.asm4s._ import is.hail.backend.ExecuteContext import is.hail.expr.ir.{EmitClassBuilder, EmitCode, EmitCodeBuilder, IEmitCode} +import is.hail.types.{RPrimitive, VirtualTypeWithReq} import is.hail.types.physical.stypes.EmitType import is.hail.types.physical.stypes.concrete.SStackStruct import is.hail.types.physical.stypes.interfaces._ import is.hail.types.physical.stypes.primitives.{SBoolean, SBooleanValue} import is.hail.types.virtual._ -import is.hail.types.{RPrimitive, VirtualTypeWithReq} import is.hail.utils._ -import scala.language.existentials - object ImputeTypeState { val resultVirtualType = TStruct( "anyNonMissing" -> TBoolean, @@ -21,51 +19,61 @@ object ImputeTypeState { "supportsBool" -> TBoolean, "supportsInt32" -> TBoolean, "supportsInt64" -> TBoolean, - "supportsFloat64" -> TBoolean + "supportsFloat64" -> TBoolean, ) - val resultSType = SStackStruct(resultVirtualType, IndexedSeq( - EmitType(SBoolean, true), - EmitType(SBoolean, true), - EmitType(SBoolean, true), - EmitType(SBoolean, true), - EmitType(SBoolean, true), - EmitType(SBoolean, true) - )) - - val resultEmitType = EmitType(resultSType, true) - - def matchBoolean(x: String): Boolean = try { - x.toBoolean - true - } catch { - case e: IllegalArgumentException => false - } - - def matchInt32(x: String): Boolean = try { - Integer.parseInt(x) - true - } catch { - case e: IllegalArgumentException => false - } + val resultSType = SStackStruct( + resultVirtualType, + IndexedSeq( + EmitType(SBoolean, true), + EmitType(SBoolean, true), + EmitType(SBoolean, true), + EmitType(SBoolean, true), + EmitType(SBoolean, true), + EmitType(SBoolean, true), + ), + ) - def matchInt64(x: String): Boolean = try { - java.lang.Long.parseLong(x) - true - } catch { - case e: IllegalArgumentException => false - } + val resultEmitType = EmitType(resultSType, true) - def matchFloat64(x: String): Boolean = try { - java.lang.Double.parseDouble(x) - true - } catch { - case e: IllegalArgumentException => false - } + def matchBoolean(x: String): Boolean = + try { + x.toBoolean + true + } catch { + case _: IllegalArgumentException => false + } + + def matchInt32(x: String): Boolean = + try { + Integer.parseInt(x) + true + } catch { + case _: IllegalArgumentException => false + } + + def matchInt64(x: String): Boolean = + try { + java.lang.Long.parseLong(x) + true + } catch { + case _: IllegalArgumentException => false + } + + def matchFloat64(x: String): Boolean = + try { + java.lang.Double.parseDouble(x) + true + } catch { + case _: IllegalArgumentException => false + } } -class ImputeTypeState(kb: EmitClassBuilder[_]) extends PrimitiveRVAState(Array(VirtualTypeWithReq(TInt32,RPrimitive()).setRequired(true)), kb) { +class ImputeTypeState(kb: EmitClassBuilder[_]) extends PrimitiveRVAState( + Array(VirtualTypeWithReq(TInt32, RPrimitive()).setRequired(true)), + kb, + ) { private def repr: Code[Int] = _repr.pv.asInt32.value private val _repr = fields(0) @@ -81,13 +89,14 @@ class ImputeTypeState(kb: EmitClassBuilder[_]) extends PrimitiveRVAState(Array(V def getSupportsF64: Code[Boolean] = (repr & 1 << 5).cne(0) - private def setRepr(cb: EmitCodeBuilder, + private def setRepr( + cb: EmitCodeBuilder, anyNonMissing: Code[Boolean], allDefined: Code[Boolean], supportsBool: Code[Boolean], supportsI32: Code[Boolean], supportsI64: Code[Boolean], - supportsF64: Code[Boolean] + supportsF64: Code[Boolean], ): Unit = { val value = cb.memoize(anyNonMissing.toI | (allDefined.toI << 1) @@ -98,38 +107,56 @@ class ImputeTypeState(kb: EmitClassBuilder[_]) extends PrimitiveRVAState(Array(V cb.assign(_repr, EmitCode.present(cb.emb, primitive(value))) } - def initialize(cb: EmitCodeBuilder): Unit = { + def initialize(cb: EmitCodeBuilder): Unit = setRepr(cb, false, true, true, true, true, true) - } def seqOp(cb: EmitCodeBuilder, ec: EmitCode): Unit = { ec.toI(cb) - .consume(cb, + .consume( + cb, cb.assign(_repr, EmitCode.present(cb.emb, primitive(cb.memoize(repr & (~(1 << 1)))))), { case (pc: SStringValue) => val s = cb.newLocal[String]("impute_type_agg_seq_str") cb.assign(s, pc.loadString(cb)) - setRepr(cb, + setRepr( + cb, true, getAllDefined, - getSupportsBool && Code.invokeScalaObject1[String, Boolean](ImputeTypeState.getClass, "matchBoolean", s), - getSupportsI32 && Code.invokeScalaObject1[String, Boolean](ImputeTypeState.getClass, "matchInt32", s), - getSupportsI64 && Code.invokeScalaObject1[String, Boolean](ImputeTypeState.getClass, "matchInt64", s), - getSupportsF64 && Code.invokeScalaObject1[String, Boolean](ImputeTypeState.getClass, "matchFloat64", s) + getSupportsBool && Code.invokeScalaObject1[String, Boolean]( + ImputeTypeState.getClass, + "matchBoolean", + s, + ), + getSupportsI32 && Code.invokeScalaObject1[String, Boolean]( + ImputeTypeState.getClass, + "matchInt32", + s, + ), + getSupportsI64 && Code.invokeScalaObject1[String, Boolean]( + ImputeTypeState.getClass, + "matchInt64", + s, + ), + getSupportsF64 && Code.invokeScalaObject1[String, Boolean]( + ImputeTypeState.getClass, + "matchFloat64", + s, + ), ) - } + }, ) } def combOp(cb: EmitCodeBuilder, other: ImputeTypeState): Unit = { - setRepr(cb, + setRepr( + cb, getAnyNonMissing || other.getAnyNonMissing, getAllDefined && other.getAllDefined, getSupportsBool && other.getSupportsBool, getSupportsI32 && other.getSupportsI32, getSupportsI64 && other.getSupportsI64, - getSupportsF64 && other.getSupportsF64 + getSupportsF64 && other.getSupportsF64, ) } } @@ -154,14 +181,30 @@ class ImputeTypeAggregator() extends StagedAggregator { state.seqOp(cb, s) } - protected def _combOp(ctx: ExecuteContext, cb: EmitCodeBuilder, state: ImputeTypeState, other: ImputeTypeState): Unit = { + protected def _combOp( + ctx: ExecuteContext, + cb: EmitCodeBuilder, + region: Value[Region], + state: ImputeTypeState, + other: ImputeTypeState, + ): Unit = state.combOp(cb, other) - } protected def _result(cb: EmitCodeBuilder, state: State, region: Value[Region]): IEmitCode = { - val emitCodes = Array(state.getAnyNonMissing, state.getAllDefined, state.getSupportsBool, state.getSupportsI32, state.getSupportsI64, state.getSupportsF64). - map(bool => new SBooleanValue(cb.memoize(bool))).map(sbv => EmitCode.present(cb.emb, sbv)) - val sv = SStackStruct.constructFromArgs(cb, region, resultEmitType.virtualType.asInstanceOf[TBaseStruct], emitCodes:_*) + val emitCodes = Array( + state.getAnyNonMissing, + state.getAllDefined, + state.getSupportsBool, + state.getSupportsI32, + state.getSupportsI64, + state.getSupportsF64, + ).map(bool => new SBooleanValue(cb.memoize(bool))).map(sbv => EmitCode.present(cb.emb, sbv)) + val sv = SStackStruct.constructFromArgs( + cb, + region, + resultEmitType.virtualType.asInstanceOf[TBaseStruct], + emitCodes: _* + ) IEmitCode.present(cb, sv) } } diff --git a/hail/src/main/scala/is/hail/expr/ir/agg/LinearRegressionAggregator.scala b/hail/src/main/scala/is/hail/expr/ir/agg/LinearRegressionAggregator.scala index dc6f70ca1b4..6f43d58ccb4 100644 --- a/hail/src/main/scala/is/hail/expr/ir/agg/LinearRegressionAggregator.scala +++ b/hail/src/main/scala/is/hail/expr/ir/agg/LinearRegressionAggregator.scala @@ -1,18 +1,21 @@ package is.hail.expr.ir.agg -import breeze.linalg.{DenseMatrix, DenseVector, diag, inv} import is.hail.annotations.{Region, RegionValueBuilder, UnsafeRow} import is.hail.asm4s._ import is.hail.backend.{ExecuteContext, HailStateManager} import is.hail.expr.ir.{EmitClassBuilder, EmitCode, EmitCodeBuilder, IEmitCode} import is.hail.types.physical._ import is.hail.types.physical.stypes.EmitType -import is.hail.types.physical.stypes.concrete.{SBaseStructPointer, SIndexablePointer, SIndexablePointerValue} +import is.hail.types.physical.stypes.concrete.{ + SBaseStructPointer, SIndexablePointer, SIndexablePointerValue, +} import is.hail.types.physical.stypes.interfaces.SIndexableValue import is.hail.types.virtual.{TArray, TFloat64, TInt32, Type} -import is.hail.utils.FastSeq -class LinearRegressionAggregatorState(val kb: EmitClassBuilder[_]) extends AbstractTypedRegionBackedAggState(LinearRegressionAggregator.stateType) +import breeze.linalg.{diag, inv, DenseMatrix, DenseVector} + +class LinearRegressionAggregatorState(val kb: EmitClassBuilder[_]) + extends AbstractTypedRegionBackedAggState(LinearRegressionAggregator.stateType) object LinearRegressionAggregator { @@ -22,14 +25,24 @@ object LinearRegressionAggregator { private val optVector = vector.setRequired(false) - val resultPType: PCanonicalStruct = PCanonicalStruct(required = false, "xty" -> optVector, "beta" -> optVector, "diag_inv" -> optVector, "beta0" -> optVector) + val resultPType: PCanonicalStruct = PCanonicalStruct( + required = false, + "xty" -> optVector, + "beta" -> optVector, + "diag_inv" -> optVector, + "beta0" -> optVector, + ) def computeResult(region: Region, xtyPtr: Long, xtxPtr: Long, k0: Int): Long = { val xty = DenseVector(UnsafeRow.readArray(vector, null, xtyPtr) .asInstanceOf[IndexedSeq[Double]].toArray[Double]) val k = xty.length - val xtx = DenseMatrix.create(k, k, UnsafeRow.readArray(vector, null, xtxPtr) - .asInstanceOf[IndexedSeq[Double]].toArray[Double]) + val xtx = DenseMatrix.create( + k, + k, + UnsafeRow.readArray(vector, null, xtxPtr) + .asInstanceOf[IndexedSeq[Double]].toArray[Double], + ) val rvb = new RegionValueBuilder(HailStateManager(Map.empty), region) rvb.start(resultPType) @@ -76,7 +89,7 @@ object LinearRegressionAggregator { rvb.endArray() } catch { case _: breeze.linalg.MatrixSingularException | - _: breeze.linalg.NotConvergedException => + _: breeze.linalg.NotConvergedException => rvb.setMissing() rvb.setMissing() rvb.setMissing() @@ -94,7 +107,8 @@ class LinearRegressionAggregator() extends StagedAggregator { type State = AbstractTypedRegionBackedAggState - override def resultEmitType: EmitType = EmitType(SBaseStructPointer(LinearRegressionAggregator.resultPType), true) + override def resultEmitType: EmitType = + EmitType(SBaseStructPointer(LinearRegressionAggregator.resultPType), true) val initOpTypes: Seq[Type] = Array(TInt32, TInt32) val seqOpTypes: Seq[Type] = Array(TFloat64, TArray(TFloat64)) @@ -102,33 +116,39 @@ class LinearRegressionAggregator() extends StagedAggregator { def initOpF(state: State)(cb: EmitCodeBuilder, kc: Code[Int], k0c: Code[Int]): Unit = { val k = cb.newLocal[Int]("lra_init_k", kc) val k0 = cb.newLocal[Int]("lra_init_k0", k0c) - cb.if_((k0 < 0) || (k0 > k), + cb.if_( + (k0 < 0) || (k0 > k), cb += Code._fatal[Unit](const("linreg: `nested_dim` must be between 0 and the number (") .concat(k.toS) - .concat(") of covariates, inclusive")) + .concat(") of covariates, inclusive")), ) cb.assign(state.off, stateType.allocate(state.region)) - cb += Region.storeAddress(stateType.fieldOffset(state.off, 0), vector.zeroes(cb, state.region, k)) - cb += Region.storeAddress(stateType.fieldOffset(state.off, 1), vector.zeroes(cb, state.region, k * k)) + cb += Region.storeAddress( + stateType.fieldOffset(state.off, 0), + vector.zeroes(cb, state.region, k), + ) + cb += Region.storeAddress( + stateType.fieldOffset(state.off, 1), + vector.zeroes(cb, state.region, k * k), + ) cb += Region.storeInt(stateType.loadField(state.off, 2), k0) } protected def _initOp(cb: EmitCodeBuilder, state: State, init: Array[EmitCode]): Unit = { val Array(kt, k0t) = init kt.toI(cb) - .consume(cb, - { - cb += Code._fatal[Unit]("linreg: init args may not be missing") - }, + .consume( + cb, + cb += Code._fatal[Unit]("linreg: init args may not be missing"), { ktCode => k0t.toI(cb) - .consume(cb, - { - cb += Code._fatal[Unit]("linreg: init args may not be missing") - }, - k0tCode => initOpF(state)(cb, ktCode.asInt.value, k0tCode.asInt.value) + .consume( + cb, + cb += Code._fatal[Unit]("linreg: init args may not be missing"), + k0tCode => initOpF(state)(cb, ktCode.asInt.value, k0tCode.asInt.value), ) - }) + }, + ) } def seqOpF(state: State)(cb: EmitCodeBuilder, y: Code[Double], x: SIndexableValue): Unit = { @@ -139,8 +159,8 @@ class LinearRegressionAggregator() extends StagedAggregator { val xty = cb.newLocal[Long]("linreg_agg_seqop_xty") val xtx = cb.newLocal[Long]("linreg_agg_seqop_xtx") - cb.if_(!x.hasMissingValues(cb), - { + cb.if_( + !x.hasMissingValues(cb), { cb.assign(xty, stateType.loadField(state.off, 0)) cb.assign(xtx, stateType.loadField(state.off, 1)) cb.assign(k, vector.loadLength(xty)) @@ -154,74 +174,101 @@ class LinearRegressionAggregator() extends StagedAggregator { val xptr = cb.newLocal[Long]("linreg_agg_seqop_xptr") val xptr2 = cb.newLocal[Long]("linreg_agg_seqop_xptr2") cb.assign(xptr, pt.firstElementOffset(xAddr, k)) - cb.while_(i < k, - { - cb += Region.storeDouble(sptr, Region.loadDouble(sptr) + (Region.loadDouble(xptr) * y)) + cb.while_( + i < k, { + cb += Region.storeDouble( + sptr, + Region.loadDouble(sptr) + (Region.loadDouble(xptr) * y), + ) cb.assign(i, i + 1) cb.assign(sptr, sptr + scalar.byteSize) cb.assign(xptr, xptr + scalar.byteSize) - }) + }, + ) cb.assign(i, 0) cb.assign(sptr, vector.firstElementOffset(xtx, k)) cb.assign(xptr, pt.firstElementOffset(xAddr, k)) - cb.while_(i < k, - { + cb.while_( + i < k, { cb.assign(j, 0) cb.assign(xptr2, pt.firstElementOffset(xAddr, k)) - cb.while_(j < k, - { + cb.while_( + j < k, { // add x[i] * x[j] to the value at sptr - cb += Region.storeDouble(sptr, Region.loadDouble(sptr) + (Region.loadDouble(xptr) * Region.loadDouble(xptr2))) + cb += Region.storeDouble( + sptr, + Region.loadDouble(sptr) + (Region.loadDouble(xptr) * Region.loadDouble(xptr2)), + ) cb.assign(j, j + 1) cb.assign(sptr, sptr + scalar.byteSize) cb.assign(xptr2, xptr2 + scalar.byteSize) - }) + }, + ) cb.assign(i, i + 1) cb.assign(xptr, xptr + scalar.byteSize) - }) + }, + ) case _ => - cb.while_(i < k, - { - cb += Region.storeDouble(sptr, Region.loadDouble(sptr) + x.loadElement(cb, i).get(cb).asDouble.value * y) + cb.while_( + i < k, { + cb += Region.storeDouble( + sptr, + Region.loadDouble(sptr) + x.loadElement(cb, i).getOrAssert(cb).asDouble.value * y, + ) cb.assign(i, i + 1) cb.assign(sptr, sptr + scalar.byteSize) - }) + }, + ) cb.assign(i, 0) cb.assign(sptr, vector.firstElementOffset(xtx, k)) - cb.while_(i < k, - { + cb.while_( + i < k, { cb.assign(j, 0) - cb.while_(j < k, - { + cb.while_( + j < k, { // add x[i] * x[j] to the value at sptr - cb += Region.storeDouble(sptr, Region.loadDouble(sptr) + - (x.loadElement(cb, i).get(cb).asDouble.value * x.loadElement(cb, j).get(cb).asDouble.value)) + cb += Region.storeDouble( + sptr, + Region.loadDouble(sptr) + + (x.loadElement(cb, i).getOrAssert(cb).asDouble.value * x.loadElement( + cb, + j, + ).getOrAssert( + cb + ).asDouble.value), + ) cb.assign(j, j + 1) cb.assign(sptr, sptr + scalar.byteSize) - }) + }, + ) cb.assign(i, i + 1) - }) + }, + ) } - }) + }, + ) } protected def _seqOp(cb: EmitCodeBuilder, state: State, seq: Array[EmitCode]): Unit = { val Array(y, x) = seq y.toI(cb) - .consume(cb, + .consume( + cb, {}, { yCode => x.toI(cb) - .consume(cb, + .consume( + cb, {}, - xCode => seqOpF(state)(cb, yCode.asDouble.value, xCode.asIndexable) + xCode => seqOpF(state)(cb, yCode.asDouble.value, xCode.asIndexable), ) - }) + }, + ) } def combOpF(state: State, other: State)(cb: EmitCodeBuilder): Unit = { @@ -242,45 +289,59 @@ class LinearRegressionAggregator() extends StagedAggregator { n := vector.loadLength(xty), i := 0, sptr := vector.firstElementOffset(xty, n), - optr := vector.firstElementOffset(oxty, n) + optr := vector.firstElementOffset(oxty, n), ) - cb.while_(i < n, { - cb += Code( - Region.storeDouble(sptr, Region.loadDouble(sptr) + Region.loadDouble(optr)), - i := i + 1, - sptr := sptr + scalar.byteSize, - optr := optr + scalar.byteSize - ) - }) + cb.while_( + i < n, { + cb += Code( + Region.storeDouble(sptr, Region.loadDouble(sptr) + Region.loadDouble(optr)), + i := i + 1, + sptr := sptr + scalar.byteSize, + optr := optr + scalar.byteSize, + ) + }, + ) cb += Code( n := vector.loadLength(xtx), i := 0, sptr := vector.firstElementOffset(xtx, n), - optr := vector.firstElementOffset(oxtx, n) + optr := vector.firstElementOffset(oxtx, n), ) - cb.while_(i < n, { + cb.while_( + i < n, cb += Code( Region.storeDouble(sptr, Region.loadDouble(sptr) + Region.loadDouble(optr)), i := i + 1, sptr := sptr + scalar.byteSize, - optr := optr + scalar.byteSize) - }) + optr := optr + scalar.byteSize, + ), + ) } - protected def _combOp(ctx: ExecuteContext, cb: EmitCodeBuilder, state: AbstractTypedRegionBackedAggState, other: AbstractTypedRegionBackedAggState): Unit = { + protected def _combOp( + ctx: ExecuteContext, + cb: EmitCodeBuilder, + region: Value[Region], + state: AbstractTypedRegionBackedAggState, + other: AbstractTypedRegionBackedAggState, + ): Unit = combOpF(state, other)(cb) - } protected def _result(cb: EmitCodeBuilder, state: State, region: Value[Region]): IEmitCode = { - val resAddr = cb.newLocal[Long]("linear_regression_agg_res", Code.invokeScalaObject4[Region, Long, Long, Int, Long]( - LinearRegressionAggregator.getClass, "computeResult", - region, - stateType.loadField(state.off, 0), - stateType.loadField(state.off, 1), - Region.loadInt(stateType.loadField(state.off, 2)))) + val resAddr = cb.newLocal[Long]( + "linear_regression_agg_res", + Code.invokeScalaObject4[Region, Long, Long, Int, Long]( + LinearRegressionAggregator.getClass, + "computeResult", + region, + stateType.loadField(state.off, 0), + stateType.loadField(state.off, 1), + Region.loadInt(stateType.loadField(state.off, 2)), + ), + ) IEmitCode.present(cb, LinearRegressionAggregator.resultPType.loadCheapSCode(cb, resAddr)) } } diff --git a/hail/src/main/scala/is/hail/expr/ir/agg/MonoidAggregator.scala b/hail/src/main/scala/is/hail/expr/ir/agg/MonoidAggregator.scala index 3e036a8a879..c6162cb2ce7 100644 --- a/hail/src/main/scala/is/hail/expr/ir/agg/MonoidAggregator.scala +++ b/hail/src/main/scala/is/hail/expr/ir/agg/MonoidAggregator.scala @@ -3,14 +3,12 @@ package is.hail.expr.ir.agg import is.hail.annotations.Region import is.hail.asm4s._ import is.hail.backend.ExecuteContext -import is.hail.expr.ir.functions.UtilFunctions import is.hail.expr.ir._ +import is.hail.expr.ir.functions.UtilFunctions import is.hail.types.physical.stypes.{EmitType, SType} import is.hail.types.physical.stypes.interfaces._ -import is.hail.types.physical.{PType, typeToTypeInfo} import is.hail.types.virtual._ -import scala.language.existentials import scala.reflect.ClassTag trait StagedMonoidSpec { @@ -34,11 +32,11 @@ class MonoidAggregator(monoid: StagedMonoidSpec) extends StagedAggregator { val stateRequired = state.vtypes.head.r.required val ev = state.fields(0) if (!ev.required) { - assert(!stateRequired, s"monoid=$monoid, stateRequired=$stateRequired") - cb.assign(ev, EmitCode.missing(cb.emb, ev.st)) + assert(!stateRequired, s"monoid=$monoid, stateRequired=$stateRequired") + cb.assign(ev, EmitCode.missing(cb.emb, ev.st)) } else { - assert(stateRequired, s"monoid=$monoid, stateRequired=$stateRequired") - cb.assign(ev, EmitCode.present(cb.emb, primitive(ev.st.virtualType, monoid.neutral.get))) + assert(stateRequired, s"monoid=$monoid, stateRequired=$stateRequired") + cb.assign(ev, EmitCode.present(cb.emb, primitive(ev.st.virtualType, monoid.neutral.get))) } } @@ -49,26 +47,35 @@ class MonoidAggregator(monoid: StagedMonoidSpec) extends StagedAggregator { combine(cb, ev, update) } - protected def _combOp(ctx: ExecuteContext, cb: EmitCodeBuilder, state: PrimitiveRVAState, other: PrimitiveRVAState): Unit = { + protected def _combOp( + ctx: ExecuteContext, + cb: EmitCodeBuilder, + region: Value[Region], + state: PrimitiveRVAState, + other: PrimitiveRVAState, + ): Unit = { val ev1 = state.fields(0) val ev2 = other.fields(0) combine(cb, ev1, ev2) } - protected def _result(cb: EmitCodeBuilder, state: State, region: Value[Region]): IEmitCode = { + protected def _result(cb: EmitCodeBuilder, state: State, region: Value[Region]): IEmitCode = state.fields(0).toI(cb) - } private def combine( cb: EmitCodeBuilder, ev1: EmitSettable, - ev2: EmitValue + ev2: EmitValue, ): Unit = { - val combined = primitive(monoid.typ, monoid(cb, ev1.pv.asPrimitive.primitiveValue, ev2.pv.asPrimitive.primitiveValue)) - cb.if_(ev1.m, + val combined = primitive( + monoid.typ, + monoid(cb, ev1.pv.asPrimitive.primitiveValue, ev2.pv.asPrimitive.primitiveValue), + ) + cb.if_( + ev1.m, cb.if_(!ev2.m, cb.assign(ev1, ev2)), - cb.if_(!ev2.m, - cb.assign(ev1, EmitCode.present(cb.emb, combined)))) + cb.if_(!ev2.m, cb.assign(ev1, EmitCode.present(cb.emb, combined))), + ) } } diff --git a/hail/src/main/scala/is/hail/expr/ir/agg/NDArrayMultiplyAddAggregator.scala b/hail/src/main/scala/is/hail/expr/ir/agg/NDArrayMultiplyAddAggregator.scala index 1c821b689ed..bae7d64d3b4 100644 --- a/hail/src/main/scala/is/hail/expr/ir/agg/NDArrayMultiplyAddAggregator.scala +++ b/hail/src/main/scala/is/hail/expr/ir/agg/NDArrayMultiplyAddAggregator.scala @@ -10,13 +10,15 @@ import is.hail.types.physical.PCanonicalNDArray import is.hail.types.physical.stypes.EmitType import is.hail.types.physical.stypes.interfaces.{SNDArray, SNDArrayValue} import is.hail.types.virtual.Type -import is.hail.utils.{FastSeq, valueToRichCodeRegion} +import is.hail.utils.{valueToRichCodeRegion, FastSeq} class NDArrayMultiplyAddAggregator(ndVTyp: VirtualTypeWithReq) extends StagedAggregator { override type State = TypedRegionBackedAggState override def resultEmitType: EmitType = ndVTyp.canonicalEmitType - private val ndTyp = resultEmitType.storageType.asInstanceOf[PCanonicalNDArray] // TODO: Set required false? + + private val ndTyp = + resultEmitType.storageType.asInstanceOf[PCanonicalNDArray] // TODO: Set required false? override def initOpTypes: Seq[Type] = Array[Type]() @@ -29,64 +31,101 @@ class NDArrayMultiplyAddAggregator(ndVTyp: VirtualTypeWithReq) extends StagedAgg initMethod.voidWithBuilder(cb => state.storeMissing(cb) ) - cb.invokeVoid(initMethod) + cb.invokeVoid(initMethod, cb.this_) } override protected def _seqOp(cb: EmitCodeBuilder, state: State, seq: Array[EmitCode]): Unit = { val Array(nextNDArrayACode, nextNDArrayBCode) = seq - val seqOpMethod = cb.emb.genEmitMethod("ndarray_add_multiply_aggregator_seq_op", - FastSeq(nextNDArrayACode.emitParamType, nextNDArrayBCode.emitParamType), CodeParamType(UnitInfo)) + val seqOpMethod = cb.emb.genEmitMethod( + "ndarray_add_multiply_aggregator_seq_op", + FastSeq(nextNDArrayACode.emitParamType, nextNDArrayBCode.emitParamType), + CodeParamType(UnitInfo), + ) seqOpMethod.voidWithBuilder { cb => val ndArrayAEmitCode = seqOpMethod.getEmitParam(cb, 1) - ndArrayAEmitCode.toI(cb).consume(cb, {}, { case checkA: SNDArrayValue => - val ndArrayBEmitCode = seqOpMethod.getEmitParam(cb, 2) - ndArrayBEmitCode.toI(cb).consume(cb, {}, { case checkB: SNDArrayValue => - val tempRegionForCreation = cb.newLocal[Region]("ndarray_add_multily_agg_temp_region", Region.stagedCreate(Region.REGULAR, cb.emb.ecb.pool())) - val NDArrayA = LinalgCodeUtils.checkColMajorAndCopyIfNeeded(checkA, cb, tempRegionForCreation) - val NDArrayB = LinalgCodeUtils.checkColMajorAndCopyIfNeeded(checkB, cb, tempRegionForCreation) - val statePV = state.storageType.loadCheapSCode(cb, state.off).asBaseStruct - statePV.loadField(cb, ndarrayFieldNumber).consume(cb, - { - cb += state.region.getNewRegion(Region.REGULAR) - state.storageType.setFieldPresent(cb, state.off.get, ndarrayFieldNumber) - val shape = IndexedSeq(NDArrayA.shapes(0), NDArrayB.shapes(1)) - val uninitializedNDArray = ndTyp.constructUninitialized(shape, ndTyp.makeColumnMajorStrides(shape, cb), cb, tempRegionForCreation) - state.storeNonmissing(cb, uninitializedNDArray) - SNDArray.gemm(cb, "N", "N", NDArrayA, NDArrayB, uninitializedNDArray) + ndArrayAEmitCode.toI(cb).consume( + cb, + {}, + { case checkA: SNDArrayValue => + val ndArrayBEmitCode = seqOpMethod.getEmitParam(cb, 2) + ndArrayBEmitCode.toI(cb).consume( + cb, + {}, + { case checkB: SNDArrayValue => + val tempRegionForCreation = cb.newLocal[Region]( + "ndarray_add_multily_agg_temp_region", + Region.stagedCreate(Region.REGULAR, cb.emb.ecb.pool()), + ) + val NDArrayA = + LinalgCodeUtils.checkColMajorAndCopyIfNeeded(checkA, cb, tempRegionForCreation) + val NDArrayB = + LinalgCodeUtils.checkColMajorAndCopyIfNeeded(checkB, cb, tempRegionForCreation) + val statePV = state.storageType.loadCheapSCode(cb, state.off).asBaseStruct + statePV.loadField(cb, ndarrayFieldNumber).consume( + cb, { + cb += state.region.getNewRegion(Region.REGULAR) + state.storageType.setFieldPresent(cb, state.off.get, ndarrayFieldNumber) + val shape = IndexedSeq(NDArrayA.shapes(0), NDArrayB.shapes(1)) + val uninitializedNDArray = ndTyp.constructUninitialized( + shape, + ndTyp.makeColumnMajorStrides(shape, cb), + cb, + tempRegionForCreation, + ) + state.storeNonmissing(cb, uninitializedNDArray) + SNDArray.gemm(cb, "N", "N", NDArrayA, NDArrayB, uninitializedNDArray) + }, + currentNDPValue => + SNDArray.gemm( + cb, + "N", + "N", + const(1.0), + NDArrayA, + NDArrayB, + const(1.0), + currentNDPValue.asNDArray, + ), + ) + cb += tempRegionForCreation.clearRegion() }, - { currentNDPValue => - SNDArray.gemm(cb, "N", "N", const(1.0), NDArrayA, NDArrayB, const(1.0), currentNDPValue.asNDArray) - } ) - cb += tempRegionForCreation.clearRegion() - }) - }) + }, + ) } - cb.invokeVoid(seqOpMethod, nextNDArrayACode, nextNDArrayBCode) + cb.invokeVoid(seqOpMethod, cb.this_, nextNDArrayACode, nextNDArrayBCode) } - override protected def _combOp(ctx: ExecuteContext, cb: EmitCodeBuilder, state: State, other: State): Unit = { + override protected def _combOp( + ctx: ExecuteContext, + cb: EmitCodeBuilder, + region: Value[Region], + state: TypedRegionBackedAggState, + other: TypedRegionBackedAggState, + ): Unit = { val combOpMethod = cb.emb.genEmitMethod[Unit]("ndarraymutiply_add_agg_comb_op") combOpMethod.voidWithBuilder { cb => val rightPV = other.storageType.loadCheapSCode(cb, other.off).asBaseStruct - rightPV.loadField(cb, ndarrayFieldNumber).consume(cb, {}, + rightPV.loadField(cb, ndarrayFieldNumber).consume( + cb, + {}, { case rightNdValue: SNDArrayValue => val leftPV = state.storageType.loadCheapSCode(cb, state.off).asBaseStruct - leftPV.loadField(cb, ndarrayFieldNumber).consume(cb, - { - state.storeNonmissing(cb, rightNdValue) - }, + leftPV.loadField(cb, ndarrayFieldNumber).consume( + cb, + state.storeNonmissing(cb, rightNdValue), { case leftNdValue: SNDArrayValue => NDArraySumAggregator.addValues(cb, state.region, leftNdValue, rightNdValue) - }) - } + }, + ) + }, ) } - cb.invokeVoid(combOpMethod) + cb.invokeVoid(combOpMethod, cb.this_) } - override protected def _result(cb: EmitCodeBuilder, state: State, region: Value[Region]): IEmitCode = { + override protected def _result(cb: EmitCodeBuilder, state: State, region: Value[Region]) + : IEmitCode = state.get(cb).map(cb)(sv => sv.copyToRegion(cb, region, sv.st)) - } } diff --git a/hail/src/main/scala/is/hail/expr/ir/agg/NDArraySumAggregator.scala b/hail/src/main/scala/is/hail/expr/ir/agg/NDArraySumAggregator.scala index 91f871b57d2..7d403580a45 100644 --- a/hail/src/main/scala/is/hail/expr/ir/agg/NDArraySumAggregator.scala +++ b/hail/src/main/scala/is/hail/expr/ir/agg/NDArraySumAggregator.scala @@ -6,8 +6,8 @@ import is.hail.backend.ExecuteContext import is.hail.expr.ir.{CodeParamType, EmitCode, EmitCodeBuilder, IEmitCode} import is.hail.types.VirtualTypeWithReq import is.hail.types.physical.PCanonicalNDArray -import is.hail.types.physical.stypes.interfaces.SNDArrayValue import is.hail.types.physical.stypes.{EmitType, SCode} +import is.hail.types.physical.stypes.interfaces.SNDArrayValue import is.hail.types.virtual.Type import is.hail.utils._ @@ -15,7 +15,9 @@ class NDArraySumAggregator(ndVTyp: VirtualTypeWithReq) extends StagedAggregator override type State = TypedRegionBackedAggState override def resultEmitType: EmitType = ndVTyp.canonicalEmitType - private val ndTyp = resultEmitType.storageType.asInstanceOf[PCanonicalNDArray] // TODO: Set required false? + + private val ndTyp = + resultEmitType.storageType.asInstanceOf[PCanonicalNDArray] // TODO: Set required false? override def initOpTypes: Seq[Type] = Array[Type]() @@ -28,66 +30,91 @@ class NDArraySumAggregator(ndVTyp: VirtualTypeWithReq) extends StagedAggregator initMethod.voidWithBuilder(cb => state.storeMissing(cb) ) - cb.invokeVoid(initMethod) + cb.invokeVoid(initMethod, cb.this_) } override protected def _seqOp(cb: EmitCodeBuilder, state: State, seq: Array[EmitCode]): Unit = { val Array(nextNDCode) = seq - val seqOpMethod = cb.emb.genEmitMethod("ndarray_sum_aggregator_seq_op", FastSeq(nextNDCode.emitParamType), CodeParamType(UnitInfo)) + val seqOpMethod = cb.emb.genEmitMethod( + "ndarray_sum_aggregator_seq_op", + FastSeq(nextNDCode.emitParamType), + CodeParamType(UnitInfo), + ) seqOpMethod.voidWithBuilder { cb => val nextNDInput = seqOpMethod.getEmitParam(cb, 1) - nextNDInput.toI(cb).consume(cb, {}, { case nextNDPV: SNDArrayValue => - val statePV = state.storageType.loadCheapSCode(cb, state.off).asBaseStruct - statePV.loadField(cb, ndarrayFieldNumber).consume(cb, - { - cb += state.region.getNewRegion(Region.TINY) - state.storageType.setFieldPresent(cb, state.off, ndarrayFieldNumber) - val tempRegionForCreation = cb.newLocal[Region]("ndarray_sum_agg_temp_region", Region.stagedCreate(Region.REGULAR, cb.emb.ecb.pool())) - val fullyCopiedNDArray = ndTyp.constructByActuallyCopyingData(nextNDPV, cb, tempRegionForCreation) - state.storeNonmissing(cb, fullyCopiedNDArray) - cb += tempRegionForCreation.clearRegion() - }, - { currentND => - NDArraySumAggregator.addValues(cb, state.region, currentND.asNDArray, nextNDPV) - } - ) - }) + nextNDInput.toI(cb).consume( + cb, + {}, + { case nextNDPV: SNDArrayValue => + val statePV = state.storageType.loadCheapSCode(cb, state.off).asBaseStruct + statePV.loadField(cb, ndarrayFieldNumber).consume( + cb, { + cb += state.region.getNewRegion(Region.TINY) + state.storageType.setFieldPresent(cb, state.off, ndarrayFieldNumber) + val tempRegionForCreation = cb.newLocal[Region]( + "ndarray_sum_agg_temp_region", + Region.stagedCreate(Region.REGULAR, cb.emb.ecb.pool()), + ) + val fullyCopiedNDArray = + ndTyp.constructByActuallyCopyingData(nextNDPV, cb, tempRegionForCreation) + state.storeNonmissing(cb, fullyCopiedNDArray) + cb += tempRegionForCreation.clearRegion() + }, + currentND => + NDArraySumAggregator.addValues(cb, state.region, currentND.asNDArray, nextNDPV), + ) + }, + ) } - cb.invokeVoid(seqOpMethod, nextNDCode) + cb.invokeVoid(seqOpMethod, cb.this_, nextNDCode) } - override protected def _combOp(ctx: ExecuteContext, cb: EmitCodeBuilder, state: TypedRegionBackedAggState, other: TypedRegionBackedAggState): Unit = { + override protected def _combOp( + ctx: ExecuteContext, + cb: EmitCodeBuilder, + region: Value[Region], + state: TypedRegionBackedAggState, + other: TypedRegionBackedAggState, + ): Unit = { val combOpMethod = cb.emb.genEmitMethod[Unit]("ndarray_sum_aggregator_comb_op") combOpMethod.voidWithBuilder { cb => val rightPV = other.storageType.loadCheapSCode(cb, other.off).asBaseStruct - rightPV.loadField(cb, ndarrayFieldNumber).consume(cb, {}, + rightPV.loadField(cb, ndarrayFieldNumber).consume( + cb, + {}, { case rightNdValue: SNDArrayValue => val leftPV = state.storageType.loadCheapSCode(cb, state.off).asBaseStruct - leftPV.loadField(cb, ndarrayFieldNumber).consume(cb, - { - state.storeNonmissing(cb, rightNdValue) - }, + leftPV.loadField(cb, ndarrayFieldNumber).consume( + cb, + state.storeNonmissing(cb, rightNdValue), { case leftNdValue: SNDArrayValue => NDArraySumAggregator.addValues(cb, state.region, leftNdValue, rightNdValue) - }) - } + }, + ) + }, ) } - cb.invokeVoid(combOpMethod) + cb.invokeVoid(combOpMethod, cb.this_) } - protected def _result(cb: EmitCodeBuilder, state: State, region: Value[Region]): IEmitCode = { + protected def _result(cb: EmitCodeBuilder, state: State, region: Value[Region]): IEmitCode = state.get(cb).map(cb)(sv => sv.copyToRegion(cb, region, sv.st)) - } } object NDArraySumAggregator { - def addValues(cb: EmitCodeBuilder, region: Value[Region], leftNdValue: SNDArrayValue, rightNdValue: SNDArrayValue): Unit = { - cb.if_(!leftNdValue.sameShape(cb, rightNdValue), - cb += Code._fatal[Unit]("Can't sum ndarrays of different shapes.")) + def addValues( + cb: EmitCodeBuilder, + region: Value[Region], + leftNdValue: SNDArrayValue, + rightNdValue: SNDArrayValue, + ): Unit = { + cb.if_( + !leftNdValue.sameShape(cb, rightNdValue), + cb += Code._fatal[Unit]("Can't sum ndarrays of different shapes."), + ) leftNdValue.coiterateMutate(cb, region, (rightNdValue, "right")) { case Seq(l, r) => diff --git a/hail/src/main/scala/is/hail/expr/ir/agg/PrevNonNullAggregator.scala b/hail/src/main/scala/is/hail/expr/ir/agg/PrevNonNullAggregator.scala index 436b6d3ca25..5f124f46210 100644 --- a/hail/src/main/scala/is/hail/expr/ir/agg/PrevNonNullAggregator.scala +++ b/hail/src/main/scala/is/hail/expr/ir/agg/PrevNonNullAggregator.scala @@ -5,7 +5,6 @@ import is.hail.asm4s._ import is.hail.backend.ExecuteContext import is.hail.expr.ir.{EmitCode, EmitCodeBuilder, IEmitCode} import is.hail.types.VirtualTypeWithReq -import is.hail.types.physical._ import is.hail.types.physical.stypes.EmitType import is.hail.types.virtual.Type @@ -23,21 +22,27 @@ class PrevNonNullAggregator(typ: VirtualTypeWithReq) extends StagedAggregator { protected def _seqOp(cb: EmitCodeBuilder, state: State, seq: Array[EmitCode]): Unit = { val Array(elt: EmitCode) = seq elt.toI(cb) - .consume(cb, + .consume( + cb, { /* do nothing if missing */ }, - sc => state.storeNonmissing(cb, sc) + sc => state.storeNonmissing(cb, sc), ) } - protected def _combOp(ctx: ExecuteContext, cb: EmitCodeBuilder, state: TypedRegionBackedAggState, other: TypedRegionBackedAggState): Unit = { + protected def _combOp( + ctx: ExecuteContext, + cb: EmitCodeBuilder, + region: Value[Region], + state: TypedRegionBackedAggState, + other: TypedRegionBackedAggState, + ): Unit = other.get(cb) - .consume(cb, + .consume( + cb, { /* do nothing if missing */ }, - sc => state.storeNonmissing(cb, sc) + sc => state.storeNonmissing(cb, sc), ) - } - protected def _result(cb: EmitCodeBuilder, state: State, region: Value[Region]): IEmitCode = { + protected def _result(cb: EmitCodeBuilder, state: State, region: Value[Region]): IEmitCode = state.get(cb).map(cb)(sv => sv.copyToRegion(cb, region, sv.st)) - } } diff --git a/hail/src/main/scala/is/hail/expr/ir/agg/ReservoirSampleAggregator.scala b/hail/src/main/scala/is/hail/expr/ir/agg/ReservoirSampleAggregator.scala index 105ead50fd9..c4ba1863140 100644 --- a/hail/src/main/scala/is/hail/expr/ir/agg/ReservoirSampleAggregator.scala +++ b/hail/src/main/scala/is/hail/expr/ir/agg/ReservoirSampleAggregator.scala @@ -3,7 +3,7 @@ package is.hail.expr.ir.agg import is.hail.annotations.Region import is.hail.asm4s.{Code, _} import is.hail.backend.ExecuteContext -import is.hail.expr.ir.{EmitClassBuilder, EmitCode, EmitCodeBuilder, IEmitCode} +import is.hail.expr.ir.{EmitClassBuilder, EmitCode, EmitCodeBuilder, EmitMethodBuilder, IEmitCode} import is.hail.io.{BufferSpec, InputBuffer, OutputBuffer} import is.hail.types.VirtualTypeWithReq import is.hail.types.physical._ @@ -12,7 +12,8 @@ import is.hail.types.physical.stypes.concrete.{SIndexablePointer, SIndexablePoin import is.hail.types.virtual.{TInt32, Type} import is.hail.utils._ -class ReservoirSampleRVAS(val eltType: VirtualTypeWithReq, val kb: EmitClassBuilder[_]) extends AggregatorState { +class ReservoirSampleRVAS(val eltType: VirtualTypeWithReq, val kb: EmitClassBuilder[_]) + extends AggregatorState { val eltPType = eltType.canonicalPType private val r: ThisFieldRef[Region] = kb.genFieldThisRef[Region]() @@ -20,7 +21,10 @@ class ReservoirSampleRVAS(val eltType: VirtualTypeWithReq, val kb: EmitClassBuil private val rand = kb.genFieldThisRef[java.util.Random]() val builder = new StagedArrayBuilder(eltPType, kb, region) - val storageType: PCanonicalTuple = PCanonicalTuple(true, PInt32Required, PInt64Required, PInt64Required, builder.stateType) + + val storageType: PCanonicalTuple = + PCanonicalTuple(true, PInt32Required, PInt64Required, PInt64Required, builder.stateType) + val maxSize = kb.genFieldThisRef[Int]() val seenSoFar = kb.genFieldThisRef[Long]() private val garbage = kb.genFieldThisRef[Long]() @@ -29,18 +33,19 @@ class ReservoirSampleRVAS(val eltType: VirtualTypeWithReq, val kb: EmitClassBuil private val garbageOffset: Code[Long] => Code[Long] = storageType.loadField(_, 2) private val builderStateOffset: Code[Long] => Code[Long] = storageType.loadField(_, 3) - def newState(cb: EmitCodeBuilder, off: Value[Long]): Unit = { + def newState(cb: EmitCodeBuilder, off: Value[Long]): Unit = cb += region.getNewRegion(regionSize) - } def createState(cb: EmitCodeBuilder): Unit = { cb.assign(rand, Code.newInstance[java.util.Random]) - cb.if_(region.isNull, { - cb.assign(r, Region.stagedCreate(regionSize, kb.pool())) - }) + cb.if_(region.isNull, cb.assign(r, Region.stagedCreate(regionSize, kb.pool()))) } - override def load(cb: EmitCodeBuilder, regionLoader: (EmitCodeBuilder, Value[Region]) => Unit, src: Value[Long]): Unit = { + override def load( + cb: EmitCodeBuilder, + regionLoader: (EmitCodeBuilder, Value[Region]) => Unit, + src: Value[Long], + ): Unit = { regionLoader(cb, r) cb.assign(maxSize, Region.loadInt(maxSizeOffset(src))) cb.assign(seenSoFar, Region.loadLong(elementsSeenOffset(src))) @@ -48,33 +53,36 @@ class ReservoirSampleRVAS(val eltType: VirtualTypeWithReq, val kb: EmitClassBuil builder.loadFrom(cb, builderStateOffset(src)) } - override def store(cb: EmitCodeBuilder, regionStorer: (EmitCodeBuilder, Value[Region]) => Unit, dest: Value[Long]): Unit = { - cb.if_(region.isValid, - { + override def store( + cb: EmitCodeBuilder, + regionStorer: (EmitCodeBuilder, Value[Region]) => Unit, + dest: Value[Long], + ): Unit = { + cb.if_( + region.isValid, { regionStorer(cb, region) cb += region.invalidate() cb += Region.storeInt(maxSizeOffset(dest), maxSize) cb += Region.storeLong(elementsSeenOffset(dest), seenSoFar) cb += Region.storeLong(garbageOffset(dest), garbage) builder.storeTo(cb, builderStateOffset(dest)) - }) + }, + ) } def serialize(codec: BufferSpec): (EmitCodeBuilder, Value[OutputBuffer]) => Unit = { - { (cb: EmitCodeBuilder, ob: Value[OutputBuffer]) => + (cb: EmitCodeBuilder, ob: Value[OutputBuffer]) => cb += ob.writeInt(maxSize) cb += ob.writeLong(seenSoFar) builder.serialize(codec)(cb, ob) - } } def deserialize(codec: BufferSpec): (EmitCodeBuilder, Value[InputBuffer]) => Unit = { - { (cb: EmitCodeBuilder, ib: Value[InputBuffer]) => + (cb: EmitCodeBuilder, ib: Value[InputBuffer]) => cb.assign(maxSize, ib.readInt()) cb.assign(seenSoFar, ib.readLong()) cb.assign(garbage, 0L) builder.deserialize(codec)(cb, ib) - } } def init(cb: EmitCodeBuilder, _maxSize: Code[Int]): Unit = { @@ -84,127 +92,158 @@ class ReservoirSampleRVAS(val eltType: VirtualTypeWithReq, val kb: EmitClassBuil builder.initialize(cb) } - def gc(cb: EmitCodeBuilder): Unit = { - cb.invokeVoid(cb.emb.ecb.getOrGenEmitMethod("reservoir_sample_gc", - (this, "gc"), FastSeq(), UnitInfo) { mb => + private[this] val gc: EmitMethodBuilder[_] = + kb.defineEmitMethod(genName("m", "reservoir_sample_gc"), FastSeq(), UnitInfo) { mb => mb.voidWithBuilder { cb => - cb.if_(garbage > (maxSize.toL * 2L + 1024L), { - val oldRegion = mb.newLocal[Region]("old_region") - cb.assign(oldRegion, region) - cb.assign(r, Region.stagedCreate(regionSize, kb.pool())) - builder.reallocateData(cb) - cb.assign(garbage, 0L) - cb += oldRegion.invoke[Unit]("invalidate") - }) + cb.if_( + garbage > (maxSize.toL * 2L + 1024L), { + val oldRegion = mb.newLocal[Region]("old_region") + cb.assign(oldRegion, region) + cb.assign(r, Region.stagedCreate(regionSize, kb.pool())) + builder.reallocateData(cb) + cb.assign(garbage, 0L) + cb += oldRegion.invoke[Unit]("invalidate") + }, + ) } - }) - } + } def seqOp(cb: EmitCodeBuilder, elt: EmitCode): Unit = { val eltVal = cb.memoize(elt) cb.assign(seenSoFar, seenSoFar + 1) - cb.if_(builder.size < maxSize, + cb.if_( + builder.size < maxSize, eltVal.toI(cb) - .consume(cb, - builder.setMissing(cb), - sc => builder.append(cb, sc)), - { - // swaps the next element into the reservoir with probability (k / n), where - // k is the reservoir size and n is the number of elements seen so far (including current) - cb.if_(rand.invoke[Double]("nextDouble") * seenSoFar.toD <= maxSize.toD, { - val idxToSwap = cb.memoize(rand.invoke[Int, Int]("nextInt", maxSize)) - builder.overwrite(cb, eltVal, idxToSwap) - cb.assign(garbage, garbage + 1L) - gc(cb) - }) - }) + .consume(cb, builder.setMissing(cb), sc => builder.append(cb, sc)), { + // swaps the next element into the reservoir with probability (k / n), where + // k is the reservoir size and n is the number of elements seen so far (including current) + cb.if_( + rand.invoke[Double]("nextDouble") * seenSoFar.toD <= maxSize.toD, { + val idxToSwap = cb.memoize(rand.invoke[Int, Int]("nextInt", maxSize)) + builder.overwrite(cb, eltVal, idxToSwap) + cb.assign(garbage, garbage + 1L) + cb.invokeVoid(gc, cb.this_) + }, + ) + }, + ) } def dump(cb: EmitCodeBuilder, prefix: String): Unit = { - cb.println(s"> dumping reservoir: $prefix with size=", maxSize.toS,", seen=", seenSoFar.toS) + cb.println(s"> dumping reservoir: $prefix with size=", maxSize.toS, ", seen=", seenSoFar.toS) val j = cb.newLocal[Int]("j", 0) - cb.while_(j < builder.size, { - cb.println(" j=", j.toS, ", elt=", cb.strValue(builder.loadElement(cb, j))) - cb.assign(j, j + 1) - }) + cb.while_( + j < builder.size, { + cb.println(" j=", j.toS, ", elt=", cb.strValue(builder.loadElement(cb, j))) + cb.assign(j, j + 1) + }, + ) } def combine(cb: EmitCodeBuilder, other: ReservoirSampleRVAS): Unit = { val j = cb.newLocal[Int]("j") - cb.if_(other.builder.size < maxSize, { + cb.if_( + other.builder.size < maxSize, { - cb.assign(j, 0) - cb.while_(j < other.builder.size, { - seqOp(cb, cb.memoize(other.builder.loadElement(cb, j))) - cb.assign(j, j + 1) - }) - }, { - cb.if_(builder.size < maxSize, { cb.assign(j, 0) - cb.while_(j < builder.size, { - other.seqOp(cb, cb.memoize(builder.loadElement(cb, j))) - cb.assign(j, j + 1) - }) - - cb.assign(seenSoFar, other.seenSoFar) - cb.assign(garbage, other.garbage) - val tmpRegion = cb.newLocal[Region]("tmpRegion", region) - cb.assign(r, other.region) - cb.assign(other.r, tmpRegion) - cb += tmpRegion.invoke[Unit]("invalidate") - builder.cloneFrom(cb, other.builder) - + cb.while_( + j < other.builder.size, { + seqOp(cb, cb.memoize(other.builder.loadElement(cb, j))) + cb.assign(j, j + 1) + }, + ) }, { - val newBuilder = new StagedArrayBuilder(eltPType, kb, region) - newBuilder.initializeWithCapacity(cb, maxSize) - - val totalWeightLeft = cb.newLocal("totalWeightLeft", seenSoFar.toD) - val totalWeightRight = cb.newLocal("totalWeightRight", other.seenSoFar.toD) + cb.if_( + builder.size < maxSize, { + cb.assign(j, 0) + cb.while_( + j < builder.size, { + other.seqOp(cb, cb.memoize(builder.loadElement(cb, j))) + cb.assign(j, j + 1) + }, + ) - val leftSize = cb.newLocal[Int]("leftSize", builder.size) - val rightSize = cb.newLocal[Int]("rightSize", other.builder.size) + cb.assign(seenSoFar, other.seenSoFar) + cb.assign(garbage, other.garbage) + val tmpRegion = cb.newLocal[Region]("tmpRegion", region) + cb.assign(r, other.region) + cb.assign(other.r, tmpRegion) + cb += tmpRegion.invoke[Unit]("invalidate") + builder.cloneFrom(cb, other.builder) - cb.assign(j, 0) - cb.while_(j < maxSize, { - val x = cb.memoize(rand.invoke[Double]("nextDouble")) - cb.if_(x * (totalWeightLeft + totalWeightRight) <= totalWeightLeft, { - - val idxToSample = cb.memoize(rand.invoke[Int, Int]("nextInt", leftSize)) - builder.loadElement(cb, idxToSample).toI(cb).consume(cb, - newBuilder.setMissing(cb), - newBuilder.append(cb, _, false)) - cb.assign(leftSize, leftSize - 1) - cb.assign(totalWeightLeft, totalWeightLeft - 1) - cb.if_(idxToSample < leftSize, { - builder.overwrite(cb, cb.memoize(builder.loadElement(cb, leftSize)), idxToSample, false) - }) }, { - val idxToSample = cb.memoize(rand.invoke[Int, Int]("nextInt", rightSize)) - other.builder.loadElement(cb, idxToSample).toI(cb).consume(cb, - newBuilder.setMissing(cb), - newBuilder.append(cb, _, true)) - cb.assign(rightSize, rightSize - 1) - cb.assign(totalWeightRight, totalWeightRight - 1) - cb.if_(idxToSample < rightSize, { - other.builder.overwrite(cb, cb.memoize(other.builder.loadElement(cb, rightSize)), idxToSample, false) - }) - }) - cb.assign(j, j + 1) - }) - builder.cloneFrom(cb, newBuilder) - cb.assign(seenSoFar, seenSoFar + other.seenSoFar) - cb.assign(garbage, garbage + leftSize.toL) - gc(cb) - }) - }) + val newBuilder = new StagedArrayBuilder(eltPType, kb, region) + newBuilder.initializeWithCapacity(cb, maxSize) + + val totalWeightLeft = cb.newLocal("totalWeightLeft", seenSoFar.toD) + val totalWeightRight = cb.newLocal("totalWeightRight", other.seenSoFar.toD) + + val leftSize = cb.newLocal[Int]("leftSize", builder.size) + val rightSize = cb.newLocal[Int]("rightSize", other.builder.size) + + cb.assign(j, 0) + cb.while_( + j < maxSize, { + val x = cb.memoize(rand.invoke[Double]("nextDouble")) + cb.if_( + x * (totalWeightLeft + totalWeightRight) <= totalWeightLeft, { + + val idxToSample = cb.memoize(rand.invoke[Int, Int]("nextInt", leftSize)) + builder.loadElement(cb, idxToSample).toI(cb).consume( + cb, + newBuilder.setMissing(cb), + newBuilder.append(cb, _, false), + ) + cb.assign(leftSize, leftSize - 1) + cb.assign(totalWeightLeft, totalWeightLeft - 1) + cb.if_( + idxToSample < leftSize, + builder.overwrite( + cb, + cb.memoize(builder.loadElement(cb, leftSize)), + idxToSample, + false, + ), + ) + }, { + val idxToSample = cb.memoize(rand.invoke[Int, Int]("nextInt", rightSize)) + other.builder.loadElement(cb, idxToSample).toI(cb).consume( + cb, + newBuilder.setMissing(cb), + newBuilder.append(cb, _, true), + ) + cb.assign(rightSize, rightSize - 1) + cb.assign(totalWeightRight, totalWeightRight - 1) + cb.if_( + idxToSample < rightSize, + other.builder.overwrite( + cb, + cb.memoize(other.builder.loadElement(cb, rightSize)), + idxToSample, + false, + ), + ) + }, + ) + cb.assign(j, j + 1) + }, + ) + builder.cloneFrom(cb, newBuilder) + cb.assign(seenSoFar, seenSoFar + other.seenSoFar) + cb.assign(garbage, garbage + leftSize.toL) + cb.invokeVoid(gc, cb.this_) + }, + ) + }, + ) } - def resultArray(cb: EmitCodeBuilder, region: Value[Region], resType: PCanonicalArray): SIndexablePointerValue = { + def resultArray(cb: EmitCodeBuilder, region: Value[Region], resType: PCanonicalArray) + : SIndexablePointerValue = resType.constructFromElements(cb, region, builder.size, deepCopy = true) { (cb, idx) => builder.loadElement(cb, idx).toI(cb) } - } def copyFrom(cb: EmitCodeBuilder, src: Value[Long]): Unit = { cb.assign(maxSize, Region.loadInt(maxSizeOffset(src))) @@ -223,25 +262,33 @@ class ReservoirSampleAggregator(typ: VirtualTypeWithReq) extends StagedAggregato val initOpTypes: Seq[Type] = Array(TInt32) val seqOpTypes: Seq[Type] = Array(typ.t) - protected def _initOp(cb: EmitCodeBuilder, state: ReservoirSampleRVAS, init: Array[EmitCode]): Unit = { + protected def _initOp(cb: EmitCodeBuilder, state: ReservoirSampleRVAS, init: Array[EmitCode]) + : Unit = { assert(init.length == 1) val Array(sizeTriplet) = init sizeTriplet.toI(cb) - .consume(cb, + .consume( + cb, cb += Code._fatal[Unit](s"argument 'n' for 'hl.agg.reservoir_sample' may not be missing"), - sc => state.init(cb, sc.asInt.value) + sc => state.init(cb, sc.asInt.value), ) } - protected def _seqOp(cb: EmitCodeBuilder, state: ReservoirSampleRVAS, seq: Array[EmitCode]): Unit = { + protected def _seqOp(cb: EmitCodeBuilder, state: ReservoirSampleRVAS, seq: Array[EmitCode]) + : Unit = { val Array(elt: EmitCode) = seq state.seqOp(cb, elt) } - protected def _combOp(ctx: ExecuteContext, cb: EmitCodeBuilder, state: ReservoirSampleRVAS, other: ReservoirSampleRVAS): Unit = state.combine(cb, other) + protected def _combOp( + ctx: ExecuteContext, + cb: EmitCodeBuilder, + region: Value[Region], + state: ReservoirSampleRVAS, + other: ReservoirSampleRVAS, + ): Unit = state.combine(cb, other) - protected def _result(cb: EmitCodeBuilder, state: State, region: Value[Region]): IEmitCode = { + protected def _result(cb: EmitCodeBuilder, state: State, region: Value[Region]): IEmitCode = // deepCopy is handled by state.resultArray IEmitCode.present(cb, state.resultArray(cb, region, resultPType)) - } } diff --git a/hail/src/main/scala/is/hail/expr/ir/agg/StagedAggregator.scala b/hail/src/main/scala/is/hail/expr/ir/agg/StagedAggregator.scala index 4d1a5330b03..02408c7bf44 100644 --- a/hail/src/main/scala/is/hail/expr/ir/agg/StagedAggregator.scala +++ b/hail/src/main/scala/is/hail/expr/ir/agg/StagedAggregator.scala @@ -3,8 +3,7 @@ package is.hail.expr.ir.agg import is.hail.annotations.Region import is.hail.asm4s._ import is.hail.backend.ExecuteContext -import is.hail.expr.ir.{EmitCode, EmitCodeBuilder, EmitContext, IEmitCode} -import is.hail.types.physical.PType +import is.hail.expr.ir.{EmitCode, EmitCodeBuilder, IEmitCode} import is.hail.types.physical.stypes.EmitType import is.hail.types.virtual.Type @@ -15,18 +14,35 @@ abstract class StagedAggregator { def initOpTypes: Seq[Type] def seqOpTypes: Seq[Type] - protected def _initOp(cb: EmitCodeBuilder, state: State, init: Array[EmitCode]) + protected def _initOp(cb: EmitCodeBuilder, state: State, init: Array[EmitCode]): Unit - protected def _seqOp(cb: EmitCodeBuilder, state: State, seq: Array[EmitCode]) + protected def _seqOp(cb: EmitCodeBuilder, state: State, seq: Array[EmitCode]): Unit - protected def _combOp(ctx: ExecuteContext, cb: EmitCodeBuilder, state: State, other: State): Unit + protected def _combOp( + ctx: ExecuteContext, + cb: EmitCodeBuilder, + region: Value[Region], + state: State, + other: State, + ): Unit protected def _result(cb: EmitCodeBuilder, state: State, region: Value[Region]): IEmitCode - def initOp(cb: EmitCodeBuilder, state: AggregatorState, init: Array[EmitCode]) = _initOp(cb, state.asInstanceOf[State], init) - def seqOp(cb: EmitCodeBuilder, state: AggregatorState, seq: Array[EmitCode]) = _seqOp(cb, state.asInstanceOf[State], seq) + def initOp(cb: EmitCodeBuilder, state: AggregatorState, init: Array[EmitCode]) = + _initOp(cb, state.asInstanceOf[State], init) + + def seqOp(cb: EmitCodeBuilder, state: AggregatorState, seq: Array[EmitCode]) = + _seqOp(cb, state.asInstanceOf[State], seq) + + def combOp( + ctx: ExecuteContext, + cb: EmitCodeBuilder, + region: Value[Region], + state: AggregatorState, + other: AggregatorState, + ): Unit = + _combOp(ctx, cb, region, state.asInstanceOf[State], other.asInstanceOf[State]) - def combOp(ctx: ExecuteContext, cb: EmitCodeBuilder, state: AggregatorState, other: AggregatorState) = _combOp(ctx, cb, state.asInstanceOf[State], other.asInstanceOf[State]) def result(cb: EmitCodeBuilder, state: AggregatorState, region: Value[Region]): IEmitCode = _result(cb, state.asInstanceOf[State], region) } diff --git a/hail/src/main/scala/is/hail/expr/ir/agg/StagedArrayBuilder.scala b/hail/src/main/scala/is/hail/expr/ir/agg/StagedArrayBuilder.scala index 54a9fdc29db..554848e3104 100644 --- a/hail/src/main/scala/is/hail/expr/ir/agg/StagedArrayBuilder.scala +++ b/hail/src/main/scala/is/hail/expr/ir/agg/StagedArrayBuilder.scala @@ -12,47 +12,60 @@ object StagedArrayBuilder { val END_SERIALIZATION: Int = 0x12345678 } -class StagedArrayBuilder(eltType: PType, kb: EmitClassBuilder[_], region: Value[Region], var initialCapacity: Int = 8) { - val eltArray = PCanonicalArray(eltType.setRequired(false), required = true) // element type must be optional for serialization to work +class StagedArrayBuilder( + eltType: PType, + kb: EmitClassBuilder[_], + region: Value[Region], + var initialCapacity: Int = 8, +) { + val eltArray = + PCanonicalArray( + eltType.setRequired(false), + required = true, + ) // element type must be optional for serialization to work val stateType = PCanonicalTuple(true, PInt32Required, PInt32Required, eltArray) - val size: Settable[Int] = kb.genFieldThisRef[Int]("size") + val size: ThisFieldRef[Int] = kb.genFieldThisRef[Int]("size") private val capacity = kb.genFieldThisRef[Int]("capacity") val data = kb.genFieldThisRef[Long]("data") - private val tmpOff = kb.genFieldThisRef[Long]("tmp_offset") private val currentSizeOffset: Code[Long] => Code[Long] = stateType.fieldOffset(_, 0) private val capacityOffset: Code[Long] => Code[Long] = stateType.fieldOffset(_, 1) private val dataOffset: Code[Long] => Code[Long] = stateType.fieldOffset(_, 2) def loadFrom(cb: EmitCodeBuilder, src: Code[Long]): Unit = { - cb.assign(tmpOff, src) + val tmpOff = cb.memoize(src) cb.assign(size, Region.loadInt(currentSizeOffset(tmpOff))) cb.assign(capacity, Region.loadInt(capacityOffset(tmpOff))) cb.assign(data, Region.loadAddress(dataOffset(tmpOff))) } - def cloneFrom(cb: EmitCodeBuilder, other: StagedArrayBuilder): Unit = { - cb.assign(tmpOff, other.tmpOff) cb.assign(size, other.size) cb.assign(data, other.data) cb.assign(capacity, other.capacity) } def copyFrom(cb: EmitCodeBuilder, src: Code[Long]): Unit = { - cb.assign(tmpOff, src) + val tmpOff = cb.memoize(src) cb.assign(size, Region.loadInt(currentSizeOffset(tmpOff))) cb.assign(capacity, Region.loadInt(capacityOffset(tmpOff))) - cb.assign(data, eltArray.store(cb, region, eltArray.loadCheapSCode(cb, Region.loadAddress(dataOffset(tmpOff))), deepCopy = true)) + cb.assign( + data, + eltArray.store( + cb, + region, + eltArray.loadCheapSCode(cb, Region.loadAddress(dataOffset(tmpOff))), + deepCopy = true, + ), + ) } - def reallocateData(cb: EmitCodeBuilder): Unit = { + def reallocateData(cb: EmitCodeBuilder): Unit = cb.assign(data, eltArray.store(cb, region, eltArray.loadCheapSCode(cb, data), deepCopy = true)) - } def storeTo(cb: EmitCodeBuilder, dest: Code[Long]): Unit = { - cb.assign(tmpOff, dest) + val tmpOff = cb.memoize(dest) cb += Region.storeInt(currentSizeOffset(tmpOff), size) cb += Region.storeInt(capacityOffset(tmpOff), capacity) cb += Region.storeAddress(dataOffset(tmpOff), data) @@ -81,8 +94,9 @@ class StagedArrayBuilder(eltType: PType, kb: EmitClassBuilder[_], region: Value[ .apply(cb, region, ib) cb.assign(data, eltArray.store(cb, region, decValue, deepCopy = false)) - cb.if_(ib.readInt() cne StagedArrayBuilder.END_SERIALIZATION, - cb._fatal(s"StagedArrayBuilder serialization failed") + cb.if_( + ib.readInt() cne StagedArrayBuilder.END_SERIALIZATION, + cb._fatal(s"StagedArrayBuilder serialization failed"), ) } } @@ -92,8 +106,8 @@ class StagedArrayBuilder(eltType: PType, kb: EmitClassBuilder[_], region: Value[ resize(cb) } - def setMissing(cb: EmitCodeBuilder): Unit = incrementSize(cb) // all elements set to missing on initialization - + def setMissing(cb: EmitCodeBuilder): Unit = + incrementSize(cb) // all elements set to missing on initialization def append(cb: EmitCodeBuilder, elt: SValue, deepCopy: Boolean = true): Unit = { eltArray.setElementPresent(cb, data, size) @@ -101,13 +115,23 @@ class StagedArrayBuilder(eltType: PType, kb: EmitClassBuilder[_], region: Value[ incrementSize(cb) } - def overwrite(cb: EmitCodeBuilder, elt: EmitValue, idx: Value[Int], deepCopy: Boolean = true): Unit = { - elt.toI(cb).consume(cb, + def overwrite(cb: EmitCodeBuilder, elt: EmitValue, idx: Value[Int], deepCopy: Boolean = true) + : Unit = + elt.toI(cb).consume( + cb, PContainer.unsafeSetElementMissing(cb, eltArray, data, idx), - value => eltType.storeAtAddress(cb, eltArray.elementOffset(data, capacity, idx), region, value, deepCopy)) - } - - def initializeWithCapacity(cb: EmitCodeBuilder, capacity: Code[Int]): Unit = initialize(cb, 0, capacity) + value => + eltType.storeAtAddress( + cb, + eltArray.elementOffset(data, capacity, idx), + region, + value, + deepCopy, + ), + ) + + def initializeWithCapacity(cb: EmitCodeBuilder, capacity: Code[Int]): Unit = + initialize(cb, 0, capacity) def initialize(cb: EmitCodeBuilder): Unit = initialize(cb, const(0), const(initialCapacity)) @@ -121,18 +145,25 @@ class StagedArrayBuilder(eltType: PType, kb: EmitClassBuilder[_], region: Value[ def elementOffset(cb: EmitCodeBuilder, idx: Value[Int]): Value[Long] = cb.memoize(eltArray.elementOffset(data, capacity, idx)) - def loadElement(cb: EmitCodeBuilder, idx: Value[Int]): EmitCode = { val m = eltArray.isElementMissing(data, idx) EmitCode(Code._empty, m, eltType.loadCheapSCode(cb, eltArray.loadElement(data, capacity, idx))) } - private def resize(cb: EmitCodeBuilder): Unit = { - val newDataOffset = kb.genFieldThisRef[Long]("new_data_offset") - cb.if_(size.ceq(capacity), - { + def swap(cb: EmitCodeBuilder, p: Value[Int], q: Value[Int]): Unit = { + val pOff = elementOffset(cb, p) + val qOff = elementOffset(cb, q) + val tmpOff = elementOffset(cb, size) + cb += Region.copyFrom(pOff, tmpOff, eltType.byteSize) + cb += Region.copyFrom(qOff, pOff, eltType.byteSize) + cb += Region.copyFrom(tmpOff, qOff, eltType.byteSize) + } + + private def resize(cb: EmitCodeBuilder): Unit = + cb.if_( + size.ceq(capacity), { cb.assign(capacity, capacity * 2) cb.assign(data, eltArray.padWithMissing(cb, region, size, capacity, data)) - }) - } + }, + ) } diff --git a/hail/src/main/scala/is/hail/expr/ir/agg/StagedBlockLinkedList.scala b/hail/src/main/scala/is/hail/expr/ir/agg/StagedBlockLinkedList.scala index 8468348dcfb..f2895416721 100644 --- a/hail/src/main/scala/is/hail/expr/ir/agg/StagedBlockLinkedList.scala +++ b/hail/src/main/scala/is/hail/expr/ir/agg/StagedBlockLinkedList.scala @@ -27,14 +27,16 @@ class StagedBlockLinkedList(val elemType: PType, val kb: EmitClassBuilder[_]) { val storageType = PCanonicalStruct( "firstNode" -> PInt64Required, "lastNode" -> PInt64Required, - "totalCount" -> PInt32Required) + "totalCount" -> PInt32Required, + ) def load(cb: EmitCodeBuilder, src: Code[Long]): Unit = { cb += Code.memoize(src, "sbll_load_src") { src => Code( firstNode := Region.loadAddress(storageType.fieldOffset(src, "firstNode")), lastNode := Region.loadAddress(storageType.fieldOffset(src, "lastNode")), - totalCount := Region.loadInt(storageType.fieldOffset(src, "totalCount"))) + totalCount := Region.loadInt(storageType.fieldOffset(src, "totalCount")), + ) } } @@ -43,7 +45,8 @@ class StagedBlockLinkedList(val elemType: PType, val kb: EmitClassBuilder[_]) { Code( Region.storeAddress(storageType.fieldOffset(dst, "firstNode"), firstNode), Region.storeAddress(storageType.fieldOffset(dst, "lastNode"), lastNode), - Region.storeInt(storageType.fieldOffset(dst, "totalCount"), totalCount)) + Region.storeInt(storageType.fieldOffset(dst, "totalCount"), totalCount), + ) } type Node = Value[Long] @@ -54,7 +57,8 @@ class StagedBlockLinkedList(val elemType: PType, val kb: EmitClassBuilder[_]) { val nodeType = PCanonicalStruct( "buf" -> bufferType, "count" -> PInt32Required, - "next" -> PInt64Required) + "next" -> PInt64Required, + ) private def buffer(n: Node): Code[Long] = Region.loadAddress(nodeType.fieldOffset(n, "buf")) @@ -71,9 +75,6 @@ class StagedBlockLinkedList(val elemType: PType, val kb: EmitClassBuilder[_]) { private def next(n: Node): Code[Long] = Region.loadAddress(nodeType.fieldOffset(n, "next")) - private def hasNext(n: Node): Code[Boolean] = - next(n) cne nil - private def setNext(cb: EmitCodeBuilder, n: Node, nNext: Node): Unit = cb += Region.storeAddress(nodeType.fieldOffset(n, "next"), nNext) @@ -83,7 +84,12 @@ class StagedBlockLinkedList(val elemType: PType, val kb: EmitClassBuilder[_]) { cb += Region.storeAddress(nodeType.fieldOffset(n, "next"), nil) } - private def pushPresent(cb: EmitCodeBuilder, n: Node)(store: (EmitCodeBuilder, Code[Long]) => Unit): Unit = { + private def pushPresent( + cb: EmitCodeBuilder, + n: Node, + )( + store: (EmitCodeBuilder, Code[Long]) => Unit + ): Unit = { bufferType.setElementPresent(cb, buffer(n), count(n)) store(cb, bufferType.elementOffset(buffer(n), capacity(n), count(n))) incrCount(cb, n) @@ -91,23 +97,34 @@ class StagedBlockLinkedList(val elemType: PType, val kb: EmitClassBuilder[_]) { private def pushMissing(cb: EmitCodeBuilder, n: Node): Unit = if (elemType.required) - cb._fatal(s"Cannot insert missing element of ptype '${elemType.asIdent}' at index ", count(n).toS, ".") + cb._fatal( + s"Cannot insert missing element of ptype '${elemType.asIdent}' at index ", + count(n).toS, + ".", + ) else { bufferType.setElementMissing(cb, buffer(n), count(n)) incrCount(cb, n) } - private def allocateNode(cb: EmitCodeBuilder, dstNode: Settable[Long])(r: Value[Region], cap: Code[Int]): Unit = { + private def allocateNode( + cb: EmitCodeBuilder, + dstNode: Settable[Long], + )( + r: Value[Region], + cap: Code[Int], + ): Unit = { val capMemo = cb.memoize[Int](cap) cb.assign(dstNode, r.allocate(nodeType.alignment, nodeType.byteSize)) initNode(cb, dstNode, buf = bufferType.allocate(r, capMemo), count = 0) bufferType.stagedInitialize(cb, buffer(dstNode), capMemo) } - private def initWithCapacity(cb: EmitCodeBuilder, r: Value[Region], initialCap: Code[Int]): Unit = { - allocateNode(cb, firstNode)(r, initialCap) - cb.assign(lastNode, firstNode) - cb.assign(totalCount, 0) + private def initWithCapacity(cb: EmitCodeBuilder, r: Value[Region], initialCap: Code[Int]) + : Unit = { + allocateNode(cb, firstNode)(r, initialCap) + cb.assign(lastNode, firstNode) + cb.assign(totalCount, 0) } def init(cb: EmitCodeBuilder, r: Value[Region]): Unit = @@ -128,71 +145,77 @@ class StagedBlockLinkedList(val elemType: PType, val kb: EmitClassBuilder[_]) { private def foreach(cb: EmitCodeBuilder)(f: (EmitCodeBuilder, EmitCode) => Unit): Unit = { foreachNode(cb) { n => val i = cb.newLocal[Int]("bll_foreach_i") - cb.for_(cb.assign(i, 0), i < count(n), cb.assign(i, i + 1), { - val elt = EmitCode.fromI(cb.emb) { cb => - IEmitCode(cb, - bufferType.isElementMissing(buffer(n), i), - elemType.loadCheapSCode(cb, bufferType.loadElement(buffer(n), capacity(n), i)) - ) - } - f(cb, elt) - }) + cb.for_( + cb.assign(i, 0), + i < count(n), + cb.assign(i, i + 1), { + val elt = EmitCode.fromI(cb.emb) { cb => + IEmitCode( + cb, + bufferType.isElementMissing(buffer(n), i), + elemType.loadCheapSCode(cb, bufferType.loadElement(buffer(n), capacity(n), i)), + ) + } + f(cb, elt) + }, + ) } } private def pushImpl(cb: EmitCodeBuilder, r: Value[Region], v: EmitCode): Unit = { - cb.if_(count(lastNode) >= capacity(lastNode), - pushNewBlockNode(cb, r, defaultBlockCap)) + cb.if_(count(lastNode) >= capacity(lastNode), pushNewBlockNode(cb, r, defaultBlockCap)) v.toI(cb) - .consume(cb, + .consume( + cb, pushMissing(cb, lastNode), - { sc => + sc => pushPresent(cb, lastNode) { (cb, addr) => elemType.storeAtAddress(cb, addr, r, sc, deepCopy = true) - } - }) + }, + ) cb.assign(totalCount, totalCount + 1) } def push(cb: EmitCodeBuilder, region: Value[Region], elt: EmitCode): Unit = { - val pushF = kb.genEmitMethod("blockLinkedListPush", - FastSeq[ParamType](typeInfo[Region], elt.emitParamType), typeInfo[Unit]) + val pushF = cb.emb.ecb.genEmitMethod( + "blockLinkedListPush", + FastSeq(typeInfo[Region], elt.emitParamType), + typeInfo[Unit], + ) pushF.voidWithBuilder { cb => - pushImpl(cb, - pushF.getCodeParam[Region](1), - pushF.getEmitParam(cb, 2)) + pushImpl(cb, pushF.getCodeParam[Region](1), pushF.getEmitParam(cb, 2)) } - cb.invokeVoid(pushF, region, elt) + cb.invokeVoid(pushF, cb.this_, region, elt) } def append(cb: EmitCodeBuilder, region: Value[Region], bll: StagedBlockLinkedList): Unit = { // it would take additional logic to get self-append to work, but we don't need it to anyways assert(bll ne this) assert(bll.elemType.isOfType(elemType)) - val appF = kb.genEmitMethod("blockLinkedListAppend", - FastSeq[ParamType](typeInfo[Region]), - typeInfo[Unit]) + val appF = + cb.emb.ecb.genEmitMethod("blockLinkedListAppend", FastSeq(typeInfo[Region]), typeInfo[Unit]) appF.voidWithBuilder { cb => - bll.foreach(cb) { (cb, elt) => - pushImpl(cb, appF.getCodeParam[Region](1), elt) - } + bll.foreach(cb)((cb, elt) => pushImpl(cb, appF.getCodeParam[Region](1), elt)) } - cb.invokeVoid(appF, region) + cb.invokeVoid(appF, cb.this_, region) } - def resultArray(cb: EmitCodeBuilder, region: Value[Region], resType: PCanonicalArray): SIndexablePointerValue = { - val (pushElement, finish) = resType.constructFromFunctions(cb, region, totalCount, deepCopy = true) - foreach(cb) { (cb, elt) => - pushElement(cb, elt.toI(cb)) - } + def resultArray(cb: EmitCodeBuilder, region: Value[Region], resType: PCanonicalArray) + : SIndexablePointerValue = { + val (pushElement, finish) = + resType.constructFromFunctions(cb, region, totalCount, deepCopy = true) + foreach(cb)((cb, elt) => pushElement(cb, elt.toI(cb))) finish(cb) } - def serialize(cb: EmitCodeBuilder, region: Value[Region], outputBuffer: Value[OutputBuffer]): Unit = { - val serF = kb.genEmitMethod("blockLinkedListSerialize", - FastSeq[ParamType](typeInfo[Region], typeInfo[OutputBuffer]), - typeInfo[Unit]) + def serialize(cb: EmitCodeBuilder, region: Value[Region], outputBuffer: Value[OutputBuffer]) + : Unit = { + val serF = cb.emb.ecb.genEmitMethod( + "blockLinkedListSerialize", + FastSeq(typeInfo[Region], typeInfo[OutputBuffer]), + typeInfo[Unit], + ) val ob = serF.getCodeParam[OutputBuffer](2) serF.voidWithBuilder { cb => val b = cb.newLocal[Long]("bll_serialize_b") @@ -203,22 +226,21 @@ class StagedBlockLinkedList(val elemType: PType, val kb: EmitClassBuilder[_]) { } cb += ob.writeBoolean(false) } - cb.invokeVoid(serF, region, outputBuffer) + cb.invokeVoid(serF, cb.this_, region, outputBuffer) } - def deserialize(cb: EmitCodeBuilder, region: Value[Region], inputBuffer: Value[InputBuffer]): Unit = { - val desF = kb.genEmitMethod("blockLinkedListDeserialize", - FastSeq[ParamType](typeInfo[Region], typeInfo[InputBuffer]), - typeInfo[Unit]) + def deserialize(cb: EmitCodeBuilder, region: Value[Region], inputBuffer: Value[InputBuffer]) + : Unit = { + val desF = cb.emb.ecb.genEmitMethod( + "blockLinkedListDeserialize", + FastSeq(typeInfo[Region], typeInfo[InputBuffer]), + typeInfo[Unit], + ) val r = desF.getCodeParam[Region](1) val ib = desF.getCodeParam[InputBuffer](2) val dec = bufferEType.buildDecoder(bufferType.virtualType, desF.ecb) - desF.voidWithBuilder { cb => - cb.while_(ib.readBoolean(), { - appendShallow(cb, r, dec(cb, r, ib)) - }) - } - cb.invokeVoid(desF, region, inputBuffer) + desF.voidWithBuilder(cb => cb.while_(ib.readBoolean(), appendShallow(cb, r, dec(cb, r, ib)))) + cb.invokeVoid(desF, cb.this_, region, inputBuffer) } private def appendShallow(cb: EmitCodeBuilder, r: Value[Region], aCode: SValue): Unit = { @@ -230,12 +252,15 @@ class StagedBlockLinkedList(val elemType: PType, val kb: EmitClassBuilder[_]) { cb.assign(totalCount, totalCount + buff.length) } - def initWithDeepCopy(cb: EmitCodeBuilder, region: Value[Region], other: StagedBlockLinkedList): Unit = { + def initWithDeepCopy(cb: EmitCodeBuilder, region: Value[Region], other: StagedBlockLinkedList) + : Unit = { assert(other ne this) assert(other.kb eq kb) - val initF = kb.genEmitMethod("blockLinkedListDeepCopy", + val initF = cb.emb.ecb.genEmitMethod( + "blockLinkedListDeepCopy", FastSeq[ParamType](typeInfo[Region]), - typeInfo[Unit]) + typeInfo[Unit], + ) val r = initF.getCodeParam[Region](1) initF.voidWithBuilder { cb => // sets firstNode @@ -244,17 +269,19 @@ class StagedBlockLinkedList(val elemType: PType, val kb: EmitClassBuilder[_]) { val buf = cb.newLocal[Long]("sbll_init_deepcopy_buf", buffer(firstNode)) other.foreach(cb) { (cb, elt) => elt.toI(cb) - .consume(cb, + .consume( + cb, PContainer.unsafeSetElementMissing(cb, bufferType, buf, i), { sc => bufferType.setElementPresent(cb, buf, i) elemType.storeAtAddress(cb, bufferType.elementOffset(buf, i), r, sc, deepCopy = true) - }) + }, + ) incrCount(cb, firstNode) cb.assign(i, i + 1) } cb.assign(totalCount, other.totalCount) } - cb.invokeVoid(initF, region) + cb.invokeVoid(initF, cb.this_, region) } } diff --git a/hail/src/main/scala/is/hail/expr/ir/agg/TakeAggregator.scala b/hail/src/main/scala/is/hail/expr/ir/agg/TakeAggregator.scala index 889d1851188..00295f2efad 100644 --- a/hail/src/main/scala/is/hail/expr/ir/agg/TakeAggregator.scala +++ b/hail/src/main/scala/is/hail/expr/ir/agg/TakeAggregator.scala @@ -12,7 +12,8 @@ import is.hail.types.physical.stypes.concrete.{SIndexablePointer, SIndexablePoin import is.hail.types.virtual.{TInt32, Type} import is.hail.utils._ -class TakeRVAS(val eltType: VirtualTypeWithReq, val kb: EmitClassBuilder[_]) extends AggregatorState { +class TakeRVAS(val eltType: VirtualTypeWithReq, val kb: EmitClassBuilder[_]) + extends AggregatorState { val eltPType = eltType.canonicalPType private val r: ThisFieldRef[Region] = kb.genFieldThisRef[Region]() @@ -27,38 +28,43 @@ class TakeRVAS(val eltType: VirtualTypeWithReq, val kb: EmitClassBuilder[_]) ext def newState(cb: EmitCodeBuilder, off: Value[Long]): Unit = cb += region.getNewRegion(regionSize) def createState(cb: EmitCodeBuilder): Unit = - cb.if_(region.isNull, { - cb.assign(r, Region.stagedCreate(regionSize, kb.pool())) - }) + cb.if_(region.isNull, cb.assign(r, Region.stagedCreate(regionSize, kb.pool()))) - override def load(cb: EmitCodeBuilder, regionLoader: (EmitCodeBuilder, Value[Region]) => Unit, src: Value[Long]): Unit = { + override def load( + cb: EmitCodeBuilder, + regionLoader: (EmitCodeBuilder, Value[Region]) => Unit, + src: Value[Long], + ): Unit = { regionLoader(cb, r) cb.assign(maxSize, Region.loadInt(maxSizeOffset(src))) builder.loadFrom(cb, builderStateOffset(src)) } - override def store(cb: EmitCodeBuilder, regionStorer: (EmitCodeBuilder, Value[Region]) => Unit, dest: Value[Long]): Unit = { - cb.if_(region.isValid, - { + override def store( + cb: EmitCodeBuilder, + regionStorer: (EmitCodeBuilder, Value[Region]) => Unit, + dest: Value[Long], + ): Unit = { + cb.if_( + region.isValid, { regionStorer(cb, region) cb += region.invalidate() cb += Region.storeInt(maxSizeOffset(dest), maxSize) builder.storeTo(cb, builderStateOffset(dest)) - }) + }, + ) } def serialize(codec: BufferSpec): (EmitCodeBuilder, Value[OutputBuffer]) => Unit = { - { (cb: EmitCodeBuilder, ob: Value[OutputBuffer]) => + (cb: EmitCodeBuilder, ob: Value[OutputBuffer]) => cb += ob.writeInt(maxSize) builder.serialize(codec)(cb, ob) - } } def deserialize(codec: BufferSpec): (EmitCodeBuilder, Value[InputBuffer]) => Unit = { - { (cb: EmitCodeBuilder, ib: Value[InputBuffer]) => + (cb: EmitCodeBuilder, ib: Value[InputBuffer]) => cb.assign(maxSize, ib.readInt()) builder.deserialize(codec)(cb, ib) - } } def init(cb: EmitCodeBuilder, _maxSize: Code[Int]): Unit = { @@ -66,32 +72,30 @@ class TakeRVAS(val eltType: VirtualTypeWithReq, val kb: EmitClassBuilder[_]) ext builder.initialize(cb) } - def seqOp(cb: EmitCodeBuilder, elt: EmitCode): Unit = { - cb.if_(builder.size < maxSize, + def seqOp(cb: EmitCodeBuilder, elt: EmitCode): Unit = + cb.if_( + builder.size < maxSize, elt.toI(cb) - .consume(cb, - builder.setMissing(cb), - sc => builder.append(cb, sc))) - } + .consume(cb, builder.setMissing(cb), sc => builder.append(cb, sc)), + ) def combine(cb: EmitCodeBuilder, other: TakeRVAS): Unit = { val j = kb.genFieldThisRef[Int]() cb.assign(j, 0) - cb.while_((builder.size < maxSize) && (j < other.builder.size), - { + cb.while_( + (builder.size < maxSize) && (j < other.builder.size), { other.builder.loadElement(cb, j).toI(cb) - .consume(cb, - builder.setMissing(cb), - sc => builder.append(cb, sc)) + .consume(cb, builder.setMissing(cb), sc => builder.append(cb, sc)) cb.assign(j, j + 1) - }) + }, + ) } - def resultArray(cb: EmitCodeBuilder, region: Value[Region], resType: PCanonicalArray): SIndexablePointerValue = { + def resultArray(cb: EmitCodeBuilder, region: Value[Region], resType: PCanonicalArray) + : SIndexablePointerValue = resType.constructFromElements(cb, region, builder.size, deepCopy = true) { (cb, idx) => builder.loadElement(cb, idx).toI(cb) } - } def copyFrom(cb: EmitCodeBuilder, src: Value[Long]): Unit = { cb.assign(maxSize, Region.loadInt(maxSizeOffset(src))) @@ -112,9 +116,10 @@ class TakeAggregator(typ: VirtualTypeWithReq) extends StagedAggregator { assert(init.length == 1) val Array(sizeTriplet) = init sizeTriplet.toI(cb) - .consume(cb, + .consume( + cb, cb += Code._fatal[Unit](s"argument 'n' for 'hl.agg.take' may not be missing"), - sc => state.init(cb, sc.asInt.value) + sc => state.init(cb, sc.asInt.value), ) } @@ -123,10 +128,15 @@ class TakeAggregator(typ: VirtualTypeWithReq) extends StagedAggregator { state.seqOp(cb, elt) } - protected def _combOp(ctx: ExecuteContext, cb: EmitCodeBuilder, state: TakeRVAS, other: TakeRVAS): Unit = state.combine(cb, other) + protected def _combOp( + ctx: ExecuteContext, + cb: EmitCodeBuilder, + region: Value[Region], + state: TakeRVAS, + other: TakeRVAS, + ): Unit = state.combine(cb, other) - protected def _result(cb: EmitCodeBuilder, state: State, region: Value[Region]): IEmitCode = { + protected def _result(cb: EmitCodeBuilder, state: State, region: Value[Region]): IEmitCode = // deepCopy is handled by state.resultArray IEmitCode.present(cb, state.resultArray(cb, region, resultPType)) - } } diff --git a/hail/src/main/scala/is/hail/expr/ir/agg/TakeByAggregator.scala b/hail/src/main/scala/is/hail/expr/ir/agg/TakeByAggregator.scala index 0cac4a38eca..3fa6237ba57 100644 --- a/hail/src/main/scala/is/hail/expr/ir/agg/TakeByAggregator.scala +++ b/hail/src/main/scala/is/hail/expr/ir/agg/TakeByAggregator.scala @@ -3,14 +3,18 @@ package is.hail.expr.ir.agg import is.hail.annotations.Region import is.hail.asm4s.{Code, _} import is.hail.backend.ExecuteContext +import is.hail.expr.ir.{ + Ascending, EmitClassBuilder, EmitCode, EmitCodeBuilder, EmitValue, IEmitCode, ParamType, SortOrder, +} import is.hail.expr.ir.orderings.StructOrdering -import is.hail.expr.ir.{Ascending, EmitClassBuilder, EmitCode, EmitCodeBuilder, EmitValue, IEmitCode, ParamType, SortOrder} import is.hail.io.{BufferSpec, InputBuffer, OutputBuffer} import is.hail.types.VirtualTypeWithReq import is.hail.types.physical._ -import is.hail.types.physical.stypes.concrete.{SBaseStructPointerValue, SIndexablePointer, SIndexablePointerValue} -import is.hail.types.physical.stypes.interfaces._ import is.hail.types.physical.stypes.{EmitType, SValue} +import is.hail.types.physical.stypes.concrete.{ + SBaseStructPointerValue, SIndexablePointer, SIndexablePointerValue, +} +import is.hail.types.physical.stypes.interfaces._ import is.hail.types.virtual.{TInt32, Type} import is.hail.utils._ @@ -18,7 +22,12 @@ object TakeByRVAS { val END_SERIALIZATION: Int = 0x1324 } -class TakeByRVAS(val valueVType: VirtualTypeWithReq, val keyVType: VirtualTypeWithReq, val kb: EmitClassBuilder[_], so: SortOrder = Ascending) extends AggregatorState { +class TakeByRVAS( + val valueVType: VirtualTypeWithReq, + val keyVType: VirtualTypeWithReq, + val kb: EmitClassBuilder[_], + so: SortOrder = Ascending, +) extends AggregatorState { private val r: Settable[Region] = kb.genFieldThisRef[Region]("takeby_region") val valueType: PType = valueVType.canonicalPType @@ -37,7 +46,9 @@ class TakeByRVAS(val valueVType: VirtualTypeWithReq, val keyVType: VirtualTypeWi private val tempPtr = kb.genFieldThisRef[Long]("tmp_ptr") private val canHaveGarbage = eltTuple.containsPointers - private val (garbage, maxGarbage) = if (canHaveGarbage) (kb.genFieldThisRef[Int](), kb.genFieldThisRef[Int]()) else (null, null) + + private val (garbage, maxGarbage) = + if (canHaveGarbage) (kb.genFieldThisRef[Int](), kb.genFieldThisRef[Int]()) else (null, null) private val garbageFields: IndexedSeq[(String, PType)] = if (canHaveGarbage) FastSeq(("current_garbage", PInt32Required), ("max_garbage", PInt32Required)) @@ -45,12 +56,15 @@ class TakeByRVAS(val valueVType: VirtualTypeWithReq, val keyVType: VirtualTypeWi FastSeq() val storageType: PStruct = - PCanonicalStruct(true, - Array(("state", ab.stateType), + PCanonicalStruct( + true, + Array( + ("state", ab.stateType), ("staging", PInt64Required), ("key_stage", PInt64Required), ("max_index", PInt64Required), - ("max_size", PInt32Required)) ++ garbageFields: _* + ("max_size", PInt32Required), + ) ++ garbageFields: _* ) def compareKey(cb: EmitCodeBuilder, k1: EmitValue, k2: EmitValue): Code[Int] = { @@ -59,11 +73,23 @@ class TakeByRVAS(val valueVType: VirtualTypeWithReq, val keyVType: VirtualTypeWi } private def compareIndexedKey(cb: EmitCodeBuilder, k1: SValue, k2: SValue): Value[Int] = { - val ord = StructOrdering.make(k1.st.asInstanceOf[SBaseStruct], k2.st.asInstanceOf[SBaseStruct], cb.emb.ecb, Array(so, Ascending), true) + val ord = StructOrdering.make( + k1.st.asInstanceOf[SBaseStruct], + k2.st.asInstanceOf[SBaseStruct], + cb.emb.ecb, + Array(so, Ascending), + true, + ) ord.compareNonnull(cb, k1, k2) } - private def maybeGCCode(cb: EmitCodeBuilder, alwaysRun: EmitCodeBuilder => Unit)(runIfGarbage: EmitCodeBuilder => Unit, runBefore: Boolean = false): Unit = { + private def maybeGCCode( + cb: EmitCodeBuilder, + alwaysRun: EmitCodeBuilder => Unit, + )( + runIfGarbage: EmitCodeBuilder => Unit, + runBefore: Boolean = false, + ): Unit = { val gc = (if (canHaveGarbage) runIfGarbage else (cb: EmitCodeBuilder) => ()) if (runBefore) { gc(cb) @@ -77,23 +103,34 @@ class TakeByRVAS(val valueVType: VirtualTypeWithReq, val keyVType: VirtualTypeWi def newState(cb: EmitCodeBuilder, off: Value[Long]): Unit = cb += region.getNewRegion(regionSize) def createState(cb: EmitCodeBuilder): Unit = - cb.if_(region.isNull, { - cb.assign(r, Region.stagedCreate(regionSize, kb.pool())) - cb += region.invalidate() - }) + cb.if_( + region.isNull, { + cb.assign(r, Region.stagedCreate(regionSize, kb.pool())) + cb += region.invalidate() + }, + ) - override def load(cb: EmitCodeBuilder, regionLoader: (EmitCodeBuilder, Value[Region]) => Unit, src: Value[Long]): Unit = { + override def load( + cb: EmitCodeBuilder, + regionLoader: (EmitCodeBuilder, Value[Region]) => Unit, + src: Value[Long], + ): Unit = { regionLoader(cb, r) loadFields(cb, src) } - override def store(cb: EmitCodeBuilder, regionStorer: (EmitCodeBuilder, Value[Region]) => Unit, dest: Value[Long]): Unit = { - cb.if_(region.isValid, - { + override def store( + cb: EmitCodeBuilder, + regionStorer: (EmitCodeBuilder, Value[Region]) => Unit, + dest: Value[Long], + ): Unit = { + cb.if_( + region.isValid, { regionStorer(cb, region) cb += region.invalidate() storeFields(cb, dest) - }) + }, + ) } private def initStaging(cb: EmitCodeBuilder): Unit = { @@ -102,15 +139,21 @@ class TakeByRVAS(val valueVType: VirtualTypeWithReq, val keyVType: VirtualTypeWi } def initialize(cb: EmitCodeBuilder, _maxSize: Code[Int]): Unit = { - maybeGCCode(cb, + maybeGCCode( + cb, { cb => cb.assign(maxIndex, 0L) cb.assign(maxSize, _maxSize) - cb.if_(maxSize < 0, - cb += Code._fatal[Unit](const("'take': 'n' cannot be negative, found '").concat(maxSize.toS))) + cb.if_( + maxSize < 0, + cb += Code._fatal[Unit]( + const("'take': 'n' cannot be negative, found '").concat(maxSize.toS) + ), + ) initStaging(cb) ab.initialize(cb) - })({ cb => + }, + )({ cb => cb.assign(garbage, 0) cb.assign(maxGarbage, Code.invokeStatic2[Math, Int, Int, Int]("max", maxSize * 2, 256)) }) @@ -118,14 +161,15 @@ class TakeByRVAS(val valueVType: VirtualTypeWithReq, val keyVType: VirtualTypeWi private def storeFields(cb: EmitCodeBuilder, destc: Code[Long]): Unit = { val dest = cb.newLocal("tba_store_fields_dest", destc) - maybeGCCode(cb, + maybeGCCode( + cb, { cb => ab.storeTo(cb, storageType.fieldOffset(dest, 0)) cb += Region.storeAddress(storageType.fieldOffset(dest, 1), staging) cb += Region.storeAddress(storageType.fieldOffset(dest, 2), keyStage) cb += Region.storeLong(storageType.fieldOffset(dest, 3), maxIndex) cb += Region.storeInt(storageType.fieldOffset(dest, 4), maxSize) - } + }, )({ cb => cb += Region.storeInt(storageType.fieldOffset(dest, 5), garbage) cb += Region.storeInt(storageType.fieldOffset(dest, 6), maxGarbage) @@ -134,65 +178,67 @@ class TakeByRVAS(val valueVType: VirtualTypeWithReq, val keyVType: VirtualTypeWi private def loadFields(cb: EmitCodeBuilder, srcc: Code[Long]): Unit = { val src = cb.newLocal("takeby_rvas_load_fields_src", srcc) - maybeGCCode(cb, + maybeGCCode( + cb, { cb => ab.loadFrom(cb, storageType.fieldOffset(src, 0)) cb.assign(staging, Region.loadAddress(storageType.fieldOffset(src, 1))) cb.assign(keyStage, Region.loadAddress(storageType.fieldOffset(src, 2))) cb.assign(maxIndex, Region.loadLong(storageType.fieldOffset(src, 3))) cb.assign(maxSize, Region.loadInt(storageType.fieldOffset(src, 4))) - } + }, )({ cb => cb.assign(garbage, Region.loadInt(storageType.fieldOffset(src, 5))) cb.assign(maxGarbage, Region.loadInt(storageType.fieldOffset(src, 6))) - } - ) + }) } def copyFrom(cb: EmitCodeBuilder, src: Value[Long]): Unit = { - maybeGCCode(cb, + maybeGCCode( + cb, { cb => initStaging(cb) ab.copyFrom(cb, storageType.fieldOffset(src, 0)) cb.assign(maxIndex, Region.loadLong(storageType.fieldOffset(src, 3))) cb.assign(maxSize, Region.loadInt(storageType.fieldOffset(src, 4))) - })({ cb => - cb.assign(maxGarbage, Region.loadInt(storageType.fieldOffset(src, 4))) - }) + }, + )({ cb => cb.assign(maxGarbage, Region.loadInt(storageType.fieldOffset(src, 4))) }) } def serialize(codec: BufferSpec): (EmitCodeBuilder, Value[OutputBuffer]) => Unit = { - { (cb: EmitCodeBuilder, ob: Value[OutputBuffer]) => - maybeGCCode(cb, + (cb: EmitCodeBuilder, ob: Value[OutputBuffer]) => + maybeGCCode( + cb, { cb => cb += ob.writeLong(maxIndex) cb += ob.writeInt(maxSize) ab.serialize(codec)(cb, ob) cb += ob.writeInt(const(TakeByRVAS.END_SERIALIZATION)) - } - )({ cb => - cb += ob.writeInt(maxGarbage) - }, runBefore = true) - } + }, + )(cb => cb += ob.writeInt(maxGarbage), runBefore = true) } def deserialize(codec: BufferSpec): (EmitCodeBuilder, Value[InputBuffer]) => Unit = { - { (cb: EmitCodeBuilder, ib: Value[InputBuffer]) => - maybeGCCode(cb, + (cb: EmitCodeBuilder, ib: Value[InputBuffer]) => + maybeGCCode( + cb, { cb => cb.assign(maxIndex, ib.readLong()) cb.assign(maxSize, ib.readInt()) ab.deserialize(codec)(cb, ib) initStaging(cb) - cb.if_(ib.readInt() cne TakeByRVAS.END_SERIALIZATION, - cb._fatal(s"StagedSizedKeyValuePriorityQueue serialization failed") + cb.if_( + ib.readInt() cne TakeByRVAS.END_SERIALIZATION, + cb._fatal(s"StagedSizedKeyValuePriorityQueue serialization failed"), ) - } - )({ cb => - cb.assign(maxGarbage, ib.readInt()) - cb.assign(garbage, 0) - }, runBefore = true) - } + }, + )( + { cb => + cb.assign(maxGarbage, ib.readInt()) + cb.assign(garbage, 0) + }, + runBefore = true, + ) } private def elementOffset(cb: EmitCodeBuilder, i: Value[Int]): Value[Long] = @@ -212,11 +258,16 @@ class TakeByRVAS(val valueVType: VirtualTypeWithReq, val keyVType: VirtualTypeWi val i = mb.getCodeParam[Long](1) val j = mb.getCodeParam[Long](2) - mb.emitWithBuilder(cb => compareIndexedKey(cb, - indexedKeyType.loadCheapSCode(cb, eltTuple.fieldOffset(i, 0)), - indexedKeyType.loadCheapSCode(cb, eltTuple.fieldOffset(j, 0)))) + mb.emitWithBuilder(cb => + compareIndexedKey( + cb, + indexedKeyType.loadCheapSCode(cb, eltTuple.fieldOffset(i, 0)), + indexedKeyType.loadCheapSCode(cb, eltTuple.fieldOffset(j, 0)), + ) + ) - mb.invokeCode(_, _, _) + (cb: EmitCodeBuilder, i: Value[Long], j: Value[Long]) => + cb.invokeCode(mb, cb.this_, i, j) } private val swap: (EmitCodeBuilder, Value[Long], Value[Long]) => Unit = { @@ -230,28 +281,30 @@ class TakeByRVAS(val valueVType: VirtualTypeWithReq, val keyVType: VirtualTypeWi cb += Region.copyFrom(staging, j, eltTuple.byteSize) }) - (cb: EmitCodeBuilder, x: Value[Long], y: Value[Long]) => cb.invokeVoid(mb, x, y) + (cb: EmitCodeBuilder, x: Value[Long], y: Value[Long]) => cb.invokeVoid(mb, cb.this_, x, y) } - private val rebalanceUp: (EmitCodeBuilder, Value[Int]) => Unit = { val mb = kb.genEmitMethod("rebalance_up", FastSeq[ParamType](IntInfo), UnitInfo) val idx = mb.getCodeParam[Int](1) mb.voidWithBuilder { cb => - cb.if_(idx > 0, - { + cb.if_( + idx > 0, { val parent = cb.memoize((idx + 1) / 2 - 1) val ii = elementOffset(cb, idx) val jj = elementOffset(cb, parent) - cb.if_(compareElt(cb, ii, jj) > 0, { - swap(cb, ii, jj) - cb.invokeVoid(mb, parent) - }) - }) + cb.if_( + compareElt(cb, ii, jj) > 0, { + swap(cb, ii, jj) + cb.invokeVoid(mb, cb.this_, parent) + }, + ) + }, + ) } - (cb: EmitCodeBuilder, x: Value[Int]) => cb.invokeVoid(mb, x) + (cb: EmitCodeBuilder, x: Value[Int]) => cb.invokeVoid(mb, cb.this_, x) } private val rebalanceDown: (EmitCodeBuilder, Value[Int]) => Unit = { @@ -267,27 +320,29 @@ class TakeByRVAS(val valueVType: VirtualTypeWithReq, val keyVType: VirtualTypeWi mb.voidWithBuilder { cb => cb.assign(child1, (idx + 1) * 2 - 1) cb.assign(child2, child1 + 1) - cb.if_(child1 < ab.size, - { - cb.if_(child2 >= ab.size, { - cb.assign(minChild, child1) - }, { - cb.if_(compareElt(cb, elementOffset(cb, child1), elementOffset(cb, child2)) > 0, { - cb.assign(minChild, child1) - }, { - cb.assign(minChild, child2) - }) - }) + cb.if_( + child1 < ab.size, { + cb.if_( + child2 >= ab.size, + cb.assign(minChild, child1), + cb.if_( + compareElt(cb, elementOffset(cb, child1), elementOffset(cb, child2)) > 0, + cb.assign(minChild, child1), + cb.assign(minChild, child2), + ), + ) cb.assign(ii, elementOffset(cb, minChild)) cb.assign(jj, elementOffset(cb, idx)) - cb.if_(compareElt(cb, ii, jj) > 0, - { + cb.if_( + compareElt(cb, ii, jj) > 0, { swap(cb, ii, jj) - cb.invokeVoid(mb, minChild) - }) - }) + cb.invokeVoid(mb, cb.this_, minChild) + }, + ) + }, + ) } - (cb: EmitCodeBuilder, x: Value[Int]) => cb.invokeVoid(mb, x) + (cb: EmitCodeBuilder, x: Value[Int]) => cb.invokeVoid(mb, cb.this_, x) } private lazy val gc: EmitCodeBuilder => Unit = { @@ -296,59 +351,79 @@ class TakeByRVAS(val valueVType: VirtualTypeWithReq, val keyVType: VirtualTypeWi val oldRegion = mb.newLocal[Region]("old_region") mb.voidWithBuilder { cb => cb.assign(garbage, garbage + 1) - cb.if_(garbage >= maxGarbage, - { + cb.if_( + garbage >= maxGarbage, { cb.assign(oldRegion, region) cb.assign(r, Region.stagedCreate(regionSize, kb.pool())) ab.reallocateData(cb) initStaging(cb) cb.assign(garbage, 0) cb += oldRegion.invoke[Unit]("invalidate") - }) + }, + ) } - (cb: EmitCodeBuilder) => cb.invokeVoid(mb) + (cb: EmitCodeBuilder) => cb.invokeVoid(mb, cb.this_) } else (_: EmitCodeBuilder) => () } - private def stageAndIndexKey(cb: EmitCodeBuilder, k: EmitCode): Unit = { k.toI(cb) - .consume(cb, - { - indexedKeyType.setFieldMissing(cb, keyStage, 0) - }, + .consume( + cb, + indexedKeyType.setFieldMissing(cb, keyStage, 0), { sc => indexedKeyType.setFieldPresent(cb, keyStage, 0) - keyType.storeAtAddress(cb, indexedKeyType.fieldOffset(keyStage, 0), region, sc, deepCopy = false) - } + keyType.storeAtAddress( + cb, + indexedKeyType.fieldOffset(keyStage, 0), + region, + sc, + deepCopy = false, + ) + }, ) cb += Region.storeLong(indexedKeyType.fieldOffset(keyStage, 1), maxIndex) cb.assign(maxIndex, maxIndex + 1L) } - private def copyElementToStaging(cb: EmitCodeBuilder, o: Code[Long]): Unit = cb += Region.copyFrom(o, staging, eltTuple.byteSize) + private def copyElementToStaging(cb: EmitCodeBuilder, o: Code[Long]): Unit = + cb += Region.copyFrom(o, staging, eltTuple.byteSize) private def copyToStaging(cb: EmitCodeBuilder, value: EmitCode, indexedKey: Code[Long]): Unit = { cb.if_(staging.ceq(0L), cb += Code._fatal[Unit]("staging is 0")) - indexedKeyType.storeAtAddress(cb, + indexedKeyType.storeAtAddress( + cb, eltTuple.fieldOffset(staging, 0), region, indexedKeyType.loadCheapSCode(cb, indexedKey), - deepCopy = false) + deepCopy = false, + ) value.toI(cb) - .consume(cb, - { - eltTuple.setFieldMissing(cb, staging, 1) - }, + .consume( + cb, + eltTuple.setFieldMissing(cb, staging, 1), { v => eltTuple.setFieldPresent(cb, staging, 1) - valueType.storeAtAddress(cb, eltTuple.fieldOffset(staging, 1), region, v, deepCopy = false) - }) + valueType.storeAtAddress( + cb, + eltTuple.fieldOffset(staging, 1), + region, + v, + deepCopy = false, + ) + }, + ) } private def swapStaging(cb: EmitCodeBuilder): Unit = { - eltTuple.storeAtAddress(cb, ab.elementOffset(cb, 0), region, eltTuple.loadCheapSCode(cb, staging), true) + eltTuple.storeAtAddress( + cb, + ab.elementOffset(cb, 0), + region, + eltTuple.loadCheapSCode(cb, staging), + true, + ) rebalanceDown(cb, 0) } @@ -358,87 +433,117 @@ class TakeByRVAS(val valueVType: VirtualTypeWithReq, val keyVType: VirtualTypeWi } def seqOp(cb: EmitCodeBuilder, v: EmitCode, k: EmitCode): Unit = { - val mb = kb.genEmitMethod("take_by_seqop", + val mb = kb.genEmitMethod( + "take_by_seqop", FastSeq[ParamType](v.emitParamType, k.emitParamType), - UnitInfo) + UnitInfo, + ) mb.voidWithBuilder { cb => val value = mb.getEmitParam(cb, 1) val key = mb.getEmitParam(cb, 2) - cb.if_(maxSize > 0, { - cb.if_(ab.size < maxSize, { - stageAndIndexKey(cb, key) - copyToStaging(cb, value, keyStage) - enqueueStaging(cb) - }, { - cb.assign(tempPtr, eltTuple.loadField(elementOffset(cb, 0), 0)) - cb.if_(compareKey(cb, key, loadKey(cb, tempPtr)) < 0, { - stageAndIndexKey(cb, key) - copyToStaging(cb, value, keyStage) - swapStaging(cb) - gc(cb) - }) - }) - }) + cb.if_( + maxSize > 0, { + cb.if_( + ab.size < maxSize, { + stageAndIndexKey(cb, key) + copyToStaging(cb, value, keyStage) + enqueueStaging(cb) + }, { + cb.assign(tempPtr, eltTuple.loadField(elementOffset(cb, 0), 0)) + cb.if_( + compareKey(cb, key, loadKey(cb, tempPtr)) < 0, { + stageAndIndexKey(cb, key) + copyToStaging(cb, value, keyStage) + swapStaging(cb) + gc(cb) + }, + ) + }, + ) + }, + ) } - cb.invokeVoid(mb, v, k) + cb.invokeVoid(mb, cb.this_, v, k) } // for tests - def seqOp(cb: EmitCodeBuilder, vm: Code[Boolean], v: Value[_], km: Code[Boolean], k: Value[_]): Unit = { - val vec = EmitCode(Code._empty, vm, if (valueType.isPrimitive) primitive(valueType.virtualType, v) else valueType.loadCheapSCode(cb, coerce[Long](v))) - val kec = EmitCode(Code._empty, km, if (keyType.isPrimitive) primitive(keyType.virtualType, k) else keyType.loadCheapSCode(cb, coerce[Long](k))) + def seqOp(cb: EmitCodeBuilder, vm: Code[Boolean], v: Value[_], km: Code[Boolean], k: Value[_]) + : Unit = { + val vec = EmitCode( + Code._empty, + vm, + if (valueType.isPrimitive) primitive(valueType.virtualType, v) + else valueType.loadCheapSCode(cb, coerce[Long](v)), + ) + val kec = EmitCode( + Code._empty, + km, + if (keyType.isPrimitive) primitive(keyType.virtualType, k) + else keyType.loadCheapSCode(cb, coerce[Long](k)), + ) seqOp(cb, vec, kec) } def combine(cb: EmitCodeBuilder, other: TakeByRVAS): Unit = { val mb = kb.genEmitMethod("take_by_combop", FastSeq[ParamType](), UnitInfo) - mb.voidWithBuilder { cb => val i = cb.newLocal[Int]("combine_i") - cb.for_(cb.assign(i, 0), i < other.ab.size, cb.assign(i, i + 1), { - val offset = other.elementOffset(cb, i) - val indexOffset = cb.memoize(indexedKeyType.fieldOffset(eltTuple.loadField(offset, 0), 1)) - cb += Region.storeLong(indexOffset, Region.loadLong(indexOffset) + maxIndex) - cb.if_(maxSize > 0, - cb.if_(ab.size < maxSize, - { - copyElementToStaging(cb, offset) - enqueueStaging(cb) - }, - { - cb.assign(tempPtr, elementOffset(cb, 0)) - cb.if_(compareElt(cb, offset, tempPtr) < 0, - { - copyElementToStaging(cb, offset) - swapStaging(cb) - gc(cb) - }) - } - )) - }) + cb.for_( + cb.assign(i, 0), + i < other.ab.size, + cb.assign(i, i + 1), { + val offset = other.elementOffset(cb, i) + val indexOffset = cb.memoize(indexedKeyType.fieldOffset(eltTuple.loadField(offset, 0), 1)) + cb += Region.storeLong(indexOffset, Region.loadLong(indexOffset) + maxIndex) + cb.if_( + maxSize > 0, + cb.if_( + ab.size < maxSize, { + copyElementToStaging(cb, offset) + enqueueStaging(cb) + }, { + cb.assign(tempPtr, elementOffset(cb, 0)) + cb.if_( + compareElt(cb, offset, tempPtr) < 0, { + copyElementToStaging(cb, offset) + swapStaging(cb) + gc(cb) + }, + ) + }, + ), + ) + }, + ) cb.assign(maxIndex, maxIndex + other.maxIndex) } - cb.invokeVoid(mb) + cb.invokeVoid(mb, cb.this_) } - def result(cb: EmitCodeBuilder, _r: Value[Region], resultType: PCanonicalArray): SIndexablePointerValue = { + def result(cb: EmitCodeBuilder, _r: Value[Region], resultType: PCanonicalArray) + : SIndexablePointerValue = { val mb = kb.genEmitMethod("take_by_result", FastSeq[ParamType](classInfo[Region]), LongInfo) - val quickSort: (EmitCodeBuilder, Value[Long], Value[Int], Value[Int]) => Value[Unit] = { - val mb = kb.genEmitMethod("result_quicksort", FastSeq[ParamType](LongInfo, IntInfo, IntInfo), UnitInfo) + val quickSort: (EmitCodeBuilder, Value[Long], Value[Int], Value[Int]) => Unit = { + val mb = kb.genEmitMethod( + "result_quicksort", + FastSeq[ParamType](LongInfo, IntInfo, IntInfo), + UnitInfo, + ) val indices = mb.getCodeParam[Long](1) val low = mb.getCodeParam[Int](2) val high = mb.getCodeParam[Int](3) val pivotIndex = mb.newLocal[Int]("pivotIdx") - val swap: (EmitCodeBuilder, Value[Long], Value[Long]) => Value[Unit] = { - val mb = kb.genEmitMethod("quicksort_swap", FastSeq[ParamType](LongInfo, LongInfo), UnitInfo) + val swap: (EmitCodeBuilder, Value[Long], Value[Long]) => Unit = { + val mb = + kb.genEmitMethod("quicksort_swap", FastSeq[ParamType](LongInfo, LongInfo), UnitInfo) val i = mb.getCodeParam[Long](1) val j = mb.getCodeParam[Long](2) @@ -448,14 +553,18 @@ class TakeByRVAS(val valueVType: VirtualTypeWithReq, val keyVType: VirtualTypeWi Code( tmp := Region.loadInt(i), Region.storeInt(i, Region.loadInt(j)), - Region.storeInt(j, tmp) + Region.storeInt(j, tmp), ) ) - mb.invokeCode(_, _, _) + (cb, i, j) => cb.invokeVoid(mb, cb.this_, i, j) } val partition: (EmitCodeBuilder, Value[Long], Value[Int], Value[Int]) => Value[Int] = { - val mb = kb.genEmitMethod("quicksort_partition", FastSeq[ParamType](LongInfo, IntInfo, IntInfo), IntInfo) + val mb = kb.genEmitMethod( + "quicksort_partition", + FastSeq[ParamType](LongInfo, IntInfo, IntInfo), + IntInfo, + ) val indices = mb.getCodeParam[Long](1) val low = mb.getCodeParam[Int](2) @@ -479,57 +588,68 @@ class TakeByRVAS(val valueVType: VirtualTypeWithReq, val keyVType: VirtualTypeWi cb.loop { Lrecur => cb.loop { Linner => cb.assign(tmpOffset, elementOffset(cb, indexAt(cb, low))) - cb.if_(compareElt(cb, tmpOffset, pivotOffset) < 0, { - cb.assign(low, low + 1) - cb.goto(Linner) - }) + cb.if_( + compareElt(cb, tmpOffset, pivotOffset) < 0, { + cb.assign(low, low + 1) + cb.goto(Linner) + }, + ) } cb.loop { Linner => cb.assign(tmpOffset, elementOffset(cb, indexAt(cb, high))) - cb.if_(compareElt(cb, tmpOffset, pivotOffset) > 0, { - cb.assign(high, high - 1) - cb.goto(Linner) - }) + cb.if_( + compareElt(cb, tmpOffset, pivotOffset) > 0, { + cb.assign(high, high - 1) + cb.goto(Linner) + }, + ) } - cb.if_(high > low, { - swap(cb, indexOffset(cb, low), indexOffset(cb, high)) - cb.assign(low, low + 1) - cb.assign(high, high - 1) - cb.goto(Lrecur) - }) + cb.if_( + high > low, { + swap(cb, indexOffset(cb, low), indexOffset(cb, high)) + cb.assign(low, low + 1) + cb.assign(high, high - 1) + cb.goto(Lrecur) + }, + ) } high } - mb.invokeCode(_, _, _, _) + (cb, indices, lo, hi) => cb.invokeCode(mb, cb.this_, indices, lo, hi) } mb.voidWithBuilder { cb => - cb.if_(low < high, { - cb.assign(pivotIndex, partition(cb, indices, low, high)) - cb.invokeVoid(mb, indices, low, pivotIndex) - cb.invokeVoid(mb, indices, cb.memoize(pivotIndex + 1), high) - }) + cb.if_( + low < high, { + cb.assign(pivotIndex, partition(cb, indices, low, high)) + cb.invokeVoid(mb, cb.this_, indices, low, pivotIndex) + cb.invokeVoid(mb, cb.this_, indices, cb.memoize(pivotIndex + 1), high) + }, + ) } - mb.invokeCode(_, _, _, _) + + (cb, indices, lo, hi) => cb.invokeVoid(mb, cb.this_, indices, lo, hi) } mb.emitWithBuilder[Long] { cb => val r = mb.getCodeParam[Region](1) - val indicesToSort = cb.newLocal[Long]("indices_to_sort", - r.load().allocate(4L, ab.size.toL * 4L)) + val indicesToSort = + cb.newLocal[Long]("indices_to_sort", r.load().allocate(4L, ab.size.toL * 4L)) val i = cb.newLocal[Int]("i", 0) def indexOffset(idx: Code[Int]): Code[Long] = indicesToSort + idx.toL * 4L - cb.while_(i < ab.size, { - cb += Region.storeInt(indexOffset(i), i) - cb.assign(i, i + 1) - }) + cb.while_( + i < ab.size, { + cb += Region.storeInt(indexOffset(i), i) + cb.assign(i, i + 1) + }, + ) quickSort(cb, indicesToSort, 0, cb.memoize(ab.size - 1)) @@ -541,15 +661,18 @@ class TakeByRVAS(val valueVType: VirtualTypeWithReq, val keyVType: VirtualTypeWi } }.a } - resultType.loadCheapSCode(cb, cb.invokeCode[Long](mb, _r)) + resultType.loadCheapSCode(cb, cb.invokeCode[Long](mb, cb.this_, _r)) } } -class TakeByAggregator(valueType: VirtualTypeWithReq, keyType: VirtualTypeWithReq) extends StagedAggregator { +class TakeByAggregator(valueType: VirtualTypeWithReq, keyType: VirtualTypeWithReq) + extends StagedAggregator { type State = TakeByRVAS - val resultEmitType: EmitType = EmitType(SIndexablePointer(PCanonicalArray(valueType.canonicalPType)), true) + val resultEmitType: EmitType = + EmitType(SIndexablePointer(PCanonicalArray(valueType.canonicalPType)), true) + val initOpTypes: Seq[Type] = Array(TInt32) val seqOpTypes: Seq[Type] = Array(valueType.t, keyType.t) @@ -557,9 +680,11 @@ class TakeByAggregator(valueType: VirtualTypeWithReq, keyType: VirtualTypeWithRe assert(init.length == 1) val Array(sizeTriplet) = init sizeTriplet.toI(cb) - .consume(cb, + .consume( + cb, cb += Code._fatal[Unit](s"argument 'n' for 'hl.agg.take' may not be missing"), - sc => state.initialize(cb, sc.asInt.value)) + sc => state.initialize(cb, sc.asInt.value), + ) } protected def _seqOp(cb: EmitCodeBuilder, state: State, seq: Array[EmitCode]): Unit = { @@ -567,11 +692,18 @@ class TakeByAggregator(valueType: VirtualTypeWithReq, keyType: VirtualTypeWithRe state.seqOp(cb, value, key) } - protected def _combOp(ctx: ExecuteContext, cb: EmitCodeBuilder, state: TakeByRVAS, other: TakeByRVAS): Unit = state.combine(cb, other) - + protected def _combOp( + ctx: ExecuteContext, + cb: EmitCodeBuilder, + region: Value[Region], + state: TakeByRVAS, + other: TakeByRVAS, + ): Unit = state.combine(cb, other) - protected def _result(cb: EmitCodeBuilder, state: State, region: Value[Region]): IEmitCode = { + protected def _result(cb: EmitCodeBuilder, state: State, region: Value[Region]): IEmitCode = // state.result does a deep copy - IEmitCode.present(cb, state.result(cb, region, resultEmitType.storageType.asInstanceOf[PCanonicalArray])) - } + IEmitCode.present( + cb, + state.result(cb, region, resultEmitType.storageType.asInstanceOf[PCanonicalArray]), + ) } diff --git a/hail/src/main/scala/is/hail/expr/ir/analyses/ComputeMethodSplits.scala b/hail/src/main/scala/is/hail/expr/ir/analyses/ComputeMethodSplits.scala index 5e26262d44d..223824354a8 100644 --- a/hail/src/main/scala/is/hail/expr/ir/analyses/ComputeMethodSplits.scala +++ b/hail/src/main/scala/is/hail/expr/ir/analyses/ComputeMethodSplits.scala @@ -1,6 +1,5 @@ package is.hail.expr.ir.analyses -import is.hail.HailContext import is.hail.backend.ExecuteContext import is.hail.expr.ir._ @@ -12,7 +11,9 @@ object ComputeMethodSplits { require(splitThreshold > 0, s"invalid method_split_ir_limit") def recurAndComputeSizeUnderneath(x: IR): Int = { - val sizeUnderneath = x.children.iterator.map { case child: IR => recurAndComputeSizeUnderneath(child) }.sum + val sizeUnderneath = x.children.iterator.map { case child: IR => + recurAndComputeSizeUnderneath(child) + }.sum val shouldSplit = !controlFlowPreventsSplit.contains(x) && (x match { case _: TailLoop => true diff --git a/hail/src/main/scala/is/hail/expr/ir/analyses/ControlFlowPreventsSplit.scala b/hail/src/main/scala/is/hail/expr/ir/analyses/ControlFlowPreventsSplit.scala index 6796e4fef72..6d06a4191e2 100644 --- a/hail/src/main/scala/is/hail/expr/ir/analyses/ControlFlowPreventsSplit.scala +++ b/hail/src/main/scala/is/hail/expr/ir/analyses/ControlFlowPreventsSplit.scala @@ -8,17 +8,19 @@ object ControlFlowPreventsSplit { def apply(x: BaseIR, parentPointers: Memo[BaseIR], usesAndDefs: UsesAndDefs): Memo[Unit] = { val m = Memo.empty[Unit] VisitIR(x) { - case r@Recur(name, _, _) => + case r @ Recur(name, _, _) => var parent: BaseIR = r - while (parent match { - case TailLoop(`name`, _, _, _) => false - case _ => true - }) { + while ( + parent match { + case TailLoop(`name`, _, _, _) => false + case _ => true + } + ) { if (!m.contains(parent)) m.bind(parent, ()) parent = parentPointers.lookup(parent) } - case r@Ref(name, t) if t.isInstanceOf[TStream] => + case r @ Ref(_, t) if t.isInstanceOf[TStream] => val declaration = usesAndDefs.defs.lookup(r) var parent: BaseIR = r while (!(parent.eq(declaration))) { diff --git a/hail/src/main/scala/is/hail/expr/ir/analyses/SemanticHash.scala b/hail/src/main/scala/is/hail/expr/ir/analyses/SemanticHash.scala index 9e66752dd19..f060c9d6eee 100644 --- a/hail/src/main/scala/is/hail/expr/ir/analyses/SemanticHash.scala +++ b/hail/src/main/scala/is/hail/expr/ir/analyses/SemanticHash.scala @@ -1,21 +1,22 @@ package is.hail.expr.ir.analyses import is.hail.backend.ExecuteContext -import is.hail.expr.ir.functions.{TableCalculateNewPartitions, WrappedMatrixToValueFunction} import is.hail.expr.ir.{MatrixRangeReader, _} +import is.hail.expr.ir.functions.{TableCalculateNewPartitions, WrappedMatrixToValueFunction} import is.hail.io.fs.FS import is.hail.io.vcf.MatrixVCFReader import is.hail.methods._ import is.hail.types.virtual._ -import is.hail.utils.{Logging, TreeTraversal, toRichBoolean} -import org.apache.commons.codec.digest.MurmurHash3 +import is.hail.utils.{toRichBoolean, Logging, TreeTraversal} -import java.io.FileNotFoundException -import java.nio.ByteBuffer import scala.collection.mutable -import scala.language.implicitConversions import scala.util.control.NonFatal +import java.io.FileNotFoundException +import java.nio.ByteBuffer + +import org.apache.commons.codec.digest.MurmurHash3 + case object SemanticHash extends Logging { type Type = Int @@ -23,7 +24,6 @@ case object SemanticHash extends Logging { def extend(x: Type, bytes: Array[Byte]): Type = MurmurHash3.hash32x86(bytes, 0, bytes.length, x) - def apply(ctx: ExecuteContext)(root: BaseIR): Option[Type] = ctx.timer.time("SemanticHash") { @@ -45,7 +45,7 @@ case object SemanticHash extends Logging { val bytes = encode(ctx.fs, ir, index) hash = extend(hash, bytes) } catch { - case error@(_: UnsupportedOperationException | _: FileNotFoundException) => + case error @ (_: UnsupportedOperationException | _: FileNotFoundException) => log.info(error) return None @@ -56,7 +56,7 @@ case object SemanticHash extends Logging { |INCLUDING THE STACK TRACE AT THE END OF THIS MESSAGE. |https://github.com/hail-is/hail/issues/new/choose |""".stripMargin, - error + error, ) return None } @@ -72,7 +72,6 @@ case object SemanticHash extends Logging { semhash } - private def encode(fs: FS, ir: BaseIR, index: Int): Array[Byte] = { val buffer: mutable.ArrayBuilder[Byte] = Array.newBuilder[Byte] ++= @@ -92,8 +91,8 @@ case object SemanticHash extends Logging { case a: AggGroupBy => buffer += a.isScan.toByte - case a: AggLet => - buffer += a.isScan.toByte + case Block(bindings, _) => + for (b <- bindings) buffer += b.scope.toByte case a: AggArrayPerElement => buffer += a.isScan.toByte @@ -156,7 +155,9 @@ case object SemanticHash extends Logging { buffer ++= getFileHash(fs)(path) case _ => - throw new UnsupportedOperationException(s"SemanticHash unknown: ${reader.getClass.getName}") + throw new UnsupportedOperationException( + s"SemanticHash unknown: ${reader.getClass.getName}" + ) } case Cast(_, typ) => @@ -195,7 +196,6 @@ case object SemanticHash extends Logging { params.nPartitions.fold(Array.empty[Byte])(Bytes.fromInt) ++= Bytes.fromInt(nPartitionsAdj) - case _: MatrixNativeReader => reader .pathsUsed @@ -203,11 +203,11 @@ case object SemanticHash extends Logging { .foreach(g => buffer ++= getFileHash(fs)(g.getPath)) case _: MatrixVCFReader => - reader.pathsUsed.foreach { path => - buffer ++= getFileHash(fs)(path) - } + reader.pathsUsed.foreach(path => buffer ++= getFileHash(fs)(path)) case _ => - throw new UnsupportedOperationException(s"SemanticHash unknown: ${reader.getClass.getName}") + throw new UnsupportedOperationException( + s"SemanticHash unknown: ${reader.getClass.getName}" + ) } case MatrixWrite(_, writer) => @@ -237,7 +237,6 @@ case object SemanticHash extends Logging { val getFieldIndex = table.typ.rowType.fieldIdx keys.map(getFieldIndex).foreach(buffer ++= Bytes.fromInt(_)) - case TableKeyByAndAggregate(_, _, _, nPartitions, bufferSize) => nPartitions.foreach { buffer ++= Bytes.fromInt(_) @@ -260,14 +259,16 @@ case object SemanticHash extends Logging { case StringTableReader(_, fileStatuses) => fileStatuses.foreach(s => buffer ++= getFileHash(fs)(s.getPath)) - case reader@(_: TableNativeReader | _: TableNativeZippedReader) => + case reader @ (_: TableNativeReader | _: TableNativeZippedReader) => reader .pathsUsed .flatMap(p => fs.glob(p + "/**").filter(_.isFile)) .foreach(g => buffer ++= getFileHash(fs)(g.getPath)) case _ => - throw new UnsupportedOperationException(s"SemanticHash unknown: ${reader.getClass.getName}") + throw new UnsupportedOperationException( + s"SemanticHash unknown: ${reader.getClass.getName}" + ) } case TableToValueApply(_, op) => @@ -282,7 +283,9 @@ case object SemanticHash extends Logging { buffer ++= Bytes.fromClass(op.getClass) case _: MatrixExportEntriesByCol => - throw new UnsupportedOperationException("SemanticHash unknown: MatrixExportEntriesByCol") + throw new UnsupportedOperationException( + "SemanticHash unknown: MatrixExportEntriesByCol" + ) } case _: ForceCountTable | _: NPartitionsTable => @@ -294,76 +297,75 @@ case object SemanticHash extends Logging { // The following are parameterized entirely by the operation's input and the operation itself case _: ArrayLen | - _: ArrayRef | - _: ArraySlice | - _: ArraySort | - _: ArrayZeros | - _: Begin | - _: BlockMatrixCollect | - _: CastToArray | - _: Coalesce | - _: CollectDistributedArray | - _: ConsoleLog | - _: Consume | - _: Die | - _: GroupByKey | - _: If | - _: InsertFields | - _: IsNA | - _: Let | - _: LiftMeOut | - _: MakeArray | - _: MakeNDArray | - _: MakeStream | - _: MakeStruct | - _: MatrixAggregate | - _: MatrixColsTable | - _: MatrixCount | - _: MatrixMapGlobals | - _: MatrixMapRows | - _: MatrixFilterRows | - _: MatrixMapCols | - _: MatrixFilterCols | - _: MatrixMapEntries | - _: MatrixFilterEntries | - _: MatrixDistinctByRow | - _: NDArrayShape | - _: NDArraySlice | - _: NDArrayReshape | - _: NDArrayWrite | - _: RelationalLet | - _: RNGSplit | - _: RNGStateLiteral | - _: StreamAgg | - _: StreamDrop | - _: StreamDropWhile | - _: StreamFilter | - _: StreamFlatMap | - _: StreamFold | - _: StreamFold2 | - _: StreamFor | - _: StreamIota | - _: StreamLen | - _: StreamMap | - _: StreamRange | - _: StreamTake | - _: StreamTakeWhile | - _: Switch | - _: TableGetGlobals | - _: TableAggregate | - _: TableAggregateByKey | - _: TableCollect | - _: TableCount | - _: TableDistinct | - _: TableFilter | - _: TableMapGlobals | - _: TableMapRows | - _: TableRename | - _: ToArray | - _: ToDict | - _: ToSet | - _: ToStream | - _: Trap => + _: ArrayRef | + _: ArraySlice | + _: ArraySort | + _: ArrayZeros | + _: BlockMatrixCollect | + _: CastToArray | + _: Coalesce | + _: CollectDistributedArray | + _: ConsoleLog | + _: Consume | + _: Die | + _: GroupByKey | + _: If | + _: InsertFields | + _: IsNA | + _: Block | + _: LiftMeOut | + _: MakeArray | + _: MakeNDArray | + _: MakeStream | + _: MakeStruct | + _: MatrixAggregate | + _: MatrixColsTable | + _: MatrixCount | + _: MatrixMapGlobals | + _: MatrixMapRows | + _: MatrixFilterRows | + _: MatrixMapCols | + _: MatrixFilterCols | + _: MatrixMapEntries | + _: MatrixFilterEntries | + _: MatrixDistinctByRow | + _: NDArrayShape | + _: NDArraySlice | + _: NDArrayReshape | + _: NDArrayWrite | + _: RelationalLet | + _: RNGSplit | + _: RNGStateLiteral | + _: StreamAgg | + _: StreamDrop | + _: StreamDropWhile | + _: StreamFilter | + _: StreamFlatMap | + _: StreamFold | + _: StreamFold2 | + _: StreamFor | + _: StreamIota | + _: StreamLen | + _: StreamMap | + _: StreamRange | + _: StreamTake | + _: StreamTakeWhile | + _: Switch | + _: TableGetGlobals | + _: TableAggregate | + _: TableAggregateByKey | + _: TableCollect | + _: TableCount | + _: TableDistinct | + _: TableFilter | + _: TableMapGlobals | + _: TableMapRows | + _: TableRename | + _: ToArray | + _: ToDict | + _: ToSet | + _: ToStream | + _: Trap => () // Discrete values @@ -393,10 +395,9 @@ case object SemanticHash extends Logging { case Some(etag) => etag.getBytes case None => - path.getBytes ++ Bytes.fromLong(fs.fileListEntry(path).getModificationTime) + path.getBytes ++ Bytes.fromLong(fs.fileStatus(path).getModificationTime) } - def levelOrder(root: BaseIR): Iterator[(BaseIR, Int)] = { val adj: ((BaseIR, Int)) => Iterator[(BaseIR, Int)] = Function.tupled((ir, _) => ir.children.zipWithIndex.iterator) diff --git a/hail/src/main/scala/is/hail/expr/ir/functions/ApproxCDFFunctions.scala b/hail/src/main/scala/is/hail/expr/ir/functions/ApproxCDFFunctions.scala index 4eff5c01363..657ca2a1a9a 100644 --- a/hail/src/main/scala/is/hail/expr/ir/functions/ApproxCDFFunctions.scala +++ b/hail/src/main/scala/is/hail/expr/ir/functions/ApproxCDFFunctions.scala @@ -1,15 +1,15 @@ package is.hail.expr.ir.functions import is.hail.annotations.Region -import is.hail.asm4s.{Code, Value, valueToCodeObject} +import is.hail.asm4s.{valueToCodeObject, Code, Value} import is.hail.expr.ir.EmitCodeBuilder import is.hail.expr.ir.agg.{ApproxCDFStateManager, QuantilesAggregator} -import is.hail.types.physical._ import is.hail.types.physical.stypes.concrete.SBaseStructPointer import is.hail.types.physical.stypes.interfaces.SBaseStructValue import is.hail.types.physical.stypes.primitives.SInt32Value import is.hail.types.virtual.TInt32 import is.hail.utils.toRichIterable + import org.apache.spark.sql.Row object ApproxCDFFunctions extends RegistryFunctions { @@ -23,27 +23,44 @@ object ApproxCDFFunctions extends RegistryFunctions { ApproxCDFStateManager.fromData(k, levels, items, counts) } - def makeStateManager(cb: EmitCodeBuilder, r: Value[Region], k: Value[Int], state: SBaseStructValue): Value[ApproxCDFStateManager] = { + def makeStateManager( + cb: EmitCodeBuilder, + r: Value[Region], + k: Value[Int], + state: SBaseStructValue, + ): Value[ApproxCDFStateManager] = { val row = svalueToJavaValue(cb, r, state) cb.memoize(Code.invokeScalaObject2[Int, Row, ApproxCDFStateManager]( - ApproxCDFFunctions.getClass, "rowToStateManager", - k, Code.checkcast[Row](row))) + ApproxCDFFunctions.getClass, + "rowToStateManager", + k, + Code.checkcast[Row](row), + )) } - def stateManagerToRow(state: ApproxCDFStateManager): Row = { + def stateManagerToRow(state: ApproxCDFStateManager): Row = Row(state.levels.toFastSeq, state.items.toFastSeq, state.compactionCounts.toFastSeq) - } - def fromStateManager(cb: EmitCodeBuilder, r: Value[Region], state: Value[ApproxCDFStateManager]): SBaseStructValue = { + def fromStateManager(cb: EmitCodeBuilder, r: Value[Region], state: Value[ApproxCDFStateManager]) + : SBaseStructValue = { val row = cb.memoize(Code.invokeScalaObject1[ApproxCDFStateManager, Row]( - ApproxCDFFunctions.getClass, "stateManagerToRow", - state)) + ApproxCDFFunctions.getClass, + "stateManagerToRow", + state, + )) unwrapReturn(cb, r, SBaseStructPointer(statePType), row).asBaseStruct } def registerAll(): Unit = { - registerSCode3("approxCDFCombine", TInt32, stateType, stateType, stateType, (_, _, _, _) => SBaseStructPointer(statePType)) { - case (r, cb, rt, k: SInt32Value, left: SBaseStructValue, right: SBaseStructValue, errorID) => + registerSCode3( + "approxCDFCombine", + TInt32, + stateType, + stateType, + stateType, + (_, _, _, _) => SBaseStructPointer(statePType), + ) { + case (r, cb, _, k: SInt32Value, left: SBaseStructValue, right: SBaseStructValue, _) => val leftState = makeStateManager(cb, r.region, k.value, left) val rightState = makeStateManager(cb, r.region, k.value, right) diff --git a/hail/src/main/scala/is/hail/expr/ir/functions/ArrayFunctions.scala b/hail/src/main/scala/is/hail/expr/ir/functions/ArrayFunctions.scala index 0ebacf8c8b4..e1426548f49 100644 --- a/hail/src/main/scala/is/hail/expr/ir/functions/ArrayFunctions.scala +++ b/hail/src/main/scala/is/hail/expr/ir/functions/ArrayFunctions.scala @@ -3,11 +3,11 @@ package is.hail.expr.ir.functions import is.hail.asm4s._ import is.hail.expr.ir._ import is.hail.expr.ir.orderings.CodeOrdering +import is.hail.types.physical.{PCanonicalArray, PType} import is.hail.types.physical.stypes.EmitType import is.hail.types.physical.stypes.concrete.SIndexablePointer import is.hail.types.physical.stypes.interfaces._ import is.hail.types.physical.stypes.primitives.{SBooleanValue, SFloat64, SInt32, SInt32Value} -import is.hail.types.physical.{PCanonicalArray, PType} import is.hail.types.tcoerce import is.hail.types.virtual._ import is.hail.utils._ @@ -15,19 +15,55 @@ import is.hail.utils._ object ArrayFunctions extends RegistryFunctions { val arrayOps: Array[(String, Type, Type, (IR, IR, Int) => IR)] = Array( - ("mul", tnum("T"), tv("T"), (ir1: IR, ir2: IR, _) =>ApplyBinaryPrimOp(Multiply(), ir1, ir2)), - ("div", TInt32, TFloat32, (ir1: IR, ir2: IR, _) =>ApplyBinaryPrimOp(FloatingPointDivide(), ir1, ir2)), - ("div", TInt64, TFloat32, (ir1: IR, ir2: IR, _) =>ApplyBinaryPrimOp(FloatingPointDivide(), ir1, ir2)), - ("div", TFloat32, TFloat32, (ir1: IR, ir2: IR, _) =>ApplyBinaryPrimOp(FloatingPointDivide(),ir1, ir2)), - ("div", TFloat64, TFloat64, (ir1: IR, ir2: IR, _) =>ApplyBinaryPrimOp(FloatingPointDivide(), ir1, ir2)), - ("floordiv", tnum("T"), tv("T"), (ir1: IR, ir2: IR, _) => - ApplyBinaryPrimOp(RoundToNegInfDivide(), ir1, ir2)), - ("add", tnum("T"), tv("T"), (ir1: IR, ir2: IR, _) =>ApplyBinaryPrimOp(Add(),ir1, ir2)), - ("sub", tnum("T"), tv("T"), (ir1: IR, ir2: IR, _) =>ApplyBinaryPrimOp(Subtract(), ir1, ir2)), - ("pow", tnum("T"), TFloat64, (ir1: IR, ir2: IR, errorID: Int) => - Apply("pow", Seq(), Seq(ir1, ir2), TFloat64, errorID)), - ("mod", tnum("T"), tv("T"), (ir1: IR, ir2: IR, errorID: Int) => - Apply("mod", Seq(), Seq(ir1, ir2), ir2.typ, errorID))) + ("mul", tnum("T"), tv("T"), (ir1: IR, ir2: IR, _) => ApplyBinaryPrimOp(Multiply(), ir1, ir2)), + ( + "div", + TInt32, + TFloat32, + (ir1: IR, ir2: IR, _) => ApplyBinaryPrimOp(FloatingPointDivide(), ir1, ir2), + ), + ( + "div", + TInt64, + TFloat32, + (ir1: IR, ir2: IR, _) => ApplyBinaryPrimOp(FloatingPointDivide(), ir1, ir2), + ), + ( + "div", + TFloat32, + TFloat32, + (ir1: IR, ir2: IR, _) => ApplyBinaryPrimOp(FloatingPointDivide(), ir1, ir2), + ), + ( + "div", + TFloat64, + TFloat64, + (ir1: IR, ir2: IR, _) => ApplyBinaryPrimOp(FloatingPointDivide(), ir1, ir2), + ), + ( + "floordiv", + tnum("T"), + tv("T"), + (ir1: IR, ir2: IR, _) => + ApplyBinaryPrimOp(RoundToNegInfDivide(), ir1, ir2), + ), + ("add", tnum("T"), tv("T"), (ir1: IR, ir2: IR, _) => ApplyBinaryPrimOp(Add(), ir1, ir2)), + ("sub", tnum("T"), tv("T"), (ir1: IR, ir2: IR, _) => ApplyBinaryPrimOp(Subtract(), ir1, ir2)), + ( + "pow", + tnum("T"), + TFloat64, + (ir1: IR, ir2: IR, errorID: Int) => + Apply("pow", Seq(), Seq(ir1, ir2), TFloat64, errorID), + ), + ( + "mod", + tnum("T"), + tv("T"), + (ir1: IR, ir2: IR, errorID: Int) => + Apply("mod", Seq(), Seq(ir1, ir2), ir2.typ, errorID), + ), + ) def mean(args: Seq[IR]): IR = { val Seq(a) = args @@ -40,7 +76,7 @@ object ArrayFunctions extends RegistryFunctions { FastSeq((n, I32(0)), (sum, zero(t))), elt, FastSeq(Ref(n, TInt32) + I32(1), Ref(sum, t) + Ref(elt, t)), - Cast(Ref(sum, t), TFloat64) / Cast(Ref(n, TInt32), TFloat64) + Cast(Ref(sum, t), TFloat64) / Cast(Ref(n, TInt32), TFloat64), ) } @@ -49,14 +85,19 @@ object ArrayFunctions extends RegistryFunctions { def extend(a1: IR, a2: IR): IR = { val uid = genUID() val typ = a1.typ - If(IsNA(a1), + If( + IsNA(a1), NA(typ), - If(IsNA(a2), + If( + IsNA(a2), NA(typ), ToArray(StreamFlatMap( MakeStream(FastSeq(a1, a2), TStream(typ)), uid, - ToStream(Ref(uid, a1.typ)))))) + ToStream(Ref(uid, a1.typ)), + )), + ), + ) } def exists(a: IR, cond: IR => IR): IR = { @@ -66,17 +107,20 @@ object ArrayFunctions extends RegistryFunctions { False(), "acc", "elt", - invoke("lor",TBoolean, - Ref("acc", TBoolean), - cond(Ref("elt", t)))) + invoke("lor", TBoolean, Ref("acc", TBoolean), cond(Ref("elt", t))), + ) } - def contains(a: IR, value: IR): IR = { - exists(a, elt => ApplyComparisonOp( - EQWithNA(elt.typ, value.typ), - elt, - value)) - } + def contains(a: IR, value: IR): IR = + exists( + a, + elt => + ApplyComparisonOp( + EQWithNA(elt.typ, value.typ), + elt, + value, + ), + ) def sum(a: IR): IR = { val t = tcoerce[TArray](a.typ).elementType @@ -91,19 +135,27 @@ object ArrayFunctions extends RegistryFunctions { val product = genUID() val v = genUID() val one = Cast(I64(1), t) - StreamFold(ToStream(a), one, product, v, ApplyBinaryPrimOp(Multiply(), Ref(product, t), Ref(v, t))) + StreamFold( + ToStream(a), + one, + product, + v, + ApplyBinaryPrimOp(Multiply(), Ref(product, t), Ref(v, t)), + ) } - def registerAll() { - registerIR1("isEmpty", TArray(tv("T")), TBoolean)((_, a,_) => isEmpty(a)) + def registerAll(): Unit = { + registerIR1("isEmpty", TArray(tv("T")), TBoolean)((_, a, _) => isEmpty(a)) - registerIR2("extend", TArray(tv("T")), TArray(tv("T")), TArray(tv("T")))((_, a, b, _) => extend(a, b)) + registerIR2("extend", TArray(tv("T")), TArray(tv("T")), TArray(tv("T")))((_, a, b, _) => + extend(a, b) + ) registerIR2("append", TArray(tv("T")), tv("T"), TArray(tv("T"))) { (_, a, c, _) => extend(a, MakeArray(FastSeq(c), TArray(c.typ))) } - registerIR2("contains", TArray(tv("T")), tv("T"), TBoolean) { (_, a, e, _) => contains(a, e) } + registerIR2("contains", TArray(tv("T")), tv("T"), TBoolean)((_, a, e, _) => contains(a, e)) for ((stringOp, argType, retType, irOp) <- arrayOps) { registerIR2(stringOp, TArray(argType), argType, TArray(retType)) { (_, a, c, errorID) => @@ -116,41 +168,54 @@ object ArrayFunctions extends RegistryFunctions { ToArray(StreamMap(ToStream(a), i, irOp(c, Ref(i, c.typ), errorID))) } - registerIR2(stringOp, TArray(argType), TArray(argType), TArray(retType)) { (_, array1, array2, errorID) => - val a1id = genUID() - val e1 = Ref(a1id, tcoerce[TArray](array1.typ).elementType) - val a2id = genUID() - val e2 = Ref(a2id, tcoerce[TArray](array2.typ).elementType) - ToArray(StreamZip(FastSeq(ToStream(array1), ToStream(array2)), FastSeq(a1id, a2id), - irOp(e1, e2, errorID), ArrayZipBehavior.AssertSameLength)) + registerIR2(stringOp, TArray(argType), TArray(argType), TArray(retType)) { + (_, array1, array2, errorID) => + val a1id = genUID() + val e1 = Ref(a1id, tcoerce[TArray](array1.typ).elementType) + val a2id = genUID() + val e2 = Ref(a2id, tcoerce[TArray](array2.typ).elementType) + ToArray(StreamZip( + FastSeq(ToStream(array1), ToStream(array2)), + FastSeq(a1id, a2id), + irOp(e1, e2, errorID), + ArrayZipBehavior.AssertSameLength, + )) } } - registerIR1("sum", TArray(tnum("T")), tv("T"))((_, a,_) => sum(a)) + registerIR1("sum", TArray(tnum("T")), tv("T"))((_, a, _) => sum(a)) registerIR1("product", TArray(tnum("T")), tv("T"))((_, a, _) => product(a)) - def makeMinMaxOp(op: String): Seq[IR] => IR = { - { case Seq(a) => - val t = tcoerce[TArray](a.typ).elementType - val value = genUID() - val first = genUID() - val acc = genUID() - StreamFold2(ToStream(a), - FastSeq((acc, NA(t)), (first, True())), - value, - FastSeq( - If(Ref(first, TBoolean), Ref(value, t), invoke(op, t, Ref(acc, t), Ref(value, t))), - False() - ), - Ref(acc, t)) - } + def makeMinMaxOp(op: String): Seq[IR] => IR = { case Seq(a) => + val t = tcoerce[TArray](a.typ).elementType + val value = genUID() + val first = genUID() + val acc = genUID() + StreamFold2( + ToStream(a), + FastSeq((acc, NA(t)), (first, True())), + value, + FastSeq( + If(Ref(first, TBoolean), Ref(value, t), invoke(op, t, Ref(acc, t), Ref(value, t))), + False(), + ), + Ref(acc, t), + ) } - registerIR("min", Array(TArray(tnum("T"))), tv("T"), inline = true)((_, a, _) => makeMinMaxOp("min")(a)) - registerIR("nanmin", Array(TArray(tnum("T"))), tv("T"), inline = true)((_, a, _) => makeMinMaxOp("nanmin")(a)) - registerIR("max", Array(TArray(tnum("T"))), tv("T"), inline = true)((_, a, _) => makeMinMaxOp("max")(a)) - registerIR("nanmax", Array(TArray(tnum("T"))), tv("T"), inline = true)((_, a, _) => makeMinMaxOp("nanmax")(a)) + registerIR("min", Array(TArray(tnum("T"))), tv("T"), inline = true)((_, a, _) => + makeMinMaxOp("min")(a) + ) + registerIR("nanmin", Array(TArray(tnum("T"))), tv("T"), inline = true)((_, a, _) => + makeMinMaxOp("nanmin")(a) + ) + registerIR("max", Array(TArray(tnum("T"))), tv("T"), inline = true)((_, a, _) => + makeMinMaxOp("max")(a) + ) + registerIR("nanmax", Array(TArray(tnum("T"))), tv("T"), inline = true)((_, a, _) => + makeMinMaxOp("nanmax")(a) + ) registerIR("mean", Array(TArray(tnum("T"))), TFloat64, inline = true)((_, a, _) => mean(a)) @@ -166,15 +231,23 @@ object ArrayFunctions extends RegistryFunctions { Let( FastSeq(a.name -> ArraySort(StreamFilter(ToStream(array), v.name, !IsNA(v)))), - If(IsNA(a), + If( + IsNA(a), NA(t), Let( FastSeq(size.name -> ArrayLen(a)), - If(size.ceq(0), + If( + size.ceq(0), NA(t), - If(invoke("mod", TInt32, size, 2).cne(0), + If( + invoke("mod", TInt32, size, 2).cne(0), ref(midIdx), // odd number of non-missing elements - div(ref(midIdx) + ref(midIdx + 1), Cast(2, t))))))) + div(ref(midIdx) + ref(midIdx + 1), Cast(2, t)), + ), + ), + ), + ), + ) } def argF(a: IR, op: (Type) => ComparisonOp[Boolean], errorID: Int): IR = { @@ -188,24 +261,34 @@ object ArrayFunctions extends RegistryFunctions { def updateAccum(min: IR, midx: IR): IR = MakeStruct(FastSeq("m" -> min, "midx" -> midx)) - GetField(StreamFold( - StreamRange(I32(0), ArrayLen(a), I32(1)), - NA(tAccum), - accum, - idx, - Let( - FastSeq( - value -> ArrayRef(a, Ref(idx, TInt32), errorID), - m -> GetField(Ref(accum, tAccum), "m"), - ), - If(IsNA(Ref(value, t)), - Ref(accum, tAccum), - If(IsNA(Ref(m, t)), - updateAccum(Ref(value, t), Ref(idx, TInt32)), - If(ApplyComparisonOp(op(t), Ref(value, t), Ref(m, t)), + GetField( + StreamFold( + StreamRange(I32(0), ArrayLen(a), I32(1)), + NA(tAccum), + accum, + idx, + Let( + FastSeq( + value -> ArrayRef(a, Ref(idx, TInt32), errorID), + m -> GetField(Ref(accum, tAccum), "m"), + ), + If( + IsNA(Ref(value, t)), + Ref(accum, tAccum), + If( + IsNA(Ref(m, t)), updateAccum(Ref(value, t), Ref(idx, TInt32)), - Ref(accum, tAccum))))) - ), "midx") + If( + ApplyComparisonOp(op(t), Ref(value, t), Ref(m, t)), + updateAccum(Ref(value, t), Ref(idx, TInt32)), + Ref(accum, tAccum), + ), + ), + ), + ), + ), + "midx", + ) } registerIR1("argmin", TArray(tv("T")), TInt32)((_, a, errorID) => argF(a, LT(_), errorID)) @@ -233,38 +316,56 @@ object ArrayFunctions extends RegistryFunctions { Let( FastSeq( value -> ArrayRef(a, Ref(idx, TInt32), errorID), - m -> GetField(Ref(accum, tAccum), "m") + m -> GetField(Ref(accum, tAccum), "m"), ), - If(IsNA(Ref(value, t)), + If( + IsNA(Ref(value, t)), Ref(accum, tAccum), - If(IsNA(Ref(m, t)), + If( + IsNA(Ref(m, t)), updateAccum(Ref(value, t), Ref(idx, TInt32), I32(1)), - If(ApplyComparisonOp(op(t), Ref(value, t), Ref(m, t)), + If( + ApplyComparisonOp(op(t), Ref(value, t), Ref(m, t)), updateAccum(Ref(value, t), Ref(idx, TInt32), I32(1)), - If(ApplyComparisonOp(EQ(t), Ref(value, t), Ref(m, t)), + If( + ApplyComparisonOp(EQ(t), Ref(value, t), Ref(m, t)), updateAccum( Ref(value, t), Ref(idx, TInt32), - ApplyBinaryPrimOp(Add(), GetField(Ref(accum, tAccum), "count"), I32(1))), - Ref(accum, tAccum)))))) + ApplyBinaryPrimOp(Add(), GetField(Ref(accum, tAccum), "count"), I32(1)), + ), + Ref(accum, tAccum), + ), + ), + ), + ), + ), ) - Let(FastSeq(result -> fold), - If(ApplyComparisonOp(EQ(TInt32), GetField(Ref(result, tAccum), "count"), I32(1)), + Let( + FastSeq(result -> fold), + If( + ApplyComparisonOp(EQ(TInt32), GetField(Ref(result, tAccum), "count"), I32(1)), GetField(Ref(result, tAccum), "midx"), - NA(TInt32))) + NA(TInt32), + ), + ) } - registerIR1("uniqueMinIndex", TArray(tv("T")), TInt32)((_, a, errorID) => uniqueIndex(a, LT(_), errorID)) + registerIR1("uniqueMinIndex", TArray(tv("T")), TInt32)((_, a, errorID) => + uniqueIndex(a, LT(_), errorID) + ) - registerIR1("uniqueMaxIndex", TArray(tv("T")), TInt32)((_, a, errorID) => uniqueIndex(a, GT(_), errorID)) + registerIR1("uniqueMaxIndex", TArray(tv("T")), TInt32)((_, a, errorID) => + uniqueIndex(a, GT(_), errorID) + ) registerIR2("indexArray", TArray(tv("T")), TInt32, tv("T")) { (_, a, i, errorID) => ArrayRef( a, - If(ApplyComparisonOp(LT(TInt32), i, I32(0)), - ApplyBinaryPrimOp(Add(), ArrayLen(a), i), - i), errorID) + If(ApplyComparisonOp(LT(TInt32), i, I32(0)), ApplyBinaryPrimOp(Add(), ArrayLen(a), i), i), + errorID, + ) } registerIR1("flatten", TArray(TArray(tv("T"))), TArray(tv("T"))) { (_, a, _) => @@ -272,162 +373,333 @@ object ArrayFunctions extends RegistryFunctions { ToArray(StreamFlatMap(ToStream(a), elt.name, ToStream(elt))) } - registerSCode4("lowerBound", TArray(tv("T")), tv("T"), TInt32, TInt32, TInt32, { - (_, _, _, _, _) => SInt32 - }) { case (r, cb, rt, array, key, begin, end, _) => - val lt = cb.emb.ecb.getOrderingFunction(key.st, array.asIndexable.st.elementType, CodeOrdering.Lt()) - primitive(BinarySearch.lowerBound(cb, array.asIndexable, { elt => - lt(cb, cb.memoize(elt), EmitValue.present(key)) - }, begin.asInt.value, end.asInt.value)) + registerSCode4( + "lowerBound", + TArray(tv("T")), + tv("T"), + TInt32, + TInt32, + TInt32, + (_, _, _, _, _) => SInt32, + ) { case (_, cb, _, array, key, begin, end, _) => + val lt = + cb.emb.ecb.getOrderingFunction(key.st, array.asIndexable.st.elementType, CodeOrdering.Lt()) + primitive(BinarySearch.lowerBound( + cb, + array.asIndexable, + elt => lt(cb, cb.memoize(elt), EmitValue.present(key)), + begin.asInt.value, + end.asInt.value, + )) } - registerIEmitCode2("corr", TArray(TFloat64), TArray(TFloat64), TFloat64, { - (_: Type, _: EmitType, _: EmitType) => EmitType(SFloat64, false) - }) { case (cb, r, rt, errorID, ec1, ec2) => + registerIEmitCode2( + "corr", + TArray(TFloat64), + TArray(TFloat64), + TFloat64, + (_: Type, _: EmitType, _: EmitType) => EmitType(SFloat64, false), + ) { case (cb, _, _, errorID, ec1, ec2) => ec1.toI(cb).flatMap(cb) { case pv1: SIndexableValue => ec2.toI(cb).flatMap(cb) { case pv2: SIndexableValue => val l1 = cb.newLocal("len1", pv1.loadLength()) val l2 = cb.newLocal("len2", pv2.loadLength()) - cb.if_(l1.cne(l2), { - cb._fatalWithError(errorID, + cb.if_( + l1.cne(l2), + cb._fatalWithError( + errorID, "'corr': cannot compute correlation between arrays of different lengths: ", - l1.toS, - ", ", - l2.toS) - }) - IEmitCode(cb, l1.ceq(0), { - val xSum = cb.newLocal[Double]("xSum", 0d) - val ySum = cb.newLocal[Double]("ySum", 0d) - val xSqSum = cb.newLocal[Double]("xSqSum", 0d) - val ySqSum = cb.newLocal[Double]("ySqSum", 0d) - val xySum = cb.newLocal[Double]("xySum", 0d) - val i = cb.newLocal[Int]("i") - val n = cb.newLocal[Int]("n", 0) - cb.for_(cb.assign(i, 0), i < l1, cb.assign(i, i + 1), { - pv1.loadElement(cb, i).consume(cb, {}, { xc => - pv2.loadElement(cb, i).consume(cb, {}, { yc => - val x = cb.newLocal[Double]("x", xc.asDouble.value) - val y = cb.newLocal[Double]("y", yc.asDouble.value) - cb.assign(xSum, xSum + x) - cb.assign(xSqSum, xSqSum + x * x) - cb.assign(ySum, ySum + y) - cb.assign(ySqSum, ySqSum + y * y) - cb.assign(xySum, xySum + x * y) - cb.assign(n, n + 1) - }) - }) - }) - val res = cb.memoize((n.toD * xySum - xSum * ySum) / Code.invokeScalaObject1[Double, Double]( - MathFunctions.mathPackageClass, - "sqrt", - (n.toD * xSqSum - xSum * xSum) * (n.toD * ySqSum - ySum * ySum))) - primitive(res) - }) + l1.toS, + ", ", + l2.toS, + ), + ) + IEmitCode( + cb, + l1.ceq(0), { + val xSum = cb.newLocal[Double]("xSum", 0d) + val ySum = cb.newLocal[Double]("ySum", 0d) + val xSqSum = cb.newLocal[Double]("xSqSum", 0d) + val ySqSum = cb.newLocal[Double]("ySqSum", 0d) + val xySum = cb.newLocal[Double]("xySum", 0d) + val i = cb.newLocal[Int]("i") + val n = cb.newLocal[Int]("n", 0) + cb.for_( + cb.assign(i, 0), + i < l1, + cb.assign(i, i + 1), { + pv1.loadElement(cb, i).consume( + cb, + {}, + { xc => + pv2.loadElement(cb, i).consume( + cb, + {}, + { yc => + val x = cb.newLocal[Double]("x", xc.asDouble.value) + val y = cb.newLocal[Double]("y", yc.asDouble.value) + cb.assign(xSum, xSum + x) + cb.assign(xSqSum, xSqSum + x * x) + cb.assign(ySum, ySum + y) + cb.assign(ySqSum, ySqSum + y * y) + cb.assign(xySum, xySum + x * y) + cb.assign(n, n + 1) + }, + ) + }, + ) + }, + ) + val res = + cb.memoize((n.toD * xySum - xSum * ySum) / Code.invokeScalaObject1[Double, Double]( + MathFunctions.mathPackageClass, + "sqrt", + (n.toD * xSqSum - xSum * xSum) * (n.toD * ySqSum - ySum * ySum), + )) + primitive(res) + }, + ) } } } - registerIEmitCode4("local_to_global_g", TArray(TVariable("T")), TArray(TInt32), TInt32, TVariable("T"), TArray(TVariable("T")), - { case (rt, inArrayET, la, n, _) => EmitType(PCanonicalArray(PType.canonical(inArrayET.st.asInstanceOf[SContainer].elementType.storageType())).sType, inArrayET.required && la.required && n.required) })( - { case (cb, region, rt: SIndexablePointer, err, array, localAlleles, nTotalAlleles, fillInValue) => - + registerIEmitCode4( + "local_to_global_g", + TArray(TVariable("T")), + TArray(TInt32), + TInt32, + TVariable("T"), + TArray(TVariable("T")), + { case (_, inArrayET, la, n, _) => + EmitType( + PCanonicalArray( + PType.canonical(inArrayET.st.asInstanceOf[SContainer].elementType.storageType()) + ).sType, + inArrayET.required && la.required && n.required, + ) + }, + ) { + case ( + cb, + region, + rt: SIndexablePointer, + err, + array, + localAlleles, + nTotalAlleles, + fillInValue, + ) => IEmitCode.multiMapEmitCodes(cb, FastSeq(array, localAlleles, nTotalAlleles)) { - case IndexedSeq(array: SIndexableValue, localAlleles: SIndexableValue, _nTotalAlleles: SInt32Value) => + case IndexedSeq( + array: SIndexableValue, + localAlleles: SIndexableValue, + _nTotalAlleles: SInt32Value, + ) => def triangle(x: Value[Int]): Code[Int] = (x * (x + 1)) / 2 - val nTotalAlleles =_nTotalAlleles.value + val nTotalAlleles = _nTotalAlleles.value val nGenotypes = cb.memoize(triangle(nTotalAlleles)) val pt = rt.pType.asInstanceOf[PCanonicalArray] - cb.if_(nTotalAlleles < 0, cb._fatalWithError(err, "local_to_global: n_total_alleles less than 0: ", nGenotypes.toS)) + cb.if_( + nTotalAlleles < 0, + cb._fatalWithError( + err, + "local_to_global: n_total_alleles less than 0: ", + nGenotypes.toS, + ), + ) val localLen = array.loadLength() val laLen = localAlleles.loadLength() - cb.if_(localLen cne triangle(laLen), cb._fatalWithError(err, "local_to_global: array should be the triangle number of local alleles: found: ", localLen.toS, " elements, and", laLen.toS, " alleles")) + cb.if_( + localLen cne triangle(laLen), + cb._fatalWithError( + err, + "local_to_global: array should be the triangle number of local alleles: found: ", + localLen.toS, + " elements, and", + laLen.toS, + " alleles", + ), + ) val fillIn = cb.memoize(fillInValue) val (push, finish) = pt.constructFromIndicesUnsafe(cb, region, nGenotypes, false) // fill in if necessary - cb.if_(localLen cne nGenotypes, { - val i = cb.newLocal[Int]("i", 0) - cb.while_(i < nGenotypes, { - push(cb, i, fillIn.toI(cb)) - cb.assign(i, i + 1) - }) - }) - + cb.if_( + localLen cne nGenotypes, { + val i = cb.newLocal[Int]("i", 0) + cb.while_( + i < nGenotypes, { + push(cb, i, fillIn.toI(cb)) + cb.assign(i, i + 1) + }, + ) + }, + ) val i = cb.newLocal[Int]("la_i", 0) val laGIndexer = cb.newLocal[Int]("g_indexer", 0) - cb.while_(i < laLen, { - val lai = localAlleles.loadElement(cb, i).get(cb, "local_to_global: local alleles elements cannot be missing", err).asInt32.value - cb.if_(lai >= nTotalAlleles, cb._fatalWithError(err, "local_to_global: local allele of ", lai.toS, " out of bounds given n_total_alleles of ", nTotalAlleles.toS)) + cb.while_( + i < laLen, { + val lai = localAlleles.loadElement(cb, i).getOrFatal( + cb, + "local_to_global: local alleles elements cannot be missing", + err, + ).asInt32.value + cb.if_( + lai >= nTotalAlleles, + cb._fatalWithError( + err, + "local_to_global: local allele of ", + lai.toS, + " out of bounds given n_total_alleles of ", + nTotalAlleles.toS, + ), + ) + + val j = cb.newLocal[Int]("la_j", 0) + cb.while_( + j <= i, { + val laj = localAlleles.loadElement(cb, j).getOrFatal( + cb, + "local_to_global: local alleles elements cannot be missing", + err, + ).asInt32.value + + val dest = cb.newLocal[Int]("dest") + cb.if_( + lai >= laj, + cb.assign(dest, triangle(lai) + laj), + cb.assign(dest, triangle(laj) + lai), + ) + + push(cb, dest, array.loadElement(cb, laGIndexer)) + cb.assign(laGIndexer, laGIndexer + 1) + cb.assign(j, j + 1) + }, + ) - val j = cb.newLocal[Int]("la_j", 0) - cb.while_(j <= i, { - val laj = localAlleles.loadElement(cb, j).get(cb, "local_to_global: local alleles elements cannot be missing", err).asInt32.value - - val dest = cb.newLocal[Int]("dest") - cb.if_(lai >= laj, { - cb.assign(dest, triangle(lai) + laj) - }, { - cb.assign(dest, triangle(laj) + lai) - }) - - push(cb, dest, array.loadElement(cb, laGIndexer)) - cb.assign(laGIndexer, laGIndexer + 1) - cb.assign(j, j+1) - }) - - cb.assign(i, i+1) - }) + cb.assign(i, i + 1) + }, + ) finish(cb) } - }) - - registerIEmitCode5("local_to_global_a_r", TArray(TVariable("T")), TArray(TInt32), TInt32, TVariable("T"), TBoolean, TArray(TVariable("T")), - {case (rt, inArrayET, la, n, _, omitFirst) => EmitType(PCanonicalArray(PType.canonical(inArrayET.st.asInstanceOf[SContainer].elementType.storageType())).sType, inArrayET.required && la.required && n.required && omitFirst.required)})( - { case (cb, region, rt: SIndexablePointer, err, array, localAlleles, nTotalAlleles, fillInValue, omitFirstElement) => + } - IEmitCode.multiMapEmitCodes(cb, FastSeq(array, localAlleles, nTotalAlleles, omitFirstElement)) { - case IndexedSeq(array: SIndexableValue, localAlleles: SIndexableValue, _nTotalAlleles: SInt32Value, omitFirst: SBooleanValue) => + registerIEmitCode5( + "local_to_global_a_r", + TArray(TVariable("T")), + TArray(TInt32), + TInt32, + TVariable("T"), + TBoolean, + TArray(TVariable("T")), + { case (_, inArrayET, la, n, _, omitFirst) => + EmitType( + PCanonicalArray( + PType.canonical(inArrayET.st.asInstanceOf[SContainer].elementType.storageType()) + ).sType, + inArrayET.required && la.required && n.required && omitFirst.required, + ) + }, + ) { + case ( + cb, + region, + rt: SIndexablePointer, + err, + array, + localAlleles, + nTotalAlleles, + fillInValue, + omitFirstElement, + ) => + IEmitCode.multiMapEmitCodes( + cb, + FastSeq(array, localAlleles, nTotalAlleles, omitFirstElement), + ) { + case IndexedSeq( + array: SIndexableValue, + localAlleles: SIndexableValue, + _nTotalAlleles: SInt32Value, + omitFirst: SBooleanValue, + ) => val nTotalAlleles = _nTotalAlleles.value val pt = rt.pType.asInstanceOf[PCanonicalArray] - cb.if_(nTotalAlleles < 0, cb._fatalWithError(err, "local_to_global: n_total_alleles less than 0: ", nTotalAlleles.toS)) + cb.if_( + nTotalAlleles < 0, + cb._fatalWithError( + err, + "local_to_global: n_total_alleles less than 0: ", + nTotalAlleles.toS, + ), + ) val localLen = array.loadLength() - cb.if_(localLen cne localAlleles.loadLength(), cb._fatalWithError(err,"local_to_global: array and local alleles lengths differ: ", localLen.toS, ", ", localAlleles.loadLength().toS)) + cb.if_( + localLen cne localAlleles.loadLength(), + cb._fatalWithError( + err, + "local_to_global: array and local alleles lengths differ: ", + localLen.toS, + ", ", + localAlleles.loadLength().toS, + ), + ) val fillIn = cb.memoize(fillInValue) val idxAdjustmentForOmitFirst = cb.newLocal[Int]("idxAdj") - cb.if_(omitFirst.value, + cb.if_( + omitFirst.value, cb.assign(idxAdjustmentForOmitFirst, 1), - cb.assign(idxAdjustmentForOmitFirst, 0)) + cb.assign(idxAdjustmentForOmitFirst, 0), + ) val globalLen = cb.memoize(nTotalAlleles - idxAdjustmentForOmitFirst) val (push, finish) = pt.constructFromIndicesUnsafe(cb, region, globalLen, false) // fill in if necessary - cb.if_(localLen cne globalLen, { - val i = cb.newLocal[Int]("i", 0) - cb.while_(i < globalLen, { - push(cb, i, fillIn.toI(cb)) - cb.assign(i, i + 1) - }) - }) + cb.if_( + localLen cne globalLen, { + val i = cb.newLocal[Int]("i", 0) + cb.while_( + i < globalLen, { + push(cb, i, fillIn.toI(cb)) + cb.assign(i, i + 1) + }, + ) + }, + ) val i = cb.newLocal[Int]("la_i", 0) - cb.while_(i < localLen, { - val lai = localAlleles.loadElement(cb, i + idxAdjustmentForOmitFirst).get(cb, "local_to_global: local alleles elements cannot be missing", err).asInt32.value - cb.if_(lai >= nTotalAlleles, cb._fatalWithError(err, "local_to_global: local allele of ", lai.toS, " out of bounds given n_total_alleles of ", nTotalAlleles.toS)) - push(cb, cb.memoize(lai - idxAdjustmentForOmitFirst), array.loadElement(cb, i)) + cb.while_( + i < localLen, { + val lai = localAlleles.loadElement(cb, i + idxAdjustmentForOmitFirst).getOrFatal( + cb, + "local_to_global: local alleles elements cannot be missing", + err, + ).asInt32.value + cb.if_( + lai >= nTotalAlleles, + cb._fatalWithError( + err, + "local_to_global: local allele of ", + lai.toS, + " out of bounds given n_total_alleles of ", + nTotalAlleles.toS, + ), + ) + push(cb, cb.memoize(lai - idxAdjustmentForOmitFirst), array.loadElement(cb, i)) - cb.assign(i, i + 1) - }) + cb.assign(i, i + 1) + }, + ) finish(cb) } - }) + } } } diff --git a/hail/src/main/scala/is/hail/expr/ir/functions/CallFunctions.scala b/hail/src/main/scala/is/hail/expr/ir/functions/CallFunctions.scala index 71ff0decaf8..427a9176189 100644 --- a/hail/src/main/scala/is/hail/expr/ir/functions/CallFunctions.scala +++ b/hail/src/main/scala/is/hail/expr/ir/functions/CallFunctions.scala @@ -1,119 +1,218 @@ package is.hail.expr.ir.functions import is.hail.asm4s.Code +import is.hail.types.physical.{PCanonicalArray, PInt32} import is.hail.types.physical.stypes._ import is.hail.types.physical.stypes.concrete.{SCanonicalCall, SIndexablePointer} import is.hail.types.physical.stypes.interfaces._ import is.hail.types.physical.stypes.primitives.{SBoolean, SInt32} -import is.hail.types.physical.{PCanonicalArray, PInt32} import is.hail.types.virtual._ import is.hail.variant._ import scala.reflect.classTag object CallFunctions extends RegistryFunctions { - def registerAll() { - registerWrappedScalaFunction1("Call", TString, TCall, (rt: Type, st: SType) => SCanonicalCall)(Call.getClass, "parse") + def registerAll(): Unit = { + registerWrappedScalaFunction1("Call", TString, TCall, (rt: Type, st: SType) => SCanonicalCall)( + Call.getClass, + "parse", + ) registerSCode1("callFromRepr", TInt32, TCall, (rt: Type, _: SType) => SCanonicalCall) { - case (er, cb, rt, repr, _) => SCanonicalCall.constructFromIntRepr(cb, repr.asInt.value) + case (_, cb, _, repr, _) => SCanonicalCall.constructFromIntRepr(cb, repr.asInt.value) } registerSCode1("Call", TBoolean, TCall, (rt: Type, _: SType) => SCanonicalCall) { - case (er, cb, rt, phased, _) => - SCanonicalCall.constructFromIntRepr(cb, Code.invokeScalaObject[Int]( - Call0.getClass, "apply", - Array(classTag[Boolean].runtimeClass), - Array(phased.asBoolean.value))) + case (_, cb, _, phased, _) => + SCanonicalCall.constructFromIntRepr( + cb, + Code.invokeScalaObject[Int]( + Call0.getClass, + "apply", + Array(classTag[Boolean].runtimeClass), + Array(phased.asBoolean.value), + ), + ) } - registerSCode2("Call", TInt32, TBoolean, TCall, (rt: Type, _: SType, _: SType) => SCanonicalCall) { - case (er, cb, rt, a1, phased, _) => - SCanonicalCall.constructFromIntRepr(cb, Code.invokeScalaObject[Int]( - Call1.getClass, "apply", - Array(classTag[Int].runtimeClass, classTag[Boolean].runtimeClass), - Array(a1.asInt.value, phased.asBoolean.value))) + registerSCode2( + "Call", + TInt32, + TBoolean, + TCall, + (rt: Type, _: SType, _: SType) => SCanonicalCall, + ) { + case (_, cb, _, a1, phased, _) => + SCanonicalCall.constructFromIntRepr( + cb, + Code.invokeScalaObject[Int]( + Call1.getClass, + "apply", + Array(classTag[Int].runtimeClass, classTag[Boolean].runtimeClass), + Array(a1.asInt.value, phased.asBoolean.value), + ), + ) } - registerSCode3("Call", TInt32, TInt32, TBoolean, TCall, (rt: Type, _: SType, _: SType, _: SType) => SCanonicalCall) { - case (er, cb, rt, a1, a2, phased, _) => - SCanonicalCall.constructFromIntRepr(cb, Code.invokeScalaObject[Int]( - Call2.getClass, "apply", - Array(classTag[Int].runtimeClass, classTag[Int].runtimeClass, classTag[Boolean].runtimeClass), - Array(a1.asInt.value, a2.asInt.value, phased.asBoolean.value))) + registerSCode3( + "Call", + TInt32, + TInt32, + TBoolean, + TCall, + (rt: Type, _: SType, _: SType, _: SType) => SCanonicalCall, + ) { + case (_, cb, _, a1, a2, phased, _) => + SCanonicalCall.constructFromIntRepr( + cb, + Code.invokeScalaObject[Int]( + Call2.getClass, + "apply", + Array( + classTag[Int].runtimeClass, + classTag[Int].runtimeClass, + classTag[Boolean].runtimeClass, + ), + Array(a1.asInt.value, a2.asInt.value, phased.asBoolean.value), + ), + ) } - registerSCode1("UnphasedDiploidGtIndexCall", TInt32, TCall, (rt: Type, _: SType) => SCanonicalCall) { - case (er, cb, rt, x, _) => - SCanonicalCall.constructFromIntRepr(cb, Code.invokeScalaObject[Int]( - Call2.getClass, "fromUnphasedDiploidGtIndex", - Array(classTag[Int].runtimeClass), - Array(x.asInt.value))) + registerSCode1( + "UnphasedDiploidGtIndexCall", + TInt32, + TCall, + (rt: Type, _: SType) => SCanonicalCall, + ) { + case (_, cb, _, x, _) => + SCanonicalCall.constructFromIntRepr( + cb, + Code.invokeScalaObject[Int]( + Call2.getClass, + "fromUnphasedDiploidGtIndex", + Array(classTag[Int].runtimeClass), + Array(x.asInt.value), + ), + ) } - - registerWrappedScalaFunction2("Call", TArray(TInt32), TBoolean, TCall, { - case (rt: Type, _: SType, _: SType) => SCanonicalCall - })(CallN.getClass, "apply") + registerWrappedScalaFunction2( + "Call", + TArray(TInt32), + TBoolean, + TCall, + { + case (_: Type, _: SType, _: SType) => SCanonicalCall + }, + )(CallN.getClass, "apply") val qualities = Array("isPhased", "isHomRef", "isHet", "isHomVar", "isNonRef", "isHetNonRef", "isHetRef") - for (q <- qualities) { + for (q <- qualities) registerSCode1(q, TCall, TBoolean, (rt: Type, _: SType) => SBoolean) { - case (er, cb, rt, call, _) => + case (_, cb, _, call, _) => primitive(cb.memoize(Code.invokeScalaObject[Boolean]( - Call.getClass, q, Array(classTag[Int].runtimeClass), Array(call.asCall.canonicalCall(cb))))) + Call.getClass, + q, + Array(classTag[Int].runtimeClass), + Array(call.asCall.canonicalCall(cb)), + ))) } - } registerSCode1("ploidy", TCall, TInt32, (rt: Type, _: SType) => SInt32) { - case (er, cb, rt, call, _) => + case (_, cb, _, call, _) => primitive(cb.memoize(Code.invokeScalaObject[Int]( - Call.getClass, "ploidy", Array(classTag[Int].runtimeClass), Array(call.asCall.canonicalCall(cb))))) + Call.getClass, + "ploidy", + Array(classTag[Int].runtimeClass), + Array(call.asCall.canonicalCall(cb)), + ))) } registerSCode1("unphase", TCall, TCall, (rt: Type, a1: SType) => a1) { - case (er, cb, rt, call, _) => + case (_, cb, _, call, _) => call.asCall.unphase(cb) } - registerSCode2("containsAllele", TCall, TInt32, TBoolean, (rt: Type, _: SType, _: SType) => SBoolean) { - case (er, cb, rt, call, allele, _) => + registerSCode2( + "containsAllele", + TCall, + TInt32, + TBoolean, + (rt: Type, _: SType, _: SType) => SBoolean, + ) { + case (_, cb, _, call, allele, _) => primitive(call.asCall.containsAllele(cb, allele.asInt.value)) } registerSCode1("nNonRefAlleles", TCall, TInt32, (rt: Type, _: SType) => SInt32) { - case (er, cb, rt, call, _) => + case (_, cb, _, call, _) => primitive(cb.memoize(Code.invokeScalaObject[Int]( - Call.getClass, "nNonRefAlleles", Array(classTag[Int].runtimeClass), Array(call.asCall.canonicalCall(cb))))) + Call.getClass, + "nNonRefAlleles", + Array(classTag[Int].runtimeClass), + Array(call.asCall.canonicalCall(cb)), + ))) } registerSCode1("unphasedDiploidGtIndex", TCall, TInt32, (rt: Type, _: SType) => SInt32) { - case (er, cb, rt, call, _) => + case (_, cb, _, call, _) => primitive(cb.memoize(Code.invokeScalaObject[Int]( - Call.getClass, "unphasedDiploidGtIndex", Array(classTag[Int].runtimeClass), Array(call.asCall.canonicalCall(cb))))) + Call.getClass, + "unphasedDiploidGtIndex", + Array(classTag[Int].runtimeClass), + Array(call.asCall.canonicalCall(cb)), + ))) } registerSCode2("index", TCall, TInt32, TInt32, (rt: Type, _: SType, _: SType) => SInt32) { - case (er, cb, rt, call, idx, _) => + case (_, cb, _, call, idx, _) => primitive(cb.memoize(Code.invokeScalaObject[Int]( - Call.getClass, "alleleByIndex", Array(classTag[Int].runtimeClass, classTag[Int].runtimeClass), Array(call.asCall.canonicalCall(cb), idx.asInt.value)))) + Call.getClass, + "alleleByIndex", + Array(classTag[Int].runtimeClass, classTag[Int].runtimeClass), + Array(call.asCall.canonicalCall(cb), idx.asInt.value), + ))) } - - registerSCode2("downcode", TCall, TInt32, TCall, (rt: Type, _: SType, _: SType) => SCanonicalCall) { - case (er, cb, rt, call, downcodedAllele, _) => - SCanonicalCall.constructFromIntRepr(cb, Code.invokeScalaObject[Int]( - Call.getClass, "downcode", Array(classTag[Int].runtimeClass, classTag[Int].runtimeClass), Array(call.asCall.canonicalCall(cb), downcodedAllele.asInt.value))) + registerSCode2( + "downcode", + TCall, + TInt32, + TCall, + (rt: Type, _: SType, _: SType) => SCanonicalCall, + ) { + case (_, cb, _, call, downcodedAllele, _) => + SCanonicalCall.constructFromIntRepr( + cb, + Code.invokeScalaObject[Int]( + Call.getClass, + "downcode", + Array(classTag[Int].runtimeClass, classTag[Int].runtimeClass), + Array(call.asCall.canonicalCall(cb), downcodedAllele.asInt.value), + ), + ) } - registerSCode2("lgt_to_gt", TCall, TArray(TInt32), TCall, { case (rt: Type, sc: SCall, _:SType) => sc }) { - case (er, cb, rt, call, localAlleles, errorID) => + registerSCode2( + "lgt_to_gt", + TCall, + TArray(TInt32), + TCall, + { case (_: Type, sc: SCall, _: SType) => sc }, + ) { + case (_, cb, _, call, localAlleles, errorID) => call.asCall.lgtToGT(cb, localAlleles.asIndexable, errorID) } - registerWrappedScalaFunction2("oneHotAlleles", TCall, TInt32, TArray(TInt32), { - case (rt: Type, _: SType, _: SType) => SIndexablePointer(PCanonicalArray(PInt32(true))) - })(Call.getClass, "oneHotAlleles") + registerWrappedScalaFunction2( + "oneHotAlleles", + TCall, + TInt32, + TArray(TInt32), + { + case (_: Type, _: SType, _: SType) => SIndexablePointer(PCanonicalArray(PInt32(true))) + }, + )(Call.getClass, "oneHotAlleles") } } diff --git a/hail/src/main/scala/is/hail/expr/ir/functions/DictFunctions.scala b/hail/src/main/scala/is/hail/expr/ir/functions/DictFunctions.scala index 4dc8e257a76..fbcca3f1103 100644 --- a/hail/src/main/scala/is/hail/expr/ir/functions/DictFunctions.scala +++ b/hail/src/main/scala/is/hail/expr/ir/functions/DictFunctions.scala @@ -9,36 +9,53 @@ object DictFunctions extends RegistryFunctions { def contains(dict: IR, key: IR) = { val i = Ref(genUID(), TInt32) - If(IsNA(dict), + If( + IsNA(dict), NA(TBoolean), - Let(FastSeq(i.name -> LowerBoundOnOrderedCollection(dict, key, onKey = true)), - If(i.ceq(ArrayLen(CastToArray(dict))), + Let( + FastSeq(i.name -> LowerBoundOnOrderedCollection(dict, key, onKey = true)), + If( + i.ceq(ArrayLen(CastToArray(dict))), False(), ApplyComparisonOp( EQWithNA(key.typ), GetField(ArrayRef(CastToArray(dict), i), "key"), - key)))) + key, + ), + ), + ), + ) } def get(dict: IR, key: IR, default: IR): IR = { val i = Ref(genUID(), TInt32) - If(IsNA(dict), + If( + IsNA(dict), NA(default.typ), - Let(FastSeq(i.name -> LowerBoundOnOrderedCollection(dict, key, onKey=true)), - If(i.ceq(ArrayLen(CastToArray(dict))), + Let( + FastSeq(i.name -> LowerBoundOnOrderedCollection(dict, key, onKey = true)), + If( + i.ceq(ArrayLen(CastToArray(dict))), default, - If(ApplyComparisonOp(EQWithNA(key.typ), GetField(ArrayRef(CastToArray(dict), i), "key"), key), + If( + ApplyComparisonOp( + EQWithNA(key.typ), + GetField(ArrayRef(CastToArray(dict), i), "key"), + key, + ), GetField(ArrayRef(CastToArray(dict), i), "value"), - default)))) + default, + ), + ), + ), + ) } val tdict = TDict(tv("key"), tv("value")) - def registerAll() { - registerIR1("isEmpty", tdict, TBoolean) { (_, d, _) => - ArrayFunctions.isEmpty(CastToArray(d)) - } + def registerAll(): Unit = { + registerIR1("isEmpty", tdict, TBoolean)((_, d, _) => ArrayFunctions.isEmpty(CastToArray(d))) registerIR2("contains", tdict, tv("key"), TBoolean)((_, a, b, _) => contains(a, b)) @@ -49,22 +66,33 @@ object DictFunctions extends RegistryFunctions { registerIR2("index", tdict, tv("key"), tv("value")) { (_, d, k, errorID) => val vtype = types.tcoerce[TBaseStruct](types.tcoerce[TContainer](d.typ).elementType).types(1) - val errormsg = invoke("concat", TString, + val errormsg = invoke( + "concat", + TString, Str("Key "), - invoke("concat", TString, + invoke( + "concat", + TString, invoke("showStr", TString, k), - invoke("concat", TString, + invoke( + "concat", + TString, Str(" not found in dictionary. Keys: "), - invoke("str", TString, invoke("keys", TArray(k.typ), d))))) + invoke("str", TString, invoke("keys", TArray(k.typ), d)), + ), + ), + ) get(d, k, Die(errormsg, vtype, errorID)) } - registerIR1("dictToArray", tdict, TArray(TStruct("key" -> tv("key"), "value" -> tv("value")))) { (_, d, _) => - val elt = Ref(genUID(), types.tcoerce[TContainer](d.typ).elementType) - ToArray(StreamMap( - ToStream(d), - elt.name, - MakeTuple.ordered(FastSeq(GetField(elt, "key"), GetField(elt, "value"))))) + registerIR1("dictToArray", tdict, TArray(TStruct("key" -> tv("key"), "value" -> tv("value")))) { + (_, d, _) => + val elt = Ref(genUID(), types.tcoerce[TContainer](d.typ).elementType) + ToArray(StreamMap( + ToStream(d), + elt.name, + MakeTuple.ordered(FastSeq(GetField(elt, "key"), GetField(elt, "value"))), + )) } registerIR1("keySet", tdict, TSet(tv("key"))) { (_, d, _) => @@ -72,9 +100,13 @@ object DictFunctions extends RegistryFunctions { ToSet(StreamMap(ToStream(d), pairs.name, GetField(pairs, "key"))) } - registerIR1("dict", TSet(TTuple(tv("key"), tv("value"))), tdict)((_, s, _) => ToDict(ToStream(s))) + registerIR1("dict", TSet(TTuple(tv("key"), tv("value"))), tdict)((_, s, _) => + ToDict(ToStream(s)) + ) - registerIR1("dict", TArray(TTuple(tv("key"), tv("value"))), tdict)((_, a, _) => ToDict(ToStream(a))) + registerIR1("dict", TArray(TTuple(tv("key"), tv("value"))), tdict)((_, a, _) => + ToDict(ToStream(a)) + ) registerIR1("keys", tdict, TArray(tv("key"))) { (_, d, _) => val elt = Ref(genUID(), types.tcoerce[TContainer](d.typ).elementType) diff --git a/hail/src/main/scala/is/hail/expr/ir/functions/Functions.scala b/hail/src/main/scala/is/hail/expr/ir/functions/Functions.scala index 1143d680843..357a84685f9 100644 --- a/hail/src/main/scala/is/hail/expr/ir/functions/Functions.scala +++ b/hail/src/main/scala/is/hail/expr/ir/functions/Functions.scala @@ -6,42 +6,44 @@ import is.hail.backend.{ExecuteContext, HailStateManager} import is.hail.experimental.ExperimentalFunctions import is.hail.expr.ir._ import is.hail.io.bgen.BGENFunctions -import is.hail.types._ import is.hail.types.physical._ +import is.hail.types.physical.stypes.{EmitType, SType, SValue} import is.hail.types.physical.stypes.concrete._ import is.hail.types.physical.stypes.interfaces._ import is.hail.types.physical.stypes.primitives._ -import is.hail.types.physical.stypes.{EmitType, SType, SValue} import is.hail.types.virtual._ import is.hail.utils._ import is.hail.variant.{Locus, ReferenceGenome} -import org.apache.spark.sql.Row import scala.collection.JavaConverters._ import scala.collection.mutable import scala.reflect._ +import org.apache.spark.sql.Row + object IRFunctionRegistry { - private val userAddedFunctions: mutable.Set[(String, (Type, Seq[Type], Seq[Type]))] = mutable.HashSet.empty + private val userAddedFunctions: mutable.Set[(String, (Type, Seq[Type], Seq[Type]))] = + mutable.HashSet.empty - def clearUserFunctions() { + def clearUserFunctions(): Unit = { userAddedFunctions.foreach { case (name, (rt, typeParameters, valueParameterTypes)) => - removeIRFunction(name, rt, typeParameters, valueParameterTypes) } + removeIRFunction(name, rt, typeParameters, valueParameterTypes) + } userAddedFunctions.clear() } type IRFunctionSignature = (Seq[Type], Seq[Type], Type, Boolean) type IRFunctionImplementation = (Seq[Type], Seq[IR], Int) => IR - val irRegistry: mutable.Map[String, mutable.Map[IRFunctionSignature, IRFunctionImplementation]] = new mutable.HashMap() + val irRegistry: mutable.Map[String, mutable.Map[IRFunctionSignature, IRFunctionImplementation]] = + new mutable.HashMap() val jvmRegistry: mutable.MultiMap[String, JVMFunction] = new mutable.HashMap[String, mutable.Set[JVMFunction]] with mutable.MultiMap[String, JVMFunction] - private[this] def requireJavaIdentifier(name: String): Unit = { + private[this] def requireJavaIdentifier(name: String): Unit = if (!isJavaIdentifier(name)) - throw new IllegalArgumentException(s"Illegal function name, not Java identifier: ${ name }") - } + throw new IllegalArgumentException(s"Illegal function name, not Java identifier: $name") def addJVMFunction(f: JVMFunction): Unit = { requireJavaIdentifier(f.name) @@ -54,7 +56,7 @@ object IRFunctionRegistry { valueParameterTypes: Seq[Type], returnType: Type, alwaysInline: Boolean, - f: IRFunctionImplementation + f: IRFunctionImplementation, ): Unit = { requireJavaIdentifier(name) @@ -68,19 +70,21 @@ object IRFunctionRegistry { argNames: java.util.ArrayList[String], argTypeStrs: java.util.ArrayList[String], returnType: String, - body: IR + body: IR, ): Unit = { requireJavaIdentifier(name) val typeParameters = typeParamStrs.asScala.map(IRParser.parseType).toFastSeq val valueParameterTypes = argTypeStrs.asScala.map(IRParser.parseType).toFastSeq userAddedFunctions += ((name, (body.typ, typeParameters, valueParameterTypes))) - addIR(name, + addIR( + name, typeParameters, - valueParameterTypes, IRParser.parseType(returnType), false, { (_, args, _) => - Subst(body, - BindingEnv(Env[IR](argNames.asScala.zip(args): _*))) - }) + valueParameterTypes, + IRParser.parseType(returnType), + false, + (_, args, _) => Subst(body, BindingEnv(Env[IR](argNames.asScala.zip(args): _*))), + ) } def pyRegisterIRForServiceBackend( @@ -90,7 +94,7 @@ object IRFunctionRegistry { argNames: Array[String], argTypeStrs: Array[String], returnType: String, - bodyStr: String + bodyStr: String, ): Unit = { requireJavaIdentifier(name) @@ -100,7 +104,8 @@ object IRFunctionRegistry { val body = IRParser.parse_value_ir( bodyStr, IRParserEnvironment(ctx, Map()), - refMap) + refMap, + ) userAddedFunctions += ((name, (body.typ, typeParameters, valueParameterTypes))) addIR( @@ -109,10 +114,7 @@ object IRFunctionRegistry { valueParameterTypes, IRParser.parseType(returnType), false, - { (_, args, _) => - Subst(body, - BindingEnv(Env[IR](argNames.zip(args): _*))) - } + (_, args, _) => Subst(body, BindingEnv(Env[IR](argNames.zip(args): _*))), ) } @@ -120,7 +122,7 @@ object IRFunctionRegistry { name: String, returnType: Type, typeParameters: Seq[Type], - valueParameterTypes: Seq[Type] + valueParameterTypes: Seq[Type], ): Unit = { val m = irRegistry(name) m.remove((typeParameters, valueParameterTypes, returnType, false)) @@ -130,34 +132,44 @@ object IRFunctionRegistry { name: String, returnType: Type, typeParameters: Seq[Type], - valueParameterTypes: Seq[Type] - ): Option[JVMFunction] = { - jvmRegistry.lift(name).map { fs => fs.filter(t => t.unify(typeParameters, valueParameterTypes, returnType)).toSeq }.getOrElse(FastSeq()) match { + valueParameterTypes: Seq[Type], + ): Option[JVMFunction] = + jvmRegistry.get(name).map { fs => + fs.filter(t => t.unify(typeParameters, valueParameterTypes, returnType)).toSeq + }.getOrElse(FastSeq()) match { case Seq() => None case Seq(f) => Some(f) - case _ => fatal(s"Multiple functions found that satisfy $name(${ valueParameterTypes.mkString(",") }).") + case _ => + fatal(s"Multiple functions found that satisfy $name(${valueParameterTypes.mkString(",")}).") } - } def lookupFunctionOrFail( name: String, returnType: Type, typeParameters: Seq[Type], - valueParameterTypes: Seq[Type] + valueParameterTypes: Seq[Type], ): JVMFunction = { jvmRegistry.lift(name) match { case None => - fatal(s"no functions found with the signature $name(${valueParameterTypes.mkString(", ")}): $returnType") + fatal( + s"no functions found with the signature $name(${valueParameterTypes.mkString(", ")}): $returnType" + ) case Some(functions) => - functions.filter(t => t.unify(typeParameters, valueParameterTypes, returnType)).toSeq match { + functions.filter(t => + t.unify(typeParameters, valueParameterTypes, returnType) + ).toSeq match { case Seq() => - val prettyFunctionSignature = s"$name[${ typeParameters.mkString(", ") }](${ valueParameterTypes.mkString(", ") }): $returnType" + val prettyFunctionSignature = + s"$name[${typeParameters.mkString(", ")}](${valueParameterTypes.mkString(", ")}): $returnType" val prettyMismatchedFunctionSignatures = functions.map(x => s" $x").mkString("\n") fatal( s"No function found with the signature $prettyFunctionSignature.\n" + - s"However, there are other functions with that name:\n$prettyMismatchedFunctionSignatures") + s"However, there are other functions with that name:\n$prettyMismatchedFunctionSignatures" + ) case Seq(f) => f - case _ => fatal(s"Multiple functions found that satisfy $name(${ valueParameterTypes.mkString(", ") }).") + case _ => fatal( + s"Multiple functions found that satisfy $name(${valueParameterTypes.mkString(", ")})." + ) } } } @@ -166,41 +178,50 @@ object IRFunctionRegistry { name: String, returnType: Type, typeParameters: Seq[Type], - valueParameterTypes: Seq[Type] + valueParameterTypes: Seq[Type], ): Option[(IRFunctionSignature, IRFunctionImplementation)] = { - irRegistry.getOrElse(name, Map.empty).filter { case ((typeParametersFound: Seq[Type], valueParameterTypesFound: Seq[Type], _, _), _) => - typeParametersFound.length == typeParameters.length && { - typeParametersFound.foreach(_.clear()) - (typeParametersFound, typeParameters).zipped.forall(_.unify(_)) - } && valueParameterTypesFound.length == valueParameterTypes.length && { - valueParameterTypesFound.foreach(_.clear()) - (valueParameterTypesFound, valueParameterTypes).zipped.forall(_.unify(_)) - } + irRegistry.getOrElse(name, Map.empty).filter { + case ((typeParametersFound: Seq[Type], valueParameterTypesFound: Seq[Type], _, _), _) => + typeParametersFound.length == typeParameters.length && { + typeParametersFound.foreach(_.clear()) + (typeParametersFound, typeParameters).zipped.forall(_.unify(_)) + } && valueParameterTypesFound.length == valueParameterTypes.length && { + valueParameterTypesFound.foreach(_.clear()) + (valueParameterTypesFound, valueParameterTypes).zipped.forall(_.unify(_)) + } }.toSeq match { case Seq() => None case Seq(kv) => Some(kv) - case _ => fatal(s"Multiple functions found that satisfy $name(${valueParameterTypes.mkString(",")}).") + case _ => + fatal(s"Multiple functions found that satisfy $name(${valueParameterTypes.mkString(",")}).") } } - def lookupSeeded(name: String, staticUID: Long, returnType: Type, arguments: Seq[Type]): Option[(Seq[IR], IR) => IR] = { + def lookupSeeded(name: String, staticUID: Long, returnType: Type, arguments: Seq[Type]) + : Option[(Seq[IR], IR) => IR] = lookupFunction(name, returnType, Array.empty[Type], TRNGState +: arguments) - .map { f => - (irArguments: Seq[IR], rngState: IR) => ApplySeeded(name, irArguments, rngState, staticUID, f.returnType.subst()) + .map { f => (irArguments: Seq[IR], rngState: IR) => + ApplySeeded(name, irArguments, rngState, staticUID, f.returnType.subst()) } - } - def lookupUnseeded(name: String, returnType: Type, arguments: Seq[Type]): Option[IRFunctionImplementation] = + def lookupUnseeded(name: String, returnType: Type, arguments: Seq[Type]) + : Option[IRFunctionImplementation] = lookupUnseeded(name, returnType, Array.empty[Type], arguments) - def lookupUnseeded(name: String, returnType: Type, typeParameters: Seq[Type], arguments: Seq[Type]): Option[IRFunctionImplementation] = { - val validIR: Option[IRFunctionImplementation] = lookupIR(name, returnType, typeParameters, arguments).map { - case ((_, _, _, inline), conversion) => (typeParametersPassed, args, errorID) => - val x = ApplyIR(name, typeParametersPassed, args, returnType, errorID) - x.conversion = conversion - x.inline = inline - x - } + def lookupUnseeded( + name: String, + returnType: Type, + typeParameters: Seq[Type], + arguments: Seq[Type], + ): Option[IRFunctionImplementation] = { + val validIR: Option[IRFunctionImplementation] = + lookupIR(name, returnType, typeParameters, arguments).map { + case ((_, _, _, inline), conversion) => (typeParametersPassed, args, errorID) => + val x = ApplyIR(name, typeParametersPassed, args, returnType, errorID) + x.conversion = conversion + x.inline = inline + x + } val validMethods = lookupFunction(name, returnType, typeParameters, arguments) .map { f => @@ -215,10 +236,10 @@ object IRFunctionRegistry { } (validIR, validMethods) match { - case (None , None) => None - case (None , Some(x)) => Some(x) - case (Some(x), None) => Some(x) - case _ => fatal(s"Multiple methods found that satisfy $name(${ arguments.mkString(",") }).") + case (None, None) => None + case (None, Some(x)) => Some(x) + case (Some(x), None) => Some(x) + case _ => fatal(s"Multiple methods found that satisfy $name(${arguments.mkString(",")}).") } } @@ -238,33 +259,45 @@ object IRFunctionRegistry { ExperimentalFunctions, ReferenceGenomeFunctions, BGENFunctions, - ApproxCDFFunctions + ApproxCDFFunctions, ).foreach(_.registerAll()) def dumpFunctions(): Unit = { - def dtype(t: Type): String = s"""dtype("${ StringEscapeUtils.escapeString(t.toString) }\")""" + def dtype(t: Type): String = s"""dtype("${StringEscapeUtils.escapeString(t.toString)}\")""" irRegistry.foreach { case (name, fns) => - fns.foreach { case ((typeParameters, valueParameterTypes, returnType, _), f) => - println(s"""register_function("${ StringEscapeUtils.escapeString(name) }", (${ typeParameters.map(dtype).mkString(",") }), (${ valueParameterTypes.map(dtype).mkString(",") }), ${ dtype(returnType) })""") - } + fns.foreach { case ((typeParameters, valueParameterTypes, returnType, _), _) => + println(s"""register_function("${StringEscapeUtils.escapeString(name)}", (${typeParameters.map( + dtype + ).mkString(",")}), (${valueParameterTypes.map(dtype).mkString(",")}), ${dtype(returnType)})""") + } } jvmRegistry.foreach { case (name, fns) => - fns.foreach { f => - println(s"""register_function("${ StringEscapeUtils.escapeString(name) }", (${ f.typeParameters.map(dtype).mkString(",") }), (${ f.valueParameterTypes.map(dtype).mkString(",") }), ${ dtype(f.returnType) })""") - } + fns.foreach { f => + println( + s"""register_function("${StringEscapeUtils.escapeString(name)}", (${f.typeParameters.map(dtype).mkString( + "," + )}), (${f.valueParameterTypes.map(dtype).mkString(",")}), ${dtype(f.returnType)})""" + ) + } } } } object RegistryHelpers { - def stupidUnwrapStruct(rgs: Map[String, ReferenceGenome], r: Region, value: Row, ptype: PType): Long = { + def stupidUnwrapStruct(rgs: Map[String, ReferenceGenome], r: Region, value: Row, ptype: PType) + : Long = { assert(value != null) ptype.unstagedStoreJavaObject(HailStateManager(rgs), value, r) } - def stupidUnwrapArray(rgs: Map[String, ReferenceGenome], r: Region, value: IndexedSeq[Annotation], ptype: PType): Long = { + def stupidUnwrapArray( + rgs: Map[String, ReferenceGenome], + r: Region, + value: IndexedSeq[Annotation], + ptype: PType, + ): Long = { assert(value != null) ptype.unstagedStoreJavaObject(HailStateManager(rgs), value, r) } @@ -274,8 +307,6 @@ abstract class RegistryFunctions { def registerAll(): Unit - private val boxes = mutable.Map[String, Box[Type]]() - def tv(name: String): TVariable = TVariable(name) @@ -297,7 +328,8 @@ abstract class RegistryFunctions { case _ => classInfo[AnyRef] } - def svalueToJavaValue(cb: EmitCodeBuilder, r: Value[Region], sc: SValue, safe: Boolean = false): Value[AnyRef] = { + def svalueToJavaValue(cb: EmitCodeBuilder, r: Value[Region], sc: SValue, safe: Boolean = false) + : Value[AnyRef] = { sc.st match { case SInt32 => cb.memoize(Code.boxInt(sc.asInt32.value)) case SInt64 => cb.memoize(Code.boxLong(sc.asInt64.value)) @@ -311,86 +343,140 @@ abstract class RegistryFunctions { val pt = PType.canonical(t.storageType()) val addr = pt.store(cb, r, sc, deepCopy = false) cb.memoize(Code.invokeScalaObject3[PType, Region, Long, AnyRef]( - if (safe) SafeRow.getClass else UnsafeRow.getClass, "readAnyRef", + if (safe) SafeRow.getClass else UnsafeRow.getClass, + "readAnyRef", cb.emb.getPType(pt), - r, addr)) + r, + addr, + )) } } - def unwrapReturn(cb: EmitCodeBuilder, r: Value[Region], st: SType, value: Code[_]): SValue = st.virtualType match { - case TBoolean => primitive(cb.memoize(coerce[Boolean](value))) - case TInt32 => primitive(cb.memoize(coerce[Int](value))) - case TInt64 => primitive(cb.memoize(coerce[Long](value))) - case TFloat32 => primitive(cb.memoize(coerce[Float](value))) - case TFloat64 => primitive(cb.memoize(coerce[Double](value))) - case TString => - val sst = st.asInstanceOf[SJavaString.type] - sst.constructFromString(cb, r, coerce[String](value)) - case TCall => - assert(st == SCanonicalCall) - new SCanonicalCallValue(cb.memoize(coerce[Int](value))) - case TArray(TInt32) => - val ast = st.asInstanceOf[SIndexablePointer] - val pca = ast.pType.asInstanceOf[PCanonicalArray] - val arr = cb.newLocal[IndexedSeq[Int]]("unrwrap_return_array_int32_arr", coerce[IndexedSeq[Int]](value)) - val len = cb.newLocal[Int]("unwrap_return_array_int32_len", arr.invoke[Int]("length")) - pca.constructFromElements(cb, r, len, deepCopy = false) { (cb, idx) => - val elt = cb.newLocal[java.lang.Integer]("unwrap_return_array_int32_elt", - Code.checkcast[java.lang.Integer](arr.invoke[Int, java.lang.Object]("apply", idx))) - IEmitCode(cb, elt.isNull, primitive(cb.memoize(elt.invoke[Int]("intValue")))) - } - case TArray(TInt64) => - val ast = st.asInstanceOf[SIndexablePointer] - val pca = ast.pType.asInstanceOf[PCanonicalArray] - val arr = cb.newLocal[IndexedSeq[Int]]("unrwrap_return_array_int64_arr", coerce[IndexedSeq[Int]](value)) - val len = cb.newLocal[Int]("unwrap_return_array_int64_len", arr.invoke[Int]("length")) - pca.constructFromElements(cb, r, len, deepCopy = false) { (cb, idx) => - val elt = cb.newLocal[java.lang.Long]("unwrap_return_array_int64_elt", - Code.checkcast[java.lang.Long](arr.invoke[Int, java.lang.Object]("apply", idx))) - IEmitCode(cb, elt.isNull, primitive(cb.memoize(elt.invoke[Long]("longValue")))) - } - case TArray(TFloat64) => - val ast = st.asInstanceOf[SIndexablePointer] - val pca = ast.pType.asInstanceOf[PCanonicalArray] - val arr = cb.newLocal[IndexedSeq[Double]]("unrwrap_return_array_float64_arr", coerce[IndexedSeq[Double]](value)) - val len = cb.newLocal[Int]("unwrap_return_array_float64_len", arr.invoke[Int]("length")) - pca.constructFromElements(cb, r, len, deepCopy = false) { (cb, idx) => - val elt = cb.newLocal[java.lang.Double]("unwrap_return_array_float64_elt", - Code.checkcast[java.lang.Double](arr.invoke[Int, java.lang.Object]("apply", idx))) - IEmitCode(cb, elt.isNull, primitive(cb.memoize(elt.invoke[Double]("doubleValue")))) - } - case TArray(TString) => - val ast = st.asInstanceOf[SJavaArrayString] - ast.construct(cb, coerce[Array[String]](value)) - case t: TBaseStruct => - val sst = st.asInstanceOf[SBaseStructPointer] - val pt = sst.pType.asInstanceOf[PCanonicalBaseStruct] - val addr = cb.memoize(Code.invokeScalaObject4[Map[String, ReferenceGenome], Region, Row, PType, Long]( - RegistryHelpers.getClass, "stupidUnwrapStruct", cb.emb.ecb.emodb.referenceGenomeMap, r.region, coerce[Row](value), cb.emb.ecb.getPType(pt))) - new SBaseStructPointerValue(SBaseStructPointer(pt.setRequired(false).asInstanceOf[PBaseStruct]), addr) - case TArray(t: TBaseStruct) => - val ast = st.asInstanceOf[SIndexablePointer] - val pca = ast.pType.asInstanceOf[PCanonicalArray] - val array = cb.memoize(Code.invokeScalaObject4[Map[String, ReferenceGenome], Region, IndexedSeq[Annotation], PType, Long]( - RegistryHelpers.getClass, "stupidUnwrapArray", cb.emb.ecb.emodb.referenceGenomeMap, r.region, coerce[IndexedSeq[Annotation]](value), cb.emb.ecb.getPType(pca))) - new SIndexablePointerValue(ast, array, cb.memoize(pca.loadLength(array)), cb.memoize(pca.firstElementOffset(array))) - } + def unwrapReturn(cb: EmitCodeBuilder, r: Value[Region], st: SType, value: Code[_]): SValue = + st.virtualType match { + case TBoolean => primitive(cb.memoize(coerce[Boolean](value))) + case TInt32 => primitive(cb.memoize(coerce[Int](value))) + case TInt64 => primitive(cb.memoize(coerce[Long](value))) + case TFloat32 => primitive(cb.memoize(coerce[Float](value))) + case TFloat64 => primitive(cb.memoize(coerce[Double](value))) + case TString => + val sst = st.asInstanceOf[SJavaString.type] + sst.constructFromString(cb, r, coerce[String](value)) + case TCall => + assert(st == SCanonicalCall) + new SCanonicalCallValue(cb.memoize(coerce[Int](value))) + case TArray(TInt32) => + val ast = st.asInstanceOf[SIndexablePointer] + val pca = ast.pType.asInstanceOf[PCanonicalArray] + val arr = cb.newLocal[IndexedSeq[Int]]( + "unrwrap_return_array_int32_arr", + coerce[IndexedSeq[Int]](value), + ) + val len = cb.newLocal[Int]("unwrap_return_array_int32_len", arr.invoke[Int]("length")) + pca.constructFromElements(cb, r, len, deepCopy = false) { (cb, idx) => + val elt = cb.newLocal[java.lang.Integer]( + "unwrap_return_array_int32_elt", + Code.checkcast[java.lang.Integer](arr.invoke[Int, java.lang.Object]("apply", idx)), + ) + IEmitCode(cb, elt.isNull, primitive(cb.memoize(elt.invoke[Int]("intValue")))) + } + case TArray(TInt64) => + val ast = st.asInstanceOf[SIndexablePointer] + val pca = ast.pType.asInstanceOf[PCanonicalArray] + val arr = cb.newLocal[IndexedSeq[Int]]( + "unrwrap_return_array_int64_arr", + coerce[IndexedSeq[Int]](value), + ) + val len = cb.newLocal[Int]("unwrap_return_array_int64_len", arr.invoke[Int]("length")) + pca.constructFromElements(cb, r, len, deepCopy = false) { (cb, idx) => + val elt = cb.newLocal[java.lang.Long]( + "unwrap_return_array_int64_elt", + Code.checkcast[java.lang.Long](arr.invoke[Int, java.lang.Object]("apply", idx)), + ) + IEmitCode(cb, elt.isNull, primitive(cb.memoize(elt.invoke[Long]("longValue")))) + } + case TArray(TFloat64) => + val ast = st.asInstanceOf[SIndexablePointer] + val pca = ast.pType.asInstanceOf[PCanonicalArray] + val arr = cb.newLocal[IndexedSeq[Double]]( + "unrwrap_return_array_float64_arr", + coerce[IndexedSeq[Double]](value), + ) + val len = cb.newLocal[Int]("unwrap_return_array_float64_len", arr.invoke[Int]("length")) + pca.constructFromElements(cb, r, len, deepCopy = false) { (cb, idx) => + val elt = cb.newLocal[java.lang.Double]( + "unwrap_return_array_float64_elt", + Code.checkcast[java.lang.Double](arr.invoke[Int, java.lang.Object]("apply", idx)), + ) + IEmitCode(cb, elt.isNull, primitive(cb.memoize(elt.invoke[Double]("doubleValue")))) + } + case TArray(TString) => + val ast = st.asInstanceOf[SJavaArrayString] + ast.construct(cb, coerce[Array[String]](value)) + case _: TBaseStruct => + val sst = st.asInstanceOf[SBaseStructPointer] + val pt = sst.pType.asInstanceOf[PCanonicalBaseStruct] + val addr = cb.memoize(Code.invokeScalaObject4[Map[ + String, + ReferenceGenome, + ], Region, Row, PType, Long]( + RegistryHelpers.getClass, + "stupidUnwrapStruct", + cb.emb.ecb.emodb.referenceGenomeMap, + r.region, + coerce[Row](value), + cb.emb.ecb.getPType(pt), + )) + new SBaseStructPointerValue( + SBaseStructPointer(pt.setRequired(false).asInstanceOf[PBaseStruct]), + addr, + ) + case TArray(_: TBaseStruct) => + val ast = st.asInstanceOf[SIndexablePointer] + val pca = ast.pType.asInstanceOf[PCanonicalArray] + val array = cb.memoize(Code.invokeScalaObject4[Map[ + String, + ReferenceGenome, + ], Region, IndexedSeq[Annotation], PType, Long]( + RegistryHelpers.getClass, + "stupidUnwrapArray", + cb.emb.ecb.emodb.referenceGenomeMap, + r.region, + coerce[IndexedSeq[Annotation]](value), + cb.emb.ecb.getPType(pca), + )) + new SIndexablePointerValue( + ast, + array, + cb.memoize(pca.loadLength(array)), + cb.memoize(pca.firstElementOffset(array)), + ) + } def registerSCode( name: String, valueParameterTypes: Array[Type], returnType: Type, calculateReturnType: (Type, Seq[SType]) => SType, - typeParameters: Array[Type] = Array.empty + typeParameters: Array[Type] = Array.empty, )( impl: (EmitRegion, EmitCodeBuilder, Seq[Type], SType, Array[SValue], Value[Int]) => SValue - ) { + ): Unit = { IRFunctionRegistry.addJVMFunction( - new UnseededMissingnessObliviousJVMFunction(name, typeParameters, valueParameterTypes, returnType, calculateReturnType) { - override def apply(r: EmitRegion, cb: EmitCodeBuilder, returnSType: SType, typeParameters: Seq[Type], errorID: Value[Int], args: SValue*): SValue = + new UnseededMissingnessObliviousJVMFunction(name, typeParameters, valueParameterTypes, + returnType, calculateReturnType) { + override def apply( + r: EmitRegion, + cb: EmitCodeBuilder, + returnSType: SType, + typeParameters: Seq[Type], + errorID: Value[Int], + args: SValue* + ): SValue = impl(r, cb, typeParameters, returnSType, args.toArray, errorID) - }) + } + ) } def registerCode( @@ -398,18 +484,27 @@ abstract class RegistryFunctions { valueParameterTypes: Array[Type], returnType: Type, calculateReturnType: (Type, Seq[SType]) => SType, - typeParameters: Array[Type] = Array.empty + typeParameters: Array[Type] = Array.empty, )( impl: (EmitRegion, EmitCodeBuilder, SType, Array[Type], Array[SValue]) => Value[_] - ) { + ): Unit = { IRFunctionRegistry.addJVMFunction( - new UnseededMissingnessObliviousJVMFunction(name, typeParameters, valueParameterTypes, returnType, calculateReturnType) { - override def apply(r: EmitRegion, cb: EmitCodeBuilder, returnSType: SType, typeParameters: Seq[Type], errorID: Value[Int], args: SValue*): SValue = { + new UnseededMissingnessObliviousJVMFunction(name, typeParameters, valueParameterTypes, + returnType, calculateReturnType) { + override def apply( + r: EmitRegion, + cb: EmitCodeBuilder, + returnSType: SType, + typeParameters: Seq[Type], + errorID: Value[Int], + args: SValue* + ): SValue = { assert(unify(typeParameters, args.map(_.st.virtualType), returnSType.virtualType)) val returnValue = impl(r, cb, returnSType, typeParameters.toArray, args.toArray) returnSType.fromValues(FastSeq(returnValue)) } - }) + } + ) } def registerEmitCode( @@ -417,17 +512,25 @@ abstract class RegistryFunctions { valueParameterTypes: Array[Type], returnType: Type, calculateReturnType: (Type, Seq[EmitType]) => EmitType, - typeParameters: Array[Type] = Array.empty + typeParameters: Array[Type] = Array.empty, )( - impl: (EmitRegion, SType, Value[Int], Array[EmitCode]) => EmitCode - ) { + impl: (EmitRegion, SType, Value[Int], Array[EmitCode]) => EmitCode + ): Unit = { IRFunctionRegistry.addJVMFunction( - new UnseededMissingnessAwareJVMFunction(name, typeParameters, valueParameterTypes, returnType, calculateReturnType) { - override def apply(r: EmitRegion, rpt: SType, typeParameters: Seq[Type], errorID: Value[Int], args: EmitCode*): EmitCode = { + new UnseededMissingnessAwareJVMFunction(name, typeParameters, valueParameterTypes, returnType, + calculateReturnType) { + override def apply( + r: EmitRegion, + rpt: SType, + typeParameters: Seq[Type], + errorID: Value[Int], + args: EmitCode* + ): EmitCode = { assert(unify(typeParameters, args.map(_.st.virtualType), rpt.virtualType)) impl(r, rpt, errorID, args.toArray) } - }) + } + ) } def registerIEmitCode( @@ -435,12 +538,13 @@ abstract class RegistryFunctions { valueParameterTypes: Array[Type], returnType: Type, calculateReturnType: (Type, Seq[EmitType]) => EmitType, - typeParameters: Array[Type] = Array.empty + typeParameters: Array[Type] = Array.empty, )( - impl: (EmitCodeBuilder, Value[Region], SType , Value[Int], Array[EmitCode]) => IEmitCode - ) { + impl: (EmitCodeBuilder, Value[Region], SType, Value[Int], Array[EmitCode]) => IEmitCode + ): Unit = { IRFunctionRegistry.addJVMFunction( - new UnseededMissingnessAwareJVMFunction(name, typeParameters, valueParameterTypes, returnType, calculateReturnType) { + new UnseededMissingnessAwareJVMFunction(name, typeParameters, valueParameterTypes, returnType, + calculateReturnType) { override def apply( cb: EmitCodeBuilder, r: Value[Region], @@ -451,34 +555,47 @@ abstract class RegistryFunctions { ): IEmitCode = { val res = impl(cb, r, rpt, errorID, args.toArray) if (res.emitType != calculateReturnType(rpt.virtualType, args.map(_.emitType))) - throw new RuntimeException(s"type mismatch while registering $name" + - s"\n got ${ res.emitType }, got ${ calculateReturnType(rpt.virtualType, args.map(_.emitType)) }") + throw new RuntimeException( + s"type mismatch while registering $name" + + s"\n got ${res.emitType}, got ${calculateReturnType(rpt.virtualType, args.map(_.emitType))}" + ) res } - override def apply(r: EmitRegion, rpt: SType, typeParameters: Seq[Type], errorID: Value[Int], args: EmitCode*): EmitCode = { - EmitCode.fromI(r.mb) { cb => - apply(cb, r.region, rpt, typeParameters, errorID, args: _*) - } - } - }) + override def apply( + r: EmitRegion, + rpt: SType, + typeParameters: Seq[Type], + errorID: Value[Int], + args: EmitCode* + ): EmitCode = + EmitCode.fromI(r.mb)(cb => apply(cb, r.region, rpt, typeParameters, errorID, args: _*)) + } + ) } def registerScalaFunction( name: String, valueParameterTypes: Array[Type], returnType: Type, - calculateReturnType: (Type, Seq[SType]) => SType + calculateReturnType: (Type, Seq[SType]) => SType, )( cls: Class[_], - method: String - ) { - registerSCode(name, valueParameterTypes, returnType, calculateReturnType) { case (r, cb, _, rt, args, _) => - val cts = valueParameterTypes.map(PrimitiveTypeToIRIntermediateClassTag(_).runtimeClass) - - val returnValue = cb.memoizeAny( - Code.invokeScalaObject(cls, method, cts, args.map { a => SType.extractPrimValue(cb, a).get })(PrimitiveTypeToIRIntermediateClassTag(returnType)), - rt.settableTupleTypes()(0)) - rt.fromValues(FastSeq(returnValue)) + method: String, + ): Unit = { + registerSCode(name, valueParameterTypes, returnType, calculateReturnType) { + case (_, cb, _, rt, args, _) => + val cts = valueParameterTypes.map(PrimitiveTypeToIRIntermediateClassTag(_).runtimeClass) + + val returnValue = cb.memoizeAny( + Code.invokeScalaObject( + cls, + method, + cts, + args.map(a => SType.extractPrimValue(cb, a).get), + )(PrimitiveTypeToIRIntermediateClassTag(returnType)), + rt.settableTupleTypes()(0), + ) + rt.fromValues(FastSeq(returnValue)) } } @@ -486,11 +603,11 @@ abstract class RegistryFunctions { name: String, valueParameterTypes: Array[Type], returnType: Type, - calculateReturnType: (Type, Seq[SType]) => SType + calculateReturnType: (Type, Seq[SType]) => SType, )( cls: Class[_], - method: String - ) { + method: String, + ): Unit = { def ct(typ: Type): ClassTag[_] = typ match { case TString => classTag[String] case TArray(TInt32) => classTag[IndexedSeq[Int]] @@ -502,182 +619,550 @@ abstract class RegistryFunctions { case t => PrimitiveTypeToIRIntermediateClassTag(t) } - def wrap(cb: EmitCodeBuilder, r: Value[Region], code: SValue): Value[_] = code.st.virtualType match { - case t if t.isPrimitive => SType.extractPrimValue(cb, code) - case TCall => code.asCall.canonicalCall(cb) - case TArray(TString) => code.st match { - case _: SJavaArrayString => cb.memoize(code.asInstanceOf[SJavaArrayStringValue].array) - case _ => - val sv = code.asIndexable - val arr = cb.newLocal[Array[String]]("scode_array_string", Code.newArray[String](sv.loadLength())) - sv.forEachDefined(cb) { case (cb, idx, elt) => - cb += (arr(idx) = elt.asString.loadString(cb)) + def wrap(cb: EmitCodeBuilder, r: Value[Region], code: SValue): Value[_] = + code.st.virtualType match { + case t if t.isPrimitive => SType.extractPrimValue(cb, code) + case TCall => code.asCall.canonicalCall(cb) + case TArray(TString) => code.st match { + case _: SJavaArrayString => cb.memoize(code.asInstanceOf[SJavaArrayStringValue].array) + case _ => + val sv = code.asIndexable + val arr = cb.newLocal[Array[String]]( + "scode_array_string", + Code.newArray[String](sv.loadLength()), + ) + sv.forEachDefined(cb) { case (cb, idx, elt) => + cb += (arr(idx) = elt.asString.loadString(cb)) + } + arr } - arr + case _ => svalueToJavaValue(cb, r, code) } - case _ => svalueToJavaValue(cb, r, code) - } - registerSCode(name, valueParameterTypes, returnType, calculateReturnType) { case (r, cb, _, rt, args, _) => - val cts = valueParameterTypes.map(ct(_).runtimeClass) - try { - unwrapReturn(cb, r.region, rt, - Code.invokeScalaObject(cls, method, cts, args.map { a => wrap(cb, r.region, a).get })(ct(returnType))) - } catch { - case e: Throwable => throw new RuntimeException(s"error while registering function $name", e) - } + registerSCode(name, valueParameterTypes, returnType, calculateReturnType) { + case (r, cb, _, rt, args, _) => + val cts = valueParameterTypes.map(ct(_).runtimeClass) + try + unwrapReturn( + cb, + r.region, + rt, + Code.invokeScalaObject(cls, method, cts, args.map(a => wrap(cb, r.region, a).get))( + ct(returnType) + ), + ) + catch { + case e: Throwable => + throw new RuntimeException(s"error while registering function $name", e) + } } } - def registerWrappedScalaFunction1(name: String, a1: Type, returnType: Type, pt: (Type, SType) => SType)(cls: Class[_], method: String): Unit = + def registerWrappedScalaFunction1( + name: String, + a1: Type, + returnType: Type, + pt: (Type, SType) => SType, + )( + cls: Class[_], + method: String, + ): Unit = registerWrappedScalaFunction(name, Array(a1), returnType, unwrappedApply(pt))(cls, method) - def registerWrappedScalaFunction2(name: String, a1: Type, a2: Type, returnType: Type, pt: (Type, SType, SType) => SType)(cls: Class[_], method: String): Unit = + def registerWrappedScalaFunction2( + name: String, + a1: Type, + a2: Type, + returnType: Type, + pt: (Type, SType, SType) => SType, + )( + cls: Class[_], + method: String, + ): Unit = registerWrappedScalaFunction(name, Array(a1, a2), returnType, unwrappedApply(pt))(cls, method) - def registerWrappedScalaFunction3(name: String, a1: Type, a2: Type, a3: Type, returnType: Type, - pt: (Type, SType, SType, SType) => SType)(cls: Class[_], method: String): Unit = - registerWrappedScalaFunction(name, Array(a1, a2, a3), returnType, unwrappedApply(pt))(cls, method) + def registerWrappedScalaFunction3( + name: String, + a1: Type, + a2: Type, + a3: Type, + returnType: Type, + pt: (Type, SType, SType, SType) => SType, + )( + cls: Class[_], + method: String, + ): Unit = + registerWrappedScalaFunction(name, Array(a1, a2, a3), returnType, unwrappedApply(pt))( + cls, + method, + ) - def registerWrappedScalaFunction4(name: String, a1: Type, a2: Type, a3: Type, a4: Type, returnType: Type, - pt: (Type, SType, SType, SType, SType) => SType)(cls: Class[_], method: String): Unit = - registerWrappedScalaFunction(name, Array(a1, a2, a3, a4), returnType, unwrappedApply(pt))(cls, method) + def registerWrappedScalaFunction4( + name: String, + a1: Type, + a2: Type, + a3: Type, + a4: Type, + returnType: Type, + pt: (Type, SType, SType, SType, SType) => SType, + )( + cls: Class[_], + method: String, + ): Unit = + registerWrappedScalaFunction(name, Array(a1, a2, a3, a4), returnType, unwrappedApply(pt))( + cls, + method, + ) - def registerJavaStaticFunction(name: String, valueParameterTypes: Array[Type], returnType: Type, pt: (Type, Seq[SType]) => SType)(cls: Class[_], method: String) { - registerCode(name, valueParameterTypes, returnType, pt) { case (r, cb, rt, _, args) => + def registerJavaStaticFunction( + name: String, + valueParameterTypes: Array[Type], + returnType: Type, + pt: (Type, Seq[SType]) => SType, + )( + cls: Class[_], + method: String, + ): Unit = { + registerCode(name, valueParameterTypes, returnType, pt) { case (_, cb, _, _, args) => val cts = valueParameterTypes.map(PrimitiveTypeToIRIntermediateClassTag(_).runtimeClass) val ct = PrimitiveTypeToIRIntermediateClassTag(returnType) cb.memoizeAny( Code.invokeStatic(cls, method, cts, args.map(a => SType.extractPrimValue(cb, a).get))(ct), - typeInfoFromClassTag(ct)) + typeInfoFromClassTag(ct), + ) } } - def registerIR(name: String, valueParameterTypes: Array[Type], returnType: Type, inline: Boolean = false, typeParameters: Array[Type] = Array.empty)(f: (Seq[Type], Seq[IR], Int) => IR): Unit = + def registerIR( + name: String, + valueParameterTypes: Array[Type], + returnType: Type, + inline: Boolean = false, + typeParameters: Array[Type] = Array.empty, + )( + f: (Seq[Type], Seq[IR], Int) => IR + ): Unit = IRFunctionRegistry.addIR(name, typeParameters, valueParameterTypes, returnType, inline, f) - def registerSCode1(name: String, mt1: Type, rt: Type, pt: (Type, SType) => SType)(impl: (EmitRegion, EmitCodeBuilder, SType, SValue, Value[Int]) => SValue): Unit = + def registerSCode1( + name: String, + mt1: Type, + rt: Type, + pt: (Type, SType) => SType, + )( + impl: (EmitRegion, EmitCodeBuilder, SType, SValue, Value[Int]) => SValue + ): Unit = registerSCode(name, Array(mt1), rt, unwrappedApply(pt)) { case (r, cb, _, rt, Array(a1), errorID) => impl(r, cb, rt, a1, errorID) } - def registerSCode1t(name: String, typeParams: Array[Type], mt1: Type, rt: Type, pt: (Type, SType) => SType)(impl: (EmitRegion, EmitCodeBuilder, Seq[Type], SType, SValue, Value[Int]) => SValue): Unit = + def registerSCode1t( + name: String, + typeParams: Array[Type], + mt1: Type, + rt: Type, + pt: (Type, SType) => SType, + )( + impl: (EmitRegion, EmitCodeBuilder, Seq[Type], SType, SValue, Value[Int]) => SValue + ): Unit = registerSCode(name, Array(mt1), rt, unwrappedApply(pt), typeParameters = typeParams) { case (r, cb, typeParams, rt, Array(a1), errorID) => impl(r, cb, typeParams, rt, a1, errorID) } - def registerSCode2(name: String, mt1: Type, mt2: Type, rt: Type, pt: (Type, SType, SType) => SType) - (impl: (EmitRegion, EmitCodeBuilder, SType, SValue, SValue, Value[Int]) => SValue): Unit = + def registerSCode2( + name: String, + mt1: Type, + mt2: Type, + rt: Type, + pt: (Type, SType, SType) => SType, + )( + impl: (EmitRegion, EmitCodeBuilder, SType, SValue, SValue, Value[Int]) => SValue + ): Unit = registerSCode(name, Array(mt1, mt2), rt, unwrappedApply(pt)) { - case (r, cb, _, rt, Array(a1, a2) , errorID) => impl(r, cb, rt, a1, a2, errorID) + case (r, cb, _, rt, Array(a1, a2), errorID) => impl(r, cb, rt, a1, a2, errorID) } - def registerSCode2t(name: String, typeParams: Array[Type], mt1: Type, mt2: Type, rt: Type, pt: (Type, SType, SType) => SType) - (impl: (EmitRegion, EmitCodeBuilder, Seq[Type], SType, SValue, SValue, Value[Int]) => SValue): Unit = + def registerSCode2t( + name: String, + typeParams: Array[Type], + mt1: Type, + mt2: Type, + rt: Type, + pt: (Type, SType, SType) => SType, + )( + impl: (EmitRegion, EmitCodeBuilder, Seq[Type], SType, SValue, SValue, Value[Int]) => SValue + ): Unit = registerSCode(name, Array(mt1, mt2), rt, unwrappedApply(pt), typeParameters = typeParams) { - case (r, cb, typeParams, rt, Array(a1, a2), errorID) => impl(r, cb, typeParams, rt, a1, a2, errorID) + case (r, cb, typeParams, rt, Array(a1, a2), errorID) => + impl(r, cb, typeParams, rt, a1, a2, errorID) } - def registerSCode3(name: String, mt1: Type, mt2: Type, mt3: Type, rt: Type, pt: (Type, SType, SType, SType) => SType) - (impl: (EmitRegion, EmitCodeBuilder, SType, SValue, SValue, SValue, Value[Int]) => SValue): Unit = + def registerSCode3( + name: String, + mt1: Type, + mt2: Type, + mt3: Type, + rt: Type, + pt: (Type, SType, SType, SType) => SType, + )( + impl: (EmitRegion, EmitCodeBuilder, SType, SValue, SValue, SValue, Value[Int]) => SValue + ): Unit = registerSCode(name, Array(mt1, mt2, mt3), rt, unwrappedApply(pt)) { case (r, cb, _, rt, Array(a1, a2, a3), errorID) => impl(r, cb, rt, a1, a2, a3, errorID) } - def registerSCode4(name: String, mt1: Type, mt2: Type, mt3: Type, mt4: Type, rt: Type, pt: (Type, SType, SType, SType, SType) => SType) - (impl: (EmitRegion, EmitCodeBuilder, SType, SValue, SValue, SValue, SValue, Value[Int]) => SValue): Unit = + def registerSCode4( + name: String, + mt1: Type, + mt2: Type, + mt3: Type, + mt4: Type, + rt: Type, + pt: (Type, SType, SType, SType, SType) => SType, + )( + impl: (EmitRegion, EmitCodeBuilder, SType, SValue, SValue, SValue, SValue, Value[Int]) => SValue + ): Unit = registerSCode(name, Array(mt1, mt2, mt3, mt4), rt, unwrappedApply(pt)) { - case (r, cb, _, rt, Array(a1, a2, a3, a4), errorID) => impl(r, cb, rt, a1, a2, a3, a4, errorID) + case (r, cb, _, rt, Array(a1, a2, a3, a4), errorID) => + impl(r, cb, rt, a1, a2, a3, a4, errorID) } - def registerSCode4t(name: String, typeParams: Array[Type], mt1: Type, mt2: Type, mt3: Type, mt4: Type, rt: Type, - pt: (Type, SType, SType, SType, SType) => SType) - (impl: (EmitRegion, EmitCodeBuilder, Seq[Type], SType, SValue, SValue, SValue, SValue, Value[Int]) => SValue): Unit = + def registerSCode4t( + name: String, + typeParams: Array[Type], + mt1: Type, + mt2: Type, + mt3: Type, + mt4: Type, + rt: Type, + pt: (Type, SType, SType, SType, SType) => SType, + )( + impl: ( + EmitRegion, + EmitCodeBuilder, + Seq[Type], + SType, + SValue, + SValue, + SValue, + SValue, + Value[Int], + ) => SValue + ): Unit = registerSCode(name, Array(mt1, mt2, mt3, mt4), rt, unwrappedApply(pt), typeParams) { - case (r, cb, typeParams, rt, Array(a1, a2, a3, a4), errorID) => impl(r, cb, typeParams, rt, a1, a2, a3, a4, errorID) + case (r, cb, typeParams, rt, Array(a1, a2, a3, a4), errorID) => + impl(r, cb, typeParams, rt, a1, a2, a3, a4, errorID) } - - def registerSCode5(name: String, mt1: Type, mt2: Type, mt3: Type, mt4: Type, mt5: Type, rt: Type, pt: (Type, SType, SType, SType, SType, SType) => SType) - (impl: (EmitRegion, EmitCodeBuilder, SType, SValue, SValue, SValue, SValue, SValue, Value[Int]) => SValue): Unit = + def registerSCode5( + name: String, + mt1: Type, + mt2: Type, + mt3: Type, + mt4: Type, + mt5: Type, + rt: Type, + pt: (Type, SType, SType, SType, SType, SType) => SType, + )( + impl: ( + EmitRegion, + EmitCodeBuilder, + SType, + SValue, + SValue, + SValue, + SValue, + SValue, + Value[Int], + ) => SValue + ): Unit = registerSCode(name, Array(mt1, mt2, mt3, mt4, mt5), rt, unwrappedApply(pt)) { - case (r, cb, _, rt, Array(a1, a2, a3, a4, a5), errorID) => impl(r, cb, rt, a1, a2, a3, a4, a5, errorID) + case (r, cb, _, rt, Array(a1, a2, a3, a4, a5), errorID) => + impl(r, cb, rt, a1, a2, a3, a4, a5, errorID) } - def registerSCode6(name: String, mt1: Type, mt2: Type, mt3: Type, mt4: Type, mt5: Type, mt6: Type, rt: Type, pt: (Type, SType, SType, SType, SType, SType, SType) => SType) - (impl: (EmitRegion, EmitCodeBuilder, SType, SValue, SValue, SValue, SValue, SValue, SValue, Value[Int]) => SValue): Unit = + def registerSCode6( + name: String, + mt1: Type, + mt2: Type, + mt3: Type, + mt4: Type, + mt5: Type, + mt6: Type, + rt: Type, + pt: (Type, SType, SType, SType, SType, SType, SType) => SType, + )( + impl: ( + EmitRegion, + EmitCodeBuilder, + SType, + SValue, + SValue, + SValue, + SValue, + SValue, + SValue, + Value[Int], + ) => SValue + ): Unit = registerSCode(name, Array(mt1, mt2, mt3, mt4, mt5, mt6), rt, unwrappedApply(pt)) { - case (r, cb, _, rt, Array(a1, a2, a3, a4, a5, a6), errorID) => impl(r, cb, rt, a1, a2, a3, a4, a5, a6, errorID) + case (r, cb, _, rt, Array(a1, a2, a3, a4, a5, a6), errorID) => + impl(r, cb, rt, a1, a2, a3, a4, a5, a6, errorID) } - def registerSCode7(name: String, mt1: Type, mt2: Type, mt3: Type, mt4: Type, mt5: Type, mt6: Type, mt7: Type, rt: Type, pt: (Type, SType, SType, SType, SType, SType, SType, SType) => SType) - (impl: (EmitRegion, EmitCodeBuilder, SType, SValue, SValue, SValue, SValue, SValue, SValue, SValue, Value[Int]) => SValue): Unit = + def registerSCode7( + name: String, + mt1: Type, + mt2: Type, + mt3: Type, + mt4: Type, + mt5: Type, + mt6: Type, + mt7: Type, + rt: Type, + pt: (Type, SType, SType, SType, SType, SType, SType, SType) => SType, + )( + impl: ( + EmitRegion, + EmitCodeBuilder, + SType, + SValue, + SValue, + SValue, + SValue, + SValue, + SValue, + SValue, + Value[Int], + ) => SValue + ): Unit = registerSCode(name, Array(mt1, mt2, mt3, mt4, mt5, mt6, mt7), rt, unwrappedApply(pt)) { - case (r, cb, _, rt, Array(a1, a2, a3, a4, a5, a6, a7), errorID) => impl(r, cb, rt, a1, a2, a3, a4, a5, a6, a7, errorID) + case (r, cb, _, rt, Array(a1, a2, a3, a4, a5, a6, a7), errorID) => + impl(r, cb, rt, a1, a2, a3, a4, a5, a6, a7, errorID) } - def registerCode1(name: String, mt1: Type, rt: Type, pt: (Type, SType) => SType)(impl: (EmitCodeBuilder, EmitRegion, SType, SValue) => Value[_]): Unit = + def registerCode1( + name: String, + mt1: Type, + rt: Type, + pt: (Type, SType) => SType, + )( + impl: (EmitCodeBuilder, EmitRegion, SType, SValue) => Value[_] + ): Unit = registerCode(name, Array(mt1), rt, unwrappedApply(pt)) { case (r, cb, rt, _, Array(a1)) => impl(cb, r, rt, a1) } - def registerCode2(name: String, mt1: Type, mt2: Type, rt: Type, pt: (Type, SType, SType) => SType) - (impl: (EmitCodeBuilder, EmitRegion, SType, SValue, SValue) => Value[_]): Unit = + def registerCode2( + name: String, + mt1: Type, + mt2: Type, + rt: Type, + pt: (Type, SType, SType) => SType, + )( + impl: (EmitCodeBuilder, EmitRegion, SType, SValue, SValue) => Value[_] + ): Unit = registerCode(name, Array(mt1, mt2), rt, unwrappedApply(pt)) { case (r, cb, rt, _, Array(a1, a2)) => impl(cb, r, rt, a1, a2) } - def registerIEmitCode1(name: String, mt1: Type, rt: Type, pt: (Type, EmitType) => EmitType) - (impl: (EmitCodeBuilder, Value[Region], SType, Value[Int], EmitCode) => IEmitCode): Unit = - registerIEmitCode(name, Array(mt1), rt, unwrappedApply(pt)) { case (cb, r, rt, errorID, Array(a1)) => - impl(cb, r, rt, errorID, a1) + def registerIEmitCode1( + name: String, + mt1: Type, + rt: Type, + pt: (Type, EmitType) => EmitType, + )( + impl: (EmitCodeBuilder, Value[Region], SType, Value[Int], EmitCode) => IEmitCode + ): Unit = + registerIEmitCode(name, Array(mt1), rt, unwrappedApply(pt)) { + case (cb, r, rt, errorID, Array(a1)) => + impl(cb, r, rt, errorID, a1) } - def registerIEmitCode2(name: String, mt1: Type, mt2: Type, rt: Type, pt: (Type, EmitType, EmitType) => EmitType) - (impl: (EmitCodeBuilder, Value[Region], SType, Value[Int], EmitCode, EmitCode) => IEmitCode): Unit = - registerIEmitCode(name, Array(mt1, mt2), rt, unwrappedApply(pt)) { case (cb, r, rt, errorID, Array(a1, a2)) => - impl(cb, r, rt, errorID, a1, a2) + def registerIEmitCode2( + name: String, + mt1: Type, + mt2: Type, + rt: Type, + pt: (Type, EmitType, EmitType) => EmitType, + )( + impl: (EmitCodeBuilder, Value[Region], SType, Value[Int], EmitCode, EmitCode) => IEmitCode + ): Unit = + registerIEmitCode(name, Array(mt1, mt2), rt, unwrappedApply(pt)) { + case (cb, r, rt, errorID, Array(a1, a2)) => + impl(cb, r, rt, errorID, a1, a2) } - def registerIEmitCode3(name: String, mt1: Type, mt2: Type, mt3: Type, rt: Type, pt: (Type, EmitType, EmitType, EmitType) => EmitType) - (impl: (EmitCodeBuilder, Value[Region], SType, Value[Int], EmitCode, EmitCode, EmitCode) => IEmitCode): Unit = - registerIEmitCode(name, Array(mt1, mt2, mt3), rt, unwrappedApply(pt)) { case (cb, r, rt, errorID, Array(a1, a2, a3)) => - impl(cb, r, rt, errorID, a1, a2, a3) + def registerIEmitCode3( + name: String, + mt1: Type, + mt2: Type, + mt3: Type, + rt: Type, + pt: (Type, EmitType, EmitType, EmitType) => EmitType, + )( + impl: ( + EmitCodeBuilder, + Value[Region], + SType, + Value[Int], + EmitCode, + EmitCode, + EmitCode, + ) => IEmitCode + ): Unit = + registerIEmitCode(name, Array(mt1, mt2, mt3), rt, unwrappedApply(pt)) { + case (cb, r, rt, errorID, Array(a1, a2, a3)) => + impl(cb, r, rt, errorID, a1, a2, a3) } - def registerIEmitCode4(name: String, mt1: Type, mt2: Type, mt3: Type, mt4: Type, rt: Type, pt: (Type, EmitType, EmitType, EmitType, EmitType) => EmitType) - (impl: (EmitCodeBuilder, Value[Region], SType, Value[Int], EmitCode, EmitCode, EmitCode, EmitCode) => IEmitCode): Unit = - registerIEmitCode(name, Array(mt1, mt2, mt3, mt4), rt, unwrappedApply(pt)) { case (cb, r, rt, errorID, Array(a1, a2, a3, a4)) => - impl(cb, r, rt, errorID, a1, a2, a3, a4) + def registerIEmitCode4( + name: String, + mt1: Type, + mt2: Type, + mt3: Type, + mt4: Type, + rt: Type, + pt: (Type, EmitType, EmitType, EmitType, EmitType) => EmitType, + )( + impl: ( + EmitCodeBuilder, + Value[Region], + SType, + Value[Int], + EmitCode, + EmitCode, + EmitCode, + EmitCode, + ) => IEmitCode + ): Unit = + registerIEmitCode(name, Array(mt1, mt2, mt3, mt4), rt, unwrappedApply(pt)) { + case (cb, r, rt, errorID, Array(a1, a2, a3, a4)) => + impl(cb, r, rt, errorID, a1, a2, a3, a4) } - def registerIEmitCode5(name: String, mt1: Type, mt2: Type, mt3: Type, mt4: Type, mt5: Type, rt: Type, pt: (Type, EmitType, EmitType, EmitType, EmitType, EmitType) => EmitType) - (impl: (EmitCodeBuilder, Value[Region], SType, Value[Int], EmitCode, EmitCode, EmitCode, EmitCode, EmitCode) => IEmitCode): Unit = - registerIEmitCode(name, Array(mt1, mt2, mt3, mt4, mt5), rt, unwrappedApply(pt)) { case (cb, r, rt, errorID, Array(a1, a2, a3, a4, a5)) => - impl(cb, r, rt, errorID, a1, a2, a3, a4, a5) + def registerIEmitCode5( + name: String, + mt1: Type, + mt2: Type, + mt3: Type, + mt4: Type, + mt5: Type, + rt: Type, + pt: (Type, EmitType, EmitType, EmitType, EmitType, EmitType) => EmitType, + )( + impl: ( + EmitCodeBuilder, + Value[Region], + SType, + Value[Int], + EmitCode, + EmitCode, + EmitCode, + EmitCode, + EmitCode, + ) => IEmitCode + ): Unit = + registerIEmitCode(name, Array(mt1, mt2, mt3, mt4, mt5), rt, unwrappedApply(pt)) { + case (cb, r, rt, errorID, Array(a1, a2, a3, a4, a5)) => + impl(cb, r, rt, errorID, a1, a2, a3, a4, a5) } - def registerIEmitCode6(name: String, mt1: Type, mt2: Type, mt3: Type, mt4: Type, mt5: Type, mt6: Type, rt: Type, pt: (Type, EmitType, EmitType, EmitType, EmitType, EmitType, EmitType) => EmitType) - (impl: (EmitCodeBuilder, Value[Region], SType, Value[Int], EmitCode, EmitCode, EmitCode, EmitCode, EmitCode, EmitCode) => IEmitCode): Unit = - registerIEmitCode(name, Array(mt1, mt2, mt3, mt4, mt5, mt6), rt, unwrappedApply(pt)) { case (cb, r, rt, errorID, Array(a1, a2, a3, a4, a5, a6)) => - impl(cb, r, rt, errorID, a1, a2, a3, a4, a5, a6) + def registerIEmitCode6( + name: String, + mt1: Type, + mt2: Type, + mt3: Type, + mt4: Type, + mt5: Type, + mt6: Type, + rt: Type, + pt: (Type, EmitType, EmitType, EmitType, EmitType, EmitType, EmitType) => EmitType, + )( + impl: ( + EmitCodeBuilder, + Value[Region], + SType, + Value[Int], + EmitCode, + EmitCode, + EmitCode, + EmitCode, + EmitCode, + EmitCode, + ) => IEmitCode + ): Unit = + registerIEmitCode(name, Array(mt1, mt2, mt3, mt4, mt5, mt6), rt, unwrappedApply(pt)) { + case (cb, r, rt, errorID, Array(a1, a2, a3, a4, a5, a6)) => + impl(cb, r, rt, errorID, a1, a2, a3, a4, a5, a6) } - def registerEmitCode2(name: String, mt1: Type, mt2: Type, rt: Type, pt: (Type, EmitType, EmitType) => EmitType) - (impl: (EmitRegion, SType, Value[Int], EmitCode, EmitCode) => EmitCode): Unit = - registerEmitCode(name, Array(mt1, mt2), rt, unwrappedApply(pt)) { case (r, rt, errorID, Array(a1, a2)) => impl(r, rt, errorID, a1, a2) } + def registerEmitCode2( + name: String, + mt1: Type, + mt2: Type, + rt: Type, + pt: (Type, EmitType, EmitType) => EmitType, + )( + impl: (EmitRegion, SType, Value[Int], EmitCode, EmitCode) => EmitCode + ): Unit = + registerEmitCode(name, Array(mt1, mt2), rt, unwrappedApply(pt)) { + case (r, rt, errorID, Array(a1, a2)) => impl(r, rt, errorID, a1, a2) + } - def registerIR1(name: String, mt1: Type, returnType: Type, typeParameters: Array[Type] = Array.empty)(f: (Seq[Type], IR, Int) => IR): Unit = - registerIR(name, Array(mt1), returnType, typeParameters = typeParameters) { case (t, Seq(a1), errorID) => f(t, a1, errorID) } + def registerIR1( + name: String, + mt1: Type, + returnType: Type, + typeParameters: Array[Type] = Array.empty, + )( + f: (Seq[Type], IR, Int) => IR + ): Unit = + registerIR(name, Array(mt1), returnType, typeParameters = typeParameters) { + case (t, Seq(a1), errorID) => f(t, a1, errorID) + } - def registerIR2(name: String, mt1: Type, mt2: Type, returnType: Type, typeParameters: Array[Type] = Array.empty)(f: (Seq[Type], IR, IR, Int) => IR): Unit = - registerIR(name, Array(mt1, mt2), returnType, typeParameters = typeParameters) { case (t, Seq(a1, a2), errorID) => f(t, a1, a2, errorID) } + def registerIR2( + name: String, + mt1: Type, + mt2: Type, + returnType: Type, + typeParameters: Array[Type] = Array.empty, + )( + f: (Seq[Type], IR, IR, Int) => IR + ): Unit = + registerIR(name, Array(mt1, mt2), returnType, typeParameters = typeParameters) { + case (t, Seq(a1, a2), errorID) => f(t, a1, a2, errorID) + } - def registerIR3(name: String, mt1: Type, mt2: Type, mt3: Type, returnType: Type, typeParameters: Array[Type] = Array.empty)(f: (Seq[Type], IR, IR, IR, Int) => IR): Unit = - registerIR(name, Array(mt1, mt2, mt3), returnType, typeParameters = typeParameters) { case (t, Seq(a1, a2, a3), errorID) => f(t, a1, a2, a3, errorID) } + def registerIR3( + name: String, + mt1: Type, + mt2: Type, + mt3: Type, + returnType: Type, + typeParameters: Array[Type] = Array.empty, + )( + f: (Seq[Type], IR, IR, IR, Int) => IR + ): Unit = + registerIR(name, Array(mt1, mt2, mt3), returnType, typeParameters = typeParameters) { + case (t, Seq(a1, a2, a3), errorID) => f(t, a1, a2, a3, errorID) + } - def registerIR4(name: String, mt1: Type, mt2: Type, mt3: Type, mt4: Type, returnType: Type, typeParameters: Array[Type] = Array.empty)(f: (Seq[Type], IR, IR, IR, IR, Int) => IR): Unit = - registerIR(name, Array(mt1, mt2, mt3, mt4), returnType, typeParameters = typeParameters) { case (t, Seq(a1, a2, a3, a4), errorID) => f(t, a1, a2, a3, a4, errorID) } + def registerIR4( + name: String, + mt1: Type, + mt2: Type, + mt3: Type, + mt4: Type, + returnType: Type, + typeParameters: Array[Type] = Array.empty, + )( + f: (Seq[Type], IR, IR, IR, IR, Int) => IR + ): Unit = + registerIR(name, Array(mt1, mt2, mt3, mt4), returnType, typeParameters = typeParameters) { + case (t, Seq(a1, a2, a3, a4), errorID) => f(t, a1, a2, a3, a4, errorID) + } } sealed abstract class JVMFunction { @@ -691,11 +1176,19 @@ sealed abstract class JVMFunction { def computeReturnEmitType(returnType: Type, valueParameterTypes: Seq[EmitType]): EmitType - def apply(mb: EmitRegion, returnType: SType, typeParameters: Seq[Type], errorID: Value[Int], args: EmitCode*): EmitCode + def apply( + mb: EmitRegion, + returnType: SType, + typeParameters: Seq[Type], + errorID: Value[Int], + args: EmitCode* + ): EmitCode - override def toString: String = s"$name[${ typeParameters.mkString(", ") }](${ valueParameterTypes.mkString(", ") }): $returnType" + override def toString: String = + s"$name[${typeParameters.mkString(", ")}](${valueParameterTypes.mkString(", ")}): $returnType" - def unify(typeArguments: Seq[Type], valueArgumentTypes: Seq[Type], returnTypeIn: Type): Boolean = { + def unify(typeArguments: Seq[Type], valueArgumentTypes: Seq[Type], returnTypeIn: Type) + : Boolean = { val concrete = (typeArguments ++ valueArgumentTypes) :+ returnTypeIn val types = (typeParameters ++ valueParameterTypes) :+ returnType types.length == concrete.length && { @@ -706,78 +1199,127 @@ sealed abstract class JVMFunction { } object MissingnessObliviousJVMFunction { - def returnSType(computeStrictReturnEmitType: (Type, Seq[SType]) => SType)(returnType: Type, valueParameterTypes: Seq[SType]): SType = { + def returnSType( + computeStrictReturnEmitType: (Type, Seq[SType]) => SType + )( + returnType: Type, + valueParameterTypes: Seq[SType], + ): SType = if (computeStrictReturnEmitType == null) SType.canonical(returnType) else computeStrictReturnEmitType(returnType, valueParameterTypes) - } } -abstract class UnseededMissingnessObliviousJVMFunction ( +abstract class UnseededMissingnessObliviousJVMFunction( override val name: String, override val typeParameters: Seq[Type], override val valueParameterTypes: Seq[Type], override val returnType: Type, - missingnessObliviousComputeReturnType: (Type, Seq[SType]) => SType + missingnessObliviousComputeReturnType: (Type, Seq[SType]) => SType, ) extends JVMFunction { - override def computeReturnEmitType(returnType: Type, valueParameterTypes: Seq[EmitType]): EmitType = { - EmitType(computeStrictReturnEmitType(returnType, valueParameterTypes.map(_.st)), valueParameterTypes.forall(_.required)) - } + override def computeReturnEmitType(returnType: Type, valueParameterTypes: Seq[EmitType]) + : EmitType = + EmitType( + computeStrictReturnEmitType(returnType, valueParameterTypes.map(_.st)), + valueParameterTypes.forall(_.required), + ) + def computeStrictReturnEmitType(returnType: Type, valueParameterTypes: Seq[SType]): SType = - MissingnessObliviousJVMFunction.returnSType(missingnessObliviousComputeReturnType)(returnType, valueParameterTypes) + MissingnessObliviousJVMFunction.returnSType(missingnessObliviousComputeReturnType)( + returnType, + valueParameterTypes, + ) - def apply(r: EmitRegion, cb: EmitCodeBuilder, returnSType: SType, typeParameters: Seq[Type], errorID: Value[Int], args: SValue*): SValue + def apply( + r: EmitRegion, + cb: EmitCodeBuilder, + returnSType: SType, + typeParameters: Seq[Type], + errorID: Value[Int], + args: SValue* + ): SValue - def apply(r: EmitRegion, returnType: SType, typeParameters: Seq[Type], errorID: Value[Int], args: EmitCode*): EmitCode = { - EmitCode.fromI(r.mb)(cb => IEmitCode.multiMapEmitCodes(cb, args.toFastSeq) { args => - apply(r, cb, returnType, typeParameters, errorID, args: _*) - }) - } + def apply( + r: EmitRegion, + returnType: SType, + typeParameters: Seq[Type], + errorID: Value[Int], + args: EmitCode* + ): EmitCode = + EmitCode.fromI(r.mb)(cb => + IEmitCode.multiMapEmitCodes(cb, args.toFastSeq) { args => + apply(r, cb, returnType, typeParameters, errorID, args: _*) + } + ) - def applyI(r: EmitRegion, cb: EmitCodeBuilder, returnType: SType, typeParameters: Seq[Type], errorID: Value[Int], args: EmitCode*): IEmitCode = { + def applyI( + r: EmitRegion, + cb: EmitCodeBuilder, + returnType: SType, + typeParameters: Seq[Type], + errorID: Value[Int], + args: EmitCode* + ): IEmitCode = IEmitCode.multiMapEmitCodes(cb, args.toFastSeq) { args => apply(r, cb, returnType, typeParameters, errorID, args: _*) } - } - def getAsMethod[C](cb: EmitClassBuilder[C], rpt: SType, typeParameters: Seq[Type], args: SType*): EmitMethodBuilder[C] = { + def getAsMethod[C](cb: EmitClassBuilder[C], rpt: SType, typeParameters: Seq[Type], args: SType*) + : EmitMethodBuilder[C] = { val unified = unify(typeParameters, args.map(_.virtualType), rpt.virtualType) assert(unified, name) - val methodbuilder = cb.genEmitMethod(name, FastSeq[ParamType](typeInfo[Region], typeInfo[Int]) ++ args.map(_.paramType), rpt.paramType) - methodbuilder.emitSCode(cb => apply(EmitRegion.default(methodbuilder), - cb, - rpt, - typeParameters, - methodbuilder.getCodeParam[Int](2), - (0 until args.length).map(i => methodbuilder.getSCodeParam(i + 3)): _*)) + val methodbuilder = cb.genEmitMethod( + name, + FastSeq[ParamType](typeInfo[Region], typeInfo[Int]) ++ args.map(_.paramType), + rpt.paramType, + ) + methodbuilder.emitSCode(cb => + apply( + EmitRegion.default(methodbuilder), + cb, + rpt, + typeParameters, + methodbuilder.getCodeParam[Int](2), + (0 until args.length).map(i => methodbuilder.getSCodeParam(i + 3)): _* + ) + ) methodbuilder } } object MissingnessAwareJVMFunction { - def returnSType(calculateReturnType: (Type, Seq[EmitType]) => EmitType)(returnType: Type, valueParameterTypes: Seq[EmitType]): EmitType = + def returnSType( + calculateReturnType: (Type, Seq[EmitType]) => EmitType + )( + returnType: Type, + valueParameterTypes: Seq[EmitType], + ): EmitType = if (calculateReturnType == null) EmitType(SType.canonical(returnType), false) else calculateReturnType(returnType, valueParameterTypes) } -abstract class UnseededMissingnessAwareJVMFunction ( +abstract class UnseededMissingnessAwareJVMFunction( override val name: String, override val typeParameters: Seq[Type], override val valueParameterTypes: Seq[Type], override val returnType: Type, - missingnessAwareComputeReturnSType: (Type, Seq[EmitType]) => EmitType + missingnessAwareComputeReturnSType: (Type, Seq[EmitType]) => EmitType, ) extends JVMFunction { - override def computeReturnEmitType(returnType: Type, valueParameterTypes: Seq[EmitType]): EmitType = - MissingnessAwareJVMFunction.returnSType(missingnessAwareComputeReturnSType)(returnType, valueParameterTypes) + override def computeReturnEmitType(returnType: Type, valueParameterTypes: Seq[EmitType]) + : EmitType = + MissingnessAwareJVMFunction.returnSType(missingnessAwareComputeReturnSType)( + returnType, + valueParameterTypes, + ) - def apply(cb: EmitCodeBuilder, + def apply( + cb: EmitCodeBuilder, r: Value[Region], rpt: SType, typeParameters: Seq[Type], errorID: Value[Int], args: EmitCode* - ): IEmitCode = { + ): IEmitCode = ??? - } } diff --git a/hail/src/main/scala/is/hail/expr/ir/functions/GenotypeFunctions.scala b/hail/src/main/scala/is/hail/expr/ir/functions/GenotypeFunctions.scala index a6110b17bc3..accf4d7862a 100644 --- a/hail/src/main/scala/is/hail/expr/ir/functions/GenotypeFunctions.scala +++ b/hail/src/main/scala/is/hail/expr/ir/functions/GenotypeFunctions.scala @@ -1,43 +1,58 @@ package is.hail.expr.ir.functions import is.hail.asm4s.{coerce => _, _} +import is.hail.types.{tcoerce => _} +import is.hail.types.physical.stypes.{EmitType, SType} import is.hail.types.physical.stypes.interfaces._ import is.hail.types.physical.stypes.primitives.{SFloat64, SInt32} -import is.hail.types.physical.stypes.{EmitType, SType} import is.hail.types.virtual.{TArray, TFloat64, TInt32, Type} -import is.hail.types.{tcoerce => _} object GenotypeFunctions extends RegistryFunctions { - def registerAll() { - registerSCode1("gqFromPL", TArray(tv("N", "int32")), TInt32, (_: Type, _: SType) => SInt32) - { case (r, cb, rt, pl: SIndexableValue, errorID) => - val m = cb.newLocal[Int]("m", 99) - val m2 = cb.newLocal[Int]("m2", 99) - val i = cb.newLocal[Int]("i", 0) + def registerAll(): Unit = { + registerSCode1("gqFromPL", TArray(tv("N", "int32")), TInt32, (_: Type, _: SType) => SInt32) { + case (_, cb, _, pl: SIndexableValue, errorID) => + val m = cb.newLocal[Int]("m", 99) + val m2 = cb.newLocal[Int]("m2", 99) + val i = cb.newLocal[Int]("i", 0) - cb.while_(i < pl.loadLength(), { - val value = pl.loadElement(cb, i).get(cb, "PL cannot have missing elements.", errorID) - val pli = cb.newLocal[Int]("pli", value.asInt.value) - cb.if_(pli < m, { - cb.assign(m2, m) - cb.assign(m, pli) - }, { - cb.if_(pli < m2, - cb.assign(m2, pli)) - }) - cb.assign(i, i + 1) - }) + cb.while_( + i < pl.loadLength(), { + val value = + pl.loadElement(cb, i).getOrFatal(cb, "PL cannot have missing elements.", errorID) + val pli = cb.newLocal[Int]("pli", value.asInt.value) + cb.if_( + pli < m, { + cb.assign(m2, m) + cb.assign(m, pli) + }, + cb.if_(pli < m2, cb.assign(m2, pli)), + ) + cb.assign(i, i + 1) + }, + ) - primitive(cb.memoize(m2 - m)) + primitive(cb.memoize(m2 - m)) } - registerIEmitCode1("dosage", TArray(tv("N", "float64")), TFloat64, - (_: Type, arrayType: EmitType) => EmitType(SFloat64, arrayType.required && arrayType.st.asInstanceOf[SContainer].elementEmitType.required) - ) { case (cb, r, rt, errorID, gp) => + registerIEmitCode1( + "dosage", + TArray(tv("N", "float64")), + TFloat64, + (_: Type, arrayType: EmitType) => + EmitType( + SFloat64, + arrayType.required && arrayType.st.asInstanceOf[SContainer].elementEmitType.required, + ), + ) { case (cb, _, _, errorID, gp) => gp.toI(cb).flatMap(cb) { case gpv: SIndexableValue => - cb.if_(gpv.loadLength().cne(3), - cb._fatalWithError(errorID, const("length of gp array must be 3, got ").concat(gpv.loadLength().toS))) + cb.if_( + gpv.loadLength().cne(3), + cb._fatalWithError( + errorID, + const("length of gp array must be 3, got ").concat(gpv.loadLength().toS), + ), + ) gpv.loadElement(cb, 1).flatMap(cb) { _1 => gpv.loadElement(cb, 2).map(cb) { _2 => diff --git a/hail/src/main/scala/is/hail/expr/ir/functions/GetElement.scala b/hail/src/main/scala/is/hail/expr/ir/functions/GetElement.scala index a0a9becf164..94308e31e97 100644 --- a/hail/src/main/scala/is/hail/expr/ir/functions/GetElement.scala +++ b/hail/src/main/scala/is/hail/expr/ir/functions/GetElement.scala @@ -1,14 +1,15 @@ package is.hail.expr.ir.functions import is.hail.backend.ExecuteContext +import is.hail.linalg.BlockMatrix import is.hail.types.BlockMatrixType import is.hail.types.virtual.Type -import is.hail.linalg.BlockMatrix case class GetElement(index: IndexedSeq[Long]) extends BlockMatrixToValueFunction { assert(index.length == 2) override def typ(childType: BlockMatrixType): Type = childType.elementType - override def execute(ctx: ExecuteContext, bm: BlockMatrix): Any = bm.getElement(index(0), index(1)) + override def execute(ctx: ExecuteContext, bm: BlockMatrix): Any = + bm.getElement(index(0), index(1)) } diff --git a/hail/src/main/scala/is/hail/expr/ir/functions/IntervalFunctions.scala b/hail/src/main/scala/is/hail/expr/ir/functions/IntervalFunctions.scala index c6091554af2..30961ef9a0f 100644 --- a/hail/src/main/scala/is/hail/expr/ir/functions/IntervalFunctions.scala +++ b/hail/src/main/scala/is/hail/expr/ir/functions/IntervalFunctions.scala @@ -4,108 +4,129 @@ import is.hail.asm4s._ import is.hail.expr.ir._ import is.hail.expr.ir.orderings.CodeOrdering import is.hail.types.physical._ +import is.hail.types.physical.stypes.{EmitType, SType, SValue} import is.hail.types.physical.stypes.concrete.{SIntervalPointer, SStackStruct, SStackStructValue} import is.hail.types.physical.stypes.interfaces._ import is.hail.types.physical.stypes.primitives.{SBoolean, SBooleanValue, SInt32} -import is.hail.types.physical.stypes.{EmitType, SType, SValue} import is.hail.types.virtual._ import is.hail.utils.FastSeq object IntervalFunctions extends RegistryFunctions { - def pointLTIntervalEndpoint(cb: EmitCodeBuilder, - point: SValue, endpoint: SValue, leansRight: Code[Boolean] + def pointLTIntervalEndpoint( + cb: EmitCodeBuilder, + point: SValue, + endpoint: SValue, + leansRight: Code[Boolean], ): Code[Boolean] = { val ord = cb.emb.ecb.getOrdering(point.st, endpoint.st) val result = ord.compareNonnull(cb, point, endpoint) (result < 0) || (result.ceq(0) && leansRight) } - def pointGTIntervalEndpoint(cb: EmitCodeBuilder, - point: SValue, endpoint: SValue, leansRight: Code[Boolean] + def pointGTIntervalEndpoint( + cb: EmitCodeBuilder, + point: SValue, + endpoint: SValue, + leansRight: Code[Boolean], ): Code[Boolean] = { val ord = cb.emb.ecb.getOrdering(point.st, endpoint.st) val result = ord.compareNonnull(cb, point, endpoint) (result > 0) || (result.ceq(0) && !leansRight) } - def intervalEndpointCompare(cb: EmitCodeBuilder, - lhs: SValue, lhsLeansRight: Code[Boolean], - rhs: SValue, rhsLeansRight: Code[Boolean] + def intervalEndpointCompare( + cb: EmitCodeBuilder, + lhs: SValue, + lhsLeansRight: Code[Boolean], + rhs: SValue, + rhsLeansRight: Code[Boolean], ): Value[Int] = { val ord = cb.emb.ecb.getOrdering(lhs.st, rhs.st) val result = cb.newLocal[Int]("intervalEndpointCompare") cb.assign(result, ord.compareNonnull(cb, lhs, rhs)) - cb.if_(result.ceq(0), - cb.assign(result, lhsLeansRight.toI - rhsLeansRight.toI)) + cb.if_(result.ceq(0), cb.assign(result, lhsLeansRight.toI - rhsLeansRight.toI)) result } - def pointIntervalCompare(cb: EmitCodeBuilder, point: SValue, interval: SIntervalValue): IEmitCode = { + def pointIntervalCompare(cb: EmitCodeBuilder, point: SValue, interval: SIntervalValue) + : IEmitCode = { interval.loadStart(cb).flatMap(cb) { start => - cb.if_(pointLTIntervalEndpoint(cb, point, start, !interval.includesStart), { - IEmitCode.present(cb, primitive(const(-1))) - }, { - interval.loadEnd(cb).map(cb) { end => - cb.if_(pointLTIntervalEndpoint(cb, point, end, interval.includesEnd), { - primitive(const(0)) - }, { - primitive(const(1)) - }) - } - }) + cb.if_( + pointLTIntervalEndpoint(cb, point, start, !interval.includesStart), + IEmitCode.present(cb, primitive(const(-1))), { + interval.loadEnd(cb).map(cb) { end => + cb.if_( + pointLTIntervalEndpoint(cb, point, end, interval.includesEnd), + primitive(const(0)), + primitive(const(1)), + ) + } + }, + ) } } - def intervalPointCompare(cb: EmitCodeBuilder, interval: SIntervalValue, point: SValue): IEmitCode = { + def intervalPointCompare(cb: EmitCodeBuilder, interval: SIntervalValue, point: SValue) + : IEmitCode = { interval.loadStart(cb).flatMap(cb) { start => - cb.if_(pointLTIntervalEndpoint(cb, point, start, !interval.includesStart), { - IEmitCode.present(cb, primitive(const(1))) - }, { - interval.loadEnd(cb).map(cb) { end => - cb.if_(pointLTIntervalEndpoint(cb, point, end, interval.includesEnd), { - primitive(const(0)) - }, { - primitive(const(-1)) - }) - } - }) + cb.if_( + pointLTIntervalEndpoint(cb, point, start, !interval.includesStart), + IEmitCode.present(cb, primitive(const(1))), { + interval.loadEnd(cb).map(cb) { end => + cb.if_( + pointLTIntervalEndpoint(cb, point, end, interval.includesEnd), + primitive(const(0)), + primitive(const(-1)), + ) + } + }, + ) } } def intervalContains(cb: EmitCodeBuilder, interval: SIntervalValue, point: SValue): IEmitCode = { interval.loadStart(cb).flatMap(cb) { start => - cb.if_(pointGTIntervalEndpoint(cb, point, start, !interval.includesStart), + cb.if_( + pointGTIntervalEndpoint(cb, point, start, !interval.includesStart), interval.loadEnd(cb).map(cb) { end => primitive(cb.memoize(pointLTIntervalEndpoint(cb, point, end, interval.includesEnd))) }, - IEmitCode.present(cb, primitive(false))) + IEmitCode.present(cb, primitive(false)), + ) } } def intervalsOverlap(cb: EmitCodeBuilder, lhs: SIntervalValue, rhs: SIntervalValue): IEmitCode = { - IEmitCode.multiFlatMap(cb, - FastSeq(lhs.loadEnd, rhs.loadStart) - ) { case Seq(lEnd, rStart) => - cb.if_(intervalEndpointCompare(cb, lEnd, lhs.includesEnd, rStart, !rhs.includesStart) > 0, { - IEmitCode.multiMap(cb, - FastSeq(lhs.loadStart, rhs.loadEnd) - ) { case Seq(lStart, rEnd) => - primitive(cb.memoize(intervalEndpointCompare(cb, rEnd, rhs.includesEnd, lStart, !lhs.includesStart) > 0)) - } - }, { - IEmitCode.present(cb, primitive(const(false))) - }) + IEmitCode.multiFlatMap(cb, FastSeq(lhs.loadEnd, rhs.loadStart)) { case Seq(lEnd, rStart) => + cb.if_( + intervalEndpointCompare(cb, lEnd, lhs.includesEnd, rStart, !rhs.includesStart) > 0, + IEmitCode.multiMap(cb, FastSeq(lhs.loadStart, rhs.loadEnd)) { case Seq(lStart, rEnd) => + primitive(cb.memoize(intervalEndpointCompare( + cb, + rEnd, + rhs.includesEnd, + lStart, + !lhs.includesStart, + ) > 0)) + }, + IEmitCode.present(cb, primitive(const(false))), + ) } } - def _partitionIntervalEndpointCompare(cb: EmitCodeBuilder, - lStruct: SBaseStructValue, lLength: Value[Int], lSign: Value[Int], - rStruct: SBaseStructValue, rLength: Value[Int], rSign: Value[Int] + def _partitionIntervalEndpointCompare( + cb: EmitCodeBuilder, + lStruct: SBaseStructValue, + lLength: Value[Int], + lSign: Value[Int], + rStruct: SBaseStructValue, + rLength: Value[Int], + rSign: Value[Int], ): Value[Int] = { val structType = lStruct.st - assert(rStruct.st.virtualType.isIsomorphicTo(structType.virtualType)) + assert(rStruct.st.virtualType.isJoinableWith(structType.virtualType)) val prefixLength = cb.memoize(lLength.min(rLength)) val result = cb.newLocal[Int]("partitionIntervalEndpointCompare") @@ -115,9 +136,11 @@ object IntervalFunctions extends RegistryFunctions { (0 until (lStruct.st.size min rStruct.st.size)).foreach { idx => val lField = cb.memoize(lStruct.loadField(cb, idx)) val rField = cb.memoize(rStruct.loadField(cb, idx)) - cb.assign(result, + cb.assign( + result, cb.emb.ecb.getOrderingFunction(lField.st, rField.st, CodeOrdering.Compare()) - .apply(cb, lField, rField)) + .apply(cb, lField, rField), + ) cb.if_(result.cne(0), cb.goto(Lafter)) if (idx < (lStruct.st.size min rStruct.st.size)) { cb.if_(prefixLength.ceq(idx + 1), cb.goto(Leq)) @@ -134,80 +157,123 @@ object IntervalFunctions extends RegistryFunctions { result } - def partitionIntervalEndpointCompare(cb: EmitCodeBuilder, - lhs: SBaseStructValue, lSign: Value[Int], - rhs: SBaseStructValue, rSign: Value[Int] + def partitionIntervalEndpointCompare( + cb: EmitCodeBuilder, + lhs: SBaseStructValue, + lSign: Value[Int], + rhs: SBaseStructValue, + rSign: Value[Int], ): Value[Int] = { - val lStruct = lhs.loadField(cb, 0).get(cb).asBaseStruct - val lLength = lhs.loadField(cb, 1).get(cb).asInt.value - val rStruct = rhs.loadField(cb, 0).get(cb).asBaseStruct - val rLength = rhs.loadField(cb, 1).get(cb).asInt.value + val lStruct = lhs.loadField(cb, 0).getOrAssert(cb).asBaseStruct + val lLength = lhs.loadField(cb, 1).getOrAssert(cb).asInt.value + val rStruct = rhs.loadField(cb, 0).getOrAssert(cb).asBaseStruct + val rLength = rhs.loadField(cb, 1).getOrAssert(cb).asInt.value _partitionIntervalEndpointCompare(cb, lStruct, lLength, lSign, rStruct, rLength, rSign) } - def compareStructWithPartitionIntervalEndpoint(cb: EmitCodeBuilder, + def compareStructWithPartitionIntervalEndpoint( + cb: EmitCodeBuilder, point: SBaseStructValue, intervalEndpoint: SBaseStructValue, - leansRight: Code[Boolean] + leansRight: Code[Boolean], ): Value[Int] = { - val endpoint = intervalEndpoint.loadField(cb, 0).get(cb).asBaseStruct - val endpointLength = intervalEndpoint.loadField(cb, 1).get(cb).asInt.value + val endpoint = intervalEndpoint.loadField(cb, 0).getOrAssert(cb).asBaseStruct + val endpointLength = intervalEndpoint.loadField(cb, 1).getOrAssert(cb).asInt.value val sign = cb.memoize((leansRight.toI << 1) - 1) _partitionIntervalEndpointCompare(cb, point, point.st.size, 0, endpoint, endpointLength, sign) } - def compareStructWithPartitionInterval(cb: EmitCodeBuilder, + def compareStructWithPartitionInterval( + cb: EmitCodeBuilder, point: SBaseStructValue, - interval: SIntervalValue + interval: SIntervalValue, ): Value[Int] = { val start = interval.loadStart(cb) - .get(cb, "partition intervals cannot have missing endpoints") + .getOrFatal(cb, "partition intervals cannot have missing endpoints") .asBaseStruct - cb.if_(compareStructWithPartitionIntervalEndpoint(cb, point, start, !interval.includesStart) < 0, { - primitive(const(-1)) - }, { - val end = interval.loadEnd(cb) - .get(cb, "partition intervals cannot have missing endpoints") - .asBaseStruct - cb.if_(compareStructWithPartitionIntervalEndpoint(cb, point, end, interval.includesEnd) < 0, { - primitive(const(0)) - }, { - primitive(const(1)) - }) - }).asInt.value + cb.if_( + compareStructWithPartitionIntervalEndpoint(cb, point, start, !interval.includesStart) < 0, + primitive(const(-1)), { + val end = interval.loadEnd(cb) + .getOrFatal(cb, "partition intervals cannot have missing endpoints") + .asBaseStruct + cb.if_( + compareStructWithPartitionIntervalEndpoint(cb, point, end, interval.includesEnd) < 0, + primitive(const(0)), + primitive(const(1)), + ) + }, + ).asInt.value } - def partitionerFindIntervalRange(cb: EmitCodeBuilder, intervals: SIndexableValue, query: SIntervalValue, errorID: Value[Int]): (Value[Int], Value[Int]) = { + def partitionerFindIntervalRange( + cb: EmitCodeBuilder, + intervals: SIndexableValue, + query: SIntervalValue, + errorID: Value[Int], + ): (Value[Int], Value[Int]) = { val needleStart = query.loadStart(cb) - .get(cb, "partitionerFindIntervalRange assumes non-missing interval endpoints", errorID) + .getOrFatal( + cb, + "partitionerFindIntervalRange assumes non-missing interval endpoints", + errorID, + ) .asBaseStruct val needleEnd = query.loadEnd(cb) - .get(cb, "partitionerFindIntervalRange assumes non-missing interval endpoints", errorID) + .getOrFatal( + cb, + "partitionerFindIntervalRange assumes non-missing interval endpoints", + errorID, + ) .asBaseStruct def ltNeedle(interval: IEmitCode): Code[Boolean] = { val intervalVal = interval - .get(cb, "partitionerFindIntervalRange: partition intervals cannot be missing", errorID) + .getOrFatal( + cb, + "partitionerFindIntervalRange: partition intervals cannot be missing", + errorID, + ) .asInterval val intervalEnd = intervalVal.loadEnd(cb) - .get(cb, "partitionerFindIntervalRange assumes non-missing interval endpoints", errorID) + .getOrFatal( + cb, + "partitionerFindIntervalRange assumes non-missing interval endpoints", + errorID, + ) .asBaseStruct - val c = partitionIntervalEndpointCompare(cb, - intervalEnd, cb.memoize((intervalVal.includesEnd.toI << 1) - 1), - needleStart, cb.memoize(const(1) - (query.includesStart.toI << 1))) + val c = partitionIntervalEndpointCompare( + cb, + intervalEnd, + cb.memoize((intervalVal.includesEnd.toI << 1) - 1), + needleStart, + cb.memoize(const(1) - (query.includesStart.toI << 1)), + ) c <= 0 } def gtNeedle(interval: IEmitCode): Code[Boolean] = { val intervalVal = interval - .get(cb, "partitionerFindIntervalRange: partition intervals cannot be missing", errorID) + .getOrFatal( + cb, + "partitionerFindIntervalRange: partition intervals cannot be missing", + errorID, + ) .asInterval val intervalStart = intervalVal.loadStart(cb) - .get(cb, "partitionerFindIntervalRange assumes non-missing interval endpoints", errorID) + .getOrFatal( + cb, + "partitionerFindIntervalRange assumes non-missing interval endpoints", + errorID, + ) .asBaseStruct - val c = partitionIntervalEndpointCompare(cb, - intervalStart, cb.memoize(const(1) - (intervalVal.includesStart.toI << 1)), - needleEnd, cb.memoize((query.includesEnd.toI << 1) - 1)) + val c = partitionIntervalEndpointCompare( + cb, + intervalStart, + cb.memoize(const(1) - (intervalVal.includesStart.toI << 1)), + needleEnd, + cb.memoize((query.includesEnd.toI << 1) - 1), + ) c >= 0 } @@ -216,11 +282,14 @@ object IntervalFunctions extends RegistryFunctions { BinarySearch.equalRange(cb, intervals, compare, ltNeedle, gtNeedle, 0, intervals.loadLength()) } - def arrayOfStructFindIntervalRange(cb: EmitCodeBuilder, + def arrayOfStructFindIntervalRange( + cb: EmitCodeBuilder, array: SIndexableValue, - startKey: SBaseStructValue, startLeansRight: Value[Boolean], - endKey: SBaseStructValue, endLeansRight: Value[Boolean], - key: IEmitCode => IEmitCode + startKey: SBaseStructValue, + startLeansRight: Value[Boolean], + endKey: SBaseStructValue, + endLeansRight: Value[Boolean], + key: IEmitCode => IEmitCode, ): (Value[Int], Value[Int]) = { def ltNeedle(elt: IEmitCode): Code[Boolean] = { val eltKey = cb.memoize(key(elt)).get(cb).asBaseStruct @@ -238,95 +307,166 @@ object IntervalFunctions extends RegistryFunctions { } def registerAll(): Unit = { - registerIEmitCode4("Interval", tv("T"), tv("T"), TBoolean, TBoolean, TInterval(tv("T")), + registerIEmitCode4( + "Interval", + tv("T"), + tv("T"), + TBoolean, + TBoolean, + TInterval(tv("T")), { case (_: Type, startpt, endpt, includesStartET, includesEndET) => - EmitType(PCanonicalInterval( - InferPType.getCompatiblePType(Seq(startpt.typeWithRequiredness.canonicalPType, endpt.typeWithRequiredness.canonicalPType)), - required = includesStartET.required && includesEndET.required - ).sType, includesStartET.required && includesEndET.required) - }) { - case (cb, r, SIntervalPointer(pt: PCanonicalInterval), _, start, end, includesStart, includesEnd) => - + EmitType( + PCanonicalInterval( + InferPType.getCompatiblePType(Seq( + startpt.typeWithRequiredness.canonicalPType, + endpt.typeWithRequiredness.canonicalPType, + )), + required = includesStartET.required && includesEndET.required, + ).sType, + includesStartET.required && includesEndET.required, + ) + }, + ) { + case ( + cb, + r, + SIntervalPointer(pt: PCanonicalInterval), + _, + start, + end, + includesStart, + includesEnd, + ) => includesStart.toI(cb).flatMap(cb) { includesStart => includesEnd.toI(cb).map(cb) { includesEnd => - - pt.constructFromCodes(cb, r, + pt.constructFromCodes( + cb, + r, start, end, includesStart.asBoolean.value, - includesEnd.asBoolean.value) + includesEnd.asBoolean.value, + ) } } } - registerIEmitCode1("start", TInterval(tv("T")), tv("T"), - (_: Type, x: EmitType) => EmitType(x.st.asInstanceOf[SInterval].pointType, x.required && x.st.asInstanceOf[SInterval].pointEmitType.required)) { - case (cb, r, rt, _, interval) => + registerIEmitCode1( + "start", + TInterval(tv("T")), + tv("T"), + (_: Type, x: EmitType) => + EmitType( + x.st.asInstanceOf[SInterval].pointType, + x.required && x.st.asInstanceOf[SInterval].pointEmitType.required, + ), + ) { + case (cb, _, _, _, interval) => interval.toI(cb).flatMap(cb) { case pv: SIntervalValue => pv.loadStart(cb) } } - registerIEmitCode1("end", TInterval(tv("T")), tv("T"), - (_: Type, x: EmitType) => EmitType(x.st.asInstanceOf[SInterval].pointType, x.required && x.st.asInstanceOf[SInterval].pointEmitType.required)) { - case (cb, r, rt, _, interval) => + registerIEmitCode1( + "end", + TInterval(tv("T")), + tv("T"), + (_: Type, x: EmitType) => + EmitType( + x.st.asInstanceOf[SInterval].pointType, + x.required && x.st.asInstanceOf[SInterval].pointEmitType.required, + ), + ) { + case (cb, _, _, _, interval) => interval.toI(cb).flatMap(cb) { case pv: SIntervalValue => pv.loadEnd(cb) } } - registerSCode1("includesStart", TInterval(tv("T")), TBoolean, (_: Type, x: SType) => - SBoolean + registerSCode1( + "includesStart", + TInterval(tv("T")), + TBoolean, + (_: Type, x: SType) => + SBoolean, ) { - case (r, cb, rt, interval: SIntervalValue, _) => primitive(interval.includesStart) + case (_, _, _, interval: SIntervalValue, _) => primitive(interval.includesStart) } - registerSCode1("includesEnd", TInterval(tv("T")), TBoolean, (_: Type, x: SType) => - SBoolean + registerSCode1( + "includesEnd", + TInterval(tv("T")), + TBoolean, + (_: Type, x: SType) => + SBoolean, ) { - case (r, cb, rt, interval: SIntervalValue, _) => primitive(interval.includesEnd) + case (_, _, _, interval: SIntervalValue, _) => primitive(interval.includesEnd) } - registerIEmitCode2("contains", TInterval(tv("T")), tv("T"), TBoolean, { - case(_: Type, intervalT: EmitType, pointT: EmitType) => - val intervalST = intervalT.st.asInstanceOf[SInterval] - val required = intervalT.required && intervalST.pointEmitType.required && pointT.required - EmitType(SBoolean, required) - }) { case (cb, r, rt, _, int, point) => - IEmitCode.multiFlatMap(cb, - FastSeq(int.toI, point.toI) - ) { case Seq(interval: SIntervalValue, point) => - intervalContains(cb, interval, point) + registerIEmitCode2( + "contains", + TInterval(tv("T")), + tv("T"), + TBoolean, + { + case (_: Type, intervalT: EmitType, pointT: EmitType) => + val intervalST = intervalT.st.asInstanceOf[SInterval] + val required = intervalT.required && intervalST.pointEmitType.required && pointT.required + EmitType(SBoolean, required) + }, + ) { case (cb, _, _, _, int, point) => + IEmitCode.multiFlatMap(cb, FastSeq(int.toI, point.toI)) { + case Seq(interval: SIntervalValue, point) => + intervalContains(cb, interval, point) } } registerSCode1("isEmpty", TInterval(tv("T")), TBoolean, (_: Type, pt: SType) => SBoolean) { - case (r, cb, rt, interval: SIntervalValue, _) => + case (_, cb, _, interval: SIntervalValue, _) => primitive(interval.isEmpty(cb)) } - registerIEmitCode2("overlaps", TInterval(tv("T")), TInterval(tv("T")), TBoolean, { - (_: Type, i1t: EmitType, i2t: EmitType) => - val i1ST = i1t.st.asInstanceOf[SInterval] - val i2ST = i2t.st.asInstanceOf[SInterval] - val required = i1t.required && i2t.required && i1ST.pointEmitType.required && i2ST.pointEmitType.required - EmitType(SBoolean, required) - }) { case (cb, r, rt, _, interval1: EmitCode, interval2: EmitCode) => + registerIEmitCode2( + "overlaps", + TInterval(tv("T")), + TInterval(tv("T")), + TBoolean, + { + (_: Type, i1t: EmitType, i2t: EmitType) => + val i1ST = i1t.st.asInstanceOf[SInterval] + val i2ST = i2t.st.asInstanceOf[SInterval] + val required = + i1t.required && i2t.required && i1ST.pointEmitType.required && i2ST.pointEmitType.required + EmitType(SBoolean, required) + }, + ) { case (cb, _, _, _, interval1: EmitCode, interval2: EmitCode) => IEmitCode.multiFlatMap(cb, FastSeq(interval1.toI, interval2.toI)) { case Seq(interval1: SIntervalValue, interval2: SIntervalValue) => - intervalsOverlap(cb, interval1, interval2) + intervalsOverlap(cb, interval1, interval2) } } - registerSCode2("sortedNonOverlappingIntervalsContain", - TArray(TInterval(tv("T"))), tv("T"), TBoolean, (_, _, _) => SBoolean - ) { case (_, cb, rt, intervals, point, errorID) => + registerSCode2( + "sortedNonOverlappingIntervalsContain", + TArray(TInterval(tv("T"))), + tv("T"), + TBoolean, + (_, _, _) => SBoolean, + ) { case (_, cb, _, intervals, point, errorID) => val compare = BinarySearch.Comparator.fromCompare { intervalEC => val interval = intervalEC - .get(cb, "sortedNonOverlappingIntervalsContain assumes non-missing intervals", errorID) + .getOrFatal( + cb, + "sortedNonOverlappingIntervalsContain assumes non-missing intervals", + errorID, + ) .asInterval intervalPointCompare(cb, interval, point) - .get(cb, "sortedNonOverlappingIntervalsContain assumes non-missing interval endpoints", errorID) + .getOrFatal( + cb, + "sortedNonOverlappingIntervalsContain assumes non-missing interval endpoints", + errorID, + ) .asInt.value } @@ -335,32 +475,57 @@ object IntervalFunctions extends RegistryFunctions { val partitionEndpointType = TTuple(tv("T"), TInt32) val partitionIntervalType = TInterval(partitionEndpointType) - registerSCode2("partitionerContains", - TArray(partitionIntervalType), tv("T"), TBoolean, - (_, _, _) => SBoolean - ) { case (_, cb, rt, intervals: SIndexableValue, point: SBaseStructValue, errorID) => + registerSCode2( + "partitionerContains", + TArray(partitionIntervalType), + tv("T"), + TBoolean, + (_, _, _) => SBoolean, + ) { case (_, cb, _, intervals: SIndexableValue, point: SBaseStructValue, errorID) => def ltNeedle(interval: IEmitCode): Code[Boolean] = { val intervalVal = interval - .get(cb, "partitionerFindIntervalRange: partition intervals cannot be missing", errorID) + .getOrFatal( + cb, + "partitionerFindIntervalRange: partition intervals cannot be missing", + errorID, + ) .asInterval val intervalEnd = intervalVal.loadEnd(cb) - .get(cb, "partitionerFindIntervalRange assumes non-missing interval endpoints", errorID) + .getOrFatal( + cb, + "partitionerFindIntervalRange assumes non-missing interval endpoints", + errorID, + ) .asBaseStruct - val c = compareStructWithPartitionIntervalEndpoint(cb, + val c = compareStructWithPartitionIntervalEndpoint( + cb, point, - intervalEnd, intervalVal.includesEnd) + intervalEnd, + intervalVal.includesEnd, + ) c > 0 } def gtNeedle(interval: IEmitCode): Code[Boolean] = { val intervalVal = interval - .get(cb, "partitionerFindIntervalRange: partition intervals cannot be missing", errorID) + .getOrFatal( + cb, + "partitionerFindIntervalRange: partition intervals cannot be missing", + errorID, + ) .asInterval val intervalStart = intervalVal.loadStart(cb) - .get(cb, "partitionerFindIntervalRange assumes non-missing interval endpoints", errorID) + .getOrFatal( + cb, + "partitionerFindIntervalRange assumes non-missing interval endpoints", + errorID, + ) .asBaseStruct - val c = compareStructWithPartitionIntervalEndpoint(cb, + val c = compareStructWithPartitionIntervalEndpoint( + cb, point, - intervalStart, !intervalVal.includesStart) + intervalStart, + !intervalVal.includesStart, + ) c < 0 } primitive(BinarySearch.containsOrdered(cb, intervals, ltNeedle, gtNeedle)) @@ -368,51 +533,108 @@ object IntervalFunctions extends RegistryFunctions { val requiredInt = EmitType(SInt32, true) val equalRangeResultType = TTuple(TInt32, TInt32) - val equalRangeResultSType = SStackStruct(equalRangeResultType, FastSeq(requiredInt, requiredInt)) + val equalRangeResultSType = + SStackStruct(equalRangeResultType, FastSeq(requiredInt, requiredInt)) - registerSCode2("partitionerFindIntervalRange", - TArray(partitionIntervalType), partitionIntervalType, equalRangeResultType, - (_, _, _) => equalRangeResultSType - ) { case (_, cb, rt, intervals: SIndexableValue, query: SIntervalValue, errorID) => + registerSCode2( + "partitionerFindIntervalRange", + TArray(partitionIntervalType), + partitionIntervalType, + equalRangeResultType, + (_, _, _) => equalRangeResultSType, + ) { case (_, cb, _, intervals: SIndexableValue, query: SIntervalValue, errorID) => val (start, end) = partitionerFindIntervalRange(cb, intervals, query, errorID) - new SStackStructValue(equalRangeResultSType, + new SStackStructValue( + equalRangeResultSType, FastSeq( EmitValue.present(primitive(start)), - EmitValue.present(primitive(end)))) + EmitValue.present(primitive(end)), + ), + ) } val endpointT = TTuple(tv("T"), TInt32) - registerSCode3("pointLessThanPartitionIntervalLeftEndpoint", tv("T"), endpointT, TBoolean, TBoolean, (_, _, _, _) => SBoolean) { - case (_, cb, _, point: SBaseStructValue, leftPartitionEndpoint: SBaseStructValue, containsStart: SBooleanValue, _) => + registerSCode3( + "pointLessThanPartitionIntervalLeftEndpoint", + tv("T"), + endpointT, + TBoolean, + TBoolean, + (_, _, _, _) => SBoolean, + ) { + case ( + _, + cb, + _, + point: SBaseStructValue, + leftPartitionEndpoint: SBaseStructValue, + containsStart: SBooleanValue, + _, + ) => primitive(cb.memoize( - compareStructWithPartitionIntervalEndpoint(cb, point, leftPartitionEndpoint, !containsStart.value) < 0)) + compareStructWithPartitionIntervalEndpoint( + cb, + point, + leftPartitionEndpoint, + !containsStart.value, + ) < 0 + )) } - registerSCode3("pointLessThanPartitionIntervalRightEndpoint", tv("T"), endpointT, TBoolean, TBoolean, (_, _, _, _) => SBoolean) { - case (_, cb, _, point: SBaseStructValue, rightPartitionEndpoint: SBaseStructValue, containsEnd: SBooleanValue, _) => + registerSCode3( + "pointLessThanPartitionIntervalRightEndpoint", + tv("T"), + endpointT, + TBoolean, + TBoolean, + (_, _, _, _) => SBoolean, + ) { + case ( + _, + cb, + _, + point: SBaseStructValue, + rightPartitionEndpoint: SBaseStructValue, + containsEnd: SBooleanValue, + _, + ) => primitive(cb.memoize( - compareStructWithPartitionIntervalEndpoint(cb, point, rightPartitionEndpoint, containsEnd.value) < 0)) + compareStructWithPartitionIntervalEndpoint( + cb, + point, + rightPartitionEndpoint, + containsEnd.value, + ) < 0 + )) } - registerSCode2("partitionIntervalContains", + registerSCode2( + "partitionIntervalContains", partitionIntervalType, - tv("T"), TBoolean, (_, _, _) => SBoolean) { + tv("T"), + TBoolean, + (_, _, _) => SBoolean, + ) { case (_, cb, _, interval: SIntervalValue, point: SBaseStructValue, _) => - val leftTuple = interval.loadStart(cb).get(cb).asBaseStruct + val leftTuple = interval.loadStart(cb).getOrAssert(cb).asBaseStruct val includesLeft = interval.includesStart - val pointGTLeft = compareStructWithPartitionIntervalEndpoint(cb, point, leftTuple, !includesLeft) > 0 + val pointGTLeft = + compareStructWithPartitionIntervalEndpoint(cb, point, leftTuple, !includesLeft) > 0 val isContained = cb.newLocal[Boolean]("partitionInterval_b", pointGTLeft) - cb.if_(isContained, { - // check right endpoint - val rightTuple = interval.loadEnd(cb).get(cb).asBaseStruct - - val includesRight = interval.includesEnd - val pointLTRight = compareStructWithPartitionIntervalEndpoint(cb, point, rightTuple, includesRight) < 0 - cb.assign(isContained, pointLTRight) - }) + cb.if_( + isContained, { + // check right endpoint + val rightTuple = interval.loadEnd(cb).getOrAssert(cb).asBaseStruct + + val includesRight = interval.includesEnd + val pointLTRight = + compareStructWithPartitionIntervalEndpoint(cb, point, rightTuple, includesRight) < 0 + cb.assign(isContained, pointLTRight) + }, + ) primitive(isContained) } diff --git a/hail/src/main/scala/is/hail/expr/ir/functions/LocusFunctions.scala b/hail/src/main/scala/is/hail/expr/ir/functions/LocusFunctions.scala index f920938841c..1742c051ff6 100644 --- a/hail/src/main/scala/is/hail/expr/ir/functions/LocusFunctions.scala +++ b/hail/src/main/scala/is/hail/expr/ir/functions/LocusFunctions.scala @@ -4,10 +4,10 @@ import is.hail.annotations.Region import is.hail.asm4s._ import is.hail.expr.ir.{EmitMethodBuilder, _} import is.hail.types.physical._ +import is.hail.types.physical.stypes.{EmitType, SType} import is.hail.types.physical.stypes.concrete._ import is.hail.types.physical.stypes.interfaces._ import is.hail.types.physical.stypes.primitives._ -import is.hail.types.physical.stypes.{EmitType, SType} import is.hail.types.virtual._ import is.hail.utils._ import is.hail.variant._ @@ -19,83 +19,128 @@ object LocusFunctions extends RegistryFunctions { def tlocus(name: String): Type = tv(name, "locus") - def tvariant(name: String): TStruct = TStruct("locus" -> tlocus(name), "alleles" -> TArray(TString)) + def tvariant(name: String): TStruct = + TStruct("locus" -> tlocus(name), "alleles" -> TArray(TString)) def tinterval(name: String): TInterval = TInterval(tlocus(name)) - def emitLocus(cb: EmitCodeBuilder, r: Value[Region], locus: Code[Locus], rt: PCanonicalLocus): SCanonicalLocusPointerValue = { + def emitLocus(cb: EmitCodeBuilder, r: Value[Region], locus: Code[Locus], rt: PCanonicalLocus) + : SCanonicalLocusPointerValue = { val loc = cb.newLocal[Locus]("emit_locus_memo", locus) - rt.constructFromPositionAndString(cb, r, + rt.constructFromContigAndPosition( + cb, + r, loc.invoke[String]("contig"), - loc.invoke[Int]("position")) + loc.invoke[Int]("position"), + ) } - def emitVariant(cb: EmitCodeBuilder, r: Value[Region], variantCode: Code[(Locus, IndexedSeq[String])], rt: PCanonicalStruct): SBaseStructPointerValue = { + def emitVariant( + cb: EmitCodeBuilder, + r: Value[Region], + variantCode: Code[(Locus, IndexedSeq[String])], + rt: PCanonicalStruct, + ): SBaseStructPointerValue = { val variant = cb.newLocal[(Locus, IndexedSeq[String])]("emit_variant_variant", variantCode) - val locus = EmitCode.fromI(cb.emb) { cb => IEmitCode.present(cb, emitLocus(cb, r, variant.invoke[Locus]("_1"), rt.types(0).asInstanceOf[PCanonicalLocus])) } + val locus = EmitCode.fromI(cb.emb) { cb => + IEmitCode.present( + cb, + emitLocus(cb, r, variant.invoke[Locus]("_1"), rt.types(0).asInstanceOf[PCanonicalLocus]), + ) + } val alleles = EmitCode.fromI(cb.emb) { cb => val pAlleles = rt.types(1).asInstanceOf[PCanonicalArray] - val all = cb.newLocal[IndexedSeq[String]]("locus_alleles_parsed_alleles", variant.invoke[IndexedSeq[String]]("_2")) + val all = cb.newLocal[IndexedSeq[String]]( + "locus_alleles_parsed_alleles", + variant.invoke[IndexedSeq[String]]("_2"), + ) val len = cb.newLocal[Int]("locus_alleles_n_alleles", all.invoke[Int]("length")) val ps = pAlleles.elementType.setRequired(false).asInstanceOf[PCanonicalString] val ss = SStringPointer(ps) val (push, finish) = pAlleles.constructFromFunctions(cb, r, len, deepCopy = false) val i = cb.newLocal[Int]("locus_alleles_i", 0) - cb.while_(i < len, { - push(cb, IEmitCode.present(cb, ss.constructFromString(cb, r, all.invoke[Int, String]("apply", i)))) - cb.assign(i, i + 1) - }) + cb.while_( + i < len, { + push( + cb, + IEmitCode.present( + cb, + ss.constructFromString(cb, r, all.invoke[Int, String]("apply", i)), + ), + ) + cb.assign(i, i + 1) + }, + ) IEmitCode.present(cb, finish(cb)) } rt.constructFromFields(cb, r, FastSeq(locus, alleles), deepCopy = false) } - def emitLocusInterval(cb: EmitCodeBuilder, r: Value[Region], intervalCode: Code[Interval], pt: PCanonicalInterval): SIntervalPointerValue = { + def emitLocusInterval( + cb: EmitCodeBuilder, + r: Value[Region], + intervalCode: Code[Interval], + pt: PCanonicalInterval, + ): SIntervalPointerValue = { val interval = cb.newLocal[Interval]("emit_interval_interval", intervalCode) val pointType = pt.pointType.asInstanceOf[PCanonicalLocus] - pt.constructFromCodes(cb, + pt.constructFromCodes( + cb, r, - EmitCode.fromI(cb.emb)(cb => IEmitCode.present(cb, emitLocus(cb, r, interval.invoke[Locus]("start"), pointType))), - EmitCode.fromI(cb.emb)(cb => IEmitCode.present(cb, emitLocus(cb, r, interval.invoke[Locus]("end"), pointType))), + EmitCode.fromI(cb.emb)(cb => + IEmitCode.present(cb, emitLocus(cb, r, interval.invoke[Locus]("start"), pointType)) + ), + EmitCode.fromI(cb.emb)(cb => + IEmitCode.present(cb, emitLocus(cb, r, interval.invoke[Locus]("end"), pointType)) + ), cb.memoize(interval.invoke[Boolean]("includesStart")), - cb.memoize(interval.invoke[Boolean]("includesEnd")) + cb.memoize(interval.invoke[Boolean]("includesEnd")), ) } def registerLocusCodePredicate(methodName: String): Unit = - registerSCode1(methodName, tlocus("T"), TBoolean, - (_: Type, x: SType) => SBoolean) { - case (r, cb, rt, locus: SLocusValue, _) => - primitive(cb.memoize(cb.emb.getReferenceGenome(locus.st.rg).invoke[Locus, Boolean](methodName, locus.getLocusObj(cb)))) + registerSCode1(methodName, tlocus("T"), TBoolean, (_: Type, x: SType) => SBoolean) { + case (_, cb, _, locus: SLocusValue, _) => + primitive(cb.memoize(cb.emb.getReferenceGenome(locus.st.rg).invoke[Locus, Boolean]( + methodName, + locus.getLocusObj(cb), + ))) } def registerLocusCodeContigPredicate(methodName: String): Unit = - registerSCode1(methodName, tlocus("T"), TBoolean, - (_: Type, x: SType) => SBoolean) { - case (r, cb, rt, locus: SLocusValue, _) => - primitive(cb.memoize(cb.emb.getReferenceGenome(locus.st.rg).invoke[String, Boolean](methodName, locus.contig(cb).loadString(cb)))) + registerSCode1(methodName, tlocus("T"), TBoolean, (_: Type, x: SType) => SBoolean) { + case (_, cb, _, locus: SLocusValue, _) => + primitive(cb.memoize(cb.emb.getReferenceGenome(locus.st.rg).invoke[String, Boolean]( + methodName, + locus.contig(cb).loadString(cb), + ))) } - def registerAll() { + def registerAll(): Unit = { val locusClass = Locus.getClass - registerSCode1("contig", tlocus("T"), TString, - (_: Type, x: SType) => x.asInstanceOf[SLocus].contigType) { - case (r, cb, rt, locus: SLocusValue, _) => + registerSCode1( + "contig", + tlocus("T"), + TString, + (_: Type, x: SType) => x.asInstanceOf[SLocus].contigType, + ) { + case (_, cb, _, locus: SLocusValue, _) => locus.contig(cb) } - registerSCode1("contig_idx", tlocus("T"), TInt32, - (_: Type, x: SType) => SInt32) { - case (r, cb, rt, locus: SLocusValue, _) => - primitive(cb.memoize(cb.emb.getReferenceGenome(locus.st.rg).invoke[String, Int]("getContigIndex", locus.contig(cb).loadString(cb)))) + registerSCode1("contig_idx", tlocus("T"), TInt32, (_: Type, x: SType) => SInt32) { + case (_, cb, _, locus: SLocusValue, _) => + primitive(cb.memoize(cb.emb.getReferenceGenome(locus.st.rg).invoke[String, Int]( + "getContigIndex", + locus.contig(cb).loadString(cb), + ))) } - registerSCode1("position", tlocus("T"), TInt32, (_: Type, x: SType) => SInt32) { - case (r, cb, rt, locus: SLocusValue, _) => + case (_, cb, _, locus: SLocusValue, _) => primitive(locus.position(cb)) } registerLocusCodePredicate("isAutosomalOrPseudoAutosomal") @@ -106,105 +151,194 @@ object LocusFunctions extends RegistryFunctions { registerLocusCodePredicate("inXNonPar") registerLocusCodePredicate("inYNonPar") - registerSCode2("add_on_contig", tlocus("T"), TInt32, tlocus("T"), (tl: Type, _:SType, _: SType) => SCanonicalLocusPointer(PCanonicalLocus(tl.asInstanceOf[TLocus].rg))) { - case (r: EmitRegion, cb: EmitCodeBuilder, rt: SCanonicalLocusPointer, inputLocus: SLocusValue, basePairsToAdd: SInt32Value, errorID) => - + registerSCode2( + "add_on_contig", + tlocus("T"), + TInt32, + tlocus("T"), + (tl: Type, _: SType, _: SType) => + SCanonicalLocusPointer(PCanonicalLocus(tl.asInstanceOf[TLocus].rg)), + ) { + case ( + r: EmitRegion, + cb: EmitCodeBuilder, + rt: SCanonicalLocusPointer, + inputLocus: SLocusValue, + basePairsToAdd: SInt32Value, + _, + ) => val contig = inputLocus.contig(cb).loadString(cb) val basePos = inputLocus.position(cb) val bps = basePairsToAdd.value val newPos = cb.newLocal[Int]("newPos") - cb.if_(bps <= 0, + cb.if_( + bps <= 0, cb.assign(newPos, (basePos + bps).max(1)), - cb.assign(newPos, (basePos + bps).min(cb.emb.getReferenceGenome(rt.rg).invoke[String, Int]("contigLength", contig))) + cb.assign( + newPos, + (basePos + bps).min(cb.emb.getReferenceGenome(rt.rg).invoke[String, Int]( + "contigLength", + contig, + )), + ), ) - rt.pType.constructFromPositionAndString(cb, r.region, contig, newPos) + rt.pType.constructFromContigAndPosition(cb, r.region, contig, newPos) } - registerSCode2("min_rep", tlocus("T"), TArray(TString), TStruct("locus" -> tv("T"), "alleles" -> TArray(TString)), { - (returnType: Type, _: SType, _: SType) => { - val locusPT = PCanonicalLocus(returnType.asInstanceOf[TStruct].field("locus").typ.asInstanceOf[TLocus].rg, true) - PCanonicalStruct("locus" -> locusPT, "alleles" -> PCanonicalArray(PCanonicalString(true), true)).sType - } - }) { - case (r, cb, SBaseStructPointer(rt: PCanonicalStruct), locus: SLocusValue, alleles: SIndexableValue, _) => - val variantTuple = Code.invokeScalaObject2[Locus, IndexedSeq[String], (Locus, IndexedSeq[String])]( - VariantMethods.getClass, "minRep", - locus.getLocusObj(cb), - Code.checkcast[IndexedSeq[String]](svalueToJavaValue(cb, r.region, alleles))) + registerSCode2( + "min_rep", + tlocus("T"), + TArray(TString), + TStruct("locus" -> tv("T"), "alleles" -> TArray(TString)), + { + (returnType: Type, _: SType, _: SType) => + val locusPT = PCanonicalLocus( + returnType.asInstanceOf[TStruct].field("locus").typ.asInstanceOf[TLocus].rg, + true, + ) + PCanonicalStruct( + "locus" -> locusPT, + "alleles" -> PCanonicalArray(PCanonicalString(true), true), + ).sType + }, + ) { + case ( + r, + cb, + SBaseStructPointer(rt: PCanonicalStruct), + locus: SLocusValue, + alleles: SIndexableValue, + _, + ) => + val variantTuple = + Code.invokeScalaObject2[Locus, IndexedSeq[String], (Locus, IndexedSeq[String])]( + VariantMethods.getClass, + "minRep", + locus.getLocusObj(cb), + Code.checkcast[IndexedSeq[String]](svalueToJavaValue(cb, r.region, alleles)), + ) emitVariant(cb, r.region, variantTuple, rt) } - registerSCode2("locus_windows_per_contig", TArray(TArray(TFloat64)), TFloat64, TTuple(TArray(TInt32), TArray(TInt32)), { + registerSCode2( + "locus_windows_per_contig", + TArray(TArray(TFloat64)), + TFloat64, + TTuple(TArray(TInt32), TArray(TInt32)), (_: Type, _: SType, _: SType) => - PCanonicalTuple(false, PCanonicalArray(PInt32(true), true), PCanonicalArray(PInt32(true), true)).sType - }) { - case (r: EmitRegion, cb: EmitCodeBuilder, SBaseStructPointer(rt: PCanonicalTuple), grouped: SIndexableValue, radiusVal: SFloat64Value, errorID) => + PCanonicalTuple( + false, + PCanonicalArray(PInt32(true), true), + PCanonicalArray(PInt32(true), true), + ).sType, + ) { + case ( + r: EmitRegion, + cb: EmitCodeBuilder, + SBaseStructPointer(rt: PCanonicalTuple), + grouped: SIndexableValue, + radiusVal: SFloat64Value, + errorID, + ) => val radius = radiusVal.value val ncontigs = grouped.loadLength() val totalLen = cb.newLocal[Int]("locuswindows_totallen", 0) - def forAllContigs(cb: EmitCodeBuilder)(f: (EmitCodeBuilder, Value[Int], SIndexableValue) => Unit): Unit = { + def forAllContigs( + cb: EmitCodeBuilder + )( + f: (EmitCodeBuilder, Value[Int], SIndexableValue) => Unit + ): Unit = { val iContig = cb.newLocal[Int]("locuswindows_icontig", 0) - cb.while_(iContig < ncontigs, { - val coordPerContig = grouped.loadElement(cb, iContig).get(cb, "locus_windows group cannot be missing") - .asIndexable - f(cb, iContig, coordPerContig) - cb.assign(iContig, iContig + 1) - }) + cb.while_( + iContig < ncontigs, { + val coordPerContig = + grouped.loadElement(cb, iContig).getOrFatal( + cb, + "locus_windows group cannot be missing", + ) + .asIndexable + f(cb, iContig, coordPerContig) + cb.assign(iContig, iContig + 1) + }, + ) } val arrayType = PCanonicalArray(PInt32(true), true) assert(rt.types(0) == arrayType) assert(rt.types(1) == arrayType) - def addIdxWithCondition(cb: EmitCodeBuilder)(cond: (EmitCodeBuilder, Value[Int], Value[Int], SIndexableValue) => Code[Boolean]): IEmitCode = { + def addIdxWithCondition( + cb: EmitCodeBuilder + )( + cond: (EmitCodeBuilder, Value[Int], Value[Int], SIndexableValue) => Code[Boolean] + ): IEmitCode = { - val (pushElement, finish) = arrayType.constructFromFunctions(cb, r.region, totalLen, deepCopy = false) + val (pushElement, finish) = + arrayType.constructFromFunctions(cb, r.region, totalLen, deepCopy = false) val offset = cb.newLocal[Int]("locuswindows_offset", 0) val lastCoord = cb.newLocal[Double]("locuswindows_coord") - forAllContigs(cb) { case (cb, contigIdx, coords) => + forAllContigs(cb) { case (cb, _, coords) => val i = cb.newLocal[Int]("locuswindows_i", 0) val idx = cb.newLocal[Int]("locuswindows_idx", 0) val len = coords.loadLength() - cb.if_(len.ceq(0), + cb.if_( + len.ceq(0), cb.assign(lastCoord, 0.0), - cb.assign(lastCoord, coords.loadElement(cb, 0).get(cb, "locus_windows: missing value for 'coord_expr'").asDouble.value)) - cb.while_(i < len, { - - coords.loadElement(cb, i).consume(cb, - cb._fatalWithError(errorID, const("locus_windows: missing value for 'coord_expr' at row ") - .concat((offset + i).toS)), - { sc => - val currentCoord = cb.newLocal[Double]("locuswindows_coord_i", sc.asDouble.value) - cb.if_(lastCoord > currentCoord, - cb._fatalWithError(errorID, "locus_windows: 'coord_expr' must be in ascending order within each contig."), - cb.assign(lastCoord, currentCoord) - ) - }) - - val Lstart = CodeLabel() - val Lbreak = CodeLabel() - - cb.define(Lstart) - cb.if_(idx >= len, - cb.goto(Lbreak) - ) - cb.if_(cond(cb, i, idx, coords), - { - cb.assign(idx, idx + 1) - cb.goto(Lstart) - }, - cb.goto(Lbreak) - ) - cb.define(Lbreak) - - pushElement(cb, IEmitCode.present(cb, primitive(cb.memoize(offset + idx)))) - - cb.assign(i, i + 1) - }) + cb.assign( + lastCoord, + coords.loadElement(cb, 0).getOrFatal( + cb, + "locus_windows: missing value for 'coord_expr'", + ).asDouble.value, + ), + ) + cb.while_( + i < len, { + + coords.loadElement(cb, i).consume( + cb, + cb._fatalWithError( + errorID, + const("locus_windows: missing value for 'coord_expr' at row ") + .concat((offset + i).toS), + ), + { sc => + val currentCoord = + cb.newLocal[Double]("locuswindows_coord_i", sc.asDouble.value) + cb.if_( + lastCoord > currentCoord, + cb._fatalWithError( + errorID, + "locus_windows: 'coord_expr' must be in ascending order within each contig.", + ), + cb.assign(lastCoord, currentCoord), + ) + }, + ) + + val Lstart = CodeLabel() + val Lbreak = CodeLabel() + + cb.define(Lstart) + cb.if_(idx >= len, cb.goto(Lbreak)) + cb.if_( + cond(cb, i, idx, coords), { + cb.assign(idx, idx + 1) + cb.goto(Lstart) + }, + cb.goto(Lbreak), + ) + cb.define(Lbreak) + + pushElement(cb, IEmitCode.present(cb, primitive(cb.memoize(offset + idx)))) + + cb.assign(i, i + 1) + }, + ) cb.assign(offset, offset + len) } IEmitCode.present(cb, finish(cb)) @@ -214,102 +348,187 @@ object LocusFunctions extends RegistryFunctions { cb.assign(totalLen, totalLen + coordsPerContig.loadLength()) } - rt.constructFromFields(cb, r.region, + rt.constructFromFields( + cb, + r.region, FastSeq[EmitCode]( - EmitCode.fromI(cb.emb)(cb => addIdxWithCondition(cb) { case (cb, i, idx, coords) => coords.loadElement(cb, i) - .get(cb, "locus_windows: missing value for 'coord_expr'") - .asDouble.value > (coords.loadElement(cb, idx) - .get(cb, "locus_windows: missing value for 'coord_expr'").asDouble.value + radius) - }), - EmitCode.fromI(cb.emb)(cb => addIdxWithCondition(cb) { case (cb, i, idx, coords) => coords.loadElement(cb, i) - .get(cb, "locus_windows: missing value for 'coord_expr'") - .asDouble.value >= (coords.loadElement(cb, idx) - .get(cb, "locus_windows: missing value for 'coord_expr'").asDouble.value - radius) - }) - ), deepCopy = false) + EmitCode.fromI(cb.emb)(cb => + addIdxWithCondition(cb) { case (cb, i, idx, coords) => + coords.loadElement(cb, i) + .getOrFatal(cb, "locus_windows: missing value for 'coord_expr'") + .asDouble.value > (coords.loadElement(cb, idx) + .getOrFatal( + cb, + "locus_windows: missing value for 'coord_expr'", + ).asDouble.value + radius) + } + ), + EmitCode.fromI(cb.emb)(cb => + addIdxWithCondition(cb) { case (cb, i, idx, coords) => + coords.loadElement(cb, i) + .getOrFatal(cb, "locus_windows: missing value for 'coord_expr'") + .asDouble.value >= (coords.loadElement(cb, idx) + .getOrFatal( + cb, + "locus_windows: missing value for 'coord_expr'", + ).asDouble.value - radius) + } + ), + ), + deepCopy = false, + ) } - registerSCode1("Locus", TString, tlocus("T"), { - (returnType: Type, _: SType) => PCanonicalLocus(returnType.asInstanceOf[TLocus].rg).sType - }) { + registerSCode1( + "Locus", + TString, + tlocus("T"), + (returnType: Type, _: SType) => PCanonicalLocus(returnType.asInstanceOf[TLocus].rg).sType, + ) { case (r, cb, SCanonicalLocusPointer(rt: PCanonicalLocus), str: SStringValue, _) => val slocus = str.loadString(cb) - emitLocus(cb, + emitLocus( + cb, r.region, - Code.invokeScalaObject2[String, ReferenceGenome, Locus](locusClass, "parse", slocus, rgCode(r.mb, rt.rg)), - rt) + Code.invokeScalaObject2[String, ReferenceGenome, Locus]( + locusClass, + "parse", + slocus, + rgCode(r.mb, rt.rg), + ), + rt, + ) } - registerSCode2("Locus", TString, TInt32, tlocus("T"), { - (returnType: Type, _: SType, _: SType) => PCanonicalLocus(returnType.asInstanceOf[TLocus].rg).sType - }) { + registerSCode2( + "Locus", + TString, + TInt32, + tlocus("T"), + (returnType: Type, _: SType, _: SType) => + PCanonicalLocus(returnType.asInstanceOf[TLocus].rg).sType, + ) { case (r, cb, SCanonicalLocusPointer(rt: PCanonicalLocus), contig, pos, _) => - cb += rgCode(r.mb, rt.rg).invoke[String, Int, Unit]("checkLocus", contig.asString.loadString(cb), pos.asInt.value) - rt.constructFromPositionAndString(cb, r.region, contig.asString.loadString(cb), pos.asInt.value) + cb += rgCode(r.mb, rt.rg).invoke[String, Int, Unit]( + "checkLocus", + contig.asString.loadString(cb), + pos.asInt.value, + ) + rt.constructFromContigAndPosition( + cb, + r.region, + contig.asString.loadString(cb), + pos.asInt.value, + ) } - registerSCode1("LocusAlleles", TString, tvariant("T"), { - (returnType: Type, _: SType) => { - val lTyp = returnType.asInstanceOf[TStruct].field("locus").typ.asInstanceOf[TLocus] - PCanonicalStruct("locus" -> PCanonicalLocus(lTyp.rg, true), "alleles" -> PCanonicalArray(PCanonicalString(true), true)).sType - } - }) { + registerSCode1( + "LocusAlleles", + TString, + tvariant("T"), + { + (returnType: Type, _: SType) => + val lTyp = returnType.asInstanceOf[TStruct].field("locus").typ.asInstanceOf[TLocus] + PCanonicalStruct( + "locus" -> PCanonicalLocus(lTyp.rg, true), + "alleles" -> PCanonicalArray(PCanonicalString(true), true), + ).sType + }, + ) { case (r, cb, SBaseStructPointer(rt: PCanonicalStruct), variantStr: SStringValue, _) => - val svar = variantStr.loadString(cb) val plocus = rt.types(0).asInstanceOf[PCanonicalLocus] val variant = Code .invokeScalaObject2[String, ReferenceGenome, (Locus, IndexedSeq[String])]( - VariantMethods.getClass, "parse", svar, rgCode(r.mb, plocus.rg)) + VariantMethods.getClass, + "parse", + svar, + rgCode(r.mb, plocus.rg), + ) emitVariant(cb, r.region, variant, rt) } - registerIEmitCode2("LocusInterval", TString, TBoolean, tinterval("T"), { - (returnType: Type, _: EmitType, _: EmitType) => { - val lPTyp = returnType.asInstanceOf[TInterval].pointType.asInstanceOf[TLocus] - EmitType(PCanonicalInterval(PCanonicalLocus(lPTyp.asInstanceOf[TLocus].rg)).sType, false) - } - }) { case (cb: EmitCodeBuilder, r: Value[Region], SIntervalPointer(rt: PCanonicalInterval), _, locusStrEC: EmitCode, invalidMissingEC: EmitCode) => - val plocus = rt.pointType.asInstanceOf[PLocus] - - - locusStrEC.toI(cb).flatMap(cb) { locusStr => - invalidMissingEC.toI(cb).flatMap(cb) { invalidMissing => + registerIEmitCode2( + "LocusInterval", + TString, + TBoolean, + tinterval("T"), + { + (returnType: Type, _: EmitType, _: EmitType) => + val lPTyp = returnType.asInstanceOf[TInterval].pointType.asInstanceOf[TLocus] + EmitType(PCanonicalInterval(PCanonicalLocus(lPTyp.asInstanceOf[TLocus].rg)).sType, false) + }, + ) { + case ( + cb: EmitCodeBuilder, + r: Value[Region], + SIntervalPointer(rt: PCanonicalInterval), + _, + locusStrEC: EmitCode, + invalidMissingEC: EmitCode, + ) => + val plocus = rt.pointType.asInstanceOf[PLocus] - val Lmissing = CodeLabel() - val Ldefined = CodeLabel() + locusStrEC.toI(cb).flatMap(cb) { locusStr => + invalidMissingEC.toI(cb).flatMap(cb) { invalidMissing => + val Lmissing = CodeLabel() + val Ldefined = CodeLabel() - val interval = cb.newLocal[Interval]("locus_interval_interval", - Code.invokeScalaObject3[String, ReferenceGenome, Boolean, Interval]( - locusClass, "parseInterval", - locusStr.asString.loadString(cb), - rgCode(cb.emb, plocus.rg), - invalidMissing.asBoolean.value)) + val interval = cb.newLocal[Interval]( + "locus_interval_interval", + Code.invokeScalaObject3[String, ReferenceGenome, Boolean, Interval]( + locusClass, + "parseInterval", + locusStr.asString.loadString(cb), + rgCode(cb.emb, plocus.rg), + invalidMissing.asBoolean.value, + ), + ) - cb.if_(interval.isNull, cb.goto(Lmissing)) + cb.if_(interval.isNull, cb.goto(Lmissing)) - val intervalCode = emitLocusInterval(cb, r, interval, rt) - cb.goto(Ldefined) - IEmitCode(Lmissing, Ldefined, intervalCode, false) + val intervalCode = emitLocusInterval(cb, r, interval, rt) + cb.goto(Ldefined) + IEmitCode(Lmissing, Ldefined, intervalCode, false) + } } - } } - registerIEmitCode6("LocusInterval", TString, TInt32, TInt32, TBoolean, TBoolean, TBoolean, tinterval("T"), { - (returnType: Type, _: EmitType, _: EmitType, _: EmitType, _: EmitType, _: EmitType, _: EmitType) => { - val lPTyp = returnType.asInstanceOf[TInterval].pointType.asInstanceOf[TLocus] - EmitType(PCanonicalInterval(PCanonicalLocus(lPTyp.rg)).sType, false) - } - }) { - case (cb: EmitCodeBuilder, r: Value[Region], - SIntervalPointer(rt: PCanonicalInterval), - errorID: Value[Int], - locusString: EmitCode, - pos1: EmitCode, - pos2: EmitCode, - include1: EmitCode, - include2: EmitCode, - invalidMissing: EmitCode) => + registerIEmitCode6( + "LocusInterval", + TString, + TInt32, + TInt32, + TBoolean, + TBoolean, + TBoolean, + tinterval("T"), + { + ( + returnType: Type, + _: EmitType, + _: EmitType, + _: EmitType, + _: EmitType, + _: EmitType, + _: EmitType, + ) => + val lPTyp = returnType.asInstanceOf[TInterval].pointType.asInstanceOf[TLocus] + EmitType(PCanonicalInterval(PCanonicalLocus(lPTyp.rg)).sType, false) + }, + ) { + case ( + cb: EmitCodeBuilder, + r: Value[Region], + SIntervalPointer(rt: PCanonicalInterval), + _: Value[Int], + locusString: EmitCode, + pos1: EmitCode, + pos2: EmitCode, + include1: EmitCode, + include2: EmitCode, + invalidMissing: EmitCode, + ) => val plocus = rt.pointType.asInstanceOf[PLocus] locusString.toI(cb).flatMap(cb) { locusString => @@ -318,20 +537,32 @@ object LocusFunctions extends RegistryFunctions { include1.toI(cb).flatMap(cb) { include1 => include2.toI(cb).flatMap(cb) { include2 => invalidMissing.toI(cb).flatMap(cb) { invalidMissing => - val Lmissing = CodeLabel() val Ldefined = CodeLabel() - val interval = cb.newLocal[Interval]("locus_interval_interval", - Code.invokeScalaObject7[String, Int, Int, Boolean, Boolean, ReferenceGenome, Boolean, Interval]( - locusClass, "makeInterval", + val interval = cb.newLocal[Interval]( + "locus_interval_interval", + Code.invokeScalaObject7[ + String, + Int, + Int, + Boolean, + Boolean, + ReferenceGenome, + Boolean, + Interval, + ]( + locusClass, + "makeInterval", locusString.asString.loadString(cb), pos1.asInt.value, pos2.asInt.value, include1.asBoolean.value, include2.asBoolean.value, rgCode(cb.emb, plocus.rg), - invalidMissing.asBoolean.value)) + invalidMissing.asBoolean.value, + ), + ) cb.if_(interval.isNull, cb.goto(Lmissing)) @@ -346,33 +577,53 @@ object LocusFunctions extends RegistryFunctions { } } - registerSCode1("globalPosToLocus", TInt64, tlocus("T"), { - (returnType: Type, _: SType) => - PCanonicalLocus(returnType.asInstanceOf[TLocus].rg).sType - }) { + registerSCode1( + "globalPosToLocus", + TInt64, + tlocus("T"), + (returnType: Type, _: SType) => PCanonicalLocus(returnType.asInstanceOf[TLocus].rg).sType, + ) { case (r, cb, SCanonicalLocusPointer(rt: PCanonicalLocus), globalPos, _) => - val locus = cb.newLocal[Locus]("global_pos_locus", - rgCode(r.mb, rt.rg).invoke[Long, Locus]("globalPosToLocus", globalPos.asLong.value)) - rt.constructFromPositionAndString(cb, r.region, locus.invoke[String]("contig"), locus.invoke[Int]("position")) + val locus = cb.newLocal[Locus]( + "global_pos_locus", + rgCode(r.mb, rt.rg).invoke[Long, Locus]("globalPosToLocus", globalPos.asLong.value), + ) + rt.constructFromContigAndPosition( + cb, + r.region, + locus.invoke[String]("contig"), + locus.invoke[Int]("position"), + ) } registerSCode1("locusToGlobalPos", tlocus("T"), TInt64, (_: Type, _: SType) => SInt64) { - case (r, cb, rt, locus: SLocusValue, _) => + case (r, cb, _, locus: SLocusValue, _) => val locusObject = locus.getLocusObj(cb) - val globalPos = cb.memoize(rgCode(r.mb, locus.st.rg).invoke[Locus, Long]("locusToGlobalPos", locusObject)) + val globalPos = + cb.memoize(rgCode(r.mb, locus.st.rg).invoke[Locus, Long]("locusToGlobalPos", locusObject)) primitive(globalPos) } - registerIEmitCode2("liftoverLocus", tlocus("T"), TFloat64, TStruct("result" -> tv("U", "locus"), "is_negative_strand" -> TBoolean), { - (returnType: Type, _: EmitType, _: EmitType) => { - val lTyp = returnType.asInstanceOf[TStruct].field("result").typ.asInstanceOf[TLocus] - EmitType(PCanonicalStruct("result" -> PCanonicalLocus(lTyp.rg, true), "is_negative_strand" -> PBoolean(true)).sType, false) - } - }) { + registerIEmitCode2( + "liftoverLocus", + tlocus("T"), + TFloat64, + TStruct("result" -> tv("U", "locus"), "is_negative_strand" -> TBoolean), + { + (returnType: Type, _: EmitType, _: EmitType) => + val lTyp = returnType.asInstanceOf[TStruct].field("result").typ.asInstanceOf[TLocus] + EmitType( + PCanonicalStruct( + "result" -> PCanonicalLocus(lTyp.rg, true), + "is_negative_strand" -> PBoolean(true), + ).sType, + false, + ) + }, + ) { case (cb, r, SBaseStructPointer(rt: PCanonicalStruct), _, loc, minMatch) => loc.toI(cb).flatMap(cb) { case loc: SLocusValue => minMatch.toI(cb).flatMap(cb) { minMatch => - val Lmissing = CodeLabel() val Ldefined = CodeLabel() @@ -381,20 +632,43 @@ object LocusFunctions extends RegistryFunctions { val destRG = rt.types(0).asInstanceOf[PLocus].rg val locusObj = loc.getLocusObj(cb) - val lifted = cb.newLocal[(Locus, Boolean)]("lifterover_locus_ lifted", - rgCode(cb.emb, srcRG).invoke[String, Locus, Double, (Locus, Boolean)]("liftoverLocus", - destRG, locusObj, minMatch.asDouble.value)) + val lifted = cb.newLocal[(Locus, Boolean)]( + "lifterover_locus_ lifted", + rgCode(cb.emb, srcRG).invoke[String, Locus, Double, (Locus, Boolean)]( + "liftoverLocus", + destRG, + locusObj, + minMatch.asDouble.value, + ), + ) cb.if_(lifted.isNull, cb.goto(Lmissing)) val locType = rt.types(0).asInstanceOf[PCanonicalLocus] - val locusCode = EmitCode.present(cb.emb, emitLocus(cb, r, Code.checkcast[Locus](lifted.getField[java.lang.Object]("_1")), locType)) - - val negativeStrandCode = EmitCode.present(cb.emb, - primitive(cb.memoize(Code.checkcast[java.lang.Boolean](lifted.getField[java.lang.Object]("_2")) - .invoke[Boolean]("booleanValue")))) - - val structCode = rt.constructFromFields(cb, r, FastSeq(locusCode, negativeStrandCode), deepCopy = false) + val locusCode = EmitCode.present( + cb.emb, + emitLocus( + cb, + r, + Code.checkcast[Locus](lifted.getField[java.lang.Object]("_1")), + locType, + ), + ) + + val negativeStrandCode = EmitCode.present( + cb.emb, + primitive(cb.memoize( + Code.checkcast[java.lang.Boolean](lifted.getField[java.lang.Object]("_2")) + .invoke[Boolean]("booleanValue") + )), + ) + + val structCode = rt.constructFromFields( + cb, + r, + FastSeq(locusCode, negativeStrandCode), + deepCopy = false, + ) cb.goto(Ldefined) IEmitCode(Lmissing, Ldefined, structCode, false) @@ -402,41 +676,72 @@ object LocusFunctions extends RegistryFunctions { } } - registerIEmitCode2("liftoverLocusInterval", tinterval("T"), TFloat64, TStruct("result" -> tinterval("U"), "is_negative_strand" -> TBoolean), { - (returnType: Type, _: EmitType, _: EmitType) => { - val lTyp = returnType.asInstanceOf[TStruct].field("result").typ.asInstanceOf[TInterval].pointType.asInstanceOf[TLocus] - EmitType(PCanonicalStruct("result" -> PCanonicalInterval(PCanonicalLocus(lTyp.rg, true), true), "is_negative_strand" -> PBoolean(true)).sType, false) - } - }) { + registerIEmitCode2( + "liftoverLocusInterval", + tinterval("T"), + TFloat64, + TStruct("result" -> tinterval("U"), "is_negative_strand" -> TBoolean), + { + (returnType: Type, _: EmitType, _: EmitType) => + val lTyp = returnType.asInstanceOf[TStruct].field("result").typ.asInstanceOf[ + TInterval + ].pointType.asInstanceOf[TLocus] + EmitType( + PCanonicalStruct( + "result" -> PCanonicalInterval(PCanonicalLocus(lTyp.rg, true), true), + "is_negative_strand" -> PBoolean(true), + ).sType, + false, + ) + }, + ) { case (cb, r, SBaseStructPointer(rt: PCanonicalStruct), _, interval, minMatch) => interval.toI(cb).flatMap(cb) { interval => minMatch.toI(cb).flatMap(cb) { minMatch => - val Lmissing = CodeLabel() val Ldefined = CodeLabel() - val iT = interval.st.asInstanceOf[SInterval] val srcRG = iT.pointType.asInstanceOf[SLocus].rg val destRG = rt.types(0).asInstanceOf[PInterval].pointType.asInstanceOf[PLocus].rg - val er = EmitRegion(cb.emb, r) val intervalObj = Code.checkcast[Interval](svalueToJavaValue(cb, r, interval)) - val lifted = cb.newLocal[(Interval, Boolean)]("liftover_locus_interval_lifted", - rgCode(cb.emb, srcRG).invoke[String, Interval, Double, (Interval, Boolean)]("liftoverLocusInterval", - destRG, intervalObj, minMatch.asDouble.value)) - + val lifted = cb.newLocal[(Interval, Boolean)]( + "liftover_locus_interval_lifted", + rgCode(cb.emb, srcRG).invoke[String, Interval, Double, (Interval, Boolean)]( + "liftoverLocusInterval", + destRG, + intervalObj, + minMatch.asDouble.value, + ), + ) cb.if_(lifted.isNull, cb.goto(Lmissing)) val iType = rt.types(0).asInstanceOf[PCanonicalInterval] - val intervalCode = EmitCode.present(cb.emb, emitLocusInterval(cb, r, Code.checkcast[Interval](lifted.getField[java.lang.Object]("_1")), iType)) - - val negativeStrandCode = EmitCode.present(cb.emb, - primitive(cb.memoize(Code.checkcast[java.lang.Boolean](lifted.getField[java.lang.Object]("_2")) - .invoke[Boolean]("booleanValue")))) - - val structCode = rt.constructFromFields(cb, r, FastSeq(intervalCode, negativeStrandCode), deepCopy = false) - + val intervalCode = EmitCode.present( + cb.emb, + emitLocusInterval( + cb, + r, + Code.checkcast[Interval](lifted.getField[java.lang.Object]("_1")), + iType, + ), + ) + + val negativeStrandCode = EmitCode.present( + cb.emb, + primitive(cb.memoize( + Code.checkcast[java.lang.Boolean](lifted.getField[java.lang.Object]("_2")) + .invoke[Boolean]("booleanValue") + )), + ) + + val structCode = rt.constructFromFields( + cb, + r, + FastSeq(intervalCode, negativeStrandCode), + deepCopy = false, + ) cb.goto(Ldefined) IEmitCode(Lmissing, Ldefined, structCode, false) diff --git a/hail/src/main/scala/is/hail/expr/ir/functions/MathFunctions.scala b/hail/src/main/scala/is/hail/expr/ir/functions/MathFunctions.scala index 28db3df70ee..bacc2c50ac4 100644 --- a/hail/src/main/scala/is/hail/expr/ir/functions/MathFunctions.scala +++ b/hail/src/main/scala/is/hail/expr/ir/functions/MathFunctions.scala @@ -8,6 +8,7 @@ import is.hail.types.physical.stypes.interfaces.primitive import is.hail.types.physical.stypes.primitives._ import is.hail.types.virtual._ import is.hail.utils._ + import org.apache.commons.math3.special.Gamma object MathFunctions extends RegistryFunctions { @@ -16,10 +17,10 @@ object MathFunctions extends RegistryFunctions { // This does a truncating log2, always rounnds down def log2(x: Int): Int = { var v = x - var r = if (v > 0xFFFF) 16 else 0 + var r = if (v > 0xffff) 16 else 0 v >>= r - if (v > 0xFF) { v >>= 8; r |= 8 } - if (v > 0xF) { v >>= 4; r |= 4 } + if (v > 0xff) { v >>= 8; r |= 8 } + if (v > 0xf) { v >>= 4; r |= 4 } if (v > 0x3) { v >>= 2; r |= 2 } r |= v >> 1 r @@ -35,7 +36,6 @@ object MathFunctions extends RegistryFunctions { v + 1 } - def gamma(x: Double): Double = Gamma.gamma(x) def floor(x: Float): Float = math.floor(x).toFloat @@ -86,7 +86,6 @@ object MathFunctions extends RegistryFunctions { java.lang.Math.floorDiv(x, y) } - def floorDiv(x: Long, y: Long): Long = { if (y == 0L) fatal(s"$x // 0: integer division by zero", ErrorIDs.NO_ERROR) @@ -97,7 +96,8 @@ object MathFunctions extends RegistryFunctions { def floorDiv(x: Double, y: Double): Double = math.floor(x / y) - def approxEqual(x: Double, y: Double, tolerance: Double, absolute: Boolean, nanSame: Boolean): Boolean = { + def approxEqual(x: Double, y: Double, tolerance: Double, absolute: Boolean, nanSame: Boolean) + : Boolean = { val withinTol = if (absolute) math.abs(x - y) <= tolerance @@ -110,7 +110,7 @@ object MathFunctions extends RegistryFunctions { val mathPackageClass: Class[_] = Class.forName("scala.math.package$") - def registerAll() { + def registerAll(): Unit = { val thisClass = getClass val statsPackageClass = Class.forName("is.hail.stats.package$") val jMathClass = classOf[java.lang.Math] @@ -140,50 +140,142 @@ object MathFunctions extends RegistryFunctions { registerScalaFunction("log", Array(TFloat64), TFloat64, null)(mathPackageClass, "log") registerScalaFunction("log", Array(TFloat64, TFloat64), TFloat64, null)(thisClass, "log") registerScalaFunction("log2", Array(TInt32), TInt32, null)(thisClass, "log2") - registerScalaFunction("roundToNextPowerOf2", Array(TInt32), TInt32, null)(thisClass, "roundToNextPowerOf2") + registerScalaFunction("roundToNextPowerOf2", Array(TInt32), TInt32, null)( + thisClass, + "roundToNextPowerOf2", + ) registerScalaFunction("gamma", Array(TFloat64), TFloat64, null)(thisClass, "gamma") - registerScalaFunction("binomTest", Array(TInt32, TInt32, TFloat64, TInt32), TFloat64, null)(statsPackageClass, "binomTest") + registerScalaFunction("binomTest", Array(TInt32, TInt32, TFloat64, TInt32), TFloat64, null)( + statsPackageClass, + "binomTest", + ) - registerScalaFunction("dbeta", Array(TFloat64, TFloat64, TFloat64), TFloat64, null)(statsPackageClass, "dbeta") + registerScalaFunction("dbeta", Array(TFloat64, TFloat64, TFloat64), TFloat64, null)( + statsPackageClass, + "dbeta", + ) registerScalaFunction("dnorm", Array(TFloat64), TFloat64, null)(statsPackageClass, "dnorm") - registerScalaFunction("dnorm", Array(TFloat64, TFloat64, TFloat64, TBoolean), TFloat64, null)(statsPackageClass, "dnorm") + registerScalaFunction("dnorm", Array(TFloat64, TFloat64, TFloat64, TBoolean), TFloat64, null)( + statsPackageClass, + "dnorm", + ) registerScalaFunction("pnorm", Array(TFloat64), TFloat64, null)(statsPackageClass, "pnorm") - registerScalaFunction("pnorm", Array(TFloat64, TFloat64, TFloat64, TBoolean, TBoolean), TFloat64, null)(statsPackageClass, "pnorm") + registerScalaFunction( + "pnorm", + Array(TFloat64, TFloat64, TFloat64, TBoolean, TBoolean), + TFloat64, + null, + )(statsPackageClass, "pnorm") registerScalaFunction("qnorm", Array(TFloat64), TFloat64, null)(statsPackageClass, "qnorm") - registerScalaFunction("qnorm", Array(TFloat64, TFloat64, TFloat64, TBoolean, TBoolean), TFloat64, null)(statsPackageClass, "qnorm") - - registerScalaFunction("pT", Array(TFloat64, TFloat64, TBoolean, TBoolean), TFloat64, null)(statsPackageClass, "pT") - registerScalaFunction("pF", Array(TFloat64, TFloat64, TFloat64, TBoolean, TBoolean), TFloat64, null)(statsPackageClass, "pF") - - registerScalaFunction("dpois", Array(TFloat64, TFloat64), TFloat64, null)(statsPackageClass, "dpois") - registerScalaFunction("dpois", Array(TFloat64, TFloat64, TBoolean), TFloat64, null)(statsPackageClass, "dpois") - - registerScalaFunction("ppois", Array(TFloat64, TFloat64), TFloat64, null)(statsPackageClass, "ppois") - registerScalaFunction("ppois", Array(TFloat64, TFloat64, TBoolean, TBoolean), TFloat64, null)(statsPackageClass, "ppois") - - registerScalaFunction("qpois", Array(TFloat64, TFloat64), TInt32, null)(statsPackageClass, "qpois") - registerScalaFunction("qpois", Array(TFloat64, TFloat64, TBoolean, TBoolean), TInt32, null)(statsPackageClass, "qpois") - - registerScalaFunction("dchisq", Array(TFloat64, TFloat64), TFloat64, null)(statsPackageClass, "dchisq") - registerScalaFunction("dchisq", Array(TFloat64, TFloat64, TBoolean), TFloat64, null)(statsPackageClass, "dchisq") - - registerScalaFunction("dnchisq", Array(TFloat64, TFloat64, TFloat64), TFloat64, null)(statsPackageClass, "dnchisq") - registerScalaFunction("dnchisq", Array(TFloat64, TFloat64, TFloat64, TBoolean), TFloat64, null)(statsPackageClass, "dnchisq") - - registerScalaFunction("pchisqtail", Array(TFloat64, TFloat64), TFloat64, null)(statsPackageClass, "pchisqtail") - registerScalaFunction("pchisqtail", Array(TFloat64, TFloat64, TBoolean, TBoolean), TFloat64, null)(statsPackageClass, "pchisqtail") - - registerScalaFunction("pnchisqtail", Array(TFloat64, TFloat64, TFloat64), TFloat64, null)(statsPackageClass, "pnchisqtail") - registerScalaFunction("pnchisqtail", Array(TFloat64, TFloat64, TFloat64, TBoolean, TBoolean), TFloat64, null)(statsPackageClass, "pnchisqtail") - - registerScalaFunction("qchisqtail", Array(TFloat64, TFloat64), TFloat64, null)(statsPackageClass, "qchisqtail") - registerScalaFunction("qchisqtail", Array(TFloat64, TFloat64, TBoolean, TBoolean), TFloat64, null)(statsPackageClass, "qchisqtail") - - registerScalaFunction("qnchisqtail", Array(TFloat64, TFloat64, TFloat64), TFloat64, null)(statsPackageClass, "qnchisqtail") - registerScalaFunction("qnchisqtail", Array(TFloat64, TFloat64, TFloat64, TBoolean, TBoolean), TFloat64, null)(statsPackageClass, "qnchisqtail") + registerScalaFunction( + "qnorm", + Array(TFloat64, TFloat64, TFloat64, TBoolean, TBoolean), + TFloat64, + null, + )(statsPackageClass, "qnorm") + + registerScalaFunction("pT", Array(TFloat64, TFloat64, TBoolean, TBoolean), TFloat64, null)( + statsPackageClass, + "pT", + ) + registerScalaFunction( + "pF", + Array(TFloat64, TFloat64, TFloat64, TBoolean, TBoolean), + TFloat64, + null, + )(statsPackageClass, "pF") + + registerScalaFunction("dpois", Array(TFloat64, TFloat64), TFloat64, null)( + statsPackageClass, + "dpois", + ) + registerScalaFunction("dpois", Array(TFloat64, TFloat64, TBoolean), TFloat64, null)( + statsPackageClass, + "dpois", + ) + + registerScalaFunction("ppois", Array(TFloat64, TFloat64), TFloat64, null)( + statsPackageClass, + "ppois", + ) + registerScalaFunction("ppois", Array(TFloat64, TFloat64, TBoolean, TBoolean), TFloat64, null)( + statsPackageClass, + "ppois", + ) + + registerScalaFunction("qpois", Array(TFloat64, TFloat64), TInt32, null)( + statsPackageClass, + "qpois", + ) + registerScalaFunction("qpois", Array(TFloat64, TFloat64, TBoolean, TBoolean), TInt32, null)( + statsPackageClass, + "qpois", + ) + + registerScalaFunction("dchisq", Array(TFloat64, TFloat64), TFloat64, null)( + statsPackageClass, + "dchisq", + ) + registerScalaFunction("dchisq", Array(TFloat64, TFloat64, TBoolean), TFloat64, null)( + statsPackageClass, + "dchisq", + ) + + registerScalaFunction("dnchisq", Array(TFloat64, TFloat64, TFloat64), TFloat64, null)( + statsPackageClass, + "dnchisq", + ) + registerScalaFunction("dnchisq", Array(TFloat64, TFloat64, TFloat64, TBoolean), TFloat64, null)( + statsPackageClass, + "dnchisq", + ) + + registerScalaFunction("pchisqtail", Array(TFloat64, TFloat64), TFloat64, null)( + statsPackageClass, + "pchisqtail", + ) + registerScalaFunction( + "pchisqtail", + Array(TFloat64, TFloat64, TBoolean, TBoolean), + TFloat64, + null, + )(statsPackageClass, "pchisqtail") + + registerScalaFunction("pnchisqtail", Array(TFloat64, TFloat64, TFloat64), TFloat64, null)( + statsPackageClass, + "pnchisqtail", + ) + registerScalaFunction( + "pnchisqtail", + Array(TFloat64, TFloat64, TFloat64, TBoolean, TBoolean), + TFloat64, + null, + )(statsPackageClass, "pnchisqtail") + + registerScalaFunction("qchisqtail", Array(TFloat64, TFloat64), TFloat64, null)( + statsPackageClass, + "qchisqtail", + ) + registerScalaFunction( + "qchisqtail", + Array(TFloat64, TFloat64, TBoolean, TBoolean), + TFloat64, + null, + )(statsPackageClass, "qchisqtail") + + registerScalaFunction("qnchisqtail", Array(TFloat64, TFloat64, TFloat64), TFloat64, null)( + statsPackageClass, + "qnchisqtail", + ) + registerScalaFunction( + "qnchisqtail", + Array(TFloat64, TFloat64, TFloat64, TBoolean, TBoolean), + TFloat64, + null, + )(statsPackageClass, "qnchisqtail") registerSCode7( "pgenchisq", @@ -195,41 +287,60 @@ object MathFunctions extends RegistryFunctions { TInt32, TFloat64, DaviesAlgorithm.pType.virtualType, - (_, _, _, _, _, _, _, _) => DaviesAlgorithm.pType.sType + (_, _, _, _, _, _, _, _) => DaviesAlgorithm.pType.sType, ) { - case (r, cb, rt, - x: SFloat64Value, - _w: SIndexablePointerValue, - _k: SIndexablePointerValue, - _lam: SIndexablePointerValue, - sigma: SFloat64Value, - maxIterations: SInt32Value, - minAccuracy: SFloat64Value, - _) => - + case ( + r, + cb, + _, + x: SFloat64Value, + _w: SIndexablePointerValue, + _k: SIndexablePointerValue, + _lam: SIndexablePointerValue, + sigma: SFloat64Value, + maxIterations: SInt32Value, + minAccuracy: SFloat64Value, + _, + ) => val w = _w.castToArray(cb) val k = _k.castToArray(cb) val lam = _lam.castToArray(cb) - val res = cb.newLocal[DaviesResultForPython]("pgenchisq_result", + val res = cb.newLocal[DaviesResultForPython]( + "pgenchisq_result", Code.invokeScalaObject7[ - Double, IndexedSeq[Double], IndexedSeq[Int], IndexedSeq[Double], Double, Int, Double, DaviesResultForPython - ](statsPackageClass, "pgenchisq", + Double, + IndexedSeq[Double], + IndexedSeq[Int], + IndexedSeq[Double], + Double, + Int, + Double, + DaviesResultForPython, + ]( + statsPackageClass, + "pgenchisq", x.value, Code.checkcast[IndexedSeq[Double]](svalueToJavaValue(cb, r.region, w)), Code.checkcast[IndexedSeq[Int]](svalueToJavaValue(cb, r.region, k)), Code.checkcast[IndexedSeq[Double]](svalueToJavaValue(cb, r.region, lam)), sigma.value, maxIterations.value, - minAccuracy.value) + minAccuracy.value, + ), ) - DaviesAlgorithm.pType.constructFromFields(cb, r.region, FastSeq( - EmitValue.present(primitive(cb.memoize(res.invoke[Double]("value")))), - EmitValue.present(primitive(cb.memoize(res.invoke[Int]("nIterations")))), - EmitValue.present(primitive(cb.memoize(res.invoke[Boolean]("converged")))), - EmitValue.present(primitive(cb.memoize(res.invoke[Int]("fault")))) - ), deepCopy = false) + DaviesAlgorithm.pType.constructFromFields( + cb, + r.region, + FastSeq( + EmitValue.present(primitive(cb.memoize(res.invoke[Double]("value")))), + EmitValue.present(primitive(cb.memoize(res.invoke[Int]("nIterations")))), + EmitValue.present(primitive(cb.memoize(res.invoke[Boolean]("converged")))), + EmitValue.present(primitive(cb.memoize(res.invoke[Int]("fault")))), + ), + deepCopy = false, + ) } registerScalaFunction("floor", Array(TFloat32), TFloat32, null)(thisClass, "floor") @@ -246,86 +357,189 @@ object MathFunctions extends RegistryFunctions { registerJavaStaticFunction("isnan", Array(TFloat32), TBoolean, null)(jFloatClass, "isNaN") registerJavaStaticFunction("isnan", Array(TFloat64), TBoolean, null)(jDoubleClass, "isNaN") - registerJavaStaticFunction("is_finite", Array(TFloat32), TBoolean, null)(jFloatClass, "isFinite") - registerJavaStaticFunction("is_finite", Array(TFloat64), TBoolean, null)(jDoubleClass, "isFinite") - - registerJavaStaticFunction("is_infinite", Array(TFloat32), TBoolean, null)(jFloatClass, "isInfinite") - registerJavaStaticFunction("is_infinite", Array(TFloat64), TBoolean, null)(jDoubleClass, "isInfinite") + registerJavaStaticFunction("is_finite", Array(TFloat32), TBoolean, null)( + jFloatClass, + "isFinite", + ) + registerJavaStaticFunction("is_finite", Array(TFloat64), TBoolean, null)( + jDoubleClass, + "isFinite", + ) + + registerJavaStaticFunction("is_infinite", Array(TFloat32), TBoolean, null)( + jFloatClass, + "isInfinite", + ) + registerJavaStaticFunction("is_infinite", Array(TFloat64), TBoolean, null)( + jDoubleClass, + "isInfinite", + ) registerJavaStaticFunction("sign", Array(TInt32), TInt32, null)(jIntegerClass, "signum") registerScalaFunction("sign", Array(TInt64), TInt64, null)(mathPackageClass, "signum") registerJavaStaticFunction("sign", Array(TFloat32), TFloat32, null)(jMathClass, "signum") registerJavaStaticFunction("sign", Array(TFloat64), TFloat64, null)(jMathClass, "signum") - registerScalaFunction("approxEqual", Array(TFloat64, TFloat64, TFloat64, TBoolean, TBoolean), TBoolean, null)(thisClass, "approxEqual") + registerScalaFunction( + "approxEqual", + Array(TFloat64, TFloat64, TFloat64, TBoolean, TBoolean), + TBoolean, + null, + )(thisClass, "approxEqual") registerWrappedScalaFunction1("entropy", TString, TFloat64, null)(thisClass, "irentropy") - registerSCode4("fisher_exact_test", TInt32, TInt32, TInt32, TInt32, fetStruct.virtualType, - (_, _, _, _, _) => fetStruct.sType - ) { case (r, cb, rt, a: SInt32Value, b: SInt32Value, c: SInt32Value, d: SInt32Value, _) => - val res = cb.newLocal[Array[Double]]("fisher_exact_test_res", - Code.invokeScalaObject4[Int, Int, Int, Int, Array[Double]](statsPackageClass, "fisherExactTest", + registerSCode4( + "fisher_exact_test", + TInt32, + TInt32, + TInt32, + TInt32, + fetStruct.virtualType, + (_, _, _, _, _) => fetStruct.sType, + ) { case (r, cb, _, a: SInt32Value, b: SInt32Value, c: SInt32Value, d: SInt32Value, _) => + val res = cb.newLocal[Array[Double]]( + "fisher_exact_test_res", + Code.invokeScalaObject4[Int, Int, Int, Int, Array[Double]]( + statsPackageClass, + "fisherExactTest", a.value, b.value, c.value, - d.value)) - - fetStruct.constructFromFields(cb, r.region, FastSeq( - EmitValue.present(primitive(cb.memoize(res(0)))), - EmitValue.present(primitive(cb.memoize(res(1)))), - EmitValue.present(primitive(cb.memoize(res(2)))), - EmitValue.present(primitive(cb.memoize(res(3)))) - ), deepCopy = false) + d.value, + ), + ) + + fetStruct.constructFromFields( + cb, + r.region, + FastSeq( + EmitValue.present(primitive(cb.memoize(res(0)))), + EmitValue.present(primitive(cb.memoize(res(1)))), + EmitValue.present(primitive(cb.memoize(res(2)))), + EmitValue.present(primitive(cb.memoize(res(3)))), + ), + deepCopy = false, + ) } - registerSCode4("chi_squared_test", TInt32, TInt32, TInt32, TInt32, chisqStruct.virtualType, - (_, _, _, _, _) => chisqStruct.sType - ) { case (r, cb, rt, a: SInt32Value, b: SInt32Value, c: SInt32Value, d: SInt32Value, _) => - val res = cb.newLocal[Array[Double]]("chi_squared_test_res", - Code.invokeScalaObject4[Int, Int, Int, Int, Array[Double]](statsPackageClass, "chiSquaredTest", + registerSCode4( + "chi_squared_test", + TInt32, + TInt32, + TInt32, + TInt32, + chisqStruct.virtualType, + (_, _, _, _, _) => chisqStruct.sType, + ) { case (r, cb, _, a: SInt32Value, b: SInt32Value, c: SInt32Value, d: SInt32Value, _) => + val res = cb.newLocal[Array[Double]]( + "chi_squared_test_res", + Code.invokeScalaObject4[Int, Int, Int, Int, Array[Double]]( + statsPackageClass, + "chiSquaredTest", a.value, b.value, c.value, - d.value)) - - chisqStruct.constructFromFields(cb, r.region, FastSeq( - EmitValue.present(primitive(cb.memoize(res(0)))), - EmitValue.present(primitive(cb.memoize(res(1)))) - ), deepCopy = false) + d.value, + ), + ) + + chisqStruct.constructFromFields( + cb, + r.region, + FastSeq( + EmitValue.present(primitive(cb.memoize(res(0)))), + EmitValue.present(primitive(cb.memoize(res(1)))), + ), + deepCopy = false, + ) } - registerSCode5("contingency_table_test", TInt32, TInt32, TInt32, TInt32, TInt32, chisqStruct.virtualType, - (_, _, _, _, _, _) => chisqStruct.sType - ) { case (r, cb, rt, a: SInt32Value, b: SInt32Value, c: SInt32Value, d: SInt32Value, mcc: SInt32Value, _) => - val res = cb.newLocal[Array[Double]]("contingency_table_test_res", - Code.invokeScalaObject5[Int, Int, Int, Int, Int, Array[Double]](statsPackageClass, "contingencyTableTest", - a.value, - b.value, - c.value, - d.value, - mcc.value)) + registerSCode5( + "contingency_table_test", + TInt32, + TInt32, + TInt32, + TInt32, + TInt32, + chisqStruct.virtualType, + (_, _, _, _, _, _) => chisqStruct.sType, + ) { + case ( + r, + cb, + _, + a: SInt32Value, + b: SInt32Value, + c: SInt32Value, + d: SInt32Value, + mcc: SInt32Value, + _, + ) => + val res = cb.newLocal[Array[Double]]( + "contingency_table_test_res", + Code.invokeScalaObject5[Int, Int, Int, Int, Int, Array[Double]]( + statsPackageClass, + "contingencyTableTest", + a.value, + b.value, + c.value, + d.value, + mcc.value, + ), + ) - chisqStruct.constructFromFields(cb, r.region, FastSeq( - EmitValue.present(primitive(cb.memoize(res(0)))), - EmitValue.present(primitive(cb.memoize(res(1)))) - ), deepCopy = false) + chisqStruct.constructFromFields( + cb, + r.region, + FastSeq( + EmitValue.present(primitive(cb.memoize(res(0)))), + EmitValue.present(primitive(cb.memoize(res(1)))), + ), + deepCopy = false, + ) } - registerSCode4("hardy_weinberg_test", TInt32, TInt32, TInt32, TBoolean, hweStruct.virtualType, - (_, _, _, _, _) => hweStruct.sType - ) { case (r, cb, rt, nHomRef: SInt32Value, nHet: SInt32Value, nHomVar: SInt32Value, oneSided: SBooleanValue, _) => - val res = cb.newLocal[Array[Double]]("hardy_weinberg_test_res", - Code.invokeScalaObject4[Int, Int, Int, Boolean, Array[Double]](statsPackageClass, "hardyWeinbergTest", - nHomRef.value, - nHet.value, - nHomVar.value, - oneSided.value)) - - hweStruct.constructFromFields(cb, r.region, FastSeq( - EmitValue.present(primitive(cb.memoize(res(0)))), - EmitValue.present(primitive(cb.memoize(res(1)))) - ), deepCopy = false) + registerSCode4( + "hardy_weinberg_test", + TInt32, + TInt32, + TInt32, + TBoolean, + hweStruct.virtualType, + (_, _, _, _, _) => hweStruct.sType, + ) { + case ( + r, + cb, + _, + nHomRef: SInt32Value, + nHet: SInt32Value, + nHomVar: SInt32Value, + oneSided: SBooleanValue, + _, + ) => + val res = cb.newLocal[Array[Double]]( + "hardy_weinberg_test_res", + Code.invokeScalaObject4[Int, Int, Int, Boolean, Array[Double]]( + statsPackageClass, + "hardyWeinbergTest", + nHomRef.value, + nHet.value, + nHomVar.value, + oneSided.value, + ), + ) + + hweStruct.constructFromFields( + cb, + r.region, + FastSeq( + EmitValue.present(primitive(cb.memoize(res(0)))), + EmitValue.present(primitive(cb.memoize(res(1)))), + ), + deepCopy = false, + ) } } } diff --git a/hail/src/main/scala/is/hail/expr/ir/functions/MatrixWriteBlockMatrix.scala b/hail/src/main/scala/is/hail/expr/ir/functions/MatrixWriteBlockMatrix.scala index 056e8d6bb0f..d57a58f52d8 100644 --- a/hail/src/main/scala/is/hail/expr/ir/functions/MatrixWriteBlockMatrix.scala +++ b/hail/src/main/scala/is/hail/expr/ir/functions/MatrixWriteBlockMatrix.scala @@ -1,17 +1,23 @@ package is.hail.expr.ir.functions -import java.io.DataOutputStream -import is.hail.HailContext import is.hail.backend.ExecuteContext import is.hail.expr.ir.MatrixValue -import is.hail.types.{MatrixType, RPrimitive, RTable, TypeWithRequiredness} -import is.hail.types.virtual.{TVoid, Type} import is.hail.linalg.{BlockMatrix, BlockMatrixMetadata, GridPartitioner, WriteBlocksRDD} import is.hail.utils._ + +import java.io.DataOutputStream + import org.json4s.jackson object MatrixWriteBlockMatrix { - def apply(ctx: ExecuteContext, mv: MatrixValue, entryField: String, path: String, overwrite: Boolean, blockSize: Int): Unit = { + def apply( + ctx: ExecuteContext, + mv: MatrixValue, + entryField: String, + path: String, + overwrite: Boolean, + blockSize: Int, + ): Unit = { val rvd = mv.rvd // FIXME @@ -48,7 +54,8 @@ object MatrixWriteBlockMatrix { implicit val formats = defaultJSONFormats jackson.Serialization.write( BlockMatrixMetadata(blockSize, nRows, localNCols, gp.partitionIndexToBlockIndex, partFiles), - os) + os, + ) } assert(blockCount == gp.numPartitions) @@ -56,4 +63,4 @@ object MatrixWriteBlockMatrix { using(fs.create(path + "/_SUCCESS"))(out => ()) } -} \ No newline at end of file +} diff --git a/hail/src/main/scala/is/hail/expr/ir/functions/NDArrayFunctions.scala b/hail/src/main/scala/is/hail/expr/ir/functions/NDArrayFunctions.scala index 765cfa0d3cb..a9361125f1f 100644 --- a/hail/src/main/scala/is/hail/expr/ir/functions/NDArrayFunctions.scala +++ b/hail/src/main/scala/is/hail/expr/ir/functions/NDArrayFunctions.scala @@ -2,33 +2,41 @@ package is.hail.expr.ir.functions import is.hail.annotations.{Memory, Region} import is.hail.asm4s._ -import is.hail.expr.ir._ import is.hail.expr.{Nat, NatVariable} +import is.hail.expr.ir._ import is.hail.linalg.{LAPACK, LinalgCodeUtils} -import is.hail.types.tcoerce +import is.hail.types.physical._ import is.hail.types.physical.stypes.EmitType -import is.hail.types.physical.stypes.concrete.{SBaseStructPointer, SNDArrayPointer, SNDArrayPointerValue} +import is.hail.types.physical.stypes.concrete.{ + SBaseStructPointer, SNDArrayPointer, SNDArrayPointerValue, +} import is.hail.types.physical.stypes.interfaces._ import is.hail.types.physical.stypes.primitives._ -import is.hail.types.physical._ +import is.hail.types.tcoerce import is.hail.types.virtual._ -import is.hail.utils._ -object NDArrayFunctions extends RegistryFunctions { - override def registerAll() { +object NDArrayFunctions extends RegistryFunctions { + override def registerAll(): Unit = { for ((stringOp, argType, retType, irOp) <- ArrayFunctions.arrayOps) { val nDimVar = NatVariable() - registerIR2(stringOp, TNDArray(argType, nDimVar), argType, TNDArray(retType, nDimVar)) { (_, a, c, errorID) => - val i = genUID() - NDArrayMap(a, i, irOp(Ref(i, c.typ), c, errorID)) + registerIR2(stringOp, TNDArray(argType, nDimVar), argType, TNDArray(retType, nDimVar)) { + (_, a, c, errorID) => + val i = genUID() + NDArrayMap(a, i, irOp(Ref(i, c.typ), c, errorID)) } - registerIR2(stringOp, argType, TNDArray(argType, nDimVar), TNDArray(retType, nDimVar)) { (_, c, a, errorID) => - val i = genUID() - NDArrayMap(a, i, irOp(c, Ref(i, c.typ), errorID)) + registerIR2(stringOp, argType, TNDArray(argType, nDimVar), TNDArray(retType, nDimVar)) { + (_, c, a, errorID) => + val i = genUID() + NDArrayMap(a, i, irOp(c, Ref(i, c.typ), errorID)) } - registerIR2(stringOp, TNDArray(argType, nDimVar), TNDArray(argType, nDimVar), TNDArray(retType, nDimVar)) { (_, l, r, errorID) => + registerIR2( + stringOp, + TNDArray(argType, nDimVar), + TNDArray(argType, nDimVar), + TNDArray(retType, nDimVar), + ) { (_, l, r, errorID) => val lid = genUID() val rid = genUID() val lElemRef = Ref(lid, tcoerce[TNDArray](l.typ).elementType) @@ -38,15 +46,29 @@ object NDArrayFunctions extends RegistryFunctions { } } - def linear_triangular_solve(ndCoef: SNDArrayValue, ndDep: SNDArrayValue, lower: SBooleanValue, outputPt: PType, cb: EmitCodeBuilder, region: Value[Region], errorID: Value[Int]): (SNDArrayValue, Value[Int]) = { + def linear_triangular_solve( + ndCoef: SNDArrayValue, + ndDep: SNDArrayValue, + lower: SBooleanValue, + outputPt: PType, + cb: EmitCodeBuilder, + region: Value[Region], + errorID: Value[Int], + ): (SNDArrayValue, Value[Int]) = { val ndCoefColMajor = LinalgCodeUtils.checkColMajorAndCopyIfNeeded(ndCoef, cb, region) val ndDepColMajor = LinalgCodeUtils.checkColMajorAndCopyIfNeeded(ndDep, cb, region) val IndexedSeq(ndCoefRow, ndCoefCol) = ndCoefColMajor.shapes - cb.if_(ndCoefRow cne ndCoefCol, cb._fatalWithError(errorID, "hail.nd.solve_triangular: matrix a must be square.")) + cb.if_( + ndCoefRow cne ndCoefCol, + cb._fatalWithError(errorID, "hail.nd.solve_triangular: matrix a must be square."), + ) val IndexedSeq(ndDepRow, ndDepCol) = ndDepColMajor.shapes - cb.if_(ndCoefRow cne ndDepRow, cb._fatalWithError(errorID,"hail.nd.solve_triangular: Solve dimensions incompatible")) + cb.if_( + ndCoefRow cne ndDepRow, + cb._fatalWithError(errorID, "hail.nd.solve_triangular: Solve dimensions incompatible"), + ) val uplo = cb.newLocal[String]("dtrtrs_uplo") cb.if_(lower.value, cb.assign(uplo, const("L")), cb.assign(uplo, const("U"))) @@ -56,22 +78,34 @@ object NDArrayFunctions extends RegistryFunctions { val outputPType = tcoerce[PCanonicalNDArray](outputPt) val output = outputPType.constructByActuallyCopyingData(ndDepColMajor, cb, region) - cb.assign(infoDTRTRSResult, Code.invokeScalaObject9[String, String, String, Int, Int, Long, Int, Long, Int, Int](LAPACK.getClass, "dtrtrs", - uplo, - const("N"), - const("N"), - ndDepRow.toI, - ndDepCol.toI, - ndCoefColMajor.firstDataAddress, - ndDepRow.toI, - output.firstDataAddress, - ndDepRow.toI - )) + cb.assign( + infoDTRTRSResult, + Code.invokeScalaObject9[String, String, String, Int, Int, Long, Int, Long, Int, Int]( + LAPACK.getClass, + "dtrtrs", + uplo, + const("N"), + const("N"), + ndDepRow.toI, + ndDepCol.toI, + ndCoefColMajor.firstDataAddress, + ndDepRow.toI, + output.firstDataAddress, + ndDepRow.toI, + ), + ) (output, infoDTRTRSResult) } - def linear_solve(a: SNDArrayValue, b: SNDArrayValue, outputPt: PType, cb: EmitCodeBuilder, region: Value[Region], errorID: Value[Int]): (SNDArrayValue, Value[Int]) = { + def linear_solve( + a: SNDArrayValue, + b: SNDArrayValue, + outputPt: PType, + cb: EmitCodeBuilder, + region: Value[Region], + errorID: Value[Int], + ): (SNDArrayValue, Value[Int]) = { val aColMajor = LinalgCodeUtils.checkColMajorAndCopyIfNeeded(a, cb, region) val bColMajor = LinalgCodeUtils.checkColMajorAndCopyIfNeeded(b, cb, region) @@ -98,19 +132,29 @@ object NDArrayFunctions extends RegistryFunctions { val outputPType = tcoerce[PCanonicalNDArray](outputPt) val outputShape = IndexedSeq(n, nrhs) - val (outputAddress, outputFinisher) = outputPType.constructDataFunction(outputShape, outputPType.makeColumnMajorStrides(outputShape, cb), cb, region) + val (outputAddress, outputFinisher) = outputPType.constructDataFunction( + outputShape, + outputPType.makeColumnMajorStrides(outputShape, cb), + cb, + region, + ) cb.append(Region.copyFrom(bColMajor.firstDataAddress, outputAddress, n * nrhs * 8L)) - cb.assign(infoDGESVResult, Code.invokeScalaObject7[Int, Int, Long, Int, Long, Long, Int, Int](LAPACK.getClass, "dgesv", - n.toI, - nrhs.toI, - aCopy, - n.toI, - ipiv, - outputAddress, - n.toI - )) + cb.assign( + infoDGESVResult, + Code.invokeScalaObject7[Int, Int, Long, Int, Long, Long, Int, Int]( + LAPACK.getClass, + "dgesv", + n.toI, + nrhs.toI, + aCopy, + n.toI, + ipiv, + outputAddress, + n.toI, + ), + ) cb.append(Code.invokeStatic1[Memory, Long, Unit]("free", ipiv.load())) cb.append(Code.invokeStatic1[Memory, Long, Unit]("free", aCopy.load())) @@ -118,129 +162,274 @@ object NDArrayFunctions extends RegistryFunctions { (outputFinisher(cb), infoDGESVResult) } - registerIEmitCode2("linear_solve_no_crash", TNDArray(TFloat64, Nat(2)), TNDArray(TFloat64, Nat(2)), TStruct(("solution", TNDArray(TFloat64, Nat(2))), ("failed", TBoolean)), - { (t, p1, p2) => EmitType(PCanonicalStruct(false, ("solution", PCanonicalNDArray(PFloat64Required, 2, false)), ("failed", PBooleanRequired)).sType, false) }) { - case (cb, region, SBaseStructPointer(outputStructType: PCanonicalStruct), errorID, aec, bec) => + registerIEmitCode2( + "linear_solve_no_crash", + TNDArray(TFloat64, Nat(2)), + TNDArray(TFloat64, Nat(2)), + TStruct(("solution", TNDArray(TFloat64, Nat(2))), ("failed", TBoolean)), + (t, p1, p2) => + EmitType( + PCanonicalStruct( + false, + ("solution", PCanonicalNDArray(PFloat64Required, 2, false)), + ("failed", PBooleanRequired), + ).sType, + false, + ), + ) { + case ( + cb, + region, + SBaseStructPointer(outputStructType: PCanonicalStruct), + errorID, + aec, + bec, + ) => aec.toI(cb).flatMap(cb) { apc => bec.toI(cb).map(cb) { bpc => val outputNDArrayPType = outputStructType.fieldType("solution") - val (resNDPCode, info) = linear_solve(apc.asNDArray, bpc.asNDArray, outputNDArrayPType, cb, region, errorID) + val (resNDPCode, info) = + linear_solve(apc.asNDArray, bpc.asNDArray, outputNDArrayPType, cb, region, errorID) val ndEmitCode = EmitCode(Code._empty, info cne 0, resNDPCode) - outputStructType.constructFromFields(cb, region, IndexedSeq[EmitCode](ndEmitCode, EmitCode(Code._empty, false, primitive(cb.memoize(info cne 0)))), false) + outputStructType.constructFromFields( + cb, + region, + IndexedSeq[EmitCode]( + ndEmitCode, + EmitCode(Code._empty, false, primitive(cb.memoize(info cne 0))), + ), + false, + ) } } } - registerSCode2("linear_solve", TNDArray(TFloat64, Nat(2)), TNDArray(TFloat64, Nat(2)), TNDArray(TFloat64, Nat(2)), - { (t, p1, p2) => PCanonicalNDArray(PFloat64Required, 2, true).sType }) { + registerSCode2( + "linear_solve", + TNDArray(TFloat64, Nat(2)), + TNDArray(TFloat64, Nat(2)), + TNDArray(TFloat64, Nat(2)), + (t, p1, p2) => PCanonicalNDArray(PFloat64Required, 2, true).sType, + ) { case (er, cb, SNDArrayPointer(pt), apc, bpc, errorID) => - val (resPCode, info) = linear_solve(apc.asNDArray, bpc.asNDArray, pt, cb, er.region, errorID) - cb.if_(info cne 0, cb._fatalWithError(errorID,s"hl.nd.solve: Could not solve, matrix was singular. dgesv error code ", info.toS)) + val (resPCode, info) = + linear_solve(apc.asNDArray, bpc.asNDArray, pt, cb, er.region, errorID) + cb.if_( + info cne 0, + cb._fatalWithError( + errorID, + s"hl.nd.solve: Could not solve, matrix was singular. dgesv error code ", + info.toS, + ), + ) resPCode } - registerIEmitCode3("linear_triangular_solve_no_crash", TNDArray(TFloat64, Nat(2)), TNDArray(TFloat64, Nat(2)), TBoolean, TStruct(("solution", TNDArray(TFloat64, Nat(2))), ("failed", TBoolean)), - { (t, p1, p2, p3) => EmitType(PCanonicalStruct(false, ("solution", PCanonicalNDArray(PFloat64Required, 2, false)), ("failed", PBooleanRequired)).sType, false) }) { - case (cb, region, SBaseStructPointer(outputStructType: PCanonicalStruct), errorID, aec, bec, lowerec) => + registerIEmitCode3( + "linear_triangular_solve_no_crash", + TNDArray(TFloat64, Nat(2)), + TNDArray(TFloat64, Nat(2)), + TBoolean, + TStruct(("solution", TNDArray(TFloat64, Nat(2))), ("failed", TBoolean)), + (t, p1, p2, p3) => + EmitType( + PCanonicalStruct( + false, + ("solution", PCanonicalNDArray(PFloat64Required, 2, false)), + ("failed", PBooleanRequired), + ).sType, + false, + ), + ) { + case ( + cb, + region, + SBaseStructPointer(outputStructType: PCanonicalStruct), + errorID, + aec, + bec, + lowerec, + ) => aec.toI(cb).flatMap(cb) { apc => bec.toI(cb).flatMap(cb) { bpc => lowerec.toI(cb).map(cb) { lowerpc => val outputNDArrayPType = outputStructType.fieldType("solution") - val (resNDPCode, info) = linear_triangular_solve(apc.asNDArray, bpc.asNDArray, lowerpc.asBoolean, outputNDArrayPType, cb, region, errorID) + val (resNDPCode, info) = linear_triangular_solve( + apc.asNDArray, + bpc.asNDArray, + lowerpc.asBoolean, + outputNDArrayPType, + cb, + region, + errorID, + ) val ndEmitCode = EmitCode(Code._empty, info cne 0, resNDPCode) - outputStructType.constructFromFields(cb, region, IndexedSeq[EmitCode](ndEmitCode, EmitCode(Code._empty, false, primitive(cb.memoize(info cne 0)))), false) + outputStructType.constructFromFields( + cb, + region, + IndexedSeq[EmitCode]( + ndEmitCode, + EmitCode(Code._empty, false, primitive(cb.memoize(info cne 0))), + ), + false, + ) } } } } - registerSCode3("linear_triangular_solve", TNDArray(TFloat64, Nat(2)), TNDArray(TFloat64, Nat(2)), TBoolean, TNDArray(TFloat64, Nat(2)), - { (t, p1, p2, p3) => PCanonicalNDArray(PFloat64Required, 2, true).sType }) { + registerSCode3( + "linear_triangular_solve", + TNDArray(TFloat64, Nat(2)), + TNDArray(TFloat64, Nat(2)), + TBoolean, + TNDArray(TFloat64, Nat(2)), + (t, p1, p2, p3) => PCanonicalNDArray(PFloat64Required, 2, true).sType, + ) { case (er, cb, SNDArrayPointer(pt), apc, bpc, lower, errorID) => - val (resPCode, info) = linear_triangular_solve(apc.asNDArray, bpc.asNDArray, lower.asBoolean, pt, cb, er.region, errorID) - cb.if_(info cne 0, cb._fatalWithError(errorID,s"hl.nd.solve: Could not solve, matrix was singular. dtrtrs error code ", info.toS)) + val (resPCode, info) = linear_triangular_solve( + apc.asNDArray, + bpc.asNDArray, + lower.asBoolean, + pt, + cb, + er.region, + errorID, + ) + cb.if_( + info cne 0, + cb._fatalWithError( + errorID, + s"hl.nd.solve: Could not solve, matrix was singular. dtrtrs error code ", + info.toS, + ), + ) resPCode } - registerSCode3("zero_band", TNDArray(TFloat64, Nat(2)), TInt64, TInt64, TNDArray(TFloat64, Nat(2)), - { (_, _, _, _) => PCanonicalNDArray(PFloat64Required, 2, true).sType }) { - case (er, cb, rst: SNDArrayPointer, block: SNDArrayValue, lower: SInt64Value, upper: SInt64Value, errorID) => - val newBlock = rst.coerceOrCopy(cb, er.region, block, deepCopy = false).asInstanceOf[SNDArrayPointerValue] + registerSCode3( + "zero_band", + TNDArray(TFloat64, Nat(2)), + TInt64, + TInt64, + TNDArray(TFloat64, Nat(2)), + (_, _, _, _) => PCanonicalNDArray(PFloat64Required, 2, true).sType, + ) { + case ( + er, + cb, + rst: SNDArrayPointer, + block: SNDArrayValue, + lower: SInt64Value, + upper: SInt64Value, + _, + ) => + val newBlock = rst.coerceOrCopy(cb, er.region, block, deepCopy = false).asInstanceOf[ + SNDArrayPointerValue + ] val IndexedSeq(nRows, nCols) = newBlock.shapes - val lowestDiagIndex = cb.memoize(- (nRows.get - 1L)) + val lowestDiagIndex = cb.memoize(-(nRows.get - 1L)) val highestDiagIndex = cb.memoize(nCols.get - 1L) val iLeft = cb.newLocal[Long]("iLeft") val iRight = cb.newLocal[Long]("iRight") val i = cb.newLocal[Long]("i") val j = cb.newLocal[Long]("j") - cb.if_(lower.value > lowestDiagIndex, { - cb.assign(iLeft, (-lower.value).max(0L)) - cb.assign(iRight, (nCols.get - lower.value).min(nRows.get)) - - cb.for_({ - cb.assign(i, iLeft) - cb.assign(j, lower.value.max(0L)) - }, i < iRight, { - cb.assign(i, i + 1L) - cb.assign(j, j + 1L) - }, { - // block(i to i, 0 until j) := 0.0 - newBlock.slice(cb, i, (null, j)).coiterateMutate(cb, er.region) { _ => + cb.if_( + lower.value > lowestDiagIndex, { + cb.assign(iLeft, (-lower.value).max(0L)) + cb.assign(iRight, (nCols.get - lower.value).min(nRows.get)) + + cb.for_( + { + cb.assign(i, iLeft) + cb.assign(j, lower.value.max(0L)) + }, + i < iRight, { + cb.assign(i, i + 1L) + cb.assign(j, j + 1L) + }, + // block(i to i, 0 until j) := 0.0 + newBlock.slice(cb, i, (null, j)).coiterateMutate(cb, er.region) { _ => + primitive(0.0d) + }, + ) + + // block(iRight until nRows, ::) := 0.0 + newBlock.slice(cb, (iRight, null), ColonIndex).coiterateMutate(cb, er.region) { _ => primitive(0.0d) } - }) - - // block(iRight until nRows, ::) := 0.0 - newBlock.slice(cb, (iRight, null), ColonIndex).coiterateMutate(cb, er.region) { _ => - primitive(0.0d) - } - }) + }, + ) - cb.if_(upper.value < highestDiagIndex, { - cb.assign(iLeft, (-upper.value).max(0L)) - cb.assign(iRight, (nCols.get - upper.value).min(nRows.get)) - - // block(0 util iLeft, ::) := 0.0 - newBlock.slice(cb, (null, iLeft), ColonIndex).coiterateMutate(cb, er.region) { _ => - primitive(0.0d) - } + cb.if_( + upper.value < highestDiagIndex, { + cb.assign(iLeft, (-upper.value).max(0L)) + cb.assign(iRight, (nCols.get - upper.value).min(nRows.get)) - cb.for_({ - cb.assign(i, iLeft) - cb.assign(j, upper.value.max(0L) + 1) - }, i < iRight, { - cb.assign(i, i + 1) - cb.assign(j, j + 1) - }, { - // block(i to i, j to nCols) := 0.0 - newBlock.slice(cb, i, (j, null)).coiterateMutate(cb, er.region) { _ => + // block(0 util iLeft, ::) := 0.0 + newBlock.slice(cb, (null, iLeft), ColonIndex).coiterateMutate(cb, er.region) { _ => primitive(0.0d) } - }) - }) + + cb.for_( + { + cb.assign(i, iLeft) + cb.assign(j, upper.value.max(0L) + 1) + }, + i < iRight, { + cb.assign(i, i + 1) + cb.assign(j, j + 1) + }, + // block(i to i, j to nCols) := 0.0 + newBlock.slice(cb, i, (j, null)).coiterateMutate(cb, er.region) { _ => + primitive(0.0d) + }, + ) + }, + ) newBlock } - registerSCode3("zero_row_intervals", TNDArray(TFloat64, Nat(2)), TArray(TInt64), TArray(TInt64), TNDArray(TFloat64, Nat(2)), - { (_, _, _, _) => PCanonicalNDArray(PFloat64Required, 2, true).sType }) { - case (er, cb, rst: SNDArrayPointer, block: SNDArrayValue, starts: SIndexableValue, stops: SIndexableValue, errorID) => - val newBlock = rst.coerceOrCopy(cb, er.region, block, deepCopy = false).asInstanceOf[SNDArrayPointerValue] + registerSCode3( + "zero_row_intervals", + TNDArray(TFloat64, Nat(2)), + TArray(TInt64), + TArray(TInt64), + TNDArray(TFloat64, Nat(2)), + (_, _, _, _) => PCanonicalNDArray(PFloat64Required, 2, true).sType, + ) { + case ( + er, + cb, + rst: SNDArrayPointer, + block: SNDArrayValue, + starts: SIndexableValue, + stops: SIndexableValue, + _, + ) => + val newBlock = rst.coerceOrCopy(cb, er.region, block, deepCopy = false).asInstanceOf[ + SNDArrayPointerValue + ] val row = cb.newLocal[Long]("rowIdx") - val IndexedSeq(nRows, nCols) = newBlock.shapes - cb.for_(cb.assign(row, 0L), row < nRows.get, cb.assign(row, row + 1L), { - val start = starts.loadElement(cb, row.toI).get(cb).asInt64.value - val stop = stops.loadElement(cb, row.toI).get(cb).asInt64.value - newBlock.slice(cb, row, (null, start)).coiterateMutate(cb, er.region) { _ => - primitive(0.0d) - } - newBlock.slice(cb, row, (stop, null)).coiterateMutate(cb, er.region) { _ => - primitive(0.0d) - } - }) + val IndexedSeq(nRows, _) = newBlock.shapes + cb.for_( + cb.assign(row, 0L), + row < nRows.get, + cb.assign(row, row + 1L), { + val start = starts.loadElement(cb, row.toI).getOrAssert(cb).asInt64.value + val stop = stops.loadElement(cb, row.toI).getOrAssert(cb).asInt64.value + newBlock.slice(cb, row, (null, start)).coiterateMutate(cb, er.region) { _ => + primitive(0.0d) + } + newBlock.slice(cb, row, (stop, null)).coiterateMutate(cb, er.region) { _ => + primitive(0.0d) + } + }, + ) newBlock - } + } } } diff --git a/hail/src/main/scala/is/hail/expr/ir/functions/RandomSeededFunctions.scala b/hail/src/main/scala/is/hail/expr/ir/functions/RandomSeededFunctions.scala index 5275ef53319..9ba56ad97b6 100644 --- a/hail/src/main/scala/is/hail/expr/ir/functions/RandomSeededFunctions.scala +++ b/hail/src/main/scala/is/hail/expr/ir/functions/RandomSeededFunctions.scala @@ -3,19 +3,21 @@ package is.hail.expr.ir.functions import is.hail.asm4s._ import is.hail.expr.Nat import is.hail.expr.ir.{EmitCodeBuilder, IEmitCode} +import is.hail.types.physical.{PCanonicalArray, PCanonicalNDArray, PFloat64, PInt32} import is.hail.types.physical.stypes._ import is.hail.types.physical.stypes.concrete.{SIndexablePointer, SNDArrayPointer, SRNGStateValue} import is.hail.types.physical.stypes.interfaces._ import is.hail.types.physical.stypes.primitives._ -import is.hail.types.physical.{PCanonicalArray, PCanonicalNDArray, PFloat64, PInt32} import is.hail.types.virtual._ import is.hail.utils.FastSeq -import net.sourceforge.jdistlib.rng.MersenneTwister + import net.sourceforge.jdistlib.{Beta, Gamma, HyperGeometric, Poisson} +import net.sourceforge.jdistlib.rng.MersenneTwister class IRRandomness(seed: Long) { - // org.apache.commons has no way to statically sample from distributions without creating objects :( + /* org.apache.commons has no way to statically sample from distributions without creating objects + * :( */ private[this] val random = new MersenneTwister() private[this] val poisState = Poisson.create_random_state() @@ -23,7 +25,7 @@ class IRRandomness(seed: Long) { private[this] def hash(pidx: Int): Long = seed ^ java.lang.Math.floorMod(pidx * 11399L, 2147483647L) - def reset(partitionIdx: Int) { + def reset(partitionIdx: Int): Unit = { val combinedSeed = hash(partitionIdx) random.setSeed(combinedSeed) } @@ -72,8 +74,13 @@ object RandomSeededFunctions extends RegistryFunctions { def rand_unif(cb: EmitCodeBuilder, rand_longs: IndexedSeq[Value[Long]]): Code[Double] = { assert(rand_longs.length == 4) - Code.invokeScalaObject4[Long, Long, Long, Long, Double](RandomSeededFunctions.getClass, "_rand_unif", - rand_longs(0), rand_longs(1), rand_longs(2), rand_longs(3) + Code.invokeScalaObject4[Long, Long, Long, Long, Double]( + RandomSeededFunctions.getClass, + "_rand_unif", + rand_longs(0), + rand_longs(1), + rand_longs(2), + rand_longs(3), ) } @@ -104,177 +111,409 @@ object RandomSeededFunctions extends RegistryFunctions { ) } - def registerAll() { - registerSCode3("rand_unif", TRNGState, TFloat64, TFloat64, TFloat64, { - case (_: Type, _: SType, _: SType, _: SType) => SFloat64 - }) { case (_, cb, rt, rngState: SRNGStateValue, min: SFloat64Value, max: SFloat64Value, errorID) => - primitive(cb.memoize(rand_unif(cb, rngState.rand(cb)) * (max.value - min.value) + min.value)) + def registerAll(): Unit = { + registerSCode3( + "rand_unif", + TRNGState, + TFloat64, + TFloat64, + TFloat64, + { + case (_: Type, _: SType, _: SType, _: SType) => SFloat64 + }, + ) { + case (_, cb, _, rngState: SRNGStateValue, min: SFloat64Value, max: SFloat64Value, _) => + primitive(cb.memoize(rand_unif( + cb, + rngState.rand(cb), + ) * (max.value - min.value) + min.value)) } - registerSCode5("rand_unif_nd", TRNGState, TInt64, TInt64, TFloat64, TFloat64, TNDArray(TFloat64, Nat(2)), { - case (_: Type, _: SType, _: SType, _: SType, _: SType, _: SType) => PCanonicalNDArray(PFloat64(true), 2, true).sType - }) { case (r, cb, rt: SNDArrayPointer, rngState: SRNGStateValue, nRows: SInt64Value, nCols: SInt64Value, min, max, errorID) => - val result = rt.pType.constructUninitialized(FastSeq(SizeValueDyn(nRows.value), SizeValueDyn(nCols.value)), cb, r.region) - val rng = cb.emb.getThreefryRNG() - rngState.copyIntoEngine(cb, rng) - result.coiterateMutate(cb, r.region) { _ => - primitive(cb.memoize(rng.invoke[Double, Double, Double]("runif", min.asDouble.value, max.asDouble.value))) - } - result + registerSCode5( + "rand_unif_nd", + TRNGState, + TInt64, + TInt64, + TFloat64, + TFloat64, + TNDArray(TFloat64, Nat(2)), + { + case (_: Type, _: SType, _: SType, _: SType, _: SType, _: SType) => + PCanonicalNDArray(PFloat64(true), 2, true).sType + }, + ) { + case ( + r, + cb, + rt: SNDArrayPointer, + rngState: SRNGStateValue, + nRows: SInt64Value, + nCols: SInt64Value, + min, + max, + _, + ) => + val result = rt.pType.constructUninitialized( + FastSeq(SizeValueDyn(nRows.value), SizeValueDyn(nCols.value)), + cb, + r.region, + ) + val rng = cb.emb.getThreefryRNG() + rngState.copyIntoEngine(cb, rng) + result.coiterateMutate(cb, r.region) { _ => + primitive(cb.memoize(rng.invoke[Double, Double, Double]( + "runif", + min.asDouble.value, + max.asDouble.value, + ))) + } + result } - registerSCode2("rand_int32", TRNGState, TInt32, TInt32, { - case (_: Type, _: SType, _: SType) => SInt32 - }) { case (r, cb, rt, rngState: SRNGStateValue, n: SInt32Value, errorID) => + registerSCode2( + "rand_int32", + TRNGState, + TInt32, + TInt32, + { + case (_: Type, _: SType, _: SType) => SInt32 + }, + ) { case (_, cb, _, rngState: SRNGStateValue, n: SInt32Value, _) => val rng = cb.emb.getThreefryRNG() rngState.copyIntoEngine(cb, rng) primitive(cb.memoize(rng.invoke[Int, Int]("nextInt", n.value))) } - registerSCode2("rand_int64", TRNGState, TInt64, TInt64, { - case (_: Type, _: SType, _: SType) => SInt64 - }) { case (r, cb, rt, rngState: SRNGStateValue, n: SInt64Value, errorID) => + registerSCode2( + "rand_int64", + TRNGState, + TInt64, + TInt64, + { + case (_: Type, _: SType, _: SType) => SInt64 + }, + ) { case (_, cb, _, rngState: SRNGStateValue, n: SInt64Value, _) => val rng = cb.emb.getThreefryRNG() rngState.copyIntoEngine(cb, rng) primitive(cb.memoize(rng.invoke[Long, Long]("nextLong", n.value))) } - registerSCode1("rand_int64", TRNGState, TInt64, { - case (_: Type, _: SType) => SInt64 - }) { case (r, cb, rt, rngState: SRNGStateValue, errorID) => + registerSCode1( + "rand_int64", + TRNGState, + TInt64, + { + case (_: Type, _: SType) => SInt64 + }, + ) { case (_, cb, _, rngState: SRNGStateValue, _) => primitive(rngState.rand(cb)(0)) } - registerSCode5("rand_norm_nd", TRNGState, TInt64, TInt64, TFloat64, TFloat64, TNDArray(TFloat64, Nat(2)), { - case (_: Type, _: SType, _: SType, _: SType, _: SType, _: SType) => PCanonicalNDArray(PFloat64(true), 2, true).sType - }) { case (r, cb, rt: SNDArrayPointer, rngState: SRNGStateValue, nRows: SInt64Value, nCols: SInt64Value, mean, sd, errorID) => - val result = rt.pType.constructUninitialized(FastSeq(SizeValueDyn(nRows.value), SizeValueDyn(nCols.value)), cb, r.region) - val rng = cb.emb.getThreefryRNG() - rngState.copyIntoEngine(cb, rng) - result.coiterateMutate(cb, r.region) { _ => - primitive(cb.memoize(rng.invoke[Double, Double, Double]("rnorm", mean.asDouble.value, sd.asDouble.value))) - } - result + registerSCode5( + "rand_norm_nd", + TRNGState, + TInt64, + TInt64, + TFloat64, + TFloat64, + TNDArray(TFloat64, Nat(2)), + { + case (_: Type, _: SType, _: SType, _: SType, _: SType, _: SType) => + PCanonicalNDArray(PFloat64(true), 2, true).sType + }, + ) { + case ( + r, + cb, + rt: SNDArrayPointer, + rngState: SRNGStateValue, + nRows: SInt64Value, + nCols: SInt64Value, + mean, + sd, + _, + ) => + val result = rt.pType.constructUninitialized( + FastSeq(SizeValueDyn(nRows.value), SizeValueDyn(nCols.value)), + cb, + r.region, + ) + val rng = cb.emb.getThreefryRNG() + rngState.copyIntoEngine(cb, rng) + result.coiterateMutate(cb, r.region) { _ => + primitive(cb.memoize(rng.invoke[Double, Double, Double]( + "rnorm", + mean.asDouble.value, + sd.asDouble.value, + ))) + } + result } - registerSCode3("rand_norm", TRNGState, TFloat64, TFloat64, TFloat64, { - case (_: Type, _: SType, _: SType, _: SType) => SFloat64 - }) { case (_, cb, rt, rngState: SRNGStateValue, mean: SFloat64Value, sd: SFloat64Value, errorID) => - val rng = cb.emb.getThreefryRNG() - rngState.copyIntoEngine(cb, rng) - primitive(cb.memoize(rng.invoke[Double, Double, Double]("rnorm", mean.value, sd.value))) + registerSCode3( + "rand_norm", + TRNGState, + TFloat64, + TFloat64, + TFloat64, + { + case (_: Type, _: SType, _: SType, _: SType) => SFloat64 + }, + ) { + case (_, cb, _, rngState: SRNGStateValue, mean: SFloat64Value, sd: SFloat64Value, _) => + val rng = cb.emb.getThreefryRNG() + rngState.copyIntoEngine(cb, rng) + primitive(cb.memoize(rng.invoke[Double, Double, Double]("rnorm", mean.value, sd.value))) } - registerSCode2("rand_bool", TRNGState, TFloat64, TBoolean, { - case (_: Type, _: SType, _: SType) => SBoolean - }) { case (_, cb, rt, rngState: SRNGStateValue, p: SFloat64Value, errorID) => + registerSCode2( + "rand_bool", + TRNGState, + TFloat64, + TBoolean, + { + case (_: Type, _: SType, _: SType) => SBoolean + }, + ) { case (_, cb, _, rngState: SRNGStateValue, p: SFloat64Value, _) => val u = rand_unif(cb, rngState.rand(cb)) primitive(cb.memoize(u < p.value)) } - registerSCode2("rand_pois", TRNGState, TFloat64, TFloat64, { - case (_: Type, _: SType, _: SType) => SFloat64 - }) { case (_, cb, rt, rngState: SRNGStateValue, lambda: SFloat64Value, errorID) => + registerSCode2( + "rand_pois", + TRNGState, + TFloat64, + TFloat64, + { + case (_: Type, _: SType, _: SType) => SFloat64 + }, + ) { case (_, cb, _, rngState: SRNGStateValue, lambda: SFloat64Value, _) => val rng = cb.emb.getThreefryRNG() rngState.copyIntoEngine(cb, rng) primitive(cb.memoize(rng.invoke[Double, Double]("rpois", lambda.value))) } - registerSCode3("rand_pois", TRNGState, TInt32, TFloat64, TArray(TFloat64), { - case (_: Type, _: SType, _: SType, _: SType) => PCanonicalArray(PFloat64(true)).sType - }) { case (r, cb, SIndexablePointer(rt: PCanonicalArray), rngState: SRNGStateValue, n: SInt32Value, lambda: SFloat64Value, errorID) => - val rng = cb.emb.getThreefryRNG() - rngState.copyIntoEngine(cb, rng) - rt.constructFromElements(cb, r.region, n.value, deepCopy = false) { case (cb, _) => - IEmitCode.present(cb, - primitive(cb.memoize(rng.invoke[Double, Double]("rpois", lambda.value))) - ) - } + registerSCode3( + "rand_pois", + TRNGState, + TInt32, + TFloat64, + TArray(TFloat64), + { + case (_: Type, _: SType, _: SType, _: SType) => PCanonicalArray(PFloat64(true)).sType + }, + ) { + case ( + r, + cb, + SIndexablePointer(rt: PCanonicalArray), + rngState: SRNGStateValue, + n: SInt32Value, + lambda: SFloat64Value, + _, + ) => + val rng = cb.emb.getThreefryRNG() + rngState.copyIntoEngine(cb, rng) + rt.constructFromElements(cb, r.region, n.value, deepCopy = false) { case (cb, _) => + IEmitCode.present( + cb, + primitive(cb.memoize(rng.invoke[Double, Double]("rpois", lambda.value))), + ) + } } - registerSCode3("rand_beta", TRNGState, TFloat64, TFloat64, TFloat64, { - case (_: Type, _: SType, _: SType, _: SType) => SFloat64 - }) { case (_, cb, rt, rngState: SRNGStateValue, a: SFloat64Value, b: SFloat64Value, errorID) => + registerSCode3( + "rand_beta", + TRNGState, + TFloat64, + TFloat64, + TFloat64, + { + case (_: Type, _: SType, _: SType, _: SType) => SFloat64 + }, + ) { case (_, cb, _, rngState: SRNGStateValue, a: SFloat64Value, b: SFloat64Value, _) => val rng = cb.emb.getThreefryRNG() rngState.copyIntoEngine(cb, rng) primitive(cb.memoize(rng.invoke[Double, Double, Double]("rbeta", a.value, b.value))) } - registerSCode5("rand_beta", TRNGState, TFloat64, TFloat64, TFloat64, TFloat64, TFloat64, { - case (_: Type, _: SType, _: SType, _: SType, _: SType, _: SType) => SFloat64 - }) { case (_, cb, rt, rngState: SRNGStateValue, a: SFloat64Value, b: SFloat64Value, min: SFloat64Value, max: SFloat64Value, errorID) => - val rng = cb.emb.getThreefryRNG() - rngState.copyIntoEngine(cb, rng) - val value = cb.newLocal[Double]("value", rng.invoke[Double, Double, Double]("rbeta", a.value, b.value)) - cb.while_(value < min.value || value > max.value, { - cb.assign(value, rng.invoke[Double, Double, Double]("rbeta", a.value, b.value)) - }) - primitive(value) + registerSCode5( + "rand_beta", + TRNGState, + TFloat64, + TFloat64, + TFloat64, + TFloat64, + TFloat64, + { + case (_: Type, _: SType, _: SType, _: SType, _: SType, _: SType) => SFloat64 + }, + ) { + case ( + _, + cb, + _, + rngState: SRNGStateValue, + a: SFloat64Value, + b: SFloat64Value, + min: SFloat64Value, + max: SFloat64Value, + _, + ) => + val rng = cb.emb.getThreefryRNG() + rngState.copyIntoEngine(cb, rng) + val value = cb.newLocal[Double]( + "value", + rng.invoke[Double, Double, Double]("rbeta", a.value, b.value), + ) + cb.while_( + value < min.value || value > max.value, + cb.assign(value, rng.invoke[Double, Double, Double]("rbeta", a.value, b.value)), + ) + primitive(value) } - registerSCode3("rand_gamma", TRNGState, TFloat64, TFloat64, TFloat64, { - case (_: Type, _: SType, _: SType, _: SType) => SFloat64 - }) { case (_, cb, rt, rngState: SRNGStateValue, a: SFloat64Value, scale: SFloat64Value, errorID) => - val rng = cb.emb.getThreefryRNG() - rngState.copyIntoEngine(cb, rng) - primitive(cb.memoize(rng.invoke[Double, Double, Double]("rgamma", a.value, scale.value))) + registerSCode3( + "rand_gamma", + TRNGState, + TFloat64, + TFloat64, + TFloat64, + { + case (_: Type, _: SType, _: SType, _: SType) => SFloat64 + }, + ) { + case (_, cb, _, rngState: SRNGStateValue, a: SFloat64Value, scale: SFloat64Value, _) => + val rng = cb.emb.getThreefryRNG() + rngState.copyIntoEngine(cb, rng) + primitive(cb.memoize(rng.invoke[Double, Double, Double]("rgamma", a.value, scale.value))) } - registerSCode2("rand_cat", TRNGState, TArray(TFloat64), TInt32, { - case (_: Type, _: SType, _: SType) => SInt32 - }) { case (_, cb, rt, rngState: SRNGStateValue, weights: SIndexableValue, errorID) => + registerSCode2( + "rand_cat", + TRNGState, + TArray(TFloat64), + TInt32, + { + case (_: Type, _: SType, _: SType) => SInt32 + }, + ) { case (_, cb, _, rngState: SRNGStateValue, weights: SIndexableValue, _) => val len = weights.loadLength() val i = cb.newLocal[Int]("i", 0) val s = cb.newLocal[Double]("sum", 0.0) - cb.while_(i < len, { - cb.assign(s, s + weights.loadElement(cb, i).get(cb, "rand_cat requires all elements of input array to be present").asFloat64.value) - cb.assign(i, i + 1) - }) + cb.while_( + i < len, { + cb.assign( + s, + s + weights.loadElement(cb, i).getOrFatal( + cb, + "rand_cat requires all elements of input array to be present", + ).asFloat64.value, + ) + cb.assign(i, i + 1) + }, + ) val r = cb.newLocal[Double]("r", rand_unif(cb, rngState.rand(cb)) * s) cb.assign(i, 0) val elt = cb.newLocal[Double]("elt") cb.loop { start => - cb.assign(elt, weights.loadElement(cb, i).get(cb, "rand_cat requires all elements of input array to be present").asFloat64.value) - cb.if_(r > elt && i < len, { - cb.assign(r, r - elt) - cb.assign(i, i + 1) - cb.goto(start) - }) + cb.assign( + elt, + weights.loadElement(cb, i).getOrFatal( + cb, + "rand_cat requires all elements of input array to be present", + ).asFloat64.value, + ) + cb.if_( + r > elt && i < len, { + cb.assign(r, r - elt) + cb.assign(i, i + 1) + cb.goto(start) + }, + ) } primitive(i) } - registerSCode3("shuffle_compute_num_samples_per_partition", TRNGState, TInt32, TArray(TInt32), TArray(TInt32), - (_, _, _, _) => SIndexablePointer(PCanonicalArray(PInt32(true), false)) - ) { case (r, cb, rt, rngState: SRNGStateValue, initalNumSamplesToSelect: SInt32Value, partitionCounts: SIndexableValue, errorID) => - val rng = cb.emb.getThreefryRNG() - rngState.copyIntoEngine(cb, rng) - - val totalNumberOfRecords = cb.newLocal[Int]("scnspp_total_number_of_records", 0) - val resultSize: Value[Int] = partitionCounts.loadLength() - val i = cb.newLocal[Int]("scnspp_index", 0) - cb.for_(cb.assign(i, 0), i < resultSize, cb.assign(i, i + 1), { - cb.assign(totalNumberOfRecords, totalNumberOfRecords + partitionCounts.loadElement(cb, i).get(cb).asInt32.value) - }) - - cb.if_(initalNumSamplesToSelect.value > totalNumberOfRecords, cb._fatal("Requested selection of ", initalNumSamplesToSelect.value.toS, - " samples from ", totalNumberOfRecords.toS, " records")) + registerSCode3( + "shuffle_compute_num_samples_per_partition", + TRNGState, + TInt32, + TArray(TInt32), + TArray(TInt32), + (_, _, _, _) => SIndexablePointer(PCanonicalArray(PInt32(true), false)), + ) { + case ( + r, + cb, + rt, + rngState: SRNGStateValue, + initalNumSamplesToSelect: SInt32Value, + partitionCounts: SIndexableValue, + _, + ) => + val rng = cb.emb.getThreefryRNG() + rngState.copyIntoEngine(cb, rng) + + val totalNumberOfRecords = cb.newLocal[Int]("scnspp_total_number_of_records", 0) + val resultSize: Value[Int] = partitionCounts.loadLength() + val i = cb.newLocal[Int]("scnspp_index", 0) + cb.for_( + cb.assign(i, 0), + i < resultSize, + cb.assign(i, i + 1), + cb.assign( + totalNumberOfRecords, + totalNumberOfRecords + partitionCounts.loadElement(cb, i).getOrAssert(cb).asInt32.value, + ), + ) - val successStatesRemaining = cb.newLocal[Int]("scnspp_success", initalNumSamplesToSelect.value) - val failureStatesRemaining = cb.newLocal[Int]("scnspp_failure", totalNumberOfRecords - successStatesRemaining) + cb.if_( + initalNumSamplesToSelect.value > totalNumberOfRecords, + cb._fatal( + "Requested selection of ", + initalNumSamplesToSelect.value.toS, + " samples from ", + totalNumberOfRecords.toS, + " records", + ), + ) - val arrayRt = rt.asInstanceOf[SIndexablePointer] - val (push, finish) = arrayRt.pType.asInstanceOf[PCanonicalArray].constructFromFunctions(cb, r.region, resultSize, false) + val successStatesRemaining = + cb.newLocal[Int]("scnspp_success", initalNumSamplesToSelect.value) + val failureStatesRemaining = + cb.newLocal[Int]("scnspp_failure", totalNumberOfRecords - successStatesRemaining) + + val arrayRt = rt.asInstanceOf[SIndexablePointer] + val (push, finish) = arrayRt.pType.asInstanceOf[PCanonicalArray].constructFromFunctions( + cb, + r.region, + resultSize, + false, + ) - cb.for_(cb.assign(i, 0), i < resultSize, cb.assign(i, i + 1), { - val numSuccesses = cb.memoize(rng.invoke[Double, Double, Double, Double]("rhyper", - successStatesRemaining.toD, failureStatesRemaining.toD, partitionCounts.loadElement(cb, i).get(cb).asInt32.value.toD).toI) - cb.assign(successStatesRemaining, successStatesRemaining - numSuccesses) - cb.assign(failureStatesRemaining, failureStatesRemaining - (partitionCounts.loadElement(cb, i).get(cb).asInt32.value - numSuccesses)) - push(cb, IEmitCode.present(cb, new SInt32Value(numSuccesses))) - }) + cb.for_( + cb.assign(i, 0), + i < resultSize, + cb.assign(i, i + 1), { + val numSuccesses = cb.memoize(rng.invoke[Double, Double, Double, Double]( + "rhyper", + successStatesRemaining.toD, + failureStatesRemaining.toD, + partitionCounts.loadElement(cb, i).getOrAssert(cb).asInt32.value.toD, + ).toI) + cb.assign(successStatesRemaining, successStatesRemaining - numSuccesses) + cb.assign( + failureStatesRemaining, + failureStatesRemaining - (partitionCounts.loadElement(cb, i).getOrAssert( + cb + ).asInt32.value - numSuccesses), + ) + push(cb, IEmitCode.present(cb, new SInt32Value(numSuccesses))) + }, + ) - finish(cb) + finish(cb) } } diff --git a/hail/src/main/scala/is/hail/expr/ir/functions/ReferenceGenomeFunctions.scala b/hail/src/main/scala/is/hail/expr/ir/functions/ReferenceGenomeFunctions.scala index 8768b3af931..18955c9c346 100644 --- a/hail/src/main/scala/is/hail/expr/ir/functions/ReferenceGenomeFunctions.scala +++ b/hail/src/main/scala/is/hail/expr/ir/functions/ReferenceGenomeFunctions.scala @@ -1,64 +1,108 @@ package is.hail.expr.ir.functions -import is.hail.asm4s import is.hail.asm4s._ import is.hail.expr.ir._ import is.hail.types.physical.stypes.SType -import is.hail.types.physical.stypes.concrete.{SJavaString, SStringPointer} -import is.hail.types.physical.stypes.primitives.{SBoolean, SInt32} +import is.hail.types.physical.stypes.concrete.SJavaString import is.hail.types.physical.stypes.interfaces._ -import is.hail.types.physical.{PBoolean, PCanonicalString, PInt32, PLocus, PString, PType} +import is.hail.types.physical.stypes.primitives.{SBoolean, SInt32} import is.hail.types.virtual._ import is.hail.variant.ReferenceGenome object ReferenceGenomeFunctions extends RegistryFunctions { - def rgCode(mb: EmitMethodBuilder[_], rg: String): Code[ReferenceGenome] = mb.getReferenceGenome(rg) + def rgCode(mb: EmitMethodBuilder[_], rg: String): Code[ReferenceGenome] = + mb.getReferenceGenome(rg) - def registerAll() { - registerSCode1t("isValidContig", Array(LocusFunctions.tlocus("R")), TString, TBoolean, (_: Type, _: SType) => SBoolean) { + def registerAll(): Unit = { + registerSCode1t( + "isValidContig", + Array(LocusFunctions.tlocus("R")), + TString, + TBoolean, + (_: Type, _: SType) => SBoolean, + ) { case (r, cb, Seq(tlocus: TLocus), _, contig, _) => val scontig = contig.asString.loadString(cb) - primitive(cb.memoize(rgCode(r.mb, tlocus.asInstanceOf[TLocus].rg).invoke[String, Boolean]("isValidContig", scontig))) + primitive(cb.memoize(rgCode(r.mb, tlocus.asInstanceOf[TLocus].rg).invoke[String, Boolean]( + "isValidContig", + scontig, + ))) } - registerSCode2t("isValidLocus", Array(LocusFunctions.tlocus("R")), TString, TInt32, TBoolean, (_: Type, _: SType, _: SType) => SBoolean) { + registerSCode2t( + "isValidLocus", + Array(LocusFunctions.tlocus("R")), + TString, + TInt32, + TBoolean, + (_: Type, _: SType, _: SType) => SBoolean, + ) { case (r, cb, Seq(tlocus: TLocus), _, contig, pos, _) => val scontig = contig.asString.loadString(cb) - primitive(cb.memoize(rgCode(r.mb, tlocus.rg).invoke[String, Int, Boolean]("isValidLocus", scontig, pos.asInt.value))) + primitive(cb.memoize(rgCode(r.mb, tlocus.rg).invoke[String, Int, Boolean]( + "isValidLocus", + scontig, + pos.asInt.value, + ))) } - registerSCode4t("getReferenceSequenceFromValidLocus", + registerSCode4t( + "getReferenceSequenceFromValidLocus", Array(LocusFunctions.tlocus("R")), - TString, TInt32, TInt32, TInt32, TString, - (_: Type, _: SType, _: SType, _: SType, _: SType) => SJavaString) { + TString, + TInt32, + TInt32, + TInt32, + TString, + (_: Type, _: SType, _: SType, _: SType, _: SType) => SJavaString, + ) { case (r, cb, Seq(typeParam: TLocus), st, contig, pos, before, after, _) => val scontig = contig.asString.loadString(cb) - unwrapReturn(cb, r.region, st, - rgCode(cb.emb, typeParam.rg).invoke[String, Int, Int, Int, String]("getSequence", + unwrapReturn( + cb, + r.region, + st, + rgCode(cb.emb, typeParam.rg).invoke[String, Int, Int, Int, String]( + "getSequence", scontig, pos.asInt.value, before.asInt.value, - after.asInt.value)) + after.asInt.value, + ), + ) } - registerSCode1t("contigLength", Array(LocusFunctions.tlocus("R")), TString, TInt32, (_: Type, _: SType) => SInt32) { + registerSCode1t( + "contigLength", + Array(LocusFunctions.tlocus("R")), + TString, + TInt32, + (_: Type, _: SType) => SInt32, + ) { case (r, cb, Seq(tlocus: TLocus), _, contig, _) => val scontig = contig.asString.loadString(cb) primitive(cb.memoize(rgCode(r.mb, tlocus.rg).invoke[String, Int]("contigLength", scontig))) } - registerIR("getReferenceSequence", Array(TString, TInt32, TInt32, TInt32), TString, typeParameters = Array(LocusFunctions.tlocus("R"))) { + registerIR( + "getReferenceSequence", + Array(TString, TInt32, TInt32, TInt32), + TString, + typeParameters = Array(LocusFunctions.tlocus("R")), + ) { case (tl, Seq(contig, pos, before, after), _) => val getRef = IRFunctionRegistry.lookupUnseeded( name = "getReferenceSequenceFromValidLocus", returnType = TString, typeParameters = tl, - arguments = Seq(TString, TInt32, TInt32, TInt32)).get + arguments = Seq(TString, TInt32, TInt32, TInt32), + ).get val isValid = IRFunctionRegistry.lookupUnseeded( "isValidLocus", TBoolean, typeParameters = tl, - Seq(TString, TInt32)).get + Seq(TString, TInt32), + ).get val r = isValid(tl, Seq(contig, pos), ErrorIDs.NO_ERROR) val p = getRef(tl, Seq(contig, pos, before, after), ErrorIDs.NO_ERROR) diff --git a/hail/src/main/scala/is/hail/expr/ir/functions/RelationalFunctions.scala b/hail/src/main/scala/is/hail/expr/ir/functions/RelationalFunctions.scala index da7274f51b3..65e8f6ca13b 100644 --- a/hail/src/main/scala/is/hail/expr/ir/functions/RelationalFunctions.scala +++ b/hail/src/main/scala/is/hail/expr/ir/functions/RelationalFunctions.scala @@ -1,17 +1,15 @@ package is.hail.expr.ir.functions import is.hail.backend.ExecuteContext -import is.hail.expr.ir.lowering.TableStage -import is.hail.expr.ir.{LowerMatrixIR, MatrixValue, RelationalSpec, TableReader, TableValue} -import is.hail.types.virtual.Type -import is.hail.types.{BlockMatrixType, MatrixType, RTable, TableType, TypeWithRequiredness} +import is.hail.expr.ir.{MatrixValue, RelationalSpec, TableValue} import is.hail.linalg.BlockMatrix import is.hail.methods._ +import is.hail.types.{BlockMatrixType, MatrixType, RTable, TableType, TypeWithRequiredness} +import is.hail.types.virtual.Type import is.hail.utils._ -import is.hail.rvd.RVDType -import org.json4s.{Extraction, JValue, ShortTypeHints} -import org.json4s.jackson.{JsonMethods, Serialization} +import org.json4s.{Extraction, JValue, ShortTypeHints} +import org.json4s.jackson.JsonMethods abstract class MatrixToMatrixFunction { def typ(childType: MatrixType): MatrixType @@ -43,13 +41,15 @@ case class WrappedMatrixToTableFunction( function: MatrixToTableFunction, colsFieldName: String, entriesFieldName: String, - colKey: IndexedSeq[String]) extends TableToTableFunction { + colKey: IndexedSeq[String], +) extends TableToTableFunction { override def typ(childType: TableType): TableType = { val mType = MatrixType.fromTableType(childType, colsFieldName, entriesFieldName, colKey) function.typ(mType) // MatrixType RVDTypes will go away } - def execute(ctx: ExecuteContext, tv: TableValue): TableValue = function.execute(ctx, tv.toMatrixValue(colKey, colsFieldName, entriesFieldName)) + def execute(ctx: ExecuteContext, tv: TableValue): TableValue = + function.execute(ctx, tv.toMatrixValue(colKey, colsFieldName, entriesFieldName)) override def preservesPartitionCounts: Boolean = function.preservesPartitionCounts } @@ -63,9 +63,8 @@ abstract class TableToTableFunction { def requestType(requestedType: TableType, childBaseType: TableType): TableType = childBaseType - def toJValue: JValue = { + def toJValue: JValue = Extraction.decompose(this)(RelationalFunctions.formats) - } } abstract class TableToValueFunction { @@ -80,15 +79,17 @@ case class WrappedMatrixToValueFunction( function: MatrixToValueFunction, colsFieldName: String, entriesFieldName: String, - colKey: IndexedSeq[String]) extends TableToValueFunction { + colKey: IndexedSeq[String], +) extends TableToValueFunction { - def typ(childType: TableType): Type = { + def typ(childType: TableType): Type = function.typ(MatrixType.fromTableType(childType, colsFieldName, entriesFieldName, colKey)) - } - def unionRequiredness(childType: RTable, resultType: TypeWithRequiredness): Unit = function.unionRequiredness(childType, resultType) + def unionRequiredness(childType: RTable, resultType: TypeWithRequiredness): Unit = + function.unionRequiredness(childType, resultType) - def execute(ctx: ExecuteContext, tv: TableValue): Any = function.execute(ctx, tv.toMatrixValue(colKey, colsFieldName, entriesFieldName)) + def execute(ctx: ExecuteContext, tv: TableValue): Any = + function.execute(ctx, tv.toMatrixValue(colKey, colsFieldName, entriesFieldName)) } abstract class MatrixToValueFunction { @@ -108,53 +109,62 @@ abstract class BlockMatrixToValueFunction { } object RelationalFunctions { - implicit val formats = RelationalSpec.formats + ShortTypeHints(List( - classOf[LinearRegressionRowsSingle], - classOf[LinearRegressionRowsChained], - classOf[TableFilterPartitions], - classOf[MatrixFilterPartitions], - classOf[TableCalculateNewPartitions], - classOf[ForceCountTable], - classOf[ForceCountMatrixTable], - classOf[NPartitionsTable], - classOf[NPartitionsMatrixTable], - classOf[LogisticRegression], - classOf[PoissonRegression], - classOf[Skat], - classOf[LocalLDPrune], - classOf[MatrixExportEntriesByCol], - classOf[PCA], - classOf[VEP], - classOf[IBD], - classOf[Nirvana], - classOf[GetElement], - classOf[WrappedMatrixToTableFunction], - classOf[WrappedMatrixToValueFunction], - classOf[PCRelate] - ), typeHintFieldName = "name") + implicit val formats = RelationalSpec.formats + ShortTypeHints( + List( + classOf[LinearRegressionRowsSingle], + classOf[LinearRegressionRowsChained], + classOf[TableFilterPartitions], + classOf[MatrixFilterPartitions], + classOf[TableCalculateNewPartitions], + classOf[ForceCountTable], + classOf[ForceCountMatrixTable], + classOf[NPartitionsTable], + classOf[NPartitionsMatrixTable], + classOf[LogisticRegression], + classOf[PoissonRegression], + classOf[Skat], + classOf[LocalLDPrune], + classOf[MatrixExportEntriesByCol], + classOf[PCA], + classOf[VEP], + classOf[IBD], + classOf[Nirvana], + classOf[GetElement], + classOf[WrappedMatrixToTableFunction], + classOf[WrappedMatrixToValueFunction], + classOf[PCRelate], + ), + typeHintFieldName = "name", + ) def extractTo[T: Manifest](ctx: ExecuteContext, config: String): T = { val jv = JsonMethods.parse(config) (jv \ "name").extract[String] match { case "VEP" => VEP.fromJValue(ctx.fs, jv).asInstanceOf[T] - case _ => { + case _ => log.info("JSON: " + jv.toString) jv.extract[T] - } } } - def lookupMatrixToMatrix(ctx: ExecuteContext, config: String): MatrixToMatrixFunction = extractTo[MatrixToMatrixFunction](ctx, config) + def lookupMatrixToMatrix(ctx: ExecuteContext, config: String): MatrixToMatrixFunction = + extractTo[MatrixToMatrixFunction](ctx, config) - def lookupMatrixToTable(ctx: ExecuteContext, config: String): MatrixToTableFunction = extractTo[MatrixToTableFunction](ctx, config) + def lookupMatrixToTable(ctx: ExecuteContext, config: String): MatrixToTableFunction = + extractTo[MatrixToTableFunction](ctx, config) - def lookupTableToTable(ctx: ExecuteContext, config: String): TableToTableFunction = extractTo[TableToTableFunction](ctx, config) + def lookupTableToTable(ctx: ExecuteContext, config: String): TableToTableFunction = + extractTo[TableToTableFunction](ctx, config) - def lookupBlockMatrixToTable(ctx: ExecuteContext, config: String): BlockMatrixToTableFunction = extractTo[BlockMatrixToTableFunction](ctx, config) + def lookupBlockMatrixToTable(ctx: ExecuteContext, config: String): BlockMatrixToTableFunction = + extractTo[BlockMatrixToTableFunction](ctx, config) - def lookupTableToValue(ctx: ExecuteContext, config: String): TableToValueFunction = extractTo[TableToValueFunction](ctx, config) + def lookupTableToValue(ctx: ExecuteContext, config: String): TableToValueFunction = + extractTo[TableToValueFunction](ctx, config) - def lookupMatrixToValue(ctx: ExecuteContext, config: String): MatrixToValueFunction = extractTo[MatrixToValueFunction](ctx, config) + def lookupMatrixToValue(ctx: ExecuteContext, config: String): MatrixToValueFunction = + extractTo[MatrixToValueFunction](ctx, config) - def lookupBlockMatrixToValue(ctx: ExecuteContext, config: String): BlockMatrixToValueFunction = extractTo[BlockMatrixToValueFunction](ctx, config) + def lookupBlockMatrixToValue(ctx: ExecuteContext, config: String): BlockMatrixToValueFunction = + extractTo[BlockMatrixToValueFunction](ctx, config) } diff --git a/hail/src/main/scala/is/hail/expr/ir/functions/SetFunctions.scala b/hail/src/main/scala/is/hail/expr/ir/functions/SetFunctions.scala index 90d8188d5c3..61c86cee77a 100644 --- a/hail/src/main/scala/is/hail/expr/ir/functions/SetFunctions.scala +++ b/hail/src/main/scala/is/hail/expr/ir/functions/SetFunctions.scala @@ -8,18 +8,22 @@ object SetFunctions extends RegistryFunctions { def contains(set: IR, elem: IR) = { val i = Ref(genUID(), TInt32) - If(IsNA(set), + If( + IsNA(set), NA(TBoolean), - Let(FastSeq(i.name -> LowerBoundOnOrderedCollection(set, elem, onKey = false)), - If(i.ceq(ArrayLen(CastToArray(set))), + Let( + FastSeq(i.name -> LowerBoundOnOrderedCollection(set, elem, onKey = false)), + If( + i.ceq(ArrayLen(CastToArray(set))), False(), - ApplyComparisonOp(EQWithNA(elem.typ), ArrayRef(CastToArray(set), i), elem)))) + ApplyComparisonOp(EQWithNA(elem.typ), ArrayRef(CastToArray(set), i), elem), + ), + ), + ) } - def registerAll() { - registerIR1("toSet", TArray(tv("T")), TSet(tv("T"))) { (_, a, _) => - ToSet(ToStream(a)) - } + def registerAll(): Unit = { + registerIR1("toSet", TArray(tv("T")), TSet(tv("T")))((_, a, _) => ToSet(ToStream(a))) registerIR1("isEmpty", TSet(tv("T")), TBoolean) { (_, s, _) => ArrayFunctions.isEmpty(CastToArray(s)) @@ -34,7 +38,9 @@ object SetFunctions extends RegistryFunctions { StreamFilter( ToStream(s), x, - ApplyComparisonOp(NEQWithNA(t), Ref(x, t), v))) + ApplyComparisonOp(NEQWithNA(t), Ref(x, t), v), + ) + ) } registerIR2("add", TSet(tv("T")), tv("T"), TSet(tv("T"))) { (_, s, v, _) => @@ -44,7 +50,9 @@ object SetFunctions extends RegistryFunctions { StreamFlatMap( MakeStream(FastSeq(CastToArray(s), MakeArray(FastSeq(v), TArray(t))), TStream(TArray(t))), x, - ToStream(Ref(x, TArray(t))))) + ToStream(Ref(x, TArray(t))), + ) + ) } registerIR2("union", TSet(tv("T")), TSet(tv("T")), TSet(tv("T"))) { (_, s1, s2, _) => @@ -54,32 +62,45 @@ object SetFunctions extends RegistryFunctions { StreamFlatMap( MakeStream(FastSeq(CastToArray(s1), CastToArray(s2)), TStream(TArray(t))), x, - ToStream(Ref(x, TArray(t))))) + ToStream(Ref(x, TArray(t))), + ) + ) } registerIR2("intersection", TSet(tv("T")), TSet(tv("T")), TSet(tv("T"))) { (_, s1, s2, _) => val t = s1.typ.asInstanceOf[TSet].elementType val x = genUID() ToSet( - StreamFilter(ToStream(s1), x, - contains(s2, Ref(x, t)))) + StreamFilter(ToStream(s1), x, contains(s2, Ref(x, t))) + ) } registerIR2("difference", TSet(tv("T")), TSet(tv("T")), TSet(tv("T"))) { (_, s1, s2, _) => val t = s1.typ.asInstanceOf[TSet].elementType val x = genUID() ToSet( - StreamFilter(ToStream(s1), x, - ApplyUnaryPrimOp(Bang, contains(s2, Ref(x, t))))) + StreamFilter(ToStream(s1), x, ApplyUnaryPrimOp(Bang, contains(s2, Ref(x, t)))) + ) } registerIR2("isSubset", TSet(tv("T")), TSet(tv("T")), TBoolean) { (_, s, w, errorID) => val t = s.typ.asInstanceOf[TSet].elementType val a = genUID() val x = genUID() - StreamFold(ToStream(s), True(), a, x, + StreamFold( + ToStream(s), + True(), + a, + x, // FIXME short circuit - ApplySpecial("land", FastSeq(), FastSeq(Ref(a, TBoolean), contains(w, Ref(x, t))), TBoolean, errorID)) + ApplySpecial( + "land", + FastSeq(), + FastSeq(Ref(a, TBoolean), contains(w, Ref(x, t))), + TBoolean, + errorID, + ), + ) } registerIR1("median", TSet(tnum("T")), tv("T")) { (_, s, _) => @@ -92,15 +113,25 @@ object SetFunctions extends RegistryFunctions { val len: IR = ArrayLen(a) def div(a: IR, b: IR): IR = ApplyBinaryPrimOp(BinaryOp.defaultDivideOp(t), a, b) - Let(FastSeq(a.name -> CastToArray(s)), - If(IsNA(a), + Let( + FastSeq(a.name -> CastToArray(s)), + If( + IsNA(a), NA(t), - Let(FastSeq(size.name -> If(len.ceq(0), len, If(IsNA(ref(len - 1)), len - 1, len))), - If(size.ceq(0), + Let( + FastSeq(size.name -> If(len.ceq(0), len, If(IsNA(ref(len - 1)), len - 1, len))), + If( + size.ceq(0), NA(t), - If(invoke("mod", TInt32, size, 2).cne(0), + If( + invoke("mod", TInt32, size, 2).cne(0), ref(midIdx), // odd number of non-missing elements - div(ref(midIdx) + ref(midIdx + 1), Cast(2, t))))))) + div(ref(midIdx) + ref(midIdx + 1), Cast(2, t)), + ), + ), + ), + ), + ) } } } diff --git a/hail/src/main/scala/is/hail/expr/ir/functions/StringFunctions.scala b/hail/src/main/scala/is/hail/expr/ir/functions/StringFunctions.scala index 087d769b0d0..2141459f8db 100644 --- a/hail/src/main/scala/is/hail/expr/ir/functions/StringFunctions.scala +++ b/hail/src/main/scala/is/hail/expr/ir/functions/StringFunctions.scala @@ -5,20 +5,24 @@ import is.hail.asm4s._ import is.hail.expr.JSONAnnotationImpex import is.hail.expr.ir._ import is.hail.types.physical.stypes._ -import is.hail.types.physical.stypes.concrete.{SJavaArrayString, SJavaArrayStringSettable, SJavaArrayStringValue, SJavaString} +import is.hail.types.physical.stypes.concrete.{ + SJavaArrayString, SJavaArrayStringSettable, SJavaArrayStringValue, SJavaString, +} import is.hail.types.physical.stypes.interfaces._ import is.hail.types.physical.stypes.primitives.{SBoolean, SInt32, SInt64} import is.hail.types.virtual._ import is.hail.utils._ -import org.apache.spark.sql.Row -import org.json4s.JValue -import org.json4s.jackson.JsonMethods -import java.time.temporal.ChronoField +import scala.collection.mutable + import java.time.{Instant, ZoneId} +import java.time.temporal.ChronoField import java.util.Locale import java.util.regex.{Matcher, Pattern} -import scala.collection.mutable + +import org.apache.spark.sql.Row +import org.json4s.JValue +import org.json4s.jackson.JsonMethods object StringFunctions extends RegistryFunctions { def reverse(s: String): String = { @@ -39,9 +43,8 @@ object StringFunctions extends RegistryFunctions { def endswith(s: String, t: String): Boolean = s.endsWith(t) - def firstMatchIn(s: String, regex: String): Array[String] = { + def firstMatchIn(s: String, regex: String): Array[String] = regex.r.findFirstMatchIn(s).map(_.subgroups.toArray).orNull - } def regexMatch(regex: String, s: String): Boolean = regex.r.findFirstIn(s).isDefined @@ -51,7 +54,7 @@ object StringFunctions extends RegistryFunctions { def replace(str: String, pattern1: String, pattern2: String): String = str.replaceAll(pattern1, pattern2) - + def split(s: String, p: String): Array[String] = s.split(p, -1) def translate(s: String, d: Map[String, String]): String = { @@ -109,27 +112,39 @@ object StringFunctions extends RegistryFunctions { separator: Either[Value[Char], Value[String]], quoteChar: Option[Value[Char]], missingSV: SIndexableValue, - errorID: Value[Int] + errorID: Value[Int], ): Value[Array[String]] = { - // note: it will be inefficient to convert a SIndexablePointer to SJavaArrayString to split each line. - // We should really choose SJavaArrayString as the stype for a literal if used in a place like this, + /* note: it will be inefficient to convert a SIndexablePointer to SJavaArrayString to split each + * line. */ + /* We should really choose SJavaArrayString as the stype for a literal if used in a place like + * this, */ // but this is a non-local stype decision that is hard in the current system val missing: Value[Array[String]] = missingSV.st match { - case SJavaArrayString(elementRequired) => missingSV.asInstanceOf[SJavaArrayStringSettable].array + case SJavaArrayString(_) => + missingSV.asInstanceOf[SJavaArrayStringSettable].array case _ => - val mb = cb.emb.ecb.newEmitMethod("convert_region_to_str_array", FastSeq(missingSV.st.paramType), arrayInfo[String]) + val mb = cb.emb.ecb.newEmitMethod( + "convert_region_to_str_array", + FastSeq(missingSV.st.paramType), + arrayInfo[String], + ) mb.emitWithBuilder[Array[String]] { cb => val sv = mb.getSCodeParam(1).asIndexable val m = cb.newLocal[Array[String]]("missingvals", Code.newArray[String](sv.loadLength())) - sv.forEachDefined(cb) { case (cb, idx, sc) => cb += (m(idx) = sc.asString.loadString(cb)) } + sv.forEachDefined(cb) { case (cb, idx, sc) => + cb += (m(idx) = sc.asString.loadString(cb)) + } m } - cb.newLocal[Array[String]]("missing_arr", cb.invokeCode(mb, missingSV)) + cb.newLocal[Array[String]]("missing_arr", cb.invokeCode(mb, cb.this_, missingSV)) } // lazy field reused across calls to split functions - val ab = cb.emb.getOrDefineLazyField[StringArrayBuilder](Code.newInstance[StringArrayBuilder, Int](16), "generate_split_quoted_regex_ab") + val ab = cb.emb.getOrDefineLazyField[StringArrayBuilder]( + Code.newInstance[StringArrayBuilder, Int](16), + "generate_split_quoted_regex_ab", + ) cb += ab.invoke[Unit]("clear") // takes the current position and current char value, returns the number of matching chars @@ -143,12 +158,22 @@ object StringFunctions extends RegistryFunctions { x } case Right(regex) => - val m = cb.newLocal[Matcher]("matcher", + val m = cb.newLocal[Matcher]( + "matcher", Code.invokeStatic1[Pattern, String, Pattern]("compile", regex) - .invoke[CharSequence, Matcher]("matcher", string)); + .invoke[CharSequence, Matcher]("matcher", string), + ); (idx: Value[Int], _: Value[Char]) => { - cb.assign(x, Code.invokeScalaObject3[String, Int, Matcher, Int]( - StringFunctions.getClass, "matchPattern", string, idx, m)); + cb.assign( + x, + Code.invokeScalaObject3[String, Int, Matcher, Int]( + StringFunctions.getClass, + "matchPattern", + string, + idx, + m, + ), + ); x } } @@ -157,59 +182,87 @@ object StringFunctions extends RegistryFunctions { val i = cb.newLocal[Int]("i", 0) val lastFieldStart = cb.newLocal[Int]("lastfieldstart", 0) - def addValueOrNA(cb: EmitCodeBuilder, endIdx: Code[Int]): Unit = { + def addValueOrNA(cb: EmitCodeBuilder, endIdx: Code[Int]): Unit = cb += Code.invokeScalaObject3[StringArrayBuilder, String, Array[String], Unit]( - StringFunctions.getClass, "addValueOrNull", ab, string.invoke[Int, Int, String]("substring", lastFieldStart, endIdx), missing) - } + StringFunctions.getClass, + "addValueOrNull", + ab, + string.invoke[Int, Int, String]("substring", lastFieldStart, endIdx), + missing, + ) val LreturnWithoutAppending = CodeLabel() - cb.while_(i < string.length(), { - val c = cb.newLocal[Char]("c", string(i)) - - val l = getPatternMatch(i, c) - cb.if_(l.cne(-1), { - addValueOrNA(cb, i) - cb.assign(i, i + l) // skip delim - cb.assign(lastFieldStart, i) - }, { - quoteChar match { - case Some(qc) => - cb.if_(c.ceq(qc), { - cb.if_(i.cne(lastFieldStart), - cb._fatalWithError(errorID, "opening quote character '", qc.toS, "' not at start of field")) - cb.assign(i, i + 1) // skip quote - cb.assign(lastFieldStart, i) - - cb.while_(i < string.length() && string(i).cne(qc), { + cb.while_( + i < string.length(), { + val c = cb.newLocal[Char]("c", string(i)) + + val l = getPatternMatch(i, c) + cb.if_( + l.cne(-1), { + addValueOrNA(cb, i) + cb.assign(i, i + l) // skip delim + cb.assign(lastFieldStart, i) + }, { + quoteChar match { + case Some(qc) => + cb.if_( + c.ceq(qc), { + cb.if_( + i.cne(lastFieldStart), + cb._fatalWithError( + errorID, + "opening quote character '", + qc.toS, + "' not at start of field", + ), + ) + cb.assign(i, i + 1) // skip quote + cb.assign(lastFieldStart, i) + + cb.while_(i < string.length() && string(i).cne(qc), cb.assign(i, i + 1)) + + addValueOrNA(cb, i) + + cb.if_( + i.ceq(string.length()), + cb._fatalWithError( + errorID, + "missing terminating quote character '", + qc.toS, + "'", + ), + ) + cb.assign(i, i + 1) // skip quote + + cb.if_( + i < string.length, { + cb.assign(c, string(i)) + val l = getPatternMatch(i, c) + cb.if_( + l.ceq(-1), + cb._fatalWithError( + errorID, + "terminating quote character '", + qc.toS, + "' not at end of field", + ), + ) + cb.assign(i, i + l) // skip delim + cb.assign(lastFieldStart, i) + }, + cb.goto(LreturnWithoutAppending), + ) + }, + cb.assign(i, i + 1), + ) + case None => cb.assign(i, i + 1) - }) - - addValueOrNA(cb, i) - - cb.if_(i.ceq(string.length()), - cb._fatalWithError(errorID, "missing terminating quote character '", qc.toS, "'")) - cb.assign(i, i + 1) // skip quote - - cb.if_(i < string.length, { - cb.assign(c, string(i)) - val l = getPatternMatch(i, c) - cb.if_(l.ceq(-1), { - cb._fatalWithError(errorID, "terminating quote character '", qc.toS, "' not at end of field") - }) - cb.assign(i, i + l) // skip delim - cb.assign(lastFieldStart, i) - }, { - cb.goto(LreturnWithoutAppending) - }) - }, { - cb.assign(i, i + 1) - }) - case None => - cb.assign(i, i + 1) - } - }) - }) + } + }, + ) + }, + ) addValueOrNA(cb, string.length()) cb.define(LreturnWithoutAppending) @@ -233,15 +286,25 @@ object StringFunctions extends RegistryFunctions { def registerAll(): Unit = { val thisClass = getClass - registerSCode1("length", TString, TInt32, (_: Type, _: SType) => SInt32) { case (r: EmitRegion, cb, _, s: SStringValue, _) => - primitive(cb.memoize(s.loadString(cb).invoke[Int]("length"))) + registerSCode1("length", TString, TInt32, (_: Type, _: SType) => SInt32) { + case (_: EmitRegion, cb, _, s: SStringValue, _) => + primitive(cb.memoize(s.loadString(cb).invoke[Int]("length"))) } - registerSCode3("substring", TString, TInt32, TInt32, TString, { - (_: Type, _: SType, _: SType, _: SType) => SJavaString - }) { - case (r: EmitRegion, cb, st: SJavaString.type, s, start, end, _) => - val str = s.asString.loadString(cb).invoke[Int, Int, String]("substring", start.asInt.value, end.asInt.value) + registerSCode3( + "substring", + TString, + TInt32, + TInt32, + TString, + (_: Type, _: SType, _: SType, _: SType) => SJavaString, + ) { + case (_: EmitRegion, cb, st: SJavaString.type, s, start, end, _) => + val str = s.asString.loadString(cb).invoke[Int, Int, String]( + "substring", + start.asInt.value, + end.asInt.value, + ) st.construct(cb, str) } @@ -255,7 +318,7 @@ object StringFunctions extends RegistryFunctions { s.name -> softBounds(start, len), e.name -> softBounds(end, len), ), - invoke("substring", TString, str, s, If(e < s, s, e)) + invoke("substring", TString, str, s, If(e < s, s, e)), ) } @@ -265,191 +328,393 @@ object StringFunctions extends RegistryFunctions { Let( FastSeq( len.name -> invoke("length", TInt32, s), - idx.name -> If((i < -len) || (i >= len), - Die(invoke("concat", TString, - Str("string index out of bounds: "), - invoke("concat", TString, - invoke("str", TString, i), - invoke("concat", TString, Str(" / "), invoke("str", TString, len)))), TInt32, errorID), - If(i < 0, i + len, i)) + idx.name -> If( + (i < -len) || (i >= len), + Die( + invoke( + "concat", + TString, + Str("string index out of bounds: "), + invoke( + "concat", + TString, + invoke("str", TString, i), + invoke("concat", TString, Str(" / "), invoke("str", TString, len)), + ), + ), + TInt32, + errorID, + ), + If(i < 0, i + len, i), + ), ), - invoke("substring", TString, s, idx, idx + 1) + invoke("substring", TString, s, idx, idx + 1), ) } - registerIR2("sliceRight", TString, TInt32, TString) { (_, s, start, _) => invoke("slice", TString, s, start, invoke("length", TInt32, s)) } - registerIR2("sliceLeft", TString, TInt32, TString) { (_, s, end, _) => invoke("slice", TString, s, I32(0), end) } + registerIR2("sliceRight", TString, TInt32, TString) { (_, s, start, _) => + invoke("slice", TString, s, start, invoke("length", TInt32, s)) + } + registerIR2("sliceLeft", TString, TInt32, TString) { (_, s, end, _) => + invoke("slice", TString, s, I32(0), end) + } - registerSCode1("str", tv("T"), TString, (_: Type, _: SType) => SJavaString) { case (r, cb, st: SJavaString.type, a, _) => - val annotation = svalueToJavaValue(cb, r.region, a) - val str = cb.emb.getType(a.st.virtualType).invoke[Any, String]("str", annotation) - st.construct(cb, str) + registerSCode1("str", tv("T"), TString, (_: Type, _: SType) => SJavaString) { + case (r, cb, st: SJavaString.type, a, _) => + val annotation = svalueToJavaValue(cb, r.region, a) + val str = cb.emb.getType(a.st.virtualType).invoke[Any, String]("str", annotation) + st.construct(cb, str) } - registerIEmitCode1("showStr", tv("T"), TString, { - (_: Type, _: EmitType) => EmitType(SJavaString, true) - }) { case (cb, r, st: SJavaString.type, _, a) => + registerIEmitCode1( + "showStr", + tv("T"), + TString, + (_: Type, _: EmitType) => EmitType(SJavaString, true), + ) { case (cb, r, st: SJavaString.type, _, a) => val jObj = cb.newLocalAny("showstr_java_obj", boxedTypeInfo(a.st.virtualType)) - a.toI(cb).consume(cb, + a.toI(cb).consume( + cb, cb.assignAny(jObj, Code._null(boxedTypeInfo(a.st.virtualType))), - sc => cb.assignAny(jObj, svalueToJavaValue(cb, r, sc))) + sc => cb.assignAny(jObj, svalueToJavaValue(cb, r, sc)), + ) val str = cb.emb.getType(a.st.virtualType).invoke[Any, String]("showStr", jObj) IEmitCode.present(cb, st.construct(cb, str)) } - registerIEmitCode2("showStr", tv("T"), TInt32, TString, { - (_: Type, _: EmitType, truncType: EmitType) => EmitType(SJavaString, truncType.required) - }) { case (cb, r, st: SJavaString.type, _, a, trunc) => + registerIEmitCode2( + "showStr", + tv("T"), + TInt32, + TString, + (_: Type, _: EmitType, truncType: EmitType) => EmitType(SJavaString, truncType.required), + ) { case (cb, r, st: SJavaString.type, _, a, trunc) => val jObj = cb.newLocalAny("showstr_java_obj", boxedTypeInfo(a.st.virtualType)) trunc.toI(cb).map(cb) { trunc => - - a.toI(cb).consume(cb, + a.toI(cb).consume( + cb, cb.assignAny(jObj, Code._null(boxedTypeInfo(a.st.virtualType))), - sc => cb.assignAny(jObj, svalueToJavaValue(cb, r, sc))) - - val str = cb.emb.getType(a.st.virtualType).invoke[Any, Int, String]("showStr", jObj, trunc.asInt.value) + sc => cb.assignAny(jObj, svalueToJavaValue(cb, r, sc)), + ) + + val str = cb.emb.getType(a.st.virtualType).invoke[Any, Int, String]( + "showStr", + jObj, + trunc.asInt.value, + ) st.construct(cb, str) } } - registerIEmitCode1("json", tv("T"), TString, (_: Type, _: EmitType) => EmitType(SJavaString, true)) { + registerIEmitCode1( + "json", + tv("T"), + TString, + (_: Type, _: EmitType) => EmitType(SJavaString, true), + ) { case (cb, r, st: SJavaString.type, _, a) => val ti = boxedTypeInfo(a.st.virtualType) val inputJavaValue = cb.newLocalAny("json_func_input_jv", ti) - a.toI(cb).consume(cb, + a.toI(cb).consume( + cb, cb.assignAny(inputJavaValue, Code._null(ti)), { sc => val jv = svalueToJavaValue(cb, r, sc) cb.assignAny(inputJavaValue, jv) - }) + }, + ) val json = cb.emb.getType(a.st.virtualType).invoke[Any, JValue]("toJSON", inputJavaValue) val str = Code.invokeScalaObject1[JValue, String](JsonMethods.getClass, "compact", json) IEmitCode.present(cb, st.construct(cb, str)) } - - registerWrappedScalaFunction1("reverse", TString, TString, (_: Type, _: SType) => SJavaString)(thisClass, "reverse") - registerWrappedScalaFunction1("upper", TString, TString, (_: Type, _: SType) => SJavaString)(thisClass, "upper") - registerWrappedScalaFunction1("lower", TString, TString, (_: Type, _: SType) => SJavaString)(thisClass, "lower") - registerWrappedScalaFunction1("strip", TString, TString, (_: Type, _: SType) => SJavaString)(thisClass, "strip") - registerWrappedScalaFunction2("contains", TString, TString, TBoolean, { - case (_: Type, _: SType, _: SType) => SBoolean - })(thisClass, "contains") - registerWrappedScalaFunction2("translate", TString, TDict(TString, TString), TString, { - case (_: Type, _: SType, _: SType) => SJavaString - })(thisClass, "translate") - registerWrappedScalaFunction2("startswith", TString, TString, TBoolean, { - case (_: Type, _: SType, _: SType) => SBoolean - })(thisClass, "startswith") - registerWrappedScalaFunction2("endswith", TString, TString, TBoolean, { - case (_: Type, _: SType, _: SType) => SBoolean - })(thisClass, "endswith") - registerWrappedScalaFunction2("regexMatch", TString, TString, TBoolean, { - case (_: Type, _: SType, _: SType) => SBoolean - })(thisClass, "regexMatch") - registerWrappedScalaFunction2("regexFullMatch", TString, TString, TBoolean, { - case (_: Type, _: SType, _: SType) => SBoolean - })(thisClass, "regexFullMatch") - registerWrappedScalaFunction2("concat", TString, TString, TString, { - case (_: Type, _: SType, _: SType) => SJavaString - })(thisClass, "concat") - - registerWrappedScalaFunction2("split", TString, TString, TArray(TString), { - case (_: Type, _: SType, _: SType) => - SJavaArrayString(true) - })(thisClass, "split") - - registerWrappedScalaFunction3("split", TString, TString, TInt32, TArray(TString), { - case (_: Type, _: SType, _: SType, _: SType) => - SJavaArrayString(true) - })(thisClass, "splitLimited") - - registerWrappedScalaFunction3("replace", TString, TString, TString, TString, { - case (_: Type, _: SType, _: SType, _: SType) => SJavaString - })(thisClass, "replace") - - registerWrappedScalaFunction2("mkString", TSet(TString), TString, TString, { - case (_: Type, _: SType, _: SType) => SJavaString - })(thisClass, "setMkString") - - registerSCode4("splitQuotedRegex", TString, TString, TArray(TString), TString, TArray(TString), { - case (_: Type, _: SType, _: SType, _: SType, _: SType) => SJavaArrayString(false) - }) { case (r, cb, st: SJavaArrayString, s, separator, missing, quote, errorID) => + registerWrappedScalaFunction1("reverse", TString, TString, (_: Type, _: SType) => SJavaString)( + thisClass, + "reverse", + ) + registerWrappedScalaFunction1("upper", TString, TString, (_: Type, _: SType) => SJavaString)( + thisClass, + "upper", + ) + registerWrappedScalaFunction1("lower", TString, TString, (_: Type, _: SType) => SJavaString)( + thisClass, + "lower", + ) + registerWrappedScalaFunction1("strip", TString, TString, (_: Type, _: SType) => SJavaString)( + thisClass, + "strip", + ) + registerWrappedScalaFunction2( + "contains", + TString, + TString, + TBoolean, + { + case (_: Type, _: SType, _: SType) => SBoolean + }, + )(thisClass, "contains") + registerWrappedScalaFunction2( + "translate", + TString, + TDict(TString, TString), + TString, + { + case (_: Type, _: SType, _: SType) => SJavaString + }, + )(thisClass, "translate") + registerWrappedScalaFunction2( + "startswith", + TString, + TString, + TBoolean, + { + case (_: Type, _: SType, _: SType) => SBoolean + }, + )(thisClass, "startswith") + registerWrappedScalaFunction2( + "endswith", + TString, + TString, + TBoolean, + { + case (_: Type, _: SType, _: SType) => SBoolean + }, + )(thisClass, "endswith") + registerWrappedScalaFunction2( + "regexMatch", + TString, + TString, + TBoolean, + { + case (_: Type, _: SType, _: SType) => SBoolean + }, + )(thisClass, "regexMatch") + registerWrappedScalaFunction2( + "regexFullMatch", + TString, + TString, + TBoolean, + { + case (_: Type, _: SType, _: SType) => SBoolean + }, + )(thisClass, "regexFullMatch") + registerWrappedScalaFunction2( + "concat", + TString, + TString, + TString, + { + case (_: Type, _: SType, _: SType) => SJavaString + }, + )(thisClass, "concat") + + registerWrappedScalaFunction2( + "split", + TString, + TString, + TArray(TString), + { + case (_: Type, _: SType, _: SType) => + SJavaArrayString(true) + }, + )(thisClass, "split") + + registerWrappedScalaFunction3( + "split", + TString, + TString, + TInt32, + TArray(TString), + { + case (_: Type, _: SType, _: SType, _: SType) => + SJavaArrayString(true) + }, + )(thisClass, "splitLimited") + + registerWrappedScalaFunction3( + "replace", + TString, + TString, + TString, + TString, + { + case (_: Type, _: SType, _: SType, _: SType) => SJavaString + }, + )(thisClass, "replace") + + registerWrappedScalaFunction2( + "mkString", + TSet(TString), + TString, + TString, + { + case (_: Type, _: SType, _: SType) => SJavaString + }, + )(thisClass, "setMkString") + + registerSCode4( + "splitQuotedRegex", + TString, + TString, + TArray(TString), + TString, + TArray(TString), + { + case (_: Type, _: SType, _: SType, _: SType, _: SType) => SJavaArrayString(false) + }, + ) { case (_, cb, st: SJavaArrayString, s, separator, missing, quote, errorID) => val quoteStr = cb.newLocal[String]("quoteStr", quote.asString.loadString(cb)) val quoteChar = cb.newLocal[Char]("quoteChar") - cb.if_(quoteStr.length().cne(1), cb._fatalWithError(errorID, "quote must be a single character")) + cb.if_( + quoteStr.length().cne(1), + cb._fatalWithError(errorID, "quote must be a single character"), + ) cb.assign(quoteChar, quoteStr(0)) val string = cb.newLocal[String]("string", s.asString.loadString(cb)) val sep = cb.newLocal[String]("sep", separator.asString.loadString(cb)) val mv = missing.asIndexable - new SJavaArrayStringValue(st, generateSplitQuotedRegex(cb, string, Right(sep), Some(quoteChar), mv, errorID)) + new SJavaArrayStringValue( + st, + generateSplitQuotedRegex(cb, string, Right(sep), Some(quoteChar), mv, errorID), + ) } - registerSCode4("splitQuotedChar", TString, TString, TArray(TString), TString, TArray(TString), { - case (_: Type, _: SType, _: SType, _: SType, _: SType) => SJavaArrayString(false) - }) { case (r, cb, st: SJavaArrayString, s, separator, missing, quote, errorID) => + registerSCode4( + "splitQuotedChar", + TString, + TString, + TArray(TString), + TString, + TArray(TString), + { + case (_: Type, _: SType, _: SType, _: SType, _: SType) => SJavaArrayString(false) + }, + ) { case (_, cb, st: SJavaArrayString, s, separator, missing, quote, errorID) => val quoteStr = cb.newLocal[String]("quoteStr", quote.asString.loadString(cb)) val quoteChar = cb.newLocal[Char]("quoteChar") - cb.if_(quoteStr.length().cne(1), cb._fatalWithError(errorID, "quote must be a single character")) + cb.if_( + quoteStr.length().cne(1), + cb._fatalWithError(errorID, "quote must be a single character"), + ) cb.assign(quoteChar, quoteStr(0)) val string = cb.newLocal[String]("string", s.asString.loadString(cb)) val sep = cb.newLocal[String]("sep", separator.asString.loadString(cb)) val sepChar = cb.newLocal[Char]("sepChar") - cb.if_(sep.length().cne(1), cb._fatalWithError(errorID, "splitQuotedChar expected a single character for separator")) + cb.if_( + sep.length().cne(1), + cb._fatalWithError(errorID, "splitQuotedChar expected a single character for separator"), + ) cb.assign(sepChar, sep(0)) val mv = missing.asIndexable - new SJavaArrayStringValue(st, generateSplitQuotedRegex(cb, string, Left(sepChar), Some(quoteChar), mv, errorID)) + new SJavaArrayStringValue( + st, + generateSplitQuotedRegex(cb, string, Left(sepChar), Some(quoteChar), mv, errorID), + ) } - registerSCode3("splitRegex", TString, TString, TArray(TString), TArray(TString), { - case (_: Type, _: SType, _: SType, _: SType) => SJavaArrayString(false) - }) { case (r, cb, st: SJavaArrayString, s, separator, missing, errorID) => + registerSCode3( + "splitRegex", + TString, + TString, + TArray(TString), + TArray(TString), + { + case (_: Type, _: SType, _: SType, _: SType) => SJavaArrayString(false) + }, + ) { case (_, cb, st: SJavaArrayString, s, separator, missing, errorID) => val string = cb.newLocal[String]("string", s.asString.loadString(cb)) val sep = cb.newLocal[String]("sep", separator.asString.loadString(cb)) val mv = missing.asIndexable - new SJavaArrayStringValue(st, generateSplitQuotedRegex(cb, string, Right(sep), None, mv, errorID)) + new SJavaArrayStringValue( + st, + generateSplitQuotedRegex(cb, string, Right(sep), None, mv, errorID), + ) } - registerSCode3("splitChar", TString, TString, TArray(TString), TArray(TString), { - case (_: Type, _: SType, _: SType, _: SType) => SJavaArrayString(false) - }) { case (r, cb, st: SJavaArrayString, s, separator, missing, errorID) => + registerSCode3( + "splitChar", + TString, + TString, + TArray(TString), + TArray(TString), + { + case (_: Type, _: SType, _: SType, _: SType) => SJavaArrayString(false) + }, + ) { case (_, cb, st: SJavaArrayString, s, separator, missing, errorID) => val string = cb.newLocal[String]("string", s.asString.loadString(cb)) val sep = cb.newLocal[String]("sep", separator.asString.loadString(cb)) val sepChar = cb.newLocal[Char]("sepChar") - cb.if_(sep.length().cne(1), cb._fatalWithError(errorID, "splitChar expected a single character for separator")) + cb.if_( + sep.length().cne(1), + cb._fatalWithError(errorID, "splitChar expected a single character for separator"), + ) cb.assign(sepChar, sep(0)) val mv = missing.asIndexable - new SJavaArrayStringValue(st, generateSplitQuotedRegex(cb, string, Left(sepChar), None, mv, errorID)) + new SJavaArrayStringValue( + st, + generateSplitQuotedRegex(cb, string, Left(sepChar), None, mv, errorID), + ) } - registerWrappedScalaFunction2("mkString", TArray(TString), TString, TString, { - case (_: Type, _: SType, _: SType) => SJavaString - })(thisClass, "arrayMkString") - - registerIEmitCode2("firstMatchIn", TString, TString, TArray(TString), { - case (_: Type, _: EmitType, _: EmitType) => EmitType(SJavaArrayString(true), false) - }) { case (cb: EmitCodeBuilder, region: Value[Region], st: SJavaArrayString, _, - s: EmitCode, r: EmitCode) => - s.toI(cb).flatMap(cb) { case sc: SStringValue => - r.toI(cb).flatMap(cb) { case rc: SStringValue => - val out = cb.newLocal[Array[String]]("out", - Code.invokeScalaObject2[String, String, Array[String]]( - thisClass, "firstMatchIn", sc.loadString(cb), rc.loadString(cb))) - IEmitCode(cb, out.isNull, st.construct(cb, out)) + registerWrappedScalaFunction2( + "mkString", + TArray(TString), + TString, + TString, + { + case (_: Type, _: SType, _: SType) => SJavaString + }, + )(thisClass, "arrayMkString") + + registerIEmitCode2( + "firstMatchIn", + TString, + TString, + TArray(TString), + { + case (_: Type, _: EmitType, _: EmitType) => EmitType(SJavaArrayString(true), false) + }, + ) { + case ( + cb: EmitCodeBuilder, + _: Value[Region], + st: SJavaArrayString, + _, + s: EmitCode, + r: EmitCode, + ) => + s.toI(cb).flatMap(cb) { case sc: SStringValue => + r.toI(cb).flatMap(cb) { case rc: SStringValue => + val out = cb.newLocal[Array[String]]( + "out", + Code.invokeScalaObject2[String, String, Array[String]]( + thisClass, + "firstMatchIn", + sc.loadString(cb), + rc.loadString(cb), + ), + ) + IEmitCode(cb, out.isNull, st.construct(cb, out)) + } } - } } - registerEmitCode2("hamming", TString, TString, TInt32, { - case (_: Type, _: EmitType, _: EmitType) => EmitType(SInt32, false) - }) { case (r: EmitRegion, rt, _, e1: EmitCode, e2: EmitCode) => + registerEmitCode2( + "hamming", + TString, + TString, + TInt32, + { + case (_: Type, _: EmitType, _: EmitType) => EmitType(SInt32, false) + }, + ) { case (r: EmitRegion, _, _, e1: EmitCode, e2: EmitCode) => EmitCode.fromI(r.mb) { cb => e1.toI(cb).flatMap(cb) { case sc1: SStringValue => e2.toI(cb).flatMap(cb) { case sc2: SStringValue => @@ -463,36 +728,73 @@ object StringFunctions extends RegistryFunctions { val l2 = cb.newLocal[Int]("hamming_len_2", v2.invoke[Int]("length")) val m = l1.cne(l2) - IEmitCode(cb, m, { - cb.while_(i < l1, { - cb.if_(v1.invoke[Int, Char]("charAt", i).toI.cne(v2.invoke[Int, Char]("charAt", i).toI), - cb.assign(n, n + 1)) - cb.assign(i, i + 1) - }) - primitive(n) - }) + IEmitCode( + cb, + m, { + cb.while_( + i < l1, { + cb.if_( + v1.invoke[Int, Char]("charAt", i).toI.cne(v2.invoke[Int, Char]( + "charAt", + i, + ).toI), + cb.assign(n, n + 1), + ) + cb.assign(i, i + 1) + }, + ) + primitive(n) + }, + ) } } } } - registerWrappedScalaFunction1("escapeString", TString, TString, (_: Type, _: SType) => SJavaString)(thisClass, "escapeString") - registerWrappedScalaFunction3("strftime", TString, TInt64, TString, TString, { - case (_: Type, _: SType, _: SType, _: SType) => SJavaString - })(thisClass, "strftime") - registerWrappedScalaFunction3("strptime", TString, TString, TString, TInt64, { - case (_: Type, _: SType, _: SType, _: SType) => SInt64 - })(thisClass, "strptime") - - registerSCode("parse_json", Array(TString), TTuple(tv("T")), - (rType: Type, _: Seq[SType]) => SType.canonical(rType), typeParameters = Array(tv("T")) + registerWrappedScalaFunction1( + "escapeString", + TString, + TString, + (_: Type, _: SType) => SJavaString, + )(thisClass, "escapeString") + registerWrappedScalaFunction3( + "strftime", + TString, + TInt64, + TString, + TString, + { + case (_: Type, _: SType, _: SType, _: SType) => SJavaString + }, + )(thisClass, "strftime") + registerWrappedScalaFunction3( + "strptime", + TString, + TString, + TString, + TInt64, + { + case (_: Type, _: SType, _: SType, _: SType) => SInt64 + }, + )(thisClass, "strptime") + + registerSCode( + "parse_json", + Array(TString), + TTuple(tv("T")), + (rType: Type, _: Seq[SType]) => SType.canonical(rType), + typeParameters = Array(tv("T")), ) { case (er, cb, _, resultType, Array(s: SStringValue), _) => - val warnCtx = cb.emb.genFieldThisRef[mutable.HashSet[String]]("parse_json_context") cb.if_(warnCtx.load().isNull, cb.assign(warnCtx, Code.newInstance[mutable.HashSet[String]]())) - val row = Code.invokeScalaObject3[String, Type, mutable.HashSet[String], Row](JSONAnnotationImpex.getClass, "irImportAnnotation", - s.loadString(cb), er.mb.ecb.getType(resultType.virtualType.asInstanceOf[TTuple].types(0)), warnCtx) + val row = Code.invokeScalaObject3[String, Type, mutable.HashSet[String], Row]( + JSONAnnotationImpex.getClass, + "irImportAnnotation", + s.loadString(cb), + er.mb.ecb.getType(resultType.virtualType.asInstanceOf[TTuple].types(0)), + warnCtx, + ) unwrapReturn(cb, er.region, resultType, row) } diff --git a/hail/src/main/scala/is/hail/expr/ir/functions/TableCalculateNewPartitions.scala b/hail/src/main/scala/is/hail/expr/ir/functions/TableCalculateNewPartitions.scala index d47e69fe0ad..55c5aa5b096 100644 --- a/hail/src/main/scala/is/hail/expr/ir/functions/TableCalculateNewPartitions.scala +++ b/hail/src/main/scala/is/hail/expr/ir/functions/TableCalculateNewPartitions.scala @@ -14,7 +14,8 @@ case class TableCalculateNewPartitions( def unionRequiredness(childType: types.RTable, resultType: types.TypeWithRequiredness): Unit = { val rinterval = types.tcoerce[types.RInterval]( - types.tcoerce[types.RIterable](resultType).elementType) + types.tcoerce[types.RIterable](resultType).elementType + ) val rstart = types.tcoerce[types.RStruct](rinterval.startType) val rend = types.tcoerce[types.RStruct](rinterval.endType) childType.keyFields.foreach { k => @@ -32,7 +33,13 @@ case class TableCalculateNewPartitions( if (ki.isEmpty) FastSeq() else - RVD.calculateKeyRanges(ctx, rvd.typ, ki, nPartitions, rvd.typ.key.length).rangeBounds.toIndexedSeq + RVD.calculateKeyRanges( + ctx, + rvd.typ, + ki, + nPartitions, + rvd.typ.key.length, + ).rangeBounds.toIndexedSeq } } } diff --git a/hail/src/main/scala/is/hail/expr/ir/functions/UtilFunctions.scala b/hail/src/main/scala/is/hail/expr/ir/functions/UtilFunctions.scala index 00dd36e1059..a1508903394 100644 --- a/hail/src/main/scala/is/hail/expr/ir/functions/UtilFunctions.scala +++ b/hail/src/main/scala/is/hail/expr/ir/functions/UtilFunctions.scala @@ -1,8 +1,6 @@ package is.hail.expr.ir.functions import is.hail.annotations.Region - -import java.util.IllegalFormatConversionException import is.hail.asm4s.{coerce => _, _} import is.hail.backend.HailStateManager import is.hail.expr.ir._ @@ -14,29 +12,40 @@ import is.hail.types.physical.stypes.interfaces._ import is.hail.types.physical.stypes.primitives._ import is.hail.types.virtual._ import is.hail.utils._ -import org.apache.spark.sql.Row import scala.reflect.ClassTag +import java.util.IllegalFormatConversionException + +import org.apache.spark.sql.Row + object UtilFunctions extends RegistryFunctions { - def parseBoolean(s: String, errID: Int): Boolean = try { - s.toBoolean - } catch { - case _: IllegalArgumentException => fatal(s"cannot parse boolean from input string '${StringEscapeUtils.escapeString(s)}'", errID) - } + def parseBoolean(s: String, errID: Int): Boolean = + try + s.toBoolean + catch { + case _: IllegalArgumentException => fatal( + s"cannot parse boolean from input string '${StringEscapeUtils.escapeString(s)}'", + errID, + ) + } - def parseInt32(s: String, errID: Int): Int = try { - s.toInt - } catch { - case _: IllegalArgumentException => fatal(s"cannot parse int32 from input string '${StringEscapeUtils.escapeString(s)}'", errID) - } + def parseInt32(s: String, errID: Int): Int = + try + s.toInt + catch { + case _: IllegalArgumentException => + fatal(s"cannot parse int32 from input string '${StringEscapeUtils.escapeString(s)}'", errID) + } - def parseInt64(s: String, errID: Int): Long = try { - s.toLong - } catch { - case _: IllegalArgumentException => fatal(s"cannot parse int64 from input string '${StringEscapeUtils.escapeString(s)}'", errID) - } + def parseInt64(s: String, errID: Int): Long = + try + s.toLong + catch { + case _: IllegalArgumentException => + fatal(s"cannot parse int64 from input string '${StringEscapeUtils.escapeString(s)}'", errID) + } def parseSpecialNum32(s: String, errID: Int): Float = { s.length match { @@ -77,18 +86,18 @@ object UtilFunctions extends RegistryFunctions { } def parseFloat32(s: String, errID: Int): Float = { - try { + try s.toFloat - } catch { + catch { case _: NumberFormatException => parseSpecialNum32(s, errID) } } def parseFloat64(s: String, errID: Int): Double = { - try { + try s.toDouble - } catch { + catch { case _: NumberFormatException => parseSpecialNum64(s, errID) } @@ -111,21 +120,23 @@ object UtilFunctions extends RegistryFunctions { case _: NumberFormatException => false } - def isValidFloat32(s: String): Boolean = try { - parseFloat32(s, -1) - true - } catch { - case _: NumberFormatException => false - case _: HailException => false - } + def isValidFloat32(s: String): Boolean = + try { + parseFloat32(s, -1) + true + } catch { + case _: NumberFormatException => false + case _: HailException => false + } - def isValidFloat64(s: String): Boolean = try { - parseFloat64(s, -1) - true - } catch { - case _: NumberFormatException => false - case _: HailException => false - } + def isValidFloat64(s: String): Boolean = + try { + parseFloat64(s, -1) + true + } catch { + case _: NumberFormatException => false + case _: HailException => false + } def min_ignore_missing(l: Int, lMissing: Boolean, r: Int, rMissing: Boolean): Int = if (lMissing) r else if (rMissing) l else Math.min(l, r) @@ -179,60 +190,118 @@ object UtilFunctions extends RegistryFunctions { def intMax(a: IR, b: IR): IR = If(ApplyComparisonOp(GT(a.typ), a, b), a, b) - def format(f: String, args: Row): String = try { - String.format(f, args.toSeq.map(_.asInstanceOf[java.lang.Object]): _*) - } catch { - case e: IllegalFormatConversionException => - fatal(s"Encountered invalid type for format string $f: format specifier ${e.getConversion} does not accept type ${e.getArgumentClass.getCanonicalName}") - } + def format(f: String, args: Row): String = + try + String.format(f, args.toSeq.map(_.asInstanceOf[java.lang.Object]): _*) + catch { + case e: IllegalFormatConversionException => + fatal( + s"Encountered invalid type for format string $f: format specifier ${e.getConversion} does not accept type ${e.getArgumentClass.getCanonicalName}" + ) + } - def registerAll() { + def registerAll(): Unit = { val thisClass = getClass - registerSCode4("valuesSimilar", tv("T"), tv("U"), TFloat64, TBoolean, TBoolean, { - case (_: Type, _: SType, _: SType, _: SType, _: SType) => SBoolean - }) { - case (er, cb, rt, l, r, tol, abs, _) => - assert(l.st.virtualType == r.st.virtualType, s"\n lt=${ l.st.virtualType }\n rt=${ r.st.virtualType }") + registerSCode4( + "valuesSimilar", + tv("T"), + tv("U"), + TFloat64, + TBoolean, + TBoolean, + { + case (_: Type, _: SType, _: SType, _: SType, _: SType) => SBoolean + }, + ) { + case (er, cb, _, l, r, tol, abs, _) => + assert( + l.st.virtualType == r.st.virtualType, + s"\n lt=${l.st.virtualType}\n rt=${r.st.virtualType}", + ) val lb = svalueToJavaValue(cb, er.region, l) val rb = svalueToJavaValue(cb, er.region, r) - primitive(cb.memoize(er.mb.getType(l.st.virtualType).invoke[Any, Any, Double, Boolean, Boolean]("valuesSimilar", lb, rb, tol.asDouble.value, abs.asBoolean.value))) + primitive(cb.memoize(er.mb.getType(l.st.virtualType).invoke[ + Any, + Any, + Double, + Boolean, + Boolean, + ]("valuesSimilar", lb, rb, tol.asDouble.value, abs.asBoolean.value))) + } + + registerCode1("triangle", TInt32, TInt32, (_: Type, _: SType) => SInt32) { + case (cb, _, _, nn) => + val n = nn.asInt.value + cb.memoize((n * (n + 1)) / 2) } - registerCode1("triangle", TInt32, TInt32, (_: Type, _: SType) => SInt32) { case (cb, _, rt, nn) => - val n = nn.asInt.value - cb.memoize((n * (n + 1)) / 2) + registerSCode1("toInt32", TBoolean, TInt32, (_: Type, _: SType) => SInt32) { + case (_, cb, _, x, _) => + primitive(cb.memoize(x.asBoolean.value.toI)) + } + registerSCode1("toInt64", TBoolean, TInt64, (_: Type, _: SType) => SInt64) { + case (_, cb, _, x, _) => + primitive(cb.memoize(x.asBoolean.value.toI.toL)) + } + registerSCode1("toFloat32", TBoolean, TFloat32, (_: Type, _: SType) => SFloat32) { + case (_, cb, _, x, _) => + primitive(cb.memoize(x.asBoolean.value.toI.toF)) + } + registerSCode1("toFloat64", TBoolean, TFloat64, (_: Type, _: SType) => SFloat64) { + case (_, cb, _, x, _) => + primitive(cb.memoize(x.asBoolean.value.toI.toD)) } - registerSCode1("toInt32", TBoolean, TInt32, (_: Type, _: SType) => SInt32) { case (_, cb, _, x, _) => - primitive(cb.memoize(x.asBoolean.value.toI)) } - registerSCode1("toInt64", TBoolean, TInt64, (_: Type, _: SType) => SInt64) { case (_, cb, _, x, _) => - primitive(cb.memoize(x.asBoolean.value.toI.toL)) } - registerSCode1("toFloat32", TBoolean, TFloat32, (_: Type, _: SType) => SFloat32) { case (_, cb, _, x, _) => - primitive(cb.memoize(x.asBoolean.value.toI.toF)) } - registerSCode1("toFloat64", TBoolean, TFloat64, (_: Type, _: SType) => SFloat64) { case (_, cb, _, x, _) => - primitive(cb.memoize(x.asBoolean.value.toI.toD)) } - - for ((name, t, rpt, ct) <- Seq[(String, Type, SType, ClassTag[_])]( - ("Boolean", TBoolean, SBoolean, implicitly[ClassTag[Boolean]]), - ("Int32", TInt32, SInt32, implicitly[ClassTag[Int]]), - ("Int64", TInt64, SInt64, implicitly[ClassTag[Long]]), - ("Float64", TFloat64, SFloat64, implicitly[ClassTag[Double]]), - ("Float32", TFloat32, SFloat32, implicitly[ClassTag[Float]]) - )) { + for ( + (name, t, rpt, ct) <- Seq[(String, Type, SType, ClassTag[_])]( + ("Boolean", TBoolean, SBoolean, implicitly[ClassTag[Boolean]]), + ("Int32", TInt32, SInt32, implicitly[ClassTag[Int]]), + ("Int64", TInt64, SInt64, implicitly[ClassTag[Long]]), + ("Float64", TFloat64, SFloat64, implicitly[ClassTag[Double]]), + ("Float32", TFloat32, SFloat32, implicitly[ClassTag[Float]]), + ) + ) { val ctString: ClassTag[String] = implicitly[ClassTag[String]] registerSCode1(s"to$name", TString, t, (_: Type, _: SType) => rpt) { - case (r, cb, rt, x: SStringValue, err) => + case (_, cb, rt, x: SStringValue, err) => val s = x.loadString(cb) - primitive(rt.virtualType, cb.memoizeAny(Code.invokeScalaObject2(thisClass, s"parse$name", s, err)(ctString, implicitly[ClassTag[Int]], ct), typeInfoFromClassTag(ct))) + primitive( + rt.virtualType, + cb.memoizeAny( + Code.invokeScalaObject2(thisClass, s"parse$name", s, err)( + ctString, + implicitly[ClassTag[Int]], + ct, + ), + typeInfoFromClassTag(ct), + ), + ) } - registerIEmitCode1(s"to${name}OrMissing", TString, t, (_: Type, _: EmitType) => EmitType(rpt, false)) { - case (cb, r, rt, err, x) => + registerIEmitCode1( + s"to${name}OrMissing", + TString, + t, + (_: Type, _: EmitType) => EmitType(rpt, false), + ) { + case (cb, _, rt, err, x) => x.toI(cb).flatMap(cb) { case sc: SStringValue => val sv = cb.newLocal[String]("s", sc.loadString(cb)) - IEmitCode(cb, + IEmitCode( + cb, !Code.invokeScalaObject1[String, Boolean](thisClass, s"isValid$name", sv), - primitive(rt.virtualType, cb.memoizeAny(Code.invokeScalaObject2(thisClass, s"parse$name", sv, err)(ctString, implicitly[ClassTag[Int]], ct), typeInfoFromClassTag(ct)))) + primitive( + rt.virtualType, + cb.memoizeAny( + Code.invokeScalaObject2(thisClass, s"parse$name", sv, err)( + ctString, + implicitly[ClassTag[Int]], + ct, + ), + typeInfoFromClassTag(ct), + ), + ), + ) } } } @@ -244,97 +313,234 @@ object UtilFunctions extends RegistryFunctions { Array("min", "max").foreach { name => registerCode2(name, TFloat32, TFloat32, TFloat32, (_: Type, _: SType, _: SType) => SFloat32) { - case (cb, r, rt, v1, v2) => - cb.memoize(Code.invokeStatic2[Math, Float, Float, Float](name, v1.asFloat.value, v2.asFloat.value)) + case (cb, _, _, v1, v2) => + cb.memoize(Code.invokeStatic2[Math, Float, Float, Float]( + name, + v1.asFloat.value, + v2.asFloat.value, + )) } registerCode2(name, TFloat64, TFloat64, TFloat64, (_: Type, _: SType, _: SType) => SFloat64) { - case (cb, r, rt, v1, v2) => - cb.memoize(Code.invokeStatic2[Math, Double, Double, Double](name, v1.asDouble.value, v2.asDouble.value)) + case (cb, _, _, v1, v2) => + cb.memoize(Code.invokeStatic2[Math, Double, Double, Double]( + name, + v1.asDouble.value, + v2.asDouble.value, + )) } val ignoreMissingName = name + "_ignore_missing" val ignoreNanName = "nan" + name val ignoreBothName = ignoreNanName + "_ignore_missing" - registerCode2(ignoreNanName, TFloat32, TFloat32, TFloat32, (_: Type, _: SType, _: SType) => SFloat32) { - case (cb, r, rt, v1, v2) => - cb.memoize(Code.invokeScalaObject2[Float, Float, Float](thisClass, ignoreNanName, v1.asFloat.value, v2.asFloat.value)) + registerCode2( + ignoreNanName, + TFloat32, + TFloat32, + TFloat32, + (_: Type, _: SType, _: SType) => SFloat32, + ) { + case (cb, _, _, v1, v2) => + cb.memoize(Code.invokeScalaObject2[Float, Float, Float]( + thisClass, + ignoreNanName, + v1.asFloat.value, + v2.asFloat.value, + )) } - registerCode2(ignoreNanName, TFloat64, TFloat64, TFloat64, (_: Type, _: SType, _: SType) => SFloat64) { - case (cb, r, rt, v1, v2) => - cb.memoize(Code.invokeScalaObject2[Double, Double, Double](thisClass, ignoreNanName, v1.asDouble.value, v2.asDouble.value)) + registerCode2( + ignoreNanName, + TFloat64, + TFloat64, + TFloat64, + (_: Type, _: SType, _: SType) => SFloat64, + ) { + case (cb, _, _, v1, v2) => + cb.memoize(Code.invokeScalaObject2[Double, Double, Double]( + thisClass, + ignoreNanName, + v1.asDouble.value, + v2.asDouble.value, + )) } - def ignoreMissingTriplet[T](cb: EmitCodeBuilder, rt: SType, v1: EmitCode, v2: EmitCode, name: String, f: (Code[T], Code[T]) => Code[T])(implicit ct: ClassTag[T], ti: TypeInfo[T]): IEmitCode = { - val value = cb.newLocal[T](s"ignore_missing_${ name }_value") + def ignoreMissingTriplet[T]( + cb: EmitCodeBuilder, + rt: SType, + v1: EmitCode, + v2: EmitCode, + name: String, + f: (Code[T], Code[T]) => Code[T], + )(implicit + ct: ClassTag[T], + ti: TypeInfo[T], + ): IEmitCode = { + val value = cb.newLocal[T](s"ignore_missing_${name}_value") val v1Value = v1.toI(cb).memoize(cb, "ignore_missing_v1") val v2Value = v2.toI(cb).memoize(cb, "ignore_missing_v2") val Lmissing = CodeLabel() val Ldefined = CodeLabel() v1Value.toI(cb) - .consume(cb, - { - v2Value.toI(cb).consume(cb, - cb.goto(Lmissing), - sc2 => cb.assignAny(value, sc2.asPrimitive.primitiveValue[T]) - ) - }, + .consume( + cb, + v2Value.toI(cb).consume( + cb, + cb.goto(Lmissing), + sc2 => cb.assignAny(value, sc2.asPrimitive.primitiveValue[T]), + ), { sc1 => cb.assign(value, sc1.asPrimitive.primitiveValue[T]) - v2Value.toI(cb).consume(cb, + v2Value.toI(cb).consume( + cb, {}, - sc2 => cb.assignAny(value, f(value, sc2.asPrimitive.primitiveValue[T])) + sc2 => cb.assignAny(value, f(value, sc2.asPrimitive.primitiveValue[T])), ) - }) + }, + ) cb.goto(Ldefined) IEmitCode(Lmissing, Ldefined, primitive(rt.virtualType, value), v1.required || v2.required) } - registerIEmitCode2(ignoreMissingName, TInt32, TInt32, TInt32, (_: Type, t1: EmitType, t2: EmitType) => EmitType(SInt32, t1.required || t2.required)) { - case (cb, r, rt, _, v1, v2) => ignoreMissingTriplet[Int](cb, rt, v1, v2, name, Code.invokeStatic2[Math, Int, Int, Int](name, _, _)) + registerIEmitCode2( + ignoreMissingName, + TInt32, + TInt32, + TInt32, + (_: Type, t1: EmitType, t2: EmitType) => EmitType(SInt32, t1.required || t2.required), + ) { + case (cb, _, rt, _, v1, v2) => ignoreMissingTriplet[Int]( + cb, + rt, + v1, + v2, + name, + Code.invokeStatic2[Math, Int, Int, Int](name, _, _), + ) } - registerIEmitCode2(ignoreMissingName, TInt64, TInt64, TInt64, (_: Type, t1: EmitType, t2: EmitType) => EmitType(SInt64, t1.required || t2.required)) { - case (cb, r, rt, _, v1, v2) => ignoreMissingTriplet[Long](cb, rt, v1, v2, name, Code.invokeStatic2[Math, Long, Long, Long](name, _, _)) + registerIEmitCode2( + ignoreMissingName, + TInt64, + TInt64, + TInt64, + (_: Type, t1: EmitType, t2: EmitType) => EmitType(SInt64, t1.required || t2.required), + ) { + case (cb, _, rt, _, v1, v2) => ignoreMissingTriplet[Long]( + cb, + rt, + v1, + v2, + name, + Code.invokeStatic2[Math, Long, Long, Long](name, _, _), + ) } - registerIEmitCode2(ignoreMissingName, TFloat32, TFloat32, TFloat32, (_: Type, t1: EmitType, t2: EmitType) => EmitType(SFloat32, t1.required || t2.required)) { - case (cb, r, rt, _, v1, v2) => ignoreMissingTriplet[Float](cb, rt, v1, v2, name, Code.invokeStatic2[Math, Float, Float, Float](name, _, _)) + registerIEmitCode2( + ignoreMissingName, + TFloat32, + TFloat32, + TFloat32, + (_: Type, t1: EmitType, t2: EmitType) => EmitType(SFloat32, t1.required || t2.required), + ) { + case (cb, _, rt, _, v1, v2) => ignoreMissingTriplet[Float]( + cb, + rt, + v1, + v2, + name, + Code.invokeStatic2[Math, Float, Float, Float](name, _, _), + ) } - registerIEmitCode2(ignoreMissingName, TFloat64, TFloat64, TFloat64, (_: Type, t1: EmitType, t2: EmitType) => EmitType(SFloat64, t1.required || t2.required)) { - case (cb, r, rt, _, v1, v2) => ignoreMissingTriplet[Double](cb, rt, v1, v2, name, Code.invokeStatic2[Math, Double, Double, Double](name, _, _)) + registerIEmitCode2( + ignoreMissingName, + TFloat64, + TFloat64, + TFloat64, + (_: Type, t1: EmitType, t2: EmitType) => EmitType(SFloat64, t1.required || t2.required), + ) { + case (cb, _, rt, _, v1, v2) => ignoreMissingTriplet[Double]( + cb, + rt, + v1, + v2, + name, + Code.invokeStatic2[Math, Double, Double, Double](name, _, _), + ) } - registerIEmitCode2(ignoreBothName, TFloat32, TFloat32, TFloat32, (_: Type, t1: EmitType, t2: EmitType) => EmitType(SFloat32, t1.required || t2.required)) { - case (cb, r, rt, _, v1, v2) => ignoreMissingTriplet[Float](cb, rt, v1, v2, ignoreNanName, Code.invokeScalaObject2[Float, Float, Float](thisClass, ignoreNanName, _, _)) + registerIEmitCode2( + ignoreBothName, + TFloat32, + TFloat32, + TFloat32, + (_: Type, t1: EmitType, t2: EmitType) => EmitType(SFloat32, t1.required || t2.required), + ) { + case (cb, _, rt, _, v1, v2) => ignoreMissingTriplet[Float]( + cb, + rt, + v1, + v2, + ignoreNanName, + Code.invokeScalaObject2[Float, Float, Float](thisClass, ignoreNanName, _, _), + ) } - registerIEmitCode2(ignoreBothName, TFloat64, TFloat64, TFloat64, (_: Type, t1: EmitType, t2: EmitType) => EmitType(SFloat64, t1.required || t2.required)) { - case (cb, r, rt, _, v1, v2) => ignoreMissingTriplet[Double](cb, rt, v1, v2, ignoreNanName, Code.invokeScalaObject2[Double, Double, Double](thisClass, ignoreNanName, _, _)) + registerIEmitCode2( + ignoreBothName, + TFloat64, + TFloat64, + TFloat64, + (_: Type, t1: EmitType, t2: EmitType) => EmitType(SFloat64, t1.required || t2.required), + ) { + case (cb, _, rt, _, v1, v2) => ignoreMissingTriplet[Double]( + cb, + rt, + v1, + v2, + ignoreNanName, + Code.invokeScalaObject2[Double, Double, Double](thisClass, ignoreNanName, _, _), + ) } } - registerSCode2("format", TString, tv("T", "tuple"), TString, (_: Type, _: SType, _: SType) => SJavaString) { + registerSCode2( + "format", + TString, + tv("T", "tuple"), + TString, + (_: Type, _: SType, _: SType) => SJavaString, + ) { case (r, cb, st: SJavaString.type, format, args, _) => val javaObjArgs = Code.checkcast[Row](svalueToJavaValue(cb, r.region, args)) - val formatted = Code.invokeScalaObject2[String, Row, String](thisClass, "format", format.asString.loadString(cb), javaObjArgs) + val formatted = Code.invokeScalaObject2[String, Row, String]( + thisClass, + "format", + format.asString.loadString(cb), + javaObjArgs, + ) st.construct(cb, formatted) } - registerIEmitCode2("land", TBoolean, TBoolean, TBoolean, (_: Type, tl: EmitType, tr: EmitType) => EmitType(SBoolean, tl.required && tr.required)) { - case (cb, _, rt,_ , l, r) => + registerIEmitCode2( + "land", + TBoolean, + TBoolean, + TBoolean, + (_: Type, tl: EmitType, tr: EmitType) => EmitType(SBoolean, tl.required && tr.required), + ) { + case (cb, _, _, _, l, r) => if (l.required && r.required) { val result = cb.newLocal[Boolean]("land_result") - cb.if_(l.toI(cb).get(cb).asBoolean.value, { - cb.assign(result, r.toI(cb).get(cb).asBoolean.value) - }, { - cb.assign(result, const(false)) - }) + cb.if_( + l.toI(cb).getOrAssert(cb).asBoolean.value, + cb.assign(result, r.toI(cb).getOrAssert(cb).asBoolean.value), + cb.assign(result, const(false)), + ) IEmitCode.present(cb, primitive(result)) } else { @@ -345,34 +551,40 @@ object UtilFunctions extends RegistryFunctions { val M = const((1 << 5) | (1 << 6) | (1 << 9)) l.toI(cb) - .consume(cb, + .consume( + cb, cb.assign(w, 1), - b1 => cb.assign(w, b1.asBoolean.value.mux(const(2), const(0))) + b1 => cb.assign(w, b1.asBoolean.value.mux(const(2), const(0))), ) - cb.if_(w.cne(0), - { - r.toI(cb).consume(cb, - cb.assign(w, w | const(4)), - { b2 => - cb.assign(w, w | b2.asBoolean.value.mux(const(8), const(0))) - } - ) - }) + cb.if_( + w.cne(0), + r.toI(cb).consume( + cb, + cb.assign(w, w | const(4)), + b2 => cb.assign(w, w | b2.asBoolean.value.mux(const(8), const(0))), + ), + ) IEmitCode(cb, ((M >> w) & 1).cne(0), primitive(cb.memoize(w.ceq(10)))) } } - registerIEmitCode2("lor", TBoolean, TBoolean, TBoolean, (_: Type, tl: EmitType, tr: EmitType) => EmitType(SBoolean, tl.required && tr.required)) { - case (cb, _, rt,_, l, r) => + registerIEmitCode2( + "lor", + TBoolean, + TBoolean, + TBoolean, + (_: Type, tl: EmitType, tr: EmitType) => EmitType(SBoolean, tl.required && tr.required), + ) { + case (cb, _, _, _, l, r) => if (l.required && r.required) { val result = cb.newLocal[Boolean]("land_result") - cb.if_(l.toI(cb).get(cb).asBoolean.value, { - cb.assign(result, const(true)) - }, { - cb.assign(result, r.toI(cb).get(cb).asBoolean.value) - }) + cb.if_( + l.toI(cb).getOrAssert(cb).asBoolean.value, + cb.assign(result, const(true)), + cb.assign(result, r.toI(cb).getOrAssert(cb).asBoolean.value), + ) IEmitCode.present(cb, primitive(result)) } else { @@ -383,52 +595,69 @@ object UtilFunctions extends RegistryFunctions { val M = const((1 << 5) | (1 << 1) | (1 << 4)) l.toI(cb) - .consume(cb, + .consume( + cb, cb.assign(w, 1), - b1 => cb.assign(w, b1.asBoolean.value.mux(const(2), const(0))) + b1 => cb.assign(w, b1.asBoolean.value.mux(const(2), const(0))), ) - cb.if_(w.cne(2), - { - r.toI(cb).consume(cb, - cb.assign(w, w | const(4)), - { b2 => - cb.assign(w, w | b2.asBoolean.value.mux(const(8), const(0))) - } - ) - }) + cb.if_( + w.cne(2), + r.toI(cb).consume( + cb, + cb.assign(w, w | const(4)), + b2 => cb.assign(w, w | b2.asBoolean.value.mux(const(8), const(0))), + ), + ) IEmitCode(cb, ((M >> w) & 1).cne(0), primitive(cb.memoize(w.cne(0)))) } } - registerIEmitCode4("getVCFHeader", TString, TString, TString, TString, - VCFHeaderInfo.headerType, (_, fileET, _, _, _) => EmitType(VCFHeaderInfo.headerTypePType.sType, fileET.required)) { - case (cb, r, rt, errID, file, filter, find, replace) => + registerIEmitCode4( + "getVCFHeader", + TString, + TString, + TString, + TString, + VCFHeaderInfo.headerType, + (_, fileET, _, _, _) => EmitType(VCFHeaderInfo.headerTypePType.sType, fileET.required), + ) { + case (cb, r, _, _, file, filter, find, replace) => file.toI(cb).map(cb) { case filePath: SStringValue => val filterVar = cb.newLocal[String]("filterVar") val findVar = cb.newLocal[String]("findVar") val replaceVar = cb.newLocal[String]("replaceVar") - filter.toI(cb).consume(cb, { - cb.assign(filterVar, Code._null) - }, { filt => - cb.assign(filterVar, filt.asString.loadString(cb)) - }) - find.toI(cb).consume(cb, { - cb.assign(findVar, Code._null) - }, { find => - cb.assign(findVar, find.asString.loadString(cb)) - }) - replace.toI(cb).consume(cb, { - cb.assign(replaceVar, Code._null) - }, { replace => - cb.assign(replaceVar, replace.asString.loadString(cb)) - }) + filter.toI(cb).consume( + cb, + cb.assign(filterVar, Code._null), + filt => cb.assign(filterVar, filt.asString.loadString(cb)), + ) + find.toI(cb).consume( + cb, + cb.assign(findVar, Code._null), + find => cb.assign(findVar, find.asString.loadString(cb)), + ) + replace.toI(cb).consume( + cb, + cb.assign(replaceVar, Code._null), + replace => cb.assign(replaceVar, replace.asString.loadString(cb)), + ) val hd = Code.invokeScalaObject5[FS, String, String, String, String, VCFHeaderInfo]( - LoadVCF.getClass, "getVCFHeaderInfo", cb.emb.getFS, - filePath.loadString(cb), filterVar, findVar, replaceVar) - val addr = cb.memoize(hd.invoke[HailStateManager, Region, Boolean, Long]("writeToRegion", - cb.emb.getObject(cb.emb.ecb.ctx.stateManager), r, const(false))) + LoadVCF.getClass, + "getVCFHeaderInfo", + cb.emb.getFS, + filePath.loadString(cb), + filterVar, + findVar, + replaceVar, + ) + val addr = cb.memoize(hd.invoke[HailStateManager, Region, Boolean, Long]( + "writeToRegion", + cb.emb.getObject(cb.emb.ecb.ctx.stateManager), + r, + const(false), + )) VCFHeaderInfo.headerTypePType.loadCheapSCode(cb, addr) } } diff --git a/hail/src/main/scala/is/hail/expr/ir/lowering/CanLowerEfficiently.scala b/hail/src/main/scala/is/hail/expr/ir/lowering/CanLowerEfficiently.scala index 122df59567a..f2755d5d7aa 100644 --- a/hail/src/main/scala/is/hail/expr/ir/lowering/CanLowerEfficiently.scala +++ b/hail/src/main/scala/is/hail/expr/ir/lowering/CanLowerEfficiently.scala @@ -2,7 +2,9 @@ package is.hail.expr.ir.lowering import is.hail.backend.ExecuteContext import is.hail.expr.ir._ -import is.hail.expr.ir.functions.{TableCalculateNewPartitions, TableToValueFunction, WrappedMatrixToTableFunction} +import is.hail.expr.ir.functions.{ + TableCalculateNewPartitions, TableToValueFunction, WrappedMatrixToTableFunction, +} import is.hail.expr.ir.lowering.LowerDistributedSort.LocalSortReader import is.hail.io.avro.AvroTableReader import is.hail.io.bgen.MatrixBGENReader @@ -39,45 +41,48 @@ object CanLowerEfficiently { case TableRead(_, _, _: TableFromBlockMatrixNativeReader) => fail(s"no lowering for TableFromBlockMatrixNativeReader") - case t: TableLiteral => + case _: TableLiteral => case TableRepartition(_, _, RepartitionStrategy.NAIVE_COALESCE) => - case t: TableRepartition => fail(s"TableRepartition has no lowered implementation") - case t: TableParallelize => - case t: TableRange => - case TableKeyBy(child, keys, isSorted) => - case t: TableOrderBy => - case t: TableFilter => - case t: TableHead => - case t: TableTail => - case t: TableJoin => - case TableIntervalJoin(_, _, _, true) => fail("TableIntervalJoin with \"product=true\" has no lowered implementation") - case TableIntervalJoin(_, _, _, false) => - case t: TableLeftJoinRightDistinct => - case t: TableMapPartitions => - case t: TableMapRows => - case t: TableMapGlobals => - case t: TableExplode => - case t: TableUnion if t.childrenSeq.length > 16 => fail(s"TableUnion lowering generates deeply nested IR if it has many children") - case t: TableUnion => - case t: TableMultiWayZipJoin => fail(s"TableMultiWayZipJoin is not passing tests due to problems in ptype inference in StreamZipJoin") - case t: TableDistinct => - case t: TableKeyByAndAggregate => - case t: TableAggregateByKey => - case t: TableRename => - case t: TableFilterIntervals => - case t: TableGen => + case _: TableRepartition => fail(s"TableRepartition has no lowered implementation") + case _: TableParallelize => + case _: TableRange => + case TableKeyBy(_, _, _) => + case _: TableOrderBy => + case _: TableFilter => + case _: TableHead => + case _: TableTail => + case _: TableJoin => + case _: TableIntervalJoin => + case _: TableLeftJoinRightDistinct => + case _: TableMapPartitions => + case _: TableMapRows => + case _: TableMapGlobals => + case _: TableExplode => + case t: TableUnion if t.childrenSeq.length > 16 => + fail(s"TableUnion lowering generates deeply nested IR if it has many children") + case _: TableUnion => + case _: TableMultiWayZipJoin => fail( + s"TableMultiWayZipJoin is not passing tests due to problems in ptype inference in StreamZipJoin" + ) + case _: TableDistinct => + case _: TableKeyByAndAggregate => + case _: TableAggregateByKey => + case _: TableRename => + case _: TableFilterIntervals => + case _: TableGen => case TableToTableApply(_, TableFilterPartitions(_, _)) => case TableToTableApply(_, WrappedMatrixToTableFunction(_: LocalLDPrune, _, _, _)) => - case t: TableToTableApply => fail(s"TableToTableApply") - case t: BlockMatrixToTableApply => fail(s"BlockMatrixToTableApply") - case t: BlockMatrixToTable => fail(s"BlockMatrixToTable has no lowered implementation") - - case x: BlockMatrixAgg => fail(s"BlockMatrixAgg needs to do tree aggregation") - case x: BlockMatrixIR => fail(s"BlockMatrixIR lowering not yet efficient/scalable") - case x: BlockMatrixWrite => fail(s"BlockMatrixIR lowering not yet efficient/scalable") - case x: BlockMatrixMultiWrite => fail(s"BlockMatrixIR lowering not yet efficient/scalable") - case x: BlockMatrixCollect => fail(s"BlockMatrixIR lowering not yet efficient/scalable") - case x: BlockMatrixToValueApply => fail(s"BlockMatrixIR lowering not yet efficient/scalable") + case _: TableToTableApply => fail(s"TableToTableApply") + case _: BlockMatrixToTableApply => fail(s"BlockMatrixToTableApply") + case _: BlockMatrixToTable => fail(s"BlockMatrixToTable has no lowered implementation") + + case _: BlockMatrixAgg => fail(s"BlockMatrixAgg needs to do tree aggregation") + case _: BlockMatrixIR => fail(s"BlockMatrixIR lowering not yet efficient/scalable") + case _: BlockMatrixWrite => fail(s"BlockMatrixIR lowering not yet efficient/scalable") + case _: BlockMatrixMultiWrite => fail(s"BlockMatrixIR lowering not yet efficient/scalable") + case _: BlockMatrixCollect => fail(s"BlockMatrixIR lowering not yet efficient/scalable") + case _: BlockMatrixToValueApply => + fail(s"BlockMatrixIR lowering not yet efficient/scalable") case _: MatrixMultiWrite => @@ -85,12 +90,14 @@ object CanLowerEfficiently { case TableToValueApply(_, ForceCountTable()) => case TableToValueApply(_, NPartitionsTable()) => case TableToValueApply(_, TableCalculateNewPartitions(_)) => - case TableToValueApply(_, f: TableToValueFunction) => fail(s"TableToValueApply: no lowering for ${ f.getClass.getName }") + case TableToValueApply(_, f: TableToValueFunction) => + fail(s"TableToValueApply: no lowering for ${f.getClass.getName}") case TableAggregate(_, _) => case TableCollect(_) => case TableGetGlobals(_) => - case TableWrite(_, writer) => if (!writer.canLowerEfficiently) fail(s"writer has no efficient lowering: ${ writer.getClass.getSimpleName }") + case TableWrite(_, writer) => if (!writer.canLowerEfficiently) + fail(s"writer has no efficient lowering: ${writer.getClass.getSimpleName}") case TableMultiWrite(_, _) => case RelationalRef(_, _) => throw new RuntimeException(s"unexpected relational ref") @@ -98,7 +105,7 @@ object CanLowerEfficiently { case x: IR => // nodes with relational children should be enumerated above explicitly if (!x.children.forall(_.isInstanceOf[IR])) { - throw new RuntimeException(s"IR must be enumerated explicitly: ${ x.getClass.getName }") + throw new RuntimeException(s"IR must be enumerated explicitly: ${x.getClass.getName}") } } diff --git a/hail/src/main/scala/is/hail/expr/ir/lowering/EvalRelationalLets.scala b/hail/src/main/scala/is/hail/expr/ir/lowering/EvalRelationalLets.scala index 59c05de4b85..a9e9c25d9eb 100644 --- a/hail/src/main/scala/is/hail/expr/ir/lowering/EvalRelationalLets.scala +++ b/hail/src/main/scala/is/hail/expr/ir/lowering/EvalRelationalLets.scala @@ -1,7 +1,10 @@ package is.hail.expr.ir.lowering import is.hail.backend.ExecuteContext -import is.hail.expr.ir.{BaseIR, CompileAndEvaluate, IR, RelationalLet, RelationalLetMatrixTable, RelationalLetTable, RelationalRef, RewriteBottomUp} +import is.hail.expr.ir.{ + BaseIR, CompileAndEvaluate, IR, RelationalLet, RelationalLetMatrixTable, RelationalLetTable, + RelationalRef, +} object EvalRelationalLets { // need to run the rest of lowerings to eval. diff --git a/hail/src/main/scala/is/hail/expr/ir/lowering/IRState.scala b/hail/src/main/scala/is/hail/expr/ir/lowering/IRState.scala index b786dcd9b3c..f0f55751a4a 100644 --- a/hail/src/main/scala/is/hail/expr/ir/lowering/IRState.scala +++ b/hail/src/main/scala/is/hail/expr/ir/lowering/IRState.scala @@ -1,6 +1,8 @@ package is.hail.expr.ir.lowering -import is.hail.expr.ir.{BaseIR, RelationalLet, RelationalRef, TableKeyBy, TableKeyByAndAggregate, TableOrderBy} +import is.hail.expr.ir.{ + BaseIR, RelationalLet, RelationalRef, TableKeyBy, TableKeyByAndAggregate, TableOrderBy, +} trait IRState { @@ -10,7 +12,7 @@ trait IRState { final def verify(ir: BaseIR): Unit = { if (!rules.forall(_.allows(ir))) - throw new RuntimeException(s"lowered state ${ this.getClass.getCanonicalName } forbids IR $ir") + throw new RuntimeException(s"lowered state ${this.getClass.getCanonicalName} forbids IR $ir") ir.children.foreach(verify) } diff --git a/hail/src/main/scala/is/hail/expr/ir/lowering/LowerAndExecuteShuffles.scala b/hail/src/main/scala/is/hail/expr/ir/lowering/LowerAndExecuteShuffles.scala index 22313ba27b1..30070440ec5 100644 --- a/hail/src/main/scala/is/hail/expr/ir/lowering/LowerAndExecuteShuffles.scala +++ b/hail/src/main/scala/is/hail/expr/ir/lowering/LowerAndExecuteShuffles.scala @@ -1,121 +1,201 @@ package is.hail.expr.ir.lowering import is.hail.backend.ExecuteContext -import is.hail.expr.ir.agg.{Extract, PhysicalAggSig, TakeStateSig} import is.hail.expr.ir.{Requiredness, _} +import is.hail.expr.ir.agg.{Extract, PhysicalAggSig, TakeStateSig} import is.hail.types._ import is.hail.types.virtual._ import is.hail.utils.FastSeq - object LowerAndExecuteShuffles { def apply(ir: BaseIR, ctx: ExecuteContext, passesBelow: LoweringPipeline): BaseIR = { - RewriteBottomUp(ir, { - case t@TableKeyBy(child, key, isSorted) if !t.definitelyDoesNotShuffle => - val r = Requiredness(child, ctx) - val reader = ctx.backend.lowerDistributedSort(ctx, child, key.map(k => SortField(k, Ascending)), r.lookup(child).asInstanceOf[RTable]) - Some(TableRead(t.typ, false, reader)) - - case t@TableOrderBy(child, sortFields) if !t.definitelyDoesNotShuffle => - val r = Requiredness(child, ctx) - val reader = ctx.backend.lowerDistributedSort(ctx, child, sortFields, r.lookup(child).asInstanceOf[RTable]) - Some(TableRead(t.typ, false, reader)) - - case t@TableKeyByAndAggregate(child, expr, newKey, nPartitions, bufferSize) => - val newKeyType = newKey.typ.asInstanceOf[TStruct] - val resultUID = genUID() - - val req = Requiredness(t, ctx) - - val aggs = Extract(expr, resultUID, req) - val postAggIR = aggs.postAggIR - val init = aggs.init - val seq = aggs.seqPerElt - val aggSigs = aggs.aggs - - val streamName = genUID() - val streamTyp = TStream(child.typ.rowType) - var ts = child - - val origGlobalTyp = ts.typ.globalType - ts = TableKeyBy(child, IndexedSeq()) - ts = TableMapGlobals(ts, MakeStruct(FastSeq( - ("oldGlobals", Ref("global", origGlobalTyp)), - ("__initState", - RunAgg(init, MakeTuple.ordered(aggSigs.indices.map { aIdx => AggStateValue(aIdx, aggSigs(aIdx).state) }), - aggSigs.map(_.state)))))) - - val insGlobName = genUID() - def insGlob = Ref(insGlobName, ts.typ.globalType) - val partiallyAggregated = - TableMapPartitions(ts, insGlob.name, streamName, - Let(FastSeq("global" -> GetField(insGlob, "oldGlobals")), - StreamBufferedAggregate(Ref(streamName, streamTyp), bindIR(GetField(insGlob, "__initState")) { states => - Begin(aggSigs.indices.map { aIdx => - InitFromSerializedValue(aIdx, GetTupleElement(states, aIdx), aggSigs(aIdx).state) - }) - }, newKey, seq, "row", aggSigs, bufferSize) - ), - 0, 0).noSharing(ctx) - - - val analyses = LoweringAnalyses(partiallyAggregated, ctx) - val preShuffleStage = ctx.backend.tableToTableStage(ctx, partiallyAggregated, analyses) - // annoying but no better alternative right now - val rt = analyses.requirednessAnalysis.lookup(partiallyAggregated).asInstanceOf[RTable] - val partiallyAggregatedReader = ctx.backend.lowerDistributedSort(ctx, - preShuffleStage, - newKeyType.fieldNames.map(k => SortField(k, Ascending)), - rt, - nPartitions) - - val takeVirtualSig = TakeStateSig(VirtualTypeWithReq(newKeyType, rt.rowType.select(newKeyType.fieldNames))) - val takeAggSig = PhysicalAggSig(Take(), takeVirtualSig) - val aggStateSigsPlusTake = aggs.states ++ Array(takeVirtualSig) - - val postAggUID = genUID() - val resultFromTakeUID = genUID() - val result = ResultOp(aggs.aggs.length, takeAggSig) - - val shuffleRead = TableRead(partiallyAggregatedReader.fullType, false, partiallyAggregatedReader) - - val partStream = Ref(genUID(), TStream(shuffleRead.typ.rowType)) - val tmp = TableMapPartitions(shuffleRead, insGlob.name, partStream.name, - Let(FastSeq("global" -> GetField(insGlob, "oldGlobals")), - mapIR(StreamGroupByKey(partStream, newKeyType.fieldNames.toIndexedSeq, missingEqual = true)) { groupRef => - RunAgg(Begin(FastSeq( - bindIR(GetField(insGlob, "__initState")) { states => - Begin(aggSigs.indices.map { aIdx => InitFromSerializedValue(aIdx, GetTupleElement(states, aIdx), aggSigs(aIdx).state) }) - }, - InitOp(aggSigs.length, IndexedSeq(I32(1)), PhysicalAggSig(Take(), takeVirtualSig)), - forIR(groupRef) { elem => - Begin(FastSeq( - SeqOp(aggSigs.length, IndexedSeq(SelectFields(elem, newKeyType.fieldNames)), PhysicalAggSig(Take(), takeVirtualSig)), - Begin((0 until aggSigs.length).map { aIdx => - CombOpValue(aIdx, GetTupleElement(GetField(elem, "agg"), aIdx), aggSigs(aIdx)) - }))) - })), - Let( - FastSeq( - resultUID -> ResultOp.makeTuple(aggs.aggs), - postAggUID -> postAggIR, - resultFromTakeUID -> result - ), { - val keyIRs: IndexedSeq[(String, IR)] = - newKeyType.fieldNames.map(keyName => keyName -> GetField(ArrayRef(Ref(resultFromTakeUID, result.typ), 0), keyName)) - - MakeStruct(keyIRs ++ expr.typ.asInstanceOf[TStruct].fieldNames.map { f => - (f, GetField(Ref(postAggUID, postAggIR.typ), f)) - }) + RewriteBottomUp( + ir, + { + case t @ TableKeyBy(child, key, _) if !t.definitelyDoesNotShuffle => + val r = Requiredness(child, ctx) + val reader = ctx.backend.lowerDistributedSort( + ctx, + child, + key.map(k => SortField(k, Ascending)), + r.lookup(child).asInstanceOf[RTable], + ) + Some(TableRead(t.typ, false, reader)) + + case t @ TableOrderBy(child, sortFields) if !t.definitelyDoesNotShuffle => + val r = Requiredness(child, ctx) + val reader = ctx.backend.lowerDistributedSort( + ctx, + child, + sortFields, + r.lookup(child).asInstanceOf[RTable], + ) + Some(TableRead(t.typ, false, reader)) + + case t @ TableKeyByAndAggregate(child, expr, newKey, nPartitions, bufferSize) => + val newKeyType = newKey.typ.asInstanceOf[TStruct] + + val req = Requiredness(t, ctx) + + val aggs = Extract(expr, req) + val postAggIR = aggs.postAggIR + val init = aggs.init + val seq = aggs.seqPerElt + val aggSigs = aggs.aggs + + val streamName = genUID() + val streamTyp = TStream(child.typ.rowType) + var ts = child + + val origGlobalTyp = ts.typ.globalType + ts = TableKeyBy(child, IndexedSeq()) + ts = TableMapGlobals( + ts, + MakeStruct(FastSeq( + ("oldGlobals", Ref("global", origGlobalTyp)), + ( + "__initState", + RunAgg( + init, + MakeTuple.ordered(aggSigs.indices.map { aIdx => + AggStateValue(aIdx, aggSigs(aIdx).state) }), - aggStateSigsPlusTake) - } - ), - newKeyType.size, newKeyType.size - 1 - ) - Some(TableMapGlobals(tmp, GetField(Ref("global", insGlob.typ), "oldGlobals"))) - case _ => None - }) + aggSigs.map(_.state), + ), + ), + )), + ) + + val insGlobName = genUID() + def insGlob = Ref(insGlobName, ts.typ.globalType) + val partiallyAggregated = + TableMapPartitions( + ts, + insGlob.name, + streamName, + Let( + FastSeq("global" -> GetField(insGlob, "oldGlobals")), + StreamBufferedAggregate( + Ref(streamName, streamTyp), + bindIR(GetField(insGlob, "__initState")) { states => + Begin(aggSigs.indices.map { aIdx => + InitFromSerializedValue( + aIdx, + GetTupleElement(states, aIdx), + aggSigs(aIdx).state, + ) + }) + }, + newKey, + seq, + "row", + aggSigs, + bufferSize, + ), + ), + 0, + 0, + ).noSharing(ctx) + + val analyses = LoweringAnalyses(partiallyAggregated, ctx) + val preShuffleStage = ctx.backend.tableToTableStage(ctx, partiallyAggregated, analyses) + // annoying but no better alternative right now + val rt = analyses.requirednessAnalysis.lookup(partiallyAggregated).asInstanceOf[RTable] + val partiallyAggregatedReader = ctx.backend.lowerDistributedSort( + ctx, + preShuffleStage, + newKeyType.fieldNames.map(k => SortField(k, Ascending)), + rt, + nPartitions, + ) + + val takeVirtualSig = + TakeStateSig(VirtualTypeWithReq(newKeyType, rt.rowType.select(newKeyType.fieldNames))) + val takeAggSig = PhysicalAggSig(Take(), takeVirtualSig) + val aggStateSigsPlusTake = aggs.states ++ Array(takeVirtualSig) + + val postAggUID = genUID() + val resultFromTakeUID = genUID() + val result = ResultOp(aggs.aggs.length, takeAggSig) + + val shuffleRead = + TableRead(partiallyAggregatedReader.fullType, false, partiallyAggregatedReader) + + val partStream = Ref(genUID(), TStream(shuffleRead.typ.rowType)) + val tmp = TableMapPartitions( + shuffleRead, + insGlob.name, + partStream.name, + Let( + FastSeq("global" -> GetField(insGlob, "oldGlobals")), + mapIR(StreamGroupByKey( + partStream, + newKeyType.fieldNames.toIndexedSeq, + missingEqual = true, + )) { groupRef => + RunAgg( + Begin(FastSeq( + bindIR(GetField(insGlob, "__initState")) { states => + Begin(aggSigs.indices.map { aIdx => + InitFromSerializedValue( + aIdx, + GetTupleElement(states, aIdx), + aggSigs(aIdx).state, + ) + }) + }, + InitOp( + aggSigs.length, + IndexedSeq(I32(1)), + PhysicalAggSig(Take(), takeVirtualSig), + ), + forIR(groupRef) { elem => + Begin(FastSeq( + SeqOp( + aggSigs.length, + IndexedSeq(SelectFields(elem, newKeyType.fieldNames)), + PhysicalAggSig(Take(), takeVirtualSig), + ), + Begin((0 until aggSigs.length).map { aIdx => + CombOpValue( + aIdx, + GetTupleElement(GetField(elem, "agg"), aIdx), + aggSigs(aIdx), + ) + }), + )) + }, + )), + Let( + FastSeq( + aggs.resultRef.name -> ResultOp.makeTuple(aggs.aggs), + postAggUID -> postAggIR, + resultFromTakeUID -> result, + ), { + val keyIRs: IndexedSeq[(String, IR)] = + newKeyType.fieldNames.map(keyName => + keyName -> GetField( + ArrayRef(Ref(resultFromTakeUID, result.typ), 0), + keyName, + ) + ) + + MakeStruct(keyIRs ++ expr.typ.asInstanceOf[TStruct].fieldNames.map { f => + (f, GetField(Ref(postAggUID, postAggIR.typ), f)) + }) + }, + ), + aggStateSigsPlusTake, + ) + }, + ), + newKeyType.size, + newKeyType.size - 1, + ) + Some(TableMapGlobals(tmp, GetField(Ref("global", insGlob.typ), "oldGlobals"))) + case _ => None + }, + ) } } diff --git a/hail/src/main/scala/is/hail/expr/ir/lowering/LowerBlockMatrixIR.scala b/hail/src/main/scala/is/hail/expr/ir/lowering/LowerBlockMatrixIR.scala index c6141dd21cb..91f7bf02caf 100644 --- a/hail/src/main/scala/is/hail/expr/ir/lowering/LowerBlockMatrixIR.scala +++ b/hail/src/main/scala/is/hail/expr/ir/lowering/LowerBlockMatrixIR.scala @@ -5,8 +5,8 @@ import is.hail.expr.Nat import is.hail.expr.ir._ import is.hail.expr.ir.functions.GetElement import is.hail.rvd.RVDPartitioner +import is.hail.types.{tcoerce, BlockMatrixSparsity, BlockMatrixType, TypeWithRequiredness} import is.hail.types.virtual._ -import is.hail.types.{BlockMatrixSparsity, BlockMatrixType, TypeWithRequiredness, tcoerce} import is.hail.utils._ object BlockMatrixStage { @@ -20,7 +20,13 @@ case class EmptyBlockMatrixStage(eltType: Type) extends BlockMatrixStage(FastSeq def blockBody(ctxRef: Ref): IR = NA(TNDArray(eltType, Nat(2))) - override def collectBlocks(staticID: String, dynamicID: IR = NA(TString))(f: (IR, IR) => IR, blocksToCollect: Array[(Int, Int)]): IR = { + override def collectBlocks( + staticID: String, + dynamicID: IR = NA(TString), + )( + f: (IR, IR) => IR, + blocksToCollect: Array[(Int, Int)], + ): IR = { assert(blocksToCollect.isEmpty) MakeArray(FastSeq(), TArray(f(Ref("x", ctxType), blockBody(Ref("x", ctxType))).typ)) } @@ -31,13 +37,21 @@ abstract class BlockMatrixStage(val broadcastVals: IndexedSeq[Ref], val ctxType: def blockBody(ctxRef: Ref): IR - def collectBlocks(staticID: String, dynamicID: IR = NA(TString))(f: (IR, IR) => IR, blocksToCollect: Array[(Int, Int)]): IR = { + def collectBlocks( + staticID: String, + dynamicID: IR = NA(TString), + )( + f: (IR, IR) => IR, + blocksToCollect: Array[(Int, Int)], + ): IR = { val ctxRef = Ref(genUID(), ctxType) val body = f(ctxRef, blockBody(ctxRef)) val ctxs = MakeStream(blocksToCollect.map(idx => blockContext(idx)), TStream(ctxRef.typ)) val bodyFreeVars = FreeVariables(body, supportsAgg = false, supportsScan = false) - val bcFields = broadcastVals.filter { ref => bodyFreeVars.eval.lookupOption(ref.name).isDefined } - val bcVals = MakeStruct(bcFields.map { ref => ref.name -> ref }) + val bcFields = broadcastVals.filter { ref => + bodyFreeVars.eval.lookupOption(ref.name).isDefined + } + val bcVals = MakeStruct(bcFields.map(ref => ref.name -> ref)) val bcRef = Ref(genUID(), bcVals.typ) val wrappedBody = Let(bcFields.map(ref => ref.name -> GetField(bcRef, ref.name)), body) CollectDistributedArray(ctxs, bcVals, ctxRef.name, bcRef.name, wrappedBody, dynamicID, staticID) @@ -52,18 +66,27 @@ abstract class BlockMatrixStage(val broadcastVals: IndexedSeq[Ref], val ctxType: val rows = if (typ.isSparse) { val blockMap = blocksRowMajor.zipWithIndex.toMap - MakeArray(Array.tabulate[IR](typ.nRowBlocks) { i => - NDArrayConcat(MakeArray(Array.tabulate[IR](typ.nColBlocks) { j => - if (blockMap.contains(i -> j)) - ArrayRef(blockResults, i * typ.nColBlocks + j) - else { - val (nRows, nCols) = typ.blockShape(i, j) - MakeNDArray.fill(zero(typ.elementType), FastSeq(nRows, nCols), True()) - } - }, tcoerce[TArray](cda.typ)), 1) - }, tcoerce[TArray](cda.typ)) + MakeArray( + Array.tabulate[IR](typ.nRowBlocks) { i => + NDArrayConcat( + MakeArray( + Array.tabulate[IR](typ.nColBlocks) { j => + if (blockMap.contains(i -> j)) + ArrayRef(blockResults, i * typ.nColBlocks + j) + else { + val (nRows, nCols) = typ.blockShape(i, j) + MakeNDArray.fill(zero(typ.elementType), FastSeq(nRows, nCols), True()) + } + }, + tcoerce[TArray](cda.typ), + ), + 1, + ) + }, + tcoerce[TArray](cda.typ), + ) } else { - ToArray(mapIR(rangeIR(I32(typ.nRowBlocks))){ rowIdxRef => + ToArray(mapIR(rangeIR(I32(typ.nRowBlocks))) { rowIdxRef => val blocksInOneRow = ToArray(mapIR(rangeIR(I32(typ.nColBlocks))) { colIdxRef => ArrayRef(blockResults, rowIdxRef * typ.nColBlocks + colIdxRef) }) @@ -84,6 +107,7 @@ abstract class BlockMatrixStage(val broadcastVals: IndexedSeq[Ref], val ctxType: def blockBody(ctxRef: Ref): IR = bindIR(GetField(ctxRef, "old"))(outer.blockBody) } } + def mapBody(f: (IR, IR) => IR): BlockMatrixStage = { val outer = this new BlockMatrixStage(broadcastVals, outer.ctxType) { @@ -93,7 +117,11 @@ abstract class BlockMatrixStage(val broadcastVals: IndexedSeq[Ref], val ctxType: } } - def condenseBlocks(typ: BlockMatrixType, rowBlocks: Array[Array[Int]], colBlocks: Array[Array[Int]]): BlockMatrixStage = { + def condenseBlocks( + typ: BlockMatrixType, + rowBlocks: Array[Array[Int]], + colBlocks: Array[Array[Int]], + ): BlockMatrixStage = { val outer = this val ctxType = TArray(TArray(TTuple(TTuple(TInt64, TInt64), outer.ctxType))) new BlockMatrixStage(outer.broadcastVals, ctxType) { @@ -107,28 +135,42 @@ abstract class BlockMatrixStage(val broadcastVals: IndexedSeq[Ref], val ctxType: MakeTuple.ordered(FastSeq(NA(TTuple(TInt64, TInt64)), outer.blockContext(idx2))) else { val (nRows, nCols) = typ.blockShape(ii, jj) - MakeTuple.ordered(FastSeq(MakeTuple.ordered(FastSeq(nRows, nCols)), NA(outer.ctxType))) + MakeTuple.ordered(FastSeq( + MakeTuple.ordered(FastSeq(nRows, nCols)), + NA(outer.ctxType), + )) } }: _*) }: _*) } def blockBody(ctxRef: Ref): IR = { - NDArrayConcat(ToArray(mapIR(ToStream(ctxRef)) { ctxRows => - NDArrayConcat(ToArray(mapIR(ToStream(ctxRows)) { shapeOrCtx => - bindIR(GetTupleElement(shapeOrCtx, 1)) { ctx => - If(IsNA(ctx), - bindIR(GetTupleElement(shapeOrCtx, 0)) { shape => - MakeNDArray( - ToArray(mapIR( - rangeIR((GetTupleElement(shape, 0) * GetTupleElement(shape, 1)).toI) - )(_ => zero(typ.elementType))), - shape, False(), ErrorIDs.NO_ERROR) - }, - outer.blockBody(ctx)) - } - }), 1) - }), 0) + NDArrayConcat( + ToArray(mapIR(ToStream(ctxRef)) { ctxRows => + NDArrayConcat( + ToArray(mapIR(ToStream(ctxRows)) { shapeOrCtx => + bindIR(GetTupleElement(shapeOrCtx, 1)) { ctx => + If( + IsNA(ctx), + bindIR(GetTupleElement(shapeOrCtx, 0)) { shape => + MakeNDArray( + ToArray(mapIR( + rangeIR((GetTupleElement(shape, 0) * GetTupleElement(shape, 1)).toI) + )(_ => zero(typ.elementType))), + shape, + False(), + ErrorIDs.NO_ERROR, + ) + }, + outer.blockBody(ctx), + ) + } + }), + 1, + ) + }), + 0, + ) } } } @@ -148,10 +190,12 @@ object BlockMatrixStage2 { IndexedSeq(), BlockMatrixType.dense(eltType, 0, 0, 0), DenseContexts(0, 0, ib.memoize(MakeArray(FastSeq(), TArray(ctxType)))), - _ => NA(ctxType)) + _ => NA(ctxType), + ) } - def broadcastVector(vector: IR, ib: IRBuilder, typ: BlockMatrixType, asRowVector: Boolean): BlockMatrixStage2 = { + def broadcastVector(vector: IR, ib: IRBuilder, typ: BlockMatrixType, asRowVector: Boolean) + : BlockMatrixStage2 = { val v: Ref = ib.strictMemoize(vector) val contexts = BMSContexts.tabulate(typ, ib) { (i, j) => val (m, n) = typ.blockShapeIR(i, j) @@ -166,26 +210,31 @@ object BlockMatrixStage2 { bindIRs(GetField(ctx, "shape"), GetField(ctx, "start")) { case Seq(shape, start) => bindIRs( if (asRowVector) GetTupleElement(shape, 1) else GetTupleElement(shape, 0), - if (asRowVector) GetTupleElement(shape, 0) else GetTupleElement(shape, 1) + if (asRowVector) GetTupleElement(shape, 0) else GetTupleElement(shape, 1), ) { case Seq(len, nRep) => bindIR( NDArrayReshape( NDArraySlice(v, maketuple(maketuple(start, start + len, 1L))), if (asRowVector) maketuple(1L, len) else maketuple(len.toL, 1L), - ErrorIDs.NO_ERROR) + ErrorIDs.NO_ERROR, + ) ) { sliced => - NDArrayConcat(ToArray(mapIR(rangeIR(nRep.toI))(_ => sliced)), if (asRowVector) 0 else 1) + NDArrayConcat( + ToArray(mapIR(rangeIR(nRep.toI))(_ => sliced)), + if (asRowVector) 0 else 1, + ) } } } - }) - } + }, + ) + } def apply( broadcastVals: IndexedSeq[Ref], typ: BlockMatrixType, contexts: BMSContexts, - _blockIR: Ref => IR + _blockIR: Ref => IR, ): BlockMatrixStage2 = { val ctxRef = Ref(genUID(), contexts.elementType) val blockIR = _blockIR(ctxRef) @@ -207,11 +256,13 @@ object BMSContexts { } def tabulate(typ: BlockMatrixType, ib: IRBuilder)(f: (IR, IR) => IR): BMSContexts = { - val contexts = ib.memoize(ToArray(mapIR(typ.sparsity.allBlocksColMajorIR(typ.nRowBlocks, typ.nColBlocks)) { coords => - bindIRs(GetTupleElement(coords, 0), GetTupleElement(coords, 1)) { case Seq(i, j) => - f(i, j) - } - })) + val contexts = + ib.memoize(ToArray(mapIR(typ.sparsity.allBlocksColMajorIR(typ.nRowBlocks, typ.nColBlocks)) { + coords => + bindIRs(GetTupleElement(coords, 0), GetTupleElement(coords, 1)) { case Seq(i, j) => + f(i, j) + } + })) typ.sparsity.definedBlocksCSCIR(typ.nColBlocks) match { case Some((pos, idx)) => SparseContexts(typ.nRowBlocks, typ.nColBlocks, ib.memoize(pos), ib.memoize(idx), contexts) @@ -220,23 +271,34 @@ object BMSContexts { } } - def transpose(contexts: BMSContexts, ib: IRBuilder, typ: BlockMatrixType): BMSContexts = contexts match { - case dense: DenseContexts => DenseContexts( - dense.nCols, dense.nRows, - ib.memoize(ToArray(flatMapIR(rangeIR(dense.nRows)) { i => - mapIR(rangeIR(dense.nCols)) { j => - ArrayRef(dense.contexts, (j * dense.nRows) + i) - } - }))) - case sparse: SparseContexts => - val (staticRowPos, staticRowIdx) = typ.sparsity.definedBlocksCSC(typ.nColBlocks).get - val (newRowPos, newRowIdx, newToOldPos) = - BlockMatrixSparsity.transposeCSCSparsityIR(typ.nRowBlocks, typ.nColBlocks, staticRowPos, staticRowIdx) - val newContexts = ToArray(mapIR(ToStream(newToOldPos)) { oldPos => - ArrayRef(contexts.contexts, oldPos) - }) - SparseContexts(sparse.nCols, sparse.nRows, ib.memoize(newRowPos), ib.memoize(newRowIdx), ib.memoize(newContexts)) - } + def transpose(contexts: BMSContexts, ib: IRBuilder, typ: BlockMatrixType): BMSContexts = + contexts match { + case dense: DenseContexts => DenseContexts( + dense.nCols, + dense.nRows, + ib.memoize(ToArray(flatMapIR(rangeIR(dense.nRows)) { i => + mapIR(rangeIR(dense.nCols))(j => ArrayRef(dense.contexts, (j * dense.nRows) + i)) + })), + ) + case sparse: SparseContexts => + val (staticRowPos, staticRowIdx) = typ.sparsity.definedBlocksCSC(typ.nColBlocks).get + val (newRowPos, newRowIdx, newToOldPos) = + BlockMatrixSparsity.transposeCSCSparsityIR( + typ.nRowBlocks, + typ.nColBlocks, + staticRowPos, + staticRowIdx, + ) + val newContexts = + ToArray(mapIR(ToStream(newToOldPos))(oldPos => ArrayRef(contexts.contexts, oldPos))) + SparseContexts( + sparse.nCols, + sparse.nRows, + ib.memoize(newRowPos), + ib.memoize(newRowIdx), + ib.memoize(newContexts), + ) + } } abstract class BMSContexts { @@ -257,7 +319,12 @@ abstract class BMSContexts { def zip(other: BMSContexts, ib: IRBuilder): BMSContexts - def grouped(rowDeps: IndexedSeq[IndexedSeq[Int]], colDeps: IndexedSeq[IndexedSeq[Int]], typ: BlockMatrixType, ib: IRBuilder): BMSContexts + def grouped( + rowDeps: IndexedSeq[IndexedSeq[Int]], + colDeps: IndexedSeq[IndexedSeq[Int]], + typ: BlockMatrixType, + ib: IRBuilder, + ): BMSContexts def collect(makeBlock: (Ref, Ref, Ref) => IR): IR @@ -275,51 +342,60 @@ object DenseContexts { } } -case class DenseContexts(nRows: TrivialIR, nCols: TrivialIR, contexts: TrivialIR) extends BMSContexts { +case class DenseContexts(nRows: TrivialIR, nCols: TrivialIR, contexts: TrivialIR) + extends BMSContexts { val elementType = contexts.typ.asInstanceOf[TArray].elementType def irValue: IR = makestruct("nRows" -> nRows, "nCols" -> nCols, "contexts" -> contexts) - def print(ctx: ExecuteContext): Unit = { - println(s"DenseContexts:\n nRows = ${Pretty(ctx, nRows)}\n nCols = ${Pretty(ctx, nCols)}\n contexts = ${Pretty(ctx, contexts)}") - } + def print(ctx: ExecuteContext): Unit = + println( + s"DenseContexts:\n nRows = ${Pretty(ctx, nRows)}\n nCols = ${Pretty(ctx, nCols)}\n contexts = ${Pretty(ctx, contexts)}" + ) def apply(row: IR, col: IR): IR = ArrayRef(contexts, (col * nRows) + row) def map(ib: IRBuilder)(f: (IR, IR, IR, IR) => IR): DenseContexts = { - DenseContexts(nRows, nCols, + DenseContexts( + nRows, + nCols, ib.memoize(ToArray(flatMapIR(rangeIR(nCols)) { j => mapIR(rangeIR(nRows)) { i => bindIR((j * nRows) + i) { pos => - bindIR(ArrayRef(contexts, pos)) { old => - f(i, j, pos, old) - } + bindIR(ArrayRef(contexts, pos))(old => f(i, j, pos, old)) } } - }))) + })), + ) } def zip(other: BMSContexts, ib: IRBuilder): BMSContexts = { val newContexts = ib.memoize(ToArray(zip2( - ToStream(this.contexts), ToStream(other.contexts), ArrayZipBehavior.AssertSameLength - ) { (l, r) => maketuple(l, r) })) + ToStream(this.contexts), + ToStream(other.contexts), + ArrayZipBehavior.AssertSameLength, + )((l, r) => maketuple(l, r)))) DenseContexts(nRows, nCols, newContexts) } def sparsify(rowPos: TrivialIR, rowIdx: TrivialIR, ib: IRBuilder): BMSContexts = { val newContexts = ib.memoize(ToArray(flatMapIR(rangeIR(nCols)) { j => - bindIRs(ArrayRef(rowPos, j), ArrayRef(rowPos, j + 1), j * nRows) { case Seq(start, end, basePos) => - mapIR(rangeIR(start, end)) { pos => - bindIR(ArrayRef(rowIdx, pos)) { i => - ArrayRef(contexts, basePos + i) + bindIRs(ArrayRef(rowPos, j), ArrayRef(rowPos, j + 1), j * nRows) { + case Seq(start, end, basePos) => + mapIR(rangeIR(start, end)) { pos => + bindIR(ArrayRef(rowIdx, pos))(i => ArrayRef(contexts, basePos + i)) } - } } })) SparseContexts(nRows, nCols, rowPos, rowIdx, newContexts) } - def grouped(rowDeps: IndexedSeq[IndexedSeq[Int]], colDeps: IndexedSeq[IndexedSeq[Int]], typ: BlockMatrixType, ib: IRBuilder): DenseContexts = { + def grouped( + rowDeps: IndexedSeq[IndexedSeq[Int]], + colDeps: IndexedSeq[IndexedSeq[Int]], + typ: BlockMatrixType, + ib: IRBuilder, + ): DenseContexts = { val rowDepsLit = Literal(TArray(TArray(TInt32)), rowDeps) val colDepsLit = Literal(TArray(TArray(TInt32)), colDeps) assert(rowDeps.nonEmpty || colDeps.nonEmpty) @@ -327,9 +403,7 @@ case class DenseContexts(nRows: TrivialIR, nCols: TrivialIR, contexts: TrivialIR val newContexts = ToArray(flatMapIR(ToStream(colDepsLit)) { localColDeps => mapIR(rangeIR(nRows)) { i => IRBuilder.scoped { ib => - val localContexts = ToArray(mapIR(ToStream(localColDeps)) { jl => - this (i, jl) - }) + val localContexts = ToArray(mapIR(ToStream(localColDeps))(jl => this(i, jl))) DenseContexts(1, ib.memoize(ArrayLen(localColDeps)), ib.memoize(localContexts)).irValue } } @@ -340,9 +414,7 @@ case class DenseContexts(nRows: TrivialIR, nCols: TrivialIR, contexts: TrivialIR val newContexts = ToArray(flatMapIR(rangeIR(nCols)) { j => mapIR(ToStream(rowDepsLit)) { localRowDeps => IRBuilder.scoped { ib => - val localContexts = ToArray(mapIR(ToStream(localRowDeps)) { il => - this (il, j) - }) + val localContexts = ToArray(mapIR(ToStream(localRowDeps))(il => this(il, j))) DenseContexts(ib.memoize(ArrayLen(localRowDeps)), 1, ib.memoize(localContexts)).irValue } } @@ -353,11 +425,13 @@ case class DenseContexts(nRows: TrivialIR, nCols: TrivialIR, contexts: TrivialIR mapIR(ToStream(rowDepsLit)) { localRowDeps => IRBuilder.scoped { ib => val localContexts = ToArray(flatMapIR(ToStream(localColDeps)) { jl => - mapIR(ToStream(localRowDeps)) { il => - this(il, jl) - } + mapIR(ToStream(localRowDeps))(il => this(il, jl)) }) - DenseContexts(ib.memoize(ArrayLen(localRowDeps)), ib.memoize(ArrayLen(localColDeps)), ib.memoize(localContexts)).irValue + DenseContexts( + ib.memoize(ArrayLen(localRowDeps)), + ib.memoize(ArrayLen(localColDeps)), + ib.memoize(localContexts), + ).irValue } } }) @@ -365,14 +439,15 @@ case class DenseContexts(nRows: TrivialIR, nCols: TrivialIR, contexts: TrivialIR } def collect(makeBlock: (Ref, Ref, Ref) => IR): IR = { - NDArrayConcat(ToArray(mapIR(rangeIR(nCols)) { j => - val colBlocks = mapIR(rangeIR(nRows)) { i => - bindIR(ArrayRef(contexts, j * nRows + i)) { ctx => - makeBlock(i, j, ctx) + NDArrayConcat( + ToArray(mapIR(rangeIR(nCols)) { j => + val colBlocks = mapIR(rangeIR(nRows)) { i => + bindIR(ArrayRef(contexts, j * nRows + i))(ctx => makeBlock(i, j, ctx)) } - } - NDArrayConcat(ToArray(colBlocks), 0) - }), 1) + NDArrayConcat(ToArray(colBlocks), 0) + }), + 1, + ) } } @@ -389,14 +464,21 @@ object SparseContexts { } } -case class SparseContexts(nRows: TrivialIR, nCols: TrivialIR, rowPos: TrivialIR, rowIdx: TrivialIR, contexts: TrivialIR) extends BMSContexts { +case class SparseContexts( + nRows: TrivialIR, + nCols: TrivialIR, + rowPos: TrivialIR, + rowIdx: TrivialIR, + contexts: TrivialIR, +) extends BMSContexts { val elementType = contexts.typ.asInstanceOf[TArray].elementType def irValue: IR = makestruct("nRows" -> nRows, "nCols" -> nCols, "contexts" -> contexts) - def print(ctx: ExecuteContext): Unit = { - println(s"SparseContexts:\n nRows = ${ Pretty(ctx, nRows) }\n nCols = ${ Pretty(ctx, nCols) }\n contexts = ${ Pretty(ctx, contexts) }") - } + def print(ctx: ExecuteContext): Unit = + println( + s"SparseContexts:\n nRows = ${Pretty(ctx, nRows)}\n nCols = ${Pretty(ctx, nCols)}\n contexts = ${Pretty(ctx, contexts)}" + ) def apply(row: IR, col: IR): IR = { val startPos = ArrayRef(rowPos, col) @@ -404,17 +486,26 @@ case class SparseContexts(nRows: TrivialIR, nCols: TrivialIR, rowPos: TrivialIR, bindIR( Apply("lowerBound", Seq(), FastSeq(rowIdx, row, startPos, endPos), TInt32, ErrorIDs.NO_ERROR) ) { pos => - If(ArrayRef(rowIdx, pos).ceq(row), + If( + ArrayRef(rowIdx, pos).ceq(row), ArrayRef(contexts, pos), - Die(strConcat("Internal Error, tried to load missing BlockMatrix context: (row = ", row, ", col = ", - col, ", pos = ", pos, ", rowPos = ", rowPos, ", rowIdx = ", rowIdx, ")"), + Die( + strConcat("Internal Error, tried to load missing BlockMatrix context: (row = ", row, + ", col = ", + col, ", pos = ", pos, ", rowPos = ", rowPos, ", rowIdx = ", rowIdx, ")"), elementType, - ErrorIDs.NO_ERROR)) + ErrorIDs.NO_ERROR, + ), + ) } } def map(ib: IRBuilder)(f: (IR, IR, IR, IR) => IR): SparseContexts = { - SparseContexts(nRows, nCols, rowPos, rowIdx, + SparseContexts( + nRows, + nCols, + rowPos, + rowIdx, ib.memoize(ToArray(flatMapIR(rangeIR(nCols)) { j => bindIRs(ArrayRef(rowPos, j), ArrayRef(rowPos, j + 1)) { case Seq(start, end) => mapIR(rangeIR(start, end)) { pos => @@ -423,31 +514,41 @@ case class SparseContexts(nRows: TrivialIR, nCols: TrivialIR, rowPos: TrivialIR, } } } - }))) + })), + ) } def mapDense(ib: IRBuilder)(f: (IR, IR, IR) => IR): DenseContexts = { val newContexts = ib.memoize(flatMapIR(rangeIR(nCols)) { j => bindIRs(ArrayRef(rowPos, j), ArrayRef(rowPos, j + 1)) { case Seq(start, end) => - val allIdxs = mapIR(rangeIR(nRows)) { i => makestruct("idx" -> i) } + val allIdxs = mapIR(rangeIR(nRows))(i => makestruct("idx" -> i)) val idxedExisting = mapIR(rangeIR(start, end)) { pos => makestruct("idx" -> ArrayRef(rowIdx, pos), "context" -> ArrayRef(contexts, pos)) } - joinRightDistinctIR(allIdxs, idxedExisting, FastSeq("idx"), FastSeq("idx"), "left") { (idx, struct) => - val i = GetField(idx, "idx") - val context = GetField(struct, "context") - f(i, j, context) + joinRightDistinctIR(allIdxs, idxedExisting, FastSeq("idx"), FastSeq("idx"), "left") { + (idx, struct) => + val i = GetField(idx, "idx") + val context = GetField(struct, "context") + f(i, j, context) } } }) DenseContexts(nRows, nCols, newContexts) } - def mapWithNewSparsity(newRowPos: TrivialIR, newRowIdx: TrivialIR, ib: IRBuilder)(f: (IR, IR, IR) => IR): SparseContexts = { + def mapWithNewSparsity( + newRowPos: TrivialIR, + newRowIdx: TrivialIR, + ib: IRBuilder, + )( + f: (IR, IR, IR) => IR + ): SparseContexts = { val newContexts = ib.memoize(ToArray(flatMapIR(rangeIR(nCols)) { j => bindIRs( - ArrayRef(rowPos, j), ArrayRef(rowPos, j + 1), - ArrayRef(newRowPos, j), ArrayRef(newRowPos, j + 1) + ArrayRef(rowPos, j), + ArrayRef(rowPos, j + 1), + ArrayRef(newRowPos, j), + ArrayRef(newRowPos, j + 1), ) { case Seq(oldStart, oldEnd, newStart, newEnd) => val newIdxs = mapIR(rangeIR(newStart, newEnd)) { pos => makestruct("idx" -> ArrayRef(newRowIdx, pos)) @@ -455,10 +556,11 @@ case class SparseContexts(nRows: TrivialIR, nCols: TrivialIR, rowPos: TrivialIR, val idxedExisting = mapIR(rangeIR(oldStart, oldEnd)) { pos => makestruct("idx" -> ArrayRef(rowIdx, pos), "context" -> ArrayRef(contexts, pos)) } - joinRightDistinctIR(newIdxs, idxedExisting, FastSeq("idx"), FastSeq("idx"), "left") { (idx, struct) => - val i = GetField(idx, "idx") - val context = GetField(struct, "context") - f(i, j, context) + joinRightDistinctIR(newIdxs, idxedExisting, FastSeq("idx"), FastSeq("idx"), "left") { + (idx, struct) => + val i = GetField(idx, "idx") + val context = GetField(struct, "context") + f(i, j, context) } } })) @@ -467,12 +569,19 @@ case class SparseContexts(nRows: TrivialIR, nCols: TrivialIR, rowPos: TrivialIR, def zip(other: BMSContexts, ib: IRBuilder): BMSContexts = { val newContexts = ib.memoize(zip2( - this.contexts, other.contexts, ArrayZipBehavior.AssertSameLength - ) { (l, r) => maketuple(l, r) }) + this.contexts, + other.contexts, + ArrayZipBehavior.AssertSameLength, + )((l, r) => maketuple(l, r))) SparseContexts(nRows, nCols, rowPos, rowIdx, newContexts) } - def grouped(rowDeps: IndexedSeq[IndexedSeq[Int]], colDeps: IndexedSeq[IndexedSeq[Int]], typ: BlockMatrixType, ib: IRBuilder): SparseContexts = { + def grouped( + rowDeps: IndexedSeq[IndexedSeq[Int]], + colDeps: IndexedSeq[IndexedSeq[Int]], + typ: BlockMatrixType, + ib: IRBuilder, + ): SparseContexts = { val newNRows = rowDeps.length val newNCols = colDeps.length val rowBlockSizes = Literal(TArray(TInt32), rowDeps.map(_.length)) @@ -499,24 +608,35 @@ case class SparseContexts(nRows: TrivialIR, nCols: TrivialIR, rowPos: TrivialIR, } }) - SparseContexts(newNRows, newNCols, ib.memoize(newRowPos), ib.memoize(newRowIdx), ib.memoize(newContexts)) + SparseContexts( + newNRows, + newNCols, + ib.memoize(newRowPos), + ib.memoize(newRowIdx), + ib.memoize(newContexts), + ) } def collect(makeBlock: (Ref, Ref, Ref) => IR): IR = { - NDArrayConcat(ToArray(mapIR(rangeIR(nCols)) { j => - val allIdxs = mapIR(rangeIR(nRows)) { i => makestruct("idx" -> i) } - val startPos = ArrayRef(rowPos, j) - val endPos = ArrayRef(rowPos, j + 1) - val idxedExisting = mapIR(rangeIR(startPos, endPos)) { pos => - makestruct("idx" -> ArrayRef(rowIdx, pos), "ctx" -> ArrayRef(contexts, pos)) - } - val colBlocks = joinRightDistinctIR(allIdxs, idxedExisting, FastSeq("idx"), FastSeq("idx"), "left") { (l, struct) => - bindIRs(GetField(l, "idx"), GetField(struct, "ctx")) { case Seq(i, ctx) => - makeBlock(i, j, ctx) + NDArrayConcat( + ToArray(mapIR(rangeIR(nCols)) { j => + val allIdxs = mapIR(rangeIR(nRows))(i => makestruct("idx" -> i)) + val startPos = ArrayRef(rowPos, j) + val endPos = ArrayRef(rowPos, j + 1) + val idxedExisting = mapIR(rangeIR(startPos, endPos)) { pos => + makestruct("idx" -> ArrayRef(rowIdx, pos), "ctx" -> ArrayRef(contexts, pos)) } - } - NDArrayConcat(ToArray(colBlocks), 0) - }), 1) + val colBlocks = + joinRightDistinctIR(allIdxs, idxedExisting, FastSeq("idx"), FastSeq("idx"), "left") { + (l, struct) => + bindIRs(GetField(l, "idx"), GetField(struct, "ctx")) { case Seq(i, ctx) => + makeBlock(i, j, ctx) + } + } + NDArrayConcat(ToArray(colBlocks), 0) + }), + 1, + ) } } @@ -525,7 +645,7 @@ class BlockMatrixStage2 private ( val typ: BlockMatrixType, val contexts: BMSContexts, private val ctxRefName: String, - private val _blockIR: IR + private val _blockIR: IR, ) { assert { def literalOrRef(x: IR) = x.isInstanceOf[Literal] || x.isInstanceOf[Ref] @@ -535,16 +655,14 @@ class BlockMatrixStage2 private ( } } - def print(ctx: ExecuteContext): Unit = { - println(s"contexts:\n${contexts.print(ctx)}\nbody(${ctxRefName}) = ${Pretty(ctx, _blockIR)}") - } + def print(ctx: ExecuteContext): Unit = + println(s"contexts:\n${contexts.print(ctx)}\nbody($ctxRefName) = ${Pretty(ctx, _blockIR)}") - def blockIR(ctx: Ref): IR = { + def blockIR(ctx: Ref): IR = if (ctx.name == ctxRefName) _blockIR else Let(FastSeq(ctxRefName -> ctx), _blockIR) - } def ctxType: Type = contexts.elementType @@ -578,7 +696,13 @@ class BlockMatrixStage2 private ( def transposed(ib: IRBuilder): BlockMatrixStage2 = { val newBlockIR = NDArrayReindex(_blockIR, FastSeq(1, 0)) - new BlockMatrixStage2(broadcastVals, typ.transpose, BMSContexts.transpose(contexts, ib, typ), ctxRefName, newBlockIR) + new BlockMatrixStage2( + broadcastVals, + typ.transpose, + BMSContexts.transpose(contexts, ib, typ), + ctxRefName, + newBlockIR, + ) } def densify(ib: IRBuilder): BlockMatrixStage2 = contexts match { @@ -590,18 +714,27 @@ class BlockMatrixStage2 private ( } def newBlock(context: Ref): IR = { bindIR(GetField(context, "oldContext")) { oldContext => - If(IsNA(oldContext), + If( + IsNA(oldContext), MakeNDArray.fill( zero(typ.elementType), FastSeq(GetField(oldContext, "nRows"), GetField(oldContext, "nCols")), - False()), - blockIR(oldContext)) + False(), + ), + blockIR(oldContext), + ) } } BlockMatrixStage2(broadcastVals, typ, newContexts, newBlock) } - def withSparsity(rowPos: TrivialIR, rowIdx: TrivialIR, ib: IRBuilder, newType: BlockMatrixType, isSubset: Boolean = false): BlockMatrixStage2 = { + def withSparsity( + rowPos: TrivialIR, + rowIdx: TrivialIR, + ib: IRBuilder, + newType: BlockMatrixType, + isSubset: Boolean = false, + ): BlockMatrixStage2 = { if (newType.sparsity.definedBlocksColMajor == typ.sparsity.definedBlocksColMajor) return this contexts match { @@ -619,12 +752,15 @@ class BlockMatrixStage2 private ( def newBlock(context: Ref): IR = { bindIR(GetField(context, "oldContext")) { oldContext => - If(IsNA(oldContext), + If( + IsNA(oldContext), MakeNDArray.fill( zero(typ.elementType), FastSeq(GetField(context, "nRows"), GetField(context, "nCols")), - False()), - blockIR(oldContext)) + False(), + ), + blockIR(oldContext), + ) } } @@ -646,8 +782,9 @@ class BlockMatrixStage2 private ( def mapBody2( other: BlockMatrixStage2, ib: IRBuilder, - sparsityStrategy: SparsityStrategy - )(f: (IR, IR) => IR + sparsityStrategy: SparsityStrategy, + )( + f: (IR, IR) => IR ): BlockMatrixStage2 = { val (alignedLeft, alignedRight) = (contexts, other.contexts, sparsityStrategy) match { case (_: DenseContexts, _: DenseContexts, _) => @@ -657,27 +794,36 @@ class BlockMatrixStage2 private ( case (_: SparseContexts, _: DenseContexts, UnionBlocks) => (this.densify(ib), other) case (_: SparseContexts, _: SparseContexts, UnionBlocks) => - val newType = typ.copy(sparsity = UnionBlocks.mergeSparsity(typ.sparsity, other.typ.sparsity)) + val newType = + typ.copy(sparsity = UnionBlocks.mergeSparsity(typ.sparsity, other.typ.sparsity)) val (unionPos, unionIdx) = newType.sparsity.definedBlocksCSCIR(newType.nColBlocks).get val unionPosRef = ib.memoize(unionPos) val unionIdxRef = ib.memoize(unionIdx) - (this.withSparsity(unionPosRef, unionIdxRef, ib, newType), other.withSparsity(unionPosRef, unionIdxRef, ib, newType)) + ( + this.withSparsity(unionPosRef, unionIdxRef, ib, newType), + other.withSparsity(unionPosRef, unionIdxRef, ib, newType), + ) case (_: DenseContexts, sparse: SparseContexts, IntersectionBlocks) => (this.withSparsity(sparse.rowPos, sparse.rowIdx, ib, other.typ), other) case (sparse: SparseContexts, _: DenseContexts, IntersectionBlocks) => (this, other.withSparsity(sparse.rowPos, sparse.rowIdx, ib, this.typ)) case (_: SparseContexts, _: SparseContexts, IntersectionBlocks) => - val newType = typ.copy(sparsity = IntersectionBlocks.mergeSparsity(typ.sparsity, other.typ.sparsity)) + val newType = + typ.copy(sparsity = IntersectionBlocks.mergeSparsity(typ.sparsity, other.typ.sparsity)) val (unionPos, unionIdx) = newType.sparsity.definedBlocksCSCIR(newType.nColBlocks).get val unionPosRef = ib.memoize(unionPos) val unionIdxRef = ib.memoize(unionIdx) - (this.withSparsity(unionPosRef, unionIdxRef, ib, newType), other.withSparsity(unionPosRef, unionIdxRef, ib, newType)) + ( + this.withSparsity(unionPosRef, unionIdxRef, ib, newType), + other.withSparsity(unionPosRef, unionIdxRef, ib, newType), + ) } alignedLeft.mapBody2Aligned(alignedRight, ib)(f) } - private def mapBody2Aligned(other: BlockMatrixStage2, ib: IRBuilder)(f: (IR, IR) => IR): BlockMatrixStage2 = { + private def mapBody2Aligned(other: BlockMatrixStage2, ib: IRBuilder)(f: (IR, IR) => IR) + : BlockMatrixStage2 = { val newContexts = contexts.zip(other.contexts, ib) val ctxRef = Ref(genUID(), newContexts.elementType) val newBlockIR = @@ -687,12 +833,23 @@ class BlockMatrixStage2 private ( val newType = typ.copy(elementType = newBlockIR.typ.asInstanceOf[TNDArray].elementType) new BlockMatrixStage2( broadcastVals ++ other.broadcastVals, - newType, newContexts, ctxRef.name, newBlockIR) + newType, + newContexts, + ctxRef.name, + newBlockIR, + ) } - def filter(keepRows: IndexedSeq[Long], keepCols: IndexedSeq[Long], typ: BlockMatrixType, ib: IRBuilder): BlockMatrixStage2 = { - val rowBlockDependents = keepRows.grouped(typ.blockSize).map(_.map(i => (i / typ.blockSize).toInt).distinct).toFastSeq - val colBlockDependents = keepCols.grouped(typ.blockSize).map(_.map(i => (i / typ.blockSize).toInt).distinct).toFastSeq + def filter( + keepRows: IndexedSeq[Long], + keepCols: IndexedSeq[Long], + typ: BlockMatrixType, + ib: IRBuilder, + ): BlockMatrixStage2 = { + val rowBlockDependents = + keepRows.grouped(typ.blockSize).map(_.map(i => (i / typ.blockSize).toInt).distinct).toFastSeq + val colBlockDependents = + keepCols.grouped(typ.blockSize).map(_.map(i => (i / typ.blockSize).toInt).distinct).toFastSeq def localIndices(idxs: IndexedSeq[Long]): IndexedSeq[IndexedSeq[Long]] = { val result = new AnyRefArrayBuilder[IndexedSeq[Long]]() @@ -718,7 +875,8 @@ class BlockMatrixStage2 private ( val t = TArray(TArray(TArray(TInt64))) val groupedKeepRowsLit = if (keepRows.isEmpty) NA(t) else Literal(t, groupedKeepRows) val groupedKeepColsLit = if (keepCols.isEmpty) NA(t) else Literal(t, groupedKeepCols) - val groupedContexts: BMSContexts = contexts.grouped(rowBlockDependents, colBlockDependents, typ, ib) + val groupedContexts: BMSContexts = + contexts.grouped(rowBlockDependents, colBlockDependents, typ, ib) val groupedContextsWithIndices = groupedContexts.map(ib) { (i, j, pos, context) => maketuple(context, ArrayRef(groupedKeepRowsLit, i), ArrayRef(groupedKeepColsLit, j)) } @@ -735,7 +893,12 @@ class BlockMatrixStage2 private ( bindIRs(ArrayRef(localKeepRows, i), ArrayRef(localKeepCols, j)) { case Seq(rows, cols) => Coalesce(FastSeq( NDArrayFilter(blockIR(localContext), FastSeq(rows, cols)), - MakeNDArray.fill(zero(typ.elementType), FastSeq(ArrayLen(rows).toL, ArrayLen(cols).toL), False()))) + MakeNDArray.fill( + zero(typ.elementType), + FastSeq(ArrayLen(rows).toL, ArrayLen(cols).toL), + False(), + ), + )) } } } @@ -745,9 +908,7 @@ class BlockMatrixStage2 private ( } def zeroBand(lower: Long, upper: Long, typ: BlockMatrixType, ib: IRBuilder): BlockMatrixStage2 = { - val ctxs = contexts.map(ib) { (i, j, _, context) => - maketuple(context, i, j) - } + val ctxs = contexts.map(ib)((i, j, _, context) => maketuple(context, i, j)) def newBody(ctx: Ref): IR = IRBuilder.scoped { ib => val oldCtx = GetTupleElement(ctx, 0) @@ -759,9 +920,10 @@ class BlockMatrixStage2 private ( val localUpper = I64(upper) - diagIndex val (nRowsInBlock, nColsInBlock) = typ.blockShapeIR(i, j) val block = blockIR(oldCtx) - If(-localLower >= (nRowsInBlock - 1L) && localUpper >= (nColsInBlock - 1L), + If( + -localLower >= (nRowsInBlock - 1L) && localUpper >= (nColsInBlock - 1L), block, - invoke("zero_band", TNDArray(TFloat64, Nat(2)), block, localLower, localUpper) + invoke("zero_band", TNDArray(TFloat64, Nat(2)), block, localLower, localUpper), ) } } @@ -769,7 +931,12 @@ class BlockMatrixStage2 private ( BlockMatrixStage2(broadcastVals, typ, ctxs, newBody) } - def zeroRowIntervals(starts: IndexedSeq[Long], stops: IndexedSeq[Long], typ: BlockMatrixType, ib: IRBuilder): BlockMatrixStage2 = { + def zeroRowIntervals( + starts: IndexedSeq[Long], + stops: IndexedSeq[Long], + typ: BlockMatrixType, + ib: IRBuilder, + ): BlockMatrixStage2 = { val t = TArray(TArray(TInt64)) val startsGrouped = Literal(t, starts.grouped(typ.blockSize).toIndexedSeq) val stopsGrouped = Literal(t, stops.grouped(typ.blockSize).toIndexedSeq) @@ -783,8 +950,12 @@ class BlockMatrixStage2 private ( val i = GetTupleElement(ctx, 1) val j = GetTupleElement(ctx, 2) val (_, nCols) = typ.blockShapeIR(i, j) - val starts = ToArray(mapIR(ToStream(GetTupleElement(ctx, 3))) { s => minIR(maxIR(s - j.toL * typ.blockSize.toLong, 0L), nCols) }) - val stops = ToArray(mapIR(ToStream(GetTupleElement(ctx, 4))) { s => minIR(maxIR(s - j.toL * typ.blockSize.toLong, 0L), nCols) }) + val starts = ToArray(mapIR(ToStream(GetTupleElement(ctx, 3))) { s => + minIR(maxIR(s - j.toL * typ.blockSize.toLong, 0L), nCols) + }) + val stops = ToArray(mapIR(ToStream(GetTupleElement(ctx, 4))) { s => + minIR(maxIR(s - j.toL * typ.blockSize.toLong, 0L), nCols) + }) bindIRs(oldCtx) { case Seq(oldCtx) => invoke("zero_row_intervals", TNDArray(TFloat64, Nat(2)), blockIR(oldCtx), starts, stops) } @@ -795,7 +966,9 @@ class BlockMatrixStage2 private ( def toTableStage(ib: IRBuilder, ctx: ExecuteContext, bmTyp: BlockMatrixType): TableStage = { val bodyFreeVars = FreeVariables(_blockIR, supportsAgg = false, supportsScan = false) - val bcFields = broadcastVals.filter { case Ref(f, _) => bodyFreeVars.eval.lookupOption(f).isDefined } + val bcFields = broadcastVals.filter { case Ref(f, _) => + bodyFreeVars.eval.lookupOption(f).isDefined + } val contextsIR = ToStream(contexts.map(ib) { (rowIdx, colIdx, pos, oldContext) => maketuple(rowIdx, colIdx, oldContext) @@ -809,7 +982,8 @@ class BlockMatrixStage2 private ( val s = makestruct( "blockRow" -> GetTupleElement(newCtxRef, 0), "blockCol" -> GetTupleElement(newCtxRef, 1), - "block" -> Let(FastSeq(ctxRefName -> GetTupleElement(newCtxRef, 2)), _blockIR)) + "block" -> Let(FastSeq(ctxRefName -> GetTupleElement(newCtxRef, 2)), _blockIR), + ) MakeStream(FastSeq(s), TStream(s.typ)) } @@ -820,14 +994,16 @@ class BlockMatrixStage2 private ( RVDPartitioner.unkeyed(ctx.stateManager, bmTyp.nDefinedBlocks), TableStageDependency.none, contextsIR, - tsPartitionFunction) + tsPartitionFunction, + ) } def collectBlocks( ib: IRBuilder, staticID: String, - dynamicID: IR = NA(TString) - )(f: (IR, IR, IR) => IR // (ctx, pos, block) + dynamicID: IR = NA(TString), + )( + f: (IR, IR, IR) => IR // (ctx, pos, block) ): IR = { val posRef = Ref(genUID(), TInt32) val newCtxRef = Ref(genUID(), TTuple(TInt32, ctxType)) @@ -836,12 +1012,14 @@ class BlockMatrixStage2 private ( posRef.name -> GetTupleElement(newCtxRef, 0), ctxRefName -> GetTupleElement(newCtxRef, 1), ), - f(ctxRef, posRef, _blockIR) + f(ctxRef, posRef, _blockIR), ) val bodyFreeVars = FreeVariables(body, supportsAgg = false, supportsScan = false) - val bcFields = broadcastVals.filter { case Ref(f, _) => bodyFreeVars.eval.lookupOption(f).isDefined } - val bcVals = MakeStruct(bcFields.map { ref => ref.name -> ref }) + val bcFields = broadcastVals.filter { case Ref(f, _) => + bodyFreeVars.eval.lookupOption(f).isDefined + } + val bcVals = MakeStruct(bcFields.map(ref => ref.name -> ref)) val bcRef = Ref(genUID(), bcVals.typ) val wrappedBody = Let(bcFields.map(ref => ref.name -> GetField(bcRef, ref.name)), body) @@ -849,7 +1027,15 @@ class BlockMatrixStage2 private ( maketuple(pos, oldContext) }.contexts) - CollectDistributedArray(cdaContexts, bcVals, newCtxRef.name, bcRef.name, wrappedBody, dynamicID, staticID) + CollectDistributedArray( + cdaContexts, + bcVals, + newCtxRef.name, + bcRef.name, + wrappedBody, + dynamicID, + staticID, + ) } def collectLocal(ib: IRBuilder, staticID: String, dynamicID: IR = NA(TString)): IR = { @@ -868,7 +1054,12 @@ class BlockMatrixStage2 private ( } object LowerBlockMatrixIR { - def apply(node: IR, typesToLower: DArrayLowering.Type, ctx: ExecuteContext, analyses: LoweringAnalyses): IR = { + def apply( + node: IR, + typesToLower: DArrayLowering.Type, + ctx: ExecuteContext, + analyses: LoweringAnalyses, + ): IR = { def lower(bmir: BlockMatrixIR, ib: IRBuilder) = LowerBlockMatrixIR.lower(bmir, ib, typesToLower, ctx, analyses) @@ -880,21 +1071,32 @@ object LowerBlockMatrixIR { case BlockMatrixToValueApply(child, GetElement(IndexedSeq(i, j))) => lower(child, ib).getElement(i, j) case BlockMatrixWrite(child, writer) => - writer.lower(ctx, lower(child, ib), ib, TypeWithRequiredness(child.typ.elementType)) //FIXME: BlockMatrixIR is currently ignored in Requiredness inference since all eltTypes are +TFloat64 - case BlockMatrixMultiWrite(blockMatrices, writer) => unimplemented(ctx, node) + writer.lower( + ctx, + lower(child, ib), + ib, + TypeWithRequiredness(child.typ.elementType), + ) // FIXME: BlockMatrixIR is currently ignored in Requiredness inference since all eltTypes are +TFloat64 + case BlockMatrixMultiWrite(_, _) => unimplemented(ctx, node) case node if node.children.exists(_.isInstanceOf[BlockMatrixIR]) => - throw new LowererUnsupportedOperation(s"IR nodes with BlockMatrixIR children need explicit rules: \n${ Pretty(ctx, node) }") + throw new LowererUnsupportedOperation( + s"IR nodes with BlockMatrixIR children need explicit rules: \n${Pretty(ctx, node)}" + ) case node => - throw new LowererUnsupportedOperation(s"Value IRs with no BlockMatrixIR children must be lowered through LowerIR: \n${ Pretty(ctx, node) }") + throw new LowererUnsupportedOperation( + s"Value IRs with no BlockMatrixIR children must be lowered through LowerIR: \n${Pretty(ctx, node)}" + ) } } } // This lowers a BlockMatrixIR to an unkeyed TableStage with rows of (blockRow, blockCol, block) def lowerToTableStage( - bmir: BlockMatrixIR, typesToLower: DArrayLowering.Type, ctx: ExecuteContext, - analyses: LoweringAnalyses + bmir: BlockMatrixIR, + typesToLower: DArrayLowering.Type, + ctx: ExecuteContext, + analyses: LoweringAnalyses, ): TableStage = { val ib = new IRBuilder() val bms = lower(bmir, ib, typesToLower, ctx, analyses) @@ -902,14 +1104,24 @@ object LowerBlockMatrixIR { } private def unimplemented[T](ctx: ExecuteContext, node: BaseIR): T = - throw new LowererUnsupportedOperation(s"unimplemented: \n${ Pretty(ctx, node) }") + throw new LowererUnsupportedOperation(s"unimplemented: \n${Pretty(ctx, node)}") - def lower(bmir: BlockMatrixIR, ib: IRBuilder, typesToLower: DArrayLowering.Type, ctx: ExecuteContext, analyses: LoweringAnalyses): BlockMatrixStage2 = { + def lower( + bmir: BlockMatrixIR, + ib: IRBuilder, + typesToLower: DArrayLowering.Type, + ctx: ExecuteContext, + analyses: LoweringAnalyses, + ): BlockMatrixStage2 = { if (!DArrayLowering.lowerBM(typesToLower)) - throw new LowererUnsupportedOperation("found BlockMatrixIR in lowering; lowering only TableIRs.") + throw new LowererUnsupportedOperation( + "found BlockMatrixIR in lowering; lowering only TableIRs." + ) bmir.children.foreach { case c: BlockMatrixIR if c.typ.blockSize != bmir.typ.blockSize => - throw new LowererUnsupportedOperation(s"Can't lower node with mismatched block sizes: ${ bmir.typ.blockSize } vs child ${ c.typ.blockSize }\n\n ${ Pretty(ctx, bmir) }") + throw new LowererUnsupportedOperation( + s"Can't lower node with mismatched block sizes: ${bmir.typ.blockSize} vs child ${c.typ.blockSize}\n\n ${Pretty(ctx, bmir)}" + ) case _ => } if (bmir.typ.matrixShape == 0L -> 0L) @@ -922,7 +1134,7 @@ object LowerBlockMatrixIR { ib: IRBuilder, typesToLower: DArrayLowering.Type, ctx: ExecuteContext, - analyses: LoweringAnalyses + analyses: LoweringAnalyses, ): BlockMatrixStage2 = { def lower(ir: BlockMatrixIR, ib: IRBuilder = ib): BlockMatrixStage2 = @@ -931,7 +1143,7 @@ object LowerBlockMatrixIR { bmir match { case BlockMatrixRead(reader) => reader.lower(ctx, ib) - case x@BlockMatrixRandom(staticUID, gaussian, shape, blockSize) => + case x @ BlockMatrixRandom(staticUID, gaussian, _, _) => val contexts = BMSContexts.tabulate(x.typ, ib) { (rowIdx, colIdx) => val (m, n) = x.typ.blockShapeIR(rowIdx, colIdx) MakeTuple.ordered(FastSeq(m, n, rowIdx * x.typ.nColBlocks + colIdx)) @@ -949,9 +1161,7 @@ object LowerBlockMatrixIR { BlockMatrixStage2(FastSeq(), x.typ, contexts, bodyIR) case BlockMatrixMap(child, eltName, f, _) => - lower(child).mapBody { body => - NDArrayMap(body, eltName, f) - } + lower(child).mapBody(body => NDArrayMap(body, eltName, f)) case BlockMatrixMap2(left, right, lname, rname, f, sparsityStrategy) => val loweredLeft = lower(left) @@ -960,7 +1170,7 @@ object LowerBlockMatrixIR { NDArrayMap2(lBody, rBody, lname, rname, f, ErrorIDs.NO_ERROR) } - case x@BlockMatrixBroadcast(child, IndexedSeq(), _, _) => + case x @ BlockMatrixBroadcast(child, IndexedSeq(), _, _) => val elt = ib.strictMemoize(IRBuilder.scoped { ib => val lowered = lower(child, ib) lowered.getElement(0L, 0L) @@ -976,21 +1186,30 @@ object LowerBlockMatrixIR { x.typ, contexts, (ctxRef: Ref) => - MakeNDArray.fill(elt, FastSeq(GetTupleElement(ctxRef, 0), GetTupleElement(ctxRef, 1)), True())) + MakeNDArray.fill( + elt, + FastSeq(GetTupleElement(ctxRef, 0), GetTupleElement(ctxRef, 1)), + True(), + ), + ) - case x@BlockMatrixBroadcast(child, IndexedSeq(axis), _, _) => + case x @ BlockMatrixBroadcast(child, IndexedSeq(axis), _, _) => val len = child.typ.shape.max val vector = NDArrayReshape( IRBuilder.scoped { ib => lower(child, ib).collectLocal(ib, "block_matrix_broadcast_single_axis") }, MakeTuple.ordered(FastSeq(I64(len))), - ErrorIDs.NO_ERROR) + ErrorIDs.NO_ERROR, + ) BlockMatrixStage2.broadcastVector(vector, ib, x.typ, asRowVector = axis == 1) - case x@BlockMatrixBroadcast(child, IndexedSeq(axis, axis2), _, _) if (axis == axis2) => // diagonal as row/col vector + case x @ BlockMatrixBroadcast(child, IndexedSeq(axis, axis2), _, _) + if (axis == axis2) => // diagonal as row/col vector val diagLen = math.min(child.typ.nRowBlocks, child.typ.nColBlocks) - val diagType = x.typ.copy(sparsity = BlockMatrixSparsity(Some(IndexedSeq.tabulate(diagLen)(i => (i, i))))) + val diagType = x.typ.copy(sparsity = + BlockMatrixSparsity(Some(IndexedSeq.tabulate(diagLen)(i => (i, i)))) + ) val rowPos = if (child.typ.nColBlocks > diagLen) ToArray(mapIR(rangeIR(child.typ.nColBlocks + 1))(i => minIR(i, diagLen))) else @@ -1009,24 +1228,35 @@ object LowerBlockMatrixIR { } } - val diagVector = MakeNDArray(ToArray(flatten(diagArray)), maketuple(math.min(child.typ.nRows, child.typ.nCols)), true, ErrorIDs.NO_ERROR) + val diagVector = MakeNDArray( + ToArray(flatten(diagArray)), + maketuple(math.min(child.typ.nRows, child.typ.nCols)), + true, + ErrorIDs.NO_ERROR, + ) BlockMatrixStage2.broadcastVector(diagVector, ib, x.typ, asRowVector = axis == 0) - case x@BlockMatrixBroadcast(child, IndexedSeq(1, 0), _, _) => //transpose + case BlockMatrixBroadcast(child, IndexedSeq(1, 0), _, _) => // transpose lower(child).transposed(ib) case BlockMatrixBroadcast(child, IndexedSeq(0, 1), _, _) => lower(child) - case x@BlockMatrixFilter(child, keep) => + case x @ BlockMatrixFilter(child, keep) => val Array(keepRow, keepCol) = keep lower(child).filter(keepRow, keepCol, x.typ, ib) case BlockMatrixDensify(child) => lower(child).densify(ib) - case x@BlockMatrixSparsify(child, sparsifier) => + case x @ BlockMatrixSparsify(child, sparsifier) => val Some((rowPos, rowIdx)) = x.typ.sparsity.definedBlocksCSCIR(x.typ.nColBlocks) - val loweredChild = lower(child).withSparsity(ib.memoize(rowPos), ib.memoize(rowIdx), ib, x.typ, isSubset = true) + val loweredChild = lower(child).withSparsity( + ib.memoize(rowPos), + ib.memoize(rowIdx), + ib, + x.typ, + isSubset = true, + ) sparsifier match { // these cases are all handled at the type level case BandSparsifier(blocksOnly, _, _) if (blocksOnly) => loweredChild @@ -1034,14 +1264,25 @@ object LowerBlockMatrixIR { case PerBlockSparsifier(_) | RectangleSparsifier(_) => loweredChild case BandSparsifier(_, l, u) => loweredChild.zeroBand(l, u, x.typ, ib) - case RowIntervalSparsifier(_, starts, stops) => loweredChild.zeroRowIntervals(starts, stops, x.typ, ib) + case RowIntervalSparsifier(_, starts, stops) => + loweredChild.zeroRowIntervals(starts, stops, x.typ, ib) } case _ => - BlockMatrixStage2.fromOldBMS(lowerNonEmpty(bmir, ib, typesToLower, ctx, analyses), bmir.typ, ib) + BlockMatrixStage2.fromOldBMS( + lowerNonEmpty(bmir, ib, typesToLower, ctx, analyses), + bmir.typ, + ib, + ) } } - def lowerNonEmpty(bmir: BlockMatrixIR, ib: IRBuilder, typesToLower: DArrayLowering.Type, ctx: ExecuteContext, analyses: LoweringAnalyses): BlockMatrixStage = { + def lowerNonEmpty( + bmir: BlockMatrixIR, + ib: IRBuilder, + typesToLower: DArrayLowering.Type, + ctx: ExecuteContext, + analyses: LoweringAnalyses, + ): BlockMatrixStage = { def lower(ir: BlockMatrixIR, ib: IRBuilder = ib) = LowerBlockMatrixIR.lower(ir, ib, typesToLower, ctx, analyses).toOldBMS @@ -1049,150 +1290,259 @@ object LowerBlockMatrixIR { bmir match { - case a@BlockMatrixAgg(child, axesToSumOut) => + case BlockMatrixAgg(child, axesToSumOut) => val loweredChild = lower(child) axesToSumOut match { - case IndexedSeq(0, 1) => + case IndexedSeq(0, 1) => val summedChild = loweredChild.mapBody { (ctx, body) => - NDArrayReshape(NDArrayAgg(body, IndexedSeq(0, 1)), MakeTuple.ordered(FastSeq(I64(1), I64(1))), ErrorIDs.NO_ERROR) + NDArrayReshape( + NDArrayAgg(body, IndexedSeq(0, 1)), + MakeTuple.ordered(FastSeq(I64(1), I64(1))), + ErrorIDs.NO_ERROR, + ) } - val summedChildType = BlockMatrixType(child.typ.elementType, IndexedSeq[Long](child.typ.nRowBlocks, child.typ.nColBlocks), child.typ.nRowBlocks == 1, 1, BlockMatrixSparsity.dense) - val res = NDArrayAgg(summedChild.collectLocal(summedChildType, "block_matrix_agg"), IndexedSeq[Int](0, 1)) + val summedChildType = BlockMatrixType( + child.typ.elementType, + IndexedSeq[Long](child.typ.nRowBlocks, child.typ.nColBlocks), + child.typ.nRowBlocks == 1, + 1, + BlockMatrixSparsity.dense, + ) + val res = NDArrayAgg( + summedChild.collectLocal(summedChildType, "block_matrix_agg"), + IndexedSeq[Int](0, 1), + ) new BlockMatrixStage(summedChild.broadcastVals, TStruct.empty) { override def blockContext(idx: (Int, Int)): IR = makestruct() - override def blockBody(ctxRef: Ref): IR = NDArrayReshape(res, MakeTuple.ordered(FastSeq(I64(1L), I64(1L))), ErrorIDs.NO_ERROR) + override def blockBody(ctxRef: Ref): IR = + NDArrayReshape(res, MakeTuple.ordered(FastSeq(I64(1L), I64(1L))), ErrorIDs.NO_ERROR) } - case IndexedSeq(0) => { // Number of rows goes to 1. Number of cols remains the same. + case IndexedSeq(0) => // Number of rows goes to 1. Number of cols remains the same. new BlockMatrixStage(loweredChild.broadcastVals, TArray(loweredChild.ctxType)) { override def blockContext(idx: (Int, Int)): IR = { val (row, col) = idx - assert(row == 0, s"Asked for idx ${idx}") + assert(row == 0, s"Asked for idx $idx") MakeArray( - (0 until child.typ.nRowBlocks).map(childRow => loweredChild.blockContext((childRow, col))), - TArray(loweredChild.ctxType) + (0 until child.typ.nRowBlocks).map(childRow => + loweredChild.blockContext((childRow, col)) + ), + TArray(loweredChild.ctxType), ) } override def blockBody(ctxRef: Ref): IR = { - val summedChildBlocks = mapIR(ToStream(ctxRef))(singleChildCtx => { - bindIR(NDArrayAgg(loweredChild.blockBody(singleChildCtx), axesToSumOut))(aggedND => NDArrayReshape(aggedND, MakeTuple.ordered(FastSeq(I64(1), GetTupleElement(NDArrayShape(aggedND), 0))), ErrorIDs.NO_ERROR)) - }) + val summedChildBlocks = mapIR(ToStream(ctxRef)) { singleChildCtx => + bindIR(NDArrayAgg(loweredChild.blockBody(singleChildCtx), axesToSumOut))( + aggedND => + NDArrayReshape( + aggedND, + MakeTuple.ordered(FastSeq( + I64(1), + GetTupleElement(NDArrayShape(aggedND), 0), + )), + ErrorIDs.NO_ERROR, + ) + ) + } val aggVar = genUID() - StreamAgg(summedChildBlocks, aggVar, ApplyAggOp(NDArraySum())(Ref(aggVar, summedChildBlocks.typ.asInstanceOf[TStream].elementType))) + StreamAgg( + summedChildBlocks, + aggVar, + ApplyAggOp(NDArraySum())(Ref( + aggVar, + summedChildBlocks.typ.asInstanceOf[TStream].elementType, + )), + ) } } - } - case IndexedSeq(1) => { // Number of cols goes to 1. Number of rows remains the same. + case IndexedSeq(1) => // Number of cols goes to 1. Number of rows remains the same. new BlockMatrixStage(loweredChild.broadcastVals, TArray(loweredChild.ctxType)) { override def blockContext(idx: (Int, Int)): IR = { val (row, col) = idx - assert(col == 0, s"Asked for idx ${idx}") + assert(col == 0, s"Asked for idx $idx") MakeArray( - (0 until child.typ.nColBlocks).map(childCol => loweredChild.blockContext((row, childCol))), - TArray(loweredChild.ctxType) + (0 until child.typ.nColBlocks).map(childCol => + loweredChild.blockContext((row, childCol)) + ), + TArray(loweredChild.ctxType), ) } override def blockBody(ctxRef: Ref): IR = { - val summedChildBlocks = mapIR(ToStream(ctxRef))(singleChildCtx => { + val summedChildBlocks = mapIR(ToStream(ctxRef)) { singleChildCtx => bindIR(NDArrayAgg(loweredChild.blockBody(singleChildCtx), axesToSumOut)) { - aggedND => NDArrayReshape(aggedND, MakeTuple(FastSeq(0 -> GetTupleElement(NDArrayShape(aggedND), 0), 1 -> I64(1))), ErrorIDs.NO_ERROR) + aggedND => + NDArrayReshape( + aggedND, + MakeTuple(FastSeq( + 0 -> GetTupleElement(NDArrayShape(aggedND), 0), + 1 -> I64(1), + )), + ErrorIDs.NO_ERROR, + ) } - }) + } val aggVar = genUID() - StreamAgg(summedChildBlocks, aggVar, ApplyAggOp(NDArraySum())(Ref(aggVar, summedChildBlocks.typ.asInstanceOf[TStream].elementType))) + StreamAgg( + summedChildBlocks, + aggVar, + ApplyAggOp(NDArraySum())(Ref( + aggVar, + summedChildBlocks.typ.asInstanceOf[TStream].elementType, + )), + ) } } - } } - case x@BlockMatrixSlice(child, IndexedSeq(IndexedSeq(rStart, rEnd, rStep), IndexedSeq(cStart, cEnd, cStep))) => + case x @ BlockMatrixSlice( + child, + IndexedSeq(IndexedSeq(rStart, rEnd, rStep), IndexedSeq(cStart, cEnd, cStep)), + ) => val rowDependents = x.rowBlockDependents val colDependents = x.colBlockDependents lower(child).condenseBlocks(child.typ, rowDependents, colDependents) - .addContext(TTuple(TTuple(TInt64, TInt64, TInt64), TTuple(TInt64, TInt64, TInt64))) { idx => - val (i, j) = idx - - // Aligned with the edges of blocks in child BM. - val blockAlignedRowStartIdx = rowDependents(i).head.toLong * x.typ.blockSize - val blockAlignedColStartIdx = colDependents(j).head.toLong * x.typ.blockSize - val blockAlignedRowEndIdx = math.min(child.typ.nRows, (rowDependents(i).last + 1L) * x.typ.blockSize * rStep) - val blockAlignedColEndIdx = math.min(child.typ.nCols, (colDependents(j).last + 1L) * x.typ.blockSize * cStep) - - // condenseBlocks can give the same data to multiple partitions. Need to make sure we don't use data - // that's already included in an earlier block. - val rStartPlusSeenAlready = rStart + i * x.typ.blockSize * rStep - val cStartPlusSeenAlready = cStart + j * x.typ.blockSize * cStep - - val rowTrueStart = rStartPlusSeenAlready - blockAlignedRowStartIdx - val rowTrueEnd = math.min(math.min(rEnd, blockAlignedRowEndIdx) - blockAlignedRowStartIdx, rowTrueStart + x.typ.blockSize * rStep) - val rows = MakeTuple.ordered(FastSeq[IR]( - rowTrueStart, - rowTrueEnd, - rStep)) - - val colTrueStart = cStartPlusSeenAlready - blockAlignedColStartIdx - val colTrueEnd = math.min(java.lang.Math.min(cEnd, blockAlignedColEndIdx) - blockAlignedColStartIdx, colTrueStart + x.typ.blockSize * cStep) - val cols = MakeTuple.ordered(FastSeq[IR]( - colTrueStart, - colTrueEnd, - cStep)) - MakeTuple.ordered(FastSeq(rows, cols)) - }.mapBody { (ctx, body) => NDArraySlice(body, GetField(ctx, "new")) } - - case RelationalLetBlockMatrix(name, value, body) => unimplemented(ctx, bmir) - - case ValueToBlockMatrix(child, shape, blockSize) if !child.typ.isInstanceOf[TArray] && !child.typ.isInstanceOf[TNDArray] => { + .addContext(TTuple(TTuple(TInt64, TInt64, TInt64), TTuple(TInt64, TInt64, TInt64))) { + idx => + val (i, j) = idx + + // Aligned with the edges of blocks in child BM. + val blockAlignedRowStartIdx = rowDependents(i).head.toLong * x.typ.blockSize + val blockAlignedColStartIdx = colDependents(j).head.toLong * x.typ.blockSize + val blockAlignedRowEndIdx = + math.min(child.typ.nRows, (rowDependents(i).last + 1L) * x.typ.blockSize * rStep) + val blockAlignedColEndIdx = + math.min(child.typ.nCols, (colDependents(j).last + 1L) * x.typ.blockSize * cStep) + + /* condenseBlocks can give the same data to multiple partitions. Need to make sure we + * don't use data */ + // that's already included in an earlier block. + val rStartPlusSeenAlready = rStart + i * x.typ.blockSize * rStep + val cStartPlusSeenAlready = cStart + j * x.typ.blockSize * cStep + + val rowTrueStart = rStartPlusSeenAlready - blockAlignedRowStartIdx + val rowTrueEnd = math.min( + math.min(rEnd, blockAlignedRowEndIdx) - blockAlignedRowStartIdx, + rowTrueStart + x.typ.blockSize * rStep, + ) + val rows = MakeTuple.ordered(FastSeq[IR]( + rowTrueStart, + rowTrueEnd, + rStep, + )) + + val colTrueStart = cStartPlusSeenAlready - blockAlignedColStartIdx + val colTrueEnd = math.min( + java.lang.Math.min(cEnd, blockAlignedColEndIdx) - blockAlignedColStartIdx, + colTrueStart + x.typ.blockSize * cStep, + ) + val cols = MakeTuple.ordered(FastSeq[IR]( + colTrueStart, + colTrueEnd, + cStep, + )) + MakeTuple.ordered(FastSeq(rows, cols)) + }.mapBody((ctx, body) => NDArraySlice(body, GetField(ctx, "new"))) + + case RelationalLetBlockMatrix(_, _, _) => unimplemented(ctx, bmir) + + case ValueToBlockMatrix(child, _, _) + if !child.typ.isInstanceOf[TArray] && !child.typ.isInstanceOf[TNDArray] => val element = lowerIR(child) new BlockMatrixStage(FastSeq(), TStruct()) { override def blockContext(idx: (Int, Int)): IR = MakeStruct(FastSeq()) - override def blockBody(ctxRef: Ref): IR = MakeNDArray(MakeArray(element), MakeTuple(FastSeq((0, I64(1)), (1, I64(1)))), False(), ErrorIDs.NO_ERROR) + override def blockBody(ctxRef: Ref): IR = MakeNDArray( + MakeArray(element), + MakeTuple(FastSeq((0, I64(1)), (1, I64(1)))), + False(), + ErrorIDs.NO_ERROR, + ) } - } - case x@ValueToBlockMatrix(child, _, blockSize) => + case x @ ValueToBlockMatrix(child, _, blockSize) => val nd = ib.memoize(child.typ match { - case _: TArray => MakeNDArray(lowerIR(child), MakeTuple.ordered(FastSeq(I64(x.typ.nRows), I64(x.typ.nCols))), True(), ErrorIDs.NO_ERROR) + case _: TArray => MakeNDArray( + lowerIR(child), + MakeTuple.ordered(FastSeq(I64(x.typ.nRows), I64(x.typ.nCols))), + True(), + ErrorIDs.NO_ERROR, + ) case _: TNDArray => lowerIR(child) }) new BlockMatrixStage(FastSeq(), nd.typ) { def blockContext(idx: (Int, Int)): IR = { val (r, c) = idx - NDArraySlice(nd, MakeTuple.ordered(FastSeq( - MakeTuple.ordered(FastSeq(I64(r.toLong * blockSize), I64(java.lang.Math.min((r.toLong + 1) * blockSize, x.typ.nRows)), I64(1))), - MakeTuple.ordered(FastSeq(I64(c.toLong * blockSize), I64(java.lang.Math.min((c.toLong + 1) * blockSize, x.typ.nCols)), I64(1)))))) + NDArraySlice( + nd, + MakeTuple.ordered(FastSeq( + MakeTuple.ordered(FastSeq( + I64(r.toLong * blockSize), + I64(java.lang.Math.min((r.toLong + 1) * blockSize, x.typ.nRows)), + I64(1), + )), + MakeTuple.ordered(FastSeq( + I64(c.toLong * blockSize), + I64(java.lang.Math.min((c.toLong + 1) * blockSize, x.typ.nCols)), + I64(1), + )), + )), + ) } def blockBody(ctxRef: Ref): IR = ctxRef } - case x@BlockMatrixDot(leftIR, rightIR) => + case BlockMatrixDot(leftIR, rightIR) => val left = lower(leftIR) val right = lower(rightIR) val newCtxType = TArray(TTuple(left.ctxType, right.ctxType)) new BlockMatrixStage(left.broadcastVals ++ right.broadcastVals, newCtxType) { def blockContext(idx: (Int, Int)): IR = { val (i, j) = idx - MakeArray(Array.tabulate[Option[IR]](leftIR.typ.nColBlocks) { k => - if (leftIR.typ.hasBlock(i -> k) && rightIR.typ.hasBlock(k -> j)) - Some(MakeTuple.ordered(FastSeq( - left.blockContext(i -> k), right.blockContext(k -> j)))) - else None - }.flatten[IR], newCtxType) + MakeArray( + Array.tabulate[Option[IR]](leftIR.typ.nColBlocks) { k => + if (leftIR.typ.hasBlock(i -> k) && rightIR.typ.hasBlock(k -> j)) + Some(MakeTuple.ordered(FastSeq( + left.blockContext(i -> k), + right.blockContext(k -> j), + ))) + else None + }.flatten[IR], + newCtxType, + ) } def blockBody(ctxRef: Ref): IR = { val tupleNDArrayStream = ToStream(ctxRef) val streamElementName = genUID() - val streamElementRef = Ref(streamElementName, tupleNDArrayStream.typ.asInstanceOf[TStream].elementType) + val streamElementRef = + Ref(streamElementName, tupleNDArrayStream.typ.asInstanceOf[TStream].elementType) val leftName = genUID() val rightName = genUID() - val leftRef = Ref(leftName, tupleNDArrayStream.typ.asInstanceOf[TStream].elementType.asInstanceOf[TTuple].types(0)) - val rightRef = Ref(rightName, tupleNDArrayStream.typ.asInstanceOf[TStream].elementType.asInstanceOf[TTuple].types(1)) - StreamAgg(tupleNDArrayStream, streamElementName, { - AggLet(leftName, GetTupleElement(streamElementRef, 0), - AggLet(rightName, GetTupleElement(streamElementRef, 1), - ApplyAggOp(NDArrayMultiplyAdd())(left.blockBody(leftRef), - right.blockBody(rightRef)), isScan=false), isScan=false) - }) + val leftRef = Ref( + leftName, + tupleNDArrayStream.typ.asInstanceOf[TStream].elementType.asInstanceOf[TTuple].types(0), + ) + val rightRef = Ref( + rightName, + tupleNDArrayStream.typ.asInstanceOf[TStream].elementType.asInstanceOf[TTuple].types(1), + ) + StreamAgg( + tupleNDArrayStream, + streamElementName, + AggLet( + leftName, + GetTupleElement(streamElementRef, 0), + AggLet( + rightName, + GetTupleElement(streamElementRef, 1), + ApplyAggOp(NDArrayMultiplyAdd())( + left.blockBody(leftRef), + right.blockBody(rightRef), + ), + isScan = false, + ), + isScan = false, + ), + ) } } } diff --git a/hail/src/main/scala/is/hail/expr/ir/lowering/LowerDistributedSort.scala b/hail/src/main/scala/is/hail/expr/ir/lowering/LowerDistributedSort.scala index 728aee0c64e..8330fa7c6dc 100644 --- a/hail/src/main/scala/is/hail/expr/ir/lowering/LowerDistributedSort.scala +++ b/hail/src/main/scala/is/hail/expr/ir/lowering/LowerDistributedSort.scala @@ -1,46 +1,59 @@ package is.hail.expr.ir.lowering import is.hail.annotations.{Annotation, ExtendedOrdering, Region, SafeRow} -import is.hail.asm4s.{AsmFunction1RegionLong, LongInfo, classInfo} +import is.hail.asm4s.{classInfo, AsmFunction1RegionLong, LongInfo} import is.hail.backend.{ExecuteContext, HailStateManager} import is.hail.expr.ir._ import is.hail.expr.ir.functions.{ArrayFunctions, IRRandomness, UtilFunctions} import is.hail.io.{BufferSpec, TypedCodecSpec} import is.hail.rvd.RVDPartitioner -import is.hail.types.physical.stypes.PTypeReferenceSingleCodeType +import is.hail.types.{tcoerce, RTable, TableType, VirtualTypeWithReq} import is.hail.types.physical.{PArray, PStruct} +import is.hail.types.physical.stypes.PTypeReferenceSingleCodeType import is.hail.types.virtual._ -import is.hail.types.{RTable, TableType, VirtualTypeWithReq, tcoerce} import is.hail.utils._ + +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.sql.Row import org.json4s.JValue import org.json4s.JsonAST.JString -import scala.collection.mutable.ArrayBuffer - object LowerDistributedSort { - def localSort(ctx: ExecuteContext, inputStage: TableStage, sortFields: IndexedSeq[SortField], rt: RTable): TableReader = { + def localSort( + ctx: ExecuteContext, + inputStage: TableStage, + sortFields: IndexedSeq[SortField], + rt: RTable, + ): TableReader = { val numPartitions = inputStage.partitioner.numPartitions - val collected = inputStage.collectWithGlobals( "shuffle_local_sort") - - val (Some(PTypeReferenceSingleCodeType(resultPType: PStruct)), f) = ctx.timer.time("LowerDistributedSort.localSort.compile")(Compile[AsmFunction1RegionLong](ctx, - FastSeq(), - FastSeq(classInfo[Region]), LongInfo, - collected, - print = None, - optimize = true)) - - val rowsAndGlobal = ctx.scopedExecution{ (hcl, fs, htc, r) => - val fRunnable = ctx.timer.time("LowerDistributedSort.localSort.initialize")(f(hcl, fs, htc, r)) + val collected = inputStage.collectWithGlobals("shuffle_local_sort") + + val (Some(PTypeReferenceSingleCodeType(resultPType: PStruct)), f) = + ctx.timer.time("LowerDistributedSort.localSort.compile")(Compile[AsmFunction1RegionLong]( + ctx, + FastSeq(), + FastSeq(classInfo[Region]), + LongInfo, + collected, + print = None, + optimize = true, + )) + + val rowsAndGlobal = ctx.scopedExecution { (hcl, fs, htc, r) => + val fRunnable = + ctx.timer.time("LowerDistributedSort.localSort.initialize")(f(hcl, fs, htc, r)) val resultAddress = ctx.timer.time("LowerDistributedSort.localSort.run")(fRunnable(ctx.r)) - ctx.timer.time("LowerDistributedSort.localSort.toJavaObject")(SafeRow.read(resultPType, resultAddress)).asInstanceOf[Row] + ctx.timer.time("LowerDistributedSort.localSort.toJavaObject")(SafeRow.read( + resultPType, + resultAddress, + )).asInstanceOf[Row] } val rowsType = resultPType.fieldType("rows").asInstanceOf[PArray] val rowType = rowsType.elementType.asInstanceOf[PStruct] val rows = rowsAndGlobal.getAs[IndexedSeq[Annotation]](0) - val kType = TStruct(sortFields.map(f => (f.field, rowType.virtualType.fieldType(f.field))): _*) val sortedRows = localAnnotationSort(ctx, rows, sortFields, rowType.virtualType) @@ -48,23 +61,36 @@ object LowerDistributedSort { val itemsPerPartition = math.max((sortedRows.length.toDouble / nPartitionsAdj).ceil.toInt, 1) // partitioner needs keys to be ascending - val partitionerKeyType = TStruct(sortFields.takeWhile(_.sortOrder == Ascending).map(f => (f.field, rowType.virtualType.fieldType(f.field))): _*) + val partitionerKeyType = TStruct(sortFields.takeWhile(_.sortOrder == Ascending).map(f => + (f.field, rowType.virtualType.fieldType(f.field)) + ): _*) val partitionerKeyIndex = partitionerKeyType.fieldNames.map(f => rowType.fieldIdx(f)) - val partitioner = new RVDPartitioner(ctx.stateManager, partitionerKeyType, + val partitioner = new RVDPartitioner( + ctx.stateManager, + partitionerKeyType, sortedRows.grouped(itemsPerPartition).map { group => val first = group.head.asInstanceOf[Row].select(partitionerKeyIndex) val last = group.last.asInstanceOf[Row].select(partitionerKeyIndex) Interval(first, last, includesStart = true, includesEnd = true) - }.toIndexedSeq) + }.toIndexedSeq, + ) val globalsIR = Literal(resultPType.fieldType("global").virtualType, rowsAndGlobal.get(1)) LocalSortReader(sortedRows, rowType.virtualType, globalsIR, partitioner, itemsPerPartition, rt) } - case class LocalSortReader(sortedRows: IndexedSeq[Annotation], rowType: TStruct, globals: IR, partitioner: RVDPartitioner, itemsPerPartition: Int, rt: RTable) extends TableReader { - lazy val fullType: TableType = TableType(rowType, partitioner.kType.fieldNames, globals.typ.asInstanceOf[TStruct]) + case class LocalSortReader( + sortedRows: IndexedSeq[Annotation], + rowType: TStruct, + globals: IR, + partitioner: RVDPartitioner, + itemsPerPartition: Int, + rt: RTable, + ) extends TableReader { + lazy val fullType: TableType = + TableType(rowType, partitioner.kType.fieldNames, globals.typ.asInstanceOf[TStruct]) override def pathsUsed: Seq[String] = FastSeq() @@ -72,13 +98,11 @@ object LowerDistributedSort { override def isDistinctlyKeyed: Boolean = false // FIXME: No default value - def rowRequiredness(ctx: ExecuteContext, requestedType: TableType): VirtualTypeWithReq = { + def rowRequiredness(ctx: ExecuteContext, requestedType: TableType): VirtualTypeWithReq = VirtualTypeWithReq.subset(requestedType.rowType, rt.rowType) - } - def globalRequiredness(ctx: ExecuteContext, requestedType: TableType): VirtualTypeWithReq = { + def globalRequiredness(ctx: ExecuteContext, requestedType: TableType): VirtualTypeWithReq = VirtualTypeWithReq.subset(requestedType.globalType, rt.globalType) - } override def toJValue: JValue = JString("LocalSortReader") @@ -97,9 +121,11 @@ object LowerDistributedSort { contexts = mapIR( StreamGrouped( ToStream(Literal(TArray(rowType), sortedRows)), - I32(itemsPerPartition)) + I32(itemsPerPartition), + ) )(ToArray(_)), - ctxRef => ToStream(ctxRef)) + ctxRef => ToStream(ctxRef), + ) .upcast(ctx, requestedType) } } @@ -108,7 +134,7 @@ object LowerDistributedSort { ctx: ExecuteContext, annotations: IndexedSeq[Annotation], sortFields: IndexedSeq[SortField], - rowType: TStruct + rowType: TStruct, ): IndexedSeq[Annotation] = { val sortColIndexOrd = sortFields.map { case SortField(n, so) => val i = rowType.fieldIdx(n) @@ -121,18 +147,17 @@ object LowerDistributedSort { val kType = TStruct(sortFields.map(f => (f.field, rowType.fieldType(f.field))): _*) val kIndex = kType.fieldNames.map(f => rowType.fieldIdx(f)) - ctx.timer.time("LowerDistributedSort.localSort.sort")(annotations.sortBy{ a: Annotation => + ctx.timer.time("LowerDistributedSort.localSort.sort")(annotations.sortBy { a: Annotation => a.asInstanceOf[Row].select(kIndex).asInstanceOf[Annotation] }(ord)) } - def distributedSort( ctx: ExecuteContext, inputStage: TableStage, sortFields: IndexedSeq[SortField], tableRequiredness: RTable, - optTargetNumPartitions: Option[Int] = None + optTargetNumPartitions: Option[Int] = None, ): TableReader = { val oversamplingNum = 3 @@ -149,140 +174,259 @@ object LowerDistributedSort { val (keyToSortBy, _) = inputStage.rowType.select(sortFields.map(sf => sf.field)) - val spec = TypedCodecSpec(rowTypeRequiredness.canonicalPType(inputStage.rowType), BufferSpec.wireSpec) + val spec = + TypedCodecSpec(rowTypeRequiredness.canonicalPType(inputStage.rowType), BufferSpec.wireSpec) val reader = PartitionNativeReader(spec, "__dummy_uid") val initialTmpPath = ctx.createTmpPath("hail_shuffle_temp_initial") - val writer = PartitionNativeWriter(spec, keyToSortBy.fieldNames, initialTmpPath, None, None, trackTotalBytes = true) + val writer = PartitionNativeWriter( + spec, + keyToSortBy.fieldNames, + initialTmpPath, + None, + None, + trackTotalBytes = true, + ) log.info("DISTRIBUTED SORT: PHASE 1: WRITE DATA") - val initialStageDataRow = CompileAndEvaluate[Annotation](ctx, inputStage.mapCollectWithGlobals("shuffle_initial_write") { part => - WritePartition(part, UUID4(), writer) - }{ case (part, globals) => - val streamElement = Ref(genUID(), part.typ.asInstanceOf[TArray].elementType) - bindIR(StreamAgg(ToStream(part), streamElement.name, - MakeStruct(FastSeq( - "min" -> AggFold.min(GetField(streamElement, "firstKey"), sortFields), - "max" -> AggFold.max(GetField(streamElement, "lastKey"), sortFields) - )) - )) { intervalRange => MakeTuple.ordered(FastSeq(part, globals, intervalRange)) } - }).asInstanceOf[Row] - val (initialPartInfo, initialGlobals, intervalRange) = (initialStageDataRow(0).asInstanceOf[IndexedSeq[Row]], initialStageDataRow(1).asInstanceOf[Row], initialStageDataRow(2).asInstanceOf[Row]) + val initialStageDataRow = CompileAndEvaluate[Annotation]( + ctx, + inputStage.mapCollectWithGlobals("shuffle_initial_write") { part => + WritePartition(part, UUID4(), writer) + } { case (part, globals) => + val streamElement = Ref(genUID(), part.typ.asInstanceOf[TArray].elementType) + bindIR(StreamAgg( + ToStream(part), + streamElement.name, + MakeStruct(FastSeq( + "min" -> AggFold.min(GetField(streamElement, "firstKey"), sortFields), + "max" -> AggFold.max(GetField(streamElement, "lastKey"), sortFields), + )), + ))(intervalRange => MakeTuple.ordered(FastSeq(part, globals, intervalRange))) + }, + ).asInstanceOf[Row] + val (initialPartInfo, initialGlobals, intervalRange) = ( + initialStageDataRow(0).asInstanceOf[IndexedSeq[Row]], + initialStageDataRow(1).asInstanceOf[Row], + initialStageDataRow(2).asInstanceOf[Row], + ) val initialGlobalsLiteral = Literal(inputStage.globalType, initialGlobals) - val initialChunks = initialPartInfo.map(row => Chunk(initialTmpPath + row(0).asInstanceOf[String], row(1).asInstanceOf[Long], row.getLong(5))) + val initialChunks = initialPartInfo.map(row => + Chunk(initialTmpPath + row(0).asInstanceOf[String], row(1).asInstanceOf[Long], row.getLong(5)) + ) val initialInterval = Interval(intervalRange(0), intervalRange(1), true, true) val initialSegment = SegmentResult(IndexedSeq(0), initialInterval, initialChunks) val totalNumberOfRows = initialChunks.map(_.size).sum - optTargetNumPartitions.foreach(i => assert(i >= 1, s"Must request positive number of partitions. Requested ${i}")) + optTargetNumPartitions.foreach(i => + assert(i >= 1, s"Must request positive number of partitions. Requested $i") + ) val targetNumPartitions = optTargetNumPartitions.getOrElse(inputStage.numPartitions) - val idealNumberOfRowsPerPart: Long = if (targetNumPartitions == 0) 1 else { + val idealNumberOfRowsPerPart: Long = if (targetNumPartitions == 0) 1 + else { Math.max(1L, totalNumberOfRows / targetNumPartitions) } - var loopState = LoopState(IndexedSeq(initialSegment), IndexedSeq.empty[SegmentResult], IndexedSeq.empty[OutputPartition]) + var loopState = LoopState( + IndexedSeq(initialSegment), + IndexedSeq.empty[SegmentResult], + IndexedSeq.empty[OutputPartition], + ) var i = 0 val rand = new IRRandomness(seed) - /* - Loop state keeps track of three things. largeSegments are too big to sort locally so have to broken up. - smallSegments are small enough to be sorted locally. readyOutputParts are any partitions that we noticed were - sorted already during course of the recursion. Loop continues until there are no largeSegments left. Then we - sort the small segments and combine them with readyOutputParts to get the final table. - */ + /* Loop state keeps track of three things. largeSegments are too big to sort locally so have to + * broken up. + * smallSegments are small enough to be sorted locally. readyOutputParts are any partitions that + * we noticed were sorted already during course of the recursion. Loop continues until there are + * no largeSegments left. Then we sort the small segments and combine them with readyOutputParts + * to get the final table. */ while (!loopState.largeSegments.isEmpty) { - val partitionDataPerSegment = segmentsToPartitionData(loopState.largeSegments, idealNumberOfRowsPerPart) + val partitionDataPerSegment = + segmentsToPartitionData(loopState.largeSegments, idealNumberOfRowsPerPart) assert(partitionDataPerSegment.size == loopState.largeSegments.size) val numSamplesPerPartitionPerSegment = partitionDataPerSegment.map { partData => val partitionCountsForOneSegment = partData.map(_.currentPartSize) val recordsInSegment = partitionCountsForOneSegment.sum val branchingFactor = math.min(recordsInSegment, defaultBranchingFactor) - howManySamplesPerPartition(rand, recordsInSegment, Math.min(recordsInSegment, (branchingFactor * oversamplingNum) - 1).toInt, partitionCountsForOneSegment) + howManySamplesPerPartition( + rand, + recordsInSegment, + Math.min(recordsInSegment, (branchingFactor * oversamplingNum) - 1).toInt, + partitionCountsForOneSegment, + ) } val numSamplesPerPartition = numSamplesPerPartitionPerSegment.flatten - val perPartStatsCDAContextData = partitionDataPerSegment.flatten.zip(numSamplesPerPartition).map { case (partData, numSamples) => - Row(partData.indices.last, partData.files, coerceToInt(partData.currentPartSize), numSamples, partData.currentPartByteSize) - } - val perPartStatsCDAContexts = ToStream(Literal(TArray(TStruct( - "segmentIdx" -> TInt32, - "files" -> TArray(TString), - "sizeOfPartition" -> TInt32, - "numSamples" -> TInt32, - "byteSize" -> TInt64)), perPartStatsCDAContextData)) - val perPartStatsIR = cdaIR(perPartStatsCDAContexts, MakeStruct(FastSeq()), s"shuffle_part_stats_iteration_$i"){ (ctxRef, _) => - val filenames = GetField(ctxRef, "files") - val samples = SeqSample(GetField(ctxRef, "sizeOfPartition"), GetField(ctxRef, "numSamples"), NA(TRNGState), false) - val partitionStream = flatMapIR(ToStream(filenames)) { fileName => - mapIR( - ReadPartition( - MakeStruct(Array("partitionIndex" -> I64(0), "partitionPath" -> fileName)), - tcoerce[TStruct](spec._vType), reader) - ) { partitionElement => - SelectFields(partitionElement, keyToSortBy.fields.map(_.name)) - } + val perPartStatsCDAContextData = + partitionDataPerSegment.flatten.zip(numSamplesPerPartition).map { + case (partData, numSamples) => + Row( + partData.indices.last, + partData.files, + coerceToInt(partData.currentPartSize), + numSamples, + partData.currentPartByteSize, + ) + } + val perPartStatsCDAContexts = ToStream(Literal( + TArray(TStruct( + "segmentIdx" -> TInt32, + "files" -> TArray(TString), + "sizeOfPartition" -> TInt32, + "numSamples" -> TInt32, + "byteSize" -> TInt64, + )), + perPartStatsCDAContextData, + )) + val perPartStatsIR = + cdaIR(perPartStatsCDAContexts, MakeStruct(FastSeq()), s"shuffle_part_stats_iteration_$i") { + (ctxRef, _) => + val filenames = GetField(ctxRef, "files") + val samples = SeqSample( + GetField(ctxRef, "sizeOfPartition"), + GetField(ctxRef, "numSamples"), + NA(TRNGState), + false, + ) + val partitionStream = flatMapIR(ToStream(filenames)) { fileName => + mapIR( + ReadPartition( + MakeStruct(Array("partitionIndex" -> I64(0), "partitionPath" -> fileName)), + tcoerce[TStruct](spec._vType), + reader, + ) + ) { partitionElement => + SelectFields(partitionElement, keyToSortBy.fields.map(_.name)) + } + } + MakeStruct(IndexedSeq( + "segmentIdx" -> GetField(ctxRef, "segmentIdx"), + "byteSize" -> GetField(ctxRef, "byteSize"), + "partData" -> samplePartition(partitionStream, samples, sortFields), + )) } - MakeStruct(IndexedSeq( - "segmentIdx" -> GetField(ctxRef, "segmentIdx"), - "byteSize" -> GetField(ctxRef, "byteSize"), - "partData" -> samplePartition(partitionStream, samples, sortFields))) - } - /* - Aggregate over the segments, to compute the pivots, whether it's already sorted, and what key interval is contained in that segment. - Also get the min and max of each individual partition. That way if it's sorted already, we know the partitioning to use. - */ + /* Aggregate over the segments, to compute the pivots, whether it's already sorted, and what + * key interval is contained in that segment. + * Also get the min and max of each individual partition. That way if it's sorted already, we + * know the partitioning to use. */ val pivotsPerSegmentAndSortedCheck = ToArray(bindIR(perPartStatsIR) { perPartStats => - mapIR(StreamGroupByKey(ToStream(perPartStats), IndexedSeq("segmentIdx"), missingEqual = true)) { oneGroup => + mapIR(StreamGroupByKey( + ToStream(perPartStats), + IndexedSeq("segmentIdx"), + missingEqual = true, + )) { oneGroup => val streamElementRef = Ref(genUID(), oneGroup.typ.asInstanceOf[TIterable].elementType) - val dataRef = Ref(genUID(), streamElementRef.typ.asInstanceOf[TStruct].fieldType("partData")) - val sizeRef = Ref(genUID(), streamElementRef.typ.asInstanceOf[TStruct].fieldType("byteSize")) - bindIR(StreamAgg(oneGroup, streamElementRef.name, { - AggLet(dataRef.name, GetField(streamElementRef, "partData"), AggLet(sizeRef.name, GetField(streamElementRef, "byteSize"), - MakeStruct(FastSeq( - ("byteSize", ApplyAggOp(Sum())(sizeRef)), - ("min", AggFold.min(GetField(dataRef, "min"), sortFields)), // Min of the mins - ("max", AggFold.max(GetField(dataRef, "max"), sortFields)), // Max of the maxes - ("perPartMins", ApplyAggOp(Collect())(GetField(dataRef, "min"))), // All the mins - ("perPartMaxes", ApplyAggOp(Collect())(GetField(dataRef, "max"))), // All the maxes - ("samples", ApplyAggOp(Collect())(GetField(dataRef, "samples"))), - ("eachPartSorted", AggFold.all(GetField(dataRef, "isSorted"))), - ("perPartIntervalTuples", ApplyAggOp(Collect())(MakeTuple.ordered(FastSeq(GetField(dataRef, "min"), GetField(dataRef, "max"))))) - )), false), false) - })) { aggResults => - val sortedOversampling = sortIR(flatMapIR(ToStream(GetField(aggResults, "samples"))) { onePartCollectedArray => ToStream(onePartCollectedArray)}) { case (left, right) => + val dataRef = + Ref(genUID(), streamElementRef.typ.asInstanceOf[TStruct].fieldType("partData")) + val sizeRef = + Ref(genUID(), streamElementRef.typ.asInstanceOf[TStruct].fieldType("byteSize")) + bindIR(StreamAgg( + oneGroup, + streamElementRef.name, { + AggLet( + dataRef.name, + GetField(streamElementRef, "partData"), + AggLet( + sizeRef.name, + GetField(streamElementRef, "byteSize"), + MakeStruct(FastSeq( + ("byteSize", ApplyAggOp(Sum())(sizeRef)), + ("min", AggFold.min(GetField(dataRef, "min"), sortFields)), // Min of the mins + ("max", AggFold.max(GetField(dataRef, "max"), sortFields)), // Max of the maxes + ( + "perPartMins", + ApplyAggOp(Collect())(GetField(dataRef, "min")), + ), // All the mins + ( + "perPartMaxes", + ApplyAggOp(Collect())(GetField(dataRef, "max")), + ), // All the maxes + ("samples", ApplyAggOp(Collect())(GetField(dataRef, "samples"))), + ("eachPartSorted", AggFold.all(GetField(dataRef, "isSorted"))), + ( + "perPartIntervalTuples", + ApplyAggOp(Collect())(MakeTuple.ordered(FastSeq( + GetField(dataRef, "min"), + GetField(dataRef, "max"), + ))), + ), + )), + false, + ), + false, + ) + }, + )) { aggResults => + val sortedOversampling = sortIR(flatMapIR(ToStream(GetField(aggResults, "samples"))) { + onePartCollectedArray => ToStream(onePartCollectedArray) + }) { case (left, right) => ApplyComparisonOp(StructLT(keyToSortBy, sortFields), left, right) } val minArray = MakeArray(GetField(aggResults, "min")) val maxArray = MakeArray(GetField(aggResults, "max")) - val tuplesInSortedOrder = tuplesAreSorted(GetField(aggResults, "perPartIntervalTuples"), sortFields) + val tuplesInSortedOrder = + tuplesAreSorted(GetField(aggResults, "perPartIntervalTuples"), sortFields) bindIR(sortedOversampling) { sortedOversampling => bindIR(ArrayLen(sortedOversampling)) { numSamples => val sortedSampling = bindIR( /* calculate a 'good' branch factor based on part sizes */ - UtilFunctions.intMax(I32(2), + UtilFunctions.intMax( + I32(2), UtilFunctions.intMin( UtilFunctions.intMin(numSamples, I32(defaultBranchingFactor)), - (I64(2L) * (GetField(aggResults, "byteSize").floorDiv(I64(sizeCutoff)))).toI)) + (I64(2L) * (GetField(aggResults, "byteSize").floorDiv(I64(sizeCutoff)))).toI, + ), + ) ) { branchingFactor => ToArray(mapIR(StreamRange(I32(1), branchingFactor, I32(1))) { idx => - If(ArrayLen(sortedOversampling) ceq 0, - Die(strConcat("aggresults=", aggResults, ", idx=", idx, ", sortedOversampling=", sortedOversampling, ", numSamples=", numSamples, ", branchingFactor=", branchingFactor), sortedOversampling.typ.asInstanceOf[TArray].elementType, -1), - ArrayRef(sortedOversampling, Apply("floor", FastSeq(), IndexedSeq(idx.toD * ((numSamples + 1) / branchingFactor)), TFloat64, ErrorIDs.NO_ERROR).toI - 1)) + If( + ArrayLen(sortedOversampling) ceq 0, + Die( + strConcat("aggresults=", aggResults, ", idx=", idx, ", sortedOversampling=", + sortedOversampling, ", numSamples=", numSamples, ", branchingFactor=", + branchingFactor), + sortedOversampling.typ.asInstanceOf[TArray].elementType, + -1, + ), + ArrayRef( + sortedOversampling, + Apply( + "floor", + FastSeq(), + IndexedSeq(idx.toD * ((numSamples + 1) / branchingFactor)), + TFloat64, + ErrorIDs.NO_ERROR, + ).toI - 1, + ), + ) }) } MakeStruct(FastSeq( - "pivotsWithEndpoints" -> ArrayFunctions.extend(ArrayFunctions.extend(minArray, sortedSampling), maxArray), - "isSorted" -> ApplySpecial("land", Seq.empty[Type], FastSeq(GetField(aggResults, "eachPartSorted"), tuplesInSortedOrder), TBoolean, ErrorIDs.NO_ERROR), - "intervalTuple" -> MakeTuple.ordered(FastSeq(GetField(aggResults, "min"), GetField(aggResults, "max"))), + "pivotsWithEndpoints" -> ArrayFunctions.extend( + ArrayFunctions.extend(minArray, sortedSampling), + maxArray, + ), + "isSorted" -> ApplySpecial( + "land", + Seq.empty[Type], + FastSeq(GetField(aggResults, "eachPartSorted"), tuplesInSortedOrder), + TBoolean, + ErrorIDs.NO_ERROR, + ), + "intervalTuple" -> MakeTuple.ordered(FastSeq( + GetField(aggResults, "min"), + GetField(aggResults, "max"), + )), "perPartMins" -> GetField(aggResults, "perPartMins"), - "perPartMaxes" -> GetField(aggResults, "perPartMaxes") + "perPartMaxes" -> GetField(aggResults, "perPartMaxes"), )) } } @@ -290,10 +434,21 @@ object LowerDistributedSort { } }) - log.info(s"DISTRIBUTED SORT: PHASE ${i+1}: STAGE 1: SAMPLE VALUES FROM PARTITIONS") + log.info(s"DISTRIBUTED SORT: PHASE ${i + 1}: STAGE 1: SAMPLE VALUES FROM PARTITIONS") // Going to check now if it's fully sorted, as well as collect and sort all the samples. - val pivotsWithEndpointsAndInfoGroupedBySegmentNumber = CompileAndEvaluate[Annotation](ctx, pivotsPerSegmentAndSortedCheck) - .asInstanceOf[IndexedSeq[Row]].map(x => (x(0).asInstanceOf[IndexedSeq[Row]], x(1).asInstanceOf[Boolean], x(2).asInstanceOf[Row], x(3).asInstanceOf[IndexedSeq[Row]], x(4).asInstanceOf[IndexedSeq[Row]])) + val pivotsWithEndpointsAndInfoGroupedBySegmentNumber = CompileAndEvaluate[Annotation]( + ctx, + pivotsPerSegmentAndSortedCheck, + ) + .asInstanceOf[IndexedSeq[Row]].map(x => + ( + x(0).asInstanceOf[IndexedSeq[Row]], + x(1).asInstanceOf[Boolean], + x(2).asInstanceOf[Row], + x(3).asInstanceOf[IndexedSeq[Row]], + x(4).asInstanceOf[IndexedSeq[Row]], + ) + ) val pivotCounts = pivotsWithEndpointsAndInfoGroupedBySegmentNumber .map(_._1.length) @@ -303,121 +458,228 @@ object LowerDistributedSort { .sortBy(_._1) .map { case (nPivots, nSegments) => s"$nPivots pivots: $nSegments" } - log.info(s"DISTRIBUTED SORT: PHASE ${i + 1}: pivot counts:\n ${pivotCounts.mkString("\n ")}") + log.info( + s"DISTRIBUTED SORT: PHASE ${i + 1}: pivot counts:\n ${pivotCounts.mkString("\n ")}" + ) - val (sortedSegmentsTuples, unsortedPivotsWithEndpointsAndInfoGroupedBySegmentNumber) = pivotsWithEndpointsAndInfoGroupedBySegmentNumber.zipWithIndex.partition { case ((_, isSorted, _, _, _), _) => isSorted} - - val outputPartitions = sortedSegmentsTuples.flatMap { case ((_, _, _, partMins, partMaxes), originalSegmentIdx) => - val segmentToBreakUp = loopState.largeSegments(originalSegmentIdx) - val currentSegmentPartitionData = partitionDataPerSegment(originalSegmentIdx) - val partRanges = partMins.zip(partMaxes) - assert(partRanges.size == currentSegmentPartitionData.size) + val (sortedSegmentsTuples, unsortedPivotsWithEndpointsAndInfoGroupedBySegmentNumber) = + pivotsWithEndpointsAndInfoGroupedBySegmentNumber.zipWithIndex.partition { + case ((_, isSorted, _, _, _), _) => isSorted + } - currentSegmentPartitionData.zip(partRanges).zipWithIndex.map { case ((pi, (intervalStart, intervalEnd)), idx) => - OutputPartition(segmentToBreakUp.indices :+ idx, Interval(intervalStart, intervalEnd, true, true), pi.files) + val outputPartitions = + sortedSegmentsTuples.flatMap { case ((_, _, _, partMins, partMaxes), originalSegmentIdx) => + val segmentToBreakUp = loopState.largeSegments(originalSegmentIdx) + val currentSegmentPartitionData = partitionDataPerSegment(originalSegmentIdx) + val partRanges = partMins.zip(partMaxes) + assert(partRanges.size == currentSegmentPartitionData.size) + + currentSegmentPartitionData.zip(partRanges).zipWithIndex.map { + case ((pi, (intervalStart, intervalEnd)), idx) => + OutputPartition( + segmentToBreakUp.indices :+ idx, + Interval(intervalStart, intervalEnd, true, true), + pi.files, + ) + } } - } - val remainingUnsortedSegments = unsortedPivotsWithEndpointsAndInfoGroupedBySegmentNumber.map {case (_, idx) => loopState.largeSegments(idx)} + val remainingUnsortedSegments = unsortedPivotsWithEndpointsAndInfoGroupedBySegmentNumber.map { + case (_, idx) => loopState.largeSegments(idx) + } - val (newBigUnsortedSegments, newSmallSegments) = if (unsortedPivotsWithEndpointsAndInfoGroupedBySegmentNumber.size > 0) { + val (newBigUnsortedSegments, newSmallSegments) = + if (unsortedPivotsWithEndpointsAndInfoGroupedBySegmentNumber.size > 0) { - val pivotsWithEndpointsGroupedBySegmentNumber = unsortedPivotsWithEndpointsAndInfoGroupedBySegmentNumber.map{ case (r, _) => r._1 } + val pivotsWithEndpointsGroupedBySegmentNumber = + unsortedPivotsWithEndpointsAndInfoGroupedBySegmentNumber.map { case (r, _) => r._1 } - val pivotsWithEndpointsGroupedBySegmentNumberLiteral = Literal(TArray(TArray(keyToSortBy)), pivotsWithEndpointsGroupedBySegmentNumber) + val pivotsWithEndpointsGroupedBySegmentNumberLiteral = + Literal(TArray(TArray(keyToSortBy)), pivotsWithEndpointsGroupedBySegmentNumber) - val tmpPath = ctx.createTmpPath("hail_shuffle_temp") - val unsortedPartitionDataPerSegment = unsortedPivotsWithEndpointsAndInfoGroupedBySegmentNumber.map { case (_, idx) => partitionDataPerSegment(idx)} + val tmpPath = ctx.createTmpPath("hail_shuffle_temp") + val unsortedPartitionDataPerSegment = + unsortedPivotsWithEndpointsAndInfoGroupedBySegmentNumber.map { case (_, idx) => + partitionDataPerSegment(idx) + } - val partitionDataPerSegmentWithPivotIndex = unsortedPartitionDataPerSegment.zipWithIndex.map { case (partitionDataForOneSegment, indexIntoPivotsArray) => - partitionDataForOneSegment.map(x => (x.indices, x.files, x.currentPartSize, indexIntoPivotsArray)) - } + val partitionDataPerSegmentWithPivotIndex = + unsortedPartitionDataPerSegment.zipWithIndex.map { + case (partitionDataForOneSegment, indexIntoPivotsArray) => + partitionDataForOneSegment.map(x => + (x.indices, x.files, x.currentPartSize, indexIntoPivotsArray) + ) + } - val distributeContextsData = partitionDataPerSegmentWithPivotIndex.flatten.zipWithIndex.map { case (part, partIdx) => Row(part._1.last, part._2, partIdx, part._4) } - val distributeContexts = ToStream(Literal(TArray(TStruct("segmentIdx" -> TInt32, "files" -> TArray(TString), "partIdx" -> TInt32, "indexIntoPivotsArray" -> TInt32)), distributeContextsData)) - val distributeGlobals = MakeStruct(IndexedSeq("pivotsWithEndpointsGroupedBySegmentIdx" -> pivotsWithEndpointsGroupedBySegmentNumberLiteral)) - - val distribute = cdaIR(distributeContexts, distributeGlobals, s"shuffle_distribute_iteration_$i") { (ctxRef, globalsRef) => - val segmentIdx = GetField(ctxRef, "segmentIdx") - val indexIntoPivotsArray = GetField(ctxRef, "indexIntoPivotsArray") - val pivotsWithEndpointsGroupedBySegmentIdx = GetField(globalsRef, "pivotsWithEndpointsGroupedBySegmentIdx") - val path = invoke("concat", TString, Str(tmpPath + "_"), invoke("str", TString, GetField(ctxRef, "partIdx"))) - val filenames = GetField(ctxRef, "files") - val partitionStream = flatMapIR(ToStream(filenames)) { fileName => - ReadPartition(MakeStruct(Array("partitionIndex" -> I64(0), "partitionPath" -> fileName)), tcoerce[TStruct](spec._vType), reader) - } - MakeTuple.ordered(IndexedSeq(segmentIdx, StreamDistribute(partitionStream, ArrayRef(pivotsWithEndpointsGroupedBySegmentIdx, indexIntoPivotsArray), path, StructCompare(keyToSortBy, keyToSortBy, sortFields.toArray), spec))) - } + val distributeContextsData = + partitionDataPerSegmentWithPivotIndex.flatten.zipWithIndex.map { case (part, partIdx) => + Row(part._1.last, part._2, partIdx, part._4) + } + val distributeContexts = ToStream(Literal( + TArray(TStruct( + "segmentIdx" -> TInt32, + "files" -> TArray(TString), + "partIdx" -> TInt32, + "indexIntoPivotsArray" -> TInt32, + )), + distributeContextsData, + )) + val distributeGlobals = MakeStruct(IndexedSeq( + "pivotsWithEndpointsGroupedBySegmentIdx" -> pivotsWithEndpointsGroupedBySegmentNumberLiteral + )) + + val distribute = + cdaIR(distributeContexts, distributeGlobals, s"shuffle_distribute_iteration_$i") { + (ctxRef, globalsRef) => + val segmentIdx = GetField(ctxRef, "segmentIdx") + val indexIntoPivotsArray = GetField(ctxRef, "indexIntoPivotsArray") + val pivotsWithEndpointsGroupedBySegmentIdx = + GetField(globalsRef, "pivotsWithEndpointsGroupedBySegmentIdx") + val path = invoke( + "concat", + TString, + Str(tmpPath + "_"), + invoke("str", TString, GetField(ctxRef, "partIdx")), + ) + val filenames = GetField(ctxRef, "files") + val partitionStream = flatMapIR(ToStream(filenames)) { fileName => + ReadPartition( + MakeStruct(Array("partitionIndex" -> I64(0), "partitionPath" -> fileName)), + tcoerce[TStruct](spec._vType), + reader, + ) + } + MakeTuple.ordered(IndexedSeq( + segmentIdx, + StreamDistribute( + partitionStream, + ArrayRef(pivotsWithEndpointsGroupedBySegmentIdx, indexIntoPivotsArray), + path, + StructCompare(keyToSortBy, keyToSortBy, sortFields.toArray), + spec, + ), + )) + } - log.info(s"DISTRIBUTED SORT: PHASE ${i+1}: STAGE 2: DISTRIBUTE") - val distributeResult = CompileAndEvaluate[Annotation](ctx, distribute) - .asInstanceOf[IndexedSeq[Row]].map(row => ( - row(0).asInstanceOf[Int], - row(1).asInstanceOf[IndexedSeq[Row]].map(innerRow => ( - innerRow(0).asInstanceOf[Interval], - innerRow(1).asInstanceOf[String], - innerRow(2).asInstanceOf[Int], - innerRow(3).asInstanceOf[Long]) - ))) - - // distributeResult is a numPartitions length array of arrays, where each inner array tells me what - // files were written to for each partition, as well as the number of entries in that file. - val protoDataPerSegment = orderedGroupBy[(Int, IndexedSeq[(Interval, String, Int, Long)]), Int](distributeResult, x => x._1).map { case (_, seqOfChunkData) => seqOfChunkData.map(_._2) } - - val transposedIntoNewSegments = protoDataPerSegment.zip(remainingUnsortedSegments.map(_.indices)).flatMap { case (oneOldSegment, priorIndices) => - val headLen = oneOldSegment.head.length - assert(oneOldSegment.forall(x => x.length == headLen)) - (0 until headLen).map(colIdx => (oneOldSegment.map(row => row(colIdx)), priorIndices)) - } + log.info(s"DISTRIBUTED SORT: PHASE ${i + 1}: STAGE 2: DISTRIBUTE") + val distributeResult = CompileAndEvaluate[Annotation](ctx, distribute) + .asInstanceOf[IndexedSeq[Row]].map(row => + ( + row(0).asInstanceOf[Int], + row(1).asInstanceOf[IndexedSeq[Row]].map(innerRow => + ( + innerRow(0).asInstanceOf[Interval], + innerRow(1).asInstanceOf[String], + innerRow(2).asInstanceOf[Int], + innerRow(3).asInstanceOf[Long], + ) + ), + ) + ) + + /* distributeResult is a numPartitions length array of arrays, where each inner array + * tells me what */ + /* files were written to for each partition, as well as the number of entries in that + * file. */ + val protoDataPerSegment = + orderedGroupBy[(Int, IndexedSeq[(Interval, String, Int, Long)]), Int]( + distributeResult, + x => x._1, + ).map { case (_, seqOfChunkData) => seqOfChunkData.map(_._2) } + + val transposedIntoNewSegments = + protoDataPerSegment.zip(remainingUnsortedSegments.map(_.indices)).flatMap { + case (oneOldSegment, priorIndices) => + val headLen = oneOldSegment.head.length + assert(oneOldSegment.forall(x => x.length == headLen)) + (0 until headLen).map(colIdx => + (oneOldSegment.map(row => row(colIdx)), priorIndices) + ) + } - val dataPerSegment = transposedIntoNewSegments.zipWithIndex.map { case ((chunksWithSameInterval, priorIndices), newIndex) => - val interval = chunksWithSameInterval.head._1 - val chunks = chunksWithSameInterval.map{ case (_, filename, numRows, numBytes) => Chunk(filename, numRows, numBytes)} - val newSegmentIndices = priorIndices :+ newIndex - SegmentResult(newSegmentIndices, interval, chunks) - } + val dataPerSegment = transposedIntoNewSegments.zipWithIndex.map { + case ((chunksWithSameInterval, priorIndices), newIndex) => + val interval = chunksWithSameInterval.head._1 + val chunks = chunksWithSameInterval.map { case (_, filename, numRows, numBytes) => + Chunk(filename, numRows, numBytes) + } + val newSegmentIndices = priorIndices :+ newIndex + SegmentResult(newSegmentIndices, interval, chunks) + } - // Decide whether a segment is small enough to be removed from consideration. - dataPerSegment.partition { sr => - val isBig = sr.chunks.map(_.byteSize).sum > sizeCutoff - // Need to call it "small" if it can't be further subdivided. - isBig && (sr.interval.left.point != sr.interval.right.point) && (sr.chunks.map(_.size).sum > 1) - } - } else { (IndexedSeq.empty[SegmentResult], IndexedSeq.empty[SegmentResult]) } + // Decide whether a segment is small enough to be removed from consideration. + dataPerSegment.partition { sr => + val isBig = sr.chunks.map(_.byteSize).sum > sizeCutoff + // Need to call it "small" if it can't be further subdivided. + isBig && (sr.interval.left.point != sr.interval.right.point) && (sr.chunks.map( + _.size + ).sum > 1) + } + } else { (IndexedSeq.empty[SegmentResult], IndexedSeq.empty[SegmentResult]) } - log.info(s"DISTRIBUTED SORT: PHASE ${i + 1}: ${newSmallSegments.length}/${newSmallSegments.length + newBigUnsortedSegments.length} segments can be locally sorted") + log.info( + s"DISTRIBUTED SORT: PHASE ${i + 1}: ${newSmallSegments.length}/${newSmallSegments.length + newBigUnsortedSegments.length} segments can be locally sorted" + ) - loopState = LoopState(newBigUnsortedSegments, loopState.smallSegments ++ newSmallSegments, loopState.readyOutputParts ++ outputPartitions) + loopState = LoopState( + newBigUnsortedSegments, + loopState.smallSegments ++ newSmallSegments, + loopState.readyOutputParts ++ outputPartitions, + ) i = i + 1 } val needSortingFilenames = loopState.smallSegments.map(_.chunks.map(_.filename)) val needSortingFilenamesContext = Literal(TArray(TArray(TString)), needSortingFilenames) - val sortedFilenamesIR = cdaIR(ToStream(needSortingFilenamesContext), MakeStruct(FastSeq()), "shuffle_local_sort") { case (ctxRef, _) => - val filenames = ctxRef - val partitionInputStream = flatMapIR(ToStream(filenames)) { fileName => - ReadPartition(MakeStruct(Array("partitionIndex" -> I64(0), "partitionPath" -> fileName)), tcoerce[TStruct](spec._vType), reader) + val sortedFilenamesIR = + cdaIR(ToStream(needSortingFilenamesContext), MakeStruct(FastSeq()), "shuffle_local_sort") { + case (ctxRef, _) => + val filenames = ctxRef + val partitionInputStream = flatMapIR(ToStream(filenames)) { fileName => + ReadPartition( + MakeStruct(Array("partitionIndex" -> I64(0), "partitionPath" -> fileName)), + tcoerce[TStruct](spec._vType), + reader, + ) + } + val newKeyFieldNames = keyToSortBy.fields.map(_.name) + val sortedStream = ToStream(sortIR(partitionInputStream) { (refLeft, refRight) => + ApplyComparisonOp( + StructLT(keyToSortBy, sortFields), + SelectFields(refLeft, newKeyFieldNames), + SelectFields(refRight, newKeyFieldNames), + ) + }) + WritePartition(sortedStream, UUID4(), writer) } - val newKeyFieldNames = keyToSortBy.fields.map(_.name) - val sortedStream = ToStream(sortIR(partitionInputStream) { (refLeft, refRight) => - ApplyComparisonOp(StructLT(keyToSortBy, sortFields), SelectFields(refLeft, newKeyFieldNames), SelectFields(refRight, newKeyFieldNames)) - }) - WritePartition(sortedStream, UUID4(), writer) - } - log.info(s"DISTRIBUTED SORT: PHASE ${i+1}: LOCALLY SORT FILES") + log.info(s"DISTRIBUTED SORT: PHASE ${i + 1}: LOCALLY SORT FILES") log.info(s"DISTRIBUTED_SORT: ${needSortingFilenames.length} segments to sort") - val sortedFilenames = CompileAndEvaluate[Annotation](ctx, sortedFilenamesIR).asInstanceOf[IndexedSeq[Row]].map(_(0).asInstanceOf[String]) - val newlySortedSegments = loopState.smallSegments.zip(sortedFilenames).map { case (sr, newFilename) => - OutputPartition(sr.indices, sr.interval, IndexedSeq(initialTmpPath + newFilename)) - } + val sortedFilenames = + CompileAndEvaluate[Annotation](ctx, sortedFilenamesIR).asInstanceOf[IndexedSeq[Row]].map( + _(0).asInstanceOf[String] + ) + val newlySortedSegments = + loopState.smallSegments.zip(sortedFilenames).map { case (sr, newFilename) => + OutputPartition(sr.indices, sr.interval, IndexedSeq(initialTmpPath + newFilename)) + } val unorderedOutputPartitions = newlySortedSegments ++ loopState.readyOutputParts - val orderedOutputPartitions = unorderedOutputPartitions.sortWith{ (srt1, srt2) => lessThanForSegmentIndices(srt1.indices, srt2.indices)} + val orderedOutputPartitions = unorderedOutputPartitions.sortWith { (srt1, srt2) => + lessThanForSegmentIndices(srt1.indices, srt2.indices) + } val keyed = sortFields.forall(sf => sf.sortOrder == Ascending) - DistributionSortReader(keyToSortBy, keyed, spec, orderedOutputPartitions, initialGlobalsLiteral, spec._vType.asInstanceOf[TStruct], tableRequiredness) + DistributionSortReader( + keyToSortBy, + keyed, + spec, + orderedOutputPartitions, + initialGlobalsLiteral, + spec._vType.asInstanceOf[TStruct], + tableRequiredness, + ) } def orderedGroupBy[T, U](is: IndexedSeq[T], func: T => U): IndexedSeq[(U, IndexedSeq[T])] = { @@ -453,18 +715,26 @@ object LowerDistributedSort { } idx += 1 } - // For there to be no difference at this point, they had to be equal whole way. Assert that they're same length. + /* For there to be no difference at this point, they had to be equal whole way. Assert that + * they're same length. */ assert(i1.length == i2.length) false } - case class PartitionInfo(indices: IndexedSeq[Int], files: IndexedSeq[String], currentPartSize: Long, currentPartByteSize: Long) + case class PartitionInfo( + indices: IndexedSeq[Int], + files: IndexedSeq[String], + currentPartSize: Long, + currentPartByteSize: Long, + ) - def segmentsToPartitionData(segments: IndexedSeq[SegmentResult], idealNumberOfRowsPerPart: Long): IndexedSeq[IndexedSeq[PartitionInfo]] = { + def segmentsToPartitionData(segments: IndexedSeq[SegmentResult], idealNumberOfRowsPerPart: Long) + : IndexedSeq[IndexedSeq[PartitionInfo]] = { segments.map { sr => val chunkDataSizes = sr.chunks.map(_.size) val segmentSize = chunkDataSizes.sum - val numParts = coerceToInt((segmentSize + idealNumberOfRowsPerPart - 1) / idealNumberOfRowsPerPart) + val numParts = + coerceToInt((segmentSize + idealNumberOfRowsPerPart - 1) / idealNumberOfRowsPerPart) var currentPartSize = 0L var currentPartByteSize = 0L val groupedIntoParts = new ArrayBuffer[PartitionInfo](numParts) @@ -475,7 +745,12 @@ object LowerDistributedSort { currentPartSize += chunk.size currentPartByteSize += chunk.byteSize if (currentPartSize >= idealNumberOfRowsPerPart) { - groupedIntoParts.append(PartitionInfo(sr.indices, currentFiles.result().toIndexedSeq, currentPartSize, currentPartByteSize)) + groupedIntoParts.append(PartitionInfo( + sr.indices, + currentFiles.result().toIndexedSeq, + currentPartSize, + currentPartByteSize, + )) currentFiles.clear() currentPartSize = 0 currentPartByteSize = 0L @@ -483,13 +758,23 @@ object LowerDistributedSort { } } if (!currentFiles.isEmpty) { - groupedIntoParts.append(PartitionInfo(sr.indices, currentFiles.result().toIndexedSeq, currentPartSize, currentPartByteSize)) + groupedIntoParts.append(PartitionInfo( + sr.indices, + currentFiles.result().toIndexedSeq, + currentPartSize, + currentPartByteSize, + )) } groupedIntoParts.result() } } - def howManySamplesPerPartition(rand: IRRandomness, totalNumberOfRecords: Long, initialNumSamplesToSelect: Int, partitionCounts: IndexedSeq[Long]): IndexedSeq[Int] = { + def howManySamplesPerPartition( + rand: IRRandomness, + totalNumberOfRecords: Long, + initialNumSamplesToSelect: Int, + partitionCounts: IndexedSeq[Long], + ): IndexedSeq[Int] = { var successStatesRemaining = initialNumSamplesToSelect var failureStatesRemaining = totalNumberOfRecords - successStatesRemaining @@ -497,7 +782,8 @@ object LowerDistributedSort { var i = 0 while (i < partitionCounts.size) { - val numSuccesses = rand.rhyper(successStatesRemaining, failureStatesRemaining, partitionCounts(i)).toInt + val numSuccesses = + rand.rhyper(successStatesRemaining, failureStatesRemaining, partitionCounts(i)).toInt successStatesRemaining -= numSuccesses failureStatesRemaining -= (partitionCounts(i) - numSuccesses) ans(i) = numSuccesses @@ -511,7 +797,8 @@ object LowerDistributedSort { // Step 1: Join the dataStream zippedWithIdx on sampleIndices? // That requires sampleIndices to be a stream of structs val samplingIndexName = "samplingPartitionIndex" - val structSampleIndices = mapIR(sampleIndices)(sampleIndex => MakeStruct(FastSeq((samplingIndexName, sampleIndex)))) + val structSampleIndices = + mapIR(sampleIndices)(sampleIndex => MakeStruct(FastSeq((samplingIndexName, sampleIndex)))) val dataWithIdx = zipWithIndex(dataStream) val leftName = genUID() @@ -519,11 +806,23 @@ object LowerDistributedSort { val leftRef = Ref(leftName, dataWithIdx.typ.asInstanceOf[TStream].elementType) val rightRef = Ref(rightName, structSampleIndices.typ.asInstanceOf[TStream].elementType) - val joined = StreamJoin(dataWithIdx, structSampleIndices, IndexedSeq("idx"), IndexedSeq(samplingIndexName), leftName, rightName, - MakeStruct(FastSeq(("elt", GetField(leftRef, "elt")), ("shouldKeep", ApplyUnaryPrimOp(Bang, IsNA(rightRef))))), - "left", requiresMemoryManagement = true) - - // Step 2: Aggregate over joined, figure out how to collect only the rows that are marked "shouldKeep" + val joined = StreamJoin( + dataWithIdx, + structSampleIndices, + IndexedSeq("idx"), + IndexedSeq(samplingIndexName), + leftName, + rightName, + MakeStruct(FastSeq( + ("elt", GetField(leftRef, "elt")), + ("shouldKeep", ApplyUnaryPrimOp(Bang, IsNA(rightRef))), + )), + "left", + requiresMemoryManagement = true, + ) + + /* Step 2: Aggregate over joined, figure out how to collect only the rows that are marked + * "shouldKeep" */ val streamElementType = joined.typ.asInstanceOf[TStream].elementType.asInstanceOf[TStruct] val streamElementName = genUID() val streamElementRef = Ref(streamElementName, streamElementType) @@ -532,74 +831,135 @@ object LowerDistributedSort { val eltRef = Ref(eltName, eltType) // Folding for isInternallySorted - val aggFoldSortedZero = MakeStruct(FastSeq("lastKeySeen" -> NA(eltType), "sortedSoFar" -> true, "haveSeenAny" -> false)) + val aggFoldSortedZero = MakeStruct(FastSeq( + "lastKeySeen" -> NA(eltType), + "sortedSoFar" -> true, + "haveSeenAny" -> false, + )) val aggFoldSortedAccumName1 = genUID() val aggFoldSortedAccumName2 = genUID() - val isSortedStateType = TStruct("lastKeySeen" -> eltType, "sortedSoFar" -> TBoolean, "haveSeenAny" -> TBoolean) + val isSortedStateType = + TStruct("lastKeySeen" -> eltType, "sortedSoFar" -> TBoolean, "haveSeenAny" -> TBoolean) val aggFoldSortedAccumRef1 = Ref(aggFoldSortedAccumName1, isSortedStateType) val isSortedSeq = bindIR(GetField(aggFoldSortedAccumRef1, "lastKeySeen")) { lastKeySeenRef => - If(!GetField(aggFoldSortedAccumRef1, "haveSeenAny"), - MakeStruct(FastSeq("lastKeySeen" -> eltRef, "sortedSoFar" -> true, "haveSeenAny" -> true)), - If (ApplyComparisonOp(StructLTEQ(eltType, sortFields), lastKeySeenRef, eltRef), - MakeStruct(FastSeq("lastKeySeen" -> eltRef, "sortedSoFar" -> GetField(aggFoldSortedAccumRef1, "sortedSoFar"), "haveSeenAny" -> true)), - MakeStruct(FastSeq("lastKeySeen" -> eltRef, "sortedSoFar" -> false, "haveSeenAny" -> true)) - ) + If( + !GetField(aggFoldSortedAccumRef1, "haveSeenAny"), + MakeStruct(FastSeq( + "lastKeySeen" -> eltRef, + "sortedSoFar" -> true, + "haveSeenAny" -> true, + )), + If( + ApplyComparisonOp(StructLTEQ(eltType, sortFields), lastKeySeenRef, eltRef), + MakeStruct(FastSeq( + "lastKeySeen" -> eltRef, + "sortedSoFar" -> GetField(aggFoldSortedAccumRef1, "sortedSoFar"), + "haveSeenAny" -> true, + )), + MakeStruct(FastSeq( + "lastKeySeen" -> eltRef, + "sortedSoFar" -> false, + "haveSeenAny" -> true, + )), + ), ) } - val isSortedComb = aggFoldSortedAccumRef1 // Do nothing, as this will never be called in a StreamAgg - - - StreamAgg(joined, streamElementName, { - AggLet(eltName, GetField(streamElementRef, "elt"), - MakeStruct(FastSeq( - ("min", AggFold.min(eltRef, sortFields)), - ("max", AggFold.max(eltRef, sortFields)), - ("samples", AggFilter(GetField(streamElementRef, "shouldKeep"), ApplyAggOp(Collect())(eltRef), false)), - ("isSorted", GetField(AggFold(aggFoldSortedZero, isSortedSeq, isSortedComb, aggFoldSortedAccumName1, aggFoldSortedAccumName2, false), "sortedSoFar")) - )), false) - }) + val isSortedComb = + aggFoldSortedAccumRef1 // Do nothing, as this will never be called in a StreamAgg + + StreamAgg( + joined, + streamElementName, { + AggLet( + eltName, + GetField(streamElementRef, "elt"), + MakeStruct(FastSeq( + ("min", AggFold.min(eltRef, sortFields)), + ("max", AggFold.max(eltRef, sortFields)), + ( + "samples", + AggFilter( + GetField(streamElementRef, "shouldKeep"), + ApplyAggOp(Collect())(eltRef), + false, + ), + ), + ( + "isSorted", + GetField( + AggFold(aggFoldSortedZero, isSortedSeq, isSortedComb, aggFoldSortedAccumName1, + aggFoldSortedAccumName2, false), + "sortedSoFar", + ), + ), + )), + false, + ) + }, + ) } - // Given an IR of type TArray(TTuple(minKey, maxKey)), determine if there's any overlap between these closed intervals. + /* Given an IR of type TArray(TTuple(minKey, maxKey)), determine if there's any overlap between + * these closed intervals. */ def tuplesAreSorted(arrayOfTuples: IR, sortFields: IndexedSeq[SortField]): IR = { - val intervalElementType = arrayOfTuples.typ.asInstanceOf[TArray].elementType.asInstanceOf[TTuple].types(0) - - foldIR(mapIR(rangeIR(1, ArrayLen(arrayOfTuples))) { idxOfTuple => - ApplyComparisonOp(StructLTEQ(intervalElementType, sortFields), GetTupleElement(ArrayRef(arrayOfTuples, idxOfTuple - 1), 1), GetTupleElement(ArrayRef(arrayOfTuples, idxOfTuple), 0)) - }, True()) { case (accum, elt) => + val intervalElementType = + arrayOfTuples.typ.asInstanceOf[TArray].elementType.asInstanceOf[TTuple].types(0) + + foldIR( + mapIR(rangeIR(1, ArrayLen(arrayOfTuples))) { idxOfTuple => + ApplyComparisonOp( + StructLTEQ(intervalElementType, sortFields), + GetTupleElement(ArrayRef(arrayOfTuples, idxOfTuple - 1), 1), + GetTupleElement(ArrayRef(arrayOfTuples, idxOfTuple), 0), + ) + }, + True(), + ) { case (accum, elt) => ApplySpecial("land", Seq.empty[Type], FastSeq(accum, elt), TBoolean, ErrorIDs.NO_ERROR) } } } -/** - * a "Chunk" is a file resulting from any StreamDistribute. Chunks are internally unsorted but contain - * data between two pivots. +/** a "Chunk" is a file resulting from any StreamDistribute. Chunks are internally unsorted but + * contain data between two pivots. */ case class Chunk(filename: String, size: Long, byteSize: Long) -/** - * A SegmentResult is the set of chunks from various StreamDistribute tasks working on the same segment - * of a previous iteration. + +/** A SegmentResult is the set of chunks from various StreamDistribute tasks working on the same + * segment of a previous iteration. */ case class SegmentResult(indices: IndexedSeq[Int], interval: Interval, chunks: IndexedSeq[Chunk]) case class OutputPartition(indices: IndexedSeq[Int], interval: Interval, files: IndexedSeq[String]) -case class LoopState(largeSegments: IndexedSeq[SegmentResult], smallSegments: IndexedSeq[SegmentResult], readyOutputParts: IndexedSeq[OutputPartition]) -case class DistributionSortReader(key: TStruct, keyed: Boolean, spec: TypedCodecSpec, orderedOutputPartitions: IndexedSeq[OutputPartition], globals: IR, rowType: TStruct, rt: RTable) extends TableReader { +case class LoopState( + largeSegments: IndexedSeq[SegmentResult], + smallSegments: IndexedSeq[SegmentResult], + readyOutputParts: IndexedSeq[OutputPartition], +) + +case class DistributionSortReader( + key: TStruct, + keyed: Boolean, + spec: TypedCodecSpec, + orderedOutputPartitions: IndexedSeq[OutputPartition], + globals: IR, + rowType: TStruct, + rt: RTable, +) extends TableReader { lazy val fullType: TableType = TableType( rowType, if (keyed) key.fieldNames else FastSeq(), - globals.typ.asInstanceOf[TStruct] + globals.typ.asInstanceOf[TStruct], ) override def pathsUsed: Seq[String] = FastSeq() def defaultPartitioning(sm: HailStateManager): RVDPartitioner = { val (partitionerKey, intervals) = if (keyed) { - (key, orderedOutputPartitions.map { segment => segment.interval }) + (key, orderedOutputPartitions.map(segment => segment.interval)) } else { - (TStruct(), orderedOutputPartitions.map { _ => Interval(Row(), Row(), true, false) }) + (TStruct(), orderedOutputPartitions.map(_ => Interval(Row(), Row(), true, false))) } new RVDPartitioner(sm, partitionerKey, intervals) @@ -609,13 +969,11 @@ case class DistributionSortReader(key: TStruct, keyed: Boolean, spec: TypedCodec override def isDistinctlyKeyed: Boolean = false // FIXME: No default value - def rowRequiredness(ctx: ExecuteContext, requestedType: TableType): VirtualTypeWithReq = { + def rowRequiredness(ctx: ExecuteContext, requestedType: TableType): VirtualTypeWithReq = VirtualTypeWithReq.subset(requestedType.rowType, rt.rowType) - } - def globalRequiredness(ctx: ExecuteContext, requestedType: TableType): VirtualTypeWithReq = { + def globalRequiredness(ctx: ExecuteContext, requestedType: TableType): VirtualTypeWithReq = VirtualTypeWithReq.subset(requestedType.globalType, rt.globalType) - } override def toJValue: JValue = JString("DistributionSortReader") @@ -638,7 +996,13 @@ case class DistributionSortReader(key: TStruct, keyed: Boolean, spec: TypedCodec Row(filesWithNums) } } - val contexts = ToStream(Literal(TArray(TStruct("files" -> TArray(TStruct("partitionIndex" -> TInt64, "partitionPath" -> TString)))), contextData)) + val contexts = ToStream(Literal( + TArray(TStruct("files" -> TArray(TStruct( + "partitionIndex" -> TInt64, + "partitionPath" -> TString, + )))), + contextData, + )) val partitioner = defaultPartitioning(ctx.stateManager) @@ -653,6 +1017,7 @@ case class DistributionSortReader(key: TStruct, keyed: Boolean, spec: TypedCodec ReadPartition(fileInfo, requestedType.rowType, PartitionNativeReader(spec, "__dummy_uid")) } partitionInputStream - }) + }, + ) } } diff --git a/hail/src/main/scala/is/hail/expr/ir/lowering/LowerTableIR.scala b/hail/src/main/scala/is/hail/expr/ir/lowering/LowerTableIR.scala index 8d1b801d2f6..c7f0ac78b3e 100644 --- a/hail/src/main/scala/is/hail/expr/ir/lowering/LowerTableIR.scala +++ b/hail/src/main/scala/is/hail/expr/ir/lowering/LowerTableIR.scala @@ -2,9 +2,9 @@ package is.hail.expr.ir.lowering import is.hail.HailContext import is.hail.backend.ExecuteContext +import is.hail.expr.ir.{agg, TableNativeWriter, _} import is.hail.expr.ir.ArrayZipBehavior.AssertSameLength import is.hail.expr.ir.functions.{TableCalculateNewPartitions, WrappedMatrixToTableFunction} -import is.hail.expr.ir.{TableNativeWriter, agg, _} import is.hail.io.{BufferSpec, TypedCodecSpec} import is.hail.methods.{ForceCountTable, LocalLDPrune, NPartitionsTable, TableFilterPartitions} import is.hail.rvd.{PartitionBoundOrdering, RVDPartitioner} @@ -12,6 +12,7 @@ import is.hail.types._ import is.hail.types.physical.{PCanonicalBinary, PCanonicalTuple} import is.hail.types.virtual._ import is.hail.utils._ + import org.apache.spark.sql.Row class LowererUnsupportedOperation(msg: String = null) extends Exception(msg) @@ -22,7 +23,7 @@ object TableStage { partitioner: RVDPartitioner, dependency: TableStageDependency, contexts: IR, - body: (Ref) => IR + body: (Ref) => IR, ): TableStage = { val globalsRef = Ref(genUID(), globals.typ) TableStage( @@ -32,7 +33,8 @@ object TableStage { partitioner, dependency, contexts, - body) + body, + ) } def apply( @@ -42,12 +44,21 @@ object TableStage { partitioner: RVDPartitioner, dependency: TableStageDependency, contexts: IR, - partition: Ref => IR + partition: Ref => IR, ): TableStage = { val ctxType = contexts.typ.asInstanceOf[TStream].elementType val ctxRef = Ref(genUID(), ctxType) - new TableStage(letBindings, broadcastVals, globals, partitioner, dependency, contexts, ctxRef.name, partition(ctxRef)) + new TableStage( + letBindings, + broadcastVals, + globals, + partitioner, + dependency, + contexts, + ctxRef.name, + partition(ctxRef), + ) } def concatenate(ctx: ExecuteContext, children: IndexedSeq[TableStage]): TableStage = { @@ -69,7 +80,8 @@ object TableStage { val newGlobals = children.head.globals val globalsRef = Ref(genUID(), newGlobals.typ) - val newPartitioner = new RVDPartitioner(ctx.stateManager, keyType, children.flatMap(_.partitioner.rangeBounds)) + val newPartitioner = + new RVDPartitioner(ctx.stateManager, keyType, children.flatMap(_.partitioner.rangeBounds)) TableStage( children.flatMap(_.letBindings) :+ globalsRef.name -> newGlobals, @@ -82,13 +94,17 @@ object TableStage { StreamMultiMerge( children.indices.map { i => bindIR(GetTupleElement(ctxRef, i)) { ctx => - If(IsNA(ctx), - MakeStream(IndexedSeq(), TStream(children(i).rowType)), - children(i).partition(ctx)) + If( + IsNA(ctx), + MakeStream(IndexedSeq(), TStream(children(i).rowType)), + children(i).partition(ctx), + ) } }, - IndexedSeq()) - }) + IndexedSeq(), + ) + }, + ) } } @@ -107,7 +123,8 @@ class TableStage( val dependency: TableStageDependency, val contexts: IR, val ctxRefName: String, - val partitionIR: IR) { + val partitionIR: IR, +) { self => // useful for debugging, but should be disabled in production code due to N^2 complexity @@ -115,17 +132,19 @@ class TableStage( contexts.typ match { case TStream(t) if t.isRealizable => - case t => throw new IllegalArgumentException(s"TableStage constructed with illegal context type $t") + case t => + throw new IllegalArgumentException(s"TableStage constructed with illegal context type $t") } - def typecheckPartition(ctx: ExecuteContext): Unit = { + def typecheckPartition(ctx: ExecuteContext): Unit = TypeCheck( ctx, partitionIR, BindingEnv(Env[Type](((letBindings ++ broadcastVals).map { case (s, x) => (s, x.typ) }) - ++ FastSeq[(String, Type)]((ctxRefName, contexts.typ.asInstanceOf[TStream].elementType)): _*))) - - } + ++ FastSeq[(String, Type)]( + (ctxRefName, contexts.typ.asInstanceOf[TStream].elementType) + ): _*)), + ) def upcast(ctx: ExecuteContext, newType: TableType): TableStage = { val newRowType = newType.rowType @@ -144,11 +163,12 @@ class TableStage( def kType: TStruct = partitioner.kType def key: IndexedSeq[String] = kType.fieldNames def globalType: TStruct = globals.typ.asInstanceOf[TStruct] + def tableType: TableType = TableType(rowType, key, globalType) assert(kType.isSubsetOf(rowType), s"Key type $kType is not a subset of $rowType") - assert(broadcastVals.exists { case (name, value) => name == globals.name && value == globals}) + assert(broadcastVals.exists { case (name, value) => name == globals.name && value == globals }) def copy( letBindings: IndexedSeq[(String, IR)] = letBindings, @@ -158,9 +178,10 @@ class TableStage( dependency: TableStageDependency = dependency, contexts: IR = contexts, ctxRefName: String = ctxRefName, - partitionIR: IR = partitionIR + partitionIR: IR = partitionIR, ): TableStage = - new TableStage(letBindings, broadcastVals, globals, partitioner, dependency, contexts, ctxRefName, partitionIR) + new TableStage(letBindings, broadcastVals, globals, partitioner, dependency, contexts, + ctxRefName, partitionIR) def partition(ctx: IR): IR = { require(ctx.typ == ctxType) @@ -174,15 +195,16 @@ class TableStage( case Some(k) => if (!partitioner.kType.fieldNames.startsWith(k)) throw new RuntimeException(s"cannot map partitions to new key!" + - s"\n prev key: ${ partitioner.kType.fieldNames.toSeq }" + - s"\n new key: ${ k }") + s"\n prev key: ${partitioner.kType.fieldNames.toSeq}" + + s"\n new key: $k") partitioner.coarsen(k.length) case None => partitioner } copy(partitionIR = f(partitionIR), partitioner = part) } - def zipPartitions(right: TableStage, newGlobals: (IR, IR) => IR, body: (IR, IR) => IR): TableStage = { + def zipPartitions(right: TableStage, newGlobals: (IR, IR) => IR, body: (IR, IR) => IR) + : TableStage = { val left = this val leftCtxTyp = left.ctxType val rightCtxTyp = right.ctxType @@ -196,9 +218,9 @@ class TableStage( val zippedCtxs = StreamZip( FastSeq(left.contexts, right.contexts), FastSeq(leftCtxRef.name, rightCtxRef.name), - MakeStruct(FastSeq(leftCtxStructField -> leftCtxRef, - rightCtxStructField -> rightCtxRef)), - ArrayZipBehavior.AssertSameLength) + MakeStruct(FastSeq(leftCtxStructField -> leftCtxRef, rightCtxStructField -> rightCtxRef)), + ArrayZipBehavior.AssertSameLength, + ) val globals = newGlobals(left.globals, right.globals) val globalsRef = Ref(genUID(), globals.typ) @@ -210,13 +232,13 @@ class TableStage( left.partitioner, left.dependency.union(right.dependency), zippedCtxs, - (ctxRef: Ref) => { + (ctxRef: Ref) => bindIR(left.partition(GetField(ctxRef, leftCtxStructField))) { lPart => bindIR(right.partition(GetField(ctxRef, rightCtxStructField))) { rPart => body(lPart, rPart) } - } - }) + }, + ) } def mapPartitionWithContext(f: (IR, Ref) => IR): TableStage = @@ -224,7 +246,15 @@ class TableStage( def mapContexts(f: IR => IR)(getOldContext: IR => IR): TableStage = { val newContexts = f(contexts) - TableStage(letBindings, broadcastVals, globals, partitioner, dependency, newContexts, ctxRef => bindIR(getOldContext(ctxRef))(partition)) + TableStage( + letBindings, + broadcastVals, + globals, + partitioner, + dependency, + newContexts, + ctxRef => bindIR(getOldContext(ctxRef))(partition), + ) } def zipContextsWithIdx(): TableStage = { @@ -239,39 +269,62 @@ class TableStage( copy( letBindings = letBindings :+ globalsRef.name -> newGlobals, broadcastVals = broadcastVals :+ globalsRef.name -> globalsRef, - globals = globalsRef) + globals = globalsRef, + ) } - def mapCollect(staticID: String, dynamicID: IR = NA(TString))(f: IR => IR): IR = { - mapCollectWithGlobals(staticID, dynamicID)(f) { (parts, globals) => parts } - } - - def mapCollectWithGlobals(staticID: String, dynamicID: IR = NA(TString))(mapF: IR => IR)(body: (IR, IR) => IR): IR = + def mapCollect(staticID: String, dynamicID: IR = NA(TString))(f: IR => IR): IR = + mapCollectWithGlobals(staticID, dynamicID)(f)((parts, globals) => parts) + + def mapCollectWithGlobals( + staticID: String, + dynamicID: IR = NA(TString), + )( + mapF: IR => IR + )( + body: (IR, IR) => IR + ): IR = mapCollectWithContextsAndGlobals(staticID, dynamicID)((part, ctx) => mapF(part))(body) // mapf is (part, ctx) => ???, body is (parts, globals) => ??? - def mapCollectWithContextsAndGlobals(staticID: String, dynamicID: IR = NA(TString))(mapF: (IR, Ref) => IR)(body: (IR, IR) => IR): IR = { + def mapCollectWithContextsAndGlobals( + staticID: String, + dynamicID: IR = NA(TString), + )( + mapF: (IR, Ref) => IR + )( + body: (IR, IR) => IR + ): IR = { val broadcastRefs = MakeStruct(broadcastVals) val glob = Ref(genUID(), broadcastRefs.typ) val cda = CollectDistributedArray( - contexts, broadcastRefs, - ctxRefName, glob.name, - Let(broadcastVals.map { case (name, _) => name -> GetField(glob, name) }, - mapF(partitionIR, Ref(ctxRefName, ctxType)) - ), dynamicID, staticID, Some(dependency)) - - Let(letBindings, bindIR(cda) { cdaRef => body(cdaRef, globals) }) + contexts, + broadcastRefs, + ctxRefName, + glob.name, + Let( + broadcastVals.map { case (name, _) => name -> GetField(glob, name) }, + mapF(partitionIR, Ref(ctxRefName, ctxType)), + ), + dynamicID, + staticID, + Some(dependency), + ) + + Let(letBindings, bindIR(cda)(cdaRef => body(cdaRef, globals))) } def collectWithGlobals(staticID: String, dynamicID: IR = NA(TString)): IR = mapCollectWithGlobals(staticID, dynamicID)(ToArray) { (parts, globals) => MakeStruct(FastSeq( "rows" -> ToArray(flatMapIR(ToStream(parts))(ToStream(_))), - "global" -> globals)) + "global" -> globals, + )) } - def countPerPartition(): IR = mapCollect("count_per_partition")(part => Cast(StreamLen(part), TInt64)) + def countPerPartition(): IR = + mapCollect("count_per_partition")(part => Cast(StreamLen(part), TInt64)) def getGlobals(): IR = Let(letBindings, globals) @@ -289,11 +342,12 @@ class TableStage( repartitionNoShuffle(ec, newPart) } - def repartitionNoShuffle(ec: ExecuteContext, - newPartitioner: RVDPartitioner, - allowDuplication: Boolean = false, - dropEmptyPartitions: Boolean = false - ): TableStage = { + def repartitionNoShuffle( + ec: ExecuteContext, + newPartitioner: RVDPartitioner, + allowDuplication: Boolean = false, + dropEmptyPartitions: Boolean = false, + ): TableStage = { if (newPartitioner == this.partitioner) { return this @@ -306,9 +360,15 @@ class TableStage( val newStage = if (LowerTableIR.isRepartitioningCheap(partitioner, newPartitioner)) { val startAndEnd = partitioner.rangeBounds.map(newPartitioner.intervalRange).zipWithIndex - if (startAndEnd.forall { case ((start, end), i) => start + 1 == end && - newPartitioner.rangeBounds(start).includes(newPartitioner.kord, partitioner.rangeBounds(i)) - }) { + if ( + startAndEnd.forall { case ((start, end), i) => + start + 1 == end && + newPartitioner.rangeBounds(start).includes( + newPartitioner.kord, + partitioner.rangeBounds(i), + ) + } + ) { val newToOld = startAndEnd.groupBy(_._1._1).map { case (newIdx, values) => (newIdx, values.map(_._2).sorted.toIndexedSeq) } @@ -316,9 +376,15 @@ class TableStage( val (oldPartIndices, newPartitionerFilt) = if (dropEmptyPartitions) { val indices = (0 until newPartitioner.numPartitions).filter(newToOld.contains) - (indices.map(newToOld), newPartitioner.copy(rangeBounds = indices.map(newPartitioner.rangeBounds))) + ( + indices.map(newToOld), + newPartitioner.copy(rangeBounds = indices.map(newPartitioner.rangeBounds)), + ) } else - ((0 until newPartitioner.numPartitions).map(i => newToOld.getOrElse(i, FastSeq())), newPartitioner) + ( + (0 until newPartitioner.numPartitions).map(i => newToOld.getOrElse(i, FastSeq())), + newPartitioner, + ) log.info( "repartitionNoShuffle - fast path," + @@ -328,32 +394,45 @@ class TableStage( val newContexts = bindIR(ToArray(contexts)) { oldCtxs => mapIR(ToStream(Literal(TArray(TArray(TInt32)), oldPartIndices))) { inds => - ToArray(mapIR(ToStream(inds)) { i => ArrayRef(oldCtxs, i) }) + ToArray(mapIR(ToStream(inds))(i => ArrayRef(oldCtxs, i))) } } - return TableStage(letBindings, broadcastVals, globals, newPartitionerFilt, dependency, newContexts, - (ctx: Ref) => flatMapIR(ToStream(ctx, true)) { oldCtx => partition(oldCtx) }) + return TableStage( + letBindings, + broadcastVals, + globals, + newPartitionerFilt, + dependency, + newContexts, + (ctx: Ref) => flatMapIR(ToStream(ctx, true))(oldCtx => partition(oldCtx)), + ) } val boundType = RVDPartitioner.intervalIRRepresentation(newPartitioner.kType) val partitionMapping: IndexedSeq[Row] = newPartitioner.rangeBounds.map { i => - Row(RVDPartitioner.intervalToIRRepresentation(i, newPartitioner.kType.size), partitioner.queryInterval(i)) + Row( + RVDPartitioner.intervalToIRRepresentation(i, newPartitioner.kType.size), + partitioner.queryInterval(i), + ) } val partitionMappingType = TStruct( "partitionBound" -> boundType, - "parentPartitions" -> TArray(TInt32) + "parentPartitions" -> TArray(TInt32), ) val prevContextUID = genUID() val mappingUID = genUID() val idxUID = genUID() - val newContexts = Let(FastSeq(prevContextUID -> ToArray(contexts)), + val newContexts = Let( + FastSeq(prevContextUID -> ToArray(contexts)), StreamMap( ToStream( Literal( TArray(partitionMappingType), - partitionMapping)), + partitionMapping, + ) + ), mappingUID, MakeStruct( FastSeq( @@ -362,57 +441,87 @@ class TableStage( StreamMap( ToStream(GetField(Ref(mappingUID, partitionMappingType), "parentPartitions")), idxUID, - ArrayRef(Ref(prevContextUID, TArray(contexts.typ.asInstanceOf[TStream].elementType)), Ref(idxUID, TInt32)) - )) + ArrayRef( + Ref(prevContextUID, TArray(contexts.typ.asInstanceOf[TStream].elementType)), + Ref(idxUID, TInt32), + ), + ) + ), ) - ) - ) + ), + ), ) val prevContextUIDPartition = genUID() - TableStage(letBindings, broadcastVals, globals, newPartitioner, dependency, newContexts, + TableStage( + letBindings, + broadcastVals, + globals, + newPartitioner, + dependency, + newContexts, (ctxRef: Ref) => { - val body = self.partition(Ref(prevContextUIDPartition, self.contexts.typ.asInstanceOf[TStream].elementType)) + val body = self.partition(Ref( + prevContextUIDPartition, + self.contexts.typ.asInstanceOf[TStream].elementType, + )) bindIR(GetField(ctxRef, "partitionBound")) { interval => takeWhile( dropWhile( StreamFlatMap( ToStream(GetField(ctxRef, "oldContexts"), true), prevContextUIDPartition, - body)) { elt => - invoke("pointLessThanPartitionIntervalLeftEndpoint", TBoolean, + body, + ) + ) { elt => + invoke( + "pointLessThanPartitionIntervalLeftEndpoint", + TBoolean, SelectFields(elt, newPartitioner.kType.fieldNames), invoke("start", boundType.pointType, interval), - invoke("includesStart", TBoolean, interval)) + invoke("includesStart", TBoolean, interval), + ) - }) { elt => - invoke("pointLessThanPartitionIntervalRightEndpoint", TBoolean, + } + ) { elt => + invoke( + "pointLessThanPartitionIntervalRightEndpoint", + TBoolean, SelectFields(elt, newPartitioner.kType.fieldNames), invoke("end", boundType.pointType, interval), - invoke("includesEnd", TBoolean, interval)) + invoke("includesEnd", TBoolean, interval), + ) } } - }) + }, + ) } else { val location = ec.createTmpPath(genUID()) - CompileAndEvaluate(ec, - TableNativeWriter(location).lower(ec, this, RTable.fromTableStage(ec, this)) + CompileAndEvaluate( + ec, + TableNativeWriter(location).lower(ec, this, RTable.fromTableStage(ec, this)), ) val newTableType = TableType(rowType, newPartitioner.kType.fieldNames, globalType) - val reader = TableNativeReader.read(ec.fs, location, Some(NativeReaderOptions( - intervals = newPartitioner.rangeBounds, - intervalPointType = newPartitioner.kType, - filterIntervals = dropEmptyPartitions - ))) + val reader = TableNativeReader.read( + ec.fs, + location, + Some(NativeReaderOptions( + intervals = newPartitioner.rangeBounds, + intervalPointType = newPartitioner.kType, + filterIntervals = dropEmptyPartitions, + )), + ) val table = TableRead(newTableType, dropRows = false, tr = reader) LowerTableIR.applyTable(table, DArrayLowering.All, ec, LoweringAnalyses.apply(table, ec)) } - assert(newStage.rowType == rowType, + assert( + newStage.rowType == rowType, s"repartitioned row type: ${newStage.rowType}\n" + - s" old row type: $rowType") + s" old row type: $rowType", + ) newStage } @@ -437,9 +546,9 @@ class TableStage( joinType: String, globalJoiner: (IR, IR) => IR, joiner: (Ref, Ref) => IR, - rightKeyIsDistinct: Boolean = false + rightKeyIsDistinct: Boolean = false, ): TableStage = { - assert(this.kType.truncate(joinKey).isIsomorphicTo(right.kType.truncate(joinKey))) + assert(this.kType.truncate(joinKey).isJoinableWith(right.kType.truncate(joinKey))) val newPartitioner = { def leftPart: RVDPartitioner = this.partitioner.strictify() @@ -449,10 +558,11 @@ class TableStage( case "right" => rightPart case "inner" => leftPart.intersect(rightPart) case "outer" => RVDPartitioner.generate( - partitioner.sm, - this.kType.fieldNames.take(joinKey), - this.kType, - leftPart.rangeBounds ++ rightPart.rangeBounds) + partitioner.sm, + this.kType.fieldNames.take(joinKey), + this.kType, + leftPart.rangeBounds ++ rightPart.rangeBounds, + ) } } val repartitionedLeft: TableStage = repartitionNoShuffle(ec, newPartitioner) @@ -467,8 +577,18 @@ class TableStage( val lEltRef = Ref(genUID(), lEltType) val rEltRef = Ref(genUID(), rEltType) - StreamJoin(lPart, rPart, lKey, rKey, lEltRef.name, rEltRef.name, joiner(lEltRef, rEltRef), joinType, - requiresMemoryManagement = true, rightKeyIsDistinct = rightKeyIsDistinct) + StreamJoin( + lPart, + rPart, + lKey, + rKey, + lEltRef.name, + rEltRef.name, + joiner(lEltRef, rEltRef), + joinType, + requiresMemoryManagement = true, + rightKeyIsDistinct = rightKeyIsDistinct, + ) } val newKey = kType.fieldNames ++ right.kType.fieldNames.drop(joinKey) @@ -490,14 +610,16 @@ class TableStage( right: TableStage, joinKey: Int, globalJoiner: (IR, IR) => IR, - joiner: (IR, IR) => IR + joiner: (IR, IR) => IR, ): TableStage = { require(joinKey <= kType.size) require(joinKey <= right.kType.size) - val leftKeyToRightKeyMap = (kType.fieldNames.take(joinKey), right.kType.fieldNames.take(joinKey)).zipped.toMap + val leftKeyToRightKeyMap = + (kType.fieldNames.take(joinKey), right.kType.fieldNames.take(joinKey)).zipped.toMap val newRightPartitioner = partitioner.coarsen(joinKey).rename(leftKeyToRightKeyMap) - val repartitionedRight = right.repartitionNoShuffle(ec, newRightPartitioner, allowDuplication = true) + val repartitionedRight = + right.repartitionNoShuffle(ec, newRightPartitioner, allowDuplication = true) zipPartitions(repartitionedRight, globalJoiner, joiner) } @@ -511,7 +633,7 @@ class TableStage( right: TableStage, rightRowRType: RStruct, globalJoiner: (IR, IR) => IR, - joiner: (IR, IR) => IR + joiner: (IR, IR) => IR, ): TableStage = { require(right.kType.size == 1) val rightKeyType = right.kType.fields.head.typ @@ -523,120 +645,171 @@ class TableStage( val rightWithPartNums = right.mapPartition(None) { partStream => flatMapIR(partStream) { row => val interval = bindIR(GetField(row, right.key.head)) { interval => - invoke("Interval", TInterval(TTuple(kType.typeAfterSelect(Array(0)), TInt32)), - MakeTuple.ordered(FastSeq(MakeStruct(FastSeq(kType.fieldNames.head -> invoke("start", kType.types.head, interval))), I32(1))), - MakeTuple.ordered(FastSeq(MakeStruct(FastSeq(kType.fieldNames.head -> invoke("end", kType.types.head, interval))), I32(1))), + invoke( + "Interval", + TInterval(TTuple(kType.typeAfterSelect(Array(0)), TInt32)), + MakeTuple.ordered(FastSeq( + MakeStruct(FastSeq(kType.fieldNames.head -> invoke( + "start", + kType.types.head, + interval, + ))), + I32(1), + )), + MakeTuple.ordered(FastSeq( + MakeStruct(FastSeq(kType.fieldNames.head -> invoke( + "end", + kType.types.head, + interval, + ))), + I32(1), + )), invoke("includesStart", TBoolean, interval), - invoke("includesEnd", TBoolean, interval) + invoke("includesEnd", TBoolean, interval), ) } - bindIR(invoke("partitionerFindIntervalRange", TTuple(TInt32, TInt32), irPartitioner, interval)) { range => - val rangeStream = StreamRange(GetTupleElement(range, 0), GetTupleElement(range, 1), I32(1), requiresMemoryManagementPerElement = true) - mapIR(rangeStream) { partNum => - InsertFields(row, FastSeq("__partNum" -> partNum)) - } + bindIR(invoke( + "partitionerFindIntervalRange", + TTuple(TInt32, TInt32), + irPartitioner, + interval, + )) { range => + val rangeStream = StreamRange( + GetTupleElement(range, 0), + GetTupleElement(range, 1), + I32(1), + requiresMemoryManagementPerElement = true, + ) + mapIR(rangeStream)(partNum => InsertFields(row, FastSeq("__partNum" -> partNum))) } } } - val rightRowRTypeWithPartNum = IndexedSeq("__partNum" -> TypeWithRequiredness(TInt32)) ++ rightRowRType.fields.map(rField => rField.name -> rField.typ) + val rightRowRTypeWithPartNum = + IndexedSeq("__partNum" -> TypeWithRequiredness(TInt32)) ++ rightRowRType.fields.map(rField => + rField.name -> rField.typ + ) val rightTableRType = RTable(rightRowRTypeWithPartNum, FastSeq(), right.key) - val sortedReader = ctx.backend.lowerDistributedSort(ctx, + val sortedReader = ctx.backend.lowerDistributedSort( + ctx, rightWithPartNums, SortField("__partNum", Ascending) +: right.key.map(k => SortField(k, Ascending)), - rightTableRType) + rightTableRType, + ) val sorted = sortedReader.lower(ctx, sortedReader.fullType) assert(sorted.kType.fieldNames.sameElements("__partNum" +: right.key)) val newRightPartitioner = new RVDPartitioner( ctx.stateManager, Some(1), TStruct.concat(TStruct("__partNum" -> TInt32), right.kType), - Array.tabulate[Interval](partitioner.numPartitions)(i => Interval(Row(i), Row(i), true, true)) - ) + Array.tabulate[Interval](partitioner.numPartitions)(i => Interval(Row(i), Row(i), true, true)), + ) val repartitioned = sorted.repartitionNoShuffle(ctx, newRightPartitioner) - .changePartitionerNoRepartition(RVDPartitioner.unkeyed(ctx.stateManager, newRightPartitioner.numPartitions)) + .changePartitionerNoRepartition(RVDPartitioner.unkeyed( + ctx.stateManager, + newRightPartitioner.numPartitions, + )) .mapPartition(None) { part => - mapIR(part) { row => - SelectFields(row, right.rowType.fieldNames) - } + mapIR(part)(row => SelectFields(row, right.rowType.fieldNames)) } zipPartitions(repartitioned, globalJoiner, joiner) } } object LowerTableIR { - def apply(ir: IR, typesToLower: DArrayLowering.Type, ctx: ExecuteContext, analyses: LoweringAnalyses): IR = { - def lower(tir: TableIR): TableStage = { + def apply( + ir: IR, + typesToLower: DArrayLowering.Type, + ctx: ExecuteContext, + analyses: LoweringAnalyses, + ): IR = { + def lower(tir: TableIR): TableStage = this.applyTable(tir, typesToLower, ctx, analyses) - } val lowered = ir match { case TableCount(tableIR) => val stage = lower(tableIR) - invoke("sum", TInt64, - stage.countPerPartition()) + invoke("sum", TInt64, stage.countPerPartition()) case TableToValueApply(child, ForceCountTable()) => val stage = lower(child) - invoke("sum", TInt64, - stage.mapCollect("table_force_count")(rows => foldIR(mapIR(rows)(row => Consume(row)), 0L)(_ + _))) + invoke( + "sum", + TInt64, + stage.mapCollect("table_force_count")(rows => + foldIR(mapIR(rows)(row => Consume(row)), 0L)(_ + _) + ), + ) case TableToValueApply(child, TableCalculateNewPartitions(nPartitions)) => val stage = lower(child) val sampleSize = math.min((nPartitions * 20 + 256), 1000000) val samplesPerPartition = sampleSize / math.max(1, stage.numPartitions) val keyType = child.typ.keyType - val samplekey = AggSignature(ReservoirSample(), - FastSeq(TInt32), - FastSeq(keyType)) + val samplekey = AggSignature(ReservoirSample(), FastSeq(TInt32), FastSeq(keyType)) - val minkey = AggSignature(TakeBy(), - FastSeq(TInt32), - FastSeq(keyType, keyType)) - - val maxkey = AggSignature(TakeBy(Descending), - FastSeq(TInt32), - FastSeq(keyType, keyType)) + val minkey = AggSignature(TakeBy(), FastSeq(TInt32), FastSeq(keyType, keyType)) + val maxkey = AggSignature(TakeBy(Descending), FastSeq(TInt32), FastSeq(keyType, keyType)) bindIR(flatten(stage.mapCollect("table_calculate_new_partitions") { rows => - streamAggIR(mapIR(rows) { row => SelectFields(row, keyType.fieldNames)}) { elt => + streamAggIR(mapIR(rows)(row => SelectFields(row, keyType.fieldNames))) { elt => ToArray(flatMapIR(ToStream( MakeArray( ApplyAggOp( FastSeq(I32(samplesPerPartition)), FastSeq(elt), - samplekey), + samplekey, + ), ApplyAggOp( FastSeq(I32(1)), FastSeq(elt, elt), - minkey), + minkey, + ), ApplyAggOp( FastSeq(I32(1)), FastSeq(elt, elt), - maxkey) - ) - )) { inner => ToStream(inner) }) + maxkey, + ), + ) + ))(inner => ToStream(inner))) } })) { partData => - - val sorted = sortIR(partData) { (l, r) => ApplyComparisonOp(LT(keyType, keyType), l, r) } - bindIR(ToArray(flatMapIR(StreamGroupByKey(ToStream(sorted), keyType.fieldNames, missingEqual = true)) { groupRef => - StreamTake(groupRef, 1) - })) { boundsArray => - + val sorted = sortIR(partData)((l, r) => ApplyComparisonOp(LT(keyType, keyType), l, r)) + bindIR(ToArray(flatMapIR(StreamGroupByKey( + ToStream(sorted), + keyType.fieldNames, + missingEqual = true, + ))(groupRef => StreamTake(groupRef, 1)))) { boundsArray => bindIR(ArrayLen(boundsArray)) { nBounds => bindIR(minIR(nBounds, nPartitions)) { nParts => - If(nParts.ceq(0), + If( + nParts.ceq(0), MakeArray(FastSeq(), TArray(TInterval(keyType))), bindIR((nBounds + (nParts - 1)) floorDiv nParts) { stepSize => ToArray(mapIR(StreamRange(0, nBounds, stepSize)) { i => - If((i + stepSize) < (nBounds - 1), - invoke("Interval", TInterval(keyType), ArrayRef(boundsArray, i), ArrayRef(boundsArray, i + stepSize), True(), False()), - invoke("Interval", TInterval(keyType), ArrayRef(boundsArray, i), ArrayRef(boundsArray, nBounds - 1), True(), True()) - )}) - } + If( + (i + stepSize) < (nBounds - 1), + invoke( + "Interval", + TInterval(keyType), + ArrayRef(boundsArray, i), + ArrayRef(boundsArray, i + stepSize), + True(), + False(), + ), + invoke( + "Interval", + TInterval(keyType), + ArrayRef(boundsArray, i), + ArrayRef(boundsArray, nBounds - 1), + True(), + True(), + ), + ) + }) + }, ) } } @@ -650,52 +823,65 @@ object LowerTableIR { lower(child).collectWithGlobals("table_collect") case TableAggregate(child, query) => - val resultUID = genUID() - val aggs = agg.Extract(query, resultUID, analyses.requirednessAnalysis, false) + val aggs = agg.Extract(query, analyses.requirednessAnalysis, false) def results: IR = ResultOp.makeTuple(aggs.aggs) val lc = lower(child) - val initState = Let(FastSeq("global" -> lc.globals), + val initState = Let( + FastSeq("global" -> lc.globals), RunAgg( aggs.init, - MakeTuple.ordered(aggs.aggs.zipWithIndex.map { case (sig, i) => AggStateValue(i, sig.state) }), - aggs.states - )) + MakeTuple.ordered(aggs.aggs.zipWithIndex.map { case (sig, i) => + AggStateValue(i, sig.state) + }), + aggs.states, + ), + ) val initStateRef = Ref(genUID(), initState.typ) val lcWithInitBinding = lc.copy( letBindings = lc.letBindings ++ FastSeq((initStateRef.name, initState)), - broadcastVals = lc.broadcastVals ++ FastSeq((initStateRef.name, initStateRef))) + broadcastVals = lc.broadcastVals ++ FastSeq((initStateRef.name, initStateRef)), + ) val initFromSerializedStates = Begin(aggs.aggs.zipWithIndex.map { case (agg, i) => - InitFromSerializedValue(i, GetTupleElement(initStateRef, i), agg.state )}) + InitFromSerializedValue(i, GetTupleElement(initStateRef, i), agg.state) + }) val branchFactor = HailContext.get.branchingFactor val useTreeAggregate = aggs.shouldTreeAggregate && branchFactor < lc.numPartitions val isCommutative = aggs.isCommutative - log.info(s"Aggregate: useTreeAggregate=${ useTreeAggregate }") - log.info(s"Aggregate: commutative=${ isCommutative }") + log.info(s"Aggregate: useTreeAggregate=$useTreeAggregate") + log.info(s"Aggregate: commutative=$isCommutative") if (useTreeAggregate) { val tmpDir = ctx.createTmpPath("aggregate_intermediates/") - val codecSpec = TypedCodecSpec(PCanonicalTuple(true, aggs.aggs.map(_ => PCanonicalBinary(true)): _*), BufferSpec.wireSpec) + val codecSpec = TypedCodecSpec( + PCanonicalTuple(true, aggs.aggs.map(_ => PCanonicalBinary(true)): _*), + BufferSpec.wireSpec, + ) val writer = ETypeValueWriter(codecSpec) val reader = ETypeValueReader(codecSpec) lcWithInitBinding.mapCollectWithGlobals("table_aggregate")({ part: IR => - Let(FastSeq("global" -> lc.globals), + Let( + FastSeq("global" -> lc.globals), RunAgg( Begin(FastSeq( initFromSerializedStates, - StreamFor(part, - "row", - aggs.seqPerElt - ) + StreamFor(part, "row", aggs.seqPerElt), )), - WriteValue(MakeTuple.ordered(aggs.aggs.zipWithIndex.map { case (sig, i) => AggStateValue(i, sig.state) }), Str(tmpDir) + UUID4(), writer), - aggs.states - )) + WriteValue( + MakeTuple.ordered(aggs.aggs.zipWithIndex.map { case (sig, i) => + AggStateValue(i, sig.state) + }), + Str(tmpDir) + UUID4(), + writer, + ), + aggs.states, + ), + ) }) { case (collected, globals) => val treeAggFunction = genUID() val currentAggStates = Ref(genUID(), TArray(TString)) @@ -708,22 +894,37 @@ object LowerTableIR { if (useInitStates) { initFromSerializedStates } else { - bindIR(ReadValue(ArrayRef(partArrayRef, 0), reader, reader.spec.encodedVirtualType)) { serializedTuple => + bindIR(ReadValue( + ArrayRef(partArrayRef, 0), + reader, + reader.spec.encodedVirtualType, + )) { serializedTuple => Begin( aggs.aggs.zipWithIndex.map { case (sig, i) => InitFromSerializedValue(i, GetTupleElement(serializedTuple, i), sig.state) - }) + } + ) } }, - forIR(StreamRange(if (useInitStates) 0 else 1, ArrayLen(partArrayRef), 1, requiresMemoryManagementPerElement = true)) { fileIdx => - - bindIR(ReadValue(ArrayRef(partArrayRef, fileIdx), reader, reader.spec.encodedVirtualType)) { serializedTuple => + forIR(StreamRange( + if (useInitStates) 0 else 1, + ArrayLen(partArrayRef), + 1, + requiresMemoryManagementPerElement = true, + )) { fileIdx => + bindIR(ReadValue( + ArrayRef(partArrayRef, fileIdx), + reader, + reader.spec.encodedVirtualType, + )) { serializedTuple => Begin( aggs.aggs.zipWithIndex.map { case (sig, i) => CombOpValue(i, GetTupleElement(serializedTuple, i), sig) - }) + } + ) } - })) + }, + )) } val loopBody = If( @@ -733,58 +934,79 @@ object LowerTableIR { treeAggFunction, FastSeq( CollectDistributedArray( - mapIR(StreamGrouped(ToStream(currentAggStates), I32(branchFactor)))(x => ToArray(x)), + mapIR(StreamGrouped(ToStream(currentAggStates), I32(branchFactor)))(x => + ToArray(x) + ), MakeStruct(FastSeq()), distAggStatesRef.name, genUID(), RunAgg( combineGroup(distAggStatesRef, false), - WriteValue(MakeTuple.ordered(aggs.aggs.zipWithIndex.map { case (sig, i) => AggStateValue(i, sig.state) }), Str(tmpDir) + UUID4(), writer), - aggs.states), - strConcat(Str("iteration="), invoke("str", TString, iterNumber), Str(", n_states="), invoke("str", TString, ArrayLen(currentAggStates))), - "table_tree_aggregate"), - iterNumber + 1), - currentAggStates.typ)) + WriteValue( + MakeTuple.ordered(aggs.aggs.zipWithIndex.map { case (sig, i) => + AggStateValue(i, sig.state) + }), + Str(tmpDir) + UUID4(), + writer, + ), + aggs.states, + ), + strConcat( + Str("iteration="), + invoke("str", TString, iterNumber), + Str(", n_states="), + invoke("str", TString, ArrayLen(currentAggStates)), + ), + "table_tree_aggregate", + ), + iterNumber + 1, + ), + currentAggStates.typ, + ), + ) bindIR(TailLoop( treeAggFunction, FastSeq[(String, IR)](currentAggStates.name -> collected, iterNumber.name -> I32(0)), loopBody.typ, - loopBody + loopBody, )) { finalParts => RunAgg( combineGroup(finalParts, true), - Let(FastSeq("global" -> globals, resultUID -> results), aggs.postAggIR), - aggs.states + Let(FastSeq("global" -> globals, aggs.resultRef.name -> results), aggs.postAggIR), + aggs.states, ) } } - } - else { + } else { lcWithInitBinding.mapCollectWithGlobals("table_aggregate_singlestage")({ part: IR => - Let(FastSeq("global" -> lc.globals), + Let( + FastSeq("global" -> lc.globals), RunAgg( Begin(FastSeq( initFromSerializedStates, - StreamFor(part, - "row", - aggs.seqPerElt - ) + StreamFor(part, "row", aggs.seqPerElt), )), - MakeTuple.ordered(aggs.aggs.zipWithIndex.map { case (sig, i) => AggStateValue(i, sig.state) }), - aggs.states - )) + MakeTuple.ordered(aggs.aggs.zipWithIndex.map { case (sig, i) => + AggStateValue(i, sig.state) + }), + aggs.states, + ), + ) }) { case (collected, globals) => - Let(FastSeq("global" -> globals), + Let( + FastSeq("global" -> globals), RunAgg( Begin(FastSeq( initFromSerializedStates, forIR(ToStream(collected, requiresMemoryManagementPerElement = true)) { state => - Begin(aggs.aggs.zipWithIndex.map { case (sig, i) => CombOpValue(i, GetTupleElement(state, i), sig) }) - } + Begin(aggs.aggs.zipWithIndex.map { case (sig, i) => + CombOpValue(i, GetTupleElement(state, i), sig) + }) + }, )), - Let(FastSeq(resultUID -> results), aggs.postAggIR), - aggs.states - ) + Let(FastSeq(aggs.resultRef.name -> results), aggs.postAggIR), + aggs.states, + ), ) } } @@ -793,28 +1015,44 @@ object LowerTableIR { lower(child).getNumPartitions() case TableWrite(child, writer) => - writer.lower(ctx, lower(child), tcoerce[RTable](analyses.requirednessAnalysis.lookup(child))) + writer.lower( + ctx, + lower(child), + tcoerce[RTable](analyses.requirednessAnalysis.lookup(child)), + ) case TableMultiWrite(children, writer) => - writer.lower(ctx, children.map(child => (lower(child), tcoerce[RTable](analyses.requirednessAnalysis.lookup(child))))) + writer.lower( + ctx, + children.map(child => + (lower(child), tcoerce[RTable](analyses.requirednessAnalysis.lookup(child))) + ), + ) case node if node.children.exists(_.isInstanceOf[TableIR]) => - throw new LowererUnsupportedOperation(s"IR nodes with TableIR children must be defined explicitly: \n${ Pretty(ctx, node) }") + throw new LowererUnsupportedOperation( + s"IR nodes with TableIR children must be defined explicitly: \n${Pretty(ctx, node)}" + ) } lowered } - def applyTable(tir: TableIR, typesToLower: DArrayLowering.Type, ctx: ExecuteContext, analyses: LoweringAnalyses): TableStage = { - def lowerIR(ir: IR): IR = { + def applyTable( + tir: TableIR, + typesToLower: DArrayLowering.Type, + ctx: ExecuteContext, + analyses: LoweringAnalyses, + ): TableStage = { + def lowerIR(ir: IR): IR = LowerToCDA.lower(ir, typesToLower, ctx, analyses) - } - def lower(tir: TableIR): TableStage = { + def lower(tir: TableIR): TableStage = this.applyTable(tir, typesToLower, ctx, analyses) - } if (typesToLower == DArrayLowering.BMOnly) - throw new LowererUnsupportedOperation("found TableIR in lowering; lowering only BlockMatrixIRs.") + throw new LowererUnsupportedOperation( + "found TableIR in lowering; lowering only BlockMatrixIRs." + ) val typ: TableType = tir.typ @@ -829,15 +1067,20 @@ object LowerTableIR { val loweredRowsAndGlobalRef = Ref(genUID(), loweredRowsAndGlobal.typ) val context = bindIR(ArrayLen(GetField(loweredRowsAndGlobalRef, "rows"))) { numRowsRef => - bindIR(invoke("extend", TArray(TInt32), ToArray(mapIR(rangeIR(nPartitionsAdj)) { partIdx => - (partIdx * numRowsRef) floorDiv nPartitionsAdj - }), - MakeArray((numRowsRef)))) { indicesArray => + bindIR(invoke( + "extend", + TArray(TInt32), + ToArray(mapIR(rangeIR(nPartitionsAdj)) { partIdx => + (partIdx * numRowsRef) floorDiv nPartitionsAdj + }), + MakeArray((numRowsRef)), + )) { indicesArray => bindIR(GetField(loweredRowsAndGlobalRef, "rows")) { rows => mapIR(rangeIR(nPartitionsAdj)) { partIdx => - ToArray(mapIR(rangeIR(ArrayRef(indicesArray, partIdx), ArrayRef(indicesArray, partIdx + 1))) { rowIdx => - ArrayRef(rows, rowIdx) - }) + ToArray(mapIR(rangeIR( + ArrayRef(indicesArray, partIdx), + ArrayRef(indicesArray, partIdx + 1), + ))(rowIdx => ArrayRef(rows, rowIdx))) } } } @@ -845,14 +1088,17 @@ object LowerTableIR { val globalsRef = Ref(genUID(), typ.globalType) TableStage( - FastSeq(loweredRowsAndGlobalRef.name -> loweredRowsAndGlobal, - globalsRef.name -> GetField(loweredRowsAndGlobalRef, "global")), + FastSeq( + loweredRowsAndGlobalRef.name -> loweredRowsAndGlobal, + globalsRef.name -> GetField(loweredRowsAndGlobalRef, "global"), + ), FastSeq(globalsRef.name -> globalsRef), globalsRef, RVDPartitioner.unkeyed(ctx.stateManager, nPartitionsAdj), TableStageDependency.none, context, - ctxRef => ToStream(ctxRef, true)) + ctxRef => ToStream(ctxRef, true), + ) case TableGen(contexts, globals, cname, gname, body, partitioner, errorId) => val loweredGlobals = lowerIR(globals) @@ -864,42 +1110,57 @@ object LowerTableIR { bindIR(ToArray(contexts)) { ref => bindIR(ArrayLen(ref)) { len => // Assert at runtime that the number of contexts matches the number of partitions - val ctxs = ToStream(If(len ceq partitioner.numPartitions, ref, { - val dieMsg = strConcat( - s"TableGen: partitioner contains ${partitioner.numPartitions} partitions,", - " got ", len, " contexts." - ) - Die(dieMsg, ref.typ, errorId) - })) + val ctxs = ToStream(If( + len ceq partitioner.numPartitions, + ref, { + val dieMsg = strConcat( + s"TableGen: partitioner contains ${partitioner.numPartitions} partitions,", + " got ", + len, + " contexts.", + ) + Die(dieMsg, ref.typ, errorId) + }, + )) // [FOR KEYED TABLES ONLY] // AFAIK, there's no way to guarantee that the rows generated in the // body conform to their partition's range bounds at compile time so // assert this at runtime in the body before it wreaks havoc upon the world. val partitionIdx = StreamRange(I32(0), I32(partitioner.numPartitions), I32(1)) - val bounds = Literal(TArray(TInterval(partitioner.kType)), partitioner.rangeBounds.toIndexedSeq) - zipIR(FastSeq(partitionIdx, ToStream(bounds), ctxs), AssertSameLength, errorId)(MakeTuple.ordered) + val bounds = Literal( + TArray(TInterval(partitioner.kType)), + partitioner.rangeBounds.toIndexedSeq, + ) + zipIR(FastSeq(partitionIdx, ToStream(bounds), ctxs), AssertSameLength, errorId)( + MakeTuple.ordered + ) } } }, - body = in => lowerIR { - val rows = Let(FastSeq(cname -> GetTupleElement(in, 2), gname -> loweredGlobals), body) - if (partitioner.kType.fields.isEmpty) rows - else bindIR(GetTupleElement(in, 1)) { interval => - mapIR(rows) { row => - val key = SelectFields(row, partitioner.kType.fieldNames) - If(invoke("contains", TBoolean, interval, key), row, { - val idx = GetTupleElement(in, 0) - val msg = strConcat( - "TableGen: Unexpected key in partition ", idx, - "\n\tRange bounds for partition ", idx, ": ", interval, - "\n\tInvalid key: ", key + body = in => + lowerIR { + val rows = + Let(FastSeq(cname -> GetTupleElement(in, 2), gname -> loweredGlobals), body) + if (partitioner.kType.fields.isEmpty) rows + else bindIR(GetTupleElement(in, 1)) { interval => + mapIR(rows) { row => + val key = SelectFields(row, partitioner.kType.fieldNames) + If( + invoke("contains", TBoolean, interval, key), + row, { + val idx = GetTupleElement(in, 0) + val msg = strConcat( + "TableGen: Unexpected key in partition ", idx, + "\n\tRange bounds for partition ", idx, ": ", interval, + "\n\tInvalid key: ", key, + ) + Die(msg, row.typ, errorId) + }, ) - Die(msg, row.typ, errorId) - }) + } } - } - } + }, ) case TableRange(n, nPartitions) => @@ -913,39 +1174,61 @@ object LowerTableIR { TableStage( MakeStruct(FastSeq()), - new RVDPartitioner(ctx.stateManager, Array("idx"), tir.typ.rowType, ranges.map { - case (start, end) => Interval(Row(start), Row(end), includesStart = true, includesEnd = false) - }), + new RVDPartitioner( + ctx.stateManager, + Array("idx"), + tir.typ.rowType, + ranges.map { + case (start, end) => + Interval(Row(start), Row(end), includesStart = true, includesEnd = false) + }, + ), TableStageDependency.none, ToStream(Literal(TArray(contextType), ranges.map(Row.fromTuple).toFastSeq)), - (ctxRef: Ref) => mapIR(StreamRange(GetField(ctxRef, "start"), GetField(ctxRef, "end"), I32(1), true)) { i => - MakeStruct(FastSeq("idx" -> i)) - }) + (ctxRef: Ref) => + mapIR(StreamRange(GetField(ctxRef, "start"), GetField(ctxRef, "end"), I32(1), true)) { + i => MakeStruct(FastSeq("idx" -> i)) + }, + ) case TableMapGlobals(child, newGlobals) => lower(child).mapGlobals(old => Let(FastSeq("global" -> old), newGlobals)) case TableAggregateByKey(child, expr) => val loweredChild = lower(child) + val repartitioned = loweredChild.repartitionNoShuffle( + ctx, + loweredChild.partitioner.coarsen(child.typ.key.length).strictify(), + ) - loweredChild.repartitionNoShuffle(ctx, loweredChild.partitioner.coarsen(child.typ.key.length).strictify()) - .mapPartition(Some(child.typ.key)) { partition => - - Let(FastSeq("global" -> loweredChild.globals), - mapIR(StreamGroupByKey(partition, child.typ.key, missingEqual = true)) { groupRef => - StreamAgg( - groupRef, - "row", - bindIRs(ArrayRef(ApplyAggOp(FastSeq(I32(1)), FastSeq(SelectFields(Ref("row", child.typ.rowType), child.typ.key)), - AggSignature(Take(), FastSeq(TInt32), FastSeq(child.typ.keyType))), I32(0)), // FIXME: would prefer a First() agg op - expr) { case Seq(key, value) => - MakeStruct(child.typ.key.map(k => (k, GetField(key, k))) ++ expr.typ.asInstanceOf[TStruct].fieldNames.map { f => - (f, GetField(value, f)) - }) - } - ) - }) - } + repartitioned.mapPartition(Some(child.typ.key)) { partition => + Let( + FastSeq("global" -> repartitioned.globals), + mapIR(StreamGroupByKey(partition, child.typ.key, missingEqual = true)) { groupRef => + StreamAgg( + groupRef, + "row", + bindIRs( + ArrayRef( + ApplyAggOp( + FastSeq(I32(1)), + FastSeq(SelectFields(Ref("row", child.typ.rowType), child.typ.key)), + AggSignature(Take(), FastSeq(TInt32), FastSeq(child.typ.keyType)), + ), + I32(0), + ), // FIXME: would prefer a First() agg op + expr, + ) { case Seq(key, value) => + MakeStruct(child.typ.key.map(k => + (k, GetField(key, k)) + ) ++ expr.typ.asInstanceOf[TStruct].fieldNames.map { f => + (f, GetField(value, f)) + }) + }, + ) + }, + ) + } case TableDistinct(child) => val loweredChild = lower(child) @@ -953,18 +1236,20 @@ object LowerTableIR { if (analyses.distinctKeyedAnalysis.contains(child)) loweredChild else - loweredChild.repartitionNoShuffle(ctx, loweredChild.partitioner.coarsen(child.typ.key.length).strictify()) + loweredChild.repartitionNoShuffle( + ctx, + loweredChild.partitioner.coarsen(child.typ.key.length).strictify(), + ) .mapPartition(None) { partition => - flatMapIR(StreamGroupByKey(partition, child.typ.key, missingEqual = true)) { groupRef => - StreamTake(groupRef, 1) + flatMapIR(StreamGroupByKey(partition, child.typ.key, missingEqual = true)) { + groupRef => StreamTake(groupRef, 1) } } case TableFilter(child, cond) => val loweredChild = lower(child) loweredChild.mapPartition(None) { rows => - Let(FastSeq("global" -> loweredChild.globals), - StreamFilter(rows, "row", cond)) + Let(FastSeq("global" -> loweredChild.globals), StreamFilter(rows, "row", cond)) } case TableFilterIntervals(child, intervals, keep) => @@ -974,7 +1259,11 @@ object LowerTableIR { val ord = PartitionBoundOrdering(ctx.stateManager, kt) val iord = ord.intervalEndpointOrdering - val filterPartitioner = new RVDPartitioner(ctx.stateManager, kt, Interval.union(intervals, ord.intervalEndpointOrdering)) + val filterPartitioner = new RVDPartitioner( + ctx.stateManager, + kt, + Interval.union(intervals, ord.intervalEndpointOrdering), + ) val boundsType = TArray(RVDPartitioner.intervalIRRepresentation(kt)) val filterIntervalsRef = Ref(genUID(), boundsType) val filterIntervals: IndexedSeq[Interval] = filterPartitioner.rangeBounds.map { i => @@ -982,11 +1271,19 @@ object LowerTableIR { } val (newRangeBounds, includedIndices, startAndEndInterval, f) = if (keep) { - val (newRangeBounds, includedIndices, startAndEndInterval) = part.rangeBounds.zipWithIndex.flatMap { case (interval, i) => - if (filterPartitioner.overlaps(interval)) { - Some((interval, i, (filterPartitioner.lowerBoundInterval(interval), filterPartitioner.upperBoundInterval(interval)))) - } else None - }.unzip3 + val (newRangeBounds, includedIndices, startAndEndInterval) = + part.rangeBounds.zipWithIndex.flatMap { case (interval, i) => + if (filterPartitioner.overlaps(interval)) { + Some(( + interval, + i, + ( + filterPartitioner.lowerBoundInterval(interval), + filterPartitioner.upperBoundInterval(interval), + ), + )) + } else None + }.unzip3 def f(partitionIntervals: IR, key: IR): IR = invoke("partitionerContains", TBoolean, partitionIntervals, key) @@ -994,15 +1291,22 @@ object LowerTableIR { (newRangeBounds, includedIndices, startAndEndInterval, f _) } else { // keep = False - val (newRangeBounds, includedIndices, startAndEndInterval) = part.rangeBounds.zipWithIndex.flatMap { case (interval, i) => - val lowerBound = filterPartitioner.lowerBoundInterval(interval) - val upperBound = filterPartitioner.upperBoundInterval(interval) - if ((lowerBound until upperBound).map(filterPartitioner.rangeBounds).exists { filterInterval => - iord.compareNonnull(filterInterval.left, interval.left) <= 0 && iord.compareNonnull(filterInterval.right, interval.right) >= 0 - }) - None - else Some((interval, i, (lowerBound, upperBound))) - }.unzip3 + val (newRangeBounds, includedIndices, startAndEndInterval) = + part.rangeBounds.zipWithIndex.flatMap { case (interval, i) => + val lowerBound = filterPartitioner.lowerBoundInterval(interval) + val upperBound = filterPartitioner.upperBoundInterval(interval) + if ( + (lowerBound until upperBound).map(filterPartitioner.rangeBounds).exists { + filterInterval => + iord.compareNonnull( + filterInterval.left, + interval.left, + ) <= 0 && iord.compareNonnull(filterInterval.right, interval.right) >= 0 + } + ) + None + else Some((interval, i, (lowerBound, upperBound))) + }.unzip3 def f(partitionIntervals: IR, key: IR): IR = !invoke("partitionerContains", TBoolean, partitionIntervals, key) @@ -1014,32 +1318,43 @@ object LowerTableIR { TableStage( letBindings = loweredChild.letBindings, - broadcastVals = loweredChild.broadcastVals ++ FastSeq((filterIntervalsRef.name, Literal(boundsType, filterIntervals))), + broadcastVals = loweredChild.broadcastVals ++ FastSeq(( + filterIntervalsRef.name, + Literal(boundsType, filterIntervals), + )), loweredChild.globals, newPart, loweredChild.dependency, contexts = bindIRs( ToArray(loweredChild.contexts), - Literal(TArray(TTuple(TInt32, TInt32)), startAndEndInterval.map(Row.fromTuple).toFastSeq) + Literal( + TArray(TTuple(TInt32, TInt32)), + startAndEndInterval.map(Row.fromTuple).toFastSeq, + ), ) { case Seq(prevContexts, bounds) => - zip2(ToStream(Literal(TArray(TInt32), includedIndices.toFastSeq)), ToStream(bounds), ArrayZipBehavior.AssumeSameLength) { (idx, bound) => + zip2( + ToStream(Literal(TArray(TInt32), includedIndices.toFastSeq)), + ToStream(bounds), + ArrayZipBehavior.AssumeSameLength, + ) { (idx, bound) => MakeStruct(FastSeq(("prevContext", ArrayRef(prevContexts, idx)), ("bounds", bound))) } }, { (part: Ref) => val oldPart = loweredChild.partition(GetField(part, "prevContext")) bindIR(GetField(part, "bounds")) { bounds => - bindIRs(GetTupleElement(bounds, 0), GetTupleElement(bounds, 1)) { case Seq(startIntervalIdx, endIntervalIdx) => - bindIR(ToArray(mapIR(rangeIR(startIntervalIdx, endIntervalIdx)) { i => ArrayRef(filterIntervalsRef, i) })) { partitionIntervals => - filterIR(oldPart) { row => - bindIR(SelectFields(row, child.typ.key)) { key => - f(partitionIntervals, key) + bindIRs(GetTupleElement(bounds, 0), GetTupleElement(bounds, 1)) { + case Seq(startIntervalIdx, endIntervalIdx) => + bindIR(ToArray(mapIR(rangeIR(startIntervalIdx, endIntervalIdx)) { i => + ArrayRef(filterIntervalsRef, i) + })) { partitionIntervals => + filterIR(oldPart) { row => + bindIR(SelectFields(row, child.typ.key))(key => f(partitionIntervals, key)) } } - } } } - } + }, ) case TableHead(child, targetNumRows) => @@ -1071,23 +1386,36 @@ object LowerTableIR { val loopBody = bindIR( loweredChild - .mapContexts(_ => StreamTake(ToStream(childContexts), howManyPartsToTryRef)) { ctx: IR => ctx } + .mapContexts(_ => StreamTake(ToStream(childContexts), howManyPartsToTryRef)) { + ctx: IR => ctx + } .mapCollect( "table_head_recursive_count", - strConcat(Str("iteration="), invoke("str", TString, iteration), Str(",nParts="), invoke("str", TString, howManyPartsToTryRef)) - )(streamLenOrMax) - ) { counts => - If( - (Cast(streamSumIR(ToStream(counts)), TInt64) >= targetNumRows) - || (ArrayLen(childContexts) <= ArrayLen(counts)), - counts, - Recur(partitionSizeArrayFunc, FastSeq(howManyPartsToTryRef * 4, iteration + 1), TArray(TInt32))) + strConcat( + Str("iteration="), + invoke("str", TString, iteration), + Str(",nParts="), + invoke("str", TString, howManyPartsToTryRef), + ), + )(streamLenOrMax) + ) { counts => + If( + (Cast(streamSumIR(ToStream(counts)), TInt64) >= targetNumRows) + || (ArrayLen(childContexts) <= ArrayLen(counts)), + counts, + Recur( + partitionSizeArrayFunc, + FastSeq(howManyPartsToTryRef * 4, iteration + 1), + TArray(TInt32), + ), + ) } TailLoop( partitionSizeArrayFunc, FastSeq(howManyPartsToTryRef.name -> howManyPartsToTry, iteration.name -> 0), loopBody.typ, - loopBody) + loopBody, + ) } } @@ -1096,24 +1424,34 @@ object LowerTableIR { val howManyPartsToKeep = genUID() val i = Ref(genUID(), TInt32) val numLeft = Ref(genUID(), TInt64) - def makeAnswer(howManyParts: IR, howManyFromLast: IR) = MakeTuple(FastSeq((0, howManyParts), (1, howManyFromLast))) + def makeAnswer(howManyParts: IR, howManyFromLast: IR) = + MakeTuple(FastSeq((0, howManyParts), (1, howManyFromLast))) val loopBody = If( - (i ceq numPartitions - 1) || ((numLeft - Cast(ArrayRef(partitionSizeArrayRef, i), TInt64)) <= 0L), + (i ceq numPartitions - 1) || ((numLeft - Cast( + ArrayRef(partitionSizeArrayRef, i), + TInt64, + )) <= 0L), makeAnswer(i + 1, numLeft), Recur( howManyPartsToKeep, FastSeq( i + 1, - numLeft - Cast(ArrayRef(partitionSizeArrayRef, i), TInt64)), - TTuple(TInt32, TInt64))) - If(numPartitions ceq 0, + numLeft - Cast(ArrayRef(partitionSizeArrayRef, i), TInt64), + ), + TTuple(TInt32, TInt64), + ), + ) + If( + numPartitions ceq 0, makeAnswer(0, 0L), TailLoop( howManyPartsToKeep, FastSeq(i.name -> 0, numLeft.name -> targetNumRows), loopBody.typ, - loopBody)) + loopBody, + ), + ) } val newCtxs = bindIR(ToArray(loweredChild.contexts)) { childContexts => @@ -1123,23 +1461,29 @@ object LowerTableIR { val numElementsFromLastPart = GetTupleElement(answerTupleRef, 1) val onlyNeededPartitions = StreamTake(ToStream(childContexts), numParts) val howManyFromEachPart = mapIR(rangeIR(numParts)) { idxRef => - If(idxRef ceq (numParts - 1), + If( + idxRef ceq (numParts - 1), Cast(numElementsFromLastPart, TInt32), - ArrayRef(partitionSizeArrayRef, idxRef)) + ArrayRef(partitionSizeArrayRef, idxRef), + ) } StreamZip( FastSeq(onlyNeededPartitions, howManyFromEachPart), FastSeq("part", "howMany"), - MakeStruct(FastSeq("numberToTake" -> Ref("howMany", TInt32), - "old" -> Ref("part", loweredChild.ctxType))), - ArrayZipBehavior.AssumeSameLength) + MakeStruct(FastSeq( + "numberToTake" -> Ref("howMany", TInt32), + "old" -> Ref("part", loweredChild.ctxType), + )), + ArrayZipBehavior.AssumeSameLength, + ) } } } val bindRelationLetsNewCtx = Let(loweredChild.letBindings, ToArray(newCtxs)) - val newCtxSeq = CompileAndEvaluate(ctx, bindRelationLetsNewCtx).asInstanceOf[IndexedSeq[Any]] + val newCtxSeq = + CompileAndEvaluate(ctx, bindRelationLetsNewCtx).asInstanceOf[IndexedSeq[Any]] val numNewParts = newCtxSeq.length val newIntervals = loweredChild.partitioner.rangeBounds.slice(0, numNewParts) val newPartitioner = loweredChild.partitioner.copy(rangeBounds = newIntervals) @@ -1151,9 +1495,12 @@ object LowerTableIR { newPartitioner, loweredChild.dependency, ToStream(Literal(bindRelationLetsNewCtx.typ, newCtxSeq)), - (ctxRef: Ref) => StreamTake( - loweredChild.partition(GetField(ctxRef, "old")), - GetField(ctxRef, "numberToTake"))) + (ctxRef: Ref) => + StreamTake( + loweredChild.partition(GetField(ctxRef, "old")), + GetField(ctxRef, "numberToTake"), + ), + ) case TableTail(child, targetNumRows) => val loweredChild = lower(child) @@ -1167,7 +1514,9 @@ object LowerTableIR { sumSoFar += partCounts(idx) idx -= 1 } - val finalParts = (idx + 1 until partCounts.length).map { partIdx => partCounts(partIdx).toInt }.toFastSeq + val finalParts = (idx + 1 until partCounts.length).map { partIdx => + partCounts(partIdx).toInt + }.toFastSeq Literal(TArray(TInt32), finalParts) case None => @@ -1179,49 +1528,87 @@ object LowerTableIR { val loopBody = bindIR( loweredChild - .mapContexts(_ => StreamDrop(ToStream(childContexts), maxIR(totalNumPartitions - howManyPartsToTryRef, 0))) { ctx: IR => ctx } + .mapContexts(_ => + StreamDrop( + ToStream(childContexts), + maxIR(totalNumPartitions - howManyPartsToTryRef, 0), + ) + ) { ctx: IR => ctx } .mapCollect( "table_tail_recursive_count", - strConcat(Str("iteration="), invoke("str", TString, iteration), Str(", nParts="), invoke("str", TString, howManyPartsToTryRef)) - )(StreamLen) - ) { counts => + strConcat( + Str("iteration="), + invoke("str", TString, iteration), + Str(", nParts="), + invoke("str", TString, howManyPartsToTryRef), + ), + )(StreamLen) + ) { counts => If( - (Cast(streamSumIR(ToStream(counts)), TInt64) >= targetNumRows) || (totalNumPartitions <= ArrayLen(counts)), + (Cast( + streamSumIR(ToStream(counts)), + TInt64, + ) >= targetNumRows) || (totalNumPartitions <= ArrayLen(counts)), counts, - Recur(partitionSizeArrayFunc, FastSeq(howManyPartsToTryRef * 4, iteration + 1), TArray(TInt32))) + Recur( + partitionSizeArrayFunc, + FastSeq(howManyPartsToTryRef * 4, iteration + 1), + TArray(TInt32), + ), + ) } TailLoop( partitionSizeArrayFunc, FastSeq(howManyPartsToTryRef.name -> howManyPartsToTry, iteration.name -> 0), loopBody.typ, - loopBody) + loopBody, + ) } } - // First element is how many partitions to keep from the right partitionSizeArrayRef, second is how many to keep from first kept element. + /* First element is how many partitions to keep from the right partitionSizeArrayRef, second + * is how many to keep from first kept element. */ def answerTuple(partitionSizeArrayRef: Ref): IR = { bindIR(ArrayLen(partitionSizeArrayRef)) { numPartitions => val howManyPartsToDrop = genUID() val i = Ref(genUID(), TInt32) val nRowsToRight = Ref(genUID(), TInt64) - def makeAnswer(howManyParts: IR, howManyFromLast: IR) = MakeTuple.ordered(FastSeq(howManyParts, howManyFromLast)) + def makeAnswer(howManyParts: IR, howManyFromLast: IR) = + MakeTuple.ordered(FastSeq(howManyParts, howManyFromLast)) val loopBody = If( - (i ceq numPartitions) || ((nRowsToRight + Cast(ArrayRef(partitionSizeArrayRef, numPartitions - i), TInt64)) >= targetNumRows), - makeAnswer(i, maxIR(0L, Cast(ArrayRef(partitionSizeArrayRef, numPartitions - i), TInt64) - (I64(targetNumRows) - nRowsToRight)).toI), + (i ceq numPartitions) || ((nRowsToRight + Cast( + ArrayRef(partitionSizeArrayRef, numPartitions - i), + TInt64, + )) >= targetNumRows), + makeAnswer( + i, + maxIR( + 0L, + Cast(ArrayRef(partitionSizeArrayRef, numPartitions - i), TInt64) - (I64( + targetNumRows + ) - nRowsToRight), + ).toI, + ), Recur( howManyPartsToDrop, FastSeq( i + 1, - nRowsToRight + Cast(ArrayRef(partitionSizeArrayRef, numPartitions - i), TInt64)), - TTuple(TInt32, TInt32))) - If(numPartitions ceq 0, + nRowsToRight + Cast(ArrayRef(partitionSizeArrayRef, numPartitions - i), TInt64), + ), + TTuple(TInt32, TInt32), + ), + ) + If( + numPartitions ceq 0, makeAnswer(0, 0), TailLoop( howManyPartsToDrop, FastSeq(i.name -> 1, nRowsToRight.name -> 0L), loopBody.typ, - loopBody)) + loopBody, + ), + ) } } @@ -1235,7 +1622,8 @@ object LowerTableIR { mapIR(rangeIR(numPartsToKeepFromRight)) { idx => MakeStruct(FastSeq( "numberToDrop" -> If(idx ceq 0, nToDropFromFirst, 0), - "old" -> ArrayRef(childContexts, idx + startIdx))) + "old" -> ArrayRef(childContexts, idx + startIdx), + )) } } } @@ -1256,9 +1644,11 @@ object LowerTableIR { newPartitioner, loweredChild.dependency, ToStream(Literal(letBindNewCtx.typ, newCtxSeq)), - (ctxRef: Ref) => bindIR(GetField(ctxRef, "old")) { oldRef => - StreamDrop(loweredChild.partition(oldRef), GetField(ctxRef, "numberToDrop")) - }) + (ctxRef: Ref) => + bindIR(GetField(ctxRef, "old")) { oldRef => + StreamDrop(loweredChild.partition(oldRef), GetField(ctxRef, "numberToDrop")) + }, + ) case TableMapRows(child, newRow) => val lc = lower(child) @@ -1266,25 +1656,25 @@ object LowerTableIR { lc.mapPartition(Some(child.typ.key)) { rows => Let( FastSeq("global" -> lc.globals), - mapIR(rows) { row => - Let(FastSeq("row" -> row), newRow) - } + mapIR(rows)(row => Let(FastSeq("row" -> row), newRow)), ) } } else { - val resultUID = genUID() - val aggs = agg.Extract(newRow, resultUID, analyses.requirednessAnalysis, isScan = true) + val aggs = agg.Extract(newRow, analyses.requirednessAnalysis, isScan = true) val results: IR = ResultOp.makeTuple(aggs.aggs) val initState = RunAgg( Let(FastSeq("global" -> lc.globals), aggs.init), - MakeTuple.ordered(aggs.aggs.zipWithIndex.map { case (sig, i) => AggStateValue(i, sig.state) }), - aggs.states + MakeTuple.ordered(aggs.aggs.zipWithIndex.map { case (sig, i) => + AggStateValue(i, sig.state) + }), + aggs.states, ) val initStateRef = Ref(genUID(), initState.typ) val lcWithInitBinding = lc.copy( letBindings = lc.letBindings ++ FastSeq((initStateRef.name, initState)), - broadcastVals = lc.broadcastVals ++ FastSeq((initStateRef.name, initStateRef))) + broadcastVals = lc.broadcastVals ++ FastSeq((initStateRef.name, initStateRef)), + ) val initFromSerializedStates = Begin(aggs.aggs.zipWithIndex.map { case (agg, i) => InitFromSerializedValue(i, GetTupleElement(initStateRef, i), agg.state) @@ -1294,198 +1684,307 @@ object LowerTableIR { val (partitionPrefixSumValues, transformPrefixSum): (IR, IR => IR) = if (big) { val tmpDir = ctx.createTmpPath("aggregate_intermediates/") - val codecSpec = TypedCodecSpec(PCanonicalTuple(true, aggs.aggs.map(_ => PCanonicalBinary(true)): _*), BufferSpec.wireSpec) + val codecSpec = TypedCodecSpec( + PCanonicalTuple(true, aggs.aggs.map(_ => PCanonicalBinary(true)): _*), + BufferSpec.wireSpec, + ) val writer = ETypeValueWriter(codecSpec) val reader = ETypeValueReader(codecSpec) - val partitionPrefixSumFiles = lcWithInitBinding.mapCollectWithGlobals("table_scan_write_prefix_sums")({ part: IR => - Let(FastSeq("global" -> lcWithInitBinding.globals), - RunAgg( + val partitionPrefixSumFiles = + lcWithInitBinding.mapCollectWithGlobals("table_scan_write_prefix_sums")({ part: IR => + Let( + FastSeq("global" -> lcWithInitBinding.globals), + RunAgg( + Begin(FastSeq( + initFromSerializedStates, + StreamFor(part, "row", aggs.seqPerElt), + )), + WriteValue( + MakeTuple.ordered(aggs.aggs.zipWithIndex.map { case (sig, i) => + AggStateValue(i, sig.state) + }), + Str(tmpDir) + UUID4(), + writer, + ), + aggs.states, + ), + ) + // Collected is TArray of TString + }) { case (collected, _) => + def combineGroup(partArrayRef: IR): IR = { Begin(FastSeq( - initFromSerializedStates, - StreamFor(part, - "row", - aggs.seqPerElt - ) - )), - WriteValue(MakeTuple.ordered(aggs.aggs.zipWithIndex.map { case (sig, i) => AggStateValue(i, sig.state) }), Str(tmpDir) + UUID4(), writer), - aggs.states - )) - // Collected is TArray of TString - }) { case (collected, _) => - - def combineGroup(partArrayRef: IR): IR = { - Begin(FastSeq( - bindIR(ReadValue(ArrayRef(partArrayRef, 0), reader, reader.spec.encodedVirtualType)) { serializedTuple => - Begin( - aggs.aggs.zipWithIndex.map { case (sig, i) => - InitFromSerializedValue(i, GetTupleElement(serializedTuple, i), sig.state) - }) - }, - forIR(StreamRange(1, ArrayLen(partArrayRef), 1, requiresMemoryManagementPerElement = true)) { fileIdx => - - bindIR(ReadValue(ArrayRef(partArrayRef, fileIdx), reader, reader.spec.encodedVirtualType)) { serializedTuple => + bindIR(ReadValue( + ArrayRef(partArrayRef, 0), + reader, + reader.spec.encodedVirtualType, + )) { serializedTuple => Begin( aggs.aggs.zipWithIndex.map { case (sig, i) => - CombOpValue(i, GetTupleElement(serializedTuple, i), sig) - }) - } - })) - } - - // Return Array[Array[String]], length is log_b(num_partitions) - // The upward pass starts with partial aggregations from each partition, - // and aggregates these in a tree parameterized by the branching factor. - // The tree ends when the number of partial aggregations is less than or - // equal to the branching factor. - - // The upward pass returns the full tree of results as an array of arrays, - // where the first element is partial aggregations per partition of the - // input. - def upPass(): IR = { - val aggStack = Ref(genUID(), TArray(TArray(TString))) - val iteration = Ref(genUID(), TInt32) - val loopName = genUID() - - val loopBody = bindIR(ArrayRef(aggStack, (ArrayLen(aggStack) - 1))) { states => - bindIR(ArrayLen(states)) { statesLen => - If( - statesLen > branchFactor, - bindIR((statesLen + branchFactor - 1) floorDiv branchFactor) { nCombines => - val contexts = mapIR(rangeIR(nCombines)) { outerIdxRef => - sliceArrayIR(states, outerIdxRef * branchFactor, (outerIdxRef + 1) * branchFactor) - } - val cdaResult = cdaIR( - contexts, MakeStruct(FastSeq()), "table_scan_up_pass", - strConcat(Str("iteration="), invoke("str", TString, iteration), Str(", nStates="), invoke("str", TString, statesLen)) - ) { case (contexts, _) => - RunAgg( - combineGroup(contexts), - WriteValue(MakeTuple.ordered(aggs.aggs.zipWithIndex.map { case (sig, i) => AggStateValue(i, sig.state) }), Str(tmpDir) + UUID4(), writer), - aggs.states) + InitFromSerializedValue(i, GetTupleElement(serializedTuple, i), sig.state) } - Recur(loopName, IndexedSeq(invoke("extend", TArray(TArray(TString)), aggStack, MakeArray(cdaResult)), iteration + 1), TArray(TArray(TString))) - }, - aggStack) + ) + }, + forIR(StreamRange( + 1, + ArrayLen(partArrayRef), + 1, + requiresMemoryManagementPerElement = true, + )) { fileIdx => + bindIR(ReadValue( + ArrayRef(partArrayRef, fileIdx), + reader, + reader.spec.encodedVirtualType, + )) { serializedTuple => + Begin( + aggs.aggs.zipWithIndex.map { case (sig, i) => + CombOpValue(i, GetTupleElement(serializedTuple, i), sig) + } + ) + } + }, + )) + } + + // Return Array[Array[String]], length is log_b(num_partitions) + // The upward pass starts with partial aggregations from each partition, + // and aggregates these in a tree parameterized by the branching factor. + // The tree ends when the number of partial aggregations is less than or + // equal to the branching factor. + + // The upward pass returns the full tree of results as an array of arrays, + // where the first element is partial aggregations per partition of the + // input. + def upPass(): IR = { + val aggStack = Ref(genUID(), TArray(TArray(TString))) + val iteration = Ref(genUID(), TInt32) + val loopName = genUID() + + val loopBody = bindIR(ArrayRef(aggStack, (ArrayLen(aggStack) - 1))) { states => + bindIR(ArrayLen(states)) { statesLen => + If( + statesLen > branchFactor, + bindIR((statesLen + branchFactor - 1) floorDiv branchFactor) { nCombines => + val contexts = mapIR(rangeIR(nCombines)) { outerIdxRef => + sliceArrayIR( + states, + outerIdxRef * branchFactor, + (outerIdxRef + 1) * branchFactor, + ) + } + val cdaResult = cdaIR( + contexts, + MakeStruct(FastSeq()), + "table_scan_up_pass", + strConcat( + Str("iteration="), + invoke("str", TString, iteration), + Str(", nStates="), + invoke("str", TString, statesLen), + ), + ) { case (contexts, _) => + RunAgg( + combineGroup(contexts), + WriteValue( + MakeTuple.ordered(aggs.aggs.zipWithIndex.map { case (sig, i) => + AggStateValue(i, sig.state) + }), + Str(tmpDir) + UUID4(), + writer, + ), + aggs.states, + ) + } + Recur( + loopName, + IndexedSeq( + invoke( + "extend", + TArray(TArray(TString)), + aggStack, + MakeArray(cdaResult), + ), + iteration + 1, + ), + TArray(TArray(TString)), + ) + }, + aggStack, + ) + } } + TailLoop( + loopName, + IndexedSeq((aggStack.name, MakeArray(collected)), (iteration.name, I32(0))), + loopBody.typ, + loopBody, + ) } - TailLoop( - loopName, - IndexedSeq((aggStack.name, MakeArray(collected)), (iteration.name, I32(0))), - loopBody.typ, - loopBody) - } - // The downward pass traverses the tree from root to leaves, computing partial scan - // sums as it goes. The two pieces of state transmitted between iterations are the - // level (an integer) referring to a position in the array `aggStack`, and `last`, - // the partial sums from the last iteration. The starting state for `last` is an - // array of a single empty aggregation state. - bindIR(upPass()) { aggStack => - val downPassLoopName = genUID() - val iteration = Ref(genUID(), TInt32) - - val level = Ref(genUID(), TInt32) - val last = Ref(genUID(), TArray(TString)) - - - bindIR(WriteValue(initState, Str(tmpDir) + UUID4(), writer)) { freshState => - val loopBody = If( - level < 0, - last, - bindIR(ArrayRef(aggStack, level)) { aggsArray => - val groups = mapIR(zipWithIndex(mapIR(StreamGrouped(ToStream(aggsArray), I32(branchFactor)))(x => ToArray(x)))) { eltAndIdx => - MakeStruct(FastSeq( - ("prev", ArrayRef(last, GetField(eltAndIdx, "idx"))), - ("partialSums", GetField(eltAndIdx, "elt")))) - } + // The downward pass traverses the tree from root to leaves, computing partial scan + // sums as it goes. The two pieces of state transmitted between iterations are the + // level (an integer) referring to a position in the array `aggStack`, and `last`, + // the partial sums from the last iteration. The starting state for `last` is an + // array of a single empty aggregation state. + bindIR(upPass()) { aggStack => + val downPassLoopName = genUID() + val iteration = Ref(genUID(), TInt32) + + val level = Ref(genUID(), TInt32) + val last = Ref(genUID(), TArray(TString)) + + bindIR(WriteValue(initState, Str(tmpDir) + UUID4(), writer)) { freshState => + val loopBody = If( + level < 0, + last, + bindIR(ArrayRef(aggStack, level)) { aggsArray => + val groups = mapIR(zipWithIndex(mapIR(StreamGrouped( + ToStream(aggsArray), + I32(branchFactor), + ))(x => ToArray(x)))) { eltAndIdx => + MakeStruct(FastSeq( + ("prev", ArrayRef(last, GetField(eltAndIdx, "idx"))), + ("partialSums", GetField(eltAndIdx, "elt")), + )) + } - val results = cdaIR( - groups, MakeTuple.ordered(FastSeq()), "table_scan_down_pass", - strConcat(Str("iteration="), invoke("str", TString, iteration), Str(", level="), invoke("str", TString, level)) - ) { case (context, _) => - bindIR(GetField(context, "prev")) { prev => - val elt = Ref(genUID(), TString) - ToArray(RunAggScan( - ToStream(GetField(context, "partialSums"), requiresMemoryManagementPerElement = true), - elt.name, - bindIR(ReadValue(prev, reader, reader.spec.encodedVirtualType)) { serializedTuple => - Begin( - aggs.aggs.zipWithIndex.map { case (sig, i) => - InitFromSerializedValue(i, GetTupleElement(serializedTuple, i), sig.state) - }) - }, - bindIR(ReadValue(elt, reader, reader.spec.encodedVirtualType)) { serializedTuple => - Begin( - aggs.aggs.zipWithIndex.map { case (sig, i) => - CombOpValue(i, GetTupleElement(serializedTuple, i), sig) - }) - }, - WriteValue(MakeTuple.ordered(aggs.aggs.zipWithIndex.map { case (sig, i) => AggStateValue(i, sig.state) }), Str(tmpDir) + UUID4(), writer), - aggs.states)) + val results = cdaIR( + groups, + MakeTuple.ordered(FastSeq()), + "table_scan_down_pass", + strConcat( + Str("iteration="), + invoke("str", TString, iteration), + Str(", level="), + invoke("str", TString, level), + ), + ) { case (context, _) => + bindIR(GetField(context, "prev")) { prev => + val elt = Ref(genUID(), TString) + ToArray(RunAggScan( + ToStream( + GetField(context, "partialSums"), + requiresMemoryManagementPerElement = true, + ), + elt.name, + bindIR(ReadValue(prev, reader, reader.spec.encodedVirtualType)) { + serializedTuple => + Begin( + aggs.aggs.zipWithIndex.map { case (sig, i) => + InitFromSerializedValue( + i, + GetTupleElement(serializedTuple, i), + sig.state, + ) + } + ) + }, + bindIR(ReadValue(elt, reader, reader.spec.encodedVirtualType)) { + serializedTuple => + Begin( + aggs.aggs.zipWithIndex.map { case (sig, i) => + CombOpValue(i, GetTupleElement(serializedTuple, i), sig) + } + ) + }, + WriteValue( + MakeTuple.ordered(aggs.aggs.zipWithIndex.map { case (sig, i) => + AggStateValue(i, sig.state) + }), + Str(tmpDir) + UUID4(), + writer, + ), + aggs.states, + )) + } } - } - Recur( - downPassLoopName, - IndexedSeq( - level - 1, - ToArray(flatten(ToStream(results))), - iteration + 1), - TArray(TString)) - }) - TailLoop( - downPassLoopName, - IndexedSeq((level.name, ArrayLen(aggStack) - 1), (last.name, MakeArray(freshState)), (iteration.name, I32(0))), - loopBody.typ, - loopBody) + Recur( + downPassLoopName, + IndexedSeq( + level - 1, + ToArray(flatten(ToStream(results))), + iteration + 1, + ), + TArray(TString), + ) + }, + ) + TailLoop( + downPassLoopName, + IndexedSeq( + (level.name, ArrayLen(aggStack) - 1), + (last.name, MakeArray(freshState)), + (iteration.name, I32(0)), + ), + loopBody.typ, + loopBody, + ) + } } } - } - (partitionPrefixSumFiles, { (file: IR) => ReadValue(file, reader, reader.spec.encodedVirtualType) }) + ( + partitionPrefixSumFiles, + { (file: IR) => ReadValue(file, reader, reader.spec.encodedVirtualType) }, + ) } else { - val partitionAggs = lcWithInitBinding.mapCollectWithGlobals("table_scan_prefix_sums_singlestage")({ part: IR => - Let(FastSeq("global" -> lc.globals), - RunAgg( - Begin(FastSeq( - initFromSerializedStates, - StreamFor(part, - "row", - aggs.seqPerElt - ) - )), - MakeTuple.ordered(aggs.aggs.zipWithIndex.map { case (sig, i) => AggStateValue(i, sig.state) }), - aggs.states - )) - }) { case (collected, globals) => - Let(FastSeq("global" -> globals), - ToArray(StreamTake({ - val acc = Ref(genUID(), initStateRef.typ) - val value = Ref(genUID(), collected.typ.asInstanceOf[TArray].elementType) - StreamScan( - ToStream(collected, requiresMemoryManagementPerElement = true), - initStateRef, - acc.name, - value.name, + val partitionAggs = + lcWithInitBinding.mapCollectWithGlobals("table_scan_prefix_sums_singlestage")({ + part: IR => + Let( + FastSeq("global" -> lc.globals), RunAgg( Begin(FastSeq( - Begin(aggs.aggs.zipWithIndex.map { case (agg, i) => - InitFromSerializedValue(i, GetTupleElement(acc, i), agg.state) - }), - Begin(aggs.aggs.zipWithIndex.map { case (sig, i) => CombOpValue(i, GetTupleElement(value, i), sig) }))), - MakeTuple.ordered(aggs.aggs.zipWithIndex.map { case (sig, i) => AggStateValue(i, sig.state) }), - aggs.states - ) + initFromSerializedStates, + StreamFor(part, "row", aggs.seqPerElt), + )), + MakeTuple.ordered(aggs.aggs.zipWithIndex.map { case (sig, i) => + AggStateValue(i, sig.state) + }), + aggs.states, + ), ) - }, ArrayLen(collected)))) - } + }) { case (collected, globals) => + Let( + FastSeq("global" -> globals), + ToArray(StreamTake( + { + val acc = Ref(genUID(), initStateRef.typ) + val value = Ref(genUID(), collected.typ.asInstanceOf[TArray].elementType) + StreamScan( + ToStream(collected, requiresMemoryManagementPerElement = true), + initStateRef, + acc.name, + value.name, + RunAgg( + Begin(FastSeq( + Begin(aggs.aggs.zipWithIndex.map { case (agg, i) => + InitFromSerializedValue(i, GetTupleElement(acc, i), agg.state) + }), + Begin(aggs.aggs.zipWithIndex.map { case (sig, i) => + CombOpValue(i, GetTupleElement(value, i), sig) + }), + )), + MakeTuple.ordered(aggs.aggs.zipWithIndex.map { case (sig, i) => + AggStateValue(i, sig.state) + }), + aggs.states, + ), + ) + }, + ArrayLen(collected), + )), + ) + } (partitionAggs, identity[IR]) } val partitionPrefixSumsRef = Ref(genUID(), partitionPrefixSumValues.typ) val zipOldContextRef = Ref(genUID(), lc.contexts.typ.asInstanceOf[TStream].elementType) - val zipPartAggUID = Ref(genUID(), partitionPrefixSumValues.typ.asInstanceOf[TArray].elementType) + val zipPartAggUID = + Ref(genUID(), partitionPrefixSumValues.typ.asInstanceOf[TArray].elementType) TableStage.apply( - letBindings = lc.letBindings ++ FastSeq((partitionPrefixSumsRef.name, partitionPrefixSumValues)), + letBindings = + lc.letBindings ++ FastSeq((partitionPrefixSumsRef.name, partitionPrefixSumValues)), broadcastVals = lc.broadcastVals, partitioner = lc.partitioner, dependency = lc.dependency, @@ -1494,30 +1993,32 @@ object LowerTableIR { FastSeq(lc.contexts, ToStream(partitionPrefixSumsRef)), FastSeq(zipOldContextRef.name, zipPartAggUID.name), MakeStruct(FastSeq(("oldContext", zipOldContextRef), ("scanState", zipPartAggUID))), - ArrayZipBehavior.AssertSameLength + ArrayZipBehavior.AssertSameLength, ), partition = { (partitionRef: Ref) => - bindIRs(GetField(partitionRef, "oldContext"), GetField(partitionRef, "scanState")) { case Seq(oldContext, rawPrefixSum) => - bindIR(transformPrefixSum(rawPrefixSum)) { scanState => - Let(FastSeq("global" -> lc.globals), - RunAggScan( - lc.partition(oldContext), - "row", - Begin(aggs.aggs.zipWithIndex.map { case (agg, i) => - InitFromSerializedValue(i, GetTupleElement(scanState, i), agg.state) - }), - aggs.seqPerElt, - Let(FastSeq(resultUID -> results), aggs.postAggIR), - aggs.states + bindIRs(GetField(partitionRef, "oldContext"), GetField(partitionRef, "scanState")) { + case Seq(oldContext, rawPrefixSum) => + bindIR(transformPrefixSum(rawPrefixSum)) { scanState => + Let( + FastSeq("global" -> lc.globals), + RunAggScan( + lc.partition(oldContext), + "row", + Begin(aggs.aggs.zipWithIndex.map { case (agg, i) => + InitFromSerializedValue(i, GetTupleElement(scanState, i), agg.state) + }), + aggs.seqPerElt, + Let(FastSeq(aggs.resultRef.name -> results), aggs.postAggIR), + aggs.states, + ), ) - ) - } + } } - } + }, ) } - case t@TableKeyBy(child, newKey, isSorted: Boolean) => + case t @ TableKeyBy(child, newKey, _: Boolean) => require(t.definitelyDoesNotShuffle) val loweredChild = lower(child) @@ -1526,7 +2027,9 @@ object LowerTableIR { .takeWhile { case (l, r) => l == r } .length - loweredChild.changePartitionerNoRepartition(loweredChild.partitioner.coarsen(nPreservedFields)) + loweredChild.changePartitionerNoRepartition( + loweredChild.partitioner.coarsen(nPreservedFields) + ) .extendKeyPreservesPartitioning(ctx, newKey) case TableLeftJoinRightDistinct(left, right, root) => @@ -1547,55 +2050,78 @@ object LowerTableIR { val rootStruct = SelectFields(rightElementRef, typeOfRootStruct.fieldNames.toIndexedSeq) val joiningOp = InsertFields(leftElementRef, FastSeq(root -> rootStruct)) StreamJoinRightDistinct( - leftPart, rightPart, - left.typ.key.take(commonKeyLength), right.typ.key, - leftElementRef.name, rightElementRef.name, - joiningOp, "left") - }) + leftPart, + rightPart, + left.typ.key.take(commonKeyLength), + right.typ.key, + leftElementRef.name, + rightElementRef.name, + joiningOp, + "left", + ) + }, + ) case TableIntervalJoin(left, right, root, product) => - assert(!product) - val loweredLeft = lower(left) - val loweredRight = lower(right) - - def partitionJoiner(lPart: IR, rPart: IR): IR = { - val lEltType = lPart.typ.asInstanceOf[TStream].elementType.asInstanceOf[TStruct] - val rEltType = rPart.typ.asInstanceOf[TStream].elementType.asInstanceOf[TStruct] - - val lKey = left.typ.key - val rKey = right.typ.key - - val lEltRef = Ref(genUID(), lEltType) - val rEltRef = Ref(genUID(), rEltType) - - StreamJoinRightDistinct( - lPart, rPart, - lKey, rKey, - lEltRef.name, rEltRef.name, - InsertFields(lEltRef, FastSeq( - root -> SelectFields(rEltRef, right.typ.valueType.fieldNames))), - "left") - } - - loweredLeft.intervalAlignAndZipPartitions(ctx, - loweredRight, + lower(left).intervalAlignAndZipPartitions( + ctx, + lower(right), analyses.requirednessAnalysis.lookup(right).asInstanceOf[RTable].rowType, (lGlobals, _) => lGlobals, - partitionJoiner) + { (lstream, rstream) => + val lref = Ref(genUID(), left.typ.rowType) + if (product) { + val rref = Ref(genUID(), TArray(right.typ.rowType)) + StreamLeftIntervalJoin( + lstream, + rstream, + left.typ.key.head, + right.typ.keyType.fields(0).name, + lref.name, + rref.name, + InsertFields( + lref, + FastSeq( + root -> mapArray(rref)(SelectFields(_, right.typ.valueType.fieldNames)) + ), + ), + ) + } else { + val rref = Ref(genUID(), right.typ.rowType) + StreamJoinRightDistinct( + lstream, + rstream, + left.typ.key, + right.typ.key, + lref.name, + rref.name, + InsertFields( + lref, + FastSeq(root -> SelectFields(rref, right.typ.valueType.fieldNames)), + ), + "left", + ) + } + }, + ) - case tj@TableJoin(left, right, joinType, joinKey) => + case tj @ TableJoin(left, right, _, _) => val loweredLeft = lower(left) val loweredRight = lower(right) LowerTableIRHelpers.lowerTableJoin(ctx, analyses, tj, loweredLeft, loweredRight) - case x@TableUnion(children) => + case x @ TableUnion(children) => val lowered = children.map(lower) val keyType = x.typ.keyType if (keyType.size == 0) { TableStage.concatenate(ctx, lowered) } else { - val newPartitioner = RVDPartitioner.generate(ctx.stateManager, keyType, lowered.flatMap(_.partitioner.rangeBounds)) + val newPartitioner = RVDPartitioner.generate( + ctx.stateManager, + keyType, + lowered.flatMap(_.partitioner.rangeBounds), + ) val repartitioned = lowered.map(_.repartitionNoShuffle(ctx, newPartitioner)) TableStage( @@ -1608,17 +2134,30 @@ object LowerTableIR { MakeTuple.ordered(ctxRefs) }, ctxRef => - StreamMultiMerge(repartitioned.indices.map(i => repartitioned(i).partition(GetTupleElement(ctxRef, i))), keyType.fieldNames) - ) + StreamMultiMerge( + repartitioned.indices.map(i => + repartitioned(i).partition(GetTupleElement(ctxRef, i)) + ), + keyType.fieldNames, + ), + ) } - case x@TableMultiWayZipJoin(children, fieldName, globalName) => + case x @ TableMultiWayZipJoin(children, fieldName, globalName) => val lowered = children.map(lower) val keyType = x.typ.keyType - val newPartitioner = RVDPartitioner.generate(ctx.stateManager, keyType, lowered.flatMap(_.partitioner.rangeBounds)) + val newPartitioner = RVDPartitioner.generate( + ctx.stateManager, + keyType, + lowered.flatMap(_.partitioner.rangeBounds), + ) val repartitioned = lowered.map(_.repartitionNoShuffle(ctx, newPartitioner)) val newGlobals = MakeStruct(FastSeq( - globalName -> MakeArray(lowered.map(_.globals), TArray(lowered.head.globalType)))) + globalName -> MakeArray( + repartitioned.map(_.globals), + TArray(repartitioned.head.globalType), + ) + )) val globalsRef = Ref(genUID(), newGlobals.typ) val keyRef = Ref(genUID(), keyType) @@ -1638,17 +2177,23 @@ object LowerTableIR { }, ctxRef => StreamZipJoin( - repartitioned.indices.map(i => repartitioned(i).partition(GetTupleElement(ctxRef, i))), + repartitioned.indices.map(i => + repartitioned(i).partition(GetTupleElement(ctxRef, i)) + ), keyType.fieldNames, keyRef.name, valsRef.name, - InsertFields(keyRef, FastSeq(fieldName -> projectedVals))) + InsertFields(keyRef, FastSeq(fieldName -> projectedVals)), + ), ) - case t@TableOrderBy(child, sortFields) => + case t @ TableOrderBy(child, _) => require(t.definitelyDoesNotShuffle) val loweredChild = lower(child) - loweredChild.changePartitionerNoRepartition(RVDPartitioner.unkeyed(ctx.stateManager, loweredChild.partitioner.numPartitions)) + loweredChild.changePartitionerNoRepartition(RVDPartitioner.unkeyed( + ctx.stateManager, + loweredChild.partitioner.numPartitions, + )) case TableExplode(child, path) => lower(child).mapPartition(Some(child.typ.key.takeWhile(k => k != path(0)))) { rows => @@ -1662,12 +2207,14 @@ object LowerTableIR { refs(i + 1) = Ref(genUID(), roots(i).typ) i += 1 } - Let(refs.tail.zip(roots).map { case (ref, root) => ref.name -> root }, + Let( + refs.tail.zip(roots).map { case (ref, root) => ref.name -> root }, mapIR(ToStream(refs.last, true)) { elt => path.zip(refs.init).foldRight[IR](elt) { case ((p, ref), inserted) => InsertFields(ref, FastSeq(p -> inserted)) } - }) + }, + ) } } @@ -1679,20 +2226,25 @@ object LowerTableIR { letBindings = lc.letBindings, broadcastVals = lc.broadcastVals, globals = lc.globals, - partitioner = lc.partitioner.copy(rangeBounds = lc.partitioner - .rangeBounds - .grouped(groupSize) - .toArray - .map(arr => Interval(arr.head.left, arr.last.right))), + partitioner = lc.partitioner.copy(rangeBounds = + lc.partitioner + .rangeBounds + .grouped(groupSize) + .toArray + .map(arr => Interval(arr.head.left, arr.last.right)) + ), dependency = lc.dependency, - contexts = mapIR(StreamGrouped(lc.contexts, groupSize)) { group => ToArray(group) }, - partition = (r: Ref) => flatMapIR(ToStream(r)) { prevCtx => lc.partition(prevCtx) } + contexts = mapIR(StreamGrouped(lc.contexts, groupSize))(group => ToArray(group)), + partition = (r: Ref) => flatMapIR(ToStream(r))(prevCtx => lc.partition(prevCtx)), ) case TableRename(child, rowMap, globalMap) => val loweredChild = lower(child) val newGlobals = - CastRename(loweredChild.globals, loweredChild.globals.typ.asInstanceOf[TStruct].rename(globalMap)) + CastRename( + loweredChild.globals, + loweredChild.globals.typ.asInstanceOf[TStruct].rename(globalMap), + ) val newGlobalsRef = Ref(genUID(), newGlobals.typ) TableStage( @@ -1702,9 +2254,11 @@ object LowerTableIR { loweredChild.partitioner.copy(kType = loweredChild.kType.rename(rowMap)), loweredChild.dependency, loweredChild.contexts, - (ctxRef: Ref) => mapIR(loweredChild.partition(ctxRef)) { row => - CastRename(row, row.typ.asInstanceOf[TStruct].rename(rowMap)) - }) + (ctxRef: Ref) => + mapIR(loweredChild.partition(ctxRef)) { row => + CastRename(row, row.typ.asInstanceOf[TStruct].rename(rowMap)) + }, + ) case TableMapPartitions(child, globalName, partitionStreamName, body, _, allowedOverlap) => val loweredChild = lower(child).strictify(ctx, allowedOverlap) @@ -1713,7 +2267,7 @@ object LowerTableIR { Let(FastSeq(globalName -> loweredChild.globals, partitionStreamName -> part), body) } - case TableLiteral(typ, rvd, enc, encodedGlobals) => + case TableLiteral(_, rvd, enc, encodedGlobals) => RVDToTableStage(rvd, EncodedLiteral(enc, encodedGlobals)) case TableToTableApply(child, TableFilterPartitions(seq, keep)) => @@ -1724,9 +2278,9 @@ object LowerTableIR { val lit = Literal(TSet(TInt32), keptSet) if (keep) { def lookupRangeBound(idx: Int): Interval = { - try { + try lc.partitioner.rangeBounds(idx) - } catch { + catch { case exc: ArrayIndexOutOfBoundsException => fatal(s"_filter_partitions: no partition with index $idx", exc) } @@ -1736,41 +2290,66 @@ object LowerTableIR { partitioner = lc.partitioner.copy(rangeBounds = arr.map(lookupRangeBound)), contexts = mapIR( filterIR( - zipWithIndex(lc.contexts)) { t => - invoke("contains", TBoolean, lit, GetField(t, "idx")) }) { t => - GetField(t, "elt") } + zipWithIndex(lc.contexts) + )(t => invoke("contains", TBoolean, lit, GetField(t, "idx"))) + )(t => GetField(t, "elt")), ) } else { lc.copy( - partitioner = lc.partitioner.copy(rangeBounds = lc.partitioner.rangeBounds.zipWithIndex.filter { case (_, idx) => !keptSet.contains(idx) }.map(_._1)), + partitioner = + lc.partitioner.copy(rangeBounds = lc.partitioner.rangeBounds.zipWithIndex.filter { + case (_, idx) => !keptSet.contains(idx) + }.map(_._1)), contexts = mapIR( filterIR( - zipWithIndex(lc.contexts)) { t => - !invoke("contains", TBoolean, lit, GetField(t, "idx")) }) { t => - GetField(t, "elt") } + zipWithIndex(lc.contexts) + )(t => !invoke("contains", TBoolean, lit, GetField(t, "idx"))) + )(t => GetField(t, "elt")), ) } - case TableToTableApply(child, WrappedMatrixToTableFunction(localLDPrune: LocalLDPrune, colsFieldName, entriesFieldName, _)) => + case TableToTableApply( + child, + WrappedMatrixToTableFunction( + localLDPrune: LocalLDPrune, + colsFieldName, + entriesFieldName, + _, + ), + ) => val lc = lower(child) lc.mapPartition(Some(child.typ.key)) { rows => - localLDPrune.makeStream(rows, entriesFieldName, ArrayLen(GetField(lc.globals, colsFieldName))) + localLDPrune.makeStream( + rows, + entriesFieldName, + ArrayLen(GetField(lc.globals, colsFieldName)), + ) }.mapGlobals(_ => makestruct()) - case bmtt@BlockMatrixToTable(bmir) => + case BlockMatrixToTable(bmir) => val ts = LowerBlockMatrixIR.lowerToTableStage(bmir, typesToLower, ctx, analyses) // I now have an unkeyed table of (blockRow, blockCol, block). ts.mapPartitionWithContext { (partition, ctxRef) => flatMapIR(partition)(singleRowRef => bindIR(GetField(singleRowRef, "block")) { singleNDRef => bindIR(NDArrayShape(singleNDRef)) { shapeTupleRef => - flatMapIR(rangeIR(Cast(GetTupleElement(shapeTupleRef, 0), TInt32))) { withinNDRowIdx => - mapIR(rangeIR(Cast(GetTupleElement(shapeTupleRef, 1), TInt32))) { withinNDColIdx => - val entry = NDArrayRef(singleNDRef, IndexedSeq(Cast(withinNDRowIdx, TInt64), Cast(withinNDColIdx, TInt64)), ErrorIDs.NO_ERROR) - val blockStartRow = GetField(singleRowRef, "blockRow") * bmir.typ.blockSize - val blockStartCol = GetField(singleRowRef, "blockCol") * bmir.typ.blockSize - makestruct("i" -> Cast(withinNDRowIdx + blockStartRow, TInt64), "j" -> Cast(withinNDColIdx + blockStartCol, TInt64), "entry" -> entry) - } + flatMapIR(rangeIR(Cast(GetTupleElement(shapeTupleRef, 0), TInt32))) { + withinNDRowIdx => + mapIR(rangeIR(Cast(GetTupleElement(shapeTupleRef, 1), TInt32))) { + withinNDColIdx => + val entry = NDArrayRef( + singleNDRef, + IndexedSeq(Cast(withinNDRowIdx, TInt64), Cast(withinNDColIdx, TInt64)), + ErrorIDs.NO_ERROR, + ) + val blockStartRow = GetField(singleRowRef, "blockRow") * bmir.typ.blockSize + val blockStartCol = GetField(singleRowRef, "blockCol") * bmir.typ.blockSize + makestruct( + "i" -> Cast(withinNDRowIdx + blockStartRow, TInt64), + "j" -> Cast(withinNDColIdx + blockStartCol, TInt64), + "entry" -> entry, + ) + } } } } @@ -1781,13 +2360,24 @@ object LowerTableIR { throw new LowererUnsupportedOperation(s"undefined: \n${Pretty(ctx, node)}") } - assert(tir.typ.globalType == lowered.globalType, s"\n ir global: ${tir.typ.globalType}\n lowered global: ${lowered.globalType}") - assert(tir.typ.rowType == lowered.rowType, s"\n ir row: ${tir.typ.rowType}\n lowered row: ${lowered.rowType}") - assert(tir.typ.keyType.isPrefixOf(lowered.kType), s"\n ir key: ${tir.typ.key}\n lowered key: ${lowered.key}") + assert( + tir.typ.globalType == lowered.globalType, + s"\n ir global: ${tir.typ.globalType}\n lowered global: ${lowered.globalType}", + ) + assert( + tir.typ.rowType == lowered.rowType, + s"\n ir row: ${tir.typ.rowType}\n lowered row: ${lowered.rowType}", + ) + assert( + tir.typ.keyType.isPrefixOf(lowered.kType), + s"\n ir key: ${tir.typ.key}\n lowered key: ${lowered.key}", + ) lowered } + // format: off + /* We have a couple of options when repartitioning a table: * 1. Send only the contexts needed to compute each new partition and * take/drop the rows that fall in that partition. @@ -1828,4 +2418,6 @@ object LowerTableIR { log.info(s"repartition cost: $cost") cost <= 1.0 } + + // format: on } diff --git a/hail/src/main/scala/is/hail/expr/ir/lowering/LowerTableIRHelpers.scala b/hail/src/main/scala/is/hail/expr/ir/lowering/LowerTableIRHelpers.scala index a2751ee4312..dd9298d119c 100644 --- a/hail/src/main/scala/is/hail/expr/ir/lowering/LowerTableIRHelpers.scala +++ b/hail/src/main/scala/is/hail/expr/ir/lowering/LowerTableIRHelpers.scala @@ -8,7 +8,13 @@ import is.hail.utils.FastSeq object LowerTableIRHelpers { - def lowerTableJoin(ctx: ExecuteContext, analyses: LoweringAnalyses, tj: TableJoin, loweredLeft: TableStage, loweredRight: TableStage): TableStage = { + def lowerTableJoin( + ctx: ExecuteContext, + analyses: LoweringAnalyses, + tj: TableJoin, + loweredLeft: TableStage, + loweredRight: TableStage, + ): TableStage = { val TableJoin(left, right, joinType, joinKey) = tj val lKeyFields = left.typ.key.take(joinKey) val lValueFields = left.typ.rowType.fieldNames.filter(f => !lKeyFields.contains(f)) @@ -18,11 +24,14 @@ object LowerTableIRHelpers { val rReq = analyses.requirednessAnalysis.lookup(right).asInstanceOf[RTable] val rightKeyIsDistinct = analyses.distinctKeyedAnalysis.contains(right) - val joinedStage = loweredLeft.orderedJoin(ctx, - loweredRight, joinKey, joinType, + val joinedStage = loweredLeft.orderedJoin( + ctx, + loweredRight, + joinKey, + joinType, (lGlobals, rGlobals) => { val rGlobalType = rGlobals.typ.asInstanceOf[TStruct] - bindIR(rGlobals) { rGlobalRef => + bindIR(rGlobals) { rGlobalRef => InsertFields(lGlobals, rGlobalType.fieldNames.map(f => f -> GetField(rGlobalRef, f))) } }, @@ -30,13 +39,20 @@ object LowerTableIRHelpers { MakeStruct( (lKeyFields, rKeyFields).zipped.map { (lKey, rKey) => if (joinType == "outer" && lReq.field(lKey).required && rReq.field(rKey).required) - lKey -> Coalesce(FastSeq(GetField(lEltRef, lKey), GetField(rEltRef, rKey), Die("TableJoin expected non-missing key", left.typ.rowType.fieldType(lKey), -1))) + lKey -> Coalesce(FastSeq( + GetField(lEltRef, lKey), + GetField(rEltRef, rKey), + Die("TableJoin expected non-missing key", left.typ.rowType.fieldType(lKey), -1), + )) else lKey -> Coalesce(FastSeq(GetField(lEltRef, lKey), GetField(rEltRef, rKey))) } ++ lValueFields.map(f => f -> GetField(lEltRef, f)) - ++ rValueFields.map(f => f -> GetField(rEltRef, f))) - }, rightKeyIsDistinct) + ++ rValueFields.map(f => f -> GetField(rEltRef, f)) + ) + }, + rightKeyIsDistinct, + ) assert(joinedStage.rowType == tj.typ.rowType) joinedStage diff --git a/hail/src/main/scala/is/hail/expr/ir/lowering/LowerToCDA.scala b/hail/src/main/scala/is/hail/expr/ir/lowering/LowerToCDA.scala index 1866b8de77e..8f485275641 100644 --- a/hail/src/main/scala/is/hail/expr/ir/lowering/LowerToCDA.scala +++ b/hail/src/main/scala/is/hail/expr/ir/lowering/LowerToCDA.scala @@ -11,21 +11,34 @@ object LowerToCDA { lower(ir, typesToLower, ctx, analyses) } - def lower(ir: IR, typesToLower: DArrayLowering.Type, ctx: ExecuteContext, analyses: LoweringAnalyses): IR = ir match { + def lower( + ir: IR, + typesToLower: DArrayLowering.Type, + ctx: ExecuteContext, + analyses: LoweringAnalyses, + ): IR = ir match { case node if node.children.forall(_.isInstanceOf[IR]) => ir.mapChildren { case c: IR => lower(c, typesToLower, ctx, analyses) } - case node if node.children.exists(n => n.isInstanceOf[TableIR]) && node.children.forall(n => n.isInstanceOf[TableIR] || n.isInstanceOf[IR]) => + case node + if node.children.exists(n => n.isInstanceOf[TableIR]) && node.children.forall(n => + n.isInstanceOf[TableIR] || n.isInstanceOf[IR] + ) => LowerTableIR(ir, typesToLower, ctx, analyses) - case node if node.children.exists(n => n.isInstanceOf[BlockMatrixIR]) && node.children.forall(n => n.isInstanceOf[BlockMatrixIR] || n.isInstanceOf[IR]) => + case node + if node.children.exists(n => n.isInstanceOf[BlockMatrixIR]) && node.children.forall(n => + n.isInstanceOf[BlockMatrixIR] || n.isInstanceOf[IR] + ) => LowerBlockMatrixIR(ir, typesToLower, ctx, analyses) case node if node.children.exists(_.isInstanceOf[MatrixIR]) => - throw new LowererUnsupportedOperation(s"MatrixIR nodes must be lowered to TableIR nodes separately: \n${ Pretty(ctx, node) }") + throw new LowererUnsupportedOperation( + s"MatrixIR nodes must be lowered to TableIR nodes separately: \n${Pretty(ctx, node)}" + ) case node => - throw new LowererUnsupportedOperation(s"Cannot lower: \n${ Pretty(ctx, node) }") + throw new LowererUnsupportedOperation(s"Cannot lower: \n${Pretty(ctx, node)}") } } diff --git a/hail/src/main/scala/is/hail/expr/ir/lowering/LoweringPass.scala b/hail/src/main/scala/is/hail/expr/ir/lowering/LoweringPass.scala index ad84438972c..83ac87554ab 100644 --- a/hail/src/main/scala/is/hail/expr/ir/lowering/LoweringPass.scala +++ b/hail/src/main/scala/is/hail/expr/ir/lowering/LoweringPass.scala @@ -1,14 +1,15 @@ package is.hail.expr.ir.lowering import is.hail.backend.ExecuteContext -import is.hail.expr.ir.agg.Extract import is.hail.expr.ir._ +import is.hail.expr.ir.agg.Extract import is.hail.expr.ir.analyses.SemanticHash import is.hail.utils._ final case class IrMetadata(semhash: Option[SemanticHash.Type]) { private[this] var hashCounter: Int = 0 private[this] var markCounter: Int = 0 + def nextHash: Option[SemanticHash.Type] = { hashCounter += 1 semhash.map(SemanticHash.extend(_, SemanticHash.Bytes.fromInt(hashCounter))) @@ -95,10 +96,13 @@ case object InlineApplyIR extends LoweringPass { val after: IRState = CompilableIRNoApply val context: String = "InlineApplyIR" - override def transform(ctx: ExecuteContext, ir: BaseIR): BaseIR = RewriteBottomUp(ir, { - case x: ApplyIR => Some(x.explicitNode) - case _ => None - }) + override def transform(ctx: ExecuteContext, ir: BaseIR): BaseIR = RewriteBottomUp( + ir, + { + case x: ApplyIR => Some(x.explicitNode) + case _ => None + }, + ) } case object LowerArrayAggsToRunAggsPass extends LoweringPass { @@ -109,48 +113,49 @@ case object LowerArrayAggsToRunAggsPass extends LoweringPass { def transform(ctx: ExecuteContext, ir: BaseIR): BaseIR = { val x = ir.noSharing(ctx) val r = Requiredness(x, ctx) - RewriteBottomUp(x, { - case x@StreamAgg(a, name, query) => - val res = genUID() - val aggs = Extract(query, res, r) - - val newNode = aggs.rewriteFromInitBindingRoot { root => - Let( - FastSeq( - res -> RunAgg( - Begin(FastSeq( - aggs.init, - StreamFor(a, name, aggs.seqPerElt) - )), - aggs.results, - aggs.states - ) - ), - root - ) - } - - if (newNode.typ != x.typ) - throw new RuntimeException(s"types differ:\n new: ${newNode.typ}\n old: ${x.typ}") - Some(newNode.noSharing(ctx)) - case x@StreamAggScan(a, name, query) => - val res = genUID() - val aggs = Extract(query, res, r, isScan=true) - val newNode = aggs.rewriteFromInitBindingRoot { root => - RunAggScan( - a, - name, - aggs.init, - aggs.seqPerElt, - Let(FastSeq(res -> aggs.results), root), - aggs.states - ) - } - if (newNode.typ != x.typ) - throw new RuntimeException(s"types differ:\n new: ${ newNode.typ }\n old: ${ x.typ }") - Some(newNode.noSharing(ctx)) - case _ => None - }) + RewriteBottomUp( + x, + { + case x @ StreamAgg(a, name, query) => + val aggs = Extract(query, r) + + val newNode = aggs.rewriteFromInitBindingRoot { root => + Let( + FastSeq( + aggs.resultRef.name -> RunAgg( + Begin(FastSeq( + aggs.init, + StreamFor(a, name, aggs.seqPerElt), + )), + aggs.results, + aggs.states, + ) + ), + root, + ) + } + + if (newNode.typ != x.typ) + throw new RuntimeException(s"types differ:\n new: ${newNode.typ}\n old: ${x.typ}") + Some(newNode.noSharing(ctx)) + case x @ StreamAggScan(a, name, query) => + val aggs = Extract(query, r, isScan = true) + val newNode = aggs.rewriteFromInitBindingRoot { root => + RunAggScan( + a, + name, + aggs.init, + aggs.seqPerElt, + Let(FastSeq(aggs.resultRef.name -> aggs.results), root), + aggs.states, + ) + } + if (newNode.typ != x.typ) + throw new RuntimeException(s"types differ:\n new: ${newNode.typ}\n old: ${x.typ}") + Some(newNode.noSharing(ctx)) + case _ => None + }, + ) } } @@ -159,9 +164,8 @@ case class EvalRelationalLetsPass(passesBelow: LoweringPipeline) extends Lowerin val after: IRState = before + NoRelationalLetsState val context: String = "EvalRelationalLets" - override def transform(ctx: ExecuteContext, ir: BaseIR): BaseIR = { + override def transform(ctx: ExecuteContext, ir: BaseIR): BaseIR = EvalRelationalLets(ir, ctx, passesBelow) - } } case class LowerAndExecuteShufflesPass(passesBelow: LoweringPipeline) extends LoweringPass { @@ -169,7 +173,6 @@ case class LowerAndExecuteShufflesPass(passesBelow: LoweringPipeline) extends Lo val after: IRState = before + LoweredShuffles val context: String = "LowerAndExecuteShuffles" - override def transform(ctx: ExecuteContext, ir: BaseIR): BaseIR = { + override def transform(ctx: ExecuteContext, ir: BaseIR): BaseIR = LowerAndExecuteShuffles(ir, ctx, passesBelow) - } } diff --git a/hail/src/main/scala/is/hail/expr/ir/lowering/LoweringPipeline.scala b/hail/src/main/scala/is/hail/expr/ir/lowering/LoweringPipeline.scala index 483c303bbf9..2df9a548d37 100644 --- a/hail/src/main/scala/is/hail/expr/ir/lowering/LoweringPipeline.scala +++ b/hail/src/main/scala/is/hail/expr/ir/lowering/LoweringPipeline.scala @@ -10,36 +10,37 @@ case class LoweringPipeline(lowerings: LoweringPass*) { final def apply(ctx: ExecuteContext, ir: BaseIR): BaseIR = { var x = ir - def render(context: String): Unit = { + def render(context: String): Unit = if (ctx.shouldLogIR()) - log.info(s"$context: IR size ${ IRSize(x) }: \n" + Pretty(ctx, x, elideLiterals = true)) - } + log.info(s"$context: IR size ${IRSize(x)}: \n" + Pretty(ctx, x, elideLiterals = true)) render(s"initial IR") lowerings.foreach { l => try { x = l.apply(ctx, x) - render(s"after ${ l.context }") + render(s"after ${l.context}") } catch { case e: Throwable => - log.error(s"error while applying lowering '${ l.context }'") + log.error(s"error while applying lowering '${l.context}'") throw e } - try { + try TypeCheck(ctx, x) - } catch { + catch { case e: Throwable => - fatal(s"error after applying ${ l.context }", e) + fatal(s"error after applying ${l.context}", e) } } x } - def noOptimization(): LoweringPipeline = LoweringPipeline(lowerings.filter(l => !l.isInstanceOf[OptimizePass]): _*) + def noOptimization(): LoweringPipeline = + LoweringPipeline(lowerings.filter(l => !l.isInstanceOf[OptimizePass]): _*) - def +(suffix: LoweringPipeline): LoweringPipeline = LoweringPipeline((lowerings ++ suffix.lowerings): _*) + def +(suffix: LoweringPipeline): LoweringPipeline = + LoweringPipeline((lowerings ++ suffix.lowerings): _*) } object LoweringPipeline { @@ -48,12 +49,15 @@ object LoweringPipeline { val base = LoweringPipeline( baseTransformer, - OptimizePass(s"$context, after ${ baseTransformer.context }")) + OptimizePass(s"$context, after ${baseTransformer.context}"), + ) // recursively lowers and executes val withShuffleRewrite = - LoweringPipeline(LowerAndExecuteShufflesPass(base), - OptimizePass(s"$context, after LowerAndExecuteShuffles")) + base + LoweringPipeline( + LowerAndExecuteShufflesPass(base), + OptimizePass(s"$context, after LowerAndExecuteShuffles"), + ) + base // recursively lowers and executes val withLetEvaluation = @@ -65,16 +69,20 @@ object LoweringPipeline { LoweringPipeline( OptimizePass(s"$context, initial IR"), LowerMatrixToTablePass, - OptimizePass(s"$context, after LowerMatrixToTable")) + withLetEvaluation + OptimizePass(s"$context, after LowerMatrixToTable"), + ) + withLetEvaluation } - private val _relationalLowerer = fullLoweringPipeline("relationalLowerer", LowerOrInterpretNonCompilablePass) - private val _relationalLowererNoOpt = _relationalLowerer.noOptimization() + private val _relationalLowerer = + fullLoweringPipeline("relationalLowerer", LowerOrInterpretNonCompilablePass) + private val _relationalLowererNoOpt = _relationalLowerer.noOptimization() // legacy lowers can run partial optimization on a TableIR/MatrixIR that gets interpreted to a // TableValue for spark compatibility - private val _relationalLowererLegacy = fullLoweringPipeline("relationalLowererLegacy", LegacyInterpretNonCompilablePass) + private val _relationalLowererLegacy = + fullLoweringPipeline("relationalLowererLegacy", LegacyInterpretNonCompilablePass) + private val _relationalLowererNoOptLegacy = _relationalLowererLegacy.noOptimization() private val _compileLowerer = LoweringPipeline( @@ -82,25 +90,29 @@ object LoweringPipeline { InlineApplyIR, OptimizePass("compileLowerer, after InlineApplyIR"), LowerArrayAggsToRunAggsPass, - OptimizePass("compileLowerer, after LowerArrayAggsToRunAggs") + OptimizePass("compileLowerer, after LowerArrayAggsToRunAggs"), ) + private val _compileLowererNoOpt = _compileLowerer.noOptimization() private val _dArrayLowerers = Array( DArrayLowering.All, DArrayLowering.TableOnly, - DArrayLowering.BMOnly).map { lv => + DArrayLowering.BMOnly, + ).map { lv => (lv -> fullLoweringPipeline("darrayLowerer", LowerToDistributedArrayPass(lv))) }.toMap private val _dArrayLowerersNoOpt = _dArrayLowerers.mapValues(_.noOptimization()).toMap - def relationalLowerer(optimize: Boolean): LoweringPipeline = if (optimize) _relationalLowerer else _relationalLowererNoOpt + def relationalLowerer(optimize: Boolean): LoweringPipeline = + if (optimize) _relationalLowerer else _relationalLowererNoOpt def legacyRelationalLowerer(optimize: Boolean): LoweringPipeline = if (optimize) _relationalLowererLegacy else _relationalLowererNoOptLegacy - def darrayLowerer(optimize: Boolean): Map[DArrayLowering.Type, LoweringPipeline] = if (optimize) _dArrayLowerers else _dArrayLowerersNoOpt + def darrayLowerer(optimize: Boolean): Map[DArrayLowering.Type, LoweringPipeline] = + if (optimize) _dArrayLowerers else _dArrayLowerersNoOpt def compileLowerer(optimize: Boolean): LoweringPipeline = if (optimize) _compileLowerer else _compileLowererNoOpt diff --git a/hail/src/main/scala/is/hail/expr/ir/lowering/RVDToTableStage.scala b/hail/src/main/scala/is/hail/expr/ir/lowering/RVDToTableStage.scala index 2686b497313..82a925aff2b 100644 --- a/hail/src/main/scala/is/hail/expr/ir/lowering/RVDToTableStage.scala +++ b/hail/src/main/scala/is/hail/expr/ir/lowering/RVDToTableStage.scala @@ -2,62 +2,86 @@ package is.hail.expr.ir.lowering import is.hail.annotations.{BroadcastRow, Region, RegionValue} import is.hail.asm4s._ -import is.hail.backend.spark.{AnonymousDependency, SparkTaskContext} import is.hail.backend.{BroadcastValue, ExecuteContext} +import is.hail.backend.spark.{AnonymousDependency, SparkTaskContext} import is.hail.expr.ir._ -import is.hail.io.fs.FS import is.hail.io.{BufferSpec, TypedCodecSpec} +import is.hail.io.fs.FS import is.hail.rvd.{RVD, RVDType} import is.hail.sparkextras.ContextRDD -import is.hail.types.physical.stypes.PTypeReferenceSingleCodeType +import is.hail.types.{RTable, TableType, VirtualTypeWithReq} import is.hail.types.physical.{PArray, PStruct} +import is.hail.types.physical.stypes.PTypeReferenceSingleCodeType import is.hail.types.virtual.TStruct -import is.hail.types.{RTable, TableType, VirtualTypeWithReq} import is.hail.utils.FastSeq -import org.apache.spark.rdd.RDD + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} + import org.apache.spark.{Dependency, Partition, SparkContext, TaskContext} +import org.apache.spark.rdd.RDD import org.json4s.JValue import org.json4s.JsonAST.JString -import java.io.{ByteArrayInputStream, ByteArrayOutputStream} - case class RVDTableReader(rvd: RVD, globals: IR, rt: RTable) extends TableReader { - lazy val fullType: TableType = TableType(rvd.rowType, rvd.typ.key, globals.typ.asInstanceOf[TStruct]) + lazy val fullType: TableType = + TableType(rvd.rowType, rvd.typ.key, globals.typ.asInstanceOf[TStruct]) override def pathsUsed: Seq[String] = Seq() override def partitionCounts: Option[IndexedSeq[Long]] = None - override def toExecuteIntermediate(ctx: ExecuteContext, requestedType: TableType, dropRows: Boolean): TableExecuteIntermediate = { + override def toExecuteIntermediate( + ctx: ExecuteContext, + requestedType: TableType, + dropRows: Boolean, + ): TableExecuteIntermediate = { assert(!dropRows) - val (Some(PTypeReferenceSingleCodeType(globType: PStruct)), f) = Compile[AsmFunction1RegionLong]( - ctx, FastSeq(), FastSeq(classInfo[Region]), LongInfo, PruneDeadFields.upcast(ctx, globals, requestedType.globalType)) + val (Some(PTypeReferenceSingleCodeType(globType: PStruct)), f) = + Compile[AsmFunction1RegionLong]( + ctx, + FastSeq(), + FastSeq(classInfo[Region]), + LongInfo, + PruneDeadFields.upcast(ctx, globals, requestedType.globalType), + ) val gbAddr = f(ctx.theHailClassLoader, ctx.fs, ctx.taskContext, ctx.r)(ctx.r) val globRow = BroadcastRow(ctx, RegionValue(ctx.r, gbAddr), globType) val rowEmitType = SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(rvd.rowPType)) - val (Some(PTypeReferenceSingleCodeType(newRowType: PStruct)), rowF) = Compile[AsmFunction2RegionLongLong]( - ctx, FastSeq(("row", rowEmitType)), FastSeq(classInfo[Region], LongInfo), LongInfo, - PruneDeadFields.upcast(ctx, In(0, rowEmitType), - requestedType.rowType)) + val (Some(PTypeReferenceSingleCodeType(newRowType: PStruct)), rowF) = + Compile[AsmFunction2RegionLongLong]( + ctx, + FastSeq(("row", rowEmitType)), + FastSeq(classInfo[Region], LongInfo), + LongInfo, + PruneDeadFields.upcast(ctx, In(0, rowEmitType), requestedType.rowType), + ) val fsBc = ctx.fsBc - TableExecuteIntermediate(TableValue(ctx, requestedType, globRow, rvd.mapPartitionsWithIndex(RVDType(newRowType, requestedType.key)) { case (i, ctx, it) => - val partF = rowF(theHailClassLoaderForSparkWorkers, fsBc.value, SparkTaskContext.get(), ctx.partitionRegion) - it.map { elt => partF(ctx.r, elt) } - })) + TableExecuteIntermediate(TableValue( + ctx, + requestedType, + globRow, + rvd.mapPartitionsWithIndex(RVDType(newRowType, requestedType.key)) { case (_, ctx, it) => + val partF = rowF( + theHailClassLoaderForSparkWorkers, + fsBc.value, + SparkTaskContext.get(), + ctx.partitionRegion, + ) + it.map(elt => partF(ctx.r, elt)) + }, + )) } override def isDistinctlyKeyed: Boolean = false - def rowRequiredness(ctx: ExecuteContext, requestedType: TableType): VirtualTypeWithReq = { + def rowRequiredness(ctx: ExecuteContext, requestedType: TableType): VirtualTypeWithReq = VirtualTypeWithReq.subset(requestedType.rowType, rt.rowType) - } - def globalRequiredness(ctx: ExecuteContext, requestedType: TableType): VirtualTypeWithReq = { + def globalRequiredness(ctx: ExecuteContext, requestedType: TableType): VirtualTypeWithReq = VirtualTypeWithReq.subset(requestedType.globalType, rt.globalType) - } override def toJValue: JValue = JString("RVDTableReader") @@ -68,10 +92,9 @@ case class RVDTableReader(rvd: RVD, globals: IR, rt: RTable) extends TableReader override def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR = PruneDeadFields.upcast(ctx, globals, requestedGlobalsType) - override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = { + override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = RVDToTableStage(rvd, globals) .upcast(ctx, requestedType) - } } object RVDToTableStage { @@ -81,7 +104,7 @@ object RVDToTableStage { partitioner = rvd.partitioner, dependency = TableStageDependency.fromRVD(rvd), contexts = StreamRange(0, rvd.getNumPartitions, 1), - body = ReadPartition(_, rvd.rowType, PartitionRVDReader(rvd, "__dummy_uid")) + body = ReadPartition(_, rvd.rowType, PartitionRVDReader(rvd, "__dummy_uid")), ) } } @@ -89,26 +112,37 @@ object RVDToTableStage { object TableStageToRVD { def apply(ctx: ExecuteContext, _ts: TableStage): (BroadcastRow, RVD) = { - val ts = TableStage(letBindings = _ts.letBindings, + val ts = TableStage( + letBindings = _ts.letBindings, broadcastVals = _ts.broadcastVals, globals = _ts.globals, partitioner = _ts.partitioner, dependency = _ts.dependency, - contexts = mapIR(_ts.contexts) { c => MakeStruct(FastSeq("context" -> c)) }, - partition = { ctx: Ref => _ts.partition(GetField(ctx, "context")) }) + contexts = mapIR(_ts.contexts)(c => MakeStruct(FastSeq("context" -> c))), + partition = { ctx: Ref => _ts.partition(GetField(ctx, "context")) }, + ) val sparkContext = ctx.backend .asSpark("TableStageToRVD") .sc val globalsAndBroadcastVals = - Let(ts.letBindings, MakeStruct(FastSeq( - "globals" -> ts.globals, - "broadcastVals" -> MakeStruct(ts.broadcastVals), - "contexts" -> ToArray(ts.contexts)) - )) + Let( + ts.letBindings, + MakeStruct(FastSeq( + "globals" -> ts.globals, + "broadcastVals" -> MakeStruct(ts.broadcastVals), + "contexts" -> ToArray(ts.contexts), + )), + ) - val (Some(PTypeReferenceSingleCodeType(gbPType: PStruct)), f) = Compile[AsmFunction1RegionLong](ctx, FastSeq(), FastSeq(classInfo[Region]), LongInfo, globalsAndBroadcastVals) + val (Some(PTypeReferenceSingleCodeType(gbPType: PStruct)), f) = Compile[AsmFunction1RegionLong]( + ctx, + FastSeq(), + FastSeq(classInfo[Region]), + LongInfo, + globalsAndBroadcastVals, + ) val gbAddr = f(ctx.theHailClassLoader, ctx.fs, ctx.taskContext, ctx.r)(ctx.r) val globPType = gbPType.fieldType("globals").asInstanceOf[PStruct] @@ -117,8 +151,10 @@ object TableStageToRVD { val bcValsPType = gbPType.fieldType("broadcastVals") val bcValsSpec = TypedCodecSpec(bcValsPType, BufferSpec.wireSpec) - val encodedBcVals = sparkContext.broadcast(bcValsSpec.encodeValue(ctx, bcValsPType, gbPType.loadField(gbAddr, 1))) - val (decodedBcValsPType: PStruct, makeBcDec) = bcValsSpec.buildDecoder(ctx, bcValsPType.virtualType) + val encodedBcVals = + sparkContext.broadcast(bcValsSpec.encodeValue(ctx, bcValsPType, gbPType.loadField(gbAddr, 1))) + val (decodedBcValsPType: PStruct, makeBcDec) = + bcValsSpec.buildDecoder(ctx, bcValsPType.virtualType) val contextsPType = gbPType.fieldType("contexts").asInstanceOf[PArray] val contextPType = contextsPType.elementType @@ -126,7 +162,8 @@ object TableStageToRVD { val contextsAddr = gbPType.loadField(gbAddr, 2) val nContexts = contextsPType.loadLength(contextsAddr) - val (decodedContextPType: PStruct, makeContextDec) = contextSpec.buildDecoder(ctx, contextPType.virtualType) + val (decodedContextPType: PStruct, makeContextDec) = + contextSpec.buildDecoder(ctx, contextPType.virtualType) val makeContextEnc = contextSpec.buildEncoder(ctx, contextPType) val encodedContexts = Array.tabulate(nContexts) { i => @@ -140,11 +177,20 @@ object TableStageToRVD { val (newRowPType: PStruct, makeIterator) = CompileIterator.forTableStageToRVD( ctx, - decodedContextPType, decodedBcValsPType, + decodedContextPType, + decodedBcValsPType, Let( - ts.broadcastVals.map(_._1).map(bcVal => bcVal -> GetField(In(1, SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(decodedBcValsPType))), bcVal)), - ts.partition(In(0, SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(decodedContextPType)))) - ) + ts.broadcastVals.map(_._1).map(bcVal => + bcVal -> GetField( + In(1, SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(decodedBcValsPType))), + bcVal, + ) + ), + ts.partition(In( + 0, + SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(decodedContextPType)), + )), + ), ) val fsBc = ctx.fsBc @@ -156,12 +202,25 @@ object TableStageToRVD { val rdd = new TableStageToRDD(fsBc, sparkContext, encodedContexts, sparkDeps) val crdd = ContextRDD.weaken(rdd) - .cflatMap { case (rvdContext, (encodedContext, idx)) => - val decodedContext = makeContextDec(new ByteArrayInputStream(encodedContext), theHailClassLoaderForSparkWorkers) + .cflatMap { case (rvdContext, (encodedContext, _)) => + val decodedContext = makeContextDec( + new ByteArrayInputStream(encodedContext), + theHailClassLoaderForSparkWorkers, + ) .readRegionValue(rvdContext.partitionRegion) - val decodedBroadcastVals = makeBcDec(new ByteArrayInputStream(encodedBcVals.value), theHailClassLoaderForSparkWorkers) + val decodedBroadcastVals = makeBcDec( + new ByteArrayInputStream(encodedBcVals.value), + theHailClassLoaderForSparkWorkers, + ) .readRegionValue(rvdContext.partitionRegion) - makeIterator(theHailClassLoaderForSparkWorkers, fsBc.value, SparkTaskContext.get(), rvdContext, decodedContext, decodedBroadcastVals) + makeIterator( + theHailClassLoaderForSparkWorkers, + fsBc.value, + SparkTaskContext.get(), + rvdContext, + decodedContext, + decodedBroadcastVals, + ) .map(_.longValue()) } @@ -175,12 +234,11 @@ class TableStageToRDD( fsBc: BroadcastValue[FS], sc: SparkContext, @transient private val collection: Array[Array[Byte]], - deps: Seq[Dependency[_]]) - extends RDD[(Array[Byte], Int)](sc, deps) { + deps: Seq[Dependency[_]], +) extends RDD[(Array[Byte], Int)](sc, deps) { - override def getPartitions: Array[Partition] = { + override def getPartitions: Array[Partition] = Array.tabulate(collection.length)(i => TableStageToRDDPartition(collection(i), i)) - } override def compute(partition: Partition, context: TaskContext): Iterator[(Array[Byte], Int)] = { val sp = partition.asInstanceOf[TableStageToRDDPartition] diff --git a/hail/src/main/scala/is/hail/expr/ir/lowering/Rule.scala b/hail/src/main/scala/is/hail/expr/ir/lowering/Rule.scala index 9a39fe46cb0..6274eeefcd8 100644 --- a/hail/src/main/scala/is/hail/expr/ir/lowering/Rule.scala +++ b/hail/src/main/scala/is/hail/expr/ir/lowering/Rule.scala @@ -1,7 +1,6 @@ package is.hail.expr.ir.lowering import is.hail.expr.ir._ -import is.hail.types.virtual.TStream trait Rule { def allows(ir: BaseIR): Boolean @@ -39,7 +38,7 @@ case object CompilableValueIRs extends Rule { case object NoApplyIR extends Rule { override def allows(ir: BaseIR): Boolean = ir match { - case x: ApplyIR => false + case _: ApplyIR => false case _ => true } } diff --git a/hail/src/main/scala/is/hail/expr/ir/lowering/TableStageDependency.scala b/hail/src/main/scala/is/hail/expr/ir/lowering/TableStageDependency.scala index 182befb7543..d6da5b5ba97 100644 --- a/hail/src/main/scala/is/hail/expr/ir/lowering/TableStageDependency.scala +++ b/hail/src/main/scala/is/hail/expr/ir/lowering/TableStageDependency.scala @@ -21,5 +21,6 @@ object TableStageDependency { } case class TableStageDependency(deps: IndexedSeq[DependencySource]) { - def union(other: TableStageDependency): TableStageDependency = TableStageDependency.union(FastSeq(this, other)) + def union(other: TableStageDependency): TableStageDependency = + TableStageDependency.union(FastSeq(this, other)) } diff --git a/hail/src/main/scala/is/hail/expr/ir/ndarrays/EmitNDArray.scala b/hail/src/main/scala/is/hail/expr/ir/ndarrays/EmitNDArray.scala index 901516068cf..c4b4ec82143 100644 --- a/hail/src/main/scala/is/hail/expr/ir/ndarrays/EmitNDArray.scala +++ b/hail/src/main/scala/is/hail/expr/ir/ndarrays/EmitNDArray.scala @@ -4,8 +4,8 @@ import is.hail.annotations.Region import is.hail.asm4s._ import is.hail.expr.ir._ import is.hail.types.physical._ -import is.hail.types.physical.stypes.interfaces._ import is.hail.types.physical.stypes.{SType, SValue} +import is.hail.types.physical.stypes.interfaces._ import is.hail.types.virtual._ import is.hail.utils._ @@ -26,7 +26,7 @@ abstract class NDArrayProducer { aShape: IndexedSeq[Value[Long]] = shape, ainitAll: EmitCodeBuilder => Unit = initAll, ainitAxis: IndexedSeq[(EmitCodeBuilder) => Unit] = initAxis, - astepAxis: IndexedSeq[(EmitCodeBuilder, Value[Long]) => Unit] = stepAxis + astepAxis: IndexedSeq[(EmitCodeBuilder, Value[Long]) => Unit] = stepAxis, ): NDArrayProducer = { new NDArrayProducer() { override def elementType: PType = aElementType @@ -36,24 +36,44 @@ abstract class NDArrayProducer { override val initAxis: IndexedSeq[EmitCodeBuilder => Unit] = ainitAxis override val stepAxis: IndexedSeq[(EmitCodeBuilder, Value[Long]) => Unit] = astepAxis - override def loadElementAtCurrentAddr(cb: EmitCodeBuilder): SValue = outer.loadElementAtCurrentAddr(cb) + override def loadElementAtCurrentAddr(cb: EmitCodeBuilder): SValue = + outer.loadElementAtCurrentAddr(cb) } } - def toSCode(cb: EmitCodeBuilder, targetType: PCanonicalNDArray, region: Value[Region], rowMajor: Boolean = false): SNDArrayValue = { + def toSCode( + cb: EmitCodeBuilder, + targetType: PCanonicalNDArray, + region: Value[Region], + rowMajor: Boolean = false, + ): SNDArrayValue = { val (firstElementAddress, finish) = targetType.constructDataFunction( shape, targetType.makeColumnMajorStrides(shape, cb), cb, - region) + region, + ) val currentWriteAddr = cb.newLocal[Long]("ndarray_producer_to_scode_cur_write_addr") cb.assign(currentWriteAddr, firstElementAddress) initAll(cb) - val idxGenerator = if (rowMajor) SNDArray.forEachIndexWithInitAndIncRowMajor _ else SNDArray.forEachIndexWithInitAndIncColMajor _ - idxGenerator(cb, shape, initAxis, stepAxis.map(stepper => (cb: EmitCodeBuilder) => stepper(cb, 1L)), "ndarray_producer_toSCode"){ (cb, indices) => - targetType.elementType.storeAtAddress(cb, currentWriteAddr, region, loadElementAtCurrentAddr(cb), true) + val idxGenerator = if (rowMajor) SNDArray.forEachIndexWithInitAndIncRowMajor _ + else SNDArray.forEachIndexWithInitAndIncColMajor _ + idxGenerator( + cb, + shape, + initAxis, + stepAxis.map(stepper => (cb: EmitCodeBuilder) => stepper(cb, 1L)), + "ndarray_producer_toSCode", + ) { (cb, indices) => + targetType.elementType.storeAtAddress( + cb, + currentWriteAddr, + region, + loadElementAtCurrentAddr(cb), + true, + ) cb.assign(currentWriteAddr, currentWriteAddr + targetType.elementType.byteSize) } @@ -70,10 +90,18 @@ object EmitNDArray { region: Value[Region], env: EmitEnv, container: Option[AggContainer], - loopEnv: Option[Env[LoopRef]] + loopEnv: Option[Env[LoopRef]], ): IEmitCode = { - def emitNDInSeparateMethod(context: String, cb: EmitCodeBuilder, ir: IR, region: Value[Region], env: EmitEnv, container: Option[AggContainer], loopEnv: Option[Env[LoopRef]]): IEmitCode = { + def emitNDInSeparateMethod( + context: String, + cb: EmitCodeBuilder, + ir: IR, + region: Value[Region], + env: EmitEnv, + container: Option[AggContainer], + loopEnv: Option[Env[LoopRef]], + ): IEmitCode = { assert(!emitter.ctx.inLoopCriticalPath.contains(ir)) val mb = cb.emb.genEmitMethod(context, FastSeq[ParamType](), UnitInfo) @@ -82,26 +110,52 @@ object EmitNDArray { var ev: EmitSettable = null mb.voidWithBuilder { cb => emitter.ctx.tryingToSplit.update(ir, ()) - val result: IEmitCode = deforest(ir, cb, r, env, container, loopEnv).map(cb)(ndap => ndap.toSCode(cb, PCanonicalNDArray(ndap.elementType.setRequired(true), ndap.nDims), r)) + val result: IEmitCode = deforest(ir, cb, r, env, container, loopEnv).map(cb)(ndap => + ndap.toSCode(cb, PCanonicalNDArray(ndap.elementType.setRequired(true), ndap.nDims), r) + ) ev = cb.emb.ecb.newEmitField(s"${context}_result", result.emitType) cb.assign(ev, result) } - cb.invokeVoid(mb) + cb.invokeVoid(mb, cb.this_) ev.toI(cb) } - def deforest(x: IR, cb: EmitCodeBuilder, region: Value[Region], env: EmitEnv, container: Option[AggContainer], loopEnv: Option[Env[LoopRef]]): IEmitCodeGen[NDArrayProducer] = { - def deforestRecur(x: IR, cb: EmitCodeBuilder = cb, region: Value[Region] = region, env: EmitEnv = env, container: Option[AggContainer] = container, loopEnv: Option[Env[LoopRef]] = loopEnv): IEmitCodeGen[NDArrayProducer] = { - - def emitI(ir: IR, cb: EmitCodeBuilder, region: Value[Region] = region, env: EmitEnv = env, container: Option[AggContainer] = container, loopEnv: Option[Env[LoopRef]] = loopEnv): IEmitCode = { + def deforest( + x: IR, + cb: EmitCodeBuilder, + region: Value[Region], + env: EmitEnv, + container: Option[AggContainer], + loopEnv: Option[Env[LoopRef]], + ): IEmitCodeGen[NDArrayProducer] = { + def deforestRecur( + x: IR, + cb: EmitCodeBuilder = cb, + region: Value[Region] = region, + env: EmitEnv = env, + container: Option[AggContainer] = container, + loopEnv: Option[Env[LoopRef]] = loopEnv, + ): IEmitCodeGen[NDArrayProducer] = { + + def emitI( + ir: IR, + cb: EmitCodeBuilder, + region: Value[Region] = region, + env: EmitEnv = env, + container: Option[AggContainer] = container, + loopEnv: Option[Env[LoopRef]] = loopEnv, + ): IEmitCode = emitter.emitI(ir, cb, region, env, container, loopEnv) - } x match { - case NDArrayMap(child, elemName, body) => { + case NDArrayMap(child, elemName, body) => deforestRecur(child, cb).map(cb) { childProducer => - val elemRef = cb.emb.newEmitField("ndarray_map_element_name", childProducer.elementType.sType, required = true) + val elemRef = cb.emb.newEmitField( + "ndarray_map_element_name", + childProducer.elementType.sType, + required = true, + ) val bodyEnv = env.bind(elemName, elemRef) val bodyEC = EmitCode.fromI(cb.emb)(cb => emitI(body, cb, env = bodyEnv)) @@ -111,25 +165,31 @@ object EmitNDArray { override val shape: IndexedSeq[Value[Long]] = childProducer.shape override val initAll: EmitCodeBuilder => Unit = childProducer.initAll override val initAxis: IndexedSeq[EmitCodeBuilder => Unit] = childProducer.initAxis - override val stepAxis: IndexedSeq[(EmitCodeBuilder, Value[Long]) => Unit] = childProducer.stepAxis + override val stepAxis: IndexedSeq[(EmitCodeBuilder, Value[Long]) => Unit] = + childProducer.stepAxis override def loadElementAtCurrentAddr(cb: EmitCodeBuilder): SValue = { - cb.assign(elemRef, EmitCode.present(cb.emb, childProducer.loadElementAtCurrentAddr(cb))) - bodyEC.toI(cb).get(cb, "NDArray map body cannot be missing") + cb.assign( + elemRef, + EmitCode.present(cb.emb, childProducer.loadElementAtCurrentAddr(cb)), + ) + bodyEC.toI(cb).getOrFatal(cb, "NDArray map body cannot be missing") } } } - } - case NDArrayMap2(lChild, rChild, lName, rName, body, errorID) => { + case NDArrayMap2(lChild, rChild, lName, rName, body, errorID) => deforestRecur(lChild, cb).flatMap(cb) { leftProducer => deforestRecur(rChild, cb).map(cb) { rightProducer => val leftShapeValues = leftProducer.shape val rightShapeValues = rightProducer.shape - val shapeArray = NDArrayEmitter.unifyShapes2(cb, leftShapeValues, rightShapeValues, errorID) + val shapeArray = + NDArrayEmitter.unifyShapes2(cb, leftShapeValues, rightShapeValues, errorID) - val lElemRef = cb.emb.newEmitField(lName, leftProducer.elementType.sType, required = true) - val rElemRef = cb.emb.newEmitField(rName, rightProducer.elementType.sType, required = true) + val lElemRef = + cb.emb.newEmitField(lName, leftProducer.elementType.sType, required = true) + val rElemRef = + cb.emb.newEmitField(rName, rightProducer.elementType.sType, required = true) val bodyEnv = env.bind(lName, lElemRef) .bind(rName, rElemRef) val bodyEC = EmitCode.fromI(cb.emb)(cb => emitI(body, cb, env = bodyEnv)) @@ -142,35 +202,38 @@ object EmitNDArray { override val shape: IndexedSeq[Value[Long]] = shapeArray override val initAll: EmitCodeBuilder => Unit = { - cb => { + cb => leftBroadcasted.initAll(cb) rightBroadcasted.initAll(cb) - } - } - override val initAxis: IndexedSeq[EmitCodeBuilder => Unit] = shape.indices.map { idx => { cb: EmitCodeBuilder => - leftBroadcasted.initAxis(idx)(cb) - rightBroadcasted.initAxis(idx)(cb) - } - } - override val stepAxis: IndexedSeq[(EmitCodeBuilder, Value[Long]) => Unit] = shape.indices.map { idx => { (cb: EmitCodeBuilder, axis: Value[Long]) => - leftBroadcasted.stepAxis(idx)(cb, axis) - rightBroadcasted.stepAxis(idx)(cb, axis) - } } + override val initAxis: IndexedSeq[EmitCodeBuilder => Unit] = + shape.indices.map { idx => cb: EmitCodeBuilder => + leftBroadcasted.initAxis(idx)(cb) + rightBroadcasted.initAxis(idx)(cb) + } + override val stepAxis: IndexedSeq[(EmitCodeBuilder, Value[Long]) => Unit] = + shape.indices.map { idx => (cb: EmitCodeBuilder, axis: Value[Long]) => + leftBroadcasted.stepAxis(idx)(cb, axis) + rightBroadcasted.stepAxis(idx)(cb, axis) + } override def loadElementAtCurrentAddr(cb: EmitCodeBuilder): SValue = { - cb.assign(lElemRef, EmitCode.present(cb.emb, leftBroadcasted.loadElementAtCurrentAddr(cb))) - cb.assign(rElemRef, EmitCode.present(cb.emb, rightBroadcasted.loadElementAtCurrentAddr(cb))) - - bodyEC.toI(cb).get(cb, "NDArrayMap2 body cannot be missing", errorID) + cb.assign( + lElemRef, + EmitCode.present(cb.emb, leftBroadcasted.loadElementAtCurrentAddr(cb)), + ) + cb.assign( + rElemRef, + EmitCode.present(cb.emb, rightBroadcasted.loadElementAtCurrentAddr(cb)), + ) + + bodyEC.toI(cb).getOrFatal(cb, "NDArrayMap2 body cannot be missing", errorID) } } } } - } case NDArrayReindex(child, indexExpr) => deforestRecur(child, cb).map(cb) { childProducer => - new NDArrayProducer { override def elementType: PType = childProducer.elementType @@ -181,34 +244,33 @@ object EmitNDArray { const(1L) } override val initAll: EmitCodeBuilder => Unit = childProducer.initAll - override val initAxis: IndexedSeq[EmitCodeBuilder => Unit] = { - indexExpr.map { childIndex => - (cb: EmitCodeBuilder) => - if (childIndex < childProducer.nDims) { - childProducer.initAxis(childIndex)(cb) - } + override val initAxis: IndexedSeq[EmitCodeBuilder => Unit] = + indexExpr.map { childIndex => (cb: EmitCodeBuilder) => + if (childIndex < childProducer.nDims) { + childProducer.initAxis(childIndex)(cb) + } } - } - override val stepAxis: IndexedSeq[(EmitCodeBuilder, Value[Long]) => Unit] = { - indexExpr.map { childIndex => - (cb: EmitCodeBuilder, step: Value[Long]) => - if (childIndex < childProducer.nDims) { - childProducer.stepAxis(childIndex)(cb, step) - } + override val stepAxis: IndexedSeq[(EmitCodeBuilder, Value[Long]) => Unit] = + indexExpr.map { childIndex => (cb: EmitCodeBuilder, step: Value[Long]) => + if (childIndex < childProducer.nDims) { + childProducer.stepAxis(childIndex)(cb, step) + } } - } override def loadElementAtCurrentAddr(cb: EmitCodeBuilder): SValue = childProducer.loadElementAtCurrentAddr(cb) } } - case x@NDArrayReshape(childND, shape, errorID) => + case x @ NDArrayReshape(childND, shape, errorID) => emitI(childND, cb).flatMap(cb) { case childND: SNDArrayValue => - // Plan: Run through the child row major, make an array. Then jump around it as needed. + /* Plan: Run through the child row major, make an array. Then jump around it as + * needed. */ val childShapeValues = childND.shapes val outputNDims = x.typ.nDims - val requestedShapeValues = Array.tabulate(outputNDims)(i => cb.newLocal[Long](s"ndarray_reindex_request_shape_$i")).toIndexedSeq + val requestedShapeValues = Array.tabulate(outputNDims)(i => + cb.newLocal[Long](s"ndarray_reindex_request_shape_$i") + ).toIndexedSeq emitI(shape, cb, env = env).map(cb) { case tupleValue: SBaseStructValue => val hasNegativeOne = cb.newLocal[Boolean]("ndarray_reshape_has_neg_one") @@ -220,202 +282,289 @@ object EmitNDArray { cb.assign(runningProduct, 1L) (0 until outputNDims).foreach { i => - cb.assign(tempShapeElement, tupleValue.loadField(cb, i).get(cb, "Can't reshape if elements of reshape tuple are missing.", errorID).asLong.value) - cb.if_(tempShapeElement < 0L, - { - cb.if_(tempShapeElement ceq -1L, - { - cb.if_(hasNegativeOne, { - cb._fatalWithError(errorID, "Can't infer shape, more than one -1") - }, { - cb.assign(hasNegativeOne, true) - }) - }, - { - cb._fatalWithError(errorID,"Can't reshape, new shape must contain only nonnegative numbers or -1") - } + cb.assign( + tempShapeElement, + tupleValue.loadField(cb, i).getOrFatal( + cb, + "Can't reshape if elements of reshape tuple are missing.", + errorID, + ).asLong.value, + ) + cb.if_( + tempShapeElement < 0L, { + cb.if_( + tempShapeElement ceq -1L, + cb.if_( + hasNegativeOne, + cb._fatalWithError(errorID, "Can't infer shape, more than one -1"), + cb.assign(hasNegativeOne, true), + ), + cb._fatalWithError( + errorID, + "Can't reshape, new shape must contain only nonnegative numbers or -1", + ), ) }, - { - cb.assign(runningProduct, runningProduct * tempShapeElement) - } + cb.assign(runningProduct, runningProduct * tempShapeElement), ) } val numElements = cb.newLocal[Long]("ndarray_reshape_child_num_elements") cb.assign(numElements, SNDArray.numElements(childShapeValues)) - cb.if_(hasNegativeOne.mux( - (runningProduct ceq 0L) || (numElements % runningProduct) > 0L, - numElements cne runningProduct - ), { - cb._fatalWithError(errorID,"Can't reshape since requested shape is incompatible with number of elements") - }) - cb.assign(replacesNegativeOne, (runningProduct ceq 0L).mux(0L, numElements / runningProduct)) + cb.if_( + hasNegativeOne.mux( + (runningProduct ceq 0L) || (numElements % runningProduct) > 0L, + numElements cne runningProduct, + ), + cb._fatalWithError( + errorID, + "Can't reshape since requested shape is incompatible with number of elements", + ), + ) + cb.assign( + replacesNegativeOne, + (runningProduct ceq 0L).mux(0L, numElements / runningProduct), + ) (0 until outputNDims).foreach { i => - cb.assign(tempShapeElement, tupleValue.loadField(cb, i).get(cb, "Can't reshape if elements of reshape tuple are missing.", errorID).asLong.value) - cb.assign(requestedShapeValues(i), (tempShapeElement ceq -1L).mux(replacesNegativeOne, tempShapeElement)) + cb.assign( + tempShapeElement, + tupleValue.loadField(cb, i).getOrFatal( + cb, + "Can't reshape if elements of reshape tuple are missing.", + errorID, + ).asLong.value, + ) + cb.assign( + requestedShapeValues(i), + (tempShapeElement ceq -1L).mux(replacesNegativeOne, tempShapeElement), + ) } val childPType = childND.st.storageType().asInstanceOf[PCanonicalNDArray] val rowMajor = fromSValue(childND, cb).toSCode(cb, childPType, region, true) - // The canonical row major thing is now in the order we want. We just need to read this with the row major striding that + /* The canonical row major thing is now in the order we want. We just need to read + * this with the row major striding that */ // would be generated for something of the new shape. - val outputPType = PCanonicalNDArray(rowMajor.st.elementPType.setRequired(true), x.typ.nDims, true) // TODO Should it be required? + val outputPType = PCanonicalNDArray( + rowMajor.st.elementPType.setRequired(true), + x.typ.nDims, + true, + ) // TODO Should it be required? val rowMajorStriding = outputPType.makeRowMajorStrides(requestedShapeValues, cb) - fromShapeStridesFirstAddress(rowMajor.st.elementPType, requestedShapeValues, rowMajorStriding, rowMajor.firstDataAddress, cb) + fromShapeStridesFirstAddress( + rowMajor.st.elementPType, + requestedShapeValues, + rowMajorStriding, + rowMajor.firstDataAddress, + cb, + ) } } - case x@NDArrayConcat(nds, axis) => + case x @ NDArrayConcat(nds, axis) => emitI(nds, cb).flatMap(cb) { case ndsArraySValue: SIndexableValue => val arrLength = ndsArraySValue.loadLength() - cb.if_(arrLength ceq 0, { - cb._fatal("need at least one ndarray to concatenate") - }) + cb.if_(arrLength ceq 0, cb._fatal("need at least one ndarray to concatenate")) - IEmitCode(cb, ndsArraySValue.hasMissingValues(cb), { - val firstND = ndsArraySValue.loadElement(cb, 0).get(cb).asNDArray + IEmitCode( + cb, + ndsArraySValue.hasMissingValues(cb), { + val firstND = ndsArraySValue.loadElement(cb, 0).getOrAssert(cb).asNDArray - // compute array of sizes along concat axis, and total size of concat axis - val arrayLongPType = PCanonicalArray(PInt64(), true) - val (pushSize, finishSizes) = arrayLongPType.constructFromFunctions(cb, region, arrLength, false) + // compute array of sizes along concat axis, and total size of concat axis + val arrayLongPType = PCanonicalArray(PInt64(), true) + val (pushSize, finishSizes) = + arrayLongPType.constructFromFunctions(cb, region, arrLength, false) - val concatAxisSize = cb.newLocal[Long](s"ndarray_concat_axis_size", 0L) + val concatAxisSize = cb.newLocal[Long](s"ndarray_concat_axis_size", 0L) - ndsArraySValue.forEachDefined(cb) { (cb, i, nd) => - val dimLength = nd.asNDArray.shapes(axis) - pushSize(cb, EmitCode(Code._empty, false, primitive(dimLength)).toI(cb)) - cb.assign(concatAxisSize, concatAxisSize + dimLength) - } - val stagedArrayOfSizes = finishSizes(cb) - - // compute index of first input which has non-zero concat axis size - val firstNonEmpty = cb.newLocal[Int]("ndarray_concat_first_nonempty", 0) - cb.while_(stagedArrayOfSizes.loadElement(cb, firstNonEmpty).get(cb).asInt64.value.ceq(0L), { - cb.assign(firstNonEmpty, firstNonEmpty + 1) - }) - - // check that all sizes along other axes are consistent - ndsArraySValue.forEachDefined(cb) { case (cb, _, nd: SNDArrayValue) => - val mismatchedDim = cb.newLocal[Int]("ndarray_concat_mismatched_dim", -1) - val expected = cb.newLocal[Long]("ndarray_concat_expected_size") - val found = cb.newLocal[Long]("ndarray_concat_found_size") - for (i <- (0 until firstND.st.nDims).reverse if i != axis) { - cb.if_(firstND.shapes(i).cne(nd.shapes(i)), { - cb.assign(mismatchedDim, i) - cb.assign(expected, firstND.shapes(i)) - cb.assign(found, nd.shapes(i)) - }) + ndsArraySValue.forEachDefined(cb) { (cb, i, nd) => + val dimLength = nd.asNDArray.shapes(axis) + pushSize(cb, EmitCode(Code._empty, false, primitive(dimLength)).toI(cb)) + cb.assign(concatAxisSize, concatAxisSize + dimLength) } - cb.if_(mismatchedDim.cne(-1), - cb._fatal(const(s"NDArrayConcat: mismatched dimensions of input NDArrays along axis ") - .concat(mismatchedDim.toS) - .concat(": expected ").concat(expected.toS) - .concat(", got ").concat(found.toS))) - } + val stagedArrayOfSizes = finishSizes(cb) + + // compute index of first input which has non-zero concat axis size + val firstNonEmpty = cb.newLocal[Int]("ndarray_concat_first_nonempty", 0) + cb.while_( + stagedArrayOfSizes.loadElement(cb, firstNonEmpty).getOrAssert( + cb + ).asInt64.value.ceq(0L), + cb.assign(firstNonEmpty, firstNonEmpty + 1), + ) - // compute shape of result - val newShape = (0 until x.typ.nDims).map { i => - if (i == axis) concatAxisSize else firstND.shapes(i) - } + // check that all sizes along other axes are consistent + ndsArraySValue.forEachDefined(cb) { case (cb, _, nd: SNDArrayValue) => + val mismatchedDim = cb.newLocal[Int]("ndarray_concat_mismatched_dim", -1) + val expected = cb.newLocal[Long]("ndarray_concat_expected_size") + val found = cb.newLocal[Long]("ndarray_concat_found_size") + for (i <- (0 until firstND.st.nDims).reverse if i != axis) + cb.if_( + firstND.shapes(i).cne(nd.shapes(i)), { + cb.assign(mismatchedDim, i) + cb.assign(expected, firstND.shapes(i)) + cb.assign(found, nd.shapes(i)) + }, + ) + cb.if_( + mismatchedDim.cne(-1), + cb._fatal( + const(s"NDArrayConcat: mismatched dimensions of input NDArrays along axis ") + .concat(mismatchedDim.toS) + .concat(": expected ").concat(expected.toS) + .concat(", got ").concat(found.toS) + ), + ) + } - new NDArrayProducer { - override def elementType: PType = firstND.st.elementPType + // compute shape of result + val newShape = (0 until x.typ.nDims).map { i => + if (i == axis) concatAxisSize else firstND.shapes(i) + } - override val shape: IndexedSeq[Value[Long]] = newShape + new NDArrayProducer { + override def elementType: PType = firstND.st.elementPType - val idxVars = shape.indices.map(i => cb.newLocal[Long](s"ndarray_produceer_fall_through_idx_${i}")) - // Need to keep track of the current ndarray being read from. - val currentNDArrayIdx = cb.newLocal[Int]("ndarray_concat_current_active_ndarray_idx") + override val shape: IndexedSeq[Value[Long]] = newShape - override val initAll: EmitCodeBuilder => Unit = { cb => - idxVars.foreach(idxVar => cb.assign(idxVar, 0L)) - cb.assign(currentNDArrayIdx, firstNonEmpty) - } - override val initAxis: IndexedSeq[EmitCodeBuilder => Unit] = - shape.indices.map(i => (cb: EmitCodeBuilder) => { - cb.assign(idxVars(i), 0L) - if (i == axis) { - cb.assign(currentNDArrayIdx, firstNonEmpty) - } - }) - override val stepAxis: IndexedSeq[(EmitCodeBuilder, Value[Long]) => Unit] = { - // For all boring axes, just add to corresponding indexVar. For the single interesting axis, - // also consider updating the currently tracked ndarray. - shape.indices.map(idx => (cb: EmitCodeBuilder, step: Value[Long]) => { - // Start by updating the idxVar by the step - val curIdxVar = idxVars(idx) - cb.assign(curIdxVar, curIdxVar + step) - if (idx == axis) { - // If bigger than current ndarray, then we need to subtract out the size of this ndarray, increment to the next ndarray, and see if we are happy yet. - val shouldLoop = cb.newLocal[Boolean]("should_loop", curIdxVar >= stagedArrayOfSizes.loadElement(cb, currentNDArrayIdx).get(cb).asInt64.value) - cb.while_(shouldLoop, - { - cb.assign(curIdxVar, curIdxVar - stagedArrayOfSizes.loadElement(cb, currentNDArrayIdx).get(cb).asInt64.value) - cb.assign(currentNDArrayIdx, currentNDArrayIdx + 1) - cb.if_(currentNDArrayIdx < stagedArrayOfSizes.loadLength(), { - cb.assign(shouldLoop, curIdxVar >= stagedArrayOfSizes.loadElement(cb, currentNDArrayIdx).get(cb).asInt64.value) - }, { - cb.assign(shouldLoop, false) - }) + val idxVars = shape.indices.map(i => + cb.newLocal[Long](s"ndarray_produceer_fall_through_idx_$i") + ) + // Need to keep track of the current ndarray being read from. + val currentNDArrayIdx = + cb.newLocal[Int]("ndarray_concat_current_active_ndarray_idx") + + override val initAll: EmitCodeBuilder => Unit = { cb => + idxVars.foreach(idxVar => cb.assign(idxVar, 0L)) + cb.assign(currentNDArrayIdx, firstNonEmpty) + } + override val initAxis: IndexedSeq[EmitCodeBuilder => Unit] = + shape.indices.map(i => + (cb: EmitCodeBuilder) => { + cb.assign(idxVars(i), 0L) + if (i == axis) { + cb.assign(currentNDArrayIdx, firstNonEmpty) } - ) - } - }) - } + } + ) + override val stepAxis: IndexedSeq[(EmitCodeBuilder, Value[Long]) => Unit] = { + /* For all boring axes, just add to corresponding indexVar. For the single + * interesting axis, */ + // also consider updating the currently tracked ndarray. + shape.indices.map(idx => + (cb: EmitCodeBuilder, step: Value[Long]) => { + // Start by updating the idxVar by the step + val curIdxVar = idxVars(idx) + cb.assign(curIdxVar, curIdxVar + step) + if (idx == axis) { + /* If bigger than current ndarray, then we need to subtract out the size + * of this ndarray, increment to the next ndarray, and see if we are + * happy yet. */ + val shouldLoop = cb.newLocal[Boolean]( + "should_loop", + curIdxVar >= stagedArrayOfSizes.loadElement( + cb, + currentNDArrayIdx, + ).getOrAssert(cb).asInt64.value, + ) + cb.while_( + shouldLoop, { + cb.assign( + curIdxVar, + curIdxVar - stagedArrayOfSizes.loadElement( + cb, + currentNDArrayIdx, + ).getOrAssert(cb).asInt64.value, + ) + cb.assign(currentNDArrayIdx, currentNDArrayIdx + 1) + cb.if_( + currentNDArrayIdx < stagedArrayOfSizes.loadLength(), + cb.assign( + shouldLoop, + curIdxVar >= stagedArrayOfSizes.loadElement( + cb, + currentNDArrayIdx, + ).getOrAssert(cb).asInt64.value, + ), + cb.assign(shouldLoop, false), + ) + }, + ) + } + } + ) + } - override def loadElementAtCurrentAddr(cb: EmitCodeBuilder): SValue = { - val currentNDArray = ndsArraySValue.loadElement(cb, currentNDArrayIdx).get(cb).asNDArray - currentNDArray.loadElement(idxVars, cb) + override def loadElementAtCurrentAddr(cb: EmitCodeBuilder): SValue = { + val currentNDArray = + ndsArraySValue.loadElement(cb, currentNDArrayIdx).getOrAssert(cb).asNDArray + currentNDArray.loadElement(idxVars, cb) + } } - } - }) + }, + ) } case NDArraySlice(child, slicesIR) => deforestRecur(child, cb).flatMap(cb) { childProducer => emitI(slicesIR, cb).flatMap(cb) { case slicesValue: SBaseStructValue => - val (indexingIndices, slicingIndices) = slicesValue.st.fieldTypes.zipWithIndex.partition { case (pFieldType, idx) => - pFieldType.isPrimitive - } match { - case (a, b) => (a.map(_._2), b.map(_._2)) - } + val (indexingIndices, slicingIndices) = + slicesValue.st.fieldTypes.zipWithIndex.partition { case (pFieldType, _) => + pFieldType.isPrimitive + } match { + case (a, b) => (a.map(_._2), b.map(_._2)) + } - IEmitCode.multiFlatMap[Int, SValue, NDArrayProducer](indexingIndices, indexingIndex => slicesValue.loadField(cb, indexingIndex), cb) { indexingSCodes => - val indexingValues = indexingSCodes.map(sCode => cb.newLocal("ndarray_slice_indexer", sCode.asInt64.value)) - val slicingValueTriplesBuilder = new BoxedArrayBuilder[(Value[Long], Value[Long], Value[Long])]() + IEmitCode.multiFlatMap[Int, SValue, NDArrayProducer]( + indexingIndices, + indexingIndex => slicesValue.loadField(cb, indexingIndex), + cb, + ) { indexingSCodes => + val indexingValues = indexingSCodes.map(sCode => + cb.newLocal("ndarray_slice_indexer", sCode.asInt64.value) + ) + val slicingValueTriplesBuilder = + new BoxedArrayBuilder[(Value[Long], Value[Long], Value[Long])]() val outputShape = { - IEmitCode.multiFlatMap[Int, SValue, IndexedSeq[Value[Long]]](slicingIndices, - valueIdx => slicesValue.loadField(cb, valueIdx), cb) { sCodeSlices: IndexedSeq[SValue] => - IEmitCode.multiFlatMap[SValue, Value[Long], IndexedSeq[Value[Long]]](sCodeSlices, { case sValueSlice: SBaseStructValue => - // I know I have a tuple of three elements here, start, stop, step - - val newDimSizeI = sValueSlice.loadField(cb, 0).flatMap(cb) { startC => - sValueSlice.loadField(cb, 1).flatMap(cb) { stopC => - sValueSlice.loadField(cb, 2).map(cb) { stepC => - val start = startC.asLong.value - val stop = stopC.asLong.value - val step = stepC.asLong.value - - slicingValueTriplesBuilder.push((start, stop, step)) - - val newDimSize = cb.newLocal[Long]("new_dim_size") - cb.if_(step >= 0L && start <= stop, { - cb.assign(newDimSize, const(1L) + ((stop - start) - 1L) / step) - }, { - cb.if_(step < 0L && start >= stop, { - cb.assign(newDimSize, (((stop - start) + 1L) / step) + 1L) - }, { - cb.assign(newDimSize, 0L) - }) - }) - newDimSize + IEmitCode.multiFlatMap[Int, SValue, IndexedSeq[Value[Long]]]( + slicingIndices, + valueIdx => slicesValue.loadField(cb, valueIdx), + cb, + ) { sCodeSlices: IndexedSeq[SValue] => + IEmitCode.multiFlatMap[SValue, Value[Long], IndexedSeq[Value[Long]]]( + sCodeSlices, + { case sValueSlice: SBaseStructValue => + // I know I have a tuple of three elements here, start, stop, step + + val newDimSizeI = sValueSlice.loadField(cb, 0).flatMap(cb) { startC => + sValueSlice.loadField(cb, 1).flatMap(cb) { stopC => + sValueSlice.loadField(cb, 2).map(cb) { stepC => + val start = startC.asLong.value + val stop = stopC.asLong.value + val step = stepC.asLong.value + + slicingValueTriplesBuilder.push((start, stop, step)) + + val newDimSize = cb.newLocal[Long]("new_dim_size") + cb.if_( + step >= 0L && start <= stop, + cb.assign(newDimSize, const(1L) + ((stop - start) - 1L) / step), + cb.if_( + step < 0L && start >= stop, + cb.assign(newDimSize, (((stop - start) + 1L) / step) + 1L), + cb.assign(newDimSize, 0L), + ), + ) + newDimSize + } } } - } - newDimSizeI - }, cb)(x => IEmitCode(cb, false, x)) + newDimSizeI + }, + cb, + )(x => IEmitCode(cb, false, x)) } } val slicingValueTriples = slicingValueTriplesBuilder.result() @@ -435,19 +584,24 @@ object EmitNDArray { } } - override val initAxis: IndexedSeq[EmitCodeBuilder => Unit] = shape.indices.map(idx => { (cb: EmitCodeBuilder) => - val whichSlicingAxis = slicingIndices(idx) - val slicingValue = slicingValueTriples(idx) - childProducer.initAxis(whichSlicingAxis)(cb) - childProducer.stepAxis(whichSlicingAxis)(cb, slicingValue._1) - }) - override val stepAxis: IndexedSeq[(EmitCodeBuilder, Value[Long]) => Unit] = shape.indices.map(idx => { (cb: EmitCodeBuilder, outerStep: Value[Long]) => - // SlicingIndices is a map from my coordinates to my child's coordinates. - val whichSlicingAxis = slicingIndices(idx) - val (start, stop, sliceStep) = slicingValueTriples(idx) - val innerStep = cb.newLocal[Long]("ndarray_producer_slice_child_step", sliceStep * outerStep) - childProducer.stepAxis(whichSlicingAxis)(cb, innerStep) - }) + override val initAxis: IndexedSeq[EmitCodeBuilder => Unit] = + shape.indices.map { idx => (cb: EmitCodeBuilder) => + val whichSlicingAxis = slicingIndices(idx) + val slicingValue = slicingValueTriples(idx) + childProducer.initAxis(whichSlicingAxis)(cb) + childProducer.stepAxis(whichSlicingAxis)(cb, slicingValue._1) + } + override val stepAxis: IndexedSeq[(EmitCodeBuilder, Value[Long]) => Unit] = + shape.indices.map(idx => { (cb: EmitCodeBuilder, outerStep: Value[Long]) => + // SlicingIndices is a map from my coordinates to my child's coordinates. + val whichSlicingAxis = slicingIndices(idx) + val (_, _, sliceStep) = slicingValueTriples(idx) + val innerStep = cb.newLocal[Long]( + "ndarray_producer_slice_child_step", + sliceStep * outerStep, + ) + childProducer.stepAxis(whichSlicingAxis)(cb, innerStep) + }) override def loadElementAtCurrentAddr(cb: EmitCodeBuilder): SValue = childProducer.loadElementAtCurrentAddr(cb) @@ -458,25 +612,28 @@ object EmitNDArray { } case NDArrayFilter(child, filters) => deforestRecur(child, cb).map(cb) { childProducer: NDArrayProducer => - - val filterWasMissing = (0 until filters.size).map(i => cb.newField[Boolean](s"ndarray_filter_${i}_was_missing")) + val filterWasMissing = (0 until filters.size).map(i => + cb.newField[Boolean](s"ndarray_filter_${i}_was_missing") + ) val filtPValues = new Array[SIndexableValue](filters.size) - val outputShape = childProducer.shape.indices.map(idx => cb.newField[Long](s"ndarray_filter_output_shapes_${idx}")) + val outputShape = childProducer.shape.indices.map(idx => + cb.newField[Long](s"ndarray_filter_output_shapes_$idx") + ) filters.zipWithIndex.foreach { case (filt, i) => - // Each filt is a sequence that may be missing with elements that may not be missing. - emitI(filt, cb).consume(cb, - { + /* Each filt is a sequence that may be missing with elements that may not be + * missing. */ + emitI(filt, cb).consume( + cb, { cb.assign(outputShape(i), childProducer.shape(i)) cb.assign(filterWasMissing(i), true) }, { - case filtArrayPValue: SIndexableValue => { + case filtArrayPValue: SIndexableValue => filtPValues(i) = filtArrayPValue cb.assign(outputShape(i), filtArrayPValue.loadLength().toL) cb.assign(filterWasMissing(i), false) - } - } + }, ) } @@ -485,41 +642,65 @@ object EmitNDArray { override val shape: IndexedSeq[Value[Long]] = outputShape - // Plan: Keep track of current indices on each axis, use them to step through filtered + /* Plan: Keep track of current indices on each axis, use them to step through + * filtered */ // dimensions accordingly. - val idxVars = shape.indices.map(idx => cb.newLocal[Long](s"ndarray_producer_filter_index_${idx}")) + val idxVars = + shape.indices.map(idx => cb.newLocal[Long](s"ndarray_producer_filter_index_$idx")) override val initAll: EmitCodeBuilder => Unit = cb => { idxVars.foreach(idxVar => cb.assign(idxVar, 0L)) childProducer.initAll(cb) } - override val initAxis: IndexedSeq[EmitCodeBuilder => Unit] = shape.indices.map { idx => - (cb: EmitCodeBuilder) => { - cb.assign(idxVars(idx), 0L) - childProducer.initAxis(idx)(cb) - cb.if_(filterWasMissing(idx), { - /* pass */ - }, { - val startPoint = cb.newLocal[Long]("ndarray_producer_filter_init_axis", filtPValues(idx).loadElement(cb, idxVars(idx).toI).get( - cb, s"NDArrayFilter: can't filter on missing index (axis=$idx)").asLong.value) - childProducer.stepAxis(idx)(cb, startPoint) - }) + override val initAxis: IndexedSeq[EmitCodeBuilder => Unit] = + shape.indices.map { idx => (cb: EmitCodeBuilder) => + { + cb.assign(idxVars(idx), 0L) + childProducer.initAxis(idx)(cb) + cb.if_( + filterWasMissing(idx), { + /* pass */ + }, { + val startPoint = cb.newLocal[Long]( + "ndarray_producer_filter_init_axis", + filtPValues(idx).loadElement(cb, idxVars(idx).toI).getOrFatal( + cb, + s"NDArrayFilter: can't filter on missing index (axis=$idx)", + ).asLong.value, + ) + childProducer.stepAxis(idx)(cb, startPoint) + }, + ) + } } - } - override val stepAxis: IndexedSeq[(EmitCodeBuilder, Value[Long]) => Unit] = shape.indices.map { idx => - (cb: EmitCodeBuilder, step: Value[Long]) => { - cb.if_(filterWasMissing(idx), { - childProducer.stepAxis(idx)(cb, step) - cb.assign(idxVars(idx), idxVars(idx) + step) - }, { - val currentPos = filtPValues(idx).loadElement(cb, idxVars(idx).toI).get(cb, s"NDArrayFilter: can't filter on missing index (axis=$idx)").asLong.value - cb.assign(idxVars(idx), idxVars(idx) + step) - val newPos = filtPValues(idx).loadElement(cb, idxVars(idx).toI).get(cb, s"NDArrayFilter: can't filter on missing index (axis=$idx)").asLong.value - val stepSize = cb.newLocal[Long]("ndarray_producer_filter_step_size", newPos - currentPos) - childProducer.stepAxis(idx)(cb, stepSize) - }) + override val stepAxis: IndexedSeq[(EmitCodeBuilder, Value[Long]) => Unit] = + shape.indices.map { idx => (cb: EmitCodeBuilder, step: Value[Long]) => + { + cb.if_( + filterWasMissing(idx), { + childProducer.stepAxis(idx)(cb, step) + cb.assign(idxVars(idx), idxVars(idx) + step) + }, { + val currentPos = + filtPValues(idx).loadElement(cb, idxVars(idx).toI).getOrFatal( + cb, + s"NDArrayFilter: can't filter on missing index (axis=$idx)", + ).asLong.value + cb.assign(idxVars(idx), idxVars(idx) + step) + val newPos = + filtPValues(idx).loadElement(cb, idxVars(idx).toI).getOrFatal( + cb, + s"NDArrayFilter: can't filter on missing index (axis=$idx)", + ).asLong.value + val stepSize = cb.newLocal[Long]( + "ndarray_producer_filter_step_size", + newPos - currentPos, + ) + childProducer.stepAxis(idx)(cb, stepSize) + }, + ) + } } - } override def loadElementAtCurrentAddr(cb: EmitCodeBuilder): SValue = childProducer.loadElementAtCurrentAddr(cb) @@ -544,38 +725,48 @@ object EmitNDArray { override val shape: IndexedSeq[Value[Long]] = newOutputShape override val initAll: EmitCodeBuilder => Unit = childProducer.initAll - // Important part here is that NDArrayAgg has less axes then its child. We need to map + /* Important part here is that NDArrayAgg has less axes then its child. We need to + * map */ // between them. - override val initAxis: IndexedSeq[EmitCodeBuilder => Unit] = { + override val initAxis: IndexedSeq[EmitCodeBuilder => Unit] = axesToKeep.map(idx => childProducer.initAxis(idx)) - } - override val stepAxis: IndexedSeq[(EmitCodeBuilder, Value[Long]) => Unit] = { + override val stepAxis: IndexedSeq[(EmitCodeBuilder, Value[Long]) => Unit] = axesToKeep.map(idx => childProducer.stepAxis(idx)) - } override def loadElementAtCurrentAddr(cb: EmitCodeBuilder): SValue = { - // Idea: For each axis that is being summed over, step through and keep a running sum. + /* Idea: For each axis that is being summed over, step through and keep a running + * sum. */ val numericElementType = elementType.asInstanceOf[PNumeric] - val runningSum = NumericPrimitives.newLocal(cb, "ndarray_agg_running_sum", numericElementType.virtualType) + val runningSum = NumericPrimitives.newLocal( + cb, + "ndarray_agg_running_sum", + numericElementType.virtualType, + ) cb.assign(runningSum, numericElementType.zero) val initsToSumOut = axesToSumOut.map(idx => childProducer.initAxis(idx)) - val stepsToSumOut = axesToSumOut.map(idx => (cb: EmitCodeBuilder) => childProducer.stepAxis(idx)(cb, 1L)) + val stepsToSumOut = axesToSumOut.map(idx => + (cb: EmitCodeBuilder) => childProducer.stepAxis(idx)(cb, 1L) + ) - SNDArray.forEachIndexWithInitAndIncColMajor(cb, newOutputShapeComplement, initsToSumOut, stepsToSumOut, "ndarray_producer_ndarray_agg") { (cb, _) => - cb.assign(runningSum, numericElementType.add(runningSum, SType.extractPrimValue(cb, childProducer.loadElementAtCurrentAddr(cb)))) + SNDArray.forEachIndexWithInitAndIncColMajor(cb, newOutputShapeComplement, + initsToSumOut, stepsToSumOut, "ndarray_producer_ndarray_agg") { (cb, _) => + cb.assign( + runningSum, + numericElementType.add( + runningSum, + SType.extractPrimValue(cb, childProducer.loadElementAtCurrentAddr(cb)), + ), + ) } primitive(numericElementType.virtualType, runningSum) } } } - case _ => { + case _ => val ndI = emitI(x, cb) - ndI.map(cb) { ndPv => - fromSValue(ndPv.asNDArray, cb) - } - } + ndI.map(cb)(ndPv => fromSValue(ndPv.asNDArray, cb)) } } @@ -589,54 +780,68 @@ object EmitNDArray { val ndSvShape = ndSv.shapes val strides = ndSv.strides - fromShapeStridesFirstAddress(ndSv.st.elementPType, ndSvShape, strides, ndSv.firstDataAddress, cb) + fromShapeStridesFirstAddress( + ndSv.st.elementPType, + ndSvShape, + strides, + ndSv.firstDataAddress, + cb, + ) } - def fromShapeStridesFirstAddress(newElementType: PType, ndSvShape: IndexedSeq[Value[Long]], strides: IndexedSeq[Value[Long]], firstDataAddress: Value[Long], cb: EmitCodeBuilder): NDArrayProducer = { - val counters = ndSvShape.indices.map(i => cb.newLocal[Long](s"ndarray_producer_fall_through_idx_${i}")) + def fromShapeStridesFirstAddress( + newElementType: PType, + ndSvShape: IndexedSeq[Value[Long]], + strides: IndexedSeq[Value[Long]], + firstDataAddress: Value[Long], + cb: EmitCodeBuilder, + ): NDArrayProducer = { + val counters = + ndSvShape.indices.map(i => cb.newLocal[Long](s"ndarray_producer_fall_through_idx_$i")) - assert(ndSvShape.size == strides.size, s"shape.size = ${ndSvShape.size} != strides.size = ${strides.size}") + assert( + ndSvShape.size == strides.size, + s"shape.size = ${ndSvShape.size} != strides.size = ${strides.size}", + ) new NDArrayProducer { override def elementType: PType = newElementType override val shape: IndexedSeq[Value[Long]] = ndSvShape - override val initAll: EmitCodeBuilder => Unit = cb => { + override val initAll: EmitCodeBuilder => Unit = cb => counters.foreach(ctr => cb.assign(ctr, 0L)) - } - override val initAxis: IndexedSeq[EmitCodeBuilder => Unit] = { - shape.indices.map(i => (cb: EmitCodeBuilder) => { - cb.assign(counters(i), 0L) - }) - } - override val stepAxis: IndexedSeq[(EmitCodeBuilder, Value[Long]) => Unit] = { - shape.indices.map{ i => - (cb: EmitCodeBuilder, step: Value[Long]) => { - cb.assign(counters(i), counters(i) + step * strides(i)) - } + override val initAxis: IndexedSeq[EmitCodeBuilder => Unit] = + shape.indices.map(i => + (cb: EmitCodeBuilder) => + cb.assign(counters(i), 0L) + ) + override val stepAxis: IndexedSeq[(EmitCodeBuilder, Value[Long]) => Unit] = + shape.indices.map { i => (cb: EmitCodeBuilder, step: Value[Long]) => + cb.assign(counters(i), counters(i) + step * strides(i)) } - } override def loadElementAtCurrentAddr(cb: EmitCodeBuilder): SValue = { - val offset = counters.foldLeft[Code[Long]](const(0L)){ (a, b) => a + b} + val offset = counters.foldLeft[Code[Long]](const(0L))((a, b) => a + b) val loaded = elementType.loadCheapSCode(cb, firstDataAddress + offset) loaded } } } - def createBroadcastMask(cb: EmitCodeBuilder, shape: IndexedSeq[Value[Long]]): IndexedSeq[Value[Long]] = { - val ffff = 0xFFFFFFFFFFFFFFFFL + def createBroadcastMask(cb: EmitCodeBuilder, shape: IndexedSeq[Value[Long]]) + : IndexedSeq[Value[Long]] = { + val ffff = 0xffffffffffffffffL shape.indices.map { idx => - cb.newLocal[Long](s"ndarray_producer_broadcast_mask_${idx}", (shape(idx) ceq 1L).mux(0L, ffff)) + cb.newLocal[Long](s"ndarray_producer_broadcast_mask_$idx", (shape(idx) ceq 1L).mux(0L, ffff)) } } - def broadcast(cb: EmitCodeBuilder, prod: NDArrayProducer,ctx: String): NDArrayProducer = { + def broadcast(cb: EmitCodeBuilder, prod: NDArrayProducer, ctx: String): NDArrayProducer = { val broadcastMask = createBroadcastMask(cb, prod.shape) - val newSteps = prod.stepAxis.indices.map { idx => - (cb: EmitCodeBuilder, step: Value[Long]) => { - val maskedStep = cb.newLocal[Long]("ndarray_producer_masked_step", step & broadcastMask(idx)) + val newSteps = prod.stepAxis.indices.map { idx => (cb: EmitCodeBuilder, step: Value[Long]) => + { + val maskedStep = + cb.newLocal[Long]("ndarray_producer_masked_step", step & broadcastMask(idx)) prod.stepAxis(idx)(cb, maskedStep) } } diff --git a/hail/src/main/scala/is/hail/expr/ir/orderings/BinaryOrdering.scala b/hail/src/main/scala/is/hail/expr/ir/orderings/BinaryOrdering.scala index d685cd8a2c5..6436b8f67d7 100644 --- a/hail/src/main/scala/is/hail/expr/ir/orderings/BinaryOrdering.scala +++ b/hail/src/main/scala/is/hail/expr/ir/orderings/BinaryOrdering.scala @@ -23,18 +23,28 @@ object BinaryOrdering { val cmp = cb.newLocal[Int]("cmp", 0) val Lbreak = CodeLabel() - cb.for_(cb.assign(i, 0), i < lim, cb.assign(i, i + 1), { - val compval = Code.invokeStatic2[java.lang.Integer, Int, Int, Int]("compare", - Code.invokeStatic1[java.lang.Byte, Byte, Int]("toUnsignedInt", xv.loadByte(cb, i)), - Code.invokeStatic1[java.lang.Byte, Byte, Int]("toUnsignedInt", yv.loadByte(cb, i))) - cb.assign(cmp, compval) - cb.if_(cmp.cne(0), cb.goto(Lbreak)) - }) + cb.for_( + cb.assign(i, 0), + i < lim, + cb.assign(i, i + 1), { + val compval = Code.invokeStatic2[java.lang.Integer, Int, Int, Int]( + "compare", + Code.invokeStatic1[java.lang.Byte, Byte, Int]("toUnsignedInt", xv.loadByte(cb, i)), + Code.invokeStatic1[java.lang.Byte, Byte, Int]("toUnsignedInt", yv.loadByte(cb, i)), + ) + cb.assign(cmp, compval) + cb.if_(cmp.cne(0), cb.goto(Lbreak)) + }, + ) cb.define(Lbreak) - cb.if_(cmp.ceq(0), { - cb.assign(cmp, Code.invokeStatic2[java.lang.Integer, Int, Int, Int]("compare", xlen, ylen)) - }) + cb.if_( + cmp.ceq(0), + cb.assign( + cmp, + Code.invokeStatic2[java.lang.Integer, Int, Int, Int]("compare", xlen, ylen), + ), + ) cmp } diff --git a/hail/src/main/scala/is/hail/expr/ir/orderings/CallOrdering.scala b/hail/src/main/scala/is/hail/expr/ir/orderings/CallOrdering.scala index 6d74a278e63..4a0e1a02c52 100644 --- a/hail/src/main/scala/is/hail/expr/ir/orderings/CallOrdering.scala +++ b/hail/src/main/scala/is/hail/expr/ir/orderings/CallOrdering.scala @@ -2,8 +2,8 @@ package is.hail.expr.ir.orderings import is.hail.asm4s.{Code, Value} import is.hail.expr.ir.{EmitClassBuilder, EmitCodeBuilder} -import is.hail.types.physical.stypes.interfaces.SCall import is.hail.types.physical.stypes.{SType, SValue} +import is.hail.types.physical.stypes.interfaces.SCall object CallOrdering { def make(t1: SCall, t2: SCall, ecb: EmitClassBuilder[_]): CodeOrdering = { @@ -14,10 +14,12 @@ object CallOrdering { override val type1: SType = t1 override val type2: SType = t2 - override def _compareNonnull(cb: EmitCodeBuilder, x: SValue, y: SValue): Value[Int] = { - cb.memoize(Code.invokeStatic2[java.lang.Integer, Int, Int, Int]("compare", - x.asCall.canonicalCall(cb), y.asCall.canonicalCall(cb))) - } + override def _compareNonnull(cb: EmitCodeBuilder, x: SValue, y: SValue): Value[Int] = + cb.memoize(Code.invokeStatic2[java.lang.Integer, Int, Int, Int]( + "compare", + x.asCall.canonicalCall(cb), + y.asCall.canonicalCall(cb), + )) } } } diff --git a/hail/src/main/scala/is/hail/expr/ir/orderings/CodeOrdering.scala b/hail/src/main/scala/is/hail/expr/ir/orderings/CodeOrdering.scala index 8e2a6a41051..25dba617b20 100644 --- a/hail/src/main/scala/is/hail/expr/ir/orderings/CodeOrdering.scala +++ b/hail/src/main/scala/is/hail/expr/ir/orderings/CodeOrdering.scala @@ -2,9 +2,9 @@ package is.hail.expr.ir.orderings import is.hail.asm4s._ import is.hail.expr.ir.{EmitClassBuilder, EmitCodeBuilder, EmitValue} +import is.hail.types.physical.stypes.{SType, SValue} import is.hail.types.physical.stypes.interfaces._ import is.hail.types.physical.stypes.primitives._ -import is.hail.types.physical.stypes.{SType, SValue} import is.hail.types.virtual._ import is.hail.utils.FastSeq @@ -52,25 +52,34 @@ object CodeOrdering { def makeOrdering(t1: SType, t2: SType, ecb: EmitClassBuilder[_]): CodeOrdering = { val canCompare = (t1.virtualType, t2.virtualType) match { - case (t1: TStruct, t2: TStruct) => t1.isIsomorphicTo(t2) + case (t1: TStruct, t2: TStruct) => t1.isJoinableWith(t2) case (t1, t2) if t1 == t2 => t1 == t2 } if (!canCompare) { - throw new RuntimeException(s"ordering: type mismatch:\n left: ${ t1.virtualType }\n right: ${ t2.virtualType }") + throw new RuntimeException( + s"ordering: type mismatch:\n left: ${t1.virtualType}\n right: ${t2.virtualType}" + ) } t1.virtualType match { - case TInt32 => Int32Ordering.make(t1.asInstanceOf[SInt32.type], t2.asInstanceOf[SInt32.type], ecb) - case TInt64 => Int64Ordering.make(t1.asInstanceOf[SInt64.type], t2.asInstanceOf[SInt64.type], ecb) - case TFloat32 => Float32Ordering.make(t1.asInstanceOf[SFloat32.type], t2.asInstanceOf[SFloat32.type], ecb) - case TFloat64 => Float64Ordering.make(t1.asInstanceOf[SFloat64.type], t2.asInstanceOf[SFloat64.type], ecb) - case TBoolean => BooleanOrdering.make(t1.asInstanceOf[SBoolean.type], t2.asInstanceOf[SBoolean.type], ecb) + case TInt32 => + Int32Ordering.make(t1.asInstanceOf[SInt32.type], t2.asInstanceOf[SInt32.type], ecb) + case TInt64 => + Int64Ordering.make(t1.asInstanceOf[SInt64.type], t2.asInstanceOf[SInt64.type], ecb) + case TFloat32 => + Float32Ordering.make(t1.asInstanceOf[SFloat32.type], t2.asInstanceOf[SFloat32.type], ecb) + case TFloat64 => + Float64Ordering.make(t1.asInstanceOf[SFloat64.type], t2.asInstanceOf[SFloat64.type], ecb) + case TBoolean => + BooleanOrdering.make(t1.asInstanceOf[SBoolean.type], t2.asInstanceOf[SBoolean.type], ecb) case TCall => CallOrdering.make(t1.asInstanceOf[SCall], t2.asInstanceOf[SCall], ecb) case TString => StringOrdering.make(t1.asInstanceOf[SString], t2.asInstanceOf[SString], ecb) case TBinary => BinaryOrdering.make(t1.asInstanceOf[SBinary], t2.asInstanceOf[SBinary], ecb) - case _: TBaseStruct => StructOrdering.make(t1.asInstanceOf[SBaseStruct], t2.asInstanceOf[SBaseStruct], ecb) + case _: TBaseStruct => + StructOrdering.make(t1.asInstanceOf[SBaseStruct], t2.asInstanceOf[SBaseStruct], ecb) case _: TLocus => LocusOrdering.make(t1.asInstanceOf[SLocus], t2.asInstanceOf[SLocus], ecb) - case _: TInterval => IntervalOrdering.make(t1.asInstanceOf[SInterval], t2.asInstanceOf[SInterval], ecb) + case _: TInterval => + IntervalOrdering.make(t1.asInstanceOf[SInterval], t2.asInstanceOf[SInterval], ecb) case _: TSet | _: TArray | _: TDict => IterableOrdering.make(t1.asInstanceOf[SContainer], t2.asInstanceOf[SContainer], ecb) } @@ -85,94 +94,114 @@ abstract class CodeOrdering { def reversed: Boolean = false - final def checkedSCode[T](cb: EmitCodeBuilder, arg1: SValue, arg2: SValue, context: String, - f: (EmitCodeBuilder, SValue, SValue) => Value[T])(implicit ti: TypeInfo[T]): Value[T] = { + final def checkedSCode[T]( + cb: EmitCodeBuilder, + arg1: SValue, + arg2: SValue, + context: String, + f: (EmitCodeBuilder, SValue, SValue) => Value[T], + )(implicit ti: TypeInfo[T] + ): Value[T] = { if (arg1.st != type1) - throw new RuntimeException(s"CodeOrdering: $context: type mismatch (left)\n generated: $type1\n argument: ${ arg1.st }") + throw new RuntimeException( + s"CodeOrdering: $context: type mismatch (left)\n generated: $type1\n argument: ${arg1.st}" + ) if (arg2.st != type2) - throw new RuntimeException(s"CodeOrdering: $context: type mismatch (right)\n generated: $type2\n argument: ${ arg2.st }") + throw new RuntimeException( + s"CodeOrdering: $context: type mismatch (right)\n generated: $type2\n argument: ${arg2.st}" + ) val cacheKey = ("ordering", reversed, type1, type2, context) - val mb = cb.emb.ecb.getOrGenEmitMethod(s"ord_$context", cacheKey, - FastSeq(arg1.st.paramType, arg2.st.paramType), ti) { mb => - + val mb = cb.emb.ecb.getOrGenEmitMethod( + s"ord_$context", + cacheKey, + FastSeq(arg1.st.paramType, arg2.st.paramType), + ti, + ) { mb => mb.emitWithBuilder[T] { cb => val arg1 = mb.getSCodeParam(1) val arg2 = mb.getSCodeParam(2) f(cb, arg1, arg2) } } - cb.memoize(cb.invokeCode[T](mb, arg1, arg2)) + cb.invokeCode[T](mb, cb.this_, arg1, arg2) } - final def checkedEmitCode[T](cb: EmitCodeBuilder, arg1: EmitValue, arg2: EmitValue, missingEqual: Boolean, context: String, - f: (EmitCodeBuilder, EmitValue, EmitValue, Boolean) => Value[T])(implicit ti: TypeInfo[T]): Value[T] = { + final def checkedEmitCode[T]( + cb: EmitCodeBuilder, + arg1: EmitValue, + arg2: EmitValue, + missingEqual: Boolean, + context: String, + f: (EmitCodeBuilder, EmitValue, EmitValue, Boolean) => Value[T], + )(implicit ti: TypeInfo[T] + ): Value[T] = { if (arg1.st != type1) - throw new RuntimeException(s"CodeOrdering: $context: type mismatch (left)\n generated: $type1\n argument: ${ arg1.st }") + throw new RuntimeException( + s"CodeOrdering: $context: type mismatch (left)\n generated: $type1\n argument: ${arg1.st}" + ) if (arg2.st != type2) - throw new RuntimeException(s"CodeOrdering: $context: type mismatch (right)\n generated: $type2\n argument: ${ arg2.st }") + throw new RuntimeException( + s"CodeOrdering: $context: type mismatch (right)\n generated: $type2\n argument: ${arg2.st}" + ) val cacheKey = ("ordering", reversed, arg1.emitType, arg2.emitType, context, missingEqual) - val mb = cb.emb.ecb.getOrGenEmitMethod(s"ord_$context", cacheKey, - FastSeq(arg1.emitParamType, arg2.emitParamType), ti) { mb => - + val mb = cb.emb.ecb.getOrGenEmitMethod( + s"ord_$context", + cacheKey, + FastSeq(arg1.emitParamType, arg2.emitParamType), + ti, + ) { mb => mb.emitWithBuilder[T] { cb => val arg1 = mb.getEmitParam(cb, 1) val arg2 = mb.getEmitParam(cb, 2) f(cb, arg1, arg2, missingEqual) } } - cb.memoize(cb.invokeCode[T](mb, arg1, arg2)) + cb.invokeCode[T](mb, cb.this_, arg1, arg2) } - - final def compareNonnull(cb: EmitCodeBuilder, x: SValue, y: SValue): Value[Int] = { + final def compareNonnull(cb: EmitCodeBuilder, x: SValue, y: SValue): Value[Int] = checkedSCode(cb, x, y, "compareNonnull", _compareNonnull) - } - final def ltNonnull(cb: EmitCodeBuilder, x: SValue, y: SValue): Value[Boolean] = { + final def ltNonnull(cb: EmitCodeBuilder, x: SValue, y: SValue): Value[Boolean] = checkedSCode(cb, x, y, "ltNonnull", _ltNonnull) - } - final def lteqNonnull(cb: EmitCodeBuilder, x: SValue, y: SValue): Value[Boolean] = { + final def lteqNonnull(cb: EmitCodeBuilder, x: SValue, y: SValue): Value[Boolean] = checkedSCode(cb, x, y, "lteqNonnull", _lteqNonnull) - } - final def gtNonnull(cb: EmitCodeBuilder, x: SValue, y: SValue): Value[Boolean] = { + final def gtNonnull(cb: EmitCodeBuilder, x: SValue, y: SValue): Value[Boolean] = checkedSCode(cb, x, y, "gtNonnull", _gtNonnull) - } - final def gteqNonnull(cb: EmitCodeBuilder, x: SValue, y: SValue): Value[Boolean] = { + final def gteqNonnull(cb: EmitCodeBuilder, x: SValue, y: SValue): Value[Boolean] = checkedSCode(cb, x, y, "gteqNonnull", _gteqNonnull) - } - final def equivNonnull(cb: EmitCodeBuilder, x: SValue, y: SValue): Value[Boolean] = { + final def equivNonnull(cb: EmitCodeBuilder, x: SValue, y: SValue): Value[Boolean] = checkedSCode(cb, x, y, "equivNonnull", _equivNonnull) - } - final def lt(cb: EmitCodeBuilder, x: EmitValue, y: EmitValue, missingEqual: Boolean): Value[Boolean] = { + final def lt(cb: EmitCodeBuilder, x: EmitValue, y: EmitValue, missingEqual: Boolean) + : Value[Boolean] = checkedEmitCode(cb, x, y, missingEqual, "lt", _lt) - } - final def lteq(cb: EmitCodeBuilder, x: EmitValue, y: EmitValue, missingEqual: Boolean): Value[Boolean] = { + final def lteq(cb: EmitCodeBuilder, x: EmitValue, y: EmitValue, missingEqual: Boolean) + : Value[Boolean] = checkedEmitCode(cb, x, y, missingEqual, "lteq", _lteq) - } - final def gt(cb: EmitCodeBuilder, x: EmitValue, y: EmitValue, missingEqual: Boolean): Value[Boolean] = { + final def gt(cb: EmitCodeBuilder, x: EmitValue, y: EmitValue, missingEqual: Boolean) + : Value[Boolean] = checkedEmitCode(cb, x, y, missingEqual, "gt", _gt) - } - final def gteq(cb: EmitCodeBuilder, x: EmitValue, y: EmitValue, missingEqual: Boolean): Value[Boolean] = { + final def gteq(cb: EmitCodeBuilder, x: EmitValue, y: EmitValue, missingEqual: Boolean) + : Value[Boolean] = checkedEmitCode(cb, x, y, missingEqual, "gteq", _gteq) - } - final def equiv(cb: EmitCodeBuilder, x: EmitValue, y: EmitValue, missingEqual: Boolean): Value[Boolean] = { + final def equiv(cb: EmitCodeBuilder, x: EmitValue, y: EmitValue, missingEqual: Boolean) + : Value[Boolean] = checkedEmitCode(cb, x, y, missingEqual, "equiv", _equiv) - } - final def compare(cb: EmitCodeBuilder, x: EmitValue, y: EmitValue, missingEqual: Boolean): Value[Int] = { + final def compare(cb: EmitCodeBuilder, x: EmitValue, y: EmitValue, missingEqual: Boolean) + : Value[Int] = checkedEmitCode(cb, x, y, missingEqual, "compare", _compare) - } def _compareNonnull(cb: EmitCodeBuilder, x: SValue, y: SValue): Value[Int] @@ -186,78 +215,86 @@ abstract class CodeOrdering { def _equivNonnull(cb: EmitCodeBuilder, x: SValue, y: SValue): Value[Boolean] - def _compare(cb: EmitCodeBuilder, x: EmitValue, y: EmitValue, missingEqual: Boolean = true): Value[Int] = { + def _compare(cb: EmitCodeBuilder, x: EmitValue, y: EmitValue, missingEqual: Boolean = true) + : Value[Int] = { val cmp = cb.newLocal[Int]("cmp") - cb.if_(x.m, + cb.if_( + x.m, cb.if_(y.m, cb.assign(cmp, if (missingEqual) 0 else -1), cb.assign(cmp, 1)), - cb.if_(y.m, cb.assign(cmp, -1), cb.assign(cmp, compareNonnull(cb, x.v, y.v)))) + cb.if_(y.m, cb.assign(cmp, -1), cb.assign(cmp, compareNonnull(cb, x.v, y.v))), + ) cmp } - def _lt(cb: EmitCodeBuilder, x: EmitValue, y: EmitValue, missingEqual: Boolean): Value[Boolean] = { + def _lt(cb: EmitCodeBuilder, x: EmitValue, y: EmitValue, missingEqual: Boolean) + : Value[Boolean] = { val ret = cb.newLocal[Boolean]("lt") if (missingEqual) { - cb.if_(x.m, + cb.if_( + x.m, cb.assign(ret, false), - cb.if_(y.m, - cb.assign(ret, true), - cb.assign(ret, ltNonnull(cb, x.v, y.v)))) + cb.if_(y.m, cb.assign(ret, true), cb.assign(ret, ltNonnull(cb, x.v, y.v))), + ) } else { - cb.if_(y.m, + cb.if_( + y.m, cb.assign(ret, true), - cb.if_(x.m, - cb.assign(ret, false), - cb.assign(ret, ltNonnull(cb, x.v, y.v)))) + cb.if_(x.m, cb.assign(ret, false), cb.assign(ret, ltNonnull(cb, x.v, y.v))), + ) } ret } - def _lteq(cb: EmitCodeBuilder, x: EmitValue, y: EmitValue, missingEqual: Boolean): Value[Boolean] = { + def _lteq(cb: EmitCodeBuilder, x: EmitValue, y: EmitValue, missingEqual: Boolean) + : Value[Boolean] = { val ret = cb.newLocal[Boolean]("lteq") - cb.if_(y.m, + cb.if_( + y.m, cb.assign(ret, true), - cb.if_(x.m, - cb.assign(ret, false), - cb.assign(ret, lteqNonnull(cb, x.v, y.v)))) + cb.if_(x.m, cb.assign(ret, false), cb.assign(ret, lteqNonnull(cb, x.v, y.v))), + ) ret } - def _gt(cb: EmitCodeBuilder, x: EmitValue, y: EmitValue, missingEqual: Boolean): Value[Boolean] = { + def _gt(cb: EmitCodeBuilder, x: EmitValue, y: EmitValue, missingEqual: Boolean) + : Value[Boolean] = { val ret = cb.newLocal[Boolean]("gt") - cb.if_(y.m, + cb.if_( + y.m, cb.assign(ret, false), - cb.if_(x.m, - cb.assign(ret, true), - cb.assign(ret, gtNonnull(cb, x.v, y.v)))) + cb.if_(x.m, cb.assign(ret, true), cb.assign(ret, gtNonnull(cb, x.v, y.v))), + ) ret } - def _gteq(cb: EmitCodeBuilder, x: EmitValue, y: EmitValue, missingEqual: Boolean): Value[Boolean] = { + def _gteq(cb: EmitCodeBuilder, x: EmitValue, y: EmitValue, missingEqual: Boolean) + : Value[Boolean] = { val ret = cb.newLocal[Boolean]("gteq") if (missingEqual) { - cb.if_(x.m, + cb.if_( + x.m, cb.assign(ret, true), - cb.if_(y.m, - cb.assign(ret, false), - cb.assign(ret, gteqNonnull(cb, x.v, y.v)))) + cb.if_(y.m, cb.assign(ret, false), cb.assign(ret, gteqNonnull(cb, x.v, y.v))), + ) } else { - cb.if_(y.m, + cb.if_( + y.m, cb.assign(ret, false), - cb.if_(x.m, - cb.assign(ret, true), - cb.assign(ret, gteqNonnull(cb, x.v, y.v)))) + cb.if_(x.m, cb.assign(ret, true), cb.assign(ret, gteqNonnull(cb, x.v, y.v))), + ) } ret } - def _equiv(cb: EmitCodeBuilder, x: EmitValue, y: EmitValue, missingEqual: Boolean): Value[Boolean] = { + def _equiv(cb: EmitCodeBuilder, x: EmitValue, y: EmitValue, missingEqual: Boolean) + : Value[Boolean] = { val ret = cb.newLocal[Boolean]("eq") if (missingEqual) { - cb.if_(x.m && y.m, + cb.if_( + x.m && y.m, cb.assign(ret, true), - cb.if_(!x.m && !y.m, - cb.assign(ret, equivNonnull(cb, x.v, y.v)), - cb.assign(ret, false))) + cb.if_(!x.m && !y.m, cb.assign(ret, equivNonnull(cb, x.v, y.v)), cb.assign(ret, false)), + ) } else { cb.if_(!x.m && !y.m, cb.assign(ret, equivNonnull(cb, x.v, y.v)), cb.assign(ret, false)) } diff --git a/hail/src/main/scala/is/hail/expr/ir/orderings/IntervalOrdering.scala b/hail/src/main/scala/is/hail/expr/ir/orderings/IntervalOrdering.scala index 36e710854c8..04b545e64da 100644 --- a/hail/src/main/scala/is/hail/expr/ir/orderings/IntervalOrdering.scala +++ b/hail/src/main/scala/is/hail/expr/ir/orderings/IntervalOrdering.scala @@ -7,189 +7,197 @@ import is.hail.types.physical.stypes.interfaces.SInterval object IntervalOrdering { - def make(t1: SInterval, t2: SInterval, ecb: EmitClassBuilder[_]): CodeOrdering = new CodeOrdering { - - override val type1: SInterval = t1 - override val type2: SInterval = t2 - - override def _compareNonnull(cb: EmitCodeBuilder, x: SValue, y: SValue): Value[Int] = { - val pointCompare = ecb.getOrderingFunction(t1.pointType, t2.pointType, CodeOrdering.Compare()) - val cmp = cb.newLocal[Int]("intervalord_cmp", 0) - - val lhs = x.asInterval - val rhs = y.asInterval - val lstart = cb.memoize(lhs.loadStart(cb)) - val rstart = cb.memoize(rhs.loadStart(cb)) - cb.assign(cmp, pointCompare(cb, lstart, rstart)) - cb.if_(cmp.ceq(0), { - cb.if_(lhs.includesStart.cne(rhs.includesStart), { - cb.assign(cmp, lhs.includesStart.mux(-1, 1)) - }, { - val lend = cb.memoize(lhs.loadEnd(cb)) - val rend = cb.memoize(rhs.loadEnd(cb)) - cb.assign(cmp, pointCompare(cb, lend, rend)) - cb.if_(cmp.ceq(0), { - cb.if_(lhs.includesEnd.cne(rhs.includesEnd), { - cb.assign(cmp, lhs.includesEnd.mux(1, -1)) - }) - }) - }) - }) - - cmp - } - - override def _equivNonnull(cb: EmitCodeBuilder, x: SValue, y: SValue): Value[Boolean] = { - val pointEq = ecb.getOrderingFunction(t1.pointType, t2.pointType, CodeOrdering.Equiv()) - - val Lout = CodeLabel() - val ret = cb.newLocal[Boolean]("interval_eq", true) - val exitWith = (value: Code[Boolean]) => { - cb.assign(ret, value) - cb.goto(Lout) + def make(t1: SInterval, t2: SInterval, ecb: EmitClassBuilder[_]): CodeOrdering = + new CodeOrdering { + + override val type1: SInterval = t1 + override val type2: SInterval = t2 + + override def _compareNonnull(cb: EmitCodeBuilder, x: SValue, y: SValue): Value[Int] = { + val pointCompare = + ecb.getOrderingFunction(t1.pointType, t2.pointType, CodeOrdering.Compare()) + val cmp = cb.newLocal[Int]("intervalord_cmp", 0) + + val lhs = x.asInterval + val rhs = y.asInterval + val lstart = cb.memoize(lhs.loadStart(cb)) + val rstart = cb.memoize(rhs.loadStart(cb)) + cb.assign(cmp, pointCompare(cb, lstart, rstart)) + cb.if_( + cmp.ceq(0), { + cb.if_( + lhs.includesStart.cne(rhs.includesStart), + cb.assign(cmp, lhs.includesStart.mux(-1, 1)), { + val lend = cb.memoize(lhs.loadEnd(cb)) + val rend = cb.memoize(rhs.loadEnd(cb)) + cb.assign(cmp, pointCompare(cb, lend, rend)) + cb.if_( + cmp.ceq(0), + cb.if_( + lhs.includesEnd.cne(rhs.includesEnd), + cb.assign(cmp, lhs.includesEnd.mux(1, -1)), + ), + ) + }, + ) + }, + ) + + cmp } - val lhs = x.asInterval - val rhs = y.asInterval + override def _equivNonnull(cb: EmitCodeBuilder, x: SValue, y: SValue): Value[Boolean] = { + val pointEq = ecb.getOrderingFunction(t1.pointType, t2.pointType, CodeOrdering.Equiv()) - cb.if_(lhs.includesStart.cne(rhs.includesStart) || - lhs.includesEnd.cne(rhs.includesEnd), { - exitWith(false) - }) + val Lout = CodeLabel() + val ret = cb.newLocal[Boolean]("interval_eq", true) + val exitWith = (value: Code[Boolean]) => { + cb.assign(ret, value) + cb.goto(Lout) + } - val lstart = cb.memoize(lhs.loadStart(cb)) - val rstart = cb.memoize(rhs.loadStart(cb)) - cb.if_(!pointEq(cb, lstart, rstart), exitWith(false)) + val lhs = x.asInterval + val rhs = y.asInterval - val lend = cb.memoize(lhs.loadEnd(cb)) - val rend = cb.memoize(rhs.loadEnd(cb)) - cb.if_(!pointEq(cb, lend, rend), exitWith(false)) + cb.if_( + lhs.includesStart.cne(rhs.includesStart) || + lhs.includesEnd.cne(rhs.includesEnd), + exitWith(false), + ) - cb.define(Lout) - ret - } + val lstart = cb.memoize(lhs.loadStart(cb)) + val rstart = cb.memoize(rhs.loadStart(cb)) + cb.if_(!pointEq(cb, lstart, rstart), exitWith(false)) - override def _ltNonnull(cb: EmitCodeBuilder, x: SValue, y: SValue): Value[Boolean] = { - val pointLt = ecb.getOrderingFunction(t1.pointType, t2.pointType, CodeOrdering.Lt()) - val pointEq = ecb.getOrderingFunction(t1.pointType, t2.pointType, CodeOrdering.Equiv()) + val lend = cb.memoize(lhs.loadEnd(cb)) + val rend = cb.memoize(rhs.loadEnd(cb)) + cb.if_(!pointEq(cb, lend, rend), exitWith(false)) - val Lout = CodeLabel() - val ret = cb.newLocal[Boolean]("interval_lt") - val exitWith = (value: Code[Boolean]) => { - cb.assign(ret, value) - cb.goto(Lout) + cb.define(Lout) + ret } - val lhs = x.asInterval - val rhs = y.asInterval - val lstart = cb.memoize(lhs.loadStart(cb)) - val rstart = cb.memoize(rhs.loadStart(cb)) + override def _ltNonnull(cb: EmitCodeBuilder, x: SValue, y: SValue): Value[Boolean] = { + val pointLt = ecb.getOrderingFunction(t1.pointType, t2.pointType, CodeOrdering.Lt()) + val pointEq = ecb.getOrderingFunction(t1.pointType, t2.pointType, CodeOrdering.Equiv()) - cb.if_(pointLt(cb, lstart, rstart), exitWith(true)) - cb.if_(!pointEq(cb, lstart, rstart), exitWith(false)) - cb.if_(lhs.includesStart && !rhs.includesStart, exitWith(true)) - cb.if_(lhs.includesStart.cne(rhs.includesStart), exitWith(false)) + val Lout = CodeLabel() + val ret = cb.newLocal[Boolean]("interval_lt") + val exitWith = (value: Code[Boolean]) => { + cb.assign(ret, value) + cb.goto(Lout) + } - val lend = cb.memoize(lhs.loadEnd(cb)) - val rend = cb.memoize(rhs.loadEnd(cb)) + val lhs = x.asInterval + val rhs = y.asInterval + val lstart = cb.memoize(lhs.loadStart(cb)) + val rstart = cb.memoize(rhs.loadStart(cb)) - cb.if_(pointLt(cb, lend, rend), exitWith(true)) - cb.assign(ret, pointEq(cb, lend, rend) && !lhs.includesEnd && rhs.includesEnd) + cb.if_(pointLt(cb, lstart, rstart), exitWith(true)) + cb.if_(!pointEq(cb, lstart, rstart), exitWith(false)) + cb.if_(lhs.includesStart && !rhs.includesStart, exitWith(true)) + cb.if_(lhs.includesStart.cne(rhs.includesStart), exitWith(false)) - cb.define(Lout) - ret - } + val lend = cb.memoize(lhs.loadEnd(cb)) + val rend = cb.memoize(rhs.loadEnd(cb)) - override def _lteqNonnull(cb: EmitCodeBuilder, x: SValue, y: SValue): Value[Boolean] = { - val pointLtEq = ecb.getOrderingFunction(t1.pointType, t2.pointType, CodeOrdering.Lteq()) - val pointEq = ecb.getOrderingFunction(t1.pointType, t2.pointType, CodeOrdering.Equiv()) + cb.if_(pointLt(cb, lend, rend), exitWith(true)) + cb.assign(ret, pointEq(cb, lend, rend) && !lhs.includesEnd && rhs.includesEnd) - val Lout = CodeLabel() - val ret = cb.newLocal[Boolean]("interval_lteq") - val exitWith = (value: Code[Boolean]) => { - cb.assign(ret, value) - cb.goto(Lout) + cb.define(Lout) + ret } - val lhs = x.asInterval - val rhs = y.asInterval - val lstart = cb.memoize(lhs.loadStart(cb)) - val rstart = cb.memoize(rhs.loadStart(cb)) - - cb.if_(!pointLtEq(cb, lstart, rstart), exitWith(false)) - cb.if_(!pointEq(cb, lstart, rstart), exitWith(true)) - cb.if_(lhs.includesStart && !rhs.includesStart, exitWith(true)) - cb.if_(lhs.includesStart.cne(rhs.includesStart), exitWith(false)) - - val lend = cb.memoize(lhs.loadEnd(cb)) - val rend = cb.memoize(rhs.loadEnd(cb)) - cb.if_(!pointLtEq(cb, lend, rend), exitWith(false)) - cb.assign(ret, !pointEq(cb, lend, rend) || !lhs.includesEnd || rhs.includesEnd) - - cb.define(Lout) - ret - } - - override def _gtNonnull(cb: EmitCodeBuilder, x: SValue, y: SValue): Value[Boolean] = { - val pointGt = ecb.getOrderingFunction(t1.pointType, t2.pointType, CodeOrdering.Gt()) - val pointEq = ecb.getOrderingFunction(t1.pointType, t2.pointType, CodeOrdering.Equiv()) - - val Lout = CodeLabel() - val ret = cb.newLocal[Boolean]("interval_gt") - val exitWith = (value: Code[Boolean]) => { - cb.assign(ret, value) - cb.goto(Lout) + override def _lteqNonnull(cb: EmitCodeBuilder, x: SValue, y: SValue): Value[Boolean] = { + val pointLtEq = ecb.getOrderingFunction(t1.pointType, t2.pointType, CodeOrdering.Lteq()) + val pointEq = ecb.getOrderingFunction(t1.pointType, t2.pointType, CodeOrdering.Equiv()) + + val Lout = CodeLabel() + val ret = cb.newLocal[Boolean]("interval_lteq") + val exitWith = (value: Code[Boolean]) => { + cb.assign(ret, value) + cb.goto(Lout) + } + + val lhs = x.asInterval + val rhs = y.asInterval + val lstart = cb.memoize(lhs.loadStart(cb)) + val rstart = cb.memoize(rhs.loadStart(cb)) + + cb.if_(!pointLtEq(cb, lstart, rstart), exitWith(false)) + cb.if_(!pointEq(cb, lstart, rstart), exitWith(true)) + cb.if_(lhs.includesStart && !rhs.includesStart, exitWith(true)) + cb.if_(lhs.includesStart.cne(rhs.includesStart), exitWith(false)) + + val lend = cb.memoize(lhs.loadEnd(cb)) + val rend = cb.memoize(rhs.loadEnd(cb)) + cb.if_(!pointLtEq(cb, lend, rend), exitWith(false)) + cb.assign(ret, !pointEq(cb, lend, rend) || !lhs.includesEnd || rhs.includesEnd) + + cb.define(Lout) + ret } - val lhs = x.asInterval - val rhs = y.asInterval - val lstart = cb.memoize(lhs.loadStart(cb)) - val rstart = cb.memoize(rhs.loadStart(cb)) + override def _gtNonnull(cb: EmitCodeBuilder, x: SValue, y: SValue): Value[Boolean] = { + val pointGt = ecb.getOrderingFunction(t1.pointType, t2.pointType, CodeOrdering.Gt()) + val pointEq = ecb.getOrderingFunction(t1.pointType, t2.pointType, CodeOrdering.Equiv()) - cb.if_(pointGt(cb, lstart, rstart), exitWith(true)) - cb.if_(!pointEq(cb, lstart, rstart), exitWith(false)) - cb.if_(!lhs.includesStart && rhs.includesStart, exitWith(true)) - cb.if_(lhs.includesStart.cne(rhs.includesStart), exitWith(false)) + val Lout = CodeLabel() + val ret = cb.newLocal[Boolean]("interval_gt") + val exitWith = (value: Code[Boolean]) => { + cb.assign(ret, value) + cb.goto(Lout) + } - val lend = cb.memoize(lhs.loadEnd(cb)) - val rend = cb.memoize(rhs.loadEnd(cb)) + val lhs = x.asInterval + val rhs = y.asInterval + val lstart = cb.memoize(lhs.loadStart(cb)) + val rstart = cb.memoize(rhs.loadStart(cb)) - cb.if_(pointGt(cb, lend, rend), exitWith(true)) - cb.assign(ret, pointEq(cb, lend, rend) && lhs.includesEnd && !rhs.includesEnd) + cb.if_(pointGt(cb, lstart, rstart), exitWith(true)) + cb.if_(!pointEq(cb, lstart, rstart), exitWith(false)) + cb.if_(!lhs.includesStart && rhs.includesStart, exitWith(true)) + cb.if_(lhs.includesStart.cne(rhs.includesStart), exitWith(false)) - cb.define(Lout) - ret - } + val lend = cb.memoize(lhs.loadEnd(cb)) + val rend = cb.memoize(rhs.loadEnd(cb)) - override def _gteqNonnull(cb: EmitCodeBuilder, x: SValue, y: SValue): Value[Boolean] = { - val pointGtEq = ecb.getOrderingFunction(t1.pointType, t2.pointType, CodeOrdering.Gteq()) - val pointEq = ecb.getOrderingFunction(t1.pointType, t2.pointType, CodeOrdering.Equiv()) + cb.if_(pointGt(cb, lend, rend), exitWith(true)) + cb.assign(ret, pointEq(cb, lend, rend) && lhs.includesEnd && !rhs.includesEnd) - val Lout = CodeLabel() - val ret = cb.newLocal[Boolean]("interval_gteq") - val exitWith = (value: Code[Boolean]) => { - cb.assign(ret, value) - cb.goto(Lout) + cb.define(Lout) + ret } - val lhs = x.asInterval - val rhs = y.asInterval - val lstart = cb.memoize(lhs.loadStart(cb)) - val rstart = cb.memoize(rhs.loadStart(cb)) - - cb.if_(!pointGtEq(cb, lstart, rstart), exitWith(false)) - cb.if_(!pointEq(cb, lstart, rstart), exitWith(true)) - cb.if_(!lhs.includesStart && rhs.includesStart, exitWith(true)) - cb.if_(lhs.includesStart.cne(rhs.includesStart), exitWith(false)) - - val lend = cb.memoize(lhs.loadEnd(cb)) - val rend = cb.memoize(rhs.loadEnd(cb)) - cb.if_(!pointGtEq(cb, lend, rend), exitWith(false)) - cb.assign(ret, !pointEq(cb, lend, rend) || lhs.includesEnd || !rhs.includesEnd) - - cb.define(Lout) - ret + override def _gteqNonnull(cb: EmitCodeBuilder, x: SValue, y: SValue): Value[Boolean] = { + val pointGtEq = ecb.getOrderingFunction(t1.pointType, t2.pointType, CodeOrdering.Gteq()) + val pointEq = ecb.getOrderingFunction(t1.pointType, t2.pointType, CodeOrdering.Equiv()) + + val Lout = CodeLabel() + val ret = cb.newLocal[Boolean]("interval_gteq") + val exitWith = (value: Code[Boolean]) => { + cb.assign(ret, value) + cb.goto(Lout) + } + + val lhs = x.asInterval + val rhs = y.asInterval + val lstart = cb.memoize(lhs.loadStart(cb)) + val rstart = cb.memoize(rhs.loadStart(cb)) + + cb.if_(!pointGtEq(cb, lstart, rstart), exitWith(false)) + cb.if_(!pointEq(cb, lstart, rstart), exitWith(true)) + cb.if_(!lhs.includesStart && rhs.includesStart, exitWith(true)) + cb.if_(lhs.includesStart.cne(rhs.includesStart), exitWith(false)) + + val lend = cb.memoize(lhs.loadEnd(cb)) + val rend = cb.memoize(rhs.loadEnd(cb)) + cb.if_(!pointGtEq(cb, lend, rend), exitWith(false)) + cb.assign(ret, !pointEq(cb, lend, rend) || lhs.includesEnd || !rhs.includesEnd) + + cb.define(Lout) + ret + } } - } } diff --git a/hail/src/main/scala/is/hail/expr/ir/orderings/IterableOrdering.scala b/hail/src/main/scala/is/hail/expr/ir/orderings/IterableOrdering.scala index d87fc78890d..af4950b160e 100644 --- a/hail/src/main/scala/is/hail/expr/ir/orderings/IterableOrdering.scala +++ b/hail/src/main/scala/is/hail/expr/ir/orderings/IterableOrdering.scala @@ -1,171 +1,192 @@ package is.hail.expr.ir.orderings import is.hail.asm4s._ -import is.hail.expr.ir.{EmitClassBuilder, EmitCode, EmitCodeBuilder, EmitValue} +import is.hail.expr.ir.{EmitClassBuilder, EmitCodeBuilder, EmitValue} import is.hail.types.physical.stypes.SValue import is.hail.types.physical.stypes.interfaces.{SContainer, SIndexableValue} object IterableOrdering { - def make(t1: SContainer, t2: SContainer, ecb: EmitClassBuilder[_]): CodeOrdering = new CodeOrdering { - - override val type1: SContainer = t1 - override val type2: SContainer = t2 - - private[this] def loop(cb: EmitCodeBuilder, lhs: SIndexableValue, rhs: SIndexableValue)( - f: (EmitValue, EmitValue) => Unit - ): Unit = { - val i = cb.newLocal[Int]("i") - val lim = cb.newLocal("lim", lhs.loadLength().min(rhs.loadLength())) - cb.for_(cb.assign(i, 0), i < lim, cb.assign(i, i + 1), { - val left = cb.memoize(lhs.loadElement(cb, i)) - val right = cb.memoize(rhs.loadElement(cb, i)) - f(left, right) - }) - } - - override def _compareNonnull(cb: EmitCodeBuilder, x: SValue, y: SValue): Value[Int] = { - val elemCmp = ecb.getOrderingFunction(t1.elementType, t2.elementType, CodeOrdering.Compare()) - - val Lout = CodeLabel() - val cmp = cb.newLocal[Int]("iterable_cmp", 0) + def make(t1: SContainer, t2: SContainer, ecb: EmitClassBuilder[_]): CodeOrdering = + new CodeOrdering { + + override val type1: SContainer = t1 + override val type2: SContainer = t2 + + private[this] def loop( + cb: EmitCodeBuilder, + lhs: SIndexableValue, + rhs: SIndexableValue, + )( + f: (EmitValue, EmitValue) => Unit + ): Unit = { + val i = cb.newLocal[Int]("i") + val lim = cb.newLocal("lim", lhs.loadLength().min(rhs.loadLength())) + cb.for_( + cb.assign(i, 0), + i < lim, + cb.assign(i, i + 1), { + val left = cb.memoize(lhs.loadElement(cb, i)) + val right = cb.memoize(rhs.loadElement(cb, i)) + f(left, right) + }, + ) + } - val lhs = x.asIndexable - val rhs = y.asIndexable - loop(cb, lhs, rhs) { (lhs, rhs) => - cb.assign(cmp, elemCmp(cb, lhs, rhs)) - cb.if_(cmp.cne(0), cb.goto(Lout)) + override def _compareNonnull(cb: EmitCodeBuilder, x: SValue, y: SValue): Value[Int] = { + val elemCmp = + ecb.getOrderingFunction(t1.elementType, t2.elementType, CodeOrdering.Compare()) + + val Lout = CodeLabel() + val cmp = cb.newLocal[Int]("iterable_cmp", 0) + + val lhs = x.asIndexable + val rhs = y.asIndexable + loop(cb, lhs, rhs) { (lhs, rhs) => + cb.assign(cmp, elemCmp(cb, lhs, rhs)) + cb.if_(cmp.cne(0), cb.goto(Lout)) + } + + // if we get here, cmp is 0 + cb.assign( + cmp, + Code.invokeStatic2[java.lang.Integer, Int, Int, Int]( + "compare", + lhs.loadLength(), + rhs.loadLength(), + ), + ) + cb.define(Lout) + cmp } - // if we get here, cmp is 0 - cb.assign(cmp, - Code.invokeStatic2[java.lang.Integer, Int, Int, Int]( - "compare", lhs.loadLength(), rhs.loadLength())) - cb.define(Lout) - cmp - } + override def _ltNonnull(cb: EmitCodeBuilder, x: SValue, y: SValue): Value[Boolean] = { + val elemLt = ecb.getOrderingFunction(t1.elementType, t2.elementType, CodeOrdering.Lt()) + val elemEq = ecb.getOrderingFunction(t1.elementType, t2.elementType, CodeOrdering.Equiv()) - override def _ltNonnull(cb: EmitCodeBuilder, x: SValue, y: SValue): Value[Boolean] = { - val elemLt = ecb.getOrderingFunction(t1.elementType, t2.elementType, CodeOrdering.Lt()) - val elemEq = ecb.getOrderingFunction(t1.elementType, t2.elementType, CodeOrdering.Equiv()) + val ret = cb.newLocal[Boolean]("iterable_lt") + val Lout = CodeLabel() - val ret = cb.newLocal[Boolean]("iterable_lt") - val Lout = CodeLabel() + val lhs = x.asIndexable + val rhs = y.asIndexable - val lhs = x.asIndexable - val rhs = y.asIndexable + loop(cb, lhs, rhs) { (lhs, rhs) => + val lt = elemLt(cb, lhs, rhs) + val eq = !lt && elemEq(cb, lhs, rhs) - loop(cb, lhs, rhs) { (lhs, rhs) => - val lt = elemLt(cb, lhs, rhs) - val eq = !lt && elemEq(cb, lhs, rhs) + cb.if_( + !eq, { + cb.assign(ret, lt) + cb.goto(Lout) + }, + ) + } - cb.if_(!eq, { - cb.assign(ret, lt) - cb.goto(Lout) - }) + cb.assign(ret, lhs.loadLength() < rhs.loadLength()) + cb.define(Lout) + ret } - cb.assign(ret, lhs.loadLength() < rhs.loadLength()) - cb.define(Lout) - ret - } + override def _lteqNonnull(cb: EmitCodeBuilder, x: SValue, y: SValue): Value[Boolean] = { + val elemLtEq = ecb.getOrderingFunction(t1.elementType, t2.elementType, CodeOrdering.Lteq()) + val elemEq = ecb.getOrderingFunction(t1.elementType, t2.elementType, CodeOrdering.Equiv()) - override def _lteqNonnull(cb: EmitCodeBuilder, x: SValue, y: SValue): Value[Boolean] = { - val elemLtEq = ecb.getOrderingFunction(t1.elementType, t2.elementType, CodeOrdering.Lteq()) - val elemEq = ecb.getOrderingFunction(t1.elementType, t2.elementType, CodeOrdering.Equiv()) + val ret = cb.newLocal[Boolean]("iterable_lteq") + val Lout = CodeLabel() - val ret = cb.newLocal[Boolean]("iterable_lteq") - val Lout = CodeLabel() + val lhs = x.asIndexable + val rhs = y.asIndexable - val lhs = x.asIndexable - val rhs = y.asIndexable + loop(cb, lhs, rhs) { (lhs, rhs) => + val lteq = elemLtEq(cb, lhs, rhs) + val eq = elemEq(cb, lhs, rhs) - loop(cb, lhs, rhs) { (lhs, rhs) => - val lteq = elemLtEq(cb, lhs, rhs) - val eq = elemEq(cb, lhs, rhs) + cb.if_( + !eq, { + cb.assign(ret, lteq) + cb.goto(Lout) + }, + ) + } - cb.if_(!eq, { - cb.assign(ret, lteq) - cb.goto(Lout) - }) + cb.assign(ret, lhs.loadLength() <= rhs.loadLength) + cb.define(Lout) + ret } - cb.assign(ret, lhs.loadLength() <= rhs.loadLength) - cb.define(Lout) - ret - } + override def _gtNonnull(cb: EmitCodeBuilder, x: SValue, y: SValue): Value[Boolean] = { + val elemGt = ecb.getOrderingFunction(t1.elementType, t2.elementType, CodeOrdering.Gt()) + val elemEq = ecb.getOrderingFunction(t1.elementType, t2.elementType, CodeOrdering.Equiv()) - override def _gtNonnull(cb: EmitCodeBuilder, x: SValue, y: SValue): Value[Boolean] = { - val elemGt = ecb.getOrderingFunction(t1.elementType, t2.elementType, CodeOrdering.Gt()) - val elemEq = ecb.getOrderingFunction(t1.elementType, t2.elementType, CodeOrdering.Equiv()) + val ret = cb.newLocal[Boolean]("iterable_gt") + val Lout = CodeLabel() - val ret = cb.newLocal[Boolean]("iterable_gt") - val Lout = CodeLabel() + val lhs = x.asIndexable + val rhs = y.asIndexable - val lhs = x.asIndexable - val rhs = y.asIndexable - val gt = cb.newLocal("gt", false) - val eq = cb.newLocal("eq", true) + loop(cb, lhs, rhs) { (lhs, rhs) => + val gt = elemGt(cb, lhs, rhs) + val eq = !gt && elemEq(cb, lhs, rhs) - loop(cb, lhs, rhs) { (lhs, rhs) => - val gt = elemGt(cb, lhs, rhs) - val eq = !gt && elemEq(cb, lhs, rhs) + cb.if_( + !eq, { + cb.assign(ret, gt) + cb.goto(Lout) + }, + ) + } - cb.if_(!eq, { - cb.assign(ret, gt) - cb.goto(Lout) - }) + cb.assign(ret, lhs.loadLength() > rhs.loadLength()) + cb.define(Lout) + ret } - cb.assign(ret, lhs.loadLength() > rhs.loadLength()) - cb.define(Lout) - ret - } + override def _gteqNonnull(cb: EmitCodeBuilder, x: SValue, y: SValue): Value[Boolean] = { + val elemGtEq = ecb.getOrderingFunction(t1.elementType, t2.elementType, CodeOrdering.Gteq()) + val elemEq = ecb.getOrderingFunction(t1.elementType, t2.elementType, CodeOrdering.Equiv()) - override def _gteqNonnull(cb: EmitCodeBuilder, x: SValue, y: SValue): Value[Boolean] = { - val elemGtEq = ecb.getOrderingFunction(t1.elementType, t2.elementType, CodeOrdering.Gteq()) - val elemEq = ecb.getOrderingFunction(t1.elementType, t2.elementType, CodeOrdering.Equiv()) + val ret = cb.newLocal[Boolean]("iterable_gteq") + val Lout = CodeLabel() - val ret = cb.newLocal[Boolean]("iterable_gteq") - val Lout = CodeLabel() + val lhs = x.asIndexable + val rhs = y.asIndexable - val lhs = x.asIndexable - val rhs = y.asIndexable + loop(cb, lhs, rhs) { (lhs, rhs) => + val gteq = elemGtEq(cb, lhs, rhs) + val eq = elemEq(cb, lhs, rhs) - loop(cb, lhs, rhs) { (lhs, rhs) => - val gteq = elemGtEq(cb, lhs, rhs) - val eq = elemEq(cb, lhs, rhs) + cb.if_( + !eq, { + cb.assign(ret, gteq) + cb.goto(Lout) + }, + ) + } - cb.if_(!eq, { - cb.assign(ret, gteq) - cb.goto(Lout) - }) + cb.assign(ret, lhs.loadLength() >= rhs.loadLength) + cb.define(Lout) + ret } - cb.assign(ret, lhs.loadLength() >= rhs.loadLength) - cb.define(Lout) - ret - } - - override def _equivNonnull(cb: EmitCodeBuilder, x: SValue, y: SValue): Value[Boolean] = { - val elemEq = ecb.getOrderingFunction(t1.elementType, t2.elementType, CodeOrdering.Equiv()) - val ret = cb.newLocal[Boolean]("iterable_eq", true) - val Lout = CodeLabel() - val exitWith = (value: Code[Boolean]) => { - cb.assign(ret, value) - cb.goto(Lout) - } - - val lhs = x.asIndexable - val rhs = y.asIndexable - cb.if_(lhs.loadLength().cne(rhs.loadLength()), exitWith(false)) - loop(cb, lhs, rhs) { (lhs, rhs) => - cb.assign(ret, elemEq(cb, lhs, rhs)) - cb.if_(!ret, cb.goto(Lout)) + override def _equivNonnull(cb: EmitCodeBuilder, x: SValue, y: SValue): Value[Boolean] = { + val elemEq = ecb.getOrderingFunction(t1.elementType, t2.elementType, CodeOrdering.Equiv()) + val ret = cb.newLocal[Boolean]("iterable_eq", true) + val Lout = CodeLabel() + val exitWith = (value: Code[Boolean]) => { + cb.assign(ret, value) + cb.goto(Lout) + } + + val lhs = x.asIndexable + val rhs = y.asIndexable + cb.if_(lhs.loadLength().cne(rhs.loadLength()), exitWith(false)) + loop(cb, lhs, rhs) { (lhs, rhs) => + cb.assign(ret, elemEq(cb, lhs, rhs)) + cb.if_(!ret, cb.goto(Lout)) + } + + cb.define(Lout) + ret } - - cb.define(Lout) - ret } - } } diff --git a/hail/src/main/scala/is/hail/expr/ir/orderings/LocusOrdering.scala b/hail/src/main/scala/is/hail/expr/ir/orderings/LocusOrdering.scala index ae30e163d4d..c8c98180e68 100644 --- a/hail/src/main/scala/is/hail/expr/ir/orderings/LocusOrdering.scala +++ b/hail/src/main/scala/is/hail/expr/ir/orderings/LocusOrdering.scala @@ -17,7 +17,8 @@ object LocusOrdering { require(t1.rg == t2.rg) - override def _compareNonnull(cb: EmitCodeBuilder, lhsc: SValue, rhsc: SValue): Value[Int] = { + override def _compareNonnull(cb: EmitCodeBuilder, lhsc: SValue, rhsc: SValue) + : Value[Int] = { val codeRG = cb.emb.getReferenceGenome(t1.rg) val lhs: SLocusValue = lhsc.asLocus val rhs: SLocusValue = rhsc.asLocus @@ -30,13 +31,25 @@ object LocusOrdering { val strcmp = CodeOrdering.makeOrdering(lhsContigType, rhsContigType, ecb) val ret = cb.newLocal[Int]("locus_cmp_ret", 0) - cb.if_(strcmp.compareNonnull(cb, lhsContig, rhsContig).ceq(0), { - cb.assign(ret, Code.invokeStatic2[java.lang.Integer, Int, Int, Int]( - "compare", lhs.position(cb), rhs.position(cb))) - }, { - cb.assign(ret, codeRG.invoke[String, String, Int]( - "compare", lhsContig.loadString(cb).get, rhsContig.loadString(cb).get)) - }) + cb.if_( + strcmp.compareNonnull(cb, lhsContig, rhsContig).ceq(0), + cb.assign( + ret, + Code.invokeStatic2[java.lang.Integer, Int, Int, Int]( + "compare", + lhs.position(cb), + rhs.position(cb), + ), + ), + cb.assign( + ret, + codeRG.invoke[String, String, Int]( + "compare", + lhsContig.loadString(cb).get, + rhsContig.loadString(cb).get, + ), + ), + ) ret } } diff --git a/hail/src/main/scala/is/hail/expr/ir/orderings/PrimitiveOrdering.scala b/hail/src/main/scala/is/hail/expr/ir/orderings/PrimitiveOrdering.scala index 492617b4dd7..28ff57e6435 100644 --- a/hail/src/main/scala/is/hail/expr/ir/orderings/PrimitiveOrdering.scala +++ b/hail/src/main/scala/is/hail/expr/ir/orderings/PrimitiveOrdering.scala @@ -2,7 +2,7 @@ package is.hail.expr.ir.orderings import is.hail.asm4s.{Code, Value} import is.hail.expr.ir.{EmitClassBuilder, EmitCodeBuilder} -import is.hail.types.physical.stypes.{SCode, SValue} +import is.hail.types.physical.stypes.SValue import is.hail.types.physical.stypes.primitives._ object Int32Ordering { @@ -13,7 +13,11 @@ object Int32Ordering { override val type2: SInt32.type = t2 override def _compareNonnull(cb: EmitCodeBuilder, x: SValue, y: SValue): Value[Int] = - cb.memoize(Code.invokeStatic2[java.lang.Integer, Int, Int, Int]("compare", x.asInt.value, y.asInt.value)) + cb.memoize(Code.invokeStatic2[java.lang.Integer, Int, Int, Int]( + "compare", + x.asInt.value, + y.asInt.value, + )) override def _ltNonnull(cb: EmitCodeBuilder, x: SValue, y: SValue): Value[Boolean] = cb.memoize(x.asInt.value < y.asInt.value) @@ -33,7 +37,6 @@ object Int32Ordering { } } - object Int64Ordering { def make(t1: SInt64.type, t2: SInt64.type, ecb: EmitClassBuilder[_]): CodeOrdering = { new CodeOrdering { @@ -42,7 +45,11 @@ object Int64Ordering { override val type2: SInt64.type = t2 override def _compareNonnull(cb: EmitCodeBuilder, x: SValue, y: SValue): Value[Int] = - cb.memoize(Code.invokeStatic2[java.lang.Long, Long, Long, Int]("compare", x.asLong.value, y.asLong.value)) + cb.memoize(Code.invokeStatic2[java.lang.Long, Long, Long, Int]( + "compare", + x.asLong.value, + y.asLong.value, + )) override def _ltNonnull(cb: EmitCodeBuilder, x: SValue, y: SValue): Value[Boolean] = cb.memoize(x.asLong.value < y.asLong.value) @@ -70,7 +77,11 @@ object Float32Ordering { override val type2: SFloat32.type = t2 override def _compareNonnull(cb: EmitCodeBuilder, x: SValue, y: SValue): Value[Int] = - cb.memoize(Code.invokeStatic2[java.lang.Float, Float, Float, Int]("compare", x.asFloat.value, y.asFloat.value)) + cb.memoize(Code.invokeStatic2[java.lang.Float, Float, Float, Int]( + "compare", + x.asFloat.value, + y.asFloat.value, + )) override def _ltNonnull(cb: EmitCodeBuilder, x: SValue, y: SValue): Value[Boolean] = cb.memoize(x.asFloat.value < y.asFloat.value) @@ -98,7 +109,11 @@ object Float64Ordering { override val type2: SFloat64.type = t2 override def _compareNonnull(cb: EmitCodeBuilder, x: SValue, y: SValue): Value[Int] = - cb.memoize(Code.invokeStatic2[java.lang.Double, Double, Double, Int]("compare", x.asDouble.value, y.asDouble.value)) + cb.memoize(Code.invokeStatic2[java.lang.Double, Double, Double, Int]( + "compare", + x.asDouble.value, + y.asDouble.value, + )) override def _ltNonnull(cb: EmitCodeBuilder, x: SValue, y: SValue): Value[Boolean] = cb.memoize(x.asDouble.value < y.asDouble.value) @@ -126,7 +141,11 @@ object BooleanOrdering { override val type2: SBoolean.type = t2 override def _compareNonnull(cb: EmitCodeBuilder, x: SValue, y: SValue): Value[Int] = - cb.memoize(Code.invokeStatic2[java.lang.Boolean, Boolean, Boolean, Int]("compare", x.asBoolean.value, y.asBoolean.value)) + cb.memoize(Code.invokeStatic2[java.lang.Boolean, Boolean, Boolean, Int]( + "compare", + x.asBoolean.value, + y.asBoolean.value, + )) } } } diff --git a/hail/src/main/scala/is/hail/expr/ir/orderings/StringOrdering.scala b/hail/src/main/scala/is/hail/expr/ir/orderings/StringOrdering.scala index 792b3a32c76..d8cf55a6aac 100644 --- a/hail/src/main/scala/is/hail/expr/ir/orderings/StringOrdering.scala +++ b/hail/src/main/scala/is/hail/expr/ir/orderings/StringOrdering.scala @@ -19,7 +19,7 @@ object StringOrdering { val bcode1 = x.asInstanceOf[SStringPointerValue] val bcode2 = y.asInstanceOf[SStringPointerValue] val ord = BinaryOrdering.make(bcode1.binaryRepr.st, bcode2.binaryRepr.st, ecb) - ord.compareNonnull(cb, bcode1.binaryRepr, bcode2.binaryRepr) + ord._compareNonnull(cb, bcode1.binaryRepr, bcode2.binaryRepr) } } @@ -29,9 +29,11 @@ object StringOrdering { override val type1: SString = t1 override val type2: SString = t2 - override def _compareNonnull(cb: EmitCodeBuilder, x: SValue, y: SValue): Value[Int] = { - cb.memoize(x.asString.loadString(cb).invoke[String, Int]("compareTo", y.asString.loadString(cb))) - } + override def _compareNonnull(cb: EmitCodeBuilder, x: SValue, y: SValue): Value[Int] = + cb.memoize(x.asString.loadString(cb).invoke[String, Int]( + "compareTo", + y.asString.loadString(cb), + )) } } } diff --git a/hail/src/main/scala/is/hail/expr/ir/orderings/StructOrdering.scala b/hail/src/main/scala/is/hail/expr/ir/orderings/StructOrdering.scala index 3ec08d41b58..2419417d505 100644 --- a/hail/src/main/scala/is/hail/expr/ir/orderings/StructOrdering.scala +++ b/hail/src/main/scala/is/hail/expr/ir/orderings/StructOrdering.scala @@ -1,9 +1,9 @@ package is.hail.expr.ir.orderings -import is.hail.asm4s.{Code, CodeLabel, Value} -import is.hail.expr.ir.{Ascending, EmitClassBuilder, EmitCode, EmitCodeBuilder, SortOrder} -import is.hail.types.physical.stypes.{SCode, SValue} -import is.hail.types.physical.stypes.interfaces.{SBaseStruct, SBaseStructValue} +import is.hail.asm4s.{CodeLabel, Value} +import is.hail.expr.ir.{Ascending, EmitClassBuilder, EmitCodeBuilder, SortOrder} +import is.hail.types.physical.stypes.SValue +import is.hail.types.physical.stypes.interfaces.SBaseStruct object StructOrdering { def make( @@ -11,7 +11,7 @@ object StructOrdering { t2: SBaseStruct, ecb: EmitClassBuilder[_], sortOrders: Array[SortOrder] = null, - missingFieldsEqual: Boolean = true + missingFieldsEqual: Boolean = true, ): CodeOrdering = new CodeOrdering { override val type1: SBaseStruct = t1 @@ -20,9 +20,12 @@ object StructOrdering { require(sortOrders == null || sortOrders.size == t1.size) private[this] def fieldOrdering(i: Int, op: CodeOrdering.Op): CodeOrdering.F[op.ReturnType] = - ecb.getOrderingFunction(t1.fieldTypes(i), t2.fieldTypes(i), + ecb.getOrderingFunction( + t1.fieldTypes(i), + t2.fieldTypes(i), if (sortOrders == null) Ascending else sortOrders(i), - op) + op, + ) override def _compareNonnull(cb: EmitCodeBuilder, x: SValue, y: SValue): Value[Int] = { val lhs = x.asBaseStruct diff --git a/hail/src/main/scala/is/hail/expr/ir/package.scala b/hail/src/main/scala/is/hail/expr/ir/package.scala index 398a15a94fd..1d4d50553ea 100644 --- a/hail/src/main/scala/is/hail/expr/ir/package.scala +++ b/hail/src/main/scala/is/hail/expr/ir/package.scala @@ -7,9 +7,10 @@ import is.hail.types.tcoerce import is.hail.types.virtual._ import is.hail.utils._ -import java.util.UUID import scala.language.implicitConversions +import java.util.UUID + package object ir { type TokenIterator = BufferedIterator[Token] type IEmitCode = IEmitCodeGen[SValue] @@ -39,22 +40,25 @@ package object ir { def invoke(name: String, rt: Type, typeArgs: Seq[Type], errorID: Int, args: IR*): IR = IRFunctionRegistry.lookupUnseeded(name, rt, typeArgs, args.map(_.typ)) match { case Some(f) => f(typeArgs, args, errorID) - case None => fatal(s"no conversion found for $name(${typeArgs.mkString(", ")}, ${args.map(_.typ).mkString(", ")}) => $rt") + case None => fatal( + s"no conversion found for $name(${typeArgs.mkString(", ")}, ${args.map(_.typ).mkString(", ")}) => $rt" + ) } def invoke(name: String, rt: Type, typeArgs: Array[Type], args: IR*): IR = - invoke(name, rt, typeArgs, ErrorIDs.NO_ERROR, args:_*) + invoke(name, rt, typeArgs, ErrorIDs.NO_ERROR, args: _*) def invoke(name: String, rt: Type, args: IR*): IR = - invoke(name, rt, Array.empty[Type], ErrorIDs.NO_ERROR, args:_*) + invoke(name, rt, Array.empty[Type], ErrorIDs.NO_ERROR, args: _*) def invoke(name: String, rt: Type, errorID: Int, args: IR*): IR = - invoke(name, rt, Array.empty[Type], errorID, args:_*) + invoke(name, rt, Array.empty[Type], errorID, args: _*) def invokeSeeded(name: String, staticUID: Long, rt: Type, rngState: IR, args: IR*): IR = IRFunctionRegistry.lookupSeeded(name, staticUID, rt, args.map(_.typ)) match { case Some(f) => f(args, rngState) - case None => fatal(s"no seeded function found for $name(${args.map(_.typ).mkString(", ")}) => $rt") + case None => + fatal(s"no seeded function found for $name(${args.map(_.typ).mkString(", ")}) => $rt") } implicit def irToPrimitiveIR(ir: IR): PrimitiveIR = new PrimitiveIR(ir) @@ -96,13 +100,11 @@ package object ir { StreamTakeWhile(v, ref.name, f(ref)) } - def maxIR(a: IR, b: IR): IR = { + def maxIR(a: IR, b: IR): IR = If(a > b, a, b) - } - def minIR(a: IR, b: IR): IR = { + def minIR(a: IR, b: IR): IR = If(a < b, a, b) - } def streamAggIR(stream: IR)(f: Ref => IR): IR = { val ref = Ref(genUID(), tcoerce[TStream](stream.typ).elementType) @@ -124,14 +126,18 @@ package object ir { StreamMap(stream, ref.name, f(ref)) } + def mapArray(array: IR)(f: Ref => IR): IR = + ToArray(mapIR(ToStream(array))(f)) + def flatMapIR(stream: IR)(f: Ref => IR): IR = { val ref = Ref(genUID(), tcoerce[TStream](stream.typ).elementType) StreamFlatMap(stream, ref.name, f(ref)) } - def flatten(stream: IR): IR = flatMapIR(if (stream.typ.isInstanceOf[TStream]) stream else ToStream(stream)) { elt => - if (elt.typ.isInstanceOf[TStream]) elt else ToStream(elt) - } + def flatten(stream: IR): IR = + flatMapIR(if (stream.typ.isInstanceOf[TStream]) stream else ToStream(stream)) { elt => + if (elt.typ.isInstanceOf[TStream]) elt else ToStream(elt) + } def foldIR(stream: IR, zero: IR)(f: (Ref, Ref) => IR): IR = { val elt = Ref(genUID(), tcoerce[TStream](stream.typ).elementType) @@ -146,25 +152,50 @@ package object ir { ArraySort(stream, l.name, r.name, f(l, r)) } - def sliceArrayIR(arrayIR: IR, startIR: IR, stopIR: IR): IR = { + def sliceArrayIR(arrayIR: IR, startIR: IR, stopIR: IR): IR = ArraySlice(arrayIR, startIR, Some(stopIR)) - } - def joinIR(left: IR, right: IR, lkey: IndexedSeq[String], rkey: IndexedSeq[String], joinType: String, requiresMemoryManagement: Boolean)(f: (Ref, Ref) => IR): IR = { + def joinIR( + left: IR, + right: IR, + lkey: IndexedSeq[String], + rkey: IndexedSeq[String], + joinType: String, + requiresMemoryManagement: Boolean, + )( + f: (Ref, Ref) => IR + ): IR = { val lRef = Ref(genUID(), left.typ.asInstanceOf[TStream].elementType) val rRef = Ref(genUID(), right.typ.asInstanceOf[TStream].elementType) - StreamJoin(left, right, lkey, rkey, lRef.name, rRef.name, f(lRef, rRef), joinType, requiresMemoryManagement) + StreamJoin( + left, + right, + lkey, + rkey, + lRef.name, + rRef.name, + f(lRef, rRef), + joinType, + requiresMemoryManagement, + ) } - def joinRightDistinctIR(left: IR, right: IR, lkey: IndexedSeq[String], rkey: IndexedSeq[String], joinType: String)(f: (Ref, Ref) => IR): IR = { + def joinRightDistinctIR( + left: IR, + right: IR, + lkey: IndexedSeq[String], + rkey: IndexedSeq[String], + joinType: String, + )( + f: (Ref, Ref) => IR + ): IR = { val lRef = Ref(genUID(), left.typ.asInstanceOf[TStream].elementType) val rRef = Ref(genUID(), right.typ.asInstanceOf[TStream].elementType) StreamJoinRightDistinct(left, right, lkey, rkey, lRef.name, rRef.name, f(lRef, rRef), joinType) } - def streamSumIR(stream: IR): IR = { - foldIR(stream, 0){ case (accum, elt) => accum + elt} - } + def streamSumIR(stream: IR): IR = + foldIR(stream, 0) { case (accum, elt) => accum + elt } def streamForceCount(stream: IR): IR = streamSumIR(mapIR(stream)(_ => I32(1))) @@ -173,7 +204,9 @@ package object ir { def rangeIR(start: IR, stop: IR): IR = StreamRange(start, stop, 1) - def insertIR(old: IR, fields: (String, IR)*): InsertFields = InsertFields(old, fields.toArray[(String, IR)]) + def insertIR(old: IR, fields: (String, IR)*): InsertFields = + InsertFields(old, fields.toArray[(String, IR)]) + def selectIR(old: IR, fields: String*): SelectFields = SelectFields(old, fields.toArray[String]) def zip2(s1: IR, s2: IR, behavior: ArrayZipBehavior.ArrayZipBehavior)(f: (Ref, Ref) => IR): IR = { @@ -189,17 +222,25 @@ package object ir { FastSeq(s, StreamIota(I32(0), I32(1))), FastSeq(r1.name, r2.name), MakeStruct(FastSeq(("elt", r1), ("idx", r2))), - ArrayZipBehavior.TakeMinLength + ArrayZipBehavior.TakeMinLength, ) } - def zipIR(ss: IndexedSeq[IR], behavior: ArrayZipBehavior.ArrayZipBehavior, errorId: Int = ErrorIDs.NO_ERROR)(f: IndexedSeq[Ref] => IR): IR = { + def zipIR( + ss: IndexedSeq[IR], + behavior: ArrayZipBehavior.ArrayZipBehavior, + errorId: Int = ErrorIDs.NO_ERROR, + )( + f: IndexedSeq[Ref] => IR + ): IR = { val refs = ss.map(s => Ref(genUID(), tcoerce[TStream](s.typ).elementType)) StreamZip(ss, refs.map(_.name), f(refs), behavior, errorId) } def makestruct(fields: (String, IR)*): MakeStruct = MakeStruct(fields.toArray[(String, IR)]) - def maketuple(fields: IR*): MakeTuple = MakeTuple(fields.toArray.zipWithIndex.map { case (field, idx) => (idx, field) }) + + def maketuple(fields: IR*): MakeTuple = + MakeTuple(fields.toArray.zipWithIndex.map { case (field, idx) => (idx, field) }) def aggBindIR(v: IR, isScan: Boolean = false)(body: Ref => IR): IR = { val ref = Ref(genUID(), v.typ) @@ -211,17 +252,34 @@ package object ir { AggExplode(v, r.name, body(r), isScan) } - def aggFoldIR(zero: IR, element: IR)(seqOp: (Ref, IR) => IR)(combOp: (Ref, Ref) => IR) : AggFold = { + def aggFoldIR(zero: IR, element: IR)(seqOp: (Ref, IR) => IR)(combOp: (Ref, Ref) => IR) + : AggFold = { val accum1 = Ref(genUID(), zero.typ) val accum2 = Ref(genUID(), zero.typ) AggFold(zero, seqOp(accum1, element), combOp(accum1, accum2), accum1.name, accum2.name, false) } - def cdaIR(contexts: IR, globals: IR, staticID: String, dynamicID: IR = NA(TString))(body: (Ref, Ref) => IR): CollectDistributedArray = { + def cdaIR( + contexts: IR, + globals: IR, + staticID: String, + dynamicID: IR = NA(TString), + )( + body: (Ref, Ref) => IR + ): CollectDistributedArray = { val contextRef = Ref(genUID(), contexts.typ.asInstanceOf[TStream].elementType) val globalRef = Ref(genUID(), globals.typ) - CollectDistributedArray(contexts, globals, contextRef.name, globalRef.name, body(contextRef, globalRef), dynamicID, staticID, None) + CollectDistributedArray( + contexts, + globals, + contextRef.name, + globalRef.name, + body(contextRef, globalRef), + dynamicID, + staticID, + None, + ) } def strConcat(irs: AnyRef*): IR = { @@ -248,7 +306,8 @@ package object ir { def logIR(result: IR, messages: AnyRef*): IR = ConsoleLog(strConcat(messages: _*), result) - implicit def toRichIndexedSeqEmitSettable(s: IndexedSeq[EmitSettable]): RichIndexedSeqEmitSettable = new RichIndexedSeqEmitSettable(s) + implicit def toRichIndexedSeqEmitSettable(s: IndexedSeq[EmitSettable]) + : RichIndexedSeqEmitSettable = new RichIndexedSeqEmitSettable(s) implicit def emitValueToCode(ev: EmitValue): EmitCode = ev.load diff --git a/hail/src/main/scala/is/hail/expr/ir/streams/EmitStream.scala b/hail/src/main/scala/is/hail/expr/ir/streams/EmitStream.scala index 28034169641..6b32b2b9465 100644 --- a/hail/src/main/scala/is/hail/expr/ir/streams/EmitStream.scala +++ b/hail/src/main/scala/is/hail/expr/ir/streams/EmitStream.scala @@ -1,104 +1,104 @@ package is.hail.expr.ir.streams import is.hail.annotations.{Region, RegionPool} +import is.hail.annotations.Region.REGULAR import is.hail.asm4s._ import is.hail.expr.ir._ import is.hail.expr.ir.agg.{AggStateSig, DictState, PhysicalAggSig, StateTuple} import is.hail.expr.ir.functions.IntervalFunctions +import is.hail.expr.ir.functions.IntervalFunctions.{ + pointGTIntervalEndpoint, pointLTIntervalEndpoint, +} import is.hail.expr.ir.orderings.StructOrdering import is.hail.linalg.LinalgCodeUtils import is.hail.lir import is.hail.methods.{BitPackedVector, BitPackedVectorBuilder, LocalLDPrune, LocalWhitening} -import is.hail.types.physical.stypes.concrete.{SBinaryPointer, SStackStruct, SUnreachable} +import is.hail.types.{RIterable, TypeWithRequiredness, VirtualTypeWithReq} +import is.hail.types.physical._ +import is.hail.types.physical.stypes.{EmitType, SSettable, SValue} +import is.hail.types.physical.stypes.concrete.{ + SBaseStructPointer, SBinaryPointer, SStackStruct, SUnreachable, +} import is.hail.types.physical.stypes.interfaces._ import is.hail.types.physical.stypes.primitives.{SFloat64Value, SInt32Value} -import is.hail.types.physical.stypes.{EmitType, SSettable} -import is.hail.types.physical.{PCanonicalArray, PCanonicalBinary, PCanonicalStruct, PType} import is.hail.types.virtual._ -import is.hail.types.{RIterable, TypeWithRequiredness, VirtualTypeWithReq} import is.hail.utils._ import is.hail.variant.Locus -import org.objectweb.asm.Opcodes._ + +import scala.annotation.nowarn import java.util +import org.objectweb.asm.Opcodes._ abstract class StreamProducer { // method builder where this stream is valid def method: EmitMethodBuilder[_] - /** - * Stream length, which is present if it can be computed (somewhat) cheaply without - * consuming the stream. + /** Stream length, which is present if it can be computed (somewhat) cheaply without consuming the + * stream. * * In order for `length` to be valid, the stream must have been initialized with `initialize`. */ val length: Option[EmitCodeBuilder => Code[Int]] - /** - * Stream producer setup method. If `initialize` is called, then the `close` method - * must be called as well to properly handle owned resources like files. + /** Stream producer setup method. If `initialize` is called, then the `close` method must be + * called as well to properly handle owned resources like files. * - * The stream's element region must be assigned by a consumer before initialize - * is called. + * The stream's element region must be assigned by a consumer before initialize is called. * * This block cannot jump away, e.g. to `LendOfStream`. - * */ def initialize(cb: EmitCodeBuilder, outerRegion: Value[Region]): Unit - /** - * Stream element region, into which the `element` is emitted. The assignment, clearing, - * and freeing of the element region is the responsibility of the stream consumer. + /** Stream element region, into which the `element` is emitted. The assignment, clearing, and + * freeing of the element region is the responsibility of the stream consumer. */ val elementRegion: Settable[Region] - /** - * This boolean parameter indicates whether the producer's elements should be allocated in + /** This boolean parameter indicates whether the producer's elements should be allocated in * separate regions (by clearing when elements leave a consumer's scope). This parameter - * propagates bottom-up from producers like [[ReadPartition]] and [[StreamRange]], but - * it is the responsibility of consumers to implement the right memory management semantics - * based on this flag. + * propagates bottom-up from producers like [[ReadPartition]] and [[StreamRange]], but it is the + * responsibility of consumers to implement the right memory management semantics based on this + * flag. */ val requiresMemoryManagementPerElement: Boolean - /** - * The `LproduceElement` label is the mechanism by which consumers drive iteration. A consumer + /** The `LproduceElement` label is the mechanism by which consumers drive iteration. A consumer * jumps to `LproduceElement` when it is ready for an element. The code block at this label, * defined by the producer, jumps to either `LproduceElementDone` or `LendOfStream`, both of * which the consumer must define. */ val LproduceElement: CodeLabel - /** - * The `LproduceElementDone` label is jumped to by the code block at `LproduceElement` if - * the stream has produced a valid `element`. The immediate stream consumer must define - * this label. + /** The `LproduceElementDone` label is jumped to by the code block at `LproduceElement` if the + * stream has produced a valid `element`. The immediate stream consumer must define this label. */ final val LproduceElementDone: CodeLabel = CodeLabel() - /** - * The `LendOfStream` label is jumped to by the code block at `LproduceElement` if - * the stream has no more elements to return. The immediate stream consumer must - * define this label. + /** The `LendOfStream` label is jumped to by the code block at `LproduceElement` if the stream has + * no more elements to return. The immediate stream consumer must define this label. */ final val LendOfStream: CodeLabel = CodeLabel() - - /** - * Stream element. This value is valid after the producer jumps to `LproduceElementDone`, - * until a consumer jumps to `LproduceElement` again, or calls `close()`. + /** Stream element. This value is valid after the producer jumps to `LproduceElementDone`, until a + * consumer jumps to `LproduceElement` again, or calls `close()`. */ val element: EmitCode - /** - * Stream producer cleanup method. If `initialize` is called, then the `close` method - * must be called as well to properly handle owned resources like files. + /** Stream producer cleanup method. If `initialize` is called, then the `close` method must be + * called as well to properly handle owned resources like files. */ def close(cb: EmitCodeBuilder): Unit - final def unmanagedConsume(cb: EmitCodeBuilder, outerRegion: Value[Region], setup: EmitCodeBuilder => Unit = _ => ())(perElement: EmitCodeBuilder => Unit): Unit = { + final def unmanagedConsume( + cb: EmitCodeBuilder, + outerRegion: Value[Region], + setup: EmitCodeBuilder => Unit = _ => (), + )( + perElement: EmitCodeBuilder => Unit + ): Unit = { this.initialize(cb, outerRegion) setup(cb) @@ -111,8 +111,15 @@ abstract class StreamProducer { this.close(cb) } - // only valid if `perElement` does not retain pointers into the element region after it returns (or adds region references) - final def memoryManagedConsume(outerRegion: Value[Region], cb: EmitCodeBuilder, setup: EmitCodeBuilder => Unit = _ => ())(perElement: EmitCodeBuilder => Unit): Unit = { + /* only valid if `perElement` does not retain pointers into the element region after it returns + * (or adds region references) */ + final def memoryManagedConsume( + outerRegion: Value[Region], + cb: EmitCodeBuilder, + setup: EmitCodeBuilder => Unit = _ => (), + )( + perElement: EmitCodeBuilder => Unit + ): Unit = { if (requiresMemoryManagementPerElement) { cb.assign(elementRegion, Region.stagedCreate(Region.REGULAR, outerRegion.getPool())) @@ -136,21 +143,38 @@ object EmitStream { mb: EmitMethodBuilder[_], outerRegion: Value[Region], env: EmitEnv, - container: Option[AggContainer] + container: Option[AggContainer], ): IEmitCode = { - def emitVoid(ir: IR, cb: EmitCodeBuilder, region: Value[Region] = outerRegion, env: EmitEnv = env, container: Option[AggContainer] = container): Unit = + @nowarn("cat=unused-locals&msg=local default argument") + def emitVoid( + ir: IR, + cb: EmitCodeBuilder, + region: Value[Region] = outerRegion, + env: EmitEnv = env, + container: Option[AggContainer] = container, + ): Unit = emitter.emitVoid(cb, ir, region, env, container, None) - def emit(ir: IR, cb: EmitCodeBuilder, region: Value[Region] = outerRegion, env: EmitEnv = env, container: Option[AggContainer] = container): IEmitCode = { + def emit( + ir: IR, + cb: EmitCodeBuilder, + region: Value[Region] = outerRegion, + env: EmitEnv = env, + container: Option[AggContainer] = container, + ): IEmitCode = ir.typ match { case _: TStream => produce(ir, cb, cb.emb, region, env, container) case _ => emitter.emitI(ir, cb, region, env, container, None) } - } // returns IEmitCode of SStreamConcrete - def produceIterator(streamIR: IR, elementPType: PType, cb: EmitCodeBuilder, outerRegion: Value[Region] = outerRegion, env: EmitEnv = env): IEmitCode = { + def produceIterator( + streamIR: IR, + elementPType: PType, + cb: EmitCodeBuilder, + env: EmitEnv, + ): IEmitCode = { val ecb = cb.emb.genEmitClass[NoBoxLongIterator]("stream_to_iter") ecb.cb.addInterface(typeInfo[MissingnessAsMethod].iname) @@ -167,30 +191,48 @@ object EmitStream { var producerRequired: Boolean = false val next = ecb.newEmitMethod("next", FastSeq[ParamType](), LongInfo) - val ctor = ecb.newEmitMethod("", FastSeq[ParamType](typeInfo[Region], arrayInfo[Long]) ++ envParamTypes, UnitInfo) + val ctor = ecb.newEmitMethod( + "", + FastSeq[ParamType](typeInfo[Region], arrayInfo[Long]) ++ envParamTypes, + UnitInfo, + ) ctor.voidWithBuilder { cb => val L = new lir.Block() L.append( - lir.methodStmt(INVOKESPECIAL, + lir.methodStmt( + INVOKESPECIAL, "java/lang/Object", "", "()V", false, UnitInfo, - FastSeq(lir.load(ctor.mb._this.asInstanceOf[LocalRef[_]].l)))) + FastSeq(lir.load(ctor.mb.this_.asInstanceOf[LocalRef[_]].l)), + ) + ) cb += new VCode(L, L, null) val newEnv = restoreEnv(cb, 3) - val s = EmitStream.produce(new Emit(emitter.ctx, ecb), streamIR, cb, next, outerRegionField, newEnv, None) + val s = EmitStream.produce( + new Emit(emitter.ctx, ecb), + streamIR, + cb, + next, + outerRegionField, + newEnv, + None, + ) producerRequired = s.required - s.consume(cb, { - if (!producerRequired) cb.assign(isMissing, true) - }, { stream => - if (!producerRequired) cb.assign(isMissing, false) - producer = stream.asStream.getProducer(next) - }) - - val self = cb.memoize(Code.checkcast[FunctionWithPartitionRegion]((ctor.getCodeParam(0)(ecb.cb.ti)))) + s.consume( + cb, + if (!producerRequired) cb.assign(isMissing, true), + { stream => + if (!producerRequired) cb.assign(isMissing, false) + producer = stream.asStream.getProducer(next) + }, + ) + + val self = + cb.memoize(Code.checkcast[FunctionWithPartitionRegion]((ctor.getCodeParam(0)(ecb.cb.ti)))) ecb.setLiteralsArray(cb, ctor.getCodeParam[Array[Long]](2)) val partitionRegion = cb.memoize(ctor.getCodeParam[Region](1)) @@ -204,11 +246,11 @@ object EmitStream { cb.goto(producer.LproduceElement) cb.define(producer.LproduceElementDone) producer.element.toI(cb) - .consume(cb, { - cb.assign(ret, 0L) - }, { value => - cb.assign(ret, elementPType.store(cb, producer.elementRegion, value, false)) - }) + .consume( + cb, + cb.assign(ret, 0L), + value => cb.assign(ret, elementPType.store(cb, producer.elementRegion, value, false)), + ) cb.goto(Lret) cb.define(producer.LendOfStream) cb.assign(eosField, true) @@ -217,7 +259,8 @@ object EmitStream { ret } - val init = ecb.newEmitMethod("init", FastSeq[ParamType](typeInfo[Region], typeInfo[Region]), UnitInfo) + val init = + ecb.newEmitMethod("init", FastSeq[ParamType](typeInfo[Region], typeInfo[Region]), UnitInfo) init.voidWithBuilder { cb => val outerRegion = init.getCodeParam[Region](1) val eltRegion = init.getCodeParam[Region](2) @@ -228,44 +271,63 @@ object EmitStream { cb.assign(eosField, false) } - val isEOS = ecb.newEmitMethod("eos", FastSeq[ParamType](), BooleanInfo) isEOS.emitWithBuilder[Boolean](cb => eosField) val isMissingMethod = ecb.newEmitMethod("isMissing", FastSeq[ParamType](), BooleanInfo) isMissingMethod.emitWithBuilder[Boolean](cb => isMissing) - val close = ecb.newEmitMethod("close", FastSeq[ParamType](), UnitInfo) close.voidWithBuilder(cb => producer.close(cb)) - val obj = cb.memoize(Code.newInstance(ecb.cb, ctor.mb, - FastSeq(cb.emb.partitionRegion.get, cb.emb.ecb.literalsArray().get) ++ envParams.map(_.get))) + val obj = cb.memoize(Code.newInstance( + ecb.cb, + ctor.mb, + FastSeq(cb.emb.partitionRegion.get, cb.emb.ecb.literalsArray().get) ++ envParams.map(_.get), + )) val iter = cb.emb.genFieldThisRef[NoBoxLongIterator]("iter") cb.assign(iter, Code.checkcast[NoBoxLongIterator](obj)) - IEmitCode(cb, - if (producerRequired) false else Code.checkcast[MissingnessAsMethod](obj).invoke[Boolean]("isMissing"), - new SStreamConcrete(SStreamIteratorLong(producer.element.required, elementPType, producer.requiresMemoryManagementPerElement), - iter)) + IEmitCode( + cb, + if (producerRequired) false + else Code.checkcast[MissingnessAsMethod](obj).invoke[Boolean]("isMissing"), + new SStreamConcrete( + SStreamIteratorLong( + producer.element.required, + elementPType, + producer.requiresMemoryManagementPerElement, + ), + iter, + ), + ) } - def produce(streamIR: IR, cb: EmitCodeBuilder, mb: EmitMethodBuilder[_] = mb, region: Value[Region] = outerRegion, env: EmitEnv = env, container: Option[AggContainer] = container): IEmitCode = + def produce( + streamIR: IR, + cb: EmitCodeBuilder, + mb: EmitMethodBuilder[_] = mb, + region: Value[Region] = outerRegion, + env: EmitEnv = env, + container: Option[AggContainer] = container, + ): IEmitCode = EmitStream.produce(emitter, streamIR, cb, mb, region, env, container) - def typeWithReqx(node: IR): VirtualTypeWithReq = VirtualTypeWithReq(node.typ, emitter.ctx.req.lookup(node).asInstanceOf[TypeWithRequiredness]) + def typeWithReqx(node: IR): VirtualTypeWithReq = + VirtualTypeWithReq(node.typ, emitter.ctx.req.lookup(node).asInstanceOf[TypeWithRequiredness]) def typeWithReq: VirtualTypeWithReq = typeWithReqx(streamIR) streamIR match { - case x@NA(_typ: TStream) => + case NA(_typ: TStream) => val st = SStream(EmitType(SUnreachable.fromVirtualType(_typ.elementType), true)) val region = mb.genFieldThisRef[Region]("na_region") val producer = new StreamProducer { override def method: EmitMethodBuilder[_] = mb override def initialize(cb: EmitCodeBuilder, outerRegion: Value[Region]): Unit = {} - override val length: Option[EmitCodeBuilder => Code[Int]] = Some(_ => Code._fatal[Int]("tried to get NA stream length")) + override val length: Option[EmitCodeBuilder => Code[Int]] = + Some(_ => Code._fatal[Int]("tried to get NA stream length")) override val elementRegion: Settable[Region] = region override val requiresMemoryManagementPerElement: Boolean = false override val LproduceElement: CodeLabel = mb.defineAndImplementLabel { cb => @@ -284,11 +346,13 @@ object EmitStream { val childProducer = stream.getProducer(mb) val producer = new StreamProducer { override def method: EmitMethodBuilder[_] = mb - override def initialize(cb: EmitCodeBuilder, outerRegion: Value[Region]): Unit = childProducer.initialize(cb, outerRegion) + override def initialize(cb: EmitCodeBuilder, outerRegion: Value[Region]): Unit = + childProducer.initialize(cb, outerRegion) override val length: Option[EmitCodeBuilder => Code[Int]] = childProducer.length override val elementRegion: Settable[Region] = childProducer.elementRegion - override val requiresMemoryManagementPerElement: Boolean = childProducer.requiresMemoryManagementPerElement + override val requiresMemoryManagementPerElement: Boolean = + childProducer.requiresMemoryManagementPerElement override val LproduceElement: CodeLabel = mb.defineAndImplementLabel { cb => cb.goto(childProducer.LproduceElement) cb.define(childProducer.LproduceElementDone) @@ -298,51 +362,43 @@ object EmitStream { override def close(cb: EmitCodeBuilder): Unit = childProducer.close(cb) } - mb.implementLabel(childProducer.LendOfStream) { cb => - cb.goto(producer.LendOfStream) - } + mb.implementLabel(childProducer.LendOfStream)(cb => cb.goto(producer.LendOfStream)) SStreamValue(producer) } - case Let(bindings, body) => - def go(env: EmitEnv): IndexedSeq[(String, IR)] => IEmitCode = { - case (name, value) +: rest => - cb.withScopedMaybeStreamValue(EmitCode.fromI(cb.emb)(cb => emit(value, cb, env = env)), s"let_$name") { ev => - go(env.bind(name, ev))(rest) - } - case Seq() => - produce(body, cb, env = env) - } - go(env)(bindings) + case let: Block => + val newEnv = emitter.emitBlock(let, cb, env, outerRegion, container, None) + produce(let.body, cb, env = newEnv) case In(n, _) => // this, Code[Region], ... val param = env.inputValues(n).toI(cb) if (!param.st.isInstanceOf[SStream]) - throw new RuntimeException(s"parameter ${ 2 + n } is not a stream! t=${ param.st } }, params=${ mb.emitParamTypes }") + throw new RuntimeException( + s"parameter ${2 + n} is not a stream! t=${param.st} }, params=${mb.emitParamTypes}" + ) param case ToStream(a, _requiresMemoryManagementPerElement) => - emit(a, cb).map(cb) { case _ind: SIndexableValue => val containerField = cb.memoizeField(_ind, "indexable").asIndexable val container = containerField.asInstanceOf[SIndexableValue] val idx = mb.genFieldThisRef[Int]("tostream_idx") val regionVar = mb.genFieldThisRef[Region]("tostream_region") - SStreamValue( new StreamProducer { override def method: EmitMethodBuilder[_] = mb - override def initialize(cb: EmitCodeBuilder, outerRegion: Value[Region]): Unit = { + override def initialize(cb: EmitCodeBuilder, outerRegion: Value[Region]): Unit = cb.assign(idx, -1) - } - override val length: Option[EmitCodeBuilder => Code[Int]] = Some(_ => container.loadLength()) + override val length: Option[EmitCodeBuilder => Code[Int]] = + Some(_ => container.loadLength()) override val elementRegion: Settable[Region] = regionVar - override val requiresMemoryManagementPerElement: Boolean = _requiresMemoryManagementPerElement + override val requiresMemoryManagementPerElement: Boolean = + _requiresMemoryManagementPerElement override val LproduceElement: CodeLabel = mb.defineAndImplementLabel { cb => cb.assign(idx, idx + 1) @@ -350,16 +406,23 @@ object EmitStream { cb.goto(LproduceElementDone) } - val element: EmitCode = EmitCode.fromI(mb) { cb => - container.loadElement(cb, idx) } + val element: EmitCode = EmitCode.fromI(mb)(cb => container.loadElement(cb, idx)) def close(cb: EmitCodeBuilder): Unit = {} - }) + } + ) } - case x@StreamBufferedAggregate(streamChild, initAggs, newKey, seqOps, name, - aggSignatures: IndexedSeq[PhysicalAggSig], bufferSize: Int) => + case x @ StreamBufferedAggregate( + streamChild, + initAggs, + newKey, + seqOps, + name, + aggSignatures: IndexedSeq[PhysicalAggSig], + bufferSize: Int, + ) => val region = mb.genFieldThisRef[Region]("stream_buff_agg_region") produce(streamChild, cb) .map(cb) { case childStream: SStreamValue => @@ -373,24 +436,39 @@ object EmitStream { val maxSize = mb.genFieldThisRef[Int]("stream_buff_agg_max_size") val nodeArray = mb.genFieldThisRef[Array[Long]]("stream_buff_agg_element_array") val idx = mb.genFieldThisRef[Int]("stream_buff_agg_idx") - val returnStreamType= x.typ.asInstanceOf[TStream] + val returnStreamType = x.typ.asInstanceOf[TStream] val returnElemType = returnStreamType.elementType val tupleFieldTypes = aggSignatures.map(_ => TBinary) - val tupleFields = (0 to tupleFieldTypes.length).zip(tupleFieldTypes).map { case (fieldIdx, fieldType) => TupleField(fieldIdx, fieldType) } - val serializedAggSType = SStackStruct(TTuple(tupleFields), tupleFieldTypes.map(_ => EmitType(SBinaryPointer(PCanonicalBinary()), true)).toIndexedSeq) - val keyAndAggFields = newKeyVType.canonicalPType.asInstanceOf[PCanonicalStruct].sType.fieldEmitTypes :+ EmitType(serializedAggSType, true) - val returnElemSType = SStackStruct(returnElemType.asInstanceOf[TBaseStruct], keyAndAggFields) - val newStreamElem = mb.newEmitField("stream_buff_agg_new_stream_elem", EmitType(returnElemSType, true)) + val tupleFields = (0 to tupleFieldTypes.length).zip(tupleFieldTypes).map { + case (fieldIdx, fieldType) => TupleField(fieldIdx, fieldType) + } + val serializedAggSType = SStackStruct( + TTuple(tupleFields), + tupleFieldTypes.map(_ => + EmitType(SBinaryPointer(PCanonicalBinary()), true) + ).toIndexedSeq, + ) + val keyAndAggFields = newKeyVType.canonicalPType.asInstanceOf[ + PCanonicalStruct + ].sType.fieldEmitTypes :+ EmitType(serializedAggSType, true) + val returnElemSType = + SStackStruct(returnElemType.asInstanceOf[TBaseStruct], keyAndAggFields) + val newStreamElem = + mb.newEmitField("stream_buff_agg_new_stream_elem", EmitType(returnElemSType, true)) val numElemInArray = mb.genFieldThisRef[Int]("stream_buff_agg_num_elem_in_size") val childStreamEnded = mb.genFieldThisRef[Boolean]("stream_buff_agg_child_stream_ended") - val produceElementMode = mb.genFieldThisRef[Boolean]("stream_buff_agg_child_produce_elt_mode") + val produceElementMode = + mb.genFieldThisRef[Boolean]("stream_buff_agg_child_produce_elt_mode") val producer: StreamProducer = new StreamProducer { override def method: EmitMethodBuilder[_] = mb override val length: Option[EmitCodeBuilder => Code[Int]] = None override def initialize(cb: EmitCodeBuilder, outerRegion: Value[Region]): Unit = { if (childProducer.requiresMemoryManagementPerElement) - cb.assign(childProducer.elementRegion, Region.stagedCreate(Region.REGULAR, outerRegion.getPool())) + cb.assign( + childProducer.elementRegion, + Region.stagedCreate(Region.REGULAR, outerRegion.getPool()), + ) else cb.assign(childProducer.elementRegion, region) @@ -406,7 +484,8 @@ object EmitStream { override val elementRegion: Settable[Region] = region - override val requiresMemoryManagementPerElement: Boolean = childProducer.requiresMemoryManagementPerElement + override val requiresMemoryManagementPerElement: Boolean = + childProducer.requiresMemoryManagementPerElement override val LproduceElement: CodeLabel = mb.defineAndImplementLabel { cb => val elementProduceLabel = CodeLabel() @@ -414,14 +493,16 @@ object EmitStream { val startLabel = CodeLabel() cb.define(startLabel) - cb.if_(produceElementMode, { - cb.goto(elementProduceLabel) - }) + cb.if_(produceElementMode, cb.goto(elementProduceLabel)) // Garbage collects old aggregator state if moving onto new group dictState.newState(cb) - val initContainer = AggContainer(aggSignatures.toArray.map(sig => sig.state), dictState.initContainer, cleanup = () => ()) - dictState.init(cb, { cb => emitVoid(initAggs, cb, container = Some(initContainer)) }) + val initContainer = AggContainer( + aggSignatures.toArray.map(sig => sig.state), + dictState.initContainer, + cleanup = () => (), + ) + dictState.init(cb, cb => emitVoid(initAggs, cb, container = Some(initContainer))) cb.define(getElemLabel) if (childProducer.requiresMemoryManagementPerElement) @@ -431,63 +512,84 @@ object EmitStream { cb.define(childProducer.LproduceElementDone) cb.assign(eltField, childProducer.element) val newKeyResultCode = EmitCode.fromI(mb) { cb => - emit(newKey, - cb = cb, - env = env.bind(name, eltField), - region = region) + emit(newKey, cb = cb, env = env.bind(name, eltField), region = region) } val resultKeyValue = newKeyResultCode.memoize(cb, "buff_agg_stream_result_key") - val keyedContainer = AggContainer(aggSignatures.toArray.map(sig => sig.state), dictState.keyed.container, cleanup = () => ()) - dictState.withContainer(cb, resultKeyValue, { cb => - emitVoid(seqOps, cb, container = Some(keyedContainer), env = env.bind(name, eltField)) - }) - cb.if_(dictState.size >= maxSize,{ - cb.assign(produceElementMode, true) - }) - - cb.if_(produceElementMode, { - cb.goto(elementProduceLabel)}, - { - cb.goto(getElemLabel) - } + val keyedContainer = AggContainer( + aggSignatures.toArray.map(sig => sig.state), + dictState.keyed.container, + cleanup = () => (), ) + dictState.withContainer( + cb, + resultKeyValue, + cb => + emitVoid( + seqOps, + cb, + container = Some(keyedContainer), + env = env.bind(name, eltField), + ), + ) + cb.if_(dictState.size >= maxSize, cb.assign(produceElementMode, true)) + + cb.if_(produceElementMode, cb.goto(elementProduceLabel), cb.goto(getElemLabel)) cb.define(childProducer.LendOfStream) cb.assign(childStreamEnded, true) cb.assign(produceElementMode, true) cb.define(elementProduceLabel) - cb.if_(numElemInArray ceq 0, { + cb.if_( + numElemInArray ceq 0, dictState.tree.foreach(cb) { (cb, elementOff) => cb += nodeArray.update(numElemInArray, elementOff) cb.assign(numElemInArray, numElemInArray + 1) - } - }) + }, + ) - cb.if_(numElemInArray <= idx, { - cb.assign(idx, 0) - cb.assign(numElemInArray, 0) - cb.assign(produceElementMode, false) - cb.if_(childStreamEnded , { - cb.goto(LendOfStream) - }, { - cb.goto(startLabel) - }) - }) + cb.if_( + numElemInArray <= idx, { + cb.assign(idx, 0) + cb.assign(numElemInArray, 0) + cb.assign(produceElementMode, false) + cb.if_(childStreamEnded, cb.goto(LendOfStream), cb.goto(startLabel)) + }, + ) val nodeAddress = cb.memoize(nodeArray(idx)) cb.assign(idx, idx + 1) dictState.loadNode(cb, nodeAddress) val keyInWrongRegion = dictState.keyed.storageType.loadCheapSCode(cb, nodeAddress) - val addrOfKeyInRightRegion = dictState.keyed.storageType.store(cb, region, keyInWrongRegion, true) - val key = dictState.keyed.storageType.loadCheapSCode(cb, addrOfKeyInRightRegion).loadField(cb, "kt").memoize(cb, "stream_buff_agg_key_right_region") - - val serializedAggValue = keyedContainer.container.states.states.map(state => state.serializeToRegion(cb, PCanonicalBinary(), region)) - val serializedAggEmitCodes = serializedAggValue.map(aggValue => EmitCode.present(mb, aggValue)) - val serializedAggTupleSValue = SStackStruct.constructFromArgs(cb, region, serializedAggSType.virtualType, serializedAggEmitCodes: _*) + val addrOfKeyInRightRegion = + dictState.keyed.storageType.store(cb, region, keyInWrongRegion, true) + val key = dictState.keyed.storageType.loadCheapSCode( + cb, + addrOfKeyInRightRegion, + ).loadField(cb, "kt").memoize(cb, "stream_buff_agg_key_right_region") + + val serializedAggValue = keyedContainer.container.states.states.map(state => + state.serializeToRegion(cb, PCanonicalBinary(), region) + ) + val serializedAggEmitCodes = + serializedAggValue.map(aggValue => EmitCode.present(mb, aggValue)) + val serializedAggTupleSValue = SStackStruct.constructFromArgs( + cb, + region, + serializedAggSType.virtualType, + serializedAggEmitCodes: _* + ) val keyValue = key.get(cb).asInstanceOf[SBaseStructValue] - val sStructToReturn = keyValue.insert(cb, region, returnElemType.asInstanceOf[TStruct], ("agg", EmitCode.present(mb, serializedAggTupleSValue) - .memoize(cb, "stream_buff_agg_return_val"))) + val sStructToReturn = keyValue.insert( + cb, + region, + returnElemType.asInstanceOf[TStruct], + ( + "agg", + EmitCode.present(mb, serializedAggTupleSValue) + .memoize(cb, "stream_buff_agg_return_val"), + ), + ) assert(returnElemSType.virtualType == sStructToReturn.st.virtualType) val casted = sStructToReturn.castTo(cb, region, returnElemSType) cb.assign(newStreamElem, EmitCode.present(mb, casted).toI(cb)) @@ -506,7 +608,7 @@ object EmitStream { SStreamValue(producer) } - case x@MakeStream(args, _, _requiresMemoryManagementPerElement) => + case MakeStream(args, _, _requiresMemoryManagementPerElement) => val region = mb.genFieldThisRef[Region]("makestream_region") // FIXME use SType.chooseCompatibleType @@ -516,40 +618,47 @@ object EmitStream { val staticLen = args.size val current = mb.genFieldThisRef[Int]("makestream_current") - IEmitCode.present(cb, SStreamValue( - new StreamProducer { - override def method: EmitMethodBuilder[_] = mb - override def initialize(cb: EmitCodeBuilder, outerRegion: Value[Region]): Unit = { - cb.assign(current, 0) // switches on 1..N - } + IEmitCode.present( + cb, + SStreamValue( + new StreamProducer { + override def method: EmitMethodBuilder[_] = mb + override def initialize(cb: EmitCodeBuilder, outerRegion: Value[Region]): Unit = + cb.assign(current, 0) // switches on 1..N - override val length: Option[EmitCodeBuilder => Code[Int]] = Some(_ => staticLen) + override val length: Option[EmitCodeBuilder => Code[Int]] = Some(_ => staticLen) - override val elementRegion: Settable[Region] = region + override val elementRegion: Settable[Region] = region - override val requiresMemoryManagementPerElement: Boolean = _requiresMemoryManagementPerElement + override val requiresMemoryManagementPerElement: Boolean = + _requiresMemoryManagementPerElement - override val LproduceElement: CodeLabel = - mb.defineAndImplementLabel { cb => - cb.switch(current, - cb.goto(LendOfStream), - args.map { a => - () => + override val LproduceElement: CodeLabel = + mb.defineAndImplementLabel { cb => + cb.switch( + current, + cb.goto(LendOfStream), + args.map { a => () => val elem = emit(a, cb, region) - cb.assign(eltField, elem.map(cb)(pc => pc.castTo(cb, region, unifiedType.st, false))) - } - ) + cb.assign( + eltField, + elem.map(cb)(pc => pc.castTo(cb, region, unifiedType.st, false)), + ) + }, + ) - cb.assign(current, current + 1) - cb.goto(LproduceElementDone) - } + cb.assign(current, current + 1) + cb.goto(LproduceElementDone) + } - val element: EmitCode = eltField.load + val element: EmitCode = eltField.load - def close(cb: EmitCodeBuilder): Unit = {} - })) + def close(cb: EmitCodeBuilder): Unit = {} + } + ), + ) - case x@If(cond, cnsq, altr) => + case If(cond, cnsq, altr) => emit(cond, cb).flatMap(cb) { cond => val xCond = mb.genFieldThisRef[Boolean]("stream_if_cond") cb.assign(xCond, cond.asBoolean.value) @@ -569,35 +678,44 @@ object EmitStream { val xElt = mb.newEmitField(unifiedElementType) val region = mb.genFieldThisRef[Region]("streamif_region") - cb.if_(xCond, + cb.if_( + xCond, leftEC.toI(cb).consume(cb, cb.goto(Lmissing), _ => cb.goto(Lpresent)), - rightEC.toI(cb).consume(cb, cb.goto(Lmissing), _ => cb.goto(Lpresent))) + rightEC.toI(cb).consume(cb, cb.goto(Lmissing), _ => cb.goto(Lpresent)), + ) val producer = new StreamProducer { override def method: EmitMethodBuilder[_] = mb override val length: Option[EmitCodeBuilder => Code[Int]] = leftProducer.length .liftedZip(rightProducer.length).map { case (computeL1, computeL2) => - cb: EmitCodeBuilder => { - val len = cb.newLocal[Int]("if_len") - cb.if_(xCond, cb.assign(len, computeL1(cb)), cb.assign(len, computeL2(cb))) - len.get + cb: EmitCodeBuilder => { + val len = cb.newLocal[Int]("if_len") + cb.if_(xCond, cb.assign(len, computeL1(cb)), cb.assign(len, computeL2(cb))) + len.get + } } - } override def initialize(cb: EmitCodeBuilder, outerRegion: Value[Region]): Unit = { - cb.if_(xCond, { - cb.assign(leftProducer.elementRegion, region) - leftProducer.initialize(cb, outerRegion) - }, { - cb.assign(rightProducer.elementRegion, region) - rightProducer.initialize(cb, outerRegion) - }) + cb.if_( + xCond, { + cb.assign(leftProducer.elementRegion, region) + leftProducer.initialize(cb, outerRegion) + }, { + cb.assign(rightProducer.elementRegion, region) + rightProducer.initialize(cb, outerRegion) + }, + ) } override val elementRegion: Settable[Region] = region - override val requiresMemoryManagementPerElement: Boolean = leftProducer.requiresMemoryManagementPerElement || rightProducer.requiresMemoryManagementPerElement + override val requiresMemoryManagementPerElement: Boolean = + leftProducer.requiresMemoryManagementPerElement || rightProducer.requiresMemoryManagementPerElement override val LproduceElement: CodeLabel = mb.defineAndImplementLabel { cb => - cb.if_(xCond, cb.goto(leftProducer.LproduceElement), cb.goto(rightProducer.LproduceElement)) + cb.if_( + xCond, + cb.goto(leftProducer.LproduceElement), + cb.goto(rightProducer.LproduceElement), + ) cb.define(leftProducer.LproduceElementDone) cb.assign(xElt, leftProducer.element.toI(cb).map(cb)(_.castTo(cb, region, xElt.st))) @@ -616,14 +734,11 @@ object EmitStream { override val element: EmitCode = xElt.load - override def close(cb: EmitCodeBuilder): Unit = { + override def close(cb: EmitCodeBuilder): Unit = cb.if_(xCond, leftProducer.close(cb), rightProducer.close(cb)) - } } - IEmitCode(Lmissing, Lpresent, - SStreamValue(producer), - leftEC.required && rightEC.required) + IEmitCode(Lmissing, Lpresent, SStreamValue(producer), leftEC.required && rightEC.required) } case StreamIota(start, step, _requiresMemoryManagementPerElement) => @@ -644,7 +759,8 @@ object EmitStream { override val elementRegion: Settable[Region] = regionVar - override val requiresMemoryManagementPerElement: Boolean = _requiresMemoryManagementPerElement + override val requiresMemoryManagementPerElement: Boolean = + _requiresMemoryManagementPerElement override val LproduceElement: CodeLabel = mb.defineAndImplementLabel { cb => cb.assign(curr, curr + stepVar) @@ -659,7 +775,8 @@ object EmitStream { } } - case StreamRange(start, stop, I32(step), _requiresMemoryManagementPerElement, errorID) if (step != 0) => + case StreamRange(start, stop, I32(step), _requiresMemoryManagementPerElement, errorID) + if (step != 0) => emit(start, cb).flatMap(cb) { startCode => emit(stop, cb).map(cb) { stopCode => val curr = mb.genFieldThisRef[Int]("streamrange_curr") @@ -675,30 +792,42 @@ object EmitStream { override val length: Option[EmitCodeBuilder => Code[Int]] = Some({ cb => val len = cb.newLocal[Int]("streamrange_len") if (step > 0) - cb.if_(startVar >= stopVar, + cb.if_( + startVar >= stopVar, cb.assign(len, 0), - cb.assign(len, ((stopVar.toL - startVar.toL - 1L) / step.toLong + 1L).toI)) + cb.assign(len, ((stopVar.toL - startVar.toL - 1L) / step.toLong + 1L).toI), + ) else - cb.if_(startVar <= stopVar, + cb.if_( + startVar <= stopVar, cb.assign(len, 0), - cb.assign(len, ((startVar.toL - stopVar.toL - 1L) / (-step.toLong) + 1L).toI)) + cb.assign(len, ((startVar.toL - stopVar.toL - 1L) / (-step.toLong) + 1L).toI), + ) len }) override def initialize(cb: EmitCodeBuilder, outerRegion: Value[Region]): Unit = { start match { - case I32(x) if step < 0 && ((x.toLong - Int.MinValue.toLong) / step.toLong + 1) < Int.MaxValue => - case I32(x) if step > 0 && ((Int.MaxValue.toLong - x.toLong) / step.toLong + 1) < Int.MaxValue => + case I32(x) + if step < 0 && ((x.toLong - Int.MinValue.toLong) / step.toLong + 1) < Int.MaxValue => + case I32(x) + if step > 0 && ((Int.MaxValue.toLong - x.toLong) / step.toLong + 1) < Int.MaxValue => case _ => - cb.if_((stopVar.toL - startVar.toL) / step.toLong > const(Int.MaxValue.toLong), - cb._fatalWithError(errorID, "Array range cannot have more than MAXINT elements.")) + cb.if_( + (stopVar.toL - startVar.toL) / step.toLong > const(Int.MaxValue.toLong), + cb._fatalWithError( + errorID, + "Array range cannot have more than MAXINT elements.", + ), + ) } cb.assign(curr, startVar - step) } override val elementRegion: Settable[Region] = regionVar - override val requiresMemoryManagementPerElement: Boolean = _requiresMemoryManagementPerElement + override val requiresMemoryManagementPerElement: Boolean = + _requiresMemoryManagementPerElement override val LproduceElement: CodeLabel = mb.defineAndImplementLabel { cb => cb.assign(curr, curr + step) @@ -719,7 +848,6 @@ object EmitStream { } case StreamRange(startIR, stopIR, stepIR, _requiresMemoryManagementPerElement, errorID) => - emit(startIR, cb).flatMap(cb) { startc => emit(stopIR, cb).flatMap(cb) { stopc => emit(stepIR, cb).map(cb) { stepc => @@ -744,23 +872,30 @@ object EmitStream { override def initialize(cb: EmitCodeBuilder, outerRegion: Value[Region]): Unit = { val llen = cb.newLocal[Long]("streamrange_llen") - cb.if_(step ceq const(0), cb._fatalWithError(errorID, "Array range cannot have step size 0.")) - cb.if_(step < const(0), { - cb.if_(start.toL <= stop.toL, { - cb.assign(llen, 0L) - }, { - cb.assign(llen, (start.toL - stop.toL - 1L) / (-step.toL) + 1L) - }) - }, { - cb.if_(start.toL >= stop.toL, { - cb.assign(llen, 0L) - }, { - cb.assign(llen, (stop.toL - start.toL - 1L) / step.toL + 1L) - }) - }) - cb.if_(llen > const(Int.MaxValue.toLong), { - cb._fatalWithError(errorID, "Array range cannot have more than MAXINT elements.") - }) + cb.if_( + step ceq const(0), + cb._fatalWithError(errorID, "Array range cannot have step size 0."), + ) + cb.if_( + step < const(0), + cb.if_( + start.toL <= stop.toL, + cb.assign(llen, 0L), + cb.assign(llen, (start.toL - stop.toL - 1L) / (-step.toL) + 1L), + ), + cb.if_( + start.toL >= stop.toL, + cb.assign(llen, 0L), + cb.assign(llen, (stop.toL - start.toL - 1L) / step.toL + 1L), + ), + ) + cb.if_( + llen > const(Int.MaxValue.toLong), + cb._fatalWithError( + errorID, + "Array range cannot have more than MAXINT elements.", + ), + ) cb.assign(len, llen.toI) cb.assign(curr, start - step) @@ -769,7 +904,8 @@ object EmitStream { override val elementRegion: Settable[Region] = regionVar - override val requiresMemoryManagementPerElement: Boolean = _requiresMemoryManagementPerElement + override val requiresMemoryManagementPerElement: Boolean = + _requiresMemoryManagementPerElement override val LproduceElement: CodeLabel = mb.defineAndImplementLabel { cb => cb.if_(idx >= len, cb.goto(LendOfStream)) @@ -788,8 +924,7 @@ object EmitStream { } } - - case SeqSample(totalSize, numToSample, rngState, _requiresMemoryManagementPerElement) => + case SeqSample(totalSize, numToSample, _, _requiresMemoryManagementPerElement) => // Implemented based on http://www.ittc.ku.edu/~jsv/Papers/Vit84.sampling.pdf Algorithm A emit(totalSize, cb).flatMap(cb) { case totalSizeVal: SInt32Value => emit(numToSample, cb).map(cb) { case numToSampleVal: SInt32Value => @@ -798,7 +933,8 @@ object EmitStream { val nRemaining = cb.newField[Int]("seq_sample_num_remaining", numToSampleVal.value) val candidate = cb.newField[Int]("seq_sample_candidate", 0) - val elementToReturn = cb.newField[Int]("seq_sample_element_to_return", -1) // -1 should never be returned. + val elementToReturn = + cb.newField[Int]("seq_sample_element_to_return", -1) // -1 should never be returned. val producer = new StreamProducer { override def method: EmitMethodBuilder[_] = mb @@ -813,25 +949,38 @@ object EmitStream { override val elementRegion: Settable[Region] = regionVar - override val requiresMemoryManagementPerElement: Boolean = _requiresMemoryManagementPerElement + override val requiresMemoryManagementPerElement: Boolean = + _requiresMemoryManagementPerElement override val LproduceElement: CodeLabel = mb.defineAndImplementLabel { cb => cb.if_(nRemaining <= 0, cb.goto(LendOfStream)) - val u = cb.newLocal[Double]("seq_sample_rand_unif", Code.invokeStatic0[Math, Double]("random")) - val fC = cb.newLocal[Double]("seq_sample_Fc", (totalSizeVal.value - candidate - nRemaining).toD / (totalSizeVal.value - candidate).toD) + val u = cb.newLocal[Double]( + "seq_sample_rand_unif", + Code.invokeStatic0[Math, Double]("random"), + ) + val fC = cb.newLocal[Double]( + "seq_sample_Fc", + (totalSizeVal.value - candidate - nRemaining).toD / (totalSizeVal.value - candidate).toD, + ) - cb.while_(fC > u, { - cb.assign(candidate, candidate + 1) - cb.assign(fC, fC * (const(1.0) - (nRemaining.toD / (totalSizeVal.value - candidate).toD))) - }) + cb.while_( + fC > u, { + cb.assign(candidate, candidate + 1) + cb.assign( + fC, + fC * (const(1.0) - (nRemaining.toD / (totalSizeVal.value - candidate).toD)), + ) + }, + ) cb.assign(nRemaining, nRemaining - 1) cb.assign(elementToReturn, candidate) cb.assign(candidate, candidate + 1) cb.goto(LproduceElementDone) } - override val element: EmitCode = EmitCode.present(mb, new SInt32Value(elementToReturn)) + override val element: EmitCode = + EmitCode.present(mb, new SInt32Value(elementToReturn)) override def close(cb: EmitCodeBuilder): Unit = {} } @@ -846,7 +995,8 @@ object EmitStream { val filterEltRegion = mb.genFieldThisRef[Region]("streamfilter_filter_region") - val elementField = cb.emb.newEmitField("streamfilter_cond", childProducer.element.emitType) + val elementField = + cb.emb.newEmitField("streamfilter_cond", childProducer.element.emitType) val producer = new StreamProducer { override def method: EmitMethodBuilder[_] = mb @@ -854,7 +1004,10 @@ object EmitStream { override def initialize(cb: EmitCodeBuilder, outerRegion: Value[Region]): Unit = { if (childProducer.requiresMemoryManagementPerElement) - cb.assign(childProducer.elementRegion, Region.stagedCreate(Region.REGULAR, outerRegion.getPool())) + cb.assign( + childProducer.elementRegion, + Region.stagedCreate(Region.REGULAR, outerRegion.getPool()), + ) else cb.assign(childProducer.elementRegion, outerRegion) childProducer.initialize(cb, outerRegion) @@ -862,7 +1015,8 @@ object EmitStream { override val elementRegion: Settable[Region] = filterEltRegion - override val requiresMemoryManagementPerElement: Boolean = childProducer.requiresMemoryManagementPerElement + override val requiresMemoryManagementPerElement: Boolean = + childProducer.requiresMemoryManagementPerElement override val LproduceElement: CodeLabel = mb.defineAndImplementLabel { cb => val Lfiltered = CodeLabel() @@ -872,12 +1026,17 @@ object EmitStream { cb.define(childProducer.LproduceElementDone) cb.assign(elementField, childProducer.element) // false and NA both fail the filter - emit(cond, cb = cb, env = env.bind(name, elementField), region = childProducer.elementRegion) - .consume(cb, + emit( + cond, + cb = cb, + env = env.bind(name, elementField), + region = childProducer.elementRegion, + ) + .consume( + cb, cb.goto(Lfiltered), - { sc => - cb.if_(!sc.asBoolean.value, cb.goto(Lfiltered)) - }) + sc => cb.if_(!sc.asBoolean.value, cb.goto(Lfiltered)), + ) if (requiresMemoryManagementPerElement) cb += filterEltRegion.takeOwnershipOfAndClear(childProducer.elementRegion) @@ -897,9 +1056,7 @@ object EmitStream { cb += childProducer.elementRegion.invalidate() } } - mb.implementLabel(childProducer.LendOfStream) { cb => - cb.goto(producer.LendOfStream) - } + mb.implementLabel(childProducer.LendOfStream)(cb => cb.goto(producer.LendOfStream)) SStreamValue(producer) } @@ -919,13 +1076,17 @@ object EmitStream { override def initialize(cb: EmitCodeBuilder, outerRegion: Value[Region]): Unit = { cb.assign(n, num.value) - cb.if_(n < 0, cb._fatal(s"stream take: negative number of elements to take: ", n.toS)) + cb.if_( + n < 0, + cb._fatal(s"stream take: negative number of elements to take: ", n.toS), + ) cb.assign(idx, 0) childProducer.initialize(cb, outerRegion) } override val elementRegion: Settable[Region] = childProducer.elementRegion - override val requiresMemoryManagementPerElement: Boolean = childProducer.requiresMemoryManagementPerElement + override val requiresMemoryManagementPerElement: Boolean = + childProducer.requiresMemoryManagementPerElement override val LproduceElement: CodeLabel = mb.defineAndImplementLabel { cb => cb.if_(idx >= n, cb.goto(LendOfStream)) cb.assign(idx, idx + 1) @@ -939,9 +1100,8 @@ object EmitStream { } override val element: EmitCode = childProducer.element - override def close(cb: EmitCodeBuilder): Unit = { + override def close(cb: EmitCodeBuilder): Unit = childProducer.close(cb) - } } SStreamValue(producer) @@ -958,28 +1118,35 @@ object EmitStream { val producer = new StreamProducer { override def method: EmitMethodBuilder[_] = mb - override val length: Option[EmitCodeBuilder => Code[Int]] = childProducer.length.map { computeL => - (cb: EmitCodeBuilder) => (computeL(cb) - n).max(0) - } + override val length: Option[EmitCodeBuilder => Code[Int]] = + childProducer.length.map { computeL => (cb: EmitCodeBuilder) => + (computeL(cb) - n).max(0) + } override def initialize(cb: EmitCodeBuilder, outerRegion: Value[Region]): Unit = { cb.assign(n, num.value) - cb.if_(n < 0, cb._fatal(s"stream drop: negative number of elements to drop: ", n.toS)) + cb.if_( + n < 0, + cb._fatal(s"stream drop: negative number of elements to drop: ", n.toS), + ) cb.assign(idx, 0) childProducer.initialize(cb, outerRegion) } override val elementRegion: Settable[Region] = childProducer.elementRegion - override val requiresMemoryManagementPerElement: Boolean = childProducer.requiresMemoryManagementPerElement + override val requiresMemoryManagementPerElement: Boolean = + childProducer.requiresMemoryManagementPerElement override val LproduceElement: CodeLabel = mb.defineAndImplementLabel { cb => cb.goto(childProducer.LproduceElement) cb.define(childProducer.LproduceElementDone) cb.assign(idx, idx + 1) - cb.if_(idx <= n, { - if (childProducer.requiresMemoryManagementPerElement) - cb += childProducer.elementRegion.clearRegion() - cb.goto(childProducer.LproduceElement) - }) + cb.if_( + idx <= n, { + if (childProducer.requiresMemoryManagementPerElement) + cb += childProducer.elementRegion.clearRegion() + cb.goto(childProducer.LproduceElement) + }, + ) cb.goto(LproduceElementDone) cb.define(childProducer.LendOfStream) @@ -987,9 +1154,8 @@ object EmitStream { } override val element: EmitCode = childProducer.element - override def close(cb: EmitCodeBuilder): Unit = { + override def close(cb: EmitCodeBuilder): Unit = childProducer.close(cb) - } } SStreamValue(producer) @@ -1001,29 +1167,40 @@ object EmitStream { .map(cb) { case childStream: SStreamValue => val childProducer = childStream.getProducer(mb) - val eltSettable = mb.newEmitField("stream_take_while_elt", childProducer.element.emitType) + val eltSettable = + mb.newEmitField("stream_take_while_elt", childProducer.element.emitType) val producer = new StreamProducer { override def method: EmitMethodBuilder[_] = mb override val length: Option[EmitCodeBuilder => Code[Int]] = None - override def initialize(cb: EmitCodeBuilder, outerRegion: Value[Region]): Unit = { + override def initialize(cb: EmitCodeBuilder, outerRegion: Value[Region]): Unit = childProducer.initialize(cb, outerRegion) - } override val elementRegion: Settable[Region] = childProducer.elementRegion - override val requiresMemoryManagementPerElement: Boolean = childProducer.requiresMemoryManagementPerElement + override val requiresMemoryManagementPerElement: Boolean = + childProducer.requiresMemoryManagementPerElement override val LproduceElement: CodeLabel = mb.defineAndImplementLabel { cb => cb.goto(childProducer.LproduceElement) cb.define(childProducer.LproduceElementDone) cb.assign(eltSettable, childProducer.element) - emit(condIR, cb, region = childProducer.elementRegion, env = env.bind(elt, eltSettable)) - .consume(cb, + emit( + condIR, + cb, + region = childProducer.elementRegion, + env = env.bind(elt, eltSettable), + ) + .consume( + cb, cb.goto(LendOfStream), - code => cb.if_(code.asBoolean.value, - cb.goto(LproduceElementDone), - cb.goto(LendOfStream))) + code => + cb.if_( + code.asBoolean.value, + cb.goto(LproduceElementDone), + cb.goto(LendOfStream), + ), + ) cb.define(childProducer.LendOfStream) cb.goto(LendOfStream) @@ -1031,9 +1208,8 @@ object EmitStream { override val element: EmitCode = eltSettable - override def close(cb: EmitCodeBuilder): Unit = { + override def close(cb: EmitCodeBuilder): Unit = childProducer.close(cb) - } } SStreamValue(producer) @@ -1043,7 +1219,8 @@ object EmitStream { produce(a, cb) .map(cb) { case childStream: SStreamValue => val childProducer = childStream.getProducer(mb) - val eltSettable = mb.newEmitField("stream_drop_while_elt", childProducer.element.emitType) + val eltSettable = + mb.newEmitField("stream_drop_while_elt", childProducer.element.emitType) val doneComparisons = mb.genFieldThisRef[Boolean]("stream_drop_while_donecomparisons") val producer = new StreamProducer { @@ -1056,9 +1233,9 @@ object EmitStream { } override val elementRegion: Settable[Region] = childProducer.elementRegion - override val requiresMemoryManagementPerElement: Boolean = childProducer.requiresMemoryManagementPerElement + override val requiresMemoryManagementPerElement: Boolean = + childProducer.requiresMemoryManagementPerElement override val LproduceElement: CodeLabel = mb.defineAndImplementLabel { cb => - cb.goto(childProducer.LproduceElement) cb.define(childProducer.LproduceElementDone) cb.assign(eltSettable, childProducer.element) @@ -1067,12 +1244,17 @@ object EmitStream { val LdropThis = CodeLabel() val LdoneDropping = CodeLabel() - emit(condIR, cb, region = childProducer.elementRegion, env = env.bind(elt, eltSettable)) - .consume(cb, + emit( + condIR, + cb, + region = childProducer.elementRegion, + env = env.bind(elt, eltSettable), + ) + .consume( + cb, cb.goto(LdoneDropping), - code => cb.if_(code.asBoolean.value, - cb.goto(LdropThis), - cb.goto(LdoneDropping))) + code => cb.if_(code.asBoolean.value, cb.goto(LdropThis), cb.goto(LdoneDropping)), + ) cb.define(LdropThis) if (childProducer.requiresMemoryManagementPerElement) @@ -1088,9 +1270,8 @@ object EmitStream { } override val element: EmitCode = eltSettable - override def close(cb: EmitCodeBuilder): Unit = { + override def close(cb: EmitCodeBuilder): Unit = childProducer.close(cb) - } } SStreamValue(producer) @@ -1102,11 +1283,14 @@ object EmitStream { val childProducer = childStream.getProducer(mb) val bodyResult = EmitCode.fromI(mb) { cb => - cb.withScopedMaybeStreamValue(childProducer.element, "streammap_element") { childProducerElement => - emit(body, - cb = cb, - env = env.bind(name, childProducerElement), - region = childProducer.elementRegion) + cb.withScopedMaybeStreamValue(childProducer.element, "streammap_element") { + childProducerElement => + emit( + body, + cb = cb, + env = env.bind(name, childProducerElement), + region = childProducer.elementRegion, + ) } } @@ -1114,13 +1298,13 @@ object EmitStream { override def method: EmitMethodBuilder[_] = mb override val length: Option[EmitCodeBuilder => Code[Int]] = childProducer.length - override def initialize(cb: EmitCodeBuilder, outerRegion: Value[Region]): Unit = { + override def initialize(cb: EmitCodeBuilder, outerRegion: Value[Region]): Unit = childProducer.initialize(cb, outerRegion) - } override val elementRegion: Settable[Region] = childProducer.elementRegion - override val requiresMemoryManagementPerElement: Boolean = childProducer.requiresMemoryManagementPerElement + override val requiresMemoryManagementPerElement: Boolean = + childProducer.requiresMemoryManagementPerElement override val LproduceElement: CodeLabel = mb.defineAndImplementLabel { cb => cb.goto(childProducer.LproduceElement) @@ -1133,24 +1317,27 @@ object EmitStream { def close(cb: EmitCodeBuilder): Unit = childProducer.close(cb) } - mb.implementLabel(childProducer.LendOfStream) { cb => - cb.goto(producer.LendOfStream) - } + mb.implementLabel(childProducer.LendOfStream)(cb => cb.goto(producer.LendOfStream)) SStreamValue(producer) } - case x@StreamScan(childIR, zeroIR, accName, eltName, bodyIR) => + case x @ StreamScan(childIR, zeroIR, accName, eltName, bodyIR) => produce(childIR, cb).map(cb) { case childStream: SStreamValue => val childProducer = childStream.getProducer(mb) - val accEmitType = VirtualTypeWithReq(zeroIR.typ, emitter.ctx.req.lookupState(x).head.asInstanceOf[TypeWithRequiredness]).canonicalEmitType + val accEmitType = VirtualTypeWithReq( + zeroIR.typ, + emitter.ctx.req.lookupState(x).head.asInstanceOf[TypeWithRequiredness], + ).canonicalEmitType val accValueAccRegion = mb.newEmitField(accEmitType) val accValueEltRegion = mb.newEmitField(accEmitType) // accRegion is unused if requiresMemoryManagementPerElement is false - val accRegion: Settable[Region] = if (childProducer.requiresMemoryManagementPerElement) mb.genFieldThisRef[Region]("streamscan_acc_region") else null + val accRegion: Settable[Region] = if (childProducer.requiresMemoryManagementPerElement) + mb.genFieldThisRef[Region]("streamscan_acc_region") + else null val first = mb.genFieldThisRef[Boolean]("streamscan_first") val producer = new StreamProducer { @@ -1169,41 +1356,68 @@ object EmitStream { override val elementRegion: Settable[Region] = childProducer.elementRegion - override val requiresMemoryManagementPerElement: Boolean = childProducer.requiresMemoryManagementPerElement + override val requiresMemoryManagementPerElement: Boolean = + childProducer.requiresMemoryManagementPerElement override val LproduceElement: CodeLabel = mb.defineAndImplementLabel { cb => - val LcopyAndReturn = CodeLabel() - cb.if_(first, { - - cb.assign(first, false) - cb.assign(accValueEltRegion, emit(zeroIR, cb, region = elementRegion).map(cb)(sc => sc.castTo(cb, elementRegion, accValueAccRegion.st))) + cb.if_( + first, { - cb.goto(LcopyAndReturn) - }) + cb.assign(first, false) + cb.assign( + accValueEltRegion, + emit(zeroIR, cb, region = elementRegion).map(cb)(sc => + sc.castTo(cb, elementRegion, accValueAccRegion.st) + ), + ) + cb.goto(LcopyAndReturn) + }, + ) cb.goto(childProducer.LproduceElement) cb.define(childProducer.LproduceElementDone) if (requiresMemoryManagementPerElement) { // deep copy accumulator into element region, then clear accumulator region - cb.assign(accValueEltRegion, accValueAccRegion.toI(cb).map(cb)(_.castTo(cb, childProducer.elementRegion, accEmitType.st, deepCopy = true))) + cb.assign( + accValueEltRegion, + accValueAccRegion.toI(cb).map(cb)(_.castTo( + cb, + childProducer.elementRegion, + accEmitType.st, + deepCopy = true, + )), + ) cb += accRegion.clearRegion() } - val bodyCode = cb.withScopedMaybeStreamValue(childProducer.element, "scan_child_elt") { ev => - emit(bodyIR, cb, env = env.bind((accName, accValueEltRegion), (eltName, ev)), region = childProducer.elementRegion) - .map(cb)(pc => pc.castTo(cb, childProducer.elementRegion, accEmitType.st, deepCopy = false)) - } + val bodyCode = + cb.withScopedMaybeStreamValue(childProducer.element, "scan_child_elt") { ev => + emit( + bodyIR, + cb, + env = env.bind((accName, accValueEltRegion), (eltName, ev)), + region = childProducer.elementRegion, + ) + .map(cb)(pc => + pc.castTo(cb, childProducer.elementRegion, accEmitType.st, deepCopy = false) + ) + } cb.assign(accValueEltRegion, bodyCode) cb.define(LcopyAndReturn) if (requiresMemoryManagementPerElement) { - cb.assign(accValueAccRegion, accValueEltRegion.toI(cb).map(cb)(pc => pc.castTo(cb, accRegion, accEmitType.st, deepCopy = true))) + cb.assign( + accValueAccRegion, + accValueEltRegion.toI(cb).map(cb)(pc => + pc.castTo(cb, accRegion, accEmitType.st, deepCopy = true) + ), + ) } cb.goto(LproduceElementDone) @@ -1218,23 +1432,30 @@ object EmitStream { } } - mb.implementLabel(childProducer.LendOfStream) { cb => - cb.goto(producer.LendOfStream) - } + mb.implementLabel(childProducer.LendOfStream)(cb => cb.goto(producer.LendOfStream)) SStreamValue(producer) } case RunAggScan(child, name, init, seqs, result, states) => - val (newContainer, aggSetup, aggCleanup) = AggContainer.fromMethodBuilder(states.toArray, mb, "run_agg_scan") + val (newContainer, aggSetup, aggCleanup) = + AggContainer.fromMethodBuilder(states.toArray, mb, "run_agg_scan") produce(child, cb).map(cb) { case childStream: SStreamValue => val childProducer = childStream.getProducer(mb) - val childEltField = mb.newEmitField("runaggscan_child_elt", childProducer.element.emitType) + val childEltField = + mb.newEmitField("runaggscan_child_elt", childProducer.element.emitType) val bodyEnv = env.bind(name -> childEltField) - val bodyResult = EmitCode.fromI(mb)(cb => emit(result, cb = cb, region = childProducer.elementRegion, - env = bodyEnv, container = Some(newContainer))) + val bodyResult = EmitCode.fromI(mb)(cb => + emit( + result, + cb = cb, + region = childProducer.elementRegion, + env = bodyEnv, + container = Some(newContainer), + ) + ) val bodyResultField = mb.newEmitField("runaggscan_result_elt", bodyResult.emitType) val producer = new StreamProducer { @@ -1248,13 +1469,20 @@ object EmitStream { } override val elementRegion: Settable[Region] = childProducer.elementRegion - override val requiresMemoryManagementPerElement: Boolean = childProducer.requiresMemoryManagementPerElement + override val requiresMemoryManagementPerElement: Boolean = + childProducer.requiresMemoryManagementPerElement override val LproduceElement: CodeLabel = mb.defineAndImplementLabel { cb => cb.goto(childProducer.LproduceElement) cb.define(childProducer.LproduceElementDone) cb.assign(childEltField, childProducer.element) cb.assign(bodyResultField, bodyResult.toI(cb)) - emitVoid(seqs, cb, region = elementRegion, env = bodyEnv, container = Some(newContainer)) + emitVoid( + seqs, + cb, + region = elementRegion, + env = bodyEnv, + container = Some(newContainer), + ) cb.goto(LproduceElementDone) } override val element: EmitCode = bodyResultField.load @@ -1265,16 +1493,23 @@ object EmitStream { } } - mb.implementLabel(childProducer.LendOfStream) { cb => - cb.goto(producer.LendOfStream) - } + mb.implementLabel(childProducer.LendOfStream)(cb => cb.goto(producer.LendOfStream)) SStreamValue(producer) } - case StreamWhiten(stream, newChunkName, prevWindowName, vecSize, windowSize, chunkSize, blockSize, normalizeAfterWhiten) => + case StreamWhiten(stream, newChunkName, prevWindowName, vecSize, windowSize, chunkSize, + blockSize, normalizeAfterWhiten) => produce(stream, cb).map(cb) { case blocks: SStreamValue => - val state = new LocalWhitening(cb, SizeValueStatic(vecSize.toLong), windowSize.toLong, chunkSize.toLong, blockSize.toLong, outerRegion, normalizeAfterWhiten) + val state = new LocalWhitening( + cb, + SizeValueStatic(vecSize.toLong), + windowSize.toLong, + chunkSize.toLong, + blockSize.toLong, + outerRegion, + normalizeAfterWhiten, + ) val eltType = blocks.st.elementType.asInstanceOf[SBaseStruct] var resultField: SSettable = null @@ -1294,21 +1529,38 @@ object EmitStream { } override val elementRegion: Settable[Region] = blocksProducer.elementRegion - override val requiresMemoryManagementPerElement: Boolean = blocksProducer.requiresMemoryManagementPerElement + override val requiresMemoryManagementPerElement: Boolean = + blocksProducer.requiresMemoryManagementPerElement override val LproduceElement: CodeLabel = mb.defineAndImplementLabel { cb => cb.goto(blocksProducer.LproduceElement) cb.define(blocksProducer.LproduceElementDone) - val row = blocksProducer.element.toI(cb).get(cb, "StreamWhiten: missing tuple").asBaseStruct - row.loadField(cb, prevWindowName).consume(cb, {}, { prevWindow => - state.initializeWindow(cb, prevWindow.asNDArray) - }) - val block = row.loadField(cb, newChunkName).get(cb, "StreamWhiten: missing chunk").asNDArray - val whitenedBlock = LinalgCodeUtils.checkColMajorAndCopyIfNeeded(block, cb, elementRegion) + val row = + blocksProducer.element.toI(cb).getOrFatal( + cb, + "StreamWhiten: missing tuple", + ).asBaseStruct + row.loadField(cb, prevWindowName).consume( + cb, + {}, + prevWindow => state.initializeWindow(cb, prevWindow.asNDArray), + ) + val block = + row.loadField(cb, newChunkName).getOrFatal( + cb, + "StreamWhiten: missing chunk", + ).asNDArray + val whitenedBlock = + LinalgCodeUtils.checkColMajorAndCopyIfNeeded(block, cb, elementRegion) state.whitenBlock(cb, whitenedBlock) // the 'newChunkName' field of 'row' is mutated in place and given // to the consumer - val result = row.insert(cb, elementRegion, eltType.virtualType.asInstanceOf[TStruct], newChunkName -> EmitValue.present(whitenedBlock)) + val result = row.insert( + cb, + elementRegion, + eltType.virtualType.asInstanceOf[TStruct], + newChunkName -> EmitValue.present(whitenedBlock), + ) resultField = mb.newPField("StreamWhiten_result", result.st) cb.assign(resultField, result) cb.goto(LproduceElementDone) @@ -1316,14 +1568,11 @@ object EmitStream { override val element: EmitCode = EmitCode.present(mb, resultField) - override def close(cb: EmitCodeBuilder): Unit = { + override def close(cb: EmitCodeBuilder): Unit = blocksProducer.close(cb) - } } - mb.implementLabel(blocksProducer.LendOfStream) { cb => - cb.goto(producer.LendOfStream) - } + mb.implementLabel(blocksProducer.LendOfStream)(cb => cb.goto(producer.LendOfStream)) SStreamValue(producer) } @@ -1337,11 +1586,14 @@ object EmitStream { val innerUnclosed = mb.genFieldThisRef[Boolean]("flatmap_inner_unclosed") val innerStreamEmitCode = EmitCode.fromI(mb) { cb => - cb.withScopedMaybeStreamValue(outerProducer.element, "flatmap_outer_value") { outerProducerValue => - emit(body, - cb = cb, - env = env.bind(name, outerProducerValue), - region = outerProducer.elementRegion) + cb.withScopedMaybeStreamValue(outerProducer.element, "flatmap_outer_value") { + outerProducerValue => + emit( + body, + cb = cb, + env = env.bind(name, outerProducerValue), + region = outerProducer.elementRegion, + ) } } @@ -1358,7 +1610,10 @@ object EmitStream { cb.assign(innerUnclosed, false) if (outerProducer.requiresMemoryManagementPerElement) - cb.assign(outerProducer.elementRegion, Region.stagedCreate(Region.REGULAR, cb.emb.ecb.pool())) + cb.assign( + outerProducer.elementRegion, + Region.stagedCreate(Region.REGULAR, cb.emb.ecb.pool()), + ) else cb.assign(outerProducer.elementRegion, outerRegion) @@ -1367,62 +1622,77 @@ object EmitStream { override val elementRegion: Settable[Region] = resultElementRegion - override val requiresMemoryManagementPerElement: Boolean = innerProducer.requiresMemoryManagementPerElement || outerProducer.requiresMemoryManagementPerElement + override val requiresMemoryManagementPerElement: Boolean = + innerProducer.requiresMemoryManagementPerElement || outerProducer.requiresMemoryManagementPerElement override val LproduceElement: CodeLabel = mb.defineAndImplementLabel { cb => val LnextOuter = CodeLabel() val LnextInner = CodeLabel() - cb.if_(first, { - cb.assign(first, false) - - cb.define(LnextOuter) - cb.define(innerProducer.LendOfStream) - - if (outerProducer.requiresMemoryManagementPerElement) - cb += outerProducer.elementRegion.clearRegion() + cb.if_( + first, { + cb.assign(first, false) + + cb.define(LnextOuter) + cb.define(innerProducer.LendOfStream) + + if (outerProducer.requiresMemoryManagementPerElement) + cb += outerProducer.elementRegion.clearRegion() + + cb.if_( + innerUnclosed, { + cb.assign(innerUnclosed, false) + innerProducer.close(cb) + if (innerProducer.requiresMemoryManagementPerElement) { + cb += innerProducer.elementRegion.invalidate() + } + }, + ) + cb.goto(outerProducer.LproduceElement) + cb.define(outerProducer.LproduceElementDone) - cb.if_(innerUnclosed, { - cb.assign(innerUnclosed, false) - innerProducer.close(cb) - if (innerProducer.requiresMemoryManagementPerElement) { - cb += innerProducer.elementRegion.invalidate() - } - }) - - cb.goto(outerProducer.LproduceElement) - cb.define(outerProducer.LproduceElementDone) - - innerStreamEmitCode.toI(cb).consume(cb, - // missing inner streams mean we should go to the next outer element - cb.goto(LnextOuter), - { - _ => - // the inner stream/producer is bound to a variable above - cb.assign(innerUnclosed, true) - if (innerProducer.requiresMemoryManagementPerElement) - cb.assign(innerProducer.elementRegion, Region.stagedCreate(Region.REGULAR, outerProducer.elementRegion.getPool())) - else - cb.assign(innerProducer.elementRegion, outerProducer.elementRegion) + innerStreamEmitCode.toI(cb).consume( + cb, + // missing inner streams mean we should go to the next outer element + cb.goto(LnextOuter), + { + _ => + // the inner stream/producer is bound to a variable above + cb.assign(innerUnclosed, true) + if (innerProducer.requiresMemoryManagementPerElement) + cb.assign( + innerProducer.elementRegion, + Region.stagedCreate( + Region.REGULAR, + outerProducer.elementRegion.getPool(), + ), + ) + else + cb.assign(innerProducer.elementRegion, outerProducer.elementRegion) - innerProducer.initialize(cb, outerRegion) - cb.goto(LnextInner) - } - ) - }) + innerProducer.initialize(cb, outerRegion) + cb.goto(LnextInner) + }, + ) + }, + ) cb.define(LnextInner) cb.goto(innerProducer.LproduceElement) cb.define(innerProducer.LproduceElementDone) if (requiresMemoryManagementPerElement) { - cb += resultElementRegion.trackAndIncrementReferenceCountOf(innerProducer.elementRegion) + cb += resultElementRegion.trackAndIncrementReferenceCountOf( + innerProducer.elementRegion + ) // if outer requires memory management and inner doesn't, // then innerProducer.elementRegion is outerProducer.elementRegion // and we shouldn't clear it. if (innerProducer.requiresMemoryManagementPerElement) { - cb += resultElementRegion.trackAndIncrementReferenceCountOf(outerProducer.elementRegion) + cb += resultElementRegion.trackAndIncrementReferenceCountOf( + outerProducer.elementRegion + ) cb += innerProducer.elementRegion.clearRegion() } } @@ -1431,13 +1701,15 @@ object EmitStream { val element: EmitCode = innerProducer.element def close(cb: EmitCodeBuilder): Unit = { - cb.if_(innerUnclosed, { - cb.assign(innerUnclosed, false) - if (innerProducer.requiresMemoryManagementPerElement) { - cb += innerProducer.elementRegion.invalidate() - } - innerProducer.close(cb) - }) + cb.if_( + innerUnclosed, { + cb.assign(innerUnclosed, false) + if (innerProducer.requiresMemoryManagementPerElement) { + cb += innerProducer.elementRegion.invalidate() + } + innerProducer.close(cb) + }, + ) outerProducer.close(cb) if (outerProducer.requiresMemoryManagementPerElement) @@ -1445,17 +1717,206 @@ object EmitStream { } } - mb.implementLabel(outerProducer.LendOfStream) { cb => - cb.goto(producer.LendOfStream) - } + mb.implementLabel(outerProducer.LendOfStream)(cb => cb.goto(producer.LendOfStream)) SStreamValue(producer) } - case x@StreamJoinRightDistinct(leftIR, rightIR, lKey, rKey, leftName, rightName, joinIR, joinType) => + case StreamLeftIntervalJoin(left, right, lKeyField, rIntrvlName, lName, rName, body) => + produce(left, cb).flatMap(cb) { case lStream: SStreamValue => + produce(right, cb).map(cb) { case rStream: SStreamValue => + // map over the keyStream + val lProd = lStream.getProducer(mb) + val rProd = rStream.getProducer(mb) + + val rElemSTy = + SBaseStructPointer(rProd.element.st.storageType().asInstanceOf[PBaseStruct]) + + def loadInterval(cb: EmitCodeBuilder, rElem: SValue): SIntervalValue = + rElem.asBaseStruct.loadField(cb, rIntrvlName).getOrAssert(cb).asInterval + + val q: StagedMinHeap = + StagedMinHeap(mb.emodb, rElemSTy) { + (cb: EmitCodeBuilder, a: SValue, b: SValue) => + val l = loadInterval(cb, a) + val r = loadInterval(cb, b) + IntervalFunctions.intervalEndpointCompare( + cb, + l.loadEnd(cb).getOrAssert(cb), + l.includesEnd, + r.loadEnd(cb).getOrAssert(cb), + r.includesEnd, + ) + }(mb.ecb) + + val lElement: SBaseStructSettable = + mb.newPField("LeftElement", lProd.element.st).asInstanceOf[SBaseStructSettable] + + val rElements: SSettable = + mb.newPField("RightElements", q.arraySType) + + var jElement: EmitSettable = + null + + val rEOS: ThisFieldRef[Boolean] = + mb.genFieldThisRef[Boolean]("RightEOS") + + val rPulled: ThisFieldRef[Boolean] = + mb.genFieldThisRef[Boolean]("RightPulled") + + SStreamValue { + new StreamProducer { + override def method: EmitMethodBuilder[_] = + mb + + override val length: Option[EmitCodeBuilder => Code[Int]] = + lProd.length + + override def initialize(cb: EmitCodeBuilder, outerRegion: Value[Region]): Unit = { + cb.assign(rEOS, false) + cb.assign(rPulled, false) + + for (p <- FastSeq(lProd, rProd)) { + p.initialize(cb, outerRegion) + cb.assign( + p.elementRegion, + if (p.requiresMemoryManagementPerElement) + Region.stagedCreate(REGULAR, mb.ecb.pool()) + else outerRegion.get, + ) + } + + q.init(cb, mb.ecb.pool()) + } + + override val elementRegion: Settable[Region] = + mb.genFieldThisRef[Region]("IntervalJoinRegion") + + override val requiresMemoryManagementPerElement: Boolean = + lProd.requiresMemoryManagementPerElement || rProd.requiresMemoryManagementPerElement + + override val LproduceElement: CodeLabel = + mb.defineAndImplementLabel { cb => + if (lProd.requiresMemoryManagementPerElement) { + cb += lProd.elementRegion.clearRegion() + } + + cb.goto(lProd.LproduceElement) + cb.define(lProd.LproduceElementDone) + + cb.assign(lElement, lProd.element.toI(cb).getOrAssert(cb).asBaseStruct) + val point = lElement.loadField(cb, lKeyField).getOrAssert(cb) + + /* Drop rows from the priority queue if their interval's right endpoint is + * before the current key. */ + cb.loop { Lrecur => + cb.if_( + q.nonEmpty(cb), { + val interval = loadInterval(cb, q.peek(cb)) + val end = interval.loadEnd(cb).getOrAssert(cb) + cb.if_( + pointGTIntervalEndpoint(cb, point, end, interval.includesEnd), { + q.pop(cb) + cb.goto(Lrecur) + }, + ) + }, + ) + } + + q.realloc(cb) + + val LallIntervalsFound = CodeLabel() + cb.if_(rEOS, cb.goto(LallIntervalsFound)) + + val LproduceRightElement = CodeLabel() + cb.if_(!rPulled, cb.goto(LproduceRightElement)) + + cb.loop { Lrecur => + val rElement = rElemSTy.coerceOrCopy( + cb, + elementRegion, + rProd.element.toI(cb).getOrAssert(cb), + deepCopy = false, + ) + val interval = loadInterval(cb, rElement) + + // Drop intervals whose right endpoint is before the key + val end = interval.loadEnd(cb).getOrAssert(cb) + cb.if_( + pointGTIntervalEndpoint(cb, point, end, interval.includesEnd), + cb.goto(LproduceRightElement), + ) + + // Stop consuming intervals if the left endpoint is after the key + val start = interval.loadStart(cb).getOrAssert(cb) + cb.if_( + pointLTIntervalEndpoint( + cb, + point, + start, + leansRight = !interval.includesStart, + ), + cb.goto(LallIntervalsFound), + ) + + q.push(cb, rElement) + + cb.define(LproduceRightElement) + if (rProd.requiresMemoryManagementPerElement) { + cb += rProd.elementRegion.clearRegion() + } + + cb.goto(rProd.LproduceElement) + cb.define(rProd.LproduceElementDone) + cb.assign(rPulled, true) + cb.goto(Lrecur) + } + + cb.define(rProd.LendOfStream) + cb.assign(rEOS, true) + + cb.define(LallIntervalsFound) + cb.assign(rElements, q.toArray(cb, elementRegion)) + val result = emit( + body, + cb, + region = elementRegion, + env = env.bind( + lName -> EmitValue.present(lElement), + rName -> EmitValue.present(rElements), + ), + ) + + jElement = mb.newEmitField("IntervalJoinResult", result.emitType) + cb.assign(jElement, result) + cb.goto(LproduceElementDone) + + cb.define(lProd.LendOfStream) + cb.goto(LendOfStream) + } + + override val element: EmitCode = + jElement + + override def close(cb: EmitCodeBuilder): Unit = { + q.close(cb) + for (p <- FastSeq(rProd, lProd)) { + p.close(cb) + if (p.requiresMemoryManagementPerElement) { + cb += p.elementRegion.invalidate() + } + } + } + } + } + } + } + + case x @ StreamJoinRightDistinct(leftIR, rightIR, lKey, rKey, leftName, rightName, joinIR, + joinType) => produce(leftIR, cb).flatMap(cb) { case leftStream: SStreamValue => produce(rightIR, cb).map(cb) { case rightStream: SStreamValue => - val leftProducer = leftStream.getProducer(mb) val rightProducer = rightStream.getProducer(mb) @@ -1468,30 +1929,38 @@ object EmitStream { if (x.isIntervalJoin) { val rhs = relt.toI(cb).flatMap(cb)(_.asBaseStruct.loadField(cb, rKey(0))) val result = cb.newLocal[Int]("SJRD-interval-compare-result") - rhs.consume(cb, { - cb.assign(result, -1) - }, { case interval: SIntervalValue => - val lhs = lelt.toI(cb).flatMap(cb)(_.asBaseStruct.loadField(cb, lKey(0))) - lhs.consume(cb, { - cb.assign(result, 1) - }, { point => - val c = IntervalFunctions.pointIntervalCompare(cb, point, interval) - c.consume(cb, { - // One of the interval endpoints is missing. In this case, - // consider the point greater, so that the join advances - // past the bad interval, keeping the point. - cb.assign(result, 1) - }, { c => - cb.assign(result, c.asInt.value) - }) - }) - }) + rhs.consume( + cb, + cb.assign(result, -1), + { case interval: SIntervalValue => + val lhs = lelt.toI(cb).flatMap(cb)(_.asBaseStruct.loadField(cb, lKey(0))) + lhs.consume( + cb, + cb.assign(result, 1), + { point => + val c = IntervalFunctions.pointIntervalCompare(cb, point, interval) + c.consume( + cb, + // One of the interval endpoints is missing. In this case, + // consider the point greater, so that the join advances + // past the bad interval, keeping the point. + cb.assign(result, 1), + c => cb.assign(result, c.asInt.value), + ) + }, + ) + }, + ) result } else { val lhs = lelt.map(cb)(_.asBaseStruct.subset(lKey: _*)) val rhs = relt.map(cb)(_.asBaseStruct.subset(rKey: _*)) - StructOrdering.make(lhs.st.asInstanceOf[SBaseStruct], rhs.st.asInstanceOf[SBaseStruct], - cb.emb.ecb, missingFieldsEqual = false) + StructOrdering.make( + lhs.st.asInstanceOf[SBaseStruct], + rhs.st.asInstanceOf[SBaseStruct], + cb.emb.ecb, + missingFieldsEqual = false, + ) .compare(cb, lhs, rhs, missingEqual = false) } } @@ -1502,27 +1971,41 @@ object EmitStream { val lxOut: EmitSettable = joinType match { case "inner" | "left" => lx - case "outer" | "right" => mb.newEmitField("streamjoin_lxout", lx.emitType.copy(required = false)) + case "outer" | "right" => + mb.newEmitField("streamjoin_lxout", lx.emitType.copy(required = false)) } val rxOut: EmitSettable = joinType match { case "inner" | "right" => rx - case "outer" | "left" => mb.newEmitField("streamjoin_rxout", rx.emitType.copy(required = false)) + case "outer" | "left" => + mb.newEmitField("streamjoin_rxout", rx.emitType.copy(required = false)) } val _elementRegion = mb.genFieldThisRef[Region]("join_right_distinct_element_region") - val _requiresMemoryManagementPerElement = leftProducer.requiresMemoryManagementPerElement || rightProducer.requiresMemoryManagementPerElement - - val joinResult = EmitCode.fromI(mb)(cb => emit(joinIR, cb, - region = _elementRegion, - env = env.bind(leftName -> lxOut, rightName -> rxOut))) + val _requiresMemoryManagementPerElement = + leftProducer.requiresMemoryManagementPerElement || rightProducer.requiresMemoryManagementPerElement + + val joinResult = EmitCode.fromI(mb)(cb => + emit( + joinIR, + cb, + region = _elementRegion, + env = env.bind(leftName -> lxOut, rightName -> rxOut), + ) + ) def sharedInit(cb: EmitCodeBuilder): Unit = { if (rightProducer.requiresMemoryManagementPerElement) - cb.assign(rightProducer.elementRegion, Region.stagedCreate(Region.REGULAR, outerRegion.getPool())) + cb.assign( + rightProducer.elementRegion, + Region.stagedCreate(Region.REGULAR, outerRegion.getPool()), + ) else cb.assign(rightProducer.elementRegion, outerRegion) if (leftProducer.requiresMemoryManagementPerElement) - cb.assign(leftProducer.elementRegion, Region.stagedCreate(Region.REGULAR, outerRegion.getPool())) + cb.assign( + leftProducer.elementRegion, + Region.stagedCreate(Region.REGULAR, outerRegion.getPool()), + ) else cb.assign(leftProducer.elementRegion, outerRegion) @@ -1543,7 +2026,8 @@ object EmitStream { joinType match { case "left" => val rightEOS = mb.genFieldThisRef[Boolean]("left_join_right_distinct_rightEOS") - val pulledRight = mb.genFieldThisRef[Boolean]("left_join_right_distinct_pulledRight]") + val pulledRight = + mb.genFieldThisRef[Boolean]("left_join_right_distinct_pulledRight]") val producer = new StreamProducer { override def method: EmitMethodBuilder[_] = mb @@ -1557,9 +2041,9 @@ object EmitStream { } override val elementRegion: Settable[Region] = _elementRegion - override val requiresMemoryManagementPerElement: Boolean = _requiresMemoryManagementPerElement + override val requiresMemoryManagementPerElement: Boolean = + _requiresMemoryManagementPerElement override val LproduceElement: CodeLabel = mb.defineAndImplementLabel { cb => - if (leftProducer.requiresMemoryManagementPerElement) cb += leftProducer.elementRegion.clearRegion() cb.goto(leftProducer.LproduceElement) @@ -1572,31 +2056,38 @@ object EmitStream { val Lpush = CodeLabel() val LpullRight = CodeLabel() - cb.if_(!pulledRight, { - cb.assign(pulledRight, true) - cb.goto(LpullRight) - }) + cb.if_( + !pulledRight, { + cb.assign(pulledRight, true) + cb.goto(LpullRight) + }, + ) val Lcompare = CodeLabel() cb.define(Lcompare) val c = cb.newLocal[Int]("left_join_right_distinct_c", compare(cb, lx, rx)) cb.if_(c > 0, cb.goto(LpullRight)) - cb.if_(c < 0, { - cb.assign(rxOut, EmitCode.missing(mb, rxOut.st)) - }, { - // c == 0 - if (rightProducer.requiresMemoryManagementPerElement) { - cb += elementRegion.trackAndIncrementReferenceCountOf(rightProducer.elementRegion) - } - cb.assign(rxOut, rx) - }) + cb.if_( + c < 0, + cb.assign(rxOut, EmitCode.missing(mb, rxOut.st)), { + // c == 0 + if (rightProducer.requiresMemoryManagementPerElement) { + cb += elementRegion.trackAndIncrementReferenceCountOf( + rightProducer.elementRegion + ) + } + cb.assign(rxOut, rx) + }, + ) cb.goto(Lpush) mb.implementLabel(Lpush) { cb => if (leftProducer.requiresMemoryManagementPerElement) - cb += elementRegion.trackAndIncrementReferenceCountOf(leftProducer.elementRegion) + cb += elementRegion.trackAndIncrementReferenceCountOf( + leftProducer.elementRegion + ) cb.goto(LproduceElementDone) } @@ -1618,21 +2109,22 @@ object EmitStream { cb.goto(Lpush) } - mb.implementLabel(leftProducer.LendOfStream) { cb => cb.goto(LendOfStream) } + mb.implementLabel(leftProducer.LendOfStream)(cb => cb.goto(LendOfStream)) } override val element: EmitCode = joinResult - override def close(cb: EmitCodeBuilder): Unit = { + override def close(cb: EmitCodeBuilder): Unit = sharedClose(cb) - } } SStreamValue(producer) case "right" => val leftEOS = mb.genFieldThisRef[Boolean]("left_join_right_distinct_leftEOS") - val pulledRight = mb.genFieldThisRef[Boolean]("left_join_right_distinct_pulledRight]") - val pushedRight = mb.genFieldThisRef[Boolean]("left_join_right_distinct_pulledRight]") + val pulledRight = + mb.genFieldThisRef[Boolean]("left_join_right_distinct_pulledRight]") + val pushedRight = + mb.genFieldThisRef[Boolean]("left_join_right_distinct_pulledRight]") val c = mb.genFieldThisRef[Int]("join_right_distinct_compResult") val producer = new StreamProducer { @@ -1648,7 +2140,8 @@ object EmitStream { } override val elementRegion: Settable[Region] = _elementRegion - override val requiresMemoryManagementPerElement: Boolean = _requiresMemoryManagementPerElement + override val requiresMemoryManagementPerElement: Boolean = + _requiresMemoryManagementPerElement override val LproduceElement: CodeLabel = mb.defineAndImplementLabel { cb => val Lpush = CodeLabel() val LpullRight = CodeLabel() @@ -1660,33 +2153,35 @@ object EmitStream { cb.if_(c <= 0, cb.goto(LpullLeft), cb.goto(LpullRight)) cb.define(Lcompare) - cb.if_(leftEOS, { - cb.if_(pushedRight, - cb.goto(LpullRight), - cb.goto(Lpush)) - }) + cb.if_(leftEOS, cb.if_(pushedRight, cb.goto(LpullRight), cb.goto(Lpush))) cb.assign(c, compare(cb, lx, rx)) cb.if_(c < 0, cb.goto(LpullLeft)) - cb.if_(c > 0, { - cb.if_(pushedRight, cb.goto(LpullRight)) - cb.assign(lxOut, EmitCode.missing(mb, lxOut.st)) - }, { - // c == 0 - if (leftProducer.requiresMemoryManagementPerElement) - cb += elementRegion.trackAndIncrementReferenceCountOf(leftProducer.elementRegion) - cb.assign(lxOut, lx) - }) + cb.if_( + c > 0, { + cb.if_(pushedRight, cb.goto(LpullRight)) + cb.assign(lxOut, EmitCode.missing(mb, lxOut.st)) + }, { + // c == 0 + if (leftProducer.requiresMemoryManagementPerElement) + cb += elementRegion.trackAndIncrementReferenceCountOf( + leftProducer.elementRegion + ) + cb.assign(lxOut, lx) + }, + ) cb.goto(Lpush) mb.implementLabel(LmaybePullRight) { cb => - cb.if_(!pulledRight, { - cb.assign(pulledRight, true) - cb.goto(LpullRight) - }, - cb.goto(Lcompare)) + cb.if_( + !pulledRight, { + cb.assign(pulledRight, true) + cb.goto(LpullRight) + }, + cb.goto(Lcompare), + ) } mb.implementLabel(LpullLeft) { cb => @@ -1710,7 +2205,9 @@ object EmitStream { mb.implementLabel(Lpush) { cb => if (rightProducer.requiresMemoryManagementPerElement) - cb += elementRegion.trackAndIncrementReferenceCountOf(rightProducer.elementRegion) + cb += elementRegion.trackAndIncrementReferenceCountOf( + rightProducer.elementRegion + ) cb.assign(pushedRight, true) cb.goto(LproduceElementDone) @@ -1724,19 +2221,19 @@ object EmitStream { } // end if right stream ends - mb.implementLabel(rightProducer.LendOfStream) { cb => cb.goto(LendOfStream) } + mb.implementLabel(rightProducer.LendOfStream)(cb => cb.goto(LendOfStream)) } override val element: EmitCode = joinResult - override def close(cb: EmitCodeBuilder): Unit = { + override def close(cb: EmitCodeBuilder): Unit = sharedClose(cb) - } } SStreamValue(producer) case "inner" => - val pulledRight = mb.genFieldThisRef[Boolean]("left_join_right_distinct_pulledRight]") + val pulledRight = + mb.genFieldThisRef[Boolean]("left_join_right_distinct_pulledRight]") val producer = new StreamProducer { override def method: EmitMethodBuilder[_] = mb @@ -1748,9 +2245,9 @@ object EmitStream { } override val elementRegion: Settable[Region] = _elementRegion - override val requiresMemoryManagementPerElement: Boolean = _requiresMemoryManagementPerElement + override val requiresMemoryManagementPerElement: Boolean = + _requiresMemoryManagementPerElement override val LproduceElement: CodeLabel = mb.defineAndImplementLabel { cb => - if (leftProducer.requiresMemoryManagementPerElement) cb += leftProducer.elementRegion.clearRegion() cb.goto(leftProducer.LproduceElement) @@ -1758,21 +2255,25 @@ object EmitStream { cb.assign(lx, leftProducer.element) val LpullRight = CodeLabel() - cb.if_(!pulledRight, { - cb.assign(pulledRight, true) - cb.goto(LpullRight) - }) + cb.if_( + !pulledRight, { + cb.assign(pulledRight, true) + cb.goto(LpullRight) + }, + ) val Lcompare = CodeLabel() cb.define(Lcompare) val c = cb.newLocal[Int]("left_join_right_distinct_c", compare(cb, lx, rx)) cb.if_(c > 0, cb.goto(LpullRight)) - cb.if_(c < 0, { - if (leftProducer.requiresMemoryManagementPerElement) - cb += leftProducer.elementRegion.clearRegion() - cb.goto(leftProducer.LproduceElement) - }) + cb.if_( + c < 0, { + if (leftProducer.requiresMemoryManagementPerElement) + cb += leftProducer.elementRegion.clearRegion() + cb.goto(leftProducer.LproduceElement) + }, + ) cb.goto(LproduceElementDone) @@ -1788,14 +2289,13 @@ object EmitStream { } // Both producer EOS labels should jump directly to EOS - mb.implementLabel(rightProducer.LendOfStream) { cb => cb.goto(LendOfStream) } - mb.implementLabel(leftProducer.LendOfStream) { cb => cb.goto(LendOfStream) } + mb.implementLabel(rightProducer.LendOfStream)(cb => cb.goto(LendOfStream)) + mb.implementLabel(leftProducer.LendOfStream)(cb => cb.goto(LendOfStream)) } override val element: EmitCode = joinResult - override def close(cb: EmitCodeBuilder): Unit = { + override def close(cb: EmitCodeBuilder): Unit = sharedClose(cb) - } } SStreamValue(producer) @@ -1823,20 +2323,22 @@ object EmitStream { } override val elementRegion: Settable[Region] = _elementRegion - override val requiresMemoryManagementPerElement: Boolean = _requiresMemoryManagementPerElement + override val requiresMemoryManagementPerElement: Boolean = + _requiresMemoryManagementPerElement override val LproduceElement: CodeLabel = mb.defineAndImplementLabel { cb => - val LpullRight = CodeLabel() val LpullLeft = CodeLabel() val Lpush = CodeLabel() - cb.if_(leftEOS, + cb.if_( + leftEOS, cb.goto(LpullRight), - cb.if_(rightEOS, + cb.if_( + rightEOS, cb.goto(LpullLeft), - cb.if_(c <= 0, - cb.goto(LpullLeft), - cb.goto(LpullRight)))) + cb.if_(c <= 0, cb.goto(LpullLeft), cb.goto(LpullRight)), + ), + ) cb.define(LpullRight) if (rightProducer.requiresMemoryManagementPerElement) @@ -1854,55 +2356,65 @@ object EmitStream { cb.assign(c, compare(cb, lx, rx)) cb.assign(lOutMissing, false) cb.assign(rOutMissing, false) - cb.if_(c > 0, - { - cb.if_(pulledRight && !pushedRight, { - cb.assign(lOutMissing, true) - if (rightProducer.requiresMemoryManagementPerElement) { - cb += elementRegion.trackAndIncrementReferenceCountOf(rightProducer.elementRegion) - } - cb.goto(Lpush) - }, - cb.goto(LpullRight) + cb.if_( + c > 0, { + cb.if_( + pulledRight && !pushedRight, { + cb.assign(lOutMissing, true) + if (rightProducer.requiresMemoryManagementPerElement) { + cb += elementRegion.trackAndIncrementReferenceCountOf( + rightProducer.elementRegion + ) + } + cb.goto(Lpush) + }, + cb.goto(LpullRight), ) - }, - { - cb.if_(c < 0, - { + }, { + cb.if_( + c < 0, { cb.assign(rOutMissing, true) if (leftProducer.requiresMemoryManagementPerElement) { - cb += elementRegion.trackAndIncrementReferenceCountOf(leftProducer.elementRegion) + cb += elementRegion.trackAndIncrementReferenceCountOf( + leftProducer.elementRegion + ) } cb.goto(Lpush) - }, - { + }, { // c == 0 if (leftProducer.requiresMemoryManagementPerElement) { - cb += elementRegion.trackAndIncrementReferenceCountOf(leftProducer.elementRegion) + cb += elementRegion.trackAndIncrementReferenceCountOf( + leftProducer.elementRegion + ) } if (rightProducer.requiresMemoryManagementPerElement) { - cb += elementRegion.trackAndIncrementReferenceCountOf(rightProducer.elementRegion) + cb += elementRegion.trackAndIncrementReferenceCountOf( + rightProducer.elementRegion + ) } cb.goto(Lpush) - }) - }) + }, + ) + }, + ) } mb.implementLabel(Lpush) { cb => - cb.if_(lOutMissing, + cb.if_( + lOutMissing, cb.assign(lxOut, EmitCode.missing(mb, lxOut.st)), - cb.assign(lxOut, lx) + cb.assign(lxOut, lx), ) - cb.if_(rOutMissing, - cb.assign(rxOut, EmitCode.missing(mb, rxOut.st)), - { + cb.if_( + rOutMissing, + cb.assign(rxOut, EmitCode.missing(mb, rxOut.st)), { cb.assign(rxOut, rx) cb.assign(pushedRight, true) - }) + }, + ) cb.goto(LproduceElementDone) } - mb.implementLabel(rightProducer.LproduceElementDone) { cb => cb.assign(rx, rightProducer.element) cb.assign(pushedRight, false) @@ -1911,43 +2423,47 @@ object EmitStream { mb.implementLabel(leftProducer.LproduceElementDone) { cb => cb.assign(lx, leftProducer.element) - cb.if_(pulledRight, - cb.if_(rightEOS, - { + cb.if_( + pulledRight, + cb.if_( + rightEOS, { if (leftProducer.requiresMemoryManagementPerElement) { - cb += elementRegion.trackAndIncrementReferenceCountOf(leftProducer.elementRegion) + cb += elementRegion.trackAndIncrementReferenceCountOf( + leftProducer.elementRegion + ) } cb.goto(Lpush) }, - { - cb.goto(Lcompare) - } + cb.goto(Lcompare), ), - cb.goto(LpullRight) + cb.goto(LpullRight), ) } mb.implementLabel(leftProducer.LendOfStream) { cb => - cb.if_(rightEOS, - cb.goto(LendOfStream), - { + cb.if_( + rightEOS, + cb.goto(LendOfStream), { cb.assign(leftEOS, true) cb.assign(lOutMissing, true) cb.assign(rOutMissing, false) - cb.if_(pulledRight && !pushedRight, - { + cb.if_( + pulledRight && !pushedRight, { if (rightProducer.requiresMemoryManagementPerElement) { - cb += elementRegion.trackAndIncrementReferenceCountOf(rightProducer.elementRegion) + cb += elementRegion.trackAndIncrementReferenceCountOf( + rightProducer.elementRegion + ) } cb.goto(Lpush) - }, - { + }, { if (rightProducer.requiresMemoryManagementPerElement) { cb += rightProducer.elementRegion.clearRegion() } cb.goto(LpullRight) - }) - }) + }, + ) + }, + ) } mb.implementLabel(rightProducer.LendOfStream) { cb => @@ -1957,7 +2473,9 @@ object EmitStream { cb.assign(rOutMissing, true) if (leftProducer.requiresMemoryManagementPerElement) { - cb += elementRegion.trackAndIncrementReferenceCountOf(leftProducer.elementRegion) + cb += elementRegion.trackAndIncrementReferenceCountOf( + leftProducer.elementRegion + ) cb += leftProducer.elementRegion.clearRegion() } cb.goto(Lpush) @@ -1965,9 +2483,8 @@ object EmitStream { } override val element: EmitCode = joinResult - override def close(cb: EmitCodeBuilder): Unit = { + override def close(cb: EmitCodeBuilder): Unit = sharedClose(cb) - } } SStreamValue(producer) @@ -1977,7 +2494,6 @@ object EmitStream { case StreamGroupByKey(a, key, missingEqual) => produce(a, cb).map(cb) { case childStream: SStreamValue => - val childProducer = childStream.getProducer(mb) val xCurElt = mb.newPField("st_grpby_curelt", childProducer.element.st) @@ -1995,16 +2511,24 @@ object EmitStream { val inOuter = mb.genFieldThisRef[Boolean]("streamgroupbykey_inouter") val first = mb.genFieldThisRef[Boolean]("streamgroupbykey_first") - // cannot reuse childProducer.elementRegion because consumers might free the region, even though - // the outer producer needs to continue pulling. We could add more control flow that sets some - // boolean flag when the inner stream is closed, and the outer producer reassigns a region if + /* cannot reuse childProducer.elementRegion because consumers might free the region, even + * though */ + /* the outer producer needs to continue pulling. We could add more control flow that sets + * some */ + /* boolean flag when the inner stream is closed, and the outer producer reassigns a region + * if */ // that flag is set, but that design seems more complicated val innerResultRegion = mb.genFieldThisRef[Region]("streamgroupbykey_inner_result_region") val outerElementRegion = mb.genFieldThisRef[Region]("streamgroupbykey_outer_elt_region") def equiv(cb: EmitCodeBuilder, l: SBaseStructValue, r: SBaseStructValue): Value[Boolean] = - StructOrdering.make(l.st, r.st, cb.emb.ecb, missingFieldsEqual = missingEqual).equivNonnull(cb, l, r) + StructOrdering.make( + l.st, + r.st, + cb.emb.ecb, + missingFieldsEqual = missingEqual, + ).equivNonnull(cb, l, r) val LchildProduceDoneInner = CodeLabel() val LchildProduceDoneOuter = CodeLabel() @@ -2015,15 +2539,19 @@ object EmitStream { override def initialize(cb: EmitCodeBuilder, outerRegion: Value[Region]): Unit = {} override val elementRegion: Settable[Region] = innerResultRegion - override val requiresMemoryManagementPerElement: Boolean = childProducer.requiresMemoryManagementPerElement + override val requiresMemoryManagementPerElement: Boolean = + childProducer.requiresMemoryManagementPerElement override val LproduceElement: CodeLabel = mb.defineAndImplementLabel { cb => val LelementReady = CodeLabel() - // the first pull from the inner stream has the next record ready to go from the outer stream - cb.if_(inOuter, { - cb.assign(inOuter, false) - cb.goto(LelementReady) - }) + /* the first pull from the inner stream has the next record ready to go from the outer + * stream */ + cb.if_( + inOuter, { + cb.assign(inOuter, false) + cb.goto(LelementReady) + }, + ) if (childProducer.requiresMemoryManagementPerElement) cb += childProducer.elementRegion.clearRegion() @@ -2032,20 +2560,24 @@ object EmitStream { cb.define(LchildProduceDoneInner) // if not equivalent, end inner stream and prepare for next outer iteration - cb.if_(!equiv(cb, curKey.asBaseStruct, lastKey.asBaseStruct), { - if (requiresMemoryManagementPerElement) - cb += keyRegion.clearRegion() + cb.if_( + !equiv(cb, curKey.asBaseStruct, lastKey.asBaseStruct), { + if (requiresMemoryManagementPerElement) + cb += keyRegion.clearRegion() - cb.assign(lastKey, subsetCode.castTo(cb, keyRegion, lastKey.st, deepCopy = true)) - cb.assign(nextGroupReady, true) - cb.assign(inOuter, true) - cb.goto(LendOfStream) - }) + cb.assign(lastKey, subsetCode.castTo(cb, keyRegion, lastKey.st, deepCopy = true)) + cb.assign(nextGroupReady, true) + cb.assign(inOuter, true) + cb.goto(LendOfStream) + }, + ) cb.define(LelementReady) if (requiresMemoryManagementPerElement) { - cb += innerResultRegion.trackAndIncrementReferenceCountOf(childProducer.elementRegion) + cb += innerResultRegion.trackAndIncrementReferenceCountOf( + childProducer.elementRegion + ) } cb.goto(LproduceElementDone) @@ -2067,7 +2599,10 @@ object EmitStream { if (childProducer.requiresMemoryManagementPerElement) { cb.assign(keyRegion, Region.stagedCreate(Region.REGULAR, outerRegion.getPool())) - cb.assign(childProducer.elementRegion, Region.stagedCreate(Region.REGULAR, outerRegion.getPool())) + cb.assign( + childProducer.elementRegion, + Region.stagedCreate(Region.REGULAR, outerRegion.getPool()), + ) } else { cb.assign(keyRegion, outerRegion) cb.assign(childProducer.elementRegion, outerRegion) @@ -2077,11 +2612,10 @@ object EmitStream { } override val elementRegion: Settable[Region] = outerElementRegion - override val requiresMemoryManagementPerElement: Boolean = childProducer.requiresMemoryManagementPerElement + override val requiresMemoryManagementPerElement: Boolean = + childProducer.requiresMemoryManagementPerElement override val LproduceElement: CodeLabel = mb.defineAndImplementLabel { cb => - cb.if_(eos, { - cb.goto(LendOfStream) - }) + cb.if_(eos, cb.goto(LendOfStream)) val LinnerStreamReady = CodeLabel() @@ -2097,17 +2631,21 @@ object EmitStream { val LdifferentKey = CodeLabel() - cb.if_(first, { - cb.assign(first, false) - cb.goto(LdifferentKey) - }) + cb.if_( + first, { + cb.assign(first, false) + cb.goto(LdifferentKey) + }, + ) // if equiv, go to next element. Otherwise, fall through to next group - cb.if_(equiv(cb, curKey.asBaseStruct, lastKey.asBaseStruct), { - if (childProducer.requiresMemoryManagementPerElement) - cb += childProducer.elementRegion.clearRegion() - cb.goto(childProducer.LproduceElement) - }) + cb.if_( + equiv(cb, curKey.asBaseStruct, lastKey.asBaseStruct), { + if (childProducer.requiresMemoryManagementPerElement) + cb += childProducer.elementRegion.clearRegion() + cb.goto(childProducer.LproduceElement) + }, + ) cb.define(LdifferentKey) if (requiresMemoryManagementPerElement) @@ -2133,14 +2671,15 @@ object EmitStream { mb.implementLabel(childProducer.LendOfStream) { cb => cb.assign(eos, true) - cb.if_(inOuter, + cb.if_( + inOuter, cb.goto(outerProducer.LendOfStream), - cb.goto(innerProducer.LendOfStream) + cb.goto(innerProducer.LendOfStream), ) } mb.implementLabel(childProducer.LproduceElementDone) { cb => - cb.assign(xCurElt, childProducer.element.toI(cb).get(cb)) + cb.assign(xCurElt, childProducer.element.toI(cb).getOrAssert(cb)) cb.assign(curKey, subsetCode) cb.if_(inOuter, cb.goto(LchildProduceDoneOuter), cb.goto(LchildProduceDoneInner)) } @@ -2150,9 +2689,7 @@ object EmitStream { case StreamGrouped(a, groupSize) => produce(a, cb).flatMap(cb) { case childStream: SStreamValue => - emit(groupSize, cb).map(cb) { case groupSize: SInt32Value => - val n = mb.genFieldThisRef[Int]("streamgrouped_n") val childProducer = childStream.getProducer(mb) @@ -2163,9 +2700,12 @@ object EmitStream { val outerElementRegion = mb.genFieldThisRef[Region]("streamgrouped_outer_elt_region") - // cannot reuse childProducer.elementRegion because consumers might free the region, even though - // the outer producer needs to continue pulling. We could add more control flow that sets some - // boolean flag when the inner stream is closed, and the outer producer reassigns a region if + /* cannot reuse childProducer.elementRegion because consumers might free the region, + * even though */ + /* the outer producer needs to continue pulling. We could add more control flow that + * sets some */ + /* boolean flag when the inner stream is closed, and the outer producer reassigns a + * region if */ // that flag is set, but that design seems more complicated val innerResultRegion = mb.genFieldThisRef[Region]("streamgrouped_inner_result_region") @@ -2178,23 +2718,33 @@ object EmitStream { override def initialize(cb: EmitCodeBuilder, outerRegion: Value[Region]): Unit = {} override val elementRegion: Settable[Region] = innerResultRegion - override val requiresMemoryManagementPerElement: Boolean = childProducer.requiresMemoryManagementPerElement + override val requiresMemoryManagementPerElement: Boolean = + childProducer.requiresMemoryManagementPerElement override val LproduceElement: CodeLabel = mb.defineAndImplementLabel { cb => - cb.if_(inOuter, { - cb.assign(inOuter, false) - cb.if_(xCounter.cne(1), cb._fatal(s"streamgrouped inner producer error, xCounter=", xCounter.toS)) - cb.goto(LchildProduceDoneInner) - }) - cb.if_(xCounter >= n, { - cb.assign(inOuter, true) - cb.goto(LendOfStream) - }) + cb.if_( + inOuter, { + cb.assign(inOuter, false) + cb.if_( + xCounter.cne(1), + cb._fatal(s"streamgrouped inner producer error, xCounter=", xCounter.toS), + ) + cb.goto(LchildProduceDoneInner) + }, + ) + cb.if_( + xCounter >= n, { + cb.assign(inOuter, true) + cb.goto(LendOfStream) + }, + ) cb.goto(childProducer.LproduceElement) cb.define(LchildProduceDoneInner) if (childProducer.requiresMemoryManagementPerElement) { - cb += innerResultRegion.trackAndIncrementReferenceCountOf(childProducer.elementRegion) + cb += innerResultRegion.trackAndIncrementReferenceCountOf( + childProducer.elementRegion + ) cb += childProducer.elementRegion.clearRegion() } @@ -2209,7 +2759,9 @@ object EmitStream { val outerProducer = new StreamProducer { override def method: EmitMethodBuilder[_] = mb override val length: Option[EmitCodeBuilder => Code[Int]] = - childProducer.length.map(compL => (cb: EmitCodeBuilder) => ((compL(cb).toL + n.toL - 1L) / n.toL).toI) + childProducer.length.map(compL => + (cb: EmitCodeBuilder) => ((compL(cb).toL + n.toL - 1L) / n.toL).toI + ) override def initialize(cb: EmitCodeBuilder, outerRegion: Value[Region]): Unit = { cb.assign(n, groupSize.value) @@ -2218,7 +2770,10 @@ object EmitStream { cb.assign(xCounter, n) if (childProducer.requiresMemoryManagementPerElement) { - cb.assign(childProducer.elementRegion, Region.stagedCreate(Region.REGULAR, outerRegion.getPool())) + cb.assign( + childProducer.elementRegion, + Region.stagedCreate(Region.REGULAR, outerRegion.getPool()), + ) } else { cb.assign(childProducer.elementRegion, outerRegion) } @@ -2227,22 +2782,21 @@ object EmitStream { } override val elementRegion: Settable[Region] = outerElementRegion - override val requiresMemoryManagementPerElement: Boolean = childProducer.requiresMemoryManagementPerElement + override val requiresMemoryManagementPerElement: Boolean = + childProducer.requiresMemoryManagementPerElement override val LproduceElement: CodeLabel = mb.defineAndImplementLabel { cb => - cb.if_(eos, { - cb.goto(LendOfStream) - }) + cb.if_(eos, cb.goto(LendOfStream)) cb.assign(inOuter, true) cb.define(LchildProduceDoneOuter) - - cb.if_(xCounter <= n, - { + cb.if_( + xCounter <= n, { if (childProducer.requiresMemoryManagementPerElement) cb += childProducer.elementRegion.clearRegion() cb.goto(childProducer.LproduceElement) - }) + }, + ) cb.assign(xCounter, 1) cb.goto(LproduceElementDone) } @@ -2257,9 +2811,10 @@ object EmitStream { mb.implementLabel(childProducer.LendOfStream) { cb => cb.assign(eos, true) - cb.if_(inOuter, + cb.if_( + inOuter, cb.goto(outerProducer.LendOfStream), - cb.goto(innerProducer.LendOfStream) + cb.goto(innerProducer.LendOfStream), ) } @@ -2273,503 +2828,557 @@ object EmitStream { } case StreamZip(as, names, body, behavior, errorID) => - IEmitCode.multiMapEmitCodes(cb, as.map(a => EmitCode.fromI(cb.emb)(cb => produce(a, cb)))) { childStreams => - - val producers = childStreams.map(_.asStream.getProducer(mb)) + IEmitCode.multiMapEmitCodes(cb, as.map(a => EmitCode.fromI(cb.emb)(cb => produce(a, cb)))) { + childStreams => + val producers = childStreams.map(_.asStream.getProducer(mb)) - assert(names.length == producers.length) + assert(names.length == producers.length) - val producer: StreamProducer = behavior match { - case behavior@(ArrayZipBehavior.TakeMinLength | ArrayZipBehavior.AssumeSameLength) => - val vars = names.zip(producers).map { case (name, p) => mb.newEmitField(name, p.element.emitType) } + val producer: StreamProducer = behavior match { + case behavior @ (ArrayZipBehavior.TakeMinLength | ArrayZipBehavior.AssumeSameLength) => + val vars = names.zip(producers).map { case (name, p) => + mb.newEmitField(name, p.element.emitType) + } - val eltRegion = mb.genFieldThisRef[Region]("streamzip_eltregion") - val bodyCode = EmitCode.fromI(mb)(cb => emit(body, cb, region = eltRegion, env = env.bind(names.zip(vars): _*))) + val eltRegion = mb.genFieldThisRef[Region]("streamzip_eltregion") + val bodyCode = EmitCode.fromI(mb)(cb => + emit(body, cb, region = eltRegion, env = env.bind(names.zip(vars): _*)) + ) - new StreamProducer { - override def method: EmitMethodBuilder[_] = mb - override val length: Option[EmitCodeBuilder => Code[Int]] = { - behavior match { - case ArrayZipBehavior.AssumeSameLength => - producers.flatMap(_.length).headOption - case ArrayZipBehavior.TakeMinLength => - anyFailAllFail((producers, as).zipped.flatMap { (producer, child) => - child match { - case _: StreamIota => None - case _ => Some(producer.length) - } - }).map { compLens => - (cb: EmitCodeBuilder) => { + new StreamProducer { + override def method: EmitMethodBuilder[_] = mb + override val length: Option[EmitCodeBuilder => Code[Int]] = { + behavior match { + case ArrayZipBehavior.AssumeSameLength => + producers.flatMap(_.length).headOption + case ArrayZipBehavior.TakeMinLength => + anyFailAllFail((producers, as).zipped.flatMap { (producer, child) => + child match { + case _: StreamIota => None + case _ => Some(producer.length) + } + }).map { compLens => (cb: EmitCodeBuilder) => compLens.map(_.apply(cb)).reduce(_.min(_)) } - } + } } - } - override def initialize(cb: EmitCodeBuilder, outerRegion: Value[Region]): Unit = { - producers.foreach { p => - if (p.requiresMemoryManagementPerElement) - cb.assign(p.elementRegion, eltRegion) - else - cb.assign(p.elementRegion, outerRegion) - p.initialize(cb, outerRegion) + override def initialize(cb: EmitCodeBuilder, outerRegion: Value[Region]): Unit = { + producers.foreach { p => + if (p.requiresMemoryManagementPerElement) + cb.assign(p.elementRegion, eltRegion) + else + cb.assign(p.elementRegion, outerRegion) + p.initialize(cb, outerRegion) + } } - } - override val elementRegion: Settable[Region] = eltRegion + override val elementRegion: Settable[Region] = eltRegion - override val requiresMemoryManagementPerElement: Boolean = producers.exists(_.requiresMemoryManagementPerElement) + override val requiresMemoryManagementPerElement: Boolean = + producers.exists(_.requiresMemoryManagementPerElement) - override val LproduceElement: CodeLabel = mb.defineAndImplementLabel { cb => - - producers.zipWithIndex.foreach { case (p, i) => - cb.goto(p.LproduceElement) - cb.define(p.LproduceElementDone) - cb.assign(vars(i), p.element) - } + override val LproduceElement: CodeLabel = mb.defineAndImplementLabel { cb => + producers.zipWithIndex.foreach { case (p, i) => + cb.goto(p.LproduceElement) + cb.define(p.LproduceElementDone) + cb.assign(vars(i), p.element) + } - cb.goto(LproduceElementDone) + cb.goto(LproduceElementDone) - // all producer EOS jumps should immediately jump to zipped EOS - producers.foreach { p => - cb.define(p.LendOfStream) - cb.goto(LendOfStream) + // all producer EOS jumps should immediately jump to zipped EOS + producers.foreach { p => + cb.define(p.LendOfStream) + cb.goto(LendOfStream) + } } - } - val element: EmitCode = bodyCode + val element: EmitCode = bodyCode - def close(cb: EmitCodeBuilder): Unit = { - producers.foreach(_.close(cb)) + def close(cb: EmitCodeBuilder): Unit = + producers.foreach(_.close(cb)) } - } - case ArrayZipBehavior.AssertSameLength => - - val vars = names.zip(producers).map { case (name, p) => mb.newEmitField(name, p.element.emitType) } - - val eltRegion = mb.genFieldThisRef[Region]("streamzip_eltregion") - val bodyCode = EmitCode.fromI(mb)(cb => emit(body, cb, region = eltRegion, env = env.bind(names.zip(vars): _*))) + case ArrayZipBehavior.AssertSameLength => + val vars = names.zip(producers).map { case (name, p) => + mb.newEmitField(name, p.element.emitType) + } - val anyEOS = mb.genFieldThisRef[Boolean]("zip_any_eos") - val allEOS = mb.genFieldThisRef[Boolean]("zip_all_eos") + val eltRegion = mb.genFieldThisRef[Region]("streamzip_eltregion") + val bodyCode = EmitCode.fromI(mb)(cb => + emit(body, cb, region = eltRegion, env = env.bind(names.zip(vars): _*)) + ) + val anyEOS = mb.genFieldThisRef[Boolean]("zip_any_eos") + val allEOS = mb.genFieldThisRef[Boolean]("zip_all_eos") - new StreamProducer { - override def method: EmitMethodBuilder[_] = mb - override val length: Option[EmitCodeBuilder => Code[Int]] = producers.flatMap(_.length) match { - case Seq() => None - case ls => - val len = mb.genFieldThisRef[Int]("zip_asl_len") - val lenTemp = mb.genFieldThisRef[Int]("zip_asl_len_temp") - Some({cb: EmitCodeBuilder => - val len = cb.newLocal[Int]("zip_len", ls.head(cb)) - ls.tail.foreach { compL => - val lenTemp = cb.newLocal[Int]("lenTemp", compL(cb)) - cb.if_(len.cne(lenTemp), cb._fatalWithError(errorID, "zip: length mismatch: ", len.toS, ", ", lenTemp.toS)) - } - len - }) - } + new StreamProducer { + override def method: EmitMethodBuilder[_] = mb + override val length: Option[EmitCodeBuilder => Code[Int]] = + producers.flatMap(_.length) match { + case Seq() => None + case ls => + Some({ cb: EmitCodeBuilder => + val len = cb.newLocal[Int]("zip_len", ls.head(cb)) + ls.tail.foreach { compL => + val lenTemp = cb.newLocal[Int]("lenTemp", compL(cb)) + cb.if_( + len.cne(lenTemp), + cb._fatalWithError( + errorID, + "zip: length mismatch: ", + len.toS, + ", ", + lenTemp.toS, + ), + ) + } + len + }) + } - override def initialize(cb: EmitCodeBuilder, outerRegion: Value[Region]): Unit = { - cb.assign(anyEOS, false) + override def initialize(cb: EmitCodeBuilder, outerRegion: Value[Region]): Unit = { + cb.assign(anyEOS, false) - producers.foreach { p => - if (p.requiresMemoryManagementPerElement) - cb.assign(p.elementRegion, eltRegion) - else - cb.assign(p.elementRegion, outerRegion) - p.initialize(cb, outerRegion) + producers.foreach { p => + if (p.requiresMemoryManagementPerElement) + cb.assign(p.elementRegion, eltRegion) + else + cb.assign(p.elementRegion, outerRegion) + p.initialize(cb, outerRegion) + } } - } - override val elementRegion: Settable[Region] = eltRegion + override val elementRegion: Settable[Region] = eltRegion - override val requiresMemoryManagementPerElement: Boolean = producers.exists(_.requiresMemoryManagementPerElement) + override val requiresMemoryManagementPerElement: Boolean = + producers.exists(_.requiresMemoryManagementPerElement) - override val LproduceElement: CodeLabel = mb.defineAndImplementLabel { cb => - cb.assign(allEOS, true) + override val LproduceElement: CodeLabel = mb.defineAndImplementLabel { cb => + cb.assign(allEOS, true) - producers.zipWithIndex.foreach { case (p, i) => + producers.zipWithIndex.foreach { case (p, i) => + val fallThrough = CodeLabel() - val fallThrough = CodeLabel() + cb.goto(p.LproduceElement) - cb.goto(p.LproduceElement) + cb.define(p.LendOfStream) + cb.assign(anyEOS, true) + cb.goto(fallThrough) - cb.define(p.LendOfStream) - cb.assign(anyEOS, true) - cb.goto(fallThrough) + cb.define(p.LproduceElementDone) + cb.assign(vars(i), p.element) + cb.assign(allEOS, false) - cb.define(p.LproduceElementDone) - cb.assign(vars(i), p.element) - cb.assign(allEOS, false) + cb.define(fallThrough) + } - cb.define(fallThrough) + cb.if_( + anyEOS, + cb.if_( + allEOS, + cb.goto(LendOfStream), + cb._fatalWithError(errorID, "zip: length mismatch"), + ), + ) + + cb.goto(LproduceElementDone) } - cb.if_(anyEOS, - cb.if_(allEOS, - cb.goto(LendOfStream), - cb._fatalWithError(errorID, "zip: length mismatch")) - ) + val element: EmitCode = bodyCode - cb.goto(LproduceElementDone) + def close(cb: EmitCodeBuilder): Unit = + producers.foreach(_.close(cb)) } - val element: EmitCode = bodyCode - - def close(cb: EmitCodeBuilder): Unit = { - producers.foreach(_.close(cb)) + case ArrayZipBehavior.ExtendNA => + val vars = names.zip(producers).map { case (name, p) => + mb.newEmitField(name, p.element.emitType.copy(required = false)) } - } - - case ArrayZipBehavior.ExtendNA => - val vars = names.zip(producers).map { case (name, p) => mb.newEmitField(name, p.element.emitType.copy(required = false)) } - - val eltRegion = mb.genFieldThisRef[Region]("streamzip_eltregion") - val bodyCode = EmitCode.fromI(mb)(cb => emit(body, cb, region = eltRegion, env = env.bind(names.zip(vars): _*))) + val eltRegion = mb.genFieldThisRef[Region]("streamzip_eltregion") + val bodyCode = EmitCode.fromI(mb)(cb => + emit(body, cb, region = eltRegion, env = env.bind(names.zip(vars): _*)) + ) - val eosPerStream = producers.indices.map(i => mb.genFieldThisRef[Boolean](s"zip_eos_$i")) - val nEOS = mb.genFieldThisRef[Int]("zip_n_eos") + val eosPerStream = + producers.indices.map(i => mb.genFieldThisRef[Boolean](s"zip_eos_$i")) + val nEOS = mb.genFieldThisRef[Int]("zip_n_eos") - new StreamProducer { - override def method: EmitMethodBuilder[_] = mb - override val length: Option[EmitCodeBuilder => Code[Int]] = - anyFailAllFail(producers.map(_.length)) - .map { compLens => - (cb: EmitCodeBuilder) => { + new StreamProducer { + override def method: EmitMethodBuilder[_] = mb + override val length: Option[EmitCodeBuilder => Code[Int]] = + anyFailAllFail(producers.map(_.length)) + .map { compLens => (cb: EmitCodeBuilder) => compLens.map(_.apply(cb)).reduce(_.max(_)) } - } - override def initialize(cb: EmitCodeBuilder, outerRegion: Value[Region]): Unit = { - producers.foreach { p => - if (p.requiresMemoryManagementPerElement) - cb.assign(p.elementRegion, eltRegion) - else - cb.assign(p.elementRegion, outerRegion) - p.initialize(cb, outerRegion) - } + override def initialize(cb: EmitCodeBuilder, outerRegion: Value[Region]): Unit = { + producers.foreach { p => + if (p.requiresMemoryManagementPerElement) + cb.assign(p.elementRegion, eltRegion) + else + cb.assign(p.elementRegion, outerRegion) + p.initialize(cb, outerRegion) + } - eosPerStream.foreach { eos => - cb.assign(eos, false) + eosPerStream.foreach(eos => cb.assign(eos, false)) + cb.assign(nEOS, 0) } - cb.assign(nEOS, 0) - } - override val elementRegion: Settable[Region] = eltRegion + override val elementRegion: Settable[Region] = eltRegion - override val requiresMemoryManagementPerElement: Boolean = producers.exists(_.requiresMemoryManagementPerElement) + override val requiresMemoryManagementPerElement: Boolean = + producers.exists(_.requiresMemoryManagementPerElement) - override val LproduceElement: CodeLabel = mb.defineAndImplementLabel { cb => + override val LproduceElement: CodeLabel = mb.defineAndImplementLabel { cb => + producers.zipWithIndex.foreach { case (p, i) => + // label at the end of processing this element + val endProduce = CodeLabel() - producers.zipWithIndex.foreach { case (p, i) => + cb.if_(eosPerStream(i), cb.goto(endProduce)) - // label at the end of processing this element - val endProduce = CodeLabel() + cb.goto(p.LproduceElement) - cb.if_(eosPerStream(i), cb.goto(endProduce)) + /* after an EOS we set the EOS boolean for that stream, and check if all + * streams have ended */ + cb.define(p.LendOfStream) + cb.assign(nEOS, nEOS + 1) - cb.goto(p.LproduceElement) + cb.if_(nEOS.ceq(const(producers.length)), cb.goto(LendOfStream)) - // after an EOS we set the EOS boolean for that stream, and check if all streams have ended - cb.define(p.LendOfStream) - cb.assign(nEOS, nEOS + 1) + /* this stream has ended before each other, so we set the eos flag and the + * element EmitSettable */ + cb.assign(eosPerStream(i), true) + cb.assign(vars(i), EmitCode.missing(mb, vars(i).st)) - cb.if_(nEOS.ceq(const(producers.length)), cb.goto(LendOfStream)) + cb.goto(endProduce) - // this stream has ended before each other, so we set the eos flag and the element EmitSettable - cb.assign(eosPerStream(i), true) - cb.assign(vars(i), EmitCode.missing(mb, vars(i).st)) + cb.define(p.LproduceElementDone) + cb.assign(vars(i), p.element) - cb.goto(endProduce) - - cb.define(p.LproduceElementDone) - cb.assign(vars(i), p.element) + cb.define(endProduce) + } - cb.define(endProduce) + cb.goto(LproduceElementDone) } - cb.goto(LproduceElementDone) - } - - val element: EmitCode = bodyCode + val element: EmitCode = bodyCode - def close(cb: EmitCodeBuilder): Unit = { - producers.foreach(_.close(cb)) + def close(cb: EmitCodeBuilder): Unit = + producers.foreach(_.close(cb)) } - } - } + } - SStreamValue(producer) + SStreamValue(producer) } - case x@StreamZipJoin(as, key, keyRef, valsRef, joinIR) => - IEmitCode.multiMapEmitCodes(cb, as.map(a => EmitCode.fromI(cb.emb)(cb => emit(a, cb)))) { children => - val producers = children.map(_.asStream.getProducer(mb)) - - val eltType = VirtualTypeWithReq.union(as.map(a => typeWithReqx(a))).canonicalEmitType - .st - .asInstanceOf[SStream] - .elementType - .storageType() - .setRequired(false) - .asInstanceOf[PCanonicalStruct] - - val keyType = eltType.selectFields(key) - - val curValsType = PCanonicalArray(eltType) - - val _elementRegion = mb.genFieldThisRef[Region]("szj_region") - val regionArray = mb.genFieldThisRef[Array[Region]]("szj_region_array") - - val staticMemManagementArray = producers.map(_.requiresMemoryManagementPerElement).toArray - val allMatch = staticMemManagementArray.toSet.size == 1 - val memoryManagementBooleansArray = if (allMatch) null else mb.genFieldThisRef[Array[Int]]("smm_separate_region_array") - - def initMemoryManagementPerElementArray(cb: EmitCodeBuilder): Unit = { - if (!allMatch) - cb.assign(memoryManagementBooleansArray, mb.getObject[Array[Int]](producers.map(_.requiresMemoryManagementPerElement.toInt).toArray)) - } - - def lookupMemoryManagementByIndex(cb: EmitCodeBuilder, idx: Code[Int]): Code[Boolean] = { - if (allMatch) - const(staticMemManagementArray.head) - else - memoryManagementBooleansArray.apply(idx).toZ - } - - // The algorithm maintains a tournament tree of comparisons between the - // current values of the k streams. The tournament tree is a complete - // binary tree with k leaves. The leaves of the tree are the streams, - // and each internal node represents the "contest" between the "winners" - // of the two subtrees, where the winner is the stream with the smaller - // current key. Each internal node stores the index of the stream which - // *lost* that contest. - // Each time we remove the overall winner, and replace that stream's - // leaf with its next value, we only need to rerun the contests on the - // path from that leaf to the root, comparing the new value with what - // previously lost that contest to the previous overall winner. - - val k = producers.length - // The leaf nodes of the tournament tree, each of which holds a pointer - // to the current value of that stream. - val heads = mb.genFieldThisRef[Array[Long]]("merge_heads") - // The internal nodes of the tournament tree, laid out in breadth-first - // order, each of which holds the index of the stream which lost that - // contest. - val bracket = mb.genFieldThisRef[Array[Int]]("merge_bracket") - // When updating the tournament tree, holds the winner of the subtree - // containing the updated leaf. Otherwise, holds the overall winner, i.e. - // the current least element. - val winner = mb.genFieldThisRef[Int]("merge_winner") - val result = mb.genFieldThisRef[Array[Long]]("merge_result") - val i = mb.genFieldThisRef[Int]("merge_i") - - val curKey = mb.newPField("st_grpby_curkey", keyType.sType) - - val xKey = mb.newEmitField("zipjoin_key", keyType.sType, required = true) - val xElts = mb.newEmitField("zipjoin_elts", curValsType.sType, required = true) - - val joinResult: EmitCode = EmitCode.fromI(mb) { cb => - val newEnv = env.bind((keyRef -> xKey), (valsRef -> xElts)) - emit(joinIR, cb, env = newEnv) - } - - val producer = new StreamProducer { - override def method: EmitMethodBuilder[_] = mb - override val length: Option[EmitCodeBuilder => Code[Int]] = None + case StreamZipJoin(as, key, keyRef, valsRef, joinIR) => + IEmitCode.multiMapEmitCodes(cb, as.map(a => EmitCode.fromI(cb.emb)(cb => emit(a, cb)))) { + children => + val producers = children.map(_.asStream.getProducer(mb)) + + val eltType = VirtualTypeWithReq.union(as.map(a => typeWithReqx(a))).canonicalEmitType + .st + .asInstanceOf[SStream] + .elementType + .storageType() + .setRequired(false) + .asInstanceOf[PCanonicalStruct] + + val keyType = eltType.selectFields(key) + + val curValsType = PCanonicalArray(eltType) + + val _elementRegion = mb.genFieldThisRef[Region]("szj_region") + val regionArray = mb.genFieldThisRef[Array[Region]]("szj_region_array") + + val staticMemManagementArray = + producers.map(_.requiresMemoryManagementPerElement).toArray + val allMatch = staticMemManagementArray.toSet.size == 1 + val memoryManagementBooleansArray = + if (allMatch) null else mb.genFieldThisRef[Array[Int]]("smm_separate_region_array") + + def initMemoryManagementPerElementArray(cb: EmitCodeBuilder): Unit = + if (!allMatch) + cb.assign( + memoryManagementBooleansArray, + mb.getObject[Array[Int]]( + producers.map(_.requiresMemoryManagementPerElement.toInt).toArray + ), + ) - override def initialize(cb: EmitCodeBuilder, outerRegion: Value[Region]): Unit = { - cb.assign(regionArray, Code.newArray[Region](k)) - producers.zipWithIndex.foreach { case (p, idx) => - if (p.requiresMemoryManagementPerElement) { - cb.assign(p.elementRegion, Region.stagedCreate(Region.REGULAR, outerRegion.getPool())) - } else - cb.assign(p.elementRegion, outerRegion) - cb += (regionArray(idx) = p.elementRegion) - p.initialize(cb, outerRegion) - } - initMemoryManagementPerElementArray(cb) - cb.assign(bracket, Code.newArray[Int](k)) - cb.assign(heads, Code.newArray[Long](k)) - cb.for_(cb.assign(i, 0), i < k, cb.assign(i, i + 1), { - cb += (bracket(i) = -1) - }) - cb.assign(result, Code._null) - cb.assign(i, 0) - cb.assign(winner, 0) + def lookupMemoryManagementByIndex(cb: EmitCodeBuilder, idx: Code[Int]): Code[Boolean] = + if (allMatch) + const(staticMemManagementArray.head) + else + memoryManagementBooleansArray.apply(idx).toZ + + // The algorithm maintains a tournament tree of comparisons between the + // current values of the k streams. The tournament tree is a complete + // binary tree with k leaves. The leaves of the tree are the streams, + // and each internal node represents the "contest" between the "winners" + // of the two subtrees, where the winner is the stream with the smaller + // current key. Each internal node stores the index of the stream which + // *lost* that contest. + // Each time we remove the overall winner, and replace that stream's + // leaf with its next value, we only need to rerun the contests on the + // path from that leaf to the root, comparing the new value with what + // previously lost that contest to the previous overall winner. + + val k = producers.length + // The leaf nodes of the tournament tree, each of which holds a pointer + // to the current value of that stream. + val heads = mb.genFieldThisRef[Array[Long]]("merge_heads") + // The internal nodes of the tournament tree, laid out in breadth-first + // order, each of which holds the index of the stream which lost that + // contest. + val bracket = mb.genFieldThisRef[Array[Int]]("merge_bracket") + // When updating the tournament tree, holds the winner of the subtree + // containing the updated leaf. Otherwise, holds the overall winner, i.e. + // the current least element. + val winner = mb.genFieldThisRef[Int]("merge_winner") + val result = mb.genFieldThisRef[Array[Long]]("merge_result") + val i = mb.genFieldThisRef[Int]("merge_i") + + val curKey = mb.newPField("st_grpby_curkey", keyType.sType) + + val xKey = mb.newEmitField("zipjoin_key", keyType.sType, required = true) + val xElts = mb.newEmitField("zipjoin_elts", curValsType.sType, required = true) + + val joinResult: EmitCode = EmitCode.fromI(mb) { cb => + val newEnv = env.bind((keyRef -> xKey), (valsRef -> xElts)) + emit(joinIR, cb, env = newEnv) } - override val elementRegion: Settable[Region] = _elementRegion - override val requiresMemoryManagementPerElement: Boolean = producers.exists(_.requiresMemoryManagementPerElement) - override val LproduceElement: CodeLabel = mb.defineAndImplementLabel { cb => - val LrunMatch = CodeLabel() - val LpullChild = CodeLabel() - val LloopEnd = CodeLabel() - val LaddToResult = CodeLabel() - val LstartNewKey = CodeLabel() - val Lpush = CodeLabel() - - def inSetup: Code[Boolean] = result.isNull + val producer = new StreamProducer { + override def method: EmitMethodBuilder[_] = mb + override val length: Option[EmitCodeBuilder => Code[Int]] = None - cb.if_(inSetup, { + override def initialize(cb: EmitCodeBuilder, outerRegion: Value[Region]): Unit = { + cb.assign(regionArray, Code.newArray[Region](k)) + producers.zipWithIndex.foreach { case (p, idx) => + if (p.requiresMemoryManagementPerElement) { + cb.assign( + p.elementRegion, + Region.stagedCreate(Region.REGULAR, outerRegion.getPool()), + ) + } else + cb.assign(p.elementRegion, outerRegion) + cb += (regionArray(idx) = p.elementRegion) + p.initialize(cb, outerRegion) + } + initMemoryManagementPerElementArray(cb) + cb.assign(bracket, Code.newArray[Int](k)) + cb.assign(heads, Code.newArray[Long](k)) + cb.for_(cb.assign(i, 0), i < k, cb.assign(i, i + 1), cb += (bracket(i) = -1)) + cb.assign(result, Code._null) cb.assign(i, 0) - cb.goto(LpullChild) - }, { - cb.if_(winner.ceq(k), cb.goto(LendOfStream), cb.goto(LstartNewKey)) - }) - - cb.define(Lpush) - cb.assign(xKey, EmitCode.present(cb.emb, curKey)) - cb.assign(xElts, EmitCode.present(cb.emb, curValsType.constructFromElements(cb, elementRegion, k, false) { (cb, i) => - IEmitCode(cb, result(i).ceq(0L), eltType.loadCheapSCode(cb, result(i))) - })) - cb.goto(LproduceElementDone) - - cb.define(LstartNewKey) - cb.for_(cb.assign(i, 0), i < k, cb.assign(i, i + 1), { - cb += (result(i) = 0L) - }) - cb.assign(curKey, eltType.loadCheapSCode(cb, heads(winner)).subset(key: _*) - .castTo(cb, elementRegion, curKey.st, true)) - cb.goto(LaddToResult) - - cb.define(LaddToResult) - cb += (result(winner) = heads(winner)) - cb.if_(lookupMemoryManagementByIndex(cb, winner), { - val r = cb.newLocal[Region]("tzj_winner_region", regionArray(winner)) - cb += elementRegion.trackAndIncrementReferenceCountOf(r) - cb += r.clearRegion() - }) - cb.goto(LpullChild) - - val matchIdx = mb.genFieldThisRef[Int]("merge_match_idx") - val challenger = mb.genFieldThisRef[Int]("merge_challenger") - // Compare 'winner' with value in 'matchIdx', loser goes in 'matchIdx', - // winner goes on to next round. A contestant '-1' beats everything - // (negative infinity), a contestant 'k' loses to everything - // (positive infinity), and values in between are indices into 'heads'. - - cb.define(LrunMatch) - cb.assign(challenger, bracket(matchIdx)) - cb.if_(matchIdx.ceq(0) || challenger.ceq(-1), cb.goto(LloopEnd)) - - val LafterChallenge = CodeLabel() - - cb.if_(challenger.cne(k), { - val LchallengerWins = CodeLabel() - - cb.if_(winner.ceq(k), cb.goto(LchallengerWins)) - - val left = eltType.loadCheapSCode(cb, heads(challenger)).subset(key: _*) - val right = eltType.loadCheapSCode(cb, heads(winner)).subset(key: _*) - val ord = StructOrdering.make(left.st, right.st, cb.emb.ecb, missingFieldsEqual = false) - cb.if_(ord.lteqNonnull(cb, left, right), - cb.goto(LchallengerWins), - cb.goto(LafterChallenge)) + cb.assign(winner, 0) + } - cb.define(LchallengerWins) - cb += (bracket(matchIdx) = winner) - cb.assign(winner, challenger) - }) - cb.define(LafterChallenge) - cb.assign(matchIdx, matchIdx >>> 1) - cb.goto(LrunMatch) + override val elementRegion: Settable[Region] = _elementRegion + override val requiresMemoryManagementPerElement: Boolean = + producers.exists(_.requiresMemoryManagementPerElement) + override val LproduceElement: CodeLabel = mb.defineAndImplementLabel { cb => + val LrunMatch = CodeLabel() + val LpullChild = CodeLabel() + val LloopEnd = CodeLabel() + val LaddToResult = CodeLabel() + val LstartNewKey = CodeLabel() + val Lpush = CodeLabel() + + def inSetup: Code[Boolean] = result.isNull + + cb.if_( + inSetup, { + cb.assign(i, 0) + cb.goto(LpullChild) + }, + cb.if_(winner.ceq(k), cb.goto(LendOfStream), cb.goto(LstartNewKey)), + ) - cb.define(LloopEnd) - cb.if_(matchIdx.ceq(0), { - // 'winner' is smallest of all k heads. If 'winner' = k, all heads - // must be k, and all streams are exhausted. + cb.define(Lpush) + cb.assign(xKey, EmitCode.present(cb.emb, curKey)) + cb.assign( + xElts, + EmitCode.present( + cb.emb, + curValsType.constructFromElements(cb, elementRegion, k, false) { (cb, i) => + IEmitCode(cb, result(i).ceq(0L), eltType.loadCheapSCode(cb, result(i))) + }, + ), + ) + cb.goto(LproduceElementDone) - cb.if_(inSetup, { - cb.if_(winner.ceq(k), - cb.goto(LendOfStream), - { - cb.assign(result, Code.newArray[Long](k)) - cb.goto(LstartNewKey) - }) - }, { - cb.if_(!winner.cne(k), cb.goto(Lpush)) - val left = eltType.loadCheapSCode(cb, heads(winner)).subset(key: _*) - val right = curKey - val ord = StructOrdering.make(left.st, right.st.asInstanceOf[SBaseStruct], - cb.emb.ecb, missingFieldsEqual = false) - cb.if_(ord.equivNonnull(cb, left, right), cb.goto(LaddToResult), cb.goto(Lpush)) - }) - }, { - // We're still in the setup phase - cb += (bracket(matchIdx) = winner) - cb.assign(i, i + 1) - cb.assign(winner, i) + cb.define(LstartNewKey) + cb.for_(cb.assign(i, 0), i < k, cb.assign(i, i + 1), cb += (result(i) = 0L)) + cb.assign( + curKey, + eltType.loadCheapSCode(cb, heads(winner)).subset(key: _*) + .castTo(cb, elementRegion, curKey.st, true), + ) + cb.goto(LaddToResult) + + cb.define(LaddToResult) + cb += (result(winner) = heads(winner)) + cb.if_( + lookupMemoryManagementByIndex(cb, winner), { + val r = cb.newLocal[Region]("tzj_winner_region", regionArray(winner)) + cb += elementRegion.trackAndIncrementReferenceCountOf(r) + cb += r.clearRegion() + }, + ) cb.goto(LpullChild) - }) - producers.zipWithIndex.foreach { case (p, idx) => - cb.define(p.LendOfStream) - cb.assign(winner, k) - cb.assign(matchIdx, (idx + k) >>> 1) + val matchIdx = mb.genFieldThisRef[Int]("merge_match_idx") + val challenger = mb.genFieldThisRef[Int]("merge_challenger") + // Compare 'winner' with value in 'matchIdx', loser goes in 'matchIdx', + // winner goes on to next round. A contestant '-1' beats everything + // (negative infinity), a contestant 'k' loses to everything + // (positive infinity), and values in between are indices into 'heads'. + + cb.define(LrunMatch) + cb.assign(challenger, bracket(matchIdx)) + cb.if_(matchIdx.ceq(0) || challenger.ceq(-1), cb.goto(LloopEnd)) + + val LafterChallenge = CodeLabel() + + cb.if_( + challenger.cne(k), { + val LchallengerWins = CodeLabel() + + cb.if_(winner.ceq(k), cb.goto(LchallengerWins)) + + val left = eltType.loadCheapSCode(cb, heads(challenger)).subset(key: _*) + val right = eltType.loadCheapSCode(cb, heads(winner)).subset(key: _*) + val ord = + StructOrdering.make(left.st, right.st, cb.emb.ecb, missingFieldsEqual = false) + cb.if_( + ord.lteqNonnull(cb, left, right), + cb.goto(LchallengerWins), + cb.goto(LafterChallenge), + ) + + cb.define(LchallengerWins) + cb += (bracket(matchIdx) = winner) + cb.assign(winner, challenger) + }, + ) + cb.define(LafterChallenge) + cb.assign(matchIdx, matchIdx >>> 1) cb.goto(LrunMatch) - cb.define(p.LproduceElementDone) - val storedElt = eltType.store(cb, p.elementRegion, p.element.toI(cb).get(cb), false) - cb += (heads(idx) = storedElt) - cb.assign(matchIdx, (idx + k) >>> 1) - cb.goto(LrunMatch) - } + cb.define(LloopEnd) + cb.if_( + matchIdx.ceq(0), { + // 'winner' is smallest of all k heads. If 'winner' = k, all heads + // must be k, and all streams are exhausted. + + cb.if_( + inSetup, { + cb.if_( + winner.ceq(k), + cb.goto(LendOfStream), { + cb.assign(result, Code.newArray[Long](k)) + cb.goto(LstartNewKey) + }, + ) + }, { + cb.if_(!winner.cne(k), cb.goto(Lpush)) + val left = eltType.loadCheapSCode(cb, heads(winner)).subset(key: _*) + val right = curKey + val ord = StructOrdering.make( + left.st, + right.st.asInstanceOf[SBaseStruct], + cb.emb.ecb, + missingFieldsEqual = false, + ) + cb.if_( + ord.equivNonnull(cb, left, right), + cb.goto(LaddToResult), + cb.goto(Lpush), + ) + }, + ) + }, { + // We're still in the setup phase + cb += (bracket(matchIdx) = winner) + cb.assign(i, i + 1) + cb.assign(winner, i) + cb.goto(LpullChild) + }, + ) - cb.define(LpullChild) - cb.switch( - winner, - cb.goto(LendOfStream), // can only happen if k=0 - producers.map { p => - () => cb.goto(p.LproduceElement) + producers.zipWithIndex.foreach { case (p, idx) => + cb.define(p.LendOfStream) + cb.assign(winner, k) + cb.assign(matchIdx, (idx + k) >>> 1) + cb.goto(LrunMatch) + + cb.define(p.LproduceElementDone) + val storedElt = + eltType.store(cb, p.elementRegion, p.element.toI(cb).getOrAssert(cb), false) + cb += (heads(idx) = storedElt) + cb.assign(matchIdx, (idx + k) >>> 1) + cb.goto(LrunMatch) } - ) - } - override val element: EmitCode = joinResult + cb.define(LpullChild) + cb.switch( + winner, + cb.goto(LendOfStream), // can only happen if k=0 + producers.map(p => () => cb.goto(p.LproduceElement)), + ) + } - override def close(cb: EmitCodeBuilder): Unit = { - producers.foreach { p => - if (p.requiresMemoryManagementPerElement) - cb += p.elementRegion.invalidate() - p.close(cb) + override val element: EmitCode = joinResult + + override def close(cb: EmitCodeBuilder): Unit = { + producers.foreach { p => + if (p.requiresMemoryManagementPerElement) + cb += p.elementRegion.invalidate() + p.close(cb) + } + cb.assign(bracket, Code._null) + cb.assign(heads, Code._null) + cb.assign(result, Code._null) } - cb.assign(bracket, Code._null) - cb.assign(heads, Code._null) - cb.assign(result, Code._null) } - } - SStreamValue(producer) + SStreamValue(producer) } - case x@StreamZipJoinProducers(contexts, ctxName, makeProducer, key, keyRef, valsRef, joinIR) => + case StreamZipJoinProducers(contexts, ctxName, makeProducer, key, keyRef, valsRef, + joinIR) => emit(contexts, cb).map(cb) { case contextsArray: SIndexableValue => val nStreams = cb.memoizeField(contextsArray.loadLength()) val iterArray = cb.memoizeField(Code.newArray[NoBoxLongIterator](nStreams), "iterArray") val idx = cb.newLocal[Int]("i", 0) - val eltType = VirtualTypeWithReq(TIterable.elementType(makeProducer.typ), - emitter.ctx.req.lookup(makeProducer).asInstanceOf[RIterable].elementType).canonicalPType + val eltType = VirtualTypeWithReq( + TIterable.elementType(makeProducer.typ), + emitter.ctx.req.lookup(makeProducer).asInstanceOf[RIterable].elementType, + ).canonicalPType .asInstanceOf[PCanonicalStruct] .setRequired(false) var streamRequiresMemoryManagement = false - cb.while_(idx < nStreams, { - val iter = produceIterator(makeProducer, - eltType, - cb, - outerRegion, - env.bind(ctxName, cb.memoize(contextsArray.loadElement(cb, idx)))) - .get(cb, "streams in zipJoinProducers cannot be missing") - .asInstanceOf[SStreamConcrete] - streamRequiresMemoryManagement = iter.st.requiresMemoryManagement - cb += iterArray.update(idx, iter.it) - cb.assign(idx, idx + 1) - }) + cb.while_( + idx < nStreams, { + val iter = produceIterator( + makeProducer, + eltType, + cb, + env.bind(ctxName, cb.memoize(contextsArray.loadElement(cb, idx))), + ) + .getOrFatal(cb, "streams in zipJoinProducers cannot be missing") + .asInstanceOf[SStreamConcrete] + streamRequiresMemoryManagement = iter.st.requiresMemoryManagement + cb += iterArray.update(idx, iter.it) + cb.assign(idx, idx + 1) + }, + ) val keyType = eltType.selectFields(key) @@ -2829,22 +3438,27 @@ object EmitStream { cb.assign(regionArray, Code.newArray[Region](nStreams)) cb.assign(bracket, Code.newArray[Int](k)) cb.assign(heads, Code.newArray[Long](k)) - cb.for_(cb.assign(i, 0), i < k, cb.assign(i, i + 1), { - cb.updateArray(bracket, i, -1) - val eltRegion: Value[Region] = if (streamRequiresMemoryManagement) { - val r = cb.memoize(Region.stagedCreate(Region.REGULAR, outerRegion.getPool())) - cb.updateArray(regionArray, i, r) - r - } else outerRegion - cb += iterArray(i).invoke[Region, Region, Unit]("init", outerRegion, eltRegion) - }) + cb.for_( + cb.assign(i, 0), + i < k, + cb.assign(i, i + 1), { + cb.updateArray(bracket, i, -1) + val eltRegion: Value[Region] = if (streamRequiresMemoryManagement) { + val r = cb.memoize(Region.stagedCreate(Region.REGULAR, outerRegion.getPool())) + cb.updateArray(regionArray, i, r) + r + } else outerRegion + cb += iterArray(i).invoke[Region, Region, Unit]("init", outerRegion, eltRegion) + }, + ) cb.assign(result, Code._null) cb.assign(i, 0) cb.assign(winner, 0) } override val elementRegion: Settable[Region] = _elementRegion - override val requiresMemoryManagementPerElement: Boolean = streamRequiresMemoryManagement + override val requiresMemoryManagementPerElement: Boolean = + streamRequiresMemoryManagement override val LproduceElement: CodeLabel = mb.defineAndImplementLabel { cb => val LrunMatch = CodeLabel() val LpullChild = CodeLabel() @@ -2855,26 +3469,34 @@ object EmitStream { def inSetup: Code[Boolean] = result.isNull - cb.if_(inSetup, { - cb.assign(i, 0) - cb.goto(LpullChild) - }, { - cb.if_(winner.ceq(k), cb.goto(LendOfStream), cb.goto(LstartNewKey)) - }) + cb.if_( + inSetup, { + cb.assign(i, 0) + cb.goto(LpullChild) + }, + cb.if_(winner.ceq(k), cb.goto(LendOfStream), cb.goto(LstartNewKey)), + ) cb.define(Lpush) cb.assign(xKey, EmitCode.present(cb.emb, curKey)) - cb.assign(xElts, EmitCode.present(cb.emb, curValsType.constructFromElements(cb, elementRegion, k, false) { (cb, i) => - IEmitCode(cb, result(i).ceq(0L), eltType.loadCheapSCode(cb, result(i))) - })) + cb.assign( + xElts, + EmitCode.present( + cb.emb, + curValsType.constructFromElements(cb, elementRegion, k, false) { (cb, i) => + IEmitCode(cb, result(i).ceq(0L), eltType.loadCheapSCode(cb, result(i))) + }, + ), + ) cb.goto(LproduceElementDone) cb.define(LstartNewKey) - cb.for_(cb.assign(i, 0), i < k, cb.assign(i, i + 1), { - cb.updateArray(result, i, 0L) - }) - cb.assign(curKey, eltType.loadCheapSCode(cb, heads(winner)).subset(key: _*) - .castTo(cb, elementRegion, curKey.st, true)) + cb.for_(cb.assign(i, 0), i < k, cb.assign(i, i + 1), cb.updateArray(result, i, 0L)) + cb.assign( + curKey, + eltType.loadCheapSCode(cb, heads(winner)).subset(key: _*) + .castTo(cb, elementRegion, curKey.st, true), + ) cb.goto(LaddToResult) cb.define(LaddToResult) @@ -2899,65 +3521,85 @@ object EmitStream { val LafterChallenge = CodeLabel() - cb.if_(challenger.cne(k), { - val LchallengerWins = CodeLabel() + cb.if_( + challenger.cne(k), { + val LchallengerWins = CodeLabel() - cb.if_(winner.ceq(k), cb.goto(LchallengerWins)) + cb.if_(winner.ceq(k), cb.goto(LchallengerWins)) - val left = eltType.loadCheapSCode(cb, heads(challenger)).subset(key: _*) - val right = eltType.loadCheapSCode(cb, heads(winner)).subset(key: _*) - val ord = StructOrdering.make(left.st, right.st, cb.emb.ecb, missingFieldsEqual = false) - cb.if_(ord.lteqNonnull(cb, left, right), - cb.goto(LchallengerWins), - cb.goto(LafterChallenge)) + val left = eltType.loadCheapSCode(cb, heads(challenger)).subset(key: _*) + val right = eltType.loadCheapSCode(cb, heads(winner)).subset(key: _*) + val ord = + StructOrdering.make(left.st, right.st, cb.emb.ecb, missingFieldsEqual = false) + cb.if_( + ord.lteqNonnull(cb, left, right), + cb.goto(LchallengerWins), + cb.goto(LafterChallenge), + ) - cb.define(LchallengerWins) - cb.updateArray(bracket, matchIdx, winner) - cb.assign(winner, challenger) - }) + cb.define(LchallengerWins) + cb.updateArray(bracket, matchIdx, winner) + cb.assign(winner, challenger) + }, + ) cb.define(LafterChallenge) cb.assign(matchIdx, matchIdx >>> 1) cb.goto(LrunMatch) cb.define(LloopEnd) - cb.if_(matchIdx.ceq(0), { - // 'winner' is smallest of all k heads. If 'winner' = k, all heads - // must be k, and all streams are exhausted. - - cb.if_(inSetup, { - cb.if_(winner.ceq(k), - cb.goto(LendOfStream), - { - cb.assign(result, Code.newArray[Long](k)) - cb.goto(LstartNewKey) - }) + cb.if_( + matchIdx.ceq(0), { + // 'winner' is smallest of all k heads. If 'winner' = k, all heads + // must be k, and all streams are exhausted. + + cb.if_( + inSetup, { + cb.if_( + winner.ceq(k), + cb.goto(LendOfStream), { + cb.assign(result, Code.newArray[Long](k)) + cb.goto(LstartNewKey) + }, + ) + }, { + cb.if_(!winner.cne(k), cb.goto(Lpush)) + val left = eltType.loadCheapSCode(cb, heads(winner)).subset(key: _*) + val right = curKey + val ord = StructOrdering.make( + left.st, + right.st.asInstanceOf[SBaseStruct], + cb.emb.ecb, + missingFieldsEqual = false, + ) + cb.if_( + ord.equivNonnull(cb, left, right), + cb.goto(LaddToResult), + cb.goto(Lpush), + ) + }, + ) }, { - cb.if_(!winner.cne(k), cb.goto(Lpush)) - val left = eltType.loadCheapSCode(cb, heads(winner)).subset(key: _*) - val right = curKey - val ord = StructOrdering.make(left.st, right.st.asInstanceOf[SBaseStruct], - cb.emb.ecb, missingFieldsEqual = false) - cb.if_(ord.equivNonnull(cb, left, right), cb.goto(LaddToResult), cb.goto(Lpush)) - }) - }, { - // We're still in the setup phase - cb.updateArray(bracket, matchIdx, winner) - cb.assign(i, i + 1) - cb.assign(winner, i) - cb.goto(LpullChild) - }) + // We're still in the setup phase + cb.updateArray(bracket, matchIdx, winner) + cb.assign(i, i + 1) + cb.assign(winner, i) + cb.goto(LpullChild) + }, + ) cb.define(LpullChild) cb.if_(winner >= nStreams, cb.goto(LendOfStream)) // can only happen if k=0 val winnerIter = cb.memoize(iterArray(winner)) val winnerNextElt = cb.memoize(winnerIter.invoke[Long]("next")) - cb.if_(winnerIter.invoke[Boolean]("eos"), { - cb.assign(matchIdx, (winner + k) >>> 1) - cb.assign(winner, k) - }, { - cb.assign(matchIdx, (winner + k) >>> 1) - cb.updateArray(heads, winner, winnerNextElt) - }) + cb.if_( + winnerIter.invoke[Boolean]("eos"), { + cb.assign(matchIdx, (winner + k) >>> 1) + cb.assign(winner, k) + }, { + cb.assign(matchIdx, (winner + k) >>> 1) + cb.updateArray(heads, winner, winnerNextElt) + }, + ) cb.goto(LrunMatch) } @@ -2965,12 +3607,14 @@ object EmitStream { override def close(cb: EmitCodeBuilder): Unit = { cb.assign(i, 0) - cb.while_(i < nStreams, { - cb += iterArray(i).invoke[Unit]("close") - if (requiresMemoryManagementPerElement) - cb += regionArray(i).invoke[Unit]("invalidate") - cb.assign(i, i + 1) - }) + cb.while_( + i < nStreams, { + cb += iterArray(i).invoke[Unit]("close") + if (requiresMemoryManagementPerElement) + cb += regionArray(i).invoke[Unit]("invalidate") + cb.assign(i, i + 1) + }, + ) if (requiresMemoryManagementPerElement) cb.assign(regionArray, Code._null) cb.assign(bracket, Code._null) @@ -2982,141 +3626,165 @@ object EmitStream { SStreamValue(producer) } + case StreamMultiMerge(as, key) => + IEmitCode.multiMapEmitCodes(cb, as.map(a => EmitCode.fromI(mb)(cb => emit(a, cb)))) { + children => + val producers = children.map(_.asStream.getProducer(mb)) + + val unifiedType = + VirtualTypeWithReq.union(as.map(a => typeWithReqx(a))).canonicalEmitType + .st + .asInstanceOf[SStream] + .elementEmitType + .storageType + .asInstanceOf[PCanonicalStruct] + + val regionArray = mb.genFieldThisRef[Array[Region]]("smm_region_array") + + val staticMemManagementArray = + producers.map(_.requiresMemoryManagementPerElement).toArray + val allMatch = staticMemManagementArray.toSet.size == 1 + val memoryManagementBooleansArray = + if (allMatch) null else mb.genFieldThisRef[Array[Int]]("smm_separate_region_array") + + def initMemoryManagementPerElementArray(cb: EmitCodeBuilder): Unit = + if (!allMatch) + cb.assign( + memoryManagementBooleansArray, + mb.getObject[Array[Int]]( + producers.map(_.requiresMemoryManagementPerElement.toInt).toArray + ), + ) - case x@StreamMultiMerge(as, key) => - IEmitCode.multiMapEmitCodes(cb, as.map(a => EmitCode.fromI(mb)(cb => emit(a, cb)))) { children => - val producers = children.map(_.asStream.getProducer(mb)) - - val unifiedType = VirtualTypeWithReq.union(as.map(a => typeWithReqx(a))).canonicalEmitType - .st - .asInstanceOf[SStream] - .elementEmitType - .storageType - .asInstanceOf[PCanonicalStruct] - - val region = mb.genFieldThisRef[Region]("smm_region") - val regionArray = mb.genFieldThisRef[Array[Region]]("smm_region_array") - - val staticMemManagementArray = producers.map(_.requiresMemoryManagementPerElement).toArray - val allMatch = staticMemManagementArray.toSet.size == 1 - val memoryManagementBooleansArray = if (allMatch) null else mb.genFieldThisRef[Array[Int]]("smm_separate_region_array") - - def initMemoryManagementPerElementArray(cb: EmitCodeBuilder): Unit = { - if (!allMatch) - cb.assign(memoryManagementBooleansArray, mb.getObject[Array[Int]](producers.map(_.requiresMemoryManagementPerElement.toInt).toArray)) - } + def lookupMemoryManagementByIndex(cb: EmitCodeBuilder, idx: Code[Int]): Code[Boolean] = + if (allMatch) + const(staticMemManagementArray.head) + else + memoryManagementBooleansArray.apply(idx).toZ - def lookupMemoryManagementByIndex(cb: EmitCodeBuilder, idx: Code[Int]): Code[Boolean] = { - if (allMatch) - const(staticMemManagementArray.head) - else - memoryManagementBooleansArray.apply(idx).toZ - } + val producer = + new StreamUtils.StreamMultiMergeBase(key, unifiedType, const(producers.length), mb) { + override def method: EmitMethodBuilder[_] = mb - val producer = new StreamUtils.StreamMultiMergeBase(key, unifiedType, const(producers.length), mb) { - override def method: EmitMethodBuilder[_] = mb + override val length: Option[EmitCodeBuilder => Code[Int]] = + anyFailAllFail(producers.map(_.length)) + .map { compLens => (cb: EmitCodeBuilder) => + compLens.map(_.apply(cb)).reduce(_ + _) + } - override val length: Option[EmitCodeBuilder => Code[Int]] = - anyFailAllFail(producers.map(_.length)) - .map { compLens => - (cb: EmitCodeBuilder) => { - compLens.map(_.apply(cb)).reduce(_ + _) + override def implInit(cb: EmitCodeBuilder, outerRegion: Value[Region]): Unit = { + cb.assign(regionArray, Code.newArray[Region](k)) + producers.zipWithIndex.foreach { case (p, i) => + if (p.requiresMemoryManagementPerElement) { + cb.assign( + p.elementRegion, + Region.stagedCreate(Region.REGULAR, outerRegion.getPool()), + ) + } else + cb.assign(p.elementRegion, outerRegion) + cb += (regionArray(i) = p.elementRegion) + p.initialize(cb, outerRegion) } + initMemoryManagementPerElementArray(cb) } - override def implInit(cb: EmitCodeBuilder, outerRegion: Value[Region]): Unit = { - cb.assign(regionArray, Code.newArray[Region](k)) - producers.zipWithIndex.foreach { case (p, i) => - if (p.requiresMemoryManagementPerElement) { - cb.assign(p.elementRegion, Region.stagedCreate(Region.REGULAR, outerRegion.getPool())) - } else - cb.assign(p.elementRegion, outerRegion) - cb += (regionArray(i) = p.elementRegion) - p.initialize(cb, outerRegion) - } - initMemoryManagementPerElementArray(cb) - } - - override val requiresMemoryManagementPerElement: Boolean = producers.exists(_.requiresMemoryManagementPerElement) - override val LproduceElement: CodeLabel = mb.defineAndImplementLabel { cb => - val LrunMatch = CodeLabel() - val LpullChild = CodeLabel() - val LloopEnd = CodeLabel() - - cb.define(LpullChild) - cb.switch(winner, - cb.goto(LendOfStream), // can only happen if k=0 - producers.map { p => - () => cb.goto(p.LproduceElement) - } - ) - - - cb.define(LrunMatch) - cb.assign(challenger, bracket(matchIdx)) - cb.if_(matchIdx.ceq(0) || challenger.ceq(-1), cb.goto(LloopEnd)) - - val LafterChallenge = CodeLabel() - cb.if_(challenger.cne(k), { - val Lwon = CodeLabel() - cb.if_(winner.ceq(k), cb.goto(Lwon)) - cb.if_(comp(cb, challenger, heads(challenger), winner, heads(winner)), cb.goto(Lwon), cb.goto(LafterChallenge)) - - cb.define(Lwon) - cb += (bracket(matchIdx) = winner) - cb.assign(winner, challenger) - }) - cb.define(LafterChallenge) - - cb.assign(matchIdx, matchIdx >>> 1) - cb.goto(LrunMatch) + override val requiresMemoryManagementPerElement: Boolean = + producers.exists(_.requiresMemoryManagementPerElement) + override val LproduceElement: CodeLabel = mb.defineAndImplementLabel { cb => + val LrunMatch = CodeLabel() + val LpullChild = CodeLabel() + val LloopEnd = CodeLabel() + + cb.define(LpullChild) + cb.switch( + winner, + cb.goto(LendOfStream), // can only happen if k=0 + producers.map(p => () => cb.goto(p.LproduceElement)), + ) - cb.define(LloopEnd) + cb.define(LrunMatch) + cb.assign(challenger, bracket(matchIdx)) + cb.if_(matchIdx.ceq(0) || challenger.ceq(-1), cb.goto(LloopEnd)) + + val LafterChallenge = CodeLabel() + cb.if_( + challenger.cne(k), { + val Lwon = CodeLabel() + cb.if_(winner.ceq(k), cb.goto(Lwon)) + cb.if_( + comp(cb, challenger, heads(challenger), winner, heads(winner)), + cb.goto(Lwon), + cb.goto(LafterChallenge), + ) - cb.if_(matchIdx.ceq(0), { - // 'winner' is smallest of all k heads. If 'winner' = k, all heads - // must be k, and all streams are exhausted. - cb.if_(winner.ceq(k), - cb.goto(LendOfStream), - { - // we have a winner - cb.if_(lookupMemoryManagementByIndex(cb, winner), { - val winnerRegion = cb.newLocal[Region]("smm_winner_region", regionArray(winner)) - cb += elementRegion.trackAndIncrementReferenceCountOf(winnerRegion) - cb += winnerRegion.clearRegion() - }) - cb.goto(LproduceElementDone) - }) - }, { - cb += (bracket(matchIdx) = winner) - cb.assign(i, i + 1) - cb.assign(winner, i) - cb.goto(LpullChild) - }) + cb.define(Lwon) + cb += (bracket(matchIdx) = winner) + cb.assign(winner, challenger) + }, + ) + cb.define(LafterChallenge) + + cb.assign(matchIdx, matchIdx >>> 1) + cb.goto(LrunMatch) + + cb.define(LloopEnd) + + cb.if_( + matchIdx.ceq(0), { + // 'winner' is smallest of all k heads. If 'winner' = k, all heads + // must be k, and all streams are exhausted. + cb.if_( + winner.ceq(k), + cb.goto(LendOfStream), { + // we have a winner + cb.if_( + lookupMemoryManagementByIndex(cb, winner), { + val winnerRegion = + cb.newLocal[Region]("smm_winner_region", regionArray(winner)) + cb += elementRegion.trackAndIncrementReferenceCountOf(winnerRegion) + cb += winnerRegion.clearRegion() + }, + ) + cb.goto(LproduceElementDone) + }, + ) + }, { + cb += (bracket(matchIdx) = winner) + cb.assign(i, i + 1) + cb.assign(winner, i) + cb.goto(LpullChild) + }, + ) - // define producer labels - producers.zipWithIndex.foreach { case (p, idx) => - cb.define(p.LendOfStream) - cb.assign(winner, k) - cb.assign(matchIdx, (const(idx) + k) >>> 1) - cb.goto(LrunMatch) + // define producer labels + producers.zipWithIndex.foreach { case (p, idx) => + cb.define(p.LendOfStream) + cb.assign(winner, k) + cb.assign(matchIdx, (const(idx) + k) >>> 1) + cb.goto(LrunMatch) - cb.define(p.LproduceElementDone) - cb += (heads(idx) = unifiedType.store(cb, p.elementRegion, p.element.toI(cb).get(cb), false)) - cb.assign(matchIdx, (const(idx) + k) >>> 1) - cb.goto(LrunMatch) - } - } + cb.define(p.LproduceElementDone) + cb += (heads(idx) = + unifiedType.store( + cb, + p.elementRegion, + p.element.toI(cb).getOrAssert(cb), + false, + ) + ) + cb.assign(matchIdx, (const(idx) + k) >>> 1) + cb.goto(LrunMatch) + } + } - override def implClose(cb: EmitCodeBuilder): Unit = { - producers.foreach { p => - if (p.requiresMemoryManagementPerElement) - cb += p.elementRegion.invalidate() - p.close(cb) + override def implClose(cb: EmitCodeBuilder): Unit = + producers.foreach { p => + if (p.requiresMemoryManagementPerElement) + cb += p.elementRegion.invalidate() + p.close(cb) + } } - } - } - SStreamValue(producer) + SStreamValue(producer) } case StreamLocalLDPrune(a, r2Threshold, winSize, maxQueueSize, nSamples) => @@ -3140,14 +3808,20 @@ object EmitStream { val element: EmitCode = elementField - val requiresMemoryManagementPerElement: Boolean = childProducer.requiresMemoryManagementPerElement + val requiresMemoryManagementPerElement: Boolean = + childProducer.requiresMemoryManagementPerElement def initialize(cb: EmitCodeBuilder, outerRegion: Value[Region]): Unit = { - cb.assign(queueSize, emit(maxQueueSize, cb).get(cb).asInt32.value) + cb.assign(queueSize, emit(maxQueueSize, cb).getOrAssert(cb).asInt32.value) cb.assign(queue, Code.newInstance[util.ArrayDeque[BitPackedVector], Int](queueSize)) - cb.assign(threshold, emit(r2Threshold, cb).get(cb).asFloat64.value) - cb.assign(windowSize, emit(winSize, cb).get(cb).asInt32.value) - cb.assign(builder, Code.newInstance[BitPackedVectorBuilder, Int](emit(nSamples, cb).get(cb).asInt32.value)) + cb.assign(threshold, emit(r2Threshold, cb).getOrAssert(cb).asFloat64.value) + cb.assign(windowSize, emit(winSize, cb).getOrAssert(cb).asInt32.value) + cb.assign( + builder, + Code.newInstance[BitPackedVectorBuilder, Int]( + emit(nSamples, cb).getOrAssert(cb).asInt32.value + ), + ) childProducer.initialize(cb, outerRegion) } @@ -3157,30 +3831,61 @@ object EmitStream { cb.goto(childProducer.LproduceElement) cb.define(childProducer.LproduceElementDone) - childProducer.element.toI(cb).consume(cb, + childProducer.element.toI(cb).consume( + cb, cb.goto(Lpruned), { case sc: SBaseStructValue => - val locus = sc.loadField(cb, "locus").get(cb).asLocus + val locus = sc.loadField(cb, "locus").getOrAssert(cb).asLocus val locusObj = locus.getLocusObj(cb) - val genotypes = sc.loadField(cb, "genotypes").get(cb).asIndexable + val genotypes = sc.loadField(cb, "genotypes").getOrAssert(cb).asIndexable cb += builder.invoke[Unit]("reset") - genotypes.forEachDefinedOrMissing(cb)({ (cb, _) => - cb += builder.invoke[Unit]("addMissing") - }, { (cb, _, gt) => - cb += builder.invoke[Int, Unit]("addGT", gt.asCall.canonicalCall(cb)) - }) - val bpv = cb.memoize(builder.invoke[Locus, Array[String], BitPackedVector]("finish", locusObj, Code._null[Array[String]])) + genotypes.forEachDefinedOrMissing(cb)( + (cb, _) => cb += builder.invoke[Unit]("addMissing"), + (cb, _, gt) => + cb += builder.invoke[Int, Unit]("addGT", gt.asCall.canonicalCall(cb)), + ) + val bpv = cb.memoize(builder.invoke[Locus, Array[String], BitPackedVector]( + "finish", + locusObj, + Code._null[Array[String]], + )) cb.if_(bpv.isNull, cb.goto(Lpruned)) - val keepVariant = Code.invokeScalaObject5[util.ArrayDeque[BitPackedVector], BitPackedVector, Double, Int, Int, Boolean](LocalLDPrune.getClass, "pruneLocal", - queue, bpv, threshold, windowSize, queueSize) + val keepVariant = Code.invokeScalaObject5[ + util.ArrayDeque[BitPackedVector], + BitPackedVector, + Double, + Int, + Int, + Boolean, + ]( + LocalLDPrune.getClass, + "pruneLocal", + queue, + bpv, + threshold, + windowSize, + queueSize, + ) cb.if_(!keepVariant, cb.goto(Lpruned)) val mean = SFloat64Value(cb.memoize(bpv.invoke[Double]("mean"))) - val centeredLengthRec = SFloat64Value(cb.memoize(bpv.invoke[Double]("centeredLengthRec"))) - val elt = SStackStruct.constructFromArgs(cb, elementRegion, elementType.virtualType.asInstanceOf[TBaseStruct], - EmitCode.present(mb, locus), EmitCode.fromI(mb)(cb => sc.loadField(cb, "alleles")), EmitCode.present(mb, mean), EmitCode.present(mb, centeredLengthRec)) - cb.assign(elementField, EmitCode.present(mb, elt.castTo(cb, elementRegion, elementField.emitType.st))) - }) + val centeredLengthRec = + SFloat64Value(cb.memoize(bpv.invoke[Double]("centeredLengthRec"))) + val elt = SStackStruct.constructFromArgs( + cb, + elementRegion, + elementType.virtualType.asInstanceOf[TBaseStruct], + EmitCode.present(mb, locus), + EmitCode.fromI(mb)(cb => sc.loadField(cb, "alleles")), + EmitCode.present(mb, mean), + EmitCode.present(mb, centeredLengthRec), + ) + cb.assign( + elementField, + EmitCode.present(mb, elt.castTo(cb, elementRegion, elementField.emitType.st)), + ) + }, + ) cb.goto(LproduceElementDone) @@ -3188,14 +3893,12 @@ object EmitStream { if (requiresMemoryManagementPerElement) cb += childProducer.elementRegion.clearRegion() cb.goto(childProducer.LproduceElement) - } + } def close(cb: EmitCodeBuilder): Unit = childProducer.close(cb) } - mb.implementLabel(childProducer.LendOfStream) { cb => - cb.goto(producer.LendOfStream) - } + mb.implementLabel(childProducer.LendOfStream)(cb => cb.goto(producer.LendOfStream)) SStreamValue(producer) } diff --git a/hail/src/main/scala/is/hail/expr/ir/streams/StagedMinHeap.scala b/hail/src/main/scala/is/hail/expr/ir/streams/StagedMinHeap.scala new file mode 100644 index 00000000000..9f811bf717e --- /dev/null +++ b/hail/src/main/scala/is/hail/expr/ir/streams/StagedMinHeap.scala @@ -0,0 +1,260 @@ +package is.hail.expr.ir.streams + +import is.hail.annotations.{Region, RegionPool} +import is.hail.asm4s._ +import is.hail.expr.ir.{EmitClassBuilder, EmitCodeBuilder, EmitMethodBuilder, EmitModuleBuilder} +import is.hail.expr.ir.agg.StagedArrayBuilder +import is.hail.types.physical.PCanonicalArray +import is.hail.types.physical.stypes.{SType, SValue} +import is.hail.types.physical.stypes.interfaces.SIndexableValue +import is.hail.utils.FastSeq + +sealed trait StagedMinHeap { + def arraySType: SType + + def init(cb: EmitCodeBuilder, pool: Value[RegionPool]): Unit + def realloc(cb: EmitCodeBuilder): Unit + def close(cb: EmitCodeBuilder): Unit + + def push(cb: EmitCodeBuilder, a: SValue): Unit + def peek(cb: EmitCodeBuilder): SValue + def pop(cb: EmitCodeBuilder): Unit + def nonEmpty(cb: EmitCodeBuilder): Value[Boolean] + + def toArray(cb: EmitCodeBuilder, region: Value[Region]): SIndexableValue +} + +object StagedMinHeap { + def apply( + modb: EmitModuleBuilder, + elemSType: SType, + )( + comparator: (EmitCodeBuilder, SValue, SValue) => Value[Int] + ): EmitClassBuilder[_] => StagedMinHeap = { + + val elemPType = elemSType.storageType().setRequired(true) + val elemParamType = elemPType.sType.paramType + + val classBuilder: EmitClassBuilder[Unit] = + modb.genEmitClass[Unit](s"MinHeap${elemPType.asIdent}") + + val pool: ThisFieldRef[RegionPool] = + classBuilder.genFieldThisRef[RegionPool]("pool") + + val region: ThisFieldRef[Region] = + classBuilder.genFieldThisRef[Region]("region") + + val garbage: ThisFieldRef[Long] = + classBuilder.genFieldThisRef[Long]("n_garbage_points") + + val heap = new StagedArrayBuilder(elemPType, classBuilder, region) + val ctor: EmitMethodBuilder[Unit] = + classBuilder.defineEmitMethod("", FastSeq(typeInfo[RegionPool]), UnitInfo) { mb => + val poolRef = mb.getCodeParam[RegionPool](1) + + mb.voidWithBuilder { cb => + cb += classBuilder.cb.super_.invoke(coerce[Object](cb.this_), Array()) + cb.assign(pool, poolRef) + cb.assign(region, Region.stagedCreate(Region.REGULAR, poolRef)) + cb.assign(garbage, 0L) + heap.initialize(cb) + } + } + + val load: EmitMethodBuilder[_] = + classBuilder.defineEmitMethod("load", FastSeq(IntInfo), elemParamType) { mb => + mb.emitSCode { cb => + val idx = mb.getCodeParam[Int](1) + heap.loadElement(cb, idx).toI(cb).getOrAssert(cb, debugMsg = idx.toS) + } + } + + val compareAtIndex: EmitMethodBuilder[_] = + classBuilder.defineEmitMethod("compareAtIndex", FastSeq(IntInfo, IntInfo), IntInfo) { mb => + mb.emitWithBuilder[Int] { cb => + val l = cb.invokeSCode(load, cb.this_, mb.getCodeParam[Int](1)) + val r = cb.invokeSCode(load, cb.this_, mb.getCodeParam[Int](2)) + comparator(cb, l, r) + } + } + + val realloc_ : EmitMethodBuilder[_] = + classBuilder.defineEmitMethod("realloc", FastSeq(), UnitInfo) { mb => + mb.voidWithBuilder { cb => + cb.if_( + garbage > heap.size.toL * 2L + 1024L, { + val oldRegion = cb.memoize(region, "tmp") + cb.assign(region, Region.stagedCreate(Region.REGULAR, pool)) + heap.reallocateData(cb) + cb.assign(garbage, 0L) + cb += oldRegion.invoke[Unit]("invalidate") + }, + ) + } + } + + val close_ : EmitMethodBuilder[_] = + classBuilder.defineEmitMethod("close", FastSeq(), UnitInfo) { mb => + mb.emit { + region.invoke[Unit]("invalidate") + } + } + + def thisNonEmpty: Code[Boolean] = + heap.size > 0 + + val peek_ : EmitMethodBuilder[_] = + classBuilder.defineEmitMethod("peek", FastSeq(), elemParamType) { mb => + mb.emitSCode { cb => + cb._assert(thisNonEmpty, s"${classBuilder.className}: peek empty") + cb.invokeSCode(load, cb.this_, cb.memoize(0)) + } + } + + val swap: EmitMethodBuilder[_] = + classBuilder.defineEmitMethod("swap", FastSeq(IntInfo, IntInfo), UnitInfo) { mb => + mb.voidWithBuilder { cb => + val x = mb.getCodeParam[Int](1) + val y = mb.getCodeParam[Int](2) + heap.swap(cb, x, y) + } + } + + val heapify: EmitMethodBuilder[_] = + classBuilder.defineEmitMethod("heapify", FastSeq(), UnitInfo) { mb => + mb.voidWithBuilder { cb => + val Ldone = CodeLabel() + cb.if_(heap.size <= 1, cb.goto(Ldone)) + + val index = cb.newLocal[Int]("index", 0) + val smallest = cb.newLocal[Int]("smallest", index) + + val child = cb.newLocal[Int]("child") + cb.loop { Lrecur => + // left child + cb.assign(child, index * 2 + 1) + cb.if_( + child < heap.size, + cb.if_( + cb.invokeCode[Int](compareAtIndex, cb.this_, child, index) < 0, + cb.assign(smallest, child), + ), + ) + + // right child + cb.assign(child, index * 2 + 2) + cb.if_( + child < heap.size, + cb.if_( + cb.invokeCode[Int](compareAtIndex, cb.this_, child, smallest) < 0, + cb.assign(smallest, child), + ), + ) + + cb.if_(smallest ceq index, cb.goto(Ldone)) + + cb.invokeVoid(swap, cb.this_, index, smallest) + cb.assign(index, smallest) + cb.goto(Lrecur) + } + + cb.define(Ldone) + } + } + + val pop_ : EmitMethodBuilder[_] = + classBuilder.defineEmitMethod("pop", FastSeq(), UnitInfo) { mb => + mb.voidWithBuilder { cb => + cb._assert(thisNonEmpty, s"${classBuilder.className}: poll empty") + + cb.assign(garbage, garbage + 1L) + val newSize = cb.memoize(heap.size - 1) + cb.if_( + newSize ceq 0, + cb.assign(heap.size, newSize), { + cb.invokeVoid(swap, cb.this_, const(0), newSize) + cb.assign(heap.size, newSize) + cb.invokeVoid(heapify, cb.this_) + }, + ) + } + } + + val append: EmitMethodBuilder[_] = + classBuilder.defineEmitMethod("append", FastSeq(elemParamType), UnitInfo) { mb => + mb.voidWithBuilder(cb => heap.append(cb, mb.getSCodeParam(1))) + } + + val push_ : EmitMethodBuilder[_] = + classBuilder.defineEmitMethod("push", FastSeq(elemParamType), UnitInfo) { mb => + mb.voidWithBuilder { cb => + cb.invokeVoid(append, cb.this_, mb.getSCodeParam(1)) + val Ldone = CodeLabel() + val current = cb.newLocal[Int]("index", heap.size - 1) + val parent = cb.newLocal[Int]("parent") + + cb.while_( + current > 0, { + cb.assign(parent, (current - 1) / 2) + val cmp = cb.invokeCode[Int](compareAtIndex, cb.this_, parent, current) + cb.if_(cmp <= 0, cb.goto(Ldone)) + + cb.invokeVoid(swap, cb.this_, parent, current) + cb.assign(current, parent) + }, + ) + + cb.define(Ldone) + } + } + + val arrayPType = PCanonicalArray(elemPType, required = true) + + val toArray_ : EmitMethodBuilder[_] = + classBuilder.defineEmitMethod( + "toArray", + FastSeq(typeInfo[Region]), + arrayPType.sType.paramType, + ) { mb => + val region = mb.getCodeParam[Region](1) + mb.emitSCode { cb => + arrayPType.constructFromElements(cb, region, heap.size, true) { + case (cb, idx) => heap.loadElement(cb, idx).toI(cb) + } + } + } + + ecb => + new StagedMinHeap { + private[this] val this_ : ThisFieldRef[_] = + ecb.genFieldThisRef("minheap")(classBuilder.cb.ti) + + override def arraySType: SType = + arrayPType.sType + + override def init(cb: EmitCodeBuilder, pool: Value[RegionPool]): Unit = + cb.assignAny(this_, Code.newInstance(classBuilder.cb, ctor.mb, FastSeq(pool))) + + override def realloc(cb: EmitCodeBuilder): Unit = + cb.invokeVoid(realloc_, this_) + + override def close(cb: EmitCodeBuilder): Unit = + cb.invokeVoid(close_, this_) + + override def push(cb: EmitCodeBuilder, a: SValue): Unit = + cb.invokeVoid(push_, this_, a) + + override def peek(cb: EmitCodeBuilder): SValue = + cb.invokeSCode(peek_, this_) + + override def pop(cb: EmitCodeBuilder): Unit = + cb.invokeVoid(pop_, this_) + + override def nonEmpty(cb: EmitCodeBuilder): Value[Boolean] = + cb.memoize(classBuilder.getField[Int](heap.size.name).get(this_) > 0) + + override def toArray(cb: EmitCodeBuilder, region: Value[Region]): SIndexableValue = + cb.invokeSCode(toArray_, this_, region).asIndexable + } + } +} diff --git a/hail/src/main/scala/is/hail/expr/ir/streams/StreamUtils.scala b/hail/src/main/scala/is/hail/expr/ir/streams/StreamUtils.scala index a3dfba0fe2f..1592a330edb 100644 --- a/hail/src/main/scala/is/hail/expr/ir/streams/StreamUtils.scala +++ b/hail/src/main/scala/is/hail/expr/ir/streams/StreamUtils.scala @@ -2,12 +2,15 @@ package is.hail.expr.ir.streams import is.hail.annotations.Region import is.hail.asm4s._ +import is.hail.expr.ir.{ + EmitCode, EmitCodeBuilder, EmitMethodBuilder, IEmitCode, IR, NDArrayMap, NDArrayMap2, Ref, + RunAggScan, StagedArrayBuilder, StreamFilter, StreamFlatMap, StreamFold, StreamFold2, StreamFor, + StreamJoinRightDistinct, StreamMap, StreamScan, StreamZip, StreamZipJoin, +} import is.hail.expr.ir.orderings.StructOrdering -import is.hail.expr.ir.{EmitClassBuilder, EmitCode, EmitCodeBuilder, EmitEnv, EmitMethodBuilder, Env, IEmitCode, IR, NDArrayMap, NDArrayMap2, Param, ParamType, Ref, RunAggScan, StagedArrayBuilder, StreamFilter, StreamFlatMap, StreamFold, StreamFold2, StreamFor, StreamJoinRightDistinct, StreamMap, StreamScan, StreamZip, StreamZipJoin} -import is.hail.types.VirtualTypeWithReq -import is.hail.types.physical.{PCanonicalArray, PCanonicalStruct, PType} +import is.hail.types.physical.{PCanonicalArray, PCanonicalStruct} import is.hail.types.physical.stypes.SingleCodeType -import is.hail.types.physical.stypes.interfaces.{NoBoxLongIterator, SIndexableValue, SStream, SStreamIteratorLong, SStreamValue} +import is.hail.types.physical.stypes.interfaces.{NoBoxLongIterator, SIndexableValue} import is.hail.utils._ object StreamUtils { @@ -17,35 +20,40 @@ object StreamUtils { stream: StreamProducer, destRegion: Value[Region], addr: Value[Long], - errorId: Int + errorId: Int, ): Unit = { val currentElementIndex = cb.newLocal[Long]("store_ndarray_elements_stream_current_index", 0) - val currentElementAddress = cb.newLocal[Long]("store_ndarray_elements_stream_current_addr", addr) + val currentElementAddress = + cb.newLocal[Long]("store_ndarray_elements_stream_current_addr", addr) val elementType = stream.element.emitType.storageType val elementByteSize = elementType.byteSize var push: (EmitCodeBuilder, IEmitCode) => Unit = null - stream.memoryManagedConsume(destRegion, cb, setup = { cb => - push = { case (cb, iec) => - iec.consume(cb, - cb._throw(Code.newInstance[HailException, String, Int]( - "Cannot construct an ndarray with missing values.", errorId - )), - { sc => - elementType.storeAtAddress(cb, currentElementAddress, destRegion, sc, deepCopy = true) - }) - cb.assign(currentElementIndex, currentElementIndex + 1) - cb.assign(currentElementAddress, currentElementAddress + elementByteSize) - } - }) { cb => - push(cb, stream.element.toI(cb)) - } + stream.memoryManagedConsume( + destRegion, + cb, + setup = { cb => + push = { case (cb, iec) => + iec.consume( + cb, + cb._throw(Code.newInstance[HailException, String, Int]( + "Cannot construct an ndarray with missing values.", + errorId, + )), + sc => + elementType.storeAtAddress(cb, currentElementAddress, destRegion, sc, deepCopy = true), + ) + cb.assign(currentElementIndex, currentElementIndex + 1) + cb.assign(currentElementAddress, currentElementAddress + elementByteSize) + } + }, + )(cb => push(cb, stream.element.toI(cb))) } def toArray( cb: EmitCodeBuilder, stream: StreamProducer, - destRegion: Value[Region] + destRegion: Value[Region], ): SIndexableValue = { val mb = cb.emb @@ -53,7 +61,12 @@ object StreamUtils { val aTyp = PCanonicalArray(stream.element.emitType.storageType, true) stream.length match { case None => - val vab = new StagedArrayBuilder(cb, SingleCodeType.fromSType(stream.element.st), stream.element.required, 0) + val vab = new StagedArrayBuilder( + cb, + SingleCodeType.fromSType(stream.element.st), + stream.element.required, + 0, + ) writeToArrayBuilder(cb, stream, vab, destRegion) cb.assign(xLen, vab.size) @@ -62,18 +75,24 @@ object StreamUtils { } case Some(computeLen) => - var pushElem: (EmitCodeBuilder, IEmitCode) => Unit = null var finish: (EmitCodeBuilder) => SIndexableValue = null - stream.memoryManagedConsume(destRegion, cb, setup = { cb => - cb.assign(xLen, computeLen(cb)) - val (_pushElem, _finish) = aTyp.constructFromFunctions(cb, destRegion, xLen, deepCopy = stream.requiresMemoryManagementPerElement) - pushElem = _pushElem - finish = _finish - }) { cb => - pushElem(cb, stream.element.toI(cb)) - } + stream.memoryManagedConsume( + destRegion, + cb, + setup = { cb => + cb.assign(xLen, computeLen(cb)) + val (_pushElem, _finish) = aTyp.constructFromFunctions( + cb, + destRegion, + xLen, + deepCopy = stream.requiresMemoryManagementPerElement, + ) + pushElem = _pushElem + finish = _finish + }, + )(cb => pushElem(cb, stream.element.toI(cb))) finish(cb) } @@ -83,20 +102,33 @@ object StreamUtils { cb: EmitCodeBuilder, stream: StreamProducer, ab: StagedArrayBuilder, - destRegion: Value[Region] + destRegion: Value[Region], ): Unit = { - stream.memoryManagedConsume(destRegion, cb, setup = { cb => - ab.clear(cb) - stream.length match { - case Some(computeLen) => ab.ensureCapacity(cb, computeLen(cb)) - case None => ab.ensureCapacity(cb, 16) - } - + stream.memoryManagedConsume( + destRegion, + cb, + setup = { cb => + ab.clear(cb) + stream.length match { + case Some(computeLen) => ab.ensureCapacity(cb, computeLen(cb)) + case None => ab.ensureCapacity(cb, 16) + } - }) { cb => - stream.element.toI(cb).consume(cb, + }, + ) { cb => + stream.element.toI(cb).consume( + cb, ab.addMissing(cb), - sc => ab.add(cb, ab.elt.coerceSCode(cb, sc, destRegion, deepCopy = stream.requiresMemoryManagementPerElement).code) + sc => + ab.add( + cb, + ab.elt.coerceSCode( + cb, + sc, + destRegion, + deepCopy = stream.requiresMemoryManagementPerElement, + ).code, + ), ) } } @@ -110,7 +142,7 @@ object StreamUtils { case StreamMap(a, _, b) => traverse(a, mult); traverse(b, 2) case StreamFilter(a, _, b) => traverse(a, mult); traverse(b, 2) case StreamFlatMap(a, _, b) => traverse(a, mult); traverse(b, 2) - case StreamJoinRightDistinct(l, r, _, _, _, c, j, _) => + case StreamJoinRightDistinct(l, r, _, _, _, _, j, _) => traverse(l, mult); traverse(r, mult); traverse(j, 2) case StreamScan(a, z, _, _, b) => traverse(a, mult); traverse(z, 2); traverse(b, 2) @@ -135,9 +167,9 @@ object StreamUtils { traverse(l, mult); traverse(r, mult); traverse(body, 2) case _ => ir.children.foreach { - case child: IR => traverse(child, mult) - case _ => - } + case child: IR => traverse(child, mult) + case _ => + } } traverse(root, 1) @@ -147,11 +179,12 @@ object StreamUtils { def isIterationLinear(ir: IR, refName: String): Boolean = multiplicity(ir, refName) <= 1 - - abstract class StreamMultiMergeBase(key: IndexedSeq[String], + abstract class StreamMultiMergeBase( + key: IndexedSeq[String], unifiedType: PCanonicalStruct, val k: Value[Int], - mb: EmitMethodBuilder[_]) extends StreamProducer { + mb: EmitMethodBuilder[_], + ) extends StreamProducer { // The algorithm maintains a tournament tree of comparisons between the // current values of the k streams. The tournament tree is a complete // binary tree with k leaves. The leaves of the tree are the streams, @@ -186,55 +219,70 @@ object StreamUtils { val region = mb.genFieldThisRef[Region]("smm_region") - /** - * The ordering function in StreamMultiMerge should use missingFieldsEqual=false to be consistent - * with other nodes that deal with struct keys. When keys compare equal, the earlier index (in - * the list of stream children) should win. These semantics extend to missing key fields, which - * requires us to compile two orderings (l/r and r/l) to maintain the abilty to take from the - * left when key fields are missing. - */ - def comp(cb: EmitCodeBuilder, li: Code[Int], lv: Code[Long], ri: Code[Int], rv: Code[Long]): Code[Boolean] = { + /** The ordering function in StreamMultiMerge should use missingFieldsEqual=false to be + * consistent with other nodes that deal with struct keys. When keys compare equal, the earlier + * index (in the list of stream children) should win. These semantics extend to missing key + * fields, which requires us to compile two orderings (l/r and r/l) to maintain the abilty to + * take from the left when key fields are missing. + */ + def comp(cb: EmitCodeBuilder, li: Code[Int], lv: Code[Long], ri: Code[Int], rv: Code[Long]) + : Code[Boolean] = { val l = unifiedType.loadCheapSCode(cb, lv).asBaseStruct.subset(key: _*) val r = unifiedType.loadCheapSCode(cb, rv).asBaseStruct.subset(key: _*) - val ord1 = StructOrdering.make(l.asBaseStruct.st, r.asBaseStruct.st, cb.emb.ecb, missingFieldsEqual = false) - val ord2 = StructOrdering.make(r.asBaseStruct.st, l.asBaseStruct.st, cb.emb.ecb, missingFieldsEqual = false) + val ord1 = StructOrdering.make( + l.asBaseStruct.st, + r.asBaseStruct.st, + cb.emb.ecb, + missingFieldsEqual = false, + ) + val ord2 = StructOrdering.make( + r.asBaseStruct.st, + l.asBaseStruct.st, + cb.emb.ecb, + missingFieldsEqual = false, + ) val b = cb.newLocal[Boolean]("stream_merge_comp_result") - cb.if_(li < ri, + cb.if_( + li < ri, cb.assign(b, ord1.compareNonnull(cb, l, r) <= 0), - cb.assign(b, ord2.compareNonnull(cb, r, l) > 0)) + cb.assign(b, ord2.compareNonnull(cb, r, l) > 0), + ) b } def implInit(cb: EmitCodeBuilder, outerRegion: Value[Region]): Unit + final def initialize(cb: EmitCodeBuilder, outerRegion: Value[Region]): Unit = { implInit(cb, outerRegion) cb.assign(bracket, Code.newArray[Int](k)) cb.assign(heads, Code.newArray[Long](k)) - cb.for_(cb.assign(i, 0), i < k, cb.assign(i, i + 1), { - cb += (bracket(i) = -1) - }) + cb.for_(cb.assign(i, 0), i < k, cb.assign(i, i + 1), cb += (bracket(i) = -1)) cb.assign(i, 0) cb.assign(winner, 0) } def implClose(cb: EmitCodeBuilder): Unit + final def close(cb: EmitCodeBuilder): Unit = { implClose(cb) cb.assign(bracket, Code._null) cb.assign(heads, Code._null) } - override final val elementRegion: Settable[Region] = region + final override val elementRegion: Settable[Region] = region - override final val element: EmitCode = EmitCode.fromI(mb)(cb => IEmitCode.present(cb, unifiedType.loadCheapSCode(cb, heads(winner)))) + final override val element: EmitCode = + EmitCode.fromI(mb)(cb => IEmitCode.present(cb, unifiedType.loadCheapSCode(cb, heads(winner)))) } - def multiMergeIterators(cb: EmitCodeBuilder, + def multiMergeIterators( + cb: EmitCodeBuilder, reqMemManagementArray: Either[Array[Boolean], Boolean], iterators: Value[Array[NoBoxLongIterator]], key: IndexedSeq[String], - unifiedType: PCanonicalStruct): StreamProducer = { + unifiedType: PCanonicalStruct, + ): StreamProducer = { val mb = cb.emb @@ -244,26 +292,30 @@ object StreamUtils { case Left(arr) => val fd = mb.genFieldThisRef[Array[Boolean]]("memManagement") fd -> ((cb: EmitCodeBuilder) => (cb.assign(fd, mb.getObject[Array[Boolean]](arr)))) - case Right(b) => (null, ((cb: EmitCodeBuilder) => ())) + case Right(_) => (null, ((cb: EmitCodeBuilder) => ())) } - def lookupMemoryManagementByIndex(cb: EmitCodeBuilder, idx: Code[Int]): Value[Boolean] = { + def lookupMemoryManagementByIndex(cb: EmitCodeBuilder, idx: Code[Int]): Value[Boolean] = reqMemManagementArray match { - case Left(arr) => cb.memoize(memManagementArrayField(idx)) + case Left(_) => cb.memoize(memManagementArrayField(idx)) case Right(b) => b } - } - new StreamMultiMergeBase(key, unifiedType, cb.memoizeField(iterators.length()), mb) { - def forEachIterator(cb: EmitCodeBuilder)(f: (EmitCodeBuilder, Value[Int], Value[NoBoxLongIterator]) => Unit) = { + def forEachIterator( + cb: EmitCodeBuilder + )( + f: (EmitCodeBuilder, Value[Int], Value[NoBoxLongIterator]) => Unit + ) = { val idx = cb.newLocal[Int]("idx", 0) - cb.while_(idx < k, { - val iter = cb.memoize(iterators(idx)) - f(cb, idx, iter) - cb.assign(idx, idx + 1) - }) + cb.while_( + idx < k, { + val iter = cb.memoize(iterators(idx)) + f(cb, idx, iter) + cb.assign(idx, idx + 1) + }, + ) } override def method: EmitMethodBuilder[_] = mb @@ -275,9 +327,11 @@ object StreamUtils { forEachIterator(cb) { case (cb, idx, iter) => val reqMM = lookupMemoryManagementByIndex(cb, idx) val eltRegion = cb.newLocal[Region]("eltRegion") - cb.if_(reqMM, + cb.if_( + reqMM, cb.assign(eltRegion, Region.stagedCreate(Region.REGULAR, outerRegion.getPool())), - cb.assign(eltRegion, outerRegion)) + cb.assign(eltRegion, outerRegion), + ) cb += iter.invoke[Region, Region, Unit]("init", outerRegion, eltRegion) cb += regionArray.update(idx, eltRegion) } @@ -299,11 +353,13 @@ object StreamUtils { cb.if_(winner >= k, cb.goto(LendOfStream)) val winnerIter = cb.memoize(iterators(winner)) val next = cb.memoize(winnerIter.invoke[Long]("next")) - cb.if_(winnerIter.invoke[Boolean]("eos"), { - cb.assign(matchIdx, (winner + k) >>> 1) - cb.assign(winner, k) - cb.goto(LrunMatch) - }) + cb.if_( + winnerIter.invoke[Boolean]("eos"), { + cb.assign(matchIdx, (winner + k) >>> 1) + cb.assign(winner, k) + cb.goto(LrunMatch) + }, + ) cb.if_(next ceq 0L, cb._fatal("stream multi merge: elements cannot be missing")) cb += heads.update(winner, next) @@ -315,15 +371,21 @@ object StreamUtils { cb.if_(matchIdx.ceq(0) || challenger.ceq(-1), cb.goto(LloopEnd)) val LafterChallenge = CodeLabel() - cb.if_(challenger.cne(k), { - val Lwon = CodeLabel() - cb.if_(winner.ceq(k), cb.goto(Lwon)) - cb.if_(comp(cb, challenger, heads(challenger), winner, heads(winner)), cb.goto(Lwon), cb.goto(LafterChallenge)) - - cb.define(Lwon) - cb += (bracket(matchIdx) = winner) - cb.assign(winner, challenger) - }) + cb.if_( + challenger.cne(k), { + val Lwon = CodeLabel() + cb.if_(winner.ceq(k), cb.goto(Lwon)) + cb.if_( + comp(cb, challenger, heads(challenger), winner, heads(winner)), + cb.goto(Lwon), + cb.goto(LafterChallenge), + ) + + cb.define(Lwon) + cb += (bracket(matchIdx) = winner) + cb.assign(winner, challenger) + }, + ) cb.define(LafterChallenge) cb.assign(matchIdx, matchIdx >>> 1) @@ -331,34 +393,38 @@ object StreamUtils { cb.define(LloopEnd) - cb.if_(matchIdx.ceq(0), { - // 'winner' is smallest of all k heads. If 'winner' = k, all heads - // must be k, and all streams are exhausted. - cb.if_(winner.ceq(k), - cb.goto(LendOfStream), - { - // we have a winner - cb.if_(lookupMemoryManagementByIndex(cb, winner), { - val winnerRegion = cb.newLocal[Region]("smm_winner_region", regionArray(winner)) - cb += elementRegion.trackAndIncrementReferenceCountOf(winnerRegion) - cb += winnerRegion.clearRegion() - }) - cb.goto(LproduceElementDone) - }) - }, { - cb += (bracket(matchIdx) = winner) - cb.assign(i, i + 1) - cb.assign(winner, i) - cb.goto(LpullChild) - }) + cb.if_( + matchIdx.ceq(0), { + // 'winner' is smallest of all k heads. If 'winner' = k, all heads + // must be k, and all streams are exhausted. + cb.if_( + winner.ceq(k), + cb.goto(LendOfStream), { + // we have a winner + cb.if_( + lookupMemoryManagementByIndex(cb, winner), { + val winnerRegion = cb.newLocal[Region]("smm_winner_region", regionArray(winner)) + cb += elementRegion.trackAndIncrementReferenceCountOf(winnerRegion) + cb += winnerRegion.clearRegion() + }, + ) + cb.goto(LproduceElementDone) + }, + ) + }, { + cb += (bracket(matchIdx) = winner) + cb.assign(i, i + 1) + cb.assign(winner, i) + cb.goto(LpullChild) + }, + ) } - override def implClose(cb: EmitCodeBuilder): Unit = { + override def implClose(cb: EmitCodeBuilder): Unit = forEachIterator(cb) { case (cb, idx, iter) => cb.if_(lookupMemoryManagementByIndex(cb, idx), cb += regionArray(idx).invalidate()) cb += iter.invoke[Unit]("close") } - } } } } diff --git a/hail/src/main/scala/is/hail/io/AbstractBinaryReader.scala b/hail/src/main/scala/is/hail/io/AbstractBinaryReader.scala index 1e7f180ea55..63899181ee6 100644 --- a/hail/src/main/scala/is/hail/io/AbstractBinaryReader.scala +++ b/hail/src/main/scala/is/hail/io/AbstractBinaryReader.scala @@ -2,7 +2,6 @@ package is.hail.io import java.io._ - abstract class AbstractBinaryReader { def read(): Int @@ -29,20 +28,19 @@ abstract class AbstractBinaryReader { def readLong(): Long = (read() & 0xff).asInstanceOf[Long] | - ((read() & 0xff).asInstanceOf[Long] << 8) | - ((read() & 0xff).asInstanceOf[Long] << 16) | - ((read() & 0xff).asInstanceOf[Long] << 24) | - ((read() & 0xff).asInstanceOf[Long] << 32) | - ((read() & 0xff).asInstanceOf[Long] << 40) | - ((read() & 0xff).asInstanceOf[Long] << 48) | - ((read() & 0xff).asInstanceOf[Long] << 56) + ((read() & 0xff).asInstanceOf[Long] << 8) | + ((read() & 0xff).asInstanceOf[Long] << 16) | + ((read() & 0xff).asInstanceOf[Long] << 24) | + ((read() & 0xff).asInstanceOf[Long] << 32) | + ((read() & 0xff).asInstanceOf[Long] << 40) | + ((read() & 0xff).asInstanceOf[Long] << 48) | + ((read() & 0xff).asInstanceOf[Long] << 56) def readInt(): Int = (read() & 0xff) | ((read() & 0xff) << 8) | ((read() & 0xff) << 16) | ((read() & 0xff) << 24) def readShort(): Int = (read() & 0xff) | ((read() & 0xff) << 8) - def readString(length: Int): String = { require(length >= 0) val byteArray = new Array[Byte](length) @@ -50,7 +48,10 @@ abstract class AbstractBinaryReader { if (result < 0) throw new EOFException() - new String(byteArray, "iso-8859-1") //FIXME figure out what BGENs are actually encoding; UTF-8 also works + new String( + byteArray, + "iso-8859-1", + ) // FIXME figure out what BGENs are actually encoding; UTF-8 also works } def readLengthAndString(lengthBytes: Int): String = { diff --git a/hail/src/main/scala/is/hail/io/BufferSpecs.scala b/hail/src/main/scala/is/hail/io/BufferSpecs.scala index 2ca8876fc89..76561cfcafb 100644 --- a/hail/src/main/scala/is/hail/io/BufferSpecs.scala +++ b/hail/src/main/scala/is/hail/io/BufferSpecs.scala @@ -1,54 +1,50 @@ package is.hail.io +import is.hail.asm4s._ import is.hail.compatibility.LZ4BlockBufferSpec +import is.hail.io.compress.LZ4 import is.hail.rvd.AbstractRVDSpec -import java.io._ -import is.hail.asm4s._ -import is.hail.io.compress.LZ4 -import org.json4s.{ DefaultFormats, Formats, ShortTypeHints } +import java.io._ +import org.json4s.{JValue, ShortTypeHints} import org.json4s.jackson.JsonMethods -import org.json4s.{Extraction, JValue} object BufferSpec { val zstdCompressionLEB: BufferSpec = LEB128BufferSpec( - BlockingBufferSpec(64 * 1024, - ZstdBlockBufferSpec(64 * 1024, - new StreamBlockBufferSpec))) + BlockingBufferSpec(64 * 1024, ZstdBlockBufferSpec(64 * 1024, new StreamBlockBufferSpec)) + ) val default: BufferSpec = zstdCompressionLEB - val blockedUncompressed: BufferSpec = BlockingBufferSpec(32 * 1024, - new StreamBlockBufferSpec) + val blockedUncompressed: BufferSpec = BlockingBufferSpec(32 * 1024, new StreamBlockBufferSpec) val unblockedUncompressed: BufferSpec = new StreamBufferSpec val wireSpec: BufferSpec = LEB128BufferSpec( - BlockingBufferSpec(64 * 1024, - ZstdSizedBasedBlockBufferSpec(64 * 1024, - /*minCompressionSize=*/256, - new StreamBlockBufferSpec))) + BlockingBufferSpec( + 64 * 1024, + ZstdSizedBasedBlockBufferSpec( + 64 * 1024, + /*minCompressionSize=*/ 256, + new StreamBlockBufferSpec, + ), + ) + ) + val memorySpec: BufferSpec = wireSpec // longtime default spec val lz4HCCompressionLEB: BufferSpec = LEB128BufferSpec( - BlockingBufferSpec(32 * 1024, - LZ4HCBlockBufferSpec(32 * 1024, - new StreamBlockBufferSpec))) + BlockingBufferSpec(32 * 1024, LZ4HCBlockBufferSpec(32 * 1024, new StreamBlockBufferSpec)) + ) val blockSpecs: Array[BufferSpec] = Array( - BlockingBufferSpec(64 * 1024, - new StreamBlockBufferSpec), - BlockingBufferSpec(32 * 1024, - LZ4HCBlockBufferSpec(32 * 1024, - new StreamBlockBufferSpec)), - BlockingBufferSpec(32 * 1024, - LZ4FastBlockBufferSpec(32 * 1024, - new StreamBlockBufferSpec)), - BlockingBufferSpec(64 * 1024, - ZstdBlockBufferSpec(64 * 1024, - new StreamBlockBufferSpec)), - new StreamBufferSpec) + BlockingBufferSpec(64 * 1024, new StreamBlockBufferSpec), + BlockingBufferSpec(32 * 1024, LZ4HCBlockBufferSpec(32 * 1024, new StreamBlockBufferSpec)), + BlockingBufferSpec(32 * 1024, LZ4FastBlockBufferSpec(32 * 1024, new StreamBlockBufferSpec)), + BlockingBufferSpec(64 * 1024, ZstdBlockBufferSpec(64 * 1024, new StreamBlockBufferSpec)), + new StreamBufferSpec, + ) val specs: Array[BufferSpec] = blockSpecs.flatMap { blockSpec => Array(blockSpec, LEB128BufferSpec(blockSpec)) @@ -61,10 +57,11 @@ object BufferSpec { def parseOrDefault( s: String, - default: BufferSpec = BufferSpec.default + default: BufferSpec = BufferSpec.default, ): BufferSpec = if (s == null) default else parse(s) - val shortTypeHints = ShortTypeHints(List( + val shortTypeHints = ShortTypeHints( + List( classOf[BlockBufferSpec], classOf[LZ4BlockBufferSpec], classOf[LZ4HCBlockBufferSpec], @@ -75,8 +72,10 @@ object BufferSpec { classOf[BufferSpec], classOf[LEB128BufferSpec], classOf[BlockingBufferSpec], - classOf[StreamBufferSpec] - ), typeHintFieldName = "name") + classOf[StreamBufferSpec], + ), + typeHintFieldName = "name", + ) } trait BufferSpec extends Spec { @@ -90,9 +89,11 @@ trait BufferSpec extends Spec { } final case class LEB128BufferSpec(child: BufferSpec) extends BufferSpec { - def buildInputBuffer(in: InputStream): InputBuffer = new LEB128InputBuffer(child.buildInputBuffer(in)) + def buildInputBuffer(in: InputStream): InputBuffer = + new LEB128InputBuffer(child.buildInputBuffer(in)) - def buildOutputBuffer(out: OutputStream): OutputBuffer = new LEB128OutputBuffer(child.buildOutputBuffer(out)) + def buildOutputBuffer(out: OutputStream): OutputBuffer = + new LEB128OutputBuffer(child.buildOutputBuffer(out)) def buildCodeInputBuffer(in: Code[InputStream]): Code[InputBuffer] = Code.newInstance[LEB128InputBuffer, InputBuffer](child.buildCodeInputBuffer(in)) @@ -104,15 +105,23 @@ final case class LEB128BufferSpec(child: BufferSpec) extends BufferSpec { final case class BlockingBufferSpec(blockSize: Int, child: BlockBufferSpec) extends BufferSpec { require(blockSize <= (1 << 16)) - def buildInputBuffer(in: InputStream): InputBuffer = new BlockingInputBuffer(blockSize, child.buildInputBuffer(in)) + def buildInputBuffer(in: InputStream): InputBuffer = + new BlockingInputBuffer(blockSize, child.buildInputBuffer(in)) - def buildOutputBuffer(out: OutputStream): OutputBuffer = new BlockingOutputBuffer(blockSize, child.buildOutputBuffer(out)) + def buildOutputBuffer(out: OutputStream): OutputBuffer = + new BlockingOutputBuffer(blockSize, child.buildOutputBuffer(out)) def buildCodeInputBuffer(in: Code[InputStream]): Code[InputBuffer] = - Code.newInstance[BlockingInputBuffer, Int, InputBlockBuffer](blockSize, child.buildCodeInputBuffer(in)) + Code.newInstance[BlockingInputBuffer, Int, InputBlockBuffer]( + blockSize, + child.buildCodeInputBuffer(in), + ) def buildCodeOutputBuffer(out: Code[OutputStream]): Code[OutputBuffer] = - Code.newInstance[BlockingOutputBuffer, Int, OutputBlockBuffer](blockSize, child.buildCodeOutputBuffer(out)) + Code.newInstance[BlockingOutputBuffer, Int, OutputBlockBuffer]( + blockSize, + child.buildCodeOutputBuffer(out), + ) } trait BlockBufferSpec extends Spec { @@ -138,15 +147,25 @@ abstract class LZ4BlockBufferSpecCommon extends BlockBufferSpec { def child: BlockBufferSpec - def buildInputBuffer(in: InputStream): InputBlockBuffer = new LZ4InputBlockBuffer(lz4, blockSize, child.buildInputBuffer(in)) + def buildInputBuffer(in: InputStream): InputBlockBuffer = + new LZ4InputBlockBuffer(lz4, blockSize, child.buildInputBuffer(in)) - def buildOutputBuffer(out: OutputStream): OutputBlockBuffer = new LZ4OutputBlockBuffer(lz4, blockSize, child.buildOutputBuffer(out)) + def buildOutputBuffer(out: OutputStream): OutputBlockBuffer = + new LZ4OutputBlockBuffer(lz4, blockSize, child.buildOutputBuffer(out)) def buildCodeInputBuffer(in: Code[InputStream]): Code[InputBlockBuffer] = - Code.newInstance[LZ4InputBlockBuffer, LZ4, Int, InputBlockBuffer](stagedlz4, blockSize, child.buildCodeInputBuffer(in)) + Code.newInstance[LZ4InputBlockBuffer, LZ4, Int, InputBlockBuffer]( + stagedlz4, + blockSize, + child.buildCodeInputBuffer(in), + ) def buildCodeOutputBuffer(out: Code[OutputStream]): Code[OutputBlockBuffer] = - Code.newInstance[LZ4OutputBlockBuffer, LZ4, Int, OutputBlockBuffer](stagedlz4, blockSize, child.buildCodeOutputBuffer(out)) + Code.newInstance[LZ4OutputBlockBuffer, LZ4, Int, OutputBlockBuffer]( + stagedlz4, + blockSize, + child.buildCodeOutputBuffer(out), + ) } final case class LZ4HCBlockBufferSpec(blockSize: Int, child: BlockBufferSpec) @@ -163,8 +182,12 @@ final case class LZ4FastBlockBufferSpec(blockSize: Int, child: BlockBufferSpec) def typeName = "LZ4FastBlockBufferSpec" } -final case class LZ4SizeBasedBlockBufferSpec(compressorType: String, blockSize: Int, minCompressionSize: Int, child: BlockBufferSpec) - extends BlockBufferSpec { +final case class LZ4SizeBasedBlockBufferSpec( + compressorType: String, + blockSize: Int, + minCompressionSize: Int, + child: BlockBufferSpec, +) extends BlockBufferSpec { def lz4: LZ4 = compressorType match { case "hc" => LZ4.hc case "fast" => LZ4.fast @@ -173,43 +196,81 @@ final case class LZ4SizeBasedBlockBufferSpec(compressorType: String, blockSize: def stagedlz4: Code[LZ4] = Code.invokeScalaObject0[LZ4](LZ4.getClass, "fast") def typeName = "LZ4SizeBasedBlockBufferSpec" - def buildInputBuffer(in: InputStream): InputBlockBuffer = new LZ4SizeBasedCompressingInputBlockBuffer(lz4, blockSize, child.buildInputBuffer(in)) + def buildInputBuffer(in: InputStream): InputBlockBuffer = + new LZ4SizeBasedCompressingInputBlockBuffer(lz4, blockSize, child.buildInputBuffer(in)) - def buildOutputBuffer(out: OutputStream): OutputBlockBuffer = new LZ4SizeBasedCompressingOutputBlockBuffer(lz4, blockSize, minCompressionSize, child.buildOutputBuffer(out)) + def buildOutputBuffer(out: OutputStream): OutputBlockBuffer = + new LZ4SizeBasedCompressingOutputBlockBuffer( + lz4, + blockSize, + minCompressionSize, + child.buildOutputBuffer(out), + ) def buildCodeInputBuffer(in: Code[InputStream]): Code[InputBlockBuffer] = - Code.newInstance[LZ4SizeBasedCompressingInputBlockBuffer, LZ4, Int, InputBlockBuffer](stagedlz4, blockSize, child.buildCodeInputBuffer(in)) + Code.newInstance[LZ4SizeBasedCompressingInputBlockBuffer, LZ4, Int, InputBlockBuffer]( + stagedlz4, + blockSize, + child.buildCodeInputBuffer(in), + ) def buildCodeOutputBuffer(out: Code[OutputStream]): Code[OutputBlockBuffer] = - Code.newInstance[LZ4SizeBasedCompressingOutputBlockBuffer, LZ4, Int, Int, OutputBlockBuffer](stagedlz4, blockSize, minCompressionSize, child.buildCodeOutputBuffer(out)) + Code.newInstance[LZ4SizeBasedCompressingOutputBlockBuffer, LZ4, Int, Int, OutputBlockBuffer]( + stagedlz4, + blockSize, + minCompressionSize, + child.buildCodeOutputBuffer(out), + ) } -final case class ZstdBlockBufferSpec(blockSize: Int, child: BlockBufferSpec) extends BlockBufferSpec { +final case class ZstdBlockBufferSpec(blockSize: Int, child: BlockBufferSpec) + extends BlockBufferSpec { require(blockSize <= (1 << 16)) - def buildInputBuffer(in: InputStream): InputBlockBuffer = new ZstdInputBlockBuffer(blockSize, child.buildInputBuffer(in)) + def buildInputBuffer(in: InputStream): InputBlockBuffer = + new ZstdInputBlockBuffer(blockSize, child.buildInputBuffer(in)) - def buildOutputBuffer(out: OutputStream): OutputBlockBuffer = new ZstdOutputBlockBuffer(blockSize, child.buildOutputBuffer(out)) + def buildOutputBuffer(out: OutputStream): OutputBlockBuffer = + new ZstdOutputBlockBuffer(blockSize, child.buildOutputBuffer(out)) def buildCodeInputBuffer(in: Code[InputStream]): Code[InputBlockBuffer] = - Code.newInstance[ZstdInputBlockBuffer, Int, InputBlockBuffer](blockSize, child.buildCodeInputBuffer(in)) + Code.newInstance[ZstdInputBlockBuffer, Int, InputBlockBuffer]( + blockSize, + child.buildCodeInputBuffer(in), + ) def buildCodeOutputBuffer(out: Code[OutputStream]): Code[OutputBlockBuffer] = - Code.newInstance[ZstdOutputBlockBuffer, Int, OutputBlockBuffer](blockSize, child.buildCodeOutputBuffer(out)) + Code.newInstance[ZstdOutputBlockBuffer, Int, OutputBlockBuffer]( + blockSize, + child.buildCodeOutputBuffer(out), + ) } -final case class ZstdSizedBasedBlockBufferSpec(blockSize: Int, minCompressionSize: Int, child: BlockBufferSpec) extends BlockBufferSpec { +final case class ZstdSizedBasedBlockBufferSpec( + blockSize: Int, + minCompressionSize: Int, + child: BlockBufferSpec, +) extends BlockBufferSpec { require(blockSize <= (1 << 16)) - def buildInputBuffer(in: InputStream): InputBlockBuffer = new ZstdSizedBasedInputBlockBuffer(blockSize, child.buildInputBuffer(in)) + def buildInputBuffer(in: InputStream): InputBlockBuffer = + new ZstdSizedBasedInputBlockBuffer(blockSize, child.buildInputBuffer(in)) - def buildOutputBuffer(out: OutputStream): OutputBlockBuffer = new ZstdSizedBasedOutputBlockBuffer(blockSize, minCompressionSize, child.buildOutputBuffer(out)) + def buildOutputBuffer(out: OutputStream): OutputBlockBuffer = + new ZstdSizedBasedOutputBlockBuffer(blockSize, minCompressionSize, child.buildOutputBuffer(out)) def buildCodeInputBuffer(in: Code[InputStream]): Code[InputBlockBuffer] = - Code.newInstance[ZstdSizedBasedInputBlockBuffer, Int, InputBlockBuffer](blockSize, child.buildCodeInputBuffer(in)) + Code.newInstance[ZstdSizedBasedInputBlockBuffer, Int, InputBlockBuffer]( + blockSize, + child.buildCodeInputBuffer(in), + ) def buildCodeOutputBuffer(out: Code[OutputStream]): Code[OutputBlockBuffer] = - Code.newInstance[ZstdSizedBasedOutputBlockBuffer, Int, Int, OutputBlockBuffer](blockSize, minCompressionSize, child.buildCodeOutputBuffer(out)) + Code.newInstance[ZstdSizedBasedOutputBlockBuffer, Int, Int, OutputBlockBuffer]( + blockSize, + minCompressionSize, + child.buildCodeOutputBuffer(out), + ) } object StreamBlockBufferSpec { diff --git a/hail/src/main/scala/is/hail/io/ByteArrayReader.scala b/hail/src/main/scala/is/hail/io/ByteArrayReader.scala index 039cc1eb855..eba31a437c3 100644 --- a/hail/src/main/scala/is/hail/io/ByteArrayReader.scala +++ b/hail/src/main/scala/is/hail/io/ByteArrayReader.scala @@ -24,9 +24,8 @@ class ByteArrayReader(val arr: Array[Byte]) extends AbstractBinaryReader { } } - def seek(pos: Int) { + def seek(pos: Int): Unit = position = pos - } def skipBytes(bytes: Long): Long = { require(bytes < Integer.MAX_VALUE) @@ -38,4 +37,4 @@ class ByteArrayReader(val arr: Array[Byte]) extends AbstractBinaryReader { } def hasNext(): Boolean = position < length -} \ No newline at end of file +} diff --git a/hail/src/main/scala/is/hail/io/CodecSpec.scala b/hail/src/main/scala/is/hail/io/CodecSpec.scala index 671959eaf83..5f1011ddbc9 100644 --- a/hail/src/main/scala/is/hail/io/CodecSpec.scala +++ b/hail/src/main/scala/is/hail/io/CodecSpec.scala @@ -1,15 +1,17 @@ package is.hail.io -import java.io.{ByteArrayInputStream, ByteArrayOutputStream, InputStream, OutputStream} import is.hail.annotations.{Region, RegionValue} -import is.hail.asm4s.{Code, HailClassLoader, theHailClassLoaderForSparkWorkers} +import is.hail.asm4s.{theHailClassLoaderForSparkWorkers, Code, HailClassLoader} import is.hail.backend.ExecuteContext +import is.hail.sparkextras.ContextRDD import is.hail.types.encoded.EType import is.hail.types.physical.PType import is.hail.types.virtual.Type -import is.hail.sparkextras.ContextRDD +import is.hail.utils.{using, ArrayOfByteArrayOutputStream} import is.hail.utils.prettyPrint.ArrayOfByteArrayInputStream -import is.hail.utils.{ArrayOfByteArrayOutputStream, using} + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, InputStream, OutputStream} + import org.apache.spark.rdd.RDD trait AbstractTypedCodecSpec extends Spec { @@ -31,7 +33,8 @@ trait AbstractTypedCodecSpec extends Spec { def decodedPType(): PType = encodedType.decodedPType(encodedVirtualType) - def buildDecoder(ctx: ExecuteContext, requestedType: Type): (PType, (InputStream, HailClassLoader) => Decoder) + def buildDecoder(ctx: ExecuteContext, requestedType: Type) + : (PType, (InputStream, HailClassLoader) => Decoder) def encode(ctx: ExecuteContext, t: PType, offset: Long): Array[Byte] = { val baos = new ByteArrayOutputStream() @@ -39,9 +42,8 @@ trait AbstractTypedCodecSpec extends Spec { baos.toByteArray } - def encode(ctx: ExecuteContext, t: PType, offset: Long, os: OutputStream): Unit = { + def encode(ctx: ExecuteContext, t: PType, offset: Long, os: OutputStream): Unit = using(buildEncoder(ctx, t)(os, ctx.theHailClassLoader))(_.writeRegionValue(offset)) - } def encodeArrays(ctx: ExecuteContext, t: PType, offset: Long): Array[Array[Byte]] = { val baos = new ArrayOfByteArrayOutputStream() @@ -49,13 +51,19 @@ trait AbstractTypedCodecSpec extends Spec { baos.toByteArrays() } - def decode(ctx: ExecuteContext, requestedType: Type, bytes: Array[Byte], region: Region): (PType, Long) = { + def decode(ctx: ExecuteContext, requestedType: Type, bytes: Array[Byte], region: Region) + : (PType, Long) = { val bais = new ByteArrayInputStream(bytes) val (pt, dec) = buildDecoder(ctx, requestedType) (pt, dec(bais, ctx.theHailClassLoader).readRegionValue(region)) } - def decodeArrays(ctx: ExecuteContext, requestedType: Type, bytes: Array[Array[Byte]], region: Region): (PType, Long) = { + def decodeArrays( + ctx: ExecuteContext, + requestedType: Type, + bytes: Array[Array[Byte]], + region: Region, + ): (PType, Long) = { val bais = new ArrayOfByteArrayInputStream(bytes) val (pt, dec) = buildDecoder(ctx, requestedType) (pt, dec(bais, ctx.theHailClassLoader).readRegionValue(region)) @@ -66,11 +74,15 @@ trait AbstractTypedCodecSpec extends Spec { def buildCodeOutputBuffer(os: Code[OutputStream]): Code[OutputBuffer] // FIXME: is there a better place for this to live? - def decodeRDD(ctx: ExecuteContext, requestedType: Type, bytes: RDD[Array[Byte]]): (PType, ContextRDD[Long]) = { + def decodeRDD(ctx: ExecuteContext, requestedType: Type, bytes: RDD[Array[Byte]]) + : (PType, ContextRDD[Long]) = { val (pt, dec) = buildDecoder(ctx, requestedType) - (pt, ContextRDD.weaken(bytes).cmapPartitions { (ctx, it) => - RegionValue.fromBytes(theHailClassLoaderForSparkWorkers, dec, ctx.region, it) - }) + ( + pt, + ContextRDD.weaken(bytes).cmapPartitions { (ctx, it) => + RegionValue.fromBytes(theHailClassLoaderForSparkWorkers, dec, ctx.region, it) + }, + ) } override def toString: String = super[Spec].toString diff --git a/hail/src/main/scala/is/hail/io/Decoder.scala b/hail/src/main/scala/is/hail/io/Decoder.scala index db0e41bb0ec..25a5d9452fa 100644 --- a/hail/src/main/scala/is/hail/io/Decoder.scala +++ b/hail/src/main/scala/is/hail/io/Decoder.scala @@ -1,14 +1,15 @@ package is.hail.io -import java.io._ import is.hail.annotations.Region import is.hail.asm4s._ -import is.hail.utils.RestartableByteArrayInputStream import is.hail.types.encoded.DecoderAsmFunction import is.hail.types.physical.PType +import is.hail.utils.RestartableByteArrayInputStream + +import java.io._ trait Decoder extends Closeable { - def close() + def close(): Unit def ptype: PType @@ -19,24 +20,28 @@ trait Decoder extends Closeable { def seek(offset: Long): Unit } -final class CompiledDecoder(in: InputBuffer, val ptype: PType, theHailClassLoader: HailClassLoader, f: (HailClassLoader) => DecoderAsmFunction) extends Decoder { - def close() { +final class CompiledDecoder( + in: InputBuffer, + val ptype: PType, + theHailClassLoader: HailClassLoader, + f: (HailClassLoader) => DecoderAsmFunction, +) extends Decoder { + def close(): Unit = in.close() - } def readByte(): Byte = in.readByte() private[this] val compiled = f(theHailClassLoader) - def readRegionValue(r: Region): Long = { + + def readRegionValue(r: Region): Long = compiled(r, in) - } def seek(offset: Long): Unit = in.seek(offset) } final class ByteArrayDecoder( theHailClassLoader: HailClassLoader, - makeDec: (InputStream, HailClassLoader) => Decoder + makeDec: (InputStream, HailClassLoader) => Decoder, ) extends Closeable { private[this] val bais = new RestartableByteArrayInputStream() private[this] val dec = makeDec(bais, theHailClassLoader) @@ -53,7 +58,6 @@ final class ByteArrayDecoder( def readValue(region: Region): Long = dec.readRegionValue(region) - def set(bytes: Array[Byte]) { + def set(bytes: Array[Byte]): Unit = bais.restart(bytes) - } } diff --git a/hail/src/main/scala/is/hail/io/DoubleInputBuffer.scala b/hail/src/main/scala/is/hail/io/DoubleInputBuffer.scala index e1b77ce4302..48f3d9c200b 100644 --- a/hail/src/main/scala/is/hail/io/DoubleInputBuffer.scala +++ b/hail/src/main/scala/is/hail/io/DoubleInputBuffer.scala @@ -1,22 +1,21 @@ package is.hail.io -import java.io.{Closeable, InputStream, OutputStream} - import is.hail.annotations.Memory import is.hail.utils._ +import java.io.{Closeable, InputStream, OutputStream} + final class DoubleInputBuffer(in: InputStream, bufSize: Int) extends Closeable { private val buf = new Array[Byte](bufSize) private var end: Int = 0 private var off: Int = 0 - def close() { + def close(): Unit = in.close() - } def readDoubles(to: Array[Double]): Unit = readDoubles(to, 0, to.length) - def readDoubles(to: Array[Double], toOff0: Int, n0: Int) { + def readDoubles(to: Array[Double], toOff0: Int, n0: Int): Unit = { assert(toOff0 >= 0) assert(n0 >= 0) assert(toOff0 <= to.length - n0) @@ -45,18 +44,17 @@ final class DoubleOutputBuffer(out: OutputStream, bufSize: Int) extends Closeabl private val buf: Array[Byte] = new Array[Byte](bufSize) private var off: Int = 0 - def close() { + def close(): Unit = { flush() out.close() } - def flush() { + def flush(): Unit = out.write(buf, 0, off) - } def writeDoubles(from: Array[Double]): Unit = writeDoubles(from, 0, from.length) - def writeDoubles(from: Array[Double], fromOff0: Int, n0: Int) { + def writeDoubles(from: Array[Double], fromOff0: Int, n0: Int): Unit = { assert(n0 >= 0) assert(fromOff0 >= 0) assert(fromOff0 <= from.length - n0) @@ -75,4 +73,4 @@ final class DoubleOutputBuffer(out: OutputStream, bufSize: Int) extends Closeabl Memory.memcpy(buf, off, from, fromOff, n) off += (n.toInt << 3) } -} \ No newline at end of file +} diff --git a/hail/src/main/scala/is/hail/io/ElasticsearchConnector.scala b/hail/src/main/scala/is/hail/io/ElasticsearchConnector.scala index eeceda4e1bd..60c39657b4b 100644 --- a/hail/src/main/scala/is/hail/io/ElasticsearchConnector.scala +++ b/hail/src/main/scala/is/hail/io/ElasticsearchConnector.scala @@ -1,11 +1,11 @@ package is.hail.io -import org.apache.spark -import org.elasticsearch.spark.sql._ - import scala.collection.JavaConverters._ import scala.collection.Map +import org.apache.spark +import org.elasticsearch.spark.sql._ + object ElasticsearchConnector { def export( @@ -16,14 +16,30 @@ object ElasticsearchConnector { indexType: String, blockSize: Int, config: java.util.HashMap[String, String], - verbose: Boolean) { - export(df, host, port, index, indexType, blockSize, - Option(config).map(_.asScala.toMap).getOrElse(Map.empty[String, String]), verbose) + verbose: Boolean, + ): Unit = { + export( + df, + host, + port, + index, + indexType, + blockSize, + Option(config).map(_.asScala.toMap).getOrElse(Map.empty[String, String]), + verbose, + ) } - def export(df: spark.sql.DataFrame, host: String = "localhost", port: Int = 9200, - index: String, indexType: String, blockSize: Int = 1000, - config: Map[String, String], verbose: Boolean = true) { + def export( + df: spark.sql.DataFrame, + host: String = "localhost", + port: Int = 9200, + index: String, + indexType: String, + blockSize: Int = 1000, + config: Map[String, String], + verbose: Boolean = true, + ): Unit = { // config docs: https://www.elastic.co/guide/en/elasticsearch/hadoop/master/configuration.html @@ -31,7 +47,8 @@ object ElasticsearchConnector { "es.nodes" -> host, "es.port" -> port.toString, "es.batch.size.entries" -> blockSize.toString, - "es.index.auto.create" -> "true") + "es.index.auto.create" -> "true", + ) val mergedConfig = if (config == null) defaultConfig @@ -39,8 +56,8 @@ object ElasticsearchConnector { defaultConfig ++ config if (verbose) - println(s"Config ${ mergedConfig }") + println(s"Config $mergedConfig") - df.saveToEs(s"${ index }/${ indexType }", mergedConfig) + df.saveToEs(s"$index/$indexType", mergedConfig) } } diff --git a/hail/src/main/scala/is/hail/io/Encoder.scala b/hail/src/main/scala/is/hail/io/Encoder.scala index 9141cc5ec9e..d76a581cf9d 100644 --- a/hail/src/main/scala/is/hail/io/Encoder.scala +++ b/hail/src/main/scala/is/hail/io/Encoder.scala @@ -1,11 +1,11 @@ package is.hail.io -import java.io._ - import is.hail.annotations.Region import is.hail.asm4s._ import is.hail.types.encoded.EncoderAsmFunction +import java.io._ + trait Encoder extends Closeable { def flush(): Unit @@ -18,30 +18,31 @@ trait Encoder extends Closeable { def indexOffset(): Long } -final class CompiledEncoder(out: OutputBuffer, theHailClassLoader: HailClassLoader, f: (HailClassLoader) => EncoderAsmFunction) extends Encoder { - def flush() { +final class CompiledEncoder( + out: OutputBuffer, + theHailClassLoader: HailClassLoader, + f: (HailClassLoader) => EncoderAsmFunction, +) extends Encoder { + def flush(): Unit = out.flush() - } - def close() { + def close(): Unit = out.close() - } private[this] val compiled = f(theHailClassLoader) - def writeRegionValue(offset: Long) { + + def writeRegionValue(offset: Long): Unit = compiled(offset, out) - } - def writeByte(b: Byte) { + def writeByte(b: Byte): Unit = out.writeByte(b) - } def indexOffset(): Long = out.indexOffset() } final class ByteArrayEncoder( theHailClassLoader: HailClassLoader, - makeEnc: (OutputStream, HailClassLoader) => Encoder + makeEnc: (OutputStream, HailClassLoader) => Encoder, ) extends Closeable { private[this] val baos = new ByteArrayOutputStream() private[this] val enc = makeEnc(baos, theHailClassLoader) diff --git a/hail/src/main/scala/is/hail/io/FileWriteMetadata.scala b/hail/src/main/scala/is/hail/io/FileWriteMetadata.scala index 55d2d2eb3dc..9c5cb92a06a 100644 --- a/hail/src/main/scala/is/hail/io/FileWriteMetadata.scala +++ b/hail/src/main/scala/is/hail/io/FileWriteMetadata.scala @@ -1,6 +1,5 @@ package is.hail.io - case class FileWriteMetadata(path: String, rowsWritten: Long, bytesWritten: Long) { def render(): String = s"$path\t$rowsWritten\t$bytesWritten" } diff --git a/hail/src/main/scala/is/hail/io/HadoopFSDataBinaryReader.scala b/hail/src/main/scala/is/hail/io/HadoopFSDataBinaryReader.scala index 26118260cf8..0a556280f1f 100644 --- a/hail/src/main/scala/is/hail/io/HadoopFSDataBinaryReader.scala +++ b/hail/src/main/scala/is/hail/io/HadoopFSDataBinaryReader.scala @@ -2,11 +2,13 @@ package is.hail.io import is.hail.io.fs.SeekableDataInputStream -class HadoopFSDataBinaryReader(fis: SeekableDataInputStream) extends AbstractBinaryReader with AutoCloseable { +class HadoopFSDataBinaryReader(fis: SeekableDataInputStream) + extends AbstractBinaryReader with AutoCloseable { override def read(): Int = fis.read() - override def read(byteArray: Array[Byte], hasRead: Int, toRead: Int): Int = fis.read(byteArray, hasRead, toRead) + override def read(byteArray: Array[Byte], hasRead: Int, toRead: Int): Int = + fis.read(byteArray, hasRead, toRead) def close(): Unit = fis.close() diff --git a/hail/src/main/scala/is/hail/io/IndexBTree.scala b/hail/src/main/scala/is/hail/io/IndexBTree.scala index 455a83f853e..11cf0d4fece 100644 --- a/hail/src/main/scala/is/hail/io/IndexBTree.scala +++ b/hail/src/main/scala/is/hail/io/IndexBTree.scala @@ -1,13 +1,13 @@ package is.hail.io -import java.io.{Closeable, DataOutputStream} -import java.util.Arrays - -import is.hail.utils._ import is.hail.io.fs.FS +import is.hail.utils._ import scala.collection.mutable +import java.io.{Closeable, DataOutputStream} +import java.util.Arrays + object IndexBTree { private[io] def calcDepth(internalAndExternalNodeCount: Long, branchingFactor: Int): Int = { var depth = 1 @@ -21,12 +21,12 @@ object IndexBTree { } private[io] def calcDepth(arr: Array[Long], branchingFactor: Int) = - //max necessary for array of length 1 becomes depth=0 + // max necessary for array of length 1 becomes depth=0 math.max(1, (math.log10(arr.length) / math.log10(branchingFactor)).ceil.toInt) private[io] def btreeLayers( arr: Array[Long], - branchingFactor: Int = 1024 + branchingFactor: Int = 1024, ): Array[Array[Long]] = { require(arr.length > 0) @@ -49,7 +49,7 @@ object IndexBTree { val paddingRequired = if (danglingElements == 0) 0 else branchingFactor - danglingElements - val padding = (0 until paddingRequired).map { _ => -1L } + val padding = (0 until paddingRequired).map(_ => -1L) // Write last layer layers.append(arr ++ padding) @@ -58,56 +58,66 @@ object IndexBTree { private[io] def btreeBytes( arr: Array[Long], - branchingFactor: Int = 1024 + branchingFactor: Int = 1024, ): Array[Byte] = btreeLayers(arr, branchingFactor) .flatten - .flatMap(l => Array[Byte]( - (l >>> 56).toByte, - (l >>> 48).toByte, - (l >>> 40).toByte, - (l >>> 32).toByte, - (l >>> 24).toByte, - (l >>> 16).toByte, - (l >>> 8).toByte, - (l >>> 0).toByte)) + .flatMap(l => + Array[Byte]( + (l >>> 56).toByte, + (l >>> 48).toByte, + (l >>> 40).toByte, + (l >>> 32).toByte, + (l >>> 24).toByte, + (l >>> 16).toByte, + (l >>> 8).toByte, + (l >>> 0).toByte, + ) + ) .toArray def write( arr: Array[Long], fileName: String, fs: FS, - branchingFactor: Int = 1024 + branchingFactor: Int = 1024, ): Unit = using(new DataOutputStream(fs.create(fileName))) { w => w.write(btreeBytes(arr, branchingFactor)) } def toString( arr: Array[Long], - branchingFactor: Int = 1024 + branchingFactor: Int = 1024, ): String = - btreeLayers(arr, branchingFactor).map(_.mkString("[", " ", "]")).mkString("(BTREE\n", "\n", "\n)") + btreeLayers(arr, branchingFactor).map(_.mkString("[", " ", "]")).mkString( + "(BTREE\n", + "\n", + "\n)", + ) } class IndexBTree(indexFileName: String, fs: FS, branchingFactor: Int = 1024) extends Closeable { val maxDepth = calcDepth() - private val is = try { - fs.openNoCompression(indexFileName) - } catch { - case e: Exception => fatal(s"Could not find a BGEN .idx file at $indexFileName. Try running HailContext.index_bgen().", e) - } + + private val is = + try + fs.openNoCompression(indexFileName) + catch { + case e: Exception => fatal( + s"Could not find a BGEN .idx file at $indexFileName. Try running HailContext.index_bgen().", + e, + ) + } def close(): Unit = is.close() def calcDepth(): Int = IndexBTree.calcDepth(fs.getFileSize(indexFileName) / 8, branchingFactor) - private def getOffset(depth: Int): Long = { + private def getOffset(depth: Int): Long = (1 until depth).map(math.pow(branchingFactor, _).toLong * 8).sum - } - private def getOffset(depth: Int, blockIndex: Long): Long = { + private def getOffset(depth: Int, blockIndex: Long): Long = getOffset(depth) + blockIndex * 8 * branchingFactor - } private def traverseTree(query: Long, startIndex: Long, currentDepth: Int): (Long, Long) = { @@ -115,7 +125,9 @@ class IndexBTree(indexFileName: String, fs: FS, branchingFactor: Int = 1024) ext def read(prevValue: Long, prevPos: Long): Long = { val currValue = is.readLong() - if (currentDepth != maxDepth && query >= prevValue && (query < currValue || currValue == -1L)) + if ( + currentDepth != maxDepth && query >= prevValue && (query < currValue || currValue == -1L) + ) prevPos else if (currentDepth == maxDepth && query <= currValue || currValue == -1L) currValue @@ -170,7 +182,7 @@ class IndexBTree(indexFileName: String, fs: FS, branchingFactor: Int = 1024) ext def queryIndex(query: Long): Option[Long] = { require(query >= 0) - val (index, result) = traverseTree(query, 0L, 1) + val (_, result) = traverseTree(query, 0L, 1) if (result != -1L) Option(result) @@ -190,21 +202,18 @@ class IndexBTree(indexFileName: String, fs: FS, branchingFactor: Int = 1024) ext } } -/** - * A BTree file of N elements is a sequence of layers containing 8-byte values. +/** A BTree file of N elements is a sequence of layers containing 8-byte values. * - * The size of layer i is {@code math.pow(branchingFactor, i + 1).toInt}. The - * last layer is the first layer whose size is large enough to contain N - * elements. The final layer contains all N elements, in their given order, - * followed by {@code branchingFactor - N} {@code -1}'s. - * - **/ + * The size of layer i is {@code math.pow(branchingFactor, i + 1).toInt} . The last layer is the + * first layer whose size is large enough to contain N elements. The final layer contains all N + * elements, in their given order, followed by {@code branchingFactor - N} {@code -1} 's. + */ // IndexBTree maps from a value to the next largest value, this treats the BTree // like an on-disk array and looks up values by index class OnDiskBTreeIndexToValue( path: String, fs: FS, - branchingFactor: Int = 1024 + branchingFactor: Int = 1024, ) extends AutoCloseable { private[this] def numLayers(size: Long): Int = IndexBTree.calcDepth(size, branchingFactor) @@ -221,13 +230,15 @@ class OnDiskBTreeIndexToValue( private[this] val layers = numLayers(fs.getFileSize(path) / 8) private[this] val junk = leadingElements(layers - 1) - private[this] var is = try { - log.info("reading index file: " + path) - fs.openNoCompression(path) - } catch { - case e: Exception => - fatal(s"Could not find a BGEN .idx file at $path. Try running HailContext.index_bgen().", e) - } + + private[this] var is = + try { + log.info("reading index file: " + path) + fs.openNoCompression(path) + } catch { + case e: Exception => + fatal(s"Could not find a BGEN .idx file at $path. Try running HailContext.index_bgen().", e) + } // WARNING: mutatively sorts the provided array def positionOfVariants(indices: Array[Int]): Array[Long] = { diff --git a/hail/src/main/scala/is/hail/io/IndexedBinaryBlockReader.scala b/hail/src/main/scala/is/hail/io/IndexedBinaryBlockReader.scala index 3501f89ba4d..beef7d2c48d 100644 --- a/hail/src/main/scala/is/hail/io/IndexedBinaryBlockReader.scala +++ b/hail/src/main/scala/is/hail/io/IndexedBinaryBlockReader.scala @@ -2,6 +2,7 @@ package is.hail.io import is.hail.annotations.RegionValueBuilder import is.hail.io.fs.{HadoopFS, WrappedSeekableDataInputStream} + import org.apache.commons.logging.{Log, LogFactory} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} @@ -12,21 +13,19 @@ abstract class KeySerializedValueRecord[K] extends Serializable { var input: Array[Byte] = _ var key: K = _ - def setSerializedValue(arr: Array[Byte]) { + def setSerializedValue(arr: Array[Byte]): Unit = this.input = arr - } def getValue(rvb: RegionValueBuilder, includeGT: Boolean): Unit - def setKey(k: K) { + def setKey(k: K): Unit = this.key = k - } def getKey: K = key } abstract class IndexedBinaryBlockReader[T](job: Configuration, split: FileSplit) - extends RecordReader[LongWritable, T] { + extends RecordReader[LongWritable, T] { val LOG: Log = LogFactory.getLog(classOf[IndexedBinaryBlockReader[T]].getName) val partitionStart: Long = split.getStart @@ -40,7 +39,9 @@ abstract class IndexedBinaryBlockReader[T](job: Configuration, split: FileSplit) val is = fs.open(file) new HadoopFSDataBinaryReader( new WrappedSeekableDataInputStream( - HadoopFS.toSeekableInputStream(is))) + HadoopFS.toSeekableInputStream(is) + ) + ) } def createKey(): LongWritable = new LongWritable() @@ -49,12 +50,11 @@ abstract class IndexedBinaryBlockReader[T](job: Configuration, split: FileSplit) def getPos: Long = pos - def getProgress: Float = { + def getProgress: Float = if (partitionStart == end) 0.0f else Math.min(1.0f, (pos - partitionStart) / (end - partitionStart).toFloat) - } def close() = bfis.close() diff --git a/hail/src/main/scala/is/hail/io/IndexedBinaryInputFormat.scala b/hail/src/main/scala/is/hail/io/IndexedBinaryInputFormat.scala index 5c51bde1788..c07d6608111 100644 --- a/hail/src/main/scala/is/hail/io/IndexedBinaryInputFormat.scala +++ b/hail/src/main/scala/is/hail/io/IndexedBinaryInputFormat.scala @@ -3,8 +3,8 @@ package is.hail.io import org.apache.hadoop.io.LongWritable import org.apache.hadoop.mapred._ - abstract class IndexedBinaryInputFormat[T] extends FileInputFormat[LongWritable, T] { - def getRecordReader(split: InputSplit, job: JobConf, reporter: Reporter): RecordReader[LongWritable, T] + def getRecordReader(split: InputSplit, job: JobConf, reporter: Reporter) + : RecordReader[LongWritable, T] } diff --git a/hail/src/main/scala/is/hail/io/InputBuffers.scala b/hail/src/main/scala/is/hail/io/InputBuffers.scala index ad8b00811b2..d04c032f9d9 100644 --- a/hail/src/main/scala/is/hail/io/InputBuffers.scala +++ b/hail/src/main/scala/is/hail/io/InputBuffers.scala @@ -1,14 +1,13 @@ package is.hail.io -import java.io._ -import java.util -import java.util.UUID -import java.util.function.Supplier - import is.hail.annotations.{Memory, Region} import is.hail.io.compress.LZ4 import is.hail.utils._ +import java.io._ +import java.util.UUID +import java.util.function.Supplier + import com.github.luben.zstd.{Zstd, ZstdDecompressCtx} trait InputBuffer extends Closeable { @@ -57,18 +56,12 @@ trait InputBuffer extends Closeable { def readDoubles(to: Array[Double]): Unit = readDoubles(to, 0, to.length) def readBoolean(): Boolean = readByte() != 0 - - def readUTF(): String = { - val n = readInt() - val a = readBytesArray(n) - new String(a, utfCharset) - } } trait InputBlockBuffer extends Spec with Closeable { def close(): Unit - def seek(offset: Long) + def seek(offset: Long): Unit def skipBytesReadRemainder(n0: Int, buf: Array[Byte]): Int = { var n = n0 @@ -100,9 +93,8 @@ final class StreamInputBuffer(in: InputStream) extends InputBuffer { Memory.loadByte(buff, 0) } - override def read(buf: Array[Byte], toOff: Int, n: Int): Unit = { + override def read(buf: Array[Byte], toOff: Int, n: Int): Unit = in.readFully(buf, toOff, n) - } def readInt(): Int = { in.readFully(buff, 0, 4) @@ -124,13 +116,11 @@ final class StreamInputBuffer(in: InputStream) extends InputBuffer { Memory.loadDouble(buff, 0) } - def readBytes(toRegion: Region, toOff: Long, n: Int): Unit = { + def readBytes(toRegion: Region, toOff: Long, n: Int): Unit = Region.storeBytes(toOff, readBytesArray(n)) - } - def readBytesArray(n: Int): Array[Byte] = { + def readBytesArray(n: Int): Array[Byte] = Array.tabulate(n)(_ => readByte()) - } def skipByte(): Unit = { val bytesRead = in.skip(1) @@ -172,7 +162,7 @@ final class StreamInputBuffer(in: InputStream) extends InputBuffer { } final class MemoryInputBuffer(mb: MemoryBuffer) extends InputBuffer { - def close() {} + def close(): Unit = {} def seek(offset: Long) = ??? @@ -189,7 +179,7 @@ final class MemoryInputBuffer(mb: MemoryBuffer) extends InputBuffer { def readBytes(toRegion: Region, toOff: Long, n: Int): Unit = mb.readBytes(toOff, n) def readBytesArray(n: Int): Array[Byte] = { - var arr = new Array[Byte](n) + val arr = new Array[Byte](n) mb.readBytesArray(arr, n) arr } @@ -210,15 +200,13 @@ final class MemoryInputBuffer(mb: MemoryBuffer) extends InputBuffer { } final class LEB128InputBuffer(in: InputBuffer) extends InputBuffer { - def close() { + def close(): Unit = in.close() - } def seek(offset: Long): Unit = in.seek(offset) - def readByte(): Byte = { + def readByte(): Byte = in.readByte() - } override def read(buf: Array[Byte], toOff: Int, n: Int) = in.read(buf, toOff, n) @@ -256,13 +244,13 @@ final class LEB128InputBuffer(in: InputBuffer) extends InputBuffer { def skipByte(): Unit = in.skipByte() - def skipInt() { + def skipInt(): Unit = { var b: Byte = readByte() while ((b & 0x80) != 0) b = readByte() } - def skipLong() { + def skipLong(): Unit = { var b: Byte = readByte() while ((b & 0x80) != 0) b = readByte() @@ -322,39 +310,31 @@ final class TracingInputBuffer( Memory.loadDouble(bytes, 0) } - def readBytes(toRegion: Region, toOff: Long, n: Int): Unit = { + def readBytes(toRegion: Region, toOff: Long, n: Int): Unit = Region.storeBytes(toOff, readBytesArray(n)) - } - def readBytesArray(n: Int): Array[Byte] = { + def readBytesArray(n: Int): Array[Byte] = Array.tabulate(n)(_ => readByte()) - } override def skipBoolean(): Unit = skipByte() - def skipByte(): Unit = { + def skipByte(): Unit = readBytesArray(1) - } - def skipInt(): Unit = { + def skipInt(): Unit = readBytesArray(4) - } - def skipLong(): Unit = { + def skipLong(): Unit = readBytesArray(8) - } - def skipFloat(): Unit = { + def skipFloat(): Unit = readBytesArray(4) - } - def skipDouble(): Unit = { + def skipDouble(): Unit = readBytesArray(8) - } - def skipBytes(n: Int): Unit = { + def skipBytes(n: Int): Unit = readBytesArray(n) - } def readDoubles(to: Array[Double], off: Int, n: Int): Unit = { var i = 0 @@ -368,11 +348,6 @@ final class TracingInputBuffer( override def readBoolean(): Boolean = readByte() != 0 - override def readUTF(): String = { - val s = in.readUTF() - logfile.write(s.getBytes(utfCharset)) - s - } } final class BlockingInputBuffer(blockSize: Int, in: InputBlockBuffer) extends InputBuffer { @@ -380,7 +355,7 @@ final class BlockingInputBuffer(blockSize: Int, in: InputBlockBuffer) extends In private[this] var end: Int = 0 private[this] var off: Int = 0 - private[this] def ensure(n: Int) { + private[this] def ensure(n: Int): Unit = { if (off == end) { end = in.readBlock(buf) off = 0 @@ -388,14 +363,13 @@ final class BlockingInputBuffer(blockSize: Int, in: InputBlockBuffer) extends In assert(off + n <= end) } - def close() { + def close(): Unit = in.close() - } def seek(offset: Long): Unit = { in.seek(offset) end = in.readBlock(buf) - off = (offset & 0xFFFF).asInstanceOf[Int] + off = (offset & 0xffff).asInstanceOf[Int] assert(off <= end) } @@ -434,7 +408,7 @@ final class BlockingInputBuffer(blockSize: Int, in: InputBlockBuffer) extends In d } - def readBytes(toRegion: Region, toOff0: Long, n0: Int) { + def readBytes(toRegion: Region, toOff0: Long, n0: Int): Unit = { assert(n0 >= 0) var toOff = toOff0 var n = n0 @@ -453,7 +427,7 @@ final class BlockingInputBuffer(blockSize: Int, in: InputBlockBuffer) extends In } } - override def read(arr: Array[Byte], toOff0: Int, n0: Int) { + override def read(arr: Array[Byte], toOff0: Int, n0: Int): Unit = { var toOff = toOff0; var n = n0 @@ -472,37 +446,37 @@ final class BlockingInputBuffer(blockSize: Int, in: InputBlockBuffer) extends In } def readBytesArray(n: Int): Array[Byte] = { - var arr = new Array[Byte](n) + val arr = new Array[Byte](n) read(arr, 0, n) arr } - def skipByte() { + def skipByte(): Unit = { ensure(1) off += 1 } - def skipInt() { + def skipInt(): Unit = { ensure(4) off += 4 } - def skipLong() { + def skipLong(): Unit = { ensure(8) off += 8 } - def skipFloat() { + def skipFloat(): Unit = { ensure(4) off += 4 } - def skipDouble() { + def skipDouble(): Unit = { ensure(8) off += 8 } - def skipBytes(n0: Int) { + def skipBytes(n0: Int): Unit = { var n = n0 if (off + n > end) { n -= (end - off) @@ -514,7 +488,7 @@ final class BlockingInputBuffer(blockSize: Int, in: InputBlockBuffer) extends In } } - def readDoubles(to: Array[Double], toOff0: Int, n0: Int) { + def readDoubles(to: Array[Double], toOff0: Int, n0: Int): Unit = { assert(toOff0 >= 0) assert(n0 >= 0) assert(toOff0 <= to.length - n0) @@ -539,9 +513,8 @@ final class BlockingInputBuffer(blockSize: Int, in: InputBlockBuffer) extends In final class StreamBlockInputBuffer(in: InputStream) extends InputBlockBuffer { private[this] val lenBuf = new Array[Byte](4) - def close() { + def close(): Unit = in.close() - } // this takes a virtual offset and will seek the underlying stream to offset >> 16 def seek(offset: Long): Unit = in.asInstanceOf[ByteTrackingInputStream].seek(offset >> 16) @@ -556,12 +529,12 @@ final class StreamBlockInputBuffer(in: InputStream) extends InputBlockBuffer { } } -final class LZ4InputBlockBuffer(lz4: LZ4, blockSize: Int, in: InputBlockBuffer) extends InputBlockBuffer { +final class LZ4InputBlockBuffer(lz4: LZ4, blockSize: Int, in: InputBlockBuffer) + extends InputBlockBuffer { private[this] val comp = new Array[Byte](4 + lz4.maxCompressedLength(blockSize)) - def close() { + def close(): Unit = in.close() - } def seek(offset: Long): Unit = in.seek(offset) @@ -598,13 +571,13 @@ final class LZ4InputBlockBuffer(lz4: LZ4, blockSize: Int, in: InputBlockBuffer) } } -final class LZ4SizeBasedCompressingInputBlockBuffer(lz4: LZ4, blockSize: Int, in: InputBlockBuffer) extends InputBlockBuffer { +final class LZ4SizeBasedCompressingInputBlockBuffer(lz4: LZ4, blockSize: Int, in: InputBlockBuffer) + extends InputBlockBuffer { private[this] val comp = new Array[Byte](8 + lz4.maxCompressedLength(blockSize)) private[this] var lim = 0 - def close() { + def close(): Unit = in.close() - } def seek(offset: Long): Unit = in.seek(offset) @@ -632,16 +605,17 @@ final class LZ4SizeBasedCompressingInputBlockBuffer(lz4: LZ4, blockSize: Int, in } object ZstdDecompressLib { - val instance = ThreadLocal.withInitial(new Supplier[ZstdDecompressCtx]() { def get: ZstdDecompressCtx = new ZstdDecompressCtx() }) + val instance = ThreadLocal.withInitial(new Supplier[ZstdDecompressCtx]() { + def get: ZstdDecompressCtx = new ZstdDecompressCtx() + }) } final class ZstdInputBlockBuffer(blockSize: Int, in: InputBlockBuffer) extends InputBlockBuffer { private[this] val zstd = ZstdDecompressLib.instance.get private[this] val comp = new Array[Byte](4 + Zstd.compressBound(blockSize).toInt) - def close(): Unit = { + def close(): Unit = in.close() - } def seek(offset: Long): Unit = in.seek(offset) @@ -658,13 +632,13 @@ final class ZstdInputBlockBuffer(blockSize: Int, in: InputBlockBuffer) extends I } } -final class ZstdSizedBasedInputBlockBuffer(blockSize: Int, in: InputBlockBuffer) extends InputBlockBuffer { +final class ZstdSizedBasedInputBlockBuffer(blockSize: Int, in: InputBlockBuffer) + extends InputBlockBuffer { private[this] val zstd = ZstdDecompressLib.instance.get private[this] val comp = new Array[Byte](4 + Zstd.compressBound(blockSize).toInt) - def close(): Unit = { + def close(): Unit = in.close() - } def seek(offset: Long): Unit = in.seek(offset) diff --git a/hail/src/main/scala/is/hail/io/MemoryBuffer.scala b/hail/src/main/scala/is/hail/io/MemoryBuffer.scala index 7af09d8b363..13da2d1a0a7 100644 --- a/hail/src/main/scala/is/hail/io/MemoryBuffer.scala +++ b/hail/src/main/scala/is/hail/io/MemoryBuffer.scala @@ -1,8 +1,8 @@ package is.hail.io -import java.util +import is.hail.annotations.Memory -import is.hail.annotations.{Memory, Region} +import java.util final class MemoryBuffer extends Serializable { var mem: Array[Byte] = new Array[Byte](8) @@ -11,18 +11,16 @@ final class MemoryBuffer extends Serializable { def capacity: Int = mem.length - def invalidate(): Unit = { + def invalidate(): Unit = mem = null - } - def clear() { + def clear(): Unit = { pos = 0 end = 0 } - def clearPos() { + def clearPos(): Unit = pos = 0 - } def set(bytes: Array[Byte]): Unit = { mem = bytes @@ -36,52 +34,51 @@ final class MemoryBuffer extends Serializable { dst } - def grow(n: Int) { + def grow(n: Int): Unit = mem = util.Arrays.copyOf(mem, math.max(capacity * 2, end + n)) - } - def copyFrom(src: MemoryBuffer) { + def copyFrom(src: MemoryBuffer): Unit = { mem = util.Arrays.copyOf(src.mem, src.capacity) end = src.end pos = src.pos } - def writeByte(b: Byte) { + def writeByte(b: Byte): Unit = { if (end + 1 > capacity) grow(1) Memory.storeByte(mem, end, b) end += 1 } - def writeInt(i: Int) { + def writeInt(i: Int): Unit = { if (end + 4 > capacity) grow(4) Memory.storeInt(mem, end, i) end += 4 } - def writeLong(i: Long) { + def writeLong(i: Long): Unit = { if (end + 8 > capacity) grow(8) Memory.storeLong(mem, end, i) end += 8 } - def writeFloat(i: Float) { + def writeFloat(i: Float): Unit = { if (end + 4 > capacity) grow(4) Memory.storeFloat(mem, end, i) end += 4 } - def writeDouble(i: Double) { + def writeDouble(i: Double): Unit = { if (end + 8 > capacity) grow(8) Memory.storeDouble(mem, end, i) end += 8 } - def writeBytes(off: Long, n: Int) { + def writeBytes(off: Long, n: Int): Unit = { if (end + n > capacity) grow(n) Memory.memcpy(mem, end, off, n) @@ -123,44 +120,44 @@ final class MemoryBuffer extends Serializable { d } - def readBytes(toOff: Long, n: Int) { + def readBytes(toOff: Long, n: Int): Unit = { assert(pos + n <= end) Memory.memcpy(toOff, mem, pos, n) pos += n } - def readBytesArray(dst: Array[Byte], n: Int) { + def readBytesArray(dst: Array[Byte], n: Int): Unit = { assert(pos + n <= end) System.arraycopy(mem, pos, dst, 0, n); pos += n } - def skipByte() { + def skipByte(): Unit = { assert(pos + 1 <= end) pos += 1 } - def skipInt() { + def skipInt(): Unit = { assert(pos + 4 <= end) pos += 4 } - def skipLong() { + def skipLong(): Unit = { assert(pos + 8 <= end) pos += 8 } - def skipFloat() { + def skipFloat(): Unit = { assert(pos + 4 <= end) pos += 4 } - def skipDouble() { + def skipDouble(): Unit = { assert(pos + 8 <= end) pos += 8 } - def skipBytes(n: Int) { + def skipBytes(n: Int): Unit = { assert(pos + n <= end) pos += n } @@ -170,11 +167,10 @@ final class MemoryBuffer extends Serializable { val x = (mem(i).toInt & 0xff).toHexString if (x.length == 1) "0" + x else x - } .mkString(" ") + }.mkString(" ") val index = (from until to by 4).map(i => String.format("%1$-12s", i.toString)).mkString("") println(s"bytes: $bytes") println(s"index: $index") } } - diff --git a/hail/src/main/scala/is/hail/io/OutputBuffers.scala b/hail/src/main/scala/is/hail/io/OutputBuffers.scala index 4c5546f0d7e..39bc09338b0 100644 --- a/hail/src/main/scala/is/hail/io/OutputBuffers.scala +++ b/hail/src/main/scala/is/hail/io/OutputBuffers.scala @@ -1,14 +1,13 @@ package is.hail.io -import java.io._ -import java.util -import java.util.function.Supplier - import is.hail.annotations.{Memory, Region} import is.hail.io.compress.LZ4 import is.hail.utils._ import is.hail.utils.richUtils.ByteTrackingOutputStream +import java.io._ +import java.util.function.Supplier + import com.github.luben.zstd.{Zstd, ZstdCompressCtx} trait OutputBuffer extends Closeable { @@ -46,15 +45,8 @@ trait OutputBuffer extends Closeable { def writeDoubles(from: Array[Double]): Unit = writeDoubles(from, 0, from.length) - def writeBoolean(b: Boolean) { + def writeBoolean(b: Boolean): Unit = writeByte(b.toByte) - } - - def writeUTF(s: String): Unit = { - val bytes = s.getBytes(utfCharset) - writeInt(bytes.length) - write(bytes) - } } trait OutputBlockBuffer extends Spec with Closeable { @@ -76,22 +68,22 @@ final class StreamOutputBuffer(out: OutputStream) extends OutputBuffer { override def writeByte(b: Byte): Unit = out.write(b.toInt) - override def writeInt(i: Int) { + override def writeInt(i: Int): Unit = { Memory.storeInt(buf, 0, i) out.write(buf, 0, 4) } - def writeLong(l: Long) { + def writeLong(l: Long): Unit = { Memory.storeLong(buf, 0, l) out.write(buf, 0, 8) } - def writeFloat(f: Float) { + def writeFloat(f: Float): Unit = { Memory.storeFloat(buf, 0, f) out.write(buf, 0, 4) } - def writeDouble(d: Double) { + def writeDouble(d: Double): Unit = { Memory.storeDouble(buf, 0, d) out.write(buf, 0, 8) } @@ -102,7 +94,7 @@ final class StreamOutputBuffer(out: OutputStream) extends OutputBuffer { def writeBytes(addr: Long, n: Int): Unit = out.write(Region.loadBytes(addr, n)) - def writeDoubles(from: Array[Double], fromOff: Int, n: Int) { + def writeDoubles(from: Array[Double], fromOff: Int, n: Int): Unit = { var i = 0 while (i < n) { writeDouble(from(fromOff + i)) @@ -112,9 +104,9 @@ final class StreamOutputBuffer(out: OutputStream) extends OutputBuffer { } final class MemoryOutputBuffer(mb: MemoryBuffer) extends OutputBuffer { - def flush() {} + def flush(): Unit = {} - def close() {} + def close(): Unit = {} def indexOffset(): Long = ??? @@ -138,9 +130,8 @@ final class MemoryOutputBuffer(mb: MemoryBuffer) extends OutputBuffer { final class LEB128OutputBuffer(out: OutputBuffer) extends OutputBuffer { def flush(): Unit = out.flush() - def close() { + def close(): Unit = out.close() - } def indexOffset(): Long = out.indexOffset() @@ -176,7 +167,8 @@ final class LEB128OutputBuffer(out: OutputBuffer) extends OutputBuffer { def writeBytes(addr: Long, n: Int): Unit = out.writeBytes(addr, n) - def writeDoubles(from: Array[Double], fromOff: Int, n: Int): Unit = out.writeDoubles(from, fromOff, n) + def writeDoubles(from: Array[Double], fromOff: Int, n: Int): Unit = + out.writeDoubles(from, fromOff, n) } final class BlockingOutputBuffer(blockSize: Int, out: OutputBlockBuffer) extends OutputBuffer { @@ -189,50 +181,50 @@ final class BlockingOutputBuffer(blockSize: Int, out: OutputBlockBuffer) extends (out.getPos() << 16) | off } - private def writeBlock() { + private def writeBlock(): Unit = { out.writeBlock(buf, off) off = 0 } - def flush() { + def flush(): Unit = { writeBlock() out.flush() } - def close() { + def close(): Unit = { flush() out.close() } - def writeByte(b: Byte) { + def writeByte(b: Byte): Unit = { if (off + 1 > buf.length) writeBlock() Memory.storeByte(buf, off, b) off += 1 } - def writeInt(i: Int) { + def writeInt(i: Int): Unit = { if (off + 4 > buf.length) writeBlock() Memory.storeInt(buf, off, i) off += 4 } - def writeLong(l: Long) { + def writeLong(l: Long): Unit = { if (off + 8 > buf.length) writeBlock() Memory.storeLong(buf, off, l) off += 8 } - def writeFloat(f: Float) { + def writeFloat(f: Float): Unit = { if (off + 4 > buf.length) writeBlock() Memory.storeFloat(buf, off, f) off += 4 } - def writeDouble(d: Double) { + def writeDouble(d: Double): Unit = { if (off + 8 > buf.length) writeBlock() Memory.storeDouble(buf, off, d) @@ -241,7 +233,7 @@ final class BlockingOutputBuffer(blockSize: Int, out: OutputBlockBuffer) extends def writeBytes(fromRegion: Region, fromOff0: Long, n0: Int) = writeBytes(fromOff0, n0) - def writeBytes(addr0: Long, n0: Int) { + def writeBytes(addr0: Long, n0: Int): Unit = { assert(n0 >= 0) var addr = addr0 var n = n0 @@ -259,7 +251,7 @@ final class BlockingOutputBuffer(blockSize: Int, out: OutputBlockBuffer) extends off += n } - def writeDoubles(from: Array[Double], fromOff0: Int, n0: Int) { + def writeDoubles(from: Array[Double], fromOff0: Int, n0: Int): Unit = { assert(n0 >= 0) assert(fromOff0 >= 0) assert(fromOff0 <= from.length - n0) @@ -282,13 +274,11 @@ final class BlockingOutputBuffer(blockSize: Int, out: OutputBlockBuffer) extends final class StreamBlockOutputBuffer(out: OutputStream) extends OutputBlockBuffer { private val lenBuf = new Array[Byte](4) - def flush() { + def flush(): Unit = out.flush() - } - def close() { + def close(): Unit = out.close() - } def writeBlock(buf: Array[Byte], len: Int): Unit = { Memory.storeInt(lenBuf, 0, len) @@ -299,16 +289,15 @@ final class StreamBlockOutputBuffer(out: OutputStream) extends OutputBlockBuffer def getPos(): Long = out.asInstanceOf[ByteTrackingOutputStream].bytesWritten } -final class LZ4OutputBlockBuffer(lz4: LZ4, blockSize: Int, out: OutputBlockBuffer) extends OutputBlockBuffer { +final class LZ4OutputBlockBuffer(lz4: LZ4, blockSize: Int, out: OutputBlockBuffer) + extends OutputBlockBuffer { private val comp = new Array[Byte](4 + lz4.maxCompressedLength(blockSize)) - def flush() { + def flush(): Unit = out.flush() - } - def close() { + def close(): Unit = out.close() - } def writeBlock(buf: Array[Byte], decompLen: Int): Unit = { val compLen = lz4.compress(comp, 4, buf, decompLen) @@ -319,16 +308,19 @@ final class LZ4OutputBlockBuffer(lz4: LZ4, blockSize: Int, out: OutputBlockBuffe def getPos(): Long = out.getPos() } -final class LZ4SizeBasedCompressingOutputBlockBuffer(lz4: LZ4, blockSize: Int, minCompressionSize: Int, out: OutputBlockBuffer) extends OutputBlockBuffer { +final class LZ4SizeBasedCompressingOutputBlockBuffer( + lz4: LZ4, + blockSize: Int, + minCompressionSize: Int, + out: OutputBlockBuffer, +) extends OutputBlockBuffer { private val comp = new Array[Byte](8 + lz4.maxCompressedLength(blockSize)) - def flush() { + def flush(): Unit = out.flush() - } - def close() { + def close(): Unit = out.close() - } def writeBlock(buf: Array[Byte], decompLen: Int): Unit = { if (decompLen < minCompressionSize) { @@ -357,7 +349,8 @@ object ZstdCompressLib { }) } -final class ZstdOutputBlockBuffer(blockSize: Int, out: OutputBlockBuffer) extends OutputBlockBuffer { +final class ZstdOutputBlockBuffer(blockSize: Int, out: OutputBlockBuffer) + extends OutputBlockBuffer { private[this] val zstd = ZstdCompressLib.instance.get private[this] val comp = new Array[Byte](4 + Zstd.compressBound(blockSize).toInt) @@ -374,7 +367,11 @@ final class ZstdOutputBlockBuffer(blockSize: Int, out: OutputBlockBuffer) extend def getPos(): Long = out.getPos() } -final class ZstdSizedBasedOutputBlockBuffer(blockSize: Int, minCompressionSize: Int, out: OutputBlockBuffer) extends OutputBlockBuffer { +final class ZstdSizedBasedOutputBlockBuffer( + blockSize: Int, + minCompressionSize: Int, + out: OutputBlockBuffer, +) extends OutputBlockBuffer { private[this] val zstd = ZstdCompressLib.instance.get private[this] val comp = new Array[Byte](4 + Zstd.compressBound(blockSize).toInt) diff --git a/hail/src/main/scala/is/hail/io/RichContextRDDRegionValue.scala b/hail/src/main/scala/is/hail/io/RichContextRDDRegionValue.scala index beb64b15ee7..466ae1b594d 100644 --- a/hail/src/main/scala/is/hail/io/RichContextRDDRegionValue.scala +++ b/hail/src/main/scala/is/hail/io/RichContextRDDRegionValue.scala @@ -1,29 +1,34 @@ package is.hail.io -import java.io._ -import is.hail.asm4s.{HailClassLoader, theHailClassLoaderForSparkWorkers} import is.hail.annotations._ +import is.hail.asm4s.{theHailClassLoaderForSparkWorkers, HailClassLoader} import is.hail.backend.ExecuteContext import is.hail.backend.spark.SparkTaskContext -import is.hail.types.physical._ import is.hail.io.fs.FS import is.hail.io.index.IndexWriter -import is.hail.rvd.{AbstractIndexSpec, IndexSpec, MakeRVDSpec, RVDContext, RVDPartitioner, RVDType} +import is.hail.rvd.{AbstractIndexSpec, MakeRVDSpec, RVDContext, RVDPartitioner, RVDType} import is.hail.sparkextras._ +import is.hail.types.physical._ import is.hail.utils._ import is.hail.utils.richUtils.ByteTrackingOutputStream + +import java.io._ + +import org.apache.spark.{ExposedMetrics, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row -import org.apache.spark.{ExposedMetrics, TaskContext} - -import scala.reflect.ClassTag object RichContextRDDRegionValue { def writeRowsPartition( makeEnc: (OutputStream, HailClassLoader) => Encoder, indexKeyFieldIndices: Array[Int] = null, - rowType: PStruct = null - )(ctx: RVDContext, it: Iterator[Long], os: OutputStream, iw: IndexWriter): (Long, Long) = { + rowType: PStruct = null, + )( + ctx: RVDContext, + it: Iterator[Long], + os: OutputStream, + iw: IndexWriter, + ): (Long, Long) = { val context = TaskContext.get val outputMetrics = if (context != null) @@ -80,7 +85,7 @@ object RichContextRDDRegionValue { stageLocally: Boolean, makeIndexWriter: (String, RegionPool) => IndexWriter, makeRowsEnc: (OutputStream) => Encoder, - makeEntriesEnc: (OutputStream) => Encoder + makeEntriesEnc: (OutputStream) => Encoder, ): FileWriteMetadata = { val fullRowType = t.rowType @@ -92,8 +97,10 @@ object RichContextRDDRegionValue { val finalIdxPath = path + "/index/" + f + ".idx" val (rowsPartPath, entriesPartPath, idxPath) = if (stageLocally) { - val rowsPartPath = ExecuteContext.createTmpPathNoCleanup(localTmpdir, "write-split-staged-rows-part") - val entriesPartPath = ExecuteContext.createTmpPathNoCleanup(localTmpdir, "write-split-staged-entries-part") + val rowsPartPath = + ExecuteContext.createTmpPathNoCleanup(localTmpdir, "write-split-staged-rows-part") + val entriesPartPath = + ExecuteContext.createTmpPathNoCleanup(localTmpdir, "write-split-staged-entries-part") val idxPath = rowsPartPath + ".idx" context.addTaskCompletionListener[Unit] { (context: TaskContext) => fs.delete(rowsPartPath, recursive = false) @@ -107,7 +114,6 @@ object RichContextRDDRegionValue { val (rowCount, totalBytesWritten) = using(fs.create(rowsPartPath)) { rowsOS => val trackedRowsOS = new ByteTrackingOutputStream(rowsOS) using(makeRowsEnc(trackedRowsOS)) { rowsEN => - using(fs.create(entriesPartPath)) { entriesOS => val trackedEntriesOS = new ByteTrackingOutputStream(entriesOS) using(makeEntriesEnc(trackedEntriesOS)) { entriesEN => @@ -130,7 +136,10 @@ object RichContextRDDRegionValue { rowCount += 1 - ExposedMetrics.setBytes(outputMetrics, trackedRowsOS.bytesWritten + trackedEntriesOS.bytesWritten) + ExposedMetrics.setBytes( + outputMetrics, + trackedRowsOS.bytesWritten + trackedEntriesOS.bytesWritten, + ) ExposedMetrics.setRecords(outputMetrics, 2 * rowCount) } @@ -140,7 +149,8 @@ object RichContextRDDRegionValue { rowsEN.flush() entriesEN.flush() - val totalBytesWritten = trackedRowsOS.bytesWritten + trackedEntriesOS.bytesWritten + iw.trackedOS().bytesWritten + val totalBytesWritten = + trackedRowsOS.bytesWritten + trackedEntriesOS.bytesWritten + iw.trackedOS().bytesWritten ExposedMetrics.setBytes(outputMetrics, totalBytesWritten) (rowCount, totalBytesWritten) @@ -171,13 +181,17 @@ object RichContextRDDRegionValue { rowsRVType: PStruct, entriesRVType: PStruct, partFiles: Array[String], - partitioner: RVDPartitioner - ) { + partitioner: RVDPartitioner, + ): Unit = { val rowsSpec = MakeRVDSpec(rowsCodecSpec, partFiles, partitioner, rowsIndexSpec) rowsSpec.write(fs, path + "/rows/rows") - val entriesSpec = MakeRVDSpec(entriesCodecSpec, partFiles, - RVDPartitioner.unkeyed(partitioner.sm, partitioner.numPartitions), entriesIndexSpec) + val entriesSpec = MakeRVDSpec( + entriesCodecSpec, + partFiles, + RVDPartitioner.unkeyed(partitioner.sm, partitioner.numPartitions), + entriesIndexSpec, + ) entriesSpec.write(fs, path + "/entries/rows") } } @@ -186,7 +200,7 @@ class RichContextRDDLong(val crdd: ContextRDD[Long]) extends AnyVal { def boundary: ContextRDD[Long] = crdd.cmapPartitionsAndContext { (consumerCtx, part) => val producerCtx = consumerCtx.freshContext - val it = part.flatMap(_ (producerCtx)) + val it = part.flatMap(_(producerCtx)) new Iterator[Long]() { private[this] var cleared: Boolean = false @@ -209,10 +223,10 @@ class RichContextRDDLong(val crdd: ContextRDD[Long]) extends AnyVal { } def toCRDDRegionValue: ContextRDD[RegionValue] = - boundary.cmapPartitionsWithContext((ctx, part) => { + boundary.cmapPartitionsWithContext { (ctx, part) => val rv = RegionValue(ctx.r) - part(ctx).map(ptr => { rv.setOffset(ptr); rv }) - }) + part(ctx).map { ptr => rv.setOffset(ptr); rv } + } def writeRows( ctx: ExecuteContext, @@ -220,33 +234,33 @@ class RichContextRDDLong(val crdd: ContextRDD[Long]) extends AnyVal { idxRelPath: String, t: RVDType, stageLocally: Boolean, - encoding: AbstractTypedCodecSpec + encoding: AbstractTypedCodecSpec, ): Array[FileWriteMetadata] = { crdd.writePartitions( ctx, path, idxRelPath, - stageLocally, - { - val f1= IndexWriter.builder(ctx, t.kType, +PCanonicalStruct()) + stageLocally, { + val f1 = IndexWriter.builder(ctx, t.kType, +PCanonicalStruct()) f1(_, theHailClassLoaderForSparkWorkers, SparkTaskContext.get(), _) }, RichContextRDDRegionValue.writeRowsPartition( encoding.buildEncoder(ctx, t.rowType), t.kFieldIdx, - t.rowType) _) + t.rowType, + ) _, + ) } - def toRows(rowType: PStruct): RDD[Row] = { + def toRows(rowType: PStruct): RDD[Row] = crdd.cmap((ctx, ptr) => SafeRow(rowType, ptr)).run - } } class RichContextRDDRegionValue(val crdd: ContextRDD[RegionValue]) extends AnyVal { def boundary: ContextRDD[RegionValue] = crdd.cmapPartitionsAndContext { (consumerCtx, part) => val producerCtx = consumerCtx.freshContext - val it = part.flatMap(_ (producerCtx)) + val it = part.flatMap(_(producerCtx)) new Iterator[RegionValue]() { private[this] var cleared: Boolean = false @@ -279,7 +293,7 @@ class RichContextRDDRegionValue(val crdd: ContextRDD[RegionValue]) extends AnyVa def cleanupRegions: ContextRDD[RegionValue] = { crdd.cmapPartitionsAndContext { (ctx, part) => - val it = part.flatMap(_ (ctx)) + val it = part.flatMap(_(ctx)) new Iterator[RegionValue]() { private[this] var cleared: Boolean = false @@ -302,8 +316,6 @@ class RichContextRDDRegionValue(val crdd: ContextRDD[RegionValue]) extends AnyVa } } - - def toRows(rowType: PStruct): RDD[Row] = { + def toRows(rowType: PStruct): RDD[Row] = crdd.run.map(rv => SafeRow(rowType, rv.offset)) - } } diff --git a/hail/src/main/scala/is/hail/io/Spec.scala b/hail/src/main/scala/is/hail/io/Spec.scala index d2245c21de6..8f31f4b8247 100644 --- a/hail/src/main/scala/is/hail/io/Spec.scala +++ b/hail/src/main/scala/is/hail/io/Spec.scala @@ -1,6 +1,7 @@ package is.hail.io import is.hail.rvd.AbstractRVDSpec + import org.json4s.Extraction import org.json4s.jackson.JsonMethods diff --git a/hail/src/main/scala/is/hail/io/TypedCodecSpec.scala b/hail/src/main/scala/is/hail/io/TypedCodecSpec.scala index d55d7f84716..c777a28b131 100644 --- a/hail/src/main/scala/is/hail/io/TypedCodecSpec.scala +++ b/hail/src/main/scala/is/hail/io/TypedCodecSpec.scala @@ -1,14 +1,13 @@ package is.hail.io -import java.io._ -import is.hail.annotations._ import is.hail.asm4s._ import is.hail.backend.ExecuteContext -import is.hail.expr.ir.{EmitClassBuilder, EmitFunctionBuilder} import is.hail.types.encoded._ import is.hail.types.physical._ import is.hail.types.virtual._ +import java.io._ + object TypedCodecSpec { def apply(pt: PType, bufferSpec: BufferSpec): TypedCodecSpec = { val eType = EType.defaultFromPType(pt) @@ -16,30 +15,39 @@ object TypedCodecSpec { } } -final case class TypedCodecSpec(_eType: EType, _vType: Type, _bufferSpec: BufferSpec) extends AbstractTypedCodecSpec { +final case class TypedCodecSpec(_eType: EType, _vType: Type, _bufferSpec: BufferSpec) + extends AbstractTypedCodecSpec { def encodedType: EType = _eType def encodedVirtualType: Type = _vType def buildEncoder(ctx: ExecuteContext, t: PType): (OutputStream, HailClassLoader) => Encoder = { val bufferToEncoder = encodedType.buildEncoder(ctx, t) - (out: OutputStream, theHailClassLoader: HailClassLoader) => bufferToEncoder(_bufferSpec.buildOutputBuffer(out), theHailClassLoader) + (out: OutputStream, theHailClassLoader: HailClassLoader) => + bufferToEncoder(_bufferSpec.buildOutputBuffer(out), theHailClassLoader) } - def decodedPType(requestedType: Type): PType = { + def decodedPType(requestedType: Type): PType = encodedType.decodedPType(requestedType) - } - def buildDecoder(ctx: ExecuteContext, requestedType: Type): (PType, (InputStream, HailClassLoader) => Decoder) = { + def buildDecoder(ctx: ExecuteContext, requestedType: Type) + : (PType, (InputStream, HailClassLoader) => Decoder) = { val (rt, bufferToDecoder) = encodedType.buildDecoder(ctx, requestedType) - (rt, (in: InputStream, theHailClassLoader: HailClassLoader) => bufferToDecoder(_bufferSpec.buildInputBuffer(in), theHailClassLoader)) + ( + rt, + (in: InputStream, theHailClassLoader: HailClassLoader) => + bufferToDecoder(_bufferSpec.buildInputBuffer(in), theHailClassLoader), + ) } - def buildStructDecoder(ctx: ExecuteContext, requestedType: TStruct): (PStruct, (InputStream, HailClassLoader) => Decoder) = { + def buildStructDecoder(ctx: ExecuteContext, requestedType: TStruct) + : (PStruct, (InputStream, HailClassLoader) => Decoder) = { val (pType: PStruct, makeDec) = buildDecoder(ctx, requestedType) pType -> makeDec } - def buildCodeInputBuffer(is: Code[InputStream]): Code[InputBuffer] = _bufferSpec.buildCodeInputBuffer(is) + def buildCodeInputBuffer(is: Code[InputStream]): Code[InputBuffer] = + _bufferSpec.buildCodeInputBuffer(is) - def buildCodeOutputBuffer(os: Code[OutputStream]): Code[OutputBuffer] = _bufferSpec.buildCodeOutputBuffer(os) + def buildCodeOutputBuffer(os: Code[OutputStream]): Code[OutputBuffer] = + _bufferSpec.buildCodeOutputBuffer(os) } diff --git a/hail/src/main/scala/is/hail/io/avro/AvroPartitionReader.scala b/hail/src/main/scala/is/hail/io/avro/AvroPartitionReader.scala index 67bdf36dac8..ecd559df445 100644 --- a/hail/src/main/scala/is/hail/io/avro/AvroPartitionReader.scala +++ b/hail/src/main/scala/is/hail/io/avro/AvroPartitionReader.scala @@ -1,26 +1,28 @@ package is.hail.io.avro import is.hail.annotations.Region -import is.hail.asm4s._ +import is.hail.asm4s.{Field => _, _} import is.hail.backend.ExecuteContext +import is.hail.expr.ir.{ + EmitCode, EmitCodeBuilder, EmitMethodBuilder, EmitValue, IEmitCode, PartitionReader, +} import is.hail.expr.ir.streams.StreamProducer -import is.hail.expr.ir.{EmitCode, EmitCodeBuilder, EmitMethodBuilder, EmitValue, IEmitCode, PartitionReader} +import is.hail.types.{RField, RStruct, TypeWithRequiredness} import is.hail.types.physical.{PCanonicalTuple, PInt64Required} -import is.hail.types.physical.stypes.EmitType import is.hail.types.physical.stypes.concrete._ -import is.hail.types.physical.stypes.interfaces.{SBaseStructValue, SStreamValue, primitive} -import is.hail.types.physical.stypes.primitives.{SInt64, SInt64Value} +import is.hail.types.physical.stypes.interfaces.{primitive, SBaseStructValue, SStreamValue} import is.hail.types.virtual._ -import is.hail.types.{RField, RStruct, TypeWithRequiredness} + +import scala.collection.JavaConverters._ + +import java.io.InputStream + import org.apache.avro.Schema import org.apache.avro.file.DataFileStream import org.apache.avro.generic.{GenericData, GenericDatumReader, GenericRecord} import org.apache.avro.io.DatumReader import org.json4s.{Extraction, JValue} -import java.io.InputStream -import scala.collection.JavaConverters._ - case class AvroPartitionReader(schema: Schema, uidFieldName: String) extends PartitionReader { def contextType: Type = TStruct("partitionPath" -> TString, "partitionIndex" -> TInt64) @@ -52,11 +54,13 @@ case class AvroPartitionReader(schema: Schema, uidFieldName: String) extends Par cb: EmitCodeBuilder, mb: EmitMethodBuilder[_], context: EmitCode, - requestedType: TStruct + requestedType: TStruct, ): IEmitCode = { context.toI(cb).map(cb) { case ctxStruct: SBaseStructValue => - val partIdx = cb.memoizeField(ctxStruct.loadField(cb, "partitionIndex").get(cb), "partIdx") - val pathString = ctxStruct.loadField(cb, "partitionPath").get(cb).asString.loadString(cb) + val partIdx = + cb.memoizeField(ctxStruct.loadField(cb, "partitionIndex").getOrAssert(cb), "partIdx") + val pathString = + ctxStruct.loadField(cb, "partitionPath").getOrAssert(cb).asString.loadString(cb) val makeUID = requestedType.hasField(uidFieldName) val concreteRequestedType = if (makeUID) @@ -79,7 +83,11 @@ case class AvroPartitionReader(schema: Schema, uidFieldName: String) extends Par cb.assign(record, Code.newInstance[GenericData.Record, Schema](codeSchema)) val is = mb.open(pathString, false) val datumReader = Code.newInstance[GenericDatumReader[GenericRecord], Schema](codeSchema) - val dataFileStream = Code.newInstance[DataFileStream[GenericRecord], InputStream, DatumReader[GenericRecord]](is, datumReader) + val dataFileStream = Code.newInstance[ + DataFileStream[GenericRecord], + InputStream, + DatumReader[GenericRecord], + ](is, datumReader) cb.assign(it, dataFileStream) cb.assign(rowIdx, -1L) @@ -98,9 +106,14 @@ case class AvroPartitionReader(schema: Schema, uidFieldName: String) extends Par val baseStruct = AvroReader.recordToHail(cb, region, record, concreteRequestedType) if (makeUID) { val uid = EmitValue.present( - SStackStruct.constructFromArgs(cb, region, TTuple(TInt64, TInt64), + SStackStruct.constructFromArgs( + cb, + region, + TTuple(TInt64, TInt64), EmitCode.present(mb, partIdx), - EmitCode.present(mb, primitive(rowIdx)))) + EmitCode.present(mb, primitive(rowIdx)), + ) + ) EmitCode.present(mb, baseStruct._insert(requestedType, uidFieldName -> uid)) } else { EmitCode.present(mb, baseStruct) @@ -121,11 +134,14 @@ object AvroReader { private[avro] def schemaToType(schema: Schema): TStruct = { try { if (schema.getType != Schema.Type.RECORD) { - throw new UnsupportedOperationException("hail conversion from avro is only supported for top level record types") + throw new UnsupportedOperationException( + "hail conversion from avro is only supported for top level record types" + ) } _schemaToType(schema).asInstanceOf[TStruct] } catch { - case e: UnsupportedOperationException => throw new UnsupportedOperationException(s"hail conversion from $schema is unsupported", e) + case e: UnsupportedOperationException => + throw new UnsupportedOperationException(s"hail conversion from $schema is unsupported", e) } } @@ -148,31 +164,80 @@ object AvroReader { throw new UnsupportedOperationException(s"hail conversion from avro $schema is unsupported") _schemaToType(types.get(1 - nullIndex)) - case _ => throw new UnsupportedOperationException(s"hail conversion from avro $schema is unsupported") + case _ => + throw new UnsupportedOperationException(s"hail conversion from avro $schema is unsupported") } - private[avro] def recordToHail(cb: EmitCodeBuilder, region: Value[Region], record: Value[GenericRecord], requestedType: TBaseStruct): SBaseStructValue = { + private[avro] def recordToHail( + cb: EmitCodeBuilder, + region: Value[Region], + record: Value[GenericRecord], + requestedType: TBaseStruct, + ): SBaseStructValue = { val codes = requestedType.fields.map { case Field(name, typ, _) => val v = cb.newLocal[AnyRef]("avro_value") cb.assign(v, record.invoke[String, AnyRef]("get", name)) typ match { case TBoolean => - EmitCode.fromI(cb.emb)(cb => IEmitCode(cb, v.isNull, primitive(cb.memoize(Code.booleanValue(Code.checkcast[java.lang.Boolean](v)))))) + EmitCode.fromI(cb.emb)(cb => + IEmitCode( + cb, + v.isNull, + primitive(cb.memoize(Code.booleanValue(Code.checkcast[java.lang.Boolean](v)))), + ) + ) case TInt32 => - EmitCode.fromI(cb.emb)(cb => IEmitCode(cb, v.isNull, primitive(cb.memoize(Code.intValue(Code.checkcast[java.lang.Number](v)))))) + EmitCode.fromI(cb.emb)(cb => + IEmitCode( + cb, + v.isNull, + primitive(cb.memoize(Code.intValue(Code.checkcast[java.lang.Number](v)))), + ) + ) case TInt64 => - EmitCode.fromI(cb.emb)(cb => IEmitCode(cb, v.isNull, primitive(cb.memoize(Code.longValue(Code.checkcast[java.lang.Number](v)))))) + EmitCode.fromI(cb.emb)(cb => + IEmitCode( + cb, + v.isNull, + primitive(cb.memoize(Code.longValue(Code.checkcast[java.lang.Number](v)))), + ) + ) case TFloat32 => - EmitCode.fromI(cb.emb)(cb => IEmitCode(cb, v.isNull, primitive(cb.memoize(Code.floatValue(Code.checkcast[java.lang.Number](v)))))) + EmitCode.fromI(cb.emb)(cb => + IEmitCode( + cb, + v.isNull, + primitive(cb.memoize(Code.floatValue(Code.checkcast[java.lang.Number](v)))), + ) + ) case TFloat64 => - EmitCode.fromI(cb.emb)(cb => IEmitCode(cb, v.isNull, primitive(cb.memoize(Code.doubleValue(Code.checkcast[java.lang.Number](v)))))) + EmitCode.fromI(cb.emb)(cb => + IEmitCode( + cb, + v.isNull, + primitive(cb.memoize(Code.doubleValue(Code.checkcast[java.lang.Number](v)))), + ) + ) case TString => - EmitCode.fromI(cb.emb)(cb => IEmitCode(cb, v.isNull, new SJavaStringValue(cb.memoize(Code.checkcast[org.apache.avro.util.Utf8](v).invoke[String]("toString"))))) + EmitCode.fromI(cb.emb)(cb => + IEmitCode( + cb, + v.isNull, + new SJavaStringValue( + cb.memoize(Code.checkcast[org.apache.avro.util.Utf8](v).invoke[String]("toString")) + ), + ) + ) case TBinary => - EmitCode.fromI(cb.emb)(cb => IEmitCode(cb, v.isNull, new SJavaBytesValue(cb.memoize(Code.checkcast[Array[Byte]](v))))) + EmitCode.fromI(cb.emb)(cb => + IEmitCode(cb, v.isNull, new SJavaBytesValue(cb.memoize(Code.checkcast[Array[Byte]](v)))) + ) case typ: TBaseStruct => - val record = cb.newLocal[GenericRecord]("avro_subrecord", Code.checkcast[GenericRecord](v)) - EmitCode.fromI(cb.emb)(cb => IEmitCode(cb, v.isNull, recordToHail(cb, region, record, typ))) + val record = + cb.newLocal[GenericRecord]("avro_subrecord", Code.checkcast[GenericRecord](v)) + EmitCode.fromI(cb.emb)(cb => + IEmitCode(cb, v.isNull, recordToHail(cb, region, record, typ)) + ) } } @@ -192,8 +257,8 @@ object AvroReader { typ match { case t: RStruct => t.fields.foreach { case RField(name, typ, _) => - setRequiredness(realSchema.getField(name).schema, typ) - } + setRequiredness(realSchema.getField(name).schema, typ) + } case _ => // do nothing } } diff --git a/hail/src/main/scala/is/hail/io/avro/AvroSchemaSerializer.scala b/hail/src/main/scala/is/hail/io/avro/AvroSchemaSerializer.scala index ffcce83802a..182143c0d0c 100644 --- a/hail/src/main/scala/is/hail/io/avro/AvroSchemaSerializer.scala +++ b/hail/src/main/scala/is/hail/io/avro/AvroSchemaSerializer.scala @@ -4,11 +4,13 @@ import org.apache.avro.Schema import org.json4s.CustomSerializer import org.json4s.jackson.JsonMethods -class AvroSchemaSerializer extends CustomSerializer[Schema](_ => ( - { case jv => - new Schema.Parser().parse(JsonMethods.compact(jv)) - }, - { case schema: Schema => - JsonMethods.parse(schema.toString) - } -)) +class AvroSchemaSerializer extends CustomSerializer[Schema](_ => + ( + { case jv => + new Schema.Parser().parse(JsonMethods.compact(jv)) + }, + { case schema: Schema => + JsonMethods.parse(schema.toString) + }, + ) + ) diff --git a/hail/src/main/scala/is/hail/io/avro/AvroTableReader.scala b/hail/src/main/scala/is/hail/io/avro/AvroTableReader.scala index 37cbc41ac68..9e1969214ab 100644 --- a/hail/src/main/scala/is/hail/io/avro/AvroTableReader.scala +++ b/hail/src/main/scala/is/hail/io/avro/AvroTableReader.scala @@ -1,28 +1,39 @@ package is.hail.io.avro -import is.hail.backend.ExecuteContext +import is.hail.backend.{ExecuteContext, HailStateManager} import is.hail.expr.ir._ import is.hail.expr.ir.lowering.{LowererUnsupportedOperation, TableStage, TableStageDependency} import is.hail.rvd.RVDPartitioner +import is.hail.types.{TableType, VirtualTypeWithReq} import is.hail.types.physical.{PCanonicalStruct, PCanonicalTuple, PInt64Required} import is.hail.types.virtual._ -import is.hail.types.{TableType, VirtualTypeWithReq} -import is.hail.utils.{FastSeq, plural} +import is.hail.utils.{plural, FastSeq} + import org.json4s.{Formats, JValue} class AvroTableReader( partitionReader: AvroPartitionReader, paths: IndexedSeq[String], - unsafeOptions: Option[UnsafeAvroTableReaderOptions] = None + unsafeOptions: Option[UnsafeAvroTableReaderOptions] = None, ) extends TableReaderWithExtraUID { - private val partitioner: RVDPartitioner = unsafeOptions.map { case UnsafeAvroTableReaderOptions(key, intervals, _) => - require(intervals.length == paths.length, - s"There must be one partition interval per avro file, have ${paths.length} ${plural(paths.length, "file")} and ${intervals.length} ${plural(intervals.length, "interval")}") - RVDPartitioner.generate(null, partitionReader.fullRowType.typeAfterSelectNames(key), intervals) - }.getOrElse { - RVDPartitioner.unkeyed(null, paths.length) - } + private def partitioner(stateManager: HailStateManager): RVDPartitioner = + unsafeOptions.map { case UnsafeAvroTableReaderOptions(key, intervals, _) => + require( + intervals.length == paths.length, + s"There must be one partition interval per avro file, have ${paths.length} ${plural( + paths.length, + "file", + )} and ${intervals.length} ${plural(intervals.length, "interval")}", + ) + RVDPartitioner.generate( + stateManager, + partitionReader.fullRowType.typeAfterSelectNames(key), + intervals, + ) + }.getOrElse { + RVDPartitioner.unkeyed(stateManager, paths.length) + } def pathsUsed: Seq[String] = paths @@ -31,42 +42,54 @@ class AvroTableReader( override def uidType = TTuple(TInt64, TInt64) override def fullTypeWithoutUIDs: TableType = - TableType(partitionReader.fullRowTypeWithoutUIDs, unsafeOptions.map(_.key).getOrElse(IndexedSeq()), TStruct()) + TableType( + partitionReader.fullRowTypeWithoutUIDs, + unsafeOptions.map(_.key).getOrElse(IndexedSeq()), + TStruct(), + ) - override def concreteRowRequiredness(ctx: ExecuteContext, requestedType: TableType): VirtualTypeWithReq = - VirtualTypeWithReq(requestedType.rowType, partitionReader.rowRequiredness(requestedType.rowType)) + override def concreteRowRequiredness(ctx: ExecuteContext, requestedType: TableType) + : VirtualTypeWithReq = + VirtualTypeWithReq( + requestedType.rowType, + partitionReader.rowRequiredness(requestedType.rowType), + ) override def uidRequiredness: VirtualTypeWithReq = VirtualTypeWithReq(PCanonicalTuple(true, PInt64Required, PInt64Required)) - override def globalRequiredness(ctx: ExecuteContext, requestedType: TableType): VirtualTypeWithReq = + override def globalRequiredness(ctx: ExecuteContext, requestedType: TableType) + : VirtualTypeWithReq = VirtualTypeWithReq(PCanonicalStruct(required = true)) def renderShort(): String = defaultRender() override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = { val globals = MakeStruct(FastSeq()) - val contexts = zip2(ToStream(Literal(TArray(TString), paths)), StreamIota(I32(0), I32(1)), ArrayZipBehavior.TakeMinLength) { (path, idx) => + val contexts = zip2( + ToStream(Literal(TArray(TString), paths)), + StreamIota(I32(0), I32(1)), + ArrayZipBehavior.TakeMinLength, + ) { (path, idx) => MakeStruct(Array("partitionPath" -> path, "partitionIndex" -> Cast(idx, TInt64))) } TableStage( globals, - partitioner, + partitioner(ctx.stateManager), TableStageDependency.none, contexts, - { ctx => - ReadPartition(ctx, requestedType.rowType, partitionReader) - } + ctx => ReadPartition(ctx, requestedType.rowType, partitionReader), ) } override def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR = - throw new LowererUnsupportedOperation(s"${ getClass.getSimpleName }.lowerGlobals not implemented") + throw new LowererUnsupportedOperation(s"${getClass.getSimpleName}.lowerGlobals not implemented") } object AvroTableReader { def fromJValue(jv: JValue): AvroTableReader = { - implicit val formats: Formats = PartitionReader.formats + new UnsafeAvroTableReaderOptionsSerializer + implicit val formats: Formats = + PartitionReader.formats + new UnsafeAvroTableReaderOptionsSerializer val paths = (jv \ "paths").extract[IndexedSeq[String]] val partitionReader = (jv \ "partitionReader").extract[AvroPartitionReader] val unsafeOptions = (jv \ "unsafeOptions").extract[Option[UnsafeAvroTableReaderOptions]] diff --git a/hail/src/main/scala/is/hail/io/avro/UnsafeAvroTableReaderOptions.scala b/hail/src/main/scala/is/hail/io/avro/UnsafeAvroTableReaderOptions.scala index 61f9fb0956e..14a2c5b8175 100644 --- a/hail/src/main/scala/is/hail/io/avro/UnsafeAvroTableReaderOptions.scala +++ b/hail/src/main/scala/is/hail/io/avro/UnsafeAvroTableReaderOptions.scala @@ -4,29 +4,36 @@ import is.hail.expr.JSONAnnotationImpex import is.hail.expr.ir.IRParser import is.hail.types.virtual.{TArray, TInterval, Type} import is.hail.utils.Interval + import org.json4s.{CustomSerializer, Formats, JObject} import org.json4s.JsonDSL._ -case class UnsafeAvroTableReaderOptions(key: IndexedSeq[String], intervals: IndexedSeq[Interval], intervalPointType: Type) +case class UnsafeAvroTableReaderOptions( + key: IndexedSeq[String], + intervals: IndexedSeq[Interval], + intervalPointType: Type, +) -class UnsafeAvroTableReaderOptionsSerializer extends CustomSerializer[UnsafeAvroTableReaderOptions](format => ( - { case jv: JObject => - implicit val fmt: Formats = format - val key = (jv \ "key").extract[IndexedSeq[String]] - val intervalPointType = IRParser.parseType((jv \ "intervalPointType").extract[String]) - val intervals = { - val jIntervals = jv \ "intervals" - val ty = TArray(TInterval(intervalPointType)) - JSONAnnotationImpex.importAnnotation(jIntervals, ty).asInstanceOf[IndexedSeq[Interval]] - } - UnsafeAvroTableReaderOptions(key, intervals, intervalPointType) - }, - { case UnsafeAvroTableReaderOptions(key, intervals, intervalPointType) => - implicit val fmt: Formats = format - val ty = TArray(TInterval(intervalPointType)) - ("name" -> UnsafeAvroTableReaderOptions.getClass.getSimpleName) ~ - ("key" -> key) ~ - ("intervals" -> JSONAnnotationImpex.exportAnnotation(intervals, ty)) ~ - ("intervalPointType" -> intervalPointType.parsableString()) - } -)) \ No newline at end of file +class UnsafeAvroTableReaderOptionsSerializer + extends CustomSerializer[UnsafeAvroTableReaderOptions](format => + ( + { case jv: JObject => + implicit val fmt: Formats = format + val key = (jv \ "key").extract[IndexedSeq[String]] + val intervalPointType = IRParser.parseType((jv \ "intervalPointType").extract[String]) + val intervals = { + val jIntervals = jv \ "intervals" + val ty = TArray(TInterval(intervalPointType)) + JSONAnnotationImpex.importAnnotation(jIntervals, ty).asInstanceOf[IndexedSeq[Interval]] + } + UnsafeAvroTableReaderOptions(key, intervals, intervalPointType) + }, + { case UnsafeAvroTableReaderOptions(key, intervals, intervalPointType) => + val ty = TArray(TInterval(intervalPointType)) + ("name" -> UnsafeAvroTableReaderOptions.getClass.getSimpleName) ~ + ("key" -> key) ~ + ("intervals" -> JSONAnnotationImpex.exportAnnotation(intervals, ty)) ~ + ("intervalPointType" -> intervalPointType.parsableString()) + }, + ) + ) diff --git a/hail/src/main/scala/is/hail/io/bgen/BgenRDDPartitions.scala b/hail/src/main/scala/is/hail/io/bgen/BgenRDDPartitions.scala index 514af6b9e12..0241e824747 100644 --- a/hail/src/main/scala/is/hail/io/bgen/BgenRDDPartitions.scala +++ b/hail/src/main/scala/is/hail/io/bgen/BgenRDDPartitions.scala @@ -1,23 +1,19 @@ package is.hail.io.bgen -import is.hail.annotations.Region -import is.hail.asm4s._ -import is.hail.backend.{ExecuteContext, HailTaskContext} -import is.hail.expr.ir.{EmitCode, EmitFunctionBuilder, IEmitCode, ParamType, TableReader} -import is.hail.io.fs.FS +import is.hail.backend.ExecuteContext import is.hail.types.virtual._ import is.hail.utils._ -import is.hail.variant.{Call2, ReferenceGenome} case class FilePartitionInfo( metadata: BgenFileMetadata, intervals: Array[Interval], partStarts: Array[Long], - partN: Array[Long] + partN: Array[Long], ) object BgenRDDPartitions extends Logging { - def checkFilesDisjoint(ctx: ExecuteContext, fileMetadata: Seq[BgenFileMetadata], keyType: Type): Array[Interval] = { + def checkFilesDisjoint(ctx: ExecuteContext, fileMetadata: Seq[BgenFileMetadata], keyType: Type) + : Array[Interval] = { assert(fileMetadata.nonEmpty) val pord = keyType.ordering(ctx.stateManager) val bounds = fileMetadata.map(md => (md.path, md.rangeBounds)) @@ -39,11 +35,10 @@ object BgenRDDPartitions extends Logging { if (!overlappingBounds.isEmpty) fatal( s"""Each BGEN file must contain a region of the genome disjoint from other files. Found the following overlapping files: - | ${ - overlappingBounds.result().map { case (f1, i1, f2, i2) => + | ${overlappingBounds.result().map { case (f1, i1, f2, i2) => s"file1: $f1\trangeBounds1: $i1\tfile2: $f2\trangeBounds2: $i2" - }.mkString("\n ") - })""".stripMargin) + }.mkString("\n ")})""".stripMargin + ) bounds.map(_._2).toArray } @@ -54,10 +49,8 @@ object BgenRDDPartitions extends Logging { files: IndexedSeq[BgenFileMetadata], blockSizeInMB: Option[Int], nPartitions: Option[Int], - keyType: Type + keyType: Type, ): IndexedSeq[FilePartitionInfo] = { - val fs = ctx.fs - val fileRangeBounds = checkFilesDisjoint(ctx, files, keyType) val intervalOrdering = TInterval(keyType).ordering(ctx.stateManager) @@ -91,7 +84,9 @@ object BgenRDDPartitions extends Logging { val nPartitions = math.min(fileNPartitions(fileIndex), file.nVariants).toInt val partNVariants: Array[Long] = partition(file.nVariants, nPartitions) val partFirstVariantIndex = partNVariants.scan(0L)(_ + _).init - val partLastVariantIndex = (partFirstVariantIndex, partNVariants).zipped.map { (idx, n) => idx + n } + val partLastVariantIndex = (partFirstVariantIndex, partNVariants).zipped.map { (idx, n) => + idx + n + } val allPositions = partFirstVariantIndex ++ partLastVariantIndex.map(_ - 1L) val keys = getKeysFromFile(file.indexPath, allPositions) @@ -100,7 +95,8 @@ object BgenRDDPartitions extends Logging { keys(i), keys(i + nPartitions), true, - true) // this must be true -- otherwise boundaries with duplicates will have the wrong range bounds + true, + ) // this must be true -- otherwise boundaries with duplicates will have the wrong range bounds }.toArray FilePartitionInfo(file, rangeBounds, partFirstVariantIndex, partNVariants) diff --git a/hail/src/main/scala/is/hail/io/bgen/BgenSettings.scala b/hail/src/main/scala/is/hail/io/bgen/BgenSettings.scala index 327a3dc1965..a92cdfeee6e 100644 --- a/hail/src/main/scala/is/hail/io/bgen/BgenSettings.scala +++ b/hail/src/main/scala/is/hail/io/bgen/BgenSettings.scala @@ -2,13 +2,12 @@ package is.hail.io.bgen import is.hail.expr.ir.PruneDeadFields import is.hail.io._ +import is.hail.types.{MatrixType, TableType} import is.hail.types.encoded._ import is.hail.types.physical._ import is.hail.types.virtual._ -import is.hail.types.{MatrixType, TableType} import is.hail.utils._ - object BgenSettings { val UNCOMPRESSED: Int = 0 val ZLIB_COMPRESSION: Int = 1 @@ -16,7 +15,8 @@ object BgenSettings { def indexKeyType(rg: Option[String]): TStruct = TStruct( "locus" -> rg.map(TLocus(_)).getOrElse(TLocus.representation), - "alleles" -> TArray(TString)) + "alleles" -> TArray(TString), + ) val indexAnnotationType: Type = TStruct.empty @@ -27,58 +27,98 @@ object BgenSettings { BufferSpec.lz4HCCompressionLEB } - - def indexCodecSpecs(indexVersion: SemanticVersion, rg: Option[String]): (AbstractTypedCodecSpec, AbstractTypedCodecSpec) = { + def indexCodecSpecs(indexVersion: SemanticVersion, rg: Option[String]) + : (AbstractTypedCodecSpec, AbstractTypedCodecSpec) = { val bufferSpec = specFromVersion(indexVersion) val keyVType = indexKeyType(rg) - val keyEType = EBaseStruct(FastSeq( - EField("locus", EBaseStruct(FastSeq( - EField("contig", EBinaryRequired, 0), - EField("position", EInt32Required, 1))), 0), - EField("alleles", EArray(EBinaryOptional, required = false), 1)), - required = false) + val keyEType = EBaseStruct( + FastSeq( + EField( + "locus", + EBaseStruct(FastSeq( + EField("contig", EBinaryRequired, 0), + EField("position", EInt32Required, 1), + )), + 0, + ), + EField("alleles", EArray(EBinaryOptional, required = false), 1), + ), + required = false, + ) val annotationVType = TStruct.empty val annotationEType = EBaseStruct(FastSeq(), required = true) val leafEType = EBaseStruct(FastSeq( EField("first_idx", EInt64Required, 0), - EField("keys", EArray(EBaseStruct(FastSeq( - EField("key", keyEType, 0), - EField("offset", EInt64Required, 1), - EField("annotation", annotationEType, 2) - ), required = true), required = true), 1) + EField( + "keys", + EArray( + EBaseStruct( + FastSeq( + EField("key", keyEType, 0), + EField("offset", EInt64Required, 1), + EField("annotation", annotationEType, 2), + ), + required = true, + ), + required = true, + ), + 1, + ), )) val leafVType = TStruct(FastSeq( Field("first_idx", TInt64, 0), - Field("keys", TArray(TStruct(FastSeq( - Field("key", keyVType, 0), - Field("offset", TInt64, 1), - Field("annotation", annotationVType, 2) - ))), 1))) + Field( + "keys", + TArray(TStruct(FastSeq( + Field("key", keyVType, 0), + Field("offset", TInt64, 1), + Field("annotation", annotationVType, 2), + ))), + 1, + ), + )) val internalNodeEType = EBaseStruct(FastSeq( - EField("children", EArray(EBaseStruct(FastSeq( - EField("index_file_offset", EInt64Required, 0), - EField("first_idx", EInt64Required, 1), - EField("first_key", keyEType, 2), - EField("first_record_offset", EInt64Required, 3), - EField("first_annotation", annotationEType, 4) - ), required = true), required = true), 0) + EField( + "children", + EArray( + EBaseStruct( + FastSeq( + EField("index_file_offset", EInt64Required, 0), + EField("first_idx", EInt64Required, 1), + EField("first_key", keyEType, 2), + EField("first_record_offset", EInt64Required, 3), + EField("first_annotation", annotationEType, 4), + ), + required = true, + ), + required = true, + ), + 0, + ) )) val internalNodeVType = TStruct(FastSeq( - Field("children", TArray(TStruct(FastSeq( - Field("index_file_offset", TInt64, 0), - Field("first_idx", TInt64, 1), - Field("first_key", keyVType, 2), - Field("first_record_offset", TInt64, 3), - Field("first_annotation", annotationVType, 4) - ))), 0) + Field( + "children", + TArray(TStruct(FastSeq( + Field("index_file_offset", TInt64, 0), + Field("first_idx", TInt64, 1), + Field("first_key", keyVType, 2), + Field("first_record_offset", TInt64, 3), + Field("first_annotation", annotationVType, 4), + ))), + 0, + ) )) - (TypedCodecSpec(leafEType, leafVType, bufferSpec), (TypedCodecSpec(internalNodeEType, internalNodeVType, bufferSpec))) + ( + TypedCodecSpec(leafEType, leafVType, bufferSpec), + (TypedCodecSpec(internalNodeEType, internalNodeVType, bufferSpec)), + ) } } @@ -86,15 +126,19 @@ case class BgenSettings( nSamples: Int, requestedType: TableType, rg: Option[String], - indexAnnotationType: Type + indexAnnotationType: Type, ) { - require(PruneDeadFields.isSupertype(requestedType, MatrixBGENReader.fullMatrixType(rg).canonicalTableType)) + require(PruneDeadFields.isSupertype( + requestedType, + MatrixBGENReader.fullMatrixType(rg).canonicalTableType, + )) val entryType: Option[TStruct] = requestedType.rowType .selfField(MatrixType.entriesIdentifier) .map(f => f.typ.asInstanceOf[TArray].elementType.asInstanceOf[TStruct]) - val rowPType: PCanonicalStruct = PCanonicalStruct(required = true, + val rowPType: PCanonicalStruct = PCanonicalStruct( + required = true, Array( "locus" -> PCanonicalLocus.schemaFromRG(rg, required = false), "alleles" -> PCanonicalArray(PCanonicalString(false), false), @@ -106,14 +150,20 @@ case class BgenSettings( Array( "GT" -> PCanonicalCall(), "GP" -> PCanonicalArray(PFloat64Required, required = true), - "dosage" -> PFloat64Required - ).filter { case (name, _) => entryType.exists(t => t.hasField(name)) - }: _*))) - .filter { case (name, _) => requestedType.rowType.hasField(name) }: _*) - - assert(rowPType.virtualType == requestedType.rowType, s"${ rowPType.virtualType.parsableString() } vs ${ requestedType.rowType.parsableString() }") - - val indexKeyType: PStruct = rowPType.selectFields(Array("locus", "alleles")).setRequired(false).asInstanceOf[PStruct] + "dosage" -> PFloat64Required, + ).filter { case (name, _) => entryType.exists(t => t.hasField(name)) }: _* + )), + ) + .filter { case (name, _) => requestedType.rowType.hasField(name) }: _* + ) + + assert( + rowPType.virtualType == requestedType.rowType, + s"${rowPType.virtualType.parsableString()} vs ${requestedType.rowType.parsableString()}", + ) + + val indexKeyType: PStruct = + rowPType.selectFields(Array("locus", "alleles")).setRequired(false).asInstanceOf[PStruct] def hasField(name: String): Boolean = requestedType.rowType.hasField(name) diff --git a/hail/src/main/scala/is/hail/io/bgen/LoadBgen.scala b/hail/src/main/scala/is/hail/io/bgen/LoadBgen.scala index 825fe702a4e..7f756fa5aac 100644 --- a/hail/src/main/scala/is/hail/io/bgen/LoadBgen.scala +++ b/hail/src/main/scala/is/hail/io/bgen/LoadBgen.scala @@ -3,9 +3,14 @@ package is.hail.io.bgen import is.hail.annotations.Region import is.hail.asm4s._ import is.hail.backend.ExecuteContext +import is.hail.expr.ir.{ + EmitCode, EmitCodeBuilder, EmitMethodBuilder, EmitSettable, EmitValue, IEmitCode, IR, + IRParserEnvironment, Literal, LowerMatrixIR, MakeStruct, MatrixHybridReader, MatrixReader, + PartitionNativeIntervalReader, PartitionReader, ReadPartition, Ref, TableNativeReader, + TableReader, ToStream, +} import is.hail.expr.ir.lowering.{TableStage, TableStageDependency} import is.hail.expr.ir.streams.StreamProducer -import is.hail.expr.ir.{EmitCode, EmitCodeBuilder, EmitMethodBuilder, EmitSettable, EmitValue, IEmitCode, IR, IRParserEnvironment, Literal, LowerMatrixIR, MakeStruct, MatrixHybridReader, MatrixReader, PartitionNativeIntervalReader, PartitionReader, ReadPartition, Ref, TableNativeReader, TableReader, ToStream} import is.hail.io._ import is.hail.io.fs.{FS, FileListEntry, SeekableDataInputStream} import is.hail.io.index.{IndexReader, StagedIndexReader} @@ -17,13 +22,14 @@ import is.hail.types.physical.stypes.concrete.{SJavaArrayString, SStackStruct} import is.hail.types.physical.stypes.interfaces._ import is.hail.types.virtual._ import is.hail.utils._ -import org.apache.spark.sql.Row -import org.json4s.JsonAST.{JArray, JInt, JNull, JString} -import org.json4s.{DefaultFormats, Extraction, Formats, JObject, JValue} import scala.collection.mutable import scala.io.Source +import org.apache.spark.sql.Row +import org.json4s.{DefaultFormats, Extraction, Formats, JObject, JValue} +import org.json4s.JsonAST.{JArray, JInt, JNull, JString} + case class BgenHeader( compression: Int, // 0 uncompressed, 1 zlib, 2 zstd nSamples: Int, @@ -33,7 +39,7 @@ case class BgenHeader( hasIds: Boolean, version: Int, fileByteSize: Long, - path: String + path: String, ) case class BgenFileMetadata( @@ -46,7 +52,8 @@ case class BgenFileMetadata( nVariants: Long, @transient indexKeyType: Type, @transient indexAnnotationType: Type, - @transient rangeBounds: Interval) { + @transient rangeBounds: Interval, +) { def nSamples: Int = header.nSamples def compression: Int = header.compression def path: String = header.path @@ -62,14 +69,14 @@ object LoadBgen { val nSamples = is.readInt() if (nSamples != bState.nSamples) - fatal("BGEN file is malformed -- number of sample IDs in header does not equal number in file") + fatal( + "BGEN file is malformed -- number of sample IDs in header does not equal number in file" + ) if (sampleIdSize + bState.headerLength > bState.dataStart - 4) fatal("BGEN file is malformed -- offset is smaller than length of header") - (0 until nSamples).map { i => - is.readLengthAndString(2) - }.toArray + (0 until nSamples).map(i => is.readLengthAndString(2)).toArray } } else { warn(s"BGEN file '$file' contains no sample ID block and no sample ID file given.\n" + @@ -92,11 +99,10 @@ object LoadBgen { } } - def readState(fs: FS, file: String): BgenHeader = { + def readState(fs: FS, file: String): BgenHeader = using(new HadoopFSDataBinaryReader(fs.openNoCompression(file))) { is => readState(is, file, fs.getFileSize(file)) } - } def readState(is: HadoopFSDataBinaryReader, path: String, byteSize: Long): BgenHeader = { is.seek(0) @@ -111,7 +117,7 @@ object LoadBgen { val magicNumber = is.readBytes(4).map(_.toInt).toFastSeq if (magicNumber != FastSeq(0, 0, 0, 0) && magicNumber != FastSeq(98, 103, 101, 110)) - fatal(s"expected magic number [0000] or [bgen], got [${ magicNumber.mkString }]") + fatal(s"expected magic number [0000] or [bgen], got [${magicNumber.mkString}]") if (headerLength > 20) is.skipBytes(headerLength - 20) @@ -136,16 +142,17 @@ object LoadBgen { hasIds, version, byteSize, - path + path, ) } - def checkVersionTwo(headers: Array[BgenHeader]) { + def checkVersionTwo(headers: Array[BgenHeader]): Unit = { val notVersionTwo = headers.filter(_.version != 2).map(x => x.path -> x.version) if (notVersionTwo.length > 0) fatal( s"""The following BGEN files are not BGENv2: - | ${ notVersionTwo.mkString("\n ") }""".stripMargin) + | ${notVersionTwo.mkString("\n ")}""".stripMargin + ) } def getAllFileListEntries(fs: FS, files: Array[String]): Array[FileListEntry] = { @@ -157,13 +164,17 @@ object LoadBgen { badFiles += file matches.flatMap { fileListEntry => - val file = fileListEntry.getPath.toString + val file = fileListEntry.getPath if (!file.endsWith(".bgen")) warn(s"input file does not have .bgen extension: $file") - if (fs.isDir(file)) + if (fileListEntry.isDirectory) fs.listDirectory(file) - .filter(fileListEntry => ".*part-[0-9]+(-[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12})?".r.matches(fileListEntry.getPath.toString)) + .filter(fileListEntry => + ".*part-[0-9]+(-[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12})?".r.matches( + fileListEntry.getPath.toString + ) + ) else Array(fileListEntry) } @@ -172,7 +183,8 @@ object LoadBgen { if (!badFiles.isEmpty) fatal( s"""The following paths refer to no files: - | ${ badFiles.result().mkString("\n ") }""".stripMargin) + | ${badFiles.result().mkString("\n ")}""".stripMargin + ) fileListEntries } @@ -180,38 +192,48 @@ object LoadBgen { def getAllFilePaths(fs: FS, files: Array[String]): Array[String] = getAllFileListEntries(fs, files).map(_.getPath.toString) - - def getBgenFileMetadata(ctx: ExecuteContext, files: Array[String], indexFiles: Array[String]): Array[BgenFileMetadata] = { + def getBgenFileMetadata( + ctx: ExecuteContext, + files: Array[FileListEntry], + indexFilePaths: Array[String], + ): Array[BgenFileMetadata] = { val fs = ctx.fs - require(files.length == indexFiles.length) - val headers = getFileHeaders(fs, files) + require(files.length == indexFilePaths.length) + val headers = getFileHeaders(fs, files.map(_.getPath)) - val cacheByRG: mutable.Map[Option[String], (String, Array[Long]) => Array[AnyRef]] = mutable.Map.empty + val cacheByRG: mutable.Map[Option[String], (String, Array[Long]) => Array[AnyRef]] = + mutable.Map.empty - headers.zip(indexFiles).map { case (h, indexFile) => - val (keyType, annotationType) = IndexReader.readTypes(fs, indexFile) + headers.zip(indexFilePaths).map { case (h, indexFilePath) => + val (keyType, annotationType) = IndexReader.readTypes(fs, indexFilePath) val rg = keyType.asInstanceOf[TStruct].field("locus").typ match { case TLocus(rg) => Some(rg) case _ => None } - val metadata = IndexReader.readMetadata(fs, indexFile, keyType, annotationType) + val metadata = IndexReader.readMetadata(fs, indexFilePath, keyType, annotationType) val indexVersion = SemanticVersion(metadata.fileVersion) val (leafSpec, internalSpec) = BgenSettings.indexCodecSpecs(indexVersion, rg) - val getKeys = cacheByRG.getOrElseUpdate(rg, StagedBGENReader.queryIndexByPosition(ctx, leafSpec, internalSpec)) + val getKeys = cacheByRG.getOrElseUpdate( + rg, + StagedBGENReader.queryIndexByPosition(ctx, leafSpec, internalSpec), + ) val attributes = metadata.attributes val skipInvalidLoci = attributes("skip_invalid_loci").asInstanceOf[Boolean] - val contigRecoding = Option(attributes("contig_recoding")).map(_.asInstanceOf[Map[String, String]]).getOrElse(Map.empty[String, String]) + val contigRecoding = Option(attributes("contig_recoding")).map(_.asInstanceOf[Map[ + String, + String, + ]]).getOrElse(Map.empty[String, String]) val nVariants = metadata.nKeys val rangeBounds = if (nVariants > 0) { - val Array(start, end) = getKeys(indexFile, Array[Long](0L, nVariants - 1)) + val Array(start, end) = getKeys(indexFilePath, Array[Long](0L, nVariants - 1)) Interval(start, end, includesStart = true, includesEnd = true) } else null BgenFileMetadata( - indexFile, + indexFilePath, indexVersion, h, rg, @@ -220,12 +242,14 @@ object LoadBgen { metadata.nKeys, keyType, annotationType, - rangeBounds) + rangeBounds, + ) } } - def getIndexFileNames(fs: FS, files: Array[String], indexFileMap: Map[String, String]): Array[String] = { - def absolutePath(rel: String): String = fs.fileListEntry(rel).getPath.toString + def getIndexFileNames(fs: FS, files: Array[FileListEntry], indexFileMap: Map[String, String]) + : Array[String] = { + def absolutePath(rel: String): String = fs.fileStatus(rel).getPath val fileMapping = Option(indexFileMap) .getOrElse(Map.empty[String, String]) @@ -235,18 +259,27 @@ object LoadBgen { if (badExtensions.nonEmpty) fatal( s"""The following index file paths defined by 'index_file_map' are missing a .idx2 file extension: - | ${ badExtensions.mkString("\n ") })""".stripMargin) + | ${badExtensions.mkString("\n ")})""".stripMargin + ) - files.map(absolutePath).map(f => fileMapping.getOrElse(f, f + ".idx2")) + files.map(f => fileMapping.getOrElse(f.getPath, f.getPath + ".idx2")) } - def getIndexFiles(fs: FS, files: Array[String], indexFileMap: Map[String, String]): Array[String] = { + def getIndexFiles(fs: FS, files: Array[FileListEntry], indexFileMap: Map[String, String]) + : Array[String] = { val indexFiles = getIndexFileNames(fs, files, indexFileMap) - val missingIdxFiles = files.zip(indexFiles).filterNot { case (f, index) => fs.exists(index) && index.endsWith("idx2") }.map(_._1) - if (missingIdxFiles.nonEmpty) + + val bgenFilesWhichAreMisssingIdx2Files = files.zip(indexFiles).filterNot { + case (_, index) => index.endsWith("idx2") && fs.isFile(index + "/index") && fs.isFile( + index + "/metadata.json.gz" + ) + }.map(_._1.getPath) + + if (bgenFilesWhichAreMisssingIdx2Files.nonEmpty) fatal( s"""The following BGEN files have no .idx2 index file. Use 'index_bgen' to create the index file once before calling 'import_bgen': - | ${ missingIdxFiles.mkString("\n ") })""".stripMargin) + | ${bgenFilesWhichAreMisssingIdx2Files.mkString("\n ")}""".stripMargin + ) indexFiles } @@ -260,7 +293,8 @@ object LoadBgen { if (rgs.distinct.length != 1) fatal( s"""Found multiple reference genomes were specified in the BGEN index files: - | ${ rgs.distinct.map(_.getOrElse("None")).mkString("\n ") }""".stripMargin) + | ${rgs.distinct.map(_.getOrElse("None")).mkString("\n ")}""".stripMargin + ) rgs.head } @@ -271,12 +305,14 @@ object LoadBgen { if (indexKeyTypes.length != 1) fatal( s"""Found more than one BGEN index key type: - | ${ indexKeyTypes.mkString("\n ") })""".stripMargin) + | ${indexKeyTypes.mkString("\n ")})""".stripMargin + ) if (indexAnnotationTypes.length != 1) fatal( s"""Found more than one BGEN index annotation type: - | ${ indexAnnotationTypes.mkString("\n ") })""".stripMargin) + | ${indexAnnotationTypes.mkString("\n ")})""".stripMargin + ) (indexKeyTypes.head, indexAnnotationTypes.head) } @@ -287,7 +323,8 @@ object MatrixBGENReader { MatrixType( globalType = TStruct.empty, colType = TStruct( - "s" -> TString), + "s" -> TString + ), colKey = Array("s"), rowType = TStruct( "locus" -> TLocus.schemaFromRG(rg), @@ -295,12 +332,15 @@ object MatrixBGENReader { "rsid" -> TString, "varid" -> TString, "offset" -> TInt64, - "file_idx" -> TInt32), + "file_idx" -> TInt32, + ), rowKey = Array("locus", "alleles"), entryType = TStruct( "GT" -> TCall, "GP" -> TArray(TFloat64), - "dosage" -> TFloat64)) + "dosage" -> TFloat64, + ), + ) } def fullMatrixType(rg: Option[String]): MatrixType = { @@ -316,34 +356,40 @@ object MatrixBGENReader { val ttNoUID = mt.copy(mt.colType.appendKey(MatrixReader.colUIDFieldName, TInt64)) .toTableType(LowerMatrixIR.entriesFieldName, LowerMatrixIR.colsFieldName) - ttNoUID.copy(rowType = ttNoUID.rowType.appendKey(MatrixReader.rowUIDFieldName, TTuple(TInt64, TInt64))) + ttNoUID.copy(rowType = + ttNoUID.rowType.appendKey(MatrixReader.rowUIDFieldName, TTuple(TInt64, TInt64)) + ) } - - def fromJValue(env: IRParserEnvironment, jv: JValue): MatrixBGENReader = { + def fromJValue(env: IRParserEnvironment, jv: JValue): MatrixBGENReader = MatrixBGENReader(env.ctx, MatrixBGENReaderParameters.fromJValue(jv)) - } - def apply(ctx: ExecuteContext, + def apply( + ctx: ExecuteContext, files: Seq[String], sampleFile: Option[String], indexFileMap: Map[String, String], nPartitions: Option[Int], blockSizeInMB: Option[Int], - includedVariants: Option[String]): MatrixBGENReader = { - MatrixBGENReader(ctx, - MatrixBGENReaderParameters(files, sampleFile, indexFileMap, nPartitions, blockSizeInMB, includedVariants)) - } + includedVariants: Option[String], + ): MatrixBGENReader = + MatrixBGENReader( + ctx, + MatrixBGENReaderParameters(files, sampleFile, indexFileMap, nPartitions, blockSizeInMB, + includedVariants), + ) def apply(ctx: ExecuteContext, params: MatrixBGENReaderParameters): MatrixBGENReader = { val fs = ctx.fs - val allFiles = LoadBgen.getAllFilePaths(fs, params.files.toArray) - val indexFiles = LoadBgen.getIndexFiles(fs, allFiles, params.indexFileMap) - val fileMetadata = LoadBgen.getBgenFileMetadata(ctx, allFiles, indexFiles) + val allFiles = LoadBgen.getAllFileListEntries(fs, params.files.toArray) + val indexFilePaths = LoadBgen.getIndexFiles(fs, allFiles, params.indexFileMap) + val fileMetadata = LoadBgen.getBgenFileMetadata(ctx, allFiles, indexFilePaths) assert(fileMetadata.nonEmpty) if (fileMetadata.exists(md => md.indexVersion != fileMetadata.head.indexVersion)) { - fatal("BGEN index version mismatch. The index versions of all files must be the same, use 'index_bgen' to reindex all files to ensure that all index versions match before calling 'import_bgen' again") + fatal( + "BGEN index version mismatch. The index versions of all files must be the same, use 'index_bgen' to reindex all files to ensure that all index versions match before calling 'import_bgen' again" + ) } val sampleIds = params.sampleFile.map(file => LoadBgen.readSampleFile(fs, file)) @@ -353,26 +399,29 @@ object MatrixBGENReader { val nSamples = sampleIds.length - val unequalSamples = fileMetadata.filter(_.header.nSamples != nSamples).map(x => (x.path, x.header.nSamples)) + val unequalSamples = + fileMetadata.filter(_.header.nSamples != nSamples).map(x => (x.path, x.header.nSamples)) if (unequalSamples.length > 0) { val unequalSamplesString = - unequalSamples.map(x => s"""(${ x._2 } ${ x._1 }""").mkString("\n ") + unequalSamples.map(x => s"""(${x._2} ${x._1}""").mkString("\n ") fatal( s"""The following BGEN files did not contain the expected number of samples $nSamples: - | $unequalSamplesString""".stripMargin) + | $unequalSamplesString""".stripMargin + ) } val noVariants = fileMetadata.filter(_.nVariants == 0).map(_.path) if (noVariants.length > 0) fatal( s"""The following BGEN files did not contain at least 1 variant: - | ${ noVariants.mkString("\n ") })""".stripMargin) + | ${noVariants.mkString("\n ")})""".stripMargin + ) LoadBgen.checkVersionTwo(fileMetadata.map(_.header)) val nVariants = fileMetadata.map(_.nVariants).sum - info(s"Number of BGEN files parsed: ${ fileMetadata.length }") + info(s"Number of BGEN files parsed: ${fileMetadata.length}") info(s"Number of samples in BGEN files: $nSamples") info(s"Number of variants across all BGEN files: $nVariants") @@ -382,14 +431,28 @@ object MatrixBGENReader { val (indexKeyType, indexAnnotationType) = LoadBgen.getIndexTypes(fileMetadata) - val filePartInfo = BgenRDDPartitions(ctx, referenceGenome, fileMetadata, + val filePartInfo = BgenRDDPartitions( + ctx, + referenceGenome, + fileMetadata, if (params.nPartitions.isEmpty && params.blockSizeInMB.isEmpty) Some(128) else - params.blockSizeInMB, params.nPartitions, indexKeyType) + params.blockSizeInMB, + params.nPartitions, + indexKeyType, + ) new MatrixBGENReader( - params, referenceGenome, fullMatrixType, indexKeyType, indexAnnotationType, sampleIds, filePartInfo, params.includedVariants) + params, + referenceGenome, + fullMatrixType, + indexKeyType, + indexAnnotationType, + sampleIds, + filePartInfo, + params.includedVariants, + ) } } @@ -405,7 +468,8 @@ object MatrixBGENReaderParameters { case JNull => None case JString(s) => Some(s) } - new MatrixBGENReaderParameters(files, sampleFile, indexFileMap, nPartitions, blockSizeInMB, includedVariants) + new MatrixBGENReaderParameters(files, sampleFile, indexFileMap, nPartitions, blockSizeInMB, + includedVariants) } } @@ -415,7 +479,8 @@ case class MatrixBGENReaderParameters( indexFileMap: Map[String, String], nPartitions: Option[Int], blockSizeInMB: Option[Int], - includedVariants: Option[String]) { + includedVariants: Option[String], +) { def toJValue: JValue = { JObject(List( @@ -427,7 +492,8 @@ case class MatrixBGENReaderParameters( }.toList), "nPartitions" -> nPartitions.map(JInt(_)).getOrElse(JNull), "blockSizeInMB" -> blockSizeInMB.map(JInt(_)).getOrElse(JNull), - "includedVariants" -> includedVariants.map(t => JString(t)).getOrElse(JNull))) + "includedVariants" -> includedVariants.map(t => JString(t)).getOrElse(JNull), + )) } } @@ -439,7 +505,8 @@ class MatrixBGENReader( indexAnnotationType: Type, sampleIds: Array[String], filePartitionInfo: IndexedSeq[FilePartitionInfo], - variants: Option[String]) extends MatrixHybridReader { + variants: Option[String], +) extends MatrixHybridReader { def pathsUsed: Seq[String] = filePartitionInfo.map(_.metadata.path) lazy val nVariants: Long = filePartitionInfo.map(_.metadata.nVariants).sum @@ -462,12 +529,14 @@ class MatrixBGENReader( nSamples, requestedType, referenceGenome, - indexAnnotationType) + indexAnnotationType, + ) } _settings } - override def concreteRowRequiredness(ctx: ExecuteContext, requestedType: TableType): VirtualTypeWithReq = { + override def concreteRowRequiredness(ctx: ExecuteContext, requestedType: TableType) + : VirtualTypeWithReq = { val settings = getSettings(requestedType) VirtualTypeWithReq(settings.rowPType) } @@ -475,24 +544,27 @@ class MatrixBGENReader( override def uidRequiredness: VirtualTypeWithReq = VirtualTypeWithReq.fullyRequired(TTuple(TInt64, TInt64)) - override def globalRequiredness(ctx: ExecuteContext, requestedType: TableType): VirtualTypeWithReq = + override def globalRequiredness(ctx: ExecuteContext, requestedType: TableType) + : VirtualTypeWithReq = VirtualTypeWithReq(PType.canonical(requestedType.globalType, required = true)) override def lowerGlobals(ctx: ExecuteContext, requestedGlobalType: TStruct): IR = { requestedGlobalType.selfField(LowerMatrixIR.colsFieldName) match { case Some(f) => val ta = f.typ.asInstanceOf[TArray] - MakeStruct(FastSeq((LowerMatrixIR.colsFieldName, { - val arraysToZip = new BoxedArrayBuilder[IndexedSeq[Any]]() - val colType = ta.elementType.asInstanceOf[TStruct] - if (colType.hasField("s")) - arraysToZip += sampleIds - if (colType.hasField(colUIDFieldName)) - arraysToZip += sampleIds.indices.map(_.toLong) - - val fields = arraysToZip.result() - Literal(ta, sampleIds.indices.map(i => Row.fromSeq(fields.map(_.apply(i))))) - }))) + MakeStruct(FastSeq(( + LowerMatrixIR.colsFieldName, { + val arraysToZip = new BoxedArrayBuilder[IndexedSeq[Any]]() + val colType = ta.elementType.asInstanceOf[TStruct] + if (colType.hasField("s")) + arraysToZip += sampleIds + if (colType.hasField(colUIDFieldName)) + arraysToZip += sampleIds.indices.map(_.toLong) + + val fields = arraysToZip.result() + Literal(ta, sampleIds.indices.map(i => Row.fromSeq(fields.map(_.apply(i))))) + }, + ))) case None => MakeStruct(FastSeq()) } } @@ -518,7 +590,8 @@ class MatrixBGENReader( val contexts = new BoxedArrayBuilder[Row]() val rangeBounds = new BoxedArrayBuilder[Interval]() filePartitionInfo.zipWithIndex.foreach { case (file, fileIdx) => - val filePartitioner = new RVDPartitioner(ctx.stateManager, tcoerce[TStruct](indexKeyType), file.intervals) + val filePartitioner = + new RVDPartitioner(ctx.stateManager, tcoerce[TStruct](indexKeyType), file.intervals) val filterKeyLen = t0.spec.table_type.key.length val strictShortKey = filePartitioner.coarsen(filterKeyLen).strictify() @@ -526,36 +599,45 @@ class MatrixBGENReader( rangeBounds ++= strictBgenKey.rangeBounds strictShortKey.partitionBoundsIRRepresentation.value.asInstanceOf[IndexedSeq[_]] - .foreach { interval => - contexts += Row(fileIdx, interval) - } + .foreach(interval => contexts += Row(fileIdx, interval)) } - val partitioner = new RVDPartitioner(ctx.stateManager, tcoerce[TStruct](indexKeyType), rangeBounds.result()) + val partitioner = + new RVDPartitioner(ctx.stateManager, tcoerce[TStruct](indexKeyType), rangeBounds.result()) val reader = BgenPartitionReaderWithVariantFilter( filePartitionInfo.map(_.metadata).toArray, referenceGenome, - PartitionNativeIntervalReader(ctx.stateManager, v, t0.spec, "__dummy")) + PartitionNativeIntervalReader(ctx.stateManager, v, t0.spec, "__dummy"), + ) TableStage( globals = globals, partitioner = partitioner, dependency = TableStageDependency.none, contexts = ToStream(Literal(TArray(reader.contextType), contexts.result().toFastSeq)), - (ref: Ref) => ReadPartition(ref, requestedType.rowType, reader) + (ref: Ref) => ReadPartition(ref, requestedType.rowType, reader), ) case None => - val partitioner = new RVDPartitioner(ctx.stateManager, tcoerce[TStruct](indexKeyType), filePartitionInfo.flatMap(_.intervals)) - val reader = BgenPartitionReader(fileMetadata = filePartitionInfo.map(_.metadata).toArray, referenceGenome) + val partitioner = new RVDPartitioner( + ctx.stateManager, + tcoerce[TStruct](indexKeyType), + filePartitionInfo.flatMap(_.intervals), + ) + val reader = BgenPartitionReader( + fileMetadata = filePartitionInfo.map(_.metadata).toArray, + referenceGenome, + ) val contexts = new BoxedArrayBuilder[Row]() var partIdx = 0 var fileIdx = 0 filePartitionInfo.foreach { file => - assert(file.intervals.length == file.partN.length && file.intervals.length == file.partStarts.length) + assert( + file.intervals.length == file.partN.length && file.intervals.length == file.partStarts.length + ) file.intervals.indices.foreach { idxInFile => contexts += Row(fileIdx, file.partStarts(idxInFile), file.partN(idxInFile), partIdx) partIdx += 1 @@ -568,21 +650,27 @@ class MatrixBGENReader( partitioner = partitioner, dependency = TableStageDependency.none, contexts = ToStream(Literal(TArray(reader.contextType), contexts.result().toFastSeq)), - (ref: Ref) => ReadPartition(ref, requestedType.rowType, reader) + (ref: Ref) => ReadPartition(ref, requestedType.rowType, reader), ) } } } -case class BgenPartitionReaderWithVariantFilter(fileMetadata: Array[BgenFileMetadata], rg: Option[String], child: PartitionNativeIntervalReader) extends PartitionReader { +case class BgenPartitionReaderWithVariantFilter( + fileMetadata: Array[BgenFileMetadata], + rg: Option[String], + child: PartitionNativeIntervalReader, +) extends PartitionReader { lazy val contextType: TStruct = TStruct( "file_index" -> TInt32, - "interval" -> RVDPartitioner.intervalIRRepresentation(child.tableSpec.table_type.keyType)) + "interval" -> RVDPartitioner.intervalIRRepresentation(child.tableSpec.table_type.keyType), + ) lazy val uidType = TTuple(TInt64, TInt64) lazy val fullRowType: TStruct = MatrixBGENReader.fullTableType(rg).rowType - def rowRequiredness(requestedType: TStruct): RStruct = StagedBGENReader.rowRequiredness(requestedType) + def rowRequiredness(requestedType: TStruct): RStruct = + StagedBGENReader.rowRequiredness(requestedType) def uidFieldName: String = TableReader.uidFieldName @@ -591,7 +679,8 @@ case class BgenPartitionReaderWithVariantFilter(fileMetadata: Array[BgenFileMeta cb: EmitCodeBuilder, mb: EmitMethodBuilder[_], context: EmitCode, - requestedType: TStruct): IEmitCode = { + requestedType: TStruct, + ): IEmitCode = { val cbfis = mb.genFieldThisRef[HadoopFSDataBinaryReader]("bgen_cbfis") val nSamples = mb.genFieldThisRef[Int]("bgen_nsamples") @@ -610,7 +699,6 @@ case class BgenPartitionReaderWithVariantFilter(fileMetadata: Array[BgenFileMeta var out: EmitSettable = null // filled in later context.toI(cb).flatMap(cb) { case context: SBaseStructValue => - val rangeBound = EmitCode.fromI(mb)(cb => context.loadField(cb, "interval")) child.emitStream(ctx, cb, mb, rangeBound, child.fullRowType.deleteKey(child.uidFieldName)) @@ -625,9 +713,10 @@ case class BgenPartitionReaderWithVariantFilter(fileMetadata: Array[BgenFileMeta override def initialize(cb: EmitCodeBuilder, outerRegion: Value[Region]): Unit = { vs.initialize(cb, outerRegion) - cb.assign(fileIdx, context.loadField(cb, "file_index").get(cb).asInt.value) - val metadata = cb.memoize(mb.getObject[IndexedSeq[BgenFileMetadata]](fileMetadata.toFastSeq) - .invoke[Int, BgenFileMetadata]("apply", fileIdx)) + cb.assign(fileIdx, context.loadField(cb, "file_index").getOrAssert(cb).asInt.value) + val metadata = + cb.memoize(mb.getObject[IndexedSeq[BgenFileMetadata]](fileMetadata.toFastSeq) + .invoke[Int, BgenFileMetadata]("apply", fileIdx)) val fileName = cb.memoize(metadata.invoke[String]("path")) val indexName = cb.memoize(metadata.invoke[String]("indexPath")) cb.assign(nSamples, metadata.invoke[Int]("nSamples")) @@ -635,58 +724,92 @@ case class BgenPartitionReaderWithVariantFilter(fileMetadata: Array[BgenFileMeta cb.assign(compression, metadata.invoke[Int]("compression")) cb.assign(skipInvalidLoci, metadata.invoke[Boolean]("skipInvalidLoci")) - cb.assign(cbfis, Code.newInstance[HadoopFSDataBinaryReader, SeekableDataInputStream]( - mb.getFS.invoke[String, SeekableDataInputStream]("openNoCompression", fileName))) + cb.assign( + cbfis, + Code.newInstance[HadoopFSDataBinaryReader, SeekableDataInputStream]( + mb.getFS.invoke[String, SeekableDataInputStream]("openNoCompression", fileName) + ), + ) index.initialize(cb, indexName) cb.assign(indexNKeys, index.nKeys(cb)) } override val elementRegion: Settable[Region] = vs.elementRegion - override val requiresMemoryManagementPerElement: Boolean = vs.requiresMemoryManagementPerElement + override val requiresMemoryManagementPerElement: Boolean = + vs.requiresMemoryManagementPerElement override val LproduceElement: CodeLabel = mb.defineAndImplementLabel { cb => val Lstart = CodeLabel() cb.define(Lstart) - cb.if_(currVariantIndex < stopVariantIndex, { - val addr = index.queryIndex(cb, vs.elementRegion, currVariantIndex) - .loadField(cb, "offset") - .get(cb).asLong.value - cb += cbfis.invoke[Long, Unit]("seek", addr) - - val reqTypeNoUID = if (requestedType.hasField(uidFieldName)) requestedType.deleteKey(uidFieldName) else requestedType - val sc = StagedBGENReader.decodeRow(cb, elementRegion, cbfis, nSamples, fileIdx, compression, skipInvalidLoci, contigRecoding, reqTypeNoUID, rg) - .toI(cb).get(cb) - val scUID = if (requestedType.hasField(uidFieldName)) - sc.asBaseStruct.insert(cb, elementRegion, requestedType, - (uidFieldName, EmitValue.present(SStackStruct.constructFromArgs(cb, elementRegion, uidType, - EmitValue.present(primitive(cb.memoize(fileIdx.toL))), - EmitValue.present(primitive(currVariantIndex)))))) - else - sc - out = mb.newEmitField(scUID.st, true) - cb.assign(out, EmitCode.present(mb, scUID)) - - cb.assign(currVariantIndex, currVariantIndex + 1) - cb.goto(LproduceElementDone) - }) - + cb.if_( + currVariantIndex < stopVariantIndex, { + val addr = index.queryIndex(cb, vs.elementRegion, currVariantIndex) + .loadField(cb, "offset") + .getOrAssert(cb).asLong.value + cb += cbfis.invoke[Long, Unit]("seek", addr) + + val reqTypeNoUID = if (requestedType.hasField(uidFieldName)) + requestedType.deleteKey(uidFieldName) + else requestedType + val sc = StagedBGENReader.decodeRow(cb, elementRegion, cbfis, nSamples, fileIdx, + compression, skipInvalidLoci, contigRecoding, reqTypeNoUID, rg) + .toI(cb).getOrAssert(cb) + val scUID = if (requestedType.hasField(uidFieldName)) + sc.asBaseStruct.insert( + cb, + elementRegion, + requestedType, + ( + uidFieldName, + EmitValue.present(SStackStruct.constructFromArgs( + cb, + elementRegion, + uidType, + EmitValue.present(primitive(cb.memoize(fileIdx.toL))), + EmitValue.present(primitive(currVariantIndex)), + )), + ), + ) + else + sc + out = mb.newEmitField(scUID.st, true) + cb.assign(out, EmitCode.present(mb, scUID)) + + cb.assign(currVariantIndex, currVariantIndex + 1) + cb.goto(LproduceElementDone) + }, + ) cb.goto(vs.LproduceElement) cb.define(vs.LproduceElementDone) - val nextVariant = vs.element.toI(cb).get(cb).asBaseStruct - val bound = SStackStruct.constructFromArgs(cb, vs.elementRegion, TTuple(nextVariant.st.virtualType, TInt32), + val nextVariant = vs.element.toI(cb).getOrAssert(cb).asBaseStruct + val bound = SStackStruct.constructFromArgs( + cb, + vs.elementRegion, + TTuple(nextVariant.st.virtualType, TInt32), EmitValue.present(if (nextVariant.st.size == 1) - nextVariant.insert(cb, elementRegion, - nextVariant.st.virtualType.asInstanceOf[TStruct].structInsert(TArray(TString), FastSeq("alleles")), - ("alleles", EmitValue.missing(SJavaArrayString(true))) + nextVariant.insert( + cb, + elementRegion, + nextVariant.st.virtualType.asInstanceOf[TStruct].structInsert( + TArray(TString), + FastSeq("alleles"), + ), + ("alleles", EmitValue.missing(SJavaArrayString(true))), ) else nextVariant), - EmitValue.present(primitive(const(nextVariant.st.size))) + EmitValue.present(primitive(const(nextVariant.st.size))), ) - cb.assign(currVariantIndex, index.queryBound(cb, bound, false).loadField(cb, 0).get(cb).asLong.value) - cb.assign(stopVariantIndex, index.queryBound(cb, bound, true).loadField(cb, 0).get(cb).asLong.value) + cb.assign( + currVariantIndex, + index.queryBound(cb, bound, false).loadField(cb, 0).getOrAssert(cb).asLong.value, + ) + cb.assign( + stopVariantIndex, + index.queryBound(cb, bound, true).loadField(cb, 0).getOrAssert(cb).asLong.value, + ) cb.goto(Lstart) cb.define(vs.LendOfStream) @@ -708,19 +831,21 @@ case class BgenPartitionReaderWithVariantFilter(fileMetadata: Array[BgenFileMeta def toJValue: JValue = Extraction.decompose(this)(PartitionReader.formats) } - -case class BgenPartitionReader(fileMetadata: Array[BgenFileMetadata], rg: Option[String]) extends PartitionReader { +case class BgenPartitionReader(fileMetadata: Array[BgenFileMetadata], rg: Option[String]) + extends PartitionReader { lazy val contextType: TStruct = TStruct( "file_index" -> TInt32, "first_variant_index" -> TInt64, "n_variants" -> TInt64, - "partition_index" -> TInt32) + "partition_index" -> TInt32, + ) lazy val uidType = TTuple(TInt64, TInt64) lazy val fullRowType: TStruct = MatrixBGENReader.fullTableType(rg).rowType - def rowRequiredness(requestedType: TStruct): RStruct = StagedBGENReader.rowRequiredness(requestedType) + def rowRequiredness(requestedType: TStruct): RStruct = + StagedBGENReader.rowRequiredness(requestedType) def uidFieldName: String = TableReader.uidFieldName @@ -729,7 +854,8 @@ case class BgenPartitionReader(fileMetadata: Array[BgenFileMetadata], rg: Option cb: EmitCodeBuilder, mb: EmitMethodBuilder[_], context: EmitCode, - requestedType: TStruct): IEmitCode = { + requestedType: TStruct, + ): IEmitCode = { val eltRegion = mb.genFieldThisRef[Region]("bgen_region") val cbfis = mb.genFieldThisRef[HadoopFSDataBinaryReader]("bgen_cbfis") @@ -747,18 +873,21 @@ case class BgenPartitionReader(fileMetadata: Array[BgenFileMetadata], rg: Option var out: EmitSettable = null // filled in later context.toI(cb).map(cb) { case context: SBaseStructValue => - val ctxField = cb.memoizeField(context, "ctxField") SStreamValue(new StreamProducer { override def method: EmitMethodBuilder[_] = mb - override val length: Option[EmitCodeBuilder => Code[Int]] = Some(cb => ctxField.asBaseStruct.loadField(cb, "n_variants").get(cb).asLong.value.toI) + override val length: Option[EmitCodeBuilder => Code[Int]] = + Some(cb => + ctxField.asBaseStruct.loadField(cb, "n_variants").getOrAssert(cb).asLong.value.toI + ) override def initialize(cb: EmitCodeBuilder, outerRegion: Value[Region]): Unit = { - cb.assign(fileIdx, context.loadField(cb, "file_index").get(cb).asInt.value) - val metadata = cb.memoize(mb.getObject[IndexedSeq[BgenFileMetadata]](fileMetadata.toFastSeq) - .invoke[Int, BgenFileMetadata]("apply", fileIdx)) + cb.assign(fileIdx, context.loadField(cb, "file_index").getOrAssert(cb).asInt.value) + val metadata = + cb.memoize(mb.getObject[IndexedSeq[BgenFileMetadata]](fileMetadata.toFastSeq) + .invoke[Int, BgenFileMetadata]("apply", fileIdx)) val fileName = cb.memoize(metadata.invoke[String]("path")) val indexName = cb.memoize(metadata.invoke[String]("indexPath")) cb.assign(nSamples, metadata.invoke[Int]("nSamples")) @@ -766,12 +895,22 @@ case class BgenPartitionReader(fileMetadata: Array[BgenFileMetadata], rg: Option cb.assign(compression, metadata.invoke[Int]("compression")) cb.assign(skipInvalidLoci, metadata.invoke[Boolean]("skipInvalidLoci")) - cb.assign(cbfis, Code.newInstance[HadoopFSDataBinaryReader, SeekableDataInputStream]( - mb.getFS.invoke[String, SeekableDataInputStream]("openNoCompression", fileName))) + cb.assign( + cbfis, + Code.newInstance[HadoopFSDataBinaryReader, SeekableDataInputStream]( + mb.getFS.invoke[String, SeekableDataInputStream]("openNoCompression", fileName) + ), + ) index.initialize(cb, indexName) - cb.assign(currVariantIndex, context.loadField(cb, "first_variant_index").get(cb).asLong.value) - cb.assign(endVariantIndex, currVariantIndex + context.loadField(cb, "n_variants").get(cb).asLong.value) + cb.assign( + currVariantIndex, + context.loadField(cb, "first_variant_index").getOrAssert(cb).asLong.value, + ) + cb.assign( + endVariantIndex, + currVariantIndex + context.loadField(cb, "n_variants").getOrAssert(cb).asLong.value, + ) } override val elementRegion: Settable[Region] = eltRegion @@ -783,25 +922,42 @@ case class BgenPartitionReader(fileMetadata: Array[BgenFileMetadata], rg: Option val addr = index.queryIndex(cb, eltRegion, currVariantIndex) .loadField(cb, "offset") - .get(cb).asLong.value + .getOrAssert(cb).asLong.value cb += cbfis.invoke[Long, Unit]("seek", addr) - val reqTypeNoUID = if (requestedType.hasField(uidFieldName)) requestedType.deleteKey(uidFieldName) else requestedType - val e = StagedBGENReader.decodeRow(cb, elementRegion, cbfis, nSamples, fileIdx, compression, skipInvalidLoci, contigRecoding, reqTypeNoUID, rg) - e.toI(cb).consume(cb, { - cb += elementRegion.clearRegion() - cb.goto(Lstart) - }, { sc => - val scUID = if (requestedType.hasField(uidFieldName)) - sc.asBaseStruct.insert(cb, eltRegion, requestedType, - (uidFieldName, EmitValue.present(SStackStruct.constructFromArgs(cb, eltRegion, uidType, - EmitValue.present(primitive(cb.memoize(fileIdx.toL))), - EmitValue.present(primitive(currVariantIndex)))))) - else - sc - out = mb.newEmitField(scUID.st, true) - cb.assign(out, EmitCode.present(mb, scUID)) - }) + val reqTypeNoUID = if (requestedType.hasField(uidFieldName)) + requestedType.deleteKey(uidFieldName) + else requestedType + val e = StagedBGENReader.decodeRow(cb, elementRegion, cbfis, nSamples, fileIdx, + compression, skipInvalidLoci, contigRecoding, reqTypeNoUID, rg) + e.toI(cb).consume( + cb, { + cb += elementRegion.clearRegion() + cb.goto(Lstart) + }, + { sc => + val scUID = if (requestedType.hasField(uidFieldName)) + sc.asBaseStruct.insert( + cb, + eltRegion, + requestedType, + ( + uidFieldName, + EmitValue.present(SStackStruct.constructFromArgs( + cb, + eltRegion, + uidType, + EmitValue.present(primitive(cb.memoize(fileIdx.toL))), + EmitValue.present(primitive(currVariantIndex)), + )), + ), + ) + else + sc + out = mb.newEmitField(scUID.st, true) + cb.assign(out, EmitCode.present(mb, scUID)) + }, + ) cb.assign(currVariantIndex, currVariantIndex + 1) cb.goto(LproduceElementDone) diff --git a/hail/src/main/scala/is/hail/io/bgen/StagedBGENReader.scala b/hail/src/main/scala/is/hail/io/bgen/StagedBGENReader.scala index 27f73478732..787b6d70d9a 100644 --- a/hail/src/main/scala/is/hail/io/bgen/StagedBGENReader.scala +++ b/hail/src/main/scala/is/hail/io/bgen/StagedBGENReader.scala @@ -3,37 +3,42 @@ package is.hail.io.bgen import is.hail.annotations.Region import is.hail.asm4s._ import is.hail.backend.ExecuteContext +import is.hail.expr.ir.{ + uuid4, ArraySorter, EmitCode, EmitCodeBuilder, EmitFunctionBuilder, EmitSettable, IEmitCode, + LowerMatrixIR, ParamType, StagedArrayBuilder, +} import is.hail.expr.ir.functions.{RegistryFunctions, StringFunctions} import is.hail.expr.ir.streams.StreamUtils -import is.hail.expr.ir.{ArraySorter, EmitCode, EmitCodeBuilder, EmitFunctionBuilder, EmitSettable, IEmitCode, LowerMatrixIR, ParamType, StagedArrayBuilder, uuid4} import is.hail.io._ import is.hail.io.fs.SeekableDataInputStream import is.hail.io.index.{StagedIndexReader, StagedIndexWriter} import is.hail.lir +import is.hail.types.{RStruct, TableType, TypeWithRequiredness} import is.hail.types.physical._ import is.hail.types.physical.stypes.SingleCodeType import is.hail.types.physical.stypes.concrete._ -import is.hail.types.physical.stypes.interfaces.{NoBoxLongIterator, SBaseStructValue, primitive} +import is.hail.types.physical.stypes.interfaces.{primitive, NoBoxLongIterator, SBaseStructValue} import is.hail.types.physical.stypes.primitives.SInt64 import is.hail.types.virtual._ -import is.hail.types.{RStruct, TableType, TypeWithRequiredness} import is.hail.utils.{BoxedArrayBuilder, CompressionUtils, FastSeq} import is.hail.variant.Call2 + import org.objectweb.asm.Opcodes._ object StagedBGENReader { def decompress( input: Array[Byte], - uncompressedSize: Int + uncompressedSize: Int, ): Array[Byte] = is.hail.utils.decompress(input, uncompressedSize) - - def recodeContig(contig: String, contigMap: Map[String, String]): String = contigMap.getOrElse(contig, contig) + def recodeContig(contig: String, contigMap: Map[String, String]): String = + contigMap.getOrElse(contig, contig) def rowRequiredness(requested: TStruct): RStruct = { val t = TypeWithRequiredness(requested).asInstanceOf[RStruct] t.fieldOption(LowerMatrixIR.entriesFieldName) - .foreach { t => t.fromPType(entryArrayPType(requested.field(LowerMatrixIR.entriesFieldName).typ)) + .foreach { t => + t.fromPType(entryArrayPType(requested.field(LowerMatrixIR.entriesFieldName).typ)) } t } @@ -42,16 +47,21 @@ object StagedBGENReader { val entryType = requested.asInstanceOf[TArray].elementType.asInstanceOf[TStruct] - PCanonicalArray(PCanonicalStruct(false, - Array( - "GT" -> PCanonicalCall(), - "GP" -> PCanonicalArray(PFloat64Required, required = true), - "dosage" -> PFloat64Required - ).filter { case (name, _) => entryType.hasField(name) - }: _*), true) + PCanonicalArray( + PCanonicalStruct( + false, + Array( + "GT" -> PCanonicalCall(), + "GP" -> PCanonicalArray(PFloat64Required, required = true), + "dosage" -> PFloat64Required, + ).filter { case (name, _) => entryType.hasField(name) }: _* + ), + true, + ) } - def decodeRow(cb: EmitCodeBuilder, + def decodeRow( + cb: EmitCodeBuilder, region: Value[Region], cbfis: Value[HadoopFSDataBinaryReader], nSamples: Value[Int], @@ -60,13 +70,24 @@ object StagedBGENReader { skipInvalidLoci: Value[Boolean], contigRecoding: Value[Map[String, String]], requestedType: TStruct, - rg: Option[String] + rg: Option[String], ): EmitCode = { var out: EmitSettable = null // defined and assigned inside method - val emb = cb.emb.ecb.genEmitMethod("decode_bgen_row", IndexedSeq[ParamType](classInfo[Region], classInfo[HadoopFSDataBinaryReader], IntInfo, IntInfo, IntInfo, BooleanInfo, classInfo[Map[String, String]]), UnitInfo) + val emb = cb.emb.ecb.genEmitMethod( + "decode_bgen_row", + IndexedSeq[ParamType]( + classInfo[Region], + classInfo[HadoopFSDataBinaryReader], + IntInfo, + IntInfo, + IntInfo, + BooleanInfo, + classInfo[Map[String, String]], + ), + UnitInfo, + ) emb.voidWithBuilder { cb => - - val rgBc = rg.map { rg => cb.memoize(emb.getReferenceGenome(rg)) } + val rgBc = rg.map(rg => cb.memoize(emb.getReferenceGenome(rg))) val region = emb.getCodeParam[Region](1) val cbfis = emb.getCodeParam[HadoopFSDataBinaryReader](2) val nSamples = emb.getCodeParam[Int](3) @@ -94,7 +115,6 @@ object StagedBGENReader { val nAlleles2 = cb.newLocal[Int]("nAlleles2") val minPloidy = cb.newLocal[Int]("minPloidy") val maxPloidy = cb.newLocal[Int]("maxPloidy") - val longPloidy = cb.newLocal[Long]("longPloidy") val ploidy = cb.newLocal[Int]("ploidy") val phase = cb.newLocal[Int]("phase") val nBitsPerProb = cb.newLocal[Int]("nBitsPerProb") @@ -107,7 +127,6 @@ object StagedBGENReader { cb.assign(c1, Call2.fromUnphasedDiploidGtIndex(1)) cb.assign(c2, Call2.fromUnphasedDiploidGtIndex(2)) - cb.assign(offset, cbfis.invoke[Long]("getPosition")) if (requestedType.hasField("varid")) @@ -121,31 +140,41 @@ object StagedBGENReader { cb += cbfis.invoke[Int, Unit]("readLengthAndSkipString", 2) cb.assign(contig, cbfis.invoke[Int, String]("readLengthAndString", 2)) - cb.assign(contigRecoded, Code.invokeScalaObject2[String, Map[String, String], String](StagedBGENReader.getClass, "recodeContig", contig, contigRecoding)) + cb.assign( + contigRecoded, + Code.invokeScalaObject2[String, Map[String, String], String]( + StagedBGENReader.getClass, + "recodeContig", + contig, + contigRecoding, + ), + ) cb.assign(position, cbfis.invoke[Int]("readInt")) - - cb.if_(skipInvalidLoci, { - rgBc.foreach { rg => - cb.if_(!rg.invoke[String, Int, Boolean]("isValidLocus", contigRecoded, position), - { - cb.assign(nAlleles, cbfis.invoke[Int]("readShort")) - cb.assign(i, 0) - cb.while_(i < nAlleles, - { - cb += cbfis.invoke[Int, Unit]("readLengthAndSkipString", 4) - cb.assign(i, i + 1) - }) - cb.assign(dataSize, cbfis.invoke[Int]("readInt")) - cb += Code.toUnit(cbfis.invoke[Long, Long]("skipBytes", dataSize.toL)) - cb.goto(LreturnMissing) - }) - } - }, { + cb.if_( + skipInvalidLoci, { + rgBc.foreach { rg => + cb.if_( + !rg.invoke[String, Int, Boolean]("isValidLocus", contigRecoded, position), { + cb.assign(nAlleles, cbfis.invoke[Int]("readShort")) + cb.assign(i, 0) + cb.while_( + i < nAlleles, { + cb += cbfis.invoke[Int, Unit]("readLengthAndSkipString", 4) + cb.assign(i, i + 1) + }, + ) + cb.assign(dataSize, cbfis.invoke[Int]("readInt")) + cb += Code.toUnit(cbfis.invoke[Long, Long]("skipBytes", dataSize.toL)) + cb.goto(LreturnMissing) + }, + ) + } + }, rgBc.foreach { rg => cb += rg.invoke[String, Int, Unit]("checkLocus", contigRecoded, position) - } - }) + }, + ) val structFieldCodes = new BoxedArrayBuilder[EmitCode]() @@ -155,38 +184,58 @@ object StagedBGENReader { val pc = requestedType.field("locus").typ match { case TLocus(rg) => val pt = SCanonicalLocusPointer(PCanonicalLocus(rg)) - pt.pType.constructFromPositionAndString(cb, region, contigRecoded, position) + pt.pType.constructFromContigAndPosition(cb, region, contigRecoded, position) case t: TStruct => val contig = SJavaString.constructFromString(cb, region, contigRecoded) - SStackStruct.constructFromArgs(cb, region, t, - EmitCode.present(cb.emb, contig), EmitCode.present(cb.emb, primitive(position))) + SStackStruct.constructFromArgs( + cb, + region, + t, + EmitCode.present(cb.emb, contig), + EmitCode.present(cb.emb, primitive(position)), + ) } structFieldCodes += EmitCode.present(cb.emb, pc) } cb.assign(nAlleles, cbfis.invoke[Int]("readShort")) - cb.if_(nAlleles.cne(2), - cb._fatal("Only biallelic variants supported, found variant with ", nAlleles.toS, " alleles: ", - contigRecoded, ":", position.toS)) + cb.if_( + nAlleles.cne(2), + cb._fatal( + "Only biallelic variants supported, found variant with ", + nAlleles.toS, + " alleles: ", + contigRecoded, + ":", + position.toS, + ), + ) if (requestedType.hasField("alleles")) { val allelesType = SJavaArrayString(true) val a = cb.newLocal[Array[String]]("alleles", Code.newArray[String](nAlleles)) - cb.while_(i < nAlleles, { - cb += a.update(i, cbfis.invoke[Int, String]("readLengthAndString", 4)) - cb.assign(i, i + 1) - }) - + cb.while_( + i < nAlleles, { + cb += a.update(i, cbfis.invoke[Int, String]("readLengthAndString", 4)) + cb.assign(i, i + 1) + }, + ) structFieldCodes += EmitCode.present(cb.emb, allelesType.construct(cb, a)) } if (requestedType.hasField("rsid")) - structFieldCodes += EmitCode.present(cb.emb, SStringPointer(PCanonicalString(false)).constructFromString(cb, region, rsid)) + structFieldCodes += EmitCode.present( + cb.emb, + SStringPointer(PCanonicalString(false)).constructFromString(cb, region, rsid), + ) if (requestedType.hasField("varid")) - structFieldCodes += EmitCode.present(cb.emb, SStringPointer(PCanonicalString(false)).constructFromString(cb, region, varid)) + structFieldCodes += EmitCode.present( + cb.emb, + SStringPointer(PCanonicalString(false)).constructFromString(cb, region, varid), + ) if (requestedType.hasField("offset")) structFieldCodes += EmitCode.present(cb.emb, primitive(offset)) if (requestedType.hasField("file_idx")) @@ -212,7 +261,6 @@ object StagedBGENReader { val memoMB = emb.genEmitMethod("memoizeEntries", FastSeq[ParamType](), UnitInfo) memoMB.voidWithBuilder { cb => - val partRegion = emb.partitionRegion val LnoOp = CodeLabel() @@ -221,78 +269,106 @@ object StagedBGENReader { val (push, finish) = memoTyp.constructFromFunctions(cb, partRegion, 1 << 16, false) val d0 = cb.newLocal[Int]("memoize_entries_d0", 0) - cb.while_(d0 < 256, { - val d1 = cb.newLocal[Int]("memoize_entries_d1", 0) - cb.while_(d1 < 256, { - val d2 = cb.newLocal[Int]("memoize_entries_d2", const(255) - d0 - d1) - - val entryFieldCodes = new BoxedArrayBuilder[EmitCode]() - - if (includeGT) - entryFieldCodes += EmitCode.fromI(cb.emb) { cb => - val Lmissing = CodeLabel() - val Lpresent = CodeLabel() - val value = cb.newLocal[Int]("bgen_gt_value") - - cb.if_(d0 > d1, - cb.if_(d0 > d2, - { - cb.assign(value, c0) - cb.goto(Lpresent) - }, - cb.if_(d2 > d0, - { - cb.assign(value, c2) - cb.goto(Lpresent) - }, - // d0 == d2 - cb.goto(Lmissing))), - // d0 <= d1 - cb.if_(d2 > d1, - { - cb.assign(value, c2) - cb.goto(Lpresent) - }, - // d2 <= d1 - cb.if_(d1.ceq(d0) || d1.ceq(d2), - cb.goto(Lmissing), - { - cb.assign(value, c1) - cb.goto(Lpresent) - }))) - - IEmitCode(Lmissing, Lpresent, new SCanonicalCallValue(value), false) - } - - if (includeGP) - entryFieldCodes += EmitCode.fromI(cb.emb) { cb => - - val divisor = cb.newLocal[Double]("divisor", 255.0) - - val gpType = entryType.field("GP").typ.asInstanceOf[PCanonicalArray] - - val (pushElement, finish) = gpType.constructFromFunctions(cb, partRegion, 3, deepCopy = false) - pushElement(cb, IEmitCode.present(cb, primitive(cb.memoize(d0.toD / divisor)))) - pushElement(cb, IEmitCode.present(cb, primitive(cb.memoize(d1.toD / divisor)))) - pushElement(cb, IEmitCode.present(cb, primitive(cb.memoize(d2.toD / divisor)))) - - IEmitCode.present(cb, finish(cb)) - } - - - if (includeDosage) - entryFieldCodes += EmitCode.fromI(cb.emb) { cb => - IEmitCode.present(cb, primitive(cb.memoize((d1 + (d2 << 1)).toD / 255.0))) - } - - push(cb, IEmitCode.present(cb, - SStackStruct.constructFromArgs(cb, partRegion, entryType.virtualType, entryFieldCodes.result(): _*))) - - cb.assign(d1, d1 + 1) - }) - - cb.assign(d0, d0 + 1) - }) + cb.while_( + d0 < 256, { + val d1 = cb.newLocal[Int]("memoize_entries_d1", 0) + cb.while_( + d1 < 256, { + val d2 = cb.newLocal[Int]("memoize_entries_d2", const(255) - d0 - d1) + + val entryFieldCodes = new BoxedArrayBuilder[EmitCode]() + + if (includeGT) + entryFieldCodes += EmitCode.fromI(cb.emb) { cb => + val Lmissing = CodeLabel() + val Lpresent = CodeLabel() + val value = cb.newLocal[Int]("bgen_gt_value") + + cb.if_( + d0 > d1, + cb.if_( + d0 > d2, { + cb.assign(value, c0) + cb.goto(Lpresent) + }, + cb.if_( + d2 > d0, { + cb.assign(value, c2) + cb.goto(Lpresent) + }, + // d0 == d2 + cb.goto(Lmissing), + ), + ), + // d0 <= d1 + cb.if_( + d2 > d1, { + cb.assign(value, c2) + cb.goto(Lpresent) + }, + // d2 <= d1 + cb.if_( + d1.ceq(d0) || d1.ceq(d2), + cb.goto(Lmissing), { + cb.assign(value, c1) + cb.goto(Lpresent) + }, + ), + ), + ) + + IEmitCode(Lmissing, Lpresent, new SCanonicalCallValue(value), false) + } + + if (includeGP) + entryFieldCodes += EmitCode.fromI(cb.emb) { cb => + val divisor = cb.newLocal[Double]("divisor", 255.0) + + val gpType = entryType.field("GP").typ.asInstanceOf[PCanonicalArray] + + val (pushElement, finish) = + gpType.constructFromFunctions(cb, partRegion, 3, deepCopy = false) + pushElement( + cb, + IEmitCode.present(cb, primitive(cb.memoize(d0.toD / divisor))), + ) + pushElement( + cb, + IEmitCode.present(cb, primitive(cb.memoize(d1.toD / divisor))), + ) + pushElement( + cb, + IEmitCode.present(cb, primitive(cb.memoize(d2.toD / divisor))), + ) + + IEmitCode.present(cb, finish(cb)) + } + + if (includeDosage) + entryFieldCodes += EmitCode.fromI(cb.emb) { cb => + IEmitCode.present(cb, primitive(cb.memoize((d1 + (d2 << 1)).toD / 255.0))) + } + + push( + cb, + IEmitCode.present( + cb, + SStackStruct.constructFromArgs( + cb, + partRegion, + entryType.virtualType, + entryFieldCodes.result(): _* + ), + ), + ) + + cb.assign(d1, d1 + 1) + }, + ) + + cb.assign(d0, d0 + 1) + }, + ) cb.assign(memoizedEntryData, finish(cb).a) cb.assign(alreadyMemoized, true) @@ -300,36 +376,55 @@ object StagedBGENReader { cb.define(LnoOp) } - cb.if_(compression ceq BgenSettings.UNCOMPRESSED, { - cb.assign(data, cbfis.invoke[Int, Array[Byte]]("readBytes", dataSize)) - }, { - cb.assign(uncompressedSize, cbfis.invoke[Int]("readInt")) - cb.assign(input, cbfis.invoke[Int, Array[Byte]]("readBytes", dataSize - 4)) - cb.if_(compression ceq BgenSettings.ZLIB_COMPRESSION, { - cb.assign(data, - Code.invokeScalaObject2[Array[Byte], Int, Array[Byte]]( - CompressionUtils.getClass, "decompressZlib", input, uncompressedSize)) - }, { - // zstd - cb.assign(data, Code.invokeScalaObject2[Array[Byte], Int, Array[Byte]]( - CompressionUtils.getClass, "decompressZstd", input, uncompressedSize)) - }) - }) + cb.if_( + compression ceq BgenSettings.UNCOMPRESSED, + cb.assign(data, cbfis.invoke[Int, Array[Byte]]("readBytes", dataSize)), { + cb.assign(uncompressedSize, cbfis.invoke[Int]("readInt")) + cb.assign(input, cbfis.invoke[Int, Array[Byte]]("readBytes", dataSize - 4)) + cb.if_( + compression ceq BgenSettings.ZLIB_COMPRESSION, + cb.assign( + data, + Code.invokeScalaObject2[Array[Byte], Int, Array[Byte]]( + CompressionUtils.getClass, + "decompressZlib", + input, + uncompressedSize, + ), + ), + // zstd + cb.assign( + data, + Code.invokeScalaObject2[Array[Byte], Int, Array[Byte]]( + CompressionUtils.getClass, + "decompressZstd", + input, + uncompressedSize, + ), + ), + ) + }, + ) cb.assign(reader, Code.newInstance[ByteArrayReader, Array[Byte]](data)) cb.assign(nRow, reader.invoke[Int]("readInt")) - cb.if_(nRow.cne(nSamples), cb._fatal( - const("Row nSamples is not equal to header nSamples: ") - .concat(nRow.toS) - .concat(", ") - .concat(nSamples.toString) - )) + cb.if_( + nRow.cne(nSamples), + cb._fatal( + const("Row nSamples is not equal to header nSamples: ") + .concat(nRow.toS) + .concat(", ") + .concat(nSamples.toString) + ), + ) cb.assign(nAlleles2, reader.invoke[Int]("readShort")) - cb.if_(nAlleles.cne(nAlleles2), + cb.if_( + nAlleles.cne(nAlleles2), cb._fatal(const( """Value for 'nAlleles' in genotype probability data storage is - |not equal to value in variant identifying data. Expected""".stripMargin) + |not equal to value in variant identifying data. Expected""".stripMargin + ) .concat(nAlleles.toS) .concat(" but found ") .concat(nAlleles2.toS) @@ -337,82 +432,108 @@ object StagedBGENReader { .concat(contig) .concat(":") .concat(position.toS) - .concat("."))) + .concat(".")), + ) cb.assign(minPloidy, reader.invoke[Int]("read")) cb.assign(maxPloidy, reader.invoke[Int]("read")) - cb.if_(minPloidy.cne(2) || maxPloidy.cne(2), + cb.if_( + minPloidy.cne(2) || maxPloidy.cne(2), cb._fatal(const("Hail only supports diploid genotypes. Found min ploidy '") .concat(minPloidy.toS) .concat("' and max ploidy '") .concat(maxPloidy.toS) - .concat("'."))) + .concat("'.")), + ) cb.assign(i, 0) - cb.while_(i < nSamples, { - cb.assign(ploidy, reader.invoke[Int]("read")) - cb.if_((ploidy & 0x3f).cne(2), - cb._fatal(const("Ploidy value must equal to 2. Found ") - .concat(ploidy.toS) - .concat("."))) - cb.assign(i, i + 1) - }) + cb.while_( + i < nSamples, { + cb.assign(ploidy, reader.invoke[Int]("read")) + cb.if_( + (ploidy & 0x3f).cne(2), + cb._fatal(const("Ploidy value must equal to 2. Found ") + .concat(ploidy.toS) + .concat(".")), + ) + cb.assign(i, i + 1) + }, + ) cb.assign(phase, reader.invoke[Int]("read")) - cb.if_(phase.cne(0) && (phase.cne(1)), + cb.if_( + phase.cne(0) && (phase.cne(1)), cb._fatal(const("Phase value must be 0 or 1. Found ") .concat(phase.toS) - .concat("."))) + .concat(".")), + ) - cb.if_(phase.ceq(1), cb._fatal("Hail does not support phased genotypes in 'import_bgen'.")) + cb.if_( + phase.ceq(1), + cb._fatal("Hail does not support phased genotypes in 'import_bgen'."), + ) cb.assign(nBitsPerProb, reader.invoke[Int]("read")) - cb.if_(nBitsPerProb < 1 || nBitsPerProb > 32, + cb.if_( + nBitsPerProb < 1 || nBitsPerProb > 32, cb._fatal(const("nBits value must be between 1 and 32 inclusive. Found ") .concat(nBitsPerProb.toS) - .concat("."))) - cb.if_(nBitsPerProb.cne(8), + .concat(".")), + ) + cb.if_( + nBitsPerProb.cne(8), cb._fatal(const("Hail only supports 8-bit probabilities, found ") .concat(nBitsPerProb.toS) - .concat("."))) + .concat(".")), + ) cb.assign(nExpectedBytesProbs, nSamples * 2) - cb.if_(reader.invoke[Int]("length").cne(nExpectedBytesProbs + nSamples.get + 10), + cb.if_( + reader.invoke[Int]("length").cne(nExpectedBytesProbs + nSamples.get + 10), cb._fatal(const("Number of uncompressed bytes '") .concat(reader.invoke[Int]("length").toS) .concat("' does not match the expected size '") .concat(nExpectedBytesProbs.toS) - .concat("'."))) + .concat("'.")), + ) - cb.invokeVoid(memoMB) + cb.invokeVoid(memoMB, cb.this_) - val (pushElement, finish) = entriesArrayType.constructFromFunctions(cb, region, nSamples, deepCopy = false) + val (pushElement, finish) = + entriesArrayType.constructFromFunctions(cb, region, nSamples, deepCopy = false) cb.assign(i, 0) - cb.while_(i < nSamples, { - - val Lmissing = CodeLabel() - val Lpresent = CodeLabel() - - cb.if_((data(i + 8) & 0x80).cne(0), cb.goto(Lmissing)) - val dataOffset = cb.newLocal[Int]("bgen_add_entries_offset", (nSamples.get + const(10).get) + i * 2) - val d0 = data(dataOffset) & 0xff - val d1 = data(dataOffset + 1) & 0xff - val pc = entryType.loadCheapSCode(cb, memoTyp.loadElement(memoizedEntryData, nSamples, (d0 << 8) | d1)) - cb.goto(Lpresent) - val iec = IEmitCode(Lmissing, Lpresent, pc, false) - pushElement(cb, iec) - - cb.assign(i, i + 1) - }) + cb.while_( + i < nSamples, { + + val Lmissing = CodeLabel() + val Lpresent = CodeLabel() + + cb.if_((data(i + 8) & 0x80).cne(0), cb.goto(Lmissing)) + val dataOffset = + cb.newLocal[Int]("bgen_add_entries_offset", (nSamples.get + const(10).get) + i * 2) + val d0 = data(dataOffset) & 0xff + val d1 = data(dataOffset + 1) & 0xff + val pc = entryType.loadCheapSCode( + cb, + memoTyp.loadElement(memoizedEntryData, nSamples, (d0 << 8) | d1), + ) + cb.goto(Lpresent) + val iec = IEmitCode(Lmissing, Lpresent, pc, false) + pushElement(cb, iec) + + cb.assign(i, i + 1) + }, + ) val pc = finish(cb) structFieldCodes += EmitCode.fromI(cb.emb)(cb => IEmitCode.present(cb, pc)) } - val ss = SStackStruct.constructFromArgs(cb, region, requestedType, structFieldCodes.result(): _*) + val ss = + SStackStruct.constructFromArgs(cb, region, requestedType, structFieldCodes.result(): _*) out = emb.ecb.newEmitField("bgen_row", ss.st, false) cb.assign(out, EmitCode.present(emb, ss)) @@ -424,11 +545,25 @@ object StagedBGENReader { cb.define(Lfinish) } - cb.invokeVoid(emb, region, cbfis, nSamples, fileIdx, compression, skipInvalidLoci, contigRecoding) + cb.invokeVoid( + emb, + cb.this_, + region, + cbfis, + nSamples, + fileIdx, + compression, + skipInvalidLoci, + contigRecoding, + ) out } - def queryIndexByPosition(ctx: ExecuteContext, leafSpec: AbstractTypedCodecSpec, internalSpec: AbstractTypedCodecSpec): (String, Array[Long]) => Array[AnyRef] = { + def queryIndexByPosition( + ctx: ExecuteContext, + leafSpec: AbstractTypedCodecSpec, + internalSpec: AbstractTypedCodecSpec, + ): (String, Array[Long]) => Array[AnyRef] = { val fb = EmitFunctionBuilder[String, Array[Long], Array[AnyRef]](ctx, "bgen_query_index") fb.emitWithBuilder { cb => @@ -441,12 +576,20 @@ object StagedBGENReader { val len = cb.memoize(indices.length()) val boxed = cb.memoize(Code.newArray[AnyRef](len)) val i = cb.newLocal[Int]("i", 0) - cb.while_(i < len, { - - val r = index.queryIndex(cb, mb.partitionRegion, cb.memoize(indices(i))).loadField(cb, "key").get(cb) - cb += boxed.update(i, StringFunctions.svalueToJavaValue(cb, mb.partitionRegion, r, safe = true)) - cb.assign(i, i + 1) - }) + cb.while_( + i < len, { + + val r = index.queryIndex(cb, mb.partitionRegion, cb.memoize(indices(i))).loadField( + cb, + "key", + ).getOrAssert(cb) + cb += boxed.update( + i, + StringFunctions.svalueToJavaValue(cb, mb.partitionRegion, r, safe = true), + ) + cb.assign(i, i + 1) + }, + ) index.close(cb) boxed } @@ -467,8 +610,21 @@ object BGENFunctions extends RegistryFunctions { def uuid(): String = uuid4() override def registerAll(): Unit = { - registerSCode("index_bgen", Array(TString, TString, TDict(TString, TString), TBoolean, TInt32), TInt64, (_, _) => SInt64, Array(TVariable("locusType"))) { - case (er, cb, Seq(locType), _, Array(_path, _idxPath, _recoding, _skipInvalidLoci, _bufferSize), err) => + registerSCode( + "index_bgen", + Array(TString, TString, TDict(TString, TString), TBoolean, TInt32), + TInt64, + (_, _) => SInt64, + Array(TVariable("locusType")), + ) { + case ( + er, + cb, + Seq(locType), + _, + Array(_path, _idxPath, _recoding, _skipInvalidLoci, _bufferSize), + err, + ) => val mb = cb.emb val ctx = cb.emb.ecb.ctx @@ -476,19 +632,34 @@ object BGENFunctions extends RegistryFunctions { val path = _path.asString.loadString(cb) val idxPath = _idxPath.asString.loadString(cb) - val recoding = cb.memoize(coerce[Map[String, String]](svalueToJavaValue(cb, er.region, _recoding))) + val recoding = + cb.memoize(coerce[Map[String, String]](svalueToJavaValue(cb, er.region, _recoding))) val skipInvalidLoci = _skipInvalidLoci.asBoolean.value val bufferSize = _bufferSize.asInt.value val cbfis = cb.memoize(Code.newInstance[HadoopFSDataBinaryReader, SeekableDataInputStream]( - mb.getFS.invoke[String, SeekableDataInputStream]("openNoCompression", path))) - - val header = cb.memoize(Code.invokeScalaObject3[HadoopFSDataBinaryReader, String, Long, BgenHeader]( - LoadBgen.getClass, "readState", cbfis, path, mb.getFS.invoke[String, Long]("getFileSize", path))) + mb.getFS.invoke[String, SeekableDataInputStream]("openNoCompression", path) + )) + + val header = + cb.memoize(Code.invokeScalaObject3[HadoopFSDataBinaryReader, String, Long, BgenHeader]( + LoadBgen.getClass, + "readState", + cbfis, + path, + mb.getFS.invoke[String, Long]("getFileSize", path), + )) - cb.if_(header.invoke[Int]("version") cne 2, { - cb._fatalWithError(err, "BGEN not version 2: ", path, ", version=", header.invoke[Int]("version").toS) - }) + cb.if_( + header.invoke[Int]("version") cne 2, + cb._fatalWithError( + err, + "BGEN not version 2: ", + path, + ", version=", + header.invoke[Int]("version").toS, + ), + ) val nSamples = cb.memoize(header.invoke[Int]("nSamples")) val fileIdx = const(-1) // unused @@ -503,14 +674,17 @@ object BGENFunctions extends RegistryFunctions { val settings: BgenSettings = BgenSettings( 0, // nSamples not used if there are no entries - TableType(rowType = TStruct( - "locus" -> TLocus.schemaFromRG(rg), - "alleles" -> TArray(TString), - "offset" -> TInt64), + TableType( + rowType = TStruct( + "locus" -> TLocus.schemaFromRG(rg), + "alleles" -> TArray(TString), + "offset" -> TInt64, + ), key = Array("locus", "alleles"), - globalType = TStruct.empty), + globalType = TStruct.empty, + ), rg, - TStruct() + TStruct(), ) val nFilesMax = cb.memoize((nVariants / bufferSize.toL + 1L).toI) @@ -518,36 +692,52 @@ object BGENFunctions extends RegistryFunctions { val paths = cb.memoize(Code.newArray[String](nFilesMax), "paths") val fileSizes = cb.memoize(Code.newArray[Int](nFilesMax), "fileSizes") - val rowPType = PCanonicalStruct("locus" -> PType.canonical(locType, true, true), + val rowPType = PCanonicalStruct( + "locus" -> PType.canonical(locType, true, true), "alleles" -> PCanonicalArray(PCanonicalString(true), true), - "offset" -> PInt64Required) + "offset" -> PInt64Required, + ) val bufferSct = SingleCodeType.fromSType(rowPType.sType) val buffer = new StagedArrayBuilder(cb, bufferSct, true, 8) val currSize = cb.newLocal[Int]("currSize", 0) val spec = TypedCodecSpec( - PType.canonical(TStruct("locus" -> locType, "alleles" -> TArray(TString), "offset" -> TInt64)), - BufferSpec.wireSpec + PType.canonical(TStruct( + "locus" -> locType, + "alleles" -> TArray(TString), + "offset" -> TInt64, + )), + BufferSpec.wireSpec, ) def dumpBuffer(cb: EmitCodeBuilder) = { val sorter = new ArraySorter(er, buffer) - sorter.sort(cb, er.region, { case (cb, region, l, r) => - val lv = bufferSct.loadToSValue(cb, l).asBaseStruct.subset("locus", "alleles") - val rv = bufferSct.loadToSValue(cb, r).asBaseStruct.subset("locus", "alleles") - cb.emb.ecb.getOrdering(lv.st, rv.st).ltNonnull(cb, lv, rv) - }) - - val path = cb.newLocal[String]("currFile", const(localTmpBase).concat(groupIndex.toS) - .concat("-").concat(Code.invokeScalaObject0[String](BGENFunctions.getClass, "uuid"))) - val ob = cb.newLocal[OutputBuffer]("currFile", spec.buildCodeOutputBuffer(mb.create(path))) + sorter.sort( + cb, + er.region, + { case (cb, _, l, r) => + val lv = bufferSct.loadToSValue(cb, l).asBaseStruct.subset("locus", "alleles") + val rv = bufferSct.loadToSValue(cb, r).asBaseStruct.subset("locus", "alleles") + cb.emb.ecb.getOrdering(lv.st, rv.st).ltNonnull(cb, lv, rv) + }, + ) + + val path = cb.newLocal[String]( + "currFile", + const(localTmpBase).concat(groupIndex.toS) + .concat("-").concat(Code.invokeScalaObject0[String](BGENFunctions.getClass, "uuid")), + ) + val ob = + cb.newLocal[OutputBuffer]("currFile", spec.buildCodeOutputBuffer(mb.create(path))) val i = cb.newLocal[Int]("i", 0) - cb.while_(i < currSize, { - val k = bufferSct.loadToSValue(cb, cb.memoizeAny(buffer.apply(i), buffer.ti)) - spec.encodedType.buildEncoder(k.st, mb.ecb).apply(cb, k, ob) - cb.assign(i, i + 1) - }) + cb.while_( + i < currSize, { + val k = bufferSct.loadToSValue(cb, cb.memoizeAny(buffer.apply(i), buffer.ti)) + spec.encodedType.buildEncoder(k.st, mb.ecb).apply(cb, k, ob) + cb.assign(i, i + 1) + }, + ) cb += paths.update(groupIndex, path) cb += fileSizes.update(groupIndex, currSize) cb += ob.invoke[Unit]("close") @@ -561,27 +751,40 @@ object BGENFunctions extends RegistryFunctions { val nRead = cb.newLocal[Long]("nRead", 0L) val nWritten = cb.newLocal[Long]("nWritten", 0L) - cb.while_(nRead < nVariants, { - StagedBGENReader.decodeRow(cb, er.region, cbfis, nSamples, fileIdx, compression, skipInvalidLoci, recoding, - TStruct("locus" -> locType, "alleles" -> TArray(TString), "offset" -> TInt64), rg).toI(cb).consume(cb, { - // do nothing if missing (invalid locus) - }, { case row: SBaseStructValue => - cb.if_(currSize ceq bufferSize, { - dumpBuffer(cb) - }) - buffer.add(cb, bufferSct.coerceSCode(cb, row, er.region, false).code) - cb.assign(currSize, currSize + 1) - cb.assign(nWritten, nWritten + 1) - }) - cb.assign(nRead, nRead + 1) - }) + cb.while_( + nRead < nVariants, { + StagedBGENReader.decodeRow( + cb, + er.region, + cbfis, + nSamples, + fileIdx, + compression, + skipInvalidLoci, + recoding, + TStruct("locus" -> locType, "alleles" -> TArray(TString), "offset" -> TInt64), + rg, + ).toI(cb).consume( + cb, { + // do nothing if missing (invalid locus) + }, + { case row: SBaseStructValue => + cb.if_(currSize ceq bufferSize, dumpBuffer(cb)) + buffer.add(cb, bufferSct.coerceSCode(cb, row, er.region, false).code) + cb.assign(currSize, currSize + 1) + cb.assign(nWritten, nWritten + 1) + }, + ) + cb.assign(nRead, nRead + 1) + }, + ) cb.if_(currSize > 0, dumpBuffer(cb)) - val ecb = cb.emb.genEmitClass[Unit]("buffer_stream") ecb.cb.addInterface(typeInfo[NoBoxLongIterator].iname) - val ctor = ecb.newEmitMethod("", FastSeq[ParamType](typeInfo[String], typeInfo[Int]), UnitInfo) + val ctor = + ecb.newEmitMethod("", FastSeq[ParamType](typeInfo[String], typeInfo[Int]), UnitInfo) val ib = ecb.genFieldThisRef[InputBuffer]("ib") val iterSize = ecb.genFieldThisRef[Int]("size") val iterCurrIdx = ecb.genFieldThisRef[Int]("currIdx") @@ -590,13 +793,16 @@ object BGENFunctions extends RegistryFunctions { ctor.voidWithBuilder { cb => val L = new lir.Block() L.append( - lir.methodStmt(INVOKESPECIAL, + lir.methodStmt( + INVOKESPECIAL, "java/lang/Object", "", "()V", false, UnitInfo, - FastSeq(lir.load(ctor.mb._this.asInstanceOf[LocalRef[_]].l)))) + FastSeq(lir.load(ctor.mb.this_.asInstanceOf[LocalRef[_]].l)), + ) + ) cb += new VCode(L, L, null) val path = cb.memoize(ctor.getCodeParam[String](1)) @@ -609,7 +815,11 @@ object BGENFunctions extends RegistryFunctions { val next = ecb.newEmitMethod("next", FastSeq[ParamType](), LongInfo) - val init = ecb.newEmitMethod("init", FastSeq[ParamType](typeInfo[Region], typeInfo[Region]), UnitInfo) + val init = ecb.newEmitMethod( + "init", + FastSeq[ParamType](typeInfo[Region], typeInfo[Region]), + UnitInfo, + ) init.voidWithBuilder { cb => val eltRegion = init.getCodeParam[Region](2) @@ -618,14 +828,27 @@ object BGENFunctions extends RegistryFunctions { next.emitWithBuilder { cb => val ret = cb.newLocal[Long]("ret") - cb.if_(iterCurrIdx < iterSize, { - cb.assign(ret, rowPType.store(cb, iterEltRegion, - spec.encodedType.buildDecoder(rowPType.virtualType, ecb).apply(cb, iterEltRegion, ib), false)) - cb.assign(iterCurrIdx, iterCurrIdx + 1) - }, { - cb.assign(iterEOS, true) - cb.assign(ret, 0L) - }) + cb.if_( + iterCurrIdx < iterSize, { + cb.assign( + ret, + rowPType.store( + cb, + iterEltRegion, + spec.encodedType.buildDecoder(rowPType.virtualType, ecb).apply( + cb, + iterEltRegion, + ib, + ), + false, + ), + ) + cb.assign(iterCurrIdx, iterCurrIdx + 1) + }, { + cb.assign(iterEOS, true) + cb.assign(ret, 0L) + }, + ) ret } @@ -638,26 +861,65 @@ object BGENFunctions extends RegistryFunctions { val iters = mb.genFieldThisRef[Array[NoBoxLongIterator]]("iters") cb.assign(iters, Code.newArray[NoBoxLongIterator](groupIndex)) val i = cb.newLocal[Int]("i") - cb.while_(i < groupIndex, { - cb += iters.update(i, coerce[NoBoxLongIterator](Code.newInstance(ecb.cb, ctor.mb, FastSeq(paths(i), fileSizes(i))))) - cb.assign(i, i + 1) - }) + cb.while_( + i < groupIndex, { + cb += iters.update( + i, + coerce[NoBoxLongIterator](Code.newInstance( + ecb.cb, + ctor.mb, + FastSeq(paths(i), fileSizes(i)), + )), + ) + cb.assign(i, i + 1) + }, + ) - val mergedStream = StreamUtils.multiMergeIterators(cb, Right(true), iters, FastSeq("locus", "alleles"), rowPType) + val mergedStream = StreamUtils.multiMergeIterators( + cb, + Right(true), + iters, + FastSeq("locus", "alleles"), + rowPType, + ) - val iw = StagedIndexWriter.withDefaults(settings.indexKeyType, mb.ecb, annotationType = +PCanonicalStruct()) - iw.init(cb, idxPath, cb.memoize(Code.invokeScalaObject3[String, Map[String, String], Boolean, Map[String, Any]]( - BGENFunctions.getClass, "wrapAttrs", mb.getObject(rg.orNull), recoding, skipInvalidLoci))) + val iw = StagedIndexWriter.withDefaults( + settings.indexKeyType, + mb.ecb, + annotationType = +PCanonicalStruct(), + ) + iw.init( + cb, + idxPath, + cb.memoize(Code.invokeScalaObject3[String, Map[String, String], Boolean, Map[ + String, + Any, + ]]( + BGENFunctions.getClass, + "wrapAttrs", + mb.getObject(rg.orNull), + recoding, + skipInvalidLoci, + )), + ) val nAdded = cb.newLocal[Long]("nAdded", 0) mergedStream.memoryManagedConsume(er.region, cb) { cb => - val row = mergedStream.element.toI(cb).get(cb).asBaseStruct + val row = mergedStream.element.toI(cb).getOrAssert(cb).asBaseStruct val key = row.subset("locus", "alleles") - val offset = row.loadField(cb, "offset").get(cb).asInt64.value + val offset = row.loadField(cb, "offset").getOrAssert(cb).asInt64.value cb.assign(nAdded, nAdded + 1) - iw.add(cb, IEmitCode.present(cb, key), offset, IEmitCode.present(cb, SStackStruct.constructFromArgs(cb, er.region, TStruct()))) + iw.add( + cb, + IEmitCode.present(cb, key), + offset, + IEmitCode.present(cb, SStackStruct.constructFromArgs(cb, er.region, TStruct())), + ) } - cb.if_(nWritten cne nAdded, cb._fatal(s"nWritten != nAdded - ", nWritten.toS, ", ", nAdded.toS)) + cb.if_( + nWritten cne nAdded, + cb._fatal(s"nWritten != nAdded - ", nWritten.toS, ", ", nAdded.toS), + ) iw.close(cb) cb += cbfis.invoke[Unit]("close") @@ -665,9 +927,11 @@ object BGENFunctions extends RegistryFunctions { } } - def wrapAttrs(rg: String, recoding: Map[String, String], skipInvalidLoci: Boolean): Map[String, Any] = { - Map("reference_genome" -> rg, + def wrapAttrs(rg: String, recoding: Map[String, String], skipInvalidLoci: Boolean) + : Map[String, Any] = + Map( + "reference_genome" -> rg, "contig_recoding" -> recoding, - "skip_invalid_loci" -> skipInvalidLoci) - } + "skip_invalid_loci" -> skipInvalidLoci, + ) } diff --git a/hail/src/main/scala/is/hail/io/compress/BGzipLineReader.scala b/hail/src/main/scala/is/hail/io/compress/BGzipLineReader.scala index fd19fbc71e2..8b1416a808c 100644 --- a/hail/src/main/scala/is/hail/io/compress/BGzipLineReader.scala +++ b/hail/src/main/scala/is/hail/io/compress/BGzipLineReader.scala @@ -1,15 +1,13 @@ package is.hail.io.compress -import java.nio.charset.StandardCharsets - import is.hail.io.fs.FS +import java.nio.charset.StandardCharsets + final class BGzipLineReader( private val fs: FS, - private val filePath: String -) - extends java.lang.AutoCloseable -{ + private val filePath: String, +) extends java.lang.AutoCloseable { private var is = new BGzipInputStream(fs.openNoCompression(filePath)) // The line iterator buffer and associated state, we use this to avoid making @@ -34,9 +32,8 @@ final class BGzipLineReader( private var virtualFileOffsetAtLastRead = 0L - def getVirtualOffset: Long = { + def getVirtualOffset: Long = virtualFileOffsetAtLastRead + (bufferCursor - bufferPositionAtLastRead) - } def virtualSeek(l: Long): Unit = { is.virtualSeek(l) @@ -74,9 +71,8 @@ final class BGzipLineReader( while (true) { var str: String = null - while (bufferCursor < bufferLen && buffer(bufferCursor) != '\n') { + while (bufferCursor < bufferLen && buffer(bufferCursor) != '\n') bufferCursor += 1 - } if (bufferCursor == bufferLen) { // no newline before end of buffer if (bufferEOF) { @@ -121,10 +117,9 @@ final class BGzipLineReader( throw new AssertionError() } - override def close(): Unit = { + override def close(): Unit = if (is != null) { is.close() is = null } - } } diff --git a/hail/src/main/scala/is/hail/io/compress/BGzipOutputStream.scala b/hail/src/main/scala/is/hail/io/compress/BGzipOutputStream.scala deleted file mode 100644 index 5cf82c6446f..00000000000 --- a/hail/src/main/scala/is/hail/io/compress/BGzipOutputStream.scala +++ /dev/null @@ -1,157 +0,0 @@ -package is.hail.io.compress - -import java.io.OutputStream -import java.util.zip.{CRC32, Deflater} - -import org.apache.hadoop.io.compress.CompressionOutputStream - -class BGzipConstants { - val blockHeaderLength = 18 // Number of bytes in the gzip block before the deflated data. - val blockLengthOffset = 16 // Location in the gzip block of the total block size (actually total block size - 1) - val blockFooterLength = 8 // Number of bytes that follow the deflated data - val maxCompressedBlockSize = 64 * 1024 // We require that a compressed block (including header and footer, be <= this) - val gzipOverhead = blockHeaderLength + blockFooterLength + 2 // Gzip overhead is the header, the footer, and the block size (encoded as a short). - val noCompressionOverhead = 10 // If Deflater has compression level == NO_COMPRESSION, 10 bytes of overhead (determined experimentally). - val defaultUncompressedBlockSize = 64 * 1024 - (gzipOverhead + noCompressionOverhead) // Push out a gzip block when this many uncompressed bytes have been accumulated. - - // gzip magic numbers - val gzipId1 = 31 - val gzipId2 = 139 - - val gzipModificationTime = 0 - val gzipFlag = 4 // set extra fields to true - val gzipXFL = 0 // extra flags - val gzipXLEN = 6 // length of extra subfield - val gzipCMDeflate = 8 // The deflate compression, which is customarily used by gzip - val defaultCompressionLevel = 5 - val gzipOsUnknown = 255 - val bgzfId1 = 66 - val bgzfId2 = 67 - val bgzfLen = 2 - val emptyGzipBlock = Array(0x1f,0x8b,0x08,0x04,0x00,0x00,0x00,0x00, - 0x00,0xff,0x06,0x00,0x42,0x43,0x02,0x00, - 0x1b,0x00,0x03,0x00,0x00,0x00,0x00,0x00, - 0x00,0x00,0x00,0x00).map(_.toByte) -} - -class BGzipOutputStream(out: OutputStream) extends CompressionOutputStream(out) { - private[this] var finished: Boolean = false - - val constants = new BGzipConstants - var numUncompressedBytes = 0 - var uncompressedBuffer = new Array[Byte](constants.defaultUncompressedBlockSize) - var compressedBuffer = new Array[Byte](constants.maxCompressedBlockSize - constants.blockHeaderLength) - - val deflater = new Deflater(constants.defaultCompressionLevel,true) - val noCompressionDeflater = new Deflater(Deflater.NO_COMPRESSION,true) - val crc32 = new CRC32 - - def write(b:Int) { - require(numUncompressedBytes < uncompressedBuffer.length) - uncompressedBuffer(numUncompressedBytes) = b.toByte - numUncompressedBytes += 1 - - if (numUncompressedBytes == uncompressedBuffer.length) - deflateBlock() - } - - override def write(bytes: Array[Byte], offset:Int, length:Int) { - require(numUncompressedBytes < uncompressedBuffer.length) - - var currentPosition = offset - var numBytesRemaining = length - - while (numBytesRemaining > 0) { - var bytesToWrite = math.min(uncompressedBuffer.length - numUncompressedBytes, numBytesRemaining) - System.arraycopy(bytes, currentPosition, uncompressedBuffer, numUncompressedBytes, bytesToWrite) - numUncompressedBytes += bytesToWrite - currentPosition += bytesToWrite - numBytesRemaining -= bytesToWrite - require(numBytesRemaining >= 0) - - if (numUncompressedBytes == uncompressedBuffer.length) - deflateBlock() - } - } - - final protected def deflateBlock(): Unit = { - require(numUncompressedBytes != 0) - assert(!finished) - - deflater.reset() - deflater.setInput(uncompressedBuffer, 0, numUncompressedBytes) - deflater.finish() - var compressedSize: Int = deflater.deflate(compressedBuffer, 0, compressedBuffer.length) - - // If it didn't all fit in compressedBuffer.length, set compression level to NO_COMPRESSION - // and try again. This should always fit. - if (!deflater.finished) { - noCompressionDeflater.reset() - noCompressionDeflater.setInput(uncompressedBuffer, 0, numUncompressedBytes) - noCompressionDeflater.finish() - compressedSize = noCompressionDeflater.deflate(compressedBuffer, 0, compressedBuffer.length) - require(noCompressionDeflater.finished) - } - // Data compressed small enough, so write it out. - crc32.reset() - crc32.update(uncompressedBuffer, 0, numUncompressedBytes) - - val totalBlockSize: Int = writeGzipBlock(compressedSize, numUncompressedBytes, crc32.getValue) - - numUncompressedBytes = 0 // reset variable - } - - def writeInt8(i: Int) = { - out.write(i & 0xff) - } - - def writeInt16(i: Int) = { - out.write(i & 0xff) - out.write((i >> 8) & 0xff) - } - - def writeInt32(i:Int) = { - out.write(i & 0xff) - out.write((i >> 8) & 0xff) - out.write((i >> 16) & 0xff) - out.write((i >> 24) & 0xff) - } - - def writeGzipBlock(compressedSize:Int,bytesToCompress:Int,crc32val:Long): Int = { - val totalBlockSize = compressedSize + constants.blockHeaderLength + constants.blockFooterLength - - writeInt8(constants.gzipId1) - writeInt8(constants.gzipId2) - writeInt8(constants.gzipCMDeflate) - writeInt8(constants.gzipFlag) - writeInt32(constants.gzipModificationTime) - writeInt8(constants.gzipXFL) - writeInt8(constants.gzipOsUnknown) - writeInt16(constants.gzipXLEN) - writeInt8(constants.bgzfId1) - writeInt8(constants.bgzfId2) - writeInt16(constants.bgzfLen) - writeInt16(totalBlockSize - 1) - out.write(compressedBuffer, 0, compressedSize) - writeInt32(crc32val.toInt) - writeInt32(bytesToCompress) - totalBlockSize - } - - def resetState() = throw new UnsupportedOperationException - - override def finish(): Unit = { - if (numUncompressedBytes != 0) - deflateBlock() - if (!finished) { - out.write(constants.emptyGzipBlock) - finished = true - } - } -} - -class ComposableBGzipOutputStream(out: OutputStream) extends BGzipOutputStream(out) { - override def finish() = if (numUncompressedBytes != 0) { - deflateBlock() - } -} diff --git a/hail/src/main/scala/is/hail/io/compress/LZ4.scala b/hail/src/main/scala/is/hail/io/compress/LZ4.scala index fbcd17ad9fd..1180c49f0dc 100644 --- a/hail/src/main/scala/is/hail/io/compress/LZ4.scala +++ b/hail/src/main/scala/is/hail/io/compress/LZ4.scala @@ -1,6 +1,6 @@ package is.hail.io.compress -import net.jpountz.lz4.{ LZ4Compressor, LZ4Factory, LZ4FastDecompressor } +import net.jpountz.lz4.{LZ4Compressor, LZ4Factory, LZ4FastDecompressor} object LZ4 { val factory = LZ4Factory.fastestInstance() @@ -11,7 +11,7 @@ object LZ4 { class LZ4 private ( compressor: LZ4Compressor, - decompressor: LZ4FastDecompressor + decompressor: LZ4FastDecompressor, ) { def maxCompressedLength(decompLen: Int): Int = compressor.maxCompressedLength(decompLen) @@ -23,7 +23,14 @@ class LZ4 private ( compressedLen } - def decompress(decomp: Array[Byte], decompOff: Int, decompLen: Int, comp: Array[Byte], compOff: Int, compLen: Int) { + def decompress( + decomp: Array[Byte], + decompOff: Int, + decompLen: Int, + comp: Array[Byte], + compOff: Int, + compLen: Int, + ): Unit = { val compLen2 = decompressor.decompress(comp, compOff, decomp, decompOff, decompLen) assert(compLen2 == compLen) } diff --git a/hail/src/main/scala/is/hail/io/fs/AzureStorageFS.scala b/hail/src/main/scala/is/hail/io/fs/AzureStorageFS.scala index e84cf839ede..d9b5c7c542e 100644 --- a/hail/src/main/scala/is/hail/io/fs/AzureStorageFS.scala +++ b/hail/src/main/scala/is/hail/io/fs/AzureStorageFS.scala @@ -1,128 +1,84 @@ package is.hail.io.fs -import is.hail.shadedazure.com.azure.core.credential.{AzureSasCredential, TokenCredential} -import is.hail.shadedazure.com.azure.identity.{ClientSecretCredential, ClientSecretCredentialBuilder, DefaultAzureCredential, DefaultAzureCredentialBuilder, ManagedIdentityCredentialBuilder} -import is.hail.shadedazure.com.azure.storage.blob.models.{BlobItem, BlobProperties, BlobRange, BlobStorageException, ListBlobsOptions} -import is.hail.shadedazure.com.azure.storage.blob.specialized.BlockBlobClient -import is.hail.shadedazure.com.azure.storage.blob.{BlobClient, BlobContainerClient, BlobServiceClient, BlobServiceClientBuilder} -import is.hail.shadedazure.com.azure.core.http.HttpClient -import is.hail.shadedazure.com.azure.core.util.HttpClientOptions +import is.hail.io.fs.FSUtil.dropTrailingSlash import is.hail.services.retryTransientErrors -import is.hail.io.fs.FSUtil.{containsWildcard, dropTrailingSlash} -import is.hail.services.Requester.httpClient -import org.apache.log4j.Logger -import org.apache.commons.io.IOUtils - -import java.net.URI +import is.hail.shadedazure.com.azure.core.credential.AzureSasCredential +import is.hail.shadedazure.com.azure.core.util.HttpClientOptions +import is.hail.shadedazure.com.azure.identity.{ + ClientSecretCredentialBuilder, DefaultAzureCredentialBuilder, +} +import is.hail.shadedazure.com.azure.storage.blob.{ + BlobClient, BlobContainerClient, BlobServiceClient, BlobServiceClientBuilder, +} +import is.hail.shadedazure.com.azure.storage.blob.models.{ + BlobItem, BlobRange, BlobStorageException, ListBlobsOptions, +} +import is.hail.shadedazure.com.azure.storage.blob.specialized.BlockBlobClient import is.hail.utils._ -import org.json4s -import org.json4s.jackson.JsonMethods -import org.json4s.Formats -import java.io.{ByteArrayInputStream, ByteArrayOutputStream, FileNotFoundException, OutputStream} -import java.nio.file.Paths -import java.time.Duration +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import org.json4s.{DefaultFormats, Formats, JInt, JObject, JString, JValue} -abstract class AzureStorageFSURL( +import java.io.{ByteArrayOutputStream, FileNotFoundException, OutputStream} +import java.nio.file.Paths +import java.time.Duration + +import org.json4s.Formats +import org.json4s.jackson.JsonMethods + +class AzureStorageFSURL( val account: String, val container: String, val path: String, - val sasToken: Option[String] + val sasToken: Option[String], ) extends FSURL { - def addPathComponent(c: String): AzureStorageFSURL = { + def addPathComponent(c: String): AzureStorageFSURL = if (path == "") withPath(c) else withPath(s"$path/$c") - } - def withPath(newPath: String): AzureStorageFSURL + def fromString(s: String): AzureStorageFSURL = AzureStorageFS.parseUrl(s) - def prefix: String - def getPath: String = path + def withPath(newPath: String): AzureStorageFSURL = + new AzureStorageFSURL(account, container, newPath, sasToken) - override def toString(): String = { - val pathPart = if (path == "") "" else s"/$path" - val sasTokenPart = sasToken.getOrElse("") + def prefix: String = s"https://$account.blob.core.windows.net/$container" - prefix + pathPart + sasTokenPart - } -} - -class AzureStorageFSHailAzURL( - account: String, - container: String, - path: String, - sasToken: Option[String] -) extends AzureStorageFSURL(account, container, path, sasToken) { + def getPath: String = path - override def withPath(newPath: String): AzureStorageFSHailAzURL = { - new AzureStorageFSHailAzURL(account, container, newPath, sasToken) + def base: String = { + val pathPart = if (path == "") "" else s"/$path" + prefix + pathPart } - override def prefix: String = s"hail-az://$account/$container" -} - -class AzureStorageFSHttpsURL( - account: String, - container: String, - path: String, - sasToken: Option[String] -) extends AzureStorageFSURL(account, container, path, sasToken) { - - override def withPath(newPath: String): AzureStorageFSHttpsURL = { - new AzureStorageFSHttpsURL(account, container, newPath, sasToken) + override def toString(): String = { + val sasTokenPart = sasToken.getOrElse("") + this.base + sasTokenPart } - - override def prefix: String = s"https://$account.blob.core.windows.net/$container" } - object AzureStorageFS { - private val HAIL_AZ_URI_REGEX = "^hail-az:\\/\\/([a-z0-9_\\-\\.]+)\\/([a-z0-9_\\-\\.]+)(\\/.*)?".r - private val AZURE_HTTPS_URI_REGEX = "^https:\\/\\/([a-z0-9_\\-\\.]+)\\.blob\\.core\\.windows\\.net\\/([a-z0-9_\\-\\.]+)(\\/.*)?".r - - private val log = Logger.getLogger(getClass.getName) - - val schemes: Array[String] = Array("hail-az", "https") + private val AZURE_HTTPS_URI_REGEX = + "^https:\\/\\/([a-z0-9_\\-\\.]+)\\.blob\\.core\\.windows\\.net\\/([a-z0-9_\\-\\.]+)(\\/.*)?".r def parseUrl(filename: String): AzureStorageFSURL = { - val scheme = new URI(filename).getScheme - if (scheme == "hail-az") { - parseHailAzUrl(filename) - } else if (scheme == "https") { - parseHttpsUrl(filename) - } else { - throw new IllegalArgumentException(s"Invalid scheme, expected hail-az or https: $scheme") - } - } - - private[this] def parseHttpsUrl(filename: String): AzureStorageFSHttpsURL = { AZURE_HTTPS_URI_REGEX .findFirstMatchIn(filename) - .map(m => { - val (path, sasToken) = parsePathAndQuery(m.group(3)) - new AzureStorageFSHttpsURL(m.group(1), m.group(2), path, sasToken) - }) - .getOrElse(throw new IllegalArgumentException("ABS URI must be of the form https://.blob.core.windows.net//")) - } - - private[this] def parseHailAzUrl(filename: String): AzureStorageFSHailAzURL = { - HAIL_AZ_URI_REGEX - .findFirstMatchIn(filename) - .map(m => { + .map { m => val (path, sasToken) = parsePathAndQuery(m.group(3)) - new AzureStorageFSHailAzURL(m.group(1), m.group(2), path, sasToken) - }) - .getOrElse(throw new IllegalArgumentException("hail-az URI must be of the form hail-az:////")) + new AzureStorageFSURL(m.group(1), m.group(2), path, sasToken) + } + .getOrElse(throw new IllegalArgumentException( + "ABS URI must be of the form https://.blob.core.windows.net//" + )) } private[this] def parsePathAndQuery(maybeNullPath: String): (String, Option[String]) = { - val pathAndMaybeQuery = Paths.get(if (maybeNullPath == null) "" else maybeNullPath.stripPrefix("/")).normalize.toString + val pathAndMaybeQuery = Paths.get(if (maybeNullPath == null) "" + else maybeNullPath.stripPrefix("/")).normalize.toString // Unfortunately it is difficult to tell the difference between a glob pattern and a SAS token, // so we make the imperfect assumption that if the query string starts with at least one @@ -133,7 +89,7 @@ object AzureStorageFS { } else { val (path, queryString) = pathAndMaybeQuery.splitAt(indexOfLastQuestionMark) queryString.split("&")(0).split("=") match { - case Array(k, v) => (path, Some(queryString)) + case Array(_, _) => (path, Some(queryString)) case _ => (pathAndMaybeQuery, None) } } @@ -141,35 +97,57 @@ object AzureStorageFS { } object AzureStorageFileListEntry { - def apply(path: String, isDir: Boolean, blobProperties: BlobProperties): BlobStorageFileListEntry = { - if (isDir) { - new BlobStorageFileListEntry(path, null, 0, true) - } else { - new BlobStorageFileListEntry(path, blobProperties.getLastModified.toEpochSecond, blobProperties.getBlobSize, false) - } - } - - def apply(blobPath: String, blobItem: BlobItem): BlobStorageFileListEntry = { + def apply(rootUrl: AzureStorageFSURL, blobItem: BlobItem): BlobStorageFileListEntry = { + val url = rootUrl.withPath(blobItem.getName) if (blobItem.isPrefix) { - new BlobStorageFileListEntry(blobPath, null, 0, true) + dir(url) } else { val properties = blobItem.getProperties - new BlobStorageFileListEntry(blobPath, properties.getLastModified.toEpochSecond, properties.getContentLength, false) + new BlobStorageFileListEntry( + url.toString, + properties.getLastModified.toEpochSecond, + properties.getContentLength, + false, + ) } } + + def dir(url: AzureStorageFSURL): BlobStorageFileListEntry = + new BlobStorageFileListEntry(url.toString, null, 0, true) } -class AzureBlobServiceClientCache(credential: TokenCredential, val httpClientOptions: HttpClientOptions) { - private[this] lazy val clients = mutable.Map[(String, String, Option[String]), BlobServiceClient]() +class AzureStorageFS(val credentialsJSON: Option[String] = None) extends FS { + type URL = AzureStorageFSURL + + private[this] lazy val clients = + mutable.Map[(String, String, Option[String]), BlobServiceClient]() - def getServiceClient(url: AzureStorageFSURL): BlobServiceClient = { + private lazy val credential = credentialsJSON match { + case None => + new DefaultAzureCredentialBuilder().build() + case Some(keyData) => + implicit val formats: Formats = defaultJSONFormats + val kvs = JsonMethods.parse(keyData) + val appId = (kvs \ "appId").extract[String] + val password = (kvs \ "password").extract[String] + val tenant = (kvs \ "tenant").extract[String] + + new ClientSecretCredentialBuilder() + .clientId(appId) + .clientSecret(password) + .tenantId(tenant) + .build() + } + + def getServiceClient(url: URL): BlobServiceClient = { val k = (url.account, url.container, url.sasToken) clients.get(k) match { case Some(client) => client case None => val clientBuilder = url.sasToken match { - case Some(sasToken) => new BlobServiceClientBuilder().credential(new AzureSasCredential(sasToken)) + case Some(sasToken) => + new BlobServiceClientBuilder().credential(new AzureSasCredential(sasToken)) case None => new BlobServiceClientBuilder().credential(credential) } @@ -189,13 +167,6 @@ class AzureBlobServiceClientCache(credential: TokenCredential, val httpClientOpt .buildClient() clients += ((url.account, url.container, url.sasToken) -> blobServiceClient) } -} - - -class AzureStorageFS(val credentialsJSON: Option[String] = None) extends FS { - type URL = AzureStorageFSURL - - import AzureStorageFS.log override def parseUrl(filename: String): URL = AzureStorageFS.parseUrl(filename) @@ -211,18 +182,18 @@ class AzureStorageFS(val credentialsJSON: Option[String] = None) extends FS { def getConfiguration(): Unit = () - def setConfiguration(config: Any): Unit = { } + def setConfiguration(config: Any): Unit = {} // ABS errors if you attempt credentialed access for a public container, // so we try once with credentials, if that fails use anonymous access for // that container going forward. def handlePublicAccessError[T](url: URL)(f: => T): T = { retryTransientErrors { - try { + try f - } catch { + catch { case e: BlobStorageException if e.getStatusCode == 401 => - serviceClientCache.setPublicAccessServiceClient(url) + setPublicAccessServiceClient(url) f } } @@ -234,44 +205,24 @@ class AzureStorageFS(val credentialsJSON: Option[String] = None) extends FS { .setConnectionIdleTimeout(Duration.ofSeconds(5)) .setWriteTimeout(Duration.ofSeconds(5)) - private lazy val serviceClientCache = credentialsJSON match { - case None => - val credential: DefaultAzureCredential = new DefaultAzureCredentialBuilder().build() - new AzureBlobServiceClientCache(credential, httpClientOptions) - case Some(keyData) => - implicit val formats: Formats = defaultJSONFormats - val kvs = JsonMethods.parse(keyData) - val appId = (kvs \ "appId").extract[String] - val password = (kvs \ "password").extract[String] - val tenant = (kvs \ "tenant").extract[String] - - val clientSecretCredential: ClientSecretCredential = new ClientSecretCredentialBuilder() - .clientId(appId) - .clientSecret(password) - .tenantId(tenant) - .build() - new AzureBlobServiceClientCache(clientSecretCredential, httpClientOptions) - } - // Set to max timeout for blob storage of 30 seconds - // https://docs.microsoft.com/en-us/rest/api/storageservices/setting-timeouts-for-blob-service-operations + /* https://docs.microsoft.com/en-us/rest/api/storageservices/setting-timeouts-for-blob-service-operations */ private val timeout = Duration.ofSeconds(30) def getBlobClient(url: URL): BlobClient = retryTransientErrors { - serviceClientCache.getServiceClient(url).getBlobContainerClient(url.container).getBlobClient(url.path) + getServiceClient(url).getBlobContainerClient(url.container).getBlobClient( + url.path + ) } def getContainerClient(url: URL): BlobContainerClient = retryTransientErrors { - serviceClientCache.getServiceClient(url).getBlobContainerClient(url.container) + getServiceClient(url).getBlobContainerClient(url.container) } def openNoCompression(url: URL): SeekableDataInputStream = handlePublicAccessError(url) { - val blobClient: BlobClient = getBlobClient(url) - val blobSize = blobClient.getProperties.getBlobSize + val blobSize = getBlobClient(url).getProperties.getBlobSize val is: SeekableInputStream = new FSSeekableInputStream { - private[this] val client: BlobClient = blobClient - val bbOS = new OutputStream { override def write(b: Array[Byte]): Unit = bb.put(b) override def write(b: Int): Unit = bb.put(b.toByte) @@ -289,9 +240,15 @@ class AzureStorageFS(val credentialsJSON: Option[String] = None) extends FS { val response = retryTransientErrors { bb.clear() - client.downloadStreamWithResponse( - bbOS, new BlobRange(pos, count), - null, null, false, timeout, null) + getBlobClient(url).downloadStreamWithResponse( + bbOS, + new BlobRange(pos, count), + null, + null, + false, + timeout, + null, + ) } if (response.getStatusCode >= 200 && response.getStatusCode < 300) { @@ -365,17 +322,17 @@ class AzureStorageFS(val credentialsJSON: Option[String] = None) extends FS { options.setPrefix(prefix) val prefixMatches = blobContainerClient.listBlobs(options, timeout) - prefixMatches.forEach(blobItem => { + prefixMatches.forEach { blobItem => assert(!blobItem.isPrefix) getBlobClient(url.withPath(blobItem.getName)).delete() - }) + } } else { - try { + try if (fileListEntry(url).isFile) { blobClient.delete() } - } catch { - case e: FileNotFoundException => + catch { + case _: FileNotFoundException => } } } @@ -388,10 +345,7 @@ class AzureStorageFS(val credentialsJSON: Option[String] = None) extends FS { // collect all children of this directory (blobs and subdirectories) val prefixMatches = blobContainerClient.listBlobsByHierarchy(prefix) - prefixMatches.forEach(blobItem => { - val blobPath = dropTrailingSlash(url.withPath(blobItem.getName).toString()) - statList += AzureStorageFileListEntry(blobPath, blobItem) - }) + prefixMatches.forEach(blobItem => statList += AzureStorageFileListEntry(url, blobItem)) statList.toArray } @@ -400,35 +354,39 @@ class AzureStorageFS(val credentialsJSON: Option[String] = None) extends FS { globWithPrefix(prefix = url.withPath(""), path = dropTrailingSlash(url.path)) } - override def fileListEntry(url: URL): FileListEntry = retryTransientErrors { + override def fileStatus(url: AzureStorageFSURL): FileStatus = retryTransientErrors { if (url.path == "") { - return new BlobStorageFileListEntry(url.toString, null, 0, true) + return AzureStorageFileListEntry.dir(url) } - val blobClient: BlobClient = getBlobClient(url) - val blobContainerClient: BlobContainerClient = getContainerClient(url) + val blobClient = getBlobClient(url) + val blobProperties = + try + blobClient.getProperties + catch { + case e: BlobStorageException if e.getStatusCode == 404 => + throw new FileNotFoundException(url.toString) + } - val prefix = dropTrailingSlash(url.path) + "/" - val options: ListBlobsOptions = new ListBlobsOptions().setPrefix(prefix).setMaxResultsPerPage(1) - val prefixMatches = blobContainerClient.listBlobs(options, timeout) - val isDir = prefixMatches.iterator().hasNext + new BlobStorageFileStatus( + url.toString, + blobProperties.getLastModified.toEpochSecond, + blobProperties.getBlobSize, + ) + } - val filename = dropTrailingSlash(url.toString) + override def fileListEntry(url: URL): FileListEntry = { + if (url.getPath == "") + return AzureStorageFileListEntry.dir(url) - val blobProperties = if (!isDir) { - try { - blobClient.getProperties - } catch { - case e: BlobStorageException => - if (e.getStatusCode == 404) - throw new FileNotFoundException(s"File not found: $filename") - else - throw e - } - } else - null + val it = { + val containerClient = getContainerClient(url) + val options = new ListBlobsOptions().setPrefix(dropTrailingSlash(url.getPath)) + val prefixMatches = containerClient.listBlobsByHierarchy("/", options, timeout) + prefixMatches.iterator() + }.asScala.map(AzureStorageFileListEntry.apply(url, _)) - AzureStorageFileListEntry(filename, isDir, blobProperties) + fileListEntryFromIterator(url, it) } override def eTag(url: URL): Some[String] = @@ -436,9 +394,8 @@ class AzureStorageFS(val credentialsJSON: Option[String] = None) extends FS { Some(getBlobClient(url).getProperties.getETag) } - def makeQualified(filename: String): String = { - AzureStorageFS.parseUrl(filename) + parseUrl(filename) filename } } diff --git a/hail/src/main/scala/is/hail/io/fs/FS.scala b/hail/src/main/scala/is/hail/io/fs/FS.scala index 158334ff58a..aa9b00cccf4 100644 --- a/hail/src/main/scala/is/hail/io/fs/FS.scala +++ b/hail/src/main/scala/is/hail/io/fs/FS.scala @@ -1,38 +1,35 @@ package is.hail.io.fs +import is.hail.{HailContext, HailFeatureFlags} import is.hail.backend.BroadcastValue import is.hail.io.compress.{BGzipInputStream, BGzipOutputStream} import is.hail.io.fs.FSUtil.{containsWildcard, dropTrailingSlash} import is.hail.services._ import is.hail.utils._ -import is.hail.{HailContext, HailFeatureFlags} -import org.apache.commons.compress.compressors.gzip.GzipCompressorInputStream -import org.apache.commons.io.IOUtils -import org.apache.hadoop + +import scala.collection.mutable +import scala.io.Source import java.io._ import java.nio.ByteBuffer import java.nio.charset._ import java.nio.file.FileSystems import java.util.zip.GZIPOutputStream -import scala.collection.mutable -import scala.io.Source - -trait Positioned { - def getPosition: Long -} -trait Seekable extends Positioned { - def seek(pos: Long): Unit -} +import org.apache.commons.compress.compressors.gzip.GzipCompressorInputStream +import org.apache.commons.io.IOUtils +import org.apache.hadoop +import org.apache.log4j.Logger -class WrappedSeekableDataInputStream(is: SeekableInputStream) extends DataInputStream(is) with Seekable { +class WrappedSeekableDataInputStream(is: SeekableInputStream) + extends DataInputStream(is) with Seekable { def getPosition: Long = is.getPosition def seek(pos: Long): Unit = is.seek(pos) } -class WrappedPositionedDataOutputStream(os: PositionedOutputStream) extends DataOutputStream(os) with Positioned { +class WrappedPositionedDataOutputStream(os: PositionedOutputStream) + extends DataOutputStream(os) with Positioned { def getPosition: Long = os.getPosition } @@ -46,13 +43,11 @@ class WrappedPositionOutputStream(os: OutputStream) extends OutputStream with Po count += 1 } - override def write(bytes: Array[Byte], off: Int, len: Int): Unit = { + override def write(bytes: Array[Byte], off: Int, len: Int): Unit = os.write(bytes, off, len) - } - override def close(): Unit = { + override def close(): Unit = os.close() - } def getPosition: Long = count } @@ -61,26 +56,56 @@ trait FSURL { def getPath: String } -trait FileListEntry { +trait FileStatus { def getPath: String + def getActualUrl: String def getModificationTime: java.lang.Long def getLen: Long - def isDirectory: Boolean def isSymlink: Boolean - def isFile: Boolean def getOwner: String + def isFileOrFileAndDirectory: Boolean = true } -class BlobStorageFileListEntry(path: String, modificationTime: java.lang.Long, size: Long, isDir: Boolean) extends FileListEntry { - def getPath: String = path +trait FileListEntry extends FileStatus { + def isFile: Boolean + def isDirectory: Boolean + override def isFileOrFileAndDirectory: Boolean = isFile +} + +class BlobStorageFileStatus( + actualUrl: String, + modificationTime: java.lang.Long, + size: Long, +) extends FileStatus { + // NB: it is called getPath but it *must* return the URL *with* the scheme. + def getPath: String = + dropTrailingSlash( + actualUrl + ) // getPath is a backwards compatible method: in the past, Hail dropped trailing slashes + def getActualUrl: String = actualUrl def getModificationTime: java.lang.Long = modificationTime def getLen: Long = size - def isDirectory: Boolean = isDir - def isFile: Boolean = !isDir def isSymlink: Boolean = false def getOwner: String = null } +class BlobStorageFileListEntry( + actualUrl: String, + modificationTime: java.lang.Long, + size: Long, + isDir: Boolean, +) extends BlobStorageFileStatus( + actualUrl, + modificationTime, + size, + ) with FileListEntry { + def isDirectory: Boolean = isDir + def isFile: Boolean = !isDir + override def isFileOrFileAndDirectory = isFile + override def toString: String = s"BSFLE($actualUrl $modificationTime $size $isDir)" + +} + trait CompressionCodec { def makeInputStream(is: InputStream): InputStream @@ -100,6 +125,8 @@ object BGZipCompressionCodec extends CompressionCodec { def makeOutputStream(os: OutputStream): OutputStream = new BGzipOutputStream(os) } +class FileAndDirectoryException(message: String) extends RuntimeException(message) + object FSUtil { def dropTrailingSlash(path: String): String = { if (path.isEmpty) @@ -203,16 +230,16 @@ abstract class FSPositionedOutputStream(val capacity: Int) extends OutputStream protected[this] val bb: ByteBuffer = ByteBuffer.allocate(capacity) protected[this] var pos: Long = 0 - def flush(): Unit + def flush(): Unit - def write(i: Int): Unit = { + def write(i: Int): Unit = { if (bb.remaining() == 0) flush() bb.put(i.toByte) pos += 1 } - override def write(bytes: Array[Byte], off: Int, len: Int): Unit = { + override def write(bytes: Array[Byte], off: Int, len: Int): Unit = { var i = off var remaining = len while (remaining > 0) { @@ -232,7 +259,7 @@ abstract class FSPositionedOutputStream(val capacity: Int) extends OutputStream object FS { def cloudSpecificFS( credentialsPath: String, - flags: Option[HailFeatureFlags] + flags: Option[HailFeatureFlags], ): FS = retryTransientErrors { val cloudSpecificFS = using(new FileInputStream(credentialsPath)) { is => val credentialsStr = Some(IOUtils.toString(is, Charset.defaultCharset())) @@ -240,12 +267,16 @@ object FS { case Some("gcp") => val requesterPaysConfiguration = flags.flatMap { flags => RequesterPaysConfiguration.fromFlags( - flags.get("gcs_requester_pays_project"), flags.get("gcs_requester_pays_buckets") + flags.get("gcs_requester_pays_project"), + flags.get("gcs_requester_pays_buckets"), ) } new GoogleStorageFS(credentialsStr, requesterPaysConfiguration) case Some("azure") => - new AzureStorageFS(credentialsStr) + sys.env.get("HAIL_TERRA") match { + case Some(_) => new TerraAzureStorageFS() + case None => new AzureStorageFS(credentialsStr) + } case Some(cloud) => throw new IllegalArgumentException(s"Bad cloud: $cloud") case None => @@ -253,28 +284,38 @@ object FS { } } - new RouterFS(Array(cloudSpecificFS, new HadoopFS(new SerializableHadoopConfiguration(new hadoop.conf.Configuration())))) + new RouterFS(Array( + cloudSpecificFS, + new HadoopFS(new SerializableHadoopConfiguration(new hadoop.conf.Configuration())), + )) } + + private val log = Logger.getLogger(getClass.getName()) } trait FS extends Serializable { type URL <: FSURL + import FS.log + def parseUrl(filename: String): URL def validUrl(filename: String): Boolean def urlAddPathComponent(url: URL, component: String): URL - final def openCachedNoCompression(filename: String): SeekableDataInputStream = openNoCompression(filename) + final def openCachedNoCompression(filename: String): SeekableDataInputStream = + openNoCompression(filename) def openCachedNoCompression(url: URL): SeekableDataInputStream = openNoCompression(url) - final def createCachedNoCompression(filename: String): PositionedDataOutputStream = createNoCompression(filename) + final def createCachedNoCompression(filename: String): PositionedDataOutputStream = + createNoCompression(filename) def createCachedNoCompression(url: URL): PositionedDataOutputStream = createNoCompression(url) - final def writeCached(filename: String)(writer: PositionedDataOutputStream => Unit) = writePDOS(filename)(writer) + final def writeCached(filename: String)(writer: PositionedDataOutputStream => Unit) = + writePDOS(filename)(writer) def writeCached(url: URL)(writer: PositionedDataOutputStream => Unit) = writePDOS(url)(writer) @@ -331,19 +372,19 @@ trait FS extends Serializable { "" } - final def openNoCompression(filename: String): SeekableDataInputStream = openNoCompression(parseUrl(filename)) + final def openNoCompression(filename: String): SeekableDataInputStream = + openNoCompression(parseUrl(filename)) def openNoCompression(url: URL): SeekableDataInputStream final def readNoCompression(filename: String): Array[Byte] = readNoCompression(parseUrl(filename)) def readNoCompression(url: URL): Array[Byte] = retryTransientErrors { - using(openNoCompression(url)) { is => - IOUtils.toByteArray(is) - } + using(openNoCompression(url))(is => IOUtils.toByteArray(is)) } - final def createNoCompression(filename: String): PositionedDataOutputStream = createNoCompression(parseUrl(filename)) + final def createNoCompression(filename: String): PositionedDataOutputStream = + createNoCompression(parseUrl(filename)) def createNoCompression(url: URL): PositionedDataOutputStream @@ -351,11 +392,13 @@ trait FS extends Serializable { def mkDir(url: URL): Unit = () - final def listDirectory(filename: String): Array[FileListEntry] = listDirectory(parseUrl(filename)) + final def listDirectory(filename: String): Array[FileListEntry] = + listDirectory(parseUrl(filename)) def listDirectory(url: URL): Array[FileListEntry] - final def delete(filename: String, recursive: Boolean): Unit = delete(parseUrl(filename), recursive) + final def delete(filename: String, recursive: Boolean): Unit = + delete(parseUrl(filename), recursive) def delete(url: URL, recursive: Boolean): Unit @@ -363,7 +406,7 @@ trait FS extends Serializable { def glob(url: URL): Array[FileListEntry] - def globWithPrefix(prefix: URL, path: String) = { + def globWithPrefix(prefix: URL, path: String): Array[FileListEntry] = { val components = if (path == "") Array.empty[String] @@ -377,9 +420,9 @@ trait FS extends Serializable { if (i == components.length) { var t = fs if (t == null) { - try { + try t = fileListEntry(prefix) - } catch { + catch { case _: FileNotFoundException => } } @@ -407,13 +450,92 @@ trait FS extends Serializable { ab.toArray } - def globAll(filenames: Iterable[String]): Array[FileListEntry] = filenames.flatMap((x: String) => glob(x)).toArray + def globAll(filenames: Iterable[String]): Array[FileListEntry] = + filenames.flatMap((x: String) => glob(x)).toArray final def eTag(filename: String): Option[String] = eTag(parseUrl(filename)) /** Return the file's HTTP etag, if the underlying file system supports etags. */ def eTag(url: URL): Option[String] + final def fileStatus(filename: String): FileStatus = fileStatus(parseUrl(filename)) + + def fileStatus(url: URL): FileStatus + + protected def fileListEntryFromIterator( + url: URL, + it: Iterator[FileListEntry], + ): FileListEntry = { + val urlStr = url.toString + val noSlash = dropTrailingSlash(urlStr) + val withSlash = noSlash + "/" + + var continue = it.hasNext + var fileFle: FileListEntry = null + var trailingSlashFle: FileListEntry = null + var dirFle: FileListEntry = null + while (continue) { + val fle = it.next() + + if (fle.isFile) { + if (fle.getActualUrl == noSlash) { + fileFle = fle + } else if (fle.getActualUrl == withSlash) { + // This is a *blob* whose name has a trailing slash e.g. "gs://bucket/object/". Users + // really ought to avoid creating these. + trailingSlashFle = fle + } + } else if (fle.isDirectory && dropTrailingSlash(fle.getActualUrl) == noSlash) { + // In Google, "directory" entries always have a trailing slash. + // + // In Azure, "directory" entries never have a trailing slash. + dirFle = fle + } + + continue = + it.hasNext && (fle.getActualUrl <= withSlash) // cloud storage APIs return blobs in alphabetical order, so we need not keep searching after withSlash + } + + if (fileFle != null) { + if (dirFle != null) { + if (trailingSlashFle != null) { + throw new FileAndDirectoryException( + s"${url.toString} appears twice as a file (once with and once without a trailing slash) and once as a directory." + ) + } else { + throw new FileAndDirectoryException( + s"${url.toString} appears as both file ${fileFle.getActualUrl} and directory ${dirFle.getActualUrl}." + ) + } + } else { + if (trailingSlashFle != null) { + log.warn( + s"Two blobs exist matching ${url.toString}: once with and once without a trailing slash. We will return the one without a trailing slash." + ) + } + fileFle + } + } else { + if (dirFle != null) { + if (trailingSlashFle != null) { + log.warn( + s"A blob with a literal trailing slash exists as well as blobs with that prefix. We will treat this as a directory. ${url.toString}" + ) + } + dirFle + } else { + if (trailingSlashFle != null) { + throw new FileNotFoundException( + s"A blob with a literal trailing slash exists. These are sometimes uses to indicate empty directories. " + + s"Hail does not support this behavior. This folder is treated as if it does not exist. ${url.toString}" + ) + } else { + throw new FileNotFoundException(url.toString) + } + } + } + } + final def fileListEntry(filename: String): FileListEntry = fileListEntry(parseUrl(filename)) def fileListEntry(url: URL): FileListEntry @@ -422,12 +544,13 @@ trait FS extends Serializable { final def deleteOnExit(filename: String): Unit = deleteOnExit(parseUrl(filename)) - def deleteOnExit(url: URL): Unit = { + def deleteOnExit(url: URL): Unit = Runtime.getRuntime.addShutdownHook( - new Thread(() => delete(url, recursive = false))) - } + new Thread(() => delete(url, recursive = false)) + ) - final def open(filename: String, codec: CompressionCodec): InputStream = open(parseUrl(filename), codec) + final def open(filename: String, codec: CompressionCodec): InputStream = + open(parseUrl(filename), codec) def open(url: URL, codec: CompressionCodec): InputStream = { val is = openNoCompression(url) @@ -443,7 +566,8 @@ trait FS extends Serializable { def open(url: URL): InputStream = open(url, gzAsBGZ = false) - final def open(filename: String, gzAsBGZ: Boolean): InputStream = open(parseUrl(filename), gzAsBGZ) + final def open(filename: String, gzAsBGZ: Boolean): InputStream = + open(parseUrl(filename), gzAsBGZ) def open(url: URL, gzAsBGZ: Boolean): InputStream = open(url, getCodecFromPath(url.getPath, gzAsBGZ)) @@ -460,39 +584,39 @@ trait FS extends Serializable { os } - final def write(filename: String)(writer: OutputStream => Unit): Unit = write(parseUrl(filename))(writer) + final def write(filename: String)(writer: OutputStream => Unit): Unit = + write(parseUrl(filename))(writer) def write(url: URL)(writer: OutputStream => Unit): Unit = using(create(url))(writer) - final def writePDOS(filename: String)(writer: PositionedDataOutputStream => Unit): Unit = writePDOS(parseUrl(filename))(writer) + final def writePDOS(filename: String)(writer: PositionedDataOutputStream => Unit): Unit = + writePDOS(parseUrl(filename))(writer) def writePDOS(url: URL)(writer: PositionedDataOutputStream => Unit): Unit = using(create(url))(os => writer(outputStreamToPositionedDataOutputStream(os))) final def getFileSize(filename: String): Long = getFileSize(parseUrl(filename)) - def getFileSize(url: URL): Long = fileListEntry(url).getLen + def getFileSize(url: URL): Long = fileStatus(url).getLen final def isFile(filename: String): Boolean = isFile(parseUrl(filename)) - final def isFile(url: URL): Boolean = { - try { - fileListEntry(url).isFile - } catch { + final def isFile(url: URL): Boolean = + try + fileStatus(url).isFileOrFileAndDirectory + catch { case _: FileNotFoundException => false } - } final def isDir(filename: String): Boolean = isDir(parseUrl(filename)) - final def isDir(url: URL): Boolean = { - try { + final def isDir(url: URL): Boolean = + try fileListEntry(url).isDirectory - } catch { + catch { case _: FileNotFoundException => false } - } final def exists(filename: String): Boolean = exists(parseUrl(filename)) @@ -507,13 +631,12 @@ trait FS extends Serializable { final def copy(src: String, dst: String): Unit = copy(src, dst, false) - final def copy(src: String, dst: String, deleteSource: Boolean): Unit = copy(parseUrl(src), parseUrl(dst), deleteSource) + final def copy(src: String, dst: String, deleteSource: Boolean): Unit = + copy(parseUrl(src), parseUrl(dst), deleteSource) def copy(src: URL, dst: URL, deleteSource: Boolean = false): Unit = { using(openNoCompression(src)) { is => - using(createNoCompression(dst)) { os => - IOUtils.copy(is, os) - } + using(createNoCompression(dst))(os => IOUtils.copy(is, os)) } if (deleteSource) delete(src, recursive = false) @@ -521,19 +644,21 @@ trait FS extends Serializable { final def copyRecode(src: String, dst: String): Unit = copyRecode(src, dst, false) - final def copyRecode(src: String, dst: String, deleteSource: Boolean): Unit = copyRecode(parseUrl(src), parseUrl(dst), deleteSource) + final def copyRecode(src: String, dst: String, deleteSource: Boolean): Unit = + copyRecode(parseUrl(src), parseUrl(dst), deleteSource) def copyRecode(src: URL, dst: URL, deleteSource: Boolean = false): Unit = { - using(open(src)) { is => - using(create(dst)) { os => - IOUtils.copy(is, os) - } - } + using(open(src))(is => using(create(dst))(os => IOUtils.copy(is, os))) if (deleteSource) delete(src, recursive = false) } - def readLines[T](filename: String, filtAndReplace: TextInputFilterAndReplace = TextInputFilterAndReplace())(reader: Iterator[WithContext[String]] => T): T = { + def readLines[T]( + filename: String, + filtAndReplace: TextInputFilterAndReplace = TextInputFilterAndReplace(), + )( + reader: Iterator[WithContext[String]] => T + ): T = { using(open(filename)) { is => val lines = Source.fromInputStream(is) @@ -548,7 +673,8 @@ trait FS extends Serializable { } } - def writeTable(filename: String, lines: Traversable[String], header: Option[String] = None): Unit = { + def writeTable(filename: String, lines: Traversable[String], header: Option[String] = None) + : Unit = { using(new OutputStreamWriter(create(filename))) { fw => header.foreach { h => fw.write(h) @@ -567,8 +693,8 @@ trait FS extends Serializable { numPartFilesExpected: Int, deleteSource: Boolean = true, header: Boolean = true, - partFilesOpt: Option[IndexedSeq[String]] = None - ) { + partFilesOpt: Option[IndexedSeq[String]] = None, + ): Unit = { if (!exists(sourceFolder + "/_SUCCESS")) fatal("write failed: no success indicator found") @@ -581,24 +707,28 @@ trait FS extends Serializable { else if (!header && headerFileListEntry.nonEmpty) fatal(s"Found unexpected header file") - val partFileListEntries = partFilesOpt match { + val partFileStatuses: Array[_ <: FileStatus] = partFilesOpt match { case None => glob(sourceFolder + "/part-*") - case Some(files) => files.map(f => fileListEntry(sourceFolder + "/" + f)).toArray + case Some(files) => files.map(f => fileStatus(sourceFolder + "/" + f)).toArray + } + + val sortedPartFileStatuses = partFileStatuses.sortBy { fileStatus => + getPartNumber(fileStatus.getPath) } - val sortedPartFileListEntries = partFileListEntries.sortBy(fs => getPartNumber(new hadoop.fs.Path(fs.getPath).getName)) - if (sortedPartFileListEntries.length != numPartFilesExpected) - fatal(s"Expected $numPartFilesExpected part files but found ${ sortedPartFileListEntries.length }") - val filesToMerge = headerFileListEntry ++ sortedPartFileListEntries + if (sortedPartFileStatuses.length != numPartFilesExpected) + fatal(s"Expected $numPartFilesExpected part files but found ${sortedPartFileStatuses.length}") + + val filesToMerge: Array[FileStatus] = headerFileListEntry ++ sortedPartFileStatuses - info(s"merging ${ filesToMerge.length } files totalling " + - s"${ readableBytes(sortedPartFileListEntries.map(_.getLen).sum) }...") + info(s"merging ${filesToMerge.length} files totalling " + + s"${readableBytes(filesToMerge.map(_.getLen).sum)}...") val (_, dt) = time { copyMergeList(filesToMerge, destinationFile, deleteSource) } - info(s"while writing:\n $destinationFile\n merge time: ${ formatTime(dt) }") + info(s"while writing:\n $destinationFile\n merge time: ${formatTime(dt)}") if (deleteSource) { delete(sourceFolder, recursive = true) @@ -607,55 +737,53 @@ trait FS extends Serializable { } } - def copyMergeList(srcFileListEntries: Array[FileListEntry], destFilename: String, deleteSource: Boolean = true) { + def copyMergeList( + srcFileStatuses: Array[_ <: FileStatus], + destFilename: String, + deleteSource: Boolean = true, + ): Unit = { val codec = Option(getCodecFromPath(destFilename)) val isBGzip = codec.exists(_ == BGZipCompressionCodec) - require(srcFileListEntries.forall { - fileListEntry => fileListEntry.getPath != destFilename && fileListEntry.isFile + require(srcFileStatuses.forall { + fileStatus => fileStatus.getPath != destFilename && fileStatus.isFileOrFileAndDirectory }) using(createNoCompression(destFilename)) { os => - var i = 0 - while (i < srcFileListEntries.length) { - val fileListEntry = srcFileListEntries(i) - val lenAdjust: Long = if (isBGzip && i < srcFileListEntries.length - 1) + while (i < srcFileStatuses.length) { + val fileListEntry = srcFileStatuses(i) + val lenAdjust: Long = if (isBGzip && i < srcFileStatuses.length - 1) -28 else 0 using(openNoCompression(fileListEntry.getPath)) { is => - hadoop.io.IOUtils.copyBytes(is, os, - fileListEntry.getLen + lenAdjust, - false) + hadoop.io.IOUtils.copyBytes(is, os, fileListEntry.getLen + lenAdjust, false) } i += 1 } } if (deleteSource) { - srcFileListEntries.foreach { fileListEntry => - delete(fileListEntry.getPath.toString, recursive = true) - } + srcFileStatuses.foreach(fileStatus => delete(fileStatus.getPath, recursive = true)) } } def concatenateFiles(sourceNames: Array[String], destFilename: String): Unit = { - val fileListEntries = sourceNames.map(fileListEntry(_)) + val fileStatuses = sourceNames.map(fileStatus(_)) - info(s"merging ${ fileListEntries.length } files totalling " + - s"${ readableBytes(fileListEntries.map(_.getLen).sum) }...") + info(s"merging ${fileStatuses.length} files totalling " + + s"${readableBytes(fileStatuses.map(_.getLen).sum)}...") - val (_, timing) = time(copyMergeList(fileListEntries, destFilename, deleteSource = false)) + val (_, timing) = time(copyMergeList(fileStatuses, destFilename, deleteSource = false)) - info(s"while writing:\n $destFilename\n merge time: ${ formatTime(timing) }") + info(s"while writing:\n $destFilename\n merge time: ${formatTime(timing)}") } final def touch(filename: String): Unit = touch(parseUrl(filename)) - def touch(url: URL): Unit = { + def touch(url: URL): Unit = using(createNoCompression(url))(_ => ()) - } lazy val broadcast: BroadcastValue[FS] = HailContext.backend.broadcast(this) diff --git a/hail/src/main/scala/is/hail/io/fs/GoogleStorageFS.scala b/hail/src/main/scala/is/hail/io/fs/GoogleStorageFS.scala index f04b243667a..a7760f92e18 100644 --- a/hail/src/main/scala/is/hail/io/fs/GoogleStorageFS.scala +++ b/hail/src/main/scala/is/hail/io/fs/GoogleStorageFS.scala @@ -1,31 +1,32 @@ package is.hail.io.fs - -import com.google.api.client.googleapis.json.GoogleJsonResponseException -import com.google.auth.oauth2.ServiceAccountCredentials -import com.google.cloud.http.HttpTransportOptions -import com.google.cloud.storage.Storage.{BlobGetOption, BlobListOption, BlobWriteOption, BlobSourceOption} -import com.google.cloud.storage.{Blob, BlobId, BlobInfo, Storage, StorageException, StorageOptions} -import com.google.cloud.{ReadChannel, WriteChannel} import is.hail.io.fs.FSUtil.dropTrailingSlash -import is.hail.services.{retryTransientErrors, isTransientError} +import is.hail.services.{isTransientError, retryTransientErrors} import is.hail.utils._ -import org.apache.log4j.Logger + +import scala.jdk.CollectionConverters._ import java.io.{ByteArrayInputStream, FileNotFoundException, IOException} -import java.net.URI import java.nio.ByteBuffer import java.nio.file.Paths -import scala.jdk.CollectionConverters.{asJavaIterableConverter, asScalaIteratorConverter, iterableAsScalaIterableConverter} +import com.google.api.client.googleapis.json.GoogleJsonResponseException +import com.google.auth.oauth2.ServiceAccountCredentials +import com.google.cloud.{ReadChannel, WriteChannel} +import com.google.cloud.http.HttpTransportOptions +import com.google.cloud.storage.{Blob, BlobId, BlobInfo, Storage, StorageException, StorageOptions} +import com.google.cloud.storage.Storage.{ + BlobGetOption, BlobListOption, BlobSourceOption, BlobWriteOption, +} +import org.apache.log4j.Logger case class GoogleStorageFSURL(bucket: String, path: String) extends FSURL { - def addPathComponent(c: String): GoogleStorageFSURL = { + def addPathComponent(c: String): GoogleStorageFSURL = if (path == "") withPath(c) else withPath(s"$path/$c") - } + def withPath(newPath: String): GoogleStorageFSURL = GoogleStorageFSURL(bucket, newPath) def fromString(s: String): GoogleStorageFSURL = GoogleStorageFS.parseUrl(s) @@ -38,13 +39,12 @@ case class GoogleStorageFSURL(bucket: String, path: String) extends FSURL { } } - object GoogleStorageFS { private val log = Logger.getLogger(getClass.getName()) private[this] val GCS_URI_REGEX = "^gs:\\/\\/([a-z0-9_\\-\\.]+)(\\/.*)?".r def parseUrl(filename: String): GoogleStorageFSURL = { - val scheme = new URI(filename).getScheme + val scheme = filename.split(":")(0) if (scheme == null || scheme != "gs") { throw new IllegalArgumentException(s"Invalid scheme, expected gs: $scheme") } @@ -55,7 +55,9 @@ object GoogleStorageFS { val maybePath = m.group(2) val path = Paths.get(if (maybePath == null) "" else maybePath.stripPrefix("/")) GoogleStorageFSURL(bucket, path.normalize().toString) - case None => throw new IllegalArgumentException(s"GCS URI must be of the form: gs://bucket/path, found $filename") + case None => throw new IllegalArgumentException( + s"GCS URI must be of the form: gs://bucket/path, found $filename" + ) } } } @@ -64,26 +66,31 @@ object GoogleStorageFileListEntry { def apply(blob: Blob): BlobStorageFileListEntry = { val isDir = blob.isDirectory - val name = dropTrailingSlash(blob.getName) - new BlobStorageFileListEntry( - s"gs://${ blob.getBucket }/$name", + s"gs://${blob.getBucket}/${blob.getName}", if (isDir) null else blob.getUpdateTimeOffsetDateTime.toInstant().toEpochMilli(), blob.getSize, - isDir) + isDir, + ) } + + def dir(url: GoogleStorageFSURL): BlobStorageFileListEntry = + return new BlobStorageFileListEntry(url.toString, null, 0, true) } object RequesterPaysConfiguration { - def fromFlags(requesterPaysProject: String, requesterPaysBuckets: String): Option[RequesterPaysConfiguration] = { + def fromFlags(requesterPaysProject: String, requesterPaysBuckets: String) + : Option[RequesterPaysConfiguration] = { if (requesterPaysProject == null) { if (requesterPaysBuckets == null) { None } else { - fatal(s"Expected gcs_requester_pays_buckets flag to be unset when gcs_requester_pays_project is unset, but instead found: $requesterPaysBuckets") + fatal( + s"Expected gcs_requester_pays_buckets flag to be unset when gcs_requester_pays_project is unset, but instead found: $requesterPaysBuckets" + ) } } else { val buckets = if (requesterPaysBuckets == null) { @@ -96,15 +103,14 @@ object RequesterPaysConfiguration { } } - case class RequesterPaysConfiguration( val project: String, - val buckets: Option[Set[String]] = None + val buckets: Option[Set[String]] = None, ) extends Serializable class GoogleStorageFS( private[this] val serviceAccountKey: Option[String] = None, - private[this] var requesterPaysConfiguration: Option[RequesterPaysConfiguration] = None + private[this] var requesterPaysConfiguration: Option[RequesterPaysConfiguration] = None, ) extends FS { type URL = GoogleStorageFSURL @@ -117,15 +123,14 @@ class GoogleStorageFS( def urlAddPathComponent(url: URL, component: String): URL = url.addPathComponent(component) - def getConfiguration(): Option[RequesterPaysConfiguration] = { + def getConfiguration(): Option[RequesterPaysConfiguration] = requesterPaysConfiguration - } - def setConfiguration(config: Any): Unit = { + def setConfiguration(config: Any): Unit = requesterPaysConfiguration = config.asInstanceOf[Option[RequesterPaysConfiguration]] - } - private[this] def requesterPaysOptions[T](bucket: String, makeUserProjectOption: String => T): Seq[T] = { + private[this] def requesterPaysOptions[T](bucket: String, makeUserProjectOption: String => T) + : Seq[T] = { requesterPaysConfiguration match { case None => Seq() @@ -144,34 +149,37 @@ class GoogleStorageFS( exc: Throwable, makeRequest: Seq[U] => T, makeUserProjectOption: String => U, - bucket: String - ): T = { + bucket: String, + ): T = if (isRequesterPaysException(exc)) { makeRequest(requesterPaysOptions(bucket, makeUserProjectOption)) } else { throw exc } - } def isRequesterPaysException(exc: Throwable): Boolean = exc match { case exc: IOException if exc.getCause() != null => isRequesterPaysException(exc.getCause()) case exc: StorageException => - exc.getMessage != null && (exc.getMessage.equals("userProjectMissing") || (exc.getCode == 400 && exc.getMessage.contains("requester pays"))) + exc.getMessage != null && (exc.getMessage.equals( + "userProjectMissing" + ) || (exc.getCode == 400 && exc.getMessage.contains("requester pays"))) case exc: GoogleJsonResponseException => - exc.getMessage != null && (exc.getMessage.equals("userProjectMissing") || (exc.getStatusCode == 400 && exc.getMessage.contains("requester pays"))) - case exc: Throwable => + exc.getMessage != null && (exc.getMessage.equals( + "userProjectMissing" + ) || (exc.getStatusCode == 400 && exc.getMessage.contains("requester pays"))) + case _: Throwable => false } private[this] def handleRequesterPays[T, U]( makeRequest: Seq[U] => T, makeUserProjectOption: String => U, - bucket: String + bucket: String, ): T = { - try { + try makeRequest(Seq()) - } catch { + catch { case exc: Throwable => retryIfRequesterPays(exc, makeRequest, makeUserProjectOption, bucket) } @@ -193,7 +201,8 @@ class GoogleStorageFS( log.info("Initializing google storage client from service account key") StorageOptions.newBuilder() .setCredentials( - ServiceAccountCredentials.fromStream(new ByteArrayInputStream(keyData.getBytes))) + ServiceAccountCredentials.fromStream(new ByteArrayInputStream(keyData.getBytes)) + ) .setTransportOptions(transportOptions) .build() .getService @@ -210,7 +219,7 @@ class GoogleStorageFS( try { if (reader == null) { val opts = options.getOrElse(FastSeq()) - reader = storage.reader(url.bucket, url.path, opts:_*) + reader = storage.reader(url.bucket, url.path, opts: _*) reader.seek(getPosition) } return reader.read(bb) @@ -253,11 +262,10 @@ class GoogleStorageFS( return n } - override def physicalSeek(newPos: Long): Unit = { + override def physicalSeek(newPos: Long): Unit = if (reader != null) { reader.seek(newPos) } - } } new WrappedSeekableDataInputStream(is) @@ -268,7 +276,7 @@ class GoogleStorageFS( } def createNoCompression(url: URL): PositionedDataOutputStream = retryTransientErrors { - log.info(f"createNoCompression: ${url}") + log.info(f"createNoCompression: $url") val blobId = BlobId.of(url.bucket, url.path) val blobInfo = BlobInfo.newBuilder(blobId) @@ -283,11 +291,11 @@ class GoogleStorageFS( } else { handleRequesterPays( { (options: Seq[BlobWriteOption]) => - writer = retryTransientErrors { storage.writer(blobInfo, options:_*) } + writer = retryTransientErrors(storage.writer(blobInfo, options: _*)) f }, BlobWriteOption.userProject, - url.bucket + url.bucket, ) } } @@ -304,7 +312,7 @@ class GoogleStorageFS( } override def close(): Unit = { - log.info(f"close: ${url}") + log.info(f"close: $url") if (!closed) { flush() retryTransientErrors { @@ -314,7 +322,7 @@ class GoogleStorageFS( } closed = true } - log.info(f"closed: ${url}") + log.info(f"closed: $url") } } @@ -325,14 +333,16 @@ class GoogleStorageFS( val srcId = BlobId.of(src.bucket, src.path) val dstId = BlobId.of(dst.bucket, dst.path) - // There is only one userProject for the whole request, the source takes precedence over the target. - // https://github.com/googleapis/java-storage/blob/0bd17b1f70e47081941a44f018e3098b37ba2c47/google-cloud-storage/src/main/java/com/google/cloud/storage/spi/v1/HttpStorageRpc.java#L1016-L1019 + /* There is only one userProject for the whole request, the source takes precedence over the + * target. */ + /* https://github.com/googleapis/java-storage/blob/0bd17b1f70e47081941a44f018e3098b37ba2c47/google-cloud-storage/src/main/java/com/google/cloud/storage/spi/v1/HttpStorageRpc.java#L1016-L1019 */ def retryCopyIfRequesterPays(exc: Exception, message: String, code: Int): Unit = { if (message == null) { throw exc } - val probablyNeedsRequesterPays = message.equals("userProjectMissing") || (code == 400 && message.contains("requester pays")) + val probablyNeedsRequesterPays = + message.equals("userProjectMissing") || (code == 400 && message.contains("requester pays")) if (!probablyNeedsRequesterPays) { throw exc } @@ -354,7 +364,10 @@ class GoogleStorageFS( .setTarget(dstId) .build() } else if (buckets.contains(src.bucket) || buckets.contains(dst.bucket)) { - throw new RuntimeException(s"both ${src.bucket} and ${dst.bucket} must be specified in the requester_pays_buckets to copy between these buckets", exc) + throw new RuntimeException( + s"both ${src.bucket} and ${dst.bucket} must be specified in the requester_pays_buckets to copy between these buckets", + exc, + ) } else { throw exc } @@ -373,7 +386,6 @@ class GoogleStorageFS( throw exc } - try { storage.copy( Storage.CopyRequest.newBuilder() @@ -394,9 +406,10 @@ class GoogleStorageFS( if (recursive) { var page = retryTransientErrors { handleRequesterPays( - (options: Seq[BlobListOption]) => storage.list(url.bucket, (BlobListOption.prefix(url.path) +: options):_*), + (options: Seq[BlobListOption]) => + storage.list(url.bucket, (BlobListOption.prefix(url.path) +: options): _*), BlobListOption.userProject, - url.bucket + url.bucket, ) } while (page != null) { @@ -408,11 +421,11 @@ class GoogleStorageFS( if (options.isEmpty) { storage.delete(blobs) } else { - blobs.asScala.foreach(storage.delete(_, options:_*)) + blobs.asScala.foreach(storage.delete(_, options: _*)) } }, BlobSourceOption.userProject, - url.bucket + url.bucket, ) } } @@ -421,9 +434,9 @@ class GoogleStorageFS( } else { // Storage.delete is idempotent. it returns a Boolean which is false if the file did not exist handleRequesterPays( - (options: Seq[BlobSourceOption]) => storage.delete(url.bucket, url.path, options:_*), + (options: Seq[BlobSourceOption]) => storage.delete(url.bucket, url.path, options: _*), BlobSourceOption.userProject, - url.bucket + url.bucket, ) } } @@ -437,9 +450,13 @@ class GoogleStorageFS( val blobs = retryTransientErrors { handleRequesterPays( - (options: Seq[BlobListOption]) => storage.list(url.bucket, (BlobListOption.prefix(path) +: BlobListOption.currentDirectory() +: options):_*), + (options: Seq[BlobListOption]) => + storage.list( + url.bucket, + (BlobListOption.prefix(path) +: BlobListOption.currentDirectory() +: options): _* + ), BlobListOption.userProject, - url.bucket + url.bucket, ) } @@ -449,31 +466,51 @@ class GoogleStorageFS( .toArray } - override def fileListEntry(url: URL): FileListEntry = retryTransientErrors { - val path = dropTrailingSlash(url.path) + private[this] def getBlob(url: URL) = retryTransientErrors { + handleRequesterPays( + (options: Seq[BlobGetOption]) => + storage.get(url.bucket, url.path, options: _*), + BlobGetOption.userProject _, + url.bucket, + ) + } + override def fileStatus(url: URL): FileStatus = retryTransientErrors { if (url.path == "") - return new BlobStorageFileListEntry(s"gs://${url.bucket}", null, 0, true) + return GoogleStorageFileListEntry.dir(url) - val blobs = retryTransientErrors { - handleRequesterPays( - (options: Seq[BlobListOption]) => storage.list(url.bucket, (BlobListOption.prefix(path) +: BlobListOption.currentDirectory() +: options):_*), - BlobListOption.userProject, - url.bucket - ) + val blob = getBlob(url) + + if (blob == null) { + throw new FileNotFoundException(url.toString) } - val it = blobs.getValues.iterator.asScala - while (it.hasNext) { - val b = it.next() - var name = b.getName - while (name.endsWith("/")) - name = name.dropRight(1) - if (name == path) - return GoogleStorageFileListEntry(b) + new BlobStorageFileStatus( + url.toString, + blob.getUpdateTimeOffsetDateTime.toInstant().toEpochMilli(), + blob.getSize, + ) + } + + override def fileListEntry(url: URL): FileListEntry = { + if (url.getPath == "") { + return GoogleStorageFileListEntry.dir(url) } - throw new FileNotFoundException(url.toString()) + val prefix = dropTrailingSlash(url.path) + val it = retryTransientErrors { + handleRequesterPays( + (options: Seq[BlobListOption]) => + storage.list( + url.bucket, + (BlobListOption.prefix(prefix) +: BlobListOption.currentDirectory() +: options): _* + ), + BlobListOption.userProject _, + url.bucket, + ) + }.iterateAll().asScala.map(GoogleStorageFileListEntry.apply(_)).iterator + + fileListEntryFromIterator(url, it) } override def eTag(url: URL): Some[String] = { @@ -481,10 +518,10 @@ class GoogleStorageFS( handleRequesterPays( (options: Seq[BlobGetOption]) => retryTransientErrors { - Some(storage.get(bucket, blob, options:_*).getEtag) + Some(storage.get(bucket, blob, options: _*).getEtag) }, BlobGetOption.userProject, - bucket + bucket, ) } diff --git a/hail/src/main/scala/is/hail/io/fs/HadoopFS.scala b/hail/src/main/scala/is/hail/io/fs/HadoopFS.scala index 545f2ff0f58..285cfd578c1 100644 --- a/hail/src/main/scala/is/hail/io/fs/HadoopFS.scala +++ b/hail/src/main/scala/is/hail/io/fs/HadoopFS.scala @@ -1,20 +1,21 @@ package is.hail.io.fs import is.hail.utils._ -import org.apache.hadoop -import org.apache.hadoop.fs.{EtagSource, FSDataInputStream, FSDataOutputStream} -import org.apache.hadoop.io.MD5Hash -import java.io._ -import java.security.MessageDigest -import java.util.Base64 import scala.util.Try +import java.io._ + +import org.apache.hadoop +import org.apache.hadoop.fs.{EtagSource, FSDataInputStream, FSDataOutputStream} + class HadoopFileListEntry(fs: hadoop.fs.FileStatus) extends FileListEntry { val normalizedPath = fs.getPath def getPath: String = fs.getPath.toString + def getActualUrl: String = fs.getPath.toString + def getModificationTime: java.lang.Long = fs.getModificationTime def getLen: Long = fs.getLen @@ -39,12 +40,11 @@ object HadoopFS { override def flush(): Unit = if (!closed) os.flush() - override def close(): Unit = { + override def close(): Unit = if (!closed) { os.close() closed = true } - } def getPosition: Long = os.getPos } @@ -59,12 +59,11 @@ object HadoopFS { override def skip(n: Long): Long = is.skip(n) - override def close(): Unit = { + override def close(): Unit = if (!closed) { is.close() closed = true } - } def seek(pos: Long): Unit = is.seek(pos) @@ -72,18 +71,17 @@ object HadoopFS { } } +case class HadoopFSURL(path: String, conf: SerializableHadoopConfiguration) extends FSURL { + private[this] val unqualifiedHadoopPath = new hadoop.fs.Path(path) + val hadoopFs = unqualifiedHadoopPath.getFileSystem(conf.value) + val hadoopPath = hadoopFs.makeQualified(unqualifiedHadoopPath) -case class HadoopFSURL(val path: String, conf: SerializableHadoopConfiguration) extends FSURL { - val hadoopPath = new hadoop.fs.Path(path) - val hadoopFs = hadoopPath.getFileSystem(conf.value) - - def addPathComponent(c: String): HadoopFSURL = HadoopFSURL(s"$path/$c", conf) - def getPath: String = path + def addPathComponent(c: String): HadoopFSURL = HadoopFSURL(s"${hadoopPath.toString}/$c", conf) + def getPath: String = hadoopPath.toString def fromString(s: String): HadoopFSURL = HadoopFSURL(s, conf) - override def toString(): String = path + override def toString(): String = hadoopPath.toString } - class HadoopFS(private[this] var conf: SerializableHadoopConfiguration) extends FS { type URL = HadoopFSURL @@ -96,38 +94,39 @@ class HadoopFS(private[this] var conf: SerializableHadoopConfiguration) extends def getConfiguration(): SerializableHadoopConfiguration = conf - def setConfiguration(_conf: Any): Unit = { + def setConfiguration(_conf: Any): Unit = conf = _conf.asInstanceOf[SerializableHadoopConfiguration] - } def createNoCompression(url: URL): PositionedDataOutputStream = { val os = url.hadoopFs.create(url.hadoopPath) new WrappedPositionedDataOutputStream( - HadoopFS.toPositionedOutputStream(os)) + HadoopFS.toPositionedOutputStream(os) + ) } def openNoCompression(url: URL): SeekableDataInputStream = { - val is = try { - url.hadoopFs.open(url.hadoopPath) - } catch { - case e: FileNotFoundException => - if (isDir(url)) - throw new FileNotFoundException(s"'$url' is a directory (or native Table/MatrixTable)") - else - throw e - } + val is = + try + url.hadoopFs.open(url.hadoopPath) + catch { + case e: FileNotFoundException => + if (isDir(url)) + throw new FileNotFoundException(s"'$url' is a directory (or native Table/MatrixTable)") + else + throw e + } new WrappedSeekableDataInputStream( - HadoopFS.toSeekableInputStream(is)) + HadoopFS.toSeekableInputStream(is) + ) } - def getFileSystem(filename: String): hadoop.fs.FileSystem = { + def getFileSystem(filename: String): hadoop.fs.FileSystem = new hadoop.fs.Path(filename).getFileSystem(conf.value) - } def listDirectory(url: URL): Array[FileListEntry] = { - var statuses = url.hadoopFs.globStatus(url.hadoopPath) + val statuses = url.hadoopFs.globStatus(url.hadoopPath) if (statuses == null) { throw new FileNotFoundException(url.toString) } else { @@ -138,28 +137,24 @@ class HadoopFS(private[this] var conf: SerializableHadoopConfiguration) extends } } - override def mkDir(url: URL): Unit = { + override def mkDir(url: URL): Unit = url.hadoopFs.mkdirs(url.hadoopPath) - } - def remove(fname: String): Unit = { + def remove(fname: String): Unit = getFileSystem(fname).delete(new hadoop.fs.Path(fname), false) - } - def rmtree(dirname: String): Unit = { + def rmtree(dirname: String): Unit = getFileSystem(dirname).delete(new hadoop.fs.Path(dirname), true) - } - def delete(url: URL, recursive: Boolean) { + def delete(url: URL, recursive: Boolean): Unit = url.hadoopFs.delete(url.hadoopPath, recursive) - } override def globAll(filenames: Iterable[String]): Array[FileListEntry] = { filenames.flatMap { filename => - val statuses = glob(filename) - if (statuses.isEmpty) + val fles = glob(filename) + if (fles.isEmpty) warn(s"'$filename' refers to no files") - statuses + fles }.toArray } @@ -167,20 +162,28 @@ class HadoopFS(private[this] var conf: SerializableHadoopConfiguration) extends var files = url.hadoopFs.globStatus(url.hadoopPath) if (files == null) files = Array.empty - log.info(s"globbing path $url returned ${ files.length } files: ${ files.map(_.getPath.getName).mkString(",") }") + log.info( + s"globbing path $url returned ${files.length} files: ${files.map(_.getPath.getName).mkString(",")}" + ) files.map(fileListEntry => new HadoopFileListEntry(fileListEntry)) } - def fileListEntry(url: URL): FileListEntry = { - new HadoopFileListEntry(url.hadoopFs.getFileStatus(url.hadoopPath)) + override def fileStatus(url: URL): FileStatus = { + val fle = fileListEntry(url) + if (fle.isDirectory) { + throw new FileNotFoundException(url.getPath) + } + fle } - override def eTag(url: URL): Option[String] = { + def fileListEntry(url: URL): FileListEntry = + new HadoopFileListEntry(url.hadoopFs.getFileStatus(url.hadoopPath)) + + override def eTag(url: URL): Option[String] = if (url.hadoopFs.hasPathCapability(url.hadoopPath, "fs.capability.etags.available")) Some(url.hadoopFs.getFileStatus(url.hadoopPath).asInstanceOf[EtagSource].getEtag) else None - } def makeQualified(path: String): String = { val ppath = new hadoop.fs.Path(path) @@ -188,9 +191,8 @@ class HadoopFS(private[this] var conf: SerializableHadoopConfiguration) extends pathFS.makeQualified(ppath).toString } - override def deleteOnExit(url: URL): Unit = { + override def deleteOnExit(url: URL): Unit = url.hadoopFs.deleteOnExit(url.hadoopPath) - } def supportsScheme(scheme: String): Boolean = { if (scheme == "") { @@ -200,7 +202,7 @@ class HadoopFS(private[this] var conf: SerializableHadoopConfiguration) extends hadoop.fs.FileSystem.getFileSystemClass(scheme, conf.value) true } catch { - case e: hadoop.fs.UnsupportedFileSystemException => false + case _: hadoop.fs.UnsupportedFileSystemException => false case e: Exception => throw e } } diff --git a/hail/src/main/scala/is/hail/io/fs/RouterFS.scala b/hail/src/main/scala/is/hail/io/fs/RouterFS.scala index c6ad3d6c13c..7d8e9df3c52 100644 --- a/hail/src/main/scala/is/hail/io/fs/RouterFS.scala +++ b/hail/src/main/scala/is/hail/io/fs/RouterFS.scala @@ -8,21 +8,21 @@ case class RouterFSURL private (_url: FSURL, val fs: FS) extends FSURL { val url = _url.asInstanceOf[fs.URL] def getPath: String = url.getPath - def addPathComponent(component: String): RouterFSURL = { + + def addPathComponent(component: String): RouterFSURL = RouterFSURL(fs)(fs.urlAddPathComponent(url, component)) - } + override def toString(): String = url.toString } class RouterFS(fss: IndexedSeq[FS]) extends FS { type URL = RouterFSURL - def lookupFS(path: String): FS = { + def lookupFS(path: String): FS = fss.find(_.validUrl(path)) match { case Some(fs) => fs case None => throw new IllegalArgumentException(s"Unsupported URI: $path") } - } override def parseUrl(filename: String): URL = { val fs = lookupFS(filename) @@ -35,13 +35,16 @@ class RouterFS(fss: IndexedSeq[FS]) extends FS { def urlAddPathComponent(url: URL, component: String): URL = url.addPathComponent(component) - override def openCachedNoCompression(url: URL): SeekableDataInputStream = url.fs.openCachedNoCompression(url.url) + override def openCachedNoCompression(url: URL): SeekableDataInputStream = + url.fs.openCachedNoCompression(url.url) - override def createCachedNoCompression(url: URL): PositionedDataOutputStream = url.fs.createCachedNoCompression(url.url) + override def createCachedNoCompression(url: URL): PositionedDataOutputStream = + url.fs.createCachedNoCompression(url.url) def openNoCompression(url: URL): SeekableDataInputStream = url.fs.openNoCompression(url.url) - def createNoCompression(url: URL): PositionedDataOutputStream = url.fs.createNoCompression(url.url) + def createNoCompression(url: URL): PositionedDataOutputStream = + url.fs.createNoCompression(url.url) override def readNoCompression(url: URL): Array[Byte] = url.fs.readNoCompression(url.url) @@ -53,6 +56,8 @@ class RouterFS(fss: IndexedSeq[FS]) extends FS { def glob(url: URL): Array[FileListEntry] = url.fs.glob(url.url) + def fileStatus(url: URL): FileStatus = url.fs.fileStatus(url.url) + def fileListEntry(url: URL): FileListEntry = url.fs.fileListEntry(url.url) override def eTag(url: URL): Option[String] = url.fs.eTag(url.url) @@ -61,7 +66,8 @@ class RouterFS(fss: IndexedSeq[FS]) extends FS { def getConfiguration(): Any = fss.map(_.getConfiguration()) - def setConfiguration(config: Any): Unit = { - fss.zip(config.asInstanceOf[IndexedSeq[_]]).foreach { case (fs: FS, config: Any) => fs.setConfiguration(config) } - } + def setConfiguration(config: Any): Unit = + fss.zip(config.asInstanceOf[IndexedSeq[_]]).foreach { case (fs: FS, config: Any) => + fs.setConfiguration(config) + } } diff --git a/hail/src/main/scala/is/hail/io/fs/TerraAzureStorageFS.scala b/hail/src/main/scala/is/hail/io/fs/TerraAzureStorageFS.scala new file mode 100644 index 00000000000..4078bf30603 --- /dev/null +++ b/hail/src/main/scala/is/hail/io/fs/TerraAzureStorageFS.scala @@ -0,0 +1,87 @@ +package is.hail.io.fs + +import is.hail.shadedazure.com.azure.core.credential.TokenRequestContext +import is.hail.shadedazure.com.azure.identity.{ + DefaultAzureCredential, DefaultAzureCredentialBuilder, +} +import is.hail.shadedazure.com.azure.storage.blob.BlobServiceClient +import is.hail.utils._ + +import scala.collection.mutable + +import org.apache.http.client.methods.HttpPost +import org.apache.http.client.utils.URIBuilder +import org.apache.http.impl.client.HttpClients +import org.apache.http.util.EntityUtils +import org.apache.log4j.Logger +import org.json4s.{DefaultFormats, Formats} +import org.json4s.jackson.JsonMethods + +object TerraAzureStorageFS { + private val log = Logger.getLogger(getClass.getName) + private val TEN_MINUTES_IN_MS = 10 * 60 * 1000 +} + +class TerraAzureStorageFS extends AzureStorageFS() { + import TerraAzureStorageFS.{log, TEN_MINUTES_IN_MS} + + private[this] val httpClient = HttpClients.custom().build() + private[this] val sasTokenCache = mutable.Map[String, (URL, Long)]() + + private[this] val workspaceManagerUrl = sys.env("WORKSPACE_MANAGER_URL") + private[this] val workspaceId = sys.env("WORKSPACE_ID") + private[this] val containerResourceId = sys.env("WORKSPACE_STORAGE_CONTAINER_ID") + private[this] val storageContainerUrl = parseUrl(sys.env("WORKSPACE_STORAGE_CONTAINER_URL")) + + private[this] val credential: DefaultAzureCredential = new DefaultAzureCredentialBuilder().build() + + override def getServiceClient(url: URL): BlobServiceClient = + if (blobInWorkspaceStorageContainer(url)) { + super.getServiceClient(getTerraSasToken(url)) + } else { + super.getServiceClient(url) + } + + def getTerraSasToken(url: URL): URL = { + sasTokenCache.get(url.base) match { + case Some((sasTokenUrl, expiration)) + if expiration > System.currentTimeMillis + TEN_MINUTES_IN_MS => sasTokenUrl + case None => + val (sasTokenUrl, expiration) = createTerraSasToken() + sasTokenCache += (url.base -> (sasTokenUrl -> expiration)) + sasTokenUrl + } + } + + private def blobInWorkspaceStorageContainer(url: URL): Boolean = + storageContainerUrl.account == url.account && storageContainerUrl.container == url.container + + private def createTerraSasToken(): (URL, Long) = { + implicit val formats: Formats = DefaultFormats + + val context = new TokenRequestContext() + context.addScopes("https://management.azure.com/.default") + val token = credential.getToken(context).block().getToken() + + val url = + s"$workspaceManagerUrl/api/workspaces/v1/$workspaceId/resources/controlled/azure/storageContainer/$containerResourceId/getSasToken" + val req = new HttpPost(url) + req.addHeader("Authorization", s"Bearer $token") + + val tenHoursInSeconds = 10 * 3600 + val expiration = System.currentTimeMillis() + tenHoursInSeconds * 1000 + val uri = new URIBuilder(req.getURI()) + .addParameter("sasPermissions", "racwdl") + .addParameter("sasExpirationDuration", tenHoursInSeconds.toString) + .build() + req.setURI(uri) + + val sasTokenUrl = using(httpClient.execute(req)) { resp => + val json = JsonMethods.parse(new String(EntityUtils.toString(resp.getEntity))) + log.info(s"Created sas token client for $containerResourceId") + (json \ "url").extract[String] + } + + (parseUrl(sasTokenUrl), expiration) + } +} diff --git a/hail/src/main/scala/is/hail/io/fs/package.scala b/hail/src/main/scala/is/hail/io/fs/package.scala index dc1ddda5f2f..39aaa8a54d4 100644 --- a/hail/src/main/scala/is/hail/io/fs/package.scala +++ b/hail/src/main/scala/is/hail/io/fs/package.scala @@ -16,5 +16,7 @@ package object fs { def outputStreamToPositionedDataOutputStream(os: OutputStream): PositionedDataOutputStream = new WrappedPositionedDataOutputStream( new WrappedPositionOutputStream( - os)) + os + ) + ) } diff --git a/hail/src/main/scala/is/hail/io/gen/ExportBGEN.scala b/hail/src/main/scala/is/hail/io/gen/ExportBGEN.scala index 17c7794dc88..bd478bf873c 100644 --- a/hail/src/main/scala/is/hail/io/gen/ExportBGEN.scala +++ b/hail/src/main/scala/is/hail/io/gen/ExportBGEN.scala @@ -1,29 +1,19 @@ package is.hail.io.gen -import is.hail.HailContext -import is.hail.annotations.{RegionValue, UnsafeRow} -import is.hail.backend.ExecuteContext -import is.hail.expr.ir.{ByteArrayBuilder, MatrixValue} -import is.hail.types.physical.PStruct +import is.hail.expr.ir.ByteArrayBuilder import is.hail.io.fs.FS -import is.hail.utils.BoxedArrayBuilder -import is.hail.variant.{ArrayGenotypeView, RegionValueVariant, View} -import is.hail.utils._ -import org.apache.hadoop.io.IOUtils -import org.apache.spark.TaskContext -import org.apache.spark.sql.Row object BgenWriter { val ploidy: Byte = 2 val phased: Byte = 0 val totalProb: Int = 255 - def shortToBytesLE(bb: ByteArrayBuilder, i: Int) { + def shortToBytesLE(bb: ByteArrayBuilder, i: Int): Unit = { bb += (i & 0xff).toByte bb += ((i >>> 8) & 0xff).toByte } - def intToBytesLE(bb: ByteArrayBuilder, i: Int) { + def intToBytesLE(bb: ByteArrayBuilder, i: Int): Unit = { bb += (i & 0xff).toByte bb += ((i >>> 8) & 0xff).toByte bb += ((i >>> 16) & 0xff).toByte @@ -46,7 +36,7 @@ object BgenWriter { 4 + l } - def updateIntToBytesLE(bb: ByteArrayBuilder, i: Int, pos: Int) { + def updateIntToBytesLE(bb: ByteArrayBuilder, i: Int, pos: Int): Unit = { bb(pos) = (i & 0xff).toByte bb(pos + 1) = ((i >>> 8) & 0xff).toByte bb(pos + 2) = ((i >>> 16) & 0xff).toByte @@ -85,13 +75,20 @@ object BgenWriter { bb.result() } - def writeSampleFile(fs: FS, path: String, sampleIds: Array[String]) { - fs.writeTable(path + ".sample", - "ID_1 ID_2 missing" :: "0 0 0" :: sampleIds.map(s => s"$s $s 0").toList) - } - - def roundWithConstantSum(input: Array[Double], fractional: Array[Double], index: Array[Int], - indexInverse: Array[Int], output: ByteArrayBuilder, expectedSize: Long) { + def writeSampleFile(fs: FS, path: String, sampleIds: Array[String]): Unit = + fs.writeTable( + path + ".sample", + "ID_1 ID_2 missing" :: "0 0 0" :: sampleIds.map(s => s"$s $s 0").toList, + ) + + def roundWithConstantSum( + input: Array[Double], + fractional: Array[Double], + index: Array[Int], + indexInverse: Array[Int], + output: ByteArrayBuilder, + expectedSize: Long, + ): Unit = { val n = input.length assert(fractional.length == n && index.length == n && indexInverse.length == n) @@ -130,7 +127,7 @@ object BgenWriter { assert(newSize == expectedSize) } - private def resetIndex(index: Array[Int]) { + private def resetIndex(index: Array[Int]): Unit = { var i = 0 while (i < index.length) { index(i) = i @@ -138,8 +135,8 @@ object BgenWriter { } } - private def quickSortWithIndex(a: Array[Double], idx: Array[Int], start: Int, n: Int) { - def swap(i: Int, j: Int) { + private def quickSortWithIndex(a: Array[Double], idx: Array[Int], start: Int, n: Int): Unit = { + def swap(i: Int, j: Int): Unit = { val tmp = idx(i) idx(i) = idx(j) idx(j) = tmp diff --git a/hail/src/main/scala/is/hail/io/gen/ExportGen.scala b/hail/src/main/scala/is/hail/io/gen/ExportGen.scala index f2b36c80023..bd62fffe52e 100644 --- a/hail/src/main/scala/is/hail/io/gen/ExportGen.scala +++ b/hail/src/main/scala/is/hail/io/gen/ExportGen.scala @@ -1,37 +1,49 @@ package is.hail.io.gen -import is.hail.HailContext -import is.hail.annotations.Region -import is.hail.backend.ExecuteContext -import is.hail.expr.ir.MatrixValue import is.hail.types.physical.{PString, PStruct} -import is.hail.variant.{ArrayGenotypeView, Locus, RegionValueVariant, VariantMethods, View} import is.hail.utils._ -import org.apache.spark.sql.Row +import is.hail.variant.{Locus, VariantMethods, View} object ExportGen { val spaceRegex = """\s+""".r def checkSample(id1: String, id2: String, missing: Double): Unit = { - if (spaceRegex.findFirstIn(id1).isDefined) - fatal(s"Invalid 'id1' found -- no white space allowed: '$id1'") - if (spaceRegex.findFirstIn(id2).isDefined) - fatal(s"Invalid 'id2' found -- no white space allowed: '$id2'") - if (missing < 0 || missing > 1) - fatal(s"'missing' values must be in the range [0, 1]. Found $missing for ($id1, $id2).") + if (spaceRegex.findFirstIn(id1).isDefined) + fatal(s"Invalid 'id1' found -- no white space allowed: '$id1'") + if (spaceRegex.findFirstIn(id2).isDefined) + fatal(s"Invalid 'id2' found -- no white space allowed: '$id2'") + if (missing < 0 || missing > 1) + fatal(s"'missing' values must be in the range [0, 1]. Found $missing for ($id1, $id2).") } - def checkVariant(contig: String, position: Int, a0: String, a1: String, varid: String, rsid: String): Unit = { + def checkVariant( + contig: String, + position: Int, + a0: String, + a1: String, + varid: String, + rsid: String, + ): Unit = { if (spaceRegex.findFirstIn(contig).isDefined) - fatal(s"Invalid contig found at '${ VariantMethods.locusAllelesToString(Locus(contig, position), Array(a0, a1)) }' -- no white space allowed: '$contig'") + fatal( + s"Invalid contig found at '${VariantMethods.locusAllelesToString(Locus(contig, position), Array(a0, a1))}' -- no white space allowed: '$contig'" + ) if (spaceRegex.findFirstIn(a0).isDefined) - fatal(s"Invalid allele found at '${ VariantMethods.locusAllelesToString(Locus(contig, position), Array(a0, a1)) }' -- no white space allowed: '$a0'") + fatal( + s"Invalid allele found at '${VariantMethods.locusAllelesToString(Locus(contig, position), Array(a0, a1))}' -- no white space allowed: '$a0'" + ) if (spaceRegex.findFirstIn(a1).isDefined) - fatal(s"Invalid allele found at '${ VariantMethods.locusAllelesToString(Locus(contig, position), Array(a0, a1)) }' -- no white space allowed: '$a1'") + fatal( + s"Invalid allele found at '${VariantMethods.locusAllelesToString(Locus(contig, position), Array(a0, a1))}' -- no white space allowed: '$a1'" + ) if (spaceRegex.findFirstIn(varid).isDefined) - fatal(s"Invalid 'varid' found at '${ VariantMethods.locusAllelesToString(Locus(contig, position), Array(a0, a1)) }' -- no white space allowed: '$varid'") + fatal( + s"Invalid 'varid' found at '${VariantMethods.locusAllelesToString(Locus(contig, position), Array(a0, a1))}' -- no white space allowed: '$varid'" + ) if (spaceRegex.findFirstIn(rsid).isDefined) - fatal(s"Invalid 'rsid' found at '${ VariantMethods.locusAllelesToString(Locus(contig, position), Array(a0, a1)) }' -- no white space allowed: '$rsid'") + fatal( + s"Invalid 'rsid' found at '${VariantMethods.locusAllelesToString(Locus(contig, position), Array(a0, a1))}' -- no white space allowed: '$rsid'" + ) } } @@ -48,7 +60,7 @@ class GenAnnotationView(rowType: PStruct) extends View { private var cachedVarid: String = _ private var cachedRsid: String = _ - def set(offset: Long) { + def set(offset: Long): Unit = { assert(rowType.isFieldDefined(offset, varidIdx)) assert(rowType.isFieldDefined(offset, rsidIdx)) this.rsidOffset = rowType.loadField(offset, rsidIdx) diff --git a/hail/src/main/scala/is/hail/io/hadoop/ByteArrayOutputFormat.scala b/hail/src/main/scala/is/hail/io/hadoop/ByteArrayOutputFormat.scala index 621bf706452..a24e4f6d283 100644 --- a/hail/src/main/scala/is/hail/io/hadoop/ByteArrayOutputFormat.scala +++ b/hail/src/main/scala/is/hail/io/hadoop/ByteArrayOutputFormat.scala @@ -9,24 +9,26 @@ import org.apache.hadoop.util.Progressable class ByteArrayOutputFormat extends FileOutputFormat[NullWritable, BytesOnlyWritable] { - class ByteArrayRecordWriter(out: DataOutputStream) extends RecordWriter[NullWritable, BytesOnlyWritable] { + class ByteArrayRecordWriter(out: DataOutputStream) + extends RecordWriter[NullWritable, BytesOnlyWritable] { - def write(key: NullWritable, value: BytesOnlyWritable) { + def write(key: NullWritable, value: BytesOnlyWritable): Unit = if (value != null) value.write(out) - } - def close(reporter: Reporter) { + def close(reporter: Reporter): Unit = out.close() - } } - override def getRecordWriter(ignored: FileSystem, job: JobConf, - name: String, progress: Progressable): RecordWriter[NullWritable, BytesOnlyWritable] = { + override def getRecordWriter( + ignored: FileSystem, + job: JobConf, + name: String, + progress: Progressable, + ): RecordWriter[NullWritable, BytesOnlyWritable] = { val file: Path = FileOutputFormat.getTaskOutputPath(job, name) val fs: FileSystem = file.getFileSystem(job) val fileOut: FSDataOutputStream = fs.create(file, progress) new ByteArrayRecordWriter(fileOut) } } - diff --git a/hail/src/main/scala/is/hail/io/hadoop/BytesOnlyWritable.scala b/hail/src/main/scala/is/hail/io/hadoop/BytesOnlyWritable.scala index 1707bcdad91..0b01b6c1f2b 100644 --- a/hail/src/main/scala/is/hail/io/hadoop/BytesOnlyWritable.scala +++ b/hail/src/main/scala/is/hail/io/hadoop/BytesOnlyWritable.scala @@ -1,6 +1,6 @@ package is.hail.io.hadoop -import java.io.{DataOutput, DataInput} +import java.io.{DataInput, DataOutput} import org.apache.hadoop.io.Writable @@ -8,16 +8,14 @@ class BytesOnlyWritable(var bytes: Array[Byte]) extends Writable { def this() = this(null) - def set(bytes: Array[Byte]) { + def set(bytes: Array[Byte]): Unit = this.bytes = bytes - } - override def write(out: DataOutput) { + override def write(out: DataOutput): Unit = { assert(bytes != null) out.write(bytes, 0, bytes.length) } - override def readFields(in: DataInput) { + override def readFields(in: DataInput): Unit = throw new UnsupportedOperationException() - } } diff --git a/hail/src/main/scala/is/hail/io/index/IndexReader.scala b/hail/src/main/scala/is/hail/io/index/IndexReader.scala index 413fb251024..4ae410403cd 100644 --- a/hail/src/main/scala/is/hail/io/index/IndexReader.scala +++ b/hail/src/main/scala/is/hail/io/index/IndexReader.scala @@ -1,41 +1,48 @@ package is.hail.io.index -import java.io.InputStream -import java.util -import java.util.Map.Entry -import is.hail.asm4s.HailClassLoader import is.hail.annotations._ +import is.hail.asm4s.HailClassLoader import is.hail.backend.{ExecuteContext, HailStateManager} -import is.hail.types.virtual.{TStruct, Type, TypeSerializer} -import is.hail.expr.ir.IRParser -import is.hail.types.physical.{PStruct, PType} import is.hail.io._ -import is.hail.io.bgen.BgenSettings -import is.hail.utils._ import is.hail.io.fs.FS -import is.hail.rvd.{AbstractIndexSpec, AbstractRVDSpec, PartitionBoundOrdering} -import org.apache.hadoop.fs.FSDataInputStream +import is.hail.rvd.{AbstractIndexSpec, PartitionBoundOrdering} +import is.hail.types.physical.PStruct +import is.hail.types.virtual.{TStruct, Type, TypeSerializer} +import is.hail.utils._ + +import java.io.InputStream +import java.util +import java.util.Map.Entry + import org.apache.spark.sql.Row -import org.json4s.{Formats, NoTypeHints} -import org.json4s.jackson.{JsonMethods, Serialization} +import org.json4s.Formats +import org.json4s.jackson.JsonMethods object IndexReaderBuilder { - def fromSpec(ctx: ExecuteContext, spec: AbstractIndexSpec): (HailClassLoader, FS, String, Int, RegionPool) => IndexReader = { + def fromSpec(ctx: ExecuteContext, spec: AbstractIndexSpec) + : (HailClassLoader, FS, String, Int, RegionPool) => IndexReader = { val (keyType, annotationType) = spec.types - val (leafPType: PStruct, leafDec) = spec.leafCodec.buildDecoder(ctx, spec.leafCodec.encodedVirtualType) - val (intPType: PStruct, intDec) = spec.internalNodeCodec.buildDecoder(ctx, spec.internalNodeCodec.encodedVirtualType) + val (leafPType: PStruct, leafDec) = + spec.leafCodec.buildDecoder(ctx, spec.leafCodec.encodedVirtualType) + val (intPType: PStruct, intDec) = + spec.internalNodeCodec.buildDecoder(ctx, spec.internalNodeCodec.encodedVirtualType) withDecoders(ctx, leafDec, intDec, keyType, annotationType, leafPType, intPType) } def withDecoders( ctx: ExecuteContext, - leafDec: (InputStream, HailClassLoader) => Decoder, intDec: (InputStream, HailClassLoader) => Decoder, - keyType: Type, annotationType: Type, - leafPType: PStruct, intPType: PStruct + leafDec: (InputStream, HailClassLoader) => Decoder, + intDec: (InputStream, HailClassLoader) => Decoder, + keyType: Type, + annotationType: Type, + leafPType: PStruct, + intPType: PStruct, ): (HailClassLoader, FS, String, Int, RegionPool) => IndexReader = { val sm = ctx.stateManager - (theHailClassLoader, fs, path, cacheCapacity, pool) => new IndexReader( - theHailClassLoader, fs, path, cacheCapacity, leafDec, intDec, keyType, annotationType, leafPType, intPType, pool, sm) + (theHailClassLoader, fs, path, cacheCapacity, pool) => + new IndexReader( + theHailClassLoader, fs, path, cacheCapacity, leafDec, intDec, keyType, annotationType, + leafPType, intPType, pool, sm) } } @@ -43,7 +50,7 @@ object IndexReader { def readUntyped(fs: FS, path: String): IndexMetadataUntypedJSON = { val jv = using(fs.open(path + "/metadata.json.gz")) { in => JsonMethods.parse(in) - .removeField{ case (f, _) => f == "keyType" || f == "annotationType" } + .removeField { case (f, _) => f == "keyType" || f == "annotationType" } } implicit val formats: Formats = defaultJSONFormats jv.extract[IndexMetadataUntypedJSON] @@ -55,14 +62,13 @@ object IndexReader { } def readTypes(fs: FS, path: String): (Type, Type) = { - val jv = using(fs.open(path + "/metadata.json.gz")) { in => JsonMethods.parse(in) } + val jv = using(fs.open(path + "/metadata.json.gz"))(in => JsonMethods.parse(in)) implicit val formats: Formats = defaultJSONFormats + new TypeSerializer val metadata = jv.extract[IndexMetadata] metadata.keyType -> metadata.annotationType } } - class IndexReader( theHailClassLoader: HailClassLoader, fs: FS, @@ -75,7 +81,7 @@ class IndexReader( val leafPType: PStruct, val internalPType: PStruct, val pool: RegionPool, - val sm: HailStateManager + val sm: HailStateManager, ) extends AutoCloseable { private[io] val metadata = IndexReader.readMetadata(fs, path, keyType, annotationType) val branchingFactor = metadata.branchingFactor @@ -83,6 +89,7 @@ class IndexReader( val nKeys = metadata.nKeys val attributes = metadata.attributes val indexRelativePath = metadata.indexPath + val ordering = keyType match { case ts: TStruct => PartitionBoundOrdering(sm, ts) case t => t.ordering(sm) @@ -92,15 +99,17 @@ class IndexReader( private val leafDecoder = leafDecoderBuilder(is, theHailClassLoader) private val internalDecoder = internalDecoderBuilder(is, theHailClassLoader) - private val region = Region(pool=pool) + private val region = Region(pool = pool) private val rv = RegionValue(region) private var cacheHits = 0L private var cacheMisses = 0L - @transient private[this] lazy val cache = new util.LinkedHashMap[Long, IndexNode](cacheCapacity, 0.75f, true) { - override def removeEldestEntry(eldest: Entry[Long, IndexNode]): Boolean = size() > cacheCapacity - } + @transient private[this] lazy val cache = + new util.LinkedHashMap[Long, IndexNode](cacheCapacity, 0.75f, true) { + override def removeEldestEntry(eldest: Entry[Long, IndexNode]): Boolean = + size() > cacheCapacity + } private[io] def readInternalNode(offset: Long): InternalNode = { if (cache.containsKey(offset)) { @@ -147,12 +156,13 @@ class IndexReader( } } - private[io] def lowerBound(key: Annotation): Long = { - if (nKeys == 0 || ordering.lteq(key, readInternalNode(metadata.rootOffset).children.head.firstKey)) + private[io] def lowerBound(key: Annotation): Long = + if ( + nKeys == 0 || ordering.lteq(key, readInternalNode(metadata.rootOffset).children.head.firstKey) + ) 0 else lowerBound(key, height - 1, metadata.rootOffset) - } private def upperBound(key: Annotation, level: Int, offset: Long): Long = { if (level == 0) { @@ -162,18 +172,18 @@ class IndexReader( } else { val node = readInternalNode(offset) val children = node.children - val n = children.length val idx = children.upperBound(key, ordering.lt, _.firstKey) upperBound(key, level - 1, children(idx - 1).indexFileOffset) } } - private[io] def upperBound(key: Annotation): Long = { - if (nKeys == 0 || ordering.lt(key, readInternalNode(metadata.rootOffset).children.head.firstKey)) + private[io] def upperBound(key: Annotation): Long = + if ( + nKeys == 0 || ordering.lt(key, readInternalNode(metadata.rootOffset).children.head.firstKey) + ) 0 else upperBound(key, height - 1, metadata.rootOffset) - } private def getLeafNode(index: Long, level: Int, offset: Long): LeafNode = { if (level == 0) { @@ -206,11 +216,15 @@ class IndexReader( node.children(localIdx.toInt) } - def boundsByInterval(interval: Interval): (Long, Long) = { + def boundsByInterval(interval: Interval): (Long, Long) = boundsByInterval(interval.start, interval.end, interval.includesStart, interval.includesEnd) - } - def boundsByInterval(start: Annotation, end: Annotation, includesStart: Boolean, includesEnd: Boolean): (Long, Long) = { + def boundsByInterval( + start: Annotation, + end: Annotation, + includesStart: Boolean, + includesEnd: Boolean, + ): (Long, Long) = { require(Interval.isValid(ordering, start, end, includesStart, includesEnd)) val startIdx = if (includesStart) lowerBound(start) else upperBound(start) val endIdx = if (includesEnd) upperBound(end) else lowerBound(end) @@ -220,7 +234,12 @@ class IndexReader( def queryByInterval(interval: Interval): Iterator[LeafChild] = queryByInterval(interval.start, interval.end, interval.includesStart, interval.includesEnd) - def queryByInterval(start: Annotation, end: Annotation, includesStart: Boolean, includesEnd: Boolean): Iterator[LeafChild] = { + def queryByInterval( + start: Annotation, + end: Annotation, + includesStart: Boolean, + includesEnd: Boolean, + ): Iterator[LeafChild] = { val (startIdx, endIdx) = boundsByInterval(start, end, includesStart, includesEnd) iterator(startIdx, endIdx) } @@ -250,7 +269,7 @@ class IndexReader( def hasNext: Boolean = pos < end - def seek(key: Annotation) { + def seek(key: Annotation): Unit = { val newPos = lowerBound(key) assert(newPos >= pos) localPos += (newPos - pos).toInt @@ -264,11 +283,11 @@ class IndexReader( def iterateUntil(key: Annotation): Iterator[LeafChild] = iterator(0, lowerBound(key)) - def close() { + def close(): Unit = { leafDecoder.close() internalDecoder.close() - log.info(s"Index reader cache queries: ${ cacheHits + cacheMisses }") - log.info(s"Index reader cache hit rate: ${ cacheHits.toDouble / (cacheHits + cacheMisses) }") + log.info(s"Index reader cache queries: ${cacheHits + cacheMisses}") + log.info(s"Index reader cache hit rate: ${cacheHits.toDouble / (cacheHits + cacheMisses)}") } } @@ -277,11 +296,14 @@ final case class InternalChild( firstIndex: Long, firstKey: Annotation, firstRecordOffset: Long, - firstAnnotation: Annotation) + firstAnnotation: Annotation, +) object InternalNode { def apply(r: Row): InternalNode = { - val children = r.get(0).asInstanceOf[IndexedSeq[Row]].map(r => InternalChild(r.getLong(0), r.getLong(1), r.get(2), r.getLong(3), r.get(4))) + val children = r.get(0).asInstanceOf[IndexedSeq[Row]].map(r => + InternalChild(r.getLong(0), r.getLong(1), r.get(2), r.getLong(3), r.get(4)) + ) InternalNode(children) } } @@ -296,7 +318,8 @@ final case class InternalNode(children: IndexedSeq[InternalChild]) extends Index final case class LeafChild( key: Annotation, recordOffset: Long, - annotation: Annotation) { + annotation: Annotation, +) { def longChild(j: Int): Long = annotation.asInstanceOf[Row].getAs[Long](j) } @@ -304,13 +327,15 @@ final case class LeafChild( object LeafNode { def apply(r: Row): LeafNode = { val firstKeyIndex = r.getLong(0) - val keys = r.get(1).asInstanceOf[IndexedSeq[Row]].map(r => LeafChild(r.get(0), r.getLong(1), r.get(2))) + val keys = + r.get(1).asInstanceOf[IndexedSeq[Row]].map(r => LeafChild(r.get(0), r.getLong(1), r.get(2))) LeafNode(firstKeyIndex, keys) } } final case class LeafNode( firstIndex: Long, - children: IndexedSeq[LeafChild]) extends IndexNode + children: IndexedSeq[LeafChild], +) extends IndexNode sealed trait IndexNode diff --git a/hail/src/main/scala/is/hail/io/index/IndexWriter.scala b/hail/src/main/scala/is/hail/io/index/IndexWriter.scala index 4e28df19241..20a2e974240 100644 --- a/hail/src/main/scala/is/hail/io/index/IndexWriter.scala +++ b/hail/src/main/scala/is/hail/io/index/IndexWriter.scala @@ -1,24 +1,28 @@ package is.hail.io.index -import is.hail.annotations.{Annotation, Region, RegionPool, RegionValueBuilder} +import is.hail.annotations.{Annotation, Region, RegionPool} import is.hail.asm4s.{HailClassLoader, _} import is.hail.backend.{ExecuteContext, HailStateManager, HailTaskContext} -import is.hail.expr.ir.{CodeParam, EmitClassBuilder, EmitCodeBuilder, EmitFunctionBuilder, EmitMethodBuilder, IEmitCode, IntArrayBuilder, LongArrayBuilder, ParamType} +import is.hail.expr.ir.{ + CodeParam, EmitClassBuilder, EmitCodeBuilder, EmitFunctionBuilder, EmitMethodBuilder, IEmitCode, + IntArrayBuilder, LongArrayBuilder, ParamType, +} import is.hail.io._ import is.hail.io.fs.FS import is.hail.rvd.AbstractRVDSpec import is.hail.types +import is.hail.types.physical.{PCanonicalArray, PCanonicalStruct, PType} import is.hail.types.physical.stypes.SValue import is.hail.types.physical.stypes.concrete.{SBaseStructPointer, SBaseStructPointerSettable} import is.hail.types.physical.stypes.interfaces.SBaseStructValue -import is.hail.types.physical.{PCanonicalArray, PCanonicalStruct, PType} import is.hail.types.virtual.Type import is.hail.utils._ import is.hail.utils.richUtils.ByteTrackingOutputStream -import org.json4s.jackson.Serialization import java.io.OutputStream +import org.json4s.jackson.Serialization + trait AbstractIndexMetadata { def fileVersion: Int @@ -46,7 +50,7 @@ case class IndexMetadataUntypedJSON( nKeys: Long, indexPath: String, rootOffset: Long, - attributes: Map[String, Any] + attributes: Map[String, Any], ) { def toMetadata(keyType: Type, annotationType: Type): IndexMetadata = IndexMetadata( fileVersion, branchingFactor, @@ -54,7 +58,7 @@ case class IndexMetadataUntypedJSON( nKeys, indexPath, rootOffset, attributes) def toFileMetadata: VariableMetadata = VariableMetadata( - branchingFactor, height, nKeys, rootOffset, attributes + branchingFactor, height, nKeys, rootOffset, attributes, ) } @@ -67,31 +71,46 @@ case class IndexMetadata( nKeys: Long, indexPath: String, rootOffset: Long, - attributes: Map[String, Any] + attributes: Map[String, Any], ) extends AbstractIndexMetadata object IndexWriter { val version: SemanticVersion = SemanticVersion(1, 2, 0) val spec: BufferSpec = BufferSpec.default + def builder( ctx: ExecuteContext, keyType: PType, annotationType: PType, branchingFactor: Int = 4096, - attributes: Map[String, Any] = Map.empty[String, Any] + attributes: Map[String, Any] = Map.empty[String, Any], ): (String, HailClassLoader, HailTaskContext, RegionPool) => IndexWriter = { - val sm = ctx.stateManager; - val f = StagedIndexWriter.build(ctx, keyType, annotationType, branchingFactor); + val sm = ctx.stateManager; + val f = StagedIndexWriter.build(ctx, keyType, annotationType, branchingFactor); { (path: String, hcl: HailClassLoader, htc: HailTaskContext, pool: RegionPool) => - new IndexWriter(sm, keyType, annotationType, f(path, hcl, htc, pool, attributes), pool, attributes) + new IndexWriter( + sm, + keyType, + annotationType, + f(path, hcl, htc, pool, attributes), + pool, + attributes, + ) } } } -class IndexWriter(sm: HailStateManager, keyType: PType, valueType: PType, comp: CompiledIndexWriter, pool: RegionPool, attributes: Map[String, Any]) extends AutoCloseable { - private val region = Region(pool=pool) - private val rvb = new RegionValueBuilder(sm, region) +class IndexWriter( + sm: HailStateManager, + keyType: PType, + valueType: PType, + comp: CompiledIndexWriter, + pool: RegionPool, + attributes: Map[String, Any], +) extends AutoCloseable { + private val region = Region(pool = pool) + def appendRow(x: Annotation, offset: Long, annotation: Annotation): Unit = { val koff = keyType.unstagedStoreJavaObject(sm, x, region) val voff = valueType.unstagedStoreJavaObject(sm, annotation, region) @@ -106,12 +125,23 @@ class IndexWriter(sm: HailStateManager, keyType: PType, valueType: PType, comp: } } -class IndexWriterArrayBuilder(name: String, maxSize: Int, sb: SettableBuilder, region: Value[Region], arrayType: PCanonicalArray) { +class IndexWriterArrayBuilder( + name: String, + maxSize: Int, + sb: SettableBuilder, + region: Value[Region], + arrayType: PCanonicalArray, +) { private val aoff = sb.newSettable[Long](s"${name}_aoff") private val len = sb.newSettable[Int](s"${name}_len") - val eltType: PCanonicalStruct = types.tcoerce[PCanonicalStruct](arrayType.elementType.setRequired((false))) - private val elt = new SBaseStructPointerSettable(SBaseStructPointer(eltType), sb.newSettable[Long](s"${name}_elt_off")) + val eltType: PCanonicalStruct = + types.tcoerce[PCanonicalStruct](arrayType.elementType.setRequired((false))) + + private val elt = new SBaseStructPointerSettable( + SBaseStructPointer(eltType), + sb.newSettable[Long](s"${name}_elt_off"), + ) def length: Code[Int] = len @@ -131,30 +161,44 @@ class IndexWriterArrayBuilder(name: String, maxSize: Int, sb: SettableBuilder, r def setFieldValue(cb: EmitCodeBuilder, name: String, field: SValue): Unit = { eltType.setFieldPresent(cb, elt.a, name) - eltType.fieldType(name).storeAtAddress(cb, eltType.fieldOffset(elt.a, name), region, field, deepCopy = true) + eltType.fieldType(name).storeAtAddress( + cb, + eltType.fieldOffset(elt.a, name), + region, + field, + deepCopy = true, + ) } def setField(cb: EmitCodeBuilder, name: String, v: => IEmitCode): Unit = - v.consume(cb, - eltType.setFieldMissing(cb, elt.a, name), - sv => setFieldValue(cb, name, sv)) + v.consume(cb, eltType.setFieldMissing(cb, elt.a, name), sv => setFieldValue(cb, name, sv)) def addChild(cb: EmitCodeBuilder): Unit = { loadChild(cb, len) cb.assign(len, len + 1) } - def loadChild(cb: EmitCodeBuilder, idx: Code[Int]): Unit = elt.store(cb, eltType.loadCheapSCode(cb, arrayType.loadElement(aoff, idx))) + + def loadChild(cb: EmitCodeBuilder, idx: Code[Int]): Unit = + elt.store(cb, eltType.loadCheapSCode(cb, arrayType.loadElement(aoff, idx))) + def getLoadedChild: SBaseStructValue = elt } class StagedIndexWriterUtils(ib: Settable[IndexWriterUtils]) { - def create(cb: EmitCodeBuilder, path: Code[String], fs: Code[FS], meta: Code[StagedIndexMetadata]): Unit = - cb.assign(ib, Code.newInstance[IndexWriterUtils, String, FS, StagedIndexMetadata](path, fs, meta)) + def create(cb: EmitCodeBuilder, path: Code[String], fs: Code[FS], meta: Code[StagedIndexMetadata]) + : Unit = + cb.assign( + ib, + Code.newInstance[IndexWriterUtils, String, FS, StagedIndexMetadata](path, fs, meta), + ) + def size: Code[Int] = ib.invoke[Int]("size") + def add(cb: EmitCodeBuilder, r: Code[Region], aoff: Code[Long], len: Code[Int]): Unit = cb += ib.invoke[Region, Long, Int, Unit]("add", r, aoff, len) - def update(cb: EmitCodeBuilder, idx: Code[Int], r: Code[Region], aoff: Code[Long], len: Code[Int]): Unit = + def update(cb: EmitCodeBuilder, idx: Code[Int], r: Code[Region], aoff: Code[Long], len: Code[Int]) + : Unit = cb += ib.invoke[Int, Region, Long, Int, Unit]("update", idx, r, aoff, len) def getRegion(idx: Code[Int]): Code[Region] = ib.invoke[Int, Region]("getRegion", idx) @@ -165,7 +209,12 @@ class StagedIndexWriterUtils(ib: Settable[IndexWriterUtils]) { def bytesWritten: Code[Long] = ib.invoke[Long]("bytesWritten") def os: Code[OutputStream] = ib.invoke[OutputStream]("os") - def writeMetadata(cb: EmitCodeBuilder, height: Code[Int], rootOffset: Code[Long], nKeys: Code[Long]): Unit = + def writeMetadata( + cb: EmitCodeBuilder, + height: Code[Int], + rootOffset: Code[Long], + nKeys: Code[Long], + ): Unit = cb += ib.invoke[Int, Long, Long, Unit]("writeMetadata", height, rootOffset, nKeys) } @@ -173,11 +222,21 @@ case class StagedIndexMetadata( branchingFactor: Int, keyType: Type, annotationType: Type, - attributes: Map[String, Any] + attributes: Map[String, Any], ) { - def serialize(out: OutputStream, height: Int, rootOffset: Long, nKeys: Long) { + def serialize(out: OutputStream, height: Int, rootOffset: Long, nKeys: Long): Unit = { import AbstractRVDSpec.formats - val metadata = IndexMetadata(IndexWriter.version.rep, branchingFactor, height, keyType, annotationType, nKeys, "index", rootOffset, attributes) + val metadata = IndexMetadata( + IndexWriter.version.rep, + branchingFactor, + height, + keyType, + annotationType, + nKeys, + "index", + rootOffset, + attributes, + ) Serialization.write(metadata, out) } } @@ -190,9 +249,8 @@ class IndexWriterUtils(path: String, fs: FS, meta: StagedIndexMetadata) { def bytesWritten: Long = trackedOS.bytesWritten def os: OutputStream = trackedOS - def writeMetadata(height: Int, rootOffset: Long, nKeys: Long): Unit = { - using(fs.create(metadataPath)) { os => meta.serialize(os, height, rootOffset, nKeys) } - } + def writeMetadata(height: Int, rootOffset: Long, nKeys: Long): Unit = + using(fs.create(metadataPath))(os => meta.serialize(os, height, rootOffset, nKeys)) val rBuilder = new BoxedArrayBuilder[Region]() val aBuilder = new LongArrayBuilder() @@ -221,7 +279,7 @@ class IndexWriterUtils(path: String, fs: FS, meta: StagedIndexMetadata) { def getLength(idx: Int): Int = lBuilder(idx) def close(): Unit = { - rBuilder.result().foreach { r => r.close() } + rBuilder.result().foreach(r => r.close()) trackedOS.close() } } @@ -238,145 +296,194 @@ object StagedIndexWriter { ctx: ExecuteContext, keyType: PType, annotationType: PType, - branchingFactor: Int = 4096 + branchingFactor: Int = 4096, ): (String, HailClassLoader, HailTaskContext, RegionPool, Map[String, Any]) => CompiledIndexWriter = { - val fb = EmitFunctionBuilder[CompiledIndexWriter](ctx, "indexwriter", + val fb = EmitFunctionBuilder[CompiledIndexWriter]( + ctx, + "indexwriter", FastSeq[ParamType](typeInfo[Long], typeInfo[Long], typeInfo[Long]), - typeInfo[Unit]) + typeInfo[Unit], + ) val cb = fb.ecb val siw = new StagedIndexWriter(branchingFactor, keyType, annotationType, cb) - cb.newEmitMethod("init", FastSeq[ParamType](typeInfo[String], classInfo[Map[String, Any]]), typeInfo[Unit]) - .voidWithBuilder(cb => siw.init(cb, cb.emb.getCodeParam[String](1), cb.emb.getCodeParam[Map[String, Any]](2))) + cb.newEmitMethod( + "init", + FastSeq[ParamType](typeInfo[String], classInfo[Map[String, Any]]), + typeInfo[Unit], + ) + .voidWithBuilder(cb => + siw.init(cb, cb.emb.getCodeParam[String](1), cb.emb.getCodeParam[Map[String, Any]](2)) + ) fb.emb.voidWithBuilder { cb => - siw.add(cb, + siw.add( + cb, IEmitCode(cb, false, keyType.loadCheapSCode(cb, fb.getCodeParam[Long](1))), fb.getCodeParam[Long](2), - IEmitCode(cb, false, annotationType.loadCheapSCode(cb, fb.getCodeParam[Long](3)))) + IEmitCode(cb, false, annotationType.loadCheapSCode(cb, fb.getCodeParam[Long](3))), + ) } cb.newEmitMethod("close", FastSeq[ParamType](), typeInfo[Unit]) .voidWithBuilder(siw.close) cb.newEmitMethod("trackedOS", FastSeq[ParamType](), typeInfo[ByteTrackingOutputStream]) - .emitWithBuilder[ByteTrackingOutputStream] { _ => Code.checkcast[ByteTrackingOutputStream](siw.utils.os) } + .emitWithBuilder[ByteTrackingOutputStream] { _ => + Code.checkcast[ByteTrackingOutputStream](siw.utils.os) + } val makeFB = fb.resultWithIndex() val fsBc = ctx.fsBc - { (path: String, hcl: HailClassLoader, htc: HailTaskContext, pool: RegionPool, attributes: Map[String, Any]) => - pool.scopedRegion { r => - // FIXME: This seems wrong? But also, anywhere we use broadcasting for the FS is wrong. - val f = makeFB(hcl, fsBc.value, htc, r) - f.init(path, attributes) - f - } + { + ( + path: String, + hcl: HailClassLoader, + htc: HailTaskContext, + pool: RegionPool, + attributes: Map[String, Any], + ) => + pool.scopedRegion { r => + // FIXME: This seems wrong? But also, anywhere we use broadcasting for the FS is wrong. + val f = makeFB(hcl, fsBc.value, htc, r) + f.init(path, attributes) + f + } } } - def withDefaults(keyType: PType, cb: EmitClassBuilder[_], + def withDefaults( + keyType: PType, + cb: EmitClassBuilder[_], branchingFactor: Int = 4096, - annotationType: PType = +PCanonicalStruct()): StagedIndexWriter = + annotationType: PType = +PCanonicalStruct(), + ): StagedIndexWriter = new StagedIndexWriter(branchingFactor, keyType, annotationType, cb) } -class StagedIndexWriter(branchingFactor: Int, keyType: PType, annotationType: PType, cb: EmitClassBuilder[_]) { +class StagedIndexWriter( + branchingFactor: Int, + keyType: PType, + annotationType: PType, + cb: EmitClassBuilder[_], +) { require(branchingFactor > 1) - private var elementIdx = cb.genFieldThisRef[Long]() + private val elementIdx = cb.genFieldThisRef[Long]() private val ob = cb.genFieldThisRef[OutputBuffer]() private val utils = new StagedIndexWriterUtils(cb.genFieldThisRef[IndexWriterUtils]()) - private val leafBuilder = new StagedLeafNodeBuilder(branchingFactor, keyType, annotationType, cb.fieldBuilder) - private val writeInternalNode: EmitMethodBuilder[_] = { - val m = cb.genEmitMethod[Int, Boolean, Unit]("writeInternalNode") + private val leafBuilder = + new StagedLeafNodeBuilder(branchingFactor, keyType, annotationType, cb.fieldBuilder) - val internalBuilder = new StagedInternalNodeBuilder(branchingFactor, keyType, annotationType, m.localBuilder) - val parentBuilder = new StagedInternalNodeBuilder(branchingFactor, keyType, annotationType, m.localBuilder) + private val writeInternalNode: EmitMethodBuilder[_] = + cb.defineEmitMethod( + genName("m", "writeInternalNode"), + FastSeq(IntInfo, BooleanInfo), + UnitInfo, + ) { m => + val internalBuilder = + new StagedInternalNodeBuilder(branchingFactor, keyType, annotationType, m.localBuilder) + val parentBuilder = + new StagedInternalNodeBuilder(branchingFactor, keyType, annotationType, m.localBuilder) - m.emitWithBuilder { cb => val level = m.getCodeParam[Int](1) val isRoot = m.getCodeParam[Boolean](2) - val idxOff = cb.newLocal[Long]("indexOff") - cb.assign(idxOff, utils.bytesWritten) - internalBuilder.loadFrom(cb, utils, level) - cb += ob.writeByte(1.toByte) - internalBuilder.encode(cb, ob) - cb += ob.flush() - - val next = m.newLocal[Int]("next") - cb.assign(next, level + 1) - cb.if_(!isRoot, { - cb.if_(utils.size.ceq(next), - parentBuilder.create(cb), { - cb.if_(utils.getLength(next).ceq(branchingFactor), - cb.invokeVoid(m, CodeParam(next), CodeParam(false)) - ) - parentBuilder.loadFrom(cb, utils, next) - }) - internalBuilder.loadChild(cb, 0) - parentBuilder.add(cb, idxOff, internalBuilder.getLoadedChild) - parentBuilder.store(cb, utils, next) - }) - - internalBuilder.reset(cb) - internalBuilder.store(cb, utils, level) - Code._empty - } - m - } - - private val writeLeafNode: EmitMethodBuilder[_] = { - val m = cb.genEmitMethod[Unit]("writeLeafNode") - val parentBuilder = new StagedInternalNodeBuilder(branchingFactor, keyType, annotationType, m.localBuilder) - m.voidWithBuilder { cb => - val idxOff = cb.newLocal[Long]("indexOff") - cb.assign(idxOff, utils.bytesWritten) - cb += ob.writeByte(0.toByte) - leafBuilder.encode(cb, ob) - cb += ob.flush() + m.emitWithBuilder { cb => + val idxOff = cb.newLocal[Long]("indexOff") + cb.assign(idxOff, utils.bytesWritten) + internalBuilder.loadFrom(cb, utils, level) + cb += ob.writeByte(1.toByte) + internalBuilder.encode(cb, ob) + cb += ob.flush() + + val next = m.newLocal[Int]("next") + cb.assign(next, level + 1) + cb.if_( + !isRoot, { + cb.if_( + utils.size.ceq(next), + parentBuilder.create(cb), { + cb.if_( + utils.getLength(next).ceq(branchingFactor), + cb.invokeVoid(m, cb.this_, CodeParam(next), CodeParam(false)), + ) + parentBuilder.loadFrom(cb, utils, next) + }, + ) + internalBuilder.loadChild(cb, 0) + parentBuilder.add(cb, idxOff, internalBuilder.getLoadedChild) + parentBuilder.store(cb, utils, next) + }, + ) - cb.if_(utils.getLength(0).ceq(branchingFactor), - cb.invokeVoid(writeInternalNode, CodeParam(0), CodeParam(false)) - ) - parentBuilder.loadFrom(cb, utils, 0) + internalBuilder.reset(cb) + internalBuilder.store(cb, utils, level) + Code._empty + } + } - leafBuilder.loadChild(cb, 0) - parentBuilder.add(cb, idxOff, leafBuilder.firstIdx(cb).asLong.value, leafBuilder.getLoadedChild) - parentBuilder.store(cb, utils, 0) - leafBuilder.reset(cb, elementIdx) + private val writeLeafNode: EmitMethodBuilder[_] = + cb.defineEmitMethod(genName("m", "writeLeafNode"), FastSeq(), UnitInfo) { m => + val parentBuilder = + new StagedInternalNodeBuilder(branchingFactor, keyType, annotationType, m.localBuilder) + m.voidWithBuilder { cb => + val idxOff = cb.newLocal[Long]("indexOff") + cb.assign(idxOff, utils.bytesWritten) + cb += ob.writeByte(0.toByte) + leafBuilder.encode(cb, ob) + cb += ob.flush() + + cb.if_( + utils.getLength(0).ceq(branchingFactor), + cb.invokeVoid(writeInternalNode, cb.this_, CodeParam(0), CodeParam(false)), + ) + parentBuilder.loadFrom(cb, utils, 0) + + leafBuilder.loadChild(cb, 0) + parentBuilder.add( + cb, + idxOff, + leafBuilder.firstIdx(cb).asLong.value, + leafBuilder.getLoadedChild, + ) + parentBuilder.store(cb, utils, 0) + leafBuilder.reset(cb, elementIdx) + } } - m - } - private val flush: EmitMethodBuilder[_] = { - val m = cb.genEmitMethod[Long]("flush") - m.emitWithBuilder { cb => - val idxOff = cb.newLocal[Long]("indexOff") - val level = m.newLocal[Int]("level") - cb.if_(leafBuilder.ab.length > 0, cb.invokeVoid(writeLeafNode)) - cb.assign(level, const(0)) - cb.while_(level < utils.size - 1, { - cb.if_(utils.getLength(level) > 0, - cb.invokeVoid(writeInternalNode, CodeParam(level), CodeParam(false)) + private val flush: EmitMethodBuilder[_] = + cb.defineEmitMethod(genName("m", "flush"), FastSeq(), LongInfo) { m => + m.emitWithBuilder { cb => + val idxOff = cb.newLocal[Long]("indexOff") + val level = m.newLocal[Int]("level") + cb.if_(leafBuilder.ab.length > 0, cb.invokeVoid(writeLeafNode, cb.this_)) + cb.assign(level, 0) + cb.while_( + level < utils.size - 1, { + cb.if_( + utils.getLength(level) > 0, + cb.invokeVoid(writeInternalNode, cb.this_, CodeParam(level), CodeParam(false)), + ) + cb.assign(level, level + 1) + }, ) - cb.assign(level, level + 1) - }) - cb.assign(idxOff, utils.bytesWritten) - writeInternalNode.invokeCode[Unit](cb, CodeParam(level), CodeParam(true)) - idxOff.load() + cb.assign(idxOff, utils.bytesWritten) + cb.invokeVoid(writeInternalNode, cb.this_, CodeParam(level), CodeParam(true)) + idxOff.load() + } } - m - } - def add(cb: EmitCodeBuilder, key: => IEmitCode, offset: Code[Long], annotation: => IEmitCode) { - cb.if_(leafBuilder.ab.length.ceq(branchingFactor), cb.invokeVoid(writeLeafNode)) + def add(cb: EmitCodeBuilder, key: => IEmitCode, offset: Code[Long], annotation: => IEmitCode) + : Unit = { + cb.if_(leafBuilder.ab.length.ceq(branchingFactor), cb.invokeVoid(writeLeafNode, cb.this_)) leafBuilder.add(cb, key, offset, annotation) cb.assign(elementIdx, elementIdx + 1L) } + def close(cb: EmitCodeBuilder): Unit = { - val off = flush.invokeCode[Long](cb) + val off = cb.invokeCode[Long](flush, cb.this_) leafBuilder.close(cb) utils.close(cb) utils.writeMetadata(cb, utils.size + 1, off, elementIdx) @@ -387,8 +494,10 @@ class StagedIndexWriter(branchingFactor: Int, keyType: PType, annotationType: PT branchingFactor, cb.emb.getObject(keyType.virtualType), cb.emb.getObject(annotationType.virtualType), - attributes) - val internalBuilder = new StagedInternalNodeBuilder(branchingFactor, keyType, annotationType, cb.localBuilder) + attributes, + ) + val internalBuilder = + new StagedInternalNodeBuilder(branchingFactor, keyType, annotationType, cb.localBuilder) cb.assign(elementIdx, 0L) utils.create(cb, path, cb.emb.getFS, metadata) cb.assign(ob, IndexWriter.spec.buildCodeOutputBuffer(utils.os)) diff --git a/hail/src/main/scala/is/hail/io/index/InternalNodeBuilder.scala b/hail/src/main/scala/is/hail/io/index/InternalNodeBuilder.scala index 9c7233a6d12..c3d6585d593 100644 --- a/hail/src/main/scala/is/hail/io/index/InternalNodeBuilder.scala +++ b/hail/src/main/scala/is/hail/io/index/InternalNodeBuilder.scala @@ -6,50 +6,74 @@ import is.hail.expr.ir.EmitCodeBuilder import is.hail.io.OutputBuffer import is.hail.types.encoded.EType import is.hail.types.physical._ -import is.hail.types.physical.stypes.concrete.{SBaseStructPointer, SBaseStructPointerSettable, SIndexablePointerValue} +import is.hail.types.physical.stypes.concrete.{ + SBaseStructPointer, SBaseStructPointerSettable, SIndexablePointerValue, +} import is.hail.types.physical.stypes.interfaces._ import is.hail.types.virtual.{TStruct, Type} object InternalNodeBuilder { - def virtualType(keyType: Type, annotationType: Type): TStruct = typ(PType.canonical(keyType), PType.canonical(annotationType)).virtualType + def virtualType(keyType: Type, annotationType: Type): TStruct = + typ(PType.canonical(keyType), PType.canonical(annotationType)).virtualType def legacyTyp(keyType: PType, annotationType: PType) = PCanonicalStruct( - "children" -> +PCanonicalArray(+PCanonicalStruct( - "index_file_offset" -> +PInt64(), - "first_idx" -> +PInt64(), - "first_key" -> keyType, - "first_record_offset" -> +PInt64(), - "first_annotation" -> annotationType - ), required = true) + "children" -> +PCanonicalArray( + +PCanonicalStruct( + "index_file_offset" -> +PInt64(), + "first_idx" -> +PInt64(), + "first_key" -> keyType, + "first_record_offset" -> +PInt64(), + "first_annotation" -> annotationType, + ), + required = true, + ) ) def arrayType(keyType: PType, annotationType: PType) = - PCanonicalArray(PCanonicalStruct(required = true, - "index_file_offset" -> +PInt64(), - "first_idx" -> +PInt64(), - "first_key" -> keyType, - "first_record_offset" -> +PInt64(), - "first_annotation" -> annotationType - ), required = true) + PCanonicalArray( + PCanonicalStruct( + required = true, + "index_file_offset" -> +PInt64(), + "first_idx" -> +PInt64(), + "first_key" -> keyType, + "first_record_offset" -> +PInt64(), + "first_annotation" -> annotationType, + ), + required = true, + ) def typ(keyType: PType, annotationType: PType) = PCanonicalStruct( "children" -> arrayType(keyType, annotationType) ) } -class StagedInternalNodeBuilder(maxSize: Int, keyType: PType, annotationType: PType, sb: SettableBuilder) { +class StagedInternalNodeBuilder( + maxSize: Int, + keyType: PType, + annotationType: PType, + sb: SettableBuilder, +) { private val region = sb.newSettable[Region]("internal_node_region") - val ab = new IndexWriterArrayBuilder("internal_node", maxSize, - sb, region, - InternalNodeBuilder.arrayType(keyType, annotationType)) + + val ab = new IndexWriterArrayBuilder( + "internal_node", + maxSize, + sb, + region, + InternalNodeBuilder.arrayType(keyType, annotationType), + ) val pType: PCanonicalStruct = InternalNodeBuilder.typ(keyType, annotationType) - private val node = new SBaseStructPointerSettable(SBaseStructPointer(pType), sb.newSettable[Long]("internal_node_node")) + + private val node = new SBaseStructPointerSettable( + SBaseStructPointer(pType), + sb.newSettable[Long]("internal_node_node"), + ) def loadFrom(cb: EmitCodeBuilder, ib: StagedIndexWriterUtils, idx: Value[Int]): Unit = { cb.assign(region, ib.getRegion(idx)) cb.assign(node.a, ib.getArrayOffset(idx)) - val aoff = node.loadField(cb, 0).get(cb).asInstanceOf[SIndexablePointerValue].a + val aoff = node.loadField(cb, 0).getOrAssert(cb).asInstanceOf[SIndexablePointerValue].a ab.loadFrom(cb, aoff, ib.getLength(idx)) } @@ -79,7 +103,12 @@ class StagedInternalNodeBuilder(maxSize: Int, keyType: PType, annotationType: PT def nodeAddress: SBaseStructValue = node - def add(cb: EmitCodeBuilder, indexFileOffset: Code[Long], firstIndex: Code[Long], firstChild: SBaseStructValue): Unit = { + def add( + cb: EmitCodeBuilder, + indexFileOffset: Code[Long], + firstIndex: Code[Long], + firstChild: SBaseStructValue, + ): Unit = { ab.addChild(cb) ab.setFieldValue(cb, "index_file_offset", primitive(cb.memoize(indexFileOffset))) ab.setFieldValue(cb, "first_idx", primitive(cb.memoize(firstIndex))) diff --git a/hail/src/main/scala/is/hail/io/index/LeafNodeBuilder.scala b/hail/src/main/scala/is/hail/io/index/LeafNodeBuilder.scala index d5ffa1242a9..b27dafdcc9e 100644 --- a/hail/src/main/scala/is/hail/io/index/LeafNodeBuilder.scala +++ b/hail/src/main/scala/is/hail/io/index/LeafNodeBuilder.scala @@ -6,58 +6,88 @@ import is.hail.expr.ir.{EmitCodeBuilder, IEmitCode} import is.hail.io.OutputBuffer import is.hail.types.encoded.EType import is.hail.types.physical._ -import is.hail.types.physical.stypes.{SCode, SValue} +import is.hail.types.physical.stypes.SValue import is.hail.types.physical.stypes.concrete.{SBaseStructPointer, SBaseStructPointerSettable} -import is.hail.types.physical.stypes.interfaces.{SBaseStructValue, primitive} +import is.hail.types.physical.stypes.interfaces.{primitive, SBaseStructValue} import is.hail.types.virtual.{TStruct, Type} import is.hail.utils._ object LeafNodeBuilder { - def virtualType(keyType: Type, annotationType: Type): TStruct = typ(PType.canonical(keyType), PType.canonical(annotationType)).virtualType + def virtualType(keyType: Type, annotationType: Type): TStruct = + typ(PType.canonical(keyType), PType.canonical(annotationType)).virtualType def legacyTyp(keyType: PType, annotationType: PType) = PCanonicalStruct( "first_idx" -> +PInt64(), - "keys" -> +PCanonicalArray(+PCanonicalStruct( - "key" -> keyType, - "offset" -> +PInt64(), - "annotation" -> annotationType - ), required = true)) + "keys" -> +PCanonicalArray( + +PCanonicalStruct( + "key" -> keyType, + "offset" -> +PInt64(), + "annotation" -> annotationType, + ), + required = true, + ), + ) def arrayType(keyType: PType, annotationType: PType) = - PCanonicalArray(PCanonicalStruct(required = true, - "key" -> keyType, - "offset" -> +PInt64(), - "annotation" -> annotationType), required = true) + PCanonicalArray( + PCanonicalStruct( + required = true, + "key" -> keyType, + "offset" -> +PInt64(), + "annotation" -> annotationType, + ), + required = true, + ) def typ(keyType: PType, annotationType: PType) = PCanonicalStruct( "first_idx" -> +PInt64(), - "keys" -> arrayType(keyType, annotationType)) + "keys" -> arrayType(keyType, annotationType), + ) } - -class StagedLeafNodeBuilder(maxSize: Int, keyType: PType, annotationType: PType, sb: SettableBuilder) { +class StagedLeafNodeBuilder( + maxSize: Int, + keyType: PType, + annotationType: PType, + sb: SettableBuilder, +) { private val region = sb.newSettable[Region]("leaf_node_region") - val ab = new IndexWriterArrayBuilder("leaf_node", maxSize, - sb, region, - LeafNodeBuilder.arrayType(keyType, annotationType)) + + val ab = new IndexWriterArrayBuilder( + "leaf_node", + maxSize, + sb, + region, + LeafNodeBuilder.arrayType(keyType, annotationType), + ) private[this] val pType: PCanonicalStruct = LeafNodeBuilder.typ(keyType, annotationType) private[this] val idxType = pType.fieldType("first_idx").asInstanceOf[PInt64] - private[this] val node = new SBaseStructPointerSettable(SBaseStructPointer(pType), sb.newSettable[Long]("lef_node_addr")) + + private[this] val node = + new SBaseStructPointerSettable(SBaseStructPointer(pType), sb.newSettable[Long]("lef_node_addr")) def close(cb: EmitCodeBuilder): Unit = cb.if_(!region.isNull, cb += region.invalidate()) def reset(cb: EmitCodeBuilder, firstIdx: Code[Long]): Unit = { cb += region.invoke[Unit]("clear") node.store(cb, pType.loadCheapSCode(cb, pType.allocate(region))) - idxType.storePrimitiveAtAddress(cb, pType.fieldOffset(node.a, "first_idx"), primitive(cb.memoize(firstIdx))) + idxType.storePrimitiveAtAddress( + cb, + pType.fieldOffset(node.a, "first_idx"), + primitive(cb.memoize(firstIdx)), + ) ab.create(cb, pType.fieldOffset(node.a, "keys")) } def create(cb: EmitCodeBuilder, firstIdx: Code[Long]): Unit = { cb.assign(region, Region.stagedCreate(Region.REGULAR, cb.emb.ecb.pool())) node.store(cb, pType.loadCheapSCode(cb, pType.allocate(region))) - idxType.storePrimitiveAtAddress(cb, pType.fieldOffset(node.a, "first_idx"), primitive(cb.memoize(firstIdx))) + idxType.storePrimitiveAtAddress( + cb, + pType.fieldOffset(node.a, "first_idx"), + primitive(cb.memoize(firstIdx)), + ) ab.create(cb, pType.fieldOffset(node.a, "keys")) } @@ -69,7 +99,8 @@ class StagedLeafNodeBuilder(maxSize: Int, keyType: PType, annotationType: PType, def nodeAddress: SBaseStructValue = node - def add(cb: EmitCodeBuilder, key: => IEmitCode, offset: Code[Long], annotation: => IEmitCode): Unit = { + def add(cb: EmitCodeBuilder, key: => IEmitCode, offset: Code[Long], annotation: => IEmitCode) + : Unit = { ab.addChild(cb) ab.setField(cb, "key", key) ab.setFieldValue(cb, "offset", primitive(cb.memoize(offset))) @@ -78,5 +109,7 @@ class StagedLeafNodeBuilder(maxSize: Int, keyType: PType, annotationType: PType, def loadChild(cb: EmitCodeBuilder, idx: Code[Int]): Unit = ab.loadChild(cb, idx) def getLoadedChild: SBaseStructValue = ab.getLoadedChild - def firstIdx(cb: EmitCodeBuilder): SValue = idxType.loadCheapSCode(cb, pType.fieldOffset(node.a, "first_idx")) -} \ No newline at end of file + + def firstIdx(cb: EmitCodeBuilder): SValue = + idxType.loadCheapSCode(cb, pType.fieldOffset(node.a, "first_idx")) +} diff --git a/hail/src/main/scala/is/hail/io/index/StagedIndexReader.scala b/hail/src/main/scala/is/hail/io/index/StagedIndexReader.scala index 8527f07231a..0e4f01343ce 100644 --- a/hail/src/main/scala/is/hail/io/index/StagedIndexReader.scala +++ b/hail/src/main/scala/is/hail/io/index/StagedIndexReader.scala @@ -3,14 +3,18 @@ package is.hail.io.index import is.hail.annotations._ import is.hail.asm4s._ import is.hail.backend.TaskFinalizer -import is.hail.expr.ir.functions.IntervalFunctions.{arrayOfStructFindIntervalRange, compareStructWithPartitionIntervalEndpoint} -import is.hail.expr.ir.{BinarySearch, EmitCode, EmitCodeBuilder, EmitMethodBuilder, EmitValue, IEmitCode} +import is.hail.expr.ir.{ + BinarySearch, EmitCode, EmitCodeBuilder, EmitMethodBuilder, EmitValue, IEmitCode, +} +import is.hail.expr.ir.functions.IntervalFunctions.{ + arrayOfStructFindIntervalRange, compareStructWithPartitionIntervalEndpoint, +} import is.hail.io.AbstractTypedCodecSpec import is.hail.io.fs.FS +import is.hail.types.physical.{PCanonicalArray, PCanonicalBaseStruct} +import is.hail.types.physical.stypes.{SSettable, SValue} import is.hail.types.physical.stypes.concrete._ import is.hail.types.physical.stypes.interfaces._ -import is.hail.types.physical.stypes.{SSettable, SValue} -import is.hail.types.physical.{PCanonicalArray, PCanonicalBaseStruct} import is.hail.types.virtual.{TInt64, TTuple} import is.hail.utils._ @@ -21,40 +25,71 @@ case class VariableMetadata( height: Int, nKeys: Long, rootOffset: Long, - attributes: Map[String, Any] + attributes: Map[String, Any], ) -class StagedIndexReader(emb: EmitMethodBuilder[_], leafCodec: AbstractTypedCodecSpec, internalCodec: AbstractTypedCodecSpec) { - private[this] val cache: Settable[LongToRegionValueCache] = emb.genFieldThisRef[LongToRegionValueCache]("index_cache") - private[this] val metadata: Settable[VariableMetadata] = emb.genFieldThisRef[VariableMetadata]("index_file_metadata") +class StagedIndexReader( + emb: EmitMethodBuilder[_], + leafCodec: AbstractTypedCodecSpec, + internalCodec: AbstractTypedCodecSpec, +) { + private[this] val cache: Settable[LongToRegionValueCache] = + emb.genFieldThisRef[LongToRegionValueCache]("index_cache") - private[this] val is: Settable[ByteTrackingInputStream] = emb.genFieldThisRef[ByteTrackingInputStream]("index_is") + private[this] val metadata: Settable[VariableMetadata] = + emb.genFieldThisRef[VariableMetadata]("index_file_metadata") + + private[this] val is: Settable[ByteTrackingInputStream] = + emb.genFieldThisRef[ByteTrackingInputStream]("index_is") private[this] val leafPType = leafCodec.encodedType.decodedPType(leafCodec.encodedVirtualType) - private[this] val internalPType = internalCodec.encodedType.decodedPType(internalCodec.encodedVirtualType) - private[this] val leafChildType = leafPType.asInstanceOf[PCanonicalBaseStruct].types(1).asInstanceOf[PCanonicalArray].elementType.sType.asInstanceOf[SBaseStruct] - private[this] val leafChildLocalType = SStackStruct(leafChildType.virtualType, leafChildType.fieldEmitTypes) + private[this] val internalPType = + internalCodec.encodedType.decodedPType(internalCodec.encodedVirtualType) + + private[this] val leafChildType = leafPType.asInstanceOf[PCanonicalBaseStruct].types( + 1 + ).asInstanceOf[PCanonicalArray].elementType.sType.asInstanceOf[SBaseStruct] - private[this] val queryResultStartIndex: Settable[Long] = emb.genFieldThisRef[Long]("index_resultIndex") - private[this] val queryResultStartLeaf: SSettable = emb.newPField("index_resultOffset", leafChildLocalType) + private[this] val leafChildLocalType = + SStackStruct(leafChildType.virtualType, leafChildType.fieldEmitTypes) - private[this] val leafDec = leafCodec.encodedType.buildDecoder(leafCodec.encodedVirtualType, emb.ecb) - private[this]val internalDec = internalCodec.encodedType.buildDecoder(internalCodec.encodedVirtualType, emb.ecb) + private[this] val queryResultStartIndex: Settable[Long] = + emb.genFieldThisRef[Long]("index_resultIndex") + + private[this] val queryResultStartLeaf: SSettable = + emb.newPField("index_resultOffset", leafChildLocalType) + + private[this] val leafDec = + leafCodec.encodedType.buildDecoder(leafCodec.encodedVirtualType, emb.ecb) + + private[this] val internalDec = + internalCodec.encodedType.buildDecoder(internalCodec.encodedVirtualType, emb.ecb) def nKeys(cb: EmitCodeBuilder): Value[Long] = cb.memoize(metadata.invoke[Long]("nKeys")) - def initialize(cb: EmitCodeBuilder, - indexPath: Value[String] - ): Unit = { + def initialize(cb: EmitCodeBuilder, indexPath: Value[String]): Unit = { val fs = cb.emb.getFS cb.assign(cache, Code.newInstance[LongToRegionValueCache, Int](16)) - cb.assign(metadata, Code.invokeScalaObject2[FS, String, IndexMetadataUntypedJSON]( - IndexReader.getClass, "readUntyped", fs, indexPath - ).invoke[VariableMetadata]("toFileMetadata")) - - // FIXME: hardcoded. Will break if we change spec -- assumption not introduced with this code, but propagated. - cb.assign(is, Code.newInstance[ByteTrackingInputStream, InputStream](cb.emb.openUnbuffered(indexPath.concat("/index"), false))) + cb.assign( + metadata, + Code.invokeScalaObject2[FS, String, IndexMetadataUntypedJSON]( + IndexReader.getClass, + "readUntyped", + fs, + indexPath, + ).invoke[VariableMetadata]("toFileMetadata"), + ) + + /* FIXME: hardcoded. Will break if we change spec -- assumption not introduced with this code, + * but propagated. */ + cb.assign( + is, + Code.newInstance[ByteTrackingInputStream, InputStream](cb.emb.openUnbuffered( + indexPath.concat("/index"), + false, + )), + ) } @@ -71,10 +106,8 @@ class StagedIndexReader(emb: EmitMethodBuilder[_], leafCodec: AbstractTypedCodec cb.assign(metadata, Code._null) } - def queryBound(cb: EmitCodeBuilder, - endpoint: SBaseStructValue, - leansRight: Value[Boolean] - ): SBaseStructValue = { + def queryBound(cb: EmitCodeBuilder, endpoint: SBaseStructValue, leansRight: Value[Boolean]) + : SBaseStructValue = { val rootLevel = cb.memoize(metadata.invoke[Int]("height") - 1) val rootOffset = cb.memoize(metadata.invoke[Long]("rootOffset")) val nKeys = this.nKeys(cb) @@ -87,21 +120,26 @@ class StagedIndexReader(emb: EmitMethodBuilder[_], leafCodec: AbstractTypedCodec // handle the cases where the query is less than all keys (including the // empty index case), to establish the precondition of runQuery, and as a // fast path for a common case - cb.if_(nKeys.ceq(0), { - cb.assign(index, 0L) - cb.goto(LReturn) - }) - - val rootChildren = readInternalNode(cb, rootOffset).loadField(cb, "children").get(cb).asIndexable - val firstChild = rootChildren.loadElement(cb, 0).get(cb).asBaseStruct - val firstKey = firstChild.loadField(cb, "first_key").get(cb).asBaseStruct + cb.if_( + nKeys.ceq(0), { + cb.assign(index, 0L) + cb.goto(LReturn) + }, + ) + + val rootChildren = + readInternalNode(cb, rootOffset).loadField(cb, "children").getOrAssert(cb).asIndexable + val firstChild = rootChildren.loadElement(cb, 0).getOrAssert(cb).asBaseStruct + val firstKey = firstChild.loadField(cb, "first_key").getOrAssert(cb).asBaseStruct val compEndpointWithFirstKey = compareStructWithPartitionIntervalEndpoint(cb, firstKey, endpoint, leansRight) - cb.if_(compEndpointWithFirstKey > 0, { - cb.assign(index, firstChild.loadField(cb, "first_idx").get(cb).asLong.value) - cb.assign(leaf, getFirstLeaf(cb, firstChild)) - cb.goto(LReturn) - }) + cb.if_( + compEndpointWithFirstKey > 0, { + cb.assign(index, firstChild.loadField(cb, "first_idx").getOrAssert(cb).asLong.value) + cb.assign(leaf, getFirstLeaf(cb, firstChild)) + cb.goto(LReturn) + }, + ) queryBound(cb, endpoint, leansRight, rootLevel, rootOffset, nKeys, leaf) cb.assign(index, queryResultStartIndex) @@ -109,16 +147,19 @@ class StagedIndexReader(emb: EmitMethodBuilder[_], leafCodec: AbstractTypedCodec cb.define(LReturn) - SStackStruct.constructFromArgs(cb, null, TTuple(TInt64, leafChildType.virtualType), + SStackStruct.constructFromArgs( + cb, + null, + TTuple(TInt64, leafChildType.virtualType), EmitCode.present(cb.emb, primitive(index)), - EmitCode.fromI(cb.emb)(cb => IEmitCode(cb, index ceq nKeys, leaf))) + EmitCode.fromI(cb.emb)(cb => IEmitCode(cb, index ceq nKeys, leaf)), + ) } - /** - * returns tuple of (start index, end index, starting leaf) - * memory of starting leaf is not owned by `region`, consumers should deep copy if necessary - * starting leaf IS MISSING if (end_idx - start_idx == 0) - */ + /** returns tuple of (start index, end index, starting leaf) memory of starting leaf is not owned + * by `region`, consumers should deep copy if necessary starting leaf IS MISSING if (end_idx - + * start_idx == 0) + */ def queryInterval(cb: EmitCodeBuilder, interval: SIntervalValue): SBaseStructValue = { val rootLevel = cb.memoize(metadata.invoke[Int]("height") - 1) val rootOffset = cb.memoize(metadata.invoke[Long]("rootOffset")) @@ -128,9 +169,9 @@ class StagedIndexReader(emb: EmitMethodBuilder[_], leafCodec: AbstractTypedCodec val startLeaf = cb.newSLocal(leafChildLocalType, "queryInterval_startOffset") val endIdx = cb.newLocal[Long]("queryInterval_endIdx") - val startKey = interval.loadStart(cb).get(cb).asBaseStruct + val startKey = interval.loadStart(cb).getOrAssert(cb).asBaseStruct val startLeansRight = cb.memoize(!interval.includesStart) - val endKey = interval.loadEnd(cb).get(cb).asBaseStruct + val endKey = interval.loadEnd(cb).getOrAssert(cb).asBaseStruct val endLeansRight = interval.includesEnd val LReturn = CodeLabel() @@ -138,33 +179,44 @@ class StagedIndexReader(emb: EmitMethodBuilder[_], leafCodec: AbstractTypedCodec // handle the cases where the query is less than all keys (including the // empty index case), to establish the precondition of runQuery, and as a // fast path for a common case - cb.if_(nKeys.ceq(0), { - cb.assign(startIdx, 0L) - cb.assign(endIdx, 0L) - cb.goto(LReturn) - }) - - val rootChildren = readInternalNode(cb, rootOffset).loadField(cb, "children").get(cb).asIndexable - val firstChild = rootChildren.loadElement(cb, 0).get(cb).asBaseStruct - val firstKey = firstChild.loadField(cb, "first_key").get(cb).asBaseStruct + cb.if_( + nKeys.ceq(0), { + cb.assign(startIdx, 0L) + cb.assign(endIdx, 0L) + cb.goto(LReturn) + }, + ) + + val rootChildren = + readInternalNode(cb, rootOffset).loadField(cb, "children").getOrAssert(cb).asIndexable + val firstChild = rootChildren.loadElement(cb, 0).getOrAssert(cb).asBaseStruct + val firstKey = firstChild.loadField(cb, "first_key").getOrAssert(cb).asBaseStruct val compStartWithFirstKey = compareStructWithPartitionIntervalEndpoint(cb, firstKey, startKey, startLeansRight) - cb.if_(compStartWithFirstKey > 0, { - cb.assign(startIdx, firstChild.loadField(cb, "first_idx").get(cb).asLong.value) - cb.assign(startLeaf, getFirstLeaf(cb, firstChild)) - - val compEndWithFirstKey = - compareStructWithPartitionIntervalEndpoint(cb, firstKey, endKey, endLeansRight) - cb.if_(compEndWithFirstKey > 0, { - cb.assign(endIdx, startIdx) - }, { - queryBound(cb, endKey, endLeansRight, rootLevel, rootOffset, nKeys, startLeaf) - cb.assign(endIdx, queryResultStartIndex) - }) - cb.goto(LReturn) - }) - - val stackInterval = SStackInterval.construct(EmitValue.present(startKey), EmitValue.present(endKey), cb.memoize(!startLeansRight), endLeansRight) + cb.if_( + compStartWithFirstKey > 0, { + cb.assign(startIdx, firstChild.loadField(cb, "first_idx").getOrAssert(cb).asLong.value) + cb.assign(startLeaf, getFirstLeaf(cb, firstChild)) + + val compEndWithFirstKey = + compareStructWithPartitionIntervalEndpoint(cb, firstKey, endKey, endLeansRight) + cb.if_( + compEndWithFirstKey > 0, + cb.assign(endIdx, startIdx), { + queryBound(cb, endKey, endLeansRight, rootLevel, rootOffset, nKeys, startLeaf) + cb.assign(endIdx, queryResultStartIndex) + }, + ) + cb.goto(LReturn) + }, + ) + + val stackInterval = SStackInterval.construct( + EmitValue.present(startKey), + EmitValue.present(endKey), + cb.memoize(!startLeansRight), + endLeansRight, + ) val (_startIdx, _startLeaf, _endIdx) = runQuery(cb, stackInterval, rootLevel, rootOffset, nKeys, startLeaf, isPointQuery = false) cb.assign(startIdx, _startIdx) @@ -174,28 +226,57 @@ class StagedIndexReader(emb: EmitMethodBuilder[_], leafCodec: AbstractTypedCodec cb.define(LReturn) val n = cb.memoize(endIdx - startIdx) - cb.if_(n < 0L, cb._fatal("n less than 0: ", n.toS, ", startIdx=", startIdx.toS, ", endIdx=", endIdx.toS, ", query=", cb.strValue(interval))) - cb.if_(n > 0L && startIdx >= nKeys, cb._fatal("bad start idx: ", startIdx.toS, ", nKeys=", nKeys.toS)) - - SStackStruct.constructFromArgs(cb, null, TTuple(TInt64, TInt64, leafChildType.virtualType), + cb.if_( + n < 0L, + cb._fatal( + "n less than 0: ", + n.toS, + ", startIdx=", + startIdx.toS, + ", endIdx=", + endIdx.toS, + ", query=", + cb.strValue(interval), + ), + ) + cb.if_( + n > 0L && startIdx >= nKeys, + cb._fatal("bad start idx: ", startIdx.toS, ", nKeys=", nKeys.toS), + ) + + SStackStruct.constructFromArgs( + cb, + null, + TTuple(TInt64, TInt64, leafChildType.virtualType), EmitCode.present(cb.emb, primitive(startIdx)), EmitCode.present(cb.emb, primitive(endIdx)), - EmitCode.fromI(cb.emb)(cb => IEmitCode(cb, n ceq 0L, startLeaf))) + EmitCode.fromI(cb.emb)(cb => IEmitCode(cb, n ceq 0L, startLeaf)), + ) } - private[this] def queryBound(cb: EmitCodeBuilder, + private[this] def queryBound( + cb: EmitCodeBuilder, endpoint: SBaseStructValue, leansRight: Value[Boolean], rootLevel: Value[Int], rootOffset: Value[Long], rootSuccessorIndex: Value[Long], - rootSuccessorLeaf: SValue + rootSuccessorLeaf: SValue, ): Unit = { cb.invokeVoid( - cb.emb.ecb.getOrGenEmitMethod("queryBound", + cb.emb.ecb.getOrGenEmitMethod( + "queryBound", ("queryBound", this), - FastSeq(endpoint.st.paramType, typeInfo[Boolean], typeInfo[Int], typeInfo[Long], typeInfo[Long], leafChildLocalType.paramType), - UnitInfo) { emb => + FastSeq( + endpoint.st.paramType, + typeInfo[Boolean], + typeInfo[Int], + typeInfo[Long], + typeInfo[Long], + leafChildLocalType.paramType, + ), + UnitInfo, + ) { emb => emb.emitWithBuilder { cb => val endpoint = emb.getSCodeParam(1).asBaseStruct val leansRight = emb.getCodeParam[Boolean](2) @@ -203,13 +284,27 @@ class StagedIndexReader(emb: EmitMethodBuilder[_], leafCodec: AbstractTypedCodec val rootOffset = emb.getCodeParam[Long](4) val rootSuccessorIndex = emb.getCodeParam[Long](5) val rootSuccessorLeaf = emb.getSCodeParam(6) - val interval = SStackInterval.construct(EmitValue.present(endpoint), EmitValue.present(endpoint), cb.memoize(!leansRight), leansRight) - val (startIndex, startLeaf, _) = runQuery(cb, interval, rootLevel, rootOffset, rootSuccessorIndex, rootSuccessorLeaf, isPointQuery = true) + val interval = SStackInterval.construct( + EmitValue.present(endpoint), + EmitValue.present(endpoint), + cb.memoize(!leansRight), + leansRight, + ) + val (startIndex, startLeaf, _) = runQuery(cb, interval, rootLevel, rootOffset, + rootSuccessorIndex, rootSuccessorLeaf, isPointQuery = true) cb.assign(queryResultStartIndex, startIndex) cb.assign(queryResultStartLeaf, startLeaf) Code._empty } - }, endpoint, leansRight, rootLevel, rootOffset, rootSuccessorIndex, rootSuccessorLeaf) + }, + cb.this_, + endpoint, + leansRight, + rootLevel, + rootOffset, + rootSuccessorIndex, + rootSuccessorLeaf, + ) } // Supports both point and interval queries. If `isPointQuery`, end key @@ -219,41 +314,55 @@ class StagedIndexReader(emb: EmitMethodBuilder[_], leafCodec: AbstractTypedCodec // If this is the root of the index, so there is no following record, // `rootSuccessorIndex` must be `nKeys`, and `rootSuccessorLeaf` can be anything, // as it will never be accessed. - private[this] def runQuery(cb: EmitCodeBuilder, + private[this] def runQuery( + cb: EmitCodeBuilder, interval: SStackIntervalValue, rootLevel: Value[Int], rootOffset: Value[Long], rootSuccessorIndex: Value[Long], rootSuccessorLeaf: SValue, - isPointQuery: Boolean + isPointQuery: Boolean, ): (Value[Long], SStackStructValue, Value[Long]) = { - val startKey = interval.loadStart(cb).get(cb).asBaseStruct + val startKey = interval.loadStart(cb).getOrAssert(cb).asBaseStruct val startLeansRight = cb.memoize(!interval.includesStart) - val endKey = interval.loadEnd(cb).get(cb).asBaseStruct + val endKey = interval.loadEnd(cb).getOrAssert(cb).asBaseStruct val endLeansRight = interval.includesEnd - def searchChildren(children: SIndexableValue, isInternalNode: Boolean): (Value[Int], Value[Int]) = { + def searchChildren(children: SIndexableValue, isInternalNode: Boolean) + : (Value[Int], Value[Int]) = { val keyFieldName = if (isInternalNode) "first_key" else "key" if (isPointQuery) { def ltNeedle(child: IEmitCode): Code[Boolean] = { - val key = child.get(cb).asBaseStruct.loadField(cb, keyFieldName).get(cb).asBaseStruct + val key = child.getOrAssert(cb).asBaseStruct.loadField(cb, keyFieldName).getOrAssert( + cb + ).asBaseStruct val c = compareStructWithPartitionIntervalEndpoint(cb, key, startKey, startLeansRight) c < 0 } val idx = BinarySearch.lowerBound(cb, children, ltNeedle) (idx, idx) } else { - arrayOfStructFindIntervalRange(cb, children, startKey, startLeansRight, endKey, endLeansRight, - _.get(cb).asBaseStruct.loadField(cb, keyFieldName)) + arrayOfStructFindIntervalRange( + cb, + children, + startKey, + startLeansRight, + endKey, + endLeansRight, + _.getOrAssert(cb).asBaseStruct.loadField(cb, keyFieldName), + ) } } val startIndex: Settable[Long] = cb.newLocal[Long]("startIndex") - val startLeaf: SStackStructSettable = cb.newSLocal(leafChildLocalType, "startOffset").asInstanceOf[SStackStructSettable] + val startLeaf: SStackStructSettable = + cb.newSLocal(leafChildLocalType, "startOffset").asInstanceOf[SStackStructSettable] val endIndex: Settable[Long] = cb.newLocal[Long]("endIndex") - val successorIndex: Settable[Long] = cb.newLocal[Long]("queryInterval_successorIndex", rootSuccessorIndex) - val successorLeaf: SStackStructSettable = cb.newSLocal(leafChildLocalType, "successorLeaf").asInstanceOf[SStackStructSettable] + val successorIndex: Settable[Long] = + cb.newLocal[Long]("queryInterval_successorIndex", rootSuccessorIndex) + val successorLeaf: SStackStructSettable = + cb.newSLocal(leafChildLocalType, "successorLeaf").asInstanceOf[SStackStructSettable] cb.assign(successorLeaf, rootSuccessorLeaf) val level = cb.newLocal[Int]("queryInterval_level", rootLevel) @@ -266,87 +375,112 @@ class StagedIndexReader(emb: EmitMethodBuilder[_], leafCodec: AbstractTypedCodec val Lstart = CodeLabel() cb.define(Lstart) - def updateSuccessor(children: SIndexableValue, idx: Value[Int]): Unit = { - cb.if_(idx < children.loadLength(), { - val successorChild = children.loadElement(cb, idx).get(cb).asBaseStruct - cb.assign(successorIndex, successorChild.loadField(cb, "first_idx").get(cb).asLong.value) - cb.assign(successorLeaf, getFirstLeaf(cb, successorChild)) - }) - } + def updateSuccessor(children: SIndexableValue, idx: Value[Int]): Unit = + cb.if_( + idx < children.loadLength(), { + val successorChild = children.loadElement(cb, idx).getOrAssert(cb).asBaseStruct + cb.assign( + successorIndex, + successorChild.loadField(cb, "first_idx").getOrAssert(cb).asLong.value, + ) + cb.assign(successorLeaf, getFirstLeaf(cb, successorChild)) + }, + ) + + cb.if_( + level > 0, { + /* InternalNode( children: IndexedSeq[InternalChild]) InternalChild( index_file_offset: + * Long, first_idx: Long, first_key: Annotation, first_record_offset: Long, + * first_annotation: Annotation) */ + val children = + readInternalNode(cb, nodeOffset).loadField(cb, "children").getOrAssert(cb).asIndexable + + val (start, end) = searchChildren(children, isInternalNode = true) + + cb.assign(level, level - 1) + cb.if_(start.ceq(0) || end.ceq(0), cb._fatal("queryInterval broken invariant")) + + cb.if_( + if (isPointQuery) const(true).get else start.ceq(end), { + updateSuccessor(children, start) + cb.assign( + nodeOffset, + children.loadElement(cb, start - 1).getOrAssert(cb).asBaseStruct.loadField( + cb, + "index_file_offset", + ).getOrAssert(cb).asLong.value, + ) + cb.goto(Lstart) + }, + ) + + cb.if_(!(start < children.loadLength()), cb._fatal("unreachable")) + + // continue with separate point queries for each endpoint + updateSuccessor(children, end) + cb.assign( + nodeOffset, + children.loadElement(cb, end - 1).getOrAssert(cb).asBaseStruct.loadField( + cb, + "index_file_offset", + ).getOrAssert(cb).asLong.value, + ) + queryBound(cb, endKey, endLeansRight, level, nodeOffset, successorIndex, successorLeaf) + cb.assign(endIndex, queryResultStartIndex) - cb.if_(level > 0, { - /* - InternalNode( - children: IndexedSeq[InternalChild]) - InternalChild( - index_file_offset: Long, - first_idx: Long, - first_key: Annotation, - first_record_offset: Long, - first_annotation: Annotation) - */ - val children = readInternalNode(cb, nodeOffset).loadField(cb, "children").get(cb).asIndexable - - val (start, end) = searchChildren(children, isInternalNode = true) - - cb.assign(level, level-1) - cb.if_(start.ceq(0) || end.ceq(0), cb._fatal("queryInterval broken invariant")) - - cb.if_(if (isPointQuery) const(true).get else start.ceq(end), { updateSuccessor(children, start) - cb.assign(nodeOffset, children.loadElement(cb, start-1).get(cb).asBaseStruct.loadField(cb, "index_file_offset").get(cb).asLong.value) - cb.goto(Lstart) - }) - - cb.if_(!(start < children.loadLength()), cb._fatal("unreachable")) - - // continue with separate point queries for each endpoint - updateSuccessor(children, end) - cb.assign(nodeOffset, children.loadElement(cb, end-1).get(cb).asBaseStruct.loadField(cb, "index_file_offset").get(cb).asLong.value) - queryBound(cb, endKey, endLeansRight, level, nodeOffset, successorIndex, successorLeaf) - cb.assign(endIndex, queryResultStartIndex) - - updateSuccessor(children, start) - cb.assign(nodeOffset, children.loadElement(cb, start-1).get(cb).asBaseStruct.loadField(cb, "index_file_offset").get(cb).asLong.value) - queryBound(cb, startKey, startLeansRight, level, nodeOffset, successorIndex, successorLeaf) - cb.assign(startIndex, queryResultStartIndex) - cb.assign(startLeaf, queryResultStartLeaf) - }, { - /* - LeafNode( - first_idx: Long, - keys: IndexedSeq[LeafChild]) - LeafChild( - key: Annotation, - offset: Long, - annotation: Annotation) - */ - val node = readLeafNode(cb, nodeOffset).asBaseStruct - val children = node.asBaseStruct.loadField(cb, "keys").get(cb).asIndexable - - val (start, end) = searchChildren(children, isInternalNode = false) - - val firstIndex = cb.memoize(node.asBaseStruct.loadField(cb, "first_idx")).get(cb).asInt64.value - cb.if_(start < children.loadLength(), { - cb.assign(startIndex, firstIndex + start.toL) - cb.assign(startLeaf, children.loadElement(cb, start).get(cb).asBaseStruct.toStackStruct(cb)) + cb.assign( + nodeOffset, + children.loadElement(cb, start - 1).getOrAssert(cb).asBaseStruct.loadField( + cb, + "index_file_offset", + ).getOrAssert(cb).asLong.value, + ) + queryBound(cb, startKey, startLeansRight, level, nodeOffset, successorIndex, successorLeaf) + cb.assign(startIndex, queryResultStartIndex) + cb.assign(startLeaf, queryResultStartLeaf) }, { - cb.if_(successorIndex.cne(firstIndex + start.toL), cb._fatal("queryInterval broken invariant")) - cb.assign(startIndex, successorIndex) - cb.assign(startLeaf, successorLeaf) - }) - cb.assign(endIndex, firstIndex + end.toL) - }) + /* LeafNode( first_idx: Long, keys: IndexedSeq[LeafChild]) LeafChild( key: Annotation, + * offset: Long, annotation: Annotation) */ + val node = readLeafNode(cb, nodeOffset).asBaseStruct + val children = node.asBaseStruct.loadField(cb, "keys").getOrAssert(cb).asIndexable + + val (start, end) = searchChildren(children, isInternalNode = false) + + val firstIndex = + cb.memoize(node.asBaseStruct.loadField(cb, "first_idx")).get(cb).asInt64.value + cb.if_( + start < children.loadLength(), { + cb.assign(startIndex, firstIndex + start.toL) + cb.assign( + startLeaf, + children.loadElement(cb, start).getOrAssert(cb).asBaseStruct.toStackStruct(cb), + ) + }, { + cb.if_( + successorIndex.cne(firstIndex + start.toL), + cb._fatal("queryInterval broken invariant"), + ) + cb.assign(startIndex, successorIndex) + cb.assign(startLeaf, successorLeaf) + }, + ) + cb.assign(endIndex, firstIndex + end.toL) + }, + ) (startIndex, startLeaf, endIndex) } - private[this] def getFirstLeaf(cb: EmitCodeBuilder, internalChild: SBaseStructValue): SValue = { - new SStackStructValue(leafChildLocalType, Array( - EmitValue.present(internalChild.loadField(cb, "first_key").get(cb)), - EmitValue.present(internalChild.loadField(cb, "first_record_offset").get(cb)), - EmitValue.present(internalChild.loadField(cb, "first_annotation").get(cb)))) - } + private[this] def getFirstLeaf(cb: EmitCodeBuilder, internalChild: SBaseStructValue): SValue = + new SStackStructValue( + leafChildLocalType, + Array( + EmitValue.present(internalChild.loadField(cb, "first_key").getOrAssert(cb)), + EmitValue.present(internalChild.loadField(cb, "first_record_offset").getOrAssert(cb)), + EmitValue.present(internalChild.loadField(cb, "first_annotation").getOrAssert(cb)), + ), + ) // internal node is an array of children private[io] def readInternalNode(cb: EmitCodeBuilder, offset: Value[Long]): SBaseStructValue = { @@ -355,23 +489,44 @@ class StagedIndexReader(emb: EmitMethodBuilder[_], leafCodec: AbstractTypedCodec // returns an address if cached, or 0L if not found val cached = cb.memoize(cache.invoke[Long, Long]("get", offset)) - cb.if_(cached cne 0L, { - cb.assign(ret, internalPType.loadCheapSCode(cb, cached)) - }, { - cb.assign(ret, cb.invokeSCode(cb.emb.ecb.getOrGenEmitMethod("readInternalNode", ("readInternalNode", this), FastSeq(LongInfo), ret.st.paramType) { emb => - emb.emitSCode { cb => - val offset = emb.getCodeParam[Long](1) - cb += is.invoke[Long, Unit]("seek", offset) - val ib = cb.memoize(internalCodec.buildCodeInputBuffer(is)) - cb.if_(ib.readByte() cne 1, cb._fatal("bad buffer at internal!")) - val region = cb.memoize(cb.emb.ecb.pool().invoke[Region.Size, Region]("getRegion", Region.TINIER)) - val internalNode = internalDec.apply(cb, region, ib) - val internalNodeAddr = internalPType.store(cb, region, internalNode, false) - cb += cache.invoke[Long, Region, Long, Unit]("put", offset, region, internalNodeAddr) - internalNode - } - }, offset)) - }) + cb.if_( + cached cne 0L, + cb.assign(ret, internalPType.loadCheapSCode(cb, cached)), { + cb.assign( + ret, + cb.invokeSCode( + cb.emb.ecb.getOrGenEmitMethod( + "readInternalNode", + ("readInternalNode", this), + FastSeq(LongInfo), + ret.st.paramType, + ) { emb => + emb.emitSCode { cb => + val offset = emb.getCodeParam[Long](1) + cb += is.invoke[Long, Unit]("seek", offset) + val ib = cb.memoize(internalCodec.buildCodeInputBuffer(is)) + cb.if_(ib.readByte() cne 1, cb._fatal("bad buffer at internal!")) + val region = cb.memoize(cb.emb.ecb.pool().invoke[Region.Size, Region]( + "getRegion", + Region.TINIER, + )) + val internalNode = internalDec.apply(cb, region, ib) + val internalNodeAddr = internalPType.store(cb, region, internalNode, false) + cb += cache.invoke[Long, Region, Long, Unit]( + "put", + offset, + region, + internalNodeAddr, + ) + internalNode + } + }, + cb.this_, + offset, + ), + ) + }, + ) ret.asBaseStruct } @@ -383,34 +538,52 @@ class StagedIndexReader(emb: EmitMethodBuilder[_], leafCodec: AbstractTypedCodec // returns an address if cached, or 0L if not found val cached = cb.memoize(cache.invoke[Long, Long]("get", offset)) - cb.if_(cached cne 0L, { - cb.assign(ret, leafPType.loadCheapSCode(cb, cached)) - }, { - cb.assign(ret, cb.invokeSCode(cb.emb.ecb.getOrGenEmitMethod("readLeafNode", ("readLeafNode", this), FastSeq(LongInfo), ret.st.paramType) { emb => - emb.emitSCode { cb => - val offset = emb.getCodeParam[Long](1) - cb += is.invoke[Long, Unit]("seek", offset) - val ib = cb.memoize(leafCodec.buildCodeInputBuffer(is)) - cb.if_(ib.readByte() cne 0, cb._fatal("bad buffer at leaf!")) - val region = cb.memoize(cb.emb.ecb.pool().invoke[Region.Size, Region]("getRegion", Region.TINIER)) - val leafNode = leafDec.apply(cb, region, ib) - val leafNodeAddr = leafPType.store(cb, region, leafNode, false) - cb += cache.invoke[Long, Region, Long, Unit]("put", offset, region, leafNodeAddr) - leafPType.loadCheapSCode(cb, leafNodeAddr) - } - }, offset)) - }) + cb.if_( + cached cne 0L, + cb.assign(ret, leafPType.loadCheapSCode(cb, cached)), { + cb.assign( + ret, + cb.invokeSCode( + cb.emb.ecb.getOrGenEmitMethod( + "readLeafNode", + ("readLeafNode", this), + FastSeq(LongInfo), + ret.st.paramType, + ) { emb => + emb.emitSCode { cb => + val offset = emb.getCodeParam[Long](1) + cb += is.invoke[Long, Unit]("seek", offset) + val ib = cb.memoize(leafCodec.buildCodeInputBuffer(is)) + cb.if_(ib.readByte() cne 0, cb._fatal("bad buffer at leaf!")) + val region = cb.memoize(cb.emb.ecb.pool().invoke[Region.Size, Region]( + "getRegion", + Region.TINIER, + )) + val leafNode = leafDec.apply(cb, region, ib) + val leafNodeAddr = leafPType.store(cb, region, leafNode, false) + cb += cache.invoke[Long, Region, Long, Unit]("put", offset, region, leafNodeAddr) + leafPType.loadCheapSCode(cb, leafNodeAddr) + } + }, + cb.this_, + offset, + ), + ) + }, + ) ret.asBaseStruct } - def queryIndex(cb: EmitCodeBuilder, region: Value[Region], absIndex: Value[Long]): SBaseStructValue = { + def queryIndex(cb: EmitCodeBuilder, region: Value[Region], absIndex: Value[Long]) + : SBaseStructValue = { cb.invokeSCode( - cb.emb.ecb.getOrGenEmitMethod("queryIndex", + cb.emb.ecb.getOrGenEmitMethod( + "queryIndex", ("queryIndex", this), FastSeq(classInfo[Region], typeInfo[Long]), - leafChildType.paramType) { emb => + leafChildType.paramType, + ) { emb => emb.emitSCode { cb => - val region = emb.getCodeParam[Region](1) val absIndex = emb.getCodeParam[Long](2) val level = cb.newLocal[Int]("lowerBound_level", metadata.invoke[Int]("height") - 1) @@ -422,25 +595,53 @@ class StagedIndexReader(emb: EmitMethodBuilder[_], leafCodec: AbstractTypedCodec val Lstart = CodeLabel() cb.define(Lstart) - cb.if_(level ceq 0, { - val leafNode = readLeafNode(cb, offset) - val localIdx = cb.memoize((absIndex - leafNode.loadField(cb, "first_idx").get(cb).asInt64.value.toL).toI) - cb.assign(result, leafNode.loadField(cb, "keys").get(cb).asIndexable.loadElement(cb, localIdx).get(cb)) - }, { - val internalNode = readInternalNode(cb, offset) - val children = internalNode.loadField(cb, "children").get(cb).asIndexable - val firstIdx = children.loadElement(cb, 0).get(cb).asBaseStruct.loadField(cb, "first_idx").get(cb).asInt64.value - val nKeysPerChild = cb.memoize(Code.invokeStatic2[java.lang.Math, Double, Double, Double]("pow", - branchingFactor.toD, - level.toD).toL) - val localIdx = cb.memoize((absIndex - firstIdx) / nKeysPerChild) - cb.assign(level, level - 1) - cb.assign(offset, children.loadElement(cb, localIdx.toI).get(cb).asBaseStruct.loadField(cb, "index_file_offset").get(cb).asInt64.value) - cb.goto(Lstart) - }) + cb.if_( + level ceq 0, { + val leafNode = readLeafNode(cb, offset) + val localIdx = cb.memoize( + (absIndex - leafNode.loadField(cb, "first_idx").getOrAssert( + cb + ).asInt64.value.toL).toI + ) + cb.assign( + result, + leafNode.loadField(cb, "keys").getOrAssert(cb).asIndexable.loadElement( + cb, + localIdx, + ).getOrAssert(cb), + ) + }, { + val internalNode = readInternalNode(cb, offset) + val children = internalNode.loadField(cb, "children").getOrAssert(cb).asIndexable + val firstIdx = children.loadElement(cb, 0).getOrAssert(cb).asBaseStruct.loadField( + cb, + "first_idx", + ).getOrAssert(cb).asInt64.value + val nKeysPerChild = + cb.memoize(Code.invokeStatic2[java.lang.Math, Double, Double, Double]( + "pow", + branchingFactor.toD, + level.toD, + ).toL) + val localIdx = cb.memoize((absIndex - firstIdx) / nKeysPerChild) + cb.assign(level, level - 1) + cb.assign( + offset, + children.loadElement(cb, localIdx.toI).getOrAssert(cb).asBaseStruct.loadField( + cb, + "index_file_offset", + ).getOrAssert(cb).asInt64.value, + ) + cb.goto(Lstart) + }, + ) leafChildType.coerceOrCopy(cb, region, result, false) } - }, region, absIndex).asBaseStruct + }, + cb.this_, + region, + absIndex, + ).asBaseStruct } } diff --git a/hail/src/main/scala/is/hail/io/package.scala b/hail/src/main/scala/is/hail/io/package.scala index c17e4691307..51724ed5dff 100644 --- a/hail/src/main/scala/is/hail/io/package.scala +++ b/hail/src/main/scala/is/hail/io/package.scala @@ -1,12 +1,11 @@ package is.hail -import java.io.OutputStreamWriter -import java.nio.charset._ - -import is.hail.asm4s._ +import is.hail.io.fs.FS import is.hail.types.virtual.Type import is.hail.utils._ -import is.hail.io.fs.FS + +import java.io.OutputStreamWriter +import java.nio.charset._ package object io { type VCFFieldAttributes = Map[String, String] @@ -15,14 +14,14 @@ package object io { val utfCharset = Charset.forName("UTF-8") - def exportTypes(filename: String, fs: FS, info: Array[(String, Type)]) { + def exportTypes(filename: String, fs: FS, info: Array[(String, Type)]): Unit = { val sb = new StringBuilder using(new OutputStreamWriter(fs.create(filename))) { out => info.foreachBetween { case (name, t) => sb.append(prettyIdentifier(name)) sb.append(":") t.pretty(sb, 0, compact = true) - } { sb += ',' } + }(sb += ',') out.write(sb.result()) } diff --git a/hail/src/main/scala/is/hail/io/plink/ExportPlink.scala b/hail/src/main/scala/is/hail/io/plink/ExportPlink.scala index d39606a5246..8a2a2a3b68e 100644 --- a/hail/src/main/scala/is/hail/io/plink/ExportPlink.scala +++ b/hail/src/main/scala/is/hail/io/plink/ExportPlink.scala @@ -1,15 +1,9 @@ package is.hail.io.plink -import java.io.{OutputStream, OutputStreamWriter} -import is.hail.HailContext -import is.hail.annotations.Region -import is.hail.backend.ExecuteContext -import is.hail.expr.ir.MatrixValue -import is.hail.types._ -import is.hail.types.physical.{PString, PStruct} -import is.hail.variant._ import is.hail.utils._ -import org.apache.spark.TaskContext +import is.hail.variant._ + +import java.io.OutputStream object ExportPlink { val bedHeader = Array[Byte](108, 27, 1) @@ -20,13 +14,21 @@ object ExportPlink { def alleles: Array[String] = Array(a0, a1) if (spaceRegex.findFirstIn(contig).isDefined) - fatal(s"Invalid contig found at '${ VariantMethods.locusAllelesToString(locus, alleles) }' -- no white space allowed: '$contig'") + fatal( + s"Invalid contig found at '${VariantMethods.locusAllelesToString(locus, alleles)}' -- no white space allowed: '$contig'" + ) if (spaceRegex.findFirstIn(a0).isDefined) - fatal(s"Invalid allele found at '${ VariantMethods.locusAllelesToString(locus, alleles) }' -- no white space allowed: '$a0'") + fatal( + s"Invalid allele found at '${VariantMethods.locusAllelesToString(locus, alleles)}' -- no white space allowed: '$a0'" + ) if (spaceRegex.findFirstIn(a1).isDefined) - fatal(s"Invalid allele found at '${ VariantMethods.locusAllelesToString(locus, alleles) }' -- no white space allowed: '$a1'") + fatal( + s"Invalid allele found at '${VariantMethods.locusAllelesToString(locus, alleles)}' -- no white space allowed: '$a1'" + ) if (spaceRegex.findFirstIn(varid).isDefined) - fatal(s"Invalid 'varid' found at '${ VariantMethods.locusAllelesToString(locus, alleles) }' -- no white space allowed: '$varid'") + fatal( + s"Invalid 'varid' found at '${VariantMethods.locusAllelesToString(locus, alleles)}' -- no white space allowed: '$varid'" + ) } } @@ -45,15 +47,14 @@ class BitPacker(nBitsPerItem: Int, os: OutputStream) extends Serializable { def +=(i: Int) = add(i) - private def write() { + private def write(): Unit = while (nBitsStaged >= 8) { os.write(data.toByte) data = data >>> 8 nBitsStaged -= 8 } - } - def flush() { + def flush(): Unit = { if (nBitsStaged > 0) os.write(data.toByte) data = 0L diff --git a/hail/src/main/scala/is/hail/io/plink/LoadPlink.scala b/hail/src/main/scala/is/hail/io/plink/LoadPlink.scala index 13a9c339954..9c4cd0d5f43 100644 --- a/hail/src/main/scala/is/hail/io/plink/LoadPlink.scala +++ b/hail/src/main/scala/is/hail/io/plink/LoadPlink.scala @@ -12,24 +12,34 @@ import is.hail.rvd.RVDPartitioner import is.hail.types._ import is.hail.types.physical._ import is.hail.types.virtual._ -import is.hail.utils.StringEscapeUtils._ import is.hail.utils._ +import is.hail.utils.StringEscapeUtils._ import is.hail.variant._ + import org.apache.spark.TaskContext import org.apache.spark.sql.Row -import org.json4s.jackson.JsonMethods import org.json4s.{DefaultFormats, Formats, JValue} +import org.json4s.jackson.JsonMethods -case class FamFileConfig(isQuantPheno: Boolean = false, +case class FamFileConfig( + isQuantPheno: Boolean = false, delimiter: String = "\\t", - missingValue: String = "NA") + missingValue: String = "NA", +) object LoadPlink { def expectedBedSize(nSamples: Int, nVariants: Long): Long = 3 + nVariants * ((nSamples + 3) / 4) - def parseBim(ctx: ExecuteContext, fs: FS, bimPath: String, a2Reference: Boolean, - contigRecoding: Map[String, String], rg: Option[ReferenceGenome], locusAllelesType: TStruct, - skipInvalidLoci: Boolean): (Int, Array[PlinkVariant]) = { + def parseBim( + ctx: ExecuteContext, + fs: FS, + bimPath: String, + a2Reference: Boolean, + contigRecoding: Map[String, String], + rg: Option[ReferenceGenome], + locusAllelesType: TStruct, + skipInvalidLoci: Boolean, + ): (Int, Array[PlinkVariant]) = { val vs = new BoxedArrayBuilder[PlinkVariant]() var n = 0 fs.readLines(bimPath) { lines => @@ -51,7 +61,9 @@ object LoadPlink { } case _ => - fatal(s"Invalid .bim line. Expected 6 fields, found ${ line.length } ${ plural(line.length, "field") }") + fatal( + s"Invalid .bim line. Expected 6 fields, found ${line.length} ${plural(line.length, "field")}" + ) } } n += 1 @@ -64,33 +76,47 @@ object LoadPlink { val numericRegex = """^-?(?:\d+|\d*\.\d+)(?:[eE]-?\d+)?$""".r - def importFamJSON(fs: FS, path: String, isQuantPheno: Boolean, delimiter: String, missingValue: String): String = { + def importFamJSON( + fs: FS, + path: String, + isQuantPheno: Boolean, + delimiter: String, + missingValue: String, + ): String = { val ffConfig = FamFileConfig(isQuantPheno, delimiter, missingValue) val (data, ptyp) = LoadPlink.parseFam(fs, path, ffConfig) val jv = JSONAnnotationImpex.exportAnnotation( Row(ptyp.virtualType.toString, data), - TStruct("type" -> TString, "data" -> TArray(ptyp.virtualType))) + TStruct("type" -> TString, "data" -> TArray(ptyp.virtualType)), + ) JsonMethods.compact(jv) } - - def parseFam(fs: FS, filename: String, ffConfig: FamFileConfig): (IndexedSeq[Row], PCanonicalStruct) = { + def parseFam(fs: FS, filename: String, ffConfig: FamFileConfig) + : (IndexedSeq[Row], PCanonicalStruct) = { val delimiter = unescapeString(ffConfig.delimiter) - val phenoSig = if (ffConfig.isQuantPheno) ("quant_pheno", PFloat64()) else ("is_case", PBoolean()) + val phenoSig = + if (ffConfig.isQuantPheno) ("quant_pheno", PFloat64()) else ("is_case", PBoolean()) - val signature = PCanonicalStruct(("id", PCanonicalString()), ("fam_id", PCanonicalString()), ("pat_id", PCanonicalString()), - ("mat_id", PCanonicalString()), ("is_female", PBoolean()), phenoSig) + val signature = PCanonicalStruct( + ("id", PCanonicalString()), + ("fam_id", PCanonicalString()), + ("pat_id", PCanonicalString()), + ("mat_id", PCanonicalString()), + ("is_female", PBoolean()), + phenoSig, + ) val idBuilder = new BoxedArrayBuilder[String] val structBuilder = new BoxedArrayBuilder[Row] - val m = fs.readLines(filename) { + fs.readLines(filename) { _.foreachLine { line => val split = line.split(delimiter) if (split.length != 6) - fatal(s"expected 6 fields, but found ${ split.length }") + fatal(s"expected 6 fields, but found ${split.length}") val Array(fam, kid, dad, mom, isFemale, pheno) = split val fam1 = if (fam != "0") fam else null @@ -115,12 +141,15 @@ object LoadPlink { if (!warnedAbout9) { warn( s"""Interpreting value '-9' as a valid quantitative phenotype, which differs from default PLINK behavior. - | Use missing='-9' to interpret '-9' as a missing value.""".stripMargin) + | Use missing='-9' to interpret '-9' as a missing value.""".stripMargin + ) warnedAbout9 = true } -9d case numericRegex() => pheno.toDouble - case _ => fatal(s"Invalid quantitative phenotype: '$pheno'. Value must be numeric or '${ ffConfig.missingValue }'") + case _ => fatal( + s"Invalid quantitative phenotype: '$pheno'. Value must be numeric or '${ffConfig.missingValue}'" + ) } else pheno match { @@ -130,7 +159,9 @@ object LoadPlink { case "0" => null case "-9" => null case "N/A" => null - case numericRegex() => fatal(s"Invalid case-control phenotype: '$pheno'. Control is '1', case is '2', missing is '0', '-9', '${ ffConfig.missingValue }', or non-numeric.") + case numericRegex() => fatal( + s"Invalid case-control phenotype: '$pheno'. Control is '1', case is '2', missing is '0', '-9', '${ffConfig.missingValue}', or non-numeric." + ) case _ => null } idBuilder += kid @@ -150,7 +181,6 @@ object LoadPlink { object MatrixPLINKReader { def fromJValue(ctx: ExecuteContext, jv: JValue): MatrixPLINKReader = { - val backend = ctx.backend val fs = ctx.fs implicit val formats: Formats = DefaultFormats @@ -162,21 +192,32 @@ object MatrixPLINKReader { val locusType = TLocus.schemaFromRG(params.rg) val locusAllelesType = TStruct( "locus" -> locusType, - "alleles" -> TArray(TString)) + "alleles" -> TArray(TString), + ) val ffConfig = FamFileConfig(params.quantPheno, params.delimiter, params.missing) val (sampleInfo, signature) = LoadPlink.parseFam(fs, params.fam, ffConfig) val nameMap = Map("id" -> "s") - val saSignature = signature.copy(fields = signature.fields.map(f => f.copy(name = nameMap.getOrElse(f.name, f.name)))) + val saSignature = signature.copy(fields = + signature.fields.map(f => f.copy(name = nameMap.getOrElse(f.name, f.name))) + ) val nSamples = sampleInfo.length if (nSamples <= 0) fatal("FAM file does not contain any samples") - val (nTotalVariants, variants) = LoadPlink.parseBim(ctx, fs, params.bim, params.a2Reference, params.contigRecoding, - referenceGenome, locusAllelesType, params.skipInvalidLoci) + val (nTotalVariants, variants) = LoadPlink.parseBim( + ctx, + fs, + params.bim, + params.a2Reference, + params.contigRecoding, + referenceGenome, + locusAllelesType, + params.skipInvalidLoci, + ) val nVariants = variants.length if (nTotalVariants <= 0) fatal("BIM file does not contain any variants") @@ -193,7 +234,9 @@ object MatrixPLINKReader { fatal("First two bytes of BED file do not match PLINK magic numbers 108 & 27") if (b3 == 0) - fatal("BED file is in individual major mode. First use plink with --make-bed to convert file to snp major mode before using Hail") + fatal( + "BED file is in individual major mode. First use plink with --make-bed to convert file to snp major mode before using Hail" + ) } val bedSize = fs.getFileSize(params.bed) @@ -230,9 +273,13 @@ object MatrixPLINKReader { var end = partScan(p + 1) if (start < end) { - while (end < nVariants - && lOrd.equiv(variants(end - 1).locusAlleles.asInstanceOf[Row].get(0), - variants(end).locusAlleles.asInstanceOf[Row].get(0))) + while ( + end < nVariants + && lOrd.equiv( + variants(end - 1).locusAlleles.asInstanceOf[Row].get(0), + variants(end).locusAlleles.asInstanceOf[Row].get(0), + ) + ) end += 1 cb += Row(params.bed, start, end) @@ -240,7 +287,9 @@ object MatrixPLINKReader { ib += Interval( variants(start).locusAlleles, variants(end - 1).locusAlleles, - includesStart = true, includesEnd = true) + includesStart = true, + includesEnd = true, + ) prevEnd = end } @@ -261,12 +310,15 @@ object MatrixPLINKReader { "locus" -> locusType, "alleles" -> TArray(TString), "rsid" -> TString, - "cm_position" -> TFloat64), + "cm_position" -> TFloat64, + ), rowKey = Array("locus", "alleles"), - entryType = TStruct("GT" -> TCall)) + entryType = TStruct("GT" -> TCall), + ) assert(locusAllelesType == fullMatrixType.rowKeyStruct) - new MatrixPLINKReader(params, referenceGenome, fullMatrixType, sampleInfo, variants, contexts, partitioner) + new MatrixPLINKReader(params, referenceGenome, fullMatrixType, sampleInfo, variants, contexts, + partitioner) } } @@ -283,13 +335,14 @@ case class MatrixPLINKReaderParameters( a2Reference: Boolean, rg: Option[String], contigRecoding: Map[String, String], - skipInvalidLoci: Boolean) + skipInvalidLoci: Boolean, +) class PlinkVariant( val index: Int, val locusAlleles: Any, val cmPos: Double, - val rsid: String + val rsid: String, ) extends Serializable class MatrixPLINKReader( @@ -299,7 +352,7 @@ class MatrixPLINKReader( sampleInfo: IndexedSeq[Row], variants: Array[PlinkVariant], contexts: Array[Any], - partitioner: RVDPartitioner + partitioner: RVDPartitioner, ) extends MatrixHybridReader { def rowUIDType = TInt64 @@ -316,16 +369,18 @@ class MatrixPLINKReader( val partitionCounts: Option[IndexedSeq[Long]] = None val globals = Row(sampleInfo.zipWithIndex.map { case (s, idx) => - Row((0 until s.length).map(s.apply) :+ idx.toLong :_*) + Row((0 until s.length).map(s.apply) :+ idx.toLong: _*) }) - override def concreteRowRequiredness(ctx: ExecuteContext, requestedType: TableType): VirtualTypeWithReq = + override def concreteRowRequiredness(ctx: ExecuteContext, requestedType: TableType) + : VirtualTypeWithReq = VirtualTypeWithReq(PType.canonical(requestedType.rowType).setRequired(true)) override def uidRequiredness: VirtualTypeWithReq = VirtualTypeWithReq(PInt64Required) - override def globalRequiredness(ctx: ExecuteContext, requestedType: TableType): VirtualTypeWithReq = + override def globalRequiredness(ctx: ExecuteContext, requestedType: TableType) + : VirtualTypeWithReq = VirtualTypeWithReq(PType.canonical(requestedType.globalType)) def executeGeneric(ctx: ExecuteContext): GenericTableValue = { @@ -340,21 +395,28 @@ class MatrixPLINKReader( "bed" -> TString, "start" -> TInt32, "end" -> TInt32, - "partitionIndex" -> TInt32) + "partitionIndex" -> TInt32, + ) val contextsWithPartIdx = contexts.zipWithIndex.map { case (row: Row, partIdx: Int) => Row(row(0), row(1), row(2), partIdx) } - val fullRowPType = PCanonicalStruct(true, + val fullRowPType = PCanonicalStruct( + true, "locus" -> PCanonicalLocus.schemaFromRG(referenceGenome.map(_.name), true), "alleles" -> PCanonicalArray(PCanonicalString(true), true), "rsid" -> PCanonicalString(true), "cm_position" -> PFloat64(true), - LowerMatrixIR.entriesFieldName -> PCanonicalArray(PCanonicalStruct(true, "GT" -> PCanonicalCall()), true), - rowUIDFieldName -> PInt64Required) + LowerMatrixIR.entriesFieldName -> PCanonicalArray( + PCanonicalStruct(true, "GT" -> PCanonicalCall()), + true, + ), + rowUIDFieldName -> PInt64Required, + ) - val bodyPType = (requestedRowType: TStruct) => fullRowPType.subsetTo(requestedRowType).asInstanceOf[PStruct] + val bodyPType = + (requestedRowType: TStruct) => fullRowPType.subsetTo(requestedRowType).asInstanceOf[PStruct] val body = { (requestedType: TStruct) => val hasLocus = requestedType.hasField("locus") @@ -364,8 +426,9 @@ class MatrixPLINKReader( val hasRowUID = requestedType.hasField(rowUIDFieldName) val hasEntries = requestedType.hasField(LowerMatrixIR.entriesFieldName) - val hasGT = hasEntries && (requestedType.fieldType(LowerMatrixIR.entriesFieldName).asInstanceOf[TArray] - .elementType.asInstanceOf[TStruct].hasField("GT")) + val hasGT = + hasEntries && (requestedType.fieldType(LowerMatrixIR.entriesFieldName).asInstanceOf[TArray] + .elementType.asInstanceOf[TStruct].hasField("GT")) val requestedPType = bodyPType(requestedType) @@ -382,19 +445,19 @@ class MatrixPLINKReader( val is = fs.open(bed) if (TaskContext.get != null) { // FIXME: need to close InputStream for other backends too - TaskContext.get.addTaskCompletionListener[Unit] { (context: TaskContext) => - is.close() - } + TaskContext.get.addTaskCompletionListener[Unit]((context: TaskContext) => is.close()) } var offset: Long = 0 val input = new Array[Byte](blockLength) val table = new Array[Int](4) - table(0) = if (localA2Reference) Call2.fromUnphasedDiploidGtIndex(2) else Call2.fromUnphasedDiploidGtIndex(0) + table(0) = if (localA2Reference) Call2.fromUnphasedDiploidGtIndex(2) + else Call2.fromUnphasedDiploidGtIndex(0) // 1 missing table(2) = Call2.fromUnphasedDiploidGtIndex(1) - table(3) = if (localA2Reference) Call2.fromUnphasedDiploidGtIndex(0) else Call2.fromUnphasedDiploidGtIndex(2) + table(3) = if (localA2Reference) Call2.fromUnphasedDiploidGtIndex(0) + else Call2.fromUnphasedDiploidGtIndex(2) Iterator.range(start, end).flatMap { i => val variant = variantsBc.value(i) @@ -482,7 +545,8 @@ class MatrixPLINKReader( contextType, contextsWithPartIdx, bodyPType, - body) + body, + ) } override def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR = { diff --git a/hail/src/main/scala/is/hail/io/reference/FASTAReader.scala b/hail/src/main/scala/is/hail/io/reference/FASTAReader.scala index 193a661680f..d3e2cf65473 100644 --- a/hail/src/main/scala/is/hail/io/reference/FASTAReader.scala +++ b/hail/src/main/scala/is/hail/io/reference/FASTAReader.scala @@ -1,22 +1,30 @@ package is.hail.io.reference -import java.util -import java.util.Map.Entry -import java.util.concurrent.locks.{Lock, ReentrantLock} -import htsjdk.samtools.reference.{ReferenceSequenceFile, ReferenceSequenceFileFactory} -import is.hail.backend.{BroadcastValue, ExecuteContext} +import is.hail.backend.ExecuteContext +import is.hail.io.fs.FS import is.hail.utils._ import is.hail.variant.{Locus, ReferenceGenome} -import is.hail.io.fs.FS -import scala.language.postfixOps import scala.collection.concurrent -case class FASTAReaderConfig(tmpdir: String, fs: FS, rg: ReferenceGenome, - fastaFile: String, indexFile: String, blockSize: Int = 4096, capacity: Int = 100 +import java.util +import java.util.Map.Entry +import java.util.concurrent.locks.{Lock, ReentrantLock} + +import htsjdk.samtools.reference.{ReferenceSequenceFile, ReferenceSequenceFileFactory} + +case class FASTAReaderConfig( + tmpdir: String, + fs: FS, + rg: ReferenceGenome, + fastaFile: String, + indexFile: String, + blockSize: Int = 4096, + capacity: Int = 100, ) { if (blockSize <= 0) fatal(s"'blockSize' must be greater than 0. Found $blockSize.") + if (capacity <= 0) fatal(s"'capacity' must be greater than 0. Found $capacity.") @@ -29,11 +37,13 @@ object FASTAReader { def getLocalFastaFile(tmpdir: String, fs: FS, fastaFile: String, indexFile: String): String = { localFastaLock.lock() - try { - localFastaFiles.getOrElseUpdate(fastaFile, FASTAReader.setup(tmpdir, fs, fastaFile, indexFile)) - } finally { + try + localFastaFiles.getOrElseUpdate( + fastaFile, + FASTAReader.setup(tmpdir, fs, fastaFile, indexFile), + ) + finally localFastaLock.unlock() - } } def setup(tmpdir: String, fs: FS, fastaFile: String, indexFile: String): String = { @@ -51,10 +61,12 @@ object FASTAReader { fs.copyRecode(indexFile, localIndexFile) } - if (!fs.exists(localFastaFile)) + if (!fs.isFile(localFastaFile)) fatal(s"Error while copying FASTA file to local file system. Did not find '$localFastaFile'.") - if (!fs.exists(localIndexFile)) - fatal(s"Error while copying FASTA index file to local file system. Did not find '$localIndexFile'.") + if (!fs.isFile(localIndexFile)) + fatal( + s"Error while copying FASTA index file to local file system. Did not find '$localIndexFile'." + ) localFastaFile } @@ -70,19 +82,20 @@ class FASTAReader(val cfg: FASTAReaderConfig) { private[this] var reader: ReferenceSequenceFile = newReader() - @transient private[this] lazy val cache = new util.LinkedHashMap[Int, String](capacity, 0.75f, true) { - override def removeEldestEntry(eldest: Entry[Int, String]): Boolean = size() > capacity - } + @transient private[this] lazy val cache = + new util.LinkedHashMap[Int, String](capacity, 0.75f, true) { + override def removeEldestEntry(eldest: Entry[Int, String]): Boolean = size() > capacity + } private def hash(pos: Long): Int = (pos / blockSize).toInt private def getSequence(contig: String, start: Int, end: Int): String = { val maxEnd = rg.contigLength(contig) - try { + try reader.getSubsequenceAt(contig, start, if (end > maxEnd) maxEnd else end).getBaseString - } catch { + catch { // One retry, to refresh the file - case e: htsjdk.samtools.SAMException => + case _: htsjdk.samtools.SAMException => reader = newReader() reader.getSubsequenceAt(contig, start, if (end > maxEnd) maxEnd else end).getBaseString } diff --git a/hail/src/main/scala/is/hail/io/reference/LiftOver.scala b/hail/src/main/scala/is/hail/io/reference/LiftOver.scala index 7dc2f403f62..c24f81f6d5b 100644 --- a/hail/src/main/scala/is/hail/io/reference/LiftOver.scala +++ b/hail/src/main/scala/is/hail/io/reference/LiftOver.scala @@ -1,16 +1,10 @@ package is.hail.io.reference -import java.io.File -import java.net.URI - -import is.hail.backend.ExecuteContext -import is.hail.variant.{Locus, ReferenceGenome} -import is.hail.utils._ import is.hail.io.fs.FS +import is.hail.utils._ +import is.hail.variant.{Locus, ReferenceGenome} import scala.collection.JavaConverters._ -import scala.collection.concurrent -import scala.language.implicitConversions object LiftOver { def apply(fs: FS, chainFile: String): LiftOver = new LiftOver(fs, chainFile) @@ -19,7 +13,10 @@ object LiftOver { class LiftOver(fs: FS, val chainFile: String) { val lo = using(fs.open(chainFile))(new htsjdk.samtools.liftover.LiftOver(_, chainFile)) - def queryInterval(interval: is.hail.utils.Interval, minMatch: Double = htsjdk.samtools.liftover.LiftOver.DEFAULT_LIFTOVER_MINMATCH): (is.hail.utils.Interval, Boolean) = { + def queryInterval( + interval: is.hail.utils.Interval, + minMatch: Double = htsjdk.samtools.liftover.LiftOver.DEFAULT_LIFTOVER_MINMATCH, + ): (is.hail.utils.Interval, Boolean) = { val start = interval.start.asInstanceOf[Locus] val end = interval.end.asInstanceOf[Locus] @@ -31,29 +28,38 @@ class LiftOver(fs: FS, val chainFile: String) { val endPos = if (interval.includesEnd) end.position else end.position - 1 if (startPos == endPos) - fatal(s"Cannot liftover a 0-length interval: ${ interval.toString }.\nDid you mean to use 'liftover_locus'?") + fatal( + s"Cannot liftover a 0-length interval: ${interval.toString}.\nDid you mean to use 'liftover_locus'?" + ) val result = lo.liftOver(new htsjdk.samtools.util.Interval(contig, startPos, endPos), minMatch) if (result != null) - (Interval( - Locus(result.getContig, result.getStart), - Locus(result.getContig, result.getEnd), - includesStart = true, - includesEnd = true), - result.isNegativeStrand) + ( + Interval( + Locus(result.getContig, result.getStart), + Locus(result.getContig, result.getEnd), + includesStart = true, + includesEnd = true, + ), + result.isNegativeStrand, + ) else null } - def queryLocus(l: Locus, minMatch: Double = htsjdk.samtools.liftover.LiftOver.DEFAULT_LIFTOVER_MINMATCH): (Locus, Boolean) = { - val result = lo.liftOver(new htsjdk.samtools.util.Interval(l.contig, l.position, l.position), minMatch) + def queryLocus( + l: Locus, + minMatch: Double = htsjdk.samtools.liftover.LiftOver.DEFAULT_LIFTOVER_MINMATCH, + ): (Locus, Boolean) = { + val result = + lo.liftOver(new htsjdk.samtools.util.Interval(l.contig, l.position, l.position), minMatch) if (result != null) (Locus(result.getContig, result.getStart), result.isNegativeStrand) else null } - def checkChainFile(srcRG: ReferenceGenome, destRG: ReferenceGenome) { + def checkChainFile(srcRG: ReferenceGenome, destRG: ReferenceGenome): Unit = { val cMap = lo.getContigMap.asScala cMap.foreach { case (srcContig, destContigs) => srcRG.checkContig(srcContig) diff --git a/hail/src/main/scala/is/hail/io/tabix/TabixReader.scala b/hail/src/main/scala/is/hail/io/tabix/TabixReader.scala index d6e28608713..ee1f09d56fd 100644 --- a/hail/src/main/scala/is/hail/io/tabix/TabixReader.scala +++ b/hail/src/main/scala/is/hail/io/tabix/TabixReader.scala @@ -1,9 +1,5 @@ package is.hail.io.tabix -import java.io.InputStream - -import htsjdk.samtools.util.FileExtensions -import htsjdk.tribble.util.ParsingUtils import is.hail.expr.ir.IntArrayBuilder import is.hail.io.compress.BGzipLineReader import is.hail.io.fs.FS @@ -12,6 +8,11 @@ import is.hail.utils._ import scala.collection.mutable import scala.language.implicitConversions +import java.io.InputStream + +import htsjdk.samtools.util.FileExtensions +import htsjdk.tribble.util.ParsingUtils + // Helper data classes class Tabix( @@ -21,7 +22,7 @@ class Tabix( val meta: Int, val seqs: Array[String], val chr2tid: mutable.HashMap[String, Int], - val indices: Array[(mutable.HashMap[Int, Array[TbiPair]], Array[Long])] + val indices: Array[(mutable.HashMap[Int, Array[TbiPair]], Array[Long])], ) case class TbiPair(var _1: Long, var _2: Long) extends java.lang.Comparable[TbiPair] { @@ -56,19 +57,19 @@ object TabixReader { def readInt(is: InputStream): Int = (is.read() & 0xff) | - ((is.read() & 0xff) << 8) | - ((is.read() & 0xff) << 16) | - ((is.read() & 0xff) << 24) + ((is.read() & 0xff) << 8) | + ((is.read() & 0xff) << 16) | + ((is.read() & 0xff) << 24) def readLong(is: InputStream): Long = (is.read() & 0xff).asInstanceOf[Long] | - ((is.read() & 0xff).asInstanceOf[Long] << 8) | - ((is.read() & 0xff).asInstanceOf[Long] << 16) | - ((is.read() & 0xff).asInstanceOf[Long] << 24) | - ((is.read() & 0xff).asInstanceOf[Long] << 32) | - ((is.read() & 0xff).asInstanceOf[Long] << 40) | - ((is.read() & 0xff).asInstanceOf[Long] << 48) | - ((is.read() & 0xff).asInstanceOf[Long] << 56) + ((is.read() & 0xff).asInstanceOf[Long] << 8) | + ((is.read() & 0xff).asInstanceOf[Long] << 16) | + ((is.read() & 0xff).asInstanceOf[Long] << 24) | + ((is.read() & 0xff).asInstanceOf[Long] << 32) | + ((is.read() & 0xff).asInstanceOf[Long] << 40) | + ((is.read() & 0xff).asInstanceOf[Long] << 48) | + ((is.read() & 0xff).asInstanceOf[Long] << 56) } class TabixReader(val filePath: String, fs: FS, idxFilePath: Option[String] = None) { @@ -88,20 +89,20 @@ class TabixReader(val filePath: String, fs: FS, idxFilePath: Option[String] = No is.read(buf, 0, 4) // read magic bytes "TBI\1" if (!(Magic sameElements buf)) fatal(s"""magic number failed validation - |magic: ${ Magic.mkString("[", ",", "]") } - |data : ${ buf.mkString("[", ",", "]") }""".stripMargin) + |magic: ${Magic.mkString("[", ",", "]")} + |data : ${buf.mkString("[", ",", "]")}""".stripMargin) val seqs = new Array[String](readInt(is)) val format = readInt(is) // Require VCF for now if (format != 2) - fatal(s"Hail only supports tabix indexing for VCF, found format code ${ format }") + fatal(s"Hail only supports tabix indexing for VCF, found format code $format") val colSeq = readInt(is) val colBeg = readInt(is) - val colEnd = readInt(is) + readInt(is) // colEnd val meta = readInt(is) // meta char for VCF is '#' if (meta != '#') - fatal(s"Meta character was ${ meta }, should be '#' for VCF") + fatal(s"Meta character was $meta, should be '#' for VCF") val chr2tid = new mutable.HashMap[String, Int]() readInt(is) // unused, need to consume @@ -121,7 +122,8 @@ class TabixReader(val filePath: String, fs: FS, idxFilePath: Option[String] = No } // read the index - val indices = new BoxedArrayBuilder[(mutable.HashMap[Int, Array[TbiPair]], Array[Long])](seqs.length) + val indices = + new BoxedArrayBuilder[(mutable.HashMap[Int, Array[TbiPair]], Array[Long])](seqs.length) i = 0 while (i < seqs.length) { // binning index @@ -154,9 +156,9 @@ class TabixReader(val filePath: String, fs: FS, idxFilePath: Option[String] = No } def chr2tid(chr: String): Int = index.chr2tid.get(chr) match { - case Some(i) => i - case _ => -1 - } + case Some(i) => i + case _ => -1 + } // This method returns an array of tuples suitable to be passed to the constructor of // TabixLineIterator. The arguments beg and end are endpoints to an interval of loci within tid. @@ -169,11 +171,11 @@ class TabixReader(val filePath: String, fs: FS, idxFilePath: Option[String] = No val idx = index.indices(tid) val bins = reg2bins(beg, end) val minOff = if (idx._2.length > 0 && (beg >> TadLidxShift) >= idx._2.length) - idx._2(idx._2.length - 1) - else if (idx._2.length > 0) - idx._2(beg >> TadLidxShift) - else - 0L + idx._2(idx._2.length - 1) + else if (idx._2.length > 0) + idx._2(beg >> TadLidxShift) + else + 0L var i = 0 var nOff = 0 @@ -189,7 +191,8 @@ class TabixReader(val filePath: String, fs: FS, idxFilePath: Option[String] = No i = 0 while (i < bins.length) { val c = idx._1.getOrElse(bins(i), null) - val len = if (c == null) { 0 } else { c.length } + val len = if (c == null) { 0 } + else { c.length } var j = 0 while (j < len) { if (TbiOrd.less64(minOff, c(j)._2)) { @@ -295,10 +298,8 @@ class TabixReader(val filePath: String, fs: FS, idxFilePath: Option[String] = No final class TabixLineIterator( private val fs: FS, private val filePath: String, - private val offsets: Array[TbiPair] -) - extends java.lang.AutoCloseable -{ + private val offsets: Array[TbiPair], +) extends java.lang.AutoCloseable { private var i: Int = -1 private var isEof = false private var lines = new BGzipLineReader(fs, filePath) @@ -331,10 +332,9 @@ final class TabixLineIterator( def getCurIdx(): Long = offsetOfPreviousLine - override def close() { + override def close(): Unit = if (lines != null) { lines.close() lines = null } - } } diff --git a/hail/src/main/scala/is/hail/io/vcf/ExportVCF.scala b/hail/src/main/scala/is/hail/io/vcf/ExportVCF.scala index 68e9c0e5c21..b1f63499d26 100644 --- a/hail/src/main/scala/is/hail/io/vcf/ExportVCF.scala +++ b/hail/src/main/scala/is/hail/io/vcf/ExportVCF.scala @@ -1,20 +1,16 @@ package is.hail.io.vcf -import htsjdk.samtools.util.FileExtensions -import htsjdk.tribble.SimpleFeature -import htsjdk.tribble.index.tabix.{TabixFormat, TabixIndexCreator} -import is.hail -import is.hail.annotations.Region -import is.hail.backend.ExecuteContext -import is.hail.expr.ir.MatrixValue +import is.hail.io.{VCFAttributes, VCFFieldAttributes, VCFMetadata} import is.hail.io.compress.{BGzipLineReader, BGzipOutputStream} import is.hail.io.fs.FS -import is.hail.io.{VCFAttributes, VCFFieldAttributes, VCFMetadata} -import is.hail.types.MatrixType -import is.hail.types.physical._ import is.hail.types.virtual._ import is.hail.utils._ -import is.hail.variant.{Call, ReferenceGenome, RegionValueVariant} +import is.hail.variant.ReferenceGenome + +import htsjdk.samtools.util.FileExtensions +import htsjdk.tribble.SimpleFeature +import htsjdk.tribble.index.tabix.{TabixFormat, TabixIndexCreator} +import is.hail object ExportVCF { def infoNumber(t: Type): String = t match { @@ -44,7 +40,7 @@ object ExportVCF { } tOption match { case Some(s) => s - case _ => fatal(s"INFO field '${ f.name }': VCF does not support type '${ f.typ }'.") + case _ => fatal(s"INFO field '${f.name}': VCF does not support type '${f.typ}'.") } } @@ -81,7 +77,7 @@ object ExportVCF { } } - def checkInfoSignature(ti: TStruct) { + def checkInfoSignature(ti: TStruct): Unit = { val invalid = ti.fields.flatMap { fd => val valid = fd.typ match { case it: TContainer if it.elementType != TBoolean => validInfoType(it.elementType) @@ -90,11 +86,15 @@ object ExportVCF { if (valid) { None } else { - Some(s"\t'${ fd.name }': '${ fd.typ }'.") + Some(s"\t'${fd.name}': '${fd.typ}'.") } } if (!invalid.isEmpty) { - fatal("VCF does not support the type(s) for the following INFO field(s):\n" + invalid.mkString("\n")) + fatal( + "VCF does not support the type(s) for the following INFO field(s):\n" + invalid.mkString( + "\n" + ) + ) } } @@ -110,7 +110,7 @@ object ExportVCF { } } - def checkFormatSignature(tg: TStruct) { + def checkFormatSignature(tg: TStruct): Unit = { val invalid = tg.fields.flatMap { fd => val valid = fd.typ match { case it: TContainer => validFormatType(it.elementType) @@ -119,26 +119,37 @@ object ExportVCF { if (valid) { None } else { - Some(s"\t'${ fd.name }': '${ fd.typ }'.") + Some(s"\t'${fd.name}': '${fd.typ}'.") } } if (!invalid.isEmpty) { - fatal("VCF does not support the type(s) for the following FORMAT field(s):\n" + invalid.mkString("\n")) + fatal( + "VCF does not support the type(s) for the following FORMAT field(s):\n" + invalid.mkString( + "\n" + ) + ) } } def getAttributes(k1: String, attributes: Option[VCFMetadata]): Option[VCFAttributes] = attributes.flatMap(_.get(k1)) - def getAttributes(k1: String, k2: String, attributes: Option[VCFMetadata]): Option[VCFFieldAttributes] = + def getAttributes(k1: String, k2: String, attributes: Option[VCFMetadata]) + : Option[VCFFieldAttributes] = getAttributes(k1, attributes).flatMap(_.get(k2)) - def makeHeader(rowType: TStruct, entryType: TStruct, rg: ReferenceGenome, append: Option[String], - metadata: Option[VCFMetadata], sampleIds: Array[String]): String = { + def makeHeader( + rowType: TStruct, + entryType: TStruct, + rg: ReferenceGenome, + append: Option[String], + metadata: Option[VCFMetadata], + sampleIds: Array[String], + ): String = { val sb = new StringBuilder() sb.append("##fileformat=VCFv4.2\n") - sb.append(s"##hailversion=${ hail.HAIL_PRETTY_VERSION }\n") + sb.append(s"##hailversion=${hail.HAIL_PRETTY_VERSION}\n") entryType.fields.foreach { f => val attrs = getAttributes("format", f.name, metadata).getOrElse(Map.empty[String, String]) @@ -153,7 +164,8 @@ object ExportVCF { sb.append("\">\n") } - val filters = getAttributes("filter", metadata).getOrElse(Map.empty[String, Any]).keys.toArray.sorted + val filters = + getAttributes("filter", metadata).getOrElse(Map.empty[String, Any]).keys.toArray.sorted filters.foreach { id => val attrs = getAttributes("filter", id, metadata).getOrElse(Map.empty[String, String]) sb.append("##FILTER=\n") } - append.foreach { append => - sb.append(append) - } + append.foreach(append => sb.append(append)) val assembly = rg.name rg.contigs.foreachBetween { c => @@ -211,14 +221,19 @@ object ExportVCF { sb.result() } - def lookupVAField(rowType: TStruct, fieldName: String, vcfColName: String, expectedTypeOpt: Option[Type]): (Boolean, Int) = { + def lookupVAField( + rowType: TStruct, + fieldName: String, + vcfColName: String, + expectedTypeOpt: Option[Type], + ): (Boolean, Int) = { rowType.fieldIdx.get(fieldName) match { case Some(idx) => val t = rowType.types(idx) if (expectedTypeOpt.forall(t == _)) // FIXME: make sure this is right (true, idx) else { - warn(s"export_vcf found row field $fieldName with type '$t', but expected type ${ expectedTypeOpt.get }. " + + warn(s"export_vcf found row field $fieldName with type '$t', but expected type ${expectedTypeOpt.get}. " + s"Emitting missing $vcfColName.") (false, 0) } @@ -228,28 +243,29 @@ object ExportVCF { } object TabixVCF { - def apply(fs: FS, filePath: String) { - val idx = using (new BGzipLineReader(fs, filePath)) { lines => - val tabix = new TabixIndexCreator(TabixFormat.VCF) - var fileOffset = lines.getVirtualOffset - var s = lines.readLine() - while (s != null) { - if (s.nonEmpty && s.charAt(0) != '#') { - val Array(chrom, posStr, _*) = s.split("\t", 3) - val pos = posStr.toInt - val feature = new SimpleFeature(chrom, pos, pos) - tabix.addFeature(feature, fileOffset) - } - - fileOffset = lines.getVirtualOffset - s = lines.readLine() - } - - tabix.finalizeIndex(fileOffset) - } - val tabixPath = htsjdk.tribble.util.ParsingUtils.appendToPath(filePath, FileExtensions.TABIX_INDEX) - using (new BGzipOutputStream(fs.createNoCompression(tabixPath))) { bgzos => - using (new htsjdk.tribble.util.LittleEndianOutputStream(bgzos)) { os => idx.write(os) } - } - } + def apply(fs: FS, filePath: String): Unit = { + val idx = using(new BGzipLineReader(fs, filePath)) { lines => + val tabix = new TabixIndexCreator(TabixFormat.VCF) + var fileOffset = lines.getVirtualOffset + var s = lines.readLine() + while (s != null) { + if (s.nonEmpty && s.charAt(0) != '#') { + val Array(chrom, posStr, _*) = s.split("\t", 3) + val pos = posStr.toInt + val feature = new SimpleFeature(chrom, pos, pos) + tabix.addFeature(feature, fileOffset) + } + + fileOffset = lines.getVirtualOffset + s = lines.readLine() + } + + tabix.finalizeIndex(fileOffset) + } + val tabixPath = + htsjdk.tribble.util.ParsingUtils.appendToPath(filePath, FileExtensions.TABIX_INDEX) + using(new BGzipOutputStream(fs.createNoCompression(tabixPath))) { bgzos => + using(new htsjdk.tribble.util.LittleEndianOutputStream(bgzos))(os => idx.write(os)) + } + } } diff --git a/hail/src/main/scala/is/hail/io/vcf/LoadVCF.scala b/hail/src/main/scala/is/hail/io/vcf/LoadVCF.scala index 73a4c3ccf3f..2751cf58485 100644 --- a/hail/src/main/scala/is/hail/io/vcf/LoadVCF.scala +++ b/hail/src/main/scala/is/hail/io/vcf/LoadVCF.scala @@ -1,18 +1,21 @@ package is.hail.io.vcf -import htsjdk.variant.vcf._ import is.hail.annotations._ import is.hail.asm4s._ -import is.hail.backend.spark.SparkBackend import is.hail.backend.{BroadcastValue, ExecuteContext, HailStateManager} +import is.hail.backend.spark.SparkBackend import is.hail.expr.JSONAnnotationImpex +import is.hail.expr.ir.{ + CloseableIterator, EmitCode, EmitCodeBuilder, EmitMethodBuilder, GenericLine, GenericLines, + GenericTableValue, IEmitCode, IR, IRParser, Literal, LowerMatrixIR, MatrixHybridReader, + MatrixReader, PartitionReader, +} import is.hail.expr.ir.lowering.TableStage import is.hail.expr.ir.streams.StreamProducer -import is.hail.expr.ir.{CloseableIterator, EmitCode, EmitCodeBuilder, EmitMethodBuilder, GenericLine, GenericLines, GenericTableValue, IEmitCode, IR, IRParser, Literal, LowerMatrixIR, MatrixHybridReader, MatrixReader, PartitionReader} +import is.hail.io.{VCFAttributes, VCFMetadata} import is.hail.io.fs.{FS, FileListEntry} import is.hail.io.tabix._ import is.hail.io.vcf.LoadVCF.{getHeaderLines, parseHeader} -import is.hail.io.{VCFAttributes, VCFMetadata} import is.hail.rvd.{RVDPartitioner, RVDType} import is.hail.sparkextras.ContextRDD import is.hail.types._ @@ -21,28 +24,29 @@ import is.hail.types.physical.stypes.interfaces.{SBaseStructValue, SStreamValue} import is.hail.types.virtual._ import is.hail.utils._ import is.hail.variant._ -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Row -import org.apache.spark.{Partition, TaskContext} -import org.json4s.JsonAST.{JArray, JObject, JString} -import org.json4s.jackson.JsonMethods -import org.json4s.{DefaultFormats, Formats, JValue} import scala.annotation.meta.param import scala.annotation.switch import scala.collection.JavaConverters._ -import scala.language.implicitConversions -class BufferedLineIterator(bit: BufferedIterator[String]) extends htsjdk.tribble.readers.LineIterator { +import htsjdk.variant.vcf._ +import org.apache.spark.{Partition, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.Row +import org.json4s.{DefaultFormats, Formats, JValue} +import org.json4s.JsonAST.{JArray, JObject, JString} +import org.json4s.jackson.JsonMethods + +class BufferedLineIterator(bit: BufferedIterator[String]) + extends htsjdk.tribble.readers.LineIterator { override def peek(): String = bit.head override def hasNext: Boolean = bit.hasNext override def next(): String = bit.next() - override def remove() { + override def remove(): Unit = throw new UnsupportedOperationException - } } object VCFHeaderInfo { @@ -53,14 +57,16 @@ object VCFHeaderInfo { "filterAttrs" -> TDict(TString, TDict(TString, TString)), "infoAttrs" -> TDict(TString, TDict(TString, TString)), "formatAttrs" -> TDict(TString, TDict(TString, TString)), - "infoFlagFields" -> TArray(TString) + "infoFlagFields" -> TArray(TString), ) val headerTypePType: PType = PType.canonical(headerType, required = false, innerRequired = true) def fromJSON(jv: JValue): VCFHeaderInfo = { - val sampleIDs = (jv \ "sampleIDs").asInstanceOf[JArray].arr.map(_.asInstanceOf[JString].s).toArray - val infoFlagFields = (jv \ "infoFlagFields").asInstanceOf[JArray].arr.map(_.asInstanceOf[JString].s).toSet + val sampleIDs = + (jv \ "sampleIDs").asInstanceOf[JArray].arr.map(_.asInstanceOf[JString].s).toArray + val infoFlagFields = + (jv \ "infoFlagFields").asInstanceOf[JArray].arr.map(_.asInstanceOf[JString].s).toSet def lookupFields(name: String) = (jv \ name).asInstanceOf[JArray].arr.map { case elt: JArray => val List(name: JString, typeStr: JString) = elt.arr @@ -78,16 +84,24 @@ object VCFHeaderInfo { val filterAttrs = lookupAttrs("filterAttrs") val infoAttrs = lookupAttrs("infoAttrs") val formatAttrs = lookupAttrs("formatAttrs") - VCFHeaderInfo(sampleIDs, infoFields, formatFields, filterAttrs, infoAttrs, formatAttrs, infoFlagFields) + VCFHeaderInfo(sampleIDs, infoFields, formatFields, filterAttrs, infoAttrs, formatAttrs, + infoFlagFields) } } -case class VCFHeaderInfo(sampleIds: Array[String], infoFields: Array[(String, Type)], formatFields: Array[(String, Type)], - filtersAttrs: VCFAttributes, infoAttrs: VCFAttributes, formatAttrs: VCFAttributes, infoFlagFields: Set[String]) { +case class VCFHeaderInfo( + sampleIds: Array[String], + infoFields: Array[(String, Type)], + formatFields: Array[(String, Type)], + filtersAttrs: VCFAttributes, + infoAttrs: VCFAttributes, + formatAttrs: VCFAttributes, + infoFlagFields: Set[String], +) { def formatCompatible(other: VCFHeaderInfo): Boolean = { val m = formatFields.toMap - other.formatFields.forall { case (name, t) => m(name) == t} + other.formatFields.forall { case (name, t) => m(name) == t } } def infoCompatible(other: VCFHeaderInfo): Boolean = { @@ -98,10 +112,16 @@ case class VCFHeaderInfo(sampleIds: Array[String], infoFields: Array[(String, Ty def genotypeSignature: TStruct = TStruct(formatFields: _*) def infoSignature: TStruct = TStruct(infoFields: _*) + def getPTypes(arrayElementsRequired: Boolean, entryFloatType: Type, callFields: Set[String]) + : (PStruct, PStruct, PStruct) = { - def getPTypes(arrayElementsRequired: Boolean, entryFloatType: Type, callFields: Set[String]): (PStruct, PStruct, PStruct) = { - - def typeToPType(fdName: String, t: Type, floatType: Type, required: Boolean, isCallField: Boolean): PType = { + def typeToPType( + fdName: String, + t: Type, + floatType: Type, + required: Boolean, + isCallField: Boolean, + ): PType = { t match { case TString if isCallField => PCanonicalCall(required) case t if isCallField => @@ -115,22 +135,38 @@ case class VCFHeaderInfo(sampleIds: Array[String], infoFields: Array[(String, Ty case TBoolean => PBooleanRequired case TArray(t2) => - PCanonicalArray(typeToPType(fdName, t2, floatType, arrayElementsRequired, false), required) + PCanonicalArray( + typeToPType(fdName, t2, floatType, arrayElementsRequired, false), + required, + ) } } - val infoType = PCanonicalStruct(true, infoFields.map { case (name, t) => - (name, typeToPType(name, t, TFloat64, false, false)) - }: _*) - val formatType = PCanonicalStruct(true, formatFields.map { case (name, t) => - (name, typeToPType(name, t, entryFloatType, false, name == "GT" || callFields.contains(name))) - }: _*) - - val vaSignature = PCanonicalStruct(Array( - PField("rsid", PCanonicalString(), 0), - PField("qual", PFloat64(), 1), - PField("filters", PCanonicalSet(PCanonicalString(true)), 2), - PField("info", infoType, 3)), true) + val infoType = PCanonicalStruct( + true, + infoFields.map { case (name, t) => + (name, typeToPType(name, t, TFloat64, false, false)) + }: _* + ) + val formatType = PCanonicalStruct( + true, + formatFields.map { case (name, t) => + ( + name, + typeToPType(name, t, entryFloatType, false, name == "GT" || callFields.contains(name)), + ) + }: _* + ) + + val vaSignature = PCanonicalStruct( + Array( + PField("rsid", PCanonicalString(), 0), + PField("qual", PFloat64(), 1), + PField("filters", PCanonicalSet(PCanonicalString(true)), 2), + PField("info", infoType, 3), + ), + true, + ) (infoType, vaSignature, formatType) } @@ -139,8 +175,14 @@ case class VCFHeaderInfo(sampleIds: Array[String], infoFields: Array[(String, Ty rvb.start(VCFHeaderInfo.headerTypePType) rvb.startStruct() rvb.addAnnotation(rvb.currentType().virtualType, sampleIds.toFastSeq) - rvb.addAnnotation(rvb.currentType().virtualType, infoFields.map { case (x1, x2) => Row(x1, x2.parsableString()) }.toFastSeq) - rvb.addAnnotation(rvb.currentType().virtualType, formatFields.map { case (x1, x2) => Row(x1, x2.parsableString()) }.toFastSeq) + rvb.addAnnotation( + rvb.currentType().virtualType, + infoFields.map { case (x1, x2) => Row(x1, x2.parsableString()) }.toFastSeq, + ) + rvb.addAnnotation( + rvb.currentType().virtualType, + formatFields.map { case (x1, x2) => Row(x1, x2.parsableString()) }.toFastSeq, + ) rvb.addAnnotation(rvb.currentType().virtualType, if (dropAttrs) Map.empty else filtersAttrs) rvb.addAnnotation(rvb.currentType().virtualType, if (dropAttrs) Map.empty else infoAttrs) rvb.addAnnotation(rvb.currentType().virtualType, if (dropAttrs) Map.empty else formatAttrs) @@ -153,8 +195,9 @@ case class VCFHeaderInfo(sampleIds: Array[String], infoFields: Array[(String, Ty JArray(List(JString(name), JString(t.parsableString()))) }.toList) - def attrsJson(attrs: Map[String, Map[String, String]]): JValue = JObject(attrs.map { case (name, m) => - (name, JObject(name -> JObject(m.map { case (k, v) => (k, JString(v)) }.toList))) + def attrsJson(attrs: Map[String, Map[String, String]]): JValue = JObject(attrs.map { + case (name, m) => + (name, JObject(name -> JObject(m.map { case (k, v) => (k, JString(v)) }.toList))) }.toList) JObject( @@ -164,13 +207,13 @@ case class VCFHeaderInfo(sampleIds: Array[String], infoFields: Array[(String, Ty "filtersAttrs" -> attrsJson(filtersAttrs), "infoAttrs" -> attrsJson(infoAttrs), "formatAttrs" -> attrsJson(formatAttrs), - "infoFlagFields" -> JArray(infoFlagFields.map(JString).toList)) + "infoFlagFields" -> JArray(infoFlagFields.map(JString).toList), + ) } } class VCFParseError(val msg: String, val pos: Int) extends RuntimeException(msg) - final class VCFLine( val line: String, val fileNum: Long, @@ -179,22 +222,25 @@ final class VCFLine( val abs: MissingArrayBuilder[String], val abi: MissingArrayBuilder[Int], val abf: MissingArrayBuilder[Float], - val abd: MissingArrayBuilder[Double]) { + val abd: MissingArrayBuilder[Double], +) { var pos: Int = 0 def parseError(msg: String): Unit = throw new VCFParseError(msg, pos) def numericValue(c: Char): Int = { if (c < '0' || c > '9') - parseError(s"invalid character '${StringEscapeUtils.escapeString(c.toString)}' in integer literal") + parseError( + s"invalid character '${StringEscapeUtils.escapeString(c.toString)}' in integer literal" + ) c - '0' } - // field contexts: field, array field, format field, call field, format array field, filter array field + /* field contexts: field, array field, format field, call field, format array field, filter array + * field */ - def endField(p: Int): Boolean = { + def endField(p: Int): Boolean = p == line.length || line(p) == '\t' - } def endArrayElement(p: Int): Boolean = { if (p == line.length) @@ -280,58 +326,49 @@ final class VCFLine( def endFilterArrayElement(): Boolean = endFilterArrayElement(pos) - def skipInfoField(): Unit = { + def skipInfoField(): Unit = while (!endInfoField()) pos += 1 - } - def skipFormatField(): Unit = { + def skipFormatField(): Unit = while (!endFormatField()) pos += 1 - } - def fieldMissing(): Boolean = { + def fieldMissing(): Boolean = pos < line.length && line(pos) == '.' && endField(pos + 1) - } - def arrayFieldMissing(): Boolean = { + def arrayFieldMissing(): Boolean = pos < line.length && line(pos) == '.' && endArrayElement(pos + 1) - } - def infoFieldMissing(): Boolean = { + def infoFieldMissing(): Boolean = pos < line.length && (line(pos) == '.' && - endInfoField(pos + 1) || - endInfoField(pos)) - } + endInfoField(pos + 1) || + endInfoField(pos)) - def formatFieldMissing(): Boolean = { + def formatFieldMissing(): Boolean = pos < line.length && line(pos) == '.' && endFormatField(pos + 1) - } - def callFieldMissing(): Boolean = { + def callFieldMissing(): Boolean = pos < line.length && line(pos) == '.' && endCallField(pos + 1) - } - def infoArrayElementMissing(): Boolean = { + def infoArrayElementMissing(): Boolean = pos < line.length && line(pos) == '.' && endInfoArrayElement(pos + 1) - } - def formatArrayElementMissing(): Boolean = { + def formatArrayElementMissing(): Boolean = pos < line.length && line(pos) == '.' && endFormatArrayElement(pos + 1) - } def parseString(): String = { val start = pos @@ -358,10 +395,9 @@ final class VCFLine( v * mul } - def skipField(): Unit = { + def skipField(): Unit = while (!endField()) pos += 1 - } def parseStringInArray(): String = { val start = pos @@ -376,7 +412,7 @@ final class VCFLine( assert(abs.size == 0) // . means no alternate alleles - if (fieldMissing()) { + if (fieldMissing()) { pos += 1 // . return } @@ -462,7 +498,8 @@ final class VCFLine( hasLocus: Boolean, hasAlleles: Boolean, hasRSID: Boolean, - skipInvalidLoci: Boolean): Boolean = { + skipInvalidLoci: Boolean, + ): Boolean = { assert(pos == 0) if (line.isEmpty || line(0) == '#') @@ -500,12 +537,11 @@ final class VCFLine( if (hasLocus) { rg match { case Some(_) => rvb.addLocus(recodedContig, start) - case None => { // Without a reference genome, we use a struct of two fields rather than a PLocus + case None => // Without a reference genome, we use a struct of two fields rather than a PLocus rvb.startStruct() // pk: Locus rvb.addString(recodedContig) rvb.addInt(start) rvb.endStruct() - } } } @@ -545,7 +581,7 @@ final class VCFLine( v } - def parseAddCall(rvb: RegionValueBuilder) { + def parseAddCall(rvb: RegionValueBuilder): Unit = { if (pos == line.length) parseError("empty call") @@ -611,13 +647,12 @@ final class VCFLine( v * mul } - def parseAddFormatInt(rvb: RegionValueBuilder) { + def parseAddFormatInt(rvb: RegionValueBuilder): Unit = if (formatFieldMissing()) { rvb.setMissing() pos += 1 } else rvb.addInt(parseFormatInt()) - } def parseFormatString(): String = { val start = pos @@ -627,20 +662,19 @@ final class VCFLine( line.substring(start, end) } - def parseAddFormatString(rvb: RegionValueBuilder) { + def parseAddFormatString(rvb: RegionValueBuilder): Unit = if (formatFieldMissing()) { rvb.setMissing() pos += 1 } else rvb.addString(parseFormatString()) - } def parseFormatFloat(): Float = { val s = parseFormatString() VCFUtils.parseVcfDouble(s).toFloat } - def parseAddFormatFloat(rvb: RegionValueBuilder) { + def parseAddFormatFloat(rvb: RegionValueBuilder): Unit = { if (formatFieldMissing()) { rvb.setMissing() pos += 1 @@ -654,13 +688,12 @@ final class VCFLine( VCFUtils.parseVcfDouble(s) } - def parseAddFormatDouble(rvb: RegionValueBuilder) { + def parseAddFormatDouble(rvb: RegionValueBuilder): Unit = if (formatFieldMissing()) { rvb.setMissing() pos += 1 } else rvb.addDouble(parseFormatDouble()) - } def parseIntInFormatArray(): Int = { if (endFormatArrayElement()) @@ -697,10 +730,12 @@ final class VCFLine( s.toDouble } - def parseArrayElement[T](ab: MissingArrayBuilder[T], eltParser: () => T) { + def parseArrayElement[T](ab: MissingArrayBuilder[T], eltParser: () => T): Unit = { if (formatArrayElementMissing()) { if (arrayElementsRequired) - parseError(s"missing value in FORMAT array. Import with argument 'array_elements_required=False'") + parseError( + "Missing value in FORMAT array. Use 'hl.import_vcf(..., array_elements_required=False)'." + ) ab.addMissing() pos += 1 } else { @@ -708,10 +743,12 @@ final class VCFLine( } } - def parseIntArrayElement() { + def parseArrayIntElement(): Unit = { if (formatArrayElementMissing()) { if (arrayElementsRequired) - parseError(s"missing value in FORMAT array. Import with argument 'array_elements_required=False'") + parseError( + "Missing value in FORMAT array. Use 'hl.import_vcf(..., array_elements_required=False)'." + ) abi.addMissing() pos += 1 } else { @@ -719,10 +756,12 @@ final class VCFLine( } } - def parseFloatArrayElement() { + def parseFloatArrayElement(): Unit = { if (formatArrayElementMissing()) { if (arrayElementsRequired) - parseError(s"missing value in FORMAT array. Import with argument 'array_elements_required=False'") + parseError( + "Missing value in FORMAT array. Use 'hl.import_vcf(..., array_elements_required=False)'." + ) abf.addMissing() pos += 1 } else { @@ -730,10 +769,12 @@ final class VCFLine( } } - def parseDoubleArrayElement() { + def parseArrayDoubleElement(): Unit = { if (formatArrayElementMissing()) { if (arrayElementsRequired) - parseError(s"missing value in FORMAT array. Import with argument 'array_elements_required=False'") + parseError( + "Missing value in FORMAT array. Use 'hl.import_vcf(..., array_elements_required=False)'." + ) abd.addMissing() pos += 1 } else { @@ -741,10 +782,12 @@ final class VCFLine( } } - def parseStringArrayElement() { + def parseArrayStringElement(): Unit = { if (formatArrayElementMissing()) { if (arrayElementsRequired) - parseError(s"missing value in FORMAT array. Import with argument 'array_elements_required=False'") + parseError( + "Missing value in FORMAT array. Use 'hl.import_vcf(..., array_elements_required=False)'." + ) abs.addMissing() pos += 1 } else { @@ -752,18 +795,18 @@ final class VCFLine( } } - def parseAddFormatArrayInt(rvb: RegionValueBuilder) { + def parseAddFormatArrayInt(rvb: RegionValueBuilder): Unit = { if (formatFieldMissing()) { rvb.setMissing() pos += 1 } else { assert(abi.length == 0) - parseIntArrayElement() + parseArrayIntElement() while (!endFormatField()) { pos += 1 // comma - parseIntArrayElement() + parseArrayIntElement() } rvb.startArray(abi.length) @@ -781,17 +824,17 @@ final class VCFLine( } } - def parseAddFormatArrayString(rvb: RegionValueBuilder) { + def parseAddFormatArrayString(rvb: RegionValueBuilder): Unit = { if (formatFieldMissing()) { rvb.setMissing() pos += 1 } else { assert(abs.length == 0) - parseStringArrayElement() + parseArrayStringElement() while (!endFormatField()) { pos += 1 // comma - parseStringArrayElement() + parseArrayStringElement() } rvb.startArray(abs.length) @@ -806,7 +849,7 @@ final class VCFLine( } } - def parseAddFormatArrayFloat(rvb: RegionValueBuilder) { + def parseAddFormatArrayFloat(rvb: RegionValueBuilder): Unit = { if (formatFieldMissing()) { rvb.setMissing() pos += 1 @@ -834,17 +877,17 @@ final class VCFLine( } } - def parseAddFormatArrayDouble(rvb: RegionValueBuilder) { + def parseAddFormatArrayDouble(rvb: RegionValueBuilder): Unit = { if (formatFieldMissing()) { rvb.setMissing() pos += 1 } else { assert(abd.length == 0) - parseDoubleArrayElement() + parseArrayDoubleElement() while (!endFormatField()) { pos += 1 // comma - parseDoubleArrayElement() + parseArrayDoubleElement() } rvb.startArray(abd.length) @@ -889,12 +932,11 @@ final class VCFLine( v * mul } - def parseAddInfoInt(rvb: RegionValueBuilder) { + def parseAddInfoInt(rvb: RegionValueBuilder): Unit = if (!infoFieldMissing()) { rvb.setPresent() rvb.addInt(parseInfoInt()) } - } def parseInfoString(): String = { val start = pos @@ -904,19 +946,17 @@ final class VCFLine( line.substring(start, end) } - def parseAddInfoString(rvb: RegionValueBuilder) { + def parseAddInfoString(rvb: RegionValueBuilder): Unit = if (!infoFieldMissing()) { rvb.setPresent() rvb.addString(parseInfoString()) } - } - def parseAddInfoDouble(rvb: RegionValueBuilder) { + def parseAddInfoDouble(rvb: RegionValueBuilder): Unit = if (!infoFieldMissing()) { rvb.setPresent() rvb.addDouble(VCFUtils.parseVcfDouble(parseInfoString())) } - } def parseIntInInfoArray(): Int = { if (endInfoArrayElement()) @@ -944,24 +984,36 @@ final class VCFLine( def parseDoubleInInfoArray(): Double = VCFUtils.parseVcfDouble(parseStringInInfoArray()) - def parseIntInfoArrayElement() { + def parseInfoArrayIntElement(): Unit = { if (infoArrayElementMissing()) { + if (arrayElementsRequired) + parseError( + "Missing value in INFO array. Use 'hl.import_vcf(..., array_elements_required=False)'." + ) abi.addMissing() - pos += 1 // dot + pos += 1 // dot } else abi += parseIntInInfoArray() } - def parseStringInfoArrayElement() { + def parseInfoArrayStringElement(): Unit = { if (infoArrayElementMissing()) { + if (arrayElementsRequired) + parseError( + "Missing value in INFO array. Use 'hl.import_vcf(..., array_elements_required=False)'." + ) abs.addMissing() - pos += 1 // dot + pos += 1 // dot } else abs += parseStringInInfoArray() } - def parseDoubleInfoArrayElement() { + def parseInfoArrayDoubleElement(): Unit = { if (infoArrayElementMissing()) { + if (arrayElementsRequired) + parseError( + "Missing value in INFO array. Use 'hl.import_vcf(..., array_elements_required=False)'." + ) abd.addMissing() pos += 1 } else { @@ -969,14 +1021,14 @@ final class VCFLine( } } - def parseAddInfoArrayInt(rvb: RegionValueBuilder) { + def parseAddInfoArrayInt(rvb: RegionValueBuilder): Unit = { if (!infoFieldMissing()) { rvb.setPresent() assert(abi.length == 0) - parseIntInfoArrayElement() + parseInfoArrayIntElement() while (!endInfoField()) { - pos += 1 // comma - parseIntInfoArrayElement() + pos += 1 // comma + parseInfoArrayIntElement() } rvb.startArray(abi.length) @@ -993,14 +1045,14 @@ final class VCFLine( } } - def parseAddInfoArrayString(rvb: RegionValueBuilder) { + def parseAddInfoArrayString(rvb: RegionValueBuilder): Unit = { if (!infoFieldMissing()) { rvb.setPresent() assert(abs.length == 0) - parseStringInfoArrayElement() + parseInfoArrayStringElement() while (!endInfoField()) { - pos += 1 // comma - parseStringInfoArrayElement() + pos += 1 // comma + parseInfoArrayStringElement() } rvb.startArray(abs.length) @@ -1017,14 +1069,14 @@ final class VCFLine( } } - def parseAddInfoArrayDouble(rvb: RegionValueBuilder) { + def parseAddInfoArrayDouble(rvb: RegionValueBuilder): Unit = { if (!infoFieldMissing()) { rvb.setPresent() assert(abd.length == 0) - parseDoubleInfoArrayElement() + parseInfoArrayDoubleElement() while (!endInfoField()) { - pos += 1 // comma - parseDoubleInfoArrayElement() + pos += 1 // comma + parseInfoArrayDoubleElement() } rvb.startArray(abd.length) @@ -1041,7 +1093,7 @@ final class VCFLine( } } - def parseAddInfoField(rvb: RegionValueBuilder, typ: Type) { + def parseAddInfoField(rvb: RegionValueBuilder, typ: Type): Unit = { val c = line(pos) if (c != ';' && c != '\t') { if (c != '=') @@ -1071,16 +1123,16 @@ final class VCFLine( } else rvb.addBoolean(true) } else { - try { + try parseAddInfoField(rvb, c.infoFieldTypes(idx)) - } catch { - case e: VCFParseError => parseError(s"error while parsing info field '$key': ${ e.msg }") + catch { + case e: VCFParseError => parseError(s"error while parsing info field '$key': ${e.msg}") } } } } - def parseAddInfo(rvb: RegionValueBuilder, c: ParseLineContext) { + def parseAddInfo(rvb: RegionValueBuilder, c: ParseLineContext): Unit = { rvb.startStruct(init = true, setMissing = true) var i = 0 while (i < c.infoFieldFlagIndices.length) { @@ -1125,16 +1177,18 @@ object FormatParser { new FormatParser( gType, formatFields.map(f => gType.fieldIdx.getOrElse(f, -1)), // -1 means field has been pruned - gType.fields.filter(f => !formatFieldsSet.contains(f.name)).map(_.index).toArray) + gType.fields.filter(f => !formatFieldsSet.contains(f.name)).map(_.index).toArray, + ) } } final class FormatParser( gType: TStruct, formatFieldGIndex: Array[Int], - missingGIndices: Array[Int]) { + missingGIndices: Array[Int], +) { - def parseAddField(l: VCFLine, rvb: RegionValueBuilder, i: Int) { + def parseAddField(l: VCFLine, rvb: RegionValueBuilder, i: Int): Unit = { // negative j values indicate field is pruned val j = formatFieldGIndex(i) if (j == -1) @@ -1164,7 +1218,7 @@ final class FormatParser( } } - def setMissing(rvb: RegionValueBuilder, i: Int) { + def setMissing(rvb: RegionValueBuilder, i: Int): Unit = { val idx = formatFieldGIndex(i) if (idx >= 0) { rvb.setFieldIndex(idx) @@ -1172,7 +1226,7 @@ final class FormatParser( } } - def parse(l: VCFLine, rvb: RegionValueBuilder) { + def parse(l: VCFLine, rvb: RegionValueBuilder): Unit = { rvb.startStruct() // g // FIXME do in bulk, add setDefinedIndex @@ -1210,19 +1264,24 @@ class ParseLineContext( val infoFlagFieldNames: java.util.HashSet[String], val nSamples: Int, val fileNum: Int, - val entriesName: String + val entriesName: String, ) { val entryType: TStruct = rowType.selfField(entriesName) match { - case Some(entriesArray) => entriesArray.typ.asInstanceOf[TArray].elementType.asInstanceOf[TStruct] + case Some(entriesArray) => + entriesArray.typ.asInstanceOf[TArray].elementType.asInstanceOf[TStruct] case None => TStruct.empty } + val infoSignature = rowType.selfField("info").map(_.typ.asInstanceOf[TStruct]).orNull val hasQual = rowType.hasField("qual") val hasFilters = rowType.hasField("filters") val hasEntryFields = entryType.size > 0 - val infoFields: java.util.HashMap[String, Int] = if (infoSignature != null) makeJavaMap(infoSignature.fieldIdx) else null + val infoFields: java.util.HashMap[String, Int] = + if (infoSignature != null) makeJavaMap(infoSignature.fieldIdx) else null + val infoFieldTypes: Array[Type] = if (infoSignature != null) infoSignature.types else null + val infoFieldFlagIndices: Array[Int] = if (infoSignature != null) { infoSignature.fields .iterator @@ -1246,11 +1305,15 @@ class ParseLineContext( } object LoadVCF { - def warnDuplicates(ids: Array[String]) { + def warnDuplicates(ids: Array[String]): Unit = { val duplicates = ids.counter().filter(_._2 > 1) if (duplicates.nonEmpty) { - warn(s"Found ${ duplicates.size } duplicate ${ plural(duplicates.size, "sample ID") }:\n @1", - duplicates.toArray.sortBy(-_._2).map { case (id, count) => s"""($count) "$id"""" }.truncatable("\n ")) + warn( + s"Found ${duplicates.size} duplicate ${plural(duplicates.size, "sample ID")}:\n @1", + duplicates.toArray.sortBy(-_._2).map { case (id, count) => + s"""($count) "$id"""" + }.truncatable("\n "), + ) } } @@ -1259,9 +1322,9 @@ object LoadVCF { case TFloat32 => TFloat32 case TFloat64 => TFloat64 case _ => fatal( - s"""invalid floating point type: - | expected ${TFloat32._toPretty} or ${TFloat64._toPretty}, got ${entryFloatTypeName}""" - ) + s"""invalid floating point type: + | expected ${TFloat32._toPretty} or ${TFloat64._toPretty}, got $entryFloatTypeName""" + ) } } @@ -1294,19 +1357,25 @@ object LoadVCF { case VCFHeaderLineType.Flag => TBoolean } - val attrs = Map("Description" -> line.getDescription, + val attrs = Map( + "Description" -> line.getDescription, "Number" -> headerNumberToString(line), - "Type" -> headerTypeToString(line)) + "Type" -> headerTypeToString(line), + ) val isFlag = line.getType == VCFHeaderLineType.Flag - if (line.isFixedCount && + if ( + line.isFixedCount && (line.getCount == 1 || - (isFlag && line.getCount == 0))) + (isFlag && line.getCount == 0)) + ) ((id, baseType), (id, attrs), isFlag) else if (isFlag) { - warn(s"invalid VCF header: at INFO field '$id' of type 'Flag', expected 'Number=0', got 'Number=${headerNumberToString(line)}''" + - s"\n Interpreting as 'Number=0' regardless.") + warn( + s"invalid VCF header: at INFO field '$id' of type 'Flag', expected 'Number=0', got 'Number=${headerNumberToString(line)}''" + + s"\n Interpreting as 'Number=0' regardless." + ) ((id, baseType), (id, attrs), isFlag) } else if (baseType.isInstanceOf[PCall]) fatal("fields in 'call_fields' must have 'Number' equal to 1.") @@ -1315,10 +1384,10 @@ object LoadVCF { } def headerSignature[T <: VCFCompoundHeaderLine]( - lines: java.util.Collection[T], + lines: java.util.Collection[T] ): (Array[(String, Type)], VCFAttributes, Set[String]) = { val (fields, attrs, flags) = lines.asScala - .map { line => headerField(line) } + .map(line => headerField(line)) .unzip3 val flagFieldNames = fields.zip(flags) @@ -1356,7 +1425,9 @@ object LoadVCF { if (!(headerLine(0) == '#' && headerLine(1) != '#')) fatal( s"""corrupt VCF: expected final header line of format '#CHROM\tPOS\tID...' - | found: @1""".stripMargin, headerLine) + | found: @1""".stripMargin, + headerLine, + ) val sampleIds: Array[String] = headerLine.split("\t").drop(9) @@ -1373,16 +1444,21 @@ object LoadVCF { def getHeaderLines[T]( fs: FS, file: String, - filterAndReplace: TextInputFilterAndReplace): Array[String] = fs.readLines(file, filterAndReplace) { lines => - lines - .takeWhile { line => line.value(0) == '#' } + filterAndReplace: TextInputFilterAndReplace, + ): Array[String] = fs.readLines(file, filterAndReplace) { lines => + lines + .takeWhile(line => line.value(0) == '#') .map(_.value) .toArray } - def getVCFHeaderInfo(fs: FS, file: String, filter: String, find: String, replace: String): VCFHeaderInfo = { - parseHeader(getHeaderLines(fs, file, TextInputFilterAndReplace(Option(filter), Option(find), Option(replace)))) - } + def getVCFHeaderInfo(fs: FS, file: String, filter: String, find: String, replace: String) + : VCFHeaderInfo = + parseHeader(getHeaderLines( + fs, + file, + TextInputFilterAndReplace(Option(filter), Option(find), Option(replace)), + )) def parseLine( rg: Option[ReferenceGenome], @@ -1393,7 +1469,8 @@ object LoadVCF { parseLineContext: ParseLineContext, vcfLine: VCFLine, entriesFieldName: String = LowerMatrixIR.entriesFieldName, - uidFieldName: String = MatrixReader.rowUIDFieldName): Boolean = { + uidFieldName: String = MatrixReader.rowUIDFieldName, + ): Boolean = { val hasLocus = rowPType.hasField("locus") val hasAlleles = rowPType.hasField("alleles") val hasRSID = rowPType.hasField("rsid") @@ -1402,7 +1479,8 @@ object LoadVCF { rvb.start(rowPType) rvb.startStruct() - val present = vcfLine.parseAddVariant(rvb, rg, contigRecoding, hasLocus, hasAlleles, hasRSID, skipInvalidLoci) + val present = vcfLine.parseAddVariant(rvb, rg, contigRecoding, hasLocus, hasAlleles, hasRSID, + skipInvalidLoci) if (!present) return present @@ -1421,13 +1499,15 @@ object LoadVCF { // parses the Variant (key), and ID if necessary, leaves the rest to f def parseLines( makeContext: () => ParseLineContext - )(f: (ParseLineContext, VCFLine, RegionValueBuilder) => Unit - )(lines: ContextRDD[WithContext[String]], + )( + f: (ParseLineContext, VCFLine, RegionValueBuilder) => Unit + )( + lines: ContextRDD[WithContext[String]], rowPType: PStruct, rgBc: Option[BroadcastValue[ReferenceGenome]], contigRecoding: Map[String, String], arrayElementsRequired: Boolean, - skipInvalidLoci: Boolean + skipInvalidLoci: Boolean, ): ContextRDD[Long] = { val hasRSID = rowPType.hasField("rsid") lines.cmapPartitions { (ctx, it) => @@ -1449,10 +1529,27 @@ object LoadVCF { val lwc = it.next() val line = lwc.value try { - val vcfLine = new VCFLine(line, context.fileNum, lwc.source.position.get, arrayElementsRequired, abs, abi, abf, abd) + val vcfLine = new VCFLine( + line, + context.fileNum, + lwc.source.position.get, + arrayElementsRequired, + abs, + abi, + abf, + abd, + ) rvb.start(rowPType) rvb.startStruct() - present = vcfLine.parseAddVariant(rvb, rgBc.map(_.value), contigRecoding, hasRSID, true, true, skipInvalidLoci) + present = vcfLine.parseAddVariant( + rvb, + rgBc.map(_.value), + contigRecoding, + hasRSID, + true, + true, + skipInvalidLoci, + ) if (present) { f(context, vcfLine, rvb) @@ -1468,15 +1565,19 @@ object LoadVCF { val excerptStart = math.max(0, pos - 36) val excerptEnd = math.min(line.length, pos + 36) val excerpt = line.substring(excerptStart, excerptEnd) - .map { c => if (c == '\t') ' ' else c } + .map(c => if (c == '\t') ' ' else c) val prefix = if (excerptStart > 0) "... " else "" val suffix = if (excerptEnd < line.length) " ..." else "" - var caretPad = prefix.length + pos - excerptStart - var pad = " " * caretPad + val caretPad = prefix.length + pos - excerptStart + val pad = " " * caretPad - fatal(s"${ source.locationString(pos) }: ${ e.msg }\n$prefix$excerpt$suffix\n$pad^\noffending line: @1\nsee the Hail log for the full offending line", line, e) + fatal( + s"${source.locationString(pos)}: ${e.msg}\n$prefix$excerpt$suffix\n$pad^\noffending line: @1\nsee the Hail log for the full offending line", + line, + e, + ) case e: Throwable => lwc.source.wrapException(e) } @@ -1495,7 +1596,12 @@ object LoadVCF { } } - def parseHeaderMetadata(fs: FS, callFields: Set[String], entryFloatType: TNumeric, headerFile: String): VCFMetadata = { + def parseHeaderMetadata( + fs: FS, + callFields: Set[String], + entryFloatType: TNumeric, + headerFile: String, + ): VCFMetadata = { val headerLines = getHeaderLines(fs, headerFile, TextInputFilterAndReplace()) val header = parseHeader(headerLines) @@ -1506,7 +1612,7 @@ object LoadVCF { c: ParseLineContext, l: VCFLine, rvb: RegionValueBuilder, - dropSamples: Boolean = false + dropSamples: Boolean = false, ): Unit = { // QUAL if (c.hasQual) { @@ -1571,23 +1677,27 @@ object LoadVCF { } } -case class PartitionedVCFPartition(index: Int, chrom: String, start: Int, end: Int) extends Partition +case class PartitionedVCFPartition(index: Int, chrom: String, start: Int, end: Int) + extends Partition class PartitionedVCFRDD( fsBc: BroadcastValue[FS], file: String, - @(transient@param) reverseContigMapping: Map[String, String], - @(transient@param) _partitions: Array[Partition] + @(transient @param) reverseContigMapping: Map[String, String], + @(transient @param) _partitions: Array[Partition], ) extends RDD[WithContext[String]](SparkBackend.sparkContext("PartitionedVCFRDD"), Seq()) { - val contigRemappingBc = if (reverseContigMapping.size != 0) sparkContext.broadcast(reverseContigMapping) else null + val contigRemappingBc = + if (reverseContigMapping.size != 0) sparkContext.broadcast(reverseContigMapping) else null protected def getPartitions: Array[Partition] = _partitions def compute(split: Partition, context: TaskContext): Iterator[WithContext[String]] = { val p = split.asInstanceOf[PartitionedVCFPartition] - val chromToQuery = if (contigRemappingBc != null) contigRemappingBc.value.getOrElse(p.chrom, p.chrom) else p.chrom + val chromToQuery = if (contigRemappingBc != null) + contigRemappingBc.value.getOrElse(p.chrom, p.chrom) + else p.chrom val reg = { val r = new TabixReader(file, fsBc.value) @@ -1601,9 +1711,7 @@ class PartitionedVCFRDD( // clean up val context = TaskContext.get - context.addTaskCompletionListener[Unit] { (context: TaskContext) => - lines.close() - } + context.addTaskCompletionListener[Unit]((context: TaskContext) => lines.close()) val it: Iterator[WithContext[String]] = new Iterator[WithContext[String]] { private var l = lines.next() @@ -1630,7 +1738,7 @@ class PartitionedVCFRDD( val pos = l.value.substring(t1 + 1, t2).toInt if (chrom != chromToQuery) { - throw new RuntimeException(s"bad chromosome! ${chromToQuery}, $l") + throw new RuntimeException(s"bad chromosome! $chromToQuery, $l") } p.start <= pos && pos <= p.end } @@ -1638,7 +1746,8 @@ class PartitionedVCFRDD( } object MatrixVCFReader { - def apply(ctx: ExecuteContext, + def apply( + ctx: ExecuteContext, files: Seq[String], callFields: Set[String], entryFloatTypeName: String, @@ -1655,12 +1764,16 @@ object MatrixVCFReader { forceGZ: Boolean, filterAndReplace: TextInputFilterAndReplace, partitionsJSON: Option[String], - partitionsTypeStr: Option[String]): MatrixVCFReader = { - MatrixVCFReader(ctx, MatrixVCFReaderParameters( - files, callFields, entryFloatTypeName, headerFile, sampleIDs, nPartitions, blockSizeInMB, minPartitions, rg, - contigRecoding, arrayElementsRequired, skipInvalidLoci, gzAsBGZ, forceGZ, filterAndReplace, - partitionsJSON, partitionsTypeStr)) - } + partitionsTypeStr: Option[String], + ): MatrixVCFReader = + MatrixVCFReader( + ctx, + MatrixVCFReaderParameters( + files, callFields, entryFloatTypeName, headerFile, sampleIDs, nPartitions, blockSizeInMB, + minPartitions, rg, + contigRecoding, arrayElementsRequired, skipInvalidLoci, gzAsBGZ, forceGZ, filterAndReplace, + partitionsJSON, partitionsTypeStr), + ) def apply(ctx: ExecuteContext, params: MatrixVCFReaderParameters): MatrixVCFReader = { val backend = ctx.backend @@ -1677,23 +1790,28 @@ object MatrixVCFReader { } checkGzipOfGlobbedFiles(params.files, fileListEntries, params.forceGZ, params.gzAsBGZ) - val entryFloatType = LoadVCF.getEntryFloatType(params.entryFloatTypeName) - - val headerLines1 = getHeaderLines(fs, params.headerFile.getOrElse(fileListEntries.head.getPath), params.filterAndReplace) + val headerLines1 = getHeaderLines( + fs, + params.headerFile.getOrElse(fileListEntries.head.getPath), + params.filterAndReplace, + ) val header1 = parseHeader(headerLines1) if (fileListEntries.length > 1) { if (params.headerFile.isEmpty) { val header1Bc = backend.broadcast(header1) - val localCallFields = params.callFields - val localFloatType = entryFloatType val files = fileListEntries.map(_.getPath) - val localArrayElementsRequired = params.arrayElementsRequired val localFilterAndReplace = params.filterAndReplace val fsConfigBC = backend.broadcast(fs.getConfiguration()) - backend.parallelizeAndComputeWithIndex(ctx.backendContext, fs, files.tail.map(_.getBytes), "load_vcf_parse_header", None) { (bytes, htc, _, fs) => + backend.parallelizeAndComputeWithIndex( + ctx.backendContext, + fs, + files.tail.map(_.getBytes), + "load_vcf_parse_header", + None, + ) { (bytes, htc, _, fs) => val fsConfig = fsConfigBC.value fs.setConfiguration(fsConfig) val file = new String(bytes) @@ -1704,9 +1822,10 @@ object MatrixVCFReader { if (params.sampleIDs.isEmpty && hd1.sampleIds.length != hd.sampleIds.length) { fatal( s"""invalid sample IDs: expected same number of samples for all inputs. - | ${ files(0) } has ${ hd1.sampleIds.length } ids and - | ${ file } has ${ hd.sampleIds.length } ids. - """.stripMargin) + | ${files(0)} has ${hd1.sampleIds.length} ids and + | $file has ${hd.sampleIds.length} ids. + """.stripMargin + ) } if (params.sampleIDs.isEmpty) { @@ -1715,8 +1834,9 @@ object MatrixVCFReader { if (s1 != s2) { fatal( s"""invalid sample IDs: expected sample ids to be identical for all inputs. Found different sample IDs at position $i. - | ${ files(0) }: $s1 - | $file: $s2""".stripMargin) + | ${files(0)}: $s1 + | $file: $s2""".stripMargin + ) } } } @@ -1724,14 +1844,16 @@ object MatrixVCFReader { if (!hd.formatCompatible(hd1)) fatal( s"""invalid genotype signature: expected signatures to be identical for all inputs. - | ${ files(0) }: ${ hd1.genotypeSignature.toString } - | $file: ${ hd.genotypeSignature.toString }""".stripMargin) + | ${files(0)}: ${hd1.genotypeSignature.toString} + | $file: ${hd.genotypeSignature.toString}""".stripMargin + ) if (!hd.infoCompatible(hd1)) fatal( s"""invalid variant annotation signature: expected signatures to be identical for all inputs. Check that all files have same INFO fields. - | ${ files(0) }: ${ hd1.infoSignature.toString } - | $file: ${ hd.infoSignature.toString }""".stripMargin) + | ${files(0)}: ${hd1.infoSignature.toString} + | $file: ${hd.infoSignature.toString}""".stripMargin + ) bytes } @@ -1742,7 +1864,12 @@ object MatrixVCFReader { LoadVCF.warnDuplicates(sampleIDs) - new MatrixVCFReader(params.copy(files = fileListEntries.map(_.getPath)), fileListEntries, referenceGenome, header1) + new MatrixVCFReader( + params.copy(files = fileListEntries.map(_.getPath)), + fileListEntries, + referenceGenome, + header1, + ) } def fromJValue(ctx: ExecuteContext, jv: JValue): MatrixVCFReader = { @@ -1770,17 +1897,24 @@ case class MatrixVCFReaderParameters( forceGZ: Boolean, filterAndReplace: TextInputFilterAndReplace, partitionsJSON: Option[String], - partitionsTypeStr: Option[String]) { - require(partitionsJSON.isEmpty == partitionsTypeStr.isEmpty, "partitions and type must either both be defined or undefined") + partitionsTypeStr: Option[String], +) { + require( + partitionsJSON.isEmpty == partitionsTypeStr.isEmpty, + "partitions and type must either both be defined or undefined", + ) } class MatrixVCFReader( val params: MatrixVCFReaderParameters, fileListEntries: IndexedSeq[FileListEntry], referenceGenome: Option[ReferenceGenome], - header: VCFHeaderInfo + header: VCFHeaderInfo, ) extends MatrixHybridReader { - require(params.partitionsJSON.isEmpty || fileListEntries.length == 1, "reading with partitions can currently only read a single path") + require( + params.partitionsJSON.isEmpty || fileListEntries.length == 1, + "reading with partitions can currently only read a single path", + ) val sampleIDs = params.sampleIDs.map(_.toArray).getOrElse(header.sampleIds) @@ -1792,7 +1926,8 @@ class MatrixVCFReader( val (infoPType, rowValuePType, formatPType) = header.getPTypes( params.arrayElementsRequired, IRParser.parseType(params.entryFloatTypeName), - params.callFields) + params.callFields, + ) def fullMatrixTypeWithoutUIDs: MatrixType = MatrixType( globalType = TStruct.empty, @@ -1801,22 +1936,30 @@ class MatrixVCFReader( rowType = TStruct( Array( "locus" -> TLocus.schemaFromRG(referenceGenome.map(_.name)), - "alleles" -> TArray(TString)) - ++ rowValuePType.fields.map(f => f.name -> f.typ.virtualType): _*), + "alleles" -> TArray(TString), + ) + ++ rowValuePType.fields.map(f => f.name -> f.typ.virtualType): _* + ), rowKey = Array("locus", "alleles"), // rowKey = Array.empty[String], - entryType = formatPType.virtualType) + entryType = formatPType.virtualType, + ) val fullRVDType = RVDType( - PCanonicalStruct(true, + PCanonicalStruct( + true, FastSeq( "locus" -> PCanonicalLocus.schemaFromRG(referenceGenome.map(_.name), true), - "alleles" -> PCanonicalArray(PCanonicalString(true), true)) - ++ rowValuePType.fields.map { f => f.name -> f.typ } - ++ FastSeq( - LowerMatrixIR.entriesFieldName -> PCanonicalArray(formatPType, true), - rowUIDFieldName -> PCanonicalTuple(true, PInt64Required, PInt64Required)): _*), - fullType.key) + "alleles" -> PCanonicalArray(PCanonicalString(true), true), + ) + ++ rowValuePType.fields.map(f => f.name -> f.typ) + ++ FastSeq( + LowerMatrixIR.entriesFieldName -> PCanonicalArray(formatPType, true), + rowUIDFieldName -> PCanonicalTuple(true, PInt64Required, PInt64Required), + ): _* + ), + fullType.key, + ) def pathsUsed: Seq[String] = params.files @@ -1826,35 +1969,41 @@ class MatrixVCFReader( val partitionCounts: Option[IndexedSeq[Long]] = None - def partitioner(sm: HailStateManager): Option[RVDPartitioner] = params.partitionsJSON.map { partitionsJSON => - val indexedPartitionsType = IRParser.parseType(params.partitionsTypeStr.get) - val jv = JsonMethods.parse(partitionsJSON) - val rangeBounds = JSONAnnotationImpex.importAnnotation(jv, indexedPartitionsType) - .asInstanceOf[IndexedSeq[Interval]] - - rangeBounds.foreach { bound => - if (!(bound.includesStart && bound.includesEnd)) - fatal("range bounds must be inclusive") - - val start = bound.start.asInstanceOf[Row].getAs[Locus](0) - val end = bound.end.asInstanceOf[Row].getAs[Locus](0) - if (start.contig != end.contig) - fatal(s"partition spec must not cross contig boundaries, start: ${start.contig} | end: ${end.contig}") + def partitioner(sm: HailStateManager): Option[RVDPartitioner] = + params.partitionsJSON.map { partitionsJSON => + val indexedPartitionsType = IRParser.parseType(params.partitionsTypeStr.get) + val jv = JsonMethods.parse(partitionsJSON) + val rangeBounds = JSONAnnotationImpex.importAnnotation(jv, indexedPartitionsType) + .asInstanceOf[IndexedSeq[Interval]] + + rangeBounds.foreach { bound => + if (!(bound.includesStart && bound.includesEnd)) + fatal("range bounds must be inclusive") + + val start = bound.start.asInstanceOf[Row].getAs[Locus](0) + val end = bound.end.asInstanceOf[Row].getAs[Locus](0) + if (start.contig != end.contig) + fatal( + s"partition spec must not cross contig boundaries, start: ${start.contig} | end: ${end.contig}" + ) + } + new RVDPartitioner( + sm, + Array("locus"), + fullType.keyType, + rangeBounds, + ) } - new RVDPartitioner( - sm, - Array("locus"), - fullType.keyType, - rangeBounds) - } - override def concreteRowRequiredness(ctx: ExecuteContext, requestedType: TableType): VirtualTypeWithReq = + override def concreteRowRequiredness(ctx: ExecuteContext, requestedType: TableType) + : VirtualTypeWithReq = VirtualTypeWithReq(tcoerce[PStruct](fullRVDType.rowType.subsetTo(requestedType.rowType))) override def uidRequiredness: VirtualTypeWithReq = VirtualTypeWithReq(PCanonicalTuple(true, PInt64Required, PInt64Required)) - override def globalRequiredness(ctx: ExecuteContext, requestedType: TableType): VirtualTypeWithReq = + override def globalRequiredness(ctx: ExecuteContext, requestedType: TableType) + : VirtualTypeWithReq = VirtualTypeWithReq(PType.canonical(requestedType.globalType)) def executeGeneric(ctx: ExecuteContext, dropRows: Boolean = false): GenericTableValue = { @@ -1872,27 +2021,46 @@ class MatrixVCFReader( val part = partitioner(ctx.stateManager) val lines = part match { case Some(partitioner) => - GenericLines.readTabix(fs, fileListEntries(0).getPath, localContigRecoding, partitioner.rangeBounds) + GenericLines.readTabix( + fs, + fileListEntries(0).getPath, + localContigRecoding, + partitioner.rangeBounds, + ) case None => - GenericLines.read(fs, fileListEntries, params.nPartitions, params.blockSizeInMB, params.minPartitions, params.gzAsBGZ, params.forceGZ) + GenericLines.read( + fs, + fileListEntries, + params.nPartitions, + params.blockSizeInMB, + params.minPartitions, + params.gzAsBGZ, + params.forceGZ, + ) } val globals = Row(sampleIDs.zipWithIndex.map { case (s, i) => Row(s, i.toLong) }.toFastSeq) val fullRowPType: PType = fullRVDType.rowType - val bodyPType = (requestedRowType: TStruct) => fullRowPType.subsetTo(requestedRowType).asInstanceOf[PStruct] + val bodyPType = + (requestedRowType: TStruct) => fullRowPType.subsetTo(requestedRowType).asInstanceOf[PStruct] - val linesBody = if (dropRows) { (_: FS, _: Any) => - CloseableIterator.empty[GenericLine] - } else + val linesBody = if (dropRows) { (_: FS, _: Any) => CloseableIterator.empty[GenericLine] } + else lines.body val body = { (requestedType: TStruct) => val requestedPType = bodyPType(requestedType) { (region: Region, theHailClassLoader: HailClassLoader, fs: FS, context: Any) => val fileNum = context.asInstanceOf[Row].getInt(1) - val parseLineContext = new ParseLineContext(requestedType, makeJavaSet(localInfoFlagFieldNames), localNSamples, fileNum, LowerMatrixIR.entriesFieldName) + val parseLineContext = new ParseLineContext( + requestedType, + makeJavaSet(localInfoFlagFieldNames), + localNSamples, + fileNum, + LowerMatrixIR.entriesFieldName, + ) val rvb = new RegionValueBuilder(sm, region) @@ -1910,19 +2078,36 @@ class MatrixVCFReader( if (newText != null) { rvb.clear() try { - val vcfLine = new VCFLine(newText, line.fileNum, line.offset, localArrayElementsRequired, abs, abi, abf, abd) - LoadVCF.parseLine(rgBc.map(_.value), localContigRecoding, localSkipInvalidLoci, - requestedPType, rvb, parseLineContext, vcfLine) + val vcfLine = new VCFLine( + newText, + line.fileNum, + line.offset, + localArrayElementsRequired, + abs, + abi, + abf, + abd, + ) + LoadVCF.parseLine( + rgBc.map(_.value), + localContigRecoding, + localSkipInvalidLoci, + requestedPType, + rvb, + parseLineContext, + vcfLine, + ) } catch { case e: Exception => - fatal(s"${ line.file }:offset ${ line.offset }: error while parsing line\n" + - s"$newText\n", e) + fatal( + s"${line.file}:offset ${line.offset}: error while parsing line\n" + + s"$newText\n", + e, + ) } } else false - }.map { _ => - rvb.result().offset - } + }.map(_ => rvb.result().offset) } } @@ -1937,14 +2122,17 @@ class MatrixVCFReader( lines.contextType.asInstanceOf[TStruct], lines.contexts, bodyPType, - body) + body, + ) } override def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR = { val globals = Row(sampleIDs.zipWithIndex.map(t => Row(t._1, t._2.toLong)).toFastSeq) - Literal.coerce(requestedGlobalsType, + Literal.coerce( + requestedGlobalsType, fullType.globalType.valueSubsetter(requestedGlobalsType) - .apply(globals)) + .apply(globals), + ) } override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = @@ -1965,7 +2153,8 @@ class MatrixVCFReader( } } -case class GVCFPartitionReader(header: VCFHeaderInfo, +case class GVCFPartitionReader( + header: VCFHeaderInfo, callFields: Set[String], entryFloatType: Type, arrayElementsRequired: Boolean, @@ -1974,45 +2163,60 @@ case class GVCFPartitionReader(header: VCFHeaderInfo, skipInvalidLoci: Boolean, filterAndReplace: TextInputFilterAndReplace, entriesFieldName: String, - uidFieldName: String) extends PartitionReader { + uidFieldName: String, +) extends PartitionReader { lazy val contextType: TStruct = TStruct( "fileNum" -> TInt32, "path" -> TString, "contig" -> TString, "start" -> TInt32, - "end" -> TInt32) - - lazy val (infoType, rowValueType, entryType) = header.getPTypes(arrayElementsRequired, entryFloatType, callFields) + "end" -> TInt32, + ) - lazy val fullRowPType: PCanonicalStruct = PCanonicalStruct(true, - FastSeq(("locus", PCanonicalLocus.schemaFromRG(rg, true)), ("alleles", PCanonicalArray(PCanonicalString(true), true))) - ++ rowValueType.fields.map { f => (f.name, f.typ) } - ++ Array(entriesFieldName -> PCanonicalArray(entryType, true), - uidFieldName -> PCanonicalTuple(true, PInt64Required, PInt64Required)): _*) + lazy val (infoType, rowValueType, entryType) = + header.getPTypes(arrayElementsRequired, entryFloatType, callFields) + + lazy val fullRowPType: PCanonicalStruct = PCanonicalStruct( + true, + FastSeq( + ("locus", PCanonicalLocus.schemaFromRG(rg, true)), + ("alleles", PCanonicalArray(PCanonicalString(true), true)), + ) + ++ rowValueType.fields.map(f => (f.name, f.typ)) + ++ Array( + entriesFieldName -> PCanonicalArray(entryType, true), + uidFieldName -> PCanonicalTuple(true, PInt64Required, PInt64Required), + ): _* + ) lazy val fullRowType: TStruct = fullRowPType.virtualType def rowRequiredness(requestedType: TStruct): RStruct = - VirtualTypeWithReq(tcoerce[PStruct](fullRowPType.subsetTo(requestedType))).r.asInstanceOf[RStruct] + VirtualTypeWithReq(tcoerce[PStruct](fullRowPType.subsetTo(requestedType))).r.asInstanceOf[ + RStruct + ] override def toJValue: JValue = { implicit val formats: Formats = DefaultFormats decomposeWithName(this, "MatrixVCFReader") } + def emitStream( ctx: ExecuteContext, cb: EmitCodeBuilder, mb: EmitMethodBuilder[_], context: EmitCode, - requestedType: TStruct + requestedType: TStruct, ): IEmitCode = { context.toI(cb).map(cb) { case ctxValue: SBaseStructValue => - val fileNum = cb.memoizeField(ctxValue.loadField(cb, "fileNum").get(cb).asInt32.value) - val filePath = cb.memoizeField(ctxValue.loadField(cb, "path").get(cb).asString.loadString(cb)) - val contig = cb.memoizeField(ctxValue.loadField(cb, "contig").get(cb).asString.loadString(cb)) - val start = cb.memoizeField(ctxValue.loadField(cb, "start").get(cb).asInt32.value) - val end = cb.memoizeField(ctxValue.loadField(cb, "end").get(cb).asInt32.value) + val fileNum = cb.memoizeField(ctxValue.loadField(cb, "fileNum").getOrAssert(cb).asInt32.value) + val filePath = + cb.memoizeField(ctxValue.loadField(cb, "path").getOrAssert(cb).asString.loadString(cb)) + val contig = + cb.memoizeField(ctxValue.loadField(cb, "contig").getOrAssert(cb).asString.loadString(cb)) + val start = cb.memoizeField(ctxValue.loadField(cb, "start").getOrAssert(cb).asInt32.value) + val end = cb.memoizeField(ctxValue.loadField(cb, "end").getOrAssert(cb).asInt32.value) val requestedPType = fullRowPType.subsetTo(requestedType).asInstanceOf[PStruct] val eltRegion = mb.genFieldThisRef[Region]("gvcf_elt_region") @@ -2024,18 +2228,53 @@ case class GVCFPartitionReader(header: VCFHeaderInfo, override val length: Option[EmitCodeBuilder => Code[Int]] = None override def initialize(cb: EmitCodeBuilder, outerRegion: Value[Region]): Unit = { - cb.assign(iter, Code.newInstance[TabixReadVCFIterator]( - Array[Class[_]](classOf[FS], classOf[String], classOf[Map[String, String]], - classOf[Int], classOf[String], classOf[Int], classOf[Int], - classOf[HailStateManager], classOf[Region], classOf[Region], - classOf[PStruct], classOf[TextInputFilterAndReplace], classOf[Set[String]], - classOf[Int], classOf[ReferenceGenome], classOf[Boolean], classOf[Boolean], - classOf[String], classOf[String]), - Array[Code[_]](mb.getFS, filePath, mb.getObject(contigRecoding), fileNum, contig, start, end, - cb.emb.getObject(cb.emb.ecb.ctx.stateManager), outerRegion, eltRegion, mb.getPType(requestedPType), mb.getObject(filterAndReplace), - mb.getObject(header.infoFlagFields), const(header.sampleIds.length), rg.map(mb.getReferenceGenome).getOrElse(Code._null[ReferenceGenome]), - const(arrayElementsRequired), const(skipInvalidLoci), const(entriesFieldName), const(uidFieldName)) - )) + cb.assign( + iter, + Code.newInstance[TabixReadVCFIterator]( + Array[Class[_]]( + classOf[FS], + classOf[String], + classOf[Map[String, String]], + classOf[Int], + classOf[String], + classOf[Int], + classOf[Int], + classOf[HailStateManager], + classOf[Region], + classOf[Region], + classOf[PStruct], + classOf[TextInputFilterAndReplace], + classOf[Set[String]], + classOf[Int], + classOf[ReferenceGenome], + classOf[Boolean], + classOf[Boolean], + classOf[String], + classOf[String], + ), + Array[Code[_]]( + mb.getFS, + filePath, + mb.getObject(contigRecoding), + fileNum, + contig, + start, + end, + cb.emb.getObject(cb.emb.ecb.ctx.stateManager), + outerRegion, + eltRegion, + mb.getPType(requestedPType), + mb.getObject(filterAndReplace), + mb.getObject(header.infoFlagFields), + const(header.sampleIds.length), + rg.map(mb.getReferenceGenome).getOrElse(Code._null[ReferenceGenome]), + const(arrayElementsRequired), + const(skipInvalidLoci), + const(entriesFieldName), + const(uidFieldName), + ), + ), + ) } override val elementRegion: Settable[Region] = eltRegion @@ -2044,7 +2283,9 @@ case class GVCFPartitionReader(header: VCFHeaderInfo, cb.assign(currentElt, iter.invoke[Region, Long]("next", eltRegion)) cb.if_(currentElt ceq 0L, cb.goto(LendOfStream), cb.goto(LproduceElementDone)) } - override val element: EmitCode = EmitCode.fromI(mb)(cb => IEmitCode.present(cb, requestedPType.loadCheapSCode(cb, currentElt))) + override val element: EmitCode = EmitCode.fromI(mb)(cb => + IEmitCode.present(cb, requestedPType.loadCheapSCode(cb, currentElt)) + ) override def close(cb: EmitCodeBuilder): Unit = { cb += iter.invoke[Unit]("close") cb.assign(iter, Code._null) diff --git a/hail/src/main/scala/is/hail/io/vcf/TabixReadVCFIterator.scala b/hail/src/main/scala/is/hail/io/vcf/TabixReadVCFIterator.scala index 51d8939131d..ad5cf64e400 100644 --- a/hail/src/main/scala/is/hail/io/vcf/TabixReadVCFIterator.scala +++ b/hail/src/main/scala/is/hail/io/vcf/TabixReadVCFIterator.scala @@ -1,19 +1,35 @@ package is.hail.io.vcf -import is.hail.annotations.{Region, RegionValueBuilder, SafeRow} +import is.hail.annotations.{Region, RegionValueBuilder} import is.hail.backend.HailStateManager import is.hail.expr.ir.{CloseableIterator, GenericLine} -import is.hail.io.fs.{FS, Positioned} +import is.hail.io.fs.FS import is.hail.io.tabix.{TabixLineIterator, TabixReader} import is.hail.types.physical.PStruct -import is.hail.utils.{MissingArrayBuilder, TextInputFilterAndReplace, fatal, makeJavaSet} +import is.hail.utils.{fatal, makeJavaSet, MissingArrayBuilder, TextInputFilterAndReplace} import is.hail.variant.ReferenceGenome -class TabixReadVCFIterator(fs: FS, file: String, contigMapping: Map[String, String], - fileNum: Int, chrom: String, start: Int, end: Int, - sm: HailStateManager, partitionRegion: Region, elementRegion: Region, requestedPType: PStruct, - filterAndReplace: TextInputFilterAndReplace, infoFlagFieldNames: Set[String], nSamples: Int, _rg: ReferenceGenome, - arrayElementsRequired: Boolean, skipInvalidLoci: Boolean, entriesFieldName: String, uidFieldName: String) { +class TabixReadVCFIterator( + fs: FS, + file: String, + contigMapping: Map[String, String], + fileNum: Int, + chrom: String, + start: Int, + end: Int, + sm: HailStateManager, + partitionRegion: Region, + elementRegion: Region, + requestedPType: PStruct, + filterAndReplace: TextInputFilterAndReplace, + infoFlagFieldNames: Set[String], + nSamples: Int, + _rg: ReferenceGenome, + arrayElementsRequired: Boolean, + skipInvalidLoci: Boolean, + entriesFieldName: String, + uidFieldName: String, +) { val chromToQuery = contigMapping.iterator.find(_._2 == chrom).map(_._1).getOrElse(chrom) val rg = Option(_rg) @@ -46,7 +62,7 @@ class TabixReadVCFIterator(fs: FS, file: String, contigMapping: Map[String, Stri val bytes = n.getBytes new GenericLine(file, 0, idx, bytes, bytes.length) } catch { - case e: Exception => fatal(s"error reading file: $file at ${ lines.getCurIdx() }", e) + case e: Exception => fatal(s"error reading file: $file at ${lines.getCurIdx()}", e) } } }.filter { gl => @@ -55,14 +71,16 @@ class TabixReadVCFIterator(fs: FS, file: String, contigMapping: Map[String, Stri val t2 = s.indexOf('\t', t1 + 1) if (t1 == -1 || t2 == -1) { - fatal(s"invalid line in file ${ gl.file } no CHROM or POS column at offset ${ gl.offset }.\n$s") + fatal( + s"invalid line in file ${gl.file} no CHROM or POS column at offset ${gl.offset}.\n$s" + ) } val chr = s.substring(0, t1) val pos = s.substring(t1 + 1, t2).toInt if (chr != chrom) { - fatal(s"in file ${ gl.file } at offset ${ gl.offset }, bad chromosome! ${ chrom }, $s") + fatal(s"in file ${gl.file} at offset ${gl.offset}, bad chromosome! $chrom, $s") } start <= pos && pos <= end } @@ -75,9 +93,15 @@ class TabixReadVCFIterator(fs: FS, file: String, contigMapping: Map[String, Stri } } - val transformer = filterAndReplace.transformer() - val parseLineContext = new ParseLineContext(requestedPType.virtualType, makeJavaSet(infoFlagFieldNames), nSamples, fileNum, entriesFieldName) + + val parseLineContext = new ParseLineContext( + requestedPType.virtualType, + makeJavaSet(infoFlagFieldNames), + nSamples, + fileNum, + entriesFieldName, + ) val rvb = new RegionValueBuilder(sm) @@ -98,14 +122,26 @@ class TabixReadVCFIterator(fs: FS, file: String, contigMapping: Map[String, Stri rvb.clear() rvb.set(elementRegion) try { - val vcfLine = new VCFLine(newText, line.fileNum, line.offset, arrayElementsRequired, abs, abi, abf, abd) + val vcfLine = new VCFLine( + newText, + line.fileNum, + line.offset, + arrayElementsRequired, + abs, + abi, + abf, + abd, + ) val pl = LoadVCF.parseLine(rg, contigMapping, skipInvalidLoci, requestedPType, rvb, parseLineContext, vcfLine, entriesFieldName, uidFieldName) pl } catch { case e: Exception => - fatal(s"${ line.file }:offset ${ line.offset }: error while parsing line\n" + - s"$newText\n", e) + fatal( + s"${line.file}:offset ${line.offset}: error while parsing line\n" + + s"$newText\n", + e, + ) } } } @@ -116,7 +152,6 @@ class TabixReadVCFIterator(fs: FS, file: String, contigMapping: Map[String, Stri 0L } - def close(): Unit = { + def close(): Unit = linesIter.close() - } } diff --git a/hail/src/main/scala/is/hail/kryo/HailKryoRegistrator.scala b/hail/src/main/scala/is/hail/kryo/HailKryoRegistrator.scala index 253a8f9365b..9a111522e39 100644 --- a/hail/src/main/scala/is/hail/kryo/HailKryoRegistrator.scala +++ b/hail/src/main/scala/is/hail/kryo/HailKryoRegistrator.scala @@ -1,15 +1,16 @@ package is.hail.kryo -import com.esotericsoftware.kryo.Kryo -import com.esotericsoftware.kryo.serializers.JavaSerializer import is.hail.annotations.{Region, UnsafeIndexedSeq, UnsafeRow} import is.hail.utils.{Interval, SerializableHadoopConfiguration} import is.hail.variant.Locus + +import com.esotericsoftware.kryo.Kryo +import com.esotericsoftware.kryo.serializers.JavaSerializer import org.apache.spark.serializer.KryoRegistrator import org.apache.spark.sql.catalyst.expressions.GenericRow class HailKryoRegistrator extends KryoRegistrator { - override def registerClasses(kryo: Kryo) { + override def registerClasses(kryo: Kryo): Unit = { kryo.register(classOf[SerializableHadoopConfiguration], new JavaSerializer()) kryo.register(classOf[UnsafeRow]) kryo.register(classOf[GenericRow]) diff --git a/hail/src/main/scala/is/hail/linalg/BLAS.scala b/hail/src/main/scala/is/hail/linalg/BLAS.scala index 5db39d98667..bcdf5a4bae5 100644 --- a/hail/src/main/scala/is/hail/linalg/BLAS.scala +++ b/hail/src/main/scala/is/hail/linalg/BLAS.scala @@ -1,13 +1,13 @@ package is.hail.linalg -import java.util.function._ - -import com.sun.jna.{FunctionMapper, Library, Native} -import com.sun.jna.ptr.{DoubleByReference, FloatByReference, IntByReference} import is.hail.utils._ import scala.util.{Failure, Success, Try} +import java.util.function._ + +import com.sun.jna.{FunctionMapper, Library, Native} +import com.sun.jna.ptr.{DoubleByReference, FloatByReference, IntByReference} object BLAS { private[this] val libraryInstance = ThreadLocal.withInitial(new Supplier[BLASLibrary]() { @@ -18,10 +18,11 @@ object BLAS { case Success(_) => log.info("Imported BLAS with standard names") standard - case Failure(exc) => + case Failure(_) => val underscoreAfterMap = new java.util.HashMap[String, FunctionMapper]() underscoreAfterMap.put(Library.OPTION_FUNCTION_MAPPER, new UnderscoreFunctionMapper) - val underscoreAfter = Native.load("blas", classOf[BLASLibrary], underscoreAfterMap).asInstanceOf[BLASLibrary] + val underscoreAfter = + Native.load("blas", classOf[BLASLibrary], underscoreAfterMap).asInstanceOf[BLASLibrary] verificationTest(underscoreAfter) match { case Success(_) => log.info("Imported BLAS with underscore names") @@ -59,7 +60,19 @@ object BLAS { libraryInstance.get.dscal(nInt, alphaDouble, X, incXInt) } - def dgemv(TRANS: String, M: Int, N: Int, ALPHA: Double, A: Long, LDA: Int, X: Long, INCX: Int, BETA: Double, Y: Long, INCY: Int): Unit = { + def dgemv( + TRANS: String, + M: Int, + N: Int, + ALPHA: Double, + A: Long, + LDA: Int, + X: Long, + INCX: Int, + BETA: Double, + Y: Long, + INCY: Int, + ): Unit = { val mInt = new IntByReference(M) val nInt = new IntByReference(N) val alphaDouble = new DoubleByReference(ALPHA) @@ -68,10 +81,25 @@ object BLAS { val incxInt = new IntByReference(INCX) val incyInt = new IntByReference(INCY) - libraryInstance.get.dgemv(TRANS, mInt, nInt, alphaDouble, A, LDAInt, X, incxInt, betaDouble, Y, incyInt) + libraryInstance.get.dgemv(TRANS, mInt, nInt, alphaDouble, A, LDAInt, X, incxInt, betaDouble, Y, + incyInt) } - def sgemm(TRANSA: String, TRANSB: String, M: Int, N: Int, K: Int, ALPHA: Float, A: Long, LDA: Int, B: Long, LDB: Int, BETA: Float, C: Long, LDC: Int) = { + def sgemm( + TRANSA: String, + TRANSB: String, + M: Int, + N: Int, + K: Int, + ALPHA: Float, + A: Long, + LDA: Int, + B: Long, + LDB: Int, + BETA: Float, + C: Long, + LDC: Int, + ) = { val mInt = new IntByReference(M) val nInt = new IntByReference(N) val kInt = new IntByReference(K) @@ -81,10 +109,25 @@ object BLAS { val betaDouble = new FloatByReference(BETA) val LDCInt = new IntByReference(LDC) - libraryInstance.get.sgemm(TRANSA, TRANSB, mInt, nInt, kInt, alphaDouble, A, LDAInt, B, LDBInt, betaDouble, C, LDCInt) + libraryInstance.get.sgemm(TRANSA, TRANSB, mInt, nInt, kInt, alphaDouble, A, LDAInt, B, LDBInt, + betaDouble, C, LDCInt) } - def dgemm(TRANSA: String, TRANSB: String, M: Int, N: Int, K: Int, ALPHA: Double, A: Long, LDA: Int, B: Long, LDB: Int, BETA: Double, C: Long, LDC: Int) = { + def dgemm( + TRANSA: String, + TRANSB: String, + M: Int, + N: Int, + K: Int, + ALPHA: Double, + A: Long, + LDA: Int, + B: Long, + LDB: Int, + BETA: Double, + C: Long, + LDC: Int, + ) = { val mInt = new IntByReference(M) val nInt = new IntByReference(N) val kInt = new IntByReference(K) @@ -94,31 +137,97 @@ object BLAS { val betaDouble = new DoubleByReference(BETA) val LDCInt = new IntByReference(LDC) - libraryInstance.get.dgemm(TRANSA, TRANSB, mInt, nInt, kInt, alphaDouble, A, LDAInt, B, LDBInt, betaDouble, C, LDCInt) + libraryInstance.get.dgemm(TRANSA, TRANSB, mInt, nInt, kInt, alphaDouble, A, LDAInt, B, LDBInt, + betaDouble, C, LDCInt) } - def dtrmm(side: String, uplo: String, transA: String, diag: String, m: Int, n: Int, alpha: Double, A: Long, ldA: Int, B: Long, ldB: Int) = { + def dtrmm( + side: String, + uplo: String, + transA: String, + diag: String, + m: Int, + n: Int, + alpha: Double, + A: Long, + ldA: Int, + B: Long, + ldB: Int, + ) = { val mInt = new IntByReference(m) val nInt = new IntByReference(n) val alphaDouble = new DoubleByReference(alpha) val ldAInt = new IntByReference(ldA) val ldBInt = new IntByReference(ldB) - libraryInstance.get.dtrmm(side, uplo, transA, diag, mInt, nInt, alphaDouble, A, ldAInt, B, ldBInt) + libraryInstance.get.dtrmm(side, uplo, transA, diag, mInt, nInt, alphaDouble, A, ldAInt, B, + ldBInt) } } trait BLASLibrary extends Library { - def dcopy(n: IntByReference, X: Long, incX: IntByReference, Y: Long, incY: IntByReference) - def dscal(n: IntByReference, alpha: DoubleByReference, X: Long, incX: IntByReference) - def dgemv(TRANS: String, M: IntByReference, N: IntByReference, ALPHA: DoubleByReference, A: Long, LDA: IntByReference, X: Long, INCX: IntByReference, BETA: DoubleByReference, Y: Long, INCY: IntByReference) - def sgemm(TRANSA: String, TRANSB: String, M: IntByReference, N: IntByReference, K: IntByReference, - ALPHA: FloatByReference, A: Long, LDA: IntByReference, B: Long, LDB: IntByReference, - BETA: FloatByReference, C: Long, LDC: IntByReference) - def dgemm(TRANSA: String, TRANSB: String, M: IntByReference, N: IntByReference, K: IntByReference, - ALPHA: DoubleByReference, A: Long, LDA: IntByReference, B: Long, LDB: IntByReference, - BETA: DoubleByReference, C: Long, LDC: IntByReference) - def dtrmm(side: String, uplo: String, transA: String, diag: String, m: IntByReference, n: IntByReference, - alpha: DoubleByReference, A: Long, ldA: IntByReference, B: Long, ldB: IntByReference) + def dcopy(n: IntByReference, X: Long, incX: IntByReference, Y: Long, incY: IntByReference): Unit + def dscal(n: IntByReference, alpha: DoubleByReference, X: Long, incX: IntByReference): Unit + + def dgemv( + TRANS: String, + M: IntByReference, + N: IntByReference, + ALPHA: DoubleByReference, + A: Long, + LDA: IntByReference, + X: Long, + INCX: IntByReference, + BETA: DoubleByReference, + Y: Long, + INCY: IntByReference, + ): Unit + + def sgemm( + TRANSA: String, + TRANSB: String, + M: IntByReference, + N: IntByReference, + K: IntByReference, + ALPHA: FloatByReference, + A: Long, + LDA: IntByReference, + B: Long, + LDB: IntByReference, + BETA: FloatByReference, + C: Long, + LDC: IntByReference, + ): Unit + + def dgemm( + TRANSA: String, + TRANSB: String, + M: IntByReference, + N: IntByReference, + K: IntByReference, + ALPHA: DoubleByReference, + A: Long, + LDA: IntByReference, + B: Long, + LDB: IntByReference, + BETA: DoubleByReference, + C: Long, + LDC: IntByReference, + ): Unit + + def dtrmm( + side: String, + uplo: String, + transA: String, + diag: String, + m: IntByReference, + n: IntByReference, + alpha: DoubleByReference, + A: Long, + ldA: IntByReference, + B: Long, + ldB: IntByReference, + ): Unit + def dnrm2(N: IntByReference, X: Array[Double], INCX: IntByReference): Double } diff --git a/hail/src/main/scala/is/hail/linalg/BlockMatrix.scala b/hail/src/main/scala/is/hail/linalg/BlockMatrix.scala index a235e2abfb2..18c08349811 100644 --- a/hail/src/main/scala/is/hail/linalg/BlockMatrix.scala +++ b/hail/src/main/scala/is/hail/linalg/BlockMatrix.scala @@ -1,12 +1,9 @@ package is.hail.linalg -import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, sum => breezeSum, _} -import breeze.numerics.{abs => breezeAbs, log => breezeLog, pow => breezePow, sqrt => breezeSqrt} -import breeze.stats.distributions.RandBasis import is.hail._ import is.hail.annotations._ -import is.hail.backend.spark.{SparkBackend, SparkTaskContext} import is.hail.backend.{BroadcastValue, ExecuteContext, HailStateManager} +import is.hail.backend.spark.{SparkBackend, SparkTaskContext} import is.hail.expr.ir.{IntArrayBuilder, TableReader, TableValue, ThreefryRandomEngine} import is.hail.io._ import is.hail.io.fs.FS @@ -17,7 +14,17 @@ import is.hail.types._ import is.hail.types.physical._ import is.hail.types.virtual._ import is.hail.utils._ -import is.hail.utils.richUtils.{ByteTrackingOutputStream, RichArray, RichContextRDD, RichDenseMatrixDouble} +import is.hail.utils.richUtils.{ + ByteTrackingOutputStream, RichArray, RichContextRDD, RichDenseMatrixDouble, +} + +import scala.collection.immutable.NumericRange + +import java.io._ + +import breeze.linalg.{sum => breezeSum, DenseMatrix => BDM, DenseVector => BDV, _} +import breeze.numerics.{abs => breezeAbs, log => breezeLog, pow => breezePow, sqrt => breezeSqrt} +import breeze.stats.distributions.RandBasis import org.apache.commons.lang3.StringUtils import org.apache.spark._ import org.apache.spark.executor.InputMetrics @@ -26,23 +33,34 @@ import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.json4s._ -import java.io._ -import scala.collection.immutable.NumericRange - -case class CollectMatricesRDDPartition(index: Int, firstPartition: Int, blockPartitions: Array[Partition], blockSize: Int, nRows: Int, nCols: Int) extends Partition { +case class CollectMatricesRDDPartition( + index: Int, + firstPartition: Int, + blockPartitions: Array[Partition], + blockSize: Int, + nRows: Int, + nCols: Int, +) extends Partition { def nBlocks: Int = blockPartitions.length } -class CollectMatricesRDD(@transient var bms: IndexedSeq[BlockMatrix]) extends RDD[BDM[Double]](SparkBackend.sparkContext("CollectMatricesRDD"), Nil) { +class CollectMatricesRDD(@transient var bms: IndexedSeq[BlockMatrix]) + extends RDD[BDM[Double]](SparkBackend.sparkContext("CollectMatricesRDD"), Nil) { private val nBlocks = bms.map(_.blocks.getNumPartitions) private val firstPartition = nBlocks.scan(0)(_ + _).init - protected def getPartitions: Array[Partition] = { + protected def getPartitions: Array[Partition] = bms.iterator.zipWithIndex.map { case (bm, i) => - CollectMatricesRDDPartition(i, firstPartition(i), bm.blocks.partitions, bm.blockSize, bm.nRows.toInt, bm.nCols.toInt) + CollectMatricesRDDPartition( + i, + firstPartition(i), + bm.blocks.partitions, + bm.blockSize, + bm.nRows.toInt, + bm.nCols.toInt, + ) } .toArray - } override def getDependencies: Seq[Dependency[_]] = bms.zipWithIndex.map { case (bm, i) => @@ -66,7 +84,10 @@ class CollectMatricesRDD(@transient var bms: IndexedSeq[BlockMatrix]) extends RD assert(it.hasNext) val ((i, j), b) = it.next() - m((i * p.blockSize) until (i * p.blockSize + b.rows), (j * p.blockSize) until (j * p.blockSize + b.cols)) := b + m( + (i * p.blockSize) until (i * p.blockSize + b.rows), + (j * p.blockSize) until (j * p.blockSize + b.cols), + ) := b k += 1 } @@ -74,7 +95,7 @@ class CollectMatricesRDD(@transient var bms: IndexedSeq[BlockMatrix]) extends RD Iterator.single(m) } - override def clearDependencies() { + override def clearDependencies(): Unit = { super.clearDependencies() bms = null } @@ -84,22 +105,30 @@ object BlockMatrix { type M = BlockMatrix val defaultBlockSize: Int = 4096 // 32 * 1024 bytes val bufferSpecBlockSize = 32 * 1024 + val bufferSpec: BufferSpec = - new BlockingBufferSpec(bufferSpecBlockSize, - new LZ4FastBlockBufferSpec(bufferSpecBlockSize, - new StreamBlockBufferSpec)) + new BlockingBufferSpec( + bufferSpecBlockSize, + new LZ4FastBlockBufferSpec(bufferSpecBlockSize, new StreamBlockBufferSpec), + ) - def apply(gp: GridPartitioner, piBlock: (GridPartitioner, Int) => ((Int, Int), BDM[Double])): BlockMatrix = + def apply(gp: GridPartitioner, piBlock: (GridPartitioner, Int) => ((Int, Int), BDM[Double])) + : BlockMatrix = new BlockMatrix( new RDD[((Int, Int), BDM[Double])](SparkBackend.sparkContext("BlockMatrix.apply"), Nil) { override val partitioner = Some(gp) protected def getPartitions: Array[Partition] = Array.tabulate(gp.numPartitions)(pi => - new Partition { def index: Int = pi } ) + new Partition { def index: Int = pi } + ) def compute(split: Partition, context: TaskContext): Iterator[((Int, Int), BDM[Double])] = Iterator.single(piBlock(gp, split.index)) - }, gp.blockSize, gp.nRows, gp.nCols) + }, + gp.blockSize, + gp.nRows, + gp.nCols, + ) def fromBreezeMatrix(lm: BDM[Double]): M = fromBreezeMatrix(lm, defaultBlockSize) @@ -113,7 +142,10 @@ object BlockMatrix { val iOffset = i * blockSize val jOffset = j * blockSize - HailContext.backend.broadcast(lm(iOffset until iOffset + blockNRows, jOffset until jOffset + blockNCols).copy) + HailContext.backend.broadcast(lm( + iOffset until iOffset + blockNRows, + jOffset until jOffset + blockNCols, + ).copy) } BlockMatrix(gp, (gp, pi) => (gp.blockCoordinates(pi), localBlocksBc(pi).value)) @@ -125,23 +157,36 @@ object BlockMatrix { def fromIRM(irm: IndexedRowMatrix, blockSize: Int): M = irm.toHailBlockMatrix(blockSize) - def fill(nRows: Long, nCols: Long, value: Double, blockSize: Int = defaultBlockSize): BlockMatrix = - BlockMatrix(GridPartitioner(blockSize, nRows, nCols), (gp, pi) => { - val (i, j) = gp.blockCoordinates(pi) - ((i, j), BDM.fill[Double](gp.blockRowNRows(i), gp.blockColNCols(j))(value)) - }) + def fill(nRows: Long, nCols: Long, value: Double, blockSize: Int = defaultBlockSize) + : BlockMatrix = + BlockMatrix( + GridPartitioner(blockSize, nRows, nCols), + (gp, pi) => { + val (i, j) = gp.blockCoordinates(pi) + ((i, j), BDM.fill[Double](gp.blockRowNRows(i), gp.blockColNCols(j))(value)) + }, + ) // uniform or Gaussian - def random(nRows: Long, nCols: Long, blockSize: Int = defaultBlockSize, - nonce: Long = 0, staticUID: Long = 0, gaussian: Boolean): M = - BlockMatrix(GridPartitioner(blockSize, nRows, nCols), (gp, pi) => { - val (i, j) = gp.blockCoordinates(pi) - val generator = ThreefryRandomEngine(nonce, staticUID, Array(pi.toLong)) - val randBasis: RandBasis = new RandBasis(generator) - val rand = if (gaussian) randBasis.gaussian else randBasis.uniform - - ((i, j), BDM.rand[Double](gp.blockRowNRows(i), gp.blockColNCols(j), rand)) - }) + def random( + nRows: Long, + nCols: Long, + blockSize: Int = defaultBlockSize, + nonce: Long = 0, + staticUID: Long = 0, + gaussian: Boolean, + ): M = + BlockMatrix( + GridPartitioner(blockSize, nRows, nCols), + (gp, pi) => { + val (i, j) = gp.blockCoordinates(pi) + val generator = ThreefryRandomEngine(nonce, staticUID, Array(pi.toLong)) + val randBasis: RandBasis = new RandBasis(generator) + val rand = if (gaussian) randBasis.gaussian else randBasis.uniform + + ((i, j), BDM.rand[Double](gp.blockRowNRows(i), gp.blockColNCols(j), rand)) + }, + ) def map2(f: (Double, Double) => Double)(l: M, r: M): M = l.map2(r, f) @@ -151,26 +196,28 @@ object BlockMatrix { val metadataRelativePath = "/metadata.json" - def checkWriteSuccess(fs: FS, uri: String) { - if (!fs.exists(uri + "/_SUCCESS")) - fatal(s"Error reading block matrix. Earlier write failed: no success indicator found at uri $uri") - } + def checkWriteSuccess(fs: FS, uri: String): Unit = + if (!fs.isFile(uri + "/_SUCCESS")) + fatal( + s"Error reading block matrix. Earlier write failed: no success indicator found at uri $uri" + ) - def readMetadata(fs: FS, uri: String): BlockMatrixMetadata = { + def readMetadata(fs: FS, uri: String): BlockMatrixMetadata = using(fs.open(uri + metadataRelativePath)) { is => implicit val formats = defaultJSONFormats jackson.Serialization.read[BlockMatrixMetadata](is) } - } def read(fs: FS, uri: String): M = { checkWriteSuccess(fs, uri) - val BlockMatrixMetadata(blockSize, nRows, nCols, maybeFiltered, partFiles) = readMetadata(fs, uri) + val BlockMatrixMetadata(blockSize, nRows, nCols, maybeFiltered, partFiles) = + readMetadata(fs, uri) val gp = GridPartitioner(blockSize, nRows, nCols, maybeFiltered) - def readBlock(pi: Int, is: InputStream, metrics: InputMetrics): Iterator[((Int, Int), BDM[Double])] = { + def readBlock(pi: Int, is: InputStream, metrics: InputMetrics) + : Iterator[((Int, Int), BDM[Double])] = { val block = RichDenseMatrixDouble.read(is, bufferSpec) is.close() @@ -182,12 +229,17 @@ object BlockMatrix { new BlockMatrix(blocks, blockSize, nRows, nCols) } - private[linalg] def assertCompatibleLocalMatrix(lm: BDM[Double]) { + private[linalg] def assertCompatibleLocalMatrix(lm: BDM[Double]): Unit = assert(lm.isCompact) - } - private[linalg] def block(bm: BlockMatrix, parts: Array[Partition], gp: GridPartitioner, context: TaskContext, - i: Int, j: Int): Option[BDM[Double]] = { + private[linalg] def block( + bm: BlockMatrix, + parts: Array[Partition], + gp: GridPartitioner, + context: TaskContext, + i: Int, + j: Int, + ): Option[BDM[Double]] = { val pi = gp.coordinatesPart(i, j) if (pi >= 0) { val it = bm.blocks.iterator(parts(pi), context) @@ -228,7 +280,12 @@ object BlockMatrix { def collectMatrices(bms: IndexedSeq[BlockMatrix]): RDD[BDM[Double]] = new CollectMatricesRDD(bms) - def binaryWriteBlockMatrices(fs: FS, bms: IndexedSeq[BlockMatrix], prefix: String, overwrite: Boolean): Unit = { + def binaryWriteBlockMatrices( + fs: FS, + bms: IndexedSeq[BlockMatrix], + prefix: String, + overwrite: Boolean, + ): Unit = { if (overwrite) fs.delete(prefix, recursive = true) else if (fs.exists(prefix)) @@ -238,7 +295,7 @@ object BlockMatrix { val d = digitsNeeded(bms.length) val fsBc = fs.broadcast - val partitionCounts = collectMatrices(bms) + collectMatrices(bms) .mapPartitionsWithIndex { case (i, it) => assert(it.hasNext) val m = it.next() @@ -262,7 +319,7 @@ object BlockMatrix { header: Option[String], addIndex: Boolean, compression: Option[String], - customFilenames: Option[Array[String]] + customFilenames: Option[Array[String]], ): Unit = { if (overwrite) @@ -282,7 +339,7 @@ object BlockMatrix { val compressionExtension = compression.map(x => "." + x).getOrElse("") - val partitionCounts = collectMatrices(bms) + collectMatrices(bms) .mapPartitionsWithIndex { case (i, it) => assert(it.hasNext) val m = it.next() @@ -292,10 +349,12 @@ object BlockMatrix { new PrintWriter( new BufferedWriter( new OutputStreamWriter( - fsBc.value.create(path))))) { f => - header.foreach { h => - f.println(h) - } + fsBc.value.create(path) + ) + ) + ) + ) { f => + header.foreach(h => f.println(h)) var i = 0 while (i < m.rows) { @@ -324,18 +383,18 @@ object BlockMatrix { } def writeBlockMatrices( - ctx: ExecuteContext, - bms: IndexedSeq[BlockMatrix], - prefix: String, - overwrite: Boolean, - forceRowMajor: Boolean - ): Unit = { + ctx: ExecuteContext, + bms: IndexedSeq[BlockMatrix], + prefix: String, + overwrite: Boolean, + forceRowMajor: Boolean, + ): Unit = { def blockMatrixURI(matrixIdx: Int): String = prefix + "_" + matrixIdx val fs = ctx.fs val tmpdir = ctx.localTmpdir - bms.zipWithIndex.foreach { case (bm, bIdx) => + bms.zipWithIndex.foreach { case (_, bIdx) => val uri = blockMatrixURI(bIdx) if (overwrite) fs.delete(uri, recursive = true) @@ -346,7 +405,12 @@ object BlockMatrix { fs.mkDir(uri + "/parts") } - def writeBlock(ctx: RVDContext, it: Iterator[((Int, Int), BDM[Double])], os: OutputStream, iw: IndexWriter): (Long, Long) = { + def writeBlock( + ctx: RVDContext, + it: Iterator[((Int, Int), BDM[Double])], + os: OutputStream, + iw: IndexWriter, + ): (Long, Long) = { val btos = new ByteTrackingOutputStream(os) assert(it.hasNext) val (_, lm) = it.next() @@ -364,7 +428,8 @@ object BlockMatrix { } val rdds = bms.map(bm => bm.blocks) - val blockMatrixMetadataFields = bms.map(bm => (bm.blockSize, bm.nRows, bm.nCols, bm.gp.partitionIndexToBlockIndex)) + val blockMatrixMetadataFields = + bms.map(bm => (bm.blockSize, bm.nRows, bm.nCols, bm.gp.partitionIndexToBlockIndex)) val first = rdds(0) val nPartitions = rdds.map(_.getNumPartitions).sum val numDigits = digitsNeeded(nPartitions) @@ -372,19 +437,31 @@ object BlockMatrix { val ordd = new OriginUnionRDD[((Int, Int), BDM[Double]), ((Int, Int), BDM[Double])]( first.sparkContext, rdds, - (_, _, it) => it + (_, _, it) => it, ) - val partMap = ordd.partitions.map(part => part.asInstanceOf[OriginUnionPartition]). - map(oup => (oup.index, (oup.originIdx, oup.originPart.index))).toMap + val partMap = ordd.partitions.map(part => part.asInstanceOf[OriginUnionPartition]).map(oup => + (oup.index, (oup.originIdx, oup.originPart.index)) + ).toMap val writerRDD = ContextRDD.weaken(ordd).cmapPartitionsWithContextAndIndex { (i, ctx, it) => val (rddIndex, partIndex) = partMap(i) val trueIt = it(ctx) val rootPath = blockMatrixURI(rddIndex) val fileName = partFile(numDigits, partIndex, TaskContext.get) - val fileDataIterator = RichContextRDD.writeParts(ctx, rootPath, fileName, null, (_, _) => null, false, fs, tmpdir, trueIt, writeBlock) - fileDataIterator.map { fd => (fd, rddIndex) } + val fileDataIterator = RichContextRDD.writeParts( + ctx, + rootPath, + fileName, + null, + (_, _) => null, + false, + fs, + tmpdir, + trueIt, + writeBlock, + ) + fileDataIterator.map(fd => (fd, rddIndex)) } val rddNumberAndPartFiles = writerRDD.collect() @@ -396,7 +473,8 @@ object BlockMatrix { implicit val formats = defaultJSONFormats val (blockSize, nRows, nCols, maybeBlocks) = blockMatrixMetadataFields(rddIndex.toInt) jackson.Serialization.write( - BlockMatrixMetadata(blockSize, nRows, nCols, maybeBlocks, fileData.map(_.path)), os + BlockMatrixMetadata(blockSize, nRows, nCols, maybeBlocks, fileData.map(_.path)), + os, ) } @@ -406,24 +484,32 @@ object BlockMatrix { } // must be top-level for Jackson to serialize correctly -case class BlockMatrixMetadata(blockSize: Int, nRows: Long, nCols: Long, maybeFiltered: Option[IndexedSeq[Int]], partFiles: IndexedSeq[String]) - -class BlockMatrix(val blocks: RDD[((Int, Int), BDM[Double])], +case class BlockMatrixMetadata( + blockSize: Int, + nRows: Long, + nCols: Long, + maybeFiltered: Option[IndexedSeq[Int]], + partFiles: IndexedSeq[String], +) + +class BlockMatrix( + val blocks: RDD[((Int, Int), BDM[Double])], val blockSize: Int, val nRows: Long, - val nCols: Long) extends Serializable { + val nCols: Long, +) extends Serializable { import BlockMatrix._ - + require(blocks.partitioner.isDefined) require(blocks.partitioner.get.isInstanceOf[GridPartitioner]) val gp: GridPartitioner = blocks.partitioner.get.asInstanceOf[GridPartitioner] - + require(gp.blockSize == blockSize && gp.nRows == nRows && gp.nCols == nCols) - + val isSparse: Boolean = gp.partitionIndexToBlockIndex.isDefined - + def requireDense(name: String): Unit = if (isSparse) fatal(s"$name is not supported for block-sparse matrices.") @@ -434,15 +520,16 @@ class BlockMatrix(val blocks: RDD[((Int, Int), BDM[Double])], realizeBlocks(None) } else this - + // if Some(bis), unrealized blocks in bis are replaced with zero blocks // if None, all unrealized blocks are replaced with zero blocks def realizeBlocks(maybeBlocksToRealize: Option[IndexedSeq[Int]]): BlockMatrix = { val realizeGP = gp.copy(partitionIndexToBlockIndex = - if (maybeBlocksToRealize.exists(_.length == gp.maxNBlocks)) None else maybeBlocksToRealize) + if (maybeBlocksToRealize.exists(_.length == gp.maxNBlocks)) None else maybeBlocksToRealize + ) val newGP = gp.union(realizeGP) - + if (newGP.numPartitions == gp.numPartitions) this else { @@ -452,7 +539,8 @@ class BlockMatrix(val blocks: RDD[((Int, Int), BDM[Double])], Iterator.single((newGP.blockCoordinates(bi), lm)) } val oldToNewPI = gp.partitionIndexToBlockIndex.get.map(newGP.blockToPartition) - val newBlocks = blocks.supersetPartitions(oldToNewPI, newGP.numPartitions, newPIPartition, Some(newGP)) + val newBlocks = + blocks.supersetPartitions(oldToNewPI, newGP.numPartitions, newPIPartition, Some(newGP)) new BlockMatrix(newBlocks, blockSize, nRows, nCols) } @@ -463,18 +551,25 @@ class BlockMatrix(val blocks: RDD[((Int, Int), BDM[Double])], this else subsetBlocks(gp.intersect(gp.copy(partitionIndexToBlockIndex = Some(blocksToKeep)))) - + // assumes subsetGP blocks are subset of gp blocks, as with subsetGP = gp.intersect(gp2) def subsetBlocks(subsetGP: GridPartitioner): BlockMatrix = { if (subsetGP.numPartitions == gp.numPartitions) this else { assert(subsetGP.partitionIndexToBlockIndex.isDefined) - new BlockMatrix(blocks.subsetPartitions(subsetGP.partitionIndexToBlockIndex.get.map(gp.blockToPartition), Some(subsetGP)), - blockSize, nRows, nCols) + new BlockMatrix( + blocks.subsetPartitions( + subsetGP.partitionIndexToBlockIndex.get.map(gp.blockToPartition), + Some(subsetGP), + ), + blockSize, + nRows, + nCols, + ) } } - + // filter to blocks overlapping diagonal band of all elements with lower <= jj - ii <= upper // if not blocksOnly, also zero out all remaining elements outside band def filterBand(lower: Long, upper: Long, blocksOnly: Boolean): BlockMatrix = { @@ -487,65 +582,69 @@ class BlockMatrix(val blocks: RDD[((Int, Int), BDM[Double])], else filteredBM.zeroBand(lower, upper) } - - def zeroBand(lower: Long, upper: Long): BlockMatrix = { - val zeroedBlocks = blocks.mapPartitions( { it => - assert(it.hasNext) - val ((i, j), lm0) = it.next() - assert(!it.hasNext) - - val nRowsInBlock = lm0.rows - val nColsInBlock = lm0.cols - - val diagIndex = (j - i).toLong * blockSize - val lowestDiagIndex = diagIndex - (nRowsInBlock - 1) - val highestDiagIndex = diagIndex + (nColsInBlock - 1) - - if (lowestDiagIndex >= lower && highestDiagIndex <= upper) - Iterator.single(((i, j), lm0)) - else { - val lm = lm0.copy // avoidable? - - if (lower > lowestDiagIndex) { - val iiLeft = math.max(diagIndex - lower, 0).toInt - val iiRight = math.min(diagIndex - lower + nColsInBlock, nRowsInBlock).toInt - - var ii = iiLeft - var jj = math.max(lower - diagIndex, 0).toInt - while (ii < iiRight) { - lm(ii to ii, 0 until jj) := 0.0 - ii += 1 - jj += 1 + + def zeroBand(lower: Long, upper: Long): BlockMatrix = { + val zeroedBlocks = blocks.mapPartitions( + { it => + assert(it.hasNext) + val ((i, j), lm0) = it.next() + assert(!it.hasNext) + + val nRowsInBlock = lm0.rows + val nColsInBlock = lm0.cols + + val diagIndex = (j - i).toLong * blockSize + val lowestDiagIndex = diagIndex - (nRowsInBlock - 1) + val highestDiagIndex = diagIndex + (nColsInBlock - 1) + + if (lowestDiagIndex >= lower && highestDiagIndex <= upper) + Iterator.single(((i, j), lm0)) + else { + val lm = lm0.copy // avoidable? + + if (lower > lowestDiagIndex) { + val iiLeft = math.max(diagIndex - lower, 0).toInt + val iiRight = math.min(diagIndex - lower + nColsInBlock, nRowsInBlock).toInt + + var ii = iiLeft + var jj = math.max(lower - diagIndex, 0).toInt + while (ii < iiRight) { + lm(ii to ii, 0 until jj) := 0.0 + ii += 1 + jj += 1 + } + + lm(iiRight until nRowsInBlock, ::) := 0.0 } - - lm(iiRight until nRowsInBlock, ::) := 0.0 - } - - if (upper < highestDiagIndex) { - val iiLeft = math.max(diagIndex - upper, 0).toInt - val iiRight = math.min(diagIndex - upper + nColsInBlock, nRowsInBlock).toInt - - lm(0 until iiLeft, ::) := 0.0 - - var ii = iiLeft - var jj = math.max(upper - diagIndex, 0).toInt + 1 - while (ii < iiRight) { - lm(ii to ii, jj until nColsInBlock) := 0.0 - ii += 1 - jj += 1 + + if (upper < highestDiagIndex) { + val iiLeft = math.max(diagIndex - upper, 0).toInt + val iiRight = math.min(diagIndex - upper + nColsInBlock, nRowsInBlock).toInt + + lm(0 until iiLeft, ::) := 0.0 + + var ii = iiLeft + var jj = math.max(upper - diagIndex, 0).toInt + 1 + while (ii < iiRight) { + lm(ii to ii, jj until nColsInBlock) := 0.0 + ii += 1 + jj += 1 + } } + Iterator.single(((i, j), lm)) } - Iterator.single(((i, j), lm)) - } - }, preservesPartitioning = true) - + }, + preservesPartitioning = true, + ) + new BlockMatrix(zeroedBlocks, blockSize, nRows, nCols) } - + // for row i, filter to indices [starts[i], stops[i]) by dropping non-overlapping blocks // if not blocksOnly, also zero out elements outside ranges in overlapping blocks // checked in Python: start >= 0 && start <= stop && stop <= nCols - def filterRowIntervals(starts: Array[Long], stops: Array[Long], blocksOnly: Boolean): BlockMatrix = { + def filterRowIntervals(starts: Array[Long], stops: Array[Long], blocksOnly: Boolean) + : BlockMatrix = { require(nRows <= Int.MaxValue) require(starts.length == nRows) require(stops.length == nRows) @@ -557,56 +656,59 @@ class BlockMatrix(val blocks: RDD[((Int, Int), BDM[Double])], else filteredBM.zeroRowIntervals(starts, stops) } - - def zeroRowIntervals(starts: Array[Long], stops: Array[Long]): BlockMatrix = { + + def zeroRowIntervals(starts: Array[Long], stops: Array[Long]): BlockMatrix = { val backend = HailContext.backend val startBlockIndexBc = backend.broadcast(starts.map(gp.indexBlockIndex)) val stopBlockIndexBc = backend.broadcast(stops.map(stop => (stop / blockSize).toInt)) val startBlockOffsetBc = backend.broadcast(starts.map(gp.indexBlockOffset)) val stopBlockOffsetsBc = backend.broadcast(stops.map(gp.indexBlockOffset)) - val zeroedBlocks = blocks.mapPartitions( { it => - assert(it.hasNext) - val ((i, j), lm0) = it.next() - assert(!it.hasNext) + val zeroedBlocks = blocks.mapPartitions( + { it => + assert(it.hasNext) + val ((i, j), lm0) = it.next() + assert(!it.hasNext) + + val lm = lm0.copy // avoidable? + + val startBlockIndex = startBlockIndexBc.value + val stopBlockIndex = stopBlockIndexBc.value + val startBlockOffset = startBlockOffsetBc.value + val stopBlockOffset = stopBlockOffsetsBc.value + + val nRowsInBlock = lm.rows + val nColsInBlock = lm.cols + + var row = i * blockSize + var ii = 0 + while (ii < nRowsInBlock) { + val startBlock = startBlockIndex(row) + if (startBlock == j) + lm(ii to ii, 0 until startBlockOffset(row)) := 0.0 + else if (startBlock > j) + lm(ii to ii, ::) := 0.0 + val stopBlock = stopBlockIndex(row) + if (stopBlock == j) + lm(ii to ii, stopBlockOffset(row) until nColsInBlock) := 0.0 + else if (stopBlock < j) + lm(ii to ii, ::) := 0.0 + row += 1 + ii += 1 + } + + Iterator.single(((i, j), lm)) + }, + preservesPartitioning = true, + ) - val lm = lm0.copy // avoidable? - - val startBlockIndex = startBlockIndexBc.value - val stopBlockIndex = stopBlockIndexBc.value - val startBlockOffset = startBlockOffsetBc.value - val stopBlockOffset = stopBlockOffsetsBc.value - - val nRowsInBlock = lm.rows - val nColsInBlock = lm.cols - - var row = i * blockSize - var ii = 0 - while (ii < nRowsInBlock) { - val startBlock = startBlockIndex(row) - if (startBlock == j) - lm(ii to ii, 0 until startBlockOffset(row)) := 0.0 - else if (startBlock > j) - lm(ii to ii, ::) := 0.0 - val stopBlock = stopBlockIndex(row) - if (stopBlock == j) - lm(ii to ii, stopBlockOffset(row) until nColsInBlock) := 0.0 - else if (stopBlock < j) - lm(ii to ii, ::) := 0.0 - row += 1 - ii += 1 - } - - Iterator.single(((i, j), lm)) - }, preservesPartitioning = true) - new BlockMatrix(zeroedBlocks, blockSize, nRows, nCols) } - + def filterRectangles(flattenedRectangles: Array[Long]): BlockMatrix = { require(flattenedRectangles.length % 4 == 0) val rectangles = flattenedRectangles.grouped(4).toArray - + filterBlocks(gp.rectanglesBlocks(rectangles)) } @@ -615,7 +717,8 @@ class BlockMatrix(val blocks: RDD[((Int, Int), BDM[Double])], output: String, rectangles: Array[Array[Long]], delimiter: String, - binary: Boolean) { + binary: Boolean, + ): Unit = { val writeRectangleBinary = (uos: OutputStream, dm: BDM[Double]) => { val os = new DoubleOutputBuffer(uos, RichArray.defaultBufSize) @@ -662,149 +765,154 @@ class BlockMatrix(val blocks: RDD[((Int, Int), BDM[Double])], // element-wise ops def unary_+(): M = this - - def unary_-(): M = blockMap(-_, - "negation", - reqDense = false) + + def unary_-(): M = blockMap(-_, "negation", reqDense = false) def add(that: M): M = if (sameBlocks(that)) { - blockMap2(that, _ + _, - "addition", - reqDense = false) + blockMap2(that, _ + _, "addition", reqDense = false) } else { - val addBlocks = new BlockMatrixUnionOpRDD(this, that, + val addBlocks = new BlockMatrixUnionOpRDD( + this, + that, _ match { case (Some(a), Some(b)) => a + b case (Some(a), None) => a case (None, Some(b)) => b case (None, None) => fatal("not possible for union") - } + }, ) new BlockMatrix(addBlocks, blockSize, nRows, nCols) } - + def sub(that: M): M = if (sameBlocks(that)) { - blockMap2(that, _ - _, - "subtraction", - reqDense = false) + blockMap2(that, _ - _, "subtraction", reqDense = false) } else { - val subBlocks = new BlockMatrixUnionOpRDD(this, that, + val subBlocks = new BlockMatrixUnionOpRDD( + this, + that, _ match { case (Some(a), Some(b)) => a - b case (Some(a), None) => a case (None, Some(b)) => -b case (None, None) => fatal("not possible for union") - } + }, ) new BlockMatrix(subBlocks, blockSize, nRows, nCols) } - + def mul(that: M): M = { val newGP = gp.intersect(that.gp) subsetBlocks(newGP).blockMap2( - that.subsetBlocks(newGP), _ *:* _, + that.subsetBlocks(newGP), + _ *:* _, "element-wise multiplication", - reqDense = false) + reqDense = false, + ) } - - def div(that: M): M = blockMap2(that, _ /:/ _, - "element-wise division") - + + def div(that: M): M = blockMap2(that, _ /:/ _, "element-wise division") + // row broadcast - def rowVectorAdd(a: Array[Double]): M = densify().rowVectorOp((lm, lv) => lm(*, ::) + lv, - "broadcasted addition of row-vector")(a) - - def rowVectorSub(a: Array[Double]): M = densify().rowVectorOp((lm, lv) => lm(*, ::) - lv, - "broadcasted subtraction of row-vector")(a) - - def rowVectorMul(a: Array[Double]): M = rowVectorOp((lm, lv) => lm(*, ::) *:* lv, + def rowVectorAdd(a: Array[Double]): M = + densify().rowVectorOp((lm, lv) => lm(*, ::) + lv, "broadcasted addition of row-vector")(a) + + def rowVectorSub(a: Array[Double]): M = + densify().rowVectorOp((lm, lv) => lm(*, ::) - lv, "broadcasted subtraction of row-vector")(a) + + def rowVectorMul(a: Array[Double]): M = rowVectorOp( + (lm, lv) => lm(*, ::) *:* lv, "broadcasted multiplication by row-vector containing nan, or infinity", - reqDense = a.exists(i => i.isNaN | i.isInfinity))(a) - - def rowVectorDiv(a: Array[Double]): M = rowVectorOp((lm, lv) => lm(*, ::) /:/ lv, + reqDense = a.exists(i => i.isNaN | i.isInfinity), + )(a) + + def rowVectorDiv(a: Array[Double]): M = rowVectorOp( + (lm, lv) => lm(*, ::) /:/ lv, "broadcasted division by row-vector containing zero, nan, or infinity", - reqDense = a.exists(i => i == 0.0 | i.isNaN | i.isInfinity))(a) - - def reverseRowVectorSub(a: Array[Double]): M = densify().rowVectorOp((lm, lv) => lm(*, ::).map(lv - _), - "broadcasted row-vector minus block matrix")(a) - - def reverseRowVectorDiv(a: Array[Double]): M = rowVectorOp((lm, lv) => lm(*, ::).map(lv /:/ _), - "broadcasted row-vector divided by block matrix")(a) - + reqDense = a.exists(i => i == 0.0 | i.isNaN | i.isInfinity), + )(a) + + def reverseRowVectorSub(a: Array[Double]): M = densify().rowVectorOp( + (lm, lv) => lm(*, ::).map(lv - _), + "broadcasted row-vector minus block matrix", + )(a) + + def reverseRowVectorDiv(a: Array[Double]): M = rowVectorOp( + (lm, lv) => lm(*, ::).map(lv /:/ _), + "broadcasted row-vector divided by block matrix", + )(a) + // column broadcast - def colVectorAdd(a: Array[Double]): M = densify().colVectorOp((lm, lv) => lm(::, *) + lv, - "broadcasted addition of column-vector")(a) - - def colVectorSub(a: Array[Double]): M = densify().colVectorOp((lm, lv) => lm(::, *) - lv, - "broadcasted subtraction of column-vector")(a) - - def colVectorMul(a: Array[Double]): M = colVectorOp((lm, lv) => lm(::, *) *:* lv, + def colVectorAdd(a: Array[Double]): M = + densify().colVectorOp((lm, lv) => lm(::, *) + lv, "broadcasted addition of column-vector")(a) + + def colVectorSub(a: Array[Double]): M = + densify().colVectorOp((lm, lv) => lm(::, *) - lv, "broadcasted subtraction of column-vector")(a) + + def colVectorMul(a: Array[Double]): M = colVectorOp( + (lm, lv) => lm(::, *) *:* lv, "broadcasted multiplication column-vector containing nan or infinity", - reqDense = a.exists(i => i.isNaN | i.isInfinity))(a) - - def colVectorDiv(a: Array[Double]): M = colVectorOp((lm, lv) => lm(::, *) /:/ lv, + reqDense = a.exists(i => i.isNaN | i.isInfinity), + )(a) + + def colVectorDiv(a: Array[Double]): M = colVectorOp( + (lm, lv) => lm(::, *) /:/ lv, "broadcasted division by column-vector containing zero, nan, or infinity", - reqDense = a.exists(i => i == 0.0 | i.isNaN | i.isInfinity))(a) + reqDense = a.exists(i => i == 0.0 | i.isNaN | i.isInfinity), + )(a) - def reverseColVectorSub(a: Array[Double]): M = densify().colVectorOp((lm, lv) => lm(::, *).map(lv - _), - "broadcasted column-vector minus block matrix")(a) + def reverseColVectorSub(a: Array[Double]): M = densify().colVectorOp( + (lm, lv) => lm(::, *).map(lv - _), + "broadcasted column-vector minus block matrix", + )(a) - def reverseColVectorDiv(a: Array[Double]): M = colVectorOp((lm, lv) => lm(::, *).map(lv /:/ _), - "broadcasted column-vector divided by block matrix")(a) + def reverseColVectorDiv(a: Array[Double]): M = colVectorOp( + (lm, lv) => lm(::, *).map(lv /:/ _), + "broadcasted column-vector divided by block matrix", + )(a) // scalar - def scalarAdd(i: Double): M = densify().blockMap(_ + i, - "scalar addition") - - def scalarSub(i: Double): M = densify().blockMap(_ - i, - "scalar subtraction") - - def scalarMul(i: Double): M = blockMap(_ *:* i, - s"multiplication by scalar $i", - reqDense = i.isNaN | i.isInfinity) - - def scalarDiv(i: Double): M = blockMap(_ /:/ i, - s"division by scalar $i", - reqDense = i == 0.0 | i.isNaN | i.isInfinity) - - def reverseScalarSub(i: Double): M = densify().blockMap(i - _, - s"scalar minus block matrix") - - def reverseScalarDiv(i: Double): M = blockMap(i /:/ _, - s"scalar divided by block matrix") + def scalarAdd(i: Double): M = densify().blockMap(_ + i, "scalar addition") + + def scalarSub(i: Double): M = densify().blockMap(_ - i, "scalar subtraction") + + def scalarMul(i: Double): M = + blockMap(_ *:* i, s"multiplication by scalar $i", reqDense = i.isNaN | i.isInfinity) + + def scalarDiv(i: Double): M = + blockMap(_ /:/ i, s"division by scalar $i", reqDense = i == 0.0 | i.isNaN | i.isInfinity) + + def reverseScalarSub(i: Double): M = densify().blockMap(i - _, s"scalar minus block matrix") + + def reverseScalarDiv(i: Double): M = blockMap(i /:/ _, s"scalar divided by block matrix") // other element-wise ops - def sqrt(): M = blockMap(breezeSqrt(_), - "sqrt", - reqDense = false) + def sqrt(): M = blockMap(breezeSqrt(_), "sqrt", reqDense = false) - def ceil(): M = blockMap(breeze.numerics.ceil(_), - "ceil", - reqDense = false) + def ceil(): M = blockMap(breeze.numerics.ceil(_), "ceil", reqDense = false) - def floor(): M = blockMap(breeze.numerics.floor(_), - "floor", - reqDense = false) + def floor(): M = blockMap(breeze.numerics.floor(_), "floor", reqDense = false) - def pow(exponent: Double): M = blockMap(breezePow(_, exponent), + def pow(exponent: Double): M = blockMap( + breezePow(_, exponent), s"exponentiation by negative power $exponent", - reqDense = exponent < 0) - - def log(): M = blockMap(breezeLog(_), - "natural logarithm") - - def abs(): M = blockMap(breezeAbs(_), - "absolute value", - reqDense = false) - + reqDense = exponent < 0, + ) + + def log(): M = blockMap(breezeLog(_), "natural logarithm") + + def abs(): M = blockMap(breezeAbs(_), "absolute value", reqDense = false) + // matrix ops - def dot(that: M): M = new BlockMatrix(new BlockMatrixMultiplyRDD(this, that), blockSize, nRows, that.nCols) + def dot(that: M): M = + new BlockMatrix(new BlockMatrixMultiplyRDD(this, that), blockSize, nRows, that.nCols) def dot(lm: BDM[Double]): M = { - require(nCols == lm.rows, - s"incompatible matrix dimensions: ${ nRows } x ${ nCols } and ${ lm.rows } x ${ lm.cols }") + require( + nCols == lm.rows, + s"incompatible matrix dimensions: $nRows x $nCols and ${lm.rows} x ${lm.cols}", + ) dot(BlockMatrix.fromBreezeMatrix(lm, blockSize)) } @@ -825,10 +933,10 @@ class BlockMatrix(val blocks: RDD[((Int, Int), BDM[Double])], } val result = new Array[Double](nDiagElements) - + val nDiagBlocks = math.min(gp.nBlockRows, gp.nBlockCols) val diagBlocks = Array.tabulate(nDiagBlocks)(i => gp.coordinatesBlock(i, i)) - + filterBlocks(diagBlocks).blocks .map { case ((i, j), lm) => assert(i == j) @@ -836,11 +944,17 @@ class BlockMatrix(val blocks: RDD[((Int, Int), BDM[Double])], } .collect() .foreach { case (i, a) => System.arraycopy(a, 0, result, i * blockSize, a.length) } - + result } - def write(ctx: ExecuteContext, uri: String, overwrite: Boolean = false, forceRowMajor: Boolean = false, stageLocally: Boolean = false) { + def write( + ctx: ExecuteContext, + uri: String, + overwrite: Boolean = false, + forceRowMajor: Boolean = false, + stageLocally: Boolean = false, + ): Unit = { val fs = ctx.fs if (overwrite) fs.delete(uri, recursive = true) @@ -867,17 +981,24 @@ class BlockMatrix(val blocks: RDD[((Int, Int), BDM[Double])], using(new DataOutputStream(fs.create(uri + metadataRelativePath))) { os => implicit val formats = defaultJSONFormats jackson.Serialization.write( - BlockMatrixMetadata(blockSize, nRows, nCols, gp.partitionIndexToBlockIndex, fileData.map(_.path)), - os) + BlockMatrixMetadata( + blockSize, + nRows, + nCols, + gp.partitionIndexToBlockIndex, + fileData.map(_.path), + ), + os, + ) } using(fs.create(uri + "/_SUCCESS"))(out => ()) val nBlocks = fileData.length assert(nBlocks == fileData.map(_.rowsWritten).sum) - info(s"wrote matrix with $nRows ${ plural(nRows, "row") } " + - s"and $nCols ${ plural(nCols, "column") } " + - s"as $nBlocks ${ plural(nBlocks, "block") } " + + info(s"wrote matrix with $nRows ${plural(nRows, "row")} " + + s"and $nCols ${plural(nCols, "column")} " + + s"as $nBlocks ${plural(nBlocks, "block")} " + s"of size $blockSize to $uri") } @@ -892,12 +1013,13 @@ class BlockMatrix(val blocks: RDD[((Int, Int), BDM[Double])], } def persist(storageLevel: String): this.type = { - val level = try { - StorageLevel.fromString(storageLevel) - } catch { - case e: IllegalArgumentException => - fatal(s"unknown StorageLevel '$storageLevel'") - } + val level = + try + StorageLevel.fromString(storageLevel) + catch { + case _: IllegalArgumentException => + fatal(s"unknown StorageLevel '$storageLevel'") + } persist(level) } @@ -907,12 +1029,21 @@ class BlockMatrix(val blocks: RDD[((Int, Int), BDM[Double])], } def toBreezeMatrix(): BDM[Double] = { - require(nRows <= Int.MaxValue, "The number of rows of this matrix should be less than or equal to " + - s"Int.MaxValue. Currently nRows: $nRows") - require(nCols <= Int.MaxValue, "The number of columns of this matrix should be less than or equal to " + - s"Int.MaxValue. Currently nCols: $nCols") - require(nRows * nCols <= Int.MaxValue, "The length of the values array must be " + - s"less than or equal to Int.MaxValue. Currently nRows * nCols: ${ nRows * nCols }") + require( + nRows <= Int.MaxValue, + "The number of rows of this matrix should be less than or equal to " + + s"Int.MaxValue. Currently nRows: $nRows", + ) + require( + nCols <= Int.MaxValue, + "The number of columns of this matrix should be less than or equal to " + + s"Int.MaxValue. Currently nCols: $nCols", + ) + require( + nRows * nCols <= Int.MaxValue, + "The length of the values array must be " + + s"less than or equal to Int.MaxValue. Currently nRows * nCols: ${nRows * nCols}", + ) val nRowsInt = nRows.toInt val nColsInt = nCols.toInt val localBlocks = blocks.collect() @@ -934,77 +1065,84 @@ class BlockMatrix(val blocks: RDD[((Int, Int), BDM[Double])], new BDM(nRowsInt, nColsInt, data) } - private def requireZippable(that: M, name: String = "operation") { - require(nRows == that.nRows, - s"$name requires same number of rows, but actually: ${ nRows }x${ nCols }, ${ that.nRows }x${ that.nCols }") - require(nCols == that.nCols, - s"$name requires same number of cols, but actually: ${ nRows }x${ nCols }, ${ that.nRows }x${ that.nCols }") - require(blockSize == that.blockSize, - s"$name requires same block size, but actually: $blockSize and ${ that.blockSize }") + private def requireZippable(that: M, name: String = "operation"): Unit = { + require( + nRows == that.nRows, + s"$name requires same number of rows, but actually: ${nRows}x$nCols, ${that.nRows}x${that.nCols}", + ) + require( + nCols == that.nCols, + s"$name requires same number of cols, but actually: ${nRows}x$nCols, ${that.nRows}x${that.nCols}", + ) + require( + blockSize == that.blockSize, + s"$name requires same block size, but actually: $blockSize and ${that.blockSize}", + ) if (!sameBlocks(that)) fatal(s"$name requires block matrices to have the same set of blocks present") } - - private def sameBlocks(that: M): Boolean = { + + private def sameBlocks(that: M): Boolean = (gp.partitionIndexToBlockIndex, that.gp.partitionIndexToBlockIndex) match { case (Some(bis), Some(bis2)) => bis sameElements bis2 case (None, None) => true case _ => false } - } - - def blockMap(op: BDM[Double] => BDM[Double], - name: String = "operation", - reqDense: Boolean = true): M = { + + def blockMap(op: BDM[Double] => BDM[Double], name: String = "operation", reqDense: Boolean = true) + : M = { if (reqDense) requireDense(name) new BlockMatrix(blocks.mapValues(op), blockSize, nRows, nCols) } - - def blockMapWithIndex(op: ((Int, Int), BDM[Double]) => BDM[Double], + + def blockMapWithIndex( + op: ((Int, Int), BDM[Double]) => BDM[Double], name: String = "operation", - reqDense: Boolean = true): M = { + reqDense: Boolean = true, + ): M = { if (reqDense) requireDense(name) new BlockMatrix(blocks.mapValuesWithKey(op), blockSize, nRows, nCols) } - def blockMap2(that: M, + def blockMap2( + that: M, op: (BDM[Double], BDM[Double]) => BDM[Double], name: String = "operation", - reqDense: Boolean = true): M = { + reqDense: Boolean = true, + ): M = { if (reqDense) { requireDense(name) that.requireDense(name) } requireZippable(that) - val newBlocks = blocks.zipPartitions(that.blocks, preservesPartitioning = true) { (thisIter, thatIter) => - new Iterator[((Int, Int), BDM[Double])] { - def hasNext: Boolean = { - assert(thisIter.hasNext == thatIter.hasNext) - thisIter.hasNext - } + val newBlocks = + blocks.zipPartitions(that.blocks, preservesPartitioning = true) { (thisIter, thatIter) => + new Iterator[((Int, Int), BDM[Double])] { + def hasNext: Boolean = { + assert(thisIter.hasNext == thatIter.hasNext) + thisIter.hasNext + } - def next(): ((Int, Int), BDM[Double]) = { - val ((i1, j1), lm1) = thisIter.next() - val ((i2, j2), lm2) = thatIter.next() - assertCompatibleLocalMatrix(lm1) - assertCompatibleLocalMatrix(lm2) - assert(i1 == i2, s"$i1 $i2") - assert(j1 == j2, s"$j1 $j2") - val lm = op(lm1, lm2) - assert(lm.rows == lm1.rows) - assert(lm.cols == lm1.cols) - ((i1, j1), lm) + def next(): ((Int, Int), BDM[Double]) = { + val ((i1, j1), lm1) = thisIter.next() + val ((i2, j2), lm2) = thatIter.next() + assertCompatibleLocalMatrix(lm1) + assertCompatibleLocalMatrix(lm2) + assert(i1 == i2, s"$i1 $i2") + assert(j1 == j2, s"$j1 $j2") + val lm = op(lm1, lm2) + assert(lm.rows == lm1.rows) + assert(lm.cols == lm1.cols) + ((i1, j1), lm) + } } } - } new BlockMatrix(newBlocks, blockSize, nRows, nCols) } - def map(op: Double => Double, - name: String = "operation", - reqDense: Boolean = true): M = { + def map(op: Double => Double, name: String = "operation", reqDense: Boolean = true): M = { if (reqDense) requireDense(name) val newBlocks = blocks.mapValues { lm => @@ -1021,64 +1159,71 @@ class BlockMatrix(val blocks: RDD[((Int, Int), BDM[Double])], new BlockMatrix(newBlocks, blockSize, nRows, nCols) } - def map2(that: M, + def map2( + that: M, op: (Double, Double) => Double, name: String = "operation", - reqDense: Boolean = true): M = { + reqDense: Boolean = true, + ): M = { if (reqDense) { requireDense(name) that.requireDense(name) } requireZippable(that) - val newBlocks = blocks.zipPartitions(that.blocks, preservesPartitioning = true) { (thisIter, thatIter) => - new Iterator[((Int, Int), BDM[Double])] { - def hasNext: Boolean = { - assert(thisIter.hasNext == thatIter.hasNext) - thisIter.hasNext - } + val newBlocks = + blocks.zipPartitions(that.blocks, preservesPartitioning = true) { (thisIter, thatIter) => + new Iterator[((Int, Int), BDM[Double])] { + def hasNext: Boolean = { + assert(thisIter.hasNext == thatIter.hasNext) + thisIter.hasNext + } - def next(): ((Int, Int), BDM[Double]) = { - val ((i1, j1), lm1) = thisIter.next() - val ((i2, j2), lm2) = thatIter.next() - assertCompatibleLocalMatrix(lm1) - assertCompatibleLocalMatrix(lm2) - assert(i1 == i2, s"$i1 $i2") - assert(j1 == j2, s"$j1 $j2") - val nRows = lm1.rows - val nCols = lm1.cols - val src1 = lm1.data - val src2 = lm2.data - val dst = new Array[Double](src1.length) - if (lm1.isTranspose == lm2.isTranspose) { - var k = 0 - while (k < src1.length) { - dst(k) = op(src1(k), src2(k)) - k += 1 - } - } else { - val length = src1.length - var k1 = 0 - var k2 = 0 - while (k1 < length) { - while (k2 < length) { - dst(k1) = op(src1(k1), src2(k2)) - k1 += 1 - k2 += lm2.majorStride + def next(): ((Int, Int), BDM[Double]) = { + val ((i1, j1), lm1) = thisIter.next() + val ((i2, j2), lm2) = thatIter.next() + assertCompatibleLocalMatrix(lm1) + assertCompatibleLocalMatrix(lm2) + assert(i1 == i2, s"$i1 $i2") + assert(j1 == j2, s"$j1 $j2") + val nRows = lm1.rows + val nCols = lm1.cols + val src1 = lm1.data + val src2 = lm2.data + val dst = new Array[Double](src1.length) + if (lm1.isTranspose == lm2.isTranspose) { + var k = 0 + while (k < src1.length) { + dst(k) = op(src1(k), src2(k)) + k += 1 + } + } else { + val length = src1.length + var k1 = 0 + var k2 = 0 + while (k1 < length) { + while (k2 < length) { + dst(k1) = op(src1(k1), src2(k2)) + k1 += 1 + k2 += lm2.majorStride + } + k2 += 1 - length } - k2 += 1 - length } + ((i1, j1), new BDM(nRows, nCols, dst, 0, lm1.majorStride, lm1.isTranspose)) } - ((i1, j1), new BDM(nRows, nCols, dst, 0, lm1.majorStride, lm1.isTranspose)) } } - } new BlockMatrix(newBlocks, blockSize, nRows, nCols) } - def map4(bm2: M, bm3: M, bm4: M, + def map4( + bm2: M, + bm3: M, + bm4: M, op: (Double, Double, Double, Double) => Double, name: String = "operation", - reqDense: Boolean = true): M = { + reqDense: Boolean = true, + ): M = { if (reqDense) { requireDense(name) bm2.requireDense(name) @@ -1088,73 +1233,79 @@ class BlockMatrix(val blocks: RDD[((Int, Int), BDM[Double])], requireZippable(bm2) requireZippable(bm3) requireZippable(bm4) - val newBlocks = blocks.zipPartitions(bm2.blocks, bm3.blocks, bm4.blocks, preservesPartitioning = true) { (it1, it2, it3, it4) => - new Iterator[((Int, Int), BDM[Double])] { - def hasNext: Boolean = { - assert(it1.hasNext == it2.hasNext) - assert(it1.hasNext == it3.hasNext) - assert(it1.hasNext == it4.hasNext) - it1.hasNext - } - - def next(): ((Int, Int), BDM[Double]) = { - val ((i1, j1), lm1) = it1.next() - val ((i2, j2), lm2) = it2.next() - val ((i3, j3), lm3) = it3.next() - val ((i4, j4), lm4) = it4.next() - assertCompatibleLocalMatrix(lm1) - assertCompatibleLocalMatrix(lm2) - assertCompatibleLocalMatrix(lm3) - assertCompatibleLocalMatrix(lm4) - assert(i1 == i2, s"$i1 $i2") - assert(j1 == j2, s"$j1 $j2") - assert(i1 == i3, s"$i1 $i3") - assert(j1 == j3, s"$j1 $j3") - assert(i1 == i4, s"$i1 $i4") - assert(j1 == j4, s"$j1 $j4") - val nRows = lm1.rows - val nCols = lm1.cols - val src1 = lm1.data - val src2 = lm2.data - val src3 = lm3.data - val src4 = lm4.data - val dst = new Array[Double](src1.length) - if (lm1.isTranspose == lm2.isTranspose - && lm1.isTranspose == lm3.isTranspose - && lm1.isTranspose == lm4.isTranspose) { - var k = 0 - while (k < src1.length) { - dst(k) = op(src1(k), src2(k), src3(k), src4(k)) - k += 1 + val newBlocks = + blocks.zipPartitions(bm2.blocks, bm3.blocks, bm4.blocks, preservesPartitioning = true) { + (it1, it2, it3, it4) => + new Iterator[((Int, Int), BDM[Double])] { + def hasNext: Boolean = { + assert(it1.hasNext == it2.hasNext) + assert(it1.hasNext == it3.hasNext) + assert(it1.hasNext == it4.hasNext) + it1.hasNext } - } else { - // FIXME: code gen the optimal tree on driver? - val length = src1.length - val lm1MinorSize = length / lm1.majorStride - var k1 = 0 - var kt = 0 - while (k1 < length) { - while (kt < length) { - val v2 = if (lm1.isTranspose == lm2.isTranspose) src2(k1) else src2(kt) - val v3 = if (lm1.isTranspose == lm3.isTranspose) src3(k1) else src3(kt) - val v4 = if (lm1.isTranspose == lm4.isTranspose) src4(k1) else src4(kt) - dst(k1) = op(src1(k1), v2, v3, v4) - k1 += 1 - kt += lm1MinorSize + + def next(): ((Int, Int), BDM[Double]) = { + val ((i1, j1), lm1) = it1.next() + val ((i2, j2), lm2) = it2.next() + val ((i3, j3), lm3) = it3.next() + val ((i4, j4), lm4) = it4.next() + assertCompatibleLocalMatrix(lm1) + assertCompatibleLocalMatrix(lm2) + assertCompatibleLocalMatrix(lm3) + assertCompatibleLocalMatrix(lm4) + assert(i1 == i2, s"$i1 $i2") + assert(j1 == j2, s"$j1 $j2") + assert(i1 == i3, s"$i1 $i3") + assert(j1 == j3, s"$j1 $j3") + assert(i1 == i4, s"$i1 $i4") + assert(j1 == j4, s"$j1 $j4") + val nRows = lm1.rows + val nCols = lm1.cols + val src1 = lm1.data + val src2 = lm2.data + val src3 = lm3.data + val src4 = lm4.data + val dst = new Array[Double](src1.length) + if ( + lm1.isTranspose == lm2.isTranspose + && lm1.isTranspose == lm3.isTranspose + && lm1.isTranspose == lm4.isTranspose + ) { + var k = 0 + while (k < src1.length) { + dst(k) = op(src1(k), src2(k), src3(k), src4(k)) + k += 1 + } + } else { + // FIXME: code gen the optimal tree on driver? + val length = src1.length + val lm1MinorSize = length / lm1.majorStride + var k1 = 0 + var kt = 0 + while (k1 < length) { + while (kt < length) { + val v2 = if (lm1.isTranspose == lm2.isTranspose) src2(k1) else src2(kt) + val v3 = if (lm1.isTranspose == lm3.isTranspose) src3(k1) else src3(kt) + val v4 = if (lm1.isTranspose == lm4.isTranspose) src4(k1) else src4(kt) + dst(k1) = op(src1(k1), v2, v3, v4) + k1 += 1 + kt += lm1MinorSize + } + kt += 1 - length + } } - kt += 1 - length + ((i1, j1), new BDM(nRows, nCols, dst, 0, lm1.majorStride, lm1.isTranspose)) } } - ((i1, j1), new BDM(nRows, nCols, dst, 0, lm1.majorStride, lm1.isTranspose)) - } } - } new BlockMatrix(newBlocks, blockSize, nRows, nCols) } - def mapWithIndex(op: (Long, Long, Double) => Double, + def mapWithIndex( + op: (Long, Long, Double) => Double, name: String = "operation", - reqDense: Boolean = true): M = { + reqDense: Boolean = true, + ): M = { if (reqDense) requireDense(name) val newBlocks = blocks.mapValuesWithKey { case ((i, j), lm) => @@ -1176,92 +1327,122 @@ class BlockMatrix(val blocks: RDD[((Int, Int), BDM[Double])], new BlockMatrix(newBlocks, blockSize, nRows, nCols) } - def map2WithIndex(that: M, + def map2WithIndex( + that: M, op: (Long, Long, Double, Double) => Double, name: String = "operation", - reqDense: Boolean = true): M = { + reqDense: Boolean = true, + ): M = { if (reqDense) { requireDense(name) that.requireDense(name) } requireZippable(that) - val newBlocks = blocks.zipPartitions(that.blocks, preservesPartitioning = true) { (thisIter, thatIter) => - new Iterator[((Int, Int), BDM[Double])] { - def hasNext: Boolean = { - assert(thisIter.hasNext == thatIter.hasNext) - thisIter.hasNext - } + val newBlocks = + blocks.zipPartitions(that.blocks, preservesPartitioning = true) { (thisIter, thatIter) => + new Iterator[((Int, Int), BDM[Double])] { + def hasNext: Boolean = { + assert(thisIter.hasNext == thatIter.hasNext) + thisIter.hasNext + } - def next(): ((Int, Int), BDM[Double]) = { - val ((i1, j1), lm1) = thisIter.next() - val ((i2, j2), lm2) = thatIter.next() - assert(i1 == i2, s"$i1 $i2") - assert(j1 == j2, s"$j1 $j2") - val iOffset = i1.toLong * blockSize - val jOffset = j1.toLong * blockSize - val size = lm1.cols * lm1.rows - val result = new Array[Double](size) - var jj = 0 - while (jj < lm1.cols) { - var ii = 0 - while (ii < lm1.rows) { - result(ii + jj * lm1.rows) = op(iOffset + ii, jOffset + jj, lm1(ii, jj), lm2(ii, jj)) - ii += 1 + def next(): ((Int, Int), BDM[Double]) = { + val ((i1, j1), lm1) = thisIter.next() + val ((i2, j2), lm2) = thatIter.next() + assert(i1 == i2, s"$i1 $i2") + assert(j1 == j2, s"$j1 $j2") + val iOffset = i1.toLong * blockSize + val jOffset = j1.toLong * blockSize + val size = lm1.cols * lm1.rows + val result = new Array[Double](size) + var jj = 0 + while (jj < lm1.cols) { + var ii = 0 + while (ii < lm1.rows) { + result(ii + jj * lm1.rows) = + op(iOffset + ii, jOffset + jj, lm1(ii, jj), lm2(ii, jj)) + ii += 1 + } + jj += 1 } - jj += 1 + ((i1, j1), new BDM(lm1.rows, lm1.cols, result)) } - ((i1, j1), new BDM(lm1.rows, lm1.cols, result)) } } - } new BlockMatrix(newBlocks, blockSize, nRows, nCols) } - def colVectorOp(op: (BDM[Double], BDV[Double]) => BDM[Double], + def colVectorOp( + op: (BDM[Double], BDV[Double]) => BDM[Double], name: String = "operation", - reqDense: Boolean = true): Array[Double] => M = { - a => val v = BDV(a) - require(v.length == nRows, s"vector length must equal nRows: ${ v.length }, $nRows") + reqDense: Boolean = true, + ): Array[Double] => M = { + a => + val v = BDV(a) + require(v.length == nRows, s"vector length must equal nRows: ${v.length}, $nRows") val vBc = HailContext.backend.broadcast(v) - blockMapWithIndex( { case ((i, _), lm) => - val lv = gp.vectorOnBlockRow(vBc.value, i) - op(lm, lv) - }, name, reqDense = reqDense) + blockMapWithIndex( + { case ((i, _), lm) => + val lv = gp.vectorOnBlockRow(vBc.value, i) + op(lm, lv) + }, + name, + reqDense = reqDense, + ) } - def rowVectorOp(op: (BDM[Double], BDV[Double]) => BDM[Double], + def rowVectorOp( + op: (BDM[Double], BDV[Double]) => BDM[Double], name: String = "operation", - reqDense: Boolean = true): Array[Double] => M = { - a => val v = BDV(a) - require(v.length == nCols, s"vector length must equal nCols: ${ v.length }, $nCols") + reqDense: Boolean = true, + ): Array[Double] => M = { + a => + val v = BDV(a) + require(v.length == nCols, s"vector length must equal nCols: ${v.length}, $nCols") val vBc = HailContext.backend.broadcast(v) - blockMapWithIndex( { case ((_, j), lm) => - val lv = gp.vectorOnBlockCol(vBc.value, j) - op(lm, lv) - }, name, reqDense = reqDense) + blockMapWithIndex( + { case ((_, j), lm) => + val lv = gp.vectorOnBlockCol(vBc.value, j) + op(lm, lv) + }, + name, + reqDense = reqDense, + ) } def reduce(blockOp: BDM[Double] => Double, scalarOp: (Double, Double) => Double): Double = - blocks - .map { case ((i, j), lm) => blockOp(lm) } - .fold(0.0)(scalarOp) - - def rowReduce(blockOp: BDM[Double] => BDV[Double], vectorOp: (BDV[Double], BDV[Double]) => BDV[Double]): BlockMatrix = + blocks + .map { case ((_, _), lm) => blockOp(lm) } + .fold(0.0)(scalarOp) + + def rowReduce( + blockOp: BDM[Double] => BDV[Double], + vectorOp: (BDV[Double], BDV[Double]) => BDV[Double], + ): BlockMatrix = new BlockMatrix( blocks - .map { case ((i, j), lm) => ((0, j), blockOp(lm)) } + .map { case ((_, j), lm) => ((0, j), blockOp(lm)) } .reduceByKey(GridPartitioner(blockSize, 1, nCols, gp.maybeBlockCols()), vectorOp) .mapValues(v => new BDM[Double](1, v.length, v.data)), - blockSize, 1, nCols) - - def colReduce(blockOp: BDM[Double] => BDV[Double], vectorOp: (BDV[Double], BDV[Double]) => BDV[Double]): BlockMatrix = + blockSize, + 1, + nCols, + ) + + def colReduce( + blockOp: BDM[Double] => BDV[Double], + vectorOp: (BDV[Double], BDV[Double]) => BDV[Double], + ): BlockMatrix = new BlockMatrix( blocks - .map { case ((i, j), lm) => ((i, 0), blockOp(lm)) } + .map { case ((i, _), lm) => ((i, 0), blockOp(lm)) } .reduceByKey(GridPartitioner(blockSize, nRows, 1, gp.maybeBlockRows()), vectorOp) .mapValues(v => new BDM[Double](v.length, 1, v.data)), - blockSize, nRows, 1) - + blockSize, + nRows, + 1, + ) + def toIndexedRowMatrix(): IndexedRowMatrix = { require(nCols <= Integer.MAX_VALUE) val nColsInt = nCols.toInt @@ -1282,15 +1463,18 @@ class BlockMatrix(val blocks: RDD[((Int, Int), BDM[Double])], l } - new IndexedRowMatrix(blocks.flatMap { case ((i, j), lm) => - val iOffset = i * blockSize - val jOffset = j * blockSize - - for (k <- 0 until lm.rows) - yield (k + iOffset, (jOffset, lm(k, ::).inner.toArray)) - }.aggregateByKey(new Array[Double](nColsInt))(seqOp, combOp) - .map { case (i, a) => IndexedRow(i, BDV(a)) }, - nRows, nColsInt) + new IndexedRowMatrix( + blocks.flatMap { case ((i, j), lm) => + val iOffset = i * blockSize + val jOffset = j * blockSize + + for (k <- 0 until lm.rows) + yield (k + iOffset, (jOffset, lm(k, ::).inner.toArray)) + }.aggregateByKey(new Array[Double](nColsInt))(seqOp, combOp) + .map { case (i, a) => IndexedRow(i, BDV(a)) }, + nRows, + nColsInt, + ) } def getElement(row: Long, col: Long): Double = { @@ -1309,56 +1493,66 @@ class BlockMatrix(val blocks: RDD[((Int, Int), BDM[Double])], } else 0.0 } - - def filterRows(keep: Array[Long]): BlockMatrix = { + + def filterRows(keep: Array[Long]): BlockMatrix = new BlockMatrix(new BlockMatrixFilterRowsRDD(this, keep), blockSize, keep.length, nCols) - } - def filterCols(keep: Array[Long]): BlockMatrix = { + def filterCols(keep: Array[Long]): BlockMatrix = new BlockMatrix(new BlockMatrixFilterColsRDD(this, keep), blockSize, nRows, keep.length) - } - def filter(keepRows: Array[Long], keepCols: Array[Long]): BlockMatrix = { - new BlockMatrix(new BlockMatrixFilterRDD(this, keepRows, keepCols), - blockSize, keepRows.length, keepCols.length) - } + def filter(keepRows: Array[Long], keepCols: Array[Long]): BlockMatrix = + new BlockMatrix( + new BlockMatrixFilterRDD(this, keepRows, keepCols), + blockSize, + keepRows.length, + keepCols.length, + ) def entriesTable(ctx: ExecuteContext): TableValue = { - val rowType = PCanonicalStruct(true, "i" -> PInt64Required, "j" -> PInt64Required, "entry" -> PFloat64Required) - + val rowType = PCanonicalStruct( + true, + "i" -> PInt64Required, + "j" -> PInt64Required, + "entry" -> PFloat64Required, + ) + val sm = ctx.stateManager - val entriesRDD = ContextRDD.weaken(blocks).cflatMap { case (rvdContext, ((blockRow, blockCol), block)) => - val rowOffset = blockRow * blockSize.toLong - val colOffset = blockCol * blockSize.toLong - - val rvb = new RegionValueBuilder(sm, rvdContext.region) - - block.activeIterator - .map { case ((i, j), entry) => - rvb.start(rowType) - rvb.startStruct() - rvb.addLong(rowOffset + i) - rvb.addLong(colOffset + j) - rvb.addDouble(entry) - rvb.endStruct() - rvb.end() - } - } + val entriesRDD = + ContextRDD.weaken(blocks).cflatMap { case (rvdContext, ((blockRow, blockCol), block)) => + val rowOffset = blockRow * blockSize.toLong + val colOffset = blockCol * blockSize.toLong + + val rvb = new RegionValueBuilder(sm, rvdContext.region) + + block.activeIterator + .map { case ((i, j), entry) => + rvb.start(rowType) + rvb.startStruct() + rvb.addLong(rowOffset + i) + rvb.addLong(colOffset + j) + rvb.addDouble(entry) + rvb.endStruct() + rvb.end() + } + } TableValue(ctx, rowType, FastSeq(), entriesRDD) } } -case class BlockMatrixFilterRDDPartition(index: Int, +case class BlockMatrixFilterRDDPartition( + index: Int, blockRowRanges: Array[(Int, Array[Int], Array[Int])], - blockColRanges: Array[(Int, Array[Int], Array[Int])]) extends Partition + blockColRanges: Array[(Int, Array[Int], Array[Int])], +) extends Partition object BlockMatrixFilterRDD { - // allBlockColRanges(newBlockCol) has elements of the form (blockCol, startIndices, endIndices) with blockCol increasing - // startIndices.zip(endIndices) gives all column-index ranges in blockCol to be copied to ranges in newBlockCol - def computeAllBlockColRanges(keep: Array[Long], - gp: GridPartitioner, - newGP: GridPartitioner): Array[Array[(Int, Array[Int], Array[Int])]] = { + /* allBlockColRanges(newBlockCol) has elements of the form (blockCol, startIndices, endIndices) + * with blockCol increasing */ + /* startIndices.zip(endIndices) gives all column-index ranges in blockCol to be copied to ranges + * in newBlockCol */ + def computeAllBlockColRanges(keep: Array[Long], gp: GridPartitioner, newGP: GridPartitioner) + : Array[Array[(Int, Array[Int], Array[Int])]] = { val blockSize = gp.blockSize val ab = new BoxedArrayBuilder[(Int, Array[Int], Array[Int])]() @@ -1389,7 +1583,9 @@ object BlockMatrixFilterRDD { var endCol = startCol + 1 var k = j + 1 - while (k < newBlockNCols && colsInNewBlock(k) == endCol && endCol < finalColInBlockCol) { // extend range + while ( + k < newBlockNCols && colsInNewBlock(k) == endCol && endCol < finalColInBlockCol + ) { // extend range endCol += 1 k += 1 } @@ -1402,21 +1598,18 @@ object BlockMatrixFilterRDD { }.toArray } - def computeAllBlockRowRanges(keep: Array[Long], - gp: GridPartitioner, - newGP: GridPartitioner): Array[Array[(Int, Array[Int], Array[Int])]] = { - + def computeAllBlockRowRanges(keep: Array[Long], gp: GridPartitioner, newGP: GridPartitioner) + : Array[Array[(Int, Array[Int], Array[Int])]] = computeAllBlockColRanges(keep, gp.transpose._1, newGP.transpose._1) - } } // checked in Python: keepRows and keepCols non-empty, increasing, valid range private class BlockMatrixFilterRDD(bm: BlockMatrix, keepRows: Array[Long], keepCols: Array[Long]) - extends RDD[((Int, Int), BDM[Double])](bm.blocks.sparkContext, Nil) { + extends RDD[((Int, Int), BDM[Double])](bm.blocks.sparkContext, Nil) { log.info("Constructing BlockMatrixFilterRDD") val t0 = System.nanoTime() - + private val originalGP = bm.gp if (bm.isSparse) { @@ -1434,7 +1627,7 @@ private class BlockMatrixFilterRDD(bm: BlockMatrix, keepRows: Array[Long], keepC private val originalMaybeBlocksSet = originalGP.partitionIndexToBlockIndex.map(_.toSet) - private val blockParentMap = (0 until tempDenseGP.numPartitions).map {blockId => + private val blockParentMap = (0 until tempDenseGP.numPartitions).map { blockId => val (newBlockRow, newBlockCol) = tempDenseGP.blockCoordinates(blockId) val parents = for { @@ -1442,26 +1635,33 @@ private class BlockMatrixFilterRDD(bm: BlockMatrix, keepRows: Array[Long], keepC blockCol <- allBlockColRanges(newBlockCol).map(_._1) } yield originalGP.coordinatesBlock(blockRow, blockCol) (blockId, parents) - }.map{case (blockId, parents) => + }.map { case (blockId, parents) => val filteredParents = originalMaybeBlocksSet match { case None => parents case Some(blockIdSet) => parents.filter(id => blockIdSet.contains(id)) } (blockId, filteredParents) - }.filter{case (_, parents) => !parents.isEmpty}.toMap + }.filter { case (_, parents) => !parents.isEmpty }.toMap private val blockIndices = blockParentMap.keys.toArray.sorted - private val newGPMaybeBlocks: Option[IndexedSeq[Int]] = if (blockIndices.length == tempDenseGP.maxNBlocks) None else Some(blockIndices) + + private val newGPMaybeBlocks: Option[IndexedSeq[Int]] = + if (blockIndices.length == tempDenseGP.maxNBlocks) None else Some(blockIndices) + private val newGP = tempDenseGP.copy(partitionIndexToBlockIndex = newGPMaybeBlocks) - log.info(s"Finished constructing block matrix filter RDD. Total time ${(System.nanoTime() - t0).toDouble / 1000000000}") + log.info( + s"Finished constructing block matrix filter RDD. Total time ${(System.nanoTime() - t0).toDouble / 1000000000}" + ) protected def getPartitions: Array[Partition] = Array.tabulate(newGP.numPartitions) { partitionIndex => val blockIndex = newGP.partitionToBlock(partitionIndex) - BlockMatrixFilterRDDPartition(partitionIndex, + BlockMatrixFilterRDDPartition( + partitionIndex, allBlockRowRanges(newGP.blockBlockRow(blockIndex)), - allBlockColRanges(newGP.blockBlockCol(blockIndex))) + allBlockColRanges(newGP.blockBlockCol(blockIndex)), + ) } override def getDependencies: Seq[Dependency[_]] = Array[Dependency[_]]( @@ -1469,10 +1669,12 @@ private class BlockMatrixFilterRDD(bm: BlockMatrix, keepRows: Array[Long], keepC def getParents(partitionId: Int): Seq[Int] = { val blockForPartition = newGP.partitionToBlock(partitionId) val blockParents = blockParentMap(blockForPartition) - val partitionParents = blockParents.map(blockId => originalGP.blockToPartition(blockId)).toSet.toArray.sorted + val partitionParents = + blockParents.map(blockId => originalGP.blockToPartition(blockId)).toSet.toArray.sorted partitionParents } - }) + } + ) def compute(split: Partition, context: TaskContext): Iterator[((Int, Int), BDM[Double])] = { val part = split.asInstanceOf[BlockMatrixFilterRDDPartition] @@ -1483,13 +1685,13 @@ private class BlockMatrixFilterRDD(bm: BlockMatrix, keepRows: Array[Long], keepC val parentZeroBlock = BDM.zeros[Double](originalGP.blockSize, originalGP.blockSize) val newBlock = BDM.zeros[Double](newBlockNRows, newBlockNCols) - log.info(s"Computing partition for FilterRDD ${part}") + log.info(s"Computing partition for FilterRDD $part") var jCol = 0 var kCol = 0 part.blockColRanges.foreach { case (blockCol, colStartIndices, colEndIndices) => val jCol0 = jCol // record first col index in newBlock corresponding to new blockCol - var jRow = 0 + var jRow = 0 var kRow = 0 part.blockRowRanges.foreach { case (blockRow, rowStartIndices, rowEndIndices) => val jRow0 = jRow // record first row index in newBlock corresponding to new blockRow @@ -1516,7 +1718,10 @@ private class BlockMatrixFilterRDD(bm: BlockMatrix, keepRows: Array[Long], keepC val eiRow = rowEndIndices(rowRangeIndex) kRow = jRow + eiRow - siRow - newBlock(jRow until kRow, jCol until kCol) := block(siRow until eiRow, siCol until eiCol) + newBlock(jRow until kRow, jCol until kCol) := block( + siRow until eiRow, + siCol until eiCol, + ) jRow = kRow rowRangeIndex += 1 @@ -1535,11 +1740,14 @@ private class BlockMatrixFilterRDD(bm: BlockMatrix, keepRows: Array[Long], keepC @transient override val partitioner: Option[Partitioner] = Some(newGP) } -case class BlockMatrixFilterOneDimRDDPartition(index: Int, blockRanges: Array[(Int, Array[Int], Array[Int])]) extends Partition +case class BlockMatrixFilterOneDimRDDPartition( + index: Int, + blockRanges: Array[(Int, Array[Int], Array[Int])], +) extends Partition // checked in Python: keep non-empty, increasing, valid range private class BlockMatrixFilterColsRDD(bm: BlockMatrix, keep: Array[Long]) - extends RDD[((Int, Int), BDM[Double])](bm.blocks.sparkContext, Nil) { + extends RDD[((Int, Int), BDM[Double])](bm.blocks.sparkContext, Nil) { private val childPartitionsBc = bm.blocks.sparkContext.broadcast(bm.blocks.partitions) @@ -1552,44 +1760,52 @@ private class BlockMatrixFilterColsRDD(bm: BlockMatrix, keep: Array[Long]) @transient private val originalMaybeBlocksSet = originalGP.partitionIndexToBlockIndex.map(_.toSet) - //Map the denseGP blocks to the blocks of parents they depend on, temporarily pretending they are all there. - //Then delete the parents that aren't in originalGP.maybeBlocks, then delete the pairs - //without parents at all. + /* Map the denseGP blocks to the blocks of parents they depend on, temporarily pretending they are + * all there. */ + // Then delete the parents that aren't in originalGP.maybeBlocks, then delete the pairs + // without parents at all. @transient private val blockParentMap = (0 until tempDenseGP.numPartitions).map { blockId => val (blockRow, newBlockCol) = tempDenseGP.blockCoordinates(blockId) blockId -> allBlockColRanges(newBlockCol).map { case (blockCol, _, _) => originalGP.coordinatesBlock(blockRow, blockCol) } - }.map{case (blockId, parents) => - val filteredParents = originalMaybeBlocksSet match { - case None => parents - case Some(blockIdSet) => parents.filter(id => blockIdSet.contains(id)) - } - (blockId, filteredParents) - }.filter{case (_, parents) => !parents.isEmpty}.toMap + }.map { case (blockId, parents) => + val filteredParents = originalMaybeBlocksSet match { + case None => parents + case Some(blockIdSet) => parents.filter(id => blockIdSet.contains(id)) + } + (blockId, filteredParents) + }.filter { case (_, parents) => !parents.isEmpty }.toMap private val blockParentMapBc = bm.blocks.sparkContext.broadcast(blockParentMap) @transient private val blockIndices = blockParentMap.keys.toFastSeq.sorted - @transient private val newGPMaybeBlocks = if (blockIndices.length == tempDenseGP.maxNBlocks) None else Some(blockIndices) + + @transient private val newGPMaybeBlocks = + if (blockIndices.length == tempDenseGP.maxNBlocks) None else Some(blockIndices) + private val newGP = tempDenseGP.copy(partitionIndexToBlockIndex = newGPMaybeBlocks) - protected def getPartitions: Array[Partition] = { + protected def getPartitions: Array[Partition] = Array.tabulate(newGP.numPartitions) { partitionIndex: Int => val blockIndex = newGP.partitionToBlock(partitionIndex) - BlockMatrixFilterOneDimRDDPartition(partitionIndex, allBlockColRanges(newGP.blockBlockCol(blockIndex))) + BlockMatrixFilterOneDimRDDPartition( + partitionIndex, + allBlockColRanges(newGP.blockBlockCol(blockIndex)), + ) } - } override def getDependencies: Seq[Dependency[_]] = Array[Dependency[_]]( new NarrowDependency(bm.blocks) { def getParents(partitionId: Int): Seq[Int] = { - val blockForPartition = newGP.partitionToBlock(partitionId) - val blockParents = blockParentMap(blockForPartition) - val partitionParents = blockParents.map(blockId => originalGP.blockToPartition(blockId)).toSet.toArray.sorted - partitionParents + val blockForPartition = newGP.partitionToBlock(partitionId) + val blockParents = blockParentMap(blockForPartition) + val partitionParents = + blockParents.map(blockId => originalGP.blockToPartition(blockId)).toSet.toArray.sorted + partitionParents } - }) + } + ) def compute(split: Partition, context: TaskContext): Iterator[((Int, Int), BDM[Double])] = { val blockIndex = newGP.partitionToBlock(split.index) @@ -1634,7 +1850,7 @@ private class BlockMatrixFilterColsRDD(bm: BlockMatrix, keep: Array[Long]) // checked in Python: keep non-empty, increasing, valid range private class BlockMatrixFilterRowsRDD(bm: BlockMatrix, keep: Array[Long]) - extends RDD[((Int, Int), BDM[Double])](bm.blocks.sparkContext, Nil) { + extends RDD[((Int, Int), BDM[Double])](bm.blocks.sparkContext, Nil) { private val childPartitionsBc = bm.blocks.sparkContext.broadcast(bm.blocks.partitions) @@ -1647,44 +1863,52 @@ private class BlockMatrixFilterRowsRDD(bm: BlockMatrix, keep: Array[Long]) @transient private val originalMaybeBlocksSet = originalGP.partitionIndexToBlockIndex.map(_.toSet) - //Map the denseGP blocks to the blocks of parents they depend on, temporarily pretending they are all there. - //Then delete the parents that aren't in originalGP.maybeBlocks, then delete the pairs - //without parents at all. + /* Map the denseGP blocks to the blocks of parents they depend on, temporarily pretending they are + * all there. */ + // Then delete the parents that aren't in originalGP.maybeBlocks, then delete the pairs + // without parents at all. @transient private val blockParentMap = (0 until tempDenseGP.numPartitions).map { blockId => val (newBlockRow, blockCol) = tempDenseGP.blockCoordinates(blockId) blockId -> allBlockRowRanges(newBlockRow).map { case (blockRow, _, _) => originalGP.coordinatesBlock(blockRow, blockCol) } - }.map{case (blockId, parents) => + }.map { case (blockId, parents) => val filteredParents = originalMaybeBlocksSet match { case None => parents case Some(blockIdSet) => parents.filter(id => blockIdSet.contains(id)) } (blockId, filteredParents) - }.filter{case (_, parents) => !parents.isEmpty}.toMap + }.filter { case (_, parents) => !parents.isEmpty }.toMap private val blockParentMapBc = bm.blocks.sparkContext.broadcast(blockParentMap) @transient private val blockIndices = blockParentMap.keys.toFastSeq.sorted - @transient private val newGPMaybeBlocks = if (blockIndices.length == tempDenseGP.maxNBlocks) None else Some(blockIndices) + + @transient private val newGPMaybeBlocks = + if (blockIndices.length == tempDenseGP.maxNBlocks) None else Some(blockIndices) + private val newGP = tempDenseGP.copy(partitionIndexToBlockIndex = newGPMaybeBlocks) - protected def getPartitions: Array[Partition] = { + protected def getPartitions: Array[Partition] = Array.tabulate(newGP.numPartitions) { partitionIndex: Int => val blockIndex = newGP.partitionToBlock(partitionIndex) - BlockMatrixFilterOneDimRDDPartition(partitionIndex, allBlockRowRanges(newGP.blockBlockRow(blockIndex))) + BlockMatrixFilterOneDimRDDPartition( + partitionIndex, + allBlockRowRanges(newGP.blockBlockRow(blockIndex)), + ) } - } override def getDependencies: Seq[Dependency[_]] = Array[Dependency[_]]( new NarrowDependency(bm.blocks) { def getParents(partitionId: Int): Seq[Int] = { val blockForPartition = newGP.partitionToBlock(partitionId) val blockParents = blockParentMap(blockForPartition) - val partitionParents = blockParents.map(blockId => originalGP.blockToPartition(blockId)).toSet.toArray.sorted + val partitionParents = + blockParents.map(blockId => originalGP.blockToPartition(blockId)).toSet.toArray.sorted partitionParents } - }) + } + ) def compute(split: Partition, context: TaskContext): Iterator[((Int, Int), BDM[Double])] = { val blockIndex = newGP.partitionToBlock(split.index) @@ -1730,7 +1954,7 @@ private class BlockMatrixFilterRowsRDD(bm: BlockMatrix, keep: Array[Long]) case class BlockMatrixTransposeRDDPartition(index: Int, prevPartition: Partition) extends Partition private class BlockMatrixTransposeRDD(bm: BlockMatrix) - extends RDD[((Int, Int), BDM[Double])](bm.blocks.sparkContext, Nil) { + extends RDD[((Int, Int), BDM[Double])](bm.blocks.sparkContext, Nil) { private val (newGP, transposedPartitionIndicesToParentPartitions) = bm.gp.transpose @@ -1743,17 +1967,20 @@ private class BlockMatrixTransposeRDD(bm: BlockMatrix) assert(newI == oldJ && newJ == oldI) Array(parent) } - }) + } + ) def compute(split: Partition, context: TaskContext): Iterator[((Int, Int), BDM[Double])] = bm.blocks.iterator(split.asInstanceOf[BlockMatrixTransposeRDDPartition].prevPartition, context) .map { case ((j, i), lm) => ((i, j), lm.t) } - protected def getPartitions: Array[Partition] = { + protected def getPartitions: Array[Partition] = Array.tabulate(newGP.numPartitions) { pi => - BlockMatrixTransposeRDDPartition(pi, bm.blocks.partitions(transposedPartitionIndicesToParentPartitions(pi))) + BlockMatrixTransposeRDDPartition( + pi, + bm.blocks.partitions(transposedPartitionIndicesToParentPartitions(pi)), + ) } - } @transient override val partitioner: Option[Partitioner] = Some(newGP) } @@ -1761,8 +1988,8 @@ private class BlockMatrixTransposeRDD(bm: BlockMatrix) private class BlockMatrixUnionOpRDD( l: BlockMatrix, r: BlockMatrix, - op: ((Option[BDM[Double]], Option[BDM[Double]])) => BDM[Double]) - extends RDD[((Int, Int), BDM[Double])](l.blocks.sparkContext, Nil) { + op: ((Option[BDM[Double]], Option[BDM[Double]])) => BDM[Double], +) extends RDD[((Int, Int), BDM[Double])](l.blocks.sparkContext, Nil) { import BlockMatrix.block @@ -1776,15 +2003,18 @@ private class BlockMatrixUnionOpRDD( private val lParts = l.blocks.partitions private val rParts = r.blocks.partitions - + override def getDependencies: Seq[Dependency[_]] = Array[Dependency[_]]( new NarrowDependency(l.blocks) { - def getParents(partitionId: Int): Seq[Int] = Array(lGP.blockToPartition(gp.partitionToBlock(partitionId))).filter(_ >= 0) + def getParents(partitionId: Int): Seq[Int] = + Array(lGP.blockToPartition(gp.partitionToBlock(partitionId))).filter(_ >= 0) }, new NarrowDependency(r.blocks) { - def getParents(partitionId: Int): Seq[Int] = Array(rGP.blockToPartition(gp.partitionToBlock(partitionId))).filter(_ >= 0) - }) + def getParents(partitionId: Int): Seq[Int] = + Array(rGP.blockToPartition(gp.partitionToBlock(partitionId))).filter(_ >= 0) + }, + ) def compute(split: Partition, context: TaskContext): Iterator[((Int, Int), BDM[Double])] = { val (i, j) = gp.partCoordinates(split.index) @@ -1794,25 +2024,31 @@ private class BlockMatrixUnionOpRDD( } protected def getPartitions: Array[Partition] = Array.tabulate(gp.numPartitions)(pi => - new Partition { def index: Int = pi } ) + new Partition { def index: Int = pi } + ) @transient override val partitioner: Option[Partitioner] = Some(gp) } private class BlockMatrixMultiplyRDD(l: BlockMatrix, r: BlockMatrix) - extends RDD[((Int, Int), BDM[Double])](l.blocks.sparkContext, Nil) { + extends RDD[((Int, Int), BDM[Double])](l.blocks.sparkContext, Nil) { import BlockMatrix.block - require(l.nCols == r.nRows, - s"inner dimensions must match, but given: ${ l.nRows }x${ l.nCols }, ${ r.nRows }x${ r.nCols }") - require(l.blockSize == r.blockSize, - s"blocks must be same size, but actually were ${ l.blockSize }x${ l.blockSize } and ${ r.blockSize }x${ r.blockSize }") + require( + l.nCols == r.nRows, + s"inner dimensions must match, but given: ${l.nRows}x${l.nCols}, ${r.nRows}x${r.nCols}", + ) + + require( + l.blockSize == r.blockSize, + s"blocks must be same size, but actually were ${l.blockSize}x${l.blockSize} and ${r.blockSize}x${r.blockSize}", + ) private val lGP = l.gp private val rGP = r.gp private val gp = GridPartitioner(l.blockSize, l.nRows, r.nCols) - + private val lParts = l.blocks.partitions private val rParts = r.blocks.partitions private val nProducts = lGP.nBlockCols @@ -1830,51 +2066,69 @@ private class BlockMatrixMultiplyRDD(l: BlockMatrix, r: BlockMatrix) val j = gp.blockBlockCol(partitionId) (0 until nProducts).map(k => rGP.coordinatesPart(k, j)).filter(_ >= 0).toArray } - }) + }, + ) - def fma(c: BDM[Double], _a: BDM[Double], _b: BDM[Double]) { + def fma(c: BDM[Double], _a: BDM[Double], _b: BDM[Double]): Unit = { assert(_a.cols == _b.rows) - val a = if (_a.majorStride < math.max(if (_a.isTranspose) _a.cols else _a.rows, 1)) _a.copy else _a - val b = if (_b.majorStride < math.max(if (_b.isTranspose) _b.cols else _b.rows, 1)) _b.copy else _b + val a = + if (_a.majorStride < math.max(if (_a.isTranspose) _a.cols else _a.rows, 1)) _a.copy else _a + val b = + if (_b.majorStride < math.max(if (_b.isTranspose) _b.cols else _b.rows, 1)) _b.copy else _b import com.github.fommil.netlib.BLAS.{getInstance => blas} blas.dgemm( if (a.isTranspose) "T" else "N", if (b.isTranspose) "T" else "N", - c.rows, c.cols, a.cols, - 1.0, a.data, a.offset, a.majorStride, - b.data, b.offset, b.majorStride, - 1.0, c.data, 0, c.rows) + c.rows, + c.cols, + a.cols, + 1.0, + a.data, + a.offset, + a.majorStride, + b.data, + b.offset, + b.majorStride, + 1.0, + c.data, + 0, + c.rows, + ) } def compute(split: Partition, context: TaskContext): Iterator[((Int, Int), BDM[Double])] = { - val (i, j) = gp.blockCoordinates(split.index) - val (blockNRows, blockNCols) = gp.blockDims(split.index) - val product = BDM.zeros[Double](blockNRows, blockNCols) - var k = 0 - while (k < nProducts) { - val left = block(l, lParts, lGP, context, i, k) - val right = block(r, rParts, rGP, context, k, j) - if (left.isDefined && right.isDefined) { - fma(product, left.get, right.get) - } - k += 1 - } - Iterator.single(((i, j), product)) - } + val (i, j) = gp.blockCoordinates(split.index) + val (blockNRows, blockNCols) = gp.blockDims(split.index) + val product = BDM.zeros[Double](blockNRows, blockNCols) + var k = 0 + while (k < nProducts) { + val left = block(l, lParts, lGP, context, i, k) + val right = block(r, rParts, rGP, context, k, j) + if (left.isDefined && right.isDefined) { + fma(product, left.get, right.get) + } + k += 1 + } + Iterator.single(((i, j), product)) + } protected def getPartitions: Array[Partition] = Array.tabulate(gp.numPartitions)(pi => - new Partition { def index: Int = pi } ) + new Partition { def index: Int = pi } + ) @transient override val partitioner: Option[Partitioner] = Some(gp) } case class BlockMatrixRectanglesRDD(rectangles: Array[Array[Long]], bm: BlockMatrix) - extends RDD[(Int, BDM[Double])](bm.blocks.sparkContext, Nil) { + extends RDD[(Int, BDM[Double])](bm.blocks.sparkContext, Nil) { assert(rectangles.forall(rect => rect.length == 4)) - assert(rectangles.forall(rect => rect(1) - rect(0) <= Int.MaxValue && rect(3) - rect(2) <= Int.MaxValue)) + + assert(rectangles.forall(rect => + rect(1) - rect(0) <= Int.MaxValue && rect(3) - rect(2) <= Int.MaxValue + )) val gp: GridPartitioner = bm.gp @@ -1882,7 +2136,8 @@ case class BlockMatrixRectanglesRDD(rectangles: Array[Array[Long]], bm: BlockMat val rect = rectangles(split.index) val Array(rectStartRow, rectEndRow, rectStartCol, rectEndCol) = rect - val rectData = new BDM[Double]((rectEndRow - rectStartRow).toInt, (rectEndCol - rectStartCol).toInt) + val rectData = + new BDM[Double]((rectEndRow - rectStartRow).toInt, (rectEndCol - rectStartCol).toInt) val blocksInRectangle = gp.rectangleBlocks(rect) blocksInRectangle.foreach { blockIdx => val (blockRowIdx, blockColIdx) = gp.blockCoordinates(blockIdx) @@ -1896,33 +2151,32 @@ case class BlockMatrixRectanglesRDD(rectangles: Array[Array[Long]], bm: BlockMat val rectRowSlice = overlapRectSlice(rectStartRow, rectEndRow, blockStartRow, blockEndRow) val rectColSlice = overlapRectSlice(rectStartCol, rectEndCol, blockStartCol, blockEndCol) - BlockMatrix.block(bm, bm.blocks.partitions, gp, context, blockRowIdx, blockColIdx).foreach { block => - rectData(rectRowSlice, rectColSlice) := block(blockRowSlice, blockColSlice) + BlockMatrix.block(bm, bm.blocks.partitions, gp, context, blockRowIdx, blockColIdx).foreach { + block => rectData(rectRowSlice, rectColSlice) := block(blockRowSlice, blockColSlice) } } Iterator.single((split.index, rectData)) } - private def overlapBlockSlice(rectStart: Long, rectEnd: Long, blockStart: Long, blockEnd: Long): Range = { + private def overlapBlockSlice(rectStart: Long, rectEnd: Long, blockStart: Long, blockEnd: Long) + : Range = { val (start, end) = absoluteOverlap(rectStart, rectEnd, blockStart, blockEnd) (start - blockStart).toInt until (end - blockStart).toInt } - private def overlapRectSlice(rectStart: Long, rectEnd: Long, blockStart: Long, blockEnd: Long): Range = { + private def overlapRectSlice(rectStart: Long, rectEnd: Long, blockStart: Long, blockEnd: Long) + : Range = { val (start, end) = absoluteOverlap(rectStart, rectEnd, blockStart, blockEnd) (start - rectStart).toInt until (end - rectStart).toInt } - private def absoluteOverlap(rectStart: Long, rectEnd: Long, blockStart: Long, blockEnd: Long): (Long, Long) = { + private def absoluteOverlap(rectStart: Long, rectEnd: Long, blockStart: Long, blockEnd: Long) + : (Long, Long) = (Math.max(rectStart, blockStart), Math.min(rectEnd, blockEnd)) - } - override protected def getPartitions: Array[Partition] = { - Array.tabulate(rectangles.length) { rectIndex => - new Partition { val index: Int = rectIndex } - } - } + override protected def getPartitions: Array[Partition] = + Array.tabulate(rectangles.length)(rectIndex => new Partition { val index: Int = rectIndex }) } // On compute, WriteBlocksRDDPartition writes the block row with index `index` @@ -1933,7 +2187,8 @@ case class WriteBlocksRDDPartition( start: Int, skip: Int, end: Int, - parentPartitions: Array[Partition]) extends Partition { + parentPartitions: Array[Partition], +) extends Partition { def range: Range = start to end } @@ -1944,7 +2199,8 @@ class WriteBlocksRDD( rvd: RVD, parentPartStarts: Array[Long], entryField: String, - gp: GridPartitioner) extends RDD[(Int, String)](SparkBackend.sparkContext("WriteBlocksRDD"), Nil) { + gp: GridPartitioner, +) extends RDD[(Int, String)](SparkBackend.sparkContext("WriteBlocksRDD"), Nil) { require(gp.nRows == parentPartStarts.last) @@ -1988,8 +2244,13 @@ class WriteBlocksRDD( if (parentPartStarts(pi) > firstRowInNextBlock) pi -= 1 - parts(blockRow) = WriteBlocksRDDPartition(blockRow, start, skip, end, - (start to end).map(i => parentPartitions(i)).toArray) + parts(blockRow) = WriteBlocksRDDPartition( + blockRow, + start, + skip, + end, + (start to end).map(i => parentPartitions(i)).toArray, + ) firstRowInBlock = firstRowInNextBlock blockRow += 1 @@ -2056,7 +2317,9 @@ class WriteBlocksRDD( var blockCol = 0 var colIdx = 0 val colIt = entryArrayType.elementIterator( - entryArrayOffset, entryArrayType.loadLength(entryArrayOffset)) + entryArrayOffset, + entryArrayType.loadLength(entryArrayOffset), + ) while (blockCol < gp.nBlockCols) { val n = gp.blockColNCols(blockCol) var j = 0 @@ -2088,7 +2351,9 @@ class WriteBlocksRDD( } } outPerBlockCol.foreach(_.close()) - paths.foreach { case (tempPath, finalPath) => fsBc.value.copy(tempPath, finalPath, deleteSource = true)} + paths.foreach { case (tempPath, finalPath) => + fsBc.value.copy(tempPath, finalPath, deleteSource = true) + } blockPartFiles.iterator } } @@ -2103,18 +2368,30 @@ class BlockMatrixReadRowBlockedRDD( partitionRanges: IndexedSeq[NumericRange.Exclusive[Long]], requestedType: TStruct, metadata: BlockMatrixMetadata, - maybeMaximumCacheMemoryInBytes: Option[Int] -) extends RDD[RVDContext => Iterator[Long]](SparkBackend.sparkContext("BlockMatrixReadRowBlockedRDD"), Nil) { + maybeMaximumCacheMemoryInBytes: Option[Int], +) extends RDD[RVDContext => Iterator[Long]]( + SparkBackend.sparkContext("BlockMatrixReadRowBlockedRDD"), + Nil, + ) { import BlockMatrixReadRowBlockedRDD._ - private[this] val BlockMatrixMetadata(blockSize, nRows, nCols, maybeFiltered, partFiles) = metadata + private[this] val BlockMatrixMetadata(blockSize, nRows, nCols, _, partFiles) = + metadata + private[this] val gp = GridPartitioner(blockSize, nRows, nCols) - private[this] val maximumCacheMemoryInBytes = maybeMaximumCacheMemoryInBytes.getOrElse(DEFAULT_MAXIMUM_CACHE_MEMORY_IN_BYTES) + + private[this] val maximumCacheMemoryInBytes = + maybeMaximumCacheMemoryInBytes.getOrElse(DEFAULT_MAXIMUM_CACHE_MEMORY_IN_BYTES) + private[this] val doublesPerFile = maximumCacheMemoryInBytes / (gp.nBlockCols * 8) - assert(doublesPerFile >= blockSize, - "BlockMatrixCachedPartFile must be able to hold at least one row of every block in memory") - override def compute(split: Partition, context: TaskContext): Iterator[RVDContext => Iterator[Long]] = { + assert( + doublesPerFile >= blockSize, + "BlockMatrixCachedPartFile must be able to hold at least one row of every block in memory", + ) + + override def compute(split: Partition, context: TaskContext) + : Iterator[RVDContext => Iterator[Long]] = { val pi = split.index val rowsForPartition = partitionRanges(pi) val createRowIdx = requestedType.fieldNames.contains("row_idx") @@ -2124,8 +2401,9 @@ class BlockMatrixReadRowBlockedRDD( Array( if (createRowIdx) Some("row_idx" -> PInt64()) else None, Some("entries" -> PCanonicalArray(PFloat64())), - if (createRowUID) Some(TableReader.uidFieldName -> PInt64()) else None - ).flatten: _*) + if (createRowUID) Some(TableReader.uidFieldName -> PInt64()) else None, + ).flatten: _* + ) if (rowsForPartition.isEmpty) { return Iterator.single(ctx => Iterator.empty) @@ -2133,7 +2411,6 @@ class BlockMatrixReadRowBlockedRDD( Iterator.single { ctx => val region = ctx.region val rvb = new RegionValueBuilder(HailStateManager(Map.empty), region) - val rv = RegionValue(region) val firstRow = rowsForPartition(0) var blockRow = (firstRow / blockSize).toInt val fs = fsBc.value @@ -2143,7 +2420,8 @@ class BlockMatrixReadRowBlockedRDD( doublesPerFile, fs, path, - partFiles(gp.coordinatesBlock(blockRow, blockCol))) + partFiles(gp.coordinatesBlock(blockRow, blockCol)), + ) } rowsForPartition.iterator.map { row => @@ -2157,7 +2435,8 @@ class BlockMatrixReadRowBlockedRDD( doublesPerFile, fs, path, - partFiles(gp.coordinatesBlock(blockRow, blockCol))) + partFiles(gp.coordinatesBlock(blockRow, blockCol)), + ) } } @@ -2180,9 +2459,8 @@ class BlockMatrixReadRowBlockedRDD( } } - override def getPartitions: Array[Partition] = { - Array.tabulate(partitionRanges.length) { pi => new Partition { val index: Int = pi } } - } + override def getPartitions: Array[Partition] = + Array.tabulate(partitionRanges.length)(pi => new Partition { val index: Int = pi }) } class BlockMatrixCachedPartFile( @@ -2190,7 +2468,7 @@ class BlockMatrixCachedPartFile( _cacheCapacity: Int, private[this] val fs: FS, path: String, - pFile: String + pFile: String, ) { private[this] val cacheCapacity = math.min(_cacheCapacity, BlockMatrix.bufferSpecBlockSize) private[this] val cache = new Array[Double](cacheCapacity) @@ -2236,10 +2514,10 @@ class BlockMatrixCachedPartFile( in.skipBytes(8 * fileIndex) val doublesToRead = math.min( cacheCapacity - startWritingAt, - rows * cols - fileIndex) + rows * cols - fileIndex, + ) in.readDoubles(cache, startWritingAt, doublesToRead) cacheEnd = doublesToRead + startWritingAt - var i = 0 fileIndex += doublesToRead assert(doublesToRead > 0) } diff --git a/hail/src/main/scala/is/hail/linalg/GridPartitioner.scala b/hail/src/main/scala/is/hail/linalg/GridPartitioner.scala index bf4dea034c1..343143fde72 100644 --- a/hail/src/main/scala/is/hail/linalg/GridPartitioner.scala +++ b/hail/src/main/scala/is/hail/linalg/GridPartitioner.scala @@ -1,22 +1,30 @@ package is.hail.linalg -import breeze.linalg.{DenseVector => BDV} import is.hail.utils._ -import org.apache.spark.Partitioner import scala.collection.mutable -/** - * BLOCKS ARE NUMBERED COLUMN MAJOR +import breeze.linalg.{DenseVector => BDV} +import org.apache.spark.Partitioner + +/** BLOCKS ARE NUMBERED COLUMN MAJOR * * @param blockSize * @param nRows * @param nCols - * @param partitionIndexToBlockIndex If exists, matrix is sparse and this contains a list of indices of blocks that are not all zero + * @param partitionIndexToBlockIndex + * If exists, matrix is sparse and this contains a list of indices of blocks that are not all + * zero */ -case class GridPartitioner(blockSize: Int, nRows: Long, nCols: Long, partitionIndexToBlockIndex: Option[IndexedSeq[Int]] = None) extends Partitioner { +case class GridPartitioner( + blockSize: Int, + nRows: Long, + nCols: Long, + partitionIndexToBlockIndex: Option[IndexedSeq[Int]] = None, +) extends Partitioner { if (nRows == 0) fatal("block matrix must have at least one row") + if (nCols == 0) fatal("block matrix must have at least one column") @@ -32,12 +40,19 @@ case class GridPartitioner(blockSize: Int, nRows: Long, nCols: Long, partitionIn val maxNBlocks: Long = nBlockRows.toLong * nBlockCols - if (!partitionIndexToBlockIndex.forall(bis => bis.isEmpty || - (bis.isIncreasing && bis.head >= 0 && bis.last < maxNBlocks && - bis.length < maxNBlocks))) // a block-sparse matrix cannot have all blocks present - throw new IllegalArgumentException(s"requirement failed: Sparse blocks sequence was ${partitionIndexToBlockIndex.toIndexedSeq}, max was ${maxNBlocks}") + if ( + !partitionIndexToBlockIndex.forall(bis => + bis.isEmpty || + (bis.isIncreasing && bis.head >= 0 && bis.last < maxNBlocks && + bis.length < maxNBlocks) + ) + ) // a block-sparse matrix cannot have all blocks present + throw new IllegalArgumentException( + s"requirement failed: Sparse blocks sequence was ${partitionIndexToBlockIndex.toIndexedSeq}, max was $maxNBlocks" + ) - val blockToPartitionMap = partitionIndexToBlockIndex.map(_.zipWithIndex.toMap.withDefaultValue(-1)) + val blockToPartitionMap = + partitionIndexToBlockIndex.map(_.zipWithIndex.toMap.withDefaultValue(-1)) val lastBlockRowNRows: Int = indexBlockOffset(nRows - 1) + 1 val lastBlockColNCols: Int = indexBlockOffset(nCols - 1) + 1 @@ -48,7 +63,8 @@ case class GridPartitioner(blockSize: Int, nRows: Long, nCols: Long, partitionIn def blockBlockRow(bi: Int): Int = bi % nBlockRows def blockBlockCol(bi: Int): Int = bi / nBlockRows - def blockDims(bi: Int): (Int, Int) = (blockRowNRows(blockBlockRow(bi)), blockColNCols(blockBlockCol(bi))) + def blockDims(bi: Int): (Int, Int) = + (blockRowNRows(blockBlockRow(bi)), blockColNCols(blockBlockCol(bi))) def nBlocks: Int = partitionIndexToBlockIndex.map(_.length).getOrElse(nBlockRows * nBlockCols) @@ -61,24 +77,28 @@ case class GridPartitioner(blockSize: Int, nRows: Long, nCols: Long, partitionIn } def intersect(that: GridPartitioner): GridPartitioner = { - copy(partitionIndexToBlockIndex = (partitionIndexToBlockIndex, that.partitionIndexToBlockIndex) match { - case (Some(bis), Some(bis2)) => Some(bis.filter(bis2.toSet)) - case (Some(bis), None) => Some(bis) - case (None, Some(bis2)) => Some(bis2) - case (None, None) => None - }) + copy(partitionIndexToBlockIndex = + (partitionIndexToBlockIndex, that.partitionIndexToBlockIndex) match { + case (Some(bis), Some(bis2)) => Some(bis.filter(bis2.toSet)) + case (Some(bis), None) => Some(bis) + case (None, Some(bis2)) => Some(bis2) + case (None, None) => None + } + ) } def union(that: GridPartitioner): GridPartitioner = { - copy(partitionIndexToBlockIndex = (partitionIndexToBlockIndex, that.partitionIndexToBlockIndex) match { - case (Some(bis), Some(bis2)) => - val union = (bis ++ bis2).distinct - if (union.length == maxNBlocks) - None - else - Some(union.sorted) - case _ => None - }) + copy(partitionIndexToBlockIndex = + (partitionIndexToBlockIndex, that.partitionIndexToBlockIndex) match { + case (Some(bis), Some(bis2)) => + val union = (bis ++ bis2).distinct + if (union.length == maxNBlocks) + None + else + Some(union.sorted) + case _ => None + } + ) } override val numPartitions: Int = partitionIndexToBlockIndex match { @@ -99,7 +119,7 @@ case class GridPartitioner(blockSize: Int, nRows: Long, nCols: Long, partitionIn def blockToPartition(blockId: Int): Int = blockToPartitionMap match { case Some(bpMap) => bpMap(blockId) - case None => blockId + case None => blockId } def partCoordinates(pi: Int): (Int, Int) = blockCoordinates(partitionToBlock(pi)) @@ -110,25 +130,29 @@ case class GridPartitioner(blockSize: Int, nRows: Long, nCols: Long, partitionIn case (i: Int, j: Int) => coordinatesPart(i, j) } - /** - * - * @return A transposed GridPartitioner and a function that maps partitions in the new transposed partitioner to - * the parent partitions in the old partitioner. + /** @return + * A transposed GridPartitioner and a function that maps partitions in the new transposed + * partitioner to the parent partitions in the old partitioner. */ def transpose: (GridPartitioner, Int => Int) = { val gpT = GridPartitioner(blockSize, nCols, nRows) partitionIndexToBlockIndex match { case Some(bis) => - def transposeBI(bi: Int): Int = gpT.coordinatesBlock(this.blockBlockCol(bi), this.blockBlockRow(bi)) + def transposeBI(bi: Int): Int = + gpT.coordinatesBlock(this.blockBlockCol(bi), this.blockBlockRow(bi)) - val (partIdxTToBlockIdxT, partIdxTToPartIdx) = bis.map(transposeBI).zipWithIndex.sortBy(_._1).unzip + val (partIdxTToBlockIdxT, partIdxTToPartIdx) = + bis.map(transposeBI).zipWithIndex.sortBy(_._1).unzip val transposedPartitionIndicesToParentPartitions = partIdxTToPartIdx.apply(_) - (GridPartitioner(blockSize, nCols, nRows, Some(partIdxTToBlockIdxT)), transposedPartitionIndicesToParentPartitions) - case None => { - def transposedBlockIndicesToParentBlocks(bi: Int) = this.coordinatesBlock(gpT.blockBlockCol(bi), gpT.blockBlockRow(bi)) + ( + GridPartitioner(blockSize, nCols, nRows, Some(partIdxTToBlockIdxT)), + transposedPartitionIndicesToParentPartitions, + ) + case None => + def transposedBlockIndicesToParentBlocks(bi: Int) = + this.coordinatesBlock(gpT.blockBlockCol(bi), gpT.blockBlockRow(bi)) (gpT, transposedBlockIndicesToParentBlocks) - } } } @@ -162,13 +186,14 @@ case class GridPartitioner(blockSize: Int, nRows: Long, nCols: Long, partitionIn // all elements with lower <= jj - ii <= upper def bandBlocks(lower: Long, upper: Long): Array[Int] = { require(lower <= upper) - + val lowerBlock = java.lang.Math.floorDiv(lower, blockSize).toInt val upperBlock = java.lang.Math.floorDiv(upper + blockSize - 1, blockSize).toInt - (for { j <- 0 until nBlockCols - i <- ((j - upperBlock) max 0) to - ((j - lowerBlock) min (nBlockRows - 1)) + (for { + j <- 0 until nBlockCols + i <- ((j - upperBlock) max 0) to + ((j - lowerBlock) min (nBlockRows - 1)) } yield (j * nBlockRows) + i).toArray } @@ -180,20 +205,21 @@ case class GridPartitioner(blockSize: Int, nRows: Long, nCols: Long, partitionIn val stopBlockRow = java.lang.Math.floorDiv(r(1) - 1, blockSize).toInt + 1 val startBlockCol = indexBlockIndex(r(2)) val stopBlockCol = java.lang.Math.floorDiv(r(3) - 1, blockSize).toInt + 1 - - (for { j <- startBlockCol until stopBlockCol - i <- startBlockRow until stopBlockRow + + (for { + j <- startBlockCol until stopBlockCol + i <- startBlockRow until stopBlockRow } yield (j * nBlockRows) + i).toArray } // returns increasing array of all blocks intersecting the union of rectangles // rectangles checked in Python def rectanglesBlocks(rectangles: Array[Array[Long]]): Array[Int] = { - val blocks = rectangles.foldLeft(mutable.Set[Int]())((s, r) => s ++= rectangleBlocks(r)).toArray + val blocks = rectangles.foldLeft(mutable.Set[Int]())((s, r) => s ++= rectangleBlocks(r)).toArray scala.util.Sorting.quickSort(blocks) blocks } - + // starts, stops checked in Python def rowIntervalsBlocks(starts: Array[Long], stops: Array[Long]): Array[Int] = { val rectangles = starts.grouped(blockSize).zip(stops.grouped(blockSize)) @@ -218,7 +244,7 @@ case class GridPartitioner(blockSize: Int, nRows: Long, nCols: Long, partitionIn } else None }.toArray - + rectanglesBlocks(rectangles) } } diff --git a/hail/src/main/scala/is/hail/linalg/LAPACK.scala b/hail/src/main/scala/is/hail/linalg/LAPACK.scala index 842e0a44b96..8c1e5a83ee8 100644 --- a/hail/src/main/scala/is/hail/linalg/LAPACK.scala +++ b/hail/src/main/scala/is/hail/linalg/LAPACK.scala @@ -1,17 +1,18 @@ package is.hail.linalg +import is.hail.utils._ + +import scala.util.{Failure, Success, Try} + import java.lang.reflect.Method import java.util.function._ -import com.sun.jna.{FunctionMapper, Library, Native, NativeLibrary} -import com.sun.jna.ptr.{IntByReference, DoubleByReference} -import scala.util.{Failure, Success, Try} -import is.hail.utils._ +import com.sun.jna.{FunctionMapper, Library, Native, NativeLibrary} +import com.sun.jna.ptr.{DoubleByReference, IntByReference} class UnderscoreFunctionMapper extends FunctionMapper { - override def getFunctionName(library: NativeLibrary, method: Method): String = { + override def getFunctionName(library: NativeLibrary, method: Method): String = method.getName() + "_" - } } // ALL LAPACK C function args must be passed by address, not value @@ -38,15 +39,17 @@ object LAPACK { versionTest(standard) match { case Success(version) => - log.info(s"Imported LAPACK library ${libraryName}, version ${version}, with standard names") + log.info(s"Imported LAPACK library $libraryName, version $version, with standard names") standard - case Failure(exception) => + case Failure(_) => val underscoreAfterMap = new java.util.HashMap[String, FunctionMapper]() underscoreAfterMap.put(Library.OPTION_FUNCTION_MAPPER, new UnderscoreFunctionMapper) val underscoreAfter = Native.load(libraryName, classOf[LAPACKLibrary], underscoreAfterMap) versionTest(underscoreAfter) match { case Success(version) => - log.info(s"Imported LAPACK library ${libraryName}, version ${version}, with underscore names") + log.info( + s"Imported LAPACK library $libraryName, version $version, with underscore names" + ) underscoreAfter case Failure(exception) => throw exception @@ -87,7 +90,21 @@ object LAPACK { infoInt.getValue() } - def dgemqrt(side: String, trans: String, m: Int, n: Int, k: Int, nb: Int, V: Long, ldV: Int, T: Long, ldT: Int, C: Long, ldC: Int, work: Long): Int = { + def dgemqrt( + side: String, + trans: String, + m: Int, + n: Int, + k: Int, + nb: Int, + V: Long, + ldV: Int, + T: Long, + ldT: Int, + C: Long, + ldC: Int, + work: Long, + ): Int = { val mInt = new IntByReference(m) val nInt = new IntByReference(n) val kInt = new IntByReference(k) @@ -96,7 +113,8 @@ object LAPACK { val ldTInt = new IntByReference(ldT) val ldCInt = new IntByReference(ldC) val infoInt = new IntByReference(1) - libraryInstance.get.dgemqrt(side, trans, mInt, nInt, kInt, nbInt, V, ldVInt, T, ldTInt, C, ldCInt, work, infoInt) + libraryInstance.get.dgemqrt(side, trans, mInt, nInt, kInt, nbInt, V, ldVInt, T, ldTInt, C, + ldCInt, work, infoInt) infoInt.getValue() } @@ -111,7 +129,21 @@ object LAPACK { infoInt.getValue() } - def dgemqr(side: String, trans: String, m: Int, n: Int, k: Int, A: Long, ldA: Int, T: Long, Tsize: Int, C: Long, ldC: Int, work: Long, Lwork: Int): Int = { + def dgemqr( + side: String, + trans: String, + m: Int, + n: Int, + k: Int, + A: Long, + ldA: Int, + T: Long, + Tsize: Int, + C: Long, + ldC: Int, + work: Long, + Lwork: Int, + ): Int = { val mInt = new IntByReference(m) val nInt = new IntByReference(n) val kInt = new IntByReference(k) @@ -120,11 +152,24 @@ object LAPACK { val ldCInt = new IntByReference(ldC) val LworkInt = new IntByReference(Lwork) val infoInt = new IntByReference(1) - libraryInstance.get.dgemqr(side, trans, mInt, nInt, kInt, A, ldAInt, T, TsizeInt, C, ldCInt, work, LworkInt, infoInt) + libraryInstance.get.dgemqr(side, trans, mInt, nInt, kInt, A, ldAInt, T, TsizeInt, C, ldCInt, + work, LworkInt, infoInt) infoInt.getValue() } - def dtpqrt(m: Int, n: Int, l: Int, nb: Int, A: Long, ldA: Int, B: Long, ldB: Int, T: Long, ldT: Int, work: Long): Int = { + def dtpqrt( + m: Int, + n: Int, + l: Int, + nb: Int, + A: Long, + ldA: Int, + B: Long, + ldB: Int, + T: Long, + ldT: Int, + work: Long, + ): Int = { val mInt = new IntByReference(m) val nInt = new IntByReference(n) val lInt = new IntByReference(l) @@ -133,11 +178,29 @@ object LAPACK { val ldBInt = new IntByReference(ldB) val ldTInt = new IntByReference(ldT) val infoInt = new IntByReference(1) - libraryInstance.get.dtpqrt(mInt, nInt, lInt, nbInt, A, ldAInt, B, ldBInt, T, ldTInt, work, infoInt) + libraryInstance.get.dtpqrt(mInt, nInt, lInt, nbInt, A, ldAInt, B, ldBInt, T, ldTInt, work, + infoInt) infoInt.getValue() } - def dtpmqrt(side: String, trans: String, m: Int, n: Int, k: Int, l: Int, nb: Int, V: Long, ldV: Int, T: Long, ldT: Int, A: Long, ldA: Int, B: Long, ldB: Int, work: Long): Int = { + def dtpmqrt( + side: String, + trans: String, + m: Int, + n: Int, + k: Int, + l: Int, + nb: Int, + V: Long, + ldV: Int, + T: Long, + ldT: Int, + A: Long, + ldA: Int, + B: Long, + ldB: Int, + work: Long, + ): Int = { val mInt = new IntByReference(m) val nInt = new IntByReference(n) val kInt = new IntByReference(k) @@ -148,13 +211,14 @@ object LAPACK { val ldAInt = new IntByReference(ldA) val ldBInt = new IntByReference(ldB) val infoInt = new IntByReference(1) - libraryInstance.get.dtpmqrt(side, trans, mInt, nInt, kInt, lInt, nbInt, V, ldVInt, T, ldTInt, A, ldAInt, B, ldBInt, work, infoInt) + libraryInstance.get.dtpmqrt(side, trans, mInt, nInt, kInt, lInt, nbInt, V, ldVInt, T, ldTInt, A, + ldAInt, B, ldBInt, work, infoInt) infoInt.getValue() } def dgetrf(M: Int, N: Int, A: Long, LDA: Int, IPIV: Long): Int = { val Mref = new IntByReference(M) - val Nref= new IntByReference(N) + val Nref = new IntByReference(N) val LDAref = new IntByReference(LDA) val INFOref = new IntByReference(1) @@ -172,17 +236,31 @@ object LAPACK { INFOref.getValue() } - - def dgesdd(JOBZ: String, M: Int, N: Int, A: Long, LDA: Int, S: Long, U: Long, LDU: Int, VT: Long, LDVT: Int, WORK: Long, LWORK: Int, IWORK: Long): Int = { + def dgesdd( + JOBZ: String, + M: Int, + N: Int, + A: Long, + LDA: Int, + S: Long, + U: Long, + LDU: Int, + VT: Long, + LDVT: Int, + WORK: Long, + LWORK: Int, + IWORK: Long, + ): Int = { val Mref = new IntByReference(M) - val Nref= new IntByReference(N) + val Nref = new IntByReference(N) val LDAref = new IntByReference(LDA) val LDUref = new IntByReference(LDU) val LDVTref = new IntByReference(LDVT) val LWORKRef = new IntByReference(LWORK) val INFOref = new IntByReference(1) - libraryInstance.get.dgesdd(JOBZ, Mref, Nref, A, LDAref, S, U, LDUref, VT, LDVTref, WORK, LWORKRef, IWORK, INFOref) + libraryInstance.get.dgesdd(JOBZ, Mref, Nref, A, LDAref, S, U, LDUref, VT, LDVTref, WORK, + LWORKRef, IWORK, INFOref) INFOref.getValue() } @@ -199,7 +277,27 @@ object LAPACK { INFOref.getValue() } - def dsyevr(jobz: String, range: String, uplo: String, n: Int, A: Long, ldA: Int, vl: Double, vu: Double, il: Int, iu: Int, abstol: Double, W: Long, Z: Long, ldZ: Int, ISuppZ: Long, Work: Long, lWork: Int, IWork: Long, lIWork: Int): Int = { + def dsyevr( + jobz: String, + range: String, + uplo: String, + n: Int, + A: Long, + ldA: Int, + vl: Double, + vu: Double, + il: Int, + iu: Int, + abstol: Double, + W: Long, + Z: Long, + ldZ: Int, + ISuppZ: Long, + Work: Long, + lWork: Int, + IWork: Long, + lIWork: Int, + ): Int = { val nRef = new IntByReference(n) val ldARef = new IntByReference(ldA) val vlRef = new DoubleByReference(vl) @@ -213,16 +311,26 @@ object LAPACK { val INFOref = new IntByReference(1) val mRef = new IntByReference(0) - libraryInstance.get.dsyevr(jobz, range, uplo, nRef, A, ldARef, vlRef, vuRef, ilRef, iuRef, abstolRef, mRef, W, Z, ldZRef, ISuppZ, Work, lWorkRef, IWork, lIWorkRef, INFOref) + libraryInstance.get.dsyevr(jobz, range, uplo, nRef, A, ldARef, vlRef, vuRef, ilRef, iuRef, + abstolRef, mRef, W, Z, ldZRef, ISuppZ, Work, lWorkRef, IWork, lIWorkRef, INFOref) INFOref.getValue() } - def dtrtrs(UPLO: String, TRANS: String, DIAG: String, N: Int, NRHS: Int, - A: Long, LDA: Int, B: Long, LDB: Int): Int = { - val Nref = new IntByReference(N) - val NRHSref = new IntByReference(NRHS) - val LDAref = new IntByReference(LDA) + def dtrtrs( + UPLO: String, + TRANS: String, + DIAG: String, + N: Int, + NRHS: Int, + A: Long, + LDA: Int, + B: Long, + LDB: Int, + ): Int = { + val Nref = new IntByReference(N) + val NRHSref = new IntByReference(NRHS) + val LDAref = new IntByReference(LDA) val LDBref = new IntByReference(LDB) val INFOref = new IntByReference(1) libraryInstance.get.dtrtrs(UPLO, TRANS, DIAG, Nref, NRHSref, A, LDAref, B, LDBref, INFOref) @@ -260,21 +368,225 @@ object LAPACK { } trait LAPACKLibrary extends Library { - def dgesv(N: IntByReference, NHRS: IntByReference, A: Long, LDA: IntByReference, IPIV: Long, B: Long, LDB: IntByReference, INFO: IntByReference) - def dgeqrf(M: IntByReference, N: IntByReference, A: Long, LDA: IntByReference, TAU: Long, WORK: Long, LWORK: IntByReference, INFO: IntByReference) - def dorgqr(M: IntByReference, N: IntByReference, K: IntByReference, A: Long, LDA: IntByReference, TAU: Long, WORK: Long, LWORK: IntByReference, INFO: IntByReference) - def dgeqrt(m: IntByReference, n: IntByReference, nb: IntByReference, A: Long, ldA: IntByReference, T: Long, ldT: IntByReference, work: Long, info: IntByReference) - def dgemqrt(side: String, trans: String, m: IntByReference, n: IntByReference, k: IntByReference, nb: IntByReference, V: Long, ldV: IntByReference, T: Long, ldT: IntByReference, C: Long, ldC: IntByReference, work: Long, info: IntByReference) - def dgeqr(m: IntByReference, n: IntByReference, A: Long, ldA: IntByReference, T: Long, Tsize: IntByReference, work: Long, lWork: IntByReference, info: IntByReference) - def dgemqr(side: String, trans: String, m: IntByReference, n: IntByReference, k: IntByReference, A: Long, ldA: IntByReference, T: Long, Tsize: IntByReference, C: Long, ldC: IntByReference, work: Long, Lwork: IntByReference, info: IntByReference) - def dtpqrt(M: IntByReference, N: IntByReference, L: IntByReference, NB: IntByReference, A: Long, LDA: IntByReference, B: Long, LDB: IntByReference, T: Long, LDT: IntByReference, WORK: Long, INFO: IntByReference) - def dtpmqrt(side: String, trans: String, M: IntByReference, N: IntByReference, K: IntByReference, L: IntByReference, NB: IntByReference, V: Long, LDV: IntByReference, T: Long, LDT: IntByReference, A: Long, LDA: IntByReference, B: Long, LDB: IntByReference, WORK: Long, INFO: IntByReference) - def dgetrf(M: IntByReference, N: IntByReference, A: Long, LDA: IntByReference, IPIV: Long, INFO: IntByReference) - def dgetri(N: IntByReference, A: Long, LDA: IntByReference, IPIV: Long, WORK: Long, LWORK: IntByReference, INFO: IntByReference) - def dgesdd(JOBZ: String, M: IntByReference, N: IntByReference, A: Long, LDA: IntByReference, S: Long, U: Long, LDU: IntByReference, VT: Long, LDVT: IntByReference, WORK: Long, LWORK: IntByReference, IWORK: Long, INFO: IntByReference) - def dsyevr(jobz: String, range: String, uplo: String, n: IntByReference, A: Long, ldA: IntByReference, vl: DoubleByReference, vu: DoubleByReference, il: IntByReference, iu: IntByReference, abstol: DoubleByReference, m: IntByReference, W: Long, Z: Long, ldZ: IntByReference, ISuppZ: Long, Work: Long, lWork: IntByReference, IWork: Long, lIWork: IntByReference, info: IntByReference) - def ilaver(MAJOR: IntByReference, MINOR: IntByReference, PATCH: IntByReference) - def ilaenv(ispec: IntByReference, name: String, opts: String, n1: IntByReference, n2: IntByReference, n3: IntByReference, n4: IntByReference): Int - def dtrtrs(UPLO: String, TRANS: String, DIAG: String, N: IntByReference, NRHS: IntByReference, A: Long, LDA: IntByReference, B: Long, LDB: IntByReference, INFO:IntByReference) - def dlacpy(uplo: String, M: IntByReference, N: IntByReference, A: Long, ldA: IntByReference, B: Long, ldB: IntByReference) + def dgesv( + N: IntByReference, + NHRS: IntByReference, + A: Long, + LDA: IntByReference, + IPIV: Long, + B: Long, + LDB: IntByReference, + INFO: IntByReference, + ): Unit + + def dgeqrf( + M: IntByReference, + N: IntByReference, + A: Long, + LDA: IntByReference, + TAU: Long, + WORK: Long, + LWORK: IntByReference, + INFO: IntByReference, + ): Unit + + def dorgqr( + M: IntByReference, + N: IntByReference, + K: IntByReference, + A: Long, + LDA: IntByReference, + TAU: Long, + WORK: Long, + LWORK: IntByReference, + INFO: IntByReference, + ): Unit + + def dgeqrt( + m: IntByReference, + n: IntByReference, + nb: IntByReference, + A: Long, + ldA: IntByReference, + T: Long, + ldT: IntByReference, + work: Long, + info: IntByReference, + ): Unit + + def dgemqrt( + side: String, + trans: String, + m: IntByReference, + n: IntByReference, + k: IntByReference, + nb: IntByReference, + V: Long, + ldV: IntByReference, + T: Long, + ldT: IntByReference, + C: Long, + ldC: IntByReference, + work: Long, + info: IntByReference, + ): Unit + + def dgeqr( + m: IntByReference, + n: IntByReference, + A: Long, + ldA: IntByReference, + T: Long, + Tsize: IntByReference, + work: Long, + lWork: IntByReference, + info: IntByReference, + ): Unit + + def dgemqr( + side: String, + trans: String, + m: IntByReference, + n: IntByReference, + k: IntByReference, + A: Long, + ldA: IntByReference, + T: Long, + Tsize: IntByReference, + C: Long, + ldC: IntByReference, + work: Long, + Lwork: IntByReference, + info: IntByReference, + ): Unit + + def dtpqrt( + M: IntByReference, + N: IntByReference, + L: IntByReference, + NB: IntByReference, + A: Long, + LDA: IntByReference, + B: Long, + LDB: IntByReference, + T: Long, + LDT: IntByReference, + WORK: Long, + INFO: IntByReference, + ): Unit + + def dtpmqrt( + side: String, + trans: String, + M: IntByReference, + N: IntByReference, + K: IntByReference, + L: IntByReference, + NB: IntByReference, + V: Long, + LDV: IntByReference, + T: Long, + LDT: IntByReference, + A: Long, + LDA: IntByReference, + B: Long, + LDB: IntByReference, + WORK: Long, + INFO: IntByReference, + ): Unit + + def dgetrf( + M: IntByReference, + N: IntByReference, + A: Long, + LDA: IntByReference, + IPIV: Long, + INFO: IntByReference, + ): Unit + + def dgetri( + N: IntByReference, + A: Long, + LDA: IntByReference, + IPIV: Long, + WORK: Long, + LWORK: IntByReference, + INFO: IntByReference, + ): Unit + + def dgesdd( + JOBZ: String, + M: IntByReference, + N: IntByReference, + A: Long, + LDA: IntByReference, + S: Long, + U: Long, + LDU: IntByReference, + VT: Long, + LDVT: IntByReference, + WORK: Long, + LWORK: IntByReference, + IWORK: Long, + INFO: IntByReference, + ): Unit + + def dsyevr( + jobz: String, + range: String, + uplo: String, + n: IntByReference, + A: Long, + ldA: IntByReference, + vl: DoubleByReference, + vu: DoubleByReference, + il: IntByReference, + iu: IntByReference, + abstol: DoubleByReference, + m: IntByReference, + W: Long, + Z: Long, + ldZ: IntByReference, + ISuppZ: Long, + Work: Long, + lWork: IntByReference, + IWork: Long, + lIWork: IntByReference, + info: IntByReference, + ): Unit + + def ilaver(MAJOR: IntByReference, MINOR: IntByReference, PATCH: IntByReference): Unit + + def ilaenv( + ispec: IntByReference, + name: String, + opts: String, + n1: IntByReference, + n2: IntByReference, + n3: IntByReference, + n4: IntByReference, + ): Int + + def dtrtrs( + UPLO: String, + TRANS: String, + DIAG: String, + N: IntByReference, + NRHS: IntByReference, + A: Long, + LDA: IntByReference, + B: Long, + LDB: IntByReference, + INFO: IntByReference, + ): Unit + + def dlacpy( + uplo: String, + M: IntByReference, + N: IntByReference, + A: Long, + ldA: IntByReference, + B: Long, + ldB: IntByReference, + ): Unit } diff --git a/hail/src/main/scala/is/hail/linalg/LinalgCodeUtils.scala b/hail/src/main/scala/is/hail/linalg/LinalgCodeUtils.scala index 7392ee24cba..d38c9f7b7df 100644 --- a/hail/src/main/scala/is/hail/linalg/LinalgCodeUtils.scala +++ b/hail/src/main/scala/is/hail/linalg/LinalgCodeUtils.scala @@ -19,7 +19,7 @@ object LinalgCodeUtils { cb.assign(answer, true) cb.assign(runningProduct, st.elementByteSize) - (0 until nDims).foreach{ index => + (0 until nDims).foreach { index => cb.assign(answer, answer & (strides(index) ceq runningProduct)) cb.assign(runningProduct, runningProduct * (shapes(index) > 0L).mux(shapes(index), 1L)) } @@ -44,46 +44,66 @@ object LinalgCodeUtils { answer } - def createColumnMajorCode(pndv: SNDArrayValue, cb: EmitCodeBuilder, region: Value[Region]): SNDArrayValue = { + def createColumnMajorCode(pndv: SNDArrayValue, cb: EmitCodeBuilder, region: Value[Region]) + : SNDArrayValue = { val shape = pndv.shapes - val pt = PCanonicalNDArray(pndv.st.elementType.storageType().setRequired(true), pndv.st.nDims, false) + val pt = + PCanonicalNDArray(pndv.st.elementType.storageType().setRequired(true), pndv.st.nDims, false) val strides = pt.makeColumnMajorStrides(shape, cb) - val (dataFirstElementAddress, dataFinisher) = pt.constructDataFunction(shape, strides, cb, region) + val (_, dataFinisher) = + pt.constructDataFunction(shape, strides, cb, region) // construct an SNDArrayCode with undefined contents val result = dataFinisher(cb) - result.coiterateMutate(cb, region, (pndv, "pndv")) { case Seq(l, r) => r } + result.coiterateMutate(cb, region, (pndv, "pndv")) { case Seq(_, r) => r } result } - def checkColMajorAndCopyIfNeeded(aInput: SNDArrayValue, cb: EmitCodeBuilder, region: Value[Region]): SNDArrayValue = { + def checkColMajorAndCopyIfNeeded( + aInput: SNDArrayValue, + cb: EmitCodeBuilder, + region: Value[Region], + ): SNDArrayValue = { val aIsColumnMajor = LinalgCodeUtils.checkColumnMajor(aInput, cb) - val aColMajor = cb.emb.newPField("ndarray_output_column_major", aInput.st).asInstanceOf[SNDArraySettable] - cb.if_(aIsColumnMajor, {cb.assign(aColMajor, aInput)}, - { - cb.assign(aColMajor, LinalgCodeUtils.createColumnMajorCode(aInput, cb, region)) - }) + val aColMajor = + cb.emb.newPField("ndarray_output_column_major", aInput.st).asInstanceOf[SNDArraySettable] + cb.if_( + aIsColumnMajor, + cb.assign(aColMajor, aInput), + cb.assign(aColMajor, LinalgCodeUtils.createColumnMajorCode(aInput, cb, region)), + ) aColMajor } - def checkStandardStriding(aInput: SNDArrayValue, cb: EmitCodeBuilder, region: Value[Region]): (SNDArrayValue, Value[Boolean]) = { + def checkStandardStriding(aInput: SNDArrayValue, cb: EmitCodeBuilder, region: Value[Region]) + : (SNDArrayValue, Value[Boolean]) = { if (aInput.st.isInstanceOf[SUnreachableNDArray]) return (aInput, const(true)) val aIsColumnMajor = LinalgCodeUtils.checkColumnMajor(aInput, cb) - val a = cb.emb.newPField("ndarray_output_standardized", aInput.st).asInstanceOf[SNDArraySettable] - cb.if_(aIsColumnMajor, {cb.assign(a, aInput)}, { - val isRowMajor = LinalgCodeUtils.checkRowMajor(aInput, cb) - cb.if_(isRowMajor, {cb.assign(a, aInput)}, { - cb.assign(a, LinalgCodeUtils.createColumnMajorCode(aInput, cb, region)) - }) - }) + val a = + cb.emb.newPField("ndarray_output_standardized", aInput.st).asInstanceOf[SNDArraySettable] + cb.if_( + aIsColumnMajor, + cb.assign(a, aInput), { + val isRowMajor = LinalgCodeUtils.checkRowMajor(aInput, cb) + cb.if_( + isRowMajor, + cb.assign(a, aInput), + cb.assign(a, LinalgCodeUtils.createColumnMajorCode(aInput, cb, region)), + ) + }, + ) (a, aIsColumnMajor) } - def linearizeIndicesRowMajor(indices: IndexedSeq[Code[Long]], shapeArray: IndexedSeq[Value[Long]], mb: EmitMethodBuilder[_]): Code[Long] = { + def linearizeIndicesRowMajor( + indices: IndexedSeq[Code[Long]], + shapeArray: IndexedSeq[Value[Long]], + mb: EmitMethodBuilder[_], + ): Code[Long] = { val index = mb.genFieldThisRef[Long]() val elementsInProcessedDimensions = mb.genFieldThisRef[Long]() Code( @@ -92,14 +112,18 @@ object LinalgCodeUtils { Code.foreach(shapeArray.zip(indices).reverse) { case (shapeElement, currentIndex) => Code( index := index + currentIndex * elementsInProcessedDimensions, - elementsInProcessedDimensions := elementsInProcessedDimensions * shapeElement + elementsInProcessedDimensions := elementsInProcessedDimensions * shapeElement, ) }, - index + index, ) } - def unlinearizeIndexRowMajor(index: Code[Long], shapeArray: IndexedSeq[Value[Long]], mb: EmitMethodBuilder[_]): (Code[Unit], IndexedSeq[Value[Long]]) = { + def unlinearizeIndexRowMajor( + index: Code[Long], + shapeArray: IndexedSeq[Value[Long]], + mb: EmitMethodBuilder[_], + ): (Code[Unit], IndexedSeq[Value[Long]]) = { val nDim = shapeArray.length val newIndices = (0 until nDim).map(_ => mb.genFieldThisRef[Long]()) val elementsInProcessedDimensions = mb.genFieldThisRef[Long]() @@ -112,9 +136,9 @@ object LinalgCodeUtils { Code( elementsInProcessedDimensions := elementsInProcessedDimensions / shapeElement, newIndex := workRemaining / elementsInProcessedDimensions, - workRemaining := workRemaining % elementsInProcessedDimensions + workRemaining := workRemaining % elementsInProcessedDimensions, ) - } + }, ) (createShape, newIndices) } diff --git a/hail/src/main/scala/is/hail/linalg/RowMatrix.scala b/hail/src/main/scala/is/hail/linalg/RowMatrix.scala index 34a02b587a8..6adbaca171a 100644 --- a/hail/src/main/scala/is/hail/linalg/RowMatrix.scala +++ b/hail/src/main/scala/is/hail/linalg/RowMatrix.scala @@ -1,38 +1,40 @@ package is.hail.linalg -import breeze.linalg.DenseMatrix -import is.hail.HailContext import is.hail.backend.{BroadcastValue, ExecuteContext, HailStateManager} import is.hail.backend.spark.SparkBackend -import is.hail.types.virtual.{TInt64, TStruct} import is.hail.io.InputBuffer import is.hail.io.fs.FS import is.hail.rvd.RVDPartitioner +import is.hail.types.virtual.{TInt64, TStruct} import is.hail.utils._ + +import breeze.linalg.DenseMatrix +import org.apache.spark.{Partition, Partitioner, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row -import org.apache.spark.{Partition, Partitioner, SparkContext, TaskContext} object RowMatrix { def apply(rows: RDD[(Long, Array[Double])], nCols: Int): RowMatrix = new RowMatrix(rows, nCols, None, None) - + def apply(rows: RDD[(Long, Array[Double])], nCols: Int, nRows: Long): RowMatrix = new RowMatrix(rows, nCols, Some(nRows), None) - - def apply(rows: RDD[(Long, Array[Double])], nCols: Int, nRows: Long, partitionCounts: Array[Long]): RowMatrix = + + def apply(rows: RDD[(Long, Array[Double])], nCols: Int, nRows: Long, partitionCounts: Array[Long]) + : RowMatrix = new RowMatrix(rows, nCols, Some(nRows), Some(partitionCounts)) - + def computePartitionCounts(partSize: Long, nRows: Long): Array[Long] = { val nParts = ((nRows - 1) / partSize).toInt + 1 val partitionCounts = Array.fill[Long](nParts)(partSize) partitionCounts(nParts - 1) = nRows - partSize * (nParts - 1) - + partitionCounts } def readBlockMatrix(fs: FS, uri: String, maybePartSize: java.lang.Integer): RowMatrix = { - val BlockMatrixMetadata(blockSize, nRows, nCols, maybeFiltered, partFiles) = BlockMatrix.readMetadata(fs, uri) + val BlockMatrixMetadata(blockSize, nRows, nCols, maybeFiltered, partFiles) = + BlockMatrix.readMetadata(fs, uri) if (nCols >= Int.MaxValue) { fatal(s"Number of columns must be less than 2^31, found $nCols") } @@ -43,124 +45,227 @@ object RowMatrix { new ReadBlocksAsRowsRDD(fs.broadcast, uri, partFiles, partitionCounts, gp), gp.nCols.toInt, gp.nRows, - partitionCounts) + partitionCounts, + ) } } -class RowMatrix(val rows: RDD[(Long, Array[Double])], +class RowMatrix( + val rows: RDD[(Long, Array[Double])], val nCols: Int, private var _nRows: Option[Long], - private var _partitionCounts: Option[Array[Long]]) extends Serializable { + private var _partitionCounts: Option[Array[Long]], +) extends Serializable { require(nCols > 0) - + def nRows: Long = _nRows match { case Some(nRows) => nRows case None => _nRows = Some(partitionCounts().sum) nRows - } - + } + def partitionCounts(): Array[Long] = _partitionCounts match { case Some(partitionCounts) => partitionCounts case None => _partitionCounts = Some(rows.countPerPartition()) partitionCounts() } - + // length nPartitions + 1, first element 0, last element rdd2 count def partitionStarts(): Array[Long] = partitionCounts().scanLeft(0L)(_ + _) def partitioner( partitionKey: Array[String] = Array("idx"), - kType: TStruct = TStruct("idx" -> TInt64)): RVDPartitioner = { - + kType: TStruct = TStruct("idx" -> TInt64), + ): RVDPartitioner = { + val partStarts = partitionStarts() - new RVDPartitioner(HailStateManager(Map.empty), partitionKey, kType, + new RVDPartitioner( + HailStateManager(Map.empty), + partitionKey, + kType, Array.tabulate(partStarts.length - 1) { i => val start = partStarts(i) val end = partStarts(i + 1) Interval(Row(start), Row(end), includesStart = true, includesEnd = false) - }) + }, + ) } - + def toBreezeMatrix(): DenseMatrix[Double] = { - require(_nRows.forall(_ <= Int.MaxValue), "The number of rows of this matrix should be less than or equal to " + - s"Int.MaxValue. Currently numRows: ${ _nRows.get }") - + require( + _nRows.forall(_ <= Int.MaxValue), + "The number of rows of this matrix should be less than or equal to " + + s"Int.MaxValue. Currently numRows: ${_nRows.get}", + ) + val a = rows.map(_._2).collect() val nRowsInt = a.length - - require(nRowsInt * nCols.toLong <= Int.MaxValue, "The length of the values array must be " + - s"less than or equal to Int.MaxValue. Currently rows * cols: ${ nRowsInt * nCols.toLong }") - + + require( + nRowsInt * nCols.toLong <= Int.MaxValue, + "The length of the values array must be " + + s"less than or equal to Int.MaxValue. Currently rows * cols: ${nRowsInt * nCols.toLong}", + ) + new DenseMatrix[Double](nRowsInt, nCols, a.flatten, 0, nCols, isTranspose = true) } - - def export(ctx: ExecuteContext, path: String, columnDelimiter: String, header: Option[String], addIndex: Boolean, exportType: String) { + + def export( + ctx: ExecuteContext, + path: String, + columnDelimiter: String, + header: Option[String], + addIndex: Boolean, + exportType: String, + ): Unit = { val localNCols = nCols - exportDelimitedRowSlices(ctx, path, columnDelimiter, header, addIndex, exportType, _ => 0, _ => localNCols) + exportDelimitedRowSlices( + ctx, + path, + columnDelimiter, + header, + addIndex, + exportType, + _ => 0, + _ => localNCols, + ) } // includes the diagonal - def exportLowerTriangle(ctx: ExecuteContext, path: String, columnDelimiter: String, header: Option[String], addIndex: Boolean, exportType: String) { + def exportLowerTriangle( + ctx: ExecuteContext, + path: String, + columnDelimiter: String, + header: Option[String], + addIndex: Boolean, + exportType: String, + ): Unit = { val localNCols = nCols - exportDelimitedRowSlices(ctx, path, columnDelimiter, header, addIndex, exportType, _ => 0, i => math.min(i + 1, localNCols.toLong).toInt) + exportDelimitedRowSlices( + ctx, + path, + columnDelimiter, + header, + addIndex, + exportType, + _ => 0, + i => math.min(i + 1, localNCols.toLong).toInt, + ) } - def exportStrictLowerTriangle(ctx: ExecuteContext, path: String, columnDelimiter: String, header: Option[String], addIndex: Boolean, exportType: String) { + def exportStrictLowerTriangle( + ctx: ExecuteContext, + path: String, + columnDelimiter: String, + header: Option[String], + addIndex: Boolean, + exportType: String, + ): Unit = { val localNCols = nCols - exportDelimitedRowSlices(ctx, path, columnDelimiter, header, addIndex, exportType, _ => 0, i => math.min(i, localNCols.toLong).toInt) + exportDelimitedRowSlices( + ctx, + path, + columnDelimiter, + header, + addIndex, + exportType, + _ => 0, + i => math.min(i, localNCols.toLong).toInt, + ) } // includes the diagonal - def exportUpperTriangle(ctx: ExecuteContext, path: String, columnDelimiter: String, header: Option[String], addIndex: Boolean, exportType: String) { + def exportUpperTriangle( + ctx: ExecuteContext, + path: String, + columnDelimiter: String, + header: Option[String], + addIndex: Boolean, + exportType: String, + ): Unit = { val localNCols = nCols - exportDelimitedRowSlices(ctx, path, columnDelimiter, header, addIndex, exportType, i => math.min(i, localNCols.toLong).toInt, _ => localNCols) - } - - def exportStrictUpperTriangle(ctx: ExecuteContext, path: String, columnDelimiter: String, header: Option[String], addIndex: Boolean, exportType: String) { + exportDelimitedRowSlices( + ctx, + path, + columnDelimiter, + header, + addIndex, + exportType, + i => math.min(i, localNCols.toLong).toInt, + _ => localNCols, + ) + } + + def exportStrictUpperTriangle( + ctx: ExecuteContext, + path: String, + columnDelimiter: String, + header: Option[String], + addIndex: Boolean, + exportType: String, + ): Unit = { val localNCols = nCols - exportDelimitedRowSlices(ctx, path, columnDelimiter, header, addIndex, exportType, i => math.min(i + 1, localNCols.toLong).toInt, _ => localNCols) + exportDelimitedRowSlices( + ctx, + path, + columnDelimiter, + header, + addIndex, + exportType, + i => math.min(i + 1, localNCols.toLong).toInt, + _ => localNCols, + ) } - - // convert elements in [start, end) of each array to a string, delimited by columnDelimiter, and export + + /* convert elements in [start, end) of each array to a string, delimited by columnDelimiter, and + * export */ def exportDelimitedRowSlices( ctx: ExecuteContext, - path: String, + path: String, columnDelimiter: String, header: Option[String], addIndex: Boolean, exportType: String, - start: (Long) => Int, - end: (Long) => Int) { - - genericExport(ctx, path, header, exportType, { (sb, i, v) => - if (addIndex) { - sb.append(i) - sb.append(columnDelimiter) - } - val l = start(i) - val r = end(i) - var j = l - while (j < r) { - if (j > l) + start: (Long) => Int, + end: (Long) => Int, + ): Unit = { + + genericExport( + ctx, + path, + header, + exportType, + { (sb, i, v) => + if (addIndex) { + sb.append(i) sb.append(columnDelimiter) - sb.append(v(j)) - j += 1 - } - }) + } + val l = start(i) + val r = end(i) + var j = l + while (j < r) { + if (j > l) + sb.append(columnDelimiter) + sb.append(v(j)) + j += 1 + } + }, + ) } // uses writeRow to convert each row to a string and writes that string to a file if non-empty def genericExport( ctx: ExecuteContext, - path: String, - header: Option[String], + path: String, + header: Option[String], exportType: String, - writeRow: (StringBuilder, Long, Array[Double]) => Unit) { - + writeRow: (StringBuilder, Long, Array[Double]) => Unit, + ): Unit = { + rows.mapPartitions { it => val sb = new StringBuilder() it.map { case (index, v) => @@ -180,26 +285,31 @@ class ReadBlocksAsRowsRDD( path: String, partFiles: IndexedSeq[String], partitionCounts: Array[Long], - gp: GridPartitioner) extends RDD[(Long, Array[Double])](SparkBackend.sparkContext("ReadBlocksAsRowsRDD"), Nil) { - + gp: GridPartitioner, +) extends RDD[(Long, Array[Double])](SparkBackend.sparkContext("ReadBlocksAsRowsRDD"), Nil) { + private val partitionStarts = partitionCounts.scanLeft(0L)(_ + _) - + if (partitionStarts.last != gp.nRows) - fatal(s"Error reading BlockMatrix as RowMatrix: expected ${partitionStarts.last} rows in RowMatrix, but found ${gp.nRows} rows in BlockMatrix") + fatal( + s"Error reading BlockMatrix as RowMatrix: expected ${partitionStarts.last} rows in RowMatrix, but found ${gp.nRows} rows in BlockMatrix" + ) if (gp.nCols > Int.MaxValue) - fatal(s"Cannot read BlockMatrix with ${gp.nCols} > Int.MaxValue columns as a RowMatrix") - + fatal(s"Cannot read BlockMatrix with ${gp.nCols} > Int.MaxValue columns as a RowMatrix") + private val nCols = gp.nCols.toInt private val nBlockCols = gp.nBlockCols private val blockSize = gp.blockSize protected def getPartitions: Array[Partition] = Array.tabulate(partitionStarts.length - 1)(pi => - ReadBlocksAsRowsRDDPartition(pi, partitionStarts(pi), partitionStarts(pi + 1))) + ReadBlocksAsRowsRDDPartition(pi, partitionStarts(pi), partitionStarts(pi + 1)) + ) def compute(split: Partition, context: TaskContext): Iterator[(Long, Array[Double])] = { - val ReadBlocksAsRowsRDDPartition(_, start, end) = split.asInstanceOf[ReadBlocksAsRowsRDDPartition] - + val ReadBlocksAsRowsRDDPartition(_, start, end) = + split.asInstanceOf[ReadBlocksAsRowsRDDPartition] + var inPerBlockCol: IndexedSeq[(InputBuffer, Int, Int)] = null var i = start @@ -210,7 +320,7 @@ class ReadBlocksAsRowsRDD( if (i == start || i % blockSize == 0) { val blockRow = (i / blockSize).toInt val nRowsInBlock = gp.blockRowNRows(blockRow) - + inPerBlockCol = (0 until nBlockCols) .flatMap { blockCol => val pi = gp.coordinatesPart(blockRow, blockCol) @@ -226,7 +336,9 @@ class ReadBlocksAsRowsRDD( assert(in.readInt() == nColsInBlock) val isTranspose = in.readBoolean() if (!isTranspose) - fatal("BlockMatrix must be stored row major on disk in order to be read as a RowMatrix") + fatal( + "BlockMatrix must be stored row major on disk in order to be read as a RowMatrix" + ) if (i == start) { val skip = (start % blockSize).toInt * (nColsInBlock << 3) @@ -240,22 +352,22 @@ class ReadBlocksAsRowsRDD( } val row = new Array[Double](nCols) - + inPerBlockCol.foreach { case (in, blockCol, nColsInBlock) => in.readDoubles(row, blockCol * blockSize, nColsInBlock) } - + val iRow = (i, row) - + i += 1 - + if (i % blockSize == 0 || i == end) inPerBlockCol.foreach(_._1.close()) - + iRow } } } - + @transient override val partitioner: Option[Partitioner] = Some(RowPartitioner(partitionStarts)) } diff --git a/hail/src/main/scala/is/hail/linalg/RowPartitioner.scala b/hail/src/main/scala/is/hail/linalg/RowPartitioner.scala index fd01186e107..a016ce38ec9 100644 --- a/hail/src/main/scala/is/hail/linalg/RowPartitioner.scala +++ b/hail/src/main/scala/is/hail/linalg/RowPartitioner.scala @@ -3,20 +3,18 @@ package is.hail.linalg import org.apache.spark.Partitioner object RowPartitioner { - /** - * a represents a partitioning of the (mathematical) integers into intervals - * with interval j given by [a(j), a(j+1)). + + /** a represents a partitioning of the (mathematical) integers into intervals with interval j + * given by [a(j), a(j+1)). + * + * -infty -1 a(0) 0 a(1) 1 len-2 a(len-1) len-1 infty (-----------)[---------)[-------- ... + * ------------)[-----------------) * - * -infty -1 a(0) 0 a(1) 1 len-2 a(len-1) len-1 infty - * (-----------)[---------)[-------- ... ------------)[-----------------) - * * a must be non-decreasing; repeated values correspond to empty intervals. * - * Returns interval containing key: - * -1 iff a is empty or key < a(0) - * j iff a(j) <= key < a(j + 1) - * len-1 iff a(len - 1) < key - **/ + * Returns interval containing key: -1 iff a is empty or key < a(0) j iff a(j) <= key < a(j + 1) + * len-1 iff a(len - 1) < key + */ def findInterval(a: Array[Long], key: Long): Int = { var lo = 0 var hi = a.length - 1 diff --git a/hail/src/main/scala/is/hail/lir/CFG.scala b/hail/src/main/scala/is/hail/lir/CFG.scala index 8a94b8c9d51..bfcd62b113f 100644 --- a/hail/src/main/scala/is/hail/lir/CFG.scala +++ b/hail/src/main/scala/is/hail/lir/CFG.scala @@ -26,8 +26,8 @@ object CFG { case x: SwitchX => edgeTo(x.Ldefault) x.Lcases.foreach(edgeTo) - case x: ReturnX => - case x: ThrowX => + case _: ReturnX => + case _: ThrowX => } } @@ -38,14 +38,15 @@ object CFG { class CFG( val entry: Int, val pred: Array[mutable.Set[Int]], - val succ: Array[mutable.Set[Int]]) { + val succ: Array[mutable.Set[Int]], +) { def nBlocks: Int = succ.length def dump(): Unit = { println(s"CFG $nBlocks:") var i = 0 while (i < nBlocks) { - println(s" $i: ${ succ(i).mkString(",") }") + println(s" $i: ${succ(i).mkString(",")}") i += 1 } } diff --git a/hail/src/main/scala/is/hail/lir/Emit.scala b/hail/src/main/scala/is/hail/lir/Emit.scala index d6737c92370..98c65e2bc3e 100644 --- a/hail/src/main/scala/is/hail/lir/Emit.scala +++ b/hail/src/main/scala/is/hail/lir/Emit.scala @@ -1,16 +1,16 @@ package is.hail.lir -import java.io.PrintWriter - import is.hail.utils._ -import org.objectweb.asm.{ClassReader, ClassVisitor, ClassWriter, Label} -import org.objectweb.asm.Opcodes._ -import org.objectweb.asm.util.{CheckClassAdapter, Textifier, TraceClassVisitor} import scala.collection.mutable -import java.io.ByteArrayOutputStream + +import java.io.{ByteArrayOutputStream, PrintWriter} import java.nio.charset.StandardCharsets +import org.objectweb.asm.{ClassReader, ClassVisitor, ClassWriter, Label} +import org.objectweb.asm.Opcodes._ +import org.objectweb.asm.util.{CheckClassAdapter, Textifier, TraceClassVisitor} + object Emit { def emitMethod(cv: ClassVisitor, m: Method, debugInformation: Boolean): Int = { val blocks = m.findBlocks() @@ -49,12 +49,11 @@ object Emit { } } - def getLocalIndex(l: Local): Int = { + def getLocalIndex(l: Local): Int = l match { case p: Parameter => parameterIndex(p.i) case _ => localIndex(l) } - } var maxStack = 0 var curLineNumber = -1 @@ -87,7 +86,12 @@ object Emit { setLineNumber(x.lineNumber) mv.visitMethodInsn( - INVOKESPECIAL, x.ctor.owner, x.ctor.name, x.ctor.desc, x.ctor.isInterface) + INVOKESPECIAL, + x.ctor.owner, + x.ctor.name, + x.ctor.desc, + x.ctor.isInterface, + ) instructionCount += 3 return case _ => @@ -114,7 +118,12 @@ object Emit { mv.visitJumpInsn(GOTO, labels(x.L)) case x: SwitchX => assert(x.Lcases.nonEmpty) - mv.visitTableSwitchInsn(0, x.Lcases.length - 1, labels(x.Ldefault), x.Lcases.map(labels): _*) + mv.visitTableSwitchInsn( + 0, + x.Lcases.length - 1, + labels(x.Ldefault), + x.Lcases.map(labels): _* + ) case x: ReturnX => if (x.children.length == 0) mv.visitInsn(RETURN) @@ -132,10 +141,20 @@ object Emit { mv.visitTypeInsn(x.op, x.ti.iname) case x: MethodX => mv.visitMethodInsn( - x.op, x.method.owner, x.method.name, x.method.desc, x.method.isInterface) + x.op, + x.method.owner, + x.method.name, + x.method.desc, + x.method.isInterface, + ) case x: MethodStmtX => mv.visitMethodInsn( - x.op, x.method.owner, x.method.name, x.method.desc, x.method.isInterface) + x.op, + x.method.owner, + x.method.name, + x.method.desc, + x.method.isInterface, + ) case x: LdcX => mv.visitLdcInsn(x.a) case x: GetFieldX => @@ -146,7 +165,7 @@ object Emit { mv.visitIincInsn(getLocalIndex(x.l), x.i) case x: StmtOpX => mv.visitInsn(x.op) - case x: ThrowX => + case _: ThrowX => mv.visitInsn(ATHROW) } } @@ -163,20 +182,18 @@ object Emit { mv.visitLabel(start) emitBlock(m.entry) - for (b <- blocks) { + for (b <- blocks) if (b ne m.entry) emitBlock(b) - } mv.visitLabel(end) - for (l <- locals) { + for (l <- locals) if (!l.isInstanceOf[Parameter]) { val n = localIndex(l) val name = if (l.name == null) s"local$n" else l.name mv.visitLocalVariable(name, l.ti.desc, null, start, end, n) } - } mv.visitMaxs(maxStack, nLocals) @@ -189,19 +206,18 @@ object Emit { cv.visit(V1_8, ACC_PUBLIC, c.name, null, c.superName, c.interfaces.toArray) c.sourceFile.foreach(cv.visitSource(_, null)) - for (f <- c.fields) { + for (f <- c.fields) f match { case f: Field => cv.visitField(ACC_PUBLIC, f.name, f.ti.desc, null, null) case f: StaticField => cv.visitField(ACC_PUBLIC | ACC_STATIC, f.name, f.ti.desc, null, null) } - } for (m <- c.methods) { val instructionCount = emitMethod(cv, m, c.sourceFile.isDefined) if (logMethodSizes) { - log.info(s"instruction count: $instructionCount: ${ c.name }.${ m.name }") + log.info(s"instruction count: $instructionCount: ${c.name}.${m.name}") if (instructionCount > 8000) - log.warn(s"big method: $instructionCount: ${ c.name }.${ m.name }") + log.warn(s"big method: $instructionCount: ${c.name}.${m.name}") } } @@ -209,28 +225,29 @@ object Emit { } def apply(c: Classx[_], print: Option[PrintWriter]): Array[Byte] = { - val bytes = try { - val cw = new ClassWriter(ClassWriter.COMPUTE_MAXS + ClassWriter.COMPUTE_FRAMES) + val bytes = + try { + val cw = new ClassWriter(ClassWriter.COMPUTE_MAXS + ClassWriter.COMPUTE_FRAMES) - emitClass(c, cw, logMethodSizes = true) + emitClass(c, cw, logMethodSizes = true) - val b = cw.toByteArray - // For efficiency, the ClassWriter does no checking, and may generate invalid - // bytecode. This will verify the generated class file, printing errors - // to System.out. - // This next line should always be commented out! + val b = cw.toByteArray + // For efficiency, the ClassWriter does no checking, and may generate invalid + // bytecode. This will verify the generated class file, printing errors + // to System.out. + // This next line should always be commented out! // CheckClassAdapter.verify(new ClassReader(b), false, new PrintWriter(System.err)) - b - } catch { - case e: Exception => - val buffer = new ByteArrayOutputStream() - val trace = new TraceClassVisitor(new PrintWriter(buffer)) - val check = new CheckClassAdapter(trace) - val classJVMByteCodeAsEscapedStr = buffer.toString(StandardCharsets.UTF_8.name()) - log.error(s"lir exception ${e}:\n" + classJVMByteCodeAsEscapedStr) - emitClass(c, check, logMethodSizes = false) - throw e - } + b + } catch { + case e: Exception => + val buffer = new ByteArrayOutputStream() + val trace = new TraceClassVisitor(new PrintWriter(buffer)) + val check = new CheckClassAdapter(trace) + val classJVMByteCodeAsEscapedStr = buffer.toString(StandardCharsets.UTF_8.name()) + log.error(s"lir exception $e:\n" + classJVMByteCodeAsEscapedStr) + emitClass(c, check, logMethodSizes = false) + throw e + } print.foreach { pw => val cr = new ClassReader(bytes) val tcv = new TraceClassVisitor(null, new Textifier, pw) diff --git a/hail/src/main/scala/is/hail/lir/InitializeLocals.scala b/hail/src/main/scala/is/hail/lir/InitializeLocals.scala index b5996eb8d27..5d9531c21ce 100644 --- a/hail/src/main/scala/is/hail/lir/InitializeLocals.scala +++ b/hail/src/main/scala/is/hail/lir/InitializeLocals.scala @@ -5,7 +5,7 @@ object InitializeLocals { m: Method, blocks: Blocks, locals: Locals, - liveness: Liveness + liveness: Liveness, ): Unit = { val entryIdx = blocks.index(m.entry) val entryUsedIn = liveness.liveIn(entryIdx) @@ -16,7 +16,8 @@ object InitializeLocals { if (!l.isInstanceOf[Parameter]) { // println(s" init $l ${l.ti}") m.entry.prepend( - store(locals(i), defaultValue(l.ti))) + store(locals(i), defaultValue(l.ti)) + ) } i = entryUsedIn.nextSetBit(i + 1) } diff --git a/hail/src/main/scala/is/hail/lir/Liveness.scala b/hail/src/main/scala/is/hail/lir/Liveness.scala index 1e063083ea9..032c8b02338 100644 --- a/hail/src/main/scala/is/hail/lir/Liveness.scala +++ b/hail/src/main/scala/is/hail/lir/Liveness.scala @@ -6,7 +6,7 @@ object Liveness { def apply( blocks: Blocks, locals: Locals, - cfg: CFG + cfg: CFG, ): Liveness = { val nBlocks = blocks.nBlocks @@ -85,4 +85,5 @@ object Liveness { } class Liveness( - val liveIn: Array[java.util.BitSet]) + val liveIn: Array[java.util.BitSet] +) diff --git a/hail/src/main/scala/is/hail/lir/PST.scala b/hail/src/main/scala/is/hail/lir/PST.scala index de57b4507b1..a8f6b685ff0 100644 --- a/hail/src/main/scala/is/hail/lir/PST.scala +++ b/hail/src/main/scala/is/hail/lir/PST.scala @@ -57,7 +57,8 @@ class PSTRegion( var end: Int, var children: Array[Int], // -1 means root, or parent not yet known - var parent: Int = -1) + var parent: Int = -1, +) object PSTResult { def unapply(result: PSTResult): Option[(Blocks, CFG, PST)] = @@ -67,13 +68,13 @@ object PSTResult { class PSTResult( val blocks: Blocks, val cfg: CFG, - val pst: PST + val pst: PST, ) class PSTBuilder( m: Method, blocks: Blocks, - cfg: CFG + cfg: CFG, ) { def nBlocks: Int = blocks.length @@ -153,10 +154,9 @@ class PSTBuilder( private def linearize(): Unit = { val pending = Array.tabulate(nBlocks) { i => var n = 0 - for (p <- cfg.pred(i)) { + for (p <- cfg.pred(i)) if (!backEdges(p -> i)) n += 1 - } n } var k = 0 @@ -249,7 +249,7 @@ class PSTBuilder( val maxTarget = maxTargetLE(end) val minSource = minSourceGE(start) (maxTarget == -1 || maxTarget <= start) && - (minSource == -1 || minSource >= end) + (minSource == -1 || minSource >= end) } private val splitBlock = new java.util.BitSet(nBlocks) @@ -285,9 +285,11 @@ class PSTBuilder( } private def addRoot(): Int = { - if (frontier.size == 1 && + if ( + frontier.size == 1 && regions(frontier(0)).start == 0 && - regions(frontier(0)).end == nBlocks - 1) { + regions(frontier(0)).end == nBlocks - 1 + ) { frontier(0) } else { val c = regions.length @@ -313,7 +315,7 @@ class PSTBuilder( // find regions in [start, end] // no edges from [0, start) target (start, end] private def findRegions(start: Int, end: Int): Unit = { - var regionStarts = new IntArrayBuilder() + val regionStarts = new IntArrayBuilder() regionStarts += start // find subregions of [start, end] @@ -488,7 +490,8 @@ class PSTBuilder( val pst = new PST( newSplitBlock.result(), newRegions, - newRoot) + newRoot, + ) new PSTResult(newBlocks, newCFG, pst) } } @@ -503,7 +506,7 @@ object PST { class PST( val splitBlock: Array[Boolean], val regions: Array[PSTRegion], - val root: Int + val root: Int, ) { def nBlocks: Int = splitBlock.length @@ -513,23 +516,22 @@ class PST( println(s"PST $nRegions:") def fmt(i: Int): String = - s"${ if (i > 0 && splitBlock(i - 1)) "<" else "" }$i${ if (splitBlock(i)) ">" else "" }" + s"${if (i > 0 && splitBlock(i - 1)) "<" else ""}$i${if (splitBlock(i)) ">" else ""}" println(" regions:") var i = 0 while (i < nRegions) { val r = regions(i) - println(s" $i: ${ fmt(r.start) } ${ fmt(r.end) } ${ r.parent } ${ r.children.mkString(",") }") + println(s" $i: ${fmt(r.start)} ${fmt(r.end)} ${r.parent} ${r.children.mkString(",")}") i += 1 } println(" children:") def printTree(i: Int, depth: Int): Unit = { val r = regions(i) - println(s"${ " " * depth }$i: ${ fmt(r.start) } ${ fmt(r.end) }") - for (c <- regions(i).children) { + println(s"${" " * depth}$i: ${fmt(r.start)} ${fmt(r.end)}") + for (c <- regions(i).children) printTree(c, depth + 2) - } } i = 0 diff --git a/hail/src/main/scala/is/hail/lir/Pretty.scala b/hail/src/main/scala/is/hail/lir/Pretty.scala index b9c405eaa0e..a7bcb78b3f1 100644 --- a/hail/src/main/scala/is/hail/lir/Pretty.scala +++ b/hail/src/main/scala/is/hail/lir/Pretty.scala @@ -1,11 +1,11 @@ package is.hail.lir +import is.hail.utils.StringEscapeUtils.escapeString + import java.io.{StringWriter, Writer} import org.objectweb.asm -import is.hail.utils.StringEscapeUtils.escapeString - class Builder(var n: Int, out: Writer, val printSourceLineNumbers: Boolean = false) { var lineNumber: Int = 0 @@ -28,9 +28,8 @@ class Builder(var n: Int, out: Writer, val printSourceLineNumbers: Boolean = fal out.write(s) } - def appendToLastLine(s: String): Unit = { + def appendToLastLine(s: String): Unit = out.write(s) - } } object Pretty { @@ -78,42 +77,37 @@ object Pretty { def fmt(c: Classx[_], b: Builder, saveLineNumbers: Boolean): Unit = { // FIXME interfaces if (b.printSourceLineNumbers) { - c.sourceFile.foreach { sf => - b += s"source file: ${ sf }" - } + c.sourceFile.foreach(sf => b += s"source file: $sf") } - b += s"class ${ c.name } extends ${ c.superName }" + b += s"class ${c.name} extends ${c.superName}" b.indent { - for (f <- c.fields) { - b += s"field ${ f.name } ${ f.ti.desc }" - } + for (f <- c.fields) + b += s"field ${f.name} ${f.ti.desc}" b += "" - for (m <- c.methods) { + for (m <- c.methods) fmt(m, b, saveLineNumbers) - } } } def fmt(m: Method, b: Builder, saveLineNumbers: Boolean): Unit = { val blocks = m.findBlocks() - b += s"def ${ m.name } (${ m.parameterTypeInfo.map(_.desc).mkString(",") })${ m.returnTypeInfo.desc }" + b += s"def ${m.name} (${m.parameterTypeInfo.map(_.desc).mkString(",")})${m.returnTypeInfo.desc}" - val label: Block => String = b => s"L${ blocks.index(b) }" + val label: Block => String = b => s"L${blocks.index(b)}" b.indent { - b += s"entry L${ blocks.index(m.entry) }" - for (ell <- blocks) { + b += s"entry L${blocks.index(m.entry)}" + for (ell <- blocks) fmt(ell, label, b, saveLineNumbers) - } } b += "" } def fmt(L: Block, label: Block => String, b: Builder, saveLineNumbers: Boolean): Unit = { - b += s"${ label(L) }:" + b += s"${label(L)}:" b.indent { var x = L.first @@ -134,45 +128,44 @@ object Pretty { if (saveLineNumbers) x.lineNumber = b.lineNumber b.indent { - for (c <- x.children) { + for (c <- x.children) if (c != null) fmt(c, label, b, saveLineNumbers) else b += "null" - } b.appendToLastLine(")") } } def header(x: X, label: Block => String): String = x match { - case x: IfX => s"${ asm.util.Printer.OPCODES(x.op) } ${ label(x.Ltrue) } ${ label(x.Lfalse) }" + case x: IfX => s"${asm.util.Printer.OPCODES(x.op)} ${label(x.Ltrue)} ${label(x.Lfalse)}" case x: GotoX => if (x.L != null) label(x.L) else "null" - case x: SwitchX => s"${ label(x.Ldefault) } (${ x.Lcases.map(label).mkString(" ") })" + case x: SwitchX => s"${label(x.Ldefault)} (${x.Lcases.map(label).mkString(" ")})" case x: LdcX => val lit = x.a match { case s: String => s""""${escapeString(s)}"""" case a => a.toString } - s"$lit ${ x.ti }" + s"$lit ${x.ti}" case x: InsnX => asm.util.Printer.OPCODES(x.op) case x: StoreX => x.l.toString - case x: IincX => s"${ x.l.toString } ${ x.i }" + case x: IincX => s"${x.l.toString} ${x.i}" case x: PutFieldX => - s"${ asm.util.Printer.OPCODES(x.op) } ${ x.f }" + s"${asm.util.Printer.OPCODES(x.op)} ${x.f}" case x: GetFieldX => - s"${ asm.util.Printer.OPCODES(x.op) } ${ x.f }" - case x: NewInstanceX => s"${ x.ti.iname } ${ x.ctor }" + s"${asm.util.Printer.OPCODES(x.op)} ${x.f}" + case x: NewInstanceX => s"${x.ti.iname} ${x.ctor}" case x: TypeInsnX => - s"${ asm.util.Printer.OPCODES(x.op) } ${ x.ti.iname }" + s"${asm.util.Printer.OPCODES(x.op)} ${x.ti.iname}" case x: NewArrayX => x.eti.desc case x: MethodX => - s"${ asm.util.Printer.OPCODES(x.op) } ${ x.method }" + s"${asm.util.Printer.OPCODES(x.op)} ${x.method}" case x: MethodStmtX => - s"${ asm.util.Printer.OPCODES(x.op) } ${ x.method }" + s"${asm.util.Printer.OPCODES(x.op)} ${x.method}" case x: LoadX => x.l.toString case x: StmtOpX => asm.util.Printer.OPCODES(x.op) case _ => diff --git a/hail/src/main/scala/is/hail/lir/SimplifyControl.scala b/hail/src/main/scala/is/hail/lir/SimplifyControl.scala index 4908a639052..5de6ec84844 100644 --- a/hail/src/main/scala/is/hail/lir/SimplifyControl.scala +++ b/hail/src/main/scala/is/hail/lir/SimplifyControl.scala @@ -5,9 +5,8 @@ import is.hail.utils.UnionFind import scala.collection.mutable object SimplifyControl { - def apply(m: Method): Unit = { + def apply(m: Method): Unit = new SimplifyControl(m).simplify() - } } class SimplifyControl(m: Method) { @@ -107,16 +106,17 @@ class SimplifyControl(m: Method) { val blocks = m.findBlocks() val u = new UnionFind(blocks.length) - blocks.indices.foreach { i => - u.makeSet(i) - } + blocks.indices.foreach(i => u.makeSet(i)) for (b <- blocks) { - if (b.first != null && - b.first.isInstanceOf[GotoX]) { + if ( + b.first != null && + b.first.isInstanceOf[GotoX] + ) { u.union( blocks.index(b), - blocks.index(b.first.asInstanceOf[GotoX].L)) + blocks.index(b.first.asInstanceOf[GotoX].L), + ) } } @@ -131,8 +131,7 @@ class SimplifyControl(m: Method) { val last = b.last.asInstanceOf[ControlX] var i = 0 while (i < last.targetArity()) { - last.setTarget(i, - rootFinalTarget(u.find(blocks.index(last.target(i))))) + last.setTarget(i, rootFinalTarget(u.find(blocks.index(last.target(i))))) i += 1 } } @@ -158,11 +157,11 @@ class SimplifyControl(m: Method) { assert(m.findBlocks().forall { b => !b.first.isInstanceOf[GotoX] && - (b.last match { - case i: IfX => i.Ltrue ne i.Lfalse - case g: GotoX => g.L.uses.size > 1 || (g.L eq m.entry) - case _ => true - }) + (b.last match { + case i: IfX => i.Ltrue ne i.Lfalse + case g: GotoX => g.L.uses.size > 1 || (g.L eq m.entry) + case _ => true + }) }) } } diff --git a/hail/src/main/scala/is/hail/lir/SplitLargeBlocks.scala b/hail/src/main/scala/is/hail/lir/SplitLargeBlocks.scala index b66905eccd5..b616cd632a8 100644 --- a/hail/src/main/scala/is/hail/lir/SplitLargeBlocks.scala +++ b/hail/src/main/scala/is/hail/lir/SplitLargeBlocks.scala @@ -44,9 +44,8 @@ object SplitLargeBlocks { def apply(m: Method): Unit = { val blocks = m.findBlocks() - for (b <- blocks) { + for (b <- blocks) if (b.approxByteCodeSize() > SplitMethod.TargetMethodSize) splitLargeBlock(m, b) - } } } diff --git a/hail/src/main/scala/is/hail/lir/SplitMethod.scala b/hail/src/main/scala/is/hail/lir/SplitMethod.scala index b6f67e476d9..533f094bf7c 100644 --- a/hail/src/main/scala/is/hail/lir/SplitMethod.scala +++ b/hail/src/main/scala/is/hail/lir/SplitMethod.scala @@ -2,11 +2,12 @@ package is.hail.lir import is.hail.asm4s._ import is.hail.utils._ -import org.objectweb.asm.Opcodes._ -import org.objectweb.asm._ import scala.collection.mutable +import org.objectweb.asm._ +import org.objectweb.asm.Opcodes._ + class SplitUnreachable() extends Exception() object SplitMethod { @@ -19,7 +20,7 @@ object SplitMethod { locals: Locals, cfg: CFG, liveness: Liveness, - pst: PST + pst: PST, ): Classx[_] = { val split = new SplitMethod(c, m, blocks, locals, cfg, liveness, pst) split.split() @@ -34,14 +35,14 @@ class SplitMethod( locals: Locals, cfg: CFG, liveness: Liveness, - pst: PST + pst: PST, ) { def nBlocks: Int = blocks.nBlocks def nLocals: Int = locals.nLocals private val blockPartitions = new UnionFind(nBlocks) - (0 until nBlocks).foreach { i => blockPartitions.makeSet(i) } + (0 until nBlocks).foreach(i => blockPartitions.makeSet(i)) private var methodSize = 0 @@ -58,19 +59,23 @@ class SplitMethod( } } - private val spillsClass = new Classx(genName("C", s"${ m.name }Spills"), "java/lang/Object", None) + private val spillsClass = new Classx(genName("C", s"${m.name}Spills"), "java/lang/Object", None) + private val spillsCtor = { val ctor = spillsClass.newMethod("", FastSeq(), UnitInfo) val L = new Block() ctor.setEntry(L) L.append( - methodStmt(INVOKESPECIAL, + methodStmt( + INVOKESPECIAL, "java/lang/Object", "", "()V", false, UnitInfo, - FastSeq(load(ctor.getParam(0))))) + FastSeq(load(ctor.getParam(0))), + ) + ) L.append(returnx()) ctor } @@ -99,8 +104,14 @@ class SplitMethod( val ti = classInfo[SplitUnreachable] val tcls = classOf[SplitUnreachable] val c = tcls.getDeclaredConstructor() - throwx(newInstance(ti, - Type.getInternalName(tcls), "", Type.getConstructorDescriptor(c), ti, FastSeq())) + throwx(newInstance( + ti, + Type.getInternalName(tcls), + "", + Type.getConstructorDescriptor(c), + ti, + FastSeq(), + )) } private val spills = m.newLocal("spills", spillsClass.ti) @@ -144,12 +155,11 @@ class SplitMethod( } } - def getSpills(): ValueX = { + def getSpills(): ValueX = if (method eq m) load(spills) else load(new Parameter(method, 1, spillsClass.ti)) - } def spill(x: X): Unit = { x.children.foreach(spill) @@ -167,10 +177,15 @@ class SplitMethod( val f = localField(x.l) if (f != null) { x.replace( - putField(f, getSpills(), + putField( + f, + getSpills(), insn2(IADD)( getField(f, getSpills()), - ldcInsn(x.i, IntInfo)))) + ldcInsn(x.i, IntInfo), + ), + ) + ) } case x: StoreX => assert(x.l ne spills) @@ -215,9 +230,8 @@ class SplitMethod( def spillLocals(): Unit = { createSpillFields() - for (splitM <- splitMethods) { + for (splitM <- splitMethods) spillLocals(splitM) - } spillLocals(m) } @@ -226,16 +240,15 @@ class SplitMethod( val method = s.containingMethod() assert(method != null) - def getSpills(): ValueX = { + def getSpills(): ValueX = if (method eq m) load(spills) else load(new Parameter(method, 1, spillsClass.ti)) - } val Lafter = new Block() Lafter.method = method - while (s.next!= null) { + while (s.next != null) { val n = s.next n.remove() Lafter.append(n) @@ -250,10 +263,8 @@ class SplitMethod( } else Lreturn.append(returnx()) s.insertAfter( - ifx(IFNE, - getField(spillReturned, getSpills()), - Lreturn, - Lafter)) + ifx(IFNE, getField(spillReturned, getSpills()), Lreturn, Lafter) + ) } } @@ -281,9 +292,8 @@ class SplitMethod( } localsToSpill.or(liveness.liveIn(start)) - for (s <- cfg.succ(end)) { + for (s <- cfg.succ(end)) localsToSpill.or(liveness.liveIn(s)) - } // replacement block for region val newL = new Block() @@ -300,9 +310,7 @@ class SplitMethod( if (m.entry == Lstart) m.setEntry(newL) - (start to end).foreach { i => - blockPartitions.union(start, i) - } + (start to end).foreach(i => blockPartitions.union(start, i)) updatedBlocks(blockPartitions.find(start)) = newL val returnTI = Lend.last match { @@ -317,7 +325,7 @@ class SplitMethod( case _: ThrowX => UnitInfo } - val splitM = c.newMethod(s"${ m.name }_region${ start }_$end", FastSeq(spillsClass.ti), returnTI) + val splitM = c.newMethod(s"${m.name}_region${start}_$end", FastSeq(spillsClass.ti), returnTI) splitMethods += splitM splitM.setEntry(Lstart) @@ -355,7 +363,8 @@ class SplitMethod( if (splitsReturn) splitReturnCalls += s } else { - splitMReturnValue = methodInsn(INVOKEVIRTUAL, splitM, Array(load(m.getParam(0)), load(spills))) + splitMReturnValue = + methodInsn(INVOKEVIRTUAL, splitM, Array(load(m.getParam(0)), load(spills))) if (splitsReturn) { val l = m.newLocal("splitMReturnValue", returnTI) val s = store(l, splitMReturnValue) @@ -397,7 +406,8 @@ class SplitMethod( x.setLtrue(Lreturn) } else { newL.append( - ifx(IFNE, splitMReturnValue, x.Ltrue, x.Lfalse)) + ifx(IFNE, splitMReturnValue, x.Ltrue, x.Lfalse) + ) val newLtrue = new Block() newLtrue.method = splitM @@ -456,8 +466,10 @@ class SplitMethod( var size = subr.iterator.map(_.size).sum var changed = true - while (changed && - size > SplitMethod.TargetMethodSize) { + while ( + changed && + size > SplitMethod.TargetMethodSize + ) { changed = false @@ -467,10 +479,12 @@ class SplitMethod( while (i < subr.size) { var s = subr(i).size var j = i + 1 - while (j < subr.size && + while ( + j < subr.size && subr(j).start == subr(i).end + 1 && pst.splitBlock(subr(i).end) && - (s + subr(j).size < SplitMethod.TargetMethodSize)) { + (s + subr(j).size < SplitMethod.TargetMethodSize) + ) { s += subr(j).size j += 1 } @@ -485,8 +499,10 @@ class SplitMethod( i = sortedsubr.length - 1 while (i >= 0) { val ri = sortedsubr(i) - if (ri.size > 20 && - size > SplitMethod.TargetMethodSize) { + if ( + ri.size > 20 && + size > SplitMethod.TargetMethodSize + ) { size -= ri.size splitSlice(ri.start, ri.end) @@ -527,12 +543,8 @@ class SplitMethod( regionSize(i) = blockSize(blockPartitions.find(r.start)) } // The PST no longer computes loop regions. See PR #13566 for the removed code. - /* - if (i != pst.root && pst.loopRegion.get(i)) { - splitSlice(r.start, r.end) - regionSize(i) = blockSize(blockPartitions.find(r.start)) - } - */ + /* if (i != pst.root && pst.loopRegion.get(i)) { splitSlice(r.start, r.end) regionSize(i) = + * blockSize(blockPartitions.find(r.start)) } */ i += 1 } } @@ -552,7 +564,8 @@ class SplitMethod( val putParam = putField( f, load(spills), - load(m.getParam(i + 1))) + load(m.getParam(i + 1)), + ) x.insertAfter(putParam) x = putParam } diff --git a/hail/src/main/scala/is/hail/lir/X.scala b/hail/src/main/scala/is/hail/lir/X.scala index fae1d4bda59..8b1f8739860 100644 --- a/hail/src/main/scala/is/hail/lir/X.scala +++ b/hail/src/main/scala/is/hail/lir/X.scala @@ -2,11 +2,13 @@ package is.hail.lir import is.hail.asm4s._ import is.hail.utils._ -import org.objectweb.asm.Opcodes._ -import java.io.PrintWriter import scala.collection.mutable +import java.io.PrintWriter + +import org.objectweb.asm.Opcodes._ + // FIXME move typeinfo stuff lir class Classx[C](val name: String, val superName: String, var sourceFile: Option[String]) { @@ -18,9 +20,8 @@ class Classx[C](val name: String, val superName: String, var sourceFile: Option[ val interfaces: mutable.ArrayBuffer[String] = new mutable.ArrayBuffer() - def addInterface(name: String): Unit = { + def addInterface(name: String): Unit = interfaces += name - } def newField(name: String, ti: TypeInfo[_]): Field = { val f = new Field(this, name, ti) @@ -36,10 +37,12 @@ class Classx[C](val name: String, val superName: String, var sourceFile: Option[ f } - def newMethod(name: String, + def newMethod( + name: String, parameterTypeInfo: IndexedSeq[TypeInfo[_]], returnTypeInfo: TypeInfo[_], - isStatic: Boolean = false): Method = { + isStatic: Boolean = false, + ): Method = { val method = new Method(this, name, parameterTypeInfo, returnTypeInfo, isStatic) methods += method method @@ -48,9 +51,7 @@ class Classx[C](val name: String, val superName: String, var sourceFile: Option[ def saveToFile(path: String): Unit = { val file = new java.io.File(path) file.getParentFile.mkdirs() - using (new java.io.PrintWriter(file)) { out => - Pretty(this, out, saveLineNumbers = true) - } + using(new java.io.PrintWriter(file))(out => Pretty(this, out, saveLineNumbers = true)) sourceFile = Some(path) } @@ -64,11 +65,12 @@ class Classx[C](val name: String, val superName: String, var sourceFile: Option[ } val shortName = name.take(50) - if (writeIRs) saveToFile(s"/tmp/hail/${shortName}.lir") + if (writeIRs) saveToFile(s"/tmp/hail/$shortName.lir") for (m <- methods) { - if (m.name != "" - && m.approxByteCodeSize() > SplitMethod.TargetMethodSize + if ( + m.name != "" + && m.approxByteCodeSize() > SplitMethod.TargetMethodSize ) { SplitLargeBlocks(m) @@ -99,21 +101,20 @@ class Classx[C](val name: String, val superName: String, var sourceFile: Option[ InitializeLocals(m, blocks, locals, liveness) } - if (writeIRs) saveToFile(s"/tmp/hail/${shortName}.split.lir") + if (writeIRs) saveToFile(s"/tmp/hail/$shortName.split.lir") // println(Pretty(this, saveLineNumbers = false)) classes.iterator.map { c => - val bytes = Emit(c, - print + val bytes = Emit( + c, + print, // Some(new PrintWriter(System.out)) ) if (writeIRs) { val classFile = new java.io.File(s"/tmp/hail/${c.name.take(50)}.class") classFile.getParentFile.mkdirs() - using (new java.io.FileOutputStream(classFile)) { fos => - fos.write(bytes) - } + using(new java.io.FileOutputStream(classFile))(fos => fos.write(bytes)) } (c.name.replace("/", "."), bytes) @@ -128,21 +129,24 @@ abstract class FieldRef { def ti: TypeInfo[_] - override def toString: String = s"$owner.$name ${ ti.desc }" + override def toString: String = s"$owner.$name ${ti.desc}" } -class Field private[lir] (classx: Classx[_], val name: String, val ti: TypeInfo[_]) extends FieldRef { +class Field private[lir] (classx: Classx[_], val name: String, val ti: TypeInfo[_]) + extends FieldRef { def owner: String = classx.name } -class StaticField private[lir] (classx: Classx[_], val name: String, val ti: TypeInfo[_]) extends FieldRef { +class StaticField private[lir] (classx: Classx[_], val name: String, val ti: TypeInfo[_]) + extends FieldRef { def owner: String = classx.name } class FieldLit( val owner: String, val name: String, - val ti: TypeInfo[_]) extends FieldRef + val ti: TypeInfo[_], +) extends FieldRef abstract class MethodRef { def owner: String @@ -156,7 +160,7 @@ abstract class MethodRef { def returnTypeInfo: TypeInfo[_] override def toString: String = - s"$owner.$name $desc${ if (isInterface) "interface" else "" }" + s"$owner.$name $desc${if (isInterface) "interface" else ""}" } class Method private[lir] ( @@ -164,7 +168,8 @@ class Method private[lir] ( val name: String, val parameterTypeInfo: IndexedSeq[TypeInfo[_]], val returnTypeInfo: TypeInfo[_], - val isStatic: Boolean) extends MethodRef { + val isStatic: Boolean, +) extends MethodRef { def nParameters: Int = parameterTypeInfo.length + (!isStatic).toInt @@ -176,19 +181,20 @@ class Method private[lir] ( private var _entry: Block = _ - def setEntry(newEntry: Block): Unit = { + def setEntry(newEntry: Block): Unit = _entry = newEntry - } def entry: Block = _entry - def getParam(i: Int): Parameter = { - new Parameter(this, i, + def getParam(i: Int): Parameter = + new Parameter( + this, + i, if (i == 0 && !isStatic) new ClassInfo(classx.name) else - parameterTypeInfo(i - (!isStatic).toInt)) - } + parameterTypeInfo(i - (!isStatic).toInt), + ) def newLocal(name: String, ti: TypeInfo[_]): Local = new Local(this, name, ti) @@ -211,12 +217,8 @@ class Method private[lir] ( if (L.method == null) L.method = this else { - /* - if (L.method ne this) { - println(s"${ L.method } $this") - // println(b.stack.mkString("\n")) - } - */ + /* if (L.method ne this) { println(s"${ L.method } $this") // + * println(b.stack.mkString("\n")) } */ assert(L.method eq this) } @@ -241,10 +243,9 @@ class Method private[lir] ( for (b <- blocks) { // don't traverse a set that's being modified val uses2 = b.uses.toArray - for ((u, i) <- uses2) { + for ((u, i) <- uses2) if (u.parent == null || !visited(u.parent)) u.setTarget(i, null) - } } new Blocks(blocks) @@ -259,7 +260,8 @@ class Method private[lir] ( if (i == 0 && !isStatic) new Parameter(this, 0, classx.ti) else - new Parameter(this, i, parameterTypeInfo(i - (!isStatic).toInt))) + new Parameter(this, i, parameterTypeInfo(i - (!isStatic).toInt)) + ) i += 1 } @@ -271,12 +273,8 @@ class Method private[lir] ( if (!verifyMethodAssignment || l.method == null) l.method = this else { - /* - if (l.method ne this) { - // println(s"$l ${l.method} ${this}\n ${l.stack.mkString(" \n")}") - println(s"$l ${l.method} ${this}") - } - */ + /* if (l.method ne this) { // println(s"$l ${l.method} ${this}\n ${l.stack.mkString(" + * \n")}") println(s"$l ${l.method} ${this}") } */ assert(l.method eq this) } @@ -309,32 +307,33 @@ class Method private[lir] ( // Verify all blocks are well-formed, all blocks and locals have correct // method set. - def verify(): Unit = { + def verify(): Unit = findLocals(findBlocks(), verifyMethodAssignment = true) - } def approxByteCodeSize(): Int = { val blocks = findBlocks() var size = 0 - for (b <- blocks) { + for (b <- blocks) size += b.approxByteCodeSize() - } size } } class MethodLit( - val owner: String, val name: String, val desc: String, val isInterface: Boolean, - val returnTypeInfo: TypeInfo[_] + val owner: String, + val name: String, + val desc: String, + val isInterface: Boolean, + val returnTypeInfo: TypeInfo[_], ) extends MethodRef class Local(var method: Method, val name: String, val ti: TypeInfo[_]) { - override def toString: String = f"t${ System.identityHashCode(this) }%08x/$name ${ ti.desc }" + override def toString: String = f"t${System.identityHashCode(this)}%08x/$name ${ti.desc}" // val stack = Thread.currentThread().getStackTrace } class Parameter(method: Method, val i: Int, ti: TypeInfo[_]) extends Local(method, null, ti) { - override def toString: String = s"arg:$i ${ ti.desc }" + override def toString: String = s"arg:$i ${ti.desc}" } class Block { @@ -375,9 +374,8 @@ class Block { // don't traverse a set that's being modified val uses2 = uses.toArray - for ((x, i) <- uses2) { + for ((x, i) <- uses2) x.setTarget(i, L) - } assert(uses.isEmpty) } @@ -385,9 +383,8 @@ class Block { assert(x.parent == null) if (x.isInstanceOf[ControlX]) // prepending a new control statement, so previous contents are dead code - while (last != null) { + while (last != null) last.remove() - } if (last == null) { first = x last = x @@ -431,7 +428,7 @@ class Block { last = null } - override def toString: String = f"L${ System.identityHashCode(this) }%08x" + override def toString: String = f"L${System.identityHashCode(this)}%08x" def approxByteCodeSize(): Int = { var size = 1 // for the block @@ -472,13 +469,8 @@ abstract class X { c.parent = null if (x != null) { - /* - if (x.parent != null) { - println(x.setParentStack.mkString("\n")) - println("-------") - println(x.stack.mkString("\n")) - } - */ + /* if (x.parent != null) { println(x.setParentStack.mkString("\n")) println("-------") + * println(x.stack.mkString("\n")) } */ assert(x.parent == null) x.parent = this // x.setParentStack = Thread.currentThread().getStackTrace @@ -702,26 +694,23 @@ class SwitchX(var lineNumber: Int = 0) extends ControlX { def Lcases: IndexedSeq[Block] = _Lcases def setLcases(newLcases: IndexedSeq[Block]): Unit = { - for ((block, i) <- _Lcases.zipWithIndex) { + for ((block, i) <- _Lcases.zipWithIndex) if (block != null) block.removeUse(this, i + 1) - } // don't allow sharing _Lcases = Array(newLcases: _*) - for ((block, i) <- _Lcases.zipWithIndex) { + for ((block, i) <- _Lcases.zipWithIndex) if (block != null) block.addUse(this, i + 1) - } } def targetArity(): Int = 1 + _Lcases.length - def target(i: Int): Block = { + def target(i: Int): Block = if (i == 0) _Ldefault else _Lcases(i - 1) - } def setTarget(i: Int, b: Block): Unit = { if (i == 0) { @@ -767,8 +756,7 @@ class StmtOpX(val op: Int, var lineNumber: Int = 0) extends StmtX class MethodStmtX(val op: Int, val method: MethodRef, var lineNumber: Int = 0) extends StmtX -class TypeInsnX(val op: Int, val ti: TypeInfo[_], var lineNumber: Int = 0) extends ValueX { -} +class TypeInsnX(val op: Int, val ti: TypeInfo[_], var lineNumber: Int = 0) extends ValueX {} class InsnX(val op: Int, _ti: TypeInfo[_], var lineNumber: Int = 0) extends ValueX { def ti: TypeInfo[_] = { @@ -863,8 +851,11 @@ class NewInstanceX(val ti: TypeInfo[_], val ctor: MethodRef, var lineNumber: Int class LdcX(val a: Any, val ti: TypeInfo[_], var lineNumber: Int = 0) extends ValueX { assert( - a.isInstanceOf[String] || a.isInstanceOf[Double] || a.isInstanceOf[Float] || a.isInstanceOf[Int] || a.isInstanceOf[Long], - s"not a string, double, float, int, or long: $a") + a.isInstanceOf[String] || a.isInstanceOf[Double] || a.isInstanceOf[Float] || a.isInstanceOf[ + Int + ] || a.isInstanceOf[Long], + s"not a string, double, float, int, or long: $a", + ) } class MethodX(val op: Int, val method: MethodRef, var lineNumber: Int = 0) extends ValueX { diff --git a/hail/src/main/scala/is/hail/lir/package.scala b/hail/src/main/scala/is/hail/lir/package.scala index a584cd0850d..ed88fa9c109 100644 --- a/hail/src/main/scala/is/hail/lir/package.scala +++ b/hail/src/main/scala/is/hail/lir/package.scala @@ -1,7 +1,10 @@ package is.hail -import is.hail.asm4s.{ArrayInfo, BooleanInfo, ClassInfo, DoubleInfo, FloatInfo, IntInfo, LongInfo, TypeInfo} +import is.hail.asm4s.{ + ArrayInfo, BooleanInfo, ClassInfo, DoubleInfo, FloatInfo, IntInfo, LongInfo, TypeInfo, +} import is.hail.utils.FastSeq + import org.objectweb.asm.Opcodes._ package object lir { @@ -14,7 +17,7 @@ package object lir { throw new RuntimeException(s"genName has invalid character(s): $baseName") s"__$tag$counter$baseName" } else - s"__$tag${ counter }null" + s"__$tag${counter}null" } def setChildren(x: X, cs: IndexedSeq[ValueX]): Unit = { @@ -132,9 +135,13 @@ package object lir { } def methodStmt( - op: Int, owner: String, name: String, desc: String, isInterface: Boolean, + op: Int, + owner: String, + name: String, + desc: String, + isInterface: Boolean, returnTypeInfo: TypeInfo[_], - args: IndexedSeq[ValueX] + args: IndexedSeq[ValueX], ): StmtX = { val x = new MethodStmtX(op, new MethodLit(owner, name, desc, isInterface, returnTypeInfo)) setChildren(x, args) @@ -142,7 +149,9 @@ package object lir { } def methodStmt( - op: Int, method: Method, args: IndexedSeq[ValueX] + op: Int, + method: Method, + args: IndexedSeq[ValueX], ): StmtX = { val x = new MethodStmtX(op, method) setChildren(x, args) @@ -150,9 +159,13 @@ package object lir { } def methodInsn( - op: Int, owner: String, name: String, desc: String, isInterface: Boolean, + op: Int, + owner: String, + name: String, + desc: String, + isInterface: Boolean, returnTypeInfo: TypeInfo[_], - args: IndexedSeq[ValueX] + args: IndexedSeq[ValueX], ): ValueX = { val x = new MethodX(op, new MethodLit(owner, name, desc, isInterface, returnTypeInfo)) setChildren(x, args) @@ -160,7 +173,9 @@ package object lir { } def methodInsn( - op: Int, m: MethodRef, args: IndexedSeq[ValueX] + op: Int, + m: MethodRef, + args: IndexedSeq[ValueX], ): ValueX = { val x = new MethodX(op, m) setChildren(x, args) @@ -234,18 +249,28 @@ package object lir { def newInstance( ti: TypeInfo[_], - owner: String, name: String, desc: String, returnTypeInfo: TypeInfo[_], - args: IndexedSeq[ValueX] + owner: String, + name: String, + desc: String, + returnTypeInfo: TypeInfo[_], + args: IndexedSeq[ValueX], ): ValueX = newInstance(ti, owner, name, desc, returnTypeInfo, args, 0) def newInstance( ti: TypeInfo[_], - owner: String, name: String, desc: String, returnTypeInfo: TypeInfo[_], + owner: String, + name: String, + desc: String, + returnTypeInfo: TypeInfo[_], args: IndexedSeq[ValueX], - lineNumber: Int + lineNumber: Int, ): ValueX = { - val x = new NewInstanceX(ti, new MethodLit(owner, name, desc, isInterface = false, returnTypeInfo), lineNumber) + val x = new NewInstanceX( + ti, + new MethodLit(owner, name, desc, isInterface = false, returnTypeInfo), + lineNumber, + ) setChildren(x, args) x } @@ -253,7 +278,8 @@ package object lir { def newInstance(ti: TypeInfo[_], method: Method, args: IndexedSeq[ValueX]): ValueX = newInstance(ti, method, args, 0) - def newInstance(ti: TypeInfo[_], method: Method, args: IndexedSeq[ValueX], lineNumber: Int): ValueX = { + def newInstance(ti: TypeInfo[_], method: Method, args: IndexedSeq[ValueX], lineNumber: Int) + : ValueX = { val x = new NewInstanceX(ti, method, lineNumber) setChildren(x, args) x diff --git a/hail/src/main/scala/is/hail/methods/FilterPartitions.scala b/hail/src/main/scala/is/hail/methods/FilterPartitions.scala index bf1d3da3850..a336cf87628 100644 --- a/hail/src/main/scala/is/hail/methods/FilterPartitions.scala +++ b/hail/src/main/scala/is/hail/methods/FilterPartitions.scala @@ -15,7 +15,9 @@ case class TableFilterPartitions(parts: Seq[Int], keep: Boolean) extends TableTo tv.rvd.subsetPartitions(parts.toArray) else { val subtract = parts.toSet - tv.rvd.subsetPartitions((0 until tv.rvd.getNumPartitions).filter(i => !subtract.contains(i)).toArray) + tv.rvd.subsetPartitions((0 until tv.rvd.getNumPartitions).filter(i => + !subtract.contains(i) + ).toArray) } tv.copy(rvd = newRVD) } diff --git a/hail/src/main/scala/is/hail/methods/ForceCount.scala b/hail/src/main/scala/is/hail/methods/ForceCount.scala index 36d5f21de2b..1c39e6028b4 100644 --- a/hail/src/main/scala/is/hail/methods/ForceCount.scala +++ b/hail/src/main/scala/is/hail/methods/ForceCount.scala @@ -1,10 +1,10 @@ package is.hail.methods import is.hail.backend.ExecuteContext -import is.hail.expr.ir.functions.{MatrixToValueFunction, TableToValueFunction} import is.hail.expr.ir.{MatrixValue, TableValue} -import is.hail.types.virtual.{TInt64, Type} +import is.hail.expr.ir.functions.{MatrixToValueFunction, TableToValueFunction} import is.hail.types.{MatrixType, RTable, TableType, TypeWithRequiredness} +import is.hail.types.virtual.{TInt64, Type} case class ForceCountTable() extends TableToValueFunction { override def typ(childType: TableType): Type = TInt64 @@ -19,7 +19,8 @@ case class ForceCountMatrixTable() extends MatrixToValueFunction { def unionRequiredness(childType: RTable, resultType: TypeWithRequiredness): Unit = () - override def execute(ctx: ExecuteContext, mv: MatrixValue): Any = throw new UnsupportedOperationException + override def execute(ctx: ExecuteContext, mv: MatrixValue): Any = + throw new UnsupportedOperationException override def lower(): Option[TableToValueFunction] = Some(ForceCountTable()) } diff --git a/hail/src/main/scala/is/hail/methods/IBD.scala b/hail/src/main/scala/is/hail/methods/IBD.scala index 19ebf90aed4..25c42d1e9a9 100644 --- a/hail/src/main/scala/is/hail/methods/IBD.scala +++ b/hail/src/main/scala/is/hail/methods/IBD.scala @@ -5,22 +5,25 @@ import is.hail.backend.ExecuteContext import is.hail.expr.ir._ import is.hail.expr.ir.functions.MatrixToTableFunction import is.hail.sparkextras.ContextRDD +import is.hail.types.{MatrixType, TableType} import is.hail.types.physical.{PCanonicalString, PCanonicalStruct, PFloat64, PInt64} import is.hail.types.virtual.{TFloat64, TStruct} -import is.hail.types.{MatrixType, TableType} import is.hail.utils._ import is.hail.variant.{AllelePair, Call, Genotype, HardCallView} -import org.apache.spark.sql.Row -import scala.language.higherKinds +import org.apache.spark.sql.Row object IBDInfo { - def apply(Z0: Double, Z1: Double, Z2: Double): IBDInfo = { + def apply(Z0: Double, Z1: Double, Z2: Double): IBDInfo = IBDInfo(Z0, Z1, Z2, Z1 / 2 + Z2) - } val pType = - PCanonicalStruct(("Z0", PFloat64()), ("Z1", PFloat64()), ("Z2", PFloat64()), ("PI_HAT", PFloat64())) + PCanonicalStruct( + ("Z0", PFloat64()), + ("Z1", PFloat64()), + ("Z2", PFloat64()), + ("PI_HAT", PFloat64()), + ) def fromRegionValue(offset: Long): IBDInfo = { val Z0 = Region.loadDouble(pType.loadField(offset, 0)) @@ -39,7 +42,7 @@ case class IBDInfo(Z0: Double, Z1: Double, Z2: Double, PI_HAT: Double) { def toAnnotation: Annotation = Annotation(Z0, Z1, Z2, PI_HAT) - def toRegionValue(rvb: RegionValueBuilder) { + def toRegionValue(rvb: RegionValueBuilder): Unit = { rvb.addDouble(Z0) rvb.addDouble(Z1) rvb.addDouble(Z2) @@ -49,7 +52,12 @@ case class IBDInfo(Z0: Double, Z1: Double, Z2: Double, PI_HAT: Double) { object ExtendedIBDInfo { val pType = - PCanonicalStruct(("ibd", IBDInfo.pType), ("ibs0", PInt64()), ("ibs1", PInt64()), ("ibs2", PInt64())) + PCanonicalStruct( + ("ibd", IBDInfo.pType), + ("ibs0", PInt64()), + ("ibs1", PInt64()), + ("ibs2", PInt64()), + ) def fromRegionValue(offset: Long): ExtendedIBDInfo = { val ibd = IBDInfo.fromRegionValue(pType.loadField(offset, 0)) @@ -62,13 +70,18 @@ object ExtendedIBDInfo { case class ExtendedIBDInfo(ibd: IBDInfo, ibs0: Long, ibs1: Long, ibs2: Long) { def pointwiseMinus(that: ExtendedIBDInfo): ExtendedIBDInfo = - ExtendedIBDInfo(ibd.pointwiseMinus(that.ibd), ibs0 - that.ibs0, ibs1 - that.ibs1, ibs2 - that.ibs2) + ExtendedIBDInfo( + ibd.pointwiseMinus(that.ibd), + ibs0 - that.ibs0, + ibs1 - that.ibs1, + ibs2 - that.ibs2, + ) def hasNaNs: Boolean = ibd.hasNaNs def makeRow(i: Any, j: Any): Row = Row(i, j, ibd.toAnnotation, ibs0, ibs1, ibs2) - def toRegionValue(rvb: RegionValueBuilder) { + def toRegionValue(rvb: RegionValueBuilder): Unit = { rvb.startStruct() ibd.toRegionValue(rvb) rvb.endStruct() @@ -79,12 +92,26 @@ case class ExtendedIBDInfo(ibd: IBDInfo, ibs0: Long, ibs1: Long, ibs2: Long) { } case class IBSExpectations( - E00: Double, E10: Double, E20: Double, - E11: Double, E21: Double, E22: Double = 1, nonNaNCount: Int = 1) { + E00: Double, + E10: Double, + E20: Double, + E11: Double, + E21: Double, + E22: Double = 1, + nonNaNCount: Int = 1, +) { def hasNaNs: Boolean = Array(E00, E10, E20, E11, E21).exists(_.isNaN) def normalized: IBSExpectations = - IBSExpectations(E00 / nonNaNCount, E10 / nonNaNCount, E20 / nonNaNCount, E11 / nonNaNCount, E21 / nonNaNCount, E22, this.nonNaNCount) + IBSExpectations( + E00 / nonNaNCount, + E10 / nonNaNCount, + E20 / nonNaNCount, + E11 / nonNaNCount, + E21 / nonNaNCount, + E22, + this.nonNaNCount, + ) def scaled(N: Long): IBSExpectations = IBSExpectations(E00 * N, E10 * N, E20 * N, E11 * N, E21 * N, E22 * N, this.nonNaNCount) @@ -95,12 +122,14 @@ case class IBSExpectations( else if (that.hasNaNs) this else - IBSExpectations(E00 + that.E00, + IBSExpectations( + E00 + that.E00, E10 + that.E10, E20 + that.E20, E11 + that.E11, E21 + that.E21, - nonNaNCount = nonNaNCount + that.nonNaNCount) + nonNaNCount = nonNaNCount + that.nonNaNCount, + ) } @@ -156,15 +185,21 @@ object IBD { maybeMaf.map(calculateCountsFromMAF).getOrElse(estimateFrequenciesFromSample) val Na = na - val a00 = 2 * p * p * q * q * ((x - 1) / x * (y - 1) / y * (Na / (Na - 1)) * (Na / (Na - 2)) * (Na / (Na - 3))) - val a10 = 4 * p * p * p * q * ((x - 1) / x * (x - 2) / x * (Na / (Na - 1)) * (Na / (Na - 2)) * (Na / (Na - 3))) + 4 * p * q * q * q * ((y - 1) / y * (y - 2) / y * (Na / (Na - 1)) * (Na / (Na - 2)) * (Na / (Na - 3))) - val a20 = q * q * q * q * ((y - 1) / y * (y - 2) / y * (y - 3) / y * (Na / (Na - 1)) * (Na / (Na - 2)) * (Na / (Na - 3))) + p * p * p * p * ((x - 1) / x * (x - 2) / x * (x - 3) / x * (Na / (Na - 1)) * (Na / (Na - 2)) * (Na / (Na - 3))) + 4 * p * p * q * q * ((x - 1) / x * (y - 1) / y * (Na / (Na - 1)) * (Na / (Na - 2)) * (Na / (Na - 3))) - val a11 = 2 * p * p * q * ((x - 1) / x * Na / (Na - 1) * Na / (Na - 2)) + 2 * p * q * q * ((y - 1) / y * Na / (Na - 1) * Na / (Na - 2)) - val a21 = p * p * p * ((x - 1) / x * (x - 2) / x * Na / (Na - 1) * Na / (Na - 2)) + q * q * q * ((y - 1) / y * (y - 2) / y * Na / (Na - 1) * Na / (Na - 2)) + p * p * q * ((x - 1) / x * Na / (Na - 1) * Na / (Na - 2)) + p * q * q * ((y - 1) / y * Na / (Na - 1) * Na / (Na - 2)) + val a00 = + 2 * p * p * q * q * ((x - 1) / x * (y - 1) / y * (Na / (Na - 1)) * (Na / (Na - 2)) * (Na / (Na - 3))) + val a10 = + 4 * p * p * p * q * ((x - 1) / x * (x - 2) / x * (Na / (Na - 1)) * (Na / (Na - 2)) * (Na / (Na - 3))) + 4 * p * q * q * q * ((y - 1) / y * (y - 2) / y * (Na / (Na - 1)) * (Na / (Na - 2)) * (Na / (Na - 3))) + val a20 = + q * q * q * q * ((y - 1) / y * (y - 2) / y * (y - 3) / y * (Na / (Na - 1)) * (Na / (Na - 2)) * (Na / (Na - 3))) + p * p * p * p * ((x - 1) / x * (x - 2) / x * (x - 3) / x * (Na / (Na - 1)) * (Na / (Na - 2)) * (Na / (Na - 3))) + 4 * p * p * q * q * ((x - 1) / x * (y - 1) / y * (Na / (Na - 1)) * (Na / (Na - 2)) * (Na / (Na - 3))) + val a11 = + 2 * p * p * q * ((x - 1) / x * Na / (Na - 1) * Na / (Na - 2)) + 2 * p * q * q * ((y - 1) / y * Na / (Na - 1) * Na / (Na - 2)) + val a21 = + p * p * p * ((x - 1) / x * (x - 2) / x * Na / (Na - 1) * Na / (Na - 2)) + q * q * q * ((y - 1) / y * (y - 2) / y * Na / (Na - 1) * Na / (Na - 2)) + p * p * q * ((x - 1) / x * Na / (Na - 1) * Na / (Na - 2)) + p * q * q * ((y - 1) / y * Na / (Na - 1) * Na / (Na - 2)) IBSExpectations(a00, a10, a20, a11, a21) } - def calculateIBDInfo(N0: Long, N1: Long, N2: Long, ibse: IBSExpectations, bounded: Boolean): ExtendedIBDInfo = { + def calculateIBDInfo(N0: Long, N1: Long, N2: Long, ibse: IBSExpectations, bounded: Boolean) + : ExtendedIBDInfo = { val ibseN = ibse.scaled(N0 + N1 + N2) val Z0 = N0 / ibseN.E00 val Z1 = (N1 - Z0 * ibseN.E10) / ibseN.E11 @@ -197,13 +232,15 @@ object IBD { final val chunkSize = 1024 - def computeIBDMatrix(ctx: ExecuteContext, + def computeIBDMatrix( + ctx: ExecuteContext, input: MatrixValue, computeMaf: Option[(RegionValue) => Double], min: Option[Double], max: Option[Double], sampleIds: IndexedSeq[String], - bounded: Boolean): ContextRDD[Long] = { + bounded: Boolean, + ): ContextRDD[Long] = { val nSamples = input.nCols val sm = ctx.stateManager @@ -241,15 +278,18 @@ object IBD { .zipWithIndex .map { case (gtGroup, i) => ((i, variantId / chunkSize), (vid, gtGroup)) } } - .aggregateByKey(Array.fill(chunkSize * chunkSize)(IBSFFI.missingGTCRep))({ case (x, (vid, gs)) => - for (i <- gs.indices) x(vid * chunkSize + i) = gs(i) - x - }, { case (x, y) => - for (i <- y.indices) - if (x(i) == IBSFFI.missingGTCRep) - x(i) = y(i) - x - }) + .aggregateByKey(Array.fill(chunkSize * chunkSize)(IBSFFI.missingGTCRep))( + { case (x, (vid, gs)) => + for (i <- gs.indices) x(vid * chunkSize + i) = gs(i) + x + }, + { case (x, y) => + for (i <- y.indices) + if (x(i) == IBSFFI.missingGTCRep) + x(i) = y(i) + x + }, + ) .map { case ((s, v), gs) => (v, (s, IBSFFI.pack(chunkSize, chunkSize, gs))) } val joined = ContextRDD.weaken(chunkedGenotypeMatrix.join(chunkedGenotypeMatrix) @@ -278,7 +318,8 @@ object IBD { j = jChunk * chunkSize + sj if j > i && j < nSamples && i < nSamples idx = si * chunkSize + sj - eibd = calculateIBDInfo(ibses(idx * 3), ibses(idx * 3 + 1), ibses(idx * 3 + 2), ibse, bounded) + eibd = + calculateIBDInfo(ibses(idx * 3), ibses(idx * 3 + 1), ibses(idx * 3 + 2), ibse, bounded) if min.forall(eibd.ibd.PI_HAT >= _) && max.forall(eibd.ibd.PI_HAT <= _) } yield { rvb.start(ibdPType) @@ -293,10 +334,18 @@ object IBD { } private val ibdPType = - PCanonicalStruct(required = true, Array(("i", PCanonicalString()), ("j", PCanonicalString())) ++ ExtendedIBDInfo.pType.fields.map(f => (f.name, f.typ)): _*) + PCanonicalStruct( + required = true, + Array( + ("i", PCanonicalString()), + ("j", PCanonicalString()), + ) ++ ExtendedIBDInfo.pType.fields.map(f => (f.name, f.typ)): _* + ) + private val ibdKey = FastSeq("i", "j") - private[methods] def generateComputeMaf(input: MatrixValue, fieldName: String): (RegionValue) => Double = { + private[methods] def generateComputeMaf(input: MatrixValue, fieldName: String) + : (RegionValue) => Double = { val rvRowType = input.rvRowType val rvRowPType = input.rvRowPType val field = rvRowType.field(fieldName) @@ -311,11 +360,13 @@ object IBD { val maf = Region.loadDouble(rvRowPType.loadField(rv.offset, idx)) if (!isDefined) { val row = new UnsafeRow(rvRowPType, rv).deleteField(entriesIdx) - fatal(s"The minor allele frequency expression evaluated to NA at ${ rowKeysF(row) }.") + fatal(s"The minor allele frequency expression evaluated to NA at ${rowKeysF(row)}.") } if (maf < 0.0 || maf > 1.0) { val row = new UnsafeRow(rvRowPType, rv).deleteField(entriesIdx) - fatal(s"The minor allele frequency expression for ${ rowKeysF(row) } evaluated to $maf which is not in [0,1].") + fatal( + s"The minor allele frequency expression for ${rowKeysF(row)} evaluated to $maf which is not in [0,1]." + ) } maf } @@ -326,14 +377,15 @@ case class IBD( mafFieldName: Option[String] = None, bounded: Boolean = true, min: Option[Double] = None, - max: Option[Double] = None) extends MatrixToTableFunction { + max: Option[Double] = None, +) extends MatrixToTableFunction { min.foreach(min => optionCheckInRangeInclusive(0.0, 1.0)("minimum", min)) max.foreach(max => optionCheckInRangeInclusive(0.0, 1.0)("maximum", max)) min.liftedZip(max).foreach { case (min, max) => if (min > max) { - fatal(s"minimum must be less than or equal to maximum: ${ min }, ${ max }") + fatal(s"minimum must be less than or equal to maximum: $min, $max") } } @@ -345,7 +397,8 @@ case class IBD( def execute(ctx: ExecuteContext, input: MatrixValue): TableValue = { input.requireUniqueSamples("ibd") val computeMaf = mafFieldName.map(IBD.generateComputeMaf(input, _)) - val crdd = IBD.computeIBDMatrix(ctx, input, computeMaf, min, max, input.stringSampleIds, bounded) + val crdd = + IBD.computeIBDMatrix(ctx, input, computeMaf, min, max, input.stringSampleIds, bounded) TableValue(ctx, IBD.ibdPType, IBD.ibdKey, crdd) } } diff --git a/hail/src/main/scala/is/hail/methods/IBSFFI.scala b/hail/src/main/scala/is/hail/methods/IBSFFI.scala index c979d387beb..93a0cc0cad1 100644 --- a/hail/src/main/scala/is/hail/methods/IBSFFI.scala +++ b/hail/src/main/scala/is/hail/methods/IBSFFI.scala @@ -2,15 +2,21 @@ package is.hail.methods import com.sun.jna._ -case class IBS (N0: Long, N1: Long, N2: Long) { } +case class IBS(N0: Long, N1: Long, N2: Long) {} object IBSFFI { val gtToCRep = Array[Byte](0, 1, 3) - val missingGTCRep : Byte = 2 + val missingGTCRep: Byte = 2 @native - def ibsMat(result: Array[Long], nSamples: Long, nPacks: Long, genotypes1: Array[Long], genotypes2: Array[Long]) + def ibsMat( + result: Array[Long], + nSamples: Long, + nPacks: Long, + genotypes1: Array[Long], + genotypes2: Array[Long], + ): Unit // NativeCode needs to control the initial loading of the libhail DLL, and // the call to getHailName() guarantees that. @@ -28,16 +34,32 @@ object IBSFFI { while (si != nSamples) { var pack = 0 while (pack != nPacks) { - val k = si + pack*genotypesPerPack*nSamples + val k = si + pack * genotypesPerPack * nSamples sampleOrientedGenotypes(si * nPacks + pack) = - gs(k).toLong << 62 | gs(k + 1 * nSamples).toLong << 60 | gs(k + 2 * nSamples).toLong << 58 | gs(k + 3 * nSamples).toLong << 56 | - gs(k + 4 * nSamples).toLong << 54 | gs(k + 5 * nSamples).toLong << 52 | gs(k + 6 * nSamples).toLong << 50 | gs(k + 7 * nSamples).toLong << 48 | - gs(k + 8 * nSamples).toLong << 46 | gs(k + 9 * nSamples).toLong << 44 | gs(k + 10 * nSamples).toLong << 42 | gs(k + 11 * nSamples).toLong << 40 | - gs(k + 12 * nSamples).toLong << 38 | gs(k + 13 * nSamples).toLong << 36 | gs(k + 14 * nSamples).toLong << 34 | gs(k + 15 * nSamples).toLong << 32 | - gs(k + 16 * nSamples).toLong << 30 | gs(k + 17 * nSamples).toLong << 28 | gs(k + 18 * nSamples).toLong << 26 | gs(k + 19 * nSamples).toLong << 24 | - gs(k + 20 * nSamples).toLong << 22 | gs(k + 21 * nSamples).toLong << 20 | gs(k + 22 * nSamples).toLong << 18 | gs(k + 23 * nSamples).toLong << 16 | - gs(k + 24 * nSamples).toLong << 14 | gs(k + 25 * nSamples).toLong << 12 | gs(k + 26 * nSamples).toLong << 10 | gs(k + 27 * nSamples).toLong << 8 | - gs(k + 28 * nSamples).toLong << 6 | gs(k + 29 * nSamples).toLong << 4 | gs(k + 30 * nSamples).toLong << 2 | gs(k + 31 * nSamples).toLong + gs(k).toLong << 62 | gs(k + 1 * nSamples).toLong << 60 | gs( + k + 2 * nSamples + ).toLong << 58 | gs(k + 3 * nSamples).toLong << 56 | + gs(k + 4 * nSamples).toLong << 54 | gs(k + 5 * nSamples).toLong << 52 | gs( + k + 6 * nSamples + ).toLong << 50 | gs(k + 7 * nSamples).toLong << 48 | + gs(k + 8 * nSamples).toLong << 46 | gs(k + 9 * nSamples).toLong << 44 | gs( + k + 10 * nSamples + ).toLong << 42 | gs(k + 11 * nSamples).toLong << 40 | + gs(k + 12 * nSamples).toLong << 38 | gs(k + 13 * nSamples).toLong << 36 | gs( + k + 14 * nSamples + ).toLong << 34 | gs(k + 15 * nSamples).toLong << 32 | + gs(k + 16 * nSamples).toLong << 30 | gs(k + 17 * nSamples).toLong << 28 | gs( + k + 18 * nSamples + ).toLong << 26 | gs(k + 19 * nSamples).toLong << 24 | + gs(k + 20 * nSamples).toLong << 22 | gs(k + 21 * nSamples).toLong << 20 | gs( + k + 22 * nSamples + ).toLong << 18 | gs(k + 23 * nSamples).toLong << 16 | + gs(k + 24 * nSamples).toLong << 14 | gs(k + 25 * nSamples).toLong << 12 | gs( + k + 26 * nSamples + ).toLong << 10 | gs(k + 27 * nSamples).toLong << 8 | + gs(k + 28 * nSamples).toLong << 6 | gs(k + 29 * nSamples).toLong << 4 | gs( + k + 30 * nSamples + ).toLong << 2 | gs(k + 31 * nSamples).toLong pack += 1 } diff --git a/hail/src/main/scala/is/hail/methods/LinearRegression.scala b/hail/src/main/scala/is/hail/methods/LinearRegression.scala index db2dee7d1f1..37c5724d796 100644 --- a/hail/src/main/scala/is/hail/methods/LinearRegression.scala +++ b/hail/src/main/scala/is/hail/methods/LinearRegression.scala @@ -1,17 +1,18 @@ package is.hail.methods -import breeze.linalg._ -import breeze.numerics.sqrt import is.hail.HailContext import is.hail.annotations._ import is.hail.backend.ExecuteContext -import is.hail.expr.ir.functions.MatrixToTableFunction import is.hail.expr.ir.{IntArrayBuilder, MatrixValue, TableValue} +import is.hail.expr.ir.functions.MatrixToTableFunction +import is.hail.stats._ import is.hail.types._ import is.hail.types.physical.PStruct import is.hail.types.virtual.{TArray, TFloat64, TInt32, TStruct} -import is.hail.stats._ import is.hail.utils._ + +import breeze.linalg._ +import breeze.numerics.sqrt import net.sourceforge.jdistlib.T case class LinearRegressionRowsSingle( @@ -19,7 +20,8 @@ case class LinearRegressionRowsSingle( xField: String, covFields: Seq[String], rowBlockSize: Int, - passThrough: Seq[String]) extends MatrixToTableFunction { + passThrough: Seq[String], +) extends MatrixToTableFunction { override def typ(childType: MatrixType): TableType = { val passThroughType = TStruct(passThrough.map(f => f -> childType.rowType.field(f).typ): _*) @@ -30,17 +32,20 @@ case class LinearRegressionRowsSingle( ("beta", TArray(TFloat64)), ("standard_error", TArray(TFloat64)), ("t_stat", TArray(TFloat64)), - ("p_value", TArray(TFloat64))) + ("p_value", TArray(TFloat64)), + ) TableType( childType.rowKeyStruct ++ passThroughType ++ schema, childType.rowKey, - TStruct.empty) + TStruct.empty, + ) } def preservesPartitionCounts: Boolean = true def execute(ctx: ExecuteContext, mv: MatrixValue): TableValue = { - val (y, cov, completeColIdx) = RegressionUtils.getPhenosCovCompleteSamples(mv, yFields.toArray, covFields.toArray) + val (y, cov, completeColIdx) = + RegressionUtils.getPhenosCovCompleteSamples(mv, yFields.toArray, covFields.toArray) val n = y.rows // n_complete_samples val k = cov.cols // nCovariates @@ -48,10 +53,12 @@ case class LinearRegressionRowsSingle( val dRec = 1d / d if (d < 1) - fatal(s"$n samples and ${ k + 1 } ${ plural(k, "covariate") } (including x) implies $d degrees of freedom.") + fatal( + s"$n samples and ${k + 1} ${plural(k, "covariate")} (including x) implies $d degrees of freedom." + ) - info(s"linear_regression_rows: running on $n samples for ${ y.cols } response ${ plural(y.cols, "variable") } y,\n" - + s" with input variable x, and ${ k } additional ${ plural(k, "covariate") }...") + info(s"linear_regression_rows: running on $n samples for ${y.cols} response ${plural(y.cols, "variable")} y,\n" + + s" with input variable x, and $k additional ${plural(k, "covariate")}...") val Qt = if (k > 0) @@ -84,92 +91,104 @@ case class LinearRegressionRowsSingle( val sm = ctx.stateManager val newRVD = mv.rvd.mapPartitionsWithContext( - rvdType) { (consumerCtx, it) => - val producerCtx = consumerCtx.freshContext - val rvb = new RegionValueBuilder(sm) - - val missingCompleteCols = new IntArrayBuilder() - val data = new Array[Double](n * rowBlockSize) - - val blockWRVs = new Array[WritableRegionValue](rowBlockSize) - var i = 0 - while (i < rowBlockSize) { - blockWRVs(i) = WritableRegionValue(sm, fullRowType, producerCtx.freshRegion()) - i += 1 - } + rvdType + ) { (consumerCtx, it) => + val producerCtx = consumerCtx.freshContext + val rvb = new RegionValueBuilder(sm) + + val missingCompleteCols = new IntArrayBuilder() + val data = new Array[Double](n * rowBlockSize) + + val blockWRVs = new Array[WritableRegionValue](rowBlockSize) + var i = 0 + while (i < rowBlockSize) { + blockWRVs(i) = WritableRegionValue(sm, fullRowType, producerCtx.freshRegion()) + i += 1 + } - it(producerCtx).trueGroupedIterator(rowBlockSize) - .flatMap { git => - var i = 0 - while (git.hasNext) { - val ptr = git.next() - RegressionUtils.setMeanImputedDoubles(data, i * n, completeColIdxBc.value, missingCompleteCols, - ptr, fullRowType, entryArrayType, entryType, entryArrayIdx, fieldIdx) - blockWRVs(i).set(ptr, true) - producerCtx.region.clear() - i += 1 - } - val blockLength = i + it(producerCtx).trueGroupedIterator(rowBlockSize) + .flatMap { git => + var i = 0 + while (git.hasNext) { + val ptr = git.next() + RegressionUtils.setMeanImputedDoubles( + data, + i * n, + completeColIdxBc.value, + missingCompleteCols, + ptr, + fullRowType, + entryArrayType, + entryType, + entryArrayIdx, + fieldIdx, + ) + blockWRVs(i).set(ptr, true) + producerCtx.region.clear() + i += 1 + } + val blockLength = i - val X = new DenseMatrix[Double](n, blockLength, data) + val X = new DenseMatrix[Double](n, blockLength, data) - val AC: DenseVector[Double] = X.t(*, ::).map(r => sum(r)) - assert(AC.length == blockLength) + val AC: DenseVector[Double] = X.t(*, ::).map(r => sum(r)) + assert(AC.length == blockLength) - val qtx: DenseMatrix[Double] = QtBc.value * X - val qty: DenseMatrix[Double] = QtyBc.value - val xxpRec: DenseVector[Double] = 1.0 / (X.t(*, ::).map(r => r dot r) - qtx.t(*, ::).map(r => r dot r)) - val ytx: DenseMatrix[Double] = yBc.value.t * X - assert(ytx.rows == yBc.value.cols && ytx.cols == blockLength) + val qtx: DenseMatrix[Double] = QtBc.value * X + val qty: DenseMatrix[Double] = QtyBc.value + val xxpRec: DenseVector[Double] = + 1.0 / (X.t(*, ::).map(r => r dot r) - qtx.t(*, ::).map(r => r dot r)) + val ytx: DenseMatrix[Double] = yBc.value.t * X + assert(ytx.rows == yBc.value.cols && ytx.cols == blockLength) - val xyp: DenseMatrix[Double] = ytx - (qty.t * qtx) - val yyp: DenseVector[Double] = yypBc.value + val xyp: DenseMatrix[Double] = ytx - (qty.t * qtx) + val yyp: DenseVector[Double] = yypBc.value - // resuse xyp - val b = xyp - i = 0 - while (i < blockLength) { - xyp(::, i) :*= xxpRec(i) - i += 1 - } + // resuse xyp + val b = xyp + i = 0 + while (i < blockLength) { + xyp(::, i) :*= xxpRec(i) + i += 1 + } - val se = sqrt(dRec * (yyp * xxpRec.t - (b *:* b))) + val se = sqrt(dRec * (yyp * xxpRec.t - (b *:* b))) - val t = b /:/ se - val p = t.map(s => 2 * T.cumulative(-math.abs(s), d, true, false)) - - (0 until blockLength).iterator.map { i => - val wrv = blockWRVs(i) - rvb.set(wrv.region) - rvb.start(rvdType.rowType) - rvb.startStruct() - rvb.addFields(fullRowType, wrv.region, wrv.offset, copiedFieldIndices) - rvb.addInt(n) - rvb.addDouble(AC(i)) - - def addSlice(dm: DenseMatrix[Double]) { - rvb.startArray(nDependentVariables) - var j = 0 - while (j < nDependentVariables) { - rvb.addDouble(dm(j, i)) - j += 1 - } - rvb.endArray() + val t = b /:/ se + val p = t.map(s => 2 * T.cumulative(-math.abs(s), d, true, false)) + + (0 until blockLength).iterator.map { i => + val wrv = blockWRVs(i) + rvb.set(wrv.region) + rvb.start(rvdType.rowType) + rvb.startStruct() + rvb.addFields(fullRowType, wrv.region, wrv.offset, copiedFieldIndices) + rvb.addInt(n) + rvb.addDouble(AC(i)) + + def addSlice(dm: DenseMatrix[Double]): Unit = { + rvb.startArray(nDependentVariables) + var j = 0 + while (j < nDependentVariables) { + rvb.addDouble(dm(j, i)) + j += 1 } + rvb.endArray() + } - addSlice(ytx) - addSlice(b) - addSlice(se) - addSlice(t) - addSlice(p) + addSlice(ytx) + addSlice(b) + addSlice(se) + addSlice(t) + addSlice(p) - rvb.endStruct() + rvb.endStruct() - producerCtx.region.addReferenceTo(wrv.region) - rvb.end() - } + producerCtx.region.addReferenceTo(wrv.region) + rvb.end() } - } + } + } TableValue(ctx, tableType, BroadcastRow.empty(ctx), newRVD) } } @@ -179,7 +198,8 @@ case class LinearRegressionRowsChained( xField: String, covFields: Seq[String], rowBlockSize: Int, - passThrough: Seq[String]) extends MatrixToTableFunction { + passThrough: Seq[String], +) extends MatrixToTableFunction { override def typ(childType: MatrixType): TableType = { val passThroughType = TStruct(passThrough.map(f => f -> childType.rowType.field(f).typ): _*) @@ -190,28 +210,36 @@ case class LinearRegressionRowsChained( ("beta", TArray(TArray(TFloat64))), ("standard_error", TArray(TArray(TFloat64))), ("t_stat", TArray(TArray(TFloat64))), - ("p_value", TArray(TArray(TFloat64)))) + ("p_value", TArray(TArray(TFloat64))), + ) TableType( childType.rowKeyStruct ++ passThroughType ++ chainedSchema, childType.rowKey, - TStruct.empty) + TStruct.empty, + ) } def preservesPartitionCounts: Boolean = true def execute(ctx: ExecuteContext, mv: MatrixValue): TableValue = { - val localData = yFields.map(y => RegressionUtils.getPhenosCovCompleteSamples(mv, y.toArray, covFields.toArray)) + val localData = yFields.map(y => + RegressionUtils.getPhenosCovCompleteSamples(mv, y.toArray, covFields.toArray) + ) val k = covFields.length // nCovariates val bcData = localData.zipWithIndex.map { case ((y, cov, completeColIdx), i) => val n = y.rows val d = n - k - 1 if (d < 1) - fatal(s"$n samples and ${ k + 1 } ${ plural(k, "covariate") } (including x) implies $d degrees of freedom.") + fatal( + s"$n samples and ${k + 1} ${plural(k, "covariate")} (including x) implies $d degrees of freedom." + ) - info(s"linear_regression_rows[$i]: running on $n samples for ${ y.cols } response ${ plural(y.cols, "variable") } y,\n" - + s" with input variable x, and ${ k } additional ${ plural(k, "covariate") }...") + info( + s"linear_regression_rows[$i]: running on $n samples for ${y.cols} response ${plural(y.cols, "variable")} y,\n" + + s" with input variable x, and $k additional ${plural(k, "covariate")}..." + ) val Qt = if (k > 0) @@ -242,124 +270,136 @@ case class LinearRegressionRowsChained( val sm = ctx.stateManager val newRVD = mv.rvd.mapPartitionsWithContext( - rvdType) { (consumerCtx, it) => - val producerCtx = consumerCtx.freshContext - val rvb = new RegionValueBuilder(sm) - - val inputData = bc.value - val builder = new IntArrayBuilder() - val data = inputData.map(cri => new Array[Double](cri.n * rowBlockSize)) - - val blockWRVs = new Array[WritableRegionValue](rowBlockSize) - var i = 0 - while (i < rowBlockSize) { - blockWRVs(i) = WritableRegionValue(sm, fullRowType, producerCtx.freshRegion()) - i += 1 - } + rvdType + ) { (consumerCtx, it) => + val producerCtx = consumerCtx.freshContext + val rvb = new RegionValueBuilder(sm) + + val inputData = bc.value + val builder = new IntArrayBuilder() + val data = inputData.map(cri => new Array[Double](cri.n * rowBlockSize)) + + val blockWRVs = new Array[WritableRegionValue](rowBlockSize) + var i = 0 + while (i < rowBlockSize) { + blockWRVs(i) = WritableRegionValue(sm, fullRowType, producerCtx.freshRegion()) + i += 1 + } - it(producerCtx).trueGroupedIterator(rowBlockSize) - .flatMap { git => - var i = 0 - while (git.hasNext) { - val ptr = git.next() - var j = 0 - while (j < nGroups) { - RegressionUtils.setMeanImputedDoubles(data(j), i * inputData(j).n, inputData(j).completeColIndex, builder, - ptr, fullRowType, entryArrayType, entryType, entryArrayIdx, fieldIdx) - j += 1 - } - blockWRVs(i).set(ptr, true) - producerCtx.region.clear() - i += 1 + it(producerCtx).trueGroupedIterator(rowBlockSize) + .flatMap { git => + var i = 0 + while (git.hasNext) { + val ptr = git.next() + var j = 0 + while (j < nGroups) { + RegressionUtils.setMeanImputedDoubles( + data(j), + i * inputData(j).n, + inputData(j).completeColIndex, + builder, + ptr, + fullRowType, + entryArrayType, + entryType, + entryArrayIdx, + fieldIdx, + ) + j += 1 } - val blockLength = i - - val results = Array.tabulate(nGroups) { j => - val cri = inputData(j) - val X = new DenseMatrix[Double](cri.n, blockLength, data(j)) - - val AC: DenseVector[Double] = X.t(*, ::).map(r => sum(r)) - assert(AC.length == blockLength) - - val qtx: DenseMatrix[Double] = cri.Qt * X - val qty: DenseMatrix[Double] = cri.Qty - val xxpRec: DenseVector[Double] = 1.0 / (X.t(*, ::).map(r => r dot r) - qtx.t(*, ::).map(r => r dot r)) - val ytx: DenseMatrix[Double] = cri.y.t * X - assert(ytx.rows == cri.y.cols && ytx.cols == blockLength) - - val xyp: DenseMatrix[Double] = ytx - (qty.t * qtx) - val yyp: DenseVector[Double] = cri.yyp - // resuse xyp - val b = xyp - i = 0 - while (i < blockLength) { - xyp(::, i) :*= xxpRec(i) - i += 1 - } - val se = sqrt((1d / cri.d) * (yyp * xxpRec.t - (b *:* b))) + blockWRVs(i).set(ptr, true) + producerCtx.region.clear() + i += 1 + } + val blockLength = i + + val results = Array.tabulate(nGroups) { j => + val cri = inputData(j) + val X = new DenseMatrix[Double](cri.n, blockLength, data(j)) + + val AC: DenseVector[Double] = X.t(*, ::).map(r => sum(r)) + assert(AC.length == blockLength) - val t = b /:/ se - val p = t.map(s => 2 * T.cumulative(-math.abs(s), cri.d, true, false)) + val qtx: DenseMatrix[Double] = cri.Qt * X + val qty: DenseMatrix[Double] = cri.Qty + val xxpRec: DenseVector[Double] = + 1.0 / (X.t(*, ::).map(r => r dot r) - qtx.t(*, ::).map(r => r dot r)) + val ytx: DenseMatrix[Double] = cri.y.t * X + assert(ytx.rows == cri.y.cols && ytx.cols == blockLength) - ChainedLinregResult(cri.n, AC, ytx, b, se, t, p) + val xyp: DenseMatrix[Double] = ytx - (qty.t * qtx) + val yyp: DenseVector[Double] = cri.yyp + // resuse xyp + val b = xyp + i = 0 + while (i < blockLength) { + xyp(::, i) :*= xxpRec(i) + i += 1 } + val se = sqrt((1d / cri.d) * (yyp * xxpRec.t - (b *:* b))) - (0 until blockLength).iterator.map { i => - val wrv = blockWRVs(i) - rvb.set(wrv.region) - rvb.start(rvdType.rowType) - rvb.startStruct() - rvb.addFields(fullRowType, wrv.region, wrv.offset, copiedFieldIndices) + val t = b /:/ se + val p = t.map(s => 2 * T.cumulative(-math.abs(s), cri.d, true, false)) - // FIXME: the below has horrible cache behavior, but hard to get around - // FIXME: it when doing a two-way in-memory transpose like this + ChainedLinregResult(cri.n, AC, ytx, b, se, t, p) + } - rvb.startArray(nGroups) - results.foreach(r => rvb.addInt(r.n)) - rvb.endArray() + (0 until blockLength).iterator.map { i => + val wrv = blockWRVs(i) + rvb.set(wrv.region) + rvb.start(rvdType.rowType) + rvb.startStruct() + rvb.addFields(fullRowType, wrv.region, wrv.offset, copiedFieldIndices) - rvb.startArray(nGroups) - results.foreach(r => rvb.addDouble(r.AC(i))) - rvb.endArray() + // FIXME: the below has horrible cache behavior, but hard to get around + // FIXME: it when doing a two-way in-memory transpose like this - def addSlice(dm: DenseMatrix[Double]) { - val size = dm.rows - rvb.startArray(size) - var j = 0 - while (j < size) { - rvb.addDouble(dm(j, i)) - j += 1 - } - rvb.endArray() - } + rvb.startArray(nGroups) + results.foreach(r => rvb.addInt(r.n)) + rvb.endArray() - rvb.startArray(nGroups) - results.foreach(r => addSlice(r.ytx)) - rvb.endArray() + rvb.startArray(nGroups) + results.foreach(r => rvb.addDouble(r.AC(i))) + rvb.endArray() - rvb.startArray(nGroups) - results.foreach(r => addSlice(r.b)) + def addSlice(dm: DenseMatrix[Double]): Unit = { + val size = dm.rows + rvb.startArray(size) + var j = 0 + while (j < size) { + rvb.addDouble(dm(j, i)) + j += 1 + } rvb.endArray() + } - rvb.startArray(nGroups) - results.foreach(r => addSlice(r.se)) - rvb.endArray() + rvb.startArray(nGroups) + results.foreach(r => addSlice(r.ytx)) + rvb.endArray() - rvb.startArray(nGroups) - results.foreach(r => addSlice(r.t)) - rvb.endArray() + rvb.startArray(nGroups) + results.foreach(r => addSlice(r.b)) + rvb.endArray() - rvb.startArray(nGroups) - results.foreach(r => addSlice(r.p)) - rvb.endArray() + rvb.startArray(nGroups) + results.foreach(r => addSlice(r.se)) + rvb.endArray() - rvb.endStruct() + rvb.startArray(nGroups) + results.foreach(r => addSlice(r.t)) + rvb.endArray() - producerCtx.region.addReferenceTo(wrv.region) - rvb.end() - } + rvb.startArray(nGroups) + results.foreach(r => addSlice(r.p)) + rvb.endArray() + + rvb.endStruct() + + producerCtx.region.addReferenceTo(wrv.region) + rvb.end() } - } + } + } TableValue(ctx, tableType, BroadcastRow.empty(ctx), newRVD) } } @@ -371,7 +411,8 @@ case class ChainedLinregInput( Qt: DenseMatrix[Double], Qty: DenseMatrix[Double], yyp: DenseVector[Double], - d: Int) + d: Int, +) case class ChainedLinregResult( n: Int, @@ -380,5 +421,5 @@ case class ChainedLinregResult( b: DenseMatrix[Double], se: DenseMatrix[Double], t: DenseMatrix[Double], - p: DenseMatrix[Double] + p: DenseMatrix[Double], ) diff --git a/hail/src/main/scala/is/hail/methods/LocalLDPrune.scala b/hail/src/main/scala/is/hail/methods/LocalLDPrune.scala index a8ad1e29ac3..6e22f021523 100644 --- a/hail/src/main/scala/is/hail/methods/LocalLDPrune.scala +++ b/hail/src/main/scala/is/hail/methods/LocalLDPrune.scala @@ -1,21 +1,15 @@ package is.hail.methods -import java.util -import is.hail.annotations._ import is.hail.backend.ExecuteContext -import is.hail.expr.ir.functions.MatrixToTableFunction import is.hail.expr.ir._ -import is.hail.sparkextras.ContextRDD +import is.hail.expr.ir.functions.MatrixToTableFunction +import is.hail.methods.BitPackedVector._ import is.hail.types._ -import is.hail.types.physical._ import is.hail.types.virtual._ -import is.hail.rvd.RVD import is.hail.utils._ import is.hail.variant._ -import org.apache.spark.rdd.RDD - -import BitPackedVector._ +import java.util object BitPackedVector { final val GENOTYPES_PER_PACK: Int = 32 @@ -110,7 +104,14 @@ class BitPackedVectorBuilder(nSamples: Int) { } } -case class BitPackedVector(locus: Locus, alleles: IndexedSeq[String], gs: Array[Long], nSamples: Int, mean: Double, centeredLengthRec: Double) { +case class BitPackedVector( + locus: Locus, + alleles: IndexedSeq[String], + gs: Array[Long], + nSamples: Int, + mean: Double, + centeredLengthRec: Double, +) { def nPacks: Int = gs.length def getPack(idx: Int): Long = gs(idx) @@ -162,8 +163,12 @@ object LocalLDPrune { table } - private def doubleSampleLookup(sample1VariantX: Int, sample1VariantY: Int, sample2VariantX: Int, - sample2VariantY: Int): (Int, Int, Int, Int) = { + private def doubleSampleLookup( + sample1VariantX: Int, + sample1VariantY: Int, + sample2VariantX: Int, + sample2VariantY: Int, + ): (Int, Int, Int, Int) = { val r1 = singleSampleLookup(sample1VariantX, sample1VariantY) val r2 = singleSampleLookup(sample2VariantX, sample2VariantY) (r1._1 + r2._1, r1._2 + r2._2, r1._3 + r2._3, r1._4 + r2._4) @@ -235,14 +240,22 @@ object LocalLDPrune { r2 } - def pruneLocal(queue: util.ArrayDeque[BitPackedVector], bpv: BitPackedVector, r2Threshold: Double, windowSize: Int, queueSize: Int): Boolean = { + def pruneLocal( + queue: util.ArrayDeque[BitPackedVector], + bpv: BitPackedVector, + r2Threshold: Double, + windowSize: Int, + queueSize: Int, + ): Boolean = { var keepVariant = true var done = false val qit = queue.descendingIterator() while (!done && qit.hasNext) { val bpvPrev = qit.next() - if (bpv.locus.contig != bpvPrev.locus.contig || bpv.locus.position - bpvPrev.locus.position > windowSize) { + if ( + bpv.locus.contig != bpvPrev.locus.contig || bpv.locus.position - bpvPrev.locus.position > windowSize + ) { done = true } else { val r2 = computeR2(bpv, bpvPrev) @@ -263,18 +276,13 @@ object LocalLDPrune { keepVariant } - private def pruneLocal(inputRDD: RDD[BitPackedVector], r2Threshold: Double, windowSize: Int, queueSize: Int): RDD[BitPackedVector] = { - inputRDD.mapPartitions({ it => - val queue = new util.ArrayDeque[BitPackedVector](queueSize) - it.filter { bpvv => - pruneLocal(queue, bpvv, r2Threshold, windowSize, queueSize) - } - }, preservesPartitioning = true) - } - - def apply(ctx: ExecuteContext, + def apply( + ctx: ExecuteContext, mt: MatrixValue, - callField: String = "GT", r2Threshold: Double = 0.2, windowSize: Int = 1000000, maxQueueSize: Int + callField: String = "GT", + r2Threshold: Double = 0.2, + windowSize: Int = 1000000, + maxQueueSize: Int, ): TableValue = { val pruner = LocalLDPrune(callField, r2Threshold, windowSize, maxQueueSize) pruner.execute(ctx, mt) @@ -282,15 +290,20 @@ object LocalLDPrune { } case class LocalLDPrune( - callField: String, r2Threshold: Double, windowSize: Int, maxQueueSize: Int + callField: String, + r2Threshold: Double, + windowSize: Int, + maxQueueSize: Int, ) extends MatrixToTableFunction { require(maxQueueSize > 0, s"Maximum queue size must be positive. Found '$maxQueueSize'.") - override def typ(childType: MatrixType): TableType = { + override def typ(childType: MatrixType): TableType = TableType( - rowType = childType.rowKeyStruct ++ TStruct("mean" -> TFloat64, "centered_length_rec" -> TFloat64), - key = childType.rowKey, globalType = TStruct.empty) - } + rowType = + childType.rowKeyStruct ++ TStruct("mean" -> TFloat64, "centered_length_rec" -> TFloat64), + key = childType.rowKey, + globalType = TStruct.empty, + ) def preservesPartitionCounts: Boolean = false @@ -307,12 +320,12 @@ case class LocalLDPrune( def execute(ctx: ExecuteContext, mv: MatrixValue): TableValue = { val nSamples = mv.nCols - val fullRowPType = mv.rvRowPType - val localCallField = callField val tableType = typ(mv.typ) - val ts = TableExecuteIntermediate(mv.toTableValue).asTableStage(ctx).mapPartition(Some(tableType.key)) { rows => - makeStream(rows, MatrixType.entriesIdentifier, nSamples) - }.mapGlobals(_ => makestruct()) + val ts = TableExecuteIntermediate(mv.toTableValue).asTableStage(ctx).mapPartition(Some( + tableType.key + ))(rows => makeStream(rows, MatrixType.entriesIdentifier, nSamples)).mapGlobals(_ => + makestruct() + ) TableExecuteIntermediate(ts).asTableValue(ctx) } } diff --git a/hail/src/main/scala/is/hail/methods/LocalWhitening.scala b/hail/src/main/scala/is/hail/methods/LocalWhitening.scala index f24ab26dc66..770748491e5 100644 --- a/hail/src/main/scala/is/hail/methods/LocalWhitening.scala +++ b/hail/src/main/scala/is/hail/methods/LocalWhitening.scala @@ -3,16 +3,24 @@ package is.hail.methods import is.hail.annotations.Region import is.hail.asm4s._ import is.hail.expr.ir.EmitCodeBuilder -import is.hail.types.physical.stypes.interfaces.SNDArray.assertColMajor -import is.hail.types.physical.stypes.interfaces.{ColonIndex => Colon, _} import is.hail.types.physical.{PCanonicalNDArray, PFloat64Required} +import is.hail.types.physical.stypes.interfaces.{ColonIndex => Colon, _} +import is.hail.types.physical.stypes.interfaces.SNDArray.assertColMajor import is.hail.utils.FastSeq -class LocalWhitening(cb: EmitCodeBuilder, vecSize: SizeValue, _w: Value[Long], chunksize: Value[Long], _blocksize: Value[Long], region: Value[Region], normalizeAfterWhitening: Boolean) { +class LocalWhitening( + cb: EmitCodeBuilder, + vecSize: SizeValue, + _w: Value[Long], + chunksize: Value[Long], + _blocksize: Value[Long], + region: Value[Region], + normalizeAfterWhitening: Boolean, +) { val m = vecSize val w = SizeValueDyn(cb.memoizeField(_w)) val b = SizeValueDyn(cb.memoizeField(chunksize)) - val wpb = SizeValueDyn(cb.memoizeField(w+b)) + val wpb = SizeValueDyn(cb.memoizeField(w + b)) val curSize = cb.newField[Long]("curSize", 0) val pivot = cb.newField[Long]("pivot", 0) @@ -22,18 +30,53 @@ class LocalWhitening(cb: EmitCodeBuilder, vecSize: SizeValue, _w: Value[Long], c val (tsize, worksize) = SNDArray.geqr_query(cb, m, chunksize, region) - val Q = cb.memoizeField(matType.constructUninitialized(FastSeq(m, w), cb, region), "LW_Q").asNDArray.coerceToShape(cb, m, w) - val R = cb.memoizeField(matType.constructUninitialized(FastSeq(w, w), cb, region), "LW_R").asNDArray.coerceToShape(cb, w, w) - val work1 = cb.memoizeField(matType.constructUninitialized(FastSeq(wpb, wpb), cb, region), "LW_work1").asNDArray.coerceToShape(cb, wpb, wpb) - val work2 = cb.memoizeField(matType.constructUninitialized(FastSeq(wpb, b), cb, region), "LW_work2").asNDArray.coerceToShape(cb, wpb, b) - val Rtemp = cb.memoizeField(matType.constructUninitialized(FastSeq(wpb, wpb), cb, region), "LW_Rtemp").asNDArray.coerceToShape(cb, wpb, wpb) - val Qtemp = cb.memoizeField(matType.constructUninitialized(FastSeq(m, b), cb, region), "LW_Qtemp").asNDArray.coerceToShape(cb, m, b) - val Qtemp2 = cb.memoizeField(matType.constructUninitialized(FastSeq(m, w), cb, region), "LW_Qtemp2").asNDArray.coerceToShape(cb, m, w) + val Q = cb.memoizeField( + matType.constructUninitialized(FastSeq(m, w), cb, region), + "LW_Q", + ).asNDArray.coerceToShape(cb, m, w) + + val R = cb.memoizeField( + matType.constructUninitialized(FastSeq(w, w), cb, region), + "LW_R", + ).asNDArray.coerceToShape(cb, w, w) + + val work1 = cb.memoizeField( + matType.constructUninitialized(FastSeq(wpb, wpb), cb, region), + "LW_work1", + ).asNDArray.coerceToShape(cb, wpb, wpb) + + val work2 = cb.memoizeField( + matType.constructUninitialized(FastSeq(wpb, b), cb, region), + "LW_work2", + ).asNDArray.coerceToShape(cb, wpb, b) + + val Rtemp = cb.memoizeField( + matType.constructUninitialized(FastSeq(wpb, wpb), cb, region), + "LW_Rtemp", + ).asNDArray.coerceToShape(cb, wpb, wpb) + + val Qtemp = cb.memoizeField( + matType.constructUninitialized(FastSeq(m, b), cb, region), + "LW_Qtemp", + ).asNDArray.coerceToShape(cb, m, b) + + val Qtemp2 = cb.memoizeField( + matType.constructUninitialized(FastSeq(m, w), cb, region), + "LW_Qtemp2", + ).asNDArray.coerceToShape(cb, m, w) + val blocksize = cb.memoizeField(_blocksize.min(w)) val work3len = SizeValueDyn(cb.memoize(worksize.max(blocksize * m.max(wpb)))) - val work3: SNDArrayValue = cb.memoizeField(vecType.constructUninitialized(FastSeq(work3len), cb, region), "LW_work3").asNDArray - val Tlen = SizeValueDyn(cb.memoizeField(tsize.max(blocksize*wpb))) - val T: SNDArrayValue = cb.memoizeField(vecType.constructUninitialized(FastSeq(Tlen), cb, region), "LW_T").asNDArray + + val work3: SNDArrayValue = cb.memoizeField( + vecType.constructUninitialized(FastSeq(work3len), cb, region), + "LW_work3", + ).asNDArray + + val Tlen = SizeValueDyn(cb.memoizeField(tsize.max(blocksize * wpb))) + + val T: SNDArrayValue = + cb.memoizeField(vecType.constructUninitialized(FastSeq(Tlen), cb, region), "LW_T").asNDArray def reset(cb: EmitCodeBuilder): Unit = { cb.assign(curSize, 0L) @@ -42,11 +85,16 @@ class LocalWhitening(cb: EmitCodeBuilder, vecSize: SizeValue, _w: Value[Long], c // Pre: A1 is current window, A2 is next window, [Q1 Q2] R = [A1 A2] is qr fact // Post: W contains locally whitened A2, Qout R[-w:, -w:] = A2[:, -w:] is qr fact - def whitenBlockPreOrthogonalized(cb: EmitCodeBuilder, - Q1: SNDArrayValue, Q2: SNDArrayValue, Qout: SNDArrayValue, - R: SNDArrayValue, W: SNDArrayValue, - work1: SNDArrayValue, work2: SNDArrayValue, - blocksize: Value[Long] + def whitenBlockPreOrthogonalized( + cb: EmitCodeBuilder, + Q1: SNDArrayValue, + Q2: SNDArrayValue, + Qout: SNDArrayValue, + R: SNDArrayValue, + W: SNDArrayValue, + work1: SNDArrayValue, + work2: SNDArrayValue, + blocksize: Value[Long], ): Unit = { SNDArray.assertMatrix(Q1, Q2, Qout, R, work1, work2) SNDArray.assertColMajor(cb, "whitenBlockPreOrthogonalized", Q1, Q2, Qout, R, work1, work2) @@ -68,26 +116,52 @@ class LocalWhitening(cb: EmitCodeBuilder, vecSize: SizeValue, _w: Value[Long], c // set work1 to I work1.setToZero(cb) - cb.for_(cb.assign(i, 0L), i.toL < w + n, cb.assign(i, i+1), work1.setElement(FastSeq(i, i), primitive(1.0), cb)) - - cb.for_(cb.assign(i, 0L), i < n, cb.assign(i, i+1), { - // Loop invariant: - // * ([Q1 Q2] work1[:, i:w+n]) R[i:w+n, i:w+n] = [A1 A2][i:w+n] is qr fact - // * ([Q1 Q2] work2[:, 0:i]) is locally whitened A2[:, 0:i] - - // work2[:, i] = work1[:, w+i] * R[w+i, w+i] - val wpi = cb.newLocal[Long]("w_plus_i", w+i) - val w1col = work1.slice(cb, Colon, wpi) - val w2col = work2.slice(cb, Colon, i) - SNDArray.copyVector(cb, w1col, w2col) - if (!normalizeAfterWhitening) { - SNDArray.scale(cb, R.loadElement(FastSeq(wpi, wpi), cb), w2col) - } - - // work3 > blocksize * (w+n - i+1) < blocksize * (w+n) - SNDArray.tpqrt(R.slice(cb, (i+1, null), (i+1, null)), R.slice(cb, (i, i+1), (i+1, null)), T, work3, blocksize, cb) - SNDArray.tpmqrt("R", "N", R.slice(cb, (i, i+1), (i+1, null)), T, work1.slice(cb, Colon, (i+1, null)), work1.slice(cb, Colon, (i, i+1)), work3, blocksize, cb) - }) + cb.for_( + cb.assign(i, 0L), + i.toL < w + n, + cb.assign(i, i + 1), + work1.setElement(FastSeq(i, i), primitive(1.0), cb), + ) + + cb.for_( + cb.assign(i, 0L), + i < n, + cb.assign(i, i + 1), { + // Loop invariant: + // * ([Q1 Q2] work1[:, i:w+n]) R[i:w+n, i:w+n] = [A1 A2][i:w+n] is qr fact + // * ([Q1 Q2] work2[:, 0:i]) is locally whitened A2[:, 0:i] + + // work2[:, i] = work1[:, w+i] * R[w+i, w+i] + val wpi = cb.newLocal[Long]("w_plus_i", w + i) + val w1col = work1.slice(cb, Colon, wpi) + val w2col = work2.slice(cb, Colon, i) + SNDArray.copyVector(cb, w1col, w2col) + if (!normalizeAfterWhitening) { + SNDArray.scale(cb, R.loadElement(FastSeq(wpi, wpi), cb), w2col) + } + + // work3 > blocksize * (w+n - i+1) < blocksize * (w+n) + SNDArray.tpqrt( + R.slice(cb, (i + 1, null), (i + 1, null)), + R.slice(cb, (i, i + 1), (i + 1, null)), + T, + work3, + blocksize, + cb, + ) + SNDArray.tpmqrt( + "R", + "N", + R.slice(cb, (i, i + 1), (i + 1, null)), + T, + work1.slice(cb, Colon, (i + 1, null)), + work1.slice(cb, Colon, (i, i + 1)), + work3, + blocksize, + cb, + ) + }, + ) // W = [Q1 Q2] work2 is locally whitened A2 SNDArray.gemm(cb, "N", "N", 1.0, Q1, work2.slice(cb, (null, w), Colon), 0.0, W) @@ -103,9 +177,12 @@ class LocalWhitening(cb: EmitCodeBuilder, vecSize: SizeValue, _w: Value[Long], c // Pre: Let Q1 = Q[:, 0:p0], Q2 = Q[:, p0:n], R11 = R[0:p0, 0:p0], R12 = R[0:p0, p0:n], etc. // * [Q2 Q1] [R22 R21; 0 R11] = [A2 A1] is a qr fact // Post: Same, with p1 substituted for p0 - def qrPivot(cb: EmitCodeBuilder, - Q: SNDArrayValue, R: SNDArrayValue, - p0: Value[Long], p1: Value[Long] + def qrPivot( + cb: EmitCodeBuilder, + Q: SNDArrayValue, + R: SNDArrayValue, + p0: Value[Long], + p1: Value[Long], ): Unit = { val Seq(m, w) = Q.shapes val Seq(t) = T.shapes @@ -113,7 +190,7 @@ class LocalWhitening(cb: EmitCodeBuilder, vecSize: SizeValue, _w: Value[Long], c cb.if_(R.shapes(1).cne(w), cb._fatal("qr_pivot: R ncols != w")) cb.if_(m <= w, cb._fatal("qr_pivot: m <= w, m=", m.toS, ", w=", w.toS)) cb.if_(p0 < 0 || p0 >= p1 || p1 > w, cb._fatal("qr_pivot: bad p0, p1")) - cb.if_(t < blocksize * p0.max((p1-p0).max(w-p1)), cb._fatal("qr_pivot: T too small")) + cb.if_(t < blocksize * p0.max((p1 - p0).max(w - p1)), cb._fatal("qr_pivot: T too small")) val r0 = (null, p0) val r1 = (p0, p1) @@ -121,27 +198,74 @@ class LocalWhitening(cb: EmitCodeBuilder, vecSize: SizeValue, _w: Value[Long], c val r01 = (null, p1) val b0 = cb.memoize(blocksize.min(p0)) - val b1 = cb.memoize(blocksize.min(p1-p0)) - val b2 = cb.memoize(blocksize.min(w-p1)) + val b1 = cb.memoize(blocksize.min(p1 - p0)) + val b2 = cb.memoize(blocksize.min(w - p1)) // Set lower trapezoid of R[r12, r1] to zero val j = cb.mb.newLocal[Long]("j") - cb.for_(cb.assign(j, p0), j < p1, cb.assign(j, j+1), { - R.slice(cb, (j+1, null), j).setToZero(cb) - }) + cb.for_( + cb.assign(j, p0), + j < p1, + cb.assign(j, j + 1), + R.slice(cb, (j + 1, null), j).setToZero(cb), + ) R.slice(cb, r0, r1).setToZero(cb) - cb.if_(p1 < w, { - SNDArray.tpqrt(R.slice(cb, r2, r2), R.slice(cb, r1, r2), T, work3, b2, cb) - SNDArray.tpmqrt("L", "T", R.slice(cb, r1, r2), T, R.slice(cb, r2, r01), R.slice(cb, r1, r01), work3, b2, cb) - SNDArray.tpmqrt("R", "N", R.slice(cb, r1, r2), T, Q.slice(cb, Colon, r2), Q.slice(cb, Colon, r1), work3, b2, cb) - }) - cb.if_(p0 > 0, { - SNDArray.tpqrt(R.slice(cb, r0, r0), R.slice(cb, r1, r0), T, work3, b0, cb) - SNDArray.tpmqrt("L", "T", R.slice(cb, r1, r0), T, R.slice(cb, r0, r1), R.slice(cb, r1, r1), work3, b0, cb) - SNDArray.tpmqrt("R", "N", R.slice(cb, r1, r0), T, Q.slice(cb, Colon, r0), Q.slice(cb, Colon, r1), work3, b0, cb) - }) + cb.if_( + p1 < w, { + SNDArray.tpqrt(R.slice(cb, r2, r2), R.slice(cb, r1, r2), T, work3, b2, cb) + SNDArray.tpmqrt( + "L", + "T", + R.slice(cb, r1, r2), + T, + R.slice(cb, r2, r01), + R.slice(cb, r1, r01), + work3, + b2, + cb, + ) + SNDArray.tpmqrt( + "R", + "N", + R.slice(cb, r1, r2), + T, + Q.slice(cb, Colon, r2), + Q.slice(cb, Colon, r1), + work3, + b2, + cb, + ) + }, + ) + cb.if_( + p0 > 0, { + SNDArray.tpqrt(R.slice(cb, r0, r0), R.slice(cb, r1, r0), T, work3, b0, cb) + SNDArray.tpmqrt( + "L", + "T", + R.slice(cb, r1, r0), + T, + R.slice(cb, r0, r1), + R.slice(cb, r1, r1), + work3, + b0, + cb, + ) + SNDArray.tpmqrt( + "R", + "N", + R.slice(cb, r1, r0), + T, + Q.slice(cb, Colon, r0), + Q.slice(cb, Colon, r1), + work3, + b0, + cb, + ) + }, + ) SNDArray.geqrt(R.slice(cb, r1, r1), T, work3, b1, cb) SNDArray.gemqrt("R", "N", R.slice(cb, r1, r1), T, Q.slice(cb, Colon, r1), work3, b1, cb) } @@ -152,11 +276,17 @@ class LocalWhitening(cb: EmitCodeBuilder, vecSize: SizeValue, _w: Value[Long], c // // Pre: Q R = A0 is qr fact of current window, A contains next window // Post: A contains A_orig whitened, Q R = A_orig - def whitenBlockSmallWindow(cb: EmitCodeBuilder, - Q: SNDArrayValue, R: SNDArrayValue, A: SNDArrayValue, - Qtemp: SNDArrayValue, Qtemp2: SNDArrayValue, Rtemp: SNDArrayValue, - work1: SNDArrayValue, work2: SNDArrayValue, - blocksize: Value[Long] + def whitenBlockSmallWindow( + cb: EmitCodeBuilder, + Q: SNDArrayValue, + R: SNDArrayValue, + A: SNDArrayValue, + Qtemp: SNDArrayValue, + Qtemp2: SNDArrayValue, + Rtemp: SNDArrayValue, + work1: SNDArrayValue, + work2: SNDArrayValue, + blocksize: Value[Long], ): Unit = { val Seq(m, w) = Q.shapes val n = A.shapes(1) @@ -195,7 +325,7 @@ class LocalWhitening(cb: EmitCodeBuilder, vecSize: SizeValue, _w: Value[Long], c whitenBlockPreOrthogonalized(cb, Q, Qtemp, Qtemp2, Rtemp, A, work1, work2, blocksize) // copy upper triangle of Rtemp[n:w+n, n:w+n] to R - SNDArray.copyMatrix(cb, "U", Rtemp.slice(cb, (n, w+n), (n, w+n)), R) + SNDArray.copyMatrix(cb, "U", Rtemp.slice(cb, (n, w + n), (n, w + n)), R) // copy Qtemp2 to Q SNDArray.copyMatrix(cb, " ", Qtemp2, Q) // now Q R = A_orig[::, n-w:n] @@ -211,15 +341,23 @@ class LocalWhitening(cb: EmitCodeBuilder, vecSize: SizeValue, _w: Value[Long], c // Post: // * [Q3 Q1 Q2] [R33 R31 R32; 0 R11 R12; 0 0 R22] = [A3 A1 A_orig] is a qr fact // * A contains whitened A_orig - def whitenBlockLargeWindow(cb: EmitCodeBuilder, - Q: SNDArrayValue, R: SNDArrayValue, p: Value[Long], A: SNDArrayValue, Qtemp: SNDArrayValue, - Qtemp2: SNDArrayValue, Rtemp: SNDArrayValue, work1: SNDArrayValue, work2: SNDArrayValue, - blocksize: Value[Long] + def whitenBlockLargeWindow( + cb: EmitCodeBuilder, + Q: SNDArrayValue, + R: SNDArrayValue, + p: Value[Long], + A: SNDArrayValue, + Qtemp: SNDArrayValue, + Qtemp2: SNDArrayValue, + Rtemp: SNDArrayValue, + work1: SNDArrayValue, + work2: SNDArrayValue, + blocksize: Value[Long], ): Unit = { val b = A.shapes(1) val bb = Rtemp.shapes(0) - cb.if_((b*2).cne(bb), cb._fatal("whitenStep: invalid dimensions")) + cb.if_((b * 2).cne(bb), cb._fatal("whitenStep: invalid dimensions")) assert(Q.hasShapeStatic(m, w)) assert(R.hasShapeStatic(w, w)) @@ -230,7 +368,7 @@ class LocalWhitening(cb: EmitCodeBuilder, vecSize: SizeValue, _w: Value[Long], c assert(work1.hasShapeStatic(bb, bb)) assert(work2.hasShapeStatic(bb, b)) - val ppb = cb.memoize(p+b) + val ppb = cb.memoize(p + b) qrPivot(cb, Q, R, p, ppb) // now [Q3 Q1 Q2] [R33 R31 R32; 0 R11 R12; 0 0 R22] = [A3 A1 A2] @@ -264,50 +402,69 @@ class LocalWhitening(cb: EmitCodeBuilder, vecSize: SizeValue, _w: Value[Long], c val b = _A.shapes(1) val A = _A.coerceToShape(cb, m, b) - cb.if_(b > chunksize, cb._fatal("whitenBlock: A too large, found ", b.toS, ", expected ", chunksize.toS)) - - cb.if_(curSize < w, { - // Orthogonalize against existing Q - val Rslice = R.slice(cb, (null, curSize), (curSize, curSize + b)) - val Qslice = Q.slice(cb, Colon, (null, curSize)) - // Rslice = Q' A - SNDArray.gemm(cb, "T", "N", Qslice, A, Rslice) - // A = A - Q Rslice - SNDArray.gemm(cb, "N", "N", -1.0, Qslice, Rslice, 1.0, A) - - // Compute QR fact of A; store R fact in Rtemp[r1, r1], Q fact in Qtemp - val Rslice2 = R.slice(cb, (curSize, curSize + b), (curSize, curSize + b)) - val Qslice2 = Q.slice(cb, Colon, (curSize, curSize + b)) - SNDArray.geqr_full(cb, A, Qslice2, Rslice2, T, work3) - - // Copy whitened A back to A - val j = cb.newLocal[Long]("j") - cb.for_(cb.assign(j, 0L), j < b, cb.assign(j, j+1), { - val Acol = A.slice(cb, Colon, j) - SNDArray.copyVector(cb, Qslice2.slice(cb, Colon, j), Acol) - SNDArray.scale(cb, Rslice2.loadElement(FastSeq(j, j), cb), Acol) - }) - - cb.assign(curSize, curSize + b) - }, { - cb.if_(curSize.cne(w), cb._fatal("whitenBlock: initial blocks didn't evenly divide window size")) - - val bb = SizeValueDyn(cb.memoize(b*2)) - whitenBlockLargeWindow(cb, - Q, R, pivot, A, - Qtemp.slice(cb, Colon, (null, b)), - Qtemp2.slice(cb, Colon, (null, b)), - Rtemp.slice(cb, (null, bb), (null, bb)), - work1.slice(cb, (null, bb), (null, bb)), - work2.slice(cb, (null, bb), (null, b)), - cb.memoize(blocksize.min(b))) - - cb.assign(pivot, pivot + b) - cb.if_(pivot >= w, { - cb.if_(pivot.cne(w), cb._fatal("whitenBlock, blocks didn't evenly divide window size")) - cb.assign(pivot, 0L) - }) - }) + cb.if_( + b > chunksize, + cb._fatal("whitenBlock: A too large, found ", b.toS, ", expected ", chunksize.toS), + ) + + cb.if_( + curSize < w, { + // Orthogonalize against existing Q + val Rslice = R.slice(cb, (null, curSize), (curSize, curSize + b)) + val Qslice = Q.slice(cb, Colon, (null, curSize)) + // Rslice = Q' A + SNDArray.gemm(cb, "T", "N", Qslice, A, Rslice) + // A = A - Q Rslice + SNDArray.gemm(cb, "N", "N", -1.0, Qslice, Rslice, 1.0, A) + + // Compute QR fact of A; store R fact in Rtemp[r1, r1], Q fact in Qtemp + val Rslice2 = R.slice(cb, (curSize, curSize + b), (curSize, curSize + b)) + val Qslice2 = Q.slice(cb, Colon, (curSize, curSize + b)) + SNDArray.geqr_full(cb, A, Qslice2, Rslice2, T, work3) + + // Copy whitened A back to A + val j = cb.newLocal[Long]("j") + cb.for_( + cb.assign(j, 0L), + j < b, + cb.assign(j, j + 1), { + val Acol = A.slice(cb, Colon, j) + SNDArray.copyVector(cb, Qslice2.slice(cb, Colon, j), Acol) + SNDArray.scale(cb, Rslice2.loadElement(FastSeq(j, j), cb), Acol) + }, + ) + + cb.assign(curSize, curSize + b) + }, { + cb.if_( + curSize.cne(w), + cb._fatal("whitenBlock: initial blocks didn't evenly divide window size"), + ) + + val bb = SizeValueDyn(cb.memoize(b * 2)) + whitenBlockLargeWindow( + cb, + Q, + R, + pivot, + A, + Qtemp.slice(cb, Colon, (null, b)), + Qtemp2.slice(cb, Colon, (null, b)), + Rtemp.slice(cb, (null, bb), (null, bb)), + work1.slice(cb, (null, bb), (null, bb)), + work2.slice(cb, (null, bb), (null, b)), + cb.memoize(blocksize.min(b)), + ) + + cb.assign(pivot, pivot + b) + cb.if_( + pivot >= w, { + cb.if_(pivot.cne(w), cb._fatal("whitenBlock, blocks didn't evenly divide window size")) + cb.assign(pivot, 0L) + }, + ) + }, + ) } def initializeWindow(cb: EmitCodeBuilder, _A: SNDArrayValue): Unit = { diff --git a/hail/src/main/scala/is/hail/methods/LogisticRegression.scala b/hail/src/main/scala/is/hail/methods/LogisticRegression.scala index d33d5171793..d8b79558e98 100644 --- a/hail/src/main/scala/is/hail/methods/LogisticRegression.scala +++ b/hail/src/main/scala/is/hail/methods/LogisticRegression.scala @@ -1,17 +1,17 @@ package is.hail.methods -import breeze.linalg._ import is.hail.HailContext import is.hail.annotations._ import is.hail.backend.ExecuteContext -import is.hail.expr.ir.functions.MatrixToTableFunction import is.hail.expr.ir.{IntArrayBuilder, MatrixValue, TableValue} -import is.hail.types.virtual.{TArray, TFloat64, TStruct} -import is.hail.types.{MatrixType, TableType} -import is.hail.rvd.RVDType +import is.hail.expr.ir.functions.MatrixToTableFunction import is.hail.stats._ +import is.hail.types.{MatrixType, TableType} +import is.hail.types.virtual.{TArray, TFloat64, TStruct} import is.hail.utils._ +import breeze.linalg._ + case class LogisticRegression( test: String, yFields: Seq[String], @@ -19,14 +19,18 @@ case class LogisticRegression( covFields: Seq[String], passThrough: Seq[String], maxIterations: Int, - tolerance: Double + tolerance: Double, ) extends MatrixToTableFunction { override def typ(childType: MatrixType): TableType = { val logRegTest = LogisticRegressionTest.tests(test) val multiPhenoSchema = TStruct(("logistic_regression", TArray(logRegTest.schema))) val passThroughType = TStruct(passThrough.map(f => f -> childType.rowType.field(f).typ): _*) - TableType(childType.rowKeyStruct ++ passThroughType ++ multiPhenoSchema, childType.rowKey, TStruct.empty) + TableType( + childType.rowKeyStruct ++ passThroughType ++ multiPhenoSchema, + childType.rowKey, + TStruct.empty, + ) } def preservesPartitionCounts: Boolean = true @@ -36,50 +40,64 @@ case class LogisticRegression( val tableType = typ(mv.typ) val newRVDType = tableType.canonicalRVDType - val multiPhenoSchema = TStruct(("logistic_regression", TArray(logRegTest.schema))) - - val (yVecs, cov, completeColIdx) = RegressionUtils.getPhenosCovCompleteSamples(mv, yFields.toArray, covFields.toArray) + val (yVecs, cov, completeColIdx) = + RegressionUtils.getPhenosCovCompleteSamples(mv, yFields.toArray, covFields.toArray) - (0 until yVecs.cols).foreach(col => { + (0 until yVecs.cols).foreach { col => if (!yVecs(::, col).forall(yi => yi == 0d || yi == 1d)) - fatal(s"For logistic regression, y at index ${col} must be bool or numeric with all present values equal to 0 or 1") - val sumY = sum(yVecs(::,col)) - if (sumY == 0d || sumY == yVecs(::,col).length) - fatal(s"For logistic regression, y at index ${col} must be non-constant") - }) + fatal( + s"For logistic regression, y at index $col must be bool or numeric with all present values equal to 0 or 1" + ) + val sumY = sum(yVecs(::, col)) + if (sumY == 0d || sumY == yVecs(::, col).length) + fatal(s"For logistic regression, y at index $col must be non-constant") + } val n = yVecs.rows val k = cov.cols val d = n - k - 1 if (d < 1) - fatal(s"$n samples and ${ k + 1 } ${ plural(k, "covariate") } (including x) implies $d degrees of freedom.") + fatal( + s"$n samples and ${k + 1} ${plural(k, "covariate")} (including x) implies $d degrees of freedom." + ) info(s"logistic_regression_rows: running $test on $n samples for response variable y,\n" - + s" with input variable x, and ${ k } additional ${ plural(k, "covariate") }...") + + s" with input variable x, and $k additional ${plural(k, "covariate")}...") - val nullFits = (0 until yVecs.cols).map(col => { + val nullFits = (0 until yVecs.cols).map { col => val nullModel = new LogisticRegressionModel(cov, yVecs(::, col)) - var nullFit = nullModel.fit(maxIter=maxIterations, tol=tolerance) + var nullFit = nullModel.fit(maxIter = maxIterations, tol = tolerance) if (!nullFit.converged) if (logRegTest == LogisticFirthTest) - nullFit = GLMFit(nullModel.bInterceptOnly(), - None, None, 0, nullFit.nIter, exploded = nullFit.exploded, converged = false) + nullFit = GLMFit( + nullModel.bInterceptOnly(), + None, + None, + 0, + nullFit.nIter, + exploded = nullFit.exploded, + converged = false, + ) else - fatal("Failed to fit logistic regression null model (standard MLE with covariates only): " + ( - if (nullFit.exploded) - s"exploded at Newton iteration ${nullFit.nIter}" - else - "Newton iteration failed to converge")) + fatal( + "Failed to fit logistic regression null model (standard MLE with covariates only): " + ( + if (nullFit.exploded) + s"exploded at Newton iteration ${nullFit.nIter}" + else + "Newton iteration failed to converge" + ) + ) nullFit - }) + } val backend = HailContext.backend val completeColIdxBc = backend.broadcast(completeColIdx) val yVecsBc = backend.broadcast(yVecs) - val XBc = backend.broadcast(new DenseMatrix[Double](n, k + 1, cov.toArray ++ Array.ofDim[Double](n))) + val XBc = + backend.broadcast(new DenseMatrix[Double](n, k + 1, cov.toArray ++ Array.ofDim[Double](n))) val nullFitBc = backend.broadcast(nullFits) val logRegTestBc = backend.broadcast(logRegTest) @@ -103,22 +121,39 @@ case class LogisticRegression( val _yVecs = yVecsBc.value val X = XBc.value.copy it.map { ptr => - RegressionUtils.setMeanImputedDoubles(X.data, n * k, completeColIdxBc.value, missingCompleteCols, - ptr, fullRowType, entryArrayType, entryType, entryArrayIdx, fieldIdx) - val logregAnnotations = (0 until _yVecs.cols).map(col => { - logRegTestBc.value.test(X, _yVecs(::,col), _nullFits(col), "logistic", maxIter=maxIterations, tol=tolerance) - }) + RegressionUtils.setMeanImputedDoubles( + X.data, + n * k, + completeColIdxBc.value, + missingCompleteCols, + ptr, + fullRowType, + entryArrayType, + entryType, + entryArrayIdx, + fieldIdx, + ) + val logregAnnotations = (0 until _yVecs.cols).map { col => + logRegTestBc.value.test( + X, + _yVecs(::, col), + _nullFits(col), + "logistic", + maxIter = maxIterations, + tol = tolerance, + ) + } rvb.start(newRVDType.rowType) rvb.startStruct() rvb.addFields(fullRowType, ctx.r, ptr, copiedFieldIndices) rvb.startArray(_yVecs.cols) - logregAnnotations.foreach(stats => { + logregAnnotations.foreach { stats => rvb.startStruct() stats.addToRVB(rvb) rvb.endStruct() - }) + } rvb.endArray() rvb.endStruct() rvb.end() diff --git a/hail/src/main/scala/is/hail/methods/MatrixExportEntriesByCol.scala b/hail/src/main/scala/is/hail/methods/MatrixExportEntriesByCol.scala index 4ad95fe3cb5..cf3cdcea9d5 100644 --- a/hail/src/main/scala/is/hail/methods/MatrixExportEntriesByCol.scala +++ b/hail/src/main/scala/is/hail/methods/MatrixExportEntriesByCol.scala @@ -1,6 +1,5 @@ package is.hail.methods -import java.io.{BufferedOutputStream, OutputStreamWriter} import is.hail.HailContext import is.hail.annotations.{UnsafeIndexedSeq, UnsafeRow} import is.hail.backend.ExecuteContext @@ -11,11 +10,19 @@ import is.hail.expr.ir.functions.MatrixToValueFunction import is.hail.types.{MatrixType, RTable, TypeWithRequiredness} import is.hail.types.virtual.{TVoid, Type} import is.hail.utils._ + +import java.io.{BufferedOutputStream, OutputStreamWriter} + import org.apache.spark.TaskContext import org.apache.spark.sql.Row -case class MatrixExportEntriesByCol(parallelism: Int, path: String, bgzip: Boolean, - headerJsonInFile: Boolean, useStringKeyAsFileName: Boolean) extends MatrixToValueFunction { +case class MatrixExportEntriesByCol( + parallelism: Int, + path: String, + bgzip: Boolean, + headerJsonInFile: Boolean, + useStringKeyAsFileName: Boolean, +) extends MatrixToValueFunction { def typ(childType: MatrixType): Type = TVoid def unionRequiredness(childType: RTable, resultType: TypeWithRequiredness): Unit = () @@ -29,23 +36,26 @@ case class MatrixExportEntriesByCol(parallelism: Int, path: String, bgzip: Boole val fileNames: IndexedSeq[String] = if (useStringKeyAsFileName) { val ids = mv.stringSampleIds if (ids.toSet.size != ids.length) // there are duplicates - fatal("export_entries_by_col cannot export with 'use_string_key_as_file_name' with duplicate keys") + fatal( + "export_entries_by_col cannot export with 'use_string_key_as_file_name' with duplicate keys" + ) ids } else Array.tabulate(mv.nCols)(i => partFile(padding, i)) - val allColValuesJSON = mv.colValues.javaValue.map(TableAnnotationImpex.exportAnnotation(_, mv.typ.colType)).toArray + val allColValuesJSON = + mv.colValues.javaValue.map(TableAnnotationImpex.exportAnnotation(_, mv.typ.colType)).toArray val tempFolders = new BoxedArrayBuilder[String] - info(s"exporting ${ mv.nCols } files in batches of $parallelism...") + info(s"exporting ${mv.nCols} files in batches of $parallelism...") val nBatches = (mv.nCols + parallelism - 1) / parallelism val resultFiles = (0 until nBatches).flatMap { batch => val startIdx = parallelism * batch val nCols = mv.nCols val endIdx = math.min(nCols, parallelism * (batch + 1)) - info(s"on batch ${ batch + 1 } of ${ nBatches }, columns $startIdx to ${ endIdx - 1 }...") + info(s"on batch ${batch + 1} of $nBatches, columns $startIdx to ${endIdx - 1}...") val d = digitsNeeded(mv.rvd.getNumPartitions) @@ -56,24 +66,24 @@ case class MatrixExportEntriesByCol(parallelism: Int, path: String, bgzip: Boole val partFileBase = path + "/tmp/" - val extension = if (bgzip) ".tsv.bgz" else ".tsv" val localHeaderJsonInFile = headerJsonInFile val colValuesJSON = HailContext.backend.broadcast( (startIdx until endIdx) .map(allColValuesJSON) - .toArray) + .toArray + ) val fsBc = fs.broadcast val localTempDir = ctx.localTmpdir val partFolders = mv.rvd.crdd.cmapPartitionsWithIndex { (i, ctx, it) => - val partFolder = partFileBase + partFile(d, i, TaskContext.get()) val filePaths = Array.tabulate(endIdx - startIdx) { j => val finalPath = partFolder + "/" + j.toString + extension - val tempPath = ExecuteContext.createTmpPathNoCleanup(localTempDir, "EEBC", extension = extension) + val tempPath = + ExecuteContext.createTmpPathNoCleanup(localTempDir, "EEBC", extension = extension) (tempPath, finalPath) } @@ -85,7 +95,7 @@ case class MatrixExportEntriesByCol(parallelism: Int, path: String, bgzip: Boole // write headers val header = ( rvType.fieldNames.filter(_ != MatrixType.entriesIdentifier) ++ entryType.fieldNames - ).mkString("\t") + ).mkString("\t") fileHandles.zipWithIndex.foreach { case (f, jj) => if (localHeaderJsonInFile) { @@ -99,15 +109,18 @@ case class MatrixExportEntriesByCol(parallelism: Int, path: String, bgzip: Boole } it.foreach { ptr => - - val entriesArray = new UnsafeIndexedSeq(entryArrayType, ctx.region, rvType.loadField(ptr, entriesIdx)) + val entriesArray = + new UnsafeIndexedSeq(entryArrayType, ctx.region, rvType.loadField(ptr, entriesIdx)) val fullRow = new UnsafeRow(rvType, ctx.region, ptr) val rowFieldStrs = (0 until rvType.size) .filter(_ != entriesIdx) .map { rowFieldIdx => - TableAnnotationImpex.exportAnnotation(fullRow(rowFieldIdx), rvType.types(rowFieldIdx).virtualType) + TableAnnotationImpex.exportAnnotation( + fullRow(rowFieldIdx), + rvType.types(rowFieldIdx).virtualType, + ) }.toArray fileHandles.indices.foreach { fileIdx => @@ -121,12 +134,13 @@ case class MatrixExportEntriesByCol(parallelism: Int, path: String, bgzip: Boole entriesArray(entryIdx) match { case null => - (0 until entryType.size).foreachBetween { _ => - os.write("NA") - }(os.write('\t')) + (0 until entryType.size).foreachBetween(_ => os.write("NA"))(os.write('\t')) case r: Row => (0 until entryType.size).foreachBetween { entryFieldIdx => - os.write(TableAnnotationImpex.exportAnnotation(r.get(entryFieldIdx), entryType.types(entryFieldIdx).virtualType)) + os.write(TableAnnotationImpex.exportAnnotation( + r.get(entryFieldIdx), + entryType.types(entryFieldIdx).virtualType, + )) }(os.write('\t')) } @@ -150,8 +164,9 @@ case class MatrixExportEntriesByCol(parallelism: Int, path: String, bgzip: Boole val newFiles = mv.sparkContext.parallelize(0 until ns, numSlices = ns) .map { sampleIdx => val partFilePath = path + "/" + partFile(digitsNeeded(nCols), sampleIdx, TaskContext.get) - val fileListEntries = partFolders.map(pf => fsBc.value.fileListEntry(pf + s"/$sampleIdx" + extension)) - fsBc.value.copyMergeList(fileListEntries, partFilePath, deleteSource = false) + val fileStatuses = + partFolders.map(pf => fsBc.value.fileStatus(pf + s"/$sampleIdx" + extension)) + fsBc.value.copyMergeList(fileStatuses, partFilePath, deleteSource = false) partFilePath }.collect() @@ -162,27 +177,30 @@ case class MatrixExportEntriesByCol(parallelism: Int, path: String, bgzip: Boole val extension = if (bgzip) ".tsv.bgz" else ".tsv" - def finalPath(idx: Int): String = { + def finalPath(idx: Int): String = path + "/" + fileNames(idx) + extension - } resultFiles.zipWithIndex.foreach { case (filePath, i) => fs.copy(filePath, finalPath(i), deleteSource = true) } fs.delete(path + "/tmp", recursive = true) - fs.writeTable(path + "/index.tsv", allColValuesJSON.zipWithIndex.map { case (json, i) => - s"${ finalPath(i) }\t$json" - }) + fs.writeTable( + path + "/index.tsv", + allColValuesJSON.zipWithIndex.map { case (json, i) => + s"${finalPath(i)}\t$json" + }, + ) info("Export finished. Cleaning up temporary files...") // clean up temporary files val temps = tempFolders.result() val fsBc = fs.broadcast - SparkBackend.sparkContext("MatrixExportEntriesByCol.execute").parallelize(temps, (temps.length / 32).max(1)).foreach { path => - fsBc.value.delete(path, recursive = true) - } + SparkBackend.sparkContext("MatrixExportEntriesByCol.execute").parallelize( + temps, + (temps.length / 32).max(1), + ).foreach(path => fsBc.value.delete(path, recursive = true)) info("Done cleaning up temporary files.") } diff --git a/hail/src/main/scala/is/hail/methods/NPartitions.scala b/hail/src/main/scala/is/hail/methods/NPartitions.scala index 2d155160737..adfecfc3820 100644 --- a/hail/src/main/scala/is/hail/methods/NPartitions.scala +++ b/hail/src/main/scala/is/hail/methods/NPartitions.scala @@ -1,10 +1,10 @@ package is.hail.methods import is.hail.backend.ExecuteContext -import is.hail.expr.ir.functions.{MatrixToValueFunction, TableToValueFunction} import is.hail.expr.ir.{MatrixValue, TableValue} -import is.hail.types.virtual.{TInt32, Type} +import is.hail.expr.ir.functions.{MatrixToValueFunction, TableToValueFunction} import is.hail.types.{MatrixType, RTable, TableType, TypeWithRequiredness} +import is.hail.types.virtual.{TInt32, Type} case class NPartitionsTable() extends TableToValueFunction { override def typ(childType: TableType): Type = TInt32 diff --git a/hail/src/main/scala/is/hail/methods/Nirvana.scala b/hail/src/main/scala/is/hail/methods/Nirvana.scala index f40c72e5ecd..44c8e8e4ad1 100644 --- a/hail/src/main/scala/is/hail/methods/Nirvana.scala +++ b/hail/src/main/scala/is/hail/methods/Nirvana.scala @@ -12,19 +12,20 @@ import is.hail.types.physical.PType import is.hail.types.virtual._ import is.hail.utils._ import is.hail.variant.{Locus, RegionValueVariant} -import org.apache.spark.sql.Row -import org.apache.spark.storage.StorageLevel -import org.json4s.jackson.JsonMethods -import java.io.{FileInputStream, IOException} -import java.util.Properties import scala.collection.JavaConverters._ import scala.collection.mutable +import java.io.{FileInputStream, IOException} +import java.util.Properties + +import org.apache.spark.sql.Row +import org.apache.spark.storage.StorageLevel +import org.json4s.jackson.JsonMethods object Nirvana { - //For Nirnava v2.0.8 + // For Nirnava v2.0.8 val nirvanaSignature = TStruct( "chromosome" -> TString, @@ -50,7 +51,7 @@ object Nirvana { "validated" -> TBoolean, "phenotypes" -> TArray(TString), "phenotypeIds" -> TArray(TString), - "reciprocalOverlap" -> TFloat64 + "reciprocalOverlap" -> TFloat64, )), "dgv" -> TArray(TStruct( "chromosome" -> TString, @@ -62,7 +63,7 @@ object Nirvana { "sampleSize" -> TInt32, "observedGains" -> TInt32, "observedLosses" -> TInt32, - "reciprocalOverlap" -> TFloat64 + "reciprocalOverlap" -> TFloat64, )), "oneKg" -> TArray(TStruct( "chromosome" -> TString, @@ -84,7 +85,7 @@ object Nirvana { "sampleSizeSas" -> TInt32, "observedGains" -> TInt32, "observedLosses" -> TInt32, - "reciprocalOverlap" -> TFloat64 + "reciprocalOverlap" -> TFloat64, )), "cosmic" -> TArray(TStruct( "id" -> TInt32, @@ -93,9 +94,9 @@ object Nirvana { "end" -> TInt32, "variantType" -> TString, "copyNumber" -> TInt32, - "cancerTypes" -> TArray(TTuple(TString,TInt32)), - "tissues" -> TArray(TTuple(TString,TInt32)), - "reciprocalOverlap" -> TFloat64 + "cancerTypes" -> TArray(TTuple(TString, TInt32)), + "tissues" -> TArray(TTuple(TString, TInt32)), + "reciprocalOverlap" -> TFloat64, )), "variants" -> TArray(TStruct( "altAllele" -> TString, @@ -113,7 +114,7 @@ object Nirvana { "regulatoryRegions" -> TArray(TStruct( "id" -> TString, "type" -> TString, - "consequence" -> TSet(TString) + "consequence" -> TSet(TString), )), "clinvar" -> TArray(TStruct( "id" -> TString, @@ -128,7 +129,7 @@ object Nirvana { "orphanetIds" -> TArray(TString), "significance" -> TString, "lastUpdatedDate" -> TString, - "pubMedIds" -> TArray(TString) + "pubMedIds" -> TArray(TString), )), "cosmic" -> TArray(TStruct( "id" -> TString, @@ -140,8 +141,8 @@ object Nirvana { "studies" -> TArray(TStruct( "id" -> TInt32, "histology" -> TString, - "primarySite" -> TString - )) + "primarySite" -> TString, + )), )), "dbsnp" -> TStruct("ids" -> TArray(TString)), "gnomad" -> TStruct( @@ -178,7 +179,7 @@ object Nirvana { "asjAc" -> TInt32, "asjAn" -> TInt32, "asjHc" -> TInt32, - "failedFilter" -> TBoolean + "failedFilter" -> TBoolean, ), "gnomadExome" -> TStruct( "coverage" -> TString, @@ -218,18 +219,18 @@ object Nirvana { "sasAc" -> TInt32, "sasAn" -> TInt32, "sasHc" -> TInt32, - "failedFilter" -> TBoolean + "failedFilter" -> TBoolean, ), "topmed" -> TStruct( "failedFilter" -> TBoolean, "allAc" -> TInt32, "allAn" -> TInt32, "allAf" -> TFloat64, - "allHc" -> TInt32 + "allHc" -> TInt32, ), "globalAllele" -> TStruct( "globalMinorAllele" -> TString, - "globalMinorAlleleFrequency" -> TFloat64 + "globalMinorAlleleFrequency" -> TFloat64, ), "oneKg" -> TStruct( "ancestralAllele" -> TString, @@ -250,12 +251,12 @@ object Nirvana { "eurAn" -> TInt32, "sasAf" -> TFloat64, "sasAc" -> TInt32, - "sasAn" -> TInt32 + "sasAn" -> TInt32, ), "mitomap" -> TArray(TStruct( "refAllele" -> TString, "altAllele" -> TString, - "diseases" -> TArray(TString), + "diseases" -> TArray(TString), "hasHomoplasmy" -> TBoolean, "hasHeteroplasmy" -> TBoolean, "status" -> TString, @@ -265,7 +266,7 @@ object Nirvana { "chromosome" -> TString, "begin" -> TInt32, "end" -> TInt32, - "variantType" -> TString + "variantType" -> TString, )), "transcripts" -> TStruct( "refSeq" -> TArray(TStruct( @@ -288,7 +289,7 @@ object Nirvana { "proteinId" -> TString, "proteinPos" -> TString, "siftScore" -> TFloat64, - "siftPrediction" -> TString + "siftPrediction" -> TString, )), "ensembl" -> TArray(TStruct( "transcript" -> TString, @@ -310,10 +311,10 @@ object Nirvana { "proteinId" -> TString, "proteinPos" -> TString, "siftScore" -> TFloat64, - "siftPrediction" -> TString - )) + "siftPrediction" -> TString, + )), ), - "overlappingGenes" -> TArray(TString) + "overlappingGenes" -> TArray(TString), )), "genes" -> TArray(TStruct( "name" -> TString, @@ -326,23 +327,23 @@ object Nirvana { "phenotype" -> TString, "mapping" -> TString, "inheritance" -> TArray(TString), - "comments" -> TString - )) + "comments" -> TString, + )), )), "exac" -> TStruct( "pLi" -> TFloat64, "pRec" -> TFloat64, - "pNull" -> TFloat64 - ) - )) + "pNull" -> TFloat64, + ), + )), ) - def printContext(w: (String) => Unit) { + def printContext(w: (String) => Unit): Unit = { w("##fileformat=VCFv4.1") w("#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT") } - def printElement(vaSignature: PType)(w: (String) => Unit, v: (Locus, Array[String])) { + def printElement(vaSignature: PType)(w: (String) => Unit, v: (Locus, Array[String])): Unit = { val (locus, alleles) = v val sb = new StringBuilder() @@ -362,16 +363,17 @@ object Nirvana { assert(tv.typ.key == FastSeq("locus", "alleles")) assert(tv.typ.rowType.size == 2) - val properties = try { - val p = new Properties() - val is = new FileInputStream(config) - p.load(is) - is.close() - p - } catch { - case e: IOException => - fatal(s"could not open file: ${ e.getMessage }") - } + val properties = + try { + val p = new Properties() + val is = new FileInputStream(config) + p.load(is) + is.close() + p + } catch { + case e: IOException => + fatal(s"could not open file: ${e.getMessage}") + } val dotnet = properties.getProperty("hail.nirvana.dotnet", "dotnet") @@ -383,9 +385,11 @@ object Nirvana { val cache = properties.getProperty("hail.nirvana.cache") - - val supplementaryAnnotationDirectoryOpt = Option(properties.getProperty("hail.nirvana.supplementaryAnnotationDirectory")) - val supplementaryAnnotationDirectory = if (supplementaryAnnotationDirectoryOpt.isEmpty) List[String]() else List("--sd", supplementaryAnnotationDirectoryOpt.get) + val supplementaryAnnotationDirectoryOpt = + Option(properties.getProperty("hail.nirvana.supplementaryAnnotationDirectory")) + val supplementaryAnnotationDirectory = if (supplementaryAnnotationDirectoryOpt.isEmpty) + List[String]() + else List("--sd", supplementaryAnnotationDirectoryOpt.get) val reference = properties.getProperty("hail.nirvana.reference") @@ -428,16 +432,19 @@ object Nirvana { } .grouped(localBlockSize) .flatMap { block => - val (jt, err, proc) = block.iterator.pipe(pb, - printContext, - printElement(localRowType), - _ => ()) + val (jt, err, proc) = + block.iterator.pipe(pb, printContext, printElement(localRowType), _ => ()) // The filter is because every other output line is a comma. val kt = jt.filter(_.startsWith("{\"chromosome")).map { s => - val a = JSONAnnotationImpex.importAnnotation(JsonMethods.parse(s), nirvanaSignature, warnContext = warnContext) - val locus = Locus(contigQuery(a).asInstanceOf[String], - startQuery(a).asInstanceOf[Int]) - val alleles = refQuery(a).asInstanceOf[String] +: altsQuery(a).asInstanceOf[IndexedSeq[String]] + val a = JSONAnnotationImpex.importAnnotation( + JsonMethods.parse(s), + nirvanaSignature, + warnContext = warnContext, + ) + val locus = + Locus(contigQuery(a).asInstanceOf[String], startQuery(a).asInstanceOf[Int]) + val alleles = + refQuery(a).asInstanceOf[String] +: altsQuery(a).asInstanceOf[IndexedSeq[String]] (Annotation(locus, alleles), a) } @@ -446,13 +453,16 @@ object Nirvana { val rc = proc.waitFor() if (rc != 0) - fatal(s"nirvana command failed with non-zero exit status $rc\n\tError:\n${err.toString}") + fatal( + s"nirvana command failed with non-zero exit status $rc\n\tError:\n${err.toString}" + ) r } } - val nirvanaRVDType = prev.typ.copy(rowType = prev.rowPType.appendKey("nirvana", PType.canonical(nirvanaSignature))) + val nirvanaRVDType = + prev.typ.copy(rowType = prev.rowPType.appendKey("nirvana", PType.canonical(nirvanaSignature))) val nirvanaRowType = nirvanaRVDType.rowType @@ -472,13 +482,15 @@ object Nirvana { rvb.end() } - }).persist(ctx, StorageLevel.MEMORY_AND_DISK) - - TableValue(ctx, - TableType(nirvanaRowType.virtualType, FastSeq("locus", "alleles"), TStruct.empty), - BroadcastRow.empty(ctx), - nirvanaRVD - ) + }, + ).persist(ctx, StorageLevel.MEMORY_AND_DISK) + + TableValue( + ctx, + TableType(nirvanaRowType.virtualType, FastSeq("locus", "alleles"), TStruct.empty), + BroadcastRow.empty(ctx), + nirvanaRVD, + ) } } @@ -486,12 +498,15 @@ case class Nirvana(config: String, blockSize: Int = 500000) extends TableToTable override def typ(childType: TableType): TableType = { assert(childType.key == FastSeq("locus", "alleles")) assert(childType.rowType.size == 2) - TableType(childType.rowType ++ TStruct("nirvana" -> Nirvana.nirvanaSignature), childType.key, childType.globalType) + TableType( + childType.rowType ++ TStruct("nirvana" -> Nirvana.nirvanaSignature), + childType.key, + childType.globalType, + ) } def preservesPartitionCounts: Boolean = false - def execute(ctx: ExecuteContext, tv: TableValue): TableValue = { + def execute(ctx: ExecuteContext, tv: TableValue): TableValue = Nirvana.annotate(ctx, tv, config, blockSize) - } } diff --git a/hail/src/main/scala/is/hail/methods/PCA.scala b/hail/src/main/scala/is/hail/methods/PCA.scala index 5db87c9aff2..1aa38207519 100644 --- a/hail/src/main/scala/is/hail/methods/PCA.scala +++ b/hail/src/main/scala/is/hail/methods/PCA.scala @@ -1,28 +1,32 @@ package is.hail.methods -import breeze.linalg.{*, DenseMatrix, DenseVector} import is.hail.HailContext import is.hail.annotations._ import is.hail.backend.ExecuteContext -import is.hail.expr.ir.functions.MatrixToTableFunction import is.hail.expr.ir.{MatrixValue, TableValue} +import is.hail.expr.ir.functions.MatrixToTableFunction import is.hail.rvd.{RVD, RVDType} import is.hail.sparkextras.ContextRDD import is.hail.types._ import is.hail.types.physical.{PCanonicalStruct, PStruct} import is.hail.types.virtual._ import is.hail.utils._ + +import breeze.linalg.{*, DenseMatrix, DenseVector} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.linalg.distributed.{IndexedRow, IndexedRowMatrix} import org.apache.spark.sql.Row case class PCA(entryField: String, k: Int, computeLoadings: Boolean) extends MatrixToTableFunction { - override def typ(childType: MatrixType): TableType = { + override def typ(childType: MatrixType): TableType = TableType( childType.rowKeyStruct ++ TStruct("loadings" -> TArray(TFloat64)), childType.rowKey, - TStruct("eigenvalues" -> TArray(TFloat64), "scores" -> TArray(childType.colKeyStruct ++ TStruct("scores" -> TArray(TFloat64))))) - } + TStruct( + "eigenvalues" -> TArray(TFloat64), + "scores" -> TArray(childType.colKeyStruct ++ TStruct("scores" -> TArray(TFloat64))), + ), + ) def preservesPartitionCounts: Boolean = false @@ -41,8 +45,9 @@ case class PCA(entryField: String, k: Int, computeLoadings: Boolean) extends Mat val svd = irm.computeSVD(k, computeLoadings) if (svd.s.size < k) fatal( - s"Found only ${ svd.s.size } non-zero (or nearly zero) eigenvalues, " + - s"but user requested ${ k } principal components.") + s"Found only ${svd.s.size} non-zero (or nearly zero) eigenvalues, " + + s"but user requested $k principal components." + ) def collectRowKeys(): Array[Annotation] = { val rowKeyIdx = mv.typ.rowKeyFieldIdx @@ -54,7 +59,9 @@ case class PCA(entryField: String, k: Int, computeLoadings: Boolean) extends Mat .collect() } - val rowType = PCanonicalStruct.canonical(TStruct(mv.typ.rowKey.zip(mv.typ.rowKeyStruct.types): _*) ++ TStruct("loadings" -> TArray(TFloat64))) + val rowType = PCanonicalStruct.canonical(TStruct( + mv.typ.rowKey.zip(mv.typ.rowKeyStruct.types): _* + ) ++ TStruct("loadings" -> TArray(TFloat64))) .setRequired(true) .asInstanceOf[PStruct] val rowKeysBc = HailContext.backend.broadcast(collectRowKeys()) @@ -99,7 +106,7 @@ case class PCA(entryField: String, k: Int, computeLoadings: Boolean) extends Mat svd.V.asInstanceOf[org.apache.spark.mllib.linalg.DenseMatrix].values else svd.V.toArray - + val V = new DenseMatrix[Double](svd.V.numRows, svd.V.numCols, data) val S = DenseVector(svd.s.toArray) @@ -107,7 +114,7 @@ case class PCA(entryField: String, k: Int, computeLoadings: Boolean) extends Mat val scaledEigenvectors = V(*, ::) *:* S val scores = (0 until mv.nCols).iterator.map { i => - (0 until k).iterator.map { j => scaledEigenvectors(i, j) }.toFastSeq + (0 until k).iterator.map(j => scaledEigenvectors(i, j)).toFastSeq }.toFastSeq val g1 = f1(mv.globals.value, eigenvalues.toFastSeq) @@ -115,9 +122,12 @@ case class PCA(entryField: String, k: Int, computeLoadings: Boolean) extends Mat f3(mv.typ.extractColKey(cv.asInstanceOf[Row]), scores(i)) } val newGlobal = f2(g1, globalScores) - - TableValue(ctx, + + TableValue( + ctx, TableType(rowType.virtualType, mv.typ.rowKey, newGlobalType.asInstanceOf[TStruct]), - BroadcastRow(ctx, newGlobal.asInstanceOf[Row], newGlobalType.asInstanceOf[TStruct]), rvd) + BroadcastRow(ctx, newGlobal.asInstanceOf[Row], newGlobalType.asInstanceOf[TStruct]), + rvd, + ) } } diff --git a/hail/src/main/scala/is/hail/methods/PCRelate.scala b/hail/src/main/scala/is/hail/methods/PCRelate.scala index 545add5c575..bc6ec96f551 100644 --- a/hail/src/main/scala/is/hail/methods/PCRelate.scala +++ b/hail/src/main/scala/is/hail/methods/PCRelate.scala @@ -1,20 +1,18 @@ package is.hail.methods -import breeze.linalg.{DenseMatrix => BDM} -import is.hail.linalg.BlockMatrix -import is.hail.linalg.BlockMatrix.ops._ -import is.hail.utils._ -import is.hail.HailContext import is.hail.backend.ExecuteContext -import is.hail.expr.ir.functions.BlockMatrixToTableFunction import is.hail.expr.ir.TableValue +import is.hail.expr.ir.functions.BlockMatrixToTableFunction +import is.hail.linalg.BlockMatrix +import is.hail.linalg.BlockMatrix.ops._ import is.hail.types.{BlockMatrixType, TableType} import is.hail.types.virtual._ -import org.apache.spark.storage.StorageLevel +import is.hail.utils._ + +import breeze.linalg.{DenseMatrix => BDM} import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row - -import scala.language.{higherKinds, implicitConversions} +import org.apache.spark.storage.StorageLevel object PCRelate { type M = BlockMatrix @@ -34,12 +32,13 @@ object PCRelate { val defaultStorageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK private val sig = TStruct( - ("i", TInt32), - ("j", TInt32), - ("kin", TFloat64), - ("ibd0", TFloat64), - ("ibd1", TFloat64), - ("ibd2", TFloat64)) + ("i", TInt32), + ("j", TInt32), + ("kin", TFloat64), + ("ibd0", TFloat64), + ("ibd1", TFloat64), + ("ibd2", TFloat64), + ) private val keys: IndexedSeq[String] = Array("i", "j") @@ -53,7 +52,9 @@ object PCRelate { while (i < nRows) { val row = x(i) if (row.length != nCols) - fatal(s"pc_relate: column index 0 has $nCols scores but column index $i has ${row.length} scores.") + fatal( + s"pc_relate: column index 0 has $nCols scores but column index $i has ${row.length} scores." + ) var j = 0 while (j < nCols) { val e = row(j) @@ -71,14 +72,18 @@ object PCRelate { r: Result[M], blockSize: Int, minKinshipOptional: Option[Double], - statistics: StatisticSubset): RDD[Row] = { + statistics: StatisticSubset, + ): RDD[Row] = { val minKinship = minKinshipOptional.getOrElse(defaultMinKinship) - def fuseBlocks(i: Int, j: Int, + def fuseBlocks( + i: Int, + j: Int, lmPhi: BDM[Double], lmK0: BDM[Double], lmK1: BDM[Double], - lmK2: BDM[Double]) = { + lmK2: BDM[Double], + ) = { if (i <= j) { val iOffset = i * blockSize @@ -88,7 +93,8 @@ object PCRelate { var jj = 0 while (jj < lmPhi.cols) { var ii = 0 - val nRowsAboveDiagonal = if (i < j) lmPhi.rows else (jj + 1) // assumes square blocks on diagonal + val nRowsAboveDiagonal = + if (i < j) lmPhi.rows else (jj + 1) // assumes square blocks on diagonal while (ii < nRowsAboveDiagonal) { val kin = lmPhi(ii, jj) if (kin >= minKinship) { @@ -108,16 +114,25 @@ object PCRelate { val Result(phi, k0, k1, k2) = r - // FIXME replace join with zipPartitions, throw away lower triangular blocks first, avoid the nulls + /* FIXME replace join with zipPartitions, throw away lower triangular blocks first, avoid the + * nulls */ statistics match { case PhiOnly => phi.blocks - .flatMap { case ((blocki, blockj), phi) => fuseBlocks(blocki, blockj, phi, null, null, null) } + .flatMap { case ((blocki, blockj), phi) => + fuseBlocks(blocki, blockj, phi, null, null, null) + } case PhiK2 => (phi.blocks join k2.blocks) - .flatMap { case ((blocki, blockj), (phi, k2)) => fuseBlocks(blocki, blockj, phi, null, null, k2) } + .flatMap { case ((blocki, blockj), (phi, k2)) => + fuseBlocks(blocki, blockj, phi, null, null, k2) + } case PhiK2K0 => (phi.blocks join k0.blocks join k2.blocks) - .flatMap { case ((blocki, blockj), ((phi, k0), k2)) => fuseBlocks(blocki, blockj, phi, k0, null, k2) } + .flatMap { case ((blocki, blockj), ((phi, k0), k2)) => + fuseBlocks(blocki, blockj, phi, k0, null, k2) + } case PhiK2K0K1 => (phi.blocks join k0.blocks join k1.blocks join k2.blocks) - .flatMap { case ((blocki, blockj), (((phi, k0), k1), k2)) => fuseBlocks(blocki, blockj, phi, k0, k1, k2) } + .flatMap { case ((blocki, blockj), (((phi, k0), k1), k2)) => + fuseBlocks(blocki, blockj, phi, k0, k1, k2) + } } } @@ -128,8 +143,8 @@ case class PCRelate( maf: Double, blockSize: Int, minKinship: Option[Double] = None, - statistics: PCRelate.StatisticSubset = PCRelate.defaultStatisticSubset) - extends BlockMatrixToTableFunction with Serializable { + statistics: PCRelate.StatisticSubset = PCRelate.defaultStatisticSubset, +) extends BlockMatrixToTableFunction with Serializable { import PCRelate._ @@ -177,19 +192,25 @@ case class PCRelate( Double.NaN else mu - } (blockedG, preMu) + }(blockedG, preMu) val variance = cacheWhen(PhiK2)( - ctx, mu.map(mu => if (java.lang.Double.isNaN(mu)) 0.0 else mu * (1.0 - mu))) + ctx, + mu.map(mu => if (java.lang.Double.isNaN(mu)) 0.0 else mu * (1.0 - mu)), + ) // write phi to cache and increase parallelism of multiplies before phi.diagonal() val phi = writeRead(ctx, this.phi(ctx, mu, variance, blockedG)) if (statistics >= PhiK2) { val k2 = cacheWhen(PhiK2K0)( - ctx, this.k2(ctx, phi, mu, variance, blockedG)) + ctx, + this.k2(ctx, phi, mu, variance, blockedG), + ) if (statistics >= PhiK2K0) { val k0 = cacheWhen(PhiK2K0K1)( - ctx, this.k0(ctx, phi, mu, k2, blockedG, ibs0(ctx, blockedG, mu))) + ctx, + this.k0(ctx, phi, mu, k2, blockedG, ibs0(ctx, blockedG, mu)), + ) if (statistics >= PhiK2K0K1) { val k1 = 1.0 - (k2 + k0) Result(phi, k0, k1, k2) @@ -201,11 +222,7 @@ case class PCRelate( Result(phi, null, null, null) } - /** - * {@code g} is variant by sample - * {@code pcs} is sample by numPCs - * - **/ + /** {@code g} is variant by sample {@code pcs} is sample by numPCs */ private[methods] def mu(ctx: ExecuteContext, blockedG: M, pcs: BDM[Double]): M = { import breeze.linalg._ @@ -221,7 +238,7 @@ case class PCRelate( private[methods] def phi(ctx: ExecuteContext, mu: M, variance: M, g: M): M = { val centeredAF = BlockMatrix.map2 { (g, mu) => if (java.lang.Double.isNaN(mu)) 0.0 else g / 2 - mu - } (g, mu) + }(g, mu) val stddev = variance.sqrt() @@ -230,14 +247,16 @@ case class PCRelate( private[methods] def ibs0(ctx: ExecuteContext, g: M, mu: M): M = { val homalt = - BlockMatrix.map2 { (g, mu) => - if (java.lang.Double.isNaN(mu) || g != 2.0) 0.0 else 1.0 - } (g, mu) + BlockMatrix.map2((g, mu) => if (java.lang.Double.isNaN(mu) || g != 2.0) 0.0 else 1.0)( + g, + mu, + ) val homref = - BlockMatrix.map2 { (g, mu) => - if (java.lang.Double.isNaN(mu) || g != 0.0) 0.0 else 1.0 - } (g, mu) + BlockMatrix.map2((g, mu) => if (java.lang.Double.isNaN(mu) || g != 0.0) 0.0 else 1.0)( + g, + mu, + ) val temp = writeRead(ctx, homalt.T.dot(homref)) @@ -246,17 +265,20 @@ case class PCRelate( private[methods] def k2(ctx: ExecuteContext, phi: M, mu: M, variance: M, g: M): M = { val twoPhi_ii = phi.diagonal().map(2.0 * _) - val normalizedGD = g.map2WithIndex(mu, { case (_, i, g, mu) => - if (java.lang.Double.isNaN(mu)) - 0.0 // https://github.com/Bioconductor-mirror/GENESIS/blob/release-3.5/R/pcrelate.R#L391 - else { - val gd = if (g == 0.0) mu - else if (g == 1.0) 0.0 - else 1.0 - mu - - gd - mu * (1.0 - mu) * twoPhi_ii(i.toInt) - } - }) + val normalizedGD = g.map2WithIndex( + mu, + { case (_, i, g, mu) => + if (java.lang.Double.isNaN(mu)) + 0.0 // https://github.com/Bioconductor-mirror/GENESIS/blob/release-3.5/R/pcrelate.R#L391 + else { + val gd = if (g == 0.0) mu + else if (g == 1.0) 0.0 + else 1.0 - mu + + gd - mu * (1.0 - mu) * twoPhi_ii(i.toInt) + } + }, + ) gram(ctx, normalizedGD) / gram(ctx, variance) } @@ -276,6 +298,6 @@ case class PCRelate( 1.0 - 4.0 * phi + k2 else ibs0 / denom - } (phi, denom, k2, ibs0) + }(phi, denom, k2, ibs0) } } diff --git a/hail/src/main/scala/is/hail/methods/PoissonRegression.scala b/hail/src/main/scala/is/hail/methods/PoissonRegression.scala index a45e1e7f2f9..b174616d86d 100644 --- a/hail/src/main/scala/is/hail/methods/PoissonRegression.scala +++ b/hail/src/main/scala/is/hail/methods/PoissonRegression.scala @@ -1,17 +1,17 @@ package is.hail.methods -import breeze.linalg._ import is.hail.HailContext import is.hail.annotations._ import is.hail.backend.ExecuteContext -import is.hail.expr.ir.functions.MatrixToTableFunction import is.hail.expr.ir.{IntArrayBuilder, MatrixValue, TableValue} -import is.hail.types.virtual.{TFloat64, TStruct} -import is.hail.types.{MatrixType, TableType} -import is.hail.rvd.RVDType +import is.hail.expr.ir.functions.MatrixToTableFunction import is.hail.stats._ +import is.hail.types.{MatrixType, TableType} +import is.hail.types.virtual.{TFloat64, TStruct} import is.hail.utils._ +import breeze.linalg._ + case class PoissonRegression( test: String, yField: String, @@ -19,13 +19,17 @@ case class PoissonRegression( covFields: Seq[String], passThrough: Seq[String], maxIterations: Int, - tolerance: Double + tolerance: Double, ) extends MatrixToTableFunction { override def typ(childType: MatrixType): TableType = { val poisRegTest = PoissonRegressionTest.tests(test) val passThroughType = TStruct(passThrough.map(f => f -> childType.rowType.field(f).typ): _*) - TableType(childType.rowKeyStruct ++ passThroughType ++ poisRegTest.schema, childType.rowKey, TStruct.empty) + TableType( + childType.rowKeyStruct ++ passThroughType ++ poisRegTest.schema, + childType.rowKey, + TStruct.empty, + ) } def preservesPartitionCounts: Boolean = true @@ -35,7 +39,8 @@ case class PoissonRegression( val tableType = typ(mv.typ) val newRVDType = tableType.canonicalRVDType - val (y, cov, completeColIdx) = RegressionUtils.getPhenoCovCompleteSamples(mv, yField, covFields.toArray) + val (y, cov, completeColIdx) = + RegressionUtils.getPhenoCovCompleteSamples(mv, yField, covFields.toArray) if (!y.forall(yi => math.floor(yi) == yi && yi >= 0)) fatal(s"For poisson regression, y must be numeric with all values non-negative integers") @@ -47,26 +52,30 @@ case class PoissonRegression( val d = n - k - 1 if (d < 1) - fatal(s"$n samples and ${ k + 1 } ${ plural(k, "covariate") } (including x) implies $d degrees of freedom.") + fatal( + s"$n samples and ${k + 1} ${plural(k, "covariate")} (including x) implies $d degrees of freedom." + ) info(s"poisson_regression_rows: running $test on $n samples for response variable y,\n" - + s" with input variable x, and ${ k } additional ${ plural(k, "covariate") }...") + + s" with input variable x, and $k additional ${plural(k, "covariate")}...") val nullModel = new PoissonRegressionModel(cov, y) - var nullFit = nullModel.fit(None, maxIter=maxIterations, tol=tolerance) + val nullFit = nullModel.fit(None, maxIter = maxIterations, tol = tolerance) if (!nullFit.converged) fatal("Failed to fit poisson regression null model (standard MLE with covariates only): " + ( if (nullFit.exploded) - s"exploded at Newton iteration ${ nullFit.nIter }" + s"exploded at Newton iteration ${nullFit.nIter}" else - "Newton iteration failed to converge")) + "Newton iteration failed to converge" + )) val backend = HailContext.backend val completeColIdxBc = backend.broadcast(completeColIdx) val yBc = backend.broadcast(y) - val XBc = backend.broadcast(new DenseMatrix[Double](n, k + 1, cov.toArray ++ Array.ofDim[Double](n))) + val XBc = + backend.broadcast(new DenseMatrix[Double](n, k + 1, cov.toArray ++ Array.ofDim[Double](n))) val nullFitBc = backend.broadcast(nullFit) val poisRegTestBc = backend.broadcast(poisRegTest) @@ -89,14 +98,24 @@ case class PoissonRegression( val X = XBc.value.copy it.map { ptr => - RegressionUtils.setMeanImputedDoubles(X.data, n * k, completeColIdxBc.value, missingCompleteCols, - ptr, fullRowType, entryArrayType, entryType, entryArrayIdx, fieldIdx) + RegressionUtils.setMeanImputedDoubles( + X.data, + n * k, + completeColIdxBc.value, + missingCompleteCols, + ptr, + fullRowType, + entryArrayType, + entryType, + entryArrayIdx, + fieldIdx, + ) rvb.start(newRVDType.rowType) rvb.startStruct() rvb.addFields(fullRowType, ctx.r, ptr, copiedFieldIndices) poisRegTestBc.value - .test(X, yBc.value, nullFitBc.value, "poisson", maxIter=maxIterations, tol=tolerance) + .test(X, yBc.value, nullFitBc.value, "poisson", maxIter = maxIterations, tol = tolerance) .addToRVB(rvb) rvb.endStruct() rvb.end() diff --git a/hail/src/main/scala/is/hail/methods/Skat.scala b/hail/src/main/scala/is/hail/methods/Skat.scala index 03a995867d6..f8c3ae4b088 100644 --- a/hail/src/main/scala/is/hail/methods/Skat.scala +++ b/hail/src/main/scala/is/hail/methods/Skat.scala @@ -1,56 +1,52 @@ package is.hail.methods -import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, _} -import breeze.numerics._ import is.hail.HailContext import is.hail.annotations.{Annotation, Region, UnsafeRow} import is.hail.backend.ExecuteContext -import is.hail.expr.ir.functions.MatrixToTableFunction import is.hail.expr.ir.{IntArrayBuilder, MatrixValue, TableValue} -import is.hail.stats.{GeneralizedChiSquaredDistribution, LogisticRegressionModel, RegressionUtils, eigSymD} +import is.hail.expr.ir.functions.MatrixToTableFunction +import is.hail.stats.{ + eigSymD, GeneralizedChiSquaredDistribution, LogisticRegressionModel, RegressionUtils, +} import is.hail.types._ import is.hail.types.virtual.{TFloat64, TInt32, TStruct, Type} import is.hail.utils._ + +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, _} +import breeze.numerics._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row -/* -Skat implements the burden test described in: - -Wu MC, Lee S, Cai T, Li Y, Boehnke M, Lin X. -Rare-Variant Association Testing for Sequencing Data with the Sequence Kernel Association Test. -American Journal of Human Genetics. 2011;89(1):82-93. doi:10.1016/j.ajhg.2011.05.029. - -For n samples and a group of m variants, we have: -y = n x 1 vector of phenotypes -X = n x k matrix of covariates including intercept = cov -mu = n x 1 vector of predictions under the null model, linear: mu = Xb, logistic: mu = sigmoid(Xb) -W = m x m diagonal matrix of variant weights -G = n x m matrix of genotypes - -The variance component score statistic in the paper is: -Q = (y - mu).t * G * W * G.t * (y - mu) - -The null distribution of Q is a mixture of independent 1 d.o.f. chi-squared random variables -with coefficients given by the non-zero eigenvalues of n x n matrix -Z * Z.t = sqrt(P_0) * G * W * G.t * sqrt(P_0) -where -P_0 = V - V * X * (X.t * V * X)^-1 * X.t * V -V = n x n diagonal matrix with diagonal elements given by sigmaSq for linear and mu_i * (1 - mu_i) for logistic - -To scale to large n, we exploit that Z * Z.t has the same non-zero eigenvalues as the m x m matrix -Z.t * Z = sqrt(W) * G.t * P_0 * G * sqrt(W) -and express the latter gramian matrix in terms of matrices A and B as follows: -linear: sigmaSq * Z.t * Z = A.t * A - B.t * B, A = G * sqrt(W) B = Q0.t * G * sqrt(W) -logistic: Z.t * Z = A.t * A - B.t * B, A = sqrt(V) * G * sqrt(W) B = C^-1 * X.t * V * G * sqrt(W) -where -Q0 = n x k matrix in QR decomposition of X = Q0 * R -C = k x k Cholesky factor of X.t * V * X = C * C.t - -For each variant, SkatTuple encodes the corresponding summand of Q and columns of A and B. -We compute and group SkatTuples by key. Then, for each key, we compute Q and A.t * A - B.t * B, -the eigenvalues of the latter, and the p-value with the Davies algorithm. -*/ +/* Skat implements the burden test described in: + * + * Wu MC, Lee S, Cai T, Li Y, Boehnke M, Lin X. + * Rare-Variant Association Testing for Sequencing Data with the Sequence Kernel Association Test. + * American Journal of Human Genetics. 2011;89(1):82-93. doi:10.1016/j.ajhg.2011.05.029. + * + * For n samples and a group of m variants, we have: + * y = n x 1 vector of phenotypes X = n x k matrix of covariates including intercept = cov mu = n x + * 1 vector of predictions under the null model, linear: mu = Xb, logistic: mu = sigmoid(Xb) W = m x + * m diagonal matrix of variant weights G = n x m matrix of genotypes + * + * The variance component score statistic in the paper is: + * Q = (y - mu).t * G * W * G.t * (y - mu) + * + * The null distribution of Q is a mixture of independent 1 d.o.f. chi-squared random variables with + * coefficients given by the non-zero eigenvalues of n x n matrix Z * Z.t = sqrt(P_0) * G * W * G.t + * * sqrt(P_0) where P_0 = V - V * X * (X.t * V * X)^-1 * X.t * V V = n x n diagonal matrix with + * diagonal elements given by sigmaSq for linear and mu_i * (1 - mu_i) for logistic + * + * To scale to large n, we exploit that Z * Z.t has the same non-zero eigenvalues as the m x m + * matrix Z.t * Z = sqrt(W) * G.t * P_0 * G * sqrt(W) and express the latter gramian matrix in terms + * of matrices A and B as follows: + * linear: sigmaSq * Z.t * Z = A.t * A - B.t * B, A = G * sqrt(W) B = Q0.t * G * sqrt(W) logistic: + * Z.t * Z = A.t * A - B.t * B, A = sqrt(V) * G * sqrt(W) B = C^-1 * X.t * V * G * sqrt(W) where Q0 + * = n x k matrix in QR decomposition of X = Q0 * R C = k x k Cholesky factor of X.t * V * X = C * + * C.t + * + * For each variant, SkatTuple encodes the corresponding summand of Q and columns of A and B. + * We compute and group SkatTuples by key. Then, for each key, we compute Q and A.t * A - B.t * B, + * the eigenvalues of the latter, and the p-value with the Davies algorithm. */ case class SkatTuple(q: Double, a: BDV[Double], b: BDV[Double]) object Skat { @@ -58,7 +54,8 @@ object Skat { require(st.nonEmpty) val st0 = st(0) - // Holds for all st(i) by construction of linearTuple and logisticTuple, checking st(0) defensively + /* Holds for all st(i) by construction of linearTuple and logisticTuple, checking st(0) + * defensively */ require(st0.a.offset == 0 && st0.a.stride == 1 && st0.b.offset == 0 && st0.b.stride == 1) val m = st.length @@ -116,9 +113,11 @@ object Skat { def computeGramian(st: Array[SkatTuple], useSmallN: Boolean): (Double, BDM[Double]) = if (useSmallN) computeGramianSmallN(st) else computeGramianLargeN(st) - // gramian is the m x m matrix (G * sqrt(W)).t * P_0 * (G * sqrt(W)) which has the same non-zero eigenvalues + /* gramian is the m x m matrix (G * sqrt(W)).t * P_0 * (G * sqrt(W)) which has the same non-zero + * eigenvalues */ // as the n x n matrix in the paper P_0^{1/2} * (G * W * G.t) * P_0^{1/2} - def computePval(q: Double, gramian: BDM[Double], accuracy: Double, iterations: Int): (Double, Int) = { + def computePval(q: Double, gramian: BDM[Double], accuracy: Double, iterations: Int) + : (Double, Int) = { val allEvals = eigSymD.justEigenvalues(gramian) // filter out those eigenvalues below the mean / 100k @@ -131,11 +130,9 @@ object Skat { val s = 0.0 val result = GeneralizedChiSquaredDistribution.cdfReturnExceptions( - q, dof, evals, noncentrality, s, iterations, accuracy + q, dof, evals, noncentrality, s, iterations, accuracy, ) val x = result.value - val nIntegrations = result.nIterations - val converged = result.converged val fault = result.fault val pval = 1 - x @@ -154,7 +151,7 @@ case class Skat( accuracy: Double, iterations: Int, logistic_max_iterations: Int, - logistic_tolerance: Double + logistic_tolerance: Double, ) extends MatrixToTableFunction { assert(logistic || logistic_max_iterations == 0 && logistic_tolerance == 0.0) @@ -168,7 +165,8 @@ case class Skat( ("size", TInt32), ("q_stat", TFloat64), ("p_value", TFloat64), - ("fault", TInt32)) + ("fault", TInt32), + ) TableType(skatSchema, FastSeq("id"), TStruct.empty) } @@ -186,14 +184,17 @@ case class Skat( if (iterations <= 0) fatal(s"iterations must be positive, default is 10000, got $iterations") - val (y, cov, completeColIdx) = RegressionUtils.getPhenoCovCompleteSamples(mv, yField, covFields.toArray) + val (y, cov, completeColIdx) = + RegressionUtils.getPhenoCovCompleteSamples(mv, yField, covFields.toArray) val n = y.size val k = cov.cols val d = n - k if (d < 1) - fatal(s"$n samples and $k ${ plural(k, "covariate") } (including intercept) implies $d degrees of freedom.") + fatal( + s"$n samples and $k ${plural(k, "covariate")} (including intercept) implies $d degrees of freedom." + ) if (logistic) { val badVals = y.findAll(yi => yi != 0d && yi != 1d) if (badVals.nonEmpty) @@ -201,7 +202,7 @@ case class Skat( s"sample; found ${badVals.length} ${plural(badVals.length, "violation")} starting with ${badVals(0)}") } - val (keyGsWeightRdd, keyType) = + val (keyGsWeightRdd, _) = computeKeyGsWeightRdd(mv, xField, completeColIdx, keyField, weightField) val backend = HailContext.backend @@ -235,7 +236,8 @@ case class Skat( val size = vsArray.length if (size <= maxSize) { val skatTuples = vsArray.map((linearTuple _).tupled) - val (q, gramian) = Skat.computeGramian(skatTuples, size.toLong * n <= maxEntriesForSmallN) + val (q, gramian) = + Skat.computeGramian(skatTuples, size.toLong * n <= maxEntriesForSmallN) // using q / sigmaSq since Z.t * Z = gramian / sigmaSq val (pval, fault) = Skat.computePval(q / sigmaSq, gramian, accuracy, iterations) @@ -251,26 +253,34 @@ case class Skat( def logisticSkat(): RDD[Row] = { val (sqrtV, res, cinvXtV) = if (k > 0) { - val logRegM = new LogisticRegressionModel(cov, y).fit(maxIter=logistic_max_iterations, tol=logistic_tolerance) + val logRegM = new LogisticRegressionModel(cov, y).fit( + maxIter = logistic_max_iterations, + tol = logistic_tolerance, + ) if (!logRegM.converged) fatal("Failed to fit logistic regression null model (MLE with covariates only): " + ( if (logRegM.exploded) - s"exploded at Newton iteration ${ logRegM.nIter }" + s"exploded at Newton iteration ${logRegM.nIter}" else - "Newton iteration failed to converge")) + "Newton iteration failed to converge" + )) val mu = sigmoid(cov * logRegM.b) val V = mu.map(x => x * (1 - x)) val VX = cov(::, *) *:* V val XtVX = cov.t * VX XtVX.forceSymmetry() var Cinv: BDM[Double] = null - try { + try Cinv = inv(cholesky(XtVX)) - } catch { + catch { case e: MatrixSingularException => - fatal("Singular matrix exception while computing Cholesky factor of X.t * V * X.\n" + e.getMessage) + fatal( + "Singular matrix exception while computing Cholesky factor of X.t * V * X.\n" + e.getMessage + ) case e: NotConvergedException => - fatal("Not converged exception while inverting Cholesky factor of X.t * V * X.\n" + e.getMessage) + fatal( + "Not converged exception while inverting Cholesky factor of X.t * V * X.\n" + e.getMessage + ) } (sqrt(V), y - mu, Cinv * VX.t) } else @@ -283,7 +293,7 @@ case class Skat( def logisticTuple(x: BDV[Double], w: Double): SkatTuple = { val xw = x * math.sqrt(w) val sqrt_q = resBc.value dot xw - SkatTuple(sqrt_q * sqrt_q, xw *:* sqrtVBc.value , CinvXtVBc.value * xw) + SkatTuple(sqrt_q * sqrt_q, xw *:* sqrtVBc.value, CinvXtVBc.value * xw) } keyGsWeightRdd.map { case (key, vs) => @@ -308,12 +318,14 @@ case class Skat( TableValue(ctx, tableType.rowType, tableType.key, skatRdd) } - def computeKeyGsWeightRdd(mv: MatrixValue, + def computeKeyGsWeightRdd( + mv: MatrixValue, xField: String, completeColIdx: Array[Int], keyField: String, // returns ((key, [(gs_v, weight_v)]), keyType) - weightField: String): (RDD[(Annotation, Iterable[(BDV[Double], Double)])], Type) = { + weightField: String, + ): (RDD[(Annotation, Iterable[(BDV[Double], Double)])], Type) = { val fullRowType = mv.rvRowPType val keyStructField = fullRowType.field(keyField) @@ -336,24 +348,42 @@ case class Skat( val n = completeColIdx.length val completeColIdxBc = HailContext.backend.broadcast(completeColIdx) - // I believe no `boundary` is needed here because `mapPartitions` calls `run` which calls `cleanupRegions`. - (mv.rvd.mapPartitions { (ctx, it) => it.flatMap { ptr => - val keyIsDefined = fullRowType.isFieldDefined(ptr, keyIndex) - val weightIsDefined = fullRowType.isFieldDefined(ptr, weightIndex) - - if (keyIsDefined && weightIsDefined) { - val weight = Region.loadDouble(fullRowType.loadField(ptr, weightIndex)) - if (weight < 0) - fatal(s"Row weights must be non-negative, got $weight") - val key = Annotation.copy(keyType.virtualType, UnsafeRow.read(keyType, ctx.r, fullRowType.loadField(ptr, keyIndex))) - val data = new Array[Double](n) - - RegressionUtils.setMeanImputedDoubles(data, 0, completeColIdxBc.value, new IntArrayBuilder(), - ptr, fullRowType, entryArrayType, entryType, entryArrayIdx, fieldIdx) - Some(key -> (BDV(data) -> weight)) - } else None - } - }.groupByKey(), keyType.virtualType) + /* I believe no `boundary` is needed here because `mapPartitions` calls `run` which calls + * `cleanupRegions`. */ + ( + mv.rvd.mapPartitions { (ctx, it) => + it.flatMap { ptr => + val keyIsDefined = fullRowType.isFieldDefined(ptr, keyIndex) + val weightIsDefined = fullRowType.isFieldDefined(ptr, weightIndex) + + if (keyIsDefined && weightIsDefined) { + val weight = Region.loadDouble(fullRowType.loadField(ptr, weightIndex)) + if (weight < 0) + fatal(s"Row weights must be non-negative, got $weight") + val key = Annotation.copy( + keyType.virtualType, + UnsafeRow.read(keyType, ctx.r, fullRowType.loadField(ptr, keyIndex)), + ) + val data = new Array[Double](n) + + RegressionUtils.setMeanImputedDoubles( + data, + 0, + completeColIdxBc.value, + new IntArrayBuilder(), + ptr, + fullRowType, + entryArrayType, + entryType, + entryArrayIdx, + fieldIdx, + ) + Some(key -> (BDV(data) -> weight)) + } else None + } + }.groupByKey(), + keyType.virtualType, + ) } } diff --git a/hail/src/main/scala/is/hail/methods/VEP.scala b/hail/src/main/scala/is/hail/methods/VEP.scala index e3c4df79eb8..2cd4d9456b2 100644 --- a/hail/src/main/scala/is/hail/methods/VEP.scala +++ b/hail/src/main/scala/is/hail/methods/VEP.scala @@ -1,6 +1,5 @@ package is.hail.methods -import com.fasterxml.jackson.core.JsonParseException import is.hail.annotations._ import is.hail.backend.ExecuteContext import is.hail.expr._ @@ -15,33 +14,34 @@ import is.hail.types.physical.PType import is.hail.types.virtual._ import is.hail.utils._ import is.hail.variant.{Locus, RegionValueVariant, VariantMethods} -import org.apache.spark.sql.Row -import org.json4s.jackson.JsonMethods -import org.json4s.{Formats, JValue} import scala.collection.JavaConverters._ import scala.collection.mutable +import com.fasterxml.jackson.core.JsonParseException +import org.apache.spark.sql.Row +import org.json4s.{Formats, JValue} +import org.json4s.jackson.JsonMethods + case class VEPConfiguration( command: Array[String], env: Map[String, String], - vep_json_schema: TStruct) + vep_json_schema: TStruct, +) object VEP { def readConfiguration(fs: FS, path: String): VEPConfiguration = { - val jv = using(fs.open(path)) { in => - JsonMethods.parse(in) - } + val jv = using(fs.open(path))(in => JsonMethods.parse(in)) implicit val formats: Formats = defaultJSONFormats + new TStructSerializer jv.extract[VEPConfiguration] } - def printContext(w: (String) => Unit) { + def printContext(w: (String) => Unit): Unit = { w("##fileformat=VCFv4.1") w("#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT") } - def printElement(w: (String) => Unit, v: (Locus, IndexedSeq[String])) { + def printElement(w: (String) => Unit, v: (Locus, IndexedSeq[String])): Unit = { val (locus, alleles) = v val sb = new StringBuilder() @@ -69,7 +69,7 @@ object VEP { val rc = proc.waitFor() if (rc != 0) { - fatal(s"VEP command '${ cmd.mkString(" ") }' failed with non-zero exit status $rc\n" + + fatal(s"VEP command '${cmd.mkString(" ")}' failed with non-zero exit status $rc\n" + " VEP Error output:\n" + err.toString) } } @@ -80,10 +80,12 @@ object VEP { val env = pb.environment() confEnv.foreach { case (key, value) => env.put(key, value) } - val (jt, err, proc) = List((Locus("1", 13372), FastSeq("G", "C"))).iterator.pipe(pb, + val (jt, err, proc) = List((Locus("1", 13372), FastSeq("G", "C"))).iterator.pipe( + pb, printContext, printElement, - _ => ()) + _ => (), + ) val csqHeader = jt.flatMap(s => csqHeaderRegex.findFirstMatchIn(s).map(m => m.group(1))) waitFor(proc, err, cmd) @@ -101,11 +103,12 @@ object VEP { new VEP(params, conf) } - def apply(fs: FS, config: String, csq: Boolean, blockSize: Int, tolerateParseError: Boolean): VEP = + def apply(fs: FS, config: String, csq: Boolean, blockSize: Int, tolerateParseError: Boolean) + : VEP = VEP(fs, VEPParameters(config, csq, blockSize, tolerateParseError)) def fromJValue(fs: FS, jv: JValue): VEP = { - log.info(s"vep config json: ${ jv.toString }") + log.info(s"vep config json: ${jv.toString}") implicit val formats: Formats = RelationalFunctions.formats val params = jv.extract[VEPParameters] VEP(fs, params) @@ -124,9 +127,11 @@ class VEP(val params: VEPParameters, conf: VEPConfiguration) extends TableToTabl override def typ(childType: TableType): TableType = { val vepType = if (params.csq) TArray(TString) else vepSignature val globType = if (params.csq) TStruct("vep_csq_header" -> TString) else TStruct.empty - TableType(childType.rowType ++ TStruct("vep" -> vepType, "vep_proc_id" -> procIDType), + TableType( + childType.rowType ++ TStruct("vep" -> vepType, "vep_proc_id" -> procIDType), childType.key, - globType) + globType, + ) } override def execute(ctx: ExecuteContext, tv: TableValue): TableValue = { @@ -141,7 +146,8 @@ class VEP(val params: VEPParameters, conf: VEPConfiguration) extends TableToTabl if (s == "__OUTPUT_FORMAT_FLAG__") if (csq) "--vcf" else "--json" else - s) + s + ) val csqHeader = if (csq) getCSQHeaderDefinition(cmd, localConf.env) else None @@ -175,12 +181,9 @@ class VEP(val params: VEPParameters, conf: VEPConfiguration) extends TableToTabl .zipWithIndex .flatMap { case (block, blockIdx) => val procID = Annotation(partIdx, blockIdx) - val (jt, err, proc) = block.iterator.pipe(pb, - printContext, - printElement, - _ => ()) + val (jt, err, proc) = block.iterator.pipe(pb, printContext, printElement, _ => ()) - val nonStarToOriginalVariant = block.map { case v@(locus, alleles) => + val nonStarToOriginalVariant = block.map { case v @ (locus, alleles) => (locus, alleles.filter(_ != "*")) -> v }.toMap @@ -188,37 +191,47 @@ class VEP(val params: VEPParameters, conf: VEPConfiguration) extends TableToTabl .filter(s => !s.isEmpty && s(0) != '#') .flatMap { s => if (csq) { - val vepv@(vepLocus, vepAlleles) = variantFromInput(s) + val vepv @ (vepLocus, vepAlleles) = variantFromInput(s) nonStarToOriginalVariant.get(vepv) match { - case Some(v@(locus, alleles)) => + case Some((locus, alleles)) => val x = csqRegex.findFirstIn(s) val a = x match { case Some(value) => value.substring(4).split(",").toFastSeq case None => - warn(s"No CSQ INFO field for VEP output variant ${ VariantMethods.locusAllelesToString(vepLocus, vepAlleles) }.\nVEP output: $s.") + warn( + s"No CSQ INFO field for VEP output variant ${VariantMethods.locusAllelesToString(vepLocus, vepAlleles)}.\nVEP output: $s." + ) null } Some((Annotation(locus, alleles), a)) case None => - fatal(s"VEP output variant ${ VariantMethods.locusAllelesToString(vepLocus, vepAlleles) } not found in original variants.\nVEP output: $s") + fatal( + s"VEP output variant ${VariantMethods.locusAllelesToString(vepLocus, vepAlleles)} not found in original variants.\nVEP output: $s" + ) } } else { try { val jv = JsonMethods.parse(s) - val a = JSONAnnotationImpex.importAnnotation(jv, localVepSignature, warnContext = warnContext) + val a = JSONAnnotationImpex.importAnnotation( + jv, + localVepSignature, + warnContext = warnContext, + ) val variantString = inputQuery(a).asInstanceOf[String] if (variantString == null) fatal(s"VEP generated null variant string" + s"\n json: $s" + s"\n parsed: $a") - val vepv@(vepLocus, vepAlleles) = variantFromInput(variantString) + val vepv @ (vepLocus, vepAlleles) = variantFromInput(variantString) nonStarToOriginalVariant.get(vepv) match { - case Some(v@(locus, alleles)) => + case Some((locus, alleles)) => Some((Annotation(locus, alleles), a)) case None => - fatal(s"VEP output variant ${ VariantMethods.locusAllelesToString(vepLocus, vepAlleles) } not found in original variants.\nVEP output: $s") + fatal( + s"VEP output variant ${VariantMethods.locusAllelesToString(vepLocus, vepAlleles)} not found in original variants.\nVEP output: $s" + ) } } catch { case e: JsonParseException if localTolerateParseError => @@ -240,9 +253,11 @@ class VEP(val params: VEPParameters, conf: VEPConfiguration) extends TableToTabl val vepType: Type = if (params.csq) TArray(TString) else vepSignature - val vepRVDType = prev.typ.copy(rowType = prev.rowPType - .appendKey("vep", PType.canonical(vepType)) - .appendKey("vep_proc_id", PType.canonical(procIDType, true, true))) + val vepRVDType = prev.typ.copy(rowType = + prev.rowPType + .appendKey("vep", PType.canonical(vepType)) + .appendKey("vep_proc_id", PType.canonical(procIDType, true, true)) + ) val vepRowType = vepRVDType.rowType @@ -263,7 +278,8 @@ class VEP(val params: VEPParameters, conf: VEPConfiguration) extends TableToTabl rvb.end() } - }) + }, + ) val globalValue = if (params.csq) @@ -272,15 +288,11 @@ class VEP(val params: VEPParameters, conf: VEPConfiguration) extends TableToTabl Row() val newTT = typ(tv.typ) - TableValue(ctx, - newTT, - BroadcastRow(ctx, globalValue, newTT.globalType), - vepRVD) + TableValue(ctx, newTT, BroadcastRow(ctx, globalValue, newTT.globalType), vepRVD) } - override def toJValue: JValue = { + override def toJValue: JValue = decomposeWithName(params, "VEP")(RelationalFunctions.formats) - } override def hashCode(): Int = params.hashCode() diff --git a/hail/src/main/scala/is/hail/misc/BGZipBlocks.scala b/hail/src/main/scala/is/hail/misc/BGZipBlocks.scala index d364bfce4e4..e5461cd7c11 100644 --- a/hail/src/main/scala/is/hail/misc/BGZipBlocks.scala +++ b/hail/src/main/scala/is/hail/misc/BGZipBlocks.scala @@ -1,14 +1,14 @@ package is.hail.misc -import java.io.InputStream - import is.hail.io.compress.BGzipInputStream import is.hail.io.fs.FS +import java.io.InputStream + object BGZipBlocks { - //Print block starts of block gzip (bgz) file - def apply(fs: FS, file: String) { - var buf = new Array[Byte](64 * 1024) + // Print block starts of block gzip (bgz) file + def apply(fs: FS, file: String): Unit = { + val buf = new Array[Byte](64 * 1024) // position of 'buf[0]' in input stream var bufPos = 0L @@ -16,7 +16,7 @@ object BGZipBlocks { var bufSize = 0 var posInBuf = 0 - def fillBuf(is: InputStream) { + def fillBuf(is: InputStream): Unit = { val newSize = bufSize - posInBuf assert(newSize >= 0) @@ -25,7 +25,7 @@ object BGZipBlocks { bufSize = newSize posInBuf = 0 - def f() { + def f(): Unit = { val needed = buf.length - bufSize if (needed > 0) { val result = is.read(buf, bufSize, needed) diff --git a/hail/src/main/scala/is/hail/package.scala b/hail/src/main/scala/is/hail/package.scala index 9c9de17b03d..c281141543f 100644 --- a/hail/src/main/scala/is/hail/package.scala +++ b/hail/src/main/scala/is/hail/package.scala @@ -6,43 +6,63 @@ package object hail { private object HailBuildInfo { - import java.util.Properties - import is.hail.utils._ + import java.util.Properties + val ( - hail_build_user: String, hail_revision: String, - hail_branch: String, - hail_build_date: String, hail_spark_version: String, - hail_pip_version: String) = { - - loadFromResource[(String, String, String, String, String, String)]("build-info.properties") { + hail_pip_version: String, + hail_build_configuration: BuildConfiguration, + ) = + loadFromResource[(String, String, String, BuildConfiguration)]("build-info.properties") { (is: InputStream) => val unknownProp = "" val props = new Properties() props.load(is) ( - props.getProperty("user", unknownProp), props.getProperty("revision", unknownProp), - props.getProperty("branch", unknownProp), - props.getProperty("date", unknownProp), props.getProperty("sparkVersion", unknownProp), - props.getProperty("hailPipVersion", unknownProp) - ) + props.getProperty("hailPipVersion", unknownProp), { + val c = props.getProperty("hailBuildConfiguration", "debug") + BuildConfiguration.parseString(c).getOrElse( + throw new IllegalArgumentException( + s"Illegal 'hailBuildConfiguration' entry in 'build-info.properties': '$c'." + ) + ) + }, + ) } - } } - val HAIL_BUILD_USER = HailBuildInfo.hail_build_user val HAIL_REVISION = HailBuildInfo.hail_revision - val HAIL_BRANCH = HailBuildInfo.hail_branch - val HAIL_BUILD_DATE = HailBuildInfo.hail_build_date val HAIL_SPARK_VERSION = HailBuildInfo.hail_spark_version val HAIL_PIP_VERSION = HailBuildInfo.hail_pip_version // FIXME: probably should use tags or something to choose English name val HAIL_PRETTY_VERSION = HAIL_PIP_VERSION + "-" + HAIL_REVISION.substring(0, 12) + val HAIL_BUILD_CONFIGURATION = HailBuildInfo.hail_build_configuration +} + +sealed trait BuildConfiguration extends Product with Serializable { + def isDebug: Boolean +} + +object BuildConfiguration { + case object Release extends BuildConfiguration { + override def isDebug: Boolean = false + } + + case object Debug extends BuildConfiguration { + override def isDebug: Boolean = true + } + + def parseString(c: String): Option[BuildConfiguration] = + c match { + case "release" => Some(BuildConfiguration.Release) + case "debug" => Some(BuildConfiguration.Debug) + case _ => None + } } diff --git a/hail/src/main/scala/is/hail/rvd/AbstractRVDSpec.scala b/hail/src/main/scala/is/hail/rvd/AbstractRVDSpec.scala index 384784a8785..e6d83d4333d 100644 --- a/hail/src/main/scala/is/hail/rvd/AbstractRVDSpec.scala +++ b/hail/src/main/scala/is/hail/rvd/AbstractRVDSpec.scala @@ -3,9 +3,12 @@ package is.hail.rvd import is.hail.annotations._ import is.hail.backend.{ExecuteContext, HailStateManager} import is.hail.compatibility +import is.hail.expr.{ir, JSONAnnotationImpex} +import is.hail.expr.ir.{ + flatMapIR, IR, Literal, PartitionNativeReader, PartitionZippedIndexedNativeReader, + PartitionZippedNativeReader, ReadPartition, Ref, ToStream, +} import is.hail.expr.ir.lowering.{TableStage, TableStageDependency} -import is.hail.expr.ir.{IR, Literal, PartitionNativeReader, PartitionZippedIndexedNativeReader, PartitionZippedNativeReader, ReadPartition, ToStream} -import is.hail.expr.{JSONAnnotationImpex, ir} import is.hail.io._ import is.hail.io.fs.FS import is.hail.io.index.{InternalNodeBuilder, LeafNodeBuilder} @@ -14,36 +17,41 @@ import is.hail.types.encoded.ETypeSerializer import is.hail.types.physical._ import is.hail.types.virtual._ import is.hail.utils._ + import org.apache.spark.TaskContext import org.apache.spark.sql.Row -import org.json4s.jackson.{JsonMethods, Serialization} import org.json4s.{DefaultFormats, Formats, JValue, ShortTypeHints} +import org.json4s.jackson.{JsonMethods, Serialization} object AbstractRVDSpec { - implicit val formats: Formats = new DefaultFormats() { - override val typeHints = ShortTypeHints(List( - classOf[AbstractRVDSpec], - classOf[OrderedRVDSpec2], - classOf[IndexedRVDSpec2], - classOf[IndexSpec2], - classOf[compatibility.OrderedRVDSpec], - classOf[compatibility.IndexedRVDSpec], - classOf[compatibility.IndexSpec], - classOf[compatibility.UnpartitionedRVDSpec], - classOf[AbstractTypedCodecSpec], - classOf[TypedCodecSpec]), - typeHintFieldName = "name") + BufferSpec.shortTypeHints - } + - new TStructSerializer + - new TypeSerializer + - new PTypeSerializer + - new RVDTypeSerializer + - new ETypeSerializer + implicit val formats: Formats = + new DefaultFormats() { + override val typeHints = ShortTypeHints( + List( + classOf[AbstractRVDSpec], + classOf[OrderedRVDSpec2], + classOf[IndexedRVDSpec2], + classOf[IndexSpec2], + classOf[compatibility.OrderedRVDSpec], + classOf[compatibility.IndexedRVDSpec], + classOf[compatibility.IndexSpec], + classOf[compatibility.UnpartitionedRVDSpec], + classOf[AbstractTypedCodecSpec], + classOf[TypedCodecSpec], + ), + typeHintFieldName = "name", + ) + BufferSpec.shortTypeHints + } + + new TStructSerializer + + new TypeSerializer + + new PTypeSerializer + + new RVDTypeSerializer + + new ETypeSerializer def read(fs: FS, path: String): AbstractRVDSpec = { try { val metadataFile = path + "/metadata.json.gz" - using(fs.open(metadataFile)) { in => JsonMethods.parse(in) } + using(fs.open(metadataFile))(in => JsonMethods.parse(in)) .transformField { case ("orvdType", value) => ("rvdType", value) } // ugh .extract[AbstractRVDSpec] } catch { @@ -58,7 +66,7 @@ object AbstractRVDSpec { path: String, rowType: PStruct, bufferSpec: BufferSpec, - rows: IndexedSeq[Annotation] + rows: IndexedSeq[Annotation], ): Array[FileWriteMetadata] = { val fs = execCtx.fs val partsPath = path + "/parts" @@ -73,16 +81,19 @@ object AbstractRVDSpec { val (part0Count, bytesWritten) = using(fs.create(partsPath + "/" + filePath)) { os => using(RVDContext.default(execCtx.r.pool)) { ctx => - val rvb = ctx.rvb RichContextRDDRegionValue.writeRowsPartition(codecSpec.buildEncoder(execCtx, rowType))( ctx, rows.iterator.map { a => rowType.unstagedStoreJavaObject(execCtx.stateManager, a, ctx.r) - }, os, null) + }, + os, + null, + ) } } - val spec = MakeRVDSpec(codecSpec, Array(filePath), RVDPartitioner.unkeyed(execCtx.stateManager, 1)) + val spec = + MakeRVDSpec(codecSpec, Array(filePath), RVDPartitioner.unkeyed(execCtx.stateManager, 1)) spec.write(fs, path) Array(FileWriteMetadata(path, part0Count, bytesWritten)) @@ -98,24 +109,26 @@ object AbstractRVDSpec { filterIntervals: Boolean, requestedType: TStruct, requestedKey: IndexedSeq[String], - uidFieldName: String + uidFieldName: String, ): IR => TableStage = { require(specRight.key.isEmpty) val partitioner = specLeft.partitioner(ctx.stateManager) newPartitioner match { case None => - val reader = PartitionZippedNativeReader( PartitionNativeReader(specLeft.typedCodecSpec, uidFieldName), - PartitionNativeReader(specRight.typedCodecSpec, uidFieldName)) + PartitionNativeReader(specRight.typedCodecSpec, uidFieldName), + ) val leftParts = specLeft.absolutePartPaths(pathLeft) val rightParts = specRight.absolutePartPaths(pathRight) assert(leftParts.length == rightParts.length) val contextsValue: IndexedSeq[Any] = (leftParts, rightParts, leftParts.indices) .zipped - .map { (path1, path2, partIdx) => Row(Row(partIdx.toLong, path1), Row(partIdx.toLong, path2)) } + .map { (path1, path2, partIdx) => + Row(Row(partIdx.toLong, path1), Row(partIdx.toLong, path2)) + } val ctxIR = ToStream(Literal(TArray(reader.contextType), contextsValue)) @@ -128,12 +141,13 @@ object AbstractRVDSpec { partitioner.coarsen(requestedKey.length), TableStageDependency.none, ctxIR, - ReadPartition(_, requestedType, reader)) + ReadPartition(_, requestedType, reader), + ) } case Some(np) => val (indexSpecLeft, indexSpecRight) = (specLeft, specRight) match { - case (l: Indexed, r: Indexed) => (l.indexSpec, r.indexSpec) + case (l: Indexed, r: Indexed) => (l.indexSpec, r.indexSpec) case _ => throw new RuntimeException(s"attempted to read unindexed table as indexed") } @@ -143,32 +157,39 @@ object AbstractRVDSpec { val extendedNewPartitioner = np.extendKey(partitioner.kType) val tmpPartitioner = extendedNewPartitioner.intersect(partitioner) - val partKeyPrefix = tmpPartitioner.kType.fieldNames.slice(0, requestedKey.length).toIndexedSeq + val partKeyPrefix = + tmpPartitioner.kType.fieldNames.slice(0, requestedKey.length).toIndexedSeq assert(requestedKey == partKeyPrefix, s"$requestedKey != $partKeyPrefix") val reader = PartitionZippedIndexedNativeReader( - specLeft.typedCodecSpec, specRight.typedCodecSpec, - indexSpecLeft, indexSpecRight, - specLeft.key, uidFieldName) + specLeft.typedCodecSpec, + specRight.typedCodecSpec, + indexSpecLeft, + indexSpecRight, + specLeft.key, + uidFieldName, + ) val absPathLeft = pathLeft val absPathRight = pathRight val partsAndIntervals: IndexedSeq[(String, Interval)] = if (specLeft.key.isEmpty) { - specLeft.partFiles.map { p => (p, null) } + specLeft.partFiles.map(p => (p, null)) } else { val partFiles = specLeft.partFiles - tmpPartitioner.rangeBounds.map { b => (partFiles(partitioner.lowerBoundInterval(b)), b) } + tmpPartitioner.rangeBounds.map(b => (partFiles(partitioner.lowerBoundInterval(b)), b)) } val kSize = specLeft.key.size - val contextsValues: IndexedSeq[Row] = partsAndIntervals.zipWithIndex.map { case ((partPath, interval), partIdx) => - Row( - partIdx.toLong, - s"${ absPathLeft }/parts/${ partPath }", - s"${ absPathRight }/parts/${ partPath }", - s"${ absPathLeft }/${ indexSpecLeft.relPath }/${ partPath }.idx", - RVDPartitioner.intervalToIRRepresentation(interval, kSize)) - } + val contextsValues: IndexedSeq[Row] = + partsAndIntervals.zipWithIndex.map { case ((partPath, interval), partIdx) => + Row( + partIdx.toLong, + s"$absPathLeft/parts/$partPath", + s"$absPathRight/parts/$partPath", + s"$absPathLeft/${indexSpecLeft.relPath}/$partPath.idx", + RVDPartitioner.intervalToIRRepresentation(interval, kSize), + ) + } val contexts = ir.ToStream(ir.Literal(TArray(reader.contextType), contextsValues)) @@ -180,7 +201,8 @@ object AbstractRVDSpec { tmpPartitioner.coarsen(requestedKey.length), TableStageDependency.none, contexts, - body) + body, + ) if (filterIntervals) ts.repartitionNoShuffle(ctx, partitioner, dropEmptyPartitions = true) else @@ -217,13 +239,13 @@ abstract class AbstractRVDSpec { requestedType: TableType, uidFieldName: String, newPartitioner: Option[RVDPartitioner] = None, - filterIntervals: Boolean = false + filterIntervals: Boolean = false, ): IR => TableStage = newPartitioner match { case Some(_) => fatal("attempted to read unindexed data as indexed") case None => val part = partitioner(ctx.stateManager) if (!part.kType.fieldNames.startsWith(requestedType.key)) - fatal(s"Error while reading table ${ path }: legacy table written without key." + + fatal(s"Error while reading table $path: legacy table written without key." + s"\n Read and write with version 0.2.70 or earlier") val rSpec = typedCodecSpec @@ -233,7 +255,8 @@ abstract class AbstractRVDSpec { TArray(ctxType), absolutePartPaths(path).zipWithIndex.map { case (x, i) => Row(i.toLong, x) - }.toFastSeq)) + }.toFastSeq, + )) val body = (ctx: IR) => ir.ReadPartition(ctx, requestedType.rowType, ir.PartitionNativeReader(rSpec, uidFieldName)) @@ -244,15 +267,15 @@ abstract class AbstractRVDSpec { part.coarsen(part.kType.fieldNames.takeWhile(requestedType.rowType.hasField).length), TableStageDependency.none, contexts, - body) + body, + ) } - def write(fs: FS, path: String) { + def write(fs: FS, path: String): Unit = using(fs.create(path + "/metadata.json.gz")) { out => import AbstractRVDSpec.formats Serialization.write(this, out) } - } } trait AbstractIndexSpec { @@ -268,17 +291,19 @@ trait AbstractIndexSpec { def offsetField: Option[String] = None - def offsetFieldIndex: Option[Int] = offsetField.map(f => annotationType.asInstanceOf[TStruct].fieldIdx(f)) + def offsetFieldIndex: Option[Int] = + offsetField.map(f => annotationType.asInstanceOf[TStruct].fieldIdx(f)) def types: (Type, Type) = (keyType, annotationType) } -case class IndexSpec2(_relPath: String, +case class IndexSpec2( + _relPath: String, _leafCodec: AbstractTypedCodecSpec, _internalNodeCodec: AbstractTypedCodecSpec, _keyType: Type, _annotationType: Type, - _offsetField: Option[String] = None + _offsetField: Option[String] = None, ) extends AbstractIndexSpec { def relPath: String = _relPath @@ -293,25 +318,40 @@ case class IndexSpec2(_relPath: String, override def offsetField: Option[String] = _offsetField } - object IndexSpec { - def fromKeyAndValuePTypes(relPath: String, keyPType: PType, annotationPType: PType, offsetFieldName: Option[String]): AbstractIndexSpec = { + def fromKeyAndValuePTypes( + relPath: String, + keyPType: PType, + annotationPType: PType, + offsetFieldName: Option[String], + ): AbstractIndexSpec = { val leafType = LeafNodeBuilder.typ(keyPType, annotationPType) val leafNodeSpec = TypedCodecSpec(leafType, BufferSpec.default) val internalType = InternalNodeBuilder.typ(keyPType, annotationPType) val internalNodeSpec = TypedCodecSpec(internalType, BufferSpec.default) - IndexSpec2(relPath, leafNodeSpec, internalNodeSpec, keyPType.virtualType, annotationPType.virtualType, offsetFieldName) + IndexSpec2( + relPath, + leafNodeSpec, + internalNodeSpec, + keyPType.virtualType, + annotationPType.virtualType, + offsetFieldName, + ) } - def emptyAnnotation(relPath: String, keyType: PStruct): AbstractIndexSpec = { + def emptyAnnotation(relPath: String, keyType: PStruct): AbstractIndexSpec = fromKeyAndValuePTypes(relPath, keyType, PCanonicalStruct(required = true), None) - } - def defaultAnnotation(relPath: String, keyType: PStruct, withOffsetField: Boolean = false): AbstractIndexSpec = { + def defaultAnnotation(relPath: String, keyType: PStruct, withOffsetField: Boolean = false) + : AbstractIndexSpec = { val name = "entries_offset" - fromKeyAndValuePTypes(relPath, keyType, PCanonicalStruct(required = true, name -> PInt64Optional), - if (withOffsetField) Some(name) else None) + fromKeyAndValuePTypes( + relPath, + keyType, + PCanonicalStruct(required = true, name -> PInt64Optional), + if (withOffsetField) Some(name) else None, + ) } } @@ -321,23 +361,27 @@ object MakeRVDSpec { partFiles: Array[String], partitioner: RVDPartitioner, indexSpec: AbstractIndexSpec = null, - attrs: Map[String, String] = Map.empty + attrs: Map[String, String] = Map.empty, ): AbstractRVDSpec = RVDSpecMaker(codecSpec, partitioner, indexSpec, attrs)(partFiles) } object RVDSpecMaker { - def apply(codecSpec: AbstractTypedCodecSpec, + def apply( + codecSpec: AbstractTypedCodecSpec, partitioner: RVDPartitioner, indexSpec: AbstractIndexSpec = null, - attrs: Map[String, String] = Map.empty): RVDSpecMaker = RVDSpecMaker( + attrs: Map[String, String] = Map.empty, + ): RVDSpecMaker = RVDSpecMaker( codecSpec, partitioner.kType.fieldNames, JSONAnnotationImpex.exportAnnotation( partitioner.rangeBounds.toFastSeq, - partitioner.rangeBoundsType), + partitioner.rangeBoundsType, + ), indexSpec, - attrs) + attrs, + ) } case class RVDSpecMaker( @@ -345,33 +389,35 @@ case class RVDSpecMaker( key: IndexedSeq[String], bounds: JValue, indexSpec: AbstractIndexSpec, - attrs: Map[String, String]) { + attrs: Map[String, String], +) { def apply(partFiles: Array[String]): AbstractRVDSpec = Option(indexSpec) match { case Some(ais) => IndexedRVDSpec2( - key, - codecSpec, - ais, - partFiles, - bounds, - attrs) + key, + codecSpec, + ais, + partFiles, + bounds, + attrs) case None => OrderedRVDSpec2( - key, - codecSpec, - partFiles, - bounds, - attrs - ) + key, + codecSpec, + partFiles, + bounds, + attrs, + ) } } object IndexedRVDSpec2 { - def apply(key: IndexedSeq[String], + def apply( + key: IndexedSeq[String], codecSpec: AbstractTypedCodecSpec, indexSpec: AbstractIndexSpec, partFiles: Array[String], partitioner: RVDPartitioner, - attrs: Map[String, String] + attrs: Map[String, String], ): AbstractRVDSpec = { IndexedRVDSpec2( key, @@ -380,8 +426,10 @@ object IndexedRVDSpec2 { partFiles, JSONAnnotationImpex.exportAnnotation( partitioner.rangeBounds.toFastSeq, - partitioner.rangeBoundsType), - attrs) + partitioner.rangeBoundsType, + ), + attrs, + ) } } @@ -391,7 +439,7 @@ case class IndexedRVDSpec2( _indexSpec: AbstractIndexSpec, _partFiles: Array[String], _jRangeBounds: JValue, - _attrs: Map[String, String] + _attrs: Map[String, String], ) extends AbstractRVDSpec with Indexed { // some lagacy OrderedRVDSpec2 were written out without the toplevel encoder required @@ -409,8 +457,15 @@ case class IndexedRVDSpec2( def partitioner(sm: HailStateManager): RVDPartitioner = { val keyType = codecSpec2.encodedVirtualType.asInstanceOf[TStruct].select(key)._1 val rangeBoundsType = TArray(TInterval(keyType)) - new RVDPartitioner(sm, keyType, - JSONAnnotationImpex.importAnnotation(_jRangeBounds, rangeBoundsType, padNulls = false).asInstanceOf[IndexedSeq[Interval]]) + new RVDPartitioner( + sm, + keyType, + JSONAnnotationImpex.importAnnotation( + _jRangeBounds, + rangeBoundsType, + padNulls = false, + ).asInstanceOf[IndexedSeq[Interval]], + ) } def partFiles: Array[String] = _partFiles @@ -425,51 +480,88 @@ case class IndexedRVDSpec2( requestedType: TableType, uidFieldName: String, newPartitioner: Option[RVDPartitioner] = None, - filterIntervals: Boolean = false + filterIntervals: Boolean = false, ): IR => TableStage = newPartitioner match { case Some(np) => - val part = partitioner(ctx.stateManager) + /* ensure the old and new partitioners have the same key, and ensure the new partitioner is + * strict */ val extendedNP = np.extendKey(part.kType) - val tmpPartitioner = part.intersect(extendedNP) assert(key.nonEmpty) - val rSpec = typedCodecSpec - val reader = ir.PartitionNativeReaderIndexed(rSpec, indexSpec, part.kType.fieldNames, uidFieldName) - - val absPath = path - val partPaths = tmpPartitioner.rangeBounds.map { b => partFiles(part.lowerBoundInterval(b)) } - + val reader = ir.PartitionNativeReaderIndexed( + typedCodecSpec, + indexSpec, + part.kType.fieldNames, + uidFieldName, + ) - val kSize = part.kType.size - absolutePartPaths(path) - assert(tmpPartitioner.rangeBounds.size == partPaths.length) - val contextsValues: IndexedSeq[Row] = tmpPartitioner.rangeBounds.map { interval => - val partIdx = part.lowerBoundInterval(interval) - val partPath = partFiles(partIdx) + def makeCtx(oldPartIdx: Int, newPartIdx: Int): Row = { + val oldInterval = part.rangeBounds(oldPartIdx) + val partFile = partFiles(oldPartIdx) + val intersectionInterval = + extendedNP.rangeBounds(newPartIdx) + .intersect(extendedNP.kord, oldInterval).get Row( - partIdx.toLong, - s"${ absPath }/parts/${ partPath }", - s"${ absPath }/${ indexSpec.relPath }/${ partPath }.idx", - RVDPartitioner.intervalToIRRepresentation(interval, kSize)) + oldPartIdx.toLong, + s"$path/parts/$partFile", + s"$path/${indexSpec.relPath}/$partFile.idx", + RVDPartitioner.intervalToIRRepresentation(intersectionInterval, part.kType.size), + ) } - assert(TArray(reader.contextType).typeCheck(contextsValues)) - - val contexts = ir.ToStream(ir.Literal(TArray(reader.contextType), contextsValues)) + val (nestedContexts, newPartitioner) = if (filterIntervals) { + /* We want to filter to intervals in newPartitioner, while preserving the old partitioning, + * but dropping any partitions we know would be empty. So we construct a map from old + * partitions to the range of overlapping new partitions, dropping any with an empty range. */ + val contextsAndBounds = for { + (oldInterval, oldPartIdx) <- part.rangeBounds.toFastSeq.zipWithIndex + overlapRange = extendedNP.queryInterval(oldInterval) + if overlapRange.nonEmpty + } yield { + val ctxs = overlapRange.map(newPartIdx => makeCtx(oldPartIdx, newPartIdx)) + // the interval spanning all overlapping filter intervals + val newInterval = Interval( + extendedNP.rangeBounds(overlapRange.head).left, + extendedNP.rangeBounds(overlapRange.last).right, + ) + ( + ctxs, + // Shrink oldInterval to the rows filtered to. + // By construction we know oldInterval and newInterval overlap + oldInterval.intersect(extendedNP.kord, newInterval).get, + ) + } + val (nestedContexts, newRangeBounds) = contextsAndBounds.unzip + + (nestedContexts, new RVDPartitioner(part.sm, part.kType, newRangeBounds)) + } else { + /* We want to use newPartitioner as the partitioner, dropping any rows not contained in any + * new partition. So we construct a map from new partitioner to the range of overlapping old + * partitions. */ + val nestedContexts = + extendedNP.rangeBounds.toFastSeq.zipWithIndex.map { case (newInterval, newPartIdx) => + val overlapRange = part.queryInterval(newInterval) + overlapRange.map(oldPartIdx => makeCtx(oldPartIdx, newPartIdx)) + } + + (nestedContexts, extendedNP) + } - val body = (ctx: IR) => ir.ReadPartition(ctx, requestedType.rowType, reader) + assert(TArray(TArray(reader.contextType)).typeCheck(nestedContexts)) { (globals: IR) => - val ts = TableStage( + TableStage( globals, - tmpPartitioner, + newPartitioner, TableStageDependency.none, - contexts, - body) - if (filterIntervals) ts.repartitionNoShuffle(ctx, part, dropEmptyPartitions = true) - else ts.repartitionNoShuffle(ctx, extendedNP) + contexts = ir.ToStream(ir.Literal(TArray(TArray(reader.contextType)), nestedContexts)), + body = (ctxs: Ref) => + flatMapIR(ToStream(ctxs, true)) { ctx => + ir.ReadPartition(ctx, requestedType.rowType, reader) + }, + ) } case None => @@ -482,7 +574,7 @@ case class OrderedRVDSpec2( _codecSpec: AbstractTypedCodecSpec, _partFiles: Array[String], _jRangeBounds: JValue, - _attrs: Map[String, String] + _attrs: Map[String, String], ) extends AbstractRVDSpec { // some legacy OrderedRVDSpec2 were written out without the toplevel encoder required @@ -496,8 +588,15 @@ case class OrderedRVDSpec2( def partitioner(sm: HailStateManager): RVDPartitioner = { val keyType = codecSpec2.encodedVirtualType.asInstanceOf[TStruct].select(key)._1 val rangeBoundsType = TArray(TInterval(keyType)) - new RVDPartitioner(sm, keyType, - JSONAnnotationImpex.importAnnotation(_jRangeBounds, rangeBoundsType, padNulls = false).asInstanceOf[IndexedSeq[Interval]]) + new RVDPartitioner( + sm, + keyType, + JSONAnnotationImpex.importAnnotation( + _jRangeBounds, + rangeBoundsType, + padNulls = false, + ).asInstanceOf[IndexedSeq[Interval]], + ) } def partFiles: Array[String] = _partFiles diff --git a/hail/src/main/scala/is/hail/rvd/KeyedRVD.scala b/hail/src/main/scala/is/hail/rvd/KeyedRVD.scala index 4346c48b1bd..15063b71064 100644 --- a/hail/src/main/scala/is/hail/rvd/KeyedRVD.scala +++ b/hail/src/main/scala/is/hail/rvd/KeyedRVD.scala @@ -2,9 +2,9 @@ package is.hail.rvd import is.hail.annotations._ import is.hail.backend.ExecuteContext +import is.hail.sparkextras._ import is.hail.types.physical.PStruct import is.hail.types.virtual.TInterval -import is.hail.sparkextras._ import is.hail.utils._ import scala.collection.generic.Growable @@ -15,23 +15,27 @@ class KeyedRVD(val rvd: RVD, val key: Int) { val virtType = RVDType(realType.rowType, realType.key.take(key)) val (kType, _) = rvd.rowType.select(virtType.key) - private def checkJoinCompatability(right: KeyedRVD) { - if (!(kType isIsomorphicTo right.kType)) + private def checkJoinCompatability(right: KeyedRVD): Unit = { + if (!(kType isJoinableWith right.kType)) fatal( s"""Incompatible join keys. Keys must have same length and types, in order: - | Left join key type: ${ kType.toString } - | Right join key type: ${ right.kType.toString } - """.stripMargin) + | Left join key type: ${kType.toString} + | Right join key type: ${right.kType.toString} + """.stripMargin + ) } - private def checkLeftIntervalJoinCompatability(right: KeyedRVD) { - if (!(kType.size == 1 && right.kType.size == 1 - && kType.types(0) == right.kType.types(0).asInstanceOf[TInterval].pointType)) + private def checkLeftIntervalJoinCompatability(right: KeyedRVD): Unit = { + if ( + !(kType.size == 1 && right.kType.size == 1 + && kType.types(0) == right.kType.types(0).asInstanceOf[TInterval].pointType) + ) fatal( s"""Incompatible join keys in left interval join: - | Left join key type: ${ kType.toString } - | Right join key type: ${ right.kType.toString } - """.stripMargin) + | Left join key type: ${kType.toString} + | Right join key type: ${right.kType.toString} + """.stripMargin + ) } // 'joinedType.key' must be the join key, followed by the remaining left key, @@ -43,7 +47,7 @@ class KeyedRVD(val rvd: RVD, val key: Int) { joinType: String, joiner: (RVDContext, Iterator[JoinedRegionValue]) => Iterator[RegionValue], joinedType: RVDType, - ctx: ExecuteContext + ctx: ExecuteContext, ): RVD = { checkJoinCompatability(right) @@ -56,14 +60,16 @@ class KeyedRVD(val rvd: RVD, val key: Int) { case "right" => rightPart case "inner" => leftPart.intersect(rightPart) case "outer" => RVDPartitioner.generate( - sm, - kType.fieldNames, - realType.kType.virtualType, - leftPart.rangeBounds ++ rightPart.rangeBounds) + sm, + kType.fieldNames, + realType.kType.virtualType, + leftPart.rangeBounds ++ rightPart.rangeBounds, + ) } } val repartitionedLeft = rvd.repartition(ctx, newPartitioner) - val compute: (OrderedRVIterator, OrderedRVIterator, Iterable[RegionValue] with Growable[RegionValue]) => Iterator[JoinedRegionValue] = + val compute + : (OrderedRVIterator, OrderedRVIterator, Iterable[RegionValue] with Growable[RegionValue]) => Iterator[JoinedRegionValue] = (joinType: @unchecked) match { case "inner" => _.innerJoin(_, _) case "left" => _.leftJoin(_, _) @@ -77,7 +83,7 @@ class KeyedRVD(val rvd: RVD, val key: Int) { repartitionedLeft.alignAndZipPartitions( joinedType.copy(key = joinedType.key.take(realType.key.length)), right.rvd, - key + key, ) { (ctx, leftIt, rightIt) => val sideBuffer = ctx.freshRegion() joiner( @@ -85,14 +91,19 @@ class KeyedRVD(val rvd: RVD, val key: Int) { compute( OrderedRVIterator(lTyp, leftIt, ctx, sm), OrderedRVIterator(rTyp, rightIt, ctx, sm), - new RegionValueArrayBuffer(rRowPType, sideBuffer, sm))) + new RegionValueArrayBuffer(rRowPType, sideBuffer, sm), + ), + ) }.extendKeyPreservesPartitioning(ctx, joinedType.key) } def orderedLeftIntervalJoin( executeContext: ExecuteContext, right: KeyedRVD, - joiner: PStruct => (RVDType, (RVDContext, Iterator[Muple[RegionValue, Iterable[RegionValue]]]) => Iterator[RegionValue]) + joiner: PStruct => ( + RVDType, + (RVDContext, Iterator[Muple[RegionValue, Iterable[RegionValue]]]) => Iterator[RegionValue], + ), ): RVD = { checkLeftIntervalJoinCompatability(right) @@ -101,22 +112,25 @@ class KeyedRVD(val rvd: RVD, val key: Int) { val sm = executeContext.stateManager rvd.intervalAlignAndZipPartitions(executeContext, right.rvd) { - t: PStruct => { + t: PStruct => val (newTyp, f) = joiner(t) - (newTyp, (ctx: RVDContext, it: Iterator[RegionValue], intervals: Iterator[RegionValue]) => - f( - ctx, - OrderedRVIterator(lTyp, it, ctx, sm) - .leftIntervalJoin(OrderedRVIterator(rTyp, intervals, ctx, sm)))) - } + ( + newTyp, + (ctx: RVDContext, it: Iterator[RegionValue], intervals: Iterator[RegionValue]) => + f( + ctx, + OrderedRVIterator(lTyp, it, ctx, sm) + .leftIntervalJoin(OrderedRVIterator(rTyp, intervals, ctx, sm)), + ), + ) } } def orderedLeftIntervalJoinDistinct( executeContext: ExecuteContext, right: KeyedRVD, - joiner: PStruct => (RVDType, (RVDContext, Iterator[JoinedRegionValue]) => Iterator[RegionValue]) + joiner: PStruct => (RVDType, (RVDContext, Iterator[JoinedRegionValue]) => Iterator[RegionValue]), ): RVD = { checkLeftIntervalJoinCompatability(right) @@ -125,22 +139,25 @@ class KeyedRVD(val rvd: RVD, val key: Int) { val sm = executeContext.stateManager rvd.intervalAlignAndZipPartitions(executeContext, right.rvd) { - t: PStruct => { + t: PStruct => val (newTyp, f) = joiner(t) - (newTyp, (ctx: RVDContext, it: Iterator[RegionValue], intervals: Iterator[RegionValue]) => - f( - ctx, - OrderedRVIterator(lTyp, it, ctx, sm) - .leftIntervalJoinDistinct(OrderedRVIterator(rTyp, intervals, ctx, sm)))) - } + ( + newTyp, + (ctx: RVDContext, it: Iterator[RegionValue], intervals: Iterator[RegionValue]) => + f( + ctx, + OrderedRVIterator(lTyp, it, ctx, sm) + .leftIntervalJoinDistinct(OrderedRVIterator(rTyp, intervals, ctx, sm)), + ), + ) } } def orderedLeftJoinDistinct( right: KeyedRVD, joiner: (RVDContext, Iterator[JoinedRegionValue]) => Iterator[RegionValue], - joinedType: RVDType + joinedType: RVDType, ): RVD = { checkJoinCompatability(right) val lTyp = virtType @@ -150,17 +167,23 @@ class KeyedRVD(val rvd: RVD, val key: Int) { rvd.alignAndZipPartitions( joinedType, right.rvd, - key + key, ) { (ctx, leftIt, rightIt) => joiner( ctx, - OrderedRVIterator(lTyp, leftIt, ctx, sm).leftJoinDistinct(OrderedRVIterator(rTyp, rightIt, ctx, sm))) + OrderedRVIterator(lTyp, leftIt, ctx, sm).leftJoinDistinct(OrderedRVIterator( + rTyp, + rightIt, + ctx, + sm, + )), + ) } } def orderedMerge( right: KeyedRVD, - ctx: ExecuteContext + ctx: ExecuteContext, ): RVD = { checkJoinCompatability(right) require(this.realType.rowType == right.realType.rowType) @@ -170,7 +193,9 @@ class KeyedRVD(val rvd: RVD, val key: Int) { this.realType.rowType, ContextRDD.union( rvd.sparkContext, - Seq(this.rvd.crdd, right.rvd.crdd))) + Seq(this.rvd.crdd, right.rvd.crdd), + ), + ) val ranges = this.rvd.partitioner.coarsenedRangeBounds(key) ++ right.rvd.partitioner.coarsenedRangeBounds(key) @@ -184,7 +209,7 @@ class KeyedRVD(val rvd: RVD, val key: Int) { repartitionedLeft.alignAndZipPartitions( this.virtType, right.rvd, - key + key, ) { (ctx, leftIt, rightIt) => OrderedRVIterator(lType, leftIt, ctx, sm) .merge(OrderedRVIterator(rType, rightIt, ctx, sm)) diff --git a/hail/src/main/scala/is/hail/rvd/PartitionBoundOrdering.scala b/hail/src/main/scala/is/hail/rvd/PartitionBoundOrdering.scala index c238578155e..77a5e14c22d 100644 --- a/hail/src/main/scala/is/hail/rvd/PartitionBoundOrdering.scala +++ b/hail/src/main/scala/is/hail/rvd/PartitionBoundOrdering.scala @@ -1,16 +1,15 @@ package is.hail.rvd -import is.hail.annotations.{ExtendedOrdering, IntervalEndpointOrdering, SafeRow} +import is.hail.annotations.{ExtendedOrdering, IntervalEndpointOrdering} import is.hail.backend.{ExecuteContext, HailStateManager} -import is.hail.types.physical.{PStruct, PType} import is.hail.types.virtual._ import is.hail.utils._ + import org.apache.spark.sql.Row object PartitionBoundOrdering { - def apply(ctx: ExecuteContext, _kType: Type): ExtendedOrdering = { + def apply(ctx: ExecuteContext, _kType: Type): ExtendedOrdering = apply(ctx.stateManager, _kType) - } def apply(sm: HailStateManager, _kType: Type): ExtendedOrdering = { val kType = _kType.asInstanceOf[TBaseStruct] @@ -18,7 +17,6 @@ object PartitionBoundOrdering { new ExtendedOrdering { outer => - val missingEqual = true override def compareNonnull(x: T, y: T): Int = { @@ -128,7 +126,8 @@ object PartitionBoundOrdering { // Returns true if for any rows r1 and r2 with r1 < x and r2 > y, // the length of the largest common prefix of r1 and r2 is less than // or equal to 'allowedOverlap' - override def lteqWithOverlap(allowedOverlap: Int)(x: IntervalEndpoint, y: IntervalEndpoint): Boolean = { + override def lteqWithOverlap(allowedOverlap: Int)(x: IntervalEndpoint, y: IntervalEndpoint) + : Boolean = { require(allowedOverlap <= fieldOrd.length) val xp = x val yp = y @@ -147,14 +146,15 @@ object PartitionBoundOrdering { val cl = xpp.length compare ypp.length if (allowedOverlap == l) prefix == l || - (cl < 0 && xp.sign < 0) || - (cl > 0 && yp.sign > 0) || - (cl == 0 && xp.sign <= yp.sign) + (cl < 0 && xp.sign < 0) || + (cl > 0 && yp.sign > 0) || + (cl == 0 && xp.sign <= yp.sign) else (xpp.length <= allowedOverlap + 1 || ypp.length <= allowedOverlap + 1) && ( (cl < 0 && xp.sign < 0) || (cl > 0 && yp.sign > 0) || - (cl == 0 && xp.sign <= yp.sign)) + (cl == 0 && xp.sign <= yp.sign) + ) } } } diff --git a/hail/src/main/scala/is/hail/rvd/RVD.scala b/hail/src/main/scala/is/hail/rvd/RVD.scala index 3cc50270ffb..07756f86f49 100644 --- a/hail/src/main/scala/is/hail/rvd/RVD.scala +++ b/hail/src/main/scala/is/hail/rvd/RVD.scala @@ -2,9 +2,9 @@ package is.hail.rvd import is.hail.HailContext import is.hail.annotations._ -import is.hail.asm4s.{HailClassLoader, theHailClassLoaderForSparkWorkers} -import is.hail.backend.spark.{SparkBackend, SparkTaskContext} +import is.hail.asm4s.{theHailClassLoaderForSparkWorkers, HailClassLoader} import is.hail.backend.{ExecuteContext, HailStateManager, HailTaskContext} +import is.hail.backend.spark.{SparkBackend, SparkTaskContext} import is.hail.expr.ir.InferPType import is.hail.expr.ir.PruneDeadFields.isSupertype import is.hail.io._ @@ -13,17 +13,17 @@ import is.hail.sparkextras._ import is.hail.types._ import is.hail.types.physical.{PCanonicalStruct, PInt64, PStruct} import is.hail.types.virtual.{TInterval, TStruct} -import is.hail.utils.PartitionCounts.{PCSubsetOffset, getPCSubsetOffset, incrementalPCSubsetOffset} import is.hail.utils._ -import org.apache.commons.lang3.StringUtils +import is.hail.utils.PartitionCounts.{getPCSubsetOffset, incrementalPCSubsetOffset, PCSubsetOffset} + +import scala.reflect.ClassTag + +import java.util + +import org.apache.spark.{Partitioner, SparkContext, TaskContext} import org.apache.spark.rdd.{RDD, ShuffledRDD} import org.apache.spark.sql.Row import org.apache.spark.storage.StorageLevel -import org.apache.spark.{Partitioner, SparkContext, TaskContext} - -import java.util -import scala.language.existentials -import scala.reflect.ClassTag abstract class RVDCoercer(val fullType: RVDType) { final def coerce(typ: RVDType, crdd: ContextRDD[Long]): RVD = { @@ -38,12 +38,12 @@ abstract class RVDCoercer(val fullType: RVDType) { class RVD( val typ: RVDType, val partitioner: RVDPartitioner, - val crdd: ContextRDD[Long] + val crdd: ContextRDD[Long], ) { self => require(crdd.getNumPartitions == partitioner.numPartitions) - require(typ.kType.virtualType isIsomorphicTo partitioner.kType) + require(typ.kType.virtualType isJoinableWith partitioner.kType) // Basic accessors @@ -66,7 +66,6 @@ class RVD( new RVD(newTyp, newPartitioner, crdd) } - // Exporting def toRows: RDD[Row] = { @@ -81,22 +80,26 @@ class RVD( def stabilize(ctx: ExecuteContext, enc: AbstractTypedCodecSpec): RDD[Array[Byte]] = { val makeEnc = enc.buildEncoder(ctx, rowPType) - crdd.mapPartitions(it => RegionValue.toBytes(theHailClassLoaderForSparkWorkers, makeEnc, it)).run + crdd.mapPartitions(it => + RegionValue.toBytes(theHailClassLoaderForSparkWorkers, makeEnc, it) + ).run } def encodedRDD(ctx: ExecuteContext, enc: AbstractTypedCodecSpec): RDD[Array[Byte]] = stabilize(ctx, enc) - def keyedEncodedRDD(ctx: ExecuteContext, enc: AbstractTypedCodecSpec, key: IndexedSeq[String] = typ.key): RDD[(Any, Array[Byte])] = { + def keyedEncodedRDD( + ctx: ExecuteContext, + enc: AbstractTypedCodecSpec, + key: IndexedSeq[String] = typ.key, + ): RDD[(Any, Array[Byte])] = { val makeEnc = enc.buildEncoder(ctx, rowPType) val kFieldIdx = typ.copy(key = key).kFieldIdx val localRowPType = rowPType crdd.cmapPartitions { (ctx, it) => val encoder = new ByteArrayEncoder(theHailClassLoaderForSparkWorkers, makeEnc) - TaskContext.get.addTaskCompletionListener[Unit] { _ => - encoder.close() - } + TaskContext.get.addTaskCompletionListener[Unit](_ => encoder.close()) it.map { ptr => val keys: Any = SafeRow.selectFields(localRowPType, ctx.r, ptr)(kFieldIdx) val bytes = encoder.regionValueToBytes(ptr) @@ -109,12 +112,14 @@ class RVD( def enforceKey( execCtx: ExecuteContext, newKey: IndexedSeq[String], - isSorted: Boolean = false + isSorted: Boolean = false, ): RVD = { require(newKey.forall(rowType.hasField)) val sharedPrefixLength = typ.key.zip(newKey).takeWhile { case (l, r) => l == r }.length if (isSorted && sharedPrefixLength == 0 && newKey.nonEmpty) { - throw new IllegalArgumentException(s"$isSorted, $sharedPrefixLength, $newKey, ${ typ }, ${ partitioner }") + throw new IllegalArgumentException( + s"$isSorted, $sharedPrefixLength, $newKey, $typ, $partitioner" + ) } if (sharedPrefixLength == newKey.length) @@ -130,30 +135,35 @@ class RVD( // Key and partitioner manipulation def changeKey( execCtx: ExecuteContext, - newKey: IndexedSeq[String] + newKey: IndexedSeq[String], ): RVD = changeKey(execCtx, newKey, newKey.length) def changeKey( execCtx: ExecuteContext, newKey: IndexedSeq[String], - partitionKey: Int + partitionKey: Int, ): RVD = RVD.coerce(execCtx, typ.copy(key = newKey), partitionKey, this.crdd) def extendKeyPreservesPartitioning( ctx: ExecuteContext, - newKey: IndexedSeq[String] + newKey: IndexedSeq[String], ): RVD = { require(newKey startsWith typ.key) require(newKey.forall(typ.rowType.fieldNames.contains)) val rvdType = typ.copy(key = newKey) - if (RVDPartitioner.isValid(ctx.stateManager, rvdType.kType.virtualType, partitioner.rangeBounds)) + if ( + RVDPartitioner.isValid(ctx.stateManager, rvdType.kType.virtualType, partitioner.rangeBounds) + ) copy(typ = rvdType, partitioner = partitioner.copy(kType = rvdType.kType.virtualType)) else { val adjustedPartitioner = partitioner.strictify() repartition(ctx, adjustedPartitioner) - .copy(typ = rvdType, partitioner = adjustedPartitioner.copy(kType = rvdType.kType.virtualType)) + .copy( + typ = rvdType, + partitioner = adjustedPartitioner.copy(kType = rvdType.kType.virtualType), + ) } } @@ -194,7 +204,8 @@ class RVD( | Previous key: $prevKeyString |This error can occur after a split_multi if the dataset |contains both multiallelic variants and duplicated loci. - """.stripMargin) + """.stripMargin + ) } } @@ -204,13 +215,15 @@ class RVD( if (!partitionerBc.value.rangeBounds(i).contains(ord, kUR)) fatal( s"""RVD error! Unexpected key in partition $i - | Range bounds for partition $i: ${ partitionerBc.value.rangeBounds(i) } - | Range of partition IDs for key: [${ partitionerBc.value.lowerBound(kUR) }, ${ partitionerBc.value.upperBound(kUR) }) - | Invalid key: ${ Region.pretty(localKPType, prevK.value.offset) }""".stripMargin) + | Range bounds for partition $i: ${partitionerBc.value.rangeBounds(i)} + | Range of partition IDs for key: [${partitionerBc.value.lowerBound(kUR)}, ${partitionerBc.value.upperBound(kUR)}) + | Invalid key: ${Region.pretty(localKPType, prevK.value.offset)}""".stripMargin + ) ptr } } - }) + }, + ) } def truncateKey(n: Int): RVD = { @@ -225,7 +238,8 @@ class RVD( else copy( typ = typ.copy(key = newKey), - partitioner = partitioner.coarsen(newKey.length)) + partitioner = partitioner.coarsen(newKey.length), + ) } // WARNING: will drop any data with keys falling outside 'partitioner'. @@ -233,7 +247,7 @@ class RVD( ctx: ExecuteContext, newPartitioner: RVDPartitioner, shuffle: Boolean = false, - filter: Boolean = true + filter: Boolean = true, ): RVD = { if (newPartitioner == this.partitioner) return this @@ -250,21 +264,26 @@ class RVD( val partBc = newPartitioner.broadcast(crdd.sparkContext) val enc = TypedCodecSpec(rowPType, BufferSpec.wireSpec) - val filtered: RVD = if (filter) filterWithContext[(UnsafeRow, SelectFieldsRow)]({ case (_, _) => - val ur = new UnsafeRow(localRowPType, null, 0) - val key = new SelectFieldsRow(ur, newType.kFieldIdx) - (ur, key) - }, { case ((ur, key), ctx, ptr) => - ur.set(ctx.r, ptr) - partBc.value.contains(key) - }) else this + val filtered: RVD = if (filter) filterWithContext[(UnsafeRow, SelectFieldsRow)]( + { case (_, _) => + val ur = new UnsafeRow(localRowPType, null, 0) + val key = new SelectFieldsRow(ur, newType.kFieldIdx) + (ur, key) + }, + { case ((ur, key), ctx, ptr) => + ur.set(ctx.r, ptr) + partBc.value.contains(key) + }, + ) + else this val shuffled: RDD[(Any, Array[Byte])] = new ShuffledRDD( filtered.keyedEncodedRDD(ctx, enc, newType.key), - newPartitioner.sparkPartitioner(crdd.sparkContext) + newPartitioner.sparkPartitioner(crdd.sparkContext), ).setKeyOrdering(kOrdering.toOrdering) - val (rType: PStruct, shuffledCRDD) = enc.decodeRDD(ctx, localRowPType.virtualType, shuffled.values) + val (rType: PStruct, shuffledCRDD) = + enc.decodeRDD(ctx, localRowPType.virtualType, shuffled.values) RVD(RVDType(rType, newType.key), newPartitioner, shuffledCRDD) } else { @@ -272,7 +291,8 @@ class RVD( new RVD( typ.copy(key = typ.key.take(newPartitioner.kType.size)), newPartitioner, - RepartitionedOrderedRDD2(ctx.stateManager, this, newPartitioner.rangeBounds)) + RepartitionedOrderedRDD2(ctx.stateManager, this, newPartitioner.rangeBounds), + ) else this } @@ -280,7 +300,7 @@ class RVD( def naiveCoalesce( maxPartitions: Int, - executeContext: ExecuteContext + executeContext: ExecuteContext, ): RVD = { val n = partitioner.numPartitions if (maxPartitions >= n) @@ -298,14 +318,15 @@ class RVD( new RVD( typ, newPartitioner, - crdd.coalesceWithEnds(newPartEnd)) + crdd.coalesceWithEnds(newPartEnd), + ) } } def coalesce( ctx: ExecuteContext, maxPartitions: Int, - shuffle: Boolean + shuffle: Boolean, ): RVD = { require(maxPartitions > 0, "cannot coalesce to nPartitions <= 0") val n = crdd.partitions.length @@ -320,22 +341,29 @@ class RVD( return RVD.unkeyed(newRowPType, shuffled) val newType = RVDType(newRowPType, typ.key) - val keyInfo = RVD.getKeyInfo(ctx, newType, newType.key.length, RVD.getKeys(ctx, newType, shuffled)) + val keyInfo = + RVD.getKeyInfo(ctx, newType, newType.key.length, RVD.getKeys(ctx, newType, shuffled)) if (keyInfo.isEmpty) return RVD.empty(ctx, typ) val newPartitioner = RVD.calculateKeyRanges( - ctx, newType, keyInfo, shuffled.getNumPartitions, newType.key.length) + ctx, + newType, + keyInfo, + shuffled.getNumPartitions, + newType.key.length, + ) - if (newPartitioner.numPartitions< maxPartitions) - warn(s"coalesced to ${ newPartitioner.numPartitions} " + - s"${ plural(newPartitioner.numPartitions, "partition") }, less than requested $maxPartitions") + if (newPartitioner.numPartitions < maxPartitions) + warn(s"coalesced to ${newPartitioner.numPartitions} " + + s"${plural(newPartitioner.numPartitions, "partition")}, less than requested $maxPartitions") repartition(ctx, newPartitioner, shuffle) } else { val partSize = countPerPartition() - log.info(s"partSize = ${ partSize.toSeq }") + log.info(s"partSize = ${partSize.toSeq}") - val partCumulativeSize = mapAccumulate[Array, Long](partSize, 0L)((s, acc) => (s + acc, s + acc)) + val partCumulativeSize = + mapAccumulate[Array, Long](partSize, 0L)((s, acc) => (s + acc, s + acc)) val totalSize = partCumulativeSize.last var newPartEnd = (0 until maxPartitions).map { i => @@ -345,8 +373,10 @@ class RVD( var j = util.Arrays.binarySearch(partCumulativeSize, t) if (j < 0) j = -j - 1 - while (j < partCumulativeSize.length - 1 - && partCumulativeSize(j + 1) == t) + while ( + j < partCumulativeSize.length - 1 + && partCumulativeSize(j + 1) == t + ) j += 1 assert(t <= partCumulativeSize(j) && (j == partCumulativeSize.length - 1 || @@ -354,13 +384,15 @@ class RVD( j }.toArray - newPartEnd = newPartEnd.zipWithIndex.filter { case (end, i) => i == 0 || newPartEnd(i) != newPartEnd(i - 1) } + newPartEnd = newPartEnd.zipWithIndex.filter { case (_, i) => + i == 0 || newPartEnd(i) != newPartEnd(i - 1) + } .map(_._1) val newPartitioner = partitioner.coalesceRangeBounds(newPartEnd) - if (newPartitioner.numPartitions< maxPartitions) - warn(s"coalesced to ${ newPartitioner.numPartitions} " + - s"${ plural(newPartitioner.numPartitions, "partition") }, less than requested $maxPartitions") + if (newPartitioner.numPartitions < maxPartitions) + warn(s"coalesced to ${newPartitioner.numPartitions} " + + s"${plural(newPartitioner.numPartitions, "partition")}, less than requested $maxPartitions") if (newPartitioner == partitioner) { this @@ -368,7 +400,8 @@ class RVD( new RVD( typ, newPartitioner, - crdd.coalesceWithEnds(newPartEnd)) + crdd.coalesceWithEnds(newPartEnd), + ) } } } @@ -401,7 +434,8 @@ class RVD( new RVD( newType, partitioner.copy(kType = newType.kType.virtualType), - sortedRDD) + sortedRDD, + ) } // Mapping @@ -414,9 +448,7 @@ class RVD( def map(newTyp: RVDType)(f: (RVDContext, Long) => Long): RVD = { require(newTyp.kType isPrefixOf typ.kType) - RVD(newTyp, - partitioner.coarsen(newTyp.key.length), - crdd.cmap(f)) + RVD(newTyp, partitioner.coarsen(newTyp.key.length), crdd.cmap(f)) } def mapPartitions[T: ClassTag]( @@ -425,30 +457,38 @@ class RVD( def mapPartitions( newTyp: RVDType - )(f: (RVDContext, Iterator[Long]) => Iterator[Long] + )( + f: (RVDContext, Iterator[Long]) => Iterator[Long] ): RVD = { require(newTyp.kType isPrefixOf typ.kType) RVD( newTyp, partitioner.coarsen(newTyp.key.length), - crdd.cmapPartitions(f)) + crdd.cmapPartitions(f), + ) } - def mapPartitionsWithContext(newTyp: RVDType)(f: (RVDContext, RVDContext => Iterator[Long]) => Iterator[Long]): RVD = { + def mapPartitionsWithContext( + newTyp: RVDType + )( + f: (RVDContext, RVDContext => Iterator[Long]) => Iterator[Long] + ): RVD = RVD( newTyp, partitioner.coarsen(newTyp.key.length), - crdd.cmapPartitionsWithContext(f) + crdd.cmapPartitionsWithContext(f), ) - } - def mapPartitionsWithContextAndIndex(newTyp: RVDType)(f: (Int, RVDContext, RVDContext => Iterator[Long]) => Iterator[Long]): RVD = { + def mapPartitionsWithContextAndIndex( + newTyp: RVDType + )( + f: (Int, RVDContext, RVDContext => Iterator[Long]) => Iterator[Long] + ): RVD = RVD( newTyp, partitioner.coarsen(newTyp.key.length), - crdd.cmapPartitionsWithContextAndIndex(f) + crdd.cmapPartitionsWithContextAndIndex(f), ) - } def mapPartitionsWithIndex[T: ClassTag]( f: (Int, RVDContext, Iterator[Long]) => Iterator[T] @@ -456,25 +496,29 @@ class RVD( def mapPartitionsWithIndex( newTyp: RVDType - )(f: (Int, RVDContext, Iterator[Long]) => Iterator[Long] + )( + f: (Int, RVDContext, Iterator[Long]) => Iterator[Long] ): RVD = { require(newTyp.kType isPrefixOf typ.kType) RVD( newTyp, partitioner.coarsen(newTyp.key.length), - crdd.cmapPartitionsWithIndex(f)) + crdd.cmapPartitionsWithIndex(f), + ) } def mapPartitionsWithIndexAndValue[V]( newTyp: RVDType, - values: Array[V] - )(f: (Int, RVDContext, V, Iterator[Long]) => Iterator[Long] + values: Array[V], + )( + f: (Int, RVDContext, V, Iterator[Long]) => Iterator[Long] ): RVD = { require(newTyp.kType isPrefixOf typ.kType) RVD( newTyp, partitioner.coarsen(newTyp.key.length), - crdd.cmapPartitionsWithIndexAndValue(values, f)) + crdd.cmapPartitionsWithIndexAndValue(values, f), + ) } // Filtering @@ -500,13 +544,15 @@ class RVD( idx -> nTake } - val newRDD = crdd. - mapPartitionsWithIndex({ case (i, it) => - if (i == idxLast) - it.take(nTake.toInt) - else - it - }, preservesPartitioning = true) + val newRDD = crdd.mapPartitionsWithIndex( + { case (i, it) => + if (i == idxLast) + it.take(nTake.toInt) + else + it + }, + preservesPartitioning = true, + ) .subsetPartitions((0 to idxLast).toArray) val newNParts = newRDD.getNumPartitions @@ -540,18 +586,21 @@ class RVD( } assert(nDrop < Int.MaxValue) - val newRDD = crdd.cmapPartitionsAndContextWithIndex({ case (i, ctx, f) => - val it = f.next()(ctx) - if (i == idxFirst) { - (0 until nDrop.toInt).foreach { _ => - ctx.region.clear() - assert(it.hasNext) - it.next() - } - it - } else - it - }, preservesPartitioning = true) + val newRDD = crdd.cmapPartitionsAndContextWithIndex( + { case (i, ctx, f) => + val it = f.next()(ctx) + if (i == idxFirst) { + (0 until nDrop.toInt).foreach { _ => + ctx.region.clear() + assert(it.hasNext) + it.next() + } + it + } else + it + }, + preservesPartitioning = true, + ) .subsetPartitions(Array.range(idxFirst, getNumPartitions)) val oldNParts = crdd.getNumPartitions @@ -565,34 +614,33 @@ class RVD( RVD(typ, newPartitioner, newRDD) } - def filter(p: (RVDContext, Long) => Boolean): RVD = { + def filter(p: (RVDContext, Long) => Boolean): RVD = filterWithContext((_, _) => (), (_: Any, c, l) => p(c, l)) - } - def filterWithContext[C](makeContext: (Int, RVDContext) => C, f: (C, RVDContext, Long) => Boolean): RVD = { - val crdd: ContextRDD[Long] = this.crdd.cmapPartitionsWithContextAndIndex { (i, consumerCtx, iteratorToFilter) => - val c = makeContext(i, consumerCtx) - val producerCtx = consumerCtx.freshContext - iteratorToFilter(producerCtx).filter { ptr => - val b = f(c, consumerCtx, ptr) - if (b) { - producerCtx.region.move(consumerCtx.region) - } - else { - producerCtx.region.clear() + def filterWithContext[C](makeContext: (Int, RVDContext) => C, f: (C, RVDContext, Long) => Boolean) + : RVD = { + val crdd: ContextRDD[Long] = + this.crdd.cmapPartitionsWithContextAndIndex { (i, consumerCtx, iteratorToFilter) => + val c = makeContext(i, consumerCtx) + val producerCtx = consumerCtx.freshContext + iteratorToFilter(producerCtx).filter { ptr => + val b = f(c, consumerCtx, ptr) + if (b) { + producerCtx.region.move(consumerCtx.region) + } else { + producerCtx.region.clear() + } + b } - b } - } RVD(this.typ, this.partitioner, crdd) } - def filterIntervals(intervals: RVDPartitioner, keep: Boolean): RVD = { + def filterIntervals(intervals: RVDPartitioner, keep: Boolean): RVD = if (keep) filterToIntervals(intervals) else filterOutIntervals(intervals) - } def filterOutIntervals(intervals: RVDPartitioner): RVD = { val intervalsBc = intervals.broadcast(sparkContext) @@ -602,13 +650,13 @@ class RVD( val rowPType = typ.rowType filterWithContext[UnsafeRow]( - { (_, _) => new UnsafeRow(kPType) }, + (_, _) => new UnsafeRow(kPType), { case (kUR, ctx, ptr) => ctx.rvb.start(kType) ctx.rvb.selectRegionValue(rowPType, kRowFieldIdx, ctx.r, ptr) kUR.set(ctx.region, ctx.rvb.end()) !intervalsBc.value.contains(kUR) - } + }, ) } @@ -620,7 +668,8 @@ class RVD( val pred: (RVDContext, Long) => Boolean = (ctx: RVDContext, ptr: Long) => { val ur = new UnsafeRow(localRowPType, ctx.r, ptr) val key = Row.fromSeq( - kRowFieldIdx.map(i => ur.get(i))) + kRowFieldIdx.map(i => ur.get(i)) + ) intervalsBc.value.contains(key) } @@ -632,7 +681,7 @@ class RVD( .filter(i => intervals.overlaps(partitioner.rangeBounds(i))) .toArray - info(s"reading ${ newPartitionIndices.length } of $nPartitions data partitions") + info(s"reading ${newPartitionIndices.length} of $nPartitions data partitions") if (newPartitionIndices.isEmpty) RVD.empty(intervals.sm, typ) @@ -643,8 +692,10 @@ class RVD( def subsetPartitions(keep: Array[Int]): RVD = { require(keep.length <= crdd.partitions.length, "tried to subset to more partitions than exist") - require(keep.isIncreasing && (keep.isEmpty || (keep.head >= 0 && keep.last < crdd.partitions.length)), - "values not increasing or not in range [0, number of partitions)") + require( + keep.isIncreasing && (keep.isEmpty || (keep.head >= 0 && keep.last < crdd.partitions.length)), + "values not increasing or not in range [0, number of partitions)", + ) val newPartitioner = partitioner.copy(rangeBounds = keep.map(partitioner.rangeBounds)) @@ -652,26 +703,31 @@ class RVD( } def combine[U: ClassTag, T: ClassTag]( - execCtx: ExecuteContext, - mkZero: (HailClassLoader, HailTaskContext) => T, - itF: (HailClassLoader, Int, RVDContext, Iterator[Long]) => T, - deserialize: (HailClassLoader, HailTaskContext) => (U => T), - serialize: (HailClassLoader, HailTaskContext, T) => U, - combOp: (HailClassLoader, HailTaskContext, T, T) => T, - commutative: Boolean, - tree: Boolean + execCtx: ExecuteContext, + mkZero: (HailClassLoader, HailTaskContext) => T, + itF: (HailClassLoader, Int, RVDContext, Iterator[Long]) => T, + deserialize: (HailClassLoader, HailTaskContext) => (U => T), + serialize: (HailClassLoader, HailTaskContext, T) => U, + combOp: (HailClassLoader, HailTaskContext, T, T) => T, + commutative: Boolean, + tree: Boolean, ): T = { var reduced = crdd.cmapPartitionsWithIndex[U] { (i, ctx, it) => Iterator.single( - serialize(theHailClassLoaderForSparkWorkers, SparkTaskContext.get(), - itF(theHailClassLoaderForSparkWorkers, i, ctx, it))) + serialize( + theHailClassLoaderForSparkWorkers, + SparkTaskContext.get(), + itF(theHailClassLoaderForSparkWorkers, i, ctx, it), + ) + ) } if (tree) { val depth = treeAggDepth(getNumPartitions, HailContext.get.branchingFactor) val scale = math.max( math.ceil(math.pow(getNumPartitions, 1.0 / depth)).toInt, - 2) + 2, + ) var i = 0 while (reduced.getNumPartitions > scale) { @@ -690,7 +746,7 @@ class RVD( val hcl = theHailClassLoaderForSparkWorkers val htc = SparkTaskContext.get() var acc = mkZero(hcl, htc) - it.foreach { case (newPart, (oldPart, v)) => + it.foreach { case (_, (_, v)) => acc = combOp(hcl, htc, acc, deserialize(hcl, htc)(v)) } Iterator.single(serialize(hcl, htc, acc)) @@ -699,28 +755,31 @@ class RVD( } } - val ac = Combiner(mkZero(execCtx.theHailClassLoader, execCtx.taskContext), { (acc1: T, acc2: T) => combOp(execCtx.theHailClassLoader, execCtx.taskContext, acc1, acc2) }, - commutative, true) - sparkContext.runJob(reduced.run, (it: Iterator[U]) => singletonElement(it), (i, x: U) => ac.combine(i, deserialize(execCtx.theHailClassLoader, execCtx.taskContext)(x))) + val ac = Combiner( + mkZero(execCtx.theHailClassLoader, execCtx.taskContext), + (acc1: T, acc2: T) => combOp(execCtx.theHailClassLoader, execCtx.taskContext, acc1, acc2), + commutative, + true, + ) + sparkContext.runJob( + reduced.run, + (it: Iterator[U]) => singletonElement(it), + (i, x: U) => ac.combine(i, deserialize(execCtx.theHailClassLoader, execCtx.taskContext)(x)), + ) ac.result() } - def count(): Long = { + def count(): Long = crdd.boundary.cmapPartitions { (ctx, it) => var count = 0L - it.foreach { _ => - count += 1 - } + it.foreach(_ => count += 1) Iterator.single(count) }.run.fold(0L)(_ + _) - } def countPerPartition(): Array[Long] = crdd.boundary.cmapPartitions { (ctx, it) => var count = 0L - it.foreach { _ => - count += 1 - } + it.foreach(_ => count += 1) Iterator.single(count) }.collect() @@ -740,7 +799,8 @@ class RVD( } } - def collectAsBytes(ctx: ExecuteContext, enc: AbstractTypedCodecSpec): Array[Array[Byte]] = stabilize(ctx, enc).collect() + def collectAsBytes(ctx: ExecuteContext, enc: AbstractTypedCodecSpec): Array[Array[Byte]] = + stabilize(ctx, enc).collect() // Persisting @@ -774,9 +834,20 @@ class RVD( def storageLevel: StorageLevel = StorageLevel.NONE - def write(ctx: ExecuteContext, path: String, idxRelPath: String, stageLocally: Boolean, codecSpec: AbstractTypedCodecSpec): Array[FileWriteMetadata] = { + def write( + ctx: ExecuteContext, + path: String, + idxRelPath: String, + stageLocally: Boolean, + codecSpec: AbstractTypedCodecSpec, + ): Array[FileWriteMetadata] = { val fileData = crdd.writeRows(ctx, path, idxRelPath, typ, stageLocally, codecSpec) - val spec = MakeRVDSpec(codecSpec, fileData.map(_.path), partitioner, IndexSpec.emptyAnnotation(idxRelPath, typ.kType)) + val spec = MakeRVDSpec( + codecSpec, + fileData.map(_.path), + partitioner, + IndexSpec.emptyAnnotation(idxRelPath, typ.kType), + ) spec.write(ctx.fs, path) fileData } @@ -785,7 +856,8 @@ class RVD( def orderedLeftJoinDistinctAndInsert( right: RVD, - root: String): RVD = { + root: String, + ): RVD = { assert(!typ.key.contains(root)) val rightRowType = right.typ.rowType @@ -818,40 +890,44 @@ class RVD( rv } } - assert(typ.key.length >= right.typ.key.length, s"$typ >= ${ right.typ }\n $this\n $right") + assert(typ.key.length >= right.typ.key.length, s"$typ >= ${right.typ}\n $this\n $right") orderedLeftJoinDistinct( right, right.typ.key.length, joiner, - typ.copy(rowType = newRowType)) + typ.copy(rowType = newRowType), + ) } def orderedLeftJoinDistinct( right: RVD, joinKey: Int, joiner: (RVDContext, Iterator[JoinedRegionValue]) => Iterator[RegionValue], - joinedType: RVDType + joinedType: RVDType, ): RVD = keyBy(joinKey).orderedLeftJoinDistinct(right.keyBy(joinKey), joiner, joinedType) def orderedLeftIntervalJoin( ctx: ExecuteContext, right: RVD, - joiner: PStruct => (RVDType, (RVDContext, Iterator[Muple[RegionValue, Iterable[RegionValue]]]) => Iterator[RegionValue]) + joiner: PStruct => ( + RVDType, + (RVDContext, Iterator[Muple[RegionValue, Iterable[RegionValue]]]) => Iterator[RegionValue], + ), ): RVD = keyBy(1).orderedLeftIntervalJoin(ctx, right.keyBy(1), joiner) def orderedLeftIntervalJoinDistinct( ctx: ExecuteContext, right: RVD, - joiner: PStruct => (RVDType, (RVDContext, Iterator[JoinedRegionValue]) => Iterator[RegionValue]) + joiner: PStruct => (RVDType, (RVDContext, Iterator[JoinedRegionValue]) => Iterator[RegionValue]), ): RVD = keyBy(1).orderedLeftIntervalJoinDistinct(ctx, right.keyBy(1), joiner) def orderedMerge( right: RVD, joinKey: Int, - ctx: ExecuteContext + ctx: ExecuteContext, ): RVD = keyBy(joinKey).orderedMerge(right.keyBy(joinKey), ctx) @@ -869,8 +945,9 @@ class RVD( def alignAndZipPartitions( newTyp: RVDType, that: RVD, - joinKey: Int - )(zipper: (RVDContext, Iterator[RegionValue], Iterator[RegionValue]) => Iterator[RegionValue] + joinKey: Int, + )( + zipper: (RVDContext, Iterator[RegionValue], Iterator[RegionValue]) => Iterator[RegionValue] ): RVD = { require(newTyp.kType isPrefixOf this.typ.kType) require(joinKey <= this.typ.key.length) @@ -882,8 +959,13 @@ class RVD( typ = newTyp, partitioner = left.partitioner, crdd = left.crdd.toCRDDRegionValue.czipPartitions( - RepartitionedOrderedRDD2(sm, that, this.partitioner.coarsenedRangeBounds(joinKey)).toCRDDRegionValue - )(zipper).toCRDDPtr) + RepartitionedOrderedRDD2( + sm, + that, + this.partitioner.coarsenedRangeBounds(joinKey), + ).toCRDDRegionValue + )(zipper).toCRDDPtr, + ) } // Like alignAndZipPartitions, when 'that' is keyed by intervals. @@ -893,10 +975,16 @@ class RVD( // current partition of 'this'. def intervalAlignAndZipPartitions( ctx: ExecuteContext, - that: RVD - )(zipper: PStruct => (RVDType, (RVDContext, Iterator[RegionValue], Iterator[RegionValue]) => Iterator[RegionValue]) + that: RVD, + )( + zipper: PStruct => ( + RVDType, + (RVDContext, Iterator[RegionValue], Iterator[RegionValue]) => Iterator[RegionValue], + ) ): RVD = { - require(that.rowType.field(that.typ.key(0)).typ.asInstanceOf[TInterval].pointType == rowType.field(typ.key(0)).typ) + require(that.rowType.field(that.typ.key(0)).typ.asInstanceOf[ + TInterval + ].pointType == rowType.field(typ.key(0)).typ) val partBc = partitioner.broadcast(sparkContext) val rightTyp = that.typ @@ -905,16 +993,15 @@ class RVD( val sm = ctx.stateManager val partitionKeyedIntervals = that.crdd.cmapPartitions { (ctx, it) => val encoder = new ByteArrayEncoder(theHailClassLoaderForSparkWorkers, makeEnc) - TaskContext.get.addTaskCompletionListener[Unit] { _ => - encoder.close() - } + TaskContext.get.addTaskCompletionListener[Unit](_ => encoder.close()) it.flatMap { ptr => val r = SafeRow(rightTyp.rowType, ptr) val interval = r.getAs[Interval](rightTyp.kFieldIdx(0)) if (interval != null) { val wrappedInterval = interval.copy( start = Row(interval.start), - end = Row(interval.end)) + end = Row(interval.end), + ) val bytes = encoder.regionValueToBytes(ptr) partBc.value.queryInterval(wrappedInterval).map(i => ((i, interval), bytes)) } else @@ -923,14 +1010,15 @@ class RVD( }.run val nParts = getNumPartitions - val intervalOrd = rightTyp.kType.types(0).virtualType.ordering(sm).toOrdering.asInstanceOf[Ordering[Interval]] + val intervalOrd = + rightTyp.kType.types(0).virtualType.ordering(sm).toOrdering.asInstanceOf[Ordering[Interval]] val sorted: RDD[((Int, Interval), Array[Byte])] = new ShuffledRDD( partitionKeyedIntervals, new Partitioner { def getPartition(key: Any): Int = key.asInstanceOf[(Int, Interval)]._1 def numPartitions: Int = nParts - } + }, ).setKeyOrdering(Ordering.by[(Int, Interval), Interval](_._2)(intervalOrd)) val (rightPType: PStruct, rightCRDD) = codecSpec.decodeRDD(ctx, that.rowType, sorted.values) @@ -938,7 +1026,8 @@ class RVD( RVD( typ = newTyp, partitioner = partitioner, - crdd = crdd.toCRDDRegionValue.czipPartitions(rightCRDD.toCRDDRegionValue)(f).toCRDDPtr) + crdd = crdd.toCRDDRegionValue.czipPartitions(rightCRDD.toCRDDRegionValue)(f).toCRDDPtr, + ) } // Private @@ -946,19 +1035,22 @@ class RVD( private[rvd] def copy( typ: RVDType = typ, partitioner: RVDPartitioner = partitioner, - crdd: ContextRDD[Long] = crdd + crdd: ContextRDD[Long] = crdd, ): RVD = RVD(typ, partitioner, crdd) private[rvd] def destabilize( ctx: ExecuteContext, stable: RDD[Array[Byte]], - enc: AbstractTypedCodecSpec + enc: AbstractTypedCodecSpec, ): (PStruct, ContextRDD[Long]) = { val (rowPType: PStruct, dec) = enc.buildDecoder(ctx, rowType) - (rowPType, ContextRDD.weaken(stable).cmapPartitions { (ctx, it) => - RegionValue.fromBytes(theHailClassLoaderForSparkWorkers, dec, ctx.region, it) - }) + ( + rowPType, + ContextRDD.weaken(stable).cmapPartitions { (ctx, it) => + RegionValue.fromBytes(theHailClassLoaderForSparkWorkers, dec, ctx.region, it) + }, + ) } private[rvd] def crddBoundary: ContextRDD[Long] = @@ -969,26 +1061,23 @@ class RVD( } object RVD { - def empty(ctx: ExecuteContext, typ: RVDType): RVD = { + def empty(ctx: ExecuteContext, typ: RVDType): RVD = RVD.empty(ctx.stateManager, typ) - } - def empty(sm: HailStateManager, typ: RVDType): RVD = { - RVD(typ, - RVDPartitioner.empty(sm, typ.kType.virtualType), - ContextRDD.empty[Long]()) - } + def empty(sm: HailStateManager, typ: RVDType): RVD = + RVD(typ, RVDPartitioner.empty(sm, typ.kType.virtualType), ContextRDD.empty[Long]()) def unkeyed(rowType: PStruct, crdd: ContextRDD[Long]): RVD = new RVD( RVDType(rowType, FastSeq()), RVDPartitioner.unkeyed(null, crdd.getNumPartitions), - crdd) + crdd, + ) def getKeys( ctx: ExecuteContext, typ: RVDType, - crdd: ContextRDD[Long] + crdd: ContextRDD[Long], ): ContextRDD[Long] = { // The region values in 'crdd' are of type `typ.rowType` val localType = typ @@ -1010,7 +1099,7 @@ object RVD { // 'partitionKey' is used to check whether the rows are ordered by the first // 'partitionKey' key fields, even if they aren't ordered by the full key. partitionKey: Int, - keys: ContextRDD[Long] + keys: ContextRDD[Long], ): Array[RVDPartitionInfo] = { // the region values in 'keys' are of typ `typ.keyType` val nPartitions = keys.getNumPartitions @@ -1028,33 +1117,42 @@ object RVD { val sm = ctx.stateManager val keyInfo = keys.crunJobWithIndex { (i, rvdContext, it) => if (it.hasNext) - Some(RVDPartitionInfo(sm, localType, partitionKey, samplesPerPartition, i, it, partitionSeed(i), rvdContext)) + Some(RVDPartitionInfo( + sm, + localType, + partitionKey, + samplesPerPartition, + i, + it, + partitionSeed(i), + rvdContext, + )) else None }.flatten val kOrd = PartitionBoundOrdering(sm, typ.kType.virtualType).toOrdering - keyInfo.sortBy(_.min)(kOrd ) + keyInfo.sortBy(_.min)(kOrd) } def coerce( execCtx: ExecuteContext, typ: RVDType, - crdd: ContextRDD[Long] + crdd: ContextRDD[Long], ): RVD = coerce(execCtx, typ, typ.key.length, crdd) def coerce( execCtx: ExecuteContext, typ: RVDType, crdd: ContextRDD[Long], - fastKeys: ContextRDD[Long] + fastKeys: ContextRDD[Long], ): RVD = coerce(execCtx, typ, typ.key.length, crdd, fastKeys) def coerce( execCtx: ExecuteContext, typ: RVDType, partitionKey: Int, - crdd: ContextRDD[Long] + crdd: ContextRDD[Long], ): RVD = { val keys = getKeys(execCtx, typ, crdd) makeCoercer(execCtx, typ, partitionKey, keys).coerce(typ, crdd) @@ -1065,16 +1163,15 @@ object RVD { typ: RVDType, partitionKey: Int, crdd: ContextRDD[Long], - keys: ContextRDD[Long] - ): RVD = { + keys: ContextRDD[Long], + ): RVD = makeCoercer(execCtx, typ, partitionKey, keys).coerce(typ, crdd) - } def makeCoercer( execCtx: ExecuteContext, fullType: RVDType, // keys: RDD[RegionValue[fullType.kType]] - keys: ContextRDD[Long] + keys: ContextRDD[Long], ): RVDCoercer = makeCoercer(execCtx, fullType, fullType.key.length, keys) def makeCoercer( @@ -1082,7 +1179,7 @@ object RVD { fullType: RVDType, partitionKey: Int, // keys: RDD[RegionValue[fullType.kType]] - keys: ContextRDD[Long] + keys: ContextRDD[Long], ): RVDCoercer = { type CRDD = ContextRDD[Long] @@ -1100,7 +1197,6 @@ object RVD { def _coerce(typ: RVDType, crdd: CRDD): RVD = empty(execCtx, typ) } - val numPartitions = keys.getNumPartitions val keyInfo = getKeyInfo(execCtx, fullType, partitionKey, keys) if (keyInfo.isEmpty) @@ -1114,8 +1210,7 @@ object RVD { if (pids.isSorted && crdd.getNumPartitions == pids.length) { assert(pids.isEmpty || pids.last < crdd.getNumPartitions) crdd - } - else { + } else { assert(pids.isEmpty || pids.max < crdd.getNumPartitions) if (!pids.isSorted) info("Coerced dataset with out-of-order partitions.") @@ -1127,25 +1222,36 @@ object RVD { val intraPartitionSortedness = minInfo.sortedness val contextStr = minInfo.contextStr - if (intraPartitionSortedness == RVDPartitionInfo.KSORTED - && RVDPartitioner.isValid(execCtx.stateManager, fullType.kType.virtualType, bounds)) { + if ( + intraPartitionSortedness == RVDPartitionInfo.KSORTED + && RVDPartitioner.isValid(execCtx.stateManager, fullType.kType.virtualType, bounds) + ) { info("Coerced sorted dataset") new RVDCoercer(fullType) { val unfixedPartitioner = new RVDPartitioner(execCtx.stateManager, fullType.kType.virtualType, bounds) - val newPartitioner = RVDPartitioner.generate(execCtx.stateManager, - fullType.key.take(partitionKey), fullType.kType.virtualType, bounds) + val newPartitioner = RVDPartitioner.generate( + execCtx.stateManager, + fullType.key.take(partitionKey), + fullType.kType.virtualType, + bounds, + ) - def _coerce(typ: RVDType, crdd: CRDD): RVD = { + def _coerce(typ: RVDType, crdd: CRDD): RVD = RVD(typ, unfixedPartitioner, orderPartitions(crdd)) .repartition(execCtx, newPartitioner, shuffle = false) - } } - } else if (intraPartitionSortedness >= RVDPartitionInfo.TSORTED - && RVDPartitioner.isValid(execCtx.stateManager, fullType.kType.virtualType.truncate(partitionKey), pkBounds)) { + } else if ( + intraPartitionSortedness >= RVDPartitionInfo.TSORTED + && RVDPartitioner.isValid( + execCtx.stateManager, + fullType.kType.virtualType.truncate(partitionKey), + pkBounds, + ) + ) { info(s"Coerced almost-sorted dataset") log.info(s"Unsorted keys: $contextStr") @@ -1154,20 +1260,20 @@ object RVD { val unfixedPartitioner = new RVDPartitioner( execCtx.stateManager, fullType.kType.virtualType.truncate(partitionKey), - pkBounds + pkBounds, ) val newPartitioner = RVDPartitioner.generate( execCtx.stateManager, fullType.key.take(partitionKey), fullType.kType.virtualType.truncate(partitionKey), - pkBounds + pkBounds, ) def _coerce(typ: RVDType, crdd: CRDD): RVD = { RVD( typ.copy(key = typ.key.take(partitionKey)), unfixedPartitioner, - orderPartitions(crdd) + orderPartitions(crdd), ).repartition(execCtx, newPartitioner, shuffle = false) .localSort(typ.key) } @@ -1182,10 +1288,9 @@ object RVD { val newPartitioner = calculateKeyRanges(execCtx, fullType, keyInfo, keys.getNumPartitions, partitionKey) - def _coerce(typ: RVDType, crdd: CRDD): RVD = { + def _coerce(typ: RVDType, crdd: CRDD): RVD = RVD.unkeyed(typ.rowType, crdd) .repartition(execCtx, newPartitioner, shuffle = true, filter = false) - } } } } @@ -1195,7 +1300,7 @@ object RVD { typ: RVDType, pInfo: Array[RVDPartitionInfo], nPartitions: Int, - partitionKey: Int + partitionKey: Int, ): RVDPartitioner = { assert(nPartitions > 0) assert(pInfo.nonEmpty) @@ -1211,13 +1316,12 @@ object RVD { def apply( typ: RVDType, partitioner: RVDPartitioner, - crdd: ContextRDD[Long] - ): RVD = { + crdd: ContextRDD[Long], + ): RVD = if (!HailContext.get.checkRVDKeys) new RVD(typ, partitioner, crdd) else new RVD(typ, partitioner, crdd).checkKeyOrdering() - } def unify(execCtx: ExecuteContext, rvds: Seq[RVD]): Seq[RVD] = { if (rvds.length == 1 || rvds.forall(_.rowPType == rvds.head.rowPType)) @@ -1229,14 +1333,16 @@ object RVD { rvds.map { rvd => val srcRowPType = rvd.rowPType val newRVDType = rvd.typ.copy(rowType = unifiedRowPType, key = unifiedKey) - rvd.map(newRVDType)((ctx, ptr) => unifiedRowPType.copyFromAddress(sm, ctx.r, srcRowPType, ptr, false)) + rvd.map(newRVDType)((ctx, ptr) => + unifiedRowPType.copyFromAddress(sm, ctx.r, srcRowPType, ptr, false) + ) } } def union( rvds: Seq[RVD], joinKey: Int, - ctx: ExecuteContext + ctx: ExecuteContext, ): RVD = rvds match { case Seq(x) => x case first +: _ => @@ -1251,7 +1357,7 @@ object RVD { def union( rvds: Seq[RVD], - ctx: ExecuteContext + ctx: ExecuteContext, ): RVD = union(rvds, rvds.head.typ.key.length, ctx) @@ -1260,14 +1366,16 @@ object RVD { rvds: IndexedSeq[RVD], paths: IndexedSeq[String], bufferSpec: BufferSpec, - stageLocally: Boolean + stageLocally: Boolean, ): Array[Array[FileWriteMetadata]] = { val first = rvds.head - rvds.foreach {rvd => + rvds.foreach { rvd => if (rvd.typ != first.typ) - throw new RuntimeException(s"Type mismatch!\n head: ${ first.typ }\n altr: ${ rvd.typ }") + throw new RuntimeException(s"Type mismatch!\n head: ${first.typ}\n altr: ${rvd.typ}") if (rvd.partitioner != first.partitioner) - throw new RuntimeException(s"Partitioner mismatch!\n head:${ first.partitioner }\n altr: ${ rvd.partitioner }") + throw new RuntimeException( + s"Partitioner mismatch!\n head:${first.partitioner}\n altr: ${rvd.partitioner}" + ) } val sc = SparkBackend.sparkContext("writeRowsSplitFiles") @@ -1288,14 +1396,16 @@ object RVD { val rowsCodecSpec = TypedCodecSpec(rowsRVType, bufferSpec) val entriesCodecSpec = TypedCodecSpec(entriesRVType, bufferSpec) val rowsIndexSpec = IndexSpec.defaultAnnotation("../../index", localTyp.kType) - val entriesIndexSpec = IndexSpec.defaultAnnotation("../../index", localTyp.kType, withOffsetField = true) + val entriesIndexSpec = + IndexSpec.defaultAnnotation("../../index", localTyp.kType, withOffsetField = true) val makeRowsEnc = rowsCodecSpec.buildEncoder(execCtx, fullRowType) val makeEntriesEnc = entriesCodecSpec.buildEncoder(execCtx, fullRowType) - val _makeIndexWriter = IndexWriter.builder(execCtx, localTyp.kType, +PCanonicalStruct("entries_offset" -> PInt64())) - val makeIndexWriter: (String, RegionPool) => IndexWriter = _makeIndexWriter(_, theHailClassLoaderForSparkWorkers, SparkTaskContext.get(), _) + val _makeIndexWriter = + IndexWriter.builder(execCtx, localTyp.kType, +PCanonicalStruct("entries_offset" -> PInt64())) + val makeIndexWriter: (String, RegionPool) => IndexWriter = + _makeIndexWriter(_, theHailClassLoaderForSparkWorkers, SparkTaskContext.get(), _) val partDigits = digitsNeeded(nPartitions) - val fileDigits = digitsNeeded(rvds.length) for (i <- 0 until nRVDs) { val path = paths(i) fs.mkDir(path + "/rows/rows/parts") @@ -1303,35 +1413,38 @@ object RVD { fs.mkDir(path + "/index") } - val partF = { (originIdx: Int, originPartIdx: Int, it: Iterator[RVDContext => Iterator[Long]]) => - Iterator.single { ctx: RVDContext => - val fullPath = paths(originIdx) - val fileData = RichContextRDDRegionValue.writeSplitRegion( - localTmpdir, - fsBc.value, - fullPath, - localTyp, - singletonElement(it)(ctx), - originPartIdx, - ctx, - partDigits, - stageLocally, - makeIndexWriter, - os => makeRowsEnc(os, theHailClassLoaderForSparkWorkers), - os => makeEntriesEnc(os, theHailClassLoaderForSparkWorkers)) - Iterator.single((fileData, originIdx)) - } + val partF = { + (originIdx: Int, originPartIdx: Int, it: Iterator[RVDContext => Iterator[Long]]) => + Iterator.single { ctx: RVDContext => + val fullPath = paths(originIdx) + val fileData = RichContextRDDRegionValue.writeSplitRegion( + localTmpdir, + fsBc.value, + fullPath, + localTyp, + singletonElement(it)(ctx), + originPartIdx, + ctx, + partDigits, + stageLocally, + makeIndexWriter, + os => makeRowsEnc(os, theHailClassLoaderForSparkWorkers), + os => makeEntriesEnc(os, theHailClassLoaderForSparkWorkers), + ) + Iterator.single((fileData, originIdx)) + } } val partFilePartitionCounts = execCtx.timer.time("writeOriginUnionRDD")(new ContextRDD( - new OriginUnionRDD(first.crdd.rdd.sparkContext, rvds.map(_.crdd.rdd), partF)) + new OriginUnionRDD(first.crdd.rdd.sparkContext, rvds.map(_.crdd.rdd), partF) + ) .collect()) - val fileDataByOrigin = Array.fill[BoxedArrayBuilder[FileWriteMetadata]](nRVDs)(new BoxedArrayBuilder()) + val fileDataByOrigin = + Array.fill[BoxedArrayBuilder[FileWriteMetadata]](nRVDs)(new BoxedArrayBuilder()) - for ((fd, oidx) <- partFilePartitionCounts) { + for ((fd, oidx) <- partFilePartitionCounts) fileDataByOrigin(oidx) += fd - } val fileData = fileDataByOrigin.map(_.result()) @@ -1340,12 +1453,22 @@ object RVD { .par .foreach { case (partFiles, i) => val fs = fsBc.value - val s = StringUtils.leftPad(i.toString, fileDigits, '0') val basePath = paths(i) - RichContextRDDRegionValue.writeSplitSpecs(fs, basePath, - rowsCodecSpec, entriesCodecSpec, rowsIndexSpec, entriesIndexSpec, - localTyp, rowsRVType, entriesRVType, partFiles.map(_.path), partitionerBc.value) - }) + RichContextRDDRegionValue.writeSplitSpecs( + fs, + basePath, + rowsCodecSpec, + entriesCodecSpec, + rowsIndexSpec, + entriesIndexSpec, + localTyp, + rowsRVType, + entriesRVType, + partFiles.map(_.path), + partitionerBc.value, + ) + } + ) fileData } diff --git a/hail/src/main/scala/is/hail/rvd/RVDContext.scala b/hail/src/main/scala/is/hail/rvd/RVDContext.scala index c2a4364424f..a8b2faf10ec 100644 --- a/hail/src/main/scala/is/hail/rvd/RVDContext.scala +++ b/hail/src/main/scala/is/hail/rvd/RVDContext.scala @@ -2,15 +2,14 @@ package is.hail.rvd import is.hail.annotations.{Region, RegionPool, RegionValueBuilder} import is.hail.backend.HailStateManager -import is.hail.utils._ import scala.collection.mutable object RVDContext { def default(pool: RegionPool) = { - val partRegion = Region(pool=pool) - val ctx = new RVDContext(partRegion, Region(pool=pool)) + val partRegion = Region(pool = pool) + val ctx = new RVDContext(partRegion, Region(pool = pool)) ctx.own(partRegion) ctx } @@ -20,13 +19,14 @@ class RVDContext(val partitionRegion: Region, val r: Region) extends AutoCloseab private[this] val children = new mutable.HashSet[AutoCloseable]() private def own(child: AutoCloseable): Unit = children += child + private[this] def disown(child: AutoCloseable): Unit = assert(children.remove(child)) own(r) def freshContext(): RVDContext = { - val ctx = new RVDContext(partitionRegion, Region(pool=r.pool)) + val ctx = new RVDContext(partitionRegion, Region(pool = r.pool)) own(ctx) ctx } @@ -46,9 +46,9 @@ class RVDContext(val partitionRegion: Region, val r: Region) extends AutoCloseab def close(): Unit = { var e: Exception = null children.foreach { child => - try { + try child.close() - } catch { + catch { case e2: Exception => if (e == null) e = e2 diff --git a/hail/src/main/scala/is/hail/rvd/RVDPartitionInfo.scala b/hail/src/main/scala/is/hail/rvd/RVDPartitionInfo.scala index bf68f984544..a7158e19076 100644 --- a/hail/src/main/scala/is/hail/rvd/RVDPartitionInfo.scala +++ b/hail/src/main/scala/is/hail/rvd/RVDPartitionInfo.scala @@ -1,12 +1,12 @@ package is.hail.rvd -import net.sourceforge.jdistlib.rng.MersenneTwister - import is.hail.annotations.{Region, RegionValue, SafeRow, WritableRegionValue} import is.hail.backend.HailStateManager import is.hail.types.virtual.Type import is.hail.utils._ +import net.sourceforge.jdistlib.rng.MersenneTwister + case class RVDPartitionInfo( partitionIndex: Int, size: Long, @@ -15,13 +15,12 @@ case class RVDPartitionInfo( // min, max: RegionValue[kType] samples: Array[Any], sortedness: Int, - contextStr: String + contextStr: String, ) { val interval = Interval(min, max, true, true) - def pretty(t: Type): String = { + def pretty(t: Type): String = s"partitionIndex=$partitionIndex,size=$size,min=$min,max=$max,samples=${samples.mkString(",")},sortedness=$sortedness" - } } object RVDPartitionInfo { @@ -37,7 +36,7 @@ object RVDPartitionInfo { partitionIndex: Int, it: Iterator[Long], seed: Long, - producerContext: RVDContext + producerContext: RVDContext, ): RVDPartitionInfo = { using(RVDContext.default(producerContext.r.pool)) { localctx => val kPType = typ.kType @@ -108,11 +107,15 @@ object RVDPartitionInfo { val safe: RegionValue => Any = SafeRow(kPType, _) - RVDPartitionInfo(partitionIndex, i, - safe(minF.value), safe(maxF.value), + RVDPartitionInfo( + partitionIndex, + i, + safe(minF.value), + safe(maxF.value), Array.tabulate[Any](math.min(i, sampleSize).toInt)(i => safe(samples(i).value)), sortedness, - contextStr) + contextStr, + ) } } } diff --git a/hail/src/main/scala/is/hail/rvd/RVDPartitioner.scala b/hail/src/main/scala/is/hail/rvd/RVDPartitioner.scala index 7d73d55c57f..e8902fb8c31 100644 --- a/hail/src/main/scala/is/hail/rvd/RVDPartitioner.scala +++ b/hail/src/main/scala/is/hail/rvd/RVDPartitioner.scala @@ -5,10 +5,11 @@ import is.hail.backend.{ExecuteContext, HailStateManager} import is.hail.expr.ir.Literal import is.hail.types.virtual._ import is.hail.utils._ + import org.apache.commons.lang3.builder.HashCodeBuilder +import org.apache.spark.{Partitioner, SparkContext} import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.Row -import org.apache.spark.{Partitioner, SparkContext} class RVDPartitioner( val sm: HailStateManager, @@ -16,7 +17,7 @@ class RVDPartitioner( // rangeBounds: Array[Interval[kType]] // rangeBounds is interval containing all keys within a partition val rangeBounds: Array[Interval], - allowedOverlap: Int + allowedOverlap: Int, ) { // expensive, for debugging // assert(rangeBounds.forall(SafeRow.isSafe)) @@ -28,32 +29,33 @@ class RVDPartitioner( sm: HailStateManager, kType: TStruct, rangeBounds: IndexedSeq[Interval], - allowedOverlap: Int + allowedOverlap: Int, ) = this(sm, kType, rangeBounds.toArray, allowedOverlap) def this( sm: HailStateManager, kType: TStruct, - rangeBounds: IndexedSeq[Interval] + rangeBounds: IndexedSeq[Interval], ) = this(sm, kType, rangeBounds.toArray, kType.size) def this( sm: HailStateManager, partitionKey: Array[String], kType: TStruct, - rangeBounds: IndexedSeq[Interval] + rangeBounds: IndexedSeq[Interval], ) = this(sm, kType, rangeBounds.toArray, math.max(partitionKey.length - 1, 0)) def this( sm: HailStateManager, partitionKey: Option[Int], kType: TStruct, - rangeBounds: IndexedSeq[Interval] + rangeBounds: IndexedSeq[Interval], ) = this(sm, kType, rangeBounds.toArray, partitionKey.map(_ - 1).getOrElse(kType.size)) require(rangeBounds.forall { case Interval(l, r, _, _) => kType.relaxedTypeCheck(l) && kType.relaxedTypeCheck(r) }) + require(allowedOverlap >= 0 && allowedOverlap <= kType.size) require(RVDPartitioner.isValid(sm, kType, rangeBounds, allowedOverlap)) @@ -69,7 +71,12 @@ class RVDPartitioner( Some(Interval(rangeBounds.head.left, rangeBounds.last.right)) def satisfiesAllowedOverlap(testAllowedOverlap: Int): Boolean = - (testAllowedOverlap >= kType.size) || RVDPartitioner.isValid(sm, kType, rangeBounds, testAllowedOverlap) + (testAllowedOverlap >= kType.size) || RVDPartitioner.isValid( + sm, + kType, + rangeBounds, + testAllowedOverlap, + ) def isStrict: Boolean = satisfiesAllowedOverlap(kType.size - 1) @@ -126,17 +133,17 @@ class RVDPartitioner( sm, kType.truncate(newKeyLen), coarsenedRangeBounds(newKeyLen), - math.min(allowedOverlap, newKeyLen)) + math.min(allowedOverlap, newKeyLen), + ) } } - def strictify(allowedOverlap: Int = kType.size - 1): RVDPartitioner = { + def strictify(allowedOverlap: Int = kType.size - 1): RVDPartitioner = if (satisfiesAllowedOverlap(allowedOverlap)) this else - coarsen(allowedOverlap+1) + coarsen(allowedOverlap + 1) .extendKey(kType) - } // Adjusts 'rangeBounds' so that 'satisfiesAllowedOverlap(kType.size - 1)' // holds, then changes key type to 'newKType'. If 'newKType' is 'kType', still @@ -152,14 +159,15 @@ class RVDPartitioner( sm, newKType, rangeBounds, - allowedOverlap) + allowedOverlap, + ) } // Operators (produce new partitioners) def subdivide( cutPoints: IndexedSeq[IntervalEndpoint], - allowedOverlap: Int = kType.size + allowedOverlap: Int = kType.size, ): RVDPartitioner = { require(cutPoints.forall { case IntervalEndpoint(row, _) => kType.relaxedTypeCheck(row) @@ -192,29 +200,33 @@ class RVDPartitioner( } def intersect(other: RVDPartitioner): RVDPartitioner = { - if (!kType.isIsomorphicTo(other.kType)) + if (!kType.isJoinableWith(other.kType)) throw new AssertionError(s"key types not isomorphic: $kType, ${other.kType}") - new RVDPartitioner(sm, kType, Interval.intersection(this.rangeBounds, other.rangeBounds, kord.intervalEndpointOrdering)) + new RVDPartitioner( + sm, + kType, + Interval.intersection(this.rangeBounds, other.rangeBounds, kord.intervalEndpointOrdering), + ) } def rename(nameMap: Map[String, String]): RVDPartitioner = new RVDPartitioner( sm, kType.rename(nameMap), rangeBounds, - allowedOverlap + allowedOverlap, ) def copy( kType: TStruct = kType, rangeBounds: IndexedSeq[Interval] = rangeBounds, - allowedOverlap: Int = allowedOverlap + allowedOverlap: Int = allowedOverlap, ): RVDPartitioner = new RVDPartitioner(sm, kType, rangeBounds, allowedOverlap) def coalesceRangeBounds(newPartEnd: IndexedSeq[Int]): RVDPartitioner = { val newRangeBounds = (-1 +: newPartEnd.init).zip(newPartEnd).map { case (s, e) => - rangeBounds(s+1).hull(kord, rangeBounds(e)) + rangeBounds(s + 1).hull(kord, rangeBounds(e)) } copy(rangeBounds = newRangeBounds) } @@ -224,36 +236,31 @@ class RVDPartitioner( def contains(index: Int, key: Any): Boolean = rangeBounds(index).contains(kord, key) - /** Returns 0 <= i <= numPartitions such that partition i is the first which - * either contains 'key' or is above 'key', returning numPartitions if 'key' - * is above all partitions. + /** Returns 0 <= i <= numPartitions such that partition i is the first which either contains 'key' + * or is above 'key', returning numPartitions if 'key' is above all partitions. * - * 'key' may be either a Row or an IntervalEndpoint. In the latter case, - * returns the ID of the first partition which overlaps the interval with - * left endpoint 'key' and unbounded right endpoint, or numPartitions if - * none do. + * 'key' may be either a Row or an IntervalEndpoint. In the latter case, returns the ID of the + * first partition which overlaps the interval with left endpoint 'key' and unbounded right + * endpoint, or numPartitions if none do. */ def lowerBound(key: Any): Int = rangeBounds.lowerBound(key, intervalKeyLT) - /** Returns 0 <= i <= numPartitions such that partition i is the first which - * is above 'key', returning numPartitions if 'key' is above all partitions. + /** Returns 0 <= i <= numPartitions such that partition i is the first which is above 'key', + * returning numPartitions if 'key' is above all partitions. * - * 'key' may be either a Row or an IntervalEndpoint. In the latter case, - * returns the ID of the first partition which is completely above the - * interval with right endpoint 'key' and unbounded left endpoint, or - * numPartitions if none are. + * 'key' may be either a Row or an IntervalEndpoint. In the latter case, returns the ID of the + * first partition which is completely above the interval with right endpoint 'key' and unbounded + * left endpoint, or numPartitions if none are. */ def upperBound(key: Any): Int = rangeBounds.upperBound(key, keyIntervalLT) /** Returns (lowerBound, upperBound). Interesting cases are: - * - partitioner contains 'key': - * [lowerBound, upperBound) is the range of partition IDs containing 'key'. - * - 'key' falls in the gap between two partitions: - * lowerBound = upperBound is the ID of the first partition above 'key'. - * - 'key' is below the first partition (or numPartitions = 0): - * lowerBound = upperBound = 0 - * - 'key' is above the last partition: - * lowerBound = upperBound = numPartitions + * - partitioner contains 'key': [lowerBound, upperBound) is the range of partition IDs + * containing 'key'. + * - 'key' falls in the gap between two partitions: lowerBound = upperBound is the ID of the + * first partition above 'key'. + * - 'key' is below the first partition (or numPartitions = 0): lowerBound = upperBound = 0 + * - 'key' is above the last partition: lowerBound = upperBound = numPartitions */ def keyRange(key: Any): (Int, Int) = rangeBounds.equalRange(key, intervalKeyLT, keyIntervalLT) @@ -266,27 +273,25 @@ class RVDPartitioner( // Interval queries - /** Returns 0 <= i <= numPartitions such that partition i is the first which - * either overlaps 'query' or is above 'query', returning numPartitions if - * 'query' is completely above all partitions. + /** Returns 0 <= i <= numPartitions such that partition i is the first which either overlaps + * 'query' or is above 'query', returning numPartitions if 'query' is completely above all + * partitions. */ def lowerBoundInterval(query: Interval): Int = rangeBounds.lowerBound(query, intervalLT) - /** Returns 0 <= i <= numPartitions such that partition i is the first which - * is above 'query', returning numPartitions if 'query' is completely above - * or overlaps all partitions. + /** Returns 0 <= i <= numPartitions such that partition i is the first which is above 'query', + * returning numPartitions if 'query' is completely above or overlaps all partitions. */ def upperBoundInterval(query: Interval): Int = rangeBounds.upperBound(query, intervalLT) /** Returns (lowerBound, upperBound). Interesting cases are: - * - partitioner overlaps 'query': - * [lowerBound, upperBound) is the range of partition IDs overlapping 'query'. - * - 'query' falls in the gap between two partitions: - * lowerBound = upperBound is the ID of the first partition above 'query'. - * - 'query' is completely below the first partition (or numPartitions = 0): - * lowerBound = upperBound = 0 - * - 'query' is completely above the last partition: - * lowerBound = upperBound = numPartitions + * - partitioner overlaps 'query': [lowerBound, upperBound) is the range of partition IDs + * overlapping 'query'. + * - 'query' falls in the gap between two partitions: lowerBound = upperBound is the ID of the + * first partition above 'query'. + * - 'query' is completely below the first partition (or numPartitions = 0): lowerBound = + * upperBound = 0 + * - 'query' is completely above the last partition: lowerBound = upperBound = numPartitions */ def intervalRange(query: Interval): (Int, Int) = rangeBounds.equalRange(query, intervalLT) @@ -299,20 +304,19 @@ class RVDPartitioner( def isDisjointFrom(query: Interval): Boolean = !overlaps(query) - def partitionBoundsIRRepresentation: Literal = { - Literal(TArray(RVDPartitioner.intervalIRRepresentation(kType)), - rangeBounds.map(i => RVDPartitioner.intervalToIRRepresentation(i, kType.size)).toFastSeq) - } + def partitionBoundsIRRepresentation: Literal = + Literal( + TArray(RVDPartitioner.intervalIRRepresentation(kType)), + rangeBounds.map(i => RVDPartitioner.intervalToIRRepresentation(i, kType.size)).toFastSeq, + ) } object RVDPartitioner { - def empty(ctx: ExecuteContext, typ: TStruct): RVDPartitioner = { + def empty(ctx: ExecuteContext, typ: TStruct): RVDPartitioner = RVDPartitioner.empty(ctx.stateManager, typ) - } - def empty(sm: HailStateManager, typ: TStruct): RVDPartitioner = { + def empty(sm: HailStateManager, typ: TStruct): RVDPartitioner = new RVDPartitioner(sm, typ, Array.empty[Interval]) - } def unkeyed(sm: HailStateManager, numPartitions: Int): RVDPartitioner = { val unkeyedInterval = Interval(Row(), Row(), true, true) @@ -320,17 +324,19 @@ object RVDPartitioner { sm, TStruct.empty, Array.fill(numPartitions)(unkeyedInterval), - 0) + 0, + ) } - def generate(sm: HailStateManager, kType: TStruct, intervals: IndexedSeq[Interval]): RVDPartitioner = + def generate(sm: HailStateManager, kType: TStruct, intervals: IndexedSeq[Interval]) + : RVDPartitioner = generate(sm, kType.fieldNames, kType, intervals) def generate( sm: HailStateManager, partitionKey: IndexedSeq[String], kType: TStruct, - intervals: IndexedSeq[Interval] + intervals: IndexedSeq[Interval], ): RVDPartitioner = { require(intervals.forall { case Interval(l, r, _, _) => kType.relaxedTypeCheck(l) && kType.relaxedTypeCheck(r) @@ -344,7 +350,7 @@ object RVDPartitioner { sm: HailStateManager, kType: TStruct, intervals: IndexedSeq[Interval], - allowedOverlap: Int + allowedOverlap: Int, ): RVDPartitioner = { val kord = PartitionBoundOrdering(sm, kType) val eord = kord.intervalEndpointOrdering @@ -380,7 +386,7 @@ object RVDPartitioner { max: Any, keys: IndexedSeq[Any], nPartitions: Int, - partitionKey: Int + partitionKey: Int, ): RVDPartitioner = { require(nPartitions > 0) require(typ.kType.virtualType.relaxedTypeCheck(min)) @@ -398,7 +404,7 @@ object RVDPartitioner { new RVDPartitioner( ctx.stateManager, typ.kType.virtualType, - FastSeq(interval) + FastSeq(interval), ).subdivide(partitionEdges, math.max(partitionKey - 1, 0)) } @@ -412,12 +418,18 @@ object RVDPartitioner { allowedOverlap: Int, ): Boolean = { rangeBounds.isEmpty || - rangeBounds.zip(rangeBounds.tail).forall { case (left: Interval, right: Interval) => - val r = PartitionBoundOrdering(sm, kType).intervalEndpointOrdering.lteqWithOverlap(allowedOverlap)(left.right, right.left) - if (!r) - log.info(s"invalid partitioner: !lteqWithOverlap($allowedOverlap)(${ left }.right, ${ right }.left)") - r - } + rangeBounds.zip(rangeBounds.tail).forall { case (left: Interval, right: Interval) => + val r = + PartitionBoundOrdering(sm, kType).intervalEndpointOrdering.lteqWithOverlap(allowedOverlap)( + left.right, + right.left, + ) + if (!r) + log.info( + s"invalid partitioner: !lteqWithOverlap($allowedOverlap)($left.right, $right.left)" + ) + r + } } def intervalIRRepresentation(ts: TStruct): TInterval = @@ -426,7 +438,8 @@ object RVDPartitioner { def intervalToIRRepresentation(interval: Interval, len: Int): Interval = { def processEndpoint(p: IntervalEndpoint): IntervalEndpoint = { val r = p.point.asInstanceOf[Row] - val newr = Row(Row.fromSeq((0 until len).map(i => if (i >= r.length) null else r.get(i))), r.length) + val newr = + Row(Row.fromSeq((0 until len).map(i => if (i >= r.length) null else r.get(i))), r.length) p.copy(point = newr) } diff --git a/hail/src/main/scala/is/hail/rvd/RVDType.scala b/hail/src/main/scala/is/hail/rvd/RVDType.scala index e08ff0a97fe..9b8b9527c60 100644 --- a/hail/src/main/scala/is/hail/rvd/RVDType.scala +++ b/hail/src/main/scala/is/hail/rvd/RVDType.scala @@ -1,22 +1,26 @@ package is.hail.rvd import is.hail.annotations._ -import is.hail.backend.{ExecuteContext, HailStateManager} +import is.hail.backend.HailStateManager import is.hail.expr.ir.IRParser -import is.hail.types.physical.{PInterval, PStruct, PType} -import is.hail.types.virtual.TStruct +import is.hail.types.physical.{PInterval, PStruct} import is.hail.utils._ + import org.json4s.CustomSerializer import org.json4s.JsonAST.{JArray, JObject, JString, JValue} -class RVDTypeSerializer extends CustomSerializer[RVDType](format => ( { - case JString(s) => IRParser.parseRVDType(s) -}, { - case rvdType: RVDType => JString(rvdType.toString) -})) - -final case class RVDType(rowType: PStruct, key: IndexedSeq[String]) - extends Serializable { +class RVDTypeSerializer extends CustomSerializer[RVDType](format => + ( + { + case JString(s) => IRParser.parseRVDType(s) + }, + { + case rvdType: RVDType => JString(rvdType.toString) + }, + ) + ) + +final case class RVDType(rowType: PStruct, key: IndexedSeq[String]) extends Serializable { require(rowType.required, rowType) val keySet: Set[String] = key.toSet @@ -25,6 +29,7 @@ final case class RVDType(rowType: PStruct, key: IndexedSeq[String]) val valueType: PStruct = rowType.dropFields(keySet) val kFieldIdx: Array[Int] = key.map(n => rowType.fieldIdx(n)).toArray + val valueFieldIdx: Array[Int] = (0 until rowType.size) .filter(i => !keySet.contains(rowType.fields(i).name)) .toArray @@ -34,11 +39,14 @@ final case class RVDType(rowType: PStruct, key: IndexedSeq[String]) @transient private var _kOrd: UnsafeOrdering = _ def kInRowOrd(sm: HailStateManager): UnsafeOrdering = { - if (_kInRowOrd == null) _kInRowOrd = RVDType.selectUnsafeOrdering(sm, rowType, kFieldIdx, rowType, kFieldIdx) + if (_kInRowOrd == null) + _kInRowOrd = RVDType.selectUnsafeOrdering(sm, rowType, kFieldIdx, rowType, kFieldIdx) _kInRowOrd } + def kRowOrd(sm: HailStateManager): UnsafeOrdering = { - if (_kRowOrd == null) _kRowOrd = RVDType.selectUnsafeOrdering(sm, kType, Array.range(0, kType.size), rowType, kFieldIdx) + if (_kRowOrd == null) _kRowOrd = + RVDType.selectUnsafeOrdering(sm, kType, Array.range(0, kType.size), rowType, kFieldIdx) _kRowOrd } @@ -54,7 +62,8 @@ final case class RVDType(rowType: PStruct, key: IndexedSeq[String]) this.kFieldIdx, other.rowType, other.kFieldIdx, - true) + true, + ) def joinComp(sm: HailStateManager, other: RVDType): UnsafeOrdering = RVDType.selectUnsafeOrdering( @@ -63,14 +72,16 @@ final case class RVDType(rowType: PStruct, key: IndexedSeq[String]) this.kFieldIdx, other.rowType, other.kFieldIdx, - false) + false, + ) - /** Comparison of a point with an interval, for use in joins where one side - * is keyed by intervals. + /** Comparison of a point with an interval, for use in joins where one side is keyed by intervals. */ def intervalJoinComp(sm: HailStateManager, other: RVDType): UnsafeOrdering = { require(other.key.length == 1) - require(other.rowType.field(other.key(0)).typ.asInstanceOf[PInterval].pointType.virtualType == rowType.field(key(0)).typ.virtualType) + require(other.rowType.field(other.key(0)).typ.asInstanceOf[ + PInterval + ].pointType.virtualType == rowType.field(key(0)).typ.virtualType) new UnsafeOrdering { val t1 = rowType @@ -111,13 +122,13 @@ final case class RVDType(rowType: PStruct, key: IndexedSeq[String]) } } - def kRowOrdView(sm: HailStateManager, region: Region) = new OrderingView[RegionValue] { val wrv = WritableRegionValue(sm, kType, region) val kRowOrdering = kRowOrd(sm) - def setFiniteValue(representative: RegionValue) { + + def setFiniteValue(representative: RegionValue): Unit = wrv.setSelect(rowType, kFieldIdx, representative) - } + def compareFinite(rv: RegionValue): Int = kRowOrdering.compare(wrv.value, rv) } @@ -126,7 +137,8 @@ final case class RVDType(rowType: PStruct, key: IndexedSeq[String]) JObject(List( "partitionKey" -> JArray(key.map(JString).toList), "key" -> JArray(key.map(JString).toList), - "rowType" -> JString(rowType.toString))) + "rowType" -> JString(rowType.toString), + )) override def toString: String = { val sb = new StringBuilder() @@ -144,9 +156,11 @@ final case class RVDType(rowType: PStruct, key: IndexedSeq[String]) object RVDType { def selectUnsafeOrdering( sm: HailStateManager, - t1: PStruct, fields1: Array[Int], - t2: PStruct, fields2: Array[Int], - missingEqual: Boolean=true + t1: PStruct, + fields1: Array[Int], + t2: PStruct, + fields2: Array[Int], + missingEqual: Boolean = true, ): UnsafeOrdering = { val fieldOrderings = Range(0, fields1.length).map { i => t1.types(fields1(i)).unsafeOrdering(sm, t2.types(fields2(i))) @@ -156,10 +170,12 @@ object RVDType { } def selectUnsafeOrdering( - t1: PStruct, fields1: Array[Int], - t2: PStruct, fields2: Array[Int], + t1: PStruct, + fields1: Array[Int], + t2: PStruct, + fields2: Array[Int], fieldOrderings: Array[UnsafeOrdering], - missingEqual: Boolean + missingEqual: Boolean, ): UnsafeOrdering = { require(fields1.length == fields2.length) require((fields1, fields2).zipped.forall { case (f1, f2) => @@ -171,7 +187,7 @@ object RVDType { new UnsafeOrdering { def compare(o1: Long, o2: Long): Int = { var i = 0 - var hasMissing=false + var hasMissing = false while (i < nFields) { val f1 = fields1(i) val f2 = fields2(i) diff --git a/hail/src/main/scala/is/hail/services/BatchConfig.scala b/hail/src/main/scala/is/hail/services/BatchConfig.scala index 6ce643b8282..661bc94e638 100644 --- a/hail/src/main/scala/is/hail/services/BatchConfig.scala +++ b/hail/src/main/scala/is/hail/services/BatchConfig.scala @@ -1,24 +1,19 @@ package is.hail.services +import is.hail.utils._ + import java.io.{File, FileInputStream} -import is.hail.utils._ import org.json4s._ import org.json4s.jackson.JsonMethods -import org.apache.log4j.Logger object BatchConfig { - private[this] val log = Logger.getLogger("BatchConfig") - - def fromConfigFile(file: String): Option[BatchConfig] = { + def fromConfigFile(file: String): Option[BatchConfig] = if (new File(file).exists()) { - using(new FileInputStream(file)) { in => - Some(fromConfig(JsonMethods.parse(in))) - } + using(new FileInputStream(file))(in => Some(fromConfig(JsonMethods.parse(in)))) } else { None } - } def fromConfig(config: JValue): BatchConfig = { implicit val formats: Formats = DefaultFormats diff --git a/hail/src/main/scala/is/hail/services/DeployConfig.scala b/hail/src/main/scala/is/hail/services/DeployConfig.scala index 0b5047b868d..95ca0e01ece 100644 --- a/hail/src/main/scala/is/hail/services/DeployConfig.scala +++ b/hail/src/main/scala/is/hail/services/DeployConfig.scala @@ -1,26 +1,18 @@ package is.hail.services +import is.hail.utils._ + import java.io.{File, FileInputStream} -import java.net._ -import is.hail.utils._ -import is.hail.services.tls._ import org.json4s._ import org.json4s.jackson.JsonMethods -import org.apache.http.client.methods._ -import org.apache.log4j.Logger - -import scala.util.Random object DeployConfig { - private[this] val log = Logger.getLogger("DeployConfig") - private[this] lazy val default: DeployConfig = fromConfigFile() private[this] var _get: DeployConfig = null - def set(x: DeployConfig) = { + def set(x: DeployConfig) = _get = x - } def get(): DeployConfig = { if (_get == null) { @@ -36,7 +28,7 @@ object DeployConfig { file = System.getenv("HAIL_DEPLOY_CONFIG_FILE") if (file == null) { - val fromHome = s"${ System.getenv("HOME") }/.hail/deploy-config.json" + val fromHome = s"${System.getenv("HOME")}/.hail/deploy-config.json" if (new File(fromHome).exists()) file = fromHome } @@ -48,11 +40,9 @@ object DeployConfig { } if (file != null) { - using(new FileInputStream(file)) { in => - fromConfig(JsonMethods.parse(in)) - } + using(new FileInputStream(file))(in => fromConfig(JsonMethods.parse(in))) } else - fromConfig("external", "default", "hail.is") + fromConfig("external", "default", "hail.is", None) } def fromConfig(config: JValue): DeployConfig = { @@ -60,65 +50,70 @@ object DeployConfig { fromConfig( (config \ "location").extract[String], (config \ "default_namespace").extract[String], - (config \ "domain").extract[Option[String]].getOrElse("hail.is")) + (config \ "domain").extract[Option[String]].getOrElse("hail.is"), + (config \ "base_path").extract[Option[String]], + ) } - def fromConfig(location: String, defaultNamespace: String, domain: String): DeployConfig = { - new DeployConfig( - sys.env.getOrElse(toEnvVarName("location"), location), - sys.env.getOrElse(toEnvVarName("default_namespace"), defaultNamespace), - sys.env.getOrElse(toEnvVarName("domain"), domain)) + def fromConfig( + locationFromConfig: String, + defaultNamespaceFromConfig: String, + domainFromConfig: String, + basePathFromConfig: Option[String], + ): DeployConfig = { + val location = sys.env.getOrElse(toEnvVarName("location"), locationFromConfig) + val defaultNamespace = + sys.env.getOrElse(toEnvVarName("default_namespace"), defaultNamespaceFromConfig) + val domain = sys.env.getOrElse(toEnvVarName("domain"), domainFromConfig) + val basePath = sys.env.get(toEnvVarName("basePath")).orElse(basePathFromConfig) + + (basePath, defaultNamespace) match { + case (None, ns) if ns != "default" => + new DeployConfig(location, ns, s"internal.$domain", Some(s"/$ns")) + case _ => new DeployConfig(location, defaultNamespace, domain, basePath) + } } - private[this] def toEnvVarName(s: String): String = { + private[this] def toEnvVarName(s: String): String = "HAIL_" + s.toUpperCase - } } class DeployConfig( val location: String, val defaultNamespace: String, - val domain: String) { - import DeployConfig._ + val domain: String, + val basePath: Option[String], +) { - def scheme(baseScheme: String = "http"): String = { + def scheme(baseScheme: String = "http"): String = if (location == "external" || location == "k8s") baseScheme + "s" else baseScheme - } - - def getServiceNamespace(service: String): String = { - defaultNamespace - } def domain(service: String): String = { - val ns = getServiceNamespace(service) location match { case "k8s" => - s"$service.$ns" + s"$service.$defaultNamespace" case "gce" => - if (ns == "default") + if (basePath.isEmpty) s"$service.hail" else "internal.hail" case "external" => - if (ns == "default") + if (basePath.isEmpty) s"$service.$domain" else - s"internal.$domain" + domain } } - def basePath(service: String): String = { - val ns = getServiceNamespace(service) - if (ns == "default") - "" - else - s"/$ns/$service" - } + def basePath(service: String): String = + basePath match { + case Some(base) => s"$base/$service" + case None => "" + } - def baseUrl(service: String, baseScheme: String = "http"): String = { - s"${ scheme(baseScheme) }://${ domain(service) }${ basePath(service) }" - } + def baseUrl(service: String, baseScheme: String = "http"): String = + s"${scheme(baseScheme)}://${domain(service)}${basePath(service)}" } diff --git a/hail/src/main/scala/is/hail/services/JSONLogLayout.scala b/hail/src/main/scala/is/hail/services/JSONLogLayout.scala index 177899f53a1..fd4de60f42d 100644 --- a/hail/src/main/scala/is/hail/services/JSONLogLayout.scala +++ b/hail/src/main/scala/is/hail/services/JSONLogLayout.scala @@ -1,15 +1,15 @@ package is.hail.services +import scala.collection.mutable.ArrayBuffer + +import java.io.StringWriter import java.text._ import java.util.function._ -import java.nio.charset.StandardCharsets -import java.io.StringWriter -import org.json4s._ import org.apache.log4j._ import org.apache.log4j.spi._ +import org.json4s._ import org.json4s.jackson.JsonMethods -import scala.collection.mutable.ArrayBuffer class DateFormatter { private[this] val fmt = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS") @@ -54,22 +54,28 @@ class JSONLogLayout extends Layout { fields += JField("logger_name", JString(event.getLoggerName())) val mdcFields = new ArrayBuffer[JField]() - mdc.forEach(new BiConsumer[Any, Any]() { def accept(key: Any, value: Any): Unit = { - mdcFields += JField(key.toString, JString(value.toString)) - } }) - fields += JField("mdc", JObject(mdcFields:_*)) + mdc.forEach(new BiConsumer[Any, Any]() { + def accept(key: Any, value: Any): Unit = + mdcFields += JField(key.toString, JString(value.toString)) + }) + fields += JField("mdc", JObject(mdcFields: _*)) fields += JField("ndc", JString(ndc)) fields += JField("severity", JString(event.getLevel().toString())) fields += JField("thread_name", JString(threadName)) if (throwableInfo != null) { - fields += JField("exception_class", JString(throwableInfo.getThrowable().getClass().getCanonicalName())) + fields += JField( + "exception_class", + JString(throwableInfo.getThrowable().getClass().getCanonicalName()), + ) fields += JField("exception_message", JString(throwableInfo.getThrowable().getMessage())) - fields += JField("exception_stacktrace", JString(formatException(throwableInfo.getThrowable()))) + fields += JField( + "exception_stacktrace", + JString(formatException(throwableInfo.getThrowable())), + ) } - val jsonEvent = JObject(fields:_*) - + val jsonEvent = JObject(fields: _*) val sw = new StringWriter() JsonMethods.mapper.writeValue(sw, jsonEvent) diff --git a/hail/src/main/scala/is/hail/services/NettyProxy.scala b/hail/src/main/scala/is/hail/services/NettyProxy.scala index 0d7c43ced88..c88f6754fe3 100644 --- a/hail/src/main/scala/is/hail/services/NettyProxy.scala +++ b/hail/src/main/scala/is/hail/services/NettyProxy.scala @@ -1,8 +1,7 @@ package is.hail.services -import java.io.IOException import io.netty.channel.epoll.Epoll -import io.netty.channel.unix.Errors // cannot be in package.scala because is.hail.io shadows top-level io +import io.netty.channel.unix.Errors // cannot be in package.scala because is.hail.io shadows top-level io object NettyProxy { val isRetryableNettyIOException: Throwable => Boolean = if (Epoll.isAvailable()) { @@ -15,23 +14,22 @@ object NettyProxy { Errors.ERRNO_EPIPE_NEGATIVE, Errors.ERRNO_ECONNRESET_NEGATIVE, Errors.ERROR_ECONNREFUSED_NEGATIVE, - Errors.ERROR_ENETUNREACH_NEGATIVE + Errors.ERROR_ENETUNREACH_NEGATIVE, ) - { case e: Errors.NativeIoException => + { + case e: Errors.NativeIoException => // NativeIoException is a subclass of IOException; therefore this case must appear before // the IOException case // // expectedErr appears to be the additive inverse of the errno returned by Linux? // - // https://github.com/netty/netty/blob/24a0ac36ea91d1aee647d738f879ac873892d829/transport-native-unix-common/src/main/java/io/netty/channel/unix/Errors.java#L49 + /* https://github.com/netty/netty/blob/24a0ac36ea91d1aee647d738f879ac873892d829/transport-native-unix-common/src/main/java/io/netty/channel/unix/Errors.java#L49 */ (nettyRetryableErrorNumbers.contains(e.expectedErr) || - // io.netty.channel.unix.Errors$NativeIoException: readAddress(..) failed: Connection reset by peer - e.getMessage.contains("Connection reset by peer") - ) - case e: Throwable => false + /* io.netty.channel.unix.Errors$NativeIoException: readAddress(..) failed: Connection reset + * by peer */ + e.getMessage.contains("Connection reset by peer")) + case _: Throwable => false } - } else { - { case e: Throwable => false } - } + } else { case _: Throwable => false } } diff --git a/hail/src/main/scala/is/hail/services/Requester.scala b/hail/src/main/scala/is/hail/services/Requester.scala index 082b49c96df..fcfbd808ba8 100644 --- a/hail/src/main/scala/is/hail/services/Requester.scala +++ b/hail/src/main/scala/is/hail/services/Requester.scala @@ -1,33 +1,26 @@ package is.hail.services -import java.io.InputStream -import java.nio.charset.StandardCharsets - -import is.hail.HailContext -import is.hail.utils._ -import is.hail.services._ -import is.hail.shadedazure.com.azure.identity.{ClientSecretCredential, ClientSecretCredentialBuilder} import is.hail.shadedazure.com.azure.core.credential.TokenRequestContext +import is.hail.shadedazure.com.azure.identity.{ + ClientSecretCredential, ClientSecretCredentialBuilder, +} +import is.hail.utils._ + +import scala.collection.JavaConverters._ + +import java.io.{FileInputStream, InputStream} import com.google.auth.oauth2.ServiceAccountCredentials import org.apache.commons.io.IOUtils import org.apache.http.{HttpEntity, HttpEntityEnclosingRequest} -import org.apache.http.client.methods.{HttpDelete, HttpGet, HttpPatch, HttpPost, HttpUriRequest} -import org.apache.http.entity.{ByteArrayEntity, ContentType, StringEntity} import org.apache.http.client.config.RequestConfig +import org.apache.http.client.methods.HttpUriRequest import org.apache.http.impl.client.{CloseableHttpClient, HttpClients} -import org.apache.http.impl.client.{CloseableHttpClient, HttpClients} -import org.apache.http.impl.conn.PoolingHttpClientConnectionManager import org.apache.http.util.EntityUtils import org.apache.log4j.{LogManager, Logger} -import org.json4s.{DefaultFormats, Formats, JObject, JValue} +import org.json4s.{Formats, JValue} import org.json4s.jackson.JsonMethods -import scala.collection.JavaConverters._ -import scala.util.Random -import java.io.FileInputStream - - abstract class CloudCredentials { def accessToken(): String } @@ -46,7 +39,8 @@ class GoogleCloudCredentials(gsaKeyPath: String) extends CloudCredentials { } class AzureCloudCredentials(credentialsPath: String) extends CloudCredentials { - private[this] val credentials: ClientSecretCredential = using(new FileInputStream(credentialsPath)) { is => + private[this] val credentials: ClientSecretCredential = + using(new FileInputStream(credentialsPath)) { is => implicit val formats: Formats = defaultJSONFormats val kvs = JsonMethods.parse(is) val appId = (kvs \ "appId").extract[String] @@ -58,7 +52,7 @@ class AzureCloudCredentials(credentialsPath: String) extends CloudCredentials { .clientSecret(password) .tenantId(tenant) .build() - } + } override def accessToken(): String = { val context = new TokenRequestContext() @@ -70,7 +64,7 @@ class AzureCloudCredentials(credentialsPath: String) extends CloudCredentials { class ClientResponseException( val status: Int, message: String, - cause: Throwable + cause: Throwable, ) extends Exception(message, cause) { def this(statusCode: Int) = this(statusCode, null, null) @@ -95,13 +89,14 @@ object Requester { .setMaxConnTotal(100) .setDefaultRequestConfig(requestConfig) .build() - } catch { case _: NoSSLConfigFound => - log.info("creating HttpClient with no SSL Context") - HttpClients.custom() - .setMaxConnPerRoute(20) - .setMaxConnTotal(100) - .setDefaultRequestConfig(requestConfig) - .build() + } catch { + case _: NoSSLConfigFound => + log.info("creating HttpClient with no SSL Context") + HttpClients.custom() + .setMaxConnPerRoute(20) + .setMaxConnTotal(100) + .setDefaultRequestConfig(requestConfig) + .build() } } @@ -122,8 +117,10 @@ class Requester( val credentials: CloudCredentials ) { import Requester._ - def requestWithHandler[T >: Null](req: HttpUriRequest, body: HttpEntity, f: InputStream => T): T = { - log.info(s"request ${ req.getMethod } ${ req.getURI }") + + def requestWithHandler[T >: Null](req: HttpUriRequest, body: HttpEntity, f: InputStream => T) + : T = { + log.info(s"request ${req.getMethod} ${req.getURI}") if (body != null) req.asInstanceOf[HttpEntityEnclosingRequest].setEntity(body) @@ -134,7 +131,7 @@ class Requester( retryTransientErrors { using(httpClient.execute(req)) { resp => val statusCode = resp.getStatusLine.getStatusCode - log.info(s"request ${ req.getMethod } ${ req.getURI } response $statusCode") + log.info(s"request ${req.getMethod} ${req.getURI} response $statusCode") if (statusCode < 200 || statusCode >= 300) { val entity = resp.getEntity val message = @@ -157,11 +154,15 @@ class Requester( requestWithHandler(req, body, IOUtils.toByteArray) def request(req: HttpUriRequest, body: HttpEntity = null): JValue = - requestWithHandler(req, body, { content => - val s = IOUtils.toByteArray(content) - if (s.isEmpty) - null - else - JsonMethods.parse(new String(s)) - }) + requestWithHandler( + req, + body, + { content => + val s = IOUtils.toByteArray(content) + if (s.isEmpty) + null + else + JsonMethods.parse(new String(s)) + }, + ) } diff --git a/hail/src/main/scala/is/hail/services/batch_client/BatchClient.scala b/hail/src/main/scala/is/hail/services/batch_client/BatchClient.scala index 0ff28f33193..cb23ecbf852 100644 --- a/hail/src/main/scala/is/hail/services/batch_client/BatchClient.scala +++ b/hail/src/main/scala/is/hail/services/batch_client/BatchClient.scala @@ -1,23 +1,20 @@ package is.hail.services.batch_client import is.hail.expr.ir.ByteArrayBuilder +import is.hail.services._ +import is.hail.utils._ + +import scala.util.Random import java.nio.charset.StandardCharsets -import is.hail.utils._ -import is.hail.services._ -import is.hail.services.DeployConfig -import org.apache.commons.io.IOUtils -import org.apache.http.{HttpEntity, HttpEntityEnclosingRequest} -import org.apache.http.client.methods.{HttpDelete, HttpGet, HttpPatch, HttpPost, HttpUriRequest} + +import org.apache.http.HttpEntity +import org.apache.http.client.methods.{HttpDelete, HttpGet, HttpPatch, HttpPost} import org.apache.http.entity.{ByteArrayEntity, ContentType, StringEntity} -import org.apache.http.impl.client.{CloseableHttpClient, HttpClients} -import org.apache.http.util.EntityUtils import org.apache.log4j.{LogManager, Logger} import org.json4s.{DefaultFormats, Formats, JInt, JObject, JString, JValue} import org.json4s.jackson.JsonMethods -import scala.util.Random - class NoBodyException(message: String, cause: Throwable) extends Exception(message, cause) { def this() = this(null, null) @@ -30,10 +27,11 @@ object BatchClient { class BatchClient( deployConfig: DeployConfig, - requester: Requester + requester: Requester, ) { - def this(credentialsPath: String) = this(DeployConfig.get, Requester.fromCredentialsFile(credentialsPath)) + def this(credentialsPath: String) = + this(DeployConfig.get, Requester.fromCredentialsFile(credentialsPath)) import BatchClient._ import requester.request @@ -47,13 +45,16 @@ class BatchClient( request(new HttpPost(s"$baseUrl$path"), body = body) def post(path: String, json: JValue = null): JValue = - post(path, + post( + path, if (json != null) new StringEntity( JsonMethods.compact(json), - ContentType.create("application/json")) + ContentType.create("application/json"), + ) else - null) + null, + ) def patch(path: String): JValue = request(new HttpPatch(s"$baseUrl$path")) @@ -76,8 +77,10 @@ class BatchClient( b += '}' val data = b.result() val resp = retryTransientErrors { - post(s"/api/v1alpha/batches/$batchID/update-fast", - new ByteArrayEntity(data, ContentType.create("application/json"))) + post( + s"/api/v1alpha/batches/$batchID/update-fast", + new ByteArrayEntity(data, ContentType.create("application/json")), + ) } b.clear() (resp \ "update_id").extract[Long] @@ -97,7 +100,9 @@ class BatchClient( s"/api/v1alpha/batches/$batchID/updates/$updateID/jobs/create", new ByteArrayEntity( data, - ContentType.create("application/json"))) + ContentType.create("application/json"), + ), + ) } b.clear() i += 1 @@ -126,14 +131,16 @@ class BatchClient( b ++= JsonMethods.compact(batchJson).getBytes(StandardCharsets.UTF_8) b += '}' val data = b.result() - val resp = retryTransientErrors{ - post("/api/v1alpha/batches/create-fast", - new ByteArrayEntity(data, ContentType.create("application/json"))) + val resp = retryTransientErrors { + post( + "/api/v1alpha/batches/create-fast", + new ByteArrayEntity(data, ContentType.create("application/json")), + ) } b.clear() (resp \ "id").extract[Long] } else { - val resp = retryTransientErrors { post("/api/v1alpha/batches/create", json = batchJson) } + val resp = retryTransientErrors(post("/api/v1alpha/batches/create", json = batchJson)) val batchID = (resp \ "id").extract[Long] val b = new ByteArrayBuilder() @@ -147,13 +154,15 @@ class BatchClient( s"/api/v1alpha/batches/$batchID/jobs/create", new ByteArrayEntity( data, - ContentType.create("application/json"))) + ContentType.create("application/json"), + ), + ) } b.clear() i += 1 } - retryTransientErrors { patch(s"/api/v1alpha/batches/$batchID/close") } + retryTransientErrors(patch(s"/api/v1alpha/batches/$batchID/close")) batchID } log.info(s"run: created batch $batchID") @@ -168,12 +177,12 @@ class BatchClient( def waitForBatch(batchID: Long, excludeDriverJobInBatch: Boolean): JValue = { implicit val formats: Formats = DefaultFormats - Thread.sleep(600) // it is not possible for the batch to be finished in less than 600ms + Thread.sleep(600) // it is not possible for the batch to be finished in less than 600ms val start = System.nanoTime() while (true) { - val batch = retryTransientErrors { get(s"/api/v1alpha/batches/$batchID") } + val batch = retryTransientErrors(get(s"/api/v1alpha/batches/$batchID")) val n_completed = (batch \ "n_completed").extract[Int] val n_jobs = (batch \ "n_jobs").extract[Int] if ((excludeDriverJobInBatch && n_completed == n_jobs - 1) || n_completed == n_jobs) @@ -184,11 +193,13 @@ class BatchClient( // at most, 5s val now = System.nanoTime() val elapsed = now - start - var d = math.max( + val d = math.max( math.min( (0.1 * (0.8 + Random.nextFloat() * 0.4) * (elapsed / 1000.0 / 1000)).toInt, - 5000), - 50) + 5000, + ), + 50, + ) Thread.sleep(d) } @@ -219,7 +230,7 @@ class BatchClient( bunches } - private def addBunchBytes(b: ByteArrayBuilder, bunch: Array[Array[Byte]]) { + private def addBunchBytes(b: ByteArrayBuilder, bunch: Array[Array[Byte]]): Unit = { var j = 0 b += '[' while (j < bunch.length) { diff --git a/hail/src/main/scala/is/hail/services/package.scala b/hail/src/main/scala/is/hail/services/package.scala index ca906da40df..b88c7a2e06b 100644 --- a/hail/src/main/scala/is/hail/services/package.scala +++ b/hail/src/main/scala/is/hail/services/package.scala @@ -1,23 +1,20 @@ package is.hail -import javax.net.ssl.SSLException -import java.net._ -import java.io.EOFException -import java.util.concurrent.TimeoutException +import is.hail.shadedazure.com.azure.storage.common.implementation.Constants import is.hail.utils._ -import org.apache.http.NoHttpResponseException -import org.apache.http.ConnectionClosedException -import org.apache.http.conn.HttpHostConnectException -import org.apache.log4j.{LogManager, Logger} - -import is.hail.shadedazure.reactor.core.Exceptions.ReactiveException -import is.hail.shadedazure.com.azure.storage.common.implementation.Constants import scala.util.Random + import java.io._ -import com.google.cloud.storage.StorageException +import java.net._ + import com.google.api.client.googleapis.json.GoogleJsonResponseException import com.google.api.client.http.HttpResponseException +import com.google.cloud.storage.StorageException +import javax.net.ssl.SSLException +import org.apache.http.{ConnectionClosedException, NoHttpResponseException} +import org.apache.http.conn.HttpHostConnectException +import org.apache.log4j.{LogManager, Logger} package object services { private lazy val log: Logger = LogManager.getLogger("is.hail.services") @@ -30,18 +27,20 @@ package object services { s } - private[this] val LOG_2_MAX_MULTIPLIER = 30 // do not set larger than 30 due to integer overflow calculating multiplier + private[this] val LOG_2_MAX_MULTIPLIER = + 30 // do not set larger than 30 due to integer overflow calculating multiplier private[this] val DEFAULT_MAX_DELAY_MS = 60000 private[this] val DEFAULT_BASE_DELAY_MS = 1000 def delayMsForTry( tries: Int, baseDelayMs: Int = DEFAULT_BASE_DELAY_MS, - maxDelayMs: Int = DEFAULT_MAX_DELAY_MS + maxDelayMs: Int = DEFAULT_MAX_DELAY_MS, ): Int = { // Based on AWS' recommendations: // - https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/ - // - https://github.com/aws/aws-sdk-java/blob/master/aws-java-sdk-core/src/main/java/com/amazonaws/retry/PredefinedBackoffStrategies.java + /* - + * https://github.com/aws/aws-sdk-java/blob/master/aws-java-sdk-core/src/main/java/com/amazonaws/retry/PredefinedBackoffStrategies.java */ val multiplier = 1L << math.min(tries, LOG_2_MAX_MULTIPLIER) val ceilingForDelayMs = math.min(baseDelayMs * multiplier, maxDelayMs).toInt val proposedDelayMs = ceilingForDelayMs / 2 + Random.nextInt(ceilingForDelayMs / 2 + 1) @@ -51,10 +50,9 @@ package object services { def sleepBeforTry( tries: Int, baseDelayMs: Int = DEFAULT_BASE_DELAY_MS, - maxDelayMs: Int = DEFAULT_MAX_DELAY_MS - ) = { + maxDelayMs: Int = DEFAULT_MAX_DELAY_MS, + ) = Thread.sleep(delayMsForTry(tries, baseDelayMs, maxDelayMs)) - } def isLimitedRetriesError(_e: Throwable): Boolean = { // An exception is a "retry once error" if a rare, known bug in a dependency or in a cloud @@ -76,12 +74,12 @@ package object services { true case e: IOException if e.getMessage != null && e.getMessage.contains("Connection reset by peer") => - // java.io.IOException: Connection reset by peer - // at sun.nio.ch.FileDispatcherImpl.read0(NativeMethod) ~[?:1.8.0_362] - // at sun.nio.ch.SocketDispatcher.read(SocketDispatcher.java:39)~[?:1.8.0_362] - // at sun.nio.ch.IOUtil.readIntoNativeBuffer(IOUtil.java:223)~[?:1.8.0_362] - // at sun.nio.ch.IOUtil.read(IOUtil.java:192) ~[?:1.8.0_362] - // at sun.nio.ch.SocketChannelImpl.read(SocketChannelImpl.java:379) ~[?:1.8.0_362] + // java.io.IOException: Connection reset by peer + // at sun.nio.ch.FileDispatcherImpl.read0(NativeMethod) ~[?:1.8.0_362] + // at sun.nio.ch.SocketDispatcher.read(SocketDispatcher.java:39)~[?:1.8.0_362] + // at sun.nio.ch.IOUtil.readIntoNativeBuffer(IOUtil.java:223)~[?:1.8.0_362] + // at sun.nio.ch.IOUtil.read(IOUtil.java:192) ~[?:1.8.0_362] + // at sun.nio.ch.SocketChannelImpl.read(SocketChannelImpl.java:379) ~[?:1.8.0_362] true case e: SSLException if e.getMessage != null && e.getMessage.contains("Tag mismatch!") => @@ -102,7 +100,7 @@ package object services { // ReactiveException it returns the exception unmodified. val e = is.hail.shadedazure.reactor.core.Exceptions.unwrap(_e) e match { - case e: NoHttpResponseException => + case _: NoHttpResponseException => true case e: HttpResponseException if RETRYABLE_HTTP_STATUS_CODES.contains(e.getStatusCode()) => @@ -111,10 +109,10 @@ package object services { if (e.getStatusCode() == 410 && e.getMessage != null && e.getMessage.contains("\"code\": 503,") && - e.getMessage.contains("\"message\": \"Backend Error\",") - ) => + e.getMessage.contains("\"message\": \"Backend Error\",")) => // hail.utils.java.FatalError: HttpResponseException: 410 Gone - // PUT https://storage.googleapis.com/upload/storage/v1/b/hail-test-ezlis/o?name=tmp/hail/nBHPQsrxGvJ4T7Ybdp1IjQ/persist_TableObF6TwC6hv/rows/metadata.json.gz&uploadType=resumable&upload_id=ADPycdsFEtq65NC-ahk6tt6qdD3bKC3asqVSJELnirlpLG_ZDV_637Nn7NourXYTgMRKlX3bQVe9BfD_QfIP_kupTxVQyrJWQJrj + /* PUT + * https://storage.googleapis.com/upload/storage/v1/b/hail-test-ezlis/o?name=tmp/hail/nBHPQsrxGvJ4T7Ybdp1IjQ/persist_TableObF6TwC6hv/rows/metadata.json.gz&uploadType=resumable&upload_id=ADPycdsFEtq65NC-ahk6tt6qdD3bKC3asqVSJELnirlpLG_ZDV_637Nn7NourXYTgMRKlX3bQVe9BfD_QfIP_kupTxVQyrJWQJrj */ // { // "error": { // "code": 503, @@ -135,34 +133,37 @@ package object services { case e: GoogleJsonResponseException if RETRYABLE_HTTP_STATUS_CODES.contains(e.getStatusCode()) => true - case e: HttpHostConnectException => + case _: HttpHostConnectException => true - case e: NoRouteToHostException => + case _: NoRouteToHostException => true - case e: SocketTimeoutException => + case _: SocketTimeoutException => true - case e: java.util.concurrent.TimeoutException => + case _: java.util.concurrent.TimeoutException => true - case e: UnknownHostException => + case _: UnknownHostException => true - case e: ConnectionClosedException => + case _: ConnectionClosedException => true case e: SocketException if e.getMessage != null && ( e.getMessage.contains("Connection timed out (Read failed)") || e.getMessage.contains("Broken pipe") || - e.getMessage.contains("Connection refused")) => + e.getMessage.contains("Connection refused") + ) => true case e: EOFException if e.getMessage != null && e.getMessage.contains("SSL peer shut down incorrectly") => true case e: IllegalStateException - if e.getMessage.contains("Timeout on blocking read") => - // Caused by: java.lang.IllegalStateException: Timeout on blocking read for 30000000000 NANOSECONDS - // reactor.core.publisher.BlockingSingleSubscriber.blockingGet(BlockingSingleSubscriber.java:123) + if e.getMessage.contains("Timeout on blocking read") || + e.getMessage.contains("Faulted stream due to underlying sink write failure") => + /* Caused by: java.lang.IllegalStateException: Timeout on blocking read for 30000000000 + * NANOSECONDS */ + /* reactor.core.publisher.BlockingSingleSubscriber.blockingGet(BlockingSingleSubscriber.java:123) */ // reactor.core.publisher.Mono.block(Mono.java:1727) - // com.azure.storage.common.implementation.StorageImplUtils.blockWithOptionalTimeout(StorageImplUtils.java:130) - // com.azure.storage.blob.specialized.BlobClientBase.downloadStreamWithResponse(BlobClientBase.java:731) + /* com.azure.storage.common.implementation.StorageImplUtils.blockWithOptionalTimeout(StorageImplUtils.java:130) */ + /* com.azure.storage.blob.specialized.BlobClientBase.downloadStreamWithResponse(BlobClientBase.java:731) */ // is.hail.io.fs.AzureStorageFS$$anon$1.fill(AzureStorageFS.scala:152) // is.hail.io.fs.FSSeekableInputStream.read(FS.scala:141) // ... @@ -182,9 +183,9 @@ package object services { def retryTransientErrors[T](f: => T, reset: Option[() => Unit] = None): T = { var tries = 0 while (true) { - try { + try return f - } catch { + catch { case e: Exception => tries += 1 val delay = delayMsForTry(tries) @@ -192,7 +193,8 @@ package object services { log.warn( s"A limited retry error has occured. We will automatically retry " + s"${5 - tries} more times. Do not be alarmed. (next delay: " + - s"$delay). The most recent error was $e.") + s"$delay). The most recent error was $e." + ) } else if (!isTransientError(e)) { throw e } else if (tries % 10 == 0) { diff --git a/hail/src/main/scala/is/hail/services/tls/package.scala b/hail/src/main/scala/is/hail/services/tls/package.scala index 5b45da87e62..504908c0940 100644 --- a/hail/src/main/scala/is/hail/services/tls/package.scala +++ b/hail/src/main/scala/is/hail/services/tls/package.scala @@ -1,18 +1,19 @@ package is.hail.services import is.hail.utils._ -import org.json4s.{DefaultFormats, Formats} + import java.io.{File, FileInputStream} import java.security.KeyStore import javax.net.ssl.{KeyManagerFactory, SSLContext, TrustManagerFactory} import org.apache.log4j.{LogManager, Logger} +import org.json4s.{DefaultFormats, Formats} import org.json4s.JsonAST.JString import org.json4s.jackson.JsonMethods class NoSSLConfigFound( message: String, - cause: Throwable + cause: Throwable, ) extends Exception(message, cause) { def this() = this(null, null) @@ -26,7 +27,8 @@ case class SSLConfig( incoming_trust_store: String, key: String, cert: String, - key_store: String) + key_store: String, +) package object tls { lazy val log: Logger = LogManager.getLogger("is.hail.tls") @@ -40,7 +42,9 @@ package object tls { using(new FileInputStream(configFile)) { is => implicit val formats: Formats = DefaultFormats - JsonMethods.parse(is).mapField { case (k, JString(v)) => (k, JString(s"$configDir/$v")) }.extract[SSLConfig] + JsonMethods.parse(is).mapField { case (k, JString(v)) => + (k, JString(s"$configDir/$v")) + }.extract[SSLConfig] } } @@ -72,16 +76,12 @@ package object tls { val pw = "dummypw".toCharArray val ks = KeyStore.getInstance("PKCS12") - using(new FileInputStream(sslConfig.key_store)) { is => - ks.load(is, pw) - } + using(new FileInputStream(sslConfig.key_store))(is => ks.load(is, pw)) val kmf = KeyManagerFactory.getInstance("SunX509") kmf.init(ks, pw) val ts = KeyStore.getInstance("JKS") - using(new FileInputStream(sslConfig.outgoing_trust_store)) { is => - ts.load(is, pw) - } + using(new FileInputStream(sslConfig.outgoing_trust_store))(is => ts.load(is, pw)) val tmf = TrustManagerFactory.getInstance("SunX509") tmf.init(ts) diff --git a/hail/src/main/scala/is/hail/sparkextras/BlockedRDD.scala b/hail/src/main/scala/is/hail/sparkextras/BlockedRDD.scala index b5ab4a2bb47..9006acbe901 100644 --- a/hail/src/main/scala/is/hail/sparkextras/BlockedRDD.scala +++ b/hail/src/main/scala/is/hail/sparkextras/BlockedRDD.scala @@ -1,16 +1,14 @@ package is.hail.sparkextras import is.hail.utils._ -import org.apache.spark.rdd.RDD -import org.apache.spark.{Dependency, NarrowDependency, Partition, TaskContext} -import scala.language.existentials import scala.reflect.ClassTag -case class BlockedRDDPartition(@transient rdd: RDD[_], - index: Int, - first: Int, - last: Int) extends Partition { +import org.apache.spark.{Dependency, NarrowDependency, Partition, TaskContext} +import org.apache.spark.rdd.RDD + +case class BlockedRDDPartition(@transient rdd: RDD[_], index: Int, first: Int, last: Int) + extends Partition { require(first <= last) val parentPartitions: Array[Partition] = range.map(rdd.partitions).toArray @@ -18,31 +16,33 @@ case class BlockedRDDPartition(@transient rdd: RDD[_], def range: Range = first to last } -class BlockedRDD[T](@transient var prev: RDD[T], +class BlockedRDD[T]( + @transient var prev: RDD[T], @transient val partFirst: Array[Int], - @transient val partLast: Array[Int] -)(implicit tct: ClassTag[T]) extends RDD[T](prev.sparkContext, Nil) { + @transient val partLast: Array[Int], +)(implicit tct: ClassTag[T] +) extends RDD[T](prev.sparkContext, Nil) { assert(partFirst.length == partLast.length) - override def getPartitions: Array[Partition] = { + override def getPartitions: Array[Partition] = Array.tabulate[Partition](partFirst.length)(i => - BlockedRDDPartition(prev, i, partFirst(i), partLast(i))) - } + BlockedRDDPartition(prev, i, partFirst(i), partLast(i)) + ) override def compute(split: Partition, context: TaskContext): Iterator[T] = { val parent = dependencies.head.rdd.asInstanceOf[RDD[T]] split.asInstanceOf[BlockedRDDPartition].parentPartitions.iterator.flatMap(p => - parent.iterator(p, context)) + parent.iterator(p, context) + ) } - override def getDependencies: Seq[Dependency[_]] = { + override def getDependencies: Seq[Dependency[_]] = FastSeq(new NarrowDependency(prev) { def getParents(id: Int): Seq[Int] = partitions(id).asInstanceOf[BlockedRDDPartition].range }) - } - override def clearDependencies() { + override def clearDependencies(): Unit = { super.clearDependencies() prev = null } @@ -52,7 +52,8 @@ class BlockedRDD[T](@transient var prev: RDD[T], val range = partition.asInstanceOf[BlockedRDDPartition].range val locationAvail = range.flatMap(i => - prev.preferredLocations(prevPartitions(i))) + prev.preferredLocations(prevPartitions(i)) + ) .groupBy(identity) .mapValues(_.length) diff --git a/hail/src/main/scala/is/hail/sparkextras/ContextPairRDDFunctions.scala b/hail/src/main/scala/is/hail/sparkextras/ContextPairRDDFunctions.scala index 64b18e69e06..34e7f54ef98 100644 --- a/hail/src/main/scala/is/hail/sparkextras/ContextPairRDDFunctions.scala +++ b/hail/src/main/scala/is/hail/sparkextras/ContextPairRDDFunctions.scala @@ -1,10 +1,10 @@ package is.hail.sparkextras +import scala.reflect.ClassTag + import org.apache.spark._ import org.apache.spark.rdd._ -import scala.reflect.ClassTag - class ContextPairRDDFunctions[K: ClassTag, V: ClassTag]( crdd: ContextRDD[(K, V)] ) { @@ -12,7 +12,8 @@ class ContextPairRDDFunctions[K: ClassTag, V: ClassTag]( def partitionBy(p: Partitioner): ContextRDD[(K, V)] = if (crdd.partitioner.contains(p)) crdd else ContextRDD.weaken( - new ShuffledRDD[K, V, V](crdd.run, p)) + new ShuffledRDD[K, V, V](crdd.run, p) + ) def values: ContextRDD[V] = crdd.map(_._2) } diff --git a/hail/src/main/scala/is/hail/sparkextras/ContextRDD.scala b/hail/src/main/scala/is/hail/sparkextras/ContextRDD.scala index 831180e41d4..fe9c4d4e4ac 100644 --- a/hail/src/main/scala/is/hail/sparkextras/ContextRDD.scala +++ b/hail/src/main/scala/is/hail/sparkextras/ContextRDD.scala @@ -1,20 +1,18 @@ package is.hail.sparkextras import is.hail.HailContext -import is.hail.annotations.RegionPool -import is.hail.backend.HailTaskContext import is.hail.backend.spark.{SparkBackend, SparkTaskContext} import is.hail.rvd.RVDContext import is.hail.utils._ -import is.hail.utils.PartitionCounts._ -import org.apache.spark._ -import org.apache.spark.rdd._ -import org.apache.spark.ExposedUtils import scala.reflect.ClassTag +import org.apache.spark._ +import org.apache.spark.rdd._ + object Combiner { - def apply[U](zero: => U, combine: (U, U) => U, commutative: Boolean, associative: Boolean): Combiner[U] = { + def apply[U](zero: => U, combine: (U, U) => U, commutative: Boolean, associative: Boolean) + : Combiner[U] = { assert(associative) if (commutative) new CommutativeAndAssociativeCombiner(zero, combine) @@ -24,7 +22,7 @@ object Combiner { } abstract class Combiner[U] { - def combine(i: Int, value0: U) + def combine(i: Int, value0: U): Unit def result(): U } @@ -45,8 +43,8 @@ class AssociativeCombiner[U](zero: => U, combine: (U, U) => U) extends Combiner[ // U it holds. private val t = new java.util.TreeMap[Int, TreeValue]() - def combine(i: Int, value0: U) { - log.info(s"at result $i, AssociativeCombiner contains ${ t.size() } queued results") + def combine(i: Int, value0: U): Unit = { + log.info(s"at result $i, AssociativeCombiner contains ${t.size()} queued results") var value = value0 var end = i @@ -87,11 +85,13 @@ object ContextRDD { ): ContextRDD[T] = new ContextRDD(rdd) def empty[T: ClassTag](): ContextRDD[T] = - new ContextRDD(SparkBackend.sparkContext("ContextRDD.empty").emptyRDD[RVDContext => Iterator[T]]) + new ContextRDD( + SparkBackend.sparkContext("ContextRDD.empty").emptyRDD[RVDContext => Iterator[T]] + ) def union[T: ClassTag]( sc: SparkContext, - xs: Seq[ContextRDD[T]] + xs: Seq[ContextRDD[T]], ): ContextRDD[T] = new ContextRDD(sc.union(xs.map(_.rdd))) @@ -103,51 +103,54 @@ object ContextRDD { def textFilesLines( files: Array[String], nPartitions: Option[Int] = None, - filterAndReplace: TextInputFilterAndReplace = TextInputFilterAndReplace() + filterAndReplace: TextInputFilterAndReplace = TextInputFilterAndReplace(), ): ContextRDD[WithContext[String]] = textFilesLines( files, nPartitions.getOrElse(HailContext.backend.defaultParallelism), - filterAndReplace) + filterAndReplace, + ) def textFilesLines( files: Array[String], nPartitions: Int, - filterAndReplace: TextInputFilterAndReplace + filterAndReplace: TextInputFilterAndReplace, ): ContextRDD[WithContext[String]] = ContextRDD.weaken( SparkBackend.sparkContext("ContxtRDD.textFilesLines").textFilesLines( files, - nPartitions) - .mapPartitions(filterAndReplace.apply)) + nPartitions, + ) + .mapPartitions(filterAndReplace.apply) + ) - def parallelize[T: ClassTag](sc: SparkContext, data: Seq[T], nPartitions: Option[Int] = None): ContextRDD[T] = - weaken(sc.parallelize(data, nPartitions.getOrElse(sc.defaultMinPartitions))).map(x => { - x - }) + def parallelize[T: ClassTag](sc: SparkContext, data: Seq[T], nPartitions: Option[Int] = None) + : ContextRDD[T] = + weaken(sc.parallelize(data, nPartitions.getOrElse(sc.defaultMinPartitions))).map(x => x) def parallelize[T: ClassTag](data: Seq[T], numSlices: Int): ContextRDD[T] = - weaken(SparkBackend.sparkContext("ContextRDD.parallelize").parallelize(data, numSlices)).map(x => { - x - }) + weaken(SparkBackend.sparkContext("ContextRDD.parallelize").parallelize(data, numSlices)).map { + x => x + } def parallelize[T: ClassTag](data: Seq[T]): ContextRDD[T] = - weaken(SparkBackend.sparkContext("ContextRDD.parallelize").parallelize(data)).map(x => { - x - }) + weaken(SparkBackend.sparkContext("ContextRDD.parallelize").parallelize(data)).map(x => x) type ElementType[T] = RVDContext => Iterator[T] def czipNPartitions[T: ClassTag, U: ClassTag]( crdds: IndexedSeq[ContextRDD[T]], - preservesPartitioning: Boolean = false - )(f: (RVDContext, Array[Iterator[T]]) => Iterator[U] + preservesPartitioning: Boolean = false, + )( + f: (RVDContext, Array[Iterator[T]]) => Iterator[U] ): ContextRDD[U] = { - def inCtx(f: RVDContext => Iterator[U]): Iterator[RVDContext => Iterator[U]] = Iterator.single(f) + def inCtx(f: RVDContext => Iterator[U]): Iterator[RVDContext => Iterator[U]] = + Iterator.single(f) new ContextRDD( MultiWayZipPartitionsRDD(crdds.map(_.rdd)) { its => inCtx(ctx => f(ctx, its.map(_.flatMap(_(ctx))))) - }) + } + ) } } @@ -158,23 +161,15 @@ class ContextRDD[T: ClassTag]( private[this] def sparkManagedContext[U](func: RVDContext => U): U = { val c = RVDContext.default(SparkTaskContext.get().getRegionPool()) - TaskContext.get().addTaskCompletionListener[Unit] { (_: TaskContext) => - c.close() - } + TaskContext.get().addTaskCompletionListener[Unit]((_: TaskContext) => c.close()) func(c) } - def run[U >: T : ClassTag]: RDD[U] = { - this.cleanupRegions.rdd.mapPartitions { part => - sparkManagedContext{ c => - part.flatMap(_(c)) - } - } - } + def run[U >: T: ClassTag]: RDD[U] = + this.cleanupRegions.rdd.mapPartitions(part => sparkManagedContext(c => part.flatMap(_(c)))) - def collect(): Array[T] = { + def collect(): Array[T] = run.collect() - } private[this] def inCtx[U: ClassTag]( f: RVDContext => Iterator[U] @@ -191,13 +186,13 @@ class ContextRDD[T: ClassTag]( def mapPartitions[U: ClassTag]( f: Iterator[T] => Iterator[U], - preservesPartitioning: Boolean = false + preservesPartitioning: Boolean = false, ): ContextRDD[U] = cmapPartitions((_, part) => f(part), preservesPartitioning) def mapPartitionsWithIndex[U: ClassTag]( f: (Int, Iterator[T]) => Iterator[U], - preservesPartitioning: Boolean = false + preservesPartitioning: Boolean = false, ): ContextRDD[U] = cmapPartitionsWithIndex((i, _, part) => f(i, part), preservesPartitioning) @@ -212,25 +207,31 @@ class ContextRDD[T: ClassTag]( def cmapPartitions[U: ClassTag]( f: (RVDContext, Iterator[T]) => Iterator[U], - preservesPartitioning: Boolean = false + preservesPartitioning: Boolean = false, ): ContextRDD[U] = new ContextRDD( rdd.mapPartitions( part => inCtx(ctx => f(ctx, part.flatMap(_(ctx)))), - preservesPartitioning)) + preservesPartitioning, + ) + ) - def cmapPartitionsWithContext[U: ClassTag](f: (RVDContext, (RVDContext) => Iterator[T]) => Iterator[U]): ContextRDD[U] = { - new ContextRDD(rdd.mapPartitions( - part => part.flatMap { - x => inCtx(consumerCtx => f(consumerCtx, x)) - })) - } + def cmapPartitionsWithContext[U: ClassTag]( + f: (RVDContext, (RVDContext) => Iterator[T]) => Iterator[U] + ): ContextRDD[U] = + new ContextRDD(rdd.mapPartitions(part => + part.flatMap { + x => inCtx(consumerCtx => f(consumerCtx, x)) + } + )) - def cmapPartitionsWithContextAndIndex[U: ClassTag](f: (Int, RVDContext, (RVDContext) => Iterator[T]) => Iterator[U]): ContextRDD[U] = { - new ContextRDD(rdd.mapPartitionsWithIndex( - (i, part) => part.flatMap { + def cmapPartitionsWithContextAndIndex[U: ClassTag]( + f: (Int, RVDContext, (RVDContext) => Iterator[T]) => Iterator[U] + ): ContextRDD[U] = + new ContextRDD(rdd.mapPartitionsWithIndex((i, part) => + part.flatMap { x => inCtx(consumerCtx => f(i, consumerCtx, x)) - })) - } + } + )) // Gives consumer ownership of the context. Consumer is responsible for freeing // resources per element. @@ -241,47 +242,55 @@ class ContextRDD[T: ClassTag]( val c = RVDContext.default(SparkTaskContext.get().getRegionPool()) val ans = f(taskContext.partitionId(), c, it.flatMap(_(c))) ans - }) + }, + ) def cmapPartitionsAndContext[U: ClassTag]( f: (RVDContext, (Iterator[RVDContext => Iterator[T]])) => Iterator[U], - preservesPartitioning: Boolean = false + preservesPartitioning: Boolean = false, ): ContextRDD[U] = onRDD(_.mapPartitions( part => inCtx(ctx => f(ctx, part)), - preservesPartitioning)) + preservesPartitioning, + )) def cmapPartitionsWithIndex[U: ClassTag]( f: (Int, RVDContext, Iterator[T]) => Iterator[U], - preservesPartitioning: Boolean = false + preservesPartitioning: Boolean = false, ): ContextRDD[U] = new ContextRDD( rdd.mapPartitionsWithIndex( (i, part) => inCtx(ctx => f(i, ctx, part.flatMap(_(ctx)))), - preservesPartitioning)) + preservesPartitioning, + ) + ) def cmapPartitionsWithIndexAndValue[U: ClassTag, V]( values: Array[V], f: (Int, RVDContext, V, Iterator[T]) => Iterator[U], - preservesPartitioning: Boolean = false + preservesPartitioning: Boolean = false, ): ContextRDD[U] = new ContextRDD( new MapPartitionsWithValueRDD[(RVDContext) => Iterator[T], (RVDContext) => Iterator[U], V]( rdd, values, (i, v, part) => inCtx(ctx => f(i, ctx, v, part.flatMap(_(ctx)))), - preservesPartitioning)) + preservesPartitioning, + ) + ) def cmapPartitionsAndContextWithIndex[U: ClassTag]( f: (Int, RVDContext, Iterator[RVDContext => Iterator[T]]) => Iterator[U], - preservesPartitioning: Boolean = false + preservesPartitioning: Boolean = false, ): ContextRDD[U] = onRDD(_.mapPartitionsWithIndex( (i, part) => inCtx(ctx => f(i, ctx, part)), - preservesPartitioning)) + preservesPartitioning, + )) def czip[U: ClassTag, V: ClassTag]( that: ContextRDD[U], - preservesPartitioning: Boolean = false - )(f: (RVDContext, T, U) => V + preservesPartitioning: Boolean = false, + )( + f: (RVDContext, T, U) => V ): ContextRDD[V] = czipPartitions(that, preservesPartitioning) { (ctx, l, r) => new Iterator[V] { def hasNext = { @@ -290,9 +299,8 @@ class ContextRDD[T: ClassTag]( assert(lhn == rhn) lhn } - def next(): V = { + def next(): V = f(ctx, l.next(), r.next()) - } } } @@ -300,8 +308,9 @@ class ContextRDD[T: ClassTag]( // between the two producers and the one consumer def zipPartitions[U: ClassTag, V: ClassTag]( that: ContextRDD[U], - preservesPartitioning: Boolean = false - )(f: (Iterator[T], Iterator[U]) => Iterator[V] + preservesPartitioning: Boolean = false, + )( + f: (Iterator[T], Iterator[U]) => Iterator[V] ): ContextRDD[V] = czipPartitions[U, V](that, preservesPartitioning)((_, l, r) => f(l, r)) @@ -309,33 +318,47 @@ class ContextRDD[T: ClassTag]( // between the two producers and the one consumer def czipPartitions[U: ClassTag, V: ClassTag]( that: ContextRDD[U], - preservesPartitioning: Boolean = false - )(f: (RVDContext, Iterator[T], Iterator[U]) => Iterator[V] + preservesPartitioning: Boolean = false, + )( + f: (RVDContext, Iterator[T], Iterator[U]) => Iterator[V] ): ContextRDD[V] = new ContextRDD( - rdd.zipPartitions(that.rdd, preservesPartitioning)( - (l, r) => inCtx(ctx => f(ctx, l.flatMap(_(ctx)), r.flatMap(_(ctx)))))) + rdd.zipPartitions(that.rdd, preservesPartitioning)((l, r) => + inCtx(ctx => f(ctx, l.flatMap(_(ctx)), r.flatMap(_(ctx)))) + ) + ) // WARNING: this method is easy to use wrong because it shares the context // between the two producers and the one consumer def czipPartitionsWithIndex[U: ClassTag, V: ClassTag]( that: ContextRDD[U], - preservesPartitioning: Boolean = false - )(f: (Int, RVDContext, Iterator[T], Iterator[U]) => Iterator[V] + preservesPartitioning: Boolean = false, + )( + f: (Int, RVDContext, Iterator[T], Iterator[U]) => Iterator[V] ): ContextRDD[V] = new ContextRDD( - rdd.zipPartitions(that.rdd, preservesPartitioning)( - (l, r) => Iterator.single(l -> r)).mapPartitionsWithIndex({ case (i, it) => - it.flatMap { case (l, r) => - inCtx(ctx => f(i, ctx, l.flatMap(_(ctx)), r.flatMap(_(ctx)))) - } - }, preservesPartitioning)) + rdd.zipPartitions(that.rdd, preservesPartitioning)((l, r) => + Iterator.single(l -> r) + ).mapPartitionsWithIndex( + { case (i, it) => + it.flatMap { case (l, r) => + inCtx(ctx => f(i, ctx, l.flatMap(_(ctx)), r.flatMap(_(ctx)))) + } + }, + preservesPartitioning, + ) + ) def czipPartitionsAndContext[U: ClassTag, V: ClassTag]( that: ContextRDD[U], - preservesPartitioning: Boolean = false - )(f: (RVDContext, Iterator[RVDContext => Iterator[T]], Iterator[RVDContext => Iterator[U]]) => Iterator[V] + preservesPartitioning: Boolean = false, + )( + f: ( + RVDContext, + Iterator[RVDContext => Iterator[T]], + Iterator[RVDContext => Iterator[U]], + ) => Iterator[V] ): ContextRDD[V] = new ContextRDD( - rdd.zipPartitions(that.rdd, preservesPartitioning)( - (l, r) => inCtx(ctx => f(ctx, l, r)))) + rdd.zipPartitions(that.rdd, preservesPartitioning)((l, r) => inCtx(ctx => f(ctx, l, r))) + ) def subsetPartitions(keptPartitionIndices: Array[Int]): ContextRDD[T] = onRDD(_.subsetPartitions(keptPartitionIndices)) @@ -354,23 +377,23 @@ class ContextRDD[T: ClassTag]( // [2, 5, 7]. With this, original partion indicies 0, 1, and 2 make up the first new partition 3, // 4, and 5 make up the second, and 6 and 7 make up the third. def coalesceWithEnds(partEnds: Array[Int]): ContextRDD[T] = - onRDD(rdd => { - rdd.coalesce(partEnds.length, shuffle = false, partitionCoalescer = Some(new CRDDCoalescer(partEnds))) - }) + onRDD { rdd => + rdd.coalesce( + partEnds.length, + shuffle = false, + partitionCoalescer = Some(new CRDDCoalescer(partEnds)), + ) + } def runJob[U: ClassTag](f: Iterator[T] => U, partitions: Seq[Int]): Array[U] = sparkContext.runJob( rdd, - { (it: Iterator[ElementType]) => - sparkManagedContext { c => - f(it.flatMap(_ (c))) - } - }, - partitions) + (it: Iterator[ElementType]) => sparkManagedContext(c => f(it.flatMap(_(c)))), + partitions, + ) - def blocked(partFirst: Array[Int], partLast: Array[Int]): ContextRDD[T] = { + def blocked(partFirst: Array[Int], partLast: Array[Int]): ContextRDD[T] = new ContextRDD(new BlockedRDD(rdd, partFirst, partLast)) - } def sparkContext: SparkContext = rdd.sparkContext @@ -379,9 +402,6 @@ class ContextRDD[T: ClassTag]( def preferredLocations(partition: Partition): Seq[String] = rdd.preferredLocations(partition) - private[this] def clean[U <: AnyRef](value: U): U = - ExposedUtils.clean(value) - def partitions: Array[Partition] = rdd.partitions def partitioner: Option[Partitioner] = rdd.partitioner @@ -403,12 +423,11 @@ private class CRDDCoalescer(partEnds: Array[Int]) extends PartitionCoalescer wit val groups = Array.fill(maxPartitions)(new PartitionGroup()) val parts = prev.partitions var i = 0 - for ((end, j) <- partEnds.zipWithIndex) { + for ((end, j) <- partEnds.zipWithIndex) while (i <= end) { groups(j).partitions += parts(i) i += 1 } - } groups } } diff --git a/hail/src/main/scala/is/hail/sparkextras/IndexReadRDD.scala b/hail/src/main/scala/is/hail/sparkextras/IndexReadRDD.scala index e7c771996dc..d7316f3d736 100644 --- a/hail/src/main/scala/is/hail/sparkextras/IndexReadRDD.scala +++ b/hail/src/main/scala/is/hail/sparkextras/IndexReadRDD.scala @@ -2,17 +2,19 @@ package is.hail.sparkextras import is.hail.backend.spark.SparkBackend import is.hail.utils.Interval -import org.apache.spark.{Dependency, Partition, RangeDependency, SparkContext, TaskContext} -import org.apache.spark.rdd.RDD import scala.reflect.ClassTag -case class IndexedFilePartition(index: Int, file: String, bounds: Option[Interval]) extends Partition +import org.apache.spark.{Partition, TaskContext} +import org.apache.spark.rdd.RDD + +case class IndexedFilePartition(index: Int, file: String, bounds: Option[Interval]) + extends Partition class IndexReadRDD[T: ClassTag]( @transient val partFiles: Array[String], @transient val intervalBounds: Option[Array[Interval]], - f: (IndexedFilePartition, TaskContext) => T + f: (IndexedFilePartition, TaskContext) => T, ) extends RDD[T](SparkBackend.sparkContext("IndexReadRDD"), Nil) { def getPartitions: Array[Partition] = Array.tabulate(partFiles.length) { i => @@ -20,8 +22,8 @@ class IndexReadRDD[T: ClassTag]( } override def compute( - split: Partition, context: TaskContext - ): Iterator[T] = { + split: Partition, + context: TaskContext, + ): Iterator[T] = Iterator.single(f(split.asInstanceOf[IndexedFilePartition], context)) - } } diff --git a/hail/src/main/scala/is/hail/sparkextras/MapPartitionsWithValueRDD.scala b/hail/src/main/scala/is/hail/sparkextras/MapPartitionsWithValueRDD.scala index 5e9b88cab7c..8b04d55d2e4 100644 --- a/hail/src/main/scala/is/hail/sparkextras/MapPartitionsWithValueRDD.scala +++ b/hail/src/main/scala/is/hail/sparkextras/MapPartitionsWithValueRDD.scala @@ -1,14 +1,15 @@ package is.hail.sparkextras -import org.apache.spark.rdd.RDD -import org.apache.spark.{Partition, TaskContext} - import scala.annotation.meta.param import scala.reflect.ClassTag +import org.apache.spark.{Partition, TaskContext} +import org.apache.spark.rdd.RDD + case class MapPartitionsWithValueRDDPartition[V]( parentPartition: Partition, - value: V) extends Partition { + value: V, +) extends Partition { def index: Int = parentPartition.index } @@ -16,20 +17,21 @@ class MapPartitionsWithValueRDD[T: ClassTag, U: ClassTag, V]( var prev: RDD[T], @(transient @param) values: Array[V], f: (Int, V, Iterator[T]) => Iterator[U], - preservesPartitioning: Boolean) extends RDD[U](prev) { + preservesPartitioning: Boolean, +) extends RDD[U](prev) { - @transient override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None + @transient override val partitioner = + if (preservesPartitioning) firstParent[T].partitioner else None - override def getPartitions: Array[Partition] = { + override def getPartitions: Array[Partition] = firstParent[T].partitions.map(p => MapPartitionsWithValueRDDPartition(p, values(p.index))) - } override def compute(split: Partition, context: TaskContext): Iterator[U] = { val p = split.asInstanceOf[MapPartitionsWithValueRDDPartition[V]] f(split.index, p.value, firstParent[T].iterator(p.parentPartition, context)) } - override def clearDependencies() { + override def clearDependencies(): Unit = { super.clearDependencies() prev = null } diff --git a/hail/src/main/scala/is/hail/sparkextras/MultiWayZipPartitionsRDD.scala b/hail/src/main/scala/is/hail/sparkextras/MultiWayZipPartitionsRDD.scala index 781d3f7c486..cf71f7f0246 100644 --- a/hail/src/main/scala/is/hail/sparkextras/MultiWayZipPartitionsRDD.scala +++ b/hail/src/main/scala/is/hail/sparkextras/MultiWayZipPartitionsRDD.scala @@ -1,25 +1,26 @@ package is.hail.sparkextras -import org.apache.spark.rdd.RDD -import org.apache.spark.{OneToOneDependency, Partition, SparkContext, TaskContext} - import scala.reflect.ClassTag +import org.apache.spark.{OneToOneDependency, Partition, SparkContext, TaskContext} +import org.apache.spark.rdd.RDD + object MultiWayZipPartitionsRDD { - def apply[T: ClassTag , V: ClassTag]( + def apply[T: ClassTag, V: ClassTag]( rdds: IndexedSeq[RDD[T]] - )(f: (Array[Iterator[T]]) => Iterator[V]): MultiWayZipPartitionsRDD[T, V] = { + )( + f: (Array[Iterator[T]]) => Iterator[V] + ): MultiWayZipPartitionsRDD[T, V] = new MultiWayZipPartitionsRDD(rdds.head.sparkContext, rdds, f) - } } private case class MultiWayZipPartition(val index: Int, val partitions: IndexedSeq[Partition]) - extends Partition + extends Partition class MultiWayZipPartitionsRDD[T: ClassTag, V: ClassTag]( sc: SparkContext, var rdds: IndexedSeq[RDD[T]], - var f: (Array[Iterator[T]]) => Iterator[V] + var f: (Array[Iterator[T]]) => Iterator[V], ) extends RDD[V](sc, rdds.map(x => new OneToOneDependency(x))) { require(rdds.length > 0) private val numParts = rdds(0).partitions.length @@ -27,11 +28,10 @@ class MultiWayZipPartitionsRDD[T: ClassTag, V: ClassTag]( override val partitioner = None - override def getPartitions: Array[Partition] = { + override def getPartitions: Array[Partition] = Array.tabulate[Partition](numParts) { i => MultiWayZipPartition(i, rdds.map(rdd => rdd.partitions(i))) } - } override def compute(s: Partition, tc: TaskContext) = { val partitions = s.asInstanceOf[MultiWayZipPartition].partitions @@ -39,7 +39,7 @@ class MultiWayZipPartitionsRDD[T: ClassTag, V: ClassTag]( f(arr) } - override def clearDependencies() { + override def clearDependencies(): Unit = { super.clearDependencies rdds = null f = null diff --git a/hail/src/main/scala/is/hail/sparkextras/OriginUnionRDD.scala b/hail/src/main/scala/is/hail/sparkextras/OriginUnionRDD.scala index ff5188e0305..f9508bf4d94 100644 --- a/hail/src/main/scala/is/hail/sparkextras/OriginUnionRDD.scala +++ b/hail/src/main/scala/is/hail/sparkextras/OriginUnionRDD.scala @@ -1,26 +1,29 @@ package is.hail.sparkextras -import org.apache.spark.{Dependency, Partition, RangeDependency, SparkContext, TaskContext} -import org.apache.spark.rdd.RDD - import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag +import org.apache.spark.{Dependency, Partition, RangeDependency, SparkContext, TaskContext} +import org.apache.spark.rdd.RDD + private[hail] class OriginUnionPartition( val index: Int, val originIdx: Int, - val originPart: Partition + val originPart: Partition, ) extends Partition class OriginUnionRDD[T: ClassTag, S: ClassTag]( sc: SparkContext, var rdds: IndexedSeq[RDD[T]], - f: (Int, Int, Iterator[T]) => Iterator[S] + f: (Int, Int, Iterator[T]) => Iterator[S], ) extends RDD[S](sc, Nil) { override def getPartitions: Array[Partition] = { val arr = new Array[Partition](rdds.map(_.partitions.length).sum) var i = 0 - for ((rdd, rddIdx) <- rdds.zipWithIndex; part <- rdd.partitions) { + for { + (rdd, rddIdx) <- rdds.zipWithIndex + part <- rdd.partitions + } { arr(i) = new OriginUnionPartition(i, rddIdx, part) i += 1 } @@ -42,7 +45,7 @@ class OriginUnionRDD[T: ClassTag, S: ClassTag]( f(p.originIdx, p.originPart.index, parent[T](p.originIdx).iterator(p.originPart, tc)) } - override def clearDependencies() { + override def clearDependencies(): Unit = { super.clearDependencies() rdds = null } diff --git a/hail/src/main/scala/is/hail/sparkextras/ReorderedPartitionsRDD.scala b/hail/src/main/scala/is/hail/sparkextras/ReorderedPartitionsRDD.scala index 2cf1fd90921..f8b5225c991 100644 --- a/hail/src/main/scala/is/hail/sparkextras/ReorderedPartitionsRDD.scala +++ b/hail/src/main/scala/is/hail/sparkextras/ReorderedPartitionsRDD.scala @@ -1,15 +1,19 @@ package is.hail.sparkextras import is.hail.utils.FastSeq -import org.apache.spark.rdd.RDD -import org.apache.spark.{Dependency, NarrowDependency, Partition, TaskContext} import scala.reflect.ClassTag +import org.apache.spark.{Dependency, NarrowDependency, Partition, TaskContext} +import org.apache.spark.rdd.RDD + case class ReorderedPartitionsRDDPartition(index: Int, oldPartition: Partition) extends Partition -class ReorderedPartitionsRDD[T](@transient var prev: RDD[T], @transient val oldIndices: Array[Int])(implicit tct: ClassTag[T]) - extends RDD[T](prev.sparkContext, Nil) { +class ReorderedPartitionsRDD[T]( + @transient var prev: RDD[T], + @transient val oldIndices: Array[Int], +)(implicit tct: ClassTag[T] +) extends RDD[T](prev.sparkContext, Nil) { override def getPartitions: Array[Partition] = { val parentPartitions = dependencies.head.rdd.asInstanceOf[RDD[T]].partitions @@ -29,7 +33,7 @@ class ReorderedPartitionsRDD[T](@transient var prev: RDD[T], @transient val oldI override def getParents(partitionId: Int): Seq[Int] = FastSeq(oldIndices(partitionId)) }) - override def clearDependencies() { + override def clearDependencies(): Unit = { super.clearDependencies() prev = null } diff --git a/hail/src/main/scala/is/hail/sparkextras/RepartitionedOrderedRDD2.scala b/hail/src/main/scala/is/hail/sparkextras/RepartitionedOrderedRDD2.scala index 250a381f731..cfd1054d297 100644 --- a/hail/src/main/scala/is/hail/sparkextras/RepartitionedOrderedRDD2.scala +++ b/hail/src/main/scala/is/hail/sparkextras/RepartitionedOrderedRDD2.scala @@ -4,57 +4,65 @@ import is.hail.annotations._ import is.hail.backend.HailStateManager import is.hail.rvd.{PartitionBoundOrdering, RVD, RVDContext, RVDPartitioner, RVDType} import is.hail.utils._ -import org.apache.spark._ -import org.apache.spark.broadcast.Broadcast -import org.apache.spark.rdd.RDD import scala.annotation.tailrec +import org.apache.spark._ +import org.apache.spark.rdd.RDD + object OrderedDependency { - def generate[T](oldPartitioner: RVDPartitioner, newIntervals: IndexedSeq[Interval], rdd: RDD[T]): OrderedDependency[T] = { + def generate[T](oldPartitioner: RVDPartitioner, newIntervals: IndexedSeq[Interval], rdd: RDD[T]) + : OrderedDependency[T] = new OrderedDependency( newIntervals.map(oldPartitioner.queryInterval).toArray, - rdd) - } + rdd, + ) } class OrderedDependency[T]( depArray: Array[Range], - rdd: RDD[T] + rdd: RDD[T], ) extends NarrowDependency[T](rdd) { override def getParents(partitionId: Int): Seq[Int] = depArray(partitionId) } object RepartitionedOrderedRDD2 { - def apply(sm: HailStateManager, prev: RVD, newRangeBounds: IndexedSeq[Interval]): ContextRDD[Long] = + def apply(sm: HailStateManager, prev: RVD, newRangeBounds: IndexedSeq[Interval]) + : ContextRDD[Long] = ContextRDD(new RepartitionedOrderedRDD2(sm, prev, newRangeBounds)) } -/** - * Repartition 'prev' to comply with 'newRangeBounds', using narrow dependencies. - * Assumes new key type is a prefix of old key type, so no reordering is - * needed. +/** Repartition 'prev' to comply with 'newRangeBounds', using narrow dependencies. Assumes new key + * type is a prefix of old key type, so no reordering is needed. */ -class RepartitionedOrderedRDD2 private (sm: HailStateManager, @transient val prev: RVD, @transient val newRangeBounds: IndexedSeq[Interval]) - extends RDD[ContextRDD.ElementType[Long]](prev.crdd.sparkContext, Nil) { // Nil since we implement getDependencies +class RepartitionedOrderedRDD2 private ( + sm: HailStateManager, + @transient val prev: RVD, + @transient val newRangeBounds: IndexedSeq[Interval], +) extends RDD[ContextRDD.ElementType[Long]](prev.crdd.sparkContext, Nil) { // Nil since we implement getDependencies val prevCRDD: ContextRDD[Long] = prev.crdd val typ: RVDType = prev.typ val kOrd: ExtendedOrdering = PartitionBoundOrdering(sm, typ.kType.virtualType) - def getPartitions: Array[Partition] = { - require(newRangeBounds.forall{i => typ.kType.virtualType.relaxedTypeCheck(i.start) && typ.kType.virtualType.relaxedTypeCheck(i.end)}) + require(newRangeBounds.forall { i => + typ.kType.virtualType.relaxedTypeCheck(i.start) && typ.kType.virtualType.relaxedTypeCheck( + i.end + ) + }) Array.tabulate[Partition](newRangeBounds.length) { i => RepartitionedOrderedRDD2Partition( i, dependency.getParents(i).toArray.map(prevCRDD.partitions), - newRangeBounds(i)) + newRangeBounds(i), + ) } } - override def compute(partition: Partition, context: TaskContext): Iterator[RVDContext => Iterator[Long]] = { + override def compute(partition: Partition, context: TaskContext) + : Iterator[RVDContext => Iterator[Long]] = { val ordPartition = partition.asInstanceOf[RepartitionedOrderedRDD2Partition] val pord = kOrd.intervalEndpointOrdering val range = ordPartition.range @@ -64,9 +72,11 @@ class RepartitionedOrderedRDD2 private (sm: HailStateManager, @transient val pre private[this] val innerCtx = outerCtx.freshContext() private[this] val outerRegion = outerCtx.region private[this] val innerRegion = innerCtx.region - private[this] val parentIterator = ordPartition.parents.iterator.flatMap(p => prevCRDD.iterator(p, context).flatMap(_.apply(innerCtx))) + private[this] val parentIterator = ordPartition.parents.iterator.flatMap(p => + prevCRDD.iterator(p, context).flatMap(_.apply(innerCtx)) + ) private[this] var pulled: Boolean = false - private[this] var current: Long = _ + private[this] var current: Long = _ private[this] val ur = new UnsafeRow(typ.rowType) private[this] val key = new SelectFieldsRow(ur, typ.kFieldIdx) @@ -83,7 +93,8 @@ class RepartitionedOrderedRDD2 private (sm: HailStateManager, @transient val pre // End the iterator if first remaining value is greater than range.right end() } else - // End the iterator if we exhausted parent iterators before finding an element greater than range.left + /* End the iterator if we exhausted parent iterators before finding an element greater + * than range.left */ end() } @@ -131,13 +142,14 @@ class RepartitionedOrderedRDD2 private (sm: HailStateManager, @transient val pre val dependency: OrderedDependency[_] = OrderedDependency.generate( prev.partitioner, newRangeBounds, - prevCRDD.rdd) + prevCRDD.rdd, + ) override def getDependencies: Seq[Dependency[_]] = FastSeq(dependency) } case class RepartitionedOrderedRDD2Partition( - index: Int, - parents: Array[Partition], - range: Interval + index: Int, + parents: Array[Partition], + range: Interval, ) extends Partition diff --git a/hail/src/main/scala/is/hail/stats/GeneralizedChiSquaredDistribution.scala b/hail/src/main/scala/is/hail/stats/GeneralizedChiSquaredDistribution.scala index ee889e4b740..c9d52cf6431 100644 --- a/hail/src/main/scala/is/hail/stats/GeneralizedChiSquaredDistribution.scala +++ b/hail/src/main/scala/is/hail/stats/GeneralizedChiSquaredDistribution.scala @@ -1,7 +1,7 @@ package is.hail.stats -import is.hail.utils._ import is.hail.types.physical._ +import is.hail.utils._ case class DaviesAlgorithmTrace( var absoluteSum: Double, @@ -10,14 +10,14 @@ case class DaviesAlgorithmTrace( var integrationIntervalInFinalIntegration: Double, var truncationPointInInitialIntegration: Double, var standardDeviationOfInitialConvergenceFactor: Double, - var cyclesToLocateIntegrationParameters: Int + var cyclesToLocateIntegrationParameters: Int, ) class DaviesResultForPython( val value: Double, val nIterations: Int, val converged: Boolean, - val fault: Int + val fault: Int, ) object DaviesAlgorithm { @@ -30,7 +30,7 @@ object DaviesAlgorithm { "value" -> PFloat64(required = true), "n_iterations" -> PInt32(required = true), "converged" -> PBoolean(required = true), - "fault" -> PInt32(required = true) + "fault" -> PInt32(required = true), ) } @@ -40,29 +40,28 @@ class DaviesAlgorithm( private[this] val lb: Array[Double], private[this] val nc: Array[Double], private[this] val lim: Int, - private[this] val sigma: Double + private[this] val sigma: Double, ) { - /** - * This algorithm is a direct port of Robert Davies' algorithm described in + + /** This algorithm is a direct port of Robert Davies' algorithm described in * - * Davies, Robert. "The distribution of a linear combination of chi-squared - * random variables." Applied Statistics 29 323-333. 1980. + * Davies, Robert. "The distribution of a linear combination of chi-squared random variables." + * Applied Statistics 29 323-333. 1980. * * The Fortran code was published with the aforementioned paper. A port to C is available on * Davies' website http://www.robertnz.net/download.html . At the time of retrieval (2023-01-15), * the code lacks a description of its license. On 2023-01-18 0304 ET I received personal e-mail * correspondence from Robert Davies indicating: * - * Assume it has the MIT license. That is on my todo list to say the MIT license applies to - * all the software on the website unless specified otherwise. - * - **/ - import GeneralizedChiSquaredDistribution._ + * Assume it has the MIT license. That is on my todo list to say the MIT license applies to all + * the software on the website unless specified otherwise. + */ import DaviesAlgorithm._ + import GeneralizedChiSquaredDistribution._ private[this] val r: Int = lb.length private[this] var count: Int = 0 - private[this] var ndtsrt: Boolean = true // "need to sort" + private[this] var ndtsrt: Boolean = true // "need to sort" private[this] var fail: Boolean = true private[this] var th: Array[Int] = new Array[Int](r) private[this] var intl: Double = 0.0 @@ -120,8 +119,7 @@ class DaviesAlgorithm( xconst = xconst + lj * (ncj / y + nj) / y sum1 = (sum1 + ncj * square(x / y) + - nj * (square(x) / y + log1(-x, false)) - ) + nj * (square(x) / y + log1(-x, false))) j -= 1 } @@ -133,7 +131,8 @@ class DaviesAlgorithm( var u2 = _u2 var u1 = 0.0 var c1 = mean - val rb = 2.0 * (if (u2 > 0.0) { lmax } else { lmin }) + val rb = 2.0 * (if (u2 > 0.0) { lmax } + else { lmin }) var u = u2 / (1.0 + u2 * rb) @@ -176,13 +175,13 @@ class DaviesAlgorithm( def truncation(_u: Double, _tausq: Double): Double = { counter() var u = _u - var tausq = _tausq + val tausq = _tausq var sum1 = 0.0 var prod2 = 0.0 var prod3 = 0.0 var s = 0 - var sum2 = (sigsq + tausq) * square(u) + val sum2 = (sigsq + tausq) * square(u) var prod1 = 2.0 * sum2 u = 2.0 * u @@ -263,7 +262,7 @@ class DaviesAlgorithm( while (i < 4) { u = ut / divisForFindu(i) - if ( truncation(u, 0.0) <= accx ) { + if (truncation(u, 0.0) <= accx) { ut = u } @@ -278,9 +277,9 @@ class DaviesAlgorithm( var k = nterm while (k >= 0) { val u = (k + 0.5) * interv - var sum1 = - 2.0 * u * c + var sum1 = -2.0 * u * c var sum2 = Math.abs(sum1) - var sum3 = - 0.5 * sigsq * square(u) + var sum3 = -0.5 * sigsq * square(u) var j = r - 1 while (j >= 0) { @@ -311,14 +310,14 @@ class DaviesAlgorithm( } } - def cfe(x: Double): Double = { counter(); if (ndtsrt) { order() } var axl = Math.abs(x) - val sxl = if (x > 0.0) { 1.0 } else { -1.0 } + val sxl = if (x > 0.0) { 1.0 } + else { -1.0 } var sum1 = 0.0; var j = r - 1 var break = false @@ -375,10 +374,14 @@ class DaviesAlgorithm( val lj = lb(j) val ncj = nc(j) if (nj < 0) { - throw new HailException(s"Degrees of freedom parameters must all be positive, ${j}'th parameter is ${nj}.") + throw new HailException( + s"Degrees of freedom parameters must all be positive, $j'th parameter is $nj." + ) } if (ncj < 0.0) { - throw new HailException(s"Non-centrality parameters must all be positive, ${j}'th parameter is ${ncj}.") + throw new HailException( + s"Non-centrality parameters must all be positive, $j'th parameter is $ncj." + ) } sd = sd + square(lj) * (2 * nj + 4.0 * ncj) mean = mean + lj * (nj + ncj) @@ -402,7 +405,9 @@ class DaviesAlgorithm( if (lmin == 0.0 && lmax == 0.0 && sigma == 0.0) { val lbStr = lb.mkString("(", ",", ")") - throw new HailException(s"Either weights vector must be non-zero or sigma must be non-zero, found: ${lbStr} and ${sigma}.") + throw new HailException( + s"Either weights vector must be non-zero or sigma must be non-zero, found: $lbStr and $sigma." + ) } sd = Math.sqrt(sd) @@ -420,7 +425,7 @@ class DaviesAlgorithm( /* truncation point with no convergence factor */ utx = findu(utx, .5 * acc1) /* does convergence factor help */ - if (c != 0.0 && (almx > 0.07 * sd)) { + if (c != 0.0 && (almx > 0.07 * sd)) { // FIXME: return the fail parameter val tausq = .25 * acc1 / cfe(c) if (fail) { @@ -455,10 +460,10 @@ class DaviesAlgorithm( throw new DaviesException() } /* find integration interval */ - val divisor = if (d1 > d2) { d1 } else { d2 } + val divisor = if (d1 > d2) { d1 } + else { d2 } intv = 2.0 * pi / divisor - /* calculate number of terms required for main and - auxillary integrations */ + /* calculate number of terms required for main and auxillary integrations */ xnt = utx / intv val xntm = 3.0 / Math.sqrt(acc1) if (xnt > xntm * 1.5) { @@ -508,8 +513,7 @@ class DaviesAlgorithm( qfval = 0.5 - intl trace.absoluteSum = ersm - /* test whether round-off error could be significant - allow for radix 8 or 16 machines */ + /* test whether round-off error could be significant allow for radix 8 or 16 machines */ up = ersm val x = up + acc / 10.0 j = 0 @@ -530,13 +534,12 @@ class DaviesAlgorithm( } object GeneralizedChiSquaredDistribution { - def exp1(x: Double): Double = { + def exp1(x: Double): Double = if (x < -50.0) { 0.0 } else { Math.exp(x) } - } def square(x: Double): Double = x * x @@ -577,23 +580,27 @@ object GeneralizedChiSquaredDistribution { nc: Array[Double], sigma: Double, lim: Int, - acc: Double + acc: Double, ): Double = { assert(n.length == lb.length) assert(lb.length == nc.length) assert(lim >= 0) assert(acc >= 0) - val (value, trace, fault) = new DaviesAlgorithm(c, n, lb, nc, lim, sigma).cdf(acc) + val (value, _, fault) = new DaviesAlgorithm(c, n, lb, nc, lim, sigma).cdf(acc) assert(fault >= 0 && fault <= 2, fault) if (fault == 1) { - throw new RuntimeException(s"Required accuracy ($acc) not achieved. Best value found was: $value.") + throw new RuntimeException( + s"Required accuracy ($acc) not achieved. Best value found was: $value." + ) } if (fault == 2) { - throw new RuntimeException(s"Round-off error is possibly significant. Best value found was: $value.") + throw new RuntimeException( + s"Round-off error is possibly significant. Best value found was: $value." + ) } value @@ -606,7 +613,7 @@ object GeneralizedChiSquaredDistribution { nc: Array[Double], sigma: Double, lim: Int, - acc: Double + acc: Double, ): DaviesResultForPython = { assert(n.length == lb.length) assert(lb.length == nc.length) diff --git a/hail/src/main/scala/is/hail/stats/LeveneHaldane.scala b/hail/src/main/scala/is/hail/stats/LeveneHaldane.scala index 6aae27c468e..ec8da73603b 100644 --- a/hail/src/main/scala/is/hail/stats/LeveneHaldane.scala +++ b/hail/src/main/scala/is/hail/stats/LeveneHaldane.scala @@ -1,20 +1,22 @@ package is.hail.stats import is.hail.utils._ + import org.apache.commons.math3.distribution.AbstractIntegerDistribution import org.apache.commons.math3.random.RandomGenerator // Efficient implementation of the Levene-Haldane distribution, used in exact tests of Hardy-Weinberg equilibrium // See docs/LeveneHaldane.pdf -class LeveneHaldane(val n: Int, - val nA: Int, - val mode: Int, - pRU: Stream[Double], - pLU: Stream[Double], - pN: Double, - rng: RandomGenerator) - extends AbstractIntegerDistribution(rng) { +class LeveneHaldane( + val n: Int, + val nA: Int, + val mode: Int, + pRU: Stream[Double], + pLU: Stream[Double], + pN: Double, + rng: RandomGenerator, +) extends AbstractIntegerDistribution(rng) { // The probability mass function P(nAB), computing no more than necessary def probability(nAB: Int): Double = @@ -25,46 +27,51 @@ class LeveneHaldane(val n: Int, else pLU((mode - nAB) / 2) / pN - // P(n0 < nAB <= n1), implemented to minimize round-off error but take advantage of the sub-geometric tails + /* P(n0 < nAB <= n1), implemented to minimize round-off error but take advantage of the + * sub-geometric tails */ override def cumulativeProbability(n0: Int, n1: Int): Double = if (n0 >= n1 || n0 >= nA || n1 < nA % 2) 0.0 else if (n0 >= mode) { - val cutoff = pRU((n0 - mode) / 2 + 1) * 1.0E-16 + val cutoff = pRU((n0 - mode) / 2 + 1) * 1.0e-16 pRU.slice((n0 - mode) / 2 + 1, (n1 - mode) / 2 + 1).takeWhile(_ > cutoff).sum / pN - } - else if (n1 < mode) { - val cutoff = pLU((mode - n1 + 1) / 2) * 1.0E-16 + } else if (n1 < mode) { + val cutoff = pLU((mode - n1 + 1) / 2) * 1.0e-16 pLU.slice((mode - n1 + 1) / 2, (mode - n0 + 1) / 2).takeWhile(_ > cutoff).sum / pN - } - else { - val cutoff = 1.0E-16 + } else { + val cutoff = 1.0e-16 (pLU.slice(1, (mode - n0 + 1) / 2).takeWhile(_ > cutoff).sum + pRU.slice(0, (n1 - mode) / 2 + 1).takeWhile(_ > cutoff).sum) / pN } + // P(nAB <= n1) def cumulativeProbability(n1: Int): Double = cumulativeProbability(-1, n1) + // P(nAB > n0) def survivalFunction(n0: Int): Double = cumulativeProbability(n0, nA) // Exact tests with the mid-p-value correction: - // half the probability of the observed outcome nAB plus the probabilities of those "more extreme", i.e. + /* half the probability of the observed outcome nAB plus the probabilities of those "more + * extreme", i.e. */ // greater, lesser, or of smaller probability, respectively // (in the latter case weighting outcomes of equal probability by 1/2) def rightMidP(nAB: Int) = survivalFunction(nAB) + 0.5 * probability(nAB) + def leftMidP(nAB: Int) = cumulativeProbability(nAB) - 0.5 * probability(nAB) + def exactMidP(nAB: Int) = { val p0U = probability(nAB) * pN if (D_==(p0U, 0.0)) 0.0 else { - val cutoff = p0U * 0.5E-16 + val cutoff = p0U * 0.5e-16 def mpU(s: Stream[Double]): Double = { - val (sEq, sLess) = s.dropWhile(D_>(_, p0U, tolerance = 1.0E-12)).span(D_==(_, p0U, tolerance = 1.0E-12)) + val (sEq, sLess) = + s.dropWhile(D_>(_, p0U, tolerance = 1.0e-12)).span(D_==(_, p0U, tolerance = 1.0e-12)) 0.5 * sEq.sum + sLess.takeWhile(_ > cutoff).sum } (mpU(pLU.tail) + mpU(pRU)) / pN @@ -73,15 +80,15 @@ class LeveneHaldane(val n: Int, def nB: Int = 2 * n - nA def getNumericalMean: Double = 1.0 * nA * nB / (2 * n - 1) + def getNumericalVariance: Double = 1.0 * nA * nB / (2 * n - 1) * (1 + (nA - 1.0) * (nB - 1) / (2 * n - 3) - 1.0 * nA * nB / (2 * n - 1)) - def isSupportConnected: Boolean = true // interpreted as restricted to the even or odd integers, + def isSupportConnected: Boolean = true // interpreted as restricted to the even or odd integers, def getSupportUpperBound: Int = nA def getSupportLowerBound: Int = nA % 2 } - object LeveneHaldane { def apply(n: Int, nA: Int, rng: RandomGenerator): LeveneHaldane = { @@ -89,7 +96,9 @@ object LeveneHaldane { val nB = 2 * n - nA val parity = nA % 2 - val mode = ((x: Double) => 2 * math.round((x - parity) / 2) + parity)((nA + 1.0) * (nB + 1) / (2 * n + 3)).toInt + val mode = ((x: Double) => 2 * math.round((x - parity) / 2) + parity)( + (nA + 1.0) * (nB + 1) / (2 * n + 3) + ).toInt def pRUfrom(nAB: Int, p: Double): Stream[Double] = p #:: pRUfrom(nAB + 2, p * (nA - nAB) * (nB - nAB) / ((nAB + 2.0) * (nAB + 1))) @@ -101,7 +110,7 @@ object LeveneHaldane { val pLU = pLUfrom(mode, 1.0) // Normalization constant - val pN = pRU.takeWhile(_ > 1.0E-16).sum + pLU.takeWhile(_ > 1.0E-16).sum - 1.0 + val pN = pRU.takeWhile(_ > 1.0e-16).sum + pLU.takeWhile(_ > 1.0e-16).sum - 1.0 new LeveneHaldane(n, nA, mode, pRU, pLU, pN, rng) } @@ -109,4 +118,4 @@ object LeveneHaldane { // If we ever want to sample, it seems standard to replace `null` with a // default Random Generator `new Well19937c()` def apply(n: Int, nA: Int): LeveneHaldane = LeveneHaldane(n, nA, null) -} \ No newline at end of file +} diff --git a/hail/src/main/scala/is/hail/stats/LinearMixedModel.scala b/hail/src/main/scala/is/hail/stats/LinearMixedModel.scala index 4c86a6d15a2..53ec21e4afe 100644 --- a/hail/src/main/scala/is/hail/stats/LinearMixedModel.scala +++ b/hail/src/main/scala/is/hail/stats/LinearMixedModel.scala @@ -1,6 +1,5 @@ package is.hail.stats -import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV} import is.hail.annotations.{BroadcastRow, Region, RegionValue, RegionValueBuilder} import is.hail.backend.ExecuteContext import is.hail.backend.spark.SparkTaskContext @@ -12,33 +11,65 @@ import is.hail.types.TableType import is.hail.types.physical.{PCanonicalStruct, PFloat64, PInt64} import is.hail.types.virtual.TStruct import is.hail.utils._ + +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV} import org.apache.spark.storage.StorageLevel -case class LMMData(gamma: Double, residualSq: Double, py: BDV[Double], px: BDM[Double], d: BDV[Double], - ydy: Double, xdy: BDV[Double], xdx: BDM[Double], yOpt: Option[BDV[Double]], xOpt: Option[BDM[Double]]) +case class LMMData( + gamma: Double, + residualSq: Double, + py: BDV[Double], + px: BDM[Double], + d: BDV[Double], + ydy: Double, + xdy: BDV[Double], + xdx: BDM[Double], + yOpt: Option[BDV[Double]], + xOpt: Option[BDM[Double]], +) object LinearMixedModel { - def pyApply(gamma: Double, residualSq: Double, py: Array[Double], px: BDM[Double], d: Array[Double], - ydy: Double, xdy: Array[Double], xdx: BDM[Double], + def pyApply( + gamma: Double, + residualSq: Double, + py: Array[Double], + px: BDM[Double], + d: Array[Double], + ydy: Double, + xdy: Array[Double], + xdx: BDM[Double], // yOpt, xOpt can be null - yOpt: Array[Double], xOpt: BDM[Double]): LinearMixedModel = { - + yOpt: Array[Double], + xOpt: BDM[Double], + ): LinearMixedModel = new LinearMixedModel( - LMMData(gamma, residualSq, BDV(py), px, BDV(d), ydy, BDV(xdy), xdx, Option(yOpt).map(BDV(_)), Option(xOpt))) - } - - private val rowType = PCanonicalStruct(true, - "idx" -> PInt64(), - "beta" -> PFloat64(), - "sigma_sq" -> PFloat64(), - "chi_sq" -> PFloat64(), - "p_value" -> PFloat64()) + LMMData( + gamma, + residualSq, + BDV(py), + px, + BDV(d), + ydy, + BDV(xdy), + xdx, + Option(yOpt).map(BDV(_)), + Option(xOpt), + ) + ) + + private val rowType = PCanonicalStruct( + true, + "idx" -> PInt64(), + "beta" -> PFloat64(), + "sigma_sq" -> PFloat64(), + "chi_sq" -> PFloat64(), + "p_value" -> PFloat64(), + ) private val tableType = TableType(rowType.virtualType, FastSeq("idx"), TStruct.empty) - def toTableIR(ctx: ExecuteContext, rvd: RVD): TableIR = { + def toTableIR(ctx: ExecuteContext, rvd: RVD): TableIR = TableLiteral(TableValue(ctx, tableType, BroadcastRow.empty(ctx), rvd), ctx.theHailClassLoader) - } } class LinearMixedModel(lmmData: LMMData) { @@ -53,7 +84,9 @@ class LinearMixedModel(lmmData: LMMData) { def fitLowRank(ctx: ExecuteContext, pa_t: RowMatrix, a_t: RowMatrix): TableIR = { if (pa_t.nRows != a_t.nRows) - fatal(s"pa_t and a_t must have the same number of rows, but found ${pa_t.nRows} and ${a_t.nRows}") + fatal( + s"pa_t and a_t must have the same number of rows, but found ${pa_t.nRows} and ${a_t.nRows}" + ) else if (!(pa_t.partitionCounts() sameElements a_t.partitionCounts())) fatal(s"pa_t and a_t both have ${pa_t.nRows} rows, but row partitions are not aligned") @@ -61,7 +94,8 @@ class LinearMixedModel(lmmData: LMMData) { val rowType = LinearMixedModel.rowType val rdd = pa_t.rows.zipPartitions(a_t.rows) { case (itPAt, itAt) => - val LMMData(gamma, nullResidualSq, py, px, d, ydy, xdy0, xdx0, Some(y), Some(x)) = lmmDataBc.value + val LMMData(gamma, nullResidualSq, py, px, d, ydy, xdy0, xdx0, Some(y), Some(x)) = + lmmDataBc.value val xdy = xdy0.copy val xdx = xdx0.copy val n = x.rows @@ -83,7 +117,10 @@ class LinearMixedModel(lmmData: LMMData) { xdy(0) = (py dot dpa) + gamma * (y dot a) xdx(0, 0) = (pa dot dpa) + gamma * (a dot a) - xdx(r1, r0) := (dpa.t * px).t + gamma * (a.t * x).t // if px and x are not copied, the forms px.t * dpa and x.t * a result in a subtle bug + xdx( + r1, + r0, + ) := (dpa.t * px).t + gamma * (a.t * x).t // if px and x are not copied, the forms px.t * dpa and x.t * a result in a subtle bug xdx(r0, r1) := xdx(r1, r0).t region.clear() @@ -97,7 +134,7 @@ class LinearMixedModel(lmmData: LMMData) { rvb.startStruct() rvb.addLong(i) - rvb.addDouble(beta(0)) // could expand to return all coefficients + rvb.addDouble(beta(0)) // could expand to return all coefficients rvb.addDouble(sigmaSq) rvb.addDouble(chiSq) rvb.addDouble(pValue) @@ -119,7 +156,8 @@ class LinearMixedModel(lmmData: LMMData) { val rvd = RVD( RVDType(rowType, LinearMixedModel.tableType.key), pa_t.partitioner(), - ContextRDD.weaken(rdd).toCRDDPtr).persist(ctx, StorageLevel.MEMORY_AND_DISK) + ContextRDD.weaken(rdd).toCRDDPtr, + ).persist(ctx, StorageLevel.MEMORY_AND_DISK) LinearMixedModel.toTableIR(ctx, rvd) } @@ -148,7 +186,10 @@ class LinearMixedModel(lmmData: LMMData) { xdy(0) = py dot dpa xdx(0, 0) = pa dot dpa - xdx(r1, r0) := (dpa.t * px).t // if px is not copied, the form px.t * dpa results in a subtle bug + xdx( + r1, + r0, + ) := (dpa.t * px).t // if px is not copied, the form px.t * dpa results in a subtle bug xdx(r0, r1) := xdx(r1, r0).t region.clear() @@ -162,7 +203,9 @@ class LinearMixedModel(lmmData: LMMData) { rvb.startStruct() rvb.addLong(i) - rvb.addDouble(beta(0)) // could expand to return all coefficients, or switch to block matrix projection trick + rvb.addDouble( + beta(0) + ) // could expand to return all coefficients, or switch to block matrix projection trick rvb.addDouble(sigmaSq) rvb.addDouble(chiSq) rvb.addDouble(pValue) @@ -185,7 +228,8 @@ class LinearMixedModel(lmmData: LMMData) { val rvd = RVD( RVDType(rowType, LinearMixedModel.tableType.key), pa_t.partitioner(), - ContextRDD.weaken(rdd).toCRDDPtr).persist(ctx, StorageLevel.MEMORY_AND_DISK) + ContextRDD.weaken(rdd).toCRDDPtr, + ).persist(ctx, StorageLevel.MEMORY_AND_DISK) LinearMixedModel.toTableIR(ctx, rvd) } diff --git a/hail/src/main/scala/is/hail/stats/LinearRegressionModel.scala b/hail/src/main/scala/is/hail/stats/LinearRegressionModel.scala index b0e67cfd198..9e958d092bc 100644 --- a/hail/src/main/scala/is/hail/stats/LinearRegressionModel.scala +++ b/hail/src/main/scala/is/hail/stats/LinearRegressionModel.scala @@ -1,8 +1,9 @@ package is.hail.stats -import breeze.linalg.{Matrix, Vector} import is.hail.annotations.Annotation import is.hail.types.virtual.{TFloat64, TStruct} + +import breeze.linalg.{Matrix, Vector} import net.sourceforge.jdistlib.T object LinearRegressionModel { @@ -10,9 +11,17 @@ object LinearRegressionModel { ("beta", TFloat64), ("se", TFloat64), ("t_stat", TFloat64), - ("p_value", TFloat64)) + ("p_value", TFloat64), + ) - def fit(x: Vector[Double], y: Vector[Double], yyp: Double, qt: Matrix[Double], qty: Vector[Double], d: Int): Annotation = { + def fit( + x: Vector[Double], + y: Vector[Double], + yyp: Double, + qt: Matrix[Double], + qty: Vector[Double], + d: Int, + ): Annotation = { val qtx = qt * x val xxp = (x dot x) - (qtx dot qtx) val xyp = (x dot y) - (qtx dot qty) diff --git a/hail/src/main/scala/is/hail/stats/LogisticRegressionModel.scala b/hail/src/main/scala/is/hail/stats/LogisticRegressionModel.scala index aefa95aff99..72412733cf5 100644 --- a/hail/src/main/scala/is/hail/stats/LogisticRegressionModel.scala +++ b/hail/src/main/scala/is/hail/stats/LogisticRegressionModel.scala @@ -1,13 +1,19 @@ package is.hail.stats -import breeze.linalg._ -import breeze.numerics._ import is.hail.annotations.RegionValueBuilder import is.hail.types.virtual._ import is.hail.utils.fatal +import breeze.linalg._ +import breeze.numerics._ + object LogisticRegressionTest { - val tests = Map("wald" -> WaldTest, "lrt" -> LikelihoodRatioTest, "score" -> LogisticScoreTest, "firth" -> LogisticFirthTest) + val tests = Map( + "wald" -> WaldTest, + "lrt" -> LikelihoodRatioTest, + "score" -> LogisticScoreTest, + "firth" -> LogisticFirthTest, + ) } abstract class GLMTest extends Serializable { @@ -17,40 +23,43 @@ abstract class GLMTest extends Serializable { nullFit: GLMFit, link: String, maxIter: Int, - tol: Double + tol: Double, ): GLMTestResult[GLMStats] val schema: TStruct } abstract class GLMStats { - def addToRVB(rvb: RegionValueBuilder) + def addToRVB(rvb: RegionValueBuilder): Unit } class GLMTestResult[+T <: GLMStats](val stats: Option[T], private val size: Int) { - def addToRVB(rvb: RegionValueBuilder) { + def addToRVB(rvb: RegionValueBuilder): Unit = stats match { case Some(s) => s.addToRVB(rvb) case None => rvb.skipFields(size) } - } } -class GLMTestResultWithFit[T <: GLMStats](override val stats: Option[T], private val size: Int, val fitStats: GLMFit) extends GLMTestResult[T](stats, size) { - override def addToRVB(rvb: RegionValueBuilder) { +class GLMTestResultWithFit[T <: GLMStats]( + override val stats: Option[T], + private val size: Int, + val fitStats: GLMFit, +) extends GLMTestResult[T](stats, size) { + override def addToRVB(rvb: RegionValueBuilder): Unit = { super.addToRVB(rvb) fitStats.addToRVB(rvb) } } - object WaldTest extends GLMTest { val schema: TStruct = TStruct( ("beta", TFloat64), ("standard_error", TFloat64), ("z_stat", TFloat64), ("p_value", TFloat64), - ("fit", GLMFit.schema)) + ("fit", GLMFit.schema), + ) def test( X: DenseMatrix[Double], @@ -58,7 +67,7 @@ object WaldTest extends GLMTest { nullFit: GLMFit, link: String, maxIter: Int, - tol: Double + tol: Double, ): GLMTestResultWithFit[WaldStats] = { require(nullFit.fisher.isDefined) @@ -67,7 +76,7 @@ object WaldTest extends GLMTest { case "poisson" => new PoissonRegressionModel(X, y) case _ => fatal("link must be logistic or poisson") } - val fit = model.fit(Some(nullFit), maxIter=maxIter, tol=tol) + val fit = model.fit(Some(nullFit), maxIter = maxIter, tol = tol) val waldStats = if (fit.converged) { try { @@ -77,8 +86,8 @@ object WaldTest extends GLMTest { Some(WaldStats(fit.b, se, z, p)) } catch { - case e: breeze.linalg.MatrixSingularException => None - case e: breeze.linalg.NotConvergedException => None + case _: breeze.linalg.MatrixSingularException => None + case _: breeze.linalg.NotConvergedException => None } } else None @@ -87,7 +96,12 @@ object WaldTest extends GLMTest { } } -case class WaldStats(b: DenseVector[Double], se: DenseVector[Double], z: DenseVector[Double], p: DenseVector[Double]) extends GLMStats { +case class WaldStats( + b: DenseVector[Double], + se: DenseVector[Double], + z: DenseVector[Double], + p: DenseVector[Double], +) extends GLMStats { def addToRVB(rvb: RegionValueBuilder): Unit = { rvb.addDouble(b(-1)) rvb.addDouble(se(-1)) @@ -96,13 +110,13 @@ case class WaldStats(b: DenseVector[Double], se: DenseVector[Double], z: DenseVe } } - object LikelihoodRatioTest extends GLMTest { val schema = TStruct( ("beta", TFloat64), ("chi_sq_stat", TFloat64), ("p_value", TFloat64), - ("fit", GLMFit.schema)) + ("fit", GLMFit.schema), + ) def test( X: DenseMatrix[Double], @@ -110,7 +124,7 @@ object LikelihoodRatioTest extends GLMTest { nullFit: GLMFit, link: String, maxIter: Int, - tol: Double + tol: Double, ): GLMTestResultWithFit[LikelihoodRatioStats] = { val m = X.cols val m0 = nullFit.b.length @@ -119,7 +133,7 @@ object LikelihoodRatioTest extends GLMTest { case "poisson" => new PoissonRegressionModel(X, y) case _ => fatal("link must be logistic or poisson") } - val fit = model.fit(Some(nullFit), maxIter=maxIter, tol=tol) + val fit = model.fit(Some(nullFit), maxIter = maxIter, tol = tol) val lrStats = if (fit.converged) { @@ -147,7 +161,8 @@ object LogisticFirthTest extends GLMTest { ("beta", TFloat64), ("chi_sq_stat", TFloat64), ("p_value", TFloat64), - ("fit", GLMFit.schema)) + ("fit", GLMFit.schema), + ) def test( X: DenseMatrix[Double], @@ -155,20 +170,20 @@ object LogisticFirthTest extends GLMTest { nullFit: GLMFit, link: String, maxIter: Int, - tol: Double + tol: Double, ): GLMTestResultWithFit[FirthStats] = { require(link == "logistic") val m = X.cols val m0 = nullFit.b.length val model = new LogisticRegressionModel(X, y) - val nullFitFirth = model.fitFirth(nullFit.b, maxIter=maxIter, tol=tol) + val nullFitFirth = model.fitFirth(nullFit.b, maxIter = maxIter, tol = tol) if (nullFitFirth.converged) { val nullFitFirthb = DenseVector.zeros[Double](m) nullFitFirthb(0 until m0) := nullFitFirth.b - val fitFirth = model.fitFirth(nullFitFirthb, maxIter=maxIter, tol=tol) + val fitFirth = model.fitFirth(nullFitFirthb, maxIter = maxIter, tol = tol) val firthStats = if (fitFirth.converged) { @@ -185,7 +200,6 @@ object LogisticFirthTest extends GLMTest { } } - case class FirthStats(b: DenseVector[Double], chi2: Double, p: Double) extends GLMStats { def addToRVB(rvb: RegionValueBuilder): Unit = { rvb.addDouble(b(-1)) @@ -194,12 +208,11 @@ case class FirthStats(b: DenseVector[Double], chi2: Double, p: Double) extends G } } - object LogisticScoreTest extends GLMTest { val schema: TStruct = TStruct( ("chi_sq_stat", TFloat64), - ("p_value", TFloat64)) - + ("p_value", TFloat64), + ) def test( X: DenseMatrix[Double], @@ -207,7 +220,7 @@ object LogisticScoreTest extends GLMTest { nullFit: GLMFit, link: String, maxIter: Int, - tol: Double + tol: Double, ): GLMTestResult[ScoreStats] = { require(link == "logistic") require(nullFit.score.isDefined && nullFit.fisher.isDefined) @@ -241,8 +254,8 @@ object LogisticScoreTest extends GLMTest { Some(ScoreStats(chi2, p)) } catch { - case e: breeze.linalg.MatrixSingularException => None - case e: breeze.linalg.NotConvergedException => None + case _: breeze.linalg.MatrixSingularException => None + case _: breeze.linalg.NotConvergedException => None } } @@ -250,7 +263,6 @@ object LogisticScoreTest extends GLMTest { } } - case class ScoreStats(chi2: Double, p: Double) extends GLMStats { def addToRVB(rvb: RegionValueBuilder): Unit = { rvb.addDouble(chi2) @@ -258,15 +270,14 @@ case class ScoreStats(chi2: Double, p: Double) extends GLMStats { } } - abstract class GeneralLinearModel { def bInterceptOnly(): DenseVector[Double] def fit(optNullFit: Option[GLMFit], maxIter: Int, tol: Double): GLMFit } - -class LogisticRegressionModel(X: DenseMatrix[Double], y: DenseVector[Double]) extends GeneralLinearModel { +class LogisticRegressionModel(X: DenseMatrix[Double], y: DenseVector[Double]) + extends GeneralLinearModel { require(y.length == X.rows) val n: Int = X.rows @@ -334,8 +345,8 @@ class LogisticRegressionModel(X: DenseMatrix[Double], y: DenseVector[Double]) ex fisher := X.t * (X(::, *) *:* (mu *:* (1d - mu))) } } catch { - case e: breeze.linalg.MatrixSingularException => exploded = true - case e: breeze.linalg.NotConvergedException => exploded = true + case _: breeze.linalg.MatrixSingularException => exploded = true + case _: breeze.linalg.NotConvergedException => exploded = true } } @@ -361,19 +372,24 @@ class LogisticRegressionModel(X: DenseMatrix[Double], y: DenseVector[Double]) ex val sqrtW = sqrt(mu *:* (1d - mu)) val QR = qr.reduced(X(::, *) *:* sqrtW) val h = QR.q(*, ::).map(r => r dot r) - val deltaB = TriSolve(QR.r(0 until m0, 0 until m0), QR.q(::, 0 until m0).t * (((y - mu) + (h *:* (0.5 - mu))) /:/ sqrtW)) + val deltaB = TriSolve( + QR.r(0 until m0, 0 until m0), + QR.q(::, 0 until m0).t * (((y - mu) + (h *:* (0.5 - mu))) /:/ sqrtW), + ) if (deltaB(0).isNaN) { exploded = true } else if (max(abs(deltaB)) < tol && iter > 1) { converged = true - logLkhd = sum(breeze.numerics.log((y *:* mu) + ((1d - y) *:* (1d - mu)))) + sum(log(abs(diag(QR.r)))) + logLkhd = sum(breeze.numerics.log((y *:* mu) + ((1d - y) *:* (1d - mu)))) + sum( + log(abs(diag(QR.r))) + ) } else { b += deltaB } } catch { - case e: breeze.linalg.MatrixSingularException => exploded = true - case e: breeze.linalg.NotConvergedException => exploded = true + case _: breeze.linalg.MatrixSingularException => exploded = true + case _: breeze.linalg.NotConvergedException => exploded = true } } @@ -385,7 +401,8 @@ object GLMFit { val schema: Type = TStruct( ("n_iterations", TInt32), ("converged", TBoolean), - ("exploded", TBoolean)) + ("exploded", TBoolean), + ) } case class GLMFit( @@ -395,7 +412,8 @@ case class GLMFit( logLkhd: Double, nIter: Int, converged: Boolean, - exploded: Boolean) { + exploded: Boolean, +) { def addToRVB(rvb: RegionValueBuilder): Unit = { rvb.startStruct() diff --git a/hail/src/main/scala/is/hail/stats/PoissonRegressionModel.scala b/hail/src/main/scala/is/hail/stats/PoissonRegressionModel.scala index 1678ee38f94..2cc2264b79c 100644 --- a/hail/src/main/scala/is/hail/stats/PoissonRegressionModel.scala +++ b/hail/src/main/scala/is/hail/stats/PoissonRegressionModel.scala @@ -1,19 +1,19 @@ package is.hail.stats -import breeze.linalg._ -import breeze.numerics._ import is.hail.types.virtual.{TFloat64, TStruct} +import breeze.linalg._ +import breeze.numerics._ object PoissonRegressionTest { val tests = Map("wald" -> WaldTest, "lrt" -> LikelihoodRatioTest, "score" -> PoissonScoreTest) } - object PoissonScoreTest extends GLMTest { val schema: TStruct = TStruct( ("chi_sq_stat", TFloat64), - ("p_value", TFloat64)) + ("p_value", TFloat64), + ) def test( X: DenseMatrix[Double], @@ -21,7 +21,7 @@ object PoissonScoreTest extends GLMTest { nullFit: GLMFit, link: String, maxIter: Int, - tol: Double + tol: Double, ): GLMTestResult[ScoreStats] = { require(link == "poisson") require(nullFit.score.isDefined && nullFit.fisher.isDefined) @@ -55,8 +55,8 @@ object PoissonScoreTest extends GLMTest { Some(ScoreStats(chi2, p)) } catch { - case e: breeze.linalg.MatrixSingularException => None - case e: breeze.linalg.NotConvergedException => None + case _: breeze.linalg.MatrixSingularException => None + case _: breeze.linalg.NotConvergedException => None } } @@ -64,8 +64,8 @@ object PoissonScoreTest extends GLMTest { } } - -class PoissonRegressionModel(X: DenseMatrix[Double], y: DenseVector[Double]) extends GeneralLinearModel { +class PoissonRegressionModel(X: DenseMatrix[Double], y: DenseVector[Double]) + extends GeneralLinearModel { require(y.length == X.rows) val n: Int = X.rows @@ -133,12 +133,13 @@ class PoissonRegressionModel(X: DenseMatrix[Double], y: DenseVector[Double]) ext fisher := X.t * (X(::, *) *:* mu) } } catch { - case e: breeze.linalg.MatrixSingularException => exploded = true - case e: breeze.linalg.NotConvergedException => exploded = true + case _: breeze.linalg.MatrixSingularException => exploded = true + case _: breeze.linalg.NotConvergedException => exploded = true } } - // dropping constant that depends only on y: -sum(breeze.numerics.lgamma(DenseVector(y.data.filter(_ != 0.0)))) + /* dropping constant that depends only on y: + * -sum(breeze.numerics.lgamma(DenseVector(y.data.filter(_ != 0.0)))) */ val logLkhd = (y dot breeze.numerics.log(mu)) - sum(mu) GLMFit(b, Some(score), Some(fisher), logLkhd, iter, converged, exploded) diff --git a/hail/src/main/scala/is/hail/stats/RegressionUtils.scala b/hail/src/main/scala/is/hail/stats/RegressionUtils.scala index 2282cf94d01..81533606a65 100644 --- a/hail/src/main/scala/is/hail/stats/RegressionUtils.scala +++ b/hail/src/main/scala/is/hail/stats/RegressionUtils.scala @@ -1,15 +1,17 @@ package is.hail.stats -import breeze.linalg._ -import is.hail.annotations.{Region, RegionValue} +import is.hail.annotations.Region import is.hail.expr.ir.{IntArrayBuilder, MatrixValue} import is.hail.types.physical.{PArray, PStruct} import is.hail.types.virtual.TFloat64 import is.hail.utils._ + +import breeze.linalg._ import org.apache.spark.sql.Row -object RegressionUtils { - def setMeanImputedDoubles(data: Array[Double], +object RegressionUtils { + def setMeanImputedDoubles( + data: Array[Double], offset: Int, completeColIdx: Array[Int], missingCompleteCols: IntArrayBuilder, @@ -18,7 +20,8 @@ object RegressionUtils { entryArrayType: PArray, entryType: PStruct, entryArrayIdx: Int, - fieldIdx: Int) : Unit = { + fieldIdx: Int, + ): Unit = { missingCompleteCols.clear() val n = completeColIdx.length @@ -52,8 +55,8 @@ object RegressionUtils { } // IndexedSeq indexed by column, Array by field - def getColumnVariables(mv: MatrixValue, names: Array[String]): IndexedSeq[Array[Option[Double]]] = { - val colType = mv.typ.colType + def getColumnVariables(mv: MatrixValue, names: Array[String]) + : IndexedSeq[Array[Option[Double]]] = { assert(names.forall(name => mv.typ.colType.field(name).typ == TFloat64)) val fieldIndices = names.map { name => val field = mv.typ.colType.field(name) @@ -71,7 +74,8 @@ object RegressionUtils { def getPhenoCovCompleteSamples( mv: MatrixValue, yField: String, - covFields: Array[String]): (DenseVector[Double], DenseMatrix[Double], Array[Int]) = { + covFields: Array[String], + ): (DenseVector[Double], DenseMatrix[Double], Array[Int]) = { val (y, covs, completeSamples) = getPhenosCovCompleteSamples(mv, Array(yField), covFields) @@ -81,14 +85,15 @@ object RegressionUtils { def getPhenosCovCompleteSamples( mv: MatrixValue, yFields: Array[String], - covFields: Array[String]): (DenseMatrix[Double], DenseMatrix[Double], Array[Int]) = { + covFields: Array[String], + ): (DenseMatrix[Double], DenseMatrix[Double], Array[Int]) = { val nPhenos = yFields.length val nCovs = covFields.length if (nPhenos == 0) fatal("No phenotypes present.") - + val yIS = getColumnVariables(mv, yFields) val covIS = getColumnVariables(mv, covFields) @@ -103,13 +108,15 @@ object RegressionUtils { fatal("No complete samples: each sample is missing its phenotype or some covariate") val yArray = yForCompleteSamples.flatMap(_.map(_.get)).toArray - val y = new DenseMatrix(rows = n, cols = nPhenos, data = yArray, offset = 0, majorStride = nPhenos, isTranspose = true) + val y = new DenseMatrix(rows = n, cols = nPhenos, data = yArray, offset = 0, + majorStride = nPhenos, isTranspose = true) val covArray = covForCompleteSamples.flatMap(_.map(_.get)).toArray - val cov = new DenseMatrix(rows = n, cols = nCovs, data = covArray, offset = 0, majorStride = nCovs, isTranspose = true) + val cov = new DenseMatrix(rows = n, cols = nCovs, data = covArray, offset = 0, + majorStride = nCovs, isTranspose = true) if (n < nCols) - warn(s"${ nCols - n } of $nCols samples have a missing phenotype or covariate.") + warn(s"${nCols - n} of $nCols samples have a missing phenotype or covariate.") (y, cov, completeSamples.toArray) } diff --git a/hail/src/main/scala/is/hail/stats/eigSymD.scala b/hail/src/main/scala/is/hail/stats/eigSymD.scala index 4b5ca338219..2367330d9f3 100644 --- a/hail/src/main/scala/is/hail/stats/eigSymD.scala +++ b/hail/src/main/scala/is/hail/stats/eigSymD.scala @@ -5,9 +5,8 @@ import breeze.linalg._ import com.github.fommil.netlib.LAPACK.{getInstance => lapack} import org.netlib.util.intW -/** - * Computes all eigenvalues (and optionally right eigenvectors) of the given - * real symmetric matrix X using dsyevd (Divide and Conquer): +/** Computes all eigenvalues (and optionally right eigenvectors) of the given real symmetric matrix + * X using dsyevd (Divide and Conquer): * `http://www.netlib.org/lapack/explore-html/d2/d8a/group__double_s_yeigen_ga694ddc6e5527b6223748e3462013d867.html` * * Based on eigSym in breeze.linalg.eig but replaces dsyev with dsyevd for higher performance: @@ -16,33 +15,34 @@ import org.netlib.util.intW object eigSymD extends UFunc { case class EigSymD[V, M](eigenvalues: V, eigenvectors: M) type DenseEigSymD = EigSymD[DenseVector[Double], DenseMatrix[Double]] + implicit object eigSymD_DM_Impl extends Impl[DenseMatrix[Double], DenseEigSymD] { - def apply(X: DenseMatrix[Double]): DenseEigSymD = { + def apply(X: DenseMatrix[Double]): DenseEigSymD = doeigSymD(X, rightEigenvectors = true) match { case (ev, Some(rev)) => EigSymD(ev, rev) case _ => throw new RuntimeException("Shouldn't be here!") } - } } object justEigenvalues extends UFunc { implicit object eigSymD_DM_Impl extends Impl[DenseMatrix[Double], DenseVector[Double]] { - def apply(X: DenseMatrix[Double]): DenseVector[Double] = { + def apply(X: DenseMatrix[Double]): DenseVector[Double] = doeigSymD(X, rightEigenvectors = false)._1 - } } } - def doeigSymD(X: Matrix[Double], rightEigenvectors: Boolean): (DenseVector[Double], Option[DenseMatrix[Double]]) = { + def doeigSymD(X: Matrix[Double], rightEigenvectors: Boolean) + : (DenseVector[Double], Option[DenseMatrix[Double]]) = { // assumes X is non-empty and symmetric, caller should check if necessary - val JOBZ = if (rightEigenvectors) "V" else "N" /* eigenvalues N, eigenvalues & eigenvectors "V" */ + val JOBZ = + if (rightEigenvectors) "V" else "N" /* eigenvalues N, eigenvalues & eigenvectors "V" */ val UPLO = "L" val N = X.rows val A = lowerTriangular(X) - val LDA = scala.math.max(1,N) + val LDA = scala.math.max(1, N) val W = DenseVector.zeros[Double](N) val LWORK = if (N <= 1) @@ -72,9 +72,8 @@ object eigSymD extends UFunc { } } -/** - * Computes all eigenvalues (and optionally right eigenvectors) of the given - * real symmetric matrix X using dsyevr (Relatively Robust Representations): +/** Computes all eigenvalues (and optionally right eigenvectors) of the given real symmetric matrix + * X using dsyevr (Relatively Robust Representations): * `http://www.netlib.org/lapack/explore-html/d2/d8a/group__double_s_yeigen_ga2ad9f4a91cddbf67fe41b621bd158f5c.html` * * Based on eigSym in breeze.linalg.eig but replaces dsyev with dsyevr for higher performance: @@ -83,34 +82,35 @@ object eigSymD extends UFunc { object eigSymR extends UFunc { case class EigSymR[V, M](eigenvalues: V, eigenvectors: M) type DenseeigSymR = EigSymR[DenseVector[Double], DenseMatrix[Double]] + implicit object eigSymR_DM_Impl extends Impl[DenseMatrix[Double], DenseeigSymR] { - def apply(X: DenseMatrix[Double]): DenseeigSymR = { + def apply(X: DenseMatrix[Double]): DenseeigSymR = doeigSymR(X, rightEigenvectors = true) match { case (ev, Some(rev)) => EigSymR(ev, rev) case _ => throw new RuntimeException("Shouldn't be here!") } - } } object justEigenvalues extends UFunc { implicit object eigSymR_DM_Impl extends Impl[DenseMatrix[Double], DenseVector[Double]] { - def apply(X: DenseMatrix[Double]): DenseVector[Double] = { + def apply(X: DenseMatrix[Double]): DenseVector[Double] = doeigSymR(X, rightEigenvectors = false)._1 - } } } - private def doeigSymR(X: Matrix[Double], rightEigenvectors: Boolean): (DenseVector[Double], Option[DenseMatrix[Double]]) = { + private def doeigSymR(X: Matrix[Double], rightEigenvectors: Boolean) + : (DenseVector[Double], Option[DenseMatrix[Double]]) = { // assumes X is non-empty and symmetric, caller should check if necessary - val JOBZ = if (rightEigenvectors) "V" else "N" /* eigenvalues N, eigenvalues & eigenvectors "V" */ + val JOBZ = + if (rightEigenvectors) "V" else "N" /* eigenvalues N, eigenvalues & eigenvectors "V" */ val RANGE = "A" // supports eigenvalue range, but not implementing that interface here val UPLO = "L" val N = X.rows val A = X.toDenseMatrix.data - val LDA = scala.math.max(1,N) + val LDA = scala.math.max(1, N) val ABSTOL = -1d // default tolerance val W = DenseVector.zeros[Double](N) val M = new intW(0) @@ -123,8 +123,29 @@ object eigSymR extends UFunc { val IWORK = new Array[Int](LIWORK) val info = new intW(0) - lapack.dsyevr(JOBZ, RANGE, UPLO, N, A, LDA, 0d, 0d, 0, 0, ABSTOL, M, W.data, - Z.data, LDZ, ISUPPZ.data, WORK, LWORK, IWORK, LIWORK, info) + lapack.dsyevr( + JOBZ, + RANGE, + UPLO, + N, + A, + LDA, + 0d, + 0d, + 0, + 0, + ABSTOL, + M, + W.data, + Z.data, + LDZ, + ISUPPZ.data, + WORK, + LWORK, + IWORK, + LIWORK, + info, + ) // A value of info.`val` < 0 would tell us that the i-th argument // of the call to dsyevr was erroneous (where i == |info.`val`|). assert(info.`val` >= 0) @@ -137,10 +158,8 @@ object eigSymR extends UFunc { } object TriSolve { - /* - Solve for x in A * x = b with upper triangular A - http://www.netlib.org/lapack/explore-html/da/dba/group__double_o_t_h_e_rcomputational_ga4e87e579d3e1a56b405d572f868cd9a1.html - */ + /* Solve for x in A * x = b with upper triangular A + * http://www.netlib.org/lapack/explore-html/da/dba/group__double_o_t_h_e_rcomputational_ga4e87e579d3e1a56b405d572f868cd9a1.html */ def apply(A: DenseMatrix[Double], b: DenseVector[Double]): DenseVector[Double] = { require(A.rows == A.cols) @@ -150,7 +169,18 @@ object TriSolve { val info: Int = { val info = new intW(0) - lapack.dtrtrs("U", "N", "N", A.rows, 1, A.toArray, A.rows, x.data, x.length, info) // x := A \ x + lapack.dtrtrs( + "U", + "N", + "N", + A.rows, + 1, + A.toArray, + A.rows, + x.data, + x.length, + info, + ) // x := A \ x info.`val` } diff --git a/hail/src/main/scala/is/hail/stats/package.scala b/hail/src/main/scala/is/hail/stats/package.scala index 44c769421ce..2b87f1d320f 100644 --- a/hail/src/main/scala/is/hail/stats/package.scala +++ b/hail/src/main/scala/is/hail/stats/package.scala @@ -1,16 +1,18 @@ package is.hail -import is.hail.types.physical.{PCanonicalStruct, PFloat64, PStruct} +import is.hail.types.physical.{PCanonicalStruct, PFloat64} import is.hail.utils._ -import net.sourceforge.jdistlib.disttest.{DistributionTest, TestKind} + import net.sourceforge.jdistlib.{Beta, ChiSquare, NonCentralChiSquare, Normal, Poisson} +import net.sourceforge.jdistlib.disttest.{DistributionTest, TestKind} import org.apache.commons.math3.distribution.HypergeometricDistribution package object stats { - def uniroot(fn: Double => Double, min: Double, max: Double, tolerance: Double = 1.220703e-4): Option[Double] = { + def uniroot(fn: Double => Double, min: Double, max: Double, tolerance: Double = 1.220703e-4) + : Option[Double] = { // based on C code in R source code called zeroin.c - // https://github.com/wch/r-source/blob/e5b21d0397c607883ff25cca379687b86933d730/src/library/stats/src/zeroin.c + /* https://github.com/wch/r-source/blob/e5b21d0397c607883ff25cca379687b86933d730/src/library/stats/src/zeroin.c */ require(min < max, "interval start must be larger than end") @@ -72,13 +74,15 @@ package object stats { q = (q - 1d) * (t1 - 1d) * (t2 - 1d) } - if (p > 0d) //p was calculated with opposite sign + if (p > 0d) // p was calculated with opposite sign q = -q else p = -p - if (p < (0.75 * cb * q - math.abs(toleranceActual * q) / 2) && - p < math.abs(previousStep * q / 2)) + if ( + p < (0.75 * cb * q - math.abs(toleranceActual * q) / 2) && + p < math.abs(previousStep * q / 2) + ) newStep = p / q } @@ -108,7 +112,9 @@ package object stats { def hardyWeinbergTest(nHomRef: Int, nHet: Int, nHomVar: Int, oneSided: Boolean): Array[Double] = { if (nHomRef < 0 || nHet < 0 || nHomVar < 0) - fatal(s"hardy_weinberg_test: all arguments must be non-negative, got $nHomRef, $nHet, $nHomVar") + fatal( + s"hardy_weinberg_test: all arguments must be non-negative, got $nHomRef, $nHet, $nHomVar" + ) val n = nHomRef + nHet + nHomVar val nAB = nHet @@ -152,15 +158,21 @@ package object stats { "p_value" -> PFloat64(required = true), "odds_ratio" -> PFloat64(required = true), "ci_95_lower" -> PFloat64(required = true), - "ci_95_upper" -> PFloat64(required = true)) + "ci_95_upper" -> PFloat64(required = true), + ) def fisherExactTest(a: Int, b: Int, c: Int, d: Int): Array[Double] = fisherExactTest(a, b, c, d, 1.0, 0.95, "two.sided") - def fisherExactTest(a: Int, b: Int, c: Int, d: Int, + def fisherExactTest( + a: Int, + b: Int, + c: Int, + d: Int, oddsRatio: Double = 1d, confidenceLevel: Double = 0.95, - alternative: String = "two.sided"): Array[Double] = { + alternative: String = "two.sided", + ): Array[Double] = { if (!(a >= 0 && b >= 0 && c >= 0 && d >= 0)) fatal(s"fisher_exact_test: all arguments must be non-negative, got $a, $b, $c, $d") @@ -179,7 +191,9 @@ package object stats { val sampleSize = a + b val numSuccessSample = a - if (!(popSize > 0 && sampleSize > 0 && sampleSize < popSize && numSuccessPopulation > 0 && numSuccessPopulation < popSize)) + if ( + !(popSize > 0 && sampleSize > 0 && sampleSize < popSize && numSuccessPopulation > 0 && numSuccessPopulation < popSize) + ) return Array(Double.NaN, Double.NaN, Double.NaN, Double.NaN) val low = math.max(0, (a + b) - (b + d)) @@ -189,9 +203,8 @@ package object stats { val hgd = new HypergeometricDistribution(null, popSize, numSuccessPopulation, sampleSize) val epsilon = 2.220446e-16 - def dhyper(k: Int, logProb: Boolean = false): Double = { + def dhyper(k: Int, logProb: Boolean): Double = if (logProb) hgd.logProbability(k) else hgd.probability(k) - } val logdc = support.map(dhyper(_, logProb = true)) @@ -201,12 +214,11 @@ package object stats { d.map(_ / d.sum) } - def phyper(k: Int, lower_tail: Boolean = true): Double = { + def phyper(k: Int, lower_tail: Boolean): Double = if (lower_tail) hgd.cumulativeProbability(k) else hgd.upperCumulativeProbability(k) - } def pnhyper(q: Int, ncp: Double = 1d, upper_tail: Boolean = false): Double = { if (ncp == 1d) { @@ -217,16 +229,18 @@ package object stats { } else if (ncp == 0d) { if (upper_tail) if (q <= low) 1d else 0d - else if (q >= low) 1d else 0d + else if (q >= low) 1d + else 0d } else if (ncp == Double.PositiveInfinity) { if (upper_tail) if (q <= high) 1d else 0d - else if (q >= high) 1d else 0d + else if (q >= high) 1d + else 0d } else { dnhyper(ncp) .zipWithIndex - .filter { case (dbl, i) => if (upper_tail) support(i) >= q else support(i) <= q } - .map { case (dbl, i) => dbl } + .filter { case (_, i) => if (upper_tail) support(i) >= q else support(i) <= q } + .map { case (dbl, _) => dbl } .sum } } @@ -242,7 +256,14 @@ package object stats { def unirootMnHyper(fn: Double => Double, x: Double)(t: Double) = mnhyper(fn(t)) - x - def unirootPnHyper(fn: Double => Double, x: Int, upper_tail: Boolean, alpha: Double)(t: Double) = + def unirootPnHyper( + fn: Double => Double, + x: Int, + upper_tail: Boolean, + alpha: Double, + )( + t: Double + ) = pnhyper(x, fn(t), upper_tail) - alpha def mle(x: Double): Double = { @@ -269,7 +290,11 @@ package object stats { if (p > alpha) uniroot(unirootPnHyper(d => d, x, upper_tail = true, alpha), 0d, 1d).getOrElse(Double.NaN) else if (p < alpha) - 1.0 / uniroot(unirootPnHyper(d => 1 / d, x, upper_tail = true, alpha), epsilon, 1d).getOrElse(Double.NaN) + 1.0 / uniroot( + unirootPnHyper(d => 1 / d, x, upper_tail = true, alpha), + epsilon, + 1d, + ).getOrElse(Double.NaN) else 1.0 } @@ -278,13 +303,18 @@ package object stats { def ncpUpper(x: Int, alpha: Double): Double = { if (x == high) { Double.PositiveInfinity - } - else { + } else { val p = pnhyper(x) if (p < alpha) - uniroot(unirootPnHyper(d => d, x, upper_tail = false, alpha), 0d, 1d).getOrElse(Double.NaN) + uniroot(unirootPnHyper(d => d, x, upper_tail = false, alpha), 0d, 1d).getOrElse( + Double.NaN + ) else if (p > alpha) - 1.0 / uniroot(unirootPnHyper(d => 1 / d, x, upper_tail = false, alpha), epsilon, 1d).getOrElse(Double.NaN) + 1.0 / uniroot( + unirootPnHyper(d => 1 / d, x, upper_tail = false, alpha), + epsilon, + 1d, + ).getOrElse(Double.NaN) else 1.0 } @@ -320,21 +350,25 @@ package object stats { Array(pvalue, oddsRatioEstimate, confInterval._1, confInterval._2) } - def dnorm(x: Double, mu: Double, sigma: Double, logP: Boolean): Double = Normal.density(x, mu, sigma, logP) + def dnorm(x: Double, mu: Double, sigma: Double, logP: Boolean): Double = + Normal.density(x, mu, sigma, logP) def dnorm(x: Double): Double = dnorm(x, mu = 0, sigma = 1, logP = false) // Returns the p for which p = Prob(Z < x) with Z a standard normal RV - def pnorm(x: Double, mu: Double, sigma: Double, lowerTail: Boolean, logP: Boolean): Double = Normal.cumulative(x, mu, sigma, lowerTail, logP) + def pnorm(x: Double, mu: Double, sigma: Double, lowerTail: Boolean, logP: Boolean): Double = + Normal.cumulative(x, mu, sigma, lowerTail, logP) def pnorm(x: Double): Double = pnorm(x, mu = 0, sigma = 1, lowerTail = true, logP = false) // Returns the x for which p = Prob(Z < x) with Z a standard normal RV - def qnorm(p: Double, mu: Double, sigma: Double, lowerTail: Boolean, logP: Boolean): Double = Normal.quantile(p, mu, sigma, lowerTail, logP) + def qnorm(p: Double, mu: Double, sigma: Double, lowerTail: Boolean, logP: Boolean): Double = + Normal.quantile(p, mu, sigma, lowerTail, logP) def qnorm(p: Double): Double = qnorm(p, mu = 0, sigma = 1, lowerTail = true, logP = false) - // Returns the p for which p = Prob(Z < x) with Z a RV having the T distribution with n degrees of freedom + /* Returns the p for which p = Prob(Z < x) with Z a RV having the T distribution with n degrees of + * freedom */ def pT(x: Double, n: Double, lower_tail: Boolean, log_p: Boolean): Double = net.sourceforge.jdistlib.T.cumulative(x, n, lower_tail, log_p) @@ -345,31 +379,53 @@ package object stats { def dchisq(x: Double, df: Double): Double = dchisq(x, df, logP = false) - def dnchisq(x: Double, df: Double, ncp: Double, logP: Boolean): Double = NonCentralChiSquare.density(x, df, ncp, logP) + def dnchisq(x: Double, df: Double, ncp: Double, logP: Boolean): Double = + NonCentralChiSquare.density(x, df, ncp, logP) def dnchisq(x: Double, df: Double, ncp: Double): Double = dnchisq(x, df, ncp, logP = false) // Returns the p for which p = Prob(Z^2 > x) with Z^2 a chi-squared RV with df degrees of freedom - def pchisqtail(x: Double, df: Double, lowerTail: Boolean, logP: Boolean): Double = ChiSquare.cumulative(x, df, lowerTail, logP) + def pchisqtail(x: Double, df: Double, lowerTail: Boolean, logP: Boolean): Double = + ChiSquare.cumulative(x, df, lowerTail, logP) def pchisqtail(x: Double, df: Double): Double = pchisqtail(x, df, lowerTail = false, logP = false) - def pnchisqtail(x: Double, df: Double, ncp: Double, lowerTail: Boolean, logP: Boolean): Double = NonCentralChiSquare.cumulative(x, df, ncp, lowerTail, logP) + def pnchisqtail(x: Double, df: Double, ncp: Double, lowerTail: Boolean, logP: Boolean): Double = + NonCentralChiSquare.cumulative(x, df, ncp, lowerTail, logP) - def pnchisqtail(x: Double, df: Double, ncp: Double): Double = pnchisqtail(x, df, ncp, lowerTail = false, logP = false) + def pnchisqtail(x: Double, df: Double, ncp: Double): Double = + pnchisqtail(x, df, ncp, lowerTail = false, logP = false) // Returns the x for which p = Prob(Z^2 > x) with Z^2 a chi-squared RV with df degrees of freedom - def qchisqtail(p: Double, df: Double, lowerTail: Boolean, logP: Boolean): Double = ChiSquare.quantile(p, df, lowerTail, logP) + def qchisqtail(p: Double, df: Double, lowerTail: Boolean, logP: Boolean): Double = + ChiSquare.quantile(p, df, lowerTail, logP) def qchisqtail(p: Double, df: Double): Double = qchisqtail(p, df, lowerTail = false, logP = false) - def qnchisqtail(p: Double, df: Double, ncp: Double, lowerTail: Boolean, logP: Boolean): Double = NonCentralChiSquare.quantile(p, df, ncp, lowerTail, logP) - - def qnchisqtail(p: Double, df: Double, ncp: Double): Double = qnchisqtail(p, df, ncp, lowerTail = false, logP = false) - - def pgenchisq(x: Double, w: IndexedSeq[Double], k: IndexedSeq[Int], lam: IndexedSeq[Double], sigma: Double, lim: Int, acc: Double): DaviesResultForPython = { - GeneralizedChiSquaredDistribution.cdfReturnExceptions(x, k.toArray, w.toArray, lam.toArray, sigma, lim, acc) - } + def qnchisqtail(p: Double, df: Double, ncp: Double, lowerTail: Boolean, logP: Boolean): Double = + NonCentralChiSquare.quantile(p, df, ncp, lowerTail, logP) + + def qnchisqtail(p: Double, df: Double, ncp: Double): Double = + qnchisqtail(p, df, ncp, lowerTail = false, logP = false) + + def pgenchisq( + x: Double, + w: IndexedSeq[Double], + k: IndexedSeq[Int], + lam: IndexedSeq[Double], + sigma: Double, + lim: Int, + acc: Double, + ): DaviesResultForPython = + GeneralizedChiSquaredDistribution.cdfReturnExceptions( + x, + k.toArray, + w.toArray, + lam.toArray, + sigma, + lim, + acc, + ) def dbeta(x: Double, a: Double, b: Double): Double = Beta.density(x, a, b, false) @@ -377,7 +433,8 @@ package object stats { def dpois(x: Double, lambda: Double): Double = dpois(x, lambda, logP = false) - def ppois(x: Double, lambda: Double, lowerTail: Boolean, logP: Boolean): Double = new Poisson(lambda).cumulative(x, lowerTail, logP) + def ppois(x: Double, lambda: Double, lowerTail: Boolean, logP: Boolean): Double = + new Poisson(lambda).cumulative(x, lowerTail, logP) def ppois(x: Double, lambda: Double): Double = ppois(x, lambda, lowerTail = true, logP = false) @@ -394,7 +451,8 @@ package object stats { case 0 => TestKind.TWO_SIDED case 1 => TestKind.LOWER case 2 => TestKind.GREATER - case _ => fatal(s"""Invalid alternative "$alternative". Must be "two-sided", "less" or "greater".""") + case _ => + fatal(s"""Invalid alternative "$alternative". Must be "two-sided", "less" or "greater".""") } DistributionTest.binomial_test(nSuccess, n, p, kind)(1) diff --git a/hail/src/main/scala/is/hail/types/BaseType.scala b/hail/src/main/scala/is/hail/types/BaseType.scala index eadf183ea88..eac535587e3 100644 --- a/hail/src/main/scala/is/hail/types/BaseType.scala +++ b/hail/src/main/scala/is/hail/types/BaseType.scala @@ -1,7 +1,7 @@ package is.hail.types abstract class BaseType { - override final def toString: String = { + final override def toString: String = { val sb = new StringBuilder pyString(sb) sb.result() @@ -13,7 +13,7 @@ abstract class BaseType { sb.result() } - def pretty(sb: StringBuilder, indent: Int, compact: Boolean) + def pretty(sb: StringBuilder, indent: Int, compact: Boolean): Unit def parsableString(): String = toPrettyString(compact = true) diff --git a/hail/src/main/scala/is/hail/types/BlockMatrixType.scala b/hail/src/main/scala/is/hail/types/BlockMatrixType.scala index 569139ab302..48f4ada6b37 100644 --- a/hail/src/main/scala/is/hail/types/BlockMatrixType.scala +++ b/hail/src/main/scala/is/hail/types/BlockMatrixType.scala @@ -4,6 +4,7 @@ import is.hail.expr.ir._ import is.hail.linalg.BlockMatrix import is.hail.types.virtual._ import is.hail.utils._ + import org.apache.spark.sql.Row object BlockMatrixSparsity { @@ -11,9 +12,11 @@ object BlockMatrixSparsity { val dense: BlockMatrixSparsity = new BlockMatrixSparsity(None: Option[IndexedSeq[(Int, Int)]]) - def apply(definedBlocks: IndexedSeq[(Int, Int)]): BlockMatrixSparsity = BlockMatrixSparsity(Some(definedBlocks)) + def apply(definedBlocks: IndexedSeq[(Int, Int)]): BlockMatrixSparsity = + BlockMatrixSparsity(Some(definedBlocks)) - def constructFromShapeAndFunction(nRows: Int, nCols: Int)(exists: (Int, Int) => Boolean): BlockMatrixSparsity = { + def constructFromShapeAndFunction(nRows: Int, nCols: Int)(exists: (Int, Int) => Boolean) + : BlockMatrixSparsity = { var i = 0 builder.clear() while (i < nRows) { @@ -28,14 +31,25 @@ object BlockMatrixSparsity { BlockMatrixSparsity(Some(builder.result().toFastSeq)) } - def fromLinearBlocks(nCols: Long, nRows: Long, blockSize: Int, definedBlocks: Option[IndexedSeq[Int]]): BlockMatrixSparsity = { + def fromLinearBlocks( + nCols: Long, + nRows: Long, + blockSize: Int, + definedBlocks: Option[IndexedSeq[Int]], + ): BlockMatrixSparsity = { val nColBlocks = BlockMatrixType.numBlocks(nCols, blockSize) definedBlocks.map { blocks => - BlockMatrixSparsity(blocks.map { linearIdx => java.lang.Math.floorDiv(linearIdx, nColBlocks) -> linearIdx % nColBlocks }) + BlockMatrixSparsity(blocks.map { linearIdx => + java.lang.Math.floorDiv(linearIdx, nColBlocks) -> linearIdx % nColBlocks + }) }.getOrElse(dense) } + def transposeCSCSparsity( - nRows: Int, nCols: Int, rowPos: IndexedSeq[Int], rowIdx: IndexedSeq[Int] + nRows: Int, + nCols: Int, + rowPos: IndexedSeq[Int], + rowIdx: IndexedSeq[Int], ): (IndexedSeq[Int], IndexedSeq[Int], IndexedSeq[Int]) = { val newRowPos = Array.ofDim[Int](nRows + 1) val newRowIdx = Array.ofDim[Int](rowIdx.length) @@ -80,7 +94,10 @@ object BlockMatrixSparsity { } def transposeCSCSparsityIR( - nRows: Int, nCols: Int, rowPos: IndexedSeq[Int], rowIdx: IndexedSeq[Int] + nRows: Int, + nCols: Int, + rowPos: IndexedSeq[Int], + rowIdx: IndexedSeq[Int], ): (IR, IR, IR) = { val (newRowPos, newRowIdx, newToOldPos) = transposeCSCSparsity(nRows, nCols, rowPos, rowIdx) val t = TArray(TInt32) @@ -88,8 +105,10 @@ object BlockMatrixSparsity { } def filterCSCSparsity( - rowPos: IndexedSeq[Int], rowIdx: IndexedSeq[Int], - rowDeps: IndexedSeq[Int], colDeps: IndexedSeq[Int] + rowPos: IndexedSeq[Int], + rowIdx: IndexedSeq[Int], + rowDeps: IndexedSeq[Int], + colDeps: IndexedSeq[Int], ): (IndexedSeq[Int], IndexedSeq[Int], IndexedSeq[Int]) = { val newRowPos = new IntArrayBuilder() val newRowIdx = new IntArrayBuilder() @@ -123,12 +142,19 @@ object BlockMatrixSparsity { } def groupedCSCSparsity( - rowPos: IndexedSeq[Int], rowIdx: IndexedSeq[Int], - rowDeps: IndexedSeq[IndexedSeq[Int]], colDeps: IndexedSeq[IndexedSeq[Int]] - ): (IndexedSeq[Int], IndexedSeq[Int], IndexedSeq[(IndexedSeq[Int], IndexedSeq[Int], IndexedSeq[Int])]) = { + rowPos: IndexedSeq[Int], + rowIdx: IndexedSeq[Int], + rowDeps: IndexedSeq[IndexedSeq[Int]], + colDeps: IndexedSeq[IndexedSeq[Int]], + ): ( + IndexedSeq[Int], + IndexedSeq[Int], + IndexedSeq[(IndexedSeq[Int], IndexedSeq[Int], IndexedSeq[Int])], + ) = { val newRowPos = new IntArrayBuilder() val newRowIdx = new IntArrayBuilder() - val nestedSparsities = new AnyRefArrayBuilder[(IndexedSeq[Int], IndexedSeq[Int], IndexedSeq[Int])]() + val nestedSparsities = + new AnyRefArrayBuilder[(IndexedSeq[Int], IndexedSeq[Int], IndexedSeq[Int])]() var curOutPos = 0 var j = 0 @@ -152,48 +178,57 @@ object BlockMatrixSparsity { } def groupedCSCSparsityIR( - rowPos: IndexedSeq[Int], rowIdx: IndexedSeq[Int], - rowDeps: IndexedSeq[IndexedSeq[Int]], colDeps: IndexedSeq[IndexedSeq[Int]] + rowPos: IndexedSeq[Int], + rowIdx: IndexedSeq[Int], + rowDeps: IndexedSeq[IndexedSeq[Int]], + colDeps: IndexedSeq[IndexedSeq[Int]], ): (IR, IR, IR) = { - val (newRowPos, newRowIdx, nestedSparsities) = groupedCSCSparsity(rowPos, rowIdx, rowDeps, colDeps) + val (newRowPos, newRowIdx, nestedSparsities) = + groupedCSCSparsity(rowPos, rowIdx, rowDeps, colDeps) val t = TArray(TInt32) - (Literal(t, newRowPos), Literal(t, newRowIdx), Literal(TArray(TTuple(t, t, t)), nestedSparsities.map(Row.fromTuple))) + ( + Literal(t, newRowPos), + Literal(t, newRowIdx), + Literal(TArray(TTuple(t, t, t)), nestedSparsities.map(Row.fromTuple)), + ) } } case class BlockMatrixSparsity(definedBlocks: Option[IndexedSeq[(Int, Int)]]) { lazy val definedBlocksColMajor: Option[IndexedSeq[(Int, Int)]] = definedBlocks.map { blocks => blocks.sortWith { case ((i1, j1), (i2, j2)) => - j1 < j2 || (j1 == j2 && i1 < i2) + j1 < j2 || (j1 == j2 && i1 < i2) } } - def definedBlocksCSC(nCols: Int): Option[(IndexedSeq[Int], IndexedSeq[Int])] = definedBlocksColMajor.map { blocks => - var curColIdx = 0 - var curPos = 0 - val pos = new Array[Int](nCols + 1) - val rowIdx = new IntArrayBuilder() - - pos(0) = 0 - for ((i, j) <- blocks) { - while (curColIdx < j) { + def definedBlocksCSC(nCols: Int): Option[(IndexedSeq[Int], IndexedSeq[Int])] = + definedBlocksColMajor.map { blocks => + var curColIdx = 0 + var curPos = 0 + val pos = new Array[Int](nCols + 1) + val rowIdx = new IntArrayBuilder() + + pos(0) = 0 + for ((i, j) <- blocks) { + while (curColIdx < j) { + pos(curColIdx + 1) = curPos + curColIdx += 1 + } + rowIdx += i + curPos += 1 + } + while (curColIdx < nCols) { pos(curColIdx + 1) = curPos curColIdx += 1 } - rowIdx += i - curPos += 1 - } - while (curColIdx < nCols) { - pos(curColIdx + 1) = curPos - curColIdx += 1 + (pos, rowIdx.result()) } - (pos, rowIdx.result()) - } - def definedBlocksCSCIR(nCols: Int): Option[(IR, IR)] = definedBlocksCSC(nCols).map { case (rowPos, rowIdx) => - val t = TArray(TInt32) - (Literal(t, rowPos), Literal(t, rowIdx)) - } + def definedBlocksCSCIR(nCols: Int): Option[(IR, IR)] = + definedBlocksCSC(nCols).map { case (rowPos, rowIdx) => + val t = TArray(TInt32) + (Literal(t, rowPos), Literal(t, rowIdx)) + } def definedBlocksColMajorIR: Option[IR] = definedBlocksColMajor.map { blocks => ToStream(Literal(TArray(TTuple(TInt32, TInt32)), blocks.map(Row.fromTuple))) @@ -212,6 +247,7 @@ case class BlockMatrixSparsity(definedBlocks: Option[IndexedSeq[(Int, Int)]]) { def isSparse: Boolean = definedBlocks.isDefined lazy val blockSet: Set[(Int, Int)] = definedBlocks.get.toSet def hasBlock(idx: (Int, Int)): Boolean = definedBlocks.isEmpty || blockSet.contains(idx) + def condense(blockOverlaps: => (Array[Array[Int]], Array[Array[Int]])): BlockMatrixSparsity = { definedBlocks.map { _ => val (ro, co) = blockOverlaps @@ -220,6 +256,7 @@ case class BlockMatrixSparsity(definedBlocks: Option[IndexedSeq[(Int, Int)]]) { } }.getOrElse(BlockMatrixSparsity.dense) } + def allBlocksColMajor(nRowBlocks: Int, nColBlocks: Int): IndexedSeq[(Int, Int)] = { definedBlocksColMajor.getOrElse { val foo = Array.fill[(Int, Int)](nRowBlocks * nColBlocks)(null) @@ -238,13 +275,10 @@ case class BlockMatrixSparsity(definedBlocks: Option[IndexedSeq[(Int, Int)]]) { } } - def allBlocksColMajorIR(nRowBlocks: Int, nColBlocks: Int): IR = definedBlocksColMajorIR.getOrElse { - flatMapIR(rangeIR(nColBlocks)) { j => - mapIR(rangeIR(nRowBlocks)) { i => - maketuple(i, j) - } + def allBlocksColMajorIR(nRowBlocks: Int, nColBlocks: Int): IR = + definedBlocksColMajorIR.getOrElse { + flatMapIR(rangeIR(nColBlocks))(j => mapIR(rangeIR(nRowBlocks))(i => maketuple(i, j))) } - } def allBlocksRowMajor(nRowBlocks: Int, nColBlocks: Int): IndexedSeq[(Int, Int)] = { (definedBlocksRowMajor).getOrElse { @@ -264,13 +298,10 @@ case class BlockMatrixSparsity(definedBlocks: Option[IndexedSeq[(Int, Int)]]) { } } - def allBlocksRowMajorIR(nRowBlocks: Int, nColBlocks: Int): IR = definedBlocksRowMajorIR.getOrElse { - flatMapIR(rangeIR(nRowBlocks)) { i => - mapIR(rangeIR(nColBlocks)) { j => - maketuple(i, j) - } + def allBlocksRowMajorIR(nRowBlocks: Int, nColBlocks: Int): IR = + definedBlocksRowMajorIR.getOrElse { + flatMapIR(rangeIR(nRowBlocks))(i => mapIR(rangeIR(nColBlocks))(j => maketuple(i, j))) } - } def transpose: BlockMatrixSparsity = BlockMatrixSparsity(definedBlocks.map(_.map { case (i, j) => (j, i) })) @@ -282,15 +313,14 @@ case class BlockMatrixSparsity(definedBlocks: Option[IndexedSeq[(Int, Int)]]) { } object BlockMatrixType { - def tensorToMatrixShape(shape: IndexedSeq[Long], isRowVector: Boolean): (Long, Long) = { + def tensorToMatrixShape(shape: IndexedSeq[Long], isRowVector: Boolean): (Long, Long) = shape match { case IndexedSeq() => (1, 1) case IndexedSeq(vectorLength) => if (isRowVector) (1, vectorLength) else (vectorLength, 1) case IndexedSeq(numRows, numCols) => (numRows, numCols) } - } - def matrixToTensorShape(nRows: Long, nCols: Long): (IndexedSeq[Long], Boolean) = { + def matrixToTensorShape(nRows: Long, nCols: Long): (IndexedSeq[Long], Boolean) = { (nRows, nCols) match { case (1, 1) => (FastSeq(), false) case (_, 1) => (FastSeq(nRows), false) @@ -310,7 +340,12 @@ object BlockMatrixType { } def fromBlockMatrix(value: BlockMatrix): BlockMatrixType = { - val sparsity = BlockMatrixSparsity.fromLinearBlocks(value.nRows, value.nCols, value.blockSize, value.gp.partitionIndexToBlockIndex) + val sparsity = BlockMatrixSparsity.fromLinearBlocks( + value.nRows, + value.nCols, + value.blockSize, + value.gp.partitionIndexToBlockIndex, + ) val (shape, isRowVector) = matrixToTensorShape(value.nRows, value.nCols) BlockMatrixType(TFloat64, shape, isRowVector, value.blockSize, sparsity) } @@ -321,7 +356,7 @@ case class BlockMatrixType( shape: IndexedSeq[Long], isRowVector: Boolean, blockSize: Int, - sparsity: BlockMatrixSparsity + sparsity: BlockMatrixSparsity, ) extends BaseType { require(blockSize >= 0) lazy val (nRows: Long, nCols: Long) = BlockMatrixType.tensorToMatrixShape(shape, isRowVector) @@ -336,10 +371,13 @@ case class BlockMatrixType( def getBlockIdx(i: Long): Int = java.lang.Math.floorDiv(i, blockSize).toInt def isSparse: Boolean = sparsity.isSparse + def nDefinedBlocks: Int = if (isSparse) sparsity.definedBlocks.get.length else nRowBlocks * nColBlocks + def hasBlock(idx: (Int, Int)): Boolean = - if (isSparse) sparsity.hasBlock(idx) else idx._1 >= 0 && idx._1 < nRowBlocks && idx._2 >= 0 && idx._2 < nColBlocks + if (isSparse) sparsity.hasBlock(idx) + else idx._1 >= 0 && idx._1 < nRowBlocks && idx._2 >= 0 && idx._2 < nColBlocks def transpose: BlockMatrixType = { val newShape = shape match { @@ -356,8 +394,8 @@ case class BlockMatrixType( def allBlocksRowMajor: IndexedSeq[(Int, Int)] = sparsity.allBlocksRowMajor(nRowBlocks, nColBlocks) def allBlocksRowMajorIR: IR = sparsity.allBlocksRowMajorIR(nRowBlocks, nColBlocks) - lazy val linearizedDefinedBlocks: Option[IndexedSeq[Int]] = sparsity.definedBlocksColMajor.map { blocks => - blocks.map { case (i, j) => i + j * nRowBlocks } + lazy val linearizedDefinedBlocks: Option[IndexedSeq[Int]] = sparsity.definedBlocksColMajor.map { + blocks => blocks.map { case (i, j) => i + j * nRowBlocks } } def blockShape(i: Int, j: Int): (Long, Long) = { @@ -375,12 +413,18 @@ case class BlockMatrixType( } private[this] def getBlockDependencies(keep: Array[Array[Long]]): Array[Array[Int]] = - keep.map(keeps => Array.range(BlockMatrixType.getBlockIdx(keeps.head, blockSize), BlockMatrixType.getBlockIdx(keeps.last, blockSize) + 1)).toArray + keep.map(keeps => + Array.range( + BlockMatrixType.getBlockIdx(keeps.head, blockSize), + BlockMatrixType.getBlockIdx(keeps.last, blockSize) + 1, + ) + ).toArray def rowBlockDependents(keepRows: Array[Array[Long]]): Array[Array[Int]] = if (keepRows.isEmpty) Array.tabulate(nRowBlocks)(i => Array(i)) else getBlockDependencies(keepRows) + def colBlockDependents(keepCols: Array[Array[Long]]): Array[Array[Int]] = if (keepCols.isEmpty) Array.tabulate(nColBlocks)(i => Array(i)) else @@ -391,12 +435,11 @@ case class BlockMatrixType( val space: String = if (compact) "" else " " - def newline() { + def newline(): Unit = if (!compact) { sb += '\n' sb.append(" " * indent) } - } sb.append(s"BlockMatrix$space{") indent += 4 diff --git a/hail/src/main/scala/is/hail/types/Box.scala b/hail/src/main/scala/is/hail/types/Box.scala index a338bdf1f6f..3c516dd19df 100644 --- a/hail/src/main/scala/is/hail/types/Box.scala +++ b/hail/src/main/scala/is/hail/types/Box.scala @@ -4,8 +4,9 @@ import java.util.function._ final case class Box[T]( b: ThreadLocal[Option[T]] = ThreadLocal.withInitial( - new Supplier[Option[T]] { def get = None }), - matchCond: (T, T) => Boolean = { (a: T, b: T) => a == b } + new Supplier[Option[T]] { def get = None } + ), + matchCond: (T, T) => Boolean = { (a: T, b: T) => a == b }, ) { def unify(t: T): Boolean = b.get match { case Some(bt) => matchCond(t, bt) @@ -14,9 +15,8 @@ final case class Box[T]( true } - def clear() { + def clear(): Unit = b.set(None) - } def get: T = b.get.get diff --git a/hail/src/main/scala/is/hail/types/MapTypes.scala b/hail/src/main/scala/is/hail/types/MapTypes.scala index 6c7ee875e1e..c95a6cf7325 100644 --- a/hail/src/main/scala/is/hail/types/MapTypes.scala +++ b/hail/src/main/scala/is/hail/types/MapTypes.scala @@ -8,16 +8,16 @@ object MapTypes { case TArray(elt) => TArray(f(elt)) case TSet(elt) => TSet(f(elt)) case TDict(kt, vt) => TDict(f(kt), f(vt)) - case t: TStruct => TStruct(t.fields.map { field => (field.name, f(field.typ)) }: _*) + case t: TStruct => TStruct(t.fields.map(field => (field.name, f(field.typ))): _*) case t: TTuple => TTuple(t.types.map(f): _*) case _ => typ } def recur(f: Type => Type)(typ: Type): Type = { - def recurF(t: Type): Type = f(apply { t => recurF(t) }(t)) + def recurF(t: Type): Type = f(apply(t => recurF(t))(t)) recurF(typ) } def foreach(f: Type => Unit)(typ: Type): Unit = - recur{ t => f(t); t }(typ) -} \ No newline at end of file + recur { t => f(t); t }(typ) +} diff --git a/hail/src/main/scala/is/hail/types/MatrixType.scala b/hail/src/main/scala/is/hail/types/MatrixType.scala index 457d9898059..12ff236eafc 100644 --- a/hail/src/main/scala/is/hail/types/MatrixType.scala +++ b/hail/src/main/scala/is/hail/types/MatrixType.scala @@ -5,37 +5,48 @@ import is.hail.expr.ir.{Env, IRParser, LowerMatrixIR} import is.hail.types.physical.{PArray, PStruct} import is.hail.types.virtual._ import is.hail.utils._ + import org.apache.spark.sql.Row import org.json4s.CustomSerializer import org.json4s.JsonAST.{JArray, JObject, JString} - -class MatrixTypeSerializer extends CustomSerializer[MatrixType](format => ( - { case JString(s) => IRParser.parseMatrixType(s) }, - { case mt: MatrixType => JString(mt.toString) })) +class MatrixTypeSerializer extends CustomSerializer[MatrixType](format => + ( + { case JString(s) => IRParser.parseMatrixType(s) }, + { case mt: MatrixType => JString(mt.toString) }, + ) + ) object MatrixType { val entriesIdentifier = "the entries! [877f12a8827e18f61222c6c8c5fb04a8]" def getRowType(rvRowType: PStruct): PStruct = rvRowType.dropFields(Set(entriesIdentifier)) - def getEntryArrayType(rvRowType: PStruct): PArray = rvRowType.field(entriesIdentifier).typ.asInstanceOf[PArray] - def getSplitEntriesType(rvRowType: PStruct): PStruct = rvRowType.selectFields(Array(entriesIdentifier)) - def getEntryType(rvRowType: PStruct): PStruct = getEntryArrayType(rvRowType).elementType.asInstanceOf[PStruct] + + def getEntryArrayType(rvRowType: PStruct): PArray = + rvRowType.field(entriesIdentifier).typ.asInstanceOf[PArray] + + def getSplitEntriesType(rvRowType: PStruct): PStruct = + rvRowType.selectFields(Array(entriesIdentifier)) + + def getEntryType(rvRowType: PStruct): PStruct = + getEntryArrayType(rvRowType).elementType.asInstanceOf[PStruct] + def getEntriesIndex(rvRowType: PStruct): Int = rvRowType.fieldIdx(entriesIdentifier) def fromTableType( typ: TableType, colsFieldName: String, entriesFieldName: String, - colKey: IndexedSeq[String] + colKey: IndexedSeq[String], ): MatrixType = { val (colType, colsFieldIdx) = typ.globalType.field(colsFieldName) match { - case Field(_, TArray(t@TStruct(_)), idx) => (t, idx) + case Field(_, TArray(t @ TStruct(_)), idx) => (t, idx) case Field(_, t, _) => fatal(s"expected cols field to be an array of structs, found $t") } val newRowType = typ.rowType.deleteKey(entriesFieldName) - val entryType = typ.rowType.field(entriesFieldName).typ.asInstanceOf[TArray].elementType.asInstanceOf[TStruct] + val entryType = + typ.rowType.field(entriesFieldName).typ.asInstanceOf[TArray].elementType.asInstanceOf[TStruct] MatrixType( typ.globalType.deleteKey(colsFieldName, colsFieldIdx), @@ -43,7 +54,8 @@ object MatrixType { colType, typ.key, newRowType, - entryType) + entryType, + ) } } @@ -53,33 +65,46 @@ case class MatrixType( colType: TStruct, rowKey: IndexedSeq[String], rowType: TStruct, - entryType: TStruct + entryType: TStruct, ) extends BaseType { - assert({ - val colFields = colType.fieldNames.toSet - colKey.forall(colFields.contains) - }, s"$colKey: $colType") + assert( + { + val colFields = colType.fieldNames.toSet + colKey.forall(colFields.contains) + }, + s"$colKey: $colType", + ) lazy val entriesRVType: TStruct = TStruct( - MatrixType.entriesIdentifier -> TArray(entryType)) + MatrixType.entriesIdentifier -> TArray(entryType) + ) - assert({ - val rowFields = rowType.fieldNames.toSet - rowKey.forall(rowFields.contains) - }, s"$rowKey: $rowType") + assert( + { + val rowFields = rowType.fieldNames.toSet + rowKey.forall(rowFields.contains) + }, + s"$rowKey: $rowType", + ) lazy val (rowKeyStruct, _) = rowType.select(rowKey) def extractRowKey: Row => Row = rowType.select(rowKey)._2 lazy val rowKeyFieldIdx: Array[Int] = rowKey.toArray.map(rowType.fieldIdx) lazy val (rowValueStruct, _) = rowType.filterSet(rowKey.toSet, include = false) - def extractRowValue: Annotation => Annotation = rowType.filterSet(rowKey.toSet, include = false)._2 + + def extractRowValue: Annotation => Annotation = + rowType.filterSet(rowKey.toSet, include = false)._2 + lazy val rowValueFieldIdx: Array[Int] = rowValueStruct.fieldNames.map(rowType.fieldIdx) lazy val (colKeyStruct, _) = colType.select(colKey) def extractColKey: Row => Row = colType.select(colKey)._2 lazy val colKeyFieldIdx: Array[Int] = colKey.toArray.map(colType.fieldIdx) lazy val (colValueStruct, _) = colType.filterSet(colKey.toSet, include = false) - def extractColValue: Annotation => Annotation = colType.filterSet(colKey.toSet, include = false)._2 + + def extractColValue: Annotation => Annotation = + colType.filterSet(colKey.toSet, include = false)._2 + lazy val colValueFieldIdx: Array[Int] = colValueStruct.fieldNames.map(colType.fieldIdx) lazy val colsTableType: TableType = @@ -89,22 +114,27 @@ case class MatrixType( TableType(rowType, rowKey, globalType) lazy val entriesTableType: TableType = { - val resultStruct = TStruct((rowType.fields ++ colType.fields ++ entryType.fields).map(f => f.name -> f.typ): _*) + val resultStruct = + TStruct((rowType.fields ++ colType.fields ++ entryType.fields).map(f => f.name -> f.typ): _*) TableType(resultStruct, rowKey ++ colKey, globalType) } - lazy val canonicalTableType: TableType = toTableType(LowerMatrixIR.entriesFieldName, LowerMatrixIR.colsFieldName) + lazy val canonicalTableType: TableType = + toTableType(LowerMatrixIR.entriesFieldName, LowerMatrixIR.colsFieldName) def toTableType(entriesFieldName: String, colsFieldName: String): TableType = TableType( rowType = rowType.appendKey(entriesFieldName, TArray(entryType)), key = rowKey, - globalType = globalType.appendKey(colsFieldName, TArray(colType))) + globalType = globalType.appendKey(colsFieldName, TArray(colType)), + ) def isCompatibleWith(tt: TableType): Boolean = { val globalType2 = tt.globalType.deleteKey(LowerMatrixIR.colsFieldName) - val colType2 = tt.globalType.field(LowerMatrixIR.colsFieldName).typ.asInstanceOf[TArray].elementType + val colType2 = + tt.globalType.field(LowerMatrixIR.colsFieldName).typ.asInstanceOf[TArray].elementType val rowType2 = tt.rowType.deleteKey(LowerMatrixIR.entriesFieldName) - val entryType2 = tt.rowType.field(LowerMatrixIR.entriesFieldName).typ.asInstanceOf[TArray].elementType + val entryType2 = + tt.rowType.field(LowerMatrixIR.entriesFieldName).typ.asInstanceOf[TArray].elementType globalType == globalType2 && colType == colType2 && rowType == rowType2 && entryType == entryType2 && rowKey == tt.key } @@ -115,19 +145,19 @@ case class MatrixType( "global" -> globalType, "va" -> rowType, "sa" -> colType, - "g" -> entryType) + "g" -> entryType, + ) - def pretty(sb: StringBuilder, indent0: Int = 0, compact: Boolean = false) { + def pretty(sb: StringBuilder, indent0: Int = 0, compact: Boolean = false): Unit = { var indent = indent0 val space: String = if (compact) "" else " " - def newline() { + def newline(): Unit = if (!compact) { sb += '\n' sb.append(" " * indent) } - } sb.append(s"Matrix$space{") indent += 4 @@ -169,34 +199,40 @@ case class MatrixType( } @transient lazy val globalEnv: Env[Type] = Env.empty[Type] - .bind("global" -> globalType) + .bind(globalBindings: _*) + + def globalBindings: IndexedSeq[(String, Type)] = + FastSeq("global" -> globalType) @transient lazy val rowEnv: Env[Type] = Env.empty[Type] - .bind("global" -> globalType) - .bind("va" -> rowType) + .bind(rowBindings: _*) + + def rowBindings: IndexedSeq[(String, Type)] = + FastSeq("global" -> globalType, "va" -> rowType) @transient lazy val colEnv: Env[Type] = Env.empty[Type] - .bind("global" -> globalType) - .bind("sa" -> colType) + .bind(colBindings: _*) + + def colBindings: IndexedSeq[(String, Type)] = + FastSeq("global" -> globalType, "sa" -> colType) @transient lazy val entryEnv: Env[Type] = Env.empty[Type] - .bind("global" -> globalType) - .bind("sa" -> colType) - .bind("va" -> rowType) - .bind("g" -> entryType) + .bind(entryBindings: _*) - def requireRowKeyVariant() { + def entryBindings: IndexedSeq[(String, Type)] = + FastSeq("global" -> globalType, "sa" -> colType, "va" -> rowType, "g" -> entryType) + + def requireRowKeyVariant(): Unit = { val rowKeyTypes = rowKeyStruct.types rowKey.zip(rowKeyTypes) match { case Seq(("locus", TLocus(_)), ("alleles", TArray(TString))) => } } - def requireColKeyString() { + def requireColKeyString(): Unit = colKeyStruct.types match { case Array(TString) => } - } def referenceGenomeName: String = { val firstKeyField = rowKeyStruct.types(0) @@ -210,7 +246,7 @@ case class MatrixType( "col_type" -> JString(colType.toString), "col_key" -> JArray(colKey.toList.map(JString(_))), "entry_type" -> JString(entryType.toString), - "global_type" -> JString(globalType.toString) + "global_type" -> JString(globalType.toString), ) } } diff --git a/hail/src/main/scala/is/hail/types/TableType.scala b/hail/src/main/scala/is/hail/types/TableType.scala index 2defa934819..2007069ab03 100644 --- a/hail/src/main/scala/is/hail/types/TableType.scala +++ b/hail/src/main/scala/is/hail/types/TableType.scala @@ -1,43 +1,50 @@ package is.hail.types import is.hail.expr.ir._ +import is.hail.rvd.RVDType import is.hail.types.physical.{PStruct, PType} import is.hail.types.virtual.{TStruct, Type} -import is.hail.rvd.RVDType import is.hail.utils._ import org.json4s._ -import org.json4s.CustomSerializer import org.json4s.JsonAST.JString -class TableTypeSerializer extends CustomSerializer[TableType](format => ( - { case JString(s) => IRParser.parseTableType(s) }, - { case tt: TableType => JString(tt.toString) })) +class TableTypeSerializer extends CustomSerializer[TableType](format => + ( + { case JString(s) => IRParser.parseTableType(s) }, + { case tt: TableType => JString(tt.toString) }, + ) + ) object TableType { - def keyType(ts: TStruct, key: IndexedSeq[String]): TStruct = ts.typeAfterSelect(key.map(ts.fieldIdx)) - def valueType(ts: TStruct, key: IndexedSeq[String]): TStruct = ts.filterSet(key.toSet, include = false)._1 + def keyType(ts: TStruct, key: IndexedSeq[String]): TStruct = + ts.typeAfterSelect(key.map(ts.fieldIdx)) + + def valueType(ts: TStruct, key: IndexedSeq[String]): TStruct = + ts.filterSet(key.toSet, include = false)._1 } -case class TableType(rowType: TStruct, key: IndexedSeq[String], globalType: TStruct) extends BaseType { +case class TableType(rowType: TStruct, key: IndexedSeq[String], globalType: TStruct) + extends BaseType { lazy val canonicalRowPType = PType.canonical(rowType).setRequired(true).asInstanceOf[PStruct] lazy val canonicalRVDType = RVDType(canonicalRowPType, key) - key.foreach {k => + key.foreach { k => if (!rowType.hasField(k)) throw new RuntimeException(s"key field $k not in row type: $rowType") } - @transient lazy val globalEnv: Env[Type] = Env.empty[Type] - .bind("global" -> globalType) + @transient lazy val globalEnv: Env[Type] = + Env.empty[Type].bind(globalBindings: _*) + + def globalBindings: IndexedSeq[(String, Type)] = + FastSeq("global" -> globalType) - @transient lazy val rowEnv: Env[Type] = Env.empty[Type] - .bind("global" -> globalType) - .bind("row" -> rowType) + @transient lazy val rowEnv: Env[Type] = + Env.empty[Type].bind(rowBindings: _*) - @transient lazy val refMap: Map[String, Type] = Map( - "global" -> globalType, - "row" -> rowType) + def rowBindings: IndexedSeq[(String, Type)] = + FastSeq("global" -> globalType, "row" -> rowType) def isCanonical: Boolean = rowType.isCanonical && globalType.isCanonical @@ -46,17 +53,16 @@ case class TableType(rowType: TStruct, key: IndexedSeq[String], globalType: TStr lazy val valueType: TStruct = TableType.valueType(rowType, key) def valueFieldIdx: Array[Int] = canonicalRVDType.valueFieldIdx - def pretty(sb: StringBuilder, indent0: Int = 0, compact: Boolean = false) { + def pretty(sb: StringBuilder, indent0: Int = 0, compact: Boolean = false): Unit = { var indent = indent0 val space: String = if (compact) "" else " " - def newline() { + def newline(): Unit = if (!compact) { sb += '\n' sb.append(" " * indent) } - } sb.append(s"Table$space{") indent += 4 @@ -81,11 +87,10 @@ case class TableType(rowType: TStruct, key: IndexedSeq[String], globalType: TStr sb += '}' } - def toJSON: JObject = { + def toJSON: JObject = JObject( "global_type" -> JString(globalType.toString), "row_type" -> JString(rowType.toString), - "row_key" -> JArray(key.map(f => JString(f)).toList) + "row_key" -> JArray(key.map(f => JString(f)).toList), ) - } } diff --git a/hail/src/main/scala/is/hail/types/TypeWithRequiredness.scala b/hail/src/main/scala/is/hail/types/TypeWithRequiredness.scala index 1bfa5aa0a44..b8c6047f06d 100644 --- a/hail/src/main/scala/is/hail/types/TypeWithRequiredness.scala +++ b/hail/src/main/scala/is/hail/types/TypeWithRequiredness.scala @@ -2,23 +2,25 @@ package is.hail.types import is.hail.annotations.{Annotation, NDArray} import is.hail.backend.ExecuteContext -import is.hail.expr.ir.lowering.TableStage import is.hail.expr.ir.{ComputeUsesAndDefs, Env, IR} +import is.hail.expr.ir.lowering.TableStage import is.hail.types.physical._ import is.hail.types.physical.stypes.EmitType import is.hail.types.physical.stypes.concrete.SIndexablePointer import is.hail.types.physical.stypes.interfaces.{SBaseStruct, SInterval, SNDArray, SStream} import is.hail.types.virtual._ -import is.hail.utils.{FastSeq, Interval, toMapFast} +import is.hail.utils.{toMapFast, FastSeq, Interval} + import org.apache.spark.sql.Row object BaseTypeWithRequiredness { def apply(typ: BaseType): BaseTypeWithRequiredness = typ match { case t: Type => TypeWithRequiredness(t) case t: TableType => RTable( - t.rowType.fields.map(f => f.name -> TypeWithRequiredness(f.typ)), - t.globalType.fields.map(f => f.name -> TypeWithRequiredness(f.typ)), - t.key) + t.rowType.fields.map(f => f.name -> TypeWithRequiredness(f.typ)), + t.globalType.fields.map(f => f.name -> TypeWithRequiredness(f.typ)), + t.key, + ) case t: BlockMatrixType => RBlockMatrix(TypeWithRequiredness(t.elementType)) } @@ -93,32 +95,35 @@ sealed abstract class BaseTypeWithRequiredness { } protected[this] def _maximizeChildren(): Unit = children.foreach(_.maximize()) + protected[this] def _unionChildren(newChildren: IndexedSeq[BaseTypeWithRequiredness]): Unit = { if (children.length != newChildren.length) { throw new AssertionError( - s"children lengths differed ${children.length} ${newChildren.length}. ${children} ${newChildren} ${this}") + s"children lengths differed ${children.length} ${newChildren.length}. $children $newChildren ${this}" + ) } // foreach on zipped seqs is very slow as the implementation // doesn't know that the seqs are the same length. - for (i <- children.indices) { + for (i <- children.indices) children(i).unionFrom(newChildren(i)) - } } protected[this] def _unionWithIntersection(ts: IndexedSeq[BaseTypeWithRequiredness]): Unit = { var i = 0 - while(i < children.length) { + while (i < children.length) { children(i).unionWithIntersection(ts.map(_.children(i))) i += 1 } } + def unionWithIntersection(ts: IndexedSeq[BaseTypeWithRequiredness]): Unit = { union(ts.exists(_.required)) _unionWithIntersection(ts) } - final def union(r: Boolean): Unit = { change |= !r && required } + final def union(r: Boolean): Unit = change |= !r && required + final def maximize(): Unit = { change |= required _maximizeChildren() @@ -135,7 +140,7 @@ sealed abstract class BaseTypeWithRequiredness { var hasChanged = change _required &= !change change = false - children.foreach { r => hasChanged |= r.probeChangedAndReset() } + children.foreach(r => hasChanged |= r.probeChangedAndReset()) hasChanged } @@ -170,7 +175,8 @@ object VirtualTypeWithReq { val twr = TypeWithRequiredness(t) twr.unionLiteral(value) VirtualTypeWithReq(t, twr) -} + } + def union(vs: IndexedSeq[VirtualTypeWithReq]): VirtualTypeWithReq = { val t = vs.head.t assert(vs.tail.forall(_.t == t)) @@ -186,12 +192,19 @@ object VirtualTypeWithReq { val r = (vt, rt) match { case (_, t: RPrimitive) => t.copy(empty) case (tt: TTuple, rt: RTuple) => - RTuple(tt.fields.map(fd => RField(fd.name, subsetRT(fd.typ, rt.fieldType(fd.name)), fd.index))) + RTuple(tt.fields.map(fd => + RField(fd.name, subsetRT(fd.typ, rt.fieldType(fd.name)), fd.index) + )) case (ts: TStruct, rt: RStruct) => - RStruct(ts.fields.map(fd => RField(fd.name, subsetRT(fd.typ, rt.field(fd.name)), fd.index))) - case (ti: TInterval, ri: RInterval) => RInterval(subsetRT(ti.pointType, ri.startType), subsetRT(ti.pointType, ri.endType)) - case (td: TDict, ri: RDict) => RDict(subsetRT(td.keyType, ri.keyType), subsetRT(td.valueType, ri.valueType)) - case (tit: TIterable, rit: RIterable) => RIterable(subsetRT(tit.elementType, rit.elementType)) + RStruct(ts.fields.map(fd => + RField(fd.name, subsetRT(fd.typ, rt.field(fd.name)), fd.index) + )) + case (ti: TInterval, ri: RInterval) => + RInterval(subsetRT(ti.pointType, ri.startType), subsetRT(ti.pointType, ri.endType)) + case (td: TDict, ri: RDict) => + RDict(subsetRT(td.keyType, ri.keyType), subsetRT(td.valueType, ri.valueType)) + case (tit: TIterable, rit: RIterable) => + RIterable(subsetRT(tit.elementType, rit.elementType)) case (tnd: TNDArray, rnd: RNDArray) => RNDArray(subsetRT(tnd.elementType, rnd.elementType)) } r.union(rt.required) @@ -203,10 +216,17 @@ object VirtualTypeWithReq { case class VirtualTypeWithReq(t: Type, r: TypeWithRequiredness) { lazy val canonicalPType: PType = r.canonicalPType(t) + lazy val canonicalEmitType: EmitType = { t match { case ts: TStream => - EmitType(SStream(VirtualTypeWithReq(ts.elementType, r.asInstanceOf[RIterable].elementType).canonicalEmitType), r.required) + EmitType( + SStream(VirtualTypeWithReq( + ts.elementType, + r.asInstanceOf[RIterable].elementType, + ).canonicalEmitType), + r.required, + ) case t => val pt = r.canonicalPType(t) EmitType(pt.sType, pt.required) @@ -227,9 +247,8 @@ case class VirtualTypeWithReq(t: Type, r: TypeWithRequiredness) { case _ => false } - override def hashCode(): Int = { + override def hashCode(): Int = canonicalPType.hashCode() + 37 - } } sealed abstract class TypeWithRequiredness extends BaseTypeWithRequiredness { @@ -237,6 +256,7 @@ sealed abstract class TypeWithRequiredness extends BaseTypeWithRequiredness { def _unionPType(pType: PType): Unit def _unionEmitType(emitType: EmitType): Unit def _matchesPType(pt: PType): Boolean + def unionLiteral(a: Annotation): Unit = if (a == null) union(false) else _unionLiteral(a) @@ -244,19 +264,26 @@ sealed abstract class TypeWithRequiredness extends BaseTypeWithRequiredness { union(pType.required) _unionPType(pType) } + def fromEmitType(emitType: EmitType): Unit = { union(emitType.required) _unionEmitType(emitType) } + def canonicalPType(t: Type): PType + def canonicalEmitType(t: Type): EmitType = { t match { - case TStream(element) => EmitType(SStream(this.asInstanceOf[RIterable].elementType.canonicalEmitType(element)), required) + case TStream(element) => EmitType( + SStream(this.asInstanceOf[RIterable].elementType.canonicalEmitType(element)), + required, + ) case _ => val pt = canonicalPType(t) EmitType(pt.sType, pt.required) } } + def matchesPType(pt: PType): Boolean = pt.required == required && _matchesPType(pt) def _toString: String override def toString: String = if (required) "+" + _toString else _toString @@ -264,7 +291,10 @@ sealed abstract class TypeWithRequiredness extends BaseTypeWithRequiredness { object RPrimitive { val children: IndexedSeq[TypeWithRequiredness] = FastSeq() - val supportedTypes: Set[Type] = Set(TBoolean, TInt32, TInt64, TFloat32, TFloat64, TBinary, TString, TCall, TVoid, TRNGState) + + val supportedTypes: Set[Type] = + Set(TBoolean, TInt32, TInt64, TFloat32, TFloat64, TBinary, TString, TCall, TVoid, TRNGState) + def typeSupported(t: Type): Boolean = RPrimitive.supportedTypes.contains(t) || t.isInstanceOf[TLocus] } @@ -276,39 +306,49 @@ final case class RPrimitive() extends TypeWithRequiredness { def _matchesPType(pt: PType): Boolean = RPrimitive.typeSupported(pt.virtualType) def _unionPType(pType: PType): Unit = assert(RPrimitive.typeSupported(pType.virtualType)) def _unionEmitType(emitType: EmitType) = assert(RPrimitive.typeSupported(emitType.virtualType)) + def copy(newChildren: IndexedSeq[BaseTypeWithRequiredness]): RPrimitive = { assert(newChildren.isEmpty) RPrimitive() } + def canonicalPType(t: Type): PType = { assert(RPrimitive.typeSupported(t)) PType.canonical(t, required) } + def _toString: String = "RPrimitive" } object RIterable { - def apply(elementType: TypeWithRequiredness): RIterable = new RIterable(elementType, eltRequired = false) + def apply(elementType: TypeWithRequiredness): RIterable = + new RIterable(elementType, eltRequired = false) + def unapply(r: RIterable): Option[TypeWithRequiredness] = Some(r.elementType) } -sealed class RIterable(val elementType: TypeWithRequiredness, eltRequired: Boolean) extends TypeWithRequiredness { +sealed class RIterable(val elementType: TypeWithRequiredness, eltRequired: Boolean) + extends TypeWithRequiredness { val children: IndexedSeq[TypeWithRequiredness] = FastSeq(elementType) - def _unionLiteral(a: Annotation): Unit = { + def _unionLiteral(a: Annotation): Unit = a.asInstanceOf[Iterable[_]].foreach(elt => elementType.unionLiteral(elt)) - } - def _matchesPType(pt: PType): Boolean = elementType.matchesPType(tcoerce[PIterable](pt).elementType) - def _unionPType(pType: PType): Unit = elementType.fromPType(pType.asInstanceOf[PIterable].elementType) - def _unionEmitType(emitType: EmitType): Unit = elementType.fromEmitType(emitType.st.asInstanceOf[SIndexablePointer].elementEmitType) - def _toString: String = s"RIterable[${ elementType.toString }]" + def _matchesPType(pt: PType): Boolean = + elementType.matchesPType(tcoerce[PIterable](pt).elementType) + + def _unionPType(pType: PType): Unit = + elementType.fromPType(pType.asInstanceOf[PIterable].elementType) + + def _unionEmitType(emitType: EmitType): Unit = + elementType.fromEmitType(emitType.st.asInstanceOf[SIndexablePointer].elementEmitType) + + def _toString: String = s"RIterable[${elementType.toString}]" - override def _maximizeChildren(): Unit = { + override def _maximizeChildren(): Unit = if (eltRequired) elementType.children.foreach(_.maximize()) else elementType.maximize() - } override def _unionChildren(newChildren: IndexedSeq[BaseTypeWithRequiredness]): Unit = { val IndexedSeq(newEltReq) = newChildren @@ -318,25 +358,27 @@ sealed class RIterable(val elementType: TypeWithRequiredness, eltRequired: Boole override def _unionWithIntersection(ts: IndexedSeq[BaseTypeWithRequiredness]): Unit = { if (eltRequired) { var i = 0 - while(i < elementType.children.length) { - elementType.children(i).unionWithIntersection(ts.map(t => tcoerce[RIterable](t).elementType.children(i))) + while (i < elementType.children.length) { + elementType.children(i).unionWithIntersection(ts.map(t => + tcoerce[RIterable](t).elementType.children(i) + )) i += 1 } } else elementType.unionWithIntersection(ts.map(t => tcoerce[RIterable](t).elementType)) } - def unionElement(newElement: BaseTypeWithRequiredness): Unit = { + def unionElement(newElement: BaseTypeWithRequiredness): Unit = if (eltRequired) - (elementType.children, newElement.children).zipped.foreach { (r1, r2) => r1.unionFrom(r2) } + (elementType.children, newElement.children).zipped.foreach((r1, r2) => r1.unionFrom(r2)) else elementType.unionFrom(newElement) - } def copy(newChildren: IndexedSeq[BaseTypeWithRequiredness]): RIterable = { val IndexedSeq(newElt: TypeWithRequiredness) = newChildren RIterable(newElt) } + def canonicalPType(t: Type): PType = { val elt = elementType.canonicalPType(tcoerce[TIterable](t).elementType) t match { @@ -345,25 +387,35 @@ sealed class RIterable(val elementType: TypeWithRequiredness, eltRequired: Boole } } } + case class RDict(keyType: TypeWithRequiredness, valueType: TypeWithRequiredness) - extends RIterable(RStruct.fromNamesAndTypes(Array("key" -> keyType, "value" -> valueType)), true) { + extends RIterable( + RStruct.fromNamesAndTypes(Array("key" -> keyType, "value" -> valueType)), + true, + ) { override def _unionLiteral(a: Annotation): Unit = - a.asInstanceOf[Map[_,_]].foreach { case (k, v) => + a.asInstanceOf[Map[_, _]].foreach { case (k, v) => keyType.unionLiteral(k) valueType.unionLiteral(v) } + override def copy(newChildren: IndexedSeq[BaseTypeWithRequiredness]): RDict = { val IndexedSeq(newElt: RStruct) = newChildren RDict(newElt.field("key"), newElt.field("value")) } + override def canonicalPType(t: Type): PType = PCanonicalDict( keyType.canonicalPType(tcoerce[TDict](t).keyType), valueType.canonicalPType(tcoerce[TDict](t).valueType), - required = required) - override def _toString: String = s"RDict[${ keyType.toString }, ${ valueType.toString }]" + required = required, + ) + + override def _toString: String = s"RDict[${keyType.toString}, ${valueType.toString}]" } -case class RNDArray(override val elementType: TypeWithRequiredness) extends RIterable(elementType, true) { + +case class RNDArray(override val elementType: TypeWithRequiredness) + extends RIterable(elementType, true) { override def _unionLiteral(a: Annotation): Unit = { val data = a.asInstanceOf[NDArray].getRowMajorElements() data.foreach { elt => @@ -371,38 +423,53 @@ case class RNDArray(override val elementType: TypeWithRequiredness) extends RIte elementType.unionLiteral(elt) } } - override def _matchesPType(pt: PType): Boolean = elementType.matchesPType(tcoerce[PNDArray](pt).elementType) - override def _unionPType(pType: PType): Unit = elementType.fromPType(pType.asInstanceOf[PNDArray].elementType) - override def _unionEmitType(emitType: EmitType): Unit = elementType.fromEmitType(emitType.st.asInstanceOf[SNDArray].elementEmitType) + + override def _matchesPType(pt: PType): Boolean = + elementType.matchesPType(tcoerce[PNDArray](pt).elementType) + + override def _unionPType(pType: PType): Unit = + elementType.fromPType(pType.asInstanceOf[PNDArray].elementType) + + override def _unionEmitType(emitType: EmitType): Unit = + elementType.fromEmitType(emitType.st.asInstanceOf[SNDArray].elementEmitType) + override def copy(newChildren: IndexedSeq[BaseTypeWithRequiredness]): RNDArray = { val IndexedSeq(newElt: TypeWithRequiredness) = newChildren RNDArray(newElt) } + override def canonicalPType(t: Type): PType = { val tnd = tcoerce[TNDArray](t) PCanonicalNDArray(elementType.canonicalPType(tnd.elementType), tnd.nDims, required = required) } - override def _toString: String = s"RNDArray[${ elementType.toString }]" + + override def _toString: String = s"RNDArray[${elementType.toString}]" } -case class RInterval(startType: TypeWithRequiredness, endType: TypeWithRequiredness) extends TypeWithRequiredness { +case class RInterval(startType: TypeWithRequiredness, endType: TypeWithRequiredness) + extends TypeWithRequiredness { val children: IndexedSeq[TypeWithRequiredness] = FastSeq(startType, endType) + def _unionLiteral(a: Annotation): Unit = { startType.unionLiteral(a.asInstanceOf[Interval].start) endType.unionLiteral(a.asInstanceOf[Interval].end) } + def _matchesPType(pt: PType): Boolean = startType.matchesPType(tcoerce[PInterval](pt).pointType) && endType.matchesPType(tcoerce[PInterval](pt).pointType) + def _unionPType(pType: PType): Unit = { startType.fromPType(pType.asInstanceOf[PInterval].pointType) endType.fromPType(pType.asInstanceOf[PInterval].pointType) } + def _unionEmitType(emitType: EmitType): Unit = { val sInterval = emitType.st.asInstanceOf[SInterval] startType.fromEmitType(sInterval.pointEmitType) endType.fromEmitType(sInterval.pointEmitType) } + def copy(newChildren: IndexedSeq[BaseTypeWithRequiredness]): RInterval = { val IndexedSeq(newStart: TypeWithRequiredness, newEnd: TypeWithRequiredness) = newChildren RInterval(newStart, newEnd) @@ -414,7 +481,8 @@ case class RInterval(startType: TypeWithRequiredness, endType: TypeWithRequiredn unified.unionFrom(endType) PCanonicalInterval(unified.canonicalPType(pointType), required = required) } - def _toString: String = s"RInterval[${ startType.toString }, ${ endType.toString }]" + + def _toString: String = s"RInterval[${startType.toString}, ${endType.toString}]" } case class RField(name: String, typ: TypeWithRequiredness, index: Int) @@ -423,30 +491,39 @@ sealed abstract class RBaseStruct extends TypeWithRequiredness { def fields: IndexedSeq[RField] def size: Int = fields.length val children: IndexedSeq[TypeWithRequiredness] = fields.map(_.typ) + def _unionLiteral(a: Annotation): Unit = - (children, a.asInstanceOf[Row].toSeq).zipped.foreach { (r, f) => r.unionLiteral(f) } + (children, a.asInstanceOf[Row].toSeq).zipped.foreach((r, f) => r.unionLiteral(f)) + def _matchesPType(pt: PType): Boolean = tcoerce[PBaseStruct](pt).fields.forall(f => children(f.index).matchesPType(f.typ)) - def _unionPType(pType: PType): Unit = { + + def _unionPType(pType: PType): Unit = pType.asInstanceOf[PBaseStruct].fields.foreach(f => children(f.index).fromPType(f.typ)) - } - def _unionEmitType(emitType: EmitType): Unit = { - emitType.st.asInstanceOf[SBaseStruct].fieldEmitTypes.zipWithIndex.foreach{ case(et, idx) => children(idx).fromEmitType(et) } - } + + def _unionEmitType(emitType: EmitType): Unit = + emitType.st.asInstanceOf[SBaseStruct].fieldEmitTypes.zipWithIndex.foreach { case (et, idx) => + children(idx).fromEmitType(et) + } def unionFields(other: RStruct): Unit = { assert(fields.length == other.fields.length) - (fields, other.fields).zipped.foreach { (fd1, fd2) => fd1.typ.unionFrom(fd2.typ) } + (fields, other.fields).zipped.foreach((fd1, fd2) => fd1.typ.unionFrom(fd2.typ)) } def canonicalPType(t: Type): PType = t match { case ts: TStruct => - PCanonicalStruct(required = required, - fields.map(f => f.name -> f.typ.canonicalPType(ts.fieldType(f.name))): _*) + PCanonicalStruct( + required = required, + fields.map(f => f.name -> f.typ.canonicalPType(ts.fieldType(f.name))): _* + ) case ts: TTuple => - PCanonicalTuple((fields, ts._types).zipped.map { case(fr, ft) => - PTupleField(ft.index, fr.typ.canonicalPType(ft.typ)) - }, required = required) + PCanonicalTuple( + (fields, ts._types).zipped.map { case (fr, ft) => + PTupleField(ft.index, fr.typ.canonicalPType(ft.typ)) + }, + required = required, + ) } } @@ -460,13 +537,19 @@ case class RStruct(fields: IndexedSeq[RField]) extends RBaseStruct { def field(name: String): TypeWithRequiredness = fieldType(name) def fieldOption(name: String): Option[TypeWithRequiredness] = fieldType.get(name) def hasField(name: String): Boolean = fieldType.contains(name) + def copy(newChildren: IndexedSeq[BaseTypeWithRequiredness]): RStruct = { assert(newChildren.length == fields.length) - RStruct.fromNamesAndTypes(Array.tabulate(fields.length)(i => fields(i).name -> tcoerce[TypeWithRequiredness](newChildren(i)))) + RStruct.fromNamesAndTypes(Array.tabulate(fields.length)(i => + fields(i).name -> tcoerce[TypeWithRequiredness](newChildren(i)) + )) } + def select(newFields: Array[String]): RStruct = RStruct(Array.tabulate(newFields.length)(i => RField(newFields(i), field(newFields(i)), i))) - def _toString: String = s"RStruct[${ fields.map(f => s"${ f.name }: ${ f.typ.toString }").mkString(",") }]" + + def _toString: String = + s"RStruct[${fields.map(f => s"${f.name}: ${f.typ.toString}").mkString(",")}]" } object RTuple { @@ -477,11 +560,16 @@ object RTuple { case class RTuple(fields: IndexedSeq[RField]) extends RBaseStruct { val fieldType: collection.Map[String, TypeWithRequiredness] = toMapFast(fields)(_.name, _.typ) def field(idx: Int): TypeWithRequiredness = fieldType(idx.toString) + def copy(newChildren: IndexedSeq[BaseTypeWithRequiredness]): RTuple = { assert(newChildren.length == fields.length) - RTuple((fields, newChildren).zipped.map { (f, c) => RField(f.name, tcoerce[TypeWithRequiredness](c), f.index) }) + RTuple((fields, newChildren).zipped.map { (f, c) => + RField(f.name, tcoerce[TypeWithRequiredness](c), f.index) + }) } - def _toString: String = s"RTuple[${ fields.map(f => s"${ f.index }: ${ f.typ.toString }").mkString(",") }]" + + def _toString: String = + s"RTuple[${fields.map(f => s"${f.index}: ${f.typ.toString}").mkString(",")}]" } case class RUnion(cases: IndexedSeq[(String, TypeWithRequiredness)]) extends TypeWithRequiredness { @@ -490,18 +578,27 @@ case class RUnion(cases: IndexedSeq[(String, TypeWithRequiredness)]) extends Typ def _matchesPType(pt: PType): Boolean = ??? def _unionPType(pType: PType): Unit = ??? def _unionEmitType(emitType: EmitType): Unit = ??? + def copy(newChildren: IndexedSeq[BaseTypeWithRequiredness]): RUnion = { assert(newChildren.length == cases.length) - RUnion(Array.tabulate(cases.length)(i => cases(i)._1 -> tcoerce[TypeWithRequiredness](newChildren(i)))) + RUnion(Array.tabulate(cases.length)(i => + cases(i)._1 -> tcoerce[TypeWithRequiredness](newChildren(i)) + )) } + def canonicalPType(t: Type): PType = ??? - def _toString: String = s"RStruct[${ cases.map { case (n, t) => s"${ n }: ${ t.toString }" }.mkString(",") }]" + + def _toString: String = + s"RStruct[${cases.map { case (n, t) => s"$n: ${t.toString}" }.mkString(",")}]" } object RTable { - def apply(rowStruct: RStruct, globStruct: RStruct, key: IndexedSeq[String]): RTable = { - RTable(rowStruct.fields.map(f => f.name -> f.typ), globStruct.fields.map(f => f.name -> f.typ), key) - } + def apply(rowStruct: RStruct, globStruct: RStruct, key: IndexedSeq[String]): RTable = + RTable( + rowStruct.fields.map(f => f.name -> f.typ), + globStruct.fields.map(f => f.name -> f.typ), + key, + ) def fromTableStage(ctx: ExecuteContext, s: TableStage): RTable = { def virtualTypeWithReq(ir: IR, inputs: Env[PType]): VirtualTypeWithReq = { @@ -525,21 +622,30 @@ object RTable { }) val ctxReq = - VirtualTypeWithReq(TIterable.elementType(s.contexts.typ), - virtualTypeWithReq(s.contexts, letBindingReq).r.asInstanceOf[RIterable].elementType + VirtualTypeWithReq( + TIterable.elementType(s.contexts.typ), + virtualTypeWithReq(s.contexts, letBindingReq).r.asInstanceOf[RIterable].elementType, ) val globalRType = virtualTypeWithReq(s.globals, letBindingReq).r.asInstanceOf[RStruct] val rowRType = - virtualTypeWithReq(s.partitionIR, broadcastValBindings.bind(s.ctxRefName, ctxReq.canonicalPType)) + virtualTypeWithReq( + s.partitionIR, + broadcastValBindings.bind(s.ctxRefName, ctxReq.canonicalPType), + ) .r.asInstanceOf[RIterable].elementType.asInstanceOf[RStruct] RTable(rowRType, globalRType, s.kType.fieldNames) } } -case class RTable(rowFields: IndexedSeq[(String, TypeWithRequiredness)], globalFields: IndexedSeq[(String, TypeWithRequiredness)], key: Seq[String]) extends BaseTypeWithRequiredness { + +case class RTable( + rowFields: IndexedSeq[(String, TypeWithRequiredness)], + globalFields: IndexedSeq[(String, TypeWithRequiredness)], + key: Seq[String], +) extends BaseTypeWithRequiredness { val rowTypes: IndexedSeq[TypeWithRequiredness] = rowFields.map(_._2) val globalTypes: IndexedSeq[TypeWithRequiredness] = globalFields.map(_._2) @@ -554,27 +660,41 @@ case class RTable(rowFields: IndexedSeq[(String, TypeWithRequiredness)], globalF val rowType: RStruct = RStruct.fromNamesAndTypes(rowFields) val globalType: RStruct = RStruct.fromNamesAndTypes(globalFields) - def unionRows(req: RStruct): Unit = rowFields.foreach { case (n, r) => if (req.hasField(n)) r.unionFrom(req.field(n)) } + def unionRows(req: RStruct): Unit = rowFields.foreach { case (n, r) => + if (req.hasField(n)) r.unionFrom(req.field(n)) + } + def unionRows(req: RTable): Unit = unionRows(req.rowType) - def unionGlobals(req: RStruct): Unit = globalFields.foreach { case (n, r) => if (req.hasField(n)) r.unionFrom(req.field(n)) } + def unionGlobals(req: RStruct): Unit = globalFields.foreach { case (n, r) => + if (req.hasField(n)) r.unionFrom(req.field(n)) + } + def unionGlobals(req: RTable): Unit = unionGlobals(req.globalType) - def unionKeys(req: RStruct): Unit = key.foreach { n => field(n).unionFrom(req.field(n)) } + def unionKeys(req: RStruct): Unit = key.foreach(n => field(n).unionFrom(req.field(n))) + def unionKeys(req: RTable): Unit = { assert(key.length <= req.key.length) - (key, req.key).zipped.foreach { (k, rk) => field(k).unionFrom(req.field(rk)) } + (key, req.key).zipped.foreach((k, rk) => field(k).unionFrom(req.field(rk))) + } + + def unionValues(req: RStruct): Unit = valueFields.foreach { n => + if (req.hasField(n)) field(n).unionFrom(req.field(n)) } - def unionValues(req: RStruct): Unit = valueFields.foreach { n => if (req.hasField(n)) field(n).unionFrom(req.field(n)) } def unionValues(req: RTable): Unit = unionValues(req.rowType) def changeKey(key: IndexedSeq[String]): RTable = RTable(rowFields, globalFields, key) def copy(newChildren: IndexedSeq[BaseTypeWithRequiredness]): RTable = { assert(newChildren.length == rowFields.length + globalFields.length) - val newRowFields = (rowFields, newChildren.take(rowFields.length)).zipped.map { case ((n, _), r: TypeWithRequiredness) => n -> r } - val newGlobalFields = (globalFields, newChildren.drop(rowFields.length)).zipped.map { case ((n, _), r: TypeWithRequiredness) => n -> r } + val newRowFields = (rowFields, newChildren.take(rowFields.length)).zipped.map { + case ((n, _), r: TypeWithRequiredness) => n -> r + } + val newGlobalFields = (globalFields, newChildren.drop(rowFields.length)).zipped.map { + case ((n, _), r: TypeWithRequiredness) => n -> r + } RTable(newRowFields, newGlobalFields, key) } @@ -587,17 +707,19 @@ case class RTable(rowFields: IndexedSeq[(String, TypeWithRequiredness)], globalF } override def toString: String = - s"RTable[\n row:${ rowType.toString }\n global:${ globalType.toString }]" + s"RTable[\n row:${rowType.toString}\n global:${globalType.toString}]" } case class RMatrix(rowType: RStruct, entryType: RStruct, colType: RStruct, globalType: RStruct) { - val entriesRVType: RStruct = RStruct.fromNamesAndTypes(FastSeq(MatrixType.entriesIdentifier -> RIterable(entryType))) + val entriesRVType: RStruct = + RStruct.fromNamesAndTypes(FastSeq(MatrixType.entriesIdentifier -> RIterable(entryType))) } case class RBlockMatrix(elementType: TypeWithRequiredness) extends BaseTypeWithRequiredness { override def children: IndexedSeq[BaseTypeWithRequiredness] = FastSeq(elementType) - override def copy(newChildren: IndexedSeq[BaseTypeWithRequiredness]): BaseTypeWithRequiredness = RBlockMatrix(newChildren(0).asInstanceOf[TypeWithRequiredness]) + override def copy(newChildren: IndexedSeq[BaseTypeWithRequiredness]): BaseTypeWithRequiredness = + RBlockMatrix(newChildren(0).asInstanceOf[TypeWithRequiredness]) - override def toString: String = s"RBlockMatrix(${elementType})" + override def toString: String = s"RBlockMatrix($elementType)" } diff --git a/hail/src/main/scala/is/hail/types/encoded/EArray.scala b/hail/src/main/scala/is/hail/types/encoded/EArray.scala index 4de969ff775..d3562c9cf0d 100644 --- a/hail/src/main/scala/is/hail/types/encoded/EArray.scala +++ b/hail/src/main/scala/is/hail/types/encoded/EArray.scala @@ -5,13 +5,14 @@ import is.hail.asm4s._ import is.hail.expr.ir.EmitCodeBuilder import is.hail.io.{InputBuffer, OutputBuffer} import is.hail.types.physical._ +import is.hail.types.physical.stypes.{SType, SValue} import is.hail.types.physical.stypes.concrete.{SIndexablePointer, SIndexablePointerValue} import is.hail.types.physical.stypes.interfaces.SIndexableValue -import is.hail.types.physical.stypes.{SType, SValue} import is.hail.types.virtual._ import is.hail.utils._ -final case class EArray(val elementType: EType, override val required: Boolean = false) extends EContainer { +final case class EArray(val elementType: EType, override val required: Boolean = false) + extends EContainer { def _decodedSType(requestedType: Type): SType = { val elementPType = elementType.decodedPType(requestedType.asInstanceOf[TContainer].elementType) requestedType match { @@ -25,8 +26,11 @@ final case class EArray(val elementType: EType, override val required: Boolean = } } - def buildPrefixEncoder(cb: EmitCodeBuilder, value: SIndexableValue, - out: Value[OutputBuffer], prefixLength: Code[Int] + def buildPrefixEncoder( + cb: EmitCodeBuilder, + value: SIndexableValue, + out: Value[OutputBuffer], + prefixLength: Code[Int], ): Unit = { val prefixLen = cb.newLocal[Int]("prefixLen", prefixLength) val i = cb.newLocal[Int]("i", 0) @@ -34,8 +38,8 @@ final case class EArray(val elementType: EType, override val required: Boolean = cb += out.writeInt(prefixLen) value.st match { - case s@SIndexablePointer(_: PCanonicalArray | _: PCanonicalSet | _:PCanonicalDict) - if s.pType.elementType.required == elementType.required => + case s @ SIndexablePointer(_: PCanonicalArray | _: PCanonicalSet | _: PCanonicalDict) + if s.pType.elementType.required == elementType.required => val pArray = s.pType match { case t: PCanonicalArray => t case t: PCanonicalSet => t.arrayRep @@ -45,43 +49,57 @@ final case class EArray(val elementType: EType, override val required: Boolean = val array = value.asInstanceOf[SIndexablePointerValue].a if (!elementType.required) { val nMissingBytes = cb.memoize(pArray.nMissingBytes(prefixLen), "nMissingBytes") - cb.if_(nMissingBytes > 0, { - cb += out.writeBytes(array + pArray.missingBytesOffset, nMissingBytes - 1) - cb += out.writeByte((Region.loadByte(array + pArray.missingBytesOffset - + (nMissingBytes - 1).toL) & EType.lowBitMask(prefixLen)).toB) - }) + cb.if_( + nMissingBytes > 0, { + cb += out.writeBytes(array + pArray.missingBytesOffset, nMissingBytes - 1) + cb += out.writeByte((Region.loadByte(array + pArray.missingBytesOffset + + (nMissingBytes - 1).toL) & EType.lowBitMask(prefixLen)).toB) + }, + ) } case _ => if (elementType.required) { - cb.if_(value.hasMissingValues(cb), cb._fatal("cannot encode indexable with missing element(s) to required EArray!")) + cb.if_( + value.hasMissingValues(cb), + cb._fatal("cannot encode indexable with missing element(s) to required EArray!"), + ) } else { val b = Code.newLocal[Int]("b") val shift = Code.newLocal[Int]("shift") cb.assign(b, 0) cb.assign(shift, 0) - cb.while_(i < prefixLen, { - cb.if_(value.isElementMissing(cb, i), cb.assign(b, b | (const(1) << shift))) - cb.assign(shift, shift + 1) - cb.assign(i, i + 1) - cb.if_(shift.ceq(8), { - cb.assign(shift, 0) - cb += out.writeByte(b.toB) - cb.assign(b, 0) - }) - }) + cb.while_( + i < prefixLen, { + cb.if_(value.isElementMissing(cb, i), cb.assign(b, b | (const(1) << shift))) + cb.assign(shift, shift + 1) + cb.assign(i, i + 1) + cb.if_( + shift.ceq(8), { + cb.assign(shift, 0) + cb += out.writeByte(b.toB) + cb.assign(b, 0) + }, + ) + }, + ) cb.if_(shift > 0, cb += out.writeByte(b.toB)) } } - cb.for_(cb.assign(i, 0), i < prefixLen, cb.assign(i, i + 1), { - value.loadElement(cb, i).consume(cb, { - if (elementType.required) - cb._fatal(s"required array element saw missing value at index ", i.toS, " in encode") - }, { pc => - elementType.buildEncoder(pc.st, cb.emb.ecb) - .apply(cb, pc, out) - }) - }) + cb.for_( + cb.assign(i, 0), + i < prefixLen, + cb.assign(i, i + 1), { + value.loadElement(cb, i).consume( + cb, + if (elementType.required) + cb._fatal(s"required array element saw missing value at index ", i.toS, " in encode"), + pc => + elementType.buildEncoder(pc.st, cb.emb.ecb) + .apply(cb, pc, out), + ) + }, + ) } override def _buildEncoder(cb: EmitCodeBuilder, v: SValue, out: Value[OutputBuffer]): Unit = { @@ -89,7 +107,12 @@ final case class EArray(val elementType: EType, override val required: Boolean = buildPrefixEncoder(cb, ind, out, ind.loadLength()) } - override def _buildDecoder(cb: EmitCodeBuilder, t: Type, region: Value[Region], in: Value[InputBuffer]): SValue = { + override def _buildDecoder( + cb: EmitCodeBuilder, + t: Type, + region: Value[Region], + in: Value[InputBuffer], + ): SValue = { val st = decodedSType(t).asInstanceOf[SIndexablePointer] val arrayType: PCanonicalArray = st.pType match { @@ -98,7 +121,10 @@ final case class EArray(val elementType: EType, override val required: Boolean = case t: PCanonicalDict => t.arrayRep } - assert(arrayType.elementType.required == elementType.required, s"${arrayType.elementType.required} | ${elementType.required}") + assert( + arrayType.elementType.required == elementType.required, + s"${arrayType.elementType.required} | ${elementType.required}", + ) val len = cb.memoize(in.readInt(), "len") val array = cb.memoize(arrayType.allocate(region, len), "array") @@ -113,23 +139,33 @@ final case class EArray(val elementType: EType, override val required: Boolean = // elements have 0 size, so all elements have the same address // still need to read `len` elements from the input stream, as they may have non-zero size val i = cb.newLocal[Int]("i") - cb.for_(cb.assign(i, 0), i < len, cb.assign(i, i + 1), { - readElemF(cb, region, elemOff, in) - }) + cb.for_(cb.assign(i, 0), i < len, cb.assign(i, i + 1), readElemF(cb, region, elemOff, in)) } else { - cb.for_({}, elemOff < pastLastOff, cb.assign(elemOff, arrayType.nextElementAddress(elemOff)), { - readElemF(cb, region, elemOff, in) - }) + cb.for_( + {}, + elemOff < pastLastOff, + cb.assign(elemOff, arrayType.nextElementAddress(elemOff)), + readElemF(cb, region, elemOff, in), + ) } } else { - cb += in.readBytes(region, array + const(arrayType.missingBytesOffset), arrayType.nMissingBytes(len)) - - cb.if_((len % 64).cne(0), { - // ensure that the last missing block has all missing bits set past the last element - val lastMissingBlockOff = cb.memoize(UnsafeUtils.roundDownAlignment(arrayType.pastLastMissingByteOff(array, len) - 1, 8)) - val lastMissingBlock = cb.memoize(Region.loadLong(lastMissingBlockOff)) - cb += Region.storeLong(lastMissingBlockOff, lastMissingBlock | (const(-1L) << len)) - }) + cb += in.readBytes( + region, + array + const(arrayType.missingBytesOffset), + arrayType.nMissingBytes(len), + ) + + cb.if_( + (len % 64).cne(0), { + // ensure that the last missing block has all missing bits set past the last element + val lastMissingBlockOff = cb.memoize(UnsafeUtils.roundDownAlignment( + arrayType.pastLastMissingByteOff(array, len) - 1, + 8, + )) + val lastMissingBlock = cb.memoize(Region.loadLong(lastMissingBlockOff)) + cb += Region.storeLong(lastMissingBlockOff, lastMissingBlock | (const(-1L) << len)) + }, + ) def unsetRightMostBit(x: Value[Long]): Code[Long] = x & (x - 1) @@ -137,27 +173,29 @@ final case class EArray(val elementType: EType, override val required: Boolean = val presentBits = cb.newLocal[Long]("presentBits", 0L) val mbyteOffset = cb.newLocal[Long]("mbyteOffset", array + arrayType.missingBytesOffset) val blockOff = cb.newLocal[Long]("blockOff", arrayType.firstElementOffset(array, len)) - val pastLastMissingByteOff = cb.memoize(arrayType.pastLastMissingByteOff(array, len), "pastLastMissingByteAddr") + val pastLastMissingByteOff = + cb.memoize(arrayType.pastLastMissingByteOff(array, len), "pastLastMissingByteAddr") val inBlockIndexToPresentValue = cb.newLocal[Int]("inBlockIndexToPresentValue", 0) cb.for_( {}, - mbyteOffset < pastLastMissingByteOff, - { + mbyteOffset < pastLastMissingByteOff, { cb.assign(blockOff, arrayType.incrementElementOffset(blockOff, 64)) cb.assign(mbyteOffset, mbyteOffset + 8) - }, - { + }, { cb.assign(presentBits, ~Region.loadLong(mbyteOffset)) cb.while_( presentBits.cne(0L), { cb.assign(inBlockIndexToPresentValue, presentBits.numberOfTrailingZeros) val elemOff = cb.memoize( - arrayType.incrementElementOffset(blockOff, inBlockIndexToPresentValue)) + arrayType.incrementElementOffset(blockOff, inBlockIndexToPresentValue) + ) readElemF(cb, region, elemOff, in) cb.assign(presentBits, unsetRightMostBit(presentBits)) - }) - }) + }, + ) + }, + ) } new SIndexablePointerValue(st, array, len, cb.memoize(arrayType.firstElementOffset(array, len))) @@ -173,15 +211,19 @@ final case class EArray(val elementType: EType, override val required: Boolean = val nMissing = cb.newLocal[Int]("nMissing", UnsafeUtils.packBitsToBytes(len)) val mbytes = cb.newLocal[Long]("mbytes", r.allocate(const(1L), nMissing.toL)) cb += in.readBytes(r, mbytes, nMissing) - cb.for_(cb.assign(i, 0), i < len, cb.assign(i, i + 1), - cb.if_(!Region.loadBit(mbytes, i.toL), skip(cb, r, in))) + cb.for_( + cb.assign(i, 0), + i < len, + cb.assign(i, i + 1), + cb.if_(!Region.loadBit(mbytes, i.toL), skip(cb, r, in)), + ) } } def _asIdent = s"array_of_${elementType.asIdent}" def _toPretty = s"EArray[$elementType]" - override def _pretty(sb: StringBuilder, indent: Int, compact: Boolean = false) { + override def _pretty(sb: StringBuilder, indent: Int, compact: Boolean = false): Unit = { sb.append("EArray[") elementType.pretty(sb, indent, compact) sb.append("]") diff --git a/hail/src/main/scala/is/hail/types/encoded/EBaseStruct.scala b/hail/src/main/scala/is/hail/types/encoded/EBaseStruct.scala index 82fe3bd879d..2012175cdd4 100644 --- a/hail/src/main/scala/is/hail/types/encoded/EBaseStruct.scala +++ b/hail/src/main/scala/is/hail/types/encoded/EBaseStruct.scala @@ -1,19 +1,19 @@ package is.hail.types.encoded import is.hail.annotations.{Region, UnsafeUtils} -import is.hail.asm4s._ +import is.hail.asm4s.{Field => _, _} import is.hail.expr.ir.EmitCodeBuilder import is.hail.io.{InputBuffer, OutputBuffer} import is.hail.types.BaseStruct import is.hail.types.physical._ +import is.hail.types.physical.stypes.{SType, SValue} import is.hail.types.physical.stypes.concrete._ import is.hail.types.physical.stypes.interfaces.{SBaseStructValue, SLocus, SLocusValue} -import is.hail.types.physical.stypes.{SType, SValue} import is.hail.types.virtual._ import is.hail.utils._ final case class EField(name: String, typ: EType, index: Int) { - def pretty(sb: StringBuilder, indent: Int, compact: Boolean) { + def pretty(sb: StringBuilder, indent: Int, compact: Boolean): Unit = { if (compact) { sb.append(prettyIdentifier(name)) sb.append(":") @@ -26,7 +26,8 @@ final case class EField(name: String, typ: EType, index: Int) { } } -final case class EBaseStruct(fields: IndexedSeq[EField], override val required: Boolean = false) extends EType { +final case class EBaseStruct(fields: IndexedSeq[EField], override val required: Boolean = false) + extends EType { assert(fields.zipWithIndex.forall { case (f, i) => f.index == i }) val types: Array[EType] = fields.map(_.typ).toArray @@ -40,13 +41,18 @@ final case class EBaseStruct(fields: IndexedSeq[EField], override val required: def fieldType(name: String): EType = types(fieldIdx(name)) - val (missingIdx: Array[Int], nMissing: Int) = BaseStruct.getMissingIndexAndCount(types.map(_.required)) + val (missingIdx: Array[Int], nMissing: Int) = + BaseStruct.getMissingIndexAndCount(types.map(_.required)) + val nMissingBytes = UnsafeUtils.packBitsToBytes(nMissing) if (!fieldNames.areDistinct()) { val duplicates = fieldNames.duplicates() - fatal(s"cannot create struct with duplicate ${ plural(duplicates.size, "field") }: " + - s"${ fieldNames.map(prettyIdentifier).mkString(", ") }", fieldNames.duplicates()) + fatal( + s"cannot create struct with duplicate ${plural(duplicates.size, "field")}: " + + s"${fieldNames.map(prettyIdentifier).mkString(", ")}", + fieldNames.duplicates(), + ) } def _decodedSType(requestedType: Type): SType = requestedType match { @@ -73,21 +79,25 @@ final case class EBaseStruct(fields: IndexedSeq[EField], override val required: override def _buildEncoder(cb: EmitCodeBuilder, v: SValue, out: Value[OutputBuffer]): Unit = { val structValue = v.st match { case SIntervalPointer(t: PCanonicalInterval) => new SBaseStructPointerValue( - SBaseStructPointer(t.representation), - v.asInstanceOf[SIntervalPointerValue].a) + SBaseStructPointer(t.representation), + v.asInstanceOf[SIntervalPointerValue].a, + ) case _: SLocus => v.asInstanceOf[SLocusValue].structRepr(cb) case _ => v.asInstanceOf[SBaseStructValue] } // write missing bytes structValue.st match { - case SBaseStructPointer(st) if st.size == size && st.fieldRequired.sameElements(fields.map(_.typ.required)) => + case SBaseStructPointer(st) + if st.size == size && st.fieldRequired.sameElements(fields.map(_.typ.required)) => val missingBytes = UnsafeUtils.packBitsToBytes(st.nMissing) val addr = structValue.asInstanceOf[SBaseStructPointerValue].a if (nMissingBytes > 1) cb += out.writeBytes(addr, missingBytes - 1) if (nMissingBytes > 0) - cb += out.writeByte((Region.loadByte(addr + (missingBytes.toLong - 1)).toI & const(EType.lowBitMask(st.nMissing & 0x7))).toB) + cb += out.writeByte((Region.loadByte(addr + (missingBytes.toLong - 1)).toI & const( + EType.lowBitMask(st.nMissing & 0x7) + )).toB) case _ => var j = 0 @@ -113,35 +123,52 @@ final case class EBaseStruct(fields: IndexedSeq[EField], override val required: // Write fields fields.foreach { ef => - structValue.loadField(cb, ef.name).consume(cb, - { - if (ef.typ.required) - cb._fatal(s"required field ${ ef.name } saw missing value in encode") - }, - { pc => + structValue.loadField(cb, ef.name).consume( + cb, + if (ef.typ.required) + cb._fatal(s"required field ${ef.name} saw missing value in encode"), + pc => ef.typ.buildEncoder(pc.st, cb.emb.ecb) - .apply(cb, pc, out) - }) + .apply(cb, pc, out), + ) } } - override def _buildDecoder(cb: EmitCodeBuilder, t: Type, region: Value[Region], in: Value[InputBuffer]): SValue = { + override def _buildDecoder( + cb: EmitCodeBuilder, + t: Type, + region: Value[Region], + in: Value[InputBuffer], + ): SValue = { val pt = decodedPType(t) val addr = cb.newLocal[Long]("base_struct_dec_addr", region.allocate(pt.alignment, pt.byteSize)) _buildInplaceDecoder(cb, pt, region, addr, in) pt.loadCheapSCode(cb, addr) } - override def _buildInplaceDecoder(cb: EmitCodeBuilder, pt: PType, region: Value[Region], addr: Value[Long], in: Value[InputBuffer]): Unit = { + override def _buildInplaceDecoder( + cb: EmitCodeBuilder, + pt: PType, + region: Value[Region], + addr: Value[Long], + in: Value[InputBuffer], + ): Unit = { val structType: PBaseStruct = pt match { case t: PCanonicalLocus => t.representation case t: PCanonicalInterval => t.representation case t: PCanonicalBaseStruct => t } val mbytes = cb.newLocal[Long]("mbytes", region.allocate(const(1), const(nMissingBytes))) + var midx = 0 + var byteIdx = 0L cb += in.readBytes(region, mbytes, nMissingBytes) + val m = cb.newLocal[Int]("cached_mbyte") fields.foreach { f => + if (midx == 0 && !f.typ.required) { + cb.assign[Int](m, Region.loadByte(mbytes + byteIdx)) + byteIdx += 1 + } if (structType.hasField(f.name)) { val rf = structType.field(f.name) val readElemF = f.typ.buildInplaceDecoder(rf.typ, cb.emb.ecb) @@ -151,20 +178,22 @@ final case class EBaseStruct(fields: IndexedSeq[EField], override val required: if (!rf.typ.required) structType.setFieldPresent(cb, addr, rf.index) } else { - cb.if_(Region.loadBit(mbytes, const(missingIdx(f.index).toLong)), { - structType.setFieldMissing(cb, addr, rf.index) - }, { - structType.setFieldPresent(cb, addr, rf.index) - readElemF(cb, region, rFieldAddr, in) - }) + cb.if_( + (const(1 << midx) & m).cne(0), + structType.setFieldMissing(cb, addr, rf.index), { + structType.setFieldPresent(cb, addr, rf.index) + readElemF(cb, region, rFieldAddr, in) + }, + ) } } else { val skip = f.typ.buildSkip(cb.emb.ecb) if (f.typ.required) skip(cb, region, in) else - cb.if_(!Region.loadBit(mbytes, const(missingIdx(f.index).toLong)), skip(cb, region, in)) + cb.if_((const(1 << midx) & m).ceq(0), skip(cb, region, in)) } + midx = (midx + (!f.typ.required).toInt) & 0x7 } } @@ -183,9 +212,7 @@ final case class EBaseStruct(fields: IndexedSeq[EField], override val required: def _asIdent: String = { val sb = new StringBuilder sb.append("struct_of_") - types.foreachBetween { ty => - sb.append(ty.asIdent) - } { + types.foreachBetween(ty => sb.append(ty.asIdent)) { sb.append("AND") } sb.append("END") @@ -198,7 +225,7 @@ final case class EBaseStruct(fields: IndexedSeq[EField], override val required: sb.result() } - override def _pretty(sb: StringBuilder, indent: Int, compact: Boolean) { + override def _pretty(sb: StringBuilder, indent: Int, compact: Boolean): Unit = { if (compact) { sb.append("EBaseStruct{") fields.foreachBetween(_.pretty(sb, indent, compact))(sb += ',') diff --git a/hail/src/main/scala/is/hail/types/encoded/EBinary.scala b/hail/src/main/scala/is/hail/types/encoded/EBinary.scala index b1e11f4d4bf..9dce4a73744 100644 --- a/hail/src/main/scala/is/hail/types/encoded/EBinary.scala +++ b/hail/src/main/scala/is/hail/types/encoded/EBinary.scala @@ -5,9 +5,9 @@ import is.hail.asm4s._ import is.hail.expr.ir.EmitCodeBuilder import is.hail.io.{InputBuffer, OutputBuffer} import is.hail.types.physical._ +import is.hail.types.physical.stypes.{SType, SValue} import is.hail.types.physical.stypes.concrete._ import is.hail.types.physical.stypes.interfaces.{SBinary, SBinaryValue, SString} -import is.hail.types.physical.stypes.{SType, SValue} import is.hail.types.virtual._ import is.hail.utils._ @@ -31,13 +31,19 @@ class EBinary(override val required: Boolean) extends EType { v.st match { case SBinaryPointer(_) => writeCanonicalBinary(v.asInstanceOf[SBinaryPointerValue]) - case SStringPointer(_) => writeCanonicalBinary(v.asInstanceOf[SStringPointerValue].binaryRepr()) + case SStringPointer(_) => + writeCanonicalBinary(v.asInstanceOf[SStringPointerValue].binaryRepr()) case _: SBinary => writeBytes(v.asInstanceOf[SBinaryValue].loadBytes(cb)) case _: SString => writeBytes(v.asString.toBytes(cb).loadBytes(cb)) } } - override def _buildDecoder(cb: EmitCodeBuilder, t: Type, region: Value[Region], in: Value[InputBuffer]): SValue = { + override def _buildDecoder( + cb: EmitCodeBuilder, + t: Type, + region: Value[Region], + in: Value[InputBuffer], + ): SValue = { val t1 = decodedSType(t) val pt = t1 match { case SStringPointer(t) => t.binaryRepresentation @@ -55,9 +61,8 @@ class EBinary(override val required: Boolean) extends EType { } } - def _buildSkip(cb: EmitCodeBuilder, r: Value[Region], in: Value[InputBuffer]): Unit = { + def _buildSkip(cb: EmitCodeBuilder, r: Value[Region], in: Value[InputBuffer]): Unit = cb += in.skipBytes(in.readInt()) - } def _decodedSType(requestedType: Type): SType = requestedType match { case TBinary => SBinaryPointer(PCanonicalBinary(false)) diff --git a/hail/src/main/scala/is/hail/types/encoded/EBlockMatrixNDArray.scala b/hail/src/main/scala/is/hail/types/encoded/EBlockMatrixNDArray.scala index 1adf5e9718d..d16963a07d8 100644 --- a/hail/src/main/scala/is/hail/types/encoded/EBlockMatrixNDArray.scala +++ b/hail/src/main/scala/is/hail/types/encoded/EBlockMatrixNDArray.scala @@ -5,16 +5,21 @@ import is.hail.asm4s._ import is.hail.expr.ir.EmitCodeBuilder import is.hail.io.{InputBuffer, OutputBuffer} import is.hail.types.physical._ +import is.hail.types.physical.stypes.{SType, SValue} import is.hail.types.physical.stypes.concrete.SNDArrayPointer import is.hail.types.physical.stypes.interfaces.SNDArrayValue -import is.hail.types.physical.stypes.{SType, SValue} import is.hail.types.virtual._ import is.hail.utils._ -final case class EBlockMatrixNDArray(elementType: EType, encodeRowMajor: Boolean = false, override val required: Boolean = false) extends EType { +final case class EBlockMatrixNDArray( + elementType: EType, + encodeRowMajor: Boolean = false, + override val required: Boolean = false, +) extends EType { type DecodedPType = PCanonicalNDArray - def setRequired(newRequired: Boolean): EBlockMatrixNDArray = EBlockMatrixNDArray(elementType, newRequired) + def setRequired(newRequired: Boolean): EBlockMatrixNDArray = + EBlockMatrixNDArray(elementType, newRequired) def _decodedSType(requestedType: Type): SType = { val elementPType = elementType.decodedPType(requestedType.asInstanceOf[TNDArray].elementType) @@ -34,21 +39,38 @@ final case class EBlockMatrixNDArray(elementType: EType, encodeRowMajor: Boolean cb += out.writeInt(c.toI) cb += out.writeBoolean(encodeRowMajor) if (encodeRowMajor) { - cb.for_(cb.assign(i, 0L), i < r, cb.assign(i, i + 1L), { - cb.for_(cb.assign(j, 0L), j < c, cb.assign(j, j + 1L), { - writeElemF(cb, ndarray.loadElement(FastSeq(i, j), cb), out) - }) - }) + cb.for_( + cb.assign(i, 0L), + i < r, + cb.assign(i, i + 1L), + cb.for_( + cb.assign(j, 0L), + j < c, + cb.assign(j, j + 1L), + writeElemF(cb, ndarray.loadElement(FastSeq(i, j), cb), out), + ), + ) } else { - cb.for_(cb.assign(j, 0L), j < c, cb.assign(j, j + 1L), { - cb.for_(cb.assign(i, 0L), i < r, cb.assign(i, i + 1L), { - writeElemF(cb, ndarray.loadElement(FastSeq(i, j), cb), out) - }) - }) + cb.for_( + cb.assign(j, 0L), + j < c, + cb.assign(j, j + 1L), + cb.for_( + cb.assign(i, 0L), + i < r, + cb.assign(i, i + 1L), + writeElemF(cb, ndarray.loadElement(FastSeq(i, j), cb), out), + ), + ) } } - override def _buildDecoder(cb: EmitCodeBuilder, t: Type, region: Value[Region], in: Value[InputBuffer]): SValue = { + override def _buildDecoder( + cb: EmitCodeBuilder, + t: Type, + region: Value[Region], + in: Value[InputBuffer], + ): SValue = { val st = decodedSType(t).asInstanceOf[SNDArrayPointer] val pt = st.pType val readElemF = elementType.buildInplaceDecoder(pt.elementType, cb.emb.ecb) @@ -57,19 +79,31 @@ final case class EBlockMatrixNDArray(elementType: EType, encodeRowMajor: Boolean val nCols = cb.newLocal[Long]("cols", in.readInt().toL) val transpose = cb.newLocal[Boolean]("transpose", in.readBoolean()) - val stride0 = cb.newLocal[Long]("stride0", transpose.mux(nCols.toL * pt.elementType.byteSize, pt.elementType.byteSize)) - val stride1 = cb.newLocal[Long]("stride1", transpose.mux(pt.elementType.byteSize, nRows * pt.elementType.byteSize)) + val stride0 = cb.newLocal[Long]( + "stride0", + transpose.mux(nCols.toL * pt.elementType.byteSize, pt.elementType.byteSize), + ) + val stride1 = cb.newLocal[Long]( + "stride1", + transpose.mux(pt.elementType.byteSize, nRows * pt.elementType.byteSize), + ) val n = cb.newLocal[Int]("length", nRows.toI * nCols.toI) - val (tFirstElementAddress, tFinisher) = pt.constructDataFunction(IndexedSeq(nRows, nCols), IndexedSeq(stride0, stride1), cb, region) - val currElementAddress = cb.newLocal[Long]("eblockmatrix_ndarray_currElementAddress", tFirstElementAddress) + val (tFirstElementAddress, tFinisher) = + pt.constructDataFunction(IndexedSeq(nRows, nCols), IndexedSeq(stride0, stride1), cb, region) + val currElementAddress = + cb.newLocal[Long]("eblockmatrix_ndarray_currElementAddress", tFirstElementAddress) val i = cb.newLocal[Int]("i") - cb.for_(cb.assign(i, 0), i < n, cb.assign(i, i + 1), { - readElemF(cb, region, currElementAddress, in) - cb.assign(currElementAddress, currElementAddress + pt.elementType.byteSize) - }) + cb.for_( + cb.assign(i, 0), + i < n, + cb.assign(i, i + 1), { + readElemF(cb, region, currElementAddress, in) + cb.assign(currElementAddress, currElementAddress + pt.elementType.byteSize) + }, + ) tFinisher(cb) } @@ -83,11 +117,12 @@ final case class EBlockMatrixNDArray(elementType: EType, encodeRowMajor: Boolean cb.for_(cb.assign(i, 0), i < len, cb.assign(i, i + 1), skip(cb, r, in)) } - def _asIdent = s"ndarray_of_${ elementType.asIdent }" + override def _asIdent: String = + s"bm_ndarray_${if (encodeRowMajor) "row" else "column"}_major_of_${elementType.asIdent}" def _toPretty = s"ENDArray[$elementType]" - override def _pretty(sb: StringBuilder, indent: Int, compact: Boolean = false) { + override def _pretty(sb: StringBuilder, indent: Int, compact: Boolean = false): Unit = { sb.append("ENDArray[") elementType.pretty(sb, indent, compact) sb.append("]") diff --git a/hail/src/main/scala/is/hail/types/encoded/EBoolean.scala b/hail/src/main/scala/is/hail/types/encoded/EBoolean.scala index b74752042d5..744f9e7938d 100644 --- a/hail/src/main/scala/is/hail/types/encoded/EBoolean.scala +++ b/hail/src/main/scala/is/hail/types/encoded/EBoolean.scala @@ -4,8 +4,8 @@ import is.hail.annotations.Region import is.hail.asm4s._ import is.hail.expr.ir.EmitCodeBuilder import is.hail.io.{InputBuffer, OutputBuffer} -import is.hail.types.physical.stypes.primitives.{SBoolean, SBooleanValue} import is.hail.types.physical.stypes.{SType, SValue} +import is.hail.types.physical.stypes.primitives.{SBoolean, SBooleanValue} import is.hail.types.virtual._ import is.hail.utils._ @@ -14,15 +14,19 @@ case object EBooleanOptional extends EBoolean(false) case object EBooleanRequired extends EBoolean(true) class EBoolean(override val required: Boolean) extends EType { - override def _buildEncoder(cb: EmitCodeBuilder, v: SValue, out: Value[OutputBuffer]): Unit = { + override def _buildEncoder(cb: EmitCodeBuilder, v: SValue, out: Value[OutputBuffer]): Unit = cb += out.writeBoolean(v.asBoolean.value) - } - override def _buildDecoder(cb: EmitCodeBuilder, t: Type, region: Value[Region], in: Value[InputBuffer]): SValue = { + override def _buildDecoder( + cb: EmitCodeBuilder, + t: Type, + region: Value[Region], + in: Value[InputBuffer], + ): SValue = new SBooleanValue(cb.memoize(in.readBoolean())) - } - def _buildSkip(cb: EmitCodeBuilder, r: Value[Region], in: Value[InputBuffer]): Unit = cb += in.skipBoolean() + def _buildSkip(cb: EmitCodeBuilder, r: Value[Region], in: Value[InputBuffer]): Unit = + cb += in.skipBoolean() def _decodedSType(requestedType: Type): SType = SBoolean @@ -34,5 +38,6 @@ class EBoolean(override val required: Boolean) extends EType { } object EBoolean { - def apply(required: Boolean = false): EBoolean = if (required) EBooleanRequired else EBooleanOptional + def apply(required: Boolean = false): EBoolean = + if (required) EBooleanRequired else EBooleanOptional } diff --git a/hail/src/main/scala/is/hail/types/encoded/EDictAsUnsortedArrayOfPairs.scala b/hail/src/main/scala/is/hail/types/encoded/EDictAsUnsortedArrayOfPairs.scala index 55fa15d9692..a6e77e60812 100644 --- a/hail/src/main/scala/is/hail/types/encoded/EDictAsUnsortedArrayOfPairs.scala +++ b/hail/src/main/scala/is/hail/types/encoded/EDictAsUnsortedArrayOfPairs.scala @@ -2,17 +2,18 @@ package is.hail.types.encoded import is.hail.annotations._ import is.hail.asm4s._ -import is.hail.expr.ir.{ArraySorter, EmitCodeBuilder, EmitMethodBuilder, EmitRegion, StagedArrayBuilder} +import is.hail.expr.ir.{ArraySorter, EmitCodeBuilder, EmitRegion, StagedArrayBuilder} import is.hail.io.{InputBuffer, OutputBuffer} -import is.hail.types.virtual._ import is.hail.types.physical._ -import is.hail.types.physical.stypes.SingleCodeType +import is.hail.types.physical.stypes.{SType, SValue, SingleCodeType} import is.hail.types.physical.stypes.concrete.{SIndexablePointer, SIndexablePointerValue} -import is.hail.types.physical.stypes.interfaces.SIndexableValue -import is.hail.types.physical.stypes.{SType, SValue} +import is.hail.types.virtual._ import is.hail.utils._ -final case class EDictAsUnsortedArrayOfPairs(val elementType: EType, override val required: Boolean = false) extends EContainer { +final case class EDictAsUnsortedArrayOfPairs( + val elementType: EType, + override val required: Boolean = false, +) extends EContainer { assert(elementType.isInstanceOf[EBaseStruct]) private[this] val arrayRepr = EArray(elementType, required) @@ -26,13 +27,13 @@ final case class EDictAsUnsortedArrayOfPairs(val elementType: EType, override va } } - def _buildEncoder(cb: EmitCodeBuilder, v: SValue, out: Value[OutputBuffer]): Unit = { + def _buildEncoder(cb: EmitCodeBuilder, v: SValue, out: Value[OutputBuffer]): Unit = // Anything we have to encode from a region should already be sorted so we don't // have to do anything else arrayRepr._buildEncoder(cb, v, out) - } - def _buildDecoder(cb: EmitCodeBuilder, t: Type, region: Value[Region], in: Value[InputBuffer]): SValue = { + def _buildDecoder(cb: EmitCodeBuilder, t: Type, region: Value[Region], in: Value[InputBuffer]) + : SValue = { val tmpRegion = cb.memoize(Region.stagedCreate(Region.REGULAR, region.getPool()), "tmp_region") val arrayDecoder = arrayRepr.buildDecoder(t, cb.emb.ecb) @@ -46,7 +47,8 @@ final case class EDictAsUnsortedArrayOfPairs(val elementType: EType, override va } val sorter = new ArraySorter(EmitRegion(cb.emb, region), ab) - def lessThan(cb: EmitCodeBuilder, region: Value[Region], l: Value[_], r: Value[_]): Value[Boolean] = { + def lessThan(cb: EmitCodeBuilder, region: Value[Region], l: Value[_], r: Value[_]) + : Value[Boolean] = { val lk = cb.memoize(sct.loadToSValue(cb, l).asBaseStruct.loadField(cb, 0)) val rk = cb.memoize(sct.loadToSValue(cb, r).asBaseStruct.loadField(cb, 0)) @@ -55,17 +57,19 @@ final case class EDictAsUnsortedArrayOfPairs(val elementType: EType, override va } sorter.sort(cb, tmpRegion, lessThan) - // TODO Should be able to overwrite the unsorted array with sorted contents instead of allocating + /* TODO Should be able to overwrite the unsorted array with sorted contents instead of + * allocating */ val ret = sorter.toRegion(cb, t) cb.append(tmpRegion.invalidate()) ret } - def _buildSkip(cb: EmitCodeBuilder, r: Value[Region], in: Value[InputBuffer]): Unit = { + def _buildSkip(cb: EmitCodeBuilder, r: Value[Region], in: Value[InputBuffer]): Unit = arrayRepr._buildSkip(cb, r, in) - } def _asIdent = s"dict_of_${elementType.asIdent}" def _toPretty = s"EDictAsUnsortedArrayOfPairs[$elementType]" - def setRequired(newRequired: Boolean): EType = EDictAsUnsortedArrayOfPairs(elementType, newRequired) + + def setRequired(newRequired: Boolean): EType = + EDictAsUnsortedArrayOfPairs(elementType, newRequired) } diff --git a/hail/src/main/scala/is/hail/types/encoded/EFloat32.scala b/hail/src/main/scala/is/hail/types/encoded/EFloat32.scala index c582dcefa02..3442bb5840a 100644 --- a/hail/src/main/scala/is/hail/types/encoded/EFloat32.scala +++ b/hail/src/main/scala/is/hail/types/encoded/EFloat32.scala @@ -4,8 +4,8 @@ import is.hail.annotations.Region import is.hail.asm4s._ import is.hail.expr.ir.EmitCodeBuilder import is.hail.io.{InputBuffer, OutputBuffer} -import is.hail.types.physical.stypes.primitives.{SFloat32, SFloat32Value} import is.hail.types.physical.stypes.{SType, SValue} +import is.hail.types.physical.stypes.primitives.{SFloat32, SFloat32Value} import is.hail.types.virtual._ import is.hail.utils._ @@ -14,15 +14,19 @@ case object EFloat32Optional extends EFloat32(false) case object EFloat32Required extends EFloat32(true) class EFloat32(override val required: Boolean) extends EType { - override def _buildEncoder(cb: EmitCodeBuilder, v: SValue, out: Value[OutputBuffer]): Unit = { + override def _buildEncoder(cb: EmitCodeBuilder, v: SValue, out: Value[OutputBuffer]): Unit = cb += out.writeFloat(v.asFloat.value) - } - override def _buildDecoder(cb: EmitCodeBuilder, t: Type, region: Value[Region], in: Value[InputBuffer]): SValue = { + override def _buildDecoder( + cb: EmitCodeBuilder, + t: Type, + region: Value[Region], + in: Value[InputBuffer], + ): SValue = new SFloat32Value(cb.memoize(in.readFloat())) - } - def _buildSkip(cb: EmitCodeBuilder, r: Value[Region], in: Value[InputBuffer]): Unit = cb += in.skipFloat() + def _buildSkip(cb: EmitCodeBuilder, r: Value[Region], in: Value[InputBuffer]): Unit = + cb += in.skipFloat() def _decodedSType(requestedType: Type): SType = SFloat32 @@ -34,5 +38,6 @@ class EFloat32(override val required: Boolean) extends EType { } object EFloat32 { - def apply(required: Boolean = false): EFloat32 = if (required) EFloat32Required else EFloat32Optional + def apply(required: Boolean = false): EFloat32 = + if (required) EFloat32Required else EFloat32Optional } diff --git a/hail/src/main/scala/is/hail/types/encoded/EFloat64.scala b/hail/src/main/scala/is/hail/types/encoded/EFloat64.scala index 28cdf421049..05aabe2a1bd 100644 --- a/hail/src/main/scala/is/hail/types/encoded/EFloat64.scala +++ b/hail/src/main/scala/is/hail/types/encoded/EFloat64.scala @@ -4,8 +4,8 @@ import is.hail.annotations.Region import is.hail.asm4s._ import is.hail.expr.ir.EmitCodeBuilder import is.hail.io.{InputBuffer, OutputBuffer} -import is.hail.types.physical.stypes.primitives.{SFloat64, SFloat64Value} import is.hail.types.physical.stypes.{SType, SValue} +import is.hail.types.physical.stypes.primitives.{SFloat64, SFloat64Value} import is.hail.types.virtual._ import is.hail.utils._ @@ -14,15 +14,19 @@ case object EFloat64Optional extends EFloat64(false) case object EFloat64Required extends EFloat64(true) class EFloat64(override val required: Boolean) extends EType { - override def _buildEncoder(cb: EmitCodeBuilder, v: SValue, out: Value[OutputBuffer]): Unit = { + override def _buildEncoder(cb: EmitCodeBuilder, v: SValue, out: Value[OutputBuffer]): Unit = cb += out.writeDouble(v.asDouble.value) - } - override def _buildDecoder(cb: EmitCodeBuilder, t: Type, region: Value[Region], in: Value[InputBuffer]): SValue = { + override def _buildDecoder( + cb: EmitCodeBuilder, + t: Type, + region: Value[Region], + in: Value[InputBuffer], + ): SValue = new SFloat64Value(cb.memoize(in.readDouble())) - } - def _buildSkip(cb: EmitCodeBuilder, r: Value[Region], in: Value[InputBuffer]): Unit = cb += in.skipDouble() + def _buildSkip(cb: EmitCodeBuilder, r: Value[Region], in: Value[InputBuffer]): Unit = + cb += in.skipDouble() def _decodedSType(requestedType: Type): SType = SFloat64 @@ -34,5 +38,6 @@ class EFloat64(override val required: Boolean) extends EType { } object EFloat64 { - def apply(required: Boolean = false): EFloat64 = if (required) EFloat64Required else EFloat64Optional + def apply(required: Boolean = false): EFloat64 = + if (required) EFloat64Required else EFloat64Optional } diff --git a/hail/src/main/scala/is/hail/types/encoded/EInt32.scala b/hail/src/main/scala/is/hail/types/encoded/EInt32.scala index e63eed48f96..b559261b162 100644 --- a/hail/src/main/scala/is/hail/types/encoded/EInt32.scala +++ b/hail/src/main/scala/is/hail/types/encoded/EInt32.scala @@ -5,10 +5,10 @@ import is.hail.asm4s._ import is.hail.expr.ir.EmitCodeBuilder import is.hail.io.{InputBuffer, OutputBuffer} import is.hail.types.physical._ +import is.hail.types.physical.stypes.{SType, SValue} import is.hail.types.physical.stypes.concrete.{SCanonicalCall, SCanonicalCallValue} import is.hail.types.physical.stypes.interfaces.{SCall, SCallValue} import is.hail.types.physical.stypes.primitives.{SInt32, SInt32Value} -import is.hail.types.physical.stypes.{SType, SValue} import is.hail.types.virtual._ import is.hail.utils._ @@ -25,7 +25,12 @@ class EInt32(override val required: Boolean) extends EType { cb += out.writeInt(x) } - override def _buildDecoder(cb: EmitCodeBuilder, t: Type, region: Value[Region], in: Value[InputBuffer]): SValue = { + override def _buildDecoder( + cb: EmitCodeBuilder, + t: Type, + region: Value[Region], + in: Value[InputBuffer], + ): SValue = { val x = cb.memoize(in.readInt()) t match { case TCall => new SCanonicalCallValue(x) @@ -37,11 +42,11 @@ class EInt32(override val required: Boolean) extends EType { cb: EmitCodeBuilder, pt: PType, region: Value[Region], - in: Value[InputBuffer] + in: Value[InputBuffer], ): Code[Int] = in.readInt() - def _buildSkip(cb: EmitCodeBuilder, r: Value[Region], in: Value[InputBuffer]): Unit = cb += in.skipInt() - + def _buildSkip(cb: EmitCodeBuilder, r: Value[Region], in: Value[InputBuffer]): Unit = + cb += in.skipInt() def _decodedSType(requestedType: Type): SType = requestedType match { case TCall => SCanonicalCall diff --git a/hail/src/main/scala/is/hail/types/encoded/EInt64.scala b/hail/src/main/scala/is/hail/types/encoded/EInt64.scala index a95034b843a..b88517bcc02 100644 --- a/hail/src/main/scala/is/hail/types/encoded/EInt64.scala +++ b/hail/src/main/scala/is/hail/types/encoded/EInt64.scala @@ -4,8 +4,8 @@ import is.hail.annotations.Region import is.hail.asm4s._ import is.hail.expr.ir.EmitCodeBuilder import is.hail.io.{InputBuffer, OutputBuffer} -import is.hail.types.physical.stypes.primitives.{SInt64, SInt64Value} import is.hail.types.physical.stypes.{SType, SValue} +import is.hail.types.physical.stypes.primitives.{SInt64, SInt64Value} import is.hail.types.virtual._ import is.hail.utils._ @@ -14,15 +14,19 @@ case object EInt64Optional extends EInt64(false) case object EInt64Required extends EInt64(true) class EInt64(override val required: Boolean) extends EType { - override def _buildEncoder(cb: EmitCodeBuilder, v: SValue, out: Value[OutputBuffer]): Unit = { + override def _buildEncoder(cb: EmitCodeBuilder, v: SValue, out: Value[OutputBuffer]): Unit = cb += out.writeLong(v.asLong.value) - } - override def _buildDecoder(cb: EmitCodeBuilder, t: Type, region: Value[Region], in: Value[InputBuffer]): SValue = { + override def _buildDecoder( + cb: EmitCodeBuilder, + t: Type, + region: Value[Region], + in: Value[InputBuffer], + ): SValue = new SInt64Value(cb.memoize(in.readLong())) - } - def _buildSkip(cb: EmitCodeBuilder, r: Value[Region], in: Value[InputBuffer]): Unit = cb += in.skipLong() + def _buildSkip(cb: EmitCodeBuilder, r: Value[Region], in: Value[InputBuffer]): Unit = + cb += in.skipLong() def _decodedSType(requestedType: Type): SType = SInt64 diff --git a/hail/src/main/scala/is/hail/types/encoded/ENDArrayColumnMajor.scala b/hail/src/main/scala/is/hail/types/encoded/ENDArrayColumnMajor.scala index e268c36c09b..d6b505c44f4 100644 --- a/hail/src/main/scala/is/hail/types/encoded/ENDArrayColumnMajor.scala +++ b/hail/src/main/scala/is/hail/types/encoded/ENDArrayColumnMajor.scala @@ -4,14 +4,15 @@ import is.hail.annotations.Region import is.hail.asm4s._ import is.hail.expr.ir.EmitCodeBuilder import is.hail.io.{InputBuffer, OutputBuffer} -import is.hail.types.physical.stypes.{SCode, SType, SValue} +import is.hail.types.physical.PCanonicalNDArray +import is.hail.types.physical.stypes.{SType, SValue} import is.hail.types.physical.stypes.concrete.SNDArrayPointer import is.hail.types.physical.stypes.interfaces.{SNDArray, SNDArrayValue} -import is.hail.types.physical.PCanonicalNDArray import is.hail.types.virtual.{TNDArray, Type} import is.hail.utils._ -case class ENDArrayColumnMajor(elementType: EType, nDims: Int, required: Boolean = false) extends EContainer { +case class ENDArrayColumnMajor(elementType: EType, nDims: Int, required: Boolean = false) + extends EContainer { override def _buildEncoder(cb: EmitCodeBuilder, v: SValue, out: Value[OutputBuffer]): Unit = { val ndarray = v.asInstanceOf[SNDArrayValue] @@ -19,34 +20,44 @@ case class ENDArrayColumnMajor(elementType: EType, nDims: Int, required: Boolean val shapes = ndarray.shapes shapes.foreach(s => cb += out.writeLong(s)) - SNDArray.coiterate(cb, (ndarray, "A")){ + SNDArray.coiterate(cb, (ndarray, "A")) { case Seq(elt) => elementType.buildEncoder(elt.st, cb.emb.ecb) .apply(cb, elt, out) } } - override def _buildDecoder(cb: EmitCodeBuilder, t: Type, region: Value[Region], in: Value[InputBuffer]): SValue = { + override def _buildDecoder( + cb: EmitCodeBuilder, + t: Type, + region: Value[Region], + in: Value[InputBuffer], + ): SValue = { val st = decodedSType(t).asInstanceOf[SNDArrayPointer] val pnd = st.pType val readElemF = elementType.buildInplaceDecoder(pnd.elementType, cb.emb.ecb) - val shapeVars = (0 until nDims).map(i => cb.newLocal[Long](s"ndarray_decoder_shape_$i", in.readLong())) + val shapeVars = + (0 until nDims).map(i => cb.newLocal[Long](s"ndarray_decoder_shape_$i", in.readLong())) val totalNumElements = cb.newLocal[Long]("ndarray_decoder_total_num_elements", 1L) - shapeVars.foreach { s => - cb.assign(totalNumElements, totalNumElements * s) - } + shapeVars.foreach(s => cb.assign(totalNumElements, totalNumElements * s)) val strides = pnd.makeColumnMajorStrides(shapeVars, cb) - val (pndFirstElementAddress, pndFinisher) = pnd.constructDataFunction(shapeVars, strides, cb, region) + val (pndFirstElementAddress, pndFinisher) = + pnd.constructDataFunction(shapeVars, strides, cb, region) - val currElementAddress = cb.newLocal[Long]("eblockmatrix_ndarray_currElementAddress", pndFirstElementAddress) + val currElementAddress = + cb.newLocal[Long]("eblockmatrix_ndarray_currElementAddress", pndFirstElementAddress) val dataIdx = cb.newLocal[Int]("ndarray_decoder_data_idx") - cb.for_(cb.assign(dataIdx, 0), dataIdx < totalNumElements.toI, cb.assign(dataIdx, dataIdx + 1), { - readElemF(cb, region, currElementAddress, in) - cb.assign(currElementAddress, currElementAddress + pnd.elementType.byteSize) - }) + cb.for_( + cb.assign(dataIdx, 0), + dataIdx < totalNumElements.toI, + cb.assign(dataIdx, dataIdx + 1), { + readElemF(cb, region, currElementAddress, in) + cb.assign(currElementAddress, currElementAddress + pnd.elementType.byteSize) + }, + ) pndFinisher(cb) } @@ -54,8 +65,10 @@ case class ENDArrayColumnMajor(elementType: EType, nDims: Int, required: Boolean def _buildSkip(cb: EmitCodeBuilder, r: Value[Region], in: Value[InputBuffer]): Unit = { val skip = elementType.buildSkip(cb.emb.ecb) - val numElements = cb.newLocal[Long]("ndarray_skipper_total_num_elements", - (0 until nDims).foldLeft(const(1L).get) { (p, i) => p * in.readLong() }) + val numElements = cb.newLocal[Long]( + "ndarray_skipper_total_num_elements", + (0 until nDims).foldLeft(const(1L).get)((p, i) => p * in.readLong()), + ) val i = cb.newLocal[Long]("ndarray_skipper_data_idx") cb.for_(cb.assign(i, 0L), i < numElements, cb.assign(i, i + 1L), skip(cb, r, in)) } @@ -66,9 +79,11 @@ case class ENDArrayColumnMajor(elementType: EType, nDims: Int, required: Boolean SNDArrayPointer(PCanonicalNDArray(elementPType, requestedTNDArray.nDims, false)) } - override def setRequired(required: Boolean): EType = ENDArrayColumnMajor(elementType, nDims, required) + override def setRequired(required: Boolean): EType = + ENDArrayColumnMajor(elementType, nDims, required) - override def _asIdent = s"ndarray_of_${ elementType.asIdent }" + override def _asIdent: String = + s"${nDims}d_array_column_major_of_${elementType.asIdent}" override def _toPretty = s"ENDArrayColumnMajor[$elementType,$nDims]" } diff --git a/hail/src/main/scala/is/hail/types/encoded/ENumpyBinaryNDArray.scala b/hail/src/main/scala/is/hail/types/encoded/ENumpyBinaryNDArray.scala index 70485ff0968..5b88a58ab6d 100644 --- a/hail/src/main/scala/is/hail/types/encoded/ENumpyBinaryNDArray.scala +++ b/hail/src/main/scala/is/hail/types/encoded/ENumpyBinaryNDArray.scala @@ -5,10 +5,10 @@ import is.hail.asm4s._ import is.hail.expr.ir.EmitCodeBuilder import is.hail.io.{InputBuffer, OutputBuffer} import is.hail.types.physical.PCanonicalNDArray +import is.hail.types.physical.stypes.{SType, SValue} import is.hail.types.physical.stypes.concrete.SNDArrayPointer import is.hail.types.physical.stypes.interfaces.SNDArrayValue import is.hail.types.physical.stypes.primitives.SFloat64 -import is.hail.types.physical.stypes.{SType, SValue} import is.hail.types.virtual.{TNDArray, Type} import is.hail.utils.FastSeq @@ -17,7 +17,8 @@ final case class ENumpyBinaryNDArray(nRows: Long, nCols: Long, required: Boolean type DecodedPType = PCanonicalNDArray val elementType = EFloat64(true) - def setRequired(newRequired: Boolean): ENumpyBinaryNDArray = ENumpyBinaryNDArray(nRows, nCols, newRequired) + def setRequired(newRequired: Boolean): ENumpyBinaryNDArray = + ENumpyBinaryNDArray(nRows, nCols, newRequired) def _decodedSType(requestedType: Type): SType = { val elementPType = elementType.decodedPType(requestedType.asInstanceOf[TNDArray].elementType) @@ -31,15 +32,26 @@ final case class ENumpyBinaryNDArray(nRows: Long, nCols: Long, required: Boolean val j = cb.newLocal[Long]("j") val writeElemF = elementType.buildEncoder(ndarray.st.elementType, cb.emb.ecb) - cb.for_(cb.assign(i, 0L), i < nRows, cb.assign(i, i + 1L), { - cb.for_(cb.assign(j, 0L), j < nCols, cb.assign(j, j + 1L), { - writeElemF(cb, ndarray.loadElement(FastSeq(i, j), cb), out) - }) - }) + cb.for_( + cb.assign(i, 0L), + i < nRows, + cb.assign(i, i + 1L), + cb.for_( + cb.assign(j, 0L), + j < nCols, + cb.assign(j, j + 1L), + writeElemF(cb, ndarray.loadElement(FastSeq(i, j), cb), out), + ), + ) } - override def _buildDecoder(cb: EmitCodeBuilder, t: Type, region: Value[Region], in: Value[InputBuffer]): SValue = { + override def _buildDecoder( + cb: EmitCodeBuilder, + t: Type, + region: Value[Region], + in: Value[InputBuffer], + ): SValue = { val st = decodedSType(t).asInstanceOf[SNDArrayPointer] val pt = st.pType val readElemF = elementType.buildInplaceDecoder(pt.elementType, cb.emb.ecb) @@ -49,27 +61,33 @@ final case class ENumpyBinaryNDArray(nRows: Long, nCols: Long, required: Boolean val n = cb.newLocal[Long]("length", nRows * nCols) - val (tFirstElementAddress, tFinisher) = pt.constructDataFunction(IndexedSeq(nRows, nCols), IndexedSeq(stride0, stride1), cb, region) - val currElementAddress = cb.newLocal[Long]("eblockmatrix_ndarray_currElementAddress", tFirstElementAddress) + val (tFirstElementAddress, tFinisher) = + pt.constructDataFunction(IndexedSeq(nRows, nCols), IndexedSeq(stride0, stride1), cb, region) + val currElementAddress = + cb.newLocal[Long]("eblockmatrix_ndarray_currElementAddress", tFirstElementAddress) val i = cb.newLocal[Long]("i") - cb.for_(cb.assign(i, 0L), i < n, cb.assign(i, i + 1L), { - readElemF(cb, region, currElementAddress, in) - cb.assign(currElementAddress, currElementAddress + pt.elementType.byteSize) - }) + cb.for_( + cb.assign(i, 0L), + i < n, + cb.assign(i, i + 1L), { + readElemF(cb, region, currElementAddress, in) + cb.assign(currElementAddress, currElementAddress + pt.elementType.byteSize) + }, + ) tFinisher(cb) } - def _buildSkip(cb: EmitCodeBuilder, r: Value[Region], in: Value[InputBuffer]): Unit = { + def _buildSkip(cb: EmitCodeBuilder, r: Value[Region], in: Value[InputBuffer]): Unit = ??? - } - def _asIdent = s"ndarray_of_${ elementType.asIdent }" + override def _asIdent: String = + s"${nRows}by${nCols}_numpy_array_of_${elementType.asIdent}" def _toPretty = s"ENDArray[$elementType]" - override def _pretty(sb: StringBuilder, indent: Int, compact: Boolean = false) { + override def _pretty(sb: StringBuilder, indent: Int, compact: Boolean = false): Unit = { sb.append("ENDArray[") elementType.pretty(sb, indent, compact) sb.append("]") diff --git a/hail/src/main/scala/is/hail/types/encoded/ERNGState.scala b/hail/src/main/scala/is/hail/types/encoded/ERNGState.scala index 6f68bdb390d..71b7d70c8f6 100644 --- a/hail/src/main/scala/is/hail/types/encoded/ERNGState.scala +++ b/hail/src/main/scala/is/hail/types/encoded/ERNGState.scala @@ -5,7 +5,9 @@ import is.hail.asm4s.Value import is.hail.expr.ir.EmitCodeBuilder import is.hail.io.{InputBuffer, OutputBuffer} import is.hail.types.physical.stypes.SValue -import is.hail.types.physical.stypes.concrete.{SCanonicalRNGStateValue, SRNGState, SRNGStateStaticInfo, SRNGStateStaticSizeValue} +import is.hail.types.physical.stypes.concrete.{ + SCanonicalRNGStateValue, SRNGState, SRNGStateStaticInfo, SRNGStateStaticSizeValue, +} import is.hail.types.virtual.Type import is.hail.utils._ @@ -13,26 +15,34 @@ import is.hail.utils._ // //case object ERNGStateRequired extends ERNGState(true) -final case class ERNGState(override val required: Boolean, staticInfo: Option[SRNGStateStaticInfo]) extends EType { - override def _buildEncoder(cb: EmitCodeBuilder, v: SValue, out: Value[OutputBuffer]): Unit = (staticInfo, v) match { - case (Some(staticInfo), v: SRNGStateStaticSizeValue) => - assert(staticInfo == v.staticInfo) - for (x <- v.runningSum) cb += out.writeLong(x) - for (x <- v.lastDynBlock) cb += out.writeLong(x) - case (None, v: SCanonicalRNGStateValue) => - for (x <- v.runningSum) cb += out.writeLong(x) - for (x <- v.lastDynBlock) cb += out.writeLong(x) - cb += out.writeInt(v.numWordsInLastBlock) - cb += out.writeBoolean(v.hasStaticSplit) - cb += out.writeInt(v.numDynBlocks) - } +final case class ERNGState(override val required: Boolean, staticInfo: Option[SRNGStateStaticInfo]) + extends EType { + override def _buildEncoder(cb: EmitCodeBuilder, v: SValue, out: Value[OutputBuffer]): Unit = + (staticInfo, v) match { + case (Some(staticInfo), v: SRNGStateStaticSizeValue) => + assert(staticInfo == v.staticInfo) + for (x <- v.runningSum) cb += out.writeLong(x) + for (x <- v.lastDynBlock) cb += out.writeLong(x) + case (None, v: SCanonicalRNGStateValue) => + for (x <- v.runningSum) cb += out.writeLong(x) + for (x <- v.lastDynBlock) cb += out.writeLong(x) + cb += out.writeInt(v.numWordsInLastBlock) + cb += out.writeBoolean(v.hasStaticSplit) + cb += out.writeInt(v.numDynBlocks) + } - override def _buildDecoder(cb: EmitCodeBuilder, t: Type, region: Value[Region], in: Value[InputBuffer]): SValue = staticInfo match { + override def _buildDecoder( + cb: EmitCodeBuilder, + t: Type, + region: Value[Region], + in: Value[InputBuffer], + ): SValue = staticInfo match { case Some(staticInfo) => new SRNGStateStaticSizeValue( _decodedSType(t), Array.fill(4)(cb.memoize(in.readLong())), - Array.fill(staticInfo.numWordsInLastBlock)(cb.memoize(in.readLong()))) + Array.fill(staticInfo.numWordsInLastBlock)(cb.memoize(in.readLong())), + ) case None => new SCanonicalRNGStateValue( _decodedSType(t), @@ -40,20 +50,22 @@ final case class ERNGState(override val required: Boolean, staticInfo: Option[SR Array.fill(4)(cb.memoize(in.readLong())), cb.memoize(in.readInt()), cb.memoize(in.readBoolean()), - cb.memoize(in.readInt())) + cb.memoize(in.readInt()), + ) } - def _buildSkip(cb: EmitCodeBuilder, r: Value[Region], in: Value[InputBuffer]): Unit = staticInfo match { - case Some(staticInfo) => - for (_ <- 0 until (4 + staticInfo.numWordsInLastBlock)) - cb += in.skipLong() - case None => - for (_ <- 0 until 8) - cb += in.skipLong() - cb += in.skipInt() - cb += in.skipBoolean() - cb += in.skipInt() - } + def _buildSkip(cb: EmitCodeBuilder, r: Value[Region], in: Value[InputBuffer]): Unit = + staticInfo match { + case Some(staticInfo) => + for (_ <- 0 until (4 + staticInfo.numWordsInLastBlock)) + cb += in.skipLong() + case None => + for (_ <- 0 until 8) + cb += in.skipLong() + cb += in.skipInt() + cb += in.skipBoolean() + cb += in.skipInt() + } def _decodedSType(requestedType: Type): SRNGState = SRNGState(staticInfo) diff --git a/hail/src/main/scala/is/hail/types/encoded/EType.scala b/hail/src/main/scala/is/hail/types/encoded/EType.scala index 3d71c7716e8..a2d83932bd5 100644 --- a/hail/src/main/scala/is/hail/types/encoded/EType.scala +++ b/hail/src/main/scala/is/hail/types/encoded/EType.scala @@ -1,89 +1,108 @@ package is.hail.types.encoded + import is.hail.annotations.Region import is.hail.asm4s.{coerce => _, _} import is.hail.backend.ExecuteContext -import is.hail.expr.ir.{EmitClassBuilder, EmitCodeBuilder, EmitFunctionBuilder, EmitMethodBuilder, IRParser, ParamType, PunctuationToken, TokenIterator} +import is.hail.expr.ir.{ + EmitClassBuilder, EmitCodeBuilder, EmitFunctionBuilder, EmitMethodBuilder, IRParser, ParamType, + PunctuationToken, TokenIterator, +} import is.hail.io._ import is.hail.types._ import is.hail.types.physical._ import is.hail.types.physical.stypes.{SType, SValue} import is.hail.types.virtual._ import is.hail.utils._ -import org.json4s.CustomSerializer -import org.json4s.JsonAST.JString import java.util import java.util.Map.Entry +import org.json4s.CustomSerializer +import org.json4s.JsonAST.JString -class ETypeSerializer extends CustomSerializer[EType](format => ( { - case JString(s) => IRParser.parse[EType](s, EType.eTypeParser) -}, { - case t: EType => JString(t.parsableString()) -})) - +class ETypeSerializer extends CustomSerializer[EType](format => + ( + { + case JString(s) => IRParser.parse[EType](s, EType.eTypeParser) + }, + { + case t: EType => JString(t.parsableString()) + }, + ) + ) abstract class EType extends BaseType with Serializable with Requiredness { type StagedEncoder = (EmitCodeBuilder, SValue, Value[OutputBuffer]) => Unit type StagedDecoder = (EmitCodeBuilder, Value[Region], Value[InputBuffer]) => SValue - type StagedInplaceDecoder = (EmitCodeBuilder, Value[Region], Value[Long], Value[InputBuffer]) => Unit - final def buildEncoder(ctx: ExecuteContext, t: PType): (OutputBuffer, HailClassLoader) => Encoder = { + type StagedInplaceDecoder = + (EmitCodeBuilder, Value[Region], Value[Long], Value[InputBuffer]) => Unit + + final def buildEncoder(ctx: ExecuteContext, t: PType) + : (OutputBuffer, HailClassLoader) => Encoder = { val f = EType.buildEncoder(ctx, this, t) - (out: OutputBuffer, theHailClassLoader: HailClassLoader) => new CompiledEncoder(out, theHailClassLoader, f) + (out: OutputBuffer, theHailClassLoader: HailClassLoader) => + new CompiledEncoder(out, theHailClassLoader, f) } - final def buildDecoder(ctx: ExecuteContext, requestedType: Type): (PType, (InputBuffer, HailClassLoader) => Decoder) = { + final def buildDecoder(ctx: ExecuteContext, requestedType: Type) + : (PType, (InputBuffer, HailClassLoader) => Decoder) = { val (rt, f) = EType.buildDecoderToRegionValue(ctx, this, requestedType) val makeDec = (in: InputBuffer, theHailClassLoader: HailClassLoader) => new CompiledDecoder(in, rt, theHailClassLoader, f) (rt, makeDec) } - final def buildStructDecoder(ctx: ExecuteContext, requestedType: TStruct): (PStruct, (InputBuffer, HailClassLoader) => Decoder) = { + final def buildStructDecoder(ctx: ExecuteContext, requestedType: TStruct) + : (PStruct, (InputBuffer, HailClassLoader) => Decoder) = { val (pType: PStruct, makeDec) = buildDecoder(ctx, requestedType) pType -> makeDec } final def buildEncoder(st: SType, kb: EmitClassBuilder[_]): StagedEncoder = { val mb = buildEncoderMethod(st, kb); - { (cb: EmitCodeBuilder, sv: SValue, ob: Value[OutputBuffer]) => cb.invokeVoid(mb, sv, ob) } + { (cb: EmitCodeBuilder, sv: SValue, ob: Value[OutputBuffer]) => + cb.invokeVoid(mb, cb.this_, sv, ob) + } } - final def buildEncoderMethod(st: SType, kb: EmitClassBuilder[_]): EmitMethodBuilder[_] = { - kb.getOrGenEmitMethod(s"ENCODE_${ st.asIdent }_TO_${ asIdent }", + final def buildEncoderMethod(st: SType, kb: EmitClassBuilder[_]): EmitMethodBuilder[_] = + kb.getOrGenEmitMethod( + s"ENCODE_${st.asIdent}_TO_$asIdent", (st, this, "ENCODE"), FastSeq[ParamType](st.paramType, classInfo[OutputBuffer]), - UnitInfo) { mb => - + UnitInfo, + ) { mb => mb.voidWithBuilder { cb => val arg = mb.getSCodeParam(1) val out = mb.getCodeParam[OutputBuffer](2) _buildEncoder(cb, arg, out) } } - } final def buildDecoder(t: Type, kb: EmitClassBuilder[_]): StagedDecoder = { val mb = buildDecoderMethod(t: Type, kb); { (cb: EmitCodeBuilder, r: Value[Region], ib: Value[InputBuffer]) => - cb.invokeSCode(mb, r, ib) + cb.invokeSCode(mb, cb.this_, r, ib) } } final def buildDecoderMethod[T](t: Type, kb: EmitClassBuilder[_]): EmitMethodBuilder[_] = { val st = decodedSType(t) - kb.getOrGenEmitMethod(s"DECODE_${ asIdent }_TO_${ st.asIdent }", + kb.getOrGenEmitMethod( + s"DECODE_${asIdent}_TO_${st.asIdent}", (t, this, "DECODE"), FastSeq[ParamType](typeInfo[Region], classInfo[InputBuffer]), - st.paramType) { mb => - + st.paramType, + ) { mb => mb.emitSCode { cb => val region: Value[Region] = mb.getCodeParam[Region](1) val in: Value[InputBuffer] = mb.getCodeParam[InputBuffer](2) val sc = _buildDecoder(cb, t, region, in) if (sc.st != st) - throw new RuntimeException(s"decoder type mismatch:\n inferred: $st\n returned: ${ sc.st }") + throw new RuntimeException( + s"decoder type mismatch:\n inferred: $st\n returned: ${sc.st}" + ) sc } } @@ -92,50 +111,54 @@ abstract class EType extends BaseType with Serializable with Requiredness { final def buildInplaceDecoder(pt: PType, kb: EmitClassBuilder[_]): StagedInplaceDecoder = { val mb = buildInplaceDecoderMethod(pt, kb); { (cb: EmitCodeBuilder, r: Value[Region], addr: Value[Long], ib: Value[InputBuffer]) => - cb.invokeVoid(mb, r, addr, ib) + cb.invokeVoid(mb, cb.this_, r, addr, ib) } } - final def buildInplaceDecoderMethod(pt: PType, kb: EmitClassBuilder[_]): EmitMethodBuilder[_] = { - kb.getOrGenEmitMethod(s"INPLACE_DECODE_${ asIdent }_TO_${ pt.asIdent }", + final def buildInplaceDecoderMethod(pt: PType, kb: EmitClassBuilder[_]): EmitMethodBuilder[_] = + kb.getOrGenEmitMethod( + s"INPLACE_DECODE_${asIdent}_TO_${pt.asIdent}", (pt, this, "INPLACE_DECODE"), FastSeq[ParamType](typeInfo[Region], typeInfo[Long], classInfo[InputBuffer]), - UnitInfo)({ mb => - + UnitInfo, + ) { mb => mb.voidWithBuilder { cb => val region: Value[Region] = mb.getCodeParam[Region](1) val addr: Value[Long] = mb.getCodeParam[Long](2) val in: Value[InputBuffer] = mb.getCodeParam[InputBuffer](3) _buildInplaceDecoder(cb, pt, region, addr, in) } - }) - } + } - final def buildSkip(kb: EmitClassBuilder[_]): (EmitCodeBuilder, Value[Region], Value[InputBuffer]) => Unit = { - val mb = kb.getOrGenEmitMethod(s"SKIP_${ asIdent }", + final def buildSkip(kb: EmitClassBuilder[_]) + : (EmitCodeBuilder, Value[Region], Value[InputBuffer]) => Unit = { + val mb = kb.getOrGenEmitMethod( + s"SKIP_$asIdent", (this, "SKIP"), FastSeq[ParamType](classInfo[Region], classInfo[InputBuffer]), - UnitInfo)({ mb => + UnitInfo, + ) { mb => mb.voidWithBuilder { cb => val r: Value[Region] = mb.getCodeParam[Region](1) val in: Value[InputBuffer] = mb.getCodeParam[InputBuffer](2) _buildSkip(cb, r, in) } - }) + } - { (cb, r, in) => cb.invokeVoid(mb, r, in) } + { (cb, r, in) => cb.invokeVoid(mb, cb.this_, r, in) } } def _buildEncoder(cb: EmitCodeBuilder, v: SValue, out: Value[OutputBuffer]): Unit - def _buildDecoder(cb: EmitCodeBuilder, t: Type, region: Value[Region], in: Value[InputBuffer]): SValue + def _buildDecoder(cb: EmitCodeBuilder, t: Type, region: Value[Region], in: Value[InputBuffer]) + : SValue def _buildInplaceDecoder( cb: EmitCodeBuilder, pt: PType, region: Value[Region], addr: Value[Long], - in: Value[InputBuffer] + in: Value[InputBuffer], ): Unit = { assert(!pt.isInstanceOf[PBaseStruct]) // should be overridden for structs val decoded = _buildDecoder(cb, pt.virtualType, region, in) @@ -144,7 +167,7 @@ abstract class EType extends BaseType with Serializable with Requiredness { def _buildSkip(cb: EmitCodeBuilder, r: Value[Region], in: Value[InputBuffer]): Unit - final def pretty(sb: StringBuilder, indent: Int, compact: Boolean) { + final def pretty(sb: StringBuilder, indent: Int, compact: Boolean): Unit = { if (required) sb.append("+") _pretty(sb, indent, compact) @@ -156,17 +179,14 @@ abstract class EType extends BaseType with Serializable with Requiredness { def _toPretty: String - def _pretty(sb: StringBuilder, indent: Int, compact: Boolean) { + def _pretty(sb: StringBuilder, indent: Int, compact: Boolean): Unit = sb.append(_toPretty) - } - final def decodedSType(requestedType: Type): SType = { + final def decodedSType(requestedType: Type): SType = _decodedSType(requestedType) - } - final def decodedPType(requestedType: Type): PType = { + final def decodedPType(requestedType: Type): PType = decodedSType(requestedType).storageType().setRequired(required) - } def _decodedSType(requestedType: Type): SType @@ -179,18 +199,28 @@ trait EncoderAsmFunction { def apply(off: Long, out: OutputBuffer): Unit } object EType { - protected[encoded] def lowBitMask(n: Int): Byte = (0xFF >>> ((-n) & 0x7)).toByte - protected[encoded] def lowBitMask(n: Code[Int]): Code[Byte] = (const(0xFF) >>> ((-n) & 0x7)).toB + protected[encoded] def lowBitMask(n: Int): Byte = (0xff >>> ((-n) & 0x7)).toByte + protected[encoded] def lowBitMask(n: Code[Int]): Code[Byte] = (const(0xff) >>> ((-n) & 0x7)).toB val cacheCapacity = 256 - protected val encoderCache = new util.LinkedHashMap[(EType, PType), (HailClassLoader) => EncoderAsmFunction](cacheCapacity, 0.75f, true) { - override def removeEldestEntry(eldest: Entry[(EType, PType), (HailClassLoader) => EncoderAsmFunction]): Boolean = size() > cacheCapacity - } + + protected val encoderCache = + new util.LinkedHashMap[(EType, PType), (HailClassLoader) => EncoderAsmFunction]( + cacheCapacity, + 0.75f, + true, + ) { + override def removeEldestEntry( + eldest: Entry[(EType, PType), (HailClassLoader) => EncoderAsmFunction] + ): Boolean = size() > cacheCapacity + } + protected var encoderCacheHits: Long = 0L protected var encoderCacheMisses: Long = 0L // The 'entry point' for building an encoder from an EType and a PType - def buildEncoder(ctx: ExecuteContext, et: EType, pt: PType): (HailClassLoader) => EncoderAsmFunction = { + def buildEncoder(ctx: ExecuteContext, et: EType, pt: PType) + : (HailClassLoader) => EncoderAsmFunction = { val k = (et, pt) if (encoderCache.containsKey(k)) { encoderCacheHits += 1 @@ -199,10 +229,13 @@ object EType { } else { encoderCacheMisses += 1 log.info(s"encoder cache miss ($encoderCacheHits hits, $encoderCacheMisses misses, " + - s"${ formatDouble(encoderCacheHits.toDouble / (encoderCacheHits + encoderCacheMisses), 3) })") - val fb = EmitFunctionBuilder[EncoderAsmFunction](ctx, "etypeEncode", + s"${formatDouble(encoderCacheHits.toDouble / (encoderCacheHits + encoderCacheMisses), 3)})") + val fb = EmitFunctionBuilder[EncoderAsmFunction]( + ctx, + "etypeEncode", Array(NotGenericTypeInfo[Long], NotGenericTypeInfo[OutputBuffer]), - NotGenericTypeInfo[Unit]) + NotGenericTypeInfo[Unit], + ) val mb = fb.apply_method mb.voidWithBuilder { cb => @@ -218,13 +251,22 @@ object EType { } } - protected val decoderCache = new util.LinkedHashMap[(EType, Type), (PType, (HailClassLoader) => DecoderAsmFunction)](cacheCapacity, 0.75f, true) { - override def removeEldestEntry(eldest: Entry[(EType, Type), (PType, (HailClassLoader) => DecoderAsmFunction)]): Boolean = size() > cacheCapacity - } + protected val decoderCache = + new util.LinkedHashMap[(EType, Type), (PType, (HailClassLoader) => DecoderAsmFunction)]( + cacheCapacity, + 0.75f, + true, + ) { + override def removeEldestEntry( + eldest: Entry[(EType, Type), (PType, (HailClassLoader) => DecoderAsmFunction)] + ): Boolean = size() > cacheCapacity + } + protected var decoderCacheHits: Long = 0L protected var decoderCacheMisses: Long = 0L - def buildDecoderToRegionValue(ctx: ExecuteContext, et: EType, t: Type): (PType, (HailClassLoader) => DecoderAsmFunction) = { + def buildDecoderToRegionValue(ctx: ExecuteContext, et: EType, t: Type) + : (PType, (HailClassLoader) => DecoderAsmFunction) = { val k = (et, t) if (decoderCache.containsKey(k)) { decoderCacheHits += 1 @@ -233,10 +275,13 @@ object EType { } else { decoderCacheMisses += 1 log.info(s"decoder cache miss ($decoderCacheHits hits, $decoderCacheMisses misses, " + - s"${ formatDouble(decoderCacheHits.toDouble / (decoderCacheHits + decoderCacheMisses), 3) }") - val fb = EmitFunctionBuilder[DecoderAsmFunction](ctx, "etypeDecode", + s"${formatDouble(decoderCacheHits.toDouble / (decoderCacheHits + decoderCacheMisses), 3)}") + val fb = EmitFunctionBuilder[DecoderAsmFunction]( + ctx, + "etypeDecode", Array(NotGenericTypeInfo[Region], NotGenericTypeInfo[InputBuffer]), - NotGenericTypeInfo[Long]) + NotGenericTypeInfo[Long], + ) val mb = fb.apply_method val pt = et.decodedPType(t) val f = et.buildDecoder(t, mb.ecb) @@ -269,10 +314,13 @@ object EType { case TBinary => EBinary(r.required) case TString => EBinary(r.required) case TLocus(_) => - EBaseStruct(Array( - EField("contig", EBinary(true), 0), - EField("position", EInt32(true), 1)), - required = r.required) + EBaseStruct( + Array( + EField("contig", EBinary(true), 0), + EField("position", EInt32(true), 1), + ), + required = r.required, + ) case TCall => EInt32(r.required) case TRNGState => ERNGState(r.required, None) case t: TInterval => @@ -282,21 +330,31 @@ object EType { EField("start", fromTypeAndAnalysis(t.pointType, rinterval.startType), 0), EField("end", fromTypeAndAnalysis(t.pointType, rinterval.endType), 1), EField("includesStart", EBoolean(true), 2), - EField("includesEnd", EBoolean(true), 3)), - required = rinterval.required) - case t: TIterable => EArray(fromTypeAndAnalysis(t.elementType, tcoerce[RIterable](r).elementType), r.required) + EField("includesEnd", EBoolean(true), 3), + ), + required = rinterval.required, + ) + case t: TIterable => + EArray(fromTypeAndAnalysis(t.elementType, tcoerce[RIterable](r).elementType), r.required) case t: TBaseStruct => val rstruct = tcoerce[RBaseStruct](r) - assert(t.size == rstruct.size, s"different number of fields: ${t} ${r}") - EBaseStruct(Array.tabulate(t.size) { i => - val f = rstruct.fields(i) - if (f.index != i) - throw new AssertionError(s"${t} [$i]") - EField(f.name, fromTypeAndAnalysis(t.fields(i).typ, f.typ), f.index) - }, required = r.required) + assert(t.size == rstruct.size, s"different number of fields: $t $r") + EBaseStruct( + Array.tabulate(t.size) { i => + val f = rstruct.fields(i) + if (f.index != i) + throw new AssertionError(s"$t [$i]") + EField(f.name, fromTypeAndAnalysis(t.fields(i).typ, f.typ), f.index) + }, + required = r.required, + ) case t: TNDArray => val rndarray = r.asInstanceOf[RNDArray] - ENDArrayColumnMajor(fromTypeAndAnalysis(t.elementType, rndarray.elementType), t.nDims, rndarray.required) + ENDArrayColumnMajor( + fromTypeAndAnalysis(t.elementType, rndarray.elementType), + t.nDims, + rndarray.required, + ) } def fromPythonTypeEncoding(t: Type): EType = t match { @@ -308,10 +366,13 @@ object EType { case TBinary => EBinary(false) case TString => EBinary(false) case TLocus(_) => - EBaseStruct(Array( - EField("contig", EBinary(false), 0), - EField("position", EInt32(false), 1)), - required = false) + EBaseStruct( + Array( + EField("contig", EBinary(false), 0), + EField("position", EInt32(false), 1), + ), + required = false, + ) case TCall => EInt32(false) case t: TInterval => EBaseStruct( @@ -319,18 +380,24 @@ object EType { EField("start", fromPythonTypeEncoding(t.pointType), 0), EField("end", fromPythonTypeEncoding(t.pointType), 1), EField("includesStart", EBoolean(false), 2), - EField("includesEnd", EBoolean(false), 3)), - required = false) - case t: TDict => EDictAsUnsortedArrayOfPairs(fromPythonTypeEncoding(t.elementType).setRequired(true), false) + EField("includesEnd", EBoolean(false), 3), + ), + required = false, + ) + case t: TDict => + EDictAsUnsortedArrayOfPairs(fromPythonTypeEncoding(t.elementType).setRequired(true), false) case t: TSet => EUnsortedSet(fromPythonTypeEncoding(t.elementType), false) case t: TIterable => EArray(fromPythonTypeEncoding(t.elementType), false) case t: TBaseStruct => - EBaseStruct(Array.tabulate(t.size) { i => - val f = t.fields(i) - if (f.index != i) - throw new AssertionError(s"${t} [$i]") - EField(f.name, fromPythonTypeEncoding(t.fields(i).typ), f.index) - }, required = false) + EBaseStruct( + Array.tabulate(t.size) { i => + val f = t.fields(i) + if (f.index != i) + throw new AssertionError(s"$t [$i]") + EField(f.name, fromPythonTypeEncoding(t.fields(i).typ), f.index) + }, + required = false, + ) case t: TNDArray => ENDArrayColumnMajor(fromPythonTypeEncoding(t.elementType).setRequired(true), t.nDims, false) } @@ -357,7 +424,12 @@ object EType { EArray(elementType, req) case "EBaseStruct" => IRParser.punctuation(it, "{") - val args = IRParser.repsepUntil(it, IRParser.struct_field(eTypeParser), PunctuationToken(","), PunctuationToken("}")) + val args = IRParser.repsepUntil( + it, + IRParser.struct_field(eTypeParser), + PunctuationToken(","), + PunctuationToken("}"), + ) IRParser.punctuation(it, "}") EBaseStruct(args.zipWithIndex.map { case ((name, t), i) => EField(name, t, i) }, req) case "ENDArrayColumnMajor" => @@ -366,7 +438,7 @@ object EType { IRParser.punctuation(it, ",") val nDims = IRParser.int32_literal(it) IRParser.punctuation(it, "]") - ENDArrayColumnMajor(elementType, nDims, req) + ENDArrayColumnMajor(elementType, nDims, req) case x => throw new UnsupportedOperationException(s"Couldn't parse $x ${it.toIndexedSeq}") } diff --git a/hail/src/main/scala/is/hail/types/encoded/EUnsortedSet.scala b/hail/src/main/scala/is/hail/types/encoded/EUnsortedSet.scala index b3be4d606af..a8c85430c78 100644 --- a/hail/src/main/scala/is/hail/types/encoded/EUnsortedSet.scala +++ b/hail/src/main/scala/is/hail/types/encoded/EUnsortedSet.scala @@ -2,17 +2,16 @@ package is.hail.types.encoded import is.hail.annotations._ import is.hail.asm4s._ -import is.hail.expr.ir.{ArraySorter, EmitCodeBuilder, EmitMethodBuilder, EmitRegion, StagedArrayBuilder} +import is.hail.expr.ir.{ArraySorter, EmitCodeBuilder, EmitRegion, StagedArrayBuilder} import is.hail.io.{InputBuffer, OutputBuffer} -import is.hail.types.virtual._ import is.hail.types.physical._ -import is.hail.types.physical.stypes.SingleCodeType +import is.hail.types.physical.stypes.{SType, SValue, SingleCodeType} import is.hail.types.physical.stypes.concrete.{SIndexablePointer, SIndexablePointerValue} -import is.hail.types.physical.stypes.interfaces.SIndexableValue -import is.hail.types.physical.stypes.{SType, SValue} +import is.hail.types.virtual._ import is.hail.utils._ -final case class EUnsortedSet(val elementType: EType, override val required: Boolean = false) extends EContainer { +final case class EUnsortedSet(val elementType: EType, override val required: Boolean = false) + extends EContainer { private[this] val arrayRepr = EArray(elementType, required) def _decodedSType(requestedType: Type): SType = { @@ -23,13 +22,13 @@ final case class EUnsortedSet(val elementType: EType, override val required: Boo } } - def _buildEncoder(cb: EmitCodeBuilder, v: SValue, out: Value[OutputBuffer]): Unit = { + def _buildEncoder(cb: EmitCodeBuilder, v: SValue, out: Value[OutputBuffer]): Unit = // Anything we have to encode from a region should already be sorted so we don't // have to do anything else arrayRepr._buildEncoder(cb, v, out) - } - def _buildDecoder(cb: EmitCodeBuilder, t: Type, region: Value[Region], in: Value[InputBuffer]): SValue = { + def _buildDecoder(cb: EmitCodeBuilder, t: Type, region: Value[Region], in: Value[InputBuffer]) + : SValue = { val tmpRegion = cb.memoize(Region.stagedCreate(Region.REGULAR, region.getPool()), "tmp_region") val arrayDecoder = arrayRepr.buildDecoder(t, cb.emb.ecb) @@ -43,21 +42,21 @@ final case class EUnsortedSet(val elementType: EType, override val required: Boo } val sorter = new ArraySorter(EmitRegion(cb.emb, region), ab) - def lessThan(cb: EmitCodeBuilder, region: Value[Region], l: Value[_], r: Value[_]): Value[Boolean] = { + def lessThan(cb: EmitCodeBuilder, region: Value[Region], l: Value[_], r: Value[_]) + : Value[Boolean] = cb.emb.ecb.getOrdering(sct.loadedSType, sct.loadedSType) .ltNonnull(cb, sct.loadToSValue(cb, l), sct.loadToSValue(cb, r)) - } sorter.sort(cb, tmpRegion, lessThan) - // TODO Should be able to overwrite the unsorted array with sorted contents instead of allocating + /* TODO Should be able to overwrite the unsorted array with sorted contents instead of + * allocating */ val ret = sorter.toRegion(cb, t) cb.append(tmpRegion.invalidate()) ret } - def _buildSkip(cb: EmitCodeBuilder, r: Value[Region], in: Value[InputBuffer]): Unit = { + def _buildSkip(cb: EmitCodeBuilder, r: Value[Region], in: Value[InputBuffer]): Unit = arrayRepr._buildSkip(cb, r, in) - } def _asIdent = s"set_of_${elementType.asIdent}" def _toPretty = s"EUnsortedSet[$elementType]" diff --git a/hail/src/main/scala/is/hail/types/package.scala b/hail/src/main/scala/is/hail/types/package.scala index ed088d70098..b4de55509c0 100644 --- a/hail/src/main/scala/is/hail/types/package.scala +++ b/hail/src/main/scala/is/hail/types/package.scala @@ -1,6 +1,7 @@ package is.hail import is.hail.types.physical.PType +import is.hail.types.physical.stypes.SType import is.hail.types.virtual.Type package object types { @@ -8,5 +9,7 @@ package object types { def tcoerce[T <: PType](x: PType): T = x.asInstanceOf[T] + def tcoerce[T <: SType](x: SType): T = x.asInstanceOf[T] + def tcoerce[T <: BaseTypeWithRequiredness](x: BaseTypeWithRequiredness): T = x.asInstanceOf[T] } diff --git a/hail/src/main/scala/is/hail/types/physical/PArray.scala b/hail/src/main/scala/is/hail/types/physical/PArray.scala index e0f983e5b85..c069b741c5b 100644 --- a/hail/src/main/scala/is/hail/types/physical/PArray.scala +++ b/hail/src/main/scala/is/hail/types/physical/PArray.scala @@ -3,8 +3,6 @@ package is.hail.types.physical import is.hail.annotations.Annotation import is.hail.backend.HailStateManager import is.hail.check.Gen -import is.hail.expr.ir.EmitMethodBuilder -import is.hail.expr.ir.orderings.CodeOrdering import is.hail.types.virtual.TArray trait PArrayIterator { @@ -16,7 +14,7 @@ trait PArrayIterator { abstract class PArray extends PContainer { lazy val virtualType: TArray = TArray(elementType.virtualType) - protected[physical] final val elementRequired = elementType.required + final protected[physical] val elementRequired = elementType.required def elementIterator(aoff: Long, length: Int): PArrayIterator diff --git a/hail/src/main/scala/is/hail/types/physical/PArrayBackedContainer.scala b/hail/src/main/scala/is/hail/types/physical/PArrayBackedContainer.scala index 237853a4a68..194c99324d0 100644 --- a/hail/src/main/scala/is/hail/types/physical/PArrayBackedContainer.scala +++ b/hail/src/main/scala/is/hail/types/physical/PArrayBackedContainer.scala @@ -45,9 +45,8 @@ trait PArrayBackedContainer extends PContainer { override def setElementMissing(cb: EmitCodeBuilder, aoff: Code[Long], i: Code[Int]): Unit = arrayRep.setElementMissing(cb, aoff, i) - override def setElementPresent(aoff: Long, i: Int) { - arrayRep.setElementPresent(aoff, i) - } + override def setElementPresent(aoff: Long, i: Int): Unit = + arrayRep.setElementPresent(aoff, i) override def setElementPresent(cb: EmitCodeBuilder, aoff: Code[Long], i: Code[Int]): Unit = arrayRep.setElementPresent(cb, aoff, i) @@ -106,7 +105,12 @@ trait PArrayBackedContainer extends PContainer { override def initialize(aoff: Long, length: Int, setMissing: Boolean = false) = arrayRep.initialize(aoff, length, setMissing) - override def stagedInitialize(cb: EmitCodeBuilder, aoff: Code[Long], length: Code[Int], setMissing: Boolean = false): Unit = + override def stagedInitialize( + cb: EmitCodeBuilder, + aoff: Code[Long], + length: Code[Int], + setMissing: Boolean = false, + ): Unit = arrayRep.stagedInitialize(cb, aoff, length, setMissing) override def zeroes(region: Region, length: Int): Long = @@ -124,8 +128,20 @@ trait PArrayBackedContainer extends PContainer { override def unsafeOrdering(sm: HailStateManager, rightType: PType): UnsafeOrdering = arrayRep.unsafeOrdering(sm, rightType) - override def _copyFromAddress(sm: HailStateManager, region: Region, srcPType: PType, srcAddress: Long, deepCopy: Boolean): Long = - arrayRep.copyFromAddress(sm, region, srcPType.asInstanceOf[PArrayBackedContainer].arrayRep, srcAddress, deepCopy) + override def _copyFromAddress( + sm: HailStateManager, + region: Region, + srcPType: PType, + srcAddress: Long, + deepCopy: Boolean, + ): Long = + arrayRep.copyFromAddress( + sm, + region, + srcPType.asInstanceOf[PArrayBackedContainer].arrayRep, + srcAddress, + deepCopy, + ) override def nextElementAddress(currentOffset: Long): Long = arrayRep.nextElementAddress(currentOffset) @@ -145,10 +161,25 @@ trait PArrayBackedContainer extends PContainer { override def pastLastElementOffset(aoff: Code[Long], length: Value[Int]): Code[Long] = arrayRep.pastLastElementOffset(aoff, length) - override def unstagedStoreAtAddress(sm: HailStateManager, addr: Long, region: Region, srcPType: PType, srcAddress: Long, deepCopy: Boolean): Unit = - arrayRep.unstagedStoreAtAddress(sm, addr, region, srcPType.asInstanceOf[PArrayBackedContainer].arrayRep, srcAddress, deepCopy) - - override def sType: SIndexablePointer = SIndexablePointer(setRequired(false).asInstanceOf[PArrayBackedContainer]) + override def unstagedStoreAtAddress( + sm: HailStateManager, + addr: Long, + region: Region, + srcPType: PType, + srcAddress: Long, + deepCopy: Boolean, + ): Unit = + arrayRep.unstagedStoreAtAddress( + sm, + addr, + region, + srcPType.asInstanceOf[PArrayBackedContainer].arrayRep, + srcAddress, + deepCopy, + ) + + override def sType: SIndexablePointer = + SIndexablePointer(setRequired(false).asInstanceOf[PArrayBackedContainer]) override def loadCheapSCode(cb: EmitCodeBuilder, addr: Code[Long]): SIndexablePointerValue = { val a = cb.memoize(addr) @@ -157,17 +188,28 @@ trait PArrayBackedContainer extends PContainer { new SIndexablePointerValue(sType, a, length, elementsAddr) } - override def store(cb: EmitCodeBuilder, region: Value[Region], value: SValue, deepCopy: Boolean): Value[Long] = + override def store(cb: EmitCodeBuilder, region: Value[Region], value: SValue, deepCopy: Boolean) + : Value[Long] = arrayRep.store(cb, region, value.asIndexable.castToArray(cb), deepCopy) - override def storeAtAddress(cb: EmitCodeBuilder, addr: Code[Long], region: Value[Region], value: SValue, deepCopy: Boolean): Unit = + override def storeAtAddress( + cb: EmitCodeBuilder, + addr: Code[Long], + region: Value[Region], + value: SValue, + deepCopy: Boolean, + ): Unit = arrayRep.storeAtAddress(cb, addr, region, value.asIndexable.castToArray(cb), deepCopy) override def loadFromNested(addr: Code[Long]): Code[Long] = arrayRep.loadFromNested(addr) override def unstagedLoadFromNested(addr: Long): Long = arrayRep.unstagedLoadFromNested(addr) - override def unstagedStoreJavaObjectAtAddress(sm: HailStateManager, addr: Long, annotation: Annotation, region: Region): Unit = { + override def unstagedStoreJavaObjectAtAddress( + sm: HailStateManager, + addr: Long, + annotation: Annotation, + region: Region, + ): Unit = Region.storeAddress(addr, unstagedStoreJavaObject(sm, annotation, region)) - } } diff --git a/hail/src/main/scala/is/hail/types/physical/PBaseStruct.scala b/hail/src/main/scala/is/hail/types/physical/PBaseStruct.scala index ded70b4d430..2a6c1e7ca23 100644 --- a/hail/src/main/scala/is/hail/types/physical/PBaseStruct.scala +++ b/hail/src/main/scala/is/hail/types/physical/PBaseStruct.scala @@ -9,12 +9,11 @@ import is.hail.types.physical.stypes.interfaces.SBaseStructValue import is.hail.utils._ object PBaseStruct { - def alignment(types: Array[PType]): Long = { + def alignment(types: Array[PType]): Long = if (types.isEmpty) 1 else types.map(_.alignment).max - } } abstract class PBaseStruct extends PType { @@ -44,9 +43,8 @@ abstract class PBaseStruct extends PType { def size: Int = fields.length - def isIsomorphicTo(other: PBaseStruct) = { + def isIsomorphicTo(other: PBaseStruct): Boolean = this.fields.size == other.fields.size && this.isCompatibleWith(other) - } def _toPretty: String = { val sb = new StringBuilder @@ -60,9 +58,7 @@ abstract class PBaseStruct extends PType { val sb = new StringBuilder sb.append(identBase) sb.append("_of_") - types.foreachBetween { ty => - sb.append(ty.asIdent) - } { + types.foreachBetween(ty => sb.append(ty.asIdent)) { sb.append("AND") } sb.append("END") @@ -73,7 +69,7 @@ abstract class PBaseStruct extends PType { size <= other.size && isCompatibleWith(other) def isCompatibleWith(other: PBaseStruct): Boolean = - fields.zip(other.fields).forall{ case (l, r) => l.typ isOfType r.typ } + fields.zip(other.fields).forall { case (l, r) => l.typ isOfType r.typ } override def unsafeOrdering(sm: HailStateManager): UnsafeOrdering = unsafeOrdering(sm, this) @@ -83,7 +79,7 @@ abstract class PBaseStruct extends PType { val right = rightType.asInstanceOf[PBaseStruct] val fieldOrderings: Array[UnsafeOrdering] = - types.zip(right.types).map { case (l, r) => l.unsafeOrdering(sm, r)} + types.zip(right.types).map { case (l, r) => l.unsafeOrdering(sm, r) } new UnsafeOrdering { def compare(o1: Long, o2: Long): Int = { @@ -117,7 +113,8 @@ abstract class PBaseStruct extends PType { def initialize(structAddress: Long, setMissing: Boolean = false): Unit - def stagedInitialize(cb: EmitCodeBuilder, structAddress: Code[Long], setMissing: Boolean = false): Unit + def stagedInitialize(cb: EmitCodeBuilder, structAddress: Code[Long], setMissing: Boolean = false) + : Unit def isFieldDefined(offset: Long, fieldIdx: Int): Boolean @@ -148,10 +145,9 @@ abstract class PBaseStruct extends PType { override lazy val containsPointers: Boolean = types.exists(_.containsPointers) - override def genNonmissingValue(sm: HailStateManager): Gen[Annotation] = { + override def genNonmissingValue(sm: HailStateManager): Gen[Annotation] = if (types.isEmpty) { Gen.const(Annotation.empty) } else Gen.uniformSequence(types.map(t => t.genValue(sm))).map(a => Annotation(a: _*)) - } } diff --git a/hail/src/main/scala/is/hail/types/physical/PBoolean.scala b/hail/src/main/scala/is/hail/types/physical/PBoolean.scala index 1410581d9da..c82c0fd0d82 100644 --- a/hail/src/main/scala/is/hail/types/physical/PBoolean.scala +++ b/hail/src/main/scala/is/hail/types/physical/PBoolean.scala @@ -13,36 +13,40 @@ case object PBooleanOptional extends PBoolean(false) case object PBooleanRequired extends PBoolean(true) class PBoolean(override val required: Boolean) extends PType with PPrimitive { - lazy val virtualType: TBoolean.type = TBoolean + lazy val virtualType: TBoolean.type = TBoolean def _asIdent = "bool" - override def _pretty(sb: StringBuilder, indent: Int, compact: Boolean): Unit = sb.append("PBoolean") + override def _pretty(sb: StringBuilder, indent: Int, compact: Boolean): Unit = + sb.append("PBoolean") override def unsafeOrdering(sm: HailStateManager): UnsafeOrdering = new UnsafeOrdering { - def compare(o1: Long, o2: Long): Int = { + def compare(o1: Long, o2: Long): Int = java.lang.Boolean.compare(Region.loadBoolean(o1), Region.loadBoolean(o2)) - } } override def byteSize: Long = 1 def sType: SBoolean.type = SBoolean - def storePrimitiveAtAddress(cb: EmitCodeBuilder, addr: Code[Long], value: SValue): Unit = { + def storePrimitiveAtAddress(cb: EmitCodeBuilder, addr: Code[Long], value: SValue): Unit = cb += Region.storeBoolean(addr, value.asBoolean.value) - } override def loadCheapSCode(cb: EmitCodeBuilder, addr: Code[Long]): SBooleanValue = new SBooleanValue(cb.memoize(Region.loadBoolean(addr))) - override def unstagedStoreJavaObjectAtAddress(sm: HailStateManager, addr: Long, annotation: Annotation, region: Region): Unit = { + override def unstagedStoreJavaObjectAtAddress( + sm: HailStateManager, + addr: Long, + annotation: Annotation, + region: Region, + ): Unit = Region.storeByte(addr, annotation.asInstanceOf[Boolean].toByte) - } } object PBoolean { - def apply(required: Boolean = false): PBoolean = if (required) PBooleanRequired else PBooleanOptional + def apply(required: Boolean = false): PBoolean = + if (required) PBooleanRequired else PBooleanOptional def unapply(t: PBoolean): Option[Boolean] = Option(t.required) } diff --git a/hail/src/main/scala/is/hail/types/physical/PCall.scala b/hail/src/main/scala/is/hail/types/physical/PCall.scala index 93f6cdc5662..b1b4d06c78f 100644 --- a/hail/src/main/scala/is/hail/types/physical/PCall.scala +++ b/hail/src/main/scala/is/hail/types/physical/PCall.scala @@ -4,4 +4,4 @@ import is.hail.types.virtual.TCall abstract class PCall extends PType { lazy val virtualType: TCall.type = TCall -} \ No newline at end of file +} diff --git a/hail/src/main/scala/is/hail/types/physical/PCanonicalArray.scala b/hail/src/main/scala/is/hail/types/physical/PCanonicalArray.scala index e0985a7895b..d532770e8be 100644 --- a/hail/src/main/scala/is/hail/types/physical/PCanonicalArray.scala +++ b/hail/src/main/scala/is/hail/types/physical/PCanonicalArray.scala @@ -16,7 +16,7 @@ final case class PCanonicalArray(elementType: PType, required: Boolean = false) def _asIdent = s"array_of_${elementType.asIdent}" - override def _pretty(sb: StringBuilder, indent: Int, compact: Boolean = false) { + override def _pretty(sb: StringBuilder, indent: Int, compact: Boolean = false): Unit = { sb.append("PCArray[") elementType.pretty(sb, indent, compact) sb.append("]") @@ -26,13 +26,20 @@ final case class PCanonicalArray(elementType: PType, required: Boolean = false) val a = cb.newLocal[Long]("a", addr) val l = cb.memoize(loadLength(addr)) cb.println("array header:") - cb.while_(a < firstElementOffset(addr, l), { - cb.println(" ", Code.invokeStatic1[java.lang.Long, Long, String]("toHexString", Region.loadLong(a)), - " (", Code.invokeStatic1[java.lang.Integer, Int, String]("toHexString", Region.loadInt(a)), - " ", Code.invokeStatic1[java.lang.Integer, Int, String]("toHexString", Region.loadInt(a+4)), - ")") - cb.assign(a, a + 8) - }) + cb.while_( + a < firstElementOffset(addr, l), { + cb.println( + " ", + Code.invokeStatic1[java.lang.Long, Long, String]("toHexString", Region.loadLong(a)), + " (", + Code.invokeStatic1[java.lang.Integer, Int, String]("toHexString", Region.loadInt(a)), + " ", + Code.invokeStatic1[java.lang.Integer, Int, String]("toHexString", Region.loadInt(a + 4)), + ")", + ) + cb.assign(a, a + 8) + }, + ) } private val elementByteSize: Long = UnsafeUtils.arrayElementSize(elementType) @@ -67,11 +74,10 @@ final case class PCanonicalArray(elementType: PType, required: Boolean = false) def contentsByteSize(length: Int): Long = elementsOffset(length) + length * elementByteSize - def contentsByteSize(length: Code[Int]): Code[Long] = { + def contentsByteSize(length: Code[Int]): Code[Long] = Code.memoize(length, "contentsByteSize_arr_len") { length => elementsOffset(length) + length.toL * elementByteSize } - } private def _elementsOffset(length: Int): Long = if (elementRequired) @@ -86,18 +92,18 @@ final case class PCanonicalArray(elementType: PType, required: Boolean = false) UnsafeUtils.roundUpAlignment(nMissingBytes(length).toL + lengthHeaderBytes, contentsAlignment) private lazy val lengthOffsetTable = 10 - private lazy val elementsOffsetTable: Array[Long] = Array.tabulate[Long](lengthOffsetTable)(i => _elementsOffset(i)) - def elementsOffset(length: Int): Long = { + private lazy val elementsOffsetTable: Array[Long] = + Array.tabulate[Long](lengthOffsetTable)(i => _elementsOffset(i)) + + def elementsOffset(length: Int): Long = if (length < lengthOffsetTable) elementsOffsetTable(length) else _elementsOffset(length) - } - def elementsOffset(length: Code[Int]): Code[Long] = { + def elementsOffset(length: Code[Int]): Code[Long] = _elementsOffset(length) - } def missingBytesOffset: Long = lengthHeaderBytes @@ -119,20 +125,18 @@ final case class PCanonicalArray(elementType: PType, required: Boolean = false) def isElementMissing(aoff: Code[Long], i: Code[Int]): Code[Boolean] = !isElementDefined(aoff, i) - def setElementMissing(aoff: Long, i: Int) { + def setElementMissing(aoff: Long, i: Int): Unit = if (!elementRequired) Region.setBit(aoff + lengthHeaderBytes, i) - } override def setElementMissing(cb: EmitCodeBuilder, aoff: Code[Long], i: Code[Int]): Unit = { assert(!elementRequired, s"Array elements of ptype '${elementType.asIdent}' cannot be missing.") cb += Region.setBit(aoff + lengthHeaderBytes, i.toL) } - def setElementPresent(aoff: Long, i: Int) { + def setElementPresent(aoff: Long, i: Int): Unit = if (!elementRequired) Region.clearBit(aoff + lengthHeaderBytes, i.toLong) - } def setElementPresent(cb: EmitCodeBuilder, aoff: Code[Long], i: Code[Int]): Unit = if (!elementRequired) @@ -166,7 +170,8 @@ final case class PCanonicalArray(elementType: PType, required: Boolean = false) firstElementOffset(aoff, loadLength(aoff)) + i.toL * const(elementByteSize) } - private def elementOffsetFromFirst(firstElementAddr: Code[Long], i: Code[Int]): Code[Long] = firstElementAddr + i.toL * const(elementByteSize) + private def elementOffsetFromFirst(firstElementAddr: Code[Long], i: Code[Int]): Code[Long] = + firstElementAddr + i.toL * const(elementByteSize) override def incrementElementOffset(currentOffset: Long, increment: Int): Long = currentOffset + increment * elementByteSize @@ -199,48 +204,47 @@ final case class PCanonicalArray(elementType: PType, required: Boolean = false) } def loadElement(aoff: Code[Long], i: Code[Int]): Code[Long] = - Code.memoize(aoff, "pcarr_load_elem_aoff") { aoff => - loadElement(aoff, loadLength(aoff), i) - } + Code.memoize(aoff, "pcarr_load_elem_aoff")(aoff => loadElement(aoff, loadLength(aoff), i)) - class Iterator ( + class Iterator( private[this] val aoff: Long, private[this] val length: Int, - private[this] var i: Int = 0 + private[this] var i: Int = 0, ) extends PArrayIterator { private[this] val firstElementOffset = PCanonicalArray.this.firstElementOffset( - aoff, length) + aoff, + length, + ) + def hasNext: Boolean = i != length def isDefined: Boolean = isElementDefined(aoff, i) + def value: Long = firstElementOffset + i * elementByteSize + def iterate: Unit = i += 1 } def elementIterator(aoff: Long, length: Int): Iterator = new Iterator(aoff, length) - def allocate(region: Region, length: Int): Long = { + def allocate(region: Region, length: Int): Long = region.allocate(contentsAlignment, contentsByteSize(length)) - } def allocate(region: Code[Region], length: Code[Int]): Code[Long] = region.allocate(contentsAlignment, contentsByteSize(length)) - private def writeMissingness(aoff: Long, length: Int, value: Byte) { + private def writeMissingness(aoff: Long, length: Int, value: Byte): Unit = Region.setMemory(aoff + lengthHeaderBytes, nMissingBytes(length), value) - } - def setAllMissingBits(aoff: Long, length: Int) { + def setAllMissingBits(aoff: Long, length: Int): Unit = if (!elementRequired) writeMissingness(aoff, length, -1) - } - def clearMissingBits(aoff: Long, length: Int) { + def clearMissingBits(aoff: Long, length: Int): Unit = if (!elementRequired) writeMissingness(aoff, length, 0) - } - def initialize(aoff: Long, length: Int, setMissing: Boolean = false) { + def initialize(aoff: Long, length: Int, setMissing: Boolean = false): Unit = { Region.storeInt(aoff, length) if (setMissing) setAllMissingBits(aoff, length) @@ -248,14 +252,23 @@ final case class PCanonicalArray(elementType: PType, required: Boolean = false) clearMissingBits(aoff, length) } - override def stagedInitialize(cb: EmitCodeBuilder, aoff: Code[Long], length: Code[Int], setMissing: Boolean = false): Unit = { + override def stagedInitialize( + cb: EmitCodeBuilder, + aoff: Code[Long], + length: Code[Int], + setMissing: Boolean = false, + ): Unit = { if (elementRequired) cb += Region.storeInt(aoff, length) else { val aoffMem = cb.memoize[Long](aoff) val lengthMem = cb.memoize[Int](length) cb += Region.storeInt(aoffMem, lengthMem) - cb += Region.setMemory(aoffMem + const(lengthHeaderBytes), nMissingBytes(lengthMem).toL, const(if (setMissing) (-1).toByte else 0.toByte)) + cb += Region.setMemory( + aoffMem + const(lengthHeaderBytes), + nMissingBytes(lengthMem).toL, + const(if (setMissing) (-1).toByte else 0.toByte), + ) } } @@ -272,7 +285,11 @@ final case class PCanonicalArray(elementType: PType, required: Boolean = false) val lengthMem = cb.memoize(length) val aoff = cb.memoize[Long](allocate(region, lengthMem)) stagedInitialize(cb, aoff, lengthMem) - cb += Region.setMemory(aoff + elementsOffset(lengthMem), lengthMem.toL * elementByteSize, 0.toByte) + cb += Region.setMemory( + aoff + elementsOffset(lengthMem), + lengthMem.toL * elementByteSize, + 0.toByte, + ) aoff } @@ -319,7 +336,12 @@ final case class PCanonicalArray(elementType: PType, required: Boolean = false) } } - def deepPointerCopy(cb: EmitCodeBuilder, region: Value[Region], dstAddressCode: Code[Long], len: Value[Int]): Unit = { + def deepPointerCopy( + cb: EmitCodeBuilder, + region: Value[Region], + dstAddressCode: Code[Long], + len: Value[Int], + ): Unit = { if (!elementType.containsPointers) { return } @@ -328,37 +350,52 @@ final case class PCanonicalArray(elementType: PType, required: Boolean = false) cb.assign(dstAddress, dstAddressCode) val currentIdx = cb.newLocal[Int]("pcarray_deep_pointer_copy_current_idx") val currentElementAddress = cb.newLocal[Long]("pcarray_deep_pointer_copy_current_element_addr") - cb.for_(cb.assign(currentIdx, 0), currentIdx < len, cb.assign(currentIdx, currentIdx + 1), - cb.if_(isElementDefined(dstAddress, currentIdx), - { + cb.for_( + cb.assign(currentIdx, 0), + currentIdx < len, + cb.assign(currentIdx, currentIdx + 1), + cb.if_( + isElementDefined(dstAddress, currentIdx), { cb.assign(currentElementAddress, elementOffset(dstAddress, len, currentIdx)) - elementType.storeAtAddress(cb, currentElementAddress, region, + elementType.storeAtAddress( + cb, + currentElementAddress, + region, elementType.loadCheapSCode(cb, elementType.loadFromNested(currentElementAddress)), - deepCopy = true + deepCopy = true, ) - }) + }, + ), ) } - def deepPointerCopy(sm: HailStateManager, region: Region, dstAddress: Long) { - if(!this.elementType.containsPointers) { + def deepPointerCopy(sm: HailStateManager, region: Region, dstAddress: Long): Unit = { + if (!this.elementType.containsPointers) { return } val numberOfElements = this.loadLength(dstAddress) var currentIdx = 0 - while(currentIdx < numberOfElements) { - if(this.isElementDefined(dstAddress, currentIdx)) { + while (currentIdx < numberOfElements) { + if (this.isElementDefined(dstAddress, currentIdx)) { val currentElementAddress = this.elementOffset(dstAddress, numberOfElements, currentIdx) - val currentElementAddressFromNested = this.elementType.unstagedLoadFromNested(currentElementAddress) - this.elementType.unstagedStoreAtAddress(sm, currentElementAddress, region, this.elementType, currentElementAddressFromNested, true) + val currentElementAddressFromNested = + this.elementType.unstagedLoadFromNested(currentElementAddress) + this.elementType.unstagedStoreAtAddress(sm, currentElementAddress, region, this.elementType, + currentElementAddressFromNested, true) } currentIdx += 1 } } - def _copyFromAddress(sm: HailStateManager, region: Region, srcPType: PType, srcAddress: Long, deepCopy: Boolean): Long = { + def _copyFromAddress( + sm: HailStateManager, + region: Region, + srcPType: PType, + srcAddress: Long, + deepCopy: Boolean, + ): Long = { val srcArrayT = srcPType.asInstanceOf[PArray] if (equalModuloRequired(srcArrayT)) { @@ -380,7 +417,14 @@ final case class PCanonicalArray(elementType: PType, required: Boolean = false) while (i < len) { if (srcArrayT.isElementDefined(srcAddress, i)) { setElementPresent(newAddr, i) - elementType.unstagedStoreAtAddress(sm, elementOffset(newAddr, len, i), region, srcElementT, srcArrayT.loadElement(srcAddress, len, i), deepCopy) + elementType.unstagedStoreAtAddress( + sm, + elementOffset(newAddr, len, i), + region, + srcElementT, + srcArrayT.loadElement(srcAddress, len, i), + deepCopy, + ) } else assert(!elementType.required) @@ -399,42 +443,73 @@ final case class PCanonicalArray(elementType: PType, required: Boolean = false) new SIndexablePointerValue(sType, a, length, offset) } - def storeContentsAtAddress(cb: EmitCodeBuilder, addr: Value[Long], region: Value[Region], indexable: SIndexableValue, deepCopy: Boolean): Unit = { + def storeContentsAtAddress( + cb: EmitCodeBuilder, + addr: Value[Long], + region: Value[Region], + indexable: SIndexableValue, + deepCopy: Boolean, + ): Unit = { val length = indexable.loadLength() indexable.st match { - case SIndexablePointer(PCanonicalArray(otherElementType, _)) if otherElementType == elementType => - cb += Region.copyFrom(indexable.asInstanceOf[SIndexablePointerValue].a, addr, contentsByteSize(length)) - deepPointerCopy(cb, region, addr, length) - case SIndexablePointer(otherType@PCanonicalArray(otherElementType, _)) if otherElementType.equalModuloRequired(elementType) => + case SIndexablePointer(PCanonicalArray(otherElementType, _)) + if otherElementType == elementType => + cb += Region.copyFrom( + indexable.asInstanceOf[SIndexablePointerValue].a, + addr, + contentsByteSize(length), + ) + deepPointerCopy(cb, region, addr, length) + case SIndexablePointer(otherType @ PCanonicalArray(otherElementType, _)) + if otherElementType.equalModuloRequired(elementType) => // other is optional, constructing required if (elementType.required) { - cb.if_(indexable.hasMissingValues(cb), - cb._fatal("tried to copy array with missing values to array of required elements")) + cb.if_( + indexable.hasMissingValues(cb), + cb._fatal("tried to copy array with missing values to array of required elements"), + ) } stagedInitialize(cb, addr, indexable.loadLength(), setMissing = false) - cb += Region.copyFrom(otherType.firstElementOffset(indexable.asInstanceOf[SIndexablePointerValue].a), this.firstElementOffset(addr), length.toL * otherType.elementByteSize) + cb += Region.copyFrom( + otherType.firstElementOffset(indexable.asInstanceOf[SIndexablePointerValue].a), + this.firstElementOffset(addr), + length.toL * otherType.elementByteSize, + ) if (deepCopy) deepPointerCopy(cb, region, addr, length) case _ => stagedInitialize(cb, addr, length, setMissing = false) val idx = cb.newLocal[Int]("pcarray_store_at_addr_idx") - cb.for_(cb.assign(idx, 0), idx < length, cb.assign(idx, idx + 1), + cb.for_( + cb.assign(idx, 0), + idx < length, + cb.assign(idx, idx + 1), indexable .loadElement(cb, idx) - .consume(cb, + .consume( + cb, PContainer.unsafeSetElementMissing(cb, this, addr, idx), - pc => elementType.storeAtAddress(cb, elementOffset(addr, length, idx), region, pc, deepCopy) - ) + pc => + elementType.storeAtAddress( + cb, + elementOffset(addr, length, idx), + region, + pc, + deepCopy, + ), + ), ) } } - def store(cb: EmitCodeBuilder, region: Value[Region], value: SValue, deepCopy: Boolean): Value[Long] = { + def store(cb: EmitCodeBuilder, region: Value[Region], value: SValue, deepCopy: Boolean) + : Value[Long] = { assert(value.st.virtualType.isInstanceOf[TArray]) value.st match { - case SIndexablePointer(PCanonicalArray(otherElementType, _)) if otherElementType == elementType && !deepCopy => + case SIndexablePointer(PCanonicalArray(otherElementType, _)) + if otherElementType == elementType && !deepCopy => value.asInstanceOf[SIndexablePointerValue].a case _ => val idxValue = value.asIndexable @@ -444,12 +519,23 @@ final case class PCanonicalArray(elementType: PType, required: Boolean = false) } } - def storeAtAddress(cb: EmitCodeBuilder, addr: Code[Long], region: Value[Region], value: SValue, deepCopy: Boolean): Unit = { + def storeAtAddress( + cb: EmitCodeBuilder, + addr: Code[Long], + region: Value[Region], + value: SValue, + deepCopy: Boolean, + ): Unit = cb += Region.storeAddress(addr, store(cb, region, value, deepCopy)) - } - - def unstagedStoreAtAddress(sm: HailStateManager, addr: Long, region: Region, srcPType: PType, srcAddress: Long, deepCopy: Boolean): Unit = { + def unstagedStoreAtAddress( + sm: HailStateManager, + addr: Long, + region: Region, + srcPType: PType, + srcAddress: Long, + deepCopy: Boolean, + ): Unit = { val srcArray = srcPType.asInstanceOf[PArray] Region.storeAddress(addr, copyFromAddress(sm, region, srcArray, srcAddress, deepCopy)) } @@ -459,61 +545,101 @@ final case class PCanonicalArray(elementType: PType, required: Boolean = false) private def deepRenameArray(t: TArray): PArray = PCanonicalArray(this.elementType.deepRename(t.elementType), this.required) - def padWithMissing(cb: EmitCodeBuilder, region: Value[Region], oldLength: Value[Int], newLength: Value[Int], srcAddress: Value[Long]): Value[Long] = { + def padWithMissing( + cb: EmitCodeBuilder, + region: Value[Region], + oldLength: Value[Int], + newLength: Value[Int], + srcAddress: Value[Long], + ): Value[Long] = { val dstAddress = cb.memoize(allocate(region, newLength)) stagedInitialize(cb, dstAddress, newLength, setMissing = true) - cb += Region.copyFrom(srcAddress + lengthHeaderBytes, dstAddress + lengthHeaderBytes, nMissingBytes(oldLength).toL) + cb += Region.copyFrom( + srcAddress + lengthHeaderBytes, + dstAddress + lengthHeaderBytes, + nMissingBytes(oldLength).toL, + ) cb += Region.copyFrom( srcAddress + elementsOffset(oldLength), dstAddress + elementsOffset(newLength), - oldLength.toL * elementByteSize) + oldLength.toL * elementByteSize, + ) dstAddress } - def constructFromElements(cb: EmitCodeBuilder, region: Value[Region], length: Value[Int], deepCopy: Boolean) - (f: (EmitCodeBuilder, Value[Int]) => IEmitCode): SIndexablePointerValue = { + def constructFromElements( + cb: EmitCodeBuilder, + region: Value[Region], + length: Value[Int], + deepCopy: Boolean, + )( + f: (EmitCodeBuilder, Value[Int]) => IEmitCode + ): SIndexablePointerValue = { val addr = cb.newLocal[Long]("pcarray_construct1_addr", allocate(region, length)) stagedInitialize(cb, addr, length, setMissing = false) val i = cb.newLocal[Int]("pcarray_construct1_i") - val firstElementAddr = cb.newLocal[Long]("pcarray_construct1_firstelementaddr", firstElementOffset(addr, length)) - cb.for_(cb.assign(i, 0), i < length, cb.assign(i, i + 1), { - f(cb, i).consume(cb, + val firstElementAddr = + cb.newLocal[Long]("pcarray_construct1_firstelementaddr", firstElementOffset(addr, length)) + cb.for_( + cb.assign(i, 0), + i < length, + cb.assign(i, i + 1), + f(cb, i).consume( + cb, PContainer.unsafeSetElementMissing(cb, this, addr, i), - { sc => elementType.storeAtAddress(cb, elementOffsetFromFirst(firstElementAddr, i), region, sc, deepCopy = deepCopy) } - ) - }) + sc => + elementType.storeAtAddress( + cb, + elementOffsetFromFirst(firstElementAddr, i), + region, + sc, + deepCopy = deepCopy, + ), + ), + ) new SIndexablePointerValue(sType, addr, length, firstElementAddr) } - // unsafe StagedArrayBuilder-like interface that gives caller control over pushing elements and finishing - def constructFromFunctions(cb: EmitCodeBuilder, region: Value[Region], length: Value[Int], deepCopy: Boolean): - ((EmitCodeBuilder, IEmitCode) => Unit, EmitCodeBuilder => SIndexablePointerValue) = { + /* unsafe StagedArrayBuilder-like interface that gives caller control over pushing elements and + * finishing */ + def constructFromFunctions( + cb: EmitCodeBuilder, + region: Value[Region], + length: Value[Int], + deepCopy: Boolean, + ): ((EmitCodeBuilder, IEmitCode) => Unit, EmitCodeBuilder => SIndexablePointerValue) = { val addr = cb.newLocal[Long]("pcarray_construct2_addr", allocate(region, length)) stagedInitialize(cb, addr, length, setMissing = false) val currentElementIndex = cb.newLocal[Int]("pcarray_construct2_current_idx", 0) - val firstElementAddress = cb.newLocal[Long]("pcarray_construct2_first_addr", firstElementOffset(addr, length)) - val currentElementAddress = cb.newLocal[Long]("pcarray_construct2_current_addr", firstElementAddress) + val firstElementAddress = + cb.newLocal[Long]("pcarray_construct2_first_addr", firstElementOffset(addr, length)) + val currentElementAddress = + cb.newLocal[Long]("pcarray_construct2_current_addr", firstElementAddress) def push(cb: EmitCodeBuilder, iec: IEmitCode): Unit = { - iec.consume(cb, + iec.consume( + cb, PContainer.unsafeSetElementMissing(cb, this, addr, currentElementIndex), - { sc => - elementType.storeAtAddress(cb, currentElementAddress, region, sc, deepCopy = deepCopy) - }) + sc => elementType.storeAtAddress(cb, currentElementAddress, region, sc, deepCopy = deepCopy), + ) cb.assign(currentElementIndex, currentElementIndex + 1) cb.assign(currentElementAddress, currentElementAddress + elementByteSize) } def finish(cb: EmitCodeBuilder): SIndexablePointerValue = { - cb.if_(currentElementIndex cne length, - cb._fatal("PCanonicalArray.constructFromFunctions push was called the wrong number of times", - ": len=", length.toS, - ", calls=", currentElementIndex.toS - ) + cb.if_( + currentElementIndex cne length, + cb._fatal( + "PCanonicalArray.constructFromFunctions push was called the wrong number of times", + ": len=", + length.toS, + ", calls=", + currentElementIndex.toS, + ), ) new SIndexablePointerValue(sType, addr, length, firstElementAddress) } @@ -521,21 +647,39 @@ final case class PCanonicalArray(elementType: PType, required: Boolean = false) (push, finish) } - def constructFromIndicesUnsafe(cb: EmitCodeBuilder, region: Value[Region], length: Value[Int], deepCopy: Boolean): - (((EmitCodeBuilder, Value[Int], IEmitCode) => Unit, (EmitCodeBuilder => SIndexablePointerValue))) = { + def constructFromIndicesUnsafe( + cb: EmitCodeBuilder, + region: Value[Region], + length: Value[Int], + deepCopy: Boolean, + ): ( + ( + (EmitCodeBuilder, Value[Int], IEmitCode) => Unit, + (EmitCodeBuilder => SIndexablePointerValue), + ), + ) = { val addr = cb.newLocal[Long]("pcarray_construct2_addr", allocate(region, length)) stagedInitialize(cb, addr, length, setMissing = false) - val firstElementAddress = cb.newLocal[Long]("pcarray_construct2_first_addr", firstElementOffset(addr, length)) + val firstElementAddress = + cb.newLocal[Long]("pcarray_construct2_first_addr", firstElementOffset(addr, length)) val push: (EmitCodeBuilder, Value[Int], IEmitCode) => Unit = { case (cb, idx, iec) => - iec.consume(cb, + iec.consume( + cb, PContainer.unsafeSetElementMissing(cb, this, addr, idx), { sc => setElementPresent(cb, addr, idx) - elementType.storeAtAddress(cb, firstElementAddress + idx.toL * elementByteSize, region, sc, deepCopy = deepCopy) - }) + elementType.storeAtAddress( + cb, + firstElementAddress + idx.toL * elementByteSize, + region, + sc, + deepCopy = deepCopy, + ) + }, + ) } val finish: EmitCodeBuilder => SIndexablePointerValue = { (cb: EmitCodeBuilder) => new SIndexablePointerValue(sType, addr, length, firstElementAddress) @@ -543,12 +687,12 @@ final case class PCanonicalArray(elementType: PType, required: Boolean = false) (push, finish) } - def loadFromNested(addr: Code[Long]): Code[Long] = Region.loadAddress(addr) override def unstagedLoadFromNested(addr: Long): Long = Region.loadAddress(addr) - override def unstagedStoreJavaObject(sm: HailStateManager, annotation: Annotation, region: Region): Long = { + override def unstagedStoreJavaObject(sm: HailStateManager, annotation: Annotation, region: Region) + : Long = { val is = annotation.asInstanceOf[IndexedSeq[Annotation]] val valueAddress = allocate(region, is.length) assert(is.length >= 0) @@ -559,8 +703,7 @@ final case class PCanonicalArray(elementType: PType, required: Boolean = false) while (i < is.length) { if (is(i) == null) { setElementMissing(valueAddress, i) - } - else { + } else { elementType.unstagedStoreJavaObjectAtAddress(sm, curElementAddress, is(i), region) } curElementAddress = nextElementAddress(curElementAddress) @@ -570,12 +713,18 @@ final case class PCanonicalArray(elementType: PType, required: Boolean = false) valueAddress } - override def unstagedStoreJavaObjectAtAddress(sm: HailStateManager, addr: Long, annotation: Annotation, region: Region): Unit = { - annotation match { - case uis: UnsafeIndexedSeq => this.unstagedStoreAtAddress(sm, addr, region, uis.t, uis.aoff, region.ne(uis.region)) - case is: IndexedSeq[Annotation] => Region.storeAddress(addr, unstagedStoreJavaObject(sm, annotation, region)) - } - } + override def unstagedStoreJavaObjectAtAddress( + sm: HailStateManager, + addr: Long, + annotation: Annotation, + region: Region, + ): Unit = + annotation match { + case uis: UnsafeIndexedSeq => + this.unstagedStoreAtAddress(sm, addr, region, uis.t, uis.aoff, region.ne(uis.region)) + case _: IndexedSeq[Annotation] => + Region.storeAddress(addr, unstagedStoreJavaObject(sm, annotation, region)) + } override def copiedType: PType = { val copiedElement = elementType.copiedType @@ -585,21 +734,30 @@ final case class PCanonicalArray(elementType: PType, required: Boolean = false) PCanonicalArray(copiedElement, required) } - def forEachDefined(cb: EmitCodeBuilder, aoff: Value[Long])(f: (EmitCodeBuilder, Value[Int], SValue) => Unit) { + def forEachDefined( + cb: EmitCodeBuilder, + aoff: Value[Long], + )( + f: (EmitCodeBuilder, Value[Int], SValue) => Unit + ): Unit = { val length = cb.memoize(loadLength(aoff)) val elementsAddress = cb.memoize(firstElementOffset(aoff)) val idx = cb.newLocal[Int]("foreach_pca_idx", 0) val elementPtr = cb.newLocal[Long]("foreach_pca_elt_ptr", elementsAddress) val et = elementType - cb.while_(idx < length, { - cb.if_(isElementMissing(aoff, idx), - {}, // do nothing, - { - val elt = et.loadCheapSCode(cb, et.loadFromNested(elementPtr)) - f(cb, idx, elt) - }) - cb.assign(idx, idx + 1) - cb.assign(elementPtr, elementPtr + elementByteSize) - }) + cb.while_( + idx < length, { + cb.if_( + isElementMissing(aoff, idx), + {}, // do nothing, + { + val elt = et.loadCheapSCode(cb, et.loadFromNested(elementPtr)) + f(cb, idx, elt) + }, + ) + cb.assign(idx, idx + 1) + cb.assign(elementPtr, elementPtr + elementByteSize) + }, + ) } } diff --git a/hail/src/main/scala/is/hail/types/physical/PCanonicalBaseStruct.scala b/hail/src/main/scala/is/hail/types/physical/PCanonicalBaseStruct.scala index 7d78d36a51e..9bd01589f00 100644 --- a/hail/src/main/scala/is/hail/types/physical/PCanonicalBaseStruct.scala +++ b/hail/src/main/scala/is/hail/types/physical/PCanonicalBaseStruct.scala @@ -8,24 +8,29 @@ import is.hail.types.BaseStruct import is.hail.types.physical.stypes.SValue import is.hail.types.physical.stypes.concrete.{SBaseStructPointer, SBaseStructPointerValue} import is.hail.utils._ + import org.apache.spark.sql.Row abstract class PCanonicalBaseStruct(val types: Array[PType]) extends PBaseStruct { if (!types.forall(_.isRealizable)) { throw new AssertionError( - s"found non realizable type(s) ${ types.filter(!_.isRealizable).mkString(", ") } in ${ types.mkString(", ") }") + s"found non realizable type(s) ${types.filter(!_.isRealizable).mkString(", ")} in ${types.mkString(", ")}" + ) } - override val (missingIdx: Array[Int], nMissing: Int) = BaseStruct.getMissingIndexAndCount(types.map(_.required)) + override val (missingIdx: Array[Int], nMissing: Int) = + BaseStruct.getMissingIndexAndCount(types.map(_.required)) + val nMissingBytes: Int = UnsafeUtils.packBitsToBytes(nMissing) val byteOffsets: Array[Long] = new Array[Long](size) - override val byteSize: Long = getByteSizeAndOffsets(types.map(_.byteSize), types.map(_.alignment), nMissingBytes, byteOffsets) - override val alignment: Long = PBaseStruct.alignment(types) + override val byteSize: Long = + getByteSizeAndOffsets(types.map(_.byteSize), types.map(_.alignment), nMissingBytes, byteOffsets) + + override val alignment: Long = PBaseStruct.alignment(types) - override def allocate(region: Region): Long = { + override def allocate(region: Region): Long = region.allocate(alignment, byteSize) - } override def allocate(region: Code[Region]): Code[Long] = region.allocate(alignment, byteSize) @@ -35,46 +40,51 @@ abstract class PCanonicalBaseStruct(val types: Array[PType]) extends PBaseStruct return } - Region.setMemory(structAddress, nMissingBytes.toLong, if (setMissing) 0xFF.toByte else 0.toByte) + Region.setMemory(structAddress, nMissingBytes.toLong, if (setMissing) 0xff.toByte else 0.toByte) } - override def stagedInitialize(cb: EmitCodeBuilder, structAddress: Code[Long], setMissing: Boolean = false): Unit = { + override def stagedInitialize( + cb: EmitCodeBuilder, + structAddress: Code[Long], + setMissing: Boolean = false, + ): Unit = if (!allFieldsRequired) { - cb += Region.setMemory(structAddress, const(nMissingBytes.toLong), const(if (setMissing) 0xFF.toByte else 0.toByte)) + cb += Region.setMemory( + structAddress, + const(nMissingBytes.toLong), + const(if (setMissing) 0xff.toByte else 0.toByte), + ) } - } override def isFieldDefined(offset: Long, fieldIdx: Int): Boolean = fieldRequired(fieldIdx) || !Region.loadBit(offset, missingIdx(fieldIdx)) - override def isFieldMissing(cb: EmitCodeBuilder, offset: Code[Long], fieldIdx: Int): Value[Boolean] = + override def isFieldMissing(cb: EmitCodeBuilder, offset: Code[Long], fieldIdx: Int) + : Value[Boolean] = if (fieldRequired(fieldIdx)) false else cb.memoize(Region.loadBit(offset, missingIdx(fieldIdx).toLong)) - override def setFieldMissing(offset: Long, fieldIdx: Int) { + override def setFieldMissing(offset: Long, fieldIdx: Int): Unit = { assert(!fieldRequired(fieldIdx)) Region.setBit(offset, missingIdx(fieldIdx)) } - override def setFieldMissing(cb: EmitCodeBuilder, offset: Code[Long], fieldIdx: Int): Unit = { + override def setFieldMissing(cb: EmitCodeBuilder, offset: Code[Long], fieldIdx: Int): Unit = if (!fieldRequired(fieldIdx)) cb += Region.setBit(offset, missingIdx(fieldIdx).toLong) else { cb._fatal(s"Required field cannot be missing.") } - } - override def setFieldPresent(offset: Long, fieldIdx: Int) { + override def setFieldPresent(offset: Long, fieldIdx: Int): Unit = if (!fieldRequired(fieldIdx)) Region.clearBit(offset, missingIdx(fieldIdx)) - } - override def setFieldPresent(cb: EmitCodeBuilder, offset: Code[Long], fieldIdx: Int): Unit = { + override def setFieldPresent(cb: EmitCodeBuilder, offset: Code[Long], fieldIdx: Int): Unit = if (!fieldRequired(fieldIdx)) cb += Region.clearBit(offset, missingIdx(fieldIdx).toLong) - } override def fieldOffset(structAddress: Long, fieldIdx: Int): Long = structAddress + byteOffsets(fieldIdx) @@ -87,49 +97,78 @@ abstract class PCanonicalBaseStruct(val types: Array[PType]) extends PBaseStruct types(fieldIdx).unstagedLoadFromNested(off) } - override def loadField(offset: Code[Long], fieldIdx: Int): Code[Long] = loadField(fieldOffset(offset, fieldIdx), types(fieldIdx)) + override def loadField(offset: Code[Long], fieldIdx: Int): Code[Long] = + loadField(fieldOffset(offset, fieldIdx), types(fieldIdx)) - private def loadField(fieldOffset: Code[Long], fieldType: PType): Code[Long] = { + private def loadField(fieldOffset: Code[Long], fieldType: PType): Code[Long] = fieldType.loadFromNested(fieldOffset) - } - def deepPointerCopy(cb: EmitCodeBuilder, region: Value[Region], dstStructAddress: Code[Long]): Unit = { + def deepPointerCopy(cb: EmitCodeBuilder, region: Value[Region], dstStructAddress: Code[Long]) + : Unit = { val dstAddr = cb.newLocal[Long]("pcbs_dpcopy_dst", dstStructAddress) fields.foreach { f => val dstFieldType = f.typ if (dstFieldType.containsPointers) { - cb.if_(isFieldDefined(cb, dstAddr, f.index), - { + cb.if_( + isFieldDefined(cb, dstAddr, f.index), { val fieldAddr = cb.newLocal[Long]("pcbs_dpcopy_field", fieldOffset(dstAddr, f.index)) - dstFieldType.storeAtAddress(cb, fieldAddr, region, dstFieldType.loadCheapSCode(cb, dstFieldType.loadFromNested(fieldAddr)), deepCopy = true) - }) + dstFieldType.storeAtAddress( + cb, + fieldAddr, + region, + dstFieldType.loadCheapSCode(cb, dstFieldType.loadFromNested(fieldAddr)), + deepCopy = true, + ) + }, + ) } } } - def deepPointerCopy(sm: HailStateManager, region: Region, dstStructAddress: Long) { + def deepPointerCopy(sm: HailStateManager, region: Region, dstStructAddress: Long): Unit = { var i = 0 while (i < this.size) { val dstFieldType = this.fields(i).typ if (dstFieldType.containsPointers && this.isFieldDefined(dstStructAddress, i)) { val dstFieldAddress = this.fieldOffset(dstStructAddress, i) val dstFieldAddressFromNested = dstFieldType.unstagedLoadFromNested(dstFieldAddress) - dstFieldType.unstagedStoreAtAddress(sm, dstFieldAddress, region, dstFieldType, dstFieldAddressFromNested, true) + dstFieldType.unstagedStoreAtAddress(sm, dstFieldAddress, region, dstFieldType, + dstFieldAddressFromNested, true) } i += 1 } } - override def _copyFromAddress(sm: HailStateManager, region: Region, srcPType: PType, srcAddress: Long, deepCopy: Boolean): Long = { + override def _copyFromAddress( + sm: HailStateManager, + region: Region, + srcPType: PType, + srcAddress: Long, + deepCopy: Boolean, + ): Long = { if (equalModuloRequired(srcPType) && !deepCopy) return srcAddress val newAddr = allocate(region) - unstagedStoreAtAddress(sm, newAddr, region, srcPType.asInstanceOf[PBaseStruct], srcAddress, deepCopy) + unstagedStoreAtAddress( + sm, + newAddr, + region, + srcPType.asInstanceOf[PBaseStruct], + srcAddress, + deepCopy, + ) newAddr } - override def unstagedStoreAtAddress(sm: HailStateManager, addr: Long, region: Region, srcPType: PType, srcAddress: Long, deepCopy: Boolean): Unit = { + override def unstagedStoreAtAddress( + sm: HailStateManager, + addr: Long, + region: Region, + srcPType: PType, + srcAddress: Long, + deepCopy: Boolean, + ): Unit = { val srcStruct = srcPType.asInstanceOf[PBaseStruct] if (equalModuloRequired(srcStruct)) { Region.copyFrom(srcAddress, addr, byteSize) @@ -142,7 +181,13 @@ abstract class PCanonicalBaseStruct(val types: Array[PType]) extends PBaseStruct if (srcStruct.isFieldDefined(srcAddress, idx)) { setFieldPresent(addr, idx) types(idx).unstagedStoreAtAddress( - sm, fieldOffset(addr, idx), region, srcStruct.types(idx), srcStruct.loadField(srcAddress, idx), deepCopy) + sm, + fieldOffset(addr, idx), + region, + srcStruct.types(idx), + srcStruct.loadField(srcAddress, idx), + deepCopy, + ) } else assert(!fieldRequired(idx)) idx += 1 @@ -150,12 +195,14 @@ abstract class PCanonicalBaseStruct(val types: Array[PType]) extends PBaseStruct } } - override def sType: SBaseStructPointer = SBaseStructPointer(setRequired(false).asInstanceOf[PCanonicalBaseStruct]) + override def sType: SBaseStructPointer = + SBaseStructPointer(setRequired(false).asInstanceOf[PCanonicalBaseStruct]) override def loadCheapSCode(cb: EmitCodeBuilder, addr: Code[Long]): SBaseStructPointerValue = new SBaseStructPointerValue(sType, cb.memoize(addr)) - override def store(cb: EmitCodeBuilder, region: Value[Region], value: SValue, deepCopy: Boolean): Value[Long] = { + override def store(cb: EmitCodeBuilder, region: Value[Region], value: SValue, deepCopy: Boolean) + : Value[Long] = { value.st match { case SBaseStructPointer(t) if t.equalModuloRequired(this) && !deepCopy => value.asInstanceOf[SBaseStructPointerValue].a @@ -166,7 +213,13 @@ abstract class PCanonicalBaseStruct(val types: Array[PType]) extends PBaseStruct } } - override def storeAtAddress(cb: EmitCodeBuilder, addr: Code[Long], region: Value[Region], value: SValue, deepCopy: Boolean): Unit = { + override def storeAtAddress( + cb: EmitCodeBuilder, + addr: Code[Long], + region: Value[Region], + value: SValue, + deepCopy: Boolean, + ): Unit = { value.st match { case SBaseStructPointer(t) if t.equalModuloRequired(this) => val pcs = value.asInstanceOf[SBaseStructPointerValue] @@ -181,58 +234,63 @@ abstract class PCanonicalBaseStruct(val types: Array[PType]) extends PBaseStruct fields.foreach { f => pcs.loadField(cb, f.index) - .consume(cb, - { - setFieldMissing(cb, addrVar, f.index) - }, - { sv => - f.typ.storeAtAddress(cb, fieldOffset(addrVar, f.index), region, sv, deepCopy) - }) + .consume( + cb, + setFieldMissing(cb, addrVar, f.index), + sv => f.typ.storeAtAddress(cb, fieldOffset(addrVar, f.index), region, sv, deepCopy), + ) } } } - def constructFromFields(cb: EmitCodeBuilder, region: Value[Region], emitFields: IndexedSeq[EmitCode], deepCopy: Boolean): SBaseStructPointerValue = { + def constructFromFields( + cb: EmitCodeBuilder, + region: Value[Region], + emitFields: IndexedSeq[EmitCode], + deepCopy: Boolean, + ): SBaseStructPointerValue = { require(emitFields.length == size) val addr = cb.newLocal[Long]("pcbs_construct_fields", allocate(region)) stagedInitialize(cb, addr, setMissing = false) emitFields.zipWithIndex.foreach { case (ev, i) => ev.toI(cb) - .consume(cb, + .consume( + cb, setFieldMissing(cb, addr, i), - { sc => - types(i).storeAtAddress(cb, fieldOffset(addr, i), region, sc, deepCopy = deepCopy) - } + sc => types(i).storeAtAddress(cb, fieldOffset(addr, i), region, sc, deepCopy = deepCopy), ) } new SBaseStructPointerValue(sType, addr) } - override def unstagedStoreJavaObject(sm: HailStateManager, annotation: Annotation, region: Region): Long = { + override def unstagedStoreJavaObject(sm: HailStateManager, annotation: Annotation, region: Region) + : Long = { val addr = allocate(region) unstagedStoreJavaObjectAtAddress(sm, addr, annotation, region) addr } - override def unstagedStoreJavaObjectAtAddress(sm: HailStateManager, addr: Long, annotation: Annotation, region: Region): Unit = { + override def unstagedStoreJavaObjectAtAddress( + sm: HailStateManager, + addr: Long, + annotation: Annotation, + region: Region, + ): Unit = { initialize(addr) val row = annotation.asInstanceOf[Row] row match { - case ur: UnsafeRow => { + case ur: UnsafeRow => this.unstagedStoreAtAddress(sm, addr, region, ur.t, ur.offset, region.ne(ur.region)) - } - case sr: Row => { + case _: Row => this.types.zipWithIndex.foreach { case (fieldPt, fieldIdx) => if (row(fieldIdx) == null) { setFieldMissing(addr, fieldIdx) - } - else { + } else { val fieldAddress = fieldOffset(addr, fieldIdx) fieldPt.unstagedStoreJavaObjectAtAddress(sm, fieldAddress, row(fieldIdx), region) } } - } } } diff --git a/hail/src/main/scala/is/hail/types/physical/PCanonicalBinary.scala b/hail/src/main/scala/is/hail/types/physical/PCanonicalBinary.scala index 488df439273..728fe9aa463 100644 --- a/hail/src/main/scala/is/hail/types/physical/PCanonicalBinary.scala +++ b/hail/src/main/scala/is/hail/types/physical/PCanonicalBinary.scala @@ -19,7 +19,13 @@ class PCanonicalBinary(val required: Boolean) extends PBinary { override def byteSize: Long = 8 - override def _copyFromAddress(sm: HailStateManager, region: Region, srcPType: PType, srcAddress: Long, deepCopy: Boolean): Long = { + override def _copyFromAddress( + sm: HailStateManager, + region: Region, + srcPType: PType, + srcAddress: Long, + deepCopy: Boolean, + ): Long = { val srcBinary = srcPType.asInstanceOf[PCanonicalBinary] if (srcBinary == this) { if (deepCopy) { @@ -38,10 +44,10 @@ class PCanonicalBinary(val required: Boolean) extends PBinary { } } - override def containsPointers: Boolean = true - override def _pretty(sb: StringBuilder, indent: Int, compact: Boolean): Unit = sb.append("PCBinary") + override def _pretty(sb: StringBuilder, indent: Int, compact: Boolean): Unit = + sb.append("PCBinary") def contentAlignment: Long = 4 @@ -77,13 +83,14 @@ class PCanonicalBinary(val required: Boolean) extends PBinary { def storeLength(boff: Long, len: Int): Unit = Region.storeInt(boff, len) - def storeLength(cb: EmitCodeBuilder, boff: Code[Long], len: Code[Int]): Unit = cb += Region.storeInt(boff, len) + def storeLength(cb: EmitCodeBuilder, boff: Code[Long], len: Code[Int]): Unit = + cb += Region.storeInt(boff, len) def bytesAddress(boff: Long): Long = boff + lengthHeaderBytes def bytesAddress(boff: Code[Long]): Code[Long] = boff + lengthHeaderBytes - def store(addr: Long, bytes: Array[Byte]) { + def store(addr: Long, bytes: Array[Byte]): Unit = { Region.storeInt(addr, bytes.length) Region.storeBytes(bytesAddress(addr), bytes) } @@ -95,7 +102,8 @@ class PCanonicalBinary(val required: Boolean) extends PBinary { cb += Region.storeBytes(bytesAddress(addr), bytes) } - def constructFromByteArray(cb: EmitCodeBuilder, region: Value[Region], bytes: Code[Array[Byte]]): SBinaryPointerValue = { + def constructFromByteArray(cb: EmitCodeBuilder, region: Value[Region], bytes: Code[Array[Byte]]) + : SBinaryPointerValue = { val ba = cb.newLocal[Array[Byte]]("pcbin_ba", bytes) val len = cb.newLocal[Int]("pcbin_len", ba.length()) val addr = cb.newLocal[Long]("pcbin_addr", allocate(region, len)) @@ -103,12 +111,25 @@ class PCanonicalBinary(val required: Boolean) extends PBinary { loadCheapSCode(cb, addr) } - def constructAtAddress(cb: EmitCodeBuilder, addr: Code[Long], region: Value[Region], srcPType: PType, srcAddress: Code[Long], deepCopy: Boolean): Unit = { + def constructAtAddress( + cb: EmitCodeBuilder, + addr: Code[Long], + region: Value[Region], + srcPType: PType, + srcAddress: Code[Long], + deepCopy: Boolean, + ): Unit = { val srcBinary = srcPType.asInstanceOf[PBinary] cb += Region.storeAddress(addr, constructOrCopy(cb, region, srcBinary, srcAddress, deepCopy)) } - private def constructOrCopy(cb: EmitCodeBuilder, region: Value[Region], srcBinary: PBinary, srcAddress: Code[Long], deepCopy: Boolean): Code[Long] = { + private def constructOrCopy( + cb: EmitCodeBuilder, + region: Value[Region], + srcBinary: PBinary, + srcAddress: Code[Long], + deepCopy: Boolean, + ): Code[Long] = { if (srcBinary == this) { if (deepCopy) { val srcAddrVar = cb.newLocal[Long]("pcanonical_binary_construct_or_copy_src_addr") @@ -129,7 +150,11 @@ class PCanonicalBinary(val required: Boolean) extends PBinary { cb.assign(len, srcBinary.loadLength(srcAddrVar)) cb.assign(newAddr, allocate(region, len)) storeLength(cb, newAddr, len) - cb += Region.copyFrom(srcAddrVar + srcBinary.lengthHeaderBytes, newAddr + lengthHeaderBytes, len.toL) + cb += Region.copyFrom( + srcAddrVar + srcBinary.lengthHeaderBytes, + newAddr + lengthHeaderBytes, + len.toL, + ) newAddr } } @@ -139,7 +164,8 @@ class PCanonicalBinary(val required: Boolean) extends PBinary { def loadCheapSCode(cb: EmitCodeBuilder, addr: Code[Long]): SBinaryPointerValue = new SBinaryPointerValue(sType, cb.memoize(addr)) - def store(cb: EmitCodeBuilder, region: Value[Region], value: SValue, deepCopy: Boolean): Value[Long] = { + def store(cb: EmitCodeBuilder, region: Value[Region], value: SValue, deepCopy: Boolean) + : Value[Long] = { value.st match { case SBinaryPointer(PCanonicalBinary(_)) => if (deepCopy) { @@ -161,11 +187,23 @@ class PCanonicalBinary(val required: Boolean) extends PBinary { } } - def storeAtAddress(cb: EmitCodeBuilder, addr: Code[Long], region: Value[Region], value: SValue, deepCopy: Boolean): Unit = { + def storeAtAddress( + cb: EmitCodeBuilder, + addr: Code[Long], + region: Value[Region], + value: SValue, + deepCopy: Boolean, + ): Unit = cb += Region.storeAddress(addr, store(cb, region, value, deepCopy)) - } - def unstagedStoreAtAddress(sm: HailStateManager, addr: Long, region: Region, srcPType: PType, srcAddress: Long, deepCopy: Boolean): Unit = { + def unstagedStoreAtAddress( + sm: HailStateManager, + addr: Long, + region: Region, + srcPType: PType, + srcAddress: Long, + deepCopy: Boolean, + ): Unit = { val srcArray = srcPType.asInstanceOf[PBinary] Region.storeAddress(addr, copyFromAddress(sm, region, srcArray, srcAddress, deepCopy)) } @@ -177,11 +215,16 @@ class PCanonicalBinary(val required: Boolean) extends PBinary { override def unstagedLoadFromNested(addr: Long): Long = Region.loadAddress(addr) - override def unstagedStoreJavaObjectAtAddress(sm: HailStateManager, addr: Long, annotation: Annotation, region: Region): Unit = { + override def unstagedStoreJavaObjectAtAddress( + sm: HailStateManager, + addr: Long, + annotation: Annotation, + region: Region, + ): Unit = Region.storeAddress(addr, unstagedStoreJavaObject(sm, annotation, region)) - } - override def unstagedStoreJavaObject(sm: HailStateManager, annotation: Annotation, region: Region): Long = { + override def unstagedStoreJavaObject(sm: HailStateManager, annotation: Annotation, region: Region) + : Long = { val bytes = annotation.asInstanceOf[Array[Byte]] val valueAddress = allocate(region, bytes.length) store(valueAddress, bytes) @@ -190,7 +233,8 @@ class PCanonicalBinary(val required: Boolean) extends PBinary { } object PCanonicalBinary { - def apply(required: Boolean = false): PCanonicalBinary = if (required) PCanonicalBinaryRequired else PCanonicalBinaryOptional + def apply(required: Boolean = false): PCanonicalBinary = + if (required) PCanonicalBinaryRequired else PCanonicalBinaryOptional def unapply(t: PBinary): Option[Boolean] = Option(t.required) } diff --git a/hail/src/main/scala/is/hail/types/physical/PCanonicalCall.scala b/hail/src/main/scala/is/hail/types/physical/PCanonicalCall.scala index c0dba64d40d..d8d3d142732 100644 --- a/hail/src/main/scala/is/hail/types/physical/PCanonicalCall.scala +++ b/hail/src/main/scala/is/hail/types/physical/PCanonicalCall.scala @@ -21,31 +21,53 @@ final case class PCanonicalCall(required: Boolean = false) extends PCall { def byteSize: Long = representation.byteSize override def alignment: Long = representation.alignment - override def unsafeOrdering(sm: HailStateManager): UnsafeOrdering = representation.unsafeOrdering(sm) // this was a terrible idea - - def setRequired(required: Boolean) = if (required == this.required) this else PCanonicalCall(required) - - override def unstagedStoreAtAddress(sm: HailStateManager, addr: Long, region: Region, srcPType: PType, srcAddress: Long, deepCopy: Boolean): Unit = { + override def unsafeOrdering(sm: HailStateManager): UnsafeOrdering = + representation.unsafeOrdering(sm) // this was a terrible idea + + def setRequired(required: Boolean) = + if (required == this.required) this else PCanonicalCall(required) + + override def unstagedStoreAtAddress( + sm: HailStateManager, + addr: Long, + region: Region, + srcPType: PType, + srcAddress: Long, + deepCopy: Boolean, + ): Unit = srcPType match { case pt: PCanonicalCall => - representation.unstagedStoreAtAddress(sm, addr, region, pt.representation, srcAddress, deepCopy) + representation.unstagedStoreAtAddress( + sm, + addr, + region, + pt.representation, + srcAddress, + deepCopy, + ) } - } override def containsPointers: Boolean = representation.containsPointers - override def _copyFromAddress(sm: HailStateManager, region: Region, srcPType: PType, srcAddress: Long, deepCopy: Boolean): Long = { + override def _copyFromAddress( + sm: HailStateManager, + region: Region, + srcPType: PType, + srcAddress: Long, + deepCopy: Boolean, + ): Long = srcPType match { - case pt: PCanonicalCall => representation._copyFromAddress(sm, region, pt.representation, srcAddress, deepCopy) + case pt: PCanonicalCall => + representation._copyFromAddress(sm, region, pt.representation, srcAddress, deepCopy) } - } def sType: SCall = SCanonicalCall def loadCheapSCode(cb: EmitCodeBuilder, addr: Code[Long]): SCanonicalCallValue = new SCanonicalCallValue(cb.memoize(Region.loadInt(addr))) - def store(cb: EmitCodeBuilder, region: Value[Region], value: SValue, deepCopy: Boolean): Value[Long] = { + def store(cb: EmitCodeBuilder, region: Value[Region], value: SValue, deepCopy: Boolean) + : Value[Long] = { value.st match { case SCanonicalCall => val newAddr = cb.memoize(region.allocate(representation.alignment, representation.byteSize)) @@ -54,19 +76,29 @@ final case class PCanonicalCall(required: Boolean = false) extends PCall { } } - def storeAtAddress(cb: EmitCodeBuilder, addr: Code[Long], region: Value[Region], value: SValue, deepCopy: Boolean): Unit = { + def storeAtAddress( + cb: EmitCodeBuilder, + addr: Code[Long], + region: Value[Region], + value: SValue, + deepCopy: Boolean, + ): Unit = cb += Region.storeInt(addr, value.asCall.canonicalCall(cb)) - } def loadFromNested(addr: Code[Long]): Code[Long] = representation.loadFromNested(addr) - override def unstagedLoadFromNested(addr: Long): Long = representation.unstagedLoadFromNested(addr) + override def unstagedLoadFromNested(addr: Long): Long = + representation.unstagedLoadFromNested(addr) - override def unstagedStoreJavaObject(sm: HailStateManager, annotation: Annotation, region: Region): Long = { + override def unstagedStoreJavaObject(sm: HailStateManager, annotation: Annotation, region: Region) + : Long = representation.unstagedStoreJavaObject(sm, annotation, region) - } - override def unstagedStoreJavaObjectAtAddress(sm: HailStateManager, addr: Long, annotation: Annotation, region: Region): Unit = { + override def unstagedStoreJavaObjectAtAddress( + sm: HailStateManager, + addr: Long, + annotation: Annotation, + region: Region, + ): Unit = representation.unstagedStoreJavaObjectAtAddress(sm, addr, annotation, region) - } } diff --git a/hail/src/main/scala/is/hail/types/physical/PCanonicalDict.scala b/hail/src/main/scala/is/hail/types/physical/PCanonicalDict.scala index 42fee90af9d..d1f29d29631 100644 --- a/hail/src/main/scala/is/hail/types/physical/PCanonicalDict.scala +++ b/hail/src/main/scala/is/hail/types/physical/PCanonicalDict.scala @@ -5,28 +5,30 @@ import is.hail.backend.HailStateManager import is.hail.types.physical.stypes.concrete.{SIndexablePointer, SIndexablePointerValue} import is.hail.types.physical.stypes.interfaces.SIndexableValue import is.hail.types.virtual.{TDict, Type} + import org.apache.spark.sql.Row object PCanonicalDict { - def coerceArrayCode(contents: SIndexableValue): SIndexableValue = { + def coerceArrayCode(contents: SIndexableValue): SIndexableValue = contents.st match { case SIndexablePointer(PCanonicalArray(ps: PBaseStruct, r)) => PCanonicalDict(ps.types(0), ps.types(1), r) .construct(contents) } - } } -final case class PCanonicalDict(keyType: PType, valueType: PType, required: Boolean = false) extends PDict with PArrayBackedContainer { +final case class PCanonicalDict(keyType: PType, valueType: PType, required: Boolean = false) + extends PDict with PArrayBackedContainer { val elementType = PCanonicalStruct(required = true, "key" -> keyType, "value" -> valueType) val arrayRep: PCanonicalArray = PCanonicalArray(elementType, required) - def setRequired(required: Boolean) = if(required == this.required) this else PCanonicalDict(keyType, valueType, required) + def setRequired(required: Boolean) = + if (required == this.required) this else PCanonicalDict(keyType, valueType, required) def _asIdent = s"dict_of_${keyType.asIdent}AND${valueType.asIdent}" - override def _pretty(sb: StringBuilder, indent: Int, compact: Boolean = false) { + override def _pretty(sb: StringBuilder, indent: Int, compact: Boolean = false): Unit = { sb.append("PCDict[") keyType.pretty(sb, indent, compact) if (compact) @@ -40,11 +42,16 @@ final case class PCanonicalDict(keyType: PType, valueType: PType, required: Bool override def deepRename(t: Type) = deepRenameDict(t.asInstanceOf[TDict]) private def deepRenameDict(t: TDict) = - PCanonicalDict(this.keyType.deepRename(t.keyType), this.valueType.deepRename(t.valueType), this.required) + PCanonicalDict( + this.keyType.deepRename(t.keyType), + this.valueType.deepRename(t.valueType), + this.required, + ) - override def unstagedStoreJavaObject(sm: HailStateManager, annotation: Annotation, region: Region): Long = { + override def unstagedStoreJavaObject(sm: HailStateManager, annotation: Annotation, region: Region) + : Long = { val annotMap = annotation.asInstanceOf[Map[Annotation, Annotation]] - val sortedArray = annotMap.map{ case (k, v) => Row(k, v) } + val sortedArray = annotMap.map { case (k, v) => Row(k, v) } .toArray .sorted(elementType.virtualType.ordering(sm).toOrdering) .toIndexedSeq @@ -54,8 +61,8 @@ final case class PCanonicalDict(keyType: PType, valueType: PType, required: Bool def construct(contents: SIndexableValue): SIndexableValue = { contents.st match { case SIndexablePointer(PCanonicalArray(pbs: PBaseStruct, _)) - if pbs.types.size == 2 && pbs.types(0) == keyType && pbs.types(1) == valueType => - case t => throw new RuntimeException(s"PCDict.construct: contents=${t}, arrayrep=${arrayRep}") + if pbs.types.size == 2 && pbs.types(0) == keyType && pbs.types(1) == valueType => + case t => throw new RuntimeException(s"PCDict.construct: contents=$t, arrayrep=$arrayRep") } val cont = contents.asInstanceOf[SIndexablePointerValue] new SIndexablePointerValue(SIndexablePointer(this), cont.a, cont.length, cont.elementsAddress) diff --git a/hail/src/main/scala/is/hail/types/physical/PCanonicalInterval.scala b/hail/src/main/scala/is/hail/types/physical/PCanonicalInterval.scala index 277a5a99e3f..e9f03d7e62e 100644 --- a/hail/src/main/scala/is/hail/types/physical/PCanonicalInterval.scala +++ b/hail/src/main/scala/is/hail/types/physical/PCanonicalInterval.scala @@ -5,21 +5,25 @@ import is.hail.asm4s._ import is.hail.backend.HailStateManager import is.hail.expr.ir.{EmitCode, EmitCodeBuilder} import is.hail.types.physical.stypes.SValue -import is.hail.types.physical.stypes.concrete.{SIntervalPointer, SIntervalPointerValue, SStackStruct} +import is.hail.types.physical.stypes.concrete.{ + SIntervalPointer, SIntervalPointerValue, SStackStruct, +} import is.hail.types.physical.stypes.interfaces.primitive import is.hail.types.physical.stypes.primitives.SBooleanValue import is.hail.types.virtual.{TInterval, Type} import is.hail.utils.{FastSeq, Interval} + import org.apache.spark.sql.Row -final case class PCanonicalInterval(pointType: PType, override val required: Boolean = false) extends PInterval { +final case class PCanonicalInterval(pointType: PType, override val required: Boolean = false) + extends PInterval { override def byteSize: Long = representation.byteSize override def alignment: Long = representation.alignment - override def _asIdent = s"interval_of_${ pointType.asIdent }" + override def _asIdent = s"interval_of_${pointType.asIdent}" - override def _pretty(sb: StringBuilder, indent: Int, compact: Boolean = false) { + override def _pretty(sb: StringBuilder, indent: Int, compact: Boolean = false): Unit = { sb.append("PCInterval[") pointType.pretty(sb, indent, compact) sb.append("]") @@ -30,7 +34,8 @@ final case class PCanonicalInterval(pointType: PType, override val required: Boo "start" -> pointType, "end" -> pointType, "includesStart" -> PBooleanRequired, - "includesEnd" -> PBooleanRequired) + "includesEnd" -> PBooleanRequired, + ) override def setRequired(required: Boolean): PCanonicalInterval = if (required == this.required) this else PCanonicalInterval(this.pointType, required) @@ -51,9 +56,11 @@ final case class PCanonicalInterval(pointType: PType, override val required: Boo override def endDefined(off: Long): Boolean = representation.isFieldDefined(off, 1) - override def includesStart(off: Long): Boolean = Region.loadBoolean(representation.loadField(off, 2)) + override def includesStart(off: Long): Boolean = + Region.loadBoolean(representation.loadField(off, 2)) - override def includesEnd(off: Long): Boolean = Region.loadBoolean(representation.loadField(off, 3)) + override def includesEnd(off: Long): Boolean = + Region.loadBoolean(representation.loadField(off, 3)) override def startDefined(cb: EmitCodeBuilder, off: Code[Long]): Value[Boolean] = representation.isFieldDefined(cb, off, 0) @@ -83,79 +90,154 @@ final case class PCanonicalInterval(pointType: PType, override val required: Boo new SIntervalPointerValue(sType, a, incStart, incEnd) } - override def store(cb: EmitCodeBuilder, region: Value[Region], value: SValue, deepCopy: Boolean): Value[Long] = { + override def store(cb: EmitCodeBuilder, region: Value[Region], value: SValue, deepCopy: Boolean) + : Value[Long] = { value.st match { case SIntervalPointer(t: PCanonicalInterval) => - representation.store(cb, region, t.representation.loadCheapSCode(cb, value.asInstanceOf[SIntervalPointerValue].a), deepCopy) + representation.store( + cb, + region, + t.representation.loadCheapSCode(cb, value.asInstanceOf[SIntervalPointerValue].a), + deepCopy, + ) case _ => val interval = value.asInterval val start = EmitCode.fromI(cb.emb)(cb => interval.loadStart(cb)) val stop = EmitCode.fromI(cb.emb)(cb => interval.loadEnd(cb)) val includesStart = EmitCode.present(cb.emb, new SBooleanValue(interval.includesStart)) val includesStop = EmitCode.present(cb.emb, new SBooleanValue(interval.includesEnd)) - representation.store(cb, region, - SStackStruct.constructFromArgs(cb, region, representation.virtualType, - start, stop, includesStart, includesStop), deepCopy) + representation.store( + cb, + region, + SStackStruct.constructFromArgs( + cb, + region, + representation.virtualType, + start, + stop, + includesStart, + includesStop, + ), + deepCopy, + ) } } - override def storeAtAddress(cb: EmitCodeBuilder, addr: Code[Long], region: Value[Region], value: SValue, deepCopy: Boolean): Unit = { + override def storeAtAddress( + cb: EmitCodeBuilder, + addr: Code[Long], + region: Value[Region], + value: SValue, + deepCopy: Boolean, + ): Unit = { value.st match { case SIntervalPointer(t: PCanonicalInterval) => - representation.storeAtAddress(cb, addr, region, t.representation.loadCheapSCode(cb, value.asInstanceOf[SIntervalPointerValue].a), deepCopy) + representation.storeAtAddress( + cb, + addr, + region, + t.representation.loadCheapSCode(cb, value.asInstanceOf[SIntervalPointerValue].a), + deepCopy, + ) case _ => val interval = value.asInterval val start = EmitCode.fromI(cb.emb)(cb => interval.loadStart(cb)) val stop = EmitCode.fromI(cb.emb)(cb => interval.loadEnd(cb)) val includesStart = EmitCode.present(cb.emb, new SBooleanValue(interval.includesStart)) val includesStop = EmitCode.present(cb.emb, new SBooleanValue(interval.includesEnd)) - representation.storeAtAddress(cb, addr, region, - SStackStruct.constructFromArgs(cb, region, representation.virtualType, - start, stop, includesStart, includesStop), - deepCopy) + representation.storeAtAddress( + cb, + addr, + region, + SStackStruct.constructFromArgs( + cb, + region, + representation.virtualType, + start, + stop, + includesStart, + includesStop, + ), + deepCopy, + ) } } - override def unstagedStoreAtAddress(sm: HailStateManager, addr: Long, region: Region, srcPType: PType, srcAddress: Long, deepCopy: Boolean): Unit = { + override def unstagedStoreAtAddress( + sm: HailStateManager, + addr: Long, + region: Region, + srcPType: PType, + srcAddress: Long, + deepCopy: Boolean, + ): Unit = srcPType match { case t: PCanonicalInterval => - representation.unstagedStoreAtAddress(sm, addr, region, t.representation, srcAddress, deepCopy) + representation.unstagedStoreAtAddress( + sm, + addr, + region, + t.representation, + srcAddress, + deepCopy, + ) } - } - override def _copyFromAddress(sm: HailStateManager, region: Region, srcPType: PType, srcAddress: Long, deepCopy: Boolean): Long = { + override def _copyFromAddress( + sm: HailStateManager, + region: Region, + srcPType: PType, + srcAddress: Long, + deepCopy: Boolean, + ): Long = srcPType match { case t: PCanonicalInterval => representation._copyFromAddress(sm, region, t.representation, srcAddress, deepCopy) } - } override def loadFromNested(addr: Code[Long]): Code[Long] = representation.loadFromNested(addr) - override def unstagedLoadFromNested(addr: Long): Long = representation.unstagedLoadFromNested(addr) + override def unstagedLoadFromNested(addr: Long): Long = + representation.unstagedLoadFromNested(addr) - override def unstagedStoreJavaObjectAtAddress(sm: HailStateManager, addr: Long, annotation: Annotation, region: Region): Unit = { + override def unstagedStoreJavaObjectAtAddress( + sm: HailStateManager, + addr: Long, + annotation: Annotation, + region: Region, + ): Unit = { val jInterval = annotation.asInstanceOf[Interval] representation.unstagedStoreJavaObjectAtAddress( sm, addr, Row(jInterval.start, jInterval.end, jInterval.includesStart, jInterval.includesEnd), - region + region, ) } - override def unstagedStoreJavaObject(sm: HailStateManager, annotation: Annotation, region: Region): Long = { + override def unstagedStoreJavaObject(sm: HailStateManager, annotation: Annotation, region: Region) + : Long = { val addr = representation.allocate(region) unstagedStoreJavaObjectAtAddress(sm, addr, annotation, region) addr } - def constructFromCodes(cb: EmitCodeBuilder, region: Value[Region], - start: EmitCode, end: EmitCode, includesStart: Value[Boolean], includesEnd: Value[Boolean] + def constructFromCodes( + cb: EmitCodeBuilder, + region: Value[Region], + start: EmitCode, + end: EmitCode, + includesStart: Value[Boolean], + includesEnd: Value[Boolean], ): SIntervalPointerValue = { val startEC = EmitCode.present(cb.emb, primitive(includesStart)) val endEC = EmitCode.present(cb.emb, primitive(includesEnd)) - val sc = representation.constructFromFields(cb, region, FastSeq(start, end, startEC, endEC), deepCopy = false) + val sc = representation.constructFromFields( + cb, + region, + FastSeq(start, end, startEC, endEC), + deepCopy = false, + ) new SIntervalPointerValue(sType, sc.a, includesStart, includesEnd) } diff --git a/hail/src/main/scala/is/hail/types/physical/PCanonicalLocus.scala b/hail/src/main/scala/is/hail/types/physical/PCanonicalLocus.scala index 2e4d01653d7..402d34f3c0d 100644 --- a/hail/src/main/scala/is/hail/types/physical/PCanonicalLocus.scala +++ b/hail/src/main/scala/is/hail/types/physical/PCanonicalLocus.scala @@ -5,16 +5,19 @@ import is.hail.asm4s._ import is.hail.backend.HailStateManager import is.hail.expr.ir.{EmitCode, EmitCodeBuilder} import is.hail.types.physical.stypes.SValue -import is.hail.types.physical.stypes.concrete.{SCanonicalLocusPointer, SCanonicalLocusPointerValue, SStackStruct} +import is.hail.types.physical.stypes.concrete.{ + SCanonicalLocusPointer, SCanonicalLocusPointerValue, SStackStruct, +} import is.hail.types.physical.stypes.interfaces._ import is.hail.utils.FastSeq import is.hail.variant._ object PCanonicalLocus { - private def representation(required: Boolean = false): PCanonicalStruct = PCanonicalStruct( + private def representation(required: Boolean): PCanonicalStruct = PCanonicalStruct( required, "contig" -> PCanonicalString(required = true), - "position" -> PInt32(required = true)) + "position" -> PInt32(required = true), + ) def schemaFromRG(rg: Option[String], required: Boolean = false): PType = rg match { case Some(name) => PCanonicalLocus(name, required) @@ -31,23 +34,26 @@ final case class PCanonicalLocus(rgName: String, required: Boolean = false) exte def rg: String = rgName - def _asIdent = "locus" + def _asIdent = s"locus_$rgName" - override def _pretty(sb: StringBuilder, indent: Call, compact: Boolean): Unit = sb.append(s"PCLocus($rgName)") + override def _pretty(sb: StringBuilder, indent: Call, compact: Boolean): Unit = + sb.append(s"PCLocus($rgName)") def setRequired(required: Boolean): PCanonicalLocus = if (required == this.required) this else PCanonicalLocus(this.rgName, required) val representation: PCanonicalStruct = PCanonicalLocus.representation(required) - private[physical] def contigAddr(address: Code[Long]): Code[Long] = representation.loadField(address, 0) + private[physical] def contigAddr(address: Code[Long]): Code[Long] = + representation.loadField(address, 0) private[physical] def contigAddr(address: Long): Long = representation.loadField(address, 0) def contig(address: Long): String = contigType.loadString(contigAddr(address)) def position(address: Long): Int = Region.loadInt(representation.fieldOffset(address, 1)) - lazy val contigType: PCanonicalString = representation.field("contig").typ.asInstanceOf[PCanonicalString] + lazy val contigType: PCanonicalString = + representation.field("contig").typ.asInstanceOf[PCanonicalString] def position(off: Code[Long]): Code[Int] = Region.loadInt(representation.loadField(off, 1)) @@ -76,19 +82,38 @@ final case class PCanonicalLocus(rgName: String, required: Boolean = false) exte } } - override def unstagedStoreAtAddress(sm: HailStateManager, addr: Long, region: Region, srcPType: PType, srcAddress: Long, deepCopy: Boolean): Unit = { + override def unstagedStoreAtAddress( + sm: HailStateManager, + addr: Long, + region: Region, + srcPType: PType, + srcAddress: Long, + deepCopy: Boolean, + ): Unit = srcPType match { - case pt: PCanonicalLocus => representation.unstagedStoreAtAddress(sm, addr, region, pt.representation, srcAddress, deepCopy) + case pt: PCanonicalLocus => representation.unstagedStoreAtAddress( + sm, + addr, + region, + pt.representation, + srcAddress, + deepCopy, + ) } - } override def containsPointers: Boolean = representation.containsPointers - override def _copyFromAddress(sm: HailStateManager, region: Region, srcPType: PType, srcAddress: Long, deepCopy: Boolean): Long = { + override def _copyFromAddress( + sm: HailStateManager, + region: Region, + srcPType: PType, + srcAddress: Long, + deepCopy: Boolean, + ): Long = srcPType match { - case pt: PCanonicalLocus => representation._copyFromAddress(sm, region, pt.representation, srcAddress, deepCopy) + case pt: PCanonicalLocus => + representation._copyFromAddress(sm, region, pt.representation, srcAddress, deepCopy) } - } def sType: SCanonicalLocusPointer = SCanonicalLocusPointer(setRequired(false)) @@ -97,11 +122,16 @@ final case class PCanonicalLocus(rgName: String, required: Boolean = false) exte new SCanonicalLocusPointerValue(sType, a, cb.memoize(contigAddr(a)), cb.memoize(position(a))) } - - def store(cb: EmitCodeBuilder, region: Value[Region], value: SValue, deepCopy: Boolean): Value[Long] = { + def store(cb: EmitCodeBuilder, region: Value[Region], value: SValue, deepCopy: Boolean) + : Value[Long] = { value.st match { case SCanonicalLocusPointer(pt) => - representation.store(cb, region, pt.representation.loadCheapSCode(cb, value.asInstanceOf[SCanonicalLocusPointerValue].a), deepCopy) + representation.store( + cb, + region, + pt.representation.loadCheapSCode(cb, value.asInstanceOf[SCanonicalLocusPointerValue].a), + deepCopy, + ) case _ => val addr = cb.memoize(representation.allocate(region)) storeAtAddress(cb, addr, region, value, deepCopy) @@ -109,16 +139,37 @@ final case class PCanonicalLocus(rgName: String, required: Boolean = false) exte } } - def storeAtAddress(cb: EmitCodeBuilder, addr: Code[Long], region: Value[Region], value: SValue, deepCopy: Boolean): Unit = { + def storeAtAddress( + cb: EmitCodeBuilder, + addr: Code[Long], + region: Value[Region], + value: SValue, + deepCopy: Boolean, + ): Unit = { value.st match { case SCanonicalLocusPointer(pt) => - representation.storeAtAddress(cb, addr, region, pt.representation.loadCheapSCode(cb, value.asInstanceOf[SCanonicalLocusPointerValue].a), deepCopy) + representation.storeAtAddress( + cb, + addr, + region, + pt.representation.loadCheapSCode(cb, value.asInstanceOf[SCanonicalLocusPointerValue].a), + deepCopy, + ) case _ => val loc = value.asLocus - representation.storeAtAddress(cb, addr, region, - SStackStruct.constructFromArgs(cb, region, representation.virtualType, - EmitCode.present(cb.emb, loc.contig(cb)), EmitCode.present(cb.emb, primitive(loc.position(cb)))), - deepCopy) + representation.storeAtAddress( + cb, + addr, + region, + SStackStruct.constructFromArgs( + cb, + region, + representation.virtualType, + EmitCode.present(cb.emb, loc.contig(cb)), + EmitCode.present(cb.emb, primitive(loc.position(cb))), + ), + deepCopy, + ) } } @@ -126,27 +177,64 @@ final case class PCanonicalLocus(rgName: String, required: Boolean = false) exte override def unstagedLoadFromNested(addr: Long): Long = addr - override def unstagedStoreLocus(sm: HailStateManager, addr: Long, contig: String, position: Int, region: Region): Unit = { - contigType.unstagedStoreJavaObjectAtAddress(sm, representation.fieldOffset(addr, 0), contig, region) - positionType.unstagedStoreJavaObjectAtAddress(sm, representation.fieldOffset(addr, 1), position, region) + override def unstagedStoreLocus( + sm: HailStateManager, + addr: Long, + contig: String, + position: Int, + region: Region, + ): Unit = { + contigType.unstagedStoreJavaObjectAtAddress( + sm, + representation.fieldOffset(addr, 0), + contig, + region, + ) + positionType.unstagedStoreJavaObjectAtAddress( + sm, + representation.fieldOffset(addr, 1), + position, + region, + ) } - override def unstagedStoreJavaObjectAtAddress(sm: HailStateManager, addr: Long, annotation: Annotation, region: Region): Unit = { + override def unstagedStoreJavaObjectAtAddress( + sm: HailStateManager, + addr: Long, + annotation: Annotation, + region: Region, + ): Unit = { val myLocus = annotation.asInstanceOf[Locus] unstagedStoreLocus(sm, addr, myLocus.contig, myLocus.position, region) } - override def unstagedStoreJavaObject(sm: HailStateManager, annotation: Annotation, region: Region): Long = { + override def unstagedStoreJavaObject(sm: HailStateManager, annotation: Annotation, region: Region) + : Long = { val addr = representation.allocate(region) unstagedStoreJavaObjectAtAddress(sm, addr, annotation, region) addr } - def constructFromPositionAndString(cb: EmitCodeBuilder, r: Value[Region], contig: Code[String], pos: Code[Int]): SCanonicalLocusPointerValue = { + def constructFromContigAndPosition( + cb: EmitCodeBuilder, + r: Value[Region], + contig: Code[String], + pos: Code[Int], + ): SCanonicalLocusPointerValue = { val position = cb.memoize(pos) val contigType = representation.fieldType("contig").asInstanceOf[PCanonicalString] val contigCode = contigType.sType.constructFromString(cb, r, contig) - val repr = representation.constructFromFields(cb, r, FastSeq(EmitCode.present(cb.emb, contigCode), EmitCode.present(cb.emb, primitive(position))), deepCopy = false) - new SCanonicalLocusPointerValue(SCanonicalLocusPointer(setRequired(false)), repr.a, contigCode.a, position) + val repr = representation.constructFromFields( + cb, + r, + FastSeq(EmitCode.present(cb.emb, contigCode), EmitCode.present(cb.emb, primitive(position))), + deepCopy = false, + ) + new SCanonicalLocusPointerValue( + SCanonicalLocusPointer(setRequired(false)), + repr.a, + contigCode.a, + position, + ) } } diff --git a/hail/src/main/scala/is/hail/types/physical/PCanonicalNDArray.scala b/hail/src/main/scala/is/hail/types/physical/PCanonicalNDArray.scala index 22a66d2da3f..a7a1c5f9dd3 100644 --- a/hail/src/main/scala/is/hail/types/physical/PCanonicalNDArray.scala +++ b/hail/src/main/scala/is/hail/types/physical/PCanonicalNDArray.scala @@ -3,29 +3,40 @@ package is.hail.types.physical import is.hail.annotations.{Annotation, NDArray, Region, UnsafeOrdering} import is.hail.asm4s.{Code, _} import is.hail.backend.HailStateManager -import is.hail.expr.ir.{CodeParam, CodeParamType, EmitCode, EmitCodeBuilder, Param, ParamType, SCodeParam} +import is.hail.expr.ir.{ + CodeParam, CodeParamType, EmitCode, EmitCodeBuilder, Param, ParamType, SCodeParam, +} import is.hail.types.physical.stypes.SValue import is.hail.types.physical.stypes.concrete._ import is.hail.types.physical.stypes.interfaces._ import is.hail.types.virtual.{TNDArray, Type} import is.hail.utils._ + import org.apache.spark.sql.Row -final case class PCanonicalNDArray(elementType: PType, nDims: Int, required: Boolean = false) extends PNDArray { +final case class PCanonicalNDArray(elementType: PType, nDims: Int, required: Boolean = false) + extends PNDArray { assert(elementType.required, "elementType must be required") - assert(!elementType.containsPointers, "ndarrays do not currently support elements which contain arrays, ndarrays, or strings") - def _asIdent: String = s"ndarray_of_${elementType.asIdent}" + assert( + !elementType.containsPointers, + "ndarrays do not currently support elements which contain arrays, ndarrays, or strings", + ) + + override def _asIdent: String = + s"${nDims}darray_of_${elementType.asIdent}" override def containsPointers: Boolean = true - override def _pretty(sb: StringBuilder, indent: Int, compact: Boolean = false) { + override def _pretty(sb: StringBuilder, indent: Int, compact: Boolean = false): Unit = { sb.append("PCNDArray[") elementType.pretty(sb, indent, compact) sb.append(s",$nDims]") } - lazy val shapeType: PCanonicalTuple = PCanonicalTuple(true, Array.tabulate(nDims)(_ => PInt64Required):_*) + lazy val shapeType: PCanonicalTuple = + PCanonicalTuple(true, Array.tabulate(nDims)(_ => PInt64Required): _*) + lazy val strideType: PCanonicalTuple = shapeType def loadShape(ndAddr: Long, idx: Int): Long = { @@ -38,93 +49,116 @@ final case class PCanonicalNDArray(elementType: PType, nDims: Int, required: Boo Region.loadLong(strideType.loadField(shapeTupleAddr, idx)) } - - def loadShapes(cb: EmitCodeBuilder, addr: Value[Long], settables: IndexedSeq[Settable[Long]]): Unit = { - assert(settables.length == nDims, s"got ${ settables.length } settables, expect ${ nDims } dims") + def loadShapes(cb: EmitCodeBuilder, addr: Value[Long], settables: IndexedSeq[Settable[Long]]) + : Unit = { + assert(settables.length == nDims, s"got ${settables.length} settables, expect $nDims dims") val shapeTuple = shapeType.loadCheapSCode(cb, representation.loadField(addr, "shape")) (0 until nDims).foreach { dimIdx => - cb.assign(settables(dimIdx), shapeTuple.loadField(cb, dimIdx).get(cb).asLong.value) + cb.assign(settables(dimIdx), shapeTuple.loadField(cb, dimIdx).getOrAssert(cb).asLong.value) } } - def loadStrides(cb: EmitCodeBuilder, addr: Value[Long], settables: IndexedSeq[Settable[Long]]): Unit = { + def loadStrides(cb: EmitCodeBuilder, addr: Value[Long], settables: IndexedSeq[Settable[Long]]) + : Unit = { assert(settables.length == nDims) val strideTuple = strideType.loadCheapSCode(cb, representation.loadField(addr, "strides")) (0 until nDims).foreach { dimIdx => - cb.assign(settables(dimIdx), strideTuple.loadField(cb, dimIdx).get(cb).asLong.value) + cb.assign(settables(dimIdx), strideTuple.loadField(cb, dimIdx).getOrAssert(cb).asLong.value) } } - override def unstagedLoadStrides(addr: Long): IndexedSeq[Long] = { - (0 until nDims).map { dimIdx => - this.loadStride(addr, dimIdx) - } - } + override def unstagedLoadStrides(addr: Long): IndexedSeq[Long] = + (0 until nDims).map(dimIdx => this.loadStride(addr, dimIdx)) - lazy val representation: PCanonicalStruct = { - PCanonicalStruct(required, + lazy val representation: PCanonicalStruct = + PCanonicalStruct( + required, ("shape", shapeType), ("strides", strideType), - ("data", PInt64Required)) - } + ("data", PInt64Required), + ) override lazy val byteSize: Long = representation.byteSize override lazy val alignment: Long = representation.alignment - override def unsafeOrdering(sm: HailStateManager): UnsafeOrdering = representation.unsafeOrdering(sm) + override def unsafeOrdering(sm: HailStateManager): UnsafeOrdering = + representation.unsafeOrdering(sm) - def numElements(shape: IndexedSeq[Value[Long]]): Code[Long] = { + def numElements(shape: IndexedSeq[Value[Long]]): Code[Long] = shape.foldLeft(1L: Code[Long])(_ * _) - } - def numElements(shape: IndexedSeq[Long]): Long = { + def numElements(shape: IndexedSeq[Long]): Long = shape.foldLeft(1L)(_ * _) - } - def makeColumnMajorStrides(sourceShapeArray: IndexedSeq[Value[Long]], cb: EmitCodeBuilder): IndexedSeq[Value[Long]] = { + def makeColumnMajorStrides(sourceShapeArray: IndexedSeq[Value[Long]], cb: EmitCodeBuilder) + : IndexedSeq[Value[Long]] = { val strides = new Array[Value[Long]](nDims) - for (i <- 0 until nDims) { + for (i <- 0 until nDims) if (i == 0) strides(i) = const(elementType.byteSize) - else strides(i) = cb.memoize(strides(i-1) * (sourceShapeArray(i-1) > 0L).mux(sourceShapeArray(i-1), 1L)) - } + else strides(i) = + cb.memoize(strides(i - 1) * (sourceShapeArray(i - 1) > 0L).mux(sourceShapeArray(i - 1), 1L)) strides } - def makeRowMajorStrides(sourceShapeArray: IndexedSeq[Value[Long]], cb: EmitCodeBuilder): IndexedSeq[Value[Long]] = { + def makeRowMajorStrides(sourceShapeArray: IndexedSeq[Value[Long]], cb: EmitCodeBuilder) + : IndexedSeq[Value[Long]] = { val strides = new Array[Value[Long]](nDims) - for (i <- (nDims - 1) to 0 by -1) { + for (i <- (nDims - 1) to 0 by -1) if (i == nDims - 1) strides(i) = const(elementType.byteSize) - else strides(i) = cb.memoize(strides(i+1) * (sourceShapeArray(i+1) > 0L).mux(sourceShapeArray(i+1), 1L)) - } + else strides(i) = + cb.memoize(strides(i + 1) * (sourceShapeArray(i + 1) > 0L).mux(sourceShapeArray(i + 1), 1L)) strides } def getElementAddress(indices: IndexedSeq[Long], nd: Long): Long = { var bytesAway = 0L - indices.zipWithIndex.foreach{case (requestedIndex: Long, strideIndex: Int) => + indices.zipWithIndex.foreach { case (requestedIndex: Long, strideIndex: Int) => bytesAway += requestedIndex * loadStride(nd, strideIndex) } bytesAway + this.unstagedDataFirstElementPointer(nd) } - private def getElementAddress(cb: EmitCodeBuilder, indices: IndexedSeq[Value[Long]], nd: Value[Long]): Value[Long] = { + private def getElementAddress( + cb: EmitCodeBuilder, + indices: IndexedSeq[Value[Long]], + nd: Value[Long], + ): Value[Long] = { val ndarrayValue = loadCheapSCode(cb, nd).asNDArray val stridesTuple = ndarrayValue.strides - cb.newLocal[Long]("pcndarray_get_element_addr", indices.zipWithIndex.map { case (requestedElementIndex, strideIndex) => - requestedElementIndex * stridesTuple(strideIndex) - }.foldLeft(const(0L).get)(_ + _) + ndarrayValue.firstDataAddress) + cb.newLocal[Long]( + "pcndarray_get_element_addr", + indices.zipWithIndex.map { case (requestedElementIndex, strideIndex) => + requestedElementIndex * stridesTuple(strideIndex) + }.foldLeft(const(0L).get)(_ + _) + ndarrayValue.firstDataAddress, + ) } - def setElement(cb: EmitCodeBuilder, region: Value[Region], - indices: IndexedSeq[Value[Long]], ndAddress: Value[Long], newElement: SValue, deepCopy: Boolean): Unit = { - elementType.storeAtAddress(cb, getElementAddress(cb, indices, ndAddress), region, newElement, deepCopy) - } + def setElement( + cb: EmitCodeBuilder, + region: Value[Region], + indices: IndexedSeq[Value[Long]], + ndAddress: Value[Long], + newElement: SValue, + deepCopy: Boolean, + ): Unit = + elementType.storeAtAddress( + cb, + getElementAddress(cb, indices, ndAddress), + region, + newElement, + deepCopy, + ) - private def getElementAddressFromDataPointerAndStrides(indices: IndexedSeq[Value[Long]], dataFirstElementPointer: Value[Long], strides: IndexedSeq[Value[Long]], cb: EmitCodeBuilder): Code[Long] = { + private def getElementAddressFromDataPointerAndStrides( + indices: IndexedSeq[Value[Long]], + dataFirstElementPointer: Value[Long], + strides: IndexedSeq[Value[Long]], + cb: EmitCodeBuilder, + ): Code[Long] = { val address = cb.newLocal[Long]("nd_get_element_address_bytes_away") cb.assign(address, dataFirstElementPointer) @@ -134,21 +168,26 @@ final case class PCanonicalNDArray(elementType: PType, nDims: Int, required: Boo address } - def loadElement(cb: EmitCodeBuilder, indices: IndexedSeq[Value[Long]], ndAddress: Value[Long]): SValue = { + def loadElement(cb: EmitCodeBuilder, indices: IndexedSeq[Value[Long]], ndAddress: Value[Long]) + : SValue = { val off = getElementAddress(cb, indices, ndAddress) elementType.loadCheapSCode(cb, elementType.loadFromNested(off)) } - def loadElementFromDataAndStrides(cb: EmitCodeBuilder, indices: IndexedSeq[Value[Long]], ndDataAddress: Value[Long], strides: IndexedSeq[Value[Long]]): Code[Long] = { + def loadElementFromDataAndStrides( + cb: EmitCodeBuilder, + indices: IndexedSeq[Value[Long]], + ndDataAddress: Value[Long], + strides: IndexedSeq[Value[Long]], + ): Code[Long] = { val off = getElementAddressFromDataPointerAndStrides(indices, ndDataAddress, strides, cb) elementType.loadFromNested(off) } def contentsByteSize(numElements: Long): Long = this.elementType.byteSize * numElements - def contentsByteSize(numElements: Code[Long]): Code[Long] = { + def contentsByteSize(numElements: Code[Long]): Code[Long] = numElements * elementType.byteSize - } def allocateData(shape: IndexedSeq[Value[Long]], region: Value[Region]): Code[Long] = { val sizeOfArray = this.contentsByteSize(this.numElements(shape).toL) @@ -164,35 +203,43 @@ final case class PCanonicalNDArray(elementType: PType, nDims: Int, required: Boo shape: IndexedSeq[SizeValue], strides: IndexedSeq[Value[Long]], cb: EmitCodeBuilder, - region: Value[Region] - ): SNDArrayPointerValue = { + region: Value[Region], + ): SNDArrayPointerValue = constructByCopyingDataPointer(shape, strides, this.allocateData(shape, region), cb, region) - } def constructUninitialized( shape: IndexedSeq[SizeValue], cb: EmitCodeBuilder, - region: Value[Region] - ): SNDArrayPointerValue = { - constructByCopyingDataPointer(shape, makeColumnMajorStrides(shape, cb), this.allocateData(shape, region), cb, region) - } + region: Value[Region], + ): SNDArrayPointerValue = + constructByCopyingDataPointer( + shape, + makeColumnMajorStrides(shape, cb), + this.allocateData(shape, region), + cb, + region, + ) def constructByCopyingArray( shape: IndexedSeq[Value[Long]], strides: IndexedSeq[Value[Long]], dataCode: SIndexableValue, cb: EmitCodeBuilder, - region: Value[Region] + region: Value[Region], ): SNDArrayValue = { - assert(shape.length == nDims, s"nDims = ${ nDims }, nShapeElts=${ shape.length }") - assert(strides.length == nDims, s"nDims = ${ nDims }, nShapeElts=${ strides.length }") + assert(shape.length == nDims, s"nDims = $nDims, nShapeElts=${shape.length}") + assert(strides.length == nDims, s"nDims = $nDims, nShapeElts=${strides.length}") val cacheKey = ("constructByCopyingArray", this, dataCode.st) - val mb = cb.emb.ecb.getOrGenEmitMethod("pcndarray_construct_by_copying_array", cacheKey, - FastSeq[ParamType](classInfo[Region], dataCode.st.paramType) ++ (0 until 2 * nDims).map(_ => CodeParamType(LongInfo)), - sType.paramType) { mb => + val mb = cb.emb.ecb.getOrGenEmitMethod( + "pcndarray_construct_by_copying_array", + cacheKey, + FastSeq[ParamType](classInfo[Region], dataCode.st.paramType) ++ (0 until 2 * nDims).map(_ => + CodeParamType(LongInfo) + ), + sType.paramType, + ) { mb => mb.emitSCode { cb => - val region = mb.getCodeParam[Region](1) val dataValue = mb.getSCodeParam(2).asIndexable val shape = (0 until nDims).map(i => SizeValueDyn(mb.getCodeParam[Long](3 + i))) @@ -201,13 +248,30 @@ final case class PCanonicalNDArray(elementType: PType, nDims: Int, required: Boo val result = constructUninitialized(shape, strides, cb, region) dataValue.st match { - case SIndexablePointer(PCanonicalArray(otherElementType, _)) if otherElementType == elementType => - cb += Region.copyFrom(dataValue.asInstanceOf[SIndexablePointerValue].elementsAddress, result.firstDataAddress, dataValue.loadLength().toL * elementType.byteSize) + case SIndexablePointer(PCanonicalArray(otherElementType, _)) + if otherElementType == elementType => + cb += Region.copyFrom( + dataValue.asInstanceOf[SIndexablePointerValue].elementsAddress, + result.firstDataAddress, + dataValue.loadLength().toL * elementType.byteSize, + ) case _ => val loopCtr = cb.newLocal[Long]("pcanonical_ndarray_construct_by_copying_loop_idx") - cb.for_(cb.assign(loopCtr, 0L), loopCtr < dataValue.loadLength().toL, cb.assign(loopCtr, loopCtr + 1L), { - elementType.storeAtAddress(cb, result.firstDataAddress + (loopCtr * elementType.byteSize), region, dataValue.loadElement(cb, loopCtr.toI).get(cb, "NDArray elements cannot be missing"), true) - }) + cb.for_( + cb.assign(loopCtr, 0L), + loopCtr < dataValue.loadLength().toL, + cb.assign(loopCtr, loopCtr + 1L), + elementType.storeAtAddress( + cb, + result.firstDataAddress + (loopCtr * elementType.byteSize), + region, + dataValue.loadElement(cb, loopCtr.toI).getOrFatal( + cb, + "NDArray elements cannot be missing", + ), + true, + ), + ) } result @@ -219,7 +283,12 @@ final case class PCanonicalNDArray(elementType: PType, nDims: Int, required: Boo case s => SizeValueDyn(s) } - cb.invokeSCode(mb, FastSeq[Param](region, SCodeParam(dataCode)) ++ (newShape.map(CodeParam(_)) ++ strides.map(CodeParam(_))): _*) + cb.invokeSCode( + mb, + FastSeq[Param](cb.this_, region, SCodeParam(dataCode)) ++ (newShape.map( + CodeParam(_) + ) ++ strides.map(CodeParam(_))): _* + ) .asNDArray .coerceToShape(cb, newShape) } @@ -228,7 +297,7 @@ final case class PCanonicalNDArray(elementType: PType, nDims: Int, required: Boo shape: IndexedSeq[Value[Long]], strides: IndexedSeq[Value[Long]], cb: EmitCodeBuilder, - region: Value[Region] + region: Value[Region], ): (Value[Long], EmitCodeBuilder => SNDArrayPointerValue) = { val newShape = shape.map { case s: SizeValue => s @@ -244,18 +313,34 @@ final case class PCanonicalNDArray(elementType: PType, nDims: Int, required: Boo strides: IndexedSeq[Value[Long]], dataPtr: Code[Long], cb: EmitCodeBuilder, - region: Value[Region] + region: Value[Region], ): SNDArrayPointerValue = { val ndAddr = cb.newLocal[Long]("ndarray_construct_addr") cb.assign(ndAddr, this.representation.allocate(region)) - shapeType.storeAtAddress(cb, cb.newLocal[Long]("construct_shape", this.representation.fieldOffset(ndAddr, "shape")), + shapeType.storeAtAddress( + cb, + cb.newLocal[Long]("construct_shape", this.representation.fieldOffset(ndAddr, "shape")), region, - SStackStruct.constructFromArgs(cb, region, shapeType.virtualType, shape.map(s => EmitCode.present(cb.emb, primitive(s))): _*), - false) - strideType.storeAtAddress(cb, cb.newLocal[Long]("construct_strides", this.representation.fieldOffset(ndAddr, "strides")), + SStackStruct.constructFromArgs( + cb, + region, + shapeType.virtualType, + shape.map(s => EmitCode.present(cb.emb, primitive(s))): _* + ), + false, + ) + strideType.storeAtAddress( + cb, + cb.newLocal[Long]("construct_strides", this.representation.fieldOffset(ndAddr, "strides")), region, - SStackStruct.constructFromArgs(cb, region, strideType.virtualType, strides.map(s => EmitCode.present(cb.emb, primitive(s))): _*), - false) + SStackStruct.constructFromArgs( + cb, + region, + strideType.virtualType, + strides.map(s => EmitCode.present(cb.emb, primitive(s))): _* + ), + false, + ) val newDataPointer = cb.newLocal("ndarray_construct_new_data_pointer", dataPtr) cb += Region.storeAddress(this.representation.fieldOffset(ndAddr, 2), newDataPointer) new SNDArrayPointerValue(sType, ndAddr, shape, strides, newDataPointer) @@ -264,23 +349,35 @@ final case class PCanonicalNDArray(elementType: PType, nDims: Int, required: Boo def constructByActuallyCopyingData( toBeCopied: SNDArrayValue, cb: EmitCodeBuilder, - region: Value[Region] + region: Value[Region], ): SNDArrayValue = { val oldDataAddr = toBeCopied.firstDataAddress - val numDataBytes = cb.newLocal("constructByActuallyCopyingData_numDataBytes", Region.getSharedChunkByteSize(oldDataAddr)) + val numDataBytes = cb.newLocal( + "constructByActuallyCopyingData_numDataBytes", + Region.getSharedChunkByteSize(oldDataAddr), + ) cb.if_(numDataBytes < 0L, cb._fatal("numDataBytes was ", numDataBytes.toS)) - val newDataAddr = cb.newLocal("constructByActuallyCopyingData_newDataAddr", region.allocateSharedChunk(numDataBytes)) + val newDataAddr = cb.newLocal( + "constructByActuallyCopyingData_newDataAddr", + region.allocateSharedChunk(numDataBytes), + ) cb += Region.copyFrom(oldDataAddr, newDataAddr, numDataBytes) constructByCopyingDataPointer( toBeCopied.shapes, toBeCopied.strides, newDataAddr, cb, - region + region, ) } - def _copyFromAddress(sm: HailStateManager, region: Region, srcPType: PType, srcAddress: Long, deepCopy: Boolean): Long = { + def _copyFromAddress( + sm: HailStateManager, + region: Region, + srcPType: PType, + srcAddress: Long, + deepCopy: Boolean, + ): Long = { val srcNDPType = srcPType.asInstanceOf[PCanonicalNDArray] assert(nDims == srcNDPType.nDims) @@ -299,14 +396,27 @@ final case class PCanonicalNDArray(elementType: PType, nDims: Int, required: Boo PCanonicalNDArray(this.elementType.deepRename(t.elementType), this.nDims, this.required) def setRequired(required: Boolean): PCanonicalNDArray = - if(required == this.required) this else PCanonicalNDArray(elementType, nDims, required) - - def unstagedStoreAtAddress(sm: HailStateManager, destAddress: Long, region: Region, srcPType: PType, srcAddress: Long, deepCopy: Boolean): Unit = { + if (required == this.required) this else PCanonicalNDArray(elementType, nDims, required) + + def unstagedStoreAtAddress( + sm: HailStateManager, + destAddress: Long, + region: Region, + srcPType: PType, + srcAddress: Long, + deepCopy: Boolean, + ): Unit = { val srcNDPType = srcPType.asInstanceOf[PCanonicalNDArray] assert(nDims == srcNDPType.nDims) if (equalModuloRequired(srcPType)) { // The situation where you can just memcpy - Region.copyFrom(srcAddress, destAddress, this.representation.field("shape").typ.byteSize + this.representation.field("strides").typ.byteSize) + Region.copyFrom( + srcAddress, + destAddress, + this.representation.field("shape").typ.byteSize + this.representation.field( + "strides" + ).typ.byteSize, + ) val srcDataAddress = srcNDPType.unstagedDataFirstElementPointer(srcAddress) @@ -319,13 +429,17 @@ final case class PCanonicalNDArray(elementType: PType, nDims: Int, required: Boo srcDataAddress } Region.storeAddress(this.representation.fieldOffset(destAddress, 2), newDataAddress) - } - else { // The situation where maybe the structs inside the ndarray have different requiredness + } else { // The situation where maybe the structs inside the ndarray have different requiredness val srcShape = srcPType.asInstanceOf[PNDArray].unstagedLoadShapes(srcAddress) val srcStrides = srcPType.asInstanceOf[PNDArray].unstagedLoadStrides(srcAddress) - shapeType.unstagedStoreJavaObjectAtAddress(sm, destAddress, Row(srcShape:_*), region) - strideType.unstagedStoreJavaObjectAtAddress(sm, destAddress + shapeType.byteSize, Row(srcStrides:_*), region) + shapeType.unstagedStoreJavaObjectAtAddress(sm, destAddress, Row(srcShape: _*), region) + strideType.unstagedStoreJavaObjectAtAddress( + sm, + destAddress + shapeType.byteSize, + Row(srcStrides: _*), + region, + ) val newDataPointer = this.allocateData(srcShape, region) Region.storeLong(this.representation.fieldOffset(destAddress, 2), newDataPointer) @@ -336,7 +450,14 @@ final case class PCanonicalNDArray(elementType: PType, nDims: Int, required: Boo SNDArray.unstagedForEachIndex(srcShape) { indices => val srcElementAddress = srcNDPType.getElementAddress(indices, srcAddress) - this.elementType.unstagedStoreAtAddress(sm, currentAddressToWrite, region, srcNDPType.elementType, srcElementAddress, true) + this.elementType.unstagedStoreAtAddress( + sm, + currentAddressToWrite, + region, + srcNDPType.elementType, + srcElementAddress, + true, + ) currentAddressToWrite += elementType.byteSize } } @@ -347,33 +468,62 @@ final case class PCanonicalNDArray(elementType: PType, nDims: Int, required: Boo def loadCheapSCode(cb: EmitCodeBuilder, addr: Code[Long]): SNDArrayPointerValue = { val a = cb.memoize(addr) val shapeTuple = shapeType.loadCheapSCode(cb, representation.loadField(a, "shape")) - val shape = Array.tabulate(nDims)(i => SizeValueDyn(shapeTuple.loadField(cb, i).get(cb).asLong.value)) + val shape = + Array.tabulate(nDims)(i => + SizeValueDyn(shapeTuple.loadField(cb, i).getOrAssert(cb).asLong.value) + ) val strideTuple = strideType.loadCheapSCode(cb, representation.loadField(a, "strides")) - val strides = Array.tabulate(nDims)(strideTuple.loadField(cb, _).get(cb).asLong.value) + val strides = Array.tabulate(nDims)(strideTuple.loadField(cb, _).getOrAssert(cb).asLong.value) val firstDataAddress = cb.memoize(dataFirstElementPointer(a)) new SNDArrayPointerValue(sType, a, shape, strides, firstDataAddress) } - def store(cb: EmitCodeBuilder, region: Value[Region], value: SValue, deepCopy: Boolean): Value[Long] = { + def store(cb: EmitCodeBuilder, region: Value[Region], value: SValue, deepCopy: Boolean) + : Value[Long] = { val addr = cb.memoize(this.representation.allocate(region)) storeAtAddress(cb, addr, region, value, deepCopy) addr } - def storeAtAddress(cb: EmitCodeBuilder, addr: Code[Long], region: Value[Region], value: SValue, deepCopy: Boolean): Unit = { + def storeAtAddress( + cb: EmitCodeBuilder, + addr: Code[Long], + region: Value[Region], + value: SValue, + deepCopy: Boolean, + ): Unit = { val targetAddr = cb.newLocal[Long]("pcanonical_ndarray_store_at_addr_target", addr) val inputSNDValue = value.asNDArray val shape = inputSNDValue.shapes val strides = inputSNDValue.strides val dataAddr = inputSNDValue.firstDataAddress - shapeType.storeAtAddress(cb, cb.newLocal[Long]("construct_shape", this.representation.fieldOffset(targetAddr, "shape")), + shapeType.storeAtAddress( + cb, + cb.newLocal[Long]("construct_shape", this.representation.fieldOffset(targetAddr, "shape")), region, - SStackStruct.constructFromArgs(cb, region, shapeType.virtualType, shape.map(s => EmitCode.present(cb.emb, primitive(s))): _*), - false) - strideType.storeAtAddress(cb, cb.newLocal[Long]("construct_strides", this.representation.fieldOffset(targetAddr, "strides")), + SStackStruct.constructFromArgs( + cb, + region, + shapeType.virtualType, + shape.map(s => EmitCode.present(cb.emb, primitive(s))): _* + ), + false, + ) + strideType.storeAtAddress( + cb, + cb.newLocal[Long]( + "construct_strides", + this.representation.fieldOffset(targetAddr, "strides"), + ), region, - SStackStruct.constructFromArgs(cb, region, strideType.virtualType, strides.map(s => EmitCode.present(cb.emb, primitive(s))): _*), - false) + SStackStruct.constructFromArgs( + cb, + region, + strideType.virtualType, + strides.map(s => EmitCode.present(cb.emb, primitive(s))): _* + ), + false, + ) value.st match { case SNDArrayPointer(t) if t.equalModuloRequired(this) => @@ -385,8 +535,8 @@ final case class PCanonicalNDArray(elementType: PType, nDims: Int, required: Boo val newDataAddr = this.allocateData(shape, region) cb += Region.storeAddress(this.representation.fieldOffset(targetAddr, "data"), newDataAddr) val outputSNDValue = loadCheapSCode(cb, targetAddr) - outputSNDValue.coiterateMutate(cb, region, true, (inputSNDValue, "input")){ - case Seq(dest, elt) => + outputSNDValue.coiterateMutate(cb, region, true, (inputSNDValue, "input")) { + case Seq(_, elt) => elt } } @@ -395,19 +545,26 @@ final case class PCanonicalNDArray(elementType: PType, nDims: Int, required: Boo def unstagedDataFirstElementPointer(ndAddr: Long): Long = Region.loadAddress(representation.loadField(ndAddr, 2)) - override def dataFirstElementPointer(ndAddr: Code[Long]): Code[Long] = Region.loadAddress(representation.loadField(ndAddr, "data")) + override def dataFirstElementPointer(ndAddr: Code[Long]): Code[Long] = + Region.loadAddress(representation.loadField(ndAddr, "data")) def loadFromNested(addr: Code[Long]): Code[Long] = addr override def unstagedLoadFromNested(addr: Long): Long = addr - override def unstagedStoreJavaObject(sm: HailStateManager, annotation: Annotation, region: Region): Long = { + override def unstagedStoreJavaObject(sm: HailStateManager, annotation: Annotation, region: Region) + : Long = { val addr = this.representation.allocate(region) this.unstagedStoreJavaObjectAtAddress(sm, addr, annotation, region) addr } - override def unstagedStoreJavaObjectAtAddress(sm: HailStateManager, addr: Long, annotation: Annotation, region: Region): Unit = { + override def unstagedStoreJavaObjectAtAddress( + sm: HailStateManager, + addr: Long, + annotation: Annotation, + region: Region, + ): Unit = { val aNDArray = annotation.asInstanceOf[NDArray] var runningProduct = this.elementType.byteSize @@ -418,16 +575,20 @@ final case class PCanonicalNDArray(elementType: PType, nDims: Int, required: Boo } val dataFirstElementAddress = this.allocateData(aNDArray.shape, region) var curElementAddress = dataFirstElementAddress - aNDArray.getRowMajorElements().foreach{ element => + aNDArray.getRowMajorElements().foreach { element => elementType.unstagedStoreJavaObjectAtAddress(sm, curElementAddress, element, region) curElementAddress += elementType.byteSize } val shapeRow = Row(aNDArray.shape: _*) val stridesRow = Row(stridesArray: _*) - this.representation.unstagedStoreJavaObjectAtAddress(sm, addr, Row(shapeRow, stridesRow, dataFirstElementAddress), region) + this.representation.unstagedStoreJavaObjectAtAddress( + sm, + addr, + Row(shapeRow, stridesRow, dataFirstElementAddress), + region, + ) } - override def copiedType: PType = { val copiedElement = elementType.copiedType if (copiedElement.eq(elementType)) diff --git a/hail/src/main/scala/is/hail/types/physical/PCanonicalSet.scala b/hail/src/main/scala/is/hail/types/physical/PCanonicalSet.scala index d1ec60b771d..0566932fc8f 100644 --- a/hail/src/main/scala/is/hail/types/physical/PCanonicalSet.scala +++ b/hail/src/main/scala/is/hail/types/physical/PCanonicalSet.scala @@ -8,22 +8,23 @@ import is.hail.types.virtual.{TSet, Type} import is.hail.utils._ object PCanonicalSet { - def coerceArrayCode(contents: SIndexableValue): SIndexableValue = { + def coerceArrayCode(contents: SIndexableValue): SIndexableValue = contents.st match { case SIndexablePointer(PCanonicalArray(elt, r)) => PCanonicalSet(elt, r).construct(contents) } - } } -final case class PCanonicalSet(elementType: PType, required: Boolean = false) extends PSet with PArrayBackedContainer { +final case class PCanonicalSet(elementType: PType, required: Boolean = false) + extends PSet with PArrayBackedContainer { val arrayRep = PCanonicalArray(elementType, required) - def setRequired(required: Boolean) = if (required == this.required) this else PCanonicalSet(elementType, required) + def setRequired(required: Boolean) = + if (required == this.required) this else PCanonicalSet(elementType, required) def _asIdent = s"set_of_${elementType.asIdent}" - override def _pretty(sb: StringBuilder, indent: Int, compact: Boolean = false) { + override def _pretty(sb: StringBuilder, indent: Int, compact: Boolean = false): Unit = { sb.append("PCSet[") elementType.pretty(sb, indent, compact) sb.append("]") @@ -34,7 +35,8 @@ final case class PCanonicalSet(elementType: PType, required: Boolean = false) e private def deepRenameSet(t: TSet) = PCanonicalSet(this.elementType.deepRename(t.elementType), this.required) - override def unstagedStoreJavaObject(sm: HailStateManager, annotation: Annotation, region: Region): Long = { + override def unstagedStoreJavaObject(sm: HailStateManager, annotation: Annotation, region: Region) + : Long = { val s: IndexedSeq[Annotation] = annotation.asInstanceOf[Set[Annotation]] .toFastSeq .sorted(elementType.virtualType.ordering(sm).toOrdering) @@ -43,7 +45,10 @@ final case class PCanonicalSet(elementType: PType, required: Boolean = false) e def construct(_contents: SIndexableValue): SIndexableValue = { val contents = _contents.asInstanceOf[SIndexablePointerValue] - assert(contents.pt.equalModuloRequired(arrayRep), s"\n contents: ${ contents.pt }\n arrayrep: ${ arrayRep }") + assert( + contents.pt.equalModuloRequired(arrayRep), + s"\n contents: ${contents.pt}\n arrayrep: $arrayRep", + ) val cont = contents.asInstanceOf[SIndexablePointerValue] new SIndexablePointerValue(SIndexablePointer(this), cont.a, cont.length, cont.elementsAddress) } diff --git a/hail/src/main/scala/is/hail/types/physical/PCanonicalString.scala b/hail/src/main/scala/is/hail/types/physical/PCanonicalString.scala index 8c49f452c19..ad595d374da 100644 --- a/hail/src/main/scala/is/hail/types/physical/PCanonicalString.scala +++ b/hail/src/main/scala/is/hail/types/physical/PCanonicalString.scala @@ -14,14 +14,27 @@ case object PCanonicalStringRequired extends PCanonicalString(true) class PCanonicalString(val required: Boolean) extends PString { def _asIdent = "string" - override def _pretty(sb: StringBuilder, indent: Int, compact: Boolean): Unit = sb.append("PCString") + override def _pretty(sb: StringBuilder, indent: Int, compact: Boolean): Unit = + sb.append("PCString") override def byteSize: Long = 8 lazy val binaryRepresentation: PCanonicalBinary = PCanonicalBinary(required) - override def _copyFromAddress(sm: HailStateManager, region: Region, srcPType: PType, srcAddress: Long, deepCopy: Boolean): Long = - binaryRepresentation.copyFromAddress(sm, region, srcPType.asInstanceOf[PString].binaryRepresentation, srcAddress, deepCopy) + override def _copyFromAddress( + sm: HailStateManager, + region: Region, + srcPType: PType, + srcAddress: Long, + deepCopy: Boolean, + ): Long = + binaryRepresentation.copyFromAddress( + sm, + region, + srcPType.asInstanceOf[PString].binaryRepresentation, + srcAddress, + deepCopy, + ) override def copiedType: PType = this @@ -46,7 +59,8 @@ class PCanonicalString(val required: Boolean) extends PString { dstAddrss } - def allocateAndStoreString(cb: EmitCodeBuilder, region: Value[Region], str: Code[String]): Value[Long] = { + def allocateAndStoreString(cb: EmitCodeBuilder, region: Value[Region], str: Code[String]) + : Value[Long] = { val dstAddress = cb.newField[Long]("pcanonical_string_alloc_dst_address") val byteRep = cb.newField[Array[Byte]]("pcanonical_string_alloc_byte_rep") cb.assign(byteRep, str.invoke[Array[Byte]]("getBytes")) @@ -55,8 +69,22 @@ class PCanonicalString(val required: Boolean) extends PString { dstAddress } - override def unstagedStoreAtAddress(sm: HailStateManager, addr: Long, region: Region, srcPType: PType, srcAddress: Long, deepCopy: Boolean): Unit = - binaryRepresentation.unstagedStoreAtAddress(sm, addr, region, srcPType.asInstanceOf[PString].binaryRepresentation, srcAddress, deepCopy) + override def unstagedStoreAtAddress( + sm: HailStateManager, + addr: Long, + region: Region, + srcPType: PType, + srcAddress: Long, + deepCopy: Boolean, + ): Unit = + binaryRepresentation.unstagedStoreAtAddress( + sm, + addr, + region, + srcPType.asInstanceOf[PString].binaryRepresentation, + srcAddress, + deepCopy, + ) def setRequired(required: Boolean): PCanonicalString = if (required == this.required) this else PCanonicalString(required) @@ -66,7 +94,8 @@ class PCanonicalString(val required: Boolean) extends PString { def loadCheapSCode(cb: EmitCodeBuilder, addr: Code[Long]): SStringPointerValue = new SStringPointerValue(sType, cb.memoize(addr)) - def store(cb: EmitCodeBuilder, region: Value[Region], value: SValue, deepCopy: Boolean): Value[Long] = { + def store(cb: EmitCodeBuilder, region: Value[Region], value: SValue, deepCopy: Boolean) + : Value[Long] = { value.st match { case SStringPointer(t) if t.equalModuloRequired(this) && !deepCopy => value.asInstanceOf[SStringPointerValue].a @@ -75,25 +104,45 @@ class PCanonicalString(val required: Boolean) extends PString { } } - def storeAtAddress(cb: EmitCodeBuilder, addr: Code[Long], region: Value[Region], value: SValue, deepCopy: Boolean): Unit = { + def storeAtAddress( + cb: EmitCodeBuilder, + addr: Code[Long], + region: Value[Region], + value: SValue, + deepCopy: Boolean, + ): Unit = cb += Region.storeAddress(addr, store(cb, region, value, deepCopy)) - } def loadFromNested(addr: Code[Long]): Code[Long] = binaryRepresentation.loadFromNested(addr) - override def unstagedLoadFromNested(addr: Long): Long = binaryRepresentation.unstagedLoadFromNested(addr) - - override def unstagedStoreJavaObject(sm: HailStateManager, annotation: Annotation, region: Region): Long = { - binaryRepresentation.unstagedStoreJavaObject(sm, annotation.asInstanceOf[String].getBytes(), region) - } - - override def unstagedStoreJavaObjectAtAddress(sm: HailStateManager, addr: Long, annotation: Annotation, region: Region): Unit = { - binaryRepresentation.unstagedStoreJavaObjectAtAddress(sm, addr, annotation.asInstanceOf[String].getBytes(), region) - } + override def unstagedLoadFromNested(addr: Long): Long = + binaryRepresentation.unstagedLoadFromNested(addr) + + override def unstagedStoreJavaObject(sm: HailStateManager, annotation: Annotation, region: Region) + : Long = + binaryRepresentation.unstagedStoreJavaObject( + sm, + annotation.asInstanceOf[String].getBytes(), + region, + ) + + override def unstagedStoreJavaObjectAtAddress( + sm: HailStateManager, + addr: Long, + annotation: Annotation, + region: Region, + ): Unit = + binaryRepresentation.unstagedStoreJavaObjectAtAddress( + sm, + addr, + annotation.asInstanceOf[String].getBytes(), + region, + ) } object PCanonicalString { - def apply(required: Boolean = false): PCanonicalString = if (required) PCanonicalStringRequired else PCanonicalStringOptional + def apply(required: Boolean = false): PCanonicalString = + if (required) PCanonicalStringRequired else PCanonicalStringOptional def unapply(t: PString): Option[Boolean] = Option(t.required) } diff --git a/hail/src/main/scala/is/hail/types/physical/PCanonicalStruct.scala b/hail/src/main/scala/is/hail/types/physical/PCanonicalStruct.scala index 74912909e5e..48551900e98 100644 --- a/hail/src/main/scala/is/hail/types/physical/PCanonicalStruct.scala +++ b/hail/src/main/scala/is/hail/types/physical/PCanonicalStruct.scala @@ -14,39 +14,49 @@ object PCanonicalStruct { def empty(required: Boolean = false): PStruct = if (required) requiredEmpty else optionalEmpty def apply(required: Boolean, args: (String, PType)*): PCanonicalStruct = - PCanonicalStruct(args - .iterator - .zipWithIndex - .map { case ((n, t), i) => PField(n, t, i) } - .toFastSeq, - required) - - def apply(names: java.util.List[String], types: java.util.List[PType], required: Boolean): PCanonicalStruct = { + PCanonicalStruct( + args + .iterator + .zipWithIndex + .map { case ((n, t), i) => PField(n, t, i) } + .toFastSeq, + required, + ) + + def apply(names: java.util.List[String], types: java.util.List[PType], required: Boolean) + : PCanonicalStruct = { val sNames = names.asScala.toArray val sTypes = types.asScala.toArray if (sNames.length != sTypes.length) - fatal(s"number of names does not match number of types: found ${ sNames.length } names and ${ sTypes.length } types") + fatal( + s"number of names does not match number of types: found ${sNames.length} names and ${sTypes.length} types" + ) PCanonicalStruct(required, sNames.zip(sTypes): _*) } def apply(args: (String, PType)*): PCanonicalStruct = - PCanonicalStruct(false, args:_*) + PCanonicalStruct(false, args: _*) def canonical(t: Type): PCanonicalStruct = PType.canonical(t).asInstanceOf[PCanonicalStruct] def canonical(t: PType): PCanonicalStruct = PType.canonical(t).asInstanceOf[PCanonicalStruct] } -final case class PCanonicalStruct(fields: IndexedSeq[PField], required: Boolean = false) extends PCanonicalBaseStruct(fields.map(_.typ).toArray) with PStruct { - assert(fields.zipWithIndex.forall { case (f, i) => f.index == i }) +final case class PCanonicalStruct(fields: IndexedSeq[PField], required: Boolean = false) + extends PCanonicalBaseStruct(fields.map(_.typ).toArray) with PStruct { + assert(fields.zipWithIndex.forall { case (f, i) => f.index == i }) if (!fieldNames.areDistinct()) { val duplicates = fieldNames.duplicates() - fatal(s"cannot create struct with duplicate ${plural(duplicates.size, "field")}: " + - s"${fieldNames.map(prettyIdentifier).mkString(", ")}", fieldNames.duplicates()) + fatal( + s"cannot create struct with duplicate ${plural(duplicates.size, "field")}: " + + s"${fieldNames.map(prettyIdentifier).mkString(", ")}", + fieldNames.duplicates(), + ) } - override def setRequired(required: Boolean): PCanonicalStruct = if(required == this.required) this else PCanonicalStruct(fields, required) + override def setRequired(required: Boolean): PCanonicalStruct = + if (required == this.required) this else PCanonicalStruct(fields, required) override def rename(m: Map[String, String]): PStruct = { val newFieldsBuilder = new BoxedArrayBuilder[(String, PType)]() @@ -57,7 +67,7 @@ final case class PCanonicalStruct(fields: IndexedSeq[PField], required: Boolean PCanonicalStruct(required, newFieldsBuilder.result(): _*) } - override def _pretty(sb: StringBuilder, indent: Int, compact: Boolean) { + override def _pretty(sb: StringBuilder, indent: Int, compact: Boolean): Unit = { if (compact) { sb.append("PCStruct{") fields.foreachBetween(_.pretty(sb, indent, compact))(sb += ',') @@ -79,7 +89,8 @@ final case class PCanonicalStruct(fields: IndexedSeq[PField], required: Boolean override def loadField(offset: Code[Long], fieldName: String): Code[Long] = loadField(offset, fieldIdx(fieldName)) - override def isFieldMissing(cb: EmitCodeBuilder, offset: Code[Long], field: String): Value[Boolean] = + override def isFieldMissing(cb: EmitCodeBuilder, offset: Code[Long], field: String) + : Value[Boolean] = isFieldMissing(cb, offset, fieldIdx(field)) override def fieldOffset(offset: Code[Long], fieldName: String): Code[Long] = @@ -112,12 +123,14 @@ final case class PCanonicalStruct(fields: IndexedSeq[PField], required: Boolean override def deepRename(t: Type): PType = deepRenameStruct(t.asInstanceOf[TStruct]) - private def deepRenameStruct(t: TStruct): PStruct = { - PCanonicalStruct((t.fields, this.fields).zipped.map( (tfield, pfield) => { - assert(tfield.index == pfield.index) - PField(tfield.name, pfield.typ.deepRename(tfield.typ), pfield.index) - }), this.required) - } + private def deepRenameStruct(t: TStruct): PStruct = + PCanonicalStruct( + (t.fields, this.fields).zipped.map { (tfield, pfield) => + assert(tfield.index == pfield.index) + PField(tfield.name, pfield.typ.deepRename(tfield.typ), pfield.index) + }, + this.required, + ) override def copiedType: PType = { val copiedTypes = types.map(_.copiedType) diff --git a/hail/src/main/scala/is/hail/types/physical/PCanonicalTuple.scala b/hail/src/main/scala/is/hail/types/physical/PCanonicalTuple.scala index 887934895b4..1f3af38e5e3 100644 --- a/hail/src/main/scala/is/hail/types/physical/PCanonicalTuple.scala +++ b/hail/src/main/scala/is/hail/types/physical/PCanonicalTuple.scala @@ -1,19 +1,27 @@ package is.hail.types.physical -import is.hail.annotations.UnsafeUtils -import is.hail.types.BaseStruct + import is.hail.types.virtual.{TTuple, Type} import is.hail.utils._ object PCanonicalTuple { - def apply(required: Boolean, args: PType*): PCanonicalTuple = PCanonicalTuple(args.iterator.zipWithIndex.map { case (t, i) => PTupleField(i, t)}.toIndexedSeq, required) + def apply(required: Boolean, args: PType*): PCanonicalTuple = PCanonicalTuple( + args.iterator.zipWithIndex.map { case (t, i) => PTupleField(i, t) }.toIndexedSeq, + required, + ) } -final case class PCanonicalTuple(_types: IndexedSeq[PTupleField], override val required: Boolean = false) extends PCanonicalBaseStruct(_types.map(_.typ).toArray) with PTuple { - lazy val fieldIndex: Map[Int, Int] = _types.zipWithIndex.map { case (tf, idx) => tf.index -> idx }.toMap +final case class PCanonicalTuple( + _types: IndexedSeq[PTupleField], + override val required: Boolean = false, +) extends PCanonicalBaseStruct(_types.map(_.typ).toArray) with PTuple { + lazy val fieldIndex: Map[Int, Int] = _types.zipWithIndex.map { case (tf, idx) => + tf.index -> idx + }.toMap - def setRequired(required: Boolean) = if(required == this.required) this else PCanonicalTuple(_types, required) + def setRequired(required: Boolean) = + if (required == this.required) this else PCanonicalTuple(_types, required) - override def _pretty(sb: StringBuilder, indent: Int, compact: Boolean) { + override def _pretty(sb: StringBuilder, indent: Int, compact: Boolean): Unit = { sb.append("PCTuple[") _types.foreachBetween { fd => sb.append(fd.index) @@ -25,12 +33,14 @@ final case class PCanonicalTuple(_types: IndexedSeq[PTupleField], override val r override def deepRename(t: Type) = deepTupleRename(t.asInstanceOf[TTuple]) - private def deepTupleRename(t: TTuple) = { - PCanonicalTuple((t._types, this._types).zipped.map( (tfield, pfield) => { - assert(tfield.index == pfield.index) - PTupleField(pfield.index, pfield.typ.deepRename(tfield.typ)) - }), this.required) - } + private def deepTupleRename(t: TTuple) = + PCanonicalTuple( + (t._types, this._types).zipped.map { (tfield, pfield) => + assert(tfield.index == pfield.index) + PTupleField(pfield.index, pfield.typ.deepRename(tfield.typ)) + }, + this.required, + ) def copiedType: PType = { val copiedTypes = types.map(_.copiedType) diff --git a/hail/src/main/scala/is/hail/types/physical/PContainer.scala b/hail/src/main/scala/is/hail/types/physical/PContainer.scala index 21198f62b30..776a90a4d3b 100644 --- a/hail/src/main/scala/is/hail/types/physical/PContainer.scala +++ b/hail/src/main/scala/is/hail/types/physical/PContainer.scala @@ -25,11 +25,11 @@ abstract class PContainer extends PIterable { def isElementDefined(aoff: Code[Long], i: Code[Int]): Code[Boolean] - def setElementMissing(aoff: Long, i: Int) + def setElementMissing(aoff: Long, i: Int): Unit def setElementMissing(cb: EmitCodeBuilder, aoff: Code[Long], i: Code[Int]): Unit - def setElementPresent(aoff: Long, i: Int) + def setElementPresent(aoff: Long, i: Int): Unit def setElementPresent(cb: EmitCodeBuilder, aoff: Code[Long], i: Code[Int]): Unit @@ -63,13 +63,18 @@ abstract class PContainer extends PIterable { def allocate(region: Code[Region], length: Code[Int]): Code[Long] - def setAllMissingBits(aoff: Long, length: Int) + def setAllMissingBits(aoff: Long, length: Int): Unit - def clearMissingBits(aoff: Long, length: Int) + def clearMissingBits(aoff: Long, length: Int): Unit - def initialize(aoff: Long, length: Int, setMissing: Boolean = false) + def initialize(aoff: Long, length: Int, setMissing: Boolean = false): Unit - def stagedInitialize(cb: EmitCodeBuilder, aoff: Code[Long], length: Code[Int], setMissing: Boolean = false): Unit + def stagedInitialize( + cb: EmitCodeBuilder, + aoff: Code[Long], + length: Code[Int], + setMissing: Boolean = false, + ): Unit def zeroes(region: Region, length: Int): Long @@ -91,10 +96,10 @@ abstract class PContainer extends PIterable { } object PContainer { - def unsafeSetElementMissing(cb: EmitCodeBuilder, p: PContainer, aoff: Code[Long], i: Code[Int]): Unit = { + def unsafeSetElementMissing(cb: EmitCodeBuilder, p: PContainer, aoff: Code[Long], i: Code[Int]) + : Unit = if (p.elementType.required) cb._fatal("Missing element at index ", i.toS, s" of ptype ${p.elementType.asIdent}'.") else p.setElementMissing(cb, aoff, i) - } -} \ No newline at end of file +} diff --git a/hail/src/main/scala/is/hail/types/physical/PField.scala b/hail/src/main/scala/is/hail/types/physical/PField.scala index bd8603c0c8a..72d702fe8fd 100644 --- a/hail/src/main/scala/is/hail/types/physical/PField.scala +++ b/hail/src/main/scala/is/hail/types/physical/PField.scala @@ -3,7 +3,7 @@ package is.hail.types.physical import is.hail.utils._ final case class PField(name: String, typ: PType, index: Int) { - def pretty(sb: StringBuilder, indent: Int, compact: Boolean) { + def pretty(sb: StringBuilder, indent: Int, compact: Boolean): Unit = { if (compact) { sb.append(prettyIdentifier(name)) sb.append(":") diff --git a/hail/src/main/scala/is/hail/types/physical/PFloat32.scala b/hail/src/main/scala/is/hail/types/physical/PFloat32.scala index 5924ef06c2a..ebd98990ec0 100644 --- a/hail/src/main/scala/is/hail/types/physical/PFloat32.scala +++ b/hail/src/main/scala/is/hail/types/physical/PFloat32.scala @@ -4,8 +4,8 @@ import is.hail.annotations._ import is.hail.asm4s.{Code, _} import is.hail.backend.HailStateManager import is.hail.expr.ir.EmitCodeBuilder -import is.hail.types.physical.stypes.primitives.{SFloat32, SFloat32Value} import is.hail.types.physical.stypes.{SType, SValue} +import is.hail.types.physical.stypes.primitives.{SFloat32, SFloat32Value} import is.hail.types.virtual.TFloat32 case object PFloat32Optional extends PFloat32(false) @@ -18,25 +18,23 @@ class PFloat32(override val required: Boolean) extends PNumeric with PPrimitive def _asIdent = "float32" - override def _pretty(sb: StringBuilder, indent: Int, compact: Boolean): Unit = sb.append("PFloat32") + override def _pretty(sb: StringBuilder, indent: Int, compact: Boolean): Unit = + sb.append("PFloat32") override def unsafeOrdering(sm: HailStateManager): UnsafeOrdering = new UnsafeOrdering { - def compare(o1: Long, o2: Long): Int = { + def compare(o1: Long, o2: Long): Int = java.lang.Float.compare(Region.loadFloat(o1), Region.loadFloat(o2)) - } } override def byteSize: Long = 4 override def zero = coerce[PFloat32](const(0.0f)) - override def add(a: Code[_], b: Code[_]): Code[PFloat32] = { + override def add(a: Code[_], b: Code[_]): Code[PFloat32] = coerce[PFloat32](coerce[Float](a) + coerce[Float](b)) - } - override def multiply(a: Code[_], b: Code[_]): Code[PFloat32] = { + override def multiply(a: Code[_], b: Code[_]): Code[PFloat32] = coerce[PFloat32](coerce[Float](a) * coerce[Float](b)) - } override def sType: SType = SFloat32 @@ -46,13 +44,18 @@ class PFloat32(override val required: Boolean) extends PNumeric with PPrimitive override def loadCheapSCode(cb: EmitCodeBuilder, addr: Code[Long]): SFloat32Value = new SFloat32Value(cb.memoize(Region.loadFloat(addr))) - override def unstagedStoreJavaObjectAtAddress(sm: HailStateManager, addr: Long, annotation: Annotation, region: Region): Unit = { + override def unstagedStoreJavaObjectAtAddress( + sm: HailStateManager, + addr: Long, + annotation: Annotation, + region: Region, + ): Unit = Region.storeFloat(addr, annotation.asInstanceOf[Float]) - } } object PFloat32 { - def apply(required: Boolean = false): PFloat32 = if (required) PFloat32Required else PFloat32Optional + def apply(required: Boolean = false): PFloat32 = + if (required) PFloat32Required else PFloat32Optional def unapply(t: PFloat32): Option[Boolean] = Option(t.required) } diff --git a/hail/src/main/scala/is/hail/types/physical/PFloat64.scala b/hail/src/main/scala/is/hail/types/physical/PFloat64.scala index b748db1ce5e..9eaf21e7239 100644 --- a/hail/src/main/scala/is/hail/types/physical/PFloat64.scala +++ b/hail/src/main/scala/is/hail/types/physical/PFloat64.scala @@ -4,8 +4,8 @@ import is.hail.annotations._ import is.hail.asm4s.{Code, _} import is.hail.backend.HailStateManager import is.hail.expr.ir.EmitCodeBuilder -import is.hail.types.physical.stypes.primitives.{SFloat64, SFloat64Value} import is.hail.types.physical.stypes.{SType, SValue} +import is.hail.types.physical.stypes.primitives.{SFloat64, SFloat64Value} import is.hail.types.virtual.TFloat64 case object PFloat64Optional extends PFloat64(false) @@ -19,25 +19,23 @@ class PFloat64(override val required: Boolean) extends PNumeric with PPrimitive def _asIdent = "float64" - override def _pretty(sb: StringBuilder, indent: Int, compact: Boolean): Unit = sb.append("PFloat64") + override def _pretty(sb: StringBuilder, indent: Int, compact: Boolean): Unit = + sb.append("PFloat64") override def unsafeOrdering(sm: HailStateManager): UnsafeOrdering = new UnsafeOrdering { - def compare(o1: Long, o2: Long): Int = { + def compare(o1: Long, o2: Long): Int = java.lang.Double.compare(Region.loadDouble(o1), Region.loadDouble(o2)) - } } override def byteSize: Long = 8 override def zero = coerce[PFloat64](const(0.0)) - override def add(a: Code[_], b: Code[_]): Code[PFloat64] = { + override def add(a: Code[_], b: Code[_]): Code[PFloat64] = coerce[PFloat64](coerce[Double](a) + coerce[Double](b)) - } - override def multiply(a: Code[_], b: Code[_]): Code[PFloat64] = { + override def multiply(a: Code[_], b: Code[_]): Code[PFloat64] = coerce[PFloat64](coerce[Double](a) * coerce[Double](b)) - } override def sType: SType = SFloat64 @@ -47,13 +45,18 @@ class PFloat64(override val required: Boolean) extends PNumeric with PPrimitive override def loadCheapSCode(cb: EmitCodeBuilder, addr: Code[Long]): SFloat64Value = new SFloat64Value(cb.memoize(Region.loadDouble(addr))) - override def unstagedStoreJavaObjectAtAddress(sm: HailStateManager, addr: Long, annotation: Annotation, region: Region): Unit = { + override def unstagedStoreJavaObjectAtAddress( + sm: HailStateManager, + addr: Long, + annotation: Annotation, + region: Region, + ): Unit = Region.storeDouble(addr, annotation.asInstanceOf[Double]) - } } object PFloat64 { - def apply(required: Boolean = false): PFloat64 = if (required) PFloat64Required else PFloat64Optional + def apply(required: Boolean = false): PFloat64 = + if (required) PFloat64Required else PFloat64Optional def unapply(t: PFloat64): Option[Boolean] = Option(t.required) } diff --git a/hail/src/main/scala/is/hail/types/physical/PInt32.scala b/hail/src/main/scala/is/hail/types/physical/PInt32.scala index 0fdc55c271e..31f97ec1260 100644 --- a/hail/src/main/scala/is/hail/types/physical/PInt32.scala +++ b/hail/src/main/scala/is/hail/types/physical/PInt32.scala @@ -1,11 +1,11 @@ package is.hail.types.physical import is.hail.annotations._ -import is.hail.asm4s.{Code, coerce, const, _} +import is.hail.asm4s.{coerce, const, Code, _} import is.hail.backend.HailStateManager import is.hail.expr.ir.EmitCodeBuilder -import is.hail.types.physical.stypes.primitives.{SInt32, SInt32Value} import is.hail.types.physical.stypes.{SType, SValue} +import is.hail.types.physical.stypes.primitives.{SInt32, SInt32Value} import is.hail.types.virtual.TInt32 case object PInt32Optional extends PInt32(false) @@ -18,22 +18,19 @@ class PInt32(override val required: Boolean) extends PNumeric with PPrimitive { override type NType = PInt32 override def unsafeOrdering(sm: HailStateManager): UnsafeOrdering = new UnsafeOrdering { - def compare(o1: Long, o2: Long): Int = { + def compare(o1: Long, o2: Long): Int = Integer.compare(Region.loadInt(o1), Region.loadInt(o2)) - } } override def byteSize: Long = 4 override def zero = coerce[PInt32](const(0)) - override def add(a: Code[_], b: Code[_]): Code[PInt32] = { + override def add(a: Code[_], b: Code[_]): Code[PInt32] = coerce[PInt32](coerce[Int](a) + coerce[Int](b)) - } - override def multiply(a: Code[_], b: Code[_]): Code[PInt32] = { + override def multiply(a: Code[_], b: Code[_]): Code[PInt32] = coerce[PInt32](coerce[Int](a) * coerce[Int](b)) - } override def sType: SType = SInt32 @@ -43,9 +40,13 @@ class PInt32(override val required: Boolean) extends PNumeric with PPrimitive { override def loadCheapSCode(cb: EmitCodeBuilder, addr: Code[Long]): SInt32Value = new SInt32Value(cb.memoize(Region.loadInt(addr))) - override def unstagedStoreJavaObjectAtAddress(sm: HailStateManager, addr: Long, annotation: Annotation, region: Region): Unit = { + override def unstagedStoreJavaObjectAtAddress( + sm: HailStateManager, + addr: Long, + annotation: Annotation, + region: Region, + ): Unit = Region.storeInt(addr, annotation.asInstanceOf[Int]) - } def unstagedLoadFromAddress(addr: Long): Int = Region.loadInt(addr) } diff --git a/hail/src/main/scala/is/hail/types/physical/PInt64.scala b/hail/src/main/scala/is/hail/types/physical/PInt64.scala index 055b8dbbcfd..f84cc374d83 100644 --- a/hail/src/main/scala/is/hail/types/physical/PInt64.scala +++ b/hail/src/main/scala/is/hail/types/physical/PInt64.scala @@ -1,11 +1,11 @@ package is.hail.types.physical import is.hail.annotations._ -import is.hail.asm4s.{Code, coerce, const, _} +import is.hail.asm4s.{coerce, const, Code, _} import is.hail.backend.HailStateManager import is.hail.expr.ir.EmitCodeBuilder -import is.hail.types.physical.stypes.primitives.{SInt64, SInt64Value} import is.hail.types.physical.stypes.{SType, SValue} +import is.hail.types.physical.stypes.primitives.{SInt64, SInt64Value} import is.hail.types.virtual.TInt64 case object PInt64Optional extends PInt64(false) @@ -19,22 +19,19 @@ class PInt64(override val required: Boolean) extends PNumeric with PPrimitive { override type NType = PInt64 override def unsafeOrdering(sm: HailStateManager): UnsafeOrdering = new UnsafeOrdering { - def compare(o1: Long, o2: Long): Int = { + def compare(o1: Long, o2: Long): Int = java.lang.Long.compare(Region.loadLong(o1), Region.loadLong(o2)) - } } override def byteSize: Long = 8 override def zero = coerce[PInt64](const(0L)) - override def add(a: Code[_], b: Code[_]): Code[PInt64] = { + override def add(a: Code[_], b: Code[_]): Code[PInt64] = coerce[PInt64](coerce[Long](a) + coerce[Long](b)) - } - override def multiply(a: Code[_], b: Code[_]): Code[PInt64] = { + override def multiply(a: Code[_], b: Code[_]): Code[PInt64] = coerce[PInt64](coerce[Long](a) * coerce[Long](b)) - } override def sType: SType = SInt64 @@ -44,9 +41,13 @@ class PInt64(override val required: Boolean) extends PNumeric with PPrimitive { override def loadCheapSCode(cb: EmitCodeBuilder, addr: Code[Long]): SInt64Value = new SInt64Value(cb.memoize(Region.loadLong(addr))) - override def unstagedStoreJavaObjectAtAddress(sm: HailStateManager, addr: Long, annotation: Annotation, region: Region): Unit = { + override def unstagedStoreJavaObjectAtAddress( + sm: HailStateManager, + addr: Long, + annotation: Annotation, + region: Region, + ): Unit = Region.storeLong(addr, annotation.asInstanceOf[Long]) - } } object PInt64 { diff --git a/hail/src/main/scala/is/hail/types/physical/PInterval.scala b/hail/src/main/scala/is/hail/types/physical/PInterval.scala index c6772e656b9..0e9a7b1a927 100644 --- a/hail/src/main/scala/is/hail/types/physical/PInterval.scala +++ b/hail/src/main/scala/is/hail/types/physical/PInterval.scala @@ -3,10 +3,8 @@ package is.hail.types.physical import is.hail.annotations._ import is.hail.asm4s._ import is.hail.backend.HailStateManager -import is.hail.check.Gen import is.hail.expr.ir.EmitCodeBuilder import is.hail.types.virtual.TInterval -import is.hail.utils._ abstract class PInterval extends PType { val pointType: PType @@ -16,6 +14,7 @@ abstract class PInterval extends PType { override def unsafeOrdering(sm: HailStateManager): UnsafeOrdering = new UnsafeOrdering { private val pOrd = pointType.unsafeOrdering(sm) + def compare(o1: Long, o2: Long): Int = { val sdef1 = startDefined(o1) if (sdef1 == startDefined(o2)) { @@ -30,10 +29,13 @@ abstract class PInterval extends PType { val includesE1 = includesEnd(o1) if (includesE1 == includesEnd(o2)) { 0 - } else if (includesE1) 1 else -1 + } else if (includesE1) 1 + else -1 } else cmp - } else if (edef1) -1 else 1 - } else if (includesS1) -1 else 1 + } else if (edef1) -1 + else 1 + } else if (includesS1) -1 + else 1 } else cmp } else { if (sdef1) -1 else 1 @@ -44,6 +46,7 @@ abstract class PInterval extends PType { def endPrimaryUnsafeOrdering(sm: HailStateManager): UnsafeOrdering = new UnsafeOrdering { private val pOrd = pointType.unsafeOrdering(sm) + def compare(o1: Long, o2: Long): Int = { val edef1 = endDefined(o1) if (edef1 == endDefined(o2)) { @@ -58,10 +61,13 @@ abstract class PInterval extends PType { val includesS1 = includesStart(o1) if (includesS1 == includesStart(o2)) { 0 - } else if (includesS1) 1 else -1 + } else if (includesS1) 1 + else -1 } else cmp - } else if (sdef1) -1 else 1 - } else if (includesE1) -1 else 1 + } else if (sdef1) -1 + else 1 + } else if (includesE1) -1 + else 1 } else cmp } else { if (edef1) -1 else 1 diff --git a/hail/src/main/scala/is/hail/types/physical/PLocus.scala b/hail/src/main/scala/is/hail/types/physical/PLocus.scala index d57e9276ce2..6d676474f1d 100644 --- a/hail/src/main/scala/is/hail/types/physical/PLocus.scala +++ b/hail/src/main/scala/is/hail/types/physical/PLocus.scala @@ -4,7 +4,6 @@ import is.hail.annotations.Region import is.hail.asm4s._ import is.hail.backend.HailStateManager import is.hail.types.virtual.TLocus -import is.hail.variant._ abstract class PLocus extends PType { lazy val virtualType: TLocus = TLocus(rg) @@ -21,5 +20,11 @@ abstract class PLocus extends PType { def positionType: PInt32 - def unstagedStoreLocus(sm: HailStateManager, addr: Long, contig: String, position: Int, region: Region): Unit + def unstagedStoreLocus( + sm: HailStateManager, + addr: Long, + contig: String, + position: Int, + region: Region, + ): Unit } diff --git a/hail/src/main/scala/is/hail/types/physical/PNDArray.scala b/hail/src/main/scala/is/hail/types/physical/PNDArray.scala index b23a4e661f9..a2693a67262 100644 --- a/hail/src/main/scala/is/hail/types/physical/PNDArray.scala +++ b/hail/src/main/scala/is/hail/types/physical/PNDArray.scala @@ -21,38 +21,43 @@ abstract class PNDArray extends PType { def dataFirstElementPointer(ndAddr: Code[Long]): Code[Long] def loadShape(off: Long, idx: Int): Long - def unstagedLoadShapes(addr: Long): IndexedSeq[Long] = { - (0 until nDims).map { dimIdx => - this.loadShape(addr, dimIdx) - } - } - - def loadShapes(cb: EmitCodeBuilder, addr: Value[Long], settables: IndexedSeq[Settable[Long]]): Unit - def loadStrides(cb: EmitCodeBuilder, addr: Value[Long], settables: IndexedSeq[Settable[Long]]): Unit + + def unstagedLoadShapes(addr: Long): IndexedSeq[Long] = + (0 until nDims).map(dimIdx => this.loadShape(addr, dimIdx)) + + def loadShapes(cb: EmitCodeBuilder, addr: Value[Long], settables: IndexedSeq[Settable[Long]]) + : Unit + + def loadStrides(cb: EmitCodeBuilder, addr: Value[Long], settables: IndexedSeq[Settable[Long]]) + : Unit + def unstagedLoadStrides(addr: Long): IndexedSeq[Long] def numElements(shape: IndexedSeq[Value[Long]]): Code[Long] - def makeRowMajorStrides(sourceShapeArray: IndexedSeq[Value[Long]], cb: EmitCodeBuilder): IndexedSeq[Value[Long]] + def makeRowMajorStrides(sourceShapeArray: IndexedSeq[Value[Long]], cb: EmitCodeBuilder) + : IndexedSeq[Value[Long]] - def makeColumnMajorStrides(sourceShapeArray: IndexedSeq[Value[Long]], cb: EmitCodeBuilder): IndexedSeq[Value[Long]] + def makeColumnMajorStrides(sourceShapeArray: IndexedSeq[Value[Long]], cb: EmitCodeBuilder) + : IndexedSeq[Value[Long]] def getElementAddress(indices: IndexedSeq[Long], nd: Long): Long - def loadElement(cb: EmitCodeBuilder, indices: IndexedSeq[Value[Long]], ndAddress: Value[Long]): SValue + def loadElement(cb: EmitCodeBuilder, indices: IndexedSeq[Value[Long]], ndAddress: Value[Long]) + : SValue def constructByCopyingArray( shape: IndexedSeq[Value[Long]], strides: IndexedSeq[Value[Long]], data: SIndexableValue, cb: EmitCodeBuilder, - region: Value[Region] + region: Value[Region], ): SNDArrayValue def constructDataFunction( shape: IndexedSeq[Value[Long]], strides: IndexedSeq[Value[Long]], cb: EmitCodeBuilder, - region: Value[Region] + region: Value[Region], ): (Value[Long], EmitCodeBuilder => SNDArrayPointerValue) } diff --git a/hail/src/main/scala/is/hail/types/physical/PPrimitive.scala b/hail/src/main/scala/is/hail/types/physical/PPrimitive.scala index 1e934c73341..5ad5b64a874 100644 --- a/hail/src/main/scala/is/hail/types/physical/PPrimitive.scala +++ b/hail/src/main/scala/is/hail/types/physical/PPrimitive.scala @@ -4,7 +4,7 @@ import is.hail.annotations.{Annotation, Region} import is.hail.asm4s.{Code, _} import is.hail.backend.HailStateManager import is.hail.expr.ir.{EmitCodeBuilder, EmitMethodBuilder} -import is.hail.types.physical.stypes.{SCode, SValue} +import is.hail.types.physical.stypes.SValue import is.hail.utils._ trait PPrimitive extends PType { @@ -14,7 +14,13 @@ trait PPrimitive extends PType { override def containsPointers: Boolean = false - def _copyFromAddress(sm: HailStateManager, region: Region, srcPType: PType, srcAddress: Long, deepCopy: Boolean): Long = { + def _copyFromAddress( + sm: HailStateManager, + region: Region, + srcPType: PType, + srcAddress: Long, + deepCopy: Boolean, + ): Long = { if (!deepCopy) return srcAddress @@ -24,26 +30,38 @@ trait PPrimitive extends PType { addr } - - def unstagedStoreAtAddress(sm: HailStateManager, addr: Long, region: Region, srcPType: PType, srcAddress: Long, deepCopy: Boolean): Unit = { + def unstagedStoreAtAddress( + sm: HailStateManager, + addr: Long, + region: Region, + srcPType: PType, + srcAddress: Long, + deepCopy: Boolean, + ): Unit = { assert(srcPType.isOfType(this)) Region.copyFrom(srcAddress, addr, byteSize) } - def store(cb: EmitCodeBuilder, region: Value[Region], value: SValue, deepCopy: Boolean): Value[Long] = { + def store(cb: EmitCodeBuilder, region: Value[Region], value: SValue, deepCopy: Boolean) + : Value[Long] = { val newAddr = cb.memoize(region.allocate(alignment, byteSize)) storeAtAddress(cb, newAddr, region, value, deepCopy) newAddr } - - override def storeAtAddress(cb: EmitCodeBuilder, addr: Code[Long], region: Value[Region], value: SValue, deepCopy: Boolean): Unit = { + override def storeAtAddress( + cb: EmitCodeBuilder, + addr: Code[Long], + region: Value[Region], + value: SValue, + deepCopy: Boolean, + ): Unit = storePrimitiveAtAddress(cb, addr, value) - } def storePrimitiveAtAddress(cb: EmitCodeBuilder, addr: Code[Long], value: SValue): Unit - override def unstagedStoreJavaObject(sm: HailStateManager, annotation: Annotation, region: Region): Long = { + override def unstagedStoreJavaObject(sm: HailStateManager, annotation: Annotation, region: Region) + : Long = { val addr = region.allocate(this.byteSize) unstagedStoreJavaObjectAtAddress(sm, addr, annotation, region) addr diff --git a/hail/src/main/scala/is/hail/types/physical/PSet.scala b/hail/src/main/scala/is/hail/types/physical/PSet.scala index af69359851f..810bae84829 100644 --- a/hail/src/main/scala/is/hail/types/physical/PSet.scala +++ b/hail/src/main/scala/is/hail/types/physical/PSet.scala @@ -8,5 +8,6 @@ import is.hail.types.virtual.TSet abstract class PSet extends PContainer { lazy val virtualType: TSet = TSet(elementType.virtualType) - override def genNonmissingValue(sm: HailStateManager): Gen[Annotation] = Gen.buildableOf[Set](elementType.genValue(sm)) + override def genNonmissingValue(sm: HailStateManager): Gen[Annotation] = + Gen.buildableOf[Set](elementType.genValue(sm)) } diff --git a/hail/src/main/scala/is/hail/types/physical/PString.scala b/hail/src/main/scala/is/hail/types/physical/PString.scala index f12e65ce59c..66bf63b6ed3 100644 --- a/hail/src/main/scala/is/hail/types/physical/PString.scala +++ b/hail/src/main/scala/is/hail/types/physical/PString.scala @@ -9,7 +9,8 @@ import is.hail.types.virtual.TString abstract class PString extends PType { lazy val virtualType: TString.type = TString - override def unsafeOrdering(sm: HailStateManager): UnsafeOrdering = PCanonicalBinary(required).unsafeOrdering(sm) + override def unsafeOrdering(sm: HailStateManager): UnsafeOrdering = + PCanonicalBinary(required).unsafeOrdering(sm) val binaryRepresentation: PBinary @@ -23,5 +24,6 @@ abstract class PString extends PType { def allocateAndStoreString(region: Region, str: String): Long - def allocateAndStoreString(cb: EmitCodeBuilder, region: Value[Region], str: Code[String]): Value[Long] + def allocateAndStoreString(cb: EmitCodeBuilder, region: Value[Region], str: Code[String]) + : Value[Long] } diff --git a/hail/src/main/scala/is/hail/types/physical/PStruct.scala b/hail/src/main/scala/is/hail/types/physical/PStruct.scala index e7965bd9132..14397e2f1e3 100644 --- a/hail/src/main/scala/is/hail/types/physical/PStruct.scala +++ b/hail/src/main/scala/is/hail/types/physical/PStruct.scala @@ -35,15 +35,19 @@ trait PStruct extends PBaseStruct { def identBase: String = "struct" - final def selectFields(names: Seq[String]): PCanonicalStruct = PCanonicalStruct(required, names.map(f => f -> field(f).typ): _*) + final def selectFields(names: Seq[String]): PCanonicalStruct = + PCanonicalStruct(required, names.map(f => f -> field(f).typ): _*) - final def dropFields(names: Set[String]): PCanonicalStruct = selectFields(fieldNames.filter(!names.contains(_))) + final def dropFields(names: Set[String]): PCanonicalStruct = + selectFields(fieldNames.filter(!names.contains(_))) - final def typeAfterSelect(keep: IndexedSeq[Int]): PCanonicalStruct = PCanonicalStruct(required, keep.map(i => fieldNames(i) -> types(i)): _*) + final def typeAfterSelect(keep: IndexedSeq[Int]): PCanonicalStruct = + PCanonicalStruct(required, keep.map(i => fieldNames(i) -> types(i)): _*) def loadField(offset: Code[Long], fieldName: String): Code[Long] - final def isFieldDefined(cb: EmitCodeBuilder, offset: Code[Long], fieldName: String): Value[Boolean] = + final def isFieldDefined(cb: EmitCodeBuilder, offset: Code[Long], fieldName: String) + : Value[Boolean] = cb.memoize(!isFieldMissing(cb, offset, fieldName)) def isFieldMissing(cb: EmitCodeBuilder, offset: Code[Long], fieldName: String): Value[Boolean] diff --git a/hail/src/main/scala/is/hail/types/physical/PSubsetStruct.scala b/hail/src/main/scala/is/hail/types/physical/PSubsetStruct.scala index 5b87ead6ab1..f575965c369 100644 --- a/hail/src/main/scala/is/hail/types/physical/PSubsetStruct.scala +++ b/hail/src/main/scala/is/hail/types/physical/PSubsetStruct.scala @@ -5,7 +5,7 @@ import is.hail.asm4s.{Code, Value} import is.hail.backend.HailStateManager import is.hail.expr.ir.EmitCodeBuilder import is.hail.types.physical.stypes.SValue -import is.hail.types.physical.stypes.concrete.SSubsetStruct +import is.hail.types.physical.stypes.concrete.SStructView import is.hail.types.physical.stypes.interfaces.{SBaseStruct, SBaseStructValue} import is.hail.types.virtual.TStruct import is.hail.utils._ @@ -20,7 +20,10 @@ object PSubsetStruct { // Semantics: PSubsetStruct is a non-constructible view of another PStruct, which is not allowed to mutate // that underlying PStruct's region data final case class PSubsetStruct(ps: PStruct, _fieldNames: IndexedSeq[String]) extends PStruct { - val fields: IndexedSeq[PField] = _fieldNames.zipWithIndex.map { case (name, i) => PField(name, ps.fieldType(name), i)} + val fields: IndexedSeq[PField] = _fieldNames.zipWithIndex.map { case (name, i) => + PField(name, ps.fieldType(name), i) + } + val required = ps.required if (fields == ps.fields) { @@ -32,12 +35,12 @@ final case class PSubsetStruct(ps: PStruct, _fieldNames: IndexedSeq[String]) ext lazy val missingIdx: Array[Int] = idxMap.map(i => ps.missingIdx(i)) lazy val nMissing: Int = missingIdx.length - override lazy val virtualType = TStruct(fields.map(f => (f.name -> f.typ.virtualType)):_*) + override lazy val virtualType = TStruct(fields.map(f => (f.name -> f.typ.virtualType)): _*) override val types: Array[PType] = fields.map(_.typ).toArray override val byteSize: Long = 8 - override def _pretty(sb: StringBuilder, indent: Int, compact: Boolean) { + override def _pretty(sb: StringBuilder, indent: Int, compact: Boolean): Unit = { sb.append("PSubsetStruct{") ps.pretty(sb, indent, compact) sb += '{' @@ -53,7 +56,8 @@ final case class PSubsetStruct(ps: PStruct, _fieldNames: IndexedSeq[String]) ext PSubsetStruct(newPStruct, newNames) } - override def isFieldMissing(cb: EmitCodeBuilder, structAddress: Code[Long], fieldName: String): Value[Boolean] = + override def isFieldMissing(cb: EmitCodeBuilder, structAddress: Code[Long], fieldName: String) + : Value[Boolean] = ps.isFieldMissing(cb, structAddress, fieldName) override def fieldOffset(structAddress: Code[Long], fieldName: String): Code[Long] = @@ -62,7 +66,8 @@ final case class PSubsetStruct(ps: PStruct, _fieldNames: IndexedSeq[String]) ext override def isFieldDefined(structAddress: Long, fieldIdx: Int): Boolean = ps.isFieldDefined(structAddress, idxMap(fieldIdx)) - override def isFieldMissing(cb: EmitCodeBuilder, structAddress: Code[Long], fieldIdx: Int): Value[Boolean] = + override def isFieldMissing(cb: EmitCodeBuilder, structAddress: Code[Long], fieldIdx: Int) + : Value[Boolean] = ps.isFieldMissing(cb, structAddress, idxMap(fieldIdx)) override def fieldOffset(structAddress: Long, fieldIdx: Int): Long = @@ -80,24 +85,29 @@ final case class PSubsetStruct(ps: PStruct, _fieldNames: IndexedSeq[String]) ext override def loadField(structAddress: Code[Long], fieldIdx: Int): Code[Long] = ps.loadField(structAddress, idxMap(fieldIdx)) - override def setFieldPresent(cb: EmitCodeBuilder, structAddress: Code[Long], fieldName: String): Unit = ??? + override def setFieldPresent(cb: EmitCodeBuilder, structAddress: Code[Long], fieldName: String) + : Unit = ??? - override def setFieldMissing(cb: EmitCodeBuilder, structAddress: Code[Long], fieldName: String): Unit = ??? + override def setFieldMissing(cb: EmitCodeBuilder, structAddress: Code[Long], fieldName: String) + : Unit = ??? override def setFieldMissing(structAddress: Long, fieldIdx: Int): Unit = ??? - override def setFieldMissing(cb: EmitCodeBuilder, structAddress: Code[Long], fieldIdx: Int): Unit = ??? + override def setFieldMissing(cb: EmitCodeBuilder, structAddress: Code[Long], fieldIdx: Int) + : Unit = ??? override def setFieldPresent(structAddress: Long, fieldIdx: Int): Unit = ??? - override def setFieldPresent(cb: EmitCodeBuilder, structAddress: Code[Long], fieldIdx: Int): Unit = ??? + override def setFieldPresent(cb: EmitCodeBuilder, structAddress: Code[Long], fieldIdx: Int) + : Unit = ??? def insertFields(fieldsToInsert: TraversableOnce[(String, PType)]): PSubsetStruct = ??? override def initialize(structAddress: Long, setMissing: Boolean): Unit = ps.initialize(structAddress, setMissing) - override def stagedInitialize(cb: EmitCodeBuilder, structAddress: Code[Long], setMissing: Boolean): Unit = + override def stagedInitialize(cb: EmitCodeBuilder, structAddress: Code[Long], setMissing: Boolean) + : Unit = ps.stagedInitialize(cb, structAddress, setMissing) def allocate(region: Region): Long = @@ -109,36 +119,67 @@ final case class PSubsetStruct(ps: PStruct, _fieldNames: IndexedSeq[String]) ext override def setRequired(required: Boolean): PType = PSubsetStruct(ps.setRequired(required).asInstanceOf[PStruct], _fieldNames) - override def copyFromAddress(sm: HailStateManager, region: Region, srcPType: PType, srcAddress: Long, deepCopy: Boolean): Long = + override def copyFromAddress( + sm: HailStateManager, + region: Region, + srcPType: PType, + srcAddress: Long, + deepCopy: Boolean, + ): Long = throw new UnsupportedOperationException - override def _copyFromAddress(sm: HailStateManager, region: Region, srcPType: PType, srcAddress: Long, deepCopy: Boolean): Long = + override def _copyFromAddress( + sm: HailStateManager, + region: Region, + srcPType: PType, + srcAddress: Long, + deepCopy: Boolean, + ): Long = throw new UnsupportedOperationException - def sType: SSubsetStruct = SSubsetStruct(ps.sType.asInstanceOf[SBaseStruct], _fieldNames) + def sType: SBaseStruct = + SStructView.subset(_fieldNames, ps.sType) - def store(cb: EmitCodeBuilder, region: Value[Region], value: SValue, deepCopy: Boolean): Value[Long] = + def store(cb: EmitCodeBuilder, region: Value[Region], value: SValue, deepCopy: Boolean) + : Value[Long] = throw new UnsupportedOperationException - def storeAtAddress(cb: EmitCodeBuilder, addr: Code[Long], region: Value[Region], value: SValue, deepCopy: Boolean): Unit = { + def storeAtAddress( + cb: EmitCodeBuilder, + addr: Code[Long], + region: Value[Region], + value: SValue, + deepCopy: Boolean, + ): Unit = throw new UnsupportedOperationException - } def loadCheapSCode(cb: EmitCodeBuilder, addr: Code[Long]): SBaseStructValue = throw new UnsupportedOperationException - def unstagedStoreAtAddress(sm: HailStateManager, addr: Long, region: Region, srcPType: PType, srcAddress: Long, deepCopy: Boolean): Unit = { + def unstagedStoreAtAddress( + sm: HailStateManager, + addr: Long, + region: Region, + srcPType: PType, + srcAddress: Long, + deepCopy: Boolean, + ): Unit = throw new UnsupportedOperationException - } def loadFromNested(addr: Code[Long]): Code[Long] = addr override def unstagedLoadFromNested(addr: Long): Long = addr - override def unstagedStoreJavaObject(sm: HailStateManager, annotation: Annotation, region: Region): Long = + override def unstagedStoreJavaObject(sm: HailStateManager, annotation: Annotation, region: Region) + : Long = throw new UnsupportedOperationException - override def unstagedStoreJavaObjectAtAddress(sm: HailStateManager, addr: Long, annotation: Annotation, region: Region): Unit = + override def unstagedStoreJavaObjectAtAddress( + sm: HailStateManager, + addr: Long, + annotation: Annotation, + region: Region, + ): Unit = throw new UnsupportedOperationException override def copiedType: PType = ??? // PSubsetStruct on its way out diff --git a/hail/src/main/scala/is/hail/types/physical/PTuple.scala b/hail/src/main/scala/is/hail/types/physical/PTuple.scala index a30a94d0c64..58310f736fa 100644 --- a/hail/src/main/scala/is/hail/types/physical/PTuple.scala +++ b/hail/src/main/scala/is/hail/types/physical/PTuple.scala @@ -10,7 +10,10 @@ trait PTuple extends PBaseStruct { lazy val virtualType: TTuple = TTuple(_types.map(tf => TupleField(tf.index, tf.typ.virtualType))) - lazy val fields: IndexedSeq[PField] = _types.zipWithIndex.map { case (PTupleField(tidx, t), i) => PField(s"$tidx", t, i) } + lazy val fields: IndexedSeq[PField] = _types.zipWithIndex.map { case (PTupleField(tidx, t), i) => + PField(s"$tidx", t, i) + } + lazy val nFields: Int = fields.size def identBase: String = "tuple" diff --git a/hail/src/main/scala/is/hail/types/physical/PType.scala b/hail/src/main/scala/is/hail/types/physical/PType.scala index 2ecabd1d1e5..533766fad94 100644 --- a/hail/src/main/scala/is/hail/types/physical/PType.scala +++ b/hail/src/main/scala/is/hail/types/physical/PType.scala @@ -5,28 +5,42 @@ import is.hail.asm4s._ import is.hail.backend.{ExecuteContext, HailStateManager} import is.hail.check.{Arbitrary, Gen} import is.hail.expr.ir._ -import is.hail.types.physical.stypes.concrete.SRNGState +import is.hail.types.{tcoerce, Requiredness} import is.hail.types.physical.stypes.{SType, SValue} +import is.hail.types.physical.stypes.concrete.SRNGState import is.hail.types.virtual._ -import is.hail.types.{Requiredness, tcoerce} import is.hail.utils._ import is.hail.variant.ReferenceGenome + import org.apache.spark.sql.Row import org.json4s.CustomSerializer import org.json4s.JsonAST.JString -class PTypeSerializer extends CustomSerializer[PType](format => ( - { case JString(s) => PType.canonical(IRParser.parsePType(s)) }, - { case t: PType => JString(t.toString) })) +class PTypeSerializer extends CustomSerializer[PType](format => + ( + { case JString(s) => PType.canonical(IRParser.parsePType(s)) }, + { case t: PType => JString(t.toString) }, + ) + ) -class PStructSerializer extends CustomSerializer[PStruct](format => ( - { case JString(s) => tcoerce[PStruct](IRParser.parsePType(s)) }, - { case t: PStruct => JString(t.toString) })) +class PStructSerializer extends CustomSerializer[PStruct](format => + ( + { case JString(s) => tcoerce[PStruct](IRParser.parsePType(s)) }, + { case t: PStruct => JString(t.toString) }, + ) + ) object PType { def genScalar(required: Boolean): Gen[PType] = - Gen.oneOf(PBoolean(required), PInt32(required), PInt64(required), PFloat32(required), - PFloat64(required), PCanonicalString(required), PCanonicalCall(required)) + Gen.oneOf( + PBoolean(required), + PInt32(required), + PInt64(required), + PFloat32(required), + PFloat64(required), + PCanonicalString(required), + PCanonicalCall(required), + ) val genOptionalScalar: Gen[PType] = genScalar(false) @@ -40,24 +54,24 @@ object PType { def genFields(required: Boolean, genFieldType: Gen[PType]): Gen[Array[PField]] = { Gen.buildableOf[Array]( - Gen.zip(Gen.identifier, genFieldType)) + Gen.zip(Gen.identifier, genFieldType) + ) .filter(fields => fields.map(_._1).areDistinct()) - .map(fields => fields - .iterator - .zipWithIndex - .map { case ((k, t), i) => PField(k, t, i) } - .toArray) + .map(fields => + fields + .iterator + .zipWithIndex + .map { case ((k, t), i) => PField(k, t, i) } + .toArray + ) } - def preGenStruct(required: Boolean, genFieldType: Gen[PType]): Gen[PStruct] = { - for (fields <- genFields(required, genFieldType)) yield - PCanonicalStruct(fields, required) - } + def preGenStruct(required: Boolean, genFieldType: Gen[PType]): Gen[PStruct] = + for (fields <- genFields(required, genFieldType)) yield PCanonicalStruct(fields, required) - def preGenTuple(required: Boolean, genFieldType: Gen[PType]): Gen[PTuple] = { - for (fields <- genFields(required, genFieldType)) yield - PCanonicalTuple(required, fields.map(_.typ): _*) - } + def preGenTuple(required: Boolean, genFieldType: Gen[PType]): Gen[PTuple] = + for (fields <- genFields(required, genFieldType)) + yield PCanonicalTuple(required, fields.map(_.typ): _*) private val defaultRequiredGenRatio = 0.2 @@ -71,7 +85,8 @@ object PType { if (required) preGenStruct(required = true, genArb) else - preGenStruct(required = false, genOptional)) + preGenStruct(required = false, genOptional) + ) def genSized(size: Int, required: Boolean, genPStruct: Gen[PStruct]): Gen[PType] = if (size < 1) @@ -82,18 +97,28 @@ object PType { Gen.frequency( (4, genScalar(required)), (1, genComplexType(required)), - (1, genArb.map { - PCanonicalArray(_) - }), - (1, genArb.map { - PCanonicalSet(_) - }), - (1, genArb.map { - PCanonicalInterval(_) - }), + ( + 1, + genArb.map { + PCanonicalArray(_) + }, + ), + ( + 1, + genArb.map { + PCanonicalSet(_) + }, + ), + ( + 1, + genArb.map { + PCanonicalInterval(_) + }, + ), (1, preGenTuple(required, genArb)), (1, Gen.zip(genRequired, genArb).map { case (k, v) => PCanonicalDict(k, v) }), - (1, genPStruct.resize(size))) + (1, genPStruct.resize(size)), + ) } def preGenArb(required: Boolean, genStruct: Gen[PStruct] = genStruct): Gen[PType] = @@ -121,13 +146,34 @@ object PType { case TCall => PCanonicalCall(required) case TRNGState => StoredSTypePType(SRNGState(None), required) case t: TLocus => PCanonicalLocus(t.rg, required) - case t: TInterval => PCanonicalInterval(canonical(t.pointType, innerRequired, innerRequired), required) - case t: TArray => PCanonicalArray(canonical(t.elementType, innerRequired, innerRequired), required) - case t: TSet => PCanonicalSet(canonical(t.elementType, innerRequired, innerRequired), required) - case t: TDict => PCanonicalDict(canonical(t.keyType, innerRequired, innerRequired), canonical(t.valueType, innerRequired, innerRequired), required) - case t: TTuple => PCanonicalTuple(t._types.map(tf => PTupleField(tf.index, canonical(tf.typ, innerRequired, innerRequired))), required) - case t: TStruct => PCanonicalStruct(t.fields.map(f => PField(f.name, canonical(f.typ, innerRequired, innerRequired), f.index)), required) - case t: TNDArray => PCanonicalNDArray(canonical(t.elementType, innerRequired, innerRequired).setRequired(true), t.nDims, required) + case t: TInterval => + PCanonicalInterval(canonical(t.pointType, innerRequired, innerRequired), required) + case t: TArray => + PCanonicalArray(canonical(t.elementType, innerRequired, innerRequired), required) + case t: TSet => + PCanonicalSet(canonical(t.elementType, innerRequired, innerRequired), required) + case t: TDict => PCanonicalDict( + canonical(t.keyType, innerRequired, innerRequired), + canonical(t.valueType, innerRequired, innerRequired), + required, + ) + case t: TTuple => PCanonicalTuple( + t._types.map(tf => + PTupleField(tf.index, canonical(tf.typ, innerRequired, innerRequired)) + ), + required, + ) + case t: TStruct => PCanonicalStruct( + t.fields.map(f => + PField(f.name, canonical(f.typ, innerRequired, innerRequired), f.index) + ), + required, + ) + case t: TNDArray => PCanonicalNDArray( + canonical(t.elementType, innerRequired, innerRequired).setRequired(true), + t.nDims, + required, + ) case TVoid => PVoid } } @@ -151,8 +197,10 @@ object PType { case t: PInterval => PCanonicalInterval(canonical(t.pointType), t.required) case t: PArray => PCanonicalArray(canonical(t.elementType), t.required) case t: PSet => PCanonicalSet(canonical(t.elementType), t.required) - case t: PTuple => PCanonicalTuple(t._types.map(pf => PTupleField(pf.index, canonical(pf.typ))), t.required) - case t: PStruct => PCanonicalStruct(t.fields.map(f => PField(f.name, canonical(f.typ), f.index)), t.required) + case t: PTuple => + PCanonicalTuple(t._types.map(pf => PTupleField(pf.index, canonical(pf.typ))), t.required) + case t: PStruct => + PCanonicalStruct(t.fields.map(f => PField(f.name, canonical(f.typ), f.index)), t.required) case t: PNDArray => PCanonicalNDArray(canonical(t.elementType), t.nDims, t.required) case t: PDict => PCanonicalDict(canonical(t.keyType), canonical(t.valueType), t.required) case PVoid => PVoid @@ -240,9 +288,7 @@ object PType { case t: TNDArray => val r = a.asInstanceOf[Row] val elems = r(2).asInstanceOf[IndexedSeq[_]] - elems.foreach { x => - setOptional(t.elementType, x, ri + 1, ci) - } + elems.foreach(x => setOptional(t.elementType, x, ri + 1, ci)) case t: TBaseStruct => val r = a.asInstanceOf[Row] val n = r.size @@ -269,7 +315,8 @@ object PType { PCanonicalDict( canonical(t.keyType, ri + 1, ci + 1), canonical(t.valueType, childRequiredIndex(ci), childIndex(ci)), - requiredVector(ri)) + requiredVector(ri), + ) case t: TArray => PCanonicalArray(canonical(t.elementType, ri + 1, ci), requiredVector(ri)) case t: TStream => @@ -286,22 +333,27 @@ object PType { PCanonicalNDArray(canonical(t.elementType, ri + 1, ci), t.nDims, requiredVector(ri)) case TString => PCanonicalString(requiredVector(ri)) case t: TStruct => - PCanonicalStruct(requiredVector(ri), + PCanonicalStruct( + requiredVector(ri), t.fields.zipWithIndex.map { case (f, j) => f.name -> canonical(f.typ, childRequiredIndex(ci + j), childIndex(ci + j)) - }: _*) + }: _* + ) case t: TTuple => - PCanonicalTuple(requiredVector(ri), + PCanonicalTuple( + requiredVector(ri), t.types.zipWithIndex.map { case (ft, j) => canonical(ft, childRequiredIndex(ci + j), childIndex(ci + j)) - }: _*) + }: _* + ) } } canonical(t, 0, 0) } - def canonicalize(t: PType, ctx: ExecuteContext, path: List[String]): Option[(HailClassLoader) => AsmFunction2RegionLongLong] = { + def canonicalize(t: PType, ctx: ExecuteContext, path: List[String]) + : Option[(HailClassLoader) => AsmFunction2RegionLongLong] = { def canonicalPath(pt: PType, path: List[String]): PType = { if (path.isEmpty) { PType.canonical(pt) @@ -309,9 +361,12 @@ object PType { val head :: tail = path pt match { - case t@PCanonicalStruct(fields, required) => + case t @ PCanonicalStruct(fields, required) => assert(t.hasField(head)) - PCanonicalStruct(fields.map(f => if (f.name == head) f.copy(typ = canonicalPath(f.typ, tail)) else f), required) + PCanonicalStruct( + fields.map(f => if (f.name == head) f.copy(typ = canonicalPath(f.typ, tail)) else f), + required, + ) case PCanonicalArray(element, required) => assert(head == "element") PCanonicalArray(canonicalPath(element, tail), required) @@ -324,9 +379,12 @@ object PType { if (cpt == t) None else { - val fb = EmitFunctionBuilder[AsmFunction2RegionLongLong](ctx, + val fb = EmitFunctionBuilder[AsmFunction2RegionLongLong]( + ctx, "copyFromAddr", - FastSeq[ParamType](classInfo[Region], LongInfo), LongInfo) + FastSeq[ParamType](classInfo[Region], LongInfo), + LongInfo, + ) fb.emitWithBuilder { cb => val region = fb.apply_method.getCodeParam[Region](1) @@ -342,7 +400,10 @@ abstract class PType extends Serializable with Requiredness { self => def genValue(sm: HailStateManager): Gen[Annotation] = - if (required) genNonmissingValue(sm) else Gen.nextCoin(0.05).flatMap(isEmpty => if (isEmpty) Gen.const(null) else genNonmissingValue(sm)) + if (required) genNonmissingValue(sm) + else Gen.nextCoin(0.05).flatMap(isEmpty => + if (isEmpty) Gen.const(null) else genNonmissingValue(sm) + ) def genNonmissingValue(sm: HailStateManager): Gen[Annotation] = virtualType.genNonmissingValue(sm) @@ -360,7 +421,8 @@ abstract class PType extends Serializable with Requiredness { def unsafeOrdering(sm: HailStateManager): UnsafeOrdering - def isCanonical: Boolean = PType.canonical(this) == this // will recons, may need to rewrite this method + def isCanonical: Boolean = + PType.canonical(this) == this // will recons, may need to rewrite this method def unsafeOrdering(sm: HailStateManager, rightType: PType): UnsafeOrdering = { require(virtualType == rightType.virtualType, s"$this, $rightType") @@ -371,13 +433,13 @@ abstract class PType extends Serializable with Requiredness { def _asIdent: String - final def pretty(sb: StringBuilder, indent: Int, compact: Boolean) { + final def pretty(sb: StringBuilder, indent: Int, compact: Boolean): Unit = { if (required) sb.append("+") _pretty(sb, indent, compact) } - def _pretty(sb: StringBuilder, indent: Int, compact: Boolean) + def _pretty(sb: StringBuilder, indent: Int, compact: Boolean): Unit def byteSize: Long @@ -391,12 +453,11 @@ abstract class PType extends Serializable with Requiredness { def equalModuloRequired(that: PType): Boolean = this == that.setRequired(required) - final def orMissing(required2: Boolean): PType = { + final def orMissing(required2: Boolean): PType = if (!required2) setRequired(false) else this - } final def isOfType(t: PType): Boolean = this.virtualType == t.virtualType @@ -415,13 +476,25 @@ abstract class PType extends Serializable with Requiredness { def subsetTo(t: Type): PType = { this match { - case x@PCanonicalStruct(fields, r) => + case x @ PCanonicalStruct(fields, r) => val ts = t.asInstanceOf[TStruct] assert(ts.fieldNames.forall(x.fieldNames.contains)) - PCanonicalStruct(r, fields.flatMap { pf => ts.selfField(pf.name).map { vf => (pf.name, pf.typ.subsetTo(vf.typ)) } }: _*) + PCanonicalStruct( + r, + fields.flatMap { pf => + ts.selfField(pf.name).map(vf => (pf.name, pf.typ.subsetTo(vf.typ))) + }: _* + ) case PCanonicalTuple(fields, r) => val tt = t.asInstanceOf[TTuple] - PCanonicalTuple(fields.flatMap { pf => tt.fieldIndex.get(pf.index).map(vi => PTupleField(pf.index, pf.typ.subsetTo(tt.types(vi)))) }, r) + PCanonicalTuple( + fields.flatMap { pf => + tt.fieldIndex.get(pf.index).map(vi => + PTupleField(pf.index, pf.typ.subsetTo(tt.types(vi))) + ) + }, + r, + ) case PCanonicalArray(e, r) => val ta = t.asInstanceOf[TArray] PCanonicalArray(e.subsetTo(ta.elementType), r) @@ -440,29 +513,60 @@ abstract class PType extends Serializable with Requiredness { } } - protected[physical] def _copyFromAddress(sm: HailStateManager, region: Region, srcPType: PType, srcAddress: Long, deepCopy: Boolean): Long - - def copyFromAddress(sm: HailStateManager, region: Region, srcPType: PType, srcAddress: Long, deepCopy: Boolean): Long = { + protected[physical] def _copyFromAddress( + sm: HailStateManager, + region: Region, + srcPType: PType, + srcAddress: Long, + deepCopy: Boolean, + ): Long + + def copyFromAddress( + sm: HailStateManager, + region: Region, + srcPType: PType, + srcAddress: Long, + deepCopy: Boolean, + ): Long = { // no requirement for requiredness // this can have more/less requiredness than srcPType // if value is not compatible with this, an exception will be thrown (virtualType, srcPType.virtualType) match { case (l: TBaseStruct, r: TBaseStruct) => assert(l.isCompatibleWith(r)) - case _ => assert(virtualType == srcPType.virtualType, s"virtualType: ${virtualType} != srcPType.virtualType: ${srcPType.virtualType}") + case _ => assert( + virtualType == srcPType.virtualType, + s"virtualType: $virtualType != srcPType.virtualType: ${srcPType.virtualType}", + ) } _copyFromAddress(sm, region, srcPType, srcAddress, deepCopy) } - // return a SCode that can cheaply operate on the region representation. Generally a pointer type, but not necessarily (e.g. primitives). + /* return a SCode that can cheaply operate on the region representation. Generally a pointer type, + * but not necessarily (e.g. primitives). */ def loadCheapSCode(cb: EmitCodeBuilder, addr: Code[Long]): SValue // stores a stack value as a region value of this type - def store(cb: EmitCodeBuilder, region: Value[Region], value: SValue, deepCopy: Boolean): Value[Long] - - // stores a stack value inside pre-allocated memory of this type (in a nested structure, for instance). - def storeAtAddress(cb: EmitCodeBuilder, addr: Code[Long], region: Value[Region], value: SValue, deepCopy: Boolean): Unit - - def unstagedStoreAtAddress(sm: HailStateManager, addr: Long, region: Region, srcPType: PType, srcAddress: Long, deepCopy: Boolean): Unit + def store(cb: EmitCodeBuilder, region: Value[Region], value: SValue, deepCopy: Boolean) + : Value[Long] + + /* stores a stack value inside pre-allocated memory of this type (in a nested structure, for + * instance). */ + def storeAtAddress( + cb: EmitCodeBuilder, + addr: Code[Long], + region: Value[Region], + value: SValue, + deepCopy: Boolean, + ): Unit + + def unstagedStoreAtAddress( + sm: HailStateManager, + addr: Long, + region: Region, + srcPType: PType, + srcAddress: Long, + deepCopy: Boolean, + ): Unit def deepRename(t: Type): PType = this @@ -474,5 +578,10 @@ abstract class PType extends Serializable with Requiredness { def unstagedStoreJavaObject(sm: HailStateManager, annotation: Annotation, region: Region): Long - def unstagedStoreJavaObjectAtAddress(sm: HailStateManager, addr: Long, annotation: Annotation, region: Region): Unit + def unstagedStoreJavaObjectAtAddress( + sm: HailStateManager, + addr: Long, + annotation: Annotation, + region: Region, + ): Unit } diff --git a/hail/src/main/scala/is/hail/types/physical/PUnrealizable.scala b/hail/src/main/scala/is/hail/types/physical/PUnrealizable.scala index 2e0e912a751..0219ea824af 100644 --- a/hail/src/main/scala/is/hail/types/physical/PUnrealizable.scala +++ b/hail/src/main/scala/is/hail/types/physical/PUnrealizable.scala @@ -1,11 +1,10 @@ package is.hail.types.physical import is.hail.annotations.{Annotation, Region} -import is.hail.asm4s.{Code, TypeInfo, Value} +import is.hail.asm4s.{Code, Value} import is.hail.backend.HailStateManager -import is.hail.expr.ir.orderings.CodeOrdering -import is.hail.expr.ir.{Ascending, Descending, EmitCodeBuilder, EmitMethodBuilder, SortOrder} -import is.hail.types.physical.stypes.{SCode, SValue} +import is.hail.expr.ir.EmitCodeBuilder +import is.hail.types.physical.stypes.SValue trait PUnrealizable extends PType { private def unsupported: Nothing = @@ -15,30 +14,61 @@ trait PUnrealizable extends PType { override def alignment: Long = unsupported - protected[physical] def _copyFromAddress(sm: HailStateManager, region: Region, srcPType: PType, srcAddress: Long, deepCopy: Boolean): Long = + protected[physical] def _copyFromAddress( + sm: HailStateManager, + region: Region, + srcPType: PType, + srcAddress: Long, + deepCopy: Boolean, + ): Long = unsupported - override def copyFromAddress(sm: HailStateManager, region: Region, srcPType: PType, srcAddress: Long, deepCopy: Boolean): Long = + override def copyFromAddress( + sm: HailStateManager, + region: Region, + srcPType: PType, + srcAddress: Long, + deepCopy: Boolean, + ): Long = unsupported - def unstagedStoreAtAddress(sm: HailStateManager, addr: Long, region: Region, srcPType: PType, srcAddress: Long, deepCopy: Boolean): Unit = + def unstagedStoreAtAddress( + sm: HailStateManager, + addr: Long, + region: Region, + srcPType: PType, + srcAddress: Long, + deepCopy: Boolean, + ): Unit = unsupported - override def unstagedStoreJavaObject(sm: HailStateManager, annotation: Annotation, region: Region): Long = + override def unstagedStoreJavaObject(sm: HailStateManager, annotation: Annotation, region: Region) + : Long = unsupported - override def unstagedStoreJavaObjectAtAddress(sm: HailStateManager, addr: Long, annotation: Annotation, region: Region): Unit = + override def unstagedStoreJavaObjectAtAddress( + sm: HailStateManager, + addr: Long, + annotation: Annotation, + region: Region, + ): Unit = unsupported override def loadCheapSCode(cb: EmitCodeBuilder, addr: Code[Long]): SValue = unsupported - override def store(cb: EmitCodeBuilder, region: Value[Region], value: SValue, deepCopy: Boolean): Value[Long] = unsupported + override def store(cb: EmitCodeBuilder, region: Value[Region], value: SValue, deepCopy: Boolean) + : Value[Long] = unsupported - override def storeAtAddress(cb: EmitCodeBuilder, addr: Code[Long], region: Value[Region], value: SValue, deepCopy: Boolean): Unit = unsupported + override def storeAtAddress( + cb: EmitCodeBuilder, + addr: Code[Long], + region: Value[Region], + value: SValue, + deepCopy: Boolean, + ): Unit = unsupported - override def containsPointers: Boolean = { + override def containsPointers: Boolean = throw new UnsupportedOperationException("containsPointers not supported on PUnrealizable") - } override def copiedType: PType = this } diff --git a/hail/src/main/scala/is/hail/types/physical/PVoid.scala b/hail/src/main/scala/is/hail/types/physical/PVoid.scala index 414d49f05d2..d7ed81df4c0 100644 --- a/hail/src/main/scala/is/hail/types/physical/PVoid.scala +++ b/hail/src/main/scala/is/hail/types/physical/PVoid.scala @@ -1,9 +1,8 @@ package is.hail.types.physical import is.hail.annotations.UnsafeOrdering -import is.hail.asm4s.{Code, TypeInfo, UnitInfo} +import is.hail.asm4s.Code import is.hail.backend.HailStateManager -import is.hail.expr.ir.EmitCodeBuilder import is.hail.types.physical.stypes.SType import is.hail.types.physical.stypes.interfaces.SVoid import is.hail.types.virtual.{TVoid, Type} @@ -22,7 +21,8 @@ case object PVoid extends PType with PUnrealizable { def setRequired(required: Boolean) = PVoid - override def unsafeOrdering(sm: HailStateManager): UnsafeOrdering = throw new NotImplementedError() + override def unsafeOrdering(sm: HailStateManager): UnsafeOrdering = + throw new NotImplementedError() def loadFromNested(addr: Code[Long]): Code[Long] = throw new NotImplementedError() diff --git a/hail/src/main/scala/is/hail/types/physical/StoredSTypePType.scala b/hail/src/main/scala/is/hail/types/physical/StoredSTypePType.scala index 985881b6174..239d6604024 100644 --- a/hail/src/main/scala/is/hail/types/physical/StoredSTypePType.scala +++ b/hail/src/main/scala/is/hail/types/physical/StoredSTypePType.scala @@ -4,11 +4,10 @@ import is.hail.annotations.{Annotation, Region, UnsafeOrdering} import is.hail.asm4s._ import is.hail.backend.HailStateManager import is.hail.expr.ir.EmitCodeBuilder -import is.hail.types.physical.stypes.{SCode, SType, SValue} +import is.hail.types.physical.stypes.{SType, SValue} import is.hail.types.virtual.Type import is.hail.utils._ - object StoredCodeTuple { def canStore(ti: TypeInfo[_]): Boolean = ti match { case IntInfo => true @@ -91,19 +90,28 @@ case class StoredSTypePType(sType: SType, required: Boolean) extends PType { override def virtualType: Type = sType.virtualType - override def store(cb: EmitCodeBuilder, region: Value[Region], value: SValue, deepCopy: Boolean): Value[Long] = { + override def store(cb: EmitCodeBuilder, region: Value[Region], value: SValue, deepCopy: Boolean) + : Value[Long] = { val addr = cb.memoize(region.allocate(ct.alignment, ct.byteSize)) ct.store(cb, addr, value.st.coerceOrCopy(cb, region, value, deepCopy).valueTuple.map(_.get)) addr } - override def storeAtAddress(cb: EmitCodeBuilder, addr: Code[Long], region: Value[Region], value: SValue, deepCopy: Boolean): Unit = { - ct.store(cb, cb.newLocal[Long]("stored_stype_ptype_addr", addr), value.st.coerceOrCopy(cb, region, value, deepCopy).valueTuple.map(_.get)) - } - - override def loadCheapSCode(cb: EmitCodeBuilder, addr: Code[Long]): SValue = { + override def storeAtAddress( + cb: EmitCodeBuilder, + addr: Code[Long], + region: Value[Region], + value: SValue, + deepCopy: Boolean, + ): Unit = + ct.store( + cb, + cb.newLocal[Long]("stored_stype_ptype_addr", addr), + value.st.coerceOrCopy(cb, region, value, deepCopy).valueTuple.map(_.get), + ) + + override def loadCheapSCode(cb: EmitCodeBuilder, addr: Code[Long]): SValue = sType.fromValues(ct.loadValues(cb, cb.newLocal[Long]("stored_stype_ptype_loaded_addr"))) - } override def loadFromNested(addr: Code[Long]): Code[Long] = addr @@ -115,25 +123,48 @@ case class StoredSTypePType(sType: SType, required: Boolean) extends PType { override def containsPointers: Boolean = sType.containsPointers - override def setRequired(required: Boolean): PType = if (required == this.required) this else StoredSTypePType(sType, required) + override def setRequired(required: Boolean): PType = + if (required == this.required) this else StoredSTypePType(sType, required) - def unsupportedCanonicalMethod: Nothing = throw new UnsupportedOperationException("not supported on StoredStypePType") + def unsupportedCanonicalMethod: Nothing = + throw new UnsupportedOperationException("not supported on StoredStypePType") - override def _pretty(sb: StringBuilder, indent: Int, compact: Boolean): Unit = sb.append(sType.toString) + override def _pretty(sb: StringBuilder, indent: Int, compact: Boolean): Unit = + sb.append(sType.toString) - override def unsafeOrdering(sm: HailStateManager, rightType: PType): UnsafeOrdering = unsupportedCanonicalMethod + override def unsafeOrdering(sm: HailStateManager, rightType: PType): UnsafeOrdering = + unsupportedCanonicalMethod override def unsafeOrdering(sm: HailStateManager): UnsafeOrdering = unsupportedCanonicalMethod def unstagedLoadFromNested(addr: Long): Long = unsupportedCanonicalMethod - def unstagedStoreJavaObject(sm: HailStateManager, annotation: Annotation, region: Region): Long = unsupportedCanonicalMethod - - def unstagedStoreJavaObjectAtAddress(sm: HailStateManager, addr: Long, annotation: Annotation, region: Region): Unit = unsupportedCanonicalMethod - - override def _copyFromAddress(sm: HailStateManager, region: Region, srcPType: PType, srcAddress: Long, deepCopy: Boolean): Long = unsupportedCanonicalMethod - - override def unstagedStoreAtAddress(sm: HailStateManager, addr: Long, region: Region, srcPType: PType, srcAddress: Long, deepCopy: Boolean): Unit = unsupportedCanonicalMethod + def unstagedStoreJavaObject(sm: HailStateManager, annotation: Annotation, region: Region): Long = + unsupportedCanonicalMethod + + def unstagedStoreJavaObjectAtAddress( + sm: HailStateManager, + addr: Long, + annotation: Annotation, + region: Region, + ): Unit = unsupportedCanonicalMethod + + override def _copyFromAddress( + sm: HailStateManager, + region: Region, + srcPType: PType, + srcAddress: Long, + deepCopy: Boolean, + ): Long = unsupportedCanonicalMethod + + override def unstagedStoreAtAddress( + sm: HailStateManager, + addr: Long, + region: Region, + srcPType: PType, + srcAddress: Long, + deepCopy: Boolean, + ): Unit = unsupportedCanonicalMethod override def _asIdent: String = "stored_stype_ptype" diff --git a/hail/src/main/scala/is/hail/types/physical/package.scala b/hail/src/main/scala/is/hail/types/physical/package.scala index 5d0adf60ab4..68cd93ce0e4 100644 --- a/hail/src/main/scala/is/hail/types/physical/package.scala +++ b/hail/src/main/scala/is/hail/types/physical/package.scala @@ -2,8 +2,6 @@ package is.hail.types import is.hail.asm4s._ -import scala.language.implicitConversions - package object physical { def typeToTypeInfo(t: PType): TypeInfo[_] = t match { case _: PInt32 => typeInfo[Int] diff --git a/hail/src/main/scala/is/hail/types/physical/stypes/SCode.scala b/hail/src/main/scala/is/hail/types/physical/stypes/SCode.scala index d2e0fb362a7..1f1b6040274 100644 --- a/hail/src/main/scala/is/hail/types/physical/stypes/SCode.scala +++ b/hail/src/main/scala/is/hail/types/physical/stypes/SCode.scala @@ -6,41 +6,50 @@ import is.hail.expr.ir.EmitCodeBuilder import is.hail.types.physical.stypes.concrete.SRNGStateValue import is.hail.types.physical.stypes.interfaces._ import is.hail.types.physical.stypes.primitives._ +import is.hail.types.virtual.Type object SCode { def add(cb: EmitCodeBuilder, left: SValue, right: SValue, required: Boolean): SValue = { (left.st, right.st) match { case (SInt32, SInt32) => new SInt32Value(cb.memoize(left.asInt.value + right.asInt.value)) - case (SFloat32, SFloat32) => new SFloat32Value(cb.memoize(left.asFloat.value + right.asFloat.value)) + case (SFloat32, SFloat32) => + new SFloat32Value(cb.memoize(left.asFloat.value + right.asFloat.value)) case (SInt64, SInt64) => new SInt64Value(cb.memoize(left.asLong.value + right.asLong.value)) - case (SFloat64, SFloat64) => new SFloat64Value(cb.memoize(left.asDouble.value + right.asDouble.value)) + case (SFloat64, SFloat64) => + new SFloat64Value(cb.memoize(left.asDouble.value + right.asDouble.value)) } } def multiply(cb: EmitCodeBuilder, left: SValue, right: SValue, required: Boolean): SValue = { (left.st, right.st) match { case (SInt32, SInt32) => new SInt32Value(cb.memoize(left.asInt.value * right.asInt.value)) - case (SFloat32, SFloat32) => new SFloat32Value(cb.memoize(left.asFloat.value * right.asFloat.value)) + case (SFloat32, SFloat32) => + new SFloat32Value(cb.memoize(left.asFloat.value * right.asFloat.value)) case (SInt64, SInt64) => new SInt64Value(cb.memoize(left.asLong.value * right.asLong.value)) - case (SFloat64, SFloat64) => new SFloat64Value(cb.memoize(left.asDouble.value * right.asDouble.value)) + case (SFloat64, SFloat64) => + new SFloat64Value(cb.memoize(left.asDouble.value * right.asDouble.value)) } } def subtract(cb: EmitCodeBuilder, left: SValue, right: SValue, required: Boolean): SValue = { (left.st, right.st) match { case (SInt32, SInt32) => new SInt32Value(cb.memoize(left.asInt.value - right.asInt.value)) - case (SFloat32, SFloat32) => new SFloat32Value(cb.memoize(left.asFloat.value - right.asFloat.value)) + case (SFloat32, SFloat32) => + new SFloat32Value(cb.memoize(left.asFloat.value - right.asFloat.value)) case (SInt64, SInt64) => new SInt64Value(cb.memoize(left.asLong.value - right.asLong.value)) - case (SFloat64, SFloat64) => new SFloat64Value(cb.memoize(left.asDouble.value - right.asDouble.value)) + case (SFloat64, SFloat64) => + new SFloat64Value(cb.memoize(left.asDouble.value - right.asDouble.value)) } } def divide(cb: EmitCodeBuilder, left: SValue, right: SValue, required: Boolean): SValue = { (left.st, right.st) match { case (SInt32, SInt32) => new SInt32Value(cb.memoize(left.asInt.value / right.asInt.value)) - case (SFloat32, SFloat32) => new SFloat32Value(cb.memoize(left.asFloat.value / right.asFloat.value)) + case (SFloat32, SFloat32) => + new SFloat32Value(cb.memoize(left.asFloat.value / right.asFloat.value)) case (SInt64, SInt64) => new SInt64Value(cb.memoize(left.asLong.value / right.asLong.value)) - case (SFloat64, SFloat64) => new SFloat64Value(cb.memoize(left.asDouble.value / right.asDouble.value)) + case (SFloat64, SFloat64) => + new SFloat64Value(cb.memoize(left.asDouble.value / right.asDouble.value)) } } @@ -95,18 +104,21 @@ trait SValue { def castTo(cb: EmitCodeBuilder, region: Value[Region], destType: SType): SValue = castTo(cb, region, destType, false) - def castTo(cb: EmitCodeBuilder, region: Value[Region], destType: SType, deepCopy: Boolean): SValue = { + def castTo(cb: EmitCodeBuilder, region: Value[Region], destType: SType, deepCopy: Boolean) + : SValue = destType.coerceOrCopy(cb, region, this, deepCopy) - } def copyToRegion(cb: EmitCodeBuilder, region: Value[Region], destType: SType): SValue = destType.coerceOrCopy(cb, region, this, deepCopy = true) - def hash(cb: EmitCodeBuilder): SInt32Value = throw new UnsupportedOperationException(s"Stype ${st} has no hashcode") + def hash(cb: EmitCodeBuilder): SInt32Value = + throw new UnsupportedOperationException(s"Stype $st has no hashcode") def sizeToStoreInBytes(cb: EmitCodeBuilder): SInt64Value -} + def castRename(t: Type): SValue = + st.castRename(t).fromValues(valueTuple) +} trait SSettable extends SValue { def store(cb: EmitCodeBuilder, v: SValue): Unit @@ -115,14 +127,13 @@ trait SSettable extends SValue { } object SSettable { - def apply(sb: SettableBuilder, st: SType, name: String): SSettable = { + def apply(sb: SettableBuilder, st: SType, name: String): SSettable = st.fromSettables(st.settableTupleTypes().zipWithIndex.map { case (ti, i) => - sb.newSettable(s"${ name }_${ st.getClass.getSimpleName }_$i")(ti) + sb.newSettable(s"${name}_${st.getClass.getSimpleName}_$i")(ti) }) - } } trait SUnrealizableValue extends SValue { override def sizeToStoreInBytes(cb: EmitCodeBuilder): SInt64Value = throw new UnsupportedOperationException(s"Unrealizable SValue has no size in bytes.") -} \ No newline at end of file +} diff --git a/hail/src/main/scala/is/hail/types/physical/stypes/SType.scala b/hail/src/main/scala/is/hail/types/physical/stypes/SType.scala index e91a9404f76..ef05d005923 100644 --- a/hail/src/main/scala/is/hail/types/physical/stypes/SType.scala +++ b/hail/src/main/scala/is/hail/types/physical/stypes/SType.scala @@ -2,7 +2,9 @@ package is.hail.types.physical.stypes import is.hail.annotations.Region import is.hail.asm4s._ -import is.hail.expr.ir.{EmitCode, EmitCodeBuilder, EmitSettable, EmitValue, SCodeEmitParamType, SCodeParamType} +import is.hail.expr.ir.{ + EmitCodeBuilder, EmitSettable, EmitValue, SCodeEmitParamType, SCodeParamType, +} import is.hail.types.{TypeWithRequiredness, VirtualTypeWithReq} import is.hail.types.physical.PType import is.hail.types.physical.stypes.concrete.SUnreachable @@ -10,7 +12,6 @@ import is.hail.types.physical.stypes.interfaces.SStream import is.hail.types.physical.stypes.primitives._ import is.hail.types.virtual._ - object SType { def chooseCompatibleType(req: VirtualTypeWithReq, stypes: SType*): SType = { val reachable = stypes.filter(t => !t.isInstanceOf[SUnreachable]).toSet @@ -24,9 +25,8 @@ object SType { req.canonicalEmitType.st // fall back to canonical emit type from requiredness } - def canonical(virt: Type): SType = { + def canonical(virt: Type): SType = PType.canonical(virt).sType - } def extractPrimValue(cb: EmitCodeBuilder, x: SValue): Value[_] = x.st.virtualType match { case TInt32 => x.asInt.value @@ -40,14 +40,23 @@ object SType { trait SType { def virtualType: Type - final def coerceOrCopy(cb: EmitCodeBuilder, region: Value[Region], value: SValue, deepCopy: Boolean): SValue = { + final def coerceOrCopy( + cb: EmitCodeBuilder, + region: Value[Region], + value: SValue, + deepCopy: Boolean, + ): SValue = value.st match { case _: SUnreachable => this.defaultValue case _ => _coerceOrCopy(cb, region, value, deepCopy) } - } - protected[stypes] def _coerceOrCopy(cb: EmitCodeBuilder, region: Value[Region], value: SValue, deepCopy: Boolean): SValue + protected[stypes] def _coerceOrCopy( + cb: EmitCodeBuilder, + region: Value[Region], + value: SValue, + deepCopy: Boolean, + ): SValue def settableTupleTypes(): IndexedSeq[TypeInfo[_]] @@ -79,7 +88,8 @@ trait SType { protected[stypes] def _typeWithRequiredness: TypeWithRequiredness - final def typeWithRequiredness: VirtualTypeWithReq = VirtualTypeWithReq(virtualType, _typeWithRequiredness) + final def typeWithRequiredness: VirtualTypeWithReq = + VirtualTypeWithReq(virtualType, _typeWithRequiredness) def containsPointers: Boolean } @@ -107,24 +117,27 @@ case class EmitType(st: SType, required: Boolean) { def fromSettables(settables: IndexedSeq[Settable[_]]): EmitSettable = new EmitSettable( if (required) None else Some(coerce[Boolean](settables.last)), - st.fromSettables(settables.take(st.nSettables)) + st.fromSettables(settables.take(st.nSettables)), ) def fromValues(values: IndexedSeq[Value[_]]): EmitValue = EmitValue( if (required) None else Some(coerce[Boolean](values.last)), - st.fromValues(values.take(st.nSettables)) + st.fromValues(values.take(st.nSettables)), ) def nSettables: Int = settableTupleTypes.length - def coerceOrCopy(cb: EmitCodeBuilder, region: Value[Region], value: EmitValue, deepCopy: Boolean): EmitValue = { + def coerceOrCopy(cb: EmitCodeBuilder, region: Value[Region], value: EmitValue, deepCopy: Boolean) + : EmitValue = { if (value.emitType == this && (!deepCopy || !value.st.containsPointers)) value else (required, value.required) match { case (true, _) => EmitValue.present(st.coerceOrCopy(cb, region, value.get(cb), deepCopy)) - case (false, true) => EmitValue.present(st.coerceOrCopy(cb, region, value.get(cb), deepCopy)).setOptional - case (false, false) => cb.memoize(value.toI(cb).map(cb)(value => st.coerceOrCopy(cb, region, value, deepCopy))) + case (false, true) => + EmitValue.present(st.coerceOrCopy(cb, region, value.get(cb), deepCopy)).setOptional + case (false, false) => + cb.memoize(value.toI(cb).map(cb)(value => st.coerceOrCopy(cb, region, value, deepCopy))) } } -} \ No newline at end of file +} diff --git a/hail/src/main/scala/is/hail/types/physical/stypes/SingleCodeSCode.scala b/hail/src/main/scala/is/hail/types/physical/stypes/SingleCodeSCode.scala index fc75bba9ceb..334929a0b86 100644 --- a/hail/src/main/scala/is/hail/types/physical/stypes/SingleCodeSCode.scala +++ b/hail/src/main/scala/is/hail/types/physical/stypes/SingleCodeSCode.scala @@ -4,10 +4,11 @@ import is.hail.annotations.Region import is.hail.asm4s._ import is.hail.expr.ir._ import is.hail.types.physical.PType -import is.hail.types.physical.stypes.interfaces.{NoBoxLongIterator, SStream, SStreamConcrete, SStreamIteratorLong, SStreamValue} +import is.hail.types.physical.stypes.interfaces.{ + NoBoxLongIterator, SStream, SStreamConcrete, SStreamIteratorLong, +} import is.hail.types.physical.stypes.primitives._ import is.hail.types.virtual._ -import is.hail.utils._ object SingleCodeType { def typeInfoFromType(t: Type): TypeInfo[_] = t match { @@ -38,7 +39,8 @@ sealed trait SingleCodeType { def virtualType: Type - def coerceSCode(cb: EmitCodeBuilder, pc: SValue, region: Value[Region], deepCopy: Boolean): SingleCodeSCode + def coerceSCode(cb: EmitCodeBuilder, pc: SValue, region: Value[Region], deepCopy: Boolean) + : SingleCodeSCode def loadedSType: SType } @@ -52,7 +54,8 @@ case object Int32SingleCodeType extends SingleCodeType { def virtualType: Type = TInt32 - def coerceSCode(cb: EmitCodeBuilder, pc: SValue, region: Value[Region], deepCopy: Boolean): SingleCodeSCode = + def coerceSCode(cb: EmitCodeBuilder, pc: SValue, region: Value[Region], deepCopy: Boolean) + : SingleCodeSCode = SingleCodeSCode(this, pc.asInt.value) } @@ -65,7 +68,8 @@ case object Int64SingleCodeType extends SingleCodeType { def virtualType: Type = TInt64 - def coerceSCode(cb: EmitCodeBuilder, pc: SValue, region: Value[Region], deepCopy: Boolean): SingleCodeSCode = + def coerceSCode(cb: EmitCodeBuilder, pc: SValue, region: Value[Region], deepCopy: Boolean) + : SingleCodeSCode = SingleCodeSCode(this, pc.asLong.value) } @@ -78,7 +82,8 @@ case object Float32SingleCodeType extends SingleCodeType { def virtualType: Type = TFloat32 - def coerceSCode(cb: EmitCodeBuilder, pc: SValue, region: Value[Region], deepCopy: Boolean): SingleCodeSCode = + def coerceSCode(cb: EmitCodeBuilder, pc: SValue, region: Value[Region], deepCopy: Boolean) + : SingleCodeSCode = SingleCodeSCode(this, pc.asFloat.value) } @@ -91,7 +96,8 @@ case object Float64SingleCodeType extends SingleCodeType { def virtualType: Type = TFloat64 - def coerceSCode(cb: EmitCodeBuilder, pc: SValue, region: Value[Region], deepCopy: Boolean): SingleCodeSCode = + def coerceSCode(cb: EmitCodeBuilder, pc: SValue, region: Value[Region], deepCopy: Boolean) + : SingleCodeSCode = SingleCodeSCode(this, pc.asDouble.value) } @@ -104,11 +110,16 @@ case object BooleanSingleCodeType extends SingleCodeType { def virtualType: Type = TBoolean - def coerceSCode(cb: EmitCodeBuilder, pc: SValue, region: Value[Region], deepCopy: Boolean): SingleCodeSCode = + def coerceSCode(cb: EmitCodeBuilder, pc: SValue, region: Value[Region], deepCopy: Boolean) + : SingleCodeSCode = SingleCodeSCode(this, pc.asBoolean.value) } -case class StreamSingleCodeType(requiresMemoryManagementPerElement: Boolean, eltType: PType, eltRequired: Boolean) extends SingleCodeType { +case class StreamSingleCodeType( + requiresMemoryManagementPerElement: Boolean, + eltType: PType, + eltRequired: Boolean, +) extends SingleCodeType { self => override def loadedSType: SType = SStream(EmitType(eltType.sType, true)) @@ -117,14 +128,14 @@ case class StreamSingleCodeType(requiresMemoryManagementPerElement: Boolean, elt def ti: TypeInfo[_] = classInfo[NoBoxLongIterator] - def loadToSValue(cb: EmitCodeBuilder, c: Value[_]): SValue = { + def loadToSValue(cb: EmitCodeBuilder, c: Value[_]): SValue = new SStreamConcrete( SStreamIteratorLong(eltRequired, eltType, requiresMemoryManagementPerElement), - coerce[NoBoxLongIterator](c) + coerce[NoBoxLongIterator](c), ) - } - def coerceSCode(cb: EmitCodeBuilder, pc: SValue, region: Value[Region], deepCopy: Boolean): SingleCodeSCode = + def coerceSCode(cb: EmitCodeBuilder, pc: SValue, region: Value[Region], deepCopy: Boolean) + : SingleCodeSCode = throw new UnsupportedOperationException } @@ -138,15 +149,15 @@ case class PTypeReferenceSingleCodeType(pt: PType) extends SingleCodeType { def virtualType: Type = pt.virtualType - def coerceSCode(cb: EmitCodeBuilder, pc: SValue, region: Value[Region], deepCopy: Boolean): SingleCodeSCode = { + def coerceSCode(cb: EmitCodeBuilder, pc: SValue, region: Value[Region], deepCopy: Boolean) + : SingleCodeSCode = SingleCodeSCode(this, pt.store(cb, region, pc, deepCopy = deepCopy)) - } } object SingleCodeSCode { - def fromSCode(cb: EmitCodeBuilder, pc: SValue, region: Value[Region], deepCopy: Boolean = false): SingleCodeSCode = { + def fromSCode(cb: EmitCodeBuilder, pc: SValue, region: Value[Region], deepCopy: Boolean = false) + : SingleCodeSCode = SingleCodeType.fromSType(pc.st).coerceSCode(cb, pc, region, deepCopy) - } } -case class SingleCodeSCode(typ: SingleCodeType, code: Value[_]) \ No newline at end of file +case class SingleCodeSCode(typ: SingleCodeType, code: Value[_]) diff --git a/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SBaseStructPointer.scala b/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SBaseStructPointer.scala index e11c681a6ea..1fbcf2f8947 100644 --- a/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SBaseStructPointer.scala +++ b/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SBaseStructPointer.scala @@ -3,36 +3,41 @@ package is.hail.types.physical.stypes.concrete import is.hail.annotations.Region import is.hail.asm4s._ import is.hail.expr.ir.{EmitCodeBuilder, IEmitCode} -import is.hail.types.physical.stypes.interfaces.{SBaseStruct, SBaseStructSettable, SBaseStructValue} -import is.hail.types.physical.stypes.{EmitType, SType, SValue} import is.hail.types.physical.{PBaseStruct, PType} +import is.hail.types.physical.stypes.{EmitType, SType, SValue} +import is.hail.types.physical.stypes.interfaces.{SBaseStruct, SBaseStructSettable, SBaseStructValue} import is.hail.types.virtual.{TBaseStruct, Type} import is.hail.utils.FastSeq - final case class SBaseStructPointer(pType: PBaseStruct) extends SBaseStruct { require(!pType.required) override def size: Int = pType.size override lazy val virtualType: TBaseStruct = pType.virtualType.asInstanceOf[TBaseStruct] - override def castRename(t: Type): SType = SBaseStructPointer(pType.deepRename(t).asInstanceOf[PBaseStruct]) + override def castRename(t: Type): SType = + SBaseStructPointer(pType.deepRename(t).asInstanceOf[PBaseStruct]) override def fieldIdx(fieldName: String): Int = pType.fieldIdx(fieldName) - override def _coerceOrCopy(cb: EmitCodeBuilder, region: Value[Region], value: SValue, deepCopy: Boolean): SValue = + override def _coerceOrCopy( + cb: EmitCodeBuilder, + region: Value[Region], + value: SValue, + deepCopy: Boolean, + ): SValue = new SBaseStructPointerValue(this, pType.store(cb, region, value, deepCopy)) override def settableTupleTypes(): IndexedSeq[TypeInfo[_]] = FastSeq(LongInfo) override def fromSettables(settables: IndexedSeq[Settable[_]]): SBaseStructPointerSettable = { - val IndexedSeq(a: Settable[Long@unchecked]) = settables + val IndexedSeq(a: Settable[Long @unchecked]) = settables assert(a.ti == LongInfo) new SBaseStructPointerSettable(this, a) } override def fromValues(values: IndexedSeq[Value[_]]): SBaseStructPointerValue = { - val IndexedSeq(a: Value[Long@unchecked]) = values + val IndexedSeq(a: Value[Long @unchecked]) = values assert(a.ti == LongInfo) new SBaseStructPointerValue(this, a) } @@ -40,7 +45,9 @@ final case class SBaseStructPointer(pType: PBaseStruct) extends SBaseStruct { def canonicalPType(): PType = pType override val fieldTypes: IndexedSeq[SType] = pType.types.map(_.sType) - override val fieldEmitTypes: IndexedSeq[EmitType] = pType.types.map(t => EmitType(t.sType, t.required)) + + override val fieldEmitTypes: IndexedSeq[EmitType] = + pType.types.map(t => EmitType(t.sType, t.required)) override def containsPointers: Boolean = pType.containsPointers @@ -51,36 +58,34 @@ final case class SBaseStructPointer(pType: PBaseStruct) extends SBaseStruct { class SBaseStructPointerValue( val st: SBaseStructPointer, - val a: Value[Long] + val a: Value[Long], ) extends SBaseStructValue { val pt: PBaseStruct = st.pType override lazy val valueTuple: IndexedSeq[Value[_]] = FastSeq(a) - override def loadField(cb: EmitCodeBuilder, fieldIdx: Int): IEmitCode = { - IEmitCode(cb, + override def loadField(cb: EmitCodeBuilder, fieldIdx: Int): IEmitCode = + IEmitCode( + cb, pt.isFieldMissing(cb, a, fieldIdx), - pt.fields(fieldIdx).typ.loadCheapSCode(cb, pt.loadField(a, fieldIdx))) - } + pt.fields(fieldIdx).typ.loadCheapSCode(cb, pt.loadField(a, fieldIdx)), + ) - override def isFieldMissing(cb: EmitCodeBuilder, fieldIdx: Int): Value[Boolean] = { + override def isFieldMissing(cb: EmitCodeBuilder, fieldIdx: Int): Value[Boolean] = pt.isFieldMissing(cb, a, fieldIdx) - } } object SBaseStructPointerSettable { - def apply(sb: SettableBuilder, st: SBaseStructPointer, name: String): SBaseStructPointerSettable = { + def apply(sb: SettableBuilder, st: SBaseStructPointer, name: String): SBaseStructPointerSettable = new SBaseStructPointerSettable(st, sb.newSettable(name)) - } } final class SBaseStructPointerSettable( st: SBaseStructPointer, - override val a: Settable[Long] + override val a: Settable[Long], ) extends SBaseStructPointerValue(st, a) with SBaseStructSettable { override def settableTuple(): IndexedSeq[Settable[_]] = FastSeq(a) - override def store(cb: EmitCodeBuilder, v: SValue): Unit = { + override def store(cb: EmitCodeBuilder, v: SValue): Unit = cb.assign(a, v.asInstanceOf[SBaseStructPointerValue].a) - } } diff --git a/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SBinaryPointer.scala b/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SBinaryPointer.scala index dada00520d7..5bf3ab23250 100644 --- a/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SBinaryPointer.scala +++ b/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SBinaryPointer.scala @@ -3,38 +3,41 @@ package is.hail.types.physical.stypes.concrete import is.hail.annotations.Region import is.hail.asm4s._ import is.hail.expr.ir.EmitCodeBuilder -import is.hail.types.physical.stypes.interfaces.{SBinary, SBinaryValue} -import is.hail.types.physical.stypes.{SSettable, SType, SValue} import is.hail.types.physical.{PBinary, PType} +import is.hail.types.physical.stypes.{SSettable, SType, SValue} +import is.hail.types.physical.stypes.interfaces.{SBinary, SBinaryValue} import is.hail.types.virtual.Type import is.hail.utils._ - final case class SBinaryPointer(pType: PBinary) extends SBinary { require(!pType.required) override lazy val virtualType: Type = pType.virtualType - override def _coerceOrCopy(cb: EmitCodeBuilder, region: Value[Region], value: SValue, deepCopy: Boolean): SValue = { + + override def _coerceOrCopy( + cb: EmitCodeBuilder, + region: Value[Region], + value: SValue, + deepCopy: Boolean, + ): SValue = new SBinaryPointerValue(this, pType.store(cb, region, value, deepCopy)) - } override def settableTupleTypes(): IndexedSeq[TypeInfo[_]] = FastSeq(LongInfo) - def loadFrom(cb: EmitCodeBuilder, region: Value[Region], pt: PType, addr: Value[Long]): SValue = { + def loadFrom(cb: EmitCodeBuilder, region: Value[Region], pt: PType, addr: Value[Long]): SValue = if (pt == this.pType) new SBinaryPointerValue(this, addr) else coerceOrCopy(cb, region, pt.loadCheapSCode(cb, addr), deepCopy = false) - } override def fromSettables(settables: IndexedSeq[Settable[_]]): SBinaryPointerSettable = { - val IndexedSeq(a: Settable[Long@unchecked]) = settables + val IndexedSeq(a: Settable[Long @unchecked]) = settables assert(a.ti == LongInfo) new SBinaryPointerSettable(this, a) } override def fromValues(values: IndexedSeq[Value[_]]): SBinaryPointerValue = { - val IndexedSeq(a: Value[Long@unchecked]) = values + val IndexedSeq(a: Value[Long @unchecked]) = values assert(a.ti == LongInfo) new SBinaryPointerValue(this, a) } @@ -50,7 +53,7 @@ final case class SBinaryPointer(pType: PBinary) extends SBinary { class SBinaryPointerValue( val st: SBinaryPointer, - val a: Value[Long] + val a: Value[Long], ) extends SBinaryValue { private val pt: PBinary = st.pType @@ -75,7 +78,7 @@ object SBinaryPointerSettable { final class SBinaryPointerSettable( st: SBinaryPointer, - override val a: Settable[Long] + override val a: Settable[Long], ) extends SBinaryPointerValue(st, a) with SSettable { override def settableTuple(): IndexedSeq[Settable[_]] = FastSeq(a) diff --git a/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SCanonicalCall.scala b/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SCanonicalCall.scala index ab82e612f2c..e4eb171522c 100644 --- a/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SCanonicalCall.scala +++ b/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SCanonicalCall.scala @@ -3,21 +3,20 @@ package is.hail.types.physical.stypes.concrete import is.hail.annotations.Region import is.hail.asm4s._ import is.hail.expr.ir.EmitCodeBuilder +import is.hail.types.physical.{PCall, PCanonicalCall, PType} +import is.hail.types.physical.stypes.{SSettable, SType, SValue} import is.hail.types.physical.stypes.interfaces.{SCall, SCallValue, SIndexableValue} import is.hail.types.physical.stypes.primitives.SInt64Value -import is.hail.types.physical.stypes.{SSettable, SType, SValue} -import is.hail.types.physical.{PCall, PCanonicalCall, PType} import is.hail.types.virtual.{TCall, Type} import is.hail.utils._ import is.hail.variant._ - case object SCanonicalCall extends SCall { - def _coerceOrCopy(cb: EmitCodeBuilder, region: Value[Region], value: SValue, deepCopy: Boolean): SValue = { + def _coerceOrCopy(cb: EmitCodeBuilder, region: Value[Region], value: SValue, deepCopy: Boolean) + : SValue = value.st match { case SCanonicalCall => value } - } lazy val virtualType: Type = TCall @@ -26,13 +25,13 @@ case object SCanonicalCall extends SCall { def settableTupleTypes(): IndexedSeq[TypeInfo[_]] = FastSeq(IntInfo) def fromSettables(settables: IndexedSeq[Settable[_]]): SCanonicalCallSettable = { - val IndexedSeq(call: Settable[Int@unchecked]) = settables + val IndexedSeq(call: Settable[Int @unchecked]) = settables assert(call.ti == IntInfo) new SCanonicalCallSettable(call) } def fromValues(values: IndexedSeq[Value[_]]): SCanonicalCallValue = { - val IndexedSeq(call: Value[Int@unchecked]) = values + val IndexedSeq(call: Value[Int @unchecked]) = values assert(call.ti == IntInfo) new SCanonicalCallValue(call) } @@ -52,16 +51,24 @@ class SCanonicalCallValue(val call: Value[Int]) extends SCallValue { override def unphase(cb: EmitCodeBuilder): SCanonicalCallValue = { val repr = cb.newLocal[Int]("unphase_call", call) - cb.if_(isPhased(cb), + cb.if_( + isPhased(cb), cb.assign(repr, Code.invokeScalaObject1[Int, Int](Call.getClass, "unphase", call)), - cb.assign(repr, call)) + cb.assign(repr, call), + ) new SCanonicalCallValue(repr) } - def containsAllele(cb: EmitCodeBuilder, allele: Value[Int]): Value[Boolean] = { - cb.memoize[Boolean](Code.invokeScalaObject2[Int, Int, Boolean]( - Call.getClass, "containsAllele", call, allele), "contains_allele") - } + def containsAllele(cb: EmitCodeBuilder, allele: Value[Int]): Value[Boolean] = + cb.memoize[Boolean]( + Code.invokeScalaObject2[Int, Int, Boolean]( + Call.getClass, + "containsAllele", + call, + allele, + ), + "contains_allele", + ) override def canonicalCall(cb: EmitCodeBuilder): Value[Int] = call @@ -81,76 +88,128 @@ class SCanonicalCallValue(val call: Value[Int]) extends SCallValue { val j = cb.newLocal[Int]("fea_j") val k = cb.newLocal[Int]("fea_k") - cb.if_(p.ceq(2), { - cb.if_(call2 < Genotype.nCachedAllelePairs, { - cb.assign(j, Code.invokeScalaObject1[Int, Int](Genotype.getClass, "cachedAlleleJ", call2)) - cb.assign(k, Code.invokeScalaObject1[Int, Int](Genotype.getClass, "cachedAlleleK", call2)) - }, { - cb.assign(k, (Code.invokeStatic1[Math, Double, Double]("sqrt", const(8d) * call2.toD + 1d) / 2d - 0.5).toI) - cb.assign(j, call2 - (k * (k + 1) / 2)) - }) - alleleCode(j) - cb.if_(isPhased(cb), cb.assign(k, k - j)) - alleleCode(k) - }, { - cb.if_(p.ceq(1), + cb.if_( + p.ceq(2), { + cb.if_( + call2 < Genotype.nCachedAllelePairs, { + cb.assign( + j, + Code.invokeScalaObject1[Int, Int](Genotype.getClass, "cachedAlleleJ", call2), + ) + cb.assign( + k, + Code.invokeScalaObject1[Int, Int](Genotype.getClass, "cachedAlleleK", call2), + ) + }, { + cb.assign( + k, + (Code.invokeStatic1[Math, Double, Double]( + "sqrt", + const(8d) * call2.toD + 1d, + ) / 2d - 0.5).toI, + ) + cb.assign(j, call2 - (k * (k + 1) / 2)) + }, + ) + alleleCode(j) + cb.if_(isPhased(cb), cb.assign(k, k - j)) + alleleCode(k) + }, + cb.if_( + p.ceq(1), alleleCode(call2), - cb.if_(p.cne(0), - cb.append(Code._fatal[Unit](const("invalid ploidy: ").concat(p.toS))))) - }) + cb.if_(p.cne(0), cb.append(Code._fatal[Unit](const("invalid ploidy: ").concat(p.toS)))), + ), + ) } - override def lgtToGT(cb: EmitCodeBuilder, localAlleles: SIndexableValue, errorID: Value[Int]): SCallValue = { + override def lgtToGT(cb: EmitCodeBuilder, localAlleles: SIndexableValue, errorID: Value[Int]) + : SCallValue = { def checkAndTranslate(cb: EmitCodeBuilder, allele: Code[Int]): Code[Int] = { val av = cb.newLocal[Int](s"allele", allele) - cb.if_(av >= localAlleles.loadLength(), - cb._fatalWithError(errorID, - s"lgt_to_gt: found allele ", av.toS, ", but there are only ", localAlleles.loadLength().toS, " local alleles")) - localAlleles.loadElement(cb, av).get(cb, const("lgt_to_gt: found missing value in local alleles at index ").concat(av.toS), errorID = errorID) - .asInt.value + cb.if_( + av >= localAlleles.loadLength(), + cb._fatalWithError( + errorID, + s"lgt_to_gt: found allele ", + av.toS, + ", but there are only ", + localAlleles.loadLength().toS, + " local alleles", + ), + ) + localAlleles.loadElement(cb, av).getOrFatal( + cb, + const("lgt_to_gt: found missing value in local alleles at index ").concat(av.toS), + errorID = errorID, + ) + .asInt.value } val repr = cb.newLocal[Int]("lgt_to_gt_repr") - cb.switch(ploidy(cb), + cb.switch( + ploidy(cb), cb._fatalWithError(errorID, s"ploidy above 2 is not currently supported"), FastSeq( - { () => cb.assign(repr, call) }, // ploidy 0 + () => cb.assign(repr, call), // ploidy 0 { () => val allele = Code.invokeScalaObject1[Int, Int](Call.getClass, "alleleRepr", call) - val newCall = Code.invokeScalaObject2[Int, Boolean, Int](Call1.getClass, "apply", - checkAndTranslate(cb, allele), isPhased(cb) + val newCall = Code.invokeScalaObject2[Int, Boolean, Int]( + Call1.getClass, + "apply", + checkAndTranslate(cb, allele), + isPhased(cb), ) cb.assign(repr, newCall) }, // ploidy 1 { () => - val allelePair = cb.newLocal[Int]("allelePair", Code.invokeScalaObject1[Int, Int](Call.getClass, "allelePairUnchecked", call)) - val j = cb.newLocal[Int]("allele_j", Code.invokeScalaObject1[Int, Int](AllelePair.getClass, "j", allelePair)) - val k = cb.newLocal[Int]("allele_k", Code.invokeScalaObject1[Int, Int](AllelePair.getClass, "k", allelePair)) + val allelePair = cb.newLocal[Int]( + "allelePair", + Code.invokeScalaObject1[Int, Int](Call.getClass, "allelePairUnchecked", call), + ) + val j = cb.newLocal[Int]( + "allele_j", + Code.invokeScalaObject1[Int, Int](AllelePair.getClass, "j", allelePair), + ) + val k = cb.newLocal[Int]( + "allele_k", + Code.invokeScalaObject1[Int, Int](AllelePair.getClass, "k", allelePair), + ) - cb.if_(j >= localAlleles.loadLength(), cb._fatalWithError(errorID, "invalid lgt_to_gt: allele ")) + cb.if_( + j >= localAlleles.loadLength(), + cb._fatalWithError(errorID, "invalid lgt_to_gt: allele "), + ) - cb.assign(repr, Code.invokeScalaObject4[Int, Int, Boolean, Int, Int](Call2.getClass, "withErrorID", - checkAndTranslate(cb, j), - checkAndTranslate(cb, k), - isPhased(cb), - errorID) + cb.assign( + repr, + Code.invokeScalaObject4[Int, Int, Boolean, Int, Int]( + Call2.getClass, + "withErrorID", + checkAndTranslate(cb, j), + checkAndTranslate(cb, k), + isPhased(cb), + errorID, + ), ) - } // ploidy 2 - ) + }, // ploidy 2 + ), ) new SCanonicalCallValue(repr) } - override def sizeToStoreInBytes(cb: EmitCodeBuilder): SInt64Value = new SInt64Value(this.pt.byteSize) + override def sizeToStoreInBytes(cb: EmitCodeBuilder): SInt64Value = + new SInt64Value(this.pt.byteSize) } object SCanonicalCallSettable { def apply(sb: SettableBuilder, name: String): SCanonicalCallSettable = - new SCanonicalCallSettable(sb.newSettable[Int](s"${ name }_call")) + new SCanonicalCallSettable(sb.newSettable[Int](s"${name}_call")) } -final class SCanonicalCallSettable(override val call: Settable[Int]) extends SCanonicalCallValue(call) with SSettable { +final class SCanonicalCallSettable(override val call: Settable[Int]) + extends SCanonicalCallValue(call) with SSettable { override def store(cb: EmitCodeBuilder, v: SValue): Unit = cb.assign(call, v.asInstanceOf[SCanonicalCallValue].call) diff --git a/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SCanonicalLocusPointer.scala b/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SCanonicalLocusPointer.scala index 56054a80feb..05922aa2b22 100644 --- a/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SCanonicalLocusPointer.scala +++ b/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SCanonicalLocusPointer.scala @@ -3,13 +3,12 @@ package is.hail.types.physical.stypes.concrete import is.hail.annotations.Region import is.hail.asm4s._ import is.hail.expr.ir.EmitCodeBuilder -import is.hail.types.physical.stypes.interfaces._ -import is.hail.types.physical.stypes.{SSettable, SType, SValue} import is.hail.types.physical.{PCanonicalLocus, PType} +import is.hail.types.physical.stypes.{SSettable, SType, SValue} +import is.hail.types.physical.stypes.interfaces._ import is.hail.types.virtual.Type import is.hail.utils.FastSeq - final case class SCanonicalLocusPointer(pType: PCanonicalLocus) extends SLocus { require(!pType.required) @@ -21,7 +20,12 @@ final case class SCanonicalLocusPointer(pType: PCanonicalLocus) extends SLocus { override def rg: String = pType.rg - override def _coerceOrCopy(cb: EmitCodeBuilder, region: Value[Region], value: SValue, deepCopy: Boolean): SValue = + override def _coerceOrCopy( + cb: EmitCodeBuilder, + region: Value[Region], + value: SValue, + deepCopy: Boolean, + ): SValue = value match { case value: SLocusValue => val locusCopy = pType.store(cb, region, value, deepCopy) @@ -35,7 +39,11 @@ final case class SCanonicalLocusPointer(pType: PCanonicalLocus) extends SLocus { override def settableTupleTypes(): IndexedSeq[TypeInfo[_]] = FastSeq(LongInfo, LongInfo, IntInfo) override def fromSettables(settables: IndexedSeq[Settable[_]]): SCanonicalLocusPointerSettable = { - val IndexedSeq(a: Settable[Long@unchecked], contig: Settable[Long@unchecked], position: Settable[Int@unchecked]) = settables + val IndexedSeq( + a: Settable[Long @unchecked], + contig: Settable[Long @unchecked], + position: Settable[Int @unchecked], + ) = settables assert(a.ti == LongInfo) assert(contig.ti == LongInfo) assert(position.ti == IntInfo) @@ -43,7 +51,11 @@ final case class SCanonicalLocusPointer(pType: PCanonicalLocus) extends SLocus { } override def fromValues(values: IndexedSeq[Value[_]]): SCanonicalLocusPointerValue = { - val IndexedSeq(a: Value[Long@unchecked], contig: Value[Long@unchecked], position: Value[Int@unchecked]) = values + val IndexedSeq( + a: Value[Long @unchecked], + contig: Value[Long @unchecked], + position: Value[Int @unchecked], + ) = values assert(a.ti == LongInfo) assert(contig.ti == LongInfo) assert(position.ti == IntInfo) @@ -52,7 +64,8 @@ final case class SCanonicalLocusPointer(pType: PCanonicalLocus) extends SLocus { override def storageType(): PType = pType - override def copiedType: SType = SCanonicalLocusPointer(pType.copiedType.asInstanceOf[PCanonicalLocus]) + override def copiedType: SType = + SCanonicalLocusPointer(pType.copiedType.asInstanceOf[PCanonicalLocus]) override def containsPointers: Boolean = pType.containsPointers } @@ -61,38 +74,41 @@ class SCanonicalLocusPointerValue( val st: SCanonicalLocusPointer, val a: Value[Long], val _contig: Value[Long], - val _position: Value[Int] + val _position: Value[Int], ) extends SLocusValue { val pt: PCanonicalLocus = st.pType override lazy val valueTuple: IndexedSeq[Value[_]] = FastSeq(a, _contig, _position) - override def contig(cb: EmitCodeBuilder): SStringValue = { + override def contig(cb: EmitCodeBuilder): SStringValue = pt.contigType.loadCheapSCode(cb, _contig).asString - } override def contigLong(cb: EmitCodeBuilder): Value[Long] = _contig override def position(cb: EmitCodeBuilder): Value[Int] = _position override def structRepr(cb: EmitCodeBuilder): SBaseStructValue = new SBaseStructPointerValue( - SBaseStructPointer(st.pType.representation), a) + SBaseStructPointer(st.pType.representation), + a, + ) } object SCanonicalLocusPointerSettable { - def apply(sb: SettableBuilder, st: SCanonicalLocusPointer, name: String): SCanonicalLocusPointerSettable = { - new SCanonicalLocusPointerSettable(st, - sb.newSettable[Long](s"${ name }_a"), - sb.newSettable[Long](s"${ name }_contig"), - sb.newSettable[Int](s"${ name }_position")) - } + def apply(sb: SettableBuilder, st: SCanonicalLocusPointer, name: String) + : SCanonicalLocusPointerSettable = + new SCanonicalLocusPointerSettable( + st, + sb.newSettable[Long](s"${name}_a"), + sb.newSettable[Long](s"${name}_contig"), + sb.newSettable[Int](s"${name}_position"), + ) } final class SCanonicalLocusPointerSettable( st: SCanonicalLocusPointer, override val a: Settable[Long], _contig: Settable[Long], - override val _position: Settable[Int] + override val _position: Settable[Int], ) extends SCanonicalLocusPointerValue(st, a, _contig, _position) with SSettable { override def settableTuple(): IndexedSeq[Settable[_]] = FastSeq(a, _contig, _position) @@ -103,6 +119,9 @@ final class SCanonicalLocusPointerSettable( cb.assign(_position, v._position) } - override def structRepr(cb: EmitCodeBuilder): SBaseStructPointerSettable = new SBaseStructPointerSettable( - SBaseStructPointer(st.pType.representation), a) + override def structRepr(cb: EmitCodeBuilder): SBaseStructPointerSettable = + new SBaseStructPointerSettable( + SBaseStructPointer(st.pType.representation), + a, + ) } diff --git a/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SIndexablePointer.scala b/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SIndexablePointer.scala index 850cda4c785..f96b49724ad 100644 --- a/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SIndexablePointer.scala +++ b/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SIndexablePointer.scala @@ -9,29 +9,43 @@ import is.hail.types.physical.stypes.interfaces.{SContainer, SIndexableValue} import is.hail.types.virtual.Type import is.hail.utils.FastSeq - final case class SIndexablePointer(pType: PContainer) extends SContainer { require(!pType.required) override lazy val virtualType: Type = pType.virtualType - override def castRename(t: Type): SType = SIndexablePointer(pType.deepRename(t).asInstanceOf[PContainer]) + override def castRename(t: Type): SType = + SIndexablePointer(pType.deepRename(t).asInstanceOf[PContainer]) override def elementType: SType = pType.elementType.sType override def elementEmitType: EmitType = EmitType(elementType, pType.elementType.required) - override def _coerceOrCopy(cb: EmitCodeBuilder, region: Value[Region], value: SValue, deepCopy: Boolean): SValue = + override def _coerceOrCopy( + cb: EmitCodeBuilder, + region: Value[Region], + value: SValue, + deepCopy: Boolean, + ): SValue = value match { case value: SIndexableValue => val a = pType.store(cb, region, value, deepCopy) - new SIndexablePointerValue(this, a, value.loadLength(), cb.memoize(pType.firstElementOffset(a))) + new SIndexablePointerValue( + this, + a, + value.loadLength(), + cb.memoize(pType.firstElementOffset(a)), + ) } override def settableTupleTypes(): IndexedSeq[TypeInfo[_]] = FastSeq(LongInfo, IntInfo, LongInfo) override def fromSettables(settables: IndexedSeq[Settable[_]]): SIndexablePointerSettable = { - val IndexedSeq(a: Settable[Long@unchecked], length: Settable[Int@unchecked], elementsAddress: Settable[Long@unchecked]) = settables + val IndexedSeq( + a: Settable[Long @unchecked], + length: Settable[Int @unchecked], + elementsAddress: Settable[Long @unchecked], + ) = settables assert(a.ti == LongInfo) assert(length.ti == IntInfo) assert(elementsAddress.ti == LongInfo) @@ -39,7 +53,11 @@ final case class SIndexablePointer(pType: PContainer) extends SContainer { } override def fromValues(values: IndexedSeq[Value[_]]): SIndexablePointerValue = { - val IndexedSeq(a: Value[Long@unchecked], length: Value[Int@unchecked], elementsAddress: Value[Long@unchecked]) = values + val IndexedSeq( + a: Value[Long @unchecked], + length: Value[Int @unchecked], + elementsAddress: Value[Long @unchecked], + ) = values assert(a.ti == LongInfo) assert(length.ti == IntInfo) assert(elementsAddress.ti == LongInfo) @@ -57,7 +75,7 @@ class SIndexablePointerValue( override val st: SIndexablePointer, val a: Value[Long], val length: Value[Int], - val elementsAddress: Value[Long] + val elementsAddress: Value[Long], ) extends SIndexableValue { val pt: PContainer = st.pType @@ -67,9 +85,11 @@ class SIndexablePointerValue( override def loadElement(cb: EmitCodeBuilder, i: Code[Int]): IEmitCode = { val iv = cb.memoize(i) - IEmitCode(cb, + IEmitCode( + cb, isElementMissing(cb, iv), - pt.elementType.loadCheapSCode(cb, pt.loadElement(a, length, iv))) // FIXME loadElement should take elementsAddress + pt.elementType.loadCheapSCode(cb, pt.loadElement(a, length, iv)), + ) // FIXME loadElement should take elementsAddress } override def isElementMissing(cb: EmitCodeBuilder, i: Code[Int]): Value[Boolean] = @@ -78,36 +98,41 @@ class SIndexablePointerValue( override def hasMissingValues(cb: EmitCodeBuilder): Value[Boolean] = cb.memoize(pt.hasMissingValues(a)) - override def castToArray(cb: EmitCodeBuilder): SIndexableValue = { + override def castToArray(cb: EmitCodeBuilder): SIndexableValue = pt match { - case t: PArray => this - case t: PCanonicalDict => new SIndexablePointerValue(SIndexablePointer(t.arrayRep), a, length, elementsAddress) - case t: PCanonicalSet => new SIndexablePointerValue(SIndexablePointer(t.arrayRep), a, length, elementsAddress) + case _: PArray => this + case t: PCanonicalDict => + new SIndexablePointerValue(SIndexablePointer(t.arrayRep), a, length, elementsAddress) + case t: PCanonicalSet => + new SIndexablePointerValue(SIndexablePointer(t.arrayRep), a, length, elementsAddress) } - } - override def forEachDefined(cb: EmitCodeBuilder)(f: (EmitCodeBuilder, Value[Int], SValue) => Unit) { + override def forEachDefined( + cb: EmitCodeBuilder + )( + f: (EmitCodeBuilder, Value[Int], SValue) => Unit + ): Unit = st.pType match { case pca: PCanonicalArray => pca.forEachDefined(cb, a)(f) case _ => super.forEachDefined(cb)(f) } - } } object SIndexablePointerSettable { - def apply(sb: SettableBuilder, st: SIndexablePointer, name: String): SIndexablePointerSettable = { - new SIndexablePointerSettable(st, - sb.newSettable[Long](s"${ name }_a"), - sb.newSettable[Int](s"${ name }_length"), - sb.newSettable[Long](s"${ name }_elems_addr")) - } + def apply(sb: SettableBuilder, st: SIndexablePointer, name: String): SIndexablePointerSettable = + new SIndexablePointerSettable( + st, + sb.newSettable[Long](s"${name}_a"), + sb.newSettable[Int](s"${name}_length"), + sb.newSettable[Long](s"${name}_elems_addr"), + ) } final class SIndexablePointerSettable( st: SIndexablePointer, override val a: Settable[Long], override val length: Settable[Int], - override val elementsAddress: Settable[Long] + override val elementsAddress: Settable[Long], ) extends SIndexablePointerValue(st, a, length, elementsAddress) with SSettable { def settableTuple(): IndexedSeq[Settable[_]] = FastSeq(a, length, elementsAddress) diff --git a/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SInsertFieldsStruct.scala b/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SInsertFieldsStruct.scala index f5b444ff2aa..fccd22559b4 100644 --- a/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SInsertFieldsStruct.scala +++ b/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SInsertFieldsStruct.scala @@ -3,13 +3,17 @@ package is.hail.types.physical.stypes.concrete import is.hail.annotations.Region import is.hail.asm4s.{Settable, TypeInfo, Value} import is.hail.expr.ir.{EmitCodeBuilder, EmitSettable, EmitValue, IEmitCode} -import is.hail.types.physical.stypes.interfaces.{SBaseStruct, SBaseStructSettable, SBaseStructValue} -import is.hail.types.physical.stypes.{EmitType, SType, SValue} import is.hail.types.physical.{PCanonicalStruct, PType} +import is.hail.types.physical.stypes.{EmitType, SType, SValue} +import is.hail.types.physical.stypes.interfaces.{SBaseStruct, SBaseStructSettable, SBaseStructValue} import is.hail.types.virtual.{TStruct, Type} import is.hail.utils._ -final case class SInsertFieldsStruct(virtualType: TStruct, parent: SBaseStruct, insertedFields: IndexedSeq[(String, EmitType)]) extends SBaseStruct { +final case class SInsertFieldsStruct( + virtualType: TStruct, + parent: SBaseStruct, + insertedFields: IndexedSeq[(String, EmitType)], +) extends SBaseStruct { override def size: Int = virtualType.size // Maps index in result struct to index in insertedFields. @@ -18,20 +22,22 @@ final case class SInsertFieldsStruct(virtualType: TStruct, parent: SBaseStruct, .map { case ((name, _), idx) => virtualType.fieldIdx(name) -> idx } .toMap - def getFieldIndexInNewOrParent(idx: Int): Either[Int, Int] = { + def getFieldIndexInNewOrParent(idx: Int): Either[Int, Int] = insertedFieldIndices.get(idx) match { case Some(idx) => Right(idx) case None => Left(parent.fieldIdx(virtualType.fieldNames(idx))) } - } - override val fieldEmitTypes: IndexedSeq[EmitType] = virtualType.fieldNames.zipWithIndex.map { case (f, idx) => - insertedFieldIndices.get(idx) match { - case Some(idx) => insertedFields(idx)._2 - case None => parent.fieldEmitTypes(parent.fieldIdx(f)) + override val fieldEmitTypes: IndexedSeq[EmitType] = + virtualType.fieldNames.zipWithIndex.map { case (f, idx) => + insertedFieldIndices.get(idx) match { + case Some(idx) => insertedFields(idx)._2 + case None => parent.fieldEmitTypes(parent.fieldIdx(f)) + } } - } - private lazy val insertedFieldSettableStarts = insertedFields.map(_._2.nSettables).scanLeft(0)(_ + _).init + + private lazy val insertedFieldSettableStarts = + insertedFields.map(_._2.nSettables).scanLeft(0)(_ + _).init override lazy val fieldTypes: IndexedSeq[SType] = fieldEmitTypes.map(_.st) @@ -41,14 +47,20 @@ final case class SInsertFieldsStruct(virtualType: TStruct, parent: SBaseStruct, if (virtualType.size < 64) SStackStruct(virtualType, fieldEmitTypes.map(_.copiedType)) else { - val ct = SBaseStructPointer(PCanonicalStruct(false, virtualType.fieldNames.zip(fieldEmitTypes.map(_.copiedType.storageType)): _*)) + val ct = SBaseStructPointer(PCanonicalStruct( + false, + virtualType.fieldNames.zip(fieldEmitTypes.map(_.copiedType.storageType)): _* + )) assert(ct.virtualType == virtualType, s"ct=$ct, this=$this") ct } } override def storageType(): PType = { - val pt = PCanonicalStruct(false, virtualType.fieldNames.zip(fieldEmitTypes.map(_.copiedType.storageType)): _*) + val pt = PCanonicalStruct( + false, + virtualType.fieldNames.zip(fieldEmitTypes.map(_.copiedType.storageType)): _* + ) assert(pt.virtualType == virtualType, s"cp=$pt, this=$this") pt } @@ -56,34 +68,48 @@ final case class SInsertFieldsStruct(virtualType: TStruct, parent: SBaseStruct, // aspirational implementation // def storageType(): PType = StoredSTypePType(this, false) - override def containsPointers: Boolean = parent.containsPointers || insertedFields.exists(_._2.st.containsPointers) + override def containsPointers: Boolean = + parent.containsPointers || insertedFields.exists(_._2.st.containsPointers) - override lazy val settableTupleTypes: IndexedSeq[TypeInfo[_]] = parent.settableTupleTypes() ++ insertedFields.flatMap(_._2.settableTupleTypes) + override lazy val settableTupleTypes: IndexedSeq[TypeInfo[_]] = + parent.settableTupleTypes() ++ insertedFields.flatMap(_._2.settableTupleTypes) override def fromSettables(settables: IndexedSeq[Settable[_]]): SInsertFieldsStructSettable = { assert(settables.map(_.ti) == settableTupleTypes) - new SInsertFieldsStructSettable(this, parent.fromSettables(settables.take(parent.nSettables)).asInstanceOf[SBaseStructSettable], insertedFields.indices.map { i => - val et = insertedFields(i)._2 - val start = insertedFieldSettableStarts(i) + parent.nSettables - et.fromSettables(settables.slice(start, start + et.nSettables)) - }) + new SInsertFieldsStructSettable( + this, + parent.fromSettables(settables.take(parent.nSettables)).asInstanceOf[SBaseStructSettable], + insertedFields.indices.map { i => + val et = insertedFields(i)._2 + val start = insertedFieldSettableStarts(i) + parent.nSettables + et.fromSettables(settables.slice(start, start + et.nSettables)) + }, + ) } override def fromValues(values: IndexedSeq[Value[_]]): SInsertFieldsStructValue = { assert(values.map(_.ti) == settableTupleTypes) - new SInsertFieldsStructValue(this, parent.fromValues(values.take(parent.nSettables)).asInstanceOf[SBaseStructValue], insertedFields.indices.map { i => - val et = insertedFields(i)._2 - val start = insertedFieldSettableStarts(i) + parent.nSettables - et.fromValues(values.slice(start, start + et.nSettables)) - }) + new SInsertFieldsStructValue( + this, + parent.fromValues(values.take(parent.nSettables)).asInstanceOf[SBaseStructValue], + insertedFields.indices.map { i => + val et = insertedFields(i)._2 + val start = insertedFieldSettableStarts(i) + parent.nSettables + et.fromValues(values.slice(start, start + et.nSettables)) + }, + ) } - override def _coerceOrCopy(cb: EmitCodeBuilder, region: Value[Region], value: SValue, deepCopy: Boolean): SValue = { + override def _coerceOrCopy( + cb: EmitCodeBuilder, + region: Value[Region], + value: SValue, + deepCopy: Boolean, + ): SValue = value match { case ss: SInsertFieldsStructValue if ss.st == this => value case _ => throw new RuntimeException(s"copy insertfields struct") } - } override def castRename(t: Type): SType = { val ts = t.asInstanceOf[TStruct] @@ -106,28 +132,27 @@ final case class SInsertFieldsStruct(virtualType: TStruct, parent: SBaseStruct, } val parentPassThroughMap = parentPassThroughFieldBuilder.result().toMap - val parentCastType = TStruct(parentType.fieldNames.map(f => parentPassThroughMap.getOrElse(f, (f, parentType.fieldType(f)))): _*) + val parentCastType = TStruct(parentType.fieldNames.map(f => + parentPassThroughMap.getOrElse(f, (f, parentType.fieldType(f))) + ): _*) val renamedParentType = parent.castRename(parentCastType) - SInsertFieldsStruct(ts, - renamedParentType.asInstanceOf[SBaseStruct], - renamedInsertedFields - ) + SInsertFieldsStruct(ts, renamedParentType.asInstanceOf[SBaseStruct], renamedInsertedFields) } } class SInsertFieldsStructValue( val st: SInsertFieldsStruct, val parent: SBaseStructValue, - val newFields: IndexedSeq[EmitValue] + val newFields: IndexedSeq[EmitValue], ) extends SBaseStructValue { - override lazy val valueTuple: IndexedSeq[Value[_]] = parent.valueTuple ++ newFields.flatMap(_.valueTuple()) + override lazy val valueTuple: IndexedSeq[Value[_]] = + parent.valueTuple ++ newFields.flatMap(_.valueTuple()) - override def loadField(cb: EmitCodeBuilder, fieldIdx: Int): IEmitCode = { + override def loadField(cb: EmitCodeBuilder, fieldIdx: Int): IEmitCode = st.getFieldIndexInNewOrParent(fieldIdx) match { case Left(parentIdx) => parent.loadField(cb, parentIdx) case Right(newFieldsIdx) => newFields(newFieldsIdx).toI(cb) } - } override def isFieldMissing(cb: EmitCodeBuilder, fieldIdx: Int): Value[Boolean] = st.getFieldIndexInNewOrParent(fieldIdx) match { @@ -139,7 +164,7 @@ class SInsertFieldsStructValue( val newFieldSet = fields.map(_._1).toSet val filteredNewFields = st.insertedFields.map(_._1) .zipWithIndex - .filter { case (name, idx) => !newFieldSet.contains(name) } + .filter { case (name, _) => !newFieldSet.contains(name) } .map { case (name, idx) => (name, newFields(idx)) } parent._insert(newType, filteredNewFields ++ fields: _*) } @@ -148,15 +173,14 @@ class SInsertFieldsStructValue( final class SInsertFieldsStructSettable( st: SInsertFieldsStruct, parent: SBaseStructSettable, - newFields: IndexedSeq[EmitSettable] + newFields: IndexedSeq[EmitSettable], ) extends SInsertFieldsStructValue(st, parent, newFields) with SBaseStructSettable { - override def settableTuple(): IndexedSeq[Settable[_]] = parent.settableTuple() ++ newFields.flatMap(_.settableTuple()) + override def settableTuple(): IndexedSeq[Settable[_]] = + parent.settableTuple() ++ newFields.flatMap(_.settableTuple()) override def store(cb: EmitCodeBuilder, sv: SValue): Unit = sv match { case sv: SInsertFieldsStructValue => parent.store(cb, sv.parent) - (newFields, sv.newFields).zipped.foreach { (settable, value) => - cb.assign(settable, value) - } + (newFields, sv.newFields).zipped.foreach((settable, value) => cb.assign(settable, value)) } -} \ No newline at end of file +} diff --git a/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SIntervalPointer.scala b/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SIntervalPointer.scala index 6e0e4d0e032..67d7b15aebe 100644 --- a/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SIntervalPointer.scala +++ b/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SIntervalPointer.scala @@ -3,31 +3,45 @@ package is.hail.types.physical.stypes.concrete import is.hail.annotations.Region import is.hail.asm4s.{BooleanInfo, LongInfo, Settable, SettableBuilder, TypeInfo, Value} import is.hail.expr.ir.{EmitCodeBuilder, IEmitCode} +import is.hail.types.physical.{PInterval, PType} import is.hail.types.physical.stypes._ import is.hail.types.physical.stypes.interfaces.{SInterval, SIntervalValue} -import is.hail.types.physical.{PInterval, PType} import is.hail.types.virtual.Type import is.hail.utils.FastSeq - final case class SIntervalPointer(pType: PInterval) extends SInterval { require(!pType.required) - override def _coerceOrCopy(cb: EmitCodeBuilder, region: Value[Region], value: SValue, deepCopy: Boolean): SValue = + override def _coerceOrCopy( + cb: EmitCodeBuilder, + region: Value[Region], + value: SValue, + deepCopy: Boolean, + ): SValue = value match { case value: SIntervalValue => - new SIntervalPointerValue(this, pType.store(cb, region, value, deepCopy), value.includesStart, value.includesEnd) + new SIntervalPointerValue( + this, + pType.store(cb, region, value, deepCopy), + value.includesStart, + value.includesEnd, + ) } - - override def castRename(t: Type): SType = SIntervalPointer(pType.deepRename(t).asInstanceOf[PInterval]) + override def castRename(t: Type): SType = + SIntervalPointer(pType.deepRename(t).asInstanceOf[PInterval]) override lazy val virtualType: Type = pType.virtualType - override def settableTupleTypes(): IndexedSeq[TypeInfo[_]] = FastSeq(LongInfo, BooleanInfo, BooleanInfo) + override def settableTupleTypes(): IndexedSeq[TypeInfo[_]] = + FastSeq(LongInfo, BooleanInfo, BooleanInfo) override def fromSettables(settables: IndexedSeq[Settable[_]]): SIntervalPointerSettable = { - val IndexedSeq(a: Settable[Long@unchecked], includesStart: Settable[Boolean@unchecked], includesEnd: Settable[Boolean@unchecked]) = settables + val IndexedSeq( + a: Settable[Long @unchecked], + includesStart: Settable[Boolean @unchecked], + includesEnd: Settable[Boolean @unchecked], + ) = settables assert(a.ti == LongInfo) assert(includesStart.ti == BooleanInfo) assert(includesEnd.ti == BooleanInfo) @@ -35,7 +49,11 @@ final case class SIntervalPointer(pType: PInterval) extends SInterval { } override def fromValues(values: IndexedSeq[Value[_]]): SIntervalPointerValue = { - val IndexedSeq(a: Value[Long@unchecked], includesStart: Value[Boolean@unchecked], includesEnd: Value[Boolean@unchecked]) = values + val IndexedSeq( + a: Value[Long @unchecked], + includesStart: Value[Boolean @unchecked], + includesEnd: Value[Boolean @unchecked], + ) = values assert(a.ti == LongInfo) assert(includesStart.ti == BooleanInfo) assert(includesEnd.ti == BooleanInfo) @@ -56,43 +74,40 @@ class SIntervalPointerValue( val st: SIntervalPointer, val a: Value[Long], val includesStart: Value[Boolean], - val includesEnd: Value[Boolean] + val includesEnd: Value[Boolean], ) extends SIntervalValue { override lazy val valueTuple: IndexedSeq[Value[_]] = FastSeq(a, includesStart, includesEnd) val pt: PInterval = st.pType override def loadStart(cb: EmitCodeBuilder): IEmitCode = - IEmitCode(cb, - !pt.startDefined(cb, a), - pt.pointType.loadCheapSCode(cb, pt.loadStart(a))) + IEmitCode(cb, !pt.startDefined(cb, a), pt.pointType.loadCheapSCode(cb, pt.loadStart(a))) override def startDefined(cb: EmitCodeBuilder): Value[Boolean] = pt.startDefined(cb, a) override def loadEnd(cb: EmitCodeBuilder): IEmitCode = - IEmitCode(cb, - !pt.endDefined(cb, a), - pt.pointType.loadCheapSCode(cb, pt.loadEnd(a))) + IEmitCode(cb, !pt.endDefined(cb, a), pt.pointType.loadCheapSCode(cb, pt.loadEnd(a))) override def endDefined(cb: EmitCodeBuilder): Value[Boolean] = pt.endDefined(cb, a) } object SIntervalPointerSettable { - def apply(sb: SettableBuilder, st: SIntervalPointer, name: String): SIntervalPointerSettable = { - new SIntervalPointerSettable(st, - sb.newSettable[Long](s"${ name }_a"), - sb.newSettable[Boolean](s"${ name }_includes_start"), - sb.newSettable[Boolean](s"${ name }_includes_end")) - } + def apply(sb: SettableBuilder, st: SIntervalPointer, name: String): SIntervalPointerSettable = + new SIntervalPointerSettable( + st, + sb.newSettable[Long](s"${name}_a"), + sb.newSettable[Boolean](s"${name}_includes_start"), + sb.newSettable[Boolean](s"${name}_includes_end"), + ) } final class SIntervalPointerSettable( st: SIntervalPointer, override val a: Settable[Long], override val includesStart: Settable[Boolean], - override val includesEnd: Settable[Boolean] + override val includesEnd: Settable[Boolean], ) extends SIntervalPointerValue(st, a, includesStart, includesEnd) with SSettable { override def settableTuple(): IndexedSeq[Settable[_]] = FastSeq(a, includesStart, includesEnd) diff --git a/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SJavaArray.scala b/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SJavaArray.scala index b0d46f082b4..12bdac52a5a 100644 --- a/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SJavaArray.scala +++ b/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SJavaArray.scala @@ -3,9 +3,9 @@ package is.hail.types.physical.stypes.concrete import is.hail.annotations.Region import is.hail.asm4s._ import is.hail.expr.ir.{EmitCodeBuilder, IEmitCode} +import is.hail.types.physical.{PCanonicalArray, PCanonicalString, PString, PType} import is.hail.types.physical.stypes._ import is.hail.types.physical.stypes.interfaces.{SContainer, SIndexableValue, SStringValue} -import is.hail.types.physical.{PCanonicalArray, PCanonicalString, PString, PType} import is.hail.types.virtual.{TArray, TString, Type} import is.hail.utils.FastSeq @@ -36,11 +36,16 @@ final case class SJavaArrayString(elementRequired: Boolean) extends SContainer { override def elementEmitType: EmitType = EmitType(elementType, elementRequired) - override def _coerceOrCopy(cb: EmitCodeBuilder, region: Value[Region], value: SValue, deepCopy: Boolean): SValue = { + override def _coerceOrCopy( + cb: EmitCodeBuilder, + region: Value[Region], + value: SValue, + deepCopy: Boolean, + ): SValue = { value.st match { - case SJavaArrayString(_) => new SJavaArrayStringValue(this, value.asInstanceOf[SJavaArrayStringValue].array) + case SJavaArrayString(_) => + new SJavaArrayStringValue(this, value.asInstanceOf[SJavaArrayStringValue].array) case SIndexablePointer(pc) if pc.elementType.isInstanceOf[PString] => - val sv = value.asInstanceOf[SIndexableValue] val len = sv.loadLength() val array = cb.memoize[Array[String]](Code.newArray[String](len)) @@ -50,12 +55,19 @@ final case class SJavaArrayString(elementRequired: Boolean) extends SContainer { cb += (array(i) = v.loadString(cb)) } case (false, r) => - sv.forEachDefinedOrMissing(cb)({ case (cb, i) => - if (r) - cb._fatal("requiredness mismatch: found missing value at index ", i.toS, s" coercing ${ sv.st } to $this") - }, { case (cb, i, elt) => - cb += (array(i) = elt.asString.loadString(cb)) - }) + sv.forEachDefinedOrMissing(cb)( + { case (cb, i) => + if (r) + cb._fatal( + "requiredness mismatch: found missing value at index ", + i.toS, + s" coercing ${sv.st} to $this", + ) + }, + { case (cb, i, elt) => + cb += (array(i) = elt.asString.loadString(cb)) + }, + ) case (false, false) => } new SJavaArrayStringValue(this, array) @@ -65,12 +77,12 @@ final case class SJavaArrayString(elementRequired: Boolean) extends SContainer { override def settableTupleTypes(): IndexedSeq[TypeInfo[_]] = FastSeq(arrayInfo[String]) override def fromSettables(settables: IndexedSeq[Settable[_]]): SJavaArrayStringSettable = { - val IndexedSeq(a: Settable[Array[String]@unchecked]) = settables + val IndexedSeq(a: Settable[Array[String] @unchecked]) = settables new SJavaArrayStringSettable(this, a) } override def fromValues(values: IndexedSeq[Value[_]]): SJavaArrayStringValue = { - val IndexedSeq(a: Value[Array[String]@unchecked]) = values + val IndexedSeq(a: Value[Array[String] @unchecked]) = values new SJavaArrayStringValue(this, a) } @@ -80,7 +92,7 @@ final case class SJavaArrayString(elementRequired: Boolean) extends SContainer { class SJavaArrayStringValue( val st: SJavaArrayString, - val array: Value[Array[String]] + val array: Value[Array[String]], ) extends SIndexableValue { override lazy val valueTuple: IndexedSeq[Value[_]] = FastSeq(array) @@ -93,39 +105,37 @@ class SJavaArrayStringValue( IEmitCode.present(cb, new SJavaStringValue(cb.memoize(array(i)))) else { val iv = cb.memoize(i) - IEmitCode(cb, - isElementMissing(cb, iv), - new SJavaStringValue(cb.memoize(array(iv)))) + IEmitCode(cb, isElementMissing(cb, iv), new SJavaStringValue(cb.memoize(array(iv)))) } } override def isElementMissing(cb: EmitCodeBuilder, i: Code[Int]): Value[Boolean] = cb.memoize(array(i).isNull) - override def hasMissingValues(cb: EmitCodeBuilder): Value[Boolean] = { + override def hasMissingValues(cb: EmitCodeBuilder): Value[Boolean] = if (st.elementRequired) const(false) else - cb.memoize(Code.invokeScalaObject1[Array[String], Boolean](SJavaArrayHelpers.getClass, "hasNulls", array)) - } + cb.memoize(Code.invokeScalaObject1[Array[String], Boolean]( + SJavaArrayHelpers.getClass, + "hasNulls", + array, + )) override def castToArray(cb: EmitCodeBuilder): SIndexableValue = this } object SJavaArrayStringSettable { - def apply(sb: SettableBuilder, st: SJavaArrayString, name: String): SJavaArrayStringSettable = { - new SJavaArrayStringSettable(st, - sb.newSettable[Array[String]](s"${ name }_arr")) - } + def apply(sb: SettableBuilder, st: SJavaArrayString, name: String): SJavaArrayStringSettable = + new SJavaArrayStringSettable(st, sb.newSettable[Array[String]](s"${name}_arr")) } final class SJavaArrayStringSettable( st: SJavaArrayString, - override val array: Settable[Array[String]] + override val array: Settable[Array[String]], ) extends SJavaArrayStringValue(st, array) with SSettable { override def settableTuple(): IndexedSeq[Settable[_]] = FastSeq(array) - override def store(cb: EmitCodeBuilder, v: SValue): Unit = { + override def store(cb: EmitCodeBuilder, v: SValue): Unit = cb.assign(array, v.asInstanceOf[SJavaArrayStringValue].array) - } } diff --git a/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SJavaBytes.scala b/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SJavaBytes.scala index fbaa41d5957..c0d422f3b22 100644 --- a/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SJavaBytes.scala +++ b/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SJavaBytes.scala @@ -3,9 +3,9 @@ package is.hail.types.physical.stypes.concrete import is.hail.annotations.Region import is.hail.asm4s._ import is.hail.expr.ir.EmitCodeBuilder -import is.hail.types.physical.stypes.interfaces.{SBinary, SBinaryValue} -import is.hail.types.physical.stypes.{SSettable, SType, SValue} import is.hail.types.physical.{PCanonicalBinary, PType} +import is.hail.types.physical.stypes.{SSettable, SType, SValue} +import is.hail.types.physical.stypes.interfaces.{SBinary, SBinaryValue} import is.hail.types.virtual._ import is.hail.utils.FastSeq @@ -20,7 +20,12 @@ case object SJavaBytes extends SBinary { override def containsPointers: Boolean = false - override def _coerceOrCopy(cb: EmitCodeBuilder, region: Value[Region], value: SValue, deepCopy: Boolean): SJavaBytesValue = + override def _coerceOrCopy( + cb: EmitCodeBuilder, + region: Value[Region], + value: SValue, + deepCopy: Boolean, + ): SJavaBytesValue = value.st match { case SJavaBytes => value.asInstanceOf[SJavaBytesValue] case _ => new SJavaBytesValue(value.asBinary.loadBytes(cb)) @@ -29,12 +34,12 @@ case object SJavaBytes extends SBinary { override def settableTupleTypes(): IndexedSeq[TypeInfo[_]] = FastSeq(arrayInfo[Byte]) override def fromSettables(settables: IndexedSeq[Settable[_]]): SJavaBytesSettable = { - val IndexedSeq(b: Settable[Array[Byte]@unchecked]) = settables + val IndexedSeq(b: Settable[Array[Byte] @unchecked]) = settables new SJavaBytesSettable(b) } override def fromValues(values: IndexedSeq[Value[_]]): SJavaBytesValue = { - val IndexedSeq(b: Value[Array[Byte]@unchecked]) = values + val IndexedSeq(b: Value[Array[Byte] @unchecked]) = values new SJavaBytesValue(b) } } @@ -54,15 +59,14 @@ class SJavaBytesValue(val bytes: Value[Array[Byte]]) extends SBinaryValue { } object SJavaBytesSettable { - def apply(sb: SettableBuilder, name: String): SJavaBytesSettable = { - new SJavaBytesSettable(sb.newSettable[Array[Byte]](s"${ name }_bytes")) - } + def apply(sb: SettableBuilder, name: String): SJavaBytesSettable = + new SJavaBytesSettable(sb.newSettable[Array[Byte]](s"${name}_bytes")) } -final class SJavaBytesSettable(override val bytes: Settable[Array[Byte]]) extends SJavaBytesValue(bytes) with SSettable { +final class SJavaBytesSettable(override val bytes: Settable[Array[Byte]]) + extends SJavaBytesValue(bytes) with SSettable { override def settableTuple(): IndexedSeq[Settable[_]] = FastSeq(bytes) - override def store(cb: EmitCodeBuilder, v: SValue): Unit = { + override def store(cb: EmitCodeBuilder, v: SValue): Unit = cb.assign(bytes, v.asInstanceOf[SJavaBytesValue].bytes) - } } diff --git a/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SJavaString.scala b/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SJavaString.scala index b2ac469f49f..5bf9a8c0c52 100644 --- a/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SJavaString.scala +++ b/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SJavaString.scala @@ -3,10 +3,10 @@ package is.hail.types.physical.stypes.concrete import is.hail.annotations.Region import is.hail.asm4s._ import is.hail.expr.ir.EmitCodeBuilder +import is.hail.types.physical.{PCanonicalString, PType} +import is.hail.types.physical.stypes.{SSettable, SType, SValue} import is.hail.types.physical.stypes.interfaces._ import is.hail.types.physical.stypes.primitives.SInt64Value -import is.hail.types.physical.stypes.{SSettable, SType, SValue} -import is.hail.types.physical.{PCanonicalString, PType} import is.hail.types.virtual.{TString, Type} import is.hail.utils.FastSeq @@ -21,28 +21,32 @@ case object SJavaString extends SString { override def castRename(t: Type): SType = this - override def _coerceOrCopy(cb: EmitCodeBuilder, region: Value[Region], value: SValue, deepCopy: Boolean): SJavaStringValue = { + override def _coerceOrCopy( + cb: EmitCodeBuilder, + region: Value[Region], + value: SValue, + deepCopy: Boolean, + ): SJavaStringValue = value.st match { case SJavaString => value.asInstanceOf[SJavaStringValue] case _ => new SJavaStringValue(value.asString.loadString(cb)) } - } override def settableTupleTypes(): IndexedSeq[TypeInfo[_]] = FastSeq(classInfo[String]) override def fromSettables(settables: IndexedSeq[Settable[_]]): SJavaStringSettable = { - val IndexedSeq(s: Settable[String@unchecked]) = settables + val IndexedSeq(s: Settable[String @unchecked]) = settables new SJavaStringSettable(s) } override def fromValues(values: IndexedSeq[Value[_]]): SJavaStringValue = { - val IndexedSeq(s: Value[String@unchecked]) = values + val IndexedSeq(s: Value[String @unchecked]) = values new SJavaStringValue(s) } - override def constructFromString(cb: EmitCodeBuilder, r: Value[Region], s: Code[String]): SJavaStringValue = { + override def constructFromString(cb: EmitCodeBuilder, r: Value[Region], s: Code[String]) + : SJavaStringValue = new SJavaStringValue(cb.memoize(s)) - } def construct(cb: EmitCodeBuilder, s: Code[String]): SJavaStringValue = new SJavaStringValue(cb.memoize(s)) @@ -70,15 +74,14 @@ class SJavaStringValue(val s: Value[String]) extends SStringValue { } object SJavaStringSettable { - def apply(sb: SettableBuilder, name: String): SJavaStringSettable = { - new SJavaStringSettable(sb.newSettable[String](s"${ name }_str")) - } + def apply(sb: SettableBuilder, name: String): SJavaStringSettable = + new SJavaStringSettable(sb.newSettable[String](s"${name}_str")) } -final class SJavaStringSettable(override val s: Settable[String]) extends SJavaStringValue(s) with SSettable { +final class SJavaStringSettable(override val s: Settable[String]) + extends SJavaStringValue(s) with SSettable { override def settableTuple(): IndexedSeq[Settable[_]] = FastSeq(s) - override def store(cb: EmitCodeBuilder, v: SValue): Unit = { + override def store(cb: EmitCodeBuilder, v: SValue): Unit = cb.assign(s, v.asInstanceOf[SJavaStringValue].s) - } } diff --git a/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SNDArrayPointer.scala b/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SNDArrayPointer.scala index 0369d2e4c4a..6712ca75b87 100644 --- a/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SNDArrayPointer.scala +++ b/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SNDArrayPointer.scala @@ -3,11 +3,11 @@ package is.hail.types.physical.stypes.concrete import is.hail.annotations.Region import is.hail.asm4s._ import is.hail.expr.ir.EmitCodeBuilder -import is.hail.types.physical.stypes.interfaces._ -import is.hail.types.physical.stypes.{SType, SValue} import is.hail.types.physical.{PCanonicalNDArray, PType} +import is.hail.types.physical.stypes.{SType, SValue} +import is.hail.types.physical.stypes.interfaces._ import is.hail.types.virtual.Type -import is.hail.utils.{FastSeq, toRichIterable} +import is.hail.utils.{toRichIterable, FastSeq} final case class SNDArrayPointer(pType: PCanonicalNDArray) extends SNDArray { require(!pType.required) @@ -24,31 +24,53 @@ final case class SNDArrayPointer(pType: PCanonicalNDArray) extends SNDArray { override def castRename(t: Type): SType = SNDArrayPointer(pType.deepRename(t)) - override def _coerceOrCopy(cb: EmitCodeBuilder, region: Value[Region], value: SValue, deepCopy: Boolean): SValue = + override def _coerceOrCopy( + cb: EmitCodeBuilder, + region: Value[Region], + value: SValue, + deepCopy: Boolean, + ): SValue = value match { case value: SNDArrayValue => val a = pType.store(cb, region, value, deepCopy) - new SNDArrayPointerValue(this, a, value.shapes, value.strides, cb.memoize(pType.dataFirstElementPointer(a))) + new SNDArrayPointerValue( + this, + a, + value.shapes, + value.strides, + cb.memoize(pType.dataFirstElementPointer(a)), + ) } override def settableTupleTypes(): IndexedSeq[TypeInfo[_]] = Array.fill(2 + nDims * 2)(LongInfo) override def fromSettables(settables: IndexedSeq[Settable[_]]): SNDArrayPointerSettable = { - val a = settables(0).asInstanceOf[Settable[Long@unchecked]] - val shape = settables.slice(1, 1 + pType.nDims).asInstanceOf[IndexedSeq[Settable[Long@unchecked]]] - val strides = settables.slice(1 + pType.nDims, 1 + 2 * pType.nDims).asInstanceOf[IndexedSeq[Settable[Long@unchecked]]] + val a = settables(0).asInstanceOf[Settable[Long @unchecked]] + val shape = + settables.slice(1, 1 + pType.nDims).asInstanceOf[IndexedSeq[Settable[Long @unchecked]]] + val strides = settables + .slice(1 + pType.nDims, 1 + 2 * pType.nDims) + .asInstanceOf[IndexedSeq[Settable[Long @unchecked]]] val dataFirstElementPointer = settables.last.asInstanceOf[Settable[Long]] assert(a.ti == LongInfo) new SNDArrayPointerSettable(this, a, shape, strides, dataFirstElementPointer) } override def fromValues(values: IndexedSeq[Value[_]]): SNDArrayPointerValue = { - val a = values(0).asInstanceOf[Value[Long@unchecked]] - val shape = values.slice(1, 1 + pType.nDims).asInstanceOf[IndexedSeq[Value[Long@unchecked]]] - val strides = values.slice(1 + pType.nDims, 1 + 2 * pType.nDims).asInstanceOf[IndexedSeq[Value[Long@unchecked]]] + val a = values(0).asInstanceOf[Value[Long @unchecked]] + val shape = values.slice(1, 1 + pType.nDims).asInstanceOf[IndexedSeq[Value[Long @unchecked]]] + val strides = values + .slice(1 + pType.nDims, 1 + 2 * pType.nDims) + .asInstanceOf[IndexedSeq[Value[Long @unchecked]]] val dataFirstElementPointer = values.last.asInstanceOf[Value[Long]] assert(a.ti == LongInfo) - new SNDArrayPointerValue(this, a, shape.map(SizeValueDyn.apply), strides, dataFirstElementPointer) + new SNDArrayPointerValue( + this, + a, + shape.map(SizeValueDyn.apply), + strides, + dataFirstElementPointer, + ) } override def storageType(): PType = pType @@ -63,16 +85,18 @@ class SNDArrayPointerValue( val a: Value[Long], val shapes: IndexedSeq[SizeValue], val strides: IndexedSeq[Value[Long]], - val firstDataAddress: Value[Long] + val firstDataAddress: Value[Long], ) extends SNDArrayValue { val pt: PCanonicalNDArray = st.pType - override lazy val valueTuple: IndexedSeq[Value[_]] = FastSeq(a) ++ shapes ++ strides ++ FastSeq(firstDataAddress) + override lazy val valueTuple: IndexedSeq[Value[_]] = + FastSeq(a) ++ shapes ++ strides ++ FastSeq(firstDataAddress) override def shapeStruct(cb: EmitCodeBuilder): SBaseStructValue = pt.shapeType.loadCheapSCode(cb, pt.representation.loadField(a, "shape")) - override def loadElementAddress(indices: IndexedSeq[Value[Long]], cb: EmitCodeBuilder): Code[Long] = { + override def loadElementAddress(indices: IndexedSeq[Value[Long]], cb: EmitCodeBuilder) + : Code[Long] = { assert(indices.size == pt.nDims) pt.loadElementFromDataAndStrides(cb, indices, firstDataAddress, strides) } @@ -82,7 +106,8 @@ class SNDArrayPointerValue( pt.elementType.loadCheapSCode(cb, loadElementAddress(indices, cb)) } - override def coerceToShape(cb: EmitCodeBuilder, otherShape: IndexedSeq[SizeValue]): SNDArrayValue = { + override def coerceToShape(cb: EmitCodeBuilder, otherShape: IndexedSeq[SizeValue]) + : SNDArrayValue = { cb.if_(!hasShape(cb, otherShape), cb._fatal("incompatible shapes")) new SNDArrayPointerValue(st, a, otherShape, strides, firstDataAddress) } @@ -94,7 +119,8 @@ class SNDArrayPointerValue( indexVars: IndexedSeq[String], destIndices: IndexedSeq[Int], arrays: (SNDArrayValue, IndexedSeq[Int], String)* - )(body: IndexedSeq[SValue] => SValue + )( + body: IndexedSeq[SValue] => SValue ): Unit = { SNDArray._coiterate(cb, indexVars, (this, destIndices, "dest") +: arrays: _*) { ptrs => val codes = (this +: arrays.map(_._1)).zip(ptrs).toFastSeq.map { case (array, ptr) => @@ -107,13 +133,14 @@ class SNDArrayPointerValue( } object SNDArrayPointerSettable { - def apply(sb: SettableBuilder, st: SNDArrayPointer, name: String): SNDArrayPointerSettable = { - new SNDArrayPointerSettable(st, sb.newSettable[Long](name), + def apply(sb: SettableBuilder, st: SNDArrayPointer, name: String): SNDArrayPointerSettable = + new SNDArrayPointerSettable( + st, + sb.newSettable[Long](name), Array.tabulate(st.pType.nDims)(i => sb.newSettable[Long](s"${name}_nd_shape_$i")), Array.tabulate(st.pType.nDims)(i => sb.newSettable[Long](s"${name}_nd_strides_$i")), - sb.newSettable[Long](s"${name}_nd_first_element") + sb.newSettable[Long](s"${name}_nd_first_element"), ) - } } final class SNDArrayPointerSettable( @@ -121,9 +148,11 @@ final class SNDArrayPointerSettable( override val a: Settable[Long], val shape: IndexedSeq[Settable[Long]], override val strides: IndexedSeq[Settable[Long]], - override val firstDataAddress: Settable[Long] -) extends SNDArrayPointerValue(st, a, shape.map(SizeValueDyn.apply), strides, firstDataAddress) with SNDArraySettable { - def settableTuple(): IndexedSeq[Settable[_]] = FastSeq(a) ++ shape ++ strides ++ FastSeq(firstDataAddress) + override val firstDataAddress: Settable[Long], +) extends SNDArrayPointerValue(st, a, shape.map(SizeValueDyn.apply), strides, firstDataAddress) + with SNDArraySettable { + def settableTuple(): IndexedSeq[Settable[_]] = + FastSeq(a) ++ shape ++ strides ++ FastSeq(firstDataAddress) def store(cb: EmitCodeBuilder, v: SValue): Unit = v match { case v: SNDArrayPointerValue => @@ -132,4 +161,4 @@ final class SNDArrayPointerSettable( (strides, v.strides).zipped.foreach(cb.assign(_, _)) cb.assign(firstDataAddress, v.firstDataAddress) } -} \ No newline at end of file +} diff --git a/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SNDArraySlice.scala b/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SNDArraySlice.scala index 54364485b8e..61b39e042c6 100644 --- a/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SNDArraySlice.scala +++ b/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SNDArraySlice.scala @@ -3,10 +3,10 @@ package is.hail.types.physical.stypes.concrete import is.hail.annotations.Region import is.hail.asm4s._ import is.hail.expr.ir.{EmitCodeBuilder, EmitValue} +import is.hail.types.physical.{PCanonicalNDArray, PType} import is.hail.types.physical.stypes._ import is.hail.types.physical.stypes.interfaces._ import is.hail.types.physical.stypes.primitives.SInt64 -import is.hail.types.physical.{PCanonicalNDArray, PType} import is.hail.types.virtual.{TNDArray, Type} import is.hail.utils.toRichIterable @@ -27,26 +27,31 @@ final case class SNDArraySlice(pType: PCanonicalNDArray) extends SNDArray { override def castRename(t: Type): SType = SNDArrayPointer(pType.deepRename(t)) - override def _coerceOrCopy(cb: EmitCodeBuilder, region: Value[Region], value: SValue, deepCopy: Boolean): SValue = + override def _coerceOrCopy( + cb: EmitCodeBuilder, + region: Value[Region], + value: SValue, + deepCopy: Boolean, + ): SValue = value.st match { case SNDArraySlice(`pType`) if !deepCopy => value } - - override def settableTupleTypes(): IndexedSeq[TypeInfo[_]] = Array.fill(2*nDims + 1)(LongInfo) + override def settableTupleTypes(): IndexedSeq[TypeInfo[_]] = Array.fill(2 * nDims + 1)(LongInfo) override def fromSettables(settables: IndexedSeq[Settable[_]]): SNDArraySliceSettable = { - assert(settables.length == 2*nDims + 1) - val shape = settables.slice(0, nDims).asInstanceOf[IndexedSeq[Settable[Long@unchecked]]] - val strides = settables.slice(nDims, 2 * nDims).asInstanceOf[IndexedSeq[Settable[Long@unchecked]]] + assert(settables.length == 2 * nDims + 1) + val shape = settables.slice(0, nDims).asInstanceOf[IndexedSeq[Settable[Long @unchecked]]] + val strides = + settables.slice(nDims, 2 * nDims).asInstanceOf[IndexedSeq[Settable[Long @unchecked]]] val dataFirstElementPointer = settables.last.asInstanceOf[Settable[Long]] new SNDArraySliceSettable(this, shape, strides, dataFirstElementPointer) } override def fromValues(settables: IndexedSeq[Value[_]]): SNDArraySliceValue = { - assert(settables.length == 2*nDims + 1) - val shape = settables.slice(0, nDims).asInstanceOf[IndexedSeq[Value[Long@unchecked]]] - val strides = settables.slice(nDims, 2 * nDims).asInstanceOf[IndexedSeq[Value[Long@unchecked]]] + assert(settables.length == 2 * nDims + 1) + val shape = settables.slice(0, nDims).asInstanceOf[IndexedSeq[Value[Long @unchecked]]] + val strides = settables.slice(nDims, 2 * nDims).asInstanceOf[IndexedSeq[Value[Long @unchecked]]] val dataFirstElementPointer = settables.last.asInstanceOf[Value[Long]] new SNDArraySliceValue(this, shape.map(SizeValueDyn.apply), strides, dataFirstElementPointer) } @@ -58,18 +63,20 @@ class SNDArraySliceValue( override val st: SNDArraySlice, override val shapes: IndexedSeq[SizeValue], override val strides: IndexedSeq[Value[Long]], - override val firstDataAddress: Value[Long] + override val firstDataAddress: Value[Long], ) extends SNDArrayValue { val pt: PCanonicalNDArray = st.pType override lazy val valueTuple: IndexedSeq[Value[_]] = shapes ++ strides :+ firstDataAddress override def shapeStruct(cb: EmitCodeBuilder): SStackStructValue = { - val shapeType = SStackStruct(st.virtualType.shapeType, Array.fill(st.nDims)(EmitType(SInt64, true))) + val shapeType = + SStackStruct(st.virtualType.shapeType, Array.fill(st.nDims)(EmitType(SInt64, true))) new SStackStructValue(shapeType, shapes.map(x => EmitValue.present(primitive(x)))) } - override def loadElementAddress(indices: IndexedSeq[Value[Long]], cb: EmitCodeBuilder): Code[Long] = { + override def loadElementAddress(indices: IndexedSeq[Value[Long]], cb: EmitCodeBuilder) + : Code[Long] = { assert(indices.size == pt.nDims) pt.loadElementFromDataAndStrides(cb, indices, firstDataAddress, strides) } @@ -89,7 +96,8 @@ class SNDArraySliceValue( indexVars: IndexedSeq[String], destIndices: IndexedSeq[Int], arrays: (SNDArrayValue, IndexedSeq[Int], String)* - )(body: IndexedSeq[SValue] => SValue + )( + body: IndexedSeq[SValue] => SValue ): Unit = { SNDArray._coiterate(cb, indexVars, (this, destIndices, "dest") +: arrays: _*) { ptrs => val codes = (this +: arrays.map(_._1)).zip(ptrs).toFastSeq.map { case (array, ptr) => @@ -102,27 +110,28 @@ class SNDArraySliceValue( } object SNDArraySliceSettable { - def apply(sb: SettableBuilder, st: SNDArraySlice, name: String): SNDArraySliceSettable = { - new SNDArraySliceSettable(st, + def apply(sb: SettableBuilder, st: SNDArraySlice, name: String): SNDArraySliceSettable = + new SNDArraySliceSettable( + st, Array.tabulate(st.pType.nDims)(i => sb.newSettable[Long](s"${name}_nd_shape_$i")), Array.tabulate(st.pType.nDims)(i => sb.newSettable[Long](s"${name}_nd_strides_$i")), - sb.newSettable[Long](s"${name}_nd_first_element") + sb.newSettable[Long](s"${name}_nd_first_element"), ) - } } final class SNDArraySliceSettable( st: SNDArraySlice, shape: IndexedSeq[Settable[Long]], override val strides: IndexedSeq[Settable[Long]], - override val firstDataAddress: Settable[Long] -) extends SNDArraySliceValue(st, shape.map(SizeValueDyn.apply), strides, firstDataAddress) with SSettable { + override val firstDataAddress: Settable[Long], +) extends SNDArraySliceValue(st, shape.map(SizeValueDyn.apply), strides, firstDataAddress) + with SSettable { override def settableTuple(): IndexedSeq[Settable[_]] = shape ++ strides :+ firstDataAddress override def store(cb: EmitCodeBuilder, v: SValue): Unit = v match { case v: SNDArraySliceValue => - (shape, v.shapes).zipped.foreach { (x, s) => cb.assign(x, s) } - (strides, v.strides).zipped.foreach { (x, s) => cb.assign(x, s) } + (shape, v.shapes).zipped.foreach((x, s) => cb.assign(x, s)) + (strides, v.strides).zipped.foreach((x, s) => cb.assign(x, s)) cb.assign(firstDataAddress, v.firstDataAddress) } } diff --git a/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SRNGState.scala b/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SRNGState.scala index a887e57cd2a..9466d00f64d 100644 --- a/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SRNGState.scala +++ b/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SRNGState.scala @@ -3,14 +3,18 @@ package is.hail.types.physical.stypes.concrete import is.hail.annotations.Region import is.hail.asm4s._ import is.hail.expr.ir.{EmitCodeBuilder, Threefry, ThreefryRandomEngine} -import is.hail.types.physical.stypes.primitives.SInt64Value -import is.hail.types.physical.stypes.{SSettable, SType, SValue} +import is.hail.types.{RPrimitive, TypeWithRequiredness} import is.hail.types.physical.{PType, StoredSTypePType} +import is.hail.types.physical.stypes.{SSettable, SType, SValue} +import is.hail.types.physical.stypes.primitives.SInt64Value import is.hail.types.virtual.{TRNGState, Type} -import is.hail.types.{RPrimitive, TypeWithRequiredness} import is.hail.utils.FastSeq -final case class SRNGStateStaticInfo(numWordsInLastBlock: Int, hasStaticSplit: Boolean, numDynBlocks: Int) { +final case class SRNGStateStaticInfo( + numWordsInLastBlock: Int, + hasStaticSplit: Boolean, + numDynBlocks: Int, +) { assert(numWordsInLastBlock <= 4 && numWordsInLastBlock >= 0) } @@ -22,7 +26,12 @@ final case class SRNGState(staticInfo: Option[SRNGStateStaticInfo]) extends STyp override def containsPointers: Boolean = false - override protected[stypes] def _coerceOrCopy(cb: EmitCodeBuilder, region: Value[Region], value: SValue, deepCopy: Boolean): SValue = + override protected[stypes] def _coerceOrCopy( + cb: EmitCodeBuilder, + region: Value[Region], + value: SValue, + deepCopy: Boolean, + ): SValue = value.st match { case SRNGState(_) => value } @@ -34,21 +43,24 @@ final case class SRNGState(staticInfo: Option[SRNGStateStaticInfo]) extends STyp Array.fill(4 + info.numWordsInLastBlock)(typeInfo[Long]) } - override def fromSettables(settables: IndexedSeq[Settable[_]]): SRNGStateSettable = staticInfo match { - case None => - new SCanonicalRNGStateSettable( - this, - settables.slice(0, 4).asInstanceOf[IndexedSeq[Settable[Long]]], - settables.slice(4, 8).asInstanceOf[IndexedSeq[Settable[Long]]], - coerce[Int](settables(8)), - coerce[Boolean](settables(9)), - coerce[Int](settables(10))) - case Some(_) => - new SRNGStateStaticSizeSettable( - this, - settables.slice(0, 4).asInstanceOf[IndexedSeq[Settable[Long]]], - settables.drop(4).asInstanceOf[IndexedSeq[Settable[Long]]]) - } + override def fromSettables(settables: IndexedSeq[Settable[_]]): SRNGStateSettable = + staticInfo match { + case None => + new SCanonicalRNGStateSettable( + this, + settables.slice(0, 4).asInstanceOf[IndexedSeq[Settable[Long]]], + settables.slice(4, 8).asInstanceOf[IndexedSeq[Settable[Long]]], + coerce[Int](settables(8)), + coerce[Boolean](settables(9)), + coerce[Int](settables(10)), + ) + case Some(_) => + new SRNGStateStaticSizeSettable( + this, + settables.slice(0, 4).asInstanceOf[IndexedSeq[Settable[Long]]], + settables.drop(4).asInstanceOf[IndexedSeq[Settable[Long]]], + ) + } override def fromValues(values: IndexedSeq[Value[_]]): SRNGStateValue = staticInfo match { case None => @@ -58,12 +70,14 @@ final case class SRNGState(staticInfo: Option[SRNGStateStaticInfo]) extends STyp values.slice(4, 8).asInstanceOf[IndexedSeq[Settable[Long]]], coerce[Int](values(8)), coerce[Boolean](values(9)), - coerce[Int](values(10))) - case Some(info) => + coerce[Int](values(10)), + ) + case Some(_) => new SRNGStateStaticSizeValue( this, values.slice(0, 4).asInstanceOf[IndexedSeq[Settable[Long]]], - values.drop(4).asInstanceOf[IndexedSeq[Settable[Long]]]) + values.drop(4).asInstanceOf[IndexedSeq[Settable[Long]]], + ) } override def storageType(): PType = StoredSTypePType(this, false) @@ -90,7 +104,10 @@ object SCanonicalRNGStateValue { typ, Array.fill[Value[Long]](4)(0), Array.fill[Value[Long]](4)(0), - 0, false, 0) + 0, + false, + 0, + ) } } @@ -100,7 +117,7 @@ class SCanonicalRNGStateValue( val lastDynBlock: IndexedSeq[Value[Long]], val numWordsInLastBlock: Value[Int], val hasStaticSplit: Value[Boolean], - val numDynBlocks: Value[Int] + val numDynBlocks: Value[Int], ) extends SRNGStateValue { override def valueTuple: IndexedSeq[Value[_]] = @@ -109,7 +126,7 @@ class SCanonicalRNGStateValue( FastSeq(numWordsInLastBlock, hasStaticSplit, numDynBlocks) override def sizeToStoreInBytes(cb: EmitCodeBuilder): SInt64Value = - new SInt64Value(4*8 + 4*8 + 4 + 4 + 4) + new SInt64Value(4 * 8 + 4 * 8 + 4 + 4 + 4) def splitStatic(cb: EmitCodeBuilder, idx: Long): SCanonicalRNGStateValue = { cb.if_(hasStaticSplit, cb._fatal("RNGState received two static splits")) @@ -120,49 +137,66 @@ class SCanonicalRNGStateValue( Threefry.encrypt(key, Array(Threefry.staticTweak, 0L), x) val newDynBlocksSum = Array.tabulate[Value[Long]](4)(i => cb.memoize(runningSum(i) ^ x(i))) - new SCanonicalRNGStateValue(st, newDynBlocksSum, lastDynBlock, numWordsInLastBlock, const(true), numDynBlocks) + new SCanonicalRNGStateValue( + st, + newDynBlocksSum, + lastDynBlock, + numWordsInLastBlock, + const(true), + numDynBlocks, + ) } def splitDyn(cb: EmitCodeBuilder, idx: Value[Long]): SCanonicalRNGStateValue = { - val newRunningSum = Array.tabulate[Settable[Long]](4)(i => cb.newLocal[Long](s"splitDyn_x$i", runningSum(i))) - val newLastDynBlock = Array.tabulate[Settable[Long]](4)(i => cb.newLocal[Long](s"splitDyn_m$i", lastDynBlock(i))) + val newRunningSum = + Array.tabulate[Settable[Long]](4)(i => cb.newLocal[Long](s"splitDyn_x$i", runningSum(i))) + val newLastDynBlock = + Array.tabulate[Settable[Long]](4)(i => cb.newLocal[Long](s"splitDyn_m$i", lastDynBlock(i))) val newNumWordsInLastBlock = cb.newLocal[Int](s"splitDyn_numWords", numWordsInLastBlock) val newNumDynBlocks = cb.newLocal[Int](s"splitDyn_numBlocks", numDynBlocks) - cb.if_(numWordsInLastBlock < 4, { - cb.switch(numWordsInLastBlock, - cb._fatal("invalid numWordsInLastBlock: ", numWordsInLastBlock.toS), - for {i <- 0 until 4} yield () => cb.assign(newLastDynBlock(i), idx) - ) - cb.assign(newNumWordsInLastBlock, newNumWordsInLastBlock + 1) - }, { - val key = Threefry.defaultKey - Threefry.encrypt(cb, key, Array(cb.memoize(numDynBlocks.toL), const(0L)), newLastDynBlock) - for (i <- 0 until 4) cb.assign(newRunningSum(i), newRunningSum(i) ^ newLastDynBlock(i)) - cb.assign(newLastDynBlock(0), idx) - for (i <- 1 until 4) cb.assign(newLastDynBlock(i), 0L) - cb.assign(newNumWordsInLastBlock, 1) - cb.assign(newNumDynBlocks, newNumDynBlocks + 1) - }) - - new SCanonicalRNGStateValue(st, newRunningSum, newLastDynBlock, newNumWordsInLastBlock, hasStaticSplit, newNumDynBlocks) + cb.if_( + numWordsInLastBlock < 4, { + cb.switch( + numWordsInLastBlock, + cb._fatal("invalid numWordsInLastBlock: ", numWordsInLastBlock.toS), + for { i <- 0 until 4 } yield () => cb.assign(newLastDynBlock(i), idx), + ) + cb.assign(newNumWordsInLastBlock, newNumWordsInLastBlock + 1) + }, { + val key = Threefry.defaultKey + Threefry.encrypt(cb, key, Array(cb.memoize(numDynBlocks.toL), const(0L)), newLastDynBlock) + for (i <- 0 until 4) cb.assign(newRunningSum(i), newRunningSum(i) ^ newLastDynBlock(i)) + cb.assign(newLastDynBlock(0), idx) + for (i <- 1 until 4) cb.assign(newLastDynBlock(i), 0L) + cb.assign(newNumWordsInLastBlock, 1) + cb.assign(newNumDynBlocks, newNumDynBlocks + 1) + }, + ) + + new SCanonicalRNGStateValue(st, newRunningSum, newLastDynBlock, newNumWordsInLastBlock, + hasStaticSplit, newNumDynBlocks) } def rand(cb: EmitCodeBuilder): IndexedSeq[Value[Long]] = { cb.if_(!hasStaticSplit, cb._fatal("RNGState never received static split")) val x = Array.tabulate[Settable[Long]](4)(i => cb.newLocal[Long](s"rand_x$i", runningSum(i))) val key = Threefry.defaultKey - val finalTweak = cb.memoize((numWordsInLastBlock ceq 4).mux(Threefry.finalBlockNoPadTweak, Threefry.finalBlockPaddedTweak)) + val finalTweak = cb.memoize((numWordsInLastBlock ceq 4).mux( + Threefry.finalBlockNoPadTweak, + Threefry.finalBlockPaddedTweak, + )) for (i <- 0 until 4) cb.assign(x(i), x(i) ^ lastDynBlock(i)) - cb.switch(numWordsInLastBlock, + cb.switch( + numWordsInLastBlock, cb._fatal("invalid numWordsInLastBlock: ", numWordsInLastBlock.toS), FastSeq( () => cb += (x(0) := x(0) ^ 1L), () => cb += (x(1) := x(1) ^ 1L), () => cb += (x(2) := x(2) ^ 1L), () => cb += (x(3) := x(3) ^ 1L), - () => {} - ) + () => {}, + ), ) Threefry.encrypt(cb, key, Array(finalTweak, const(0L)), x) x @@ -171,19 +205,28 @@ class SCanonicalRNGStateValue( def copyIntoEngine(cb: EmitCodeBuilder, tf: Value[ThreefryRandomEngine]): Unit = { cb.if_(!hasStaticSplit, cb._fatal("RNGState never received static split")) val x = Array.tabulate[Settable[Long]](4)(i => cb.newLocal[Long](s"cie_x$i", runningSum(i))) - val finalTweak = (numWordsInLastBlock ceq 4).mux(Threefry.finalBlockNoPadTweak, Threefry.finalBlockPaddedTweak) + val finalTweak = + (numWordsInLastBlock ceq 4).mux(Threefry.finalBlockNoPadTweak, Threefry.finalBlockPaddedTweak) for (i <- 0 until 4) cb.assign(x(i), x(i) ^ lastDynBlock(i)) - cb.switch(numWordsInLastBlock, + cb.switch( + numWordsInLastBlock, cb._fatal("invalid numWordsInLastBlock: ", numWordsInLastBlock.toS), FastSeq( () => cb += (x(0) := x(0) ^ 1L), () => cb += (x(1) := x(1) ^ 1L), () => cb += (x(2) := x(2) ^ 1L), () => cb += (x(3) := x(3) ^ 1L), - () => {} - ) + () => {}, + ), + ) + cb += tf.invoke[Long, Long, Long, Long, Long, Unit]( + "resetState", + x(0), + x(1), + x(2), + x(3), + finalTweak, ) - cb += tf.invoke[Long, Long, Long, Long, Long, Unit]("resetState", x(0), x(1), x(2), x(3), finalTweak) } } @@ -193,13 +236,13 @@ class SCanonicalRNGStateSettable( override val lastDynBlock: IndexedSeq[Settable[Long]], override val numWordsInLastBlock: Settable[Int], override val hasStaticSplit: Settable[Boolean], - override val numDynBlocks: Settable[Int] -) extends SCanonicalRNGStateValue(st, runningSum, lastDynBlock, numWordsInLastBlock, hasStaticSplit, numDynBlocks) - with SRNGStateSettable { + override val numDynBlocks: Settable[Int], +) extends SCanonicalRNGStateValue(st, runningSum, lastDynBlock, numWordsInLastBlock, hasStaticSplit, + numDynBlocks) with SRNGStateSettable { override def store(cb: EmitCodeBuilder, v: SValue): Unit = v match { case v: SCanonicalRNGStateValue => - (runningSum, v.runningSum).zipped.foreach { (x, s) => cb.assign(x, s) } - (lastDynBlock, v.lastDynBlock).zipped.foreach { (x, s) => cb.assign(x, s) } + (runningSum, v.runningSum).zipped.foreach((x, s) => cb.assign(x, s)) + (lastDynBlock, v.lastDynBlock).zipped.foreach((x, s) => cb.assign(x, s)) cb.assign(numWordsInLastBlock, v.numWordsInLastBlock) cb.assign(hasStaticSplit, v.hasStaticSplit) cb.assign(numDynBlocks, v.numDynBlocks) @@ -217,14 +260,15 @@ object SRNGStateStaticSizeValue { new SRNGStateStaticSizeValue( typ, Array.fill[Value[Long]](4)(0), - Array[Value[Long]]()) + Array[Value[Long]](), + ) } } class SRNGStateStaticSizeValue( override val st: SRNGState, val runningSum: IndexedSeq[Value[Long]], - val lastDynBlock: IndexedSeq[Value[Long]] + val lastDynBlock: IndexedSeq[Value[Long]], ) extends SRNGStateValue { val staticInfo = st.staticInfo.get assert(runningSum.length == 4) @@ -234,7 +278,7 @@ class SRNGStateStaticSizeValue( runningSum ++ lastDynBlock override def sizeToStoreInBytes(cb: EmitCodeBuilder): SInt64Value = - new SInt64Value(4*8 + staticInfo.numWordsInLastBlock*8) + new SInt64Value(4 * 8 + staticInfo.numWordsInLastBlock * 8) def splitStatic(cb: EmitCodeBuilder, idx: Long): SRNGStateStaticSizeValue = { assert(!staticInfo.hasStaticSplit) @@ -246,28 +290,42 @@ class SRNGStateStaticSizeValue( val newDynBlocksSum = Array.tabulate[Value[Long]](4)(i => cb.memoize(runningSum(i) ^ x(i))) new SRNGStateStaticSizeValue( - st = SRNGState(Some(SRNGStateStaticInfo(staticInfo.numWordsInLastBlock, true, staticInfo.numDynBlocks))), + st = SRNGState(Some(SRNGStateStaticInfo( + staticInfo.numWordsInLastBlock, + true, + staticInfo.numDynBlocks, + ))), runningSum = newDynBlocksSum, - lastDynBlock = lastDynBlock) + lastDynBlock = lastDynBlock, + ) } def splitDyn(cb: EmitCodeBuilder, idx: Value[Long]): SRNGStateStaticSizeValue = { if (staticInfo.numWordsInLastBlock < 4) { return new SRNGStateStaticSizeValue( - st = SRNGState(Some(SRNGStateStaticInfo(staticInfo.numWordsInLastBlock + 1, staticInfo.hasStaticSplit, staticInfo.numDynBlocks))), + st = SRNGState(Some(SRNGStateStaticInfo( + staticInfo.numWordsInLastBlock + 1, + staticInfo.hasStaticSplit, + staticInfo.numDynBlocks, + ))), runningSum = runningSum, - lastDynBlock = lastDynBlock :+ idx + lastDynBlock = lastDynBlock :+ idx, ) } - val x = Array.tabulate[Settable[Long]](4)(i => cb.newLocal[Long](s"splitDyn_x$i", lastDynBlock(i))) + val x = + Array.tabulate[Settable[Long]](4)(i => cb.newLocal[Long](s"splitDyn_x$i", lastDynBlock(i))) val key = Threefry.defaultKey Threefry.encrypt(cb, key, Array(const(staticInfo.numDynBlocks.toLong), const(0L)), x) for (i <- 0 until 4) cb.assign(x(i), x(i) ^ runningSum(i)) new SRNGStateStaticSizeValue( - st = SRNGState(Some(SRNGStateStaticInfo(1, staticInfo.hasStaticSplit, staticInfo.numDynBlocks + 1))), + st = SRNGState(Some(SRNGStateStaticInfo( + 1, + staticInfo.hasStaticSplit, + staticInfo.numDynBlocks + 1, + ))), runningSum = x, - lastDynBlock = Array(idx) + lastDynBlock = Array(idx), ) } @@ -295,21 +353,28 @@ class SRNGStateStaticSizeValue( } else { Threefry.finalBlockNoPadTweak } - cb += tf.invoke[Long, Long, Long, Long, Long, Unit]("resetState", x(0), x(1), x(2), x(3), finalTweak) + cb += tf.invoke[Long, Long, Long, Long, Long, Unit]( + "resetState", + x(0), + x(1), + x(2), + x(3), + finalTweak, + ) } } class SRNGStateStaticSizeSettable( st: SRNGState, override val runningSum: IndexedSeq[Settable[Long]], - override val lastDynBlock: IndexedSeq[Settable[Long]] + override val lastDynBlock: IndexedSeq[Settable[Long]], ) extends SRNGStateStaticSizeValue(st, runningSum, lastDynBlock) with SRNGStateSettable { override def store(cb: EmitCodeBuilder, v: SValue): Unit = v match { case v: SRNGStateStaticSizeValue => - (runningSum, v.runningSum).zipped.foreach { (x, s) => cb.assign(x, s) } - (lastDynBlock, v.lastDynBlock).zipped.foreach { (x, s) => cb.assign(x, s) } + (runningSum, v.runningSum).zipped.foreach((x, s) => cb.assign(x, s)) + (lastDynBlock, v.lastDynBlock).zipped.foreach((x, s) => cb.assign(x, s)) } override def settableTuple(): IndexedSeq[Settable[_]] = runningSum ++ lastDynBlock -} \ No newline at end of file +} diff --git a/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SStackInterval.scala b/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SStackInterval.scala index 55ae27ef217..ff4b996b713 100644 --- a/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SStackInterval.scala +++ b/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SStackInterval.scala @@ -10,7 +10,12 @@ import is.hail.types.virtual.{TInterval, Type} import is.hail.utils.FastSeq object SStackInterval { - def construct(start: EmitValue, end: EmitValue, includesStart: Value[Boolean], includesEnd: Value[Boolean]): SStackIntervalValue = { + def construct( + start: EmitValue, + end: EmitValue, + includesStart: Value[Boolean], + includesEnd: Value[Boolean], + ): SStackIntervalValue = { assert(start.emitType == end.emitType) new SStackIntervalValue(SStackInterval(start.emitType), start, end, includesStart, includesEnd) } @@ -18,25 +23,33 @@ object SStackInterval { final case class SStackInterval(pointEmitType: EmitType) extends SInterval { - override def _coerceOrCopy(cb: EmitCodeBuilder, region: Value[Region], value: SValue, deepCopy: Boolean): SValue = + override def _coerceOrCopy( + cb: EmitCodeBuilder, + region: Value[Region], + value: SValue, + deepCopy: Boolean, + ): SValue = value match { - case value: SStackIntervalValue => new SStackIntervalValue(this, - pointEmitType.coerceOrCopy(cb, region, value.start, deepCopy), - pointEmitType.coerceOrCopy(cb, region, value.end, deepCopy), - value.includesStart, - value.includesEnd - ) + case value: SStackIntervalValue => new SStackIntervalValue( + this, + pointEmitType.coerceOrCopy(cb, region, value.start, deepCopy), + pointEmitType.coerceOrCopy(cb, region, value.end, deepCopy), + value.includesStart, + value.includesEnd, + ) case value: SIntervalValue => - new SStackIntervalValue(this, + new SStackIntervalValue( + this, pointEmitType.coerceOrCopy(cb, region, cb.memoize(value.loadStart(cb)), deepCopy), pointEmitType.coerceOrCopy(cb, region, cb.memoize(value.loadEnd(cb)), deepCopy), value.includesStart, - value.includesEnd - ) + value.includesEnd, + ) } - - override def castRename(t: Type): SType = SStackInterval(pointEmitType.copy(st = pointType.castRename(t.asInstanceOf[TInterval].pointType))) + override def castRename(t: Type): SType = SStackInterval(pointEmitType.copy(st = + pointType.castRename(t.asInstanceOf[TInterval].pointType) + )) override lazy val virtualType: Type = TInterval(pointEmitType.virtualType) @@ -48,21 +61,25 @@ final case class SStackInterval(pointEmitType: EmitType) extends SInterval { override def fromSettables(settables: IndexedSeq[Settable[_]]): SStackIntervalSettable = { val pointNSettables = pointEmitType.nSettables assert(settables.length == 2 * pointNSettables + 2) - new SStackIntervalSettable(this, + new SStackIntervalSettable( + this, pointEmitType.fromSettables(settables.slice(0, pointNSettables)), pointEmitType.fromSettables(settables.slice(pointNSettables, 2 * pointNSettables)), coerce[Boolean](settables(pointNSettables * 2)), - coerce[Boolean](settables(pointNSettables * 2 + 1))) + coerce[Boolean](settables(pointNSettables * 2 + 1)), + ) } override def fromValues(values: IndexedSeq[Value[_]]): SStackIntervalValue = { val pointNValues = pointEmitType.nSettables assert(values.length == 2 * pointNValues + 2) - new SStackIntervalValue(this, + new SStackIntervalValue( + this, pointEmitType.fromValues(values.slice(0, pointNValues)), pointEmitType.fromValues(values.slice(pointNValues, 2 * pointNValues)), coerce[Boolean](values(pointNValues * 2)), - coerce[Boolean](values(pointNValues * 2 + 1))) + coerce[Boolean](values(pointNValues * 2 + 1)), + ) } override def pointType: SType = pointEmitType.st @@ -79,10 +96,12 @@ class SStackIntervalValue( val start: EmitValue, val end: EmitValue, val includesStart: Value[Boolean], - val includesEnd: Value[Boolean] + val includesEnd: Value[Boolean], ) extends SIntervalValue { require(start.emitType == end.emitType && start.emitType == st.pointEmitType) - override lazy val valueTuple: IndexedSeq[Value[_]] = start.valueTuple() ++ end.valueTuple() ++ FastSeq(includesStart, includesEnd) + + override lazy val valueTuple: IndexedSeq[Value[_]] = + start.valueTuple() ++ end.valueTuple() ++ FastSeq(includesStart, includesEnd) override def loadStart(cb: EmitCodeBuilder): IEmitCode = start.toI(cb) @@ -98,9 +117,10 @@ final class SStackIntervalSettable( override val start: EmitSettable, override val end: EmitSettable, override val includesStart: Settable[Boolean], - override val includesEnd: Settable[Boolean] + override val includesEnd: Settable[Boolean], ) extends SStackIntervalValue(st, start, end, includesStart, includesEnd) with SSettable { - override lazy val settableTuple: IndexedSeq[Settable[_]] = start.settableTuple() ++ end.settableTuple() ++ FastSeq(includesStart, includesEnd) + override lazy val settableTuple: IndexedSeq[Settable[_]] = + start.settableTuple() ++ end.settableTuple() ++ FastSeq(includesStart, includesEnd) override def store(cb: EmitCodeBuilder, v: SValue): Unit = v match { case v: SStackIntervalValue => diff --git a/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SStackStruct.scala b/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SStackStruct.scala index f0af086f806..9b7e3bcc9f3 100644 --- a/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SStackStruct.scala +++ b/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SStackStruct.scala @@ -1,27 +1,34 @@ package is.hail.types.physical.stypes.concrete import is.hail.annotations.Region +import is.hail.asm4s._ import is.hail.expr.ir.{EmitCode, EmitCodeBuilder, EmitSettable, EmitValue, IEmitCode} -import is.hail.types.physical.stypes.interfaces.{SBaseStruct, SBaseStructSettable, SBaseStructValue} -import is.hail.types.physical.stypes.{EmitType, SCode, SType, SValue} import is.hail.types.physical._ -import is.hail.utils._ -import is.hail.asm4s._ -import is.hail.types.physical.stypes.primitives.SInt64Value +import is.hail.types.physical.stypes.{EmitType, SType, SValue} +import is.hail.types.physical.stypes.interfaces.{SBaseStruct, SBaseStructSettable, SBaseStructValue} import is.hail.types.virtual.{TBaseStruct, TStruct, TTuple, Type} object SStackStruct { val MAX_FIELDS_FOR_CONSTRUCT: Int = 64 - def constructFromArgs(cb: EmitCodeBuilder, region: Value[Region], t: TBaseStruct, args: EmitCode*): SBaseStructValue = { + def constructFromArgs(cb: EmitCodeBuilder, region: Value[Region], t: TBaseStruct, args: EmitCode*) + : SBaseStructValue = { val as = args.toArray assert(t.size == args.size) if (region != null && as.length > MAX_FIELDS_FOR_CONSTRUCT) { val structType: PCanonicalBaseStruct = t match { case ts: TStruct => - PCanonicalStruct(false, ts.fieldNames.zip(as.map(_.emitType)).map { case (f, et) => (f, et.storageType) }: _*) + PCanonicalStruct( + false, + ts.fieldNames.zip(as.map(_.emitType)).map { case (f, et) => (f, et.storageType) }: _* + ) case tt: TTuple => - PCanonicalTuple(tt._types.zip(as.map(_.emitType)).map { case (tf, et) => PTupleField(tf.index, et.storageType) }, false) + PCanonicalTuple( + tt._types.zip(as.map(_.emitType)).map { case (tf, et) => + PTupleField(tf.index, et.storageType) + }, + false, + ) } structType.constructFromFields(cb, region, as, false) } else { @@ -31,7 +38,8 @@ object SStackStruct { } } -final case class SStackStruct(virtualType: TBaseStruct, fieldEmitTypes: IndexedSeq[EmitType]) extends SBaseStruct { +final case class SStackStruct(virtualType: TBaseStruct, fieldEmitTypes: IndexedSeq[EmitType]) + extends SBaseStruct { override def size: Int = virtualType.size private lazy val settableStarts = fieldEmitTypes.map(_.nSettables).scanLeft(0)(_ + _).init @@ -42,33 +50,59 @@ final case class SStackStruct(virtualType: TBaseStruct, fieldEmitTypes: IndexedS override def storageType(): PType = virtualType match { case ts: TStruct => - PCanonicalStruct(false, ts.fieldNames.zip(fieldEmitTypes).map { case (f, et) => (f, et.storageType) }: _*) + PCanonicalStruct( + false, + ts.fieldNames.zip(fieldEmitTypes).map { case (f, et) => (f, et.storageType) }: _* + ) case tt: TTuple => - PCanonicalTuple(tt._types.zip(fieldEmitTypes).map { case (tf, et) => PTupleField(tf.index, et.storageType) }, false) + PCanonicalTuple( + tt._types.zip(fieldEmitTypes).map { case (tf, et) => + PTupleField(tf.index, et.storageType) + }, + false, + ) } - override def copiedType: SType = SStackStruct(virtualType, fieldEmitTypes.map(f => f.copy(st = f.st.copiedType))) + override def copiedType: SType = + SStackStruct(virtualType, fieldEmitTypes.map(f => f.copy(st = f.st.copiedType))) override def containsPointers: Boolean = fieldEmitTypes.exists(_.st.containsPointers) - override lazy val settableTupleTypes: IndexedSeq[TypeInfo[_]] = fieldEmitTypes.flatMap(_.settableTupleTypes) + override lazy val settableTupleTypes: IndexedSeq[TypeInfo[_]] = + fieldEmitTypes.flatMap(_.settableTupleTypes) override def fromSettables(settables: IndexedSeq[Settable[_]]): SStackStructSettable = { - assert(settables.length == fieldEmitTypes.map(_.nSettables).sum, s"mismatch: ${ settables.length } settables, expect ${ fieldEmitTypes.map(_.nSettables).sum }\n ${ settables.map(_.ti).mkString(",") }\n ${ fieldEmitTypes.map(_.settableTupleTypes).mkString(" | ") }") - new SStackStructSettable(this, fieldEmitTypes.indices.map { i => - val et = fieldEmitTypes(i) - val start = settableStarts(i) - et.fromSettables(settables.slice(start, start + et.nSettables)) - }) + assert( + settables.length == fieldEmitTypes.map(_.nSettables).sum, + s"mismatch: ${settables.length} settables, expect ${fieldEmitTypes.map(_.nSettables).sum}\n ${settables.map( + _.ti + ).mkString(",")}\n ${fieldEmitTypes.map(_.settableTupleTypes).mkString(" | ")}", + ) + new SStackStructSettable( + this, + fieldEmitTypes.indices.map { i => + val et = fieldEmitTypes(i) + val start = settableStarts(i) + et.fromSettables(settables.slice(start, start + et.nSettables)) + }, + ) } override def fromValues(values: IndexedSeq[Value[_]]): SStackStructValue = { - assert(values.length == fieldEmitTypes.map(_.nSettables).sum, s"mismatch: ${ values.length } settables, expect ${ fieldEmitTypes.map(_.nSettables).sum }\n ${ values.map(_.ti).mkString(",") }\n ${ fieldEmitTypes.map(_.settableTupleTypes).mkString(" | ") }") - new SStackStructValue(this, fieldEmitTypes.indices.map { i => - val et = fieldEmitTypes(i) - val start = settableStarts(i) - et.fromValues(values.slice(start, start + et.nSettables)) - }) + assert( + values.length == fieldEmitTypes.map(_.nSettables).sum, + s"mismatch: ${values.length} settables, expect ${fieldEmitTypes.map(_.nSettables).sum}\n ${values.map( + _.ti + ).mkString(",")}\n ${fieldEmitTypes.map(_.settableTupleTypes).mkString(" | ")}", + ) + new SStackStructValue( + this, + fieldEmitTypes.indices.map { i => + val et = fieldEmitTypes(i) + val start = settableStarts(i) + et.fromValues(values.slice(start, start + et.nSettables)) + }, + ) } def fromEmitCodes(cb: EmitCodeBuilder, values: IndexedSeq[EmitCode]): SStackStructValue = { @@ -76,34 +110,47 @@ final case class SStackStruct(virtualType: TBaseStruct, fieldEmitTypes: IndexedS s } - override def _coerceOrCopy(cb: EmitCodeBuilder, region: Value[Region], value: SValue, deepCopy: Boolean): SValue = { + override def _coerceOrCopy( + cb: EmitCodeBuilder, + region: Value[Region], + value: SValue, + deepCopy: Boolean, + ): SValue = { value match { case ss: SStackStructValue => if (ss.st == this && !deepCopy) ss else - new SStackStructValue(this, fieldEmitTypes.zip(ss.values).map { case (newType, ev) => - val iec = ev.map(cb) { field => newType.st.coerceOrCopy(cb, region, field, deepCopy) } - (newType.required, iec.required) match { - case (true, false) => EmitValue.present(iec.get(cb)) - case (false, true) => iec.setOptional - case _ => iec - } - }) + new SStackStructValue( + this, + fieldEmitTypes.zip(ss.values).map { case (newType, ev) => + val iec = ev.map(cb)(field => newType.st.coerceOrCopy(cb, region, field, deepCopy)) + (newType.required, iec.required) match { + case (true, false) => EmitValue.present(iec.get(cb)) + case (false, true) => iec.setOptional + case _ => iec + } + }, + ) case _ => val sv = value.asBaseStruct - new SStackStructValue(this, Array.tabulate[EmitValue](size) { i => - val newType = fieldEmitTypes(i) - val ec = EmitCode.fromI(cb.emb) { cb => - sv.loadField(cb, i).map(cb) { field => newType.st.coerceOrCopy(cb, region, field, deepCopy) } - } - val ev = ec.memoize(cb, "_coerceOrCopy") - (newType.required, ev.required) match { - case (true, false) => EmitValue.present(ev.get(cb)) - case (false, true) => ev.setOptional - case _ => ev - } - }) + new SStackStructValue( + this, + Array.tabulate[EmitValue](size) { i => + val newType = fieldEmitTypes(i) + val ec = EmitCode.fromI(cb.emb) { cb => + sv.loadField(cb, i).map(cb) { field => + newType.st.coerceOrCopy(cb, region, field, deepCopy) + } + } + val ev = ec.memoize(cb, "_coerceOrCopy") + (newType.required, ev.required) match { + case (true, false) => EmitValue.present(ev.get(cb)) + case (false, true) => ev.setOptional + case _ => ev + } + }, + ) } } @@ -111,35 +158,40 @@ final case class SStackStruct(virtualType: TBaseStruct, fieldEmitTypes: IndexedS val ts = t.asInstanceOf[TBaseStruct] SStackStruct( ts, - ts.types.zip(fieldEmitTypes).map { case (v, e) => e.copy(st = e.st.castRename(v)) } + ts.types.zip(fieldEmitTypes).map { case (v, e) => e.copy(st = e.st.castRename(v)) }, ) } } -class SStackStructValue(val st: SStackStruct, val values: IndexedSeq[EmitValue]) extends SBaseStructValue { - assert((st.fieldTypes, values).zipped.forall { (st, v) => v.st == st }, - s"type mismatch!\n struct type: $st\n value types: ${values.map(_.st).mkString("[", ", ", "]")}") +class SStackStructValue(val st: SStackStruct, val values: IndexedSeq[EmitValue]) + extends SBaseStructValue { + assert( + (st.fieldTypes, values).zipped.forall((st, v) => v.st == st), + s"type mismatch!\n struct type: $st\n value types: ${values.map(_.st).mkString("[", ", ", "]")}", + ) override lazy val valueTuple: IndexedSeq[Value[_]] = values.flatMap(_.valueTuple) - override def loadField(cb: EmitCodeBuilder, fieldIdx: Int): IEmitCode = { + override def loadField(cb: EmitCodeBuilder, fieldIdx: Int): IEmitCode = values(fieldIdx).toI(cb) - } override def isFieldMissing(cb: EmitCodeBuilder, fieldIdx: Int): Value[Boolean] = values(fieldIdx).m - override def subset(fieldNames: String*): SStackStructValue = { + override def subset(fieldNames: String*): SBaseStructValue = { val newToOld = fieldNames.map(st.fieldIdx).toArray val oldVType = st.virtualType.asInstanceOf[TStruct] val newVirtualType = TStruct(newToOld.map(i => (oldVType.fieldNames(i), oldVType.types(i))): _*) - new SStackStructValue(SStackStruct(newVirtualType, newToOld.map(st.fieldEmitTypes)), newToOld.map(values)) + new SStackStructValue( + SStackStruct(newVirtualType, newToOld.map(st.fieldEmitTypes)), + newToOld.map(values), + ) } } final class SStackStructSettable( st: SStackStruct, - settables: IndexedSeq[EmitSettable] + settables: IndexedSeq[EmitSettable], ) extends SStackStructValue(st, settables) with SBaseStructSettable { override def settableTuple(): IndexedSeq[Settable[_]] = settables.flatMap(_.settableTuple()) diff --git a/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SStringPointer.scala b/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SStringPointer.scala index aec9479996d..350905dd5d8 100644 --- a/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SStringPointer.scala +++ b/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SStringPointer.scala @@ -3,14 +3,13 @@ package is.hail.types.physical.stypes.concrete import is.hail.annotations.Region import is.hail.asm4s.{Code, LongInfo, Settable, SettableBuilder, TypeInfo, Value} import is.hail.expr.ir.EmitCodeBuilder +import is.hail.types.physical.{PString, PType} +import is.hail.types.physical.stypes.{SSettable, SType, SValue} import is.hail.types.physical.stypes.interfaces.{SString, SStringValue} import is.hail.types.physical.stypes.primitives.SInt64Value -import is.hail.types.physical.stypes.{SSettable, SType, SValue} -import is.hail.types.physical.{PString, PType} import is.hail.types.virtual.Type import is.hail.utils.FastSeq - final case class SStringPointer(pType: PString) extends SString { require(!pType.required) @@ -18,26 +17,31 @@ final case class SStringPointer(pType: PString) extends SString { override def castRename(t: Type): SType = this - override def _coerceOrCopy(cb: EmitCodeBuilder, region: Value[Region], value: SValue, deepCopy: Boolean): SValue = + override def _coerceOrCopy( + cb: EmitCodeBuilder, + region: Value[Region], + value: SValue, + deepCopy: Boolean, + ): SValue = new SStringPointerValue(this, pType.store(cb, region, value, deepCopy)) override def settableTupleTypes(): IndexedSeq[TypeInfo[_]] = FastSeq(LongInfo) override def fromSettables(settables: IndexedSeq[Settable[_]]): SStringPointerSettable = { - val IndexedSeq(a: Settable[Long@unchecked]) = settables + val IndexedSeq(a: Settable[Long @unchecked]) = settables assert(a.ti == LongInfo) new SStringPointerSettable(this, a) } override def fromValues(values: IndexedSeq[Value[_]]): SStringPointerValue = { - val IndexedSeq(a: Value[Long@unchecked]) = values + val IndexedSeq(a: Value[Long @unchecked]) = values assert(a.ti == LongInfo) new SStringPointerValue(this, a) } - override def constructFromString(cb: EmitCodeBuilder, r: Value[Region], s: Code[String]): SStringPointerValue = { + override def constructFromString(cb: EmitCodeBuilder, r: Value[Region], s: Code[String]) + : SStringPointerValue = new SStringPointerValue(this, pType.allocateAndStoreString(cb, r, s)) - } override def storageType(): PType = pType @@ -51,7 +55,8 @@ class SStringPointerValue(val st: SStringPointer, val a: Value[Long]) extends SS override lazy val valueTuple: IndexedSeq[Value[_]] = FastSeq(a) - def binaryRepr(): SBinaryPointerValue = new SBinaryPointerValue(SBinaryPointer(st.pType.binaryRepresentation), a) + def binaryRepr(): SBinaryPointerValue = + new SBinaryPointerValue(SBinaryPointer(st.pType.binaryRepresentation), a) def loadLength(cb: EmitCodeBuilder): Value[Int] = cb.memoize(pt.loadLength(a)) @@ -62,22 +67,22 @@ class SStringPointerValue(val st: SStringPointer, val a: Value[Long]) extends SS def toBytes(cb: EmitCodeBuilder): SBinaryPointerValue = new SBinaryPointerValue(SBinaryPointer(pt.binaryRepresentation), a) - override def sizeToStoreInBytes(cb: EmitCodeBuilder): SInt64Value = this.binaryRepr().sizeToStoreInBytes(cb) + override def sizeToStoreInBytes(cb: EmitCodeBuilder): SInt64Value = + this.binaryRepr().sizeToStoreInBytes(cb) } object SStringPointerSettable { - def apply(sb: SettableBuilder, st: SStringPointer, name: String): SStringPointerSettable = { - new SStringPointerSettable(st, - sb.newSettable[Long](s"${ name }_a")) - } + def apply(sb: SettableBuilder, st: SStringPointer, name: String): SStringPointerSettable = + new SStringPointerSettable(st, sb.newSettable[Long](s"${name}_a")) } -final class SStringPointerSettable(st: SStringPointer, override val a: Settable[Long]) extends SStringPointerValue(st, a) with SSettable { +final class SStringPointerSettable(st: SStringPointer, override val a: Settable[Long]) + extends SStringPointerValue(st, a) with SSettable { override def settableTuple(): IndexedSeq[Settable[_]] = FastSeq(a) - override def store(cb: EmitCodeBuilder, v: SValue): Unit = { + override def store(cb: EmitCodeBuilder, v: SValue): Unit = cb.assign(a, v.asInstanceOf[SStringPointerValue].a) - } - override def binaryRepr(): SBinaryPointerSettable = new SBinaryPointerSettable(SBinaryPointer(st.pType.binaryRepresentation), a) + override def binaryRepr(): SBinaryPointerSettable = + new SBinaryPointerSettable(SBinaryPointer(st.pType.binaryRepresentation), a) } diff --git a/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SStructView.scala b/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SStructView.scala new file mode 100644 index 00000000000..7a970e49c8b --- /dev/null +++ b/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SStructView.scala @@ -0,0 +1,162 @@ +package is.hail.types.physical.stypes.concrete + +import is.hail.annotations.Region +import is.hail.asm4s.{Settable, TypeInfo, Value} +import is.hail.expr.ir.{EmitCodeBuilder, IEmitCode} +import is.hail.types.physical.{PCanonicalStruct, PType} +import is.hail.types.physical.stypes.{EmitType, SType, SValue} +import is.hail.types.physical.stypes.interfaces.{SBaseStruct, SBaseStructSettable, SBaseStructValue} +import is.hail.types.virtual.{TBaseStruct, TStruct, Type} + +object SStructView { + def subset(fieldnames: IndexedSeq[String], struct: SBaseStruct): SStructView = + struct match { + case s: SStructView => + val pfields = s.parent.virtualType.fields + new SStructView( + s.parent, + fieldnames.map(f => pfields(s.newToOldFieldMapping(s.fieldIdx(f))).name), + s.rename.typeAfterSelectNames(fieldnames), + ) + + case s => + val restrict = s.virtualType.asInstanceOf[TStruct].typeAfterSelectNames(fieldnames) + new SStructView(s, fieldnames, restrict) + } +} + +// A 'view' on `SBaseStruct`s, ie one that presents an upcast and/or renamed facade on another +final class SStructView( + private val parent: SBaseStruct, + private val restrict: IndexedSeq[String], + private val rename: TStruct, +) extends SBaseStruct { + + assert( + parent.virtualType.asInstanceOf[TStruct].typeAfterSelectNames(restrict) isIsomorphicTo rename, + s"""Renamed type is not isomorphic to subsetted type + | parent: '${parent.virtualType._toPretty}' + | restrict: '${restrict.mkString("[", ",", "]")}' + | rename: '${rename._toPretty}' + |""".stripMargin, + ) + + override def size: Int = + restrict.length + + lazy val newToOldFieldMapping: Map[Int, Int] = + restrict.view.zipWithIndex.map { case (f, i) => i -> parent.fieldIdx(f) }.toMap + + override lazy val fieldTypes: IndexedSeq[SType] = + Array.tabulate(size) { i => + parent + .fieldTypes(newToOldFieldMapping(i)) + .castRename(rename.fields(i).typ) + } + + override lazy val fieldEmitTypes: IndexedSeq[EmitType] = + Array.tabulate(size) { i => + parent + .fieldEmitTypes(newToOldFieldMapping(i)) + .copy(st = fieldTypes(i)) + } + + override def virtualType: TBaseStruct = + rename + + override def fieldIdx(fieldName: String): Int = + rename.fieldIdx(fieldName) + + override def castRename(t: Type): SType = + new SStructView(parent, restrict, rename = t.asInstanceOf[TStruct]) + + override def _coerceOrCopy( + cb: EmitCodeBuilder, + region: Value[Region], + value: SValue, + deepCopy: Boolean, + ): SValue = { + if (deepCopy) + throw new NotImplementedError("Deep copy on struct view") + + value.st match { + case s: SStructView if this == s && !deepCopy => + value + } + } + + override def settableTupleTypes(): IndexedSeq[TypeInfo[_]] = + parent.settableTupleTypes() + + override def fromSettables(settables: IndexedSeq[Settable[_]]): SStructViewSettable = + new SStructViewSettable( + this, + parent.fromSettables(settables).asInstanceOf[SBaseStructSettable], + ) + + override def fromValues(values: IndexedSeq[Value[_]]): SStructViewValue = + new SStructViewValue(this, parent.fromValues(values).asInstanceOf[SBaseStructValue]) + + override def copiedType: SType = + if (virtualType.size < 64) + SStackStruct(virtualType, fieldEmitTypes.map(_.copiedType)) + else { + val ct = SBaseStructPointer(storageType().asInstanceOf[PCanonicalStruct]) + assert(ct.virtualType == virtualType, s"ct=$ct, this=$this") + ct + } + + def storageType(): PType = { + val pt = PCanonicalStruct( + required = false, + args = rename.fieldNames.zip(fieldEmitTypes.map(_.copiedType.storageType)): _*, + ) + assert(pt.virtualType == virtualType, s"pt=$pt, this=$this") + pt + } + + // aspirational implementation + // def storageType(): PType = StoredSTypePType(this, false) + + override def containsPointers: Boolean = + parent.containsPointers + + override def equals(obj: Any): Boolean = + obj match { + case s: SStructView => + rename == s.rename && + newToOldFieldMapping == s.newToOldFieldMapping && + parent == s.parent + case _ => + false + } +} + +class SStructViewValue(val st: SStructView, val prev: SBaseStructValue) extends SBaseStructValue { + + override lazy val valueTuple: IndexedSeq[Value[_]] = + prev.valueTuple + + override def subset(fieldNames: String*): SBaseStructValue = + new SStructViewValue(SStructView.subset(fieldNames.toIndexedSeq, st), prev) + + override def loadField(cb: EmitCodeBuilder, fieldIdx: Int): IEmitCode = + prev + .loadField(cb, st.newToOldFieldMapping(fieldIdx)) + .map(cb)(_.castRename(st.virtualType.fields(fieldIdx).typ)) + + override def isFieldMissing(cb: EmitCodeBuilder, fieldIdx: Int): Value[Boolean] = + prev.isFieldMissing(cb, st.newToOldFieldMapping(fieldIdx)) +} + +final class SStructViewSettable(st: SStructView, prev: SBaseStructSettable) + extends SStructViewValue(st, prev) with SBaseStructSettable { + override def subset(fieldNames: String*): SBaseStructValue = + new SStructViewSettable(SStructView.subset(fieldNames.toIndexedSeq, st), prev) + + override def settableTuple(): IndexedSeq[Settable[_]] = + prev.settableTuple() + + override def store(cb: EmitCodeBuilder, pv: SValue): Unit = + prev.store(cb, pv.asInstanceOf[SStructViewValue].prev) +} diff --git a/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SSubsetStruct.scala b/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SSubsetStruct.scala deleted file mode 100644 index 29ade6dab61..00000000000 --- a/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SSubsetStruct.scala +++ /dev/null @@ -1,102 +0,0 @@ -package is.hail.types.physical.stypes.concrete - -import is.hail.annotations.Region -import is.hail.asm4s.{Settable, TypeInfo, Value} -import is.hail.expr.ir.{EmitCodeBuilder, IEmitCode} -import is.hail.types.physical.stypes.interfaces.{SBaseStruct, SBaseStructSettable, SBaseStructValue} -import is.hail.types.physical.stypes.{EmitType, SCode, SType, SValue} -import is.hail.types.physical.{PCanonicalStruct, PType} -import is.hail.types.virtual.{TStruct, Type} - -final case class SSubsetStruct(parent: SBaseStruct, fieldNames: IndexedSeq[String]) extends SBaseStruct { - - override val size: Int = fieldNames.size - - val _fieldIdx: Map[String, Int] = fieldNames.zipWithIndex.toMap - val newToOldFieldMapping: Map[Int, Int] = _fieldIdx - .map { case (f, i) => (i, parent.virtualType.asInstanceOf[TStruct].fieldIdx(f)) } - - override val fieldTypes: IndexedSeq[SType] = Array.tabulate(size)(i => parent.fieldTypes(newToOldFieldMapping(i))) - override val fieldEmitTypes: IndexedSeq[EmitType] = Array.tabulate(size)(i => parent.fieldEmitTypes(newToOldFieldMapping(i))) - - override lazy val virtualType: TStruct = { - val vparent = parent.virtualType.asInstanceOf[TStruct] - TStruct(fieldNames.map(f => (f, vparent.field(f).typ)): _*) - } - - override def fieldIdx(fieldName: String): Int = _fieldIdx(fieldName) - - override def castRename(t: Type): SType = { - val renamedVType = t.asInstanceOf[TStruct] - val newNames = renamedVType.fieldNames - val subsetPrevVirtualType = virtualType - val vparent = parent.virtualType.asInstanceOf[TStruct] - val newParent = TStruct(vparent.fieldNames.map(f => subsetPrevVirtualType.fieldIdx.get(f) match { - case Some(idxInSelectedFields) => - val renamed = renamedVType.fields(idxInSelectedFields) - (renamed.name, renamed.typ) - case None => (f, vparent.fieldType(f)) - }): _*) - val newType = SSubsetStruct(parent.castRename(newParent).asInstanceOf[SBaseStruct], newNames) - assert(newType.virtualType == t) - newType - } - - override def _coerceOrCopy(cb: EmitCodeBuilder, region: Value[Region], value: SValue, deepCopy: Boolean): SValue = { - if (deepCopy) - throw new NotImplementedError("Deep copy on subset struct") - value.st match { - case SSubsetStruct(parent2, fd2) if parent == parent2 && fieldNames == fd2 && !deepCopy => - value - } - } - - override def settableTupleTypes(): IndexedSeq[TypeInfo[_]] = parent.settableTupleTypes() - - override def fromSettables(settables: IndexedSeq[Settable[_]]): SSubsetStructSettable = { - new SSubsetStructSettable(this, parent.fromSettables(settables).asInstanceOf[SBaseStructSettable]) - } - - override def fromValues(values: IndexedSeq[Value[_]]): SSubsetStructValue = { - new SSubsetStructValue(this, parent.fromValues(values).asInstanceOf[SBaseStructValue]) - } - - override def copiedType: SType = { - if (virtualType.size < 64) - SStackStruct(virtualType, fieldEmitTypes.map(_.copiedType)) - else { - val ct = SBaseStructPointer(PCanonicalStruct(false, virtualType.fieldNames.zip(fieldEmitTypes.map(_.copiedType.storageType)): _*)) - assert(ct.virtualType == virtualType, s"ct=$ct, this=$this") - ct - } - } - - def storageType(): PType = { - val pt = PCanonicalStruct(false, virtualType.fieldNames.zip(fieldEmitTypes.map(_.copiedType.storageType)): _*) - assert(pt.virtualType == virtualType, s"pt=$pt, this=$this") - pt - } - -// aspirational implementation -// def storageType(): PType = StoredSTypePType(this, false) - - override def containsPointers: Boolean = parent.containsPointers -} - -class SSubsetStructValue(val st: SSubsetStruct, val prev: SBaseStructValue) extends SBaseStructValue { - override lazy val valueTuple: IndexedSeq[Value[_]] = prev.valueTuple - - override def loadField(cb: EmitCodeBuilder, fieldIdx: Int): IEmitCode = { - prev.loadField(cb, st.newToOldFieldMapping(fieldIdx)) - } - - override def isFieldMissing(cb: EmitCodeBuilder, fieldIdx: Int): Value[Boolean] = - prev.isFieldMissing(cb, st.newToOldFieldMapping(fieldIdx)) -} - -final class SSubsetStructSettable(st: SSubsetStruct, prev: SBaseStructSettable) extends SSubsetStructValue(st, prev) with SBaseStructSettable { - override def settableTuple(): IndexedSeq[Settable[_]] = prev.settableTuple() - - override def store(cb: EmitCodeBuilder, pv: SValue): Unit = - prev.store(cb, pv.asInstanceOf[SSubsetStructValue].prev) -} diff --git a/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SUnreachable.scala b/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SUnreachable.scala index 007ec2533f6..5a6f47cf336 100644 --- a/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SUnreachable.scala +++ b/hail/src/main/scala/is/hail/types/physical/stypes/concrete/SUnreachable.scala @@ -3,10 +3,10 @@ package is.hail.types.physical.stypes.concrete import is.hail.annotations.Region import is.hail.asm4s._ import is.hail.expr.ir.{EmitCodeBuilder, EmitValue, IEmitCode} +import is.hail.types.physical.{PCanonicalNDArray, PNDArray, PType} import is.hail.types.physical.stypes._ import is.hail.types.physical.stypes.interfaces._ import is.hail.types.physical.stypes.primitives.SInt64Value -import is.hail.types.physical.{PCanonicalNDArray, PNDArray, PType} import is.hail.types.virtual._ import is.hail.utils.FastSeq @@ -32,7 +32,8 @@ object SUnreachable { abstract class SUnreachable extends SType { override def settableTupleTypes(): IndexedSeq[TypeInfo[_]] = FastSeq() - override def storageType(): PType = PType.canonical(virtualType, required = false, innerRequired = true) + override def storageType(): PType = + PType.canonical(virtualType, required = false, innerRequired = true) override def asIdent: String = s"s_unreachable" @@ -44,7 +45,12 @@ abstract class SUnreachable extends SType { override def fromValues(values: IndexedSeq[Value[_]]): SUnreachableValue = sv - override def _coerceOrCopy(cb: EmitCodeBuilder, region: Value[Region], value: SValue, deepCopy: Boolean): SValue = sv + override def _coerceOrCopy( + cb: EmitCodeBuilder, + region: Value[Region], + value: SValue, + deepCopy: Boolean, + ): SValue = sv override def copiedType: SType = this @@ -72,7 +78,8 @@ case class SUnreachableStruct(virtualType: TBaseStruct) extends SUnreachable wit override val sv = new SUnreachableStructValue(this) } -class SUnreachableStructValue(override val st: SUnreachableStruct) extends SUnreachableValue with SBaseStructValue { +class SUnreachableStructValue(override val st: SUnreachableStruct) + extends SUnreachableValue with SBaseStructValue { override def loadField(cb: EmitCodeBuilder, fieldIdx: Int): IEmitCode = IEmitCode.present(cb, SUnreachable.fromVirtualType(st.virtualType.types(fieldIdx)).defaultValue) @@ -84,7 +91,12 @@ class SUnreachableStructValue(override val st: SUnreachableStruct) extends SUnre new SUnreachableStructValue(SUnreachableStruct(newType)) } - override def insert(cb: EmitCodeBuilder, region: Value[Region], newType: TStruct, fields: (String, EmitValue)*): SBaseStructValue = + override def insert( + cb: EmitCodeBuilder, + region: Value[Region], + newType: TStruct, + fields: (String, EmitValue)* + ): SBaseStructValue = new SUnreachableStructValue(SUnreachableStruct(newType)) override def _insert(newType: TStruct, fields: (String, EmitValue)*): SBaseStructValue = @@ -112,7 +124,8 @@ case object SUnreachableString extends SUnreachable with SString { override val sv = new SUnreachableStringValue - override def constructFromString(cb: EmitCodeBuilder, r: Value[Region], s: Code[String]): SStringValue = sv + override def constructFromString(cb: EmitCodeBuilder, r: Value[Region], s: Code[String]) + : SStringValue = sv } class SUnreachableStringValue extends SUnreachableValue with SStringValue { @@ -133,14 +146,18 @@ case class SUnreachableLocus(virtualType: TLocus) extends SUnreachable with SLoc override def rg: String = virtualType.rg } -class SUnreachableLocusValue(override val st: SUnreachableLocus) extends SUnreachableValue with SLocusValue { +class SUnreachableLocusValue(override val st: SUnreachableLocus) + extends SUnreachableValue with SLocusValue { override def position(cb: EmitCodeBuilder): Value[Int] = const(0) override def contig(cb: EmitCodeBuilder): SStringValue = SUnreachableString.sv override def contigLong(cb: EmitCodeBuilder): Value[Long] = const(0) - override def structRepr(cb: EmitCodeBuilder): SBaseStructValue = SUnreachableStruct(TStruct("contig" -> TString, "position" -> TInt32)).defaultValue.asInstanceOf[SUnreachableStructValue] + override def structRepr(cb: EmitCodeBuilder): SBaseStructValue = SUnreachableStruct(TStruct( + "contig" -> TString, + "position" -> TInt32, + )).defaultValue.asInstanceOf[SUnreachableStructValue] } case object SUnreachableCall extends SUnreachable with SCall { @@ -164,10 +181,10 @@ class SUnreachableCallValue extends SUnreachableValue with SCallValue { override def st: SUnreachableCall.type = SUnreachableCall - override def lgtToGT(cb: EmitCodeBuilder, localAlleles: SIndexableValue, errorID: Value[Int]): SCallValue = st.sv + override def lgtToGT(cb: EmitCodeBuilder, localAlleles: SIndexableValue, errorID: Value[Int]) + : SCallValue = st.sv } - case class SUnreachableInterval(virtualType: TInterval) extends SUnreachable with SInterval { override val sv = new SUnreachableIntervalValue(this) @@ -176,23 +193,25 @@ case class SUnreachableInterval(virtualType: TInterval) extends SUnreachable wit override def pointEmitType: EmitType = EmitType(pointType, true) } -class SUnreachableIntervalValue(override val st: SUnreachableInterval) extends SUnreachableValue with SIntervalValue { +class SUnreachableIntervalValue(override val st: SUnreachableInterval) + extends SUnreachableValue with SIntervalValue { override def includesStart: Value[Boolean] = const(false) override def includesEnd: Value[Boolean] = const(false) - override def loadStart(cb: EmitCodeBuilder): IEmitCode = IEmitCode.present(cb, SUnreachable.fromVirtualType(st.virtualType.pointType).defaultValue) + override def loadStart(cb: EmitCodeBuilder): IEmitCode = + IEmitCode.present(cb, SUnreachable.fromVirtualType(st.virtualType.pointType).defaultValue) override def startDefined(cb: EmitCodeBuilder): Value[Boolean] = const(false) - override def loadEnd(cb: EmitCodeBuilder): IEmitCode = IEmitCode.present(cb, SUnreachable.fromVirtualType(st.virtualType.pointType).defaultValue) + override def loadEnd(cb: EmitCodeBuilder): IEmitCode = + IEmitCode.present(cb, SUnreachable.fromVirtualType(st.virtualType.pointType).defaultValue) override def endDefined(cb: EmitCodeBuilder): Value[Boolean] = const(false) override def isEmpty(cb: EmitCodeBuilder): Value[Boolean] = const(false) } - case class SUnreachableNDArray(virtualType: TNDArray) extends SUnreachable with SNDArray { override val sv = new SUnreachableNDArrayValue(this) @@ -207,31 +226,51 @@ case class SUnreachableNDArray(virtualType: TNDArray) extends SUnreachable with override def elementByteSize: Long = 0L } -class SUnreachableNDArrayValue(override val st: SUnreachableNDArray) extends SUnreachableValue with SNDArraySettable { +class SUnreachableNDArrayValue(override val st: SUnreachableNDArray) + extends SUnreachableValue with SNDArraySettable { val pt = st.pType - override def loadElement(indices: IndexedSeq[Value[Long]], cb: EmitCodeBuilder): SValue = SUnreachable.fromVirtualType(st.virtualType.elementType).defaultValue + override def loadElement(indices: IndexedSeq[Value[Long]], cb: EmitCodeBuilder): SValue = + SUnreachable.fromVirtualType(st.virtualType.elementType).defaultValue - override def loadElementAddress(indices: IndexedSeq[is.hail.asm4s.Value[Long]],cb: is.hail.expr.ir.EmitCodeBuilder): is.hail.asm4s.Code[Long] = const(0L) + override def loadElementAddress( + indices: IndexedSeq[is.hail.asm4s.Value[Long]], + cb: is.hail.expr.ir.EmitCodeBuilder, + ): is.hail.asm4s.Code[Long] = const(0L) override def shapes: IndexedSeq[SizeValue] = (0 until st.nDims).map(_ => SizeValueStatic(0L)) - override def shapeStruct(cb: EmitCodeBuilder): SBaseStructValue = SUnreachableStruct(TTuple((0 until st.nDims).map(_ => TInt64): _*)).sv + override def shapeStruct(cb: EmitCodeBuilder): SBaseStructValue = + SUnreachableStruct(TTuple((0 until st.nDims).map(_ => TInt64): _*)).sv override def strides: IndexedSeq[Value[Long]] = (0 until st.nDims).map(_ => const(0L)) - override def outOfBounds(indices: IndexedSeq[Value[Long]], cb: EmitCodeBuilder): Code[Boolean] = const(false) + override def outOfBounds(indices: IndexedSeq[Value[Long]], cb: EmitCodeBuilder): Code[Boolean] = + const(false) - override def assertInBounds(indices: IndexedSeq[Value[Long]], cb: EmitCodeBuilder, errorId: Int = -1): Unit = {} + override def assertInBounds( + indices: IndexedSeq[Value[Long]], + cb: EmitCodeBuilder, + errorId: Int = -1, + ): Unit = {} override def sameShape(cb: EmitCodeBuilder, other: SNDArrayValue): Code[Boolean] = const(false) - override def coerceToShape(cb: EmitCodeBuilder, otherShape: IndexedSeq[SizeValue]): SNDArrayValue = this + override def coerceToShape(cb: EmitCodeBuilder, otherShape: IndexedSeq[SizeValue]) + : SNDArrayValue = this override def firstDataAddress: Value[Long] = const(0L) - override def coiterateMutate(cb: EmitCodeBuilder, region: Value[Region], deepCopy: Boolean, indexVars: IndexedSeq[String], - destIndices: IndexedSeq[Int], arrays: (SNDArrayValue, IndexedSeq[Int], String)*)(body: IndexedSeq[SValue] => SValue): Unit = () + override def coiterateMutate( + cb: EmitCodeBuilder, + region: Value[Region], + deepCopy: Boolean, + indexVars: IndexedSeq[String], + destIndices: IndexedSeq[Int], + arrays: (SNDArrayValue, IndexedSeq[Int], String)* + )( + body: IndexedSeq[SValue] => SValue + ): Unit = () } case class SUnreachableContainer(virtualType: TContainer) extends SUnreachable with SContainer { @@ -242,12 +281,14 @@ case class SUnreachableContainer(virtualType: TContainer) extends SUnreachable w lazy val elementEmitType: EmitType = EmitType(elementType, true) } -class SUnreachableContainerValue(override val st: SUnreachableContainer) extends SUnreachableValue with SIndexableValue { +class SUnreachableContainerValue(override val st: SUnreachableContainer) + extends SUnreachableValue with SIndexableValue { override def loadLength(): Value[Int] = const(0) override def isElementMissing(cb: EmitCodeBuilder, i: Code[Int]): Value[Boolean] = const(false) - override def loadElement(cb: EmitCodeBuilder, i: Code[Int]): IEmitCode = IEmitCode.present(cb, SUnreachable.fromVirtualType(st.virtualType.elementType).defaultValue) + override def loadElement(cb: EmitCodeBuilder, i: Code[Int]): IEmitCode = + IEmitCode.present(cb, SUnreachable.fromVirtualType(st.virtualType.elementType).defaultValue) override def hasMissingValues(cb: EmitCodeBuilder): Value[Boolean] = const(false) diff --git a/hail/src/main/scala/is/hail/types/physical/stypes/interfaces/SBaseStruct.scala b/hail/src/main/scala/is/hail/types/physical/stypes/interfaces/SBaseStruct.scala index 94d35a67f26..9f3cc00b0bd 100644 --- a/hail/src/main/scala/is/hail/types/physical/stypes/interfaces/SBaseStruct.scala +++ b/hail/src/main/scala/is/hail/types/physical/stypes/interfaces/SBaseStruct.scala @@ -3,12 +3,12 @@ package is.hail.types.physical.stypes.interfaces import is.hail.annotations.Region import is.hail.asm4s._ import is.hail.expr.ir.{EmitCode, EmitCodeBuilder, EmitValue, IEmitCode} +import is.hail.types.{RStruct, RTuple, TypeWithRequiredness} import is.hail.types.physical.PCanonicalStruct import is.hail.types.physical.stypes._ import is.hail.types.physical.stypes.concrete._ import is.hail.types.physical.stypes.primitives.{SInt32Value, SInt64Value} import is.hail.types.virtual.{TBaseStruct, TStruct, TTuple} -import is.hail.types.{RStruct, RTuple, TypeWithRequiredness} import is.hail.utils._ object SBaseStruct { @@ -17,7 +17,6 @@ object SBaseStruct { val rt = s2.st.virtualType.asInstanceOf[TStruct] val resultVType = TStruct.concat(lt, rt) - val st1 = s1.st val st2 = s2.st (s1, s2) match { @@ -26,7 +25,8 @@ object SBaseStruct { case (s1: SStackStructValue, s2) => s2._insert(resultVType, lt.fieldNames.zip(s1.values): _*) case _ => - val newVals = (0 until st2.size).map(i => cb.memoize(s2.loadField(cb, i), "InsertFieldsStruct_merge")) + val newVals = + (0 until st2.size).map(i => cb.memoize(s2.loadField(cb, i), "InsertFieldsStruct_merge")) s1._insert(resultVType, rt.fieldNames.zip(newVals): _*) } } @@ -37,19 +37,19 @@ trait SBaseStruct extends SType { def size: Int - val fieldTypes: IndexedSeq[SType] - val fieldEmitTypes: IndexedSeq[EmitType] + def fieldTypes: IndexedSeq[SType] + def fieldEmitTypes: IndexedSeq[EmitType] def fieldIdx(fieldName: String): Int def _typeWithRequiredness: TypeWithRequiredness = { virtualType match { case ts: TStruct => RStruct.fromNamesAndTypes(ts.fieldNames.zip(fieldEmitTypes).map { - case (name, et) => (name, et.typeWithRequiredness.r) - }) + case (name, et) => (name, et.typeWithRequiredness.r) + }) case tt: TTuple => RTuple.fromNamesAndTypes(tt._types.zip(fieldEmitTypes).map { - case (f, et) => (f.index.toString, et.typeWithRequiredness.r) - }) + case (f, et) => (f.index.toString, et.typeWithRequiredness.r) + }) } } } @@ -66,62 +66,80 @@ trait SBaseStructValue extends SValue { def loadField(cb: EmitCodeBuilder, fieldIdx: Int): IEmitCode - def loadField(cb: EmitCodeBuilder, fieldName: String): IEmitCode = loadField(cb, st.fieldIdx(fieldName)) + def loadField(cb: EmitCodeBuilder, fieldName: String): IEmitCode = + loadField(cb, st.fieldIdx(fieldName)) - def subset(fieldNames: String*): SBaseStructValue = { - val st = SSubsetStruct(this.st, fieldNames.toIndexedSeq) - new SSubsetStructValue(st, this) - } + def subset(fieldNames: String*): SBaseStructValue = + new SStructViewValue(SStructView.subset(fieldNames.toIndexedSeq, st), this) override def hash(cb: EmitCodeBuilder): SInt32Value = { val hash_result = cb.newLocal[Int]("hash_result_struct", 1) - (0 until st.size).foreach(i => { - loadField(cb, i).consume(cb, { cb.assign(hash_result, hash_result * 31) }, - {field => cb.assign(hash_result, (hash_result * 31) + field.hash(cb).value)}) - }) + (0 until st.size).foreach { i => + loadField(cb, i).consume( + cb, + cb.assign(hash_result, hash_result * 31), + field => cb.assign(hash_result, (hash_result * 31) + field.hash(cb).value), + ) + } new SInt32Value(hash_result) } override def sizeToStoreInBytes(cb: EmitCodeBuilder): SInt64Value = { - // Size in bytes of the struct that must represent this thing, plus recursive call on any non-missing children. + /* Size in bytes of the struct that must represent this thing, plus recursive call on any + * non-missing children. */ val pStructSize = this.st.storageType().byteSize val sizeSoFar = cb.newLocal[Long]("sstackstruct_size_in_bytes", pStructSize) (0 until st.size).foreach { idx => if (this.st.fieldTypes(idx).containsPointers) { - val sizeAtThisIdx: Value[Long] = this.loadField(cb, idx).consumeCode(cb, { - const(0L) - }, { sv => - sv.sizeToStoreInBytes(cb).value - }) + val sizeAtThisIdx: Value[Long] = + this.loadField(cb, idx).consumeCode(cb, const(0L), sv => sv.sizeToStoreInBytes(cb).value) cb.assign(sizeSoFar, sizeSoFar + sizeAtThisIdx) } } new SInt64Value(sizeSoFar) } - def toStackStruct(cb: EmitCodeBuilder): SStackStructValue = { + def toStackStruct(cb: EmitCodeBuilder): SStackStructValue = new SStackStructValue( SStackStruct(st.virtualType, st.fieldEmitTypes), - Array.tabulate(st.size)( i => cb.memoize(loadField(cb, i)))) - } + Array.tabulate(st.size)(i => cb.memoize(loadField(cb, i))), + ) - def _insert(newType: TStruct, fields: (String, EmitValue)*): SBaseStructValue = { + def _insert(newType: TStruct, fields: (String, EmitValue)*): SBaseStructValue = new SInsertFieldsStructValue( - SInsertFieldsStruct(newType, st, fields.map { case (name, ec) => (name, ec.emitType) }.toFastSeq), + SInsertFieldsStruct( + newType, + st, + fields.map { case (name, ec) => (name, ec.emitType) }.toFastSeq, + ), this, - fields.map(_._2).toFastSeq + fields.map(_._2).toFastSeq, ) - } - def insert(cb: EmitCodeBuilder, region: Value[Region], newType: TStruct, fields: (String, EmitValue)*): SBaseStructValue = { - if (st.settableTupleTypes().length + fields.map(_._2.emitType.settableTupleTypes.length).sum < 64) + def insert( + cb: EmitCodeBuilder, + region: Value[Region], + newType: TStruct, + fields: (String, EmitValue)* + ): SBaseStructValue = { + if ( + st.settableTupleTypes().length + fields.map(_._2.emitType.settableTupleTypes.length).sum < 64 + ) return _insert(newType, fields: _*) val newFieldMap = fields.toMap val allFields = newType.fieldNames.map { f => - (f, newFieldMap.getOrElse(f, cb.memoize(EmitCode.fromI(cb.emb)(cb => loadField(cb, f)), "insert"))) } + ( + f, + newFieldMap.getOrElse( + f, + cb.memoize(EmitCode.fromI(cb.emb)(cb => loadField(cb, f)), "insert"), + ), + ) + } - val pcs = PCanonicalStruct(false, allFields.map { case (f, ec) => (f, ec.emitType.storageType) }: _*) + val pcs = + PCanonicalStruct(false, allFields.map { case (f, ec) => (f, ec.emitType.storageType) }: _*) pcs.constructFromFields(cb, region, allFields.map(_._2.load), false) } } diff --git a/hail/src/main/scala/is/hail/types/physical/stypes/interfaces/SBinary.scala b/hail/src/main/scala/is/hail/types/physical/stypes/interfaces/SBinary.scala index 68cda63552e..9950bafbefe 100644 --- a/hail/src/main/scala/is/hail/types/physical/stypes/interfaces/SBinary.scala +++ b/hail/src/main/scala/is/hail/types/physical/stypes/interfaces/SBinary.scala @@ -1,12 +1,12 @@ package is.hail.types.physical.stypes.interfaces -import is.hail.asm4s.Code.invokeStatic1 import is.hail.asm4s._ +import is.hail.asm4s.Code.invokeStatic1 import is.hail.expr.ir.EmitCodeBuilder +import is.hail.types.{RPrimitive, TypeWithRequiredness} import is.hail.types.physical.PCanonicalBinary -import is.hail.types.physical.stypes.primitives.{SInt32Value, SInt64Value} import is.hail.types.physical.stypes.{SType, SValue} -import is.hail.types.{RPrimitive, TypeWithRequiredness} +import is.hail.types.physical.stypes.primitives.{SInt32Value, SInt64Value} trait SBinary extends SType { override def _typeWithRequiredness: TypeWithRequiredness = RPrimitive() @@ -20,7 +20,10 @@ trait SBinaryValue extends SValue { def loadByte(cb: EmitCodeBuilder, i: Code[Int]): Value[Byte] override def hash(cb: EmitCodeBuilder): SInt32Value = - new SInt32Value(cb.memoize(invokeStatic1[java.util.Arrays, Array[Byte], Int]("hashCode", loadBytes(cb)))) + new SInt32Value(cb.memoize(invokeStatic1[java.util.Arrays, Array[Byte], Int]( + "hashCode", + loadBytes(cb), + ))) override def sizeToStoreInBytes(cb: EmitCodeBuilder): SInt64Value = { val binaryStorageType = this.st.storageType().asInstanceOf[PCanonicalBinary] diff --git a/hail/src/main/scala/is/hail/types/physical/stypes/interfaces/SCall.scala b/hail/src/main/scala/is/hail/types/physical/stypes/interfaces/SCall.scala index 73ccf2f67be..ec978b0bc9a 100644 --- a/hail/src/main/scala/is/hail/types/physical/stypes/interfaces/SCall.scala +++ b/hail/src/main/scala/is/hail/types/physical/stypes/interfaces/SCall.scala @@ -1,10 +1,10 @@ package is.hail.types.physical.stypes.interfaces -import is.hail.asm4s.{Code, Value} +import is.hail.asm4s.Value import is.hail.expr.ir.EmitCodeBuilder -import is.hail.types.physical.stypes.primitives.SInt32Value -import is.hail.types.physical.stypes.{SCode, SType, SValue} import is.hail.types.{RPrimitive, TypeWithRequiredness} +import is.hail.types.physical.stypes.{SType, SValue} +import is.hail.types.physical.stypes.primitives.SInt32Value trait SCall extends SType { override def _typeWithRequiredness: TypeWithRequiredness = RPrimitive() diff --git a/hail/src/main/scala/is/hail/types/physical/stypes/interfaces/SContainer.scala b/hail/src/main/scala/is/hail/types/physical/stypes/interfaces/SContainer.scala index 36a3041a5bf..a98a1f45bb9 100644 --- a/hail/src/main/scala/is/hail/types/physical/stypes/interfaces/SContainer.scala +++ b/hail/src/main/scala/is/hail/types/physical/stypes/interfaces/SContainer.scala @@ -3,15 +3,17 @@ package is.hail.types.physical.stypes.interfaces import is.hail.annotations.Region import is.hail.asm4s._ import is.hail.expr.ir.{EmitCodeBuilder, IEmitCode} +import is.hail.types.{RIterable, TypeWithRequiredness} import is.hail.types.physical.{PCanonicalArray, PContainer} -import is.hail.types.physical.stypes.primitives.{SInt32Value, SInt64Value} import is.hail.types.physical.stypes.{EmitType, SType, SValue} -import is.hail.types.{RIterable, TypeWithRequiredness} +import is.hail.types.physical.stypes.primitives.{SInt32Value, SInt64Value} trait SContainer extends SType { def elementType: SType def elementEmitType: EmitType - override def _typeWithRequiredness: TypeWithRequiredness = RIterable(elementEmitType.typeWithRequiredness.r) + + override def _typeWithRequiredness: TypeWithRequiredness = + RIterable(elementEmitType.typeWithRequiredness.r) } trait SIndexableValue extends SValue { @@ -30,58 +32,81 @@ trait SIndexableValue extends SValue { def castToArray(cb: EmitCodeBuilder): SIndexableValue - def forEachDefined(cb: EmitCodeBuilder)(f: (EmitCodeBuilder, Value[Int], SValue) => Unit): Unit = { + def forEachDefined(cb: EmitCodeBuilder)(f: (EmitCodeBuilder, Value[Int], SValue) => Unit) + : Unit = { val length = loadLength() val idx = cb.newLocal[Int]("foreach_idx", 0) - cb.while_(idx < length, { - - loadElement(cb, idx).consume(cb, - {}, /*do nothing if missing*/ - { eltCode => - f(cb, idx, eltCode) - }) - cb.assign(idx, idx + 1) - }) + cb.while_( + idx < length, { + + loadElement(cb, idx).consume( + cb, + {}, /*do nothing if missing*/ + eltCode => f(cb, idx, eltCode), + ) + cb.assign(idx, idx + 1) + }, + ) } - def forEachDefinedOrMissing(cb: EmitCodeBuilder)(missingF: (EmitCodeBuilder, Value[Int]) => Unit, presentF: (EmitCodeBuilder, Value[Int], SValue) => Unit): Unit = { + def forEachDefinedOrMissing( + cb: EmitCodeBuilder + )( + missingF: (EmitCodeBuilder, Value[Int]) => Unit, + presentF: (EmitCodeBuilder, Value[Int], SValue) => Unit, + ): Unit = { val length = loadLength() val idx = cb.newLocal[Int]("foreach_idx", 0) - cb.while_(idx < length, { - - loadElement(cb, idx).consume(cb, - { /*do function if missing*/ - missingF(cb, idx) - }, - { eltCode => - presentF(cb, idx, eltCode) - }) - cb.assign(idx, idx + 1) - }) + cb.while_( + idx < length, { + + loadElement(cb, idx).consume( + cb, + /* do function if missing */ + missingF(cb, idx), + eltCode => presentF(cb, idx, eltCode), + ) + cb.assign(idx, idx + 1) + }, + ) } override def hash(cb: EmitCodeBuilder): SInt32Value = { val hash_result = cb.newLocal[Int]("array_hash", 1) - forEachDefinedOrMissing(cb)({ - case (cb, idx) => cb.assign(hash_result, hash_result * 31) - }, { - case (cb, idx, element) => - cb.assign(hash_result, hash_result * 31 + element.hash(cb).value) - }) + forEachDefinedOrMissing(cb)( + { + case (cb, _) => cb.assign(hash_result, hash_result * 31) + }, + { + case (cb, _, element) => + cb.assign(hash_result, hash_result * 31 + element.hash(cb).value) + }, + ) new SInt32Value(hash_result) } - def sliceArray(cb: EmitCodeBuilder, region: Value[Region], pt: PCanonicalArray, start: Code[Int], end: Code[Int], deepCopy: Boolean = false): SIndexableValue = { + def sliceArray( + cb: EmitCodeBuilder, + region: Value[Region], + pt: PCanonicalArray, + start: Code[Int], + end: Code[Int], + deepCopy: Boolean = false, + ): SIndexableValue = { val startMemo = cb.newLocal[Int]("sindexable_slice_array_start_memo", start) - pt.constructFromElements(cb, region, cb.newLocal[Int]("slice_length", end - startMemo), deepCopy){ (cb, idx) => - this.loadElement(cb, idx + startMemo) - } + pt.constructFromElements( + cb, + region, + cb.newLocal[Int]("slice_length", end - startMemo), + deepCopy, + )((cb, idx) => this.loadElement(cb, idx + startMemo)) } override def sizeToStoreInBytes(cb: EmitCodeBuilder): SInt64Value = { val storageType = this.st.storageType().asInstanceOf[PContainer] val length = this.loadLength() - val totalSize = cb.newLocal[Long]("sindexableptr_size_in_bytes", storageType.elementsOffset(length).toL) + val totalSize = + cb.newLocal[Long]("sindexableptr_size_in_bytes", storageType.elementsOffset(length).toL) if (this.st.elementType.containsPointers) { this.forEachDefined(cb) { (cb, _, element) => cb.assign(totalSize, totalSize + element.sizeToStoreInBytes(cb).value) diff --git a/hail/src/main/scala/is/hail/types/physical/stypes/interfaces/SInterval.scala b/hail/src/main/scala/is/hail/types/physical/stypes/interfaces/SInterval.scala index 6c6eb11d9a4..fb67fbe0827 100644 --- a/hail/src/main/scala/is/hail/types/physical/stypes/interfaces/SInterval.scala +++ b/hail/src/main/scala/is/hail/types/physical/stypes/interfaces/SInterval.scala @@ -1,16 +1,16 @@ package is.hail.types.physical.stypes.interfaces -import is.hail.asm4s.Value +import is.hail.asm4s._ import is.hail.expr.ir.{EmitCodeBuilder, IEmitCode} +import is.hail.expr.ir.orderings.CodeOrdering import is.hail.types.{RInterval, TypeWithRequiredness} -import is.hail.types.physical.stypes.primitives.SInt64Value import is.hail.types.physical.stypes.{EmitType, SType, SValue} -import is.hail.asm4s._ -import is.hail.expr.ir.orderings.CodeOrdering +import is.hail.types.physical.stypes.primitives.SInt64Value trait SInterval extends SType { def pointType: SType def pointEmitType: EmitType + override def _typeWithRequiredness: TypeWithRequiredness = { val pt = pointEmitType.typeWithRequiredness.r RInterval(pt, pt) @@ -36,13 +36,17 @@ trait SIntervalValue extends SValue { val pIntervalSize = this.st.storageType().byteSize val sizeSoFar = cb.newLocal[Long]("sstackstruct_size_in_bytes", pIntervalSize) - loadStart(cb).consume(cb, {}, {sv => - cb.assign(sizeSoFar, sizeSoFar + sv.sizeToStoreInBytes(cb).value) - }) + loadStart(cb).consume( + cb, + {}, + sv => cb.assign(sizeSoFar, sizeSoFar + sv.sizeToStoreInBytes(cb).value), + ) - loadEnd(cb).consume(cb, {}, {sv => - cb.assign(sizeSoFar, sizeSoFar + sv.sizeToStoreInBytes(cb).value) - }) + loadEnd(cb).consume( + cb, + {}, + sv => cb.assign(sizeSoFar, sizeSoFar + sv.sizeToStoreInBytes(cb).value), + ) new SInt64Value(sizeSoFar) } @@ -54,9 +58,11 @@ trait SIntervalValue extends SValue { val start = cb.memoize(loadStart(cb), "start") val end = cb.memoize(loadEnd(cb), "end") val empty = cb.newLocal[Boolean]("is_empty") - cb.if_(includesStart && includesEnd, + cb.if_( + includesStart && includesEnd, cb.assign(empty, gt(cb, start, end)), - cb.assign(empty, gteq(cb, start, end))) + cb.assign(empty, gteq(cb, start, end)), + ) empty } } diff --git a/hail/src/main/scala/is/hail/types/physical/stypes/interfaces/SLocus.scala b/hail/src/main/scala/is/hail/types/physical/stypes/interfaces/SLocus.scala index 7c995c28f88..0f673e69594 100644 --- a/hail/src/main/scala/is/hail/types/physical/stypes/interfaces/SLocus.scala +++ b/hail/src/main/scala/is/hail/types/physical/stypes/interfaces/SLocus.scala @@ -2,10 +2,10 @@ package is.hail.types.physical.stypes.interfaces import is.hail.asm4s.{Code, Value} import is.hail.expr.ir.EmitCodeBuilder -import is.hail.types.physical.stypes.primitives.{SInt32Value, SInt64Value} -import is.hail.types.physical.stypes.{SCode, SType, SValue} import is.hail.types.{RPrimitive, TypeWithRequiredness} -import is.hail.variant.{Locus, ReferenceGenome} +import is.hail.types.physical.stypes.{SType, SValue} +import is.hail.types.physical.stypes.primitives.{SInt32Value, SInt64Value} +import is.hail.variant.Locus trait SLocus extends SType { def rg: String @@ -23,13 +23,17 @@ trait SLocusValue extends SValue { def position(cb: EmitCodeBuilder): Value[Int] def getLocusObj(cb: EmitCodeBuilder): Value[Locus] = - cb.memoize(Code.invokeStatic2[Locus, String, Int, Locus]("apply", - contig(cb).loadString(cb), position(cb))) + cb.memoize(Code.invokeStatic2[Locus, String, Int, Locus]( + "apply", + contig(cb).loadString(cb), + position(cb), + )) def structRepr(cb: EmitCodeBuilder): SBaseStructValue override def hash(cb: EmitCodeBuilder): SInt32Value = structRepr(cb).hash(cb) - override def sizeToStoreInBytes(cb: EmitCodeBuilder): SInt64Value = structRepr(cb).sizeToStoreInBytes(cb) + override def sizeToStoreInBytes(cb: EmitCodeBuilder): SInt64Value = + structRepr(cb).sizeToStoreInBytes(cb) } diff --git a/hail/src/main/scala/is/hail/types/physical/stypes/interfaces/SNDArray.scala b/hail/src/main/scala/is/hail/types/physical/stypes/interfaces/SNDArray.scala index 99d964e0fad..53791e5c651 100644 --- a/hail/src/main/scala/is/hail/types/physical/stypes/interfaces/SNDArray.scala +++ b/hail/src/main/scala/is/hail/types/physical/stypes/interfaces/SNDArray.scala @@ -4,35 +4,50 @@ import is.hail.annotations.Region import is.hail.asm4s._ import is.hail.expr.ir.EmitCodeBuilder import is.hail.linalg.{BLAS, LAPACK} +import is.hail.types.{RNDArray, TypeWithRequiredness} +import is.hail.types.physical._ +import is.hail.types.physical.stypes.{EmitType, SSettable, SType, SValue} import is.hail.types.physical.stypes.concrete.{SNDArraySlice, SNDArraySliceValue} import is.hail.types.physical.stypes.primitives.SInt64Value -import is.hail.types.physical.stypes.{EmitType, SSettable, SType, SValue} -import is.hail.types.physical._ import is.hail.types.virtual.TInt32 -import is.hail.types.{RNDArray, TypeWithRequiredness} -import is.hail.utils.{FastSeq, toRichIterable, valueToRichCodeRegion} +import is.hail.utils.{toRichIterable, valueToRichCodeRegion, FastSeq} import scala.collection.mutable object SNDArray { - def numElements(shape: IndexedSeq[Value[Long]]): Code[Long] = { + def numElements(shape: IndexedSeq[Value[Long]]): Code[Long] = shape.foldLeft(1L: Code[Long])(_ * _) - } // Column major order - def forEachIndexColMajor(cb: EmitCodeBuilder, shape: IndexedSeq[Value[Long]], context: String) - (f: (EmitCodeBuilder, IndexedSeq[Value[Long]]) => Unit): Unit = { - forEachIndexWithInitAndIncColMajor(cb, shape, shape.map(_ => (cb: EmitCodeBuilder) => ()), shape.map(_ => (cb: EmitCodeBuilder) => ()), context)(f) - } + def forEachIndexColMajor( + cb: EmitCodeBuilder, + shape: IndexedSeq[Value[Long]], + context: String, + )( + f: (EmitCodeBuilder, IndexedSeq[Value[Long]]) => Unit + ): Unit = + forEachIndexWithInitAndIncColMajor( + cb, + shape, + shape.map(_ => (cb: EmitCodeBuilder) => ()), + shape.map(_ => (cb: EmitCodeBuilder) => ()), + context, + )(f) - def coiterate(cb: EmitCodeBuilder, arrays: (SNDArrayValue, String)*)(body: IndexedSeq[SValue] => Unit): Unit = { + def coiterate( + cb: EmitCodeBuilder, + arrays: (SNDArrayValue, String)* + )( + body: IndexedSeq[SValue] => Unit + ): Unit = { if (arrays.isEmpty) return val indexVars = Array.tabulate(arrays(0)._1.st.nDims)(i => s"i$i").toFastSeq val indices = Array.range(0, arrays(0)._1.st.nDims).toFastSeq coiterate(cb, indexVars, arrays.map { case (array, name) => (array, indices, name) }: _*)(body) } - // Note: to iterate through an array in column major order, make sure the indices are in ascending order. E.g. + /* Note: to iterate through an array in column major order, make sure the indices are in ascending + * order. E.g. */ // A.coiterate(cb, IndexedSeq("i", "j"), (A, IndexedSeq(0, 1), "A"), (B, IndexedSeq(0, 1), "B"), { // SCode.add(cb, a, b) // }) @@ -41,7 +56,8 @@ object SNDArray { cb: EmitCodeBuilder, indexVars: IndexedSeq[String], arrays: (SNDArrayValue, IndexedSeq[Int], String)* - )(body: IndexedSeq[SValue] => Unit + )( + body: IndexedSeq[SValue] => Unit ): Unit = { _coiterate(cb, indexVars, arrays: _*) { ptrs => val codes = ptrs.zip(arrays).map { case (ptr, (array, _, _)) => @@ -56,22 +72,24 @@ object SNDArray { cb: EmitCodeBuilder, indexVars: IndexedSeq[String], arrays: (SNDArrayValue, IndexedSeq[Int], String)* - )(body: IndexedSeq[Value[Long]] => Unit + )( + body: IndexedSeq[Value[Long]] => Unit ): Unit = { val indexSizes = new Array[Settable[Int]](indexVars.length) - val indexCoords = Array.tabulate(indexVars.length) { i => cb.newLocal[Int](indexVars(i)) } + val indexCoords = Array.tabulate(indexVars.length)(i => cb.newLocal[Int](indexVars(i))) case class ArrayInfo( array: SNDArrayValue, strides: IndexedSeq[Value[Long]], pos: IndexedSeq[Settable[Long]], indexToDim: Map[Int, Int], - name: String) + name: String, + ) val info = arrays.toIndexedSeq.map { case (array, indices, name) => for (idx <- indices) assert(idx < indexVars.length && idx >= 0) // FIXME: relax this assumption to handle transposing, non-column major - for (i <- 0 until indices.length - 1) assert(indices(i) < indices(i+1)) + for (i <- 0 until indices.length - 1) assert(indices(i) < indices(i + 1)) assert(indices.length == array.st.nDims) val shape = array.shapes @@ -81,11 +99,14 @@ object SNDArray { indexSizes(idx) = cb.newLocal[Int](s"${indexVars(idx)}_max") cb.assign(indexSizes(idx), shape(i).toI) } else { - cb.if_(indexSizes(idx).cne(shape(i).toI), cb._fatal(s"${indexVars(idx)} indexes incompatible dimensions")) + cb.if_( + indexSizes(idx).cne(shape(i).toI), + cb._fatal(s"${indexVars(idx)} indexes incompatible dimensions"), + ) } } val strides = array.strides - val pos = Array.tabulate(array.st.nDims + 1) { i => cb.newLocal[Long](s"$name$i") } + val pos = Array.tabulate(array.st.nDims + 1)(i => cb.newLocal[Long](s"$name$i")) val indexToDim = indices.zipWithIndex.toMap ArrayInfo(array, strides, pos, indexToDim, name) } @@ -98,60 +119,67 @@ object SNDArray { val coord = indexCoords(idx) def init(): Unit = { cb.assign(coord, 0) - for (n <- arrays.indices) { + for (n <- arrays.indices) if (info(n).indexToDim.contains(idx)) { val i = info(n).indexToDim(idx) // FIXME: assumes array's indices in ascending order - cb.assign(info(n).pos(i), info(n).pos(i+1)) + cb.assign(info(n).pos(i), info(n).pos(i + 1)) } - } } def increment(): Unit = { cb.assign(coord, coord + 1) - for (n <- arrays.indices) { + for (n <- arrays.indices) if (info(n).indexToDim.contains(idx)) { val i = info(n).indexToDim(idx) cb.assign(info(n).pos(i), info(n).pos(i) + info(n).strides(i)) } - } } cb.for_(init(), coord < indexSizes(idx), increment(), recurLoopBuilder(idx - 1)) } } - for (n <- arrays.indices) { + for (n <- arrays.indices) cb.assign(info(n).pos(info(n).array.st.nDims), info(n).array.firstDataAddress) - } recurLoopBuilder(indexVars.length - 1) } // Column major order - def forEachIndexWithInitAndIncColMajor(cb: EmitCodeBuilder, shape: IndexedSeq[Value[Long]], inits: IndexedSeq[EmitCodeBuilder => Unit], - incrementers: IndexedSeq[EmitCodeBuilder => Unit], context: String) - (f: (EmitCodeBuilder, IndexedSeq[Value[Long]]) => Unit): Unit = { + def forEachIndexWithInitAndIncColMajor( + cb: EmitCodeBuilder, + shape: IndexedSeq[Value[Long]], + inits: IndexedSeq[EmitCodeBuilder => Unit], + incrementers: IndexedSeq[EmitCodeBuilder => Unit], + context: String, + )( + f: (EmitCodeBuilder, IndexedSeq[Value[Long]]) => Unit + ): Unit = { - val indices = Array.tabulate(shape.length) { dimIdx => cb.newLocal[Long](s"${ context }_foreach_dim_$dimIdx", 0L) } + val indices = Array.tabulate(shape.length) { dimIdx => + cb.newLocal[Long](s"${context}_foreach_dim_$dimIdx", 0L) + } def recurLoopBuilder(dimIdx: Int, innerLambda: () => Unit): Unit = { if (dimIdx == shape.length) { innerLambda() - } - else { + } else { val dimVar = indices(dimIdx) - recurLoopBuilder(dimIdx + 1, + recurLoopBuilder( + dimIdx + 1, () => { - cb.for_({ - inits(dimIdx)(cb) - cb.assign(dimVar, 0L) - }, dimVar < shape(dimIdx), { - incrementers(dimIdx)(cb) - cb.assign(dimVar, dimVar + 1L) - }, - innerLambda() + cb.for_( + { + inits(dimIdx)(cb) + cb.assign(dimVar, 0L) + }, + dimVar < shape(dimIdx), { + incrementers(dimIdx)(cb) + cb.assign(dimVar, dimVar + 1L) + }, + innerLambda(), ) - } + }, ) } } @@ -162,37 +190,57 @@ object SNDArray { } // Row major order - def forEachIndexRowMajor(cb: EmitCodeBuilder, shape: IndexedSeq[Value[Long]], context: String) - (f: (EmitCodeBuilder, IndexedSeq[Value[Long]]) => Unit): Unit = { - forEachIndexWithInitAndIncRowMajor(cb, shape, shape.map(_ => (cb: EmitCodeBuilder) => ()), shape.map(_ => (cb: EmitCodeBuilder) => ()), context)(f) - } + def forEachIndexRowMajor( + cb: EmitCodeBuilder, + shape: IndexedSeq[Value[Long]], + context: String, + )( + f: (EmitCodeBuilder, IndexedSeq[Value[Long]]) => Unit + ): Unit = + forEachIndexWithInitAndIncRowMajor( + cb, + shape, + shape.map(_ => (cb: EmitCodeBuilder) => ()), + shape.map(_ => (cb: EmitCodeBuilder) => ()), + context, + )(f) // Row major order - def forEachIndexWithInitAndIncRowMajor(cb: EmitCodeBuilder, shape: IndexedSeq[Value[Long]], inits: IndexedSeq[EmitCodeBuilder => Unit], - incrementers: IndexedSeq[EmitCodeBuilder => Unit], context: String) - (f: (EmitCodeBuilder, IndexedSeq[Value[Long]]) => Unit): Unit = { + def forEachIndexWithInitAndIncRowMajor( + cb: EmitCodeBuilder, + shape: IndexedSeq[Value[Long]], + inits: IndexedSeq[EmitCodeBuilder => Unit], + incrementers: IndexedSeq[EmitCodeBuilder => Unit], + context: String, + )( + f: (EmitCodeBuilder, IndexedSeq[Value[Long]]) => Unit + ): Unit = { - val indices = Array.tabulate(shape.length) { dimIdx => cb.newLocal[Long](s"${ context }_foreach_dim_$dimIdx", 0L) } + val indices = Array.tabulate(shape.length) { dimIdx => + cb.newLocal[Long](s"${context}_foreach_dim_$dimIdx", 0L) + } def recurLoopBuilder(dimIdx: Int, innerLambda: () => Unit): Unit = { if (dimIdx == -1) { innerLambda() - } - else { + } else { val dimVar = indices(dimIdx) - recurLoopBuilder(dimIdx - 1, + recurLoopBuilder( + dimIdx - 1, () => { - cb.for_({ - inits(dimIdx)(cb) - cb.assign(dimVar, 0L) - }, dimVar < shape(dimIdx), { - incrementers(dimIdx)(cb) - cb.assign(dimVar, dimVar + 1L) - }, - innerLambda() + cb.for_( + { + inits(dimIdx)(cb) + cb.assign(dimVar, 0L) + }, + dimVar < shape(dimIdx), { + incrementers(dimIdx)(cb) + cb.assign(dimVar, dimVar + 1L) + }, + innerLambda(), ) - } + }, ) } } @@ -203,24 +251,22 @@ object SNDArray { } // Column major order - def unstagedForEachIndex(shape: IndexedSeq[Long]) - (f: IndexedSeq[Long] => Unit): Unit = { + def unstagedForEachIndex(shape: IndexedSeq[Long])(f: IndexedSeq[Long] => Unit): Unit = { - val indices = Array.tabulate(shape.length) {dimIdx => 0L} + val indices = Array.tabulate(shape.length)(dimIdx => 0L) def recurLoopBuilder(dimIdx: Int, innerLambda: () => Unit): Unit = { if (dimIdx == shape.length) { innerLambda() - } - else { + } else { - recurLoopBuilder(dimIdx + 1, - () => { - (0 until shape(dimIdx).toInt).foreach(_ => { + recurLoopBuilder( + dimIdx + 1, + () => + (0 until shape(dimIdx).toInt).foreach { _ => innerLambda() indices(dimIdx) += 1 - }) - } + }, ) } } @@ -230,31 +276,46 @@ object SNDArray { recurLoopBuilder(0, body) } - def assertMatrix(nds: SNDArrayValue*): Unit = { + def assertMatrix(nds: SNDArrayValue*): Unit = for (nd <- nds) assert(nd.st.nDims == 2) - } - def assertVector(nds: SNDArrayValue*): Unit = { + def assertVector(nds: SNDArrayValue*): Unit = for (nd <- nds) assert(nd.st.nDims == 1) - } - def assertColMajor(cb: EmitCodeBuilder, caller: String, nds: SNDArrayValue*): Unit = { - for (nd <- nds) { - cb.if_(nd.strides(0).cne(nd.st.pType.elementType.byteSize), - cb._fatal(s"$caller requires column major: found row stride ", nd.strides(0).toS, ", expected ", nd.st.pType.elementType.byteSize.toString)) - } - } + def assertColMajor(cb: EmitCodeBuilder, caller: String, nds: SNDArrayValue*): Unit = + for (nd <- nds) + cb.if_( + nd.strides(0).cne(nd.st.pType.elementType.byteSize), + cb._fatal( + s"$caller requires column major: found row stride ", + nd.strides(0).toS, + ", expected ", + nd.st.pType.elementType.byteSize.toString, + ), + ) def copyVector(cb: EmitCodeBuilder, X: SNDArrayValue, Y: SNDArrayValue): Unit = { val Seq(n) = X.shapes - Y.assertHasShape(cb, FastSeq(n), "copy: vectors have different sizes: ", Y.shapes(0).toS, ", ", n.toS) + Y.assertHasShape( + cb, + FastSeq(n), + "copy: vectors have different sizes: ", + Y.shapes(0).toS, + ", ", + n.toS, + ) val ldX = X.eltStride(0).max(1) val ldY = Y.eltStride(0).max(1) - cb += Code.invokeScalaObject5[Int, Long, Int, Long, Int, Unit](BLAS.getClass, "dcopy", + cb += Code.invokeScalaObject5[Int, Long, Int, Long, Int, Unit]( + BLAS.getClass, + "dcopy", n.toI, - X.firstDataAddress, ldX, - Y.firstDataAddress, ldY) + X.firstDataAddress, + ldX, + Y.firstDataAddress, + ldY, + ) } def copyMatrix(cb: EmitCodeBuilder, uplo: String, X: SNDArrayValue, Y: SNDArrayValue): Unit = { @@ -262,10 +323,17 @@ object SNDArray { Y.assertHasShape(cb, FastSeq(m, n), "copyMatrix: matrices have different shapes") val ldX = X.eltStride(1).max(1) val ldY = Y.eltStride(1).max(1) - cb += Code.invokeScalaObject7[String, Int, Int, Long, Int, Long, Int, Unit](LAPACK.getClass, "dlacpy", - uplo, m.toI, n.toI, - X.firstDataAddress, ldX, - Y.firstDataAddress, ldY) + cb += Code.invokeScalaObject7[String, Int, Int, Long, Int, Long, Int, Unit]( + LAPACK.getClass, + "dlacpy", + uplo, + m.toI, + n.toI, + X.firstDataAddress, + ldX, + Y.firstDataAddress, + ldY, + ) } def scale(cb: EmitCodeBuilder, alpha: SValue, X: SNDArrayValue): Unit = @@ -274,15 +342,29 @@ object SNDArray { def scale(cb: EmitCodeBuilder, alpha: Value[Double], X: SNDArrayValue): Unit = { val Seq(n) = X.shapes val ldX = X.eltStride(0).max(1) - cb += Code.invokeScalaObject4[Int, Double, Long, Int, Unit](BLAS.getClass, "dscal", - n.toI, alpha, X.firstDataAddress, ldX) + cb += Code.invokeScalaObject4[Int, Double, Long, Int, Unit]( + BLAS.getClass, + "dscal", + n.toI, + alpha, + X.firstDataAddress, + ldX, + ) } - def gemv(cb: EmitCodeBuilder, trans: String, A: SNDArrayValue, X: SNDArrayValue, Y: SNDArrayValue): Unit = { + def gemv(cb: EmitCodeBuilder, trans: String, A: SNDArrayValue, X: SNDArrayValue, Y: SNDArrayValue) + : Unit = gemv(cb, trans, 1.0, A, X, 1.0, Y) - } - def gemv(cb: EmitCodeBuilder, trans: String, alpha: Value[Double], A: SNDArrayValue, X: SNDArrayValue, beta: Value[Double], Y: SNDArrayValue): Unit = { + def gemv( + cb: EmitCodeBuilder, + trans: String, + alpha: Value[Double], + A: SNDArrayValue, + X: SNDArrayValue, + beta: Value[Double], + Y: SNDArrayValue, + ): Unit = { assertMatrix(A) val Seq(m, n) = A.shapes val errMsg = "gemv: incompatible dimensions" @@ -298,19 +380,56 @@ object SNDArray { val ldA = A.eltStride(1).max(1) val ldX = X.eltStride(0).max(1) val ldY = Y.eltStride(0).max(1) - cb += Code.invokeScalaObject11[String, Int, Int, Double, Long, Int, Long, Int, Double, Long, Int, Unit](BLAS.getClass, "dgemv", - trans, m.toI, n.toI, + cb += Code.invokeScalaObject11[ + String, + Int, + Int, + Double, + Long, + Int, + Long, + Int, + Double, + Long, + Int, + Unit, + ]( + BLAS.getClass, + "dgemv", + trans, + m.toI, + n.toI, alpha, - A.firstDataAddress, ldA, - X.firstDataAddress, ldX, + A.firstDataAddress, + ldA, + X.firstDataAddress, + ldX, beta, - Y.firstDataAddress, ldY) + Y.firstDataAddress, + ldY, + ) } - def gemm(cb: EmitCodeBuilder, tA: String, tB: String, A: SNDArrayValue, B: SNDArrayValue, C: SNDArrayValue): Unit = + def gemm( + cb: EmitCodeBuilder, + tA: String, + tB: String, + A: SNDArrayValue, + B: SNDArrayValue, + C: SNDArrayValue, + ): Unit = gemm(cb, tA, tB, 1.0, A, B, 0.0, C) - def gemm(cb: EmitCodeBuilder, tA: String, tB: String, alpha: Value[Double], A: SNDArrayValue, B: SNDArrayValue, beta: Value[Double], C: SNDArrayValue): Unit = { + def gemm( + cb: EmitCodeBuilder, + tA: String, + tB: String, + alpha: Value[Double], + A: SNDArrayValue, + B: SNDArrayValue, + beta: Value[Double], + C: SNDArrayValue, + ): Unit = { assertMatrix(A, B, C) val Seq(m, n) = C.shapes val k = if (tA == "N") A.shapes(1) else A.shapes(0) @@ -329,17 +448,50 @@ object SNDArray { val ldA = A.eltStride(1).max(1) val ldB = B.eltStride(1).max(1) val ldC = C.eltStride(1).max(1) - cb += Code.invokeScalaObject13[String, String, Int, Int, Int, Double, Long, Int, Long, Int, Double, Long, Int, Unit](BLAS.getClass, "dgemm", - tA, tB, m.toI, n.toI, k.toI, + cb += Code.invokeScalaObject13[ + String, + String, + Int, + Int, + Int, + Double, + Long, + Int, + Long, + Int, + Double, + Long, + Int, + Unit, + ]( + BLAS.getClass, + "dgemm", + tA, + tB, + m.toI, + n.toI, + k.toI, alpha, - A.firstDataAddress, ldA, - B.firstDataAddress, ldB, + A.firstDataAddress, + ldA, + B.firstDataAddress, + ldB, beta, - C.firstDataAddress, ldC) + C.firstDataAddress, + ldC, + ) } - def trmm(cb: EmitCodeBuilder, side: String, uplo: String, transA: String, diag: String, - alpha: Value[Double], A: SNDArrayValue, B: SNDArrayValue): Unit = { + def trmm( + cb: EmitCodeBuilder, + side: String, + uplo: String, + transA: String, + diag: String, + alpha: Value[Double], + A: SNDArrayValue, + B: SNDArrayValue, + ): Unit = { assertMatrix(A, B) assertColMajor(cb, "trmm", A, B) @@ -347,19 +499,48 @@ object SNDArray { val Seq(a0, a1) = A.shapes cb.if_(a1.cne(if (side == "left") m else n), cb._fatal("trmm: incompatible matrix dimensions")) // Elide check in the common case that we statically know A is square - if (a0 != a1) cb.if_(a0 < a1, cb._fatal("trmm: A has fewer rows than cols: ", a0.toS, ", ", a1.toS)) + if (a0 != a1) + cb.if_(a0 < a1, cb._fatal("trmm: A has fewer rows than cols: ", a0.toS, ", ", a1.toS)) val ldA = A.eltStride(1).max(1) val ldB = B.eltStride(1).max(1) - cb += Code.invokeScalaObject11[String, String, String, String, Int, Int, Double, Long, Int, Long, Int, Unit](BLAS.getClass, "dtrmm", - side, uplo, transA, diag, - m.toI, n.toI, + cb += Code.invokeScalaObject11[ + String, + String, + String, + String, + Int, + Int, + Double, + Long, + Int, + Long, + Int, + Unit, + ]( + BLAS.getClass, + "dtrmm", + side, + uplo, + transA, + diag, + m.toI, + n.toI, alpha, - A.firstDataAddress, ldA, - B.firstDataAddress, ldB) + A.firstDataAddress, + ldA, + B.firstDataAddress, + ldB, + ) } - def geqrt(A: SNDArrayValue, T: SNDArrayValue, work: SNDArrayValue, blocksize: Value[Long], cb: EmitCodeBuilder): Unit = { + def geqrt( + A: SNDArrayValue, + T: SNDArrayValue, + work: SNDArrayValue, + blocksize: Value[Long], + cb: EmitCodeBuilder, + ): Unit = { if (A.st.nDims == 2) assertColMajor(cb, "geqrt", A) else assertVector(A) assertVector(work, T) @@ -367,20 +548,39 @@ object SNDArray { val nb = blocksize val min = cb.memoize(m.min(n)) cb.if_((nb > min && min > 0) || nb < 1, cb._fatal("geqrt: invalid block size: ", nb.toS)) - cb.if_(T.shapes(0) < nb*(m.min(n)), cb._fatal("geqrt: T too small")) + cb.if_(T.shapes(0) < nb * (m.min(n)), cb._fatal("geqrt: T too small")) cb.if_(work.shapes(0) < nb * n, cb._fatal("geqrt: work array too small")) val error = cb.mb.newLocal[Int]() val ldA = if (A.st.nDims == 2) A.eltStride(1).max(1) else m.toI - cb.assign(error, Code.invokeScalaObject8[Int, Int, Int, Long, Int, Long, Int, Long, Int](LAPACK.getClass, "dgeqrt", - m.toI, n.toI, nb.toI, - A.firstDataAddress, ldA, - T.firstDataAddress, nb.toI.max(1), - work.firstDataAddress)) + cb.assign( + error, + Code.invokeScalaObject8[Int, Int, Int, Long, Int, Long, Int, Long, Int]( + LAPACK.getClass, + "dgeqrt", + m.toI, + n.toI, + nb.toI, + A.firstDataAddress, + ldA, + T.firstDataAddress, + nb.toI.max(1), + work.firstDataAddress, + ), + ) cb.if_(error.cne(0), cb._fatal("LAPACK error dtpqrt. Error code = ", error.toS)) } - def gemqrt(side: String, trans: String, V: SNDArrayValue, T: SNDArrayValue, C: SNDArrayValue, work: SNDArrayValue, blocksize: Value[Long], cb: EmitCodeBuilder): Unit = { + def gemqrt( + side: String, + trans: String, + V: SNDArrayValue, + T: SNDArrayValue, + C: SNDArrayValue, + work: SNDArrayValue, + blocksize: Value[Long], + cb: EmitCodeBuilder, + ): Unit = { assertMatrix(V) assertColMajor(cb, "gemqrt", V) if (C.st.nDims == 2) assertColMajor(cb, "gemqrt", C) else assertVector(C) @@ -392,7 +592,7 @@ object SNDArray { val Seq(m, n) = if (C.st.nDims == 2) C.shapes else FastSeq(C.shapes(0), SizeValueStatic(1)) val nb = blocksize cb.if_((nb > k && k > 0) || nb < 1, cb._fatal("gemqrt: invalid block size: ", nb.toS)) - cb.if_(T.shapes(0) < nb*k, cb._fatal("gemqrt: invalid T size")) + cb.if_(T.shapes(0) < nb * k, cb._fatal("gemqrt: invalid T size")) if (side == "L") { cb.if_(l.cne(m), cb._fatal("gemqrt: invalid dimensions")) cb.if_(work.shapes(0) < nb * n, cb._fatal("work array too small")) @@ -404,18 +604,55 @@ object SNDArray { val error = cb.mb.newLocal[Int]() val ldV = V.eltStride(1).max(1) val ldC = if (C.st.nDims == 2) C.eltStride(1).max(1) else m.toI - cb.assign(error, Code.invokeScalaObject13[String, String, Int, Int, Int, Int, Long, Int, Long, Int, Long, Int, Long, Int](LAPACK.getClass, "dgemqrt", - side, trans, m.toI, n.toI, k.toI, nb.toI, - V.firstDataAddress, ldV, - T.firstDataAddress, nb.toI.max(1), - C.firstDataAddress, ldC, - work.firstDataAddress)) + cb.assign( + error, + Code.invokeScalaObject13[ + String, + String, + Int, + Int, + Int, + Int, + Long, + Int, + Long, + Int, + Long, + Int, + Long, + Int, + ]( + LAPACK.getClass, + "dgemqrt", + side, + trans, + m.toI, + n.toI, + k.toI, + nb.toI, + V.firstDataAddress, + ldV, + T.firstDataAddress, + nb.toI.max(1), + C.firstDataAddress, + ldC, + work.firstDataAddress, + ), + ) cb.if_(error.cne(0), cb._fatal("LAPACK error dgemqrt. Error code = ", error.toS)) } // Computes the QR factorization of A. Stores resulting factors in Q and R, overwriting A. - def geqrt_full(cb: EmitCodeBuilder, A: SNDArrayValue, Q: SNDArrayValue, R: SNDArrayValue, T: SNDArrayValue, work: SNDArrayValue, blocksize: Value[Long]): Unit = { - val Seq(m, n) = A.shapes + def geqrt_full( + cb: EmitCodeBuilder, + A: SNDArrayValue, + Q: SNDArrayValue, + R: SNDArrayValue, + T: SNDArrayValue, + work: SNDArrayValue, + blocksize: Value[Long], + ): Unit = { + val Seq(_, n) = A.shapes SNDArray.geqrt(A, T, work, blocksize, cb) // copy upper triangle of A0 to R SNDArray.copyMatrix(cb, "U", A.slice(cb, (null, n), ColonIndex), R) @@ -423,21 +660,35 @@ object SNDArray { // Set Q to I Q.setToZero(cb) val i = cb.mb.newLocal[Long]("i") - cb.for_(cb.assign(i, 0L), i < n, cb.assign(i, i+1), { - Q.setElement(FastSeq(i, i), primitive(1.0), cb) - }) + cb.for_( + cb.assign(i, 0L), + i < n, + cb.assign(i, i + 1), + Q.setElement(FastSeq(i, i), primitive(1.0), cb), + ) SNDArray.gemqrt("L", "N", A, T, Q, work, blocksize, cb) } - def geqr_query(cb: EmitCodeBuilder, m: Value[Long], n: Value[Long], region: Value[Region]): (Value[Long], Value[Long]) = { + def geqr_query(cb: EmitCodeBuilder, m: Value[Long], n: Value[Long], region: Value[Region]) + : (Value[Long], Value[Long]) = { val T = cb.memoize(region.allocate(8L * 5, 8L)) val work = cb.memoize(region.allocate(8L, 8L)) - val info = cb.memoize(Code.invokeScalaObject8[Int, Int, Long, Int, Long, Int, Long, Int, Int](LAPACK.getClass, "dgeqr", - m.toI, n.toI, - 0, m.toI, - T, -1, - work, -1)) - cb.if_(info.cne(0), cb._fatal(s"LAPACK error DGEQR. Failed size query. Error code = ", info.toS)) + val info = cb.memoize(Code.invokeScalaObject8[Int, Int, Long, Int, Long, Int, Long, Int, Int]( + LAPACK.getClass, + "dgeqr", + m.toI, + n.toI, + 0, + m.toI, + T, + -1, + work, + -1, + )) + cb.if_( + info.cne(0), + cb._fatal(s"LAPACK error DGEQR. Failed size query. Error code = ", info.toS), + ) val Tsize = cb.memoize(Region.loadDouble(T).toL) val LWork = cb.memoize(Region.loadDouble(work).toL) (cb.memoize(Tsize.max(5)), cb.memoize(LWork.max(1))) @@ -454,11 +705,21 @@ object SNDArray { val ldA = A.eltStride(1).max(1) val info = cb.newLocal[Int]("dgeqrf_info") - cb.assign(info, Code.invokeScalaObject8[Int, Int, Long, Int, Long, Int, Long, Int, Int](LAPACK.getClass, "dgeqr", - m.toI, n.toI, - A.firstDataAddress, ldA, - T.firstDataAddress, Tsize.toI, - work.firstDataAddress, lwork.toI)) + cb.assign( + info, + Code.invokeScalaObject8[Int, Int, Long, Int, Long, Int, Long, Int, Int]( + LAPACK.getClass, + "dgeqr", + m.toI, + n.toI, + A.firstDataAddress, + ldA, + T.firstDataAddress, + Tsize.toI, + work.firstDataAddress, + lwork.toI, + ), + ) val optTsize = T.loadElement(FastSeq(0), cb).asFloat64.value.toI val optLwork = work.loadElement(FastSeq(0), cb).asFloat64.value.toI cb.if_(optTsize > Tsize.toI, cb._fatal(s"dgeqr: T too small")) @@ -466,7 +727,15 @@ object SNDArray { cb.if_(info.cne(0), cb._fatal(s"LAPACK error dgeqr. Error code = ", info.toS)) } - def gemqr(cb: EmitCodeBuilder, side: String, trans: String, A: SNDArrayValue, T: SNDArrayValue, C: SNDArrayValue, work: SNDArrayValue): Unit = { + def gemqr( + cb: EmitCodeBuilder, + side: String, + trans: String, + A: SNDArrayValue, + T: SNDArrayValue, + C: SNDArrayValue, + work: SNDArrayValue, + ): Unit = { assertMatrix(A) assertColMajor(cb, "gemqr", A) if (C.st.nDims == 2) assertColMajor(cb, "gemqr", C) else assertVector(C) @@ -487,18 +756,54 @@ object SNDArray { val error = cb.mb.newLocal[Int]() val ldA = A.eltStride(1).max(1) val ldC = if (C.st.nDims == 2) C.eltStride(1).max(1) else m.toI - cb.assign(error, Code.invokeScalaObject13[String, String, Int, Int, Int, Long, Int, Long, Int, Long, Int, Long, Int, Int](LAPACK.getClass, "dgemqr", - side, trans, m.toI, n.toI, k.toI, - A.firstDataAddress, ldA, - T.firstDataAddress, Tsize.toI, - C.firstDataAddress, ldC, - work.firstDataAddress, Lwork.toI)) + cb.assign( + error, + Code.invokeScalaObject13[ + String, + String, + Int, + Int, + Int, + Long, + Int, + Long, + Int, + Long, + Int, + Long, + Int, + Int, + ]( + LAPACK.getClass, + "dgemqr", + side, + trans, + m.toI, + n.toI, + k.toI, + A.firstDataAddress, + ldA, + T.firstDataAddress, + Tsize.toI, + C.firstDataAddress, + ldC, + work.firstDataAddress, + Lwork.toI, + ), + ) cb.if_(error.cne(0), cb._fatal("LAPACK error dgemqr. Error code = ", error.toS)) } // Computes the QR factorization of A. Stores resulting factors in Q and R, overwriting A. - def geqr_full(cb: EmitCodeBuilder, A: SNDArrayValue, Q: SNDArrayValue, R: SNDArrayValue, T: SNDArrayValue, work: SNDArrayValue): Unit = { - val Seq(m, n) = A.shapes + def geqr_full( + cb: EmitCodeBuilder, + A: SNDArrayValue, + Q: SNDArrayValue, + R: SNDArrayValue, + T: SNDArrayValue, + work: SNDArrayValue, + ): Unit = { + val Seq(_, n) = A.shapes SNDArray.geqr(cb, A, T, work) // copy upper triangle of A0 to R SNDArray.copyMatrix(cb, "U", A.slice(cb, (null, n), ColonIndex), R) @@ -506,13 +811,23 @@ object SNDArray { // Set Q to I Q.setToZero(cb) val i = cb.mb.newLocal[Long]("i") - cb.for_(cb.assign(i, 0L), i < n, cb.assign(i, i+1), { - Q.setElement(FastSeq(i, i), primitive(1.0), cb) - }) + cb.for_( + cb.assign(i, 0L), + i < n, + cb.assign(i, i + 1), + Q.setElement(FastSeq(i, i), primitive(1.0), cb), + ) SNDArray.gemqr(cb, "L", "N", A, T, Q, work) } - def tpqrt(A: SNDArrayValue, B: SNDArrayValue, T: SNDArrayValue, work: SNDArrayValue, blocksize: Value[Long], cb: EmitCodeBuilder): Unit = { + def tpqrt( + A: SNDArrayValue, + B: SNDArrayValue, + T: SNDArrayValue, + work: SNDArrayValue, + blocksize: Value[Long], + cb: EmitCodeBuilder, + ): Unit = { assertMatrix(A, B) assertColMajor(cb, "tpqrt", A, B) assertVector(work, T) @@ -520,23 +835,45 @@ object SNDArray { val Seq(m, n) = B.shapes val nb = blocksize cb.if_(nb > n || nb < 1, cb._fatal("tpqrt: invalid block size")) - cb.if_(T.shapes(0) < nb*n, cb._fatal("tpqrt: T too small")) + cb.if_(T.shapes(0) < nb * n, cb._fatal("tpqrt: T too small")) A.assertHasShape(cb, FastSeq(n, n), "tpqrt: invalid shapes") cb.if_(work.shapes(0) < nb * n, cb._fatal("tpqrt: work array too small")) val error = cb.mb.newLocal[Int]() val ldA = A.eltStride(1).max(1) val ldB = B.eltStride(1).max(1) - cb.assign(error, Code.invokeScalaObject11[Int, Int, Int, Int, Long, Int, Long, Int, Long, Int, Long, Int](LAPACK.getClass, "dtpqrt", - m.toI, n.toI, 0, nb.toI, - A.firstDataAddress, ldA, - B.firstDataAddress, ldB, - T.firstDataAddress, nb.toI.max(1), - work.firstDataAddress)) + cb.assign( + error, + Code.invokeScalaObject11[Int, Int, Int, Int, Long, Int, Long, Int, Long, Int, Long, Int]( + LAPACK.getClass, + "dtpqrt", + m.toI, + n.toI, + 0, + nb.toI, + A.firstDataAddress, + ldA, + B.firstDataAddress, + ldB, + T.firstDataAddress, + nb.toI.max(1), + work.firstDataAddress, + ), + ) cb.if_(error.cne(0), cb._fatal("LAPACK error dtpqrt. Error code = ", error.toS)) } - def tpmqrt(side: String, trans: String, V: SNDArrayValue, T: SNDArrayValue, A: SNDArrayValue, B: SNDArrayValue, work: SNDArrayValue, blocksize: Value[Long], cb: EmitCodeBuilder): Unit = { + def tpmqrt( + side: String, + trans: String, + V: SNDArrayValue, + T: SNDArrayValue, + A: SNDArrayValue, + B: SNDArrayValue, + work: SNDArrayValue, + blocksize: Value[Long], + cb: EmitCodeBuilder, + ): Unit = { assertMatrix(A, B, V) assertColMajor(cb, "tpmqrt", A, B, V) assertVector(work, T) @@ -547,7 +884,7 @@ object SNDArray { val Seq(m, n) = B.shapes val nb = blocksize cb.if_(nb > k || nb < 1, cb._fatal("tpmqrt: invalid block size")) - cb.if_(T.shapes(0) < nb*k, cb._fatal("tpmqrt: T too small")) + cb.if_(T.shapes(0) < nb * k, cb._fatal("tpmqrt: T too small")) if (side == "L") { cb.if_(l.cne(m), cb._fatal("tpmqrt: invalid dimensions")) cb.if_(work.shapes(0) < nb * n, cb._fatal("tpmqrt: work array too small")) @@ -562,27 +899,74 @@ object SNDArray { val ldV = V.eltStride(1).max(1) val ldA = A.eltStride(1).max(1) val ldB = B.eltStride(1).max(1) - cb.assign(error, Code.invokeScalaObject16[String, String, Int, Int, Int, Int, Int, Long, Int, Long, Int, Long, Int, Long, Int, Long, Int](LAPACK.getClass, "dtpmqrt", - side, trans, m.toI, n.toI, k.toI, 0, nb.toI, - V.firstDataAddress, ldV, - T.firstDataAddress, nb.toI.max(1), - A.firstDataAddress, ldA, - B.firstDataAddress, ldB, - work.firstDataAddress)) + cb.assign( + error, + Code.invokeScalaObject16[ + String, + String, + Int, + Int, + Int, + Int, + Int, + Long, + Int, + Long, + Int, + Long, + Int, + Long, + Int, + Long, + Int, + ]( + LAPACK.getClass, + "dtpmqrt", + side, + trans, + m.toI, + n.toI, + k.toI, + 0, + nb.toI, + V.firstDataAddress, + ldV, + T.firstDataAddress, + nb.toI.max(1), + A.firstDataAddress, + ldA, + B.firstDataAddress, + ldB, + work.firstDataAddress, + ), + ) cb.if_(error.cne(0), cb._fatal("LAPACK error dtpqrt. Error code = ", error.toS)) } - def geqrf_query(cb: EmitCodeBuilder, m: Value[Int], n: Value[Int], region: Value[Region]): Value[Int] = { + def geqrf_query(cb: EmitCodeBuilder, m: Value[Int], n: Value[Int], region: Value[Region]) + : Value[Int] = { val LWorkAddress = cb.newLocal[Long]("dgeqrf_lwork_address") val LWork = cb.newLocal[Int]("dgeqrf_lwork") val info = cb.newLocal[Int]("dgeqrf_info") cb.assign(LWorkAddress, region.allocate(8L, 8L)) - cb.assign(info, Code.invokeScalaObject7[Int, Int, Long, Int, Long, Long, Int, Int](LAPACK.getClass, "dgeqrf", - m.toI, n.toI, - 0, m.toI, - 0, - LWorkAddress, -1)) - cb.if_(info.cne(0), cb._fatal(s"LAPACK error DGEQRF. Failed size query. Error code = ", info.toS)) + cb.assign( + info, + Code.invokeScalaObject7[Int, Int, Long, Int, Long, Long, Int, Int]( + LAPACK.getClass, + "dgeqrf", + m.toI, + n.toI, + 0, + m.toI, + 0, + LWorkAddress, + -1, + ), + ) + cb.if_( + info.cne(0), + cb._fatal(s"LAPACK error DGEQRF. Failed size query. Error code = ", info.toS), + ) cb.assign(LWork, Region.loadDouble(LWorkAddress).toI) cb.memoize((LWork > 0).mux(LWork, 1)) } @@ -599,15 +983,30 @@ object SNDArray { val ldA = A.eltStride(1).max(1) val info = cb.newLocal[Int]("dgeqrf_info") - cb.assign(info, Code.invokeScalaObject7[Int, Int, Long, Int, Long, Long, Int, Int](LAPACK.getClass, "dgeqrf", - m.toI, n.toI, - A.firstDataAddress, ldA, - T.firstDataAddress, - work.firstDataAddress, lwork.toI)) + cb.assign( + info, + Code.invokeScalaObject7[Int, Int, Long, Int, Long, Long, Int, Int]( + LAPACK.getClass, + "dgeqrf", + m.toI, + n.toI, + A.firstDataAddress, + ldA, + T.firstDataAddress, + work.firstDataAddress, + lwork.toI, + ), + ) cb.if_(info.cne(0), cb._fatal(s"LAPACK error DGEQRF. Error code = ", info.toS)) } - def orgqr(cb: EmitCodeBuilder, k: Value[Int], A: SNDArrayValue, T: SNDArrayValue, work: SNDArrayValue): Unit = { + def orgqr( + cb: EmitCodeBuilder, + k: Value[Int], + A: SNDArrayValue, + T: SNDArrayValue, + work: SNDArrayValue, + ): Unit = { assertMatrix(A) assertColMajor(cb, "orgqr", A) assertVector(T, work) @@ -620,33 +1019,98 @@ object SNDArray { val ldA = A.eltStride(1).max(1) val info = cb.newLocal[Int]("dgeqrf_info") - cb.assign(info, Code.invokeScalaObject8[Int, Int, Int, Long, Int, Long, Long, Int, Int](LAPACK.getClass, "dorgqr", - m.toI, n.toI, k.toI, - A.firstDataAddress, ldA, - T.firstDataAddress, - work.firstDataAddress, lwork.toI)) + cb.assign( + info, + Code.invokeScalaObject8[Int, Int, Int, Long, Int, Long, Long, Int, Int]( + LAPACK.getClass, + "dorgqr", + m.toI, + n.toI, + k.toI, + A.firstDataAddress, + ldA, + T.firstDataAddress, + work.firstDataAddress, + lwork.toI, + ), + ) cb.if_(info.cne(0), cb._fatal(s"LAPACK error DGEQRF. Error code = ", info.toS)) } - def syevr_query(cb: EmitCodeBuilder, jobz: String, uplo: String, n: Value[Int], region: Value[Region]): (SizeValue, SizeValue) = { + def syevr_query( + cb: EmitCodeBuilder, + jobz: String, + uplo: String, + n: Value[Int], + region: Value[Region], + ): (SizeValue, SizeValue) = { val WorkAddress = cb.memoize(region.allocate(8L, 8L)) val IWorkAddress = cb.memoize(region.allocate(4L, 4L)) - val info = cb.memoize(Code.invokeScalaObject19[String, String, String, Int, Long, Int, Double, Double, Int, Int, Double, Long, Long, Int, Long, Long, Int, Long, Int, Int](LAPACK.getClass, "dsyevr", - jobz, "A", uplo, - n, 0, n, - 0, 0, 0, 0, + val info = cb.memoize(Code.invokeScalaObject19[ + String, + String, + String, + Int, + Long, + Int, + Double, + Double, + Int, + Int, + Double, + Long, + Long, + Int, + Long, + Long, + Int, + Long, + Int, + Int, + ]( + LAPACK.getClass, + "dsyevr", + jobz, + "A", + uplo, + n, + 0, + n, + 0, + 0, 0, - 0, 0, n, 0, - WorkAddress, -1, - IWorkAddress, -1)) - cb.if_(info.cne(0), cb._fatal(s"LAPACK error DSYEVR. Failed size query. Error code = ", info.toS)) + 0, + 0, + 0, + n, + 0, + WorkAddress, + -1, + IWorkAddress, + -1, + )) + cb.if_( + info.cne(0), + cb._fatal(s"LAPACK error DSYEVR. Failed size query. Error code = ", info.toS), + ) val LWork = cb.memoize(Region.loadDouble(WorkAddress).toL) val LIWork = cb.memoize(Region.loadInt(IWorkAddress).toL) - (SizeValueDyn(cb.memoize((LWork > 0).mux(LWork, 1))), SizeValueDyn(cb.memoize((LIWork > 0).mux(LIWork, 1)))) + ( + SizeValueDyn(cb.memoize((LWork > 0).mux(LWork, 1))), + SizeValueDyn(cb.memoize((LIWork > 0).mux(LIWork, 1))), + ) } - def syevr(cb: EmitCodeBuilder, uplo: String, A: SNDArrayValue, W: SNDArrayValue, Z: Option[(SNDArrayValue, SNDArrayValue)], Work: SNDArrayValue, IWork: SNDArrayValue): Unit = { + def syevr( + cb: EmitCodeBuilder, + uplo: String, + A: SNDArrayValue, + W: SNDArrayValue, + Z: Option[(SNDArrayValue, SNDArrayValue)], + Work: SNDArrayValue, + IWork: SNDArrayValue, + ): Unit = { assertMatrix(A) assertColMajor(cb, "orgqr", A) assertVector(W, Work, IWork) @@ -666,27 +1130,65 @@ object SNDArray { assertMatrix(z) z.assertHasShape(cb, Array(n, n), "syevr: Z has wrong size") - iSuppZ.assertHasShape(cb, IndexedSeq(SizeValueDyn(cb.memoize(n * 2))), "syevr: ISuppZ has wrong size") + iSuppZ.assertHasShape( + cb, + IndexedSeq(SizeValueDyn(cb.memoize(n * 2))), + "syevr: ISuppZ has wrong size", + ) ("V", z.firstDataAddress, z.eltStride(1).max(1), iSuppZ.firstDataAddress) case None => ("N", const(0L), const(1).get, const(0L)) } - val info = cb.memoize(Code.invokeScalaObject19[String, String, String, Int, Long, Int, Double, Double, Int, Int, Double, Long, Long, Int, Long, Long, Int, Long, Int, Int](LAPACK.getClass, "dsyevr", - jobz, "A", uplo, - n.toI, A.firstDataAddress, ldA, - 0, 0, 0, 0, + val info = cb.memoize(Code.invokeScalaObject19[ + String, + String, + String, + Int, + Long, + Int, + Double, + Double, + Int, + Int, + Double, + Long, + Long, + Int, + Long, + Long, + Int, + Long, + Int, + Int, + ]( + LAPACK.getClass, + "dsyevr", + jobz, + "A", + uplo, + n.toI, + A.firstDataAddress, + ldA, + 0, + 0, + 0, + 0, 0, - W.firstDataAddress, zAddr, ldZ, + W.firstDataAddress, + zAddr, + ldZ, iSuppZAddr, - Work.firstDataAddress, lWork.toI, - IWork.firstDataAddress, lIWork.toI)) + Work.firstDataAddress, + lWork.toI, + IWork.firstDataAddress, + lIWork.toI, + )) cb.if_(info.cne(0), cb._fatal(s"LAPACK error DSYEVR. Error code = ", info.toS)) } } - trait SNDArray extends SType { def pType: PNDArray @@ -698,7 +1200,8 @@ trait SNDArray extends SType { def elementByteSize: Long - override def _typeWithRequiredness: TypeWithRequiredness = RNDArray(elementType.typeWithRequiredness.setRequired(true).r) + override def _typeWithRequiredness: TypeWithRequiredness = + RNDArray(elementType.typeWithRequiredness.setRequired(true).r) } sealed abstract class NDArrayIndex @@ -717,31 +1220,39 @@ sealed abstract class SizeValue extends Value[Long] { case (SizeValueStatic(l), SizeValueStatic(r)) => const(l == r) case (l, r) => if (l == r) const(true) else l.get.ceq(r.get) } + def cne(other: SizeValue): Code[Boolean] = (this, other) match { case (SizeValueStatic(l), SizeValueStatic(r)) => const(l != r) case (l, r) => if (l == r) const(false) else l.get.cne(r.get) } } + object SizeValueDyn { def apply(v: Value[Long]): SizeValueDyn = new SizeValueDyn(v) def unapply(size: SizeValueDyn): Some[Value[Long]] = Some(size.v) } + object SizeValueStatic { def apply(v: Long): SizeValueStatic = { assert(v >= 0) new SizeValueStatic(v) } + def unapply(size: SizeValueStatic): Some[Long] = Some(size.v) } + final class SizeValueDyn(val v: Value[Long]) extends SizeValue { def get: Code[Long] = v.get + override def equals(other: Any): Boolean = other match { case SizeValueDyn(v2) => v eq v2 case _ => false } } + final class SizeValueStatic(val v: Long) extends SizeValue { def get: Code[Long] = const(v) + override def equals(other: Any): Boolean = other match { case SizeValueStatic(v2) => v == v2 case _ => false @@ -784,12 +1295,16 @@ trait SNDArrayValue extends SValue { def assertInBounds(indices: IndexedSeq[Value[Long]], cb: EmitCodeBuilder, errorId: Int): Unit = { val shape = this.shapes for (dimIndex <- 0 until st.nDims) { - cb.if_(indices(dimIndex) >= shape(dimIndex) || indices(dimIndex) < 0, { - cb._fatalWithError(errorId, - "Index ", indices(dimIndex).toS, + cb.if_( + indices(dimIndex) >= shape(dimIndex) || indices(dimIndex) < 0, + cb._fatalWithError( + errorId, + "Index ", + indices(dimIndex).toS, s" is out of bounds for axis $dimIndex with size ", - shape(dimIndex).toS) - }) + shape(dimIndex).toS, + ), + ) } } @@ -801,22 +1316,22 @@ trait SNDArrayValue extends SValue { val shape = this.shapes assert(shape.length == otherShape.length) - (shape, otherShape).zipped.foreach { (s1, s2) => - b = s1.ceq(s2) - } + (shape, otherShape).zipped.foreach((s1, s2) => b = s1.ceq(s2)) b } def assertHasShape(cb: EmitCodeBuilder, otherShape: IndexedSeq[SizeValue], msg: Code[String]*) = if (!hasShapeStatic(otherShape)) - cb.if_(!hasShape(cb, otherShape), + cb.if_( + !hasShape(cb, otherShape), cb._fatal( msg ++ - (const("\nExpected shape ").get +: - shapes.map(_.toS).intersperse[Code[String]]("(", ",", ")")) ++ - (const(", found ").get +: - otherShape.map(_.toS).intersperse[Code[String]]("(", ",", ")")): _*, - )) + (const("\nExpected shape ").get +: + shapes.map(_.toS).intersperse[Code[String]]("(", ",", ")")) ++ + (const(", found ").get +: + otherShape.map(_.toS).intersperse[Code[String]]("(", ",", ")")): _* + ), + ) // True IFF shape can be proven equal to otherShape statically def hasShapeStatic(otherShape: IndexedSeq[SizeValue]): Boolean = @@ -843,12 +1358,13 @@ trait SNDArrayValue extends SValue { // Find largest prefix of dimensions which are stored contiguously. def contigDimsRecur(i: Int): Unit = if (i < st.nDims) { - cb.if_(tmp.ceq(eltStride(i)), { - cb.assign(tmp, tmp * shapes(i).toI) - contigDimsRecur(i+1) - }, { - cb.assign(contiguousDims, i) - }) + cb.if_( + tmp.ceq(eltStride(i)), { + cb.assign(tmp, tmp * shapes(i).toI) + contigDimsRecur(i + 1) + }, + cb.assign(contiguousDims, i), + ) } else { cb.assign(contiguousDims, st.nDims) } @@ -867,24 +1383,30 @@ trait SNDArrayValue extends SValue { def recur(startPtr: Value[Long], dim: Int, contiguousDims: Int): Unit = if (dim > 0) { if (contiguousDims == dim) - cb += Region.setMemory(startPtr, shapes(dim-1) * strides(dim-1), 0: Byte) + cb += Region.setMemory(startPtr, shapes(dim - 1) * strides(dim - 1), 0: Byte) else { val ptr = cb.mb.newLocal[Long](s"NDArray_setToZero_ptr_$dim") val end = cb.mb.newLocal[Long](s"NDArray_setToZero_end_$dim") cb.assign(ptr, startPtr) - cb.assign(end, startPtr + strides(dim-1) * shapes(dim-1)) - cb.for_({}, ptr < end, cb.assign(ptr, ptr + strides(dim-1)), recur(ptr, dim - 1, contiguousDims)) + cb.assign(end, startPtr + strides(dim - 1) * shapes(dim - 1)) + cb.for_( + {}, + ptr < end, + cb.assign(ptr, ptr + strides(dim - 1)), + recur(ptr, dim - 1, contiguousDims), + ) } } else { eltType.storePrimitiveAtAddress(cb, startPtr, primitive(eltType.virtualType, eltType.zero)) } - cb.switch(contiguousDims, + cb.switch( + contiguousDims, recur(firstDataAddress, st.nDims, 2), FastSeq( () => recur(firstDataAddress, st.nDims, 0), () => recur(firstDataAddress, st.nDims, 1), - ) + ), ) } @@ -893,20 +1415,50 @@ trait SNDArrayValue extends SValue { eltType.storePrimitiveAtAddress(cb, loadElementAddress(indices, cb), value) } - def coiterateMutate(cb: EmitCodeBuilder, region: Value[Region], arrays: (SNDArrayValue, String)*)(body: IndexedSeq[SValue] => SValue): Unit = + def coiterateMutate( + cb: EmitCodeBuilder, + region: Value[Region], + arrays: (SNDArrayValue, String)* + )( + body: IndexedSeq[SValue] => SValue + ): Unit = coiterateMutate(cb, region, false, arrays: _*)(body) - def coiterateMutate(cb: EmitCodeBuilder, region: Value[Region], deepCopy: Boolean, arrays: (SNDArrayValue, String)*)(body: IndexedSeq[SValue] => SValue): Unit = { + def coiterateMutate( + cb: EmitCodeBuilder, + region: Value[Region], + deepCopy: Boolean, + arrays: (SNDArrayValue, String)* + )( + body: IndexedSeq[SValue] => SValue + ): Unit = { val indexVars = Array.tabulate(st.nDims)(i => s"i$i").toFastSeq val indices = Array.range(0, st.nDims).toFastSeq - coiterateMutate(cb, region, deepCopy, indexVars, indices, arrays.map { case (array, name) => (array, indices, name) }: _*)(body) + coiterateMutate( + cb, + region, + deepCopy, + indexVars, + indices, + arrays.map { case (array, name) => (array, indices, name) }: _* + )(body) } - def coiterateMutate(cb: EmitCodeBuilder, region: Value[Region], indexVars: IndexedSeq[String], destIndices: IndexedSeq[Int], arrays: (SNDArrayValue, IndexedSeq[Int], String)*)(body: IndexedSeq[SValue] => SValue): Unit = + def coiterateMutate( + cb: EmitCodeBuilder, + region: Value[Region], + indexVars: IndexedSeq[String], + destIndices: IndexedSeq[Int], + arrays: (SNDArrayValue, IndexedSeq[Int], String)* + )( + body: IndexedSeq[SValue] => SValue + ): Unit = coiterateMutate(cb, region, false, indexVars, destIndices, arrays: _*)(body) - // Note: to iterate through an array in column major order, make sure the indices are in ascending order. E.g. - // A.coiterateMutate(cb, region, IndexedSeq("i", "j"), IndexedSeq((A, IndexedSeq(0, 1), "A"), (B, IndexedSeq(0, 1), "B")), { + /* Note: to iterate through an array in column major order, make sure the indices are in ascending + * order. E.g. */ + /* A.coiterateMutate(cb, region, IndexedSeq("i", "j"), IndexedSeq((A, IndexedSeq(0, 1), "A"), (B, + * IndexedSeq(0, 1), "B")), { */ // SCode.add(cb, a, b) // }) // computes A += B. @@ -917,7 +1469,8 @@ trait SNDArrayValue extends SValue { indexVars: IndexedSeq[String], destIndices: IndexedSeq[Int], arrays: (SNDArrayValue, IndexedSeq[Int], String)* - )(body: IndexedSeq[SValue] => SValue + )( + body: IndexedSeq[SValue] => SValue ): Unit def _slice(cb: EmitCodeBuilder, indices: IndexedSeq[NDArrayIndex]): SNDArraySliceValue = { @@ -928,10 +1481,34 @@ trait SNDArrayValue extends SValue { for (i <- indices.indices) indices(i) match { case ScalarIndex(j) => - cb.if_(j < 0 || j >= shapeX(i), cb._fatal("Scalar index out of bounds (axis ", i.toString, "): ", j.toS, " is not in [0,", shapeX(i).toS, ")")) + cb.if_( + j < 0 || j >= shapeX(i), + cb._fatal( + "Scalar index out of bounds (axis ", + i.toString, + "): ", + j.toS, + " is not in [0,", + shapeX(i).toS, + ")", + ), + ) case SliceIndex(Some(begin), Some(end)) => - cb.if_(begin > end, cb._fatal("Invalid slice index, " , begin.toS, " > ", end.toS)) - cb.if_(begin < 0 || end > shapeX(i), cb._fatal("Slice index out of bounds: (axis ", i.toString, ") range ", begin.toS, ":", end.toS, " is not contained by [0,", shapeX(i).toS, ")")) + cb.if_(begin > end, cb._fatal("Invalid slice index, ", begin.toS, " > ", end.toS)) + cb.if_( + begin < 0 || end > shapeX(i), + cb._fatal( + "Slice index out of bounds: (axis ", + i.toString, + ") range ", + begin.toS, + ":", + end.toS, + " is not contained by [0,", + shapeX(i).toS, + ")", + ), + ) val s = cb.newLocal[Long]("slice_size", end - begin) shapeBuilder += SizeValueDyn(s) stridesBuilder += stridesX(i) @@ -940,13 +1517,44 @@ trait SNDArrayValue extends SValue { shapeBuilder += end stridesBuilder += stridesX(i) case SliceIndex(None, Some(end)) => - cb.if_(end < 0, cb._fatal("Slice end index out of bounds (axis ", i.toString, "): endpoint " , end.toS, " < 0")) - cb.if_(end > shapeX(i), cb._fatal("Slice end index out of bounds: endpoint ", end.toS, " > ", shapeX(i).toS)) + cb.if_( + end < 0, + cb._fatal( + "Slice end index out of bounds (axis ", + i.toString, + "): endpoint ", + end.toS, + " < 0", + ), + ) + cb.if_( + end > shapeX(i), + cb._fatal("Slice end index out of bounds: endpoint ", end.toS, " > ", shapeX(i).toS), + ) shapeBuilder += SizeValueDyn(end) stridesBuilder += stridesX(i) case SliceIndex(Some(begin), None) => - cb.if_(begin < 0, cb._fatal("Slice start index out of bounds (axis ", i.toString, "): startpoint " , begin.toS, " < 0")) - cb.if_(begin > shapeX(i), cb._fatal("Slice start index out of bounds (axis ", i.toString, "): startpoint ", begin.toS, " > ", shapeX(i).toS)) + cb.if_( + begin < 0, + cb._fatal( + "Slice start index out of bounds (axis ", + i.toString, + "): startpoint ", + begin.toS, + " < 0", + ), + ) + cb.if_( + begin > shapeX(i), + cb._fatal( + "Slice start index out of bounds (axis ", + i.toString, + "): startpoint ", + begin.toS, + " > ", + shapeX(i).toS, + ), + ) val s = cb.newLocal[Long]("slice_size", shapeX(i) - begin) shapeBuilder += SizeValueDyn(s) stridesBuilder += stridesX(i) @@ -954,12 +1562,40 @@ trait SNDArrayValue extends SValue { shapeBuilder += shapeX(i) stridesBuilder += stridesX(i) case SliceSize(None, size) => - cb.if_(size > shapeX(i), cb._fatal("Slice size out of bounds (axis ", i.toString, "): size ", size.toS, " > ", shapeX(i).toS)) + cb.if_( + size > shapeX(i), + cb._fatal( + "Slice size out of bounds (axis ", + i.toString, + "): size ", + size.toS, + " > ", + shapeX(i).toS, + ), + ) shapeBuilder += size stridesBuilder += stridesX(i) case SliceSize(Some(begin), size) => - cb.if_(begin < 0, cb._fatal("Slice start out of bounds (axis ", i.toString, "): start ", begin.toS, " < 0")) - cb.if_(begin + size > shapeX(i), cb._fatal("Slice index out of bounds (axis ", i.toString, "): range ", begin.toS, ":", begin.toS, "+", size.toS, " is not contained by [0,", shapeX(i).toS, ")")) + cb.if_( + begin < 0, + cb._fatal("Slice start out of bounds (axis ", i.toString, "): start ", begin.toS, " < 0"), + ) + cb.if_( + begin + size > shapeX(i), + cb._fatal( + "Slice index out of bounds (axis ", + i.toString, + "): range ", + begin.toS, + ":", + begin.toS, + "+", + size.toS, + " is not contained by [0,", + shapeX(i).toS, + ")", + ), + ) shapeBuilder += size stridesBuilder += stridesX(i) case ColonIndex => @@ -976,9 +1612,11 @@ trait SNDArrayValue extends SValue { case ColonIndex => const(0L) } - val newFirstDataAddress = cb.newLocal[Long]("slice_ptr", loadElementAddress(firstElementIndices, cb)) + val newFirstDataAddress = + cb.newLocal[Long]("slice_ptr", loadElementAddress(firstElementIndices, cb)) - val newSType = SNDArraySlice(PCanonicalNDArray(st.pType.elementType, newShape.size, st.pType.required)) + val newSType = + SNDArraySlice(PCanonicalNDArray(st.pType.elementType, newShape.size, st.pType.required)) new SNDArraySliceValue(newSType, newShape, newStrides, newFirstDataAddress) } @@ -1015,7 +1653,7 @@ trait SNDArrayValue extends SValue { val totalSize = cb.newLocal[Long]("sindexableptr_size_in_bytes", storageType.byteSize) if (storageType.elementType.containsPointers) { - SNDArray.coiterate(cb, (this, "A")){ + SNDArray.coiterate(cb, (this, "A")) { case Seq(elt) => cb.assign(totalSize, totalSize + elt.sizeToStoreInBytes(cb).value) } diff --git a/hail/src/main/scala/is/hail/types/physical/stypes/interfaces/SStream.scala b/hail/src/main/scala/is/hail/types/physical/stypes/interfaces/SStream.scala index 138c182d114..cb94b3f352e 100644 --- a/hail/src/main/scala/is/hail/types/physical/stypes/interfaces/SStream.scala +++ b/hail/src/main/scala/is/hail/types/physical/stypes/interfaces/SStream.scala @@ -2,25 +2,26 @@ package is.hail.types.physical.stypes.interfaces import is.hail.annotations.Region import is.hail.asm4s._ -import is.hail.expr.ir.streams.StreamProducer import is.hail.expr.ir.{EmitCode, EmitCodeBuilder, EmitMethodBuilder, IEmitCode} +import is.hail.expr.ir.streams.StreamProducer +import is.hail.types.{RIterable, TypeWithRequiredness} import is.hail.types.physical.PType import is.hail.types.physical.stypes._ import is.hail.types.virtual.{TStream, Type} -import is.hail.types.{RIterable, TypeWithRequiredness} import is.hail.utils.FastSeq trait MissingnessAsMethod { def isMissing: Boolean } + trait NoBoxLongIterator { - def init(partitionRegion: Region, elementRegion: Region) + def init(partitionRegion: Region, elementRegion: Region): Unit // after next() has been called, if eos is true, stream has ended // (and value returned by next() is garbage) def eos: Boolean - def next(): Long // 0L represents missing value + def next(): Long // 0L represents missing value def close(): Unit } @@ -31,16 +32,26 @@ object SStream { final case class SimpleSStream(elementEmitType: EmitType) extends SStream { override def settableTupleTypes(): IndexedSeq[TypeInfo[_]] = throw new NotImplementedError() - override def fromSettables(settables: IndexedSeq[Settable[_]]): SSettable = throw new NotImplementedError() + + override def fromSettables(settables: IndexedSeq[Settable[_]]): SSettable = + throw new NotImplementedError() override def fromValues(values: IndexedSeq[Value[_]]): SValue = throw new NotImplementedError() } -final case class SStreamIteratorLong(elementRequired: Boolean, elementPType: PType, requiresMemoryManagement: Boolean) extends SStream { - override def settableTupleTypes(): IndexedSeq[TypeInfo[_]] = IndexedSeq(classInfo[NoBoxLongIterator]) - override def fromSettables(settables: IndexedSeq[Settable[_]]): SSettable = new SStreamConcreteSettable(this, coerce[NoBoxLongIterator](settables(0))) +final case class SStreamIteratorLong( + elementRequired: Boolean, + elementPType: PType, + requiresMemoryManagement: Boolean, +) extends SStream { + override def settableTupleTypes(): IndexedSeq[TypeInfo[_]] = + IndexedSeq(classInfo[NoBoxLongIterator]) - override def fromValues(values: IndexedSeq[Value[_]]): SValue = new SStreamConcrete(this, coerce[NoBoxLongIterator](values(0))) + override def fromSettables(settables: IndexedSeq[Settable[_]]): SSettable = + new SStreamConcreteSettable(this, coerce[NoBoxLongIterator](settables(0))) + + override def fromValues(values: IndexedSeq[Value[_]]): SValue = + new SStreamConcrete(this, coerce[NoBoxLongIterator](values(0))) override val elementEmitType: EmitType = EmitType(elementPType.sType, elementRequired) } @@ -51,7 +62,12 @@ sealed trait SStream extends SType { def elementType: SType = elementEmitType.st - override def _coerceOrCopy(cb: EmitCodeBuilder, region: Value[Region], value: SValue, deepCopy: Boolean): SValue = { + override def _coerceOrCopy( + cb: EmitCodeBuilder, + region: Value[Region], + value: SValue, + deepCopy: Boolean, + ): SValue = { if (deepCopy) throw new NotImplementedError() assert(value.st == this) @@ -66,16 +82,18 @@ sealed trait SStream extends SType { override def virtualType: Type = TStream(elementType.virtualType) - override def castRename(t: Type): SType = throw new UnsupportedOperationException("rename on stream") + override def castRename(t: Type): SType = + throw new UnsupportedOperationException("rename on stream") - override def _typeWithRequiredness: TypeWithRequiredness = RIterable(elementEmitType.typeWithRequiredness.r) + override def _typeWithRequiredness: TypeWithRequiredness = + RIterable(elementEmitType.typeWithRequiredness.r) } object SStreamValue { - def apply(producer: StreamProducer): SStreamValue = SStreamControlFlow(SStream(producer.element.emitType), producer) + def apply(producer: StreamProducer): SStreamValue = + SStreamControlFlow(SStream(producer.element.emitType), producer) } - trait SStreamValue extends SUnrealizableValue { def st: SStream @@ -84,7 +102,8 @@ trait SStreamValue extends SUnrealizableValue { def defineUnusedLabels(mb: EmitMethodBuilder[_]): Unit } -class SStreamConcrete(val st: SStreamIteratorLong, val it: Value[NoBoxLongIterator]) extends SStreamValue { +class SStreamConcrete(val st: SStreamIteratorLong, val it: Value[NoBoxLongIterator]) + extends SStreamValue { lazy val valueTuple: IndexedSeq[Value[_]] = FastSeq(it) @@ -96,9 +115,8 @@ class SStreamConcrete(val st: SStreamIteratorLong, val it: Value[NoBoxLongIterat val next = mb.genFieldThisRef[Long]("stream_iter_next") override val length: Option[EmitCodeBuilder => Code[Int]] = None - override def initialize(cb: EmitCodeBuilder, outerRegion: Value[Region]): Unit = { + override def initialize(cb: EmitCodeBuilder, outerRegion: Value[Region]): Unit = cb += it.invoke[Region, Region, Unit]("init", outerRegion, elementRegion) - } override val elementRegion: Settable[Region] = elRegion override val requiresMemoryManagementPerElement: Boolean = st.requiresMemoryManagement @@ -107,12 +125,14 @@ class SStreamConcrete(val st: SStreamIteratorLong, val it: Value[NoBoxLongIterat cb.if_(it.invoke[Boolean]("eos"), cb.goto(LendOfStream)) cb.goto(LproduceElementDone) } - override val element: EmitCode = { - + override val element: EmitCode = EmitCode.fromI(mb) { cb => - IEmitCode(cb, if (st.elementRequired) const(false) else (next cne 0L), st.elementPType.loadCheapSCode(cb, next)) + IEmitCode( + cb, + if (st.elementRequired) const(false) else (next cne 0L), + st.elementPType.loadCheapSCode(cb, next), + ) } - } override def close(cb: EmitCodeBuilder): Unit = it.invoke[Unit]("close") } @@ -122,7 +142,7 @@ class SStreamConcrete(val st: SStreamIteratorLong, val it: Value[NoBoxLongIterat } class SStreamConcreteSettable(st: SStreamIteratorLong, val itSettable: Settable[NoBoxLongIterator]) - extends SStreamConcrete(st, itSettable) with SSettable { + extends SStreamConcrete(st, itSettable) with SSettable { override def store(cb: EmitCodeBuilder, v: SValue): Unit = { assert(v.st == st) cb.assign(itSettable, v.asInstanceOf[SStreamConcrete].it) @@ -135,7 +155,7 @@ case class SStreamControlFlow(st: SimpleSStream, producer: StreamProducer) exten override def getProducer(mb: EmitMethodBuilder[_]): StreamProducer = { if (mb != producer.method) throw new RuntimeException("stream used in method different from where it was generated -- " + - s"generated in ${ producer.method.mb.methodName }, used in ${ mb.mb.methodName }") + s"generated in ${producer.method.mb.methodName}, used in ${mb.mb.methodName}") producer } @@ -145,14 +165,14 @@ case class SStreamControlFlow(st: SimpleSStream, producer: StreamProducer) exten (producer.LendOfStream.isImplemented, producer.LproduceElementDone.isImplemented) match { case (true, true) => case (false, false) => - EmitCodeBuilder.scopedVoid(mb) { cb => cb.define(producer.LendOfStream) cb.define(producer.LproduceElementDone) cb._fatal("unreachable") } - case (eos, ped) => throw new RuntimeException(s"unrealizable value unused asymmetrically: eos=$eos, ped=$ped") + case (eos, ped) => + throw new RuntimeException(s"unrealizable value unused asymmetrically: eos=$eos, ped=$ped") } producer.element.pv match { case ss: SStreamValue => ss.defineUnusedLabels(mb) diff --git a/hail/src/main/scala/is/hail/types/physical/stypes/interfaces/SString.scala b/hail/src/main/scala/is/hail/types/physical/stypes/interfaces/SString.scala index 7622c0c6c46..a400b861092 100644 --- a/hail/src/main/scala/is/hail/types/physical/stypes/interfaces/SString.scala +++ b/hail/src/main/scala/is/hail/types/physical/stypes/interfaces/SString.scala @@ -3,9 +3,9 @@ package is.hail.types.physical.stypes.interfaces import is.hail.annotations.Region import is.hail.asm4s._ import is.hail.expr.ir.EmitCodeBuilder -import is.hail.types.physical.stypes.primitives.SInt32Value -import is.hail.types.physical.stypes.{SCode, SType, SValue} import is.hail.types.{RPrimitive, TypeWithRequiredness} +import is.hail.types.physical.stypes.{SType, SValue} +import is.hail.types.physical.stypes.primitives.SInt32Value trait SString extends SType { def constructFromString(cb: EmitCodeBuilder, r: Value[Region], s: Code[String]): SStringValue diff --git a/hail/src/main/scala/is/hail/types/physical/stypes/interfaces/SVoid.scala b/hail/src/main/scala/is/hail/types/physical/stypes/interfaces/SVoid.scala index e3aeafbc434..aa975c8455b 100644 --- a/hail/src/main/scala/is/hail/types/physical/stypes/interfaces/SVoid.scala +++ b/hail/src/main/scala/is/hail/types/physical/stypes/interfaces/SVoid.scala @@ -15,13 +15,20 @@ case object SVoid extends SType { override def castRename(t: Type): SType = this - override def _coerceOrCopy(cb: EmitCodeBuilder, region: Value[Region], value: SValue, deepCopy: Boolean): SValue = value + override def _coerceOrCopy( + cb: EmitCodeBuilder, + region: Value[Region], + value: SValue, + deepCopy: Boolean, + ): SValue = value override def settableTupleTypes(): IndexedSeq[TypeInfo[_]] = IndexedSeq() - override def fromSettables(settables: IndexedSeq[Settable[_]]): SSettable = throw new UnsupportedOperationException + override def fromSettables(settables: IndexedSeq[Settable[_]]): SSettable = + throw new UnsupportedOperationException - override def fromValues(values: IndexedSeq[Value[_]]): SValue = throw new UnsupportedOperationException + override def fromValues(values: IndexedSeq[Value[_]]): SValue = + throw new UnsupportedOperationException override def storageType(): PType = throw new UnsupportedOperationException diff --git a/hail/src/main/scala/is/hail/types/physical/stypes/primitives/SBoolean.scala b/hail/src/main/scala/is/hail/types/physical/stypes/primitives/SBoolean.scala index 4eaaf1600b6..0ad1d0d46e8 100644 --- a/hail/src/main/scala/is/hail/types/physical/stypes/primitives/SBoolean.scala +++ b/hail/src/main/scala/is/hail/types/physical/stypes/primitives/SBoolean.scala @@ -3,12 +3,11 @@ package is.hail.types.physical.stypes.primitives import is.hail.annotations.Region import is.hail.asm4s.{BooleanInfo, Settable, SettableBuilder, TypeInfo, Value} import is.hail.expr.ir.EmitCodeBuilder -import is.hail.types.physical.stypes.{SSettable, SType, SValue} import is.hail.types.physical.{PBoolean, PType} +import is.hail.types.physical.stypes.{SSettable, SType, SValue} import is.hail.types.virtual.{TBoolean, Type} import is.hail.utils.FastSeq - case object SBoolean extends SPrimitive { override def ti: TypeInfo[_] = BooleanInfo @@ -16,25 +15,29 @@ case object SBoolean extends SPrimitive { override def castRename(t: Type): SType = this - override def _coerceOrCopy(cb: EmitCodeBuilder, region: Value[Region], value: SValue, deepCopy: Boolean): SValue = { + override def _coerceOrCopy( + cb: EmitCodeBuilder, + region: Value[Region], + value: SValue, + deepCopy: Boolean, + ): SValue = value.st match { case SBoolean => value } - } override def settableTupleTypes(): IndexedSeq[TypeInfo[_]] = FastSeq(BooleanInfo) override def fromSettables(settables: IndexedSeq[Settable[_]]): SBooleanSettable = { - val IndexedSeq(x: Settable[Boolean@unchecked]) = settables + val IndexedSeq(x: Settable[Boolean @unchecked]) = settables assert(x.ti == BooleanInfo) - new SBooleanSettable( x) + new SBooleanSettable(x) } override def fromValues(values: IndexedSeq[Value[_]]): SBooleanValue = { - val IndexedSeq(x: Value[Boolean@unchecked]) = values + val IndexedSeq(x: Value[Boolean @unchecked]) = values assert(x.ti == BooleanInfo) - new SBooleanValue( x) + new SBooleanValue(x) } override def storageType(): PType = PBoolean() @@ -56,9 +59,8 @@ class SBooleanValue(val value: Value[Boolean]) extends SPrimitiveValue { } object SBooleanSettable { - def apply(sb: SettableBuilder, name: String): SBooleanSettable = { - new SBooleanSettable( sb.newSettable[Boolean](name)) - } + def apply(sb: SettableBuilder, name: String): SBooleanSettable = + new SBooleanSettable(sb.newSettable[Boolean](name)) } class SBooleanSettable(x: Settable[Boolean]) extends SBooleanValue(x) with SSettable { diff --git a/hail/src/main/scala/is/hail/types/physical/stypes/primitives/SFloat32.scala b/hail/src/main/scala/is/hail/types/physical/stypes/primitives/SFloat32.scala index 37c3c094064..d53237a2512 100644 --- a/hail/src/main/scala/is/hail/types/physical/stypes/primitives/SFloat32.scala +++ b/hail/src/main/scala/is/hail/types/physical/stypes/primitives/SFloat32.scala @@ -3,8 +3,8 @@ package is.hail.types.physical.stypes.primitives import is.hail.annotations.Region import is.hail.asm4s.{Code, FloatInfo, Settable, SettableBuilder, TypeInfo, Value} import is.hail.expr.ir.EmitCodeBuilder -import is.hail.types.physical.stypes.{SSettable, SType, SValue} import is.hail.types.physical.{PFloat32, PType} +import is.hail.types.physical.stypes.{SSettable, SType, SValue} import is.hail.types.virtual.{TFloat32, Type} import is.hail.utils.FastSeq @@ -15,22 +15,26 @@ case object SFloat32 extends SPrimitive { override def castRename(t: Type): SType = this - override def _coerceOrCopy(cb: EmitCodeBuilder, region: Value[Region], value: SValue, deepCopy: Boolean): SValue = { + override def _coerceOrCopy( + cb: EmitCodeBuilder, + region: Value[Region], + value: SValue, + deepCopy: Boolean, + ): SValue = value.st match { case SFloat32 => value } - } override def settableTupleTypes(): IndexedSeq[TypeInfo[_]] = FastSeq(FloatInfo) override def fromSettables(settables: IndexedSeq[Settable[_]]): SFloat32Settable = { - val IndexedSeq(x: Settable[Float@unchecked]) = settables + val IndexedSeq(x: Settable[Float @unchecked]) = settables assert(x.ti == FloatInfo) new SFloat32Settable(x) } override def fromValues(values: IndexedSeq[Value[_]]): SFloat32Value = { - val IndexedSeq(x: Value[Float@unchecked]) = values + val IndexedSeq(x: Value[Float @unchecked]) = values assert(x.ti == FloatInfo) new SFloat32Value(x) } @@ -48,15 +52,17 @@ class SFloat32Value(val value: Value[Float]) extends SPrimitiveValue { override def _primitiveValue: Value[_] = value override def hash(cb: EmitCodeBuilder): SInt32Value = - new SInt32Value(cb.memoize(Code.invokeStatic1[java.lang.Float, Float, Int]("floatToIntBits", value))) + new SInt32Value(cb.memoize(Code.invokeStatic1[java.lang.Float, Float, Int]( + "floatToIntBits", + value, + ))) override def sizeToStoreInBytes(cb: EmitCodeBuilder): SInt64Value = new SInt64Value(4L) } object SFloat32Settable { - def apply(sb: SettableBuilder, name: String): SFloat32Settable = { + def apply(sb: SettableBuilder, name: String): SFloat32Settable = new SFloat32Settable(sb.newSettable[Float](name)) - } } final class SFloat32Settable(x: Settable[Float]) extends SFloat32Value(x) with SSettable { diff --git a/hail/src/main/scala/is/hail/types/physical/stypes/primitives/SFloat64.scala b/hail/src/main/scala/is/hail/types/physical/stypes/primitives/SFloat64.scala index 692028a9611..41e4c9cb555 100644 --- a/hail/src/main/scala/is/hail/types/physical/stypes/primitives/SFloat64.scala +++ b/hail/src/main/scala/is/hail/types/physical/stypes/primitives/SFloat64.scala @@ -1,11 +1,11 @@ package is.hail.types.physical.stypes.primitives import is.hail.annotations.Region -import is.hail.asm4s.Code.invokeStatic1 import is.hail.asm4s.{DoubleInfo, Settable, SettableBuilder, TypeInfo, Value} +import is.hail.asm4s.Code.invokeStatic1 import is.hail.expr.ir.EmitCodeBuilder -import is.hail.types.physical.stypes.{SSettable, SType, SValue} import is.hail.types.physical.{PFloat64, PType} +import is.hail.types.physical.stypes.{SSettable, SType, SValue} import is.hail.types.virtual.{TFloat64, Type} import is.hail.utils.FastSeq @@ -16,22 +16,26 @@ case object SFloat64 extends SPrimitive { override def castRename(t: Type): SType = this - override def _coerceOrCopy(cb: EmitCodeBuilder, region: Value[Region], value: SValue, deepCopy: Boolean): SValue = { + override def _coerceOrCopy( + cb: EmitCodeBuilder, + region: Value[Region], + value: SValue, + deepCopy: Boolean, + ): SValue = value.st match { case SFloat64 => value } - } override def settableTupleTypes(): IndexedSeq[TypeInfo[_]] = FastSeq(DoubleInfo) override def fromSettables(settables: IndexedSeq[Settable[_]]): SFloat64Settable = { - val IndexedSeq(x: Settable[Double@unchecked]) = settables + val IndexedSeq(x: Settable[Double @unchecked]) = settables assert(x.ti == DoubleInfo) new SFloat64Settable(x) } override def fromValues(settables: IndexedSeq[Value[_]]): SFloat64Value = { - val IndexedSeq(x: Value[Double@unchecked]) = settables + val IndexedSeq(x: Value[Double @unchecked]) = settables assert(x.ti == DoubleInfo) new SFloat64Value(x) } @@ -59,9 +63,8 @@ class SFloat64Value(val value: Value[Double]) extends SPrimitiveValue { } object SFloat64Settable { - def apply(sb: SettableBuilder, name: String): SFloat64Settable = { + def apply(sb: SettableBuilder, name: String): SFloat64Settable = new SFloat64Settable(sb.newSettable[Double](name)) - } } final class SFloat64Settable(x: Settable[Double]) extends SFloat64Value(x) with SSettable { diff --git a/hail/src/main/scala/is/hail/types/physical/stypes/primitives/SInt32.scala b/hail/src/main/scala/is/hail/types/physical/stypes/primitives/SInt32.scala index 17109c165fd..f2dcea4340c 100644 --- a/hail/src/main/scala/is/hail/types/physical/stypes/primitives/SInt32.scala +++ b/hail/src/main/scala/is/hail/types/physical/stypes/primitives/SInt32.scala @@ -3,8 +3,8 @@ package is.hail.types.physical.stypes.primitives import is.hail.annotations.Region import is.hail.asm4s.{IntInfo, Settable, SettableBuilder, TypeInfo, Value} import is.hail.expr.ir.EmitCodeBuilder -import is.hail.types.physical.stypes.{SSettable, SType, SValue} import is.hail.types.physical.{PInt32, PType} +import is.hail.types.physical.stypes.{SSettable, SType, SValue} import is.hail.types.virtual.{TInt32, Type} import is.hail.utils.FastSeq @@ -15,22 +15,26 @@ case object SInt32 extends SPrimitive { override def castRename(t: Type): SType = this - override def _coerceOrCopy(cb: EmitCodeBuilder, region: Value[Region], value: SValue, deepCopy: Boolean): SValue = { + override def _coerceOrCopy( + cb: EmitCodeBuilder, + region: Value[Region], + value: SValue, + deepCopy: Boolean, + ): SValue = value.st match { case SInt32 => value } - } override def settableTupleTypes(): IndexedSeq[TypeInfo[_]] = FastSeq(IntInfo) override def fromSettables(settables: IndexedSeq[Settable[_]]): SInt32Settable = { - val IndexedSeq(x: Settable[Int@unchecked]) = settables + val IndexedSeq(x: Settable[Int @unchecked]) = settables assert(x.ti == IntInfo) new SInt32Settable(x) } override def fromValues(settables: IndexedSeq[Value[_]]): SInt32Value = { - val IndexedSeq(x: Value[Int@unchecked]) = settables + val IndexedSeq(x: Value[Int @unchecked]) = settables assert(x.ti == IntInfo) new SInt32Value(x) } @@ -54,9 +58,8 @@ class SInt32Value(val value: Value[Int]) extends SPrimitiveValue { } object SInt32Settable { - def apply(sb: SettableBuilder, name: String): SInt32Settable = { + def apply(sb: SettableBuilder, name: String): SInt32Settable = new SInt32Settable(sb.newSettable[Int](name)) - } } final class SInt32Settable(x: Settable[Int]) extends SInt32Value(x) with SSettable { diff --git a/hail/src/main/scala/is/hail/types/physical/stypes/primitives/SInt64.scala b/hail/src/main/scala/is/hail/types/physical/stypes/primitives/SInt64.scala index b336f1291bf..27a32ea8fc6 100644 --- a/hail/src/main/scala/is/hail/types/physical/stypes/primitives/SInt64.scala +++ b/hail/src/main/scala/is/hail/types/physical/stypes/primitives/SInt64.scala @@ -1,11 +1,11 @@ package is.hail.types.physical.stypes.primitives import is.hail.annotations.Region -import is.hail.asm4s.Code.invokeStatic1 import is.hail.asm4s.{LongInfo, Settable, SettableBuilder, TypeInfo, Value} +import is.hail.asm4s.Code.invokeStatic1 import is.hail.expr.ir.EmitCodeBuilder -import is.hail.types.physical.stypes.{SSettable, SType, SValue} import is.hail.types.physical.{PInt64, PType} +import is.hail.types.physical.stypes.{SSettable, SType, SValue} import is.hail.types.virtual.{TInt64, Type} import is.hail.utils.FastSeq @@ -16,22 +16,26 @@ case object SInt64 extends SPrimitive { override def castRename(t: Type): SType = this - override def _coerceOrCopy(cb: EmitCodeBuilder, region: Value[Region], value: SValue, deepCopy: Boolean): SValue = { + override def _coerceOrCopy( + cb: EmitCodeBuilder, + region: Value[Region], + value: SValue, + deepCopy: Boolean, + ): SValue = value.st match { case SInt64 => value } - } override def settableTupleTypes(): IndexedSeq[TypeInfo[_]] = FastSeq(LongInfo) override def fromSettables(settables: IndexedSeq[Settable[_]]): SInt64Settable = { - val IndexedSeq(x: Settable[Long@unchecked]) = settables + val IndexedSeq(x: Settable[Long @unchecked]) = settables assert(x.ti == LongInfo) new SInt64Settable(x) } override def fromValues(settables: IndexedSeq[Value[_]]): SInt64Value = { - val IndexedSeq(x: Value[Long@unchecked]) = settables + val IndexedSeq(x: Value[Long @unchecked]) = settables assert(x.ti == LongInfo) new SInt64Value(x) } @@ -55,13 +59,13 @@ class SInt64Value(val value: Value[Long]) extends SPrimitiveValue { override def hash(cb: EmitCodeBuilder): SInt32Value = new SInt32Value(cb.memoize(invokeStatic1[java.lang.Long, Long, Int]("hashCode", value))) - override def sizeToStoreInBytes(cb: EmitCodeBuilder): SInt64Value = new SInt64Value(this.st.storageType().asInstanceOf[PInt64].byteSize) + override def sizeToStoreInBytes(cb: EmitCodeBuilder): SInt64Value = + new SInt64Value(this.st.storageType().asInstanceOf[PInt64].byteSize) } object SInt64Settable { - def apply(sb: SettableBuilder, name: String): SInt64Settable = { + def apply(sb: SettableBuilder, name: String): SInt64Settable = new SInt64Settable(sb.newSettable[Long](name)) - } } final class SInt64Settable(x: Settable[Long]) extends SInt64Value(x) with SSettable { diff --git a/hail/src/main/scala/is/hail/types/physical/stypes/primitives/SPrimitive.scala b/hail/src/main/scala/is/hail/types/physical/stypes/primitives/SPrimitive.scala index b2c14552479..aa05e7763d2 100644 --- a/hail/src/main/scala/is/hail/types/physical/stypes/primitives/SPrimitive.scala +++ b/hail/src/main/scala/is/hail/types/physical/stypes/primitives/SPrimitive.scala @@ -2,7 +2,7 @@ package is.hail.types.physical.stypes.primitives import is.hail.asm4s._ import is.hail.types.{RPrimitive, TypeWithRequiredness} -import is.hail.types.physical.stypes.{SCode, SType, SValue} +import is.hail.types.physical.stypes.{SType, SValue} trait SPrimitive extends SType { def ti: TypeInfo[_] @@ -19,4 +19,4 @@ abstract class SPrimitiveValue extends SValue { protected[primitives] def _primitiveValue: Value[_] final def primitiveValue[T]: Value[T] = coerce[T](_primitiveValue) -} \ No newline at end of file +} diff --git a/hail/src/main/scala/is/hail/types/virtual/Field.scala b/hail/src/main/scala/is/hail/types/virtual/Field.scala index 5af8c507dce..d6fdf5c358e 100644 --- a/hail/src/main/scala/is/hail/types/virtual/Field.scala +++ b/hail/src/main/scala/is/hail/types/virtual/Field.scala @@ -9,7 +9,7 @@ final case class Field(name: String, typ: Type, index: Int) { typ.unify(cf.typ) && index == cf.index - def pretty(sb: StringBuilder, indent: Int, compact: Boolean) { + def pretty(sb: StringBuilder, indent: Int, compact: Boolean): Unit = { if (compact) { sb.append(prettyIdentifier(name)) sb.append(":") diff --git a/hail/src/main/scala/is/hail/types/virtual/TArray.scala b/hail/src/main/scala/is/hail/types/virtual/TArray.scala index 84e5d7705c0..07ada81bf14 100644 --- a/hail/src/main/scala/is/hail/types/virtual/TArray.scala +++ b/hail/src/main/scala/is/hail/types/virtual/TArray.scala @@ -3,9 +3,10 @@ package is.hail.types.virtual import is.hail.annotations.{Annotation, ExtendedOrdering} import is.hail.backend.HailStateManager import is.hail.check.Gen -import org.json4s.jackson.JsonMethods -import scala.reflect.{ClassTag, classTag} +import scala.reflect.{classTag, ClassTag} + +import org.json4s.jackson.JsonMethods final case class TArray(elementType: Type) extends TContainer { override def pyString(sb: StringBuilder): Unit = { @@ -28,7 +29,7 @@ final case class TArray(elementType: Type) extends TContainer { override def subst() = TArray(elementType.subst()) - override def _pretty(sb: StringBuilder, indent: Int, compact: Boolean = false) { + override def _pretty(sb: StringBuilder, indent: Int, compact: Boolean = false): Unit = { sb.append("Array[") elementType.pretty(sb, indent, compact) sb.append("]") @@ -61,4 +62,10 @@ final case class TArray(elementType: Type) extends TContainer { } override def arrayElementsRepr: TArray = this + + override def isIsomorphicTo(t: Type): Boolean = + t match { + case a: TArray => elementType isIsomorphicTo a.elementType + case _ => false + } } diff --git a/hail/src/main/scala/is/hail/types/virtual/TBaseStruct.scala b/hail/src/main/scala/is/hail/types/virtual/TBaseStruct.scala index f8ca2eaac89..052e68d78d1 100644 --- a/hail/src/main/scala/is/hail/types/virtual/TBaseStruct.scala +++ b/hail/src/main/scala/is/hail/types/virtual/TBaseStruct.scala @@ -4,22 +4,27 @@ import is.hail.annotations._ import is.hail.backend.HailStateManager import is.hail.check.Gen import is.hail.utils._ + +import scala.reflect.{classTag, ClassTag} + import org.apache.spark.sql.Row import org.json4s.jackson.JsonMethods -import scala.reflect.{ClassTag, classTag} - object TBaseStruct { - /** - * Define an ordering on Row objects. Works with any row r such that the list - * of types of r is a prefix of types, or types is a prefix of the list of - * types of r. + + /** Define an ordering on Row objects. Works with any row r such that the list of types of r is a + * prefix of types, or types is a prefix of the list of types of r. */ - def getOrdering(sm: HailStateManager, types: Array[Type], missingEqual: Boolean = true): ExtendedOrdering = + def getOrdering(sm: HailStateManager, types: Array[Type], missingEqual: Boolean = true) + : ExtendedOrdering = ExtendedOrdering.rowOrdering(types.map(_.ordering(sm)), missingEqual) - def getJoinOrdering(sm: HailStateManager, types: Array[Type], missingEqual: Boolean = false): ExtendedOrdering = - ExtendedOrdering.rowOrdering(types.map(_.mkOrdering(sm, missingEqual = missingEqual)), _missingEqual = missingEqual) + def getJoinOrdering(sm: HailStateManager, types: Array[Type], missingEqual: Boolean = false) + : ExtendedOrdering = + ExtendedOrdering.rowOrdering( + types.map(_.mkOrdering(sm, missingEqual = missingEqual)), + _missingEqual = missingEqual, + ) } abstract class TBaseStruct extends Type { @@ -42,14 +47,14 @@ abstract class TBaseStruct extends Type { override def _typeCheck(a: Any): Boolean = a match { case row: Row => row.length == types.length && - isComparableAt(a) + isComparableAt(a) case _ => false } def relaxedTypeCheck(a: Any): Boolean = a match { case row: Row => row.length <= types.length && - isComparableAt(a) + isComparableAt(a) case _ => false } @@ -61,14 +66,23 @@ abstract class TBaseStruct extends Type { case _ => false } - def isIsomorphicTo(other: TBaseStruct): Boolean = + def isJoinableWith(other: TBaseStruct): Boolean = size == other.size && isCompatibleWith(other) def isPrefixOf(other: TBaseStruct): Boolean = size <= other.size && isCompatibleWith(other) + override def isIsomorphicTo(t: Type): Boolean = + t match { + case s: TBaseStruct => size == s.size && forallZippedFields(s)(_.typ isIsomorphicTo _.typ) + case _ => false + } + def isCompatibleWith(other: TBaseStruct): Boolean = - fields.zip(other.fields).forall{ case (l, r) => l.typ == r.typ } + forallZippedFields(other)(_.typ == _.typ) + + private def forallZippedFields(s: TBaseStruct)(p: (Field, Field) => Boolean): Boolean = + (fields, s.fields).zipped.forall(p) def truncate(newSize: Int): TBaseStruct @@ -91,16 +105,18 @@ abstract class TBaseStruct extends Type { if (types.length > fuel) Gen.uniformSequence(types.map(t => Gen.const(null))).map(a => Annotation(a: _*)) else - Gen.uniformSequence(types.map(t => t.genValue(sm))).map(a => Annotation(a: _*))) + Gen.uniformSequence(types.map(t => t.genValue(sm))).map(a => Annotation(a: _*)) + ) } - override def valuesSimilar(a1: Annotation, a2: Annotation, tolerance: Double, absolute: Boolean): Boolean = + override def valuesSimilar(a1: Annotation, a2: Annotation, tolerance: Double, absolute: Boolean) + : Boolean = a1 == a2 || (a1 != null && a2 != null && types.zip(a1.asInstanceOf[Row].toSeq).zip(a2.asInstanceOf[Row].toSeq) - .forall { - case ((t, x1), x2) => - t.valuesSimilar(x1, x2, tolerance, absolute) - }) + .forall { + case ((t, x1), x2) => + t.valuesSimilar(x1, x2, tolerance, absolute) + }) override def scalaClassTag: ClassTag[Row] = classTag[Row] } diff --git a/hail/src/main/scala/is/hail/types/virtual/TBinary.scala b/hail/src/main/scala/is/hail/types/virtual/TBinary.scala index c80506293f6..804da2e06c0 100644 --- a/hail/src/main/scala/is/hail/types/virtual/TBinary.scala +++ b/hail/src/main/scala/is/hail/types/virtual/TBinary.scala @@ -4,7 +4,6 @@ import is.hail.annotations._ import is.hail.backend.HailStateManager import is.hail.check.Arbitrary._ import is.hail.check.Gen -import is.hail.types.physical.PBinary import scala.reflect.{ClassTag, _} @@ -13,16 +12,22 @@ case object TBinary extends Type { def _typeCheck(a: Any): Boolean = a.isInstanceOf[Array[Byte]] - override def genNonmissingValue(sm: HailStateManager): Gen[Annotation] = Gen.buildableOf(arbitrary[Byte]) + override def genNonmissingValue(sm: HailStateManager): Gen[Annotation] = + Gen.buildableOf(arbitrary[Byte]) override def scalaClassTag: ClassTag[Array[Byte]] = classTag[Array[Byte]] - def mkOrdering(sm: HailStateManager, _missingEqual: Boolean = true): ExtendedOrdering = ExtendedOrdering.iterableOrdering(new ExtendedOrdering { - val missingEqual = _missingEqual + def mkOrdering(sm: HailStateManager, _missingEqual: Boolean = true): ExtendedOrdering = + ExtendedOrdering.iterableOrdering(new ExtendedOrdering { + val missingEqual = _missingEqual - override def compareNonnull(x: Any, y: Any): Int = - java.lang.Integer.compare( - java.lang.Byte.toUnsignedInt(x.asInstanceOf[Byte]), - java.lang.Byte.toUnsignedInt(y.asInstanceOf[Byte])) - }) + override def compareNonnull(x: Any, y: Any): Int = + java.lang.Integer.compare( + java.lang.Byte.toUnsignedInt(x.asInstanceOf[Byte]), + java.lang.Byte.toUnsignedInt(y.asInstanceOf[Byte]), + ) + }) + + override def isIsomorphicTo(t: Type): Boolean = + this == t } diff --git a/hail/src/main/scala/is/hail/types/virtual/TBoolean.scala b/hail/src/main/scala/is/hail/types/virtual/TBoolean.scala index 3032a1fec2c..5fc1e898c62 100644 --- a/hail/src/main/scala/is/hail/types/virtual/TBoolean.scala +++ b/hail/src/main/scala/is/hail/types/virtual/TBoolean.scala @@ -4,16 +4,14 @@ import is.hail.annotations._ import is.hail.backend.HailStateManager import is.hail.check.Arbitrary._ import is.hail.check.Gen -import is.hail.types.physical.PBoolean import scala.reflect.{ClassTag, _} case object TBoolean extends Type { def _toPretty = "Boolean" - override def pyString(sb: StringBuilder): Unit = { + override def pyString(sb: StringBuilder): Unit = sb.append("bool") - } def _typeCheck(a: Any): Boolean = a.isInstanceOf[Boolean] @@ -27,4 +25,7 @@ case object TBoolean extends Type { override def mkOrdering(sm: HailStateManager, missingEqual: Boolean): ExtendedOrdering = ExtendedOrdering.extendToNull(implicitly[Ordering[Boolean]], missingEqual) + + override def isIsomorphicTo(t: Type): Boolean = + this == t } diff --git a/hail/src/main/scala/is/hail/types/virtual/TCall.scala b/hail/src/main/scala/is/hail/types/virtual/TCall.scala index 35a1dbee23a..1db13fa683b 100644 --- a/hail/src/main/scala/is/hail/types/virtual/TCall.scala +++ b/hail/src/main/scala/is/hail/types/virtual/TCall.scala @@ -3,8 +3,6 @@ package is.hail.types.virtual import is.hail.annotations._ import is.hail.backend.HailStateManager import is.hail.check.Gen -import is.hail.types._ -import is.hail.types.physical.PCall import is.hail.variant.Call import scala.reflect.{ClassTag, _} @@ -12,9 +10,9 @@ import scala.reflect.{ClassTag, _} case object TCall extends Type { def _toPretty = "Call" - override def pyString(sb: StringBuilder): Unit = { + override def pyString(sb: StringBuilder): Unit = sb.append("call") - } + val representation: Type = TInt32 def _typeCheck(a: Any): Boolean = a.isInstanceOf[Int] @@ -23,8 +21,12 @@ case object TCall extends Type { override def scalaClassTag: ClassTag[java.lang.Integer] = classTag[java.lang.Integer] - override def str(a: Annotation): String = if (a == null) "NA" else Call.toString(a.asInstanceOf[Call]) + override def str(a: Annotation): String = + if (a == null) "NA" else Call.toString(a.asInstanceOf[Call]) override def mkOrdering(sm: HailStateManager, missingEqual: Boolean): ExtendedOrdering = ExtendedOrdering.extendToNull(implicitly[Ordering[Int]], missingEqual) + + override def isIsomorphicTo(t: Type): Boolean = + this == t } diff --git a/hail/src/main/scala/is/hail/types/virtual/TContainer.scala b/hail/src/main/scala/is/hail/types/virtual/TContainer.scala index 71afda8e8cd..0a22a5c17d1 100644 --- a/hail/src/main/scala/is/hail/types/virtual/TContainer.scala +++ b/hail/src/main/scala/is/hail/types/virtual/TContainer.scala @@ -1,14 +1,14 @@ package is.hail.types.virtual import is.hail.annotations.Annotation -import is.hail.types.physical.PContainer abstract class TContainer extends TIterable { - override def valuesSimilar(a1: Annotation, a2: Annotation, tolerance: Double, absolute: Boolean): Boolean = + override def valuesSimilar(a1: Annotation, a2: Annotation, tolerance: Double, absolute: Boolean) + : Boolean = a1 == a2 || (a1 != null && a2 != null && (a1.asInstanceOf[Iterable[_]].size == a2.asInstanceOf[Iterable[_]].size) && a1.asInstanceOf[Iterable[_]].zip(a2.asInstanceOf[Iterable[_]]) - .forall { case (e1, e2) => elementType.valuesSimilar(e1, e2, tolerance, absolute) }) + .forall { case (e1, e2) => elementType.valuesSimilar(e1, e2, tolerance, absolute) }) def arrayElementsRepr: TArray } diff --git a/hail/src/main/scala/is/hail/types/virtual/TDict.scala b/hail/src/main/scala/is/hail/types/virtual/TDict.scala index 4982be9d00f..b7b140a8861 100644 --- a/hail/src/main/scala/is/hail/types/virtual/TDict.scala +++ b/hail/src/main/scala/is/hail/types/virtual/TDict.scala @@ -3,14 +3,15 @@ package is.hail.types.virtual import is.hail.annotations.{Annotation, ExtendedOrdering} import is.hail.backend.HailStateManager import is.hail.check.Gen -import is.hail.types.physical.PDict import is.hail.utils._ -import org.json4s.jackson.JsonMethods -import scala.reflect.{ClassTag, classTag} +import scala.reflect.{classTag, ClassTag} + +import org.json4s.jackson.JsonMethods final case class TDict(keyType: Type, valueType: Type) extends TContainer { - lazy val elementType: TBaseStruct = (TStruct("key" -> keyType, "value" -> valueType)).asInstanceOf[TBaseStruct] + lazy val elementType: TBaseStruct = + (TStruct("key" -> keyType, "value" -> valueType)).asInstanceOf[TBaseStruct] override def canCompare(other: Type): Boolean = other match { case TDict(okt, ovt) => keyType.canCompare(okt) && valueType.canCompare(ovt) @@ -19,12 +20,11 @@ final case class TDict(keyType: Type, valueType: Type) extends TContainer { override def children = FastSeq(keyType, valueType) - override def unify(concrete: Type): Boolean = { + override def unify(concrete: Type): Boolean = concrete match { case TDict(kt, vt) => keyType.unify(kt) && valueType.unify(vt) case _ => false } - } override def subst() = TDict(keyType.subst(), valueType.subst()) @@ -38,7 +38,7 @@ final case class TDict(keyType: Type, valueType: Type) extends TContainer { sb.append('>') } - override def _pretty(sb: StringBuilder, indent: Int, compact: Boolean = false) { + override def _pretty(sb: StringBuilder, indent: Int, compact: Boolean = false): Unit = { sb.append("Dict[") keyType.pretty(sb, indent, compact) if (compact) @@ -50,7 +50,9 @@ final case class TDict(keyType: Type, valueType: Type) extends TContainer { } def _typeCheck(a: Any): Boolean = a == null || (a.isInstanceOf[Map[_, _]] && - a.asInstanceOf[Map[_, _]].forall { case (k, v) => keyType.typeCheck(k) && valueType.typeCheck(v) }) + a.asInstanceOf[Map[_, _]].forall { case (k, v) => + keyType.typeCheck(k) && valueType.typeCheck(v) + }) override def _showStr(a: Annotation): String = a.asInstanceOf[Map[Annotation, Annotation]] @@ -62,11 +64,14 @@ final case class TDict(keyType: Type, valueType: Type) extends TContainer { override def genNonmissingValue(sm: HailStateManager): Gen[Annotation] = Gen.buildableOf2[Map](Gen.zip(keyType.genValue(sm), valueType.genValue(sm))) - override def valuesSimilar(a1: Annotation, a2: Annotation, tolerance: Double, absolute: Boolean): Boolean = + override def valuesSimilar(a1: Annotation, a2: Annotation, tolerance: Double, absolute: Boolean) + : Boolean = a1 == a2 || (a1 != null && a2 != null && a1.asInstanceOf[Map[Any, _]].outerJoin(a2.asInstanceOf[Map[Any, _]]) .forall { case (_, (o1, o2)) => - o1.liftedZip(o2).exists { case (v1, v2) => valueType.valuesSimilar(v1, v2, tolerance, absolute) } + o1.liftedZip(o2).exists { case (v1, v2) => + valueType.valuesSimilar(v1, v2, tolerance, absolute) + } }) override def scalaClassTag: ClassTag[Map[_, _]] = classTag[Map[_, _]] @@ -85,4 +90,13 @@ final case class TDict(keyType: Type, valueType: Type) extends TContainer { } override def arrayElementsRepr: TArray = TArray(elementType) + + override def isIsomorphicTo(t: Type): Boolean = + t match { + case d: TDict => + (keyType isIsomorphicTo d.keyType) && + (valueType isIsomorphicTo d.valueType) + case _ => + false + } } diff --git a/hail/src/main/scala/is/hail/types/virtual/TFloat32.scala b/hail/src/main/scala/is/hail/types/virtual/TFloat32.scala index d38ae8b2032..250cbfea535 100644 --- a/hail/src/main/scala/is/hail/types/virtual/TFloat32.scala +++ b/hail/src/main/scala/is/hail/types/virtual/TFloat32.scala @@ -4,7 +4,6 @@ import is.hail.annotations._ import is.hail.backend.HailStateManager import is.hail.check.Arbitrary._ import is.hail.check.Gen -import is.hail.types.physical.PFloat32 import is.hail.utils._ import scala.reflect.{ClassTag, _} @@ -12,19 +11,21 @@ import scala.reflect.{ClassTag, _} case object TFloat32 extends TNumeric { def _toPretty = "Float32" - override def pyString(sb: StringBuilder): Unit = { + override def pyString(sb: StringBuilder): Unit = sb.append("float32") - } def _typeCheck(a: Any): Boolean = a.isInstanceOf[Float] override def _showStr(a: Annotation): String = "%.02e".format(a.asInstanceOf[Float]) - override def str(a: Annotation): String = if (a == null) "NA" else "%.5e".format(a.asInstanceOf[Float]) + override def str(a: Annotation): String = + if (a == null) "NA" else "%.5e".format(a.asInstanceOf[Float]) - override def genNonmissingValue(sm: HailStateManager): Gen[Annotation] = arbitrary[Double].map(_.toFloat) + override def genNonmissingValue(sm: HailStateManager): Gen[Annotation] = + arbitrary[Double].map(_.toFloat) - override def valuesSimilar(a1: Annotation, a2: Annotation, tolerance: Double, absolute: Boolean): Boolean = + override def valuesSimilar(a1: Annotation, a2: Annotation, tolerance: Double, absolute: Boolean) + : Boolean = a1 == a2 || (a1 != null && a2 != null && { val f1 = a1.asInstanceOf[Float] val f2 = a2.asInstanceOf[Float] @@ -42,4 +43,7 @@ case object TFloat32 extends TNumeric { override def mkOrdering(sm: HailStateManager, missingEqual: Boolean): ExtendedOrdering = ExtendedOrdering.extendToNull(implicitly[Ordering[Float]], missingEqual) + + override def isIsomorphicTo(t: Type): Boolean = + this == t } diff --git a/hail/src/main/scala/is/hail/types/virtual/TFloat64.scala b/hail/src/main/scala/is/hail/types/virtual/TFloat64.scala index 3691e1a2c8a..3af52119ce6 100644 --- a/hail/src/main/scala/is/hail/types/virtual/TFloat64.scala +++ b/hail/src/main/scala/is/hail/types/virtual/TFloat64.scala @@ -4,7 +4,6 @@ import is.hail.annotations._ import is.hail.backend.HailStateManager import is.hail.check.Arbitrary._ import is.hail.check.Gen -import is.hail.types.physical.PFloat64 import is.hail.utils._ import scala.reflect.{ClassTag, _} @@ -12,19 +11,20 @@ import scala.reflect.{ClassTag, _} case object TFloat64 extends TNumeric { override def _toPretty = "Float64" - override def pyString(sb: StringBuilder): Unit = { + override def pyString(sb: StringBuilder): Unit = sb.append("float64") - } def _typeCheck(a: Any): Boolean = a.isInstanceOf[Double] override def _showStr(a: Annotation): String = "%.02e".format(a.asInstanceOf[Double]) - override def str(a: Annotation): String = if (a == null) "NA" else "%.5e".format(a.asInstanceOf[Double]) + override def str(a: Annotation): String = + if (a == null) "NA" else "%.5e".format(a.asInstanceOf[Double]) override def genNonmissingValue(sm: HailStateManager): Gen[Annotation] = arbitrary[Double] - override def valuesSimilar(a1: Annotation, a2: Annotation, tolerance: Double, absolute: Boolean): Boolean = + override def valuesSimilar(a1: Annotation, a2: Annotation, tolerance: Double, absolute: Boolean) + : Boolean = a1 == a2 || (a1 != null && a2 != null && { val f1 = a1.asInstanceOf[Double] val f2 = a2.asInstanceOf[Double] @@ -42,4 +42,7 @@ case object TFloat64 extends TNumeric { override def mkOrdering(sm: HailStateManager, missingEqual: Boolean): ExtendedOrdering = ExtendedOrdering.extendToNull(implicitly[Ordering[Double]], missingEqual) + + override def isIsomorphicTo(t: Type): Boolean = + this == t } diff --git a/hail/src/main/scala/is/hail/types/virtual/TInt32.scala b/hail/src/main/scala/is/hail/types/virtual/TInt32.scala index e0f088487d6..0836c64ad1c 100644 --- a/hail/src/main/scala/is/hail/types/virtual/TInt32.scala +++ b/hail/src/main/scala/is/hail/types/virtual/TInt32.scala @@ -1,21 +1,17 @@ package is.hail.types.virtual -import is.hail.annotations.{Region, _} +import is.hail.annotations._ import is.hail.backend.HailStateManager -import is.hail.asm4s.Code import is.hail.check.Arbitrary._ import is.hail.check.Gen -import is.hail.expr.ir.EmitMethodBuilder -import is.hail.types.physical.PInt32 import scala.reflect.{ClassTag, _} case object TInt32 extends TIntegral { def _toPretty = "Int32" - override def pyString(sb: StringBuilder): Unit = { + override def pyString(sb: StringBuilder): Unit = sb.append("int32") - } def _typeCheck(a: Any): Boolean = a.isInstanceOf[Int] @@ -25,4 +21,7 @@ case object TInt32 extends TIntegral { override def mkOrdering(sm: HailStateManager, missingEqual: Boolean): ExtendedOrdering = ExtendedOrdering.extendToNull(implicitly[Ordering[Int]], missingEqual) + + override def isIsomorphicTo(t: Type): Boolean = + this == t } diff --git a/hail/src/main/scala/is/hail/types/virtual/TInt64.scala b/hail/src/main/scala/is/hail/types/virtual/TInt64.scala index e7662b0fef4..8ae68c39f7e 100644 --- a/hail/src/main/scala/is/hail/types/virtual/TInt64.scala +++ b/hail/src/main/scala/is/hail/types/virtual/TInt64.scala @@ -4,16 +4,14 @@ import is.hail.annotations._ import is.hail.backend.HailStateManager import is.hail.check.Arbitrary._ import is.hail.check.Gen -import is.hail.types.physical.PInt64 import scala.reflect.{ClassTag, _} case object TInt64 extends TIntegral { def _toPretty = "Int64" - override def pyString(sb: StringBuilder): Unit = { + override def pyString(sb: StringBuilder): Unit = sb.append("int64") - } def _typeCheck(a: Any): Boolean = a.isInstanceOf[Long] @@ -23,4 +21,7 @@ case object TInt64 extends TIntegral { override def mkOrdering(sm: HailStateManager, missingEqual: Boolean): ExtendedOrdering = ExtendedOrdering.extendToNull(implicitly[Ordering[Long]], missingEqual) + + override def isIsomorphicTo(t: Type): Boolean = + this == t } diff --git a/hail/src/main/scala/is/hail/types/virtual/TInterval.scala b/hail/src/main/scala/is/hail/types/virtual/TInterval.scala index ea2f317f036..bc874bf72a4 100644 --- a/hail/src/main/scala/is/hail/types/virtual/TInterval.scala +++ b/hail/src/main/scala/is/hail/types/virtual/TInterval.scala @@ -3,11 +3,9 @@ package is.hail.types.virtual import is.hail.annotations.{Annotation, ExtendedOrdering} import is.hail.backend.HailStateManager import is.hail.check.Gen -import is.hail.types.physical.PInterval -import is.hail.types.virtual.TCall.representation import is.hail.utils.{FastSeq, Interval} -import scala.reflect.{ClassTag, classTag} +import scala.reflect.{classTag, ClassTag} case class TInterval(pointType: Type) extends Type { @@ -20,7 +18,8 @@ case class TInterval(pointType: Type) extends Type { pointType.pyString(sb) sb.append('>') } - override def _pretty(sb: StringBuilder, indent: Int, compact: Boolean = false) { + + override def _pretty(sb: StringBuilder, indent: Int, compact: Boolean = false): Unit = { sb.append("Interval[") pointType.pretty(sb, indent, compact) sb.append("]") @@ -31,20 +30,21 @@ case class TInterval(pointType: Type) extends Type { pointType.typeCheck(i.start) && pointType.typeCheck(i.end) } - override def genNonmissingValue(sm: HailStateManager): Gen[Annotation] = Interval.gen(pointType.ordering(sm), pointType.genValue(sm)) + override def genNonmissingValue(sm: HailStateManager): Gen[Annotation] = + Interval.gen(pointType.ordering(sm), pointType.genValue(sm)) override def scalaClassTag: ClassTag[Interval] = classTag[Interval] override def mkOrdering(sm: HailStateManager, missingEqual: Boolean): ExtendedOrdering = - Interval.ordering(pointType.ordering(sm), startPrimary=true, missingEqual) + Interval.ordering(pointType.ordering(sm), startPrimary = true, missingEqual) - lazy val structRepresentation: TStruct = { + lazy val structRepresentation: TStruct = TStruct( "start" -> pointType, "end" -> pointType, "includesStart" -> TBoolean, - "includesEnd" -> TBoolean) - } + "includesEnd" -> TBoolean, + ) override def unify(concrete: Type): Boolean = concrete match { case TInterval(cpointType) => pointType.unify(cpointType) @@ -52,4 +52,10 @@ case class TInterval(pointType: Type) extends Type { } override def subst() = TInterval(pointType.subst()) + + override def isIsomorphicTo(t: Type): Boolean = + t match { + case i: TInterval => pointType isIsomorphicTo i.pointType + case _ => false + } } diff --git a/hail/src/main/scala/is/hail/types/virtual/TLocus.scala b/hail/src/main/scala/is/hail/types/virtual/TLocus.scala index b2813a6fad7..2c52dd03de5 100644 --- a/hail/src/main/scala/is/hail/types/virtual/TLocus.scala +++ b/hail/src/main/scala/is/hail/types/virtual/TLocus.scala @@ -1,21 +1,19 @@ package is.hail.types.virtual import is.hail.annotations._ -import is.hail.backend.{BroadcastValue, HailStateManager} +import is.hail.backend.HailStateManager import is.hail.check._ -import is.hail.types.physical.PLocus -import is.hail.types.virtual.TCall.representation import is.hail.utils._ import is.hail.variant._ -import scala.reflect.{ClassTag, classTag} +import scala.reflect.{classTag, ClassTag} object TLocus { - val representation: TStruct = { + val representation: TStruct = TStruct( "contig" -> TString, - "position" -> TInt32) - } + "position" -> TInt32, + ) def schemaFromRG(rg: Option[String], required: Boolean = false): Type = rg match { // must match tlocus.schema_from_rg @@ -35,9 +33,11 @@ case class TLocus(rgName: String) extends Type { sb.append(prettyIdentifier(rgName)) sb.append('>') } + def _typeCheck(a: Any): Boolean = a.isInstanceOf[Locus] - override def genNonmissingValue(sm: HailStateManager): Gen[Annotation] = Locus.gen(sm.referenceGenomes(rgName)) + override def genNonmissingValue(sm: HailStateManager): Gen[Annotation] = + Locus.gen(sm.referenceGenomes(rgName)) override def scalaClassTag: ClassTag[Locus] = classTag[Locus] @@ -46,10 +46,17 @@ case class TLocus(rgName: String) extends Type { lazy val representation: TStruct = TLocus.representation - def locusOrdering(sm: HailStateManager): Ordering[Locus] = sm.referenceGenomes(rgName).locusOrdering + def locusOrdering(sm: HailStateManager): Ordering[Locus] = + sm.referenceGenomes(rgName).locusOrdering override def unify(concrete: Type): Boolean = concrete match { case TLocus(crgName) => rgName == crgName case _ => false } + + override def isIsomorphicTo(t: Type): Boolean = + t match { + case l: TLocus => rgName == l.rgName + case _ => false + } } diff --git a/hail/src/main/scala/is/hail/types/virtual/TNDArray.scala b/hail/src/main/scala/is/hail/types/virtual/TNDArray.scala index 42a30466e68..0e53b888904 100644 --- a/hail/src/main/scala/is/hail/types/virtual/TNDArray.scala +++ b/hail/src/main/scala/is/hail/types/virtual/TNDArray.scala @@ -1,13 +1,13 @@ package is.hail.types.virtual -import is.hail.annotations.{Annotation, ExtendedOrdering, NDArray, UnsafeIndexedSeq} +import is.hail.annotations.{Annotation, ExtendedOrdering, NDArray} import is.hail.backend.HailStateManager -import is.hail.expr.{Nat, NatBase} import is.hail.check.Gen -import is.hail.types.physical.PNDArray -import org.apache.spark.sql.Row +import is.hail.expr.{Nat, NatBase} -import scala.reflect.{ClassTag, classTag} +import scala.reflect.{classTag, ClassTag} + +import org.apache.spark.sql.Row object TNDArray { def matMulNDims(l: Int, r: Int): Int = { @@ -26,16 +26,18 @@ final case class TNDArray(elementType: Type, nDimsBase: NatBase) extends Type { nDimsBase.asInstanceOf[Nat].n } - override def valuesSimilar(a1: Annotation, a2: Annotation, tolerance: Double, absolute: Boolean): Boolean = { + override def valuesSimilar(a1: Annotation, a2: Annotation, tolerance: Double, absolute: Boolean) + : Boolean = { if (a1 == null || a2 == null) { a1 == a2 - } - else { + } else { val aNd1 = a1.asInstanceOf[NDArray] val aNd2 = a2.asInstanceOf[NDArray] val sameShape = aNd1.shape == aNd2.shape - val sameData = aNd1.getRowMajorElements().zip(aNd2.getRowMajorElements()).forall{ case (e1, e2) => elementType.valuesSimilar(e1, e2, tolerance, absolute)} + val sameData = aNd1.getRowMajorElements().zip(aNd2.getRowMajorElements()).forall { + case (e1, e2) => elementType.valuesSimilar(e1, e2, tolerance, absolute) + } sameShape && sameData } @@ -51,7 +53,7 @@ final case class TNDArray(elementType: Type, nDimsBase: NatBase) extends Type { def _toPretty = s"NDArray[$elementType,$nDims]" - override def _pretty(sb: StringBuilder, indent: Int, compact: Boolean = false) { + override def _pretty(sb: StringBuilder, indent: Int, compact: Boolean = false): Unit = { sb.append("NDArray[") elementType.pretty(sb, indent, compact) sb.append(",") @@ -60,16 +62,17 @@ final case class TNDArray(elementType: Type, nDimsBase: NatBase) extends Type { } override def str(a: Annotation): String = { - if (a == null) "NA" else { + if (a == null) "NA" + else { val aNd = a.asInstanceOf[NDArray] val shape = aNd.shape val data = aNd.getRowMajorElements() - def dataToNestedString(data: Iterator[Annotation], shape: Seq[Long], sb: StringBuilder):Unit = { + def dataToNestedString(data: Iterator[Annotation], shape: Seq[Long], sb: StringBuilder) + : Unit = { if (shape.isEmpty) { sb.append(data.next().toString) - } - else { + } else { sb.append("[") val howMany = shape.head var repeat = 0 @@ -89,16 +92,16 @@ final case class TNDArray(elementType: Type, nDimsBase: NatBase) extends Type { val prettyData = stringBuilder.result() val prettyShape = "(" + shape.mkString(", ") + ")" - s"ndarray{shape=${prettyShape}, data=${prettyData}}" + s"ndarray{shape=$prettyShape, data=$prettyData}" } } - override def unify(concrete: Type): Boolean = { + override def unify(concrete: Type): Boolean = concrete match { - case TNDArray(cElementType, cNDims) => elementType.unify(cElementType) && nDimsBase.unify(cNDims) + case TNDArray(cElementType, cNDims) => + elementType.unify(cElementType) && nDimsBase.unify(cNDims) case _ => false } - } override def clear(): Unit = { elementType.clear() @@ -109,21 +112,23 @@ final case class TNDArray(elementType: Type, nDimsBase: NatBase) extends Type { override def scalaClassTag: ClassTag[Row] = classTag[Row] - def _typeCheck(a: Annotation): Boolean = { a match { + def _typeCheck(a: Annotation): Boolean = a match { case nd: NDArray => nd.forall(e => elementType.typeCheck(e)) case _ => false } - } - override def genNonmissingValue(sm: HailStateManager): Gen[Annotation] = ??? override def mkOrdering(sm: HailStateManager, missingEqual: Boolean): ExtendedOrdering = null lazy val shapeType: TTuple = TTuple(Array.fill(nDims)(TInt64): _*) - private lazy val representation = TStruct( - ("shape", shapeType), - ("data", TArray(elementType)) - ) + override def isIsomorphicTo(t: Type): Boolean = + t match { + case nda: TNDArray => + (elementType isIsomorphicTo nda.elementType) && + nDimsBase == nda.nDimsBase + case _ => + false + } } diff --git a/hail/src/main/scala/is/hail/types/virtual/TRNGState.scala b/hail/src/main/scala/is/hail/types/virtual/TRNGState.scala index b2be498c31b..87f3b9ce160 100644 --- a/hail/src/main/scala/is/hail/types/virtual/TRNGState.scala +++ b/hail/src/main/scala/is/hail/types/virtual/TRNGState.scala @@ -7,12 +7,17 @@ import is.hail.check.Gen case object TRNGState extends Type { override def _toPretty = "RNGState" - override def pyString(sb: StringBuilder): Unit = { + override def pyString(sb: StringBuilder): Unit = sb.append("rng_state") - } override def genNonmissingValue(sm: HailStateManager): Gen[Annotation] = ??? def _typeCheck(a: Any): Boolean = ??? - def mkOrdering(sm: HailStateManager, missingEqual: Boolean): is.hail.annotations.ExtendedOrdering = ??? + + def mkOrdering(sm: HailStateManager, missingEqual: Boolean) + : is.hail.annotations.ExtendedOrdering = ??? + def scalaClassTag: scala.reflect.ClassTag[_ <: AnyRef] = ??? + + override def isIsomorphicTo(t: Type): Boolean = + this == t } diff --git a/hail/src/main/scala/is/hail/types/virtual/TSet.scala b/hail/src/main/scala/is/hail/types/virtual/TSet.scala index 975eb5b45b3..1154c0d3bf1 100644 --- a/hail/src/main/scala/is/hail/types/virtual/TSet.scala +++ b/hail/src/main/scala/is/hail/types/virtual/TSet.scala @@ -3,11 +3,10 @@ package is.hail.types.virtual import is.hail.annotations.{Annotation, ExtendedOrdering} import is.hail.backend.HailStateManager import is.hail.check.Gen -import is.hail.types.physical.PSet -import is.hail.utils._ -import org.json4s.jackson.JsonMethods -import scala.reflect.{ClassTag, classTag} +import scala.reflect.{classTag, ClassTag} + +import org.json4s.jackson.JsonMethods final case class TSet(elementType: Type) extends TContainer { def _toPretty = s"Set[$elementType]" @@ -33,7 +32,7 @@ final case class TSet(elementType: Type) extends TContainer { def _typeCheck(a: Any): Boolean = a.isInstanceOf[Set[_]] && a.asInstanceOf[Set[_]].forall(elementType.typeCheck) - override def _pretty(sb: StringBuilder, indent: Int, compact: Boolean = false) { + override def _pretty(sb: StringBuilder, indent: Int, compact: Boolean = false): Unit = { sb.append("Set[") elementType.pretty(sb, indent, compact) sb.append("]") @@ -47,10 +46,10 @@ final case class TSet(elementType: Type) extends TContainer { .map { case elt => elementType.showStr(elt) } .mkString("{", ",", "}") - override def str(a: Annotation): String = JsonMethods.compact(toJSON(a)) - override def genNonmissingValue(sm: HailStateManager): Gen[Annotation] = Gen.buildableOf[Set](elementType.genValue(sm)) + override def genNonmissingValue(sm: HailStateManager): Gen[Annotation] = + Gen.buildableOf[Set](elementType.genValue(sm)) override def scalaClassTag: ClassTag[Set[AnyRef]] = classTag[Set[AnyRef]] @@ -60,4 +59,10 @@ final case class TSet(elementType: Type) extends TContainer { } override def arrayElementsRepr: TArray = TArray(elementType) + + override def isIsomorphicTo(t: Type): Boolean = + t match { + case s: TSet => elementType isIsomorphicTo s.elementType + case _ => false + } } diff --git a/hail/src/main/scala/is/hail/types/virtual/TStream.scala b/hail/src/main/scala/is/hail/types/virtual/TStream.scala index 96d0390ffb3..e56145679b8 100644 --- a/hail/src/main/scala/is/hail/types/virtual/TStream.scala +++ b/hail/src/main/scala/is/hail/types/virtual/TStream.scala @@ -3,9 +3,10 @@ package is.hail.types.virtual import is.hail.annotations.{Annotation, ExtendedOrdering} import is.hail.backend.HailStateManager import is.hail.check.Gen -import org.json4s.jackson.JsonMethods -import scala.reflect.{ClassTag, classTag} +import scala.reflect.{classTag, ClassTag} + +import org.json4s.jackson.JsonMethods final case class TStream(elementType: Type) extends TIterable { override def pyString(sb: StringBuilder): Unit = { @@ -26,7 +27,7 @@ final case class TStream(elementType: Type) extends TIterable { override def subst() = TStream(elementType.subst()) - override def _pretty(sb: StringBuilder, indent: Int, compact: Boolean = false) { + override def _pretty(sb: StringBuilder, indent: Int, compact: Boolean = false): Unit = { sb.append("Stream[") elementType.pretty(sb, indent, compact) sb.append("]") @@ -46,4 +47,10 @@ final case class TStream(elementType: Type) extends TIterable { throw new UnsupportedOperationException("Stream comparison is currently undefined.") override def scalaClassTag: ClassTag[Iterator[AnyRef]] = classTag[Iterator[AnyRef]] + + override def isIsomorphicTo(t: Type): Boolean = + t match { + case s: TStream => elementType isIsomorphicTo s.elementType + case _ => false + } } diff --git a/hail/src/main/scala/is/hail/types/virtual/TString.scala b/hail/src/main/scala/is/hail/types/virtual/TString.scala index 2d08b383a3a..c335cba9d4b 100644 --- a/hail/src/main/scala/is/hail/types/virtual/TString.scala +++ b/hail/src/main/scala/is/hail/types/virtual/TString.scala @@ -4,17 +4,14 @@ import is.hail.annotations._ import is.hail.backend.HailStateManager import is.hail.check.Arbitrary._ import is.hail.check.Gen -import is.hail.types.physical.PString -import is.hail.utils._ import scala.reflect.{ClassTag, _} case object TString extends Type { def _toPretty = "String" - override def pyString(sb: StringBuilder): Unit = { + override def pyString(sb: StringBuilder): Unit = sb.append("str") - } override def _showStr(a: Annotation): String = "\"" + a.asInstanceOf[String] + "\"" @@ -26,4 +23,7 @@ case object TString extends Type { override def mkOrdering(sm: HailStateManager, missingEqual: Boolean): ExtendedOrdering = ExtendedOrdering.extendToNull(implicitly[Ordering[String]], missingEqual) + + override def isIsomorphicTo(t: Type): Boolean = + this == t } diff --git a/hail/src/main/scala/is/hail/types/virtual/TStruct.scala b/hail/src/main/scala/is/hail/types/virtual/TStruct.scala index d88a97eaa0f..81e35ff1301 100644 --- a/hail/src/main/scala/is/hail/types/virtual/TStruct.scala +++ b/hail/src/main/scala/is/hail/types/virtual/TStruct.scala @@ -4,16 +4,20 @@ import is.hail.annotations._ import is.hail.backend.HailStateManager import is.hail.expr.ir.{Env, IRParser, IntArrayBuilder} import is.hail.utils._ + +import scala.collection.JavaConverters._ +import scala.collection.mutable + import org.apache.spark.sql.Row import org.json4s.CustomSerializer import org.json4s.JsonAST.JString -import scala.collection.JavaConverters._ -import scala.reflect.ClassTag - -class TStructSerializer extends CustomSerializer[TStruct](format => ( - { case JString(s) => IRParser.parseStructType(s) }, - { case t: TStruct => JString(t.parsableString()) })) +class TStructSerializer extends CustomSerializer[TStruct](format => + ( + { case JString(s) => IRParser.parseStructType(s) }, + { case t: TStruct => JString(t.parsableString()) }, + ) + ) object TStruct { val empty: TStruct = TStruct() @@ -29,14 +33,18 @@ object TStruct { val sNames = names.asScala.toArray val sTypes = types.asScala.toArray if (sNames.length != sTypes.length) - fatal(s"number of names does not match number of types: found ${ sNames.length } names and ${ sTypes.length } types") + fatal( + s"number of names does not match number of types: found ${sNames.length} names and ${sTypes.length} types" + ) TStruct(sNames.zip(sTypes): _*) } def concat(struct1: TStruct, struct2: TStruct): TStruct = { - struct2.fieldNames.foreach { field => assert(!struct1.hasField(field)) } - TStruct(struct1.fields ++ struct2.fields.map(field => field.copy(index = field.index + struct1.size))) + struct2.fieldNames.foreach(field => assert(!struct1.hasField(field))) + TStruct(struct1.fields ++ struct2.fields.map(field => + field.copy(index = field.index + struct1.size) + )) } } @@ -45,7 +53,14 @@ final case class TStruct(fields: IndexedSeq[Field]) extends TBaseStruct { lazy val types: Array[Type] = fields.map(_.typ).toArray - lazy val fieldNames: Array[String] = fields.map(_.name).toArray + val fieldNames: Array[String] = { + val seen = mutable.Set.empty[String] + fields.toArray.map { f => + val name = f.name + assert(seen.add(name), f"duplicate name '$name' found in '${_toPretty}'.") + name + } + } def size: Int = fields.length @@ -56,17 +71,17 @@ final case class TStruct(fields: IndexedSeq[Field]) extends TBaseStruct { override def canCompare(other: Type): Boolean = other match { case t: TStruct => size == t.size && fields.zip(t.fields).forall { case (f1, f2) => - f1.name == f2.name && f1.typ.canCompare(f2.typ) - } + f1.name == f2.name && f1.typ.canCompare(f2.typ) + } case _ => false } override def unify(concrete: Type): Boolean = concrete match { case TStruct(cfields) => fields.length == cfields.length && - (fields, cfields).zipped.forall { case (f, cf) => - f.unify(cf) - } + (fields, cfields).zipped.forall { case (f, cf) => + f.unify(cf) + } case _ => false } @@ -86,9 +101,9 @@ final case class TStruct(fields: IndexedSeq[Field]) extends TBaseStruct { if (path.isEmpty) None else (1 until path.length).foldLeft(selfField(path.head)) { case (Some(f), i) => f.typ match { - case s: TStruct => s.selfField(path(i)) - case _ => return None - } + case s: TStruct => s.selfField(path(i)) + case _ => return None + } case _ => return None } @@ -100,12 +115,15 @@ final case class TStruct(fields: IndexedSeq[Field]) extends TBaseStruct { case Some(f) => val (t, q) = f.typ.queryTyped(p.tail) val localIndex = f.index - (t, (a: Any) => - if (a == null) - null - else - q(a.asInstanceOf[Row].get(localIndex))) - case None => throw new AnnotationPathException(s"struct has no field ${ p.head }") + ( + t, + (a: Any) => + if (a == null) + null + else + q(a.asInstanceOf[Row].get(localIndex)), + ) + case None => throw new AnnotationPathException(s"struct has no field ${p.head}") } } } @@ -117,7 +135,6 @@ final case class TStruct(fields: IndexedSeq[Field]) extends TBaseStruct { val missing: Annotation = null.asInstanceOf[Annotation] - def updateField(typ: TStruct, idx: Int)(f: Inserter)(a: Annotation, v: Any): Annotation = a match { case r: Row => @@ -132,9 +149,8 @@ final case class TStruct(fields: IndexedSeq[Field]) extends TBaseStruct { val arr = new Array[Any](typ.size + 1) a match { case r: Row => - for (i <- 0 until typ.size) { + for (i <- 0 until typ.size) arr.update(i, r.get(i)) - } case _ => } arr(typ.size) = f(missing, v) @@ -149,15 +165,18 @@ final case class TStruct(fields: IndexedSeq[Field]) extends TBaseStruct { parent.selfField(name) match { case Some(Field(name, t, idx)) => ( - t match { case s: TStruct => s case _ => TStruct.empty }, + t match { + case s: TStruct => s + case _ => TStruct.empty + }, typ => parent.updateKey(name, idx, typ), - updateField(parent, idx) + updateField(parent, idx), ) case None => ( TStruct.empty, typ => parent.appendKey(name, typ), - addField(parent) + addField(parent), ) } } @@ -170,7 +189,10 @@ final case class TStruct(fields: IndexedSeq[Field]) extends TBaseStruct { } def structInsert(signature: Type, p: IndexedSeq[String]): TStruct = { - require(p.nonEmpty || signature.isInstanceOf[TStruct], s"tried to remap top-level struct to non-struct $signature") + require( + p.nonEmpty || signature.isInstanceOf[TStruct], + s"tried to remap top-level struct to non-struct $signature", + ) val (t, _) = insert(signature, p) t } @@ -216,7 +238,7 @@ final case class TStruct(fields: IndexedSeq[Field]) extends TBaseStruct { // In fieldIdxBuilder, positive integers are field indices from the left. // Negative integers are the complement of field indices from the right. - val rightFieldIdx = other.fields.map { f => f.name -> (f.index -> f.typ) }.toMap + val rightFieldIdx = other.fields.map(f => f.name -> (f.index -> f.typ)).toMap val leftFields = fieldNames.toSet fields.foreach { f => @@ -295,10 +317,9 @@ final case class TStruct(fields: IndexedSeq[Field]) extends TBaseStruct { val notFound = set.filter(name => selfField(name).isEmpty).map(prettyIdentifier) if (notFound.nonEmpty) fatal( - s"""invalid struct filter operation: ${ - plural(notFound.size, s"field ${ notFound.head }", s"fields [ ${ notFound.mkString(", ") } ]") - } not found - | Existing struct fields: [ ${ fields.map(f => prettyIdentifier(f.name)).mkString(", ") } ]""".stripMargin) + s"""invalid struct filter operation: ${plural(notFound.size, s"field ${notFound.head}", s"fields [ ${notFound.mkString(", ")} ]")} not found + | Existing struct fields: [ ${fields.map(f => prettyIdentifier(f.name)).mkString(", ")} ]""".stripMargin + ) val fn = (f: Field) => if (include) @@ -310,9 +331,10 @@ final case class TStruct(fields: IndexedSeq[Field]) extends TBaseStruct { def ++(that: TStruct): TStruct = { val overlapping = fields.map(_.name).toSet.intersect( - that.fields.map(_.name).toSet) + that.fields.map(_.name).toSet + ) if (overlapping.nonEmpty) - fatal(s"overlapping fields in struct concatenation: ${ overlapping.mkString(", ") }") + fatal(s"overlapping fields in struct concatenation: ${overlapping.mkString(", ")}") TStruct(fields.map(f => (f.name, f.typ)) ++ that.fields.map(f => (f.name, f.typ)): _*) } @@ -357,11 +379,11 @@ final case class TStruct(fields: IndexedSeq[Field]) extends TBaseStruct { sb.append(prettyIdentifier(field.name)) sb.append(": ") field.typ.pyString(sb) - }) { sb.append(", ")} + })(sb.append(", ")) sb.append('}') } - override def _pretty(sb: StringBuilder, indent: Int, compact: Boolean) { + override def _pretty(sb: StringBuilder, indent: Int, compact: Boolean): Unit = { if (compact) { sb.append("Struct{") fields.foreachBetween(_.pretty(sb, indent, compact))(sb += ',') @@ -381,14 +403,10 @@ final case class TStruct(fields: IndexedSeq[Field]) extends TBaseStruct { } def select(keep: IndexedSeq[String]): (TStruct, (Row) => Row) = { - val t = TStruct(keep.map { n => - n -> field(n).typ - }: _*) + val t = TStruct(keep.map(n => n -> field(n).typ): _*) val keepIdx = keep.map(fieldIdx) - val selectF: Row => Row = { r => - Row.fromSeq(keepIdx.map(r.get)) - } + val selectF: Row => Row = { r => Row.fromSeq(keepIdx.map(r.get)) } (t, selectF) } @@ -405,7 +423,8 @@ final case class TStruct(fields: IndexedSeq[Field]) extends TBaseStruct { return identity val subStruct = subtype.asInstanceOf[TStruct] - val subsetFields = subStruct.fields.map(f => (fieldIdx(f.name), fieldType(f.name).valueSubsetter(f.typ))) + val subsetFields = + subStruct.fields.map(f => (fieldIdx(f.name), fieldType(f.name).valueSubsetter(f.typ))) { (a: Any) => val r = a.asInstanceOf[Row] diff --git a/hail/src/main/scala/is/hail/types/virtual/TTuple.scala b/hail/src/main/scala/is/hail/types/virtual/TTuple.scala index 1d52a46d2a6..afeb4f3f8d5 100644 --- a/hail/src/main/scala/is/hail/types/virtual/TTuple.scala +++ b/hail/src/main/scala/is/hail/types/virtual/TTuple.scala @@ -3,12 +3,15 @@ package is.hail.types.virtual import is.hail.annotations.ExtendedOrdering import is.hail.backend.HailStateManager import is.hail.utils._ + import org.apache.spark.sql.Row object TTuple { val empty: TTuple = TTuple() - def apply(args: Type*): TTuple = TTuple(args.iterator.zipWithIndex.map { case (t, i) => TupleField(i, t) }.toArray) + def apply(args: Type*): TTuple = TTuple(args.iterator.zipWithIndex.map { case (t, i) => + TupleField(i, t) + }.toArray) } case class TupleField(index: Int, typ: Type) @@ -16,9 +19,13 @@ case class TupleField(index: Int, typ: Type) final case class TTuple(_types: IndexedSeq[TupleField]) extends TBaseStruct { lazy val types: Array[Type] = _types.map(_.typ).toArray - lazy val fields: IndexedSeq[Field] = _types.zipWithIndex.map { case (tf, i) => Field(s"${ tf.index }", tf.typ, i) } + lazy val fields: IndexedSeq[Field] = _types.zipWithIndex.map { case (tf, i) => + Field(s"${tf.index}", tf.typ, i) + } - lazy val fieldIndex: Map[Int, Int] = _types.zipWithIndex.map { case (tf, idx) => tf.index -> idx }.toMap + lazy val fieldIndex: Map[Int, Int] = _types.zipWithIndex.map { case (tf, idx) => + tf.index -> idx + }.toMap override def mkOrdering(sm: HailStateManager, missingEqual: Boolean): ExtendedOrdering = TBaseStruct.getOrdering(sm, types, missingEqual) @@ -31,22 +38,24 @@ final case class TTuple(_types: IndexedSeq[TupleField]) extends TBaseStruct { TTuple(_types.take(newSize)) override def canCompare(other: Type): Boolean = other match { - case t: TTuple => size == t.size && _types.zip(t._types).forall { case (t1, t2) => t1.index == t2.index && t1.typ.canCompare(t2.typ) } + case t: TTuple => size == t.size && _types.zip(t._types).forall { case (t1, t2) => + t1.index == t2.index && t1.typ.canCompare(t2.typ) + } case _ => false } override def unify(concrete: Type): Boolean = concrete match { case TTuple(ctypes) => size == ctypes.length && - (types, ctypes).zipped.forall { case (t, ct) => - t.unify(ct.typ) - } + (types, ctypes).zipped.forall { case (t, ct) => + t.unify(ct.typ) + } case _ => false } override def subst() = TTuple(_types.map(tf => tf.copy(typ = tf.typ.subst()))) - override def _pretty(sb: StringBuilder, indent: Int, compact: Boolean) { + override def _pretty(sb: StringBuilder, indent: Int, compact: Boolean): Unit = { if (!_isCanonical) { sb.append("TupleSubset[") fields.foreachBetween { fd => @@ -57,7 +66,7 @@ final case class TTuple(_types: IndexedSeq[TupleField]) extends TBaseStruct { sb += ']' } else { sb.append("Tuple[") - _types.foreachBetween { fd => fd.typ.pretty(sb, indent, compact) }(sb += ',') + _types.foreachBetween(fd => fd.typ.pretty(sb, indent, compact))(sb += ',') sb += ']' } } @@ -69,10 +78,10 @@ final case class TTuple(_types: IndexedSeq[TupleField]) extends TBaseStruct { sb.append(field.name) sb.append(':') field.typ.pyString(sb) - }) { sb.append(", ") } + })(sb.append(", ")) sb.append(')') } else { - fields.foreachBetween({ field => field.typ.pyString(sb) }) { sb.append(", ") } + fields.foreachBetween({ field => field.typ.pyString(sb) })(sb.append(", ")) } sb.append(')') } @@ -82,7 +91,8 @@ final case class TTuple(_types: IndexedSeq[TupleField]) extends TBaseStruct { return identity val subTuple = subtype.asInstanceOf[TTuple] - val subsetFields = subTuple.fields.map(f => (fieldIndex(f.index), fields(f.index).typ.valueSubsetter(f.typ))) + val subsetFields = + subTuple.fields.map(f => (fieldIndex(f.index), fields(f.index).typ.valueSubsetter(f.typ))) { (a: Any) => val r = a.asInstanceOf[Row] diff --git a/hail/src/main/scala/is/hail/types/virtual/TUnion.scala b/hail/src/main/scala/is/hail/types/virtual/TUnion.scala index 3520ed370d9..1f1663efdcb 100644 --- a/hail/src/main/scala/is/hail/types/virtual/TUnion.scala +++ b/hail/src/main/scala/is/hail/types/virtual/TUnion.scala @@ -5,11 +5,12 @@ import is.hail.backend.HailStateManager import is.hail.check.Gen import is.hail.expr.ir.IRParser import is.hail.utils._ -import org.json4s.CustomSerializer -import org.json4s.JsonAST.JString import scala.reflect.ClassTag +import org.json4s.CustomSerializer +import org.json4s.JsonAST.JString + final case class Case(name: String, typ: Type, index: Int) { def unify(cf: Case): Boolean = @@ -17,7 +18,7 @@ final case class Case(name: String, typ: Type, index: Int) { typ.unify(cf.typ) && index == cf.index - def pretty(sb: StringBuilder, indent: Int, compact: Boolean) { + def pretty(sb: StringBuilder, indent: Int, compact: Boolean): Unit = { if (compact) { sb.append(prettyIdentifier(name)) sb.append(":") @@ -30,9 +31,12 @@ final case class Case(name: String, typ: Type, index: Int) { } } -class TUnionSerializer extends CustomSerializer[TUnion](format => ( - { case JString(s) => IRParser.parseUnionType(s) }, - { case t: TUnion => JString(t.parsableString()) })) +class TUnionSerializer extends CustomSerializer[TUnion](format => + ( + { case JString(s) => IRParser.parseUnionType(s) }, + { case t: TUnion => JString(t.parsableString()) }, + ) + ) object TUnion { val empty: TUnion = TUnion(FastSeq()) @@ -57,9 +61,9 @@ final case class TUnion(cases: IndexedSeq[Case]) extends Type { override def unify(concrete: Type): Boolean = concrete match { case TUnion(cfields) => cases.length == cfields.length && - (cases, cfields).zipped.forall { case (f, cf) => - f.unify(cf) - } + (cases, cfields).zipped.forall { case (f, cf) => + f.unify(cf) + } case _ => false } @@ -98,11 +102,11 @@ final case class TUnion(cases: IndexedSeq[Case]) extends Type { sb.append(prettyIdentifier(field.name)) sb.append(": ") field.typ.pyString(sb) - }) { sb.append(", ")} + })(sb.append(", ")) sb.append('}') } - override def _pretty(sb: StringBuilder, indent: Int, compact: Boolean) { + override def _pretty(sb: StringBuilder, indent: Int, compact: Boolean): Unit = { if (compact) { sb.append("Union{") cases.foreachBetween(_.pretty(sb, indent, compact))(sb += ',') @@ -126,4 +130,13 @@ final case class TUnion(cases: IndexedSeq[Case]) extends Type { override def scalaClassTag: ClassTag[AnyRef] = ??? override def mkOrdering(sm: HailStateManager, missingEqual: Boolean): ExtendedOrdering = ??? + + override def isIsomorphicTo(t: Type): Boolean = + t match { + case u: TUnion => + size == u.size && + (cases, u.cases).zipped.forall(_.typ isIsomorphicTo _.typ) + case _ => + false + } } diff --git a/hail/src/main/scala/is/hail/types/virtual/TVariable.scala b/hail/src/main/scala/is/hail/types/virtual/TVariable.scala index fffacedd7d6..25a6a628f83 100644 --- a/hail/src/main/scala/is/hail/types/virtual/TVariable.scala +++ b/hail/src/main/scala/is/hail/types/virtual/TVariable.scala @@ -4,7 +4,6 @@ import is.hail.annotations.{Annotation, ExtendedOrdering} import is.hail.backend.HailStateManager import is.hail.check.Gen import is.hail.types.Box -import is.hail.types.physical.PType import scala.collection.mutable import scala.reflect.ClassTag @@ -18,7 +17,8 @@ object TVariable { "float64" -> ((t: Type) => t == TFloat64), "locus" -> ((t: Type) => t.isInstanceOf[TLocus]), "struct" -> ((t: Type) => t.isInstanceOf[TStruct]), - "tuple" -> ((t: Type) => t.isInstanceOf[TTuple])) + "tuple" -> ((t: Type) => t.isInstanceOf[TTuple]), + ) private[this] val namedBoxes: mutable.Map[String, Box[Type]] = mutable.Map() @@ -50,9 +50,8 @@ final case class TVariable(name: String, cond: String = null) extends Type { else s"?$name" - override def pyString(sb: StringBuilder): Unit = { + override def pyString(sb: StringBuilder): Unit = sb.append(_toPretty) - } override def isRealizable = false @@ -64,9 +63,8 @@ final case class TVariable(name: String, cond: String = null) extends Type { override def isBound: Boolean = b.isEmpty - override def clear() { + override def clear(): Unit = b.clear() - } override def subst(): Type = { assert(b.isDefined) @@ -75,7 +73,11 @@ final case class TVariable(name: String, cond: String = null) extends Type { override def genNonmissingValue(sm: HailStateManager): Gen[Annotation] = ??? - override def scalaClassTag: ClassTag[AnyRef] = throw new RuntimeException("TVariable is not realizable") + override def scalaClassTag: ClassTag[AnyRef] = + throw new RuntimeException("TVariable is not realizable") override def mkOrdering(sm: HailStateManager, missingEqual: Boolean): ExtendedOrdering = null + + override def isIsomorphicTo(t: Type): Boolean = + false } diff --git a/hail/src/main/scala/is/hail/types/virtual/TVoid.scala b/hail/src/main/scala/is/hail/types/virtual/TVoid.scala index e0a527da6ba..2f1c16e23ec 100644 --- a/hail/src/main/scala/is/hail/types/virtual/TVoid.scala +++ b/hail/src/main/scala/is/hail/types/virtual/TVoid.scala @@ -3,22 +3,24 @@ package is.hail.types.virtual import is.hail.annotations.{Annotation, ExtendedOrdering} import is.hail.backend.HailStateManager import is.hail.check.Gen -import is.hail.types.physical.PVoid case object TVoid extends Type { override def _toPretty = "Void" - override def pyString(sb: StringBuilder): Unit = { + override def pyString(sb: StringBuilder): Unit = sb.append("void") - } def genNonmissingValue(sm: HailStateManager): Gen[Annotation] = ??? override def mkOrdering(sm: HailStateManager, missingEqual: Boolean): ExtendedOrdering = null - override def scalaClassTag: scala.reflect.ClassTag[_ <: AnyRef] = throw new UnsupportedOperationException("No ClassTag for Void") + override def scalaClassTag: scala.reflect.ClassTag[_ <: AnyRef] = + throw new UnsupportedOperationException("No ClassTag for Void") - override def _typeCheck(a: Any): Boolean = throw new UnsupportedOperationException("No elements of Void") + override def _typeCheck(a: Any): Boolean = a.isInstanceOf[Unit] override def isRealizable = false + + override def isIsomorphicTo(t: Type): Boolean = + this == t } diff --git a/hail/src/main/scala/is/hail/types/virtual/Type.scala b/hail/src/main/scala/is/hail/types/virtual/Type.scala index 285b1c0d92e..ad38299366d 100644 --- a/hail/src/main/scala/is/hail/types/virtual/Type.scala +++ b/hail/src/main/scala/is/hail/types/virtual/Type.scala @@ -3,21 +3,25 @@ package is.hail.types.virtual import is.hail.annotations._ import is.hail.backend.HailStateManager import is.hail.check.{Arbitrary, Gen} -import is.hail.expr.ir._ import is.hail.expr.{JSONAnnotationImpex, SparkAnnotationImpex} +import is.hail.expr.ir._ import is.hail.types._ import is.hail.utils import is.hail.utils._ import is.hail.variant.ReferenceGenome -import org.apache.spark.sql.types.DataType -import org.json4s.JsonAST.JString -import org.json4s.{CustomSerializer, JValue} import scala.reflect.ClassTag -class TypeSerializer extends CustomSerializer[Type](format => ( - { case JString(s) => IRParser.parseType(s) }, - { case t: Type => JString(t.parsableString()) })) +import org.apache.spark.sql.types.DataType +import org.json4s.{CustomSerializer, JValue} +import org.json4s.JsonAST.JString + +class TypeSerializer extends CustomSerializer[Type](format => + ( + { case JString(s) => IRParser.parseType(s) }, + { case t: Type => JString(t.parsableString()) }, + ) + ) object Type { def genScalar(): Gen[Type] = @@ -32,26 +36,23 @@ object Type { def genFields(genFieldType: Gen[Type]): Gen[Array[Field]] = { Gen.buildableOf[Array]( - Gen.zip(Gen.identifier, genFieldType)) + Gen.zip(Gen.identifier, genFieldType) + ) .filter(fields => fields.map(_._1).areDistinct()) - .map(fields => fields - .iterator - .zipWithIndex - .map { case ((k, t), i) => Field(k, t, i) } - .toArray) + .map(fields => + fields + .iterator + .zipWithIndex + .map { case ((k, t), i) => Field(k, t, i) } + .toArray + ) } - def preGenStruct(genFieldType: Gen[Type]): Gen[TStruct] = { - for (fields <- genFields(genFieldType)) yield { - TStruct(fields) - } - } + def preGenStruct(genFieldType: Gen[Type]): Gen[TStruct] = + for (fields <- genFields(genFieldType)) yield TStruct(fields) - def preGenTuple(genFieldType: Gen[Type]): Gen[TTuple] = { - for (fields <- genFields(genFieldType)) yield { - TTuple(fields.map(_.typ): _*) - } - } + def preGenTuple(genFieldType: Gen[Type]): Gen[TTuple] = + for (fields <- genFields(genFieldType)) yield TTuple(fields.map(_.typ): _*) private val defaultRequiredGenRatio = 0.2 def genStruct: Gen[TStruct] = Gen.coin(defaultRequiredGenRatio).flatMap(c => preGenStruct(genArb)) @@ -65,18 +66,28 @@ object Type { Gen.frequency( (4, genScalar()), (1, genComplexType()), - (1, genArb.map { - TArray(_) - }), - (1, genArb.map { - TSet(_) - }), - (1, genArb.map { - TInterval(_) - }), + ( + 1, + genArb.map { + TArray(_) + }, + ), + ( + 1, + genArb.map { + TSet(_) + }, + ), + ( + 1, + genArb.map { + TInterval(_) + }, + ), (1, preGenTuple(genArb)), (1, Gen.zip(genRequired, genArb).map { case (k, v) => TDict(k, v) }), - (1, genTStruct.resize(size))) + (1, genTStruct.resize(size)), + ) } def preGenArb(genStruct: Gen[TStruct] = genStruct): Gen[Type] = @@ -98,7 +109,8 @@ object Type { v <- t.genValue(sm).resize(y) } yield (t, v) - implicit def arbType = Arbitrary(genArb) + implicit def arbType: Arbitrary[Type] = + Arbitrary(genArb) } abstract class Type extends BaseType with Serializable { @@ -127,28 +139,25 @@ abstract class Type extends BaseType with Serializable { def query(fields: String*): Querier = query(fields.toList) def query(path: List[String]): Querier = { - val (t, q) = queryTyped(path) + val (_, q) = queryTyped(path) q } def queryTyped(fields: String*): (Type, Querier) = queryTyped(fields.toList) - def queryTyped(path: List[String]): (Type, Querier) = { + def queryTyped(path: List[String]): (Type, Querier) = if (path.nonEmpty) - throw new AnnotationPathException(s"invalid path ${ path.mkString(".") } from type ${ this }") + throw new AnnotationPathException(s"invalid path ${path.mkString(".")} from type ${this}") else (this, identity[Annotation]) - } - final def pretty(sb: StringBuilder, indent: Int, compact: Boolean) { + final def pretty(sb: StringBuilder, indent: Int, compact: Boolean): Unit = _pretty(sb, indent, compact) - } def _toPretty: String - def _pretty(sb: StringBuilder, indent: Int, compact: Boolean) { + def _pretty(sb: StringBuilder, indent: Int, compact: Boolean): Unit = sb.append(_toPretty) - } def schema: DataType = SparkAnnotationImpex.exportType(this) @@ -175,8 +184,14 @@ abstract class Type extends BaseType with Serializable { def isRealizable: Boolean = children.forall(_.isRealizable) - /* compare values for equality, but compare Float and Double values by the absolute value of their difference is within tolerance or with D_== */ - def valuesSimilar(a1: Annotation, a2: Annotation, tolerance: Double = utils.defaultTolerance, absolute: Boolean = false): Boolean = a1 == a2 + /* compare values for equality, but compare Float and Double values by the absolute value of their + * difference is within tolerance or with D_== */ + def valuesSimilar( + a1: Annotation, + a2: Annotation, + tolerance: Double = utils.defaultTolerance, + absolute: Boolean = false, + ): Boolean = a1 == a2 def scalaClassTag: ClassTag[_ <: AnyRef] @@ -185,18 +200,17 @@ abstract class Type extends BaseType with Serializable { def mkOrdering(sm: HailStateManager, missingEqual: Boolean = true): ExtendedOrdering @transient protected var ord: ExtendedOrdering = _ + def ordering(sm: HailStateManager): ExtendedOrdering = { if (ord == null) ord = mkOrdering(sm) ord } - def jsonReader: JSONReader[Annotation] = new JSONReader[Annotation] { - def fromJSON(a: JValue): Annotation = JSONAnnotationImpex.importAnnotation(a, self) - } + def jsonReader: JSONReader[Annotation] = + (a: JValue) => JSONAnnotationImpex.importAnnotation(a, self) - def jsonWriter: JSONWriter[Annotation] = new JSONWriter[Annotation] { - def toJSON(pk: Annotation): JValue = JSONAnnotationImpex.exportAnnotation(pk, self) - } + def jsonWriter: JSONWriter[Annotation] = + (pk: Annotation) => JSONAnnotationImpex.exportAnnotation(pk, self) def _typeCheck(a: Any): Boolean @@ -207,31 +221,5 @@ abstract class Type extends BaseType with Serializable { identity } - def canCastTo(t: Type): Boolean = this match { - case TInterval(tt1) => t match { - case TInterval(tt2) => tt1.canCastTo(tt2) - case _ => false - } - case TStruct(f1) => t match { - case TStruct(f2) => f1.size == f2.size && f1.indices.forall(i => f1(i).typ.canCastTo(f2(i).typ)) - case _ => false - } - case TTuple(f1) => t match { - case TTuple(f2) => f1.size == f2.size && f1.indices.forall(i => f1(i).typ.canCastTo(f2(i).typ)) - case _ => false - } - case TArray(t1) => t match { - case TArray(t2) => t1.canCastTo(t2) - case _ => false - } - case TSet(t1) => t match { - case TSet(t2) => t1.canCastTo(t2) - case _ => false - } - case TDict(k1, v1) => t match { - case TDict(k2, v2) => k1.canCastTo(k2) && v1.canCastTo(v2) - case _ => false - } - case _ => this == t - } + def isIsomorphicTo(t: Type): Boolean } diff --git a/hail/src/main/scala/is/hail/utils/AbsoluteFuzzyComparable.scala b/hail/src/main/scala/is/hail/utils/AbsoluteFuzzyComparable.scala index 3b1b0e7f524..3709b4f9b0e 100644 --- a/hail/src/main/scala/is/hail/utils/AbsoluteFuzzyComparable.scala +++ b/hail/src/main/scala/is/hail/utils/AbsoluteFuzzyComparable.scala @@ -1,19 +1,21 @@ package is.hail.utils trait AbsoluteFuzzyComparable[A] { - def absoluteEq(tolerance: Double, x: A, y: A) : Boolean + def absoluteEq(tolerance: Double, x: A, y: A): Boolean } object AbsoluteFuzzyComparable { - def absoluteEq[T](tolerance: Double, x: T, y: T)(implicit afc: AbsoluteFuzzyComparable[T]): Boolean = + def absoluteEq[T](tolerance: Double, x: T, y: T)(implicit afc: AbsoluteFuzzyComparable[T]) + : Boolean = afc.absoluteEq(tolerance, x, y) implicit object afcDoubles extends AbsoluteFuzzyComparable[Double] { def absoluteEq(tolerance: Double, x: Double, y: Double) = Math.abs(x - y) <= tolerance } - implicit def afcMaps[K, V](implicit vRFC: AbsoluteFuzzyComparable[V]): AbsoluteFuzzyComparable[Map[K, V]] = + implicit def afcMaps[K, V](implicit vRFC: AbsoluteFuzzyComparable[V]) + : AbsoluteFuzzyComparable[Map[K, V]] = new AbsoluteFuzzyComparable[Map[K, V]] { def absoluteEq(tolerance: Double, x: Map[K, V], y: Map[K, V]) = x.keySet == y.keySet && x.keys.forall(k => vRFC.absoluteEq(tolerance, x(k), y(k))) diff --git a/hail/src/main/scala/is/hail/utils/ArrayOfByteArrayOutputStream.scala b/hail/src/main/scala/is/hail/utils/ArrayOfByteArrayOutputStream.scala index a029f273f23..ccd55fd5b65 100644 --- a/hail/src/main/scala/is/hail/utils/ArrayOfByteArrayOutputStream.scala +++ b/hail/src/main/scala/is/hail/utils/ArrayOfByteArrayOutputStream.scala @@ -6,22 +6,18 @@ class ArrayOfByteArrayOutputStream(initialBufferCapacity: Int) extends OutputStr val MAX_ARRAY_SIZE = Integer.MAX_VALUE - 8; - /** - * The buffer where data is stored. - */ + /** The buffer where data is stored. */ protected var buf = new BoxedArrayBuilder[ByteArrayOutputStream](1) buf += new ByteArrayOutputStream(initialBufferCapacity) protected var bytesInCurrentArray = 0 protected var currentArray = buf(0) - /** - * Creates a new byte array output stream. The buffer capacity is - * initially 32 bytes, though its size increases if necessary. + /** Creates a new byte array output stream. The buffer capacity is initially 32 bytes, though its + * size increases if necessary. */ - def this() { + def this() = this(32) - } def ensureNextByte(): Unit = { if (bytesInCurrentArray == MAX_ARRAY_SIZE) { @@ -52,7 +48,6 @@ class ArrayOfByteArrayOutputStream(initialBufferCapacity: Int) extends OutputStr } } - def toByteArrays(): Array[Array[Byte]] = { + def toByteArrays(): Array[Array[Byte]] = buf.result().map(_.toByteArray) - } } diff --git a/hail/src/main/scala/is/hail/utils/ArrayStack.scala b/hail/src/main/scala/is/hail/utils/ArrayStack.scala index 6c388cae087..3ee6131745a 100644 --- a/hail/src/main/scala/is/hail/utils/ArrayStack.scala +++ b/hail/src/main/scala/is/hail/utils/ArrayStack.scala @@ -15,9 +15,8 @@ final class ObjectArrayStack[T <: AnyRef](hintSize: Int = 16)(implicit tct: Clas def nonEmpty: Boolean = size_ > 0 - def clear(): Unit = { + def clear(): Unit = size_ = 0 - } def top: T = { assert(size_ > 0) @@ -26,7 +25,7 @@ final class ObjectArrayStack[T <: AnyRef](hintSize: Int = 16)(implicit tct: Clas def topOption: Option[T] = if (size_ > 0) Some(top) else None - def push(x: T) { + def push(x: T): Unit = { if (size_ == a.length) { val newA = new Array[T](size_ * 2) System.arraycopy(a, 0, newA, 0, size_) @@ -44,7 +43,7 @@ final class ObjectArrayStack[T <: AnyRef](hintSize: Int = 16)(implicit tct: Clas x } - def update(i: Int, x: T) { + def update(i: Int, x: T): Unit = { assert(i >= 0 && i < size_) a(size_ - i - 1) = x } @@ -70,9 +69,8 @@ final class LongArrayStack(hintSize: Int = 16) { def nonEmpty: Boolean = size_ > 0 - def clear(): Unit = { + def clear(): Unit = size_ = 0 - } def top: Long = { assert(size_ > 0) @@ -81,7 +79,7 @@ final class LongArrayStack(hintSize: Int = 16) { def topOption: Option[Long] = if (size_ > 0) Some(top) else None - def push(x: Long) { + def push(x: Long): Unit = { if (size_ == a.length) { val newA = new Array[Long](size_ * 2) System.arraycopy(a, 0, newA, 0, size_) @@ -99,7 +97,7 @@ final class LongArrayStack(hintSize: Int = 16) { x } - def update(i: Int, x: Long) { + def update(i: Int, x: Long): Unit = { assert(i >= 0 && i < size_) a(size_ - i - 1) = x } @@ -125,9 +123,8 @@ final class IntArrayStack(hintSize: Int = 16) { def nonEmpty: Boolean = size_ > 0 - def clear(): Unit = { + def clear(): Unit = size_ = 0 - } def top: Int = { assert(size_ > 0) @@ -136,7 +133,7 @@ final class IntArrayStack(hintSize: Int = 16) { def topOption: Option[Int] = if (size_ > 0) Some(top) else None - def push(x: Int) { + def push(x: Int): Unit = { if (size_ == a.length) { val newA = new Array[Int](size_ * 2) System.arraycopy(a, 0, newA, 0, size_) @@ -154,7 +151,7 @@ final class IntArrayStack(hintSize: Int = 16) { x } - def update(i: Int, x: Int) { + def update(i: Int, x: Int): Unit = { assert(i >= 0 && i < size_) a(size_ - i - 1) = x } @@ -165,4 +162,4 @@ final class IntArrayStack(hintSize: Int = 16) { } def toArray: Array[Int] = (0 until size).map(apply).toArray -} \ No newline at end of file +} diff --git a/hail/src/main/scala/is/hail/utils/BinaryHeap.scala b/hail/src/main/scala/is/hail/utils/BinaryHeap.scala index 3bb2eec6e31..c12aad44f31 100644 --- a/hail/src/main/scala/is/hail/utils/BinaryHeap.scala +++ b/hail/src/main/scala/is/hail/utils/BinaryHeap.scala @@ -1,9 +1,10 @@ package is.hail.utils -import Math.signum import scala.collection.mutable import scala.reflect.ClassTag +import java.lang.Math.signum + class BinaryHeap[T: ClassTag](minimumCapacity: Int = 32, maybeTieBreaker: (T, T) => Double = null) { private var ts: Array[T] = new Array[T](minimumCapacity) private var ranks: Array[Long] = new Array[Long](minimumCapacity) @@ -32,11 +33,13 @@ class BinaryHeap[T: ClassTag](minimumCapacity: Int = 32, maybeTieBreaker: (T, T) def nonEmpty: Boolean = next != 0 override def toString(): String = - s"values: ${ ts.slice(0, next): IndexedSeq[T] }; ranks: ${ ranks.slice(0, next): IndexedSeq[Long] }" + s"values: ${ts.slice(0, next): IndexedSeq[T]}; ranks: ${ranks.slice(0, next): IndexedSeq[Long]}" - def insert(t: T, r: Long) { + def insert(t: T, r: Long): Unit = { if (m.contains(t)) - throw new RuntimeException(s"key $t already exists with priority ${ ranks(m(t)) }, cannot add it again with priority $r") + throw new RuntimeException( + s"key $t already exists with priority ${ranks(m(t))}, cannot add it again with priority $r" + ) maybeGrow() put(next, t, r) bubbleUp(next) @@ -76,14 +79,14 @@ class BinaryHeap[T: ClassTag](minimumCapacity: Int = 32, maybeTieBreaker: (T, T) def getPriority(t: T): Long = ranks(m(t)) - def decreasePriorityTo(t: T, r: Long) { + def decreasePriorityTo(t: T, r: Long): Unit = { val i = m(t) assert(ranks(i) > r) ranks(i) = r bubbleDown(i) } - def decreasePriority(t: T, f: (Long) => Long) { + def decreasePriority(t: T, f: (Long) => Long): Unit = { val i = m(t) val r = f(ranks(i)) assert(ranks(i) > r) @@ -91,14 +94,14 @@ class BinaryHeap[T: ClassTag](minimumCapacity: Int = 32, maybeTieBreaker: (T, T) bubbleDown(i) } - def increasePriorityTo(t: T, r: Long) { + def increasePriorityTo(t: T, r: Long): Unit = { val i = m(t) assert(ranks(i) < r) ranks(i) = r bubbleUp(i) } - def increasePriority(t: T, f: (Long) => Long) { + def increasePriority(t: T, f: (Long) => Long): Unit = { val i = m(t) val r = f(ranks(i)) assert(ranks(i) < r) @@ -118,13 +121,13 @@ class BinaryHeap[T: ClassTag](minimumCapacity: Int = 32, maybeTieBreaker: (T, T) private def parent(i: Int) = if (i == 0) 0 else (i - 1) >>> 1 - private def put(to: Int, t: T, rank: Long) { + private def put(to: Int, t: T, rank: Long): Unit = { ts(to) = t ranks(to) = rank m(t) = to } - private def swap(i: Int, j: Int) { + private def swap(i: Int, j: Int): Unit = { val tempt = ts(i) ts(i) = ts(j) ts(j) = tempt @@ -135,7 +138,7 @@ class BinaryHeap[T: ClassTag](minimumCapacity: Int = 32, maybeTieBreaker: (T, T) m(ts(i)) = i } - private def maybeGrow() { + private def maybeGrow(): Unit = { if (next >= ts.length) { val ts2 = new Array[T](ts.length << 1) val ranks2 = new Array[Long](ts.length << 1) @@ -146,7 +149,7 @@ class BinaryHeap[T: ClassTag](minimumCapacity: Int = 32, maybeTieBreaker: (T, T) } } - private def maybeShrink() { + private def maybeShrink(): Unit = { if (next >= minimumCapacity && next < (ts.length >>> 2)) { val ts2 = new Array[T](ts.length >>> 2) val ranks2 = new Array[Long](ts.length >>> 2) @@ -157,7 +160,7 @@ class BinaryHeap[T: ClassTag](minimumCapacity: Int = 32, maybeTieBreaker: (T, T) } } - private def bubbleUp(i: Int) { + private def bubbleUp(i: Int): Unit = { var current = i var p = parent(current) while (ranks(current) > ranks(p) || isLeftFavoredTie(current, p)) { @@ -167,7 +170,7 @@ class BinaryHeap[T: ClassTag](minimumCapacity: Int = 32, maybeTieBreaker: (T, T) } } - private def bubbleDown(i: Int) { + private def bubbleDown(i: Int): Unit = { var current = i var largest = current var continue = false @@ -175,9 +178,19 @@ class BinaryHeap[T: ClassTag](minimumCapacity: Int = 32, maybeTieBreaker: (T, T) val leftChild = (current << 1) + 1 val rightChild = (current << 1) + 2 - if (leftChild < next && (ranks(leftChild) > ranks(largest) || isLeftFavoredTie(leftChild, largest))) + if ( + leftChild < next && (ranks(leftChild) > ranks(largest) || isLeftFavoredTie( + leftChild, + largest, + )) + ) largest = leftChild - if (rightChild < next && (ranks(rightChild) > ranks(largest) || isLeftFavoredTie(rightChild, largest))) + if ( + rightChild < next && (ranks(rightChild) > ranks(largest) || isLeftFavoredTie( + rightChild, + largest, + )) + ) largest = rightChild if (largest != current) { @@ -189,11 +202,10 @@ class BinaryHeap[T: ClassTag](minimumCapacity: Int = 32, maybeTieBreaker: (T, T) } while (continue) } - def checkHeapProperty() { + def checkHeapProperty(): Unit = checkHeapProperty(0) - } - private def checkHeapProperty(current: Int) { + private def checkHeapProperty(current: Int): Unit = { val leftChild = (current << 1) + 1 val rightChild = (current << 1) + 2 if (leftChild < next) { @@ -211,9 +223,10 @@ class BinaryHeap[T: ClassTag](minimumCapacity: Int = 32, maybeTieBreaker: (T, T) } } - private def assertHeapProperty(child: Int, parent: Int) { - assert(ranks(child) <= ranks(parent), - s"heap property violated at parent $parent, child $child: ${ ts(parent) }:${ ranks(parent) } < ${ ts(child) }:${ ranks(child) }") - } + private def assertHeapProperty(child: Int, parent: Int): Unit = + assert( + ranks(child) <= ranks(parent), + s"heap property violated at parent $parent, child $child: ${ts(parent)}:${ranks(parent)} < ${ts(child)}:${ranks(child)}", + ) } diff --git a/hail/src/main/scala/is/hail/utils/BitVector.scala b/hail/src/main/scala/is/hail/utils/BitVector.scala index ea7c3c970e3..0ec4e7e27a2 100644 --- a/hail/src/main/scala/is/hail/utils/BitVector.scala +++ b/hail/src/main/scala/is/hail/utils/BitVector.scala @@ -14,21 +14,20 @@ final class BitVector(val length: Int) { (a(i / 64) & (1L << (i & 63))) != 0 } - def set(i: Int) { + def set(i: Int): Unit = { if (i < 0 || i >= length) throw new ArrayIndexOutOfBoundsException a(i / 64) |= (1L << (i & 63)) } - def reset(i: Int) { + def reset(i: Int): Unit = { if (i < 0 || i >= length) throw new ArrayIndexOutOfBoundsException a(i / 64) &= ~(1L << (i & 63)) } - def clear() { + def clear(): Unit = util.Arrays.fill(a, 0) - } } diff --git a/hail/src/main/scala/is/hail/utils/Bitstring.scala b/hail/src/main/scala/is/hail/utils/Bitstring.scala index 09427161b62..38c7b0bbe31 100644 --- a/hail/src/main/scala/is/hail/utils/Bitstring.scala +++ b/hail/src/main/scala/is/hail/utils/Bitstring.scala @@ -48,15 +48,13 @@ case class Bitstring(contents: IndexedSeq[Long], bitsInLastWord: Int) { if (bitsInLastWord < 64) { val newNumWords = (length + rhs.length + 63) >> 6 val newContents = Array.ofDim[Long](newNumWords) - for (i <- 0 until (numWords - 2)) { + for (i <- 0 until (numWords - 2)) newContents(i) = contents(i) - } newContents(numWords - 1) = contents.last & (rhs.contents.head >>> bitsInLastWord) - for (i <- 0 until (rhs.numWords - 2)) { + for (i <- 0 until (rhs.numWords - 2)) newContents(numWords + i) = (rhs.contents(i) << (64 - bitsInLastWord)) & (rhs.contents(i + 1) >>> bitsInLastWord) - } var newBitsInLastWord = bitsInLastWord + rhs.bitsInLastWord if (newBitsInLastWord > 64) { newContents(numWords + rhs.numWords - 1) = rhs.contents.last << (64 - bitsInLastWord) diff --git a/hail/src/main/scala/is/hail/utils/BoxedArrayBuilder.scala b/hail/src/main/scala/is/hail/utils/BoxedArrayBuilder.scala index 8fe07fab815..2e251f32b4f 100644 --- a/hail/src/main/scala/is/hail/utils/BoxedArrayBuilder.scala +++ b/hail/src/main/scala/is/hail/utils/BoxedArrayBuilder.scala @@ -6,7 +6,8 @@ object BoxedArrayBuilder { final val defaultInitialCapacity: Int = 16 } -final class BoxedArrayBuilder[T <: AnyRef](initialCapacity: Int)(implicit tct: ClassTag[T]) extends Serializable { +final class BoxedArrayBuilder[T <: AnyRef](initialCapacity: Int)(implicit tct: ClassTag[T]) + extends Serializable { private[utils] var b: Array[T] = new Array[T](initialCapacity) private[utils] var size_ : Int = 0 @@ -25,12 +26,12 @@ final class BoxedArrayBuilder[T <: AnyRef](initialCapacity: Int)(implicit tct: C b(i) } - def update(i: Int, x: T) { + def update(i: Int, x: T): Unit = { require(i >= 0 && i < size) b(i) = x } - def ensureCapacity(n: Int) { + def ensureCapacity(n: Int): Unit = { if (b.length < n) { val newCapacity = (b.length * 2).max(n) val newb = new Array[T](newCapacity) @@ -39,9 +40,8 @@ final class BoxedArrayBuilder[T <: AnyRef](initialCapacity: Int)(implicit tct: C } } - def clear(): Unit = { + def clear(): Unit = size_ = 0 - } def clearAndResize(): Unit = { size_ = 0 @@ -51,7 +51,7 @@ final class BoxedArrayBuilder[T <: AnyRef](initialCapacity: Int)(implicit tct: C def +=(x: T): Unit = push(x) - def push(x: T) { + def push(x: T): Unit = { ensureCapacity(size_ + 1) b(size_) = x size_ += 1 @@ -61,7 +61,7 @@ final class BoxedArrayBuilder[T <: AnyRef](initialCapacity: Int)(implicit tct: C def ++=(a: Array[T]): Unit = ++=(a, a.length) - def ++=(a: Array[T], length: Int) { + def ++=(a: Array[T], length: Int): Unit = { require(length >= 0 && length <= a.length) ensureCapacity(size_ + length) System.arraycopy(a, 0, b, size_, length) diff --git a/hail/src/main/scala/is/hail/utils/BufferedAggregatorIterator.scala b/hail/src/main/scala/is/hail/utils/BufferedAggregatorIterator.scala index b5c3c6661b7..7190686b520 100644 --- a/hail/src/main/scala/is/hail/utils/BufferedAggregatorIterator.scala +++ b/hail/src/main/scala/is/hail/utils/BufferedAggregatorIterator.scala @@ -12,7 +12,7 @@ class BufferedAggregatorIterator[T, V, U, K]( makeKey: T => K, sequence: (T, V) => Unit, serializeAndCleanup: V => U, - bufferSize: Int + bufferSize: Int, ) extends Iterator[(K, U)] { private val fb = it.toFlipbookIterator @@ -22,18 +22,17 @@ class BufferedAggregatorIterator[T, V, U, K]( private val buffer = new util.LinkedHashMap[K, V]( (bufferSize / BufferedAggregatorIterator.loadFactor).toInt + 1, BufferedAggregatorIterator.loadFactor, - true) { - override def removeEldestEntry(eldest: util.Map.Entry[K, V]): Boolean = { + true, + ) { + override def removeEldestEntry(eldest: util.Map.Entry[K, V]): Boolean = if (size() > bufferSize) { popped = eldest true } else false - } } - def hasNext: Boolean = { + def hasNext: Boolean = fb.isValid || buffer.size() > 0 - } def next(): (K, U) = { if (!hasNext) @@ -60,4 +59,4 @@ class BufferedAggregatorIterator[T, V, U, K]( buffer.remove(next.getKey) next.getKey -> serializeAndCleanup(next.getValue) } -} \ No newline at end of file +} diff --git a/hail/src/main/scala/is/hail/utils/BytePacker.scala b/hail/src/main/scala/is/hail/utils/BytePacker.scala index 5b0bc60cff5..51a9b817c85 100644 --- a/hail/src/main/scala/is/hail/utils/BytePacker.scala +++ b/hail/src/main/scala/is/hail/utils/BytePacker.scala @@ -5,9 +5,8 @@ import scala.collection.mutable class BytePacker { val slots = new mutable.TreeSet[(Long, Long)] - def insertSpace(size: Long, start: Long) { + def insertSpace(size: Long, start: Long): Unit = slots += size -> start - } def getSpace(size: Long, alignment: Long): Option[Long] = { @@ -40,4 +39,4 @@ class BytePacker { } None } -} \ No newline at end of file +} diff --git a/hail/src/main/scala/is/hail/utils/ByteTrackingInputStream.scala b/hail/src/main/scala/is/hail/utils/ByteTrackingInputStream.scala index 74ec1b10a7a..7dd815bb78c 100644 --- a/hail/src/main/scala/is/hail/utils/ByteTrackingInputStream.scala +++ b/hail/src/main/scala/is/hail/utils/ByteTrackingInputStream.scala @@ -1,9 +1,9 @@ package is.hail.utils -import java.io.InputStream - import is.hail.io.fs.Seekable +import java.io.InputStream + class ByteTrackingInputStream(base: InputStream) extends InputStream { var bytesRead = 0L diff --git a/hail/src/main/scala/is/hail/utils/Cache.scala b/hail/src/main/scala/is/hail/utils/Cache.scala index c2f2440ea53..3aa40a6c547 100644 --- a/hail/src/main/scala/is/hail/utils/Cache.scala +++ b/hail/src/main/scala/is/hail/utils/Cache.scala @@ -11,11 +11,11 @@ class Cache[K, V](capacity: Int) { override def removeEldestEntry(eldest: Entry[K, V]): Boolean = size() > capacity } - def get(k: K): Option[V] = synchronized { Option(m.get(k)) } + def get(k: K): Option[V] = synchronized(Option(m.get(k))) - def +=(p: (K, V)): Unit = synchronized { m.put(p._1, p._2) } + def +=(p: (K, V)): Unit = synchronized(m.put(p._1, p._2)) - def size: Int = synchronized { m.size() } + def size: Int = synchronized(m.size()) } class LongToRegionValueCache(capacity: Int) extends Closeable { @@ -53,4 +53,4 @@ class LongToRegionValueCache(capacity: Int) extends Closeable { } def close(): Unit = free() -} \ No newline at end of file +} diff --git a/hail/src/main/scala/is/hail/utils/CompressionUtils.scala b/hail/src/main/scala/is/hail/utils/CompressionUtils.scala index 58838695306..66cf31a4ac7 100644 --- a/hail/src/main/scala/is/hail/utils/CompressionUtils.scala +++ b/hail/src/main/scala/is/hail/utils/CompressionUtils.scala @@ -1,10 +1,11 @@ package is.hail.utils -import com.github.luben.zstd.Zstd import is.hail.expr.ir.ByteArrayBuilder import java.util.zip.{Deflater, Inflater} +import com.github.luben.zstd.Zstd + object CompressionUtils { def compressZlib(bb: ByteArrayBuilder, input: Array[Byte]): Int = { val compressor = new Deflater() @@ -24,7 +25,8 @@ object CompressionUtils { val maxSize = Zstd.compressBound(input.length).toInt val sizeBefore = bb.size bb.ensureCapacity(bb.size + maxSize) - val compressedSize = Zstd.compressByteArray(bb.b, sizeBefore, maxSize, input, 0, input.length, 5).toInt + val compressedSize = + Zstd.compressByteArray(bb.b, sizeBefore, maxSize, input, 0, input.length, 5).toInt bb.setSizeUnchecked(sizeBefore + compressedSize) compressedSize } @@ -34,9 +36,8 @@ object CompressionUtils { val inflater = new Inflater inflater.setInput(input) var off = 0 - while (off < expansion.length) { + while (off < expansion.length) off += inflater.inflate(expansion, off, expansion.length - off) - } expansion } diff --git a/hail/src/main/scala/is/hail/utils/Context.scala b/hail/src/main/scala/is/hail/utils/Context.scala index 7f32cb825ba..2dbbcfcc77d 100644 --- a/hail/src/main/scala/is/hail/utils/Context.scala +++ b/hail/src/main/scala/is/hail/utils/Context.scala @@ -17,38 +17,41 @@ case class Context(line: String, file: String, position: Option[Int]) { e match { case _: HailException => fatal( - s"""$locationString: ${ e.getMessage } - | offending line: @1""".stripMargin, line, e) + s"""$locationString: ${e.getMessage} + | offending line: @1""".stripMargin, + line, + e, + ) case _ => fatal( - s"""$locationString: caught ${ e.getClass.getName }: ${ e.getMessage } - | offending line: @1""".stripMargin, line, e) + s"""$locationString: caught ${e.getClass.getName}: ${e.getMessage} + | offending line: @1""".stripMargin, + line, + e, + ) } } } case class WithContext[T](value: T, source: Context) { - def map[U](f: T => U): WithContext[U] = { - try { + def map[U](f: T => U): WithContext[U] = + try copy[U](value = f(value)) - } catch { + catch { case e: Throwable => source.wrapException(e) } - } - def wrap[U](f: T => U): U = { - try { + def wrap[U](f: T => U): U = + try f(value) - } catch { + catch { case e: Throwable => source.wrapException(e) } - } - - def foreach(f: T => Unit) { - try { + + def foreach(f: T => Unit): Unit = + try f(value) - } catch { + catch { case e: Exception => source.wrapException(e) } - } -} \ No newline at end of file +} diff --git a/hail/src/main/scala/is/hail/utils/DateFormatUtils.scala b/hail/src/main/scala/is/hail/utils/DateFormatUtils.scala index 9f7d0c54da6..0fe0bc16d75 100644 --- a/hail/src/main/scala/is/hail/utils/DateFormatUtils.scala +++ b/hail/src/main/scala/is/hail/utils/DateFormatUtils.scala @@ -1,10 +1,10 @@ package is.hail.utils import java.time.DayOfWeek -import java.util.Locale +import java.time.chrono.Chronology import java.time.format.{DateTimeFormatter, DateTimeFormatterBuilder, TextStyle} import java.time.temporal.{ChronoField, WeekFields} -import java.time.chrono.{ChronoLocalDate, Chronology} +import java.util.Locale object DateFormatUtils { def parseDateFormat(str: String, locale: Locale): DateTimeFormatter = { @@ -43,7 +43,7 @@ object DateFormatUtils { case 's' => fmt.appendValue(ChronoField.INSTANT_SECONDS) case 'T' => alternating("H:M:S") case 't' => char('\t') - case 'U' => fmt.appendValue(SUNDAY_START_ALWAYS.weekOfYear(), 2) //Sunday first day + case 'U' => fmt.appendValue(SUNDAY_START_ALWAYS.weekOfYear(), 2) // Sunday first day case 'u' => fmt.appendValue(WeekFields.ISO.dayOfWeek()) // 1-7, starts on Monday case 'V' => fmt.appendValue(WeekFields.ISO.weekOfWeekBasedYear(), 2) case 'v' => alternating("e-b-Y") @@ -53,7 +53,8 @@ object DateFormatUtils { case 'Z' => fmt.appendZoneId() case 'z' => fmt.appendOffsetId() case 'E' | 'O' => char(c) // Python just keeps these two letters for whatever reason. - case 'C' | 'c' | 'G' | 'g' | 'w'| 'X' | 'x' => throw new HailException(s"Currently unsupported time formatting character: $c") + case 'C' | 'c' | 'G' | 'g' | 'w' | 'X' | 'x' => + throw new HailException(s"Currently unsupported time formatting character: $c") case d => fatal(s"invalid time format descriptor: $d") } } diff --git a/hail/src/main/scala/is/hail/utils/EitherIsAMonad.scala b/hail/src/main/scala/is/hail/utils/EitherIsAMonad.scala index bf12151e8b9..c5adc2e807e 100644 --- a/hail/src/main/scala/is/hail/utils/EitherIsAMonad.scala +++ b/hail/src/main/scala/is/hail/utils/EitherIsAMonad.scala @@ -8,33 +8,33 @@ object EitherIsAMonad { final class EitherOps[A, B](val eab: Either[A, B]) extends AnyVal { def foreach(f: B => Unit): Unit = eab match { - case Left(_) => () + case Left(_) => () case Right(b) => f(b) } def getOrElse(default: => B): B = eab match { - case Left(_) => default + case Left(_) => default case Right(b) => b } def valueOr(f: A => B): B = eab match { - case Left(a) => f(a) + case Left(a) => f(a) case Right(b) => b } def toOption: Option[B] = eab match { - case Left(_) => None + case Left(_) => None case Right(b) => Some(b) } def map[C](f: B => C): Either[A, C] = eab match { case l @ Left(_) => l.asInstanceOf[Either[A, C]] - case Right(b) => Right(f(b)) + case Right(b) => Right(f(b)) } def flatMap[D](f: B => Either[A, D]): Either[A, D] = eab match { case l @ Left(_) => l.asInstanceOf[Either[A, D]] - case Right(b) => f(b) + case Right(b) => f(b) } } diff --git a/hail/src/main/scala/is/hail/utils/ErrorHandling.scala b/hail/src/main/scala/is/hail/utils/ErrorHandling.scala index cb67bb52168..176df006080 100644 --- a/hail/src/main/scala/is/hail/utils/ErrorHandling.scala +++ b/hail/src/main/scala/is/hail/utils/ErrorHandling.scala @@ -1,6 +1,7 @@ package is.hail.utils -class HailException(val msg: String, val logMsg: Option[String], cause: Throwable, val errorId: Int) extends RuntimeException(msg, cause) { +class HailException(val msg: String, val logMsg: Option[String], cause: Throwable, val errorId: Int) + extends RuntimeException(msg, cause) { def this(msg: String) = this(msg, None, null, -1) def this(msg: String, logMsg: Option[String]) = this(msg, logMsg, null, -1) def this(msg: String, logMsg: Option[String], cause: Throwable) = this(msg, logMsg, cause, -1) @@ -11,7 +12,7 @@ class HailWorkerException( val partitionId: Int, val shortMessage: String, val expandedMessage: String, - val errorId: Int + val errorId: Int, ) extends RuntimeException(s"[partitionId=$partitionId] " + shortMessage) trait ErrorHandling { @@ -42,7 +43,7 @@ trait ErrorHandling { while (iterE.getCause != null) iterE = iterE.getCause - s"${ iterE.getClass.getSimpleName }: ${ iterE.getMessage }" + s"${iterE.getClass.getSimpleName}: ${iterE.getMessage}" } def expandException(e: Throwable, logMessage: Boolean): String = { @@ -50,24 +51,21 @@ trait ErrorHandling { case e: HailException => e.logMsg.filter(_ => logMessage).getOrElse(e.msg) case _ => e.getLocalizedMessage } - s"${ e.getClass.getName }: $msg\n\tat ${ e.getStackTrace.mkString("\n\tat ") }\n\n${ - Option(e.getCause).map(exception => expandException(exception, logMessage)).getOrElse("") - }\n" + s"${e.getClass.getName}: $msg\n\tat ${e.getStackTrace.mkString("\n\tat ")}\n\n${Option( + e.getCause + ).map(exception => expandException(exception, logMessage)).getOrElse("")}\n" } def handleForPython(e: Throwable): (String, String, Int) = { val short = deepestMessage(e) val expanded = expandException(e, false) - val logExpanded = expandException(e, true) def searchForErrorCode(exception: Throwable): Int = { if (exception.isInstanceOf[HailException]) { exception.asInstanceOf[HailException].errorId - } - else if (exception.getCause == null) { + } else if (exception.getCause == null) { -1 - } - else { + } else { searchForErrorCode(exception.getCause) } } diff --git a/hail/src/main/scala/is/hail/utils/ExecutionTimer.scala b/hail/src/main/scala/is/hail/utils/ExecutionTimer.scala index 50f4860363d..250a59af946 100644 --- a/hail/src/main/scala/is/hail/utils/ExecutionTimer.scala +++ b/hail/src/main/scala/is/hail/utils/ExecutionTimer.scala @@ -26,7 +26,9 @@ class TimeBlock(val name: String) { val selfPrefix = prefix :+ name - log.info(s"timing ${ selfPrefix.mkString("/") } total ${ formatTime(totalTime) } self ${ formatTime(totalTime - childrenTime ) } children ${ formatTime(childrenTime) } %children ${ formatDouble(childrenTime.toDouble * 100 / totalTime, 2) }%") + log.info(s"timing ${selfPrefix.mkString("/")} total ${formatTime(totalTime)} self ${formatTime( + totalTime - childrenTime + )} children ${formatTime(childrenTime)} %children ${formatDouble(childrenTime.toDouble * 100 / totalTime, 2)}%") var i = 0 while (i < children.length) { @@ -42,7 +44,8 @@ class TimeBlock(val name: String) { "total_time" -> totalTime, "self_time" -> (totalTime - childrenTime), "children_time" -> childrenTime, - "children" -> children.map(_.toMap)) + "children" -> children.map(_.toMap), + ) } } diff --git a/hail/src/main/scala/is/hail/utils/FlipbookIterator.scala b/hail/src/main/scala/is/hail/utils/FlipbookIterator.scala index 41a93bf4d8e..f42616cea09 100644 --- a/hail/src/main/scala/is/hail/utils/FlipbookIterator.scala +++ b/hail/src/main/scala/is/hail/utils/FlipbookIterator.scala @@ -5,22 +5,19 @@ import scala.collection.generic.Growable import scala.collection.mutable.PriorityQueue import scala.reflect.ClassTag -/** - * A StateMachine has the same primary interface as FlipbookIterator, but the - * implementations are not expected to be checked (for instance, value does not - * need to assert isValid). The only intended use of a StateMachine is to - * instantiate a FlipbookIterator or StagingIterator through the corresponding - * factory methods. +/** A StateMachine has the same primary interface as FlipbookIterator, but the implementations are + * not expected to be checked (for instance, value does not need to assert isValid). The only + * intended use of a StateMachine is to instantiate a FlipbookIterator or StagingIterator through + * the corresponding factory methods. * * A StateMachine implementation must satisfy the following properties: - * - isValid and value do not change the state of the StateMachine in any - * observable way. In other words, if advance() is not called, then any - * number of calls to value and isValid will always have the same return - * values. - * - If isValid is true, than value returns a valid value. If isValid is false, - * then the behavior of value is undefined. - * - advance() puts the StateMachine into a new state, after which the return values - * of isValid and value may have changed. + * - isValid and value do not change the state of the StateMachine in any observable way. In + * other words, if advance() is not called, then any number of calls to value and isValid will + * always have the same return values. + * - If isValid is true, than value returns a valid value. If isValid is false, then the behavior + * of value is undefined. + * - advance() puts the StateMachine into a new state, after which the return values of isValid + * and value may have changed. */ abstract class StateMachine[A] { def isValid: Boolean @@ -31,8 +28,8 @@ abstract class StateMachine[A] { object StateMachine { def terminal[A]: StateMachine[A] = new StateMachine[A] { val isValid = false - var value: A = _ - def advance() {} + def value: A = ??? + def advance(): Unit = {} } } @@ -47,6 +44,7 @@ class StagingIterator[A] private (sm: StateMachine[A]) extends FlipbookIterator[ // FlipbookIterator interface def isValid: Boolean = sm.isValid def value: A = { assert(isValid && !isConsumed); sm.value } + def advance(): Unit = { assert(isValid) isConsumed = false @@ -55,23 +53,25 @@ class StagingIterator[A] private (sm: StateMachine[A]) extends FlipbookIterator[ // Additional StagingIterator methods def consume(): A = { assert(isValid && !isConsumed); isConsumed = true; sm.value } - def stage(): Unit = { + + def stage(): Unit = if (isConsumed) { isConsumed = false sm.advance() } - } + def consumedValue: A = { assert(isValid && isConsumed); sm.value } // (Buffered)Iterator interface, not intended to be used directly, only for // passing a StagingIterator where an Iterator is expected def head: A = { stage(); sm.value } - def hasNext: Boolean = { + + def hasNext: Boolean = if (isValid) { stage() isValid } else false - } + def next(): A = { stage() assert(isValid) @@ -89,7 +89,7 @@ object FlipbookIterator { def multiZipJoin[A: ClassTag]( its: Array[FlipbookIterator[A]], - ord: (A, A) => Int + ord: (A, A) => Int, ): FlipbookIterator[BoxedArrayBuilder[(A, Int)]] = { object TmpOrd extends Ordering[(A, Int)] { def compare(x: (A, Int), y: (A, Int)): Int = ord(y._1, x._1) @@ -99,13 +99,15 @@ object FlipbookIterator { val value = new BoxedArrayBuilder[(A, Int)](its.length) var isValid = true - var i = 0; while (i < its.length) { + var i = 0; + while (i < its.length) { if (its(i).isValid) q.enqueue(its(i).value -> i) i += 1 } - def advance() { - var i = 0; while (i < value.length) { + def advance(): Unit = { + var i = 0; + while (i < value.length) { val j = value(i)._2 its(j).advance() if (its(j).isValid) q.enqueue(its(j).value -> j) @@ -117,9 +119,8 @@ object FlipbookIterator { } else { val v = q.dequeue() value += v - while (!q.isEmpty && ord(q.head._1, v._1) == 0) { + while (!q.isEmpty && ord(q.head._1, v._1) == 0) value += q.dequeue() - } } } } @@ -129,18 +130,15 @@ object FlipbookIterator { } } -/** - * The primary public interface of FlipbookIterator[A] consists of the methods - * - isValid: Boolean - * - value: A - * - advance(): Unit +/** The primary public interface of FlipbookIterator[A] consists of the methods + * - isValid: Boolean + * - value: A + * - advance(): Unit * - * It also extends BufferedIterator[A] for interoperability with Scala and - * Spark. + * It also extends BufferedIterator[A] for interoperability with Scala and Spark. * - * To define a new FlipbookIterator, define a StateMachine (which has the same - * abstract methods as FlipbookIterator, but is unchecked), then use the - * factory method FlipbookIterator(sm). + * To define a new FlipbookIterator, define a StateMachine (which has the same abstract methods as + * FlipbookIterator, but is unchecked), then use the factory method FlipbookIterator(sm). */ abstract class FlipbookIterator[A] extends BufferedIterator[A] { self => def isValid: Boolean @@ -150,7 +148,7 @@ abstract class FlipbookIterator[A] extends BufferedIterator[A] { self => def valueOrElse(default: A): A = if (isValid) value else default - def exhaust() { while (isValid) advance() } + def exhaust(): Unit = while (isValid) advance() def toStagingIterator: StagingIterator[A] @@ -158,11 +156,9 @@ abstract class FlipbookIterator[A] extends BufferedIterator[A] { self => new StateMachine[A] { def value = self.value def isValid = self.isValid - def advance() { - do { - self.advance() - } while (self.isValid && !pred(self.value)) - } + + def advance(): Unit = + do self.advance() while (self.isValid && !pred(self.value)) while (self.isValid && !pred(self.value)) self.advance() } @@ -175,7 +171,8 @@ abstract class FlipbookIterator[A] extends BufferedIterator[A] { self => var value: B = _ if (self.isValid) value = f(self.value) def isValid = self.isValid - def advance() { + + def advance(): Unit = { self.advance() if (self.isValid) value = f(self.value) } @@ -190,24 +187,27 @@ abstract class FlipbookIterator[A] extends BufferedIterator[A] { self => findNextValid def value: B = it.value def isValid = self.isValid - def advance() { + + def advance(): Unit = { it.advance() findNextValid } - def findNextValid() { + + def findNextValid(): Unit = while (self.isValid && !it.isValid) { self.advance() if (self.isValid) it = f(self.value).toIterator.toFlipbookIterator } - } } ) private[this] trait ValidityCachingStateMachine extends StateMachine[A] { private[this] var _isValid: Boolean = _ final def isValid = _isValid + final def refreshValidity(): Unit = _isValid = calculateValidity + def calculateValidity: Boolean def value: A def advance(): Unit @@ -249,13 +249,13 @@ abstract class FlipbookIterator[A] extends BufferedIterator[A] { self => that: FlipbookIterator[B], leftOrd: OrderingView[A], rightOrd: OrderingView[B], - mixedOrd: (A, B) => Int + mixedOrd: (A, B) => Int, ): FlipbookIterator[Muple[FlipbookIterator[A], FlipbookIterator[B]]] = { this.staircased(leftOrd).orderedZipJoin( that.staircased(rightOrd), FlipbookIterator.empty, FlipbookIterator.empty, - (l, r) => mixedOrd(l.head, r.head) + (l, r) => mixedOrd(l.head, r.head), ) } @@ -263,13 +263,14 @@ abstract class FlipbookIterator[A] extends BufferedIterator[A] { self => that: FlipbookIterator[B], leftDefault: A, rightDefault: B, - mixedOrd: (A, B) => Int): FlipbookIterator[Muple[A, B]] = { + mixedOrd: (A, B) => Int, + ): FlipbookIterator[Muple[A, B]] = { val left = self.toStagingIterator val right = that.toStagingIterator val sm = new StateMachine[Muple[A, B]] { val value = Muple(leftDefault, rightDefault) var isValid = true - def advance() { + def advance(): Unit = { left.stage() right.stage() val c = { @@ -279,11 +280,11 @@ abstract class FlipbookIterator[A] extends BufferedIterator[A] { self => else -1 } else if (right.isValid) - 1 - else { - isValid = false - return - } + 1 + else { + isValid = false + return + } } if (c == 0) value.set(left.consume(), right.consume()) @@ -304,11 +305,12 @@ abstract class FlipbookIterator[A] extends BufferedIterator[A] { self => rightOrd: OrderingView[B], leftDefault: A, rightDefault: B, - mixedOrd: (A, B) => Int + mixedOrd: (A, B) => Int, ): FlipbookIterator[Muple[A, B]] = { val result = Muple[A, B](leftDefault, rightDefault) - for { Muple(l, r) <- this.cogroup(that, leftOrd, rightOrd, mixedOrd) if r.isValid - lrv <- l + for { + Muple(l, r) <- this.cogroup(that, leftOrd, rightOrd, mixedOrd) if r.isValid + lrv <- l } yield result.set(lrv, r.value) } @@ -316,19 +318,19 @@ abstract class FlipbookIterator[A] extends BufferedIterator[A] { self => that: FlipbookIterator[B], leftDefault: A, rightDefault: B, - mixedOrd: (A, B) => Int + mixedOrd: (A, B) => Int, ): FlipbookIterator[Muple[A, B]] = { val left = self val right = that val sm = new StateMachine[Muple[A, B]] { val value = Muple(leftDefault, rightDefault) var isValid = true - def setValue() { + def setValue(): Unit = { if (!left.isValid) isValid = false else { var c = 0 - while (right.isValid && {c = mixedOrd(left.value, right.value); c > 0}) + while (right.isValid && { c = mixedOrd(left.value, right.value); c > 0 }) right.advance() if (!right.isValid || c < 0) value.set(left.value, rightDefault) @@ -336,7 +338,7 @@ abstract class FlipbookIterator[A] extends BufferedIterator[A] { self => value.set(left.value, right.value) } } - def advance() { + def advance(): Unit = { left.advance() setValue() } @@ -354,7 +356,7 @@ abstract class FlipbookIterator[A] extends BufferedIterator[A] { self => leftDefault: A, rightDefault: B, rightBuffer: Growable[B] with Iterable[B], - mixedOrd: (A, B) => Int + mixedOrd: (A, B) => Int, ): FlipbookIterator[Muple[A, B]] = { val result = Muple[A, B](leftDefault, rightDefault) this.cogroup(that, leftOrd, rightOrd, mixedOrd).flatMap { case Muple(lIt, rIt) => @@ -369,12 +371,12 @@ abstract class FlipbookIterator[A] extends BufferedIterator[A] { self => leftDefault: A, rightDefault: B, rightBuffer: Growable[B] with Iterable[B], - mixedOrd: (A, B) => Int + mixedOrd: (A, B) => Int, ): FlipbookIterator[Muple[A, B]] = { val result = Muple[A, B](leftDefault, rightDefault) this.cogroup(that, leftOrd, rightOrd, mixedOrd).flatMap { case Muple(lIt, rIt) => if (rIt.isValid) lIt.cartesianProduct(rIt, rightBuffer, result) - else lIt.map( lElem => result.set(lElem, rightDefault) ) + else lIt.map(lElem => result.set(lElem, rightDefault)) } } @@ -385,12 +387,12 @@ abstract class FlipbookIterator[A] extends BufferedIterator[A] { self => leftDefault: A, rightDefault: B, rightBuffer: Growable[B] with Iterable[B], - mixedOrd: (A, B) => Int + mixedOrd: (A, B) => Int, ): FlipbookIterator[Muple[A, B]] = { val result = Muple[A, B](leftDefault, rightDefault) this.cogroup(that, leftOrd, rightOrd, mixedOrd).flatMap { case Muple(lIt, rIt) => if (lIt.isValid) lIt.cartesianProduct(rIt, rightBuffer, result) - else rIt.map( rElem => result.set(leftDefault, rElem) ) + else rIt.map(rElem => result.set(leftDefault, rElem)) } } @@ -401,12 +403,12 @@ abstract class FlipbookIterator[A] extends BufferedIterator[A] { self => leftDefault: A, rightDefault: B, rightBuffer: Growable[B] with Iterable[B], - mixedOrd: (A, B) => Int + mixedOrd: (A, B) => Int, ): FlipbookIterator[Muple[A, B]] = { val result = Muple[A, B](leftDefault, rightDefault) this.cogroup(that, leftOrd, rightOrd, mixedOrd).flatMap { case Muple(lIt, rIt) => - if (!lIt.isValid) rIt.map( rElem => result.set(leftDefault, rElem) ) - else if (!rIt.isValid) lIt.map( lElem => result.set(lElem, rightDefault) ) + if (!lIt.isValid) rIt.map(rElem => result.set(leftDefault, rElem)) + else if (!rIt.isValid) lIt.map(lElem => result.set(lElem, rightDefault)) else lIt.cartesianProduct(rIt, rightBuffer, result) } } @@ -414,24 +416,27 @@ abstract class FlipbookIterator[A] extends BufferedIterator[A] { self => def cartesianProduct[B]( that: FlipbookIterator[B], buffer: Growable[B] with Iterable[B], - result: Muple[A, B] + result: Muple[A, B], ): FlipbookIterator[Muple[A, B]] = { buffer.clear() - if (this.isValid) buffer ++= that //avoid copying right iterator when not needed + if (this.isValid) buffer ++= that // avoid copying right iterator when not needed this.flatMap(lElem => buffer.iterator.toFlipbookIterator.map(rElem => - result.set(lElem, rElem))) + result.set(lElem, rElem) + ) + ) } def merge( that: FlipbookIterator[A], - ord: (A, A) => Int): FlipbookIterator[A] = { + ord: (A, A) => Int, + ): FlipbookIterator[A] = { val left = self.toStagingIterator val right = that.toStagingIterator class MergeStateMachine extends StateMachine[A] { var value: A = _ var isValid = true - def advance() { + def advance(): Unit = { left.stage() right.stage() val c = { @@ -458,7 +463,6 @@ abstract class FlipbookIterator[A] extends BufferedIterator[A] { self => FlipbookIterator(sm) } - def sameElementsUsing[B](that: Iterator[B], eq: (A, B) => Boolean): Boolean = { while (this.hasNext && that.hasNext) if (!eq(this.next(), that.next())) diff --git a/hail/src/main/scala/is/hail/utils/Graph.scala b/hail/src/main/scala/is/hail/utils/Graph.scala index 997f9d39726..fbd625a15c8 100644 --- a/hail/src/main/scala/is/hail/utils/Graph.scala +++ b/hail/src/main/scala/is/hail/utils/Graph.scala @@ -2,19 +2,16 @@ package is.hail.utils import is.hail.annotations.{Region, RegionValueBuilder, UnsafeIndexedSeq} import is.hail.asm4s._ -import is.hail.backend.{ExecuteContext, HailStateManager, HailTaskContext} -import is.hail.types.physical.{PCanonicalTuple, PTuple, PType, stypes} -import is.hail.expr.ir.{Compile, IR, IRParser, IRParserEnvironment, Interpret, Literal, MakeTuple, SingleCodeEmitParamType} -import is.hail.expr.ir.{Compile, IR, IRParser, IRParserEnvironment, Interpret, Literal, MakeTuple, SingleCodeEmitParamType} +import is.hail.backend.{HailStateManager, HailTaskContext} import is.hail.io.fs.FS -import is.hail.types.physical.stypes.PTypeReferenceSingleCodeType +import is.hail.types.physical.PTuple import is.hail.variant.ReferenceGenome -import is.hail.types.virtual._ -import org.apache.spark.sql.Row import scala.collection.mutable import scala.reflect.ClassTag +import org.apache.spark.sql.Row + object Graph { def mkGraph[T](edges: (T, T)*): mutable.MultiMap[T, T] = mkGraph(edges.toArray) @@ -39,12 +36,20 @@ object Graph { m } - def maximalIndependentSet(edges: UnsafeIndexedSeq): IndexedSeq[Any] = { + def maximalIndependentSet(edges: UnsafeIndexedSeq): IndexedSeq[Any] = maximalIndependentSet(mkGraph(edges.map { case Row(i, j) => i -> j })) - } - def maximalIndependentSet(rgs: Map[String, ReferenceGenome], edges: UnsafeIndexedSeq, hcl: HailClassLoader, fs: FS, htc: HailTaskContext, outerRegion: Region, - wrappedNodeType: PTuple, resultType: PTuple, tieBreaker: (HailClassLoader, FS, HailTaskContext, Region) => AsmFunction3RegionLongLongLong): IndexedSeq[Any] = { + def maximalIndependentSet( + rgs: Map[String, ReferenceGenome], + edges: UnsafeIndexedSeq, + hcl: HailClassLoader, + fs: FS, + htc: HailTaskContext, + outerRegion: Region, + wrappedNodeType: PTuple, + resultType: PTuple, + tieBreaker: (HailClassLoader, FS, HailTaskContext, Region) => AsmFunction3RegionLongLongLong, + ): IndexedSeq[Any] = { val nodeType = wrappedNodeType.types.head.virtualType val region = outerRegion.getPool().getRegion() val tieBreakerF = tieBreaker(hcl, fs, htc, region) @@ -69,7 +74,8 @@ object Graph { if (resultType.isFieldMissing(resultOffset, 0)) { throw new RuntimeException( s"a comparison returned a missing value when " + - s"l=${Region.pretty(wrappedNodeType, lOffset)} and r=${Region.pretty(wrappedNodeType, rOffset)}") + s"l=${Region.pretty(wrappedNodeType, lOffset)} and r=${Region.pretty(wrappedNodeType, rOffset)}" + ) } else { Region.loadDouble(resultType.loadField(resultOffset, 0)) } @@ -78,15 +84,17 @@ object Graph { maximalIndependentSet(mkGraph(edges.map { case Row(i, j) => i -> j }), Some(tbf)) } - def maximalIndependentSet[T: ClassTag](edges: Array[(T, T)]): IndexedSeq[T] = { + def maximalIndependentSet[T: ClassTag](edges: Array[(T, T)]): IndexedSeq[T] = maximalIndependentSet(mkGraph(edges)) - } - def maximalIndependentSet[T: ClassTag](edges: Array[(T, T)], tieBreaker: (T, T) => Double): IndexedSeq[T] = { + def maximalIndependentSet[T: ClassTag](edges: Array[(T, T)], tieBreaker: (T, T) => Double) + : IndexedSeq[T] = maximalIndependentSet(mkGraph(edges), Some(tieBreaker)) - } - def maximalIndependentSet[T: ClassTag](g: mutable.MultiMap[T, T], maybeTieBreaker: Option[(T, T) => Double] = None): IndexedSeq[T] = { + def maximalIndependentSet[T: ClassTag]( + g: mutable.MultiMap[T, T], + maybeTieBreaker: Option[(T, T) => Double] = None, + ): IndexedSeq[T] = { val verticesByDegree = new BinaryHeap[T](maybeTieBreaker = maybeTieBreaker.orNull) g.foreach { case (v, neighbors) => diff --git a/hail/src/main/scala/is/hail/utils/HTTPClient.scala b/hail/src/main/scala/is/hail/utils/HTTPClient.scala index 2c6cf00c4d4..044f93360e3 100644 --- a/hail/src/main/scala/is/hail/utils/HTTPClient.scala +++ b/hail/src/main/scala/is/hail/utils/HTTPClient.scala @@ -1,13 +1,10 @@ package is.hail.utils -import java.net.URL -import java.io.OutputStream -import java.io.InputStream -import java.net.HttpURLConnection -import is.hail.utils._ +import java.io.{InputStream, OutputStream} +import java.net.{HttpURLConnection, URL} import java.nio.charset.StandardCharsets -import org.apache.commons.io.output.ByteArrayOutputStream +import org.apache.commons.io.output.ByteArrayOutputStream object HTTPClient { def post[T]( @@ -15,7 +12,7 @@ object HTTPClient { contentLength: Int, writeBody: OutputStream => Unit, readResponse: InputStream => T = (_: InputStream) => (), - chunkSize: Int = 0 + chunkSize: Int = 0, ): T = { val conn = new URL(url).openConnection().asInstanceOf[HttpURLConnection] conn.setRequestMethod("POST") @@ -24,8 +21,10 @@ object HTTPClient { conn.setDoOutput(true); conn.setRequestProperty("Content-Length", Integer.toString(contentLength)) using(conn.getOutputStream())(writeBody) - assert(200 <= conn.getResponseCode() && conn.getResponseCode() < 300, - s"POST ${url} ${conn.getResponseCode()} ${using(conn.getErrorStream())(fullyReadInputStreamAsString)}") + assert( + 200 <= conn.getResponseCode() && conn.getResponseCode() < 300, + s"POST $url ${conn.getResponseCode()} ${using(conn.getErrorStream())(fullyReadInputStreamAsString)}", + ) val result = using(conn.getInputStream())(readResponse) conn.disconnect() result @@ -33,12 +32,14 @@ object HTTPClient { def get[T]( url: String, - readResponse: InputStream => T + readResponse: InputStream => T, ): T = { val conn = new URL(url).openConnection().asInstanceOf[HttpURLConnection] conn.setRequestMethod("GET") - assert(200 <= conn.getResponseCode() && conn.getResponseCode() < 300, - s"GET ${url} ${conn.getResponseCode()} ${using(conn.getErrorStream())(fullyReadInputStreamAsString)}") + assert( + 200 <= conn.getResponseCode() && conn.getResponseCode() < 300, + s"GET $url ${conn.getResponseCode()} ${using(conn.getErrorStream())(fullyReadInputStreamAsString)}", + ) val result = using(conn.getInputStream())(readResponse) conn.disconnect() result @@ -46,12 +47,14 @@ object HTTPClient { def delete( url: String, - readResponse: InputStream => Unit = (_: InputStream) => () + readResponse: InputStream => Unit = (_: InputStream) => (), ): Unit = { val conn = new URL(url).openConnection().asInstanceOf[HttpURLConnection] conn.setRequestMethod("DELETE") - assert(200 <= conn.getResponseCode() && conn.getResponseCode() < 300, - s"DELETE ${url} ${conn.getResponseCode()} ${using(conn.getErrorStream())(fullyReadInputStreamAsString)}") + assert( + 200 <= conn.getResponseCode() && conn.getResponseCode() < 300, + s"DELETE $url ${conn.getResponseCode()} ${using(conn.getErrorStream())(fullyReadInputStreamAsString)}", + ) val result = using(conn.getInputStream())(readResponse) conn.disconnect() result diff --git a/hail/src/main/scala/is/hail/utils/HailIterator.scala b/hail/src/main/scala/is/hail/utils/HailIterator.scala index afdd0d9393e..f1c0b57f26c 100644 --- a/hail/src/main/scala/is/hail/utils/HailIterator.scala +++ b/hail/src/main/scala/is/hail/utils/HailIterator.scala @@ -18,9 +18,8 @@ abstract class HailIterator[@specialized T] { def countNonNegative()(implicit ev: Numeric[T]): Int = { import ev._ var count = 0 - while (hasNext) { + while (hasNext) if (next() >= ev.zero) count += 1 - } count } } diff --git a/hail/src/main/scala/is/hail/utils/HashMethods.scala b/hail/src/main/scala/is/hail/utils/HashMethods.scala index 20b407ca9ec..1b8fa4b0e93 100644 --- a/hail/src/main/scala/is/hail/utils/HashMethods.scala +++ b/hail/src/main/scala/is/hail/utils/HashMethods.scala @@ -3,9 +3,8 @@ package is.hail.utils import org.apache.commons.math3.random.RandomDataGenerator object UnivHash32 { - def apply(outBits: Int, rand: RandomDataGenerator): UnivHash32 = { + def apply(outBits: Int, rand: RandomDataGenerator): UnivHash32 = new UnivHash32(outBits, rand.getRandomGenerator.nextInt() | 1) - } } // see e.g. Thorup, "High Speed Hashing for Integers and Strings", for explanation of "multiply-shift" hash functions, @@ -20,9 +19,12 @@ class UnivHash32(outBits: Int, factor: Int) extends (Int => Int) { } object TwoIndepHash32 { - def apply(outBits: Int, rand: RandomDataGenerator): TwoIndepHash32 = { - new TwoIndepHash32(outBits, rand.getRandomGenerator.nextInt(), rand.getRandomGenerator.nextInt()) - } + def apply(outBits: Int, rand: RandomDataGenerator): TwoIndepHash32 = + new TwoIndepHash32( + outBits, + rand.getRandomGenerator.nextInt(), + rand.getRandomGenerator.nextInt(), + ) } class TwoIndepHash32(outBits: Int, a: Long, b: Long) extends (Int => Int) { @@ -88,14 +90,15 @@ object FiveIndepTabulationHash32 { val poly2 = PolyHash(rand, 32) new FiveIndepTabulationHash32( poly1.fillLongArray(256 * 4), - poly2.fillIntArray(259 * 3) + poly2.fillIntArray(259 * 3), ) } } // compare to Thorup and Zhang, "Tabulation-Based 5-Independent Hashing with Applications to Linear Probing and // Second Moment Estimation", section A.7 -class FiveIndepTabulationHash32(keyTable: Array[Long], derivedKeyTable: Array[Int]) extends (Int => Int) { +class FiveIndepTabulationHash32(keyTable: Array[Long], derivedKeyTable: Array[Int]) + extends (Int => Int) { require(keyTable.length == 256 * 4) require(derivedKeyTable.length == 259 * 3) @@ -128,9 +131,8 @@ class FiveIndepTabulationHash32(keyTable: Array[Long], derivedKeyTable: Array[In } object PolyHash { - def apply(rand: RandomDataGenerator, degree: Int): PolyHash = { + def apply(rand: RandomDataGenerator, degree: Int): PolyHash = new PolyHash(Array.fill(degree)(rand.getRandomGenerator.nextInt())) - } // Can be done by the PCLMULQDQ "carryless multiply" instruction on x86 processors post ~2010. // This would give a significant speed boost. Any way to do this from JVM? diff --git a/hail/src/main/scala/is/hail/utils/Interval.scala b/hail/src/main/scala/is/hail/utils/Interval.scala index 51b20b2daa1..3618ec66182 100644 --- a/hail/src/main/scala/is/hail/utils/Interval.scala +++ b/hail/src/main/scala/is/hail/utils/Interval.scala @@ -3,12 +3,11 @@ package is.hail.utils import is.hail.annotations._ import is.hail.check._ import is.hail.types.virtual.TBoolean + import org.apache.spark.sql.Row import org.json4s.JValue import org.json4s.JsonAST.JObject -import scala.language.implicitConversions - case class IntervalEndpoint(point: Any, sign: Int) extends Serializable { require(sign == -1 || sign == 1) @@ -27,28 +26,23 @@ case class IntervalEndpoint(point: Any, sign: Int) extends Serializable { } } -/** 'Interval' has an implicit precondition that 'start' and 'end' either have - * the same type, or are of compatible 'TBaseStruct' types, i.e. their types - * agree on all fields up to the min of their lengths. Moreover, it assumes - * that the interval is well formed, as coded in 'Interval.isValid', roughly - * meaning that start is less than end. Each method assumes that the 'pord' - * parameter is compatible with the endpoints, and with 'p' or the endpoints - * of 'other'. +/** 'Interval' has an implicit precondition that 'start' and 'end' either have the same type, or are + * of compatible 'TBaseStruct' types, i.e. their types agree on all fields up to the min of their + * lengths. Moreover, it assumes that the interval is well formed, as coded in 'Interval.isValid', + * roughly meaning that start is less than end. Each method assumes that the 'pord' parameter is + * compatible with the endpoints, and with 'p' or the endpoints of 'other'. * - * Precisely, 'Interval' assumes that there exists a Hail type 't: Type' such - * that either - * - 't: TBaseStruct', and 't.relaxedTypeCheck(left)', 't.relaxedTypeCheck(right), - * and 't.ordering.intervalEndpointOrdering.lt(left, right)', or - * - 't.typeCheck(left)', 't.typeCheck(right)', and 't.ordering.lt(left, right)' + * Precisely, 'Interval' assumes that there exists a Hail type 't: Type' such that either + * - 't: TBaseStruct', and 't.relaxedTypeCheck(left)', 't.relaxedTypeCheck(right), and + * 't.ordering.intervalEndpointOrdering.lt(left, right)', or + * - 't.typeCheck(left)', 't.typeCheck(right)', and 't.ordering.lt(left, right)' * - * Moreover, every method on 'Interval' taking a 'pord' has the precondition - * that there exists a Hail type 't: Type' such that 'pord' was constructed by - * 't.ordering', and either - * - 't: TBaseStruct' and 't.relaxedTypeCheck(x)', or - * - 't.typeCheck(x)', - * where 'x' is each of 'left', 'right', 'p', 'other.left', and 'other.right' - * as appropriate. In the case 't: TBaseStruct', 't' could be replaced by any - * 't2' such that 't.isPrefixOf(t2)' without changing the behavior. + * Moreover, every method on 'Interval' taking a 'pord' has the precondition that there exists a + * Hail type 't: Type' such that 'pord' was constructed by 't.ordering', and either + * - 't: TBaseStruct' and 't.relaxedTypeCheck(x)', or + * - 't.typeCheck(x)', where 'x' is each of 'left', 'right', 'p', 'other.left', and 'other.right' + * as appropriate. In the case 't: TBaseStruct', 't' could be replaced by any 't2' such that + * 't.isPrefixOf(t2)' without changing the behavior. */ class Interval(val left: IntervalEndpoint, val right: IntervalEndpoint) extends Serializable { require(left != null) @@ -81,7 +75,12 @@ class Interval(val left: IntervalEndpoint, val right: IntervalEndpoint) extends def isDisjointFrom(pord: ExtendedOrdering, other: Interval): Boolean = !overlaps(pord, other) - def copy(start: Any = start, end: Any = end, includesStart: Boolean = includesStart, includesEnd: Boolean = includesEnd): Interval = + def copy( + start: Any = start, + end: Any = end, + includesStart: Boolean = includesStart, + includesEnd: Boolean = includesEnd, + ): Interval = Interval(start, end, includesStart, includesEnd) def extendLeft(newLeft: IntervalEndpoint): Interval = Interval(newLeft, right) @@ -89,10 +88,12 @@ class Interval(val left: IntervalEndpoint, val right: IntervalEndpoint) extends def extendRight(newRight: IntervalEndpoint): Interval = Interval(left, newRight) def toJSON(f: (Any) => JValue): JValue = - JObject("start" -> f(start), + JObject( + "start" -> f(start), "end" -> f(end), "includeStart" -> TBoolean.toJSON(includesStart), - "includeEnd" -> TBoolean.toJSON(includesEnd)) + "includeEnd" -> TBoolean.toJSON(includesEnd), + ) def isBelow(pord: ExtendedOrdering, other: Interval): Boolean = ext(pord).compare(this.right, other.left) <= 0 @@ -123,7 +124,8 @@ class Interval(val left: IntervalEndpoint, val right: IntervalEndpoint) extends if (ext(pord).compare(this.right, other.right) < 0) other.right else - this.right) + this.right, + ) def intersect(pord: ExtendedOrdering, other: Interval): Option[Interval] = if (overlaps(pord, other)) { @@ -137,14 +139,16 @@ class Interval(val left: IntervalEndpoint, val right: IntervalEndpoint) extends if (ext(pord).compare(this.right, other.right) < 0) this.right else - other.right)) + other.right, + )) } else None def coarsen(newKeyLen: Int): Interval = Interval(left.coarsenLeft(newKeyLen), right.coarsenRight(newKeyLen)) - override def toString: String = (if (includesStart) "[" else "(") + start + "-" + end + (if (includesEnd) "]" else ")") + override def toString: String = + (if (includesStart) "[" else "(") + start + "-" + end + (if (includesEnd) "]" else ")") override def equals(other: Any): Boolean = other match { case that: Interval => left == that.left && right == that.right @@ -166,32 +170,43 @@ object Interval { def unapply(interval: Interval): Option[(Any, Any, Boolean, Boolean)] = Some((interval.start, interval.end, interval.includesStart, interval.includesEnd)) - def orNone(pord: ExtendedOrdering, - start: Any, end: Any, - includesStart: Boolean, includesEnd: Boolean + def orNone( + pord: ExtendedOrdering, + start: Any, + end: Any, + includesStart: Boolean, + includesEnd: Boolean, ): Option[Interval] = if (isValid(pord, start, end, includesStart, includesEnd)) Some(Interval(start, end, includesStart, includesEnd)) else None - def orNone(pord: ExtendedOrdering, left: IntervalEndpoint, right: IntervalEndpoint): Option[Interval] = + def orNone(pord: ExtendedOrdering, left: IntervalEndpoint, right: IntervalEndpoint) + : Option[Interval] = orNone(pord, left.point, right.point, left.sign < 0, right.sign > 0) - def isValid(pord: ExtendedOrdering, - start: Any, end: Any, - includesStart: Boolean, includesEnd: Boolean + def isValid( + pord: ExtendedOrdering, + start: Any, + end: Any, + includesStart: Boolean, + includesEnd: Boolean, ): Boolean = { val (left, right) = toIntervalEndpoints(start, end, includesStart, includesEnd) pord.intervalEndpointOrdering.compare(left, right) < 0 } def toIntervalEndpoints( - start: Any, end: Any, - includesStart: Boolean, includesEnd: Boolean + start: Any, + end: Any, + includesStart: Boolean, + includesEnd: Boolean, ): (IntervalEndpoint, IntervalEndpoint) = - (IntervalEndpoint(start, if (includesStart) -1 else 1), - IntervalEndpoint(end, if (includesEnd) 1 else -1)) + ( + IntervalEndpoint(start, if (includesStart) -1 else 1), + IntervalEndpoint(end, if (includesEnd) 1 else -1), + ) def gen[P](pord: ExtendedOrdering, pgen: Gen[P]): Gen[Interval] = Gen.zip(pgen, pgen, Gen.coin(), Gen.coin()) @@ -203,7 +218,8 @@ object Interval { Interval(y, x, s, e) } - def ordering(pord: ExtendedOrdering, startPrimary: Boolean, _missingEqual: Boolean = true): ExtendedOrdering = new ExtendedOrdering { + def ordering(pord: ExtendedOrdering, startPrimary: Boolean, _missingEqual: Boolean = true) + : ExtendedOrdering = new ExtendedOrdering { val missingEqual = _missingEqual override def compareNonnull(x: Any, y: Any): Int = { @@ -240,7 +256,11 @@ object Interval { } // assumes that both `x1` and `x2` are both sorted, non-overlapping interval sequences. - def intersection(x1: IndexedSeq[Interval], x2: IndexedSeq[Interval], ord: IntervalEndpointOrdering): Array[Interval] = { + def intersection( + x1: IndexedSeq[Interval], + x2: IndexedSeq[Interval], + ord: IntervalEndpointOrdering, + ): Array[Interval] = { var i = 0 var j = 0 diff --git a/hail/src/main/scala/is/hail/utils/LoggerOutputStream.scala b/hail/src/main/scala/is/hail/utils/LoggerOutputStream.scala index c74d2a62972..fcd13f12de4 100644 --- a/hail/src/main/scala/is/hail/utils/LoggerOutputStream.scala +++ b/hail/src/main/scala/is/hail/utils/LoggerOutputStream.scala @@ -8,15 +8,15 @@ import org.apache.log4j.{Level, Logger} class LoggerOutputStream(logger: Logger, level: Level) extends OutputStream { private val buffer = new ByteArrayOutputStream() - override def write(b: Int) { + override def write(b: Int): Unit = { buffer.write(b) if (b == '\n') { val line = buffer.toString(StandardCharsets.UTF_8.name()) level match { case Level.TRACE => logger.trace(line) case Level.DEBUG => logger.debug(line) - case Level.INFO => logger.info(line) - case Level.WARN => logger.warn(line) + case Level.INFO => logger.info(line) + case Level.WARN => logger.warn(line) case Level.ERROR => logger.error(line) } buffer.reset() diff --git a/hail/src/main/scala/is/hail/utils/Logging.scala b/hail/src/main/scala/is/hail/utils/Logging.scala index 2ce8c67beed..8845d6cbe71 100644 --- a/hail/src/main/scala/is/hail/utils/Logging.scala +++ b/hail/src/main/scala/is/hail/utils/Logging.scala @@ -25,11 +25,10 @@ trait Logging { consoleLogger } - def info(msg: String) { + def info(msg: String): Unit = consoleLog.info(msg) - } - def info(msg: String, t: Truncatable) { + def info(msg: String, t: Truncatable): Unit = { val (screen, logged) = t.strings if (screen == logged) consoleLog.info(format(msg, screen)) @@ -40,11 +39,10 @@ trait Logging { } } - def warn(msg: String) { + def warn(msg: String): Unit = consoleLog.warn(msg) - } - def warn(msg: String, t: Truncatable) { + def warn(msg: String, t: Truncatable): Unit = { val (screen, logged) = t.strings if (screen == logged) consoleLog.warn(format(msg, screen)) @@ -55,7 +53,6 @@ trait Logging { } } - def error(msg: String) { + def error(msg: String): Unit = consoleLog.error(msg) - } } diff --git a/hail/src/main/scala/is/hail/utils/MemoryBufferWrapper.scala b/hail/src/main/scala/is/hail/utils/MemoryBufferWrapper.scala index 3591a280769..113baaf9cab 100644 --- a/hail/src/main/scala/is/hail/utils/MemoryBufferWrapper.scala +++ b/hail/src/main/scala/is/hail/utils/MemoryBufferWrapper.scala @@ -23,7 +23,6 @@ final class MemoryWriterWrapper { def length(): Int = mb.end - def copyToAddress(addr: Long): Unit = { + def copyToAddress(addr: Long): Unit = mb.readBytes(addr, mb.end) - } } diff --git a/hail/src/main/scala/is/hail/utils/MissingArrayBuilder.scala b/hail/src/main/scala/is/hail/utils/MissingArrayBuilder.scala index aed4e4fa465..d7159124e23 100644 --- a/hail/src/main/scala/is/hail/utils/MissingArrayBuilder.scala +++ b/hail/src/main/scala/is/hail/utils/MissingArrayBuilder.scala @@ -20,12 +20,12 @@ final class MissingArrayBuilder[@specialized T](initialCapacity: Int)(implicit t b(i) } - def update(i: Int, x: T) { + def update(i: Int, x: T): Unit = { require(i >= 0 && i < size) b(i) = x } - def ensureCapacity(n: Int) { + def ensureCapacity(n: Int): Unit = { if (b.length < n) { val newCapacity = (b.length * 2).max(n) val newb = new Array[T](newCapacity) @@ -37,11 +37,10 @@ final class MissingArrayBuilder[@specialized T](initialCapacity: Int)(implicit t } } - def clear(): Unit = { + def clear(): Unit = size_ = 0 - } - def +=(x: T) { + def +=(x: T): Unit = { ensureCapacity(size_ + 1) b(size_) = x missing(size_) = false @@ -52,7 +51,7 @@ final class MissingArrayBuilder[@specialized T](initialCapacity: Int)(implicit t def ++=(a: Array[T]): Unit = ++=(a, a.length) - def ++=(a: Array[T], length: Int) { + def ++=(a: Array[T], length: Int): Unit = { require(length >= 0 && length <= a.length) ensureCapacity(size_ + length) System.arraycopy(a, 0, b, size_, length) @@ -76,7 +75,7 @@ final class MissingArrayBuilder[@specialized T](initialCapacity: Int)(implicit t missing(i) = m } - def addMissing() { + def addMissing(): Unit = { ensureCapacity(size_ + 1) missing(size_) = true size_ += 1 diff --git a/hail/src/main/scala/is/hail/utils/MultiArray2.scala b/hail/src/main/scala/is/hail/utils/MultiArray2.scala index fb96e87c622..c72c9114071 100644 --- a/hail/src/main/scala/is/hail/utils/MultiArray2.scala +++ b/hail/src/main/scala/is/hail/utils/MultiArray2.scala @@ -1,43 +1,51 @@ package is.hail.utils -import java.io.Serializable - import scala.collection.immutable.IndexedSeq import scala.reflect.ClassTag +import java.io.Serializable -class MultiArray2[@specialized(Int, Long, Float, Double, Boolean) T](val n1: Int, - val n2: Int, - val a: Array[T]) extends Serializable with Iterable[T] { +class MultiArray2[@specialized(Int, Long, Float, Double, Boolean) T]( + val n1: Int, + val n2: Int, + val a: Array[T], +) extends Serializable with Iterable[T] { require(n1 >= 0 && n2 >= 0) - require(a.length == n1*n2) + require(a.length == n1 * n2) - class Row(val i:Int) extends IndexedSeq[T] { + class Row(val i: Int) extends IndexedSeq[T] { require(i >= 0 && i < n1) - def apply(j:Int): T = { + + def apply(j: Int): T = { if (j < 0 || j >= length) throw new ArrayIndexOutOfBoundsException - a(i*n2 + j) + a(i * n2 + j) } + def length: Int = n2 } - class Column(val j:Int) extends IndexedSeq[T] { + class Column(val j: Int) extends IndexedSeq[T] { require(j >= 0 && j < n2) - def apply(i:Int): T = { + + def apply(i: Int): T = { if (i < 0 || i >= length) throw new ArrayIndexOutOfBoundsException - a(i*n2 + j) + a(i * n2 + j) } + def length: Int = n1 } - def row(i:Int) = new Row(i) - def column(j:Int) = new Column(j) + def row(i: Int) = new Row(i) + def column(j: Int) = new Column(j) def rows: Iterable[Row] = for (i <- rowIndices) yield row(i) def columns: Iterable[Column] = for (j <- columnIndices) yield column(j) - def indices: Iterable[(Int,Int)] = for (i <- 0 until n1; j <- 0 until n2) yield (i, j) + def indices: Iterable[(Int, Int)] = for { + i <- 0 until n1 + j <- 0 until n2 + } yield (i, j) def rowIndices: Iterable[Int] = 0 until n1 @@ -45,24 +53,24 @@ class MultiArray2[@specialized(Int, Long, Float, Double, Boolean) T](val n1: Int def apply(i: Int, j: Int): T = { require(i >= 0 && i < n1 && j >= 0 && j < n2) - a(i*n2 + j) + a(i * n2 + j) } - def update(i: Int, j: Int, x:T): Unit = { + def update(i: Int, j: Int, x: T): Unit = { require(i >= 0 && i < n1 && j >= 0 && j < n2) - a.update(i*n2 + j,x) + a.update(i * n2 + j, x) } - def update(t: (Int,Int), x:T): Unit = { + def update(t: (Int, Int), x: T): Unit = { require(t._1 >= 0 && t._1 < n1 && t._2 >= 0 && t._2 < n2) - update(t._1,t._2,x) + update(t._1, t._2, x) } def array: Array[T] = a - def zip[S](other: MultiArray2[S]): MultiArray2[(T,S)] = { + def zip[S](other: MultiArray2[S]): MultiArray2[(T, S)] = { require(n1 == other.n1 && n2 == other.n2) - new MultiArray2(n1,n2,a.zip(other.a)) + new MultiArray2(n1, n2, a.zip(other.a)) } def iterator: Iterator[T] = a.iterator @@ -75,4 +83,3 @@ object MultiArray2 { def empty[T](implicit tct: ClassTag[T]): MultiArray2[T] = new MultiArray2[T](0, 0, Array.empty[T](tct)) } - diff --git a/hail/src/main/scala/is/hail/utils/NumericImplicits.scala b/hail/src/main/scala/is/hail/utils/NumericImplicits.scala index 1ccdaf19a19..ed9367ad660 100644 --- a/hail/src/main/scala/is/hail/utils/NumericImplicits.scala +++ b/hail/src/main/scala/is/hail/utils/NumericImplicits.scala @@ -1,17 +1,20 @@ package is.hail.utils -import breeze.linalg.operators.{OpAdd, OpSub} +import scala.language.implicitConversions + import breeze.linalg.{DenseVector => BDenseVector, SparseVector => BSparseVector, Vector => BVector} +import breeze.linalg.operators.{OpAdd, OpSub} +import org.apache.spark.mllib.linalg.{ + DenseVector => SDenseVector, SparseVector => SSparseVector, Vector => SVector, +} import org.apache.spark.mllib.linalg.distributed.IndexedRow -import org.apache.spark.mllib.linalg.{DenseVector => SDenseVector, SparseVector => SSparseVector, Vector => SVector} - -import scala.language.implicitConversions trait NumericImplicits { implicit def toBDenseVector(v: SDenseVector): BDenseVector[Double] = new BDenseVector(v.values) - implicit def toBSparseVector(v: SSparseVector): BSparseVector[Double] = new BSparseVector(v.indices, v.values, v.size) + implicit def toBSparseVector(v: SSparseVector): BSparseVector[Double] = + new BSparseVector(v.indices, v.values, v.size) implicit def toBVector(v: SVector): BVector[Double] = v match { case v: SSparseVector => v @@ -20,7 +23,8 @@ trait NumericImplicits { implicit def toSDenseVector(v: BDenseVector[Double]): SDenseVector = new SDenseVector(v.toArray) - implicit def toSSparseVector(v: BSparseVector[Double]): SSparseVector = new SSparseVector(v.length, v.array.index, v.array.data) + implicit def toSSparseVector(v: BSparseVector[Double]): SSparseVector = + new SSparseVector(v.length, v.array.index, v.array.data) implicit def toSVector(v: BVector[Double]): SVector = v match { case v: BDenseVector[Double] => v @@ -31,16 +35,20 @@ trait NumericImplicits { def apply(a: BVector[Double], b: SVector): BVector[Double] = a - toBVector(b) } - implicit object subBVectorIndexedRow extends OpSub.Impl2[BVector[Double], IndexedRow, IndexedRow] { - def apply(a: BVector[Double], b: IndexedRow): IndexedRow = IndexedRow(b.index, a - toBVector(b.vector)) + implicit object subBVectorIndexedRow + extends OpSub.Impl2[BVector[Double], IndexedRow, IndexedRow] { + def apply(a: BVector[Double], b: IndexedRow): IndexedRow = + IndexedRow(b.index, a - toBVector(b.vector)) } implicit object addBVectorSVector extends OpAdd.Impl2[BVector[Double], SVector, BVector[Double]] { def apply(a: BVector[Double], b: SVector): BVector[Double] = a + toBVector(b) } - implicit object addBVectorIndexedRow extends OpAdd.Impl2[BVector[Double], IndexedRow, IndexedRow] { - def apply(a: BVector[Double], b: IndexedRow): IndexedRow = IndexedRow(b.index, a + toBVector(b.vector)) + implicit object addBVectorIndexedRow + extends OpAdd.Impl2[BVector[Double], IndexedRow, IndexedRow] { + def apply(a: BVector[Double], b: IndexedRow): IndexedRow = + IndexedRow(b.index, a + toBVector(b.vector)) } } diff --git a/hail/src/main/scala/is/hail/utils/NumericPair.scala b/hail/src/main/scala/is/hail/utils/NumericPair.scala index 7a243254d11..ebb0876e9c5 100644 --- a/hail/src/main/scala/is/hail/utils/NumericPair.scala +++ b/hail/src/main/scala/is/hail/utils/NumericPair.scala @@ -38,7 +38,6 @@ trait NumericPairImplicits { val numeric: scala.math.Numeric[Float] = implicitly[Numeric[Float]] } - implicit object DoublePair extends NumericPair[Double, java.lang.Double] { def box(t: Double): java.lang.Double = Double.box(t) diff --git a/hail/src/main/scala/is/hail/utils/OrderingView.scala b/hail/src/main/scala/is/hail/utils/OrderingView.scala index ddf499a7c1c..b81987064fa 100644 --- a/hail/src/main/scala/is/hail/utils/OrderingView.scala +++ b/hail/src/main/scala/is/hail/utils/OrderingView.scala @@ -4,16 +4,17 @@ trait OrderingView[A] { protected def setFiniteValue(a: A): Unit protected def compareFinite(a: A): Int - def setValue(a: A) { + def setValue(a: A): Unit = { isInfinite = 0 setFiniteValue(a) } + def compare(a: A): Int = if (isInfinite != 0) isInfinite else compareFinite(a) def isEquivalent(a: A): Boolean = compare(a) == 0 - def setBottom() { isInfinite = -1 } - def setTop() { isInfinite = 1 } + def setBottom(): Unit = isInfinite = -1 + def setTop(): Unit = isInfinite = 1 private var isInfinite: Int = -1 } diff --git a/hail/src/main/scala/is/hail/utils/ParseTrieNode.scala b/hail/src/main/scala/is/hail/utils/ParseTrieNode.scala index ac5df43268e..70baf217bc2 100644 --- a/hail/src/main/scala/is/hail/utils/ParseTrieNode.scala +++ b/hail/src/main/scala/is/hail/utils/ParseTrieNode.scala @@ -2,18 +2,18 @@ package is.hail.utils import scala.collection.mutable.ArrayBuffer -/** - * A character trie used for parsing membership in a set literal from a character sequence. +/** A character trie used for parsing membership in a set literal from a character sequence. * - * The children of a node are represented in an unordered array; linear search is used to - * traverse the tree with the assumption that the average number of children per node is - * small (small enough that a O(n log n) binary search or O(1) hash lookup would be more - * expensive) + * The children of a node are represented in an unordered array; linear search is used to traverse + * the tree with the assumption that the average number of children per node is small (small enough + * that a O(n log n) binary search or O(1) hash lookup would be more expensive) */ -class ParseTrieNode(val value: String, +class ParseTrieNode( + val value: String, val children: Array[ParseTrieNode], - val nextChar: Array[Char]) { + val nextChar: Array[Char], +) { def search(next: Char): ParseTrieNode = { var i = 0 @@ -26,20 +26,21 @@ class ParseTrieNode(val value: String, } } - object ParseTrieNode { def generate(data: Array[String]): ParseTrieNode = { - class ParseTrieNodeBuilder(var value: String, + class ParseTrieNodeBuilder( + var value: String, var children: ArrayBuffer[ParseTrieNodeBuilder], - var nextChar: ArrayBuffer[Char]) { + var nextChar: ArrayBuffer[Char], + ) { - def result(): ParseTrieNode = { + def result(): ParseTrieNode = new ParseTrieNode(value, children.toArray.map(_.result()), nextChar.toArray) - } } - val root = new ParseTrieNodeBuilder(null, new ArrayBuffer[ParseTrieNodeBuilder], new ArrayBuffer[Char]) + val root = + new ParseTrieNodeBuilder(null, new ArrayBuffer[ParseTrieNodeBuilder], new ArrayBuffer[Char]) def insert(s: String): Unit = { var idx = 0 @@ -53,7 +54,11 @@ object ParseTrieNode { while (continue) { if (i >= charBuff.size) { node.nextChar += next - node = new ParseTrieNodeBuilder(null, new ArrayBuffer[ParseTrieNodeBuilder], new ArrayBuffer[Char]) + node = new ParseTrieNodeBuilder( + null, + new ArrayBuffer[ParseTrieNodeBuilder], + new ArrayBuffer[Char], + ) buff += node continue = false } else { @@ -76,4 +81,4 @@ object ParseTrieNode { root.result() } -} \ No newline at end of file +} diff --git a/hail/src/main/scala/is/hail/utils/PartitionCounts.scala b/hail/src/main/scala/is/hail/utils/PartitionCounts.scala index 611c7ba0c7e..1ddc6d990a7 100644 --- a/hail/src/main/scala/is/hail/utils/PartitionCounts.scala +++ b/hail/src/main/scala/is/hail/utils/PartitionCounts.scala @@ -5,7 +5,7 @@ object PartitionCounts { case class PCSubsetOffset( finalIndex: Int, nKeep: Long, - nDrop: Long + nDrop: Long, ) def getPCSubsetOffset(n: Long, pcs: Iterator[Long]): Option[PCSubsetOffset] = { @@ -46,8 +46,10 @@ object PartitionCounts { def incrementalPCSubsetOffset( n: Long, - partIndices: IndexedSeq[Int] - )(computePCs: IndexedSeq[Int] => IndexedSeq[Long]): PCSubsetOffset = { + partIndices: IndexedSeq[Int], + )( + computePCs: IndexedSeq[Int] => IndexedSeq[Long] + ): PCSubsetOffset = { var nLeft = n var nPartsScanned = 0 var lastIdx = -1 diff --git a/hail/src/main/scala/is/hail/utils/Py4jUtils.scala b/hail/src/main/scala/is/hail/utils/Py4jUtils.scala index 119c4a75646..ad105c88c87 100644 --- a/hail/src/main/scala/is/hail/utils/Py4jUtils.scala +++ b/hail/src/main/scala/is/hail/utils/Py4jUtils.scala @@ -2,14 +2,16 @@ package is.hail.utils import is.hail.HailContext import is.hail.expr.JSONAnnotationImpex -import is.hail.io.fs.{FS, FileListEntry, SeekableDataInputStream} +import is.hail.io.fs.{FS, FileListEntry, FileStatus, SeekableDataInputStream} import is.hail.types.virtual.Type -import org.json4s.JsonAST._ -import org.json4s.jackson.JsonMethods -import java.io.{InputStream, OutputStream} import scala.collection.JavaConverters._ +import java.io.{InputStream, OutputStream} + +import org.json4s.JsonAST._ +import org.json4s.jackson.JsonMethods + trait Py4jUtils { def arrayToArrayList[T](arr: Array[T]): java.util.ArrayList[T] = { val list = new java.util.ArrayList[T]() @@ -55,31 +57,42 @@ trait Py4jUtils { JsonMethods.compact(JArray(statuses.map(fs => fileListEntryToJson(fs)).toList)) } + def fileStatus(fs: FS, path: String): String = { + val stat = fs.fileStatus(path) + JsonMethods.compact(fileStatusToJson(stat)) + } + def fileListEntry(fs: FS, path: String): String = { val stat = fs.fileListEntry(path) JsonMethods.compact(fileListEntryToJson(stat)) } - private def fileListEntryToJson(fs: FileListEntry): JObject = { + private def fileStatusToJson(fs: FileStatus): JObject = { JObject( "path" -> JString(fs.getPath.toString), "size" -> JInt(fs.getLen), - "is_dir" -> JBool(fs.isDirectory), "is_link" -> JBool(fs.isSymlink), "modification_time" -> (if (fs.getModificationTime != null) - JString( - new java.text.SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSSSSSZ").format( - new java.util.Date(fs.getModificationTime))) - else - JNull), + JString( + new java.text.SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSSSSSZ").format( + new java.util.Date(fs.getModificationTime) + ) + ) + else + JNull), "owner" -> ( if (fs.getOwner != null) JString(fs.getOwner) else - JNull)) + JNull + ), + ) } + private def fileListEntryToJson(fs: FileListEntry): JObject = + JObject(fileStatusToJson(fs).obj :+ ("is_dir" -> JBool(fs.isDirectory))) + private val kilo: Long = 1024 private val mega: Long = kilo * 1024 private val giga: Long = mega * 1024 @@ -98,9 +111,8 @@ trait Py4jUtils { formatDigits(bytes, tera) + "T" } - private def formatDigits(n: Long, factor: Long): String = { + private def formatDigits(n: Long, factor: Long): String = "%.1f".format(n / factor.toDouble) - } def readFile(fs: FS, path: String, buffSize: Int): HadoopSeekablePyReader = new HadoopSeekablePyReader(fs.fileListEntry(path), fs.openNoCompression(path), buffSize) @@ -120,22 +132,18 @@ trait Py4jUtils { new HadoopPyWriter(fs.create(path)) } - def addSocketAppender(hostname: String, port: Int) { + def addSocketAppender(hostname: String, port: Int): Unit = StringSocketAppender.get() .connect(hostname, port, HailContext.logFormat) - } - def logWarn(msg: String) { + def logWarn(msg: String): Unit = warn(msg) - } - def logInfo(msg: String) { + def logInfo(msg: String): Unit = info(msg) - } - def logError(msg: String) { + def logError(msg: String): Unit = error(msg) - } def makeJSON(t: Type, value: Any): String = { val jv = JSONAnnotationImpex.exportAnnotation(value, t) @@ -158,12 +166,12 @@ class HadoopPyReader(in: InputStream, buffSize: Int) { buff.slice(0, bytesRead) } - def close() { + def close(): Unit = in.close() - } } -class HadoopSeekablePyReader(status: FileListEntry, in: SeekableDataInputStream, buffSize: Int) extends HadoopPyReader(in, buffSize) { +class HadoopSeekablePyReader(status: FileListEntry, in: SeekableDataInputStream, buffSize: Int) + extends HadoopPyReader(in, buffSize) { def seek(pos: Long, whence: Int): Long = { // whence corresponds to python arguments to seek // it is validated in python @@ -185,15 +193,13 @@ class HadoopSeekablePyReader(status: FileListEntry, in: SeekableDataInputStream, } class HadoopPyWriter(out: OutputStream) { - def write(b: Array[Byte]) { + def write(b: Array[Byte]): Unit = out.write(b) - } - def flush() { + def flush(): Unit = out.flush() - } - def close() { + def close(): Unit = { out.flush() out.close() } diff --git a/hail/src/main/scala/is/hail/utils/RelativeFuzzyComparable.scala b/hail/src/main/scala/is/hail/utils/RelativeFuzzyComparable.scala index 9742e4d2c46..5cc572e6b2c 100644 --- a/hail/src/main/scala/is/hail/utils/RelativeFuzzyComparable.scala +++ b/hail/src/main/scala/is/hail/utils/RelativeFuzzyComparable.scala @@ -1,19 +1,21 @@ package is.hail.utils trait RelativeFuzzyComparable[A] { - def relativeEq(tolerance: Double, x: A, y: A) : Boolean + def relativeEq(tolerance: Double, x: A, y: A): Boolean } object RelativeFuzzyComparable { - def relativeEq[T](tolerance: Double, x: T, y: T)(implicit afc: RelativeFuzzyComparable[T]): Boolean = + def relativeEq[T](tolerance: Double, x: T, y: T)(implicit afc: RelativeFuzzyComparable[T]) + : Boolean = afc.relativeEq(tolerance, x, y) implicit object rfcDoubles extends RelativeFuzzyComparable[Double] { def relativeEq(tolerance: Double, x: Double, y: Double) = D_==(x, y, tolerance) } - implicit def rfcMaps[K, V](implicit vRFC: RelativeFuzzyComparable[V]): RelativeFuzzyComparable[Map[K, V]] = + implicit def rfcMaps[K, V](implicit vRFC: RelativeFuzzyComparable[V]) + : RelativeFuzzyComparable[Map[K, V]] = new RelativeFuzzyComparable[Map[K, V]] { def relativeEq(tolerance: Double, x: Map[K, V], y: Map[K, V]) = x.keySet == y.keySet && x.keys.forall(k => vRFC.relativeEq(tolerance, x(k), y(k))) diff --git a/hail/src/main/scala/is/hail/utils/RestartableByteArrayInputStream.scala b/hail/src/main/scala/is/hail/utils/RestartableByteArrayInputStream.scala index 8fcaf8c69d3..f6a79b683b2 100644 --- a/hail/src/main/scala/is/hail/utils/RestartableByteArrayInputStream.scala +++ b/hail/src/main/scala/is/hail/utils/RestartableByteArrayInputStream.scala @@ -1,13 +1,14 @@ package is.hail.utils -import java.io.{ IOException, InputStream } +import java.io.{IOException, InputStream} // not thread safe class RestartableByteArrayInputStream extends InputStream { private[this] var off: Int = 0 private[this] var end: Int = 0 private[this] var buf: Array[Byte] = null - def this(buf: Array[Byte]) { + + def this(buf: Array[Byte]) = { this() restart(buf) } @@ -20,35 +21,46 @@ class RestartableByteArrayInputStream extends InputStream { off += 1 b } + override def read(dest: Array[Byte]): Int = read(dest, 0, dest.length) + override def read(dest: Array[Byte], destOff: Int, requestedLength: Int): Int = { val length = math.min(requestedLength, end - off) System.arraycopy(buf, off, dest, destOff, length) off += length length } + override def skip(n: Long): Long = { if (n <= 0) { return 0 } val skipped = math.min( math.min(n, Integer.MAX_VALUE).toInt, - end - off) + end - off, + ) off += skipped skipped } + override def available(): Int = end - off + override def markSupported(): Boolean = false + override def mark(readAheadLimit: Int): Unit = throw new IOException("unsupported operation") + override def reset(): Unit = throw new IOException("unsupported operation") + override def close(): Unit = buf = null + def restart(buf: Array[Byte]): Unit = restart(buf, 0, buf.length) + def restart(buf: Array[Byte], start: Int, end: Int): Unit = { require(start >= 0) require(start <= end) diff --git a/hail/src/main/scala/is/hail/utils/SemanticVersion.scala b/hail/src/main/scala/is/hail/utils/SemanticVersion.scala index e477604513b..a7827f6b5ba 100644 --- a/hail/src/main/scala/is/hail/utils/SemanticVersion.scala +++ b/hail/src/main/scala/is/hail/utils/SemanticVersion.scala @@ -5,10 +5,9 @@ case class SemanticVersion(major: Int, minor: Int, patch: Int) extends Ordered[S assert((minor & 0xff) == minor) assert((patch & 0xff) == patch) - def supports(that: SemanticVersion): Boolean = { + def supports(that: SemanticVersion): Boolean = major == that.major && that.minor <= minor - } def rep: Int = (major << 16) | (minor << 8) | patch @@ -27,4 +26,4 @@ case class SemanticVersion(major: Int, minor: Int, patch: Int) extends Ordered[S object SemanticVersion { def apply(rep: Int): SemanticVersion = SemanticVersion((rep >> 16) & 0xff, (rep >> 8) & 0xff, rep & 0xff) -} \ No newline at end of file +} diff --git a/hail/src/main/scala/is/hail/utils/SerializableHadoopConfiguration.scala b/hail/src/main/scala/is/hail/utils/SerializableHadoopConfiguration.scala index 2340b4d81dd..28a266ca922 100644 --- a/hail/src/main/scala/is/hail/utils/SerializableHadoopConfiguration.scala +++ b/hail/src/main/scala/is/hail/utils/SerializableHadoopConfiguration.scala @@ -4,13 +4,14 @@ import java.io.{ObjectInputStream, ObjectOutputStream, Serializable} import org.apache.hadoop -class SerializableHadoopConfiguration(@transient var value: hadoop.conf.Configuration) extends Serializable { - private def writeObject(out: ObjectOutputStream) { +class SerializableHadoopConfiguration(@transient var value: hadoop.conf.Configuration) + extends Serializable { + private def writeObject(out: ObjectOutputStream): Unit = { out.defaultWriteObject() value.write(out) } - private def readObject(in: ObjectInputStream) { + private def readObject(in: ObjectInputStream): Unit = { value = new hadoop.conf.Configuration(false) value.readFields(in) } diff --git a/hail/src/main/scala/is/hail/utils/SpillingCollectIterator.scala b/hail/src/main/scala/is/hail/utils/SpillingCollectIterator.scala index 140909be905..a8bb6d8babb 100644 --- a/hail/src/main/scala/is/hail/utils/SpillingCollectIterator.scala +++ b/hail/src/main/scala/is/hail/utils/SpillingCollectIterator.scala @@ -3,13 +3,16 @@ package is.hail.utils import is.hail.backend.ExecuteContext import is.hail.backend.spark.SparkBackend import is.hail.io.fs.FS -import org.apache.spark.rdd.RDD + +import scala.reflect.{classTag, ClassTag} import java.io.{ObjectInputStream, ObjectOutputStream} -import scala.reflect.{ClassTag, classTag} + +import org.apache.spark.rdd.RDD object SpillingCollectIterator { - def apply[T: ClassTag](localTmpdir: String, fs: FS, rdd: RDD[T], sizeLimit: Int): SpillingCollectIterator[T] = { + def apply[T: ClassTag](localTmpdir: String, fs: FS, rdd: RDD[T], sizeLimit: Int) + : SpillingCollectIterator[T] = { val nPartitions = rdd.partitions.length val x = new SpillingCollectIterator(localTmpdir, fs, nPartitions, sizeLimit) val ctc = classTag[T] @@ -17,12 +20,18 @@ object SpillingCollectIterator { rdd, (_, it: Iterator[T]) => it.toArray(ctc), 0 until nPartitions, - x.append _) + x.append _, + ) x } } -class SpillingCollectIterator[T: ClassTag] private (localTmpdir: String, fs: FS, nPartitions: Int, sizeLimit: Int) extends Iterator[T] { +class SpillingCollectIterator[T: ClassTag] private ( + localTmpdir: String, + fs: FS, + nPartitions: Int, + sizeLimit: Int, +) extends Iterator[T] { private[this] val files: Array[(String, Long)] = new Array(nPartitions) private[this] val buf: Array[Array[T]] = new Array(nPartitions) private[this] var _size: Long = 0L @@ -34,7 +43,8 @@ class SpillingCollectIterator[T: ClassTag] private (localTmpdir: String, fs: FS, buf(partition) = a _size += a.length if (_size > sizeLimit) { - val file = ExecuteContext.createTmpPathNoCleanup(localTmpdir, s"spilling-collect-iterator-$partition") + val file = + ExecuteContext.createTmpPathNoCleanup(localTmpdir, s"spilling-collect-iterator-$partition") log.info(s"spilling partition $partition to $file") using(fs.createNoCompression(file)) { os => var k = 0 @@ -95,4 +105,3 @@ class SpillingCollectIterator[T: ClassTag] private (localTmpdir: String, fs: FS, it.next } } - diff --git a/hail/src/main/scala/is/hail/utils/StackSafe.scala b/hail/src/main/scala/is/hail/utils/StackSafe.scala index cf9dda716b4..110da9c9de3 100644 --- a/hail/src/main/scala/is/hail/utils/StackSafe.scala +++ b/hail/src/main/scala/is/hail/utils/StackSafe.scala @@ -34,14 +34,14 @@ object StackSafe { private class ContCell[A, B](val f: A => StackFrame[B], var next: ContCell[B, _] = null) - private final case class Done[A](result: A) extends StackFrame[A] + final private case class Done[A](result: A) extends StackFrame[A] - private final case class Thunk[A](force: () => StackFrame[A]) extends StackFrame[A] + final private case class Thunk[A](force: () => StackFrame[A]) extends StackFrame[A] - private final class More[A, B]( + final private class More[A, B]( _next: StackFrame[A], _contHead: ContCell[A, _], - _contTail: ContCell[_, B] + _contTail: ContCell[_, B], ) extends StackFrame[B] { // type erased locals allow mutating to different type parameters, then // casting `this` if needed @@ -57,9 +57,9 @@ object StackSafe { if (contHead == null) next.asInstanceOf[StackFrame[B]] else next match { case Done(result) => - nextTE = contHead.f(result) - contHeadTE = contHead.next - this + nextTE = contHead.f(result) + contHeadTE = contHead.next + this case thunk: Thunk[A] => nextTE = thunk.force() this @@ -83,19 +83,19 @@ object StackSafe { } } - implicit class RichIndexedSeq[A](val s: IndexedSeq[A]) extends AnyVal { - def mapRecur[B, That](f: A => StackFrame[B])(implicit bf: CanBuildFrom[IndexedSeq[A], B, That]): StackFrame[That] = { + implicit class RichIndexedSeq[A](private val s: IndexedSeq[A]) extends AnyVal { + def mapRecur[B, That](f: A => StackFrame[B])(implicit bf: CanBuildFrom[IndexedSeq[A], B, That]) + : StackFrame[That] = { val builder = bf(s) builder.sizeHint(s) var i = 0 var cont: B => StackFrame[That] = null - def loop(): StackFrame[That] = { + def loop(): StackFrame[That] = if (i < s.size) { f(s(i)).flatMap(cont) } else { done(builder.result) } - } cont = { b => builder += b i += 1 @@ -105,19 +105,19 @@ object StackSafe { } } - implicit class RichArray[A](val a: Array[A]) extends AnyVal { - def mapRecur[B](f: A => StackFrame[B])(implicit bf: CanBuildFrom[Array[A], B, Array[B]]): StackFrame[Array[B]] = { + implicit class RichArray[A](private val a: Array[A]) extends AnyVal { + def mapRecur[B](f: A => StackFrame[B])(implicit bf: CanBuildFrom[Array[A], B, Array[B]]) + : StackFrame[Array[B]] = { val builder = bf(a) builder.sizeHint(a) var i = 0 var cont: B => StackFrame[Array[B]] = null - def loop(): StackFrame[Array[B]] = { + def loop(): StackFrame[Array[B]] = if (i < a.size) { f(a(i)).flatMap(cont) } else { done(builder.result) } - } cont = { b => builder += b i += 1 @@ -127,40 +127,37 @@ object StackSafe { } } - implicit class RichOption[A](val o: Option[A]) extends AnyVal { - def mapRecur[B](f: A => StackFrame[B]): StackFrame[Option[B]] = { + implicit class RichOption[A](private val o: Option[A]) extends AnyVal { + def mapRecur[B](f: A => StackFrame[B]): StackFrame[Option[B]] = o match { case None => done(None) case Some(a) => call(f(a)).map(b => Some(b)) } - } } - implicit class RichIterator[A](val i: Iterable[A]) extends AnyVal { + implicit class RichIterator[A](private val i: Iterable[A]) extends AnyVal { def foreachRecur(f: A => StackFrame[Unit]): StackFrame[Unit] = { val it = i.iterator - def loop(): StackFrame[Unit] = { + def loop(): StackFrame[Unit] = if (it.hasNext) { - f(it.next()).flatMap { _ => call(loop()) } + f(it.next()).flatMap(_ => call(loop())) } else { done(()) } - } loop() } } - implicit class RichIteratorStackFrame[A](val i: Iterator[StackFrame[A]]) extends AnyVal { + implicit class RichIteratorStackFrame[A](private val i: Iterator[StackFrame[A]]) extends AnyVal { def collectRecur(implicit bf: CanBuild[A, Array[A]]): StackFrame[IndexedSeq[A]] = { val builder = bf() var cont: A => StackFrame[IndexedSeq[A]] = null - def loop(): StackFrame[IndexedSeq[A]] = { + def loop(): StackFrame[IndexedSeq[A]] = if (i.hasNext) { i.next().flatMap(cont) } else { done(builder.result()) } - } cont = { a => builder += a call(loop()) @@ -169,18 +166,18 @@ object StackSafe { } } - def fillArray[A](n: Int)(body: => StackFrame[A])(implicit bf: CanBuild[A, Array[A]]): StackFrame[Array[A]] = { + def fillArray[A](n: Int)(body: => StackFrame[A])(implicit bf: CanBuild[A, Array[A]]) + : StackFrame[Array[A]] = { val builder = bf() builder.sizeHint(n) var i = 0 var cont: A => StackFrame[Array[A]] = null - def loop(): StackFrame[Array[A]] = { + def loop(): StackFrame[Array[A]] = if (i < n) { body.flatMap(cont) } else { done(builder.result) } - } cont = { a => builder += a i += 1 diff --git a/hail/src/main/scala/is/hail/utils/StringEscapeUtils.scala b/hail/src/main/scala/is/hail/utils/StringEscapeUtils.scala index 85e0170e397..5c5c452d268 100644 --- a/hail/src/main/scala/is/hail/utils/StringEscapeUtils.scala +++ b/hail/src/main/scala/is/hail/utils/StringEscapeUtils.scala @@ -6,9 +6,12 @@ object StringEscapeUtils { def hex(ch: Char): String = Integer.toHexString(ch).toUpperCase(Locale.ENGLISH) - def escapeStringSimple(str: String, escapeChar: Char, + def escapeStringSimple( + str: String, + escapeChar: Char, escapeFirst: (Char) => Boolean, - escape: (Char) => Boolean): String = { + escape: (Char) => Boolean, + ): String = { val sb = new StringBuilder var i: Int = 0 while (i < str.length) { @@ -95,8 +98,7 @@ object StringEscapeUtils { case _ => if (ch > 0xf) { sb.append("\\u00" + hex(ch)) - } - else { + } else { sb.append("\\u000" + hex(ch)) } } @@ -127,12 +129,12 @@ object StringEscapeUtils { sb.result() } - def unescapeString(str: String): String = unescapeString(str, new StringBuilder(capacity = str.length)) + def unescapeString(str: String): String = + unescapeString(str, new StringBuilder(capacity = str.length)) def unescapeString(str: String, sb: StringBuilder): String = { sb.clear() - val sz = str.length() var hadSlash = false var inUnicode = false lazy val unicode = new StringBuilder(capacity = 4) @@ -154,7 +156,7 @@ object StringEscapeUtils { inUnicode = false hadSlash = false } catch { - case nfe: NumberFormatException => + case _: NumberFormatException => fatal("Unable to parse unicode value: " + unicode) } } diff --git a/hail/src/main/scala/is/hail/utils/StringSocketAppender.scala b/hail/src/main/scala/is/hail/utils/StringSocketAppender.scala index 1e054c977f3..6f05eea2f93 100644 --- a/hail/src/main/scala/is/hail/utils/StringSocketAppender.scala +++ b/hail/src/main/scala/is/hail/utils/StringSocketAppender.scala @@ -1,16 +1,13 @@ package is.hail.utils -import org.apache.log4j.helpers.LogLog -import org.apache.log4j.spi.{ErrorCode, LoggingEvent} -import org.apache.log4j.{AppenderSkeleton, PatternLayout} - import java.io.{IOException, InterruptedIOException, ObjectOutputStream, OutputStream} import java.net.{ConnectException, InetAddress, Socket} -/** - * This class was translated and streamlined from - * org.apache.log4j.net.SocketAppender - */ +import org.apache.log4j.{AppenderSkeleton, PatternLayout} +import org.apache.log4j.helpers.LogLog +import org.apache.log4j.spi.{ErrorCode, LoggingEvent} + +/** This class was translated and streamlined from org.apache.log4j.net.SocketAppender */ object StringSocketAppender { // low reconnection delay because everything is local @@ -22,13 +19,11 @@ object StringSocketAppender { } class StringSocketAppender() extends AppenderSkeleton { - private var remoteHost: String = _ private var address: InetAddress = _ private var port: Int = _ private var os: OutputStream = _ - private var reconnectionDelay = StringSocketAppender.DEFAULT_RECONNECTION_DELAY + private val reconnectionDelay = StringSocketAppender.DEFAULT_RECONNECTION_DELAY private var connector: SocketConnector = null - private var counter = 0 private var patternLayout: PatternLayout = _ private var initialized: Boolean = false @@ -37,19 +32,18 @@ class StringSocketAppender() extends AppenderSkeleton { def connect(host: String, port: Int, format: String): Unit = { this.port = port this.address = InetAddress.getByName(host) - this.remoteHost = host this.patternLayout = new PatternLayout(format) connect(address, port) initialized = true } - override def close() { + override def close(): Unit = { if (closed) return this.closed = true cleanUp() } - def cleanUp() { + def cleanUp(): Unit = { if (os != null) { try os.close() @@ -66,7 +60,7 @@ class StringSocketAppender() extends AppenderSkeleton { } } - private def connect(address: InetAddress, port: Int) { + private def connect(address: InetAddress, port: Int): Unit = { if (this.address == null) return try { // First, close the previous connection if any. cleanUp() @@ -87,29 +81,34 @@ class StringSocketAppender() extends AppenderSkeleton { } } - override def append(event: LoggingEvent) { + override def append(event: LoggingEvent): Unit = { if (!initialized) return if (event == null) return if (address == null) { errorHandler.error("No remote host is set for SocketAppender named \"" + this.name + "\".") return } - if (os != null) try { - event.getLevel - val str = patternLayout.format(event) - os.write(str.getBytes("ISO-8859-1")) - os.flush() - } catch { - case e: IOException => - if (e.isInstanceOf[InterruptedIOException]) Thread.currentThread.interrupt() - os = null - LogLog.warn("Detected problem with connection: " + e) - if (reconnectionDelay > 0) fireConnector() - else errorHandler.error("Detected problem with connection, not reconnecting.", e, ErrorCode.GENERIC_FAILURE) - } + if (os != null) + try { + event.getLevel + val str = patternLayout.format(event) + os.write(str.getBytes("ISO-8859-1")) + os.flush() + } catch { + case e: IOException => + if (e.isInstanceOf[InterruptedIOException]) Thread.currentThread.interrupt() + os = null + LogLog.warn("Detected problem with connection: " + e) + if (reconnectionDelay > 0) fireConnector() + else errorHandler.error( + "Detected problem with connection, not reconnecting.", + e, + ErrorCode.GENERIC_FAILURE, + ) + } } - private def fireConnector() { + private def fireConnector(): Unit = { if (connector == null) { LogLog.debug("Starting a new connector thread.") connector = new SocketConnector @@ -119,16 +118,13 @@ class StringSocketAppender() extends AppenderSkeleton { } } - /** - * The SocketAppender does not use a layout. Hence, this method - * returns false. - **/ + /** The SocketAppender does not use a layout. Hence, this method returns false. */ override def requiresLayout = false class SocketConnector extends Thread { var interrupted = false - override def run() { + override def run(): Unit = { var socket: Socket = null var c = true while (c && !interrupted) @@ -143,10 +139,10 @@ class StringSocketAppender() extends AppenderSkeleton { c = false } } catch { - case e: InterruptedException => + case _: InterruptedException => LogLog.debug("Connector interrupted. Leaving loop.") return - case e: ConnectException => + case _: ConnectException => LogLog.debug("Remote host " + address.getHostName + " refused connection.") case e: IOException => if (e.isInstanceOf[InterruptedIOException]) Thread.currentThread.interrupt() diff --git a/hail/src/main/scala/is/hail/utils/TextInputFilterAndReplace.scala b/hail/src/main/scala/is/hail/utils/TextInputFilterAndReplace.scala index f563b67e4e5..bd647fa03c9 100644 --- a/hail/src/main/scala/is/hail/utils/TextInputFilterAndReplace.scala +++ b/hail/src/main/scala/is/hail/utils/TextInputFilterAndReplace.scala @@ -1,6 +1,10 @@ package is.hail.utils -case class TextInputFilterAndReplace(filterPattern: Option[String] = None, findPattern: Option[String] = None, replacePattern: Option[String] = None) { +case class TextInputFilterAndReplace( + filterPattern: Option[String] = None, + findPattern: Option[String] = None, + replacePattern: Option[String] = None, +) { require(!(findPattern.isDefined ^ replacePattern.isDefined)) private val fpRegex = filterPattern.map(_.r).orNull diff --git a/hail/src/main/scala/is/hail/utils/TracingInputStream.scala b/hail/src/main/scala/is/hail/utils/TracingInputStream.scala index c88e41d55e9..ff54d9eb886 100644 --- a/hail/src/main/scala/is/hail/utils/TracingInputStream.scala +++ b/hail/src/main/scala/is/hail/utils/TracingInputStream.scala @@ -16,7 +16,6 @@ class TracingInputStream( b } - override def close(): Unit = { + override def close(): Unit = in.close() - } } diff --git a/hail/src/main/scala/is/hail/utils/TruncatedArrayIndexedSeq.scala b/hail/src/main/scala/is/hail/utils/TruncatedArrayIndexedSeq.scala index 03ab35119e1..3ec20537776 100644 --- a/hail/src/main/scala/is/hail/utils/TruncatedArrayIndexedSeq.scala +++ b/hail/src/main/scala/is/hail/utils/TruncatedArrayIndexedSeq.scala @@ -1,6 +1,7 @@ package is.hail.utils -class TruncatedArrayIndexedSeq[T](a: Array[T], newLength: Int) extends IndexedSeq[T] with Serializable { +class TruncatedArrayIndexedSeq[T](a: Array[T], newLength: Int) + extends IndexedSeq[T] with Serializable { def length: Int = newLength def apply(idx: Int): T = { @@ -9,4 +10,3 @@ class TruncatedArrayIndexedSeq[T](a: Array[T], newLength: Int) extends IndexedSe a(idx) } } - diff --git a/hail/src/main/scala/is/hail/utils/TryAll.scala b/hail/src/main/scala/is/hail/utils/TryAll.scala index 5bf4b793cba..92debd20fed 100644 --- a/hail/src/main/scala/is/hail/utils/TryAll.scala +++ b/hail/src/main/scala/is/hail/utils/TryAll.scala @@ -4,9 +4,9 @@ import scala.util.{Failure, Success, Try} object TryAll { def apply[K](f: => K): Try[K] = - try { + try Success(f) - } catch { + catch { case e: Throwable => Failure(e) } -} \ No newline at end of file +} diff --git a/hail/src/main/scala/is/hail/utils/UnionFind.scala b/hail/src/main/scala/is/hail/utils/UnionFind.scala index 2050805fdfd..bf2da08dea1 100644 --- a/hail/src/main/scala/is/hail/utils/UnionFind.scala +++ b/hail/src/main/scala/is/hail/utils/UnionFind.scala @@ -7,12 +7,11 @@ class UnionFind(initialCapacity: Int = 32) { def size: Int = count - private def ensure(i: Int) { + private def ensure(i: Int): Unit = { if (i >= a.length) { var newLength = a.length << 1 - while (i >= newLength) { + while (i >= newLength) newLength = newLength << 1 - } val a2 = new Array[Int](newLength) Array.copy(a, 0, a2, 0, a.length) a = a2 @@ -22,7 +21,7 @@ class UnionFind(initialCapacity: Int = 32) { } } - def makeSet(i: Int) { + def makeSet(i: Int): Unit = { ensure(i) a(i) = i count += 1 @@ -31,9 +30,8 @@ class UnionFind(initialCapacity: Int = 32) { def find(x: Int): Int = { require(x < a.length) var representative = x - while (representative != a(representative)) { + while (representative != a(representative)) representative = a(representative) - } var current = x while (representative != current) { val temp = a(current) @@ -43,7 +41,7 @@ class UnionFind(initialCapacity: Int = 32) { current } - def union(x: Int, y: Int) { + def union(x: Int, y: Int): Unit = { val xroot = find(x) val yroot = find(y) diff --git a/hail/src/main/scala/is/hail/utils/package.scala b/hail/src/main/scala/is/hail/utils/package.scala index e672b4a19f0..097459b024a 100644 --- a/hail/src/main/scala/is/hail/utils/package.scala +++ b/hail/src/main/scala/is/hail/utils/package.scala @@ -4,33 +4,34 @@ import is.hail.annotations.ExtendedOrdering import is.hail.check.Gen import is.hail.expr.ir.ByteArrayBuilder import is.hail.io.fs.{FS, FileListEntry} + +import scala.collection.{mutable, GenTraversableOnce, TraversableOnce} +import scala.collection.generic.CanBuildFrom +import scala.collection.mutable.ArrayBuffer +import scala.language.higherKinds +import scala.reflect.ClassTag +import scala.util.{Failure, Success, Try} + +import java.io._ +import java.lang.reflect.Method +import java.net.{URI, URLClassLoader} +import java.security.SecureRandom +import java.text.SimpleDateFormat +import java.util.{Base64, Date} +import java.util.concurrent.ExecutorService + import org.apache.commons.io.output.TeeOutputStream import org.apache.commons.lang3.StringUtils import org.apache.hadoop.fs.PathIOException import org.apache.hadoop.mapred.FileSplit import org.apache.hadoop.mapreduce.lib.input.{FileSplit => NewFileSplit} import org.apache.log4j.Level -import org.apache.spark.sql.Row import org.apache.spark.{Partition, TaskContext} +import org.apache.spark.sql.Row +import org.json4s.{Extraction, Formats, JObject, NoTypeHints, Serializer} import org.json4s.JsonAST.{JArray, JString} import org.json4s.jackson.Serialization import org.json4s.reflect.TypeInfo -import org.json4s.{Extraction, Formats, JObject, NoTypeHints, Serializer} - -import java.io._ -import java.lang.reflect.Method -import java.net.{URI, URLClassLoader} -import java.security.SecureRandom -import java.text.SimpleDateFormat -import java.util.concurrent.ExecutorService -import java.util.{Base64, Date} -import scala.collection.generic.CanBuildFrom -import scala.collection.mutable.ArrayBuffer -import scala.collection.{GenTraversableOnce, TraversableOnce, mutable} -import scala.language.{higherKinds, implicitConversions} -import scala.reflect.ClassTag -import scala.util.{Failure, Success, Try} -import org.apache.spark.sql.Row package utils { trait Truncatable { @@ -51,21 +52,28 @@ package utils { } sealed trait AnyFailAllFail[C[_]] { - def apply[T](ts: TraversableOnce[Option[T]])(implicit cbf: CanBuildFrom[Nothing, T, C[T]]): Option[C[T]] = { + def apply[T](ts: TraversableOnce[Option[T]])(implicit cbf: CanBuildFrom[Nothing, T, C[T]]) + : Option[C[T]] = { val b = cbf() - for (t <- ts) { + for (t <- ts) if (t.isEmpty) return None else b += t.get - } Some(b.result()) } } sealed trait MapAccumulate[C[_], U] { - def apply[T, S](a: Iterable[T], z: S)(f: (T, S) => (U, S)) - (implicit uct: ClassTag[U], cbf: CanBuildFrom[Nothing, U, C[U]]): C[U] = { + def apply[T, S]( + a: Iterable[T], + z: S, + )( + f: (T, S) => (U, S) + )(implicit + uct: ClassTag[U], + cbf: CanBuildFrom[Nothing, U, C[U]], + ): C[U] = { val b = cbf() var acc = z for ((x, i) <- a.zipWithIndex) { @@ -78,23 +86,19 @@ package utils { } } -package object utils extends Logging - with richUtils.Implicits - with NumericPairImplicits - with utils.NumericImplicits - with Py4jUtils - with ErrorHandling { +package object utils + extends Logging with richUtils.Implicits with NumericPairImplicits with utils.NumericImplicits + with Py4jUtils with ErrorHandling { def utilsPackageClass = getClass def getStderrAndLogOutputStream[T](implicit tct: ClassTag[T]): OutputStream = new TeeOutputStream(new LoggerOutputStream(log, Level.ERROR), System.err) - def format(s: String, substitutions: Any*): String = { + def format(s: String, substitutions: Any*): String = substitutions.zipWithIndex.foldLeft(s) { case (str, (value, i)) => - str.replace(s"@${ i + 1 }", value.toString) + str.replace(s"@${i + 1}", value.toString) } - } def coerceToInt(l: Long): Int = { if (l > Int.MaxValue || l < Int.MinValue) @@ -107,7 +111,7 @@ package object utils extends Logging fileListEntries: Array[FileListEntry], forceGZ: Boolean, gzAsBGZ: Boolean, - maxSizeMB: Int = 128 + maxSizeMB: Int = 128, ) = { if (fileListEntries.isEmpty) fatal(s"arguments refer to no files: ${globPaths.toIndexedSeq}.") @@ -124,8 +128,8 @@ package object utils extends Logging fileListEntry: FileListEntry, forceGZ: Boolean, gzAsBGZ: Boolean, - maxSizeMB: Int = 128 - ) { + maxSizeMB: Int = 128, + ): Unit = { if (!forceGZ && !gzAsBGZ) fatal( s"""Cannot load file '${fileListEntry.getPath}' @@ -133,15 +137,17 @@ package object utils extends Logging | If the file is actually block gzipped (even though its extension is .gz), | use the 'force_bgz' argument to treat all .gz file extensions as .bgz. | If you are sure that you want to load a non-block-gzipped file serially - | on one core, use the 'force' argument.""".stripMargin) + | on one core, use the 'force' argument.""".stripMargin + ) else if (!gzAsBGZ) { val fileSize = fileListEntry.getLen if (fileSize > 1024 * 1024 * maxSizeMB) warn( - s"""file '${fileListEntry.getPath}' is ${ readableBytes(fileSize) } + s"""file '${fileListEntry.getPath}' is ${readableBytes(fileSize)} | It will be loaded serially (on one core) due to usage of the 'force' argument. | If it is actually block-gzipped, either rename to .bgz or use the 'force_bgz' - | argument.""".stripMargin) + | argument.""".stripMargin + ) } } @@ -169,11 +175,11 @@ package object utils extends Logging math.ceil(math.log(nPartitions) / math.log(branchingFactor)).toInt } - def simpleAssert(p: Boolean) { + def simpleAssert(p: Boolean): Unit = if (!p) throw new AssertionError - } - def optionCheckInRangeInclusive[A](low: A, high: A)(name: String, a: A)(implicit ord: Ordering[A]): Unit = + def optionCheckInRangeInclusive[A](low: A, high: A)(name: String, a: A)(implicit ord: Ordering[A]) + : Unit = if (ord.lt(a, low) || ord.gt(a, high)) { fatal(s"$name cannot lie outside [$low, $high]: $a") } @@ -205,8 +211,7 @@ package object utils extends Logging val tMins = (tMilliseconds / msPerMinute).toInt val tSec = (tMilliseconds % msPerMinute) / 1e3 ("%d" + "m" + "%.1f" + "s").format(tMins, tSec) - } - else { + } else { val tHrs = (tMilliseconds / msPerHour).toInt val tMins = ((tMilliseconds % msPerHour) / msPerMinute).toInt val tSec = (tMilliseconds % msPerMinute) / 1e3 @@ -249,7 +254,6 @@ package object utils extends Logging else (tib, "TiB") - val num = formatDouble(absds.toDouble / div.toDouble, precision) s"$num $suffix" } @@ -260,12 +264,11 @@ package object utils extends Logging else None - def nullIfNot(p: Boolean, x: Any): Any = { + def nullIfNot(p: Boolean, x: Any): Any = if (p) x else null - } def divOption(num: Double, denom: Double): Option[Double] = someIf(denom != 0, num / denom) @@ -281,13 +284,11 @@ package object utils extends Logging def D_epsilon(a: Double, b: Double, tolerance: Double = defaultTolerance): Double = math.max(java.lang.Double.MIN_NORMAL, tolerance * math.max(math.abs(a), math.abs(b))) - def D_==(a: Double, b: Double, tolerance: Double = defaultTolerance): Boolean = { + def D_==(a: Double, b: Double, tolerance: Double = defaultTolerance): Boolean = a == b || math.abs(a - b) <= D_epsilon(a, b, tolerance) - } - def D_!=(a: Double, b: Double, tolerance: Double = defaultTolerance): Boolean = { + def D_!=(a: Double, b: Double, tolerance: Double = defaultTolerance): Boolean = !(a == b) && math.abs(a - b) > D_epsilon(a, b, tolerance) - } def D_<(a: Double, b: Double, tolerance: Double = defaultTolerance): Boolean = !(a == b) && a - b < -D_epsilon(a, b, tolerance) @@ -328,6 +329,7 @@ package object utils extends Logging def rowIterator(r: Row): Iterator[Any] = new Iterator[Any] { var idx: Int = 0 def hasNext: Boolean = idx < r.size + def next: Any = { val a = r(idx) idx += 1 @@ -340,14 +342,13 @@ package object utils extends Logging .resize(12) .filter(s => !s.isEmpty) - def prettyIdentifier(str: String): String = { + def prettyIdentifier(str: String): String = if (str.matches("""[_a-zA-Z]\w*""")) str else - s"`${ StringEscapeUtils.escapeString(str, backticked = true) }`" - } + s"`${StringEscapeUtils.escapeString(str, backticked = true)}`" - def formatDouble(d: Double, precision: Int): String = s"%.${ precision }f".format(d) + def formatDouble(d: Double, precision: Int): String = s"%.${precision}f".format(d) def uriPath(uri: String): String = new URI(uri).getPath @@ -369,17 +370,30 @@ package object utils extends Logging def mapAccumulate[C[_], U] = mapAccumulateInstance.asInstanceOf[MapAccumulate[C, U]] - /** - * An abstraction for building an {@code Array} of known size. Guarantees a left-to-right traversal + /** An abstraction for building an {@code Array} of known size. Guarantees a left-to-right + * traversal * - * @param xs the thing to iterate over - * @param size the size of array to allocate - * @param key given the source value and its source index, yield the target index - * @param combine given the target value, the target index, the source value, and the source index, compute the new target value + * @param xs + * the thing to iterate over + * @param size + * the size of array to allocate + * @param key + * given the source value and its source index, yield the target index + * @param combine + * given the target value, the target index, the source value, and the source index, compute + * the new target value * @tparam A * @tparam B */ - def coalesce[A, B: ClassTag](xs: GenTraversableOnce[A])(size: Int, key: (A, Int) => Int, z: B)(combine: (B, A) => B): Array[B] = { + def coalesce[A, B: ClassTag]( + xs: GenTraversableOnce[A] + )( + size: Int, + key: (A, Int) => Int, + z: B, + )( + combine: (B, A) => B + ): Array[B] = { val a = Array.fill(size)(z) for ((x, idx) <- xs.toIterator.zipWithIndex) { @@ -396,9 +410,8 @@ package object utils extends Logging val newline = System.lineSeparator() val sb = new StringBuilder sb ++= "The maps do not have the same entries:" + newline - for (failure <- failures) { - sb ++= s" At key ${ failure._1 }, the left map has ${ failure._2 } and the right map has ${ failure._3 }" + newline - } + for (failure <- failures) + sb ++= s" At key ${failure._1}, the left map has ${failure._2} and the right map has ${failure._3}" + newline sb ++= s" The left map is: $l" + newline sb ++= s" The right map is: $r" + newline sb.result() @@ -407,11 +420,12 @@ package object utils extends Logging if (l.keySet != r.keySet) { println( s"""The maps do not have the same keys. - | These keys are unique to the left-hand map: ${ l.keySet -- r.keySet } - | These keys are unique to the right-hand map: ${ r.keySet -- l.keySet } + | These keys are unique to the left-hand map: ${l.keySet -- r.keySet} + | These keys are unique to the right-hand map: ${r.keySet -- l.keySet} | The left map is: $l | The right map is: $r - """.stripMargin) + """.stripMargin + ) false } else { val fs = Array.newBuilder[(K, V, V)] @@ -450,9 +464,9 @@ package object utils extends Logging } def lookupMethod(c: Class[_], method: String): Method = { - try { + try c.getDeclaredMethod(method) - } catch { + catch { case _: Exception => assert(c != classOf[java.lang.Object]) lookupMethod(c.getSuperclass, method) @@ -464,15 +478,16 @@ package object utils extends Logging m.invoke(obj, args: _*) } - /* - * Use reflection to get the path of a partition coming from a Parquet read. This requires accessing Spark - * internal interfaces. It works with Spark 1 and 2 and doesn't depend on the location of the Parquet - * package (parquet vs org.apache.parquet) which can vary between distributions. - */ + /* Use reflection to get the path of a partition coming from a Parquet read. This requires + * accessing Spark internal interfaces. It works with Spark 1 and 2 and doesn't depend on the + * location of the Parquet package (parquet vs org.apache.parquet) which can vary between + * distributions. */ def partitionPath(p: Partition): String = { p.getClass.getCanonicalName match { case "org.apache.spark.rdd.SqlNewHadoopPartition" => - val split = invokeMethod(invokeMethod(p, "serializableHadoopSplit"), "value").asInstanceOf[NewFileSplit] + val split = invokeMethod(invokeMethod(p, "serializableHadoopSplit"), "value").asInstanceOf[ + NewFileSplit + ] split.getPath.getName case "org.apache.spark.sql.execution.datasources.FilePartition" => @@ -535,11 +550,11 @@ package object utils extends Logging def roundWithConstantSum(a: Array[Double]): Array[Int] = { val withFloors = a.zipWithIndex.map { case (d, i) => (i, d, math.floor(d)) } - val totalFractional = (withFloors.map { case (i, orig, floor) => orig - floor }.sum + 0.5).toInt + val totalFractional = (withFloors.map { case (_, orig, floor) => orig - floor }.sum + 0.5).toInt withFloors .sortBy { case (_, orig, floor) => floor - orig } .zipWithIndex - .map { case ((i, orig, floor), iSort) => + .map { case ((i, orig, _), iSort) => if (iSort < totalFractional) (i, math.ceil(orig)) else @@ -608,12 +623,13 @@ package object utils extends Logging def partSuffix(ctx: TaskContext): String = { val rng = new java.security.SecureRandom() val fileUUID = new java.util.UUID(rng.nextLong(), rng.nextLong()) - s"${ ctx.stageId() }-${ ctx.partitionId() }-${ ctx.attemptNumber() }-$fileUUID" + s"${ctx.stageId()}-${ctx.partitionId()}-${ctx.attemptNumber()}-$fileUUID" } - def partFile(d: Int, i: Int, ctx: TaskContext): String = s"${ partFile(d, i) }-${ partSuffix(ctx) }" + def partFile(d: Int, i: Int, ctx: TaskContext): String = s"${partFile(d, i)}-${partSuffix(ctx)}" - def mangle(strs: Array[String], formatter: Int => String = "_%d".format(_)): (Array[String], Array[(String, String)]) = { + def mangle(strs: Array[String], formatter: Int => String = "_%d".format(_)) + : (Array[String], Array[(String, String)]) = { val b = new BoxedArrayBuilder[String] val uniques = new mutable.HashSet[String]() @@ -644,18 +660,20 @@ package object utils extends Logging def using[R <: AutoCloseable, T](r: R)(consume: (R) => T): T = { var caught = false - try { + try consume(r) - } catch { + catch { case original: Exception => caught = true - try { + try r.close() - } catch { + catch { case duringClose: Exception => if (original == duringClose) { - log.info(s"""The exact same exception object, ${original}, was thrown by both - |the consumer and the close method. I will throw the original.""".stripMargin) + log.info( + s"""The exact same exception object, $original, was thrown by both + |the consumer and the close method. I will throw the original.""".stripMargin + ) throw original } else { duringClose.addSuppressed(original) @@ -663,11 +681,10 @@ package object utils extends Logging } } throw original - } finally { + } finally if (!caught) { r.close() } - } } def singletonElement[T](it: Iterator[T]): T = { @@ -723,13 +740,12 @@ package object utils extends Logging parts } - def matchErrorToNone[T, U](f: (T) => U): (T) => Option[U] = (x: T) => { - try { + def matchErrorToNone[T, U](f: (T) => U): (T) => Option[U] = (x: T) => + try Some(f(x)) - } catch { + catch { case _: MatchError => None } - } def charRegex(c: Char): String = { // See: https://docs.oracle.com/javase/tutorial/essential/regex/literals.html @@ -741,19 +757,17 @@ package object utils extends Logging s } - def ordMax[T](left: T, right: T, ord: ExtendedOrdering): T = { + def ordMax[T](left: T, right: T, ord: ExtendedOrdering): T = if (ord.gt(left, right)) left else right - } - def ordMin[T](left: T, right: T, ord: ExtendedOrdering): T = { + def ordMin[T](left: T, right: T, ord: ExtendedOrdering): T = if (ord.lt(left, right)) left else right - } def makeJavaMap[K, V](x: TraversableOnce[(K, V)]): java.util.HashMap[K, V] = { val m = new java.util.HashMap[K, V] @@ -769,8 +783,9 @@ package object utils extends Logging def toMapFast[T, K, V]( ts: TraversableOnce[T] - )(key: T => K, - value: T => V + )( + key: T => K, + value: T => V, ): collection.Map[K, V] = { val it = ts.toIterator val m = mutable.Map[K, V]() @@ -783,11 +798,12 @@ package object utils extends Logging def toMapIfUnique[K, K2, V]( kvs: Traversable[(K, V)] - )(keyBy: K => K2 + )( + keyBy: K => K2 ): Either[Map[K2, Traversable[K]], Map[K2, V]] = { val grouped = kvs.groupBy(x => keyBy(x._1)) - val dupes = grouped.filter { case (k, m) => m.size != 1 } + val dupes = grouped.filter { case (_, m) => m.size != 1 } if (dupes.nonEmpty) { Left(dupes.map { case (k, m) => k -> m.map(_._1) }) @@ -798,11 +814,11 @@ package object utils extends Logging } } - def dumpClassLoader(cl: ClassLoader) { - System.err.println(s"ClassLoader ${ cl.getClass.getCanonicalName }:") + def dumpClassLoader(cl: ClassLoader): Unit = { + System.err.println(s"ClassLoader ${cl.getClass.getCanonicalName}:") cl match { case cl: URLClassLoader => - System.err.println(s" ${ cl.getURLs.mkString(" ") }") + System.err.println(s" ${cl.getURLs.mkString(" ")}") case _ => System.err.println(" non-URLClassLoader") } @@ -817,52 +833,60 @@ package object utils extends Logging using(new OutputStreamWriter(fs.create(path + "/README.txt"))) { out => out.write( s"""This folder comprises a Hail (www.hail.is) native Table or MatrixTable. - | Written with version ${ HailContext.get.version } - | Created at ${ dateFormat.format(new Date()) }""".stripMargin) + | Written with version ${HailContext.get.version} + | Created at ${dateFormat.format(new Date())}""".stripMargin + ) } } - def decompress(input: Array[Byte], size: Int): Array[Byte] = CompressionUtils.decompressZlib(input, size) + def decompress(input: Array[Byte], size: Int): Array[Byte] = + CompressionUtils.decompressZlib(input, size) - def compress(bb: ByteArrayBuilder, input: Array[Byte]): Int = CompressionUtils.compressZlib(bb, input) + def compress(bb: ByteArrayBuilder, input: Array[Byte]): Int = + CompressionUtils.compressZlib(bb, input) - def unwrappedApply[U, T](f: (U, T) => T): (U, Seq[T]) => T = if (f == null) null else { (s, ts) => - f(s, ts(0)) - } + def unwrappedApply[U, T](f: (U, T) => T): (U, Seq[T]) => T = + if (f == null) null else { (s, ts) => f(s, ts(0)) } - def unwrappedApply[U, T](f: (U, T, T) => T): (U, Seq[T]) => T = if (f == null) null else { (s, ts) => + def unwrappedApply[U, T](f: (U, T, T) => T): (U, Seq[T]) => T = if (f == null) null + else { (s, ts) => val Seq(t1, t2) = ts f(s, t1, t2) } - def unwrappedApply[U, T](f: (U, T, T, T) => T): (U, Seq[T]) => T = if (f == null) null else { (s, ts) => + def unwrappedApply[U, T](f: (U, T, T, T) => T): (U, Seq[T]) => T = if (f == null) null + else { (s, ts) => val Seq(t1, t2, t3) = ts f(s, t1, t2, t3) } - def unwrappedApply[U, T](f: (U, T, T, T, T) => T): (U, Seq[T]) => T = if (f == null) null else { (s, ts) => + def unwrappedApply[U, T](f: (U, T, T, T, T) => T): (U, Seq[T]) => T = if (f == null) null + else { (s, ts) => val Seq(t1, t2, t3, t4) = ts f(s, t1, t2, t3, t4) } - def unwrappedApply[U, T](f: (U, T, T, T, T, T) => T): (U, Seq[T]) => T = if (f == null) null else { (s, ts) => + def unwrappedApply[U, T](f: (U, T, T, T, T, T) => T): (U, Seq[T]) => T = if (f == null) null + else { (s, ts) => val Seq(t1, t2, t3, t4, t5) = ts f(s, t1, t2, t3, t4, t5) } - def unwrappedApply[U, T](f: (U, T, T, T, T, T, T) => T): (U, Seq[T]) => T = if (f == null) null else { (s, ts) => + def unwrappedApply[U, T](f: (U, T, T, T, T, T, T) => T): (U, Seq[T]) => T = if (f == null) null + else { (s, ts) => val Seq(arg1, arg2, arg3, arg4, arg5, arg6) = ts f(s, arg1, arg2, arg3, arg4, arg5, arg6) } - def unwrappedApply[U, T](f: (U, T, T, T, T, T, T, T) => T): (U, Seq[T]) => T = if (f == null) null else { (s, ts) => + def unwrappedApply[U, T](f: (U, T, T, T, T, T, T, T) => T): (U, Seq[T]) => T = if (f == null) null + else { (s, ts) => val Seq(arg1, arg2, arg3, arg4, arg5, arg6, arg7) = ts f(s, arg1, arg2, arg3, arg4, arg5, arg6, arg7) } def drainInputStreamToOutputStream( is: InputStream, - os: OutputStream + os: OutputStream, ): Unit = { val buffer = new Array[Byte](1024) var length = is.read(buffer) @@ -910,13 +934,11 @@ package object utils extends Logging (fileOffset << 16) | blockOffset } - def virtualOffsetBlockOffset(offset: Long): Int = { - (offset & 0xFFFF).toInt - } + def virtualOffsetBlockOffset(offset: Long): Int = + (offset & 0xffff).toInt - def virtualOffsetCompressedOffset(offset: Long): Long = { + def virtualOffsetCompressedOffset(offset: Long): Long = offset >> 16 - } def tokenUrlSafe(n: Int): String = { val bytes = new Array[Byte](32) @@ -926,7 +948,12 @@ package object utils extends Logging } // mutates byteOffsets and returns the byte size - def getByteSizeAndOffsets(byteSize: Array[Long], alignment: Array[Long], nMissingBytes: Long, byteOffsets: Array[Long]): Long = { + def getByteSizeAndOffsets( + byteSize: Array[Long], + alignment: Array[Long], + nMissingBytes: Long, + byteOffsets: Array[Long], + ): Long = { assert(byteSize.length == alignment.length) assert(byteOffsets.length == byteSize.length) val bp = new BytePacker() @@ -953,15 +980,12 @@ package object utils extends Logging offset } - /** - * Merge the sorted `IndexedSeq`s `xs` and `ys` using comparison function `lt`. - */ + /** Merge the sorted `IndexedSeq`s `xs` and `ys` using comparison function `lt`. */ def merge[A](xs: IndexedSeq[A], ys: IndexedSeq[A], lt: (A, A) => Boolean): IndexedSeq[A] = (xs.length, ys.length) match { case (0, _) => ys case (_, 0) => xs case (n, m) => - val res = new ArrayBuffer[A](n + m) var i = 0 @@ -976,28 +1000,27 @@ package object utils extends Logging } } - for (k <- i until n) { + for (k <- i until n) res += xs(k) - } - for (k <- j until m) { + for (k <- j until m) res += ys(k) - } res } - - /** - * Run (task, key) pairs on the `executor`, returning some `F` of the - * failures and an `IndexedSeq` of the successes with their corresponding - * key. - */ - def runAll[F[_], A](executor: ExecutorService) - (accum: (F[Throwable], (Throwable, Int)) => F[Throwable]) - (init: F[Throwable]) - (tasks: IndexedSeq[(() => A, Int)]) - : (F[Throwable], IndexedSeq[(A, Int)]) = { + /** Run (task, key) pairs on the `executor`, returning some `F` of the failures and an + * `IndexedSeq` of the successes with their corresponding key. + */ + def runAll[F[_], A]( + executor: ExecutorService + )( + accum: (F[Throwable], (Throwable, Int)) => F[Throwable] + )( + init: F[Throwable] + )( + tasks: IndexedSeq[(() => A, Int)] + ): (F[Throwable], IndexedSeq[(A, Int)]) = { var err = init val buffer = new mutable.ArrayBuffer[(A, Int)](tasks.length) @@ -1020,7 +1043,7 @@ package object utils extends Logging def runAllKeepFirstError[A]( executor: ExecutorService ): IndexedSeq[(() => A, Int)] => (Option[Throwable], IndexedSeq[(A, Int)]) = - runAll[Option, A](executor) { case (opt, (e, _)) => opt.orElse(Some(e)) } (None) + runAll[Option, A](executor) { case (opt, (e, _)) => opt.orElse(Some(e)) }(None) } // FIXME: probably resolved in 3.6 https://github.com/json4s/json4s/commit/fc96a92e1aa3e9e3f97e2e91f94907fdfff6010d @@ -1033,11 +1056,13 @@ object GenericIndexedSeqSerializer extends Serializer[IndexedSeq[_]] { override def deserialize(implicit format: Formats) = { case (TypeInfo(IndexedSeqClass, parameterizedType), JArray(xs)) => - val typeInfo = TypeInfo(parameterizedType - .map(_.getActualTypeArguments()(0)) - .getOrElse(throw new RuntimeException("No type parameter info for type IndexedSeq")) - .asInstanceOf[Class[_]], - None) + val typeInfo = TypeInfo( + parameterizedType + .map(_.getActualTypeArguments()(0)) + .getOrElse(throw new RuntimeException("No type parameter info for type IndexedSeq")) + .asInstanceOf[Class[_]], + None, + ) xs.map(x => Extraction.extract(x, typeInfo)).toArray[Any] } } diff --git a/hail/src/main/scala/is/hail/utils/prettyPrint/ArrayOfByteArrayInputStream.scala b/hail/src/main/scala/is/hail/utils/prettyPrint/ArrayOfByteArrayInputStream.scala index b04b45351ba..1526bf96b7a 100644 --- a/hail/src/main/scala/is/hail/utils/prettyPrint/ArrayOfByteArrayInputStream.scala +++ b/hail/src/main/scala/is/hail/utils/prettyPrint/ArrayOfByteArrayInputStream.scala @@ -17,8 +17,7 @@ class ArrayOfByteArrayInputStream(bytes: Array[Array[Byte]]) extends InputStream val readByte = byteInputStreams(currentInputStreamIdx).read() if (readByte == -1) { currentInputStreamIdx += 1 - } - else { + } else { foundByte = true byteToReturn = readByte } @@ -31,11 +30,12 @@ class ArrayOfByteArrayInputStream(bytes: Array[Array[Byte]]) extends InputStream var numBytesRead = 0 var moreToRead = true - while(numBytesRead < len && moreToRead) { + while (numBytesRead < len && moreToRead) { if (currentInputStreamIdx == byteInputStreams.length) { moreToRead = false } else { - val bytesReadInOneCall = byteInputStreams(currentInputStreamIdx).read(b, off + numBytesRead, len - numBytesRead) + val bytesReadInOneCall = + byteInputStreams(currentInputStreamIdx).read(b, off + numBytesRead, len - numBytesRead) if (bytesReadInOneCall == -1) { currentInputStreamIdx += 1 } else { diff --git a/hail/src/main/scala/is/hail/utils/prettyPrint/PrettyPrintWriter.scala b/hail/src/main/scala/is/hail/utils/prettyPrint/PrettyPrintWriter.scala index d739f746dcb..ba9fa1c1147 100644 --- a/hail/src/main/scala/is/hail/utils/prettyPrint/PrettyPrintWriter.scala +++ b/hail/src/main/scala/is/hail/utils/prettyPrint/PrettyPrintWriter.scala @@ -1,12 +1,12 @@ package is.hail.utils.prettyPrint -import java.io.{StringWriter, Writer} -import java.util.ArrayDeque - import is.hail.utils.BoxedArrayBuilder import scala.annotation.tailrec +import java.io.{StringWriter, Writer} +import java.util.ArrayDeque + object Doc { def render(doc: Doc, width: Int, ribbonWidth: Int, _maxLines: Int, out: Writer): Unit = { // All groups whose formatting is still undetermined. The innermost group is at the end. @@ -45,7 +45,9 @@ object Doc { printNode(node, false) } else { pendingGroups.getLast.contents += node - while (!pendingGroups.isEmpty && globalPos - pendingGroups.getFirst.start > remainingInLine) { + while ( + !pendingGroups.isEmpty && globalPos - pendingGroups.getFirst.start > remainingInLine + ) { val head = pendingGroups.removeFirst() head.end = globalPos printNode(head, false) @@ -119,12 +121,11 @@ object Doc { pendingCloses -= 1 } - def openGroups(): Unit = { + def openGroups(): Unit = while (pendingOpens > 0) { pendingGroups.addLast(GroupN(new BoxedArrayBuilder[ScannedNode](), globalPos, -1)) pendingOpens -= 1 } - } try { while (currentNode != null) { @@ -153,7 +154,7 @@ object Doc { closeGroups() } catch { case _: MaxLinesExceeded => - // 'maxLines' have been printed, so break out of the loop and stop printing. + // 'maxLines' have been printed, so break out of the loop and stop printing. } } } @@ -175,14 +176,19 @@ private[prettyPrint] case class Group(body: Doc) extends Doc private[prettyPrint] case class Indent(i: Int, body: Doc) extends Doc private[prettyPrint] case class Concat(it: Iterable[Doc]) extends Doc -private[prettyPrint] abstract class ScannedNode +abstract private[prettyPrint] class ScannedNode private[prettyPrint] case class TextN(t: String) extends ScannedNode private[prettyPrint] case class LineN(indentation: Int, ifFlat: String) extends ScannedNode -private[prettyPrint] case class GroupN(contents: BoxedArrayBuilder[ScannedNode], start: Int, var end: Int) extends ScannedNode -private[prettyPrint] abstract class KontNode +private[prettyPrint] case class GroupN( + contents: BoxedArrayBuilder[ScannedNode], + start: Int, + var end: Int, +) extends ScannedNode + +abstract private[prettyPrint] class KontNode private[prettyPrint] case object PopGroupK extends KontNode private[prettyPrint] case class UnindentK(indent: Int) extends KontNode private[prettyPrint] case class ConcatK(kont: Iterator[Doc]) extends KontNode -private[prettyPrint] class MaxLinesExceeded() extends Exception \ No newline at end of file +private[prettyPrint] class MaxLinesExceeded() extends Exception diff --git a/hail/src/main/scala/is/hail/utils/richUtils/ByteTrackingOutputStream.scala b/hail/src/main/scala/is/hail/utils/richUtils/ByteTrackingOutputStream.scala index 9a78158c06a..d8d89d4c99c 100644 --- a/hail/src/main/scala/is/hail/utils/richUtils/ByteTrackingOutputStream.scala +++ b/hail/src/main/scala/is/hail/utils/richUtils/ByteTrackingOutputStream.scala @@ -5,17 +5,17 @@ import java.io.OutputStream class ByteTrackingOutputStream(base: OutputStream) extends OutputStream { var bytesWritten = 0L - def write(c: Int) { + def write(c: Int): Unit = { bytesWritten += 1 base.write(c) } - override def write(b: Array[Byte]) { + override def write(b: Array[Byte]): Unit = { base.write(b) bytesWritten += b.length } - override def write(b: Array[Byte], off: Int, len: Int) { + override def write(b: Array[Byte], off: Int, len: Int): Unit = { base.write(b, off, len) bytesWritten += len } diff --git a/hail/src/main/scala/is/hail/utils/richUtils/Implicits.scala b/hail/src/main/scala/is/hail/utils/richUtils/Implicits.scala index 6cf09ecff26..a6870f50b56 100644 --- a/hail/src/main/scala/is/hail/utils/richUtils/Implicits.scala +++ b/hail/src/main/scala/is/hail/utils/richUtils/Implicits.scala @@ -1,49 +1,55 @@ package is.hail.utils.richUtils -import java.io.InputStream - -import breeze.linalg.DenseMatrix -import is.hail.annotations.{JoinedRegionValue, Region, RegionValue, RegionValueBuilder} +import is.hail.annotations.{JoinedRegionValue, Region, RegionValue} import is.hail.asm4s.{Code, Value} -import is.hail.io.{InputBuffer, OutputBuffer, RichContextRDDRegionValue, RichContextRDDLong} -import is.hail.rvd.RVDContext +import is.hail.io.{InputBuffer, OutputBuffer, RichContextRDDLong, RichContextRDDRegionValue} import is.hail.sparkextras._ import is.hail.utils.{HailIterator, MultiArray2, Truncatable, WithContext} -import org.apache.spark.SparkContext -import org.apache.spark.mllib.linalg.distributed.IndexedRowMatrix -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Row -import scala.collection.{TraversableOnce, mutable} +import scala.collection.{mutable, TraversableOnce} import scala.language.implicitConversions import scala.reflect.ClassTag import scala.util.matching.Regex +import java.io.InputStream + +import breeze.linalg.DenseMatrix +import org.apache.spark.SparkContext +import org.apache.spark.mllib.linalg.distributed.IndexedRowMatrix +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.Row + trait Implicits { implicit def toRichArray[T](a: Array[T]): RichArray[T] = new RichArray(a) implicit def toRichIndexedSeq[T](s: IndexedSeq[T]): RichIndexedSeq[T] = new RichIndexedSeq(s) - implicit def toRichIndexedSeqAnyRef[T <: AnyRef](s: IndexedSeq[T]): RichIndexedSeqAnyRef[T] = new RichIndexedSeqAnyRef(s) + implicit def toRichIndexedSeqAnyRef[T <: AnyRef](s: IndexedSeq[T]): RichIndexedSeqAnyRef[T] = + new RichIndexedSeqAnyRef(s) implicit def arrayToRichIndexedSeq[T](s: Array[T]): RichIndexedSeq[T] = new RichIndexedSeq(s) implicit def toRichBoolean(b: Boolean): RichBoolean = new RichBoolean(b) - implicit def toRichDenseMatrixDouble(m: DenseMatrix[Double]): RichDenseMatrixDouble = new RichDenseMatrixDouble(m) + implicit def toRichDenseMatrixDouble(m: DenseMatrix[Double]): RichDenseMatrixDouble = + new RichDenseMatrixDouble(m) - implicit def toRichEnumeration[T <: Enumeration](e: T): RichEnumeration[T] = new RichEnumeration(e) + implicit def toRichEnumeration[T <: Enumeration](e: T): RichEnumeration[T] = + new RichEnumeration(e) - implicit def toRichIndexedRowMatrix(irm: IndexedRowMatrix): RichIndexedRowMatrix = new RichIndexedRowMatrix(irm) + implicit def toRichIndexedRowMatrix(irm: IndexedRowMatrix): RichIndexedRowMatrix = + new RichIndexedRowMatrix(irm) - implicit def toRichIntPairTraversableOnce[V](t: TraversableOnce[(Int, V)]): RichIntPairTraversableOnce[V] = + implicit def toRichIntPairTraversableOnce[V](t: TraversableOnce[(Int, V)]) + : RichIntPairTraversableOnce[V] = new RichIntPairTraversableOnce[V](t) implicit def toRichIterable[T](i: Iterable[T]): RichIterable[T] = new RichIterable(i) implicit def toRichIterable[T](a: Array[T]): RichIterable[T] = new RichIterable(a) - implicit def toRichContextIterator[T](it: Iterator[WithContext[T]]): RichContextIterator[T] = new RichContextIterator[T](it) + implicit def toRichContextIterator[T](it: Iterator[WithContext[T]]): RichContextIterator[T] = + new RichContextIterator[T](it) implicit def toRichIterator[T](it: Iterator[T]): RichIterator[T] = new RichIterator[T](it) @@ -53,28 +59,36 @@ trait Implicits { implicit def toRichMap[K, V](m: Map[K, V]): RichMap[K, V] = new RichMap(m) - implicit def toRichMultiArray2Long(ma: MultiArray2[Long]): RichMultiArray2Long = new RichMultiArray2Long(ma) + implicit def toRichMultiArray2Long(ma: MultiArray2[Long]): RichMultiArray2Long = + new RichMultiArray2Long(ma) - implicit def toRichMultiArray2Int(ma: MultiArray2[Int]): RichMultiArray2Int = new RichMultiArray2Int(ma) + implicit def toRichMultiArray2Int(ma: MultiArray2[Int]): RichMultiArray2Int = + new RichMultiArray2Int(ma) - implicit def toRichMultiArray2Double(ma: MultiArray2[Double]): RichMultiArray2Double = new RichMultiArray2Double(ma) + implicit def toRichMultiArray2Double(ma: MultiArray2[Double]): RichMultiArray2Double = + new RichMultiArray2Double(ma) - implicit def toRichMutableMap[K, V](m: mutable.Map[K, V]): RichMutableMap[K, V] = new RichMutableMap(m) + implicit def toRichMutableMap[K, V](m: mutable.Map[K, V]): RichMutableMap[K, V] = + new RichMutableMap(m) implicit def toRichOption[T](o: Option[T]): RichOption[T] = new RichOption[T](o) - implicit def toRichOrderedArray[T: Ordering](a: Array[T]): RichOrderedArray[T] = new RichOrderedArray(a) + implicit def toRichOrderedArray[T: Ordering](a: Array[T]): RichOrderedArray[T] = + new RichOrderedArray(a) - implicit def toRichOrderedSeq[T: Ordering](s: Seq[T]): RichOrderedSeq[T] = new RichOrderedSeq[T](s) + implicit def toRichOrderedSeq[T: Ordering](s: Seq[T]): RichOrderedSeq[T] = + new RichOrderedSeq[T](s) - implicit def toRichPairRDD[K, V](r: RDD[(K, V)])(implicit kct: ClassTag[K], - vct: ClassTag[V]): RichPairRDD[K, V] = new RichPairRDD(r) + implicit def toRichPairRDD[K, V](r: RDD[(K, V)])(implicit kct: ClassTag[K], vct: ClassTag[V]) + : RichPairRDD[K, V] = new RichPairRDD(r) implicit def toRichRDD[T](r: RDD[T])(implicit tct: ClassTag[T]): RichRDD[T] = new RichRDD(r) - implicit def toRichContextRDDRegionValue(r: ContextRDD[RegionValue]): RichContextRDDRegionValue = new RichContextRDDRegionValue(r) + implicit def toRichContextRDDRegionValue(r: ContextRDD[RegionValue]): RichContextRDDRegionValue = + new RichContextRDDRegionValue(r) - implicit def toRichContextRDDLong(r: ContextRDD[Long]): RichContextRDDLong = new RichContextRDDLong(r) + implicit def toRichContextRDDLong(r: ContextRDD[Long]): RichContextRDDLong = + new RichContextRDDLong(r) implicit def toRichRegex(r: Regex): RichRegex = new RichRegex(r) @@ -84,7 +98,8 @@ trait Implicits { implicit def toRichString(str: String): RichString = new RichString(str) - implicit def toRichStringBuilder(sb: mutable.StringBuilder): RichStringBuilder = new RichStringBuilder(sb) + implicit def toRichStringBuilder(sb: mutable.StringBuilder): RichStringBuilder = + new RichStringBuilder(sb) implicit def toTruncatable(s: String): Truncatable = s.truncatable() @@ -92,10 +107,11 @@ trait Implicits { implicit def toTruncatable(arr: Array[_]): Truncatable = toTruncatable(arr: Iterable[_]) - implicit def toHailIteratorDouble(it: HailIterator[Int]): HailIterator[Double] = new HailIterator[Double] { - override def next(): Double = it.next().toDouble - override def hasNext: Boolean = it.hasNext - } + implicit def toHailIteratorDouble(it: HailIterator[Int]): HailIterator[Double] = + new HailIterator[Double] { + override def next(): Double = it.next().toDouble + override def hasNext: Boolean = it.hasNext + } implicit def toRichInputStream(in: InputStream): RichInputStream = new RichInputStream(in) @@ -106,19 +122,26 @@ trait Implicits { implicit def toRichCodeRegion(r: Code[Region]): RichCodeRegion = new RichCodeRegion(r) - implicit def toRichPartialKleisliOptionFunction[A, B](x: PartialFunction[A, Option[B]]): RichPartialKleisliOptionFunction[A, B] = new RichPartialKleisliOptionFunction(x) + implicit def toRichPartialKleisliOptionFunction[A, B](x: PartialFunction[A, Option[B]]) + : RichPartialKleisliOptionFunction[A, B] = new RichPartialKleisliOptionFunction(x) - implicit def toContextPairRDDFunctions[K: ClassTag, V: ClassTag](x: ContextRDD[(K, V)]): ContextPairRDDFunctions[K, V] = new ContextPairRDDFunctions(x) + implicit def toContextPairRDDFunctions[K: ClassTag, V: ClassTag](x: ContextRDD[(K, V)]) + : ContextPairRDDFunctions[K, V] = new ContextPairRDDFunctions(x) - implicit def toRichContextRDD[T: ClassTag](x: ContextRDD[T]): RichContextRDD[T] = new RichContextRDD(x) + implicit def toRichContextRDD[T: ClassTag](x: ContextRDD[T]): RichContextRDD[T] = + new RichContextRDD(x) implicit def toRichContextRDDRow(x: ContextRDD[Row]): RichContextRDDRow = new RichContextRDDRow(x) - implicit def valueToRichCodeInputBuffer(in: Value[InputBuffer]): RichCodeInputBuffer = new RichCodeInputBuffer(in) + implicit def valueToRichCodeInputBuffer(in: Value[InputBuffer]): RichCodeInputBuffer = + new RichCodeInputBuffer(in) - implicit def valueToRichCodeOutputBuffer(out: Value[OutputBuffer]): RichCodeOutputBuffer = new RichCodeOutputBuffer(out) + implicit def valueToRichCodeOutputBuffer(out: Value[OutputBuffer]): RichCodeOutputBuffer = + new RichCodeOutputBuffer(out) - implicit def toRichCodeIterator[T](it: Code[Iterator[T]]): RichCodeIterator[T] = new RichCodeIterator[T](it) + implicit def toRichCodeIterator[T](it: Code[Iterator[T]]): RichCodeIterator[T] = + new RichCodeIterator[T](it) - implicit def valueToRichCodeIterator[T](it: Value[Iterator[T]]): RichCodeIterator[T] = new RichCodeIterator[T](it) + implicit def valueToRichCodeIterator[T](it: Value[Iterator[T]]): RichCodeIterator[T] = + new RichCodeIterator[T](it) } diff --git a/hail/src/main/scala/is/hail/utils/richUtils/RichArray.scala b/hail/src/main/scala/is/hail/utils/richUtils/RichArray.scala index f0e99fee343..7f3d93760d0 100644 --- a/hail/src/main/scala/is/hail/utils/richUtils/RichArray.scala +++ b/hail/src/main/scala/is/hail/utils/richUtils/RichArray.scala @@ -1,26 +1,24 @@ package is.hail.utils.richUtils -import is.hail.io.fs.FS -import is.hail.HailContext import is.hail.io.{DoubleInputBuffer, DoubleOutputBuffer} +import is.hail.io.fs.FS import is.hail.utils._ object RichArray { val defaultBufSize: Int = 4096 << 3 - + def importFromDoubles(fs: FS, path: String, n: Int): Array[Double] = { val a = new Array[Double](n) importFromDoubles(fs, path, a, defaultBufSize) a } - - def importFromDoubles(fs: FS, path: String, a: Array[Double], bufSize: Int): Unit = { + + def importFromDoubles(fs: FS, path: String, a: Array[Double], bufSize: Int): Unit = using(fs.open(path)) { is => val in = new DoubleInputBuffer(is, bufSize) in.readDoubles(a) } - } def exportToDoubles(fs: FS, path: String, a: Array[Double]): Unit = exportToDoubles(fs, path, a, defaultBufSize) diff --git a/hail/src/main/scala/is/hail/utils/richUtils/RichCodeInputBuffer.scala b/hail/src/main/scala/is/hail/utils/richUtils/RichCodeInputBuffer.scala index 556f6fbbd9e..1fec9f5f73c 100644 --- a/hail/src/main/scala/is/hail/utils/richUtils/RichCodeInputBuffer.scala +++ b/hail/src/main/scala/is/hail/utils/richUtils/RichCodeInputBuffer.scala @@ -72,9 +72,6 @@ class RichCodeInputBuffer( def readBoolean(): Code[Boolean] = ib.invoke[Boolean]("readBoolean") - def readUTF(): Code[String] = - ib.invoke[String]("readUTF") - def readBytes(toRegion: Value[Region], toOff: Code[Long], n: Int): Code[Unit] = { if (n == 0) Code._empty @@ -82,7 +79,8 @@ class RichCodeInputBuffer( Code.memoize(toOff, "ib_ready_bytes_to") { toOff => Code.memoize(ib, "ib_ready_bytes_in") { ib => Code((0 until n).map(i => - Region.storeByte(toOff.get + i.toLong, ib.readByte()))) + Region.storeByte(toOff.get + i.toLong, ib.readByte()) + )) } } else diff --git a/hail/src/main/scala/is/hail/utils/richUtils/RichCodeIterator.scala b/hail/src/main/scala/is/hail/utils/richUtils/RichCodeIterator.scala index 0455a86b62b..d091eeab5eb 100644 --- a/hail/src/main/scala/is/hail/utils/richUtils/RichCodeIterator.scala +++ b/hail/src/main/scala/is/hail/utils/richUtils/RichCodeIterator.scala @@ -4,6 +4,7 @@ import is.hail.asm4s.{Code, TypeInfo} class RichCodeIterator[T](it: Code[Iterator[T]]) { def hasNext: Code[Boolean] = it.invoke[Boolean]("hasNext") + def next()(implicit tti: TypeInfo[T]): Code[T] = Code.checkcast[T](it.invoke[java.lang.Object]("next")) } diff --git a/hail/src/main/scala/is/hail/utils/richUtils/RichCodeOutputBuffer.scala b/hail/src/main/scala/is/hail/utils/richUtils/RichCodeOutputBuffer.scala index 7f8eedd6744..539da999918 100644 --- a/hail/src/main/scala/is/hail/utils/richUtils/RichCodeOutputBuffer.scala +++ b/hail/src/main/scala/is/hail/utils/richUtils/RichCodeOutputBuffer.scala @@ -3,9 +3,8 @@ package is.hail.utils.richUtils import is.hail.annotations.Region import is.hail.asm4s._ import is.hail.expr.ir.EmitCodeBuilder -import is.hail.types.physical._ import is.hail.io.OutputBuffer -import is.hail.types.physical.stypes.{SCode, SValue} +import is.hail.types.physical.stypes.SValue import is.hail.types.virtual._ class RichCodeOutputBuffer( diff --git a/hail/src/main/scala/is/hail/utils/richUtils/RichCodeRegion.scala b/hail/src/main/scala/is/hail/utils/richUtils/RichCodeRegion.scala index 02cedc9aaae..1a203f62611 100644 --- a/hail/src/main/scala/is/hail/utils/richUtils/RichCodeRegion.scala +++ b/hail/src/main/scala/is/hail/utils/richUtils/RichCodeRegion.scala @@ -8,16 +8,16 @@ class RichCodeRegion(val region: Code[Region]) extends AnyVal { def allocate(alignment: Code[Long], n: Code[Long]): Code[Long] = region.invoke[Long, Long, Long]("allocate", alignment, n) - def clearRegion(): Code[Unit] = { + def clearRegion(): Code[Unit] = region.invoke[Unit]("clear") - } def getMemory(): Code[RegionMemory] = region.invoke[RegionMemory]("getMemory") def trackAndIncrementReferenceCountOf(other: Code[Region]): Code[Unit] = region.invoke[Region, Unit]("addReferenceTo", other) - def takeOwnershipOfAndClear(other: Code[Region]): Code[Unit] = other.invoke[Region, Unit]("move", region) + def takeOwnershipOfAndClear(other: Code[Region]): Code[Unit] = + other.invoke[Region, Unit]("move", region) def setNumParents(n: Code[Int]): Code[Unit] = region.invoke[Int, Unit]("setNumParents", n) @@ -38,11 +38,14 @@ class RichCodeRegion(val region: Code[Region]) extends AnyVal { def invalidate(): Code[Unit] = region.invoke[Unit]("invalidate") - def getNewRegion(blockSize: Code[Int]): Code[Unit] = region.invoke[Int, Unit]("getNewRegion", blockSize) + def getNewRegion(blockSize: Code[Int]): Code[Unit] = + region.invoke[Int, Unit]("getNewRegion", blockSize) - def storeJavaObject(obj: Code[AnyRef]): Code[Int] = region.invoke[AnyRef, Int]("storeJavaObject", obj) + def storeJavaObject(obj: Code[AnyRef]): Code[Int] = + region.invoke[AnyRef, Int]("storeJavaObject", obj) - def lookupJavaObject(idx: Code[Int]): Code[AnyRef] = region.invoke[Int, AnyRef]("lookupJavaObject", idx) + def lookupJavaObject(idx: Code[Int]): Code[AnyRef] = + region.invoke[Int, AnyRef]("lookupJavaObject", idx) def getPool(): Code[RegionPool] = region.invoke[RegionPool]("getPool") diff --git a/hail/src/main/scala/is/hail/utils/richUtils/RichContextIterator.scala b/hail/src/main/scala/is/hail/utils/richUtils/RichContextIterator.scala index ced6e21993e..3e77555f642 100644 --- a/hail/src/main/scala/is/hail/utils/richUtils/RichContextIterator.scala +++ b/hail/src/main/scala/is/hail/utils/richUtils/RichContextIterator.scala @@ -5,7 +5,6 @@ import is.hail.utils.WithContext class RichContextIterator[T](val i: Iterator[WithContext[T]]) { def mapLines[U](f: T => U): Iterator[U] = i.map(_.map(f).value) - def foreachLine(f: T => Unit) { + def foreachLine(f: T => Unit): Unit = i.foreach(_.foreach(f)) - } } diff --git a/hail/src/main/scala/is/hail/utils/richUtils/RichContextRDD.scala b/hail/src/main/scala/is/hail/utils/richUtils/RichContextRDD.scala index 27d69292413..c9c1fe310af 100644 --- a/hail/src/main/scala/is/hail/utils/richUtils/RichContextRDD.scala +++ b/hail/src/main/scala/is/hail/utils/richUtils/RichContextRDD.scala @@ -1,28 +1,36 @@ package is.hail.utils.richUtils -import java.io._ -import is.hail.HailContext -import is.hail.annotations.{Region, RegionPool} -import is.hail.backend.{ExecuteContext, HailTaskContext} -import is.hail.backend.spark.SparkTaskContext +import is.hail.annotations.RegionPool +import is.hail.backend.ExecuteContext import is.hail.io.FileWriteMetadata import is.hail.io.fs.FS import is.hail.io.index.IndexWriter import is.hail.rvd.RVDContext -import is.hail.utils._ import is.hail.sparkextras._ -import org.apache.hadoop.conf.{Configuration => HadoopConf} -import org.apache.spark.TaskContext -import org.apache.spark.rdd.RDD +import is.hail.utils._ import scala.reflect.ClassTag +import java.io._ + +import org.apache.spark.TaskContext + object RichContextRDD { - def writeParts[T](ctx: RVDContext, rootPath: String, f:String, idxRelPath: String, mkIdxWriter: (String, RegionPool) => IndexWriter, - stageLocally: Boolean, fs: FS, localTmpdir: String, it: Iterator[T], - write: (RVDContext, Iterator[T], OutputStream, IndexWriter) => (Long, Long)): Iterator[FileWriteMetadata] = { + def writeParts[T]( + ctx: RVDContext, + rootPath: String, + f: String, + idxRelPath: String, + mkIdxWriter: (String, RegionPool) => IndexWriter, + stageLocally: Boolean, + fs: FS, + localTmpdir: String, + it: Iterator[T], + write: (RVDContext, Iterator[T], OutputStream, IndexWriter) => (Long, Long), + ): Iterator[FileWriteMetadata] = { val finalFilename = rootPath + "/parts/" + f - val finalIdxFilename = if (idxRelPath != null) rootPath + "/" + idxRelPath + "/" + f + ".idx" else null + val finalIdxFilename = + if (idxRelPath != null) rootPath + "/" + idxRelPath + "/" + f + ".idx" else null val (filename, idxFilename) = if (stageLocally) { val context = TaskContext.get @@ -57,7 +65,7 @@ class RichContextRDD[T: ClassTag](crdd: ContextRDD[T]) { def cleanupRegions: ContextRDD[T] = { crdd.cmapPartitionsAndContext { (ctx, part) => - val it = part.flatMap(_ (ctx)) + val it = part.flatMap(_(ctx)) new Iterator[T]() { private[this] var cleared: Boolean = false @@ -80,8 +88,6 @@ class RichContextRDD[T: ClassTag](crdd: ContextRDD[T]) { } } - - // If idxPath is null, then mkIdxWriter should return null and not read its string argument def writePartitions( ctx: ExecuteContext, @@ -89,7 +95,7 @@ class RichContextRDD[T: ClassTag](crdd: ContextRDD[T]) { idxRelPath: String, stageLocally: Boolean, mkIdxWriter: (String, RegionPool) => IndexWriter, - write: (RVDContext, Iterator[T], OutputStream, IndexWriter) => (Long, Long) + write: (RVDContext, Iterator[T], OutputStream, IndexWriter) => (Long, Long), ): Array[FileWriteMetadata] = { val localTmpdir = ctx.localTmpdir val fs = ctx.fs @@ -104,7 +110,8 @@ class RichContextRDD[T: ClassTag](crdd: ContextRDD[T]) { val fileData = crdd.cmapPartitionsWithIndex { (i, ctx, it) => val f = partFile(d, i, TaskContext.get) - RichContextRDD.writeParts(ctx, path, f, idxRelPath, mkIdxWriter, stageLocally, fs, localTmpdir, it, write) + RichContextRDD.writeParts(ctx, path, f, idxRelPath, mkIdxWriter, stageLocally, fs, + localTmpdir, it, write) } .collect() diff --git a/hail/src/main/scala/is/hail/utils/richUtils/RichContextRDDRow.scala b/hail/src/main/scala/is/hail/utils/richUtils/RichContextRDDRow.scala index 4480363049a..c075eda9a20 100644 --- a/hail/src/main/scala/is/hail/utils/richUtils/RichContextRDDRow.scala +++ b/hail/src/main/scala/is/hail/utils/richUtils/RichContextRDDRow.scala @@ -1,14 +1,12 @@ package is.hail.utils.richUtils -import is.hail.annotations.RegionValue -import is.hail.types.physical.PStruct -import is.hail.rvd.RVDContext import is.hail.sparkextras.ContextRDD +import is.hail.types.physical.PStruct import is.hail.utils._ + import org.apache.spark.sql.Row class RichContextRDDRow(crdd: ContextRDD[Row]) { - def toRegionValues(rowType: PStruct): ContextRDD[Long] = { + def toRegionValues(rowType: PStruct): ContextRDD[Long] = crdd.cmapPartitions((ctx, it) => it.copyToRegion(ctx.region, rowType)) - } } diff --git a/hail/src/main/scala/is/hail/utils/richUtils/RichDenseMatrixDouble.scala b/hail/src/main/scala/is/hail/utils/richUtils/RichDenseMatrixDouble.scala index 4a5516a940b..b23d1266a0d 100644 --- a/hail/src/main/scala/is/hail/utils/richUtils/RichDenseMatrixDouble.scala +++ b/hail/src/main/scala/is/hail/utils/richUtils/RichDenseMatrixDouble.scala @@ -1,17 +1,18 @@ package is.hail.utils.richUtils -import java.io.{DataInputStream, DataOutputStream, InputStream, OutputStream} - -import breeze.linalg.{DenseMatrix => BDM} +import is.hail.io._ import is.hail.io.fs.FS -import is.hail.HailContext import is.hail.linalg.{BlockMatrix, BlockMatrixMetadata, GridPartitioner} -import is.hail.io._ import is.hail.utils._ + +import java.io.{DataInputStream, DataOutputStream, InputStream, OutputStream} + +import breeze.linalg.{DenseMatrix => BDM} import org.json4s.jackson object RichDenseMatrixDouble { - def apply(nRows: Int, nCols: Int, data: Array[Double], isTranspose: Boolean = false): BDM[Double] = { + def apply(nRows: Int, nCols: Int, data: Array[Double], isTranspose: Boolean = false) + : BDM[Double] = { require(data.length == nRows * nCols) new BDM[Double]( @@ -20,7 +21,8 @@ object RichDenseMatrixDouble { data = data, offset = 0, majorStride = if (isTranspose) nCols else nRows, - isTranspose = isTranspose) + isTranspose = isTranspose, + ) } // assumes data isCompact, caller must close @@ -34,15 +36,21 @@ object RichDenseMatrixDouble { val data = new Array[Double](rows * cols) in.readDoubles(data) - new BDM[Double](rows, cols, data, - offset = 0, majorStride = if (isTranspose) cols else rows, isTranspose = isTranspose) + new BDM[Double]( + rows, + cols, + data, + offset = 0, + majorStride = if (isTranspose) cols else rows, + isTranspose = isTranspose, + ) } - def read(fs: FS, path: String, bufferSpec: BufferSpec): BDM[Double] = { + def read(fs: FS, path: String, bufferSpec: BufferSpec): BDM[Double] = using(new DataInputStream(fs.open(path)))(is => read(is, bufferSpec)) - } - def importFromDoubles(fs: FS, path: String, nRows: Int, nCols: Int, rowMajor: Boolean): BDM[Double] = { + def importFromDoubles(fs: FS, path: String, nRows: Int, nCols: Int, rowMajor: Boolean) + : BDM[Double] = { require(nRows * nCols.toLong <= Int.MaxValue) val data = RichArray.importFromDoubles(fs, path, nRows * nCols) @@ -52,7 +60,7 @@ object RichDenseMatrixDouble { def exportToDoubles(fs: FS, path: String, m: BDM[Double], forceRowMajor: Boolean): Boolean = { val (data, rowMajor) = m.toCompactData(forceRowMajor) assert(data.length == m.rows * m.cols) - + RichArray.exportToDoubles(fs, path, data) rowMajor @@ -62,12 +70,14 @@ object RichDenseMatrixDouble { class RichDenseMatrixDouble(val m: BDM[Double]) extends AnyVal { // dot is overloaded in Breeze def matrixMultiply(bm: BlockMatrix): BlockMatrix = { - require(m.cols == bm.nRows, - s"incompatible matrix dimensions: ${ m.rows } x ${ m.cols } and ${ bm.nRows } x ${ bm.nCols } ") + require( + m.cols == bm.nRows, + s"incompatible matrix dimensions: ${m.rows} x ${m.cols} and ${bm.nRows} x ${bm.nCols} ", + ) BlockMatrix.fromBreezeMatrix(m, bm.blockSize).dot(bm) } - def forceSymmetry() { + def forceSymmetry(): Unit = { require(m.rows == m.cols, "only square matrices can be made symmetric") var i = 0 @@ -93,7 +103,7 @@ class RichDenseMatrixDouble(val m: BDM[Double]) extends AnyVal { } // caller must close - def write(os: OutputStream, forceRowMajor: Boolean, bufferSpec: BufferSpec) { + def write(os: OutputStream, forceRowMajor: Boolean, bufferSpec: BufferSpec): Unit = { val (data, isTranspose) = m.toCompactData(forceRowMajor) assert(data.length == m.rows * m.cols) @@ -106,16 +116,20 @@ class RichDenseMatrixDouble(val m: BDM[Double]) extends AnyVal { out.flush() } - def - write(fs: FS, path: String, forceRowMajor: Boolean = false, bufferSpec: BufferSpec) { + def write(fs: FS, path: String, forceRowMajor: Boolean = false, bufferSpec: BufferSpec): Unit = using(fs.create(path))(os => write(os, forceRowMajor, bufferSpec: BufferSpec)) - } - def writeBlockMatrix(fs: FS, path: String, blockSize: Int, forceRowMajor: Boolean = false, overwrite: Boolean = false) { + def writeBlockMatrix( + fs: FS, + path: String, + blockSize: Int, + forceRowMajor: Boolean = false, + overwrite: Boolean = false, + ): Unit = { if (overwrite) fs.delete(path, recursive = true) else if (fs.exists(path)) - fatal(s"file already exists: $path") + fatal(s"file already exists: $path") fs.mkDir(path) @@ -143,10 +157,11 @@ class RichDenseMatrixDouble(val m: BDM[Double]) extends AnyVal { implicit val formats = defaultJSONFormats jackson.Serialization.write( BlockMatrixMetadata(blockSize, m.rows, m.cols, gp.partitionIndexToBlockIndex, partFiles), - os) + os, + ) } - info(s"wrote $nParts ${ plural(nParts, "item") } in $nParts ${ plural(nParts, "partition") }") + info(s"wrote $nParts ${plural(nParts, "item")} in $nParts ${plural(nParts, "partition")}") using(fs.create(path + "/_SUCCESS"))(out => ()) } diff --git a/hail/src/main/scala/is/hail/utils/richUtils/RichIndexedRowMatrix.scala b/hail/src/main/scala/is/hail/utils/richUtils/RichIndexedRowMatrix.scala index afc27ea1e30..65cd8f16dca 100644 --- a/hail/src/main/scala/is/hail/utils/richUtils/RichIndexedRowMatrix.scala +++ b/hail/src/main/scala/is/hail/utils/richUtils/RichIndexedRowMatrix.scala @@ -1,15 +1,16 @@ package is.hail.utils.richUtils -import org.apache.spark._ -import org.apache.spark.rdd.RDD -import breeze.linalg.{DenseMatrix => BDM} import is.hail.linalg._ import is.hail.utils._ + +import breeze.linalg.{DenseMatrix => BDM} +import org.apache.spark._ import org.apache.spark.mllib.linalg.distributed.IndexedRowMatrix +import org.apache.spark.rdd.RDD object RichIndexedRowMatrix { - private def seqOp(gp: GridPartitioner) - (block: Array[Double], row: (Int, Int, Int, Array[Double])): Array[Double] = { + private def seqOp(gp: GridPartitioner)(block: Array[Double], row: (Int, Int, Int, Array[Double])) + : Array[Double] = { val (i, j, ii, rowSegment) = row val nRowsInBlock = gp.blockRowNRows(i) @@ -69,14 +70,17 @@ class RichIndexedRowMatrix(indexedRowMatrix: IndexedRowMatrix) { }.aggregateByKey(null: Array[Double], gp)(seqOp(gp), combOp) .mapValuesWithKey { case ((i, j), data) => new BDM[Double](gp.blockRowNRows(i), gp.blockColNCols(j), data) - } + } new BlockMatrix(new EmptyPartitionIsAZeroMatrixRDD(blocks), blockSize, nRows, nCols) } } private class EmptyPartitionIsAZeroMatrixRDD(blocks: RDD[((Int, Int), BDM[Double])]) - extends RDD[((Int, Int), BDM[Double])](blocks.sparkContext, Seq[Dependency[_]](new OneToOneDependency(blocks))) { + extends RDD[((Int, Int), BDM[Double])]( + blocks.sparkContext, + Seq[Dependency[_]](new OneToOneDependency(blocks)), + ) { @transient val gp: GridPartitioner = (blocks.partitioner: @unchecked) match { case Some(p: GridPartitioner) => p } @@ -99,4 +103,8 @@ private class EmptyPartitionIsAZeroMatrixRDD(blocks: RDD[((Int, Int), BDM[Double Some(gp) } -private class BlockPartition(val index: Int, val blockCoordinates: (Int, Int), val blockDims: (Int, Int)) extends Partition {} +private class BlockPartition( + val index: Int, + val blockCoordinates: (Int, Int), + val blockDims: (Int, Int), +) extends Partition {} diff --git a/hail/src/main/scala/is/hail/utils/richUtils/RichIndexedSeq.scala b/hail/src/main/scala/is/hail/utils/richUtils/RichIndexedSeq.scala index 61a08aea788..4d1e7e7c4b4 100644 --- a/hail/src/main/scala/is/hail/utils/richUtils/RichIndexedSeq.scala +++ b/hail/src/main/scala/is/hail/utils/richUtils/RichIndexedSeq.scala @@ -16,21 +16,19 @@ class RichIndexedSeqAnyRef[T <: AnyRef](val a: IndexedSeq[T]) extends AnyVal { same } } + /** Rich wrapper for an indexed sequence. * * Houses the generic binary search methods. All methods taking * - a search key 'x: U', - * - a key comparison 'lt: (U, U) => Boolean' (the most generic versions - * allow the search key 'x' to be of a different type than the elements of - * the sequence, and take one or two mixed type comparison functions), + * - a key comparison 'lt: (U, U) => Boolean' (the most generic versions allow the search key 'x' + * to be of a different type than the elements of the sequence, and take one or two mixed type + * comparison functions), * - and a key projection 'k: (T) => U', - * assume the following preconditions for all 0 <= i <= j < a.size (writing < - * for 'lt'): - * 1. if 'x' < k(a(i)) then 'x' < k(a(j)) - * 2. if k(a(j)) < 'x' then k(a(i)) < 'x' - * These can be rephrased as 1: 'x' < k(_) partitions a, and 2: k(_) < 'x' - * partitions a. (Actually, upperBound only needs 1. and lowerBound only needs - * 2.) + * assume the following preconditions for all 0 <= i <= j < a.size (writing < for 'lt'): + * 1. if 'x' < k(a(i)) then 'x' < k(a(j)) 2. if k(a(j)) < 'x' then k(a(i)) < 'x' + * These can be rephrased as 1: 'x' < k(_) partitions a, and 2: k(_) < 'x' partitions a. (Actually, + * upperBound only needs 1. and lowerBound only needs 2.) */ class RichIndexedSeq[T](val a: IndexedSeq[T]) extends AnyVal { @@ -78,13 +76,18 @@ class RichIndexedSeq[T](val a: IndexedSeq[T]) extends AnyVal { x: V, ltUV: (U, V) => Boolean, ltVU: (V, U) => Boolean, - k: (T) => U + k: (T) => U, ): (Int, Int) = - runSearch(x, ltUV, ltVU, k, + runSearch( + x, + ltUV, + ltVU, + k, (l, m, u) => (lowerBound(x, l, m, ltUV, k), upperBound(x, m + 1, u, ltVU, k)), (m) => - (m, m)) + (m, m), + ) def equalRange[U >: T](x: U, lt: (U, U) => Boolean): (Int, Int) = equalRange(x, lt, lt, identity[U]) @@ -96,7 +99,7 @@ class RichIndexedSeq[T](val a: IndexedSeq[T]) extends AnyVal { x: V, ltUV: (U, V) => Boolean, ltVU: (V, U) => Boolean, - k: (T) => U + k: (T) => U, ): Boolean = runSearch(x, ltUV, ltVU, k, (_, _, _) => true, (_) => false) def containsOrdered[U](x: U, lt: (U, U) => Boolean, k: (T) => U): Boolean = @@ -108,11 +111,11 @@ class RichIndexedSeq[T](val a: IndexedSeq[T]) extends AnyVal { def containsOrdered[U >: T, V](x: V, ltUV: (U, V) => Boolean, ltVU: (V, U) => Boolean): Boolean = containsOrdered(x, ltUV, ltVU, identity[U]) - /** Returns 'start' <= i <= 'end' such that p(k(a(j))) is false for all j - * in ['start', i), and p(k(a(j))) is true for all j in [i, 'end'). + /** Returns 'start' <= i <= 'end' such that p(k(a(j))) is false for all j in ['start', i), and + * p(k(a(j))) is true for all j in [i, 'end'). * - * Assumes p(k(_)) partitions a, i.e. for all 0 <= i <= j < a.size, - * if p(k(a(i))) then p(k(a(j))). + * Assumes p(k(_)) partitions a, i.e. for all 0 <= i <= j < a.size, if p(k(a(i))) then + * p(k(a(j))). */ def partitionPoint[U](p: (U) => Boolean, start: Int, end: Int, k: (T) => U): Int = { var left = start @@ -136,11 +139,10 @@ class RichIndexedSeq[T](val a: IndexedSeq[T]) extends AnyVal { def partitionPoint[U >: T](p: (U) => Boolean): Int = partitionPoint(p, identity[U]) - /** Perform binary search until either an index i is found for which k(a(i)) - * is incomparible with 'x', or it is certain that no such i exists. In the - * first case, call 'found'(l, i, u), where [l, u] is the current range of - * the search. In the second case, call 'notFound'(j), where k(a(i)) < x for - * all i in [0, j) and x < k(a(i)) for all i in [j, a.size). + /** Perform binary search until either an index i is found for which k(a(i)) is incomparible with + * 'x', or it is certain that no such i exists. In the first case, call 'found'(l, i, u), where + * [l, u] is the current range of the search. In the second case, call 'notFound'(j), where + * k(a(i)) < x for all i in [0, j) and x < k(a(i)) for all i in [j, a.size). */ private def runSearch[U, V, R]( x: V, @@ -148,7 +150,7 @@ class RichIndexedSeq[T](val a: IndexedSeq[T]) extends AnyVal { ltVU: (V, U) => Boolean, k: (T) => U, found: (Int, Int, Int) => R, - notFound: (Int) => R + notFound: (Int) => R, ): R = { var left = 0 var right = a.size @@ -172,12 +174,11 @@ class RichIndexedSeq[T](val a: IndexedSeq[T]) extends AnyVal { def treeReduce(f: (T, T) => T)(implicit tct: ClassTag[T]): T = { var is: IndexedSeq[T] = a - while (is.length > 1) { + while (is.length > 1) is = is.iterator.grouped(2).map { case Seq(x1, x2) => f(x1, x2) case Seq(x1) => x1 }.toFastSeq - } is.head } } diff --git a/hail/src/main/scala/is/hail/utils/richUtils/RichInputStream.scala b/hail/src/main/scala/is/hail/utils/richUtils/RichInputStream.scala index 516c991d826..61b5ab8f34f 100644 --- a/hail/src/main/scala/is/hail/utils/richUtils/RichInputStream.scala +++ b/hail/src/main/scala/is/hail/utils/richUtils/RichInputStream.scala @@ -1,12 +1,12 @@ package is.hail.utils.richUtils -import java.io.InputStream import is.hail.utils._ +import java.io.InputStream + class RichInputStream(val in: InputStream) extends AnyVal { - def readFully(to: Array[Byte]): Unit = { + def readFully(to: Array[Byte]): Unit = readFully(to, 0, to.length) - } def readFully(to: Array[Byte], toOff: Int, n: Int): Unit = { val nRead = readRepeatedly(to, toOff, n) diff --git a/hail/src/main/scala/is/hail/utils/richUtils/RichIntPairTraversableOnce.scala b/hail/src/main/scala/is/hail/utils/richUtils/RichIntPairTraversableOnce.scala index 0ded8de9fb1..fc389b5f699 100644 --- a/hail/src/main/scala/is/hail/utils/richUtils/RichIntPairTraversableOnce.scala +++ b/hail/src/main/scala/is/hail/utils/richUtils/RichIntPairTraversableOnce.scala @@ -4,7 +4,8 @@ import scala.collection.TraversableOnce import scala.reflect.ClassTag class RichIntPairTraversableOnce[V](val t: TraversableOnce[(Int, V)]) extends AnyVal { - def reduceByKeyToArray(n: Int, zero: => V)(f: (V, V) => V)(implicit vct: ClassTag[V]): Array[V] = { + def reduceByKeyToArray(n: Int, zero: => V)(f: (V, V) => V)(implicit vct: ClassTag[V]) + : Array[V] = { val a = Array.fill[V](n)(zero) t.foreach { case (k, v) => a(k) = f(a(k), v) diff --git a/hail/src/main/scala/is/hail/utils/richUtils/RichIterable.scala b/hail/src/main/scala/is/hail/utils/richUtils/RichIterable.scala index d00bf460545..ea56128989b 100644 --- a/hail/src/main/scala/is/hail/utils/richUtils/RichIterable.scala +++ b/hail/src/main/scala/is/hail/utils/richUtils/RichIterable.scala @@ -2,10 +2,11 @@ package is.hail.utils.richUtils import is.hail.utils._ -import java.io.Serializable -import scala.collection.{AbstractIterable, mutable} +import scala.collection.{mutable, AbstractIterable} import scala.reflect.ClassTag +import java.io.Serializable + object RichIterable { def single[A](a: A): Iterable[A] = new AbstractIterable[A] { override def iterator = Iterator.single(a) @@ -23,9 +24,8 @@ object RichIterable { } class RichIterable[T](val i: Iterable[T]) extends Serializable { - def foreachBetween(f: (T) => Unit)(g: => Unit) { + def foreachBetween(f: (T) => Unit)(g: => Unit): Unit = i.iterator.foreachBetween(f)(g) - } def intersperse[S >: T](sep: S): Iterable[S] = new Iterable[S] { def iterator = i.iterator.intersperse(sep) @@ -51,7 +51,8 @@ class RichIterable[T](val i: Iterable[T]) extends Serializable { } } - def lazyMapWith2[T2, T3, S](i2: Iterable[T2], i3: Iterable[T3], f: (T, T2, T3) => S): Iterable[S] = + def lazyMapWith2[T2, T3, S](i2: Iterable[T2], i3: Iterable[T3], f: (T, T2, T3) => S) + : Iterable[S] = new Iterable[S] with Serializable { def iterator: Iterator[S] = new Iterator[S] { val it: Iterator[T] = i.iterator @@ -96,16 +97,15 @@ class RichIterable[T](val i: Iterable[T]) extends Serializable { def counter(): Map[T, Int] = { val m = new mutable.HashMap[T, Int]() - i.foreach { elem => m.updateValue(elem, 0, _ + 1) } + i.foreach(elem => m.updateValue(elem, 0, _ + 1)) m.toMap } - def toFastSeq(implicit tct: ClassTag[T]): IndexedSeq[T] = { + def toFastSeq(implicit tct: ClassTag[T]): IndexedSeq[T] = i match { case i: mutable.WrappedArray[T] => i case i: mutable.ArrayBuffer[T] => i case _ => i.toArray[T] } - } } diff --git a/hail/src/main/scala/is/hail/utils/richUtils/RichIterator.scala b/hail/src/main/scala/is/hail/utils/richUtils/RichIterator.scala index c0cdb0834d5..6d2002e813e 100644 --- a/hail/src/main/scala/is/hail/utils/richUtils/RichIterator.scala +++ b/hail/src/main/scala/is/hail/utils/richUtils/RichIterator.scala @@ -3,17 +3,19 @@ package is.hail.utils.richUtils import is.hail.annotations.{Region, RegionValue} import is.hail.types.physical.PStruct import is.hail.utils.{FlipbookIterator, StagingIterator, StateMachine} -import org.apache.spark.sql.Row -import java.io.PrintWriter import scala.collection.JavaConverters._ import scala.io.Source import scala.reflect.ClassTag +import java.io.PrintWriter + +import org.apache.spark.sql.Row + class RichIteratorLong(val it: Iterator[Long]) extends AnyVal { def toIteratorRV(region: Region): Iterator[RegionValue] = { val rv = RegionValue(region) - it.map(ptr => { rv.setOffset(ptr); rv }) + it.map { ptr => rv.setOffset(ptr); rv } } } @@ -24,14 +26,14 @@ class RichIterator[T](val it: Iterator[T]) extends AnyVal { new StateMachine[T] { def value: T = bit.head def isValid = bit.hasNext - def advance() { bit.next() } + def advance(): Unit = bit.next() } ) } def toFlipbookIterator: FlipbookIterator[T] = toStagingIterator - def foreachBetween(f: (T) => Unit)(g: => Unit) { + def foreachBetween(f: (T) => Unit)(g: => Unit): Unit = { if (it.hasNext) { f(it.next()) while (it.hasNext) { @@ -44,6 +46,7 @@ class RichIterator[T](val it: Iterator[T]) extends AnyVal { def intersperse[S >: T](sep: S): Iterator[S] = new Iterator[S] { var nextIsSep = false def hasNext = it.hasNext + def next() = { val n = if (nextIsSep) sep else it.next() nextIsSep = !nextIsSep @@ -54,6 +57,7 @@ class RichIterator[T](val it: Iterator[T]) extends AnyVal { def intersperse[S >: T](start: S, sep: S, end: S): Iterator[S] = new Iterator[S] { var state = 0 def hasNext = state != 4 + def next() = { state match { case 0 => @@ -73,10 +77,12 @@ class RichIterator[T](val it: Iterator[T]) extends AnyVal { } } - def pipe(pb: ProcessBuilder, + def pipe( + pb: ProcessBuilder, printHeader: (String => Unit) => Unit, printElement: (String => Unit, T) => Unit, - printFooter: (String => Unit) => Unit): (Iterator[String], StringBuilder, Process) = { + printFooter: (String => Unit) => Unit, + ): (Iterator[String], StringBuilder, Process) = { val command = pb.command().asScala.mkString(" ") @@ -85,14 +91,13 @@ class RichIterator[T](val it: Iterator[T]) extends AnyVal { val error = new StringBuilder() // Start a thread capture the process stderr new Thread("stderr reader for " + command) { - override def run() { + override def run(): Unit = Source.fromInputStream(proc.getErrorStream).addString(error) - } }.start() // Start a thread to feed the process input from our parent's iterator new Thread("stdin writer for " + command) { - override def run() { + override def run(): Unit = { val out = new PrintWriter(proc.getOutputStream) printHeader(out.println) @@ -118,7 +123,8 @@ class RichIterator[T](val it: Iterator[T]) extends AnyVal { if (!hasNext) throw new NoSuchElementException("next on empty iterator") - // the previous element must must be fully consumed or the next block will start in the wrong place + /* the previous element must must be fully consumed or the next block will start in the + * wrong place */ assert(prev == null || !prev.hasNext) prev = new Iterator[T] { var i = 0 @@ -143,9 +149,6 @@ class RichIterator[T](val it: Iterator[T]) extends AnyVal { } class RichRowIterator(val it: Iterator[Row]) extends AnyVal { - def copyToRegion(region: Region, rowTyp: PStruct): Iterator[Long] = { - it.map { row => - rowTyp.unstagedStoreJavaObject(null, row, region) - } - } + def copyToRegion(region: Region, rowTyp: PStruct): Iterator[Long] = + it.map(row => rowTyp.unstagedStoreJavaObject(null, row, region)) } diff --git a/hail/src/main/scala/is/hail/utils/richUtils/RichMap.scala b/hail/src/main/scala/is/hail/utils/richUtils/RichMap.scala index a934fbea025..5d080063e74 100644 --- a/hail/src/main/scala/is/hail/utils/richUtils/RichMap.scala +++ b/hail/src/main/scala/is/hail/utils/richUtils/RichMap.scala @@ -1,11 +1,11 @@ package is.hail.utils.richUtils class RichMap[K, V](val m: Map[K, V]) extends AnyVal { - def force = m.map(identity) // needed to make serializable: https://issues.scala-lang.org/browse/SI-7005 + def force = + m.map(identity) // needed to make serializable: https://issues.scala-lang.org/browse/SI-7005 - def outerJoin[V2](other: Map[K, V2]): Map[K, (Option[V], Option[V2])] = { - (m.keySet ++ other.keySet).map { k => (k, (m.get(k), other.get(k))) }.toMap - } + def outerJoin[V2](other: Map[K, V2]): Map[K, (Option[V], Option[V2])] = + (m.keySet ++ other.keySet).map(k => (k, (m.get(k), other.get(k)))).toMap def isTrivial(implicit eq: K =:= V): Boolean = m.forall { case (k, v) => k == v } diff --git a/hail/src/main/scala/is/hail/utils/richUtils/RichMultiArray2Numeric.scala b/hail/src/main/scala/is/hail/utils/richUtils/RichMultiArray2Numeric.scala index d2fe87ed581..44ec8eb203a 100644 --- a/hail/src/main/scala/is/hail/utils/richUtils/RichMultiArray2Numeric.scala +++ b/hail/src/main/scala/is/hail/utils/richUtils/RichMultiArray2Numeric.scala @@ -81,4 +81,4 @@ class RichMultiArray2Int(val ma: MultiArray2[Int]) extends AnyVal { } ma } -} \ No newline at end of file +} diff --git a/hail/src/main/scala/is/hail/utils/richUtils/RichMutableMap.scala b/hail/src/main/scala/is/hail/utils/richUtils/RichMutableMap.scala index ad18bf08728..11c258d142d 100644 --- a/hail/src/main/scala/is/hail/utils/richUtils/RichMutableMap.scala +++ b/hail/src/main/scala/is/hail/utils/richUtils/RichMutableMap.scala @@ -3,7 +3,6 @@ package is.hail.utils.richUtils import scala.collection.mutable class RichMutableMap[K, V](val m: mutable.Map[K, V]) extends AnyVal { - def updateValue(k: K, default: => V, f: (V) => V) { + def updateValue(k: K, default: => V, f: (V) => V): Unit = m += ((k, f(m.getOrElse(k, default)))) - } } diff --git a/hail/src/main/scala/is/hail/utils/richUtils/RichOption.scala b/hail/src/main/scala/is/hail/utils/richUtils/RichOption.scala index a4e5a2c6107..2cd23662698 100644 --- a/hail/src/main/scala/is/hail/utils/richUtils/RichOption.scala +++ b/hail/src/main/scala/is/hail/utils/richUtils/RichOption.scala @@ -5,5 +5,7 @@ class RichOption[T](val o: Option[T]) extends AnyVal { override def toString: String = o.toString - def liftedZip[U](other: Option[U]): Option[(T, U)] = o.flatMap { val1 => other.map(val2 => (val1, val2)) } + def liftedZip[U](other: Option[U]): Option[(T, U)] = o.flatMap { val1 => + other.map(val2 => (val1, val2)) + } } diff --git a/hail/src/main/scala/is/hail/utils/richUtils/RichPairRDD.scala b/hail/src/main/scala/is/hail/utils/richUtils/RichPairRDD.scala index c1366ba6be1..1f17fcdc3a7 100644 --- a/hail/src/main/scala/is/hail/utils/richUtils/RichPairRDD.scala +++ b/hail/src/main/scala/is/hail/utils/richUtils/RichPairRDD.scala @@ -3,6 +3,6 @@ package is.hail.utils.richUtils import org.apache.spark.rdd.RDD class RichPairRDD[K, V](val rdd: RDD[(K, V)]) extends AnyVal { - def mapValuesWithKey[W](f: (K, V) => W): RDD[(K, W)] = rdd.mapPartitions(_.map { case (k, v) => (k, f(k, v)) }, - preservesPartitioning = true) + def mapValuesWithKey[W](f: (K, V) => W): RDD[(K, W)] = + rdd.mapPartitions(_.map { case (k, v) => (k, f(k, v)) }, preservesPartitioning = true) } diff --git a/hail/src/main/scala/is/hail/utils/richUtils/RichRDD.scala b/hail/src/main/scala/is/hail/utils/richUtils/RichRDD.scala index ebc5d0cd216..74d01f4d4ae 100644 --- a/hail/src/main/scala/is/hail/utils/richUtils/RichRDD.scala +++ b/hail/src/main/scala/is/hail/utils/richUtils/RichRDD.scala @@ -1,25 +1,25 @@ package is.hail.utils.richUtils import is.hail.backend.ExecuteContext - -import java.io.{OutputStream, OutputStreamWriter} import is.hail.io.FileWriteMetadata -import is.hail.rvd.RVDContext +import is.hail.io.compress.{BGzipCodec, ComposableBGzipCodec, ComposableBGzipOutputStream} import is.hail.sparkextras._ import is.hail.utils._ -import is.hail.io.compress.{BGzipCodec, ComposableBGzipCodec, ComposableBGzipOutputStream} -import is.hail.io.fs.FS + +import scala.collection.mutable +import scala.reflect.ClassTag + +import java.io.{OutputStream, OutputStreamWriter} + import org.apache.hadoop import org.apache.hadoop.io.compress.CompressionCodecFactory import org.apache.spark.{NarrowDependency, Partition, Partitioner, TaskContext} import org.apache.spark.rdd.RDD -import scala.reflect.ClassTag -import scala.collection.mutable - case class SubsetRDDPartition(index: Int, parentPartition: Partition) extends Partition -case class SupersetRDDPartition(index: Int, maybeParentPartition: Option[Partition]) extends Partition +case class SupersetRDDPartition(index: Int, maybeParentPartition: Option[Partition]) + extends Partition class RichRDD[T](val r: RDD[T]) extends AnyVal { def reorderPartitions(oldIndices: Array[Int])(implicit tct: ClassTag[T]): RDD[T] = @@ -31,12 +31,21 @@ class RichRDD[T](val r: RDD[T]) extends AnyVal { Iterator(it.exists(p)) }.fold(false)(_ || _) - def writeTable(ctx: ExecuteContext, filename: String, header: Option[String] = None, exportType: String = ExportType.CONCATENATED) { + def writeTable( + ctx: ExecuteContext, + filename: String, + header: Option[String] = None, + exportType: String = ExportType.CONCATENATED, + ): Unit = { val hConf = r.sparkContext.hadoopConfiguration val codecFactory = new CompressionCodecFactory(hConf) val codec = { val codec = codecFactory.getCodec(new hadoop.fs.Path(filename)) - if (codec != null && codec.isInstanceOf[BGzipCodec] && exportType == ExportType.PARALLEL_COMPOSABLE) + if ( + codec != null && codec.isInstanceOf[ + BGzipCodec + ] && exportType == ExportType.PARALLEL_COMPOSABLE + ) new ComposableBGzipCodec else codec @@ -68,7 +77,7 @@ class RichRDD[T](val r: RDD[T]) extends AnyVal { case ExportType.PARALLEL_COMPOSABLE => r case ExportType.PARALLEL_HEADER_IN_SHARD => - r.mapPartitions { it => Iterator(h) ++ it } + r.mapPartitions(it => Iterator(h) ++ it) case _ => fatal(s"Unknown export type: $exportType") } } @@ -106,12 +115,13 @@ class RichRDD[T](val r: RDD[T]) extends AnyVal { } // this filename should sort after every partition - using(new OutputStreamWriter(fs.create(parallelOutputPath + "/part-composable-end" + ext))) { out => - // do nothing, for bgzip, this will write the empty block + using(new OutputStreamWriter(fs.create(parallelOutputPath + "/part-composable-end" + ext))) { + out => + // do nothing, for bgzip, this will write the empty block } } - if (!fs.exists(parallelOutputPath + "/_SUCCESS")) + if (!fs.isFile(parallelOutputPath + "/_SUCCESS")) fatal("write failed: no success indicator found") if (exportType == ExportType.CONCATENATED) { @@ -127,23 +137,33 @@ class RichRDD[T](val r: RDD[T]) extends AnyVal { case _ => None } - def collectAsSet(): collection.Set[T] = { + def collectAsSet(): collection.Set[T] = r.aggregate(mutable.Set.empty[T])( { case (s, elem) => s += elem }, - { case (s1, s2) => s1 ++ s2 } + { case (s1, s2) => s1 ++ s2 }, ) - } - def subsetPartitions(keep: IndexedSeq[Int], newPartitioner: Option[Partitioner] = None)(implicit ct: ClassTag[T]): RDD[T] = { - require(keep.length <= r.partitions.length, - s"tried to subset to more partitions than exist ${keep.toSeq} ${r.partitions.toSeq}") - require(keep.isIncreasing && (keep.isEmpty || (keep.head >= 0 && keep.last < r.partitions.length)), - "values not sorted or not in range [0, number of partitions)") + def subsetPartitions( + keep: IndexedSeq[Int], + newPartitioner: Option[Partitioner] = None, + )(implicit ct: ClassTag[T] + ): RDD[T] = { + require( + keep.length <= r.partitions.length, + s"tried to subset to more partitions than exist ${keep.toSeq} ${r.partitions.toSeq}", + ) + require( + keep.isIncreasing && (keep.isEmpty || (keep.head >= 0 && keep.last < r.partitions.length)), + "values not sorted or not in range [0, number of partitions)", + ) val parentPartitions = r.partitions - new RDD[T](r.sparkContext, FastSeq(new NarrowDependency[T](r) { - def getParents(partitionId: Int): Seq[Int] = FastSeq(keep(partitionId)) - })) { + new RDD[T]( + r.sparkContext, + FastSeq(new NarrowDependency[T](r) { + def getParents(partitionId: Int): Seq[Int] = FastSeq(keep(partitionId)) + }), + ) { def getPartitions: Array[Partition] = keep.indices.map { i => SubsetRDDPartition(i, parentPartitions(keep(i))) }.toArray @@ -159,7 +179,9 @@ class RichRDD[T](val r: RDD[T]) extends AnyVal { oldToNewPI: IndexedSeq[Int], newNPartitions: Int, newPIPartition: Int => Iterator[T], - newPartitioner: Option[Partitioner] = None)(implicit ct: ClassTag[T]): RDD[T] = { + newPartitioner: Option[Partitioner] = None, + )(implicit ct: ClassTag[T] + ): RDD[T] = { require(oldToNewPI.length == r.partitions.length) require(oldToNewPI.forall(pi => pi >= 0 && pi < newNPartitions)) @@ -168,22 +190,24 @@ class RichRDD[T](val r: RDD[T]) extends AnyVal { val parentPartitions = r.partitions val newToOldPI = oldToNewPI.zipWithIndex.toMap - new RDD[T](r.sparkContext, FastSeq(new NarrowDependency[T](r) { - def getParents(partitionId: Int): Seq[Int] = newToOldPI.get(partitionId) match { - case Some(oldPI) => Array(oldPI) - case None => Array.empty[Int] - } - })) { + new RDD[T]( + r.sparkContext, + FastSeq(new NarrowDependency[T](r) { + def getParents(partitionId: Int): Seq[Int] = newToOldPI.get(partitionId) match { + case Some(oldPI) => Array(oldPI) + case None => Array.empty[Int] + } + }), + ) { def getPartitions: Array[Partition] = Array.tabulate(newNPartitions) { i => SupersetRDDPartition(i, newToOldPI.get(i).map(parentPartitions)) } - def compute(split: Partition, context: TaskContext): Iterator[T] = { + def compute(split: Partition, context: TaskContext): Iterator[T] = split.asInstanceOf[SupersetRDDPartition].maybeParentPartition match { case Some(part) => r.compute(part, context) case None => newPIPartition(split.index) } - } @transient override val partitioner: Option[Partitioner] = newPartitioner } @@ -198,13 +222,15 @@ class RichRDD[T](val r: RDD[T]) extends AnyVal { ctx: ExecuteContext, path: String, stageLocally: Boolean, - write: (Iterator[T], OutputStream) => (Long, Long) + write: (Iterator[T], OutputStream) => (Long, Long), )(implicit tct: ClassTag[T] ): (Array[FileWriteMetadata]) = - ContextRDD.weaken(r).writePartitions(ctx, + ContextRDD.weaken(r).writePartitions( + ctx, path, null, stageLocally, (_, _) => null, - (_, it, os, _) => write(it, os)) + (_, it, os, _) => write(it, os), + ) } diff --git a/hail/src/main/scala/is/hail/utils/richUtils/RichRow.scala b/hail/src/main/scala/is/hail/utils/richUtils/RichRow.scala index 94bb63fe33e..f88e9b3fe3c 100644 --- a/hail/src/main/scala/is/hail/utils/richUtils/RichRow.scala +++ b/hail/src/main/scala/is/hail/utils/richUtils/RichRow.scala @@ -1,10 +1,9 @@ package is.hail.utils.richUtils -import is.hail.utils.BoxedArrayBuilder -import org.apache.spark.sql.Row - import scala.collection.mutable +import org.apache.spark.sql.Row + class RichRow(r: Row) { def update(i: Int, a: Any): Row = { @@ -37,7 +36,7 @@ class RichRow(r: Row) { def truncate(newSize: Int): Row = { require(newSize <= r.size) - Row.fromSeq(Array.tabulate(newSize){ i => r.get(i) }) + Row.fromSeq(Array.tabulate(newSize)(i => r.get(i))) } } @@ -47,4 +46,4 @@ class RowWithDeletedField(parent: Row, deleteIdx: Int) extends Row { override def get(i: Int): Any = if (i < deleteIdx) parent.get(i) else parent.get(i + 1) override def copy(): Row = this -} \ No newline at end of file +} diff --git a/hail/src/main/scala/is/hail/utils/richUtils/RichSparkContext.scala b/hail/src/main/scala/is/hail/utils/richUtils/RichSparkContext.scala index e49f00638a9..64856109dac 100644 --- a/hail/src/main/scala/is/hail/utils/richUtils/RichSparkContext.scala +++ b/hail/src/main/scala/is/hail/utils/richUtils/RichSparkContext.scala @@ -1,18 +1,17 @@ package is.hail.utils.richUtils import is.hail.utils._ + import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD class RichSparkContext(val sc: SparkContext) extends AnyVal { - def textFilesLines(files: Array[String], - nPartitions: Int = sc.defaultMinPartitions): RDD[WithContext[String]] = { + def textFilesLines(files: Array[String], nPartitions: Int = sc.defaultMinPartitions) + : RDD[WithContext[String]] = { - /* - * Don't use: - * sc.union(files.map(sc.textFile, nPartitions)) - * since it asks for nPartitions per file instead of nPartitions over all. - */ + /* Don't use: + * sc.union(files.map(sc.textFile, nPartitions)) since it asks for nPartitions per file instead + * of nPartitions over all. */ val rdd = sc.textFile(files.mkString(","), nPartitions) val partitionFile = rdd.partitions.map(partitionPath) @@ -20,13 +19,12 @@ class RichSparkContext(val sc: SparkContext) extends AnyVal { .mapPartitionsWithIndex { case (i, it) => // FIXME subclass TextInputFormat to return (file, line) val file = partitionFile(i) - it.map { line => - WithContext(line, Context(line, file, None)) - } + it.map(line => WithContext(line, Context(line, file, None))) } } - def textFileLines(file: String, nPartitions: Int = sc.defaultMinPartitions): RDD[WithContext[String]] = + def textFileLines(file: String, nPartitions: Int = sc.defaultMinPartitions) + : RDD[WithContext[String]] = sc.textFile(file, nPartitions) .map(l => WithContext(l, Context(l, file, None))) } diff --git a/hail/src/main/scala/is/hail/utils/richUtils/RichStringBuilder.scala b/hail/src/main/scala/is/hail/utils/richUtils/RichStringBuilder.scala index 19856dfe5d6..1b1787cf2a0 100644 --- a/hail/src/main/scala/is/hail/utils/richUtils/RichStringBuilder.scala +++ b/hail/src/main/scala/is/hail/utils/richUtils/RichStringBuilder.scala @@ -3,7 +3,7 @@ package is.hail.utils.richUtils import scala.collection.mutable class RichStringBuilder(val sb: mutable.StringBuilder) extends AnyVal { - def tsvAppend(a: Any) { + def tsvAppend(a: Any): Unit = { a match { case null | None => sb.append("NA") case Some(x) => tsvAppend(x) diff --git a/hail/src/main/scala/is/hail/variant/Call.scala b/hail/src/main/scala/is/hail/variant/Call.scala index cbe15e31c0e..400d9e63559 100644 --- a/hail/src/main/scala/is/hail/variant/Call.scala +++ b/hail/src/main/scala/is/hail/variant/Call.scala @@ -4,15 +4,14 @@ import is.hail.check.Gen import is.hail.expr.Parser import is.hail.utils._ -import java.io.Serializable import scala.annotation.switch import scala.collection.JavaConverters._ -import scala.language.implicitConversions + +import java.io.Serializable object Call0 { - def apply(phased: Boolean = false): Call = { + def apply(phased: Boolean = false): Call = Call(0, phased, ploidy = 0) - } } object Call1 { @@ -59,7 +58,8 @@ object Call2 { } object CallN { - def apply(alleles: java.util.List[Int], phased: Boolean): Call = apply(alleles.asScala.toFastSeq, phased) + def apply(alleles: java.util.List[Int], phased: Boolean): Call = + apply(alleles.asScala.toFastSeq, phased) def apply(alleles: IndexedSeq[Int], phased: Boolean = false): Call = { val ploidy = alleles.length @@ -110,7 +110,7 @@ object Call extends Serializable { def allelePair(c: Call): Int = { if (!isDiploid(c)) - fatal(s"invalid ploidy: ${ ploidy(c) }. Only support ploidy == 2") + fatal(s"invalid ploidy: ${ploidy(c)}. Only support ploidy == 2") allelePairUnchecked(c) } @@ -126,7 +126,7 @@ object Call extends Serializable { def unphasedDiploidGtIndex(c: Call): Int = { if (!isDiploid(c)) - fatal(s"unphased_diploid_gt_index only supports ploidy == 2. Found ${ Call.toString(c) }.") + fatal(s"unphased_diploid_gt_index only supports ploidy == 2. Found ${Call.toString(c)}.") if (isPhased(c)) { val p = Genotype.allelePair(alleleRepr(c)) val j = AllelePair.j(p) @@ -159,7 +159,7 @@ object Call extends Serializable { if (i == 0) AllelePair.j(p) else AllelePair.k(p) case _ => if (i < 0 || i >= ploidy(c)) - fatal(s"Index out of bounds for call with ploidy=${ ploidy(c) }: $i") + fatal(s"Index out of bounds for call with ploidy=${ploidy(c)}: $i") alleles(c)(i) } } @@ -171,7 +171,11 @@ object Call extends Serializable { Call1(if (Call.alleleByIndex(c, 0) == i) 1 else 0, Call.isPhased(c)) case 2 => val p = Call.allelePair(c) - Call2(if (AllelePair.j(p) == i) 1 else 0, if (AllelePair.k(p) == i) 1 else 0, Call.isPhased(c)) + Call2( + if (AllelePair.j(p) == i) 1 else 0, + if (AllelePair.k(p) == i) 1 else 0, + Call.isPhased(c), + ) case _ => CallN(Call.alleles(c).map(a => if (a == i) 1 else 0), Call.isPhased(c)) } @@ -224,7 +228,7 @@ object Call extends Serializable { case 4 => "1|1" case _ => val p = allelePair(c) - s"${ AllelePair.j(p) }|${ AllelePair.k(p) }" + s"${AllelePair.j(p)}|${AllelePair.k(p)}" } case _ => alleles(c).mkString("|") @@ -248,7 +252,7 @@ object Call extends Serializable { case 2 => "1/1" case _ => val p = allelePair(c) - s"${ AllelePair.j(p) }/${ AllelePair.k(p) }" + s"${AllelePair.j(p)}/${AllelePair.k(p)}" } case _ => alleles(c).mkString("/") @@ -297,7 +301,7 @@ object Call extends Serializable { case 4 => phased_11 case _ => val p = allelePair(c) - s"${ AllelePair.j(p) }|${ AllelePair.k(p) }".getBytes() + s"${AllelePair.j(p)}|${AllelePair.k(p)}".getBytes() } case _ => alleles(c).mkString("|").getBytes() @@ -321,7 +325,7 @@ object Call extends Serializable { case 2 => unphased_11 case _ => val p = allelePair(c) - s"${ AllelePair.j(p) }/${ AllelePair.k(p) }".getBytes() + s"${AllelePair.j(p)}/${AllelePair.k(p)}".getBytes() } case _ => alleles(c).mkString("/").getBytes() @@ -371,9 +375,9 @@ object Call extends Serializable { (ploidy(c): @switch) match { case 0 | 1 => false case 2 => alleleRepr(c) > 0 && { - val p = allelePairUnchecked(c) - AllelePair.j(p) != AllelePair.k(p) - } + val p = allelePairUnchecked(c) + AllelePair.j(p) != AllelePair.k(p) + } case _ => throw new UnsupportedOperationException } } @@ -383,30 +387,29 @@ object Call extends Serializable { case 0 => false case 1 => alleleRepr(c) > 0 case 2 => alleleRepr(c) > 0 && { - val p = allelePairUnchecked(c) - AllelePair.j(p) == AllelePair.k(p) - } + val p = allelePairUnchecked(c) + AllelePair.j(p) == AllelePair.k(p) + } case _ => throw new UnsupportedOperationException } } - def isNonRef(c: Call): Boolean = { + def isNonRef(c: Call): Boolean = (ploidy(c): @switch) match { case 0 => false case 1 | 2 => alleleRepr(c) > 0 case _ => alleles(c).exists(_ != 0) } - } def isHetNonRef(c: Call): Boolean = { (ploidy(c): @switch) match { case 0 | 1 => false case 2 => alleleRepr(c) > 0 && { - val p = allelePairUnchecked(c) - val j = AllelePair.j(p) - val k = AllelePair.k(p) - j > 0 && k > 0 && k != j - } + val p = allelePairUnchecked(c) + val j = AllelePair.j(p) + val k = AllelePair.k(p) + j > 0 && k > 0 && k != j + } case _ => throw new UnsupportedOperationException } } @@ -415,11 +418,11 @@ object Call extends Serializable { (ploidy(c): @switch) match { case 0 | 1 => false case 2 => alleleRepr(c) > 0 && { - val p = allelePairUnchecked(c) - val j = AllelePair.j(p) - val k = AllelePair.k(p) - (j == 0 && k > 0) || (k == 0 && j > 0) - } + val p = allelePairUnchecked(c) + val j = AllelePair.j(p) + val k = AllelePair.k(p) + (j == 0 && k > 0) || (k == 0 && j > 0) + } case _ => throw new UnsupportedOperationException } } @@ -466,7 +469,7 @@ object Call extends Serializable { } } - def check(c: Call, nAlleles: Int) { + def check(c: Call, nAlleles: Int): Unit = { (ploidy(c): @switch) match { case 0 => case 1 => @@ -480,13 +483,20 @@ object Call extends Serializable { unphasedDiploidGtIndex(Call2(AllelePair.j(p), AllelePair.k(p))) } else unphasedDiploidGtIndex(c) - assert(udtn < nGenotypes, s"Invalid call found '${ c.toString }' for number of alleles equal to '$nAlleles'.") + assert( + udtn < nGenotypes, + s"Invalid call found '${c.toString}' for number of alleles equal to '$nAlleles'.", + ) case _ => alleles(c).foreach(a => assert(a >= 0 && a < nAlleles)) } } - def gen(nAlleles: Int, ploidyGen: Gen[Int] = Gen.choose(0, 2), phasedGen: Gen[Boolean] = Gen.nextCoin(0.5)): Gen[Call] = for { + def gen( + nAlleles: Int, + ploidyGen: Gen[Int] = Gen.choose(0, 2), + phasedGen: Gen[Boolean] = Gen.nextCoin(0.5), + ): Gen[Call] = for { ploidy <- ploidyGen phased <- phasedGen alleles <- Gen.buildableOfN[Array](ploidy, Gen.choose(0, nAlleles - 1)) diff --git a/hail/src/main/scala/is/hail/variant/Genotype.scala b/hail/src/main/scala/is/hail/variant/Genotype.scala index b7e41747d93..3acf5989c6c 100644 --- a/hail/src/main/scala/is/hail/variant/Genotype.scala +++ b/hail/src/main/scala/is/hail/variant/Genotype.scala @@ -4,9 +4,8 @@ import is.hail.annotations.Annotation import is.hail.check.Gen import is.hail.types.virtual.{TArray, TCall, TInt32, TStruct} import is.hail.utils._ -import org.apache.spark.sql.Row -import scala.language.implicitConversions +import org.apache.spark.sql.Row object GenotypeType extends Enumeration { type GenotypeType = Value @@ -23,12 +22,11 @@ object AllelePair { j | (k << 16) } - def fromNonNormalized(j: Int, k: Int): Int = { + def fromNonNormalized(j: Int, k: Int): Int = if (j <= k) AllelePair(j, k) else AllelePair(k, j) - } def j(p: Int): Int = p & 0xffff def k(p: Int): Int = (p >> 16) & 0xffff @@ -45,7 +43,8 @@ object Genotype { "AD" -> TArray(TInt32), "DP" -> TInt32, "GQ" -> TInt32, - "PL" -> TArray(TInt32)) + "PL" -> TArray(TInt32), + ) def call(g: Annotation): Option[Call] = { if (g == null) @@ -61,15 +60,29 @@ object Genotype { def apply(c: BoxedCall): Annotation = Annotation(c, null, null, null, null) - def apply(c: BoxedCall, ad: Array[Int], dp: java.lang.Integer, gq: java.lang.Integer, pl: Array[Int]): Annotation = + def apply( + c: BoxedCall, + ad: Array[Int], + dp: java.lang.Integer, + gq: java.lang.Integer, + pl: Array[Int], + ): Annotation = Annotation(c, ad: IndexedSeq[Int], dp, gq, pl: IndexedSeq[Int]) - def apply(c: Option[Call] = None, + def apply( + c: Option[Call] = None, ad: Option[Array[Int]] = None, dp: Option[Int] = None, gq: Option[Int] = None, - pl: Option[Array[Int]] = None): Annotation = - Annotation(c.orNull, ad.map(adx => adx: IndexedSeq[Int]).orNull, dp.orNull, gq.orNull, pl.map(plx => plx: IndexedSeq[Int]).orNull) + pl: Option[Array[Int]] = None, + ): Annotation = + Annotation( + c.orNull, + ad.map(adx => adx: IndexedSeq[Int]).orNull, + dp.orNull, + gq.orNull, + pl.map(plx => plx: IndexedSeq[Int]).orNull, + ) def gqFromPL(pl: Array[Int]): Int = { var m = 99 @@ -127,7 +140,9 @@ object Genotype { val maxPhredInTable = 8192 - lazy val phredToLinearConversionTable: Array[Double] = (0 to maxPhredInTable).map { i => math.pow(10, i / -10.0) }.toArray + lazy val phredToLinearConversionTable: Array[Double] = (0 to maxPhredInTable).map { i => + math.pow(10, i / -10.0) + }.toArray def phredToLinear(i: Int): Double = if (i < maxPhredInTable) phredToLinearConversionTable(i) else math.pow(10, i / -10.0) @@ -140,15 +155,44 @@ object Genotype { (p1 + 2 * p2) / (p0 + p1 + p2) } - val smallAllelePair: Array[Int] = Array(AllelePair(0, 0), AllelePair(0, 1), AllelePair(1, 1), - AllelePair(0, 2), AllelePair(1, 2), AllelePair(2, 2), - AllelePair(0, 3), AllelePair(1, 3), AllelePair(2, 3), AllelePair(3, 3), - AllelePair(0, 4), AllelePair(1, 4), AllelePair(2, 4), AllelePair(3, 4), AllelePair(4, 4), - AllelePair(0, 5), AllelePair(1, 5), AllelePair(2, 5), AllelePair(3, 5), AllelePair(4, 5), AllelePair(5, 5), - AllelePair(0, 6), AllelePair(1, 6), AllelePair(2, 6), AllelePair(3, 6), AllelePair(4, 6), AllelePair(5, 6), + val smallAllelePair: Array[Int] = Array( + AllelePair(0, 0), + AllelePair(0, 1), + AllelePair(1, 1), + AllelePair(0, 2), + AllelePair(1, 2), + AllelePair(2, 2), + AllelePair(0, 3), + AllelePair(1, 3), + AllelePair(2, 3), + AllelePair(3, 3), + AllelePair(0, 4), + AllelePair(1, 4), + AllelePair(2, 4), + AllelePair(3, 4), + AllelePair(4, 4), + AllelePair(0, 5), + AllelePair(1, 5), + AllelePair(2, 5), + AllelePair(3, 5), + AllelePair(4, 5), + AllelePair(5, 5), + AllelePair(0, 6), + AllelePair(1, 6), + AllelePair(2, 6), + AllelePair(3, 6), + AllelePair(4, 6), + AllelePair(5, 6), AllelePair(6, 6), - AllelePair(0, 7), AllelePair(1, 7), AllelePair(2, 7), AllelePair(3, 7), AllelePair(4, 7), - AllelePair(5, 7), AllelePair(6, 7), AllelePair(7, 7)) + AllelePair(0, 7), + AllelePair(1, 7), + AllelePair(2, 7), + AllelePair(3, 7), + AllelePair(4, 7), + AllelePair(5, 7), + AllelePair(6, 7), + AllelePair(7, 7), + ) val smallAlleleJ: Array[Int] = smallAllelePair.map(AllelePair.j) val smallAlleleK: Array[Int] = smallAllelePair.map(AllelePair.k) @@ -175,12 +219,11 @@ object Genotype { AllelePair(j, k) } - def allelePair(i: Int): Int = { + def allelePair(i: Int): Int = if (i < smallAllelePair.length) smallAllelePair(i) else allelePairSqrt(i) - } def diploidGtIndex(j: Int, k: Int): Int = { if (j < 0 | j > k) { @@ -191,26 +234,26 @@ object Genotype { def diploidGtIndex(p: Int): Int = diploidGtIndex(AllelePair.j(p), AllelePair.k(p)) - def diploidGtIndexWithSwap(i: Int, j: Int): Int = { + def diploidGtIndexWithSwap(i: Int, j: Int): Int = if (j < i) diploidGtIndex(j, i) else diploidGtIndex(i, j) - } def genExtremeNonmissing(nAlleles: Int): Gen[Annotation] = { val m = Int.MaxValue / (nAlleles + 1) val nGenotypes = triangle(nAlleles) - val gg = for (c: Option[Call] <- Gen.option(Call.genUnphasedDiploid(nAlleles)); - ad <- Gen.option(Gen.buildableOfN[Array](nAlleles, Gen.choose(0, m))); - dp <- Gen.option(Gen.choose(0, m)); - gq <- Gen.option(Gen.choose(0, 10000)); + val gg = for { + c: Option[Call] <- Gen.option(Call.genUnphasedDiploid(nAlleles)) + ad <- Gen.option(Gen.buildableOfN[Array](nAlleles, Gen.choose(0, m))) + dp <- Gen.option(Gen.choose(0, m)) + gq <- Gen.option(Gen.choose(0, 10000)) pl <- Gen.oneOfGen( Gen.option(Gen.buildableOfN[Array](nGenotypes, Gen.choose(0, m))), - Gen.option(Gen.buildableOfN[Array](nGenotypes, Gen.choose(0, 100))))) yield { - c.foreach { c => - pl.foreach { pla => pla(Call.unphasedDiploidGtIndex(c)) = 0 } - } + Gen.option(Gen.buildableOfN[Array](nGenotypes, Gen.choose(0, 100))), + ) + } yield { + c.foreach(c => pl.foreach(pla => pla(Call.unphasedDiploidGtIndex(c)) = 0)) pl.foreach { pla => val m = pla.min var i = 0 @@ -219,35 +262,41 @@ object Genotype { i += 1 } } - val g = Annotation(c.orNull, + val g = Annotation( + c.orNull, ad.map(a => a: IndexedSeq[Int]).orNull, dp.map(_ + ad.map(_.sum).getOrElse(0)).orNull, gq.orNull, - pl.map(a => a: IndexedSeq[Int]).orNull) + pl.map(a => a: IndexedSeq[Int]).orNull, + ) g } gg } - def genExtreme(nAlleles: Int): Gen[Annotation] = { + def genExtreme(nAlleles: Int): Gen[Annotation] = Gen.frequency( (100, genExtremeNonmissing(nAlleles)), - (1, Gen.const(null))) - } + (1, Gen.const(null)), + ) def genRealisticNonmissing(nAlleles: Int): Gen[Annotation] = { val nGenotypes = triangle(nAlleles) - val gg = for (callRate <- Gen.choose(0d, 1d); - alleleFrequencies <- Gen.buildableOfN[Array](nAlleles, Gen.choose(1e-6, 1d)) // avoid divison by 0 - .map { rawWeights => - val sum = rawWeights.sum - rawWeights.map(_ / sum) - }; - c <- Gen.option(Gen.zip(Gen.chooseWithWeights(alleleFrequencies), Gen.chooseWithWeights(alleleFrequencies)) - .map { case (gti, gtj) => Call2(gti, gtj) }, callRate); - ad <- Gen.option(Gen.buildableOfN[Array](nAlleles, - Gen.choose(0, 50))); - dp <- Gen.choose(0, 30).map(d => ad.map(o => o.sum + d)); + val gg = for { + callRate <- Gen.choose(0d, 1d) + alleleFrequencies <- + Gen.buildableOfN[Array](nAlleles, Gen.choose(1e-6, 1d)) // avoid divison by 0 + .map { rawWeights => + val sum = rawWeights.sum + rawWeights.map(_ / sum) + } + c <- Gen.option( + Gen.zip(Gen.chooseWithWeights(alleleFrequencies), Gen.chooseWithWeights(alleleFrequencies)) + .map { case (gti, gtj) => Call2(gti, gtj) }, + callRate, + ) + ad <- Gen.option(Gen.buildableOfN[Array](nAlleles, Gen.choose(0, 50))) + dp <- Gen.choose(0, 30).map(d => ad.map(o => o.sum + d)) pl <- Gen.option(Gen.buildableOfN[Array](nGenotypes, Gen.choose(0, 1000)).map { arr => c match { case Some(x) => @@ -257,19 +306,17 @@ object Genotype { val min = arr.min arr.map(_ - min) } - }); + }) gq <- Gen.choose(-30, 30).map(i => pl.map(pls => math.max(0, gqFromPL(pls) + i))) - ) yield - Annotation(c.orNull, ad.map(a => a: IndexedSeq[Int]).orNull, dp.orNull, gq.orNull, pl.map(a => a: IndexedSeq[Int]).orNull) + } yield Annotation(c.orNull, ad.map(a => a: IndexedSeq[Int]).orNull, dp.orNull, gq.orNull, pl.map(a => a: IndexedSeq[Int]).orNull) gg } - def genRealistic(nAlleles: Int): Gen[Annotation] = { + def genRealistic(nAlleles: Int): Gen[Annotation] = Gen.frequency( (100, genRealisticNonmissing(nAlleles)), - (1, Gen.const(null))) - } - + (1, Gen.const(null)), + ) def genGenericCallAndProbabilitiesGenotype(nAlleles: Int): Gen[Annotation] = { val nGenotypes = triangle(nAlleles) @@ -277,10 +324,12 @@ object Genotype { val c = gp.flatMap(a => Option(uniqueMaxIndex(a))).map(Call2.fromUnphasedDiploidGtIndex(_)) Row( c.orNull, - gp.map(gpx => gpx.map(p => p.toDouble / 32768): IndexedSeq[Double]).orNull) + gp.map(gpx => gpx.map(p => p.toDouble / 32768): IndexedSeq[Double]).orNull, + ) } Gen.frequency( (100, gg), - (1, Gen.const(null))) + (1, Gen.const(null)), + ) } } diff --git a/hail/src/main/scala/is/hail/variant/HardCallView.scala b/hail/src/main/scala/is/hail/variant/HardCallView.scala index 7cb5682addd..cc5e715e847 100644 --- a/hail/src/main/scala/is/hail/variant/HardCallView.scala +++ b/hail/src/main/scala/is/hail/variant/HardCallView.scala @@ -1,6 +1,6 @@ package is.hail.variant -import is.hail.annotations.{Region, RegionValue} +import is.hail.annotations.Region import is.hail.types._ import is.hail.types.physical._ import is.hail.types.virtual.TCall @@ -22,11 +22,15 @@ final class ArrayGenotypeView(rvType: PStruct) { } } - private val (gtExists, gtIndex, gtType) = lookupField("GT", _ == PCanonicalCall()) - private val (gpExists, gpIndex, _gpType) = lookupField("GP", - pt => pt.isInstanceOf[PArray] && pt.asInstanceOf[PArray].elementType.isInstanceOf[PFloat64]) + private val (gtExists, gtIndex, _) = lookupField("GT", _ == PCanonicalCall()) + + private val (gpExists, gpIndex, _gpType) = lookupField( + "GP", + pt => pt.isInstanceOf[PArray] && pt.asInstanceOf[PArray].elementType.isInstanceOf[PFloat64], + ) + // Do not try to move this cast into the destructuring above - // https://stackoverflow.com/questions/27789412/scala-exception-in-for-comprehension-with-type-annotation + /* https://stackoverflow.com/questions/27789412/scala-exception-in-for-comprehension-with-type-annotation */ private[this] val gpType = _gpType.asInstanceOf[PArray] private var gsOffset: Long = _ @@ -34,12 +38,12 @@ final class ArrayGenotypeView(rvType: PStruct) { private var gOffset: Long = _ var gIsDefined: Boolean = _ - def set(offset: Long) { + def set(offset: Long): Unit = { gsOffset = rvType.loadField(offset, entriesIndex) gsLength = tgs.loadLength(gsOffset) } - def setGenotype(idx: Int) { + def setGenotype(idx: Int): Unit = { require(idx >= 0 && idx < gsLength) gIsDefined = tgs.isElementDefined(gsOffset, idx) gOffset = tgs.loadElement(gsOffset, gsLength, idx) @@ -70,11 +74,9 @@ final class ArrayGenotypeView(rvType: PStruct) { } } - object HardCallView { - def apply(rowSignature: PStruct): HardCallView = { + def apply(rowSignature: PStruct): HardCallView = new HardCallView(rowSignature, "GT") - } } final class HardCallView(rvType: PStruct, callField: String) { @@ -99,12 +101,12 @@ final class HardCallView(rvType: PStruct, callField: String) { var gsLength: Int = _ var gIsDefined: Boolean = _ - def set(offset: Long) { + def set(offset: Long): Unit = { gsOffset = rvType.loadField(offset, entriesIndex) gsLength = tgs.loadLength(gsOffset) } - def setGenotype(idx: Int) { + def setGenotype(idx: Int): Unit = { require(idx >= 0 && idx < gsLength) gIsDefined = tgs.isElementDefined(gsOffset, idx) gOffset = tgs.loadElement(gsOffset, gsLength, idx) diff --git a/hail/src/main/scala/is/hail/variant/Locus.scala b/hail/src/main/scala/is/hail/variant/Locus.scala index 9d3fe45fdc4..da7dc169e7b 100644 --- a/hail/src/main/scala/is/hail/variant/Locus.scala +++ b/hail/src/main/scala/is/hail/variant/Locus.scala @@ -4,13 +4,13 @@ import is.hail.annotations.Annotation import is.hail.check.Gen import is.hail.expr.Parser import is.hail.utils._ + +import scala.collection.JavaConverters._ + import org.apache.spark.sql.Row import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} import org.json4s._ -import scala.collection.JavaConverters._ -import scala.language.implicitConversions - object Locus { val simpleContigs: Seq[String] = (1 to 22).map(_.toString) ++ Seq("X", "Y", "MT") @@ -19,21 +19,20 @@ object Locus { Locus(contig, position) } - def annotation(contig: String, position: Int, rg: Option[ReferenceGenome]): Annotation = { + def annotation(contig: String, position: Int, rg: Option[ReferenceGenome]): Annotation = rg match { case Some(ref) => Locus(contig, position, ref) case None => Annotation(contig, position) } - } def sparkSchema: StructType = StructType(Array( StructField("contig", StringType, nullable = false), - StructField("position", IntegerType, nullable = false))) + StructField("position", IntegerType, nullable = false), + )) - def fromRow(r: Row): Locus = { + def fromRow(r: Row): Locus = Locus(r.getAs[String](0), r.getInt(1)) - } def gen(rg: ReferenceGenome): Gen[Locus] = for { (contig, length) <- Contig.gen(rg) @@ -53,14 +52,29 @@ object Locus { def parseInterval(str: String, rg: ReferenceGenome, invalidMissing: Boolean = false): Interval = Parser.parseLocusInterval(str, rg, invalidMissing) - def parseIntervals(arr: Array[String], rg: ReferenceGenome, invalidMissing: Boolean): Array[Interval] = arr.map(parseInterval(_, rg, invalidMissing)) - - def parseIntervals(arr: java.util.List[String], rg: ReferenceGenome, invalidMissing: Boolean = false): Array[Interval] = parseIntervals(arr.asScala.toArray, rg, invalidMissing) - - def makeInterval(contig: String, start: Int, end: Int, includesStart: Boolean, includesEnd: Boolean, - rgBase: ReferenceGenome, invalidMissing: Boolean = false): Interval = { + def parseIntervals(arr: Array[String], rg: ReferenceGenome, invalidMissing: Boolean) + : Array[Interval] = arr.map(parseInterval(_, rg, invalidMissing)) + + def parseIntervals( + arr: java.util.List[String], + rg: ReferenceGenome, + invalidMissing: Boolean = false, + ): Array[Interval] = parseIntervals(arr.asScala.toArray, rg, invalidMissing) + + def makeInterval( + contig: String, + start: Int, + end: Int, + includesStart: Boolean, + includesEnd: Boolean, + rgBase: ReferenceGenome, + invalidMissing: Boolean = false, + ): Interval = { val rg = rgBase.asInstanceOf[ReferenceGenome] - rg.toLocusInterval(Interval(Locus(contig, start), Locus(contig, end), includesStart, includesEnd), invalidMissing) + rg.toLocusInterval( + Interval(Locus(contig, start), Locus(contig, end), includesStart, includesEnd), + invalidMissing, + ) } } @@ -69,14 +83,16 @@ case class Locus(contig: String, position: Int) { def toJSON: JValue = JObject( ("contig", JString(contig)), - ("position", JInt(position))) + ("position", JInt(position)), + ) def copyChecked(rg: ReferenceGenome, contig: String = contig, position: Int = position): Locus = { rg.checkLocus(contig, position) Locus(contig, position) } - def isAutosomalOrPseudoAutosomal(rg: ReferenceGenome): Boolean = isAutosomal(rg) || inXPar(rg) || inYPar(rg) + def isAutosomalOrPseudoAutosomal(rg: ReferenceGenome): Boolean = + isAutosomal(rg) || inXPar(rg) || inYPar(rg) def isAutosomal(rg: ReferenceGenome): Boolean = !(inX(rg) || inY(rg) || isMitochondrial(rg)) diff --git a/hail/src/main/scala/is/hail/variant/ReferenceGenome.scala b/hail/src/main/scala/is/hail/variant/ReferenceGenome.scala index d018b8b6a9b..7412b32ff9b 100644 --- a/hail/src/main/scala/is/hail/variant/ReferenceGenome.scala +++ b/hail/src/main/scala/is/hail/variant/ReferenceGenome.scala @@ -1,40 +1,38 @@ package is.hail.variant -import java.io.InputStream -import htsjdk.samtools.reference.FastaSequenceIndex import is.hail.HailContext -import is.hail.asm4s.Code -import is.hail.backend.{BroadcastValue, ExecuteContext, HailStateManager} +import is.hail.annotations.ExtendedOrdering +import is.hail.backend.{BroadcastValue, ExecuteContext} import is.hail.check.Gen -import is.hail.expr.ir.{EmitClassBuilder, RelationalSpec} -import is.hail.expr.{JSONExtractContig, JSONExtractIntervalLocus, JSONExtractReferenceGenome, Parser} +import is.hail.expr.{ + JSONExtractContig, JSONExtractIntervalLocus, JSONExtractReferenceGenome, Parser, +} +import is.hail.expr.ir.RelationalSpec import is.hail.io.fs.FS -import is.hail.io.reference.LiftOver -import is.hail.io.reference.{FASTAReader, FASTAReaderConfig} +import is.hail.io.reference.{FASTAReader, FASTAReaderConfig, LiftOver} import is.hail.types._ -import is.hail.types.virtual.{TInt64, TLocus, Type} +import is.hail.types.virtual.{TLocus, Type} import is.hail.utils._ import scala.collection.JavaConverters._ import scala.collection.mutable -import scala.language.implicitConversions -import org.apache.spark.TaskContext -import org.json4s._ -import org.json4s.jackson.{JsonMethods, Serialization} +import java.io.{FileNotFoundException, InputStream} import java.lang.ThreadLocal -import is.hail.annotations.ExtendedOrdering +import htsjdk.samtools.reference.FastaSequenceIndex +import org.apache.spark.TaskContext +import org.json4s._ +import org.json4s.jackson.{JsonMethods, Serialization} class BroadcastRG(rgParam: ReferenceGenome) extends Serializable { @transient private[this] val rg: ReferenceGenome = rgParam - private[this] val rgBc: BroadcastValue[ReferenceGenome] = { + private[this] val rgBc: BroadcastValue[ReferenceGenome] = if (TaskContext.get != null) null else rg.broadcast - } def value: ReferenceGenome = { val t = if (rg != null) @@ -45,9 +43,15 @@ class BroadcastRG(rgParam: ReferenceGenome) extends Serializable { } } -case class ReferenceGenome(name: String, contigs: Array[String], lengths: Map[String, Int], - xContigs: Set[String] = Set.empty[String], yContigs: Set[String] = Set.empty[String], - mtContigs: Set[String] = Set.empty[String], parInput: Array[(Locus, Locus)] = Array.empty[(Locus, Locus)]) extends Serializable { +case class ReferenceGenome( + name: String, + contigs: Array[String], + lengths: Map[String, Int], + xContigs: Set[String] = Set.empty[String], + yContigs: Set[String] = Set.empty[String], + mtContigs: Set[String] = Set.empty[String], + parInput: Array[(Locus, Locus)] = Array.empty[(Locus, Locus)], +) extends Serializable { @transient lazy val broadcastRG: BroadcastRG = new BroadcastRG(this) val nContigs = contigs.length @@ -62,25 +66,35 @@ case class ReferenceGenome(name: String, contigs: Array[String], lengths: Map[St val extraLengths = lengths.keySet.diff(contigs.toSet) if (missingLengths.nonEmpty) - fatal(s"No lengths given for the following contigs: ${ missingLengths.mkString(", ") }") + fatal(s"No lengths given for the following contigs: ${missingLengths.mkString(", ")}") if (extraLengths.nonEmpty) - fatal(s"Contigs found in 'lengths' that are not present in 'contigs': ${ extraLengths.mkString(", ") }") + fatal( + s"Contigs found in 'lengths' that are not present in 'contigs': ${extraLengths.mkString(", ")}" + ) if (xContigs.intersect(yContigs).nonEmpty) - fatal(s"Found the contigs '${ xContigs.intersect(yContigs).mkString(", ") }' in both X and Y contigs.") + fatal( + s"Found the contigs '${xContigs.intersect(yContigs).mkString(", ")}' in both X and Y contigs." + ) if (xContigs.intersect(mtContigs).nonEmpty) - fatal(s"Found the contigs '${ xContigs.intersect(mtContigs).mkString(", ") }' in both X and MT contigs.") + fatal( + s"Found the contigs '${xContigs.intersect(mtContigs).mkString(", ")}' in both X and MT contigs." + ) if (yContigs.intersect(mtContigs).nonEmpty) - fatal(s"Found the contigs '${ yContigs.intersect(mtContigs).mkString(", ") }' in both Y and MT contigs.") + fatal( + s"Found the contigs '${yContigs.intersect(mtContigs).mkString(", ")}' in both Y and MT contigs." + ) - val contigsIndex: java.util.HashMap[String, Integer] = makeJavaMap(contigs.iterator.zipWithIndex.map { case (c, i) => (c, box(i))}) + val contigsIndex: java.util.HashMap[String, Integer] = + makeJavaMap(contigs.iterator.zipWithIndex.map { case (c, i) => (c, box(i)) }) val contigsSet: java.util.HashSet[String] = makeJavaSet(contigs) - private val jLengths: java.util.HashMap[String, java.lang.Integer] = makeJavaMap(lengths.iterator.map { case (c, i) => (c, box(i))}) + private val jLengths: java.util.HashMap[String, java.lang.Integer] = + makeJavaMap(lengths.iterator.map { case (c, i) => (c, box(i)) }) val lengthsByIndex: Array[Int] = contigs.map(lengths) @@ -94,13 +108,19 @@ case class ReferenceGenome(name: String, contigs: Array[String], lengths: Map[St val mtNotInRef = mtContigs.diff(contigsSet.asScala) if (xNotInRef.nonEmpty) - fatal(s"The following X contig names are absent from the reference: '${ xNotInRef.mkString(", ") }'.") + fatal( + s"The following X contig names are absent from the reference: '${xNotInRef.mkString(", ")}'." + ) if (yNotInRef.nonEmpty) - fatal(s"The following Y contig names are absent from the reference: '${ yNotInRef.mkString(", ") }'.") + fatal( + s"The following Y contig names are absent from the reference: '${yNotInRef.mkString(", ")}'." + ) if (mtNotInRef.nonEmpty) - fatal(s"The following mitochondrial contig names are absent from the reference: '${ mtNotInRef.mkString(", ") }'.") + fatal( + s"The following mitochondrial contig names are absent from the reference: '${mtNotInRef.mkString(", ")}'." + ) val xContigIndices = xContigs.map(contigsIndex.get) val yContigIndices = yContigs.map(contigsIndex.get) @@ -130,11 +150,17 @@ case class ReferenceGenome(name: String, contigs: Array[String], lengths: Map[St val par = parInput.map { case (start, end) => if (start.contig != end.contig) - fatal(s"The contigs for the 'start' and 'end' of a PAR interval must be the same. Found '$start-$end'.") + fatal( + s"The contigs for the 'start' and 'end' of a PAR interval must be the same. Found '$start-$end'." + ) - if ((!xContigs.contains(start.contig) && !yContigs.contains(start.contig)) || - (!xContigs.contains(end.contig) && !yContigs.contains(end.contig))) - fatal(s"The contig name for PAR interval '$start-$end' was not found in xContigs '${ xContigs.mkString(",") }' or in yContigs '${ yContigs.mkString(",") }'.") + if ( + (!xContigs.contains(start.contig) && !yContigs.contains(start.contig)) || + (!xContigs.contains(end.contig) && !yContigs.contains(end.contig)) + ) + fatal( + s"The contig name for PAR interval '$start-$end' was not found in xContigs '${xContigs.mkString(",")}' or in yContigs '${yContigs.mkString(",")}'." + ) Interval(start, end, includesStart = true, includesEnd = false) } @@ -193,23 +219,27 @@ case class ReferenceGenome(name: String, contigs: Array[String], lengths: Map[St def isValidContig(contig: String): Boolean = contigsSet.contains(contig) - def isValidLocus(contig: String, pos: Int): Boolean = isValidContig(contig) && pos > 0 && pos <= contigLength(contig) + def isValidLocus(contig: String, pos: Int): Boolean = + isValidContig(contig) && pos > 0 && pos <= contigLength(contig) def isValidLocus(l: Locus): Boolean = isValidLocus(l.contig, l.position) - def checkContig(contig: String): Unit = { + def checkContig(contig: String): Unit = if (!isValidContig(contig)) fatal(s"Contig '$contig' is not in the reference genome '$name'.") - } def checkLocus(l: Locus): Unit = checkLocus(l.contig, l.position) def checkLocus(contig: String, pos: Int): Unit = { if (!isValidLocus(contig, pos)) { if (!isValidContig(contig)) - fatal(s"Invalid locus '$contig:$pos' found. Contig '$contig' is not in the reference genome '$name'.") + fatal( + s"Invalid locus '$contig:$pos' found. Contig '$contig' is not in the reference genome '$name'." + ) else - fatal(s"Invalid locus '$contig:$pos' found. Position '$pos' is not within the range [1-${contigLength(contig)}] for reference genome '$name'.") + fatal( + s"Invalid locus '$contig:$pos' found. Position '$pos' is not within the range [1-${contigLength(contig)}] for reference genome '$name'." + ) } } @@ -224,9 +254,13 @@ case class ReferenceGenome(name: String, contigs: Array[String], lengths: Map[St return null else { if (!isValidContig(start.contig)) - fatal(s"Invalid interval '$i' found. Contig '${ start.contig }' is not in the reference genome '$name'.") + fatal( + s"Invalid interval '$i' found. Contig '${start.contig}' is not in the reference genome '$name'." + ) else - fatal(s"Invalid interval '$i' found. Start '$start' is not within the range [1-${ contigLength(start.contig) }] for reference genome '$name'.") + fatal( + s"Invalid interval '$i' found. Start '$start' is not within the range [1-${contigLength(start.contig)}] for reference genome '$name'." + ) } } @@ -235,9 +269,13 @@ case class ReferenceGenome(name: String, contigs: Array[String], lengths: Map[St return null else { if (!isValidContig(end.contig)) - fatal(s"Invalid interval '$i' found. Contig '${ end.contig }' is not in the reference genome '$name'.") + fatal( + s"Invalid interval '$i' found. Contig '${end.contig}' is not in the reference genome '$name'." + ) else - fatal(s"Invalid interval '$i' found. End '$end' is not within the range [1-${ contigLength(end.contig) }] for reference genome '$name'.") + fatal( + s"Invalid interval '$i' found. End '$end' is not within the range [1-${contigLength(end.contig)}] for reference genome '$name'." + ) } } @@ -291,7 +329,8 @@ case class ReferenceGenome(name: String, contigs: Array[String], lengths: Map[St def isMitochondrial(contig: String): Boolean = mtContigs.contains(contig) - def isAutosomal(contig: String): Boolean = !(inX(contig) || inY(contig) || isMitochondrial(contig)) + def isAutosomal(contig: String): Boolean = + !(inX(contig) || inY(contig) || isMitochondrial(contig)) def inPar(l: Locus): Boolean = par.exists(_.contains(extendedLocusOrdering, l)) @@ -303,11 +342,13 @@ case class ReferenceGenome(name: String, contigs: Array[String], lengths: Map[St def inYNonPar(l: Locus): Boolean = inY(l.contig) && !inPar(l) - def isAutosomalOrPseudoAutosomal(l: Locus): Boolean = isAutosomal(l.contig) || ((inX(l.contig) || inY(l.contig)) && inPar(l)) + def isAutosomalOrPseudoAutosomal(l: Locus): Boolean = + isAutosomal(l.contig) || ((inX(l.contig) || inY(l.contig)) && inPar(l)) - def compare(contig1: String, contig2: String): Int = ReferenceGenome.compare(contigsIndex, contig1, contig2) + def compare(contig1: String, contig2: String): Int = + ReferenceGenome.compare(contigsIndex, contig1, contig2) - def validateContigRemap(contigMapping: Map[String, String]) { + def validateContigRemap(contigMapping: Map[String, String]): Unit = { val badContigs = mutable.Set[(String, String)]() contigMapping.foreach { case (oldName, newName) => @@ -316,33 +357,42 @@ case class ReferenceGenome(name: String, contigs: Array[String], lengths: Map[St } if (badContigs.nonEmpty) - fatal(s"Found ${ badContigs.size } ${ plural(badContigs.size, "contig mapping that does", "contigs mapping that do") }" + - s" not have remapped contigs in reference genome '$name':\n " + - s"@1", badContigs.truncatable("\n ")) + fatal( + s"Found ${badContigs.size} ${plural(badContigs.size, "contig mapping that does", "contigs mapping that do")}" + + s" not have remapped contigs in reference genome '$name':\n " + + s"@1", + badContigs.truncatable("\n "), + ) } def hasSequence: Boolean = fastaFilePath != null - def addSequence(ctx: ExecuteContext, fastaFile: String, indexFile: String) { + def addSequence(ctx: ExecuteContext, fastaFile: String, indexFile: String): Unit = { if (hasSequence) fatal(s"FASTA sequence has already been loaded for reference genome '$name'.") val tmpdir = ctx.localTmpdir val fs = ctx.fs - if (!fs.exists(fastaFile)) - fatal(s"FASTA file '$fastaFile' does not exist or you do not have access.") - if (!fs.exists(indexFile)) - fatal(s"FASTA index file '$indexFile' does not exist or you do not have access.") + if (!fs.isFile(fastaFile)) + fatal(s"FASTA file '$fastaFile' does not exist, is not a file, or you do not have access.") + if (!fs.isFile(indexFile)) + fatal( + s"FASTA index file '$indexFile' does not exist, is not a file, or you do not have access." + ) fastaFilePath = fastaFile fastaIndexPath = indexFile - // assumption, fastaFile and indexFile will not move or change for the entire duration of a hail pipeline + /* assumption, fastaFile and indexFile will not move or change for the entire duration of a hail + * pipeline */ val index = using(fs.open(indexFile))(new FastaSequenceIndex(_)) val missingContigs = contigs.filterNot(index.hasIndexEntry) if (missingContigs.nonEmpty) - fatal(s"Contigs missing in FASTA '$fastaFile' that are present in reference genome '$name':\n " + - s"@1", missingContigs.truncatable("\n ")) + fatal( + s"Contigs missing in FASTA '$fastaFile' that are present in reference genome '$name':\n " + + s"@1", + missingContigs.truncatable("\n "), + ) val invalidLengths = lengths.flatMap { case (c, l) => val fastaLength = index.getIndexEntry(c).getSize @@ -350,15 +400,19 @@ case class ReferenceGenome(name: String, contigs: Array[String], lengths: Map[St Some((c, l, fastaLength)) else None - }.map { case (c, e, f) => s"$c\texpected:$e\tfound:$f"} + }.map { case (c, e, f) => s"$c\texpected:$e\tfound:$f" } if (invalidLengths.nonEmpty) - fatal(s"Contig sizes in FASTA '$fastaFile' do not match expected sizes for reference genome '$name':\n " + - s"@1", invalidLengths.truncatable("\n ")) + fatal( + s"Contig sizes in FASTA '$fastaFile' do not match expected sizes for reference genome '$name':\n " + + s"@1", + invalidLengths.truncatable("\n "), + ) heal(tmpdir, fs) } - @transient private lazy val realFastaReader: ThreadLocal[FASTAReader] = new ThreadLocal[FASTAReader] + @transient private lazy val realFastaReader: ThreadLocal[FASTAReader] = + new ThreadLocal[FASTAReader] private def fastaReader(): FASTAReader = { if (!hasSequence) @@ -370,16 +424,14 @@ case class ReferenceGenome(name: String, contigs: Array[String], lengths: Map[St realFastaReader.get() } - def getSequence(contig: String, position: Int, before: Int = 0, after: Int = 0): String = { + def getSequence(contig: String, position: Int, before: Int = 0, after: Int = 0): String = fastaReader().lookup(contig, position, before, after) - } def getSequence(l: Locus, before: Int, after: Int): String = getSequence(l.contig, l.position, before, after) - def getSequence(i: Interval): String = { + def getSequence(i: Interval): String = fastaReader().lookup(i) - } def removeSequence(): Unit = { if (!hasSequence) @@ -398,15 +450,17 @@ case class ReferenceGenome(name: String, contigs: Array[String], lengths: Map[St if (name == destRGName) fatal(s"Destination reference genome cannot have the same name as this reference '$name'") if (hasLiftover(destRGName)) - fatal(s"Chain file already exists for source reference '$name' and destination reference '$destRGName'.") + fatal( + s"Chain file already exists for source reference '$name' and destination reference '$destRGName'." + ) val tmpdir = ctx.localTmpdir val fs = ctx.fs - if (!fs.exists(chainFile)) - fatal(s"Chain file '$chainFile' does not exist.") + if (!fs.isFile(chainFile)) + fatal(s"Chain file '$chainFile' does not exist, is not a file, or you do not have access.") - val chainFilePath = fs.fileListEntry(chainFile).getPath + val chainFilePath = fs.parseUrl(chainFile).toString val lo = LiftOver(fs, chainFilePath) val destRG = ctx.getReference(destRGName) lo.checkChainFile(this, destRG) @@ -417,7 +471,9 @@ case class ReferenceGenome(name: String, contigs: Array[String], lengths: Map[St def getLiftover(destRGName: String): LiftOver = { if (!hasLiftover(destRGName)) - fatal(s"Chain file has not been loaded for source reference '$name' and destination reference '$destRGName'.") + fatal( + s"Chain file has not been loaded for source reference '$name' and destination reference '$destRGName'." + ) liftoverMap(destRGName) } @@ -433,7 +489,8 @@ case class ReferenceGenome(name: String, contigs: Array[String], lengths: Map[St lo.queryLocus(l, minMatch) } - def liftoverLocusInterval(destRGName: String, interval: Interval, minMatch: Double): (Interval, Boolean) = { + def liftoverLocusInterval(destRGName: String, interval: Interval, minMatch: Double) + : (Interval, Boolean) = { val lo = getLiftover(destRGName) lo.queryInterval(interval, minMatch) } @@ -444,7 +501,7 @@ case class ReferenceGenome(name: String, contigs: Array[String], lengths: Map[St // since removeLiftover updates both maps, so we don't check to see if liftoverMap has // keys that are not in chainFiles for ((destRGName, chainFile) <- chainFiles) { - val chainFilePath = fs.fileListEntry(chainFile).getPath + val chainFilePath = fs.parseUrl(chainFile).toString liftoverMap.get(destRGName) match { case Some(lo) if lo.chainFile == chainFilePath => // do nothing case _ => liftoverMap += destRGName -> LiftOver(fs, chainFilePath) @@ -453,15 +510,18 @@ case class ReferenceGenome(name: String, contigs: Array[String], lengths: Map[St // add sequence if (fastaFilePath != null) { - val fastaPath = fs.fileListEntry(fastaFilePath).getPath - val indexPath = fs.fileListEntry(fastaIndexPath).getPath - if (fastaReaderCfg == null || fastaReaderCfg.fastaFile != fastaPath || fastaReaderCfg.indexFile != indexPath) { + val fastaPath = fs.parseUrl(fastaFilePath).toString + val indexPath = fs.parseUrl(fastaIndexPath).toString + if ( + fastaReaderCfg == null || fastaReaderCfg.fastaFile != fastaPath || fastaReaderCfg.indexFile != indexPath + ) { fastaReaderCfg = FASTAReaderConfig(tmpdir, fs, this, fastaPath, indexPath) } } } - @transient lazy val broadcast: BroadcastValue[ReferenceGenome] = HailContext.backend.broadcast(this) + @transient lazy val broadcast: BroadcastValue[ReferenceGenome] = + HailContext.backend.broadcast(this) override def hashCode: Int = { import org.apache.commons.lang3.builder.HashCodeBuilder @@ -481,12 +541,12 @@ case class ReferenceGenome(name: String, contigs: Array[String], lengths: Map[St other match { case rg: ReferenceGenome => name == rg.name && - contigs.sameElements(rg.contigs) && - lengths == rg.lengths && - xContigs == rg.xContigs && - yContigs == rg.yContigs && - mtContigs == rg.mtContigs && - par.sameElements(rg.par) + contigs.sameElements(rg.contigs) && + lengths == rg.lengths && + xContigs == rg.xContigs && + yContigs == rg.yContigs && + mtContigs == rg.mtContigs && + par.sameElements(rg.par) case _ => false } } @@ -497,18 +557,28 @@ case class ReferenceGenome(name: String, contigs: Array[String], lengths: Map[St def write(fs: is.hail.io.fs.FS, file: String): Unit = using(fs.create(file)) { out => - val jrg = JSONExtractReferenceGenome(name, + val jrg = JSONExtractReferenceGenome( + name, contigs.map(contig => JSONExtractContig(contig, contigLength(contig))), - xContigs, yContigs, mtContigs, - par.map(i => JSONExtractIntervalLocus(i.start.asInstanceOf[Locus], i.end.asInstanceOf[Locus]))) + xContigs, + yContigs, + mtContigs, + par.map(i => + JSONExtractIntervalLocus(i.start.asInstanceOf[Locus], i.end.asInstanceOf[Locus]) + ), + ) implicit val formats: Formats = defaultJSONFormats Serialization.write(jrg, out) } - def toJSON: JSONExtractReferenceGenome = JSONExtractReferenceGenome(name, + def toJSON: JSONExtractReferenceGenome = JSONExtractReferenceGenome( + name, contigs.map(contig => JSONExtractContig(contig, contigLength(contig))), - xContigs, yContigs, mtContigs, - par.map(i => JSONExtractIntervalLocus(i.start.asInstanceOf[Locus], i.end.asInstanceOf[Locus]))) + xContigs, + yContigs, + mtContigs, + par.map(i => JSONExtractIntervalLocus(i.start.asInstanceOf[Locus], i.end.asInstanceOf[Locus])), + ) def toJSONString: String = { implicit val formats: Formats = defaultJSONFormats @@ -526,8 +596,10 @@ object ReferenceGenome { def builtinReferences(): Map[String, ReferenceGenome] = { var builtin: Map[String, ReferenceGenome] = Map() val files = Array( - "reference/grch37.json", "reference/grch38.json", - "reference/grcm38.json", "reference/canfam3.json" + "reference/grch37.json", + "reference/grch38.json", + "reference/grcm38.json", + "reference/canfam3.json", ) for (filename <- files) { val rg = loadFromResource[ReferenceGenome](filename)(read) @@ -546,32 +618,36 @@ object ReferenceGenome { JsonMethods.parse(str).extract[JSONExtractReferenceGenome].toReferenceGenome } - def fromResource(file: String): ReferenceGenome = { + def fromResource(file: String): ReferenceGenome = loadFromResource[ReferenceGenome](file)(read) - } - def fromFile(fs: FS, file: String): ReferenceGenome = { + def fromFile(fs: FS, file: String): ReferenceGenome = using(fs.open(file))(read) - } - def fromHailDataset(fs: FS, path: String): Array[ReferenceGenome] = { + def fromHailDataset(fs: FS, path: String): Array[ReferenceGenome] = RelationalSpec.readReferences(fs, path) - } - def fromJSON(config: String): ReferenceGenome = { + def fromJSON(config: String): ReferenceGenome = parse(config) - } - def fromFASTAFile(ctx: ExecuteContext, name: String, fastaFile: String, indexFile: String, - xContigs: Array[String] = Array.empty[String], yContigs: Array[String] = Array.empty[String], - mtContigs: Array[String] = Array.empty[String], parInput: Array[String] = Array.empty[String]): ReferenceGenome = { - val tmpdir = ctx.localTmpdir + def fromFASTAFile( + ctx: ExecuteContext, + name: String, + fastaFile: String, + indexFile: String, + xContigs: Array[String] = Array.empty[String], + yContigs: Array[String] = Array.empty[String], + mtContigs: Array[String] = Array.empty[String], + parInput: Array[String] = Array.empty[String], + ): ReferenceGenome = { val fs = ctx.fs - if (!fs.exists(fastaFile)) - fatal(s"FASTA file '$fastaFile' does not exist.") - if (!fs.exists(indexFile)) - fatal(s"FASTA index file '$indexFile' does not exist.") + if (!fs.isFile(fastaFile)) + fatal(s"FASTA file '$fastaFile' does not exist, is not a file, or you do not have access.") + if (!fs.isFile(indexFile)) + fatal( + s"FASTA index file '$indexFile' does not exist, is not a file, or you do not have access." + ) val index = using(fs.open(indexFile))(new FastaSequenceIndex(_)) @@ -585,27 +661,40 @@ object ReferenceGenome { lengths += (contig -> length.toInt) } - ReferenceGenome(name, contigs.result(), lengths.result().toMap, xContigs, yContigs, mtContigs, parInput) + ReferenceGenome( + name, + contigs.result(), + lengths.result().toMap, + xContigs, + yContigs, + mtContigs, + parInput, + ) } def readReferences(fs: FS, path: String): Array[ReferenceGenome] = { - if (fs.exists(path)) { - val refs = fs.listDirectory(path) - val rgs = mutable.Set[ReferenceGenome]() - refs.foreach { fileSystem => - val rgPath = fileSystem.getPath.toString - val rg = using(fs.open(rgPath))(read) - val name = rg.name - if (!rgs.contains(rg) && !hailReferences.contains(name)) - rgs += rg + val refs = + try + fs.listDirectory(path) + catch { + case _: FileNotFoundException => + return Array() } - rgs.toArray - } else Array() + + val rgs = mutable.Set[ReferenceGenome]() + refs.foreach { fileSystem => + val rgPath = fileSystem.getPath.toString + val rg = using(fs.open(rgPath))(read) + val name = rg.name + if (!rgs.contains(rg) && !hailReferences.contains(name)) + rgs += rg + } + rgs.toArray } - def writeReference(fs: FS, path: String, rg: ReferenceGenome) { + def writeReference(fs: FS, path: String, rg: ReferenceGenome): Unit = { val rgPath = path + "/" + rg.name + ".json.gz" - if (!hailReferences.contains(rg.name) && !fs.exists(rgPath)) + if (!hailReferences.contains(rg.name) && !fs.isFile(rgPath)) rg.asInstanceOf[ReferenceGenome].write(fs, rgPath) } @@ -619,9 +708,8 @@ object ReferenceGenome { rgs } - def exportReferences(fs: FS, path: String, rgs: Set[ReferenceGenome]) { + def exportReferences(fs: FS, path: String, rgs: Set[ReferenceGenome]): Unit = rgs.foreach(writeReference(fs, path, _)) - } def compare(contigsIndex: java.util.HashMap[String, Integer], c1: String, c2: String): Int = { val i1 = contigsIndex.get(c1) @@ -639,43 +727,72 @@ object ReferenceGenome { Integer.compare(l1.position, l2.position) } - def gen: Gen[ReferenceGenome] = for { - name <- Gen.identifier.filter(!ReferenceGenome.hailReferences.contains(_)) - nContigs <- Gen.choose(3, 10) - contigs <- Gen.distinctBuildableOfN[Array](nContigs, Gen.identifier) - lengths <- Gen.buildableOfN[Array](nContigs, Gen.choose(1000000, 500000000)) - contigsIndex = contigs.zip(lengths).toMap - xContig <- Gen.oneOfSeq(contigs) - parXA <- Gen.choose(0, contigsIndex(xContig)) - parXB <- Gen.choose(0, contigsIndex(xContig)) - yContig <- Gen.oneOfSeq(contigs) if yContig != xContig - parYA <- Gen.choose(0, contigsIndex(yContig)) - parYB <- Gen.choose(0, contigsIndex(yContig)) - mtContig <- Gen.oneOfSeq(contigs) if mtContig != xContig && mtContig != yContig - } yield ReferenceGenome(name, contigs, contigs.zip(lengths).toMap, Set(xContig), Set(yContig), Set(mtContig), - Array( - (Locus(xContig, math.min(parXA, parXB)), - Locus(xContig, math.max(parXA, parXB))), - (Locus(yContig, math.min(parYA, parYB)), - Locus(yContig, math.max(parYA, parYB))))) - - def apply(name: String, contigs: Array[String], lengths: Map[String, Int], xContigs: Array[String], yContigs: Array[String], - mtContigs: Array[String], parInput: Array[String]): ReferenceGenome = { + def gen: Gen[ReferenceGenome] = + for { + name <- Gen.identifier.filter(!ReferenceGenome.hailReferences.contains(_)) + nContigs <- Gen.choose(3, 10) + contigs <- Gen.distinctBuildableOfN[Array](nContigs, Gen.identifier) + lengths <- Gen.buildableOfN[Array](nContigs, Gen.choose(1000000, 500000000)) + contigsIndex = contigs.zip(lengths).toMap + xContig <- Gen.oneOfSeq(contigs) + parXA <- Gen.choose(0, contigsIndex(xContig)) + parXB <- Gen.choose(0, contigsIndex(xContig)) + yContig <- Gen.oneOfSeq(contigs) if yContig != xContig + parYA <- Gen.choose(0, contigsIndex(yContig)) + parYB <- Gen.choose(0, contigsIndex(yContig)) + mtContig <- Gen.oneOfSeq(contigs) if mtContig != xContig && mtContig != yContig + } yield ReferenceGenome( + name, + contigs, + contigs.zip(lengths).toMap, + Set(xContig), + Set(yContig), + Set(mtContig), + Array( + (Locus(xContig, math.min(parXA, parXB)), Locus(xContig, math.max(parXA, parXB))), + (Locus(yContig, math.min(parYA, parYB)), Locus(yContig, math.max(parYA, parYB))), + ), + ) + + def apply( + name: String, + contigs: Array[String], + lengths: Map[String, Int], + xContigs: Array[String], + yContigs: Array[String], + mtContigs: Array[String], + parInput: Array[String], + ): ReferenceGenome = { val parRegex = """(\w+):(\d+)-(\d+)""".r val par = parInput.map { - case parRegex(contig, start, end) => (Locus(contig.toString, start.toInt), Locus(contig.toString, end.toInt)) + case parRegex(contig, start, end) => + (Locus(contig.toString, start.toInt), Locus(contig.toString, end.toInt)) case _ => fatal("expected PAR input of form contig:start-end") } ReferenceGenome(name, contigs, lengths, xContigs.toSet, yContigs.toSet, mtContigs.toSet, par) } - def apply(name: java.lang.String, contigs: java.util.List[String], lengths: java.util.Map[String, Int], - xContigs: java.util.List[String], yContigs: java.util.List[String], - mtContigs: java.util.List[String], parInput: java.util.List[String]): ReferenceGenome = - ReferenceGenome(name, contigs.asScala.toArray, lengths.asScala.toMap, xContigs.asScala.toArray, yContigs.asScala.toArray, - mtContigs.asScala.toArray, parInput.asScala.toArray) + def apply( + name: java.lang.String, + contigs: java.util.List[String], + lengths: java.util.Map[String, Int], + xContigs: java.util.List[String], + yContigs: java.util.List[String], + mtContigs: java.util.List[String], + parInput: java.util.List[String], + ): ReferenceGenome = + ReferenceGenome( + name, + contigs.asScala.toArray, + lengths.asScala.toMap, + xContigs.asScala.toArray, + yContigs.asScala.toArray, + mtContigs.asScala.toArray, + parInput.asScala.toArray, + ) - def getMapFromArray(arr: Array[ReferenceGenome]): Map[String, ReferenceGenome] = arr.map(rg => (rg.name, rg)).toMap + def getMapFromArray(arr: Array[ReferenceGenome]): Map[String, ReferenceGenome] = + arr.map(rg => (rg.name, rg)).toMap } diff --git a/hail/src/main/scala/is/hail/variant/RegionValueVariant.scala b/hail/src/main/scala/is/hail/variant/RegionValueVariant.scala index 958b997241a..fe223effb09 100644 --- a/hail/src/main/scala/is/hail/variant/RegionValueVariant.scala +++ b/hail/src/main/scala/is/hail/variant/RegionValueVariant.scala @@ -1,6 +1,5 @@ package is.hail.variant -import is.hail.annotations._ import is.hail.types.physical.{PArray, PInt32, PLocus, PString, PStruct} import is.hail.utils._ @@ -18,7 +17,7 @@ class RegionValueVariant(rowType: PStruct) extends View { private var cachedAlleles: Array[String] = null private var cachedLocus: Locus = null - def set(address: Long) { + def set(address: Long): Unit = { if (!rowType.isFieldDefined(address, locusIdx)) fatal(s"The row field 'locus' cannot have missing values.") if (!rowType.isFieldDefined(address, allelesIdx)) @@ -56,7 +55,7 @@ class RegionValueVariant(rowType: PStruct) extends View { var i = 0 while (i < nAlleles) { if (taa.isElementDefined(allelesOffset, i)) - cachedAlleles(i) = allelePType.loadString(taa.loadElement(allelesOffset, i)) + cachedAlleles(i) = allelePType.loadString(taa.loadElement(allelesOffset, i)) i += 1 } } diff --git a/hail/src/main/scala/is/hail/variant/VariantMethods.scala b/hail/src/main/scala/is/hail/variant/VariantMethods.scala index f01f48555fd..96f5cc016bb 100644 --- a/hail/src/main/scala/is/hail/variant/VariantMethods.scala +++ b/hail/src/main/scala/is/hail/variant/VariantMethods.scala @@ -21,7 +21,7 @@ object VariantMethods { } def locusAllelesToString(locus: Locus, alleles: IndexedSeq[String]): String = - s"$locus:${ alleles(0) }:${ alleles.tail.mkString(",") }" + s"$locus:${alleles(0)}:${alleles.tail.mkString(",")}" def minRep(locus: Locus, alleles: IndexedSeq[String]): (Locus, IndexedSeq[String]) = { if (alleles.isEmpty) @@ -44,26 +44,28 @@ object VariantMethods { val min_length = math.min(ref.length, alts.map(x => x.length).min) var ne = 0 - while (ne < min_length - 1 + while ( + ne < min_length - 1 && alts.forall(x => ref(ref.length - ne - 1) == x(x.length - ne - 1)) - ) { + ) ne += 1 - } var ns = 0 - while (ns < min_length - ne - 1 + while ( + ns < min_length - ne - 1 && alts.forall(x => ref(ns) == x(ns)) - ) { + ) ns += 1 - } if (ne + ns == 0) (locus, alleles) else { assert(ns < ref.length - ne && alts.forall(x => ns < x.length - ne)) - (Locus(locus.contig, locus.position + ns), + ( + Locus(locus.contig, locus.position + ns), ref.substring(ns, ref.length - ne) +: - altAlleles.map(a => if (a == "*") a else a.substring(ns, a.length - ne)).toArray) + altAlleles.map(a => if (a == "*") a else a.substring(ns, a.length - ne)).toArray, + ) } } } @@ -74,16 +76,17 @@ object VariantSubgen { contigGen = Contig.gen(rg), nAllelesGen = Gen.frequency((5, Gen.const(2)), (1, Gen.choose(2, 10))), refGen = genDNAString, - altGen = Gen.frequency((10, genDNAString), - (1, Gen.const("*")))) + altGen = Gen.frequency((10, genDNAString), (1, Gen.const("*"))), + ) def plinkCompatible(rg: ReferenceGenome): VariantSubgen = { val r = random(rg) val compatible = (1 until 22).map(_.toString).toSet r.copy( - contigGen = r.contigGen.filter { case (contig, len) => + contigGen = r.contigGen.filter { case (contig, _) => compatible.contains(contig) - }) + } + ) } def biallelic(rg: ReferenceGenome): VariantSubgen = random(rg).copy(nAllelesGen = Gen.const(2)) @@ -96,7 +99,8 @@ case class VariantSubgen( contigGen: Gen[(String, Int)], nAllelesGen: Gen[Int], refGen: Gen[String], - altGen: Gen[String]) { + altGen: Gen[String], +) { def genLocusAlleles: Gen[Annotation] = for { @@ -106,8 +110,8 @@ case class VariantSubgen( ref <- refGen altAlleles <- Gen.distinctBuildableOfN[Array]( nAlleles - 1, - altGen) + altGen, + ) .filter(!_.contains(ref)) - } yield - Annotation(Locus(contig, start), (ref +: altAlleles).toFastSeq) + } yield Annotation(Locus(contig, start), (ref +: altAlleles).toFastSeq) } diff --git a/hail/src/main/scala/is/hail/variant/View.scala b/hail/src/main/scala/is/hail/variant/View.scala index d08a496fd09..0ace959c67b 100644 --- a/hail/src/main/scala/is/hail/variant/View.scala +++ b/hail/src/main/scala/is/hail/variant/View.scala @@ -1,5 +1,5 @@ package is.hail.variant trait View { - def set(offset: Long) + def set(offset: Long): Unit } diff --git a/hail/src/main/scala/is/hail/variant/package.scala b/hail/src/main/scala/is/hail/variant/package.scala index 08a91788a73..6634c576497 100644 --- a/hail/src/main/scala/is/hail/variant/package.scala +++ b/hail/src/main/scala/is/hail/variant/package.scala @@ -1,7 +1,5 @@ package is.hail -import scala.language.implicitConversions - package object variant { type Call = Int type BoxedCall = java.lang.Integer diff --git a/hail/src/main/scala/org/apache/spark/ExposedMetrics.scala b/hail/src/main/scala/org/apache/spark/ExposedMetrics.scala index 3302a130c9f..ae707a77d82 100644 --- a/hail/src/main/scala/org/apache/spark/ExposedMetrics.scala +++ b/hail/src/main/scala/org/apache/spark/ExposedMetrics.scala @@ -3,19 +3,15 @@ package org.apache.spark import org.apache.spark.executor.{InputMetrics, OutputMetrics} object ExposedMetrics { - def incrementRecord(metrics: InputMetrics) { + def incrementRecord(metrics: InputMetrics): Unit = metrics.incRecordsRead(1) - } - def incrementBytes(metrics: InputMetrics, nBytes: Long) { + def incrementBytes(metrics: InputMetrics, nBytes: Long): Unit = metrics.incBytesRead(nBytes) - } - def setBytes(metrics: OutputMetrics, nBytes: Long) { + def setBytes(metrics: OutputMetrics, nBytes: Long): Unit = metrics.setBytesWritten(nBytes) - } - def setRecords(metrics: OutputMetrics, nRecords: Long) { + def setRecords(metrics: OutputMetrics, nRecords: Long): Unit = metrics.setRecordsWritten(nRecords) - } } diff --git a/hail/src/main/scala/org/apache/spark/ExposedUtils.scala b/hail/src/main/scala/org/apache/spark/ExposedUtils.scala index 896dbc82703..f633e0215e5 100644 --- a/hail/src/main/scala/org/apache/spark/ExposedUtils.scala +++ b/hail/src/main/scala/org/apache/spark/ExposedUtils.scala @@ -1,10 +1,10 @@ package org.apache.spark +import scala.reflect._ + import org.apache.spark.serializer._ import org.apache.spark.util._ -import scala.reflect._ - object ExposedUtils { def clean[F <: AnyRef](f: F, checkSerializable: Boolean = true): F = { ClosureCleaner.clean(f, checkSerializable) diff --git a/hail/src/main/scala/is/hail/annotations/Memory.java b/hail/src/release/java/is/hail/annotations/Memory.java similarity index 100% rename from hail/src/main/scala/is/hail/annotations/Memory.java rename to hail/src/release/java/is/hail/annotations/Memory.java diff --git a/hail/src/test/resources/balding-nichols-1024-variants-4-samples-3-populations.py b/hail/src/test/resources/balding-nichols-1024-variants-4-samples-3-populations.py index c02a49e0925..2bb48510e74 100644 --- a/hail/src/test/resources/balding-nichols-1024-variants-4-samples-3-populations.py +++ b/hail/src/test/resources/balding-nichols-1024-variants-4-samples-3-populations.py @@ -1,4 +1,5 @@ import hail as hl + hl.set_global_seed(0) mt = hl.balding_nichols_model(n_populations=3, n_variants=(1 << 10), n_samples=4) mt = mt.key_cols_by(s='s' + hl.str(mt.sample_idx)) diff --git a/hail/src/test/resources/makeTestInfoScore.py b/hail/src/test/resources/makeTestInfoScore.py deleted file mode 100644 index 871f64718a6..00000000000 --- a/hail/src/test/resources/makeTestInfoScore.py +++ /dev/null @@ -1,162 +0,0 @@ -#! /usr/bin/python - -import sys -import os -import random - -seed = sys.argv[1] -nSamples = int(sys.argv[2]) -nVariants = int(sys.argv[3]) -root = sys.argv[4] - -random.seed(seed) - -def homRef(maf): - return (1.0 - maf) * (1.0 - maf) -def het(maf): - return 2 * maf * (1.0 - maf) -def homAlt(maf): - return maf * maf - -def randomGen(missingRate): - gps = [] - for j in range(nSamples): - if random.random() < missingRate: - gps += [0, 0, 0] - else: - d1 = random.random() - d2 = random.uniform(0, 1.0 - d1) - gps += [d1, d2, 1.0 - d1 - d2] - return gps - -def hweGen(maf, missingRate): - bb = homAlt(maf) - aa = homRef(maf) - gps = [] - for j in range(nSamples): - gt = random.random() - missing = random.random() - if missing < missingRate: - gps += [0, 0, 0] - else: - d1 = 1.0 - random.uniform(0, 0.01) - d2 = random.uniform(0, 1.0 - d1) - d3 = 1.0 - d1 - d2 - - if gt < aa: - gps += [d1, d2, d3] - elif gt >= aa and gt <= 1.0 - bb: - gps += [d2, d1, d3] - else: - gps += [d3, d2, d1] - - return gps - -def constantGen(triple, missingRate): - gps = [] - for j in range(nSamples): - if random.random() < missingRate: - gps += [0, 0, 0] - else: - gps += triple - return gps - -variants = {} -for i in range(nVariants * 0, nVariants * 1): - variants[i] = randomGen(0.0) - -for i in range(nVariants * 1, nVariants * 2): - missingRate = random.random() - variants[i] = randomGen(missingRate) - -for i in range(nVariants * 2, nVariants * 3): - maf = random.random() - variants[i] = hweGen(maf, 0.0) - -for i in range(nVariants * 3, nVariants * 4): - maf = random.random() - missingRate = random.random() - variants[i] = hweGen(maf, missingRate) - -for i in range(nVariants * 4, nVariants * 5): - missingRate = random.random() - variants[i] = constantGen([1, 0, 0], missingRate) - -for i in range(nVariants * 5, nVariants * 6): - missingRate= random.random() - variants[i]= constantGen([0, 1, 0], missingRate) - -for i in range(nVariants * 6, nVariants * 7): - missingRate= random.random() - variants[i]= constantGen([0, 0, 1], missingRate) - -variants[i + 1] = constantGen([0, 0, 0], 0.0) -variants[i + 2]= constantGen([1, 0, 0], 0.0) -variants[i + 3]= constantGen([0, 1, 0], 0.0) -variants[i + 4]= constantGen([0, 0, 1], 0.0) - -def transformDosage(dx): - w0 = dx[0] - w1 = dx[1] - w2 = dx[2] - - sumDx = w0 + w1 + w2 - - try: - l0 = int(w0 * 32768 / sumDx + 0.5) - l1 = int((w0 + w1) * 32768 / sumDx + 0.5) - l0 - l2 = 32768 - l0 - l1 - except: - print dx - sys.exit() - return [l0 / 32768.0, l1 / 32768.0, l2 / 32768.0] - -def calcInfoScore(gps): - nIncluded = 0 - e = [] - f = [] - altAllele = 0.0 - totalDosage = 0.0 - - for i in range(0, len(gps), 3): - dx = gps[i:i + 3] - if sum(dx) != 0.0: - dxt = transformDosage(dx) - nIncluded += 1 - e.append(dxt[1] + 2 * dxt[2]) - f.append(dxt[1] + 4 * dxt[2]) - altAllele += (dxt[1] + 2 *dxt[2]) - totalDosage += sum(dxt) - - z = zip(e, f) - z = [fi - ei * ei for (ei, fi) in z] - - if totalDosage == 0.0: - infoScore = None - else: - theta = altAllele / totalDosage - if theta != 0.0 and theta != 1.0: - infoScore = 1.0 - (sum(z) / (2 * float(nIncluded) * theta * (1.0 - theta))) - else: - infoScore = 1.0 - - return (infoScore, nIncluded) - - -genOutput = open(root + ".gen", 'w') -sampleOutput = open(root + ".sample", 'w') -resultOutput = open(root + ".result", 'w') - -sampleOutput.write("ID_1 ID_2 missing\n0 0 0\n") -for j in range(nSamples): - id = "sample" + str(j) - sampleOutput.write(" ".join([id, id, "0"]) + "\n") - -for v in variants: - genOutput.write("01 SNPID_{0} RSID_{0} {0} A G ".format(v) + " ".join([str(d) for d in variants[v]]) + "\n") - (infoScore, nIncluded) = calcInfoScore(variants[v]) - resultOutput.write(" ".join(["01:{0}:A:G SNPID_{0} RSID_{0}".format(v), str(infoScore), str(nIncluded)]) + "\n") - -genOutput.close() -sampleOutput.close() -resultOutput.close() diff --git a/hail/src/test/resources/missingInfoArray.vcf b/hail/src/test/resources/missingInfoArray.vcf index 913ffb8d1a5..b65ac6fe572 100644 --- a/hail/src/test/resources/missingInfoArray.vcf +++ b/hail/src/test/resources/missingInfoArray.vcf @@ -9,6 +9,8 @@ ##FORMAT= ##INFO= ##INFO= +##INFO= +##INFO= #CHROM POS ID REF ALT QUAL FILTER INFO FORMAT C1046::HG02024 C1046::HG02025 -X 16050036 . A C 19961.13 . FOO=1,.;BAR=2,.,. GT:GTA:GTZ:AD:DP:GQ:PL 0/0:./.:0/1:10,0:10:44:0,44,180 1:.:0:0,6:7:70:70,0 -X 16061250 . T A,C 547794.46 . FOO=.,2,.;BAR=.,1.0,. GT:GTA:GTZ:AD:DP:GQ:PL 2/2:2/1:1/1:0,0,11:11:33:396,402,411,33,33,0 2:.:1:0,0,9:9:24:24,40,0 +X 16050036 . A C 19961.13 . FOO=1,.;BAR=2,.,.;JUST_A_DOT=. GT:GTA:GTZ:AD:DP:GQ:PL 0/0:./.:0/1:10,0:10:44:0,44,180 1:.:0:0,6:7:70:70,0 +X 16061250 . T A,C 547794.46 . FOO=.,2,.;BAR=.,1.0,.;JUST_A_DOT=. GT:GTA:GTZ:AD:DP:GQ:PL 2/2:2/1:1/1:0,0,11:11:33:396,402,411,33,33,0 2:.:1:0,0,9:9:24:24,40,0 diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/README.txt b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/README.txt new file mode 100644 index 00000000000..62f13d3fe96 --- /dev/null +++ b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/README.txt @@ -0,0 +1,3 @@ +This folder comprises a Hail (www.hail.is) native Table or MatrixTable. + Written with version 0.2.128-705d4033e0c9 + Created at 2024/03/27 12:03:10 \ No newline at end of file diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/_SUCCESS b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/_SUCCESS new file mode 100644 index 00000000000..e69de29bb2d diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/cols/README.txt b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/cols/README.txt new file mode 100644 index 00000000000..62f13d3fe96 --- /dev/null +++ b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/cols/README.txt @@ -0,0 +1,3 @@ +This folder comprises a Hail (www.hail.is) native Table or MatrixTable. + Written with version 0.2.128-705d4033e0c9 + Created at 2024/03/27 12:03:10 \ No newline at end of file diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/cols/_SUCCESS b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/cols/_SUCCESS new file mode 100644 index 00000000000..e69de29bb2d diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/cols/metadata.json.gz b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/cols/metadata.json.gz new file mode 100644 index 00000000000..ba853f2bd33 Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/cols/metadata.json.gz differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/cols/rows/metadata.json.gz b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/cols/rows/metadata.json.gz new file mode 100644 index 00000000000..4b51c50a941 Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/cols/rows/metadata.json.gz differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/cols/rows/parts/part-0 b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/cols/rows/parts/part-0 new file mode 100644 index 00000000000..5c15140d10d Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/cols/rows/parts/part-0 differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/README.txt b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/README.txt new file mode 100644 index 00000000000..62f13d3fe96 --- /dev/null +++ b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/README.txt @@ -0,0 +1,3 @@ +This folder comprises a Hail (www.hail.is) native Table or MatrixTable. + Written with version 0.2.128-705d4033e0c9 + Created at 2024/03/27 12:03:10 \ No newline at end of file diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/_SUCCESS b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/_SUCCESS new file mode 100644 index 00000000000..e69de29bb2d diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/metadata.json.gz b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/metadata.json.gz new file mode 100644 index 00000000000..ca138832fdd Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/metadata.json.gz differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/rows/metadata.json.gz b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/rows/metadata.json.gz new file mode 100644 index 00000000000..5d5444d5a3a Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/rows/metadata.json.gz differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/rows/parts/part-00-cdb826da-6c5c-47b6-945b-3190a87a6a14 b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/rows/parts/part-00-cdb826da-6c5c-47b6-945b-3190a87a6a14 new file mode 100644 index 00000000000..46ce64cb28e Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/rows/parts/part-00-cdb826da-6c5c-47b6-945b-3190a87a6a14 differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/rows/parts/part-01-06f6a507-61e2-4bd1-a917-e1809270144c b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/rows/parts/part-01-06f6a507-61e2-4bd1-a917-e1809270144c new file mode 100644 index 00000000000..9770ebd72d4 Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/rows/parts/part-01-06f6a507-61e2-4bd1-a917-e1809270144c differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/rows/parts/part-02-881d024c-5baf-4fe6-bc8f-53eda3845bde b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/rows/parts/part-02-881d024c-5baf-4fe6-bc8f-53eda3845bde new file mode 100644 index 00000000000..58aa83882f4 Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/rows/parts/part-02-881d024c-5baf-4fe6-bc8f-53eda3845bde differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/rows/parts/part-03-1e085a57-4dcb-4131-bc79-353324ffad47 b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/rows/parts/part-03-1e085a57-4dcb-4131-bc79-353324ffad47 new file mode 100644 index 00000000000..557d34f510f Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/rows/parts/part-03-1e085a57-4dcb-4131-bc79-353324ffad47 differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/rows/parts/part-04-d17ed9aa-6b33-4b0b-85d5-578da32f7581 b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/rows/parts/part-04-d17ed9aa-6b33-4b0b-85d5-578da32f7581 new file mode 100644 index 00000000000..d3ca64b40ba Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/rows/parts/part-04-d17ed9aa-6b33-4b0b-85d5-578da32f7581 differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/rows/parts/part-05-40d512f8-23ba-485e-aefa-47eced2bfe6d b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/rows/parts/part-05-40d512f8-23ba-485e-aefa-47eced2bfe6d new file mode 100644 index 00000000000..728043e4dd0 Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/rows/parts/part-05-40d512f8-23ba-485e-aefa-47eced2bfe6d differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/rows/parts/part-06-9b2dc9c7-c8b1-4ed4-9056-20142b5f6658 b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/rows/parts/part-06-9b2dc9c7-c8b1-4ed4-9056-20142b5f6658 new file mode 100644 index 00000000000..ea21ec1388a Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/rows/parts/part-06-9b2dc9c7-c8b1-4ed4-9056-20142b5f6658 differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/rows/parts/part-07-b9a32d97-cb10-4158-aeaa-645dcea68ca7 b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/rows/parts/part-07-b9a32d97-cb10-4158-aeaa-645dcea68ca7 new file mode 100644 index 00000000000..850e93422b8 Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/rows/parts/part-07-b9a32d97-cb10-4158-aeaa-645dcea68ca7 differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/rows/parts/part-08-c2a0123f-a3d4-4b80-9c21-73cb2bed0b63 b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/rows/parts/part-08-c2a0123f-a3d4-4b80-9c21-73cb2bed0b63 new file mode 100644 index 00000000000..a2747db6ac5 Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/rows/parts/part-08-c2a0123f-a3d4-4b80-9c21-73cb2bed0b63 differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/rows/parts/part-09-ca197aee-6bfd-4068-b771-e9ca63551a7c b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/rows/parts/part-09-ca197aee-6bfd-4068-b771-e9ca63551a7c new file mode 100644 index 00000000000..e20bed4dea8 Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/rows/parts/part-09-ca197aee-6bfd-4068-b771-e9ca63551a7c differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/rows/parts/part-10-17048169-a98b-49ee-ae4d-62641023b3ac b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/rows/parts/part-10-17048169-a98b-49ee-ae4d-62641023b3ac new file mode 100644 index 00000000000..5020435b81b Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/rows/parts/part-10-17048169-a98b-49ee-ae4d-62641023b3ac differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/rows/parts/part-11-c89858f5-4d78-4739-af31-308a1c257ff4 b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/rows/parts/part-11-c89858f5-4d78-4739-af31-308a1c257ff4 new file mode 100644 index 00000000000..f2ddc1d151a Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/rows/parts/part-11-c89858f5-4d78-4739-af31-308a1c257ff4 differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/rows/parts/part-12-3e391e78-782d-495d-a29c-cacc56e1baf8 b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/rows/parts/part-12-3e391e78-782d-495d-a29c-cacc56e1baf8 new file mode 100644 index 00000000000..329bf68b7bb Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/rows/parts/part-12-3e391e78-782d-495d-a29c-cacc56e1baf8 differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/rows/parts/part-13-62566d28-e496-4538-a325-b567be66accf b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/rows/parts/part-13-62566d28-e496-4538-a325-b567be66accf new file mode 100644 index 00000000000..78c76c8d04e Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/rows/parts/part-13-62566d28-e496-4538-a325-b567be66accf differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/rows/parts/part-14-8ab32ab7-15cd-4302-bb45-6b3dc02db5b6 b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/rows/parts/part-14-8ab32ab7-15cd-4302-bb45-6b3dc02db5b6 new file mode 100644 index 00000000000..c5756714bb1 Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/rows/parts/part-14-8ab32ab7-15cd-4302-bb45-6b3dc02db5b6 differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/rows/parts/part-15-c4301966-4fd8-4ea0-b439-b49a693bf683 b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/rows/parts/part-15-c4301966-4fd8-4ea0-b439-b49a693bf683 new file mode 100644 index 00000000000..b3a18a17ea0 Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/rows/parts/part-15-c4301966-4fd8-4ea0-b439-b49a693bf683 differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/rows/parts/part-16-8d638c2e-b1a5-4507-ba00-337a02e3f431 b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/rows/parts/part-16-8d638c2e-b1a5-4507-ba00-337a02e3f431 new file mode 100644 index 00000000000..325ba793d14 Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/rows/parts/part-16-8d638c2e-b1a5-4507-ba00-337a02e3f431 differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/rows/parts/part-17-0c739863-b5fe-4e33-8f47-3e2751b599df b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/rows/parts/part-17-0c739863-b5fe-4e33-8f47-3e2751b599df new file mode 100644 index 00000000000..621aa82fb97 Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/rows/parts/part-17-0c739863-b5fe-4e33-8f47-3e2751b599df differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/rows/parts/part-18-35d65ae7-5d1d-43f8-bb21-e6565874975e b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/rows/parts/part-18-35d65ae7-5d1d-43f8-bb21-e6565874975e new file mode 100644 index 00000000000..788d0cd093f Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/rows/parts/part-18-35d65ae7-5d1d-43f8-bb21-e6565874975e differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/rows/parts/part-19-2fd81de2-5d34-43db-809d-2f1fe1e67200 b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/rows/parts/part-19-2fd81de2-5d34-43db-809d-2f1fe1e67200 new file mode 100644 index 00000000000..b58aaa00ba7 Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/entries/rows/parts/part-19-2fd81de2-5d34-43db-809d-2f1fe1e67200 differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/globals/README.txt b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/globals/README.txt new file mode 100644 index 00000000000..62f13d3fe96 --- /dev/null +++ b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/globals/README.txt @@ -0,0 +1,3 @@ +This folder comprises a Hail (www.hail.is) native Table or MatrixTable. + Written with version 0.2.128-705d4033e0c9 + Created at 2024/03/27 12:03:10 \ No newline at end of file diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/globals/_SUCCESS b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/globals/_SUCCESS new file mode 100644 index 00000000000..e69de29bb2d diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/globals/globals/metadata.json.gz b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/globals/globals/metadata.json.gz new file mode 100644 index 00000000000..369b04d91ac Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/globals/globals/metadata.json.gz differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/globals/globals/parts/part-0 b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/globals/globals/parts/part-0 new file mode 100644 index 00000000000..89e711531de Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/globals/globals/parts/part-0 differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/globals/metadata.json.gz b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/globals/metadata.json.gz new file mode 100644 index 00000000000..89b81c893f6 Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/globals/metadata.json.gz differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/globals/rows/metadata.json.gz b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/globals/rows/metadata.json.gz new file mode 100644 index 00000000000..369b04d91ac Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/globals/rows/metadata.json.gz differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/globals/rows/parts/part-0 b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/globals/rows/parts/part-0 new file mode 100644 index 00000000000..89e711531de Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/globals/rows/parts/part-0 differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-00-cdb826da-6c5c-47b6-945b-3190a87a6a14.idx/index b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-00-cdb826da-6c5c-47b6-945b-3190a87a6a14.idx/index new file mode 100644 index 00000000000..627c0967eef Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-00-cdb826da-6c5c-47b6-945b-3190a87a6a14.idx/index differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-00-cdb826da-6c5c-47b6-945b-3190a87a6a14.idx/metadata.json.gz b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-00-cdb826da-6c5c-47b6-945b-3190a87a6a14.idx/metadata.json.gz new file mode 100644 index 00000000000..7d3e3e8fc77 Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-00-cdb826da-6c5c-47b6-945b-3190a87a6a14.idx/metadata.json.gz differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-01-06f6a507-61e2-4bd1-a917-e1809270144c.idx/index b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-01-06f6a507-61e2-4bd1-a917-e1809270144c.idx/index new file mode 100644 index 00000000000..a43c7360c42 Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-01-06f6a507-61e2-4bd1-a917-e1809270144c.idx/index differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-01-06f6a507-61e2-4bd1-a917-e1809270144c.idx/metadata.json.gz b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-01-06f6a507-61e2-4bd1-a917-e1809270144c.idx/metadata.json.gz new file mode 100644 index 00000000000..685b6946a41 Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-01-06f6a507-61e2-4bd1-a917-e1809270144c.idx/metadata.json.gz differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-02-881d024c-5baf-4fe6-bc8f-53eda3845bde.idx/index b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-02-881d024c-5baf-4fe6-bc8f-53eda3845bde.idx/index new file mode 100644 index 00000000000..2d6ef5bfb89 Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-02-881d024c-5baf-4fe6-bc8f-53eda3845bde.idx/index differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-02-881d024c-5baf-4fe6-bc8f-53eda3845bde.idx/metadata.json.gz b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-02-881d024c-5baf-4fe6-bc8f-53eda3845bde.idx/metadata.json.gz new file mode 100644 index 00000000000..6441fab8db8 Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-02-881d024c-5baf-4fe6-bc8f-53eda3845bde.idx/metadata.json.gz differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-03-1e085a57-4dcb-4131-bc79-353324ffad47.idx/index b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-03-1e085a57-4dcb-4131-bc79-353324ffad47.idx/index new file mode 100644 index 00000000000..d42ae3f6d26 Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-03-1e085a57-4dcb-4131-bc79-353324ffad47.idx/index differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-03-1e085a57-4dcb-4131-bc79-353324ffad47.idx/metadata.json.gz b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-03-1e085a57-4dcb-4131-bc79-353324ffad47.idx/metadata.json.gz new file mode 100644 index 00000000000..025d6ff3782 Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-03-1e085a57-4dcb-4131-bc79-353324ffad47.idx/metadata.json.gz differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-04-d17ed9aa-6b33-4b0b-85d5-578da32f7581.idx/index b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-04-d17ed9aa-6b33-4b0b-85d5-578da32f7581.idx/index new file mode 100644 index 00000000000..e77f2844cd2 Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-04-d17ed9aa-6b33-4b0b-85d5-578da32f7581.idx/index differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-04-d17ed9aa-6b33-4b0b-85d5-578da32f7581.idx/metadata.json.gz b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-04-d17ed9aa-6b33-4b0b-85d5-578da32f7581.idx/metadata.json.gz new file mode 100644 index 00000000000..685b6946a41 Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-04-d17ed9aa-6b33-4b0b-85d5-578da32f7581.idx/metadata.json.gz differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-05-40d512f8-23ba-485e-aefa-47eced2bfe6d.idx/index b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-05-40d512f8-23ba-485e-aefa-47eced2bfe6d.idx/index new file mode 100644 index 00000000000..4606b87a655 Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-05-40d512f8-23ba-485e-aefa-47eced2bfe6d.idx/index differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-05-40d512f8-23ba-485e-aefa-47eced2bfe6d.idx/metadata.json.gz b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-05-40d512f8-23ba-485e-aefa-47eced2bfe6d.idx/metadata.json.gz new file mode 100644 index 00000000000..dff25d235e5 Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-05-40d512f8-23ba-485e-aefa-47eced2bfe6d.idx/metadata.json.gz differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-06-9b2dc9c7-c8b1-4ed4-9056-20142b5f6658.idx/index b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-06-9b2dc9c7-c8b1-4ed4-9056-20142b5f6658.idx/index new file mode 100644 index 00000000000..84c9027fccb Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-06-9b2dc9c7-c8b1-4ed4-9056-20142b5f6658.idx/index differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-06-9b2dc9c7-c8b1-4ed4-9056-20142b5f6658.idx/metadata.json.gz b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-06-9b2dc9c7-c8b1-4ed4-9056-20142b5f6658.idx/metadata.json.gz new file mode 100644 index 00000000000..685b6946a41 Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-06-9b2dc9c7-c8b1-4ed4-9056-20142b5f6658.idx/metadata.json.gz differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-07-b9a32d97-cb10-4158-aeaa-645dcea68ca7.idx/index b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-07-b9a32d97-cb10-4158-aeaa-645dcea68ca7.idx/index new file mode 100644 index 00000000000..eda3e9c9b38 Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-07-b9a32d97-cb10-4158-aeaa-645dcea68ca7.idx/index differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-07-b9a32d97-cb10-4158-aeaa-645dcea68ca7.idx/metadata.json.gz b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-07-b9a32d97-cb10-4158-aeaa-645dcea68ca7.idx/metadata.json.gz new file mode 100644 index 00000000000..43a1c4fe7eb Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-07-b9a32d97-cb10-4158-aeaa-645dcea68ca7.idx/metadata.json.gz differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-08-c2a0123f-a3d4-4b80-9c21-73cb2bed0b63.idx/index b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-08-c2a0123f-a3d4-4b80-9c21-73cb2bed0b63.idx/index new file mode 100644 index 00000000000..5460a0ff2a7 Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-08-c2a0123f-a3d4-4b80-9c21-73cb2bed0b63.idx/index differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-08-c2a0123f-a3d4-4b80-9c21-73cb2bed0b63.idx/metadata.json.gz b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-08-c2a0123f-a3d4-4b80-9c21-73cb2bed0b63.idx/metadata.json.gz new file mode 100644 index 00000000000..685b6946a41 Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-08-c2a0123f-a3d4-4b80-9c21-73cb2bed0b63.idx/metadata.json.gz differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-09-ca197aee-6bfd-4068-b771-e9ca63551a7c.idx/index b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-09-ca197aee-6bfd-4068-b771-e9ca63551a7c.idx/index new file mode 100644 index 00000000000..08d2e54c744 Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-09-ca197aee-6bfd-4068-b771-e9ca63551a7c.idx/index differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-09-ca197aee-6bfd-4068-b771-e9ca63551a7c.idx/metadata.json.gz b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-09-ca197aee-6bfd-4068-b771-e9ca63551a7c.idx/metadata.json.gz new file mode 100644 index 00000000000..6e9d0a13d7e Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-09-ca197aee-6bfd-4068-b771-e9ca63551a7c.idx/metadata.json.gz differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-10-17048169-a98b-49ee-ae4d-62641023b3ac.idx/index b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-10-17048169-a98b-49ee-ae4d-62641023b3ac.idx/index new file mode 100644 index 00000000000..eb46615541c Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-10-17048169-a98b-49ee-ae4d-62641023b3ac.idx/index differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-10-17048169-a98b-49ee-ae4d-62641023b3ac.idx/metadata.json.gz b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-10-17048169-a98b-49ee-ae4d-62641023b3ac.idx/metadata.json.gz new file mode 100644 index 00000000000..8141d4d180b Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-10-17048169-a98b-49ee-ae4d-62641023b3ac.idx/metadata.json.gz differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-11-c89858f5-4d78-4739-af31-308a1c257ff4.idx/index b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-11-c89858f5-4d78-4739-af31-308a1c257ff4.idx/index new file mode 100644 index 00000000000..075b284619d Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-11-c89858f5-4d78-4739-af31-308a1c257ff4.idx/index differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-11-c89858f5-4d78-4739-af31-308a1c257ff4.idx/metadata.json.gz b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-11-c89858f5-4d78-4739-af31-308a1c257ff4.idx/metadata.json.gz new file mode 100644 index 00000000000..43a1c4fe7eb Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-11-c89858f5-4d78-4739-af31-308a1c257ff4.idx/metadata.json.gz differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-12-3e391e78-782d-495d-a29c-cacc56e1baf8.idx/index b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-12-3e391e78-782d-495d-a29c-cacc56e1baf8.idx/index new file mode 100644 index 00000000000..32dde649c5f Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-12-3e391e78-782d-495d-a29c-cacc56e1baf8.idx/index differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-12-3e391e78-782d-495d-a29c-cacc56e1baf8.idx/metadata.json.gz b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-12-3e391e78-782d-495d-a29c-cacc56e1baf8.idx/metadata.json.gz new file mode 100644 index 00000000000..685b6946a41 Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-12-3e391e78-782d-495d-a29c-cacc56e1baf8.idx/metadata.json.gz differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-13-62566d28-e496-4538-a325-b567be66accf.idx/index b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-13-62566d28-e496-4538-a325-b567be66accf.idx/index new file mode 100644 index 00000000000..74a6b5222be Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-13-62566d28-e496-4538-a325-b567be66accf.idx/index differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-13-62566d28-e496-4538-a325-b567be66accf.idx/metadata.json.gz b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-13-62566d28-e496-4538-a325-b567be66accf.idx/metadata.json.gz new file mode 100644 index 00000000000..6e9d0a13d7e Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-13-62566d28-e496-4538-a325-b567be66accf.idx/metadata.json.gz differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-14-8ab32ab7-15cd-4302-bb45-6b3dc02db5b6.idx/index b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-14-8ab32ab7-15cd-4302-bb45-6b3dc02db5b6.idx/index new file mode 100644 index 00000000000..de143b10175 Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-14-8ab32ab7-15cd-4302-bb45-6b3dc02db5b6.idx/index differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-14-8ab32ab7-15cd-4302-bb45-6b3dc02db5b6.idx/metadata.json.gz b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-14-8ab32ab7-15cd-4302-bb45-6b3dc02db5b6.idx/metadata.json.gz new file mode 100644 index 00000000000..685b6946a41 Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-14-8ab32ab7-15cd-4302-bb45-6b3dc02db5b6.idx/metadata.json.gz differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-15-c4301966-4fd8-4ea0-b439-b49a693bf683.idx/index b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-15-c4301966-4fd8-4ea0-b439-b49a693bf683.idx/index new file mode 100644 index 00000000000..443c522300a Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-15-c4301966-4fd8-4ea0-b439-b49a693bf683.idx/index differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-15-c4301966-4fd8-4ea0-b439-b49a693bf683.idx/metadata.json.gz b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-15-c4301966-4fd8-4ea0-b439-b49a693bf683.idx/metadata.json.gz new file mode 100644 index 00000000000..025d6ff3782 Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-15-c4301966-4fd8-4ea0-b439-b49a693bf683.idx/metadata.json.gz differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-16-8d638c2e-b1a5-4507-ba00-337a02e3f431.idx/index b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-16-8d638c2e-b1a5-4507-ba00-337a02e3f431.idx/index new file mode 100644 index 00000000000..163abab71e0 Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-16-8d638c2e-b1a5-4507-ba00-337a02e3f431.idx/index differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-16-8d638c2e-b1a5-4507-ba00-337a02e3f431.idx/metadata.json.gz b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-16-8d638c2e-b1a5-4507-ba00-337a02e3f431.idx/metadata.json.gz new file mode 100644 index 00000000000..6e9d0a13d7e Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-16-8d638c2e-b1a5-4507-ba00-337a02e3f431.idx/metadata.json.gz differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-17-0c739863-b5fe-4e33-8f47-3e2751b599df.idx/index b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-17-0c739863-b5fe-4e33-8f47-3e2751b599df.idx/index new file mode 100644 index 00000000000..e592bce7110 Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-17-0c739863-b5fe-4e33-8f47-3e2751b599df.idx/index differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-17-0c739863-b5fe-4e33-8f47-3e2751b599df.idx/metadata.json.gz b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-17-0c739863-b5fe-4e33-8f47-3e2751b599df.idx/metadata.json.gz new file mode 100644 index 00000000000..6e9d0a13d7e Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-17-0c739863-b5fe-4e33-8f47-3e2751b599df.idx/metadata.json.gz differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-18-35d65ae7-5d1d-43f8-bb21-e6565874975e.idx/index b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-18-35d65ae7-5d1d-43f8-bb21-e6565874975e.idx/index new file mode 100644 index 00000000000..c67d666d1a5 Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-18-35d65ae7-5d1d-43f8-bb21-e6565874975e.idx/index differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-18-35d65ae7-5d1d-43f8-bb21-e6565874975e.idx/metadata.json.gz b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-18-35d65ae7-5d1d-43f8-bb21-e6565874975e.idx/metadata.json.gz new file mode 100644 index 00000000000..6e9d0a13d7e Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-18-35d65ae7-5d1d-43f8-bb21-e6565874975e.idx/metadata.json.gz differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-19-2fd81de2-5d34-43db-809d-2f1fe1e67200.idx/index b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-19-2fd81de2-5d34-43db-809d-2f1fe1e67200.idx/index new file mode 100644 index 00000000000..a3b1b24d5f7 Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-19-2fd81de2-5d34-43db-809d-2f1fe1e67200.idx/index differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-19-2fd81de2-5d34-43db-809d-2f1fe1e67200.idx/metadata.json.gz b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-19-2fd81de2-5d34-43db-809d-2f1fe1e67200.idx/metadata.json.gz new file mode 100644 index 00000000000..025d6ff3782 Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/index/part-19-2fd81de2-5d34-43db-809d-2f1fe1e67200.idx/metadata.json.gz differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/metadata.json.gz b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/metadata.json.gz new file mode 100644 index 00000000000..ab3049a6bcf Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/metadata.json.gz differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/README.txt b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/README.txt new file mode 100644 index 00000000000..62f13d3fe96 --- /dev/null +++ b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/README.txt @@ -0,0 +1,3 @@ +This folder comprises a Hail (www.hail.is) native Table or MatrixTable. + Written with version 0.2.128-705d4033e0c9 + Created at 2024/03/27 12:03:10 \ No newline at end of file diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/_SUCCESS b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/_SUCCESS new file mode 100644 index 00000000000..e69de29bb2d diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/metadata.json b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/metadata.json new file mode 100644 index 00000000000..93b22d27737 --- /dev/null +++ b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/metadata.json @@ -0,0 +1 @@ +{"file_version":67328,"hail_version":"0.2.128-705d4033e0c9","references_rel_path":"../references","table_type":"Table{global:Struct{},key:[locus],row:Struct{locus:Locus(GRCh37),alleles:Array[String],rsid:String,qual:Float64,filters:Set[String],info:Struct{NEGATIVE_TRAIN_SITE:Boolean,HWP:Float64,AC:Array[Int32],culprit:String,MQ0:Int32,ReadPosRankSum:Float64,AN:Int32,InbreedingCoeff:Float64,AF:Array[Float64],GQ_STDDEV:Float64,FS:Float64,DP:Int32,GQ_MEAN:Float64,POSITIVE_TRAIN_SITE:Boolean,VQSLOD:Float64,ClippingRankSum:Float64,BaseQRankSum:Float64,MLEAF:Array[Float64],MLEAC:Array[Int32],MQ:Float64,QD:Float64,END:Int32,DB:Boolean,HaplotypeScore:Float64,MQRankSum:Float64,CCC:Int32,NCC:Int32,DS:Boolean}}}","components":{"globals":{"name":"RVDComponentSpec","rel_path":"../globals/rows"},"rows":{"name":"RVDComponentSpec","rel_path":"rows"},"partition_counts":{"name":"PartitionCountsComponentSpec","counts":[18,17,17,18,17,17,17,18,17,17,17,18,17,17,17,18,17,17,17,18]},"properties":{"name":"PropertiesSpec","properties":{"distinctlyKeyed":false}}},"name":"TableSpec"} \ No newline at end of file diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/metadata.json.gz b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/metadata.json.gz new file mode 100644 index 00000000000..08dc5527ba2 Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/metadata.json.gz differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/rows/metadata.json b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/rows/metadata.json new file mode 100644 index 00000000000..ac31fd099cf --- /dev/null +++ b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/rows/metadata.json @@ -0,0 +1 @@ +{"name":"IndexedRVDSpec2","_key":["locus"],"_codecSpec":{"name":"TypedCodecSpec","_eType":"+EBaseStruct{locus:EBaseStruct{contig:+EBinary,position:+EInt32},alleles:EArray[EBinary],rsid:EBinary,qual:EFloat64,filters:EArray[EBinary],info:EBaseStruct{NEGATIVE_TRAIN_SITE:EBoolean,HWP:EFloat64,AC:EArray[EInt32],culprit:EBinary,MQ0:EInt32,ReadPosRankSum:EFloat64,AN:EInt32,InbreedingCoeff:EFloat64,AF:EArray[EFloat64],GQ_STDDEV:EFloat64,FS:EFloat64,DP:EInt32,GQ_MEAN:EFloat64,POSITIVE_TRAIN_SITE:EBoolean,VQSLOD:EFloat64,ClippingRankSum:EFloat64,BaseQRankSum:EFloat64,MLEAF:EArray[EFloat64],MLEAC:EArray[EInt32],MQ:EFloat64,QD:EFloat64,END:EInt32,DB:EBoolean,HaplotypeScore:EFloat64,MQRankSum:EFloat64,CCC:EInt32,NCC:EInt32,DS:EBoolean}}","_vType":"Struct{locus:Locus(GRCh37),alleles:Array[String],rsid:String,qual:Float64,filters:Set[String],info:Struct{NEGATIVE_TRAIN_SITE:Boolean,HWP:Float64,AC:Array[Int32],culprit:String,MQ0:Int32,ReadPosRankSum:Float64,AN:Int32,InbreedingCoeff:Float64,AF:Array[Float64],GQ_STDDEV:Float64,FS:Float64,DP:Int32,GQ_MEAN:Float64,POSITIVE_TRAIN_SITE:Boolean,VQSLOD:Float64,ClippingRankSum:Float64,BaseQRankSum:Float64,MLEAF:Array[Float64],MLEAC:Array[Int32],MQ:Float64,QD:Float64,END:Int32,DB:Boolean,HaplotypeScore:Float64,MQRankSum:Float64,CCC:Int32,NCC:Int32,DS:Boolean}}","_bufferSpec":{"name":"LEB128BufferSpec","child":{"name":"BlockingBufferSpec","blockSize":65536,"child":{"name":"ZstdBlockBufferSpec","blockSize":65536,"child":{"name":"StreamBlockBufferSpec"}}}}},"_indexSpec":{"name":"IndexSpec2","_relPath":"../../index","_leafCodec":{"name":"TypedCodecSpec","_eType":"EBaseStruct{first_idx:+EInt64,keys:+EArray[+EBaseStruct{key:+EBaseStruct{locus:EBaseStruct{contig:+EBinary,position:+EInt32}},offset:+EInt64,annotation:+EBaseStruct{entries_offset:EInt64}}]}","_vType":"Struct{first_idx:Int64,keys:Array[Struct{key:Struct{locus:Locus(GRCh37)},offset:Int64,annotation:Struct{entries_offset:Int64}}]}","_bufferSpec":{"name":"LEB128BufferSpec","child":{"name":"BlockingBufferSpec","blockSize":65536,"child":{"name":"ZstdBlockBufferSpec","blockSize":65536,"child":{"name":"StreamBlockBufferSpec"}}}}},"_internalNodeCodec":{"name":"TypedCodecSpec","_eType":"EBaseStruct{children:+EArray[+EBaseStruct{index_file_offset:+EInt64,first_idx:+EInt64,first_key:+EBaseStruct{locus:EBaseStruct{contig:+EBinary,position:+EInt32}},first_record_offset:+EInt64,first_annotation:+EBaseStruct{entries_offset:EInt64}}]}","_vType":"Struct{children:Array[Struct{index_file_offset:Int64,first_idx:Int64,first_key:Struct{locus:Locus(GRCh37)},first_record_offset:Int64,first_annotation:Struct{entries_offset:Int64}}]}","_bufferSpec":{"name":"LEB128BufferSpec","child":{"name":"BlockingBufferSpec","blockSize":65536,"child":{"name":"ZstdBlockBufferSpec","blockSize":65536,"child":{"name":"StreamBlockBufferSpec"}}}}},"_keyType":"Struct{locus:Locus(GRCh37)}","_annotationType":"Struct{entries_offset:Int64}"},"_partFiles":["part-00-cdb826da-6c5c-47b6-945b-3190a87a6a14","part-01-06f6a507-61e2-4bd1-a917-e1809270144c","part-02-881d024c-5baf-4fe6-bc8f-53eda3845bde","part-03-1e085a57-4dcb-4131-bc79-353324ffad47","part-04-d17ed9aa-6b33-4b0b-85d5-578da32f7581","part-05-40d512f8-23ba-485e-aefa-47eced2bfe6d","part-06-9b2dc9c7-c8b1-4ed4-9056-20142b5f6658","part-07-b9a32d97-cb10-4158-aeaa-645dcea68ca7","part-08-c2a0123f-a3d4-4b80-9c21-73cb2bed0b63","part-09-ca197aee-6bfd-4068-b771-e9ca63551a7c","part-10-17048169-a98b-49ee-ae4d-62641023b3ac","part-11-c89858f5-4d78-4739-af31-308a1c257ff4","part-12-3e391e78-782d-495d-a29c-cacc56e1baf8","part-13-62566d28-e496-4538-a325-b567be66accf","part-14-8ab32ab7-15cd-4302-bb45-6b3dc02db5b6","part-15-c4301966-4fd8-4ea0-b439-b49a693bf683","part-16-8d638c2e-b1a5-4507-ba00-337a02e3f431","part-17-0c739863-b5fe-4e33-8f47-3e2751b599df","part-18-35d65ae7-5d1d-43f8-bb21-e6565874975e","part-19-2fd81de2-5d34-43db-809d-2f1fe1e67200"],"_jRangeBounds":[{"start":{"locus":{"contig":"20","position":10019093}},"end":{"locus":{"contig":"20","position":10286773}},"includeStart":true,"includeEnd":true},{"start":{"locus":{"contig":"20","position":10286773}},"end":{"locus":{"contig":"20","position":10603326}},"includeStart":true,"includeEnd":true},{"start":{"locus":{"contig":"20","position":10603326}},"end":{"locus":{"contig":"20","position":10625804}},"includeStart":true,"includeEnd":true},{"start":{"locus":{"contig":"20","position":10625804}},"end":{"locus":{"contig":"20","position":10653469}},"includeStart":true,"includeEnd":true},{"start":{"locus":{"contig":"20","position":10653469}},"end":{"locus":{"contig":"20","position":13071871}},"includeStart":true,"includeEnd":true},{"start":{"locus":{"contig":"20","position":13071871}},"end":{"locus":{"contig":"20","position":13260252}},"includeStart":true,"includeEnd":true},{"start":{"locus":{"contig":"20","position":13260252}},"end":{"locus":{"contig":"20","position":13561632}},"includeStart":true,"includeEnd":true},{"start":{"locus":{"contig":"20","position":13561632}},"end":{"locus":{"contig":"20","position":13709115}},"includeStart":true,"includeEnd":true},{"start":{"locus":{"contig":"20","position":13709115}},"end":{"locus":{"contig":"20","position":13798776}},"includeStart":true,"includeEnd":true},{"start":{"locus":{"contig":"20","position":13798776}},"end":{"locus":{"contig":"20","position":14032627}},"includeStart":true,"includeEnd":true},{"start":{"locus":{"contig":"20","position":14032627}},"end":{"locus":{"contig":"20","position":15948325}},"includeStart":true,"includeEnd":true},{"start":{"locus":{"contig":"20","position":15948325}},"end":{"locus":{"contig":"20","position":16347823}},"includeStart":true,"includeEnd":true},{"start":{"locus":{"contig":"20","position":16347823}},"end":{"locus":{"contig":"20","position":16410559}},"includeStart":true,"includeEnd":true},{"start":{"locus":{"contig":"20","position":16410559}},"end":{"locus":{"contig":"20","position":17410116}},"includeStart":true,"includeEnd":true},{"start":{"locus":{"contig":"20","position":17410116}},"end":{"locus":{"contig":"20","position":17475217}},"includeStart":true,"includeEnd":true},{"start":{"locus":{"contig":"20","position":17475217}},"end":{"locus":{"contig":"20","position":17595540}},"includeStart":true,"includeEnd":true},{"start":{"locus":{"contig":"20","position":17595540}},"end":{"locus":{"contig":"20","position":17600357}},"includeStart":true,"includeEnd":true},{"start":{"locus":{"contig":"20","position":17600357}},"end":{"locus":{"contig":"20","position":17608348}},"includeStart":true,"includeEnd":true},{"start":{"locus":{"contig":"20","position":17608348}},"end":{"locus":{"contig":"20","position":17705709}},"includeStart":true,"includeEnd":true},{"start":{"locus":{"contig":"20","position":17705709}},"end":{"locus":{"contig":"20","position":17970876}},"includeStart":true,"includeEnd":true}],"_attrs":{}} \ No newline at end of file diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/rows/metadata.json.gz b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/rows/metadata.json.gz new file mode 100644 index 00000000000..94fa69a4b5e Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/rows/metadata.json.gz differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/rows/parts/part-00-cdb826da-6c5c-47b6-945b-3190a87a6a14 b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/rows/parts/part-00-cdb826da-6c5c-47b6-945b-3190a87a6a14 new file mode 100644 index 00000000000..bd9fde1e1a7 Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/rows/parts/part-00-cdb826da-6c5c-47b6-945b-3190a87a6a14 differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/rows/parts/part-01-06f6a507-61e2-4bd1-a917-e1809270144c b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/rows/parts/part-01-06f6a507-61e2-4bd1-a917-e1809270144c new file mode 100644 index 00000000000..a1cd5b98f8a Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/rows/parts/part-01-06f6a507-61e2-4bd1-a917-e1809270144c differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/rows/parts/part-02-881d024c-5baf-4fe6-bc8f-53eda3845bde b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/rows/parts/part-02-881d024c-5baf-4fe6-bc8f-53eda3845bde new file mode 100644 index 00000000000..afb8ce8cc61 Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/rows/parts/part-02-881d024c-5baf-4fe6-bc8f-53eda3845bde differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/rows/parts/part-03-1e085a57-4dcb-4131-bc79-353324ffad47 b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/rows/parts/part-03-1e085a57-4dcb-4131-bc79-353324ffad47 new file mode 100644 index 00000000000..768145d1894 Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/rows/parts/part-03-1e085a57-4dcb-4131-bc79-353324ffad47 differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/rows/parts/part-04-d17ed9aa-6b33-4b0b-85d5-578da32f7581 b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/rows/parts/part-04-d17ed9aa-6b33-4b0b-85d5-578da32f7581 new file mode 100644 index 00000000000..6239f2a11bc Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/rows/parts/part-04-d17ed9aa-6b33-4b0b-85d5-578da32f7581 differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/rows/parts/part-05-40d512f8-23ba-485e-aefa-47eced2bfe6d b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/rows/parts/part-05-40d512f8-23ba-485e-aefa-47eced2bfe6d new file mode 100644 index 00000000000..7e2b979ea0a Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/rows/parts/part-05-40d512f8-23ba-485e-aefa-47eced2bfe6d differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/rows/parts/part-06-9b2dc9c7-c8b1-4ed4-9056-20142b5f6658 b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/rows/parts/part-06-9b2dc9c7-c8b1-4ed4-9056-20142b5f6658 new file mode 100644 index 00000000000..b45a4a1edac Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/rows/parts/part-06-9b2dc9c7-c8b1-4ed4-9056-20142b5f6658 differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/rows/parts/part-07-b9a32d97-cb10-4158-aeaa-645dcea68ca7 b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/rows/parts/part-07-b9a32d97-cb10-4158-aeaa-645dcea68ca7 new file mode 100644 index 00000000000..ed86b73273c Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/rows/parts/part-07-b9a32d97-cb10-4158-aeaa-645dcea68ca7 differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/rows/parts/part-08-c2a0123f-a3d4-4b80-9c21-73cb2bed0b63 b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/rows/parts/part-08-c2a0123f-a3d4-4b80-9c21-73cb2bed0b63 new file mode 100644 index 00000000000..18f8b835340 Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/rows/parts/part-08-c2a0123f-a3d4-4b80-9c21-73cb2bed0b63 differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/rows/parts/part-09-ca197aee-6bfd-4068-b771-e9ca63551a7c b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/rows/parts/part-09-ca197aee-6bfd-4068-b771-e9ca63551a7c new file mode 100644 index 00000000000..970968a799c Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/rows/parts/part-09-ca197aee-6bfd-4068-b771-e9ca63551a7c differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/rows/parts/part-10-17048169-a98b-49ee-ae4d-62641023b3ac b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/rows/parts/part-10-17048169-a98b-49ee-ae4d-62641023b3ac new file mode 100644 index 00000000000..370ced208ec Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/rows/parts/part-10-17048169-a98b-49ee-ae4d-62641023b3ac differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/rows/parts/part-11-c89858f5-4d78-4739-af31-308a1c257ff4 b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/rows/parts/part-11-c89858f5-4d78-4739-af31-308a1c257ff4 new file mode 100644 index 00000000000..1977621af1c Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/rows/parts/part-11-c89858f5-4d78-4739-af31-308a1c257ff4 differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/rows/parts/part-12-3e391e78-782d-495d-a29c-cacc56e1baf8 b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/rows/parts/part-12-3e391e78-782d-495d-a29c-cacc56e1baf8 new file mode 100644 index 00000000000..433d5edfe3f Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/rows/parts/part-12-3e391e78-782d-495d-a29c-cacc56e1baf8 differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/rows/parts/part-13-62566d28-e496-4538-a325-b567be66accf b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/rows/parts/part-13-62566d28-e496-4538-a325-b567be66accf new file mode 100644 index 00000000000..a993f23c317 Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/rows/parts/part-13-62566d28-e496-4538-a325-b567be66accf differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/rows/parts/part-14-8ab32ab7-15cd-4302-bb45-6b3dc02db5b6 b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/rows/parts/part-14-8ab32ab7-15cd-4302-bb45-6b3dc02db5b6 new file mode 100644 index 00000000000..19c21783ca7 Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/rows/parts/part-14-8ab32ab7-15cd-4302-bb45-6b3dc02db5b6 differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/rows/parts/part-15-c4301966-4fd8-4ea0-b439-b49a693bf683 b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/rows/parts/part-15-c4301966-4fd8-4ea0-b439-b49a693bf683 new file mode 100644 index 00000000000..b72a0a3b17e Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/rows/parts/part-15-c4301966-4fd8-4ea0-b439-b49a693bf683 differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/rows/parts/part-16-8d638c2e-b1a5-4507-ba00-337a02e3f431 b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/rows/parts/part-16-8d638c2e-b1a5-4507-ba00-337a02e3f431 new file mode 100644 index 00000000000..807a56fbfad Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/rows/parts/part-16-8d638c2e-b1a5-4507-ba00-337a02e3f431 differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/rows/parts/part-17-0c739863-b5fe-4e33-8f47-3e2751b599df b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/rows/parts/part-17-0c739863-b5fe-4e33-8f47-3e2751b599df new file mode 100644 index 00000000000..90e1b9a99d1 Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/rows/parts/part-17-0c739863-b5fe-4e33-8f47-3e2751b599df differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/rows/parts/part-18-35d65ae7-5d1d-43f8-bb21-e6565874975e b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/rows/parts/part-18-35d65ae7-5d1d-43f8-bb21-e6565874975e new file mode 100644 index 00000000000..77cc5d4ba42 Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/rows/parts/part-18-35d65ae7-5d1d-43f8-bb21-e6565874975e differ diff --git a/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/rows/parts/part-19-2fd81de2-5d34-43db-809d-2f1fe1e67200 b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/rows/parts/part-19-2fd81de2-5d34-43db-809d-2f1fe1e67200 new file mode 100644 index 00000000000..9b37a0670af Binary files /dev/null and b/hail/src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows/rows/parts/part-19-2fd81de2-5d34-43db-809d-2f1fe1e67200 differ diff --git a/hail/src/test/resources/vep_grch38_input_req_indexed_cache.tsv b/hail/src/test/resources/vep_grch38_input_req_indexed_cache.tsv new file mode 100644 index 00000000000..11c11c7a1b3 --- /dev/null +++ b/hail/src/test/resources/vep_grch38_input_req_indexed_cache.tsv @@ -0,0 +1,21 @@ +locus alleles +chr1:1339585 ["G","A"] +chr1:24907372 ["C","T"] +chr1:36859143 ["G","T"] +chr1:37969436 ["T","C"] +chr1:40416828 ["G","A"] +chr1:41581842 ["G","A"] +chr1:43920822 ["T","C"] +chr1:45327881 ["G","A"] +chr1:46817055 ["CT","C"] +chr1:54999203 ["C","T"] +chr1:65218884 ["C","T"] +chr1:102962250 ["G","T"] +chr1:111756087 ["G","C"] +chr1:201363319 ["G","A"] +chr1:223749094 ["A","G"] +chr1:224294328 ["G","A"] +chr1:235809337 ["G","A"] +chr1:241592073 ["G","T"] +chr2:9376947 ["G","A"] +chr2:11618532 ["C","T"] diff --git a/hail/src/test/scala/is/hail/HailContextSuite.scala b/hail/src/test/scala/is/hail/HailContextSuite.scala index e1314c7394e..3701bb52d81 100644 --- a/hail/src/test/scala/is/hail/HailContextSuite.scala +++ b/hail/src/test/scala/is/hail/HailContextSuite.scala @@ -1,6 +1,7 @@ package is.hail import is.hail.backend.spark.SparkBackend + import org.testng.annotations.Test class HailContextSuite extends HailSuite { diff --git a/hail/src/test/scala/is/hail/HailSuite.scala b/hail/src/test/scala/is/hail/HailSuite.scala index fec61bbe882..5e7b85fbea3 100644 --- a/hail/src/test/scala/is/hail/HailSuite.scala +++ b/hail/src/test/scala/is/hail/HailSuite.scala @@ -1,23 +1,24 @@ package is.hail -import breeze.linalg.DenseMatrix import is.hail.ExecStrategy.ExecStrategy import is.hail.TestUtils._ import is.hail.annotations._ -import is.hail.backend.spark.SparkBackend import is.hail.backend.{BroadcastValue, ExecuteContext} +import is.hail.backend.spark.SparkBackend import is.hail.expr.ir._ import is.hail.io.fs.FS import is.hail.types.virtual._ import is.hail.utils._ + +import java.io.{File, PrintWriter} + +import breeze.linalg.DenseMatrix import org.apache.spark.SparkContext import org.apache.spark.sql.Row import org.scalatest.testng.TestNGSuite import org.testng.ITestContext import org.testng.annotations.{AfterMethod, BeforeClass, BeforeMethod} -import java.io.{File, PrintWriter} - object HailSuite { val theHailClassLoader = TestUtils.theHailClassLoader @@ -29,11 +30,14 @@ object HailSuite { appName = "Hail.TestNG", master = System.getProperty("hail.master"), local = "local[2]", - blockSize = 0) - .set("spark.unsafe.exceptionOnMemoryLeak", "true")), + blockSize = 0, + ) + .set("spark.unsafe.exceptionOnMemoryLeak", "true") + ), tmpdir = "/tmp", localTmpdir = "file:///tmp", - skipLoggingConfiguration = true) + skipLoggingConfiguration = true, + ) HailContext(backend) } @@ -50,7 +54,7 @@ class HailSuite extends TestNGSuite { def hc: HailContext = HailSuite.hc - @BeforeClass def ensureHailContextInitialized() { hc } + @BeforeClass def ensureHailContextInitialized(): Unit = hc def backend: SparkBackend = hc.sparkBackend("HailSuite.backend") @@ -72,7 +76,7 @@ class HailSuite extends TestNGSuite { timer = new ExecutionTimer("HailSuite") assert(ctx == null) pool = RegionPool() - ctx = backend.createExecuteContextForTests(timer, Region(pool=pool)) + ctx = backend.createExecuteContextForTests(timer, Region(pool = pool)) } @AfterMethod @@ -87,21 +91,19 @@ class HailSuite extends TestNGSuite { throw new RuntimeException(s"method stopped spark context!") } - def withExecuteContext[T]()(f: ExecuteContext => T): T = { + def withExecuteContext[T]()(f: ExecuteContext => T): T = ExecutionTimer.logTime("HailSuite.withExecuteContext") { timer => hc.sparkBackend("HailSuite.withExecuteContext").withExecuteContext(timer)(f) } - } def assertEvalsTo( x: IR, env: Env[(Any, Type)], args: IndexedSeq[(Any, Type)], agg: Option[(IndexedSeq[Row], TStruct)], - expected: Any - )( - implicit execStrats: Set[ExecStrategy] - ) { + expected: Any, + )(implicit execStrats: Set[ExecStrategy] + ): Unit = { TypeCheck(ctx, x, BindingEnv(env.mapValues(_._2), agg = agg.map(_._2.toEnv))) @@ -128,29 +130,47 @@ class HailSuite extends TestNGSuite { Interpret[Any](ctx, x, env, args, optimize = false) case ExecStrategy.JvmCompile => assert(Forall(x, node => Compilable(node))) - eval(x, env, args, agg, bytecodePrinter = - Option(ctx.getFlag("jvm_bytecode_dump")) - .map { path => - val pw = new PrintWriter(new File(path)) - pw.print(s"/* JVM bytecode dump for IR:\n${Pretty(ctx, x)}\n */\n\n") - pw - }, true, ctx) + eval( + x, + env, + args, + agg, + bytecodePrinter = + Option(ctx.getFlag("jvm_bytecode_dump")) + .map { path => + val pw = new PrintWriter(new File(path)) + pw.print(s"/* JVM bytecode dump for IR:\n${Pretty(ctx, x)}\n */\n\n") + pw + }, + true, + ctx, + ) case ExecStrategy.JvmCompileUnoptimized => assert(Forall(x, node => Compilable(node))) - eval(x, env, args, agg, bytecodePrinter = - Option(ctx.getFlag("jvm_bytecode_dump")) - .map { path => - val pw = new PrintWriter(new File(path)) - pw.print(s"/* JVM bytecode dump for IR:\n${Pretty(ctx, x)}\n */\n\n") - pw - }, - optimize = false, ctx) + eval( + x, + env, + args, + agg, + bytecodePrinter = + Option(ctx.getFlag("jvm_bytecode_dump")) + .map { path => + val pw = new PrintWriter(new File(path)) + pw.print(s"/* JVM bytecode dump for IR:\n${Pretty(ctx, x)}\n */\n\n") + pw + }, + optimize = false, + ctx, + ) case ExecStrategy.LoweredJVMCompile => loweredExecute(ctx, x, env, args, agg) } if (t != TVoid) { assert(t.typeCheck(res), s"\n t=$t\n result=$res\n strategy=$strat") - assert(t.valuesSimilar(res, expected), s"\n result=$res\n expect=$expected\n strategy=$strat)") + assert( + t.valuesSimilar(res, expected), + s"\n result=$res\n expect=$expected\n strategy=$strat)", + ) } } catch { case e: Exception => @@ -161,37 +181,42 @@ class HailSuite extends TestNGSuite { } } - def assertNDEvals(nd: IR, expected: Any) - (implicit execStrats: Set[ExecStrategy]) { + def assertNDEvals(nd: IR, expected: Any)(implicit execStrats: Set[ExecStrategy]): Unit = assertNDEvals(nd, Env.empty, FastSeq(), None, expected) - } - def assertNDEvals(nd: IR, expected: (Any, IndexedSeq[Long])) - (implicit execStrats: Set[ExecStrategy]) { + def assertNDEvals( + nd: IR, + expected: (Any, IndexedSeq[Long]), + )(implicit execStrats: Set[ExecStrategy] + ): Unit = if (expected == null) assertNDEvals(nd, Env.empty, FastSeq(), None, null, null) else assertNDEvals(nd, Env.empty, FastSeq(), None, expected._2, expected._1) - } - def assertNDEvals(nd: IR, args: IndexedSeq[(Any, Type)], expected: Any) - (implicit execStrats: Set[ExecStrategy]) { + def assertNDEvals( + nd: IR, + args: IndexedSeq[(Any, Type)], + expected: Any, + )(implicit execStrats: Set[ExecStrategy] + ): Unit = assertNDEvals(nd, Env.empty, args, None, expected) - } - def assertNDEvals(nd: IR, agg: (IndexedSeq[Row], TStruct), expected: Any) - (implicit execStrats: Set[ExecStrategy]) { + def assertNDEvals( + nd: IR, + agg: (IndexedSeq[Row], TStruct), + expected: Any, + )(implicit execStrats: Set[ExecStrategy] + ): Unit = assertNDEvals(nd, Env.empty, FastSeq(), Some(agg), expected) - } def assertNDEvals( nd: IR, env: Env[(Any, Type)], args: IndexedSeq[(Any, Type)], agg: Option[(IndexedSeq[Row], TStruct)], - expected: Any - )( - implicit execStrats: Set[ExecStrategy] + expected: Any, + )(implicit execStrats: Set[ExecStrategy] ): Unit = { var e: IndexedSeq[Any] = expected.asInstanceOf[IndexedSeq[Any]] val dims = Array.fill(nd.typ.asInstanceOf[TNDArray].nDims) { @@ -209,26 +234,27 @@ class HailSuite extends TestNGSuite { args: IndexedSeq[(Any, Type)], agg: Option[(IndexedSeq[Row], TStruct)], dims: IndexedSeq[Long], - expected: Any - )( - implicit execStrats: Set[ExecStrategy] + expected: Any, + )(implicit execStrats: Set[ExecStrategy] ): Unit = { - val arrayIR = if (expected == null) nd else { - val refs = Array.fill(nd.typ.asInstanceOf[TNDArray].nDims) { Ref(genUID(), TInt32) } - Let(FastSeq("nd" -> nd), + val arrayIR = if (expected == null) nd + else { + val refs = Array.fill(nd.typ.asInstanceOf[TNDArray].nDims)(Ref(genUID(), TInt32)) + Let( + FastSeq("nd" -> nd), dims.zip(refs).foldRight[IR](NDArrayRef(Ref("nd", nd.typ), refs.map(Cast(_, TInt64)), -1)) { case ((n, ref), accum) => ToArray(StreamMap(rangeIR(n.toInt), ref.name, accum)) - }) + }, + ) } assertEvalsTo(arrayIR, env, args, agg, expected) } def assertBMEvalsTo( bm: BlockMatrixIR, - expected: DenseMatrix[Double] - )( - implicit execStrats: Set[ExecStrategy] + expected: DenseMatrix[Double], + )(implicit execStrats: Set[ExecStrategy] ): Unit = { ExecuteContext.scoped() { ctx => val filteredExecStrats: Set[ExecStrategy] = @@ -252,45 +278,41 @@ class HailSuite extends TestNGSuite { if (execStrats.contains(strat)) throw e } } - val expectedArray = Array.tabulate(expected.rows)(i => Array.tabulate(expected.cols)(j => expected(i, j)).toFastSeq).toFastSeq - assertNDEvals(BlockMatrixCollect(bm), expectedArray)(filteredExecStrats.filterNot(ExecStrategy.interpretOnly)) + val expectedArray = Array.tabulate(expected.rows)(i => + Array.tabulate(expected.cols)(j => expected(i, j)).toFastSeq + ).toFastSeq + assertNDEvals(BlockMatrixCollect(bm), expectedArray)( + filteredExecStrats.filterNot(ExecStrategy.interpretOnly) + ) } } def assertAllEvalTo( xs: (IR, Any)* - )( - implicit execStrats: Set[ExecStrategy] - ): Unit = { + )(implicit execStrats: Set[ExecStrategy] + ): Unit = assertEvalsTo(MakeTuple.ordered(xs.toArray.map(_._1)), Row.fromSeq(xs.map(_._2))) - } def assertEvalsTo( x: IR, - expected: Any - )( - implicit execStrats: Set[ExecStrategy] - ) { + expected: Any, + )(implicit execStrats: Set[ExecStrategy] + ): Unit = assertEvalsTo(x, Env.empty, FastSeq(), None, expected) - } def assertEvalsTo( x: IR, args: IndexedSeq[(Any, Type)], - expected: Any - )( - implicit execStrats: Set[ExecStrategy] - ) { + expected: Any, + )(implicit execStrats: Set[ExecStrategy] + ): Unit = assertEvalsTo(x, Env.empty, args, None, expected) - } def assertEvalsTo( x: IR, agg: (IndexedSeq[Row], TStruct), - expected: Any - )( - implicit execStrats: Set[ExecStrategy] - ) { + expected: Any, + )(implicit execStrats: Set[ExecStrategy] + ): Unit = assertEvalsTo(x, Env.empty, FastSeq(), Some(agg), expected) - } } diff --git a/hail/src/test/scala/is/hail/LogTestListener.scala b/hail/src/test/scala/is/hail/LogTestListener.scala index 82ead74f399..aa5347480ba 100644 --- a/hail/src/test/scala/is/hail/LogTestListener.scala +++ b/hail/src/test/scala/is/hail/LogTestListener.scala @@ -1,24 +1,20 @@ package is.hail import java.io.{PrintWriter, StringWriter} -import is.hail.utils._ -import org.apache.log4j.{ConsoleAppender, PatternLayout} + import org.testng.{ITestContext, ITestListener, ITestResult} class LogTestListener extends ITestListener { - def testString(result: ITestResult): String = { - s"${ result.getTestClass.getName }.${ result.getMethod.getMethodName }" - } + def testString(result: ITestResult): String = + s"${result.getTestClass.getName}.${result.getMethod.getMethodName}" - def onTestStart(result: ITestResult) { - System.err.println(s"starting test ${ testString(result) }...") - } + override def onTestStart(result: ITestResult): Unit = + System.err.println(s"starting test ${testString(result)}...") - def onTestSuccess(result: ITestResult) { - System.err.println(s"test ${ testString(result) } SUCCESS") - } + override def onTestSuccess(result: ITestResult): Unit = + System.err.println(s"test ${testString(result)} SUCCESS") - def onTestFailure(result: ITestResult) { + override def onTestFailure(result: ITestResult): Unit = { val cause = result.getThrowable if (cause != null) { val sw = new StringWriter() @@ -26,22 +22,15 @@ class LogTestListener extends ITestListener { cause.printStackTrace(pw) System.err.println(s"Exception:\n$sw") } - System.err.println(s"test ${ testString(result) } FAILURE\n") + System.err.println(s"test ${testString(result)} FAILURE\n") } - def onTestSkipped(result: ITestResult) { - System.err.println(s"test ${ testString(result) } SKIPPED") - } + override def onTestSkipped(result: ITestResult): Unit = + System.err.println(s"test ${testString(result)} SKIPPED") - def onTestFailedButWithinSuccessPercentage(result: ITestResult) { + override def onTestFailedButWithinSuccessPercentage(result: ITestResult): Unit = {} - } + override def onStart(context: ITestContext): Unit = {} - def onStart(context: ITestContext) { - - } - - def onFinish(context: ITestContext) { - - } + override def onFinish(context: ITestContext): Unit = {} } diff --git a/hail/src/test/scala/is/hail/TestUtils.scala b/hail/src/test/scala/is/hail/TestUtils.scala index 1d7da451e1f..246bd9b15b8 100644 --- a/hail/src/test/scala/is/hail/TestUtils.scala +++ b/hail/src/test/scala/is/hail/TestUtils.scala @@ -1,32 +1,40 @@ package is.hail -import breeze.linalg.{DenseMatrix, Matrix, Vector} import is.hail.annotations.{Region, RegionValueBuilder, SafeRow} import is.hail.asm4s._ import is.hail.backend.ExecuteContext -import is.hail.expr.ir.lowering.LowererUnsupportedOperation import is.hail.expr.ir._ +import is.hail.expr.ir.lowering.LowererUnsupportedOperation import is.hail.io.vcf.MatrixVCFReader -import is.hail.types.physical.stypes.PTypeReferenceSingleCodeType import is.hail.types.physical.{PBaseStruct, PCanonicalArray, PType} +import is.hail.types.physical.stypes.PTypeReferenceSingleCodeType import is.hail.types.virtual._ import is.hail.utils._ import is.hail.variant._ -import org.apache.spark.SparkException -import org.apache.spark.sql.Row -import java.io.PrintWriter import scala.collection.mutable +import java.io.PrintWriter + +import breeze.linalg.{DenseMatrix, Matrix, Vector} +import org.apache.spark.SparkException +import org.apache.spark.sql.Row + object ExecStrategy extends Enumeration { type ExecStrategy = Value val Interpret, InterpretUnoptimized, JvmCompile, LoweredJVMCompile, JvmCompileUnoptimized = Value val unoptimizedCompileOnly: Set[ExecStrategy] = Set(JvmCompileUnoptimized) val compileOnly: Set[ExecStrategy] = Set(JvmCompile, JvmCompileUnoptimized) - val javaOnly: Set[ExecStrategy] = Set(Interpret, InterpretUnoptimized, JvmCompile, JvmCompileUnoptimized) + + val javaOnly: Set[ExecStrategy] = + Set(Interpret, InterpretUnoptimized, JvmCompile, JvmCompileUnoptimized) + val interpretOnly: Set[ExecStrategy] = Set(Interpret, InterpretUnoptimized) - val nonLowering: Set[ExecStrategy] = Set(Interpret, InterpretUnoptimized, JvmCompile, JvmCompileUnoptimized) + + val nonLowering: Set[ExecStrategy] = + Set(Interpret, InterpretUnoptimized, JvmCompile, JvmCompileUnoptimized) + val lowering: Set[ExecStrategy] = Set(LoweredJVMCompile) val backendOnly: Set[ExecStrategy] = Set(LoweredJVMCompile) val allRelational: Set[ExecStrategy] = interpretOnly.union(lowering) @@ -37,37 +45,45 @@ object TestUtils { import org.scalatest.Assertions._ - def interceptException[E <: Throwable : Manifest](regex: String)(f: => Any) { + def interceptException[E <: Throwable: Manifest](regex: String)(f: => Any): Unit = { val thrown = intercept[E](f) val p = regex.r.findFirstIn(thrown.getMessage).isDefined val msg = s"""expected fatal exception with pattern '$regex' - | Found: ${ thrown.getMessage } """ + | Found: ${thrown.getMessage} """ if (!p) println(msg) assert(p, msg) } - def interceptFatal(regex: String)(f: => Any) { + + def interceptFatal(regex: String)(f: => Any): Unit = interceptException[HailException](regex)(f) - } - def interceptSpark(regex: String)(f: => Any) { + def interceptSpark(regex: String)(f: => Any): Unit = interceptException[SparkException](regex)(f) - } - def interceptAssertion(regex: String)(f: => Any) { + def interceptAssertion(regex: String)(f: => Any): Unit = interceptException[AssertionError](regex)(f) - } - def assertVectorEqualityDouble(A: Vector[Double], B: Vector[Double], tolerance: Double = utils.defaultTolerance) { + def assertVectorEqualityDouble( + A: Vector[Double], + B: Vector[Double], + tolerance: Double = utils.defaultTolerance, + ): Unit = { assert(A.size == B.size) assert((0 until A.size).forall(i => D_==(A(i), B(i), tolerance))) } - def assertMatrixEqualityDouble(A: Matrix[Double], B: Matrix[Double], tolerance: Double = utils.defaultTolerance) { + def assertMatrixEqualityDouble( + A: Matrix[Double], + B: Matrix[Double], + tolerance: Double = utils.defaultTolerance, + ): Unit = { assert(A.rows == B.rows) assert(A.cols == B.cols) - assert((0 until A.rows).forall(i => (0 until A.cols).forall(j => D_==(A(i, j), B(i, j), tolerance)))) + assert((0 until A.rows).forall(i => + (0 until A.cols).forall(j => D_==(A(i, j), B(i, j), tolerance)) + )) } def isConstant(A: Vector[Int]): Boolean = { @@ -88,151 +104,173 @@ object TestUtils { new DenseMatrix(A.rows, newCols, data) } - def unphasedDiploidGtIndicesToBoxedCall(m: DenseMatrix[Int]): DenseMatrix[BoxedCall] = { + def unphasedDiploidGtIndicesToBoxedCall(m: DenseMatrix[Int]): DenseMatrix[BoxedCall] = m.map(g => if (g == -1) null: BoxedCall else Call2.fromUnphasedDiploidGtIndex(g): BoxedCall) - } - - def loweredExecute(ctx: ExecuteContext, x: IR, env: Env[(Any, Type)], + def loweredExecute( + ctx: ExecuteContext, + x: IR, + env: Env[(Any, Type)], args: IndexedSeq[(Any, Type)], agg: Option[(IndexedSeq[Row], TStruct)], - bytecodePrinter: Option[PrintWriter] = None + bytecodePrinter: Option[PrintWriter] = None, ): Any = { if (agg.isDefined || !env.isEmpty || !args.isEmpty) throw new LowererUnsupportedOperation("can't test with aggs or user defined args/env") ExecutionTimer.logTime("TestUtils.loweredExecute") { timer => HailContext.sparkBackend("TestUtils.loweredExecute") - .jvmLowerAndExecute(ctx, timer, x, optimize = false, lowerTable = true, lowerBM = true, print = bytecodePrinter) + .jvmLowerAndExecute(ctx, timer, x, optimize = false, lowerTable = true, lowerBM = true, + print = bytecodePrinter) } } - def eval(x: IR): Any = ExecuteContext.scoped(){ ctx => + def eval(x: IR): Any = ExecuteContext.scoped() { ctx => eval(x, Env.empty, FastSeq(), None, None, true, ctx) } - def eval(x: IR, + def eval( + x: IR, env: Env[(Any, Type)], args: IndexedSeq[(Any, Type)], agg: Option[(IndexedSeq[Row], TStruct)], bytecodePrinter: Option[PrintWriter] = None, optimize: Boolean = true, - ctx: ExecuteContext + ctx: ExecuteContext, ): Any = { - val inputTypesB = new BoxedArrayBuilder[Type]() - val inputsB = new mutable.ArrayBuffer[Any]() + val inputTypesB = new BoxedArrayBuilder[Type]() + val inputsB = new mutable.ArrayBuffer[Any]() - args.foreach { case (v, t) => - inputsB += v - inputTypesB += t - } + args.foreach { case (v, t) => + inputsB += v + inputTypesB += t + } - env.m.foreach { case (name, (v, t)) => - inputsB += v - inputTypesB += t - } + env.m.foreach { case (_, (v, t)) => + inputsB += v + inputTypesB += t + } - val argsType = TTuple(inputTypesB.result(): _*) - val resultType = TTuple(x.typ) - val argsVar = genUID() + val argsType = TTuple(inputTypesB.result(): _*) + val resultType = TTuple(x.typ) + val argsVar = genUID() - val (_, substEnv) = env.m.foldLeft((args.length, Env.empty[IR])) { case ((i, env), (name, (v, t))) => + val (_, substEnv) = + env.m.foldLeft((args.length, Env.empty[IR])) { case ((i, env), (name, (_, _))) => (i + 1, env.bind(name, GetTupleElement(Ref(argsVar, argsType), i))) } - def rewrite(x: IR): IR = { - x match { - case In(i, t) => - GetTupleElement(Ref(argsVar, argsType), i) - case _ => - MapIR(rewrite)(x) - } + def rewrite(x: IR): IR = { + x match { + case In(i, _) => + GetTupleElement(Ref(argsVar, argsType), i) + case _ => + MapIR(rewrite)(x) } + } - val argsPType = PType.canonical(argsType).setRequired(true) - agg match { - case Some((aggElements, aggType)) => - val aggElementVar = genUID() - val aggArrayVar = genUID() - val aggPType = PType.canonical(aggType) - val aggArrayPType = PCanonicalArray(aggPType, required = true) + val argsPType = PType.canonical(argsType).setRequired(true) + agg match { + case Some((aggElements, aggType)) => + val aggElementVar = genUID() + val aggArrayVar = genUID() + val aggPType = PType.canonical(aggType) + val aggArrayPType = PCanonicalArray(aggPType, required = true) - val substAggEnv = aggType.fields.foldLeft(Env.empty[IR]) { case (env, f) => - env.bind(f.name, GetField(Ref(aggElementVar, aggType), f.name)) - } - val aggIR = StreamAgg(ToStream(Ref(aggArrayVar, aggArrayPType.virtualType)), - aggElementVar, - MakeTuple.ordered(FastSeq(rewrite(Subst(x, BindingEnv(eval = substEnv, agg = Some(substAggEnv))))))) - - val (Some(PTypeReferenceSingleCodeType(resultType2)), f) = Compile[AsmFunction3RegionLongLongLong](ctx, - FastSeq((argsVar, SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(argsPType))), - (aggArrayVar, SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(aggArrayPType)))), - FastSeq(classInfo[Region], LongInfo, LongInfo), LongInfo, + val substAggEnv = aggType.fields.foldLeft(Env.empty[IR]) { case (env, f) => + env.bind(f.name, GetField(Ref(aggElementVar, aggType), f.name)) + } + val aggIR = StreamAgg( + ToStream(Ref(aggArrayVar, aggArrayPType.virtualType)), + aggElementVar, + MakeTuple.ordered(FastSeq(rewrite(Subst( + x, + BindingEnv(eval = substEnv, agg = Some(substAggEnv)), + )))), + ) + + val (Some(PTypeReferenceSingleCodeType(resultType2)), f) = + Compile[AsmFunction3RegionLongLongLong]( + ctx, + FastSeq( + (argsVar, SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(argsPType))), + ( + aggArrayVar, + SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(aggArrayPType)), + ), + ), + FastSeq(classInfo[Region], LongInfo, LongInfo), + LongInfo, aggIR, print = bytecodePrinter, - optimize = optimize) - assert(resultType2.virtualType == resultType) - - ctx.r.pool.scopedRegion { region => - val rvb = new RegionValueBuilder(ctx.stateManager, region) - rvb.start(argsPType) - rvb.startTuple() - var i = 0 - while (i < inputsB.length) { - rvb.addAnnotation(inputTypesB(i), inputsB(i)) - i += 1 - } - rvb.endTuple() - val argsOff = rvb.end() - - rvb.start(aggArrayPType) - rvb.startArray(aggElements.length) - aggElements.foreach { r => - rvb.addAnnotation(aggType, r) - } - rvb.endArray() - val aggOff = rvb.end() - - val resultOff = f(theHailClassLoader, ctx.fs, ctx.taskContext, region)(region, argsOff, aggOff) - SafeRow(resultType2.asInstanceOf[PBaseStruct], resultOff).get(0) + optimize = optimize, + ) + assert(resultType2.virtualType == resultType) + + ctx.r.pool.scopedRegion { region => + val rvb = new RegionValueBuilder(ctx.stateManager, region) + rvb.start(argsPType) + rvb.startTuple() + var i = 0 + while (i < inputsB.length) { + rvb.addAnnotation(inputTypesB(i), inputsB(i)) + i += 1 } + rvb.endTuple() + val argsOff = rvb.end() + + rvb.start(aggArrayPType) + rvb.startArray(aggElements.length) + aggElements.foreach(r => rvb.addAnnotation(aggType, r)) + rvb.endArray() + val aggOff = rvb.end() + + val resultOff = + f(theHailClassLoader, ctx.fs, ctx.taskContext, region)(region, argsOff, aggOff) + SafeRow(resultType2.asInstanceOf[PBaseStruct], resultOff).get(0) + } - case None => - val (Some(PTypeReferenceSingleCodeType(resultType2)), f) = Compile[AsmFunction2RegionLongLong](ctx, - FastSeq((argsVar, SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(argsPType)))), - FastSeq(classInfo[Region], LongInfo), LongInfo, + case None => + val (Some(PTypeReferenceSingleCodeType(resultType2)), f) = + Compile[AsmFunction2RegionLongLong]( + ctx, + FastSeq(( + argsVar, + SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(argsPType)), + )), + FastSeq(classInfo[Region], LongInfo), + LongInfo, MakeTuple.ordered(FastSeq(rewrite(Subst(x, BindingEnv(substEnv))))), optimize = optimize, - print = bytecodePrinter) - assert(resultType2.virtualType == resultType) - - ctx.r.pool.scopedRegion { region => - val rvb = new RegionValueBuilder(ctx.stateManager, region) - rvb.start(argsPType) - rvb.startTuple() - var i = 0 - while (i < inputsB.length) { - rvb.addAnnotation(inputTypesB(i), inputsB(i)) - i += 1 - } - rvb.endTuple() - val argsOff = rvb.end() - - val resultOff = f(theHailClassLoader, ctx.fs, ctx.taskContext, region)(region, argsOff) - SafeRow(resultType2.asInstanceOf[PBaseStruct], resultOff).get(0) + print = bytecodePrinter, + ) + assert(resultType2.virtualType == resultType) + + ctx.r.pool.scopedRegion { region => + val rvb = new RegionValueBuilder(ctx.stateManager, region) + rvb.start(argsPType) + rvb.startTuple() + var i = 0 + while (i < inputsB.length) { + rvb.addAnnotation(inputTypesB(i), inputsB(i)) + i += 1 } - } + rvb.endTuple() + val argsOff = rvb.end() + + val resultOff = f(theHailClassLoader, ctx.fs, ctx.taskContext, region)(region, argsOff) + SafeRow(resultType2.asInstanceOf[PBaseStruct], resultOff).get(0) + } + } } - def assertEvalSame(x: IR) { + def assertEvalSame(x: IR): Unit = assertEvalSame(x, Env.empty, FastSeq()) - } - def assertEvalSame(x: IR, args: IndexedSeq[(Any, Type)]) { + def assertEvalSame(x: IR, args: IndexedSeq[(Any, Type)]): Unit = assertEvalSame(x, Env.empty, args) - } - def assertEvalSame(x: IR, env: Env[(Any, Type)], args: IndexedSeq[(Any, Type)]) { + def assertEvalSame(x: IR, env: Env[(Any, Type)], args: IndexedSeq[(Any, Type)]): Unit = { val t = x.typ val (i, i2, c) = ExecuteContext.scoped() { ctx => @@ -250,45 +288,51 @@ object TestUtils { assert(t.valuesSimilar(i2, c), s"interpret (optimize = false) $i vs compile $c") } - def assertThrows[E <: Throwable : Manifest](x: IR, regex: String) { + def assertThrows[E <: Throwable: Manifest](x: IR, regex: String): Unit = assertThrows[E](x, Env.empty[(Any, Type)], FastSeq.empty[(Any, Type)], regex) - } - def assertThrows[E <: Throwable : Manifest](x: IR, env: Env[(Any, Type)], args: IndexedSeq[(Any, Type)], regex: String) { + def assertThrows[E <: Throwable: Manifest]( + x: IR, + env: Env[(Any, Type)], + args: IndexedSeq[(Any, Type)], + regex: String, + ): Unit = ExecuteContext.scoped() { ctx => interceptException[E](regex)(Interpret[Any](ctx, x, env, args)) interceptException[E](regex)(Interpret[Any](ctx, x, env, args, optimize = false)) interceptException[E](regex)(eval(x, env, args, None, None, true, ctx)) } - } - def assertFatal(x: IR, regex: String) { + def assertFatal(x: IR, regex: String): Unit = assertThrows[HailException](x, regex) - } - def assertFatal(x: IR, args: IndexedSeq[(Any, Type)], regex: String) { + def assertFatal(x: IR, args: IndexedSeq[(Any, Type)], regex: String): Unit = assertThrows[HailException](x, Env.empty[(Any, Type)], args, regex) - } - def assertFatal(x: IR, env: Env[(Any, Type)], args: IndexedSeq[(Any, Type)], regex: String) { + def assertFatal(x: IR, env: Env[(Any, Type)], args: IndexedSeq[(Any, Type)], regex: String) + : Unit = assertThrows[HailException](x, env, args, regex) - } - def assertCompiledThrows[E <: Throwable : Manifest](x: IR, env: Env[(Any, Type)], args: IndexedSeq[(Any, Type)], regex: String) { + def assertCompiledThrows[E <: Throwable: Manifest]( + x: IR, + env: Env[(Any, Type)], + args: IndexedSeq[(Any, Type)], + regex: String, + ): Unit = ExecuteContext.scoped() { ctx => interceptException[E](regex)(eval(x, env, args, None, None, true, ctx)) } - } - def assertCompiledThrows[E <: Throwable : Manifest](x: IR, regex: String) { + def assertCompiledThrows[E <: Throwable: Manifest](x: IR, regex: String): Unit = assertCompiledThrows[E](x, Env.empty[(Any, Type)], FastSeq.empty[(Any, Type)], regex) - } - def assertCompiledFatal(x: IR, regex: String) { + def assertCompiledFatal(x: IR, regex: String): Unit = assertCompiledThrows[HailException](x, regex) - } - def importVCF(ctx: ExecuteContext, file: String, force: Boolean = false, + def importVCF( + ctx: ExecuteContext, + file: String, + force: Boolean = false, forceBGZ: Boolean = false, headerFile: Option[String] = None, nPartitions: Option[Int] = None, @@ -301,15 +345,17 @@ object TestUtils { arrayElementsRequired: Boolean = true, skipInvalidLoci: Boolean = false, partitionsJSON: Option[String] = None, - partitionsTypeStr: Option[String] = None): MatrixIR = { + partitionsTypeStr: Option[String] = None, + ): MatrixIR = { val entryFloatType = TFloat64._toPretty - val reader = MatrixVCFReader(ctx, + val reader = MatrixVCFReader( + ctx, Array(file), callFields, entryFloatType, headerFile, - /*sampleIDs=*/None, + /*sampleIDs=*/ None, nPartitions, blockSizeInMB, minPartitions, @@ -321,7 +367,8 @@ object TestUtils { force, TextInputFilterAndReplace(), partitionsJSON, - partitionsTypeStr) + partitionsTypeStr, + ) MatrixRead(reader.fullMatrixTypeWithoutUIDs, dropSamples, false, reader) } } diff --git a/hail/src/test/scala/is/hail/TestUtilsSuite.scala b/hail/src/test/scala/is/hail/TestUtilsSuite.scala index d138b57e845..a6c91c5aef5 100644 --- a/hail/src/test/scala/is/hail/TestUtilsSuite.scala +++ b/hail/src/test/scala/is/hail/TestUtilsSuite.scala @@ -1,16 +1,11 @@ package is.hail -import java.io.File -import java.lang.reflect.Modifier -import java.net.URI - import breeze.linalg.{DenseMatrix, DenseVector} -import is.hail.utils.BoxedArrayBuilder -import org.testng.annotations.{DataProvider, Test} +import org.testng.annotations.Test class TestUtilsSuite extends HailSuite { - @Test def matrixEqualityTest() { + @Test def matrixEqualityTest(): Unit = { val M = DenseMatrix((1d, 0d), (0d, 1d)) val M1 = DenseMatrix((1d, 0d), (0d, 1.0001d)) val V = DenseVector(0d, 1d) @@ -24,7 +19,7 @@ class TestUtilsSuite extends HailSuite { intercept[Exception](TestUtils.assertMatrixEqualityDouble(M, M1)) } - @Test def constantVectorTest() { + @Test def constantVectorTest(): Unit = { assert(TestUtils.isConstant(DenseVector())) assert(TestUtils.isConstant(DenseVector(0))) assert(TestUtils.isConstant(DenseVector(0, 0))) @@ -34,11 +29,9 @@ class TestUtilsSuite extends HailSuite { } @Test def removeConstantColsTest(): Unit = { - val M = DenseMatrix((0, 0, 1, 1, 0), - (0, 1, 0, 1, 1)) + val M = DenseMatrix((0, 0, 1, 1, 0), (0, 1, 0, 1, 1)) - val M1 = DenseMatrix((0, 1, 0), - (1, 0, 1)) + val M1 = DenseMatrix((0, 1, 0), (1, 0, 1)) assert(TestUtils.removeConstantCols(M) == M1) } diff --git a/hail/src/test/scala/is/hail/annotations/AnnotationsSuite.scala b/hail/src/test/scala/is/hail/annotations/AnnotationsSuite.scala index 0d2637372ac..1f9a9415f3b 100644 --- a/hail/src/test/scala/is/hail/annotations/AnnotationsSuite.scala +++ b/hail/src/test/scala/is/hail/annotations/AnnotationsSuite.scala @@ -1,19 +1,12 @@ package is.hail.annotations -import is.hail.backend.ExecuteContext -import is.hail.types.virtual._ -import is.hail.testUtils._ -import is.hail.utils._ -import is.hail.{HailSuite, TestUtils} -import org.testng.annotations.Test +import is.hail.HailSuite -import scala.language.implicitConversions +import org.testng.annotations.Test -/** - * This testing suite evaluates the functionality of the [[is.hail.annotations]] package - */ +/** This testing suite evaluates the functionality of the [[is.hail.annotations]] package */ class AnnotationsSuite extends HailSuite { - @Test def testExtendedOrdering() { + @Test def testExtendedOrdering(): Unit = { val ord = ExtendedOrdering.extendToNull(implicitly[Ordering[Int]]) val rord = ord.reverse diff --git a/hail/src/test/scala/is/hail/annotations/ApproxCDFAggregatorSuite.scala b/hail/src/test/scala/is/hail/annotations/ApproxCDFAggregatorSuite.scala index 8b78e81efc9..6511d635f5c 100644 --- a/hail/src/test/scala/is/hail/annotations/ApproxCDFAggregatorSuite.scala +++ b/hail/src/test/scala/is/hail/annotations/ApproxCDFAggregatorSuite.scala @@ -1,36 +1,37 @@ package is.hail.annotations import is.hail.expr.ir.agg._ -import org.testng.annotations.Test + import org.scalatest.testng.TestNGSuite +import org.testng.annotations.Test class ApproxCDFAggregatorSuite extends TestNGSuite { @Test - def testMerge() { - val array: Array[Double] = Array(1,3,5,0,0,0,2,4,6) + def testMerge(): Unit = { + val array: Array[Double] = Array(1, 3, 5, 0, 0, 0, 2, 4, 6) ApproxCDFHelper.merge(array, 0, 3, array, 6, 9, array, 3) assert(array.view(3, 9) sameElements Range(1, 7)) } @Test - def testCompactLevelZero() { + def testCompactLevelZero(): Unit = { val rand = new java.util.Random(1) // first Boolean is `true` - val levels: Array[Int] = Array(0,4,7,10) - val items: Array[Double] = Array(7,2,6,4, 1,3,8, 0,5,9) + val levels: Array[Int] = Array(0, 4, 7, 10) + val items: Array[Double] = Array(7, 2, 6, 4, 1, 3, 8, 0, 5, 9) val compactionCounts: Array[Int] = Array(0, 0, 0) val combiner = new ApproxCDFCombiner(levels, items, compactionCounts, 3, rand) combiner.compactLevel(0) - assert(items.view(1,10) sameElements Array(2,7, 1,3,6,8, 0,5,9)) + assert(items.view(1, 10) sameElements Array(2, 7, 1, 3, 6, 8, 0, 5, 9)) } @Test - def testCompactLevel() { + def testCompactLevel(): Unit = { val rand = new java.util.Random(1) // first Boolean is `true` - val levels: Array[Int] = Array(0,3,6,9) - val items: Array[Double] = Array(7,2,4, 1,3,8, 0,5,9) + val levels: Array[Int] = Array(0, 3, 6, 9) + val items: Array[Double] = Array(7, 2, 4, 1, 3, 8, 0, 5, 9) val compactionCounts: Array[Int] = Array(0, 0, 0) val combiner = new ApproxCDFCombiner(levels, items, compactionCounts, 3, rand) combiner.compactLevel(1) - assert(items.view(1,9) sameElements Array(7,2,4, 1, 0,5,8,9)) + assert(items.view(1, 9) sameElements Array(7, 2, 4, 1, 0, 5, 8, 9)) } } diff --git a/hail/src/test/scala/is/hail/annotations/RegionSuite.scala b/hail/src/test/scala/is/hail/annotations/RegionSuite.scala index 7ffee1a3db9..b2976cc25d6 100644 --- a/hail/src/test/scala/is/hail/annotations/RegionSuite.scala +++ b/hail/src/test/scala/is/hail/annotations/RegionSuite.scala @@ -1,28 +1,23 @@ package is.hail.annotations import is.hail.expr.ir.LongArrayBuilder -import is.hail.utils.{info, using} -import org.scalatest.testng.TestNGSuite -import org.testng.annotations.Test +import is.hail.utils.using import scala.collection.mutable.ArrayBuffer +import org.scalatest.testng.TestNGSuite +import org.testng.annotations.Test + class RegionSuite extends TestNGSuite { - @Test def testRegionSizes() { + @Test def testRegionSizes(): Unit = RegionPool.scoped { pool => - pool.scopedSmallRegion { region => - Array.range(0, 30).foreach { _ => region.allocate(1, 500) } - } - + pool.scopedSmallRegion(region => Array.range(0, 30).foreach(_ => region.allocate(1, 500))) - pool.scopedTinyRegion { region => - Array.range(0, 30).foreach { _ => region.allocate(1, 60) } - } + pool.scopedTinyRegion(region => Array.range(0, 30).foreach(_ => region.allocate(1, 60))) } - } - @Test def testRegionAllocationSimple() { + @Test def testRegionAllocationSimple(): Unit = { using(RegionPool(strictMemoryCheck = true)) { pool => assert(pool.numFreeBlocks() == 0) assert(pool.numRegions() == 0) @@ -78,12 +73,14 @@ class RegionSuite extends TestNGSuite { } } - @Test def testRegionAllocation() { + @Test def testRegionAllocation(): Unit = { RegionPool.scoped { pool => case class Counts(regions: Int, freeRegions: Int) { def allocate(n: Int): Counts = - copy(regions = regions + math.max(0, n - freeRegions), - freeRegions = math.max(0, freeRegions - n)) + copy( + regions = regions + math.max(0, n - freeRegions), + freeRegions = math.max(0, freeRegions - n), + ) def free(nRegions: Int, nExtraBlocks: Int = 0): Counts = copy(freeRegions = freeRegions + nRegions) @@ -110,15 +107,15 @@ class RegionSuite extends TestNGSuite { assertAfterEquals(before.free(2)) pool.scopedRegion { region => - pool.scopedRegion { region2 => region.addReferenceTo(region2) } - pool.scopedRegion { region2 => region.addReferenceTo(region2) } + pool.scopedRegion(region2 => region.addReferenceTo(region2)) + pool.scopedRegion(region2 => region.addReferenceTo(region2)) assertAfterEquals(before.allocate(3)) } assertAfterEquals(before.free(3)) } } - @Test def testRegionReferences() { + @Test def testRegionReferences(): Unit = { RegionPool.scoped { pool => def offset(region: Region) = region.allocate(0) @@ -131,27 +128,21 @@ class RegionSuite extends TestNGSuite { res } - val region = Region(pool=pool) + val region = Region(pool = pool) region.setNumParents(5) val off4 = using(assertUsesRegions(1) { region.getParentReference(4, Region.SMALL) - }) { r => - offset(r) - } + })(r => offset(r)) val off2 = pool.scopedTinyRegion { r => region.setParentReference(r, 2) offset(r) } - using(region.getParentReference(2, Region.TINY)) { r => - assert(offset(r) == off2) - } + using(region.getParentReference(2, Region.TINY))(r => assert(offset(r) == off2)) - using(region.getParentReference(4, Region.SMALL)) { r => - assert(offset(r) == off4) - } + using(region.getParentReference(4, Region.SMALL))(r => assert(offset(r) == off4)) assertUsesRegions(-1) { region.unreferenceRegionAtIndex(2) @@ -239,7 +230,6 @@ class RegionSuite extends TestNGSuite { @Test def testChunkCache(): Unit = { RegionPool.scoped { pool => - val operations = ArrayBuffer[(String, Long)]() def allocate(numBytes: Long): Long = { @@ -258,11 +248,11 @@ class RegionSuite extends TestNGSuite { chunkCache.freeChunkToCache(ab.pop()) ab += chunkCache.getChunk(pool, 50L)._1 assert(operations(0) == (("allocate", 512))) - //512 size chunk freed from cache to not exceed peak memory + // 512 size chunk freed from cache to not exceed peak memory assert(operations(1) == (("free", 0L))) assert(operations(2) == (("allocate", 64))) chunkCache.freeChunkToCache(ab.pop()) - //No additional allocate should be made as uses cache + // No additional allocate should be made as uses cache ab += chunkCache.getChunk(pool, 50L)._1 assert(operations.length == 3) ab += chunkCache.getChunk(pool, 40L)._1 diff --git a/hail/src/test/scala/is/hail/annotations/ScalaToRegionValue.scala b/hail/src/test/scala/is/hail/annotations/ScalaToRegionValue.scala index 56ffe05cb8f..2fe8e143333 100644 --- a/hail/src/test/scala/is/hail/annotations/ScalaToRegionValue.scala +++ b/hail/src/test/scala/is/hail/annotations/ScalaToRegionValue.scala @@ -1,10 +1,9 @@ package is.hail.annotations -import is.hail.types.physical.PType import is.hail.backend.HailStateManager +import is.hail.types.physical.PType object ScalaToRegionValue { - def apply(sm: HailStateManager, region: Region, t: PType, a: Annotation): Long = { + def apply(sm: HailStateManager, region: Region, t: PType, a: Annotation): Long = t.unstagedStoreJavaObject(sm, a, region) - } } diff --git a/hail/src/test/scala/is/hail/annotations/StagedConstructorSuite.scala b/hail/src/test/scala/is/hail/annotations/StagedConstructorSuite.scala index 78a188e78ac..8acfb0089c4 100644 --- a/hail/src/test/scala/is/hail/annotations/StagedConstructorSuite.scala +++ b/hail/src/test/scala/is/hail/annotations/StagedConstructorSuite.scala @@ -10,6 +10,7 @@ import is.hail.types.physical.stypes.interfaces._ import is.hail.types.physical.stypes.primitives.SInt32Value import is.hail.types.virtual._ import is.hail.utils._ + import org.apache.spark.sql.Row import org.testng.annotations.Test @@ -20,7 +21,7 @@ class StagedConstructorSuite extends HailSuite { def sm = ctx.stateManager @Test - def testCanonicalString() { + def testCanonicalString(): Unit = { val rt = PCanonicalString() val input = "hello" val fb = EmitFunctionBuilder[Region, String, Long](ctx, "fb") @@ -28,10 +29,15 @@ class StagedConstructorSuite extends HailSuite { fb.emitWithBuilder { cb => val st = SStringPointer(rt) val region = fb.emb.getCodeParam[Region](1) - rt.store(cb, region, st.constructFromString(cb, region, fb.getCodeParam[String](2)), deepCopy = false) + rt.store( + cb, + region, + st.constructFromString(cb, region, fb.getCodeParam[String](2)), + deepCopy = false, + ) } - val region = Region(pool=pool) + val region = Region(pool = pool) val rv = RegionValue(region) rv.setOffset(fb.result()(theHailClassLoader)(region, input)) @@ -40,7 +46,7 @@ class StagedConstructorSuite extends HailSuite { println(rv.pretty(rt)) } - val region2 = Region(pool=pool) + val region2 = Region(pool = pool) val rv2 = RegionValue(region2) val bytes = input.getBytes() val bt = PCanonicalBinary() @@ -59,16 +65,21 @@ class StagedConstructorSuite extends HailSuite { } @Test - def testInt() { + def testInt(): Unit = { val rt = PInt32() val input = 3 val fb = EmitFunctionBuilder[Region, Int, Long](ctx, "fb") fb.emitWithBuilder { cb => - rt.store(cb, fb.emb.getCodeParam[Region](1), primitive(fb.getCodeParam[Int](2)), deepCopy = false) + rt.store( + cb, + fb.emb.getCodeParam[Region](1), + primitive(fb.getCodeParam[Int](2)), + deepCopy = false, + ) } - val region = Region(pool=pool) + val region = Region(pool = pool) val rv = RegionValue(region) rv.setOffset(fb.result()(theHailClassLoader)(region, input)) @@ -77,7 +88,7 @@ class StagedConstructorSuite extends HailSuite { println(rv.pretty(rt)) } - val region2 = Region(pool=pool) + val region2 = Region(pool = pool) val rv2 = RegionValue(region2) rv2.setOffset(region2.allocate(4, 4)) Region.storeInt(rv2.offset, input) @@ -92,7 +103,7 @@ class StagedConstructorSuite extends HailSuite { } @Test - def testArray() { + def testArray(): Unit = { val rt = PCanonicalArray(PInt32()) val input = 3 val fb = EmitFunctionBuilder[Region, Int, Long](ctx, "fb") @@ -105,7 +116,7 @@ class StagedConstructorSuite extends HailSuite { }.a } - val region = Region(pool=pool) + val region = Region(pool = pool) val rv = RegionValue(region) rv.setOffset(fb.result()(theHailClassLoader)(region, input)) @@ -114,7 +125,7 @@ class StagedConstructorSuite extends HailSuite { println(rv.pretty(rt)) } - val region2 = Region(pool=pool) + val region2 = Region(pool = pool) val rv2 = RegionValue(region2) rv2.setOffset(ScalaToRegionValue(sm, region2, rt, FastSeq(input))) @@ -130,7 +141,7 @@ class StagedConstructorSuite extends HailSuite { } @Test - def testStruct() { + def testStruct(): Unit = { val pstring = PCanonicalString() val rt = PCanonicalStruct("a" -> pstring, "b" -> PInt32()) val input = 3 @@ -138,18 +149,21 @@ class StagedConstructorSuite extends HailSuite { fb.emitWithBuilder { cb => val region = fb.emb.getCodeParam[Region](1) - rt.constructFromFields(cb, region, FastSeq( - EmitCode.fromI(cb.emb) { cb => - val st = SStringPointer(pstring) - IEmitCode.present(cb, st.constructFromString(cb, region, const("hello"))) - }, - EmitCode.fromI(cb.emb) { cb => - IEmitCode.present(cb, primitive(fb.getCodeParam[Int](2))) - } - ), deepCopy = false).a + rt.constructFromFields( + cb, + region, + FastSeq( + EmitCode.fromI(cb.emb) { cb => + val st = SStringPointer(pstring) + IEmitCode.present(cb, st.constructFromString(cb, region, const("hello"))) + }, + EmitCode.fromI(cb.emb)(cb => IEmitCode.present(cb, primitive(fb.getCodeParam[Int](2)))), + ), + deepCopy = false, + ).a } - val region = Region(pool=pool) + val region = Region(pool = pool) val rv = RegionValue(region) rv.setOffset(fb.result()(theHailClassLoader)(region, input)) @@ -158,7 +172,7 @@ class StagedConstructorSuite extends HailSuite { println(rv.pretty(rt)) } - val region2 = Region(pool=pool) + val region2 = Region(pool = pool) val rv2 = RegionValue(region2) rv2.setOffset(ScalaToRegionValue(sm, region2, rt, Annotation("hello", input))) @@ -175,7 +189,7 @@ class StagedConstructorSuite extends HailSuite { } @Test - def testArrayOfStruct() { + def testArrayOfStruct(): Unit = { val structType = PCanonicalStruct("a" -> PInt32(), "b" -> PCanonicalString()) val arrayType = PCanonicalArray(structType) val input = "hello" @@ -186,14 +200,27 @@ class StagedConstructorSuite extends HailSuite { arrayType.constructFromElements(cb, region, const(2), false) { (cb, idx) => val st = SStringPointer(PCanonicalString()) - IEmitCode.present(cb, structType.constructFromFields(cb, region, FastSeq( - EmitCode.fromI(cb.emb)(cb => IEmitCode.present(cb, primitive(cb.memoize(idx + 1)))), - EmitCode.fromI(cb.emb)(cb => IEmitCode.present(cb, st.constructFromString(cb, region, fb.getCodeParam[String](2)))) - ), deepCopy = false)) + IEmitCode.present( + cb, + structType.constructFromFields( + cb, + region, + FastSeq( + EmitCode.fromI(cb.emb)(cb => IEmitCode.present(cb, primitive(cb.memoize(idx + 1)))), + EmitCode.fromI(cb.emb)(cb => + IEmitCode.present( + cb, + st.constructFromString(cb, region, fb.getCodeParam[String](2)), + ) + ), + ), + deepCopy = false, + ), + ) }.a } - val region = Region(pool=pool) + val region = Region(pool = pool) val rv = RegionValue(region) rv.setOffset(fb.result()(theHailClassLoader)(region, input)) @@ -202,7 +229,7 @@ class StagedConstructorSuite extends HailSuite { println(rv.pretty(arrayType)) } - val region2 = Region(pool=pool) + val region2 = Region(pool = pool) val rv2 = RegionValue(region2) val rvb = new RegionValueBuilder(sm, region2) rvb.start(arrayType) @@ -223,16 +250,17 @@ class StagedConstructorSuite extends HailSuite { assert(rv.pretty(arrayType) == rv2.pretty(arrayType)) assert(new UnsafeIndexedSeq(arrayType, rv.region, rv.offset).sameElements( - new UnsafeIndexedSeq(arrayType, rv2.region, rv2.offset))) + new UnsafeIndexedSeq(arrayType, rv2.region, rv2.offset) + )) } @Test - def testMissingRandomAccessArray() { + def testMissingRandomAccessArray(): Unit = { val rt = PCanonicalArray(PCanonicalStruct("a" -> PInt32(), "b" -> PCanonicalString())) val intVal = 20 val strVal = "a string with a partner of 20" - val region = Region(pool=pool) - val region2 = Region(pool=pool) + val region = Region(pool = pool) + val region2 = Region(pool = pool) val rvb = new RegionValueBuilder(sm, region) val rvb2 = new RegionValueBuilder(sm, region2) val rv = RegionValue(region) @@ -264,16 +292,17 @@ class StagedConstructorSuite extends HailSuite { rv2.setOffset(rvb2.end()) assert(rv.pretty(rt) == rv2.pretty(rt)) assert(new UnsafeIndexedSeq(rt, rv.region, rv.offset).sameElements( - new UnsafeIndexedSeq(rt, rv2.region, rv2.offset))) + new UnsafeIndexedSeq(rt, rv2.region, rv2.offset) + )) } @Test - def testSetFieldPresent() { + def testSetFieldPresent(): Unit = { val rt = PCanonicalStruct("a" -> PInt32(), "b" -> PCanonicalString(), "c" -> PFloat64()) val intVal = 30 val floatVal = 39.273d - val r = Region(pool=pool) - val r2 = Region(pool=pool) + val r = Region(pool = pool) + val r2 = Region(pool = pool) val rv = RegionValue(r) val rv2 = RegionValue(r2) val rvb = new RegionValueBuilder(sm, r) @@ -306,7 +335,7 @@ class StagedConstructorSuite extends HailSuite { } @Test - def testStructWithArray() { + def testStructWithArray(): Unit = { val tArray = PCanonicalArray(PInt32()) val rt = PCanonicalStruct("a" -> PCanonicalString(), "b" -> tArray) val input = "hello" @@ -314,19 +343,34 @@ class StagedConstructorSuite extends HailSuite { fb.emitWithBuilder { cb => val region = fb.emb.getCodeParam[Region](1) - rt.constructFromFields(cb, region, FastSeq( - EmitCode.fromI(cb.emb)(cb => - IEmitCode.present(cb, - SStringPointer(PCanonicalString()).constructFromString(cb, region, fb.getCodeParam[String](2)))), - EmitCode.fromI(cb.emb)(cb => - IEmitCode.present(cb, - tArray.constructFromElements(cb, region, const(2), deepCopy = false) { (cb, idx) => - IEmitCode.present(cb, primitive(cb.memoize(idx + 1))) - })) - ), deepCopy = false).a + rt.constructFromFields( + cb, + region, + FastSeq( + EmitCode.fromI(cb.emb)(cb => + IEmitCode.present( + cb, + SStringPointer(PCanonicalString()).constructFromString( + cb, + region, + fb.getCodeParam[String](2), + ), + ) + ), + EmitCode.fromI(cb.emb)(cb => + IEmitCode.present( + cb, + tArray.constructFromElements(cb, region, const(2), deepCopy = false) { (cb, idx) => + IEmitCode.present(cb, primitive(cb.memoize(idx + 1))) + }, + ) + ), + ), + deepCopy = false, + ).a } - val region = Region(pool=pool) + val region = Region(pool = pool) val rv = RegionValue(region) rv.setOffset(fb.result()(theHailClassLoader)(region, input)) @@ -335,7 +379,7 @@ class StagedConstructorSuite extends HailSuite { println(rv.pretty(rt)) } - val region2 = Region(pool=pool) + val region2 = Region(pool = pool) val rv2 = RegionValue(region2) val rvb = new RegionValueBuilder(sm, region2) @@ -343,9 +387,8 @@ class StagedConstructorSuite extends HailSuite { rvb.startStruct() rvb.addString(input) rvb.startArray(2) - for (i <- 1 to 2) { + for (i <- 1 to 2) rvb.addInt(i) - } rvb.endArray() rvb.endStruct() @@ -362,7 +405,7 @@ class StagedConstructorSuite extends HailSuite { } @Test - def testMissingArray() { + def testMissingArray(): Unit = { val rt = PCanonicalArray(PInt32()) val input = 3 val fb = EmitFunctionBuilder[Region, Int, Long](ctx, "fb") @@ -374,7 +417,7 @@ class StagedConstructorSuite extends HailSuite { }.a } - val region = Region(pool=pool) + val region = Region(pool = pool) val rv = RegionValue(region) rv.setOffset(fb.result()(theHailClassLoader)(region, input)) @@ -383,7 +426,7 @@ class StagedConstructorSuite extends HailSuite { println(rv.pretty(rt)) } - val region2 = Region(pool=pool) + val region2 = Region(pool = pool) val rv2 = RegionValue(region2) rv2.setOffset(ScalaToRegionValue(sm, region2, rt, FastSeq(input, null))) @@ -394,45 +437,54 @@ class StagedConstructorSuite extends HailSuite { assert(rv.pretty(rt) == rv2.pretty(rt)) assert(new UnsafeIndexedSeq(rt, rv.region, rv.offset).sameElements( - new UnsafeIndexedSeq(rt, rv2.region, rv2.offset))) + new UnsafeIndexedSeq(rt, rv2.region, rv2.offset) + )) } - def printRegion(region: Region, string: String) { + def printRegion(region: Region, string: String): Unit = println(region.prettyBits()) - } @Test - def testAddPrimitive() { + def testAddPrimitive(): Unit = { val t = PCanonicalStruct("a" -> PInt32(), "b" -> PBoolean(), "c" -> PFloat64()) val fb = EmitFunctionBuilder[Region, Int, Boolean, Double, Long](ctx, "fb") fb.emitWithBuilder { cb => val region = fb.emb.getCodeParam[Region](1) - t.constructFromFields(cb, region, FastSeq( - EmitCode.fromI(cb.emb)(cb => IEmitCode.present(cb, primitive(fb.getCodeParam[Int](2)))), - EmitCode.fromI(cb.emb)(cb => IEmitCode.present(cb, primitive(fb.getCodeParam[Boolean](3)))), - EmitCode.fromI(cb.emb)(cb => IEmitCode.present(cb, primitive(fb.getCodeParam[Double](4)))) - ), deepCopy = false).a + t.constructFromFields( + cb, + region, + FastSeq( + EmitCode.fromI(cb.emb)(cb => IEmitCode.present(cb, primitive(fb.getCodeParam[Int](2)))), + EmitCode.fromI(cb.emb)(cb => + IEmitCode.present(cb, primitive(fb.getCodeParam[Boolean](3))) + ), + EmitCode.fromI(cb.emb)(cb => IEmitCode.present(cb, primitive(fb.getCodeParam[Double](4)))), + ), + deepCopy = false, + ).a } - val region = Region(pool=pool) + val region = Region(pool = pool) val f = fb.result()(theHailClassLoader) def run(i: Int, b: Boolean, d: Double): (Int, Boolean, Double) = { val off = f(region, i, b, d) - (Region.loadInt(t.loadField(off, 0)), + ( + Region.loadInt(t.loadField(off, 0)), Region.loadBoolean(t.loadField(off, 1)), - Region.loadDouble(t.loadField(off, 2))) + Region.loadDouble(t.loadField(off, 2)), + ) } assert(run(3, true, 42.0) == ((3, true, 42.0))) assert(run(42, false, -1.0) == ((42, false, -1.0))) } - @Test def testDeepCopy() { + @Test def testDeepCopy(): Unit = { val g = Type.genStruct .flatMap(t => Gen.zip(Gen.const(t), t.genValue(sm))) - .filter { case (t, a) => a != null } + .filter { case (_, a) => a != null } .map { case (t, a) => (PType.canonical(t).asInstanceOf[PStruct], a) } val p = Prop.forAll(g) { case (t, a) => @@ -442,15 +494,18 @@ class StagedConstructorSuite extends HailSuite { val src = ScalaToRegionValue(sm, srcRegion, t, a) val fb = EmitFunctionBuilder[Region, Long, Long](ctx, "deep_copy") - fb.emitWithBuilder[Long](cb => t.store(cb, - fb.apply_method.getCodeParam[Region](1), - t.loadCheapSCode(cb, fb.apply_method.getCodeParam[Long](2)), - deepCopy = true)) + fb.emitWithBuilder[Long](cb => + t.store( + cb, + fb.apply_method.getCodeParam[Region](1), + t.loadCheapSCode(cb, fb.apply_method.getCodeParam[Long](2)), + deepCopy = true, + ) + ) val copyF = fb.resultWithIndex()(theHailClassLoader, ctx.fs, ctx.taskContext, region) val newOff = copyF(region, src) - - //clear old stuff + // clear old stuff val len = srcRegion.allocate(0) - src Region.storeBytes(src, Array.fill(len.toInt)(0.toByte)) newOff @@ -462,19 +517,25 @@ class StagedConstructorSuite extends HailSuite { p.check() } - @Test def testUnstagedCopy() { - val t1 = PCanonicalArray(PCanonicalStruct( - true, - "x1" -> PInt32(), - "x2" -> PCanonicalArray(PInt32(), required = true), - "x3" -> PCanonicalArray(PInt32(true), required = true), - "x4" -> PCanonicalSet(PCanonicalStruct(true, "y" -> PCanonicalString(true)), required = false) - ), required = false) + @Test def testUnstagedCopy(): Unit = { + val t1 = PCanonicalArray( + PCanonicalStruct( + true, + "x1" -> PInt32(), + "x2" -> PCanonicalArray(PInt32(), required = true), + "x3" -> PCanonicalArray(PInt32(true), required = true), + "x4" -> PCanonicalSet( + PCanonicalStruct(true, "y" -> PCanonicalString(true)), + required = false, + ), + ), + required = false, + ) val t2 = RequirednessSuite.deepInnerRequired(t1, false) val value = IndexedSeq( - Row(1, IndexedSeq(1,2,3), IndexedSeq(0, -1), Set(Row("asdasdasd"), Row(""))), - Row(1, IndexedSeq(), IndexedSeq(-1), Set(Row("aa"))) + Row(1, IndexedSeq(1, 2, 3), IndexedSeq(0, -1), Set(Row("asdasdasd"), Row(""))), + Row(1, IndexedSeq(), IndexedSeq(-1), Set(Row("aa"))), ) pool.scopedRegion { r => @@ -490,19 +551,28 @@ class StagedConstructorSuite extends HailSuite { } } - @Test def testStagedCopy() { - val t1 = PCanonicalStruct(false, "a" -> PCanonicalArray(PCanonicalStruct( - true, - "x1" -> PInt32(), - "x2" -> PCanonicalArray(PInt32(), required = true), - "x3" -> PCanonicalArray(PInt32(true), required = true), - "x4" -> PCanonicalSet(PCanonicalStruct(true, "y" -> PCanonicalString(true)), required = false) - ), required = false)) + @Test def testStagedCopy(): Unit = { + val t1 = PCanonicalStruct( + false, + "a" -> PCanonicalArray( + PCanonicalStruct( + true, + "x1" -> PInt32(), + "x2" -> PCanonicalArray(PInt32(), required = true), + "x3" -> PCanonicalArray(PInt32(true), required = true), + "x4" -> PCanonicalSet( + PCanonicalStruct(true, "y" -> PCanonicalString(true)), + required = false, + ), + ), + required = false, + ), + ) val t2 = RequirednessSuite.deepInnerRequired(t1, false).asInstanceOf[PCanonicalStruct] val value = IndexedSeq( - Row(1, IndexedSeq(1,2,3), IndexedSeq(0, -1), Set(Row("asdasdasd"), Row(""))), - Row(1, IndexedSeq(), IndexedSeq(-1), Set(Row("aa"))) + Row(1, IndexedSeq(1, 2, 3), IndexedSeq(0, -1), Set(Row("asdasdasd"), Row(""))), + Row(1, IndexedSeq(), IndexedSeq(-1), Set(Row("aa"))), ) val valueT2 = t2.types(0) @@ -513,7 +583,12 @@ class StagedConstructorSuite extends HailSuite { val f1 = EmitFunctionBuilder[Long](ctx, "stagedCopy1") f1.emitWithBuilder { cb => val region = f1.partitionRegion - t2.constructFromFields(cb, region, FastSeq(EmitCode.present(cb.emb, t2.types(0).loadCheapSCode(cb, v1))), deepCopy = false).a + t2.constructFromFields( + cb, + region, + FastSeq(EmitCode.present(cb.emb, t2.types(0).loadCheapSCode(cb, v1))), + deepCopy = false, + ).a } val cp1 = f1.resultWithIndex()(theHailClassLoader, ctx.fs, ctx.taskContext, r)() assert(SafeRow.read(t2, cp1) == Row(value)) @@ -521,7 +596,12 @@ class StagedConstructorSuite extends HailSuite { val f2 = EmitFunctionBuilder[Long](ctx, "stagedCopy2") f2.emitWithBuilder { cb => val region = f2.partitionRegion - t1.constructFromFields(cb, region, FastSeq(EmitCode.present(cb.emb, t2.types(0).loadCheapSCode(cb, v1))), deepCopy = false).a + t1.constructFromFields( + cb, + region, + FastSeq(EmitCode.present(cb.emb, t2.types(0).loadCheapSCode(cb, v1))), + deepCopy = false, + ).a } val cp2 = f2.resultWithIndex()(theHailClassLoader, ctx.fs, ctx.taskContext, r)() assert(SafeRow.read(t1, cp2) == Row(value)) diff --git a/hail/src/test/scala/is/hail/annotations/UnsafeSuite.scala b/hail/src/test/scala/is/hail/annotations/UnsafeSuite.scala index 1d34678657e..2730821034a 100644 --- a/hail/src/test/scala/is/hail/annotations/UnsafeSuite.scala +++ b/hail/src/test/scala/is/hail/annotations/UnsafeSuite.scala @@ -7,20 +7,23 @@ import is.hail.rvd.AbstractRVDSpec import is.hail.types.physical._ import is.hail.types.virtual.{TArray, TStruct, Type} import is.hail.utils._ + +import scala.util.Random + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} + import org.apache.spark.sql.Row import org.json4s.jackson.Serialization import org.testng.annotations.{DataProvider, Test} -import java.io.{ByteArrayInputStream, ByteArrayOutputStream} -import scala.util.Random - class UnsafeSuite extends HailSuite { def subsetType(t: Type): Type = { t match { case t: TStruct => TStruct( t.fields.filter(_ => Random.nextDouble() < 0.4) - .map(f => f.name -> f.typ): _*) + .map(f => f.name -> f.typ): _* + ) case t: TArray => TArray(subsetType(t.elementType)) @@ -50,35 +53,33 @@ class UnsafeSuite extends HailSuite { def sm = ctx.stateManager @DataProvider(name = "codecs") - def codecs(): Array[Array[Any]] = { - (BufferSpec.specs ++ Array(TypedCodecSpec(PCanonicalStruct("x" -> PInt64()), BufferSpec.default))) + def codecs(): Array[Array[Any]] = + (BufferSpec.specs ++ Array(TypedCodecSpec( + PCanonicalStruct("x" -> PInt64()), + BufferSpec.default, + ))) .map(x => Array[Any](x)) - } - @Test(dataProvider = "codecs") def testCodecSerialization(codec: Spec) { + @Test(dataProvider = "codecs") def testCodecSerialization(codec: Spec): Unit = { implicit val formats = AbstractRVDSpec.formats assert(Serialization.read[Spec](codec.toString) == codec) } - @Test def testCodec() { - val region = Region(pool=pool) - val region2 = Region(pool=pool) - val region3 = Region(pool=pool) - val region4 = Region(pool=pool) - val rvb = new RegionValueBuilder(sm, region) - - val path = ctx.createTmpPath("test-codec", "ser") + @Test def testCodec(): Unit = { + val region = Region(pool = pool) + val region2 = Region(pool = pool) + val region3 = Region(pool = pool) + val region4 = Region(pool = pool) val g = Type.genStruct .flatMap(t => Gen.zip(Gen.const(t), t.genValue(sm))) - .filter { case (t, a) => a != null } + .filter { case (_, a) => a != null } val p = Prop.forAll(g) { case (t, a) => assert(t.typeCheck(a)) val pt = PType.canonical(t).asInstanceOf[PStruct] val requestedType = subsetType(t).asInstanceOf[TStruct] - val prt = PType.canonical(requestedType).asInstanceOf[PStruct] val a2 = subset(t, requestedType, a) assert(requestedType.typeCheck(a2)) @@ -135,7 +136,7 @@ class UnsafeSuite extends HailSuite { p.check() } - @Test def testCodecForNonWrappedTypes() { + @Test def testCodecForNonWrappedTypes(): Unit = { val valuesAndTypes = FastSeq( 5 -> PInt32(), 6L -> PInt64(), @@ -143,7 +144,8 @@ class UnsafeSuite extends HailSuite { 5.7d -> PFloat64(), "foo" -> PCanonicalString(), Array[Byte](61, 62, 63) -> PCanonicalBinary(), - FastSeq[Int](1, 2, 3) -> PCanonicalArray(PInt32())) + FastSeq[Int](1, 2, 3) -> PCanonicalArray(PInt32()), + ) valuesAndTypes.foreach { case (v, t) => pool.scopedRegion { region => @@ -158,7 +160,8 @@ class UnsafeSuite extends HailSuite { val serialized = baos.toByteArray val (decT, dec) = cs2.buildDecoder(ctx, t.virtualType) assert(decT == t) - val res = dec((new ByteArrayInputStream(serialized)), theHailClassLoader).readRegionValue(region) + val res = + dec((new ByteArrayInputStream(serialized)), theHailClassLoader).readRegionValue(region) assert(t.unsafeOrdering(sm).equiv(res, off)) } @@ -166,7 +169,7 @@ class UnsafeSuite extends HailSuite { } } - @Test def testBufferWriteReadDoubles() { + @Test def testBufferWriteReadDoubles(): Unit = { val a = Array(1.0, -349.273, 0.0, 9925.467, 0.001) BufferSpec.specs.foreach { bufferSpec => @@ -184,15 +187,15 @@ class UnsafeSuite extends HailSuite { } } - @Test def testRegionValue() { - val region = Region(pool=pool) - val region2 = Region(pool=pool) + @Test def testRegionValue(): Unit = { + val region = Region(pool = pool) + val region2 = Region(pool = pool) val rvb = new RegionValueBuilder(sm, region) val rvb2 = new RegionValueBuilder(sm, region2) val g = Type.genArb .flatMap(t => Gen.zip(Gen.const(t), t.genValue(sm), Gen.choose(0, 100), Gen.choose(0, 100))) - .filter { case (t, a, n, n2) => a != null } + .filter { case (_, a, _, _) => a != null } val p = Prop.forAll(g) { case (t, a, n, n2) => val pt = PType.canonical(t) t.typeCheck(a) @@ -233,7 +236,8 @@ class UnsafeSuite extends HailSuite { val ps = pt.asInstanceOf[PStruct] region2.clear() region2.allocate(1, n) // preallocate - val offset4 = ps.unstagedStoreJavaObject(sm, Row.fromSeq(a.asInstanceOf[Row].toSeq), region2) + val offset4 = + ps.unstagedStoreJavaObject(sm, Row.fromSeq(a.asInstanceOf[Row].toSeq), region2) val ur4 = new UnsafeRow(ps, region2, offset4) assert(t.valuesSimilar(a, ur4)) case _ => @@ -250,7 +254,8 @@ class UnsafeSuite extends HailSuite { t match { case t: TStruct => val ps = pt.asInstanceOf[PStruct] - val offset6 = ps.unstagedStoreJavaObject(sm, Row.fromSeq(a.asInstanceOf[Row].toSeq), region) + val offset6 = + ps.unstagedStoreJavaObject(sm, Row.fromSeq(a.asInstanceOf[Row].toSeq), region) val ur6 = new UnsafeRow(ps, region, offset6) assert(t.valuesSimilar(a, ur6)) case _ => @@ -261,7 +266,6 @@ class UnsafeSuite extends HailSuite { p.check() } - val g = (for { s <- Gen.size // prefer smaller type and bigger values @@ -272,62 +276,61 @@ class UnsafeSuite extends HailSuite { v <- t.genNonmissingValue(sm).resize(y) } yield (t, v)).filter(_._2 != null) - @Test def testPacking() { + @Test def testPacking(): Unit = { - def makeStruct(types: PType*): PCanonicalStruct = { + def makeStruct(types: PType*): PCanonicalStruct = PCanonicalStruct(types.zipWithIndex.map { case (t, i) => (s"f$i", t) }: _*) - } val t1 = makeStruct( // missing byte is 0 - PInt32(), //4-8 - PInt32(), //8-12 - PFloat64(), //16-24 - PBoolean(), //1-2 - PBoolean(), //2-3 - PBoolean(), //3-4 - PBoolean(), //12-13 - PBoolean()) //13-14 + PInt32(), // 4-8 + PInt32(), // 8-12 + PFloat64(), // 16-24 + PBoolean(), // 1-2 + PBoolean(), // 2-3 + PBoolean(), // 3-4 + PBoolean(), // 12-13 + PBoolean(), + ) // 13-14 assert(t1.byteOffsets.toSeq == Seq(4, 8, 16, 1, 2, 3, 12, 13)) assert(t1.byteSize == 24) - val t2 = makeStruct( //missing bytes 0, 1 - PBoolean(), //2-3 - PInt32(), //4-8 - PInt32(), //8-12 - PFloat64(), //16-24 - PInt32(), //12-16 - PInt32(), //24-28 - PFloat64(), //32-40 - PInt32(), //28-32 - PBoolean(), //3-4 - PFloat64(), //40-48 - PBoolean()) //48-49 + val t2 = makeStruct( // missing bytes 0, 1 + PBoolean(), // 2-3 + PInt32(), // 4-8 + PInt32(), // 8-12 + PFloat64(), // 16-24 + PInt32(), // 12-16 + PInt32(), // 24-28 + PFloat64(), // 32-40 + PInt32(), // 28-32 + PBoolean(), // 3-4 + PFloat64(), // 40-48 + PBoolean(), + ) // 48-49 assert(t2.byteOffsets.toSeq == Seq(2, 4, 8, 16, 12, 24, 32, 28, 3, 40, 48)) assert(t2.byteSize == 49) val t3 = makeStruct((0 until 512).map(_ => PFloat64()): _*) assert(t3.byteSize == (512 / 8) + 512 * 8) - val t4 = makeStruct((0 until 256).flatMap(_ => Iterator(PInt32(), PInt32(), PFloat64(), PBoolean())): _*) + val t4 = makeStruct((0 until 256).flatMap(_ => + Iterator(PInt32(), PInt32(), PFloat64(), PBoolean()) + ): _*) assert(t4.byteSize == 256 * 4 / 8 + 256 * 4 * 2 + 256 * 8 + 256) } - @Test def testEmptySize() { + @Test def testEmptySize(): Unit = assert(PCanonicalStruct().byteSize == 0) - } - @Test def testUnsafeOrdering() { - val region = Region(pool=pool) - val region2 = Region(pool=pool) - val rvb = new RegionValueBuilder(sm, region) - val rvb2 = new RegionValueBuilder(sm, region2) + @Test def testUnsafeOrdering(): Unit = { + val region = Region(pool = pool) + val region2 = Region(pool = pool) val g = PType.genStruct .flatMap(t => Gen.zip(Gen.const(t), Gen.zip(t.genValue(sm), t.genValue(sm)))) - .filter { case (t, (a1, a2)) => a1 != null && a2 != null } + .filter { case (_, (a1, a2)) => a1 != null && a2 != null } .resize(10) val p = Prop.forAll(g) { case (t, (a1, a2)) => - val tv = t.virtualType tv.typeCheck(a1) diff --git a/hail/src/test/scala/is/hail/asm4s/ASM4SSuite.scala b/hail/src/test/scala/is/hail/asm4s/ASM4SSuite.scala index 4b833314e85..ceb3bee5bf4 100644 --- a/hail/src/test/scala/is/hail/asm4s/ASM4SSuite.scala +++ b/hail/src/test/scala/is/hail/asm4s/ASM4SSuite.scala @@ -1,25 +1,27 @@ package is.hail.asm4s -import java.io.PrintWriter import is.hail.HailSuite import is.hail.asm4s.Code._ -import is.hail.asm4s.FunctionBuilder._ import is.hail.check.{Gen, Prop} -import is.hail.utils.HailException -import org.scalatest.testng.TestNGSuite -import org.testng.annotations.Test +import is.hail.utils.FastSeq -import scala.collection.mutable import scala.language.postfixOps -import scala.util.Random -trait Z2Z { def apply(z:Boolean): Boolean } +import java.io.PrintWriter + +import org.testng.annotations.Test + +trait Z2Z { def apply(z: Boolean): Boolean } class ASM4SSuite extends HailSuite { private[this] val theHailClassLoader = new HailClassLoader(getClass().getClassLoader()) @Test def not(): Unit = { - val notb = FunctionBuilder[Z2Z]("is/hail/asm4s/Z2Z", Array(NotGenericTypeInfo[Boolean]), NotGenericTypeInfo[Boolean]) + val notb = FunctionBuilder[Z2Z]( + "is/hail/asm4s/Z2Z", + Array(NotGenericTypeInfo[Boolean]), + NotGenericTypeInfo[Boolean], + ) notb.emit(!notb.getArg[Boolean](1)) val not = notb.result(ctx.shouldWriteIRFiles())(theHailClassLoader) assert(!not(true)) @@ -44,7 +46,7 @@ class ASM4SSuite extends HailSuite { @Test def iinc(): Unit = { val fb = FunctionBuilder[Int]("F") val l = fb.newLocal[Int]() - fb.emit(Code(l := 0, l++, l += 2, l)) + fb.emit(Code(l := 0, l ++, l += 2, l)) val f = fb.result(ctx.shouldWriteIRFiles())(theHailClassLoader) assert(f() == 3) } @@ -57,7 +59,7 @@ class ASM4SSuite extends HailSuite { arr(0) = 6, arr(1) = 7, arr(2) = -6, - arr(hb.getArg[Int](1)) + arr(hb.getArg[Int](1)), )) val h = hb.result(ctx.shouldWriteIRFiles())(theHailClassLoader) assert(h(0) == 6) @@ -66,57 +68,59 @@ class ASM4SSuite extends HailSuite { } @Test def get(): Unit = { - val fb = FunctionBuilder[A, Int]("F") - fb.emit(fb.getArg[A](1).getField[Int]("i")) + val fb = FunctionBuilder[Foo, Int]("F") + fb.emit(fb.getArg[Foo](1).getField[Int]("i")) val i = fb.result(ctx.shouldWriteIRFiles())(theHailClassLoader) - val a = new A + val a = new Foo assert(i(a) == 5) } @Test def invoke(): Unit = { - val fb = FunctionBuilder[A, Int]("F") - fb.emit(fb.getArg[A](1).invoke[Int]("f")) + val fb = FunctionBuilder[Foo, Int]("F") + fb.emit(fb.getArg[Foo](1).invoke[Int]("f")) val i = fb.result(ctx.shouldWriteIRFiles())(theHailClassLoader) - val a = new A + val a = new Foo assert(i(a) == 6) } @Test def invoke2(): Unit = { - val fb = FunctionBuilder[A, Int]("F") - fb.emit(fb.getArg[A](1).invoke[Int, Int]("g", 6)) + val fb = FunctionBuilder[Foo, Int]("F") + fb.emit(fb.getArg[Foo](1).invoke[Int, Int]("g", 6)) val j = fb.result(ctx.shouldWriteIRFiles())(theHailClassLoader) - val a = new A + val a = new Foo assert(j(a) == 11) } @Test def newInstance(): Unit = { val fb = FunctionBuilder[Int]("F") - fb.emit(Code.newInstance[A]().invoke[Int]("f")) + fb.emit(Code.newInstance[Foo]().invoke[Int]("f")) val f = fb.result(ctx.shouldWriteIRFiles())(theHailClassLoader) assert(f() == 6) } @Test def put(): Unit = { val fb = FunctionBuilder[Int]("F") - val inst = fb.newLocal[A]() + val inst = fb.newLocal[Foo]() fb.emit(Code( - inst.store(Code.newInstance[A]()), + inst.store(Code.newInstance[Foo]()), inst.put("i", -2), - inst.getField[Int]("i"))) + inst.getField[Int]("i"), + )) val f = fb.result(ctx.shouldWriteIRFiles())(theHailClassLoader) assert(f() == -2) } @Test def staticPut(): Unit = { val fb = FunctionBuilder[Int]("F") - val inst = fb.newLocal[A]() + val inst = fb.newLocal[Foo]() fb.emit(Code( - inst.store(Code.newInstance[A]()), + inst.store(Code.newInstance[Foo]()), inst.put("j", -2), - Code.getStatic[A, Int]("j"))) + Code.getStatic[Foo, Int]("j"), + )) val f = fb.result(ctx.shouldWriteIRFiles())(theHailClassLoader) assert(f() == -2) } @@ -143,10 +147,12 @@ class ASM4SSuite extends HailSuite { fb.emitWithBuilder[Int] { cb => val r = cb.newLocal[Int]("r") cb.assign(r, 1) - cb.while_(i > 1, { - cb.assign(r, r * i) - cb.assign(i, i - 1) - }) + cb.while_( + i > 1, { + cb.assign(r, r * i) + cb.assign(i, i - 1) + }, + ) r } @@ -168,12 +174,12 @@ class ASM4SSuite extends HailSuite { @Test def anewarray(): Unit = { val fb = FunctionBuilder[Int]("F") - val arr = fb.newLocal[Array[A]]() + val arr = fb.newLocal[Array[Foo]]() fb.emit(Code( - arr.store(newArray[A](2)), - arr(0) = Code.newInstance[A](), - arr(1) = Code.newInstance[A](), - arr(0).getField[Int]("i") + arr(1).getField[Int]("i") + arr.store(newArray[Foo](2)), + arr(0) = Code.newInstance[Foo](), + arr(1) = Code.newInstance[Foo](), + arr(0).getField[Int]("i") + arr(1).getField[Int]("i"), )) val f = fb.result(ctx.shouldWriteIRFiles())(theHailClassLoader) assert(f() == 10) @@ -182,7 +188,7 @@ class ASM4SSuite extends HailSuite { def fibonacciReference(i: Int): Int = i match { case 0 => 0 case 1 => 1 - case n => fibonacciReference(n-1) + fibonacciReference(n-2) + case n => fibonacciReference(n - 1) + fibonacciReference(n - 2) } @Test def fibonacci(): Unit = { @@ -191,26 +197,29 @@ class ASM4SSuite extends HailSuite { val i = fb.getArg[Int](1) fb.emitWithBuilder[Int] { cb => val n = cb.newLocal[Int]("n") - cb.if_(i < 3, cb.assign(n, 1), { - val vn_1 = cb.newLocal[Int]("vn_1") - val vn_2 = cb.newLocal[Int]("vn_2") - cb.assign(vn_1, 1) - cb.assign(vn_2, 1) - cb.while_(i > 3, { - val temp = fb.newLocal[Int]() - cb.assign(temp, vn_2 + vn_1) - cb.assign(vn_1, temp) - cb.assign(i, i - 1) - }) - cb.assign(n, vn_2 + vn_1) - }) + cb.if_( + i < 3, + cb.assign(n, 1), { + val vn_1 = cb.newLocal[Int]("vn_1") + val vn_2 = cb.newLocal[Int]("vn_2") + cb.assign(vn_1, 1) + cb.assign(vn_2, 1) + cb.while_( + i > 3, { + val temp = fb.newLocal[Int]() + cb.assign(temp, vn_2 + vn_1) + cb.assign(vn_1, temp) + cb.assign(i, i - 1) + }, + ) + cb.assign(n, vn_2 + vn_1) + }, + ) n } val f = fb.result(ctx.shouldWriteIRFiles())(theHailClassLoader) - Prop.forAll(Gen.choose(0, 100)) { i => - fibonacciReference(i) == f(i) - } + Prop.forAll(Gen.choose(0, 100))(i => fibonacciReference(i) == f(i)) } @Test def nanAlwaysComparesFalse(): Unit = { @@ -313,17 +322,21 @@ class ASM4SSuite extends HailSuite { val a = fb.getArg[Int](1) val b = fb.getArg[Int](2) val c = fb.getArg[Int](3) - val res = cb.newLocal[Int]("res") - cb.if_(a.ceq(0), { - cb.assign(res, add.invoke(cb, b, c)) - }, { - cb.if_(a.ceq(1), - cb.assign(res, sub.invoke(cb, b, c)), - cb.assign(res, mul.invoke(cb, b, c))) - }) + val res = cb.newLocal[Int]("result") + cb.switch( + a, + cb._fatal("invalid choice"), + FastSeq( + () => cb.assign(res, cb.invoke(add, cb.this_, b, c)), + () => cb.assign(res, cb.invoke(sub, cb.this_, b, c)), + () => cb.assign(res, cb.invoke(mul, cb.this_, b, c)), + ), + ) res } - val f = fb.result(ctx.shouldWriteIRFiles(), Some(new PrintWriter(System.out)))(theHailClassLoader) + + val f = + fb.result(ctx.shouldWriteIRFiles(), Some(new PrintWriter(System.out)))(theHailClassLoader) assert(f(0, 1, 1) == 2) assert(f(1, 5, 1) == 4) assert(f(2, 2, 8) == 16) @@ -336,11 +349,15 @@ class ASM4SSuite extends HailSuite { val v1 = add.newLocal[Int]() val v2 = add.newLocal[Int]() - add.emit(Code(v1 := add.getArg[Int](1), - v2 := add.getArg[Int](2), - v1 + v2)) + add.emit( + Code( + v1 := add.getArg[Int](1), + v2 := add.getArg[Int](2), + v1 + v2, + ) + ) - fb.emitWithBuilder(add.invoke(_, fb.getArg[Int](1), fb.getArg[Int](2))) + fb.emitWithBuilder(cb => cb.invoke(add, cb.this_, fb.getArg[Int](1), fb.getArg[Int](2))) val f = fb.result(ctx.shouldWriteIRFiles())(theHailClassLoader) assert(f(1, 1) == 2) } @@ -355,7 +372,8 @@ class ASM4SSuite extends HailSuite { val c = Code( intField.store(fb.getArg[Int](1)), longField.store(fb.getArg[Long](2)), - booleanField.store(fb.getArg[Boolean](3))) + booleanField.store(fb.getArg[Boolean](3)), + ) typeInfo[T] match { case IntInfo => fb.emit(Code(c, intField.load())) @@ -381,14 +399,17 @@ class ASM4SSuite extends HailSuite { val c = Code( intField.store(fb.getArg[Int](1)), longField.store(fb.getArg[Long](2)), - booleanField.store(fb.getArg[Boolean](3))) + booleanField.store(fb.getArg[Boolean](3)), + ) typeInfo[T] match { case IntInfo => mb.emit(Code(c, intField.load())) case LongInfo => mb.emit(Code(c, longField.load())) case BooleanInfo => mb.emit(Code(c, booleanField.load())) } - fb.emitWithBuilder(mb.invoke(_, fb.getArg[Int](1), fb.getArg[Long](2), fb.getArg[Boolean](3))) + fb.emitWithBuilder { cb => + cb.invoke(mb, cb.this_, fb.getArg[Int](1), fb.getArg[Long](2), fb.getArg[Boolean](3)) + } val f = fb.result(ctx.shouldWriteIRFiles())(theHailClassLoader) f(arg1, arg2, arg3) } @@ -407,7 +428,7 @@ class ASM4SSuite extends HailSuite { a := 0, a := lzy, a := lzy, - lzy + lzy, )) val f = F.result(ctx.shouldWriteIRFiles())(theHailClassLoader) @@ -429,8 +450,8 @@ class ASM4SSuite extends HailSuite { @Test def testInit(): Unit = { val Main = FunctionBuilder[Int]("Main") val a = Main.genFieldThisRef[Int]("a") - Main.emitInit { a := 1 } - Main.emit { a } + Main.emitInit(a := 1) + Main.emit(a) val test = Main.result(ctx.shouldWriteIRFiles())(theHailClassLoader) assert(test() == 1) @@ -439,13 +460,39 @@ class ASM4SSuite extends HailSuite { @Test def testClinit(): Unit = { val Main = FunctionBuilder[Int]("Main") val a = Main.newStaticField[Int]("a") - Main.emitClinit { a.put(1) } - Main.emit { a.get() } + Main.emitClinit(a.put(1)) + Main.emit(a.get()) val test = Main.result(ctx.shouldWriteIRFiles())(theHailClassLoader) assert(test() == 1) } + @Test def testClassInstances(): Unit = { + val Counter = FunctionBuilder[Int]("Counter") + val x = Counter.genFieldThisRef[Int]("x") + Counter.emitInit(x := 0) + Counter.emit { + Code( + x := x + 1, + x, + ) + } + + val Main = FunctionBuilder[Int]("Main") + Main.emitWithBuilder[Int] { cb => + val a = cb.newLocal("a", Code.newInstance(Counter.cb, Counter.cb.ctor, FastSeq())) + val b = cb.newLocal("b", Code.newInstance(Counter.cb, Counter.cb.ctor, FastSeq())) + cb.invoke[Int](Counter.mb, a) + cb.invoke[Int](Counter.mb, a) + cb.invoke[Int](Counter.mb, b) + cb.invoke[Int](Counter.mb, a) * cb.invoke[Int](Counter.mb, b) + } + + Counter.result(ctx.shouldWriteIRFiles())(theHailClassLoader) + val test = Main.result(ctx.shouldWriteIRFiles())(theHailClassLoader) + assert(test() == 6) + } + @Test def testIf(): Unit = { val Main = FunctionBuilder[Int, Int]("If") Main.emitWithBuilder[Int] { cb => @@ -456,7 +503,7 @@ class ASM4SSuite extends HailSuite { } val abs = Main.result(ctx.shouldWriteIRFiles())(theHailClassLoader) - Prop.forAll { (x: Int) => abs(x) == x.abs }.check() + Prop.forAll((x: Int) => abs(x) == x.abs).check() } @Test def testWhile(): Unit = { @@ -468,17 +515,18 @@ class ASM4SSuite extends HailSuite { val acc = cb.newLocal[Int]("signum") cb.if_(a > 0, cb.assign(acc, 1), cb.assign(acc, -1)) - cb.while_(a cne 0, { - cb.assign(a, a - acc) - cb.assign(b, b + acc) - }) + cb.while_( + a cne 0, { + cb.assign(a, a - acc) + cb.assign(b, b + acc) + }, + ) b } val add = Main.result(ctx.shouldWriteIRFiles())(theHailClassLoader) - Prop.forAll(Gen.choose(-10, 10), Gen.choose(-10, 10)) - { (x, y) => add(x, y) == x + y } + Prop.forAll(Gen.choose(-10, 10), Gen.choose(-10, 10))((x, y) => add(x, y) == x + y) .check() } @@ -494,14 +542,14 @@ class ASM4SSuite extends HailSuite { setup = cb.if_(a > 0, cb.assign(acc, 1), cb.assign(acc, -1)), cond = a cne 0, incr = cb.assign(a, a - acc), - body = cb.assign(b, b + acc) + body = cb.assign(b, b + acc), ) b } val add = Main.result(ctx.shouldWriteIRFiles())(theHailClassLoader) - Prop.forAll(Gen.choose(-10, 10), Gen.choose(-10, 10)) { (x, y) => add(x, y) == x + y } + Prop.forAll(Gen.choose(-10, 10), Gen.choose(-10, 10))((x, y) => add(x, y) == x + y) .check() } diff --git a/hail/src/test/scala/is/hail/asm4s/CodeSuite.scala b/hail/src/test/scala/is/hail/asm4s/CodeSuite.scala index 67cde07a537..012b5e16ad1 100644 --- a/hail/src/test/scala/is/hail/asm4s/CodeSuite.scala +++ b/hail/src/test/scala/is/hail/asm4s/CodeSuite.scala @@ -3,17 +3,20 @@ package is.hail.asm4s import is.hail.HailSuite import is.hail.annotations.Region import is.hail.expr.ir.{EmitCodeBuilder, EmitFunctionBuilder, EmitValue, IEmitCode} +import is.hail.types.physical._ import is.hail.types.physical.stypes.{EmitType, SValue} import is.hail.types.physical.stypes.concrete._ -import is.hail.types.physical.stypes.primitives.{SFloat32Value, SFloat64Value, SInt32, SInt32Value, SInt64, SInt64Value} -import is.hail.types.physical._ +import is.hail.types.physical.stypes.primitives.{ + SFloat32Value, SFloat64Value, SInt32, SInt32Value, SInt64, SInt64Value, +} import is.hail.types.virtual.{TInt32, TInt64, TStruct} + import org.apache.spark.sql.Row import org.testng.annotations.Test class CodeSuite extends HailSuite { - @Test def testForLoop() { + @Test def testForLoop(): Unit = { val fb = EmitFunctionBuilder[Int](ctx, "foo") fb.emitWithBuilder[Int] { cb => val i = cb.newLocal[Int]("i") @@ -30,21 +33,27 @@ class CodeSuite extends HailSuite { @Test def testSizeBasic(): Unit = { val int64 = new SInt64Value(5L) val int32 = new SInt32Value(2) - val struct = new SStackStructValue(SStackStruct(TStruct("x" -> TInt64, "y" -> TInt32), IndexedSeq(EmitType(SInt64, true), EmitType(SInt32, false))), IndexedSeq(EmitValue(None, int64), EmitValue(Some(false), int32))) + val struct = new SStackStructValue( + SStackStruct( + TStruct("x" -> TInt64, "y" -> TInt32), + IndexedSeq(EmitType(SInt64, true), EmitType(SInt32, false)), + ), + IndexedSeq(EmitValue(None, int64), EmitValue(Some(false), int32)), + ) val str = new SJavaStringValue(const("cat")) def testSizeHelper(v: SValue): Long = { val fb = EmitFunctionBuilder[Long](ctx, "test_size_in_bytes") val mb = fb.apply_method - mb.emit(EmitCodeBuilder.scopedCode(mb) { cb => - v.sizeToStoreInBytes(cb).value - }) + mb.emit(EmitCodeBuilder.scopedCode(mb)(cb => v.sizeToStoreInBytes(cb).value)) fb.result()(theHailClassLoader)() } assert(testSizeHelper(int64) == 8L) assert(testSizeHelper(int32) == 4L) - assert(testSizeHelper(struct) == 16L) // 1 missing byte that gets 4 byte aligned, 8 bytes for long, 4 bytes for missing int + assert( + testSizeHelper(struct) == 16L + ) // 1 missing byte that gets 4 byte aligned, 8 bytes for long, 4 bytes for missing int assert(testSizeHelper(str) == 7L) // 4 byte header, 3 bytes for the 3 letters. } @@ -56,11 +65,17 @@ class CodeSuite extends HailSuite { mb.emit(EmitCodeBuilder.scopedCode(mb) { cb => val region = fb.emb.getCodeParam[Region](1) val sarray = ptype.constructFromElements(cb, region, 5, true) { (cb, idx) => - cb.if_(idx ceq 2, { IEmitCode.missing(cb, stype.elementType.defaultValue)}, { IEmitCode.present(cb, new SInt32Value(idx))}) + cb.if_( + idx ceq 2, + IEmitCode.missing(cb, stype.elementType.defaultValue), + IEmitCode.present(cb, new SInt32Value(idx)), + ) } sarray.sizeToStoreInBytes(cb).value }) - assert(fb.result()(theHailClassLoader)(ctx.r) == 36L) // 2 missing bytes 8 byte aligned + 8 header bytes + 5 elements * 4 bytes for ints. + assert( + fb.result()(theHailClassLoader)(ctx.r) == 36L + ) // 2 missing bytes 8 byte aligned + 8 header bytes + 5 elements * 4 bytes for ints. } @Test def testIntervalSizeInBytes(): Unit = { @@ -68,37 +83,65 @@ class CodeSuite extends HailSuite { val mb = fb.apply_method val structL = new SStackStructValue( - SStackStruct(TStruct("x" -> TInt64, "y" -> TInt32), IndexedSeq(EmitType(SInt64, true), EmitType(SInt32, false))), - IndexedSeq(EmitValue(None, new SInt64Value(5L)), EmitValue(Some(false), new SInt32Value(2))) + SStackStruct( + TStruct("x" -> TInt64, "y" -> TInt32), + IndexedSeq(EmitType(SInt64, true), EmitType(SInt32, false)), + ), + IndexedSeq(EmitValue(None, new SInt64Value(5L)), EmitValue(Some(false), new SInt32Value(2))), ) val structR = new SStackStructValue( - SStackStruct(TStruct("x" -> TInt64, "y" -> TInt32), IndexedSeq(EmitType(SInt64, true), EmitType(SInt32, false))), - IndexedSeq(EmitValue(None, new SInt64Value(8L)), EmitValue(Some(false), new SInt32Value(5))) + SStackStruct( + TStruct("x" -> TInt64, "y" -> TInt32), + IndexedSeq(EmitType(SInt64, true), EmitType(SInt32, false)), + ), + IndexedSeq(EmitValue(None, new SInt64Value(8L)), EmitValue(Some(false), new SInt32Value(5))), ) val pType = PCanonicalInterval(structL.st.storageType()) mb.emit(EmitCodeBuilder.scopedCode(mb) { cb => val region = fb.emb.getCodeParam[Region](1) - val sval: SValue =pType.constructFromCodes(cb, region, - EmitValue(Some(false), structL), EmitValue(Some(false), structR), - true, true) + val sval: SValue = pType.constructFromCodes( + cb, + region, + EmitValue(Some(false), structL), + EmitValue(Some(false), structR), + true, + true, + ) sval.sizeToStoreInBytes(cb).value }) assert(fb.result()(theHailClassLoader)(ctx.r) == 72L) // 2 28 byte structs, plus 2 1 byte booleans that get 8 byte for an extra 8 bytes, plus missing bytes. } - @Test def testHash() { - val fields = IndexedSeq(PField("a", PCanonicalString(), 0), PField("b", PInt32(), 1), PField("c", PFloat32(), 2)) + @Test def testHash(): Unit = { + val fields = IndexedSeq( + PField("a", PCanonicalString(), 0), + PField("b", PInt32(), 1), + PField("c", PFloat32(), 2), + ) assert(hashTestNumHelper(new SInt32Value(6)) == hashTestNumHelper(new SInt32Value(6))) - assert(hashTestNumHelper(new SInt64Value(5000000000l)) == hashTestNumHelper(new SInt64Value(5000000000l))) - assert(hashTestNumHelper(new SFloat32Value(3.14f)) == hashTestNumHelper(new SFloat32Value(3.14f))) - assert(hashTestNumHelper(new SFloat64Value(5000000000.89d)) == hashTestNumHelper(new SFloat64Value(5000000000.89d))) - assert(hashTestStringHelper("dog")== hashTestStringHelper("dog")) - assert(hashTestArrayHelper(IndexedSeq(1,2,3,4,5,6)) == hashTestArrayHelper(IndexedSeq(1,2,3,4,5,6))) - assert(hashTestArrayHelper(IndexedSeq(1,2)) != hashTestArrayHelper(IndexedSeq(3,4,5,6,7))) - assert(hashTestStructHelper(Row("wolf", 8, .009f), fields) == hashTestStructHelper(Row("wolf", 8, .009f), fields)) - assert(hashTestStructHelper(Row("w", 8, .009f), fields) != hashTestStructHelper(Row("opaque", 8, .009f), fields)) + assert(hashTestNumHelper(new SInt64Value(5000000000L)) == hashTestNumHelper( + new SInt64Value(5000000000L) + )) + assert( + hashTestNumHelper(new SFloat32Value(3.14f)) == hashTestNumHelper(new SFloat32Value(3.14f)) + ) + assert(hashTestNumHelper(new SFloat64Value(5000000000.89d)) == hashTestNumHelper( + new SFloat64Value(5000000000.89d) + )) + assert(hashTestStringHelper("dog") == hashTestStringHelper("dog")) + assert(hashTestArrayHelper(IndexedSeq(1, 2, 3, 4, 5, 6)) == hashTestArrayHelper(IndexedSeq(1, 2, + 3, 4, 5, 6))) + assert(hashTestArrayHelper(IndexedSeq(1, 2)) != hashTestArrayHelper(IndexedSeq(3, 4, 5, 6, 7))) + assert(hashTestStructHelper(Row("wolf", 8, .009f), fields) == hashTestStructHelper( + Row("wolf", 8, .009f), + fields, + )) + assert(hashTestStructHelper(Row("w", 8, .009f), fields) != hashTestStructHelper( + Row("opaque", 8, .009f), + fields, + )) } def hashTestNumHelper(v: SValue): Int = { @@ -123,7 +166,7 @@ class CodeSuite extends HailSuite { val hash = i.hash(cb) hash.value }) - val region = Region(pool=pool) + val region = Region(pool = pool) fb.result()(theHailClassLoader)(region) } @@ -137,12 +180,12 @@ class CodeSuite extends HailSuite { val hash = arrayToHash.hash(cb) hash.value }) - val region = Region(pool=pool) + val region = Region(pool = pool) val arrayPointer = pArray.unstagedStoreJavaObject(ctx.stateManager, toHash, region) fb.result()(theHailClassLoader)(arrayPointer) } - def hashTestStructHelper(toHash: Row, fields : IndexedSeq[PField]): Int = { + def hashTestStructHelper(toHash: Row, fields: IndexedSeq[PField]): Int = { val pStruct = PCanonicalStruct(fields) val fb = EmitFunctionBuilder[Long, Int](ctx, "test_hash") val mb = fb.apply_method @@ -152,7 +195,7 @@ class CodeSuite extends HailSuite { val hash = structToHash.hash(cb) hash.value }) - val region = Region(pool=pool) + val region = Region(pool = pool) val structPointer = pStruct.unstagedStoreJavaObject(ctx.stateManager, toHash, region) fb.result()(theHailClassLoader)(structPointer) } diff --git a/hail/src/test/scala/is/hail/asm4s/A.java b/hail/src/test/scala/is/hail/asm4s/Foo.java similarity index 89% rename from hail/src/test/scala/is/hail/asm4s/A.java rename to hail/src/test/scala/is/hail/asm4s/Foo.java index 91e5ea28600..dc44fbf412a 100644 --- a/hail/src/test/scala/is/hail/asm4s/A.java +++ b/hail/src/test/scala/is/hail/asm4s/Foo.java @@ -1,6 +1,6 @@ package is.hail.asm4s; -public class A { +public class Foo { public static int j = 11; public int i = 5; diff --git a/hail/src/test/scala/is/hail/expr/ParserSuite.scala b/hail/src/test/scala/is/hail/expr/ParserSuite.scala index 31135d815ae..959d33ff62d 100644 --- a/hail/src/test/scala/is/hail/expr/ParserSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ParserSuite.scala @@ -1,9 +1,10 @@ package is.hail.expr import is.hail.HailSuite + import org.testng.annotations.Test -class ParserSuite extends HailSuite{ +class ParserSuite extends HailSuite { @Test def testOneOfLiteral(): Unit = { val strings = Array("A", "B", "AB", "AA", "CAD", "EF") val p = Parser.oneOfLiteral(strings) diff --git a/hail/src/test/scala/is/hail/expr/ir/Aggregators2Suite.scala b/hail/src/test/scala/is/hail/expr/ir/Aggregators2Suite.scala index a8ce8e9e16c..c657bd9a302 100644 --- a/hail/src/test/scala/is/hail/expr/ir/Aggregators2Suite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/Aggregators2Suite.scala @@ -1,86 +1,109 @@ package is.hail.expr.ir +import is.hail.{ExecStrategy, HailSuite} import is.hail.annotations._ import is.hail.asm4s._ import is.hail.expr.ir.agg._ import is.hail.io.BufferSpec +import is.hail.types.{MatrixType, VirtualTypeWithReq} import is.hail.types.physical._ import is.hail.types.physical.stypes.PTypeReferenceSingleCodeType import is.hail.types.virtual._ -import is.hail.types.{MatrixType, VirtualTypeWithReq, tcoerce} import is.hail.utils._ import is.hail.variant.{Call0, Call1, Call2} -import is.hail.{ExecStrategy, HailSuite} + import org.apache.spark.sql.Row import org.testng.annotations.Test class Aggregators2Suite extends HailSuite { def assertAggEqualsProcessed( - aggSig: PhysicalAggSig, - initOp: IR, - seqOps: IndexedSeq[IR], - expected: Any, - args: IndexedSeq[(String, (Type, Any))] = FastSeq(), - nPartitions: Int = 2, - expectedInit: Option[Any] = None, - transformResult: Option[Any => Any] = None + aggSig: PhysicalAggSig, + initOp: IR, + seqOps: IndexedSeq[IR], + expected: Any, + args: IndexedSeq[(String, (Type, Any))] = FastSeq(), + nPartitions: Int = 2, + expectedInit: Option[Any] = None, + transformResult: Option[Any => Any] = None, ): Unit = { assert(seqOps.length >= 2 * nPartitions, s"Test aggregators with a larger stream!") - val argT = PType.canonical(TStruct(args.map { case (n, (typ, _)) => n -> typ }: _*)).setRequired(true).asInstanceOf[PStruct] + val argT = PType.canonical( + TStruct(args.map { case (n, (typ, _)) => n -> typ }: _*) + ).setRequired(true).asInstanceOf[PStruct] val argVs = Row.fromSeq(args.map { case (_, (_, v)) => v }) val argRef = Ref(genUID(), argT.virtualType) val spec = BufferSpec.wireSpec - val (_, combAndDuplicate) = CompileWithAggregators[AsmFunction1RegionUnit](ctx, + val (_, combAndDuplicate) = CompileWithAggregators[AsmFunction1RegionUnit]( + ctx, Array.fill(nPartitions)(aggSig.state), FastSeq(), - FastSeq(classInfo[Region]), UnitInfo, + FastSeq(classInfo[Region]), + UnitInfo, Begin( Array.tabulate(nPartitions)(i => DeserializeAggs(i, i, spec, Array(aggSig.state))) ++ Array.range(1, nPartitions).map(i => CombOp(0, i, aggSig)) :+ SerializeAggs(0, 0, spec, Array(aggSig.state)) :+ - DeserializeAggs(1, 0, spec, Array(aggSig.state)))) + DeserializeAggs(1, 0, spec, Array(aggSig.state)) + ), + ) - val (Some(PTypeReferenceSingleCodeType(rt: PTuple)), resF) = CompileWithAggregators[AsmFunction1RegionLong](ctx, - Array.fill(nPartitions)(aggSig.state), - FastSeq(), - FastSeq(classInfo[Region]), LongInfo, - ResultOp.makeTuple(Array(aggSig, aggSig))) + val (Some(PTypeReferenceSingleCodeType(rt: PTuple)), resF) = + CompileWithAggregators[AsmFunction1RegionLong]( + ctx, + Array.fill(nPartitions)(aggSig.state), + FastSeq(), + FastSeq(classInfo[Region]), + LongInfo, + ResultOp.makeTuple(Array(aggSig, aggSig)), + ) assert(rt.types(0) == rt.types(1)) val resultType = rt.types(0) if (transformResult.isEmpty) - assert(resultType.virtualType.typeCheck(expected), s"expected type ${ resultType.virtualType.parsableString() }, got ${expected}") + assert( + resultType.virtualType.typeCheck(expected), + s"expected type ${resultType.virtualType.parsableString()}, got $expected", + ) pool.scopedRegion { region => val argOff = ScalaToRegionValue(ctx.stateManager, region, argT, argVs) def withArgs(foo: IR) = { - CompileWithAggregators[AsmFunction2RegionLongUnit](ctx, + CompileWithAggregators[AsmFunction2RegionLongUnit]( + ctx, Array(aggSig.state), FastSeq((argRef.name, SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(argT)))), - FastSeq(classInfo[Region], LongInfo), UnitInfo, - Let(args.map { case (n, _) => n -> GetField(argRef, n) }, foo) - )._2 + FastSeq(classInfo[Region], LongInfo), + UnitInfo, + Let(args.map { case (n, _) => n -> GetField(argRef, n) }, foo), + )._2 } val serialize = SerializeAggs(0, 0, spec, Array(aggSig.state)) - val (_, writeF) = CompileWithAggregators[AsmFunction1RegionUnit](ctx, + val (_, writeF) = CompileWithAggregators[AsmFunction1RegionUnit]( + ctx, Array(aggSig.state), FastSeq(), - FastSeq(classInfo[Region]), UnitInfo, - serialize) + FastSeq(classInfo[Region]), + UnitInfo, + serialize, + ) val initF = withArgs(initOp) expectedInit.foreach { v => - val (Some(PTypeReferenceSingleCodeType(rt: PBaseStruct)), resOneF) = CompileWithAggregators[AsmFunction1RegionLong](ctx, - Array(aggSig.state), - FastSeq(), - FastSeq(classInfo[Region]), LongInfo, - ResultOp.makeTuple(Array(aggSig))) + val (Some(PTypeReferenceSingleCodeType(rt: PBaseStruct)), resOneF) = + CompileWithAggregators[AsmFunction1RegionLong]( + ctx, + Array(aggSig.state), + FastSeq(), + FastSeq(classInfo[Region]), + LongInfo, + ResultOp.makeTuple(Array(aggSig)), + ) val init = initF(theHailClassLoader, ctx.fs, ctx.taskContext, region) val res = resOneF(theHailClassLoader, ctx.fs, ctx.taskContext, region) @@ -94,22 +117,23 @@ class Aggregators2Suite extends HailSuite { } } - val serializedParts = seqOps.grouped(math.ceil(seqOps.length / nPartitions.toDouble).toInt).map { seqs => - val init = initF(theHailClassLoader, ctx.fs, ctx.taskContext, region) - val seq = withArgs(Begin(seqs))(theHailClassLoader, ctx.fs, ctx.taskContext, region) - val write = writeF(theHailClassLoader, ctx.fs, ctx.taskContext, region) - pool.scopedSmallRegion { aggRegion => - init.newAggState(aggRegion) - init(region, argOff) - val ioff = init.getAggOffset() - seq.setAggState(aggRegion, ioff) - seq(region, argOff) - val soff = seq.getAggOffset() - write.setAggState(aggRegion, soff) - write(region) - write.getSerializedAgg(0) - } - }.toArray + val serializedParts = + seqOps.grouped(math.ceil(seqOps.length / nPartitions.toDouble).toInt).map { seqs => + val init = initF(theHailClassLoader, ctx.fs, ctx.taskContext, region) + val seq = withArgs(Begin(seqs))(theHailClassLoader, ctx.fs, ctx.taskContext, region) + val write = writeF(theHailClassLoader, ctx.fs, ctx.taskContext, region) + pool.scopedSmallRegion { aggRegion => + init.newAggState(aggRegion) + init(region, argOff) + val ioff = init.getAggOffset() + seq.setAggState(aggRegion, ioff) + seq(region, argOff) + val soff = seq.getAggOffset() + write.setAggState(aggRegion, soff) + write(region) + write.getSerializedAgg(0) + } + }.toArray pool.scopedSmallRegion { aggRegion => val combOp = combAndDuplicate(theHailClassLoader, ctx.fs, ctx.taskContext, region) @@ -123,74 +147,139 @@ class Aggregators2Suite extends HailSuite { val double = SafeRow(rt, res(region)) transformResult match { case Some(f) => - assert(f(double.get(0)) == f(double.get(1)), - s"\nbefore: ${ f(double.get(0)) }\nafter: ${ f(double.get(1)) }") - assert(f(double.get(0)) == expected, - s"\nresult: ${ f(double.get(0)) }\nexpect: ${ expected }") + assert( + f(double.get(0)) == f(double.get(1)), + s"\nbefore: ${f(double.get(0))}\nafter: ${f(double.get(1))}", + ) + assert( + f(double.get(0)) == expected, + s"\nresult: ${f(double.get(0))}\nexpect: $expected", + ) case None => - assert(resultType.virtualType.valuesSimilar(double.get(0), double.get(1)), // state does not change through serialization - s"\nbefore: ${ double.get(0) }\nafter: ${ double.get(1) }") - assert(resultType.virtualType.valuesSimilar(double.get(0), expected), - s"\nresult: ${ double.get(0) }\nexpect: $expected") + assert( + resultType.virtualType.valuesSimilar( + double.get(0), + double.get(1), + ), // state does not change through serialization + s"\nbefore: ${double.get(0)}\nafter: ${double.get(1)}", + ) + assert( + resultType.virtualType.valuesSimilar(double.get(0), expected), + s"\nresult: ${double.get(0)}\nexpect: $expected", + ) } } } } def assertAggEquals( - aggSig: PhysicalAggSig, - initArgs: IndexedSeq[IR], - seqArgs: IndexedSeq[IndexedSeq[IR]], - expected: Any, - args: IndexedSeq[(String, (Type, Any))] = FastSeq(), - nPartitions: Int = 2, - expectedInit: Option[Any] = None, - transformResult: Option[Any => Any] = None): Unit = - assertAggEqualsProcessed(aggSig, + aggSig: PhysicalAggSig, + initArgs: IndexedSeq[IR], + seqArgs: IndexedSeq[IndexedSeq[IR]], + expected: Any, + args: IndexedSeq[(String, (Type, Any))] = FastSeq(), + nPartitions: Int = 2, + expectedInit: Option[Any] = None, + transformResult: Option[Any => Any] = None, + ): Unit = + assertAggEqualsProcessed( + aggSig, InitOp(0, initArgs, aggSig), seqArgs.map(s => SeqOp(0, s, aggSig)), - expected, args, nPartitions, expectedInit, - transformResult) + expected, + args, + nPartitions, + expectedInit, + transformResult, + ) val t = TStruct("a" -> TString, "b" -> TInt64) val rows = FastSeq(Row("abcd", 5L), null, Row(null, -2L), Row("abcd", 7L), null, Row("foo", null)) val arrayType = TArray(t) val pnnAggSig = PhysicalAggSig(PrevNonnull(), TypedStateSig(VirtualTypeWithReq.fullyOptional(t))) - val countAggSig = PhysicalAggSig(Count(), TypedStateSig(VirtualTypeWithReq.fullyOptional(TInt64).setRequired(true))) - val sumAggSig = PhysicalAggSig(Sum(), TypedStateSig(VirtualTypeWithReq.fullyOptional(TInt64).setRequired(true))) - def collectAggSig(t: Type): PhysicalAggSig = PhysicalAggSig(Collect(), CollectStateSig(VirtualTypeWithReq(PType.canonical(t)))) + val countAggSig = PhysicalAggSig( + Count(), + TypedStateSig(VirtualTypeWithReq.fullyOptional(TInt64).setRequired(true)), + ) - @Test def TestCount() { + val sumAggSig = + PhysicalAggSig(Sum(), TypedStateSig(VirtualTypeWithReq.fullyOptional(TInt64).setRequired(true))) + + def collectAggSig(t: Type): PhysicalAggSig = + PhysicalAggSig(Collect(), CollectStateSig(VirtualTypeWithReq(PType.canonical(t)))) + + @Test def TestCount(): Unit = { val seqOpArgs = Array.fill(rows.length)(FastSeq[IR]()) - assertAggEquals(countAggSig, FastSeq(), seqOpArgs, expected = rows.length.toLong, args = FastSeq(("rows", (arrayType, rows)))) + assertAggEquals( + countAggSig, + FastSeq(), + seqOpArgs, + expected = rows.length.toLong, + args = FastSeq(("rows", (arrayType, rows))), + ) } - @Test def testSum() { - val seqOpArgs = Array.tabulate(rows.length)(i => FastSeq[IR](GetField(ArrayRef(Ref("rows", arrayType), i), "b"))) - assertAggEquals(sumAggSig, FastSeq(), seqOpArgs, expected = 10L, args = FastSeq(("rows", (arrayType, rows)))) + @Test def testSum(): Unit = { + val seqOpArgs = Array.tabulate(rows.length)(i => + FastSeq[IR](GetField(ArrayRef(Ref("rows", arrayType), i), "b")) + ) + assertAggEquals( + sumAggSig, + FastSeq(), + seqOpArgs, + expected = 10L, + args = FastSeq(("rows", (arrayType, rows))), + ) } - @Test def testPrevNonnullStr() { - val aggSig = PhysicalAggSig(PrevNonnull(), TypedStateSig(VirtualTypeWithReq(PCanonicalString()))) - val seqOpArgs = Array.tabulate(rows.length)(i => FastSeq[IR](GetField(ArrayRef(Ref("rows", arrayType), i), "a"))) + @Test def testPrevNonnullStr(): Unit = { + val aggSig = + PhysicalAggSig(PrevNonnull(), TypedStateSig(VirtualTypeWithReq(PCanonicalString()))) + val seqOpArgs = Array.tabulate(rows.length)(i => + FastSeq[IR](GetField(ArrayRef(Ref("rows", arrayType), i), "a")) + ) - assertAggEquals(aggSig, FastSeq(), seqOpArgs, expected = rows.last.get(0), args = FastSeq(("rows", (arrayType, rows)))) + assertAggEquals( + aggSig, + FastSeq(), + seqOpArgs, + expected = rows.last.get(0), + args = FastSeq(("rows", (arrayType, rows))), + ) } - @Test def testPrevNonnull() { - val seqOpArgs = Array.tabulate(rows.length)(i => FastSeq[IR](ArrayRef(Ref("rows", TArray(t)), i))) - assertAggEquals(pnnAggSig, FastSeq(), seqOpArgs, expected = rows.last, args = FastSeq(("rows", (arrayType, rows)))) + @Test def testPrevNonnull(): Unit = { + val seqOpArgs = + Array.tabulate(rows.length)(i => FastSeq[IR](ArrayRef(Ref("rows", TArray(t)), i))) + assertAggEquals( + pnnAggSig, + FastSeq(), + seqOpArgs, + expected = rows.last, + args = FastSeq(("rows", (arrayType, rows))), + ) } - @Test def testProduct() { - val aggSig = PhysicalAggSig(Product(), TypedStateSig(VirtualTypeWithReq.fullyOptional(TInt64).setRequired(true))) - val seqOpArgs = Array.tabulate(rows.length)(i => FastSeq[IR](GetField(ArrayRef(Ref("rows", arrayType), i), "b"))) - assertAggEquals(aggSig, FastSeq(), seqOpArgs, expected = -70L, args = FastSeq(("rows", (arrayType, rows)))) + @Test def testProduct(): Unit = { + val aggSig = PhysicalAggSig( + Product(), + TypedStateSig(VirtualTypeWithReq.fullyOptional(TInt64).setRequired(true)), + ) + val seqOpArgs = Array.tabulate(rows.length)(i => + FastSeq[IR](GetField(ArrayRef(Ref("rows", arrayType), i), "b")) + ) + assertAggEquals( + aggSig, + FastSeq(), + seqOpArgs, + expected = -70L, + args = FastSeq(("rows", (arrayType, rows))), + ) } - @Test def testCallStats() { + @Test def testCallStats(): Unit = { val t = TStruct("x" -> TCall) val calls = FastSeq( @@ -210,32 +299,38 @@ class Aggregators2Suite extends HailSuite { Row(Call2(1, 3)), null, null, - Row(null)) + Row(null), + ) val aggSig = PhysicalAggSig(CallStats(), CallStatsStateSig()) def seqOpArgs(calls: IndexedSeq[Any]) = Array.tabulate(calls.length)(i => - FastSeq[IR](GetField(ArrayRef(Ref("calls", TArray(t)), i), "x"))) + FastSeq[IR](GetField(ArrayRef(Ref("calls", TArray(t)), i), "x")) + ) val an = 18 val ac = FastSeq(10, 6, 1, 1, 0) val af = ac.map(_.toDouble / an).toFastSeq val homCount = FastSeq(3, 1, 0, 0, 0) - assertAggEquals(aggSig, + assertAggEquals( + aggSig, FastSeq(I32(5)), seqOpArgs(calls), expected = Row(ac, af, an, homCount), - args = FastSeq(("calls", (TArray(t), calls)))) + args = FastSeq(("calls", (TArray(t), calls))), + ) val allMissing = calls.filter(_ == null) - assertAggEquals(aggSig, + assertAggEquals( + aggSig, FastSeq(I32(5)), seqOpArgs(allMissing), expected = Row(FastSeq(0, 0, 0, 0, 0), null, 0, FastSeq(0, 0, 0, 0, 0)), - args = FastSeq(("calls", (TArray(t), allMissing)))) + args = FastSeq(("calls", (TArray(t), allMissing))), + ) } - @Test def testTakeBy() { + @Test def testTakeBy(): Unit = { val t = TStruct( "a" -> TStruct("x" -> TInt32, "y" -> TInt64), "b" -> TInt32, @@ -244,7 +339,8 @@ class Aggregators2Suite extends HailSuite { "e" -> TFloat64, "f" -> TBoolean, "g" -> TString, - "h" -> TArray(TInt32)) + "h" -> TArray(TInt32), + ) val rows = FastSeq( Row(Row(11, 11L), 1, 1L, 1f, 1d, true, "1", FastSeq(1, 1)), @@ -260,7 +356,7 @@ class Aggregators2Suite extends HailSuite { Row(Row(1010, 1011L), 11, 11L, 11f, 11d, true, "11111111111", FastSeq()), Row(null, null, null, null, null, null, null, null), Row(null, null, null, null, null, null, null, null), - Row(null, null, null, null, null, null, null, null) + Row(null, null, null, null, null, null, null, null), ) val rowsReversed = rows.take(rows.length - 3).reverse ++ rows.takeRight(3) @@ -273,7 +369,7 @@ class Aggregators2Suite extends HailSuite { { val (a, b) = rows.zipWithIndex.partition(_._2 % 2 == 0) a.map(_._1) ++ b.map(_._1) - } // random-ish + }, // random-ish ) val valueTransformations: Array[(Type, IR => IR, Row => Any)] = Array( @@ -282,47 +378,63 @@ class Aggregators2Suite extends HailSuite { (TFloat64, GetField(_, "e"), Option(_).map(_.get(4)).orNull), (TBoolean, GetField(_, "f"), Option(_).map(_.get(5)).orNull), (TString, GetField(_, "g"), Option(_).map(_.get(6)).orNull), - (TArray(TInt32), GetField(_, "h"), Option(_).map(_.get(7)).orNull) + (TArray(TInt32), GetField(_, "h"), Option(_).map(_.get(7)).orNull), ) val keyTransformations: Array[(Type, IR => IR)] = Array( (TInt32, GetField(_, "b")), (TFloat64, GetField(_, "e")), (TString, GetField(_, "g")), - (TStruct("x" -> TInt32, "y" -> TInt64), GetField(_, "a")) + (TStruct("x" -> TInt32, "y" -> TInt64), GetField(_, "a")), ) - def test(n: Int, data: IndexedSeq[Row], valueType: Type, valueF: IR => IR, resultF: Row => Any, keyType: Type, keyF: IR => IR, so: SortOrder = Ascending): Unit = { - - val aggSig = PhysicalAggSig(TakeBy(), TakeByStateSig(VirtualTypeWithReq(PType.canonical(valueType)), VirtualTypeWithReq(PType.canonical(keyType)), so)) + def test( + n: Int, + data: IndexedSeq[Row], + valueType: Type, + valueF: IR => IR, + resultF: Row => Any, + keyType: Type, + keyF: IR => IR, + so: SortOrder = Ascending, + ): Unit = { + + val aggSig = PhysicalAggSig( + TakeBy(), + TakeByStateSig( + VirtualTypeWithReq(PType.canonical(valueType)), + VirtualTypeWithReq(PType.canonical(keyType)), + so, + ), + ) val seqOpArgs = Array.tabulate(rows.length) { i => val ref = ArrayRef(Ref("rows", TArray(t)), i) FastSeq[IR](valueF(ref), keyF(ref)) } - assertAggEquals(aggSig, + assertAggEquals( + aggSig, FastSeq(I32(n)), seqOpArgs, expected = (if (so == Descending) rowsReversed else rows).take(n).map(resultF), - args = FastSeq(("rows", (TArray(t), data)))) + args = FastSeq(("rows", (TArray(t), data))), + ) } // test counts and data input orderings - for ( - n <- FastSeq(0, 1, 4, 100); - perm <- permutations; + for { + n <- FastSeq(0, 1, 4, 100) + perm <- permutations so <- FastSeq(Ascending, Descending) - ) { - test(n, perm, t, identity[IR], identity[Row], TInt32, GetField(_, "b"), so) } + test(n, perm, t, identity[IR], identity[Row], TInt32, GetField(_, "b"), so) // test key and value types - for ( - (vt, valueF, resultF) <- valueTransformations; - (kt, keyF) <- keyTransformations - ) { - test(4, permutations.last, vt, valueF, resultF, kt, keyF) + for { + (vt, valueF, resultF) <- valueTransformations + (kt, keyF) <- keyTransformations } + test(4, permutations.last, vt, valueF, resultF, kt, keyF) // test stable sort test(7, rows, t, identity[IR], identity[Row], TInt64, _ => I64(5L)) @@ -330,26 +442,49 @@ class Aggregators2Suite extends HailSuite { // test GC behavior by passing a large collection val rows2 = Array.tabulate(1200)(i => Row(i, i.toString)).toFastSeq val t2 = TStruct("a" -> TInt32, "b" -> TString) - val aggSig2 = PhysicalAggSig(TakeBy(), TakeByStateSig(VirtualTypeWithReq(PType.canonical(t2)), VirtualTypeWithReq(PType.canonical(TInt32)), Ascending)) - val seqOpArgs2 = Array.tabulate(rows2.length)(i => FastSeq[IR]( - ArrayRef(Ref("rows", TArray(t2)), i), GetField(ArrayRef(Ref("rows", TArray(t2)), i), "a"))) + val aggSig2 = PhysicalAggSig( + TakeBy(), + TakeByStateSig( + VirtualTypeWithReq(PType.canonical(t2)), + VirtualTypeWithReq(PType.canonical(TInt32)), + Ascending, + ), + ) + val seqOpArgs2 = Array.tabulate(rows2.length)(i => + FastSeq[IR]( + ArrayRef(Ref("rows", TArray(t2)), i), + GetField(ArrayRef(Ref("rows", TArray(t2)), i), "a"), + ) + ) - assertAggEquals(aggSig2, + assertAggEquals( + aggSig2, FastSeq(I32(17)), seqOpArgs2, expected = rows2.take(17), - args = FastSeq(("rows", (TArray(t2), rows2.reverse)))) + args = FastSeq(("rows", (TArray(t2), rows2.reverse))), + ) // test inside of aggregation val tr = TableRange(10000, 5) - val ta = TableAggregate(tr, ApplyAggOp(FastSeq(19), - FastSeq(invoke("str", TString, GetField(Ref("row", tr.typ.rowType), "idx")), I32(9999) - GetField(Ref("row", tr.typ.rowType), "idx")), - AggSignature(TakeBy(), FastSeq(TInt32), FastSeq(TString, TInt32)))) + val ta = TableAggregate( + tr, + ApplyAggOp( + FastSeq(19), + FastSeq( + invoke("str", TString, GetField(Ref("row", tr.typ.rowType), "idx")), + I32(9999) - GetField(Ref("row", tr.typ.rowType), "idx"), + ), + AggSignature(TakeBy(), FastSeq(TInt32), FastSeq(TString, TInt32)), + ), + ) - assertEvalsTo(ta, (0 until 19).map(i => (9999 - i).toString).toFastSeq)(ExecStrategy.interpretOnly) + assertEvalsTo(ta, (0 until 19).map(i => (9999 - i).toString).toFastSeq)( + ExecStrategy.interpretOnly + ) } - @Test def testTake() { + @Test def testTake(): Unit = { val t = TStruct( "a" -> TStruct("x" -> TInt32, "y" -> TInt64), "b" -> TInt32, @@ -358,7 +493,8 @@ class Aggregators2Suite extends HailSuite { "e" -> TFloat64, "f" -> TBoolean, "g" -> TString, - "h" -> TArray(TInt32)) + "h" -> TArray(TInt32), + ) val rows = FastSeq( Row(Row(11, 11L), 1, 1L, 1f, 1d, true, "one", FastSeq(1, 1)), @@ -378,36 +514,40 @@ class Aggregators2Suite extends HailSuite { Row(Row(88, 88L), 8, 8L, 8f, 8d, null, "eight", null), Row(Row(99, 99L), 9, 9L, 9f, 9d, null, "nine", FastSeq(null)), Row(Row(1010, 1010L), 10, 10L, 10f, 10d, false, "ten", FastSeq()), - Row(Row(1111, 1111L), 11, 11L, 11f, 11d, true, "eleven", FastSeq()) + Row(Row(1111, 1111L), 11, 11L, 11f, 11d, true, "eleven", FastSeq()), ) val aggSig = PhysicalAggSig(Take(), TakeStateSig(VirtualTypeWithReq(PType.canonical(t)))) - val seqOpArgs = Array.tabulate(rows.length)(i => FastSeq[IR](ArrayRef(Ref("rows", TArray(t)), i))) + val seqOpArgs = + Array.tabulate(rows.length)(i => FastSeq[IR](ArrayRef(Ref("rows", TArray(t)), i))) FastSeq(0, 1, 3, 8, 10, 15, 30).foreach { i => - assertAggEquals(aggSig, + assertAggEquals( + aggSig, FastSeq(I32(i)), seqOpArgs, expected = rows.take(i), - args = FastSeq(("rows", (TArray(t), rows)))) + args = FastSeq(("rows", (TArray(t), rows))), + ) } val transformations: IndexedSeq[(IR => IR, Row => Any, Type)] = t.fields.map { f => - ((x: IR) => GetField(x, f.name), - (r: Row) => if (r == null) null else r.get(f.index), - f.typ) + ((x: IR) => GetField(x, f.name), (r: Row) => if (r == null) null else r.get(f.index), f.typ) }.filter(_._3 == TString) transformations.foreach { case (irF, rowF, subT) => val aggSig = PhysicalAggSig(Take(), TakeStateSig(VirtualTypeWithReq(PType.canonical(subT)))) - val seqOpArgs = Array.tabulate(rows.length)(i => FastSeq[IR](irF(ArrayRef(Ref("rows", TArray(t)), i)))) + val seqOpArgs = + Array.tabulate(rows.length)(i => FastSeq[IR](irF(ArrayRef(Ref("rows", TArray(t)), i)))) val expected = rows.take(10).map(rowF) - assertAggEquals(aggSig, + assertAggEquals( + aggSig, FastSeq(I32(10)), seqOpArgs, expected = expected, - args = FastSeq(("rows", (TArray(t), rows)))) + args = FastSeq(("rows", (TArray(t), rows))), + ) } } @@ -416,58 +556,101 @@ class Aggregators2Suite extends HailSuite { Begin(FastSeq( SeqOp(aggIdx, FastSeq(ArrayLen(a)), alstate), - StreamFor(StreamRange(0, ArrayLen(a), 1), idx.name, + StreamFor( + StreamRange(0, ArrayLen(a), 1), + idx.name, bindIR(ArrayRef(a, idx)) { elt => SeqOp(aggIdx, FastSeq(idx, seqOps(elt)), AggElementsAggSig(alstate.nested)) - } - ) + }, + ), )) } - @Test def testMin() { + @Test def testMin(): Unit = { val aggSig = PhysicalAggSig(Min(), TypedStateSig(VirtualTypeWithReq(PInt64(false)))) - val seqOpArgs = Array.tabulate(rows.length)(i => FastSeq[IR](GetField(ArrayRef(Ref("rows", arrayType), i), "b"))) + val seqOpArgs = Array.tabulate(rows.length)(i => + FastSeq[IR](GetField(ArrayRef(Ref("rows", arrayType), i), "b")) + ) val seqOpArgsNA = Array.tabulate(8)(i => FastSeq[IR](NA(TInt64))) - assertAggEquals(aggSig, FastSeq(), seqOpArgs, expected = -2L, args = FastSeq(("rows", (arrayType, rows)))) - assertAggEquals(aggSig, FastSeq(), seqOpArgsNA, expected = null, args = FastSeq(("rows", (arrayType, rows)))) + assertAggEquals( + aggSig, + FastSeq(), + seqOpArgs, + expected = -2L, + args = FastSeq(("rows", (arrayType, rows))), + ) + assertAggEquals( + aggSig, + FastSeq(), + seqOpArgsNA, + expected = null, + args = FastSeq(("rows", (arrayType, rows))), + ) } - @Test def testMax() { + @Test def testMax(): Unit = { val aggSig = PhysicalAggSig(Max(), TypedStateSig(VirtualTypeWithReq(PInt64(false)))) - val seqOpArgs = Array.tabulate(rows.length)(i => FastSeq[IR](GetField(ArrayRef(Ref("rows", arrayType), i), "b"))) + val seqOpArgs = Array.tabulate(rows.length)(i => + FastSeq[IR](GetField(ArrayRef(Ref("rows", arrayType), i), "b")) + ) val seqOpArgsNA = Array.tabulate(8)(i => FastSeq[IR](NA(TInt64))) - assertAggEquals(aggSig, FastSeq(), seqOpArgs, expected = 7L, args = FastSeq(("rows", (arrayType, rows)))) - assertAggEquals(aggSig, FastSeq(), seqOpArgsNA, expected = null, args = FastSeq(("rows", (arrayType, rows)))) + assertAggEquals( + aggSig, + FastSeq(), + seqOpArgs, + expected = 7L, + args = FastSeq(("rows", (arrayType, rows))), + ) + assertAggEquals( + aggSig, + FastSeq(), + seqOpArgsNA, + expected = null, + args = FastSeq(("rows", (arrayType, rows))), + ) } - @Test def testCollectLongs() { - val seqOpArgs = Array.tabulate(rows.length)(i => FastSeq[IR](GetField(ArrayRef(Ref("rows", arrayType), i), "b"))) - assertAggEquals(collectAggSig(TInt64), FastSeq(), seqOpArgs, + @Test def testCollectLongs(): Unit = { + val seqOpArgs = Array.tabulate(rows.length)(i => + FastSeq[IR](GetField(ArrayRef(Ref("rows", arrayType), i), "b")) + ) + assertAggEquals( + collectAggSig(TInt64), + FastSeq(), + seqOpArgs, expected = FastSeq(5L, null, -2L, 7L, null, null), - args = FastSeq(("rows", (arrayType, rows))) + args = FastSeq(("rows", (arrayType, rows))), ) } - @Test def testCollectStrs() { - val seqOpArgs = Array.tabulate(rows.length)(i => FastSeq[IR](GetField(ArrayRef(Ref("rows", arrayType), i), "a"))) + @Test def testCollectStrs(): Unit = { + val seqOpArgs = Array.tabulate(rows.length)(i => + FastSeq[IR](GetField(ArrayRef(Ref("rows", arrayType), i), "a")) + ) - assertAggEquals(collectAggSig(TString), FastSeq(), seqOpArgs, + assertAggEquals( + collectAggSig(TString), + FastSeq(), + seqOpArgs, expected = FastSeq("abcd", null, null, "abcd", null, "foo"), - args = FastSeq(("rows", (arrayType, rows))) + args = FastSeq(("rows", (arrayType, rows))), ) } - @Test def testCollectBig() { + @Test def testCollectBig(): Unit = { val seqOpArgs = Array.tabulate(100)(i => FastSeq(I64(i))) - assertAggEquals(collectAggSig(TInt64), FastSeq(), seqOpArgs, + assertAggEquals( + collectAggSig(TInt64), + FastSeq(), + seqOpArgs, expected = Array.tabulate(100)(i => i.toLong).toIndexedSeq, - args = FastSeq(("rows", (arrayType, rows))) + args = FastSeq(("rows", (arrayType, rows))), ) } - @Test def testArrayElementsAgg() { + @Test def testArrayElementsAgg(): Unit = { val alState = ArrayLenAggSig(knownLength = false, FastSeq(pnnAggSig, countAggSig, sumAggSig)) val value = FastSeq( @@ -476,163 +659,274 @@ class Aggregators2Suite extends HailSuite { FastSeq(Row("a", 2L), Row("b", 2L), null, Row("f", 2L)), FastSeq(Row("a", 3L), Row("b", 3L), Row("c", 3L), Row("f", 3L)), FastSeq(Row("a", 4L), Row("b", 4L), Row("c", 4L), null), - FastSeq(null, null, null, Row("f", 5L))) + FastSeq(null, null, null, Row("f", 5L)), + ) val expected = FastSeq( Row(Row("a", 4L), 6L, 10L), Row(Row("b", 4L), 6L, 9L), Row(Row("c", 4L), 6L, 8L), - Row(Row("f", 5L), 6L, 10L)) + Row(Row("f", 5L), 6L, 10L), + ) - val init = InitOp(0, FastSeq(Begin(FastSeq[IR]( - InitOp(0, FastSeq(), pnnAggSig), - InitOp(1, FastSeq(), countAggSig), - InitOp(2, FastSeq(), sumAggSig) - ))), alState) + val init = InitOp( + 0, + FastSeq(Begin(FastSeq[IR]( + InitOp(0, FastSeq(), pnnAggSig), + InitOp(1, FastSeq(), countAggSig), + InitOp(2, FastSeq(), sumAggSig), + ))), + alState, + ) val stream = Ref("stream", TArray(arrayType)) val seq = Array.tabulate(value.length) { i => - seqOpOverArray(0, ArrayRef(stream, i), { elt => - Begin(FastSeq( - SeqOp(0, FastSeq(elt), pnnAggSig), - SeqOp(1, FastSeq(), countAggSig), - SeqOp(2, FastSeq(GetField(elt, "b")), sumAggSig))) - }, alState) + seqOpOverArray( + 0, + ArrayRef(stream, i), + elt => + Begin(FastSeq( + SeqOp(0, FastSeq(elt), pnnAggSig), + SeqOp(1, FastSeq(), countAggSig), + SeqOp(2, FastSeq(GetField(elt, "b")), sumAggSig), + )), + alState, + ) } - assertAggEqualsProcessed(alState, init, seq, expected, FastSeq(("stream", (stream.typ, value))), 2, None) + assertAggEqualsProcessed( + alState, + init, + seq, + expected, + FastSeq(("stream", (stream.typ, value))), + 2, + None, + ) } - @Test def testNestedArrayElementsAgg() { + @Test def testNestedArrayElementsAgg(): Unit = { val alstate1 = ArrayLenAggSig(knownLength = false, FastSeq(sumAggSig)) - val aestate1 = AggElementsAggSig(FastSeq(sumAggSig)) val alstate2 = ArrayLenAggSig(knownLength = false, FastSeq[PhysicalAggSig](alstate1)) - val init = InitOp(0, FastSeq(Begin(FastSeq[IR]( - InitOp(0, FastSeq(Begin(FastSeq[IR]( - InitOp(0, FastSeq(), sumAggSig) - ))), alstate1) - ))), alstate2) + val init = InitOp( + 0, + FastSeq(Begin(FastSeq[IR]( + InitOp( + 0, + FastSeq(Begin(FastSeq[IR]( + InitOp(0, FastSeq(), sumAggSig) + ))), + alstate1, + ) + ))), + alstate2, + ) val stream = Ref("stream", TArray(TArray(TArray(TInt64)))) val seq = Array.tabulate(10) { i => - seqOpOverArray(0, ArrayRef(stream, i), { array1 => - seqOpOverArray(0, array1, { elt => - SeqOp(0, FastSeq(elt), sumAggSig) - }, alstate1) - }, alstate2) + seqOpOverArray( + 0, + ArrayRef(stream, i), + array1 => seqOpOverArray(0, array1, elt => SeqOp(0, FastSeq(elt), sumAggSig), alstate1), + alstate2, + ) } val expected = FastSeq(Row(FastSeq(Row(45L)))) val args = Array.tabulate(10)(i => FastSeq(FastSeq(i.toLong))).toFastSeq - assertAggEqualsProcessed(alstate2, init, seq, expected, FastSeq(("stream", (stream.typ, args))), 2, None) + assertAggEqualsProcessed( + alstate2, + init, + seq, + expected, + FastSeq(("stream", (stream.typ, args))), + 2, + None, + ) } - @Test def testArrayElementsAggTake() { + @Test def testArrayElementsAggTake(): Unit = { val value = FastSeq( FastSeq(Row("a", 0L), Row("b", 0L), Row("c", 0L), Row("f", 0L)), FastSeq(Row("a", 1L), null, Row("c", 1L), null), FastSeq(Row("a", 2L), Row("b", 2L), null, Row("f", 2L)), FastSeq(Row("a", 3L), Row("b", 3L), Row("c", 3L), Row("f", 3L)), FastSeq(Row("a", 4L), Row("b", 4L), Row("c", 4L), null), - FastSeq(null, null, null, Row("f", 5L))) + FastSeq(null, null, null, Row("f", 5L)), + ) val take = PhysicalAggSig(Take(), TakeStateSig(VirtualTypeWithReq(PType.canonical(t)))) val alstate = ArrayLenAggSig(knownLength = false, FastSeq(take)) - val init = InitOp(0, FastSeq(Begin(FastSeq[IR]( - InitOp(0, FastSeq(I32(3)), take) - ))), alstate) + val init = InitOp( + 0, + FastSeq(Begin(FastSeq[IR]( + InitOp(0, FastSeq(I32(3)), take) + ))), + alstate, + ) val stream = Ref("stream", TArray(arrayType)) val seq = Array.tabulate(value.length) { i => - seqOpOverArray(0, ArrayRef(stream, i), { elt => - SeqOp(0, FastSeq(elt), take) - }, alstate) + seqOpOverArray(0, ArrayRef(stream, i), elt => SeqOp(0, FastSeq(elt), take), alstate) } - val expected = Array.tabulate(value(0).length)(i => Row(Array.tabulate(3)(j => value(j)(i)).toFastSeq)).toFastSeq - assertAggEqualsProcessed(alstate, init, seq, expected, FastSeq(("stream", (stream.typ, value))), 2, None) + val expected = Array.tabulate(value(0).length)(i => + Row(Array.tabulate(3)(j => value(j)(i)).toFastSeq) + ).toFastSeq + assertAggEqualsProcessed( + alstate, + init, + seq, + expected, + FastSeq(("stream", (stream.typ, value))), + 2, + None, + ) } - @Test def testGroup() { - val group = GroupedAggSig(VirtualTypeWithReq(PCanonicalString()), FastSeq(pnnAggSig, countAggSig, sumAggSig)) + @Test def testGroup(): Unit = { + val group = GroupedAggSig( + VirtualTypeWithReq(PCanonicalString()), + FastSeq(pnnAggSig, countAggSig, sumAggSig), + ) val initOpArgs = FastSeq(Begin(FastSeq( InitOp(0, FastSeq(), pnnAggSig), InitOp(1, FastSeq(), countAggSig), - InitOp(2, FastSeq(), sumAggSig)))) + InitOp(2, FastSeq(), sumAggSig), + ))) - val rows = FastSeq(Row("abcd", 5L), null, Row(null, -2L), Row("abcd", 7L), null, Row("foo", null)) + val rows = + FastSeq(Row("abcd", 5L), null, Row(null, -2L), Row("abcd", 7L), null, Row("foo", null)) val rref = Ref("rows", TArray(t)) val seqOpArgs = Array.tabulate(rows.length)(i => - FastSeq[IR](GetField(ArrayRef(rref, i), "a"), + FastSeq[IR]( + GetField(ArrayRef(rref, i), "a"), Begin(FastSeq( SeqOp(0, FastSeq(ArrayRef(rref, i)), pnnAggSig), SeqOp(1, FastSeq(), countAggSig), - SeqOp(2, FastSeq(GetField(ArrayRef(rref, i), "b")), sumAggSig))))) + SeqOp(2, FastSeq(GetField(ArrayRef(rref, i), "b")), sumAggSig), + )), + ) + ) val expected = Map( "abcd" -> Row(Row("abcd", 7L), 2L, 12L), "foo" -> Row(Row("foo", null), 1L, 0L), - (null, Row(Row(null, -2L), 3L, -2L))) + (null, Row(Row(null, -2L), 3L, -2L)), + ) - assertAggEquals(group, initOpArgs, seqOpArgs, expected = expected, args = FastSeq(("rows", (arrayType, rows)))) + assertAggEquals( + group, + initOpArgs, + seqOpArgs, + expected = expected, + args = FastSeq(("rows", (arrayType, rows))), + ) } - @Test def testNestedGroup() { + @Test def testNestedGroup(): Unit = { - val group1 = GroupedAggSig(VirtualTypeWithReq(PCanonicalString()), FastSeq(pnnAggSig, countAggSig, sumAggSig)) - val group2 = GroupedAggSig(VirtualTypeWithReq(PCanonicalString()), FastSeq[PhysicalAggSig](group1)) + val group1 = GroupedAggSig( + VirtualTypeWithReq(PCanonicalString()), + FastSeq(pnnAggSig, countAggSig, sumAggSig), + ) + val group2 = + GroupedAggSig(VirtualTypeWithReq(PCanonicalString()), FastSeq[PhysicalAggSig](group1)) val initOpArgs = FastSeq( - InitOp(0, FastSeq( - Begin(FastSeq( - InitOp(0, FastSeq(), pnnAggSig), - InitOp(1, FastSeq(), countAggSig), - InitOp(2, FastSeq(), sumAggSig))) - ), group1)) + InitOp( + 0, + FastSeq( + Begin(FastSeq( + InitOp(0, FastSeq(), pnnAggSig), + InitOp(1, FastSeq(), countAggSig), + InitOp(2, FastSeq(), sumAggSig), + )) + ), + group1, + ) + ) - val rows = FastSeq(Row("abcd", 5L), null, Row(null, -2L), Row("abcd", 7L), null, Row("foo", null)) + val rows = + FastSeq(Row("abcd", 5L), null, Row(null, -2L), Row("abcd", 7L), null, Row("foo", null)) val rref = Ref("rows", TArray(t)) val seqOpArgs = Array.tabulate(rows.length)(i => - FastSeq[IR](GetField(ArrayRef(rref, i), "a"), - SeqOp(0, FastSeq[IR](GetField(ArrayRef(rref, i), "a"), - Begin(FastSeq( - SeqOp(0, FastSeq(ArrayRef(rref, i)), pnnAggSig), - SeqOp(1, FastSeq(), countAggSig), - SeqOp(2, FastSeq(GetField(ArrayRef(rref, i), "b")), sumAggSig))) - ), group1))) + FastSeq[IR]( + GetField(ArrayRef(rref, i), "a"), + SeqOp( + 0, + FastSeq[IR]( + GetField(ArrayRef(rref, i), "a"), + Begin(FastSeq( + SeqOp(0, FastSeq(ArrayRef(rref, i)), pnnAggSig), + SeqOp(1, FastSeq(), countAggSig), + SeqOp(2, FastSeq(GetField(ArrayRef(rref, i), "b")), sumAggSig), + )), + ), + group1, + ), + ) + ) val expected = Map( "abcd" -> Row(Map("abcd" -> Row(Row("abcd", 7L), 2L, 12L))), "foo" -> Row(Map("foo" -> Row(Row("foo", null), 1L, 0L))), - (null, Row(Map((null, Row(Row(null, -2L), 3L, -2L)))))) + (null, Row(Map((null, Row(Row(null, -2L), 3L, -2L))))), + ) - assertAggEquals(group2, initOpArgs, seqOpArgs, expected = expected, args = FastSeq(("rows", (arrayType, rows)))) + assertAggEquals( + group2, + initOpArgs, + seqOpArgs, + expected = expected, + args = FastSeq(("rows", (arrayType, rows))), + ) } - @Test def testCollectAsSet() { - val rows = FastSeq(Row("abcd", 5L), null, Row(null, -2L), Row("abcd", 7L), null, Row("foo", null)) + @Test def testCollectAsSet(): Unit = { + val rows = + FastSeq(Row("abcd", 5L), null, Row(null, -2L), Row("abcd", 7L), null, Row("foo", null)) val rref = Ref("rows", TArray(t)) val elts = Array.tabulate(rows.length)(i => FastSeq(GetField(ArrayRef(rref, i), "a"))) val eltsPrimitive = Array.tabulate(rows.length)(i => FastSeq(GetField(ArrayRef(rref, i), "b"))) val expected = Set("abcd", "foo", null) - val expectedPrimitive = Set(5L, -2L, 7L, null) - - val aggsig = PhysicalAggSig(CollectAsSet(), CollectAsSetStateSig(VirtualTypeWithReq(PCanonicalString()))) - val aggsigPrimitive = PhysicalAggSig(CollectAsSet(), CollectAsSetStateSig(VirtualTypeWithReq(PInt64()))) - assertAggEquals(aggsig, FastSeq(), elts, expected = expected, args = FastSeq(("rows", (arrayType, rows))), expectedInit = Some(Set())) - assertAggEquals(aggsigPrimitive, FastSeq(), eltsPrimitive, expected = expectedPrimitive, args = FastSeq(("rows", (arrayType, rows))), expectedInit = Some(Set())) + val expectedPrimitive: Set[Any] = Set(5L, -2L, 7L, null) + + val aggsig = + PhysicalAggSig(CollectAsSet(), CollectAsSetStateSig(VirtualTypeWithReq(PCanonicalString()))) + val aggsigPrimitive = + PhysicalAggSig(CollectAsSet(), CollectAsSetStateSig(VirtualTypeWithReq(PInt64()))) + assertAggEquals( + aggsig, + FastSeq(), + elts, + expected = expected, + args = FastSeq(("rows", (arrayType, rows))), + expectedInit = Some(Set()), + ) + assertAggEquals( + aggsigPrimitive, + FastSeq(), + eltsPrimitive, + expected = expectedPrimitive, + args = FastSeq(("rows", (arrayType, rows))), + expectedInit = Some(Set()), + ) } - @Test def testDownsample() { - val aggSig = PhysicalAggSig(Downsample(), DownsampleStateSig(VirtualTypeWithReq(PCanonicalArray(PCanonicalString())))) + @Test def testDownsample(): Unit = { + val aggSig = PhysicalAggSig( + Downsample(), + DownsampleStateSig(VirtualTypeWithReq(PCanonicalArray(PCanonicalString()))), + ) val rows = FastSeq( Row(-1.23, 1.23, null), Row(-10d, 10d, FastSeq("foo")), @@ -650,64 +944,96 @@ class Aggregators2Suite extends HailSuite { Row(4d, 4.4d, null), Row(3d, 3.3d, null), Row(3d, 3.3d, null), - Row(3d, 3.3d, null) + Row(3d, 3.3d, null), ) val arrayType = TArray(TStruct("x" -> TFloat64, "y" -> TFloat64, "label" -> TArray(TString))) - val seqOpArgs = Array.tabulate(rows.length)(i => FastSeq[IR]( - GetField(ArrayRef(Ref("rows", arrayType), i), "x"), - GetField(ArrayRef(Ref("rows", arrayType), i), "y"), - GetField(ArrayRef(Ref("rows", arrayType), i), "label") - )) + val seqOpArgs = Array.tabulate(rows.length)(i => + FastSeq[IR]( + GetField(ArrayRef(Ref("rows", arrayType), i), "x"), + GetField(ArrayRef(Ref("rows", arrayType), i), "y"), + GetField(ArrayRef(Ref("rows", arrayType), i), "label"), + ) + ) - assertAggEquals(aggSig, + assertAggEquals( + aggSig, FastSeq(I32(500)), Array.fill[IndexedSeq[IR]](20)(FastSeq(NA(TFloat64), NA(TFloat64), NA(TArray(TString)))), expected = FastSeq(), - args = FastSeq(("rows", (arrayType, rows)))) + args = FastSeq(("rows", (arrayType, rows))), + ) val expected = rows.toSet - assertAggEquals(aggSig, + assertAggEquals( + aggSig, FastSeq(I32(100)), seqOpArgs, expected = expected, args = FastSeq(("rows", (arrayType, rows))), - transformResult = Some(_.asInstanceOf[IndexedSeq[_]].toSet)) + transformResult = Some(_.asInstanceOf[IndexedSeq[_]].toSet), + ) } @Test def testLoweringMatrixMapColsWithAggFilterAndLets(): Unit = { - val t = MatrixType(TStruct.empty, FastSeq("col_idx"), TStruct("col_idx" -> TInt32), FastSeq("row_idx"), TStruct("row_idx" -> TInt32), TStruct.empty) + val t = MatrixType( + TStruct.empty, + FastSeq("col_idx"), + TStruct("col_idx" -> TInt32), + FastSeq("row_idx"), + TStruct("row_idx" -> TInt32), + TStruct.empty, + ) val ir = TableCollect(MatrixColsTable(MatrixMapCols( MatrixRead(t, false, false, MatrixRangeReader(10, 10, None)), - InsertFields(Ref("sa", t.colType), FastSeq(("foo", - Let(FastSeq("bar" -> (GetField(Ref("sa", t.colType), "col_idx") + I32(1))), - AggFilter( - GetField(Ref("va", t.rowType), "row_idx") < I32(5), - Ref("bar", TInt32).toL + Ref("bar", TInt32).toL + ApplyAggOp( - FastSeq(), - FastSeq(GetField(Ref("va", t.rowType), "row_idx").toL), - AggSignature(Sum(), FastSeq(), FastSeq(TInt64))), - false))))), - Some(FastSeq())))) - assertEvalsTo(ir, Row((0 until 10).map(i => Row(i, 2L * i + 12L)), Row()))(ExecStrategy.interpretOnly) + InsertFields( + Ref("sa", t.colType), + FastSeq(( + "foo", + Let( + FastSeq("bar" -> (GetField(Ref("sa", t.colType), "col_idx") + I32(1))), + AggFilter( + GetField(Ref("va", t.rowType), "row_idx") < I32(5), + Ref("bar", TInt32).toL + Ref("bar", TInt32).toL + ApplyAggOp( + FastSeq(), + FastSeq(GetField(Ref("va", t.rowType), "row_idx").toL), + AggSignature(Sum(), FastSeq(), FastSeq(TInt64)), + ), + false, + ), + ), + )), + ), + Some(FastSeq()), + ))) + assertEvalsTo(ir, Row((0 until 10).map(i => Row(i, 2L * i + 12L)), Row()))( + ExecStrategy.interpretOnly + ) } @Test def testRunAggScan(): Unit = { implicit val execStrats = ExecStrategy.compileOnly - val sig = PhysicalAggSig(Sum(), TypedStateSig(VirtualTypeWithReq.fullyOptional(TFloat64).setRequired(true))) + val sig = PhysicalAggSig( + Sum(), + TypedStateSig(VirtualTypeWithReq.fullyOptional(TFloat64).setRequired(true)), + ) val x = ToArray(RunAggScan( StreamRange(I32(0), I32(5), I32(1)), "foo", InitOp(0, FastSeq(), sig), SeqOp(0, FastSeq(Ref("foo", TInt32).toD), sig), ResultOp(0, sig), - Array(sig.state))) + Array(sig.state), + )) assertEvalsTo(x, FastSeq(0.0, 0.0, 1.0, 3.0, 6.0)) } @Test def testNestedRunAggScan(): Unit = { implicit val execStrats = ExecStrategy.compileOnly - val sig = PhysicalAggSig(Sum(), TypedStateSig(VirtualTypeWithReq.fullyOptional(TFloat64).setRequired(true))) + val sig = PhysicalAggSig( + Sum(), + TypedStateSig(VirtualTypeWithReq.fullyOptional(TFloat64).setRequired(true)), + ) val x = ToArray( StreamFlatMap( @@ -719,11 +1045,17 @@ class Aggregators2Suite extends HailSuite { InitOp(0, FastSeq(), sig), SeqOp(0, FastSeq(Ref("foo", TInt32).toD), sig), ResultOp(0, sig), - Array(sig.state)))) - assertEvalsTo(x, FastSeq( - 0.0, 0.0, 1.0, - 0.0, 0.0, 1.0, 3.0, - 0.0, 0.0, 1.0, 3.0, 6.0)) + Array(sig.state), + ), + ) + ) + assertEvalsTo( + x, + FastSeq( + 0.0, 0.0, 1.0, + 0.0, 0.0, 1.0, 3.0, + 0.0, 0.0, 1.0, 3.0, 6.0), + ) } @Test def testRunAggBasic(): Unit = { @@ -733,9 +1065,11 @@ class Aggregators2Suite extends HailSuite { Begin(FastSeq( InitOp(0, FastSeq(), sig), SeqOp(0, FastSeq(F64(1.0)), sig), - SeqOp(0, FastSeq(F64(-5.0)), sig))), + SeqOp(0, FastSeq(F64(-5.0)), sig), + )), ResultOp.makeTuple(FastSeq(sig)), - FastSeq(sig.state)) + FastSeq(sig.state), + ) assertEvalsTo(x, Row(-4.0)) } @@ -749,19 +1083,26 @@ class Aggregators2Suite extends HailSuite { StreamFor( StreamRange(I32(0), I32(10), I32(1)), "foo", - SeqOp(0, FastSeq( - RunAgg( - Begin(FastSeq( - InitOp(0, FastSeq(), sumSig), - SeqOp(0, FastSeq(F64(-1.0)), sumSig), - SeqOp(0, FastSeq(Ref("foo", TInt32).toD), sumSig))), - ResultOp(0, sumSig), - FastSeq(sumSig.state)) - ), takeSig) - )) - ), + SeqOp( + 0, + FastSeq( + RunAgg( + Begin(FastSeq( + InitOp(0, FastSeq(), sumSig), + SeqOp(0, FastSeq(F64(-1.0)), sumSig), + SeqOp(0, FastSeq(Ref("foo", TInt32).toD), sumSig), + )), + ResultOp(0, sumSig), + FastSeq(sumSig.state), + ) + ), + takeSig, + ), + ), + )), ResultOp(0, takeSig), - FastSeq(takeSig.state)) + FastSeq(takeSig.state), + ) assertEvalsTo(x, FastSeq(-1d, 0d, 1d, 2d, 3d)) } @@ -773,23 +1114,25 @@ class Aggregators2Suite extends HailSuite { Begin(FastSeq( InitOp(0, FastSeq(I32(10)), takeSig), SeqOp(0, FastSeq(NA(TInt64)), takeSig), - SeqOp(0, FastSeq(I64(-1l)), takeSig), - SeqOp(0, FastSeq(I64(2l)), takeSig) - ) - ), + SeqOp(0, FastSeq(I64(-1L)), takeSig), + SeqOp(0, FastSeq(I64(2L)), takeSig), + )), AggStateValue(0, takeSig.state), - FastSeq(takeSig.state) + FastSeq(takeSig.state), )), RunAgg( Begin(FastSeq( InitOp(0, FastSeq(I32(10)), takeSig), CombOpValue(0, Ref("x", TBinary), takeSig), - SeqOp(0, FastSeq(I64(3l)), takeSig), + SeqOp(0, FastSeq(I64(3L)), takeSig), CombOpValue(0, Ref("x", TBinary), takeSig), - SeqOp(0, FastSeq(I64(0l)), takeSig))), + SeqOp(0, FastSeq(I64(0L)), takeSig), + )), ResultOp(0, takeSig), - FastSeq(takeSig.state))) + FastSeq(takeSig.state), + ), + ) - assertEvalsTo(x, FastSeq(null, -1l, 2l, 3l, null, null, -1l, 2l, 0l)) + assertEvalsTo(x, FastSeq(null, -1L, 2L, 3L, null, null, -1L, 2L, 0L)) } } diff --git a/hail/src/test/scala/is/hail/expr/ir/AggregatorsSuite.scala b/hail/src/test/scala/is/hail/expr/ir/AggregatorsSuite.scala index ed6343ea29f..cc41b6d7df6 100644 --- a/hail/src/test/scala/is/hail/expr/ir/AggregatorsSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/AggregatorsSuite.scala @@ -1,11 +1,12 @@ package is.hail.expr.ir +import is.hail.{ExecStrategy, HailSuite} import is.hail.expr.ir.DeprecatedIRBuilder._ import is.hail.expr.ir.lowering.{DArrayLowering, LowerTableIR} import is.hail.types.virtual._ import is.hail.utils.{FastSeq, _} import is.hail.variant.Call2 -import is.hail.{ExecStrategy, HailSuite} + import org.apache.spark.sql.Row import org.testng.annotations.Test @@ -13,27 +14,50 @@ class AggregatorsSuite extends HailSuite { implicit val execStrats = ExecStrategy.compileOnly - def runAggregator(op: AggOp, aggType: TStruct, agg: IndexedSeq[Row], expected: Any, initOpArgs: IndexedSeq[IR], - seqOpArgs: IndexedSeq[IR]) { + def runAggregator( + op: AggOp, + aggType: TStruct, + agg: IndexedSeq[Row], + expected: Any, + initOpArgs: IndexedSeq[IR], + seqOpArgs: IndexedSeq[IR], + ): Unit = { val aggSig = AggSignature(op, initOpArgs.map(_.typ), seqOpArgs.map(_.typ)) assertEvalsTo( ApplyAggOp(initOpArgs, seqOpArgs, aggSig), (agg, aggType), - expected) + expected, + ) } - def runAggregator(op: AggOp, t: Type, a: IndexedSeq[Any], expected: Any, - initOpArgs: IndexedSeq[IR] = FastSeq()) { - runAggregator(op, + def runAggregator( + op: AggOp, + t: Type, + a: IndexedSeq[Any], + expected: Any, + initOpArgs: IndexedSeq[IR] = FastSeq(), + ): Unit = { + runAggregator( + op, TStruct("x" -> t), a.map(i => Row(i)), expected, initOpArgs, - seqOpArgs = FastSeq(Ref("x", t))) + seqOpArgs = FastSeq(Ref("x", t)), + ) } - @Test def sumFloat64() { + @Test def nestedAgg(): Unit = { + val agg = ToArray(StreamMap(StreamRange(0, 10, 1), "elt", ApplyAggOp(Count())())) + assertEvalsTo( + agg, + (FastSeq(1, 2).map(i => Row(i)), TStruct("x" -> TInt32)), + IndexedSeq.fill(10)(2L), + ) + } + + @Test def sumFloat64(): Unit = { runAggregator(Sum(), TFloat64, (0 to 100).map(_.toDouble), 5050.0) runAggregator(Sum(), TFloat64, FastSeq(), 0.0) runAggregator(Sum(), TFloat64, FastSeq(42.0), 42.0) @@ -41,156 +65,210 @@ class AggregatorsSuite extends HailSuite { runAggregator(Sum(), TFloat64, FastSeq(null, null, null), 0.0) } - @Test def sumInt64() { + @Test def sumInt64(): Unit = runAggregator(Sum(), TInt64, FastSeq(-1L, 2L, 3L), 4L) - } - @Test def collectBoolean() { - runAggregator(Collect(), TBoolean, FastSeq(true, false, null, true, false), FastSeq(true, false, null, true, false)) + @Test def collectBoolean(): Unit = { + runAggregator( + Collect(), + TBoolean, + FastSeq(true, false, null, true, false), + FastSeq(true, false, null, true, false), + ) } - @Test def collectInt() { + @Test def collectInt(): Unit = runAggregator(Collect(), TInt32, FastSeq(10, null, 5), FastSeq(10, null, 5)) - } - @Test def collectLong() { + @Test def collectLong(): Unit = runAggregator(Collect(), TInt64, FastSeq(10L, null, 5L), FastSeq(10L, null, 5L)) - } - @Test def collectFloat() { + @Test def collectFloat(): Unit = runAggregator(Collect(), TFloat32, FastSeq(10f, null, 5f), FastSeq(10f, null, 5f)) - } - @Test def collectDouble() { + @Test def collectDouble(): Unit = runAggregator(Collect(), TFloat64, FastSeq(10d, null, 5d), FastSeq(10d, null, 5d)) - } - @Test def collectString() { + @Test def collectString(): Unit = runAggregator(Collect(), TString, FastSeq("hello", null, "foo"), FastSeq("hello", null, "foo")) - } - @Test def collectArray() { - runAggregator(Collect(), - TArray(TInt32), FastSeq(FastSeq(1, 2, 3), null, FastSeq()), FastSeq(FastSeq(1, 2, 3), null, FastSeq())) + @Test def collectArray(): Unit = { + runAggregator( + Collect(), + TArray(TInt32), + FastSeq(FastSeq(1, 2, 3), null, FastSeq()), + FastSeq(FastSeq(1, 2, 3), null, FastSeq()), + ) } - @Test def collectStruct() { - runAggregator(Collect(), + @Test def collectStruct(): Unit = { + runAggregator( + Collect(), TStruct("a" -> TInt32, "b" -> TBoolean), FastSeq(Row(5, true), Row(3, false), null, Row(0, false), null), - FastSeq(Row(5, true), Row(3, false), null, Row(0, false), null)) + FastSeq(Row(5, true), Row(3, false), null, Row(0, false), null), + ) } - @Test def count() { - runAggregator(Count(), + @Test def count(): Unit = { + runAggregator( + Count(), TStruct("x" -> TString), FastSeq(Row("hello"), Row("foo"), Row("a"), Row(null), Row("b"), Row(null), Row("c")), 7L, initOpArgs = FastSeq(), - seqOpArgs = FastSeq()) + seqOpArgs = FastSeq(), + ) } - @Test def collectAsSetBoolean() { - runAggregator(CollectAsSet(), TBoolean, FastSeq(true, false, null, true, false), Set(true, false, null)) + @Test def collectAsSetBoolean(): Unit = { + runAggregator( + CollectAsSet(), + TBoolean, + FastSeq(true, false, null, true, false), + Set(true, false, null), + ) runAggregator(CollectAsSet(), TBoolean, FastSeq(true, null, true), Set(true, null)) } - @Test def collectAsSetNumeric() { + @Test def collectAsSetNumeric(): Unit = { runAggregator(CollectAsSet(), TInt32, FastSeq(10, null, 5, 5, null), Set(10, null, 5)) runAggregator(CollectAsSet(), TInt64, FastSeq(10L, null, 5L, 5L, null), Set(10L, null, 5L)) runAggregator(CollectAsSet(), TFloat32, FastSeq(10f, null, 5f, 5f, null), Set(10f, null, 5f)) runAggregator(CollectAsSet(), TFloat64, FastSeq(10d, null, 5d, 5d, null), Set(10d, null, 5d)) } - @Test def collectAsSetString() { - runAggregator(CollectAsSet(), TString, FastSeq("hello", null, "foo", null, "foo"), Set("hello", null, "foo")) + @Test def collectAsSetString(): Unit = { + runAggregator( + CollectAsSet(), + TString, + FastSeq("hello", null, "foo", null, "foo"), + Set("hello", null, "foo"), + ) } - @Test def collectAsSetArray() { + @Test def collectAsSetArray(): Unit = { val inputCollection = FastSeq(FastSeq(1, 2, 3), null, FastSeq(), null, FastSeq(1, 2, 3)) val expected = Set(FastSeq(1, 2, 3), null, FastSeq()) runAggregator(CollectAsSet(), TArray(TInt32), inputCollection, expected) } - @Test def collectAsSetStruct(): Unit = { - runAggregator(CollectAsSet(), + @Test def collectAsSetStruct(): Unit = + runAggregator( + CollectAsSet(), TStruct("a" -> TInt32, "b" -> TBoolean), FastSeq(Row(5, true), Row(3, false), null, Row(0, false), null, Row(5, true)), - Set(Row(5, true), Row(3, false), null, Row(0, false))) - } + Set(Row(5, true), Row(3, false), null, Row(0, false)), + ) - @Test def callStats() { - runAggregator(CallStats(), TCall, + @Test def callStats(): Unit = { + runAggregator( + CallStats(), + TCall, FastSeq(Call2(0, 0), Call2(0, 1), null, Call2(0, 2)), Row(FastSeq(4, 1, 1), FastSeq(4.0 / 6.0, 1.0 / 6.0, 1.0 / 6.0), 6, FastSeq(1, 0, 0)), - initOpArgs = FastSeq(I32(3))) + initOpArgs = FastSeq(I32(3)), + ) } // FIXME Max Boolean not supported by old-style MaxAggregator - @Test def maxInt32() { + @Test def maxInt32(): Unit = { runAggregator(Max(), TInt32, FastSeq(), null) runAggregator(Max(), TInt32, FastSeq(null), null) runAggregator(Max(), TInt32, FastSeq(-2, null, 7), 7) } - @Test def maxInt64() { + @Test def maxInt64(): Unit = runAggregator(Max(), TInt64, FastSeq(-2L, null, 7L), 7L) - } - @Test def maxFloat32() { + @Test def maxFloat32(): Unit = runAggregator(Max(), TFloat32, FastSeq(-2.0f, null, 7.2f), 7.2f) - } - @Test def maxFloat64() { + @Test def maxFloat64(): Unit = runAggregator(Max(), TFloat64, FastSeq(-2.0, null, 7.2), 7.2) - } - @Test def takeInt32() { - runAggregator(Take(), TInt32, FastSeq(2, null, 7), FastSeq(2, null), - initOpArgs = FastSeq(I32(2))) + @Test def takeInt32(): Unit = { + runAggregator( + Take(), + TInt32, + FastSeq(2, null, 7), + FastSeq(2, null), + initOpArgs = FastSeq(I32(2)), + ) } - @Test def takeInt64() { - runAggregator(Take(), TInt64, FastSeq(2L, null, 7L), FastSeq(2L, null), - initOpArgs = FastSeq(I32(2))) + @Test def takeInt64(): Unit = { + runAggregator( + Take(), + TInt64, + FastSeq(2L, null, 7L), + FastSeq(2L, null), + initOpArgs = FastSeq(I32(2)), + ) } - @Test def takeFloat32() { - runAggregator(Take(), TFloat32, FastSeq(2.0f, null, 7.2f), FastSeq(2.0f, null), - initOpArgs = FastSeq(I32(2))) + @Test def takeFloat32(): Unit = { + runAggregator( + Take(), + TFloat32, + FastSeq(2.0f, null, 7.2f), + FastSeq(2.0f, null), + initOpArgs = FastSeq(I32(2)), + ) } - @Test def takeFloat64() { - runAggregator(Take(), TFloat64, FastSeq(2.0, null, 7.2), FastSeq(2.0, null), - initOpArgs = FastSeq(I32(2))) + @Test def takeFloat64(): Unit = { + runAggregator( + Take(), + TFloat64, + FastSeq(2.0, null, 7.2), + FastSeq(2.0, null), + initOpArgs = FastSeq(I32(2)), + ) } - @Test def takeCall() { - runAggregator(Take(), TCall, FastSeq(Call2(0, 0), null, Call2(1, 0)), FastSeq(Call2(0, 0), null), - initOpArgs = FastSeq(I32(2))) + @Test def takeCall(): Unit = { + runAggregator( + Take(), + TCall, + FastSeq(Call2(0, 0), null, Call2(1, 0)), + FastSeq(Call2(0, 0), null), + initOpArgs = FastSeq(I32(2)), + ) } - @Test def takeString() { - runAggregator(Take(), TString, FastSeq("a", null, "b"), FastSeq("a", null), - initOpArgs = FastSeq(I32(2))) + @Test def takeString(): Unit = { + runAggregator( + Take(), + TString, + FastSeq("a", null, "b"), + FastSeq("a", null), + initOpArgs = FastSeq(I32(2)), + ) } @Test - def sumMultivar() { + def sumMultivar(): Unit = { val aggSig = AggSignature(Sum(), FastSeq(), FastSeq(TFloat64)) - assertEvalsTo(ApplyAggOp( - FastSeq(), - FastSeq(ApplyBinaryPrimOp(Multiply(), Ref("a", TFloat64), Ref("b", TFloat64))), - aggSig), - (FastSeq(Row(1.0, 10.0), Row(10.0, 10.0), Row(null, 10.0)), TStruct("a" -> TFloat64, "b" -> TFloat64)), - 110.0) + assertEvalsTo( + ApplyAggOp( + FastSeq(), + FastSeq(ApplyBinaryPrimOp(Multiply(), Ref("a", TFloat64), Ref("b", TFloat64))), + aggSig, + ), + ( + FastSeq(Row(1.0, 10.0), Row(10.0, 10.0), Row(null, 10.0)), + TStruct("a" -> TFloat64, "b" -> TFloat64), + ), + 110.0, + ) } private[this] def assertArraySumEvalsTo[T]( eltType: Type, a: IndexedSeq[Seq[T]], - expected: Seq[T] + expected: Seq[T], ): Unit = { val aggSig = AggSignature(Sum(), FastSeq(), FastSeq(eltType)) @@ -198,10 +276,17 @@ class AggregatorsSuite extends HailSuite { val structType = TStruct("foo" -> TArray(eltType)) assertEvalsTo( - AggArrayPerElement(Ref("foo", TArray(eltType)), "elt", "_", - ApplyAggOp(FastSeq(), FastSeq(Ref("elt", eltType)), aggSig), None, isScan = false), + AggArrayPerElement( + Ref("foo", TArray(eltType)), + "elt", + "_", + ApplyAggOp(FastSeq(), FastSeq(Ref("elt", eltType)), aggSig), + None, + isScan = false, + ), (aggregable, structType), - expected) + expected, + ) } @Test @@ -209,7 +294,7 @@ class AggregatorsSuite extends HailSuite { assertArraySumEvalsTo[Double]( TFloat64, FastSeq(), - null + null, ) @Test @@ -217,7 +302,7 @@ class AggregatorsSuite extends HailSuite { assertArraySumEvalsTo[Double]( TFloat64, FastSeq(null), - null + null, ) @Test @@ -225,7 +310,7 @@ class AggregatorsSuite extends HailSuite { assertArraySumEvalsTo[Double]( TFloat64, FastSeq(null, null, null), - null + null, ) @Test @@ -233,7 +318,7 @@ class AggregatorsSuite extends HailSuite { assertArraySumEvalsTo[Long]( TInt64, FastSeq(), - null + null, ) @Test @@ -241,7 +326,7 @@ class AggregatorsSuite extends HailSuite { assertArraySumEvalsTo[Long]( TInt64, FastSeq(null), - null + null, ) @Test @@ -249,7 +334,7 @@ class AggregatorsSuite extends HailSuite { assertArraySumEvalsTo[Long]( TInt64, FastSeq(null, null, null), - null + null, ) @Test @@ -259,8 +344,9 @@ class AggregatorsSuite extends HailSuite { FastSeq( FastSeq(1.0, 2.0), FastSeq(10.0, 20.0), - null), - FastSeq(11.0, 22.0) + null, + ), + FastSeq(11.0, 22.0), ) @Test @@ -270,8 +356,9 @@ class AggregatorsSuite extends HailSuite { FastSeq( FastSeq(1L, 2L), FastSeq(10L, 20L), - null), - FastSeq(11L, 22L) + null, + ), + FastSeq(11L, 22L), ) @Test @@ -281,244 +368,560 @@ class AggregatorsSuite extends HailSuite { FastSeq( null, FastSeq(1L, 33L), - FastSeq(42L, 3L)), - FastSeq(43L, 36L) + FastSeq(42L, 3L), + ), + FastSeq(43L, 36L), ) - private[this] def assertTakeByEvalsTo(aggType: Type, keyType: Type, n: Int, a: IndexedSeq[Row], expected: IndexedSeq[Any]) { - runAggregator(TakeBy(), TStruct("x" -> aggType, "y" -> keyType), + private[this] def assertTakeByEvalsTo( + aggType: Type, + keyType: Type, + n: Int, + a: IndexedSeq[Row], + expected: IndexedSeq[Any], + ): Unit = { + runAggregator( + TakeBy(), + TStruct("x" -> aggType, "y" -> keyType), a, expected, initOpArgs = FastSeq(I32(n)), - seqOpArgs = FastSeq(Ref("x", aggType), Ref("y", keyType))) + seqOpArgs = FastSeq(Ref("x", aggType), Ref("y", keyType)), + ) } - @Test def takeByNGreater() { - assertTakeByEvalsTo(TInt32, TInt32, 5, - FastSeq(Row(3, 4)), - FastSeq(3)) - } + @Test def takeByNGreater(): Unit = + assertTakeByEvalsTo(TInt32, TInt32, 5, FastSeq(Row(3, 4)), FastSeq(3)) - @Test def takeByBooleanBoolean() { - assertTakeByEvalsTo(TBoolean, TBoolean, 3, + @Test def takeByBooleanBoolean(): Unit = { + assertTakeByEvalsTo( + TBoolean, + TBoolean, + 3, FastSeq(Row(false, true), Row(null, null), Row(true, false)), - FastSeq(true, false, null)) + FastSeq(true, false, null), + ) } - @Test def takeByBooleanInt() { - assertTakeByEvalsTo(TBoolean, TInt32, 3, - FastSeq(Row(false, 0), Row(null, null), Row(true, 1), Row(false, 3), Row(true, null), Row(null, 2)), - FastSeq(false, true, null)) + @Test def takeByBooleanInt(): Unit = { + assertTakeByEvalsTo( + TBoolean, + TInt32, + 3, + FastSeq( + Row(false, 0), + Row(null, null), + Row(true, 1), + Row(false, 3), + Row(true, null), + Row(null, 2), + ), + FastSeq(false, true, null), + ) } - @Test def takeByBooleanLong() { - assertTakeByEvalsTo(TBoolean, TInt64, 3, - FastSeq(Row(false, 0L), Row(null, null), Row(true, 1L), Row(false, 3L), Row(true, null), Row(null, 2L)), - FastSeq(false, true, null)) + @Test def takeByBooleanLong(): Unit = { + assertTakeByEvalsTo( + TBoolean, + TInt64, + 3, + FastSeq( + Row(false, 0L), + Row(null, null), + Row(true, 1L), + Row(false, 3L), + Row(true, null), + Row(null, 2L), + ), + FastSeq(false, true, null), + ) } - @Test def takeByBooleanFloat() { - assertTakeByEvalsTo(TBoolean, TFloat32, 3, - FastSeq(Row(false, 0F), Row(null, null), Row(true, 1F), Row(false, 3F), Row(true, null), Row(null, 2F)), - FastSeq(false, true, null)) + @Test def takeByBooleanFloat(): Unit = { + assertTakeByEvalsTo( + TBoolean, + TFloat32, + 3, + FastSeq( + Row(false, 0f), + Row(null, null), + Row(true, 1f), + Row(false, 3f), + Row(true, null), + Row(null, 2f), + ), + FastSeq(false, true, null), + ) } - @Test def takeByBooleanDouble() { - assertTakeByEvalsTo(TBoolean, TFloat64, 3, - FastSeq(Row(false, 0D), Row(null, null), Row(true, 1D), Row(false, 3D), Row(true, null), Row(null, 2D)), - FastSeq(false, true, null)) + @Test def takeByBooleanDouble(): Unit = { + assertTakeByEvalsTo( + TBoolean, + TFloat64, + 3, + FastSeq( + Row(false, 0d), + Row(null, null), + Row(true, 1d), + Row(false, 3d), + Row(true, null), + Row(null, 2d), + ), + FastSeq(false, true, null), + ) } - @Test def takeByBooleanAnnotation() { - assertTakeByEvalsTo(TBoolean, TString, 3, - FastSeq(Row(false, "a"), Row(null, null), Row(true, "b"), Row(false, "d"), Row(true, null), Row(null, "c")), - FastSeq(false, true, null)) + @Test def takeByBooleanAnnotation(): Unit = { + assertTakeByEvalsTo( + TBoolean, + TString, + 3, + FastSeq( + Row(false, "a"), + Row(null, null), + Row(true, "b"), + Row(false, "d"), + Row(true, null), + Row(null, "c"), + ), + FastSeq(false, true, null), + ) } - @Test def takeByIntBoolean() { - assertTakeByEvalsTo(TInt32, TBoolean, 2, + @Test def takeByIntBoolean(): Unit = { + assertTakeByEvalsTo( + TInt32, + TBoolean, + 2, FastSeq(Row(3, true), Row(null, null), Row(null, false)), - FastSeq(null, 3)) + FastSeq(null, 3), + ) } - @Test def takeByIntInt() { - assertTakeByEvalsTo(TInt32, TInt32, 3, + @Test def takeByIntInt(): Unit = { + assertTakeByEvalsTo( + TInt32, + TInt32, + 3, FastSeq(Row(3, 4), Row(null, null), Row(null, 2), Row(11, 0), Row(45, 1), Row(3, null)), - FastSeq(11, 45, null)) + FastSeq(11, 45, null), + ) } - @Test def takeByIntLong() { - assertTakeByEvalsTo(TInt32, TInt64, 3, + @Test def takeByIntLong(): Unit = { + assertTakeByEvalsTo( + TInt32, + TInt64, + 3, FastSeq(Row(3, 4L), Row(null, null), Row(null, 2L), Row(11, 0L), Row(45, 1L), Row(3, null)), - FastSeq(11, 45, null)) + FastSeq(11, 45, null), + ) } - @Test def takeByIntFloat() { - assertTakeByEvalsTo(TInt32, TFloat32, 3, - FastSeq(Row(3, 4F), Row(null, null), Row(null, 2F), Row(11, 0F), Row(45, 1F), Row(3, null)), - FastSeq(11, 45, null)) + @Test def takeByIntFloat(): Unit = { + assertTakeByEvalsTo( + TInt32, + TFloat32, + 3, + FastSeq(Row(3, 4f), Row(null, null), Row(null, 2f), Row(11, 0f), Row(45, 1f), Row(3, null)), + FastSeq(11, 45, null), + ) } - @Test def takeByIntDouble() { - assertTakeByEvalsTo(TInt32, TFloat64, 3, - FastSeq(Row(3, 4D), Row(null, null), Row(null, 2D), Row(11, 0D), Row(45, 1D), Row(3, null)), - FastSeq(11, 45, null)) + @Test def takeByIntDouble(): Unit = { + assertTakeByEvalsTo( + TInt32, + TFloat64, + 3, + FastSeq(Row(3, 4d), Row(null, null), Row(null, 2d), Row(11, 0d), Row(45, 1d), Row(3, null)), + FastSeq(11, 45, null), + ) } - @Test def takeByIntAnnotation() { - assertTakeByEvalsTo(TInt32, TString, 3, - FastSeq(Row(3, "d"), Row(null, null), Row(null, "c"), Row(11, "a"), Row(45, "b"), Row(3, null)), - FastSeq(11, 45, null)) + @Test def takeByIntAnnotation(): Unit = { + assertTakeByEvalsTo( + TInt32, + TString, + 3, + FastSeq( + Row(3, "d"), + Row(null, null), + Row(null, "c"), + Row(11, "a"), + Row(45, "b"), + Row(3, null), + ), + FastSeq(11, 45, null), + ) } - @Test def takeByLongBoolean() { - assertTakeByEvalsTo(TInt64, TBoolean, 2, + @Test def takeByLongBoolean(): Unit = { + assertTakeByEvalsTo( + TInt64, + TBoolean, + 2, FastSeq(Row(3L, true), Row(null, null), Row(null, false)), - FastSeq(null, 3L)) + FastSeq(null, 3L), + ) } - @Test def takeByLongInt() { - assertTakeByEvalsTo(TInt64, TInt32, 3, + @Test def takeByLongInt(): Unit = { + assertTakeByEvalsTo( + TInt64, + TInt32, + 3, FastSeq(Row(3L, 4), Row(null, null), Row(null, 2), Row(11L, 0), Row(45L, 1), Row(3L, null)), - FastSeq(11L, 45L, null)) + FastSeq(11L, 45L, null), + ) } - @Test def takeByLongLong() { - assertTakeByEvalsTo(TInt64, TInt64, 3, - FastSeq(Row(3L, 4L), Row(null, null), Row(null, 2L), Row(11L, 0L), Row(45L, 1L), Row(3L, null)), - FastSeq(11L, 45L, null)) + @Test def takeByLongLong(): Unit = { + assertTakeByEvalsTo( + TInt64, + TInt64, + 3, + FastSeq( + Row(3L, 4L), + Row(null, null), + Row(null, 2L), + Row(11L, 0L), + Row(45L, 1L), + Row(3L, null), + ), + FastSeq(11L, 45L, null), + ) } - @Test def takeByLongFloat() { - assertTakeByEvalsTo(TInt64, TFloat32, 3, - FastSeq(Row(3L, 4F), Row(null, null), Row(null, 2F), Row(11L, 0F), Row(45L, 1F), Row(3L, null)), - FastSeq(11L, 45L, null)) + @Test def takeByLongFloat(): Unit = { + assertTakeByEvalsTo( + TInt64, + TFloat32, + 3, + FastSeq( + Row(3L, 4f), + Row(null, null), + Row(null, 2f), + Row(11L, 0f), + Row(45L, 1f), + Row(3L, null), + ), + FastSeq(11L, 45L, null), + ) } - @Test def takeByLongDouble() { - assertTakeByEvalsTo(TInt64, TFloat64, 3, - FastSeq(Row(3L, 4D), Row(null, null), Row(null, 2D), Row(11L, 0D), Row(45L, 1D), Row(3L, null)), - FastSeq(11L, 45L, null)) + @Test def takeByLongDouble(): Unit = { + assertTakeByEvalsTo( + TInt64, + TFloat64, + 3, + FastSeq( + Row(3L, 4d), + Row(null, null), + Row(null, 2d), + Row(11L, 0d), + Row(45L, 1d), + Row(3L, null), + ), + FastSeq(11L, 45L, null), + ) } - @Test def takeByLongAnnotation() { - assertTakeByEvalsTo(TInt64, TString, 3, - FastSeq(Row(3L, "d"), Row(null, null), Row(null, "c"), Row(11L, "a"), Row(45L, "b"), Row(3L, null)), - FastSeq(11L, 45L, null)) + @Test def takeByLongAnnotation(): Unit = { + assertTakeByEvalsTo( + TInt64, + TString, + 3, + FastSeq( + Row(3L, "d"), + Row(null, null), + Row(null, "c"), + Row(11L, "a"), + Row(45L, "b"), + Row(3L, null), + ), + FastSeq(11L, 45L, null), + ) } - @Test def takeByFloatBoolean() { - assertTakeByEvalsTo(TFloat32, TBoolean, 2, - FastSeq(Row(3F, true), Row(null, null), Row(null, false)), - FastSeq(null, 3F)) + @Test def takeByFloatBoolean(): Unit = { + assertTakeByEvalsTo( + TFloat32, + TBoolean, + 2, + FastSeq(Row(3f, true), Row(null, null), Row(null, false)), + FastSeq(null, 3f), + ) } - @Test def takeByFloatInt() { - assertTakeByEvalsTo(TFloat32, TInt32, 3, - FastSeq(Row(3F, 4), Row(null, null), Row(null, 2), Row(11F, 0), Row(45F, 1), Row(3F, null)), - FastSeq(11F, 45F, null)) + @Test def takeByFloatInt(): Unit = { + assertTakeByEvalsTo( + TFloat32, + TInt32, + 3, + FastSeq(Row(3f, 4), Row(null, null), Row(null, 2), Row(11f, 0), Row(45f, 1), Row(3f, null)), + FastSeq(11f, 45f, null), + ) } - @Test def takeByFloatLong() { - assertTakeByEvalsTo(TFloat32, TInt64, 3, - FastSeq(Row(3F, 4L), Row(null, null), Row(null, 2L), Row(11F, 0L), Row(45F, 1L), Row(3F, null)), - FastSeq(11F, 45F, null)) + @Test def takeByFloatLong(): Unit = { + assertTakeByEvalsTo( + TFloat32, + TInt64, + 3, + FastSeq( + Row(3f, 4L), + Row(null, null), + Row(null, 2L), + Row(11f, 0L), + Row(45f, 1L), + Row(3f, null), + ), + FastSeq(11f, 45f, null), + ) } - @Test def takeByFloatFloat() { - assertTakeByEvalsTo(TFloat32, TFloat32, 3, - FastSeq(Row(3F, 4F), Row(null, null), Row(null, 2F), Row(11F, 0F), Row(45F, 1F), Row(3F, null)), - FastSeq(11F, 45F, null)) + @Test def takeByFloatFloat(): Unit = { + assertTakeByEvalsTo( + TFloat32, + TFloat32, + 3, + FastSeq( + Row(3f, 4f), + Row(null, null), + Row(null, 2f), + Row(11f, 0f), + Row(45f, 1f), + Row(3f, null), + ), + FastSeq(11f, 45f, null), + ) } - @Test def takeByFloatDouble() { - assertTakeByEvalsTo(TFloat32, TFloat64, 3, - FastSeq(Row(3F, 4D), Row(null, null), Row(null, 2D), Row(11F, 0D), Row(45F, 1D), Row(3F, null)), - FastSeq(11F, 45F, null)) + @Test def takeByFloatDouble(): Unit = { + assertTakeByEvalsTo( + TFloat32, + TFloat64, + 3, + FastSeq( + Row(3f, 4d), + Row(null, null), + Row(null, 2d), + Row(11f, 0d), + Row(45f, 1d), + Row(3f, null), + ), + FastSeq(11f, 45f, null), + ) } - @Test def takeByFloatAnnotation() { - assertTakeByEvalsTo(TFloat32, TString, 3, - FastSeq(Row(3F, "d"), Row(null, null), Row(null, "c"), Row(11F, "a"), Row(45F, "b"), Row(3F, null)), - FastSeq(11F, 45F, null)) + @Test def takeByFloatAnnotation(): Unit = { + assertTakeByEvalsTo( + TFloat32, + TString, + 3, + FastSeq( + Row(3f, "d"), + Row(null, null), + Row(null, "c"), + Row(11f, "a"), + Row(45f, "b"), + Row(3f, null), + ), + FastSeq(11f, 45f, null), + ) } - @Test def takeByDoubleBoolean() { - assertTakeByEvalsTo(TFloat64, TBoolean, 2, - FastSeq(Row(3D, true), Row(null, null), Row(null, false)), - FastSeq(null, 3D)) + @Test def takeByDoubleBoolean(): Unit = { + assertTakeByEvalsTo( + TFloat64, + TBoolean, + 2, + FastSeq(Row(3d, true), Row(null, null), Row(null, false)), + FastSeq(null, 3d), + ) } - @Test def takeByDoubleInt() { - assertTakeByEvalsTo(TFloat64, TInt32, 3, - FastSeq(Row(3D, 4), Row(null, null), Row(null, 2), Row(11D, 0), Row(45D, 1), Row(3D, null)), - FastSeq(11D, 45D, null)) + @Test def takeByDoubleInt(): Unit = { + assertTakeByEvalsTo( + TFloat64, + TInt32, + 3, + FastSeq(Row(3d, 4), Row(null, null), Row(null, 2), Row(11d, 0), Row(45d, 1), Row(3d, null)), + FastSeq(11d, 45d, null), + ) } - @Test def takeByDoubleLong() { - assertTakeByEvalsTo(TFloat64, TInt64, 3, - FastSeq(Row(3D, 4L), Row(null, null), Row(null, 2L), Row(11D, 0L), Row(45D, 1L), Row(3D, null)), - FastSeq(11D, 45D, null)) + @Test def takeByDoubleLong(): Unit = { + assertTakeByEvalsTo( + TFloat64, + TInt64, + 3, + FastSeq( + Row(3d, 4L), + Row(null, null), + Row(null, 2L), + Row(11d, 0L), + Row(45d, 1L), + Row(3d, null), + ), + FastSeq(11d, 45d, null), + ) } - @Test def takeByDoubleFloat() { - assertTakeByEvalsTo(TFloat64, TFloat32, 3, - FastSeq(Row(3D, 4F), Row(null, null), Row(null, 2F), Row(11D, 0F), Row(45D, 1F), Row(3D, null)), - FastSeq(11D, 45D, null)) + @Test def takeByDoubleFloat(): Unit = { + assertTakeByEvalsTo( + TFloat64, + TFloat32, + 3, + FastSeq( + Row(3d, 4f), + Row(null, null), + Row(null, 2f), + Row(11d, 0f), + Row(45d, 1f), + Row(3d, null), + ), + FastSeq(11d, 45d, null), + ) } - @Test def takeByDoubleDouble() { - assertTakeByEvalsTo(TFloat64, TFloat64, 3, - FastSeq(Row(3D, 4D), Row(null, null), Row(null, 2D), Row(11D, 0D), Row(45D, 1D), Row(3D, null)), - FastSeq(11D, 45D, null)) + @Test def takeByDoubleDouble(): Unit = { + assertTakeByEvalsTo( + TFloat64, + TFloat64, + 3, + FastSeq( + Row(3d, 4d), + Row(null, null), + Row(null, 2d), + Row(11d, 0d), + Row(45d, 1d), + Row(3d, null), + ), + FastSeq(11d, 45d, null), + ) } - @Test def takeByDoubleAnnotation() { - assertTakeByEvalsTo(TFloat64, TString, 3, - FastSeq(Row(3D, "d"), Row(null, null), Row(null, "c"), Row(11D, "a"), Row(45D, "b"), Row(3D, null)), - FastSeq(11D, 45D, null)) + @Test def takeByDoubleAnnotation(): Unit = { + assertTakeByEvalsTo( + TFloat64, + TString, + 3, + FastSeq( + Row(3d, "d"), + Row(null, null), + Row(null, "c"), + Row(11d, "a"), + Row(45d, "b"), + Row(3d, null), + ), + FastSeq(11d, 45d, null), + ) } - @Test def takeByAnnotationBoolean() { - assertTakeByEvalsTo(TString, TBoolean, 2, + @Test def takeByAnnotationBoolean(): Unit = { + assertTakeByEvalsTo( + TString, + TBoolean, + 2, FastSeq(Row("hello", true), Row(null, null), Row(null, false)), - FastSeq(null, "hello")) + FastSeq(null, "hello"), + ) } - @Test def takeByAnnotationInt() { - assertTakeByEvalsTo(TString, TInt32, 3, + @Test def takeByAnnotationInt(): Unit = { + assertTakeByEvalsTo( + TString, + TInt32, + 3, FastSeq(Row("a", 4), Row(null, null), Row(null, 2), Row("b", 0), Row("c", 1), Row("d", null)), - FastSeq("b", "c", null)) + FastSeq("b", "c", null), + ) } - @Test def takeByAnnotationLong() { - assertTakeByEvalsTo(TString, TInt64, 3, - FastSeq(Row("a", 4L), Row(null, null), Row(null, 2L), Row("b", 0L), Row("c", 1L), Row("d", null)), - FastSeq("b", "c", null)) + @Test def takeByAnnotationLong(): Unit = { + assertTakeByEvalsTo( + TString, + TInt64, + 3, + FastSeq( + Row("a", 4L), + Row(null, null), + Row(null, 2L), + Row("b", 0L), + Row("c", 1L), + Row("d", null), + ), + FastSeq("b", "c", null), + ) } - @Test def takeByAnnotationFloat() { - assertTakeByEvalsTo(TString, TFloat32, 3, - FastSeq(Row("a", 4F), Row(null, null), Row(null, 2F), Row("b", 0F), Row("c", 1F), Row("d", null)), - FastSeq("b", "c", null)) + @Test def takeByAnnotationFloat(): Unit = { + assertTakeByEvalsTo( + TString, + TFloat32, + 3, + FastSeq( + Row("a", 4f), + Row(null, null), + Row(null, 2f), + Row("b", 0f), + Row("c", 1f), + Row("d", null), + ), + FastSeq("b", "c", null), + ) } - @Test def takeByAnnotationDouble() { - assertTakeByEvalsTo(TString, TFloat64, 3, - FastSeq(Row("a", 4D), Row(null, null), Row(null, 2D), Row("b", 0D), Row("c", 1D), Row("d", null)), - FastSeq("b", "c", null)) + @Test def takeByAnnotationDouble(): Unit = { + assertTakeByEvalsTo( + TString, + TFloat64, + 3, + FastSeq( + Row("a", 4d), + Row(null, null), + Row(null, 2d), + Row("b", 0d), + Row("c", 1d), + Row("d", null), + ), + FastSeq("b", "c", null), + ) } - @Test def takeByAnnotationAnnotation() { - assertTakeByEvalsTo(TString, TString, 3, - FastSeq(Row("a", "d"), Row(null, null), Row(null, "c"), Row("b", "a"), Row("c", "b"), Row("d", null)), - FastSeq("b", "c", null)) + @Test def takeByAnnotationAnnotation(): Unit = { + assertTakeByEvalsTo( + TString, + TString, + 3, + FastSeq( + Row("a", "d"), + Row(null, null), + Row(null, "c"), + Row("b", "a"), + Row("c", "b"), + Row("d", null), + ), + FastSeq("b", "c", null), + ) } - @Test def takeByCallLong() { - assertTakeByEvalsTo(TCall, TInt64, 3, - FastSeq(Row(Call2(0, 0), 4L), Row(null, null), Row(null, 2L), Row(Call2(0, 1), 0L), Row(Call2(1, 1), 1L), Row(Call2(0, 2), null)), - FastSeq(Call2(0, 1), Call2(1, 1), null)) + @Test def takeByCallLong(): Unit = { + assertTakeByEvalsTo( + TCall, + TInt64, + 3, + FastSeq( + Row(Call2(0, 0), 4L), + Row(null, null), + Row(null, 2L), + Row(Call2(0, 1), 0L), + Row(Call2(1, 1), 1L), + Row(Call2(0, 2), null), + ), + FastSeq(Call2(0, 1), Call2(1, 1), null), + ) } def runKeyedAggregator( @@ -528,183 +931,279 @@ class AggregatorsSuite extends HailSuite { agg: IndexedSeq[Row], expected: Any, initOpArgs: IndexedSeq[IR], - seqOpArgs: IndexedSeq[IR]) { + seqOpArgs: IndexedSeq[IR], + ): Unit = { assertEvalsTo( - AggGroupBy(key, + AggGroupBy( + key, ApplyAggOp( initOpArgs, seqOpArgs, - AggSignature(op, initOpArgs.map(_.typ), seqOpArgs.map(_.typ))), - false), + AggSignature(op, initOpArgs.map(_.typ), seqOpArgs.map(_.typ)), + ), + false, + ), (agg, aggType), - expected) + expected, + ) } @Test - def keyedCount() { - runKeyedAggregator(Count(), + def keyedCount(): Unit = { + runKeyedAggregator( + Count(), Ref("k", TInt32), TStruct("k" -> TInt32), FastSeq(Row(1), Row(2), Row(3), Row(1), Row(1), Row(null), Row(null)), Map(1 -> 3L, 2 -> 1L, 3 -> 1L, (null, 2L)), initOpArgs = FastSeq(), - seqOpArgs = FastSeq()) + seqOpArgs = FastSeq(), + ) - runKeyedAggregator(Count(), + runKeyedAggregator( + Count(), Ref("k", TBoolean), TStruct("k" -> TBoolean), FastSeq(Row(true), Row(true), Row(true), Row(false), Row(false), Row(null), Row(null)), Map(true -> 3L, false -> 2L, (null, 2L)), initOpArgs = FastSeq(), - seqOpArgs = FastSeq()) + seqOpArgs = FastSeq(), + ) // test struct as key - runKeyedAggregator(Count(), + runKeyedAggregator( + Count(), Ref("k", TStruct("a" -> TBoolean)), TStruct("k" -> TStruct("a" -> TBoolean)), - FastSeq(Row(Row(true)), Row(Row(true)), Row(Row(true)), Row(Row(false)), Row(Row(false)), Row(Row(null)), Row(Row(null))), + FastSeq( + Row(Row(true)), + Row(Row(true)), + Row(Row(true)), + Row(Row(false)), + Row(Row(false)), + Row(Row(null)), + Row(Row(null)), + ), Map(Row(true) -> 3L, Row(false) -> 2L, (Row(null), 2L)), initOpArgs = FastSeq(), - seqOpArgs = FastSeq()) + seqOpArgs = FastSeq(), + ) } @Test - def keyedCollect() { + def keyedCollect(): Unit = { runKeyedAggregator( Collect(), Ref("k", TBoolean), TStruct("k" -> TBoolean, "v" -> TInt32), - FastSeq(Row(true, 5), Row(true, 3), Row(true, null), Row(false, 0), Row(false, null), Row(null, null), Row(null, 2)), + FastSeq( + Row(true, 5), + Row(true, 3), + Row(true, null), + Row(false, 0), + Row(false, null), + Row(null, null), + Row(null, 2), + ), Map(true -> FastSeq(5, 3, null), false -> FastSeq(0, null), (null, FastSeq(null, 2))), FastSeq(), - FastSeq(Ref("v", TInt32))) + FastSeq(Ref("v", TInt32)), + ) } @Test - def keyedCallStats() { + def keyedCallStats(): Unit = { runKeyedAggregator( CallStats(), Ref("k", TBoolean), - TStruct("k" -> TBoolean, "v" ->TCall), - FastSeq(Row(true, null), Row(true, Call2(0, 1)), Row(true, Call2(0, 1)), - Row(false, null), Row(false, Call2(0, 0)), Row(false, Call2(1, 1))), - Map(true -> Row(FastSeq(2, 2), FastSeq(0.5, 0.5), 4, FastSeq(0, 0)), - false -> Row(FastSeq(2, 2), FastSeq(0.5, 0.5), 4, FastSeq(1, 1))), + TStruct("k" -> TBoolean, "v" -> TCall), + FastSeq( + Row(true, null), + Row(true, Call2(0, 1)), + Row(true, Call2(0, 1)), + Row(false, null), + Row(false, Call2(0, 0)), + Row(false, Call2(1, 1)), + ), + Map( + true -> Row(FastSeq(2, 2), FastSeq(0.5, 0.5), 4, FastSeq(0, 0)), + false -> Row(FastSeq(2, 2), FastSeq(0.5, 0.5), 4, FastSeq(1, 1)), + ), FastSeq(I32(2)), - FastSeq(Ref("v", TCall))) + FastSeq(Ref("v", TCall)), + ) } @Test - def keyedTakeBy() { - runKeyedAggregator(TakeBy(), + def keyedTakeBy(): Unit = { + runKeyedAggregator( + TakeBy(), Ref("k", TString), TStruct("k" -> TString, "x" -> TFloat64, "y" -> TInt32), - FastSeq(Row("case", 0.2, 5), Row("control", 0.4, 0), Row(null, 1.0, 3), Row("control", 0.0, 2), Row("case", 0.3, 6), Row("control", 0.5, 1)), - Map("case" -> FastSeq(0.2, 0.3), - "control" -> FastSeq(0.4, 0.5), - (null, FastSeq(1.0))), + FastSeq( + Row("case", 0.2, 5), + Row("control", 0.4, 0), + Row(null, 1.0, 3), + Row("control", 0.0, 2), + Row("case", 0.3, 6), + Row("control", 0.5, 1), + ), + Map("case" -> FastSeq(0.2, 0.3), "control" -> FastSeq(0.4, 0.5), (null, FastSeq(1.0))), FastSeq(I32(2)), - FastSeq(Ref("x", TFloat64), Ref("y", TInt32))) + FastSeq(Ref("x", TFloat64), Ref("y", TInt32)), + ) } @Test - def keyedKeyedCollect() { - val agg = FastSeq(Row("EUR", true, 1), Row("EUR", false, 2), Row("AFR", true, 3), Row("AFR", null, 4)) + def keyedKeyedCollect(): Unit = { + val agg = + FastSeq(Row("EUR", true, 1), Row("EUR", false, 2), Row("AFR", true, 3), Row("AFR", null, 4)) val aggType = TStruct("k1" -> TString, "k2" -> TBoolean, "x" -> TInt32) - val expected = Map("EUR" -> Map(true -> FastSeq(1), false -> FastSeq(2)), "AFR" -> Map(true -> FastSeq(3), (null, FastSeq(4)))) + val expected: Map[String, Map[Any, Seq[Int]]] = Map( + "EUR" -> Map(true -> FastSeq(1), false -> FastSeq(2)), + "AFR" -> Map(true -> FastSeq(3), (null, FastSeq(4))), + ) val aggSig = AggSignature(Collect(), FastSeq(), FastSeq(TInt32)) assertEvalsTo( - AggGroupBy(Ref("k1", TString), - AggGroupBy(Ref("k2", TBoolean), + AggGroupBy( + Ref("k1", TString), + AggGroupBy( + Ref("k2", TBoolean), ApplyAggOp( FastSeq(), FastSeq(Ref("x", TInt32)), - aggSig), - false), - false), + aggSig, + ), + false, + ), + false, + ), (agg, aggType), - expected + expected, ) } @Test - def keyedKeyedCallStats() { + def keyedKeyedCallStats(): Unit = { val agg = FastSeq( Row("EUR", "CASE", null), Row("EUR", "CONTROL", Call2(0, 1)), Row("AFR", "CASE", Call2(1, 1)), - Row("AFR", "CONTROL", null)) + Row("AFR", "CONTROL", null), + ) val aggType = TStruct("k1" -> TString, "k2" -> TString, "g" -> TCall) val expected = Map( "EUR" -> Map( "CONTROL" -> Row(FastSeq(1, 1), FastSeq(0.5, 0.5), 2, FastSeq(0, 0)), - "CASE" -> Row(FastSeq(0, 0), null, 0, FastSeq(0, 0))), + "CASE" -> Row(FastSeq(0, 0), null, 0, FastSeq(0, 0)), + ), "AFR" -> Map( "CASE" -> Row(FastSeq(0, 2), FastSeq(0.0, 1.0), 2, FastSeq(0, 1)), - "CONTROL" -> Row(FastSeq(0, 0), null, 0, FastSeq(0, 0)))) + "CONTROL" -> Row(FastSeq(0, 0), null, 0, FastSeq(0, 0)), + ), + ) val aggSig = AggSignature(CallStats(), FastSeq(TInt32), FastSeq(TCall)) assertEvalsTo( - AggGroupBy(Ref("k1", TString), - AggGroupBy(Ref("k2", TString), + AggGroupBy( + Ref("k1", TString), + AggGroupBy( + Ref("k2", TString), ApplyAggOp( FastSeq(I32(2)), FastSeq(Ref("g", TCall)), - aggSig), false), false), + aggSig, + ), + false, + ), + false, + ), (agg, aggType), - expected + expected, ) } @Test - def keyedKeyedTakeBy() { + def keyedKeyedTakeBy(): Unit = { val agg = FastSeq( - Row("case", "a", 0.2, 5), Row("control", "b", 0.4, 0), - Row(null, "c", 1.0, 3), Row("control", "b", 0.0, 2), - Row("case", "a", 0.3, 6), Row("control", "b", 0.5, 1)) + Row("case", "a", 0.2, 5), + Row("control", "b", 0.4, 0), + Row(null, "c", 1.0, 3), + Row("control", "b", 0.0, 2), + Row("case", "a", 0.3, 6), + Row("control", "b", 0.5, 1), + ) val aggType = TStruct("k1" -> TString, "k2" -> TString, "x" -> TFloat64, "y" -> TInt32) val expected = Map( "case" -> Map("a" -> FastSeq(0.2, 0.3)), "control" -> Map("b" -> FastSeq(0.4, 0.5)), - (null, Map("c" -> FastSeq(1.0)))) + (null, Map("c" -> FastSeq(1.0))), + ) val aggSig = AggSignature(TakeBy(), FastSeq(TInt32), FastSeq(TFloat64, TInt32)) assertEvalsTo( - AggGroupBy(Ref("k1", TString), - AggGroupBy(Ref("k2", TString), + AggGroupBy( + Ref("k1", TString), + AggGroupBy( + Ref("k2", TString), ApplyAggOp( FastSeq(I32(2)), FastSeq(Ref("x", TFloat64), Ref("y", TInt32)), - aggSig), false), false), + aggSig, + ), + false, + ), + false, + ), (agg, aggType), - expected + expected, ) } @Test - def keyedKeyedKeyedCollect() { - val agg = FastSeq(Row("EUR", "CASE", true, 1), Row("EUR", "CONTROL", true, 2), Row("AFR", "CASE", false, 3), Row("AFR", "CONTROL", false, 4)) + def keyedKeyedKeyedCollect(): Unit = { + val agg = FastSeq( + Row("EUR", "CASE", true, 1), + Row("EUR", "CONTROL", true, 2), + Row("AFR", "CASE", false, 3), + Row("AFR", "CONTROL", false, 4), + ) val aggType = TStruct("k1" -> TString, "k2" -> TString, "k3" -> TBoolean, "x" -> TInt32) - val expected = Map("EUR" -> Map("CASE" -> Map(true -> FastSeq(1)), "CONTROL" -> Map(true -> FastSeq(2))), "AFR" -> Map("CASE" -> Map(false -> FastSeq(3)), "CONTROL" -> Map(false -> FastSeq(4)))) + val expected = Map( + "EUR" -> Map("CASE" -> Map(true -> FastSeq(1)), "CONTROL" -> Map(true -> FastSeq(2))), + "AFR" -> Map("CASE" -> Map(false -> FastSeq(3)), "CONTROL" -> Map(false -> FastSeq(4))), + ) val aggSig = AggSignature(Collect(), FastSeq(), FastSeq(TInt32)) assertEvalsTo( - AggGroupBy(Ref("k1", TString), - AggGroupBy(Ref("k2", TString), - AggGroupBy(Ref("k3", TBoolean), + AggGroupBy( + Ref("k1", TString), + AggGroupBy( + Ref("k2", TString), + AggGroupBy( + Ref("k3", TBoolean), ApplyAggOp( FastSeq(), FastSeq(Ref("x", TInt32)), - aggSig), false), false), false), + aggSig, + ), + false, + ), + false, + ), + false, + ), (agg, aggType), - expected + expected, ) } @Test def downsampleWhenEmpty(): Unit = { - runAggregator(Downsample(), + runAggregator( + Downsample(), TStruct("x" -> TFloat64, "y" -> TFloat64, "label" -> TArray(TString)), FastSeq(), FastSeq(), FastSeq(10), - seqOpArgs = FastSeq(Ref("x", TFloat64), Ref("y", TFloat64), Ref("label", TArray(TString)))) + seqOpArgs = FastSeq(Ref("x", TFloat64), Ref("y", TFloat64), Ref("label", TArray(TString))), + ) } @Test def testAggFilter(): Unit = { @@ -713,12 +1212,14 @@ class AggregatorsSuite extends HailSuite { val agg = FastSeq(Row(true, -1L), Row(true, 1L), Row(false, 3L), Row(true, 5L)) assertEvalsTo( - AggFilter(Ref("x", TBoolean), - ApplyAggOp(FastSeq(), - FastSeq(Ref("y", TInt64)), - aggSig), false), + AggFilter( + Ref("x", TBoolean), + ApplyAggOp(FastSeq(), FastSeq(Ref("y", TInt64)), aggSig), + false, + ), (agg, aggType), - 5L) + 5L, + ) } @Test def testAggExplode(): Unit = { @@ -728,16 +1229,19 @@ class AggregatorsSuite extends HailSuite { Row(FastSeq[Long](1, 4)), Row(FastSeq[Long]()), Row(FastSeq[Long](-1, 3)), - Row(FastSeq[Long](4, 5, 6, -7))) + Row(FastSeq[Long](4, 5, 6, -7)), + ) assertEvalsTo( - AggExplode(ToStream(Ref("x", TArray(TInt64))), + AggExplode( + ToStream(Ref("x", TArray(TInt64))), "y", - ApplyAggOp(FastSeq(), - FastSeq(Ref("y", TInt64)), - aggSig), false), + ApplyAggOp(FastSeq(), FastSeq(Ref("y", TInt64)), aggSig), + false, + ), (agg, aggType), - 15L) + 15L, + ) } @Test def testArrayElementsAggregator(): Unit = { @@ -750,14 +1254,18 @@ class AggregatorsSuite extends HailSuite { TableAggregate( ht, - AggArrayPerElement(GetField(Ref("row", ht.typ.rowType), "aRange"), "elt", "_'", + AggArrayPerElement( + GetField(Ref("row", ht.typ.rowType), "aRange"), + "elt", + "_'", ApplyAggOp( FastSeq(), FastSeq(Cast(Ref("elt", TInt32), TInt64)), - AggSignature(Sum(), FastSeq(), FastSeq(TInt64))), + AggSignature(Sum(), FastSeq(), FastSeq(TInt64)), + ), None, - false - ) + false, + ), ) } @@ -776,40 +1284,82 @@ class AggregatorsSuite extends HailSuite { TableAggregate( ht, - AggArrayPerElement(GetField(Ref("row", ht.typ.rowType), "aRange"), "elt", "_'", + AggArrayPerElement( + GetField(Ref("row", ht.typ.rowType), "aRange"), + "elt", + "_'", ApplyAggOp( FastSeq(), FastSeq(Cast(Ref("elt", TInt32), TInt64)), - AggSignature(Sum(), FastSeq(), FastSeq(TInt64))), + AggSignature(Sum(), FastSeq(), FastSeq(TInt64)), + ), knownLength, - false - ) + false, + ), ) } assertEvalsTo(getAgg(10, 10, None), null) assertEvalsTo(getAgg(10, 10, Some(1)), FastSeq(0L)) - assertEvalsTo(getAgg(10, 10, Some(GetField(Ref("global", TStruct("m" -> TInt32)), "m"))), Array.fill(10)(0L).toFastSeq) + assertEvalsTo( + getAgg(10, 10, Some(GetField(Ref("global", TStruct("m" -> TInt32)), "m"))), + Array.fill(10)(0L).toFastSeq, + ) } @Test def testImputeTypeSimple(): Unit = { runAggregator(ImputeType(), TString, FastSeq(null), Row(false, false, true, true, true, true)) - runAggregator(ImputeType(), TString, FastSeq("1231", "1234.5", null), Row(true, false, false, false, false, true)) - runAggregator(ImputeType(), TString, FastSeq("1231", "123"), Row(true, true, false, true, true, true)) - runAggregator(ImputeType(), TString, FastSeq("true", "false"), Row(true, true, true, false, false, false)) + runAggregator( + ImputeType(), + TString, + FastSeq("1231", "1234.5", null), + Row(true, false, false, false, false, true), + ) + runAggregator( + ImputeType(), + TString, + FastSeq("1231", "123"), + Row(true, true, false, true, true, true), + ) + runAggregator( + ImputeType(), + TString, + FastSeq("true", "false"), + Row(true, true, true, false, false, false), + ) } @Test def testFoldAgg(): Unit = { val barRef = Ref("bar", TInt32) val bazRef = Ref("baz", TInt32) - val myIR = StreamAgg(mapIR(rangeIR(100)){ idx => makestruct(("idx", idx), ("unused", idx + idx))}, "foo", - AggFold(I32(0), Ref("bar", TInt32) + GetField(Ref("foo", TStruct("idx" -> TInt32, "unused" -> TInt32)), "idx"), barRef + bazRef, "bar", "baz", false) + val myIR = StreamAgg( + mapIR(rangeIR(100))(idx => makestruct(("idx", idx), ("unused", idx + idx))), + "foo", + AggFold( + I32(0), + Ref("bar", TInt32) + GetField( + Ref("foo", TStruct("idx" -> TInt32, "unused" -> TInt32)), + "idx", + ), + barRef + bazRef, + "bar", + "baz", + false, + ), ) assertEvalsTo(myIR, 4950) - val myTableIR = TableAggregate(TableRange(100, 5), - AggFold(I32(0), Ref("bar", TInt32) + GetField(Ref("row", TStruct("idx" -> TInt32)), "idx"), barRef + bazRef, "bar", "baz", false) + val myTableIR = TableAggregate( + TableRange(100, 5), + AggFold( + I32(0), + Ref("bar", TInt32) + GetField(Ref("row", TStruct("idx" -> TInt32)), "idx"), + barRef + bazRef, + "bar", + "baz", + false, + ), ) val analyses = LoweringAnalyses.apply(myTableIR, ctx) @@ -822,8 +1372,20 @@ class AggregatorsSuite extends HailSuite { val barRef = Ref("bar", TInt32) val bazRef = Ref("baz", TInt32) - val myIR = ToArray(StreamAggScan(mapIR(rangeIR(10)){ idx => makestruct(("idx", idx), ("unused", idx + idx))}, "foo", - AggFold(I32(0), Ref("bar", TInt32) + GetField(Ref("foo", TStruct("idx" -> TInt32, "unused" -> TInt32)), "idx"), barRef + bazRef, "bar", "baz", true) + val myIR = ToArray(StreamAggScan( + mapIR(rangeIR(10))(idx => makestruct(("idx", idx), ("unused", idx + idx))), + "foo", + AggFold( + I32(0), + Ref("bar", TInt32) + GetField( + Ref("foo", TStruct("idx" -> TInt32, "unused" -> TInt32)), + "idx", + ), + barRef + bazRef, + "bar", + "baz", + true, + ), )) assertEvalsTo(myIR, IndexedSeq(0, 0, 1, 3, 6, 10, 15, 21, 28, 36)) } diff --git a/hail/src/test/scala/is/hail/expr/ir/ArrayDeforestationSuite.scala b/hail/src/test/scala/is/hail/expr/ir/ArrayDeforestationSuite.scala index 6cd6ee4102a..d32b7035f90 100644 --- a/hail/src/test/scala/is/hail/expr/ir/ArrayDeforestationSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/ArrayDeforestationSuite.scala @@ -1,10 +1,10 @@ package is.hail.expr.ir import is.hail.{ExecStrategy, HailSuite} -import is.hail.types.virtual._ -import is.hail.TestUtils._ import is.hail.types.tcoerce +import is.hail.types.virtual._ import is.hail.utils._ + import org.apache.spark.sql.Row import org.testng.annotations.Test @@ -15,20 +15,23 @@ class ArrayDeforestationSuite extends HailSuite { ToArray(StreamMap( StreamRange(0, len, 1), "x1", - Ref("x1", TInt32) + 5)) + Ref("x1", TInt32) + 5, + )) def arrayWithRegion(len: IR): IR = ToArray(StreamMap( StreamRange(0, len, 1), "x2", - MakeStruct(FastSeq[(String, IR)]("f1" -> (Ref("x2", TInt32) + 1), "f2" -> 0)))) + MakeStruct(FastSeq[(String, IR)]("f1" -> (Ref("x2", TInt32) + 1), "f2" -> 0)), + )) def primitiveArrayWithRegion(len: IR): IR = { val array = arrayWithRegion(len) ToArray(StreamMap( ToStream(array), "x3", - GetField(Ref("x3", tcoerce[TArray](array.typ).elementType), "f1"))) + GetField(Ref("x3", tcoerce[TArray](array.typ).elementType), "f1"), + )) } def arrayFoldWithStructWithPrimitiveValues(len: IR, max1: Int, max2: Int): IR = { @@ -38,29 +41,38 @@ class ArrayDeforestationSuite extends HailSuite { StreamFold( ToStream(primitiveArrayWithRegion(len)), zero, - accum.name, value.name, - If(value > GetField(accum, "max1"), + accum.name, + value.name, + If( + value > GetField(accum, "max1"), MakeStruct(FastSeq("max1" -> value, "max2" -> GetField(accum, "max1"))), - If(value > GetField(accum, "max2"), + If( + value > GetField(accum, "max2"), MakeStruct(FastSeq("max1" -> GetField(accum, "max1"), "max2" -> value)), - accum))) + accum, + ), + ), + ) } def arrayFoldWithStruct(len: IR, v1: Int, v2: Int): IR = { val zero = MakeTuple.ordered(FastSeq( MakeStruct(FastSeq[(String, IR)]("f1" -> v1, "f2" -> v2)), - MakeStruct(FastSeq[(String, IR)]("f1" -> v1, "f2" -> v2)))) + MakeStruct(FastSeq[(String, IR)]("f1" -> v1, "f2" -> v2)), + )) val array = arrayWithRegion(len) val accum = Ref(genUID(), zero.typ) val value = Ref(genUID(), tcoerce[TArray](array.typ).elementType) StreamFold( ToStream(array), zero, - accum.name, value.name, - MakeTuple.ordered(FastSeq(GetTupleElement(accum, 1), value))) + accum.name, + value.name, + MakeTuple.ordered(FastSeq(GetTupleElement(accum, 1), value)), + ) } - @Test def testArrayFold() { + @Test def testArrayFold(): Unit = { assertEvalsTo(arrayFoldWithStructWithPrimitiveValues(5, -5, -6), Row(5, 4)) assertEvalsTo(arrayFoldWithStruct(5, -5, -6), Row(Row(4, 0), Row(5, 0))) } diff --git a/hail/src/test/scala/is/hail/expr/ir/ArrayFunctionsSuite.scala b/hail/src/test/scala/is/hail/expr/ir/ArrayFunctionsSuite.scala index 8bec54ac17e..d76741cb770 100644 --- a/hail/src/test/scala/is/hail/expr/ir/ArrayFunctionsSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/ArrayFunctionsSuite.scala @@ -1,10 +1,11 @@ package is.hail.expr.ir +import is.hail.{ExecStrategy, HailSuite} import is.hail.TestUtils._ import is.hail.expr.ir.TestUtils._ import is.hail.types.virtual._ import is.hail.utils.FastSeq -import is.hail.{ExecStrategy, HailSuite} + import org.testng.annotations.{DataProvider, Test} class ArrayFunctionsSuite extends HailSuite { @@ -17,43 +18,51 @@ class ArrayFunctionsSuite extends HailSuite { Array(FastSeq(3, 7)), Array(null), Array(FastSeq(3, null, 7, null)), - Array(FastSeq()) + Array(FastSeq()), ) @DataProvider(name = "basicPairs") def basicPairsData(): Array[Array[Any]] = basicData().flatten.combinations(2).toArray @Test(dataProvider = "basic") - def isEmpty(a: IndexedSeq[Integer]) { - assertEvalsTo(invoke("isEmpty", TBoolean, toIRArray(a)), - Option(a).map(_.isEmpty).orNull) - } + def isEmpty(a: IndexedSeq[Integer]): Unit = + assertEvalsTo(invoke("isEmpty", TBoolean, toIRArray(a)), Option(a).map(_.isEmpty).orNull) @Test(dataProvider = "basic") - def append(a: IndexedSeq[Integer]) { - assertEvalsTo(invoke("append", TArray(TInt32), toIRArray(a), I32(1)), - Option(a).map(_ :+ 1).orNull) - } + def append(a: IndexedSeq[Integer]): Unit = + assertEvalsTo( + invoke("append", TArray(TInt32), toIRArray(a), I32(1)), + Option(a).map(_ :+ 1).orNull, + ) @Test(dataProvider = "basic") - def appendNull(a: IndexedSeq[Integer]) { - assertEvalsTo(invoke("append", TArray(TInt32), toIRArray(a), NA(TInt32)), - Option(a).map(_ :+ null).orNull) - } + def appendNull(a: IndexedSeq[Integer]): Unit = + assertEvalsTo( + invoke("append", TArray(TInt32), toIRArray(a), NA(TInt32)), + Option(a).map(_ :+ null).orNull, + ) @Test(dataProvider = "basic") - def sum(a: IndexedSeq[Integer]) { - assertEvalsTo(invoke("sum", TInt32, toIRArray(a)), - Option(a).flatMap(_.foldLeft[Option[Int]](Some(0))((comb, x) => comb.flatMap(c => Option(x).map(_ + c)))).orNull) + def sum(a: IndexedSeq[Integer]): Unit = { + assertEvalsTo( + invoke("sum", TInt32, toIRArray(a)), + Option(a).flatMap(_.foldLeft[Option[Int]](Some(0))((comb, x) => + comb.flatMap(c => Option(x).map(_ + c)) + )).orNull, + ) } @Test(dataProvider = "basic") - def product(a: IndexedSeq[Integer]) { - assertEvalsTo(invoke("product", TInt32, toIRArray(a)), - Option(a).flatMap(_.foldLeft[Option[Int]](Some(1))((comb, x) => comb.flatMap(c => Option(x).map(_ * c)))).orNull) + def product(a: IndexedSeq[Integer]): Unit = { + assertEvalsTo( + invoke("product", TInt32, toIRArray(a)), + Option(a).flatMap(_.foldLeft[Option[Int]](Some(1))((comb, x) => + comb.flatMap(c => Option(x).map(_ * c)) + )).orNull, + ) } - @Test def mean() { + @Test def mean(): Unit = { assertEvalsTo(invoke("mean", TFloat64, IRArray(3, 7)), 5.0) assertEvalsTo(invoke("mean", TFloat64, IRArray(3, null, 7)), null) assertEvalsTo(invoke("mean", TFloat64, IRArray(3, 7, 11)), 7.0) @@ -62,7 +71,7 @@ class ArrayFunctionsSuite extends HailSuite { assertEvalsTo(invoke("mean", TFloat64, naa), null) } - @Test def median() { + @Test def median(): Unit = { assertEvalsTo(invoke("median", TInt32, IRArray(5)), 5) assertEvalsTo(invoke("median", TInt32, IRArray(5, null, null)), 5) assertEvalsTo(invoke("median", TInt32, IRArray(3, 7)), 5) @@ -73,12 +82,13 @@ class ArrayFunctionsSuite extends HailSuite { assertEvalsTo(invoke("median", TInt32, IRArray(null)), null) assertEvalsTo(invoke("median", TInt32, naa), null) } - + @Test(dataProvider = "basicPairs") - def extend(a: IndexedSeq[Integer], b: IndexedSeq[Integer]) { - assertEvalsTo(invoke("extend", TArray(TInt32), toIRArray(a), toIRArray(b)), - Option(a).zip(Option(b)).headOption.map { case (x, y) => x ++ y}.orNull) - } + def extend(a: IndexedSeq[Integer], b: IndexedSeq[Integer]): Unit = + assertEvalsTo( + invoke("extend", TArray(TInt32), toIRArray(a), toIRArray(b)), + Option(a).zip(Option(b)).headOption.map { case (x, y) => x ++ y }.orNull, + ) @DataProvider(name = "sort") def sortData(): Array[Array[Any]] = Array( @@ -86,29 +96,71 @@ class ArrayFunctionsSuite extends HailSuite { Array(null, null, null), Array(FastSeq(3, null, 1, null, 3), FastSeq(1, 3, 3, null, null), FastSeq(3, 3, 1, null, null)), Array(FastSeq(1, null, 3, null, 1), FastSeq(1, 1, 3, null, null), FastSeq(3, 1, 1, null, null)), - Array(FastSeq(), FastSeq(), FastSeq()) + Array(FastSeq(), FastSeq(), FastSeq()), ) @Test(dataProvider = "sort") - def min(a: IndexedSeq[Integer], asc: IndexedSeq[Integer], desc: IndexedSeq[Integer]) { - assertEvalsTo(invoke("min", TInt32, toIRArray(a)), - Option(asc).filter(!_.contains(null)).flatMap(_.headOption).orNull) - } + def min(a: IndexedSeq[Integer], asc: IndexedSeq[Integer], desc: IndexedSeq[Integer]): Unit = + assertEvalsTo( + invoke("min", TInt32, toIRArray(a)), + Option(asc).filter(!_.contains(null)).flatMap(_.headOption).orNull, + ) - @Test def testMinMaxNans() { + @Test def testMinMaxNans(): Unit = { assertAllEvalTo( - (invoke("min", TFloat32, MakeArray(FastSeq(F32(Float.NaN), F32(1.0f), F32(Float.NaN), F32(111.0f)), TArray(TFloat32))), Float.NaN), - (invoke("max", TFloat32, MakeArray(FastSeq(F32(Float.NaN), F32(1.0f), F32(Float.NaN), F32(111.0f)), TArray(TFloat32))), Float.NaN), - (invoke("min", TFloat64, MakeArray(FastSeq(F64(Double.NaN), F64(1.0), F64(Double.NaN), F64(111.0)), TArray(TFloat64))), Double.NaN), - (invoke("max", TFloat64, MakeArray(FastSeq(F64(Double.NaN), F64(1.0), F64(Double.NaN), F64(111.0)), TArray(TFloat64))), Double.NaN) + ( + invoke( + "min", + TFloat32, + MakeArray( + FastSeq(F32(Float.NaN), F32(1.0f), F32(Float.NaN), F32(111.0f)), + TArray(TFloat32), + ), + ), + Float.NaN, + ), + ( + invoke( + "max", + TFloat32, + MakeArray( + FastSeq(F32(Float.NaN), F32(1.0f), F32(Float.NaN), F32(111.0f)), + TArray(TFloat32), + ), + ), + Float.NaN, + ), + ( + invoke( + "min", + TFloat64, + MakeArray( + FastSeq(F64(Double.NaN), F64(1.0), F64(Double.NaN), F64(111.0)), + TArray(TFloat64), + ), + ), + Double.NaN, + ), + ( + invoke( + "max", + TFloat64, + MakeArray( + FastSeq(F64(Double.NaN), F64(1.0), F64(Double.NaN), F64(111.0)), + TArray(TFloat64), + ), + ), + Double.NaN, + ), ) } @Test(dataProvider = "sort") - def max(a: IndexedSeq[Integer], asc: IndexedSeq[Integer], desc: IndexedSeq[Integer]) { - assertEvalsTo(invoke("max", TInt32, toIRArray(a)), - Option(desc).filter(!_.contains(null)).flatMap(_.headOption).orNull) - } + def max(a: IndexedSeq[Integer], asc: IndexedSeq[Integer], desc: IndexedSeq[Integer]): Unit = + assertEvalsTo( + invoke("max", TInt32, toIRArray(a)), + Option(desc).filter(!_.contains(null)).flatMap(_.headOption).orNull, + ) @DataProvider(name = "argminmax") def argMinMaxData(): Array[Array[Any]] = Array( @@ -116,18 +168,16 @@ class ArrayFunctionsSuite extends HailSuite { Array(null, null, null), Array(FastSeq(3, null, 1, null, 3), 2, 0), Array(FastSeq(1, null, 3, null, 1), 0, 2), - Array(FastSeq(), null, null) + Array(FastSeq(), null, null), ) @Test(dataProvider = "argminmax") - def argmin(a: IndexedSeq[Integer], argmin: Integer, argmax: Integer) { + def argmin(a: IndexedSeq[Integer], argmin: Integer, argmax: Integer): Unit = assertEvalsTo(invoke("argmin", TInt32, toIRArray(a)), argmin) - } @Test(dataProvider = "argminmax") - def argmax(a: IndexedSeq[Integer], argmin: Integer, argmax: Integer) { + def argmax(a: IndexedSeq[Integer], argmin: Integer, argmax: Integer): Unit = assertEvalsTo(invoke("argmax", TInt32, toIRArray(a)), argmax) - } @DataProvider(name = "uniqueMinMaxIndex") def uniqueMinMaxData(): Array[Array[Any]] = Array( @@ -135,25 +185,23 @@ class ArrayFunctionsSuite extends HailSuite { Array(null, null, null), Array(FastSeq(3, null, 1, null, 3), 2, null), Array(FastSeq(1, null, 3, null, 1), null, 2), - Array(FastSeq(), null, null) + Array(FastSeq(), null, null), ) @Test(dataProvider = "uniqueMinMaxIndex") - def uniqueMinIndex(a: IndexedSeq[Integer], argmin: Integer, argmax: Integer) { + def uniqueMinIndex(a: IndexedSeq[Integer], argmin: Integer, argmax: Integer): Unit = assertEvalsTo(invoke("uniqueMinIndex", TInt32, toIRArray(a)), argmin) - } @Test(dataProvider = "uniqueMinMaxIndex") - def uniqueMaxIndex(a: IndexedSeq[Integer], argmin: Integer, argmax: Integer) { + def uniqueMaxIndex(a: IndexedSeq[Integer], argmin: Integer, argmax: Integer): Unit = assertEvalsTo(invoke("uniqueMaxIndex", TInt32, toIRArray(a)), argmax) - } @DataProvider(name = "arrayOpsData") def arrayOpsData(): Array[Array[Any]] = Array[Any]( FastSeq(3, 9, 7, 1), FastSeq(null, 2, null, 8), FastSeq(5, 3, null, null), - null + null, ).combinations(2).toArray @DataProvider(name = "arrayOpsOperations") @@ -162,49 +210,66 @@ class ArrayFunctionsSuite extends HailSuite { ("sub", _ - _), ("mul", _ * _), ("floordiv", _ / _), - ("mod", _ % _) + ("mod", _ % _), ).map(_.productIterator.toArray) @DataProvider(name = "arrayOps") def arrayOpsPairs(): Array[Array[Any]] = - for (Array(a, b) <- arrayOpsData(); Array(s, f) <- arrayOpsOperations) - yield Array(a, b, s, f) + for { + Array(a, b) <- arrayOpsData() + Array(s, f) <- arrayOpsOperations + } yield Array(a, b, s, f) - def lift(f: (Int, Int) => Int): (IndexedSeq[Integer], IndexedSeq[Integer]) => IndexedSeq[Integer] = { + def lift(f: (Int, Int) => Int) + : (IndexedSeq[Integer], IndexedSeq[Integer]) => IndexedSeq[Integer] = { case (a, b) => Option(a).zip(Option(b)).headOption.map { case (a0, b0) => - a0.zip(b0).map { case (i, j) => Option(i).zip(Option(j)).headOption.map[Integer] { case (m, n) => f(m, n) }.orNull } + a0.zip(b0).map { case (i, j) => + Option(i).zip(Option(j)).headOption.map[Integer] { case (m, n) => f(m, n) }.orNull + } }.orNull } @Test(dataProvider = "arrayOps") - def arrayOps(a: IndexedSeq[Integer], b: IndexedSeq[Integer], s: String, f: (Int, Int) => Int) { + def arrayOps(a: IndexedSeq[Integer], b: IndexedSeq[Integer], s: String, f: (Int, Int) => Int) + : Unit = assertEvalsTo(invoke(s, TArray(TInt32), toIRArray(a), toIRArray(b)), lift(f)(a, b)) - } @Test(dataProvider = "arrayOpsData") - def arrayOpsFPDiv(a: IndexedSeq[Integer], b: IndexedSeq[Integer]) { - assertEvalsTo(invoke("div", TArray(TFloat64), toIRArray(a), toIRArray(b)), + def arrayOpsFPDiv(a: IndexedSeq[Integer], b: IndexedSeq[Integer]): Unit = { + assertEvalsTo( + invoke("div", TArray(TFloat64), toIRArray(a), toIRArray(b)), Option(a).zip(Option(b)).headOption.map { case (a0, b0) => - a0.zip(b0).map { case (i, j) => Option(i).zip(Option(j)).headOption.map[java.lang.Double] { case (m, n) => m.toDouble / n }.orNull } - }.orNull ) + a0.zip(b0).map { case (i, j) => + Option(i).zip(Option(j)).headOption.map[java.lang.Double] { case (m, n) => + m.toDouble / n + }.orNull + } + }.orNull, + ) } @Test(dataProvider = "arrayOpsData") - def arrayOpsPow(a: IndexedSeq[Integer], b: IndexedSeq[Integer]) { - assertEvalsTo(invoke("pow", TArray(TFloat64), toIRArray(a), toIRArray(b)), + def arrayOpsPow(a: IndexedSeq[Integer], b: IndexedSeq[Integer]): Unit = { + assertEvalsTo( + invoke("pow", TArray(TFloat64), toIRArray(a), toIRArray(b)), Option(a).zip(Option(b)).headOption.map { case (a0, b0) => - a0.zip(b0).map { case (i, j) => Option(i).zip(Option(j)).headOption.map[java.lang.Double] { case (m, n) => math.pow(m.toDouble, n.toDouble) }.orNull } - }.orNull ) + a0.zip(b0).map { case (i, j) => + Option(i).zip(Option(j)).headOption.map[java.lang.Double] { case (m, n) => + math.pow(m.toDouble, n.toDouble) + }.orNull + } + }.orNull, + ) } @Test(dataProvider = "arrayOpsOperations") - def arrayOpsDifferentLength(s: String, f: (Int, Int) => Int) { + def arrayOpsDifferentLength(s: String, f: (Int, Int) => Int): Unit = { assertFatal(invoke(s, TArray(TInt32), IRArray(1, 2, 3), IRArray(1, 2)), "length mismatch") assertFatal(invoke(s, TArray(TInt32), IRArray(1, 2), IRArray(1, 2, 3)), "length mismatch") } - @Test def indexing() { + @Test def indexing(): Unit = { val a = IRArray(0, null, 2) assertEvalsTo(invoke("indexArray", TInt32, a, I32(0)), 0) assertEvalsTo(invoke("indexArray", TInt32, a, I32(1)), null) @@ -217,7 +282,7 @@ class ArrayFunctionsSuite extends HailSuite { assertEvalsTo(invoke("indexArray", TInt32, a, NA(TInt32)), null) } - @Test def slicing() { + @Test def slicing(): Unit = { val a = IRArray(0, null, 2) assertEvalsTo(ArraySlice(a, I32(1), None), FastSeq(null, 2)) assertEvalsTo(ArraySlice(a, I32(-2), None), FastSeq(null, 2)) @@ -237,7 +302,7 @@ class ArrayFunctionsSuite extends HailSuite { assertEvalsTo(ArraySlice(a, I32(1), Some(I32(2))), FastSeq(null)) assertEvalsTo(ArraySlice(a, I32(0), Some(I32(2))), FastSeq(0, null)) assertEvalsTo(ArraySlice(a, I32(0), Some(I32(3))), FastSeq(0, null, 2)) - assertEvalsTo(ArraySlice(a, I32(-1),Some( I32(3))), FastSeq(2)) + assertEvalsTo(ArraySlice(a, I32(-1), Some(I32(3))), FastSeq(2)) assertEvalsTo(ArraySlice(a, I32(-4), Some(I32(4))), FastSeq(0, null, 2)) assertEvalsTo(ArraySlice(naa, I32(1), Some(I32(2))), null) assertEvalsTo(ArraySlice(a, I32(1), Some(NA(TInt32))), null) @@ -251,50 +316,59 @@ class ArrayFunctionsSuite extends HailSuite { Array(FastSeq(null, FastSeq(1)), FastSeq(1)), Array(FastSeq(null, null), FastSeq()), Array(FastSeq(FastSeq(null), FastSeq(), FastSeq(7)), FastSeq(null, 7)), - Array(FastSeq(FastSeq(), FastSeq()), FastSeq()) + Array(FastSeq(FastSeq(), FastSeq()), FastSeq()), ) @Test(dataProvider = "flatten") - def flatten(in: IndexedSeq[IndexedSeq[Integer]], expected: IndexedSeq[Int]) { - assertEvalsTo(invoke("flatten", TArray(TInt32), MakeArray(in.map(toIRArray(_)), TArray(TArray(TInt32)))), expected) - } + def flatten(in: IndexedSeq[IndexedSeq[Integer]], expected: IndexedSeq[Int]): Unit = + assertEvalsTo( + invoke("flatten", TArray(TInt32), MakeArray(in.map(toIRArray(_)), TArray(TArray(TInt32)))), + expected, + ) - @Test def testContains() { + @Test def testContains(): Unit = { val t = TArray(TString) assertEvalsTo( invoke("contains", TBoolean, In(0, t), Str("a")), args = FastSeq(FastSeq() -> t), - expected=false) + expected = false, + ) assertEvalsTo( invoke("contains", TBoolean, In(0, t), Str("a")), args = FastSeq(FastSeq(null) -> t), - expected=false) + expected = false, + ) assertEvalsTo( invoke("contains", TBoolean, In(0, t), Str("a")), args = FastSeq(FastSeq("c", "a", "b") -> t), - expected=true) + expected = true, + ) assertEvalsTo( invoke("contains", TBoolean, In(0, t), Str("a")), args = FastSeq(FastSeq("c", "a", "b", null) -> t), - expected=true) + expected = true, + ) assertEvalsTo( invoke("contains", TBoolean, In(0, t), Str("a")), args = FastSeq((null, t)), - expected=null) + expected = null, + ) assertEvalsTo( invoke("contains", TBoolean, In(0, t), NA(t.elementType)), args = FastSeq((null, t)), - expected=null) + expected = null, + ) assertEvalsTo( invoke("contains", TBoolean, In(0, t), NA(t.elementType)), args = FastSeq(FastSeq("a", null) -> t), - expected=true) + expected = true, + ) } } diff --git a/hail/src/test/scala/is/hail/expr/ir/BlockMatrixIRSuite.scala b/hail/src/test/scala/is/hail/expr/ir/BlockMatrixIRSuite.scala index 419ea4b35a2..5d4638c39c7 100644 --- a/hail/src/test/scala/is/hail/expr/ir/BlockMatrixIRSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/BlockMatrixIRSuite.scala @@ -1,6 +1,6 @@ package is.hail.expr.ir -import breeze.linalg.{DenseMatrix => BDM} +import is.hail.{ExecStrategy, HailSuite} import is.hail.ExecStrategy.ExecStrategy import is.hail.expr.Nat import is.hail.io.TypedCodecSpec @@ -8,7 +8,8 @@ import is.hail.linalg.BlockMatrix import is.hail.types.encoded.{EBlockMatrixNDArray, EFloat64Required} import is.hail.types.virtual._ import is.hail.utils._ -import is.hail.{ExecStrategy, HailSuite} + +import breeze.linalg.{DenseMatrix => BDM} import org.testng.annotations.Test class BlockMatrixIRSuite extends HailSuite { @@ -19,8 +20,11 @@ class BlockMatrixIRSuite extends HailSuite { val shape: Array[Long] = Array[Long](N_ROWS, N_COLS) def toIR(bdm: BDM[Double], blockSize: Int = BLOCK_SIZE): BlockMatrixIR = - ValueToBlockMatrix(Literal(TArray(TFloat64), bdm.t.toArray.toFastSeq), - FastSeq(bdm.rows, bdm.cols), blockSize) + ValueToBlockMatrix( + Literal(TArray(TFloat64), bdm.t.toArray.toFastSeq), + FastSeq(bdm.rows, bdm.cols), + blockSize, + ) def fill(v: Double, nRows: Int = N_ROWS, nCols: Int = N_COLS, blockSize: Int = BLOCK_SIZE) = toIR(BDM.fill[Double](nRows, nCols)(v), blockSize) @@ -29,25 +33,52 @@ class BlockMatrixIRSuite extends HailSuite { implicit val execStrats: Set[ExecStrategy] = ExecStrategy.allRelational - def makeMap2(left: BlockMatrixIR, right: BlockMatrixIR, op: BinaryOp, strategy: SparsityStrategy): - BlockMatrixMap2 = { - BlockMatrixMap2(left, right, "l", "r", ApplyBinaryPrimOp(op, Ref("l", TFloat64), Ref("r", TFloat64)), strategy) - } - - @Test def testBlockMatrixWriteRead() { + def makeMap2(left: BlockMatrixIR, right: BlockMatrixIR, op: BinaryOp, strategy: SparsityStrategy) + : BlockMatrixMap2 = + BlockMatrixMap2( + left, + right, + "l", + "r", + ApplyBinaryPrimOp(op, Ref("l", TFloat64), Ref("r", TFloat64)), + strategy, + ) + + @Test def testBlockMatrixWriteRead(): Unit = { implicit val execStrats: Set[ExecStrategy] = ExecStrategy.interpretOnly val tempPath = ctx.createTmpPath("test-blockmatrix-write-read", "bm") - Interpret[Unit](ctx, BlockMatrixWrite(ones, - BlockMatrixNativeWriter(tempPath, false, false, false))) - - assertBMEvalsTo(BlockMatrixRead(BlockMatrixNativeReader(fs, tempPath)), BDM.fill[Double](N_ROWS, N_COLS)(1)) + Interpret[Unit]( + ctx, + BlockMatrixWrite(ones, BlockMatrixNativeWriter(tempPath, false, false, false)), + ) + + assertBMEvalsTo( + BlockMatrixRead(BlockMatrixNativeReader(fs, tempPath)), + BDM.fill[Double](N_ROWS, N_COLS)(1), + ) } - @Test def testBlockMatrixMap() { - val sqrtIR = BlockMatrixMap(ones, "element", Apply("sqrt", FastSeq(), FastSeq(Ref("element", TFloat64)), TFloat64, ErrorIDs.NO_ERROR), false) - val negIR = BlockMatrixMap(ones, "element", ApplyUnaryPrimOp(Negate, Ref("element", TFloat64)), false) - val logIR = BlockMatrixMap(ones, "element", Apply("log", FastSeq(), FastSeq(Ref("element", TFloat64)), TFloat64, ErrorIDs.NO_ERROR), true) - val absIR = BlockMatrixMap(ones, "element", Apply("abs", FastSeq(), FastSeq(Ref("element", TFloat64)), TFloat64, ErrorIDs.NO_ERROR), false) + @Test def testBlockMatrixMap(): Unit = { + val sqrtIR = BlockMatrixMap( + ones, + "element", + Apply("sqrt", FastSeq(), FastSeq(Ref("element", TFloat64)), TFloat64, ErrorIDs.NO_ERROR), + false, + ) + val negIR = + BlockMatrixMap(ones, "element", ApplyUnaryPrimOp(Negate, Ref("element", TFloat64)), false) + val logIR = BlockMatrixMap( + ones, + "element", + Apply("log", FastSeq(), FastSeq(Ref("element", TFloat64)), TFloat64, ErrorIDs.NO_ERROR), + true, + ) + val absIR = BlockMatrixMap( + ones, + "element", + Apply("abs", FastSeq(), FastSeq(Ref("element", TFloat64)), TFloat64, ErrorIDs.NO_ERROR), + false, + ) assertBMEvalsTo(sqrtIR, BDM.fill[Double](3, 3)(1)) assertBMEvalsTo(negIR, BDM.fill[Double](3, 3)(-1)) @@ -55,7 +86,7 @@ class BlockMatrixIRSuite extends HailSuite { assertBMEvalsTo(absIR, BDM.fill[Double](3, 3)(1)) } - @Test def testBlockMatrixMap2() { + @Test def testBlockMatrixMap2(): Unit = { val onesAddOnes = makeMap2(ones, ones, Add(), UnionBlocks) val onesSubOnes = makeMap2(ones, ones, Subtract(), UnionBlocks) val onesMulOnes = makeMap2(ones, ones, Multiply(), IntersectionBlocks) @@ -67,10 +98,17 @@ class BlockMatrixIRSuite extends HailSuite { assertBMEvalsTo(onesDivOnes, BDM.fill[Double](3, 3)(1.0 / 1.0)) } - @Test def testBlockMatrixBroadcastValue_Scalars() { + @Test def testBlockMatrixBroadcastValue_Scalars(): Unit = { val broadcastTwo = BlockMatrixBroadcast( - ValueToBlockMatrix(MakeArray(IndexedSeq[F64](F64(2)), TArray(TFloat64)), Array[Long](1, 1), ones.typ.blockSize), - FastSeq(), shape, ones.typ.blockSize) + ValueToBlockMatrix( + MakeArray(IndexedSeq[F64](F64(2)), TArray(TFloat64)), + Array[Long](1, 1), + ones.typ.blockSize, + ), + FastSeq(), + shape, + ones.typ.blockSize, + ) val onesAddTwo = makeMap2(ones, broadcastTwo, Add(), UnionBlocks) val onesSubTwo = makeMap2(ones, broadcastTwo, Subtract(), UnionBlocks) @@ -83,31 +121,40 @@ class BlockMatrixIRSuite extends HailSuite { assertBMEvalsTo(onesDivTwo, BDM.fill[Double](3, 3)(1.0 / 2.0)) } - @Test def testBlockMatrixBroadcastValue_Vectors() { + @Test def testBlockMatrixBroadcastValue_Vectors(): Unit = { val vectorLiteral = MakeArray(IndexedSeq[F64](F64(1), F64(2), F64(3)), TArray(TFloat64)) - val broadcastRowVector = BlockMatrixBroadcast(ValueToBlockMatrix(vectorLiteral, Array[Long](1, 3), - ones.typ.blockSize), FastSeq(1), shape, ones.typ.blockSize) - val broadcastColVector = BlockMatrixBroadcast(ValueToBlockMatrix(vectorLiteral, Array[Long](3, 1), - ones.typ.blockSize), FastSeq(0), shape, ones.typ.blockSize) + val broadcastRowVector = BlockMatrixBroadcast( + ValueToBlockMatrix(vectorLiteral, Array[Long](1, 3), ones.typ.blockSize), + FastSeq(1), + shape, + ones.typ.blockSize, + ) + val broadcastColVector = BlockMatrixBroadcast( + ValueToBlockMatrix(vectorLiteral, Array[Long](3, 1), ones.typ.blockSize), + FastSeq(0), + shape, + ones.typ.blockSize, + ) val ops = Array( (Add(), UnionBlocks, (i: Double, j: Double) => i + j), (Subtract(), UnionBlocks, (i: Double, j: Double) => i - j), (Multiply(), IntersectionBlocks, (i: Double, j: Double) => i * j), - (FloatingPointDivide(), NeedsDense, (i: Double, j: Double) => i / j)) + (FloatingPointDivide(), NeedsDense, (i: Double, j: Double) => i / j), + ) for ((op, merge, f) <- ops) { val rightRowOp = makeMap2(ones, broadcastRowVector, op, merge) val rightColOp = makeMap2(ones, broadcastColVector, op, merge) val leftRowOp = makeMap2(broadcastRowVector, ones, op, merge) val leftColOp = makeMap2(broadcastColVector, ones, op, merge) - BDM.tabulate(3, 3){ (_, j) => f(1.0, j + 1) } + BDM.tabulate(3, 3)((_, j) => f(1.0, j + 1)) - val expectedRightRowOp = BDM.tabulate(3, 3){ (_, j) => f(1.0, j + 1) } - val expectedRightColOp = BDM.tabulate(3, 3){ (i, _) => f(1.0, i + 1) } - val expectedLeftRowOp = BDM.tabulate(3, 3){ (_, j) => f(j + 1, 1.0) } - val expectedLeftColOp = BDM.tabulate(3, 3){ (i, _) => f(i + 1, 1.0) } + val expectedRightRowOp = BDM.tabulate(3, 3)((_, j) => f(1.0, j + 1)) + val expectedRightColOp = BDM.tabulate(3, 3)((i, _) => f(1.0, i + 1)) + val expectedLeftRowOp = BDM.tabulate(3, 3)((_, j) => f(j + 1, 1.0)) + val expectedLeftColOp = BDM.tabulate(3, 3)((i, _) => f(i + 1, 1.0)) assertBMEvalsTo(rightRowOp, expectedRightRowOp) assertBMEvalsTo(rightColOp, expectedRightColOp) @@ -116,7 +163,7 @@ class BlockMatrixIRSuite extends HailSuite { } } - @Test def testBlockMatrixFilter() { + @Test def testBlockMatrixFilter(): Unit = { val nRows = 5 val nCols = 8 val original = BDM.tabulate[Double](nRows, nCols)((i, j) => i * nCols + j) @@ -125,15 +172,21 @@ class BlockMatrixIRSuite extends HailSuite { val keepRows = Array(0L, 1L, 4L) val keepCols = Array(0L, 2L, 7L) - assertBMEvalsTo(BlockMatrixFilter(unfiltered, Array(keepRows, Array())), - original(keepRows.map(_.toInt).toFastSeq, ::).toDenseMatrix) - assertBMEvalsTo(BlockMatrixFilter(unfiltered, Array(Array(), keepCols)), - original(::, keepCols.map(_.toInt).toFastSeq).toDenseMatrix) - assertBMEvalsTo(BlockMatrixFilter(unfiltered, Array(keepRows, keepCols)), - original(keepRows.map(_.toInt).toFastSeq, keepCols.map(_.toInt).toFastSeq).toDenseMatrix) + assertBMEvalsTo( + BlockMatrixFilter(unfiltered, Array(keepRows, Array())), + original(keepRows.map(_.toInt).toFastSeq, ::).toDenseMatrix, + ) + assertBMEvalsTo( + BlockMatrixFilter(unfiltered, Array(Array(), keepCols)), + original(::, keepCols.map(_.toInt).toFastSeq).toDenseMatrix, + ) + assertBMEvalsTo( + BlockMatrixFilter(unfiltered, Array(keepRows, keepCols)), + original(keepRows.map(_.toInt).toFastSeq, keepCols.map(_.toInt).toFastSeq).toDenseMatrix, + ) } - @Test def testBlockMatrixSlice() { + @Test def testBlockMatrixSlice(): Unit = { val nRows = 12 val nCols = 8 val original = BDM.tabulate[Double](nRows, nCols)((i, j) => i * nCols + j) @@ -141,30 +194,54 @@ class BlockMatrixIRSuite extends HailSuite { val rowSlice = FastSeq(1L, 10L, 3L) val colSlice = FastSeq(4L, 8L, 2L) - assertBMEvalsTo(BlockMatrixSlice(unsliced, FastSeq(rowSlice, colSlice)), + assertBMEvalsTo( + BlockMatrixSlice(unsliced, FastSeq(rowSlice, colSlice)), original( Array.range(rowSlice(0).toInt, rowSlice(1).toInt, rowSlice(2).toInt).toFastSeq, - Array.range(colSlice(0).toInt, colSlice(1).toInt, colSlice(2).toInt).toFastSeq).toDenseMatrix) + Array.range(colSlice(0).toInt, colSlice(1).toInt, colSlice(2).toInt).toFastSeq, + ).toDenseMatrix, + ) } - @Test def testBlockMatrixDot() { + @Test def testBlockMatrixDot(): Unit = { val m1 = BDM.tabulate[Double](5, 4)((i, j) => (i + 1) * j) val m2 = BDM.tabulate[Double](4, 6)((i, j) => (i + 5) * (j - 2)) assertBMEvalsTo(BlockMatrixDot(toIR(m1), toIR(m2)), m1 * m2) } - @Test def testBlockMatrixRandom() { + @Test def testBlockMatrixRandom(): Unit = { val gaussian = BlockMatrixRandom(0, gaussian = true, shape = Array(5L, 6L), blockSize = 3) val uniform = BlockMatrixRandom(0, gaussian = false, shape = Array(5L, 6L), blockSize = 3) - assertBMEvalsTo(BlockMatrixMap2(gaussian, gaussian, "l", "r", Ref("l", TFloat64) - Ref("r", TFloat64), NeedsDense), BDM.fill(5, 6)(0.0)) - assertBMEvalsTo(BlockMatrixMap2(uniform, uniform, "l", "r", Ref("l", TFloat64) - Ref("r", TFloat64), NeedsDense), BDM.fill(5, 6)(0.0)) + assertBMEvalsTo( + BlockMatrixMap2( + gaussian, + gaussian, + "l", + "r", + Ref("l", TFloat64) - Ref("r", TFloat64), + NeedsDense, + ), + BDM.fill(5, 6)(0.0), + ) + assertBMEvalsTo( + BlockMatrixMap2( + uniform, + uniform, + "l", + "r", + Ref("l", TFloat64) - Ref("r", TFloat64), + NeedsDense, + ), + BDM.fill(5, 6)(0.0), + ) } - @Test def readBlockMatrixIR() { + @Test def readBlockMatrixIR(): Unit = { implicit val execStrats: Set[ExecStrategy] = ExecStrategy.compileOnly val etype = EBlockMatrixNDArray(EFloat64Required, required = true) - val path = "src/test/resources/blockmatrix_example/0/parts/part-0-28-0-0-0feb7ac2-ab02-6cd4-5547-bfcb94dacb33" + val path = + "src/test/resources/blockmatrix_example/0/parts/part-0-28-0-0-0feb7ac2-ab02-6cd4-5547-bfcb94dacb33" val matrix = BlockMatrix.read(fs, "src/test/resources/blockmatrix_example/0").toBreezeMatrix() val expected = Array.tabulate(2)(i => Array.tabulate(2)(j => matrix(i, j)).toFastSeq).toFastSeq @@ -173,20 +250,33 @@ class BlockMatrixIRSuite extends HailSuite { val reader = ETypeValueReader(spec) val read = ReadValue(Str(path), reader, typ) assertNDEvals(read, expected) - assertNDEvals(ReadValue( - WriteValue(read, Str(ctx.createTmpPath("read-blockmatrix-ir", "hv")) + UUID4(), ETypeValueWriter(spec)), - reader, typ), expected) + assertNDEvals( + ReadValue( + WriteValue( + read, + Str(ctx.createTmpPath("read-blockmatrix-ir", "hv")) + UUID4(), + ETypeValueWriter(spec), + ), + reader, + typ, + ), + expected, + ) } - @Test def readWriteBlockMatrix() { + @Test def readWriteBlockMatrix(): Unit = { val original = "src/test/resources/blockmatrix_example/0" val expected = BlockMatrix.read(ctx.fs, original).toBreezeMatrix() val path = ctx.createTmpPath("read-blockmatrix-ir", "bm") - assertEvalsTo(BlockMatrixWrite( - BlockMatrixRead(BlockMatrixNativeReader(ctx.fs, original)), - BlockMatrixNativeWriter(path, overwrite = true, forceRowMajor = false, stageLocally = false)), ()) + assertEvalsTo( + BlockMatrixWrite( + BlockMatrixRead(BlockMatrixNativeReader(ctx.fs, original)), + BlockMatrixNativeWriter(path, overwrite = true, forceRowMajor = false, stageLocally = false), + ), + (), + ) assertBMEvalsTo(BlockMatrixRead(BlockMatrixNativeReader(ctx.fs, path)), expected) } diff --git a/hail/src/test/scala/is/hail/expr/ir/CallFunctionsSuite.scala b/hail/src/test/scala/is/hail/expr/ir/CallFunctionsSuite.scala index ba82403f811..e04a181b506 100644 --- a/hail/src/test/scala/is/hail/expr/ir/CallFunctionsSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/CallFunctionsSuite.scala @@ -1,10 +1,10 @@ package is.hail.expr.ir -import is.hail.TestUtils._ +import is.hail.{ExecStrategy, HailSuite} import is.hail.expr.ir.TestUtils.IRCall import is.hail.types.virtual.{TArray, TBoolean, TCall, TInt32} import is.hail.variant._ -import is.hail.{ExecStrategy, HailSuite} + import org.testng.annotations.{DataProvider, Test} class CallFunctionsSuite extends HailSuite { @@ -21,7 +21,7 @@ class CallFunctionsSuite extends HailSuite { Array(Call2(1, 0, true)), Array(Call2(0, 1, false)), Array(CallN(Array(1, 1), false)), - Array(Call.parse("0|1")) + Array(Call.parse("0|1")), ) } @@ -33,7 +33,7 @@ class CallFunctionsSuite extends HailSuite { Array(Call2(1, 0, false)), Array(Call2(0, 1, false)), Array(Call2(3, 1, false)), - Array(Call2(3, 3, false)) + Array(Call2(3, 3, false)), ) } @@ -50,89 +50,85 @@ class CallFunctionsSuite extends HailSuite { Array(CallN(Array(1, 1), false), 0), Array(CallN(Array(1, 1), false), 1), Array(Call.parse("0|1"), 0), - Array(Call.parse("0|1"), 1) + Array(Call.parse("0|1"), 1), ) } - @Test def constructors() { + @Test def constructors(): Unit = { assertEvalsTo(invoke("Call", TCall, False()), Call0()) assertEvalsTo(invoke("Call", TCall, I32(0), True()), Call1(0, true)) assertEvalsTo(invoke("Call", TCall, I32(1), False()), Call1(1, false)) assertEvalsTo(invoke("Call", TCall, I32(0), I32(0), False()), Call2(0, 0, false)) - assertEvalsTo(invoke("Call", TCall, TestUtils.IRArray(0, 1), False()), CallN(Array(0, 1), false)) + assertEvalsTo( + invoke("Call", TCall, TestUtils.IRArray(0, 1), False()), + CallN(Array(0, 1), false), + ) assertEvalsTo(invoke("Call", TCall, Str("0|1")), Call2(0, 1, true)) } @Test(dataProvider = "basic") - def isPhased(c: Call) { - assertEvalsTo(invoke("isPhased", TBoolean, IRCall(c)), - Option(c).map(Call.isPhased).orNull) - } + def isPhased(c: Call): Unit = + assertEvalsTo(invoke("isPhased", TBoolean, IRCall(c)), Option(c).map(Call.isPhased).orNull) @Test(dataProvider = "basic") - def isHomRef(c: Call) { - assertEvalsTo(invoke("isHomRef", TBoolean, IRCall(c)), - Option(c).map(Call.isHomRef).orNull) - } + def isHomRef(c: Call): Unit = + assertEvalsTo(invoke("isHomRef", TBoolean, IRCall(c)), Option(c).map(Call.isHomRef).orNull) @Test(dataProvider = "basic") - def isHet(c: Call) { - assertEvalsTo(invoke("isHet", TBoolean, IRCall(c)), - Option(c).map(Call.isHet).orNull) - } + def isHet(c: Call): Unit = + assertEvalsTo(invoke("isHet", TBoolean, IRCall(c)), Option(c).map(Call.isHet).orNull) @Test(dataProvider = "basic") - def isHomVar(c: Call) { - assertEvalsTo(invoke("isHomVar", TBoolean,IRCall(c)), - Option(c).map(Call.isHomVar).orNull) - } + def isHomVar(c: Call): Unit = + assertEvalsTo(invoke("isHomVar", TBoolean, IRCall(c)), Option(c).map(Call.isHomVar).orNull) @Test(dataProvider = "basic") - def isNonRef(c: Call) { - assertEvalsTo(invoke("isNonRef", TBoolean, IRCall(c)), - Option(c).map(Call.isNonRef).orNull) - } + def isNonRef(c: Call): Unit = + assertEvalsTo(invoke("isNonRef", TBoolean, IRCall(c)), Option(c).map(Call.isNonRef).orNull) @Test(dataProvider = "basic") - def isHetNonRef(c: Call) { - assertEvalsTo(invoke("isHetNonRef", TBoolean, IRCall(c)), - Option(c).map(Call.isHetNonRef).orNull) - } + def isHetNonRef(c: Call): Unit = + assertEvalsTo( + invoke("isHetNonRef", TBoolean, IRCall(c)), + Option(c).map(Call.isHetNonRef).orNull, + ) @Test(dataProvider = "basic") - def isHetRef(c: Call) { - assertEvalsTo(invoke("isHetRef", TBoolean, IRCall(c)), - Option(c).map(Call.isHetRef).orNull) - } + def isHetRef(c: Call): Unit = + assertEvalsTo(invoke("isHetRef", TBoolean, IRCall(c)), Option(c).map(Call.isHetRef).orNull) @Test(dataProvider = "basic") - def nNonRefAlleles(c: Call) { - assertEvalsTo(invoke("nNonRefAlleles", TInt32, IRCall(c)), - Option(c).map(Call.nNonRefAlleles).orNull) - } + def nNonRefAlleles(c: Call): Unit = + assertEvalsTo( + invoke("nNonRefAlleles", TInt32, IRCall(c)), + Option(c).map(Call.nNonRefAlleles).orNull, + ) @Test(dataProvider = "basicWithIndex") - def alleleByIndex(c: Call, idx: Int) { - assertEvalsTo(invoke("index", TInt32, IRCall(c), I32(idx)), - Option(c).map(c => Call.alleleByIndex(c, idx)).orNull) - } + def alleleByIndex(c: Call, idx: Int): Unit = + assertEvalsTo( + invoke("index", TInt32, IRCall(c), I32(idx)), + Option(c).map(c => Call.alleleByIndex(c, idx)).orNull, + ) @Test(dataProvider = "basicWithIndex") - def downcode(c: Call, idx: Int) { - assertEvalsTo(invoke("downcode", TCall, IRCall(c), I32(idx)), - Option(c).map(c => Call.downcode(c, idx)).orNull) - } + def downcode(c: Call, idx: Int): Unit = + assertEvalsTo( + invoke("downcode", TCall, IRCall(c), I32(idx)), + Option(c).map(c => Call.downcode(c, idx)).orNull, + ) @Test(dataProvider = "diploid") - def unphasedDiploidGtIndex(c: Call) { - assertEvalsTo(invoke("unphasedDiploidGtIndex", TInt32, IRCall(c)), - Option(c).map(c => Call.unphasedDiploidGtIndex(c)).orNull) - } + def unphasedDiploidGtIndex(c: Call): Unit = + assertEvalsTo( + invoke("unphasedDiploidGtIndex", TInt32, IRCall(c)), + Option(c).map(c => Call.unphasedDiploidGtIndex(c)).orNull, + ) @Test(dataProvider = "basic") - def oneHotAlleles(c: Call) { - assertEvalsTo(invoke("oneHotAlleles", TArray(TInt32), IRCall(c), I32(2)), - Option(c).map(c => Call.oneHotAlleles(c, 2)).orNull) - } + def oneHotAlleles(c: Call): Unit = + assertEvalsTo( + invoke("oneHotAlleles", TArray(TInt32), IRCall(c), I32(2)), + Option(c).map(c => Call.oneHotAlleles(c, 2)).orNull, + ) } - diff --git a/hail/src/test/scala/is/hail/expr/ir/DictFunctionsSuite.scala b/hail/src/test/scala/is/hail/expr/ir/DictFunctionsSuite.scala index e44b2e01198..f0e071c6aa1 100644 --- a/hail/src/test/scala/is/hail/expr/ir/DictFunctionsSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/DictFunctionsSuite.scala @@ -1,10 +1,11 @@ package is.hail.expr.ir +import is.hail.{ExecStrategy, HailSuite} import is.hail.TestUtils._ import is.hail.expr.ir.TestUtils._ import is.hail.types.virtual._ import is.hail.utils.FastSeq -import is.hail.{ExecStrategy, HailSuite} + import org.apache.spark.sql.Row import org.testng.annotations.{DataProvider, Test} @@ -20,76 +21,87 @@ class DictFunctionsSuite extends HailSuite { Array(IndexedSeq((1, 3), (2, null), null, (null, 1), (3, 7))), Array(IndexedSeq()), Array(IndexedSeq(null)), - Array(null) + Array(null), ) @Test(dataProvider = "basic") - def dictFromArray(a: IndexedSeq[(Integer, Integer)]) { + def dictFromArray(a: IndexedSeq[(Integer, Integer)]): Unit = { assertEvalsTo(invoke("dict", TDict(TInt32, TInt32), toIRPairArray(a)), tuplesToMap(a)) assertEvalsTo(toIRDict(a), tuplesToMap(a)) } @Test(dataProvider = "basic") - def dictFromSet(a: IndexedSeq[(Integer, Integer)]) { - assertEvalsTo(invoke("dict", TDict(TInt32, TInt32), ToSet(ToStream(toIRPairArray(a)))), tuplesToMap(a)) - } + def dictFromSet(a: IndexedSeq[(Integer, Integer)]): Unit = + assertEvalsTo( + invoke("dict", TDict(TInt32, TInt32), ToSet(ToStream(toIRPairArray(a)))), + tuplesToMap(a), + ) @Test(dataProvider = "basic") - def isEmpty(a: IndexedSeq[(Integer, Integer)]) { - assertEvalsTo(invoke("isEmpty", TBoolean, toIRDict(a)), - Option(a).map(_.forall(_ == null)).orNull) - } + def isEmpty(a: IndexedSeq[(Integer, Integer)]): Unit = + assertEvalsTo( + invoke("isEmpty", TBoolean, toIRDict(a)), + Option(a).map(_.forall(_ == null)).orNull, + ) @DataProvider(name = "dictToArray") def dictToArrayData(): Array[Array[Any]] = Array( Array(FastSeq(1 -> 3, 2 -> 7), FastSeq(Row(1, 3), Row(2, 7))), - Array(FastSeq(1 -> 3, 2 -> null, null, (null, 1), 3 -> 7), - FastSeq(Row(1, 3), Row(2, null), Row(3, 7), Row(null, 1))), + Array( + FastSeq(1 -> 3, 2 -> null, null, (null, 1), 3 -> 7), + FastSeq(Row(1, 3), Row(2, null), Row(3, 7), Row(null, 1)), + ), Array(FastSeq(), FastSeq()), Array(FastSeq(null), FastSeq()), - Array(null, null)) + Array(null, null), + ) @Test(dataProvider = "dictToArray") - def dictToArray(a: IndexedSeq[(Integer, Integer)], expected: (IndexedSeq[Row])) { + def dictToArray(a: IndexedSeq[(Integer, Integer)], expected: (IndexedSeq[Row])): Unit = assertEvalsTo(invoke("dictToArray", TArray(TTuple(TInt32, TInt32)), toIRDict(a)), expected) - } @DataProvider(name = "keysAndValues") def keysAndValuesData(): Array[Array[Any]] = Array( Array(FastSeq(1 -> 3, 2 -> 7), FastSeq(1, 2), FastSeq(3, 7)), - Array(FastSeq(1 -> 3, 2 -> null, null, (null, 1), 3 -> 7), - FastSeq(1, 2, 3, null), FastSeq(3, null, 7, 1)), + Array( + FastSeq(1 -> 3, 2 -> null, null, (null, 1), 3 -> 7), + FastSeq(1, 2, 3, null), + FastSeq(3, null, 7, 1), + ), Array(FastSeq(), FastSeq(), FastSeq()), Array(FastSeq(null), FastSeq(), FastSeq()), - Array(null, null, null)) + Array(null, null, null), + ) @Test(dataProvider = "keysAndValues") - def keySet(a: IndexedSeq[(Integer, Integer)], + def keySet( + a: IndexedSeq[(Integer, Integer)], keys: IndexedSeq[Integer], - values: IndexedSeq[Integer]) { - assertEvalsTo(invoke("keySet", TSet(TInt32), toIRDict(a)), - Option(keys).map(_.toSet).orNull) - } + values: IndexedSeq[Integer], + ): Unit = + assertEvalsTo(invoke("keySet", TSet(TInt32), toIRDict(a)), Option(keys).map(_.toSet).orNull) @Test(dataProvider = "keysAndValues") - def keys(a: IndexedSeq[(Integer, Integer)], + def keys( + a: IndexedSeq[(Integer, Integer)], keys: IndexedSeq[Integer], - values: IndexedSeq[Integer]) { + values: IndexedSeq[Integer], + ): Unit = assertEvalsTo(invoke("keys", TArray(TInt32), toIRDict(a)), keys) - } @Test(dataProvider = "keysAndValues") - def values(a: IndexedSeq[(Integer, Integer)], + def values( + a: IndexedSeq[(Integer, Integer)], keys: IndexedSeq[Integer], - values: IndexedSeq[Integer]) { + values: IndexedSeq[Integer], + ): Unit = assertEvalsTo(invoke("values", TArray(TInt32), toIRDict(a)), values) - } val d = IRDict((1, 3), (3, 7), (5, null), (null, 5)) val dwoutna = IRDict((1, 3), (3, 7), (5, null)) val na = NA(TInt32) - @Test def dictGet() { + @Test def dictGet(): Unit = { assertEvalsTo(invoke("get", TInt32, NA(TDict(TInt32, TInt32)), 1, na), null) assertEvalsTo(invoke("get", TInt32, d, 0, na), null) assertEvalsTo(invoke("get", TInt32, d, 1, na), 3) @@ -115,7 +127,7 @@ class DictFunctionsSuite extends HailSuite { assertFatal(invoke("index", TInt32, IRDict(), 100), "dictionary") } - @Test def dictContains() { + @Test def dictContains(): Unit = { assertEvalsTo(invoke("contains", TBoolean, d, 0), false) assertEvalsTo(invoke("contains", TBoolean, d, 1), true) assertEvalsTo(invoke("contains", TBoolean, d, 2), false) diff --git a/hail/src/test/scala/is/hail/expr/ir/DistinctlyKeyedSuite.scala b/hail/src/test/scala/is/hail/expr/ir/DistinctlyKeyedSuite.scala index 15778bccde8..644a4192676 100644 --- a/hail/src/test/scala/is/hail/expr/ir/DistinctlyKeyedSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/DistinctlyKeyedSuite.scala @@ -3,13 +3,16 @@ package is.hail.expr.ir import is.hail.HailSuite import is.hail.types.virtual.TInt32 import is.hail.utils.FastSeq + import org.testng.annotations.Test -class DistinctlyKeyedSuite extends HailSuite{ +class DistinctlyKeyedSuite extends HailSuite { @Test def distinctlyKeyedRangeTableBase(): Unit = { val tableRange = TableRange(10, 2) - val tableFilter = TableFilter(tableRange, ApplyComparisonOp(LT(TInt32), - GetField(Ref("row", tableRange.typ.rowType), "idx"), I32(5))) + val tableFilter = TableFilter( + tableRange, + ApplyComparisonOp(LT(TInt32), GetField(Ref("row", tableRange.typ.rowType), "idx"), I32(5)), + ) val tableDistinct = TableDistinct(tableFilter) val tableIRSeq = IndexedSeq(tableRange, tableFilter, tableDistinct) val distinctlyKeyedAnalysis = DistinctlyKeyed.apply(tableDistinct) @@ -18,13 +21,16 @@ class DistinctlyKeyedSuite extends HailSuite{ @Test def readTableKeyByDistinctlyKeyedAnalysis(): Unit = { val rt = TableRange(40, 4) - val idxRef = GetField(Ref("row", rt.typ.rowType), "idx") - val at = TableMapRows(rt, MakeStruct(FastSeq( - "idx" -> idxRef, - "const" -> 5, - "half" -> idxRef.floorDiv(2), - "oneRepeat" -> If(idxRef ceq I32(10), I32(9), idxRef) - ))) + val idxRef = GetField(Ref("row", rt.typ.rowType), "idx") + val at = TableMapRows( + rt, + MakeStruct(FastSeq( + "idx" -> idxRef, + "const" -> 5, + "half" -> idxRef.floorDiv(2), + "oneRepeat" -> If(idxRef ceq I32(10), I32(9), idxRef), + )), + ) val keyedByConst = TableKeyBy(at, IndexedSeq("const")) val pathConst = ctx.createTmpPath("test-table-distinctly-keyed", "ht") Interpret[Unit](ctx, TableWrite(keyedByConst, TableNativeWriter(pathConst))) @@ -52,8 +58,14 @@ class DistinctlyKeyedSuite extends HailSuite{ val tableRange1 = TableRange(10, 2) val tableRange2 = TableRange(10, 2) val row = Ref("row", tableRange2.typ.rowType) - val tableRange1Mapped = TableMapRows(tableRange1, InsertFields(row, FastSeq("x" -> ToArray(StreamRange(0, GetField(row, "idx"), 1))))) - val tableRange2Mapped = TableMapRows(tableRange2, InsertFields(row, FastSeq("x" -> ToArray(StreamRange(0, GetField(row, "idx"), 1))))) + val tableRange1Mapped = TableMapRows( + tableRange1, + InsertFields(row, FastSeq("x" -> ToArray(StreamRange(0, GetField(row, "idx"), 1)))), + ) + val tableRange2Mapped = TableMapRows( + tableRange2, + InsertFields(row, FastSeq("x" -> ToArray(StreamRange(0, GetField(row, "idx"), 1)))), + ) val tableUnion = TableUnion(IndexedSeq(tableRange1Mapped, tableRange2Mapped)) val tableExplode = TableExplode(tableUnion, FastSeq("x")) val notDistinctlyKeyedSeq = IndexedSeq(tableUnion, tableExplode) @@ -68,8 +80,14 @@ class DistinctlyKeyedSuite extends HailSuite{ val tableRange1 = TableRange(10, 2) val tableRange2 = TableRange(10, 2) val row = Ref("row", tableRange2.typ.rowType) - val tableRange1Mapped = TableMapRows(tableRange1, InsertFields(row, FastSeq("x" -> ToArray(StreamRange(0, GetField(row, "idx"), 1))))) - val tableRange2Mapped = TableMapRows(tableRange2, InsertFields(row, FastSeq("x" -> ToArray(StreamRange(0, GetField(row, "idx"), 1))))) + val tableRange1Mapped = TableMapRows( + tableRange1, + InsertFields(row, FastSeq("x" -> ToArray(StreamRange(0, GetField(row, "idx"), 1)))), + ) + val tableRange2Mapped = TableMapRows( + tableRange2, + InsertFields(row, FastSeq("x" -> ToArray(StreamRange(0, GetField(row, "idx"), 1)))), + ) val tableUnion = TableUnion(IndexedSeq(tableRange1Mapped, tableRange2Mapped)) val tableExplode = TableExplode(tableUnion, FastSeq("x")) val tableDistinct = TableDistinct(tableExplode) @@ -79,8 +97,10 @@ class DistinctlyKeyedSuite extends HailSuite{ @Test def iRparent(): Unit = { val tableRange = TableRange(10, 2) - val tableFilter = TableFilter(tableRange, ApplyComparisonOp(LT(TInt32), - GetField(Ref("row", tableRange.typ.rowType), "idx"), I32(5))) + val tableFilter = TableFilter( + tableRange, + ApplyComparisonOp(LT(TInt32), GetField(Ref("row", tableRange.typ.rowType), "idx"), I32(5)), + ) val tableDistinct = TableDistinct(tableFilter) val tableCollect = TableCollect(tableDistinct) val distinctlyKeyedAnalysis = DistinctlyKeyed.apply(tableCollect) diff --git a/hail/src/test/scala/is/hail/expr/ir/EmitStreamSuite.scala b/hail/src/test/scala/is/hail/expr/ir/EmitStreamSuite.scala index 60f12c6476f..d0abf67043a 100644 --- a/hail/src/test/scala/is/hail/expr/ir/EmitStreamSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/EmitStreamSuite.scala @@ -1,5 +1,6 @@ package is.hail.expr.ir +import is.hail.{ExecStrategy, HailSuite} import is.hail.TestUtils._ import is.hail.annotations.{Region, SafeRow, ScalaToRegionValue} import is.hail.asm4s._ @@ -9,12 +10,14 @@ import is.hail.expr.ir.lowering.LoweringPipeline import is.hail.expr.ir.streams.{EmitStream, StreamUtils} import is.hail.types.VirtualTypeWithReq import is.hail.types.physical._ +import is.hail.types.physical.stypes.{ + PTypeReferenceSingleCodeType, SingleCodeSCode, StreamSingleCodeType, +} import is.hail.types.physical.stypes.interfaces.{NoBoxLongIterator, SStreamValue} -import is.hail.types.physical.stypes.{PTypeReferenceSingleCodeType, SingleCodeSCode, StreamSingleCodeType} import is.hail.types.virtual._ import is.hail.utils._ import is.hail.variant.Call2 -import is.hail.{ExecStrategy, HailSuite} + import org.apache.spark.sql.Row import org.testng.annotations.Test @@ -22,38 +25,21 @@ class EmitStreamSuite extends HailSuite { implicit val execStrats = ExecStrategy.compileOnly - private def compile1[T: TypeInfo, R: TypeInfo](f: (EmitMethodBuilder[_], Value[T]) => Code[R]): T => R = { - val fb = EmitFunctionBuilder[T, R](ctx, "stream_test") - val mb = fb.apply_method - mb.emit(f(mb, mb.getCodeParam[T](1))) - val asmFn = fb.result()(theHailClassLoader) - asmFn.apply - } - - private def compile2[T: TypeInfo, U: TypeInfo, R: TypeInfo](f: (EmitMethodBuilder[_], Code[T], Code[U]) => Code[R]): (T, U) => R = { - val fb = EmitFunctionBuilder[T, U, R](ctx, "F") - val mb = fb.apply_method - mb.emit(f(mb, mb.getCodeParam[T](1), mb.getCodeParam[U](2))) - val asmFn = fb.result()(theHailClassLoader) - asmFn.apply - } - - private def compile3[T: TypeInfo, U: TypeInfo, V: TypeInfo, R: TypeInfo](f: (EmitMethodBuilder[_], Code[T], Code[U], Code[V]) => Code[R]): (T, U, V) => R = { - val fb = EmitFunctionBuilder[T, U, V, R](ctx, "F") - val mb = fb.apply_method - mb.emit(f(mb, mb.getCodeParam[T](1), mb.getCodeParam[U](2), mb.getCodeParam[V](3))) - val asmFn = fb.result()(theHailClassLoader) - asmFn.apply - } - def log(str: Code[String], enabled: Boolean = false): Code[Unit] = if (enabled) Code._println(str) else Code._empty private def compileStream[F: TypeInfo, T]( streamIR: IR, - inputTypes: IndexedSeq[EmitParamType] - )(call: (F, Region, T) => Long): T => IndexedSeq[Any] = { - val fb = EmitFunctionBuilder[F](ctx, "F", (classInfo[Region]: ParamType) +: inputTypes.map(pt => pt: ParamType), LongInfo) + inputTypes: IndexedSeq[EmitParamType], + )( + call: (F, Region, T) => Long + ): T => IndexedSeq[Any] = { + val fb = EmitFunctionBuilder[F]( + ctx, + "F", + (classInfo[Region]: ParamType) +: inputTypes.map(pt => pt: ParamType), + LongInfo, + ) val mb = fb.apply_method val ir = streamIR.deepCopy() @@ -67,14 +53,26 @@ class EmitStreamSuite extends HailSuite { case s => s } TypeCheck(ctx, s) - EmitStream.produce(new Emit(emitContext, fb.ecb), s, cb, cb.emb, region, EmitEnv(Env.empty, inputTypes.indices.map(i => mb.storeEmitParamAsField(cb, i + 2))), None) - .consumeCode[Long](cb, 0L, { s => - val arr = StreamUtils.toArray(cb, s.asStream.getProducer(mb), region) - val scp = SingleCodeSCode.fromSCode(cb, arr, region, false) - arrayType = scp.typ.asInstanceOf[PTypeReferenceSingleCodeType].pt - - coerce[Long](scp.code) - }) + EmitStream.produce( + new Emit(emitContext, fb.ecb), + s, + cb, + cb.emb, + region, + EmitEnv(Env.empty, inputTypes.indices.map(i => mb.storeEmitParamAsField(cb, i + 2))), + None, + ) + .consumeCode[Long]( + cb, + 0L, + { s => + val arr = StreamUtils.toArray(cb, s.asStream.getProducer(mb), region) + val scp = SingleCodeSCode.fromSCode(cb, arr, region, false) + arrayType = scp.typ.asInstanceOf[PTypeReferenceSingleCodeType].pt + + coerce[Long](scp.code) + }, + ) }) val f = fb.resultWithIndex() (arg: T) => @@ -89,7 +87,10 @@ class EmitStreamSuite extends HailSuite { private def compileStream(ir: IR, inputType: PType): Any => IndexedSeq[Any] = { type F = AsmFunction3RegionLongBooleanLong - compileStream[F, Any](ir, FastSeq(SingleCodeEmitParamType(false, PTypeReferenceSingleCodeType(inputType)))) { (f: F, r: Region, arg: Any) => + compileStream[F, Any]( + ir, + FastSeq(SingleCodeEmitParamType(false, PTypeReferenceSingleCodeType(inputType))), + ) { (f: F, r: Region, arg: Any) => if (arg == null) f(r, 0L, true) else @@ -97,25 +98,33 @@ class EmitStreamSuite extends HailSuite { } } - private def compileStreamWithIter(ir: IR, requiresMemoryManagementPerElement: Boolean, elementType: PType): Iterator[Any] => IndexedSeq[Any] = { + private def compileStreamWithIter( + ir: IR, + requiresMemoryManagementPerElement: Boolean, + elementType: PType, + ): Iterator[Any] => IndexedSeq[Any] = { trait F { def apply(o: Region, a: NoBoxLongIterator): Long } - compileStream[F, Iterator[Any]](ir, - IndexedSeq(SingleCodeEmitParamType(true, StreamSingleCodeType(requiresMemoryManagementPerElement, elementType, true)))) { (f: F, r: Region, it: Iterator[Any]) => - val rvi = new NoBoxLongIterator { + compileStream[F, Iterator[Any]]( + ir, + IndexedSeq(SingleCodeEmitParamType( + true, + StreamSingleCodeType(requiresMemoryManagementPerElement, elementType, true), + )), + ) { (f: F, r: Region, it: Iterator[Any]) => + val rvi = new NoBoxLongIterator { var _eltRegion: Region = _ var eos: Boolean = _ def init(outerRegion: Region, eltRegion: Region): Unit = _eltRegion = eltRegion - override def next(): Long = { + override def next(): Long = if (eos || !it.hasNext) { eos = true 0L } else ScalaToRegionValue(ctx.stateManager, _eltRegion, elementType, it.next()) - } override def close(): Unit = () } assert(it != null, "null iterators not supported") @@ -124,7 +133,7 @@ class EmitStreamSuite extends HailSuite { } private def evalStream(ir: IR): IndexedSeq[Any] = - compileStream[AsmFunction1RegionLong, Unit](ir, FastSeq()) { (f, r, _) => f(r) } + compileStream[AsmFunction1RegionLong, Unit](ir, FastSeq())((f, r, _) => f(r)) .apply(()) private def evalStreamLen(streamIR: IR): Option[Int] = { @@ -139,17 +148,36 @@ class EmitStreamSuite extends HailSuite { val len = cb.newLocal[Int]("len", 0) val len2 = cb.newLocal[Int]("len2", -1) - EmitStream.produce(new Emit(emitContext, fb.ecb), ir, cb, cb.emb, region, EmitEnv(Env.empty, FastSeq()), None) - .consume(cb, + EmitStream.produce( + new Emit(emitContext, fb.ecb), + ir, + cb, + cb.emb, + region, + EmitEnv(Env.empty, FastSeq()), + None, + ) + .consume( + cb, {}, { case stream: SStreamValue => val producer = stream.getProducer(cb.emb) - producer.memoryManagedConsume(region, cb, { cb => producer.length.foreach(computeLen => cb.assign(len2, computeLen(cb))) }) { cb => - cb.assign(len, len + 1) - } - }) - cb.if_(len2.cne(-1) && (len2.cne(len)), - cb._fatal(s"length mismatch between computed and iteration length: computed=", len2.toS, ", iter=", len.toS)) + producer.memoryManagedConsume( + region, + cb, + cb => producer.length.foreach(computeLen => cb.assign(len2, computeLen(cb))), + )(cb => cb.assign(len, len + 1)) + }, + ) + cb.if_( + len2.cne(-1) && (len2.cne(len)), + cb._fatal( + s"length mismatch between computed and iteration length: computed=", + len2.toS, + ", iter=", + len.toS, + ), + ) len2 } @@ -160,19 +188,21 @@ class EmitStreamSuite extends HailSuite { } } - @Test def testEmitNA() { + @Test def testEmitNA(): Unit = assert(evalStream(NA(TStream(TInt32))) == null) - } - @Test def testEmitMake() { + @Test def testEmitMake(): Unit = { val typ = TStream(TInt32) val tests: Array[(IR, IndexedSeq[Any])] = Array( MakeStream(IndexedSeq[IR](1, 2, NA(TInt32), 3), typ) -> IndexedSeq(1, 2, null, 3), MakeStream(IndexedSeq[IR](), typ) -> IndexedSeq(), - MakeStream(IndexedSeq[IR](MakeTuple.ordered(IndexedSeq(4, 5))), TStream(TTuple(TInt32, TInt32))) -> + MakeStream( + IndexedSeq[IR](MakeTuple.ordered(IndexedSeq(4, 5))), + TStream(TTuple(TInt32, TInt32)), + ) -> IndexedSeq(Row(4, 5)), MakeStream(IndexedSeq[IR](Str("hi"), Str("world")), TStream(TString)) -> - IndexedSeq("hi", "world") + IndexedSeq("hi", "world"), ) for ((ir, v) <- tests) { assert(evalStream(ir) == v, Pretty(ctx, ir)) @@ -180,22 +210,35 @@ class EmitStreamSuite extends HailSuite { } } - @Test def testEmitRange() { - val tripleType = PCanonicalStruct(false, "start" -> PInt32(), "stop" -> PInt32(), "step" -> PInt32()) + @Test def testEmitRange(): Unit = { + val tripleType = + PCanonicalStruct(false, "start" -> PInt32(), "stop" -> PInt32(), "step" -> PInt32()) val range = compileStream( StreamRange( - GetField(In(0, SingleCodeEmitParamType(false, PTypeReferenceSingleCodeType(tripleType))), "start"), - GetField(In(0, SingleCodeEmitParamType(false, PTypeReferenceSingleCodeType(tripleType))), "stop"), - GetField(In(0, SingleCodeEmitParamType(false, PTypeReferenceSingleCodeType(tripleType))), "step")), - tripleType) + GetField( + In(0, SingleCodeEmitParamType(false, PTypeReferenceSingleCodeType(tripleType))), + "start", + ), + GetField( + In(0, SingleCodeEmitParamType(false, PTypeReferenceSingleCodeType(tripleType))), + "stop", + ), + GetField( + In(0, SingleCodeEmitParamType(false, PTypeReferenceSingleCodeType(tripleType))), + "step", + ), + ), + tripleType, + ) for { start <- -2 to 2 - stop <- -2 to 8 - step <- 1 to 3 - } { - assert(range(Row(start, stop, step)) == Array.range(start, stop, step).toFastSeq, - s"($start, $stop, $step)") + stop <- -2 to 8 + step <- 1 to 3 } + assert( + range(Row(start, stop, step)) == Array.range(start, stop, step).toFastSeq, + s"($start, $stop, $step)", + ) assert(range(Row(null, 10, 1)) == null) assert(range(Row(0, null, 1)) == null) assert(range(Row(0, 10, null)) == null) @@ -207,16 +250,18 @@ class EmitStreamSuite extends HailSuite { val n = 2 val seqIr = SeqSample( - I32(N), - I32(n), - RNGStateLiteral(), - false - ) + I32(N), + I32(n), + RNGStateLiteral(), + false, + ) - val compiled = compileStream[AsmFunction1RegionLong, Unit](seqIr, FastSeq()) { (f, r, _) => f(r) } + val compiled = compileStream[AsmFunction1RegionLong, Unit](seqIr, FastSeq()) { (f, r, _) => + f(r) + } // Generate many pairs of numbers between 0 and N, every pair should be equally likely - val results = Array.tabulate(N, N){ case(i, j) => 0} + val results = Array.tabulate(N, N) { case (_, _) => 0 } (0 until 1000000).foreach { i => val IndexedSeq = compiled.apply(()).map(_.asInstanceOf[Int]) assert(IndexedSeq.size == n) @@ -225,7 +270,6 @@ class EmitStreamSuite extends HailSuite { assert(IndexedSeq.forall(e => e >= 0 && e < N)) } - (0 until N).foreach { i => (i + 1 until N).foreach { j => val entry = results(i)(j) @@ -235,11 +279,11 @@ class EmitStreamSuite extends HailSuite { } } - @Test def testEmitToStream() { + @Test def testEmitToStream(): Unit = { val tests: Array[(IR, IndexedSeq[Any])] = Array( ToStream(MakeArray(IndexedSeq[IR](), TArray(TInt32))) -> IndexedSeq(), ToStream(MakeArray(IndexedSeq[IR](1, 2, 3, 4), TArray(TInt32))) -> IndexedSeq(1, 2, 3, 4), - ToStream(NA(TArray(TInt32))) -> null + ToStream(NA(TArray(TInt32))) -> null, ) for ((ir, v) <- tests) { val expectedLen = Option(v).map(_.length) @@ -250,18 +294,19 @@ class EmitStreamSuite extends HailSuite { @Test def testEmitLet(): Unit = { val ir = - Let(FastSeq("start" -> 3, "end" -> 10), + Let( + FastSeq("start" -> 3, "end" -> 10), StreamFlatMap( StreamRange(Ref("start", TInt32), Ref("end", TInt32), 1), "i", - MakeStream(IndexedSeq(Ref("i", TInt32), Ref("end", TInt32)), TStream(TInt32)) - ) + MakeStream(IndexedSeq(Ref("i", TInt32), Ref("end", TInt32)), TStream(TInt32)), + ), ) - assert(evalStream(ir) == (3 until 10).flatMap { i => IndexedSeq(i, 10) }, Pretty(ctx, ir)) + assert(evalStream(ir) == (3 until 10).flatMap(i => IndexedSeq(i, 10)), Pretty(ctx, ir)) assert(evalStreamLen(ir).isEmpty, Pretty(ctx, ir)) } - @Test def testEmitMap() { + @Test def testEmitMap(): Unit = { def ten = StreamRange(I32(0), I32(10), I32(1)) def x = Ref("x", TInt32) @@ -272,7 +317,7 @@ class EmitStreamSuite extends HailSuite { StreamMap(ten, "x", x * 2) -> (0 until 10).map(_ * 2), StreamMap(ten, "x", x.toL) -> (0 until 10).map(_.toLong), StreamMap(StreamMap(ten, "x", x + 1), "y", y * y) -> (0 until 10).map(i => (i + 1) * (i + 1)), - StreamMap(ten, "x", NA(TInt32)) -> IndexedSeq.tabulate(10) { _ => null } + StreamMap(ten, "x", NA(TInt32)) -> IndexedSeq.tabulate(10)(_ => null), ) for ((ir, v) <- tests) { assert(evalStream(ir) == v, Pretty(ctx, ir)) @@ -280,7 +325,7 @@ class EmitStreamSuite extends HailSuite { } } - @Test def testEmitFilter() { + @Test def testEmitFilter(): Unit = { def ten = StreamRange(I32(0), I32(10), I32(1)) def x = Ref("x", TInt32) @@ -289,9 +334,13 @@ class EmitStreamSuite extends HailSuite { val tests: Array[(IR, IndexedSeq[Any])] = Array( StreamFilter(ten, "x", x cne 5) -> (0 until 10).filter(_ != 5), - StreamFilter(StreamMap(ten, "x", (x * 2).toL), "y", y > 5L) -> (3 until 10).map(x => (x * 2).toLong), + StreamFilter(StreamMap(ten, "x", (x * 2).toL), "y", y > 5L) -> (3 until 10).map(x => + (x * 2).toLong + ), StreamFilter(StreamMap(ten, "x", (x * 2).toL), "y", NA(TBoolean)) -> IndexedSeq(), - StreamFilter(StreamMap(ten, "x", NA(TInt32)), "z", True()) -> IndexedSeq.tabulate(10) { _ => null } + StreamFilter(StreamMap(ten, "x", NA(TInt32)), "z", True()) -> IndexedSeq.tabulate(10) { _ => + null + }, ) for ((ir, v) <- tests) { assert(evalStream(ir) == v, Pretty(ctx, ir)) @@ -299,7 +348,7 @@ class EmitStreamSuite extends HailSuite { } } - @Test def testEmitFlatMap() { + @Test def testEmitFlatMap(): Unit = { def x = Ref("x", TInt32) def y = Ref("y", TInt32) @@ -311,16 +360,24 @@ class EmitStreamSuite extends HailSuite { IndexedSeq(), StreamFlatMap(StreamRange(0, NA(TInt32), 1), "x", StreamRange(0, x, 1)) -> null, - StreamFlatMap(StreamRange(0, 20, 1), "x", - StreamFlatMap(StreamRange(0, x, 1), "y", - StreamRange(0, (x + y), 1))) -> - (0 until 20).flatMap { x => (0 until x).flatMap { y => 0 until (x + y) } }, - StreamFlatMap(StreamFilter(StreamRange(0, 5, 1), "x", x cne 3), - "y", MakeStream(IndexedSeq(y, y), TStream(TInt32))) -> + StreamFlatMap( + StreamRange(0, 20, 1), + "x", + StreamFlatMap(StreamRange(0, x, 1), "y", StreamRange(0, (x + y), 1)), + ) -> + (0 until 20).flatMap(x => (0 until x).flatMap(y => 0 until (x + y))), + StreamFlatMap( + StreamFilter(StreamRange(0, 5, 1), "x", x cne 3), + "y", + MakeStream(IndexedSeq(y, y), TStream(TInt32)), + ) -> IndexedSeq(0, 0, 1, 1, 2, 2, 4, 4), - StreamFlatMap(StreamRange(0, 4, 1), - "x", ToStream(MakeArray(IndexedSeq[IR](x, x), TArray(TInt32)))) -> - IndexedSeq(0, 0, 1, 1, 2, 2, 3, 3) + StreamFlatMap( + StreamRange(0, 4, 1), + "x", + ToStream(MakeArray(IndexedSeq[IR](x, x), TArray(TInt32))), + ) -> + IndexedSeq(0, 0, 1, 1, 2, 2, 3, 3), ) for ((ir, v) <- tests) { assert(evalStream(ir) == v, Pretty(ctx, ir)) @@ -334,71 +391,148 @@ class EmitStreamSuite extends HailSuite { val streamType = TStream(TStruct("a" -> TInt64, "b" -> TInt64)) val numSeq = (0 until 12).map(i => IndexedSeq(I64(i), I64(i + 1))) val numTupleSeq = numSeq.map(_ => IndexedSeq("a", "b")).zip(numSeq) - val countStructSeq = numTupleSeq.map { case (s, i) => s.zip(i)}.map(is => MakeStruct(is)) + val countStructSeq = numTupleSeq.map { case (s, i) => s.zip(i) }.map(is => MakeStruct(is)) val countStructStream = MakeStream(countStructSeq, streamType, false) - val countAggSig = PhysicalAggSig(Count(), TypedStateSig(VirtualTypeWithReq.fullyOptional(TInt64).setRequired(true))) + val countAggSig = PhysicalAggSig( + Count(), + TypedStateSig(VirtualTypeWithReq.fullyOptional(TInt64).setRequired(true)), + ) val initOps = InitOp(0, FastSeq(), countAggSig) val seqOps = SeqOp(0, FastSeq(), countAggSig) - val newKey = MakeStruct(IndexedSeq("count" -> SelectFields(Ref("foo", streamType.elementType), IndexedSeq("a", "b")))) - val streamBuffAggCount = StreamBufferedAggregate(countStructStream, initOps, newKey, seqOps, "foo", IndexedSeq(countAggSig), 8) + val newKey = MakeStruct(IndexedSeq("count" -> SelectFields( + Ref("foo", streamType.elementType), + IndexedSeq("a", "b"), + ))) + val streamBuffAggCount = StreamBufferedAggregate( + countStructStream, + initOps, + newKey, + seqOps, + "foo", + IndexedSeq(countAggSig), + 8, + ) val result = mapIR(streamBuffAggCount) { elem => MakeStruct(IndexedSeq( "key" -> GetField(elem, "count"), "aggResult" -> - RunAgg(InitFromSerializedValue(0, GetTupleElement(GetField(elem, "agg"), 0), countAggSig.state), ResultOp(0, countAggSig), IndexedSeq(countAggSig.state)) - ))} + RunAgg( + InitFromSerializedValue( + 0, + GetTupleElement(GetField(elem, "agg"), 0), + countAggSig.state, + ), + ResultOp(0, countAggSig), + IndexedSeq(countAggSig.state), + ), + )) + } assert(evalStream(result).equals(resultArrayToCompare)) } + @Test def testStreamBufferedAggregatorCombine(): Unit = { val resultArrayToCompare = IndexedSeq(Row(Row(1), 2)) val streamType = TStream(TStruct("a" -> TInt64)) val elemOne = MakeStruct(IndexedSeq(("a", I64(1)))) val elemTwo = MakeStruct(IndexedSeq(("a", I64(1)))) val countStructStream = MakeStream(IndexedSeq(elemOne, elemTwo), streamType) - val countAggSig = PhysicalAggSig(Count(), TypedStateSig(VirtualTypeWithReq.fullyOptional(TInt64).setRequired(true))) + val countAggSig = PhysicalAggSig( + Count(), + TypedStateSig(VirtualTypeWithReq.fullyOptional(TInt64).setRequired(true)), + ) val initOps = InitOp(0, FastSeq(), countAggSig) val seqOps = SeqOp(0, FastSeq(), countAggSig) - val newKey = MakeStruct(IndexedSeq("count" -> SelectFields(Ref("foo", streamType.elementType), IndexedSeq("a")))) - val streamBuffAggCount = StreamBufferedAggregate(countStructStream, initOps, newKey, seqOps, "foo", IndexedSeq(countAggSig), 8) + val newKey = MakeStruct(IndexedSeq("count" -> SelectFields( + Ref("foo", streamType.elementType), + IndexedSeq("a"), + ))) + val streamBuffAggCount = StreamBufferedAggregate( + countStructStream, + initOps, + newKey, + seqOps, + "foo", + IndexedSeq(countAggSig), + 8, + ) val result = mapIR(streamBuffAggCount) { elem => MakeStruct(IndexedSeq( "key" -> GetField(elem, "count"), "aggResult" -> - RunAgg(InitFromSerializedValue(0, GetTupleElement(GetField(elem, "agg"), 0), countAggSig.state), ResultOp(0, countAggSig), IndexedSeq(countAggSig.state)) - ))} + RunAgg( + InitFromSerializedValue( + 0, + GetTupleElement(GetField(elem, "agg"), 0), + countAggSig.state, + ), + ResultOp(0, countAggSig), + IndexedSeq(countAggSig.state), + ), + )) + } assert(evalStream(result) == resultArrayToCompare) } @Test def testStreamBufferedAggregatorCollectAggregator(): Unit = { - val resultArrayToCompare = IndexedSeq(Row(Row(1), IndexedSeq(1, 3)), Row(Row(2), IndexedSeq(2, 4))) + val resultArrayToCompare = + IndexedSeq(Row(Row(1), IndexedSeq(1, 3)), Row(Row(2), IndexedSeq(2, 4))) val streamType = TStream(TStruct("a" -> TInt64, "b" -> TInt64)) val elemOne = MakeStruct(IndexedSeq(("a", I64(1)), ("b", I64(1)))) val elemTwo = MakeStruct(IndexedSeq(("a", I64(2)), ("b", I64(2)))) val elemThree = MakeStruct(IndexedSeq(("a", I64(1)), ("b", I64(3)))) val elemFour = MakeStruct(IndexedSeq(("a", I64(2)), ("b", I64(4)))) - val collectStructStream = MakeStream(IndexedSeq(elemOne, elemTwo, elemThree, elemFour), streamType) - val collectAggSig = PhysicalAggSig(Collect(), CollectStateSig(VirtualTypeWithReq(PType.canonical(TInt64)))) + val collectStructStream = + MakeStream(IndexedSeq(elemOne, elemTwo, elemThree, elemFour), streamType) + val collectAggSig = + PhysicalAggSig(Collect(), CollectStateSig(VirtualTypeWithReq(PType.canonical(TInt64)))) val initOps = InitOp(0, FastSeq(), collectAggSig) val seqOps = SeqOp(0, FastSeq(GetField(Ref("foo", streamType.elementType), "b")), collectAggSig) - val newKey = MakeStruct(IndexedSeq("collect" -> SelectFields(Ref("foo", streamType.elementType), IndexedSeq("a")))) - val streamBuffAggCollect = StreamBufferedAggregate(collectStructStream, initOps, newKey, seqOps, "foo", IndexedSeq(collectAggSig), 8) + val newKey = MakeStruct(IndexedSeq("collect" -> SelectFields( + Ref("foo", streamType.elementType), + IndexedSeq("a"), + ))) + val streamBuffAggCollect = StreamBufferedAggregate( + collectStructStream, + initOps, + newKey, + seqOps, + "foo", + IndexedSeq(collectAggSig), + 8, + ) val result = mapIR(streamBuffAggCollect) { elem => MakeStruct(IndexedSeq( "key" -> GetField(elem, "collect"), "aggResult" -> - RunAgg(InitFromSerializedValue(0, GetTupleElement(GetField(elem, "agg"), 0), collectAggSig.state), ResultOp(0, collectAggSig), IndexedSeq(collectAggSig.state)) - ))} + RunAgg( + InitFromSerializedValue( + 0, + GetTupleElement(GetField(elem, "agg"), 0), + collectAggSig.state, + ), + ResultOp(0, collectAggSig), + IndexedSeq(collectAggSig.state), + ), + )) + } assert(evalStream(result) == resultArrayToCompare) } @Test def testStreamBufferedAggregatorMultipleAggregators(): Unit = { - val resultArrayToCompare = IndexedSeq(Row(Row(1), Row(3, IndexedSeq(1L, 3L, 2L))), Row(Row(2), Row(2, IndexedSeq(2L, 4L))), - Row(Row(3), Row(3, IndexedSeq(1L, 2L, 3L))), Row(Row(4), Row(1, IndexedSeq(4L))), - Row(Row(5), Row(1, IndexedSeq(1L))), Row(Row(6), Row(1, IndexedSeq(3L))), - Row(Row(7), Row(1, IndexedSeq(4L))), Row(Row(8), Row(1, IndexedSeq(1L))), - Row(Row(8), Row(1, IndexedSeq(2L))), Row(Row(9), Row(1, IndexedSeq(3L))), - Row(Row(10), Row(2, IndexedSeq(4L, 4L)))) + val resultArrayToCompare = IndexedSeq( + Row(Row(1), Row(3, IndexedSeq(1L, 3L, 2L))), + Row(Row(2), Row(2, IndexedSeq(2L, 4L))), + Row(Row(3), Row(3, IndexedSeq(1L, 2L, 3L))), + Row(Row(4), Row(1, IndexedSeq(4L))), + Row(Row(5), Row(1, IndexedSeq(1L))), + Row(Row(6), Row(1, IndexedSeq(3L))), + Row(Row(7), Row(1, IndexedSeq(4L))), + Row(Row(8), Row(1, IndexedSeq(1L))), + Row(Row(8), Row(1, IndexedSeq(2L))), + Row(Row(9), Row(1, IndexedSeq(3L))), + Row(Row(10), Row(2, IndexedSeq(4L, 4L))), + ) val streamType = TStream(TStruct("a" -> TInt64, "b" -> TInt64)) val elemOne = MakeStruct(IndexedSeq(("a", I64(1)), ("b", I64(1)))) val elemTwo = MakeStruct(IndexedSeq(("a", I64(2)), ("b", I64(2)))) @@ -417,77 +551,119 @@ class EmitStreamSuite extends HailSuite { val elemFifteen = MakeStruct(IndexedSeq(("a", I64(9)), ("b", I64(3)))) val elemSixteen = MakeStruct(IndexedSeq(("a", I64(10)), ("b", I64(4)))) val elemSeventeen = MakeStruct(IndexedSeq(("a", I64(10)), ("b", I64(4)))) - val collectStructStream = MakeStream(IndexedSeq(elemOne, elemTwo, elemThree, elemFour, elemFive, elemSix, elemSeven, - elemEight, elemNine, elemTen, elemEleven, elemTwelve, elemThirteen, - elemFourteen, elemFifteen, elemSixteen, elemSeventeen), streamType) - val collectAggSig = PhysicalAggSig(Collect(), CollectStateSig(VirtualTypeWithReq(PType.canonical(TInt64)))) - val countAggSig = PhysicalAggSig(Count(), TypedStateSig(VirtualTypeWithReq.fullyOptional(TInt64).setRequired(true))) + val collectStructStream = MakeStream( + IndexedSeq(elemOne, elemTwo, elemThree, elemFour, elemFive, elemSix, elemSeven, + elemEight, elemNine, elemTen, elemEleven, elemTwelve, elemThirteen, + elemFourteen, elemFifteen, elemSixteen, elemSeventeen), + streamType, + ) + val collectAggSig = + PhysicalAggSig(Collect(), CollectStateSig(VirtualTypeWithReq(PType.canonical(TInt64)))) + val countAggSig = PhysicalAggSig( + Count(), + TypedStateSig(VirtualTypeWithReq.fullyOptional(TInt64).setRequired(true)), + ) val initOps = Begin(IndexedSeq( InitOp(0, FastSeq(), countAggSig), - InitOp(1, FastSeq(), collectAggSig) + InitOp(1, FastSeq(), collectAggSig), )) val seqOps = Begin(IndexedSeq( SeqOp(0, FastSeq(), countAggSig), - SeqOp(1, FastSeq(GetField(Ref("foo", streamType.elementType), "b")), collectAggSig) + SeqOp(1, FastSeq(GetField(Ref("foo", streamType.elementType), "b")), collectAggSig), )) - val newKey = MakeStruct(IndexedSeq("collect" -> SelectFields(Ref("foo", streamType.elementType), IndexedSeq("a")))) - val streamBuffAggCollect = StreamBufferedAggregate(collectStructStream, initOps, newKey, seqOps, "foo", - IndexedSeq(countAggSig, collectAggSig), 8) + val newKey = MakeStruct(IndexedSeq("collect" -> SelectFields( + Ref("foo", streamType.elementType), + IndexedSeq("a"), + ))) + val streamBuffAggCollect = StreamBufferedAggregate( + collectStructStream, + initOps, + newKey, + seqOps, + "foo", + IndexedSeq(countAggSig, collectAggSig), + 8, + ) val result = mapIR(streamBuffAggCollect) { elem => MakeStruct(IndexedSeq( "key" -> GetField(elem, "collect"), "aggResult" -> RunAgg( Begin(IndexedSeq( - InitFromSerializedValue(0, GetTupleElement(GetField(elem, "agg"), 0), countAggSig.state), - InitFromSerializedValue(1, GetTupleElement(GetField(elem, "agg"), 1), collectAggSig.state)) - ), + InitFromSerializedValue( + 0, + GetTupleElement(GetField(elem, "agg"), 0), + countAggSig.state, + ), + InitFromSerializedValue( + 1, + GetTupleElement(GetField(elem, "agg"), 1), + collectAggSig.state, + ), + )), MakeTuple.ordered(IndexedSeq(ResultOp(0, countAggSig), ResultOp(1, collectAggSig))), - IndexedSeq(countAggSig.state, collectAggSig.state)) - ))} + IndexedSeq(countAggSig.state, collectAggSig.state), + ), + )) + } assert(evalStream(result) == resultArrayToCompare) } - @Test def testEmitJoinRightDistinct() { + @Test def testEmitJoinRightDistinct(): Unit = { val eltType = TStruct("k" -> TInt32, "v" -> TString) def join(lstream: IR, rstream: IR, joinType: String): IR = StreamJoinRightDistinct( - lstream, rstream, FastSeq("k"), FastSeq("k"), "l", "r", + lstream, + rstream, + FastSeq("k"), + FastSeq("k"), + "l", + "r", MakeTuple.ordered(IndexedSeq( GetField(Ref("l", eltType), "v"), - GetField(Ref("r", eltType), "v"))), - joinType) + GetField(Ref("r", eltType), "v"), + )), + joinType, + ) def leftjoin(lstream: IR, rstream: IR): IR = join(lstream, rstream, "left") def outerjoin(lstream: IR, rstream: IR): IR = join(lstream, rstream, "outer") def pairs(xs: IndexedSeq[(Int, String)]): IR = - MakeStream(xs.map { case (a, b) => MakeStruct(IndexedSeq("k" -> I32(a), "v" -> Str(b))) }, TStream(eltType)) + MakeStream( + xs.map { case (a, b) => MakeStruct(IndexedSeq("k" -> I32(a), "v" -> Str(b))) }, + TStream(eltType), + ) val tests: Array[(IR, IR, IndexedSeq[Any], IndexedSeq[Any])] = Array( (pairs(IndexedSeq()), pairs(IndexedSeq()), IndexedSeq(), IndexedSeq()), - (pairs(IndexedSeq(3 -> "A")), + ( + pairs(IndexedSeq(3 -> "A")), pairs(IndexedSeq()), IndexedSeq(Row("A", null)), - IndexedSeq(Row("A", null))), - (pairs(IndexedSeq()), - pairs(IndexedSeq(3 -> "B")), - IndexedSeq(), - IndexedSeq(Row(null, "B"))), - (pairs(IndexedSeq(0 -> "A")), + IndexedSeq(Row("A", null)), + ), + (pairs(IndexedSeq()), pairs(IndexedSeq(3 -> "B")), IndexedSeq(), IndexedSeq(Row(null, "B"))), + ( + pairs(IndexedSeq(0 -> "A")), pairs(IndexedSeq(0 -> "B")), IndexedSeq(Row("A", "B")), - IndexedSeq(Row("A", "B"))), - (pairs(IndexedSeq(0 -> "A", 2 -> "B", 3 -> "C")), + IndexedSeq(Row("A", "B")), + ), + ( + pairs(IndexedSeq(0 -> "A", 2 -> "B", 3 -> "C")), pairs(IndexedSeq(0 -> "a", 1 -> ".", 2 -> "b", 4 -> "..")), IndexedSeq(Row("A", "a"), Row("B", "b"), Row("C", null)), - IndexedSeq(Row("A", "a"), Row(null, "."), Row("B", "b"), Row("C", null), Row(null, ".."))), - (pairs(IndexedSeq(0 -> "A", 1 -> "B1", 1 -> "B2")), + IndexedSeq(Row("A", "a"), Row(null, "."), Row("B", "b"), Row("C", null), Row(null, "..")), + ), + ( + pairs(IndexedSeq(0 -> "A", 1 -> "B1", 1 -> "B2")), pairs(IndexedSeq(0 -> "a", 1 -> "b", 2 -> "c")), IndexedSeq(Row("A", "a"), Row("B1", "b"), Row("B2", "b")), - IndexedSeq(Row("A", "a"), Row("B1", "b"), Row("B2", "b"), Row(null, "c"))) + IndexedSeq(Row("A", "a"), Row("B1", "b"), Row("B2", "b"), Row(null, "c")), + ), ) for ((lstream, rstream, expectedLeft, expectedOuter) <- tests) { val l = leftjoin(lstream, rstream) @@ -499,62 +675,88 @@ class EmitStreamSuite extends HailSuite { } } - @Test def testEmitJoinRightDistinctInterval() { + @Test def testEmitJoinRightDistinctInterval(): Unit = { val lEltType = TStruct("k" -> TInt32, "v" -> TString) val rEltType = TStruct("k" -> TInterval(TInt32), "v" -> TString) def join(lstream: IR, rstream: IR, joinType: String): IR = StreamJoinRightDistinct( - lstream, rstream, FastSeq("k"), FastSeq("k"), "l", "r", + lstream, + rstream, + FastSeq("k"), + FastSeq("k"), + "l", + "r", MakeTuple.ordered(IndexedSeq( GetField(Ref("l", lEltType), "v"), - GetField(Ref("r", rEltType), "v"))), - joinType) + GetField(Ref("r", rEltType), "v"), + )), + joinType, + ) def leftjoin(lstream: IR, rstream: IR): IR = join(lstream, rstream, "left") def innerjoin(lstream: IR, rstream: IR): IR = join(lstream, rstream, "inner") def lElts(xs: (Int, String)*): IR = - MakeStream(xs.toArray.map { case (a, b) => MakeStruct(IndexedSeq("k" -> I32(a), "v" -> Str(b))) }, TStream(lEltType)) + MakeStream( + xs.toArray.map { case (a, b) => MakeStruct(IndexedSeq("k" -> I32(a), "v" -> Str(b))) }, + TStream(lEltType), + ) def rElts(xs: ((Char, Any, Any, Char), String)*): IR = - MakeStream(xs.toArray.map { - case ((is, s, e, ie), v) => - val start = if (s == null) NA(TInt32) else I32(s.asInstanceOf[Int]) - val end = if (e == null) NA(TInt32) else I32(e.asInstanceOf[Int]) - val includesStart = is == '[' - val includesEnd = ie == ']' - val interval = ApplySpecial("Interval", FastSeq(), FastSeq(start, end, includesStart, includesEnd), TInterval(TInt32), 0) - MakeStruct(IndexedSeq("k" -> interval, "v" -> Str(v))) - }, TStream(rEltType)) + MakeStream( + xs.toArray.map { + case ((is, s, e, ie), v) => + val start = if (s == null) NA(TInt32) else I32(s.asInstanceOf[Int]) + val end = if (e == null) NA(TInt32) else I32(e.asInstanceOf[Int]) + val includesStart = is == '[' + val includesEnd = ie == ']' + val interval = ApplySpecial( + "Interval", + FastSeq(), + FastSeq(start, end, includesStart, includesEnd), + TInterval(TInt32), + 0, + ) + MakeStruct(IndexedSeq("k" -> interval, "v" -> Str(v))) + }, + TStream(rEltType), + ) val tests: Array[(IR, IR, IndexedSeq[Any], IndexedSeq[Any])] = Array( (lElts(), rElts(), IndexedSeq(), IndexedSeq()), - (lElts(3 -> "A"), - rElts(), - IndexedSeq(Row("A", null)), - IndexedSeq()), - (lElts(), - rElts(('[', 1, 2, ']') -> "B"), - IndexedSeq(), - IndexedSeq()), - (lElts(0 -> "A"), + (lElts(3 -> "A"), rElts(), IndexedSeq(Row("A", null)), IndexedSeq()), + (lElts(), rElts(('[', 1, 2, ']') -> "B"), IndexedSeq(), IndexedSeq()), + ( + lElts(0 -> "A"), rElts(('[', 0, 1, ')') -> "B"), IndexedSeq(Row("A", "B")), - IndexedSeq(Row("A", "B"))), - (lElts(0 -> "A"), - rElts(('(', 0, 1, ')') -> "B"), - IndexedSeq(Row("A", null)), - IndexedSeq()), - (lElts(0 -> "A", 2 -> "B", 3 -> "C", 4 -> "D"), - rElts(('[', 0, 2, ')') -> "a", ('(', 0, 1, ']') -> ".", ('[', 1, 4, ')') -> "b", ('[', 2, 4, ')') -> ".."), + IndexedSeq(Row("A", "B")), + ), + (lElts(0 -> "A"), rElts(('(', 0, 1, ')') -> "B"), IndexedSeq(Row("A", null)), IndexedSeq()), + ( + lElts(0 -> "A", 2 -> "B", 3 -> "C", 4 -> "D"), + rElts( + ('[', 0, 2, ')') -> "a", + ('(', 0, 1, ']') -> ".", + ('[', 1, 4, ')') -> "b", + ('[', 2, 4, ')') -> "..", + ), IndexedSeq(Row("A", "a"), Row("B", "b"), Row("C", "b"), Row("D", null)), - IndexedSeq(Row("A", "a"), Row("B", "b"), Row("C", "b"))), - (lElts(1 -> "A", 2 -> "B", 3 -> "C", 4 -> "D"), - rElts(('[', 0, null, ')') -> ".", ('(', 0, 1, ']') -> "a", ('[', 1, 4, ')') -> "b", ('[', 2, 4, ')') -> ".."), + IndexedSeq(Row("A", "a"), Row("B", "b"), Row("C", "b")), + ), + ( + lElts(1 -> "A", 2 -> "B", 3 -> "C", 4 -> "D"), + rElts( + ('[', 0, null, ')') -> ".", + ('(', 0, 1, ']') -> "a", + ('[', 1, 4, ')') -> "b", + ('[', 2, 4, ')') -> "..", + ), IndexedSeq(Row("A", "a"), Row("B", "b"), Row("C", "b"), Row("D", null)), - IndexedSeq(Row("A", "a"), Row("B", "b"), Row("C", "b"))) + IndexedSeq(Row("A", "a"), Row("B", "b"), Row("C", "b")), + ), ) for ((lstream, rstream, expectedLeft, expectedInner) <- tests) { val l = leftjoin(lstream, rstream) @@ -566,51 +768,59 @@ class EmitStreamSuite extends HailSuite { } } - @Test def testStreamJoinOuterWithKeyRepeats() { + @Test def testStreamJoinOuterWithKeyRepeats(): Unit = { val lEltType = TStruct("k" -> TInt32, "idx_left" -> TInt32) val lRows = FastSeq( Row(1, 1), Row(1, 2), Row(1, 3), - Row(3, 4) + Row(3, 4), ) val a = ToStream( Literal( TArray(lEltType), - lRows - )) + lRows, + ) + ) val rEltType = TStruct("k" -> TInt32, "idx_right" -> TInt32) val rRows = FastSeq( Row(1, 1), Row(2, 2), - Row(4, 3) + Row(4, 3), ) val b = ToStream( Literal( TArray(rEltType), - rRows - )) + rRows, + ) + ) - val ir = StreamJoinRightDistinct(a, b, - FastSeq("k"), FastSeq("k"), - "L", "R", + val ir = StreamJoinRightDistinct( + a, + b, + FastSeq("k"), + FastSeq("k"), + "L", + "R", MakeStruct(FastSeq("left" -> Ref("L", lEltType), "right" -> Ref("R", rEltType))), - "outer") + "outer", + ) val compiled = evalStream(ir) val expected = FastSeq( - Row( Row(1, 1), Row(1, 1)), - Row( Row(1, 2), Row(1, 1)), - Row( Row(1, 3), Row(1, 1)), - Row( null, Row(2, 2)), - Row( Row(3, 4), null), - Row(null, Row(4, 3))) + Row(Row(1, 1), Row(1, 1)), + Row(Row(1, 2), Row(1, 1)), + Row(Row(1, 3), Row(1, 1)), + Row(null, Row(2, 2)), + Row(Row(3, 4), null), + Row(null, Row(4, 3)), + ) assert(compiled == expected) } - @Test def testEmitScan() { + @Test def testEmitScan(): Unit = { def a = Ref("a", TInt32) def v = Ref("v", TInt32) @@ -618,10 +828,9 @@ class EmitStreamSuite extends HailSuite { def x = Ref("x", TInt32) val tests: Array[(IR, IndexedSeq[Any])] = Array( - StreamScan(MakeStream(IndexedSeq(), TStream(TInt32)), - 9, "a", "v", a + v) -> IndexedSeq(9), - StreamScan(StreamMap(StreamRange(0, 4, 1), "x", x * x), - 1, "a", "v", a + v) -> IndexedSeq(1, 1 /*1+0*0*/ , 2 /*1+1*1*/ , 6 /*2+2*2*/ , 15 /*6+3*3*/) + StreamScan(MakeStream(IndexedSeq(), TStream(TInt32)), 9, "a", "v", a + v) -> IndexedSeq(9), + StreamScan(StreamMap(StreamRange(0, 4, 1), "x", x * x), 1, "a", "v", a + v) -> IndexedSeq(1, + 1 /*1+0*0*/, 2 /*1+1*1*/, 6 /*2+2*2*/, 15 /*6+3*3*/ ), ) for ((ir, v) <- tests) { assert(evalStream(ir) == v, Pretty(ctx, ir)) @@ -629,10 +838,12 @@ class EmitStreamSuite extends HailSuite { } } - @Test def testEmitAggScan() { + @Test def testEmitAggScan(): Unit = { def assertAggScan(ir: IR, inType: Type, tests: (Any, Any)*): Unit = { - val aggregate = compileStream(LoweringPipeline.compileLowerer(false).apply(ctx, ir).asInstanceOf[IR], - PType.canonical(inType)) + val aggregate = compileStream( + LoweringPipeline.compileLowerer(false).apply(ctx, ir).asInstanceOf[IR], + PType.canonical(inType), + ) for ((inp, expected) <- tests) assert(aggregate(inp) == expected, Pretty(ctx, ir)) } @@ -641,69 +852,98 @@ class EmitStreamSuite extends HailSuite { ApplyScanOp( initArgs.toFastSeq, opArgs.toFastSeq, - AggSignature(op, - initArgs.map(_.typ), - opArgs.map(_.typ))) + AggSignature(op, initArgs.map(_.typ), opArgs.map(_.typ)), + ) val pairType = TStruct("x" -> TCall, "y" -> TInt32) val intsType = TArray(TInt32) assertAggScan( - StreamAggScan(ToStream(In(0, TArray(pairType))), + StreamAggScan( + ToStream(In(0, TArray(pairType))), "foo", GetField(Ref("foo", pairType), "y") + GetField( - scanOp(CallStats(), + scanOp( + CallStats(), IndexedSeq(I32(2)), - IndexedSeq(GetField(Ref("foo", pairType), "x")) + IndexedSeq(GetField(Ref("foo", pairType), "x")), ), - "AN") + "AN", + ), ), TArray(pairType), FastSeq( - Row(null, 1), Row(Call2(0, 0), 2), Row(Call2(0, 1), 3), Row(Call2(1, 1), 4), null, Row(null, 5) - ) -> FastSeq(1 + 0, 2 + 0, 3 + 2, 4 + 4, null, 5 + 6) + Row(null, 1), + Row(Call2(0, 0), 2), + Row(Call2(0, 1), 3), + Row(Call2(1, 1), 4), + null, + Row(null, 5), + ) -> FastSeq(1 + 0, 2 + 0, 3 + 2, 4 + 4, null, 5 + 6), ) assertAggScan( StreamAggScan( - StreamAggScan(ToStream(In(0, intsType)), + StreamAggScan( + ToStream(In(0, intsType)), "i", - scanOp(Sum(), IndexedSeq(), IndexedSeq(Ref("i", TInt32).toL))), + scanOp(Sum(), IndexedSeq(), IndexedSeq(Ref("i", TInt32).toL)), + ), "x", - scanOp(Max(), IndexedSeq(), IndexedSeq(Ref("x", TInt64))) + scanOp(Max(), IndexedSeq(), IndexedSeq(Ref("x", TInt64))), ), intsType, FastSeq(2, 5, 8, -3, 2, 2, 1, 0, 0) -> - IndexedSeq(null, 0L, 2L, 7L, 15L, 15L, 15L, 16L, 17L) + IndexedSeq(null, 0L, 2L, 7L, 15L, 15L, 15L, 16L, 17L), ) } - @Test def testEmitFromIterator() { + @Test def testEmitFromIterator(): Unit = { val intsPType = PInt32(true) val f1 = compileStreamWithIter( - StreamScan(In(0, SingleCodeEmitParamType(true, StreamSingleCodeType(true, PInt32(true), true))), + StreamScan( + In(0, SingleCodeEmitParamType(true, StreamSingleCodeType(true, PInt32(true), true))), zero = 0, - "a", "x", Ref("a", TInt32) + Ref("x", TInt32) * Ref("x", TInt32) - ), false, intsPType) + "a", + "x", + Ref("a", TInt32) + Ref("x", TInt32) * Ref("x", TInt32), + ), + false, + intsPType, + ) assert(f1((1 to 4).iterator) == IndexedSeq(0, 1, 1 + 4, 1 + 4 + 9, 1 + 4 + 9 + 16)) assert(f1(Iterator.empty) == IndexedSeq(0)) val f2 = compileStreamWithIter( StreamFlatMap( In(0, SingleCodeEmitParamType(true, StreamSingleCodeType(false, PInt32(true), true))), - "n", StreamRange(0, Ref("n", TInt32), 1) - ), false, intsPType) + "n", + StreamRange(0, Ref("n", TInt32), 1), + ), + false, + intsPType, + ) assert(f2(IndexedSeq(1, 5, 2, 9).iterator) == IndexedSeq(1, 5, 2, 9).flatMap(0 until _)) val f3 = compileStreamWithIter( - StreamRange(0, StreamLen(In(0, SingleCodeEmitParamType(true, StreamSingleCodeType(false, PInt32(true), true)))), 1), false, intsPType) + StreamRange( + 0, + StreamLen(In( + 0, + SingleCodeEmitParamType(true, StreamSingleCodeType(false, PInt32(true), true)), + )), + 1, + ), + false, + intsPType, + ) assert(f3(IndexedSeq(1, 5, 2, 9).iterator) == IndexedSeq(0, 1, 2, 3)) assert(f3(IndexedSeq().iterator) == IndexedSeq()) } - @Test def testEmitIf() { + @Test def testEmitIf(): Unit = { def xs = MakeStream(IndexedSeq[IR](5, 3, 6), TStream(TInt32)) def ys = StreamRange(0, 4, 1) @@ -719,8 +959,9 @@ class EmitStreamSuite extends HailSuite { StreamFlatMap( MakeStream(IndexedSeq(False(), True(), False()), TStream(TBoolean)), "x", - If(Ref("x", TBoolean), xs, ys)) - -> IndexedSeq(0, 1, 2, 3, 5, 3, 6, 0, 1, 2, 3) + If(Ref("x", TBoolean), xs, ys), + ) + -> IndexedSeq(0, 1, 2, 3, 5, 3, 6, 0, 1, 2, 3), ) val lens: Array[Option[Int]] = Array(Some(3), Some(4), Some(3), None, None, None) for (((ir, v), len) <- tests zip lens) { @@ -729,71 +970,95 @@ class EmitStreamSuite extends HailSuite { } } - @Test def testZipIfNA() { + @Test def testZipIfNA(): Unit = { - val t = PCanonicalStruct(true, "missingParam" -> PCanonicalArray(PFloat64()), + val t = PCanonicalStruct( + true, + "missingParam" -> PCanonicalArray(PFloat64()), "xs" -> PCanonicalArray(PFloat64()), - "ys" -> PCanonicalArray(PFloat64())) + "ys" -> PCanonicalArray(PFloat64()), + ) val i1 = Ref("in", t.virtualType) val ir = MakeTuple.ordered(IndexedSeq(StreamFold( StreamZip( FastSeq( - ToStream(If(IsNA(GetField(i1, "missingParam")), NA(TArray(TFloat64)), GetField(i1, "xs"))), - ToStream(GetField(i1, "ys")) + ToStream(If( + IsNA(GetField(i1, "missingParam")), + NA(TArray(TFloat64)), + GetField(i1, "xs"), + )), + ToStream(GetField(i1, "ys")), ), FastSeq("zipL", "zipR"), Ref("zipL", TFloat64) * Ref("zipR", TFloat64), - ArrayZipBehavior.AssertSameLength + ArrayZipBehavior.AssertSameLength, ), F64(0d), - "foldAcc", "foldVal", - Ref("foldAcc", TFloat64) + Ref("foldVal", TFloat64) + "foldAcc", + "foldVal", + Ref("foldAcc", TFloat64) + Ref("foldVal", TFloat64), ))) - val (Some(PTypeReferenceSingleCodeType(pt)), f) = Compile[AsmFunction2RegionLongLong](ctx, + val (Some(PTypeReferenceSingleCodeType(pt)), f) = Compile[AsmFunction2RegionLongLong]( + ctx, FastSeq(("in", SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(t)))), - FastSeq(classInfo[Region], LongInfo), LongInfo, - ir) + FastSeq(classInfo[Region], LongInfo), + LongInfo, + ir, + ) pool.scopedSmallRegion { r => - val input = t.unstagedStoreJavaObject(ctx.stateManager, Row(null, IndexedSeq(1d, 2d), IndexedSeq(3d, 4d)), r) + val input = t.unstagedStoreJavaObject( + ctx.stateManager, + Row(null, IndexedSeq(1d, 2d), IndexedSeq(3d, 4d)), + r, + ) - assert(SafeRow.read(pt, f(theHailClassLoader, ctx.fs, ctx.taskContext, r)(r, input)) == Row(null)) + assert( + SafeRow.read(pt, f(theHailClassLoader, ctx.fs, ctx.taskContext, r)(r, input)) == Row(null) + ) } } - @Test def testFold() { + @Test def testFold(): Unit = { val ints = Literal(TArray(TInt32), IndexedSeq(1, 2, 3, 4)) val strsLit = Literal(TArray(TString), IndexedSeq("one", "two", "three", "four")) - val strs = MakeStream(FastSeq(Str("one"), Str("two"), Str("three"), Str("four")), TStream(TString), true) + val strs = + MakeStream(FastSeq(Str("one"), Str("two"), Str("three"), Str("four")), TStream(TString), true) assertEvalsTo( - foldIR(ToStream(ints, requiresMemoryManagementPerElement = false), I32(-1)) { (acc, elt) => acc + elt }, - 9 + foldIR(ToStream(ints, requiresMemoryManagementPerElement = false), I32(-1)) { (acc, elt) => + acc + elt + }, + 9, ) assertEvalsTo( - foldIR(ToStream(strsLit, requiresMemoryManagementPerElement = false), Str("")) { (acc, elt) => invoke("concat", TString, acc, elt) }, - "onetwothreefour" + foldIR(ToStream(strsLit, requiresMemoryManagementPerElement = false), Str("")) { (acc, elt) => + invoke("concat", TString, acc, elt) + }, + "onetwothreefour", ) assertEvalsTo( - foldIR(strs, Str("")) { (acc, elt) => invoke("concat", TString, acc, elt) }, - "onetwothreefour" + foldIR(strs, Str(""))((acc, elt) => invoke("concat", TString, acc, elt)), + "onetwothreefour", ) } - @Test def testGrouped() { + @Test def testGrouped(): Unit = { // empty => empty assertEvalsTo( ToArray( mapIR( StreamGrouped( StreamRange(0, 0, 1, false), - I32(5) - )) { inner => ToArray(inner) } + I32(5), + ) + )(inner => ToArray(inner)) ), - IndexedSeq()) + IndexedSeq(), + ) // general case where stream ends in inner group assertEvalsTo( @@ -801,14 +1066,17 @@ class EmitStreamSuite extends HailSuite { mapIR( StreamGrouped( StreamRange(0, 10, 1, false), - I32(3) - )) { inner => ToArray(inner) } + I32(3), + ) + )(inner => ToArray(inner)) ), IndexedSeq( IndexedSeq(0, 1, 2), IndexedSeq(3, 4, 5), IndexedSeq(6, 7, 8), - IndexedSeq(9))) + IndexedSeq(9), + ), + ) // stream ends at end of inner group assertEvalsTo( @@ -816,12 +1084,15 @@ class EmitStreamSuite extends HailSuite { mapIR( StreamGrouped( StreamRange(0, 10, 1, false), - I32(5) - )) { inner => ToArray(inner) } + I32(5), + ) + )(inner => ToArray(inner)) ), IndexedSeq( IndexedSeq(0, 1, 2, 3, 4), - IndexedSeq(5, 6, 7, 8, 9))) + IndexedSeq(5, 6, 7, 8, 9), + ), + ) // separate regions assertEvalsTo( @@ -829,53 +1100,61 @@ class EmitStreamSuite extends HailSuite { mapIR( StreamGrouped( MakeStream((0 until 10).map(x => Str(x.toString)), TStream(TString), true), - I32(4) - )) { inner => ToArray(inner) } + I32(4), + ) + )(inner => ToArray(inner)) ), IndexedSeq( IndexedSeq("0", "1", "2", "3"), IndexedSeq("4", "5", "6", "7"), - IndexedSeq("8", "9"))) + IndexedSeq("8", "9"), + ), + ) } - @Test def testMakeStream() { + @Test def testMakeStream(): Unit = { assertEvalsTo( ToArray( MakeStream(IndexedSeq(I32(1), NA(TInt32), I32(2)), TStream(TInt32)) ), - IndexedSeq(1, null, 2) + IndexedSeq(1, null, 2), ) assertEvalsTo( ToArray( - MakeStream(IndexedSeq(Literal(TArray(TInt32), IndexedSeq(1)), NA(TArray(TInt32))), TStream(TArray(TInt32))) + MakeStream( + IndexedSeq(Literal(TArray(TInt32), IndexedSeq(1)), NA(TArray(TInt32))), + TStream(TArray(TInt32)), + ) ), - IndexedSeq(IndexedSeq(1), null) + IndexedSeq(IndexedSeq(1), null), ) } - @Test def testMultiplicity() { + @Test def testMultiplicity(): Unit = { val target = Ref("target", TStream(TInt32)) val i = Ref("i", TInt32) - for ((ir, v) <- IndexedSeq( - StreamRange(0, 10, 1) -> 0, - target -> 1, - Let(FastSeq("x" -> True()), target) -> 1, - StreamMap(target, "i", i) -> 1, - StreamMap(StreamMap(target, "i", i), "i", i * i) -> 1, - StreamFilter(target, "i", StreamFold(StreamRange(0, i, 1), 0, "a", "i", i)) -> 1, - StreamFilter(StreamRange(0, 5, 1), "i", StreamFold(target, 0, "a", "i", i)) -> 2, - StreamFlatMap(target, "i", StreamRange(0, i, 1)) -> 1, - StreamFlatMap(StreamRange(0, 5, 1), "i", target) -> 2, - StreamScan(StreamMap(target, "i", i), 0, "a", "i", i) -> 1, - StreamScan(StreamScan(target, 0, "a", "i", i), 0, "a", "i", i) -> 1 - )) { + for ( + (ir, v) <- IndexedSeq( + StreamRange(0, 10, 1) -> 0, + target -> 1, + Let(FastSeq("x" -> True()), target) -> 1, + StreamMap(target, "i", i) -> 1, + StreamMap(StreamMap(target, "i", i), "i", i * i) -> 1, + StreamFilter(target, "i", StreamFold(StreamRange(0, i, 1), 0, "a", "i", i)) -> 1, + StreamFilter(StreamRange(0, 5, 1), "i", StreamFold(target, 0, "a", "i", i)) -> 2, + StreamFlatMap(target, "i", StreamRange(0, i, 1)) -> 1, + StreamFlatMap(StreamRange(0, 5, 1), "i", target) -> 2, + StreamScan(StreamMap(target, "i", i), 0, "a", "i", i) -> 1, + StreamScan(StreamScan(target, 0, "a", "i", i), 0, "a", "i", i) -> 1, + ) + ) assert(StreamUtils.multiplicity(ir, "target") == v, Pretty(ctx, ir)) - } } - def assertMemoryDoesNotScaleWithStreamSize(lowSize: Int = 50, highSize: Int = 2500)(f: IR => IR): Unit = { + def assertMemoryDoesNotScaleWithStreamSize(lowSize: Int = 50, highSize: Int = 2500)(f: IR => IR) + : Unit = { val memUsed1 = ExecuteContext.scoped() { ctx => eval(f(lowSize), Env.empty, FastSeq(), None, None, false, ctx) ctx.r.pool.getHighestTotalUsage @@ -890,64 +1169,102 @@ class EmitStreamSuite extends HailSuite { throw new RuntimeException(s"memory usage scales with stream size!" + s"\n at size=$lowSize, memory=$memUsed1" + s"\n at size=$highSize, memory=$memUsed2" + - s"\n IR: ${ Pretty(ctx, f(lowSize)) }") + s"\n IR: ${Pretty(ctx, f(lowSize))}") } def sumIR(x: IR): IR = foldIR(x, 0) { case (acc, value) => acc + value } - def foldLength(x: IR): IR = sumIR(mapIR(x) { _ => I32(1) }) + def foldLength(x: IR): IR = sumIR(mapIR(x)(_ => I32(1))) def rangeStructs(size: IR): IR = mapIR(StreamRange(0, size, 1, true)) { i => - makestruct(("idx", i), ("foo", invoke("str", TString, i)), ("bigArray", ToArray(rangeIR(10000)))) + makestruct( + ("idx", i), + ("foo", invoke("str", TString, i)), + ("bigArray", ToArray(rangeIR(10000))), + ) } def filteredRangeStructs(size: IR): IR = mapIR(filterIR( StreamRange(0, size, 1, true) - ) { i => i < (size / 2).toI }) { i => - makestruct(("idx", i), ("foo2", invoke("str", TString, i)), ("bigArray2", ToArray(rangeIR(10000)))) + )(i => i < (size / 2).toI)) { i => + makestruct( + ("idx", i), + ("foo2", invoke("str", TString, i)), + ("bigArray2", ToArray(rangeIR(10000))), + ) } - @Test def testMemoryRangeFold(): Unit = { - + @Test def testMemoryRangeFold(): Unit = assertMemoryDoesNotScaleWithStreamSize() { size => - foldIR(mapIR(flatMapIR(StreamRange(0, size, 1, true)) { x => StreamRange(0, x, 1, true) }) { i => - invoke("str", TString, i) - }, I32(0)) { case (acc, value) => maxIR(acc, invoke("length", TInt32, value)) } + foldIR( + mapIR(flatMapIR(StreamRange(0, size, 1, true))(x => StreamRange(0, x, 1, true))) { i => + invoke("str", TString, i) + }, + I32(0), + ) { case (acc, value) => maxIR(acc, invoke("length", TInt32, value)) } } - } @Test def testStreamJoinMemory(): Unit = { assertMemoryDoesNotScaleWithStreamSize() { size => - sumIR(joinIR(rangeStructs(size), filteredRangeStructs(size), IndexedSeq("idx"), IndexedSeq("idx"), "inner", false) { case (l, r) => I32(1) }) + sumIR(joinIR( + rangeStructs(size), + filteredRangeStructs(size), + IndexedSeq("idx"), + IndexedSeq("idx"), + "inner", + false, + ) { case (_, _) => I32(1) }) } assertMemoryDoesNotScaleWithStreamSize() { size => - sumIR(joinIR(rangeStructs(size), filteredRangeStructs(size), IndexedSeq("idx"), IndexedSeq("idx"), "left", false) { case (l, r) => I32(1) }) + sumIR(joinIR( + rangeStructs(size), + filteredRangeStructs(size), + IndexedSeq("idx"), + IndexedSeq("idx"), + "left", + false, + ) { case (_, _) => I32(1) }) } assertMemoryDoesNotScaleWithStreamSize() { size => - sumIR(joinIR(rangeStructs(size), filteredRangeStructs(size), IndexedSeq("idx"), IndexedSeq("idx"), "right", false) { case (l, r) => I32(1) }) + sumIR(joinIR( + rangeStructs(size), + filteredRangeStructs(size), + IndexedSeq("idx"), + IndexedSeq("idx"), + "right", + false, + ) { case (_, _) => I32(1) }) } assertMemoryDoesNotScaleWithStreamSize() { size => - sumIR(joinIR(rangeStructs(size), filteredRangeStructs(size), IndexedSeq("idx"), IndexedSeq("idx"), "outer", false) { case (l, r) => I32(1) }) + sumIR(joinIR( + rangeStructs(size), + filteredRangeStructs(size), + IndexedSeq("idx"), + IndexedSeq("idx"), + "outer", + false, + ) { case (_, _) => I32(1) }) } } @Test def testStreamGroupedMemory(): Unit = { assertMemoryDoesNotScaleWithStreamSize() { size => - sumIR(mapIR(StreamGrouped(rangeIR(size), 100)) { stream => I32(1) }) + sumIR(mapIR(StreamGrouped(rangeIR(size), 100))(stream => I32(1))) } assertMemoryDoesNotScaleWithStreamSize() { size => - sumIR(mapIR(StreamGrouped(rangeIR(size), 100)) { stream => sumIR(stream) }) + sumIR(mapIR(StreamGrouped(rangeIR(size), 100))(stream => sumIR(stream))) } } - @Test def testStreamFilterMemory(): Unit = { + @Test def testStreamFilterMemory(): Unit = assertMemoryDoesNotScaleWithStreamSize(highSize = 100000) { size => - StreamLen(filterIR(mapIR(StreamRange(0, size, 1, true)) { i => invoke("str", TString, i) }) { str => invoke("length", TInt32, str) > (size * 9 / 10).toString.size }) + StreamLen(filterIR(mapIR(StreamRange(0, size, 1, true))(i => invoke("str", TString, i))) { + str => invoke("length", TInt32, str) > (size * 9 / 10).toString.size + }) } - } @Test def testStreamFlatMapMemory(): Unit = { assertMemoryDoesNotScaleWithStreamSize() { size => @@ -963,26 +1280,34 @@ class EmitStreamSuite extends HailSuite { } } - @Test def testGroupedFlatMapMemManagementMismatch(): Unit = { + @Test def testGroupedFlatMapMemManagementMismatch(): Unit = assertMemoryDoesNotScaleWithStreamSize() { size => - foldLength(flatMapIR(mapIR(StreamGrouped(rangeStructs(size), 16)) { x => ToArray(x) }) { a => ToStream(a, false) }) + foldLength(flatMapIR(mapIR(StreamGrouped(rangeStructs(size), 16))(x => ToArray(x))) { a => + ToStream(a, false) + }) } - } @Test def testStreamTakeWhile(): Unit = { val makestream = MakeStream(FastSeq(I32(1), I32(2), I32(0), I32(1), I32(-1)), TStream(TInt32)) - assert(evalStream(takeWhile(makestream) { r => r > 0 }) == IndexedSeq(1, 2)) + assert(evalStream(takeWhile(makestream)(r => r > 0)) == IndexedSeq(1, 2)) assert(evalStream(StreamTake(makestream, I32(3))) == IndexedSeq(1, 2, 0)) - assert(evalStream(takeWhile(makestream) { r => NA(TBoolean) }) == IndexedSeq()) - assert(evalStream(takeWhile(makestream) { r => If(r > 0, True(), NA(TBoolean)) }) == IndexedSeq(1, 2)) + assert(evalStream(takeWhile(makestream)(r => NA(TBoolean))) == IndexedSeq()) + assert(evalStream(takeWhile(makestream)(r => If(r > 0, True(), NA(TBoolean)))) == IndexedSeq( + 1, + 2, + )) } @Test def testStreamDropWhile(): Unit = { val makestream = MakeStream(FastSeq(I32(1), I32(2), I32(0), I32(1), I32(-1)), TStream(TInt32)) - assert(evalStream(dropWhile(makestream) { r => r > 0 }) == IndexedSeq(0, 1, -1)) + assert(evalStream(dropWhile(makestream)(r => r > 0)) == IndexedSeq(0, 1, -1)) assert(evalStream(StreamDrop(makestream, I32(3))) == IndexedSeq(1, -1)) - assert(evalStream(dropWhile(makestream) { r => NA(TBoolean) }) == IndexedSeq(1, 2, 0, 1, -1)) - assert(evalStream(dropWhile(makestream) { r => If(r > 0, True(), NA(TBoolean)) }) == IndexedSeq(0, 1, -1)) + assert(evalStream(dropWhile(makestream)(r => NA(TBoolean))) == IndexedSeq(1, 2, 0, 1, -1)) + assert(evalStream(dropWhile(makestream)(r => If(r > 0, True(), NA(TBoolean)))) == IndexedSeq( + 0, + 1, + -1, + )) } @@ -996,11 +1321,11 @@ class EmitStreamSuite extends HailSuite { } assertMemoryDoesNotScaleWithStreamSize() { size => - foldLength(dropWhile(rangeStructs(size)) { elt => GetField(elt, "idx") < (size / 2).toI }) + foldLength(dropWhile(rangeStructs(size))(elt => GetField(elt, "idx") < (size / 2).toI)) } assertMemoryDoesNotScaleWithStreamSize() { size => - foldLength(takeWhile(rangeStructs(size)) { elt => GetField(elt, "idx") < (size / 2).toI }) + foldLength(takeWhile(rangeStructs(size))(elt => GetField(elt, "idx") < (size / 2).toI)) } } @@ -1008,4 +1333,64 @@ class EmitStreamSuite extends HailSuite { assert(evalStream(takeWhile(iota(0, 2))(elt => elt < 10)) == IndexedSeq(0, 2, 4, 6, 8)) assert(evalStream(StreamTake(iota(5, -5), 3)) == IndexedSeq(5, 0, -5)) } + + @Test def testStreamIntervalJoin(): Unit = { + val keyStream = mapIR(StreamRange(0, 9, 1, requiresMemoryManagementPerElement = true)) { i => + MakeStruct(FastSeq("i" -> i)) + } + val kType = TIterable.elementType(keyStream.typ).asInstanceOf[TStruct] + val rightElemType = TStruct("interval" -> TInterval(TInt32)) + + val intervals: IndexedSeq[Interval] = + for { + (start, end, includesStart, includesEnd) <- FastSeq( + (1, 6, true, false), + (2, 2, false, false), + (3, 5, true, true), + (4, 6, true, true), + (6, 7, false, true), + ) + } yield Interval( + IntervalEndpoint(start, if (includesStart) -1 else 1), + IntervalEndpoint(end, if (includesEnd) 1 else -1), + ) + + val join = + ToArray( + StreamLeftIntervalJoin( + keyStream, + ToStream( + Literal(TArray(rightElemType), intervals.map(Row(_))), + requiresMemoryManagementPerElement = true, + ), + kType.fieldNames.head, + "interval", + "lname", + "rname", + InsertFields( + Ref("lname", kType), + FastSeq( + "intervals" -> mapArray(Ref("rname", TArray(rightElemType))) { + GetField(_, "interval") + } + ), + ), + ) + ) + + assertEvalsTo( + join, + FastSeq( + Row(0, FastSeq()), + Row(1, FastSeq(intervals(0))), + Row(2, FastSeq(intervals(0))), + Row(3, FastSeq(intervals(2), intervals(0))), + Row(4, FastSeq(intervals(2), intervals(0), intervals(3))), + Row(5, FastSeq(intervals(2), intervals(0), intervals(3))), + Row(6, FastSeq(intervals(3))), + Row(7, FastSeq(intervals(4))), + Row(8, FastSeq()), + ), + ) + } } diff --git a/hail/src/test/scala/is/hail/expr/ir/EncodedLiteralSuite.scala b/hail/src/test/scala/is/hail/expr/ir/EncodedLiteralSuite.scala index f74cceeaa6d..667f684d50b 100644 --- a/hail/src/test/scala/is/hail/expr/ir/EncodedLiteralSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/EncodedLiteralSuite.scala @@ -1,6 +1,7 @@ package is.hail.expr.ir import is.hail.HailSuite + import org.testng.annotations.Test class EncodedLiteralSuite extends HailSuite { diff --git a/hail/src/test/scala/is/hail/expr/ir/ExtractIntervalFiltersSuite.scala b/hail/src/test/scala/is/hail/expr/ir/ExtractIntervalFiltersSuite.scala index 9eb1c93bad7..12a76d10bda 100644 --- a/hail/src/test/scala/is/hail/expr/ir/ExtractIntervalFiltersSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/ExtractIntervalFiltersSuite.scala @@ -1,10 +1,11 @@ package is.hail.expr.ir +import is.hail.{ExecStrategy, HailSuite} import is.hail.rvd.RVDPartitioner import is.hail.types.virtual._ import is.hail.utils.{FastSeq, Interval, IntervalEndpoint} import is.hail.variant.{Locus, ReferenceGenome} -import is.hail.{ExecStrategy, HailSuite} + import org.apache.spark.sql.Row import org.testng.annotations.Test @@ -24,11 +25,13 @@ class ExtractIntervalFiltersSuite extends HailSuite { outer => val fullKeyRefs = Array( SelectFields(structRef, structRefKey), - MakeStruct(FastSeq("y" -> GetField(structRef, "y"), "z" -> GetField(structRef, "z")))) + MakeStruct(FastSeq("y" -> GetField(structRef, "y"), "z" -> GetField(structRef, "z"))), + ) val prefixKeyRefs = Array( SelectFields(structRef, FastSeq("y")), - MakeStruct(FastSeq("y" -> GetField(structRef, "y")))) + MakeStruct(FastSeq("y" -> GetField(structRef, "y"))), + ) def wrappedIntervalEndpoint(x: Any, sign: Int) = IntervalEndpoint(Row(x), sign) @@ -46,8 +49,20 @@ class ExtractIntervalFiltersSuite extends HailSuite { outer => def and(l: IR, r: IR): IR = invoke("land", TBoolean, l, r) def not(b: IR): IR = ApplyUnaryPrimOp(Bang, b) - def check(filter: IR, rowRef: Ref, key: IR, probes: IndexedSeq[Row], residualFilter: IR, trueIntervals: Seq[Interval]): Unit = { - val result = ExtractIntervalFilters.extractPartitionFilters(ctx, filter, rowRef, key.typ.asInstanceOf[TStruct].fieldNames) + def check( + filter: IR, + rowRef: Ref, + key: IR, + probes: IndexedSeq[Row], + residualFilter: IR, + trueIntervals: Seq[Interval], + ): Unit = { + val result = ExtractIntervalFilters.extractPartitionFilters( + ctx, + filter, + rowRef, + key.typ.asInstanceOf[TStruct].fieldNames, + ) if (result.isEmpty) { assert(trueIntervals == FastSeq(Interval(Row(), Row(), true, true))) return @@ -60,9 +75,8 @@ class ExtractIntervalFiltersSuite extends HailSuite { outer => val irIntervals: IR = Literal( TArray(RVDPartitioner.intervalIRRepresentation(keyType)), - trueIntervals.map { i => - RVDPartitioner.intervalToIRRepresentation(i, keyType.size) - }) + trueIntervals.map(i => RVDPartitioner.intervalToIRRepresentation(i, keyType.size)), + ) val filterIsTrue = Coalesce(FastSeq(filter, False())) val residualIsTrue = Coalesce(FastSeq(residualFilter, False())) @@ -74,9 +88,12 @@ class ExtractIntervalFiltersSuite extends HailSuite { outer => True(), accRef.name, rowRef.name, - ApplyComparisonOp(EQ(TBoolean), + ApplyComparisonOp( + EQ(TBoolean), filterIsTrue, - invoke("land", TBoolean, keyInIntervals, residualIsTrue))) + invoke("land", TBoolean, keyInIntervals, residualIsTrue), + ), + ) assertEvalsTo(testIR, true)(ExecStrategy.compileOnly) } @@ -91,7 +108,7 @@ class ExtractIntervalFiltersSuite extends HailSuite { outer => naIntervals: Seq[Interval], trueResidual: IR = True(), falseResidual: IR = True(), - naResidual: IR = True() + naResidual: IR = True(), ): Unit = { check(filter, rowRef, key, probes, trueResidual, trueIntervals) check(ApplyUnaryPrimOp(Bang, filter), rowRef, key, probes, falseResidual, falseIntervals) @@ -100,128 +117,200 @@ class ExtractIntervalFiltersSuite extends HailSuite { outer => @Test def testIsNA(): Unit = { val testRows = FastSeq( - Row(0, 0, true), - Row(0, null, true)) - checkAll(IsNA(k1), ref1, k1Full, testRows, + Row(0, 0, true), + Row(0, null, true), + ) + checkAll( + IsNA(k1), + ref1, + k1Full, + testRows, FastSeq(Interval(Row(null), Row(), true, true)), FastSeq(Interval(Row(), Row(null), true, false)), - FastSeq()) + FastSeq(), + ) } - @Test def testKeyComparison() { + @Test def testKeyComparison(): Unit = { def check( op: ComparisonOp[Boolean], point: IR, trueIntervals: IndexedSeq[Interval], falseIntervals: IndexedSeq[Interval], - naIntervals: IndexedSeq[Interval] - ) { + naIntervals: IndexedSeq[Interval], + ): Unit = { val testRows = FastSeq( Row(0, -1, true), - Row(0, 0, true), - Row(0, 1, true), - Row(0, null, true)) - - checkAll(ApplyComparisonOp(op, k1, point), ref1, k1Full, testRows, trueIntervals, falseIntervals, naIntervals) - checkAll(ApplyComparisonOp(ComparisonOp.swap(op), point, k1), ref1, k1Full, testRows, trueIntervals, falseIntervals, naIntervals) - checkAll(ApplyComparisonOp(ComparisonOp.negate(op), k1, point), ref1, k1Full, testRows, falseIntervals, trueIntervals, naIntervals) - checkAll(ApplyComparisonOp(ComparisonOp.swap(ComparisonOp.negate(op)), point, k1), ref1, k1Full, testRows, falseIntervals, trueIntervals, naIntervals) + Row(0, 0, true), + Row(0, 1, true), + Row(0, null, true), + ) + + checkAll( + ApplyComparisonOp(op, k1, point), + ref1, + k1Full, + testRows, + trueIntervals, + falseIntervals, + naIntervals, + ) + checkAll( + ApplyComparisonOp(ComparisonOp.swap(op), point, k1), + ref1, + k1Full, + testRows, + trueIntervals, + falseIntervals, + naIntervals, + ) + checkAll( + ApplyComparisonOp(ComparisonOp.negate(op), k1, point), + ref1, + k1Full, + testRows, + falseIntervals, + trueIntervals, + naIntervals, + ) + checkAll( + ApplyComparisonOp(ComparisonOp.swap(ComparisonOp.negate(op)), point, k1), + ref1, + k1Full, + testRows, + falseIntervals, + trueIntervals, + naIntervals, + ) } - check(LT(TInt32), I32(0), + check( + LT(TInt32), + I32(0), FastSeq(Interval(Row(), Row(0), true, false)), FastSeq(Interval(Row(0), Row(null), true, false)), - FastSeq(Interval(Row(null), Row(), true, true))) - check(GT(TInt32), I32(0), + FastSeq(Interval(Row(null), Row(), true, true)), + ) + check( + GT(TInt32), + I32(0), FastSeq(Interval(Row(0), Row(null), false, false)), FastSeq(Interval(Row(), Row(0), true, true)), - FastSeq(Interval(Row(null), Row(), true, true))) - check(EQ(TInt32), I32(0), + FastSeq(Interval(Row(null), Row(), true, true)), + ) + check( + EQ(TInt32), + I32(0), FastSeq(Interval(Row(0), Row(0), true, true)), FastSeq( Interval(Row(), Row(0), true, false), - Interval(Row(0), Row(null), false, false)), - FastSeq(Interval(Row(null), Row(), true, true))) + Interval(Row(0), Row(null), false, false), + ), + FastSeq(Interval(Row(null), Row(), true, true)), + ) // These are never true (always missing), extracts the empty set of intervals - check(LT(TInt32), NA(TInt32), - FastSeq(), - FastSeq(), - FastSeq(Interval(Row(), Row(), true, true))) - check(GT(TInt32), NA(TInt32), - FastSeq(), - FastSeq(), - FastSeq(Interval(Row(), Row(), true, true))) - check(EQ(TInt32), NA(TInt32), - FastSeq(), - FastSeq(), - FastSeq(Interval(Row(), Row(), true, true))) + check(LT(TInt32), NA(TInt32), FastSeq(), FastSeq(), FastSeq(Interval(Row(), Row(), true, true))) + check(GT(TInt32), NA(TInt32), FastSeq(), FastSeq(), FastSeq(Interval(Row(), Row(), true, true))) + check(EQ(TInt32), NA(TInt32), FastSeq(), FastSeq(), FastSeq(Interval(Row(), Row(), true, true))) - check(EQWithNA(TInt32), I32(0), + check( + EQWithNA(TInt32), + I32(0), FastSeq(Interval(Row(0), Row(0), true, true)), FastSeq( Interval(Row(), Row(0), true, false), - Interval(Row(0), Row(), false, true)), - FastSeq()) - check(EQWithNA(TInt32), NA(TInt32), + Interval(Row(0), Row(), false, true), + ), + FastSeq(), + ) + check( + EQWithNA(TInt32), + NA(TInt32), FastSeq(Interval(Row(null), Row(), true, true)), FastSeq(Interval(Row(), Row(null), true, false)), - FastSeq()) + FastSeq(), + ) - assert(ExtractIntervalFilters.extractPartitionFilters(ctx, ApplyComparisonOp(Compare(TInt32), I32(0), k1), ref1, ref1Key).isEmpty) + assert(ExtractIntervalFilters.extractPartitionFilters( + ctx, + ApplyComparisonOp(Compare(TInt32), I32(0), k1), + ref1, + ref1Key, + ).isEmpty) } - @Test def testLiteralContains() { - def check(node: IR, trueIntervals: IndexedSeq[Interval], falseIntervals: IndexedSeq[Interval], naIntervals: IndexedSeq[Interval]) { + @Test def testLiteralContains(): Unit = { + def check( + node: IR, + trueIntervals: IndexedSeq[Interval], + falseIntervals: IndexedSeq[Interval], + naIntervals: IndexedSeq[Interval], + ): Unit = { val testRows = FastSeq( Row(0, 1, true), Row(0, 5, true), Row(0, 10, true), - Row(0, null, true)) + Row(0, null, true), + ) checkAll(node, ref1, k1Full, testRows, trueIntervals, falseIntervals, naIntervals) } - for { + for { lit <- Array( Literal(TSet(TInt32), Set(null, 10, 1)), Literal(TArray(TInt32), FastSeq(10, 1, null)), - Literal(TDict(TInt32, TString), Map(1 -> "foo", (null, "bar"), 10 -> "baz"))) + Literal(TDict(TInt32, TString), Map(1 -> "foo", (null, "bar"), 10 -> "baz")), + ) } { - check(invoke("contains", TBoolean, lit, k1), + check( + invoke("contains", TBoolean, lit, k1), FastSeq( Interval(Row(1), Row(1), true, true), Interval(Row(10), Row(10), true, true), - Interval(Row(null), Row(), true, true)), + Interval(Row(null), Row(), true, true), + ), FastSeq( Interval(Row(), Row(1), true, false), Interval(Row(1), Row(10), false, false), - Interval(Row(10), Row(null), false, false)), - FastSeq()) + Interval(Row(10), Row(null), false, false), + ), + FastSeq(), + ) } - for { + for { lit <- Array( Literal(TSet(TInt32), Set(10, 1)), Literal(TArray(TInt32), FastSeq(10, 1)), - Literal(TDict(TInt32, TString), Map(1 -> "foo", 10 -> "baz"))) + Literal(TDict(TInt32, TString), Map(1 -> "foo", 10 -> "baz")), + ) } { check( invoke("contains", TBoolean, lit, k1), FastSeq( Interval(Row(1), Row(1), true, true), - Interval(Row(10), Row(10), true, true)), + Interval(Row(10), Row(10), true, true), + ), FastSeq( Interval(Row(), Row(1), true, false), Interval(Row(1), Row(10), false, false), - Interval(Row(10), Row(), false, true)), - FastSeq()) + Interval(Row(10), Row(), false, true), + ), + FastSeq(), + ) } } - @Test def testLiteralContainsStruct() { + @Test def testLiteralContainsStruct(): Unit = { hc // force initialization - def check(node: IR, trueIntervals: IndexedSeq[Interval], falseIntervals: IndexedSeq[Interval], naIntervals: IndexedSeq[Interval]) { + def check( + node: IR, + trueIntervals: IndexedSeq[Interval], + falseIntervals: IndexedSeq[Interval], + naIntervals: IndexedSeq[Interval], + ): Unit = { val testRows = FastSeq( Row(0, 1, 2), Row(0, 3, 4), @@ -229,28 +318,45 @@ class ExtractIntervalFiltersSuite extends HailSuite { outer => Row(0, 5, null), Row(0, 5, 5), Row(0, null, 1), - Row(0, null, null)) - checkAll(node, structRef, fullKeyRefs(0), testRows, trueIntervals, falseIntervals, naIntervals) + Row(0, null, null), + ) + checkAll( + node, + structRef, + fullKeyRefs(0), + testRows, + trueIntervals, + falseIntervals, + naIntervals, + ) } for { lit <- Array( Literal(TSet(structT1), Set(Row(1, 2), Row(3, 4), Row(3, null))), Literal(TArray(structT1), FastSeq(Row(3, 4), Row(1, 2), Row(3, null))), - Literal(TDict(structT1, TString), Map(Row(1, 2) -> "foo", Row(3, 4) -> "bar", Row(3, null) -> "baz"))) + Literal( + TDict(structT1, TString), + Map(Row(1, 2) -> "foo", Row(3, 4) -> "bar", Row(3, null) -> "baz"), + ), + ) } { for (k <- fullKeyRefs) { - check(invoke("contains", TBoolean, lit, k), + check( + invoke("contains", TBoolean, lit, k), IndexedSeq( Interval(Row(1, 2), Row(1, 2), true, true), Interval(Row(3, 4), Row(3, 4), true, true), - Interval(Row(3, null), Row(3, null), true, true)), + Interval(Row(3, null), Row(3, null), true, true), + ), IndexedSeq( Interval(Row(), Row(1, 2), true, false), Interval(Row(1, 2), Row(3, 4), false, false), Interval(Row(3, 4), Row(3, null), false, false), - Interval(Row(3, null), Row(), false, true)), - IndexedSeq()) + Interval(Row(3, null), Row(), false, true), + ), + IndexedSeq(), + ) } } @@ -258,46 +364,63 @@ class ExtractIntervalFiltersSuite extends HailSuite { outer => lit <- Array( Literal(TSet(structT2), Set(Row(1), Row(3), Row(null))), Literal(TArray(structT2), FastSeq(Row(3), Row(null), Row(1))), - Literal(TDict(structT2, TString), Map(Row(1) -> "foo", Row(null) -> "baz", Row(3) -> "bar"))) + Literal(TDict(structT2, TString), Map(Row(1) -> "foo", Row(null) -> "baz", Row(3) -> "bar")), + ) } { for (k <- prefixKeyRefs) { - check(invoke("contains", TBoolean, lit, k), + check( + invoke("contains", TBoolean, lit, k), IndexedSeq( Interval(Row(1), Row(1), true, true), Interval(Row(3), Row(3), true, true), - Interval(Row(null), Row(), true, true)), + Interval(Row(null), Row(), true, true), + ), IndexedSeq( Interval(Row(), Row(1), true, false), Interval(Row(1), Row(3), false, false), - Interval(Row(3), Row(null), false, false)), - IndexedSeq()) + Interval(Row(3), Row(null), false, false), + ), + IndexedSeq(), + ) } } } - @Test def testIntervalContains() { + @Test def testIntervalContains(): Unit = { val interval = Interval(1, 5, false, true) val testRows = FastSeq( Row(0, 0, true), Row(0, 1, true), Row(0, 5, true), Row(0, 10, true), - Row(0, null, true)) + Row(0, null, true), + ) val ir = invoke("contains", TBoolean, Literal(TInterval(TInt32), interval), k1) - checkAll(ir, ref1, k1Full, testRows, + checkAll( + ir, + ref1, + k1Full, + testRows, FastSeq(Interval(Row(1), Row(5), false, true)), FastSeq( Interval(Row(), Row(1), true, true), - Interval(Row(5), Row(null), false, false)), - FastSeq(Interval(Row(null), Row(), true, true))) + Interval(Row(5), Row(null), false, false), + ), + FastSeq(Interval(Row(null), Row(), true, true)), + ) } - @Test def testIntervalContainsStruct() { + @Test def testIntervalContainsStruct(): Unit = { val fullInterval = Interval(Row(1, 1), Row(2, 2), false, true) val prefixInterval = Interval(Row(1), Row(2), false, true) - def check(node: IR, trueIntervals: IndexedSeq[Interval], falseIntervals: IndexedSeq[Interval], naIntervals: IndexedSeq[Interval]) { + def check( + node: IR, + trueIntervals: IndexedSeq[Interval], + falseIntervals: IndexedSeq[Interval], + naIntervals: IndexedSeq[Interval], + ): Unit = { val testRows = FastSeq( Row(0, null, 0), Row(0, 0, 0), @@ -309,30 +432,45 @@ class ExtractIntervalFiltersSuite extends HailSuite { outer => Row(0, 2, 2), Row(0, 2, null), Row(0, 3, 0), - Row(0, null, null)) - checkAll(node, structRef, fullKeyRefs(0), testRows, trueIntervals, falseIntervals, naIntervals) + Row(0, null, null), + ) + checkAll( + node, + structRef, + fullKeyRefs(0), + testRows, + trueIntervals, + falseIntervals, + naIntervals, + ) } for (k <- fullKeyRefs) { - check(invoke("contains", TBoolean, Literal(TInterval(structT1), fullInterval), k), + check( + invoke("contains", TBoolean, Literal(TInterval(structT1), fullInterval), k), FastSeq(fullInterval), FastSeq( Interval(Row(), Row(1, 1), true, true), - Interval(Row(2, 2), Row(), false, true)), - FastSeq()) + Interval(Row(2, 2), Row(), false, true), + ), + FastSeq(), + ) } for (k <- prefixKeyRefs) { - check(invoke("contains", TBoolean, Literal(TInterval(structT2), prefixInterval), k), + check( + invoke("contains", TBoolean, Literal(TInterval(structT2), prefixInterval), k), FastSeq(prefixInterval), FastSeq( Interval(Row(), Row(1), true, true), - Interval(Row(2), Row(), false, true)), - FastSeq()) + Interval(Row(2), Row(), false, true), + ), + FastSeq(), + ) } } - @Test def testLocusContigComparison() { + @Test def testLocusContigComparison(): Unit = { hc // force initialization val ref = Ref("foo", TStruct("x" -> TLocus(ReferenceGenome.GRCh38))) val k = GetField(ref, "x") @@ -345,20 +483,27 @@ class ExtractIntervalFiltersSuite extends HailSuite { outer => Row(Locus("chr2", 1)), Row(Locus("chr2", 1000)), Row(Locus("chr3", 5)), - Row(null)) + Row(null), + ) val trueIntervals = FastSeq( - Interval(Row(Locus("chr2", 1)), Row(Locus("chr2", grch38.contigLength("chr2"))), true, false)) + Interval(Row(Locus("chr2", 1)), Row(Locus("chr2", grch38.contigLength("chr2"))), true, false) + ) val falseIntervals = FastSeq( Interval(Row(), Row(Locus("chr2", 1)), true, false), - Interval(Row(Locus("chr2", grch38.contigLength("chr2"))), Row(null), true, false)) + Interval(Row(Locus("chr2", grch38.contigLength("chr2"))), Row(null), true, false), + ) val naIntervals = FastSeq(Interval(Row(null), Row(), true, true)) checkAll(ir1, ref, ref, testRows, trueIntervals, falseIntervals, naIntervals) checkAll(ir2, ref, ref, testRows, trueIntervals, falseIntervals, naIntervals) + + val ir3 = neq(Str("chr2"), invoke("contig", TString, k)) + checkAll(ir3, ref, ref, testRows, falseIntervals, trueIntervals, naIntervals) + checkAll(not(ir1), ref, ref, testRows, falseIntervals, trueIntervals, naIntervals) } - @Test def testLocusPositionComparison() { + @Test def testLocusPositionComparison(): Unit = { hc // force initialization val ref = Ref("foo", TStruct("x" -> TLocus(ReferenceGenome.GRCh38))) val k = GetField(ref, "x") @@ -369,10 +514,12 @@ class ExtractIntervalFiltersSuite extends HailSuite { outer => point: Int, truePosIntervals: IndexedSeq[Interval], falsePosIntervals: IndexedSeq[Interval], - naPosIntervals: IndexedSeq[Interval] - ) { - val trueIntervals = ExtractIntervalFilters.liftPosIntervalsToLocus(truePosIntervals, grch38, ctx) - val falseIntervals = ExtractIntervalFilters.liftPosIntervalsToLocus(falsePosIntervals, grch38, ctx) + naPosIntervals: IndexedSeq[Interval], + ): Unit = { + val trueIntervals = + ExtractIntervalFilters.liftPosIntervalsToLocus(truePosIntervals, grch38, ctx) + val falseIntervals = + ExtractIntervalFilters.liftPosIntervalsToLocus(falsePosIntervals, grch38, ctx) val naIntervals = ExtractIntervalFilters.liftPosIntervalsToLocus(naPosIntervals, grch38, ctx) val testRows = FastSeq( @@ -384,69 +531,139 @@ class ExtractIntervalFiltersSuite extends HailSuite { outer => Row(Locus("chr2", 5)), Row(Locus("chr2", 100)), Row(Locus("chr3", 105)), - Row(null)) + Row(null), + ) - checkAll(ApplyComparisonOp(op, pos, I32(point)), ref, ref, testRows, trueIntervals, falseIntervals, naIntervals) - checkAll(ApplyComparisonOp(ComparisonOp.swap(op), I32(point), pos), ref, ref, testRows, trueIntervals, falseIntervals, naIntervals) - checkAll(ApplyComparisonOp(ComparisonOp.negate(op), pos, I32(point)), ref, ref, testRows, falseIntervals, trueIntervals, naIntervals) - checkAll(ApplyComparisonOp(ComparisonOp.swap(ComparisonOp.negate(op)), I32(point), pos), ref, ref, testRows, falseIntervals, trueIntervals, naIntervals) + checkAll( + ApplyComparisonOp(op, pos, I32(point)), + ref, + ref, + testRows, + trueIntervals, + falseIntervals, + naIntervals, + ) + checkAll( + ApplyComparisonOp(ComparisonOp.swap(op), I32(point), pos), + ref, + ref, + testRows, + trueIntervals, + falseIntervals, + naIntervals, + ) + checkAll( + ApplyComparisonOp(ComparisonOp.negate(op), pos, I32(point)), + ref, + ref, + testRows, + falseIntervals, + trueIntervals, + naIntervals, + ) + checkAll( + ApplyComparisonOp(ComparisonOp.swap(ComparisonOp.negate(op)), I32(point), pos), + ref, + ref, + testRows, + falseIntervals, + trueIntervals, + naIntervals, + ) } - check(GT(TInt32), 100, + check( + GT(TInt32), + 100, FastSeq(Interval(Row(100), Row(null), false, false)), FastSeq(Interval(Row(), Row(100), true, true)), - FastSeq(Interval(Row(null), Row(), true, true))) - check(GT(TInt32), -1000, - FastSeq(Interval(Row(1), Row(null), true, false)), - FastSeq(), - FastSeq(Interval(Row(null), Row(), true, true))) + FastSeq(Interval(Row(null), Row(), true, true)), + ) + check( + GT(TInt32), + -1000, + FastSeq(Interval(Row(1), Row(null), true, false)), + FastSeq(), + FastSeq(Interval(Row(null), Row(), true, true)), + ) - check(LT(TInt32), 100, + check( + LT(TInt32), + 100, FastSeq(Interval(Row(), Row(100), true, false)), FastSeq(Interval(Row(100), Row(null), true, false)), - FastSeq(Interval(Row(null), Row(), true, true))) - check(LT(TInt32), -1000, + FastSeq(Interval(Row(null), Row(), true, true)), + ) + check( + LT(TInt32), + -1000, FastSeq(), FastSeq(Interval(Row(), Row(null), true, false)), - FastSeq(Interval(Row(null), Row(), true, true))) + FastSeq(Interval(Row(null), Row(), true, true)), + ) - check(EQ(TInt32), 100, + check( + EQ(TInt32), + 100, FastSeq(Interval(Row(100), Row(100), true, true)), FastSeq( Interval(Row(), Row(100), true, false), - Interval(Row(100), Row(null), false, false)), - FastSeq(Interval(Row(null), Row(), true, true))) - check(EQ(TInt32), -1000, + Interval(Row(100), Row(null), false, false), + ), + FastSeq(Interval(Row(null), Row(), true, true)), + ) + check( + EQ(TInt32), + -1000, FastSeq(), FastSeq(Interval(Row(), Row(null), true, false)), - FastSeq(Interval(Row(null), Row(), true, true))) + FastSeq(Interval(Row(null), Row(), true, true)), + ) - check(EQWithNA(TInt32), 100, + check( + EQWithNA(TInt32), + 100, FastSeq(Interval(Row(100), Row(100), true, true)), FastSeq( Interval(Row(), Row(100), true, false), - Interval(Row(100), Row(), false, true)), - FastSeq()) - check(EQWithNA(TInt32), -1000, + Interval(Row(100), Row(), false, true), + ), + FastSeq(), + ) + check( + EQWithNA(TInt32), + -1000, FastSeq(), FastSeq(Interval(Row(), Row(), true, true)), - FastSeq()) + FastSeq(), + ) - assert(ExtractIntervalFilters.extractPartitionFilters(ctx, ApplyComparisonOp(Compare(TInt32), I32(0), pos), ref, ref1Key).isEmpty) + assert(ExtractIntervalFilters.extractPartitionFilters( + ctx, + ApplyComparisonOp(Compare(TInt32), I32(0), pos), + ref, + ref1Key, + ).isEmpty) } - @Test def testLocusContigContains() { + @Test def testLocusContigContains(): Unit = { hc // force initialization val ref = Ref("foo", TStruct("x" -> TLocus(ReferenceGenome.GRCh38))) val k = GetField(ref, "x") val contig = invoke("contig", TString, k) - def check(node: IR, trueIntervals: IndexedSeq[Interval], falseIntervals: IndexedSeq[Interval], naIntervals: IndexedSeq[Interval]) { + def check( + node: IR, + trueIntervals: IndexedSeq[Interval], + falseIntervals: IndexedSeq[Interval], + naIntervals: IndexedSeq[Interval], + ): Unit = { val testRows = FastSeq( Row(Locus("chr1", 5)), Row(Locus("chr2", 1)), Row(Locus("chr10", 5)), - Row(null)) + Row(null), + ) checkAll(node, ref, ref, testRows, trueIntervals, falseIntervals, naIntervals) } @@ -454,80 +671,120 @@ class ExtractIntervalFiltersSuite extends HailSuite { outer => lit <- Array( Literal(TSet(TString), Set("chr10", "chr1", null, "foo")), Literal(TArray(TString), FastSeq("foo", "chr10", null, "chr1")), - Literal(TDict(TString, TString), Map("chr1" -> "foo", "chr10" -> "bar", "foo" -> "baz", (null, "quux")))) + Literal( + TDict(TString, TString), + Map("chr1" -> "foo", "chr10" -> "bar", "foo" -> "baz", (null, "quux")), + ), + ) } { - check(invoke("contains", TBoolean, lit, contig), + check( + invoke("contains", TBoolean, lit, contig), FastSeq( Interval( Row(Locus("chr1", 1)), Row(Locus("chr1", grch38.contigLength("chr1"))), - true, false), + true, + false, + ), Interval( Row(Locus("chr10", 1)), Row(Locus("chr10", grch38.contigLength("chr10"))), - true, false), - Interval(Row(null), Row(), true, true)), + true, + false, + ), + Interval(Row(null), Row(), true, true), + ), FastSeq( Interval( Row(), Row(Locus("chr1", 1)), - true, false), + true, + false, + ), Interval( Row(Locus("chr1", grch38.contigLength("chr1"))), Row(Locus("chr10", 1)), - true, false), + true, + false, + ), Interval( Row(Locus("chr10", grch38.contigLength("chr10"))), Row(null), - true, false)), - FastSeq()) + true, + false, + ), + ), + FastSeq(), + ) } for { lit <- Array( Literal(TSet(TString), Set("chr10", "chr1", "foo")), Literal(TArray(TString), FastSeq("foo", "chr10", "chr1")), - Literal(TDict(TString, TString), Map("chr1" -> "foo", "chr10" -> "bar", "foo" -> "baz"))) + Literal(TDict(TString, TString), Map("chr1" -> "foo", "chr10" -> "bar", "foo" -> "baz")), + ) } { - check(invoke("contains", TBoolean, lit, contig), + check( + invoke("contains", TBoolean, lit, contig), FastSeq( Interval( Row(Locus("chr1", 1)), Row(Locus("chr1", grch38.contigLength("chr1"))), - true, false), + true, + false, + ), Interval( Row(Locus("chr10", 1)), Row(Locus("chr10", grch38.contigLength("chr10"))), - true, false)), + true, + false, + ), + ), FastSeq( Interval( Row(), Row(Locus("chr1", 1)), - true, false), + true, + false, + ), Interval( Row(Locus("chr1", grch38.contigLength("chr1"))), Row(Locus("chr10", 1)), - true, false), + true, + false, + ), Interval( Row(Locus("chr10", grch38.contigLength("chr10"))), Row(), - true, true)), - FastSeq()) + true, + true, + ), + ), + FastSeq(), + ) } } - @Test def testIntervalListFold() { + @Test def testIntervalListFold(): Unit = { val inIntervals = FastSeq( Interval(0, 10, true, false), Interval(20, 25, true, false), - Interval(-10, 5, true, false)) + Interval(-10, 5, true, false), + ) val inIntervalsWithNull = FastSeq( Interval(0, 10, true, false), null, Interval(20, 25, true, false), - Interval(-10, 5, true, false)) + Interval(-10, 5, true, false), + ) - def check(node: IR, trueIntervals: IndexedSeq[Interval], falseIntervals: IndexedSeq[Interval], naIntervals: IndexedSeq[Interval]) { + def check( + node: IR, + trueIntervals: IndexedSeq[Interval], + falseIntervals: IndexedSeq[Interval], + naIntervals: IndexedSeq[Interval], + ): Unit = { val testRows = FastSeq( Row(0, -15, true), Row(0, -10, true), @@ -540,7 +797,8 @@ class ExtractIntervalFiltersSuite extends HailSuite { outer => Row(0, 22, true), Row(0, 25, true), Row(0, 30, true), - Row(0, null, true)) + Row(0, null, true), + ) checkAll(node, ref1, k1Full, testRows, trueIntervals, falseIntervals, naIntervals) } @@ -549,35 +807,55 @@ class ExtractIntervalFiltersSuite extends HailSuite { outer => False(), "acc", "elt", - invoke("lor", TBoolean, + invoke( + "lor", + TBoolean, Ref("acc", TBoolean), - invoke("contains", TBoolean, Ref("elt", TInterval(TInt32)), k1))) + invoke("contains", TBoolean, Ref("elt", TInterval(TInt32)), k1), + ), + ) - check(containsKey(inIntervals), + check( + containsKey(inIntervals), FastSeq( Interval(Row(-10), Row(10), true, false), - Interval(Row(20), Row(25), true, false)), + Interval(Row(20), Row(25), true, false), + ), FastSeq( Interval(Row(), Row(-10), true, false), Interval(Row(10), Row(20), true, false), - Interval(Row(25), Row(null), true, false)), - FastSeq(Interval(Row(null), Row(), true, true))) + Interval(Row(25), Row(null), true, false), + ), + FastSeq(Interval(Row(null), Row(), true, true)), + ) // Whenever the previous would be false, this is instead missing, because of the null // In particular, it is never false, so notIR2 filters everything - check(containsKey(inIntervalsWithNull), + check( + containsKey(inIntervalsWithNull), FastSeq( Interval(Row(-10), Row(10), true, false), - Interval(Row(20), Row(25), true, false)), + Interval(Row(20), Row(25), true, false), + ), FastSeq(), FastSeq( Interval(Row(), Row(-10), true, false), Interval(Row(10), Row(20), true, false), - Interval(Row(25), Row(), true, true))) + Interval(Row(25), Row(), true, true), + ), + ) } - @Test def testDisjunction() { - def check(node: IR, trueIntervals: IndexedSeq[Interval], falseIntervals: IndexedSeq[Interval], naIntervals: IndexedSeq[Interval], trueResidual: IR = True(), falseResidual: IR = True(), naResidual: IR = True()) { + @Test def testDisjunction(): Unit = { + def check( + node: IR, + trueIntervals: IndexedSeq[Interval], + falseIntervals: IndexedSeq[Interval], + naIntervals: IndexedSeq[Interval], + trueResidual: IR = True(), + falseResidual: IR = True(), + naResidual: IR = True(), + ): Unit = { val testRows = FastSeq( Row(0, 0, true), Row(0, 0, false), @@ -590,21 +868,27 @@ class ExtractIntervalFiltersSuite extends HailSuite { outer => Row(0, 15, true), Row(0, 15, false), Row(0, null, true), - Row(0, null, false)) - checkAll(node, ref1, k1Full, testRows, trueIntervals, falseIntervals, naIntervals, trueResidual, falseResidual, naResidual) + Row(0, null, false), + ) + checkAll(node, ref1, k1Full, testRows, trueIntervals, falseIntervals, naIntervals, + trueResidual, falseResidual, naResidual) } val lt5 = lt(k1, I32(5)) val gt10 = gt(k1, I32(10)) - check(or(lt5, gt10), + check( + or(lt5, gt10), FastSeq( Interval(Row(), Row(5), true, false), - Interval(Row(10), Row(null), false, false)), + Interval(Row(10), Row(null), false, false), + ), FastSeq(Interval(Row(5), Row(10), true, true)), - FastSeq(Interval(Row(null), Row(), true, true))) + FastSeq(Interval(Row(null), Row(), true, true)), + ) - check(or(lt5, unknownBool), + check( + or(lt5, unknownBool), // could be true anywhere, since unknownBool might be true FastSeq(Interval(Row(), Row(), true, true)), // can only be false if lt5 is false @@ -615,23 +899,35 @@ class ExtractIntervalFiltersSuite extends HailSuite { outer => // we've filtered to the rows where lt5 is false falseResidual = not(or(False(), unknownBool)), // we've filtered to where lt5 is false or missing, so can't simplify - naResidual = IsNA(or(lt5, unknownBool))) + naResidual = IsNA(or(lt5, unknownBool)), + ) - check(and(not(or(lt5, unknownBool)), - not(or(gt10, unknownBool))), + check( + and(not(or(lt5, unknownBool)), not(or(gt10, unknownBool))), FastSeq(Interval(Row(5), Row(10), true, true)), FastSeq(Interval(Row(), Row(), true, true)), FastSeq(Interval(Row(5), Row(10), true, true), Interval(Row(null), Row(), true, true)), trueResidual = and( not(or(False(), unknownBool)), - not(or(False(), unknownBool))), + not(or(False(), unknownBool)), + ), naResidual = IsNA(and( not(or(lt5, unknownBool)), - not(or(gt10, unknownBool))))) + not(or(gt10, unknownBool)), + )), + ) } - @Test def testConjunction() { - def check(node: IR, trueIntervals: IndexedSeq[Interval], falseIntervals: IndexedSeq[Interval], naIntervals: IndexedSeq[Interval], trueResidual: IR = True(), falseResidual: IR = True(), naResidual: IR = True()) { + @Test def testConjunction(): Unit = { + def check( + node: IR, + trueIntervals: IndexedSeq[Interval], + falseIntervals: IndexedSeq[Interval], + naIntervals: IndexedSeq[Interval], + trueResidual: IR = True(), + falseResidual: IR = True(), + naResidual: IR = True(), + ): Unit = { val testRows = FastSeq( Row(0, 0, true), Row(0, 0, false), @@ -644,21 +940,27 @@ class ExtractIntervalFiltersSuite extends HailSuite { outer => Row(0, 15, true), Row(0, 15, false), Row(0, null, true), - Row(0, null, false)) - checkAll(node, ref1, k1Full, testRows, trueIntervals, falseIntervals, naIntervals, trueResidual, falseResidual, naResidual) + Row(0, null, false), + ) + checkAll(node, ref1, k1Full, testRows, trueIntervals, falseIntervals, naIntervals, + trueResidual, falseResidual, naResidual) } val gt5 = gt(k1, I32(5)) val lt10 = lt(k1, I32(10)) - check(and(gt5, lt10), + check( + and(gt5, lt10), FastSeq(Interval(Row(5), Row(10), false, false)), FastSeq( Interval(Row(), Row(5), true, true), - Interval(Row(10), Row(null), true, false)), - FastSeq(Interval(Row(null), Row(), true, true))) + Interval(Row(10), Row(null), true, false), + ), + FastSeq(Interval(Row(null), Row(), true, true)), + ) - check(and(gt5, unknownBool), + check( + and(gt5, unknownBool), // can only be true if gt5 is true FastSeq(Interval(Row(5), Row(null), false, false)), // could be false anywhere, since unknownBool might be false @@ -669,54 +971,89 @@ class ExtractIntervalFiltersSuite extends HailSuite { outer => // we've filtered to the rows where gt5 is true trueResidual = and(True(), unknownBool), // we've filtered to where gt5 is false or missing, so can't simplify - naResidual = IsNA(and(gt5, unknownBool))) + naResidual = IsNA(and(gt5, unknownBool)), + ) } @Test def testCoalesce(): Unit = { - def check(node: IR, trueIntervals: IndexedSeq[Interval], falseIntervals: IndexedSeq[Interval], naIntervals: IndexedSeq[Interval], trueResidual: IR = True(), falseResidual: IR = True(), naResidual: IR = True()) { + def check( + node: IR, + trueIntervals: IndexedSeq[Interval], + falseIntervals: IndexedSeq[Interval], + naIntervals: IndexedSeq[Interval], + trueResidual: IR = True(), + falseResidual: IR = True(), + naResidual: IR = True(), + ): Unit = { val testRows = FastSeq( Row(0, 0, true), Row(0, 5, true), Row(0, 7, true), Row(0, 10, true), Row(0, 15, true), - Row(0, null, true)) - checkAll(node, ref1, k1Full, testRows, trueIntervals, falseIntervals, naIntervals, trueResidual, falseResidual, naResidual) + Row(0, null, true), + ) + checkAll(node, ref1, k1Full, testRows, trueIntervals, falseIntervals, naIntervals, + trueResidual, falseResidual, naResidual) } val gt5 = gt(k1, I32(5)) val lt10 = lt(k1, I32(10)) - check(Coalesce(FastSeq(gt5, lt10, False())), + check( + Coalesce(FastSeq(gt5, lt10, False())), FastSeq(Interval(Row(5), Row(null), false, false)), FastSeq( Interval(Row(), Row(5), true, true), - Interval(Row(null), Row(), true, true)), - FastSeq()) + Interval(Row(null), Row(), true, true), + ), + FastSeq(), + ) } @Test def testIf(): Unit = { - def check(node: IR, trueIntervals: IndexedSeq[Interval], falseIntervals: IndexedSeq[Interval], naIntervals: IndexedSeq[Interval], trueResidual: IR = True(), falseResidual: IR = True(), naResidual: IR = True()) { + def check( + node: IR, + trueIntervals: IndexedSeq[Interval], + falseIntervals: IndexedSeq[Interval], + naIntervals: IndexedSeq[Interval], + trueResidual: IR = True(), + falseResidual: IR = True(), + naResidual: IR = True(), + ): Unit = { val testRows = FastSeq( Row(0, 0, true), Row(0, 5, true), Row(0, 7, true), Row(0, 10, true), Row(0, 15, true), - Row(0, null, true)) - checkAll(node, ref1, k1Full, testRows, trueIntervals, falseIntervals, naIntervals, trueResidual, falseResidual, naResidual) + Row(0, null, true), + ) + checkAll(node, ref1, k1Full, testRows, trueIntervals, falseIntervals, naIntervals, + trueResidual, falseResidual, naResidual) } - check(If(gt(k1, I32(0)), lt(k1, I32(5)), gt(k1, I32(-5))), + check( + If(gt(k1, I32(0)), lt(k1, I32(5)), gt(k1, I32(-5))), FastSeq(Interval(Row(-5), Row(5), false, false)), FastSeq( Interval(Row(), Row(-5), true, true), - Interval(Row(5), Row(null), true, false)), - FastSeq(Interval(Row(null), Row(), true, true))) + Interval(Row(5), Row(null), true, false), + ), + FastSeq(Interval(Row(null), Row(), true, true)), + ) } @Test def testSwitch(): Unit = { - def check(node: IR, trueIntervals: IndexedSeq[Interval], falseIntervals: IndexedSeq[Interval], naIntervals: IndexedSeq[Interval], trueResidual: IR = True(), falseResidual: IR = True(), naResidual: IR = True()) { + def check( + node: IR, + trueIntervals: IndexedSeq[Interval], + falseIntervals: IndexedSeq[Interval], + naIntervals: IndexedSeq[Interval], + trueResidual: IR = True(), + falseResidual: IR = True(), + naResidual: IR = True(), + ): Unit = { val testRows = FastSeq( Row(0, 0, true), Row(0, 5, true), @@ -728,21 +1065,22 @@ class ExtractIntervalFiltersSuite extends HailSuite { outer => Row(1, null, true), Row(null, null, true), ) - checkAll(node, ref1, k1Full, testRows, trueIntervals, falseIntervals, naIntervals, trueResidual, falseResidual, naResidual) + checkAll(node, ref1, k1Full, testRows, trueIntervals, falseIntervals, naIntervals, + trueResidual, falseResidual, naResidual) } check( Switch(I32(0), gt(k1, I32(-5)), FastSeq(lt(k1, I32(5)))), FastSeq(Interval(Row(), Row(5), true, false)), FastSeq(Interval(Row(5), Row(null), true, false)), - FastSeq(Interval(Row(null), Row(), true, true)) + FastSeq(Interval(Row(null), Row(), true, true)), ) check( Switch(I32(-1), gt(k1, I32(-5)), FastSeq(lt(k1, I32(5)))), FastSeq(Interval(Row(-5), Row(null), false, false)), FastSeq(Interval(Row(), Row(-5), true, true)), - FastSeq(Interval(Row(null), Row(), true, true)) + FastSeq(Interval(Row(null), Row(), true, true)), ) val filter = Switch(GetField(ref1, "w"), gt(k1, I32(-5)), FastSeq(lt(k1, I32(5)))) @@ -753,7 +1091,7 @@ class ExtractIntervalFiltersSuite extends HailSuite { outer => FastSeq(Interval(Row(), Row(), true, true)), trueResidual = filter, falseResidual = ApplyUnaryPrimOp(Bang, filter), - naResidual = IsNA(filter) + naResidual = IsNA(filter), ) } @@ -762,27 +1100,40 @@ class ExtractIntervalFiltersSuite extends HailSuite { outer => Row(0, 0, true), Row(0, 10, true), Row(0, 20, true), - Row(0, null, true)) + Row(0, null, true), + ) - val count = TableAggregate(TableRange(10, 1), ApplyAggOp(FastSeq(), FastSeq(), AggSignature(Count(), FastSeq(), FastSeq()))) + val count = TableAggregate( + TableRange(10, 1), + ApplyAggOp(FastSeq(), FastSeq(), AggSignature(Count(), FastSeq(), FastSeq())), + ) print(count.typ) val filter = gt(count, Cast(k1, TInt64)) check(filter, ref1, k1Full, testRows, filter, FastSeq(Interval(Row(), Row(), true, true))) } - @Test def testIntegration() { + @Test def testIntegration(): Unit = { hc // force initialization val tab1 = TableRange(10, 5) def k = GetField(Ref("row", tab1.typ.rowType), "idx") - val tf = TableFilter(tab1, - Coalesce(FastSeq(invoke("land", TBoolean, - ApplyComparisonOp(GT(TInt32), k, I32(3)), - ApplyComparisonOp(LTEQ(TInt32), k, I32(9)) - ), False()))) + val tf = TableFilter( + tab1, + Coalesce(FastSeq( + invoke( + "land", + TBoolean, + ApplyComparisonOp(GT(TInt32), k, I32(3)), + ApplyComparisonOp(LTEQ(TInt32), k, I32(9)), + ), + False(), + )), + ) - assert(ExtractIntervalFilters(ctx, tf).asInstanceOf[TableFilter].child.isInstanceOf[TableFilterIntervals]) + assert(ExtractIntervalFilters(ctx, tf).asInstanceOf[TableFilter].child.isInstanceOf[ + TableFilterIntervals + ]) assertEvalsTo(TableCount(tf), 6L)(ExecStrategy.interpretOnly) } } diff --git a/hail/src/test/scala/is/hail/expr/ir/FakeTableReader.scala b/hail/src/test/scala/is/hail/expr/ir/FakeTableReader.scala index f727079d15c..76cb9b9e454 100644 --- a/hail/src/test/scala/is/hail/expr/ir/FakeTableReader.scala +++ b/hail/src/test/scala/is/hail/expr/ir/FakeTableReader.scala @@ -1,15 +1,21 @@ package is.hail.expr.ir + import is.hail.backend.ExecuteContext import is.hail.expr.ir.lowering.TableStage -import is.hail.types.virtual.TStruct import is.hail.types.{TableType, VirtualTypeWithReq} +import is.hail.types.virtual.TStruct class FakeTableReader extends TableReader { override def pathsUsed: Seq[String] = ??? override def partitionCounts: Option[IndexedSeq[Long]] = ??? override def fullType: TableType = ??? - override def rowRequiredness(ctx: ExecuteContext, requestedType: TableType): VirtualTypeWithReq = ??? - override def globalRequiredness(ctx: ExecuteContext, requestedType: TableType): VirtualTypeWithReq = ??? + + override def rowRequiredness(ctx: ExecuteContext, requestedType: TableType): VirtualTypeWithReq = + ??? + + override def globalRequiredness(ctx: ExecuteContext, requestedType: TableType) + : VirtualTypeWithReq = ??? + override def renderShort(): String = ??? override def lowerGlobals(ctx: ExecuteContext, requestedGlobalsType: TStruct): IR = ??? override def lower(ctx: ExecuteContext, requestedType: TableType): TableStage = ??? diff --git a/hail/src/test/scala/is/hail/expr/ir/FoldConstantsSuite.scala b/hail/src/test/scala/is/hail/expr/ir/FoldConstantsSuite.scala index 29321325e55..14716f2852f 100644 --- a/hail/src/test/scala/is/hail/expr/ir/FoldConstantsSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/FoldConstantsSuite.scala @@ -1,18 +1,17 @@ package is.hail.expr.ir import is.hail.HailSuite -import is.hail.types.virtual.{TFloat64, TInt32, TRNGState, TTuple} -import org.apache.spark.sql.Row -import org.scalatest.testng.TestNGSuite +import is.hail.types.virtual.{TFloat64, TInt32} + import org.testng.annotations.{DataProvider, Test} class FoldConstantsSuite extends HailSuite { - @Test def testRandomBlocksFolding() { + @Test def testRandomBlocksFolding(): Unit = { val x = ApplySeeded("rand_norm", Seq(F64(0d), F64(0d)), RNGStateLiteral(), 0L, TFloat64) assert(FoldConstants(ctx, x) == x) } - @Test def testErrorCatching() { + @Test def testErrorCatching(): Unit = { val ir = invoke("toInt32", TInt32, Str("")) assert(FoldConstants(ctx, ir) == ir) } @@ -23,13 +22,12 @@ class FoldConstantsSuite extends HailSuite { AggLet("x", I32(1), I32(1), false), AggLet("x", I32(1), I32(1), true), ApplyAggOp(Sum())(I64(1)), - ApplyScanOp(Sum())(I64(1)) - ).map(x => Array[Any](x)) + ApplyScanOp(Sum())(I64(1)), + ).map(x => Array[Any](x)) } @Test def testAggNodesConstruction(): Unit = aggNodes() - @Test(dataProvider = "aggNodes") def testAggNodesDoNotFold(node: IR): Unit = { + @Test(dataProvider = "aggNodes") def testAggNodesDoNotFold(node: IR): Unit = assert(FoldConstants(ctx, node) == node) - } } diff --git a/hail/src/test/scala/is/hail/expr/ir/ForwardLetsSuite.scala b/hail/src/test/scala/is/hail/expr/ir/ForwardLetsSuite.scala index 6d01f6873fc..5d7ffcdada6 100644 --- a/hail/src/test/scala/is/hail/expr/ir/ForwardLetsSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/ForwardLetsSuite.scala @@ -3,19 +3,16 @@ package is.hail.expr.ir import is.hail.HailSuite import is.hail.TestUtils._ import is.hail.expr.Nat -import is.hail.expr.ir.DeprecatedIRBuilder.{applyAggOp, let, _} import is.hail.types.virtual._ import is.hail.utils._ -import org.scalatest.AppendedClues.convertToClueful -import org.scalatest.Matchers.{be, convertToAnyShouldWrapper} + import org.testng.annotations.{BeforeMethod, DataProvider, Test} class ForwardLetsSuite extends HailSuite { @BeforeMethod - def resetUidCounter(): Unit = { + def resetUidCounter(): Unit = is.hail.expr.ir.uidCounter = 0 - } @DataProvider(name = "nonForwardingOps") def nonForwardingOps(): Array[Array[IR]] = { @@ -26,13 +23,37 @@ class ForwardLetsSuite extends HailSuite { ToArray(StreamMap(ToStream(a), "y", ApplyBinaryPrimOp(Add(), x, y))), ToArray(StreamFilter(ToStream(a), "y", ApplyComparisonOp(LT(TInt32), x, y))), ToArray(StreamFlatMap(ToStream(a), "y", StreamRange(x, y, I32(1)))), - StreamFold(ToStream(a), I32(0), "acc", "y", ApplyBinaryPrimOp(Add(), ApplyBinaryPrimOp(Add(), x, y), Ref("acc", TInt32))), - StreamFold2(ToStream(a), FastSeq(("acc", I32(0))), "y", FastSeq(x + y + Ref("acc", TInt32)), Ref("acc", TInt32)), - ToArray(StreamScan(ToStream(a), I32(0), "acc", "y", ApplyBinaryPrimOp(Add(), ApplyBinaryPrimOp(Add(), x, y), Ref("acc", TInt32)))), - MakeStruct(FastSeq("a" -> ApplyBinaryPrimOp(Add(), x, I32(1)), "b" -> ApplyBinaryPrimOp(Add(), x, I32(2)))), - MakeTuple.ordered(FastSeq(ApplyBinaryPrimOp(Add(), x, I32(1)), ApplyBinaryPrimOp(Add(), x, I32(2)))), + StreamFold( + ToStream(a), + I32(0), + "acc", + "y", + ApplyBinaryPrimOp(Add(), ApplyBinaryPrimOp(Add(), x, y), Ref("acc", TInt32)), + ), + StreamFold2( + ToStream(a), + FastSeq(("acc", I32(0))), + "y", + FastSeq(x + y + Ref("acc", TInt32)), + Ref("acc", TInt32), + ), + ToArray(StreamScan( + ToStream(a), + I32(0), + "acc", + "y", + ApplyBinaryPrimOp(Add(), ApplyBinaryPrimOp(Add(), x, y), Ref("acc", TInt32)), + )), + MakeStruct(FastSeq( + "a" -> ApplyBinaryPrimOp(Add(), x, I32(1)), + "b" -> ApplyBinaryPrimOp(Add(), x, I32(2)), + )), + MakeTuple.ordered(FastSeq( + ApplyBinaryPrimOp(Add(), x, I32(1)), + ApplyBinaryPrimOp(Add(), x, I32(2)), + )), ApplyBinaryPrimOp(Add(), ApplyBinaryPrimOp(Add(), x, x), I32(1)), - StreamAgg(ToStream(a), "y", ApplyAggOp(Sum())(x + y)) + StreamAgg(ToStream(a), "y", ApplyAggOp(Sum())(x + y)), ).map(ir => Array[IR](Let(FastSeq("x" -> (In(0, TInt32) + In(0, TInt32))), ir))) } @@ -42,12 +63,25 @@ class ForwardLetsSuite extends HailSuite { val y = Ref("y", TInt32) Array( NDArrayMap(In(1, TNDArray(TInt32, Nat(1))), "y", x + y), - NDArrayMap2(In(1, TNDArray(TInt32, Nat(1))), In(2, TNDArray(TInt32, Nat(1))), "y", "z", x + y + Ref("z", TInt32), ErrorIDs.NO_ERROR), - TailLoop("f", FastSeq("y" -> I32(0)), TInt32, If(y < x, Recur("f", FastSeq[IR](y - I32(1)), TInt32), x)) + NDArrayMap2( + In(1, TNDArray(TInt32, Nat(1))), + In(2, TNDArray(TInt32, Nat(1))), + "y", + "z", + x + y + Ref("z", TInt32), + ErrorIDs.NO_ERROR, + ), + TailLoop( + "f", + FastSeq("y" -> I32(0)), + TInt32, + If(y < x, Recur("f", FastSeq[IR](y - I32(1)), TInt32), x), + ), ).map(ir => Array[IR](Let(FastSeq("x" -> (In(0, TInt32) + In(0, TInt32))), ir))) } - def aggMin(value: IR): ApplyAggOp = ApplyAggOp(FastSeq(), FastSeq(value), AggSignature(Min(), FastSeq(), FastSeq(value.typ))) + def aggMin(value: IR): ApplyAggOp = + ApplyAggOp(FastSeq(), FastSeq(value), AggSignature(Min(), FastSeq(), FastSeq(value.typ))) @DataProvider(name = "nonForwardingAggOps") def nonForwardingAggOps(): Array[Array[IR]] = { @@ -56,7 +90,7 @@ class ForwardLetsSuite extends HailSuite { val y = Ref("y", TInt32) Array( AggArrayPerElement(ToArray(a), "y", "_", aggMin(x + y), None, false), - AggExplode(a, "y", aggMin(y + x), false) + AggExplode(a, "y", aggMin(y + x), false), ).map(ir => Array[IR](AggLet("x", In(0, TInt32) + In(0, TInt32), ir, false))) } @@ -70,7 +104,7 @@ class ForwardLetsSuite extends HailSuite { ApplyBinaryPrimOp(Add(), ApplyBinaryPrimOp(Add(), I32(2), x), I32(1)), ApplyUnaryPrimOp(Negate, x), ToArray(StreamMap(StreamRange(I32(0), x, I32(1)), "foo", Ref("foo", TInt32))), - ToArray(StreamFilter(StreamRange(I32(0), x, I32(1)), "foo", Ref("foo", TInt32) <= I32(0))) + ToArray(StreamFilter(StreamRange(I32(0), x, I32(1)), "foo", Ref("foo", TInt32) <= I32(0))), ).map(ir => Array[IR](Let(FastSeq("x" -> (In(0, TInt32) + In(0, TInt32))), ir))) } @@ -80,17 +114,28 @@ class ForwardLetsSuite extends HailSuite { val other = Ref("other", TInt32) Array( AggFilter(x.ceq(I32(0)), aggMin(other), false), - aggMin(x + other) + aggMin(x + other), ).map(ir => Array[IR](AggLet("x", In(0, TInt32) + In(0, TInt32), ir, false))) } - @Test def assertDataProvidersWork() { + @Test def assertDataProvidersWork(): Unit = { nonForwardingOps() forwardingOps() nonForwardingAggOps() forwardingAggOps() } + @Test def testBlock(): Unit = { + val ir = Block( + FastSeq(Binding("x", I32(1), Scope.AGG), Binding("y", Ref("x", TInt32), Scope.AGG)), + ApplyAggOp(Sum())(Ref("y", TInt32)), + ) + val after: IR = ForwardLets(ctx)(ir) + val expected = ApplyAggOp(Sum())(I32(1)) + val normalize = new NormalizeNames(_.toString) + assert(normalize(ctx, after) == normalize(ctx, expected)) + } + @Test(dataProvider = "nonForwardingOps") def testNonForwardingOps(ir: IR): Unit = { val after = ForwardLets(ctx)(ir) @@ -102,26 +147,26 @@ class ForwardLetsSuite extends HailSuite { @Test(dataProvider = "nonForwardingNonEvalOps") def testNonForwardingNonEvalOps(ir: IR): Unit = { val after = ForwardLets(ctx)(ir) - assert(after.isInstanceOf[Let]) + assert(after.isInstanceOf[Block]) } @Test(dataProvider = "nonForwardingAggOps") def testNonForwardingAggOps(ir: IR): Unit = { val after = ForwardLets(ctx)(ir) - assert(after.isInstanceOf[AggLet]) + assert(after.isInstanceOf[Block]) } @Test(dataProvider = "forwardingOps") def testForwardingOps(ir: IR): Unit = { val after = ForwardLets(ctx)(ir) - assert(!after.isInstanceOf[Let]) + assert(!after.isInstanceOf[Block]) assertEvalSame(ir, args = Array(5 -> TInt32)) } @Test(dataProvider = "forwardingAggOps") def testForwardingAggOps(ir: IR): Unit = { val after = ForwardLets(ctx)(ir) - assert(!after.isInstanceOf[AggLet]) + assert(!after.isInstanceOf[Block]) } @DataProvider(name = "TrivialIRCases") @@ -132,30 +177,28 @@ class ForwardLetsSuite extends HailSuite { Array( Let(FastSeq("x" -> I32(0)), I32(2)), I32(2), - """"x" is unused.""" + """"x" is unused.""", ), Array( Let(FastSeq("x" -> I32(0)), Ref("x", TInt32)), I32(0), - """"x" is constant and is used once.""" + """"x" is constant and is used once.""", ), Array( Let(FastSeq("x" -> I32(2)), Ref("x", TInt32) * Ref("x", TInt32)), I32(2) * I32(2), - """"x" is a primitive constant (ForwardLets does not evaluate).""" + """"x" is a primitive constant (ForwardLets does not evaluate).""", ), Array( bindIRs(I32(2), F64(pi), Ref("r", TFloat64)) { case Seq(two, pi, r) => - ApplyBinaryPrimOp(Multiply(), - ApplyBinaryPrimOp(Multiply(), Cast(two, TFloat64), pi), - r - ) + ApplyBinaryPrimOp(Multiply(), ApplyBinaryPrimOp(Multiply(), Cast(two, TFloat64), pi), r) }, - ApplyBinaryPrimOp(Multiply(), + ApplyBinaryPrimOp( + Multiply(), ApplyBinaryPrimOp(Multiply(), Cast(I32(2), TFloat64), F64(pi)), - Ref("r", TFloat64) + Ref("r", TFloat64), ), - """Forward constant primitive values and simple use ref.""" + """Forward constant primitive values and simple use ref.""", ), Array( Let( @@ -164,52 +207,84 @@ class ForwardLetsSuite extends HailSuite { iruid(1) -> Cast(Ref(iruid(0), TInt32), TFloat64), iruid(2) -> ApplyBinaryPrimOp(FloatingPointDivide(), Ref(iruid(1), TFloat64), F64(2)), iruid(3) -> F64(pi), - iruid(4) -> ApplyBinaryPrimOp(Multiply(), Ref(iruid(3), TFloat64), Ref(iruid(1), TFloat64)), - iruid(5) -> ApplyBinaryPrimOp(Multiply(), Ref(iruid(2), TFloat64), Ref(iruid(2), TFloat64)), - iruid(6) -> ApplyBinaryPrimOp(Multiply(), Ref(iruid(3), TFloat64), Ref(iruid(5), TFloat64)) + iruid(4) -> ApplyBinaryPrimOp( + Multiply(), + Ref(iruid(3), TFloat64), + Ref(iruid(1), TFloat64), + ), + iruid(5) -> ApplyBinaryPrimOp( + Multiply(), + Ref(iruid(2), TFloat64), + Ref(iruid(2), TFloat64), + ), + iruid(6) -> ApplyBinaryPrimOp( + Multiply(), + Ref(iruid(3), TFloat64), + Ref(iruid(5), TFloat64), + ), ), MakeStruct(FastSeq( "radius" -> Ref(iruid(2), TFloat64), "circumference" -> Ref(iruid(4), TFloat64), "area" -> Ref(iruid(6), TFloat64), - )) - ), - Let(FastSeq( - iruid(1) -> Cast(I32(2), TFloat64), - iruid(2) -> ApplyBinaryPrimOp(FloatingPointDivide(), Ref(iruid(1), TFloat64), F64(2)), + )), ), + Let( + FastSeq( + iruid(1) -> Cast(I32(2), TFloat64), + iruid(2) -> ApplyBinaryPrimOp(FloatingPointDivide(), Ref(iruid(1), TFloat64), F64(2)), + ), MakeStruct(FastSeq( "radius" -> Ref(iruid(2), TFloat64), "circumference" -> ApplyBinaryPrimOp(Multiply(), F64(pi), Ref(iruid(1), TFloat64)), - "area" -> ApplyBinaryPrimOp(Multiply(), F64(pi), - ApplyBinaryPrimOp(Multiply(), Ref(iruid(2), TFloat64), Ref(iruid(2), TFloat64)) - ) - )) + "area" -> ApplyBinaryPrimOp( + Multiply(), + F64(pi), + ApplyBinaryPrimOp(Multiply(), Ref(iruid(2), TFloat64), Ref(iruid(2), TFloat64)), + ), + )), ), - "Cascading Let-bindings are forwarded" - ) + "Cascading Let-bindings are forwarded", + ), ) } @Test(dataProvider = "TrivialIRCases") - def testTrivialCases(input: IR, expected: IR, reason: String): Unit = - ForwardLets(ctx)(input) should be(expected) withClue reason + def testTrivialCases(input: IR, expected: IR, reason: String): Unit = { + val result = ForwardLets(ctx)(input) + assert( + result == expected, + s"\ninput:\n${Pretty.sexprStyle(input)}\nexpected:\n${Pretty.sexprStyle(expected)}\ngot:\n${Pretty.sexprStyle(result)}", + ) + } @Test def testAggregators(): Unit = { - val aggEnv = Env[Type]("row" -> TStruct("idx" -> TInt32)) - val ir0 = applyAggOp(Sum(), seqOpArgs = FastSeq(let(x = 'row('idx) - 1) { - 'x.toD - })) - .apply(aggEnv) + val row = Ref("row", TStruct("idx" -> TInt32)) + val x = Ref("x", TInt32) + val aggEnv = Env[Type](row.name -> row.typ) + + val ir0 = ApplyAggOp( + FastSeq(), + FastSeq(Let(FastSeq("x" -> (GetField(row, "idx") - 1)), Cast(x, TFloat64))), + AggSignature(Sum(), FastSeq(), FastSeq(TFloat64)), + ) TypeCheck(ctx, ForwardLets(ctx)(ir0), BindingEnv(Env.empty, agg = Some(aggEnv))) } @Test def testNestedBindingOverwrites(): Unit = { val env = Env[Type]("x" -> TInt32) - val ir = let(y = 'x.toD, x = 'x.toD) { - 'x + 'x + 'y - }(env) + def xInt = Ref("x", TInt32) + def xCast = Cast(xInt, TFloat64) + def xFloat = Ref("x", TFloat64) + def y = Ref("y", TFloat64) + val ir = Let( + FastSeq( + "y" -> xCast, + "x" -> xCast, + ), + xFloat + xFloat + y, + ) TypeCheck(ctx, ir, BindingEnv(env)) TypeCheck(ctx, ForwardLets(ctx)(ir), BindingEnv(env)) @@ -223,8 +298,8 @@ class ForwardLetsSuite extends HailSuite { StreamAgg( ToStream(In(1, TArray(TInt32))), "bar", - Ref("y", TInt32) + Ref("x", TInt32) - ) + Ref("y", TInt32) + Ref("x", TInt32), + ), ) TypeCheck(ctx, x, BindingEnv(Env("y" -> TInt32))) diff --git a/hail/src/test/scala/is/hail/expr/ir/FunctionSuite.scala b/hail/src/test/scala/is/hail/expr/ir/FunctionSuite.scala index 0da5a89a33a..75429f33c51 100644 --- a/hail/src/test/scala/is/hail/expr/ir/FunctionSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/FunctionSuite.scala @@ -1,12 +1,13 @@ package is.hail.expr.ir +import is.hail.{ExecStrategy, HailSuite} import is.hail.asm4s._ import is.hail.expr.ir.functions.{IRFunctionRegistry, RegistryFunctions} import is.hail.types.physical.stypes.interfaces._ import is.hail.types.virtual._ import is.hail.utils.FastSeq import is.hail.variant.Call2 -import is.hail.{ExecStrategy, HailSuite} + import org.testng.annotations.Test object ScalaTestObject { @@ -21,17 +22,25 @@ class ScalaTestCompanion { def testFunction(): Int = 3 } - object TestRegisterFunctions extends RegistryFunctions { - def registerAll() { + def registerAll(): Unit = { registerIR1("addone", TInt32, TInt32)((_, a, _) => ApplyBinaryPrimOp(Add(), a, I32(1))) - registerJavaStaticFunction("compare", Array(TInt32, TInt32), TInt32, null)(classOf[java.lang.Integer], "compare") - registerScalaFunction("foobar1", Array(), TInt32, null)(ScalaTestObject.getClass, "testFunction") - registerScalaFunction("foobar2", Array(), TInt32, null)(ScalaTestCompanion.getClass, "testFunction") + registerJavaStaticFunction("compare", Array(TInt32, TInt32), TInt32, null)( + classOf[java.lang.Integer], + "compare", + ) + registerScalaFunction("foobar1", Array(), TInt32, null)( + ScalaTestObject.getClass, + "testFunction", + ) + registerScalaFunction("foobar2", Array(), TInt32, null)( + ScalaTestCompanion.getClass, + "testFunction", + ) registerSCode2("testCodeUnification", tnum("x"), tv("x", "int32"), tv("x"), null) { - case (_, cb, rt, a, b, _) => primitive(cb.memoize(a.asInt.value + b.asInt.value)) + case (_, cb, _, a, b, _) => primitive(cb.memoize(a.asInt.value + b.asInt.value)) } - registerSCode1("testCodeUnification2", tv("x"), tv("x"), null) { case (_, cb, rt, a, _) => a } + registerSCode1("testCodeUnification2", tv("x"), tv("x"), null) { case (_, _, _, a, _) => a } } } @@ -47,53 +56,67 @@ class FunctionSuite extends HailSuite { } @Test - def testCodeFunction() { - assertEvalsTo(lookup("triangle", TInt32, TInt32)(In(0, TInt32)), + def testCodeFunction(): Unit = + assertEvalsTo( + lookup("triangle", TInt32, TInt32)(In(0, TInt32)), FastSeq(5 -> TInt32), - (5 * (5 + 1)) / 2) - } + (5 * (5 + 1)) / 2, + ) @Test - def testStaticFunction() { - assertEvalsTo(lookup("compare", TInt32, TInt32, TInt32)(In(0, TInt32), I32(0)) > 0, + def testStaticFunction(): Unit = + assertEvalsTo( + lookup("compare", TInt32, TInt32, TInt32)(In(0, TInt32), I32(0)) > 0, FastSeq(5 -> TInt32), - true) - } + true, + ) @Test - def testScalaFunction() { + def testScalaFunction(): Unit = assertEvalsTo(lookup("foobar1", TInt32)(), 1) - } @Test - def testIRConversion() { - assertEvalsTo(lookup("addone", TInt32, TInt32)(In(0, TInt32)), - FastSeq(5 -> TInt32), - 6) - } + def testIRConversion(): Unit = + assertEvalsTo(lookup("addone", TInt32, TInt32)(In(0, TInt32)), FastSeq(5 -> TInt32), 6) @Test - def testScalaFunctionCompanion() { + def testScalaFunctionCompanion(): Unit = assertEvalsTo(lookup("foobar2", TInt32)(), 2) - } @Test - def testVariableUnification() { - assert(IRFunctionRegistry.lookupUnseeded("testCodeUnification", TInt32, Seq(TInt32, TInt32)).isDefined) - assert(IRFunctionRegistry.lookupUnseeded("testCodeUnification", TInt32, Seq(TInt64, TInt32)).isEmpty) - assert(IRFunctionRegistry.lookupUnseeded("testCodeUnification", TInt64, Seq(TInt32, TInt32)).isEmpty) - assert(IRFunctionRegistry.lookupUnseeded("testCodeUnification2", TArray(TInt32), Seq(TArray(TInt32))).isDefined) + def testVariableUnification(): Unit = { + assert(IRFunctionRegistry.lookupUnseeded( + "testCodeUnification", + TInt32, + Seq(TInt32, TInt32), + ).isDefined) + assert(IRFunctionRegistry.lookupUnseeded( + "testCodeUnification", + TInt32, + Seq(TInt64, TInt32), + ).isEmpty) + assert(IRFunctionRegistry.lookupUnseeded( + "testCodeUnification", + TInt64, + Seq(TInt32, TInt32), + ).isEmpty) + assert(IRFunctionRegistry.lookupUnseeded( + "testCodeUnification2", + TArray(TInt32), + Seq(TArray(TInt32)), + ).isDefined) } @Test - def testUnphasedDiploidGtIndexCall() { - assertEvalsTo(lookup("UnphasedDiploidGtIndexCall", TCall, TInt32)(In(0, TInt32)), + def testUnphasedDiploidGtIndexCall(): Unit = + assertEvalsTo( + lookup("UnphasedDiploidGtIndexCall", TCall, TInt32)(In(0, TInt32)), FastSeq(0 -> TInt32), - Call2.fromUnphasedDiploidGtIndex(0)) - } + Call2.fromUnphasedDiploidGtIndex(0), + ) @Test - def testFunctionBuilderGetOrDefine() { + def testGetOrGenMethod(): Unit = { val fb = EmitFunctionBuilder[Int](ctx, "foo") val i = fb.genFieldThisRef[Int]() val mb1 = fb.getOrGenEmitMethod("foo", "foo", FastSeq[ParamType](), UnitInfo) { mb => @@ -102,14 +125,13 @@ class FunctionSuite extends HailSuite { val mb2 = fb.getOrGenEmitMethod("foo", "foo", FastSeq[ParamType](), UnitInfo) { mb => mb.emit(i := i - 100) } - fb.emitWithBuilder(cb => { + fb.emitWithBuilder { cb => cb.assign(i, 0) - mb1.invokeCode(cb) - mb2.invokeCode(cb) + cb.invokeVoid(mb1, cb.this_) + cb.invokeVoid(mb2, cb.this_) i - }) + } pool.scopedRegion { r => - assert(fb.resultWithIndex().apply(theHailClassLoader, ctx.fs, ctx.taskContext, r)() == 2) } } diff --git a/hail/src/test/scala/is/hail/expr/ir/GenotypeFunctionsSuite.scala b/hail/src/test/scala/is/hail/expr/ir/GenotypeFunctionsSuite.scala index fb407b155cd..521922af0e7 100644 --- a/hail/src/test/scala/is/hail/expr/ir/GenotypeFunctionsSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/GenotypeFunctionsSuite.scala @@ -1,17 +1,18 @@ package is.hail.expr.ir +import is.hail.{ExecStrategy, HailSuite} import is.hail.TestUtils._ import is.hail.expr.ir.TestUtils._ import is.hail.types.virtual.TFloat64 import is.hail.utils.FastSeq -import is.hail.{ExecStrategy, HailSuite} + import org.testng.annotations.{DataProvider, Test} class GenotypeFunctionsSuite extends HailSuite { implicit val execStrats = ExecStrategy.javaOnly - @DataProvider(name="gps") + @DataProvider(name = "gps") def gpData(): Array[Array[Any]] = Array( Array(FastSeq(1.0, 0.0, 0.0), 0.0), Array(FastSeq(0.0, 1.0, 0.0), 1.0), @@ -22,14 +23,14 @@ class GenotypeFunctionsSuite extends HailSuite { Array(FastSeq(null, null, null), null), Array(FastSeq(null, 0.5, 0.5), 1.5), Array(FastSeq(0.0, null, 1.0), null), - Array(FastSeq(0.0, 0.5, null), null)) + Array(FastSeq(0.0, 0.5, null), null), + ) - @Test(dataProvider="gps") - def testDosage(gp: IndexedSeq[java.lang.Double], expected: java.lang.Double) { + @Test(dataProvider = "gps") + def testDosage(gp: IndexedSeq[java.lang.Double], expected: java.lang.Double): Unit = assertEvalsTo(invoke("dosage", TFloat64, toIRDoubleArray(gp)), expected) - } - @Test def testDosageLength() { + @Test def testDosageLength(): Unit = { assertFatal(invoke("dosage", TFloat64, IRDoubleArray(1.0, 1.5)), "length") assertFatal(invoke("dosage", TFloat64, IRDoubleArray(1.0, 1.5, 0.0, 0.0)), "length") } diff --git a/hail/src/test/scala/is/hail/expr/ir/IRSuite.scala b/hail/src/test/scala/is/hail/expr/ir/IRSuite.scala index 1b4efd82c17..f10e0ead8f3 100644 --- a/hail/src/test/scala/is/hail/expr/ir/IRSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/IRSuite.scala @@ -1,5 +1,6 @@ package is.hail.expr.ir +import is.hail.{ExecStrategy, HailSuite} import is.hail.ExecStrategy.ExecStrategy import is.hail.TestUtils._ import is.hail.annotations.{BroadcastRow, ExtendedOrdering, SafeNDArray} @@ -9,29 +10,30 @@ import is.hail.expr.ir.ArrayZipBehavior.ArrayZipBehavior import is.hail.expr.ir.DeprecatedIRBuilder._ import is.hail.expr.ir.agg._ import is.hail.expr.ir.functions._ -import is.hail.io.bgen.MatrixBGENReader import is.hail.io.{BufferSpec, TypedCodecSpec} +import is.hail.io.bgen.MatrixBGENReader import is.hail.linalg.BlockMatrix import is.hail.methods._ import is.hail.rvd.{PartitionBoundOrdering, RVD, RVDPartitioner} +import is.hail.types.{tcoerce, BlockMatrixType, TableType, VirtualTypeWithReq} import is.hail.types.physical._ import is.hail.types.physical.stypes._ import is.hail.types.physical.stypes.primitives.SInt32 import is.hail.types.virtual._ -import is.hail.types.{BlockMatrixType, TableType, VirtualTypeWithReq, tcoerce} +import is.hail.types.virtual.TIterable.elementType import is.hail.utils.{FastSeq, _} import is.hail.variant.{Call2, Locus} -import is.hail.{ExecStrategy, HailSuite} + +import scala.language.implicitConversions + import org.apache.spark.sql.Row import org.json4s.jackson.{JsonMethods, Serialization} import org.testng.annotations.{DataProvider, Test} -import scala.language.{dynamics, implicitConversions} - class IRSuite extends HailSuite { implicit val execStrats = ExecStrategy.nonLowering - @Test def testRandDifferentLengthUIDStrings() { + @Test def testRandDifferentLengthUIDStrings(): Unit = { implicit val execStrats = ExecStrategy.lowering val staticUID: Long = 112233 var rng: IR = RNGStateLiteral() @@ -52,37 +54,29 @@ class IRSuite extends HailSuite { assert(expected1 != expected3) } - @Test def testI32() { + @Test def testI32(): Unit = assertEvalsTo(I32(5), 5) - } - - @Test def testI64() { + @Test def testI64(): Unit = assertEvalsTo(I64(5), 5L) - } - @Test def testF32() { + @Test def testF32(): Unit = assertEvalsTo(F32(3.14f), 3.14f) - } - @Test def testF64() { + @Test def testF64(): Unit = assertEvalsTo(F64(3.14), 3.14) - } - @Test def testStr() { + @Test def testStr(): Unit = assertEvalsTo(Str("Hail"), "Hail") - } - @Test def testTrue() { + @Test def testTrue(): Unit = assertEvalsTo(True(), true) - } - @Test def testFalse() { + @Test def testFalse(): Unit = assertEvalsTo(False(), false) - } // FIXME Void() doesn't work because we can't handle a void type in a tuple - @Test def testCast() { + @Test def testCast(): Unit = { assertAllEvalTo( (Cast(I32(5), TInt32), 5), (Cast(I32(5), TInt64), 5L), @@ -102,31 +96,38 @@ class IRSuite extends HailSuite { (Cast(F64(3.99), TInt32), 3), // truncate (Cast(F64(3.14), TInt64), 3L), (Cast(F64(3.14), TFloat32), 3.14f), - (Cast(F64(3.14), TFloat64), 3.14)) + (Cast(F64(3.14), TFloat64), 3.14), + ) } - @Test def testCastRename() { + @Test def testCastRename(): Unit = { assertEvalsTo(CastRename(MakeStruct(FastSeq(("x", I32(1)))), TStruct("foo" -> TInt32)), Row(1)) - assertEvalsTo(CastRename(MakeArray(FastSeq(MakeStruct(FastSeq(("x", I32(1))))), - TArray(TStruct("x" -> TInt32))), TArray(TStruct("foo" -> TInt32))), - FastSeq(Row(1))) + assertEvalsTo( + CastRename( + MakeArray(FastSeq(MakeStruct(FastSeq(("x", I32(1))))), TArray(TStruct("x" -> TInt32))), + TArray(TStruct("foo" -> TInt32)), + ), + FastSeq(Row(1)), + ) } - @Test def testNA() { + @Test def testNA(): Unit = assertEvalsTo(NA(TInt32), null) - } - @Test def testCoalesce() { + @Test def testCoalesce(): Unit = { assertEvalsTo(Coalesce(FastSeq(In(0, TInt32))), FastSeq((null, TInt32)), null) assertEvalsTo(Coalesce(FastSeq(In(0, TInt32))), FastSeq((1, TInt32)), 1) assertEvalsTo(Coalesce(FastSeq(NA(TInt32), In(0, TInt32))), FastSeq((null, TInt32)), null) assertEvalsTo(Coalesce(FastSeq(NA(TInt32), In(0, TInt32))), FastSeq((1, TInt32)), 1) assertEvalsTo(Coalesce(FastSeq(In(0, TInt32), NA(TInt32))), FastSeq((1, TInt32)), 1) - assertEvalsTo(Coalesce(FastSeq(NA(TInt32), I32(1), I32(1), NA(TInt32), I32(1), NA(TInt32), I32(1))), 1) + assertEvalsTo( + Coalesce(FastSeq(NA(TInt32), I32(1), I32(1), NA(TInt32), I32(1), NA(TInt32), I32(1))), + 1, + ) assertEvalsTo(Coalesce(FastSeq(NA(TInt32), I32(1), Die("foo", TInt32))), 1) } - @Test def testCoalesceWithDifferentRequiredeness() { + @Test def testCoalesceWithDifferentRequiredeness(): Unit = { val t1 = In(0, TArray(TInt32)) val t2 = NA(TArray(TInt32)) val value = FastSeq(1, 2, 3, 4) @@ -140,51 +141,56 @@ class IRSuite extends HailSuite { val f64na = NA(TFloat64) val bna = NA(TBoolean) - @Test def testApplyUnaryPrimOpNegate() { + @Test def testApplyUnaryPrimOpNegate(): Unit = { assertAllEvalTo( (ApplyUnaryPrimOp(Negate, I32(5)), -5), (ApplyUnaryPrimOp(Negate, i32na), null), (ApplyUnaryPrimOp(Negate, I64(5)), -5L), (ApplyUnaryPrimOp(Negate, i64na), null), - (ApplyUnaryPrimOp(Negate, F32(5)), -5F), + (ApplyUnaryPrimOp(Negate, F32(5)), -5f), (ApplyUnaryPrimOp(Negate, f32na), null), - (ApplyUnaryPrimOp(Negate, F64(5)), -5D), - (ApplyUnaryPrimOp(Negate, f64na), null) + (ApplyUnaryPrimOp(Negate, F64(5)), -5d), + (ApplyUnaryPrimOp(Negate, f64na), null), ) } - @Test def testApplyUnaryPrimOpBang() { + @Test def testApplyUnaryPrimOpBang(): Unit = { assertEvalsTo(ApplyUnaryPrimOp(Bang, False()), true) assertEvalsTo(ApplyUnaryPrimOp(Bang, True()), false) assertEvalsTo(ApplyUnaryPrimOp(Bang, bna), null) } - @Test def testApplyUnaryPrimOpBitFlip() { + @Test def testApplyUnaryPrimOpBitFlip(): Unit = { assertAllEvalTo( (ApplyUnaryPrimOp(BitNot, I32(0xdeadbeef)), ~0xdeadbeef), (ApplyUnaryPrimOp(BitNot, I32(-0xdeadbeef)), ~(-0xdeadbeef)), (ApplyUnaryPrimOp(BitNot, i32na), null), (ApplyUnaryPrimOp(BitNot, I64(0xdeadbeef12345678L)), ~0xdeadbeef12345678L), (ApplyUnaryPrimOp(BitNot, I64(-0xdeadbeef12345678L)), ~(-0xdeadbeef12345678L)), - (ApplyUnaryPrimOp(BitNot, i64na), null) + (ApplyUnaryPrimOp(BitNot, i64na), null), ) } - @Test def testApplyUnaryPrimOpBitCount() { + @Test def testApplyUnaryPrimOpBitCount(): Unit = { assertAllEvalTo( (ApplyUnaryPrimOp(BitCount, I32(0xdeadbeef)), Integer.bitCount(0xdeadbeef)), (ApplyUnaryPrimOp(BitCount, I32(-0xdeadbeef)), Integer.bitCount(-0xdeadbeef)), (ApplyUnaryPrimOp(BitCount, i32na), null), - (ApplyUnaryPrimOp(BitCount, I64(0xdeadbeef12345678L)), java.lang.Long.bitCount(0xdeadbeef12345678L)), - (ApplyUnaryPrimOp(BitCount, I64(-0xdeadbeef12345678L)), java.lang.Long.bitCount(-0xdeadbeef12345678L)), - (ApplyUnaryPrimOp(BitCount, i64na), null) + ( + ApplyUnaryPrimOp(BitCount, I64(0xdeadbeef12345678L)), + java.lang.Long.bitCount(0xdeadbeef12345678L), + ), + ( + ApplyUnaryPrimOp(BitCount, I64(-0xdeadbeef12345678L)), + java.lang.Long.bitCount(-0xdeadbeef12345678L), + ), + (ApplyUnaryPrimOp(BitCount, i64na), null), ) } - @Test def testApplyBinaryPrimOpAdd() { - def assertSumsTo(t: Type, x: Any, y: Any, sum: Any) { + @Test def testApplyBinaryPrimOpAdd(): Unit = { + def assertSumsTo(t: Type, x: Any, y: Any, sum: Any): Unit = assertEvalsTo(ApplyBinaryPrimOp(Add(), In(0, t), In(1, t)), FastSeq(x -> t, y -> t), sum) - } assertSumsTo(TInt32, 5, 3, 8) assertSumsTo(TInt32, 5, null, null) assertSumsTo(TInt32, null, 3, null) @@ -206,10 +212,13 @@ class IRSuite extends HailSuite { assertSumsTo(TFloat64, null, null, null) } - @Test def testApplyBinaryPrimOpSubtract() { - def assertExpected(t: Type, x: Any, y: Any, expected: Any) { - assertEvalsTo(ApplyBinaryPrimOp(Subtract(), In(0, t), In(1, t)), FastSeq(x -> t, y -> t), expected) - } + @Test def testApplyBinaryPrimOpSubtract(): Unit = { + def assertExpected(t: Type, x: Any, y: Any, expected: Any): Unit = + assertEvalsTo( + ApplyBinaryPrimOp(Subtract(), In(0, t), In(1, t)), + FastSeq(x -> t, y -> t), + expected, + ) assertExpected(TInt32, 5, 2, 3) assertExpected(TInt32, 5, null, null) @@ -232,10 +241,13 @@ class IRSuite extends HailSuite { assertExpected(TFloat64, null, null, null) } - @Test def testApplyBinaryPrimOpMultiply() { - def assertExpected(t: Type, x: Any, y: Any, expected: Any) { - assertEvalsTo(ApplyBinaryPrimOp(Multiply(), In(0, t), In(1, t)), FastSeq(x -> t, y -> t), expected) - } + @Test def testApplyBinaryPrimOpMultiply(): Unit = { + def assertExpected(t: Type, x: Any, y: Any, expected: Any): Unit = + assertEvalsTo( + ApplyBinaryPrimOp(Multiply(), In(0, t), In(1, t)), + FastSeq(x -> t, y -> t), + expected, + ) assertExpected(TInt32, 5, 2, 10) assertExpected(TInt32, 5, null, null) @@ -258,10 +270,13 @@ class IRSuite extends HailSuite { assertExpected(TFloat64, null, null, null) } - @Test def testApplyBinaryPrimOpFloatingPointDivide() { - def assertExpected(t: Type, x: Any, y: Any, expected: Any) { - assertEvalsTo(ApplyBinaryPrimOp(FloatingPointDivide(), In(0, t), In(1, t)), FastSeq(x -> t, y -> t), expected) - } + @Test def testApplyBinaryPrimOpFloatingPointDivide(): Unit = { + def assertExpected(t: Type, x: Any, y: Any, expected: Any): Unit = + assertEvalsTo( + ApplyBinaryPrimOp(FloatingPointDivide(), In(0, t), In(1, t)), + FastSeq(x -> t, y -> t), + expected, + ) assertExpected(TInt32, 5, 2, 2.5) assertExpected(TInt32, 5, null, null) @@ -284,10 +299,13 @@ class IRSuite extends HailSuite { assertExpected(TFloat64, null, null, null) } - @Test def testApplyBinaryPrimOpRoundToNegInfDivide() { - def assertExpected(t: Type, x: Any, y: Any, expected: Any) { - assertEvalsTo(ApplyBinaryPrimOp(RoundToNegInfDivide(), In(0, t), In(1, t)), FastSeq(x -> t, y -> t), expected) - } + @Test def testApplyBinaryPrimOpRoundToNegInfDivide(): Unit = { + def assertExpected(t: Type, x: Any, y: Any, expected: Any): Unit = + assertEvalsTo( + ApplyBinaryPrimOp(RoundToNegInfDivide(), In(0, t), In(1, t)), + FastSeq(x -> t, y -> t), + expected, + ) assertExpected(TInt32, 5, 2, 2) assertExpected(TInt32, 5, null, null) @@ -311,9 +329,12 @@ class IRSuite extends HailSuite { } @Test def testApplyBinaryPrimOpBitAnd(): Unit = { - def assertExpected(t: Type, x: Any, y: Any, expected: Any) { - assertEvalsTo(ApplyBinaryPrimOp(BitAnd(), In(0, t), In(1, t)), FastSeq(x -> t, y -> t), expected) - } + def assertExpected(t: Type, x: Any, y: Any, expected: Any): Unit = + assertEvalsTo( + ApplyBinaryPrimOp(BitAnd(), In(0, t), In(1, t)), + FastSeq(x -> t, y -> t), + expected, + ) assertExpected(TInt32, 5, 2, 5 & 2) assertExpected(TInt32, -5, 2, -5 & 2) @@ -333,9 +354,12 @@ class IRSuite extends HailSuite { } @Test def testApplyBinaryPrimOpBitOr(): Unit = { - def assertExpected(t: Type, x: Any, y: Any, expected: Any) { - assertEvalsTo(ApplyBinaryPrimOp(BitOr(), In(0, t), In(1, t)), FastSeq(x -> t, y -> t), expected) - } + def assertExpected(t: Type, x: Any, y: Any, expected: Any): Unit = + assertEvalsTo( + ApplyBinaryPrimOp(BitOr(), In(0, t), In(1, t)), + FastSeq(x -> t, y -> t), + expected, + ) assertExpected(TInt32, 5, 2, 5 | 2) assertExpected(TInt32, -5, 2, -5 | 2) @@ -355,9 +379,12 @@ class IRSuite extends HailSuite { } @Test def testApplyBinaryPrimOpBitXOr(): Unit = { - def assertExpected(t: Type, x: Any, y: Any, expected: Any) { - assertEvalsTo(ApplyBinaryPrimOp(BitXOr(), In(0, t), In(1, t)), FastSeq(x -> t, y -> t), expected) - } + def assertExpected(t: Type, x: Any, y: Any, expected: Any): Unit = + assertEvalsTo( + ApplyBinaryPrimOp(BitXOr(), In(0, t), In(1, t)), + FastSeq(x -> t, y -> t), + expected, + ) assertExpected(TInt32, 5, 2, 5 ^ 2) assertExpected(TInt32, -5, 2, -5 ^ 2) @@ -377,9 +404,12 @@ class IRSuite extends HailSuite { } @Test def testApplyBinaryPrimOpLeftShift(): Unit = { - def assertShiftsTo(t: Type, x: Any, y: Any, expected: Any) { - assertEvalsTo(ApplyBinaryPrimOp(LeftShift(), In(0, t), In(1, TInt32)), FastSeq(x -> t, y -> TInt32), expected) - } + def assertShiftsTo(t: Type, x: Any, y: Any, expected: Any): Unit = + assertEvalsTo( + ApplyBinaryPrimOp(LeftShift(), In(0, t), In(1, TInt32)), + FastSeq(x -> t, y -> TInt32), + expected, + ) assertShiftsTo(TInt32, 5, 2, 5 << 2) assertShiftsTo(TInt32, -5, 2, -5 << 2) @@ -395,9 +425,12 @@ class IRSuite extends HailSuite { } @Test def testApplyBinaryPrimOpRightShift(): Unit = { - def assertShiftsTo(t: Type, x: Any, y: Any, expected: Any) { - assertEvalsTo(ApplyBinaryPrimOp(RightShift(), In(0, t), In(1, TInt32)), FastSeq(x -> t, y -> TInt32), expected) - } + def assertShiftsTo(t: Type, x: Any, y: Any, expected: Any): Unit = + assertEvalsTo( + ApplyBinaryPrimOp(RightShift(), In(0, t), In(1, TInt32)), + FastSeq(x -> t, y -> TInt32), + expected, + ) assertShiftsTo(TInt32, 0xff5, 2, 0xff5 >> 2) assertShiftsTo(TInt32, -5, 2, -5 >> 2) @@ -413,9 +446,12 @@ class IRSuite extends HailSuite { } @Test def testApplyBinaryPrimOpLogicalRightShift(): Unit = { - def assertShiftsTo(t: Type, x: Any, y: Any, expected: Any) { - assertEvalsTo(ApplyBinaryPrimOp(LogicalRightShift(), In(0, t), In(1, TInt32)), FastSeq(x -> t, y -> TInt32), expected) - } + def assertShiftsTo(t: Type, x: Any, y: Any, expected: Any): Unit = + assertEvalsTo( + ApplyBinaryPrimOp(LogicalRightShift(), In(0, t), In(1, TInt32)), + FastSeq(x -> t, y -> TInt32), + expected, + ) assertShiftsTo(TInt32, 0xff5, 2, 0xff5 >>> 2) assertShiftsTo(TInt32, -5, 2, -5 >>> 2) @@ -430,10 +466,9 @@ class IRSuite extends HailSuite { assertShiftsTo(TInt64, null, null, null) } - @Test def testApplyComparisonOpGT() { - def assertComparesTo(t: Type, x: Any, y: Any, expected: Boolean) { + @Test def testApplyComparisonOpGT(): Unit = { + def assertComparesTo(t: Type, x: Any, y: Any, expected: Boolean): Unit = assertEvalsTo(ApplyComparisonOp(GT(t), In(0, t), In(1, t)), FastSeq(x -> t, y -> t), expected) - } assertComparesTo(TInt32, 1, 1, false) assertComparesTo(TInt32, 0, 1, false) @@ -453,10 +488,13 @@ class IRSuite extends HailSuite { } - @Test def testApplyComparisonOpGTEQ() { - def assertComparesTo(t: Type, x: Any, y: Any, expected: Boolean) { - assertEvalsTo(ApplyComparisonOp(GTEQ(t), In(0, t), In(1, t)), FastSeq(x -> t, y -> t), expected) - } + @Test def testApplyComparisonOpGTEQ(): Unit = { + def assertComparesTo(t: Type, x: Any, y: Any, expected: Boolean): Unit = + assertEvalsTo( + ApplyComparisonOp(GTEQ(t), In(0, t), In(1, t)), + FastSeq(x -> t, y -> t), + expected, + ) assertComparesTo(TInt32, 1, 1, true) assertComparesTo(TInt32, 0, 1, false) @@ -475,10 +513,9 @@ class IRSuite extends HailSuite { assertComparesTo(TFloat64, 1.0, 0.0, true) } - @Test def testApplyComparisonOpLT() { - def assertComparesTo(t: Type, x: Any, y: Any, expected: Boolean) { + @Test def testApplyComparisonOpLT(): Unit = { + def assertComparesTo(t: Type, x: Any, y: Any, expected: Boolean): Unit = assertEvalsTo(ApplyComparisonOp(LT(t), In(0, t), In(1, t)), FastSeq(x -> t, y -> t), expected) - } assertComparesTo(TInt32, 1, 1, false) assertComparesTo(TInt32, 0, 1, true) @@ -498,10 +535,13 @@ class IRSuite extends HailSuite { } - @Test def testApplyComparisonOpLTEQ() { - def assertComparesTo(t: Type, x: Any, y: Any, expected: Boolean) { - assertEvalsTo(ApplyComparisonOp(LTEQ(t), In(0, t), In(1, t)), FastSeq(x -> t, y -> t), expected) - } + @Test def testApplyComparisonOpLTEQ(): Unit = { + def assertComparesTo(t: Type, x: Any, y: Any, expected: Boolean): Unit = + assertEvalsTo( + ApplyComparisonOp(LTEQ(t), In(0, t), In(1, t)), + FastSeq(x -> t, y -> t), + expected, + ) assertComparesTo(TInt32, 1, 1, true) assertComparesTo(TInt32, 0, 1, true) @@ -521,10 +561,9 @@ class IRSuite extends HailSuite { } - @Test def testApplyComparisonOpEQ() { - def assertComparesTo(t: Type, x: Any, y: Any, expected: Boolean) { + @Test def testApplyComparisonOpEQ(): Unit = { + def assertComparesTo(t: Type, x: Any, y: Any, expected: Boolean): Unit = assertEvalsTo(ApplyComparisonOp(EQ(t), In(0, t), In(1, t)), FastSeq(x -> t, y -> t), expected) - } assertComparesTo(TInt32, 1, 1, expected = true) assertComparesTo(TInt32, 0, 1, expected = false) @@ -543,10 +582,13 @@ class IRSuite extends HailSuite { assertComparesTo(TFloat64, 1.0, 0.0, expected = false) } - @Test def testApplyComparisonOpNE() { - def assertComparesTo(t: Type, x: Any, y: Any, expected: Boolean) { - assertEvalsTo(ApplyComparisonOp(NEQ(t), In(0, t), In(1, t)), FastSeq(x -> t, y -> t), expected) - } + @Test def testApplyComparisonOpNE(): Unit = { + def assertComparesTo(t: Type, x: Any, y: Any, expected: Boolean): Unit = + assertEvalsTo( + ApplyComparisonOp(NEQ(t), In(0, t), In(1, t)), + FastSeq(x -> t, y -> t), + expected, + ) assertComparesTo(TInt32, 1, 1, expected = false) assertComparesTo(TInt32, 0, 1, expected = true) @@ -565,18 +607,17 @@ class IRSuite extends HailSuite { assertComparesTo(TFloat64, 1.0, 0.0, expected = true) } - @Test def testDieCodeBUilder() { + @Test def testDieCodeBUilder(): Unit = assertFatal(Die("msg1", TInt32) + Die("msg2", TInt32), "msg1") - } - @Test def testIf() { + @Test def testIf(): Unit = { assertEvalsTo(If(True(), I32(5), I32(7)), 5) assertEvalsTo(If(False(), I32(5), I32(7)), 7) assertEvalsTo(If(NA(TBoolean), I32(5), I32(7)), null) assertEvalsTo(If(True(), NA(TInt32), I32(7)), null) } - @DataProvider(name="SwitchEval") + @DataProvider(name = "SwitchEval") def switchEvalRules: Array[Array[Any]] = Array( Array(I32(-1), I32(Int.MinValue), FastSeq(0, Int.MaxValue).map(I32), Int.MinValue), @@ -592,288 +633,363 @@ class IRSuite extends HailSuite { def testSwitch(x: IR, default: IR, cases: IndexedSeq[IR], result: Any): Unit = assertEvalsTo(Switch(x, default, cases), result) - @Test def testLet() { + @Test def testLet(): Unit = { assertEvalsTo(Let(FastSeq("v" -> I32(5)), Ref("v", TInt32)), 5) assertEvalsTo(Let(FastSeq("v" -> NA(TInt32)), Ref("v", TInt32)), null) assertEvalsTo(Let(FastSeq("v" -> I32(5)), NA(TInt32)), null) assertEvalsTo( - ToArray(mapIR(Let(FastSeq("v" -> I32(5)), StreamRange(0, Ref("v", TInt32), 1))) { x => x + I32(2) }), - FastSeq(2, 3, 4, 5, 6) + ToArray(mapIR(Let(FastSeq("v" -> I32(5)), StreamRange(0, Ref("v", TInt32), 1))) { x => + x + I32(2) + }), + FastSeq(2, 3, 4, 5, 6), ) assertEvalsTo( - ToArray(StreamMap(Let(FastSeq("q" -> I32(2)), - StreamMap(Let(FastSeq("v" -> (Ref("q", TInt32) + I32(3))), - StreamRange(0, Ref("v", TInt32), 1)), - "x", Ref("x", TInt32) + Ref("q", TInt32))), - "y", Ref("y", TInt32) + I32(3))), - FastSeq(5, 6, 7, 8, 9)) + ToArray(StreamMap( + Let( + FastSeq("q" -> I32(2)), + StreamMap( + Let(FastSeq("v" -> (Ref("q", TInt32) + I32(3))), StreamRange(0, Ref("v", TInt32), 1)), + "x", + Ref("x", TInt32) + Ref("q", TInt32), + ), + ), + "y", + Ref("y", TInt32) + I32(3), + )), + FastSeq(5, 6, 7, 8, 9), + ) // test let binding streams - assertEvalsTo(Let(FastSeq("s" -> MakeStream(IndexedSeq(I32(0), I32(5)), TStream(TInt32))), ToArray(Ref("s", TStream(TInt32)))), - FastSeq(0, 5)) - assertEvalsTo(Let(FastSeq("s" -> NA(TStream(TInt32))), ToArray(Ref("s", TStream(TInt32)))), - null) assertEvalsTo( - ToArray(Let(FastSeq("s" -> MakeStream(IndexedSeq(I32(0), I32(5)), TStream(TInt32))), - StreamTake(Ref("s", TStream(TInt32)), I32(1)))), - FastSeq(0)) + Let( + FastSeq("s" -> MakeStream(IndexedSeq(I32(0), I32(5)), TStream(TInt32))), + ToArray(Ref("s", TStream(TInt32))), + ), + FastSeq(0, 5), + ) + assertEvalsTo( + Let(FastSeq("s" -> NA(TStream(TInt32))), ToArray(Ref("s", TStream(TInt32)))), + null, + ) + assertEvalsTo( + ToArray(Let( + FastSeq("s" -> MakeStream(IndexedSeq(I32(0), I32(5)), TStream(TInt32))), + StreamTake(Ref("s", TStream(TInt32)), I32(1)), + )), + FastSeq(0), + ) } - @Test def testMakeArray() { - assertEvalsTo(MakeArray(FastSeq(I32(5), NA(TInt32), I32(-3)), TArray(TInt32)), FastSeq(5, null, -3)) + @Test def testMakeArray(): Unit = { + assertEvalsTo( + MakeArray(FastSeq(I32(5), NA(TInt32), I32(-3)), TArray(TInt32)), + FastSeq(5, null, -3), + ) assertEvalsTo(MakeArray(FastSeq(), TArray(TInt32)), FastSeq()) } - @Test def testGetNestedElementPTypesI32() { + @Test def testGetNestedElementPTypesI32(): Unit = { var types = IndexedSeq(PInt32(true)) - var res = InferPType.getCompatiblePType(types) + var res = InferPType.getCompatiblePType(types) assert(res == PInt32(true)) types = IndexedSeq(PInt32(false)) - res = InferPType.getCompatiblePType(types) + res = InferPType.getCompatiblePType(types) assert(res == PInt32(false)) types = IndexedSeq(PInt32(false), PInt32(true)) - res = InferPType.getCompatiblePType(types) + res = InferPType.getCompatiblePType(types) assert(res == PInt32(false)) types = IndexedSeq(PInt32(true), PInt32(true)) - res = InferPType.getCompatiblePType(types) + res = InferPType.getCompatiblePType(types) assert(res == PInt32(true)) } - @Test def testGetNestedElementPTypesI64() { + @Test def testGetNestedElementPTypesI64(): Unit = { var types = IndexedSeq(PInt64(true)) - var res = InferPType.getCompatiblePType(types) + var res = InferPType.getCompatiblePType(types) assert(res == PInt64(true)) types = IndexedSeq(PInt64(false)) - res = InferPType.getCompatiblePType(types) + res = InferPType.getCompatiblePType(types) assert(res == PInt64(false)) types = IndexedSeq(PInt64(false), PInt64(true)) - res = InferPType.getCompatiblePType(types) + res = InferPType.getCompatiblePType(types) assert(res == PInt64(false)) types = IndexedSeq(PInt64(true), PInt64(true)) - res = InferPType.getCompatiblePType(types) + res = InferPType.getCompatiblePType(types) assert(res == PInt64(true)) } - @Test def testGetNestedElementPFloat32() { + @Test def testGetNestedElementPFloat32(): Unit = { var types = IndexedSeq(PFloat32(true)) - var res = InferPType.getCompatiblePType(types) + var res = InferPType.getCompatiblePType(types) assert(res == PFloat32(true)) types = IndexedSeq(PFloat32(false)) - res = InferPType.getCompatiblePType(types) + res = InferPType.getCompatiblePType(types) assert(res == PFloat32(false)) types = IndexedSeq(PFloat32(false), PFloat32(true)) - res = InferPType.getCompatiblePType(types) + res = InferPType.getCompatiblePType(types) assert(res == PFloat32(false)) types = IndexedSeq(PFloat32(true), PFloat32(true)) - res = InferPType.getCompatiblePType(types) + res = InferPType.getCompatiblePType(types) assert(res == PFloat32(true)) } - @Test def testGetNestedElementPFloat64() { + @Test def testGetNestedElementPFloat64(): Unit = { var types = IndexedSeq(PFloat64(true)) - var res = InferPType.getCompatiblePType(types) + var res = InferPType.getCompatiblePType(types) assert(res == PFloat64(true)) types = IndexedSeq(PFloat64(false)) - res = InferPType.getCompatiblePType(types) + res = InferPType.getCompatiblePType(types) assert(res == PFloat64(false)) types = IndexedSeq(PFloat64(false), PFloat64(true)) - res = InferPType.getCompatiblePType(types) + res = InferPType.getCompatiblePType(types) assert(res == PFloat64(false)) types = IndexedSeq(PFloat64(true), PFloat64(true)) - res = InferPType.getCompatiblePType(types) + res = InferPType.getCompatiblePType(types) assert(res == PFloat64(true)) } - @Test def testGetNestedElementPCanonicalString() { + @Test def testGetNestedElementPCanonicalString(): Unit = { var types = IndexedSeq(PCanonicalString(true)) - var res = InferPType.getCompatiblePType(types) + var res = InferPType.getCompatiblePType(types) assert(res == PCanonicalString(true)) types = IndexedSeq(PCanonicalString(false)) - res = InferPType.getCompatiblePType(types) + res = InferPType.getCompatiblePType(types) assert(res == PCanonicalString(false)) types = IndexedSeq(PCanonicalString(false), PCanonicalString(true)) - res = InferPType.getCompatiblePType(types) + res = InferPType.getCompatiblePType(types) assert(res == PCanonicalString(false)) types = IndexedSeq(PCanonicalString(true), PCanonicalString(true)) - res = InferPType.getCompatiblePType(types) + res = InferPType.getCompatiblePType(types) assert(res == PCanonicalString(true)) } - @Test def testGetNestedPCanonicalArray() { + @Test def testGetNestedPCanonicalArray(): Unit = { var types = IndexedSeq(PCanonicalArray(PInt32(true), true)) - var res = InferPType.getCompatiblePType(types) + var res = InferPType.getCompatiblePType(types) assert(res == PCanonicalArray(PInt32(true), true)) types = IndexedSeq(PCanonicalArray(PInt32(true), false)) - res = InferPType.getCompatiblePType(types) + res = InferPType.getCompatiblePType(types) assert(res == PCanonicalArray(PInt32(true), false)) types = IndexedSeq(PCanonicalArray(PInt32(false), true)) - res = InferPType.getCompatiblePType(types) + res = InferPType.getCompatiblePType(types) assert(res == PCanonicalArray(PInt32(false), true)) types = IndexedSeq(PCanonicalArray(PInt32(false), false)) - res = InferPType.getCompatiblePType(types) + res = InferPType.getCompatiblePType(types) assert(res == PCanonicalArray(PInt32(false), false)) types = IndexedSeq( PCanonicalArray(PInt32(true), true), - PCanonicalArray(PInt32(true), true) + PCanonicalArray(PInt32(true), true), ) - res = InferPType.getCompatiblePType(types) + res = InferPType.getCompatiblePType(types) assert(res == PCanonicalArray(PInt32(true), true)) types = IndexedSeq( PCanonicalArray(PInt32(false), true), - PCanonicalArray(PInt32(true), true) + PCanonicalArray(PInt32(true), true), ) - res = InferPType.getCompatiblePType(types) + res = InferPType.getCompatiblePType(types) assert(res == PCanonicalArray(PInt32(false), true)) types = IndexedSeq( PCanonicalArray(PInt32(false), true), - PCanonicalArray(PInt32(true), false) + PCanonicalArray(PInt32(true), false), ) - res = InferPType.getCompatiblePType(types) + res = InferPType.getCompatiblePType(types) assert(res == PCanonicalArray(PInt32(false), false)) types = IndexedSeq( PCanonicalArray(PCanonicalArray(PInt32(true), true), true), - PCanonicalArray(PCanonicalArray(PInt32(true), true), true) + PCanonicalArray(PCanonicalArray(PInt32(true), true), true), ) - res = InferPType.getCompatiblePType(types) + res = InferPType.getCompatiblePType(types) assert(res == PCanonicalArray(PCanonicalArray(PInt32(true), true), true)) types = IndexedSeq( PCanonicalArray(PCanonicalArray(PInt32(true), true), true), - PCanonicalArray(PCanonicalArray(PInt32(false), true), true) + PCanonicalArray(PCanonicalArray(PInt32(false), true), true), ) - res = InferPType.getCompatiblePType(types) + res = InferPType.getCompatiblePType(types) assert(res == PCanonicalArray(PCanonicalArray(PInt32(false), true), true)) types = IndexedSeq( PCanonicalArray(PCanonicalArray(PInt32(true), false), true), - PCanonicalArray(PCanonicalArray(PInt32(false), true), true) + PCanonicalArray(PCanonicalArray(PInt32(false), true), true), ) - res = InferPType.getCompatiblePType(types) + res = InferPType.getCompatiblePType(types) assert(res == PCanonicalArray(PCanonicalArray(PInt32(false), false), true)) types = IndexedSeq( PCanonicalArray(PCanonicalArray(PInt32(true), false), false), - PCanonicalArray(PCanonicalArray(PInt32(false), true), true) + PCanonicalArray(PCanonicalArray(PInt32(false), true), true), ) - res = InferPType.getCompatiblePType(types) + res = InferPType.getCompatiblePType(types) assert(res == PCanonicalArray(PCanonicalArray(PInt32(false), false), false)) } - @Test def testGetNestedElementPCanonicalDict() { + @Test def testGetNestedElementPCanonicalDict(): Unit = { var types = IndexedSeq(PCanonicalDict(PInt32(true), PCanonicalString(true), true)) - var res = InferPType.getCompatiblePType(types) + var res = InferPType.getCompatiblePType(types) assert(res == PCanonicalDict(PInt32(true), PCanonicalString(true), true)) types = IndexedSeq(PCanonicalDict(PInt32(false), PCanonicalString(true), true)) - res = InferPType.getCompatiblePType(types) + res = InferPType.getCompatiblePType(types) assert(res == PCanonicalDict(PInt32(false), PCanonicalString(true), true)) types = IndexedSeq(PCanonicalDict(PInt32(true), PCanonicalString(false), true)) - res = InferPType.getCompatiblePType(types) + res = InferPType.getCompatiblePType(types) assert(res == PCanonicalDict(PInt32(true), PCanonicalString(false), true)) types = IndexedSeq(PCanonicalDict(PInt32(true), PCanonicalString(true), false)) - res = InferPType.getCompatiblePType(types) + res = InferPType.getCompatiblePType(types) assert(res == PCanonicalDict(PInt32(true), PCanonicalString(true), false)) types = IndexedSeq(PCanonicalDict(PInt32(false), PCanonicalString(false), false)) - res = InferPType.getCompatiblePType(types) + res = InferPType.getCompatiblePType(types) assert(res == PCanonicalDict(PInt32(false), PCanonicalString(false), false)) types = IndexedSeq( PCanonicalDict(PInt32(true), PCanonicalString(true), true), - PCanonicalDict(PInt32(true), PCanonicalString(true), true) + PCanonicalDict(PInt32(true), PCanonicalString(true), true), ) - res = InferPType.getCompatiblePType(types) + res = InferPType.getCompatiblePType(types) assert(res == PCanonicalDict(PInt32(true), PCanonicalString(true), true)) types = IndexedSeq( PCanonicalDict(PInt32(true), PCanonicalString(true), false), - PCanonicalDict(PInt32(true), PCanonicalString(true), false) + PCanonicalDict(PInt32(true), PCanonicalString(true), false), ) - res = InferPType.getCompatiblePType(types) + res = InferPType.getCompatiblePType(types) assert(res == PCanonicalDict(PInt32(true), PCanonicalString(true), false)) types = IndexedSeq( PCanonicalDict(PInt32(false), PCanonicalString(true), true), - PCanonicalDict(PInt32(true), PCanonicalString(true), true) + PCanonicalDict(PInt32(true), PCanonicalString(true), true), ) - res = InferPType.getCompatiblePType(types) + res = InferPType.getCompatiblePType(types) assert(res == PCanonicalDict(PInt32(false), PCanonicalString(true), true)) types = IndexedSeq( PCanonicalDict(PInt32(false), PCanonicalString(true), true), - PCanonicalDict(PInt32(true), PCanonicalString(false), true) + PCanonicalDict(PInt32(true), PCanonicalString(false), true), ) - res = InferPType.getCompatiblePType(types) + res = InferPType.getCompatiblePType(types) assert(res == PCanonicalDict(PInt32(false), PCanonicalString(false), true)) types = IndexedSeq( PCanonicalDict(PInt32(false), PCanonicalString(true), false), - PCanonicalDict(PInt32(true), PCanonicalString(false), true) + PCanonicalDict(PInt32(true), PCanonicalString(false), true), ) - res = InferPType.getCompatiblePType(types) + res = InferPType.getCompatiblePType(types) assert(res == PCanonicalDict(PInt32(false), PCanonicalString(false), false)) types = IndexedSeq( + PCanonicalDict( + PInt32(true), + PCanonicalDict(PInt32(true), PCanonicalString(true), true), + true, + ), PCanonicalDict(PInt32(true), PCanonicalDict(PInt32(true), PCanonicalString(true), true), true), - PCanonicalDict(PInt32(true), PCanonicalDict(PInt32(true), PCanonicalString(true), true), true) ) - res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalDict(PInt32(true), PCanonicalDict(PInt32(true), PCanonicalString(true), true), true)) + res = InferPType.getCompatiblePType(types) + assert(res == PCanonicalDict( + PInt32(true), + PCanonicalDict(PInt32(true), PCanonicalString(true), true), + true, + )) types = IndexedSeq( - PCanonicalDict(PInt32(true), PCanonicalDict(PInt32(false), PCanonicalString(true), true), true), - PCanonicalDict(PInt32(true), PCanonicalDict(PInt32(true), PCanonicalString(true), true), true) + PCanonicalDict( + PInt32(true), + PCanonicalDict(PInt32(false), PCanonicalString(true), true), + true, + ), + PCanonicalDict(PInt32(true), PCanonicalDict(PInt32(true), PCanonicalString(true), true), true), ) - res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalDict(PInt32(true), PCanonicalDict(PInt32(false), PCanonicalString(true), true), true)) + res = InferPType.getCompatiblePType(types) + assert(res == PCanonicalDict( + PInt32(true), + PCanonicalDict(PInt32(false), PCanonicalString(true), true), + true, + )) types = IndexedSeq( - PCanonicalDict(PInt32(true), PCanonicalDict(PInt32(false), PCanonicalString(true), true), true), - PCanonicalDict(PInt32(true), PCanonicalDict(PInt32(true), PCanonicalString(false), true), true) + PCanonicalDict( + PInt32(true), + PCanonicalDict(PInt32(false), PCanonicalString(true), true), + true, + ), + PCanonicalDict( + PInt32(true), + PCanonicalDict(PInt32(true), PCanonicalString(false), true), + true, + ), ) - res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalDict(PInt32(true), PCanonicalDict(PInt32(false), PCanonicalString(false), true), true)) + res = InferPType.getCompatiblePType(types) + assert(res == PCanonicalDict( + PInt32(true), + PCanonicalDict(PInt32(false), PCanonicalString(false), true), + true, + )) types = IndexedSeq( - PCanonicalDict(PInt32(true), PCanonicalDict(PInt32(false), PCanonicalString(true), true), true), - PCanonicalDict(PInt32(true), PCanonicalDict(PInt32(true), PCanonicalString(false), true), true) + PCanonicalDict( + PInt32(true), + PCanonicalDict(PInt32(false), PCanonicalString(true), true), + true, + ), + PCanonicalDict( + PInt32(true), + PCanonicalDict(PInt32(true), PCanonicalString(false), true), + true, + ), ) - res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalDict(PInt32(true), PCanonicalDict(PInt32(false), PCanonicalString(false), true), true)) + res = InferPType.getCompatiblePType(types) + assert(res == PCanonicalDict( + PInt32(true), + PCanonicalDict(PInt32(false), PCanonicalString(false), true), + true, + )) types = IndexedSeq( - PCanonicalDict(PInt32(true), PCanonicalDict(PInt32(false), PCanonicalString(true), false), true), - PCanonicalDict(PInt32(true), PCanonicalDict(PInt32(true), PCanonicalString(false), true), true) + PCanonicalDict( + PInt32(true), + PCanonicalDict(PInt32(false), PCanonicalString(true), false), + true, + ), + PCanonicalDict( + PInt32(true), + PCanonicalDict(PInt32(true), PCanonicalString(false), true), + true, + ), ) - res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalDict(PInt32(true), PCanonicalDict(PInt32(false), PCanonicalString(false), false), true)) + res = InferPType.getCompatiblePType(types) + assert(res == PCanonicalDict( + PInt32(true), + PCanonicalDict(PInt32(false), PCanonicalString(false), false), + true, + )) } - @Test def testGetNestedElementPCanonicalStruct() { + @Test def testGetNestedElementPCanonicalStruct(): Unit = { var types = IndexedSeq(PCanonicalStruct(true, "a" -> PInt32(true), "b" -> PInt32(true))) var res = InferPType.getCompatiblePType(types) assert(res == PCanonicalStruct(true, "a" -> PInt32(true), "b" -> PInt32(true))) @@ -896,51 +1012,93 @@ class IRSuite extends HailSuite { types = IndexedSeq( PCanonicalStruct(true, "a" -> PInt32(true), "b" -> PInt32(true)), - PCanonicalStruct(true, "a" -> PInt32(true), "b" -> PInt32(true)) + PCanonicalStruct(true, "a" -> PInt32(true), "b" -> PInt32(true)), ) res = InferPType.getCompatiblePType(types) assert(res == PCanonicalStruct(true, "a" -> PInt32(true), "b" -> PInt32(true))) types = IndexedSeq( PCanonicalStruct(true, "a" -> PInt32(true), "b" -> PInt32(true)), - PCanonicalStruct(true, "a" -> PInt32(false), "b" -> PInt32(false)) + PCanonicalStruct(true, "a" -> PInt32(false), "b" -> PInt32(false)), ) res = InferPType.getCompatiblePType(types) assert(res == PCanonicalStruct(true, "a" -> PInt32(false), "b" -> PInt32(false))) types = IndexedSeq( PCanonicalStruct(false, "a" -> PInt32(true), "b" -> PInt32(true)), - PCanonicalStruct(true, "a" -> PInt32(false), "b" -> PInt32(false)) + PCanonicalStruct(true, "a" -> PInt32(false), "b" -> PInt32(false)), ) res = InferPType.getCompatiblePType(types) assert(res == PCanonicalStruct(false, "a" -> PInt32(false), "b" -> PInt32(false))) types = IndexedSeq( - PCanonicalStruct(true, "a" -> PCanonicalStruct(true, "c" -> PInt32(true), "d" -> PInt32(true)),"b" -> PInt32(true)) + PCanonicalStruct( + true, + "a" -> PCanonicalStruct(true, "c" -> PInt32(true), "d" -> PInt32(true)), + "b" -> PInt32(true), + ) ) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalStruct(true, "a" -> PCanonicalStruct(true, "c" -> PInt32(true), "d" -> PInt32(true)), "b" -> PInt32(true))) + assert(res == PCanonicalStruct( + true, + "a" -> PCanonicalStruct(true, "c" -> PInt32(true), "d" -> PInt32(true)), + "b" -> PInt32(true), + )) types = IndexedSeq( - PCanonicalStruct(true, "a" -> PCanonicalStruct(true, "c" -> PInt32(false), "d" -> PInt32(true)),"b" -> PInt32(true)) + PCanonicalStruct( + true, + "a" -> PCanonicalStruct(true, "c" -> PInt32(false), "d" -> PInt32(true)), + "b" -> PInt32(true), + ) ) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalStruct(true, "a" -> PCanonicalStruct(true, "c" -> PInt32(false), "d" -> PInt32(true)), "b" -> PInt32(true))) + assert(res == PCanonicalStruct( + true, + "a" -> PCanonicalStruct(true, "c" -> PInt32(false), "d" -> PInt32(true)), + "b" -> PInt32(true), + )) types = IndexedSeq( - PCanonicalStruct(true, "a" -> PCanonicalStruct(true, "c" -> PInt32(false), "d" -> PInt32(false)), "b" -> PInt32(true)), - PCanonicalStruct(true, "a" -> PCanonicalStruct(true, "c" -> PInt32(true), "d" -> PInt32(true)), "b" -> PInt32(true))) + PCanonicalStruct( + true, + "a" -> PCanonicalStruct(true, "c" -> PInt32(false), "d" -> PInt32(false)), + "b" -> PInt32(true), + ), + PCanonicalStruct( + true, + "a" -> PCanonicalStruct(true, "c" -> PInt32(true), "d" -> PInt32(true)), + "b" -> PInt32(true), + ), + ) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalStruct(true, "a" -> PCanonicalStruct(true, "c" -> PInt32(false), "d" -> PInt32(false)), "b" -> PInt32(true))) + assert(res == PCanonicalStruct( + true, + "a" -> PCanonicalStruct(true, "c" -> PInt32(false), "d" -> PInt32(false)), + "b" -> PInt32(true), + )) types = IndexedSeq( - PCanonicalStruct(true, "a" -> PCanonicalStruct(false, "c" -> PInt32(false), "d" -> PInt32(false)), "b" -> PInt32(true)), - PCanonicalStruct(true, "a" -> PCanonicalStruct(true, "c" -> PInt32(true), "d" -> PInt32(true)), "b" -> PInt32(true))) + PCanonicalStruct( + true, + "a" -> PCanonicalStruct(false, "c" -> PInt32(false), "d" -> PInt32(false)), + "b" -> PInt32(true), + ), + PCanonicalStruct( + true, + "a" -> PCanonicalStruct(true, "c" -> PInt32(true), "d" -> PInt32(true)), + "b" -> PInt32(true), + ), + ) res = InferPType.getCompatiblePType(types) - assert(res == PCanonicalStruct(true, "a" -> PCanonicalStruct(false, "c" -> PInt32(false), "d" -> PInt32(false)), "b" -> PInt32(true))) + assert(res == PCanonicalStruct( + true, + "a" -> PCanonicalStruct(false, "c" -> PInt32(false), "d" -> PInt32(false)), + "b" -> PInt32(true), + )) } - @Test def testGetNestedElementPCanonicalTuple() { + @Test def testGetNestedElementPCanonicalTuple(): Unit = { var types = IndexedSeq(PCanonicalTuple(true, PInt32(true))) var res = InferPType.getCompatiblePType(types) assert(res == PCanonicalTuple(true, PInt32(true))) @@ -959,41 +1117,41 @@ class IRSuite extends HailSuite { types = IndexedSeq( PCanonicalTuple(true, PInt32(true)), - PCanonicalTuple(true, PInt32(true)) + PCanonicalTuple(true, PInt32(true)), ) res = InferPType.getCompatiblePType(types) assert(res == PCanonicalTuple(true, PInt32(true))) types = IndexedSeq( PCanonicalTuple(true, PInt32(true)), - PCanonicalTuple(false, PInt32(true)) + PCanonicalTuple(false, PInt32(true)), ) res = InferPType.getCompatiblePType(types) assert(res == PCanonicalTuple(false, PInt32(true))) types = IndexedSeq( PCanonicalTuple(true, PInt32(false)), - PCanonicalTuple(false, PInt32(true)) + PCanonicalTuple(false, PInt32(true)), ) res = InferPType.getCompatiblePType(types) assert(res == PCanonicalTuple(false, PInt32(false))) types = IndexedSeq( PCanonicalTuple(true, PCanonicalTuple(true, PInt32(true))), - PCanonicalTuple(true, PCanonicalTuple(true, PInt32(false))) + PCanonicalTuple(true, PCanonicalTuple(true, PInt32(false))), ) res = InferPType.getCompatiblePType(types) assert(res == PCanonicalTuple(true, PCanonicalTuple(true, PInt32(false)))) types = IndexedSeq( PCanonicalTuple(true, PCanonicalTuple(false, PInt32(true))), - PCanonicalTuple(true, PCanonicalTuple(true, PInt32(false))) + PCanonicalTuple(true, PCanonicalTuple(true, PInt32(false))), ) res = InferPType.getCompatiblePType(types) assert(res == PCanonicalTuple(true, PCanonicalTuple(false, PInt32(false)))) } - @Test def testGetNestedElementPCanonicalSet() { + @Test def testGetNestedElementPCanonicalSet(): Unit = { var types = IndexedSeq(PCanonicalSet(PInt32(true), true)) var res = InferPType.getCompatiblePType(types) assert(res == PCanonicalSet(PInt32(true), true)) @@ -1012,48 +1170,48 @@ class IRSuite extends HailSuite { types = IndexedSeq( PCanonicalSet(PInt32(true), true), - PCanonicalSet(PInt32(true), true) + PCanonicalSet(PInt32(true), true), ) res = InferPType.getCompatiblePType(types) assert(res == PCanonicalSet(PInt32(true), true)) types = IndexedSeq( PCanonicalSet(PInt32(false), true), - PCanonicalSet(PInt32(true), true) + PCanonicalSet(PInt32(true), true), ) res = InferPType.getCompatiblePType(types) assert(res == PCanonicalSet(PInt32(false), true)) types = IndexedSeq( PCanonicalSet(PInt32(false), true), - PCanonicalSet(PInt32(true), false) + PCanonicalSet(PInt32(true), false), ) res = InferPType.getCompatiblePType(types) assert(res == PCanonicalSet(PInt32(false), false)) types = IndexedSeq( PCanonicalSet(PCanonicalSet(PInt32(true), true), true), - PCanonicalSet(PCanonicalSet(PInt32(true), true), true) + PCanonicalSet(PCanonicalSet(PInt32(true), true), true), ) res = InferPType.getCompatiblePType(types) assert(res == PCanonicalSet(PCanonicalSet(PInt32(true), true), true)) types = IndexedSeq( PCanonicalSet(PCanonicalSet(PInt32(true), true), true), - PCanonicalSet(PCanonicalSet(PInt32(false), true), true) + PCanonicalSet(PCanonicalSet(PInt32(false), true), true), ) res = InferPType.getCompatiblePType(types) assert(res == PCanonicalSet(PCanonicalSet(PInt32(false), true), true)) types = IndexedSeq( PCanonicalSet(PCanonicalSet(PInt32(true), false), true), - PCanonicalSet(PCanonicalSet(PInt32(false), true), true) + PCanonicalSet(PCanonicalSet(PInt32(false), true), true), ) res = InferPType.getCompatiblePType(types) assert(res == PCanonicalSet(PCanonicalSet(PInt32(false), false), true)) } - @Test def testGetNestedElementPCanonicalInterval() { + @Test def testGetNestedElementPCanonicalInterval(): Unit = { var types = IndexedSeq(PCanonicalInterval(PInt32(true), true)) var res = InferPType.getCompatiblePType(types) assert(res == PCanonicalInterval(PInt32(true), true)) @@ -1072,90 +1230,95 @@ class IRSuite extends HailSuite { types = IndexedSeq( PCanonicalInterval(PInt32(true), true), - PCanonicalInterval(PInt32(true), true) + PCanonicalInterval(PInt32(true), true), ) res = InferPType.getCompatiblePType(types) assert(res == PCanonicalInterval(PInt32(true), true)) types = IndexedSeq( PCanonicalInterval(PInt32(false), true), - PCanonicalInterval(PInt32(true), true) + PCanonicalInterval(PInt32(true), true), ) res = InferPType.getCompatiblePType(types) assert(res == PCanonicalInterval(PInt32(false), true)) types = IndexedSeq( PCanonicalInterval(PInt32(true), true), - PCanonicalInterval(PInt32(true), false) + PCanonicalInterval(PInt32(true), false), ) res = InferPType.getCompatiblePType(types) assert(res == PCanonicalInterval(PInt32(true), false)) types = IndexedSeq( PCanonicalInterval(PInt32(false), true), - PCanonicalInterval(PInt32(true), false) + PCanonicalInterval(PInt32(true), false), ) res = InferPType.getCompatiblePType(types) assert(res == PCanonicalInterval(PInt32(false), false)) types = IndexedSeq( PCanonicalInterval(PCanonicalInterval(PInt32(true), true), true), - PCanonicalInterval(PCanonicalInterval(PInt32(true), true), true) + PCanonicalInterval(PCanonicalInterval(PInt32(true), true), true), ) res = InferPType.getCompatiblePType(types) assert(res == PCanonicalInterval(PCanonicalInterval(PInt32(true), true), true)) types = IndexedSeq( PCanonicalInterval(PCanonicalInterval(PInt32(true), false), true), - PCanonicalInterval(PCanonicalInterval(PInt32(true), true), true) + PCanonicalInterval(PCanonicalInterval(PInt32(true), true), true), ) res = InferPType.getCompatiblePType(types) assert(res == PCanonicalInterval(PCanonicalInterval(PInt32(true), false), true)) types = IndexedSeq( PCanonicalInterval(PCanonicalInterval(PInt32(false), true), true), - PCanonicalInterval(PCanonicalInterval(PInt32(true), true), true) + PCanonicalInterval(PCanonicalInterval(PInt32(true), true), true), ) res = InferPType.getCompatiblePType(types) assert(res == PCanonicalInterval(PCanonicalInterval(PInt32(false), true), true)) types = IndexedSeq( PCanonicalInterval(PCanonicalInterval(PInt32(true), false), true), - PCanonicalInterval(PCanonicalInterval(PInt32(false), true), true) + PCanonicalInterval(PCanonicalInterval(PInt32(false), true), true), ) res = InferPType.getCompatiblePType(types) assert(res == PCanonicalInterval(PCanonicalInterval(PInt32(false), false), true)) } - @Test def testMakeStruct() { + @Test def testMakeStruct(): Unit = { assertEvalsTo(MakeStruct(FastSeq()), Row()) assertEvalsTo(MakeStruct(FastSeq("a" -> NA(TInt32), "b" -> 4, "c" -> 0.5)), Row(null, 4, 0.5)) - //making sure wide structs get emitted without failure + // making sure wide structs get emitted without failure assertEvalsTo(GetField(MakeStruct((0 until 20000).map(i => s"foo$i" -> I32(1))), "foo1"), 1) } @Test def testMakeArrayWithDifferentRequiredness(): Unit = { val pt1 = PCanonicalArray(PCanonicalStruct("a" -> PInt32(), "b" -> PCanonicalArray(PInt32()))) - val pt2 = PCanonicalArray(PCanonicalStruct(true, "a" -> PInt32(true), "b" -> PCanonicalArray(PInt32(), true))) + val pt2 = PCanonicalArray(PCanonicalStruct( + true, + "a" -> PInt32(true), + "b" -> PCanonicalArray(PInt32(), true), + )) val value = Row(2, FastSeq(1)) assertEvalsTo( MakeArray( In(0, SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(pt1.elementType))), - In(1, SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(pt2.elementType)))), + In(1, SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(pt2.elementType))), + ), FastSeq((null, pt1.virtualType.elementType), (value, pt2.virtualType.elementType)), - FastSeq(null, value) + FastSeq(null, value), ) } - @Test def testMakeTuple() { + @Test def testMakeTuple(): Unit = { assertEvalsTo(MakeTuple.ordered(FastSeq()), Row()) assertEvalsTo(MakeTuple.ordered(FastSeq(NA(TInt32), 4, 0.5)), Row(null, 4, 0.5)) - //making sure wide structs get emitted without failure + // making sure wide structs get emitted without failure assertEvalsTo(GetTupleElement(MakeTuple.ordered((0 until 20000).map(I32)), 1), 1) } - @Test def testGetTupleElement() { + @Test def testGetTupleElement(): Unit = { implicit val execStrats = ExecStrategy.javaOnly val t = MakeTuple.ordered(FastSeq(I32(5), Str("abc"), NA(TInt32))) @@ -1167,44 +1330,56 @@ class IRSuite extends HailSuite { assertEvalsTo(GetTupleElement(na, 0), null) } - @Test def testLetBoundPrunedTuple(): Unit = { + @Test def testLetBoundPrunedTuple(): Unit = { implicit val execStrats = ExecStrategy.unoptimizedCompileOnly val t2 = MakeTuple(FastSeq((2, I32(5)))) - val letBoundTuple = bindIR(t2) { tupleRef => - GetTupleElement(tupleRef, 2) - } + val letBoundTuple = bindIR(t2)(tupleRef => GetTupleElement(tupleRef, 2)) assertEvalsTo(letBoundTuple, 5) } - @Test def testArrayRef() { - assertEvalsTo(ArrayRef(MakeArray(FastSeq(I32(5), NA(TInt32)), TArray(TInt32)), I32(0), ErrorIDs.NO_ERROR), 5) - assertEvalsTo(ArrayRef(MakeArray(FastSeq(I32(5), NA(TInt32)), TArray(TInt32)), I32(1), ErrorIDs.NO_ERROR), null) - assertEvalsTo(ArrayRef(MakeArray(FastSeq(I32(5), NA(TInt32)), TArray(TInt32)), NA(TInt32), ErrorIDs.NO_ERROR), null) + @Test def testArrayRef(): Unit = { + assertEvalsTo( + ArrayRef(MakeArray(FastSeq(I32(5), NA(TInt32)), TArray(TInt32)), I32(0), ErrorIDs.NO_ERROR), + 5, + ) + assertEvalsTo( + ArrayRef(MakeArray(FastSeq(I32(5), NA(TInt32)), TArray(TInt32)), I32(1), ErrorIDs.NO_ERROR), + null, + ) + assertEvalsTo( + ArrayRef( + MakeArray(FastSeq(I32(5), NA(TInt32)), TArray(TInt32)), + NA(TInt32), + ErrorIDs.NO_ERROR, + ), + null, + ) - assertFatal(ArrayRef(MakeArray(FastSeq(I32(5)), TArray(TInt32)), I32(2)), "array index out of bounds") + assertFatal( + ArrayRef(MakeArray(FastSeq(I32(5)), TArray(TInt32)), I32(2)), + "array index out of bounds", + ) } - @Test def testArrayLen() { + @Test def testArrayLen(): Unit = { assertEvalsTo(ArrayLen(NA(TArray(TInt32))), null) assertEvalsTo(ArrayLen(MakeArray(FastSeq(), TArray(TInt32))), 0) assertEvalsTo(ArrayLen(MakeArray(FastSeq(I32(5), NA(TInt32)), TArray(TInt32))), 2) } - @Test def testArraySort() { + @Test def testArraySort(): Unit = { implicit val execStrats = ExecStrategy.javaOnly assertEvalsTo(ArraySort(ToStream(NA(TArray(TInt32)))), null) val a = MakeArray(FastSeq(I32(-7), I32(2), NA(TInt32), I32(2)), TArray(TInt32)) - assertEvalsTo(ArraySort(ToStream(a)), - FastSeq(-7, 2, 2, null)) - assertEvalsTo(ArraySort(ToStream(a), False()), - FastSeq(2, 2, -7, null)) + assertEvalsTo(ArraySort(ToStream(a)), FastSeq(-7, 2, 2, null)) + assertEvalsTo(ArraySort(ToStream(a), False()), FastSeq(2, 2, -7, null)) } - @Test def testStreamZip() { + @Test def testStreamZip(): Unit = { val range12 = StreamRange(0, 12, 1) val range6 = StreamRange(0, 12, 2) val range8 = StreamRange(0, 24, 3) @@ -1215,8 +1390,10 @@ class IRSuite extends HailSuite { def zip(behavior: ArrayZipBehavior, irs: IR*): IR = StreamZip( irs.toFastSeq, irs.indices.map(_.toString), - MakeTuple.ordered(irs.toArray.zipWithIndex.map { case (ir, i) => Ref(i.toString, ir.typ.asInstanceOf[TStream].elementType) }), - behavior + MakeTuple.ordered(irs.toArray.zipWithIndex.map { case (ir, i) => + Ref(i.toString, ir.typ.asInstanceOf[TStream].elementType) + }), + behavior, ) def zipToTuple(behavior: ArrayZipBehavior, irs: IR*): IR = ToArray(zip(behavior, irs: _*)) @@ -1244,11 +1421,17 @@ class IRSuite extends HailSuite { assertEvalsTo(StreamLen(zip(ArrayZipBehavior.AssumeSameLength, range8, range8)), 8) // https://github.com/hail-is/hail/issues/8359 - is.hail.TestUtils.assertThrows[HailException](zipToTuple(ArrayZipBehavior.AssertSameLength, range6, range8): IR, "zip: length mismatch": String) - is.hail.TestUtils.assertThrows[HailException](zipToTuple(ArrayZipBehavior.AssertSameLength, range12, lit6): IR, "zip: length mismatch": String) + is.hail.TestUtils.assertThrows[HailException]( + zipToTuple(ArrayZipBehavior.AssertSameLength, range6, range8): IR, + "zip: length mismatch": String, + ) + is.hail.TestUtils.assertThrows[HailException]( + zipToTuple(ArrayZipBehavior.AssertSameLength, range12, lit6): IR, + "zip: length mismatch": String, + ) } - @Test def testToSet() { + @Test def testToSet(): Unit = { implicit val execStrats = ExecStrategy.javaOnly assertEvalsTo(ToSet(ToStream(NA(TArray(TInt32)))), null) @@ -1258,91 +1441,92 @@ class IRSuite extends HailSuite { assertEvalsTo(ToSet(ToStream(a)), Set(-7, 2, null)) } - @Test def testToArrayFromSet() { + @Test def testToArrayFromSet(): Unit = { val t = TSet(TInt32) assertEvalsTo(CastToArray(NA(t)), null) - assertEvalsTo(CastToArray(In(0, t)), - FastSeq((Set(-7, 2, null), t)), - FastSeq(-7, 2, null)) + assertEvalsTo(CastToArray(In(0, t)), FastSeq((Set(-7, 2, null), t)), FastSeq(-7, 2, null)) } - @Test def testToDict() { + @Test def testToDict(): Unit = { implicit val execStrats = ExecStrategy.javaOnly assertEvalsTo(ToDict(ToStream(NA(TArray(TTuple(FastSeq(TInt32, TString): _*))))), null) - val a = MakeArray(FastSeq( - MakeTuple.ordered(FastSeq(I32(5), Str("a"))), - MakeTuple.ordered(FastSeq(I32(5), Str("a"))), // duplicate key-value pair - MakeTuple.ordered(FastSeq(NA(TInt32), Str("b"))), - MakeTuple.ordered(FastSeq(I32(3), NA(TString))), - NA(TTuple(FastSeq(TInt32, TString): _*)) // missing value - ), TArray(TTuple(FastSeq(TInt32, TString): _*))) + val a = MakeArray( + FastSeq( + MakeTuple.ordered(FastSeq(I32(5), Str("a"))), + MakeTuple.ordered(FastSeq(I32(5), Str("a"))), // duplicate key-value pair + MakeTuple.ordered(FastSeq(NA(TInt32), Str("b"))), + MakeTuple.ordered(FastSeq(I32(3), NA(TString))), + NA(TTuple(FastSeq(TInt32, TString): _*)), // missing value + ), + TArray(TTuple(FastSeq(TInt32, TString): _*)), + ) assertEvalsTo(ToDict(ToStream(a)), Map(5 -> "a", (null, "b"), 3 -> null)) } - @Test def testToArrayFromDict() { + @Test def testToArrayFromDict(): Unit = { val t = TDict(TInt32, TString) assertEvalsTo(CastToArray(NA(t)), null) - val d = Map(1 -> "a", 2 -> null, (null, "c")) - assertEvalsTo(CastToArray(In(0, t)), + val d: Map[Any, Any] = Map(1 -> "a", 2 -> null, (null, "c")) + assertEvalsTo( + CastToArray(In(0, t)), // wtf you can't do null -> ... FastSeq((d, t)), - FastSeq(Row(1, "a"), Row(2, null), Row(null, "c"))) + FastSeq(Row(1, "a"), Row(2, null), Row(null, "c")), + ) } - @Test def testToArrayFromArray() { + @Test def testToArrayFromArray(): Unit = { val t = TArray(TInt32) assertEvalsTo(NA(t), null) - assertEvalsTo(In(0, t), - FastSeq((FastSeq(-7, 2, null, 2), t)), - FastSeq(-7, 2, null, 2)) + assertEvalsTo(In(0, t), FastSeq((FastSeq(-7, 2, null, 2), t)), FastSeq(-7, 2, null, 2)) } - @Test def testSetContains() { + @Test def testSetContains(): Unit = { implicit val execStrats = ExecStrategy.javaOnly val t = TSet(TInt32) assertEvalsTo(invoke("contains", TBoolean, NA(t), I32(2)), null) - assertEvalsTo(invoke("contains", TBoolean, In(0, t), NA(TInt32)), + assertEvalsTo( + invoke("contains", TBoolean, In(0, t), NA(TInt32)), FastSeq((Set(-7, 2, null), t)), - true) - assertEvalsTo(invoke("contains", TBoolean, In(0, t), I32(2)), + true, + ) + assertEvalsTo( + invoke("contains", TBoolean, In(0, t), I32(2)), FastSeq((Set(-7, 2, null), t)), - true) - assertEvalsTo(invoke("contains", TBoolean, In(0, t), I32(0)), + true, + ) + assertEvalsTo( + invoke("contains", TBoolean, In(0, t), I32(0)), FastSeq((Set(-7, 2, null), t)), - false) - assertEvalsTo(invoke("contains", TBoolean, In(0, t), I32(7)), - FastSeq((Set(-7, 2), t)), - false) + false, + ) + assertEvalsTo(invoke("contains", TBoolean, In(0, t), I32(7)), FastSeq((Set(-7, 2), t)), false) } - @Test def testDictContains() { + @Test def testDictContains(): Unit = { implicit val execStrats = ExecStrategy.javaOnly val t = TDict(TInt32, TString) assertEvalsTo(invoke("contains", TBoolean, NA(t), I32(2)), null) - val d = Map(1 -> "a", 2 -> null, (null, "c")) - assertEvalsTo(invoke("contains", TBoolean, In(0, t), NA(TInt32)), - FastSeq((d, t)), - true) - assertEvalsTo(invoke("contains", TBoolean, In(0, t), I32(2)), - FastSeq((d, t)), - true) - assertEvalsTo(invoke("contains", TBoolean, In(0, t), I32(0)), - FastSeq((d, t)), - false) - assertEvalsTo(invoke("contains", TBoolean, In(0, t), I32(3)), + val d: Map[Any, Any] = Map(1 -> "a", 2 -> null, (null, "c")) + assertEvalsTo(invoke("contains", TBoolean, In(0, t), NA(TInt32)), FastSeq((d, t)), true) + assertEvalsTo(invoke("contains", TBoolean, In(0, t), I32(2)), FastSeq((d, t)), true) + assertEvalsTo(invoke("contains", TBoolean, In(0, t), I32(0)), FastSeq((d, t)), false) + assertEvalsTo( + invoke("contains", TBoolean, In(0, t), I32(3)), FastSeq((Map(1 -> "a", 2 -> null), t)), - false) + false, + ) } - @Test def testLowerBoundOnOrderedCollectionArray() { + @Test def testLowerBoundOnOrderedCollectionArray(): Unit = { implicit val execStrats = ExecStrategy.javaOnly val na = NA(TArray(TInt32)) @@ -1350,24 +1534,25 @@ class IRSuite extends HailSuite { val awoutna = MakeArray(FastSeq(I32(0), I32(2), I32(4)), TArray(TInt32)) val awna = MakeArray(FastSeq(I32(0), I32(2), I32(4), NA(TInt32)), TArray(TInt32)) - val awdups = MakeArray(FastSeq(I32(0), I32(0), I32(2), I32(4), I32(4), NA(TInt32)), TArray(TInt32)) + val awdups = + MakeArray(FastSeq(I32(0), I32(0), I32(2), I32(4), I32(4), NA(TInt32)), TArray(TInt32)) assertAllEvalTo( (LowerBoundOnOrderedCollection(awoutna, I32(-1), onKey = false), 0), - (LowerBoundOnOrderedCollection(awoutna, I32(0), onKey = false), 0), - (LowerBoundOnOrderedCollection(awoutna, I32(1), onKey = false), 1), - (LowerBoundOnOrderedCollection(awoutna, I32(2), onKey = false), 1), - (LowerBoundOnOrderedCollection(awoutna, I32(3), onKey = false), 2), - (LowerBoundOnOrderedCollection(awoutna, I32(4), onKey = false), 2), - (LowerBoundOnOrderedCollection(awoutna, I32(5), onKey = false), 3), - (LowerBoundOnOrderedCollection(awoutna, NA(TInt32), onKey = false), 3), - (LowerBoundOnOrderedCollection(awna, NA(TInt32), onKey = false), 3), - (LowerBoundOnOrderedCollection(awna, I32(5), onKey = false), 3), - (LowerBoundOnOrderedCollection(awdups, I32(0), onKey = false), 0), - (LowerBoundOnOrderedCollection(awdups, I32(4), onKey = false), 3) + (LowerBoundOnOrderedCollection(awoutna, I32(0), onKey = false), 0), + (LowerBoundOnOrderedCollection(awoutna, I32(1), onKey = false), 1), + (LowerBoundOnOrderedCollection(awoutna, I32(2), onKey = false), 1), + (LowerBoundOnOrderedCollection(awoutna, I32(3), onKey = false), 2), + (LowerBoundOnOrderedCollection(awoutna, I32(4), onKey = false), 2), + (LowerBoundOnOrderedCollection(awoutna, I32(5), onKey = false), 3), + (LowerBoundOnOrderedCollection(awoutna, NA(TInt32), onKey = false), 3), + (LowerBoundOnOrderedCollection(awna, NA(TInt32), onKey = false), 3), + (LowerBoundOnOrderedCollection(awna, I32(5), onKey = false), 3), + (LowerBoundOnOrderedCollection(awdups, I32(0), onKey = false), 0), + (LowerBoundOnOrderedCollection(awdups, I32(4), onKey = false), 3), ) } - @Test def testLowerBoundOnOrderedCollectionSet() { + @Test def testLowerBoundOnOrderedCollectionSet(): Unit = { implicit val execStrats = ExecStrategy.javaOnly val na = NA(TSet(TInt32)) @@ -1383,12 +1568,13 @@ class IRSuite extends HailSuite { assertEvalsTo(LowerBoundOnOrderedCollection(swoutna, I32(5), onKey = false), 3) assertEvalsTo(LowerBoundOnOrderedCollection(swoutna, NA(TInt32), onKey = false), 3) - val swna = ToSet(MakeStream(FastSeq(I32(0), I32(2), I32(2), I32(4), NA(TInt32)), TStream(TInt32))) + val swna = + ToSet(MakeStream(FastSeq(I32(0), I32(2), I32(2), I32(4), NA(TInt32)), TStream(TInt32))) assertEvalsTo(LowerBoundOnOrderedCollection(swna, NA(TInt32), onKey = false), 3) assertEvalsTo(LowerBoundOnOrderedCollection(swna, I32(5), onKey = false), 3) } - @Test def testLowerBoundOnOrderedCollectionDict() { + @Test def testLowerBoundOnOrderedCollectionDict(): Unit = { implicit val execStrats = ExecStrategy.javaOnly val na = NA(TDict(TInt32, TString)) @@ -1421,28 +1607,36 @@ class IRSuite extends HailSuite { assertEvalsTo(range2, 10) val range3 = StreamLen(StreamRange(-10, 5, 1)) assertEvalsTo(range3, 15) - val mappedRange = StreamLen(mapIR(StreamRange(2, 7, 1)) { ref => maxIR(ref, 3)}) + val mappedRange = StreamLen(mapIR(StreamRange(2, 7, 1))(ref => maxIR(ref, 3))) assertEvalsTo(mappedRange, 5) - val streamOfStreams = mapIR(rangeIR(5)) { elementRef => rangeIR(elementRef) } - assertEvalsTo(StreamLen(flatMapIR(streamOfStreams){ x => x}), 4 + 3 + 2 + 1) + val streamOfStreams = mapIR(rangeIR(5))(elementRef => rangeIR(elementRef)) + assertEvalsTo(StreamLen(flatMapIR(streamOfStreams)(x => x)), 4 + 3 + 2 + 1) - val filteredRange = StreamLen(StreamFilter(rangeIR(12), "x", irToPrimitiveIR(Ref("x", TInt32)) < 5)) + val filteredRange = + StreamLen(StreamFilter(rangeIR(12), "x", irToPrimitiveIR(Ref("x", TInt32)) < 5)) assertEvalsTo(filteredRange, 5) val lenOfLet = StreamLen(bindIR(I32(5))(ref => - StreamGrouped(mapIR(rangeIR(20))(range_element => - InsertFields(MakeStruct(IndexedSeq(("num", range_element + ref))), IndexedSeq(("y", 12)))), 3))) + StreamGrouped( + mapIR(rangeIR(20))(range_element => + InsertFields(MakeStruct(IndexedSeq(("num", range_element + ref))), IndexedSeq(("y", 12))) + ), + 3, + ) + )) assertEvalsTo(lenOfLet, 7) } - @Test def testStreamLenUnconsumedInnerStream(): Unit = { - assertEvalsTo(StreamLen( - mapIR(StreamGrouped(filterIR(rangeIR(10))(x => x.cne(I32(0))), 3))( group => ToArray(group)) - ), 3) - } + @Test def testStreamLenUnconsumedInnerStream(): Unit = + assertEvalsTo( + StreamLen( + mapIR(StreamGrouped(filterIR(rangeIR(10))(x => x.cne(I32(0))), 3))(group => ToArray(group)) + ), + 3, + ) - @Test def testStreamTake() { + @Test def testStreamTake(): Unit = { val naa = NA(TStream(TInt32)) val a = MakeStream(IndexedSeq(I32(3), NA(TInt32), I32(7)), TStream(TInt32)) @@ -1455,7 +1649,7 @@ class IRSuite extends HailSuite { assertEvalsTo(StreamLen(StreamTake(a, 2)), 2) } - @Test def testStreamDrop() { + @Test def testStreamDrop(): Unit = { val naa = NA(TStream(TInt32)) val a = MakeStream(IndexedSeq(I32(3), NA(TInt32), I32(7)), TStream(TInt32)) @@ -1474,14 +1668,20 @@ class IRSuite extends HailSuite { ToArray(StreamMap(stream, "inner", ToArray(Ref("inner", innerType)))) } - @Test def testStreamGrouped() { + @Test def testStreamGrouped(): Unit = { val naa = NA(TStream(TInt32)) val a = MakeStream(IndexedSeq(I32(3), NA(TInt32), I32(7)), TStream(TInt32)) assertEvalsTo(toNestedArray(StreamGrouped(naa, I32(2))), null) assertEvalsTo(toNestedArray(StreamGrouped(a, NA(TInt32))), null) - assertEvalsTo(toNestedArray(StreamGrouped(MakeStream(IndexedSeq(), TStream(TInt32)), I32(2))), FastSeq()) - assertEvalsTo(toNestedArray(StreamGrouped(a, I32(1))), FastSeq(FastSeq(3), FastSeq(null), FastSeq(7))) + assertEvalsTo( + toNestedArray(StreamGrouped(MakeStream(IndexedSeq(), TStream(TInt32)), I32(2))), + FastSeq(), + ) + assertEvalsTo( + toNestedArray(StreamGrouped(a, I32(1))), + FastSeq(FastSeq(3), FastSeq(null), FastSeq(7)), + ) assertEvalsTo(toNestedArray(StreamGrouped(a, I32(2))), FastSeq(FastSeq(3, null), FastSeq(7))) assertEvalsTo(toNestedArray(StreamGrouped(a, I32(5))), FastSeq(FastSeq(3, null, 7))) assertFatal(toNestedArray(StreamGrouped(a, I32(0))), "stream grouped: non-positive size") @@ -1498,15 +1698,18 @@ class IRSuite extends HailSuite { StreamMap(StreamGrouped(stream, fromEach), "inner", StreamTake(Ref("inner", innerType), take)) } - assertEvalsTo(toNestedArray(takeFromEach(r, I32(1), I32(3))), - FastSeq(FastSeq(0), FastSeq(3), FastSeq(6), FastSeq(9))) - assertEvalsTo(toNestedArray(takeFromEach(r, I32(2), I32(3))), - FastSeq(FastSeq(0, 1), FastSeq(3, 4), FastSeq(6, 7), FastSeq(9))) - assertEvalsTo(toNestedArray(takeFromEach(r, I32(0), I32(5))), - FastSeq(FastSeq(), FastSeq())) + assertEvalsTo( + toNestedArray(takeFromEach(r, I32(1), I32(3))), + FastSeq(FastSeq(0), FastSeq(3), FastSeq(6), FastSeq(9)), + ) + assertEvalsTo( + toNestedArray(takeFromEach(r, I32(2), I32(3))), + FastSeq(FastSeq(0, 1), FastSeq(3, 4), FastSeq(6, 7), FastSeq(9)), + ) + assertEvalsTo(toNestedArray(takeFromEach(r, I32(0), I32(5))), FastSeq(FastSeq(), FastSeq())) } - @Test def testStreamGroupByKey() { + @Test def testStreamGroupByKey(): Unit = { val structType = TStruct("a" -> TInt32, "b" -> TInt32) val naa = NA(TStream(structType)) val a = MakeStream( @@ -1518,17 +1721,23 @@ class IRSuite extends HailSuite { MakeStruct(IndexedSeq("a" -> I32(1), "b" -> I32(2))), MakeStruct(IndexedSeq("a" -> I32(1), "b" -> I32(4))), MakeStruct(IndexedSeq("a" -> I32(1), "b" -> I32(6))), - MakeStruct(IndexedSeq("a" -> I32(4), "b" -> NA(TInt32)))), - TStream(structType)) + MakeStruct(IndexedSeq("a" -> I32(4), "b" -> NA(TInt32))), + ), + TStream(structType), + ) def group(a: IR): IR = StreamGroupByKey(a, FastSeq("a"), false) assertEvalsTo(toNestedArray(group(naa)), null) - assertEvalsTo(toNestedArray(group(a)), - FastSeq(FastSeq(Row(3, 1), Row(3, 3)), - FastSeq(Row(null, -1)), - FastSeq(Row(null, -2)), - FastSeq(Row(1, 2), Row(1, 4), Row(1, 6)), - FastSeq(Row(4, null)))) + assertEvalsTo( + toNestedArray(group(a)), + FastSeq( + FastSeq(Row(3, 1), Row(3, 3)), + FastSeq(Row(null, -1)), + FastSeq(Row(null, -2)), + FastSeq(Row(1, 2), Row(1, 4), Row(1, 6)), + FastSeq(Row(4, null)), + ), + ) assertEvalsTo(toNestedArray(group(MakeStream(IndexedSeq(), TStream(structType)))), FastSeq()) // test when inner streams are unused @@ -1539,34 +1748,46 @@ class IRSuite extends HailSuite { StreamMap(group(stream), "inner", StreamTake(Ref("inner", innerType), take)) } - assertEvalsTo(toNestedArray(takeFromEach(a, I32(1))), - FastSeq(FastSeq(Row(3, 1)), - FastSeq(Row(null, -1)), - FastSeq(Row(null, -2)), - FastSeq(Row(1, 2)), - FastSeq(Row(4, null)))) - assertEvalsTo(toNestedArray(takeFromEach(a, I32(2))), - FastSeq(FastSeq(Row(3, 1), Row(3, 3)), - FastSeq(Row(null, -1)), - FastSeq(Row(null, -2)), - FastSeq(Row(1, 2), Row(1, 4)), - FastSeq(Row(4, null)))) - } - - @Test def testStreamMap() { + assertEvalsTo( + toNestedArray(takeFromEach(a, I32(1))), + FastSeq( + FastSeq(Row(3, 1)), + FastSeq(Row(null, -1)), + FastSeq(Row(null, -2)), + FastSeq(Row(1, 2)), + FastSeq(Row(4, null)), + ), + ) + assertEvalsTo( + toNestedArray(takeFromEach(a, I32(2))), + FastSeq( + FastSeq(Row(3, 1), Row(3, 3)), + FastSeq(Row(null, -1)), + FastSeq(Row(null, -2)), + FastSeq(Row(1, 2), Row(1, 4)), + FastSeq(Row(4, null)), + ), + ) + } + + @Test def testStreamMap(): Unit = { val naa = NA(TStream(TInt32)) val a = MakeStream(IndexedSeq(I32(3), NA(TInt32), I32(7)), TStream(TInt32)) assertEvalsTo(ToArray(StreamMap(naa, "a", I32(5))), null) - assertEvalsTo(ToArray(StreamMap(a, "a", ApplyBinaryPrimOp(Add(), Ref("a", TInt32), I32(1)))), FastSeq(4, null, 8)) + assertEvalsTo( + ToArray(StreamMap(a, "a", ApplyBinaryPrimOp(Add(), Ref("a", TInt32), I32(1)))), + FastSeq(4, null, 8), + ) - assertEvalsTo(ToArray(Let(FastSeq("a" -> I32(5)), - StreamMap(a, "a", Ref("a", TInt32)))), - FastSeq(3, null, 7)) + assertEvalsTo( + ToArray(Let(FastSeq("a" -> I32(5)), StreamMap(a, "a", Ref("a", TInt32)))), + FastSeq(3, null, 7), + ) } - @Test def testStreamFilter() { + @Test def testStreamFilter(): Unit = { val nsa = NA(TStream(TInt32)) val a = MakeStream(IndexedSeq(I32(3), NA(TInt32), I32(7)), TStream(TInt32)) @@ -1576,26 +1797,32 @@ class IRSuite extends HailSuite { assertEvalsTo(ToArray(StreamFilter(a, "x", False())), FastSeq()) assertEvalsTo(ToArray(StreamFilter(a, "x", True())), FastSeq(3, null, 7)) - assertEvalsTo(ToArray(StreamFilter(a, "x", - IsNA(Ref("x", TInt32)))), FastSeq(null)) - assertEvalsTo(ToArray(StreamFilter(a, "x", - ApplyUnaryPrimOp(Bang, IsNA(Ref("x", TInt32))))), FastSeq(3, 7)) + assertEvalsTo(ToArray(StreamFilter(a, "x", IsNA(Ref("x", TInt32)))), FastSeq(null)) + assertEvalsTo( + ToArray(StreamFilter(a, "x", ApplyUnaryPrimOp(Bang, IsNA(Ref("x", TInt32))))), + FastSeq(3, 7), + ) - assertEvalsTo(ToArray(StreamFilter(a, "x", - ApplyComparisonOp(LT(TInt32), Ref("x", TInt32), I32(6)))), FastSeq(3)) + assertEvalsTo( + ToArray(StreamFilter(a, "x", ApplyComparisonOp(LT(TInt32), Ref("x", TInt32), I32(6)))), + FastSeq(3), + ) } - @Test def testArrayFlatMap() { + @Test def testArrayFlatMap(): Unit = { val ta = TArray(TInt32) val ts = TStream(TInt32) val tsa = TStream(ta) val nsa = NA(tsa) val naas = MakeStream(FastSeq(NA(ta), NA(ta)), tsa) - val a = MakeStream(FastSeq( - MakeArray(FastSeq(I32(7), NA(TInt32)), ta), - NA(ta), - MakeArray(FastSeq(I32(2)), ta)), - tsa) + val a = MakeStream( + FastSeq( + MakeArray(FastSeq(I32(7), NA(TInt32)), ta), + NA(ta), + MakeArray(FastSeq(I32(2)), ta), + ), + tsa, + ) assertEvalsTo(ToArray(StreamFlatMap(nsa, "a", MakeStream(FastSeq(I32(5)), ts))), null) @@ -1603,67 +1830,129 @@ class IRSuite extends HailSuite { assertEvalsTo(ToArray(StreamFlatMap(a, "a", ToStream(Ref("a", ta)))), FastSeq(7, null, 2)) - assertEvalsTo(ToArray(StreamFlatMap(StreamRange(I32(0), I32(3), I32(1)), "i", ToStream(ArrayRef(ToArray(a), Ref("i", TInt32))))), FastSeq(7, null, 2)) + assertEvalsTo( + ToArray(StreamFlatMap( + StreamRange(I32(0), I32(3), I32(1)), + "i", + ToStream(ArrayRef(ToArray(a), Ref("i", TInt32))), + )), + FastSeq(7, null, 2), + ) - assertEvalsTo(ToArray(Let(FastSeq("a" -> I32(5)), StreamFlatMap(a, "a", ToStream(Ref("a", ta))))), FastSeq(7, null, 2)) + assertEvalsTo( + ToArray(Let(FastSeq("a" -> I32(5)), StreamFlatMap(a, "a", ToStream(Ref("a", ta))))), + FastSeq(7, null, 2), + ) - val b = MakeStream(FastSeq( - MakeArray(FastSeq(I32(7), I32(0)), ta), - NA(ta), - MakeArray(FastSeq(I32(2)), ta)), - tsa) - assertEvalsTo(ToArray(Let(FastSeq("a" -> I32(5)), StreamFlatMap(b, "b", ToStream(Ref("b", ta))))), FastSeq(7, 0, 2)) + val b = MakeStream( + FastSeq( + MakeArray(FastSeq(I32(7), I32(0)), ta), + NA(ta), + MakeArray(FastSeq(I32(2)), ta), + ), + tsa, + ) + assertEvalsTo( + ToArray(Let(FastSeq("a" -> I32(5)), StreamFlatMap(b, "b", ToStream(Ref("b", ta))))), + FastSeq(7, 0, 2), + ) val st = MakeStream(FastSeq(I32(1), I32(5), I32(2), NA(TInt32)), TStream(TInt32)) val expected = FastSeq(-1, 0, -1, 0, 1, 2, 3, 4, -1, 0, 1) - assertEvalsTo(ToArray(StreamFlatMap(st, "foo", StreamRange(I32(-1), Ref("foo", TInt32), I32(1)))), expected) + assertEvalsTo( + ToArray(StreamFlatMap(st, "foo", StreamRange(I32(-1), Ref("foo", TInt32), I32(1)))), + expected, + ) } - @Test def testStreamFold() { + @Test def testStreamFold(): Unit = { def fold(s: IR, zero: IR, f: (IR, IR) => IR): IR = StreamFold(s, zero, "_accum", "_elt", f(Ref("_accum", zero.typ), Ref("_elt", zero.typ))) assertEvalsTo(fold(StreamRange(1, 2, 1), NA(TBoolean), (accum, elt) => IsNA(accum)), true) assertEvalsTo(fold(TestUtils.IRStream(1, 2, 3), 0, (accum, elt) => accum + elt), 6) assertEvalsTo(fold(TestUtils.IRStream(1, 2, 3), NA(TInt32), (accum, elt) => accum + elt), null) - assertEvalsTo(fold(TestUtils.IRStream(1, null, 3), NA(TInt32), (accum, elt) => accum + elt), null) + assertEvalsTo( + fold(TestUtils.IRStream(1, null, 3), NA(TInt32), (accum, elt) => accum + elt), + null, + ) assertEvalsTo(fold(TestUtils.IRStream(1, null, 3), 0, (accum, elt) => accum + elt), null) - assertEvalsTo(fold(TestUtils.IRStream(1, null, 3), NA(TInt32), (accum, elt) => I32(5) + I32(5)), 10) + assertEvalsTo( + fold(TestUtils.IRStream(1, null, 3), NA(TInt32), (accum, elt) => I32(5) + I32(5)), + 10, + ) } - @Test def testArrayFold2() { + @Test def testArrayFold2(): Unit = { implicit val execStrats = ExecStrategy.compileOnly - val af = StreamFold2(ToStream(In(0, TArray(TInt32))), + val af = StreamFold2( + ToStream(In(0, TArray(TInt32))), FastSeq(("x", I32(0)), ("y", NA(TInt32))), "val", - FastSeq(Ref("val", TInt32) + Ref("x", TInt32), Coalesce(FastSeq(Ref("y", TInt32), Ref("val", TInt32)))), - MakeStruct(FastSeq(("x", Ref("x", TInt32)), ("y", Ref("y", TInt32)))) + FastSeq( + Ref("val", TInt32) + Ref("x", TInt32), + Coalesce(FastSeq(Ref("y", TInt32), Ref("val", TInt32))), + ), + MakeStruct(FastSeq(("x", Ref("x", TInt32)), ("y", Ref("y", TInt32)))), ) assertEvalsTo(af, FastSeq((FastSeq(1, 2, 3), TArray(TInt32))), Row(6, 1)) } - @Test def testArrayScan() { + @Test def testArrayScan(): Unit = { implicit val execStrats = ExecStrategy.javaOnly def scan(array: IR, zero: IR, f: (IR, IR) => IR): IR = - ToArray(StreamScan(array, zero, "_accum", "_elt", f(Ref("_accum", zero.typ), Ref("_elt", zero.typ)))) + ToArray(StreamScan( + array, + zero, + "_accum", + "_elt", + f(Ref("_accum", zero.typ), Ref("_elt", zero.typ)), + )) - assertEvalsTo(scan(StreamRange(1, 4, 1), NA(TBoolean), (accum, elt) => IsNA(accum)), FastSeq(null, true, false, false)) - assertEvalsTo(scan(TestUtils.IRStream(1, 2, 3), 0, (accum, elt) => accum + elt), FastSeq(0, 1, 3, 6)) - assertEvalsTo(scan(TestUtils.IRStream(1, 2, 3), NA(TInt32), (accum, elt) => accum + elt), FastSeq(null, null, null, null)) - assertEvalsTo(scan(TestUtils.IRStream(1, null, 3), NA(TInt32), (accum, elt) => accum + elt), FastSeq(null, null, null, null)) + assertEvalsTo( + scan(StreamRange(1, 4, 1), NA(TBoolean), (accum, elt) => IsNA(accum)), + FastSeq(null, true, false, false), + ) + assertEvalsTo( + scan(TestUtils.IRStream(1, 2, 3), 0, (accum, elt) => accum + elt), + FastSeq(0, 1, 3, 6), + ) + assertEvalsTo( + scan(TestUtils.IRStream(1, 2, 3), NA(TInt32), (accum, elt) => accum + elt), + FastSeq(null, null, null, null), + ) + assertEvalsTo( + scan(TestUtils.IRStream(1, null, 3), NA(TInt32), (accum, elt) => accum + elt), + FastSeq(null, null, null, null), + ) assertEvalsTo(scan(NA(TStream(TInt32)), 0, (accum, elt) => accum + elt), null) - assertEvalsTo(scan(MakeStream(IndexedSeq(), TStream(TInt32)), 99, (accum, elt) => accum + elt), FastSeq(99)) - assertEvalsTo(scan(StreamFlatMap(StreamRange(0, 5, 1), "z", MakeStream(IndexedSeq(), TStream(TInt32))), 99, (accum, elt) => accum + elt), FastSeq(99)) + assertEvalsTo( + scan(MakeStream(IndexedSeq(), TStream(TInt32)), 99, (accum, elt) => accum + elt), + FastSeq(99), + ) + assertEvalsTo( + scan( + StreamFlatMap(StreamRange(0, 5, 1), "z", MakeStream(IndexedSeq(), TStream(TInt32))), + 99, + (accum, elt) => accum + elt, + ), + FastSeq(99), + ) } - def makeNDArray(data: IndexedSeq[Double], shape: IndexedSeq[Long], rowMajor: IR): MakeNDArray = { - MakeNDArray(MakeArray(data.map(F64), TArray(TFloat64)), MakeTuple.ordered(shape.map(I64)), rowMajor, ErrorIDs.NO_ERROR) - } + def makeNDArray(data: IndexedSeq[Double], shape: IndexedSeq[Long], rowMajor: IR): MakeNDArray = + MakeNDArray( + MakeArray(data.map(F64), TArray(TFloat64)), + MakeTuple.ordered(shape.map(I64)), + rowMajor, + ErrorIDs.NO_ERROR, + ) - def makeNDArrayRef(nd: IR, indxs: IndexedSeq[Long]): NDArrayRef = NDArrayRef(nd, indxs.map(I64), -1) + def makeNDArrayRef(nd: IR, indxs: IndexedSeq[Long]): NDArrayRef = + NDArrayRef(nd, indxs.map(I64), -1) val scalarRowMajor = makeNDArray(FastSeq(3.0), FastSeq(), True()) val scalarColMajor = makeNDArray(FastSeq(3.0), FastSeq(), False()) @@ -1675,7 +1964,7 @@ class IRSuite extends HailSuite { val cubeRowMajor = makeNDArray((0 until 27).map(_.toDouble), FastSeq(3, 3, 3), True()) val cubeColMajor = makeNDArray((0 until 27).map(_.toDouble), FastSeq(3, 3, 3), False()) - @Test def testNDArrayShape() { + @Test def testNDArrayShape(): Unit = { implicit val execStrats = ExecStrategy.compileOnly assertEvalsTo(NDArrayShape(scalarRowMajor), Row()) @@ -1683,7 +1972,7 @@ class IRSuite extends HailSuite { assertEvalsTo(NDArrayShape(cubeRowMajor), Row(3L, 3L, 3L)) } - @Test def testNDArrayRef() { + @Test def testNDArrayRef(): Unit = { implicit val execStrats: Set[ExecStrategy] = ExecStrategy.compileOnly assertEvalsTo(makeNDArrayRef(scalarRowMajor, FastSeq()), 3.0) @@ -1711,7 +2000,7 @@ class IRSuite extends HailSuite { assertEvalsTo(centerColMajor, 13.0) } - @Test def testNDArrayReshape() { + @Test def testNDArrayReshape(): Unit = { implicit val execStrats: Set[ExecStrategy] = ExecStrategy.compileOnly val v = NDArrayReshape(matrixRowMajor, MakeTuple.ordered(IndexedSeq(I64(4))), ErrorIDs.NO_ERROR) @@ -1723,29 +2012,51 @@ class IRSuite extends HailSuite { assertEvalsTo(makeNDArrayRef(mat2, FastSeq(0, 0)), 1.0) } - @Test def testNDArrayConcat() { + @Test def testNDArrayConcat(): Unit = { implicit val execStrats: Set[ExecStrategy] = ExecStrategy.compileOnly - def nds(ndData: (IndexedSeq[Int], Long, Long)*): IR = { - MakeArray(ndData.toArray.map { case (values, nRows, nCols) => - if (values == null) NA(TNDArray(TInt32, Nat(2))) else - MakeNDArray(Literal(TArray(TInt32), values), - Literal(TTuple(TInt64, TInt64), Row(nRows, nCols)), True(), ErrorIDs.NO_ERROR) - }, TArray(TNDArray(TInt32, Nat(2)))) - } + def nds(ndData: (IndexedSeq[Int], Long, Long)*): IR = + MakeArray( + ndData.toArray.map { case (values, nRows, nCols) => + if (values == null) NA(TNDArray(TInt32, Nat(2))) + else + MakeNDArray( + Literal(TArray(TInt32), values), + Literal(TTuple(TInt64, TInt64), Row(nRows, nCols)), + True(), + ErrorIDs.NO_ERROR, + ) + }, + TArray(TNDArray(TInt32, Nat(2))), + ) - val nd1 = (FastSeq( - 0, 1, 2, - 3, 4, 5), 2L, 3L) + val nd1 = ( + FastSeq( + 0, 1, 2, + 3, 4, 5), + 2L, + 3L, + ) - val rowwise = (FastSeq( - 6, 7, 8, - 9, 10, 11, - 12, 13, 14), 3L, 3L) + val rowwise = ( + FastSeq( + 6, 7, 8, + 9, 10, 11, + 12, 13, 14), + 3L, + 3L, + ) - val colwise = (FastSeq( - 15, 16, - 17, 18), 2L, 2L) + val colwise = ( + FastSeq( + 15, + 16, + 17, + 18, + ), + 2L, + 2L, + ) val emptyRowwise = (FastSeq(), 0L, 3L) val emptyColwise = (FastSeq(), 2L, 0L) @@ -1756,10 +2067,12 @@ class IRSuite extends HailSuite { FastSeq(3, 4, 5), FastSeq(6, 7, 8), FastSeq(9, 10, 11), - FastSeq(12, 13, 14)) + FastSeq(12, 13, 14), + ) val colwiseExpected = FastSeq( FastSeq(0, 1, 2, 15, 16), - FastSeq(3, 4, 5, 17, 18)) + FastSeq(3, 4, 5, 17, 18), + ) assertNDEvals(NDArrayConcat(nds(nd1, rowwise), 0), rowwiseExpected) assertNDEvals(NDArrayConcat(nds(nd1, rowwise, emptyRowwise), 0), rowwiseExpected) @@ -1778,25 +2091,33 @@ class IRSuite extends HailSuite { assertNDEvals(NDArrayConcat(NA(TArray(TNDArray(TInt32, Nat(2)))), 1), null) } - @Test def testNDArrayMap() { + @Test def testNDArrayMap(): Unit = { implicit val execStrats: Set[ExecStrategy] = ExecStrategy.compileOnly val data = 0 until 10 val shape = FastSeq(2L, 5L) - val nDim = 2 val positives = makeNDArray(data.map(_.toDouble), shape, True()) val negatives = NDArrayMap(positives, "e", ApplyUnaryPrimOp(Negate, Ref("e", TFloat64))) assertEvalsTo(makeNDArrayRef(positives, FastSeq(1L, 0L)), 5.0) assertEvalsTo(makeNDArrayRef(negatives, FastSeq(1L, 0L)), -5.0) - val trues = MakeNDArray(MakeArray(data.map(_ => True()), TArray(TBoolean)), MakeTuple.ordered(shape.map(I64)), True(), ErrorIDs.NO_ERROR) + val trues = MakeNDArray( + MakeArray(data.map(_ => True()), TArray(TBoolean)), + MakeTuple.ordered(shape.map(I64)), + True(), + ErrorIDs.NO_ERROR, + ) val falses = NDArrayMap(trues, "e", ApplyUnaryPrimOp(Bang, Ref("e", TBoolean))) assertEvalsTo(makeNDArrayRef(trues, FastSeq(1L, 0L)), true) assertEvalsTo(makeNDArrayRef(falses, FastSeq(1L, 0L)), false) - val bools = MakeNDArray(MakeArray(data.map(i => if (i % 2 == 0) True() else False()), TArray(TBoolean)), - MakeTuple.ordered(shape.map(I64)), False(), ErrorIDs.NO_ERROR) + val bools = MakeNDArray( + MakeArray(data.map(i => if (i % 2 == 0) True() else False()), TArray(TBoolean)), + MakeTuple.ordered(shape.map(I64)), + False(), + ErrorIDs.NO_ERROR, + ) val boolsToBinary = NDArrayMap(bools, "e", If(Ref("e", TBoolean), I64(1L), I64(0L))) val one = makeNDArrayRef(boolsToBinary, FastSeq(0L, 0L)) val zero = makeNDArrayRef(boolsToBinary, FastSeq(1L, 1L)) @@ -1804,22 +2125,38 @@ class IRSuite extends HailSuite { assertEvalsTo(zero, 0L) } - @Test def testNDArrayMap2() { + @Test def testNDArrayMap2(): Unit = { implicit val execStrats: Set[ExecStrategy] = ExecStrategy.compileOnly val shape = MakeTuple.ordered(FastSeq(2L, 2L).map(I64)) - val numbers = MakeNDArray(MakeArray((0 until 4).map { i => F64(i.toDouble) }, TArray(TFloat64)), shape, True(), ErrorIDs.NO_ERROR) - val bools = MakeNDArray(MakeArray(IndexedSeq(True(), False(), False(), True()), TArray(TBoolean)), shape, True(), ErrorIDs.NO_ERROR) + val numbers = MakeNDArray( + MakeArray((0 until 4).map(i => F64(i.toDouble)), TArray(TFloat64)), + shape, + True(), + ErrorIDs.NO_ERROR, + ) + val bools = MakeNDArray( + MakeArray(IndexedSeq(True(), False(), False(), True()), TArray(TBoolean)), + shape, + True(), + ErrorIDs.NO_ERROR, + ) - val actual = NDArrayMap2(numbers, bools, "n", "b", - ApplyBinaryPrimOp(Add(), Ref("n", TFloat64), If(Ref("b", TBoolean), F64(10), F64(20))), ErrorIDs.NO_ERROR) + val actual = NDArrayMap2( + numbers, + bools, + "n", + "b", + ApplyBinaryPrimOp(Add(), Ref("n", TFloat64), If(Ref("b", TBoolean), F64(10), F64(20))), + ErrorIDs.NO_ERROR, + ) val ten = makeNDArrayRef(actual, FastSeq(0L, 0L)) val twentyTwo = makeNDArrayRef(actual, FastSeq(1L, 0L)) assertEvalsTo(ten, 10.0) assertEvalsTo(twentyTwo, 22.0) } - @Test def testNDArrayReindex() { + @Test def testNDArrayReindex(): Unit = { implicit val execStrats: Set[ExecStrategy] = ExecStrategy.compileOnly val transpose = NDArrayReindex(matrixRowMajor, FastSeq(1, 0)) @@ -1842,14 +2179,17 @@ class IRSuite extends HailSuite { assertEvalsTo(makeNDArrayRef(partialTranspose, partialTranposeIdx), 3.0) } - @Test def testNDArrayBroadcasting() { + @Test def testNDArrayBroadcasting(): Unit = { implicit val execStrats: Set[ExecStrategy] = ExecStrategy.compileOnly val scalarWithMatrix = NDArrayMap2( NDArrayReindex(scalarRowMajor, FastSeq(1, 0)), matrixRowMajor, - "s", "m", - ApplyBinaryPrimOp(Add(), Ref("s", TFloat64), Ref("m", TFloat64)), ErrorIDs.NO_ERROR) + "s", + "m", + ApplyBinaryPrimOp(Add(), Ref("s", TFloat64), Ref("m", TFloat64)), + ErrorIDs.NO_ERROR, + ) val topLeft = makeNDArrayRef(scalarWithMatrix, FastSeq(0, 0)) assertEvalsTo(topLeft, 4.0) @@ -1857,16 +2197,25 @@ class IRSuite extends HailSuite { val vectorWithMatrix = NDArrayMap2( NDArrayReindex(vectorRowMajor, FastSeq(1, 0)), matrixRowMajor, - "v", "m", - ApplyBinaryPrimOp(Add(), Ref("v", TFloat64), Ref("m", TFloat64)), ErrorIDs.NO_ERROR) + "v", + "m", + ApplyBinaryPrimOp(Add(), Ref("v", TFloat64), Ref("m", TFloat64)), + ErrorIDs.NO_ERROR, + ) assertEvalsTo(makeNDArrayRef(vectorWithMatrix, FastSeq(0, 0)), 2.0) assertEvalsTo(makeNDArrayRef(vectorWithMatrix, FastSeq(0, 1)), 1.0) assertEvalsTo(makeNDArrayRef(vectorWithMatrix, FastSeq(1, 0)), 4.0) val colVector = makeNDArray(FastSeq(1.0, -1.0), FastSeq(2, 1), True()) - val colVectorWithMatrix = NDArrayMap2(colVector, matrixRowMajor, "v", "m", - ApplyBinaryPrimOp(Add(), Ref("v", TFloat64), Ref("m", TFloat64)), ErrorIDs.NO_ERROR) + val colVectorWithMatrix = NDArrayMap2( + colVector, + matrixRowMajor, + "v", + "m", + ApplyBinaryPrimOp(Add(), Ref("v", TFloat64), Ref("m", TFloat64)), + ErrorIDs.NO_ERROR, + ) assertEvalsTo(makeNDArrayRef(colVectorWithMatrix, FastSeq(0, 0)), 2.0) assertEvalsTo(makeNDArrayRef(colVectorWithMatrix, FastSeq(0, 1)), 3.0) @@ -1875,22 +2224,31 @@ class IRSuite extends HailSuite { val vectorWithEmpty = NDArrayMap2( NDArrayReindex(vectorRowMajor, FastSeq(1, 0)), makeNDArray(FastSeq(), FastSeq(0, 2), True()), - "v", "m", - ApplyBinaryPrimOp(Add(), Ref("v", TFloat64), Ref("m", TFloat64)), ErrorIDs.NO_ERROR) + "v", + "m", + ApplyBinaryPrimOp(Add(), Ref("v", TFloat64), Ref("m", TFloat64)), + ErrorIDs.NO_ERROR, + ) assertEvalsTo(NDArrayShape(vectorWithEmpty), Row(0L, 2L)) val colVectorWithEmpty = NDArrayMap2( colVector, makeNDArray(FastSeq(), FastSeq(2, 0), True()), - "v", "m", - ApplyBinaryPrimOp(Add(), Ref("v", TFloat64), Ref("m", TFloat64)), ErrorIDs.NO_ERROR) + "v", + "m", + ApplyBinaryPrimOp(Add(), Ref("v", TFloat64), Ref("m", TFloat64)), + ErrorIDs.NO_ERROR, + ) assertEvalsTo(NDArrayShape(colVectorWithEmpty), Row(2L, 0L)) } - @Test def testNDArrayAgg() { + @Test def testNDArrayAgg(): Unit = { implicit val execStrats: Set[ExecStrategy] = ExecStrategy.compileOnly - val empty = makeNDArrayRef(NDArrayAgg(makeNDArray(IndexedSeq(), IndexedSeq(0, 5), true), IndexedSeq(0, 1)), IndexedSeq()) + val empty = makeNDArrayRef( + NDArrayAgg(makeNDArray(IndexedSeq(), IndexedSeq(0, 5), true), IndexedSeq(0, 1)), + IndexedSeq(), + ) assertEvalsTo(empty, 0.0) val three = makeNDArrayRef(NDArrayAgg(scalarRowMajor, IndexedSeq.empty), IndexedSeq()) @@ -1908,14 +2266,17 @@ class IRSuite extends HailSuite { assertEvalsTo(twentySeven, 3.0) } - @Test def testNDArrayMatMul() { + @Test def testNDArrayMatMul(): Unit = { implicit val execStrats: Set[ExecStrategy] = ExecStrategy.compileOnly val dotProduct = NDArrayMatMul(vectorRowMajor, vectorRowMajor, ErrorIDs.NO_ERROR) val zero = makeNDArrayRef(dotProduct, IndexedSeq()) assertEvalsTo(zero, 2.0) - val seven = makeNDArrayRef(NDArrayMatMul(matrixRowMajor, matrixRowMajor, ErrorIDs.NO_ERROR), IndexedSeq(0, 0)) + val seven = makeNDArrayRef( + NDArrayMatMul(matrixRowMajor, matrixRowMajor, ErrorIDs.NO_ERROR), + IndexedSeq(0, 0), + ) assertEvalsTo(seven, 7.0) val twoByThreeByFive = threeTensorRowMajor @@ -1925,11 +2286,15 @@ class IRSuite extends HailSuite { assertEvalsTo(thirty, 30.0) val threeByTwoByFive = NDArrayReindex(twoByThreeByFive, IndexedSeq(1, 0, 2)) - val matMulCube = NDArrayMatMul(NDArrayReindex(matrixRowMajor, IndexedSeq(2, 0, 1)), threeByTwoByFive, ErrorIDs.NO_ERROR) + val matMulCube = NDArrayMatMul( + NDArrayReindex(matrixRowMajor, IndexedSeq(2, 0, 1)), + threeByTwoByFive, + ErrorIDs.NO_ERROR, + ) assertEvalsTo(makeNDArrayRef(matMulCube, IndexedSeq(0, 0, 0)), 30.0) } - @Test def testNDArrayInv() { + @Test def testNDArrayInv(): Unit = { implicit val execStrats: Set[ExecStrategy] = ExecStrategy.compileOnly val matrixRowMajor = makeNDArray(FastSeq(1.5, 2.0, 4.0, 5.0), FastSeq(2, 2), True()) val inv = NDArrayInv(matrixRowMajor, ErrorIDs.NO_ERROR) @@ -1937,17 +2302,28 @@ class IRSuite extends HailSuite { assertNDEvals(inv, expectedInv) } - @Test def testNDArraySlice() { + @Test def testNDArraySlice(): Unit = { implicit val execStrats: Set[ExecStrategy] = ExecStrategy.compileOnly - val rightCol = NDArraySlice(matrixRowMajor, MakeTuple.ordered(IndexedSeq(MakeTuple.ordered(IndexedSeq(I64(0), I64(2), I64(1))), I64(1)))) + val rightCol = NDArraySlice( + matrixRowMajor, + MakeTuple.ordered(IndexedSeq(MakeTuple.ordered(IndexedSeq(I64(0), I64(2), I64(1))), I64(1))), + ) assertEvalsTo(NDArrayShape(rightCol), Row(2L)) assertEvalsTo(makeNDArrayRef(rightCol, FastSeq(0)), 2.0) assertEvalsTo(makeNDArrayRef(rightCol, FastSeq(1)), 4.0) - val topRow = NDArraySlice(matrixRowMajor, - MakeTuple.ordered(IndexedSeq(I64(0), - MakeTuple.ordered(IndexedSeq(I64(0), GetTupleElement(NDArrayShape(matrixRowMajor), 1), I64(1)))))) + val topRow = NDArraySlice( + matrixRowMajor, + MakeTuple.ordered(IndexedSeq( + I64(0), + MakeTuple.ordered(IndexedSeq( + I64(0), + GetTupleElement(NDArrayShape(matrixRowMajor), 1), + I64(1), + )), + )), + ) assertEvalsTo(makeNDArrayRef(topRow, FastSeq(0)), 1.0) assertEvalsTo(makeNDArrayRef(topRow, FastSeq(1)), 2.0) @@ -1955,68 +2331,112 @@ class IRSuite extends HailSuite { assertEvalsTo(makeNDArrayRef(scalarSlice, FastSeq()), 3.0) } - @Test def testNDArrayFilter() { + @Test def testNDArrayFilter(): Unit = { implicit val execStrats: Set[ExecStrategy] = ExecStrategy.compileOnly assertNDEvals( NDArrayFilter(matrixRowMajor, FastSeq(NA(TArray(TInt64)), NA(TArray(TInt64)))), - FastSeq(FastSeq(1.0, 2.0), - FastSeq(3.0, 4.0))) + FastSeq(FastSeq(1.0, 2.0), FastSeq(3.0, 4.0)), + ) assertNDEvals( - NDArrayFilter(matrixRowMajor, FastSeq( - MakeArray(FastSeq(I64(0), I64(1)), TArray(TInt64)), - MakeArray(FastSeq(I64(0), I64(1)), TArray(TInt64)))), - FastSeq(FastSeq(1.0, 2.0), - FastSeq(3.0, 4.0))) + NDArrayFilter( + matrixRowMajor, + FastSeq( + MakeArray(FastSeq(I64(0), I64(1)), TArray(TInt64)), + MakeArray(FastSeq(I64(0), I64(1)), TArray(TInt64)), + ), + ), + FastSeq(FastSeq(1.0, 2.0), FastSeq(3.0, 4.0)), + ) assertNDEvals( - NDArrayFilter(matrixRowMajor, FastSeq( - MakeArray(FastSeq(I64(1), I64(0)), TArray(TInt64)), - MakeArray(FastSeq(I64(1), I64(0)), TArray(TInt64)))), - FastSeq(FastSeq(4.0, 3.0), - FastSeq(2.0, 1.0))) + NDArrayFilter( + matrixRowMajor, + FastSeq( + MakeArray(FastSeq(I64(1), I64(0)), TArray(TInt64)), + MakeArray(FastSeq(I64(1), I64(0)), TArray(TInt64)), + ), + ), + FastSeq(FastSeq(4.0, 3.0), FastSeq(2.0, 1.0)), + ) assertNDEvals( - NDArrayFilter(matrixRowMajor, FastSeq( - MakeArray(FastSeq(I64(0)), TArray(TInt64)), NA(TArray(TInt64)))), - FastSeq(FastSeq(1.0, 2.0))) + NDArrayFilter( + matrixRowMajor, + FastSeq( + MakeArray(FastSeq(I64(0)), TArray(TInt64)), + NA(TArray(TInt64)), + ), + ), + FastSeq(FastSeq(1.0, 2.0)), + ) assertNDEvals( - NDArrayFilter(matrixRowMajor, FastSeq( - NA(TArray(TInt64)), MakeArray(FastSeq(I64(0)), TArray(TInt64)))), - FastSeq(FastSeq(1.0), - FastSeq(3.0))) + NDArrayFilter( + matrixRowMajor, + FastSeq( + NA(TArray(TInt64)), + MakeArray(FastSeq(I64(0)), TArray(TInt64)), + ), + ), + FastSeq(FastSeq(1.0), FastSeq(3.0)), + ) assertNDEvals( - NDArrayFilter(matrixRowMajor, FastSeq( - MakeArray(FastSeq(I64(1)), TArray(TInt64)), - MakeArray(FastSeq(I64(1)), TArray(TInt64)))), - FastSeq(FastSeq(4.0))) + NDArrayFilter( + matrixRowMajor, + FastSeq( + MakeArray(FastSeq(I64(1)), TArray(TInt64)), + MakeArray(FastSeq(I64(1)), TArray(TInt64)), + ), + ), + FastSeq(FastSeq(4.0)), + ) } - private def join(left: IR, right: IR, lKeys: IndexedSeq[String], rKeys: IndexedSeq[String], rightDistinct: Boolean, joinType: String): IR = { + private def join( + left: IR, + right: IR, + lKeys: IndexedSeq[String], + rKeys: IndexedSeq[String], + rightDistinct: Boolean, + joinType: String, + ): IR = { val joinF = { (l: IR, r: IR) => def getL(field: String): IR = GetField(Ref("_left", l.typ), field) def getR(field: String): IR = GetField(Ref("_right", r.typ), field) - Let(FastSeq("_right" -> r, "_left" -> l), + Let( + FastSeq("_right" -> r, "_left" -> l), MakeStruct( - (lKeys, rKeys).zipped.map { (lk, rk) => lk -> Coalesce(IndexedSeq(getL(lk), getR(rk))) } + (lKeys, rKeys).zipped.map((lk, rk) => lk -> Coalesce(IndexedSeq(getL(lk), getR(rk)))) ++ tcoerce[TStruct](l.typ).fields.filter(f => !lKeys.contains(f.name)).map { f => - f.name -> GetField(Ref("_left", l.typ), f.name) - } ++ tcoerce[TStruct](r.typ).fields.filter(f => !rKeys.contains(f.name)).map { f => - f.name -> GetField(Ref("_right", r.typ), f.name) - } - ) + f.name -> GetField(Ref("_left", l.typ), f.name) + } ++ tcoerce[TStruct](r.typ).fields.filter(f => !rKeys.contains(f.name)).map { f => + f.name -> GetField(Ref("_right", r.typ), f.name) + } + ), ) } - ToArray(StreamJoin.apply(left, right, lKeys, rKeys, "_l", "_r", - joinF(Ref("_l", tcoerce[TStream](left.typ).elementType), Ref("_r", tcoerce[TStream](right.typ).elementType)), - joinType, requiresMemoryManagement = false, rightKeyIsDistinct = rightDistinct)) + ToArray(StreamJoin.apply( + left, + right, + lKeys, + rKeys, + "_l", + "_r", + joinF( + Ref("_l", tcoerce[TStream](left.typ).elementType), + Ref("_r", tcoerce[TStream](right.typ).elementType), + ), + joinType, + requiresMemoryManagement = false, + rightKeyIsDistinct = rightDistinct, + )) } - @Test def testStreamZipJoin() { + @Test def testStreamZipJoin(): Unit = { def eltType = TStruct("k1" -> TInt32, "k2" -> TString, "idx" -> TInt32) def makeStream(a: IndexedSeq[Integer]): IR = { if (a == null) @@ -2027,20 +2447,33 @@ class IRSuite extends HailSuite { MakeStruct(FastSeq( "k1" -> (if (n == null) NA(TInt32) else I32(n)), "k2" -> Str("x"), - "idx" -> I32(idx)))}, - TStream(eltType)) + "idx" -> I32(idx), + )) + }, + TStream(eltType), + ) } def zipJoin(as: IndexedSeq[IndexedSeq[Integer]], key: Int): IR = { val streams = as.map(makeStream) - val keyRef = Ref(genUID(), TStruct(FastSeq("k1", "k2").take(key).map(k => k -> eltType.fieldType(k)): _*)) + val keyRef = Ref( + genUID(), + TStruct(FastSeq("k1", "k2").take(key).map(k => k -> eltType.fieldType(k)): _*), + ) val valsRef = Ref(genUID(), TArray(eltType)) - ToArray(StreamZipJoin(streams, FastSeq("k1", "k2").take(key), keyRef.name, valsRef.name, InsertFields(keyRef, FastSeq("vals" -> valsRef)))) + ToArray(StreamZipJoin( + streams, + FastSeq("k1", "k2").take(key), + keyRef.name, + valsRef.name, + InsertFields(keyRef, FastSeq("vals" -> valsRef)), + )) } assertEvalsTo( zipJoin(FastSeq(Array[Integer](0, 1, null), null), 1), - null) + null, + ) assertEvalsTo( zipJoin(FastSeq(Array[Integer](0, 1, null), Array[Integer](1, 2, null)), 1), @@ -2049,34 +2482,43 @@ class IRSuite extends HailSuite { Row(1, FastSeq(Row(1, "x", 1), Row(1, "x", 0))), Row(2, FastSeq(null, Row(2, "x", 1))), Row(null, FastSeq(Row(null, "x", 2), null)), - Row(null, FastSeq(null, Row(null, "x", 2))))) + Row(null, FastSeq(null, Row(null, "x", 2))), + ), + ) assertEvalsTo( zipJoin(FastSeq(Array[Integer](0, 1), Array[Integer](1, 2), Array[Integer](0, 2)), 1), FastSeq( Row(0, FastSeq(Row(0, "x", 0), null, Row(0, "x", 0))), Row(1, FastSeq(Row(1, "x", 1), Row(1, "x", 0), null)), - Row(2, FastSeq(null, Row(2, "x", 1), Row(2, "x", 1))))) + Row(2, FastSeq(null, Row(2, "x", 1), Row(2, "x", 1))), + ), + ) assertEvalsTo( zipJoin(FastSeq(Array[Integer](0, 1), Array[Integer](), Array[Integer](0, 2)), 1), FastSeq( Row(0, FastSeq(Row(0, "x", 0), null, Row(0, "x", 0))), Row(1, FastSeq(Row(1, "x", 1), null, null)), - Row(2, FastSeq(null, null, Row(2, "x", 1))))) + Row(2, FastSeq(null, null, Row(2, "x", 1))), + ), + ) assertEvalsTo( zipJoin(FastSeq(Array[Integer](), Array[Integer]()), 1), - FastSeq()) + FastSeq(), + ) assertEvalsTo( zipJoin(FastSeq(Array[Integer](0, 1)), 1), FastSeq( Row(0, FastSeq(Row(0, "x", 0))), - Row(1, FastSeq(Row(1, "x", 1))))) + Row(1, FastSeq(Row(1, "x", 1))), + ), + ) } - @Test def testStreamMultiMerge() { + @Test def testStreamMultiMerge(): Unit = { def eltType = TStruct("k1" -> TInt32, "k2" -> TString, "idx" -> TInt32) def makeStream(a: IndexedSeq[Integer]): IR = { if (a == null) @@ -2087,8 +2529,11 @@ class IRSuite extends HailSuite { MakeStruct(FastSeq( "k1" -> (if (n == null) NA(TInt32) else I32(n)), "k2" -> Str("x"), - "idx" -> I32(idx)))}, - TStream(eltType)) + "idx" -> I32(idx), + )) + }, + TStream(eltType), + ) } def merge(as: IndexedSeq[IndexedSeq[Integer]], key: Int): IR = { @@ -2111,7 +2556,9 @@ class IRSuite extends HailSuite { Row(null, "x", 2), Row(null, "x", 3), Row(null, "x", 2), - Row(null, "x", 3))) + Row(null, "x", 3), + ), + ) assertEvalsTo( merge(FastSeq(Array[Integer](0, 1), Array[Integer](1, 2), Array[Integer](0, 2)), 1), @@ -2121,7 +2568,9 @@ class IRSuite extends HailSuite { Row(1, "x", 1), Row(1, "x", 0), Row(2, "x", 1), - Row(2, "x", 1))) + Row(2, "x", 1), + ), + ) assertEvalsTo( merge(FastSeq(Array[Integer](0, 1), Array[Integer](), Array[Integer](0, 2)), 1), @@ -2129,30 +2578,55 @@ class IRSuite extends HailSuite { Row(0, "x", 0), Row(0, "x", 0), Row(1, "x", 1), - Row(2, "x", 1))) + Row(2, "x", 1), + ), + ) assertEvalsTo( merge(FastSeq(Array[Integer](), Array[Integer]()), 1), - FastSeq()) + FastSeq(), + ) assertEvalsTo( merge(FastSeq(Array[Integer](0, 1)), 1), FastSeq( Row(0, "x", 0), - Row(1, "x", 1))) + Row(1, "x", 1), + ), + ) } - @Test def testJoinRightDistinct() { + @Test def testJoinRightDistinct(): Unit = { implicit val execStrats = ExecStrategy.javaOnly def joinRows(left: IndexedSeq[Integer], right: IndexedSeq[Integer], joinType: String): IR = { join( - MakeStream.unify(ctx, left.zipWithIndex.map { case (n, idx) => MakeStruct(FastSeq("lk1" -> (if (n == null) NA(TInt32) else I32(n)), "lk2" -> Str("x"), "a" -> I64(idx))) }), - MakeStream.unify(ctx, right.zipWithIndex.map { case (n, idx) => MakeStruct(FastSeq("b" -> I32(idx), "rk2" -> Str("x"), "rk1" -> (if (n == null) NA(TInt32) else I32(n)), "c" -> Str("foo"))) }), + MakeStream.unify( + ctx, + left.zipWithIndex.map { case (n, idx) => + MakeStruct(FastSeq( + "lk1" -> (if (n == null) NA(TInt32) else I32(n)), + "lk2" -> Str("x"), + "a" -> I64(idx), + )) + }, + ), + MakeStream.unify( + ctx, + right.zipWithIndex.map { case (n, idx) => + MakeStruct(FastSeq( + "b" -> I32(idx), + "rk2" -> Str("x"), + "rk1" -> (if (n == null) NA(TInt32) else I32(n)), + "c" -> Str("foo"), + )) + }, + ), FastSeq("lk1", "lk2"), FastSeq("rk1", "rk2"), rightDistinct = true, - joinType) + joinType, + ) } def leftJoinRows(left: IndexedSeq[Integer], right: IndexedSeq[Integer]): IR = joinRows(left, right, "left") @@ -2162,61 +2636,107 @@ class IRSuite extends HailSuite { assertEvalsTo( join( NA(TStream(TStruct("k1" -> TInt32, "k2" -> TString, "a" -> TInt64))), - MakeStream.unify(ctx, IndexedSeq(MakeStruct(FastSeq("b" -> I32(0), "k2" -> Str("x"), "k1" -> I32(3), "c" -> Str("foo"))))), + MakeStream.unify( + ctx, + IndexedSeq(MakeStruct(FastSeq( + "b" -> I32(0), + "k2" -> Str("x"), + "k1" -> I32(3), + "c" -> Str("foo"), + ))), + ), FastSeq("k1", "k2"), FastSeq("k1", "k2"), true, - "left"), - null) + "left", + ), + null, + ) assertEvalsTo( join( - MakeStream.unify(ctx, IndexedSeq(MakeStruct(FastSeq("k1" -> I32(0), "k2" -> Str("x"), "a" -> I64(3))))), + MakeStream.unify( + ctx, + IndexedSeq(MakeStruct(FastSeq("k1" -> I32(0), "k2" -> Str("x"), "a" -> I64(3)))), + ), NA(TStream(TStruct("b" -> TInt32, "k2" -> TString, "k1" -> TInt32, "c" -> TString))), FastSeq("k1", "k2"), FastSeq("k1", "k2"), true, - "left"), - null) + "left", + ), + null, + ) - assertEvalsTo(leftJoinRows(Array[Integer](0, null), Array[Integer](1, null)), FastSeq( - Row(0, "x", 0L, null, null), - Row(null, "x", 1L, null, null))) + assertEvalsTo( + leftJoinRows(Array[Integer](0, null), Array[Integer](1, null)), + FastSeq( + Row(0, "x", 0L, null, null), + Row(null, "x", 1L, null, null), + ), + ) - assertEvalsTo(outerJoinRows(Array[Integer](0, null), Array[Integer](1, null)), FastSeq( - Row(0, "x", 0L, null, null), - Row(1, "x", null, 0, "foo"), - Row(null, "x", 1L, null, null), - Row(null, "x", null, 1, "foo"))) + assertEvalsTo( + outerJoinRows(Array[Integer](0, null), Array[Integer](1, null)), + FastSeq( + Row(0, "x", 0L, null, null), + Row(1, "x", null, 0, "foo"), + Row(null, "x", 1L, null, null), + Row(null, "x", null, 1, "foo"), + ), + ) - assertEvalsTo(leftJoinRows(Array[Integer](0, 1, 2), Array[Integer](1)), FastSeq( - Row(0, "x", 0L, null, null), - Row(1, "x", 1L, 0, "foo"), - Row(2, "x", 2L, null, null))) + assertEvalsTo( + leftJoinRows(Array[Integer](0, 1, 2), Array[Integer](1)), + FastSeq( + Row(0, "x", 0L, null, null), + Row(1, "x", 1L, 0, "foo"), + Row(2, "x", 2L, null, null), + ), + ) - assertEvalsTo(leftJoinRows(Array[Integer](0, 1, 2), Array[Integer](-1, 0, 0, 1, 1, 2, 2, 3)), FastSeq( - Row(0, "x", 0L, 1, "foo"), - Row(1, "x", 1L, 3, "foo"), - Row(2, "x", 2L, 5, "foo"))) + assertEvalsTo( + leftJoinRows(Array[Integer](0, 1, 2), Array[Integer](-1, 0, 0, 1, 1, 2, 2, 3)), + FastSeq( + Row(0, "x", 0L, 1, "foo"), + Row(1, "x", 1L, 3, "foo"), + Row(2, "x", 2L, 5, "foo"), + ), + ) - assertEvalsTo(leftJoinRows(Array[Integer](0, 1, 1, 2), Array[Integer](-1, 0, 0, 1, 1, 2, 2, 3)), FastSeq( - Row(0, "x", 0L, 1, "foo"), - Row(1, "x", 1L, 3, "foo"), - Row(1, "x", 2L, 3, "foo"), - Row(2, "x", 3L, 5, "foo"))) + assertEvalsTo( + leftJoinRows(Array[Integer](0, 1, 1, 2), Array[Integer](-1, 0, 0, 1, 1, 2, 2, 3)), + FastSeq( + Row(0, "x", 0L, 1, "foo"), + Row(1, "x", 1L, 3, "foo"), + Row(1, "x", 2L, 3, "foo"), + Row(2, "x", 3L, 5, "foo"), + ), + ) } - @Test def testStreamJoin() { + @Test def testStreamJoin(): Unit = { implicit val execStrats = ExecStrategy.javaOnly def joinRows(left: IndexedSeq[Integer], right: IndexedSeq[Integer], joinType: String): IR = { join( - MakeStream.unify(ctx, left.zipWithIndex.map { case (n, idx) => MakeStruct(FastSeq("lk" -> (if (n == null) NA(TInt32) else I32(n)), "l" -> I32(idx))) }), - MakeStream.unify(ctx, right.zipWithIndex.map { case (n, idx) => MakeStruct(FastSeq("rk" -> (if (n == null) NA(TInt32) else I32(n)), "r" -> I32(idx))) }), + MakeStream.unify( + ctx, + left.zipWithIndex.map { case (n, idx) => + MakeStruct(FastSeq("lk" -> (if (n == null) NA(TInt32) else I32(n)), "l" -> I32(idx))) + }, + ), + MakeStream.unify( + ctx, + right.zipWithIndex.map { case (n, idx) => + MakeStruct(FastSeq("rk" -> (if (n == null) NA(TInt32) else I32(n)), "r" -> I32(idx))) + }, + ), FastSeq("lk"), FastSeq("rk"), false, - joinType) + joinType, + ) } def leftJoinRows(left: IndexedSeq[Integer], right: IndexedSeq[Integer]): IR = joinRows(left, right, "left") @@ -2227,151 +2747,219 @@ class IRSuite extends HailSuite { def rightJoinRows(left: IndexedSeq[Integer], right: IndexedSeq[Integer]): IR = joinRows(left, right, "right") - assertEvalsTo(leftJoinRows(Array[Integer](1, 1, 2, 2, null, null), Array[Integer](0, 0, 1, 1, 3, 3, null, null)), FastSeq( - Row(1, 0, 2), - Row(1, 0, 3), - Row(1, 1, 2), - Row(1, 1, 3), - Row(2, 2, null), - Row(2, 3, null), - Row(null, 4, null), - Row(null, 5, null))) - - assertEvalsTo(outerJoinRows(Array[Integer](1, 1, 2, 2, null, null), Array[Integer](0, 0, 1, 1, 3, 3, null, null)), FastSeq( - Row(0, null, 0), - Row(0, null, 1), - Row(1, 0, 2), - Row(1, 0, 3), - Row(1, 1, 2), - Row(1, 1, 3), - Row(2, 2, null), - Row(2, 3, null), - Row(3, null, 4), - Row(3, null, 5), - Row(null, 4, null), - Row(null, 5, null), - Row(null, null, 6), - Row(null, null, 7))) - - assertEvalsTo(innerJoinRows(Array[Integer](1, 1, 2, 2, null, null), Array[Integer](0, 0, 1, 1, 3, 3, null, null)), FastSeq( - Row(1, 0, 2), - Row(1, 0, 3), - Row(1, 1, 2), - Row(1, 1, 3))) - - assertEvalsTo(rightJoinRows(Array[Integer](1, 1, 2, 2, null, null), Array[Integer](0, 0, 1, 1, 3, 3, null, null)), FastSeq( - Row(0, null, 0), - Row(0, null, 1), - Row(1, 0, 2), - Row(1, 0, 3), - Row(1, 1, 2), - Row(1, 1, 3), - Row(3, null, 4), - Row(3, null, 5), - Row(null, null, 6), - Row(null, null, 7))) - } - - @Test def testStreamMerge() { + assertEvalsTo( + leftJoinRows( + Array[Integer](1, 1, 2, 2, null, null), + Array[Integer](0, 0, 1, 1, 3, 3, null, null), + ), + FastSeq( + Row(1, 0, 2), + Row(1, 0, 3), + Row(1, 1, 2), + Row(1, 1, 3), + Row(2, 2, null), + Row(2, 3, null), + Row(null, 4, null), + Row(null, 5, null), + ), + ) + + assertEvalsTo( + outerJoinRows( + Array[Integer](1, 1, 2, 2, null, null), + Array[Integer](0, 0, 1, 1, 3, 3, null, null), + ), + FastSeq( + Row(0, null, 0), + Row(0, null, 1), + Row(1, 0, 2), + Row(1, 0, 3), + Row(1, 1, 2), + Row(1, 1, 3), + Row(2, 2, null), + Row(2, 3, null), + Row(3, null, 4), + Row(3, null, 5), + Row(null, 4, null), + Row(null, 5, null), + Row(null, null, 6), + Row(null, null, 7), + ), + ) + + assertEvalsTo( + innerJoinRows( + Array[Integer](1, 1, 2, 2, null, null), + Array[Integer](0, 0, 1, 1, 3, 3, null, null), + ), + FastSeq( + Row(1, 0, 2), + Row(1, 0, 3), + Row(1, 1, 2), + Row(1, 1, 3), + ), + ) + + assertEvalsTo( + rightJoinRows( + Array[Integer](1, 1, 2, 2, null, null), + Array[Integer](0, 0, 1, 1, 3, 3, null, null), + ), + FastSeq( + Row(0, null, 0), + Row(0, null, 1), + Row(1, 0, 2), + Row(1, 0, 3), + Row(1, 1, 2), + Row(1, 1, 3), + Row(3, null, 4), + Row(3, null, 5), + Row(null, null, 6), + Row(null, null, 7), + ), + ) + } + + @Test def testStreamMerge(): Unit = { implicit val execStrats = ExecStrategy.compileOnly def mergeRows(left: IndexedSeq[Integer], right: IndexedSeq[Integer], key: Int): IR = { val typ = TStream(TStruct("k" -> TInt32, "sign" -> TInt32, "idx" -> TInt32)) - ToArray(StreamMultiMerge(FastSeq( - if (left == null) - NA(typ) - else - MakeStream(left.zipWithIndex.map { case (n, idx) => - MakeStruct(FastSeq( - "k" -> (if (n == null) NA(TInt32) else I32(n)), - "sign" -> I32(1), - "idx" -> I32(idx))) - }, typ), - if (right == null) - NA(typ) - else - MakeStream(right.zipWithIndex.map { case (n, idx) => - MakeStruct(FastSeq( - "k" -> (if (n == null) NA(TInt32) else I32(n)), - "sign" -> I32(-1), - "idx" -> I32(idx))) - }, typ)), - FastSeq("k", "sign").take(key))) + ToArray(StreamMultiMerge( + FastSeq( + if (left == null) + NA(typ) + else + MakeStream( + left.zipWithIndex.map { case (n, idx) => + MakeStruct(FastSeq( + "k" -> (if (n == null) NA(TInt32) else I32(n)), + "sign" -> I32(1), + "idx" -> I32(idx), + )) + }, + typ, + ), + if (right == null) + NA(typ) + else + MakeStream( + right.zipWithIndex.map { case (n, idx) => + MakeStruct(FastSeq( + "k" -> (if (n == null) NA(TInt32) else I32(n)), + "sign" -> I32(-1), + "idx" -> I32(idx), + )) + }, + typ, + ), + ), + FastSeq("k", "sign").take(key), + )) } - assertEvalsTo(mergeRows(Array[Integer](1, 1, 2, 2, null, null), Array[Integer](0, 0, 1, 1, 3, 3, null, null), 1), FastSeq( - Row(0, -1, 0), - Row(0, -1, 1), - Row(1, 1, 0), - Row(1, 1, 1), - Row(1, -1, 2), - Row(1, -1, 3), - Row(2, 1, 2), - Row(2, 1, 3), - Row(3, -1, 4), - Row(3, -1, 5), - Row(null, 1, 4), - Row(null, 1, 5), - Row(null, -1, 6), - Row(null, -1, 7) - )) + assertEvalsTo( + mergeRows( + Array[Integer](1, 1, 2, 2, null, null), + Array[Integer](0, 0, 1, 1, 3, 3, null, null), + 1, + ), + FastSeq( + Row(0, -1, 0), + Row(0, -1, 1), + Row(1, 1, 0), + Row(1, 1, 1), + Row(1, -1, 2), + Row(1, -1, 3), + Row(2, 1, 2), + Row(2, 1, 3), + Row(3, -1, 4), + Row(3, -1, 5), + Row(null, 1, 4), + Row(null, 1, 5), + Row(null, -1, 6), + Row(null, -1, 7), + ), + ) // right stream ends first - assertEvalsTo(mergeRows(Array[Integer](1, 1, 2, 2), Array[Integer](0, 0, 1, 1), 1), FastSeq( - Row(0, -1, 0), - Row(0, -1, 1), - Row(1, 1, 0), - Row(1, 1, 1), - Row(1, -1, 2), - Row(1, -1, 3), - Row(2, 1, 2), - Row(2, 1, 3))) + assertEvalsTo( + mergeRows(Array[Integer](1, 1, 2, 2), Array[Integer](0, 0, 1, 1), 1), + FastSeq( + Row(0, -1, 0), + Row(0, -1, 1), + Row(1, 1, 0), + Row(1, 1, 1), + Row(1, -1, 2), + Row(1, -1, 3), + Row(2, 1, 2), + Row(2, 1, 3), + ), + ) // compare on two key fields - assertEvalsTo(mergeRows(Array[Integer](1, 1, 2, 2, null, null), Array[Integer](0, 0, 1, 1, 3, 3, null, null), 2), FastSeq( - Row(0, -1, 0), - Row(0, -1, 1), - Row(1, -1, 2), - Row(1, -1, 3), - Row(1, 1, 0), - Row(1, 1, 1), - Row(2, 1, 2), - Row(2, 1, 3), - Row(3, -1, 4), - Row(3, -1, 5), - Row(null, 1, 4), - Row(null, 1, 5), - Row(null, -1, 6), - Row(null, -1, 7))) + assertEvalsTo( + mergeRows( + Array[Integer](1, 1, 2, 2, null, null), + Array[Integer](0, 0, 1, 1, 3, 3, null, null), + 2, + ), + FastSeq( + Row(0, -1, 0), + Row(0, -1, 1), + Row(1, -1, 2), + Row(1, -1, 3), + Row(1, 1, 0), + Row(1, 1, 1), + Row(2, 1, 2), + Row(2, 1, 3), + Row(3, -1, 4), + Row(3, -1, 5), + Row(null, 1, 4), + Row(null, 1, 5), + Row(null, -1, 6), + Row(null, -1, 7), + ), + ) // right stream empty - assertEvalsTo(mergeRows(Array[Integer](1, 2, null), Array[Integer](), 1), FastSeq( - Row(1, 1, 0), - Row(2, 1, 1), - Row(null, 1, 2))) + assertEvalsTo( + mergeRows(Array[Integer](1, 2, null), Array[Integer](), 1), + FastSeq( + Row(1, 1, 0), + Row(2, 1, 1), + Row(null, 1, 2), + ), + ) // left stream empty - assertEvalsTo(mergeRows(Array[Integer](), Array[Integer](1, 2, null), 1), FastSeq( - Row(1, -1, 0), - Row(2, -1, 1), - Row(null, -1, 2))) + assertEvalsTo( + mergeRows(Array[Integer](), Array[Integer](1, 2, null), 1), + FastSeq( + Row(1, -1, 0), + Row(2, -1, 1), + Row(null, -1, 2), + ), + ) // one stream missing assertEvalsTo(mergeRows(null, Array[Integer](1, 2, null), 1), null) assertEvalsTo(mergeRows(Array[Integer](1, 2, null), null, 1), null) } - @Test def testDie() { + @Test def testDie(): Unit = { assertFatal(Die("mumblefoo", TFloat64), "mble") assertFatal(Die(NA(TString), TFloat64, -1), "message missing") } - @Test def testStreamRange() { - def assertEquals(start: Integer, stop: Integer, step: Integer, expected: IndexedSeq[Int]) { - assertEvalsTo(ToArray(StreamRange(In(0, TInt32), In(1, TInt32), In(2, TInt32))), + @Test def testStreamRange(): Unit = { + def assertEquals(start: Integer, stop: Integer, step: Integer, expected: IndexedSeq[Int]) + : Unit = + assertEvalsTo( + ToArray(StreamRange(In(0, TInt32), In(1, TInt32), In(2, TInt32))), args = FastSeq(start -> TInt32, stop -> TInt32, step -> TInt32), - expected = expected) - } + expected = expected, + ) assertEquals(0, 5, null, null) assertEquals(0, null, 1, null) assertEquals(null, 5, 1, null) @@ -2391,7 +2979,7 @@ class IRSuite extends HailSuite { assertEquals(Int.MinValue, Int.MaxValue, Int.MaxValue / 5, expected) } - @Test def testArrayAgg() { + @Test def testArrayAgg(): Unit = { implicit val execStrats = ExecStrategy.compileOnly val sumSig = AggSignature(Sum(), IndexedSeq(), IndexedSeq(TInt64)) @@ -2399,48 +2987,76 @@ class IRSuite extends HailSuite { StreamAgg( StreamMap(StreamRange(I32(0), I32(4), I32(1)), "x", Cast(Ref("x", TInt32), TInt64)), "x", - ApplyAggOp(FastSeq.empty, FastSeq(Ref("x", TInt64)), sumSig)), - 6L) + ApplyAggOp(FastSeq.empty, FastSeq(Ref("x", TInt64)), sumSig), + ), + 6L, + ) } - @Test def testArrayAggContexts() { + @Test def testArrayAggContexts(): Unit = { implicit val execStrats = ExecStrategy.compileOnly - val ir = Let(FastSeq("x" -> (In(0, TInt32) * In(0, TInt32))), // multiply to prevent forwarding + val ir = Let( + FastSeq("x" -> (In(0, TInt32) * In(0, TInt32))), // multiply to prevent forwarding StreamAgg( StreamRange(I32(0), I32(10), I32(1)), "elt", - AggLet("y", - Cast(Ref("x", TInt32) * Ref("x", TInt32) * Ref("elt", TInt32), TInt64), // different type to trigger validation errors - invoke("append", TArray(TArray(TInt32)), - ApplyAggOp(FastSeq(), FastSeq( - MakeArray(FastSeq( - Ref("x", TInt32), - Ref("elt", TInt32), - Cast(Ref("y", TInt64), TInt32), - Cast(Ref("y", TInt64), TInt32)), // reference y twice to prevent forwarding - TArray(TInt32))), - AggSignature(Collect(), FastSeq(), FastSeq(TArray(TInt32)))), - MakeArray(FastSeq(Ref("x", TInt32)), TArray(TInt32))), - isScan = false))) - - assertEvalsTo(ir, FastSeq(1 -> TInt32), - (0 until 10).map(i => FastSeq(1, i, i, i)) ++ FastSeq(FastSeq(1))) - } - - @Test def testStreamAggScan() { + AggLet( + "y", + Cast( + Ref("x", TInt32) * Ref("x", TInt32) * Ref("elt", TInt32), + TInt64, + ), // different type to trigger validation errors + invoke( + "append", + TArray(TArray(TInt32)), + ApplyAggOp( + FastSeq(), + FastSeq( + MakeArray( + FastSeq( + Ref("x", TInt32), + Ref("elt", TInt32), + Cast(Ref("y", TInt64), TInt32), + Cast(Ref("y", TInt64), TInt32), + ), // reference y twice to prevent forwarding + TArray(TInt32), + ) + ), + AggSignature(Collect(), FastSeq(), FastSeq(TArray(TInt32))), + ), + MakeArray(FastSeq(Ref("x", TInt32)), TArray(TInt32)), + ), + isScan = false, + ), + ), + ) + + assertEvalsTo( + ir, + FastSeq(1 -> TInt32), + (0 until 10).map(i => FastSeq(1, i, i, i)) ++ FastSeq(FastSeq(1)), + ) + } + + @Test def testStreamAggScan(): Unit = { implicit val execStrats = ExecStrategy.compileOnly val eltType = TStruct("x" -> TCall, "y" -> TInt32) - val ir = (StreamAggScan(ToStream(In(0, TArray(eltType))), + val ir = (StreamAggScan( + ToStream(In(0, TArray(eltType))), "foo", GetField(Ref("foo", eltType), "y") + - GetField(ApplyScanOp( - FastSeq(I32(2)), - FastSeq(GetField(Ref("foo", eltType), "x")), - AggSignature(CallStats(), FastSeq(TInt32), FastSeq(TCall)) - ), "AN"))) + GetField( + ApplyScanOp( + FastSeq(I32(2)), + FastSeq(GetField(Ref("foo", eltType), "x")), + AggSignature(CallStats(), FastSeq(TInt32), FastSeq(TCall)), + ), + "AN", + ), + )) val input = FastSeq( Row(null, 1), @@ -2448,16 +3064,19 @@ class IRSuite extends HailSuite { Row(Call2(0, 1), 3), Row(Call2(1, 1), 4), null, - Row(null, 5)) -> TArray(eltType) + Row(null, 5), + ) -> TArray(eltType) - assertEvalsTo(ToArray(ir), + assertEvalsTo( + ToArray(ir), args = FastSeq(input), - expected = FastSeq(1 + 0, 2 + 0, 3 + 2, 4 + 4, null, 5 + 6)) + expected = FastSeq(1 + 0, 2 + 0, 3 + 2, 4 + 4, null, 5 + 6), + ) - assertEvalsTo(StreamLen(ir), args=FastSeq(input), 6) + assertEvalsTo(StreamLen(ir), args = FastSeq(input), 6) } - @Test def testInsertFields() { + @Test def testInsertFields(): Unit = { implicit val execStrats = ExecStrategy.javaOnly val s = TStruct("a" -> TInt64, "b" -> TString) @@ -2466,97 +3085,123 @@ class IRSuite extends HailSuite { assertEvalsTo( InsertFields( NA(s), - IndexedSeq()), - null) + IndexedSeq(), + ), + null, + ) assertEvalsTo( InsertFields( emptyStruct, - IndexedSeq("a" -> I64(5))), - Row(5L, null)) + IndexedSeq("a" -> I64(5)), + ), + Row(5L, null), + ) assertEvalsTo( InsertFields( emptyStruct, - IndexedSeq("c" -> F64(3.2))), - Row(null, null, 3.2)) + IndexedSeq("c" -> F64(3.2)), + ), + Row(null, null, 3.2), + ) assertEvalsTo( InsertFields( emptyStruct, - IndexedSeq("c" -> NA(TFloat64))), - Row(null, null, null)) + IndexedSeq("c" -> NA(TFloat64)), + ), + Row(null, null, null), + ) assertEvalsTo( InsertFields( MakeStruct(IndexedSeq("a" -> NA(TInt64), "b" -> Str("abc"))), - IndexedSeq()), - Row(null, "abc")) + IndexedSeq(), + ), + Row(null, "abc"), + ) assertEvalsTo( InsertFields( MakeStruct(IndexedSeq("a" -> NA(TInt64), "b" -> Str("abc"))), - IndexedSeq("a" -> I64(5))), - Row(5L, "abc")) + IndexedSeq("a" -> I64(5)), + ), + Row(5L, "abc"), + ) assertEvalsTo( InsertFields( MakeStruct(IndexedSeq("a" -> NA(TInt64), "b" -> Str("abc"))), - IndexedSeq("c" -> F64(3.2))), - Row(null, "abc", 3.2)) + IndexedSeq("c" -> F64(3.2)), + ), + Row(null, "abc", 3.2), + ) assertEvalsTo( InsertFields(NA(TStruct("a" -> TInt32)), IndexedSeq("foo" -> I32(5))), - null + null, ) assertEvalsTo( InsertFields( In(0, s), IndexedSeq("c" -> F64(3.2), "d" -> F64(5.5), "e" -> F64(6.6)), - Some(FastSeq("c", "d", "e", "a", "b"))), + Some(FastSeq("c", "d", "e", "a", "b")), + ), FastSeq(Row(null, "abc") -> s), - Row(3.2, 5.5, 6.6, null, "abc")) + Row(3.2, 5.5, 6.6, null, "abc"), + ) assertEvalsTo( InsertFields( In(0, s), IndexedSeq("c" -> F64(3.2), "d" -> F64(5.5), "e" -> F64(6.6)), - Some(FastSeq("a", "b", "c", "d", "e"))), + Some(FastSeq("a", "b", "c", "d", "e")), + ), FastSeq(Row(null, "abc") -> s), - Row(null, "abc", 3.2, 5.5, 6.6)) + Row(null, "abc", 3.2, 5.5, 6.6), + ) assertEvalsTo( InsertFields( In(0, s), IndexedSeq("c" -> F64(3.2), "d" -> F64(5.5), "e" -> F64(6.6)), - Some(FastSeq("c", "a", "d", "b", "e"))), + Some(FastSeq("c", "a", "d", "b", "e")), + ), FastSeq(Row(null, "abc") -> s), - Row(3.2, null, 5.5, "abc", 6.6)) + Row(3.2, null, 5.5, "abc", 6.6), + ) } - @Test def testSelectFields() { + @Test def testSelectFields(): Unit = { assertEvalsTo( SelectFields( NA(TStruct("foo" -> TInt32, "bar" -> TFloat64)), - FastSeq("foo")), - null) + FastSeq("foo"), + ), + null, + ) assertEvalsTo( SelectFields( MakeStruct(FastSeq("foo" -> 6, "bar" -> 0.0)), - FastSeq("foo")), - Row(6)) + FastSeq("foo"), + ), + Row(6), + ) assertEvalsTo( SelectFields( MakeStruct(FastSeq("a" -> 6, "b" -> 0.0, "c" -> 3L)), - FastSeq("b", "a")), - Row(0.0, 6)) + FastSeq("b", "a"), + ), + Row(0.0, 6), + ) } - @Test def testGetField() { + @Test def testGetField(): Unit = { implicit val execStrats = ExecStrategy.javaOnly val s = MakeStruct(IndexedSeq("a" -> NA(TInt64), "b" -> Str("abc"))) @@ -2567,43 +3212,59 @@ class IRSuite extends HailSuite { assertEvalsTo(GetField(na, "a"), null) } - @Test def testLiteral() { - implicit val execStrats = Set(ExecStrategy.Interpret, ExecStrategy.InterpretUnoptimized, ExecStrategy.JvmCompile) - val poopEmoji = new String(Array[Char](0xD83D, 0xDCA9)) + @Test def testLiteral(): Unit = { + implicit val execStrats = + Set(ExecStrategy.Interpret, ExecStrategy.InterpretUnoptimized, ExecStrategy.JvmCompile) + val poopEmoji = new String(Array[Char](0xd83d, 0xdca9)) val types = Array( TTuple(TInt32, TString, TArray(TInt32)), TArray(TString), - TDict(TInt32, TString) + TDict(TInt32, TString), ) val values = Array( - Row(400, "foo"+poopEmoji, FastSeq(4, 6, 8)), + Row(400, "foo" + poopEmoji, FastSeq(4, 6, 8)), FastSeq(poopEmoji, "", "foo"), - Map[Int, String](1 -> "", 5 -> "foo", -4 -> poopEmoji) + Map[Int, String](1 -> "", 5 -> "foo", -4 -> poopEmoji), ) assertEvalsTo(Literal(types(0), values(0)), values(0)) - assertEvalsTo(MakeTuple.ordered(types.zip(values).map { case (t, v) => Literal(t, v) }), Row.fromSeq(values.toFastSeq)) - assertEvalsTo(Str("hello"+poopEmoji), "hello"+poopEmoji) + assertEvalsTo( + MakeTuple.ordered(types.zip(values).map { case (t, v) => Literal(t, v) }), + Row.fromSeq(values.toFastSeq), + ) + assertEvalsTo(Str("hello" + poopEmoji), "hello" + poopEmoji) } - @Test def testSameLiteralsWithDifferentTypes() { - assertEvalsTo(ApplyComparisonOp(EQ(TArray(TInt32)), - ToArray(StreamMap(ToStream(Literal(TArray(TFloat64), FastSeq(1.0, 2.0))), "elt", Cast(Ref("elt", TFloat64), TInt32))), - Literal(TArray(TInt32), FastSeq(1, 2))), true) + @Test def testSameLiteralsWithDifferentTypes(): Unit = { + assertEvalsTo( + ApplyComparisonOp( + EQ(TArray(TInt32)), + ToArray(StreamMap( + ToStream(Literal(TArray(TFloat64), FastSeq(1.0, 2.0))), + "elt", + Cast(Ref("elt", TFloat64), TInt32), + )), + Literal(TArray(TInt32), FastSeq(1, 2)), + ), + true, + ) } - @Test def testTableCount() { + @Test def testTableCount(): Unit = { implicit val execStrats = Set(ExecStrategy.Interpret, ExecStrategy.InterpretUnoptimized) assertEvalsTo(TableCount(TableRange(0, 4)), 0L) assertEvalsTo(TableCount(TableRange(7, 4)), 7L) } - @Test def testTableGetGlobals() { + @Test def testTableGetGlobals(): Unit = { implicit val execStrats = ExecStrategy.interpretOnly - assertEvalsTo(TableGetGlobals(TableMapGlobals(TableRange(0, 1), Literal(TStruct("a" -> TInt32), Row(1)))), Row(1)) + assertEvalsTo( + TableGetGlobals(TableMapGlobals(TableRange(0, 1), Literal(TStruct("a" -> TInt32), Row(1)))), + Row(1), + ) } - @Test def testTableAggregate() { + @Test def testTableAggregate(): Unit = { implicit val execStrats = ExecStrategy.allRelational val table = TableRange(3, 2) @@ -2612,7 +3273,7 @@ class IRSuite extends HailSuite { assertEvalsTo(TableAggregate(table, MakeStruct(IndexedSeq("foo" -> count))), Row(3L)) } - @Test def testMatrixAggregate() { + @Test def testMatrixAggregate(): Unit = { implicit val execStrats = ExecStrategy.interpretOnly val matrix = MatrixIR.range(5, 5, None) @@ -2621,15 +3282,31 @@ class IRSuite extends HailSuite { assertEvalsTo(MatrixAggregate(matrix, MakeStruct(IndexedSeq("foo" -> count))), Row(25L)) } - @Test def testGroupByKey() { - implicit val execStrats = Set(ExecStrategy.Interpret, ExecStrategy.InterpretUnoptimized, ExecStrategy.JvmCompile, ExecStrategy.JvmCompileUnoptimized) + @Test def testGroupByKey(): Unit = { + implicit val execStrats = Set( + ExecStrategy.Interpret, + ExecStrategy.InterpretUnoptimized, + ExecStrategy.JvmCompile, + ExecStrategy.JvmCompileUnoptimized, + ) def tuple(k: String, v: Int): IR = MakeTuple.ordered(IndexedSeq(Str(k), I32(v))) - def groupby(tuples: IR*): IR = GroupByKey(MakeStream(tuples.toArray[IR], TStream(TTuple(TString, TInt32)))) + def groupby(tuples: IR*): IR = + GroupByKey(MakeStream(tuples.toArray[IR], TStream(TTuple(TString, TInt32)))) - val collection1 = groupby(tuple("foo", 0), tuple("bar", 4), tuple("foo", -1), tuple("bar", 0), tuple("foo", 10), tuple("", 0)) - assertEvalsTo(collection1, Map("" -> FastSeq(0), "bar" -> FastSeq(4, 0), "foo" -> FastSeq(0, -1, 10))) + val collection1 = groupby( + tuple("foo", 0), + tuple("bar", 4), + tuple("foo", -1), + tuple("bar", 0), + tuple("foo", 10), + tuple("", 0), + ) + assertEvalsTo( + collection1, + Map("" -> FastSeq(0), "bar" -> FastSeq(4, 0), "foo" -> FastSeq(0, -1, 10)), + ) assertEvalsTo(groupby(), Map()) } @@ -2639,44 +3316,95 @@ class IRSuite extends HailSuite { Array(FastSeq(0.0, 0.0), TArray(TFloat64), TArray(TFloat64)), Array(Set(0, 1), TSet(TInt32), TSet(TInt32)), Array(Map(0L -> 5, 3L -> 20), TDict(TInt64, TInt32), TDict(TInt64, TInt32)), - Array(Interval(1, 2, includesStart = false, includesEnd = true), TInterval(TInt32), TInterval(TInt32)), - Array(Row("foo", 0.0), TStruct("a" -> TString, "b" -> TFloat64), TStruct("a" -> TString, "b" -> TFloat64)), + Array( + Interval(1, 2, includesStart = false, includesEnd = true), + TInterval(TInt32), + TInterval(TInt32), + ), + Array( + Row("foo", 0.0), + TStruct("a" -> TString, "b" -> TFloat64), + TStruct("a" -> TString, "b" -> TFloat64), + ), Array(Row("foo", 0.0), TTuple(TString, TFloat64), TTuple(TString, TFloat64)), - Array(Row(FastSeq("foo"), 0.0), TTuple(TArray(TString), TFloat64), TTuple(TArray(TString), TFloat64)) + Array( + Row(FastSeq("foo"), 0.0), + TTuple(TArray(TString), TFloat64), + TTuple(TArray(TString), TFloat64), + ), ) @Test(dataProvider = "compareDifferentTypes") - def testComparisonOpDifferentTypes(a: Any, t1: Type, t2: Type) { + def testComparisonOpDifferentTypes(a: Any, t1: Type, t2: Type): Unit = { implicit val execStrats = ExecStrategy.javaOnly - assertEvalsTo(ApplyComparisonOp(EQ(t1, t2), In(0, t1), In(1, t2)), FastSeq(a -> t1, a -> t2), true) - assertEvalsTo(ApplyComparisonOp(LT(t1, t2), In(0, t1), In(1, t2)), FastSeq(a -> t1, a -> t2), false) - assertEvalsTo(ApplyComparisonOp(GT(t1, t2), In(0, t1), In(1, t2)), FastSeq(a -> t1, a -> t2), false) - assertEvalsTo(ApplyComparisonOp(LTEQ(t1, t2), In(0, t1), In(1, t2)), FastSeq(a -> t1, a -> t2), true) - assertEvalsTo(ApplyComparisonOp(GTEQ(t1, t2), In(0, t1), In(1, t2)), FastSeq(a -> t1, a -> t2), true) - assertEvalsTo(ApplyComparisonOp(NEQ(t1, t2), In(0, t1), In(1, t2)), FastSeq(a -> t1, a -> t2), false) - assertEvalsTo(ApplyComparisonOp(EQWithNA(t1, t2), In(0, t1), In(1, t2)), FastSeq(a -> t1, a -> t2), true) - assertEvalsTo(ApplyComparisonOp(NEQWithNA(t1, t2), In(0, t1), In(1, t2)), FastSeq(a -> t1, a -> t2), false) - assertEvalsTo(ApplyComparisonOp(Compare(t1, t2), In(0, t1), In(1, t2)), FastSeq(a -> t1, a -> t2), 0) + assertEvalsTo( + ApplyComparisonOp(EQ(t1, t2), In(0, t1), In(1, t2)), + FastSeq(a -> t1, a -> t2), + true, + ) + assertEvalsTo( + ApplyComparisonOp(LT(t1, t2), In(0, t1), In(1, t2)), + FastSeq(a -> t1, a -> t2), + false, + ) + assertEvalsTo( + ApplyComparisonOp(GT(t1, t2), In(0, t1), In(1, t2)), + FastSeq(a -> t1, a -> t2), + false, + ) + assertEvalsTo( + ApplyComparisonOp(LTEQ(t1, t2), In(0, t1), In(1, t2)), + FastSeq(a -> t1, a -> t2), + true, + ) + assertEvalsTo( + ApplyComparisonOp(GTEQ(t1, t2), In(0, t1), In(1, t2)), + FastSeq(a -> t1, a -> t2), + true, + ) + assertEvalsTo( + ApplyComparisonOp(NEQ(t1, t2), In(0, t1), In(1, t2)), + FastSeq(a -> t1, a -> t2), + false, + ) + assertEvalsTo( + ApplyComparisonOp(EQWithNA(t1, t2), In(0, t1), In(1, t2)), + FastSeq(a -> t1, a -> t2), + true, + ) + assertEvalsTo( + ApplyComparisonOp(NEQWithNA(t1, t2), In(0, t1), In(1, t2)), + FastSeq(a -> t1, a -> t2), + false, + ) + assertEvalsTo( + ApplyComparisonOp(Compare(t1, t2), In(0, t1), In(1, t2)), + FastSeq(a -> t1, a -> t2), + 0, + ) } @DataProvider(name = "valueIRs") - def valueIRs(): Array[Array[Object]] = { - withExecuteContext() { ctx => - valueIRs(ctx) - } - } + def valueIRs(): Array[Array[Object]] = + withExecuteContext()(ctx => valueIRs(ctx)) def valueIRs(ctx: ExecuteContext): Array[Array[Object]] = { val fs = ctx.fs - CompileAndEvaluate(ctx, invoke("index_bgen", TInt64, - Array[Type](TLocus("GRCh37")), - Str("src/test/resources/example.8bits.bgen"), - Str("src/test/resources/example.8bits.bgen.idx2"), - Literal(TDict(TString, TString), Map("01" -> "1")), - False(), - I32(1000000))) + CompileAndEvaluate( + ctx, + invoke( + "index_bgen", + TInt64, + Array[Type](TLocus("GRCh37")), + Str("src/test/resources/example.8bits.bgen"), + Str("src/test/resources/example.8bits.bgen.idx2"), + Literal(TDict(TString, TString), Map("01" -> "1")), + False(), + I32(1000000), + ), + ) val b = True() val bin = Ref("bin", TBinary) @@ -2686,11 +3414,14 @@ class IRSuite extends HailSuite { val str = Str("Hail") val a = Ref("a", TArray(TInt32)) val st = Ref("st", TStream(TInt32)) - val whitenStream = Ref("whitenStream", TStream(TStruct("prevWindow" -> TNDArray(TFloat64, Nat(2)), "newChunk" -> TNDArray(TFloat64, Nat(2))))) - val mat = Ref("mat", TNDArray(TFloat64, Nat(2))) - val aa = Ref("aa", TArray(TArray(TInt32))) + val whitenStream = Ref( + "whitenStream", + TStream(TStruct( + "prevWindow" -> TNDArray(TFloat64, Nat(2)), + "newChunk" -> TNDArray(TFloat64, Nat(2)), + )), + ) val sta = Ref("sta", TStream(TArray(TInt32))) - val da = Ref("da", TArray(TTuple(TInt32, TString))) val std = Ref("std", TStream(TTuple(TInt32, TString))) val v = Ref("v", TInt32) val s = Ref("s", TStruct("x" -> TInt32, "y" -> TInt64, "z" -> TFloat64)) @@ -2721,15 +3452,27 @@ class IRSuite extends HailSuite { val mt = MatrixIR.range(20, 2, Some(3)) val vcf = is.hail.TestUtils.importVCF(ctx, "src/test/resources/sample.vcf") - val bgenReader = MatrixBGENReader(ctx, FastSeq("src/test/resources/example.8bits.bgen"), None, Map.empty[String, String], None, None, None) + val bgenReader = MatrixBGENReader( + ctx, + FastSeq("src/test/resources/example.8bits.bgen"), + None, + Map.empty[String, String], + None, + None, + None, + ) val bgen = MatrixRead(bgenReader.fullMatrixType, false, false, bgenReader) - val blockMatrix = BlockMatrixRead(BlockMatrixNativeReader(fs, "src/test/resources/blockmatrix_example/0")) + val blockMatrix = + BlockMatrixRead(BlockMatrixNativeReader(fs, "src/test/resources/blockmatrix_example/0")) val blockMatrixWriter = BlockMatrixNativeWriter("/path/to/file.bm", false, false, false) val blockMatrixMultiWriter = BlockMatrixBinaryMultiWriter("/path/to/prefix", false) - val nd = MakeNDArray(MakeArray(FastSeq(I32(-1), I32(1)), TArray(TInt32)), + val nd = MakeNDArray( + MakeArray(FastSeq(I32(-1), I32(1)), TArray(TInt32)), MakeTuple.ordered(FastSeq(I64(1), I64(2))), - True(), ErrorIDs.NO_ERROR) + True(), + ErrorIDs.NO_ERROR, + ) val rngState = RNGStateLiteral() def collect(ir: IR): IR = @@ -2741,11 +3484,19 @@ class IRSuite extends HailSuite { env => env.bindEval(refs.map(r => r.name -> r.typ): _*) val irs = Array[(IR, BindingEnv[Type] => BindingEnv[Type])]( - i, I64(5), F32(3.14f), F64(3.14), str, True(), False(), Void(), + i, + I64(5), + F32(3.14f), + F64(3.14), + str, + True(), + False(), + Void(), UUID4(), Cast(i, TFloat64), CastRename(NA(TStruct("a" -> TInt32)), TStruct("b" -> TInt32)), - NA(TInt32), IsNA(i), + NA(TInt32), + IsNA(i), If(b, i, j), Switch(i, j, 0 until 7 map I32), Coalesce(FastSeq(i, I32(1))), @@ -2767,8 +3518,13 @@ class IRSuite extends HailSuite { NDArrayAgg(nd, FastSeq(0)), NDArrayWrite(nd, Str("/path/to/ndarray")), NDArrayMatMul(nd, nd, ErrorIDs.NO_ERROR), - NDArraySlice(nd, MakeTuple.ordered(FastSeq(MakeTuple.ordered(FastSeq(I64(0), I64(2), I64(1))), - MakeTuple.ordered(FastSeq(I64(0), I64(2), I64(1)))))), + NDArraySlice( + nd, + MakeTuple.ordered(FastSeq( + MakeTuple.ordered(FastSeq(I64(0), I64(2), I64(1))), + MakeTuple.ordered(FastSeq(I64(0), I64(2), I64(1))), + )), + ), NDArrayFilter(nd, FastSeq(NA(TArray(TInt64)), NA(TArray(TInt64)))), ArrayRef(a, i) -> Array(a), ArrayLen(a) -> Array(a), @@ -2789,42 +3545,96 @@ class IRSuite extends HailSuite { StreamTakeWhile(st, "v", v < I32(5)) -> Array(st), StreamDropWhile(st, "v", v < I32(5)) -> Array(st), StreamMap(st, "v", v) -> Array(st), - StreamZip(FastSeq(st, st), FastSeq("foo", "bar"), True(), ArrayZipBehavior.TakeMinLength) -> Array(st), + StreamZip( + FastSeq(st, st), + FastSeq("foo", "bar"), + True(), + ArrayZipBehavior.TakeMinLength, + ) -> Array(st), StreamFilter(st, "v", b) -> Array(st), StreamFlatMap(sta, "a", ToStream(a)) -> Array(sta), StreamFold(st, I32(0), "x", "v", v) -> Array(st), StreamFold2(StreamFold(st, I32(0), "x", "v", v)) -> Array(st), StreamScan(st, I32(0), "x", "v", v) -> Array(st), - StreamWhiten(whitenStream, "newChunk", "prevWindow", 1, 1, 1, 1, false) -> Array(whitenStream), + StreamWhiten(whitenStream, "newChunk", "prevWindow", 1, 1, 1, 1, false) -> Array( + whitenStream + ), StreamJoinRightDistinct( StreamMap(StreamRange(0, 2, 1), "x", MakeStruct(FastSeq("x" -> Ref("x", TInt32)))), StreamMap(StreamRange(0, 3, 1), "x", MakeStruct(FastSeq("x" -> Ref("x", TInt32)))), - FastSeq("x"), FastSeq("x"), "l", "r", I32(1), "left"), + FastSeq("x"), + FastSeq("x"), + "l", + "r", + I32(1), + "left", + ), { + val left = + StreamMap(StreamRange(0, 2, 1), "x", MakeStruct(FastSeq("x" -> Ref("x", TInt32)))) + val right = ToStream(Literal( + TArray(TStruct("a" -> TInterval(TInt32))), + FastSeq(Row(Interval(IntervalEndpoint(0, -1), IntervalEndpoint(1, 1)))), + )) + val lref = Ref("lname", elementType(left.typ)) + val rref = Ref("rname", TArray(elementType(right.typ))) + StreamLeftIntervalJoin( + left, + right, + "x", + "a", + lref.name, + rref.name, + InsertFields(lref, FastSeq("join" -> rref)), + ) + }, StreamFor(st, "v", Void()) -> Array(st), - StreamAgg(st, "x", ApplyAggOp(FastSeq.empty, FastSeq(Cast(Ref("x", TInt32), TInt64)), sumSig)) -> Array(st), - StreamAggScan(st, "x", ApplyScanOp(FastSeq.empty, FastSeq(Cast(Ref("x", TInt32), TInt64)), sumSig)) -> Array(st), - RunAgg(Begin(FastSeq( - InitOp(0, FastSeq(Begin(FastSeq(InitOp(0, FastSeq(), pSumSig)))), groupSignature), - SeqOp(0, FastSeq(I32(1), SeqOp(0, FastSeq(I64(1)), pSumSig)), groupSignature))), - AggStateValue(0, groupSignature.state), FastSeq(groupSignature.state)), - RunAggScan(StreamRange(I32(0), I32(1), I32(1)), + StreamAgg( + st, + "x", + ApplyAggOp(FastSeq.empty, FastSeq(Cast(Ref("x", TInt32), TInt64)), sumSig), + ) -> Array(st), + StreamAggScan( + st, + "x", + ApplyScanOp(FastSeq.empty, FastSeq(Cast(Ref("x", TInt32), TInt64)), sumSig), + ) -> Array(st), + RunAgg( + Begin(FastSeq( + InitOp(0, FastSeq(Begin(FastSeq(InitOp(0, FastSeq(), pSumSig)))), groupSignature), + SeqOp(0, FastSeq(I32(1), SeqOp(0, FastSeq(I64(1)), pSumSig)), groupSignature), + )), + AggStateValue(0, groupSignature.state), + FastSeq(groupSignature.state), + ), + RunAggScan( + StreamRange(I32(0), I32(1), I32(1)), "foo", InitOp(0, FastSeq(Begin(FastSeq(InitOp(0, FastSeq(), pSumSig)))), groupSignature), SeqOp(0, FastSeq(Ref("foo", TInt32), SeqOp(0, FastSeq(I64(1)), pSumSig)), groupSignature), AggStateValue(0, groupSignature.state), - FastSeq(groupSignature.state)), + FastSeq(groupSignature.state), + ), AggFilter(True(), I32(0), false) -> (_.createAgg), AggExplode(NA(TStream(TInt32)), "x", I32(0), false) -> (_.createAgg), AggGroupBy(True(), I32(0), false) -> (_.createAgg), ApplyAggOp(FastSeq.empty, FastSeq(I32(0)), collectSig) -> (_.createAgg), - ApplyAggOp(FastSeq(I32(2)), FastSeq(call), callStatsSig) -> (_.createAgg.bindAgg(call.name, call.typ)), + ApplyAggOp(FastSeq(I32(2)), FastSeq(call), callStatsSig) -> (_.createAgg.bindAgg( + call.name, + call.typ, + )), ApplyAggOp(FastSeq(I32(10)), FastSeq(F64(-2.11), I32(4)), takeBySig) -> (_.createAgg), AggFold(I32(0), l + I32(1), l + r, l.name, r.name, false) -> (_.createAgg), InitOp(0, FastSeq(I32(2)), pCallStatsSig), SeqOp(0, FastSeq(i), pCollectSig), CombOp(0, 1, pCollectSig), ResultOp(0, pCollectSig), - ResultOp(0, PhysicalAggSig(Fold(), FoldStateSig(EmitType(SInt32, true), "accum", "other", Ref("accum", TInt32)))), + ResultOp( + 0, + PhysicalAggSig( + Fold(), + FoldStateSig(EmitType(SInt32, true), "accum", "other", Ref("accum", TInt32)), + ), + ), SerializeAggs(0, 0, BufferSpec.default, FastSeq(pCollectSig.state)), DeserializeAggs(0, 0, BufferSpec.default, FastSeq(pCollectSig.state)), CombOpValue(0, bin, pCollectSig) -> Array(bin), @@ -2856,49 +3666,95 @@ class IRSuite extends HailSuite { MatrixWrite(vcf, MatrixPLINKWriter("/path/to/base")), MatrixWrite(bgen, MatrixGENWriter("/path/to/base")), MatrixWrite(mt, MatrixBlockMatrixWriter("path/to/data/bm", true, "a", 4096)), - MatrixMultiWrite(Array(mt, mt), MatrixNativeMultiWriter(IndexedSeq("/path/to/mt1", "/path/to/mt2"))), - TableMultiWrite(Array(table, table), WrappedMatrixNativeMultiWriter(MatrixNativeMultiWriter(IndexedSeq("/path/to/mt1", "/path/to/mt2")), FastSeq("foo"))), + MatrixMultiWrite( + Array(mt, mt), + MatrixNativeMultiWriter(IndexedSeq("/path/to/mt1", "/path/to/mt2")), + ), + TableMultiWrite( + Array(table, table), + WrappedMatrixNativeMultiWriter( + MatrixNativeMultiWriter(IndexedSeq("/path/to/mt1", "/path/to/mt2")), + FastSeq("foo"), + ), + ), MatrixAggregate(mt, MakeStruct(IndexedSeq("foo" -> count))), BlockMatrixCollect(blockMatrix), BlockMatrixWrite(blockMatrix, blockMatrixWriter), BlockMatrixMultiWrite(IndexedSeq(blockMatrix, blockMatrix), blockMatrixMultiWriter), BlockMatrixWrite(blockMatrix, BlockMatrixPersistWriter("x", "MEMORY_ONLY")), - CollectDistributedArray(StreamRange(0, 3, 1), 1, "x", "y", Ref("x", TInt32), NA(TString), "test"), - ReadPartition(MakeStruct(Array("partitionIndex" -> I64(0), "partitionPath" -> Str("foo"))), + CollectDistributedArray( + StreamRange(0, 3, 1), + 1, + "x", + "y", + Ref("x", TInt32), + NA(TString), + "test", + ), + ReadPartition( + MakeStruct(Array("partitionIndex" -> I64(0), "partitionPath" -> Str("foo"))), TStruct("foo" -> TInt32), PartitionNativeReader( - TypedCodecSpec(PCanonicalStruct("foo" -> PInt32(), "bar" -> PCanonicalString()), BufferSpec.default), - "rowUID")), + TypedCodecSpec( + PCanonicalStruct("foo" -> PInt32(), "bar" -> PCanonicalString()), + BufferSpec.default, + ), + "rowUID", + ), + ), WritePartition( - MakeStream(FastSeq(), TStream(TStruct())), NA(TString), - PartitionNativeWriter(TypedCodecSpec(PType.canonical(TStruct()), BufferSpec.default), IndexedSeq(), "path", None, None)), + MakeStream(FastSeq(), TStream(TStruct())), + NA(TString), + PartitionNativeWriter( + TypedCodecSpec(PType.canonical(TStruct()), BufferSpec.default), + IndexedSeq(), + "path", + None, + None, + ), + ), WriteMetadata( Begin(FastSeq()), - RelationalWriter("path", overwrite = false, None)), - ReadValue(Str("foo"), ETypeValueReader(TypedCodecSpec(PCanonicalStruct("foo" -> PInt32(), "bar" -> PCanonicalString()), BufferSpec.default)), TStruct("foo" -> TInt32)), - WriteValue(I32(1), Str("foo"), ETypeValueWriter(TypedCodecSpec(PInt32(), BufferSpec.default))), - WriteValue(I32(1), Str("foo"), ETypeValueWriter(TypedCodecSpec(PInt32(), BufferSpec.default)), Some(Str("/tmp/uid/part"))), + RelationalWriter("path", overwrite = false, None), + ), + ReadValue( + Str("foo"), + ETypeValueReader(TypedCodecSpec( + PCanonicalStruct("foo" -> PInt32(), "bar" -> PCanonicalString()), + BufferSpec.default, + )), + TStruct("foo" -> TInt32), + ), + WriteValue( + I32(1), + Str("foo"), + ETypeValueWriter(TypedCodecSpec(PInt32(), BufferSpec.default)), + ), + WriteValue( + I32(1), + Str("foo"), + ETypeValueWriter(TypedCodecSpec(PInt32(), BufferSpec.default)), + Some(Str("/tmp/uid/part")), + ), LiftMeOut(I32(1)), RelationalLet("x", I32(0), I32(0)), - TailLoop("y", IndexedSeq("x" -> I32(0)), TInt32, Recur("y", FastSeq(I32(4)), TInt32)) - ) + TailLoop("y", IndexedSeq("x" -> I32(0)), TInt32, Recur("y", FastSeq(I32(4)), TInt32)), + ) val emptyEnv = BindingEnv.empty[Type] irs.map { case (ir, bind) => Array(ir, bind(emptyEnv)) } } @DataProvider(name = "tableIRs") - def tableIRs(): Array[Array[TableIR]] = { - withExecuteContext() { ctx => - tableIRs(ctx) - } - } + def tableIRs(): Array[Array[TableIR]] = + withExecuteContext()(ctx => tableIRs(ctx)) def tableIRs(ctx: ExecuteContext): Array[Array[TableIR]] = { try { val fs = ctx.fs val read = TableIR.read(fs, "src/test/resources/backward_compatability/1.1.0/table/0.ht") - val mtRead = MatrixIR.read(fs, "src/test/resources/backward_compatability/1.0.0/matrix_table/0.hmt") + val mtRead = + MatrixIR.read(fs, "src/test/resources/backward_compatability/1.0.0/matrix_table/0.hmt") val b = True() val xs: Array[TableIR] = Array( @@ -2907,13 +3763,14 @@ class IRSuite extends HailSuite { TableFilter(read, b), read, MatrixColsTable(mtRead), - TableAggregateByKey(read, + TableAggregateByKey( + read, MakeStruct(FastSeq( - "a" -> I32(5)))), - TableKeyByAndAggregate(read, - NA(TStruct.empty), NA(TStruct.empty), Some(1), 2), - TableJoin(read, - TableRange(100, 10), "inner", 1), + "a" -> I32(5) + )), + ), + TableKeyByAndAggregate(read, NA(TStruct.empty), NA(TStruct.empty), Some(1), 2), + TableJoin(read, TableRange(100, 10), "inner", 1), TableLeftJoinRightDistinct(read, TableRange(100, 10), "root"), TableMultiWayZipJoin(FastSeq(read, read), " * data * ", "globals"), MatrixEntriesTable(mtRead), @@ -2923,33 +3780,67 @@ class IRSuite extends HailSuite { TableTail(read, 10), TableParallelize( MakeStruct(FastSeq( - "rows" -> MakeArray(FastSeq( - MakeStruct(FastSeq("a" -> NA(TInt32))), - MakeStruct(FastSeq("a" -> I32(1))) - ), TArray(TStruct("a" -> TInt32))), - "global" -> MakeStruct(FastSeq()))), None), - TableMapRows(TableKeyBy(read, FastSeq()), + "rows" -> MakeArray( + FastSeq( + MakeStruct(FastSeq("a" -> NA(TInt32))), + MakeStruct(FastSeq("a" -> I32(1))), + ), + TArray(TStruct("a" -> TInt32)), + ), + "global" -> MakeStruct(FastSeq()), + )), + None, + ), + TableMapRows( + TableKeyBy(read, FastSeq()), MakeStruct(FastSeq( "a" -> GetField(Ref("row", read.typ.rowType), "f32"), - "b" -> F64(-2.11)))), - TableMapPartitions(TableKeyBy(read, FastSeq()), "g", "rs", StreamTake(Ref("rs", TStream(read.typ.rowType)), 1), 0, 0), - TableMapGlobals(read, + "b" -> F64(-2.11), + )), + ), + TableMapPartitions( + TableKeyBy(read, FastSeq()), + "g", + "rs", + StreamTake(Ref("rs", TStream(read.typ.rowType)), 1), + 0, + 0, + ), + TableMapGlobals( + read, MakeStruct(FastSeq( - "foo" -> NA(TArray(TInt32))))), + "foo" -> NA(TArray(TInt32)) + )), + ), TableRange(100, 10), TableUnion( - FastSeq(TableRange(100, 10), TableRange(50, 10))), + FastSeq(TableRange(100, 10), TableRange(50, 10)) + ), TableExplode(read, Array("mset")), - TableOrderBy(TableKeyBy(read, FastSeq()), FastSeq(SortField("m", Ascending), SortField("m", Descending))), + TableOrderBy( + TableKeyBy(read, FastSeq()), + FastSeq(SortField("m", Ascending), SortField("m", Descending)), + ), CastMatrixToTable(mtRead, " # entries", " # cols"), TableRename(read, Map("idx" -> "idx_foo"), Map("global_f32" -> "global_foo")), - TableFilterIntervals(read, FastSeq(Interval(IntervalEndpoint(Row(0), -1), IntervalEndpoint(Row(10), 1))), keep = false), - RelationalLetTable("x", I32(0), read), - { + TableFilterIntervals( + read, + FastSeq(Interval(IntervalEndpoint(Row(0), -1), IntervalEndpoint(Row(10), 1))), + keep = false, + ), + RelationalLetTable("x", I32(0), read), { val structs = MakeStream(FastSeq(), TStream(TStruct())) val partitioner = RVDPartitioner.empty(ctx.stateManager, TStruct()) - TableGen(structs, MakeStruct(FastSeq()), "cname", "gname", structs, partitioner, errorId = 180) - } + TableGen( + structs, + MakeStruct(FastSeq()), + "cname", + "gname", + structs, + partitioner, + errorId = 180, + ) + }, ) xs.map(x => Array(x)) } catch { @@ -2961,30 +3852,42 @@ class IRSuite extends HailSuite { } @DataProvider(name = "matrixIRs") - def matrixIRs(): Array[Array[MatrixIR]] = { - withExecuteContext() { ctx => - matrixIRs(ctx) - } - } + def matrixIRs(): Array[Array[MatrixIR]] = + withExecuteContext()(ctx => matrixIRs(ctx)) def matrixIRs(ctx: ExecuteContext): Array[Array[MatrixIR]] = { try { val fs = ctx.fs - CompileAndEvaluate(ctx, invoke("index_bgen", TInt64, - Array[Type](TLocus("GRCh37")), - Str("src/test/resources/example.8bits.bgen"), - Str("src/test/resources/example.8bits.bgen.idx2"), - Literal(TDict(TString, TString), Map("01" -> "1")), - False(), - I32(1000000))) + CompileAndEvaluate( + ctx, + invoke( + "index_bgen", + TInt64, + Array[Type](TLocus("GRCh37")), + Str("src/test/resources/example.8bits.bgen"), + Str("src/test/resources/example.8bits.bgen.idx2"), + Literal(TDict(TString, TString), Map("01" -> "1")), + False(), + I32(1000000), + ), + ) val tableRead = TableIR.read(fs, "src/test/resources/backward_compatability/1.1.0/table/0.ht") - val read = MatrixIR.read(fs, "src/test/resources/backward_compatability/1.0.0/matrix_table/0.hmt") + val read = + MatrixIR.read(fs, "src/test/resources/backward_compatability/1.0.0/matrix_table/0.hmt") val range = MatrixIR.range(3, 7, None) val vcf = is.hail.TestUtils.importVCF(ctx, "src/test/resources/sample.vcf") - val bgenReader = MatrixBGENReader(ctx, FastSeq("src/test/resources/example.8bits.bgen"), None, Map.empty[String, String], None, None, None) + val bgenReader = MatrixBGENReader( + ctx, + FastSeq("src/test/resources/example.8bits.bgen"), + None, + Map.empty[String, String], + None, + None, + None, + ) val bgen = MatrixRead(bgenReader.fullMatrixType, false, false, bgenReader) val range1 = MatrixIR.range(20, 2, Some(3)) @@ -2994,15 +3897,20 @@ class IRSuite extends HailSuite { val newCol = MakeStruct(FastSeq( "col_idx" -> GetField(Ref("sa", read.typ.colType), "col_idx"), - "new_f32" -> ApplyBinaryPrimOp(Add(), + "new_f32" -> ApplyBinaryPrimOp( + Add(), GetField(Ref("sa", read.typ.colType), "col_f32"), - F32(-5.2f)))) + F32(-5.2f), + ), + )) val newRow = MakeStruct(FastSeq( "row_idx" -> GetField(Ref("va", read.typ.rowType), "row_idx"), - "new_f32" -> ApplyBinaryPrimOp(Add(), + "new_f32" -> ApplyBinaryPrimOp( + Add(), GetField(Ref("va", read.typ.rowType), "row_f32"), - F32(-5.2f))) - ) + F32(-5.2f), + ), + )) val collectSig = AggSignature(Collect(), IndexedSeq(), IndexedSeq(TInt32)) val collect = ApplyAggOp(FastSeq.empty, FastSeq(I32(0)), collectSig) @@ -3021,10 +3929,16 @@ class IRSuite extends HailSuite { MatrixKeyRowsBy(read, FastSeq("row_m", "row_d"), false), MatrixMapRows(read, newRow), MatrixRepartition(read, 10, 0), - MatrixMapEntries(read, MakeStruct(FastSeq( - "global_f32" -> ApplyBinaryPrimOp(Add(), - GetField(Ref("global", read.typ.globalType), "global_f32"), - F32(-5.2f))))), + MatrixMapEntries( + read, + MakeStruct(FastSeq( + "global_f32" -> ApplyBinaryPrimOp( + Add(), + GetField(Ref("global", read.typ.globalType), "global_f32"), + F32(-5.2f), + ) + )), + ), MatrixCollectColsByKey(read), MatrixAggregateColsByKey(read, newEntryAnn, newColAnn), MatrixAggregateRowsByKey(read, newEntryAnn, newRowAnn), @@ -3043,12 +3957,24 @@ class IRSuite extends HailSuite { CastMatrixToTable(read, " # entries", " # cols"), " # entries", " # cols", - read.typ.colKey), + read.typ.colKey, + ), MatrixAnnotateColsTable(read, tableRead, "uid_123"), - MatrixAnnotateRowsTable(read, tableRead, "uid_123", product=false), - MatrixRename(read, Map("global_i64" -> "foo"), Map("col_i64" -> "bar"), Map("row_i64" -> "baz"), Map("entry_i64" -> "quam")), - MatrixFilterIntervals(read, FastSeq(Interval(IntervalEndpoint(Row(0), -1), IntervalEndpoint(Row(10), 1))), keep = false), - RelationalLetMatrixTable("x", I32(0), read)) + MatrixAnnotateRowsTable(read, tableRead, "uid_123", product = false), + MatrixRename( + read, + Map("global_i64" -> "foo"), + Map("col_i64" -> "bar"), + Map("row_i64" -> "baz"), + Map("entry_i64" -> "quam"), + ), + MatrixFilterIntervals( + read, + FastSeq(Interval(IntervalEndpoint(Row(0), -1), IntervalEndpoint(Row(10), 1))), + keep = false, + ), + RelationalLetMatrixTable("x", I32(0), read), + ) xs.map(x => Array(x)) } catch { @@ -3061,17 +3987,22 @@ class IRSuite extends HailSuite { @DataProvider(name = "blockMatrixIRs") def blockMatrixIRs(): Array[Array[BlockMatrixIR]] = { - val read = BlockMatrixRead(BlockMatrixNativeReader(fs, "src/test/resources/blockmatrix_example/0")) + val read = + BlockMatrixRead(BlockMatrixNativeReader(fs, "src/test/resources/blockmatrix_example/0")) val transpose = BlockMatrixBroadcast(read, FastSeq(1, 0), FastSeq(2, 2), 2) val dot = BlockMatrixDot(read, transpose) val slice = BlockMatrixSlice(read, FastSeq(FastSeq(0, 2, 1), FastSeq(0, 1, 1))) val sparsify1 = BlockMatrixSparsify(read, RectangleSparsifier(FastSeq(FastSeq(0L, 1L, 5L, 6L)))) val sparsify2 = BlockMatrixSparsify(read, BandSparsifier(true, -1L, 1L)) - val sparsify3 = BlockMatrixSparsify(read, RowIntervalSparsifier(true, FastSeq(0L, 1L, 5L, 6L), FastSeq(5L, 6L, 8L, 9L))) + val sparsify3 = BlockMatrixSparsify( + read, + RowIntervalSparsifier(true, FastSeq(0L, 1L, 5L, 6L), FastSeq(5L, 6L, 8L, 9L)), + ) val densify = BlockMatrixDensify(read) - val blockMatrixIRs = Array[BlockMatrixIR](read, + val blockMatrixIRs = Array[BlockMatrixIR]( + read, transpose, dot, sparsify1, @@ -3079,7 +4010,8 @@ class IRSuite extends HailSuite { sparsify3, densify, RelationalLetBlockMatrix("x", I32(0), read), - slice) + slice, + ) blockMatrixIRs.map(ir => Array(ir)) } @@ -3092,7 +4024,7 @@ class IRSuite extends HailSuite { } @Test(dataProvider = "valueIRs") - def testValueIRParser(x: IR, refMap: BindingEnv[Type]) { + def testValueIRParser(x: IR, refMap: BindingEnv[Type]): Unit = { val env = IRParserEnvironment(ctx) val s = Pretty.sexprStyle(x, elideLiterals = false) @@ -3103,30 +4035,31 @@ class IRSuite extends HailSuite { } @Test(dataProvider = "tableIRs") - def testTableIRParser(x: TableIR) { + def testTableIRParser(x: TableIR): Unit = { val s = Pretty.sexprStyle(x, elideLiterals = false) val x2 = IRParser.parse_table_ir(ctx, s) assert(x2 == x) } @Test(dataProvider = "matrixIRs") - def testMatrixIRParser(x: MatrixIR) { + def testMatrixIRParser(x: MatrixIR): Unit = { val s = Pretty.sexprStyle(x, elideLiterals = false) val x2 = IRParser.parse_matrix_ir(ctx, s) assert(x2 == x) } @Test(dataProvider = "blockMatrixIRs") - def testBlockMatrixIRParser(x: BlockMatrixIR) { + def testBlockMatrixIRParser(x: BlockMatrixIR): Unit = { val s = Pretty.sexprStyle(x, elideLiterals = false) val x2 = IRParser.parse_blockmatrix_ir(ctx, s) assert(x2 == x) } - def testBlockMatrixIRParserPersist() { + def testBlockMatrixIRParserPersist(): Unit = { val bm = BlockMatrix.fill(1, 1, 0.0, 5) backend.persist(ctx.backendContext, "x", bm, "MEMORY_ONLY") - val persist = BlockMatrixRead(BlockMatrixPersistReader("x", BlockMatrixType.fromBlockMatrix(bm))) + val persist = + BlockMatrixRead(BlockMatrixPersistReader("x", BlockMatrixType.fromBlockMatrix(bm))) val s = Pretty.sexprStyle(persist, elideLiterals = false) val x2 = IRParser.parse_blockmatrix_ir(ctx, s) @@ -3134,7 +4067,7 @@ class IRSuite extends HailSuite { backend.unpersist(ctx.backendContext, "x") } - @Test def testCachedIR() { + @Test def testCachedIR(): Unit = { val cached = Literal(TSet(TInt32), Set(1)) val s = s"(JavaIR 1)" val x2 = ExecuteContext.scoped() { ctx => @@ -3143,7 +4076,7 @@ class IRSuite extends HailSuite { assert(x2 eq cached) } - @Test def testCachedTableIR() { + @Test def testCachedTableIR(): Unit = { val cached = TableRange(1, 1) val s = s"(JavaTable 1)" val x2 = ExecuteContext.scoped() { ctx => @@ -3152,32 +4085,54 @@ class IRSuite extends HailSuite { assert(x2 eq cached) } - @Test def testArrayContinuationDealsWithIfCorrectly() { + @Test def testArrayContinuationDealsWithIfCorrectly(): Unit = { val ir = ToArray(StreamMap( - If(IsNA(In(0, TBoolean)), - NA(TStream(TInt32)), - ToStream(In(1, TArray(TInt32)))), - "x", Cast(Ref("x", TInt32), TInt64))) + If(IsNA(In(0, TBoolean)), NA(TStream(TInt32)), ToStream(In(1, TArray(TInt32)))), + "x", + Cast(Ref("x", TInt32), TInt64), + )) assertEvalsTo(ir, FastSeq(true -> TBoolean, FastSeq(0) -> TArray(TInt32)), FastSeq(0L)) } - @Test def testTableGetGlobalsSimplifyRules() { + @Test def testTableGetGlobalsSimplifyRules(): Unit = { implicit val execStrats = ExecStrategy.interpretOnly - val t1 = TableType(TStruct("a" -> TInt32), FastSeq("a"), TStruct("g1" -> TInt32, "g2" -> TFloat64)) - val t2 = TableType(TStruct("a" -> TInt32), FastSeq("a"), TStruct("g3" -> TInt32, "g4" -> TFloat64)) - val tab1 = TableLiteral(TableValue(ctx, t1, BroadcastRow(ctx, Row(1, 1.1), t1.globalType), RVD.empty(ctx, t1.canonicalRVDType)), theHailClassLoader) - val tab2 = TableLiteral(TableValue(ctx, t2, BroadcastRow(ctx, Row(2, 2.2), t2.globalType), RVD.empty(ctx, t2.canonicalRVDType)), theHailClassLoader) + val t1 = + TableType(TStruct("a" -> TInt32), FastSeq("a"), TStruct("g1" -> TInt32, "g2" -> TFloat64)) + val t2 = + TableType(TStruct("a" -> TInt32), FastSeq("a"), TStruct("g3" -> TInt32, "g4" -> TFloat64)) + val tab1 = TableLiteral( + TableValue( + ctx, + t1, + BroadcastRow(ctx, Row(1, 1.1), t1.globalType), + RVD.empty(ctx, t1.canonicalRVDType), + ), + theHailClassLoader, + ) + val tab2 = TableLiteral( + TableValue( + ctx, + t2, + BroadcastRow(ctx, Row(2, 2.2), t2.globalType), + RVD.empty(ctx, t2.canonicalRVDType), + ), + theHailClassLoader, + ) assertEvalsTo(TableGetGlobals(TableJoin(tab1, tab2, "left")), Row(1, 1.1, 2, 2.2)) - assertEvalsTo(TableGetGlobals(TableMapGlobals(tab1, InsertFields(Ref("global", t1.globalType), IndexedSeq("g1" -> I32(3))))), Row(3, 1.1)) + assertEvalsTo( + TableGetGlobals(TableMapGlobals( + tab1, + InsertFields(Ref("global", t1.globalType), IndexedSeq("g1" -> I32(3))), + )), + Row(3, 1.1), + ) assertEvalsTo(TableGetGlobals(TableRename(tab1, Map.empty, Map("g2" -> "g3"))), Row(1, 1.1)) } - - - @Test def testAggLet() { + @Test def testAggLet(): Unit = { implicit val execStrats = ExecStrategy.interpretOnly val ir = TableRange(2, 2) .aggregate( @@ -3193,53 +4148,87 @@ class IRSuite extends HailSuite { assertEvalsTo(ir, 61L) } - @Test def testRelationalLet() { + @Test def testRelationalLet(): Unit = { implicit val execStrats = ExecStrategy.interpretOnly val ir = RelationalLet("x", NA(TInt32), RelationalRef("x", TInt32)) assertEvalsTo(ir, null) } - - @Test def testRelationalLetTable() { + @Test def testRelationalLetTable(): Unit = { implicit val execStrats = ExecStrategy.interpretOnly val t = TArray(TStruct("x" -> TInt32)) - val ir = TableAggregate(RelationalLetTable("x", - Literal(t, FastSeq(Row(1))), - TableParallelize(MakeStruct(FastSeq("rows" -> RelationalRef("x", t), "global" -> MakeStruct(FastSeq()))))), - ApplyAggOp(FastSeq(), FastSeq(), AggSignature(Count(), FastSeq(), FastSeq()))) + val ir = TableAggregate( + RelationalLetTable( + "x", + Literal(t, FastSeq(Row(1))), + TableParallelize(MakeStruct(FastSeq( + "rows" -> RelationalRef("x", t), + "global" -> MakeStruct(FastSeq()), + ))), + ), + ApplyAggOp(FastSeq(), FastSeq(), AggSignature(Count(), FastSeq(), FastSeq())), + ) assertEvalsTo(ir, 1L) } - @Test def testRelationalLetMatrixTable() { + @Test def testRelationalLetMatrixTable(): Unit = { implicit val execStrats = ExecStrategy.interpretOnly val t = TArray(TStruct("x" -> TInt32)) val m = CastTableToMatrix( TableMapGlobals( TableMapRows( - TableRange(1, 1), InsertFields(Ref("row", TStruct("idx" -> TInt32)), FastSeq("entries" -> RelationalRef("x", t)))), - MakeStruct(FastSeq("cols" -> MakeArray(FastSeq(MakeStruct(FastSeq("s" -> I32(0)))), TArray(TStruct("s" -> TInt32)))))), + TableRange(1, 1), + InsertFields( + Ref("row", TStruct("idx" -> TInt32)), + FastSeq("entries" -> RelationalRef("x", t)), + ), + ), + MakeStruct(FastSeq("cols" -> MakeArray( + FastSeq(MakeStruct(FastSeq("s" -> I32(0)))), + TArray(TStruct("s" -> TInt32)), + ))), + ), "entries", "cols", - FastSeq()) - val ir = MatrixAggregate(RelationalLetMatrixTable("x", - Literal(t, FastSeq(Row(1))), - m), - ApplyAggOp(FastSeq(), FastSeq(), AggSignature(Count(), FastSeq(), FastSeq()))) + FastSeq(), + ) + val ir = MatrixAggregate( + RelationalLetMatrixTable("x", Literal(t, FastSeq(Row(1))), m), + ApplyAggOp(FastSeq(), FastSeq(), AggSignature(Count(), FastSeq(), FastSeq())), + ) assertEvalsTo(ir, 1L) } - @DataProvider(name = "relationalFunctions") def relationalFunctionsData(): Array[Array[Any]] = Array( Array(TableFilterPartitions(Array(1, 2, 3), keep = true)), Array(VEP(fs, "src/test/resources/dummy_vep_config.json", false, 1, true)), - Array(WrappedMatrixToTableFunction(LinearRegressionRowsSingle(Array("foo"), "bar", Array("baz"), 1, Array("a", "b")), "foo", "bar", FastSeq("ck"))), + Array(WrappedMatrixToTableFunction( + LinearRegressionRowsSingle(Array("foo"), "bar", Array("baz"), 1, Array("a", "b")), + "foo", + "bar", + FastSeq("ck"), + )), Array(LinearRegressionRowsSingle(Array("foo"), "bar", Array("baz"), 1, Array("a", "b"))), - Array(LinearRegressionRowsChained(FastSeq(FastSeq("foo")), "bar", Array("baz"), 1, Array("a", "b"))), - Array(LogisticRegression("firth", Array("a", "b"), "c", Array("d", "e"), Array("f", "g"), 25, 1e-6)), + Array(LinearRegressionRowsChained( + FastSeq(FastSeq("foo")), + "bar", + Array("baz"), + 1, + Array("a", "b"), + )), + Array(LogisticRegression( + "firth", + Array("a", "b"), + "c", + Array("d", "e"), + Array("f", "g"), + 25, + 1e-6, + )), Array(PoissonRegression("firth", "a", "c", Array("d", "e"), Array("f", "g"), 25, 1e-6)), Array(Skat("a", "b", "c", "d", Array("e", "f"), false, 1, 0.1, 100, 0, 0.0)), Array(LocalLDPrune("x", 0.95, 123, 456)), @@ -3252,34 +4241,46 @@ class IRSuite extends HailSuite { Array(NPartitionsMatrixTable()), Array(WrappedMatrixToValueFunction(NPartitionsMatrixTable(), "foo", "bar", FastSeq("a", "c"))), Array(MatrixExportEntriesByCol(1, "asd", false, true, false)), - Array(GetElement(FastSeq(1, 2))) + Array(GetElement(FastSeq(1, 2))), ) - @Test def relationalFunctionsRun(): Unit = { + @Test def relationalFunctionsRun(): Unit = relationalFunctionsData() - } @Test(dataProvider = "relationalFunctions") def testRelationalFunctionsSerialize(x: Any): Unit = { implicit val formats = RelationalFunctions.formats x match { - case x: MatrixToMatrixFunction => assert(RelationalFunctions.lookupMatrixToMatrix(ctx, Serialization.write(x)) == x) - case x: MatrixToTableFunction => assert(RelationalFunctions.lookupMatrixToTable(ctx, Serialization.write(x)) == x) - case x: MatrixToValueFunction => assert(RelationalFunctions.lookupMatrixToValue(ctx, Serialization.write(x)) == x) - case x: TableToTableFunction => assert(RelationalFunctions.lookupTableToTable(ctx, JsonMethods.compact(x.toJValue)) == x) - case x: TableToValueFunction => assert(RelationalFunctions.lookupTableToValue(ctx, Serialization.write(x)) == x) - case x: BlockMatrixToTableFunction => assert(RelationalFunctions.lookupBlockMatrixToTable(ctx, Serialization.write(x)) == x) - case x: BlockMatrixToValueFunction => assert(RelationalFunctions.lookupBlockMatrixToValue(ctx, Serialization.write(x)) == x) + case x: MatrixToMatrixFunction => + assert(RelationalFunctions.lookupMatrixToMatrix(ctx, Serialization.write(x)) == x) + case x: MatrixToTableFunction => + assert(RelationalFunctions.lookupMatrixToTable(ctx, Serialization.write(x)) == x) + case x: MatrixToValueFunction => + assert(RelationalFunctions.lookupMatrixToValue(ctx, Serialization.write(x)) == x) + case x: TableToTableFunction => + assert(RelationalFunctions.lookupTableToTable(ctx, JsonMethods.compact(x.toJValue)) == x) + case x: TableToValueFunction => + assert(RelationalFunctions.lookupTableToValue(ctx, Serialization.write(x)) == x) + case x: BlockMatrixToTableFunction => + assert(RelationalFunctions.lookupBlockMatrixToTable(ctx, Serialization.write(x)) == x) + case x: BlockMatrixToValueFunction => + assert(RelationalFunctions.lookupBlockMatrixToValue(ctx, Serialization.write(x)) == x) } } - @Test def testFoldWithSetup() { + @Test def testFoldWithSetup(): Unit = { val v = In(0, TInt32) - val cond1 = If(v.ceq(I32(3)), + val cond1 = If( + v.ceq(I32(3)), MakeStream(FastSeq(I32(1), I32(2), I32(3)), TStream(TInt32)), - MakeStream(FastSeq(I32(4), I32(5), I32(6)), TStream(TInt32))) - assertEvalsTo(StreamFold(cond1, True(), "accum", "i", Ref("i", TInt32).ceq(v)), FastSeq(0 -> TInt32), false) + MakeStream(FastSeq(I32(4), I32(5), I32(6)), TStream(TInt32)), + ) + assertEvalsTo( + StreamFold(cond1, True(), "accum", "i", Ref("i", TInt32).ceq(v)), + FastSeq(0 -> TInt32), + false, + ) } @Test def testNonCanonicalTypeParsing(): Unit = { @@ -3292,14 +4293,16 @@ class IRSuite extends HailSuite { def regressionTestUnifyBug(): Unit = { // failed due to misuse of Type.unify - val ir = IRParser.parse_value_ir(ctx, + val ir = IRParser.parse_value_ir( + ctx, """ |(ToArray (StreamMap __uid_3 | (ToStream (Literal Array[Interval[Locus(GRCh37)]] "[{\"start\": {\"contig\": \"20\", \"position\": 10277621}, \"end\": {\"contig\": \"20\", \"position\": 11898992}, \"includeStart\": true, \"includeEnd\": false}]")) | (Apply Interval Interval[Struct{locus:Locus(GRCh37)}] | (MakeStruct (locus (Apply start Locus(GRCh37) (Ref __uid_3)))) | (MakeStruct (locus (Apply end Locus(GRCh37) (Ref __uid_3)))) (True) (False)))) - |""".stripMargin) + |""".stripMargin, + ) val v = ExecutionTimer.logTime("IRSuite.regressionTestUnifyBug") { timer => backend.execute(timer, ir, optimize = true) } @@ -3307,68 +4310,105 @@ class IRSuite extends HailSuite { ir.typ.ordering(ctx.stateManager).equiv( FastSeq( Interval( - Row(Locus("20", 10277621)), Row(Locus("20", 11898992)), includesStart = true, includesEnd = false)), - v)) + Row(Locus("20", 10277621)), + Row(Locus("20", 11898992)), + includesStart = true, + includesEnd = false, + ) + ), + v, + ) + ) } @Test def testSimpleTailLoop(): Unit = { implicit val execStrats = ExecStrategy.compileOnly - val triangleSum: IR = TailLoop("f", + val triangleSum: IR = TailLoop( + "f", FastSeq("x" -> In(0, TInt32), "accum" -> In(1, TInt32)), TInt32, - If(Ref("x", TInt32) <= I32(0), + If( + Ref("x", TInt32) <= I32(0), Ref("accum", TInt32), - Recur("f", + Recur( + "f", FastSeq( Ref("x", TInt32) - I32(1), - Ref("accum", TInt32) + Ref("x", TInt32)), - TInt32))) + Ref("accum", TInt32) + Ref("x", TInt32), + ), + TInt32, + ), + ), + ) assertEvalsTo(triangleSum, FastSeq(5 -> TInt32, 0 -> TInt32), 15) assertEvalsTo(triangleSum, FastSeq(5 -> TInt32, (null, TInt32)), null) - assertEvalsTo(triangleSum, FastSeq((null, TInt32), 0 -> TInt32), null) + assertEvalsTo(triangleSum, FastSeq((null, TInt32), 0 -> TInt32), null) } @Test def testNestedTailLoop(): Unit = { implicit val execStrats = ExecStrategy.compileOnly - val triangleSum: IR = TailLoop("f1", + val triangleSum: IR = TailLoop( + "f1", FastSeq("x" -> In(0, TInt32), "accum" -> I32(0)), TInt32, - If(Ref("x", TInt32) <= I32(0), - TailLoop("f2", + If( + Ref("x", TInt32) <= I32(0), + TailLoop( + "f2", FastSeq("x2" -> Ref("accum", TInt32), "accum2" -> I32(0)), TInt32, - If(Ref("x2", TInt32) <= I32(0), + If( + Ref("x2", TInt32) <= I32(0), Ref("accum2", TInt32), - Recur("f2", + Recur( + "f2", FastSeq( Ref("x2", TInt32) - I32(5), - Ref("accum2", TInt32) + Ref("x2", TInt32)), - TInt32))), - Recur("f1", + Ref("accum2", TInt32) + Ref("x2", TInt32), + ), + TInt32, + ), + ), + ), + Recur( + "f1", FastSeq( Ref("x", TInt32) - I32(1), - Ref("accum", TInt32) + Ref("x", TInt32)), - TInt32))) + Ref("accum", TInt32) + Ref("x", TInt32), + ), + TInt32, + ), + ), + ) assertEvalsTo(triangleSum, FastSeq(5 -> TInt32), 15 + 10 + 5) } @Test def testTailLoopNDMemory(): Unit = { - implicit val execStrats = ExecStrategy.compileOnly - val ndType = TNDArray(TInt32, Nat(2)) - val ndSum: IR = TailLoop("f", + val ndSum: IR = TailLoop( + "f", FastSeq("x" -> In(0, TInt32), "accum" -> In(1, ndType)), ndType, - If(Ref("x", TInt32) <= I32(0), + If( + Ref("x", TInt32) <= I32(0), Ref("accum", ndType), - Recur("f", + Recur( + "f", FastSeq( Ref("x", TInt32) - I32(1), - NDArrayMap(Ref("accum", ndType), "ndElement", Ref("ndElement", ndType.elementType) + Ref("x", TInt32))), - ndType))) + NDArrayMap( + Ref("accum", ndType), + "ndElement", + Ref("ndElement", ndType.elementType) + Ref("x", TInt32), + ), + ), + ndType, + ), + ), + ) val startingArg = SafeNDArray(IndexedSeq[Long](4L, 4L), (0 until 16).toFastSeq) @@ -3405,13 +4445,18 @@ class IRSuite extends HailSuite { testFreeVarsHelper(liftIR) val sumSig = AggSignature(Sum(), IndexedSeq(), IndexedSeq(TInt64)) - val streamAggIR = StreamAgg( + val streamAggIR = StreamAgg( StreamMap(StreamRange(I32(0), I32(4), I32(1)), "x", Cast(Ref("x", TInt32), TInt64)), "x", - ApplyAggOp(FastSeq.empty, FastSeq(Ref("x", TInt64)), sumSig)) + ApplyAggOp(FastSeq.empty, FastSeq(Ref("x", TInt64)), sumSig), + ) testFreeVarsHelper(streamAggIR) - val streamScanIR = StreamAggScan(Ref("st", TStream(TInt32)), "x", ApplyScanOp(FastSeq.empty, FastSeq(Cast(Ref("x", TInt32), TInt64)), sumSig)) + val streamScanIR = StreamAggScan( + Ref("st", TStream(TInt32)), + "x", + ApplyScanOp(FastSeq.empty, FastSeq(Cast(Ref("x", TInt32), TInt64)), sumSig), + ) testFreeVarsHelper(streamScanIR) } @@ -3423,7 +4468,15 @@ class IRSuite extends HailSuite { Array(Float64SingleCodeType, 1.2), Array(PTypeReferenceSingleCodeType(PCanonicalString()), "foo"), Array(PTypeReferenceSingleCodeType(PCanonicalArray(PInt32())), FastSeq(5, 7, null, 3)), - Array(PTypeReferenceSingleCodeType(PCanonicalTuple(false, PInt32(), PCanonicalString(), PCanonicalStruct())), Row(3, "bar", Row())) + Array( + PTypeReferenceSingleCodeType(PCanonicalTuple( + false, + PInt32(), + PCanonicalString(), + PCanonicalStruct(), + )), + Row(3, "bar", Row()), + ), ) @Test(dataProvider = "nonNullTypesAndValues") @@ -3435,12 +4488,11 @@ class IRSuite extends HailSuite { val reader = ETypeValueReader(spec) val prefix = ctx.createTmpPath("test-read-write-values") val filename = WriteValue(node, Str(prefix) + UUID4(), writer) - for (v <- Array(value, null)) { + for (v <- Array(value, null)) assertEvalsTo(ReadValue(filename, reader, pt.virtualType), FastSeq(v -> pt.virtualType), v) - } } - @Test(dataProvider="nonNullTypesAndValues") + @Test(dataProvider = "nonNullTypesAndValues") def testReadWriteValueDistributed(pt: SingleCodeType, value: Any): Unit = { implicit val execStrats = ExecStrategy.compileOnly val node = In(0, SingleCodeEmitParamType(true, pt)) @@ -3448,33 +4500,48 @@ class IRSuite extends HailSuite { val writer = ETypeValueWriter(spec) val reader = ETypeValueReader(spec) val prefix = ctx.createTmpPath("test-read-write-value-dist") - val readArray = Let(FastSeq("files" -> - CollectDistributedArray(StreamMap(StreamRange(0, 10, 1), "x", node), MakeStruct(FastSeq()), - "ctx", "globals", - WriteValue(Ref("ctx", node.typ), Str(prefix) + UUID4(), writer), NA(TString), "test")), - StreamMap(ToStream(Ref("files", TArray(TString))), "filename", - ReadValue(Ref("filename", TString), reader, pt.virtualType))) - for (v <- Array(value, null)) { + val readArray = Let( + FastSeq("files" -> + CollectDistributedArray( + StreamMap(StreamRange(0, 10, 1), "x", node), + MakeStruct(FastSeq()), + "ctx", + "globals", + WriteValue(Ref("ctx", node.typ), Str(prefix) + UUID4(), writer), + NA(TString), + "test", + )), + StreamMap( + ToStream(Ref("files", TArray(TString))), + "filename", + ReadValue(Ref("filename", TString), reader, pt.virtualType), + ), + ) + for (v <- Array(value, null)) assertEvalsTo(ToArray(readArray), FastSeq(v -> pt.virtualType), Array.fill(10)(v).toFastSeq) - } } - @Test def testUUID4() { + @Test def testUUID4(): Unit = { val single = UUID4() val hex = "[0-9a-f]" val format = s"$hex{8}-$hex{4}-$hex{4}-$hex{4}-$hex{12}" // 12345678-1234-5678-1234-567812345678 assertEvalsTo( - bindIR(single){ s => + bindIR(single) { s => invoke("regexMatch", TBoolean, Str(format), s) && - invoke("length", TInt32, s).ceq(I32(36)) - }, true) + invoke("length", TInt32, s).ceq(I32(36)) + }, + true, + ) - val stream = mapIR(rangeIR(5)) { _ => single } + val stream = mapIR(rangeIR(5))(_ => single) - def selfZip(s: IR, n: Int) = StreamZip(Array.fill(n)(s), Array.tabulate(n)(i => s"$i"), + def selfZip(s: IR, n: Int) = StreamZip( + Array.fill(n)(s), + Array.tabulate(n)(i => s"$i"), MakeArray(Array.tabulate(n)(i => Ref(s"$i", TString)), TArray(TString)), - ArrayZipBehavior.AssumeSameLength) + ArrayZipBehavior.AssumeSameLength, + ) def assertNumDistinct(s: IR, expected: Int) = assertEvalsTo(ArrayLen(CastToArray(ToSet(s))), expected) @@ -3485,21 +4552,26 @@ class IRSuite extends HailSuite { } @Test def testZipDoesntPruneLengthInfo(): Unit = { - for (behavior <- Array(ArrayZipBehavior.AssumeSameLength, - ArrayZipBehavior.AssertSameLength, - ArrayZipBehavior.TakeMinLength, - ArrayZipBehavior.ExtendNA)) { + for ( + behavior <- Array( + ArrayZipBehavior.AssumeSameLength, + ArrayZipBehavior.AssertSameLength, + ArrayZipBehavior.TakeMinLength, + ArrayZipBehavior.ExtendNA, + ) + ) { val zip = StreamZip( FastSeq(StreamRange(0, 10, 1), StreamRange(0, 10, 1)), FastSeq("x", "y"), makestruct("x" -> Str("foo"), "y" -> Str("bar")), - behavior) + behavior, + ) assertEvalsTo(ToArray(zip), Array.fill(10)(Row("foo", "bar")).toFastSeq) } } - @Test def testStreamDistribute(): Unit = { + @Test def testStreamDistribute(): Unit = { val data1 = IndexedSeq(0, 1, 1, 2, 4, 7, 7, 7, 9, 11, 15, 20, 22, 28, 50, 100) val pivots1 = IndexedSeq(-10, 1, 7, 7, 15, 22, 50, 200) val pivots2 = IndexedSeq(-10, 1, 1, 7, 9, 28, 50, 200) @@ -3522,23 +4594,42 @@ class IRSuite extends HailSuite { } def runStreamDistTest(data: IndexedSeq[Int], splitters: IndexedSeq[Int]): Unit = { - def makeRowStruct(i: Int) = MakeStruct(IndexedSeq(("rowIdx", I32(i)), ("extraInfo", I32(i * i)))) + def makeRowStruct(i: Int) = + MakeStruct(IndexedSeq(("rowIdx", I32(i)), ("extraInfo", I32(i * i)))) def makeKeyStruct(i: Int) = MakeStruct(IndexedSeq(("rowIdx", I32(i)))) - val child = ToStream(MakeArray(data.map(makeRowStruct):_*)) - val pivots = MakeArray(splitters.map(makeKeyStruct):_*) - val spec = TypedCodecSpec(PCanonicalStruct(("rowIdx", PInt32Required), ("extraInfo", PInt32Required)), BufferSpec.default) - val dist = StreamDistribute(child, pivots, Str(ctx.localTmpdir), Compare(pivots.typ.asInstanceOf[TArray].elementType), spec) - val result = eval(dist).asInstanceOf[IndexedSeq[Row]].map(row => (row(0).asInstanceOf[Interval], row(1).asInstanceOf[String], row(2).asInstanceOf[Int], row(3).asInstanceOf[Long])) - val kord: ExtendedOrdering = PartitionBoundOrdering(ctx, pivots.typ.asInstanceOf[TArray].elementType) + val child = ToStream(MakeArray(data.map(makeRowStruct): _*)) + val pivots = MakeArray(splitters.map(makeKeyStruct): _*) + val spec = TypedCodecSpec( + PCanonicalStruct(("rowIdx", PInt32Required), ("extraInfo", PInt32Required)), + BufferSpec.default, + ) + val dist = StreamDistribute( + child, + pivots, + Str(ctx.localTmpdir), + Compare(pivots.typ.asInstanceOf[TArray].elementType), + spec, + ) + val result = eval(dist).asInstanceOf[IndexedSeq[Row]].map(row => + ( + row(0).asInstanceOf[Interval], + row(1).asInstanceOf[String], + row(2).asInstanceOf[Int], + row(3).asInstanceOf[Long], + ) + ) + val kord: ExtendedOrdering = + PartitionBoundOrdering(ctx, pivots.typ.asInstanceOf[TArray].elementType) var dataIdx = 0 - result.foreach { case (interval, path, elementCount, numBytes) => + result.foreach { case (interval, path, elementCount, _) => val reader = PartitionNativeReader(spec, "rowUID") val read = ToArray(ReadPartition( MakeStruct(Array("partitionIndex" -> I64(0), "partitionPath" -> Str(path))), tcoerce[TStruct](spec._vType), - reader)) + reader, + )) val rowsFromDisk = eval(read).asInstanceOf[IndexedSeq[Row]] assert(rowsFromDisk.size == elementCount) assert(rowsFromDisk.forall(interval.contains(kord, _))) diff --git a/hail/src/test/scala/is/hail/expr/ir/IntervalSuite.scala b/hail/src/test/scala/is/hail/expr/ir/IntervalSuite.scala index 6cd1c93dae5..2c803273d55 100644 --- a/hail/src/test/scala/is/hail/expr/ir/IntervalSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/IntervalSuite.scala @@ -1,10 +1,11 @@ package is.hail.expr.ir +import is.hail.{ExecStrategy, HailSuite} import is.hail.TestUtils._ import is.hail.rvd.RVDPartitioner import is.hail.types.virtual._ import is.hail.utils._ -import is.hail.{ExecStrategy, HailSuite} + import org.apache.spark.sql.Row import org.testng.ITestContext import org.testng.annotations.{BeforeMethod, Test} @@ -18,9 +19,17 @@ class IntervalSuite extends HailSuite { val na = NA(tinterval1) def point(i: Int): IR = MakeTuple.ordered(FastSeq(I32(i))) - def interval(start: IR, end: IR, includeStart: java.lang.Boolean, includeEnd: java.lang.Boolean): IR = { - invoke("Interval", TInterval(start.typ), start, end, Literal.coerce(TBoolean, includeStart), Literal.coerce(TBoolean, includeEnd)) - } + + def interval(start: IR, end: IR, includeStart: java.lang.Boolean, includeEnd: java.lang.Boolean) + : IR = + invoke( + "Interval", + TInterval(start.typ), + start, + end, + Literal.coerce(TBoolean, includeStart), + Literal.coerce(TBoolean, includeEnd), + ) val i1 = interval(point(1), point(2), true, false) val i2 = interval(point(1), NA(tpoint1), true, false) @@ -28,7 +37,7 @@ class IntervalSuite extends HailSuite { val i4 = interval(NA(tpoint1), point(2), null, false) val i5 = interval(NA(tpoint1), point(2), true, null) - @Test def constructor() { + @Test def constructor(): Unit = { assertEvalsTo(i1, Interval(Row(1), Row(2), true, false)) assertEvalsTo(i2, Interval(Row(1), null, true, false)) assertEvalsTo(i3, Interval(null, Row(2), true, false)) @@ -36,33 +45,33 @@ class IntervalSuite extends HailSuite { assertEvalsTo(i5, null) } - @Test def start() { + @Test def start(): Unit = { assertEvalsTo(invoke("start", tpoint1, i1), Row(1)) assertEvalsTo(invoke("start", tpoint1, i2), Row(1)) assertEvalsTo(invoke("start", tpoint1, i3), null) assertEvalsTo(invoke("start", tpoint1, na), null) } - @Test def defaultValueCorrectlyStored() { + @Test def defaultValueCorrectlyStored(): Unit = { assertEvalsTo(If(GetTupleElement(invoke("start", tpoint1, i1), 0).ceq(1), true, false), true) assertEvalsTo(If(GetTupleElement(invoke("end", tpoint1, i1), 0).ceq(2), true, false), true) } - @Test def end() { + @Test def end(): Unit = { assertEvalsTo(invoke("end", tpoint1, i1), Row(2)) assertEvalsTo(invoke("end", tpoint1, i2), null) assertEvalsTo(invoke("end", tpoint1, i3), Row(2)) assertEvalsTo(invoke("end", tpoint1, na), null) } - @Test def includeStart() { + @Test def includeStart(): Unit = { assertEvalsTo(invoke("includesStart", TBoolean, i1), true) assertEvalsTo(invoke("includesStart", TBoolean, i2), true) assertEvalsTo(invoke("includesStart", TBoolean, i3), true) assertEvalsTo(invoke("includesStart", TBoolean, na), null) } - @Test def includeEnd() { + @Test def includeEnd(): Unit = { assertEvalsTo(invoke("includesEnd", TBoolean, i1), false) assertEvalsTo(invoke("includesEnd", TBoolean, i2), false) assertEvalsTo(invoke("includesEnd", TBoolean, i3), false) @@ -78,62 +87,84 @@ class IntervalSuite extends HailSuite { SetInterval(1, 3, false, true), SetInterval(2, 3, false, false), SetInterval(1, 2, true, true), - SetInterval(3, 1, true, false)) + SetInterval(3, 1, true, false), + ) def toIRInterval(i: SetInterval): IR = - invoke("Interval", TInterval(TInt32), ErrorIDs.NO_ERROR, i.start, i.end, i.includesStart, i.includesEnd) + invoke( + "Interval", + TInterval(TInt32), + ErrorIDs.NO_ERROR, + i.start, + i.end, + i.includesStart, + i.includesEnd, + ) - @Test def contains() { - for (setInterval <- testIntervals; p <- points) { + @Test def contains(): Unit = { + for { + setInterval <- testIntervals + p <- points + } { val interval = toIRInterval(setInterval) assert(eval(invoke("contains", TBoolean, interval, p)) == setInterval.contains(p)) } } - @Test def isEmpty() { + @Test def isEmpty(): Unit = { for (setInterval <- testIntervals) { val interval = toIRInterval(setInterval) - assert(eval(invoke("isEmpty", TBoolean, ErrorIDs.NO_ERROR, interval)) == setInterval.definitelyEmpty()) + assert(eval( + invoke("isEmpty", TBoolean, ErrorIDs.NO_ERROR, interval) + ) == setInterval.definitelyEmpty()) } } - @Test def overlaps() { - for (setInterval1 <- testIntervals; setInterval2 <- testIntervals) { + @Test def overlaps(): Unit = { + for { + setInterval1 <- testIntervals + setInterval2 <- testIntervals + } { val interval1 = toIRInterval(setInterval1) val interval2 = toIRInterval(setInterval2) - assert(eval(invoke("overlaps", TBoolean, interval1, interval2)) == setInterval1.probablyOverlaps(setInterval2)) + assert(eval( + invoke("overlaps", TBoolean, interval1, interval2) + ) == setInterval1.probablyOverlaps(setInterval2)) } } - - def intInterval(start: Int, end: Int, includesStart: Boolean = true, includesEnd: Boolean = false): Interval = + def intInterval(start: Int, end: Int, includesStart: Boolean = true, includesEnd: Boolean = false) + : Interval = Interval(start, end, includesStart, includesEnd) - @Test def testIntervalSortAndReduce() { + @Test def testIntervalSortAndReduce(): Unit = { val ord = TInt32.ordering(ctx.stateManager).intervalEndpointOrdering assert(Interval.union(Array[Interval](), ord).sameElements(Array[Interval]())) assert(Interval.union(Array(intInterval(0, 10)), ord) .sameElements(Array(intInterval(0, 10)))) - assert(Interval.union(Array( - intInterval(0, 10), - intInterval(0, 20, includesEnd = true), - intInterval(20, 30), - intInterval(40, 50) - ).reverse, ord).toSeq == FastSeq( + assert(Interval.union( + Array( + intInterval(0, 10), + intInterval(0, 20, includesEnd = true), + intInterval(20, 30), + intInterval(40, 50), + ).reverse, + ord, + ).toSeq == FastSeq( intInterval(0, 30), - intInterval(40, 50) + intInterval(40, 50), )) } - @Test def testIntervalIntersection() { + @Test def testIntervalIntersection(): Unit = { val ord = TInt32.ordering(ctx.stateManager).intervalEndpointOrdering val x1 = Array[Interval]( intInterval(5, 10), intInterval(15, 20), - intInterval(25, 26) + intInterval(25, 26), ) val x2 = Array[Interval]( intInterval(0, 1), @@ -141,7 +172,7 @@ class IntervalSuite extends HailSuite { intInterval(23, 24), intInterval(24, 25), intInterval(25, 26), - intInterval(26, 27) + intInterval(26, 27), ) val x3 = Array[Interval]( @@ -154,28 +185,59 @@ class IntervalSuite extends HailSuite { assert(Interval.intersection(x1, x2, ord).toSeq == x1.toSeq) assert(Interval.intersection(x1, x3, ord).toSeq == FastSeq[Interval]( intInterval(7, 10), - intInterval(15, 19, includesEnd = true))) + intInterval(15, 19, includesEnd = true), + )) } - @Test def testsortedNonOverlappingIntervalsContain() { - val intervals = Literal(TArray(TInterval(TInt32)), FastSeq( - Interval(0, 1, includesStart = true, includesEnd = true), - Interval(10, 20, includesStart = true, includesEnd = true), - Interval(30, 32, includesStart = false, includesEnd = false), - Interval(32, 32, includesStart = true, includesEnd = true) - )) + @Test def testsortedNonOverlappingIntervalsContain(): Unit = { + val intervals = Literal( + TArray(TInterval(TInt32)), + FastSeq( + Interval(0, 1, includesStart = true, includesEnd = true), + Interval(10, 20, includesStart = true, includesEnd = true), + Interval(30, 32, includesStart = false, includesEnd = false), + Interval(32, 32, includesStart = true, includesEnd = true), + ), + ) - assertEvalsTo(invoke("sortedNonOverlappingIntervalsContain", TBoolean, intervals, I32(-1)), false) - assertEvalsTo(invoke("sortedNonOverlappingIntervalsContain", TBoolean, intervals, I32(7)), false) - assertEvalsTo(invoke("sortedNonOverlappingIntervalsContain", TBoolean, intervals, I32(27)), false) - assertEvalsTo(invoke("sortedNonOverlappingIntervalsContain", TBoolean, intervals, I32(30)), false) - assertEvalsTo(invoke("sortedNonOverlappingIntervalsContain", TBoolean, intervals, I32(300)), false) + assertEvalsTo( + invoke("sortedNonOverlappingIntervalsContain", TBoolean, intervals, I32(-1)), + false, + ) + assertEvalsTo( + invoke("sortedNonOverlappingIntervalsContain", TBoolean, intervals, I32(7)), + false, + ) + assertEvalsTo( + invoke("sortedNonOverlappingIntervalsContain", TBoolean, intervals, I32(27)), + false, + ) + assertEvalsTo( + invoke("sortedNonOverlappingIntervalsContain", TBoolean, intervals, I32(30)), + false, + ) + assertEvalsTo( + invoke("sortedNonOverlappingIntervalsContain", TBoolean, intervals, I32(300)), + false, + ) assertEvalsTo(invoke("sortedNonOverlappingIntervalsContain", TBoolean, intervals, I32(0)), true) assertEvalsTo(invoke("sortedNonOverlappingIntervalsContain", TBoolean, intervals, I32(1)), true) - assertEvalsTo(invoke("sortedNonOverlappingIntervalsContain", TBoolean, intervals, I32(10)), true) - assertEvalsTo(invoke("sortedNonOverlappingIntervalsContain", TBoolean, intervals, I32(11)), true) - assertEvalsTo(invoke("sortedNonOverlappingIntervalsContain", TBoolean, intervals, I32(31)), true) - assertEvalsTo(invoke("sortedNonOverlappingIntervalsContain", TBoolean, intervals, I32(32)), true) + assertEvalsTo( + invoke("sortedNonOverlappingIntervalsContain", TBoolean, intervals, I32(10)), + true, + ) + assertEvalsTo( + invoke("sortedNonOverlappingIntervalsContain", TBoolean, intervals, I32(11)), + true, + ) + assertEvalsTo( + invoke("sortedNonOverlappingIntervalsContain", TBoolean, intervals, I32(31)), + true, + ) + assertEvalsTo( + invoke("sortedNonOverlappingIntervalsContain", TBoolean, intervals, I32(32)), + true, + ) } val partitionerKType = TStruct("k1" -> TInt32, "k2" -> TInt32, "k3" -> TInt32) @@ -183,22 +245,28 @@ class IntervalSuite extends HailSuite { @BeforeMethod def setupRVDPartitioner(context: ITestContext): Unit = { - partitioner = new RVDPartitioner(ctx.stateManager, partitionerKType, + partitioner = new RVDPartitioner( + ctx.stateManager, + partitionerKType, Array( Interval(Row(1, 0), Row(4, 3), true, false), Interval(Row(4, 3), Row(7, 9), true, false), - Interval(Row(7, 11), Row(10, 0), true, true)) + Interval(Row(7, 11), Row(10, 0), true, true), + ), ).partitionBoundsIRRepresentation } - @Test def testsortedNonOverlappingPartitionIntervalsEqualRange() { - def assertRange(interval: Interval, startIdx: Int, endIdx: Int) { + @Test def testsortedNonOverlappingPartitionIntervalsEqualRange(): Unit = { + def assertRange(interval: Interval, startIdx: Int, endIdx: Int): Unit = { val resultType = TTuple(TInt32, TInt32) - val irInterval = Literal(RVDPartitioner.intervalIRRepresentation(partitionerKType), - RVDPartitioner.intervalToIRRepresentation(interval, 3)) + val irInterval = Literal( + RVDPartitioner.intervalIRRepresentation(partitionerKType), + RVDPartitioner.intervalToIRRepresentation(interval, 3), + ) assertEvalsTo( invoke("partitionerFindIntervalRange", resultType, partitioner, irInterval), - Row(startIdx, endIdx)) + Row(startIdx, endIdx), + ) } assertRange(Interval(Row(3, 4, 0), Row(7, 11), true, true), 0, 3) assertRange(Interval(Row(3, 4), Row(7, 9), true, false), 0, 2) @@ -207,24 +275,38 @@ class IntervalSuite extends HailSuite { assertRange(Interval(Row(-1, 7), Row(0, 9), true, false), 0, 0) } - @Test def testPointPartitionIntervalEndpointComparison() { - def assertComp(point: IndexedSeq[Int], intervalEndpoint: IndexedSeq[Int], leansRight: Boolean, function: String, expected: Boolean) { + @Test def testPointPartitionIntervalEndpointComparison(): Unit = { + def assertComp( + point: IndexedSeq[Int], + intervalEndpoint: IndexedSeq[Int], + leansRight: Boolean, + function: String, + expected: Boolean, + ): Unit = { val pointIR = MakeTuple.ordered(point.map(I32)) val endpointIR = MakeTuple.ordered(FastSeq( MakeTuple.ordered(Array.tabulate(3)(i => - if (i < intervalEndpoint.length) I32(intervalEndpoint(i)) else NA(TInt32))), - I32(intervalEndpoint.length))) + if (i < intervalEndpoint.length) I32(intervalEndpoint(i)) else NA(TInt32) + )), + I32(intervalEndpoint.length), + )) val leansRightIR = if (leansRight) True() else False() assertEvalsTo( invoke(function, TBoolean, pointIR, endpointIR, leansRightIR), - expected) - } - def assertLT(point: IndexedSeq[Int], intervalEndpoint: IndexedSeq[Int], leansRight: Boolean) { - assertComp(point, intervalEndpoint, leansRight, "pointLessThanPartitionIntervalRightEndpoint", true) - } - def assertNotLT(point: IndexedSeq[Int], intervalEndpoint: IndexedSeq[Int], leansRight: Boolean) { - assertComp(point, intervalEndpoint, leansRight, "pointLessThanPartitionIntervalRightEndpoint", false) + expected, + ) } + def assertLT(point: IndexedSeq[Int], intervalEndpoint: IndexedSeq[Int], leansRight: Boolean) + : Unit = + assertComp(point, intervalEndpoint, leansRight, "pointLessThanPartitionIntervalRightEndpoint", + true) + def assertNotLT( + point: IndexedSeq[Int], + intervalEndpoint: IndexedSeq[Int], + leansRight: Boolean, + ): Unit = + assertComp(point, intervalEndpoint, leansRight, "pointLessThanPartitionIntervalRightEndpoint", + false) assertLT(Array(1, 3, 2), Array(1, 3, 2), true) assertNotLT(Array(1, 3, 2), Array(1, 3, 2), false) assertLT(Array(1, 3, 2), Array(1, 3, 4), true) diff --git a/hail/src/test/scala/is/hail/expr/ir/LiftLiteralsSuite.scala b/hail/src/test/scala/is/hail/expr/ir/LiftLiteralsSuite.scala index 0fad01d0c62..1fbadd00741 100644 --- a/hail/src/test/scala/is/hail/expr/ir/LiftLiteralsSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/LiftLiteralsSuite.scala @@ -3,25 +3,33 @@ package is.hail.expr.ir import is.hail.{ExecStrategy, HailSuite} import is.hail.types.virtual.TInt64 import is.hail.utils.FastSeq -import is.hail.TestUtils._ + import org.apache.spark.sql.Row import org.testng.annotations.Test class LiftLiteralsSuite extends HailSuite { implicit val execStrats = ExecStrategy.interpretOnly - @Test def testNestedGlobalsRewrite() { - val tab = TableLiteral(TableRange(10, 1).analyzeAndExecute(ctx).asTableValue(ctx), theHailClassLoader) + @Test def testNestedGlobalsRewrite(): Unit = { + val tab = + TableLiteral(TableRange(10, 1).analyzeAndExecute(ctx).asTableValue(ctx), theHailClassLoader) val ir = TableGetGlobals( TableMapGlobals( tab, - Let(FastSeq("global" -> I64(1)), + Let( + FastSeq("global" -> I64(1)), MakeStruct( FastSeq( "x" -> ApplyBinaryPrimOp( Add(), TableCount(tab), - Ref("global", TInt64))))))) + Ref("global", TInt64), + ) + ) + ), + ), + ) + ) assertEvalsTo(ir, Row(11L)) } diff --git a/hail/src/test/scala/is/hail/expr/ir/LocusFunctionsSuite.scala b/hail/src/test/scala/is/hail/expr/ir/LocusFunctionsSuite.scala index bab55004e09..848c956f6e2 100644 --- a/hail/src/test/scala/is/hail/expr/ir/LocusFunctionsSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/LocusFunctionsSuite.scala @@ -4,6 +4,7 @@ import is.hail.{ExecStrategy, HailSuite} import is.hail.types.virtual._ import is.hail.utils.{FastSeq, Interval} import is.hail.variant.{Locus, ReferenceGenome} + import org.apache.spark.sql.Row import org.testng.annotations.Test @@ -15,85 +16,165 @@ class LocusFunctionsSuite extends HailSuite { private def tlocus = TLocus(grch38.name) private def tvariant = TStruct("locus" -> tlocus, "alleles" -> TArray(TString)) - def locusIR: Apply = Apply("Locus", FastSeq(), FastSeq(Str("chr22"), I32(1)), tlocus, ErrorIDs.NO_ERROR) + def locusIR: Apply = + Apply("Locus", FastSeq(), FastSeq(Str("chr22"), I32(1)), tlocus, ErrorIDs.NO_ERROR) def locus = Locus("chr22", 1, grch38) - @Test def contig() { + @Test def contig(): Unit = assertEvalsTo(invoke("contig", TString, locusIR), locus.contig) - } - @Test def position() { + @Test def position(): Unit = assertEvalsTo(invoke("position", TInt32, locusIR), locus.position) - } - @Test def isAutosomalOrPseudoAutosomal() { - assertEvalsTo(invoke("isAutosomalOrPseudoAutosomal", TBoolean, locusIR), locus.isAutosomalOrPseudoAutosomal(grch38)) - } + @Test def isAutosomalOrPseudoAutosomal(): Unit = + assertEvalsTo( + invoke("isAutosomalOrPseudoAutosomal", TBoolean, locusIR), + locus.isAutosomalOrPseudoAutosomal(grch38), + ) - @Test def isAutosomal() { + @Test def isAutosomal(): Unit = assertEvalsTo(invoke("isAutosomal", TBoolean, locusIR), locus.isAutosomal(grch38)) - } - @Test def inYNonPar() { + @Test def inYNonPar(): Unit = assertEvalsTo(invoke("inYNonPar", TBoolean, locusIR), locus.inYNonPar(grch38)) - } - @Test def inXPar() { + @Test def inXPar(): Unit = assertEvalsTo(invoke("inXPar", TBoolean, locusIR), locus.inXPar(grch38)) - } - @Test def isMitochondrial() { + @Test def isMitochondrial(): Unit = assertEvalsTo(invoke("isMitochondrial", TBoolean, locusIR), locus.isMitochondrial(grch38)) - } - @Test def inXNonPar() { + @Test def inXNonPar(): Unit = assertEvalsTo(invoke("inXNonPar", TBoolean, locusIR), locus.inXNonPar(grch38)) - } - @Test def inYPar() { + @Test def inYPar(): Unit = assertEvalsTo(invoke("inYPar", TBoolean, locusIR), locus.inYPar(grch38)) - } - @Test def minRep() { + @Test def minRep(): Unit = { val alleles = MakeArray(FastSeq(Str("AA"), Str("AT")), TArray(TString)) - assertEvalsTo(invoke("min_rep", tvariant, locusIR, alleles), Row(Locus("chr22", 2), FastSeq("A", "T"))) + assertEvalsTo( + invoke("min_rep", tvariant, locusIR, alleles), + Row(Locus("chr22", 2), FastSeq("A", "T")), + ) assertEvalsTo(invoke("min_rep", tvariant, locusIR, NA(TArray(TString))), null) } - @Test def globalPosition() { + @Test def globalPosition(): Unit = assertEvalsTo(invoke("locusToGlobalPos", TInt64, locusIR), grch38.locusToGlobalPos(locus)) - } - @Test def reverseGlobalPosition() { + @Test def reverseGlobalPosition(): Unit = { val globalPosition = 2824183054L - assertEvalsTo(invoke("globalPosToLocus", tlocus, I64(globalPosition)), grch38.globalPosToLocus(globalPosition)) + assertEvalsTo( + invoke("globalPosToLocus", tlocus, I64(globalPosition)), + grch38.globalPosToLocus(globalPosition), + ) } - @Test def testMultipleReferenceGenomes() { + @Test def testMultipleReferenceGenomes(): Unit = { implicit val execStrats = ExecStrategy.compileOnly val ir = MakeTuple.ordered(FastSeq( invoke("Locus", TLocus(ReferenceGenome.GRCh37), Str("1"), I32(1)), - invoke("Locus", TLocus(ReferenceGenome.GRCh38), Str("chr1"), I32(1)))) + invoke("Locus", TLocus(ReferenceGenome.GRCh38), Str("chr1"), I32(1)), + )) - assertEvalsTo(ir, Row(Locus("1", 1, ctx.getReference(ReferenceGenome.GRCh37)), Locus("chr1", 1, ctx.getReference(ReferenceGenome.GRCh38)))) + assertEvalsTo( + ir, + Row( + Locus("1", 1, ctx.getReference(ReferenceGenome.GRCh37)), + Locus("chr1", 1, ctx.getReference(ReferenceGenome.GRCh38)), + ), + ) } - @Test def testMakeInterval() { + @Test def testMakeInterval(): Unit = { // TString, TInt32, TInt32, TBoolean, TBoolean, TBoolean val ir = MakeTuple.ordered(FastSeq( - invoke("LocusInterval", TInterval(TLocus(grch38.name)), NA(TString), I32(1), I32(100), True(), True(), False()), - invoke("LocusInterval", TInterval(TLocus(grch38.name)), Str("chr1"), NA(TInt32), I32(100), True(), True(), False()), - invoke("LocusInterval", TInterval(TLocus(grch38.name)), Str("chr1"), I32(1), NA(TInt32), True(), True(), False()), - invoke("LocusInterval", TInterval(TLocus(grch38.name)), Str("chr1"), I32(1), I32(100), NA(TBoolean), True(), False()), - invoke("LocusInterval", TInterval(TLocus(grch38.name)), Str("chr1"), I32(1), I32(100), True(), NA(TBoolean), False()), - invoke("LocusInterval", TInterval(TLocus(grch38.name)), Str("chr1"), I32(1), I32(100), True(), True(), NA(TBoolean)), - invoke("LocusInterval", TInterval(TLocus(grch38.name)), Str("chr1"), I32(-1), I32(0), True(), True(), True()), - invoke("LocusInterval", TInterval(TLocus(grch38.name)), Str("chr1"), I32(1), I32(100), True(), True(), True()) + invoke( + "LocusInterval", + TInterval(TLocus(grch38.name)), + NA(TString), + I32(1), + I32(100), + True(), + True(), + False(), + ), + invoke( + "LocusInterval", + TInterval(TLocus(grch38.name)), + Str("chr1"), + NA(TInt32), + I32(100), + True(), + True(), + False(), + ), + invoke( + "LocusInterval", + TInterval(TLocus(grch38.name)), + Str("chr1"), + I32(1), + NA(TInt32), + True(), + True(), + False(), + ), + invoke( + "LocusInterval", + TInterval(TLocus(grch38.name)), + Str("chr1"), + I32(1), + I32(100), + NA(TBoolean), + True(), + False(), + ), + invoke( + "LocusInterval", + TInterval(TLocus(grch38.name)), + Str("chr1"), + I32(1), + I32(100), + True(), + NA(TBoolean), + False(), + ), + invoke( + "LocusInterval", + TInterval(TLocus(grch38.name)), + Str("chr1"), + I32(1), + I32(100), + True(), + True(), + NA(TBoolean), + ), + invoke( + "LocusInterval", + TInterval(TLocus(grch38.name)), + Str("chr1"), + I32(-1), + I32(0), + True(), + True(), + True(), + ), + invoke( + "LocusInterval", + TInterval(TLocus(grch38.name)), + Str("chr1"), + I32(1), + I32(100), + True(), + True(), + True(), + ), )) - assertEvalsTo(ir, + assertEvalsTo( + ir, Row( null, null, @@ -102,8 +183,8 @@ class LocusFunctionsSuite extends HailSuite { null, null, null, - Interval(Locus("chr1", 1, grch38), Locus("chr1", 100, grch38), true, true) - ) + Interval(Locus("chr1", 1, grch38), Locus("chr1", 100, grch38), true, true), + ), ) } } diff --git a/hail/src/test/scala/is/hail/expr/ir/LoweringPipelineSuite.scala b/hail/src/test/scala/is/hail/expr/ir/LoweringPipelineSuite.scala index db8883735c0..0ed0a7d6189 100644 --- a/hail/src/test/scala/is/hail/expr/ir/LoweringPipelineSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/LoweringPipelineSuite.scala @@ -2,6 +2,7 @@ package is.hail.expr.ir import is.hail.HailSuite import is.hail.expr.ir.lowering.{LowerMatrixToTablePass, LoweringPipeline, OptimizePass} + import org.testng.annotations.Test class LoweringPipelineSuite extends HailSuite { diff --git a/hail/src/test/scala/is/hail/expr/ir/MathFunctionsSuite.scala b/hail/src/test/scala/is/hail/expr/ir/MathFunctionsSuite.scala index 99697747342..320d8d19e47 100644 --- a/hail/src/test/scala/is/hail/expr/ir/MathFunctionsSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/MathFunctionsSuite.scala @@ -1,15 +1,13 @@ package is.hail.expr.ir -import is.hail.{ExecStrategy, HailSuite, stats} -import is.hail.types._ -import is.hail.utils._ +import is.hail.{stats, ExecStrategy, HailSuite} import is.hail.TestUtils._ -import is.hail.expr.ir.functions.MathFunctions import is.hail.types.virtual._ +import is.hail.utils._ + import org.apache.spark.sql.Row import org.testng.annotations.{DataProvider, Test} - class MathFunctionsSuite extends HailSuite { hc implicit val execStrats = ExecStrategy.values @@ -32,7 +30,7 @@ class MathFunctionsSuite extends HailSuite { assertEvalsTo(invoke("roundToNextPowerOf2", TInt32, I32(64)), 64) } - @Test def isnan() { + @Test def isnan(): Unit = { implicit val execStrats = ExecStrategy.javaOnly assertEvalsTo(invoke("isnan", TBoolean, F32(0)), false) @@ -42,7 +40,7 @@ class MathFunctionsSuite extends HailSuite { assertEvalsTo(invoke("isnan", TBoolean, F64(Double.NaN)), true) } - @Test def is_finite() { + @Test def is_finite(): Unit = { implicit val execStrats = ExecStrategy.javaOnly assertEvalsTo(invoke("is_finite", TBoolean, F32(0)), expected = true) @@ -58,7 +56,7 @@ class MathFunctionsSuite extends HailSuite { assertEvalsTo(invoke("is_finite", TBoolean, F64(Double.NegativeInfinity)), expected = false) } - @Test def is_infinite() { + @Test def is_infinite(): Unit = { implicit val execStrats = ExecStrategy.javaOnly assertEvalsTo(invoke("is_infinite", TBoolean, F32(0)), expected = false) @@ -74,17 +72,16 @@ class MathFunctionsSuite extends HailSuite { assertEvalsTo(invoke("is_infinite", TBoolean, F64(Double.NegativeInfinity)), expected = true) } - - @Test def sign() { + @Test def sign(): Unit = { implicit val execStrats = ExecStrategy.javaOnly assertEvalsTo(invoke("sign", TInt32, I32(2)), 1) assertEvalsTo(invoke("sign", TInt32, I32(0)), 0) assertEvalsTo(invoke("sign", TInt32, I32(-2)), -1) - assertEvalsTo(invoke("sign", TInt64, I64(2)), 1l) - assertEvalsTo(invoke("sign", TInt64, I64(0)), 0l) - assertEvalsTo(invoke("sign", TInt64, I64(-2)), -1l) + assertEvalsTo(invoke("sign", TInt64, I64(2)), 1L) + assertEvalsTo(invoke("sign", TInt64, I64(0)), 0L) + assertEvalsTo(invoke("sign", TInt64, I64(-2)), -1L) assertEvalsTo(invoke("sign", TFloat32, F32(2)), 1.0f) assertEvalsTo(invoke("sign", TFloat32, F32(0)), 0.0f) @@ -99,20 +96,68 @@ class MathFunctionsSuite extends HailSuite { assertEvalsTo(invoke("sign", TFloat64, F64(Double.NegativeInfinity)), -1.0) } - @Test def approxEqual() { + @Test def approxEqual(): Unit = { implicit val execStrats = ExecStrategy.javaOnly - assertEvalsTo(invoke("approxEqual", TBoolean, F64(0.025), F64(0.0250000001), F64(1e-4), False(), False()), true) - assertEvalsTo(invoke("approxEqual", TBoolean, F64(0.0154), F64(0.0156), F64(1e-4), True(), False()), false) - assertEvalsTo(invoke("approxEqual", TBoolean, F64(0.0154), F64(0.0156), F64(1e-3), True(), False()), true) - assertEvalsTo(invoke("approxEqual", TBoolean, F64(Double.NaN), F64(Double.NaN), F64(1e-3), True(), False()), false) - assertEvalsTo(invoke("approxEqual", TBoolean, F64(Double.NaN), F64(Double.NaN), F64(1e-3), True(), True()), true) - assertEvalsTo(invoke("approxEqual", TBoolean, F64(Double.PositiveInfinity), F64(Double.PositiveInfinity), F64(1e-3), True(), False()), true) - assertEvalsTo(invoke("approxEqual", TBoolean, F64(Double.NegativeInfinity), F64(Double.NegativeInfinity), F64(1e-3), True(), False()), true) - assertEvalsTo(invoke("approxEqual", TBoolean, F64(Double.PositiveInfinity), F64(Double.NegativeInfinity), F64(1e-3), True(), False()), false) + assertEvalsTo( + invoke("approxEqual", TBoolean, F64(0.025), F64(0.0250000001), F64(1e-4), False(), False()), + true, + ) + assertEvalsTo( + invoke("approxEqual", TBoolean, F64(0.0154), F64(0.0156), F64(1e-4), True(), False()), + false, + ) + assertEvalsTo( + invoke("approxEqual", TBoolean, F64(0.0154), F64(0.0156), F64(1e-3), True(), False()), + true, + ) + assertEvalsTo( + invoke("approxEqual", TBoolean, F64(Double.NaN), F64(Double.NaN), F64(1e-3), True(), False()), + false, + ) + assertEvalsTo( + invoke("approxEqual", TBoolean, F64(Double.NaN), F64(Double.NaN), F64(1e-3), True(), True()), + true, + ) + assertEvalsTo( + invoke( + "approxEqual", + TBoolean, + F64(Double.PositiveInfinity), + F64(Double.PositiveInfinity), + F64(1e-3), + True(), + False(), + ), + true, + ) + assertEvalsTo( + invoke( + "approxEqual", + TBoolean, + F64(Double.NegativeInfinity), + F64(Double.NegativeInfinity), + F64(1e-3), + True(), + False(), + ), + true, + ) + assertEvalsTo( + invoke( + "approxEqual", + TBoolean, + F64(Double.PositiveInfinity), + F64(Double.NegativeInfinity), + F64(1e-3), + True(), + False(), + ), + false, + ) } - @Test def entropy() { + @Test def entropy(): Unit = { implicit val execStrats = ExecStrategy.javaOnly assertEvalsTo(invoke("entropy", TFloat64, Str("")), 0.0) @@ -130,26 +175,51 @@ class MathFunctionsSuite extends HailSuite { Array(1, 1, 0, 1, 0.38647623077123266, Double.PositiveInfinity), Array(1, 1, 1, 0, 0.38647623077123266, 0.0), Array(10, 10, 10, 10, 1.0, 1.0), - Array(51, 43, 22, 92, 1.462626e-7, (51.0 * 92) / (22 * 43)) + Array(51, 43, 22, 92, 1.462626e-7, (51.0 * 92) / (22 * 43)), ) @Test(dataProvider = "chi_squared_test") - def chiSquaredTest(a: Int, b: Int, c: Int, d: Int, pValue: Double, oddsRatio: Double) { - val r = eval(invoke("chi_squared_test", stats.chisqStruct.virtualType, ErrorIDs.NO_ERROR, a, b, c, d)).asInstanceOf[Row] - assert(D0_==(pValue, r.getDouble(0))) - assert(D0_==(oddsRatio, r.getDouble(1))) + def chiSquaredTest(a: Int, b: Int, c: Int, d: Int, pValue: Double, oddsRatio: Double): Unit = { + val r = eval(invoke( + "chi_squared_test", + stats.chisqStruct.virtualType, + ErrorIDs.NO_ERROR, + a, + b, + c, + d, + )).asInstanceOf[Row] + assert(D0_==(pValue, r.getDouble(0))) + assert(D0_==(oddsRatio, r.getDouble(1))) } @DataProvider(name = "fisher_exact_test") def fisherExactData(): Array[Array[Any]] = Array( Array(0, 0, 0, 0, Double.NaN, Double.NaN, Double.NaN, Double.NaN), Array(10, 10, 10, 10, 1.0, 1.0, 0.243858, 4.100748), - Array(51, 43, 22, 92, 2.1565e-7, 4.918058, 2.565937, 9.677930) + Array(51, 43, 22, 92, 2.1565e-7, 4.918058, 2.565937, 9.677930), ) @Test(dataProvider = "fisher_exact_test") - def fisherExactTest(a: Int, b: Int, c: Int, d: Int, pValue: Double, oddsRatio: Double, confLower: Double, confUpper: Double) { - val r = eval(invoke("fisher_exact_test", stats.fetStruct.virtualType, ErrorIDs.NO_ERROR, a, b, c, d)).asInstanceOf[Row] + def fisherExactTest( + a: Int, + b: Int, + c: Int, + d: Int, + pValue: Double, + oddsRatio: Double, + confLower: Double, + confUpper: Double, + ): Unit = { + val r = eval(invoke( + "fisher_exact_test", + stats.fetStruct.virtualType, + ErrorIDs.NO_ERROR, + a, + b, + c, + d, + )).asInstanceOf[Row] assert(D0_==(pValue, r.getDouble(0))) assert(D0_==(oddsRatio, r.getDouble(1))) assert(D0_==(confLower, r.getDouble(2))) @@ -159,12 +229,29 @@ class MathFunctionsSuite extends HailSuite { @DataProvider(name = "contingency_table_test") def contingencyTableData(): Array[Array[Any]] = Array( Array(51, 43, 22, 92, 22, 1.462626e-7, 4.95983087), - Array(51, 43, 22, 92, 23, 2.1565e-7, 4.91805817) + Array(51, 43, 22, 92, 23, 2.1565e-7, 4.91805817), ) @Test(dataProvider = "contingency_table_test") - def contingencyTableTest(a: Int, b: Int, c: Int, d: Int, minCellCount: Int, pValue: Double, oddsRatio: Double) { - val r = eval(invoke("contingency_table_test", stats.chisqStruct.virtualType, ErrorIDs.NO_ERROR, a, b, c, d, minCellCount)).asInstanceOf[Row] + def contingencyTableTest( + a: Int, + b: Int, + c: Int, + d: Int, + minCellCount: Int, + pValue: Double, + oddsRatio: Double, + ): Unit = { + val r = eval(invoke( + "contingency_table_test", + stats.chisqStruct.virtualType, + ErrorIDs.NO_ERROR, + a, + b, + c, + d, + minCellCount, + )).asInstanceOf[Row] assert(D0_==(pValue, r.getDouble(0))) assert(D0_==(oddsRatio, r.getDouble(1))) } @@ -174,24 +261,45 @@ class MathFunctionsSuite extends HailSuite { Array(0, 0, 0, Double.NaN, 0.5), Array(1, 2, 1, 0.57142857, 0.65714285), Array(0, 1, 0, 1.0, 0.5), - Array(100, 200, 100, 0.50062578, 0.96016808) + Array(100, 200, 100, 0.50062578, 0.96016808), ) @Test(dataProvider = "hardy_weinberg_test") - def hardyWeinbergTest(nHomRef: Int, nHet: Int, nHomVar: Int, pValue: Double, hetFreq: Double) { - val r = eval(invoke("hardy_weinberg_test", stats.hweStruct.virtualType, ErrorIDs.NO_ERROR, nHomRef, nHet, nHomVar, false)).asInstanceOf[Row] + def hardyWeinbergTest(nHomRef: Int, nHet: Int, nHomVar: Int, pValue: Double, hetFreq: Double) + : Unit = { + val r = eval(invoke( + "hardy_weinberg_test", + stats.hweStruct.virtualType, + ErrorIDs.NO_ERROR, + nHomRef, + nHet, + nHomVar, + false, + )).asInstanceOf[Row] assert(D0_==(pValue, r.getDouble(0))) assert(D0_==(hetFreq, r.getDouble(1))) } - @Test def modulusTest() { - assertFatal(invoke("mod", TInt32, I32(1), I32(0)), "(modulo by zero)|(error while calling 'mod')") - assertFatal(invoke("mod", TInt64, I64(1), I64(0)), "(modulo by zero)|(error while calling 'mod')") - assertFatal(invoke("mod", TFloat32, F32(1), F32(0)), "(modulo by zero)|(error while calling 'mod')") - assertFatal(invoke("mod", TFloat64, F64(1), F64(0)), "(modulo by zero)|(error while calling 'mod')") + @Test def modulusTest(): Unit = { + assertFatal( + invoke("mod", TInt32, I32(1), I32(0)), + "(modulo by zero)|(error while calling 'mod')", + ) + assertFatal( + invoke("mod", TInt64, I64(1), I64(0)), + "(modulo by zero)|(error while calling 'mod')", + ) + assertFatal( + invoke("mod", TFloat32, F32(1), F32(0)), + "(modulo by zero)|(error while calling 'mod')", + ) + assertFatal( + invoke("mod", TFloat64, F64(1), F64(0)), + "(modulo by zero)|(error while calling 'mod')", + ) } - @Test def testMinMax() { + @Test def testMinMax(): Unit = { implicit val execStrats = ExecStrategy.javaOnly assertAllEvalTo( (invoke("min", TFloat32, F32(1.0f), F32(2.0f)), 1.0f), @@ -209,7 +317,7 @@ class MathFunctionsSuite extends HailSuite { (invoke("min", TFloat32, F32(1.0f), F32(Float.NaN)), Float.NaN), (invoke("min", TFloat64, F64(1.0), F64(Double.NaN)), Double.NaN), (invoke("max", TFloat32, F32(1.0f), F32(Float.NaN)), Float.NaN), - (invoke("max", TFloat64, F64(1.0), F64(Double.NaN)), Double.NaN) + (invoke("max", TFloat64, F64(1.0), F64(Double.NaN)), Double.NaN), ) } } diff --git a/hail/src/test/scala/is/hail/expr/ir/MatrixIRSuite.scala b/hail/src/test/scala/is/hail/expr/ir/MatrixIRSuite.scala index b747c351247..9ef28b85a72 100644 --- a/hail/src/test/scala/is/hail/expr/ir/MatrixIRSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/MatrixIRSuite.scala @@ -1,5 +1,6 @@ package is.hail.expr.ir +import is.hail.{ExecStrategy, HailSuite} import is.hail.ExecStrategy.ExecStrategy import is.hail.TestUtils._ import is.hail.annotations.BroadcastRow @@ -7,27 +8,40 @@ import is.hail.expr.JSONAnnotationImpex import is.hail.expr.ir.TestUtils._ import is.hail.types.virtual._ import is.hail.utils._ -import is.hail.{ExecStrategy, HailSuite} + import org.apache.spark.sql.Row import org.json4s.jackson.JsonMethods import org.testng.annotations.{DataProvider, Test} class MatrixIRSuite extends HailSuite { - implicit val execStrats: Set[ExecStrategy] = Set(ExecStrategy.Interpret, ExecStrategy.InterpretUnoptimized, ExecStrategy.LoweredJVMCompile) + implicit val execStrats: Set[ExecStrategy] = + Set(ExecStrategy.Interpret, ExecStrategy.InterpretUnoptimized, ExecStrategy.LoweredJVMCompile) @Test def testMatrixWriteRead(): Unit = { val range = MatrixIR.range(10, 10, Some(3)) - val withEntries = MatrixMapEntries(range, makestruct( - "i" -> GetField(Ref("va", range.typ.rowType), "row_idx"), - "j" -> GetField(Ref("sa", range.typ.colType), "col_idx"))) + val withEntries = MatrixMapEntries( + range, + makestruct( + "i" -> GetField(Ref("va", range.typ.rowType), "row_idx"), + "j" -> GetField(Ref("sa", range.typ.colType), "col_idx"), + ), + ) val original = MatrixMapGlobals(withEntries, makestruct("foo" -> I32(0))) val path = ctx.createTmpPath("test-range-read", "mt") val writer1 = MatrixNativeWriter(path, overwrite = true) val partType = TArray(TInterval(TStruct("row_idx" -> TInt32))) - val parts = JsonMethods.compact(JSONAnnotationImpex.exportAnnotation(FastSeq(Interval(Row(0), Row(10), true, false)), partType)) - val writer2 = MatrixNativeWriter(path, overwrite = true, partitions = parts.toString, partitionsTypeStr = partType.parsableString()) + val parts = JsonMethods.compact(JSONAnnotationImpex.exportAnnotation( + FastSeq(Interval(Row(0), Row(10), true, false)), + partType, + )) + val writer2 = MatrixNativeWriter( + path, + overwrite = true, + partitions = parts.toString, + partitionsTypeStr = partType.parsableString(), + ) for (writer <- Array(writer1, writer2)) { assertEvalsTo(MatrixWrite(original, writer), ()) @@ -41,21 +55,35 @@ class MatrixIRSuite extends HailSuite { (partSize, partIndex) <- partition(10, 3).zipWithIndex i <- 0 until partSize } yield Row(partIndex.toLong, i.toLong) - (0 until 10, uids).zipped.map { (i, uid) => Row(i, uid, expectedCols.map { case Row(j, _) => Row(i, j) }) } + (0 until 10, uids).zipped.map { (i, uid) => + Row(i, uid, expectedCols.map { case Row(j, _) => Row(i, j) }) + } } else - Array.tabulate(10)(i => Row(i, Row(0L, i.toLong), expectedCols.map { case Row(j, _) => Row(i, j) })).toFastSeq + Array.tabulate(10)(i => + Row(i, Row(0L, i.toLong), expectedCols.map { case Row(j, _) => Row(i, j) }) + ).toFastSeq val expectedGlobals = Row(0, expectedCols); { - implicit val execStrats: Set[ExecStrategy] = Set(ExecStrategy.Interpret, ExecStrategy.InterpretUnoptimized) + implicit val execStrats: Set[ExecStrategy] = + Set(ExecStrategy.Interpret, ExecStrategy.InterpretUnoptimized) assertEvalsTo( TableCollect(TableKeyBy(CastMatrixToTable(read, "entries", "cols"), FastSeq())), - Row(expectedRows, expectedGlobals)) - assertEvalsTo(TableCollect(TableKeyBy(CastMatrixToTable(droppedRows, "entries", "cols"), FastSeq())), Row(FastSeq(), expectedGlobals)) + Row(expectedRows, expectedGlobals), + ) + assertEvalsTo( + TableCollect(TableKeyBy(CastMatrixToTable(droppedRows, "entries", "cols"), FastSeq())), + Row(FastSeq(), expectedGlobals), + ) } } } - def rangeMatrix(nRows: Int = 20, nCols: Int = 20, nPartitions: Option[Int] = Some(4), uids: Boolean = false): MatrixIR = { + def rangeMatrix( + nRows: Int = 20, + nCols: Int = 20, + nPartitions: Option[Int] = Some(4), + uids: Boolean = false, + ): MatrixIR = { val reader = MatrixRangeReader(nRows, nCols, nPartitions) val requestedType = if (uids) reader.fullMatrixType @@ -70,7 +98,7 @@ class MatrixIRSuite extends HailSuite { def getCols(mir: MatrixIR): Array[Row] = Interpret(MatrixColsTable(mir), ctx).rdd.collect() - @Test def testScanCountBehavesLikeIndexOnRows() { + @Test def testScanCountBehavesLikeIndexOnRows(): Unit = { val mt = rangeMatrix() val oldRow = Ref("va", mt.typ.rowType) @@ -81,7 +109,7 @@ class MatrixIRSuite extends HailSuite { assert(rows.forall { case Row(row_idx, idx) => row_idx == idx }, rows.toSeq) } - @Test def testScanCollectBehavesLikeRangeOnRows() { + @Test def testScanCollectBehavesLikeRangeOnRows(): Unit = { val mt = rangeMatrix() val oldRow = Ref("va", mt.typ.rowType) @@ -89,21 +117,28 @@ class MatrixIRSuite extends HailSuite { val newMatrix = MatrixMapRows(mt, newRow) val rows = getRows(newMatrix) - assert(rows.forall { case Row(row_idx: Int, range: IndexedSeq[_]) => range sameElements Array.range(0, row_idx) }) + assert(rows.forall { case Row(row_idx: Int, range: IndexedSeq[_]) => + range sameElements Array.range(0, row_idx) + }) } - @Test def testScanCollectBehavesLikeRangeWithAggregationOnRows() { + @Test def testScanCollectBehavesLikeRangeWithAggregationOnRows(): Unit = { val mt = rangeMatrix() val oldRow = Ref("va", mt.typ.rowType) - val newRow = InsertFields(oldRow, Seq("n" -> IRAggCount, "range" -> IRScanCollect(GetField(oldRow, "row_idx").toL))) + val newRow = InsertFields( + oldRow, + Seq("n" -> IRAggCount, "range" -> IRScanCollect(GetField(oldRow, "row_idx").toL)), + ) val newMatrix = MatrixMapRows(mt, newRow) val rows = getRows(newMatrix) - assert(rows.forall { case Row(row_idx: Int, n: Long, range: IndexedSeq[_]) => (n == 20) && (range sameElements Array.range(0, row_idx)) }) + assert(rows.forall { case Row(row_idx: Int, n: Long, range: IndexedSeq[_]) => + (n == 20) && (range sameElements Array.range(0, row_idx)) + }) } - @Test def testScanCountBehavesLikeIndexOnCols() { + @Test def testScanCountBehavesLikeIndexOnCols(): Unit = { val mt = rangeMatrix() val oldCol = Ref("sa", mt.typ.colType) @@ -114,7 +149,7 @@ class MatrixIRSuite extends HailSuite { assert(cols.forall { case Row(col_idx, idx) => col_idx == idx }) } - @Test def testScanCollectBehavesLikeRangeOnCols() { + @Test def testScanCollectBehavesLikeRangeOnCols(): Unit = { val mt = rangeMatrix() val oldCol = Ref("sa", mt.typ.colType) @@ -122,18 +157,25 @@ class MatrixIRSuite extends HailSuite { val newMatrix = MatrixMapCols(mt, newCol, None) val cols = getCols(newMatrix) - assert(cols.forall { case Row(col_idx: Int, range: IndexedSeq[_]) => range sameElements Array.range(0, col_idx) }) + assert(cols.forall { case Row(col_idx: Int, range: IndexedSeq[_]) => + range sameElements Array.range(0, col_idx) + }) } - @Test def testScanCollectBehavesLikeRangeWithAggregationOnCols() { + @Test def testScanCollectBehavesLikeRangeWithAggregationOnCols(): Unit = { val mt = rangeMatrix() val oldCol = Ref("sa", mt.typ.colType) - val newCol = InsertFields(oldCol, Seq("n" -> IRAggCount, "range" -> IRScanCollect(GetField(oldCol, "col_idx").toL))) + val newCol = InsertFields( + oldCol, + Seq("n" -> IRAggCount, "range" -> IRScanCollect(GetField(oldCol, "col_idx").toL)), + ) val newMatrix = MatrixMapCols(mt, newCol, None) val cols = getCols(newMatrix) - assert(cols.forall { case Row(col_idx: Int, n: Long, range: IndexedSeq[_]) => (n == 20) && (range sameElements Array.range(0, col_idx)) }) + assert(cols.forall { case Row(col_idx: Int, n: Long, range: IndexedSeq[_]) => + (n == 20) && (range sameElements Array.range(0, col_idx)) + }) } def rangeRowMatrix(start: Int, end: Int): MatrixIR = { @@ -145,8 +187,11 @@ class MatrixIRSuite extends HailSuite { MatrixKeyRowsBy(baseRange, FastSeq()), InsertFields( row, - FastSeq("row_idx" -> (GetField(row, "row_idx") + start)))), - FastSeq("row_idx")) + FastSeq("row_idx" -> (GetField(row, "row_idx") + start)), + ), + ), + FastSeq("row_idx"), + ) } @DataProvider(name = "unionRowsData") @@ -156,10 +201,11 @@ class MatrixIRSuite extends HailSuite { Array(FastSeq(0 -> 6, 5 -> 7)), Array(FastSeq(2 -> 3, 0 -> 1, 5 -> 7)), Array(FastSeq(2 -> 4, 0 -> 3, 5 -> 7)), - Array(FastSeq(3 -> 6, 0 -> 1, 5 -> 7))) + Array(FastSeq(3 -> 6, 0 -> 1, 5 -> 7)), + ) @Test(dataProvider = "unionRowsData") - def testMatrixUnionRows(ranges: IndexedSeq[(Int, Int)]) { + def testMatrixUnionRows(ranges: IndexedSeq[(Int, Int)]): Unit = { val expectedOrdering = ranges.flatMap { case (start, end) => Array.range(start, end) }.sorted @@ -181,18 +227,19 @@ class MatrixIRSuite extends HailSuite { Array(FastSeq("na"), FastSeq(null)), Array(FastSeq("x", "y"), FastSeq(3)), Array(FastSeq("foo", "bar"), FastSeq(1, 3)), - Array(FastSeq("a", "b", "c"), FastSeq())) + Array(FastSeq("a", "b", "c"), FastSeq()), + ) @Test(dataProvider = "explodeRowsData") - def testMatrixExplode(path: IndexedSeq[String], collection: IndexedSeq[Integer]) { - val tarray = TArray(TInt32) + def testMatrixExplode(path: IndexedSeq[String], collection: IndexedSeq[Integer]): Unit = { val range = rangeMatrix(5, 2, None) val field = path.init.foldRight(path.last -> toIRArray(collection))(_ -> IRStruct(_)) val annotated = MatrixMapRows(range, InsertFields(Ref("va", range.typ.rowType), FastSeq(field))) val q = annotated.typ.rowType.query(path: _*) - val exploded = getRows(MatrixExplodeRows(annotated, path.toFastSeq)).map(q(_).asInstanceOf[Integer]) + val exploded = + getRows(MatrixExplodeRows(annotated, path.toFastSeq)).map(q(_).asInstanceOf[Integer]) val expected = if (collection == null) Array[Integer]() else Array.fill(5)(collection).flatten assert(exploded sameElements expected) @@ -204,26 +251,29 @@ class MatrixIRSuite extends HailSuite { val rowSig = TStruct( "row_idx" -> TInt32, "animal" -> TString, - "__entries" -> TArray(TStruct("ent1" -> TString, "ent2" -> TFloat64)) + "__entries" -> TArray(TStruct("ent1" -> TString, "ent2" -> TFloat64)), ) val keyNames = FastSeq("row_idx") val colSig = TStruct("col_idx" -> TInt32, "tag" -> TString) val globalType = TStruct(("__cols", TArray(colSig))) var tv = TableValue(ctx, rowSig, keyNames, rowRdd) - tv = tv.copy(typ = tv.typ.copy(globalType = globalType), globals = BroadcastRow(ctx, Row(cdata.toFastSeq), globalType)) + tv = tv.copy( + typ = tv.typ.copy(globalType = globalType), + globals = BroadcastRow(ctx, Row(cdata.toFastSeq), globalType), + ) TableLiteral(tv, theHailClassLoader) } - @Test def testCastTableToMatrix() { + @Test def testCastTableToMatrix(): Unit = { val rdata = Array( Row(1, "fish", FastSeq(Row("a", 1.0), Row("x", 2.0))), Row(2, "cat", FastSeq(Row("b", 0.0), Row("y", 0.1))), - Row(3, "dog", FastSeq(Row("c", -1.0), Row("z", 30.0))) + Row(3, "dog", FastSeq(Row("c", -1.0), Row("z", 30.0))), ) val cdata = Array( Row(1, "atag"), - Row(2, "btag") + Row(2, "btag"), ) val rowTab = makeLocalizedTable(rdata, cdata) @@ -244,15 +294,15 @@ class MatrixIRSuite extends HailSuite { assert(localCols sameElements cdata) } - @Test def testCastTableToMatrixErrors() { + @Test def testCastTableToMatrixErrors(): Unit = { val rdata = Array( Row(1, "fish", FastSeq(Row("x", 2.0))), Row(2, "cat", FastSeq(Row("b", 0.0), Row("y", 0.1))), - Row(3, "dog", FastSeq(Row("c", -1.0), Row("z", 30.0))) + Row(3, "dog", FastSeq(Row("c", -1.0), Row("z", 30.0))), ) val cdata = Array( Row(1, "atag"), - Row(2, "btag") + Row(2, "btag"), ) val rowTab = makeLocalizedTable(rdata, cdata) @@ -271,34 +321,40 @@ class MatrixIRSuite extends HailSuite { val rdata2 = Array( Row(1, "fish", null), Row(2, "cat", FastSeq(Row("b", 0.0), Row("y", 0.1))), - Row(3, "dog", FastSeq(Row("c", -1.0), Row("z", 30.0))) + Row(3, "dog", FastSeq(Row("c", -1.0), Row("z", 30.0))), ) val rowTab2 = makeLocalizedTable(rdata2, cdata) val mir2 = CastTableToMatrix(rowTab2, "__entries", "__cols", Array("col_idx")) - interceptSpark("missing") { Interpret(mir2, ctx, optimize = true).rvd.count() } + interceptSpark("missing")(Interpret(mir2, ctx, optimize = true).rvd.count()) } - @Test def testMatrixFiltersWorkWithRandomness() { + @Test def testMatrixFiltersWorkWithRandomness(): Unit = { val range = rangeMatrix(20, 20, Some(4), uids = true) def rand(rng: IR): IR = ApplySeeded("rand_bool", FastSeq(0.5), rng, 0, TBoolean) val colUID = GetField(Ref("sa", range.typ.colType), MatrixReader.colUIDFieldName) val colRNG = RNGSplit(RNGStateLiteral(), colUID) - val cols = Interpret(MatrixFilterCols(range, rand(colRNG)), ctx, optimize = true).toMatrixValue(range.typ.colKey).nCols + val cols = Interpret(MatrixFilterCols(range, rand(colRNG)), ctx, optimize = true).toMatrixValue( + range.typ.colKey + ).nCols val rowUID = GetField(Ref("va", range.typ.rowType), MatrixReader.rowUIDFieldName) val rowRNG = RNGSplit(RNGStateLiteral(), rowUID) val rows = Interpret(MatrixFilterRows(range, rand(rowRNG)), ctx, optimize = true).rvd.count() val entryRNG = RNGSplit(RNGStateLiteral(), MakeTuple.ordered(FastSeq(rowUID, colUID))) - val entries = Interpret(MatrixEntriesTable(MatrixFilterEntries(range, rand(entryRNG))), ctx, optimize = true).rvd.count() + val entries = Interpret( + MatrixEntriesTable(MatrixFilterEntries(range, rand(entryRNG))), + ctx, + optimize = true, + ).rvd.count() assert(cols < 20 && cols > 0) assert(rows < 20 && rows > 0) assert(entries < 400 && entries > 0) } - @Test def testMatrixRepartition() { + @Test def testMatrixRepartition(): Unit = { val range = rangeMatrix(11, 3, Some(10)) val params = Array( @@ -307,7 +363,7 @@ class MatrixIRSuite extends HailSuite { 5 -> RepartitionStrategy.SHUFFLE, 5 -> RepartitionStrategy.NAIVE_COALESCE, 10 -> RepartitionStrategy.SHUFFLE, - 10 -> RepartitionStrategy.COALESCE + 10 -> RepartitionStrategy.COALESCE, ) params.foreach { case (n, strat) => val rvd = Interpret(MatrixRepartition(range, n, strat), ctx, optimize = false).rvd @@ -317,13 +373,16 @@ class MatrixIRSuite extends HailSuite { } } - @Test def testMatrixMultiWriteDifferentTypesRaisesError() { + @Test def testMatrixMultiWriteDifferentTypesRaisesError(): Unit = { val vcf = is.hail.TestUtils.importVCF(ctx, "src/test/resources/sample.vcf") val range = rangeMatrix(10, 2, None) val path1 = ctx.createTmpPath("test1") val path2 = ctx.createTmpPath("test2") intercept[HailException] { - TypeCheck(ctx, MatrixMultiWrite(FastSeq(vcf, range), MatrixNativeMultiWriter(IndexedSeq(path1, path2)))) + TypeCheck( + ctx, + MatrixMultiWrite(FastSeq(vcf, range), MatrixNativeMultiWriter(IndexedSeq(path1, path2))), + ) } } } diff --git a/hail/src/test/scala/is/hail/expr/ir/MemoryLeakSuite.scala b/hail/src/test/scala/is/hail/expr/ir/MemoryLeakSuite.scala index efbddd02e37..07727bdc187 100644 --- a/hail/src/test/scala/is/hail/expr/ir/MemoryLeakSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/MemoryLeakSuite.scala @@ -6,6 +6,7 @@ import is.hail.backend.ExecuteContext import is.hail.expr.ir import is.hail.types.virtual.{TArray, TBoolean, TSet, TString} import is.hail.utils._ + import org.testng.annotations.Test class MemoryLeakSuite extends HailSuite { @@ -13,14 +14,21 @@ class MemoryLeakSuite extends HailSuite { val litSize = 32000 - def run(size: Int): Long = { + def run(size: Int): Long = { val lit = Literal(TSet(TString), (0 until litSize).map(_.toString).toSet) val queries = Literal(TArray(TString), (0 until size).map(_.toString).toFastSeq) ExecuteContext.scoped() { ctx => - val r = eval( + eval( ToArray( - mapIR(ToStream(queries)) { r => ir.invoke("contains", TBoolean, lit, r) } - ), Env.empty, FastSeq(), None, None, false, ctx) + mapIR(ToStream(queries))(r => ir.invoke("contains", TBoolean, lit, r)) + ), + Env.empty, + FastSeq(), + None, + None, + false, + ctx, + ) ctx.r.pool.getHighestTotalUsage } } diff --git a/hail/src/test/scala/is/hail/expr/ir/MissingArrayBuilderSuite.scala b/hail/src/test/scala/is/hail/expr/ir/MissingArrayBuilderSuite.scala index 0fbc6a5cfdd..e4137672c0a 100644 --- a/hail/src/test/scala/is/hail/expr/ir/MissingArrayBuilderSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/MissingArrayBuilderSuite.scala @@ -2,11 +2,12 @@ package is.hail.expr.ir import is.hail.asm4s.AsmFunction2 import is.hail.utils.FastSeq -import org.scalatest.testng.TestNGSuite -import org.testng.annotations.{DataProvider, Test} import scala.reflect.ClassTag +import org.scalatest.testng.TestNGSuite +import org.testng.annotations.{DataProvider, Test} + class MissingArrayBuilderSuite extends TestNGSuite { def ordering[T <: AnyVal](f: (T, T) => Boolean): AsmFunction2[T, T, Boolean] = new AsmFunction2[T, T, Boolean] { @@ -14,8 +15,11 @@ class MissingArrayBuilderSuite extends TestNGSuite { } def addToArrayBuilder[B <: MissingArrayBuilder, T]( - ab: B, array: IndexedSeq[T] - )(f: (B, T) => Unit): Unit = { + ab: B, + array: IndexedSeq[T], + )( + f: (B, T) => Unit + ): Unit = { array.foreach { i => if (i == null) ab.addMissing() @@ -24,7 +28,8 @@ class MissingArrayBuilderSuite extends TestNGSuite { } } - def getResult[B <: MissingArrayBuilder, T >: Null : ClassTag](ab: B)(f: (B, Int) => T): Array[T] = { + def getResult[B <: MissingArrayBuilder, T >: Null: ClassTag](ab: B)(f: (B, Int) => T) + : Array[T] = { Array.tabulate[T](ab.size) { i => if (ab.isMissing(i)) null @@ -37,16 +42,16 @@ class MissingArrayBuilderSuite extends TestNGSuite { def integerData(): Array[Array[Any]] = Array( Array(FastSeq(3, null, 3, 7, null), FastSeq(3, 3, 7, null, null)), Array(FastSeq(null, null, null, null), FastSeq(null, null, null, null)), - Array(FastSeq(), FastSeq()) + Array(FastSeq(), FastSeq()), ) - @Test(dataProvider="sortInt") - def testSortOnIntArrayBuilder(array: IndexedSeq[Integer], expected: IndexedSeq[Integer]) { + @Test(dataProvider = "sortInt") + def testSortOnIntArrayBuilder(array: IndexedSeq[Integer], expected: IndexedSeq[Integer]): Unit = { val ab = new IntMissingArrayBuilder(16) - addToArrayBuilder(ab, array) { (iab, i) => iab.add(i) } + addToArrayBuilder(ab, array)((iab, i) => iab.add(i)) - ab.sort(ordering[Int] { (i, j) => i < j}) - val result = getResult[IntMissingArrayBuilder, Integer](ab) { (iab, i) => Int.box(iab(i)) } + ab.sort(ordering[Int]((i, j) => i < j)) + val result = getResult[IntMissingArrayBuilder, Integer](ab)((iab, i) => Int.box(iab(i))) assert(result sameElements expected) } @@ -54,16 +59,21 @@ class MissingArrayBuilderSuite extends TestNGSuite { def longData(): Array[Array[Any]] = Array( Array(FastSeq(3L, null, 3L, 7L, null), FastSeq(3L, 3L, 7L, null, null)), Array(FastSeq(null, null, null, null), FastSeq(null, null, null, null)), - Array(FastSeq(), FastSeq()) + Array(FastSeq(), FastSeq()), ) - @Test(dataProvider="sortLong") - def testSortOnLongArrayBuilder(array: IndexedSeq[java.lang.Long], expected: IndexedSeq[java.lang.Long]) { + @Test(dataProvider = "sortLong") + def testSortOnLongArrayBuilder( + array: IndexedSeq[java.lang.Long], + expected: IndexedSeq[java.lang.Long], + ): Unit = { val ab = new LongMissingArrayBuilder(16) - addToArrayBuilder(ab, array) { (jab, j) => jab.add(j) } + addToArrayBuilder(ab, array)((jab, j) => jab.add(j)) - ab.sort(ordering[Long] { (i, j) => i < j}) - val result = getResult[LongMissingArrayBuilder, java.lang.Long](ab) { (jab, j) => Long.box(jab(j)) } + ab.sort(ordering[Long]((i, j) => i < j)) + val result = getResult[LongMissingArrayBuilder, java.lang.Long](ab) { (jab, j) => + Long.box(jab(j)) + } assert(result sameElements expected) } @@ -71,16 +81,21 @@ class MissingArrayBuilderSuite extends TestNGSuite { def floatData(): Array[Array[Any]] = Array( Array(FastSeq(3f, null, 3f, 7f, null), FastSeq(3f, 3f, 7f, null, null)), Array(FastSeq(null, null, null, null), FastSeq(null, null, null, null)), - Array(FastSeq(), FastSeq()) + Array(FastSeq(), FastSeq()), ) - @Test(dataProvider="sortFloat") - def testSortOnFloatArrayBuilder(array: IndexedSeq[java.lang.Float], expected: IndexedSeq[java.lang.Float]) { + @Test(dataProvider = "sortFloat") + def testSortOnFloatArrayBuilder( + array: IndexedSeq[java.lang.Float], + expected: IndexedSeq[java.lang.Float], + ): Unit = { val ab = new FloatMissingArrayBuilder(16) - addToArrayBuilder(ab, array) { (fab, f) => fab.add(f) } + addToArrayBuilder(ab, array)((fab, f) => fab.add(f)) - ab.sort(ordering[Float] { (i, j) => i < j }) - val result = getResult[FloatMissingArrayBuilder, java.lang.Float](ab) { (fab, f) => Float.box(fab(f)) } + ab.sort(ordering[Float]((i, j) => i < j)) + val result = getResult[FloatMissingArrayBuilder, java.lang.Float](ab) { (fab, f) => + Float.box(fab(f)) + } assert(result sameElements expected) } @@ -88,16 +103,21 @@ class MissingArrayBuilderSuite extends TestNGSuite { def doubleData(): Array[Array[Any]] = Array( Array(FastSeq(3d, null, 3d, 7d, null), FastSeq(3d, 3d, 7d, null, null)), Array(FastSeq(null, null, null, null), FastSeq(null, null, null, null)), - Array(FastSeq(), FastSeq()) + Array(FastSeq(), FastSeq()), ) - @Test(dataProvider="sortDouble") - def testSortOnDoubleArrayBuilder(array: IndexedSeq[java.lang.Double], expected: IndexedSeq[java.lang.Double]) { + @Test(dataProvider = "sortDouble") + def testSortOnDoubleArrayBuilder( + array: IndexedSeq[java.lang.Double], + expected: IndexedSeq[java.lang.Double], + ): Unit = { val ab = new DoubleMissingArrayBuilder(16) - addToArrayBuilder(ab, array) { (dab, d) => dab.add(d) } + addToArrayBuilder(ab, array)((dab, d) => dab.add(d)) - ab.sort(ordering[Double] { (i, j) => i < j }) - val result = getResult[DoubleMissingArrayBuilder, java.lang.Double](ab) { (dab, d) => Double.box(dab(d)) } + ab.sort(ordering[Double]((i, j) => i < j)) + val result = getResult[DoubleMissingArrayBuilder, java.lang.Double](ab) { (dab, d) => + Double.box(dab(d)) + } assert(result sameElements expected) } @@ -105,16 +125,21 @@ class MissingArrayBuilderSuite extends TestNGSuite { def booleanData(): Array[Array[Any]] = Array( Array(FastSeq(true, null, true, false, null), FastSeq(false, true, true, null, null)), Array(FastSeq(null, null, null, null), FastSeq(null, null, null, null)), - Array(FastSeq(), FastSeq()) + Array(FastSeq(), FastSeq()), ) - @Test(dataProvider="sortBoolean") - def testSortOnBooleanArrayBuilder(array: IndexedSeq[java.lang.Boolean], expected: IndexedSeq[java.lang.Boolean]) { + @Test(dataProvider = "sortBoolean") + def testSortOnBooleanArrayBuilder( + array: IndexedSeq[java.lang.Boolean], + expected: IndexedSeq[java.lang.Boolean], + ): Unit = { val ab = new BooleanMissingArrayBuilder(16) - addToArrayBuilder(ab, array) { (bab, b) => bab.add(b) } + addToArrayBuilder(ab, array)((bab, b) => bab.add(b)) - ab.sort(ordering[Boolean] { (i, j) => i < j }) - val result = getResult[BooleanMissingArrayBuilder, java.lang.Boolean](ab) { (bab, b) => Boolean.box(bab(b)) } + ab.sort(ordering[Boolean]((i, j) => i < j)) + val result = getResult[BooleanMissingArrayBuilder, java.lang.Boolean](ab) { (bab, b) => + Boolean.box(bab(b)) + } assert(result sameElements expected) } } diff --git a/hail/src/test/scala/is/hail/expr/ir/OrderingSuite.scala b/hail/src/test/scala/is/hail/expr/ir/OrderingSuite.scala index 8b50d9029a1..57d38090233 100644 --- a/hail/src/test/scala/is/hail/expr/ir/OrderingSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/OrderingSuite.scala @@ -11,6 +11,7 @@ import is.hail.types.physical.stypes.EmitType import is.hail.types.physical.stypes.interfaces.SBaseStructValue import is.hail.types.virtual._ import is.hail.utils._ + import org.apache.spark.sql.Row import org.testng.annotations.{DataProvider, Test} @@ -25,7 +26,7 @@ class OrderingSuite extends HailSuite { case ti: TInterval => recursiveSize(ti.pointType) case tc: TIterable => recursiveSize(tc.elementType) case tbs: TBaseStruct => - tbs.types.map { t => recursiveSize(t) }.sum + tbs.types.map(t => recursiveSize(t)).sum case _ => 0 } inner + 1 @@ -35,7 +36,7 @@ class OrderingSuite extends HailSuite { t: PType, op: CodeOrdering.Op, r: Region, - sortOrder: SortOrder = Ascending + sortOrder: SortOrder = Ascending, ): AsmFunction3[Region, Long, Long, op.ReturnType] = { implicit val x = op.rtti val fb = EmitFunctionBuilder[Region, Long, Long, op.ReturnType](ctx, "lifted") @@ -43,20 +44,21 @@ class OrderingSuite extends HailSuite { val cv1 = t.loadCheapSCode(cb, fb.getCodeParam[Long](2)) val cv2 = t.loadCheapSCode(cb, fb.getCodeParam[Long](3)) fb.ecb.getOrderingFunction(cv1.st, cv2.st, op) - .apply(cb, EmitValue.present(cv1), EmitValue.present(cv2)) + .apply(cb, EmitValue.present(cv1), EmitValue.present(cv2)) } fb.resultWithIndex()(theHailClassLoader, ctx.fs, ctx.taskContext, r) } - @Test def testMissingNonequalComparisons() { + @Test def testMissingNonequalComparisons(): Unit = { def getStagedOrderingFunctionWithMissingness( t: PType, op: CodeOrdering.Op, r: Region, - sortOrder: SortOrder = Ascending + sortOrder: SortOrder = Ascending, ): AsmFunction5[Region, Boolean, Long, Boolean, Long, op.ReturnType] = { implicit val x = op.rtti - val fb = EmitFunctionBuilder[Region, Boolean, Long, Boolean, Long, op.ReturnType](ctx, "lifted") + val fb = + EmitFunctionBuilder[Region, Boolean, Long, Boolean, Long, op.ReturnType](ctx, "lifted") fb.emitWithBuilder { cb => val m1 = fb.getCodeParam[Boolean](2) val cv1 = t.loadCheapSCode(cb, fb.getCodeParam[Long](3)) @@ -74,151 +76,157 @@ class OrderingSuite extends HailSuite { t <- Type.genStruct a <- t.genNonmissingValue(sm) } yield (t, a) - val p = Prop.forAll(compareGen) { case (t, a) => pool.scopedRegion { region => - val pType = PType.canonical(t).asInstanceOf[PStruct] - val rvb = new RegionValueBuilder(sm, region) + val p = Prop.forAll(compareGen) { case (t, a) => + pool.scopedRegion { region => + val pType = PType.canonical(t).asInstanceOf[PStruct] - val v = pType.unstagedStoreJavaObject(sm, a, region) + val v = pType.unstagedStoreJavaObject(sm, a, region) - val eordME = t.mkOrdering(sm) - val eordMNE = t.mkOrdering(sm, missingEqual = false) + val eordME = t.mkOrdering(sm) + val eordMNE = t.mkOrdering(sm, missingEqual = false) - def checkCompare(compResult: Int, expected: Int) { - assert(java.lang.Integer.signum(compResult) == expected, - s"compare expected: $expected vs $compResult\n t=${t.parsableString()}\n v=$a") - } + def checkCompare(compResult: Int, expected: Int): Unit = + assert( + java.lang.Integer.signum(compResult) == expected, + s"compare expected: $expected vs $compResult\n t=${t.parsableString()}\n v=$a", + ) - val fcompareME = getStagedOrderingFunctionWithMissingness(pType, CodeOrdering.Compare(), region) + val fcompareME = + getStagedOrderingFunctionWithMissingness(pType, CodeOrdering.Compare(), region) - checkCompare(fcompareME(region, true, v, true, v), 0) - checkCompare(fcompareME(region, true, v, false, v), 1) - checkCompare(fcompareME(region, false, v, true, v), -1) + checkCompare(fcompareME(region, true, v, true, v), 0) + checkCompare(fcompareME(region, true, v, false, v), 1) + checkCompare(fcompareME(region, false, v, true, v), -1) - checkCompare(eordME.compare(null, null), 0) - checkCompare(eordME.compare(null, a), 1) - checkCompare(eordME.compare(a, null), -1) + checkCompare(eordME.compare(null, null), 0) + checkCompare(eordME.compare(null, a), 1) + checkCompare(eordME.compare(a, null), -1) - val fcompareMNE = getStagedOrderingFunctionWithMissingness(pType, CodeOrdering.Compare(false), region) + val fcompareMNE = + getStagedOrderingFunctionWithMissingness(pType, CodeOrdering.Compare(false), region) - checkCompare(fcompareMNE(region, true, v, true, v), -1) - checkCompare(fcompareMNE(region, true, v, false, v), 1) - checkCompare(fcompareMNE(region, false, v, true, v), -1) + checkCompare(fcompareMNE(region, true, v, true, v), -1) + checkCompare(fcompareMNE(region, true, v, false, v), 1) + checkCompare(fcompareMNE(region, false, v, true, v), -1) - checkCompare(eordMNE.compare(null, null), -1) - checkCompare(eordMNE.compare(null, a), 1) - checkCompare(eordMNE.compare(a, null), -1) + checkCompare(eordMNE.compare(null, null), -1) + checkCompare(eordMNE.compare(null, a), 1) + checkCompare(eordMNE.compare(a, null), -1) - def check(result: Boolean, expected: Boolean) { - assert(result == expected, s"t=${t.parsableString()}\n v=$a") - } + def check(result: Boolean, expected: Boolean): Unit = + assert(result == expected, s"t=${t.parsableString()}\n v=$a") - val fequivME = getStagedOrderingFunctionWithMissingness(pType, CodeOrdering.Equiv(), region) + val fequivME = getStagedOrderingFunctionWithMissingness(pType, CodeOrdering.Equiv(), region) - check(fequivME(region, true, v, true, v), true) - check(fequivME(region, true, v, false, v), false) - check(fequivME(region, false, v, true, v), false) + check(fequivME(region, true, v, true, v), true) + check(fequivME(region, true, v, false, v), false) + check(fequivME(region, false, v, true, v), false) - check(eordME.equiv(null, null), true) - check(eordME.equiv(null, a), false) - check(eordME.equiv(a, null), false) + check(eordME.equiv(null, null), true) + check(eordME.equiv(null, a), false) + check(eordME.equiv(a, null), false) - val fequivMNE = getStagedOrderingFunctionWithMissingness(pType, CodeOrdering.Equiv(false), region) + val fequivMNE = + getStagedOrderingFunctionWithMissingness(pType, CodeOrdering.Equiv(false), region) - check(fequivMNE(region, true, v, true, v), false) - check(fequivMNE(region, true, v, false, v), false) - check(fequivMNE(region, false, v, true, v), false) + check(fequivMNE(region, true, v, true, v), false) + check(fequivMNE(region, true, v, false, v), false) + check(fequivMNE(region, false, v, true, v), false) - check(eordMNE.equiv(null, null), false) - check(eordMNE.equiv(null, a), false) - check(eordMNE.equiv(a, null), false) + check(eordMNE.equiv(null, null), false) + check(eordMNE.equiv(null, a), false) + check(eordMNE.equiv(a, null), false) - val fltME = getStagedOrderingFunctionWithMissingness(pType, CodeOrdering.Lt(), region) + val fltME = getStagedOrderingFunctionWithMissingness(pType, CodeOrdering.Lt(), region) - check(fltME(region, true, v, true, v), false) - check(fltME(region, true, v, false, v), false) - check(fltME(region, false, v, true, v), true) + check(fltME(region, true, v, true, v), false) + check(fltME(region, true, v, false, v), false) + check(fltME(region, false, v, true, v), true) - check(eordME.lt(null, null), false) - check(eordME.lt(null, a), false) - check(eordME.lt(a, null), true) + check(eordME.lt(null, null), false) + check(eordME.lt(null, a), false) + check(eordME.lt(a, null), true) - val fltMNE = getStagedOrderingFunctionWithMissingness(pType, CodeOrdering.Lt(false), region) + val fltMNE = getStagedOrderingFunctionWithMissingness(pType, CodeOrdering.Lt(false), region) - check(fltMNE(region, true, v, true, v), true) - check(fltMNE(region, true, v, false, v), false) - check(fltMNE(region, false, v, true, v), true) + check(fltMNE(region, true, v, true, v), true) + check(fltMNE(region, true, v, false, v), false) + check(fltMNE(region, false, v, true, v), true) - check(eordMNE.lt(null, null), true) - check(eordMNE.lt(null, a), false) - check(eordMNE.lt(a, null), true) + check(eordMNE.lt(null, null), true) + check(eordMNE.lt(null, a), false) + check(eordMNE.lt(a, null), true) - val flteqME = getStagedOrderingFunctionWithMissingness(pType, CodeOrdering.Lteq(), region) + val flteqME = getStagedOrderingFunctionWithMissingness(pType, CodeOrdering.Lteq(), region) - check(flteqME(region, true, v, true, v), true) - check(flteqME(region, true, v, false, v), false) - check(flteqME(region, false, v, true, v), true) + check(flteqME(region, true, v, true, v), true) + check(flteqME(region, true, v, false, v), false) + check(flteqME(region, false, v, true, v), true) - check(eordME.lteq(null, null), true) - check(eordME.lteq(null, a), false) - check(eordME.lteq(a, null), true) + check(eordME.lteq(null, null), true) + check(eordME.lteq(null, a), false) + check(eordME.lteq(a, null), true) - val flteqMNE = getStagedOrderingFunctionWithMissingness(pType, CodeOrdering.Lteq(false), region) + val flteqMNE = + getStagedOrderingFunctionWithMissingness(pType, CodeOrdering.Lteq(false), region) - check(flteqMNE(region, true, v, true, v), true) - check(flteqMNE(region, true, v, false, v), false) - check(flteqMNE(region, false, v, true, v), true) + check(flteqMNE(region, true, v, true, v), true) + check(flteqMNE(region, true, v, false, v), false) + check(flteqMNE(region, false, v, true, v), true) - check(eordMNE.lteq(null, null), true) - check(eordMNE.lteq(null, a), false) - check(eordMNE.lteq(a, null), true) + check(eordMNE.lteq(null, null), true) + check(eordMNE.lteq(null, a), false) + check(eordMNE.lteq(a, null), true) - val fgtME = getStagedOrderingFunctionWithMissingness(pType, CodeOrdering.Gt(), region) + val fgtME = getStagedOrderingFunctionWithMissingness(pType, CodeOrdering.Gt(), region) - check(fgtME(region, true, v, true, v), false) - check(fgtME(region, true, v, false, v), true) - check(fgtME(region, false, v, true, v), false) + check(fgtME(region, true, v, true, v), false) + check(fgtME(region, true, v, false, v), true) + check(fgtME(region, false, v, true, v), false) - check(eordME.gt(null, null), false) - check(eordME.gt(null, a), true) - check(eordME.gt(a, null), false) + check(eordME.gt(null, null), false) + check(eordME.gt(null, a), true) + check(eordME.gt(a, null), false) - val fgtMNE = getStagedOrderingFunctionWithMissingness(pType, CodeOrdering.Gt(false), region) + val fgtMNE = getStagedOrderingFunctionWithMissingness(pType, CodeOrdering.Gt(false), region) - check(fgtMNE(region, true, v, true, v), false) - check(fgtMNE(region, true, v, false, v), true) - check(fgtMNE(region, false, v, true, v), false) + check(fgtMNE(region, true, v, true, v), false) + check(fgtMNE(region, true, v, false, v), true) + check(fgtMNE(region, false, v, true, v), false) - check(eordMNE.gt(null, null), false) - check(eordMNE.gt(null, a), true) - check(eordMNE.gt(a, null), false) + check(eordMNE.gt(null, null), false) + check(eordMNE.gt(null, a), true) + check(eordMNE.gt(a, null), false) - val fgteqME = getStagedOrderingFunctionWithMissingness(pType, CodeOrdering.Gteq(), region) + val fgteqME = getStagedOrderingFunctionWithMissingness(pType, CodeOrdering.Gteq(), region) - check(fgteqME(region, true, v, true, v), true) - check(fgteqME(region, true, v, false, v), true) - check(fgteqME(region, false, v, true, v), false) + check(fgteqME(region, true, v, true, v), true) + check(fgteqME(region, true, v, false, v), true) + check(fgteqME(region, false, v, true, v), false) - check(eordME.gteq(null, null), true) - check(eordME.gteq(null, a), true) - check(eordME.gteq(a, null), false) + check(eordME.gteq(null, null), true) + check(eordME.gteq(null, a), true) + check(eordME.gteq(a, null), false) - val fgteqMNE = getStagedOrderingFunctionWithMissingness(pType, CodeOrdering.Gt(false), region) + val fgteqMNE = + getStagedOrderingFunctionWithMissingness(pType, CodeOrdering.Gt(false), region) - check(fgteqMNE(region, true, v, true, v), false) - check(fgteqMNE(region, true, v, false, v), true) - check(fgteqMNE(region, false, v, true, v), false) + check(fgteqMNE(region, true, v, true, v), false) + check(fgteqMNE(region, true, v, false, v), true) + check(fgteqMNE(region, false, v, true, v), false) - check(eordMNE.gteq(null, null), false) - check(eordMNE.gteq(null, a), true) - check(eordMNE.gteq(a, null), false) + check(eordMNE.gteq(null, null), false) + check(eordMNE.gteq(null, a), true) + check(eordMNE.gteq(a, null), false) - true - }} + true + } + } p.check() } - @Test def testRandomOpsAgainstExtended() { + @Test def testRandomOpsAgainstExtended(): Unit = { val compareGen = for { t <- Type.genArb a1 <- t.genNonmissingValue(sm) @@ -227,7 +235,6 @@ class OrderingSuite extends HailSuite { val p = Prop.forAll(compareGen) { case (t, a1, a2) => pool.scopedRegion { region => val pType = PType.canonical(t) - val rvb = new RegionValueBuilder(sm, region) val v1 = pType.unstagedStoreJavaObject(sm, a1, region) @@ -237,7 +244,10 @@ class OrderingSuite extends HailSuite { val fcompare = getStagedOrderingFunction(pType, CodeOrdering.Compare(), region) val result = java.lang.Integer.signum(fcompare(region, v1, v2)) - assert(result == compare, s"compare expected: $compare vs $result\n t=${t.parsableString()}\n v1=${a1}\n v2=$a2") + assert( + result == compare, + s"compare expected: $compare vs $result\n t=${t.parsableString()}\n v1=$a1\n v2=$a2", + ) val equiv = t.ordering(sm).equiv(a1, a2) val fequiv = getStagedOrderingFunction(pType, CodeOrdering.Equiv(), region) @@ -270,7 +280,7 @@ class OrderingSuite extends HailSuite { p.check() } - @Test def testReverseIsSwappedArgumentsOfExtendedOrdering() { + @Test def testReverseIsSwappedArgumentsOfExtendedOrdering(): Unit = { val compareGen = for { t <- Type.genArb a1 <- t.genNonmissingValue(sm) @@ -279,7 +289,6 @@ class OrderingSuite extends HailSuite { val p = Prop.forAll(compareGen) { case (t, a1, a2) => pool.scopedRegion { region => val pType = PType.canonical(t) - val rvb = new RegionValueBuilder(sm, region) val v1 = pType.unstagedStoreJavaObject(sm, a1, region) @@ -293,7 +302,6 @@ class OrderingSuite extends HailSuite { assert(result == compare, s"compare expected: $compare vs $result") - val equiv = reversedExtendedOrdering.equiv(a2, a1) val fequiv = getStagedOrderingFunction(pType, CodeOrdering.Equiv(), region, Descending) @@ -325,7 +333,7 @@ class OrderingSuite extends HailSuite { p.check() } - @Test def testSortOnRandomArray() { + @Test def testSortOnRandomArray(): Unit = { implicit val execStrats = ExecStrategy.javaOnly val compareGen = for { elt <- Type.genArb @@ -334,15 +342,17 @@ class OrderingSuite extends HailSuite { } yield (elt, a, asc) val p = Prop.forAll(compareGen) { case (t, a: IndexedSeq[Any], asc: Boolean) => val ord = if (asc) t.ordering(sm).toOrdering else t.ordering(sm).reverse.toOrdering - assertEvalsTo(ArraySort(ToStream(In(0, TArray(t))), Literal.coerce(TBoolean, asc)), + assertEvalsTo( + ArraySort(ToStream(In(0, TArray(t))), Literal.coerce(TBoolean, asc)), FastSeq(a -> TArray(t)), - expected = a.sorted(ord)) + expected = a.sorted(ord), + ) true } p.check() } - def testToSetOnRandomDuplicatedArray() { + def testToSetOnRandomDuplicatedArray(): Unit = { implicit val execStrats = ExecStrategy.javaOnly val compareGen = for { elt <- Type.genArb @@ -350,15 +360,17 @@ class OrderingSuite extends HailSuite { } yield (elt, a) val p = Prop.forAll(compareGen) { case (t, a: IndexedSeq[Any]) => val array = a ++ a - assertEvalsTo(ToArray(ToSet(In(0, TArray(t)))), + assertEvalsTo( + ToArray(ToSet(In(0, TArray(t)))), FastSeq(array -> TArray(t)), - expected = array.sorted(t.ordering(sm).toOrdering).distinct) + expected = array.sorted(t.ordering(sm).toOrdering).distinct, + ) true } p.check() } - def testToDictOnRandomDuplicatedArray() { + def testToDictOnRandomDuplicatedArray(): Unit = { implicit val execStrats = ExecStrategy.javaOnly val compareGen = for { kt <- Type.genArb @@ -366,80 +378,97 @@ class OrderingSuite extends HailSuite { telt = TTuple(kt, vt) a <- TArray(telt).genNonmissingValue(sm) } yield (telt, a) - val p = Prop.forAll(compareGen) { case (telt: TTuple, a: IndexedSeq[Row]@unchecked) => + val p = Prop.forAll(compareGen) { case (telt: TTuple, a: IndexedSeq[Row] @unchecked) => val tdict = TDict(telt.types(0), telt.types(1)) val array: IndexedSeq[Row] = a ++ a val expectedMap = array.filter(_ != null).map { case Row(k, v) => (k, v) }.toMap assertEvalsTo( - ToArray(StreamMap(ToStream(In(0, TArray(telt))), - "x", GetField(Ref("x", tdict.elementType), "key"))), + ToArray(StreamMap( + ToStream(In(0, TArray(telt))), + "x", + GetField(Ref("x", tdict.elementType), "key"), + )), FastSeq(array -> TArray(telt)), - expected = expectedMap.keys.toFastSeq.sorted(telt.types(0).ordering(sm).toOrdering)) + expected = expectedMap.keys.toFastSeq.sorted(telt.types(0).ordering(sm).toOrdering), + ) true } p.check() } - @Test def testSortOnMissingArray() { + @Test def testSortOnMissingArray(): Unit = { implicit val execStrats = ExecStrategy.javaOnly val ts = TStream(TStruct("key" -> TInt32, "value" -> TInt32)) val irs: Array[IR => IR] = Array(ArraySort(_, True()), ToSet(_), ToDict(_)) - for (irF <- irs) { assertEvalsTo(IsNA(irF(NA(ts))), true) } + for (irF <- irs) assertEvalsTo(IsNA(irF(NA(ts))), true) } - @Test def testSetContainsOnRandomSet() { + @Test def testSetContainsOnRandomSet(): Unit = { implicit val execStrats = ExecStrategy.javaOnly val compareGen = Type.genArb .flatMap(t => Gen.zip(Gen.const(TSet(t)), TSet(t).genNonmissingValue(sm), t.genValue(sm))) - val p = Prop.forAll(compareGen) { case (tset: TSet, set: Set[Any]@unchecked, test1) => + val p = Prop.forAll(compareGen) { case (tset: TSet, set: Set[Any] @unchecked, test1) => val telt = tset.elementType if (set.nonEmpty) { assertEvalsTo( invoke("contains", TBoolean, In(0, tset), In(1, telt)), FastSeq(set -> tset, set.head -> telt), - expected = true) + expected = true, + ) } assertEvalsTo( invoke("contains", TBoolean, In(0, tset), In(1, telt)), FastSeq(set -> tset, test1 -> telt), - expected = set.contains(test1)) + expected = set.contains(test1), + ) true } p.check() } - def testDictGetOnRandomDict() { + def testDictGetOnRandomDict(): Unit = { implicit val execStrats = ExecStrategy.javaOnly val compareGen = Gen.zip(Type.genArb, Type.genArb).flatMap { case (k, v) => - Gen.zip(Gen.const(TDict(k, v)), TDict(k, v).genNonmissingValue(sm), k.genNonmissingValue(sm)) + Gen.zip( + Gen.const(TDict(k, v)), + TDict(k, v).genNonmissingValue(sm), + k.genNonmissingValue(sm), + ) } - val p = Prop.forAll(compareGen) { case (tdict: TDict, dict: Map[Any, Any]@unchecked, testKey1) => - assertEvalsTo(invoke("get", tdict.valueType, In(0, tdict), In(1, tdict.keyType)), - FastSeq(dict -> tdict, - testKey1 -> tdict.keyType), - dict.getOrElse(testKey1, null)) - - if (dict.nonEmpty) { - val testKey2 = dict.keys.toSeq.head - val expected2 = dict(testKey2) - assertEvalsTo(invoke("get", tdict.valueType, In(0, tdict), In(1, tdict.keyType)), - FastSeq(dict -> tdict, - testKey2 -> tdict.keyType), - expected2) + val p = + Prop.forAll(compareGen) { case (tdict: TDict, dict: Map[Any, Any] @unchecked, testKey1) => + assertEvalsTo( + invoke("get", tdict.valueType, In(0, tdict), In(1, tdict.keyType)), + FastSeq(dict -> tdict, testKey1 -> tdict.keyType), + dict.getOrElse(testKey1, null), + ) + + if (dict.nonEmpty) { + val testKey2 = dict.keys.toSeq.head + val expected2 = dict(testKey2) + assertEvalsTo( + invoke("get", tdict.valueType, In(0, tdict), In(1, tdict.keyType)), + FastSeq(dict -> tdict, testKey2 -> tdict.keyType), + expected2, + ) + } + true } - true - } p.check() } - def testBinarySearchOnSet() { - val compareGen = Type.genArb.flatMap(t => Gen.zip(Gen.const(t), TSet(t).genNonmissingValue(sm), t.genNonmissingValue(sm))) - val p = Prop.forAll(compareGen.filter { case (t, a, elem) => a.asInstanceOf[Set[Any]].nonEmpty }) { case (t, a, elem) => + def testBinarySearchOnSet(): Unit = { + val compareGen = Type.genArb.flatMap(t => + Gen.zip(Gen.const(t), TSet(t).genNonmissingValue(sm), t.genNonmissingValue(sm)) + ) + val p = Prop.forAll(compareGen.filter { case (_, a, _) => + a.asInstanceOf[Set[Any]].nonEmpty + }) { case (t, a, elem) => val set = a.asInstanceOf[Set[Any]] val pt = PType.canonical(t) val pset = PCanonicalSet(pt) @@ -448,23 +477,31 @@ class OrderingSuite extends HailSuite { val pArray = PCanonicalArray(pt) pool.scopedRegion { region => - val rvb = new RegionValueBuilder(sm, region) - val soff = pset.unstagedStoreJavaObject(sm, set, region) val eoff = pTuple.unstagedStoreJavaObject(sm, Row(elem), region) val fb = EmitFunctionBuilder[Region, Long, Long, Int](ctx, "binary_search") - val cregion = fb.getCodeParam[Region](1).load() val cset = fb.getCodeParam[Long](2) val cetuple = fb.getCodeParam[Long](3) - val bs = new BinarySearch(fb.apply_method, pset.sType, EmitType(pset.elementType.sType, true), { - (cb, elt) => elt - }) + val bs = new BinarySearch( + fb.apply_method, + pset.sType, + EmitType(pset.elementType.sType, true), + { + (cb, elt) => elt + }, + ) fb.emitWithBuilder(cb => - bs.search(cb, pset.loadCheapSCode(cb, cset), - EmitCode.fromI(fb.apply_method)(cb => IEmitCode.present(cb, pt.loadCheapSCode(cb, pTuple.loadField(cetuple, 0)))))) + bs.search( + cb, + pset.loadCheapSCode(cb, cset), + EmitCode.fromI(fb.apply_method)(cb => + IEmitCode.present(cb, pt.loadCheapSCode(cb, pTuple.loadField(cetuple, 0))) + ), + ) + ) val asArray = SafeIndexedSeq(pArray, soff) @@ -473,16 +510,20 @@ class OrderingSuite extends HailSuite { val maybeEqual = asArray(closestI) set.contains(elem) ==> (elem == maybeEqual) && - (t.ordering(sm).compare(elem, maybeEqual) <= 0 || (closestI == set.size - 1)) + (t.ordering(sm).compare(elem, maybeEqual) <= 0 || (closestI == set.size - 1)) } } p.check() } - @Test def testBinarySearchOnDict() { + @Test def testBinarySearchOnDict(): Unit = { val compareGen = Gen.zip(Type.genArb, Type.genArb) - .flatMap { case (k, v) => Gen.zip(Gen.const(TDict(k, v)), TDict(k, v).genNonmissingValue(sm), k.genValue(sm)) } - val p = Prop.forAll(compareGen.filter { case (tdict, a, key) => a.asInstanceOf[Map[Any, Any]].nonEmpty }) { case (tDict, a, key) => + .flatMap { case (k, v) => + Gen.zip(Gen.const(TDict(k, v)), TDict(k, v).genNonmissingValue(sm), k.genValue(sm)) + } + val p = Prop.forAll(compareGen.filter { case (_, a, _) => + a.asInstanceOf[Map[Any, Any]].nonEmpty + }) { case (tDict, a, key) => val dict = a.asInstanceOf[Map[Any, Any]] val pDict = PType.canonical(tDict).asInstanceOf[PDict] @@ -496,16 +537,27 @@ class OrderingSuite extends HailSuite { val cdict = fb.getCodeParam[Long](2) val cktuple = fb.getCodeParam[Long](3) - val bs = new BinarySearch(fb.apply_method, pDict.sType, EmitType(pDict.keyType.sType, false), { (cb, elt) => - cb.memoize(elt.toI(cb).flatMap(cb) { - case x: SBaseStructValue => - x.loadField(cb, 0) - }) - }) + val bs = new BinarySearch( + fb.apply_method, + pDict.sType, + EmitType(pDict.keyType.sType, false), + { (cb, elt) => + cb.memoize(elt.toI(cb).flatMap(cb) { + case x: SBaseStructValue => + x.loadField(cb, 0) + }) + }, + ) fb.emitWithBuilder(cb => - bs.search(cb, pDict.loadCheapSCode(cb, cdict), - EmitCode.fromI(fb.apply_method)(cb => IEmitCode.present(cb, pDict.keyType.loadCheapSCode(cb, ptuple.loadField(cktuple, 0)))))) + bs.search( + cb, + pDict.loadCheapSCode(cb, cdict), + EmitCode.fromI(fb.apply_method)(cb => + IEmitCode.present(cb, pDict.keyType.loadCheapSCode(cb, ptuple.loadField(cktuple, 0))) + ), + ) + ) val asArray = SafeIndexedSeq(PCanonicalArray(pDict.elementType), soff) @@ -521,44 +573,71 @@ class OrderingSuite extends HailSuite { def getKey(i: Int) = asArray(i).asInstanceOf[Row].get(0) val maybeEqual = getKey(closestI) val closestIIsClosest = - (pDict.keyType.virtualType.ordering(sm).compare(key, maybeEqual) <= 0 || closestI == dict.size - 1) && - (closestI == 0 || pDict.keyType.virtualType.ordering(sm).compare(key, getKey(closestI - 1)) > 0) - - // FIXME: -0.0 and 0.0 count as the same in scala Map, but not off-heap Hail data structures + (pDict.keyType.virtualType.ordering(sm).compare( + key, + maybeEqual, + ) <= 0 || closestI == dict.size - 1) && + (closestI == 0 || pDict.keyType.virtualType.ordering(sm).compare( + key, + getKey(closestI - 1), + ) > 0) + + /* FIXME: -0.0 and 0.0 count as the same in scala Map, but not off-heap Hail data + * structures */ val kord = tDict.keyType.ordering(sm) - (dict.contains(key) && dict.keysIterator.exists(kord.compare(_, key) == 0)) ==> (key == maybeEqual) && closestIIsClosest + (dict.contains(key) && dict.keysIterator.exists( + kord.compare(_, key) == 0 + )) ==> (key == maybeEqual) && closestIIsClosest } } } p.check() } - @Test def testContainsWithArrayFold() { + @Test def testContainsWithArrayFold(): Unit = { implicit val execStrats = ExecStrategy.javaOnly val set1 = ToSet(MakeStream(IndexedSeq(I32(1), I32(4)), TStream(TInt32))) val set2 = ToSet(MakeStream(IndexedSeq(I32(9), I32(1), I32(4)), TStream(TInt32))) - assertEvalsTo(StreamFold(ToStream(set1), True(), "accumulator", "setelt", - ApplySpecial("land", + assertEvalsTo( + StreamFold( + ToStream(set1), + True(), + "accumulator", + "setelt", + ApplySpecial( + "land", FastSeq(), FastSeq( Ref("accumulator", TBoolean), - invoke("contains", TBoolean, set2, Ref("setelt", TInt32))), TBoolean, ErrorIDs.NO_ERROR)), true) + invoke("contains", TBoolean, set2, Ref("setelt", TInt32)), + ), + TBoolean, + ErrorIDs.NO_ERROR, + ), + ), + true, + ) } @DataProvider(name = "arrayDoubleOrderingData") def arrayDoubleOrderingData(): Array[Array[Any]] = { - val xs = Array[Any](null, Double.NegativeInfinity, -0.0, 0.0, 1.0, Double.PositiveInfinity, Double.NaN) + val xs = + Array[Any](null, Double.NegativeInfinity, -0.0, 0.0, 1.0, Double.PositiveInfinity, Double.NaN) val as = Array(null: IndexedSeq[Any]) ++ (for (x <- xs) yield IndexedSeq[Any](x)) - for (a <- as; a2 <- as) - yield Array[Any](a, a2) + for { + a <- as + a2 <- as + } yield Array[Any](a, a2) } @Test(dataProvider = "arrayDoubleOrderingData") def testOrderingArrayDouble( - a: IndexedSeq[Any], a2: IndexedSeq[Any]) { + a: IndexedSeq[Any], + a2: IndexedSeq[Any], + ): Unit = { val t = TArray(TFloat64) val args = FastSeq(a -> t, a2 -> t) @@ -576,7 +655,9 @@ class OrderingSuite extends HailSuite { @Test(dataProvider = "arrayDoubleOrderingData") def testOrderingSetDouble( - a: IndexedSeq[Any], a2: IndexedSeq[Any]) { + a: IndexedSeq[Any], + a2: IndexedSeq[Any], + ): Unit = { val t = TSet(TFloat64) val s = if (a != null) a.toSet else null @@ -596,21 +677,26 @@ class OrderingSuite extends HailSuite { @DataProvider(name = "rowDoubleOrderingData") def rowDoubleOrderingData(): Array[Array[Any]] = { - val xs = Array[Any](null, Double.NegativeInfinity, -0.0, 0.0, 1.0, Double.PositiveInfinity, Double.NaN) - val as = Array(null: IndexedSeq[Any]) ++ - (for (x <- xs) yield FastSeq[Any](x)) + val xs = + Array[Any](null, Double.NegativeInfinity, -0.0, 0.0, 1.0, Double.PositiveInfinity, Double.NaN) val ss = Array[Any](null, "a", "aa") - val rs = for (x <- xs; s <- ss) - yield Row(x, s) + val rs = for { + x <- xs + s <- ss + } yield Row(x, s) - for (r <- rs; r2 <- rs) - yield Array[Any](r, r2) + for { + r <- rs + r2 <- rs + } yield Array[Any](r, r2) } @Test(dataProvider = "rowDoubleOrderingData") def testOrderingRowDouble( - r: Row, r2: Row) { + r: Row, + r2: Row, + ): Unit = { val t = TStruct("x" -> TFloat64, "s" -> TString) val args = FastSeq(r -> t, r2 -> t) diff --git a/hail/src/test/scala/is/hail/expr/ir/PruneSuite.scala b/hail/src/test/scala/is/hail/expr/ir/PruneSuite.scala index 97186fdddda..b2d1fdb6bb0 100644 --- a/hail/src/test/scala/is/hail/expr/ir/PruneSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/PruneSuite.scala @@ -7,32 +7,47 @@ import is.hail.rvd.RVD import is.hail.types._ import is.hail.types.virtual._ import is.hail.utils._ -import org.apache.spark.sql.Row -import org.json4s.JValue -import org.testng.annotations.{DataProvider, Test} import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import org.apache.spark.sql.Row +import org.json4s.JValue +import org.testng.annotations.{DataProvider, Test} + class PruneSuite extends HailSuite { - @Test def testUnionType() { + @Test def testUnionType(): Unit = { val base = TStruct( "a" -> TStruct( "aa" -> TInt32, "ab" -> TStruct( - "aaa" -> TString)), + "aaa" -> TString + ), + ), "b" -> TInt32, "c" -> TArray(TStruct( - "ca" -> TInt32))) + "ca" -> TInt32 + )), + ) assert(PruneDeadFields.unify(base, TStruct.empty) == TStruct.empty) assert(PruneDeadFields.unify(base, TStruct("b" -> TInt32)) == TStruct("b" -> TInt32)) - assert(PruneDeadFields.unify(base, TStruct("a" -> TStruct.empty)) == TStruct("a" -> TStruct.empty)) - assert(PruneDeadFields.unify(base, TStruct("a" -> TStruct.empty), - TStruct("b" -> TInt32)) == TStruct("a" -> TStruct.empty, "b" -> TInt32)) - assert(PruneDeadFields.unify(base, TStruct("c" -> TArray(TStruct.empty))) == TStruct("c" -> TArray(TStruct.empty))) - assert(PruneDeadFields.unify(base, TStruct("a" -> TStruct("ab" -> TStruct.empty)), - TStruct("c" -> TArray(TStruct.empty))) == TStruct("a" -> TStruct("ab" -> TStruct.empty), "c" -> TArray(TStruct.empty))) + assert( + PruneDeadFields.unify(base, TStruct("a" -> TStruct.empty)) == TStruct("a" -> TStruct.empty) + ) + assert(PruneDeadFields.unify( + base, + TStruct("a" -> TStruct.empty), + TStruct("b" -> TInt32), + ) == TStruct("a" -> TStruct.empty, "b" -> TInt32)) + assert(PruneDeadFields.unify(base, TStruct("c" -> TArray(TStruct.empty))) == TStruct( + "c" -> TArray(TStruct.empty) + )) + assert(PruneDeadFields.unify( + base, + TStruct("a" -> TStruct("ab" -> TStruct.empty)), + TStruct("c" -> TArray(TStruct.empty)), + ) == TStruct("a" -> TStruct("ab" -> TStruct.empty), "c" -> TArray(TStruct.empty))) } @Test def testIsSupertype(): Unit = { @@ -53,19 +68,25 @@ class PruneSuite extends HailSuite { assert(PruneDeadFields.isSupertype(tuple2IntsFirstRemoved, tuple2Ints)) } - def checkMemo(ir: BaseIR, requestedType: BaseType, expected: Array[BaseType]) { + def checkMemo(ir: BaseIR, requestedType: BaseType, expected: Array[BaseType]): Unit = { val irCopy = ir.deepCopy() - assert(PruneDeadFields.isSupertype(requestedType, irCopy.typ), - s"not supertype:\n super: ${ requestedType.parsableString() }\n sub: ${ irCopy.typ.parsableString() }") + assert( + PruneDeadFields.isSupertype(requestedType, irCopy.typ), + s"not supertype:\n super: ${requestedType.parsableString()}\n sub: ${irCopy.typ.parsableString()}", + ) val ms = PruneDeadFields.ComputeMutableState(Memo.empty[BaseType], mutable.HashMap.empty) irCopy match { - case mir: MatrixIR => PruneDeadFields.memoizeMatrixIR(ctx, mir, requestedType.asInstanceOf[MatrixType], ms) - case tir: TableIR => PruneDeadFields.memoizeTableIR(ctx, tir, requestedType.asInstanceOf[TableType], ms) + case mir: MatrixIR => + PruneDeadFields.memoizeMatrixIR(ctx, mir, requestedType.asInstanceOf[MatrixType], ms) + case tir: TableIR => + PruneDeadFields.memoizeTableIR(ctx, tir, requestedType.asInstanceOf[TableType], ms) case ir: IR => PruneDeadFields.memoizeValueIR(ctx, ir, requestedType.asInstanceOf[Type], ms) } irCopy.children.zipWithIndex.foreach { case (child, i) => if (expected(i) != null && expected(i) != ms.requestedType.lookup(child)) { - fatal(s"For base IR $ir\n Child $i with IR $child\n Expected: ${ expected(i) }\n Actual: ${ ms.requestedType.lookup(child) }") + fatal( + s"For base IR $ir\n Child $i with IR $child\n Expected: ${expected(i)}\n Actual: ${ms.requestedType.lookup(child)}" + ) } } } @@ -73,7 +94,8 @@ class PruneSuite extends HailSuite { def checkRebuild[T <: BaseIR]( ir: T, requestedType: BaseType, - f: (T, T) => Boolean = (left: T, right: T) => left == right) { + f: (T, T) => Boolean = (left: T, right: T) => left == right, + ): Unit = { val irCopy = ir.deepCopy() val ms = PruneDeadFields.ComputeMutableState(Memo.empty[BaseType], mutable.HashMap.empty) val rebuilt = (irCopy match { @@ -85,64 +107,97 @@ class PruneSuite extends HailSuite { PruneDeadFields.rebuild(ctx, tir, ms.rebuildState) case ir: IR => PruneDeadFields.memoizeValueIR(ctx, ir, requestedType.asInstanceOf[Type], ms) - PruneDeadFields.rebuildIR(ctx, ir, BindingEnv(Env.empty, Some(Env.empty), Some(Env.empty)), ms.rebuildState) + PruneDeadFields.rebuildIR( + ctx, + ir, + BindingEnv(Env.empty, Some(Env.empty), Some(Env.empty)), + ms.rebuildState, + ) }).asInstanceOf[T] if (!f(ir, rebuilt)) fatal(s"IR did not rebuild the same:\n Base: $ir\n Rebuilt: $rebuilt") } - lazy val tab = TableLiteral(TableKeyBy( - TableParallelize( - Literal( - TStruct( - "rows" -> TArray(TStruct("1" -> TString, - "2" -> TArray(TStruct("2A" -> TInt32)), - "3" -> TString, - "4" -> TStruct("A" -> TInt32, "B" -> TArray(TStruct("i" -> TString))), - "5" -> TString)), - "global" -> TStruct("g1" -> TInt32, "g2" -> TInt32)), - Row(FastSeq(Row("hi", FastSeq(Row(1)), "bye", Row(2, FastSeq(Row("bar"))), "foo")), Row(5, 10))), - None), - FastSeq("3"), - false).analyzeAndExecute(ctx).asTableValue(ctx), - theHailClassLoader) - - lazy val tr = TableRead(tab.typ, false, new FakeTableReader { - override def pathsUsed: Seq[String] = Seq.empty - override def fullType: TableType = tab.typ - }) + lazy val tab = TableLiteral( + TableKeyBy( + TableParallelize( + Literal( + TStruct( + "rows" -> TArray(TStruct( + "1" -> TString, + "2" -> TArray(TStruct("2A" -> TInt32)), + "3" -> TString, + "4" -> TStruct("A" -> TInt32, "B" -> TArray(TStruct("i" -> TString))), + "5" -> TString, + )), + "global" -> TStruct("g1" -> TInt32, "g2" -> TInt32), + ), + Row( + FastSeq(Row("hi", FastSeq(Row(1)), "bye", Row(2, FastSeq(Row("bar"))), "foo")), + Row(5, 10), + ), + ), + None, + ), + FastSeq("3"), + false, + ).analyzeAndExecute(ctx).asTableValue(ctx), + theHailClassLoader, + ) + + lazy val tr = TableRead( + tab.typ, + false, + new FakeTableReader { + override def pathsUsed: Seq[String] = Seq.empty + override def fullType: TableType = tab.typ + }, + ) lazy val mType = MatrixType( TStruct("g1" -> TInt32, "g2" -> TFloat64), FastSeq("ck"), TStruct("ck" -> TString, "c2" -> TInt32, "c3" -> TArray(TStruct("cc" -> TInt32))), FastSeq("rk"), - TStruct("rk" -> TInt32, "r2" -> TStruct("x" -> TInt32), "r3" -> TArray(TStruct("rr" -> TInt32))), - TStruct("e1" -> TFloat64, "e2" -> TFloat64)) - lazy val mat = MatrixLiteral(ctx, + TStruct( + "rk" -> TInt32, + "r2" -> TStruct("x" -> TInt32), + "r3" -> TArray(TStruct("rr" -> TInt32)), + ), + TStruct("e1" -> TFloat64, "e2" -> TFloat64), + ) + + lazy val mat = MatrixLiteral( + ctx, mType, RVD.empty(ctx, mType.canonicalTableType.canonicalRVDType), Row(1, 1.0), - FastSeq(Row("1", 2, FastSeq(Row(3))))) + FastSeq(Row("1", 2, FastSeq(Row(3)))), + ) - lazy val mr = MatrixRead(mat.typ, false, false, new MatrixReader { - def pathsUsed: IndexedSeq[String] = FastSeq() + lazy val mr = MatrixRead( + mat.typ, + false, + false, + new MatrixReader { + def pathsUsed: IndexedSeq[String] = FastSeq() - override def columnCount: Option[Int] = None + override def columnCount: Option[Int] = None - def partitionCounts: Option[IndexedSeq[Long]] = None + def partitionCounts: Option[IndexedSeq[Long]] = None - def rowUIDType = TTuple(TInt64, TInt64) - def colUIDType = TTuple(TInt64, TInt64) + def rowUIDType = TTuple(TInt64, TInt64) + def colUIDType = TTuple(TInt64, TInt64) - def fullMatrixTypeWithoutUIDs: MatrixType = mat.typ + def fullMatrixTypeWithoutUIDs: MatrixType = mat.typ - def lower(requestedType: MatrixType, dropCols: Boolean, dropRows: Boolean): TableIR = ??? + def lower(requestedType: MatrixType, dropCols: Boolean, dropRows: Boolean): TableIR = ??? - def toJValue: JValue = ??? + def toJValue: JValue = ??? - override def renderShort(): String = ??? - }) + override def renderShort(): String = ??? + }, + ) lazy val emptyTableDep = TableType(TStruct.empty, FastSeq(), TStruct.empty) @@ -155,17 +210,16 @@ class PruneSuite extends HailSuite { case "global" => Ref("global", tt.globalType) } - split.tail.foreach { field => - ir = GetField(ir, field) - } + split.tail.foreach(field => ir = GetField(ir, field)) let = Let(FastSeq(genUID() -> ir), let) } let } - def tableRefStruct(tt: TableType, fields: String*): IR = { - MakeStruct(tt.key.map(k => k -> GetField(Ref("row", tt.rowType), k)) ++ FastSeq("foo" -> tableRefBoolean(tt, fields: _*))) - } + def tableRefStruct(tt: TableType, fields: String*): IR = + MakeStruct(tt.key.map(k => k -> GetField(Ref("row", tt.rowType), k)) ++ FastSeq( + "foo" -> tableRefBoolean(tt, fields: _*) + )) def matrixRefBoolean(mt: MatrixType, fields: String*): IR = { var let: IR = True() @@ -178,17 +232,14 @@ class PruneSuite extends HailSuite { case "global" => Ref("global", mt.globalType) } - split.tail.foreach { field => - ir = GetField(ir, field) - } + split.tail.foreach(field => ir = GetField(ir, field)) let = Let(FastSeq(genUID() -> ir), let) } let } - def matrixRefStruct(mt: MatrixType, fields: String*): IR = { + def matrixRefStruct(mt: MatrixType, fields: String*): IR = MakeStruct(FastSeq("foo" -> matrixRefBoolean(mt, fields: _*))) - } def subsetTable(tt: TableType, fields: String*): TableType = { val rowFields = new BoxedArrayBuilder[TStruct]() @@ -208,8 +259,11 @@ class PruneSuite extends HailSuite { val k = if (noKey) FastSeq() else tt.key tt.copy( key = k, - rowType = PruneDeadFields.unify(tt.rowType, Array(PruneDeadFields.selectKey(tt.rowType, k)) ++ rowFields.result(): _*), - globalType = PruneDeadFields.unify(tt.globalType, globalFields.result(): _*) + rowType = PruneDeadFields.unify( + tt.rowType, + Array(PruneDeadFields.selectKey(tt.rowType, k)) ++ rowFields.result(): _* + ), + globalType = PruneDeadFields.unify(tt.globalType, globalFields.result(): _*), ) } @@ -243,172 +297,224 @@ class PruneSuite extends HailSuite { rowKey = rk, colKey = ck, globalType = PruneDeadFields.unify(mt.globalType, globalFields.result(): _*), - colType = PruneDeadFields.unify(mt.colType, Array(PruneDeadFields.selectKey(mt.colType, ck)) ++ colFields.result(): _*), - rowType = PruneDeadFields.unify(mt.rowType, Array(PruneDeadFields.selectKey(mt.rowType, rk)) ++ rowFields.result(): _*), - entryType = PruneDeadFields.unify(mt.entryType, entryFields.result(): _*)) + colType = PruneDeadFields.unify( + mt.colType, + Array(PruneDeadFields.selectKey(mt.colType, ck)) ++ colFields.result(): _* + ), + rowType = PruneDeadFields.unify( + mt.rowType, + Array(PruneDeadFields.selectKey(mt.rowType, rk)) ++ rowFields.result(): _* + ), + entryType = PruneDeadFields.unify(mt.entryType, entryFields.result(): _*), + ) } - def mangle(t: TableIR): TableIR = { + def mangle(t: TableIR): TableIR = TableRename( t, t.typ.rowType.fieldNames.map(x => x -> (x + "_")).toMap, - t.typ.globalType.fieldNames.map(x => x -> (x + "_")).toMap + t.typ.globalType.fieldNames.map(x => x -> (x + "_")).toMap, ) - } - @Test def testTableJoinMemo() { + @Test def testTableJoinMemo(): Unit = { val tk1 = TableKeyBy(tab, Array("1")) val tk2 = mangle(TableKeyBy(tab, Array("3"))) val tj = TableJoin(tk1, tk2, "inner", 1) - checkMemo(tj, + checkMemo( + tj, subsetTable(tj.typ, "row.1", "row.4", "row.1_"), Array( subsetTable(tk1.typ, "row.1", "row.4"), - subsetTable(tk2.typ, "row.1_", "row.3_") - ) + subsetTable(tk2.typ, "row.1_", "row.3_"), + ), ) val tk3 = TableKeyBy(tab, Array("1", "2")) val tk4 = mangle(TableKeyBy(tab, Array("1", "2"))) val tj2 = TableJoin(tk3, tk4, "inner", 1) - checkMemo(tj2, + checkMemo( + tj2, subsetTable(tj2.typ, "row.3_"), Array( subsetTable(tk3.typ, "row.1", "row.2"), - subsetTable(tk4.typ, "row.1_", "row.2_", "row.3_") - )) + subsetTable(tk4.typ, "row.1_", "row.2_", "row.3_"), + ), + ) - checkMemo(tj2, + checkMemo( + tj2, subsetTable(tj2.typ, "row.3_", "NO_KEY"), Array( TableType(globalType = TStruct.empty, key = Array("1"), rowType = TStruct("1" -> TString)), - TableType(globalType = TStruct.empty, key = Array("1_"), rowType = TStruct("1_" -> TString, "3_" -> TString)) - )) + TableType( + globalType = TStruct.empty, + key = Array("1_"), + rowType = TStruct("1_" -> TString, "3_" -> TString), + ), + ), + ) } - @Test def testTableLeftJoinRightDistinctMemo() { + @Test def testTableLeftJoinRightDistinctMemo(): Unit = { val tk1 = TableKeyBy(tab, Array("1")) val tk2 = TableKeyBy(tab, Array("3")) val tj = TableLeftJoinRightDistinct(tk1, tk2, "foo") - checkMemo(tj, + checkMemo( + tj, subsetTable(tj.typ, "row.1", "row.4", "row.foo"), Array( subsetTable(tk1.typ, "row.1", "row.4"), - subsetTable(tk2.typ) - ) + subsetTable(tk2.typ), + ), ) } - @Test def testTableIntervalJoinMemo() { + @Test def testTableIntervalJoinMemo(): Unit = { val tk1 = TableKeyBy(tab, Array("1")) val tk2 = TableKeyBy(tab, Array("3")) - val tj = TableIntervalJoin(tk1, tk2, "foo", product=false) - checkMemo(tj, + val tj = TableIntervalJoin(tk1, tk2, "foo", product = false) + checkMemo( + tj, subsetTable(tj.typ, "row.1", "row.4", "row.foo"), Array( subsetTable(tk1.typ, "row.1", "row.4"), - subsetTable(tk2.typ) - ) + subsetTable(tk2.typ), + ), ) } - @Test def testTableMultiWayZipJoinMemo() { + @Test def testTableMultiWayZipJoinMemo(): Unit = { val tk1 = TableKeyBy(tab, Array("1")) val ts = Array(tk1, tk1, tk1) val tmwzj = TableMultiWayZipJoin(ts, "data", "gbls") - checkMemo(tmwzj, subsetTable(tmwzj.typ, "row.data.2", "global.gbls.g1"), ts.map { t => - subsetTable(t.typ, "row.2", "global.g1") - }) + checkMemo( + tmwzj, + subsetTable(tmwzj.typ, "row.data.2", "global.gbls.g1"), + ts.map(t => subsetTable(t.typ, "row.2", "global.g1")), + ) } - @Test def testTableExplodeMemo() { + @Test def testTableExplodeMemo(): Unit = { val te = TableExplode(tab, Array("2")) checkMemo(te, subsetTable(te.typ), Array(subsetTable(tab.typ, "row.2"))) } - @Test def testTableFilterMemo() { - checkMemo(TableFilter(tab, tableRefBoolean(tab.typ, "row.2")), + @Test def testTableFilterMemo(): Unit = { + checkMemo( + TableFilter(tab, tableRefBoolean(tab.typ, "row.2")), subsetTable(tab.typ, "row.3"), - Array(subsetTable(tab.typ, "row.2", "row.3"), null)) - checkMemo(TableFilter(tab, False()), + Array(subsetTable(tab.typ, "row.2", "row.3"), null), + ) + checkMemo( + TableFilter(tab, False()), subsetTable(tab.typ, "row.1"), - Array(subsetTable(tab.typ, "row.1"), TBoolean)) + Array(subsetTable(tab.typ, "row.1"), TBoolean), + ) } - @Test def testTableKeyByMemo() { + @Test def testTableKeyByMemo(): Unit = { val tk = TableKeyBy(tab, Array("1")) - checkMemo(tk, subsetTable(tk.typ, "row.2"), Array(subsetTable(tab.typ, "row.1", "row.2", "NO_KEY"))) + checkMemo( + tk, + subsetTable(tk.typ, "row.2"), + Array(subsetTable(tab.typ, "row.1", "row.2", "NO_KEY")), + ) val tk2 = TableKeyBy(tab, Array("3"), isSorted = true) checkMemo(tk2, subsetTable(tk2.typ, "row.2"), Array(subsetTable(tab.typ, "row.2"))) } - @Test def testTableMapRowsMemo() { + @Test def testTableMapRowsMemo(): Unit = { val tmr = TableMapRows(tab, tableRefStruct(tab.typ, "row.1", "row.2")) - checkMemo(tmr, subsetTable(tmr.typ, "row.foo"), Array(subsetTable(tab.typ, "row.1", "row.2"), null)) + checkMemo( + tmr, + subsetTable(tmr.typ, "row.foo"), + Array(subsetTable(tab.typ, "row.1", "row.2"), null), + ) val tmr2 = TableMapRows(tab, tableRefStruct(tab.typ, "row.1", "row.2")) - checkMemo(tmr2, subsetTable(tmr2.typ, "row.foo", "NO_KEY"), Array(subsetTable(tab.typ, "row.1", "row.2", "NO_KEY"), null)) + checkMemo( + tmr2, + subsetTable(tmr2.typ, "row.foo", "NO_KEY"), + Array(subsetTable(tab.typ, "row.1", "row.2", "NO_KEY"), null), + ) } - @Test def testTableMapGlobalsMemo() { + @Test def testTableMapGlobalsMemo(): Unit = { val tmg = TableMapGlobals(tab, tableRefStruct(tab.typ, "global.g1")) - checkMemo(tmg, subsetTable(tmg.typ, "global.foo"), Array(subsetTable(tab.typ, "global.g1"), null)) + checkMemo( + tmg, + subsetTable(tmg.typ, "global.foo"), + Array(subsetTable(tab.typ, "global.g1"), null), + ) } - @Test def testMatrixColsTableMemo() { + @Test def testMatrixColsTableMemo(): Unit = { val mct = MatrixColsTable(mat) - checkMemo(mct, + checkMemo( + mct, subsetTable(mct.typ, "global.g1", "row.c2"), - Array(subsetMatrixTable(mat.typ, "global.g1", "sa.c2", "NO_ROW_KEY"))) + Array(subsetMatrixTable(mat.typ, "global.g1", "sa.c2", "NO_ROW_KEY")), + ) } - @Test def testMatrixRowsTableMemo() { + @Test def testMatrixRowsTableMemo(): Unit = { val mrt = MatrixRowsTable(mat) - checkMemo(mrt, + checkMemo( + mrt, subsetTable(mrt.typ, "global.g1", "row.r2"), - Array(subsetMatrixTable(mat.typ, "global.g1", "va.r2", "NO_COL_KEY"))) + Array(subsetMatrixTable(mat.typ, "global.g1", "va.r2", "NO_COL_KEY")), + ) } - @Test def testMatrixEntriesTableMemo() { + @Test def testMatrixEntriesTableMemo(): Unit = { val met = MatrixEntriesTable(mat) - checkMemo(met, + checkMemo( + met, subsetTable(met.typ, "global.g1", "row.r2", "row.c2", "row.e2"), - Array(subsetMatrixTable(mat.typ, "global.g1", "va.r2", "sa.c2", "g.e2"))) + Array(subsetMatrixTable(mat.typ, "global.g1", "va.r2", "sa.c2", "g.e2")), + ) } - @Test def testTableKeyByAndAggregateMemo() { - val tka = TableKeyByAndAggregate(tab, + @Test def testTableKeyByAndAggregateMemo(): Unit = { + val tka = TableKeyByAndAggregate( + tab, tableRefStruct(tab.typ, "row.2"), MakeStruct(FastSeq("bar" -> tableRefBoolean(tab.typ, "row.3"))), None, - 1) + 1, + ) - checkMemo(tka, subsetTable(tka.typ, "row.foo"), Array(subsetTable(tab.typ, "row.2", "row.3", "NO_KEY"), null, null)) + checkMemo( + tka, + subsetTable(tka.typ, "row.foo"), + Array(subsetTable(tab.typ, "row.2", "row.3", "NO_KEY"), null, null), + ) checkMemo(tka, subsetTable(tka.typ), Array(subsetTable(tab.typ, "row.3", "NO_KEY"), null, null)) } @Test def testTableAggregateByKeyMemo(): Unit = { val tabk = TableAggregateByKey( tab, - SelectFields(Ref("row", tab.typ.rowType), IndexedSeq("5")) + SelectFields(Ref("row", tab.typ.rowType), IndexedSeq("5")), + ) + checkMemo( + tabk, + requestedType = subsetTable(tabk.typ, "row.3", "row.5"), + Array(subsetTable(tabk.typ, "row.3", "row.5"), TStruct(("5", TString))), ) - checkMemo(tabk, requestedType = subsetTable(tabk.typ, "row.3", "row.5"), - Array(subsetTable(tabk.typ, "row.3", "row.5"), TStruct(("5", TString)))) } - @Test def testTableUnionMemo() { + @Test def testTableUnionMemo(): Unit = checkMemo( TableUnion(FastSeq(tab, tab)), subsetTable(tab.typ, "row.1", "global.g1"), - Array(subsetTable(tab.typ, "row.1", "global.g1"), - subsetTable(tab.typ, "row.1")) + Array(subsetTable(tab.typ, "row.1", "global.g1"), subsetTable(tab.typ, "row.1")), ) - } - @Test def testTableOrderByMemo() { + @Test def testTableOrderByMemo(): Unit = { val tob = TableOrderBy(tab, Array(SortField("2", Ascending))) checkMemo(tob, subsetTable(tob.typ), Array(subsetTable(tab.typ, "row.2", "row.2.2A", "NO_KEY"))) @@ -416,152 +522,236 @@ class PruneSuite extends HailSuite { checkMemo(tob2, subsetTable(tob2.typ), Array(subsetTable(tab.typ))) } - @Test def testCastMatrixToTableMemo() { + @Test def testCastMatrixToTableMemo(): Unit = { val m2t = CastMatrixToTable(mat, "__entries", "__cols") - checkMemo(m2t, + checkMemo( + m2t, subsetTable(m2t.typ, "row.r2", "global.__cols.c2", "global.g2", "row.__entries.e2"), - Array(subsetMatrixTable(mat.typ, "va.r2", "global.g2", "sa.c2", "g.e2", "NO_COL_KEY")) + Array(subsetMatrixTable(mat.typ, "va.r2", "global.g2", "sa.c2", "g.e2", "NO_COL_KEY")), ) } - @Test def testMatrixFilterColsMemo() { + @Test def testMatrixFilterColsMemo(): Unit = { val mfc = MatrixFilterCols(mat, matrixRefBoolean(mat.typ, "global.g1", "sa.c2")) - checkMemo(mfc, + checkMemo( + mfc, subsetMatrixTable(mfc.typ, "sa.c3"), - Array(subsetMatrixTable(mat.typ, "global.g1", "sa.c2", "sa.c3"), null)) + Array(subsetMatrixTable(mat.typ, "global.g1", "sa.c2", "sa.c3"), null), + ) } - @Test def testMatrixFilterRowsMemo() { + @Test def testMatrixFilterRowsMemo(): Unit = { val mfr = MatrixFilterRows(mat, matrixRefBoolean(mat.typ, "global.g1", "va.r2")) - checkMemo(mfr, + checkMemo( + mfr, subsetMatrixTable(mfr.typ, "sa.c3", "va.r3"), - Array(subsetMatrixTable(mat.typ, "global.g1", "va.r2", "sa.c3", "va.r3"), null)) + Array(subsetMatrixTable(mat.typ, "global.g1", "va.r2", "sa.c3", "va.r3"), null), + ) } - @Test def testMatrixFilterEntriesMemo() { - val mfe = MatrixFilterEntries(mat, matrixRefBoolean(mat.typ, "global.g1", "va.r2", "sa.c2", "g.e2")) - checkMemo(mfe, + @Test def testMatrixFilterEntriesMemo(): Unit = { + val mfe = + MatrixFilterEntries(mat, matrixRefBoolean(mat.typ, "global.g1", "va.r2", "sa.c2", "g.e2")) + checkMemo( + mfe, subsetMatrixTable(mfe.typ, "sa.c3", "va.r3"), - Array(subsetMatrixTable(mat.typ, "global.g1", "va.r2", "sa.c3", "sa.c2", "va.r3", "g.e2"), null)) + Array( + subsetMatrixTable(mat.typ, "global.g1", "va.r2", "sa.c3", "sa.c2", "va.r3", "g.e2"), + null, + ), + ) } - @Test def testMatrixMapColsMemo() { - val mmc = MatrixMapCols(mat, matrixRefStruct(mat.typ, "global.g1", "sa.c2", "va.r2", "g.e2"), Some(FastSeq())) - checkMemo(mmc, subsetMatrixTable(mmc.typ, "va.r3", "sa.foo"), - Array(subsetMatrixTable(mat.typ, "global.g1", "sa.c2", "va.r2", "g.e2", "va.r3", "NO_COL_KEY"), null)) - val mmc2 = MatrixMapCols(mat, MakeStruct(FastSeq( - ("ck" -> GetField(Ref("sa", mat.typ.colType), "ck")), - ("foo",matrixRefStruct(mat.typ, "global.g1", "sa.c2", "va.r2", "g.e2")))), None) - checkMemo(mmc2, subsetMatrixTable(mmc2.typ, "va.r3", "sa.foo.foo"), - Array(subsetMatrixTable(mat.typ, "global.g1", "sa.c2", "va.r2", "g.e2", "va.r3"), null)) + @Test def testMatrixMapColsMemo(): Unit = { + val mmc = MatrixMapCols( + mat, + matrixRefStruct(mat.typ, "global.g1", "sa.c2", "va.r2", "g.e2"), + Some(FastSeq()), + ) + checkMemo( + mmc, + subsetMatrixTable(mmc.typ, "va.r3", "sa.foo"), + Array( + subsetMatrixTable(mat.typ, "global.g1", "sa.c2", "va.r2", "g.e2", "va.r3", "NO_COL_KEY"), + null, + ), + ) + val mmc2 = MatrixMapCols( + mat, + MakeStruct(FastSeq( + ("ck" -> GetField(Ref("sa", mat.typ.colType), "ck")), + ("foo", matrixRefStruct(mat.typ, "global.g1", "sa.c2", "va.r2", "g.e2")), + )), + None, + ) + checkMemo( + mmc2, + subsetMatrixTable(mmc2.typ, "va.r3", "sa.foo.foo"), + Array(subsetMatrixTable(mat.typ, "global.g1", "sa.c2", "va.r2", "g.e2", "va.r3"), null), + ) } - @Test def testMatrixKeyRowsByMemo() { + @Test def testMatrixKeyRowsByMemo(): Unit = { val mkr = MatrixKeyRowsBy(mat, FastSeq("rk")) checkMemo(mkr, subsetMatrixTable(mkr.typ, "va.rk"), Array(subsetMatrixTable(mat.typ, "va.rk"))) } - @Test def testMatrixMapRowsMemo() { + @Test def testMatrixMapRowsMemo(): Unit = { val mmr = MatrixMapRows( MatrixKeyRowsBy(mat, IndexedSeq.empty), - matrixRefStruct(mat.typ, "global.g1", "sa.c2", "va.r2", "g.e2")) - checkMemo(mmr, subsetMatrixTable(mmr.typ, "sa.c3", "va.foo"), - Array(subsetMatrixTable(mat.typ.copy(rowKey = IndexedSeq.empty), "global.g1", "sa.c2", "va.r2", "g.e2", "sa.c3"), null)) + matrixRefStruct(mat.typ, "global.g1", "sa.c2", "va.r2", "g.e2"), + ) + checkMemo( + mmr, + subsetMatrixTable(mmr.typ, "sa.c3", "va.foo"), + Array( + subsetMatrixTable( + mat.typ.copy(rowKey = IndexedSeq.empty), + "global.g1", + "sa.c2", + "va.r2", + "g.e2", + "sa.c3", + ), + null, + ), + ) } - @Test def testMatrixMapGlobalsMemo() { + @Test def testMatrixMapGlobalsMemo(): Unit = { val mmg = MatrixMapGlobals(mat, matrixRefStruct(mat.typ, "global.g1")) - checkMemo(mmg, subsetMatrixTable(mmg.typ, "global.foo", "va.r3", "sa.c3"), - Array(subsetMatrixTable(mat.typ, "global.g1", "va.r3", "sa.c3"), null)) + checkMemo( + mmg, + subsetMatrixTable(mmg.typ, "global.foo", "va.r3", "sa.c3"), + Array(subsetMatrixTable(mat.typ, "global.g1", "va.r3", "sa.c3"), null), + ) } - @Test def testMatrixAnnotateRowsTableMemo() { + @Test def testMatrixAnnotateRowsTableMemo(): Unit = { val tl = TableLiteral(Interpret(MatrixRowsTable(mat), ctx), theHailClassLoader) - val mart = MatrixAnnotateRowsTable(mat, tl, "foo", product=false) - checkMemo(mart, subsetMatrixTable(mart.typ, "va.foo.r3", "va.r3"), - Array(subsetMatrixTable(mat.typ, "va.r3"), subsetTable(tl.typ, "row.r3"))) + val mart = MatrixAnnotateRowsTable(mat, tl, "foo", product = false) + checkMemo( + mart, + subsetMatrixTable(mart.typ, "va.foo.r3", "va.r3"), + Array(subsetMatrixTable(mat.typ, "va.r3"), subsetTable(tl.typ, "row.r3")), + ) } - @Test def testCollectColsByKeyMemo() { + @Test def testCollectColsByKeyMemo(): Unit = { val ccbk = MatrixCollectColsByKey(mat) - checkMemo(ccbk, + checkMemo( + ccbk, subsetMatrixTable(ccbk.typ, "g.e2", "sa.c2", "NO_COL_KEY"), - Array(subsetMatrixTable(mat.typ, "g.e2", "sa.c2"))) + Array(subsetMatrixTable(mat.typ, "g.e2", "sa.c2")), + ) } - @Test def testMatrixExplodeRowsMemo() { + @Test def testMatrixExplodeRowsMemo(): Unit = { val mer = MatrixExplodeRows(mat, FastSeq("r3")) - checkMemo(mer, + checkMemo( + mer, subsetMatrixTable(mer.typ, "va.r2"), - Array(subsetMatrixTable(mat.typ, "va.r2", "va.r3"))) + Array(subsetMatrixTable(mat.typ, "va.r2", "va.r3")), + ) } - @Test def testMatrixRepartitionMemo() { + @Test def testMatrixRepartitionMemo(): Unit = { checkMemo( MatrixRepartition(mat, 10, RepartitionStrategy.SHUFFLE), subsetMatrixTable(mat.typ, "va.r2", "global.g1"), - Array(subsetMatrixTable(mat.typ, "va.r2", "global.g1"), - subsetMatrixTable(mat.typ, "va.r2", "global.g1")) + Array( + subsetMatrixTable(mat.typ, "va.r2", "global.g1"), + subsetMatrixTable(mat.typ, "va.r2", "global.g1"), + ), ) } - @Test def testMatrixUnionRowsMemo() { + @Test def testMatrixUnionRowsMemo(): Unit = { checkMemo( MatrixUnionRows(FastSeq(mat, mat)), subsetMatrixTable(mat.typ, "va.r2", "global.g1"), - Array(subsetMatrixTable(mat.typ, "va.r2", "global.g1"), - subsetMatrixTable(mat.typ, "va.r2", "global.g1")) + Array( + subsetMatrixTable(mat.typ, "va.r2", "global.g1"), + subsetMatrixTable(mat.typ, "va.r2", "global.g1"), + ), ) } - @Test def testMatrixDistinctByRowMemo() { + @Test def testMatrixDistinctByRowMemo(): Unit = { checkMemo( MatrixDistinctByRow(mat), subsetMatrixTable(mat.typ, "va.r2", "global.g1"), - Array(subsetMatrixTable(mat.typ, "va.r2", "global.g1"), - subsetMatrixTable(mat.typ, "va.r2", "global.g1")) + Array( + subsetMatrixTable(mat.typ, "va.r2", "global.g1"), + subsetMatrixTable(mat.typ, "va.r2", "global.g1"), + ), ) } - @Test def testMatrixExplodeColsMemo() { + @Test def testMatrixExplodeColsMemo(): Unit = { val mer = MatrixExplodeCols(mat, FastSeq("c3")) - checkMemo(mer, + checkMemo( + mer, subsetMatrixTable(mer.typ, "va.r2"), - Array(subsetMatrixTable(mat.typ, "va.r2", "sa.c3"))) + Array(subsetMatrixTable(mat.typ, "va.r2", "sa.c3")), + ) } - @Test def testCastTableToMatrixMemo() { + @Test def testCastTableToMatrixMemo(): Unit = { val m2t = CastMatrixToTable(mat, "__entries", "__cols") val t2m = CastTableToMatrix(m2t, "__entries", "__cols", FastSeq("ck")) - checkMemo(t2m, + checkMemo( + t2m, subsetMatrixTable(mat.typ, "va.r2", "sa.c2", "global.g2", "g.e2"), - Array(subsetTable(m2t.typ, "row.r2", "global.g2", "global.__cols.ck", "global.__cols.c2", "row.__entries.e2")) + Array(subsetTable( + m2t.typ, + "row.r2", + "global.g2", + "global.__cols.ck", + "global.__cols.c2", + "row.__entries.e2", + )), ) } - @Test def testMatrixAggregateRowsByKeyMemo() { - val magg = MatrixAggregateRowsByKey(mat, + @Test def testMatrixAggregateRowsByKeyMemo(): Unit = { + val magg = MatrixAggregateRowsByKey( + mat, matrixRefStruct(mat.typ, "g.e2", "va.r2", "sa.c2"), - matrixRefStruct(mat.typ, "va.r3", "global.g1")) - checkMemo(magg, + matrixRefStruct(mat.typ, "va.r3", "global.g1"), + ) + checkMemo( + magg, subsetMatrixTable(magg.typ, "sa.c3", "g.foo", "va.foo"), - Array(subsetMatrixTable(mat.typ, "sa.c3", "g.e2", "va.r2", "sa.c2", "global.g1", "va.r3"), null, null) + Array( + subsetMatrixTable(mat.typ, "sa.c3", "g.e2", "va.r2", "sa.c2", "global.g1", "va.r3"), + null, + null, + ), ) } - @Test def testMatrixAggregateColsByKeyMemo() { - val magg = MatrixAggregateColsByKey(mat, + @Test def testMatrixAggregateColsByKeyMemo(): Unit = { + val magg = MatrixAggregateColsByKey( + mat, matrixRefStruct(mat.typ, "g.e2", "va.r2", "sa.c2"), - matrixRefStruct(mat.typ, "sa.c3", "global.g1")) - checkMemo(magg, + matrixRefStruct(mat.typ, "sa.c3", "global.g1"), + ) + checkMemo( + magg, subsetMatrixTable(magg.typ, "va.r3", "g.foo", "sa.foo"), - Array(subsetMatrixTable(mat.typ, "sa.c2", "va.r2", "va.r3", "g.e2", "global.g1", "sa.c3"), null, null)) + Array( + subsetMatrixTable(mat.typ, "sa.c2", "va.r2", "va.r3", "g.e2", "global.g1", "sa.c3"), + null, + null, + ), + ) } val ref = Ref("x", TStruct("a" -> TInt32, "b" -> TInt32, "c" -> TInt32)) val arr = MakeArray(FastSeq(ref, ref), TArray(ref.typ)) val st = MakeStream(FastSeq(ref, ref), TStream(ref.typ)) - val ndArr = MakeNDArray(arr, MakeTuple(IndexedSeq((0, I64(2l)))), True(), ErrorIDs.NO_ERROR) + val ndArr = MakeNDArray(arr, MakeTuple(IndexedSeq((0, I64(2L)))), True(), ErrorIDs.NO_ERROR) val empty = TStruct.empty val justA = TStruct("a" -> TInt32) val justB = TStruct("b" -> TInt32) @@ -570,273 +760,396 @@ class PruneSuite extends HailSuite { val justARequired = TStruct("a" -> TInt32) val justBRequired = TStruct("b" -> TInt32) - @Test def testIfMemo() { - checkMemo(If(True(), ref, ref), - justA, - Array(TBoolean, justA, justA)) - } + @Test def testIfMemo(): Unit = + checkMemo(If(True(), ref, ref), justA, Array(TBoolean, justA, justA)) @Test def testSwitchMemo(): Unit = checkMemo( Switch(I32(0), ref, FastSeq(ref)), justA, - Array(TInt32, justA, justA) + Array(TInt32, justA, justA), ) - @Test def testCoalesceMemo() { - checkMemo(Coalesce(FastSeq(ref, ref)), - justA, - Array(justA, justA)) - } + @Test def testCoalesceMemo(): Unit = + checkMemo(Coalesce(FastSeq(ref, ref)), justA, Array(justA, justA)) - @Test def testLetMemo() { + @Test def testLetMemo(): Unit = { checkMemo(Let(FastSeq("foo" -> ref), Ref("foo", ref.typ)), justA, Array(justA, null)) checkMemo(Let(FastSeq("foo" -> ref), True()), TBoolean, Array(empty, null)) } - @Test def testAggLetMemo() { - checkMemo(AggLet("foo", ref, - ApplyAggOp(FastSeq(), FastSeq( - SelectFields(Ref("foo", ref.typ), IndexedSeq("a"))), - AggSignature(Collect(), FastSeq(), FastSeq(ref.typ))), false), - TArray(justA), Array(justA, null)) + @Test def testAggLetMemo(): Unit = { + checkMemo( + AggLet( + "foo", + ref, + ApplyAggOp( + FastSeq(), + FastSeq( + SelectFields(Ref("foo", ref.typ), IndexedSeq("a")) + ), + AggSignature(Collect(), FastSeq(), FastSeq(ref.typ)), + ), + false, + ), + TArray(justA), + Array(justA, null), + ) checkMemo(AggLet("foo", ref, True(), false), TBoolean, Array(empty, null)) } - @Test def testMakeArrayMemo() { + @Test def testMakeArrayMemo(): Unit = checkMemo(arr, TArray(justB), Array(justB, justB)) - } - @Test def testArrayRefMemo() { + @Test def testArrayRefMemo(): Unit = checkMemo(ArrayRef(arr, I32(0)), justB, Array(TArray(justB), null, null)) - } - @Test def testArrayLenMemo() { + @Test def testArrayLenMemo(): Unit = checkMemo(ArrayLen(arr), TInt32, Array(TArray(empty))) - } - @Test def testStreamTakeMemo() { + @Test def testStreamTakeMemo(): Unit = checkMemo(StreamTake(st, I32(2)), TStream(justA), Array(TStream(justA), null)) - } - @Test def testStreamDropMemo() { + @Test def testStreamDropMemo(): Unit = checkMemo(StreamDrop(st, I32(2)), TStream(justA), Array(TStream(justA), null)) - } - @Test def testStreamMapMemo() { - checkMemo(StreamMap(st, "foo", Ref("foo", ref.typ)), - TStream(justB), Array(TStream(justB), null)) - } + @Test def testStreamMapMemo(): Unit = + checkMemo( + StreamMap(st, "foo", Ref("foo", ref.typ)), + TStream(justB), + Array(TStream(justB), null), + ) - @Test def testStreamGroupedMemo() { - checkMemo(StreamGrouped(st, I32(2)), - TStream(TStream(justB)), Array(TStream(justB), null)) - } + @Test def testStreamGroupedMemo(): Unit = + checkMemo(StreamGrouped(st, I32(2)), TStream(TStream(justB)), Array(TStream(justB), null)) - @Test def testStreamGroupByKeyMemo() { - checkMemo(StreamGroupByKey(st, FastSeq("a"), false), - TStream(TStream(justB)), Array(TStream(TStruct("a" -> TInt32, "b" -> TInt32)), null)) - } + @Test def testStreamGroupByKeyMemo(): Unit = + checkMemo( + StreamGroupByKey(st, FastSeq("a"), false), + TStream(TStream(justB)), + Array(TStream(TStruct("a" -> TInt32, "b" -> TInt32)), null), + ) - @Test def testStreamMergeMemo() { + @Test def testStreamMergeMemo(): Unit = { val st2 = st.deepCopy() checkMemo( StreamMultiMerge( IndexedSeq(st, st2), - FastSeq("a")), - TStream(justB), Array(TStream(aAndB), TStream(aAndB))) + FastSeq("a"), + ), + TStream(justB), + Array(TStream(aAndB), TStream(aAndB)), + ) } - @Test def testStreamZipMemo() { + @Test def testStreamZipMemo(): Unit = { val a2 = st.deepCopy() val a3 = st.deepCopy() - for (b <- Array(ArrayZipBehavior.ExtendNA, ArrayZipBehavior.TakeMinLength, ArrayZipBehavior.AssertSameLength)) { - checkMemo(StreamZip( + for ( + b <- Array( + ArrayZipBehavior.ExtendNA, + ArrayZipBehavior.TakeMinLength, + ArrayZipBehavior.AssertSameLength, + ) + ) { + checkMemo( + StreamZip( + FastSeq(st, a2, a3), + FastSeq("foo", "bar", "baz"), + Let( + FastSeq( + "foo1" -> GetField(Ref("foo", ref.typ), "b"), + "bar2" -> GetField(Ref("bar", ref.typ), "a"), + ), + False(), + ), + b, + ), + TStream(TBoolean), + Array(TStream(justB), TStream(justA), TStream(empty), null), + ) + } + + checkMemo( + StreamZip( FastSeq(st, a2, a3), FastSeq("foo", "bar", "baz"), Let( FastSeq( "foo1" -> GetField(Ref("foo", ref.typ), "b"), - "bar2" -> GetField(Ref("bar", ref.typ), "a") + "bar2" -> GetField(Ref("bar", ref.typ), "a"), ), - False() - ), b - ), - TStream(TBoolean), Array(TStream(justB), TStream(justA), TStream(empty), null)) - } - - checkMemo(StreamZip( - FastSeq(st, a2, a3), - FastSeq("foo", "bar", "baz"), - Let( - FastSeq( - "foo1" -> GetField(Ref("foo", ref.typ), "b"), - "bar2" -> GetField(Ref("bar", ref.typ), "a") + False(), ), - False() + ArrayZipBehavior.AssumeSameLength, ), - ArrayZipBehavior.AssumeSameLength), - TStream(TBoolean), Array(TStream(justB), TStream(justA), null, null)) + TStream(TBoolean), + Array(TStream(justB), TStream(justA), null, null), + ) } - @Test def testStreamFilterMemo() { - checkMemo(StreamFilter(st, "foo", Let(FastSeq("foo2" -> GetField(Ref("foo", ref.typ), "b")), False())), - TStream(empty), Array(TStream(justB), null)) - checkMemo(StreamFilter(st, "foo", False()), - TStream(empty), Array(TStream(empty), null)) - checkMemo(StreamFilter(st, "foo", False()), - TStream(justB), Array(TStream(justB), null)) + @Test def testStreamFilterMemo(): Unit = { + checkMemo( + StreamFilter(st, "foo", Let(FastSeq("foo2" -> GetField(Ref("foo", ref.typ), "b")), False())), + TStream(empty), + Array(TStream(justB), null), + ) + checkMemo(StreamFilter(st, "foo", False()), TStream(empty), Array(TStream(empty), null)) + checkMemo(StreamFilter(st, "foo", False()), TStream(justB), Array(TStream(justB), null)) } - @Test def testStreamFlatMapMemo() { - checkMemo(StreamFlatMap(st, "foo", MakeStream(FastSeq(Ref("foo", ref.typ)), TStream(ref.typ))), + @Test def testStreamFlatMapMemo(): Unit = + checkMemo( + StreamFlatMap(st, "foo", MakeStream(FastSeq(Ref("foo", ref.typ)), TStream(ref.typ))), TStream(justA), - Array(TStream(justA), null)) - } + Array(TStream(justA), null), + ) - @Test def testStreamFoldMemo() { - checkMemo(StreamFold(st, I32(0), "comb", "foo", GetField(Ref("foo", ref.typ), "a")), + @Test def testStreamFoldMemo(): Unit = + checkMemo( + StreamFold(st, I32(0), "comb", "foo", GetField(Ref("foo", ref.typ), "a")), TInt32, - Array(TStream(justA), null, null)) - } + Array(TStream(justA), null, null), + ) - @Test def testStreamScanMemo() { - checkMemo(StreamScan(st, I32(0), "comb", "foo", GetField(Ref("foo", ref.typ), "a")), + @Test def testStreamScanMemo(): Unit = + checkMemo( + StreamScan(st, I32(0), "comb", "foo", GetField(Ref("foo", ref.typ), "a")), TStream(TInt32), - Array(TStream(justA), null, null)) - } + Array(TStream(justA), null, null), + ) - @Test def testStreamJoinRightDistinct() { + @Test def testStreamJoinRightDistinct(): Unit = { val l = Ref("l", ref.typ) val r = Ref("r", ref.typ) checkMemo( - StreamJoinRightDistinct(st, st, FastSeq("a"), FastSeq("a"), "l", "r", - MakeStruct(FastSeq("a" -> GetField(l, "a"), "b" -> GetField(l, "b"), "c" -> GetField(l, "c"), "d" -> GetField(r, "b"), "e" -> GetField(r, "c"))), - "left"), + StreamJoinRightDistinct( + st, + st, + FastSeq("a"), + FastSeq("a"), + "l", + "r", + MakeStruct(FastSeq( + "a" -> GetField(l, "a"), + "b" -> GetField(l, "b"), + "c" -> GetField(l, "c"), + "d" -> GetField(r, "b"), + "e" -> GetField(r, "c"), + )), + "left", + ), TStream(TStruct("b" -> TInt32, "d" -> TInt32)), Array( TStream(TStruct("a" -> TInt32, "b" -> TInt32)), TStream(TStruct("a" -> TInt32, "b" -> TInt32)), - TStruct("b" -> TInt32, "d" -> TInt32))) + TStruct("b" -> TInt32, "d" -> TInt32), + ), + ) } - @Test def testStreamForMemo() { - checkMemo(StreamFor(st, "foo", Begin(FastSeq(GetField(Ref("foo", ref.typ), "a")))), - TVoid, - Array(TStream(justA), null)) + @Test def testStreamLeftIntervalJoin(): Unit = { + val leftElemType = TStruct("a" -> TInt32, "b" -> TInt32, "c" -> TInt32) + val rightElemType = TStruct("interval" -> TInterval(TInt32), "ignored" -> TVoid) + + val join = + StreamLeftIntervalJoin( + MakeStream(FastSeq(), TStream(leftElemType)), + MakeStream(FastSeq(), TStream(rightElemType)), + leftElemType.fieldNames.head, + "interval", + "lname", + "rname", + InsertFields( + Ref("lname", leftElemType), + FastSeq("intervals" -> Ref("rname", TArray(rightElemType))), + ), + ) + + val prunedLElemType = leftElemType.deleteKey("b") + val prunedRElemType = rightElemType.deleteKey("ignored") + val requestedElemType = prunedLElemType.insertFields( + FastSeq("intervals" -> TArray(prunedRElemType)) + ) + + checkMemo( + join, + TStream(requestedElemType), + Array( + TStream(prunedLElemType), + TStream(prunedRElemType), + requestedElemType, + ), + ) + + checkRebuild[StreamLeftIntervalJoin]( + join, + TStream(requestedElemType), + (_, pruned) => + pruned.left.typ == TStream(prunedLElemType) && + pruned.right.typ == TStream(prunedRElemType) && + pruned.body.typ == requestedElemType, + ) } + @Test def testStreamForMemo(): Unit = + checkMemo( + StreamFor(st, "foo", Begin(FastSeq(GetField(Ref("foo", ref.typ), "a")))), + TVoid, + Array(TStream(justA), null), + ) + @Test def testMakeNDArrayMemo(): Unit = { checkMemo( MakeNDArray( Ref("x", TArray(TStruct("a" -> TInt32, "b" -> TInt64))), Ref("y", TTuple(TInt32, TInt32)), - True(), ErrorIDs.NO_ERROR), + True(), + ErrorIDs.NO_ERROR, + ), TNDArray(TStruct("a" -> TInt32), Nat(2)), Array( TArray(TStruct("a" -> TInt32)), TTuple(TInt32, TInt32), - TBoolean - ) + TBoolean, + ), ) } - @Test def testNDArrayMapMemo(): Unit = { - checkMemo(NDArrayMap(ndArr, "foo", Ref("foo", ref.typ)), - TNDArray(justBRequired, Nat(1)), Array(TNDArray(justBRequired, Nat(1)), null)) - } + @Test def testNDArrayMapMemo(): Unit = + checkMemo( + NDArrayMap(ndArr, "foo", Ref("foo", ref.typ)), + TNDArray(justBRequired, Nat(1)), + Array(TNDArray(justBRequired, Nat(1)), null), + ) @Test def testNDArrayMap2Memo(): Unit = { - checkMemo(NDArrayMap2(ndArr, ndArr, "left", "right", Ref("left", ref.typ), ErrorIDs.NO_ERROR), - TNDArray(justBRequired, Nat(1)), Array(TNDArray(justBRequired, Nat(1)), TNDArray(TStruct.empty, Nat(1)), null)) - checkMemo(NDArrayMap2(ndArr, ndArr, "left", "right", Ref("right", ref.typ), ErrorIDs.NO_ERROR), - TNDArray(justBRequired, Nat(1)), Array(TNDArray(TStruct.empty, Nat(1)), TNDArray(justBRequired, Nat(1)), null)) - val addFieldsIR = ApplyBinaryPrimOp(Add(), GetField(Ref("left", ref.typ), "a"), GetField(Ref("right", ref.typ), "b")) - checkMemo(NDArrayMap2(ndArr, ndArr, "left", "right", addFieldsIR, ErrorIDs.NO_ERROR), - TNDArray(TInt32, Nat(1)), Array(TNDArray(justARequired, Nat(1)), TNDArray(justBRequired, Nat(1)), null)) + checkMemo( + NDArrayMap2(ndArr, ndArr, "left", "right", Ref("left", ref.typ), ErrorIDs.NO_ERROR), + TNDArray(justBRequired, Nat(1)), + Array(TNDArray(justBRequired, Nat(1)), TNDArray(TStruct.empty, Nat(1)), null), + ) + checkMemo( + NDArrayMap2(ndArr, ndArr, "left", "right", Ref("right", ref.typ), ErrorIDs.NO_ERROR), + TNDArray(justBRequired, Nat(1)), + Array(TNDArray(TStruct.empty, Nat(1)), TNDArray(justBRequired, Nat(1)), null), + ) + val addFieldsIR = ApplyBinaryPrimOp( + Add(), + GetField(Ref("left", ref.typ), "a"), + GetField(Ref("right", ref.typ), "b"), + ) + checkMemo( + NDArrayMap2(ndArr, ndArr, "left", "right", addFieldsIR, ErrorIDs.NO_ERROR), + TNDArray(TInt32, Nat(1)), + Array(TNDArray(justARequired, Nat(1)), TNDArray(justBRequired, Nat(1)), null), + ) } - @Test def testMakeStructMemo() { - checkMemo(MakeStruct(IndexedSeq("a" -> ref, "b" -> I32(10))), - TStruct("a" -> justA), Array(justA, null)) - checkMemo(MakeStruct(IndexedSeq("a" -> ref, "b" -> I32(10))), - TStruct.empty, Array(null, null)) + @Test def testMakeStructMemo(): Unit = { + checkMemo( + MakeStruct(IndexedSeq("a" -> ref, "b" -> I32(10))), + TStruct("a" -> justA), + Array(justA, null), + ) + checkMemo(MakeStruct(IndexedSeq("a" -> ref, "b" -> I32(10))), TStruct.empty, Array(null, null)) } - @Test def testInsertFieldsMemo() { - checkMemo(InsertFields(ref, IndexedSeq("d" -> ref)), + @Test def testInsertFieldsMemo(): Unit = + checkMemo( + InsertFields(ref, IndexedSeq("d" -> ref)), justA ++ TStruct("d" -> justB), - Array(justA, justB)) - } + Array(justA, justB), + ) - @Test def testSelectFieldsMemo() { + @Test def testSelectFieldsMemo(): Unit = { checkMemo(SelectFields(ref, IndexedSeq("a", "b")), justA, Array(justA)) checkMemo(SelectFields(ref, IndexedSeq("b", "a")), bAndA, Array(aAndB)) } - @Test def testGetFieldMemo() { + @Test def testGetFieldMemo(): Unit = checkMemo(GetField(ref, "a"), TInt32, Array(justA)) - } - @Test def testMakeTupleMemo() { + @Test def testMakeTupleMemo(): Unit = checkMemo(MakeTuple(IndexedSeq(0 -> ref)), TTuple(justA), Array(justA)) - } - @Test def testGetTupleElementMemo() { - checkMemo(GetTupleElement(MakeTuple.ordered(IndexedSeq(ref, ref)), 1), justB, Array(TTuple(FastSeq(TupleField(1, justB))))) - } + @Test def testGetTupleElementMemo(): Unit = + checkMemo( + GetTupleElement(MakeTuple.ordered(IndexedSeq(ref, ref)), 1), + justB, + Array(TTuple(FastSeq(TupleField(1, justB)))), + ) - @Test def testCastRenameMemo() { + @Test def testCastRenameMemo(): Unit = { checkMemo( CastRename( Ref("x", TArray(TStruct("x" -> TInt32, "y" -> TString))), - TArray(TStruct("y" -> TInt32, "z" -> TString))), + TArray(TStruct("y" -> TInt32, "z" -> TString)), + ), TArray(TStruct("z" -> TString)), - Array(TArray(TStruct("y" -> TString))) + Array(TArray(TStruct("y" -> TString))), ) } @Test def testAggFilterMemo(): Unit = { val t = TStruct("a" -> TInt32, "b" -> TInt64, "c" -> TString) val select = SelectFields(Ref("x", t), IndexedSeq("c")) - checkMemo(AggFilter( - ApplyComparisonOp(LT(TInt32, TInt32), GetField(Ref("x", t), "a"), I32(0)), - ApplyAggOp(FastSeq(), FastSeq(select), - AggSignature(Collect(), FastSeq(), FastSeq(select.typ))), - false), + checkMemo( + AggFilter( + ApplyComparisonOp(LT(TInt32, TInt32), GetField(Ref("x", t), "a"), I32(0)), + ApplyAggOp( + FastSeq(), + FastSeq(select), + AggSignature(Collect(), FastSeq(), FastSeq(select.typ)), + ), + false, + ), TArray(TStruct("c" -> TString)), - Array(null, TArray(TStruct("c" -> TString)))) + Array(null, TArray(TStruct("c" -> TString))), + ) } @Test def testAggExplodeMemo(): Unit = { val t = TStream(TStruct("a" -> TInt32, "b" -> TInt64)) val select = SelectFields(Ref("foo", t.elementType), IndexedSeq("a")) - checkMemo(AggExplode(Ref("x", t), - "foo", - ApplyAggOp(FastSeq(), FastSeq(select), - AggSignature(Collect(), FastSeq(), FastSeq(select.typ))), - false), + checkMemo( + AggExplode( + Ref("x", t), + "foo", + ApplyAggOp( + FastSeq(), + FastSeq(select), + AggSignature(Collect(), FastSeq(), FastSeq(select.typ)), + ), + false, + ), TArray(TStruct("a" -> TInt32)), - Array(TStream(TStruct("a" -> TInt32)), - TArray(TStruct("a" -> TInt32)))) + Array(TStream(TStruct("a" -> TInt32)), TArray(TStruct("a" -> TInt32))), + ) } @Test def testAggArrayPerElementMemo(): Unit = { val t = TArray(TStruct("a" -> TInt32, "b" -> TInt64)) val select = SelectFields(Ref("foo", t.elementType), IndexedSeq("a")) - checkMemo(AggArrayPerElement(Ref("x", t), - "foo", - "bar", - ApplyAggOp(FastSeq(), FastSeq(select), - AggSignature(Collect(), FastSeq(), FastSeq(select.typ))), - None, - false), + checkMemo( + AggArrayPerElement( + Ref("x", t), + "foo", + "bar", + ApplyAggOp( + FastSeq(), + FastSeq(select), + AggSignature(Collect(), FastSeq(), FastSeq(select.typ)), + ), + None, + false, + ), TArray(TArray(TStruct("a" -> TInt32))), - Array(TArray(TStruct("a" -> TInt32)), - TArray(TStruct("a" -> TInt32)))) + Array(TArray(TStruct("a" -> TInt32)), TArray(TStruct("a" -> TInt32))), + ) } - @Test def testCDAMemo() { + @Test def testCDAMemo(): Unit = { val ctxT = TStruct("a" -> TInt32, "b" -> TString) val globT = TStruct("c" -> TInt64, "d" -> TFloat64) val x = CollectDistributedArray( @@ -844,554 +1157,763 @@ class PruneSuite extends HailSuite { NA(globT), "ctx", "glob", - MakeTuple.ordered(FastSeq(Ref("ctx", ctxT), Ref("glob", globT))), NA(TString), "test") + MakeTuple.ordered(FastSeq(Ref("ctx", ctxT), Ref("glob", globT))), + NA(TString), + "test", + ) - checkMemo(x, TArray(TTuple(ctxT.typeAfterSelectNames(Array("a")), globT.typeAfterSelectNames(Array("c")))), - Array(TStream(ctxT.typeAfterSelectNames(Array("a"))), globT.typeAfterSelectNames(Array("c")), null, TString)) + checkMemo( + x, + TArray(TTuple(ctxT.typeAfterSelectNames(Array("a")), globT.typeAfterSelectNames(Array("c")))), + Array( + TStream(ctxT.typeAfterSelectNames(Array("a"))), + globT.typeAfterSelectNames(Array("c")), + null, + TString, + ), + ) } - @Test def testTableCountMemo() { + @Test def testTableCountMemo(): Unit = checkMemo(TableCount(tab), TInt64, Array(subsetTable(tab.typ, "NO_KEY"))) - } - @Test def testTableGetGlobalsMemo() { - checkMemo(TableGetGlobals(tab), TStruct("g1" -> TInt32), Array(subsetTable(tab.typ, "global.g1", "NO_KEY"))) - } + @Test def testTableGetGlobalsMemo(): Unit = + checkMemo( + TableGetGlobals(tab), + TStruct("g1" -> TInt32), + Array(subsetTable(tab.typ, "global.g1", "NO_KEY")), + ) - @Test def testTableCollectMemo() { + @Test def testTableCollectMemo(): Unit = checkMemo( TableCollect(tab), TStruct("rows" -> TArray(TStruct("3" -> TString)), "global" -> TStruct("g2" -> TInt32)), - Array(subsetTable(tab.typ, "row.3", "global.g2"))) - } + Array(subsetTable(tab.typ, "row.3", "global.g2")), + ) - @Test def testTableHeadMemo() { + @Test def testTableHeadMemo(): Unit = checkMemo( TableHead(tab, 10L), subsetTable(tab.typ.copy(key = FastSeq()), "global.g1"), - Array(subsetTable(tab.typ, "row.3", "global.g1"))) - } + Array(subsetTable(tab.typ, "row.3", "global.g1")), + ) - @Test def testTableTailMemo() { + @Test def testTableTailMemo(): Unit = checkMemo( TableTail(tab, 10L), subsetTable(tab.typ.copy(key = FastSeq()), "global.g1"), - Array(subsetTable(tab.typ, "row.3", "global.g1"))) - } + Array(subsetTable(tab.typ, "row.3", "global.g1")), + ) - @Test def testTableToValueApplyMemo() { + @Test def testTableToValueApplyMemo(): Unit = checkMemo( TableToValueApply(tab, ForceCountTable()), TInt64, - Array(tab.typ) + Array(tab.typ), ) - } - @Test def testMatrixToValueApplyMemo() { + @Test def testMatrixToValueApplyMemo(): Unit = checkMemo( MatrixToValueApply(mat, ForceCountMatrixTable()), TInt64, - Array(mat.typ) + Array(mat.typ), ) - } - @Test def testTableAggregateMemo() { - checkMemo(TableAggregate(tab, tableRefBoolean(tab.typ, "global.g1")), + @Test def testTableAggregateMemo(): Unit = + checkMemo( + TableAggregate(tab, tableRefBoolean(tab.typ, "global.g1")), TBoolean, - Array(subsetTable(tab.typ, "global.g1"), null)) - } + Array(subsetTable(tab.typ, "global.g1"), null), + ) - @Test def testMatrixAggregateMemo() { - checkMemo(MatrixAggregate(mat, matrixRefBoolean(mat.typ, "global.g1")), + @Test def testMatrixAggregateMemo(): Unit = + checkMemo( + MatrixAggregate(mat, matrixRefBoolean(mat.typ, "global.g1")), TBoolean, - Array(subsetMatrixTable(mat.typ, "global.g1", "NO_COL_KEY"), null)) - } + Array(subsetMatrixTable(mat.typ, "global.g1", "NO_COL_KEY"), null), + ) - @Test def testPipelineLetMemo() { + @Test def testPipelineLetMemo(): Unit = { val t = TStruct("a" -> TInt32) - checkMemo(RelationalLet("foo", NA(t), RelationalRef("foo", t)), TStruct.empty, Array(TStruct.empty, TStruct.empty)) + checkMemo( + RelationalLet("foo", NA(t), RelationalRef("foo", t)), + TStruct.empty, + Array(TStruct.empty, TStruct.empty), + ) } - @Test def testTableFilterRebuild() { - checkRebuild(TableFilter(tr, tableRefBoolean(tr.typ, "row.2")), subsetTable(tr.typ, "row.3"), + @Test def testTableFilterRebuild(): Unit = { + checkRebuild( + TableFilter(tr, tableRefBoolean(tr.typ, "row.2")), + subsetTable(tr.typ, "row.3"), (_: BaseIR, r: BaseIR) => { val tf = r.asInstanceOf[TableFilter] TypeCheck(ctx, tf.pred, PruneDeadFields.relationalTypeToEnv(tf.typ)) tf.child.typ == subsetTable(tr.typ, "row.3", "row.2") - }) + }, + ) } - @Test def testTableMapRowsRebuild() { + @Test def testTableMapRowsRebuild(): Unit = { val tmr = TableMapRows(tr, tableRefStruct(tr.typ, "row.2", "global.g1")) - checkRebuild(tmr, subsetTable(tmr.typ, "row.foo"), + checkRebuild( + tmr, + subsetTable(tmr.typ, "row.foo"), (_: BaseIR, r: BaseIR) => { val tmr = r.asInstanceOf[TableMapRows] TypeCheck(ctx, tmr.newRow, PruneDeadFields.relationalTypeToEnv(tmr.child.typ)) tmr.child.typ == subsetTable(tr.typ, "row.2", "global.g1", "row.3") - }) + }, + ) val tmr2 = TableMapRows(tr, tableRefStruct(tr.typ, "row.2", "global.g1")) - checkRebuild(tmr2, subsetTable(tmr2.typ, "row.foo", "NO_KEY"), + checkRebuild( + tmr2, + subsetTable(tmr2.typ, "row.foo", "NO_KEY"), (_: BaseIR, r: BaseIR) => { val tmr = r.asInstanceOf[TableMapRows] TypeCheck(ctx, tmr.newRow, PruneDeadFields.relationalTypeToEnv(tmr.child.typ)) - tmr.child.typ == subsetTable(tr.typ, "row.2", "global.g1", "row.3", "NO_KEY") // FIXME: remove row.3 when TableRead is fixed - }) + tmr.child.typ == subsetTable( + tr.typ, + "row.2", + "global.g1", + "row.3", + "NO_KEY", + ) // FIXME: remove row.3 when TableRead is fixed + }, + ) } - @Test def testTableMapGlobalsRebuild() { + @Test def testTableMapGlobalsRebuild(): Unit = { val tmg = TableMapGlobals(tr, tableRefStruct(tr.typ, "global.g1")) - checkRebuild(tmg, subsetTable(tmg.typ, "global.foo"), + checkRebuild( + tmg, + subsetTable(tmg.typ, "global.foo"), (_: BaseIR, r: BaseIR) => { val tmg = r.asInstanceOf[TableMapGlobals] TypeCheck(ctx, tmg.newGlobals, PruneDeadFields.relationalTypeToEnv(tmg.child.typ)) tmg.child.typ == subsetTable(tr.typ, "global.g1") - }) + }, + ) } - @Test def testTableLeftJoinRightDistinctRebuild() { + @Test def testTableLeftJoinRightDistinctRebuild(): Unit = { val tk1 = TableKeyBy(tab, Array("1")) val tk2 = TableKeyBy(tab, Array("3")) val tj = TableLeftJoinRightDistinct(tk1, tk2, "foo") - checkRebuild(tj, subsetTable(tj.typ, "row.1", "row.4"), - (_: BaseIR, r: BaseIR) => { - r.isInstanceOf[TableKeyBy] // no dependence on row.foo elides the join - }) + checkRebuild( + tj, + subsetTable(tj.typ, "row.1", "row.4"), + (_: BaseIR, r: BaseIR) => + r.isInstanceOf[TableKeyBy], // no dependence on row.foo elides the join + ) } - @Test def testTableIntervalJoinRebuild() { + @Test def testTableIntervalJoinRebuild(): Unit = { val tk1 = TableKeyBy(tab, Array("1")) val tk2 = TableKeyBy(tab, Array("3")) - val tj = TableIntervalJoin(tk1, tk2, "foo", product=false) + val tj = TableIntervalJoin(tk1, tk2, "foo", product = false) - checkRebuild(tj, subsetTable(tj.typ, "row.1", "row.4"), - (_: BaseIR, r: BaseIR) => { - r.isInstanceOf[TableKeyBy] // no dependence on row.foo elides the join - }) + checkRebuild( + tj, + subsetTable(tj.typ, "row.1", "row.4"), + (_: BaseIR, r: BaseIR) => + r.isInstanceOf[TableKeyBy], // no dependence on row.foo elides the join + ) } - @Test def testTableUnionRebuildUnifiesRowTypes() { - val mapExpr = InsertFields(Ref("row", tr.typ.rowType), - FastSeq("foo" -> tableRefBoolean(tr.typ, "row.3", "global.g1"))) + @Test def testTableUnionRebuildUnifiesRowTypes(): Unit = { + val mapExpr = InsertFields( + Ref("row", tr.typ.rowType), + FastSeq("foo" -> tableRefBoolean(tr.typ, "row.3", "global.g1")), + ) val tfilter = TableFilter( TableMapRows(tr, mapExpr), - tableRefBoolean(tr.typ, "row.2")) + tableRefBoolean(tr.typ, "row.2"), + ) val tmap = TableMapRows(tr, mapExpr) val tunion = TableUnion(FastSeq(tfilter, tmap)) - checkRebuild(tunion, subsetTable(tunion.typ, "row.foo"), + checkRebuild( + tunion, + subsetTable(tunion.typ, "row.foo"), (_: BaseIR, rebuilt: BaseIR) => { val tu = rebuilt.asInstanceOf[TableUnion] val tf = tu.childrenSeq(0) val tm = tu.childrenSeq(1) tf.typ.rowType == tm.typ.rowType && - tu.typ == subsetTable(tunion.typ, "row.foo", "global.g1") - }) + tu.typ == subsetTable(tunion.typ, "row.foo", "global.g1") + }, + ) } - @Test def testTableMultiWayZipJoinRebuildUnifiesRowTypes() { + @Test def testTableMultiWayZipJoinRebuildUnifiesRowTypes(): Unit = { val t1 = TableKeyBy(tab, Array("1")) val t2 = TableFilter(t1, tableRefBoolean(t1.typ, "row.2")) val t3 = TableFilter(t1, tableRefBoolean(t1.typ, "row.3")) val ts = Array(t1, t2, t3) val tmwzj = TableMultiWayZipJoin(ts, "data", "gbls") val childRType = subsetTable(t1.typ, "row.2", "global.g1") - checkRebuild(tmwzj, subsetTable(tmwzj.typ, "row.data.2", "global.gbls.g1"), + checkRebuild( + tmwzj, + subsetTable(tmwzj.typ, "row.data.2", "global.gbls.g1"), (_: BaseIR, rebuilt: BaseIR) => { val t = rebuilt.asInstanceOf[TableMultiWayZipJoin] - t.childrenSeq.forall { c => c.typ == childRType } - }) + t.childrenSeq.forall(c => c.typ == childRType) + }, + ) } - - @Test def testMatrixFilterColsRebuild() { + @Test def testMatrixFilterColsRebuild(): Unit = { val mfc = MatrixFilterCols(mr, matrixRefBoolean(mr.typ, "sa.c2")) - checkRebuild(mfc, subsetMatrixTable(mfc.typ, "global.g1"), + checkRebuild( + mfc, + subsetMatrixTable(mfc.typ, "global.g1"), (_: BaseIR, r: BaseIR) => { val mfc = r.asInstanceOf[MatrixFilterCols] TypeCheck(ctx, mfc.pred, PruneDeadFields.relationalTypeToEnv(mfc.child.typ)) mfc.child.asInstanceOf[MatrixRead].typ == subsetMatrixTable(mr.typ, "global.g1", "sa.c2") - } + }, ) } - @Test def testMatrixFilterEntriesRebuild() { + @Test def testMatrixFilterEntriesRebuild(): Unit = { val mfe = MatrixFilterEntries(mr, matrixRefBoolean(mr.typ, "sa.c2", "va.r2", "g.e1")) - checkRebuild(mfe, subsetMatrixTable(mfe.typ, "global.g1"), + checkRebuild( + mfe, + subsetMatrixTable(mfe.typ, "global.g1"), (_: BaseIR, r: BaseIR) => { val mfe = r.asInstanceOf[MatrixFilterEntries] TypeCheck(ctx, mfe.pred, PruneDeadFields.relationalTypeToEnv(mfe.child.typ)) - mfe.child.asInstanceOf[MatrixRead].typ == subsetMatrixTable(mr.typ, "global.g1", "sa.c2", "va.r2", "g.e1") - } + mfe.child.asInstanceOf[MatrixRead].typ == subsetMatrixTable( + mr.typ, + "global.g1", + "sa.c2", + "va.r2", + "g.e1", + ) + }, ) } - @Test def testMatrixMapRowsRebuild() { + @Test def testMatrixMapRowsRebuild(): Unit = { val mmr = MatrixMapRows( MatrixKeyRowsBy(mr, IndexedSeq.empty), - matrixRefStruct(mr.typ, "va.r2")) - checkRebuild(mmr, subsetMatrixTable(mmr.typ, "global.g1", "g.e1", "va.foo"), + matrixRefStruct(mr.typ, "va.r2"), + ) + checkRebuild( + mmr, + subsetMatrixTable(mmr.typ, "global.g1", "g.e1", "va.foo"), (_: BaseIR, r: BaseIR) => { val mmr = r.asInstanceOf[MatrixMapRows] TypeCheck(ctx, mmr.newRow, PruneDeadFields.relationalTypeToEnv(mmr.child.typ)) - mmr.child.asInstanceOf[MatrixKeyRowsBy].child.asInstanceOf[MatrixRead].typ == subsetMatrixTable(mr.typ, "global.g1", "va.r2", "g.e1") - } + mmr.child.asInstanceOf[MatrixKeyRowsBy].child.asInstanceOf[ + MatrixRead + ].typ == subsetMatrixTable(mr.typ, "global.g1", "va.r2", "g.e1") + }, ) } - @Test def testMatrixMapColsRebuild() { - val mmc = MatrixMapCols(mr, matrixRefStruct(mr.typ, "sa.c2"), - Some(FastSeq("foo"))) - checkRebuild(mmc, subsetMatrixTable(mmc.typ, "global.g1", "g.e1", "sa.foo"), + @Test def testMatrixMapColsRebuild(): Unit = { + val mmc = MatrixMapCols(mr, matrixRefStruct(mr.typ, "sa.c2"), Some(FastSeq("foo"))) + checkRebuild( + mmc, + subsetMatrixTable(mmc.typ, "global.g1", "g.e1", "sa.foo"), (_: BaseIR, r: BaseIR) => { val mmc = r.asInstanceOf[MatrixMapCols] TypeCheck(ctx, mmc.newCol, PruneDeadFields.relationalTypeToEnv(mmc.child.typ)) - mmc.child.asInstanceOf[MatrixRead].typ == subsetMatrixTable(mr.typ, "global.g1", "sa.c2", "g.e1") - } + mmc.child.asInstanceOf[MatrixRead].typ == subsetMatrixTable( + mr.typ, + "global.g1", + "sa.c2", + "g.e1", + ) + }, ) } - @Test def testMatrixMapEntriesRebuild() { + @Test def testMatrixMapEntriesRebuild(): Unit = { val mme = MatrixMapEntries(mr, matrixRefStruct(mr.typ, "sa.c2", "va.r2")) - checkRebuild(mme, subsetMatrixTable(mme.typ, "global.g1", "g.foo"), + checkRebuild( + mme, + subsetMatrixTable(mme.typ, "global.g1", "g.foo"), (_: BaseIR, r: BaseIR) => { val mme = r.asInstanceOf[MatrixMapEntries] TypeCheck(ctx, mme.newEntries, PruneDeadFields.relationalTypeToEnv(mme.child.typ)) - mme.child.asInstanceOf[MatrixRead].typ == subsetMatrixTable(mr.typ, "global.g1", "sa.c2", "va.r2") - } + mme.child.asInstanceOf[MatrixRead].typ == subsetMatrixTable( + mr.typ, + "global.g1", + "sa.c2", + "va.r2", + ) + }, ) } - @Test def testMatrixMapGlobalsRebuild() { + @Test def testMatrixMapGlobalsRebuild(): Unit = { val mmg = MatrixMapGlobals(mr, matrixRefStruct(mr.typ, "global.g1")) - checkRebuild(mmg, subsetMatrixTable(mmg.typ, "global.foo", "g.e1", "va.r2"), + checkRebuild( + mmg, + subsetMatrixTable(mmg.typ, "global.foo", "g.e1", "va.r2"), (_: BaseIR, r: BaseIR) => { val mmg = r.asInstanceOf[MatrixMapGlobals] TypeCheck(ctx, mmg.newGlobals, PruneDeadFields.relationalTypeToEnv(mmg.child.typ)) - mmg.child.asInstanceOf[MatrixRead].typ == subsetMatrixTable(mr.typ, "global.g1", "va.r2", "g.e1") - } + mmg.child.asInstanceOf[MatrixRead].typ == subsetMatrixTable( + mr.typ, + "global.g1", + "va.r2", + "g.e1", + ) + }, ) } - @Test def testMatrixAggregateRowsByKeyRebuild() { - val ma = MatrixAggregateRowsByKey(mr, matrixRefStruct(mr.typ, "sa.c2"), matrixRefStruct(mr.typ, "global.g1")) - checkRebuild(ma, subsetMatrixTable(ma.typ, "va.foo", "g.foo"), + @Test def testMatrixAggregateRowsByKeyRebuild(): Unit = { + val ma = MatrixAggregateRowsByKey( + mr, + matrixRefStruct(mr.typ, "sa.c2"), + matrixRefStruct(mr.typ, "global.g1"), + ) + checkRebuild( + ma, + subsetMatrixTable(ma.typ, "va.foo", "g.foo"), (_: BaseIR, r: BaseIR) => { val ma = r.asInstanceOf[MatrixAggregateRowsByKey] TypeCheck(ctx, ma.entryExpr, PruneDeadFields.relationalTypeToEnv(ma.child.typ)) ma.child.asInstanceOf[MatrixRead].typ == subsetMatrixTable(mr.typ, "global.g1", "sa.c2") - } + }, ) } - @Test def testMatrixAggregateColsByKeyRebuild() { - val ma = MatrixAggregateColsByKey(mr, matrixRefStruct(mr.typ, "va.r2"), matrixRefStruct(mr.typ, "global.g1")) - checkRebuild(ma, subsetMatrixTable(ma.typ, "g.foo", "sa.foo"), + @Test def testMatrixAggregateColsByKeyRebuild(): Unit = { + val ma = MatrixAggregateColsByKey( + mr, + matrixRefStruct(mr.typ, "va.r2"), + matrixRefStruct(mr.typ, "global.g1"), + ) + checkRebuild( + ma, + subsetMatrixTable(ma.typ, "g.foo", "sa.foo"), (_: BaseIR, r: BaseIR) => { val ma = r.asInstanceOf[MatrixAggregateColsByKey] TypeCheck(ctx, ma.entryExpr, PruneDeadFields.relationalTypeToEnv(ma.child.typ)) ma.child.asInstanceOf[MatrixRead].typ == subsetMatrixTable(mr.typ, "global.g1", "va.r2") - } + }, ) } - @Test def testMatrixUnionRowsRebuild() { + @Test def testMatrixUnionRowsRebuild(): Unit = { val mat2 = MatrixLiteral(mType.copy(colKey = FastSeq()), mat.tl) checkRebuild( - MatrixUnionRows(FastSeq(mat, MatrixMapCols(mat2, Ref("sa", mat2.typ.colType), Some(FastSeq("ck"))))), + MatrixUnionRows(FastSeq( + mat, + MatrixMapCols(mat2, Ref("sa", mat2.typ.colType), Some(FastSeq("ck"))), + )), mat.typ.copy(colKey = FastSeq()), - (_: BaseIR, r: BaseIR) => { + (_: BaseIR, r: BaseIR) => r.asInstanceOf[MatrixUnionRows].childrenSeq.forall { _.typ.colKey.isEmpty - } - }) + }, + ) } @Test def testMatrixUnionColsRebuild(): Unit = { - def getColField(name: String) = { + def getColField(name: String) = GetField(Ref("sa", mat.typ.colType), name) - } - def childrenMatch(matrixUnionCols: MatrixUnionCols): Boolean = { + def childrenMatch(matrixUnionCols: MatrixUnionCols): Boolean = matrixUnionCols.left.typ.colType == matrixUnionCols.right.typ.colType && matrixUnionCols.left.typ.entryType == matrixUnionCols.right.typ.entryType - } - val wrappedMat = MatrixMapCols(mat, - MakeStruct(IndexedSeq(("ck", getColField("ck")), ("c2", getColField("c2")), ("c3", getColField("c3")))), Some(FastSeq("ck")) + val wrappedMat = MatrixMapCols( + mat, + MakeStruct(IndexedSeq( + ("ck", getColField("ck")), + ("c2", getColField("c2")), + ("c3", getColField("c3")), + )), + Some(FastSeq("ck")), ) val wrappedMat2 = MatrixRename( - wrappedMat, - Map.empty, - Map.empty, - wrappedMat.typ.rowType.fieldNames.map(x => x -> (x + "_")).toMap, - Map.empty) + wrappedMat, + Map.empty, + Map.empty, + wrappedMat.typ.rowType.fieldNames.map(x => x -> (x + "_")).toMap, + Map.empty, + ) val mucBothSame = MatrixUnionCols(wrappedMat, wrappedMat2, "inner") checkRebuild(mucBothSame, mucBothSame.typ) - checkRebuild[MatrixUnionCols](mucBothSame, mucBothSame.typ.copy(colType = TStruct(("ck", TString), ("c2", TInt32))), (old, rebuilt) => - (old.typ.rowType == rebuilt.typ.rowType) && - (old.typ.globalType == rebuilt.typ.globalType) && - (rebuilt.typ.colType.fieldNames.toIndexedSeq == IndexedSeq("ck", "c2")) && - childrenMatch(rebuilt) + checkRebuild[MatrixUnionCols]( + mucBothSame, + mucBothSame.typ.copy(colType = TStruct(("ck", TString), ("c2", TInt32))), + (old, rebuilt) => + (old.typ.rowType == rebuilt.typ.rowType) && + (old.typ.globalType == rebuilt.typ.globalType) && + (rebuilt.typ.colType.fieldNames.toIndexedSeq == IndexedSeq("ck", "c2")) && + childrenMatch(rebuilt), ) - // Since `mat` is a MatrixLiteral, it won't be rebuilt, will keep all fields. But wrappedMat is a MatrixMapCols, so it will drop - // unrequested fields. This test would fail without upcasting in the MatrixUnionCols rebuild rule. + /* Since `mat` is a MatrixLiteral, it won't be rebuilt, will keep all fields. But wrappedMat is + * a MatrixMapCols, so it will drop */ + /* unrequested fields. This test would fail without upcasting in the MatrixUnionCols rebuild + * rule. */ val muc2 = MatrixUnionCols(mat, wrappedMat2, "inner") - checkRebuild[MatrixUnionCols](muc2, muc2.typ.copy(colType = TStruct(("ck", TString))), (old, rebuilt) => - childrenMatch(rebuilt) + checkRebuild[MatrixUnionCols]( + muc2, + muc2.typ.copy(colType = TStruct(("ck", TString))), + (old, rebuilt) => + childrenMatch(rebuilt), ) } - @Test def testMatrixAnnotateRowsTableRebuild() { + @Test def testMatrixAnnotateRowsTableRebuild(): Unit = { val tl = TableLiteral(Interpret(MatrixRowsTable(mat), ctx), theHailClassLoader) - val mart = MatrixAnnotateRowsTable(mat, tl, "foo", product=false) - checkRebuild(mart, subsetMatrixTable(mart.typ), - (_: BaseIR, r: BaseIR) => { - r.isInstanceOf[MatrixLiteral] - }) + val mart = MatrixAnnotateRowsTable(mat, tl, "foo", product = false) + checkRebuild( + mart, + subsetMatrixTable(mart.typ), + (_: BaseIR, r: BaseIR) => + r.isInstanceOf[MatrixLiteral], + ) } val ts = TStruct( "a" -> TInt32, "b" -> TInt64, - "c" -> TString + "c" -> TString, ) def subsetTS(fields: String*): TStruct = ts.filterSet(fields.toSet)._1 - @Test def testNARebuild() { - checkRebuild(NA(ts), subsetTS("b"), + @Test def testNARebuild(): Unit = { + checkRebuild( + NA(ts), + subsetTS("b"), (_: BaseIR, r: BaseIR) => { val na = r.asInstanceOf[NA] na.typ == subsetTS("b") - }) + }, + ) } - @Test def testIfRebuild() { - checkRebuild(If(True(), NA(ts), NA(ts)), subsetTS("b"), + @Test def testIfRebuild(): Unit = { + checkRebuild( + If(True(), NA(ts), NA(ts)), + subsetTS("b"), (_: BaseIR, r: BaseIR) => { val ir = r.asInstanceOf[If] ir.cnsq.typ == subsetTS("b") && ir.altr.typ == subsetTS("b") - }) + }, + ) } @Test def testSwitchRebuild(): Unit = - checkRebuild[IR](Switch(I32(0), NA(ts), FastSeq(NA(ts))), subsetTS("b"), { - case (_, Switch(_, default, cases)) => - default.typ == subsetTS("b") && + checkRebuild[IR]( + Switch(I32(0), NA(ts), FastSeq(NA(ts))), + subsetTS("b"), + { + case (_, Switch(_, default, cases)) => + default.typ == subsetTS("b") && cases(0).typ == subsetTS("b") - } + }, ) - @Test def testCoalesceRebuild() { - checkRebuild(Coalesce(FastSeq(NA(ts), NA(ts))), subsetTS("b"), - (_: BaseIR, r: BaseIR) => { - r.children.forall(_.typ == subsetTS("b")) - }) + @Test def testCoalesceRebuild(): Unit = { + checkRebuild( + Coalesce(FastSeq(NA(ts), NA(ts))), + subsetTS("b"), + (_: BaseIR, r: BaseIR) => + r.children.forall(_.typ == subsetTS("b")), + ) } - @Test def testLetRebuild() { - checkRebuild(Let(FastSeq("x" -> NA(ts)), Ref("x", ts)), subsetTS("b"), + @Test def testLetRebuild(): Unit = { + checkRebuild( + Let(FastSeq("x" -> NA(ts)), Ref("x", ts)), + subsetTS("b"), (_: BaseIR, r: BaseIR) => { - val ir = r.asInstanceOf[Let] - ir.bindings.head._2.typ == subsetTS("b") - }) + val ir = r.asInstanceOf[Block] + ir.bindings.head.value.typ == subsetTS("b") + }, + ) } - @Test def testAggLetRebuild() { - checkRebuild(AggLet("foo", NA(ref.typ), - ApplyAggOp(FastSeq(), FastSeq( - SelectFields(Ref("foo", ref.typ), IndexedSeq("a"))), - AggSignature(Collect(), FastSeq(), FastSeq(ref.typ))), false), TArray(subsetTS("a")), - (_: BaseIR, r: BaseIR) => { - val ir = r.asInstanceOf[AggLet] - ir.value.typ == subsetTS("a") - }) + @Test def testAggLetRebuild(): Unit = { + checkRebuild( + AggLet( + "foo", + NA(ref.typ), + ApplyAggOp( + FastSeq(), + FastSeq( + SelectFields(Ref("foo", ref.typ), IndexedSeq("a")) + ), + AggSignature(Collect(), FastSeq(), FastSeq(ref.typ)), + ), + false, + ), + TArray(subsetTS("a")), + (_: BaseIR, r: BaseIR) => + r match { + case Block(Seq(Binding(_, value, Scope.AGG)), _) => + value.typ == subsetTS("a") + }, + ) } - @Test def testMakeArrayRebuild() { - checkRebuild(MakeArray(IndexedSeq(NA(ts)), TArray(ts)), TArray(subsetTS("b")), + @Test def testMakeArrayRebuild(): Unit = { + checkRebuild( + MakeArray(IndexedSeq(NA(ts)), TArray(ts)), + TArray(subsetTS("b")), (_: BaseIR, r: BaseIR) => { val ir = r.asInstanceOf[MakeArray] ir.args.head.typ == subsetTS("b") - }) + }, + ) } - @Test def testStreamTakeRebuild() { - checkRebuild(StreamTake(MakeStream(IndexedSeq(NA(ts)), TStream(ts)), I32(2)), TStream(subsetTS("b")), - (_: BaseIR, r: BaseIR) => { - val ir = r.asInstanceOf[StreamTake] - ir.a.typ == TStream(subsetTS("b")) - }) + @Test def testStreamTakeRebuild(): Unit = { + checkRebuild( + StreamTake(MakeStream(IndexedSeq(NA(ts)), TStream(ts)), I32(2)), + TStream(subsetTS("b")), + (_: BaseIR, r: BaseIR) => { + val ir = r.asInstanceOf[StreamTake] + ir.a.typ == TStream(subsetTS("b")) + }, + ) } - @Test def testStreamDropRebuild() { - checkRebuild(StreamDrop(MakeStream(IndexedSeq(NA(ts)), TStream(ts)), I32(2)), TStream(subsetTS("b")), - (_: BaseIR, r: BaseIR) => { - val ir = r.asInstanceOf[StreamDrop] - ir.a.typ == TStream(subsetTS("b")) - }) + @Test def testStreamDropRebuild(): Unit = { + checkRebuild( + StreamDrop(MakeStream(IndexedSeq(NA(ts)), TStream(ts)), I32(2)), + TStream(subsetTS("b")), + (_: BaseIR, r: BaseIR) => { + val ir = r.asInstanceOf[StreamDrop] + ir.a.typ == TStream(subsetTS("b")) + }, + ) } - @Test def testStreamMapRebuild() { - checkRebuild(StreamMap(MakeStream(IndexedSeq(NA(ts)), TStream(ts)), "x", Ref("x", ts)), TStream(subsetTS("b")), + @Test def testStreamMapRebuild(): Unit = { + checkRebuild( + StreamMap(MakeStream(IndexedSeq(NA(ts)), TStream(ts)), "x", Ref("x", ts)), + TStream(subsetTS("b")), (_: BaseIR, r: BaseIR) => { val ir = r.asInstanceOf[StreamMap] ir.a.typ == TStream(subsetTS("b")) - }) + }, + ) } - @Test def testStreamGroupedRebuild() { - checkRebuild(StreamGrouped(MakeStream(IndexedSeq(NA(ts)), TStream(ts)), I32(2)), TStream(TStream(subsetTS("b"))), + @Test def testStreamGroupedRebuild(): Unit = { + checkRebuild( + StreamGrouped(MakeStream(IndexedSeq(NA(ts)), TStream(ts)), I32(2)), + TStream(TStream(subsetTS("b"))), (_: BaseIR, r: BaseIR) => { val ir = r.asInstanceOf[StreamGrouped] ir.a.typ == TStream(subsetTS("b")) - }) + }, + ) } - @Test def testStreamGroupByKeyRebuild() { - checkRebuild(StreamGroupByKey(MakeStream(IndexedSeq(NA(ts)), TStream(ts)), FastSeq("a"), false), TStream(TStream(subsetTS("b"))), - (_: BaseIR, r: BaseIR) => { - val ir = r.asInstanceOf[StreamGroupByKey] - ir.a.typ == TStream(subsetTS("a", "b")) - }) + @Test def testStreamGroupByKeyRebuild(): Unit = { + checkRebuild( + StreamGroupByKey(MakeStream(IndexedSeq(NA(ts)), TStream(ts)), FastSeq("a"), false), + TStream(TStream(subsetTS("b"))), + (_: BaseIR, r: BaseIR) => { + val ir = r.asInstanceOf[StreamGroupByKey] + ir.a.typ == TStream(subsetTS("a", "b")) + }, + ) } - @Test def testStreamMergeRebuild() { + @Test def testStreamMergeRebuild(): Unit = { checkRebuild( - StreamMultiMerge(IndexedSeq(MakeStream(IndexedSeq(NA(ts)), TStream(ts)), MakeStream(IndexedSeq(NA(ts)), TStream(ts))), FastSeq("a")), + StreamMultiMerge( + IndexedSeq( + MakeStream(IndexedSeq(NA(ts)), TStream(ts)), + MakeStream(IndexedSeq(NA(ts)), TStream(ts)), + ), + FastSeq("a"), + ), TStream(subsetTS("b")), - (_: BaseIR, r: BaseIR) => r.typ == TStream(subsetTS("a", "b"))) + (_: BaseIR, r: BaseIR) => r.typ == TStream(subsetTS("a", "b")), + ) } - @Test def testStreamZipRebuild() { + @Test def testStreamZipRebuild(): Unit = { val a2 = st.deepCopy() val a3 = st.deepCopy() - for (b <- Array(ArrayZipBehavior.ExtendNA, ArrayZipBehavior.TakeMinLength, ArrayZipBehavior.AssertSameLength)) { - - checkRebuild(StreamZip( + for ( + b <- Array( + ArrayZipBehavior.ExtendNA, + ArrayZipBehavior.TakeMinLength, + ArrayZipBehavior.AssertSameLength, + ) + ) { + + checkRebuild( + StreamZip( + FastSeq(st, a2, a3), + FastSeq("foo", "bar", "baz"), + Let( + FastSeq( + "foo1" -> GetField(Ref("foo", ref.typ), "b"), + "bar2" -> GetField(Ref("bar", ref.typ), "a"), + ), + False(), + ), + b, + ), + TStream(TBoolean), + (_: BaseIR, r: BaseIR) => r.asInstanceOf[StreamZip].as.length == 3, + ) + } + checkRebuild( + StreamZip( FastSeq(st, a2, a3), FastSeq("foo", "bar", "baz"), Let( FastSeq( "foo1" -> GetField(Ref("foo", ref.typ), "b"), - "bar2" -> GetField(Ref("bar", ref.typ), "a") + "bar2" -> GetField(Ref("bar", ref.typ), "a"), ), - False() - ), b), - TStream(TBoolean), - (_: BaseIR, r: BaseIR) => r.asInstanceOf[StreamZip].as.length == 3) - } - checkRebuild(StreamZip( - FastSeq(st, a2, a3), - FastSeq("foo", "bar", "baz"), - Let( - FastSeq( - "foo1" -> GetField(Ref("foo", ref.typ), "b"), - "bar2" -> GetField(Ref("bar", ref.typ), "a") + False(), ), - False() + ArrayZipBehavior.AssumeSameLength, ), - ArrayZipBehavior.AssumeSameLength), TStream(TBoolean), - (_: BaseIR, r: BaseIR) => r.asInstanceOf[StreamZip].as.length == 2) + (_: BaseIR, r: BaseIR) => r.asInstanceOf[StreamZip].as.length == 2, + ) } - @Test def testStreamFlatmapRebuild() { - checkRebuild(StreamFlatMap(MakeStream(IndexedSeq(NA(ts)), TStream(ts)), "x", MakeStream(IndexedSeq(Ref("x", ts)), TStream(ts))), + @Test def testStreamFlatmapRebuild(): Unit = { + checkRebuild( + StreamFlatMap( + MakeStream(IndexedSeq(NA(ts)), TStream(ts)), + "x", + MakeStream(IndexedSeq(Ref("x", ts)), TStream(ts)), + ), TStream(subsetTS("b")), (_: BaseIR, r: BaseIR) => { val ir = r.asInstanceOf[StreamFlatMap] ir.a.typ == TStream(subsetTS("b")) - }) + }, + ) } - @Test def testMakeStructRebuild() { - checkRebuild(MakeStruct(IndexedSeq("a" -> NA(TInt32), "b" -> NA(TInt64), "c" -> NA(TString))), subsetTS("b"), - (_: BaseIR, r: BaseIR) => { - r == MakeStruct(IndexedSeq("b" -> NA(TInt64))) - }) + @Test def testMakeStructRebuild(): Unit = { + checkRebuild( + MakeStruct(IndexedSeq("a" -> NA(TInt32), "b" -> NA(TInt64), "c" -> NA(TString))), + subsetTS("b"), + (_: BaseIR, r: BaseIR) => + r == MakeStruct(IndexedSeq("b" -> NA(TInt64))), + ) } - @Test def testInsertFieldsRebuild() { - checkRebuild(InsertFields(NA(TStruct("a" -> TInt32)), IndexedSeq("b" -> NA(TInt64), "c" -> NA(TString))), + @Test def testInsertFieldsRebuild(): Unit = { + checkRebuild( + InsertFields(NA(TStruct("a" -> TInt32)), IndexedSeq("b" -> NA(TInt64), "c" -> NA(TString))), subsetTS("b"), (_: BaseIR, r: BaseIR) => { val ir = r.asInstanceOf[InsertFields] ir.fields == IndexedSeq( "b" -> NA(TInt64) ) - }) + }, + ) - // Example needs to have field insertion that overwrites an unrequested field with a different type. - val insertF = InsertFields(Ref("foo", TStruct(("a", TInt32), ("b", TInt32))), - IndexedSeq(("a", I64(8))) + /* Example needs to have field insertion that overwrites an unrequested field with a different + * type. */ + val insertF = + InsertFields(Ref("foo", TStruct(("a", TInt32), ("b", TInt32))), IndexedSeq(("a", I64(8)))) + checkRebuild[InsertFields]( + insertF, + TStruct(("b", TInt32)), + (old, rebuilt) => + PruneDeadFields.isSupertype(rebuilt.typ, old.typ), ) - checkRebuild[InsertFields](insertF, TStruct(("b", TInt32)), (old, rebuilt) => { - PruneDeadFields.isSupertype(rebuilt.typ, old.typ) - }) } - @Test def testMakeTupleRebuild() { - checkRebuild(MakeTuple(IndexedSeq(0 -> I32(1), 1 -> F64(1.0), 2 -> NA(TString))), + @Test def testMakeTupleRebuild(): Unit = { + checkRebuild( + MakeTuple(IndexedSeq(0 -> I32(1), 1 -> F64(1.0), 2 -> NA(TString))), TTuple(FastSeq(TupleField(2, TString))), - (_: BaseIR, r: BaseIR) => { - r == MakeTuple(IndexedSeq(2 -> NA(TString))) - }) + (_: BaseIR, r: BaseIR) => + r == MakeTuple(IndexedSeq(2 -> NA(TString))), + ) } - @Test def testSelectFieldsRebuild() { - checkRebuild(SelectFields(NA(ts), IndexedSeq("a", "b")), + @Test def testSelectFieldsRebuild(): Unit = { + checkRebuild( + SelectFields(NA(ts), IndexedSeq("a", "b")), subsetTS("b"), (_: BaseIR, r: BaseIR) => { val ir = r.asInstanceOf[SelectFields] ir.fields == IndexedSeq("b") - }) + }, + ) } - @Test def testCastRenameRebuild() { + @Test def testCastRenameRebuild(): Unit = { checkRebuild( CastRename( NA(TArray(TStruct("x" -> TInt32, "y" -> TString))), - TArray(TStruct("y" -> TInt32, "z" -> TString))), + TArray(TStruct("y" -> TInt32, "z" -> TString)), + ), TArray(TStruct("z" -> TString)), (_: BaseIR, r: BaseIR) => { val ir = r.asInstanceOf[CastRename] ir._typ == TArray(TStruct("z" -> TString)) - }) + }, + ) } - val ndArrayTS = MakeNDArray(MakeArray(ArrayBuffer(NA(ts)), TArray(ts)), MakeTuple(IndexedSeq((0, I64(1l)))), True(), ErrorIDs.NO_ERROR) + val ndArrayTS = MakeNDArray( + MakeArray(ArrayBuffer(NA(ts)), TArray(ts)), + MakeTuple(IndexedSeq((0, I64(1L)))), + True(), + ErrorIDs.NO_ERROR, + ) - @Test def testNDArrayMapRebuild() { - checkRebuild(NDArrayMap(ndArrayTS, "x", Ref("x", ts)), TNDArray(subsetTS("b"), Nat(1)), + @Test def testNDArrayMapRebuild(): Unit = { + checkRebuild( + NDArrayMap(ndArrayTS, "x", Ref("x", ts)), + TNDArray(subsetTS("b"), Nat(1)), (_: BaseIR, r: BaseIR) => { val ir = r.asInstanceOf[NDArrayMap] - // Even though the type I requested wasn't required, NDArrays always have a required element type. + /* Even though the type I requested wasn't required, NDArrays always have a required element + * type. */ ir.nd.typ == TNDArray(TStruct(("b", TInt64)), Nat(1)) - }) + }, + ) } @Test def testNDArrayMap2Rebuild(): Unit = { - checkRebuild(NDArrayMap2(ndArrayTS, ndArrayTS, "left", "right", Ref("left", ts), ErrorIDs.NO_ERROR), TNDArray(subsetTS("b"), Nat(1)), + checkRebuild( + NDArrayMap2(ndArrayTS, ndArrayTS, "left", "right", Ref("left", ts), ErrorIDs.NO_ERROR), + TNDArray(subsetTS("b"), Nat(1)), (_: BaseIR, r: BaseIR) => { val ir = r.asInstanceOf[NDArrayMap2] ir.l.typ == TNDArray(TStruct(("b", TInt64)), Nat(1)) ir.r.typ == TNDArray(TStruct.empty, Nat(1)) - }) - checkRebuild(NDArrayMap2(ndArrayTS, ndArrayTS, "left", "right", Ref("right", ts), ErrorIDs.NO_ERROR), TNDArray(subsetTS("b"), Nat(1)), + }, + ) + checkRebuild( + NDArrayMap2(ndArrayTS, ndArrayTS, "left", "right", Ref("right", ts), ErrorIDs.NO_ERROR), + TNDArray(subsetTS("b"), Nat(1)), (_: BaseIR, r: BaseIR) => { val ir = r.asInstanceOf[NDArrayMap2] ir.l.typ == TNDArray(TStruct.empty, Nat(1)) ir.r.typ == TNDArray(TStruct(("b", TInt64)), Nat(1)) - }) + }, + ) } - @Test def testCDARebuild() { + @Test def testCDARebuild(): Unit = { val ctxT = TStruct("a" -> TInt32, "b" -> TString) val globT = TStruct("c" -> TInt64, "d" -> TFloat64) val x = CollectDistributedArray( @@ -1399,79 +1921,106 @@ class PruneSuite extends HailSuite { NA(globT), "ctx", "glob", - MakeTuple.ordered(FastSeq(Ref("ctx", ctxT), Ref("glob", globT))), NA(TString), "test") + MakeTuple.ordered(FastSeq(Ref("ctx", ctxT), Ref("glob", globT))), + NA(TString), + "test", + ) val selectedCtxT = ctxT.typeAfterSelectNames(Array("a")) val selectedGlobT = globT.typeAfterSelectNames(Array("c")) - checkRebuild(x, TArray(TTuple(selectedCtxT, selectedGlobT)), (_: BaseIR, r: BaseIR) => { - r == CollectDistributedArray( - NA(TStream(selectedCtxT)), - NA(selectedGlobT), - "ctx", - "glob", - MakeTuple.ordered(FastSeq(Ref("ctx", selectedCtxT), Ref("glob", selectedGlobT))), NA(TString), "test") - }) + checkRebuild( + x, + TArray(TTuple(selectedCtxT, selectedGlobT)), + (_: BaseIR, r: BaseIR) => { + r == CollectDistributedArray( + NA(TStream(selectedCtxT)), + NA(selectedGlobT), + "ctx", + "glob", + MakeTuple.ordered(FastSeq(Ref("ctx", selectedCtxT), Ref("glob", selectedGlobT))), + NA(TString), + "test", + ) + }, + ) } - - @Test def testTableAggregateRebuild() { + @Test def testTableAggregateRebuild(): Unit = { val ta = TableAggregate(tr, tableRefBoolean(tr.typ, "row.2")) - checkRebuild(ta, TBoolean, + checkRebuild( + ta, + TBoolean, (_: BaseIR, r: BaseIR) => { val ir = r.asInstanceOf[TableAggregate] ir.child.typ == subsetTable(tr.typ, "row.2") - }) + }, + ) } - @Test def testTableCollectRebuild() { + @Test def testTableCollectRebuild(): Unit = { val tc = TableCollect(TableKeyBy(tab, FastSeq())) - checkRebuild(tc, TStruct("global" -> TStruct("g1" -> TInt32)), - (_: BaseIR, r: BaseIR) => { - r.asInstanceOf[MakeStruct].fields.head._2.isInstanceOf[TableGetGlobals] - }) + checkRebuild( + tc, + TStruct("global" -> TStruct("g1" -> TInt32)), + (_: BaseIR, r: BaseIR) => + r.asInstanceOf[MakeStruct].fields.head._2.isInstanceOf[TableGetGlobals], + ) - checkRebuild(tc, TStruct.empty, - (_: BaseIR, r: BaseIR) => { - r == MakeStruct(IndexedSeq()) - }) + checkRebuild( + tc, + TStruct.empty, + (_: BaseIR, r: BaseIR) => + r == MakeStruct(IndexedSeq()), + ) } - @Test def testMatrixAggregateRebuild() { + @Test def testMatrixAggregateRebuild(): Unit = { val ma = MatrixAggregate(mr, matrixRefBoolean(mr.typ, "va.r2")) - checkRebuild(ma, TBoolean, + checkRebuild( + ma, + TBoolean, (_: BaseIR, r: BaseIR) => { val ir = r.asInstanceOf[MatrixAggregate] ir.child.typ == subsetMatrixTable(mr.typ, "va.r2") - }) + }, + ) } - @Test def testPipelineLetRebuild() { + @Test def testPipelineLetRebuild(): Unit = { val t = TStruct("a" -> TInt32) - checkRebuild(RelationalLet("foo", NA(t), RelationalRef("foo", t)), TStruct.empty, - (_: BaseIR, r: BaseIR) => { - r.asInstanceOf[RelationalLet].body == RelationalRef("foo", TStruct.empty) - }) + checkRebuild( + RelationalLet("foo", NA(t), RelationalRef("foo", t)), + TStruct.empty, + (_: BaseIR, r: BaseIR) => + r.asInstanceOf[RelationalLet].body == RelationalRef("foo", TStruct.empty), + ) } - @Test def testPipelineLetTableRebuild() { + @Test def testPipelineLetTableRebuild(): Unit = { val t = TStruct("a" -> TInt32) - checkRebuild(RelationalLetTable("foo", NA(t), TableMapGlobals(tab, RelationalRef("foo", t))), + checkRebuild( + RelationalLetTable("foo", NA(t), TableMapGlobals(tab, RelationalRef("foo", t))), tab.typ.copy(globalType = TStruct.empty), - (_: BaseIR, r: BaseIR) => { - r.asInstanceOf[RelationalLetTable].body.asInstanceOf[TableMapGlobals].newGlobals == RelationalRef("foo", TStruct.empty) - }) + (_: BaseIR, r: BaseIR) => + r.asInstanceOf[RelationalLetTable].body.asInstanceOf[ + TableMapGlobals + ].newGlobals == RelationalRef("foo", TStruct.empty), + ) } - @Test def testPipelineLetMatrixTableRebuild() { + @Test def testPipelineLetMatrixTableRebuild(): Unit = { val t = TStruct("a" -> TInt32) - checkRebuild(RelationalLetMatrixTable("foo", NA(t), MatrixMapGlobals(mat, RelationalRef("foo", t))), + checkRebuild( + RelationalLetMatrixTable("foo", NA(t), MatrixMapGlobals(mat, RelationalRef("foo", t))), mat.typ.copy(globalType = TStruct.empty), - (_: BaseIR, r: BaseIR) => { - r.asInstanceOf[RelationalLetMatrixTable].body.asInstanceOf[MatrixMapGlobals].newGlobals == RelationalRef("foo", TStruct.empty) - }) + (_: BaseIR, r: BaseIR) => + r.asInstanceOf[RelationalLetMatrixTable].body.asInstanceOf[ + MatrixMapGlobals + ].newGlobals == RelationalRef("foo", TStruct.empty), + ) } - @Test def testIfUnification() { + @Test def testIfUnification(): Unit = { val pred = False() val t = TStruct("a" -> TInt32, "b" -> TInt32) val pruneT = TStruct("a" -> TInt32) @@ -1485,8 +2034,12 @@ class PruneSuite extends HailSuite { .bind(ifIR, pruneT) // should run without error! - PruneDeadFields.rebuildIR(ctx, ifIR, BindingEnv.empty[Type].bindEval("a", t), - PruneDeadFields.RebuildMutableState(memo, mutable.HashMap.empty)) + PruneDeadFields.rebuildIR( + ctx, + ifIR, + BindingEnv.empty[Type].bindEval("a", t), + PruneDeadFields.RebuildMutableState(memo, mutable.HashMap.empty), + ) } @DataProvider(name = "supertypePairs") @@ -1495,82 +2048,142 @@ class PruneSuite extends HailSuite { Array( TStruct( "a" -> TInt32, - "b" -> TArray(TInt64)), + "b" -> TArray(TInt64), + ), TStruct( "a" -> TInt32, - "b" -> TArray(TInt64))), - Array(TSet(TString), TSet(TString)) + "b" -> TArray(TInt64), + ), + ), + Array(TSet(TString), TSet(TString)), ) @Test(dataProvider = "supertypePairs") - def testIsSupertypeRequiredness(t1: Type, t2: Type) = { - assert(PruneDeadFields.isSupertype(t1, t2), + def testIsSupertypeRequiredness(t1: Type, t2: Type) = + assert( + PruneDeadFields.isSupertype(t1, t2), s"""Failure, supertype relationship not met - | supertype: ${ t1.toPrettyString(true) } - | subtype: ${ t2.toPrettyString(true) }""".stripMargin) - } + | supertype: ${t1.toPrettyString(true)} + | subtype: ${t2.toPrettyString(true)}""".stripMargin, + ) - @Test def testApplyScanOp() { + @Test def testApplyScanOp(): Unit = { val x = Ref("x", TInt32) val y = Ref("y", TInt32) val collectScan = ApplyScanOp( FastSeq(), FastSeq(MakeStruct(FastSeq(("x", x), ("y", y)))), - AggSignature(Collect(), FastSeq(), FastSeq(TStruct("x" -> TInt32, "y" -> TInt32)))) - checkRebuild(collectScan, TArray(TStruct("y" -> TInt32)), { (_: BaseIR, reb: BaseIR) => reb.typ == TArray(TStruct("y" -> TInt32))}) + AggSignature(Collect(), FastSeq(), FastSeq(TStruct("x" -> TInt32, "y" -> TInt32))), + ) + checkRebuild( + collectScan, + TArray(TStruct("y" -> TInt32)), + (_: BaseIR, reb: BaseIR) => reb.typ == TArray(TStruct("y" -> TInt32)), + ) val takeScan = ApplyScanOp( FastSeq(I32(1)), FastSeq(MakeStruct(FastSeq(("x", x), ("y", y)))), - AggSignature(Take(), FastSeq(TInt32), FastSeq(TStruct("x" -> TInt32, "y" -> TInt32)))) - checkRebuild(takeScan, TArray(TStruct("y" -> TInt32)), { (_: BaseIR, reb: BaseIR) => reb.typ == TArray(TStruct("y" -> TInt32))}) + AggSignature(Take(), FastSeq(TInt32), FastSeq(TStruct("x" -> TInt32, "y" -> TInt32))), + ) + checkRebuild( + takeScan, + TArray(TStruct("y" -> TInt32)), + (_: BaseIR, reb: BaseIR) => reb.typ == TArray(TStruct("y" -> TInt32)), + ) val prevnn = ApplyScanOp( FastSeq(), FastSeq(MakeStruct(FastSeq(("x", x), ("y", y)))), - AggSignature(PrevNonnull(), FastSeq(), FastSeq(TStruct("x" -> TInt32, "y" -> TInt32)))) - checkRebuild(prevnn, TStruct("y" -> TInt32), { (_: BaseIR, reb: BaseIR) => reb.typ == TStruct("y" -> TInt32)}) + AggSignature(PrevNonnull(), FastSeq(), FastSeq(TStruct("x" -> TInt32, "y" -> TInt32))), + ) + checkRebuild( + prevnn, + TStruct("y" -> TInt32), + (_: BaseIR, reb: BaseIR) => reb.typ == TStruct("y" -> TInt32), + ) val takeByScan = ApplyScanOp( FastSeq(I32(1)), FastSeq(MakeStruct(FastSeq(("x", x), ("y", y))), MakeStruct(FastSeq(("x", x), ("y", y)))), - AggSignature(TakeBy(), FastSeq(TInt32), FastSeq(TStruct("x" -> TInt32, "y" -> TInt32), TStruct("x" -> TInt32, "y" -> TInt32)))) - checkRebuild(takeByScan, TArray(TStruct("y" -> TInt32)), { (_: BaseIR, reb: BaseIR) => - val s = reb.asInstanceOf[ApplyScanOp] - s.seqOpArgs == FastSeq(MakeStruct(FastSeq(("y", y))), MakeStruct(FastSeq(("x", x), ("y", y))))}) + AggSignature( + TakeBy(), + FastSeq(TInt32), + FastSeq(TStruct("x" -> TInt32, "y" -> TInt32), TStruct("x" -> TInt32, "y" -> TInt32)), + ), + ) + checkRebuild( + takeByScan, + TArray(TStruct("y" -> TInt32)), + { (_: BaseIR, reb: BaseIR) => + val s = reb.asInstanceOf[ApplyScanOp] + s.seqOpArgs == FastSeq( + MakeStruct(FastSeq(("y", y))), + MakeStruct(FastSeq(("x", x), ("y", y))), + ) + }, + ) } - @Test def testApplyAggOp() { + @Test def testApplyAggOp(): Unit = { val x = Ref("x", TInt32) val y = Ref("y", TInt32) val collectAgg = ApplyAggOp( FastSeq(), FastSeq(MakeStruct(FastSeq(("x", x), ("y", y)))), - AggSignature(Collect(), FastSeq(), FastSeq(TStruct("x" -> TInt32, "y" -> TInt32)))) - checkRebuild(collectAgg, TArray(TStruct("y" -> TInt32)), { (_: BaseIR, reb: BaseIR) => reb.typ == TArray(TStruct("y" -> TInt32))}) + AggSignature(Collect(), FastSeq(), FastSeq(TStruct("x" -> TInt32, "y" -> TInt32))), + ) + checkRebuild( + collectAgg, + TArray(TStruct("y" -> TInt32)), + (_: BaseIR, reb: BaseIR) => reb.typ == TArray(TStruct("y" -> TInt32)), + ) val takeAgg = ApplyAggOp( FastSeq(I32(1)), FastSeq(MakeStruct(FastSeq(("x", x), ("y", y)))), - AggSignature(Take(), FastSeq(TInt32), FastSeq(TStruct("x" -> TInt32, "y" -> TInt32)))) - checkRebuild(takeAgg, TArray(TStruct("y" -> TInt32)), { (_: BaseIR, reb: BaseIR) => reb.typ == TArray(TStruct("y" -> TInt32))}) + AggSignature(Take(), FastSeq(TInt32), FastSeq(TStruct("x" -> TInt32, "y" -> TInt32))), + ) + checkRebuild( + takeAgg, + TArray(TStruct("y" -> TInt32)), + (_: BaseIR, reb: BaseIR) => reb.typ == TArray(TStruct("y" -> TInt32)), + ) val prevnn = ApplyAggOp( FastSeq(), FastSeq(MakeStruct(FastSeq(("x", x), ("y", y)))), - AggSignature(PrevNonnull(), FastSeq(), FastSeq(TStruct("x" -> TInt32, "y" -> TInt32)))) - checkRebuild(prevnn, TStruct("y" -> TInt32), { (_: BaseIR, reb: BaseIR) => reb.typ == TStruct("y" -> TInt32)}) + AggSignature(PrevNonnull(), FastSeq(), FastSeq(TStruct("x" -> TInt32, "y" -> TInt32))), + ) + checkRebuild( + prevnn, + TStruct("y" -> TInt32), + (_: BaseIR, reb: BaseIR) => reb.typ == TStruct("y" -> TInt32), + ) val takeByAgg = ApplyAggOp( FastSeq(I32(1)), FastSeq(MakeStruct(FastSeq(("x", x), ("y", y))), MakeStruct(FastSeq(("x", x), ("y", y)))), - AggSignature(TakeBy(), FastSeq(TInt32), FastSeq(TStruct("x" -> TInt32, "y" -> TInt32), TStruct("x" -> TInt32, "y" -> TInt32)))) - checkRebuild(takeByAgg, TArray(TStruct("y" -> TInt32)), { (_: BaseIR, reb: BaseIR) => - val a = reb.asInstanceOf[ApplyAggOp] - a.seqOpArgs == FastSeq(MakeStruct(FastSeq(("y", y))), MakeStruct(FastSeq(("x", x), ("y", y))))}) + AggSignature( + TakeBy(), + FastSeq(TInt32), + FastSeq(TStruct("x" -> TInt32, "y" -> TInt32), TStruct("x" -> TInt32, "y" -> TInt32)), + ), + ) + checkRebuild( + takeByAgg, + TArray(TStruct("y" -> TInt32)), + { (_: BaseIR, reb: BaseIR) => + val a = reb.asInstanceOf[ApplyAggOp] + a.seqOpArgs == FastSeq( + MakeStruct(FastSeq(("y", y))), + MakeStruct(FastSeq(("x", x), ("y", y))), + ) + }, + ) } - @Test def testStreamFold2() { + @Test def testStreamFold2(): Unit = { val eltType = TStruct("a" -> TInt32, "b" -> TInt32) val accum1Type = TStruct("c" -> TInt32, "d" -> TInt32) @@ -1581,8 +2194,11 @@ class PruneSuite extends HailSuite { FastSeq( MakeStruct(FastSeq( "c" -> GetField(Ref("elt", eltType), "a"), - "d" -> GetField(Ref("1", accum1Type), "c")))), - Ref("1", TStruct("c" -> TInt32, "d" -> TInt32))) + "d" -> GetField(Ref("1", accum1Type), "c"), + )) + ), + Ref("1", TStruct("c" -> TInt32, "d" -> TInt32)), + ) def checker(original: IR, rebuilt: IR): Boolean = { val r = rebuilt.asInstanceOf[StreamFold2] diff --git a/hail/src/test/scala/is/hail/expr/ir/RandomSuite.scala b/hail/src/test/scala/is/hail/expr/ir/RandomSuite.scala index 844f52011e8..b8c055351de 100644 --- a/hail/src/test/scala/is/hail/expr/ir/RandomSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/RandomSuite.scala @@ -2,8 +2,11 @@ package is.hail.expr.ir import is.hail.HailSuite import is.hail.asm4s._ -import is.hail.types.physical.stypes.concrete.{SCanonicalRNGStateSettable, SCanonicalRNGStateValue, SRNGState, SRNGStateStaticSizeValue} +import is.hail.types.physical.stypes.concrete.{ + SCanonicalRNGStateSettable, SCanonicalRNGStateValue, SRNGState, SRNGStateStaticSizeValue, +} import is.hail.utils.FastSeq + import org.apache.commons.math3.distribution.ChiSquaredDistribution import org.testng.annotations.Test @@ -14,15 +17,17 @@ class RandomSuite extends HailSuite { Array(0x0L, 0x0L, 0x0L, 0x0L), Array(0x0L, 0x0L), Array(0x0L, 0x0L, 0x0L, 0x0L), - Array(0x09218EBDE6C85537L, 0x55941F5266D86105L, 0x4BD25E16282434DCL, 0xEE29EC846BD2E40BL) - ), ( - Array(0x1716151413121110L, 0x1F1E1D1C1B1A1918L, 0x2726252423222120L, 0x2F2E2D2C2B2A2928L), - Array(0x0706050403020100L, 0x0F0E0D0C0B0A0908L), - Array(0xF8F9FAFBFCFDFEFFL, 0xF0F1F2F3F4F5F6F7L, 0xE8E9EAEBECEDEEEFL, 0xE0E1E2E3E4E5E6E7L), - Array(0x008CF75D18C19DA0L, 0x1D7D14BE2266E7D8L, 0x5D09E0E985FE673BL, 0xB4A5480C6039B172L) - )) - - @Test def testThreefry() { + Array(0x09218ebde6c85537L, 0x55941f5266d86105L, 0x4bd25e16282434dcL, 0xee29ec846bd2e40bL), + ), + ( + Array(0x1716151413121110L, 0x1f1e1d1c1b1a1918L, 0x2726252423222120L, 0x2f2e2d2c2b2a2928L), + Array(0x0706050403020100L, 0x0f0e0d0c0b0a0908L), + Array(0xf8f9fafbfcfdfeffL, 0xf0f1f2f3f4f5f6f7L, 0xe8e9eaebecedeeefL, 0xe0e1e2e3e4e5e6e7L), + Array(0x008cf75d18c19da0L, 0x1d7d14be2266e7d8L, 0x5d09e0e985fe673bL, 0xb4a5480c6039b172L), + ), + ) + + @Test def testThreefry(): Unit = { for { (key, tweak, input, expected) <- threefryTestCases } { @@ -35,8 +40,15 @@ class RandomSuite extends HailSuite { x = input.clone() Threefry.encryptUnrolled( - expandedKey(0), expandedKey(1), expandedKey(2), expandedKey(3), expandedKey(4), - tweak(0), tweak(1), x) + expandedKey(0), + expandedKey(1), + expandedKey(2), + expandedKey(3), + expandedKey(4), + tweak(0), + tweak(1), + x, + ) assert(x sameElements expected) x = input.clone() @@ -50,9 +62,8 @@ class RandomSuite extends HailSuite { f.emb.emitWithBuilder { cb => val message = f.mb.getArg[Array[Long]](1) var state = SRNGStateStaticSizeValue(cb) - for (i <- 0 until size) { + for (i <- 0 until size) state = state.splitDyn(cb, cb.memoize(message(i))) - } state = state.splitStatic(cb, staticID) val result = state.rand(cb) @@ -67,18 +78,20 @@ class RandomSuite extends HailSuite { f.result()(new HailClassLoader(getClass.getClassLoader)) } - def pmacEngineStagedStaticSize(staticID: Long, size: Int): AsmFunction1[Array[Long], ThreefryRandomEngine] = { + def pmacEngineStagedStaticSize(staticID: Long, size: Int) + : AsmFunction1[Array[Long], ThreefryRandomEngine] = { val f = EmitFunctionBuilder[Array[Long], ThreefryRandomEngine](ctx, "pmacStaticSize") f.emb.emitWithBuilder { cb => val message = f.mb.getArg[Array[Long]](1) var state = SRNGStateStaticSizeValue(cb) - for (i <- 0 until size) { + for (i <- 0 until size) state = state.splitDyn(cb, cb.memoize(message(i))) - } state = state.splitStatic(cb, staticID) val engine = cb.memoize(Code.invokeScalaObject0[ThreefryRandomEngine]( - ThreefryRandomEngine.getClass, "apply")) + ThreefryRandomEngine.getClass, + "apply", + )) state.copyIntoEngine(cb, engine) engine } @@ -93,9 +106,12 @@ class RandomSuite extends HailSuite { cb.assign(state, SCanonicalRNGStateValue(cb)) val i = cb.newLocal[Int]("i", 0) val len = cb.memoize(message.length()) - cb.for_({}, i < len, cb.assign(i, i + 1), { - cb.assign(state, state.splitDyn(cb, cb.memoize(message(i)))) - }) + cb.for_( + {}, + i < len, + cb.assign(i, i + 1), + cb.assign(state, state.splitDyn(cb, cb.memoize(message(i)))), + ) cb.assign(state, state.splitStatic(cb, staticID)) val result = state.rand(cb) @@ -118,13 +134,18 @@ class RandomSuite extends HailSuite { cb.assign(state, SCanonicalRNGStateValue(cb)) val i = cb.newLocal[Int]("i", 0) val len = cb.memoize(message.length()) - cb.for_({}, i < len, cb.assign(i, i + 1), { - cb.assign(state, state.splitDyn(cb, cb.memoize(message(i)))) - }) + cb.for_( + {}, + i < len, + cb.assign(i, i + 1), + cb.assign(state, state.splitDyn(cb, cb.memoize(message(i)))), + ) cb.assign(state, state.splitStatic(cb, staticID)) val engine = cb.memoize(Code.invokeScalaObject0[ThreefryRandomEngine]( - ThreefryRandomEngine.getClass, "apply")) + ThreefryRandomEngine.getClass, + "apply", + )) state.copyIntoEngine(cb, engine) engine } @@ -135,10 +156,10 @@ class RandomSuite extends HailSuite { (Array[Long](), 0L), (Array[Long](100, 101), 10L), (Array[Long](100, 101, 102, 103), 20L), - (Array[Long](100, 101, 102, 103, 104), 30L) + (Array[Long](100, 101, 102, 103, 104), 30L), ) - @Test def testPMAC() { + @Test def testPMAC(): Unit = { for { (message, staticID) <- pmacTestCases } { @@ -150,7 +171,7 @@ class RandomSuite extends HailSuite { } } - @Test def testPMACHash() { + @Test def testPMACHash(): Unit = { for { (message, _) <- pmacTestCases } { @@ -163,7 +184,7 @@ class RandomSuite extends HailSuite { } } - @Test def testRandomEngine() { + @Test def testRandomEngine(): Unit = { for { (message, staticID) <- pmacTestCases } { @@ -188,7 +209,7 @@ class RandomSuite extends HailSuite { } } - def runChiSquareTest(samples: Int, buckets: Int)(sample: => Int) { + def runChiSquareTest(samples: Int, buckets: Int)(sample: => Int): Unit = { val chiSquareDist = new ChiSquaredDistribution(buckets - 1) val expected = samples.toDouble / buckets var numRuns = 0 @@ -202,13 +223,14 @@ class RandomSuite extends HailSuite { val chisquare = counts.map(observed => math.pow(observed - expected, 2) / expected).sum val pvalue = 1 - chiSquareDist.cumulativeProbability(chisquare) numRuns += 1 - geometricMean = math.pow(geometricMean, (numRuns - 1).toDouble / numRuns) * math.pow(pvalue, 1.0 / numRuns) + geometricMean = + math.pow(geometricMean, (numRuns - 1).toDouble / numRuns) * math.pow(pvalue, 1.0 / numRuns) } assert(geometricMean >= passThreshold, s"failed after $numRuns runs with pvalue $geometricMean") println(s"passed after $numRuns runs with pvalue $geometricMean") } - @Test def testRandomInt() { + @Test def testRandomInt(): Unit = { val n = 1 << 25 val k = 1 << 15 val rand = ThreefryRandomEngine.randState() @@ -217,7 +239,7 @@ class RandomSuite extends HailSuite { } } - @Test def testBoundedUniformInt() { + @Test def testBoundedUniformInt(): Unit = { var n = 1 << 25 var k = 1 << 15 val rand = ThreefryRandomEngine.randState() @@ -226,13 +248,13 @@ class RandomSuite extends HailSuite { } n = 30000000 - k = math.pow(n, 3.0/5).toInt + k = math.pow(n, 3.0 / 5).toInt runChiSquareTest(n, k) { rand.nextInt(k) } } - @Test def testBoundedUniformLong() { + @Test def testBoundedUniformLong(): Unit = { var n = 1 << 25 var k = 1 << 15 val rand = ThreefryRandomEngine.randState() @@ -241,13 +263,13 @@ class RandomSuite extends HailSuite { } n = 30000000 - k = math.pow(n, 3.0/5).toInt + k = math.pow(n, 3.0 / 5).toInt runChiSquareTest(n, k) { rand.nextLong(k).toInt } } - @Test def testUniformDouble() { + @Test def testUniformDouble(): Unit = { val n = 1 << 25 val k = 1 << 15 val rand = ThreefryRandomEngine.randState() @@ -258,7 +280,7 @@ class RandomSuite extends HailSuite { } } - @Test def testUniformFloat() { + @Test def testUniformFloat(): Unit = { val n = 1 << 25 val k = 1 << 15 val rand = ThreefryRandomEngine.randState() diff --git a/hail/src/test/scala/is/hail/expr/ir/RequirednessSuite.scala b/hail/src/test/scala/is/hail/expr/ir/RequirednessSuite.scala index ce7ade253f7..dea6905fc2c 100644 --- a/hail/src/test/scala/is/hail/expr/ir/RequirednessSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/RequirednessSuite.scala @@ -5,13 +5,14 @@ import is.hail.expr.Nat import is.hail.expr.ir.agg.CallStatsState import is.hail.io.{BufferSpec, TypedCodecSpec} import is.hail.stats.fetStruct +import is.hail.types._ import is.hail.types.physical._ import is.hail.types.physical.stypes.EmitType import is.hail.types.physical.stypes.interfaces.SStream import is.hail.types.physical.stypes.primitives.SInt32 import is.hail.types.virtual._ -import is.hail.types._ import is.hail.utils.{BoxedArrayBuilder, FastSeq} + import org.apache.spark.sql.Row import org.testng.annotations.{DataProvider, Test} @@ -32,12 +33,11 @@ class RequirednessSuite extends HailSuite { def call(r: Boolean): IR = if (r) invoke("Call", TCall, Str("0/0")) else NA(TCall) - def stream(r: Boolean, elt: Boolean): IR = { + def stream(r: Boolean, elt: Boolean): IR = if (r) MakeStream(FastSeq(int(elt), int(required)), tstream) else mapIR(NA(tstream))(x => x + int(elt)) - } def array(r: Boolean, elt: Boolean): IR = ToArray(stream(r, elt)) @@ -62,12 +62,12 @@ class RequirednessSuite extends HailSuite { // NDArrayMap(NA(tnestednd), genUID(), array(optional, aelt)) // } - def nestedstream(r: Boolean, a: Boolean, aelt: Boolean): IR = { + def nestedstream(r: Boolean, a: Boolean, aelt: Boolean): IR = if (r) MakeStream(FastSeq(array(a, required), array(required, aelt)), tnestedstream) else mapIR(NA(tnestedstream))(x => array(a, aelt)) - } + def nestedarray(r: Boolean, a: Boolean, aelt: Boolean): IR = ToArray(nestedstream(r, a, aelt)) def pint(r: Boolean): PInt32 = PInt32(r) @@ -83,31 +83,50 @@ class RequirednessSuite extends HailSuite { PCanonicalTuple(r, pint(i), parray(a, elt)) def pnd(r: Boolean): PNDArray = PCanonicalNDArray(pint(required), 2, r) + def pnestednd(r: Boolean, aelt: Boolean): PNDArray = PCanonicalNDArray(parray(required, aelt), 2, r) + def pnestedarray(r: Boolean, a: Boolean, aelt: Boolean): PArray = PCanonicalArray(parray(a, aelt), r) - def interval(point: IR, r: Boolean): IR = invoke("Interval", TInterval(point.typ), point, point.deepCopy(), True(), if (r) True() else NA(TBoolean)) + def interval(point: IR, r: Boolean): IR = invoke( + "Interval", + TInterval(point.typ), + point, + point.deepCopy(), + True(), + if (r) True() else NA(TBoolean), + ) def pinterval(point: PType, r: Boolean): PInterval = PCanonicalInterval(point, r) - @DataProvider(name="valueIR") + @DataProvider(name = "valueIR") def valueIR(): Array[Array[Any]] = { val nodes = new BoxedArrayBuilder[Array[Any]](50) val allRequired = Array( - I32(5), I64(5), F32(3.14f), F64(3.14), Str("foo"), True(), False(), + I32(5), + I64(5), + F32(3.14f), + F64(3.14), + Str("foo"), + True(), + False(), IsNA(I32(5)), Cast(I32(5), TFloat64), Die("mumblefoo", TFloat64), Literal(TStruct("x" -> TInt32), Row(1)), MakeArray(FastSeq(I32(4)), TArray(TInt32)), MakeStruct(FastSeq("x" -> I32(4), "y" -> Str("foo"))), - MakeTuple.ordered(FastSeq(I32(5), Str("bar")))) + MakeTuple.ordered(FastSeq(I32(5), Str("bar"))), + ) allRequired.foreach { n => - nodes += Array(n, RequirednessSuite.deepInnerRequired(PType.canonical(n.typ, required), required)) + nodes += Array( + n, + RequirednessSuite.deepInnerRequired(PType.canonical(n.typ, required), required), + ) } val bools = Array(true, false) @@ -127,21 +146,28 @@ class RequirednessSuite extends HailSuite { } // test coalesce - nodes += Array(Coalesce(FastSeq( - array(required, optional), - array(optional, required))), - parray(required, optional)) + nodes += Array( + Coalesce(FastSeq( + array(required, optional), + array(optional, required), + )), + parray(required, optional), + ) - nodes += Array(Coalesce(FastSeq( - array(optional, optional), - array(optional, required))), - parray(optional, optional)) + nodes += Array( + Coalesce(FastSeq( + array(optional, optional), + array(optional, required), + )), + parray(optional, optional), + ) // test read/write - val pDisc = PCanonicalStruct(required, + val pDisc = PCanonicalStruct( + required, "a" -> pint(optional), "b" -> parray(required, required), - "c" -> PCanonicalArray(pstruct(required, required, optional, required), required) + "c" -> PCanonicalArray(pstruct(required, required, optional, required), required), ) val spec = TypedCodecSpec(pDisc, BufferSpec.default) @@ -150,10 +176,11 @@ class RequirednessSuite extends HailSuite { val contextType = pr.contextType val rt1 = TStruct("a" -> TInt32, "b" -> TArray(TInt32)) val rt2 = TStruct("a" -> TInt32, "c" -> TArray(TStruct("x" -> TInt32))) - Array(Str("foo") -> pDisc, + Array( + Str("foo") -> pDisc, NA(TString) -> pDisc, Str("foo") -> pDisc.subsetTo(rt1), - Str("foo") -> pDisc.subsetTo(rt2) + Str("foo") -> pDisc.subsetTo(rt2), ).foreach { case (path, pt: PStruct) => nodes += Array( ReadPartition( @@ -162,19 +189,27 @@ class RequirednessSuite extends HailSuite { else NA(contextType), pt.virtualType, - pr), - EmitType(SStream(EmitType(pt.sType, pt.required)), path.isInstanceOf[Str])) - nodes += Array(ReadValue(path, ETypeValueReader(spec), pt.virtualType), pt.setRequired(path.isInstanceOf[Str])) + pr, + ), + EmitType(SStream(EmitType(pt.sType, pt.required)), path.isInstanceOf[Str]), + ) + nodes += Array( + ReadValue(path, ETypeValueReader(spec), pt.virtualType), + pt.setRequired(path.isInstanceOf[Str]), + ) } - val value = Literal(pDisc.virtualType, Row(null, IndexedSeq(1), IndexedSeq(Row(1, IndexedSeq(1))))) + val value = + Literal(pDisc.virtualType, Row(null, IndexedSeq(1), IndexedSeq(Row(1, IndexedSeq(1))))) nodes += Array(WriteValue(value, Str("foo"), vr), PCanonicalString(required)) nodes += Array(WriteValue(NA(pDisc.virtualType), Str("foo"), vr), PCanonicalString(optional)) nodes += Array(WriteValue(value, NA(TString), vr), PCanonicalString(optional)) // test bindings - nodes += Array(bindIR(nestedarray(required, optional, optional)) { v => ArrayRef(v, I32(0)) }, - PCanonicalArray(PInt32(optional), optional)) + nodes += Array( + bindIR(nestedarray(required, optional, optional))(v => ArrayRef(v, I32(0))), + PCanonicalArray(PInt32(optional), optional), + ) nodes += { val arr = array(required, required) val elemType = TIterable.elementType(arr.typ) @@ -184,39 +219,57 @@ class RequirednessSuite extends HailSuite { iruid(0) -> int(optional), iruid(1) -> arr, iruid(2) -> ToStream(Ref(iruid(1), arr.typ)), - iruid(4) -> StreamMap(Ref(iruid(2), TStream(elemType)), iruid(3), - ApplyBinaryPrimOp(Multiply(), Ref(iruid(3), elemType), Ref(iruid(0), elemType)) + iruid(4) -> StreamMap( + Ref(iruid(2), TStream(elemType)), + iruid(3), + ApplyBinaryPrimOp(Multiply(), Ref(iruid(3), elemType), Ref(iruid(0), elemType)), ), iruid(5) -> ToArray(Ref(iruid(4), TStream(elemType))), - iruid(6) -> int(required) + iruid(6) -> int(required), ), - ArrayRef(Ref(iruid(5), arr.typ), Ref(iruid(6), TInt32)) + ArrayRef(Ref(iruid(5), arr.typ), Ref(iruid(6), TInt32)), ), - pint(optional) + pint(optional), ) } // filter - nodes += Array(StreamFilter(stream(optional, optional), "x", Ref("x", TInt32).ceq(0)), - EmitType(SStream(EmitType(SInt32, optional)), optional)) + nodes += Array( + StreamFilter(stream(optional, optional), "x", Ref("x", TInt32).ceq(0)), + EmitType(SStream(EmitType(SInt32, optional)), optional), + ) // StreamFold - nodes += Array(StreamFold( - nestedstream(optional, optional, optional), - I32(0), "a", "b", - ArrayRef(Ref("b", tarray), Ref("a", TInt32)) - ), PInt32(optional)) + nodes += Array( + StreamFold( + nestedstream(optional, optional, optional), + I32(0), + "a", + "b", + ArrayRef(Ref("b", tarray), Ref("a", TInt32)), + ), + PInt32(optional), + ) // StreamFold2 - nodes += Array(StreamFold2( - nestedstream(optional, optional, optional), - FastSeq("b" -> I32(0)), "a", - FastSeq(ArrayRef(Ref("a", tarray), Ref("b", TInt32))), - Ref("b", TInt32) - ), PInt32(optional)) + nodes += Array( + StreamFold2( + nestedstream(optional, optional, optional), + FastSeq("b" -> I32(0)), + "a", + FastSeq(ArrayRef(Ref("a", tarray), Ref("b", TInt32))), + Ref("b", TInt32), + ), + PInt32(optional), + ) // StreamScan - nodes += Array(StreamScan( + nodes += Array( + StreamScan( nestedstream(optional, optional, optional), - I32(0), "a", "b", - ArrayRef(Ref("b", tarray), Ref("a", TInt32)) - ), EmitType(SStream(EmitType(SInt32, optional)), optional)) + I32(0), + "a", + "b", + ArrayRef(Ref("b", tarray), Ref("a", TInt32)), + ), + EmitType(SStream(EmitType(SInt32, optional)), optional), + ) // TailLoop val param1 = Ref(genUID(), tarray) val param2 = Ref(genUID(), TInt32) @@ -224,100 +277,168 @@ class RequirednessSuite extends HailSuite { "loop", FastSeq( param1.name -> array(required, required), - param2.name -> int(required)), + param2.name -> int(required), + ), tnestedarray, - If(False(), // required + If( + False(), // required MakeArray(FastSeq(param1), tnestedarray), // required - If(param2 <= I32(1), // possibly missing - Recur("loop", FastSeq( - array(required, optional), - int(required)), tnestedarray), - Recur("loop", FastSeq( - array(optional, required), - int(optional)), tnestedarray)))) + If( + param2 <= I32(1), // possibly missing + Recur( + "loop", + FastSeq( + array(required, optional), + int(required), + ), + tnestedarray, + ), + Recur( + "loop", + FastSeq( + array(optional, required), + int(optional), + ), + tnestedarray, + ), + ), + ), + ) nodes += Array(loop, PCanonicalArray(PCanonicalArray(PInt32(optional), optional), optional)) // Switch - for ((x, d, cs, r) <- Array( - (required, required, FastSeq(required), required), - (optional, required, FastSeq(required), optional), - (required, optional, FastSeq(required), optional), - (required, required, FastSeq(optional), optional), - )) { + for ( + (x, d, cs, r) <- Array( + (required, required, FastSeq(required), required), + (optional, required, FastSeq(required), optional), + (required, optional, FastSeq(required), optional), + (required, required, FastSeq(optional), optional), + ) + ) nodes += Array(Switch(int(x), int(d), cs.map(int)), PInt32(r)) - } // ArrayZip val s1 = Ref(genUID(), TInt32) val s2 = Ref(genUID(), TInt32) val notExtendNA = StreamZip( FastSeq(stream(required, optional), stream(required, required)), FastSeq(s1.name, s2.name), - s1 + s2, ArrayZipBehavior.TakeMinLength) + s1 + s2, + ArrayZipBehavior.TakeMinLength, + ) val extendNA = StreamZip( FastSeq(stream(required, required), stream(required, required)), FastSeq(s1.name, s2.name), - s1 + s2, ArrayZipBehavior.ExtendNA) + s1 + s2, + ArrayZipBehavior.ExtendNA, + ) nodes += Array(notExtendNA, pstream(required, optional)) nodes += Array(extendNA, pstream(required, optional)) // ArraySort - nodes += Array(ArraySort(stream(optional, required), s1.name, s2.name, True()), parray(optional, required)) + nodes += Array( + ArraySort(stream(optional, required), s1.name, s2.name, True()), + parray(optional, required), + ) // CollectDistributedArray - nodes += Array(CollectDistributedArray( - stream(optional, required), - int(optional), - s1.name, s2.name, - s1 + s2, NA(TString), "test"), parray(optional, optional)) + nodes += Array( + CollectDistributedArray( + stream(optional, required), + int(optional), + s1.name, + s2.name, + s1 + s2, + NA(TString), + "test", + ), + parray(optional, optional), + ) // ApplyIR nodes += Array( invoke("argmin", TInt32, array(required, required)), - pint(optional)) + pint(optional), + ) nodes += Array( invoke("argmin", TInt32, array(required, optional)), - pint(optional)) + pint(optional), + ) nodes += Array( invoke("argmin", TInt32, array(optional, required)), - pint(optional)) + pint(optional), + ) // Apply nodes += Array( - invoke("fisher_exact_test", fetStruct.virtualType, - int(required), int(required), int(required), int(required)), - fetStruct.setRequired(required)) + invoke( + "fisher_exact_test", + fetStruct.virtualType, + int(required), + int(required), + int(required), + int(required), + ), + fetStruct.setRequired(required), + ) nodes += Array( - invoke("fisher_exact_test", fetStruct.virtualType, - int(optional), int(required), int(required), int(required)), - fetStruct.setRequired(optional)) + invoke( + "fisher_exact_test", + fetStruct.virtualType, + int(optional), + int(required), + int(required), + int(required), + ), + fetStruct.setRequired(optional), + ) nodes += Array( - invoke("Interval", TInterval(TArray(TInt32)), - array(required, optional), array(required, required), True(), NA(TBoolean)), - PCanonicalInterval(parray(required, optional), optional)) + invoke( + "Interval", + TInterval(TArray(TInt32)), + array(required, optional), + array(required, required), + True(), + NA(TBoolean), + ), + PCanonicalInterval(parray(required, optional), optional), + ) nodes.result() } - @DataProvider(name="tableIR") + @DataProvider(name = "tableIR") def tableIR(): Array[Array[Any]] = { val nodes = new BoxedArrayBuilder[Array[Any]](50) - nodes += Array[Any](TableRange(1, 1), PCanonicalStruct(required, "idx" -> PInt32(required)), PCanonicalStruct.empty(required)) + nodes += Array[Any]( + TableRange(1, 1), + PCanonicalStruct(required, "idx" -> PInt32(required)), + PCanonicalStruct.empty(required), + ) - val table = TableParallelize(makestruct( - "rows" -> MakeArray(makestruct( - "a" -> nestedarray(optional, required, optional), - "b" -> struct(required, required, required, optional), - "c" -> nd(required))), - "global" -> makestruct( - "x" -> array(required, optional), - "y" -> int(optional), - "z" -> struct(required, required, required, optional)) - ), None) + val table = TableParallelize( + makestruct( + "rows" -> MakeArray(makestruct( + "a" -> nestedarray(optional, required, optional), + "b" -> struct(required, required, required, optional), + "c" -> nd(required), + )), + "global" -> makestruct( + "x" -> array(required, optional), + "y" -> int(optional), + "z" -> struct(required, required, required, optional), + ), + ), + None, + ) - val rowType = PCanonicalStruct(required, + val rowType = PCanonicalStruct( + required, "a" -> pnestedarray(optional, required, optional), "b" -> pstruct(required, required, required, optional), - "c" -> pnd(required)) - val globalType = PCanonicalStruct(required, + "c" -> pnd(required), + ) + val globalType = PCanonicalStruct( + required, "x" -> parray(required, optional), "y" -> pint(optional), - "z" -> pstruct(required, required, required, optional)) + "z" -> pstruct(required, required, required, optional), + ) def row = Ref("row", table.typ.rowType) def global = Ref("global", table.typ.globalType) @@ -338,31 +459,51 @@ class RequirednessSuite extends HailSuite { nodes += Array(TableRename(table, rMap, gMap), rowType.rename(rMap), globalType.rename(gMap)) nodes += Array( - TableMapRows(table, insertIR(row, - "a2" -> ApplyScanOp(Collect())(GetField(row, "a")), - "x2" -> GetField(global, "x"))), + TableMapRows( + table, + insertIR( + row, + "a2" -> ApplyScanOp(Collect())(GetField(row, "a")), + "x2" -> GetField(global, "x"), + ), + ), rowType.insertFields(FastSeq( "a2" -> PCanonicalArray(rowType.fieldType("a"), required), - "x2" -> globalType.fieldType("x"))), - globalType) + "x2" -> globalType.fieldType("x"), + )), + globalType, + ) nodes += Array( TableMapGlobals(table, insertIR(global, "x2" -> GetField(global, "x"))), rowType, - globalType.insertFields(FastSeq("x2" -> globalType.fieldType("x")))) + globalType.insertFields(FastSeq("x2" -> globalType.fieldType("x"))), + ) - nodes += Array(TableExplode( - TableMapRows(table, insertIR(row, "e1" -> struct(r = optional, i = required, a = optional, elt = required))), FastSeq("e1", "y")), - rowType.insertFields(FastSeq("e1" -> PCanonicalStruct(required, "x" -> pint(required), "y" -> pint(required)))), - globalType) + nodes += Array( + TableExplode( + TableMapRows( + table, + insertIR(row, "e1" -> struct(r = optional, i = required, a = optional, elt = required)), + ), + FastSeq("e1", "y"), + ), + rowType.insertFields(FastSeq("e1" -> PCanonicalStruct( + required, + "x" -> pint(required), + "y" -> pint(required), + ))), + globalType, + ) nodes += Array( TableUnion(FastSeq( table.deepCopy(), - TableMapRows(table, insertIR(row, "a" -> nestedarray(optional, optional, required))))), + TableMapRows(table, insertIR(row, "a" -> nestedarray(optional, optional, required))), + )), rowType.insertFields(FastSeq("a" -> pnestedarray(optional, optional, optional))), - globalType) - + globalType, + ) val collect = ApplyAggOp(Collect())(GetField(row, "b")) val callstats = ApplyAggOp(CallStats(), int(optional))(call(required)) @@ -370,106 +511,174 @@ class RequirednessSuite extends HailSuite { nodes += Array( TableKeyByAndAggregate(table, expr, makestruct("a" -> GetField(row, "a")), None, 5), - PCanonicalStruct(required, + PCanonicalStruct( + required, "a" -> rowType.fieldType("a"), "collect" -> PCanonicalArray(rowType.fieldType("b"), required), - "callstats" -> CallStatsState.resultPType.setRequired(true)), - globalType) + "callstats" -> CallStatsState.resultPType.setRequired(true), + ), + globalType, + ) nodes += Array( TableAggregateByKey(TableKeyBy(table, FastSeq("a")), expr), - PCanonicalStruct(required, + PCanonicalStruct( + required, "a" -> rowType.fieldType("a"), "collect" -> PCanonicalArray(rowType.fieldType("b"), required), - "callstats" -> CallStatsState.resultPType.setRequired(true)), - globalType) + "callstats" -> CallStatsState.resultPType.setRequired(true), + ), + globalType, + ) val left = TableMapGlobals( - TableKeyBy(TableMapRows(table.deepCopy(), makestruct( - "a" -> nestedarray(required, optional, required), - "b" -> GetField(row, "b"))), FastSeq("a")), - selectIR(global, "x")) + TableKeyBy( + TableMapRows( + table.deepCopy(), + makestruct( + "a" -> nestedarray(required, optional, required), + "b" -> GetField(row, "b"), + ), + ), + FastSeq("a"), + ), + selectIR(global, "x"), + ) val right = TableMapGlobals( - TableKeyBy(TableMapRows(table.deepCopy(), makestruct( - "a" -> nestedarray(required, required, optional), - "c" -> GetField(row, "c"))), FastSeq("a")), - selectIR(global, "y", "z")) + TableKeyBy( + TableMapRows( + table.deepCopy(), + makestruct( + "a" -> nestedarray(required, required, optional), + "c" -> GetField(row, "c"), + ), + ), + FastSeq("a"), + ), + selectIR(global, "y", "z"), + ) nodes += Array( TableJoin(left, right, "left", 1), - PCanonicalStruct(required, + PCanonicalStruct( + required, "a" -> pnestedarray(required, optional, required), "b" -> rowType.fieldType("b"), - "c" -> rowType.fieldType("c").setRequired(optional)), - globalType) + "c" -> rowType.fieldType("c").setRequired(optional), + ), + globalType, + ) nodes += Array( TableJoin(left, right, "right", 1), - PCanonicalStruct(required, + PCanonicalStruct( + required, "a" -> pnestedarray(required, required, optional), "b" -> rowType.fieldType("b").setRequired(optional), - "c" -> rowType.fieldType("c")), - globalType) + "c" -> rowType.fieldType("c"), + ), + globalType, + ) nodes += Array( TableJoin(left, right, "inner", 1), - PCanonicalStruct(required, + PCanonicalStruct( + required, "a" -> pnestedarray(required, required, required), "b" -> rowType.fieldType("b"), - "c" -> rowType.fieldType("c")), - globalType) + "c" -> rowType.fieldType("c"), + ), + globalType, + ) nodes += Array( TableJoin(left, right, "outer", 1), - PCanonicalStruct(required, + PCanonicalStruct( + required, "a" -> pnestedarray(required, optional, optional), "b" -> rowType.fieldType("b").setRequired(optional), - "c" -> rowType.fieldType("c").setRequired(optional)), - globalType) + "c" -> rowType.fieldType("c").setRequired(optional), + ), + globalType, + ) val intervalTable = TableKeyBy( - TableMapRows(table.deepCopy(), makestruct( - "a" -> interval(nestedarray(required, required, optional), required), - "c" -> GetField(row, "c"))), - FastSeq("a")) + TableMapRows( + table.deepCopy(), + makestruct( + "a" -> interval(nestedarray(required, required, optional), required), + "c" -> GetField(row, "c"), + ), + ), + FastSeq("a"), + ) nodes += Array( TableIntervalJoin(left, intervalTable, "root", product = false), - PCanonicalStruct(required, + PCanonicalStruct( + required, "a" -> pnestedarray(required, optional, required), "b" -> rowType.fieldType("b"), - "root" -> PCanonicalStruct(optional, - "c" -> rowType.fieldType("c"))), - globalType.selectFields(FastSeq("x"))) + "root" -> PCanonicalStruct(optional, "c" -> rowType.fieldType("c")), + ), + globalType.selectFields(FastSeq("x")), + ) nodes += Array( TableIntervalJoin(left, intervalTable, "root", product = true), - PCanonicalStruct(required, + PCanonicalStruct( + required, "a" -> pnestedarray(required, optional, required), "b" -> rowType.fieldType("b"), "root" -> PCanonicalArray( - PCanonicalStruct(required, - "c" -> rowType.fieldType("c")), optional)), - globalType.selectFields(FastSeq("x"))) - - nodes += Array(TableMultiWayZipJoin(FastSeq( - TableKeyBy(TableMapRows(table.deepCopy(), insertIR(row, - "a" -> nestedarray(required, optional, required))), FastSeq("a")), - TableKeyBy(TableMapRows(table.deepCopy(), insertIR(row, - "a" -> nestedarray(required, required, optional))), FastSeq("a"))), - "value", "global"), - PCanonicalStruct(required, + PCanonicalStruct(required, "c" -> rowType.fieldType("c")), + optional, + ), + ), + globalType.selectFields(FastSeq("x")), + ) + + nodes += Array( + TableMultiWayZipJoin( + FastSeq( + TableKeyBy( + TableMapRows( + table.deepCopy(), + insertIR(row, "a" -> nestedarray(required, optional, required)), + ), + FastSeq("a"), + ), + TableKeyBy( + TableMapRows( + table.deepCopy(), + insertIR(row, "a" -> nestedarray(required, required, optional)), + ), + FastSeq("a"), + ), + ), + "value", + "global", + ), + PCanonicalStruct( + required, "a" -> pnestedarray(required, optional, optional), - "value" -> PCanonicalArray(PCanonicalStruct(optional, - "b" -> rowType.fieldType("b"), - "c" -> rowType.fieldType("c")), required)), - PCanonicalStruct(required, "global" -> PCanonicalArray(globalType, required))) - - nodes += Array(TableLeftJoinRightDistinct(left, right, "root"), - PCanonicalStruct(required, - "a" -> pnestedarray(required, optional, required), - "b" -> rowType.fieldType("b"), - "root" -> PCanonicalStruct(optional, "c" -> rowType.fieldType("c"))), - globalType.selectFields(FastSeq("x"))) + "value" -> PCanonicalArray( + PCanonicalStruct(optional, "b" -> rowType.fieldType("b"), "c" -> rowType.fieldType("c")), + required, + ), + ), + PCanonicalStruct(required, "global" -> PCanonicalArray(globalType, required)), + ) + + nodes += Array( + TableLeftJoinRightDistinct(left, right, "root"), + PCanonicalStruct( + required, + "a" -> pnestedarray(required, optional, required), + "b" -> rowType.fieldType("b"), + "root" -> PCanonicalStruct(optional, "c" -> rowType.fieldType("c")), + ), + globalType.selectFields(FastSeq("x")), + ) nodes.result() } @@ -478,28 +687,27 @@ class RequirednessSuite extends HailSuite { val s = new BoxedArrayBuilder[String]() valueIR().map(v => v(0) -> v(1)).foreach { case (n: IR, t: PType) => - if (n.typ != t.virtualType) - s += s"${ n.typ } != ${ t.virtualType }: \n${ Pretty(ctx, n) }" + if (n.typ != t.virtualType) + s += s"${n.typ} != ${t.virtualType}: \n${Pretty(ctx, n)}" case (n: IR, et: EmitType) => if (n.typ != et.virtualType) - s += s"${ n.typ } != ${ et.virtualType }: \n${ Pretty(ctx, n) }" + s += s"${n.typ} != ${et.virtualType}: \n${Pretty(ctx, n)}" } tableIR().map(v => (v(0), v(1), v(2))).foreach { case (n: TableIR, row: PType, global: PType) => - if (n.typ.rowType != row.virtualType || n.typ.globalType != global.virtualType ) + if (n.typ.rowType != row.virtualType || n.typ.globalType != global.virtualType) s += - s"""row: ${ n.typ.rowType } vs ${ row.virtualType } - |global: ${ n.typ.globalType } vs ${ global.virtualType }: - |${ Pretty(ctx, n) }" + s"""row: ${n.typ.rowType} vs ${row.virtualType} + |global: ${n.typ.globalType} vs ${global.virtualType}: + |${Pretty(ctx, n)}" |""".stripMargin } assert(s.size == 0, s.result().mkString("\n\n")) } - def /**/dump(m: Memo[BaseTypeWithRequiredness]): String = { + def /**/ dump(m: Memo[BaseTypeWithRequiredness]): String = m.m.map { case (node, t) => - s"${Pretty(ctx, node.t)}: \n$t" + s"${Pretty(ctx, node.t)}: \n$t" }.mkString("\n\n") - } @Test(dataProvider = "valueIR") def testRequiredness(node: IR, expected: Any): Unit = { @@ -510,7 +718,10 @@ class RequirednessSuite extends HailSuite { } val res = Requiredness.apply(node, ctx) val actual = res.r.lookup(node).asInstanceOf[TypeWithRequiredness] - assert(actual.canonicalEmitType(node.typ) == et, s"\n\n${Pretty(ctx, node)}: \n$actual\n\n${ dump(res.r) }") + assert( + actual.canonicalEmitType(node.typ) == et, + s"\n\n${Pretty(ctx, node)}: \n$actual\n\n${dump(res.r)}", + ) } @Test def sharedNodesWorkCorrectly(): Unit = { @@ -519,54 +730,86 @@ class RequirednessSuite extends HailSuite { val node = InsertFields(n2, FastSeq("c" -> GetField(n2, "a"), "d" -> GetField(n2, "b"))) val res = Requiredness.apply(node, ctx) val actual = tcoerce[TypeWithRequiredness](res.r.lookup(node)).canonicalPType(node.typ) - assert(actual == PCanonicalStruct(required, - "a" -> PInt32(required), "b" -> PInt32(required), - "c" -> PInt32(required), "d" -> PInt32(required))) + assert(actual == PCanonicalStruct( + required, + "a" -> PInt32(required), + "b" -> PInt32(required), + "c" -> PInt32(required), + "d" -> PInt32(required), + )) } @Test(dataProvider = "tableIR") def testTableRequiredness(node: TableIR, row: PType, global: PType): Unit = { val res = Requiredness.apply(node, ctx) val actual = res.r.lookup(node).asInstanceOf[RTable] - assert(actual.rowType.canonicalPType(node.typ.rowType) == row, s"\n\n${Pretty(ctx, node)}: \n$actual\n\n${ dump(res.r) }") - assert(actual.globalType.canonicalPType(node.typ.globalType) == global, s"\n\n${Pretty(ctx, node)}: \n$actual\n\n${ dump(res.r) }") + assert( + actual.rowType.canonicalPType(node.typ.rowType) == row, + s"\n\n${Pretty(ctx, node)}: \n$actual\n\n${dump(res.r)}", + ) + assert( + actual.globalType.canonicalPType(node.typ.globalType) == global, + s"\n\n${Pretty(ctx, node)}: \n$actual\n\n${dump(res.r)}", + ) } - @Test def testTableReader() { - val table = TableParallelize(makestruct( - "rows" -> MakeArray(makestruct( - "a" -> nestedarray(optional, required, optional), - "b" -> struct(required, required, required, optional), - "c" -> array(optional, required))), - "global" -> makestruct( - "x" -> array(required, optional), - "y" -> int(optional), - "z" -> struct(required, required, required, optional)) - ), None) + @Test def testTableReader(): Unit = { + val table = TableParallelize( + makestruct( + "rows" -> MakeArray(makestruct( + "a" -> nestedarray(optional, required, optional), + "b" -> struct(required, required, required, optional), + "c" -> array(optional, required), + )), + "global" -> makestruct( + "x" -> array(required, optional), + "y" -> int(optional), + "z" -> struct(required, required, required, optional), + ), + ), + None, + ) val path = ctx.createTmpPath("test-table-requiredness", "ht") - CompileAndEvaluate[Unit](ctx, TableWrite(table, TableNativeWriter(path, overwrite = true)), false) + CompileAndEvaluate[Unit]( + ctx, + TableWrite(table, TableNativeWriter(path, overwrite = true)), + false, + ) val reader = TableNativeReader(fs, TableNativeReaderParameters(path, None)) - for (rType <- Array(table.typ, - TableType(TStruct("a" -> tnestedarray), FastSeq(), TStruct("z" -> tstruct)) - )) { + for ( + rType <- Array( + table.typ, + TableType(TStruct("a" -> tnestedarray), FastSeq(), TStruct("z" -> tstruct)), + ) + ) { val row = reader.rowRequiredness(ctx, rType) val global = reader.globalRequiredness(ctx, rType) val node = TableRead(rType, dropRows = false, reader) val res = Requiredness.apply(node, ctx) val actual = res.r.lookup(node).asInstanceOf[RTable] - assert(VirtualTypeWithReq(rType.rowType, actual.rowType) == row, s"\n\n${ Pretty(ctx, node) }: \n$actual\n\n${ dump(res.r) }") - assert(VirtualTypeWithReq(rType.globalType, actual.globalType) == global, s"\n\n${ Pretty(ctx, node) }: \n$actual\n\n${ dump(res.r) }") + assert( + VirtualTypeWithReq(rType.rowType, actual.rowType) == row, + s"\n\n${Pretty(ctx, node)}: \n$actual\n\n${dump(res.r)}", + ) + assert( + VirtualTypeWithReq(rType.globalType, actual.globalType) == global, + s"\n\n${Pretty(ctx, node)}: \n$actual\n\n${dump(res.r)}", + ) } } @Test def testSubsettedTuple(): Unit = { val node = MakeTuple(FastSeq(0 -> I32(0), 4 -> NA(TInt32), 2 -> NA(TArray(TInt32)))) - val expected = PCanonicalTuple(FastSeq( - PTupleField(0, PInt32(required)), - PTupleField(4, PInt32(optional)), - PTupleField(2, PCanonicalArray(PInt32(required), optional))), required) + val expected = PCanonicalTuple( + FastSeq( + PTupleField(0, PInt32(required)), + PTupleField(4, PInt32(optional)), + PTupleField(2, PCanonicalArray(PInt32(required), optional)), + ), + required, + ) val res = Requiredness.apply(node, ctx) val actual = tcoerce[TypeWithRequiredness](res.r.lookup(node)).canonicalPType(node.typ) assert(actual == expected) @@ -578,11 +821,21 @@ object RequirednessSuite { t match { case t: PCanonicalArray => PCanonicalArray(deepInnerRequired(t.elementType, true), required) case t: PCanonicalSet => PCanonicalSet(deepInnerRequired(t.elementType, true), required) - case t: PCanonicalDict => PCanonicalDict(deepInnerRequired(t.keyType, true), deepInnerRequired(t.valueType, true), required) + case t: PCanonicalDict => PCanonicalDict( + deepInnerRequired(t.keyType, true), + deepInnerRequired(t.valueType, true), + required, + ) case t: PCanonicalStruct => - PCanonicalStruct(t.fields.map(f => PField(f.name, deepInnerRequired(f.typ, true), f.index)), required) + PCanonicalStruct( + t.fields.map(f => PField(f.name, deepInnerRequired(f.typ, true), f.index)), + required, + ) case t: PCanonicalTuple => - PCanonicalTuple(t._types.map { f => f.copy(typ = deepInnerRequired(f.typ, true)) }, required) + PCanonicalTuple( + t._types.map(f => f.copy(typ = deepInnerRequired(f.typ, true))), + required, + ) case t: PCanonicalInterval => PCanonicalInterval(deepInnerRequired(t.pointType, true), required) case t => diff --git a/hail/src/test/scala/is/hail/expr/ir/SetFunctionsSuite.scala b/hail/src/test/scala/is/hail/expr/ir/SetFunctionsSuite.scala index c11eaa4a877..10c74088d8d 100644 --- a/hail/src/test/scala/is/hail/expr/ir/SetFunctionsSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/SetFunctionsSuite.scala @@ -1,13 +1,11 @@ package is.hail.expr.ir import is.hail.{ExecStrategy, HailSuite} -import is.hail.types._ import is.hail.TestUtils._ import is.hail.expr.ir.TestUtils._ import is.hail.types.virtual._ -import is.hail.utils.FastSeq + import org.testng.annotations.Test -import org.scalatest.testng.TestNGSuite class SetFunctionsSuite extends HailSuite { val naa = NA(TArray(TInt32)) @@ -15,7 +13,7 @@ class SetFunctionsSuite extends HailSuite { implicit val execStrats = ExecStrategy.javaOnly - @Test def toSet() { + @Test def toSet(): Unit = { assertEvalsTo(IRSet(3, 7), Set(3, 7)) assertEvalsTo(IRSet(3, null, 7), Set(null, 3, 7)) assertEvalsTo(nas, null) @@ -25,7 +23,7 @@ class SetFunctionsSuite extends HailSuite { assertEvalsTo(invoke("toSet", TSet(TInt32), naa), null) } - @Test def isEmpty() { + @Test def isEmpty(): Unit = { assertEvalsTo(invoke("isEmpty", TBoolean, IRSet(3, 7)), false) assertEvalsTo(invoke("isEmpty", TBoolean, IRSet(3, null, 7)), false) assertEvalsTo(invoke("isEmpty", TBoolean, IRSet()), true) @@ -33,7 +31,7 @@ class SetFunctionsSuite extends HailSuite { assertEvalsTo(invoke("isEmpty", TBoolean, nas), null) } - @Test def contains() { + @Test def contains(): Unit = { val s = IRSet(3, null, 7) val swoutna = IRSet(3, 7) @@ -47,7 +45,7 @@ class SetFunctionsSuite extends HailSuite { assert(eval(invoke("contains", TBoolean, IRSet(), 3)) == false) } - @Test def remove() { + @Test def remove(): Unit = { val s = IRSet(3, null, 7) assertEvalsTo(invoke("remove", TSet(TInt32), s, I32(3)), Set(null, 7)) assertEvalsTo(invoke("remove", TSet(TInt32), s, I32(4)), Set(null, 3, 7)) @@ -55,7 +53,7 @@ class SetFunctionsSuite extends HailSuite { assertEvalsTo(invoke("remove", TSet(TInt32), IRSet(3, 7), NA(TInt32)), Set(3, 7)) } - @Test def add() { + @Test def add(): Unit = { val s = IRSet(3, null, 7) assertEvalsTo(invoke("add", TSet(TInt32), s, I32(3)), Set(null, 3, 7)) assertEvalsTo(invoke("add", TSet(TInt32), s, I32(4)), Set(null, 3, 4, 7)) @@ -64,30 +62,44 @@ class SetFunctionsSuite extends HailSuite { assertEvalsTo(invoke("add", TSet(TInt32), IRSet(3, 7), NA(TInt32)), Set(null, 3, 7)) } - @Test def isSubset() { + @Test def isSubset(): Unit = { val s = IRSet(3, null, 7) assertEvalsTo(invoke("isSubset", TBoolean, s, invoke("add", TSet(TInt32), s, I32(4))), true) - assertEvalsTo(invoke("isSubset", TBoolean, IRSet(3, 7), invoke("add", TSet(TInt32), IRSet(3, 7), NA(TInt32))), true) + assertEvalsTo( + invoke( + "isSubset", + TBoolean, + IRSet(3, 7), + invoke("add", TSet(TInt32), IRSet(3, 7), NA(TInt32)), + ), + true, + ) assertEvalsTo(invoke("isSubset", TBoolean, s, invoke("remove", TSet(TInt32), s, I32(3))), false) - assertEvalsTo(invoke("isSubset", TBoolean, s, invoke("remove", TSet(TInt32), s, NA(TInt32))), false) + assertEvalsTo( + invoke("isSubset", TBoolean, s, invoke("remove", TSet(TInt32), s, NA(TInt32))), + false, + ) } - @Test def union() { + @Test def union(): Unit = { assertEvalsTo(invoke("union", TSet(TInt32), IRSet(3, null, 7), IRSet(3, 8)), Set(null, 3, 7, 8)) assertEvalsTo(invoke("union", TSet(TInt32), IRSet(3, 7), IRSet(3, 8, null)), Set(null, 3, 7, 8)) } - @Test def intersection() { + @Test def intersection(): Unit = { assertEvalsTo(invoke("intersection", TSet(TInt32), IRSet(3, null, 7), IRSet(3, 8)), Set(3)) - assertEvalsTo(invoke("intersection", TSet(TInt32), IRSet(3, null, 7), IRSet(3, 8, null)), Set(null, 3)) + assertEvalsTo( + invoke("intersection", TSet(TInt32), IRSet(3, null, 7), IRSet(3, 8, null)), + Set(null, 3), + ) } - @Test def difference() { + @Test def difference(): Unit = { assertEvalsTo(invoke("difference", TSet(TInt32), IRSet(3, null, 7), IRSet(3, 8)), Set(null, 7)) assertEvalsTo(invoke("difference", TSet(TInt32), IRSet(3, null, 7), IRSet(3, 8, null)), Set(7)) } - @Test def median() { + @Test def median(): Unit = { assertEvalsTo(invoke("median", TInt32, IRSet(5)), 5) assertEvalsTo(invoke("median", TInt32, IRSet(5, null)), 5) assertEvalsTo(invoke("median", TInt32, IRSet(3, 7)), 5) diff --git a/hail/src/test/scala/is/hail/expr/ir/SimplifySuite.scala b/hail/src/test/scala/is/hail/expr/ir/SimplifySuite.scala index e0bbad38c06..3f48b09bd8d 100644 --- a/hail/src/test/scala/is/hail/expr/ir/SimplifySuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/SimplifySuite.scala @@ -1,10 +1,11 @@ package is.hail.expr.ir +import is.hail.{ExecStrategy, HailSuite} import is.hail.expr.ir.TestUtils.IRAggCount import is.hail.types.virtual._ import is.hail.utils.{FastSeq, Interval} import is.hail.variant.Locus -import is.hail.{ExecStrategy, HailSuite} + import org.apache.spark.sql.Row import org.scalatest.Matchers.{be, convertToAnyShouldWrapper} import org.testng.annotations.{BeforeMethod, DataProvider, Test} @@ -13,90 +14,122 @@ class SimplifySuite extends HailSuite { implicit val execStrats = ExecStrategy.interpretOnly @BeforeMethod - def resetUidCounter(): Unit = { + def resetUidCounter(): Unit = is.hail.expr.ir.uidCounter = 0 - } - @Test def testTableMultiWayZipJoinGlobalsRewrite() { + @Test def testTableMultiWayZipJoinGlobalsRewrite(): Unit = { hc val tmwzj = TableGetGlobals(TableMultiWayZipJoin( - Array(TableRange(10, 10), - TableRange(10, 10), - TableRange(10, 10)), + Array(TableRange(10, 10), TableRange(10, 10), TableRange(10, 10)), "rowField", - "globalField")) + "globalField", + )) assertEvalsTo(tmwzj, Row(FastSeq(Row(), Row(), Row()))) } - @Test def testRepartitionableMapUpdatesForUpstreamOptimizations() { + @Test def testRepartitionableMapUpdatesForUpstreamOptimizations(): Unit = { hc val range = TableKeyBy(TableRange(10, 3), FastSeq()) val simplifiableIR = - If(True(), - GetField(Ref("row", range.typ.rowType), "idx").ceq(0), - False()) + If(True(), GetField(Ref("row", range.typ.rowType), "idx").ceq(0), False()) val checksRepartitioningIR = TableFilter( TableOrderBy(range, FastSeq(SortField("idx", Ascending))), - simplifiableIR) + simplifiableIR, + ) assertEvalsTo(TableAggregate(checksRepartitioningIR, IRAggCount), 1L) } - lazy val base = Literal(TStruct("1" -> TInt32, "2" -> TInt32), Row(1,2)) + lazy val base = Literal(TStruct("1" -> TInt32, "2" -> TInt32), Row(1, 2)) - @Test def testInsertFieldsRewriteRules() { + @Test def testInsertFieldsRewriteRules(): Unit = { val ir1 = InsertFields(InsertFields(base, Seq("1" -> I32(2)), None), Seq("1" -> I32(3)), None) assert(Simplify(ctx, ir1) == InsertFields(base, Seq("1" -> I32(3)), Some(FastSeq("1", "2")))) - val ir2 = InsertFields(InsertFields(base, Seq("3" -> I32(2)), Some(FastSeq("3", "1", "2"))), Seq("3" -> I32(3)), None) - assert(Simplify(ctx, ir2) == InsertFields(base, Seq("3" -> I32(3)), Some(FastSeq("3", "1", "2")))) - - val ir3 = InsertFields(InsertFields(base, Seq("3" -> I32(2)), Some(FastSeq("3", "1", "2"))), Seq("4" -> I32(3)), Some(FastSeq("3", "1", "2", "4"))) - assert(Simplify(ctx, ir3) == InsertFields(base, Seq("3" -> I32(2), "4" -> I32(3)), Some(FastSeq("3", "1", "2", "4")))) - - val ir4 = InsertFields(InsertFields(base, Seq("3" -> I32(0), "4" -> I32(1))), Seq("3" -> I32(5))) - assert(Simplify(ctx, ir4) == InsertFields(base, Seq("4" -> I32(1), "3" -> I32(5)), Some(FastSeq("1", "2", "3", "4")))) + val ir2 = InsertFields( + InsertFields(base, Seq("3" -> I32(2)), Some(FastSeq("3", "1", "2"))), + Seq("3" -> I32(3)), + None, + ) + assert(Simplify(ctx, ir2) == InsertFields( + base, + Seq("3" -> I32(3)), + Some(FastSeq("3", "1", "2")), + )) + + val ir3 = InsertFields( + InsertFields(base, Seq("3" -> I32(2)), Some(FastSeq("3", "1", "2"))), + Seq("4" -> I32(3)), + Some(FastSeq("3", "1", "2", "4")), + ) + assert(Simplify(ctx, ir3) == InsertFields( + base, + Seq("3" -> I32(2), "4" -> I32(3)), + Some(FastSeq("3", "1", "2", "4")), + )) + + val ir4 = + InsertFields(InsertFields(base, Seq("3" -> I32(0), "4" -> I32(1))), Seq("3" -> I32(5))) + assert(Simplify(ctx, ir4) == InsertFields( + base, + Seq("4" -> I32(1), "3" -> I32(5)), + Some(FastSeq("1", "2", "3", "4")), + )) } - lazy val base2 = Literal(TStruct("A" -> TInt32, "B" -> TInt32, "C" -> TInt32, "D" -> TInt32), Row(1, 2, 3, 4)) + lazy val base2 = + Literal(TStruct("A" -> TInt32, "B" -> TInt32, "C" -> TInt32, "D" -> TInt32), Row(1, 2, 3, 4)) + @Test def testInsertFieldsWhereFieldBeingInsertedCouldBeSelected(): Unit = { val ir1 = - InsertFields( - SelectFields(base2, IndexedSeq("A", "B", "C")), - IndexedSeq("B" -> GetField(base2, "B")), - None - ) + InsertFields( + SelectFields(base2, IndexedSeq("A", "B", "C")), + IndexedSeq("B" -> GetField(base2, "B")), + None, + ) val simplify1 = Simplify(ctx, ir1) assert(simplify1.typ == ir1.typ) } - @Test def testInsertSelectRewriteRules() { + @Test def testInsertSelectRewriteRules(): Unit = { val ir1 = SelectFields(InsertFields(base, FastSeq("3" -> I32(1)), None), FastSeq("1")) assert(Simplify(ctx, ir1) == SelectFields(base, FastSeq("1"))) val ir2 = SelectFields(InsertFields(base, FastSeq("3" -> I32(1)), None), FastSeq("3", "1")) - assert(Simplify(ctx, ir2) == InsertFields(SelectFields(base, FastSeq("1")), FastSeq("3" -> I32(1)), Some(FastSeq("3", "1")))) + assert(Simplify(ctx, ir2) == InsertFields( + SelectFields(base, FastSeq("1")), + FastSeq("3" -> I32(1)), + Some(FastSeq("3", "1")), + )) } - @Test def testContainsRewrites() { - assertEvalsTo(invoke("contains", TBoolean, Literal(TArray(TString), FastSeq("a")), In(0, TString)), + @Test def testContainsRewrites(): Unit = { + assertEvalsTo( + invoke("contains", TBoolean, Literal(TArray(TString), FastSeq("a")), In(0, TString)), FastSeq("a" -> TString), - true) + true, + ) - assertEvalsTo(invoke("contains", TBoolean, ToSet(ToStream(In(0, TArray(TString)))), Str("a")), + assertEvalsTo( + invoke("contains", TBoolean, ToSet(ToStream(In(0, TArray(TString)))), Str("a")), FastSeq(FastSeq("a") -> TArray(TString)), - true) - + true, + ) - assertEvalsTo(invoke("contains", TBoolean, ToArray(ToStream(In(0, TSet(TString)))), Str("a")), + assertEvalsTo( + invoke("contains", TBoolean, ToArray(ToStream(In(0, TSet(TString)))), Str("a")), FastSeq(Set("a") -> TSet(TString)), - true) + true, + ) } - @Test def testTableCountExplodeSetRewrite() { + @Test def testTableCountExplodeSetRewrite(): Unit = { var ir: TableIR = TableRange(1, 1) - ir = TableMapRows(ir, InsertFields(Ref("row", ir.typ.rowType), Seq("foo" -> Literal(TSet(TInt32), Set(1))))) + ir = TableMapRows( + ir, + InsertFields(Ref("row", ir.typ.rowType), Seq("foo" -> Literal(TSet(TInt32), Set(1)))), + ) ir = TableExplode(ir, FastSeq("foo")) assertEvalsTo(TableCount(ir), 1L) } @@ -109,44 +142,56 @@ class SimplifySuite extends HailSuite { Array( Array( - Let(FastSeq(r2.name -> InsertFields(r, FastSeq("y" -> F64(0)))), - InsertFields(r2, FastSeq("z" -> GetField(r2, "x").toD)) + Let( + FastSeq(r2.name -> InsertFields(r, FastSeq("y" -> F64(0)))), + InsertFields(r2, FastSeq("z" -> GetField(r2, "x").toD)), ), - Let(FastSeq(iruid(0L) -> F64(0), r2.name -> r), - InsertFields(Ref(r2.name, r.typ), + Let( + FastSeq(iruid(0L) -> F64(0), r2.name -> r), + InsertFields( + Ref(r2.name, r.typ), FastSeq( "y" -> Ref(iruid(0L), TFloat64), - "z" -> GetField(Ref(r2.name, r.typ), "x").toD + "z" -> GetField(Ref(r2.name, r.typ), "x").toD, ), - Some(FastSeq("x", "y", "z")) - ) - ) + Some(FastSeq("x", "y", "z")), + ), + ), ), Array( - Let(FastSeq(r2.name -> InsertFields(r, FastSeq("y" -> F64(0)))), - InsertFields(r2, FastSeq("z" -> (GetField(r2, "x").toD + GetField(r2, "y")))) + Let( + FastSeq(r2.name -> InsertFields(r, FastSeq("y" -> F64(0)))), + InsertFields(r2, FastSeq("z" -> (GetField(r2, "x").toD + GetField(r2, "y")))), ), - Let(FastSeq(iruid(0) -> F64(0), r2.name -> r), - InsertFields(Ref(r2.name, r.typ), + Let( + FastSeq(iruid(0) -> F64(0), r2.name -> r), + InsertFields( + Ref(r2.name, r.typ), FastSeq( "y" -> Ref(iruid(0), TFloat64), - "z" -> (GetField(Ref(r2.name, r.typ), "x").toD + Ref(iruid(0), TFloat64)) + "z" -> (GetField(Ref(r2.name, r.typ), "x").toD + Ref(iruid(0), TFloat64)), ), - Some(FastSeq("x", "y", "z")) - ) - ) + Some(FastSeq("x", "y", "z")), + ), + ), ), Array( - Let(FastSeq(r2.name -> InsertFields(r, FastSeq("y" -> F64(0)))), - InsertFields(Ref("something_else", TStruct.empty), FastSeq("z" -> GetField(r2, "y").toI)) + Let( + FastSeq(r2.name -> InsertFields(r, FastSeq("y" -> F64(0)))), + InsertFields(Ref("something_else", TStruct.empty), FastSeq("z" -> GetField(r2, "y").toI)), + ), + Let( + FastSeq(iruid(0) -> F64(0), r2.name -> r), + InsertFields( + Ref("something_else", TStruct.empty), + FastSeq("z" -> Ref(iruid(0), TFloat64).toI), + ), ), - Let(FastSeq(iruid(0) -> F64(0), r2.name -> r), - InsertFields(Ref("something_else", TStruct.empty), FastSeq("z" -> Ref(iruid(0), TFloat64).toI)) - ) ), Array.fill(2) { // unrewriteable - Let(FastSeq(r2.name -> InsertFields(r, FastSeq("y" -> Ref("other", TFloat64)))), - InsertFields(r2, FastSeq(("z", invoke("str", TString, r2)))) + Let( + FastSeq(r2.name -> InsertFields(r, FastSeq("y" -> Ref("other", TFloat64)))), + InsertFields(r2, FastSeq(("z", invoke("str", TString, r2)))), ) }, Array( @@ -154,9 +199,9 @@ class SimplifySuite extends HailSuite { FastSeq( "a" -> I32(32), r2.name -> InsertFields(r, FastSeq("y" -> F64(0))), - r3.name -> InsertFields(r2, FastSeq("w" -> Ref("a", TInt32))) + r3.name -> InsertFields(r2, FastSeq("w" -> Ref("a", TInt32))), ), - InsertFields(r3, FastSeq("z" -> (GetField(r3, "x").toD + GetField(r3, "y")))) + InsertFields(r3, FastSeq("z" -> (GetField(r3, "x").toD + GetField(r3, "y")))), ), Let( FastSeq( @@ -165,30 +210,31 @@ class SimplifySuite extends HailSuite { r2.name -> r, iruid(1) -> Ref(iruid(0), TFloat64), iruid(2) -> Ref("a", TInt32), - r3.name -> Ref(r2.name, r.typ) + r3.name -> Ref(r2.name, r.typ), ), InsertFields( Ref(r3.name, r.typ), FastSeq( "y" -> Ref(iruid(1), TFloat64), "w" -> Ref(iruid(2), TInt32), - "z" -> (GetField(Ref(r3.name, r.typ), "x").toD + Ref(iruid(1), TFloat64)) + "z" -> (GetField(Ref(r3.name, r.typ), "x").toD + Ref(iruid(1), TFloat64)), ), - Some(FastSeq("x", "y", "w", "z")) - ) - ) + Some(FastSeq("x", "y", "w", "z")), + ), + ), ), Array( Let( FastSeq( - "a" -> Let(FastSeq("b" -> (I32(1) + Ref("OTHER_1", TInt32))), + "a" -> Let( + FastSeq("b" -> (I32(1) + Ref("OTHER_1", TInt32))), InsertFields( Ref("TOP", TStruct("foo" -> TInt32)), FastSeq( "field0" -> Ref("b", TInt32), - "field1" -> (I32(1) + Ref("b", TInt32)) - ) - ) + "field1" -> (I32(1) + Ref("b", TInt32)), + ), + ), ) ), InsertFields( @@ -197,34 +243,34 @@ class SimplifySuite extends HailSuite { "field2" -> (I32(1) + GetField( Ref("a", TStruct("foo" -> TInt32, "field0" -> TInt32, "field1" -> TInt32)), - "field1" + "field1", )) - ) - ) + ), + ), ), Let( FastSeq( "b" -> (I32(1) + Ref("OTHER_1", TInt32)), iruid(0) -> Ref("b", TInt32), iruid(1) -> (I32(1) + Ref("b", TInt32)), - "a" -> Ref("TOP", TStruct("foo" -> TInt32)) + "a" -> Ref("TOP", TStruct("foo" -> TInt32)), ), InsertFields( Ref("a", TStruct("foo" -> TInt32)), FastSeq( "field0" -> Ref(iruid(0), TInt32), "field1" -> Ref(iruid(1), TInt32), - "field2" -> (I32(1) + Ref(iruid(1), TInt32)) + "field2" -> (I32(1) + Ref(iruid(1), TInt32)), ), - Some(FastSeq("foo", "field0", "field1", "field2")) - ) - ) - ) + Some(FastSeq("foo", "field0", "field1", "field2")), + ), + ), + ), ) } @Test(dataProvider = "NestedInserts") - def testNestedInsertsSimplify(input: IR, expected: IR): Unit = { + def testNestedInsertsSimplify(input: IR, expected: IR): Unit = { val actual = Simplify(ctx, input) actual should be(expected) } @@ -232,47 +278,65 @@ class SimplifySuite extends HailSuite { @Test def testArrayAggNoAggRewrites(): Unit = { val doesRewrite: Array[StreamAgg] = Array( StreamAgg(In(0, TArray(TInt32)), "foo", Ref("x", TInt32)), - StreamAgg(In(0, TArray(TInt32)), "foo", - AggLet("bar", In(1, TInt32) * In(1, TInt32), Ref("x", TInt32), true))) + StreamAgg( + In(0, TArray(TInt32)), + "foo", + AggLet("bar", In(1, TInt32) * In(1, TInt32), Ref("x", TInt32), true), + ), + ) - doesRewrite.foreach { a => - assert(Simplify(ctx, a) == a.query) - } + doesRewrite.foreach(a => assert(Simplify(ctx, a) == a.query)) val doesNotRewrite: Array[StreamAgg] = Array( - StreamAgg(In(0, TArray(TInt32)), "foo", - ApplyAggOp(FastSeq(), FastSeq(Ref("foo", TInt32)), - AggSignature(Sum(), FastSeq(), FastSeq(TInt32)))), - StreamAgg(In(0, TArray(TInt32)), "foo", - AggLet("bar", In(1, TInt32) * In(1, TInt32), Ref("x", TInt32), false)) + StreamAgg( + In(0, TArray(TInt32)), + "foo", + ApplyAggOp( + FastSeq(), + FastSeq(Ref("foo", TInt32)), + AggSignature(Sum(), FastSeq(), FastSeq(TInt32)), + ), + ), + StreamAgg( + In(0, TArray(TInt32)), + "foo", + AggLet("bar", In(1, TInt32) * In(1, TInt32), Ref("x", TInt32), false), + ), ) - doesNotRewrite.foreach { a => - assert(Simplify(ctx, a) == a) - } + doesNotRewrite.foreach(a => assert(Simplify(ctx, a) == a)) } @Test def testArrayAggScanNoAggRewrites(): Unit = { val doesRewrite: Array[StreamAggScan] = Array( StreamAggScan(In(0, TArray(TInt32)), "foo", Ref("x", TInt32)), - StreamAggScan(In(0, TArray(TInt32)), "foo", - AggLet("bar", In(1, TInt32) * In(1, TInt32), Ref("x", TInt32), false))) + StreamAggScan( + In(0, TArray(TInt32)), + "foo", + AggLet("bar", In(1, TInt32) * In(1, TInt32), Ref("x", TInt32), false), + ), + ) - doesRewrite.foreach { a => - assert(Simplify(ctx, a) == a.query) - } + doesRewrite.foreach(a => assert(Simplify(ctx, a) == a.query)) val doesNotRewrite: Array[StreamAggScan] = Array( - StreamAggScan(In(0, TArray(TInt32)), "foo", - ApplyScanOp(FastSeq(), FastSeq(Ref("foo", TInt32)), - AggSignature(Sum(), FastSeq(), FastSeq(TInt64)))), - StreamAggScan(In(0, TArray(TInt32)), "foo", - AggLet("bar", In(1, TInt32) * In(1, TInt32), Ref("x", TInt32), true)) + StreamAggScan( + In(0, TArray(TInt32)), + "foo", + ApplyScanOp( + FastSeq(), + FastSeq(Ref("foo", TInt32)), + AggSignature(Sum(), FastSeq(), FastSeq(TInt64)), + ), + ), + StreamAggScan( + In(0, TArray(TInt32)), + "foo", + AggLet("bar", In(1, TInt32) * In(1, TInt32), Ref("x", TInt32), true), + ), ) - doesNotRewrite.foreach { a => - assert(Simplify(ctx, a) == a) - } + doesNotRewrite.foreach(a => assert(Simplify(ctx, a) == a)) } @Test def testArrayLenCollectToTableCount(): Unit = { @@ -288,18 +352,25 @@ class SimplifySuite extends HailSuite { val reader = MatrixRangeReader(1, 1, None) var mir: MatrixIR = MatrixRead(reader.fullMatrixType, false, false, reader) val colType = reader.fullMatrixType.colType - mir = MatrixMapCols(mir, AggLet("foo", I32(1), InsertFields(Ref("sa", colType), FastSeq(("bar", I32(2)))), false), None) + mir = MatrixMapCols( + mir, + AggLet("foo", I32(1), InsertFields(Ref("sa", colType), FastSeq(("bar", I32(2)))), false), + None, + ) val tir = MatrixColsTable(mir) assert(Simplify(ctx, tir) == tir) } - @Test def testFilterParallelize() { - for (rowsAndGlobals <- Array( - MakeStruct(FastSeq( - ("rows", In(0, TArray(TStruct("x" -> TInt32)))), - ("global", In(1, TStruct.empty)))), - In(0, TStruct("rows" -> TArray(TStruct("x" -> TInt32)), "global" -> TStruct.empty))) + @Test def testFilterParallelize(): Unit = { + for ( + rowsAndGlobals <- Array( + MakeStruct(FastSeq( + ("rows", In(0, TArray(TStruct("x" -> TInt32)))), + ("global", In(1, TStruct.empty)), + )), + In(0, TStruct("rows" -> TArray(TStruct("x" -> TInt32)), "global" -> TStruct.empty)), + ) ) { val tp = TableParallelize(rowsAndGlobals, None) val tf = TableFilter(tp, GetField(Ref("row", tp.typ.rowType), "x") < 100) @@ -313,25 +384,29 @@ class SimplifySuite extends HailSuite { @Test def testStreamLenSimplifications(): Unit = { val rangeIR = StreamRange(I32(0), I32(10), I32(1)) val mapOfRange = mapIR(rangeIR)(range_element => range_element + 5) - val mapBlockedByLet = bindIR(I32(5))(ref => mapIR(rangeIR)(range_element => range_element + ref)) + val mapBlockedByLet = + bindIR(I32(5))(ref => mapIR(rangeIR)(range_element => range_element + ref)) assert(Simplify(ctx, StreamLen(rangeIR)) == Simplify(ctx, StreamLen(mapOfRange))) assert(Simplify(ctx, StreamLen(mapBlockedByLet)) match { - case Let(_, body) => body == Simplify(ctx, StreamLen(mapOfRange)) + case Block(_, body) => body == Simplify(ctx, StreamLen(mapOfRange)) }) } - @Test def testNestedFilterIntervals() { + @Test def testNestedFilterIntervals(): Unit = { var tir: TableIR = TableRange(10, 5) def r = Ref("row", tir.typ.rowType) - tir = TableMapRows(tir, InsertFields(r, FastSeq("idx2" -> GetField(r, "idx")))) + tir = TableMapRows(tir, InsertFields(r, FastSeq("idx2" -> GetField(r, "idx")))) tir = TableKeyBy(tir, FastSeq("idx", "idx2")) tir = TableFilterIntervals(tir, FastSeq(Interval(Row(0), Row(1), true, false)), false) tir = TableFilterIntervals(tir, FastSeq(Interval(Row(8), Row(10), true, false)), false) - assert(Simplify(ctx, tir).asInstanceOf[TableFilterIntervals].intervals == FastSeq(Interval(Row(0), Row(1), true, false), Interval(Row(8), Row(10), true, false))) + assert(Simplify(ctx, tir).asInstanceOf[TableFilterIntervals].intervals == FastSeq( + Interval(Row(0), Row(1), true, false), + Interval(Row(8), Row(10), true, false), + )) } - @Test def testSimplifyReadFilterIntervals() { + @Test def testSimplifyReadFilterIntervals(): Unit = { val src = "src/test/resources/sample-indexed-0.2.52.mt" val mnr = MatrixNativeReader(fs, src, None) @@ -343,38 +418,86 @@ class SimplifySuite extends HailSuite { val tzr = mr.lower().asInstanceOf[TableMapGlobals].child.asInstanceOf[TableRead] val tzrr = tzr.tr.asInstanceOf[TableNativeZippedReader] - val intervals1 = FastSeq(Interval(Row(Locus("1", 100000)), Row(Locus("1", 200000)), true, false), Interval(Row(Locus("2", 100000)), Row(Locus("2", 200000)), true, false)) - val intervals2 = FastSeq(Interval(Row(Locus("1", 150000)), Row(Locus("1", 250000)), true, false), Interval(Row(Locus("2", 150000)), Row(Locus("2", 250000)), true, false)) - val intersection = FastSeq(Interval(Row(Locus("1", 150000)), Row(Locus("1", 200000)), true, false), Interval(Row(Locus("2", 150000)), Row(Locus("2", 200000)), true, false)) + val intervals1 = FastSeq( + Interval(Row(Locus("1", 100000)), Row(Locus("1", 200000)), true, false), + Interval(Row(Locus("2", 100000)), Row(Locus("2", 200000)), true, false), + ) + val intervals2 = FastSeq( + Interval(Row(Locus("1", 150000)), Row(Locus("1", 250000)), true, false), + Interval(Row(Locus("2", 150000)), Row(Locus("2", 250000)), true, false), + ) + val intersection = FastSeq( + Interval(Row(Locus("1", 150000)), Row(Locus("1", 200000)), true, false), + Interval(Row(Locus("2", 150000)), Row(Locus("2", 200000)), true, false), + ) val tfi1 = TableFilterIntervals(tr, intervals1, true) - val exp1 = TableRead(tnr.fullType, false, TableNativeReader(fs, TableNativeReaderParameters(src + "/rows", Some(NativeReaderOptions(intervals1, tnr.fullType.keyType, true))))) + val exp1 = TableRead( + tnr.fullType, + false, + TableNativeReader( + fs, + TableNativeReaderParameters( + src + "/rows", + Some(NativeReaderOptions(intervals1, tnr.fullType.keyType, true)), + ), + ), + ) assert(Simplify(ctx, tfi1) == exp1) val tfi2 = TableFilterIntervals(exp1, intervals2, true) - val exp2 = TableRead(tnr.fullType, false, TableNativeReader(fs, TableNativeReaderParameters(src + "/rows", Some(NativeReaderOptions(intersection, tnr.fullType.keyType, true))))) + val exp2 = TableRead( + tnr.fullType, + false, + TableNativeReader( + fs, + TableNativeReaderParameters( + src + "/rows", + Some(NativeReaderOptions(intersection, tnr.fullType.keyType, true)), + ), + ), + ) assert(Simplify(ctx, tfi2) == exp2) val ztfi1 = TableFilterIntervals(tzr, intervals1, true) - val zexp1 = TableRead(tzr.typ, false, tzrr.copy(options = Some(NativeReaderOptions(intervals1, tnr.fullType.keyType, true)))) + val zexp1 = TableRead( + tzr.typ, + false, + tzrr.copy(options = Some(NativeReaderOptions(intervals1, tnr.fullType.keyType, true))), + ) assert(Simplify(ctx, ztfi1) == zexp1) val ztfi2 = TableFilterIntervals(ztfi1, intervals2, true) - val zexp2 = TableRead(tzr.typ, false, tzrr.copy(options = Some(NativeReaderOptions(intersection, tnr.fullType.keyType, true)))) + val zexp2 = TableRead( + tzr.typ, + false, + tzrr.copy(options = Some(NativeReaderOptions(intersection, tnr.fullType.keyType, true))), + ) assert(Simplify(ctx, ztfi2) == zexp2) } - @Test(enabled = false) def testFilterIntervalsKeyByToFilter() { + @Test(enabled = false) def testFilterIntervalsKeyByToFilter(): Unit = { var t: TableIR = TableRange(100, 10) - t = TableMapRows(t, InsertFields(Ref("row", t.typ.rowType), FastSeq(("x", I32(1) - GetField(Ref("row", t.typ.rowType), "idx"))))) + t = TableMapRows( + t, + InsertFields( + Ref("row", t.typ.rowType), + FastSeq(("x", I32(1) - GetField(Ref("row", t.typ.rowType), "idx"))), + ), + ) t = TableKeyBy(t, FastSeq("x")) - t = TableFilterIntervals(t, FastSeq(Interval(Row(-10), Row(10), includesStart = true, includesEnd = false)), keep = true) + t = TableFilterIntervals( + t, + FastSeq(Interval(Row(-10), Row(10), includesStart = true, includesEnd = false)), + keep = true, + ) val t2 = Simplify(ctx, t) assert(t2 match { - case TableKeyBy(TableFilter(child, _), _, _) => !Exists(child, _.isInstanceOf[TableFilterIntervals]) + case TableKeyBy(TableFilter(child, _), _, _) => + !Exists(child, _.isInstanceOf[TableFilterIntervals]) case _ => false }) } @@ -383,30 +506,30 @@ class SimplifySuite extends HailSuite { val stream = StreamRange(I32(0), I32(10), I32(1)) val streamSlice1 = Simplify(ctx, ArraySlice(ToArray(stream), I32(0), Some(I32(7)))) assert(streamSlice1 match { - case ToArray(StreamTake(_,_)) => true + case ToArray(StreamTake(_, _)) => true case _ => false - } ) + }) assertEvalsTo(streamSlice1.asInstanceOf[IR], FastSeq(0, 1, 2, 3, 4, 5, 6)) val streamSlice2 = Simplify(ctx, ArraySlice(ToArray(stream), I32(3), Some(I32(5)))) assert(streamSlice2 match { - case ToArray(StreamTake(StreamDrop(_,_), _)) => true + case ToArray(StreamTake(StreamDrop(_, _), _)) => true case _ => false - } ) + }) assertEvalsTo(streamSlice2.asInstanceOf[IR], FastSeq(3, 4)) val streamSlice3 = Simplify(ctx, ArraySlice(ToArray(stream), I32(6), Some(I32(2)))) assert(streamSlice3 match { case MakeArray(_, _) => true case _ => false - } ) + }) assertEvalsTo(streamSlice3.asInstanceOf[IR], FastSeq()) val streamSlice4 = Simplify(ctx, ArraySlice(ToArray(stream), I32(0), None)) assert(streamSlice4 match { case ToArray(StreamDrop(_, _)) => true case _ => false - } ) + }) assertEvalsTo(streamSlice4.asInstanceOf[IR], FastSeq(0, 1, 2, 3, 4, 5, 6, 7, 8, 9)) } @@ -428,8 +551,14 @@ class SimplifySuite extends HailSuite { Array( Array(ApplyUnaryPrimOp(Negate, ApplyUnaryPrimOp(Negate, ref(typ))), ref(typ)), Array(ApplyUnaryPrimOp(BitNot, ApplyUnaryPrimOp(BitNot, ref(typ))), ref(typ)), - Array(ApplyUnaryPrimOp(Negate, ApplyUnaryPrimOp(BitNot, ref(typ))), ApplyUnaryPrimOp(Negate, ApplyUnaryPrimOp(BitNot, ref(typ)))), - Array(ApplyUnaryPrimOp(BitNot, ApplyUnaryPrimOp(Negate, ref(typ))), ApplyUnaryPrimOp(BitNot, ApplyUnaryPrimOp(Negate, ref(typ)))) + Array( + ApplyUnaryPrimOp(Negate, ApplyUnaryPrimOp(BitNot, ref(typ))), + ApplyUnaryPrimOp(Negate, ApplyUnaryPrimOp(BitNot, ref(typ))), + ), + Array( + ApplyUnaryPrimOp(BitNot, ApplyUnaryPrimOp(Negate, ref(typ))), + ApplyUnaryPrimOp(BitNot, ApplyUnaryPrimOp(Negate, ref(typ))), + ), ).asInstanceOf[Array[Array[Any]]] } @@ -437,61 +566,80 @@ class SimplifySuite extends HailSuite { def testUnaryIntegralSimplification(input: IR, expected: IR): Unit = assert(Simplify(ctx, input) == expected) - @DataProvider(name="binaryIntegralArithmetic") + @DataProvider(name = "binaryIntegralArithmetic") def binaryIntegralArithmetic: Array[Array[Any]] = - Array((Literal.coerce(TInt32, _)) -> TInt32, (Literal.coerce(TInt64, _)) -> TInt64).flatMap { case (pure, typ) => - Array.concat( - Array( - // Addition - Array(ApplyBinaryPrimOp(Add(), ref(typ), ref(typ)), ApplyBinaryPrimOp(Multiply(), pure(2), ref(typ))), - Array(ApplyBinaryPrimOp(Add(), pure(0), ref(typ)), ref(typ)), - Array(ApplyBinaryPrimOp(Add(), ref(typ), pure(0)), ref(typ)), - - // Subtraction - Array(ApplyBinaryPrimOp(Subtract(), ref(typ), ref(typ)), pure(0)), - Array(ApplyBinaryPrimOp(Subtract(), pure(0), ref(typ)), ApplyUnaryPrimOp(Negate, ref(typ))), - Array(ApplyBinaryPrimOp(Subtract(), ref(typ), pure(0)), ref(typ)), - - // Multiplication - Array(ApplyBinaryPrimOp(Multiply(), pure(0), ref(typ)), pure(0)), - Array(ApplyBinaryPrimOp(Multiply(), ref(typ), pure(0)), pure(0)), - Array(ApplyBinaryPrimOp(Multiply(), pure(1), ref(typ)), ref(typ)), - Array(ApplyBinaryPrimOp(Multiply(), ref(typ), pure(1)), ref(typ)), - Array(ApplyBinaryPrimOp(Multiply(), pure(-1), ref(typ)), ApplyUnaryPrimOp(Negate, ref(typ))), - Array(ApplyBinaryPrimOp(Multiply(), ref(typ), pure(-1)), ApplyUnaryPrimOp(Negate, ref(typ))), - - // Div (truncated to -Inf) - Array(ApplyBinaryPrimOp(RoundToNegInfDivide(), ref(typ), ref(typ)), pure(1)), - Array(ApplyBinaryPrimOp(RoundToNegInfDivide(), pure(0), ref(typ)), pure(0)), - Array(ApplyBinaryPrimOp(RoundToNegInfDivide(), ref(typ), pure(0)), Die("division by zero", typ)), - Array(ApplyBinaryPrimOp(RoundToNegInfDivide(), ref(typ), pure(1)), ref(typ)), - Array(ApplyBinaryPrimOp(RoundToNegInfDivide(), ref(typ), pure(-1)), ApplyUnaryPrimOp(Negate, ref(typ))), - - // Bitwise And - Array(ApplyBinaryPrimOp(BitAnd(), pure(0), ref(typ)), pure(0)), - Array(ApplyBinaryPrimOp(BitAnd(), ref(typ), pure(0)), pure(0)), - Array(ApplyBinaryPrimOp(BitAnd(), pure(-1), ref(typ)), ref(typ)), - Array(ApplyBinaryPrimOp(BitAnd(), ref(typ), pure(-1)), ref(typ)), - - // Bitwise Or - Array(ApplyBinaryPrimOp(BitOr(), pure(0), ref(typ)), ref(typ)), - Array(ApplyBinaryPrimOp(BitOr(), ref(typ), pure(0)), ref(typ)), - Array(ApplyBinaryPrimOp(BitOr(), pure(-1), ref(typ)), pure(-1)), - Array(ApplyBinaryPrimOp(BitOr(), ref(typ), pure(-1)), pure(-1)), - - // Bitwise Xor - Array(ApplyBinaryPrimOp(BitXOr(), ref(typ), ref(typ)), pure(0)), - Array(ApplyBinaryPrimOp(BitXOr(), ref(typ), pure(0)), ref(typ)), - Array(ApplyBinaryPrimOp(BitXOr(), pure(0), ref(typ)), ref(typ)), - ).asInstanceOf[Array[Array[Any]]], - // Shifts - Array(LeftShift(), RightShift(), LogicalRightShift()).flatMap { shift => + Array((Literal.coerce(TInt32, _)) -> TInt32, (Literal.coerce(TInt64, _)) -> TInt64).flatMap { + case (pure, typ) => + Array.concat( Array( - Array(ApplyBinaryPrimOp(shift, pure(0), ref(TInt32)), pure(0)), - Array(ApplyBinaryPrimOp(shift, ref(typ), I32(0)), ref(typ)) - ) - }.asInstanceOf[Array[Array[Any]]] - ) + // Addition + Array( + ApplyBinaryPrimOp(Add(), ref(typ), ref(typ)), + ApplyBinaryPrimOp(Multiply(), pure(2), ref(typ)), + ), + Array(ApplyBinaryPrimOp(Add(), pure(0), ref(typ)), ref(typ)), + Array(ApplyBinaryPrimOp(Add(), ref(typ), pure(0)), ref(typ)), + + // Subtraction + Array(ApplyBinaryPrimOp(Subtract(), ref(typ), ref(typ)), pure(0)), + Array( + ApplyBinaryPrimOp(Subtract(), pure(0), ref(typ)), + ApplyUnaryPrimOp(Negate, ref(typ)), + ), + Array(ApplyBinaryPrimOp(Subtract(), ref(typ), pure(0)), ref(typ)), + + // Multiplication + Array(ApplyBinaryPrimOp(Multiply(), pure(0), ref(typ)), pure(0)), + Array(ApplyBinaryPrimOp(Multiply(), ref(typ), pure(0)), pure(0)), + Array(ApplyBinaryPrimOp(Multiply(), pure(1), ref(typ)), ref(typ)), + Array(ApplyBinaryPrimOp(Multiply(), ref(typ), pure(1)), ref(typ)), + Array( + ApplyBinaryPrimOp(Multiply(), pure(-1), ref(typ)), + ApplyUnaryPrimOp(Negate, ref(typ)), + ), + Array( + ApplyBinaryPrimOp(Multiply(), ref(typ), pure(-1)), + ApplyUnaryPrimOp(Negate, ref(typ)), + ), + + // Div (truncated to -Inf) + Array(ApplyBinaryPrimOp(RoundToNegInfDivide(), ref(typ), ref(typ)), pure(1)), + Array(ApplyBinaryPrimOp(RoundToNegInfDivide(), pure(0), ref(typ)), pure(0)), + Array( + ApplyBinaryPrimOp(RoundToNegInfDivide(), ref(typ), pure(0)), + Die("division by zero", typ), + ), + Array(ApplyBinaryPrimOp(RoundToNegInfDivide(), ref(typ), pure(1)), ref(typ)), + Array( + ApplyBinaryPrimOp(RoundToNegInfDivide(), ref(typ), pure(-1)), + ApplyUnaryPrimOp(Negate, ref(typ)), + ), + + // Bitwise And + Array(ApplyBinaryPrimOp(BitAnd(), pure(0), ref(typ)), pure(0)), + Array(ApplyBinaryPrimOp(BitAnd(), ref(typ), pure(0)), pure(0)), + Array(ApplyBinaryPrimOp(BitAnd(), pure(-1), ref(typ)), ref(typ)), + Array(ApplyBinaryPrimOp(BitAnd(), ref(typ), pure(-1)), ref(typ)), + + // Bitwise Or + Array(ApplyBinaryPrimOp(BitOr(), pure(0), ref(typ)), ref(typ)), + Array(ApplyBinaryPrimOp(BitOr(), ref(typ), pure(0)), ref(typ)), + Array(ApplyBinaryPrimOp(BitOr(), pure(-1), ref(typ)), pure(-1)), + Array(ApplyBinaryPrimOp(BitOr(), ref(typ), pure(-1)), pure(-1)), + + // Bitwise Xor + Array(ApplyBinaryPrimOp(BitXOr(), ref(typ), ref(typ)), pure(0)), + Array(ApplyBinaryPrimOp(BitXOr(), ref(typ), pure(0)), ref(typ)), + Array(ApplyBinaryPrimOp(BitXOr(), pure(0), ref(typ)), ref(typ)), + ).asInstanceOf[Array[Array[Any]]], + // Shifts + Array(LeftShift(), RightShift(), LogicalRightShift()).flatMap { shift => + Array( + Array(ApplyBinaryPrimOp(shift, pure(0), ref(TInt32)), pure(0)), + Array(ApplyBinaryPrimOp(shift, ref(typ), I32(0)), ref(typ)), + ) + }.asInstanceOf[Array[Array[Any]]], + ) } @Test(dataProvider = "binaryIntegralArithmetic") @@ -500,7 +648,10 @@ class SimplifySuite extends HailSuite { @DataProvider(name = "floatingIntegralArithmetic") def binaryFloatingArithmetic: Array[Array[Any]] = - Array((Literal.coerce(TFloat32, _)) -> TFloat32, (Literal.coerce(TFloat64, _)) -> TFloat64).flatMap { case (pure, typ) => + Array( + (Literal.coerce(TFloat32, _)) -> TFloat32, + (Literal.coerce(TFloat64, _)) -> TFloat64, + ).flatMap { case (pure, typ) => Array( // Addition Array(ApplyBinaryPrimOp(Add(), pure(0), ref(typ)), ref(typ)), @@ -513,12 +664,21 @@ class SimplifySuite extends HailSuite { // Multiplication Array(ApplyBinaryPrimOp(Multiply(), pure(1), ref(typ)), ref(typ)), Array(ApplyBinaryPrimOp(Multiply(), ref(typ), pure(1)), ref(typ)), - Array(ApplyBinaryPrimOp(Multiply(), pure(-1), ref(typ)), ApplyUnaryPrimOp(Negate, ref(typ))), - Array(ApplyBinaryPrimOp(Multiply(), ref(typ), pure(-1)), ApplyUnaryPrimOp(Negate, ref(typ))), + Array( + ApplyBinaryPrimOp(Multiply(), pure(-1), ref(typ)), + ApplyUnaryPrimOp(Negate, ref(typ)), + ), + Array( + ApplyBinaryPrimOp(Multiply(), ref(typ), pure(-1)), + ApplyUnaryPrimOp(Negate, ref(typ)), + ), // Div (truncated to -Inf) Array(ApplyBinaryPrimOp(RoundToNegInfDivide(), ref(typ), pure(1)), ref(typ)), - Array(ApplyBinaryPrimOp(RoundToNegInfDivide(), ref(typ), pure(-1)), ApplyUnaryPrimOp(Negate, ref(typ))), + Array( + ApplyBinaryPrimOp(RoundToNegInfDivide(), ref(typ), pure(-1)), + ApplyUnaryPrimOp(Negate, ref(typ)), + ), ).asInstanceOf[Array[Array[Any]]] } @@ -532,24 +692,30 @@ class SimplifySuite extends HailSuite { ValueToBlockMatrix( MakeArray((1 to 4).map(F64(_)), TArray(TFloat64)), FastSeq(2, 2), - 10 + 10, ) Array( Array(BlockMatrixBroadcast(matrix, 0 to 1, matrix.shape, matrix.blockSize), matrix), Array(BlockMatrixMap(matrix, "x", Ref("x", TFloat64), true), matrix), - Array(BlockMatrixMap(matrix, "x", ref(TFloat64), true), BlockMatrixBroadcast( - ValueToBlockMatrix(ref(TFloat64), FastSeq(1, 1), matrix.blockSize), - FastSeq(), - matrix.shape, - matrix.blockSize - )), - Array(BlockMatrixMap(matrix, "x", F64(2356), true), BlockMatrixBroadcast( - ValueToBlockMatrix(F64(2356), FastSeq(1, 1), matrix.blockSize), - FastSeq(), - matrix.shape, - matrix.blockSize - )), + Array( + BlockMatrixMap(matrix, "x", ref(TFloat64), true), + BlockMatrixBroadcast( + ValueToBlockMatrix(ref(TFloat64), FastSeq(1, 1), matrix.blockSize), + FastSeq(), + matrix.shape, + matrix.blockSize, + ), + ), + Array( + BlockMatrixMap(matrix, "x", F64(2356), true), + BlockMatrixBroadcast( + ValueToBlockMatrix(F64(2356), FastSeq(1, 1), matrix.blockSize), + FastSeq(), + matrix.shape, + matrix.blockSize, + ), + ), ).asInstanceOf[Array[Array[Any]]] } @@ -562,15 +728,80 @@ class SimplifySuite extends HailSuite { Array( Array(I32(-1), I32(-1), IndexedSeq.tabulate(5)(I32), I32(-1)), Array(I32(1), I32(-1), IndexedSeq.tabulate(5)(I32), I32(1)), - Array(ref(TInt32), I32(-1), IndexedSeq.tabulate(5)(I32), Switch(ref(TInt32), I32(-1), IndexedSeq.tabulate(5)(I32))), + Array( + ref(TInt32), + I32(-1), + IndexedSeq.tabulate(5)(I32), + Switch(ref(TInt32), I32(-1), IndexedSeq.tabulate(5)(I32)), + ), Array(I32(256), I32(-1), IndexedSeq.empty[IR], I32(-1)), - Array(ref(TInt32), I32(-1), IndexedSeq.empty[IR], Switch(ref(TInt32), I32(-1), IndexedSeq.empty[IR])), // missingness + Array( + ref(TInt32), + I32(-1), + IndexedSeq.empty[IR], + Switch(ref(TInt32), I32(-1), IndexedSeq.empty[IR]), + ), // missingness ) @Test(dataProvider = "SwitchRules") def testTestSwitchSimplification(x: IR, default: IR, cases: IndexedSeq[IR], expected: Any): Unit = assert(Simplify(ctx, Switch(x, default, cases)) == expected) -} + @DataProvider(name = "IfRules") + def ifRules: Array[Array[Any]] = { + val x = Ref(genUID(), TInt32) + val y = Ref(genUID(), TInt32) + val c = Ref(genUID(), TBoolean) + + Array( + Array(True(), x, Die("Failure", x.typ), x), + Array(False(), Die("Failure", x.typ), x, x), + Array(IsNA(x), NA(x.typ), x, x), + Array(ApplyUnaryPrimOp(Bang, c), x, y, If(c, y, x)), + Array(c, If(c, x, y), y, If(c, x, y)), + Array(c, x, If(c, x, y), If(c, x, y)), + Array(c, x, x, If(IsNA(c), NA(x.typ), x)), + ) + } + + @Test(dataProvider = "IfRules") + def testIfSimplification(pred: IR, cnsq: IR, altr: IR, expected: Any): Unit = + assert(Simplify(ctx, If(pred, cnsq, altr)) == expected) + @DataProvider(name = "MakeStructRules") + def makeStructRules: Array[Array[Any]] = { + val s = ref(TStruct( + "a" -> TInt32, + "b" -> TInt64, + "c" -> TFloat32, + )) + def get(name: String) = GetField(s, name) + + Array( + Array( + FastSeq("x" -> get("a")), + CastRename(SelectFields(s, FastSeq("a")), TStruct("x" -> TInt32)), + ), + Array( + FastSeq("x" -> get("a"), "y" -> get("b")), + CastRename(SelectFields(s, FastSeq("a", "b")), TStruct("x" -> TInt32, "y" -> TInt64)), + ), + Array( + FastSeq("a" -> get("a"), "b" -> get("b")), + SelectFields(s, FastSeq("a", "b")), + ), + Array( + FastSeq("a" -> get("a"), "b" -> get("b"), "c" -> get("c")), + s, + ), + ) + } + + @Test(dataProvider = "MakeStructRules") + def testMakeStruct(fields: IndexedSeq[(String, IR)], expected: IR): Unit = { + val x = Simplify(ctx, MakeStruct(fields)) + assert(x == expected) + } + +} diff --git a/hail/src/test/scala/is/hail/expr/ir/StagedBTreeSuite.scala b/hail/src/test/scala/is/hail/expr/ir/StagedBTreeSuite.scala index 82642328b63..8c923f5bc59 100644 --- a/hail/src/test/scala/is/hail/expr/ir/StagedBTreeSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/StagedBTreeSuite.scala @@ -1,6 +1,5 @@ package is.hail.expr.ir -import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import is.hail.HailSuite import is.hail.annotations.Region import is.hail.asm4s._ @@ -8,36 +7,48 @@ import is.hail.backend.ExecuteContext import is.hail.check.{Gen, Prop} import is.hail.expr.ir.agg._ import is.hail.expr.ir.orderings.CodeOrdering -import is.hail.types.physical._ import is.hail.io.{InputBuffer, OutputBuffer, StreamBufferSpec} +import is.hail.types.physical._ import is.hail.types.physical.stypes.Int64SingleCodeType import is.hail.types.physical.stypes.interfaces.primitive import is.hail.types.physical.stypes.primitives.SInt64 import is.hail.utils._ -import org.testng.annotations.Test import scala.collection.mutable + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} + +import org.testng.annotations.Test + class TestBTreeKey(mb: EmitMethodBuilder[_]) extends BTreeKey { private val comp = mb.ecb.getOrderingFunction(SInt64, SInt64, CodeOrdering.Compare()) - override def storageType: PTuple = PCanonicalTuple(required = true, PInt64(), PCanonicalTuple(false)) + + override def storageType: PTuple = + PCanonicalTuple(required = true, PInt64(), PCanonicalTuple(false)) + override def compType: PType = PInt64() + override def isEmpty(cb: EmitCodeBuilder, off: Code[Long]): Value[Boolean] = storageType.isFieldMissing(cb, off, 1) + override def initializeEmpty(cb: EmitCodeBuilder, off: Code[Long]): Unit = storageType.setFieldMissing(cb, off, 1) def storeKey(cb: EmitCodeBuilder, _off: Code[Long], m: Code[Boolean], v: Code[Long]): Unit = { val off = cb.memoize[Long](_off) storageType.stagedInitialize(cb, off) - cb.if_(m, + cb.if_( + m, storageType.setFieldMissing(cb, off, 0), - cb += Region.storeLong(storageType.fieldOffset(off, 0), v) + cb += Region.storeLong(storageType.fieldOffset(off, 0), v), ) } override def copy(cb: EmitCodeBuilder, src: Code[Long], dest: Code[Long]): Unit = cb += Region.copyFrom(src, dest, storageType.byteSize) - override def deepCopy(cb: EmitCodeBuilder, er: EmitRegion, src: Code[Long], dest: Code[Long]): Unit = + + override def deepCopy(cb: EmitCodeBuilder, er: EmitRegion, src: Code[Long], dest: Code[Long]) + : Unit = copy(cb, src, dest) override def compKeys(cb: EmitCodeBuilder, k1: EmitValue, k2: EmitValue): Value[Int] = @@ -46,17 +57,18 @@ class TestBTreeKey(mb: EmitMethodBuilder[_]) extends BTreeKey { override def loadCompKey(cb: EmitCodeBuilder, off: Value[Long]): EmitValue = EmitValue( Some(storageType.isFieldMissing(cb, off, 0)), - primitive(cb.memoize(Region.loadLong(storageType.fieldOffset(off, 0))))) + primitive(cb.memoize(Region.loadLong(storageType.fieldOffset(off, 0)))), + ) } object BTreeBackedSet { - def bulkLoad(ctx: ExecuteContext, region: Region, serialized: Array[Byte], n: Int): BTreeBackedSet = { + def bulkLoad(ctx: ExecuteContext, region: Region, serialized: Array[Byte], n: Int) + : BTreeBackedSet = { val fb = EmitFunctionBuilder[Region, InputBuffer, Long](ctx, "btree_bulk_load") val cb = fb.ecb val root = fb.genFieldThisRef[Long]() val r = fb.genFieldThisRef[Region]() val ib = fb.getCodeParam[InputBuffer](2) - val ib2 = fb.genFieldThisRef[InputBuffer]() val km = fb.genFieldThisRef[Boolean]() val kv = fb.genFieldThisRef[Long]() @@ -76,7 +88,10 @@ object BTreeBackedSet { val inputBuffer = new StreamBufferSpec().buildInputBuffer(new ByteArrayInputStream(serialized)) val set = new BTreeBackedSet(ctx, region, n) - set.root = fb.resultWithIndex()(HailSuite.theHailClassLoader, ctx.fs, ctx.taskContext, region)(region, inputBuffer) + set.root = fb.resultWithIndex()(HailSuite.theHailClassLoader, ctx.fs, ctx.taskContext, region)( + region, + inputBuffer, + ) set } } @@ -119,9 +134,7 @@ class BTreeBackedSet(ctx: ExecuteContext, region: Region, n: Int) { cb.assign(r, fb.getCodeParam[Region](1)) cb.assign(root, fb.getCodeParam[Long](2)) cb.assign(elt, btree.getOrElseInitialize(cb, ec)) - cb.if_(key.isEmpty(cb, elt), { - key.storeKey(cb, elt, m, v) - }) + cb.if_(key.isEmpty(cb, elt), key.storeKey(cb, elt, m, v)) root } fb.resultWithIndex()(HailSuite.theHailClassLoader, ctx.fs, ctx.taskContext, region) @@ -146,19 +159,20 @@ class BTreeBackedSet(ctx: ExecuteContext, region: Region, n: Int) { btree.foreach(cb) { (cb, _koff) => val koff = cb.memoize(_koff) val ec = key.loadCompKey(cb, koff) - cb.if_(ec.m, - sab.addMissing(cb), - sab.add(cb, ec.pv.asInt64.value)) + cb.if_(ec.m, sab.addMissing(cb), sab.add(cb, ec.pv.asInt64.value)) } cb += (returnArray := Code.newArray[java.lang.Long](sab.size)) cb.for_( cb.assign(idx, 0), idx < sab.size, cb.assign(idx, idx + 1), - cb += returnArray.update(idx, sab.isMissing(idx).mux( - Code._null[java.lang.Long], - Code.boxLong(coerce[Long](sab(idx))) - )) + cb += returnArray.update( + idx, + sab.isMissing(idx).mux( + Code._null[java.lang.Long], + Code.boxLong(coerce[Long](sab(idx))), + ), + ), ) returnArray } @@ -184,9 +198,7 @@ class BTreeBackedSet(ctx: ExecuteContext, region: Region, n: Int) { val off = cb.newLocal("off", offc) val ev = cb.memoize(key.loadCompKey(cb, off), "ev") cb += ob.writeBoolean(ev.m) - cb.if_(!ev.m, { - cb += ob.writeLong(ev.pv.asInt64.value) - }) + cb.if_(!ev.m, cb += ob.writeLong(ev.pv.asInt64.value)) } ob2.flush() } @@ -233,7 +245,8 @@ class StagedBTreeSuite extends HailSuite { 3 -> Gen.choose(-10, 10), 5 -> Gen.choose(-30, 30), 6 -> Gen.choose(-30, 30), - 22 -> Gen.choose(-3, 3)) + 22 -> Gen.choose(-3, 3), + ) for ((n, values) <- nodeSizeParams) { val testSet = new BTreeBackedSet(ctx, region, n) diff --git a/hail/src/test/scala/is/hail/expr/ir/StagedMinHeapSuite.scala b/hail/src/test/scala/is/hail/expr/ir/StagedMinHeapSuite.scala new file mode 100644 index 00000000000..6c4d4e89607 --- /dev/null +++ b/hail/src/test/scala/is/hail/expr/ir/StagedMinHeapSuite.scala @@ -0,0 +1,211 @@ +package is.hail.expr.ir + +import is.hail.HailSuite +import is.hail.annotations.{Region, SafeIndexedSeq} +import is.hail.asm4s._ +import is.hail.check.Gen +import is.hail.check.Prop.forAll +import is.hail.expr.ir.functions.LocusFunctions +import is.hail.expr.ir.streams.StagedMinHeap +import is.hail.types.physical.{PCanonicalArray, PCanonicalLocus, PInt32Required} +import is.hail.types.physical.stypes.{SType, SValue} +import is.hail.types.physical.stypes.concrete.SIndexablePointerValue +import is.hail.types.physical.stypes.primitives.{SInt32, SInt32Value} +import is.hail.utils.{using, FastSeq} +import is.hail.variant.{Locus, ReferenceGenome} + +import org.scalatest.Matchers.{be, convertToAnyShouldWrapper} +import org.testng.annotations.Test + +sealed trait StagedCoercions[A] { + def ti: TypeInfo[A] + def sType: SType + def fromType(cb: EmitCodeBuilder, region: Value[Region], a: Value[A]): SValue + def toType(cb: EmitCodeBuilder, sa: SValue): Value[A] +} + +sealed trait StagedCoercionInstances { + implicit object StagedIntCoercions extends StagedCoercions[Int] { + override def ti: TypeInfo[Int] = + implicitly + + override def sType: SType = + SInt32 + + override def fromType(cb: EmitCodeBuilder, region: Value[Region], a: Value[Int]): SValue = + new SInt32Value(a) + + override def toType(cb: EmitCodeBuilder, sa: SValue): Value[Int] = + sa.asInt.value + } + + def stagedLocusCoercions(rg: ReferenceGenome): StagedCoercions[Locus] = + new StagedCoercions[Locus] { + override def ti: TypeInfo[Locus] = + implicitly + + override def sType: SType = + PCanonicalLocus(rg.name, required = true).sType + + override def fromType(cb: EmitCodeBuilder, region: Value[Region], a: Value[Locus]): SValue = + LocusFunctions.emitLocus(cb, region, a, sType.storageType().asInstanceOf[PCanonicalLocus]) + + override def toType(cb: EmitCodeBuilder, sa: SValue): Value[Locus] = + sa.asLocus.getLocusObj(cb) + } +} + +class StagedMinHeapSuite extends HailSuite with StagedCoercionInstances { + + @Test def testSorting(): Unit = + forAll((xs: IndexedSeq[Int]) => sort(xs) == xs.sorted).check() + + @Test def testHeapProperty(): Unit = + forAll { (xs: IndexedSeq[Int]) => + val heap = heapify(xs) + (0 until heap.size / 2).forall { i => + ((2 * i + 1) >= heap.size || heap(i) <= heap(2 * i + 1)) && + ((2 * i + 2) >= heap.size || heap(i) <= heap(2 * i + 2)) + } + }.check() + + @Test def testNonEmpty(): Unit = + gen("NonEmpty") { (heap: IntHeap) => + heap.nonEmpty should be(false) + for (i <- 0 to 10) heap.push(i) + heap.nonEmpty should be(true) + for (_ <- 0 to 10) heap.pop() + heap.nonEmpty should be(false) + } + + val loci: Gen[(ReferenceGenome, IndexedSeq[Locus])] = + for { + genome <- ReferenceGenome.gen + loci <- Gen.buildableOf(Locus.gen(genome)) + } yield (genome, loci) + + @Test def testLocus(): Unit = + forAll(loci) { case (rg: ReferenceGenome, loci: IndexedSeq[Locus]) => + withReferenceGenome(rg) { + + val sortedLoci = + gen("Locus", stagedLocusCoercions(rg)) { (heap: LocusHeap) => + loci.foreach(heap.push) + IndexedSeq.fill(loci.size)(heap.pop()) + } + + sortedLoci == loci.sorted(rg.locusOrdering) + } + }.check() + + def withReferenceGenome[A](rg: ReferenceGenome)(f: => A): A = { + ctx.backend.addReference(rg) + try f + finally ctx.backend.removeReference(rg.name) + } + + def sort(xs: IndexedSeq[Int]): IndexedSeq[Int] = + gen("Sort") { (heap: IntHeap) => + xs.foreach(heap.push) + IndexedSeq.fill(xs.size)(heap.pop()) + } + + def heapify(xs: IndexedSeq[Int]): IndexedSeq[Int] = + gen("Heapify") { (heap: IntHeap) => + pool.scopedRegion { r => + xs.foreach(heap.push) + val ptr = heap.toArray(r) + SafeIndexedSeq(PCanonicalArray(PInt32Required), ptr).asInstanceOf[IndexedSeq[Int]] + } + } + + def gen[H <: Heap[A], A, B]( + name: String, + A: StagedCoercions[A], + )( + f: H => B + )(implicit H: TypeInfo[H] + ): B = + gen[H, A, B](name)(f)(H, A) + + def gen[H <: Heap[A], A, B]( + name: String + )( + f: H => B + )(implicit + H: TypeInfo[H], + A: StagedCoercions[A], + ): B = { + val emodb = new EmitModuleBuilder(ctx, new ModuleBuilder()) + val Main = emodb.genEmitClass[H](name)(H) + + val MinHeap = StagedMinHeap(Main.emodb, A.sType) { + (cb: EmitCodeBuilder, x: SValue, y: SValue) => + cb.emb.ecb.getOrdering(A.sType, A.sType).compareNonnull(cb, x, y) + }(Main) + + Main.defineEmitMethod("push", FastSeq(A.ti), UnitInfo) { mb => + mb.voidWithBuilder { cb => + MinHeap.push(cb, A.fromType(cb, Main.partitionRegion, mb.getCodeParam[A](1)(A.ti))) + } + } + + Main.defineEmitMethod("pop", FastSeq(), A.ti) { mb => + mb.emitWithBuilder[A] { cb => + val res = A.toType(cb, MinHeap.peek(cb)) + MinHeap.pop(cb) + MinHeap.realloc(cb) + res + } + } + + Main.defineEmitMethod("nonEmpty", FastSeq(), BooleanInfo) { mb => + mb.emitWithBuilder[Boolean](MinHeap.nonEmpty) + } + + Main.defineEmitMethod("toArray", FastSeq(typeInfo[Region]), LongInfo) { mb => + mb.emitWithBuilder { cb => + val region = mb.getCodeParam[Region](1) + val arr = MinHeap.toArray(cb, region) + arr.asInstanceOf[SIndexablePointerValue].a + } + } + + trait Resource extends AutoCloseable { def init(): Unit } + + Main.cb.addInterface(implicitly[TypeInfo[Resource]].iname) + Main.defineEmitMethod("init", FastSeq(), UnitInfo) { mb => + mb.voidWithBuilder { cb => + // Properties like pool and reference genomes are set after `Main`'s + // default constructor is called, thus we need a separate method to + // initialise the heap with them. + MinHeap.init(cb, Main.pool()) + } + } + Main.defineEmitMethod("close", FastSeq(), UnitInfo)(mb => mb.voidWithBuilder(MinHeap.close)) + + pool.scopedRegion { r => + val heap = Main + .resultWithIndex()(theHailClassLoader, ctx.fs, ctx.taskContext, r) + .asInstanceOf[H with Resource] + + heap.init() + using(heap)(f) + } + } + + trait LocusHeap extends Heap[Locus] { + def push(locus: Locus): Unit + def pop(): Locus + } + + trait IntHeap extends Heap[Int] { + def push(x: Int): Unit + def pop(): Int + } + + trait Heap[A] { + def nonEmpty: Boolean + def toArray(r: Region): Long + } +} diff --git a/hail/src/test/scala/is/hail/expr/ir/StringFunctionsSuite.scala b/hail/src/test/scala/is/hail/expr/ir/StringFunctionsSuite.scala index c36512b2284..d60f3ab2761 100644 --- a/hail/src/test/scala/is/hail/expr/ir/StringFunctionsSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/StringFunctionsSuite.scala @@ -1,17 +1,18 @@ package is.hail.expr.ir +import is.hail.{ExecStrategy, HailSuite} import is.hail.TestUtils._ import is.hail.expr.ir.TestUtils._ import is.hail.types.virtual._ import is.hail.utils.FastSeq -import is.hail.{ExecStrategy, HailSuite} + import org.json4s.jackson.JsonMethods import org.testng.annotations.{DataProvider, Test} class StringFunctionsSuite extends HailSuite { implicit val execStrats = ExecStrategy.javaOnly - @Test def testRegexMatch() { + @Test def testRegexMatch(): Unit = { assertEvalsTo(invoke("regexMatch", TBoolean, Str("a"), NA(TString)), null) assertEvalsTo(invoke("regexMatch", TBoolean, NA(TString), Str("b")), null) @@ -22,27 +23,27 @@ class StringFunctionsSuite extends HailSuite { assertEvalsTo(invoke("regexMatch", TBoolean, Str("[a-z][0-9]"), Str("3x")), false) } - @Test def testLength() { + @Test def testLength(): Unit = { assertEvalsTo(invoke("length", TInt32, Str("ab")), 2) assertEvalsTo(invoke("length", TInt32, Str("")), 0) assertEvalsTo(invoke("length", TInt32, NA(TString)), null) } - @Test def testSubstring() { + @Test def testSubstring(): Unit = { assertEvalsTo(invoke("substring", TString, Str("ab"), 0, 1), "a") assertEvalsTo(invoke("substring", TString, Str("ab"), NA(TInt32), 1), null) assertEvalsTo(invoke("substring", TString, Str("ab"), 0, NA(TInt32)), null) assertEvalsTo(invoke("substring", TString, NA(TString), 0, 1), null) } - @Test def testConcat() { + @Test def testConcat(): Unit = { assertEvalsTo(invoke("concat", TString, Str("a"), NA(TString)), null) assertEvalsTo(invoke("concat", TString, NA(TString), Str("b")), null) assertEvalsTo(invoke("concat", TString, Str("a"), Str("b")), "ab") } - @Test def testSplit() { + @Test def testSplit(): Unit = { assertEvalsTo(invoke("split", TArray(TString), NA(TString), Str(",")), null) assertEvalsTo(invoke("split", TArray(TString), Str("a,b,c"), NA(TString)), null) @@ -53,10 +54,13 @@ class StringFunctionsSuite extends HailSuite { assertEvalsTo(invoke("split", TArray(TString), Str("a,b,c"), NA(TString), I32(2)), null) assertEvalsTo(invoke("split", TArray(TString), Str("a,b,c"), Str(","), NA(TInt32)), null) - assertEvalsTo(invoke("split", TArray(TString), Str("a,b,c"), Str(","), I32(2)), FastSeq("a", "b,c")) + assertEvalsTo( + invoke("split", TArray(TString), Str("a,b,c"), Str(","), I32(2)), + FastSeq("a", "b,c"), + ) } - @Test def testReplace() { + @Test def testReplace(): Unit = { assertEvalsTo(invoke("replace", TString, NA(TString), Str(","), Str(".")), null) assertEvalsTo(invoke("replace", TString, Str("a,b,c"), NA(TString), Str(".")), null) assertEvalsTo(invoke("replace", TString, Str("a,b,c"), Str(","), NA(TString)), null) @@ -64,7 +68,7 @@ class StringFunctionsSuite extends HailSuite { assertEvalsTo(invoke("replace", TString, Str("a,b,c"), Str(","), Str(".")), "a.b.c") } - @Test def testArrayMkString() { + @Test def testArrayMkString(): Unit = { assertEvalsTo(invoke("mkString", TString, IRStringArray("a", "b", "c"), NA(TString)), null) assertEvalsTo(invoke("mkString", TString, NA(TArray(TString)), Str(",")), null) assertEvalsTo(invoke("mkString", TString, IRStringArray("a", "b", "c"), Str(",")), "a,b,c") @@ -73,7 +77,7 @@ class StringFunctionsSuite extends HailSuite { assertEvalsTo(invoke("mkString", TString, IRStringArray("a", null, "c"), Str(",")), "a,null,c") } - @Test def testSetMkString() { + @Test def testSetMkString(): Unit = { assertEvalsTo(invoke("mkString", TString, IRStringSet("a", "b", "c"), NA(TString)), null) assertEvalsTo(invoke("mkString", TString, NA(TSet(TString)), Str(",")), null) assertEvalsTo(invoke("mkString", TString, IRStringSet("a", "b", "c"), Str(",")), "a,b,c") @@ -82,13 +86,19 @@ class StringFunctionsSuite extends HailSuite { assertEvalsTo(invoke("mkString", TString, IRStringSet("a", null, "c"), Str(",")), "a,c,null") } - @Test def testFirstMatchIn() { + @Test def testFirstMatchIn(): Unit = { assertEvalsTo(invoke("firstMatchIn", TArray(TString), Str("""([a-zA-Z]+)"""), Str("1")), null) - assertEvalsTo(invoke("firstMatchIn", TArray(TString), Str("Hello world!"), Str("""([a-zA-Z]+)""")), FastSeq("Hello")) - assertEvalsTo(invoke("firstMatchIn", TArray(TString), Str("Hello world!"), Str("""[a-zA-Z]+""")), FastSeq()) + assertEvalsTo( + invoke("firstMatchIn", TArray(TString), Str("Hello world!"), Str("""([a-zA-Z]+)""")), + FastSeq("Hello"), + ) + assertEvalsTo( + invoke("firstMatchIn", TArray(TString), Str("Hello world!"), Str("""[a-zA-Z]+""")), + FastSeq(), + ) } - @Test def testHammingDistance() { + @Test def testHammingDistance(): Unit = { assertEvalsTo(invoke("hamming", TInt32, Str("foo"), NA(TString)), null) assertEvalsTo(invoke("hamming", TInt32, Str("foo"), Str("fool")), null) assertEvalsTo(invoke("hamming", TInt32, Str("foo"), Str("fol")), 1) @@ -101,20 +111,23 @@ class StringFunctionsSuite extends HailSuite { Array(F32(3.14f), TFloat32), Array(I64(7), TInt64), Array(IRArray(1, null, 5), TArray(TInt32)), - Array(MakeTuple.ordered(FastSeq(1, NA(TInt32), 5.7)), TTuple(TInt32, TInt32, TFloat64)) + Array(MakeTuple.ordered(FastSeq(1, NA(TInt32), 5.7)), TTuple(TInt32, TInt32, TFloat64)), ) @Test(dataProvider = "str") - def str(annotation: IR, typ: Type) { - assertEvalsTo(invoke("str", TString, annotation), { - val a = eval(annotation); if (a == null) null else typ.str(a) - }) - } + def str(annotation: IR, typ: Type): Unit = + assertEvalsTo( + invoke("str", TString, annotation), { + val a = eval(annotation); if (a == null) null else typ.str(a) + }, + ) @Test(dataProvider = "str") - def json(annotation: IR, typ: Type) { - assertEvalsTo(invoke("json", TString, annotation), JsonMethods.compact(typ.toJSON(eval(annotation)))) - } + def json(annotation: IR, typ: Type): Unit = + assertEvalsTo( + invoke("json", TString, annotation), + JsonMethods.compact(typ.toJSON(eval(annotation))), + ) @DataProvider(name = "time") def timeData(): Array[Array[Any]] = Array( @@ -125,39 +138,53 @@ class StringFunctionsSuite extends HailSuite { // % A a B b C c D d e F G g H I j k l M m n p R r S s T t U u V v W w X x Y y Z z // ■ ■ ■ ■ ■ ⊗ ⊗ ■ ■ ■ ■ ⊗ ⊗ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ■ ⊗ ⊗ ⊗ ■ ■ ⊗ ■ - Array("%t%%%n%s", "\t%\n123456789", 123456789), - Array("%m/%d/%y %I:%M:%S %p", "10/10/97 11:45:23 PM", 876541523), - Array("%m/%d/%y %I:%M:%S %p", "07/08/19 03:00:01 AM", 1562569201), - Array("%Y.%m.%d %H:%M:%S %z", "1997.10.10 23:45:23 -04:00", 876541523), - Array("%Y.%m.%d %H:%M:%S %Z", "2019.07.08 03:00:01 America/New_York", 1562569201), - Array("day %j of %Y. %R:%S", "day 283 of 1997. 23:45:23", 876541523), - Array("day %j of %Y. %R:%S", "day 189 of 2019. 03:00:01", 1562569201), - Array("day %j of %Y. %R:%S", "day 001 of 1970. 22:46:40", 100000), - Array("%v %T", "10-Oct-1997 23:45:23", 876541523), - Array("%v %T", " 8-Jul-2019 03:00:01", 1562569201), - Array("%A, %B %e, %Y. %r", "Friday, October 10, 1997. 11:45:23 PM", 876541523), - Array("%A, %B %e, %Y. %r", "Monday, July 8, 2019. 03:00:01 AM", 1562569201), - Array("%a, %b %e, '%y. %I:%M:%S %p", "Fri, Oct 10, '97. 11:45:23 PM", 876541523), - Array("%a, %b %e, '%y. %I:%M:%S %p", "Mon, Jul 8, '19. 03:00:01 AM", 1562569201), - Array("%D %l:%M:%S %p", "10/10/97 11:45:23 PM", 876541523), - Array("%D %l:%M:%S %p", "07/08/19 3:00:01 AM", 1562569201), - Array("%F %k:%M:%S", "1997-10-10 23:45:23", 876541523), - Array("%F %k:%M:%S", "2019-07-08 3:00:01", 1562569201), - Array("ISO 8601 week day %u. %Y.%m.%d %H:%M:%S", "ISO 8601 week day 4. 1970.01.01 22:46:40", 100000), - Array("Week number %U of %Y. %Y.%m.%d %H:%M:%S", "Week number 00 of 1973. 1973.01.01 10:33:20", 94750400), - Array("ISO 8601 week #%V. %Y.%m.%d %H:%M:%S", "ISO 8601 week #53. 2005.01.02 00:00:00", 1104642000), - Array("ISO 8601 week #%V. %Y.%m.%d %H:%M:%S", "ISO 8601 week #01. 2005.01.03 00:00:00", 1104728400), - Array("Monday week #%W. %Y.%m.%d %H:%M:%S", "Monday week #00. 2005.01.02 00:00:00", 1104642000), - Array("Monday week #%W. %Y.%m.%d %H:%M:%S", "Monday week #01. 2005.01.03 00:00:00", 1104728400) + Array("%t%%%n%s", "\t%\n123456789", 123456789), + Array("%m/%d/%y %I:%M:%S %p", "10/10/97 11:45:23 PM", 876541523), + Array("%m/%d/%y %I:%M:%S %p", "07/08/19 03:00:01 AM", 1562569201), + Array("%Y.%m.%d %H:%M:%S %z", "1997.10.10 23:45:23 -04:00", 876541523), + Array("%Y.%m.%d %H:%M:%S %Z", "2019.07.08 03:00:01 America/New_York", 1562569201), + Array("day %j of %Y. %R:%S", "day 283 of 1997. 23:45:23", 876541523), + Array("day %j of %Y. %R:%S", "day 189 of 2019. 03:00:01", 1562569201), + Array("day %j of %Y. %R:%S", "day 001 of 1970. 22:46:40", 100000), + Array("%v %T", "10-Oct-1997 23:45:23", 876541523), + Array("%v %T", " 8-Jul-2019 03:00:01", 1562569201), + Array("%A, %B %e, %Y. %r", "Friday, October 10, 1997. 11:45:23 PM", 876541523), + Array("%A, %B %e, %Y. %r", "Monday, July 8, 2019. 03:00:01 AM", 1562569201), + Array("%a, %b %e, '%y. %I:%M:%S %p", "Fri, Oct 10, '97. 11:45:23 PM", 876541523), + Array("%a, %b %e, '%y. %I:%M:%S %p", "Mon, Jul 8, '19. 03:00:01 AM", 1562569201), + Array("%D %l:%M:%S %p", "10/10/97 11:45:23 PM", 876541523), + Array("%D %l:%M:%S %p", "07/08/19 3:00:01 AM", 1562569201), + Array("%F %k:%M:%S", "1997-10-10 23:45:23", 876541523), + Array("%F %k:%M:%S", "2019-07-08 3:00:01", 1562569201), + Array( + "ISO 8601 week day %u. %Y.%m.%d %H:%M:%S", + "ISO 8601 week day 4. 1970.01.01 22:46:40", + 100000, + ), + Array( + "Week number %U of %Y. %Y.%m.%d %H:%M:%S", + "Week number 00 of 1973. 1973.01.01 10:33:20", + 94750400, + ), + Array( + "ISO 8601 week #%V. %Y.%m.%d %H:%M:%S", + "ISO 8601 week #53. 2005.01.02 00:00:00", + 1104642000, + ), + Array( + "ISO 8601 week #%V. %Y.%m.%d %H:%M:%S", + "ISO 8601 week #01. 2005.01.03 00:00:00", + 1104728400, + ), + Array("Monday week #%W. %Y.%m.%d %H:%M:%S", "Monday week #00. 2005.01.02 00:00:00", 1104642000), + Array("Monday week #%W. %Y.%m.%d %H:%M:%S", "Monday week #01. 2005.01.03 00:00:00", 1104728400), ) @Test(dataProvider = "time") - def strftime(fmt: String, s: String, t: Long) { + def strftime(fmt: String, s: String, t: Long): Unit = assertEvalsTo(invoke("strftime", TString, Str(fmt), I64(t), Str("America/New_York")), s) - } @Test(dataProvider = "time") - def strptime(fmt: String, s: String, t: Long) { + def strptime(fmt: String, s: String, t: Long): Unit = assertEvalsTo(invoke("strptime", TInt64, Str(s), Str(fmt), Str("America/New_York")), t) - } } diff --git a/hail/src/test/scala/is/hail/expr/ir/StringLengthSuite.scala b/hail/src/test/scala/is/hail/expr/ir/StringLengthSuite.scala index 4f8aeaceed0..1a01864492d 100644 --- a/hail/src/test/scala/is/hail/expr/ir/StringLengthSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/StringLengthSuite.scala @@ -1,17 +1,16 @@ package is.hail.expr.ir import is.hail.{ExecStrategy, HailSuite} -import is.hail.TestUtils._ import is.hail.types.virtual.TInt32 + import org.testng.annotations.Test class StringLengthSuite extends HailSuite { implicit val execStrats = ExecStrategy.javaOnly - @Test def sameAsJavaStringLength() { + @Test def sameAsJavaStringLength(): Unit = { val strings = Array("abc", "", "\uD83D\uDCA9") - for (s <- strings) { + for (s <- strings) assertEvalsTo(invoke("length", TInt32, Str(s)), s.length) - } } } diff --git a/hail/src/test/scala/is/hail/expr/ir/StringSliceSuite.scala b/hail/src/test/scala/is/hail/expr/ir/StringSliceSuite.scala index 505dc3c8316..ac869b23455 100644 --- a/hail/src/test/scala/is/hail/expr/ir/StringSliceSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/StringSliceSuite.scala @@ -1,17 +1,18 @@ package is.hail.expr.ir +import is.hail.{ExecStrategy, HailSuite} import is.hail.TestUtils._ import is.hail.types.virtual.TString import is.hail.utils._ -import is.hail.{ExecStrategy, HailSuite} + import org.testng.annotations.Test class StringSliceSuite extends HailSuite { implicit val execStrats = ExecStrategy.javaOnly - @Test def unicodeSlicingSlicesCodePoints() { + @Test def unicodeSlicingSlicesCodePoints(): Unit = { val poopEmoji = "\uD83D\uDCA9" - val s = s"abc${ poopEmoji }def" + val s = s"abc${poopEmoji}def" // FIXME: The replacement character for slicing halfway into a // 2-codepoint-wide character differs between UTF8 and UTF16. @@ -20,67 +21,67 @@ class StringSliceSuite extends HailSuite { val replacementCharacter = "?" assertEvalsTo(invoke("slice", TString, Str(s), I32(0), I32(4)), s"abc$replacementCharacter") - assertEvalsTo(invoke("slice", TString, Str(s), I32(4), I32(8)), s"${ replacementCharacter }def") + assertEvalsTo(invoke("slice", TString, Str(s), I32(4), I32(8)), s"${replacementCharacter}def") assertEvalsTo(invoke("slice", TString, Str(s), I32(0), I32(5)), s"abc$poopEmoji") } - @Test def zeroToLengthIsIdentity() { + @Test def zeroToLengthIsIdentity(): Unit = assertEvalsTo(invoke("slice", TString, Str("abc"), I32(0), I32(3)), "abc") - } - @Test def simpleSlicesMatchIntuition() { + @Test def simpleSlicesMatchIntuition(): Unit = { assertEvalsTo(invoke("slice", TString, Str("abc"), I32(3), I32(3)), "") assertEvalsTo(invoke("slice", TString, Str("abc"), I32(1), I32(3)), "bc") assertEvalsTo(invoke("slice", TString, Str("abc"), I32(2), I32(3)), "c") assertEvalsTo(invoke("slice", TString, Str("abc"), I32(0), I32(2)), "ab") } - @Test def sizeZeroSliceIsEmptyString() { + @Test def sizeZeroSliceIsEmptyString(): Unit = { assertEvalsTo(invoke("slice", TString, Str("abc"), I32(2), I32(2)), "") assertEvalsTo(invoke("slice", TString, Str("abc"), I32(1), I32(1)), "") assertEvalsTo(invoke("slice", TString, Str("abc"), I32(0), I32(0)), "") } - @Test def substringMatchesJavaStringSubstring() { + @Test def substringMatchesJavaStringSubstring(): Unit = { assertEvalsTo( invoke("substring", TString, Str("abc"), I32(0), I32(2)), - "abc".substring(0, 2)) + "abc".substring(0, 2), + ) assertEvalsTo( invoke("substring", TString, Str("foobarbaz"), I32(3), I32(5)), - "foobarbaz".substring(3, 5)) + "foobarbaz".substring(3, 5), + ) } - @Test def isStrict() { + @Test def isStrict(): Unit = { assertEvalsTo(invoke("slice", TString, NA(TString), I32(0), I32(2)), null) assertEvalsTo(invoke("slice", TString, NA(TString), I32(-5), I32(-10)), null) } - @Test def leftSliceMatchesIntuition() { + @Test def leftSliceMatchesIntuition(): Unit = { assertEvalsTo(invoke("sliceRight", TString, Str("abc"), I32(2)), "c") assertEvalsTo(invoke("sliceRight", TString, Str("abc"), I32(1)), "bc") } - @Test def rightSliceMatchesIntuition() { + @Test def rightSliceMatchesIntuition(): Unit = { assertEvalsTo(invoke("sliceLeft", TString, Str("abc"), I32(2)), "ab") assertEvalsTo(invoke("sliceLeft", TString, Str("abc"), I32(1)), "a") } - @Test def bothSideSliceMatchesIntuition() { + @Test def bothSideSliceMatchesIntuition(): Unit = assertEvalsTo(invoke("slice", TString, Str("abc"), I32(0), I32(2)), "ab") - // assertEvalsTo(invoke("slice", TString, Str("abc"), I32(1), I32(3)), "bc") - } + // assertEvalsTo(invoke("slice", TString, Str("abc"), I32(1), I32(3)), "bc") - @Test def leftSliceIsPythony() { + @Test def leftSliceIsPythony(): Unit = { assertEvalsTo(invoke("sliceRight", TString, Str("abc"), I32(-1)), "c") assertEvalsTo(invoke("sliceRight", TString, Str("abc"), I32(-2)), "bc") } - @Test def rightSliceIsPythony() { + @Test def rightSliceIsPythony(): Unit = { assertEvalsTo(invoke("sliceLeft", TString, Str("abc"), I32(-1)), "ab") assertEvalsTo(invoke("sliceLeft", TString, Str("abc"), I32(-2)), "a") } - @Test def sliceIsPythony() { + @Test def sliceIsPythony(): Unit = { assertEvalsTo(invoke("slice", TString, Str("abc"), I32(-3), I32(-1)), "ab") assertEvalsTo(invoke("slice", TString, Str("abc"), I32(-3), I32(-2)), "a") assertEvalsTo(invoke("slice", TString, Str("abc"), I32(-2), I32(-1)), "b") @@ -89,7 +90,7 @@ class StringSliceSuite extends HailSuite { assertEvalsTo(invoke("slice", TString, Str("abc"), I32(1), I32(-1)), "b") } - @Test def bothSidesSliceFunctionOutOfBoundsNotFatal() { + @Test def bothSidesSliceFunctionOutOfBoundsNotFatal(): Unit = { assertEvalsTo(invoke("slice", TString, Str("abc"), I32(4), I32(4)), "") assertEvalsTo(invoke("slice", TString, Str("abc"), I32(3), I32(2)), "") assertEvalsTo(invoke("slice", TString, Str("abc"), I32(-1), I32(2)), "") @@ -100,7 +101,7 @@ class StringSliceSuite extends HailSuite { assertEvalsTo(invoke("slice", TString, Str("abc"), I32(-10), I32(-1)), "ab") } - @Test def leftSliceFunctionOutOfBoundsNotFatal() { + @Test def leftSliceFunctionOutOfBoundsNotFatal(): Unit = { assertEvalsTo(invoke("sliceRight", TString, Str("abc"), I32(15)), "") assertEvalsTo(invoke("sliceRight", TString, Str("abc"), I32(4)), "") assertEvalsTo(invoke("sliceRight", TString, Str("abc"), I32(3)), "") @@ -109,7 +110,7 @@ class StringSliceSuite extends HailSuite { assertEvalsTo(invoke("sliceRight", TString, Str("abc"), I32(-100)), "abc") } - @Test def rightSliceFunctionOutOfBoundsNotFatal() { + @Test def rightSliceFunctionOutOfBoundsNotFatal(): Unit = { assertEvalsTo(invoke("sliceLeft", TString, Str("abc"), I32(15)), "abc") assertEvalsTo(invoke("sliceLeft", TString, Str("abc"), I32(4)), "abc") assertEvalsTo(invoke("sliceLeft", TString, Str("abc"), I32(3)), "abc") @@ -118,7 +119,7 @@ class StringSliceSuite extends HailSuite { assertEvalsTo(invoke("sliceLeft", TString, Str("abc"), I32(-100)), "") } - @Test def testStringIndex() { + @Test def testStringIndex(): Unit = { assertEvalsTo(invoke("index", TString, In(0, TString), I32(0)), FastSeq("Baz" -> TString), "B") assertEvalsTo(invoke("index", TString, In(0, TString), I32(1)), FastSeq("Baz" -> TString), "a") assertEvalsTo(invoke("index", TString, In(0, TString), I32(2)), FastSeq("Baz" -> TString), "z") @@ -127,10 +128,18 @@ class StringSliceSuite extends HailSuite { assertEvalsTo(invoke("index", TString, In(0, TString), I32(-3)), FastSeq("Baz" -> TString), "B") interceptFatal("string index out of bounds") { - assertEvalsTo(invoke("index", TString, In(0, TString), I32(3)), FastSeq("Baz" -> TString), "B") + assertEvalsTo( + invoke("index", TString, In(0, TString), I32(3)), + FastSeq("Baz" -> TString), + "B", + ) } interceptFatal("string index out of bounds") { - assertEvalsTo(invoke("index", TString, In(0, TString), I32(-4)), FastSeq("Baz" -> TString), "B") + assertEvalsTo( + invoke("index", TString, In(0, TString), I32(-4)), + FastSeq("Baz" -> TString), + "B", + ) } } } diff --git a/hail/src/test/scala/is/hail/expr/ir/TableIRSuite.scala b/hail/src/test/scala/is/hail/expr/ir/TableIRSuite.scala index 738aad97bf2..9a187a40590 100644 --- a/hail/src/test/scala/is/hail/expr/ir/TableIRSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/TableIRSuite.scala @@ -1,27 +1,29 @@ package is.hail.expr.ir +import is.hail.{ExecStrategy, HailSuite} import is.hail.ExecStrategy.ExecStrategy import is.hail.TestUtils._ import is.hail.annotations.SafeNDArray import is.hail.expr.Nat import is.hail.expr.ir.TestUtils._ import is.hail.expr.ir.lowering.{DArrayLowering, LowerTableIR} -import is.hail.methods.ForceCountTable +import is.hail.methods.{ForceCountTable, NPartitionsTable} import is.hail.rvd.RVDPartitioner import is.hail.types._ import is.hail.types.virtual._ import is.hail.utils._ -import is.hail.{ExecStrategy, HailSuite} +import is.hail.variant.Locus import org.apache.spark.sql.Row -import org.scalatest.Inspectors.forAll import org.scalatest.{Failed, Succeeded} +import org.scalatest.Inspectors.forAll import org.testng.annotations.{DataProvider, Test} class TableIRSuite extends HailSuite { - implicit val execStrats: Set[ExecStrategy] = Set(ExecStrategy.Interpret, ExecStrategy.InterpretUnoptimized, ExecStrategy.LoweredJVMCompile) + implicit val execStrats: Set[ExecStrategy] = + Set(ExecStrategy.Interpret, ExecStrategy.InterpretUnoptimized, ExecStrategy.LoweredJVMCompile) - @Test def testRangeCount() { + @Test def testRangeCount(): Unit = { val node1 = TableCount(TableRange(10, 2)) val node2 = TableCount(TableRange(15, 5)) val node = ApplyBinaryPrimOp(Add(), node1, node2) @@ -37,9 +39,12 @@ class TableIRSuite extends HailSuite { assertEvalsTo(forceCountRange, tableRangeSize.toLong) } - @Test def testRangeRead() { + @Test def testRangeRead(): Unit = { implicit val execStrats = ExecStrategy.lowering - val original = TableKeyBy(TableMapGlobals(TableRange(10, 3), MakeStruct(FastSeq("foo" -> I32(57)))), FastSeq()) + val original = TableKeyBy( + TableMapGlobals(TableRange(10, 3), MakeStruct(FastSeq("foo" -> I32(57)))), + FastSeq(), + ) val path = ctx.createTmpPath("test-range-read", "ht") val write = TableWrite(original, TableNativeWriter(path, overwrite = true)) @@ -51,7 +56,7 @@ class TableIRSuite extends HailSuite { (partSize, partIndex) <- partition(10, 3).zipWithIndex i <- 0 until partSize } yield Row(partIndex.toLong, i.toLong) - val expectedRows = (0 until 10, uids).zipped.map { (i, uid) => Row(i, uid) } + val expectedRows = (0 until 10, uids).zipped.map((i, uid) => Row(i, uid)) val expectedGlobals = Row(57) assertEvalsTo(TableCollect(read), Row(expectedRows, expectedGlobals)) @@ -64,7 +69,7 @@ class TableIRSuite extends HailSuite { assertEvalsTo(TableCount(tir), 120L) } - @Test def testRangeCollect() { + @Test def testRangeCollect(): Unit = { implicit val execStrats = Set(ExecStrategy.Interpret, ExecStrategy.InterpretUnoptimized) val t = TableRange(10, 2) val row = Ref("row", t.typ.rowType) @@ -73,30 +78,47 @@ class TableIRSuite extends HailSuite { assertEvalsTo(node, Row(Array.tabulate(10)(i => Row(i, i)).toFastSeq, Row())) } - @Test def testNestedRangeCollect() { + @Test def testNestedRangeCollect(): Unit = { implicit val execStrats = ExecStrategy.allRelational val r = TableRange(2, 2) val tc = GetField(collect(TableRange(2, 2)), "rows") - val m = TableMapRows(TableRange(2, 2), InsertFields(Ref("row", r.typ.rowType), FastSeq("collected" -> tc))) - assertEvalsTo(collect(m), - Row(FastSeq( - Row(0, FastSeq(Row(0), Row(1))), - Row(1, FastSeq(Row(0), Row(1))) - ), Row())) + val m = TableMapRows( + TableRange(2, 2), + InsertFields(Ref("row", r.typ.rowType), FastSeq("collected" -> tc)), + ) + assertEvalsTo( + collect(m), + Row( + FastSeq( + Row(0, FastSeq(Row(0), Row(1))), + Row(1, FastSeq(Row(0), Row(1))), + ), + Row(), + ), + ) } - @Test def testRangeSum() { + @Test def testRangeSum(): Unit = { implicit val execStrats = ExecStrategy.interpretOnly val t = TableRange(10, 2) val row = Ref("row", t.typ.rowType) val sum = AggSignature(Sum(), FastSeq(), FastSeq(TInt64)) - val node = collect(TableMapRows(t, InsertFields(row, FastSeq("sum" -> ApplyScanOp(FastSeq(), FastSeq(Cast(GetField(row, "idx"), TInt64)), sum))))) - assertEvalsTo(node, Row(Array.tabulate(10)(i => Row(i, Array.range(0, i).sum.toLong)).toFastSeq, Row())) + val node = collect(TableMapRows( + t, + InsertFields( + row, + FastSeq("sum" -> ApplyScanOp(FastSeq(), FastSeq(Cast(GetField(row, "idx"), TInt64)), sum)), + ), + )) + assertEvalsTo( + node, + Row(Array.tabulate(10)(i => Row(i, Array.range(0, i).sum.toLong)).toFastSeq, Row()), + ) } - @Test def testGetGlobals() { + @Test def testGetGlobals(): Unit = { implicit val execStrats = Set(ExecStrategy.Interpret, ExecStrategy.InterpretUnoptimized) val t = TableRange(10, 2) val newGlobals = InsertFields(Ref("global", t.typ.globalType), FastSeq("x" -> collect(t))) @@ -104,13 +126,17 @@ class TableIRSuite extends HailSuite { assertEvalsTo(node, Row(Row(Array.tabulate(10)(i => Row(i)).toFastSeq, Row()))) } - @Test def testCollectGlobals() { + @Test def testCollectGlobals(): Unit = { implicit val execStrats = Set(ExecStrategy.Interpret, ExecStrategy.InterpretUnoptimized) val t = TableRange(10, 2) val newGlobals = InsertFields(Ref("global", t.typ.globalType), FastSeq("x" -> collect(t))) val node = TableMapRows( TableMapGlobals(t, newGlobals), - InsertFields(Ref("row", t.typ.rowType), FastSeq("x" -> GetField(Ref("global", newGlobals.typ), "x")))) + InsertFields( + Ref("row", t.typ.rowType), + FastSeq("x" -> GetField(Ref("global", newGlobals.typ), "x")), + ), + ) val collectedT = Row(Array.tabulate(10)(i => Row(i)).toFastSeq, Row()) val expected = Array.tabulate(10)(i => Row(i, collectedT)).toFastSeq @@ -118,96 +144,140 @@ class TableIRSuite extends HailSuite { assertEvalsTo(collect(node), Row(expected, Row(collectedT))) } - @Test def testRangeExplode() { + @Test def testRangeExplode(): Unit = { implicit val execStrats = Set(ExecStrategy.Interpret, ExecStrategy.InterpretUnoptimized) val t = TableRange(10, 2) val row = Ref("row", t.typ.rowType) - val t2 = TableMapRows(t, InsertFields(row, FastSeq("x" -> ToArray(StreamRange(0, GetField(row, "idx"), 1))))) + val t2 = TableMapRows( + t, + InsertFields(row, FastSeq("x" -> ToArray(StreamRange(0, GetField(row, "idx"), 1)))), + ) val node = TableExplode(t2, FastSeq("x")) val expected = Array.range(0, 10).flatMap(i => Array.range(0, i).map(Row(i, _))).toFastSeq assertEvalsTo(collect(node), Row(expected, Row())) - val t3 = TableMapRows(t, InsertFields(row, - FastSeq("x" -> - MakeStruct(FastSeq("y" -> ToArray(StreamRange(0, GetField(row, "idx"), 1))))))) + val t3 = TableMapRows( + t, + InsertFields( + row, + FastSeq("x" -> + MakeStruct(FastSeq("y" -> ToArray(StreamRange(0, GetField(row, "idx"), 1))))), + ), + ) val node2 = TableExplode(t3, FastSeq("x", "y")) - val expected2 = Array.range(0, 10).flatMap(i => Array.range(0, i).map(j => Row(i, Row(j)))).toFastSeq + val expected2 = + Array.range(0, 10).flatMap(i => Array.range(0, i).map(j => Row(i, Row(j)))).toFastSeq assertEvalsTo(collect(node2), Row(expected2, Row())) } - @Test def testFilter() { + @Test def testFilter(): Unit = { implicit val execStrats = Set(ExecStrategy.Interpret, ExecStrategy.InterpretUnoptimized) val t = TableRange(10, 2) val node = TableFilter( - TableMapGlobals(t, MakeStruct(FastSeq("x" -> GetField(ArrayRef(GetField(collect(t), "rows"), 4), "idx")))), - ApplyComparisonOp(EQ(TInt32), GetField(Ref("row", t.typ.rowType), "idx"), GetField(Ref("global", TStruct("x" -> TInt32)), "x"))) + TableMapGlobals( + t, + MakeStruct(FastSeq("x" -> GetField(ArrayRef(GetField(collect(t), "rows"), 4), "idx"))), + ), + ApplyComparisonOp( + EQ(TInt32), + GetField(Ref("row", t.typ.rowType), "idx"), + GetField(Ref("global", TStruct("x" -> TInt32)), "x"), + ), + ) val expected = Array.tabulate(10)(Row(_)).filter(_.get(0) == 4).toFastSeq assertEvalsTo(collect(node), Row(expected, Row(4))) } - @Test def testFilterIntervals() { + @Test def testFilterIntervals(): Unit = { implicit val execStrats = ExecStrategy.allRelational - - def assertFilterIntervals(intervals: IndexedSeq[Interval], keep: Boolean, expected: IndexedSeq[Int]): Unit = { + def assertFilterIntervals( + intervals: IndexedSeq[Interval], + keep: Boolean, + expected: IndexedSeq[Int], + ): Unit = { var t: TableIR = TableRange(10, 5) - t = TableFilterIntervals(t, intervals.map(i => Interval(Row(i.start), Row(i.end), i.includesStart, i.includesEnd)), keep) + t = TableFilterIntervals( + t, + intervals.map(i => Interval(Row(i.start), Row(i.end), i.includesStart, i.includesEnd)), + keep, + ) assertEvalsTo(GetField(collect(t), "rows"), expected.map(Row(_))) } assertFilterIntervals( FastSeq(Interval(0, 5, true, false)), true, - FastSeq(0, 1, 2, 3, 4)) + FastSeq(0, 1, 2, 3, 4), + ) assertFilterIntervals( FastSeq(Interval(0, 5, true, false)), false, - FastSeq(5, 6, 7, 8, 9)) + FastSeq(5, 6, 7, 8, 9), + ) assertFilterIntervals( FastSeq(), true, - FastSeq()) + FastSeq(), + ) assertFilterIntervals( FastSeq(), false, - FastSeq(0, 1, 2, 3, 4, 5, 6, 7, 8, 9)) + FastSeq(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), + ) assertFilterIntervals( FastSeq(), true, - FastSeq()) + FastSeq(), + ) assertFilterIntervals( - FastSeq(Interval(0, 5, true, false), Interval(1, 6, false, true), Interval(8, 9, true, false)), + FastSeq( + Interval(0, 5, true, false), + Interval(1, 6, false, true), + Interval(8, 9, true, false), + ), false, - FastSeq(7, 9)) + FastSeq(7, 9), + ) assertFilterIntervals( - FastSeq(Interval(0, 5, true, false), Interval(1, 6, false, true), Interval(8, 9, true, false)), + FastSeq( + Interval(0, 5, true, false), + Interval(1, 6, false, true), + Interval(8, 9, true, false), + ), true, - FastSeq(0, 1, 2, 3, 4, 5, 6, 8)) + FastSeq(0, 1, 2, 3, 4, 5, 6, 8), + ) } - @Test def testTableMapWithLiterals() { + @Test def testTableMapWithLiterals(): Unit = { implicit val execStrats = Set(ExecStrategy.Interpret, ExecStrategy.InterpretUnoptimized) val t = TableRange(10, 2) - val node = TableMapRows(t, - InsertFields(Ref("row", t.typ.rowType), + val node = TableMapRows( + t, + InsertFields( + Ref("row", t.typ.rowType), FastSeq( "a" -> Str("foo"), - "b" -> Literal(TTuple(TInt32, TString), Row(1, "hello"))))) + "b" -> Literal(TTuple(TInt32, TString), Row(1, "hello")), + ), + ), + ) val expected = Array.tabulate(10)(Row(_, "foo", Row(1, "hello"))).toFastSeq assertEvalsTo(collect(node), Row(expected, Row())) } - @Test def testScanCountBehavesLikeIndex() { + @Test def testScanCountBehavesLikeIndex(): Unit = { implicit val execStrats = ExecStrategy.interpretOnly val t = rangeKT val oldRow = Ref("row", t.typ.rowType) @@ -215,10 +285,13 @@ class TableIRSuite extends HailSuite { val newRow = InsertFields(oldRow, Seq("idx2" -> IRScanCount)) val newTable = TableMapRows(t, newRow) val expected = Array.tabulate(20)(i => Row(i, i.toLong)).toFastSeq - assertEvalsTo(ArraySort(ToStream(TableAggregate(newTable, IRAggCollect(Ref("row", newRow.typ)))), True()), expected) + assertEvalsTo( + ArraySort(ToStream(TableAggregate(newTable, IRAggCollect(Ref("row", newRow.typ)))), True()), + expected, + ) } - @Test def testScanCollectBehavesLikeRange() { + @Test def testScanCollectBehavesLikeRange(): Unit = { implicit val execStrats = ExecStrategy.interpretOnly val t = rangeKT val oldRow = Ref("row", t.typ.rowType) @@ -227,11 +300,17 @@ class TableIRSuite extends HailSuite { val newTable = TableMapRows(t, newRow) val expected = Array.tabulate(20)(i => Row(i, Array.range(0, i).toFastSeq)).toFastSeq - assertEvalsTo(ArraySort(ToStream(TableAggregate(newTable, IRAggCollect(Ref("row", newRow.typ)))), True()), expected) + assertEvalsTo( + ArraySort(ToStream(TableAggregate(newTable, IRAggCollect(Ref("row", newRow.typ)))), True()), + expected, + ) } val rowType = TStruct(("A", TInt32), ("B", TInt32), ("C", TInt32)) - val joinedType = TStruct(("A", TInt32), ("B", TInt32), ("C", TInt32), ("B_1", TInt32), ("C_1", TInt32)) + + val joinedType = + TStruct(("A", TInt32), ("B", TInt32), ("C", TInt32), ("B_1", TInt32), ("C_1", TInt32)) + val kType = TStruct(("A", TInt32), ("B", TInt32)) val leftData = FastSeq( @@ -258,7 +337,7 @@ class TableIRSuite extends HailSuite { (36, 1, -1), (36, 2, -1), (37, 1, -1), - (37, 2, -1) + (37, 2, -1), ).map(Row.fromTuple) val rightData = FastSeq( @@ -285,7 +364,7 @@ class TableIRSuite extends HailSuite { (38, 1, 1), (38, 2, 1), (41, 1, 1), - (41, 2, 1) + (41, 2, 1), ).map(Row.fromTuple) val expectedUnion = Array( @@ -336,7 +415,7 @@ class TableIRSuite extends HailSuite { (38, 1, 1), (38, 2, 1), (41, 1, 1), - (41, 2, 1) + (41, 2, 1), ).map(Row.fromTuple) val expectedZipJoin = Array( @@ -377,8 +456,8 @@ class TableIRSuite extends HailSuite { (38, 1, FastSeq(null, Row(1))), (38, 2, FastSeq(null, Row(1))), (41, 1, FastSeq(null, Row(1))), - (41, 2, FastSeq(null, Row(1))) - ).map(Row.fromTuple) + (41, 2, FastSeq(null, Row(1))), + ).map(Row.fromTuple) val expectedOuterJoin = Array( (3, 1, -1, null, null), @@ -428,14 +507,14 @@ class TableIRSuite extends HailSuite { (38, null, null, 1, 1), (38, null, null, 2, 1), (41, null, null, 1, 1), - (41, null, null, 2, 1) + (41, null, null, 2, 1), ).map(Row.fromTuple) val joinTypes = Array( ("outer", (row: Row) => true), ("left", (row: Row) => !row.isNullAt(1)), ("right", (row: Row) => !row.isNullAt(3)), - ("inner", (row: Row) => !row.isNullAt(1) && !row.isNullAt(3)) + ("inner", (row: Row) => !row.isNullAt(1) && !row.isNullAt(3)), ) @DataProvider(name = "join") @@ -448,19 +527,17 @@ class TableIRSuite extends HailSuite { val ab = new BoxedArrayBuilder[Array[Any]]() for ((j, p) <- joinTypes) { for { - lParts <- Array[Integer](1, 2, 3); rParts <- Array[Integer](1, 2, 3) - } { - - ab += Array[Any](lParts, rParts, j, p, defaultLeftProject, defaultRightProject) + lParts <- Array[Integer](1, 2, 3) + rParts <- Array[Integer](1, 2, 3) } + ab += Array[Any](lParts, rParts, j, p, defaultLeftProject, defaultRightProject) for { leftProject <- Seq[Set[Int]](Set(), Set(1), Set(2), Set(1, 2)) - rightProject <- Seq[Set[Int]](Set(), Set(1), Set(2), Set(1, 2)) - if !leftProject.contains(1) || rightProject.contains(1) - } { - ab += Array[Any](defaultLParts, defaultRParts, j, p, leftProject, rightProject) + rightProject <- Seq[Set[Int]](Set(), Set(1), Set(2), Set(1, 2)) + if !leftProject.contains(1) || rightProject.contains(1) } + ab += Array[Any](defaultLParts, defaultRParts, j, p, leftProject, rightProject) } ab.result() } @@ -472,27 +549,34 @@ class TableIRSuite extends HailSuite { joinType: String, pred: Row => Boolean, leftProject: Set[Int], - rightProject: Set[Int] - ) { + rightProject: Set[Int], + ): Unit = { val (leftType, leftProjectF) = rowType.filter(f => !leftProject.contains(f.index)) val left = TableKeyBy( TableParallelize( Literal( TStruct("rows" -> TArray(leftType), "global" -> TStruct.empty), - Row(leftData.map(leftProjectF.asInstanceOf[Row => Row]), Row())), - Some(lParts)), - if (!leftProject.contains(1)) FastSeq("A", "B") else FastSeq("A")) + Row(leftData.map(leftProjectF.asInstanceOf[Row => Row]), Row()), + ), + Some(lParts), + ), + if (!leftProject.contains(1)) FastSeq("A", "B") else FastSeq("A"), + ) val (rightType, rightProjectF) = rowType.filter(f => !rightProject.contains(f.index)) val right = TableKeyBy( TableParallelize( Literal( TStruct("rows" -> TArray(rightType), "global" -> TStruct.empty), - Row(rightData.map(rightProjectF.asInstanceOf[Row => Row]), Row())), - Some(rParts)), - if (!rightProject.contains(1)) FastSeq("A", "B") else FastSeq("A")) + Row(rightData.map(rightProjectF.asInstanceOf[Row => Row]), Row()), + ), + Some(rParts), + ), + if (!rightProject.contains(1)) FastSeq("A", "B") else FastSeq("A"), + ) - val (_, joinProjectF) = joinedType.filter(f => !leftProject.contains(f.index) && !rightProject.contains(f.index - 2)) + val (_, joinProjectF) = + joinedType.filter(f => !leftProject.contains(f.index) && !rightProject.contains(f.index - 2)) val joined = collect( TableJoin( left, @@ -502,8 +586,12 @@ class TableIRSuite extends HailSuite { .filter(right.typ.rowType.hasField) .map(a => a -> (a + "_")) .toMap, - Map.empty), - joinType, 1)) + Map.empty, + ), + joinType, + 1, + ) + ) assertEvalsTo(joined, Row(expectedOuterJoin.filter(pred).map(joinProjectF).toFastSeq, Row())) } @@ -516,22 +604,28 @@ class TableIRSuite extends HailSuite { } yield Array[Any](lParts, rParts) @Test(dataProvider = "union") - def testTableUnion(lParts: Int, rParts: Int) { + def testTableUnion(lParts: Int, rParts: Int): Unit = { val left = TableKeyBy( TableParallelize( Literal( TStruct("rows" -> TArray(rowType), "global" -> TStruct.empty), - Row(leftData, Row())), - Some(lParts)), - FastSeq("A", "B")) + Row(leftData, Row()), + ), + Some(lParts), + ), + FastSeq("A", "B"), + ) val right = TableKeyBy( TableParallelize( Literal( TStruct("rows" -> TArray(rowType), "global" -> TStruct.empty), - Row(rightData, Row())), - Some(rParts)), - FastSeq("A", "B")) + Row(rightData, Row()), + ), + Some(rParts), + ), + FastSeq("A", "B"), + ) val merged = collect(TableUnion(FastSeq(left, right))) @@ -539,23 +633,29 @@ class TableIRSuite extends HailSuite { } @Test(dataProvider = "union") - def testTableMultiWayZipJoin(lParts: Int, rParts: Int) { + def testTableMultiWayZipJoin(lParts: Int, rParts: Int): Unit = { implicit val execStrats = Set(ExecStrategy.LoweredJVMCompile) val left = TableKeyBy( TableParallelize( Literal( TStruct("rows" -> TArray(rowType), "global" -> TStruct.empty), - Row(leftData, Row())), - Some(lParts)), - FastSeq("A", "B")) + Row(leftData, Row()), + ), + Some(lParts), + ), + FastSeq("A", "B"), + ) val right = TableKeyBy( TableParallelize( Literal( TStruct("rows" -> TArray(rowType), "global" -> TStruct.empty), - Row(rightData, Row())), - Some(rParts)), - FastSeq("A", "B")) + Row(rightData, Row()), + ), + Some(rParts), + ), + FastSeq("A", "B"), + ) val merged = collect(TableMultiWayZipJoin(FastSeq(left, right), "row", "global")) @@ -563,7 +663,7 @@ class TableIRSuite extends HailSuite { } // Catches a bug in the partitioner created by the importer. - @Test def testTableJoinOfImport() { + @Test def testTableJoinOfImport(): Unit = { val mnr = MatrixNativeReader(fs, "src/test/resources/sample.vcf.mt") val mt2 = MatrixRead(mnr.fullMatrixType, false, false, mnr) val t2 = MatrixRowsTable(mt2) @@ -574,7 +674,26 @@ class TableIRSuite extends HailSuite { assertEvalsTo(TableCount(join), 346L) } - @Test def testTableKeyBy() { + @Test def testNativeReaderWithOverlappingPartitions(): Unit = { + val path = "src/test/resources/sample.vcf-20-partitions-with-overlap.mt/rows" + // i1 overlaps the first two partitions + val i1 = Interval(Row(Locus("20", 10200000)), Row(Locus("20", 10500000)), true, true) + + def test(filterIntervals: Boolean, expectedNParts: Int): Unit = { + val opts = NativeReaderOptions(FastSeq(i1), TLocus("GRCh37"), filterIntervals) + val tr = TableNativeReader(fs, TableNativeReaderParameters(path, Some(opts))) + val tir = TableRead(tr.fullTypeWithoutUIDs, false, tr) + val nParts = TableToValueApply(tir, NPartitionsTable()) + val count = TableToValueApply(tir, ForceCountTable()) + assertEvalsTo(nParts, expectedNParts) + assertEvalsTo(count, 20L) + } + + test(false, 1) + test(true, 2) + } + + @Test def testTableKeyBy(): Unit = { implicit val execStrats = ExecStrategy.interpretOnly val data = Array(Array("A", 1), Array("A", 2), Array("B", 1)) val rdd = sc.parallelize(data.map(Row.fromSeq(_))) @@ -582,16 +701,27 @@ class TableIRSuite extends HailSuite { val keyNames = FastSeq("field1", "field2") val tt = TableType(rowType = signature, key = keyNames, globalType = TStruct.empty) val base = TableLiteral( - TableValue(ctx, tt.rowType, tt.key, rdd), theHailClassLoader) + TableValue(ctx, tt.rowType, tt.key, rdd), + theHailClassLoader, + ) - // construct the table with a longer key, then copy the table to shorten the key in type, but not rvd - val distinctCount = TableCount(TableDistinct(TableLiteral(tt.copy(key = FastSeq("field1")), base.rvd, base.enc, base.encodedGlobals))) + /* construct the table with a longer key, then copy the table to shorten the key in type, but + * not rvd */ + val distinctCount = TableCount(TableDistinct(TableLiteral( + tt.copy(key = FastSeq("field1")), + base.rvd, + base.enc, + base.encodedGlobals, + ))) assertEvalsTo(distinctCount, 2L) } - @Test def testTableKeyByLowering() { + @Test def testTableKeyByLowering(): Unit = { implicit val execStrats = ExecStrategy.lowering - val t = TStruct("rows" -> TArray(TStruct("a" -> TInt32, "b" -> TString)), "global" -> TStruct("x" -> TString)) + val t = TStruct( + "rows" -> TArray(TStruct("a" -> TInt32, "b" -> TString)), + "global" -> TStruct("x" -> TString), + ) val length = 10 val value = Row(FastSeq(0 until length: _*).map(i => Row(0, "row" + i)), Row("global")) @@ -601,9 +731,12 @@ class TableIRSuite extends HailSuite { assertEvalsTo(TableCount(keyed), length.toLong) } - @Test def testTableParallelize() { + @Test def testTableParallelize(): Unit = { implicit val execStrats = ExecStrategy.allRelational - val t = TStruct("rows" -> TArray(TStruct("a" -> TInt32, "b" -> TString)), "global" -> TStruct("x" -> TString)) + val t = TStruct( + "rows" -> TArray(TStruct("a" -> TInt32, "b" -> TString)), + "global" -> TStruct("x" -> TString), + ) Array(1, 10, 17, 34, 103).foreach { length => val value = Row(FastSeq(0 until length: _*).map(i => Row(i, "row" + i)), Row("global")) assertEvalsTo( @@ -611,14 +744,21 @@ class TableIRSuite extends HailSuite { TableParallelize( Literal( t, - value - ))), value) + value, + ) + ) + ), + value, + ) } } - @Test def testTableParallelizeCount() { + @Test def testTableParallelizeCount(): Unit = { implicit val execStrats: Set[ExecStrategy] = ExecStrategy.allRelational - val t = TStruct("rows" -> TArray(TStruct("a" -> TInt32, "b" -> TString)), "global" -> TStruct("x" -> TString)) + val t = TStruct( + "rows" -> TArray(TStruct("a" -> TInt32, "b" -> TString)), + "global" -> TStruct("x" -> TString), + ) val value = Row(FastSeq(Row(0, "row1"), Row(1, "row2")), Row("glob")) assertEvalsTo( @@ -626,17 +766,21 @@ class TableIRSuite extends HailSuite { TableParallelize( Literal( t, - value - ))), - 2L + value, + ) + ) + ), + 2L, ) } @Test def testTableHead(): Unit = { - val t = TStruct("rows" -> TArray(TStruct("a" -> TInt32, "b" -> TString)), "global" -> TStruct("x" -> TString)) - def makeData(length: Int): Row = { + val t = TStruct( + "rows" -> TArray(TStruct("a" -> TInt32, "b" -> TString)), + "global" -> TStruct("x" -> TString), + ) + def makeData(length: Int): Row = Row(FastSeq(0 until length: _*).map(i => Row(i, "row" + i)), Row("global")) - } val numRowsToTakeArray = Array(0, 4, 7, 12) val numInitialPartitionsArray = Array(1, 2, 6, 10, 13) val initialDataLength = 10 @@ -650,27 +794,34 @@ class TableIRSuite extends HailSuite { TableHead( TableParallelize( Literal(t, initialData), - Some(howManyInitialPartitions) + Some(howManyInitialPartitions), ), - howManyRowsToTake + howManyRowsToTake, ) ), - headData) + headData, + ) } } } @Test def testTableTail(): Unit = { - val t = TStruct("rows" -> TArray(TStruct("a" -> TInt32, "b" -> TString)), "global" -> TStruct("x" -> TString)) + val t = TStruct( + "rows" -> TArray(TStruct("a" -> TInt32, "b" -> TString)), + "global" -> TStruct("x" -> TString), + ) val numRowsToTakeArray = Array(0, 2, 7, 10, 12) val numInitialPartitionsArray = Array(1, 3, 6, 10, 13) val initialDataLength = 10 - def makeData(length: Int): Row = { - Row(FastSeq((initialDataLength - length) until initialDataLength: _*).map(i => Row(i, "row" + i)), Row("global")) - } + def makeData(length: Int): Row = + Row( + FastSeq((initialDataLength - length) until initialDataLength: _*).map(i => + Row(i, "row" + i) + ), + Row("global"), + ) val initialData = makeData(initialDataLength) - numRowsToTakeArray.foreach { howManyRowsToTake => val headData = makeData(Math.min(howManyRowsToTake, initialDataLength)) numInitialPartitionsArray.foreach { howManyInitialPartitions => @@ -679,17 +830,18 @@ class TableIRSuite extends HailSuite { TableTail( TableParallelize( Literal(t, initialData), - Some(howManyInitialPartitions) + Some(howManyInitialPartitions), ), - howManyRowsToTake + howManyRowsToTake, ) ), - headData) + headData, + ) } } } - @Test def testShuffleAndJoinDoesntMemoryLeak() { + @Test def testShuffleAndJoinDoesntMemoryLeak(): Unit = { implicit val execStrats = Set(ExecStrategy.LoweredJVMCompile, ExecStrategy.Interpret) val row = Ref("row", TStruct("idx" -> TInt32)) val t1 = TableRename(TableRange(1, 1), Map("idx" -> "idx_"), Map.empty) @@ -697,49 +849,67 @@ class TableIRSuite extends HailSuite { TableKeyBy( TableMapRows( TableRange(50000, 1), - InsertFields(row, - FastSeq("k" -> (I32(49999) - GetField(row, "idx"))))), - FastSeq("k")) + InsertFields(row, FastSeq("k" -> (I32(49999) - GetField(row, "idx")))), + ), + FastSeq("k"), + ) assertEvalsTo(TableCount(TableJoin(t1, t2, "left")), 1L) } @Test def testTableRename(): Unit = { implicit val execStrats = ExecStrategy.lowering - val t = TStruct("rows" -> TArray(TStruct("a" -> TInt32, "b" -> TString)), "global" -> TStruct(("x", TString), ("y", TInt32))) + val t = TStruct( + "rows" -> TArray(TStruct("a" -> TInt32, "b" -> TString)), + "global" -> TStruct(("x", TString), ("y", TInt32)), + ) val value = Row(FastSeq(0 until 10: _*).map(i => Row(i, "row" + i)), Row("globalVal", 3)) - val adjustedValue = Row(FastSeq(0 until 10: _*).map(i => Row(i + 3, "row" + i)), Row("globalVal", 3)) + val adjustedValue = + Row(FastSeq(0 until 10: _*).map(i => Row(i + 3, "row" + i)), Row("globalVal", 3)) val renameIR = TableRename( TableParallelize( Literal( t, - value - )), + value, + ) + ), Map[String, String]("a" -> "c"), - Map[String, String]("y" -> "z") + Map[String, String]("y" -> "z"), ) val newRow = MakeStruct(FastSeq( - ("foo", GetField(Ref("row", renameIR.typ.rowType), "c") + GetField(Ref("global", TStruct(("x", TString), ("z", TInt32))), "z")), - ("bar", GetField(Ref("row", renameIR.typ.rowType), "b"))) - ) + ( + "foo", + GetField(Ref("row", renameIR.typ.rowType), "c") + GetField( + Ref("global", TStruct(("x", TString), ("z", TInt32))), + "z", + ), + ), + ("bar", GetField(Ref("row", renameIR.typ.rowType), "b")), + )) val mapped = TableMapRows(renameIR, newRow) assertEvalsTo( collectNoKey( mapped - ), adjustedValue) + ), + adjustedValue, + ) } @Test def testTableMapGlobals(): Unit = { - val t = TStruct("rows" -> TArray(TStruct("a" -> TInt32, "b" -> TString)), "global" -> TStruct("x" -> TString)) + val t = TStruct( + "rows" -> TArray(TStruct("a" -> TInt32, "b" -> TString)), + "global" -> TStruct("x" -> TString), + ) val innerRowRef = Ref("row", t.field("rows").typ.asInstanceOf[TArray].elementType) val innerGlobalRef = Ref("global", t.field("global").typ) val length = 10 val value = Row(FastSeq(0 until length: _*).map(i => Row(i, "row" + i)), Row("global")) - val modifedValue = Row(FastSeq(0 until length: _*).map(i => Row(i, "global")), Row("newGlobals")) + val modifedValue = + Row(FastSeq(0 until length: _*).map(i => Row(i, "global")), Row("newGlobals")) assertEvalsTo( collectNoKey( TableMapGlobals( @@ -747,18 +917,22 @@ class TableIRSuite extends HailSuite { TableParallelize( Literal( t, - value - )), - MakeStruct(FastSeq("a" -> GetField(innerRowRef, "a"), "b" -> GetField(innerGlobalRef, "x"))) + value, + ) + ), + MakeStruct(FastSeq( + "a" -> GetField(innerRowRef, "a"), + "b" -> GetField(innerGlobalRef, "x"), + )), ), - MakeStruct(FastSeq("x" -> Str("newGlobals"))) + MakeStruct(FastSeq("x" -> Str("newGlobals"))), ) ), - modifedValue) + modifedValue, + ) } - @Test def testTableWrite() { - implicit val execStrats = ExecStrategy.interpretOnly + @Test def testTableWrite(): Unit = { val table = TableRange(5, 4) val path = ctx.createTmpPath("test-table-write", "ht") Interpret[Unit](ctx, TableWrite(table, TableNativeWriter(path))) @@ -771,17 +945,19 @@ class TableIRSuite extends HailSuite { } @Test def testWriteKeyDistinctness(): Unit = { - implicit val execStrats = ExecStrategy.interpretOnly val rt = TableRange(40, 4) - val idxRef = GetField(Ref("row", rt.typ.rowType), "idx") - val at = TableMapRows(rt, MakeStruct(FastSeq( - "idx" -> idxRef, - "const" -> 5, - "half" -> idxRef.floorDiv(2), - "oneRepeat" -> If(idxRef ceq I32(10), I32(9), idxRef), - "oneMissing" -> If(idxRef ceq I32(4), NA(TInt32), idxRef), - "twoMissing" -> If((idxRef ceq 10) || (idxRef ceq 2), NA(TInt32), idxRef) - ))) + val idxRef = GetField(Ref("row", rt.typ.rowType), "idx") + val at = TableMapRows( + rt, + MakeStruct(FastSeq( + "idx" -> idxRef, + "const" -> 5, + "half" -> idxRef.floorDiv(2), + "oneRepeat" -> If(idxRef ceq I32(10), I32(9), idxRef), + "oneMissing" -> If(idxRef ceq I32(4), NA(TInt32), idxRef), + "twoMissing" -> If((idxRef ceq 10) || (idxRef ceq 2), NA(TInt32), idxRef), + )), + ) val keyedByConst = TableKeyBy(at, IndexedSeq("const")) val pathConst = ctx.createTmpPath("test-table-write-distinctness", "ht") Interpret[Unit](ctx, TableWrite(keyedByConst, TableNativeWriter(pathConst))) @@ -819,7 +995,7 @@ class TableIRSuite extends HailSuite { assert(!readTwoMissing.isDistinctlyKeyed) } - @Test def testPartitionCountsWithDropRows() { + @Test def testPartitionCountsWithDropRows(): Unit = { val tr = new FakeTableReader { override def pathsUsed: Seq[String] = Seq.empty override def partitionCounts: Option[IndexedSeq[Long]] = Some(FastSeq(1, 2, 3, 4)) @@ -829,62 +1005,106 @@ class TableIRSuite extends HailSuite { assert(tir.partitionCounts.forall(_.sum == 0)) } - @Test def testScanInAggInMapRows() { + @Test def testScanInAggInMapRows(): Unit = { implicit val execStrats = ExecStrategy.interpretOnly val sumSig = AggSignature(Sum(), FastSeq(), FastSeq(TInt64)) var tr: TableIR = TableRange(10, 3) tr = TableKeyBy(tr, FastSeq(), false) - tr = TableMapRows(tr, InsertFields(Ref("row", tr.typ.rowType), - FastSeq(("result", - StreamAgg( - StreamAggScan( - StreamRange(0, GetField(Ref("row", tr.typ.rowType), "idx"), 1), - "streamx", - ApplyScanOp(FastSeq(), FastSeq(Ref("streamx", TInt32).toL), sumSig)), - "aggx", - ApplyAggOp(FastSeq(), FastSeq(Ref("aggx", TInt64)), sumSig)))))) - assertEvalsTo(TableCollect(tr), Row(IndexedSeq.tabulate(10) { i => - val r = (0 until i).map(_.toLong).scanLeft(0L)(_ + _).init.sum - Row(i, r) - }, Row() - )) + tr = TableMapRows( + tr, + InsertFields( + Ref("row", tr.typ.rowType), + FastSeq(( + "result", + StreamAgg( + StreamAggScan( + StreamRange(0, GetField(Ref("row", tr.typ.rowType), "idx"), 1), + "streamx", + ApplyScanOp(FastSeq(), FastSeq(Ref("streamx", TInt32).toL), sumSig), + ), + "aggx", + ApplyAggOp(FastSeq(), FastSeq(Ref("aggx", TInt64)), sumSig), + ), + )), + ), + ) + assertEvalsTo( + TableCollect(tr), + Row( + IndexedSeq.tabulate(10) { i => + val r = (0 until i).map(_.toLong).scanLeft(0L)(_ + _).init.sum + Row(i, r) + }, + Row(), + ), + ) } - @Test def testScanInAggInScanInMapRows() { + @Test def testScanInAggInScanInMapRows(): Unit = { implicit val execStrats = ExecStrategy.interpretOnly val sumSig = AggSignature(Sum(), FastSeq(), FastSeq(TInt64)) var tr: TableIR = TableRange(10, 3) tr = TableKeyBy(tr, FastSeq(), false) - tr = TableMapRows(tr, InsertFields(Ref("row", tr.typ.rowType), - FastSeq(("result", - ApplyScanOp(FastSeq(), - FastSeq(StreamAgg( - StreamAggScan( - StreamRange(0, GetField(Ref("row", tr.typ.rowType), "idx"), 1), - "streamx", - ApplyScanOp(FastSeq(), FastSeq(Ref("streamx", TInt32).toL), sumSig)), - "aggx", - ApplyAggOp(FastSeq(), FastSeq(Ref("aggx", TInt64)), sumSig))), - sumSig))))) - assertEvalsTo(TableCollect(tr), Row(Array.tabulate(10) { i => - (0 until i).map(_.toLong).scanLeft(0L)(_ + _).init.sum - }.scanLeft(0L)(_ + _) - .zipWithIndex - .map { case (x, idx) => Row(idx, x) }.init.toFastSeq, - Row() - )) + tr = TableMapRows( + tr, + InsertFields( + Ref("row", tr.typ.rowType), + FastSeq(( + "result", + ApplyScanOp( + FastSeq(), + FastSeq(StreamAgg( + StreamAggScan( + StreamRange(0, GetField(Ref("row", tr.typ.rowType), "idx"), 1), + "streamx", + ApplyScanOp(FastSeq(), FastSeq(Ref("streamx", TInt32).toL), sumSig), + ), + "aggx", + ApplyAggOp(FastSeq(), FastSeq(Ref("aggx", TInt64)), sumSig), + )), + sumSig, + ), + )), + ), + ) + assertEvalsTo( + TableCollect(tr), + Row( + Array.tabulate(10)(i => (0 until i).map(_.toLong).scanLeft(0L)(_ + _).init.sum).scanLeft( + 0L + )(_ + _) + .zipWithIndex + .map { case (x, idx) => Row(idx, x) }.init.toFastSeq, + Row(), + ), + ) } @Test def testTableAggregateByKey(): Unit = { implicit val execStrats = ExecStrategy.allRelational var tir: TableIR = TableRead.native(fs, "src/test/resources/three_key.ht") tir = TableKeyBy(tir, FastSeq("x", "y"), true) - tir = TableAggregateByKey(tir, MakeStruct(FastSeq( - ("sum", ApplyAggOp(FastSeq(), FastSeq(GetField(Ref("row", tir.typ.rowType), "z").toL), AggSignature(Sum(), FastSeq(), FastSeq(TInt64)))), - ("n", ApplyAggOp(FastSeq(), FastSeq(), AggSignature(Count(), FastSeq(), FastSeq()))) - ))) + tir = TableAggregateByKey( + tir, + MakeStruct(FastSeq( + ( + "sum", + ApplyAggOp( + FastSeq(), + FastSeq(GetField(Ref("row", tir.typ.rowType), "z").toL), + AggSignature(Sum(), FastSeq(), FastSeq(TInt64)), + ), + ), + ("n", ApplyAggOp(FastSeq(), FastSeq(), AggSignature(Count(), FastSeq(), FastSeq()))), + )), + ) val ir = GetField(TableCollect(TableKeyBy(tir, FastSeq())), "rows") - assertEvalsTo(ir, (0 until 10).flatMap(i => (0 until i).map(j => Row(i, j, (0 until j).sum.toLong, j.toLong))).filter(_.getAs[Long](3) > 0)) + assertEvalsTo( + ir, + (0 until 10).flatMap(i => + (0 until i).map(j => Row(i, j, (0 until j).sum.toLong, j.toLong)) + ).filter(_.getAs[Long](3) > 0), + ) } @Test def testTableDistinct(): Unit = { @@ -902,7 +1122,7 @@ class TableIRSuite extends HailSuite { assertEvalsTo(TableCount(distinctByAll), 120L) } - @Test def testRangeOrderByDescending() { + @Test def testRangeOrderByDescending(): Unit = { var tir: TableIR = TableRange(10, 3) tir = TableOrderBy(tir, FastSeq(SortField("idx", Descending))) val x = GetField(TableCollect(tir), "rows") @@ -911,17 +1131,21 @@ class TableIRSuite extends HailSuite { } @Test def testTableLeftJoinRightDistinctRangeTables(): Unit = { - IndexedSeq((1, 1), (3, 2), (10, 5), (5, 10)).foreach { case(nParts1, nParts2) => + IndexedSeq((1, 1), (3, 2), (10, 5), (5, 10)).foreach { case (nParts1, nParts2) => val rangeTable1 = TableRange(10, nParts1) var rangeTable2: TableIR = TableRange(5, nParts2) val row = Ref("row", rangeTable2.typ.rowType) - rangeTable2 = TableMapRows(rangeTable2, InsertFields(row, FastSeq("x" -> GetField(row, "idx")))) + rangeTable2 = + TableMapRows(rangeTable2, InsertFields(row, FastSeq("x" -> GetField(row, "idx")))) val joinedRanges = TableLeftJoinRightDistinct(rangeTable1, rangeTable2, "foo") assertEvalsTo(TableCount(joinedRanges), 10L) val expectedJoinCollectResult = Row( - (0 until 5).map(i => Row(FastSeq(i, Row(i)): _*)) ++ (5 until 10).map(i => Row(FastSeq(i, null): _*)), - Row()) + (0 until 5).map(i => Row(FastSeq(i, Row(i)): _*)) ++ (5 until 10).map(i => + Row(FastSeq(i, null): _*) + ), + Row(), + ) assertEvalsTo(collect(joinedRanges), expectedJoinCollectResult) } } @@ -933,17 +1157,36 @@ class TableIRSuite extends HailSuite { ir = ToArray(mapIR(ir)(ToArray)) ir = InsertFields(Ref("row", tir.typ.rowType), Seq("foo" -> ir)) tir = TableMapRows(tir, ir) - assertEvalsTo(collect(tir), Row(FastSeq(Row(0, FastSeq(FastSeq(0, 1), FastSeq(2, 3), FastSeq(4)))), Row())) + assertEvalsTo( + collect(tir), + Row(FastSeq(Row(0, FastSeq(FastSeq(0, 1), FastSeq(2, 3), FastSeq(4)))), Row()), + ) } val parTable1Length = 7 - val parTable1Type = TStruct("rows" -> TArray(TStruct("a1" -> TString, "b1" -> TInt32, "c1" -> TString)), "global" -> TStruct("x" -> TString)) - val value1 = Row(FastSeq(0 until parTable1Length: _*).map(i => Row("row" + i, i * i, s"t1_${i}")), Row("global")) + + val parTable1Type = TStruct( + "rows" -> TArray(TStruct("a1" -> TString, "b1" -> TInt32, "c1" -> TString)), + "global" -> TStruct("x" -> TString), + ) + + val value1 = Row( + FastSeq(0 until parTable1Length: _*).map(i => Row("row" + i, i * i, s"t1_$i")), + Row("global"), + ) + val table1 = TableParallelize(Literal(parTable1Type, value1), Some(2)) val parTable2Length = 9 - val parTable2Type = TStruct("rows" -> TArray(TStruct("a2" -> TString, "b2" -> TInt32, "c2" -> TString)), "global" -> TStruct("y"-> TInt32)) - val value2 = Row(FastSeq(0 until parTable2Length: _*).map(i => Row("row" + i, -2 * i, s"t2_${i}")), Row(15)) + + val parTable2Type = TStruct( + "rows" -> TArray(TStruct("a2" -> TString, "b2" -> TInt32, "c2" -> TString)), + "global" -> TStruct("y" -> TInt32), + ) + + val value2 = + Row(FastSeq(0 until parTable2Length: _*).map(i => Row("row" + i, -2 * i, s"t2_$i")), Row(15)) + val table2 = TableParallelize(Literal(parTable2Type, value2), Some(3)) val table1KeyedByA = TableKeyBy(table1, IndexedSeq("a1")) @@ -955,18 +1198,90 @@ class TableIRSuite extends HailSuite { assertEvalsTo(TableCount(table2KeyedByA), parTable2Length.toLong) assertEvalsTo(TableCount(joinedParKeyedByA), parTable1Length.toLong) - assertEvalsTo(collect(joinedParKeyedByA), Row(FastSeq(0 until parTable1Length: _*).map(i => - Row("row" + i, i * i, s"t1_${i}", Row(-2 * i, s"t2_${i}"))), Row("global")) + assertEvalsTo( + collect(joinedParKeyedByA), + Row( + FastSeq(0 until parTable1Length: _*).map(i => + Row("row" + i, i * i, s"t1_$i", Row(-2 * i, s"t2_$i")) + ), + Row("global"), + ), ) } @Test def testTableLeftJoinRightDistinctParallelizePrefixKey(): Unit = { val table1KeyedByAAndB = TableKeyBy(table1, IndexedSeq("a1", "b1")) - val joinedParKeyedByAAndB = TableLeftJoinRightDistinct(table1KeyedByAAndB, table2KeyedByA, "joinRoot") + val joinedParKeyedByAAndB = + TableLeftJoinRightDistinct(table1KeyedByAAndB, table2KeyedByA, "joinRoot") assertEvalsTo(TableCount(joinedParKeyedByAAndB), parTable1Length.toLong) - assertEvalsTo(collect(joinedParKeyedByAAndB), Row(FastSeq(0 until parTable1Length: _*).map(i => - Row("row" + i, i * i, s"t1_${i}", Row(-2 * i, s"t2_${i}"))), Row("global")) + assertEvalsTo( + collect(joinedParKeyedByAAndB), + Row( + FastSeq(0 until parTable1Length: _*).map(i => + Row("row" + i, i * i, s"t1_$i", Row(-2 * i, s"t2_$i")) + ), + Row("global"), + ), + ) + } + + @Test def testTableIntervalJoin(): Unit = { + val intervals: IndexedSeq[Interval] = + for { + (start, end, includesStart, includesEnd) <- FastSeq( + (1, 6, true, false), + (2, 2, false, false), + (3, 5, true, true), + (4, 6, true, true), + (6, 7, false, true), + ) + } yield Interval( + IntervalEndpoint(start, if (includesStart) -1 else 1), + IntervalEndpoint(end, if (includesEnd) 1 else -1), + ) + + val left = + TableKeyBy( + TableParallelize(MakeStruct(FastSeq( + "rows" -> Literal(TArray(TStruct("a" -> TInt32)), (0 until 9).map(Row(_))), + "global" -> MakeStruct(FastSeq("left" -> Str("globals"))), + ))), + FastSeq("a"), + isSorted = true, + ) + + val right = + TableKeyBy( + TableParallelize(MakeStruct(FastSeq( + "rows" -> Literal( + TArray(TStruct("interval" -> TInterval(TInt32), "b" -> TInt32)), + intervals.zipWithIndex.map { case (i, idx) => Row(i, idx) }, + ), + "global" -> MakeStruct(FastSeq("bye" -> I32(-1))), + ))), + FastSeq("interval"), + isSorted = true, + ) + + val join = TableIntervalJoin(left, right, "rights", product = true) + + assertEvalsTo( + collect(join), + Row( + FastSeq( + Row(0, FastSeq()), + Row(1, FastSeq(Row(0))), + Row(2, FastSeq(Row(0))), + Row(3, FastSeq(Row(2), Row(0))), + Row(4, FastSeq(Row(2), Row(0), Row(3))), + Row(5, FastSeq(Row(2), Row(0), Row(3))), + Row(6, FastSeq(Row(3))), + Row(7, FastSeq(Row(4))), + Row(8, FastSeq()), + ), + Row("globals"), + ), ) } @@ -975,117 +1290,220 @@ class TableIRSuite extends HailSuite { val unkeyed = TableKeyBy(tir, IndexedSeq[String]()) val rowRef = Ref("row", unkeyed.typ.rowType) val aggSignature = AggSignature(Sum(), FastSeq(), FastSeq(TInt64)) - val aggExpression = MakeStruct(FastSeq("y_sum" -> ApplyAggOp(FastSeq(), FastSeq(Cast(GetField(rowRef, "y"), TInt64)), aggSignature))) - val keyByXAndAggregateSum = TableKeyByAndAggregate(unkeyed, aggExpression, MakeStruct(FastSeq("x" -> GetField(rowRef, "x"))), bufferSize = 50) + val aggExpression = MakeStruct(FastSeq("y_sum" -> ApplyAggOp( + FastSeq(), + FastSeq(Cast(GetField(rowRef, "y"), TInt64)), + aggSignature, + ))) + val keyByXAndAggregateSum = TableKeyByAndAggregate( + unkeyed, + aggExpression, + MakeStruct(FastSeq("x" -> GetField(rowRef, "x"))), + bufferSize = 50, + ) assertEvalsTo( collect(keyByXAndAggregateSum), - Row(FastSeq(Row(2, 1L), Row(3,5L), Row(4, 14L), Row(5, 30L), Row(6, 55L), Row(7, 91L), Row(8, 140L), Row(9, 204L)), Row()) + Row( + FastSeq( + Row(2, 1L), + Row(3, 5L), + Row(4, 14L), + Row(5, 30L), + Row(6, 55L), + Row(7, 91L), + Row(8, 140L), + Row(9, 204L), + ), + Row(), + ), ) // Keying by a newly computed field. - val keyByXPlusTwoAndAggregateSum = TableKeyByAndAggregate(unkeyed, aggExpression, MakeStruct(FastSeq("xPlusTwo" -> (GetField(rowRef, "x") + 2))), bufferSize = 50) + val keyByXPlusTwoAndAggregateSum = TableKeyByAndAggregate( + unkeyed, + aggExpression, + MakeStruct(FastSeq("xPlusTwo" -> (GetField(rowRef, "x") + 2))), + bufferSize = 50, + ) assertEvalsTo( collect(keyByXPlusTwoAndAggregateSum), - Row(FastSeq(Row(4, 1L), Row(5,5L), Row(6, 14L), Row(7, 30L), Row(8, 55L), Row(9, 91L), Row(10, 140L), Row(11, 204L)), Row()) + Row( + FastSeq( + Row(4, 1L), + Row(5, 5L), + Row(6, 14L), + Row(7, 30L), + Row(8, 55L), + Row(9, 91L), + Row(10, 140L), + Row(11, 204L), + ), + Row(), + ), ) // Keying by just Z when original is keyed by x,y,z, naming it x anyway. - val keyByZAndAggregateSum = TableKeyByAndAggregate(tir, aggExpression, MakeStruct(FastSeq("x" -> GetField(rowRef, "z"))), bufferSize = 50) + val keyByZAndAggregateSum = TableKeyByAndAggregate( + tir, + aggExpression, + MakeStruct(FastSeq("x" -> GetField(rowRef, "z"))), + bufferSize = 50, + ) assertEvalsTo( collect(keyByZAndAggregateSum), - Row(FastSeq(Row(0, 120L), Row(1, 112L), Row(2, 98L), Row(3, 80L), Row(4, 60L), Row(5, 40L), Row(6, 22L), Row(7, 8L)), Row()) + Row( + FastSeq( + Row(0, 120L), + Row(1, 112L), + Row(2, 98L), + Row(3, 80L), + Row(4, 60L), + Row(5, 40L), + Row(6, 22L), + Row(7, 8L), + ), + Row(), + ), ) } @Test def testTableAggregateCollectAndTake(): Unit = { implicit val execStrats = ExecStrategy.allRelational var tir: TableIR = TableRange(10, 3) - tir = TableMapRows(tir, InsertFields(Ref("row", tir.typ.rowType), FastSeq("aStr" -> Str("foo")))) - val x = TableAggregate(tir, + tir = + TableMapRows(tir, InsertFields(Ref("row", tir.typ.rowType), FastSeq("aStr" -> Str("foo")))) + val x = TableAggregate( + tir, MakeTuple.ordered(FastSeq( ApplyAggOp(Collect())(Ref("row", tir.typ.rowType)), - ApplyAggOp(Take(), I32(5))(GetField(Ref("row", tir.typ.rowType), "idx")) - ))) + ApplyAggOp(Take(), I32(5))(GetField(Ref("row", tir.typ.rowType), "idx")), + )), + ) - assertEvalsTo(x, Row( - (0 until 10).map(i => Row(i, "foo")), - 0 until 5)) + assertEvalsTo( + x, + Row( + (0 until 10).map(i => Row(i, "foo")), + 0 until 5, + ), + ) } @Test def testNDArrayMultiplyAddAggregator(): Unit = { implicit val execStrats = ExecStrategy.allRelational var tir: TableIR = TableRange(6, 3) - val nDArray1 = Literal(TNDArray(TFloat64, Nat(2)), SafeNDArray(IndexedSeq(2L,2L), IndexedSeq(1.0,1.0,1.0,1.0))) - val nDArray2 = Literal(TNDArray(TFloat64, Nat(2)), SafeNDArray(IndexedSeq(2L,2L), IndexedSeq(2.0,2.0,2.0,2.0))) - tir = TableMapRows(tir, InsertFields(Ref("row", tir.typ.rowType), - FastSeq("nDArrayA" -> nDArray1, "nDArrayB" -> nDArray2))) - val x = TableAggregate(tir, ApplyAggOp(NDArrayMultiplyAdd())(GetField(Ref("row", tir.typ.rowType), "nDArrayA"), - GetField(Ref("row", tir.typ.rowType), "nDArrayB"))) + val nDArray1 = Literal( + TNDArray(TFloat64, Nat(2)), + SafeNDArray(IndexedSeq(2L, 2L), IndexedSeq(1.0, 1.0, 1.0, 1.0)), + ) + val nDArray2 = Literal( + TNDArray(TFloat64, Nat(2)), + SafeNDArray(IndexedSeq(2L, 2L), IndexedSeq(2.0, 2.0, 2.0, 2.0)), + ) + tir = TableMapRows( + tir, + InsertFields( + Ref("row", tir.typ.rowType), + FastSeq("nDArrayA" -> nDArray1, "nDArrayB" -> nDArray2), + ), + ) + val x = TableAggregate( + tir, + ApplyAggOp(NDArrayMultiplyAdd())( + GetField(Ref("row", tir.typ.rowType), "nDArrayA"), + GetField(Ref("row", tir.typ.rowType), "nDArrayB"), + ), + ) assertEvalsTo(x, SafeNDArray(Vector(2, 2), IndexedSeq(24.0, 24.0, 24.0, 24.0))) } @Test def testTableScanCollect(): Unit = { implicit val execStrats = ExecStrategy.allRelational var tir: TableIR = TableRange(5, 3) - tir = TableMapRows(tir, - InsertFields(Ref("row", tir.typ.rowType), - FastSeq("scans" -> MakeTuple.ordered(FastSeq(ApplyScanOp(Count())(), ApplyScanOp(Collect())(GetField(Ref("row", tir.typ.rowType), "idx"))))))) - val x = TableAggregate(tir, - ApplyAggOp(Collect())(Ref("row", tir.typ.rowType)) - ) + tir = TableMapRows( + tir, + InsertFields( + Ref("row", tir.typ.rowType), + FastSeq("scans" -> MakeTuple.ordered(FastSeq( + ApplyScanOp(Count())(), + ApplyScanOp(Collect())(GetField(Ref("row", tir.typ.rowType), "idx")), + ))), + ), + ) + val x = TableAggregate(tir, ApplyAggOp(Collect())(Ref("row", tir.typ.rowType))) - assertEvalsTo(x, FastSeq( - Row(0, Row(0L, FastSeq())), - Row(1, Row(1L, FastSeq(0))), - Row(2, Row(2L, FastSeq(0,1))), - Row(3, Row(3L, FastSeq(0,1,2))), - Row(4, Row(4L, FastSeq(0,1,2,3))) - )) + assertEvalsTo( + x, + FastSeq( + Row(0, Row(0L, FastSeq())), + Row(1, Row(1L, FastSeq(0))), + Row(2, Row(2L, FastSeq(0, 1))), + Row(3, Row(3L, FastSeq(0, 1, 2))), + Row(4, Row(4L, FastSeq(0, 1, 2, 3))), + ), + ) } - @Test def testIssue9016() { - val rows = mapIR(ToStream(MakeArray(makestruct("a" -> MakeTuple.ordered(FastSeq(I32(0), I32(1))))))) { row => - If(IsNA(row), - NA(TStruct("a" -> TTuple(FastSeq(TupleField(1, TInt32))))), - makestruct("a" -> bindIR(GetField(row, "a")) { a => - If(IsNA(a), NA(TTuple(FastSeq(TupleField(1, TInt32)))), MakeTuple(FastSeq(1 -> GetTupleElement(a, 1)))) - })) - } - val table = TableParallelize(makestruct("rows" -> ToArray(rows), "global" -> makestruct()), None) + @Test def testIssue9016(): Unit = { + val rows = + mapIR(ToStream(MakeArray(makestruct("a" -> MakeTuple.ordered(FastSeq(I32(0), I32(1))))))) { + row => + If( + IsNA(row), + NA(TStruct("a" -> TTuple(FastSeq(TupleField(1, TInt32))))), + makestruct("a" -> bindIR(GetField(row, "a")) { a => + If( + IsNA(a), + NA(TTuple(FastSeq(TupleField(1, TInt32)))), + MakeTuple(FastSeq(1 -> GetTupleElement(a, 1))), + ) + }), + ) + } + val table = + TableParallelize(makestruct("rows" -> ToArray(rows), "global" -> makestruct()), None) assertEvalsTo(TableCollect(table), Row(FastSeq(Row(Row(1))), Row())) } @Test def testTableNativeZippedReaderWithPrefixKey(): Unit = { - /* - This test is important because it tests that we can handle lowering with a TableNativeZippedReader - when elements of the original key get pruned away (so I copy key to only be "locus" instead of "locus", "alleles") - */ + /* This test is important because it tests that we can handle lowering with a + * TableNativeZippedReader when elements of the original key get pruned away (so I copy key to + * only be "locus" instead of "locus", "alleles") */ val rowsPath = "src/test/resources/sample.vcf.mt/rows" val entriesPath = "src/test/resources/sample.vcf.mt/entries" val mnr = MatrixNativeReader(fs, "src/test/resources/sample.vcf.mt") val mnrSpec = mnr.getSpec() - val reader = TableNativeZippedReader(rowsPath, entriesPath, None, mnrSpec.rowsSpec, mnrSpec.entriesSpec) - val tableType = mnr.matrixToTableType(mnr.fullMatrixType).copy(globalType = TStruct(), key=IndexedSeq("locus")) - val irToLower = TableAggregate(TableRead(tableType, false, reader), + val reader = + TableNativeZippedReader(rowsPath, entriesPath, None, mnrSpec.rowsSpec, mnrSpec.entriesSpec) + val tableType = mnr.matrixToTableType(mnr.fullMatrixType).copy( + globalType = TStruct(), + key = IndexedSeq("locus"), + ) + val irToLower = TableAggregate( + TableRead(tableType, false, reader), MakeTuple.ordered(FastSeq( ApplyAggOp(Collect())(GetField(Ref("row", tableType.rowType), "rsid")) - ))) + )), + ) val optimized = Optimize(irToLower, "foo", ctx) val analyses = LoweringAnalyses.apply(optimized, ctx) LowerTableIR(optimized, DArrayLowering.All, ctx, analyses) } - @Test def testTableMapPartitions() { + @Test def testTableMapPartitions(): Unit = { val table = TableKeyBy( TableMapGlobals( TableRange(20, nPartitions = 4), - MakeStruct(FastSeq("greeting" -> Str("Hello")))), - IndexedSeq(), false) + MakeStruct(FastSeq("greeting" -> Str("Hello"))), + ), + IndexedSeq(), + false, + ) val rowType = TStruct("idx" -> TInt32) val part = Ref("part", TStream(rowType)) @@ -1094,77 +1512,122 @@ class TableIRSuite extends HailSuite { assertEvalsTo( collect( - TableMapPartitions(table, "g", "part", + TableMapPartitions( + table, + "g", + "part", StreamMap( part, "row2", - InsertFields(Ref("row2", rowType), FastSeq("str" -> Str("foo"))))) + InsertFields(Ref("row2", rowType), FastSeq("str" -> Str("foo"))), + ), + ) ), - Row(IndexedSeq.tabulate(20) { i => - Row(i, "foo") - }, Row("Hello"))) + Row(IndexedSeq.tabulate(20)(i => Row(i, "foo")), Row("Hello")), + ) assertEvalsTo( collect( - TableMapPartitions(table, "g", "part", + TableMapPartitions( + table, + "g", + "part", StreamFilter( part, "row2", - GetField(Ref("row2", rowType), "idx") > 0)) + GetField(Ref("row2", rowType), "idx") > 0, + ), + ) ), - Row(IndexedSeq.tabulate(20) { i => - Row(i) - }.filter(_.getAs[Int](0) > 0), Row("Hello"))) + Row(IndexedSeq.tabulate(20)(i => Row(i)).filter(_.getAs[Int](0) > 0), Row("Hello")), + ) assertEvalsTo( collect( - TableMapPartitions(table, "g", "part", + TableMapPartitions( + table, + "g", + "part", StreamFlatMap( part, "row2", mapIR(StreamRange(0, 3, 1)) { i => MakeStruct(FastSeq("str" -> Str("Hello"), "i" -> i)) - } - ))), - Row((0 until 20).flatMap(i => (0 until 3).map(j => Row("Hello", j))), Row("Hello"))) + }, + ), + ) + ), + Row((0 until 20).flatMap(i => (0 until 3).map(j => Row("Hello", j))), Row("Hello")), + ) assertEvalsTo( collect( - TableMapPartitions(table, "g", "part", + TableMapPartitions( + table, + "g", + "part", // replace every row in partition with the first row StreamFilter( - StreamScan(part, - NA(rowType), - "acc", "row", - If(IsNA(acc), row, acc)), + StreamScan(part, NA(rowType), "acc", "row", If(IsNA(acc), row, acc)), "row", - !IsNA(row))) + !IsNA(row), + ), + ) + ), + Row( + IndexedSeq.tabulate(20) { i => + // 0,1,2,3,4,5,6,7,8,9,... ==> + // 0,0,0,0,0,5,5,5,5,5,... + Row((i / 5) * 5) + }, + Row("Hello"), ), - Row(IndexedSeq.tabulate(20) { i => - // 0,1,2,3,4,5,6,7,8,9,... ==> - // 0,0,0,0,0,5,5,5,5,5,... - Row((i / 5) * 5) - }, Row("Hello"))) + ) val e = intercept[HailException](TypeCheck( ctx, - collect(TableMapPartitions(table, "g", "part", StreamFlatMap(StreamRange(0, 2, 1), "_", part))))) - assert("must iterate over the partition exactly once".r.findFirstIn(e.getCause.getMessage).isDefined) + collect(TableMapPartitions( + table, + "g", + "part", + StreamFlatMap(StreamRange(0, 2, 1), "_", part), + )), + )) + assert( + "must iterate over the partition exactly once".r.findFirstIn(e.getCause.getMessage).isDefined + ) } - @Test def testRepartitionCostEstimate: Unit = { + @Test def testRepartitionCostEstimate(): Unit = { val empty = RVDPartitioner.empty(ctx.stateManager, TStruct(Array.empty[Field])) val some = RVDPartitioner.unkeyed(ctx.stateManager, _) val data = IndexedSeq( (empty, empty, Succeeded, Failed("Repartitioning from an empty partitioner should be free")), - (empty, some(1), Succeeded, Failed("Repartitioning from an empty partitioner should be free")), + ( + empty, + some(1), + Succeeded, + Failed("Repartitioning from an empty partitioner should be free"), + ), (some(1), empty, Succeeded, Failed("Repartitioning to an empty partitioner should be free")), - (some(5), some(1), Succeeded, Failed("Combining multiple partitions into one should not incur a reload")), - (some(1), some(60), Failed("Recomputing the same partition multiple times should be replaced with a reload"), Succeeded) + ( + some(5), + some(1), + Succeeded, + Failed("Combining multiple partitions into one should not incur a reload"), + ), + ( + some(1), + some(60), + Failed("Recomputing the same partition multiple times should be replaced with a reload"), + Succeeded, + ), ) - forAll(data) { case (a, b, t, f) => (if (LowerTableIR.isRepartitioningCheap(a, b)) t else f).toSucceeded } + forAll(data) { case (a, b, t, f) => + (if (LowerTableIR.isRepartitioningCheap(a, b)) t else f).toSucceeded + } } } diff --git a/hail/src/test/scala/is/hail/expr/ir/TakeByAggregatorSuite.scala b/hail/src/test/scala/is/hail/expr/ir/TakeByAggregatorSuite.scala index 2ba385f00e5..bca7915bff1 100644 --- a/hail/src/test/scala/is/hail/expr/ir/TakeByAggregatorSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/TakeByAggregatorSuite.scala @@ -8,15 +8,20 @@ import is.hail.types.VirtualTypeWithReq import is.hail.types.physical._ import is.hail.types.physical.stypes.primitives.SInt32Value import is.hail.utils._ + import org.testng.annotations.Test class TakeByAggregatorSuite extends HailSuite { - @Test def testPointers() { + @Test def testPointers(): Unit = { for ((size, n) <- Array((1000, 100), (1, 10), (100, 10000), (1000, 10000))) { val fb = EmitFunctionBuilder[Region, Long](ctx, "test_pointers") val cb = fb.ecb val stringPT = PCanonicalString(true) - val tba = new TakeByRVAS(VirtualTypeWithReq(PCanonicalString(true)), VirtualTypeWithReq(PInt64Optional), cb) + val tba = new TakeByRVAS( + VirtualTypeWithReq(PCanonicalString(true)), + VirtualTypeWithReq(PInt64Optional), + cb, + ) pool.scopedRegion { r => val argR = fb.getCodeParam[Region](1) val i = fb.genFieldThisRef[Long]() @@ -28,30 +33,36 @@ class TakeByAggregatorSuite extends HailSuite { tba.newState(cb, 0L) tba.initialize(cb, size) cb += (i := 0L) - cb.while_(i < n.toLong, { - cb += argR.invoke[Unit]("clear") - cb.assign(off, stringPT.allocateAndStoreString(cb, argR, const("str").concat(i.toS))) - tba.seqOp(cb, false, off, false, cb.memoize(-i)) - cb += (i := i + 1L) - }) + cb.while_( + i < n.toLong, { + cb += argR.invoke[Unit]("clear") + cb.assign(off, stringPT.allocateAndStoreString(cb, argR, const("str").concat(i.toS))) + tba.seqOp(cb, false, off, false, cb.memoize(-i)) + cb += (i := i + 1L) + }, + ) tba.result(cb, argR, rt).a } val o = fb.resultWithIndex()(theHailClassLoader, ctx.fs, ctx.taskContext, r)(r) val result = SafeRow.read(rt, o) - assert(result == ((n - 1) to 0 by -1) - .iterator - .map(i => s"str$i") - .take(size) - .toFastSeq, s"size=$size, n=$n") + assert( + result == ((n - 1) to 0 by -1) + .iterator + .map(i => s"str$i") + .take(size) + .toFastSeq, + s"size=$size, n=$n", + ) } } } - @Test def testMissing() { + @Test def testMissing(): Unit = { val fb = EmitFunctionBuilder[Region, Long](ctx, "take_by_test_missing") val cb = fb.ecb - val tba = new TakeByRVAS(VirtualTypeWithReq(PInt32Optional), VirtualTypeWithReq(PInt32Optional), cb) + val tba = + new TakeByRVAS(VirtualTypeWithReq(PInt32Optional), VirtualTypeWithReq(PInt32Optional), cb) pool.scopedRegion { r => val argR = fb.getCodeParam[Region](1) val rt = PCanonicalArray(tba.valueType) @@ -77,7 +88,7 @@ class TakeByAggregatorSuite extends HailSuite { } } - @Test def testRandom() { + @Test def testRandom(): Unit = { for (n <- Array(1, 2, 10, 100, 1000, 10000, 100000, 1000000)) { val nToTake = 1025 val fb = EmitFunctionBuilder[Region, Long](ctx, "take_by_test_random") @@ -89,7 +100,8 @@ class TakeByAggregatorSuite extends HailSuite { val random = fb.genFieldThisRef[Int]() val resultOff = fb.genFieldThisRef[Long]() - val tba = new TakeByRVAS(VirtualTypeWithReq(PInt32Required), VirtualTypeWithReq(PInt32Required), kb) + val tba = + new TakeByRVAS(VirtualTypeWithReq(PInt32Required), VirtualTypeWithReq(PInt32Required), kb) val ab = new agg.StagedArrayBuilder(PInt32Required, kb, argR) val rt = PCanonicalArray(tba.valueType) val er = new EmitRegion(fb.apply_method, argR) @@ -101,12 +113,14 @@ class TakeByAggregatorSuite extends HailSuite { tba.initialize(cb, nToTake) ab.initialize(cb) cb += (i := 0) - cb.while_(i < n, { - cb += (random := rng.invoke[Double, Double, Double]("runif", -10000d, 10000d).toI) - tba.seqOp(cb, false, random, false, random) - ab.append(cb, new SInt32Value(random)) - cb += (i := i + 1) - }) + cb.while_( + i < n, { + cb += (random := rng.invoke[Double, Double, Double]("runif", -10000d, 10000d).toI) + tba.seqOp(cb, false, random, false, random) + ab.append(cb, new SInt32Value(random)) + cb += (i := i + 1) + }, + ) cb.if_(ab.size cne n, cb._fatal("bad size!")) cb += (resultOff := argR.allocate(8L, 16L)) cb += Region.storeAddress(resultOff, tba.result(cb, argR, rt).a) diff --git a/hail/src/test/scala/is/hail/expr/ir/TestUtils.scala b/hail/src/test/scala/is/hail/expr/ir/TestUtils.scala index 028ecefefef..f7b9bda9b67 100644 --- a/hail/src/test/scala/is/hail/expr/ir/TestUtils.scala +++ b/hail/src/test/scala/is/hail/expr/ir/TestUtils.scala @@ -8,9 +8,14 @@ object TestUtils { def rangeKT: TableIR = TableKeyBy(TableRange(20, 4), FastSeq()) def collect(tir: TableIR): IR = - TableAggregate(tir, MakeStruct(FastSeq( - "rows" -> IRAggCollect(Ref("row", tir.typ.rowType)), - "global" -> Ref("global", tir.typ.globalType)))) + TableAggregate( + tir, + MakeStruct(FastSeq( + "rows" -> IRAggCollect(Ref("row", tir.typ.rowType)), + "global" -> Ref("global", tir.typ.globalType), + )), + ) + def collectNoKey(tir: TableIR): IR = TableCollect(tir) def toIRInt(i: Integer): IR = @@ -53,9 +58,9 @@ object TestUtils { else MakeArray(a.map(s => Literal.coerce(TString, s)), TArray(TString)) - def IRStringArray(a: String*): IR = toIRStringArray(FastSeq(a:_*)) + def IRStringArray(a: String*): IR = toIRStringArray(FastSeq(a: _*)) - def IRStringSet(a: String*): IR = ToSet(ToStream(toIRStringArray(FastSeq(a:_*)))) + def IRStringSet(a: String*): IR = ToSet(ToStream(toIRStringArray(FastSeq(a: _*)))) def toIRDoubleArray(a: IndexedSeq[java.lang.Double]): IR = if (a == null) @@ -82,14 +87,14 @@ object TestUtils { def toIRSet(a: IndexedSeq[Integer]): IR = if (a == null) NA(TSet(TInt32)) - else + else ToSet(ToStream(toIRArray(a))) def IRSet(a: Integer*): IR = toIRSet(a.toArray[Integer]) def IRCall(c: Call): IR = invoke("callFromRepr", TCall, I32(c)) - def IRAggCount: IR = { + def IRAggCount: IR = { val aggSig = AggSignature(Count(), FastSeq.empty, FastSeq.empty) ApplyAggOp(FastSeq.empty, FastSeq.empty, aggSig) } diff --git a/hail/src/test/scala/is/hail/expr/ir/TrapNodeSuite.scala b/hail/src/test/scala/is/hail/expr/ir/TrapNodeSuite.scala index 5909ca6748e..bab77019356 100644 --- a/hail/src/test/scala/is/hail/expr/ir/TrapNodeSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/TrapNodeSuite.scala @@ -1,16 +1,17 @@ package is.hail.expr.ir +import is.hail.{ExecStrategy, HailSuite} import is.hail.TestUtils._ import is.hail.types.virtual._ import is.hail.utils._ -import is.hail.{ExecStrategy, HailSuite} + import org.apache.spark.sql.Row import org.testng.annotations.Test class TrapNodeSuite extends HailSuite { implicit val execStrats = ExecStrategy.javaOnly - @Test def testTrapNode() { + @Test def testTrapNode(): Unit = { assertEvalsTo(Trap(ArrayRef(Literal(TArray(TInt32), FastSeq(0, 1, 2)), I32(1))), Row(null, 1)) val res = eval(Trap(ArrayRef(Literal(TArray(TInt32), FastSeq(0, 1, 2)), I32(-1)))) res match { @@ -22,13 +23,11 @@ class TrapNodeSuite extends HailSuite { assertEvalsTo(Trap(Die(Str("foo bar"), TInt32, 5)), Row(Row("foo bar", 5), null)) } - @Test def testTrapNodeInLargerContext() { - def resultByIdx(idx: Int): IR = bindIR(Trap(ArrayRef(Literal(TArray(TInt32), FastSeq(100, 200, 300)), I32(idx)))) { value => - If(IsNA(GetTupleElement(value, 0)), - GetTupleElement(value, 1), - I32(-1) - ) - } + @Test def testTrapNodeInLargerContext(): Unit = { + def resultByIdx(idx: Int): IR = + bindIR(Trap(ArrayRef(Literal(TArray(TInt32), FastSeq(100, 200, 300)), I32(idx)))) { value => + If(IsNA(GetTupleElement(value, 0)), GetTupleElement(value, 1), I32(-1)) + } assertEvalsTo(resultByIdx(-100), -1) assertEvalsTo(resultByIdx(2), 300) diff --git a/hail/src/test/scala/is/hail/expr/ir/UtilFunctionsSuite.scala b/hail/src/test/scala/is/hail/expr/ir/UtilFunctionsSuite.scala index 3805b191dc3..0e19d72a1db 100644 --- a/hail/src/test/scala/is/hail/expr/ir/UtilFunctionsSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/UtilFunctionsSuite.scala @@ -3,7 +3,7 @@ package is.hail.expr.ir import is.hail.{ExecStrategy, HailSuite} import is.hail.TestUtils._ import is.hail.types.virtual.{TBoolean, TInt32, TStream} -import org.scalatest.testng.TestNGSuite + import org.testng.annotations.Test class UtilFunctionsSuite extends HailSuite { @@ -11,19 +11,23 @@ class UtilFunctionsSuite extends HailSuite { val na = NA(TBoolean) val die = Die("it ded", TBoolean) + val folded = StreamFold( MakeStream(IndexedSeq(true), TStream(TBoolean)), - die, "a", "e", - Ref("a", TBoolean) || Ref("e", TBoolean)) + die, + "a", + "e", + Ref("a", TBoolean) || Ref("e", TBoolean), + ) - @Test def shortCircuitOr() { + @Test def shortCircuitOr(): Unit = { assertEvalsTo(True() || True(), true) assertEvalsTo(True() || False(), true) assertEvalsTo(False() || True(), true) assertEvalsTo(False() || False(), false) } - @Test def shortCircuitOrHandlesMissingness() { + @Test def shortCircuitOrHandlesMissingness(): Unit = { assertEvalsTo(na || na, null) assertEvalsTo(na || True(), true) assertEvalsTo(True() || na, true) @@ -32,8 +36,8 @@ class UtilFunctionsSuite extends HailSuite { } - @Test def shortCircuitOrHandlesErrors() { - //FIXME: interpreter evaluates args for ApplySpecial before invoking the function :-| + @Test def shortCircuitOrHandlesErrors(): Unit = { + // FIXME: interpreter evaluates args for ApplySpecial before invoking the function :-| assertCompiledFatal(na || die, "it ded") assertCompiledFatal(False() || die, "it ded") // FIXME: This needs to be fixed with an interpreter function registry @@ -49,14 +53,14 @@ class UtilFunctionsSuite extends HailSuite { assert(eval(True() || folded) == true) } - @Test def shortCircuitAnd() { + @Test def shortCircuitAnd(): Unit = { assertEvalsTo(True() && True(), true) assertEvalsTo(True() && False(), false) assertEvalsTo(False() && True(), false) assertEvalsTo(False() && False(), false) } - @Test def shortCircuitAndHandlesMissingness() { + @Test def shortCircuitAndHandlesMissingness(): Unit = { assertEvalsTo(na && na, null) assertEvalsTo(True() && na, null) assertEvalsTo(na && True(), null) @@ -64,7 +68,7 @@ class UtilFunctionsSuite extends HailSuite { assertEvalsTo(na && False(), false) } - @Test def shortCircuitAndHandlesErroes() { + @Test def shortCircuitAndHandlesErroes(): Unit = { // FIXME: interpreter evaluates args for ApplySpecial before invoking the function :-| assertCompiledFatal(na && die, "it ded") assertCompiledFatal(True() && die, "it ded") @@ -80,7 +84,7 @@ class UtilFunctionsSuite extends HailSuite { assert(eval(False() && folded) == false) } - @Test def testParseFunctionRequiredness() { + @Test def testParseFunctionRequiredness(): Unit = { assertEvalsTo(invoke("toInt32OrMissing", TInt32, Str("123")), 123) assertEvalsTo(invoke("toInt32OrMissing", TInt32, Str("foo")), null) } diff --git a/hail/src/test/scala/is/hail/expr/ir/agg/DownsampleSuite.scala b/hail/src/test/scala/is/hail/expr/ir/agg/DownsampleSuite.scala index af1e33ec5be..4b258e8ff09 100644 --- a/hail/src/test/scala/is/hail/expr/ir/agg/DownsampleSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/agg/DownsampleSuite.scala @@ -5,8 +5,9 @@ import is.hail.annotations.{Region, RegionPool} import is.hail.asm4s._ import is.hail.expr.ir.{EmitCode, EmitFunctionBuilder} import is.hail.types.VirtualTypeWithReq -import is.hail.types.physical.stypes.primitives.SFloat64Value import is.hail.types.physical.{PCanonicalArray, PCanonicalString} +import is.hail.types.physical.stypes.primitives.SFloat64Value + import org.testng.annotations.Test class DownsampleSuite extends HailSuite { @@ -34,16 +35,20 @@ class DownsampleSuite extends HailSuite { cb.assign(i, 0) ds1.init(cb, 100) ds2.init(cb, 100) - cb.while_(i < 10000000, { + cb.while_( + i < 10000000, { cb.assign(x, rng.invoke[Double, Double, Double]("runif", 0d, 1d)) cb.assign(y, rng.invoke[Double, Double, Double]("runif", 0d, 1d)) - ds1.insert(cb, + ds1.insert( + cb, EmitCode.present(cb.emb, new SFloat64Value(x)), EmitCode.present(cb.emb, new SFloat64Value(y)), - EmitCode.missing(cb.emb, PCanonicalArray(PCanonicalString()).sType)) + EmitCode.missing(cb.emb, PCanonicalArray(PCanonicalString()).sType), + ) cb.assign(i, i + const(1)) - }) + }, + ) ds1.merge(cb, ds2) ds3.init(cb, 100) ds1.merge(cb, ds3) diff --git a/hail/src/test/scala/is/hail/expr/ir/agg/StagedBlockLinkedListSuite.scala b/hail/src/test/scala/is/hail/expr/ir/agg/StagedBlockLinkedListSuite.scala index f812c464b25..6a94cdd282c 100644 --- a/hail/src/test/scala/is/hail/expr/ir/agg/StagedBlockLinkedListSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/agg/StagedBlockLinkedListSuite.scala @@ -6,11 +6,12 @@ import is.hail.asm4s.Code import is.hail.expr.ir.{EmitCode, EmitFunctionBuilder} import is.hail.types.physical._ import is.hail.utils._ -import org.testng.Assert._ -import org.testng.annotations.Test import scala.collection.generic.Growable +import org.testng.Assert._ +import org.testng.annotations.Test + class StagedBlockLinkedListSuite extends HailSuite { class BlockLinkedList[E](region: Region, val elemPType: PType, initImmediately: Boolean = true) @@ -43,18 +44,19 @@ class StagedBlockLinkedListSuite extends HailSuite { val ptr = fb.getCodeParam[Long](2) val eltOff = fb.getCodeParam[Long](3) fb.emitWithBuilder[Unit] { cb => - sbll.load(cb, ptr) - sbll.push(cb, r, EmitCode(Code._empty, - eltOff.get.ceq(0L), - elemPType.loadCheapSCode(cb, eltOff))) + sbll.push( + cb, + r, + EmitCode(Code._empty, eltOff.get.ceq(0L), elemPType.loadCheapSCode(cb, eltOff)), + ) sbll.store(cb, ptr) Code._empty } val f = fb.result()(theHailClassLoader) ({ (r, ptr, elt) => - f(r, ptr, if(elt == null) 0L else ScalaToRegionValue(ctx.stateManager, r, elemPType, elt)) + f(r, ptr, if (elt == null) 0L else ScalaToRegionValue(ctx.stateManager, r, elemPType, elt)) }) } @@ -121,13 +123,13 @@ class StagedBlockLinkedListSuite extends HailSuite { val f = fb.result()(theHailClassLoader) ({ (r, other) => f(r, other.ptr) }) - } + } private var ptr = 0L - def clear(): Unit = { ptr = initF(region) } - def +=(e: E): this.type = { pushF(region, ptr, e) ; this } - def ++=(other: BlockLinkedList[E]): this.type = { appendF(region, ptr, other) ; this } + def clear(): Unit = ptr = initF(region) + def +=(e: E): this.type = { pushF(region, ptr, e); this } + def ++=(other: BlockLinkedList[E]): this.type = { appendF(region, ptr, other); this } def toIndexedSeq: IndexedSeq[E] = materializeF(region, ptr) if (initImmediately) clear() @@ -139,20 +141,19 @@ class StagedBlockLinkedListSuite extends HailSuite { } } - @Test def testPushIntsRequired() { + @Test def testPushIntsRequired(): Unit = pool.scopedRegion { region => val b = new BlockLinkedList[Int](region, PInt32Required) for (i <- 1 to 100) b += i assertEquals(b.toIndexedSeq, IndexedSeq.tabulate(100)(_ + 1)) } - } - @Test def testPushStrsMissing() { + @Test def testPushStrsMissing(): Unit = { pool.scopedRegion { region => val a = new BoxedArrayBuilder[String]() val b = new BlockLinkedList[String](region, PCanonicalString()) for (i <- 1 to 100) { - val elt = if(i%3 == 0) null else i.toString() + val elt = if (i % 3 == 0) null else i.toString() a += elt b += elt } @@ -160,7 +161,7 @@ class StagedBlockLinkedListSuite extends HailSuite { } } - @Test def testAppendAnother() { + @Test def testAppendAnother(): Unit = { pool.scopedRegion { region => val b1 = new BlockLinkedList[String](region, PCanonicalString()) val b2 = new BlockLinkedList[String](region, PCanonicalString()) @@ -173,7 +174,7 @@ class StagedBlockLinkedListSuite extends HailSuite { } } - @Test def testDeepCopy() { + @Test def testDeepCopy(): Unit = { pool.scopedRegion { region => val b1 = new BlockLinkedList[Double](region, PFloat64()) b1 ++= Seq(1.0, 2.0, 3.0) diff --git a/hail/src/test/scala/is/hail/expr/ir/analyses/SemanticHashSuite.scala b/hail/src/test/scala/is/hail/expr/ir/analyses/SemanticHashSuite.scala index 31e05af6add..866cb000cfe 100644 --- a/hail/src/test/scala/is/hail/expr/ir/analyses/SemanticHashSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/analyses/SemanticHashSuite.scala @@ -1,39 +1,41 @@ package is.hail.expr.ir.analyses +import is.hail.{HAIL_PRETTY_VERSION, HailSuite} import is.hail.backend.ExecuteContext import is.hail.expr.ir._ import is.hail.io.fs.{FS, FakeFS, FakeURL, FileListEntry} import is.hail.linalg.BlockMatrixMetadata import is.hail.rvd.AbstractRVDSpec -import is.hail.types.virtual._ import is.hail.types.{MatrixType, TableType} -import is.hail.utils.{FastSeq, using} -import is.hail.{HAIL_PRETTY_VERSION, HailSuite} -import org.json4s.JValue -import org.testng.annotations.{DataProvider, Test} +import is.hail.types.virtual._ +import is.hail.utils.{using, FastSeq} + +import scala.util.control.NonFatal import java.io.FileNotFoundException import java.lang -import scala.util.control.NonFatal + +import org.json4s.JValue +import org.testng.annotations.{DataProvider, Test} class SemanticHashSuite extends HailSuite { def isTriviallySemanticallyEquivalent: Array[Array[Any]] = Array( - Array( True(), True(), true, "Refl"), - Array( False(), False(), true, "Refl"), - Array( True(), False(), false, "Refl"), - Array( I32(0), I32(0), true, "Refl"), - Array( I32(0), I32(1), false, "Refl"), - Array( I64(0), I64(0), true, "Refl"), - Array( I64(0), I64(1), false, "Refl"), - Array( F32(0), F32(0), true, "Refl"), - Array( F32(0), F32(1), false, "Refl"), - Array( Void(), Void(), true, "Refl"), - Array( Str("a"), Str("a"), true, "Refl"), - Array( Str("a"), Str("b"), false, "Refl"), - Array(NA(TInt32), NA(TInt32), true, "Refl"), - Array(NA(TInt32), NA(TFloat64), false, "Refl") + Array(True(), True(), true, "Refl"), + Array(False(), False(), true, "Refl"), + Array(True(), False(), false, "Refl"), + Array(I32(0), I32(0), true, "Refl"), + Array(I32(0), I32(1), false, "Refl"), + Array(I64(0), I64(0), true, "Refl"), + Array(I64(0), I64(1), false, "Refl"), + Array(F32(0), F32(0), true, "Refl"), + Array(F32(0), F32(1), false, "Refl"), + Array(Void(), Void(), true, "Refl"), + Array(Str("a"), Str("a"), true, "Refl"), + Array(Str("a"), Str("b"), false, "Refl"), + Array(NA(TInt32), NA(TInt32), true, "Refl"), + Array(NA(TInt32), NA(TFloat64), false, "Refl"), ) def mkRelationalLet(bindings: IndexedSeq[(String, IR)], body: IR): IR = @@ -48,42 +50,40 @@ class SemanticHashSuite extends HailSuite { let(FastSeq("x" -> Void()), ref("x", TVoid)), let(FastSeq("y" -> Void()), ref("y", TVoid)), true, - "names used in let-bindings do not change semantics" + "names used in let-bindings do not change semantics", ), - Array( let(FastSeq("x" -> Void(), "y" -> Void()), ref("x", TVoid)), let(FastSeq("y" -> Void(), "x" -> Void()), ref("y", TVoid)), true, - "names of let-bindings do not change semantics" + "names of let-bindings do not change semantics", ), Array( let(FastSeq("a" -> I32(0)), ref("a", TInt32)), let(FastSeq("a" -> Void()), ref("a", TVoid)), false, - "different IRs" + "different IRs", ), Array( let(FastSeq("x" -> Void(), "y" -> Void()), ref("x", TVoid)), let(FastSeq("y" -> Void(), "x" -> Void()), ref("x", TVoid)), false, - "Different binding being referenced" + "Different binding being referenced", ), /* `SemanticHash` does not perform or recognise opportunities for simplification. - * The following examples demonstrate some of its limitations as a consequence. - */ + * The following examples demonstrate some of its limitations as a consequence. */ Array( let(FastSeq("A" -> Void()), ref("A", TVoid)), let(FastSeq("A" -> let(FastSeq(genUID() -> I32(0)), Void())), ref("A", TVoid)), false, - "SemanticHash does not simplify" + "SemanticHash does not simplify", ), Array( let(FastSeq("A" -> Void()), ref("A", TVoid)), let(FastSeq("A" -> Void(), "B" -> I32(0)), ref("A", TVoid)), false, - "SemanticHash does not simplify" - ) + "SemanticHash does not simplify", + ), ) } @@ -94,58 +94,58 @@ class SemanticHashSuite extends HailSuite { MakeStruct(Array.empty[(String, IR)]), MakeStruct(Array.empty[(String, IR)]), true, - "empty structs" + "empty structs", ), Array( MakeStruct(Array(genUID() -> Void())), MakeStruct(Array(genUID() -> Void())), true, - "field names do not affect MakeStruct semantics" + "field names do not affect MakeStruct semantics", ), Array( MakeTuple(Array.empty[(Int, IR)]), MakeTuple(Array.empty[(Int, IR)]), true, - "empty tuples" + "empty tuples", ), Array( MakeTuple(Array(0 -> Void())), MakeTuple(Array(0 -> Void())), true, - "identical tuples" + "identical tuples", ), Array( MakeTuple(Array(0 -> Void())), MakeTuple(Array(1 -> Void())), false, - "tuple indices affect MakeTuple semantics" - ) + "tuple indices affect MakeTuple semantics", + ), ), { def f(mkType: Int => Type, get: (IR, Int) => IR, isSame: Boolean, reason: String) = - Array.tabulate(2) { idx => bindIR(NA(mkType(idx)))(get(_, idx)) } ++ Array(isSame, reason) + Array.tabulate(2)(idx => bindIR(NA(mkType(idx)))(get(_, idx))) ++ Array(isSame, reason) Array( f( mkType = i => TStruct(i.toString -> TVoid), get = (ir, i) => GetField(ir, i.toString), isSame = true, - "field names do not affect GetField semantics" + "field names do not affect GetField semantics", ), f( mkType = _ => TTuple(TVoid), get = (ir, _) => GetTupleElement(ir, 0), isSame = true, - "GetTupleElement of same index" + "GetTupleElement of same index", ), f( mkType = i => TTuple(Array(TupleField(i, TVoid))), get = (ir, i) => GetTupleElement(ir, i), isSame = false, - "GetTupleElement on different index" - ) + "GetTupleElement on different index", + ), ) - } + }, ) def isTreeStructureSemanticallyEquivalent: Array[Array[Any]] = @@ -153,7 +153,7 @@ class SemanticHashSuite extends HailSuite { Array( MakeArray( MakeArray(I32(0)), - MakeArray(I32(0)) + MakeArray(I32(0)), ), MakeArray( MakeArray( @@ -161,7 +161,7 @@ class SemanticHashSuite extends HailSuite { ) ), false, - "Tree structure contributes to semantics" + "Tree structure contributes to semantics", ) ) @@ -170,7 +170,7 @@ class SemanticHashSuite extends HailSuite { isTriviallySemanticallyEquivalent, isLetSemanticallyEquivalent, isBaseStructSemanticallyEquivalent, - isTreeStructureSemanticallyEquivalent + isTreeStructureSemanticallyEquivalent, ) def isTableIRSemanticallyEquivalent: Array[Array[Any]] = { @@ -183,7 +183,7 @@ class SemanticHashSuite extends HailSuite { def mkTableIR(ttype: TableType, path: String): TableIR = mkTableRead(new TableNativeReader( TableNativeReaderParameters(path, None), - mkFakeTableSpec(ttype) + mkFakeTableSpec(ttype), )) val tir = mkTableIR(ttype, "/fake/table") @@ -192,26 +192,57 @@ class SemanticHashSuite extends HailSuite { Array( Array(tir, tir, true, "TableRead same table"), Array(tir, mkTableIR(ttype, "/another/fake/table"), false, "TableRead different table"), - Array(TableKeyBy(tir, IndexedSeq("a")), TableKeyBy(tir, IndexedSeq("a")), true, "TableKeyBy same key"), - Array(TableKeyBy(tir, IndexedSeq("a")), TableKeyBy(tir, IndexedSeq("b")), false, "TableKeyBy different key") + Array( + TableKeyBy(tir, IndexedSeq("a")), + TableKeyBy(tir, IndexedSeq("a")), + true, + "TableKeyBy same key", + ), + Array( + TableKeyBy(tir, IndexedSeq("a")), + TableKeyBy(tir, IndexedSeq("b")), + false, + "TableKeyBy different key", + ), ), Array[String => TableReader]( - path => new StringTableReader(StringTableReaderParameters(Array(path), None, false, false, false), fakeFs.glob(path)), - path => TableNativeZippedReader(path + ".left", path + ".right", None, mkFakeTableSpec(ttype), mkFakeTableSpec(ttypeb)), - + path => + new StringTableReader( + StringTableReaderParameters(Array(path), None, false, false, false), + fakeFs.glob(path), + ), + path => + TableNativeZippedReader( + path + ".left", + path + ".right", + None, + mkFakeTableSpec(ttype), + mkFakeTableSpec(ttypeb), + ), ) .map(mkTableRead _ compose _) .flatMap { reader => Array( Array(reader("/fake/table"), reader("/fake/table"), true, "read same table"), - Array(reader("/fake/table"), reader("/another/fake/table"), false, "read different table") + Array( + reader("/fake/table"), + reader("/another/fake/table"), + false, + "read different table", + ), ) }, Array( TableGetGlobals, TableAggregate(_, Void()), TableAggregateByKey(_, MakeStruct(FastSeq())), - TableKeyByAndAggregate(_, MakeStruct(FastSeq()), MakeStruct(FastSeq("idx" -> I32(0))), None, 256), + TableKeyByAndAggregate( + _, + MakeStruct(FastSeq()), + MakeStruct(FastSeq("idx" -> I32(0))), + None, + 256, + ), TableCollect, TableCount, TableDistinct, @@ -219,34 +250,46 @@ class SemanticHashSuite extends HailSuite { TableMapGlobals(_, MakeStruct(IndexedSeq.empty)), TableMapRows(_, MakeStruct(FastSeq("a" -> I32(0)))), TableRename(_, Map.empty, Map.empty), - ).map { wrap => - Array(wrap(tir), wrap(tir), true, "") - } + ).map(wrap => Array(wrap(tir), wrap(tir), true, "")), ) } def isBlockMatrixIRSemanticallyEquivalent: Array[Array[Any]] = Array[String => BlockMatrixReader]( path => BlockMatrixBinaryReader(path, Array(1L, 1L), 1), - path => new BlockMatrixNativeReader(BlockMatrixNativeReaderParameters(path), BlockMatrixMetadata(1, 1, 1, None, IndexedSeq.empty)) + path => + new BlockMatrixNativeReader( + BlockMatrixNativeReaderParameters(path), + BlockMatrixMetadata(1, 1, 1, None, IndexedSeq.empty), + ), ) .map(BlockMatrixRead compose _) .flatMap { reader => Array( - Array(reader("/fake/block-matrix"), reader("/fake/block-matrix"), true, "Read same block matrix"), - Array(reader("/fake/block-matrix"), reader("/another/fake/block-matrix"), false, "Read different block matrix"), + Array( + reader("/fake/block-matrix"), + reader("/fake/block-matrix"), + true, + "Read same block matrix", + ), + Array( + reader("/fake/block-matrix"), + reader("/another/fake/block-matrix"), + false, + "Read different block matrix", + ), ) } @DataProvider(name = "isBaseIRSemanticallyEquivalent") def isBaseIRSemanticallyEquivalent: Array[Array[Any]] = - try { + try Array.concat( isValueIRSemanticallyEquivalent, isTableIRSemanticallyEquivalent, - isBlockMatrixIRSemanticallyEquivalent + isBlockMatrixIRSemanticallyEquivalent, ) - } catch { + catch { case NonFatal(t) => t.printStackTrace() throw t @@ -254,7 +297,10 @@ class SemanticHashSuite extends HailSuite { @Test(dataProvider = "isBaseIRSemanticallyEquivalent") def testSemanticEquivalence(a: BaseIR, b: BaseIR, isEqual: Boolean, comment: String): Unit = - assertResult(isEqual, s"expected semhash($a) ${if (isEqual) "==" else "!="} semhash($b), $comment")( + assertResult( + isEqual, + s"expected semhash($a) ${if (isEqual) "==" else "!="} semhash($b), $comment", + )( semhash(fakeFs)(a) == semhash(fakeFs)(b) ) @@ -263,7 +309,7 @@ class SemanticHashSuite extends HailSuite { val fs = new FakeFS { override def eTag(url: FakeURL): Option[String] = - throw new FileNotFoundException(url.getPath()) + throw new FileNotFoundException(url.getPath) } val ir = @@ -285,14 +331,12 @@ class SemanticHashSuite extends HailSuite { ctx.timer, ctx.tempFileManager, ctx.theHailClassLoader, - ctx.referenceGenomes, ctx.flags, ctx.backendContext, - ctx.irMetadata + ctx.irMetadata, ))(SemanticHash(_)(ir)) } - val fakeFs: FS = new FakeFS { override def eTag(url: FakeURL): Option[String] = @@ -301,6 +345,7 @@ class SemanticHashSuite extends HailSuite { override def glob(url: FakeURL): Array[FileListEntry] = Array(new FileListEntry { override def getPath: String = url.getPath + override def getActualUrl: String = url.getPath override def getModificationTime: lang.Long = ??? override def getLen: Long = ??? override def isDirectory: Boolean = ??? @@ -314,8 +359,10 @@ class SemanticHashSuite extends HailSuite { val ty = MatrixType( TStruct.empty, - FastSeq("col_idx"), TStruct("col_idx" -> TInt32), - FastSeq("row_idx"), TStruct("row_idx" -> TInt32), + FastSeq("col_idx"), + TStruct("col_idx" -> TInt32), + FastSeq("row_idx"), + TStruct("row_idx" -> TInt32), TStruct.empty, ) @@ -342,7 +389,7 @@ class SemanticHashSuite extends HailSuite { override def components: Map[String, ComponentSpec] = Map.empty override def toJValue: JValue = ??? - } + }, ) MatrixRead(ty, false, false, reader) diff --git a/hail/src/test/scala/is/hail/expr/ir/lowering/LowerDistributedSortSuite.scala b/hail/src/test/scala/is/hail/expr/ir/lowering/LowerDistributedSortSuite.scala index 82ef4e1255f..b7fbb039e24 100644 --- a/hail/src/test/scala/is/hail/expr/ir/lowering/LowerDistributedSortSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/lowering/LowerDistributedSortSuite.scala @@ -1,36 +1,64 @@ package is.hail.expr.ir.lowering +import is.hail.{ExecStrategy, HailSuite, TestUtils} +import is.hail.expr.ir.{ + mapIR, Apply, Ascending, Descending, ErrorIDs, GetField, I32, Literal, LoweringAnalyses, + MakeStruct, Ref, SelectFields, SortField, TableIR, TableMapRows, TableRange, ToArray, ToStream, +} import is.hail.expr.ir.lowering.LowerDistributedSort.samplePartition -import is.hail.expr.ir.{Apply, Ascending, Descending, ErrorIDs, GetField, I32, Literal, LoweringAnalyses, MakeStruct, Ref, SelectFields, SortField, TableIR, TableMapRows, TableRange, ToArray, ToStream, mapIR} import is.hail.types.RTable import is.hail.types.virtual.{TArray, TInt32, TStruct} import is.hail.utils.FastSeq -import is.hail.{ExecStrategy, HailSuite, TestUtils} + import org.apache.spark.sql.Row import org.testng.annotations.Test class LowerDistributedSortSuite extends HailSuite { implicit val execStrats = ExecStrategy.compileOnly - @Test def testSamplePartition() { - val dataKeys = IndexedSeq((0, 0), (0, -1), (1, 4), (2, 8), (3, 4), (4, 5), (5, 3), (6, 9), (7, 7), (8, -3), (9, 1)) + @Test def testSamplePartition(): Unit = { + val dataKeys = IndexedSeq( + (0, 0), + (0, -1), + (1, 4), + (2, 8), + (3, 4), + (4, 5), + (5, 3), + (6, 9), + (7, 7), + (8, -3), + (9, 1), + ) val elementType = TStruct(("key1", TInt32), ("key2", TInt32), ("value", TInt32)) - val data1 = ToStream(Literal(TArray(elementType), dataKeys.map{ case (k1, k2) => Row(k1, k2, k1 * k1)})) + val data1 = + ToStream(Literal(TArray(elementType), dataKeys.map { case (k1, k2) => Row(k1, k2, k1 * k1) })) val sampleSeq = ToStream(Literal(TArray(TInt32), IndexedSeq(0, 2, 3, 7))) - val sampled = samplePartition(mapIR(data1)(s => SelectFields(s, IndexedSeq("key1", "key2"))), sampleSeq, IndexedSeq(SortField("key1", Ascending), SortField("key2", Ascending))) + val sampled = samplePartition( + mapIR(data1)(s => SelectFields(s, IndexedSeq("key1", "key2"))), + sampleSeq, + IndexedSeq(SortField("key1", Ascending), SortField("key2", Ascending)), + ) - assertEvalsTo(sampled, Row(Row(0, -1), Row(9, 1), IndexedSeq( Row(0, 0), Row(1, 4), Row(2, 8), Row(6, 9)), false)) + assertEvalsTo( + sampled, + Row(Row(0, -1), Row(9, 1), IndexedSeq(Row(0, 0), Row(1, 4), Row(2, 8), Row(6, 9)), false), + ) val dataKeys2 = IndexedSeq((0, 0), (0, 1), (1, 0), (3, 3)) val elementType2 = TStruct(("key1", TInt32), ("key2", TInt32)) - val data2 = ToStream(Literal(TArray(elementType2), dataKeys2.map{ case (k1, k2) => Row(k1, k2)})) + val data2 = + ToStream(Literal(TArray(elementType2), dataKeys2.map { case (k1, k2) => Row(k1, k2) })) val sampleSeq2 = ToStream(Literal(TArray(TInt32), IndexedSeq(0))) - val sampled2 = samplePartition(mapIR(data2)(s => SelectFields(s, IndexedSeq("key2", "key1"))), sampleSeq2, IndexedSeq(SortField("key2", Ascending), SortField("key1", Ascending))) - assertEvalsTo(sampled2, Row(Row(0, 0), Row(3, 3), IndexedSeq( Row(0, 0)), false)) + val sampled2 = samplePartition( + mapIR(data2)(s => SelectFields(s, IndexedSeq("key2", "key1"))), + sampleSeq2, + IndexedSeq(SortField("key2", Ascending), SortField("key1", Ascending)), + ) + assertEvalsTo(sampled2, Row(Row(0, 0), Row(3, 3), IndexedSeq(Row(0, 0)), false)) } - // Only does ascending for now def testDistributedSortHelper(myTable: TableIR, sortFields: IndexedSeq[SortField]): Unit = { val originalShuffleCutoff = backend.getFlag("shuffle_cutoff_to_local_sort") @@ -42,13 +70,21 @@ class LowerDistributedSortSuite extends HailSuite { val sortedTs = LowerDistributedSort.distributedSort(ctx, stage, sortFields, rt) .lower(ctx, myTable.typ.copy(key = FastSeq())) - val res = TestUtils.eval(sortedTs.mapCollect("test")(x => ToArray(x))).asInstanceOf[IndexedSeq[IndexedSeq[Row]]].flatten + val res = + TestUtils.eval(sortedTs.mapCollect("test")(x => ToArray(x))).asInstanceOf[IndexedSeq[ + IndexedSeq[Row] + ]].flatten val rowFunc = myTable.typ.rowType.select(sortFields.map(_.field))._2 val unsortedCollect = is.hail.expr.ir.TestUtils.collect(myTable) val unsortedAnalyses = LoweringAnalyses.apply(unsortedCollect, ctx) - val unsorted = TestUtils.eval(LowerTableIR.apply(unsortedCollect, DArrayLowering.All, ctx, unsortedAnalyses)).asInstanceOf[Row](0).asInstanceOf[IndexedSeq[Row]] - val scalaSorted = unsorted.sortWith{ case (l, r) => + val unsorted = TestUtils.eval(LowerTableIR.apply( + unsortedCollect, + DArrayLowering.All, + ctx, + unsortedAnalyses, + )).asInstanceOf[Row](0).asInstanceOf[IndexedSeq[Row]] + val scalaSorted = unsorted.sortWith { case (l, r) => val leftKey = rowFunc(l) val rightKey = rowFunc(r) var ans = false @@ -67,9 +103,8 @@ class LowerDistributedSortSuite extends HailSuite { ans } assert(res == scalaSorted) - } finally { + } finally backend.setFlag("shuffle_cutoff_to_local_sort", originalShuffleCutoff) - } } @Test def testDistributedSort(): Unit = { @@ -79,18 +114,30 @@ class LowerDistributedSortSuite extends HailSuite { tableRange, MakeStruct(IndexedSeq( "idx" -> GetField(rangeRow, "idx"), - "foo" -> Apply("mod", IndexedSeq(), IndexedSeq(GetField(rangeRow, "idx"), I32(2)), TInt32, ErrorIDs.NO_ERROR), + "foo" -> Apply( + "mod", + IndexedSeq(), + IndexedSeq(GetField(rangeRow, "idx"), I32(2)), + TInt32, + ErrorIDs.NO_ERROR, + ), "backwards" -> -GetField(rangeRow, "idx"), - "const" -> I32(4) - )) + "const" -> I32(4), + )), ) - testDistributedSortHelper(tableWithExtraField, IndexedSeq(SortField("foo", Ascending), SortField("idx", Ascending))) + testDistributedSortHelper( + tableWithExtraField, + IndexedSeq(SortField("foo", Ascending), SortField("idx", Ascending)), + ) testDistributedSortHelper(tableWithExtraField, IndexedSeq(SortField("idx", Ascending))) testDistributedSortHelper(tableWithExtraField, IndexedSeq(SortField("backwards", Ascending))) testDistributedSortHelper(tableWithExtraField, IndexedSeq(SortField("const", Ascending))) testDistributedSortHelper(tableWithExtraField, IndexedSeq(SortField("idx", Descending))) - testDistributedSortHelper(tableWithExtraField, IndexedSeq(SortField("foo", Descending), SortField("idx", Ascending))) + testDistributedSortHelper( + tableWithExtraField, + IndexedSeq(SortField("foo", Descending), SortField("idx", Ascending)), + ) } @Test def testDistributedSortEmpty(): Unit = { diff --git a/hail/src/test/scala/is/hail/expr/ir/table/TableGenSuite.scala b/hail/src/test/scala/is/hail/expr/ir/table/TableGenSuite.scala index 567494007c5..8cc3d8b6cdf 100644 --- a/hail/src/test/scala/is/hail/expr/ir/table/TableGenSuite.scala +++ b/hail/src/test/scala/is/hail/expr/ir/table/TableGenSuite.scala @@ -1,14 +1,15 @@ package is.hail.expr.ir.table +import is.hail.{ExecStrategy, HailSuite} import is.hail.TestUtils.loweredExecute import is.hail.backend.ExecuteContext -import is.hail.expr.ir.TestUtils.IRAggCollect import is.hail.expr.ir._ +import is.hail.expr.ir.TestUtils.IRAggCollect import is.hail.expr.ir.lowering.{DArrayLowering, LowerTableIR} import is.hail.rvd.RVDPartitioner import is.hail.types.virtual._ import is.hail.utils.{FastSeq, HailException, Interval} -import is.hail.{ExecStrategy, HailSuite} + import org.apache.spark.SparkException import org.apache.spark.sql.Row import org.scalatest.Matchers._ @@ -19,7 +20,7 @@ class TableGenSuite extends HailSuite { implicit val execStrategy = ExecStrategy.lowering @Test(groups = Array("construction", "typecheck")) - def testWithInvalidContextsType: Unit = { + def testWithInvalidContextsType(): Unit = { val ex = intercept[IllegalArgumentException] { mkTableGen(contexts = Some(Str("oh noes :'("))).typecheck() } @@ -30,9 +31,12 @@ class TableGenSuite extends HailSuite { } @Test(groups = Array("construction", "typecheck")) - def testWithInvalidGlobalsType: Unit = { + def testWithInvalidGlobalsType(): Unit = { val ex = intercept[IllegalArgumentException] { - mkTableGen(globals = Some(Str("oh noes :'(")), body = Some(MakeStream(IndexedSeq(), TStream(TStruct())))).typecheck() + mkTableGen( + globals = Some(Str("oh noes :'(")), + body = Some(MakeStream(IndexedSeq(), TStream(TStruct()))), + ).typecheck() } ex.getMessage should include("globals") ex.getMessage should include(s"Expected: ${classOf[TStruct].getName}") @@ -40,7 +44,7 @@ class TableGenSuite extends HailSuite { } @Test(groups = Array("construction", "typecheck")) - def testWithInvalidBodyType: Unit = { + def testWithInvalidBodyType(): Unit = { val ex = intercept[IllegalArgumentException] { mkTableGen(body = Some(Str("oh noes :'("))).typecheck() } @@ -50,9 +54,11 @@ class TableGenSuite extends HailSuite { } @Test(groups = Array("construction", "typecheck")) - def testWithInvalidBodyElementType: Unit = { + def testWithInvalidBodyElementType(): Unit = { val ex = intercept[IllegalArgumentException] { - mkTableGen(body = Some(MakeStream(IndexedSeq(Str("oh noes :'(")), TStream(TString)))).typecheck() + mkTableGen(body = + Some(MakeStream(IndexedSeq(Str("oh noes :'(")), TStream(TString))) + ).typecheck() } ex.getMessage should include("body.elementType") ex.getMessage should include(s"Expected: ${classOf[TStruct].getName}") @@ -60,23 +66,27 @@ class TableGenSuite extends HailSuite { } @Test(groups = Array("construction", "typecheck")) - def testWithInvalidPartitionerKeyType: Unit = { + def testWithInvalidPartitionerKeyType(): Unit = { val ex = intercept[IllegalArgumentException] { - mkTableGen(partitioner = Some(RVDPartitioner.empty(ctx.stateManager, TStruct("does-not-exist" -> TInt32)))).typecheck() + mkTableGen(partitioner = + Some(RVDPartitioner.empty(ctx.stateManager, TStruct("does-not-exist" -> TInt32))) + ).typecheck() } ex.getMessage should include("partitioner") } @Test(groups = Array("construction", "typecheck")) - def testWithTooLongPartitionerKeyType: Unit = { + def testWithTooLongPartitionerKeyType(): Unit = { val ex = intercept[IllegalArgumentException] { - mkTableGen(partitioner = Some(RVDPartitioner.empty(ctx.stateManager, TStruct("does-not-exist" -> TInt32)))).typecheck() + mkTableGen(partitioner = + Some(RVDPartitioner.empty(ctx.stateManager, TStruct("does-not-exist" -> TInt32))) + ).typecheck() } ex.getMessage should include("partitioner") } @Test(groups = Array("requiredness")) - def testRequiredness: Unit = { + def testRequiredness(): Unit = { val table = mkTableGen() val analysis = Requiredness(table, ctx) analysis.lookup(table).required shouldBe true @@ -84,43 +94,44 @@ class TableGenSuite extends HailSuite { } @Test(groups = Array("lowering")) - def testLowering: Unit = { + def testLowering(): Unit = { val table = TestUtils.collect(mkTableGen()) val lowered = LowerTableIR(table, DArrayLowering.All, ctx, LoweringAnalyses(table, ctx)) assertEvalsTo(lowered, Row(FastSeq(0, 0).map(Row(_)), Row(0))) } @Test(groups = Array("lowering")) - def testNumberOfContextsMatchesPartitions: Unit = { + def testNumberOfContextsMatchesPartitions(): Unit = { val errorId = 42 val table = TestUtils.collect(mkTableGen( partitioner = Some(RVDPartitioner.unkeyed(ctx.stateManager, 0)), - errorId = Some(errorId) + errorId = Some(errorId), )) val lowered = LowerTableIR(table, DArrayLowering.All, ctx, LoweringAnalyses(table, ctx)) val ex = intercept[HailException] { - ExecuteContext.scoped() { ctx => - loweredExecute(ctx, lowered, Env.empty, FastSeq(), None) - } + ExecuteContext.scoped()(ctx => loweredExecute(ctx, lowered, Env.empty, FastSeq(), None)) } ex.errorId shouldBe errorId ex.getMessage should include("partitioner contains 0 partitions, got 2 contexts.") } @Test(groups = Array("lowering")) - def testRowsAreCorrectlyKeyed: Unit = { + def testRowsAreCorrectlyKeyed(): Unit = { val errorId = 56 val table = TestUtils.collect(mkTableGen( - partitioner = Some(new RVDPartitioner(ctx.stateManager, TStruct("a" -> TInt32), FastSeq( - Interval(Row(0), Row(0), true, false), Interval(Row(1), Row(1), true, false) - ))), - errorId = Some(errorId) + partitioner = Some(new RVDPartitioner( + ctx.stateManager, + TStruct("a" -> TInt32), + FastSeq( + Interval(Row(0), Row(0), true, false), + Interval(Row(1), Row(1), true, false), + ), + )), + errorId = Some(errorId), )) val lowered = LowerTableIR(table, DArrayLowering.All, ctx, LoweringAnalyses(table, ctx)) val ex = intercept[SparkException] { - ExecuteContext.scoped() { ctx => - loweredExecute(ctx, lowered, Env.empty, FastSeq(), None) - } + ExecuteContext.scoped()(ctx => loweredExecute(ctx, lowered, Env.empty, FastSeq(), None)) }.getCause.asInstanceOf[HailException] ex.errorId shouldBe errorId @@ -128,45 +139,48 @@ class TableGenSuite extends HailSuite { } @Test(groups = Array("optimization", "prune")) - def testPruneNoUnusedFields: Unit = { + def testPruneNoUnusedFields(): Unit = { val start = mkTableGen() val pruned = PruneDeadFields(ctx, start) pruned.typ shouldBe start.typ } @Test(groups = Array("optimization", "prune")) - def testPruneGlobals: Unit = { + def testPruneGlobals(): Unit = { val cname = "contexts" - val start = mkTableGen(cname = Some(cname), body = Some { - val elem = MakeStruct(IndexedSeq("a" -> Ref(cname, TInt32))) - MakeStream(IndexedSeq(elem), TStream(elem.typ)) - }) - - val TableAggregate(pruned, _) = PruneDeadFields(ctx, - TableAggregate(start, IRAggCollect(Ref("row", start.typ.rowType))) + val start = mkTableGen( + cname = Some(cname), + body = Some { + val elem = MakeStruct(IndexedSeq("a" -> Ref(cname, TInt32))) + MakeStream(IndexedSeq(elem), TStream(elem.typ)) + }, ) + val TableAggregate(pruned, _) = + PruneDeadFields(ctx, TableAggregate(start, IRAggCollect(Ref("row", start.typ.rowType)))) + pruned.typ should not be start.typ pruned.typ.globalType shouldBe TStruct() pruned.asInstanceOf[TableGen].globals shouldBe MakeStruct(IndexedSeq()) } @Test(groups = Array("optimization", "prune")) - def testPruneContexts: Unit = { + def testPruneContexts(): Unit = { val start = mkTableGen() val TableGetGlobals(pruned) = PruneDeadFields(ctx, TableGetGlobals(start)) pruned.typ should not be start.typ pruned.typ.rowType shouldBe TStruct() } - def mkTableGen(contexts: Option[IR] = None, - globals: Option[IR] = None, - cname: Option[String] = None, - gname: Option[String] = None, - body: Option[IR] = None, - partitioner: Option[RVDPartitioner] = None, - errorId: Option[Int] = None - ): TableGen = { + def mkTableGen( + contexts: Option[IR] = None, + globals: Option[IR] = None, + cname: Option[String] = None, + gname: Option[String] = None, + body: Option[IR] = None, + partitioner: Option[RVDPartitioner] = None, + errorId: Option[Int] = None, + ): TableGen = { val theGlobals = globals.getOrElse(MakeStruct(IndexedSeq("g" -> 0))) val contextName = cname.getOrElse(genUID()) val globalsName = gname.getOrElse(genUID()) @@ -183,7 +197,7 @@ class TableGenSuite extends HailSuite { MakeStream(IndexedSeq(elem), TStream(elem.typ)) }, partitioner.getOrElse(RVDPartitioner.unkeyed(ctx.stateManager, 2)), - errorId.getOrElse(ErrorIDs.NO_ERROR) + errorId.getOrElse(ErrorIDs.NO_ERROR), ) } } diff --git a/hail/src/test/scala/is/hail/io/AvroReaderSuite.scala b/hail/src/test/scala/is/hail/io/AvroReaderSuite.scala index 18aad7af9f9..02c54db69f3 100644 --- a/hail/src/test/scala/is/hail/io/AvroReaderSuite.scala +++ b/hail/src/test/scala/is/hail/io/AvroReaderSuite.scala @@ -1,10 +1,11 @@ package is.hail.io +import is.hail.{ExecStrategy, HailSuite} import is.hail.ExecStrategy.ExecStrategy import is.hail.expr.ir.{I64, MakeStruct, ReadPartition, Str, ToArray} import is.hail.io.avro.AvroPartitionReader -import is.hail.utils.{FastSeq, fatal, using} -import is.hail.{ExecStrategy, HailSuite} +import is.hail.utils.{fatal, using, FastSeq} + import org.apache.avro.SchemaBuilder import org.apache.avro.file.DataFileWriter import org.apache.avro.generic.{GenericDatumWriter, GenericRecord, GenericRecordBuilder} @@ -28,19 +29,19 @@ class AvroReaderSuite extends HailSuite { Row(1, 1L, 1.0f, 1.0d, ""), Row(-1, -1L, -1.0f, -1.0d, "minus one"), Row(Int.MaxValue, Long.MaxValue, Float.MaxValue, Double.MaxValue, null), - Row(Int.MinValue, null, Float.MinPositiveValue, Double.MinPositiveValue, "MINIMUM STRING") + Row(Int.MinValue, null, Float.MinPositiveValue, Double.MinPositiveValue, "MINIMUM STRING"), ) private val partitionReader = AvroPartitionReader(testSchema, "rowUID") def makeRecord(row: Row): GenericRecord = row match { case Row(int, long, float, double, string) => new GenericRecordBuilder(testSchema) - .set("an_int", int) - .set("an_optional_long", long) - .set("a_float", float) - .set("a_double", double) - .set("an_optional_string", string) - .build() + .set("an_int", int) + .set("an_optional_long", long) + .set("a_float", float) + .set("a_double", double) + .set("an_optional_string", string) + .build() case _ => fatal("invalid row") } @@ -48,10 +49,12 @@ class AvroReaderSuite extends HailSuite { val avroFile = ctx.createTmpPath("avro_test", "avro") using(fs.create(avroFile)) { os => - using(new DataFileWriter[GenericRecord](new GenericDatumWriter(testSchema)).create(testSchema, os)) { dw => - for (row <- testValue) { + using(new DataFileWriter[GenericRecord](new GenericDatumWriter(testSchema)).create( + testSchema, + os, + )) { dw => + for (row <- testValue) dw.append(makeRecord(row)) - } } } @@ -63,8 +66,9 @@ class AvroReaderSuite extends HailSuite { val ir = ToArray(ReadPartition( MakeStruct(Array("partitionPath" -> Str(avroFile), "partitionIndex" -> I64(0))), partitionReader.fullRowType, - partitionReader)) - val testValueWithUIDs = testValue.zipWithIndex.map { case(x, i) => + partitionReader, + )) + val testValueWithUIDs = testValue.zipWithIndex.map { case (x, i) => Row(x(0), x(1), x(2), x(3), x(4), Row(0L, i.toLong)) } assertEvalsTo(ir, testValueWithUIDs) @@ -75,7 +79,8 @@ class AvroReaderSuite extends HailSuite { val ir = ToArray(ReadPartition( MakeStruct(Array("partitionPath" -> Str(avroFile), "partitionIndex" -> I64(0))), partitionReader.fullRowType.typeAfterSelect(FastSeq(0, 2, 4)), - partitionReader)) + partitionReader, + )) val expected = testValue.map { case Row(int, _, float, _, string) => Row(int, float, string) } assertEvalsTo(ir, expected) } diff --git a/hail/src/test/scala/is/hail/io/ByteArrayReaderSuite.scala b/hail/src/test/scala/is/hail/io/ByteArrayReaderSuite.scala index b784358177a..11bc29c4ec5 100644 --- a/hail/src/test/scala/is/hail/io/ByteArrayReaderSuite.scala +++ b/hail/src/test/scala/is/hail/io/ByteArrayReaderSuite.scala @@ -5,21 +5,21 @@ import org.testng.annotations.Test class ByteArrayReaderSuite extends TestNGSuite { @Test - def readLongReadsALong() { + def readLongReadsALong(): Unit = { val a = Array(0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff) .map(_.toByte) assert(new ByteArrayReader(a).readLong() == -1L) } @Test - def readLongReadsALong2() { + def readLongReadsALong2(): Unit = { val a = Array(0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, 0xf8) .map(_.toByte) assert(new ByteArrayReader(a).readLong() == 0xf8f7f6f5f4f3f2f1L) } @Test - def readLongReadsALong3() { + def readLongReadsALong3(): Unit = { val a = Array(0xf8, 0xf7, 0xf6, 0xf5, 0xf4, 0xf3, 0xf2, 0xf1) .map(_.toByte) assert(new ByteArrayReader(a).readLong() == 0xf1f2f3f4f5f6f7f8L) diff --git a/hail/src/test/scala/is/hail/io/IndexBTreeSuite.scala b/hail/src/test/scala/is/hail/io/IndexBTreeSuite.scala index 29ce784c35a..4682309ba0b 100644 --- a/hail/src/test/scala/is/hail/io/IndexBTreeSuite.scala +++ b/hail/src/test/scala/is/hail/io/IndexBTreeSuite.scala @@ -4,9 +4,8 @@ import is.hail.HailSuite import is.hail.check.Gen._ import is.hail.check.Prop._ import is.hail.check.Properties -import org.testng.annotations.Test -import scala.language.implicitConversions +import org.testng.annotations.Test class IndexBTreeSuite extends HailSuite { @@ -16,7 +15,8 @@ class IndexBTreeSuite extends HailSuite { depth <- frequency((4, const(1)), (5, const(2)), (1, const(3))) arraySize <- choose( math.max(1, math.pow(10, (depth - 1) * math.log10(1024)).toInt), - math.min(1100000, math.pow(10, depth * math.log10(1024)).toInt)) + math.min(1100000, math.pow(10, depth * math.log10(1024)).toInt), + ) } yield (depth, arraySize) def fillRandomArray(arraySize: Int): Array[Long] = { @@ -34,7 +34,6 @@ class IndexBTreeSuite extends HailSuite { property("query gives same answer as array") = forAll(arraySizeGenerator) { case (depth: Int, arraySize: Int) => val arrayRandomStarts = fillRandomArray(arraySize) - val maxLong = arrayRandomStarts.takeRight(1)(0) val index = ctx.createTmpPath("testBtree", "idx") fs.delete(index, true) @@ -59,7 +58,9 @@ class IndexBTreeSuite extends HailSuite { arrayRandomStarts.forall { case (l) => btree.queryIndex(l - 1).contains(l) } else { val randomIndices = Array(0) ++ Array.fill(100)(choose(0, arraySize - 1).sample()) - randomIndices.map(arrayRandomStarts).forall { case (l) => btree.queryIndex(l - 1).contains(l) } + randomIndices.map(arrayRandomStarts).forall { case (l) => + btree.queryIndex(l - 1).contains(l) + } } if (!depthCorrect || !indexCorrectSize || !queryCorrect) @@ -70,20 +71,18 @@ class IndexBTreeSuite extends HailSuite { } } - @Test def test() { + @Test def test(): Unit = Spec.check() - } - @Test def oneVariant() { + @Test def oneVariant(): Unit = { val index = Array(24.toLong) - val fileSize = 30 //made-up value greater than index + val fileSize = 30 // made-up value greater than index val idxFile = ctx.createTmpPath("testBtree_1variant", "idx") fs.delete(idxFile, recursive = true) IndexBTree.write(index, idxFile, fs) val btree = new IndexBTree(idxFile, fs) - intercept[IllegalArgumentException] { btree.queryIndex(-5) } @@ -96,7 +95,7 @@ class IndexBTreeSuite extends HailSuite { assert(btree.queryIndex(fileSize - 1).isEmpty) } - @Test def zeroVariants() { + @Test def zeroVariants(): Unit = { intercept[IllegalArgumentException] { val index = Array[Long]() val idxFile = ctx.createTmpPath("testBtree_0variant", "idx") @@ -105,7 +104,7 @@ class IndexBTreeSuite extends HailSuite { } } - @Test def testMultipleOfBranchingFactorDoesNotAddUnnecessaryElements() { + @Test def testMultipleOfBranchingFactorDoesNotAddUnnecessaryElements(): Unit = { val in = Array[Long](10, 9, 8, 7, 6, 5, 4, 3) val bigEndianBytes = Array[Byte]( 0, 0, 0, 0, 0, 0, 0, 10, @@ -120,17 +119,18 @@ class IndexBTreeSuite extends HailSuite { sameElements bigEndianBytes) } - @Test def writeReadMultipleOfBranchingFactorDoesNotError() { + @Test def writeReadMultipleOfBranchingFactorDoesNotError(): Unit = { val idxFile = ctx.createTmpPath("btree") IndexBTree.write( Array.tabulate(1024)(i => i), idxFile, - fs) + fs, + ) val index = new IndexBTree(idxFile, fs) assert(index.queryIndex(33).contains(33L)) } - @Test def queryArrayPositionAndFileOffsetIsCorrectSmallArray() { + @Test def queryArrayPositionAndFileOffsetIsCorrectSmallArray(): Unit = { val f = ctx.createTmpPath("btree") val v = Array[Long](1, 2, 3, 40, 50, 60, 70) val branchingFactor = 1024 @@ -148,7 +148,7 @@ class IndexBTreeSuite extends HailSuite { assert(bt.queryArrayPositionAndFileOffset(71).isEmpty) } - @Test def queryArrayPositionAndFileOffsetIsCorrectTwoLevelsArray() { + @Test def queryArrayPositionAndFileOffsetIsCorrectTwoLevelsArray(): Unit = { def sqr(x: Long) = x * x val f = ctx.createTmpPath("btree") val v = Array.tabulate(1025)(x => sqr(x)) @@ -175,7 +175,7 @@ class IndexBTreeSuite extends HailSuite { assert(bt.queryArrayPositionAndFileOffset(5).contains((3, sqr(3)))) } - @Test def queryArrayPositionAndFileOffsetIsCorrectThreeLevelsArray() { + @Test def queryArrayPositionAndFileOffsetIsCorrectThreeLevelsArray(): Unit = { def sqr(x: Long) = x * x val f = ctx.createTmpPath("btree") val v = Array.tabulate(1024 * 1024 + 1)(x => sqr(x)) @@ -201,16 +201,28 @@ class IndexBTreeSuite extends HailSuite { assert(bt.queryArrayPositionAndFileOffset(4).contains((2, sqr(2)))) assert(bt.queryArrayPositionAndFileOffset(5).contains((3, sqr(3)))) - assert(bt.queryArrayPositionAndFileOffset(sqr(1024 * 1024 - 1)).contains((1024 * 1024 - 1, sqr(1024 * 1024 - 1)))) - assert(bt.queryArrayPositionAndFileOffset(sqr(1024 * 1024 - 1) + 1).contains((1024 * 1024, sqr(1024 * 1024)))) - - assert(bt.queryArrayPositionAndFileOffset(sqr(1024 * 1024)).contains((1024 * 1024, sqr(1024 * 1024)))) - assert(bt.queryArrayPositionAndFileOffset(sqr(1024 * 1024) - 1).contains((1024 * 1024, sqr(1024 * 1024)))) + assert(bt.queryArrayPositionAndFileOffset(sqr(1024 * 1024 - 1)).contains(( + 1024 * 1024 - 1, + sqr(1024 * 1024 - 1), + ))) + assert(bt.queryArrayPositionAndFileOffset(sqr(1024 * 1024 - 1) + 1).contains(( + 1024 * 1024, + sqr(1024 * 1024), + ))) + + assert(bt.queryArrayPositionAndFileOffset(sqr(1024 * 1024)).contains(( + 1024 * 1024, + sqr(1024 * 1024), + ))) + assert(bt.queryArrayPositionAndFileOffset(sqr(1024 * 1024) - 1).contains(( + 1024 * 1024, + sqr(1024 * 1024), + ))) assert(bt.queryArrayPositionAndFileOffset(sqr(1024 * 1024) + 1).isEmpty) } - @Test def onDiskBTreeIndexToValueSmallCorrect() { + @Test def onDiskBTreeIndexToValueSmallCorrect(): Unit = { val f = ctx.createTmpPath("btree") val v = Array[Long](1, 2, 3, 4, 5, 6, 7) val branchingFactor = 3 @@ -223,17 +235,20 @@ class IndexBTreeSuite extends HailSuite { val indices = Seq(0, 5, 1, 6) val actual = bt.positionOfVariants(indices.toArray) val expected = indices.sorted.map(v) - assert(actual sameElements expected, - s"${ actual.toSeq } not same elements as expected ${ expected.toSeq }") + assert( + actual sameElements expected, + s"${actual.toSeq} not same elements as expected ${expected.toSeq}", + ) } catch { case t: Throwable => throw new RuntimeException( "exception while checking BTree: " + IndexBTree.toString(v, branchingFactor), - t) + t, + ) } } - @Test def onDiskBTreeIndexToValueRandomized() { + @Test def onDiskBTreeIndexToValueRandomized(): Unit = { val g = for { longs <- buildableOf[Array](choose(0L, Long.MaxValue)) indices <- buildableOf[Array](choose(0, longs.length - 1)) @@ -246,19 +261,22 @@ class IndexBTreeSuite extends HailSuite { val bt = new OnDiskBTreeIndexToValue(f, fs, branchingFactor) val actual = bt.positionOfVariants(indices.toArray) val expected = indices.sorted.map(longs) - assert(actual sameElements expected, - s"${ actual.toSeq } not same elements as expected ${ expected.toSeq }") + assert( + actual sameElements expected, + s"${actual.toSeq} not same elements as expected ${expected.toSeq}", + ) } catch { case t: Throwable => throw new RuntimeException( "exception while checking BTree: " + IndexBTree.toString(longs, branchingFactor), - t) + t, + ) } true }.check() } - @Test def onDiskBTreeIndexToValueFourLayers() { + @Test def onDiskBTreeIndexToValueFourLayers(): Unit = { val longs = Array.tabulate(3 * 3 * 3 * 3)(x => x.toLong) val indices = Array(0, 3, 10, 20, 26, 27, 34, 55, 79, 80) val f = ctx.createTmpPath("btree") @@ -268,17 +286,20 @@ class IndexBTreeSuite extends HailSuite { val bt = new OnDiskBTreeIndexToValue(f, fs, branchingFactor) val expected = indices.sorted.map(longs) val actual = bt.positionOfVariants(indices.toArray) - assert(actual sameElements expected, - s"${ actual.toSeq } not same elements as expected ${ expected.toSeq }") + assert( + actual sameElements expected, + s"${actual.toSeq} not same elements as expected ${expected.toSeq}", + ) } catch { case t: Throwable => throw new RuntimeException( "exception while checking BTree: " + IndexBTree.toString(longs, branchingFactor), - t) + t, + ) } } - @Test def calcDepthIsCorrect() { + @Test def calcDepthIsCorrect(): Unit = { def sqr(x: Long) = x * x def cube(x: Long) = x * x * x diff --git a/hail/src/test/scala/is/hail/io/IndexSuite.scala b/hail/src/test/scala/is/hail/io/IndexSuite.scala index 9222f5e1ce6..18e4bbee6be 100644 --- a/hail/src/test/scala/is/hail/io/IndexSuite.scala +++ b/hail/src/test/scala/is/hail/io/IndexSuite.scala @@ -6,6 +6,7 @@ import is.hail.io.index._ import is.hail.types.physical.{PCanonicalString, PCanonicalStruct, PInt32, PType} import is.hail.types.virtual._ import is.hail.utils._ + import org.apache.spark.sql.Row import org.testng.annotations.{DataProvider, Test} @@ -27,29 +28,33 @@ class IndexSuite extends HailSuite { val leafsWithDups = stringsWithDups.zipWithIndex.map { case (s, i) => LeafChild(s, i, Row()) } @DataProvider(name = "elements") - def data(): Array[Array[Array[String]]] = { + def data(): Array[Array[Array[String]]] = (1 to strings.length).map(i => Array(strings.take(i))).toArray - } - def writeIndex(file: String, + def writeIndex( + file: String, data: Array[Any], annotations: Array[Annotation], keyType: PType, annotationType: PType, branchingFactor: Int, - attributes: Map[String, Any]) { - val bufferSpec = BufferSpec.default - - val iw = IndexWriter.builder(ctx, keyType, annotationType, branchingFactor, attributes)(file, theHailClassLoader, ctx.taskContext, pool) + attributes: Map[String, Any], + ): Unit = { + val iw = IndexWriter.builder(ctx, keyType, annotationType, branchingFactor, attributes)( + file, + theHailClassLoader, + ctx.taskContext, + pool, + ) data.zip(annotations).zipWithIndex.foreach { case ((s, a), offset) => iw.appendRow(s, offset, a) } iw.close() } - def indexReader(file: String, annotationType: Type, keyPType: PType = PCanonicalString()): IndexReader = { + def indexReader(file: String, annotationType: Type, keyPType: PType = PCanonicalString()) + : IndexReader = indexReader(file, PType.canonical(annotationType), keyPType) - } def indexReader(file: String, annotationPType: PType, keyPType: PType): IndexReader = { val leafPType = LeafNodeBuilder.typ(keyPType, annotationPType) @@ -61,34 +66,53 @@ class IndexSuite extends HailSuite { assert(lrt == leafPType) val (irt, intDec) = intSpec.buildDecoder(ctx, intPType.virtualType) assert(irt == intPType) - IndexReaderBuilder.withDecoders(ctx, leafDec, intDec, - keyPType.virtualType, annotationPType.virtualType, - leafPType, intPType).apply(theHailClassLoader, fs, file, 8, pool) + IndexReaderBuilder.withDecoders( + ctx, + leafDec, + intDec, + keyPType.virtualType, + annotationPType.virtualType, + leafPType, + intPType, + ).apply(theHailClassLoader, fs, file, 8, pool) } - def writeIndex(file: String, + def writeIndex( + file: String, data: Array[String], annotations: Array[Annotation], annotationType: Type, branchingFactor: Int = 2, - attributes: Map[String, Any] = Map.empty[String, Any]): Unit = - writeIndex(file, data.map(_.asInstanceOf[Any]), annotations, PCanonicalString(), PType.canonical(annotationType), branchingFactor, attributes) + attributes: Map[String, Any] = Map.empty[String, Any], + ): Unit = + writeIndex( + file, + data.map(_.asInstanceOf[Any]), + annotations, + PCanonicalString(), + PType.canonical(annotationType), + branchingFactor, + attributes, + ) @Test(dataProvider = "elements") - def writeReadGivesSameAsInput(data: Array[String]) { + def writeReadGivesSameAsInput(data: Array[String]): Unit = { val file = ctx.createTmpPath("test", "idx") - val attributes = Map("foo" -> true, "bar" -> 5) + val attributes: Map[String, Any] = Map("foo" -> true, "bar" -> 5) val a: (Int) => Annotation = (i: Int) => Row(i % 2 == 0) for (branchingFactor <- 2 to 5) { - writeIndex(file, + writeIndex( + file, data, data.indices.map(i => a(i)).toArray, TStruct("a" -> TBoolean), branchingFactor, - attributes) - assert(fs.getFileSize(file) != 0) + attributes, + ) + assert(fs.getFileSize(file + "/index") != 0) + assert(fs.getFileSize(file + "/metadata.json.gz") != 0) val index = indexReader(file, TStruct("a" -> TBoolean)) @@ -105,10 +129,11 @@ class IndexSuite extends HailSuite { } } - @Test def testEmptyKeys() { + @Test def testEmptyKeys(): Unit = { val file = ctx.createTmpPath("empty", "idx") writeIndex(file, Array.empty[String], Array.empty[Annotation], TStruct("a" -> TBoolean), 2) - assert(fs.getFileSize(file) != 0) + assert(fs.getFileSize(file + "/index") != 0) + assert(fs.getFileSize(file + "/metadata.json.gz") != 0) val index = indexReader(file, TStruct("a" -> TBoolean)) intercept[IllegalArgumentException](index.queryByIndex(0L)) assert(index.queryByKey("moo").isEmpty) @@ -116,20 +141,31 @@ class IndexSuite extends HailSuite { index.close() } - @Test def testLowerBound() { + @Test def testLowerBound(): Unit = { for (branchingFactor <- 2 to 5) { val file = ctx.createTmpPath("lowerBound", "idx") - writeIndex(file, stringsWithDups, stringsWithDups.indices.map(i => Row()).toArray, TStruct.empty, branchingFactor) + writeIndex( + file, + stringsWithDups, + stringsWithDups.indices.map(i => Row()).toArray, + TStruct.empty, + branchingFactor, + ) val index = indexReader(file, TStruct.empty) - val n = stringsWithDups.length - val f = { i: Int => stringsWithDups(i) } - val expectedResult = Array( - "aardvark" -> 0, "bear" -> 0, "cat" -> 2, - "dog" -> 9, "elk" -> 10, "mouse" -> 10, - "opossum" -> 12, "skunk" -> 12, "snail" -> 15, - "whale" -> 15, "zebra" -> 17, "zoo" -> stringsWithDups.length + "aardvark" -> 0, + "bear" -> 0, + "cat" -> 2, + "dog" -> 9, + "elk" -> 10, + "mouse" -> 10, + "opossum" -> 12, + "skunk" -> 12, + "snail" -> 15, + "whale" -> 15, + "zebra" -> 17, + "zoo" -> stringsWithDups.length, ) expectedResult.foreach { case (s, expectedIdx) => @@ -138,21 +174,31 @@ class IndexSuite extends HailSuite { } } - @Test def testUpperBound() { + @Test def testUpperBound(): Unit = { for (branchingFactor <- 2 to 5) { val file = ctx.createTmpPath("upperBound", "idx") - writeIndex(file, stringsWithDups, stringsWithDups.indices.map(i => Row()).toArray, TStruct.empty, branchingFactor = 2) + writeIndex( + file, + stringsWithDups, + stringsWithDups.indices.map(i => Row()).toArray, + TStruct.empty, + branchingFactor = 2, + ) val index = indexReader(file, TStruct.empty) - val n = stringsWithDups.length - val f = { i: Int => stringsWithDups(i) } - val expectedResult = Array( - "aardvark" -> 0, "bear" -> 2, "cat" -> 9, - "dog" -> 10, "elk" -> 10, "mouse" -> 12, - "opossum" -> 12, "skunk" -> 15, "snail" -> 15, - "whale" -> 17, "zebra" -> stringsWithDups.length, - "zoo" -> stringsWithDups.length + "aardvark" -> 0, + "bear" -> 2, + "cat" -> 9, + "dog" -> 10, + "elk" -> 10, + "mouse" -> 12, + "opossum" -> 12, + "skunk" -> 15, + "snail" -> 15, + "whale" -> 17, + "zebra" -> stringsWithDups.length, + "zoo" -> stringsWithDups.length, ) expectedResult.foreach { case (s, expectedIdx) => @@ -161,60 +207,150 @@ class IndexSuite extends HailSuite { } } - @Test def testRangeIterator() { + @Test def testRangeIterator(): Unit = { for (branchingFactor <- 2 to 5) { val file = ctx.createTmpPath("range", "idx") val a = { (i: Int) => Row() } - writeIndex(file, stringsWithDups, stringsWithDups.indices.map(a).toArray, TStruct.empty, branchingFactor) + writeIndex( + file, + stringsWithDups, + stringsWithDups.indices.map(a).toArray, + TStruct.empty, + branchingFactor, + ) val index = indexReader(file, TStruct.empty) val bounds = stringsWithDups.indices.toArray.combinations(2).toArray - bounds.foreach(b => index.iterator(b(0), b(1)).toArray sameElements leafsWithDups.slice(b(0), b(1))) + bounds.foreach(b => + index.iterator(b(0), b(1)).toArray sameElements leafsWithDups.slice(b(0), b(1)) + ) - assert(index.iterator.toArray sameElements stringsWithDups.zipWithIndex.map { case (s, i) => LeafChild(s, i, a(i)) }) + assert(index.iterator.toArray sameElements stringsWithDups.zipWithIndex.map { case (s, i) => + LeafChild(s, i, a(i)) + }) } } - @Test def testQueryByKey() { + @Test def testQueryByKey(): Unit = { for (branchingFactor <- 2 to 5) { val file = ctx.createTmpPath("key", "idx") - writeIndex(file, stringsWithDups, stringsWithDups.indices.map(i => Row()).toArray, TStruct.empty, branchingFactor) + writeIndex( + file, + stringsWithDups, + stringsWithDups.indices.map(i => Row()).toArray, + TStruct.empty, + branchingFactor, + ) val index = indexReader(file, TStruct.empty) val stringsNotInList = Array("aardvark", "crow", "elk", "otter", "zoo") assert(stringsNotInList.forall(s => index.queryByKey(s).isEmpty)) val stringsInList = stringsWithDups.distinct - assert(stringsInList.forall(s => index.queryByKey(s) sameElements leafsWithDups.filter(_.key == s))) + assert(stringsInList.forall(s => + index.queryByKey(s) sameElements leafsWithDups.filter(_.key == s) + )) } } - @Test def testIntervalIterator() { + @Test def testIntervalIterator(): Unit = { for (branchingFactor <- 2 to 5) { val file = ctx.createTmpPath("interval", "idx") - writeIndex(file, stringsWithDups, stringsWithDups.indices.map(i => Row()).toArray, TStruct.empty, branchingFactor) + writeIndex( + file, + stringsWithDups, + stringsWithDups.indices.map(i => Row()).toArray, + TStruct.empty, + branchingFactor, + ) val index = indexReader(file, TStruct.empty) // intervals with endpoint in list - assert(index.queryByInterval("bear", "bear", includesStart = true, includesEnd = true).toFastSeq == index.iterator(0, 2).toFastSeq) - - assert(index.queryByInterval("bear", "cat", includesStart = true, includesEnd = false).toFastSeq == index.iterator(0, 2).toFastSeq) - assert(index.queryByInterval("bear", "cat", includesStart = true, includesEnd = true).toFastSeq == index.iterator(0, 9).toFastSeq) - assert(index.queryByInterval("bear", "cat", includesStart = false, includesEnd = true).toFastSeq == index.iterator(2, 9).toFastSeq) - assert(index.queryByInterval("bear", "cat", includesStart = false, includesEnd = false).toFastSeq == index.iterator(2, 2).toFastSeq) + assert(index.queryByInterval( + "bear", + "bear", + includesStart = true, + includesEnd = true, + ).toFastSeq == index.iterator(0, 2).toFastSeq) + + assert(index.queryByInterval( + "bear", + "cat", + includesStart = true, + includesEnd = false, + ).toFastSeq == index.iterator(0, 2).toFastSeq) + assert(index.queryByInterval( + "bear", + "cat", + includesStart = true, + includesEnd = true, + ).toFastSeq == index.iterator(0, 9).toFastSeq) + assert(index.queryByInterval( + "bear", + "cat", + includesStart = false, + includesEnd = true, + ).toFastSeq == index.iterator(2, 9).toFastSeq) + assert(index.queryByInterval( + "bear", + "cat", + includesStart = false, + includesEnd = false, + ).toFastSeq == index.iterator(2, 2).toFastSeq) // intervals with endpoint(s) not in list - assert(index.queryByInterval("cat", "snail", includesStart = true, includesEnd = false).toFastSeq == index.iterator(2, 15).toFastSeq) - assert(index.queryByInterval("cat", "snail", includesStart = true, includesEnd = true).toFastSeq == index.iterator(2, 15).toFastSeq) - assert(index.queryByInterval("aardvark", "cat", includesStart = true, includesEnd = true).toFastSeq == index.iterator(0, 9).toFastSeq) - assert(index.queryByInterval("aardvark", "cat", includesStart = false, includesEnd = true).toFastSeq == index.iterator(0, 9).toFastSeq) + assert(index.queryByInterval( + "cat", + "snail", + includesStart = true, + includesEnd = false, + ).toFastSeq == index.iterator(2, 15).toFastSeq) + assert(index.queryByInterval( + "cat", + "snail", + includesStart = true, + includesEnd = true, + ).toFastSeq == index.iterator(2, 15).toFastSeq) + assert(index.queryByInterval( + "aardvark", + "cat", + includesStart = true, + includesEnd = true, + ).toFastSeq == index.iterator(0, 9).toFastSeq) + assert(index.queryByInterval( + "aardvark", + "cat", + includesStart = false, + includesEnd = true, + ).toFastSeq == index.iterator(0, 9).toFastSeq) // illegal interval queries - intercept[IllegalArgumentException](index.queryByInterval("bear", "bear", includesStart = false, includesEnd = false).toFastSeq) - intercept[IllegalArgumentException](index.queryByInterval("bear", "bear", includesStart = false, includesEnd = true).toFastSeq) - intercept[IllegalArgumentException](index.queryByInterval("bear", "bear", includesStart = true, includesEnd = false).toFastSeq) - intercept[IllegalArgumentException](index.queryByInterval("cat", "bear", includesStart = true, includesEnd = true).toFastSeq) - - val endPoints = (stringsWithDups.distinct ++ Array("aardvark", "boar", "elk", "oppossum", "snail", "zoo")).combinations(2) + intercept[IllegalArgumentException](index.queryByInterval( + "bear", + "bear", + includesStart = false, + includesEnd = false, + ).toFastSeq) + intercept[IllegalArgumentException](index.queryByInterval( + "bear", + "bear", + includesStart = false, + includesEnd = true, + ).toFastSeq) + intercept[IllegalArgumentException](index.queryByInterval( + "bear", + "bear", + includesStart = true, + includesEnd = false, + ).toFastSeq) + intercept[IllegalArgumentException](index.queryByInterval( + "cat", + "bear", + includesStart = true, + includesEnd = true, + ).toFastSeq) + + val endPoints = (stringsWithDups.distinct ++ Array("aardvark", "boar", "elk", "oppossum", + "snail", "zoo")).combinations(2) val ordering = TString.ordering(ctx.stateManager) for (bounds <- endPoints) { @@ -227,8 +363,7 @@ class IndexSuite extends HailSuite { case -1 => stringsWithDups.length case x: Int => x } - } - else + } else stringsWithDups.lastIndexWhere(bounds(0) >= _) + 1 // want last index at transition point where key is gteq and then want to exclude that point (add 1) val upperBoundIdx = @@ -237,7 +372,12 @@ class IndexSuite extends HailSuite { else stringsWithDups.lastIndexWhere(bounds(1) > _) + 1 // want last index before transition point and then want to include that value so add 1 - assert(index.queryByInterval(bounds(0), bounds(1), includesStart, includesEnd).toFastSeq == + assert(index.queryByInterval( + bounds(0), + bounds(1), + includesStart, + includesEnd, + ).toFastSeq == index.iterator(lowerBoundIdx, upperBoundIdx).toFastSeq) if (includesStart) @@ -247,53 +387,101 @@ class IndexSuite extends HailSuite { assert(index.iterateUntil(bounds(1)).toFastSeq == leafsWithDups.slice(0, upperBoundIdx).toFastSeq) } else - intercept[IllegalArgumentException](index.queryByInterval(bounds(0), bounds(1), includesStart, includesEnd)) + intercept[IllegalArgumentException](index.queryByInterval( + bounds(0), + bounds(1), + includesStart, + includesEnd, + )) } } } } } - @Test def testIntervalIteratorWorksWithGeneralEndpoints() { + @Test def testIntervalIteratorWorksWithGeneralEndpoints(): Unit = { for (branchingFactor <- 2 to 5) { val keyType = PCanonicalStruct("a" -> PCanonicalString(), "b" -> PInt32()) val file = ctx.createTmpPath("from", "idx") - writeIndex(file, + writeIndex( + file, stringsWithDups.zipWithIndex.map { case (s, i) => Row(s, i) }, stringsWithDups.indices.map(i => Row()).toArray, keyType, +PCanonicalStruct(), branchingFactor, - Map.empty) + Map.empty, + ) - val leafChildren = stringsWithDups.zipWithIndex.map { case (s, i) => LeafChild(Row(s, i), i, Row()) }.toFastSeq + val leafChildren = stringsWithDups.zipWithIndex.map { case (s, i) => + LeafChild(Row(s, i), i, Row()) + }.toFastSeq - val index = indexReader(file, +PCanonicalStruct(), keyPType = PCanonicalStruct("a" -> PCanonicalString(), "b" -> PInt32())) - assert(index.queryByInterval(Row("cat", 3), Row("cat", 5), includesStart = true, includesEnd = false).toFastSeq == + val index = indexReader( + file, + +PCanonicalStruct(), + keyPType = PCanonicalStruct("a" -> PCanonicalString(), "b" -> PInt32()), + ) + assert(index.queryByInterval( + Row("cat", 3), + Row("cat", 5), + includesStart = true, + includesEnd = false, + ).toFastSeq == leafChildren.slice(3, 5)) - assert(index.queryByInterval(Row("cat"), Row("cat", 5), includesStart = true, includesEnd = false).toFastSeq == + assert(index.queryByInterval( + Row("cat"), + Row("cat", 5), + includesStart = true, + includesEnd = false, + ).toFastSeq == leafChildren.slice(2, 5)) - assert(index.queryByInterval(Row(), Row(), includesStart = true, includesEnd = true).toFastSeq == + assert(index.queryByInterval( + Row(), + Row(), + includesStart = true, + includesEnd = true, + ).toFastSeq == leafChildren) - assert(index.queryByInterval(Row(), Row("cat"), includesStart = true, includesEnd = false).toFastSeq == + assert(index.queryByInterval( + Row(), + Row("cat"), + includesStart = true, + includesEnd = false, + ).toFastSeq == leafChildren.take(2)) - assert(index.queryByInterval(Row("zebra"), Row(), includesStart = true, includesEnd = true).toFastSeq == + assert(index.queryByInterval( + Row("zebra"), + Row(), + includesStart = true, + includesEnd = true, + ).toFastSeq == leafChildren.takeRight(3)) } } - @Test def testIterateFromUntil() { + @Test def testIterateFromUntil(): Unit = { for (branchingFactor <- 2 to 5) { val file = ctx.createTmpPath("from", "idx") - writeIndex(file, stringsWithDups, stringsWithDups.indices.map(i => Row()).toArray, TStruct.empty, branchingFactor) + writeIndex( + file, + stringsWithDups, + stringsWithDups.indices.map(i => Row()).toArray, + TStruct.empty, + branchingFactor, + ) val index = indexReader(file, TStruct.empty) - val uniqueStrings = stringsWithDups.distinct ++ Array("aardvark", "crow", "elk", "otter", "zoo") + val uniqueStrings = + stringsWithDups.distinct ++ Array("aardvark", "crow", "elk", "otter", "zoo") uniqueStrings.foreach { s => var start = stringsWithDups.indexWhere(s <= _) if (start == -1) start = stringsWithDups.length - assert(index.iterateFrom(s).toFastSeq == leafsWithDups.slice(start, stringsWithDups.length).toFastSeq) + assert(index.iterateFrom(s).toFastSeq == leafsWithDups.slice( + start, + stringsWithDups.length, + ).toFastSeq) val end = stringsWithDups.lastIndexWhere(s > _) + 1 assert(index.iterateUntil(s).toFastSeq == leafsWithDups.slice(0, end).toFastSeq) diff --git a/hail/src/test/scala/is/hail/io/TabixSuite.scala b/hail/src/test/scala/is/hail/io/TabixSuite.scala index e1ff6d2cb46..7901a406546 100644 --- a/hail/src/test/scala/is/hail/io/TabixSuite.scala +++ b/hail/src/test/scala/is/hail/io/TabixSuite.scala @@ -1,9 +1,10 @@ package is.hail.io -import htsjdk.tribble.readers.{TabixReader => HtsjdkTabixReader} import is.hail.HailSuite import is.hail.io.tabix._ import is.hail.io.vcf.TabixVCF + +import htsjdk.tribble.readers.{TabixReader => HtsjdkTabixReader} import org.testng.annotations.{BeforeTest, Test} import org.testng.asserts.SoftAssert @@ -16,21 +17,19 @@ class TabixSuite extends HailSuite { lazy val reader = new TabixReader(vcfGzFile, fs) - @BeforeTest def initialize() { + @BeforeTest def initialize(): Unit = hc // reference to initialize - } - @Test def testLargeNumberOfSequences() { + @Test def testLargeNumberOfSequences(): Unit = { val tbx = new TabixReader(null, fs, Some("src/test/resources/large-tabix.tbi")) // known length of sequences assert(tbx.index.seqs.length == 3366) } - @Test def testSequenceNames() { + @Test def testSequenceNames(): Unit = { val expectedSeqNames = new Array[String](24); - for (i <- 1 to 22) { - expectedSeqNames(i-1) = i.toString - } + for (i <- 1 to 22) + expectedSeqNames(i - 1) = i.toString expectedSeqNames(22) = "X"; expectedSeqNames(23) = "Y"; @@ -38,20 +37,19 @@ class TabixSuite extends HailSuite { assert(expectedSeqNames.length == sequenceNames.size) val asserts = new SoftAssert() - for (s <- expectedSeqNames) { - asserts.assertTrue(sequenceNames.contains(s), s"sequenceNames does not contain ${ s }") - } + for (s <- expectedSeqNames) + asserts.assertTrue(sequenceNames.contains(s), s"sequenceNames does not contain $s") asserts.assertAll() } - @Test def testSequenceSet() { + @Test def testSequenceSet(): Unit = { val chrs = reader.index.chr2tid.keySet assert(!chrs.isEmpty) assert(chrs.contains("1")) assert(!chrs.contains("MT")) } - @Test def testLineIterator() { + @Test def testLineIterator(): Unit = { val htsjdkrdr = new HtsjdkTabixReader(vcfGzFile) // In range access for (chr <- Seq("1", "19", "X")) { @@ -94,22 +92,25 @@ class TabixSuite extends HailSuite { } } - def _testLineIterator2(vcfFile: String) { + def _testLineIterator2(vcfFile: String): Unit = { val chr = "20" val htsjdkrdr = new HtsjdkTabixReader(vcfFile) val hailrdr = new TabixReader(vcfFile, fs) val tid = hailrdr.chr2tid(chr) - for ((start, end) <- - Seq( - (12990058, 12990059), // Small interval, containing just one locus at end - (10570000, 13000000), // Open interval - (10019093, 16360860), // Closed interval - (11000000, 13029764), // Half open (beg, end] - (17434340, 18000000), // Half open [beg, end) - (13943975, 14733634), // Some random intervals - (11578765, 15291865), - (12703588, 16751726))) { + for ( + (start, end) <- + Seq( + (12990058, 12990059), // Small interval, containing just one locus at end + (10570000, 13000000), // Open interval + (10019093, 16360860), // Closed interval + (11000000, 13029764), // Half open (beg, end] + (17434340, 18000000), // Half open [beg, end) + (13943975, 14733634), // Some random intervals + (11578765, 15291865), + (12703588, 16751726), + ) + ) { val pairs = hailrdr.queryPairs(tid, start, end) val htsIter = htsjdkrdr.query(chr, start, end) val hailIter = new TabixLineIterator(fs, hailrdr.filePath, pairs) @@ -130,11 +131,10 @@ class TabixSuite extends HailSuite { } } - @Test def testLineIterator2() { + @Test def testLineIterator2(): Unit = _testLineIterator2("src/test/resources/sample.vcf.bgz") - } - @Test def testWriter() { + @Test def testWriter(): Unit = { val vcfFile = "src/test/resources/sample.vcf.bgz" val path = ctx.createTmpPath("test-tabix-write", "bgz") fs.copy(vcfFile, path) diff --git a/hail/src/test/scala/is/hail/io/compress/BGzipCodecSuite.scala b/hail/src/test/scala/is/hail/io/compress/BGzipCodecSuite.scala index 53fc70e053b..9c3a98ae9bc 100644 --- a/hail/src/test/scala/is/hail/io/compress/BGzipCodecSuite.scala +++ b/hail/src/test/scala/is/hail/io/compress/BGzipCodecSuite.scala @@ -1,21 +1,22 @@ package is.hail.io.compress -import htsjdk.samtools.util.BlockCompressedFilePointerUtil import is.hail.HailSuite import is.hail.TestUtils._ import is.hail.check.Gen import is.hail.check.Prop.forAll import is.hail.expr.ir.GenericLines import is.hail.utils._ -import org.apache.commons.io.IOUtils -import org.apache.spark.sql.Row -import org.apache.{hadoop => hd} -import org.testng.annotations.Test import scala.collection.JavaConverters._ import scala.collection.mutable import scala.io.Source +import htsjdk.samtools.util.BlockCompressedFilePointerUtil +import org.apache.{hadoop => hd} +import org.apache.commons.io.IOUtils +import org.apache.spark.sql.Row +import org.testng.annotations.Test + class TestFileInputFormat extends hd.mapreduce.lib.input.TextInputFormat { override def getSplits(job: hd.mapreduce.JobContext): java.util.List[hd.mapreduce.InputSplit] = { val hConf = job.getConfiguration @@ -38,7 +39,13 @@ class TestFileInputFormat extends hd.mapreduce.lib.input.TextInputFormat { val splitSize = e - s val blkIndex = getBlockIndex(blkLocations, s) - splits += makeSplit(path, s, splitSize, blkLocations(blkIndex).getHosts, blkLocations(blkIndex).getCachedHosts) + splits += makeSplit( + path, + s, + splitSize, + blkLocations(blkIndex).getHosts, + blkLocations(blkIndex).getCachedHosts, + ) } splits.asJava @@ -51,14 +58,12 @@ class BGzipCodecSuite extends HailSuite { // is actually a bgz file val gzPath = "src/test/resources/sample.vcf.gz" - /* - * bgz.test.sample.vcf.bgz was created as follows: - * - split sample.vcf into 60-line chunks: `split -l 60 sample.vcf sample.vcf.` - * - move the line boundary on two chunks by 1 character in different directions - * - bgzip compressed the chunks - * - stripped the empty terminate block in chunks except ad and ag (the last) - * - concatenated the chunks - */ + /* bgz.test.sample.vcf.bgz was created as follows: + * - split sample.vcf into 60-line chunks: `split -l 60 sample.vcf sample.vcf.` + * - move the line boundary on two chunks by 1 character in different directions + * - bgzip compressed the chunks + * - stripped the empty terminate block in chunks except ad and ag (the last) + * - concatenated the chunks */ val compPath = "src/test/resources/bgz.test.sample.vcf.bgz" def compareLines(lines2: IndexedSeq[String], lines: IndexedSeq[String]): Unit = { @@ -72,71 +77,77 @@ class BGzipCodecSuite extends HailSuite { } } - @Test def testGenericLinesSimpleUncompressed() { + @Test def testGenericLinesSimpleUncompressed(): Unit = { val lines = Source.fromFile(uncompPath).getLines().toFastSeq - val uncompStatus = fs.fileListEntry(uncompPath) + val uncompStatus = fs.fileStatus(uncompPath) var i = 0 while (i < 16) { val lines2 = GenericLines.collect( fs, - GenericLines.read(fs, Array(uncompStatus), Some(i), None, None, false, false)) + GenericLines.read(fs, Array(uncompStatus), Some(i), None, None, false, false), + ) compareLines(lines2, lines) i += 1 } } - @Test def testGenericLinesSimpleBGZ() { + @Test def testGenericLinesSimpleBGZ(): Unit = { val lines = Source.fromFile(uncompPath).getLines().toFastSeq - val compStatus = fs.fileListEntry(compPath) + val compStatus = fs.fileStatus(compPath) var i = 0 while (i < 16) { val lines2 = GenericLines.collect( fs, - GenericLines.read(fs, Array(compStatus), Some(i), None, None, false, false)) + GenericLines.read(fs, Array(compStatus), Some(i), None, None, false, false), + ) compareLines(lines2, lines) i += 1 } } - @Test def testGenericLinesSimpleGZ() { + @Test def testGenericLinesSimpleGZ(): Unit = { val lines = Source.fromFile(uncompPath).getLines().toFastSeq // won't split, just run once - val gzStatus = fs.fileListEntry(gzPath) + val gzStatus = fs.fileStatus(gzPath) val lines2 = GenericLines.collect( fs, - GenericLines.read(fs, Array(gzStatus), Some(7), None, None, false, true)) + GenericLines.read(fs, Array(gzStatus), Some(7), None, None, false, true), + ) compareLines(lines2, lines) } - @Test def testGenericLinesRefuseGZ() { + @Test def testGenericLinesRefuseGZ(): Unit = interceptFatal("Cowardly refusing") { - val gzStatus = fs.fileListEntry(gzPath) + val gzStatus = fs.fileStatus(gzPath) GenericLines.read(fs, Array(gzStatus), Some(7), None, None, false, false) } - } - @Test def testGenericLinesRandom() { + @Test def testGenericLinesRandom(): Unit = { val lines = Source.fromFile(uncompPath).getLines().toFastSeq val compLength = 195353 - val compSplits = Array[Long](6566, 20290, 33438, 41165, 56691, 70278, 77419, 92522, 106310, 112477, 112505, 124593, + val compSplits = Array[Long](6566, 20290, 33438, 41165, 56691, 70278, 77419, 92522, 106310, + 112477, 112505, 124593, 136405, 144293, 157375, 169172, 175174, 186973, 195325) - val g = for (n <- Gen.oneOfGen( - Gen.choose(0, 10), - Gen.choose(0, 100)); - rawSplits <- Gen.buildableOfN[Array](n, - Gen.oneOfGen(Gen.choose(0L, compLength), - Gen.applyGen(Gen.oneOf[(Long) => Long](identity, _ - 1, _ + 1), - Gen.oneOfSeq(compSplits))))) - yield - (Array(0L, compLength) ++ rawSplits).distinct.sorted + val g = for { + n <- Gen.oneOfGen( + Gen.choose(0, 10), + Gen.choose(0, 100), + ) + rawSplits <- Gen.buildableOfN[Array]( + n, + Gen.oneOfGen( + Gen.choose(0L, compLength), + Gen.applyGen(Gen.oneOf[(Long) => Long](identity, _ - 1, _ + 1), Gen.oneOfSeq(compSplits)), + ), + ) + } yield (Array(0L, compLength) ++ rawSplits).distinct.sorted val p = forAll(g) { splits => - val contexts = (0 until (splits.length - 1)) .map { i => val end = makeVirtualOffset(splits(i + 1), 0) @@ -149,7 +160,7 @@ class BGzipCodecSuite extends HailSuite { p.check() } - @Test def test() { + @Test def test(): Unit = { sc.hadoopConfiguration.setLong("mapreduce.input.fileinputformat.split.minsize", 1L) val uncompIS = fs.open(uncompPath) @@ -174,21 +185,25 @@ class BGzipCodecSuite extends HailSuite { } val compLength = 195353 - val compSplits = Array[Long](6566, 20290, 33438, 41165, 56691, 70278, 77419, 92522, 106310, 112477, 112505, 124593, + val compSplits = Array[Long](6566, 20290, 33438, 41165, 56691, 70278, 77419, 92522, 106310, + 112477, 112505, 124593, 136405, 144293, 157375, 169172, 175174, 186973, 195325) - val g = for (n <- Gen.oneOfGen( - Gen.choose(0, 10), - Gen.choose(0, 100)); - rawSplits <- Gen.buildableOfN[Array](n, - Gen.oneOfGen(Gen.choose(0L, compLength), - Gen.applyGen(Gen.oneOf[(Long) => Long](identity, _ - 1, _ + 1), - Gen.oneOfSeq(compSplits))))) - yield - (Array(0L, compLength) ++ rawSplits).distinct.sorted + val g = for { + n <- Gen.oneOfGen( + Gen.choose(0, 10), + Gen.choose(0, 100), + ) + rawSplits <- Gen.buildableOfN[Array]( + n, + Gen.oneOfGen( + Gen.choose(0L, compLength), + Gen.applyGen(Gen.oneOf[(Long) => Long](identity, _ - 1, _ + 1), Gen.oneOfSeq(compSplits)), + ), + ) + } yield (Array(0L, compLength) ++ rawSplits).distinct.sorted val p = forAll(g) { splits => - val jobConf = new hd.conf.Configuration(sc.hadoopConfiguration) jobConf.set("bgz.test.splits", splits.mkString(",")) val rdd = sc.newAPIHadoopFile[hd.io.LongWritable, hd.io.Text, TestFileInputFormat]( @@ -196,7 +211,8 @@ class BGzipCodecSuite extends HailSuite { classOf[TestFileInputFormat], classOf[hd.io.LongWritable], classOf[hd.io.Text], - jobConf) + jobConf, + ) val rddLines = rdd.map(_._2.toString).collectOrdered() rddLines.sameElements(lines) @@ -204,12 +220,18 @@ class BGzipCodecSuite extends HailSuite { p.check() } - @Test def testVirtualSeek() { + @Test def testVirtualSeek(): Unit = { // real offsets of the start of some blocks, paired with the offset to the next block - val blockStarts = Array[(Long, Long)]((0, 14653), (69140, 82949), (133703, 146664), (181362, 192983 /* end of file */)) + val blockStarts = Array[(Long, Long)]( + (0, 14653), + (69140, 82949), + (133703, 146664), + (181362, 192983 /* end of file */ ), + ) // NOTE: maxBlockSize is the length of all blocks other than the last val maxBlockSize = 65280 - // number determined by counting bytes from sample.vcf from uncompBlockStarts.last to the end of the file + /* number determined by counting bytes from sample.vcf from uncompBlockStarts.last to the end of + * the file */ val lastBlockLen = 55936 // offsets into the uncompressed file val uncompBlockStarts = Array[Int](0, 326400, 652800, 913920) @@ -219,77 +241,87 @@ class BGzipCodecSuite extends HailSuite { using(fs.openNoCompression(uncompPath)) { uncompIS => using(new BGzipInputStream(fs.openNoCompression(compPath))) { decompIS => - - val fromEnd = 48 // arbitrary number of bytes from the end of block to attempt to seek to - for (((cOff, nOff), uOff) <- blockStarts.zip(uncompBlockStarts); - e <- Seq(0, 1024, maxBlockSize - fromEnd)) { - val decompData = new Array[Byte](100) - val uncompData = new Array[Byte](100) - val extra = if (cOff == blockStarts.last._1 && e == maxBlockSize - fromEnd) + val fromEnd = 48 // arbitrary number of bytes from the end of block to attempt to seek to + for { + ((cOff, nOff), uOff) <- blockStarts.zip(uncompBlockStarts) + e <- Seq(0, 1024, maxBlockSize - fromEnd) + } { + val decompData = new Array[Byte](100) + val uncompData = new Array[Byte](100) + val extra = if (cOff == blockStarts.last._1 && e == maxBlockSize - fromEnd) lastBlockLen - fromEnd else e - val vOff = BlockCompressedFilePointerUtil.makeFilePointer(cOff, extra) - - decompIS.virtualSeek(vOff) - assert(decompIS.getVirtualOffset() == vOff); - uncompIS.seek(uOff + extra) - - val decompRead = decompIS.readRepeatedly(decompData) - val uncompRead = uncompIS.readRepeatedly(uncompData) - - assert(decompRead == uncompRead, s"""compressed offset: ${ cOff } - |decomp bytes read: ${ decompRead } - |uncomp bytes read: ${ uncompRead }\n""".stripMargin) - assert(decompData sameElements uncompData, s"data differs for compressed offset: ${ cOff }") - val expectedVirtualOffset = if (extra == lastBlockLen - fromEnd) + val vOff = BlockCompressedFilePointerUtil.makeFilePointer(cOff, extra) + + decompIS.virtualSeek(vOff) + assert(decompIS.getVirtualOffset() == vOff); + uncompIS.seek(uOff + extra) + + val decompRead = decompIS.readRepeatedly(decompData) + val uncompRead = uncompIS.readRepeatedly(uncompData) + + assert( + decompRead == uncompRead, + s"""compressed offset: $cOff + |decomp bytes read: $decompRead + |uncomp bytes read: $uncompRead\n""".stripMargin, + ) + assert(decompData sameElements uncompData, s"data differs for compressed offset: $cOff") + val expectedVirtualOffset = if (extra == lastBlockLen - fromEnd) BlockCompressedFilePointerUtil.makeFilePointer(nOff, 0) else if (extra == maxBlockSize - fromEnd) BlockCompressedFilePointerUtil.makeFilePointer(nOff, decompRead - fromEnd) else BlockCompressedFilePointerUtil.makeFilePointer(cOff, extra + decompRead) - assert(expectedVirtualOffset == decompIS.getVirtualOffset()) - } - - // here we test reading from the middle of a block to it's end - val decompData = new Array[Byte](maxBlockSize) - val toSkip = 20000 - val vOff = BlockCompressedFilePointerUtil.makeFilePointer(blockStarts(2)._1, toSkip) - decompIS.virtualSeek(vOff) - assert(decompIS.getVirtualOffset() == vOff) - assert(decompIS.read(decompData) == maxBlockSize - 20000) - assert(decompIS.getVirtualOffset() == BlockCompressedFilePointerUtil.makeFilePointer(blockStarts(2)._2, 0)) - - // Trying to seek to the end of a block should fail - intercept[java.io.IOException] { - val vOff = BlockCompressedFilePointerUtil.makeFilePointer(blockStarts(1)._1, maxBlockSize) - decompIS.virtualSeek(vOff) - } + assert(expectedVirtualOffset == decompIS.getVirtualOffset()) + } - // Trying to seek past the end of a block should fail - intercept[java.io.IOException] { - val vOff = BlockCompressedFilePointerUtil.makeFilePointer(blockStarts(0)._1, maxBlockSize + 1) + // here we test reading from the middle of a block to it's end + val decompData = new Array[Byte](maxBlockSize) + val toSkip = 20000 + val vOff = BlockCompressedFilePointerUtil.makeFilePointer(blockStarts(2)._1, toSkip) decompIS.virtualSeek(vOff) - } + assert(decompIS.getVirtualOffset() == vOff) + assert(decompIS.read(decompData) == maxBlockSize - 20000) + assert(decompIS.getVirtualOffset() == BlockCompressedFilePointerUtil.makeFilePointer( + blockStarts(2)._2, + 0, + )) + + // Trying to seek to the end of a block should fail + intercept[java.io.IOException] { + val vOff = BlockCompressedFilePointerUtil.makeFilePointer(blockStarts(1)._1, maxBlockSize) + decompIS.virtualSeek(vOff) + } - // Trying to seek to the end of the last block should fail - intercept[java.io.IOException] { - val vOff = BlockCompressedFilePointerUtil.makeFilePointer(blockStarts.last._1, lastBlockLen) - decompIS.virtualSeek(vOff) - } + // Trying to seek past the end of a block should fail + intercept[java.io.IOException] { + val vOff = + BlockCompressedFilePointerUtil.makeFilePointer(blockStarts(0)._1, maxBlockSize + 1) + decompIS.virtualSeek(vOff) + } - // trying to seek to the end of file should succeed - decompIS.virtualSeek(0) - val eofOffset = BlockCompressedFilePointerUtil.makeFilePointer(blockStarts.last._2, 0) - decompIS.virtualSeek(eofOffset) - assert(-1 == decompIS.read()) + // Trying to seek to the end of the last block should fail + intercept[java.io.IOException] { + val vOff = + BlockCompressedFilePointerUtil.makeFilePointer(blockStarts.last._1, lastBlockLen) + decompIS.virtualSeek(vOff) + } - // seeking past end of file directly should fail - decompIS.virtualSeek(0) - intercept[java.io.IOException] { - val vOff = BlockCompressedFilePointerUtil.makeFilePointer(blockStarts.last._2, 1) - decompIS.virtualSeek(vOff) + // trying to seek to the end of file should succeed + decompIS.virtualSeek(0) + val eofOffset = BlockCompressedFilePointerUtil.makeFilePointer(blockStarts.last._2, 0) + decompIS.virtualSeek(eofOffset) + assert(-1 == decompIS.read()) + + // seeking past end of file directly should fail + decompIS.virtualSeek(0) + intercept[java.io.IOException] { + val vOff = BlockCompressedFilePointerUtil.makeFilePointer(blockStarts.last._2, 1) + decompIS.virtualSeek(vOff) + } } - }} + } } } diff --git a/hail/src/test/scala/is/hail/io/fs/AzureStorageFSSuite.scala b/hail/src/test/scala/is/hail/io/fs/AzureStorageFSSuite.scala index f7aa3f1f01c..d6f3b8cbc4f 100644 --- a/hail/src/test/scala/is/hail/io/fs/AzureStorageFSSuite.scala +++ b/hail/src/test/scala/is/hail/io/fs/AzureStorageFSSuite.scala @@ -1,14 +1,12 @@ package is.hail.io.fs +import java.io.FileInputStream + import org.apache.commons.io.IOUtils -import org.scalatest.testng.TestNGSuite import org.testng.SkipException import org.testng.annotations.{BeforeClass, Test} -import java.io.FileInputStream - - -class AzureStorageFSSuite extends TestNGSuite with FSSuite { +class AzureStorageFSSuite extends FSSuite { @BeforeClass def beforeclass(): Unit = { if (System.getenv("HAIL_CLOUD") != "azure") { @@ -22,11 +20,11 @@ class AzureStorageFSSuite extends TestNGSuite with FSSuite { lazy val fs = { val aac = System.getenv("AZURE_APPLICATION_CREDENTIALS") if (aac == null) { - new AzureStorageFS() - } - else { + new AzureStorageFS() + } else { new AzureStorageFS( - Some(new String(IOUtils.toByteArray(new FileInputStream(aac))))) + Some(new String(IOUtils.toByteArray(new FileInputStream(aac)))) + ) } } @@ -35,18 +33,12 @@ class AzureStorageFSSuite extends TestNGSuite with FSSuite { assert(fs.makeQualified(qualifiedFileName) == qualifiedFileName) val unqualifiedFileName = "https://account/container/path" - try { + try fs.makeQualified(unqualifiedFileName) - } catch { case _: IllegalArgumentException => return } assert(false) } - - @Test def testETag(): Unit = { - val etag = fs.eTag(s"$fsResourcesRoot/a") - assert(etag.nonEmpty) - } } diff --git a/hail/src/test/scala/is/hail/io/fs/FSSuite.scala b/hail/src/test/scala/is/hail/io/fs/FSSuite.scala index 23f2622efb9..0901c6a2f57 100644 --- a/hail/src/test/scala/is/hail/io/fs/FSSuite.scala +++ b/hail/src/test/scala/is/hail/io/fs/FSSuite.scala @@ -1,15 +1,17 @@ package is.hail.io.fs -import is.hail.HailSuite +import is.hail.{HailSuite, TestUtils} import is.hail.backend.ExecuteContext import is.hail.io.fs.FSUtil.dropTrailingSlash import is.hail.utils._ + +import java.io.FileNotFoundException + import org.apache.commons.io.IOUtils +import org.apache.hadoop.fs.FileAlreadyExistsException import org.scalatest.testng.TestNGSuite import org.testng.annotations.Test -import java.io.FileNotFoundException - trait FSSuite extends TestNGSuite { val root: String = System.getenv("HAIL_TEST_STORAGE_URI") @@ -20,34 +22,23 @@ trait FSSuite extends TestNGSuite { def fs: FS /* Structure of src/test/resources/fs: - /a - /adir - /adir/x - /az - /dir - /dir/x - /zzz - */ + * /a /adir /adir/x /az /dir /dir/x /zzz */ def r(s: String): String = s"$fsResourcesRoot$s" - def t(extension: String = null): String = ExecuteContext.createTmpPathNoCleanup(tmpdir, "fs-suite-tmp", extension) + def t(extension: String = null): String = + ExecuteContext.createTmpPathNoCleanup(tmpdir, "fs-suite-tmp", extension) - def pathsRelRoot(root: String, statuses: Array[FileListEntry]): Set[String] = { + def pathsRelRoot(root: String, statuses: Array[FileListEntry]): Set[String] = statuses.map { status => - var p = status.getPath + val p = status.getPath assert(p.startsWith(root), s"$p $root") p.drop(root.length) }.toSet - } - - def pathsRelResourcesRoot(statuses: Array[FileListEntry]): Set[String] = pathsRelRoot(fsResourcesRoot, statuses) - @Test def testExists(): Unit = { - assert(fs.exists(r("/a"))) - - assert(fs.exists(r("/zzz"))) - assert(!fs.exists(r("/z"))) // prefix + def pathsRelResourcesRoot(statuses: Array[FileListEntry]): Set[String] = + pathsRelRoot(fsResourcesRoot, statuses) + @Test def testExistsOnDirectory(): Unit = { assert(fs.exists(r("/dir"))) assert(fs.exists(r("/dir/"))) @@ -55,6 +46,21 @@ trait FSSuite extends TestNGSuite { assert(!fs.exists(r("/does_not_exist_dir/"))) } + @Test def testExistsOnFile(): Unit = { + assert(fs.exists(r("/a"))) + + assert(fs.exists(r("/zzz"))) + assert(!fs.exists(r("/z"))) // prefix + } + + @Test def testFileStatusOnFile(): Unit = { + // file + val f = r("/a") + val s = fs.fileStatus(f) + assert(s.getPath == f) + assert(s.getLen == 12) + } + @Test def testFileListEntryOnFile(): Unit = { // file val f = r("/a") @@ -65,6 +71,13 @@ trait FSSuite extends TestNGSuite { assert(s.getLen == 12) } + @Test def testFileStatusOnDirIsFailure(): Unit = { + val f = r("/dir") + TestUtils.interceptException[FileNotFoundException](f)( + fs.fileStatus(f) + ) + } + @Test def testFileListEntryOnDir(): Unit = { // file val f = r("/dir") @@ -84,9 +97,9 @@ trait FSSuite extends TestNGSuite { } @Test def testFileListEntryOnMissingFile(): Unit = { - try { + try fs.fileListEntry(r("/does_not_exist")) - } catch { + catch { case _: FileNotFoundException => return } @@ -156,32 +169,42 @@ trait FSSuite extends TestNGSuite { @Test def testGlobFilename(): Unit = { val statuses = fs.glob(r("/a*")) - assert(pathsRelResourcesRoot(statuses) == Set("/a", "/adir", "/az"), - s"${statuses} ${pathsRelResourcesRoot(statuses)} ${Set("/a", "/adir", "/az")}") + assert( + pathsRelResourcesRoot(statuses) == Set("/a", "/adir", "/az"), + s"$statuses ${pathsRelResourcesRoot(statuses)} ${Set("/a", "/adir", "/az")}", + ) } @Test def testGlobFilenameMatchSingleCharacter(): Unit = { val statuses = fs.glob(r("/a?")) - assert(pathsRelResourcesRoot(statuses) == Set("/az"), - s"${statuses} ${pathsRelResourcesRoot(statuses)} ${Set("/az")}") + assert( + pathsRelResourcesRoot(statuses) == Set("/az"), + s"$statuses ${pathsRelResourcesRoot(statuses)} ${Set("/az")}", + ) } @Test def testGlobFilenameMatchSingleCharacterInMiddleOfName(): Unit = { val statuses = fs.glob(r("/a?ir")) - assert(pathsRelResourcesRoot(statuses) == Set("/adir"), - s"${statuses} ${pathsRelResourcesRoot(statuses)} ${Set("/adir")}") + assert( + pathsRelResourcesRoot(statuses) == Set("/adir"), + s"$statuses ${pathsRelResourcesRoot(statuses)} ${Set("/adir")}", + ) } @Test def testGlobDirnameMatchSingleCharacterInMiddleOfName(): Unit = { val statuses = fs.glob(r("/a?ir/x")) - assert(pathsRelResourcesRoot(statuses) == Set("/adir/x"), - s"${statuses} ${pathsRelResourcesRoot(statuses)} ${Set("/adir/x")}") + assert( + pathsRelResourcesRoot(statuses) == Set("/adir/x"), + s"$statuses ${pathsRelResourcesRoot(statuses)} ${Set("/adir/x")}", + ) } @Test def testGlobMatchDir(): Unit = { val statuses = fs.glob(r("/*dir/x")) - assert(pathsRelResourcesRoot(statuses) == Set("/adir/x", "/dir/x"), - s"${statuses} ${pathsRelResourcesRoot(statuses)} ${Set("/adir/x", "/dir/x")}") + assert( + pathsRelResourcesRoot(statuses) == Set("/adir/x", "/dir/x"), + s"$statuses ${pathsRelResourcesRoot(statuses)} ${Set("/adir/x", "/dir/x")}", + ) } @Test def testGlobRoot(): Unit = { @@ -190,6 +213,20 @@ trait FSSuite extends TestNGSuite { assert(pathsRelRoot(root, statuses) == Set("")) } + @Test def testFileEndingWithPeriod(): Unit = { + val f = fs.makeQualified(t()) + fs.touch(f + "/foo.") + val statuses = fs.listDirectory(f) + assert(statuses.length == 1, statuses) + val status = statuses(0) + if (this.isInstanceOf[AzureStorageFSSuite]) { + // https://github.com/Azure/azure-sdk-for-java/issues/36674 + assert(status.getPath == f + "/foo") + } else { + assert(status.getPath == f + "/foo.") + } + } + @Test def testGlobRootWithSlash(): Unit = { if (root.endsWith("/")) return @@ -245,14 +282,14 @@ trait FSSuite extends TestNGSuite { val s2 = "second" val f = t() - using(fs.create(f)) { _.write(s1.getBytes) } + using(fs.create(f))(_.write(s1.getBytes)) assert(fs.exists(f)) using(fs.open(f)) { is => val read = new String(IOUtils.toByteArray(is)) assert(read == s1) } - using(fs.create(f)) { _.write(s2.getBytes) } + using(fs.create(f))(_.write(s2.getBytes)) assert(fs.exists(f)) using(fs.open(f)) { is => val read = new String(IOUtils.toByteArray(is)) @@ -260,13 +297,11 @@ trait FSSuite extends TestNGSuite { } } - @Test def testGetCodecExtension(): Unit = { + @Test def testGetCodecExtension(): Unit = assert(fs.getCodecExtension("foo.vcf.bgz") == ".bgz") - } - @Test def testStripCodecExtension(): Unit = { + @Test def testStripCodecExtension(): Unit = assert(fs.stripCodecExtension("foo.vcf.bgz") == "foo.vcf") - } @Test def testReadWriteBytes(): Unit = { val f = t() @@ -316,7 +351,7 @@ trait FSSuite extends TestNGSuite { var i = 0 while (i < numWrites) { val readFromIs = is.read() - assert(readFromIs == (i & 0xff), s"${i} ${i & 0xff} ${readFromIs}") + assert(readFromIs == (i & 0xff), s"$i ${i & 0xff} $readFromIs") i = i + 1 } } @@ -336,9 +371,9 @@ trait FSSuite extends TestNGSuite { @Test def testSeekMoreThanMaxInt(): Unit = { val f = t() - using (fs.create(f)) { os => + using(fs.create(f)) { os => val eight_mib = 8 * 1024 * 1024 - val arr = Array.fill(eight_mib){0.toByte} + val arr = Array.fill(eight_mib)(0.toByte) var i = 0 // 256 * 8MiB = 2GiB while (i < 256) { @@ -378,7 +413,6 @@ trait FSSuite extends TestNGSuite { } using(fs.openNoCompression(f)) { is => - is.seek(251) assert(is.read() == 0) assert(is.read() == 1) @@ -389,17 +423,14 @@ trait FSSuite extends TestNGSuite { val toRead = new Array[Byte](512) is.readFully(toRead) - (0 until toRead.length).foreach { i => - assert(toRead(i) == ((seekPos + i) % 251).toByte) - } + (0 until toRead.length).foreach(i => assert(toRead(i) == ((seekPos + i) % 251).toByte)) } } @Test def largeDirectoryOperations(): Unit = { - val prefix = s"$tmpdir/fs-suite/delete-many-files/${ java.util.UUID.randomUUID() }" - for (i <- 0 until 2000) { + val prefix = s"$tmpdir/fs-suite/delete-many-files/${java.util.UUID.randomUUID()}" + for (i <- 0 until 2000) fs.touch(s"$prefix/$i.suffix") - } assert(fs.listDirectory(prefix).size == 2000) assert(fs.glob(prefix + "/" + "*.suffix").size == 2000) @@ -407,15 +438,19 @@ trait FSSuite extends TestNGSuite { assert(fs.exists(prefix)) fs.delete(prefix, recursive = true) if (fs.exists(prefix)) { - // NB: TestNGSuite.assert does not have a lazy message argument so we must use an if to protect this list + /* NB: TestNGSuite.assert does not have a lazy message argument so we must use an if to + * protect this list */ // // see: https://www.scalatest.org/scaladoc/1.7.2/org/scalatest/testng/TestNGSuite.html - assert(false, s"files not deleted:\n${ fs.listDirectory(prefix).map(_.getPath).mkString("\n") }") + assert( + false, + s"files not deleted:\n${fs.listDirectory(prefix).map(_.getPath).mkString("\n")}", + ) } } @Test def testSeekAfterEOF(): Unit = { - val prefix = s"$tmpdir/fs-suite/delete-many-files/${ java.util.UUID.randomUUID() }" + val prefix = s"$tmpdir/fs-suite/delete-many-files/${java.util.UUID.randomUUID()}" val p = s"$prefix/seek_file" using(fs.createCachedNoCompression(p)) { os => os.write(1.toByte) @@ -433,17 +468,176 @@ trait FSSuite extends TestNGSuite { assert(is.read() == 1.toByte) } } + + @Test def fileAndDirectoryIsError(): Unit = { + val d = t() + fs.mkDir(d) + fs.touch(s"$d/x/file") + try { + fs.touch(s"$d/x") + fs.fileListEntry(s"$d/x") + assert(false) + } catch { + /* Hadoop, in particular, errors when you touch an object whose name is a prefix of another + * object. */ + case exc: FileAndDirectoryException + if exc.getMessage() == s"$d/x appears as both file $d/x and directory $d/x/." => + case exc: FileNotFoundException if exc.getMessage() == s"$d/x (Is a directory)" => + } + } + + @Test def testETag(): Unit = { + val etag = fs.eTag(s"$fsResourcesRoot/a") + if (fs.parseUrl(fsResourcesRoot).toString.startsWith("file:")) { + // only the local file system should lack etags. + assert(etag.isEmpty) + } else { + assert(etag.nonEmpty) + } + } + + @Test def fileAndDirectoryIsErrorEvenIfPrefixedFileIsNotLexicographicallyFirst(): Unit = { + val d = t() + fs.mkDir(d) + fs.touch(s"$d/x") + // fs.touch(s"$d/x ") // Hail does not support spaces in path names + fs.touch(s"$d/x!") + fs.touch(s"$d/x${'"'}") + fs.touch(s"$d/x#") + fs.touch(s"$d/x$$") + // fs.touch(s"$d/x%") // Azure dislikes %'s + // java.lang.IllegalArgumentException: URLDecoder: Incomplete trailing escape (%) pattern + // at java.net.URLDecoder.decode(URLDecoder.java:187) + // at is.hail.shadedazure.com.azure.storage.common.Utility.decode(Utility.java:88) + // at is.hail.shadedazure.com.azure.storage.common.Utility.urlDecode(Utility.java:55) + /* at + * is.hail.shadedazure.com.azure.storage.blob.specialized.BlobAsyncClientBase.(BlobAsyncClientBase.java:238) */ + /* at + * is.hail.shadedazure.com.azure.storage.blob.specialized.BlobAsyncClientBase.(BlobAsyncClientBase.java:202) */ + /* at + * is.hail.shadedazure.com.azure.storage.blob.BlobAsyncClient.(BlobAsyncClient.java:154) */ + /* at + * is.hail.shadedazure.com.azure.storage.blob.BlobContainerAsyncClient.getBlobAsyncClient(BlobContainerAsyncClient.java:194) */ + /* at + * is.hail.shadedazure.com.azure.storage.blob.BlobContainerAsyncClient.getBlobAsyncClient(BlobContainerAsyncClient.java:172) */ + /* at + * is.hail.shadedazure.com.azure.storage.blob.BlobContainerClient.getBlobClient(BlobContainerClient.java:98) */ + // at is.hail.io.fs.AzureStorageFS.$anonfun$getBlobClient$1(AzureStorageFS.scala:255) + fs.touch(s"$d/x&") + fs.touch(s"$d/x'") + fs.touch(s"$d/x)") + fs.touch(s"$d/x(") + fs.touch(s"$d/x*") + fs.touch(s"$d/x+") + fs.touch(s"$d/x,") + fs.touch(s"$d/x-") + // fs.touch(s"$d/x.") // https://github.com/Azure/azure-sdk-for-java/issues/36674 + try { + fs.touch(s"$d/x/file") + fs.fileListEntry(s"$d/x") + assert(false) + } catch { + /* Hadoop, in particular, errors when you touch an object whose name is a prefix of another + * object. */ + case exc: FileAndDirectoryException + if exc.getMessage() == s"$d/x appears as both file $d/x and directory $d/x/." => + case exc: FileAlreadyExistsException + if exc.getMessage() == s"Destination exists and is not a directory: $d/x" => + } + } + + @Test def fileListEntrySeesDirectoryEvenIfPrefixedFileIsNotLexicographicallyFirst(): Unit = { + val d = t() + fs.mkDir(d) + // fs.touch(s"$d/x ") // Hail does not support spaces in path names + fs.touch(s"$d/x!") + fs.touch(s"$d/x${'"'}") + fs.touch(s"$d/x#") + fs.touch(s"$d/x$$") + // fs.touch(s"$d/x%") // Azure dislikes %'s + // java.lang.IllegalArgumentException: URLDecoder: Incomplete trailing escape (%) pattern + // at java.net.URLDecoder.decode(URLDecoder.java:187) + // at is.hail.shadedazure.com.azure.storage.common.Utility.decode(Utility.java:88) + // at is.hail.shadedazure.com.azure.storage.common.Utility.urlDecode(Utility.java:55) + /* at + * is.hail.shadedazure.com.azure.storage.blob.specialized.BlobAsyncClientBase.(BlobAsyncClientBase.java:238) */ + /* at + * is.hail.shadedazure.com.azure.storage.blob.specialized.BlobAsyncClientBase.(BlobAsyncClientBase.java:202) */ + /* at + * is.hail.shadedazure.com.azure.storage.blob.BlobAsyncClient.(BlobAsyncClient.java:154) */ + /* at + * is.hail.shadedazure.com.azure.storage.blob.BlobContainerAsyncClient.getBlobAsyncClient(BlobContainerAsyncClient.java:194) */ + /* at + * is.hail.shadedazure.com.azure.storage.blob.BlobContainerAsyncClient.getBlobAsyncClient(BlobContainerAsyncClient.java:172) */ + /* at + * is.hail.shadedazure.com.azure.storage.blob.BlobContainerClient.getBlobClient(BlobContainerClient.java:98) */ + // at is.hail.io.fs.AzureStorageFS.$anonfun$getBlobClient$1(AzureStorageFS.scala:255) + fs.touch(s"$d/x&") + fs.touch(s"$d/x'") + fs.touch(s"$d/x)") + fs.touch(s"$d/x(") + fs.touch(s"$d/x*") + fs.touch(s"$d/x+") + fs.touch(s"$d/x,") + fs.touch(s"$d/x-") + // fs.touch(s"$d/x.") // https://github.com/Azure/azure-sdk-for-java/issues/36674 + fs.touch(s"$d/x/file") + + val fle = fs.fileListEntry(s"$d/x") + assert(fle.isDirectory) + assert(!fle.isFile) + } + + @Test def fileListEntrySeesFileEvenWithPeersPreceedingThePositionOfANonPresentDirectoryEntry() + : Unit = { + val d = t() + fs.mkDir(d) + fs.touch(s"$d/x") + // fs.touch(s"$d/x ") // Hail does not support spaces in path names + fs.touch(s"$d/x!") + fs.touch(s"$d/x${'"'}") + fs.touch(s"$d/x#") + fs.touch(s"$d/x$$") + // fs.touch(s"$d/x%") // Azure dislikes %'s + // java.lang.IllegalArgumentException: URLDecoder: Incomplete trailing escape (%) pattern + // at java.net.URLDecoder.decode(URLDecoder.java:187) + // at is.hail.shadedazure.com.azure.storage.common.Utility.decode(Utility.java:88) + // at is.hail.shadedazure.com.azure.storage.common.Utility.urlDecode(Utility.java:55) + /* at + * is.hail.shadedazure.com.azure.storage.blob.specialized.BlobAsyncClientBase.(BlobAsyncClientBase.java:238) */ + /* at + * is.hail.shadedazure.com.azure.storage.blob.specialized.BlobAsyncClientBase.(BlobAsyncClientBase.java:202) */ + /* at + * is.hail.shadedazure.com.azure.storage.blob.BlobAsyncClient.(BlobAsyncClient.java:154) */ + /* at + * is.hail.shadedazure.com.azure.storage.blob.BlobContainerAsyncClient.getBlobAsyncClient(BlobContainerAsyncClient.java:194) */ + /* at + * is.hail.shadedazure.com.azure.storage.blob.BlobContainerAsyncClient.getBlobAsyncClient(BlobContainerAsyncClient.java:172) */ + /* at + * is.hail.shadedazure.com.azure.storage.blob.BlobContainerClient.getBlobClient(BlobContainerClient.java:98) */ + // at is.hail.io.fs.AzureStorageFS.$anonfun$getBlobClient$1(AzureStorageFS.scala:255) + fs.touch(s"$d/x&") + fs.touch(s"$d/x'") + fs.touch(s"$d/x)") + fs.touch(s"$d/x(") + fs.touch(s"$d/x*") + fs.touch(s"$d/x+") + fs.touch(s"$d/x,") + fs.touch(s"$d/x-") + // fs.touch(s"$d/x.") // https://github.com/Azure/azure-sdk-for-java/issues/36674 + + val fle = fs.fileListEntry(s"$d/x") + assert(!fle.isDirectory) + assert(fle.isFile) + assert(fle.getPath == fs.parseUrl(s"$d/x").toString) + } } class HadoopFSSuite extends HailSuite with FSSuite { override val root: String = "file:/" - override lazy val fsResourcesRoot: String = "file:" + new java.io.File("./src/test/resources/fs").getCanonicalPath + override lazy val fsResourcesRoot: String = + "file:" + new java.io.File("./src/test/resources/fs").getCanonicalPath override lazy val tmpdir: String = ctx.tmpdir - - @Test def testETag(): Unit = { - val etag = fs.eTag(s"$fsResourcesRoot/a") - assert(etag.isEmpty) - } } diff --git a/hail/src/test/scala/is/hail/io/fs/FakeFS.scala b/hail/src/test/scala/is/hail/io/fs/FakeFS.scala index 9770097e8fc..d91a6e57339 100644 --- a/hail/src/test/scala/is/hail/io/fs/FakeFS.scala +++ b/hail/src/test/scala/is/hail/io/fs/FakeFS.scala @@ -1,19 +1,20 @@ package is.hail.io.fs - case class FakeURL(path: String) extends FSURL { - def getPath(): String = path + def getPath: String = path + def getActualUrl: String = path } abstract class FakeFS extends FS { override type URL = FakeURL override def validUrl(filename: String): Boolean = ??? override def parseUrl(filename: String): FakeURL = FakeURL(filename) - override def urlAddPathComponent(url: FakeURL,component: String): FakeURL = ??? + override def urlAddPathComponent(url: FakeURL, component: String): FakeURL = ??? override def openNoCompression(url: FakeURL): SeekableDataInputStream = ??? override def createNoCompression(url: FakeURL): PositionedDataOutputStream = ??? - override def delete(url: FakeURL,recursive: Boolean): Unit = ??? + override def delete(url: FakeURL, recursive: Boolean): Unit = ??? override def eTag(url: FakeURL): Option[String] = ??? + override def fileStatus(url: FakeURL): FileStatus = ??? override def fileListEntry(url: FakeURL): FileListEntry = ??? override def glob(url: FakeURL): Array[FileListEntry] = ??? override def listDirectory(url: FakeURL): Array[FileListEntry] = ??? diff --git a/hail/src/test/scala/is/hail/io/fs/GoogleStorageFSSuite.scala b/hail/src/test/scala/is/hail/io/fs/GoogleStorageFSSuite.scala index 44db57abd83..ff7c234d56a 100644 --- a/hail/src/test/scala/is/hail/io/fs/GoogleStorageFSSuite.scala +++ b/hail/src/test/scala/is/hail/io/fs/GoogleStorageFSSuite.scala @@ -1,12 +1,12 @@ package is.hail.io.fs +import java.io.FileInputStream + import org.apache.commons.io.IOUtils import org.scalatest.testng.TestNGSuite import org.testng.SkipException import org.testng.annotations.{BeforeClass, Test} -import java.io.FileInputStream - class GoogleStorageFSSuite extends TestNGSuite with FSSuite { @BeforeClass def beforeclass(): Unit = { @@ -24,7 +24,8 @@ class GoogleStorageFSSuite extends TestNGSuite with FSSuite { new GoogleStorageFS() } else { new GoogleStorageFS( - Some(new String(IOUtils.toByteArray(new FileInputStream(gac))))) + Some(new String(IOUtils.toByteArray(new FileInputStream(gac)))) + ) } } @@ -33,19 +34,12 @@ class GoogleStorageFSSuite extends TestNGSuite with FSSuite { assert(fs.makeQualified(qualifiedFileName) == qualifiedFileName) val unqualifiedFileName = "not-gs://bucket/path" - try { + try fs.makeQualified(unqualifiedFileName) - } catch { case _: IllegalArgumentException => return } assert(false) } - - @Test def testETag(): Unit = { - val etag = fs.eTag(s"$fsResourcesRoot/a") - assert(etag.nonEmpty) - } - } diff --git a/hail/src/test/scala/is/hail/linalg/BlockMatrixSuite.scala b/hail/src/test/scala/is/hail/linalg/BlockMatrixSuite.scala index fbb183c8c9c..cb2a9008a5c 100644 --- a/hail/src/test/scala/is/hail/linalg/BlockMatrixSuite.scala +++ b/hail/src/test/scala/is/hail/linalg/BlockMatrixSuite.scala @@ -1,21 +1,19 @@ package is.hail.linalg - -import breeze.linalg.{*, diag, DenseMatrix => BDM, DenseVector => BDV} +import is.hail.{HailSuite, TestUtils} +import is.hail.check._ import is.hail.check.Arbitrary._ import is.hail.check.Gen._ import is.hail.check.Prop._ -import is.hail.check._ import is.hail.expr.ir.{CompileAndEvaluate, GetField, TableCollect, TableLiteral} import is.hail.linalg.BlockMatrix.ops._ import is.hail.types.virtual.{TFloat64, TInt64, TStruct} import is.hail.utils._ -import is.hail.{HailSuite, TestUtils} + +import breeze.linalg.{*, diag, DenseMatrix => BDM, DenseVector => BDV} import org.apache.spark.sql.Row import org.testng.annotations.Test -import scala.language.implicitConversions - class BlockMatrixSuite extends HailSuite { // row major @@ -48,7 +46,7 @@ class BlockMatrixSuite extends HailSuite { def blockMatrixGen( blockSize: Gen[Int] = defaultBlockSize, dims: Gen[(Int, Int)] = defaultDims, - element: Gen[Double] = defaultElement + element: Gen[Double] = defaultElement, ): Gen[BlockMatrix] = for { blockSize <- blockSize (nRows, nCols) <- dims @@ -65,12 +63,16 @@ class BlockMatrixSuite extends HailSuite { l <- interestingPosInt s = math.sqrt(math.min(l, size)).toInt } yield (s, s), - element = element + element = element, ) - def twoMultipliableBlockMatrices(element: Gen[Double] = defaultElement): Gen[(BlockMatrix, BlockMatrix)] = for { + def twoMultipliableBlockMatrices(element: Gen[Double] = defaultElement) + : Gen[(BlockMatrix, BlockMatrix)] = for { Array(nRows, innerDim, nCols) <- nonEmptyNCubeOfVolumeAtMostSize(3) - blockSize <- interestingPosInt.filter(_ > 3) // 1 or 2 cause large numbers of partitions, leading to slow tests + blockSize <- + interestingPosInt.filter( + _ > 3 + ) // 1 or 2 cause large numbers of partitions, leading to slow tests l <- blockMatrixGen(const(blockSize), const(nRows -> innerDim), element) r <- blockMatrixGen(const(blockSize), const(innerDim -> nCols), element) } yield (l, r) @@ -80,15 +82,25 @@ class BlockMatrixSuite extends HailSuite { private val defaultRelTolerance = 1e-14 - private def sameDoubleMatrixNaNEqualsNaN(x: BDM[Double], y: BDM[Double], relTolerance: Double = defaultRelTolerance): Boolean = + private def sameDoubleMatrixNaNEqualsNaN( + x: BDM[Double], + y: BDM[Double], + relTolerance: Double = defaultRelTolerance, + ): Boolean = findDoubleMatrixMismatchNaNEqualsNaN(x, y, relTolerance) match { case Some(_) => false case None => true } - private def findDoubleMatrixMismatchNaNEqualsNaN(x: BDM[Double], y: BDM[Double], relTolerance: Double = defaultRelTolerance): Option[(Int, Int)] = { - assert(x.rows == y.rows && x.cols == y.cols, - s"dimension mismatch: ${ x.rows } x ${ x.cols } vs ${ y.rows } x ${ y.cols }") + private def findDoubleMatrixMismatchNaNEqualsNaN( + x: BDM[Double], + y: BDM[Double], + relTolerance: Double = defaultRelTolerance, + ): Option[(Int, Int)] = { + assert( + x.rows == y.rows && x.cols == y.cols, + s"dimension mismatch: ${x.rows} x ${x.cols} vs ${y.rows} x ${y.cols}", + ) var j = 0 while (j < x.cols) { var i = 0 @@ -96,7 +108,7 @@ class BlockMatrixSuite extends HailSuite { if (D_==(x(i, j) - y(i, j), relTolerance) && !(x(i, j).isNaN && y(i, j).isNaN)) { println(x.toString(1000, 1000)) println(y.toString(1000, 1000)) - println(s"inequality found at ($i, $j): ${ x(i, j) } and ${ y(i, j) }") + println(s"inequality found at ($i, $j): ${x(i, j)} and ${y(i, j)}") return Some((i, j)) } i += 1 @@ -107,51 +119,67 @@ class BlockMatrixSuite extends HailSuite { } @Test - def pointwiseSubtractCorrect() { - val m = toBM(4, 4, Array[Double]( - 1, 2, 3, 4, - 5, 6, 7, 8, - 9, 10, 11, 12, - 13, 14, 15, 16)) - - val expected = toLM(4, 4, Array[Double]( - 0, -3, -6, -9, - 3, 0, -3, -6, - 6, 3, 0, -3, - 9, 6, 3, 0)) + def pointwiseSubtractCorrect(): Unit = { + val m = toBM( + 4, + 4, + Array[Double]( + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15, 16), + ) + + val expected = toLM( + 4, + 4, + Array[Double]( + 0, -3, -6, -9, + 3, 0, -3, -6, + 6, 3, 0, -3, + 9, 6, 3, 0), + ) val actual = (m - m.T).toBreezeMatrix() assert(actual == expected) } @Test - def multiplyByLocalMatrix() { - val ll = toLM(4, 4, Array[Double]( - 1, 2, 3, 4, - 5, 6, 7, 8, - 9, 10, 11, 12, - 13, 14, 15, 16)) + def multiplyByLocalMatrix(): Unit = { + val ll = toLM( + 4, + 4, + Array[Double]( + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15, 16), + ) val l = toBM(ll) - val lr = toLM(4, 1, Array[Double]( + val lr = toLM( + 4, 1, - 2, - 3, - 4)) + Array[Double]( + 1, + 2, + 3, + 4, + ), + ) assert(ll * lr === l.dot(lr).toBreezeMatrix()) } @Test - def randomMultiplyByLocalMatrix() { + def randomMultiplyByLocalMatrix(): Unit = forAll(twoMultipliableDenseMatrices[Double]()) { case (ll, lr) => val l = toBM(ll) sameDoubleMatrixNaNEqualsNaN(ll * lr, l.dot(lr).toBreezeMatrix()) }.check() - } @Test - def multiplySameAsBreeze() { + def multiplySameAsBreeze(): Unit = { def randomLm(n: Int, m: Int) = denseMatrix[Double](n, m) forAll(randomLm(4, 4), randomLm(4, 4)) { (ll, lr) => @@ -182,55 +210,65 @@ class BlockMatrixSuite extends HailSuite { sameDoubleMatrixNaNEqualsNaN(l.dot(r).toBreezeMatrix(), ll * lr) }.check() - forAll(twoMultipliableDenseMatrices[Double](), interestingPosInt) { case ((ll, lr), blockSize) => - val l = toBM(ll, blockSize) - val r = toBM(lr, blockSize) + forAll(twoMultipliableDenseMatrices[Double](), interestingPosInt) { + case ((ll, lr), blockSize) => + val l = toBM(ll, blockSize) + val r = toBM(lr, blockSize) - sameDoubleMatrixNaNEqualsNaN(l.dot(r).toBreezeMatrix(), ll * lr) + sameDoubleMatrixNaNEqualsNaN(l.dot(r).toBreezeMatrix(), ll * lr) }.check() } @Test - def multiplySameAsBreezeRandomized() { - forAll(twoMultipliableBlockMatrices(nonExtremeDouble)) { case (l: BlockMatrix, r: BlockMatrix) => - val actual = l.dot(r).toBreezeMatrix() - val expected = l.toBreezeMatrix() * r.toBreezeMatrix() - - findDoubleMatrixMismatchNaNEqualsNaN(actual, expected) match { - case Some((i, j)) => - println(s"blockSize: ${ l.blockSize }") - println(s"${ l.toBreezeMatrix() }") - println(s"${ r.toBreezeMatrix() }") - println(s"row: ${ l.toBreezeMatrix()(i, ::) }") - println(s"col: ${ r.toBreezeMatrix()(::, j) }") - false - case None => - true - } + def multiplySameAsBreezeRandomized(): Unit = { + forAll(twoMultipliableBlockMatrices(nonExtremeDouble)) { + case (l: BlockMatrix, r: BlockMatrix) => + val actual = l.dot(r).toBreezeMatrix() + val expected = l.toBreezeMatrix() * r.toBreezeMatrix() + + findDoubleMatrixMismatchNaNEqualsNaN(actual, expected) match { + case Some((i, j)) => + println(s"blockSize: ${l.blockSize}") + println(s"${l.toBreezeMatrix()}") + println(s"${r.toBreezeMatrix()}") + println(s"row: ${l.toBreezeMatrix()(i, ::)}") + println(s"col: ${r.toBreezeMatrix()(::, j)}") + false + case None => + true + } }.check() } @Test - def rowwiseMultiplication() { - val l = toBM(4, 4, Array[Double]( - 1, 2, 3, 4, - 5, 6, 7, 8, - 9, 10, 11, 12, - 13, 14, 15, 16)) + def rowwiseMultiplication(): Unit = { + val l = toBM( + 4, + 4, + Array[Double]( + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15, 16), + ) val v = Array[Double](1, 2, 3, 4) - val result = toLM(4, 4, Array[Double]( - 1, 4, 9, 16, - 5, 12, 21, 32, - 9, 20, 33, 48, - 13, 28, 45, 64)) + val result = toLM( + 4, + 4, + Array[Double]( + 1, 4, 9, 16, + 5, 12, 21, 32, + 9, 20, 33, 48, + 13, 28, 45, 64), + ) assert(l.rowVectorMul(v).toBreezeMatrix() == result) } @Test - def rowwiseMultiplicationRandom() { + def rowwiseMultiplicationRandom(): Unit = { val g = for { l <- blockMatrixGen() v <- buildableOfN[Array](l.nCols.toInt, arbitrary[Double]) @@ -247,26 +285,34 @@ class BlockMatrixSuite extends HailSuite { } @Test - def colwiseMultiplication() { - val l = toBM(4, 4, Array[Double]( - 1, 2, 3, 4, - 5, 6, 7, 8, - 9, 10, 11, 12, - 13, 14, 15, 16)) + def colwiseMultiplication(): Unit = { + val l = toBM( + 4, + 4, + Array[Double]( + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15, 16), + ) val v = Array[Double](1, 2, 3, 4) - val result = toLM(4, 4, Array[Double]( - 1, 2, 3, 4, - 10, 12, 14, 16, - 27, 30, 33, 36, - 52, 56, 60, 64)) + val result = toLM( + 4, + 4, + Array[Double]( + 1, 2, 3, 4, + 10, 12, 14, 16, + 27, 30, 33, 36, + 52, 56, 60, 64), + ) assert(l.colVectorMul(v).toBreezeMatrix() == result) } @Test - def colwiseMultiplicationRandom() { + def colwiseMultiplicationRandom(): Unit = { val g = for { l <- blockMatrixGen() v <- buildableOfN[Array](l.nRows.toInt, arbitrary[Double]) @@ -281,57 +327,77 @@ class BlockMatrixSuite extends HailSuite { if (sameDoubleMatrixNaNEqualsNaN(actual, expected)) true else { - println(s"${ l.toBreezeMatrix().toArray.toSeq }\n*\n${ v.toSeq }") + println(s"${l.toBreezeMatrix().toArray.toSeq}\n*\n${v.toSeq}") false } }.check() } @Test - def colwiseAddition() { - val l = toBM(4, 4, Array[Double]( - 1, 2, 3, 4, - 5, 6, 7, 8, - 9, 10, 11, 12, - 13, 14, 15, 16)) + def colwiseAddition(): Unit = { + val l = toBM( + 4, + 4, + Array[Double]( + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15, 16), + ) val v = Array[Double](1, 2, 3, 4) - val result = toLM(4, 4, Array[Double]( - 2, 3, 4, 5, - 7, 8, 9, 10, - 12, 13, 14, 15, - 17, 18, 19, 20)) + val result = toLM( + 4, + 4, + Array[Double]( + 2, 3, 4, 5, + 7, 8, 9, 10, + 12, 13, 14, 15, + 17, 18, 19, 20), + ) assert(l.colVectorAdd(v).toBreezeMatrix() == result) } @Test - def rowwiseAddition() { - val l = toBM(4, 4, Array[Double]( - 1, 2, 3, 4, - 5, 6, 7, 8, - 9, 10, 11, 12, - 13, 14, 15, 16)) + def rowwiseAddition(): Unit = { + val l = toBM( + 4, + 4, + Array[Double]( + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15, 16), + ) val v = Array[Double](1, 2, 3, 4) - val result = toLM(4, 4, Array[Double]( - 2, 4, 6, 8, - 6, 8, 10, 12, - 10, 12, 14, 16, - 14, 16, 18, 20)) + val result = toLM( + 4, + 4, + Array[Double]( + 2, 4, 6, 8, + 6, 8, 10, 12, + 10, 12, 14, 16, + 14, 16, 18, 20), + ) assert(l.rowVectorAdd(v).toBreezeMatrix() == result) } @Test - def diagonalTestTiny() { - val lm = toLM(3, 4, Array[Double]( - 1, 2, 3, 4, - 5, 6, 7, 8, - 9, 10, 11, 12)) - + def diagonalTestTiny(): Unit = { + val lm = toLM( + 3, + 4, + Array[Double]( + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12), + ) + val m = toBM(lm, blockSize = 2) assert(m.diagonal().toSeq == Seq(1, 6, 11)) @@ -340,7 +406,7 @@ class BlockMatrixSuite extends HailSuite { } @Test - def diagonalTestRandomized() { + def diagonalTestRandomized(): Unit = { forAll(squareBlockMatrixGen()) { (m: BlockMatrix) => val lm = m.toBreezeMatrix() val diagonalLength = math.min(lm.rows, lm.cols) @@ -350,14 +416,14 @@ class BlockMatrixSuite extends HailSuite { true else { println(s"lm: $lm") - println(s"${ m.diagonal().toSeq } != ${ diagonal.toSeq }") + println(s"${m.diagonal().toSeq} != ${diagonal.toSeq}") false } }.check() } @Test - def fromLocalTest() { + def fromLocalTest(): Unit = { forAll(denseMatrix[Double]().flatMap { m => Gen.zip(Gen.const(m), Gen.choose(math.sqrt(m.rows).toInt, m.rows + 16)) }) { case (lm, blockSize) => @@ -367,12 +433,16 @@ class BlockMatrixSuite extends HailSuite { } @Test - def readWriteIdentityTrivial() { - val m = toBM(4, 4, Array[Double]( - 1, 2, 3, 4, - 5, 6, 7, 8, - 9, 10, 11, 12, - 13, 14, 15, 16)) + def readWriteIdentityTrivial(): Unit = { + val m = toBM( + 4, + 4, + Array[Double]( + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15, 16), + ) val fname = ctx.createTmpPath("test") m.write(ctx, fname) @@ -384,12 +454,16 @@ class BlockMatrixSuite extends HailSuite { } @Test - def readWriteIdentityTrivialTransposed() { - val m = toBM(4, 4, Array[Double]( - 1, 2, 3, 4, - 5, 6, 7, 8, - 9, 10, 11, 12, - 13, 14, 15, 16)) + def readWriteIdentityTrivialTransposed(): Unit = { + val m = toBM( + 4, + 4, + Array[Double]( + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15, 16), + ) val fname = ctx.createTmpPath("test") m.T.write(ctx, fname) @@ -401,17 +475,20 @@ class BlockMatrixSuite extends HailSuite { } @Test - def readWriteIdentityRandom() { + def readWriteIdentityRandom(): Unit = { forAll(blockMatrixGen()) { (m: BlockMatrix) => val fname = ctx.createTmpPath("test") m.write(ctx, fname) - assert(sameDoubleMatrixNaNEqualsNaN(m.toBreezeMatrix(), BlockMatrix.read(fs, fname).toBreezeMatrix())) + assert(sameDoubleMatrixNaNEqualsNaN( + m.toBreezeMatrix(), + BlockMatrix.read(fs, fname).toBreezeMatrix(), + )) true }.check() } @Test - def transpose() { + def transpose(): Unit = { forAll(blockMatrixGen()) { (m: BlockMatrix) => val transposed = m.toBreezeMatrix().t assert(transposed.rows == m.nCols) @@ -422,7 +499,7 @@ class BlockMatrixSuite extends HailSuite { } @Test - def doubleTransposeIsIdentity() { + def doubleTransposeIsIdentity(): Unit = { forAll(blockMatrixGen(element = nonExtremeDouble)) { (m: BlockMatrix) => val mt = m.T.cache() val mtt = m.T.T.cache() @@ -435,32 +512,33 @@ class BlockMatrixSuite extends HailSuite { } @Test - def cachedOpsOK() { - forAll(twoMultipliableBlockMatrices(nonExtremeDouble)) { case (l: BlockMatrix, r: BlockMatrix) => - l.cache() - r.cache() - - val actual = l.dot(r).toBreezeMatrix() - val expected = l.toBreezeMatrix() * r.toBreezeMatrix() - - if (!sameDoubleMatrixNaNEqualsNaN(actual, expected)) { - println(s"${ l.toBreezeMatrix() }") - println(s"${ r.toBreezeMatrix() }") - assert(false) - } + def cachedOpsOK(): Unit = { + forAll(twoMultipliableBlockMatrices(nonExtremeDouble)) { + case (l: BlockMatrix, r: BlockMatrix) => + l.cache() + r.cache() + + val actual = l.dot(r).toBreezeMatrix() + val expected = l.toBreezeMatrix() * r.toBreezeMatrix() + + if (!sameDoubleMatrixNaNEqualsNaN(actual, expected)) { + println(s"${l.toBreezeMatrix()}") + println(s"${r.toBreezeMatrix()}") + assert(false) + } - if (!sameDoubleMatrixNaNEqualsNaN(l.T.cache().T.toBreezeMatrix(), l.toBreezeMatrix())) { - println(s"${ l.T.cache().T.toBreezeMatrix() }") - println(s"${ l.toBreezeMatrix() }") - assert(false) - } + if (!sameDoubleMatrixNaNEqualsNaN(l.T.cache().T.toBreezeMatrix(), l.toBreezeMatrix())) { + println(s"${l.T.cache().T.toBreezeMatrix()}") + println(s"${l.toBreezeMatrix()}") + assert(false) + } - true + true }.check() } @Test - def toIRMToHBMIdentity() { + def toIRMToHBMIdentity(): Unit = { forAll(blockMatrixGen()) { (m: BlockMatrix) => val roundtrip = m.toIndexedRowMatrix().toHailBlockMatrix(m.blockSize) @@ -478,33 +556,52 @@ class BlockMatrixSuite extends HailSuite { } @Test - def map2RespectsTransposition() { - val lm = toLM(4, 2, Array[Double]( - 1, 2, - 3, 4, - 5, 6, - 7, 8)) - val lmt = toLM(2, 4, Array[Double]( - 1, 3, 5, 7, - 2, 4, 6, 8)) + def map2RespectsTransposition(): Unit = { + val lm = toLM( + 4, + 2, + Array[Double]( + 1, 2, + 3, 4, + 5, 6, + 7, 8), + ) + val lmt = toLM( + 2, + 4, + Array[Double]( + 1, 3, 5, 7, + 2, 4, 6, 8), + ) val m = toBM(lm) val mt = toBM(lmt) assert(m.map2(mt.T, _ + _).toBreezeMatrix() === lm + lm) - assert(mt.T.map2(m, _ + _).toBreezeMatrix() === lm + lm, s"${ mt.toBreezeMatrix() }\n${ mt.T.toBreezeMatrix() }\n${ m.toBreezeMatrix() }") + assert( + mt.T.map2(m, _ + _).toBreezeMatrix() === lm + lm, + s"${mt.toBreezeMatrix()}\n${mt.T.toBreezeMatrix()}\n${m.toBreezeMatrix()}", + ) } @Test - def map4RespectsTransposition() { - val lm = toLM(4, 2, Array[Double]( - 1, 2, - 3, 4, - 5, 6, - 7, 8)) - val lmt = toLM(2, 4, Array[Double]( - 1, 3, 5, 7, - 2, 4, 6, 8)) + def map4RespectsTransposition(): Unit = { + val lm = toLM( + 4, + 2, + Array[Double]( + 1, 2, + 3, 4, + 5, 6, + 7, 8), + ) + val lmt = toLM( + 2, + 4, + Array[Double]( + 1, 3, 5, 7, + 2, 4, 6, 8), + ) val m = toBM(lm) val mt = toBM(lmt) @@ -514,15 +611,23 @@ class BlockMatrixSuite extends HailSuite { } @Test - def mapRespectsTransposition() { - val lm = toLM(4, 2, Array[Double]( - 1, 2, - 3, 4, - 5, 6, - 7, 8)) - val lmt = toLM(2, 4, Array[Double]( - 1, 3, 5, 7, - 2, 4, 6, 8)) + def mapRespectsTransposition(): Unit = { + val lm = toLM( + 4, + 2, + Array[Double]( + 1, 2, + 3, 4, + 5, 6, + 7, 8), + ) + val lmt = toLM( + 2, + 4, + Array[Double]( + 1, 3, 5, 7, + 2, 4, 6, 8), + ) val m = toBM(lm) val mt = toBM(lmt) @@ -533,15 +638,23 @@ class BlockMatrixSuite extends HailSuite { } @Test - def mapWithIndexRespectsTransposition() { - val lm = toLM(4, 2, Array[Double]( - 1, 2, - 3, 4, - 5, 6, - 7, 8)) - val lmt = toLM(2, 4, Array[Double]( - 1, 3, 5, 7, - 2, 4, 6, 8)) + def mapWithIndexRespectsTransposition(): Unit = { + val lm = toLM( + 4, + 2, + Array[Double]( + 1, 2, + 3, 4, + 5, 6, + 7, 8), + ) + val lmt = toLM( + 2, + 4, + Array[Double]( + 1, 3, 5, 7, + 2, 4, 6, 8), + ) val m = toBM(lm) val mt = toBM(lmt) @@ -561,15 +674,23 @@ class BlockMatrixSuite extends HailSuite { } @Test - def map2WithIndexRespectsTransposition() { - val lm = toLM(4, 2, Array[Double]( - 1, 2, - 3, 4, - 5, 6, - 7, 8)) - val lmt = toLM(2, 4, Array[Double]( - 1, 3, 5, 7, - 2, 4, 6, 8)) + def map2WithIndexRespectsTransposition(): Unit = { + val lm = toLM( + 4, + 2, + Array[Double]( + 1, 2, + 3, 4, + 5, 6, + 7, 8), + ) + val lmt = toLM( + 2, + 4, + Array[Double]( + 1, 3, 5, 7, + 2, 4, 6, 8), + ) val m = toBM(lm) val mt = toBM(lmt) @@ -587,24 +708,28 @@ class BlockMatrixSuite extends HailSuite { 4.0 * lm.t) assert(mt.map2WithIndex(m.T, (i, j, x, y) => x + 2 * y + j * 2 + i + 1).toBreezeMatrix() === 4.0 * lm.t) - assert(mt.T.map2WithIndex(m.T.T, (i, j, x, y) => 3 * x + 5 * y + i * 2 + j + 1).toBreezeMatrix() === + assert(mt.T.map2WithIndex( + m.T.T, + (i, j, x, y) => 3 * x + 5 * y + i * 2 + j + 1, + ).toBreezeMatrix() === 9.0 * lm) } @Test - def filterCols() { + def filterCols(): Unit = { val lm = new BDM[Double](9, 10, (0 until 90).map(_.toDouble).toArray) - for {blockSize <- Seq(1, 2, 3, 5, 10, 11) - } { + for { blockSize <- Seq(1, 2, 3, 5, 10, 11) } { val bm = BlockMatrix.fromBreezeMatrix(lm, blockSize) - for {keep <- Seq( - Array(0), - Array(1), - Array(9), - Array(0, 3, 4, 5, 7), - Array(1, 4, 5, 7, 8, 9), - Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9)) + for { + keep <- Seq( + Array(0), + Array(1), + Array(9), + Array(0, 3, 4, 5, 7), + Array(1, 4, 5, 7, 8, 9), + Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), + ) } { val filteredViaBlock = bm.filterCols(keep.map(_.toLong)).toBreezeMatrix() val filteredViaBreeze = lm(::, keep.toFastSeq).copy @@ -615,17 +740,18 @@ class BlockMatrixSuite extends HailSuite { } @Test - def filterColsTranspose() { + def filterColsTranspose(): Unit = { val lm = new BDM[Double](9, 10, (0 until 90).map(_.toDouble).toArray) val lmt = lm.t - for {blockSize <- Seq(2, 3) - } { + for { blockSize <- Seq(2, 3) } { val bm = BlockMatrix.fromBreezeMatrix(lm, blockSize).transpose() - for {keep <- Seq( - Array(0), - Array(1, 4, 5, 7, 8), - Array(0, 1, 2, 3, 4, 5, 6, 7, 8)) + for { + keep <- Seq( + Array(0), + Array(1, 4, 5, 7, 8), + Array(0, 1, 2, 3, 4, 5, 6, 7, 8), + ) } { val filteredViaBlock = bm.filterCols(keep.map(_.toLong)).toBreezeMatrix() val filteredViaBreeze = lmt(::, keep.toFastSeq).copy @@ -636,16 +762,17 @@ class BlockMatrixSuite extends HailSuite { } @Test - def filterRows() { + def filterRows(): Unit = { val lm = new BDM[Double](9, 10, (0 until 90).map(_.toDouble).toArray) - for {blockSize <- Seq(2, 3) - } { + for { blockSize <- Seq(2, 3) } { val bm = BlockMatrix.fromBreezeMatrix(lm, blockSize) - for {keep <- Seq( - Array(0), - Array(1, 4, 5, 7, 8), - Array(0, 1, 2, 3, 4, 5, 6, 7, 8)) + for { + keep <- Seq( + Array(0), + Array(1, 4, 5, 7, 8), + Array(0, 1, 2, 3, 4, 5, 6, 7, 8), + ) } { val filteredViaBlock = bm.filterRows(keep.map(_.toLong)).toBreezeMatrix() val filteredViaBreeze = lm(keep.toFastSeq, ::).copy @@ -656,19 +783,20 @@ class BlockMatrixSuite extends HailSuite { } @Test - def filterSymmetric() { + def filterSymmetric(): Unit = { val lm = new BDM[Double](10, 10, (0 until 100).map(_.toDouble).toArray) - for {blockSize <- Seq(1, 2, 3, 5, 10, 11) - } { + for { blockSize <- Seq(1, 2, 3, 5, 10, 11) } { val bm = BlockMatrix.fromBreezeMatrix(lm, blockSize) - for {keep <- Seq( - Array(0), - Array(1), - Array(9), - Array(0, 3, 4, 5, 7), - Array(1, 4, 5, 7, 8, 9), - Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9)) + for { + keep <- Seq( + Array(0), + Array(1), + Array(9), + Array(0, 3, 4, 5, 7), + Array(1, 4, 5, 7, 8, 9), + Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), + ) } { val filteredViaBlock = bm.filter(keep.map(_.toLong), keep.map(_.toLong)).toBreezeMatrix() val filteredViaBreeze = lm(keep.toFastSeq, keep.toFastSeq).copy @@ -679,23 +807,25 @@ class BlockMatrixSuite extends HailSuite { } @Test - def filter() { + def filter(): Unit = { val lm = new BDM[Double](9, 10, (0 until 90).map(_.toDouble).toArray) - for {blockSize <- Seq(1, 2, 3, 5, 10, 11) - } { + for { blockSize <- Seq(1, 2, 3, 5, 10, 11) } { val bm = BlockMatrix.fromBreezeMatrix(lm, blockSize) for { keepRows <- Seq( Array(1), Array(0, 3, 4, 5, 7), - Array(0, 1, 2, 3, 4, 5, 6, 7, 8)) + Array(0, 1, 2, 3, 4, 5, 6, 7, 8), + ) keepCols <- Seq( Array(2), Array(1, 4, 5, 7, 8, 9), - Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9)) + Array(0, 1, 2, 3, 4, 5, 6, 7, 8, 9), + ) } { - val filteredViaBlock = bm.filter(keepRows.map(_.toLong), keepCols.map(_.toLong)).toBreezeMatrix() + val filteredViaBlock = + bm.filter(keepRows.map(_.toLong), keepCols.map(_.toLong)).toBreezeMatrix() val filteredViaBreeze = lm(keepRows.toFastSeq, keepCols.toFastSeq).copy assert(filteredViaBlock === filteredViaBreeze) @@ -704,10 +834,10 @@ class BlockMatrixSuite extends HailSuite { } @Test - def writeLocalAsBlockTest() { + def writeLocalAsBlockTest(): Unit = { val lm = new BDM[Double](10, 10, (0 until 100).map(_.toDouble).toArray) - for {blockSize <- Seq(1, 2, 3, 5, 10, 11)} { + for { blockSize <- Seq(1, 2, 3, 5, 10, 11) } { val fname = ctx.createTmpPath("test") lm.writeBlockMatrix(fs, fname, blockSize) assert(lm === BlockMatrix.read(fs, fname).toBreezeMatrix()) @@ -715,11 +845,15 @@ class BlockMatrixSuite extends HailSuite { } @Test - def randomTest() { - var lm1 = BlockMatrix.random(5, 10, 2, staticUID = 1, nonce = 1, gaussian = false).toBreezeMatrix() - var lm2 = BlockMatrix.random(5, 10, 2, staticUID = 1, nonce = 1, gaussian = false).toBreezeMatrix() - var lm3 = BlockMatrix.random(5, 10, 2, staticUID = 2, nonce = 1, gaussian = false).toBreezeMatrix() - var lm4 = BlockMatrix.random(5, 10, 2, staticUID = 1, nonce = 2, gaussian = false).toBreezeMatrix() + def randomTest(): Unit = { + var lm1 = + BlockMatrix.random(5, 10, 2, staticUID = 1, nonce = 1, gaussian = false).toBreezeMatrix() + var lm2 = + BlockMatrix.random(5, 10, 2, staticUID = 1, nonce = 1, gaussian = false).toBreezeMatrix() + var lm3 = + BlockMatrix.random(5, 10, 2, staticUID = 2, nonce = 1, gaussian = false).toBreezeMatrix() + var lm4 = + BlockMatrix.random(5, 10, 2, staticUID = 1, nonce = 2, gaussian = false).toBreezeMatrix() println(lm1) assert(lm1 === lm2) @@ -746,11 +880,11 @@ class BlockMatrixSuite extends HailSuite { val expectedEntries = data.map(x => ((x % 9).toLong, (x / 9).toLong, x)).toSet val expectedSignature = TStruct("i" -> TInt64, "j" -> TInt64, "entry" -> TFloat64) - for {blockSize <- Seq(1, 4, 10)} { + for { blockSize <- Seq(1, 4, 10) } { val entriesLiteral = TableLiteral(toBM(lm, blockSize).entriesTable(ctx), theHailClassLoader) assert(entriesLiteral.typ.rowType == expectedSignature) - val rows = CompileAndEvaluate[IndexedSeq[Row]](ctx, - GetField(TableCollect(entriesLiteral), "rows")) + val rows = + CompileAndEvaluate[IndexedSeq[Row]](ctx, GetField(TableCollect(entriesLiteral), "rows")) val entries = rows.map(row => (row.get(0), row.get(1), row.get(2))).toSet // block size affects order of rows in table, but sets will be the same assert(entries === expectedEntries) @@ -769,11 +903,11 @@ class BlockMatrixSuite extends HailSuite { TableCollect( TableLiteral( bm.filterBlocks(Array(0, 1, 6)).entriesTable(ctx), - theHailClassLoader + theHailClassLoader, ) ), - "rows" - ) + "rows", + ), ) val expected = rows .sortBy(r => (r.get(0).asInstanceOf[Long], r.get(1).asInstanceOf[Long])) @@ -787,12 +921,12 @@ class BlockMatrixSuite extends HailSuite { val lm = new BDM[Double](2, 3, Array(0.0, 1.0, 4.0, 9.0, 16.0, 25.0)) val bm = BlockMatrix.fromBreezeMatrix(lm, blockSize = 2) val expected = new BDM[Double](2, 3, Array(0.0, 1.0, 2.0, 3.0, 4.0, 5.0)) - + TestUtils.assertMatrixEqualityDouble(bm.pow(0.0).toBreezeMatrix(), BDM.fill(2, 3)(1.0)) TestUtils.assertMatrixEqualityDouble(bm.pow(0.5).toBreezeMatrix(), expected) TestUtils.assertMatrixEqualityDouble(bm.sqrt().toBreezeMatrix(), expected) } - + def filteredEquals(bm1: BlockMatrix, bm2: BlockMatrix): Boolean = bm1.blocks.collect() sameElements bm2.blocks.collect() @@ -806,8 +940,10 @@ class BlockMatrixSuite extends HailSuite { val onlyEightColEleven = onlyEight.filterCols(Array(11)).toBreezeMatrix() val onlyEightCornerFour = onlyEight.filter(Array(10, 11), Array(10, 11)).toBreezeMatrix() - assert(onlyEightRowEleven.toArray sameElements Array(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 131, 143).map(_.toDouble)) - assert(onlyEightColEleven.toArray sameElements Array(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 142, 143).map(_.toDouble)) + assert(onlyEightRowEleven.toArray sameElements Array(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 131, + 143).map(_.toDouble)) + assert(onlyEightColEleven.toArray sameElements Array(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 142, + 143).map(_.toDouble)) assert(onlyEightCornerFour == new BDM[Double](2, 2, Array(130.0, 131.0, 142.0, 143.0))) } @@ -816,7 +952,11 @@ class BlockMatrixSuite extends HailSuite { val lm = new BDM[Double](9, 12, (0 to 107).map(_.toDouble).toArray) val bm = toBM(lm, blockSize = 3) val sparse = bm.filterBand(0, 0, true) - assert(sparse.transpose().gp.partitionIndexToBlockIndex.get.toIndexedSeq == IndexedSeq(0, 5, 10)) + assert(sparse.transpose().gp.partitionIndexToBlockIndex.get.toIndexedSeq == IndexedSeq( + 0, + 5, + 10, + )) } @Test @@ -827,18 +967,25 @@ class BlockMatrixSuite extends HailSuite { val banded = bm.filterBand(0, 0, false) val rowFilt = banded.filterRows((0L until nRows.toLong by 2L).toArray) val summed = rowFilt.rowSum().toBreezeMatrix().toArray - val expected = Array.tabulate(nRows)(x => if (x % 2 == 0) 2.0 else 0) ++ Array.tabulate(nCols - nRows)(x => 0.0) + val expected = + Array.tabulate(nRows)(x => if (x % 2 == 0) 2.0 else 0) ++ Array.tabulate(nCols - nRows)(x => + 0.0 + ) assert(summed sameElements expected) } @Test - def testFilterBlocks() { - val lm = toLM(4, 4, Array( - 1, 2, 3, 4, - 5, 6, 7, 8, - 9, 10, 11, 12, - 13, 14, 15, 16)) - + def testFilterBlocks(): Unit = { + val lm = toLM( + 4, + 4, + Array( + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15, 16), + ) + val bm = toBM(lm, blockSize = 2) val keepArray = Array( @@ -847,36 +994,45 @@ class BlockMatrixSuite extends HailSuite { Array(1, 3), Array(2, 3), Array(1, 2, 3), - Array(0, 1, 2, 3)) + Array(0, 1, 2, 3), + ) + + val localBlocks = + Array(lm(0 to 1, 0 to 1), lm(2 to 3, 0 to 1), lm(0 to 1, 2 to 3), lm(2 to 3, 2 to 3)) - val localBlocks = Array(lm(0 to 1, 0 to 1), lm(2 to 3, 0 to 1), lm(0 to 1, 2 to 3), lm(2 to 3, 2 to 3)) - for { keep <- keepArray } { val fbm = bm.filterBlocks(keep) - + assert(fbm.blocks.count() == keep.length) assert(fbm.blocks.collect().forall { case ((i, j), block) => - block == localBlocks(fbm.gp.coordinatesBlock(i, j)) } ) + block == localBlocks(fbm.gp.coordinatesBlock(i, j)) + }) } - + // test multiple block filters val bm13 = bm.filterBlocks(Array(1, 3)).cache() assert(filteredEquals(bm13, bm13.filterBlocks(Array(1, 3)))) assert(filteredEquals(bm13, bm.filterBlocks(Array(1, 2, 3)).filterBlocks(Array(0, 1, 3)))) assert(filteredEquals(bm13, bm13.filterBlocks(Array(0, 1, 2, 3)))) - assert(filteredEquals(bm.filterBlocks(Array(1)), - bm.filterBlocks(Array(1, 2, 3)).filterBlocks(Array(0, 1, 2)).filterBlocks(Array(0, 1, 3)))) + assert(filteredEquals( + bm.filterBlocks(Array(1)), + bm.filterBlocks(Array(1, 2, 3)).filterBlocks(Array(0, 1, 2)).filterBlocks(Array(0, 1, 3)), + )) } - + @Test - def testSparseBlockMatrixIO() { - val lm = toLM(4, 4, Array( - 1, 2, 3, 4, - 5, 6, 7, 8, - 9, 10, 11, 12, - 13, 14, 15, 16)) - + def testSparseBlockMatrixIO(): Unit = { + val lm = toLM( + 4, + 4, + Array( + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15, 16), + ) + val bm = toBM(lm, blockSize = 2) val keepArray = Array( @@ -885,8 +1041,9 @@ class BlockMatrixSuite extends HailSuite { Array(1, 3), Array(2, 3), Array(1, 2, 3), - Array(0, 1, 2, 3)) - + Array(0, 1, 2, 3), + ) + val lm_zero = BDM.zeros[Double](2, 2) def filterBlocks(keep: Array[Int]): BDM[Double] = { @@ -903,14 +1060,14 @@ class BlockMatrixSuite extends HailSuite { for { keep <- keepArray } { val fbm = bm.filterBlocks(keep) val flm = filterBlocks(keep) - + assert(fbm.toBreezeMatrix() === flm) assert(flm === fbm.toIndexedRowMatrix().toHailBlockMatrix().toBreezeMatrix()) - + val fname = ctx.createTmpPath("test") fbm.write(ctx, fname, forceRowMajor = true) - + assert(RowMatrix.readBlockMatrix(fs, fname, 3).toBreezeMatrix() === flm) assert(filteredEquals(fbm, BlockMatrix.read(fs, fname))) @@ -918,22 +1075,26 @@ class BlockMatrixSuite extends HailSuite { } @Test - def testSparseBlockMatrixMathAndFilter() { - val lm = toLM(4, 4, Array( - 1, 2, 3, 4, - 5, 6, 7, 8, - 9, 10, 11, 12, - 13, 14, 15, 16)) + def testSparseBlockMatrixMathAndFilter(): Unit = { + val lm = toLM( + 4, + 4, + Array( + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15, 16), + ) val bm = toBM(lm, blockSize = 2) - + val keepArray = Array( Array.empty[Int], Array(0), Array(1, 3), Array(2, 3), Array(1, 2, 3), - Array(0, 1, 2, 3) + Array(0, 1, 2, 3), ) val lm_zero = BDM.zeros[Double](2, 2) @@ -947,11 +1108,11 @@ class BlockMatrixSuite extends HailSuite { } flm } - + val transposeBI = Array(0 -> 0, 1 -> 2, 2 -> 1, 3 -> 3).toMap val v = Array(1.0, 2.0, 3.0, 4.0) - + // test transpose, diagonal, math ops, filter ops for { keep <- keepArray } { println(s"Test says keep block: ${keep.toIndexedSeq}") @@ -961,7 +1122,9 @@ class BlockMatrixSuite extends HailSuite { assert(filteredEquals(fbm.transpose().transpose(), fbm)) assert(filteredEquals( - fbm.transpose(), bm.transpose().filterBlocks(keep.map(transposeBI).sorted))) + fbm.transpose(), + bm.transpose().filterBlocks(keep.map(transposeBI).sorted), + )) assert(fbm.diagonal() sameElements diag(fbm.toBreezeMatrix()).toArray) @@ -999,86 +1162,94 @@ class BlockMatrixSuite extends HailSuite { assert(fbm.colVectorAdd(v).toBreezeMatrix() === flm(::, *) + BDV(v)) assert(fbm.colVectorSub(v).toBreezeMatrix() === flm(::, *) - BDV(v)) assert(fbm.reverseColVectorSub(v).toBreezeMatrix() === -(flm(::, *) - BDV(v))) - + // filter ops assert(fbm.filterRows(Array(1, 2)).toBreezeMatrix() === flm(1 to 2, ::)) assert(fbm.filterCols(Array(1, 2)).toBreezeMatrix() === flm(::, 1 to 2)) assert(fbm.filter(Array(1, 2), Array(1, 2)).toBreezeMatrix() === flm(1 to 2, 1 to 2)) } - + val bm0 = bm.filterBlocks(Array(0)) val bm13 = bm.filterBlocks(Array(1, 3)) val bm23 = bm.filterBlocks(Array(2, 3)) val bm123 = bm.filterBlocks(Array(1, 2, 3)) - + val lm0 = filterBlocks(Array(0)) val lm13 = filterBlocks(Array(1, 3)) val lm23 = filterBlocks(Array(2, 3)) val lm123 = filterBlocks(Array(1, 2, 3)) - + // test +/- with mismatched blocks assert(filteredEquals(bm0 + bm13, bm.filterBlocks(Array(0, 1, 3)))) - + assert((bm0 + bm).toBreezeMatrix() === lm0 + lm) assert((bm + bm0).toBreezeMatrix() === lm + lm0) assert( - (bm0 + 2.0 * bm13 + 3.0 * bm23 + 5.0 * bm123).toBreezeMatrix() === - lm0 + 2.0 * lm13 + 3.0 * lm23 + 5.0 * lm123) + (bm0 + 2.0 * bm13 + 3.0 * bm23 + 5.0 * bm123).toBreezeMatrix() === + lm0 + 2.0 * lm13 + 3.0 * lm23 + 5.0 * lm123 + ) assert( - (bm123 + 2.0 * bm13 + 3.0 * bm23 + 5.0 * bm0).toBreezeMatrix() === - lm123 + 2.0 * lm13 + 3.0 * lm23 + 5.0 * lm0) + (bm123 + 2.0 * bm13 + 3.0 * bm23 + 5.0 * bm0).toBreezeMatrix() === + lm123 + 2.0 * lm13 + 3.0 * lm23 + 5.0 * lm0 + ) - assert((bm0 - bm).toBreezeMatrix() === lm0 - lm) assert((bm - bm0).toBreezeMatrix() === lm - lm0) assert( - (bm0 - 2.0 * bm13 - 3.0 * bm23 - 5.0 * bm123).toBreezeMatrix() === - lm0 - 2.0 * lm13 - 3.0 * lm23 - 5.0 * lm123) + (bm0 - 2.0 * bm13 - 3.0 * bm23 - 5.0 * bm123).toBreezeMatrix() === + lm0 - 2.0 * lm13 - 3.0 * lm23 - 5.0 * lm123 + ) assert( - (bm123 - 2.0 * bm13 - 3.0 * bm23 - 5.0 * bm0).toBreezeMatrix() === - lm123 - 2.0 * lm13 - 3.0 * lm23 - 5.0 * lm0) - + (bm123 - 2.0 * bm13 - 3.0 * bm23 - 5.0 * bm0).toBreezeMatrix() === + lm123 - 2.0 * lm13 - 3.0 * lm23 - 5.0 * lm0 + ) + // test * with mismatched blocks - assert(filteredEquals(bm0 * bm13, bm.filterBlocks(Array.empty[Int]))) + assert(filteredEquals(bm0 * bm13, bm.filterBlocks(Array.empty[Int]))) assert(filteredEquals(bm13 * bm23, (bm * bm).filterBlocks(Array(3)))) assert(filteredEquals(bm13 * bm, (bm * bm).filterBlocks(Array(1, 3)))) assert(filteredEquals(bm * bm13, (bm * bm).filterBlocks(Array(1, 3)))) // test unsupported ops val notSupported: String = "not supported for block-sparse matrices" - + val v0 = Array(0.0, Double.NaN, Double.PositiveInfinity, Double.NegativeInfinity) - - TestUtils.interceptFatal(notSupported) { bm0 / bm0 } - TestUtils.interceptFatal(notSupported) { bm0.reverseRowVectorDiv(v) } - TestUtils.interceptFatal(notSupported) { bm0.reverseColVectorDiv(v) } - TestUtils.interceptFatal(notSupported) { 1 / bm0 } - - TestUtils.interceptFatal(notSupported) { bm0.rowVectorDiv(v0) } - TestUtils.interceptFatal(notSupported) { bm0.colVectorDiv(v0) } - TestUtils.interceptFatal("multiplication by scalar NaN") { bm0 * Double.NaN } - TestUtils.interceptFatal("division by scalar 0.0") { bm0 / 0 } - - TestUtils.interceptFatal(notSupported) { bm0.pow(-1)} + + TestUtils.interceptFatal(notSupported)(bm0 / bm0) + TestUtils.interceptFatal(notSupported)(bm0.reverseRowVectorDiv(v)) + TestUtils.interceptFatal(notSupported)(bm0.reverseColVectorDiv(v)) + TestUtils.interceptFatal(notSupported)(1 / bm0) + + TestUtils.interceptFatal(notSupported)(bm0.rowVectorDiv(v0)) + TestUtils.interceptFatal(notSupported)(bm0.colVectorDiv(v0)) + TestUtils.interceptFatal("multiplication by scalar NaN")(bm0 * Double.NaN) + TestUtils.interceptFatal("division by scalar 0.0")(bm0 / 0) + + TestUtils.interceptFatal(notSupported)(bm0.pow(-1)) } - + @Test - def testRealizeBlocks() { - val lm = toLM(4, 4, Array( - 1, 2, 3, 4, - 5, 6, 7, 8, - 9, 10, 11, 12, - 13, 14, 15, 16)) + def testRealizeBlocks(): Unit = { + val lm = toLM( + 4, + 4, + Array( + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15, 16), + ) val bm = toBM(lm, blockSize = 2) - + val keepArray = Array( Array.empty[Int], Array(0), Array(1, 3), Array(2, 3), Array(1, 2, 3), - Array(0, 1, 2, 3)) + Array(0, 1, 2, 3), + ) val lm_zero = BDM.zeros[Double](2, 2) @@ -1091,20 +1262,22 @@ class BlockMatrixSuite extends HailSuite { } flm } - + assert(filteredEquals(bm.densify(), bm)) assert(filteredEquals(bm.realizeBlocks(None), bm)) for { keep <- keepArray } { val fbm = bm.filterBlocks(keep) val flm = filterBlocks(keep) - + assert(filteredEquals(fbm.densify(), toBM(flm, blockSize = 2))) assert(filteredEquals(fbm.realizeBlocks(Some(keep)), fbm)) - + val bis = (keep ++ Array(0, 2)).distinct.sorted - assert(filteredEquals(fbm.realizeBlocks(Some(Array(0, 2))), - toBM(flm, blockSize = 2).filterBlocks(bis))) + assert(filteredEquals( + fbm.realizeBlocks(Some(Array(0, 2))), + toBM(flm, blockSize = 2).filterBlocks(bis), + )) } } } diff --git a/hail/src/test/scala/is/hail/linalg/GridPartitionerSuite.scala b/hail/src/test/scala/is/hail/linalg/GridPartitionerSuite.scala index 79809b9579f..ec1dd69e38d 100644 --- a/hail/src/test/scala/is/hail/linalg/GridPartitionerSuite.scala +++ b/hail/src/test/scala/is/hail/linalg/GridPartitionerSuite.scala @@ -5,9 +5,9 @@ import org.testng.annotations.Test class GridPartitionerSuite extends TestNGSuite { - private def assertLayout(hg: GridPartitioner, layout: ((Int, Int), Int)*) { + private def assertLayout(hg: GridPartitioner, layout: ((Int, Int), Int)*): Unit = { layout.foreach { case ((i, j), p) => - assert(hg.coordinatesBlock(i, j) === p, s"at coordinates ${ (i, j) }") + assert(hg.coordinatesBlock(i, j) === p, s"at coordinates ${(i, j)}") } layout.foreach { case ((i, j), p) => assert(hg.blockCoordinates(p) === ((i, j)), s"at pid $p") @@ -15,41 +15,37 @@ class GridPartitionerSuite extends TestNGSuite { } @Test - def squareIsColumnMajor() { - assertLayout(GridPartitioner(2, 4, 4), - (0, 0) -> 0, - (1, 0) -> 1, - (0, 1) -> 2, - (1, 1) -> 3 - ) - } + def squareIsColumnMajor(): Unit = + assertLayout(GridPartitioner(2, 4, 4), (0, 0) -> 0, (1, 0) -> 1, (0, 1) -> 2, (1, 1) -> 3) @Test - def rectangleMoreRowsIsColumnMajor() { - assertLayout(GridPartitioner(2, 6, 4), + def rectangleMoreRowsIsColumnMajor(): Unit = { + assertLayout( + GridPartitioner(2, 6, 4), (0, 0) -> 0, (1, 0) -> 1, (2, 0) -> 2, (0, 1) -> 3, (1, 1) -> 4, - (2, 1) -> 5 + (2, 1) -> 5, ) } @Test - def rectangleMoreColsIsColumnMajor() { - assertLayout(GridPartitioner(2, 4, 6), + def rectangleMoreColsIsColumnMajor(): Unit = { + assertLayout( + GridPartitioner(2, 4, 6), (0, 0) -> 0, (1, 0) -> 1, (0, 1) -> 2, (1, 1) -> 3, (0, 2) -> 4, - (1, 2) -> 5 + (1, 2) -> 5, ) } @Test - def bandedBlocksTest() { + def bandedBlocksTest(): Unit = { // 0 3 6 9 // 1 4 7 10 // 2 5 8 11 @@ -77,9 +73,9 @@ class GridPartitionerSuite extends TestNGSuite { assert(gp.bandBlocks(-1000, 1000) sameElements (0 until 12)) } } - + @Test - def rectangularBlocksTest() { + def rectangularBlocksTest(): Unit = { // 0 3 6 9 // 1 4 7 10 // 2 5 8 11 @@ -94,11 +90,15 @@ class GridPartitionerSuite extends TestNGSuite { assert(gp.rectangleBlocks(Array(9, 11, 9, 11)) sameElements Array(0, 1, 3, 4)) assert(gp.rectanglesBlocks(Array(Array(9, 11, 9, 11))) sameElements Array(0, 1, 3, 4)) - + assert(gp.rectangleBlocks(Array(10, 20, 10, 30)) sameElements Array(4, 7)) assert(gp.rectanglesBlocks(Array( - Array(9, 11, 9, 11), Array(10, 20, 10, 30), Array(0, 1, 20, 21), Array(20, 21, 20, 31))) + Array(9, 11, 9, 11), + Array(10, 20, 10, 30), + Array(0, 1, 20, 21), + Array(20, 21, 20, 31), + )) sameElements Array(0, 1, 3, 4, 6, 7, 8, 11)) assert(gp.rectangleBlocks(Array(0, 21, 0, 31)) sameElements (0 until 12)) diff --git a/hail/src/test/scala/is/hail/linalg/RowMatrixSuite.scala b/hail/src/test/scala/is/hail/linalg/RowMatrixSuite.scala index 817b5b904e6..4e0d2f04ce1 100644 --- a/hail/src/test/scala/is/hail/linalg/RowMatrixSuite.scala +++ b/hail/src/test/scala/is/hail/linalg/RowMatrixSuite.scala @@ -1,61 +1,69 @@ package is.hail.linalg - -import breeze.linalg.DenseMatrix + import is.hail.HailSuite import is.hail.check.Gen import is.hail.utils._ + +import breeze.linalg.DenseMatrix import org.testng.annotations.Test class RowMatrixSuite extends HailSuite { - private def rowArrayToRowMatrix(a: Array[Array[Double]], nPartitions: Int = sc.defaultParallelism): RowMatrix = { + private def rowArrayToRowMatrix(a: Array[Array[Double]], nPartitions: Int = sc.defaultParallelism) + : RowMatrix = { require(a.length > 0) val nRows = a.length val nCols = a(0).length - - RowMatrix(sc.parallelize(a.zipWithIndex.map { case (row, i) => (i.toLong, row) }, nPartitions), nCols, nRows) + + RowMatrix( + sc.parallelize(a.zipWithIndex.map { case (row, i) => (i.toLong, row) }, nPartitions), + nCols, + nRows, + ) } - + private def rowArrayToLocalMatrix(a: Array[Array[Double]]): DenseMatrix[Double] = { require(a.length > 0) val nRows = a.length val nCols = a(0).length - + new DenseMatrix[Double](nRows, nCols, a.flatten, 0, nCols, isTranspose = true) } - + @Test - def localizeRowMatrix() { + def localizeRowMatrix(): Unit = { val fname = ctx.createTmpPath("test") - + val rowArrays = Array( Array(1.0, 2.0, 3.0), - Array(4.0, 5.0, 6.0)) + Array(4.0, 5.0, 6.0), + ) val rowMatrix = rowArrayToRowMatrix(rowArrays) val localMatrix = rowArrayToLocalMatrix(rowArrays) - + BlockMatrix.fromBreezeMatrix(localMatrix).write(ctx, fname) - + assert(rowMatrix.toBreezeMatrix() === localMatrix) } @Test - def readBlockSmall() { + def readBlockSmall(): Unit = { val fname = ctx.createTmpPath("test") - + val localMatrix = DenseMatrix( Array(1.0, 2.0, 3.0), - Array(4.0, 5.0, 6.0)) - + Array(4.0, 5.0, 6.0), + ) + BlockMatrix.fromBreezeMatrix(localMatrix).write(ctx, fname, forceRowMajor = true) - + val rowMatrixFromBlock = RowMatrix.readBlockMatrix(fs, fname, 1) - + assert(rowMatrixFromBlock.toBreezeMatrix() == localMatrix) } - + @Test - def readBlock() { + def readBlock(): Unit = { val fname = ctx.createTmpPath("test") val lm = Gen.denseMatrix[Double](9, 10).sample() @@ -63,159 +71,363 @@ class RowMatrixSuite extends HailSuite { blockSize <- Seq(1, 2, 3, 4, 6, 7, 9, 10) partSize <- Seq(1, 2, 4, 9, 11) } { - BlockMatrix.fromBreezeMatrix(lm, blockSize).write(ctx, fname, overwrite = true, forceRowMajor = true) + BlockMatrix.fromBreezeMatrix(lm, blockSize).write( + ctx, + fname, + overwrite = true, + forceRowMajor = true, + ) val rowMatrix = RowMatrix.readBlockMatrix(fs, fname, partSize) - + assert(rowMatrix.toBreezeMatrix() === lm) } } - + private def readCSV(fname: String): Array[Array[Double]] = - fs.readLines(fname)( it => + fs.readLines(fname)(it => it.map(_.value) .map(_.split(",").map(_.toDouble)) .toArray[Array[Double]] ) - private def exportImportAssert(export: (String) => Unit, expected: Array[Double]*) { + private def exportImportAssert(export: (String) => Unit, expected: Array[Double]*): Unit = { val fname = ctx.createTmpPath("test") export(fname) assert(readCSV(fname) === expected.toArray[Array[Double]]) } @Test - def exportWithIndex() { + def exportWithIndex(): Unit = { val rowArrays = Array( Array(1.0, 2.0, 3.0), Array(4.0, 5.0, 6.0), - Array(7.0, 8.0, 9.0)) + Array(7.0, 8.0, 9.0), + ) val rowMatrix = rowArrayToRowMatrix(rowArrays, nPartitions = 2) val rowArraysWithIndex = Array( Array(0.0, 1.0, 2.0, 3.0), Array(1.0, 4.0, 5.0, 6.0), - Array(2.0, 7.0, 8.0, 9.0)) + Array(2.0, 7.0, 8.0, 9.0), + ) - exportImportAssert(rowMatrix.export(ctx, _, ",", header = None, addIndex = true, exportType = ExportType.CONCATENATED), - rowArraysWithIndex: _*) + exportImportAssert( + rowMatrix.export( + ctx, + _, + ",", + header = None, + addIndex = true, + exportType = ExportType.CONCATENATED, + ), + rowArraysWithIndex: _* + ) } @Test - def exportSquare() { + def exportSquare(): Unit = { val rowArrays = Array( Array(1.0, 2.0, 3.0), Array(4.0, 5.0, 6.0), - Array(7.0, 8.0, 9.0)) + Array(7.0, 8.0, 9.0), + ) val rowMatrix = rowArrayToRowMatrix(rowArrays) - exportImportAssert(rowMatrix.export(ctx, _, ",", header=None, addIndex = false, exportType = ExportType.CONCATENATED), - rowArrays: _*) + exportImportAssert( + rowMatrix.export( + ctx, + _, + ",", + header = None, + addIndex = false, + exportType = ExportType.CONCATENATED, + ), + rowArrays: _* + ) - exportImportAssert(rowMatrix.exportLowerTriangle(ctx, _, ",", header=None, addIndex = false, exportType = ExportType.CONCATENATED), + exportImportAssert( + rowMatrix.exportLowerTriangle( + ctx, + _, + ",", + header = None, + addIndex = false, + exportType = ExportType.CONCATENATED, + ), Array(1.0), Array(4.0, 5.0), - Array(7.0, 8.0, 9.0)) + Array(7.0, 8.0, 9.0), + ) - exportImportAssert(rowMatrix.exportStrictLowerTriangle(ctx, _, ",", header=None, addIndex = false, exportType = ExportType.CONCATENATED), + exportImportAssert( + rowMatrix.exportStrictLowerTriangle( + ctx, + _, + ",", + header = None, + addIndex = false, + exportType = ExportType.CONCATENATED, + ), Array(4.0), - Array(7.0, 8.0)) + Array(7.0, 8.0), + ) - exportImportAssert(rowMatrix.exportUpperTriangle(ctx, _, ",", header=None, addIndex = false, exportType = ExportType.CONCATENATED), + exportImportAssert( + rowMatrix.exportUpperTriangle( + ctx, + _, + ",", + header = None, + addIndex = false, + exportType = ExportType.CONCATENATED, + ), Array(1.0, 2.0, 3.0), Array(5.0, 6.0), - Array(9.0)) + Array(9.0), + ) - exportImportAssert(rowMatrix.exportStrictUpperTriangle(ctx, _, ",", header=None, addIndex = false, exportType = ExportType.CONCATENATED), + exportImportAssert( + rowMatrix.exportStrictUpperTriangle( + ctx, + _, + ",", + header = None, + addIndex = false, + exportType = ExportType.CONCATENATED, + ), Array(2.0, 3.0), - Array(6.0)) + Array(6.0), + ) } - + @Test - def exportWide() { + def exportWide(): Unit = { val rowArrays = Array( Array(1.0, 2.0, 3.0), - Array(4.0, 5.0, 6.0)) + Array(4.0, 5.0, 6.0), + ) val rowMatrix = rowArrayToRowMatrix(rowArrays) - exportImportAssert(rowMatrix.export(ctx, _, ",", header=None, addIndex = false, exportType = ExportType.CONCATENATED), - rowArrays: _*) + exportImportAssert( + rowMatrix.export( + ctx, + _, + ",", + header = None, + addIndex = false, + exportType = ExportType.CONCATENATED, + ), + rowArrays: _* + ) - exportImportAssert(rowMatrix.exportLowerTriangle(ctx, _, ",", header=None, addIndex = false, exportType = ExportType.CONCATENATED), + exportImportAssert( + rowMatrix.exportLowerTriangle( + ctx, + _, + ",", + header = None, + addIndex = false, + exportType = ExportType.CONCATENATED, + ), Array(1.0), - Array(4.0, 5.0)) + Array(4.0, 5.0), + ) - exportImportAssert(rowMatrix.exportStrictLowerTriangle(ctx, _, ",", header=None, addIndex = false, exportType = ExportType.CONCATENATED), - Array(4.0)) - - exportImportAssert(rowMatrix.exportUpperTriangle(ctx, _, ",", header=None, addIndex = false, exportType = ExportType.CONCATENATED), + exportImportAssert( + rowMatrix.exportStrictLowerTriangle( + ctx, + _, + ",", + header = None, + addIndex = false, + exportType = ExportType.CONCATENATED, + ), + Array(4.0), + ) + + exportImportAssert( + rowMatrix.exportUpperTriangle( + ctx, + _, + ",", + header = None, + addIndex = false, + exportType = ExportType.CONCATENATED, + ), Array(1.0, 2.0, 3.0), - Array(5.0, 6.0)) - - exportImportAssert(rowMatrix.exportStrictUpperTriangle(ctx, _, ",", header=None, addIndex = false, exportType = ExportType.CONCATENATED), + Array(5.0, 6.0), + ) + + exportImportAssert( + rowMatrix.exportStrictUpperTriangle( + ctx, + _, + ",", + header = None, + addIndex = false, + exportType = ExportType.CONCATENATED, + ), Array(2.0, 3.0), - Array(6.0)) + Array(6.0), + ) } - + @Test - def exportTall() { + def exportTall(): Unit = { val rowArrays = Array( Array(1.0, 2.0), Array(4.0, 5.0), - Array(7.0, 8.0)) + Array(7.0, 8.0), + ) val rowMatrix = rowArrayToRowMatrix(rowArrays) - exportImportAssert(rowMatrix.export(ctx, _, ",", header=None, addIndex = false, exportType = ExportType.CONCATENATED), - rowArrays: _*) + exportImportAssert( + rowMatrix.export( + ctx, + _, + ",", + header = None, + addIndex = false, + exportType = ExportType.CONCATENATED, + ), + rowArrays: _* + ) - exportImportAssert(rowMatrix.exportLowerTriangle(ctx, _, ",", header=None, addIndex = false, exportType = ExportType.CONCATENATED), + exportImportAssert( + rowMatrix.exportLowerTriangle( + ctx, + _, + ",", + header = None, + addIndex = false, + exportType = ExportType.CONCATENATED, + ), Array(1.0), Array(4.0, 5.0), - Array(7.0, 8.0)) + Array(7.0, 8.0), + ) - exportImportAssert(rowMatrix.exportStrictLowerTriangle(ctx, _, ",", header=None, addIndex = false, exportType = ExportType.CONCATENATED), + exportImportAssert( + rowMatrix.exportStrictLowerTriangle( + ctx, + _, + ",", + header = None, + addIndex = false, + exportType = ExportType.CONCATENATED, + ), Array(4.0), - Array(7.0, 8.0)) + Array(7.0, 8.0), + ) - exportImportAssert(rowMatrix.exportUpperTriangle(ctx, _, ",", header=None, addIndex = false, exportType = ExportType.CONCATENATED), + exportImportAssert( + rowMatrix.exportUpperTriangle( + ctx, + _, + ",", + header = None, + addIndex = false, + exportType = ExportType.CONCATENATED, + ), Array(1.0, 2.0), - Array(5.0)) - - exportImportAssert(rowMatrix.exportStrictUpperTriangle(ctx, _, ",", header=None, addIndex = false, exportType = ExportType.CONCATENATED), - Array(2.0)) - } + Array(5.0), + ) + + exportImportAssert( + rowMatrix.exportStrictUpperTriangle( + ctx, + _, + ",", + header = None, + addIndex = false, + exportType = ExportType.CONCATENATED, + ), + Array(2.0), + ) + } @Test - def exportBig() { - val rowArrays: Array[Array[Double]] = Array.tabulate(20)( r => Array.tabulate(30)(c => 30 * c + r)) + def exportBig(): Unit = { + val rowArrays: Array[Array[Double]] = + Array.tabulate(20)(r => Array.tabulate(30)(c => 30 * c + r)) val rowMatrix = rowArrayToRowMatrix(rowArrays) - - exportImportAssert(rowMatrix.export(ctx, _, ",", header=None, addIndex = false, exportType = ExportType.CONCATENATED), - rowArrays: _*) - exportImportAssert(rowMatrix.exportLowerTriangle(ctx, _, ",", header=None, addIndex = false, exportType = ExportType.CONCATENATED), + exportImportAssert( + rowMatrix.export( + ctx, + _, + ",", + header = None, + addIndex = false, + exportType = ExportType.CONCATENATED, + ), + rowArrays: _* + ) + + exportImportAssert( + rowMatrix.exportLowerTriangle( + ctx, + _, + ",", + header = None, + addIndex = false, + exportType = ExportType.CONCATENATED, + ), rowArrays.zipWithIndex .map { case (a, i) => - a.zipWithIndex.filter { case (_, j) => j <= i }.map(_._1) } - .toArray[Array[Double]]:_*) - - exportImportAssert(rowMatrix.exportStrictLowerTriangle(ctx, _, ",", header=None, addIndex = false, exportType = ExportType.CONCATENATED), + a.zipWithIndex.filter { case (_, j) => j <= i }.map(_._1) + } + .toArray[Array[Double]]: _* + ) + + exportImportAssert( + rowMatrix.exportStrictLowerTriangle( + ctx, + _, + ",", + header = None, + addIndex = false, + exportType = ExportType.CONCATENATED, + ), rowArrays.zipWithIndex .map { case (a, i) => - a.zipWithIndex.filter { case (_, j) => j < i }.map(_._1) } + a.zipWithIndex.filter { case (_, j) => j < i }.map(_._1) + } .filter(_.nonEmpty) - .toArray[Array[Double]]:_*) + .toArray[Array[Double]]: _* + ) - exportImportAssert(rowMatrix.exportUpperTriangle(ctx, _, ",", header=None, addIndex = false, exportType = ExportType.CONCATENATED), + exportImportAssert( + rowMatrix.exportUpperTriangle( + ctx, + _, + ",", + header = None, + addIndex = false, + exportType = ExportType.CONCATENATED, + ), rowArrays.zipWithIndex .map { case (a, i) => - a.zipWithIndex.filter { case (_, j) => j >= i }.map(_._1) } - .toArray[Array[Double]]:_*) + a.zipWithIndex.filter { case (_, j) => j >= i }.map(_._1) + } + .toArray[Array[Double]]: _* + ) - exportImportAssert(rowMatrix.exportStrictUpperTriangle(ctx, _, ",", header=None, addIndex = false, exportType = ExportType.CONCATENATED), + exportImportAssert( + rowMatrix.exportStrictUpperTriangle( + ctx, + _, + ",", + header = None, + addIndex = false, + exportType = ExportType.CONCATENATED, + ), rowArrays.zipWithIndex .map { case (a, i) => - a.zipWithIndex.filter { case (_, j) => j > i }.map(_._1) } + a.zipWithIndex.filter { case (_, j) => j > i }.map(_._1) + } .filter(_.nonEmpty) - .toArray[Array[Double]]:_*) + .toArray[Array[Double]]: _* + ) } -} \ No newline at end of file +} diff --git a/hail/src/test/scala/is/hail/linalg/RowPartitionerSuite.scala b/hail/src/test/scala/is/hail/linalg/RowPartitionerSuite.scala index 7d44b7f5567..a564d92daf4 100644 --- a/hail/src/test/scala/is/hail/linalg/RowPartitionerSuite.scala +++ b/hail/src/test/scala/is/hail/linalg/RowPartitionerSuite.scala @@ -1,13 +1,14 @@ package is.hail.linalg -import is.hail.check.Arbitrary.arbitrary import is.hail.check.{Gen, Prop} +import is.hail.check.Arbitrary.arbitrary + import org.scalatest.testng.TestNGSuite import org.testng.annotations.Test class RowPartitionerSuite extends TestNGSuite { @Test - def testGetPartition() { + def testGetPartition(): Unit = { val partitionStarts = Array[Long](0, 0, 0, 4, 5, 5, 8, 10, 10) val partitionCounts = Array(0, 0, 4, 1, 0, 3, 2, 0) val keyPart = partitionCounts.zipWithIndex.flatMap { case (count, pi) => Array.fill(count)(pi) } @@ -16,9 +17,9 @@ class RowPartitionerSuite extends TestNGSuite { assert(rp.numPartitions == 8) assert((0 until 10).forall(i => keyPart(i) == rp.getPartition(i.toLong))) } - - @Test def testFindInterval() { - def naiveFindInterval(a: Array[Long], key: Long): Int = { + + @Test def testFindInterval(): Unit = { + def naiveFindInterval(a: Array[Long], key: Long): Int = { if (a.length == 0 || key < a(0)) -1 else if (key >= a(a.length - 1)) @@ -32,16 +33,15 @@ class RowPartitionerSuite extends TestNGSuite { } val moreKeys = Array(Long.MinValue, -1000L, -1L, 0L, 1L, 1000L, Long.MaxValue) - + val g = for { a0 <- Gen.buildableOf[Array](arbitrary[Long]) a = a0.sorted } yield { val len = a.length - for {key <- a ++ moreKeys} { + for { key <- a ++ moreKeys } if (key > a(0) && key < a(len - 1)) assert(RowPartitioner.findInterval(a, key) == naiveFindInterval(a, key)) - } true } Prop.forAll(g).check() diff --git a/hail/src/test/scala/is/hail/lir/CompileTimeRequirednessSuite.scala b/hail/src/test/scala/is/hail/lir/CompileTimeRequirednessSuite.scala index f79c26ce7c9..4faa96ff20d 100644 --- a/hail/src/test/scala/is/hail/lir/CompileTimeRequirednessSuite.scala +++ b/hail/src/test/scala/is/hail/lir/CompileTimeRequirednessSuite.scala @@ -2,7 +2,7 @@ package is.hail.lir import is.hail.HailSuite import is.hail.asm4s._ -import is.hail.utils._ + import org.testng.annotations.Test class CompileTimeRequirednessSuite extends HailSuite { diff --git a/hail/src/test/scala/is/hail/lir/LIRSplitSuite.scala b/hail/src/test/scala/is/hail/lir/LIRSplitSuite.scala index 61d2d54efa6..060ceb1157c 100644 --- a/hail/src/test/scala/is/hail/lir/LIRSplitSuite.scala +++ b/hail/src/test/scala/is/hail/lir/LIRSplitSuite.scala @@ -1,13 +1,14 @@ package is.hail.lir import is.hail.HailSuite -import is.hail.expr.ir.{EmitFunctionBuilder, ParamType} import is.hail.asm4s._ +import is.hail.expr.ir.{EmitFunctionBuilder, ParamType} + import org.testng.annotations.Test class LIRSplitSuite extends HailSuite { - @Test def testSplitPreservesParameterMutation() { + @Test def testSplitPreservesParameterMutation(): Unit = { val f = EmitFunctionBuilder[Unit](ctx, "F") f.emitWithBuilder { cb => val mb = f.newEmitMethod("m", IndexedSeq[ParamType](typeInfo[Long]), typeInfo[Unit]) @@ -15,11 +16,9 @@ class LIRSplitSuite extends HailSuite { val arg = mb.getCodeParam[Long](1) cb.assign(arg, 1000L) - (0 until 1000).foreach { i => - cb.if_(arg.cne(1000L), cb._fatal(s"bad split at $i!")) - } + (0 until 1000).foreach(i => cb.if_(arg.cne(1000L), cb._fatal(s"bad split at $i!"))) } - cb.invokeVoid(mb, const(1L)) + cb.invokeVoid(mb, cb.this_, const(1L)) Code._empty } f.resultWithIndex()(theHailClassLoader, ctx.fs, ctx.taskContext, ctx.r)() diff --git a/hail/src/test/scala/is/hail/methods/ExprSuite.scala b/hail/src/test/scala/is/hail/methods/ExprSuite.scala index 44d67162f97..bc16f8b169b 100644 --- a/hail/src/test/scala/is/hail/methods/ExprSuite.scala +++ b/hail/src/test/scala/is/hail/methods/ExprSuite.scala @@ -5,9 +5,11 @@ import is.hail.backend.HailStateManager import is.hail.check.Prop._ import is.hail.check.Properties import is.hail.expr._ -import is.hail.types.virtual.{TFloat64, TInt32, Type} import is.hail.expr.ir.IRParser +import is.hail.types.virtual._ import is.hail.utils.StringEscapeUtils._ + +import org.apache.spark.sql.Row import org.json4s._ import org.json4s.jackson.JsonMethods._ import org.testng.annotations.Test @@ -16,7 +18,7 @@ class ExprSuite extends HailSuite { def sm: HailStateManager = ctx.stateManager - @Test def testTypePretty() { + @Test def testTypePretty(): Unit = { // for arbType val sb = new StringBuilder @@ -41,15 +43,13 @@ class ExprSuite extends HailSuite { }) } - @Test def testEscaping() { - val p = forAll { (s: String) => - s == unescapeString(escapeString(s)) - } + @Test def testEscaping(): Unit = { + val p = forAll((s: String) => s == unescapeString(escapeString(s))) p.check() } - @Test def testEscapingSimple() { + @Test def testEscapingSimple(): Unit = { // a == 0x61, _ = 0x5f assert(escapeStringSimple("abc", '_', _ => false) == "abc") assert(escapeStringSimple("abc", '_', _ == 'a') == "_61bc") @@ -62,16 +62,43 @@ class ExprSuite extends HailSuite { assert(unescapeStringSimple("my name is _u540d_u8c26", '_') == "my name is 名谦") val p = forAll { (s: String) => - s == unescapeStringSimple(escapeStringSimple(s, '_', _.isLetterOrDigit, _.isLetterOrDigit), '_') + s == unescapeStringSimple( + escapeStringSimple(s, '_', _.isLetterOrDigit, _.isLetterOrDigit), + '_', + ) } p.check() } - @Test def testImpexes() { + @Test def testImportEmptyJSONObjectAsStruct(): Unit = + assert(JSONAnnotationImpex.importAnnotation(parse("{}"), TStruct()) == Row()) + + @Test def testExportEmptyJSONObjectAsStruct(): Unit = + assert(compact(render(JSONAnnotationImpex.exportAnnotation(Row(), TStruct()))) == "{}") + + @Test def testRoundTripEmptyJSONObject(): Unit = { + val actual = JSONAnnotationImpex.exportAnnotation( + JSONAnnotationImpex.importAnnotation(parse("{}"), TStruct()), + TStruct(), + ) + assert(compact(render(actual)) == "{}") + } + + @Test def testRoundTripEmptyStruct(): Unit = { + val actual = JSONAnnotationImpex.importAnnotation( + JSONAnnotationImpex.exportAnnotation(Row(), TStruct()), + TStruct(), + ) + assert(actual == Row()) + } + + @Test def testImpexes(): Unit = { - val g = for {t <- Type.genArb - a <- t.genValue(sm)} yield (t, a) + val g = for { + t <- Type.genArb + a <- t.genValue(sm) + } yield (t, a) object Spec extends Properties("ImpEx") { property("json") = forAll(g) { case (t, a) => @@ -87,7 +114,7 @@ class ExprSuite extends HailSuite { Spec.check() } - @Test def testOrdering() { + @Test def testOrdering(): Unit = { val intOrd = TInt32.ordering(ctx.stateManager) assert(intOrd.compare(-2, -2) == 0) @@ -96,9 +123,11 @@ class ExprSuite extends HailSuite { assert(intOrd.compare(5, null) < 0) assert(intOrd.compare(null, -2) > 0) - val g = for (t <- Type.genArb; - a <- t.genValue(sm); - b <- t.genValue(sm)) yield (t, a, b) + val g = for { + t <- Type.genArb + a <- t.genValue(sm) + b <- t.genValue(sm) + } yield (t, a, b) val p = forAll(g) { case (t, a, b) => val ord = t.ordering(ctx.stateManager) diff --git a/hail/src/test/scala/is/hail/methods/LocalLDPruneSuite.scala b/hail/src/test/scala/is/hail/methods/LocalLDPruneSuite.scala index f6d6add8ea5..ebee4aa797d 100644 --- a/hail/src/test/scala/is/hail/methods/LocalLDPruneSuite.scala +++ b/hail/src/test/scala/is/hail/methods/LocalLDPruneSuite.scala @@ -1,18 +1,14 @@ package is.hail.methods -import breeze.linalg.{Vector => BVector} -import is.hail.annotations.{Annotation, Region, RegionPool, RegionValue, RegionValueBuilder} -import is.hail.backend.HailTaskContext -import is.hail.backend.spark.SparkTaskContext -import is.hail.check.Prop._ +import is.hail.{HailSuite, TestUtils} +import is.hail.annotations.Annotation import is.hail.check.{Gen, Properties} -import is.hail.expr.ir.{Interpret, LongArrayBuilder, MatrixValue, TableValue} -import is.hail.types._ -import is.hail.types.physical.{PStruct, PType} -import is.hail.types.virtual.{TArray, TString, TStruct} +import is.hail.check.Prop._ +import is.hail.expr.ir.{Interpret, MatrixValue, TableValue} import is.hail.utils._ import is.hail.variant._ -import is.hail.{HailSuite, TestUtils} + +import breeze.linalg.{Vector => BVector} import org.apache.spark.rdd.RDD import org.testng.annotations.Test @@ -25,35 +21,38 @@ object LocalLDPruneSuite { val alleles = Array("A", "T") val nSamples = calls.length val builder = new BitPackedVectorBuilder(nSamples) - for (call <- calls) { + for (call <- calls) if (call == null) builder.addMissing() else builder.addGT(call) - } Option(builder.finish(locus, alleles)) } def correlationMatrixGT(gts: Array[Iterable[Annotation]]) = correlationMatrix(gts.map { gts => - gts.map { gt => Genotype.call(gt).map(c => c: BoxedCall).orNull }.toArray + gts.map(gt => Genotype.call(gt).map(c => c: BoxedCall).orNull).toArray }) def correlationMatrix(gts: Array[Array[BoxedCall]]) = { - val bvi = gts.map { gs => LocalLDPruneSuite.fromCalls(gs.toIndexedSeq) } - val r2 = for (i <- bvi.indices; j <- bvi.indices) yield { - (bvi(i), bvi(j)) match { + val bvi = gts.map(gs => LocalLDPruneSuite.fromCalls(gs.toIndexedSeq)) + val r2 = + for { + i <- bvi.indices + j <- bvi.indices + } yield (bvi(i), bvi(j)) match { case (Some(x), Some(y)) => Some(LocalLDPrune.computeR2(x, y)) case _ => None } - } val nVariants = bvi.length new MultiArray2(nVariants, nVariants, r2.toArray) } def estimateMemoryRequirements(nVariants: Long, nSamples: Int, memoryPerCore: Long): Int = { - val bytesPerVariant = math.ceil(8 * nSamples.toDouble / BitPackedVector.GENOTYPES_PER_PACK).toLong + variantByteOverhead + val bytesPerVariant = math.ceil( + 8 * nSamples.toDouble / BitPackedVector.GENOTYPES_PER_PACK + ).toLong + variantByteOverhead val memoryAvailPerCore = memoryPerCore * fractionMemoryToUse val maxQueueSize = math.max(1, math.ceil(memoryAvailPerCore / bytesPerVariant).toInt) @@ -110,18 +109,24 @@ object LocalLDPruneSuite { class LocalLDPruneSuite extends HailSuite { val memoryPerCoreBytes = 256 * 1024 * 1024 val nCores = 4 - lazy val mt = Interpret(TestUtils.importVCF(ctx, "src/test/resources/sample.vcf.bgz", nPartitions = Option(10)), - ctx, false).toMatrixValue(Array("s")) + + lazy val mt = Interpret( + TestUtils.importVCF(ctx, "src/test/resources/sample.vcf.bgz", nPartitions = Option(10)), + ctx, + false, + ).toMatrixValue(Array("s")) lazy val maxQueueSize = LocalLDPruneSuite.estimateMemoryRequirements( mt.rvd.count(), mt.nCols, - memoryPerCoreBytes) + memoryPerCoreBytes, + ) def toC2(i: Int): BoxedCall = if (i == -1) null else Call2.fromUnphasedDiploidGtIndex(i) def getLocallyPrunedRDDWithGT( - unprunedMatrixTable: MatrixValue, locallyPrunedTable: TableValue + unprunedMatrixTable: MatrixValue, + locallyPrunedTable: TableValue, ): RDD[(Locus, Any, Iterable[Annotation])] = { val mtLocusIndex = unprunedMatrixTable.rvRowPType.index("locus").get val mtAllelesIndex = unprunedMatrixTable.rvRowPType.index("alleles").get @@ -131,21 +136,34 @@ class LocalLDPruneSuite extends HailSuite { val allelesIndex = locallyPrunedTable.rvd.rowType.fieldIdx("alleles") val locallyPrunedVariants = locallyPrunedTable.rdd.mapPartitions( - it => it.map(row => (row.get(locusIndex), row.get(allelesIndex))), preservesPartitioning = true).collectAsSet() + it => it.map(row => (row.get(locusIndex), row.get(allelesIndex))), + preservesPartitioning = true, + ).collectAsSet() unprunedMatrixTable.rvd.toRows.map { r => - (r.getAs[Locus](mtLocusIndex), r.getAs[Any](mtAllelesIndex), r.getAs[Iterable[Annotation]](mtEntriesIndex)) - }.filter { case (locus, alleles, gs) => locallyPrunedVariants.contains((locus, alleles)) } + ( + r.getAs[Locus](mtLocusIndex), + r.getAs[Any](mtAllelesIndex), + r.getAs[Iterable[Annotation]](mtEntriesIndex), + ) + }.filter { case (locus, alleles, _) => locallyPrunedVariants.contains((locus, alleles)) } } - def isGloballyUncorrelated(unprunedMatrixTable: MatrixValue, locallyPrunedTable: TableValue, r2Threshold: Double, - windowSize: Int): Boolean = { + def isGloballyUncorrelated( + unprunedMatrixTable: MatrixValue, + locallyPrunedTable: TableValue, + r2Threshold: Double, + windowSize: Int, + ): Boolean = { val locallyPrunedRDD = getLocallyPrunedRDDWithGT(unprunedMatrixTable, locallyPrunedTable) - val nSamples = unprunedMatrixTable.nCols - val r2Matrix = LocalLDPruneSuite.correlationMatrixGT(locallyPrunedRDD.map { case (locus, alleles, gs) => gs }.collect()) - val variantMap = locallyPrunedRDD.zipWithIndex.map { case ((locus, alleles, gs), i) => (i.toInt, locus) }.collectAsMap() + val r2Matrix = LocalLDPruneSuite.correlationMatrixGT(locallyPrunedRDD.map { + case (_, _, gs) => gs + }.collect()) + val variantMap = locallyPrunedRDD.zipWithIndex.map { case ((locus, _, _), i) => + (i.toInt, locus) + }.collectAsMap() r2Matrix.indices.forall { case (i, j) => val locus1 = variantMap(i) @@ -153,59 +171,72 @@ class LocalLDPruneSuite extends HailSuite { val r2 = r2Matrix(i, j) locus1 == locus2 || - locus1.contig != locus2.contig || - (locus1.contig == locus2.contig && math.abs(locus1.position - locus2.position) > windowSize) || - r2.exists(_ < r2Threshold) + locus1.contig != locus2.contig || + (locus1.contig == locus2.contig && math.abs( + locus1.position - locus2.position + ) > windowSize) || + r2.exists(_ < r2Threshold) } } - def isLocallyUncorrelated(unprunedMatrixTable: MatrixValue, locallyPrunedTable: TableValue, r2Threshold: Double, - windowSize: Int): Boolean = { + def isLocallyUncorrelated( + unprunedMatrixTable: MatrixValue, + locallyPrunedTable: TableValue, + r2Threshold: Double, + windowSize: Int, + ): Boolean = { val locallyPrunedRDD = getLocallyPrunedRDDWithGT(unprunedMatrixTable, locallyPrunedTable) - val nSamples = unprunedMatrixTable.nCols val locallyUncorrelated = { - locallyPrunedRDD.mapPartitions(it => { - // bind function for serialization - val computeCorrelationMatrix = (gts: Array[Iterable[Annotation]]) => - LocalLDPruneSuite.correlationMatrixGT(gts) - - val (it1, it2) = it.duplicate - val localR2Matrix = computeCorrelationMatrix(it1.map { case (locus, alleles, gs) => gs }.toArray) - val localVariantMap = it2.zipWithIndex.map { case ((locus, alleles, gs), i) => (i, locus) }.toMap - - val uncorrelated = localR2Matrix.indices.forall { case (i, j) => - val locus1 = localVariantMap(i) - val locus2 = localVariantMap(j) - val r2 = localR2Matrix(i, j) - - locus1 == locus2 || + locallyPrunedRDD.mapPartitions( + it => { + // bind function for serialization + val computeCorrelationMatrix = (gts: Array[Iterable[Annotation]]) => + LocalLDPruneSuite.correlationMatrixGT(gts) + + val (it1, it2) = it.duplicate + val localR2Matrix = computeCorrelationMatrix(it1.map { case (_, _, gs) => + gs + }.toArray) + val localVariantMap = it2.zipWithIndex.map { case ((locus, _, _), i) => + (i, locus) + }.toMap + + val uncorrelated = localR2Matrix.indices.forall { case (i, j) => + val locus1 = localVariantMap(i) + val locus2 = localVariantMap(j) + val r2 = localR2Matrix(i, j) + + locus1 == locus2 || locus1.contig != locus2.contig || - (locus1.contig == locus2.contig && math.abs(locus1.position - locus2.position) > windowSize) || + (locus1.contig == locus2.contig && math.abs( + locus1.position - locus2.position + ) > windowSize) || r2.exists(_ < r2Threshold) - } + } - Iterator(uncorrelated) - }, preservesPartitioning = true) + Iterator(uncorrelated) + }, + preservesPartitioning = true, + ) } locallyUncorrelated.fold(true)((bool1, bool2) => bool1 && bool2) } - @Test def testBitPackUnpack() { + @Test def testBitPackUnpack(): Unit = { val calls1 = Array(-1, 0, 1, 2, 1, 1, 0, 0, 0, 0, 2, 2, -1, -1, -1, -1).map(toC2) val calls2 = Array(0, 1, 2, 2, 2, 0, -1, -1).map(toC2) val calls3 = calls1 ++ Array.ofDim[Int](32 - calls1.length).map(toC2) ++ calls2 - for (calls <- Array(calls1, calls2, calls3)) { + for (calls <- Array(calls1, calls2, calls3)) assert(LocalLDPruneSuite.fromCalls(calls).forall { bpv => bpv.unpack().map(toC2(_)) sameElements calls }) - } } - @Test def testR2() { + @Test def testR2(): Unit = { val calls = Array( Array(1, 0, 0, 0, 0, 0, 0, 0).map(toC2), Array(1, 1, 1, 1, 1, 1, 1, 1).map(toC2), @@ -213,12 +244,16 @@ class LocalLDPruneSuite extends HailSuite { Array(1, 0, 0, 0, 1, 1, 1, 1).map(toC2), Array(1, 0, 0, 0, 1, 1, 2, 2).map(toC2), Array(1, 0, 1, 1, 2, 2, 0, 1).map(toC2), - Array(1, 0, 1, 0, 2, 2, 1, 1).map(toC2) + Array(1, 0, 1, 0, 2, 2, 1, 1).map(toC2), ) - val actualR2 = new MultiArray2(7, 7, fs.readLines("src/test/resources/ldprune_corrtest.txt")(_.flatMap(_.map { line => - line.trim.split("\t").map(r2 => if (r2 == "NA") None else Some(r2.toDouble)) - }.value).toArray)) + val actualR2 = new MultiArray2( + 7, + 7, + fs.readLines("src/test/resources/ldprune_corrtest.txt")(_.flatMap(_.map { line => + line.trim.split("\t").map(r2 => if (r2 == "NA") None else Some(r2.toDouble)) + }.value).toArray), + ) val computedR2 = LocalLDPruneSuite.correlationMatrix(calls) @@ -249,13 +284,14 @@ class LocalLDPruneSuite extends HailSuite { } object Spec extends Properties("LDPrune") { - val vectorGen = for (nSamples: Int <- Gen.choose(1, 1000); - v1: Array[BoxedCall] <- Gen.buildableOfN[Array](nSamples, Gen.choose(-1, 2).map(toC2)); - v2: Array[BoxedCall] <- Gen.buildableOfN[Array](nSamples, Gen.choose(-1, 2).map(toC2)) - ) yield (nSamples, v1, v2) + val vectorGen = for { + nSamples: Int <- Gen.choose(1, 1000) + v1: Array[BoxedCall] <- Gen.buildableOfN[Array](nSamples, Gen.choose(-1, 2).map(toC2)) + v2: Array[BoxedCall] <- Gen.buildableOfN[Array](nSamples, Gen.choose(-1, 2).map(toC2)) + } yield (nSamples, v1, v2) property("bitPacked pack and unpack give same as orig") = - forAll(vectorGen) { case (nSamples: Int, v1: Array[BoxedCall], _) => + forAll(vectorGen) { case (_: Int, v1: Array[BoxedCall], _) => val bpv = LocalLDPruneSuite.fromCalls(v1) bpv match { @@ -269,8 +305,10 @@ class LocalLDPruneSuite extends HailSuite { val bv1 = LocalLDPruneSuite.fromCalls(v1) val bv2 = LocalLDPruneSuite.fromCalls(v2) - val sgs1 = LocalLDPruneSuite.normalizedHardCalls(v1).map(math.sqrt(1d / nSamples) * BVector(_)) - val sgs2 = LocalLDPruneSuite.normalizedHardCalls(v2).map(math.sqrt(1d / nSamples) * BVector(_)) + val sgs1 = + LocalLDPruneSuite.normalizedHardCalls(v1).map(math.sqrt(1d / nSamples) * BVector(_)) + val sgs2 = + LocalLDPruneSuite.normalizedHardCalls(v2).map(math.sqrt(1d / nSamples) * BVector(_)) (bv1, bv2, sgs1, sgs2) match { case (Some(a), Some(b), Some(c: BVector[Double]), Some(d: BVector[Double])) => @@ -278,7 +316,8 @@ class LocalLDPruneSuite extends HailSuite { val r2Breeze = rBreeze * rBreeze val r2BitPacked = LocalLDPrune.computeR2(a, b) - val isSame = D_==(r2BitPacked, r2Breeze) && D_>=(r2BitPacked, 0d) && D_<=(r2BitPacked, 1d) + val isSame = + D_==(r2BitPacked, r2Breeze) && D_>=(r2BitPacked, 0d) && D_<=(r2BitPacked, 1d) if (!isSame) { println(s"breeze=$r2Breeze bitPacked=$r2BitPacked nSamples=$nSamples") } @@ -288,12 +327,12 @@ class LocalLDPruneSuite extends HailSuite { } } - @Test def testRandom() { + @Test def testRandom(): Unit = Spec.check() - } - @Test def testIsLocallyUncorrelated() { - val locallyPrunedVariantsTable = LocalLDPrune(ctx, mt, r2Threshold = 0.2, windowSize = 1000000, maxQueueSize = maxQueueSize) + @Test def testIsLocallyUncorrelated(): Unit = { + val locallyPrunedVariantsTable = + LocalLDPrune(ctx, mt, r2Threshold = 0.2, windowSize = 1000000, maxQueueSize = maxQueueSize) assert(isLocallyUncorrelated(mt, locallyPrunedVariantsTable, 0.2, 1000000)) assert(!isGloballyUncorrelated(mt, locallyPrunedVariantsTable, 0.2, 1000000)) } diff --git a/hail/src/test/scala/is/hail/methods/MultiArray2Suite.scala b/hail/src/test/scala/is/hail/methods/MultiArray2Suite.scala index 243d73787ff..0026478fb9c 100644 --- a/hail/src/test/scala/is/hail/methods/MultiArray2Suite.scala +++ b/hail/src/test/scala/is/hail/methods/MultiArray2Suite.scala @@ -2,18 +2,19 @@ package is.hail.methods import is.hail.HailSuite import is.hail.utils.MultiArray2 + import org.testng.annotations.Test -class MultiArray2Suite extends HailSuite{ +class MultiArray2Suite extends HailSuite { @Test def test() = { // test multiarray of size 0 will be created - val ma0 = MultiArray2.fill[Int](0, 0)(0) + MultiArray2.fill[Int](0, 0)(0) // test multiarray of size 0 that apply nothing out intercept[IllegalArgumentException] { val ma0 = MultiArray2.fill[Int](0, 0)(0) - ma0(0,0) + ma0(0, 0) } // test array index out of bounds on row slice @@ -24,81 +25,82 @@ class MultiArray2Suite extends HailSuite{ // bad multiarray initiation -- negative number intercept[IllegalArgumentException] { - val a = MultiArray2.fill[Int](-5,5)(0) + MultiArray2.fill[Int](-5, 5)(0) } // bad multiarray initiation -- negative number intercept[IllegalArgumentException] { - val a = MultiArray2.fill[Int](5,-5)(0) + MultiArray2.fill[Int](5, -5)(0) } - val ma1 = MultiArray2.fill[Int](10,3)(0) - for ((i,j) <- ma1.indices) { - ma1.update(i,j,i*j) - } - assert(ma1(2,2) == 4) - assert(ma1(6,1) == 6) + val ma1 = MultiArray2.fill[Int](10, 3)(0) + for ((i, j) <- ma1.indices) + ma1.update(i, j, i * j) + assert(ma1(2, 2) == 4) + assert(ma1(6, 1) == 6) // Catch exception if try to apply value that is not in indices of multiarray intercept[IllegalArgumentException] { - val foo = ma1(100,100) - } - - val ma2 = MultiArray2.fill[Int](10,3)(0) - for ((i,j) <- ma2.indices) { - ma2.update(i,j,i+j) + ma1(100, 100) } + val ma2 = MultiArray2.fill[Int](10, 3)(0) + for ((i, j) <- ma2.indices) + ma2.update(i, j, i + j) - assert(ma2(2,2) == 4) - assert(ma2(6,1) == 7) + assert(ma2(2, 2) == 4) + assert(ma2(6, 1) == 7) // Test zip with two ints val ma3 = ma1.zip(ma2) - assert(ma3(2,2) == ((4, 4))) - assert(ma3(6,1) == ((6, 7))) + assert(ma3(2, 2) == ((4, 4))) + assert(ma3(6, 1) == ((6, 7))) // Test zip with multi-arrays of different types - val ma4 = MultiArray2.fill[String](10,3)("foo") + val ma4 = MultiArray2.fill[String](10, 3)("foo") val ma5 = ma1.zip(ma4) - assert(ma5(2,2) == ((4,"foo"))) - assert(ma5(0,0) == ((0,"foo"))) + assert(ma5(2, 2) == ((4, "foo"))) + assert(ma5(0, 0) == ((0, "foo"))) // Test row slice - for (row <- ma5.rows; idx <- 0 until row.length) { - assert(row(idx) == ((row.i*idx, "foo"))) + for { + row <- ma5.rows + idx <- 0 until row.length } + assert(row(idx) == ((row.i * idx, "foo"))) intercept[IllegalArgumentException] { - val x = ma5.row(100) + ma5.row(100) } intercept[ArrayIndexOutOfBoundsException] { val x = ma5.row(0) - val y = x(100) + x(100) } intercept[IllegalArgumentException] { - val x = ma5.row(-5) + ma5.row(-5) } intercept[IllegalArgumentException] { - val x = ma5.column(100) + ma5.column(100) } intercept[IllegalArgumentException] { - val x = ma5.column(-5) + ma5.column(-5) } intercept[ArrayIndexOutOfBoundsException] { val x = ma5.column(0) - val y = x(100) + x(100) } // Test column slice - for (column <- ma5.columns; idx <- 0 until column.length) { - assert(column(idx) == ((column.j*idx, "foo"))) + for { + column <- ma5.columns + idx <- 0 until column.length } + assert(column(idx) == ((column.j * idx, "foo"))) } } diff --git a/hail/src/test/scala/is/hail/methods/SkatSuite.scala b/hail/src/test/scala/is/hail/methods/SkatSuite.scala index 067b295060b..a9b5e723f21 100644 --- a/hail/src/test/scala/is/hail/methods/SkatSuite.scala +++ b/hail/src/test/scala/is/hail/methods/SkatSuite.scala @@ -1,33 +1,35 @@ package is.hail.methods import is.hail.{HailSuite, TestUtils} +import is.hail.expr.ir.DoubleArrayBuilder import is.hail.utils._ + import breeze.linalg._ -import is.hail.expr.ir.DoubleArrayBuilder import org.testng.annotations.Test case class SkatAggForR(xs: BoxedArrayBuilder[DenseVector[Double]], weights: DoubleArrayBuilder) class SkatSuite extends HailSuite { - @Test def smallNLargeNEqualityTest() { + @Test def smallNLargeNEqualityTest(): Unit = { val rand = scala.util.Random rand.setSeed(0) - + val n = 10 // samples val m = 5 // variants val k = 3 // covariates - - val st = Array.tabulate(m){ _ => - SkatTuple(rand.nextDouble(), + + val st = Array.tabulate(m) { _ => + SkatTuple( + rand.nextDouble(), DenseVector(Array.fill(n)(rand.nextDouble())), - DenseVector(Array.fill(k)(rand.nextDouble()))) + DenseVector(Array.fill(k)(rand.nextDouble())), + ) } - val (qSmall, gramianSmall) = Skat.computeGramianSmallN(st) val (qLarge, gramianLarge) = Skat.computeGramianLargeN(st) - + assert(D_==(qSmall, qLarge)) TestUtils.assertMatrixEqualityDouble(gramianSmall, gramianLarge) } diff --git a/hail/src/test/scala/is/hail/rvd/RVDPartitionerSuite.scala b/hail/src/test/scala/is/hail/rvd/RVDPartitionerSuite.scala index 1364f16769d..4e6713c76be 100644 --- a/hail/src/test/scala/is/hail/rvd/RVDPartitionerSuite.scala +++ b/hail/src/test/scala/is/hail/rvd/RVDPartitionerSuite.scala @@ -3,6 +3,7 @@ package is.hail.rvd import is.hail.HailSuite import is.hail.types.virtual.{TInt32, TStruct} import is.hail.utils.{FastSeq, Interval} + import org.apache.spark.sql.Row import org.testng.ITestContext import org.testng.annotations.{BeforeMethod, Test} @@ -14,31 +15,37 @@ class RVDPartitionerSuite extends HailSuite { @BeforeMethod def setupPartitioner(context: ITestContext): Unit = { - partitioner = new RVDPartitioner(ctx.stateManager, kType, + partitioner = new RVDPartitioner( + ctx.stateManager, + kType, Array( Interval(Row(1, 0), Row(4, 3), true, false), Interval(Row(4, 3), Row(7, 9), true, false), - Interval(Row(7, 11), Row(10, 0), true, true)) + Interval(Row(7, 11), Row(10, 0), true, true), + ), ) } - @Test def testExtendKey() { - val p = new RVDPartitioner(ctx.stateManager, TStruct(("A", TInt32), ("B", TInt32)), + @Test def testExtendKey(): Unit = { + val p = new RVDPartitioner( + ctx.stateManager, + TStruct(("A", TInt32), ("B", TInt32)), Array( Interval(Row(1, 0), Row(4, 3), true, true), Interval(Row(4, 3), Row(4, 3), true, true), Interval(Row(4, 3), Row(7, 9), true, false), - Interval(Row(7, 11), Row(10, 0), true, true)) - ) + Interval(Row(7, 11), Row(10, 0), true, true), + ), + ) val extended = p.extendKey(kType) assert(extended.rangeBounds sameElements Array( Interval(Row(1, 0), Row(4, 3), true, true), Interval(Row(4, 3), Row(7, 9), false, false), - Interval(Row(7, 11), Row(10, 0), true, true)) - ) + Interval(Row(7, 11), Row(10, 0), true, true), + )) } - @Test def testGetPartitionWithPartitionKeys() { + @Test def testGetPartitionWithPartitionKeys(): Unit = { assert(partitioner.lowerBound(Row(-1, 7)) == 0) assert(partitioner.upperBound(Row(-1, 7)) == 0) @@ -58,7 +65,7 @@ class RVDPartitionerSuite extends HailSuite { assert(partitioner.upperBound(Row(12, 19)) == 3) } - @Test def testGetPartitionWithLargerKeys() { + @Test def testGetPartitionWithLargerKeys(): Unit = { assert(partitioner.lowerBound(Row(0, 1, 3)) == 0) assert(partitioner.upperBound(Row(0, 1, 3)) == 0) @@ -73,18 +80,18 @@ class RVDPartitionerSuite extends HailSuite { assert(partitioner.lowerBound(Row(11, 1, 42)) == 3) } - @Test def testGetPartitionPKWithSmallerKeys() { - assert(partitioner.lowerBound(Row(2)) == 0) - assert(partitioner.upperBound(Row(2)) == 1) + @Test def testGetPartitionPKWithSmallerKeys(): Unit = { + assert(partitioner.lowerBound(Row(2)) == 0) + assert(partitioner.upperBound(Row(2)) == 1) - assert(partitioner.lowerBound(Row(4)) == 0) - assert(partitioner.upperBound(Row(4)) == 2) + assert(partitioner.lowerBound(Row(4)) == 0) + assert(partitioner.upperBound(Row(4)) == 2) - assert(partitioner.lowerBound(Row(11)) == 3) - assert(partitioner.upperBound(Row(11)) == 3) - } + assert(partitioner.lowerBound(Row(11)) == 3) + assert(partitioner.upperBound(Row(11)) == 3) + } - @Test def testGetPartitionRange() { + @Test def testGetPartitionRange(): Unit = { assert(partitioner.queryInterval(Interval(Row(3, 4), Row(7, 11), true, true)) == Seq(0, 1, 2)) assert(partitioner.queryInterval(Interval(Row(3, 4), Row(7, 9), true, false)) == Seq(0, 1)) assert(partitioner.queryInterval(Interval(Row(4), Row(5), true, true)) == Seq(0, 1)) @@ -92,19 +99,20 @@ class RVDPartitionerSuite extends HailSuite { assert(partitioner.queryInterval(Interval(Row(-1, 7), Row(0, 9), true, false)) == Seq()) } - @Test def testGetSafePartitionKeyRange() { + @Test def testGetSafePartitionKeyRange(): Unit = { assert(partitioner.queryKey(Row(0, 0)).isEmpty) assert(partitioner.queryKey(Row(7, 10)).isEmpty) assert(partitioner.queryKey(Row(7, 11)) == Range.inclusive(2, 2)) } - @Test def testGenerateDisjoint() { + @Test def testGenerateDisjoint(): Unit = { val intervals = Array( - Interval(Row(1, 0, 4), Row(4, 3, 2), true, false), - Interval(Row(4, 3, 5), Row(7, 9, 1), true, false), - Interval(Row(7, 11, 3), Row(10, 0, 1), true, true), - Interval(Row(11, 0, 2), Row(11, 0, 15), false, true), - Interval(Row(11, 0, 15), Row(11, 0, 20), true, false)) + Interval(Row(1, 0, 4), Row(4, 3, 2), true, false), + Interval(Row(4, 3, 5), Row(7, 9, 1), true, false), + Interval(Row(7, 11, 3), Row(10, 0, 1), true, true), + Interval(Row(11, 0, 2), Row(11, 0, 15), false, true), + Interval(Row(11, 0, 15), Row(11, 0, 20), true, false), + ) val p3 = RVDPartitioner.generate(ctx.stateManager, Array("A", "B", "C"), kType, intervals) assert(p3.satisfiesAllowedOverlap(2)) @@ -114,8 +122,8 @@ class RVDPartitionerSuite extends HailSuite { Interval(Row(4, 3, 5), Row(7, 9, 1), true, false), Interval(Row(7, 11, 3), Row(10, 0, 1), true, true), Interval(Row(11, 0, 2), Row(11, 0, 15), false, true), - Interval(Row(11, 0, 15), Row(11, 0, 20), false, false)) - ) + Interval(Row(11, 0, 15), Row(11, 0, 20), false, false), + )) val p2 = RVDPartitioner.generate(ctx.stateManager, Array("A", "B"), kType, intervals) assert(p2.satisfiesAllowedOverlap(1)) @@ -124,8 +132,8 @@ class RVDPartitionerSuite extends HailSuite { Interval(Row(1, 0, 4), Row(4, 3), true, true), Interval(Row(4, 3), Row(7, 9, 1), false, false), Interval(Row(7, 11, 3), Row(10, 0, 1), true, true), - Interval(Row(11, 0, 2), Row(11, 0, 20), false, false)) - ) + Interval(Row(11, 0, 2), Row(11, 0, 20), false, false), + )) val p1 = RVDPartitioner.generate(ctx.stateManager, Array("A"), kType, intervals) assert(p1.satisfiesAllowedOverlap(0)) @@ -134,11 +142,11 @@ class RVDPartitionerSuite extends HailSuite { Interval(Row(1, 0, 4), Row(4), true, true), Interval(Row(4), Row(7), false, true), Interval(Row(7), Row(10, 0, 1), false, true), - Interval(Row(11, 0, 2), Row(11, 0, 20), false, false)) - ) + Interval(Row(11, 0, 2), Row(11, 0, 20), false, false), + )) } - @Test def testGenerateEmptyKey() { + @Test def testGenerateEmptyKey(): Unit = { val intervals1 = Array(Interval(Row(), Row(), true, true)) val intervals5 = Array.fill(5)(Interval(Row(), Row(), true, true)) @@ -152,22 +160,28 @@ class RVDPartitionerSuite extends HailSuite { assert(p0.rangeBounds.isEmpty) } - @Test def testIntersect() { + @Test def testIntersect(): Unit = { val kType = TStruct(("key", TInt32)) val left = - new RVDPartitioner(ctx.stateManager, kType, + new RVDPartitioner( + ctx.stateManager, + kType, Array( Interval(Row(1), Row(10), true, false), Interval(Row(12), Row(13), true, false), - Interval(Row(14), Row(19), true, false)) + Interval(Row(14), Row(19), true, false), + ), ) val right = - new RVDPartitioner(ctx.stateManager, kType, + new RVDPartitioner( + ctx.stateManager, + kType, Array( Interval(Row(1), Row(4), true, false), Interval(Row(4), Row(5), true, false), Interval(Row(7), Row(16), true, true), - Interval(Row(19), Row(20), true, true)) + Interval(Row(19), Row(20), true, true), + ), ) assert(left.intersect(right).rangeBounds sameElements Array( @@ -175,8 +189,7 @@ class RVDPartitionerSuite extends HailSuite { Interval(Row(4), Row(5), true, false), Interval(Row(7), Row(10), true, false), Interval(Row(12), Row(13), true, false), - Interval(Row(14), Row(16), true, true) - ) - ) + Interval(Row(14), Row(16), true, true), + )) } } diff --git a/hail/src/test/scala/is/hail/services/batch_client/BatchClientSuite.scala b/hail/src/test/scala/is/hail/services/batch_client/BatchClientSuite.scala index 3585777b438..a084729ae48 100644 --- a/hail/src/test/scala/is/hail/services/batch_client/BatchClientSuite.scala +++ b/hail/src/test/scala/is/hail/services/batch_client/BatchClientSuite.scala @@ -1,8 +1,9 @@ package is.hail.services.batch_client import is.hail.utils._ -import org.json4s.JsonAST._ + import org.json4s.{DefaultFormats, Formats} +import org.json4s.JsonAST._ import org.scalatest.testng.TestNGSuite import org.testng.annotations.Test @@ -14,7 +15,8 @@ class BatchClientSuite extends TestNGSuite { JObject( "billing_project" -> JString("test"), "n_jobs" -> JInt(1), - "token" -> JString(token)), + "token" -> JString(token), + ), FastSeq( JObject( "always_run" -> JBool(false), @@ -25,9 +27,13 @@ class BatchClientSuite extends TestNGSuite { "command" -> JArray(List( JString("/bin/bash"), JString("-c"), - JString("echo 'Hello, world!'"))), - "type" -> JString("docker")) - ))) + JString("echo 'Hello, world!'"), + )), + "type" -> JString("docker"), + ), + ) + ), + ) implicit val formats: Formats = DefaultFormats assert((batch \ "state").extract[String] == "success") } diff --git a/hail/src/test/scala/is/hail/stats/FisherExactTestSuite.scala b/hail/src/test/scala/is/hail/stats/FisherExactTestSuite.scala index 93260334a3a..bf5d68aa584 100644 --- a/hail/src/test/scala/is/hail/stats/FisherExactTestSuite.scala +++ b/hail/src/test/scala/is/hail/stats/FisherExactTestSuite.scala @@ -1,17 +1,12 @@ package is.hail.stats import is.hail.HailSuite -import org.testng.annotations.Test -import scala.language.postfixOps +import org.testng.annotations.Test class FisherExactTestSuite extends HailSuite { - @Test def testPvalue() { - val N = 200 - val K = 100 - val k = 10 - val n = 15 + @Test def testPvalue(): Unit = { val a = 5 val b = 10 val c = 95 diff --git a/hail/src/test/scala/is/hail/stats/GeneralizedChiSquaredDistributionSuite.scala b/hail/src/test/scala/is/hail/stats/GeneralizedChiSquaredDistributionSuite.scala index 64c8b19528d..710cdadefe7 100644 --- a/hail/src/test/scala/is/hail/stats/GeneralizedChiSquaredDistributionSuite.scala +++ b/hail/src/test/scala/is/hail/stats/GeneralizedChiSquaredDistributionSuite.scala @@ -1,8 +1,8 @@ package is.hail.stats import is.hail.HailSuite -import org.testng.annotations.Test +import org.testng.annotations.Test class GeneralizedChiSquaredDistributionSuite extends HailSuite { private[this] def pgenchisq( @@ -12,26 +12,24 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { nc: Array[Double], sigma: Double, lim: Int, - acc: Double - ) = { + acc: Double, + ) = new DaviesAlgorithm(c, n, lb, nc, lim, sigma).cdf(acc) - } - private[this] def nearEqual(a: Double, b: Double): Boolean = { + private[this] def nearEqual(a: Double, b: Double): Boolean = /* Davies only reports 6 significant figures */ Math.abs(a - b) < 0.0000005 - } private[this] def nearEqualDAT(x: DaviesAlgorithmTrace, y: DaviesAlgorithmTrace): Boolean = { val DaviesAlgorithmTrace(a, b, c, d, e, f, g) = x val DaviesAlgorithmTrace(a2, b2, c2, d2, e2, f2, g2) = x (nearEqual(a, a2) && - nearEqual(b, b2) && - nearEqual(c, c2) && - nearEqual(d, d2) && - nearEqual(e, e2) && - nearEqual(f, f2) && - nearEqual(g, g2)) + nearEqual(b, b2) && + nearEqual(c, c2) && + nearEqual(d, d2) && + nearEqual(e, e2) && + nearEqual(f, f2) && + nearEqual(g, g2)) } @Test def test0(): Unit = { @@ -42,11 +40,14 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { Array(0.0, 0.0, 0.0), 0.0, 1000, - 0.0001 + 0.0001, ) assert(nearEqual(actualValue, 0.054213)) - assert(nearEqualDAT(actualTrace, DaviesAlgorithmTrace(0.76235, 744, 2, 0.03819, 53.37969, 0.0, 51))) + assert(nearEqualDAT( + actualTrace, + DaviesAlgorithmTrace(0.76235, 744, 2, 0.03819, 53.37969, 0.0, 51), + )) assert(actualFault == 0) } @@ -58,10 +59,13 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { Array(0.0, 0.0, 0.0), 0.0, 1000, - 0.0001 + 0.0001, ) assert(nearEqual(actualValue, 0.493555)) - assert(nearEqualDAT(actualTrace, DaviesAlgorithmTrace(1.57018, 625, 2, 0.03964, 34.66214, 0.04784, 51))) + assert(nearEqualDAT( + actualTrace, + DaviesAlgorithmTrace(1.57018, 625, 2, 0.03964, 34.66214, 0.04784, 51), + )) assert(actualFault == 0) } @@ -73,10 +77,13 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { Array(0.0, 0.0, 0.0), 0.0, 1000, - 0.0001 + 0.0001, ) assert(nearEqual(actualValue, 0.876027)) - assert(nearEqualDAT(actualTrace, DaviesAlgorithmTrace(3.16244, 346, 1, 0.04602, 15.88681, 0.14159, 32))) + assert(nearEqualDAT( + actualTrace, + DaviesAlgorithmTrace(3.16244, 346, 1, 0.04602, 15.88681, 0.14159, 32), + )) assert(actualFault == 0) } @@ -88,10 +95,13 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { Array(0.0, 0.0, 0.0), 0.0, 1000, - 0.0001 + 0.0001, ) assert(nearEqual(actualValue, 0.006435)) - assert(nearEqualDAT(actualTrace, DaviesAlgorithmTrace(0.84764, 74, 1, 0.03514, 2.55311, 0.0, 22))) + assert(nearEqualDAT( + actualTrace, + DaviesAlgorithmTrace(0.84764, 74, 1, 0.03514, 2.55311, 0.0, 22), + )) assert(actualFault == 0) } @@ -103,10 +113,13 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { Array(0.0, 0.0, 0.0), 0.0, 1000, - 0.0001 + 0.0001, ) assert(nearEqual(actualValue, 0.600208)) - assert(nearEqualDAT(actualTrace, DaviesAlgorithmTrace(1.74138, 66, 1, 0.03907, 2.55311, 0.0, 22))) + assert(nearEqualDAT( + actualTrace, + DaviesAlgorithmTrace(1.74138, 66, 1, 0.03907, 2.55311, 0.0, 22), + )) assert(actualFault == 0) } @@ -118,7 +131,7 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { Array(0.0, 0.0, 0.0), 0.0, 1000, - 0.0001 + 0.0001, ) assert(nearEqual(actualValue, 0.983897)) assert(nearEqualDAT(actualTrace, DaviesAlgorithmTrace(3.72757, 50, 1, 0.052, 2.55311, 0.0, 22))) @@ -133,10 +146,13 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { Array(0.0, 0.0, 0.0), 0.0, 1000, - 0.0001 + 0.0001, ) assert(nearEqual(actualValue, 0.002697)) - assert(nearEqualDAT(actualTrace, DaviesAlgorithmTrace(1.20122, 18, 1, 0.02706, 0.46096, 0.0, 20))) + assert(nearEqualDAT( + actualTrace, + DaviesAlgorithmTrace(1.20122, 18, 1, 0.02706, 0.46096, 0.0, 20), + )) assert(actualFault == 0) } @@ -148,10 +164,13 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { Array(0.0, 0.0, 0.0), 0.0, 1000, - 0.0001 + 0.0001, ) assert(nearEqual(actualValue, 0.564753)) - assert(nearEqualDAT(actualTrace, DaviesAlgorithmTrace(2.06868, 15, 1, 0.03269, 0.46096, 0.0, 20))) + assert(nearEqualDAT( + actualTrace, + DaviesAlgorithmTrace(2.06868, 15, 1, 0.03269, 0.46096, 0.0, 20), + )) assert(actualFault == 0) } @@ -163,10 +182,13 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { Array(0.0, 0.0, 0.0), 0.0, 1000, - 0.0001 + 0.0001, ) assert(nearEqual(actualValue, 0.991229)) - assert(nearEqualDAT(actualTrace, DaviesAlgorithmTrace(3.58496, 10, 1, 0.05141, 0.46096, 0.0, 20))) + assert(nearEqualDAT( + actualTrace, + DaviesAlgorithmTrace(3.58496, 10, 1, 0.05141, 0.46096, 0.0, 20), + )) assert(actualFault == 0) } @@ -178,10 +200,13 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { Array(0.0, 0.0, 0.0), 0.0, 1000, - 0.0001 + 0.0001, ) assert(nearEqual(actualValue, 0.033357)) - assert(nearEqualDAT(actualTrace, DaviesAlgorithmTrace(1.29976, 27, 1, 0.03459, 0.88302, 0.0, 19))) + assert(nearEqualDAT( + actualTrace, + DaviesAlgorithmTrace(1.29976, 27, 1, 0.03459, 0.88302, 0.0, 19), + )) assert(actualFault == 0) } @@ -193,10 +218,13 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { Array(0.0, 0.0, 0.0), 0.0, 1000, - 0.0001 + 0.0001, ) assert(nearEqual(actualValue, 0.580446)) - assert(nearEqualDAT(actualTrace, DaviesAlgorithmTrace(2.01747, 24, 1, 0.03887, 0.88302, 0.0, 19))) + assert(nearEqualDAT( + actualTrace, + DaviesAlgorithmTrace(2.01747, 24, 1, 0.03887, 0.88302, 0.0, 19), + )) assert(actualFault == 0) } @@ -208,10 +236,13 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { Array(0.0, 0.0, 0.0), 0.0, 1000, - 0.0001 + 0.0001, ) assert(nearEqual(actualValue, 0.991283)) - assert(nearEqualDAT(actualTrace, DaviesAlgorithmTrace(3.81157, 17, 1, 0.05628, 0.88302, 0.0, 19))) + assert(nearEqualDAT( + actualTrace, + DaviesAlgorithmTrace(3.81157, 17, 1, 0.05628, 0.88302, 0.0, 19), + )) assert(actualFault == 0) } @@ -223,10 +254,13 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { Array(6.0, 2.0), 0.0, 1000, - 0.0001 + 0.0001, ) assert(nearEqual(actualValue, 0.006125)) - assert(nearEqualDAT(actualTrace, DaviesAlgorithmTrace(1.16271, 16, 1, 0.01561, 0.24013, 0.0, 19))) + assert(nearEqualDAT( + actualTrace, + DaviesAlgorithmTrace(1.16271, 16, 1, 0.01561, 0.24013, 0.0, 19), + )) assert(actualFault == 0) } @@ -238,10 +272,13 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { Array(6.0, 2.0), 0.0, 1000, - 0.0001 + 0.0001, ) assert(nearEqual(actualValue, 0.591339)) - assert(nearEqualDAT(actualTrace, DaviesAlgorithmTrace(2.02277, 13, 1, 0.01949, 0.24013, 0.0, 19))) + assert(nearEqualDAT( + actualTrace, + DaviesAlgorithmTrace(2.02277, 13, 1, 0.01949, 0.24013, 0.0, 19), + )) assert(actualFault == 0) } @@ -253,10 +290,13 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { Array(6.0, 2.0), 0.0, 1000, - 0.0001 + 0.0001, ) assert(nearEqual(actualValue, 0.977914)) - assert(nearEqualDAT(actualTrace, DaviesAlgorithmTrace(3.09687, 10, 1, 0.02825, 0.24013, 0.0, 19))) + assert(nearEqualDAT( + actualTrace, + DaviesAlgorithmTrace(3.09687, 10, 1, 0.02825, 0.24013, 0.0, 19), + )) assert(actualFault == 0) } @@ -268,10 +308,13 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { Array(6.0, 2.0), 0.0, 1000, - 0.0001 + 0.0001, ) assert(nearEqual(actualValue, 0.045126)) - assert(nearEqualDAT(actualTrace, DaviesAlgorithmTrace(0.8712, 603, 2, 0.01628, 13.86318, 0.0, 49))) + assert(nearEqualDAT( + actualTrace, + DaviesAlgorithmTrace(0.8712, 603, 2, 0.01628, 13.86318, 0.0, 49), + )) assert(actualFault == 0) } @@ -283,10 +326,13 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { Array(6.0, 2.0), 0.0, 1000, - 0.0001 + 0.0001, ) assert(nearEqual(actualValue, 0.592431)) - assert(nearEqualDAT(actualTrace, DaviesAlgorithmTrace(1.69157, 340, 1, 0.02043, 6.93159, 0.24644, 31))) + assert(nearEqualDAT( + actualTrace, + DaviesAlgorithmTrace(1.69157, 340, 1, 0.02043, 6.93159, 0.24644, 31), + )) assert(actualFault == 0) } @@ -298,10 +344,13 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { Array(6.0, 2.0), 0.0, 1000, - 0.0001 + 0.0001, ) assert(nearEqual(actualValue, 0.977648)) - assert(nearEqualDAT(actualTrace, DaviesAlgorithmTrace(3.06625, 87, 1, 0.02888, 2.47557, 0.81533, 29))) + assert(nearEqualDAT( + actualTrace, + DaviesAlgorithmTrace(3.06625, 87, 1, 0.02888, 2.47557, 0.81533, 29), + )) assert(actualFault == 0) } @@ -313,7 +362,7 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { Array(0.0, 0.0, 0.0, 0.0, 0.0, 0.0), 0.0, 1000, - 0.0001 + 0.0001, ) assert(nearEqual(actualValue, 0.01095)) assert(nearEqualDAT(actualTrace, DaviesAlgorithmTrace(1.82147, 13, 1, 0.01582, 0.193, 0.0, 18))) @@ -328,7 +377,7 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { Array(0.0, 0.0, 0.0, 0.0, 0.0, 0.0), 0.0, 1000, - 0.0001 + 0.0001, ) assert(nearEqual(actualValue, 0.654735)) assert(nearEqualDAT(actualTrace, DaviesAlgorithmTrace(2.73768, 11, 1, 0.0195, 0.193, 0.0, 18))) @@ -343,7 +392,7 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { Array(0.0, 0.0, 0.0, 0.0, 0.0, 0.0), 0.0, 1000, - 0.0001 + 0.0001, ) assert(nearEqual(actualValue, 0.984606)) assert(nearEqualDAT(actualTrace, DaviesAlgorithmTrace(3.83651, 8, 1, 0.02707, 0.193, 0.0, 18))) @@ -358,10 +407,13 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { Array(6.0, 2.0, 6.0, 2.0), 0.0, 1000, - 0.0001 + 0.0001, ) assert(nearEqual(actualValue, 0.043679)) - assert(nearEqualDAT(actualTrace, DaviesAlgorithmTrace(1.65876, 10, 1, 0.01346, 0.12785, 0.0, 18))) + assert(nearEqualDAT( + actualTrace, + DaviesAlgorithmTrace(1.65876, 10, 1, 0.01346, 0.12785, 0.0, 18), + )) assert(actualFault == 0) } @@ -373,10 +425,13 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { Array(6.0, 2.0, 6.0, 2.0), 0.0, 1000, - 0.0001 + 0.0001, ) assert(nearEqual(actualValue, 0.584765)) - assert(nearEqualDAT(actualTrace, DaviesAlgorithmTrace(2.34799, 9, 1, 0.01668, 0.12785, 0.0, 18))) + assert(nearEqualDAT( + actualTrace, + DaviesAlgorithmTrace(2.34799, 9, 1, 0.01668, 0.12785, 0.0, 18), + )) assert(actualFault == 0) } @@ -388,10 +443,13 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { Array(6.0, 2.0, 6.0, 2.0), 0.0, 1000, - 0.0001 + 0.0001, ) assert(nearEqual(actualValue, 0.953774)) - assert(nearEqualDAT(actualTrace, DaviesAlgorithmTrace(3.11236, 7, 1, 0.02271, 0.12785, 0.0, 18))) + assert(nearEqualDAT( + actualTrace, + DaviesAlgorithmTrace(3.11236, 7, 1, 0.02271, 0.12785, 0.0, 18), + )) assert(actualFault == 0) } @@ -403,10 +461,13 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { Array(6.0, 2.0, 6.0, 2.0), 0.0, 1000, - 0.0001 + 0.0001, ) assert(nearEqual(actualValue, 0.078208)) - assert(nearEqualDAT(actualTrace, DaviesAlgorithmTrace(1.42913, 10, 1, 0.01483, 0.12785, 0.0, 19))) + assert(nearEqualDAT( + actualTrace, + DaviesAlgorithmTrace(1.42913, 10, 1, 0.01483, 0.12785, 0.0, 19), + )) assert(actualFault == 0) } @@ -418,10 +479,13 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { Array(6.0, 2.0, 6.0, 2.0), 0.0, 1000, - 0.0001 + 0.0001, ) assert(nearEqual(actualValue, 0.522108)) - assert(nearEqualDAT(actualTrace, DaviesAlgorithmTrace(1.42909, 8, 1, 0.01771, 0.12785, 0.0, 19))) + assert(nearEqualDAT( + actualTrace, + DaviesAlgorithmTrace(1.42909, 8, 1, 0.01771, 0.12785, 0.0, 19), + )) assert(actualFault == 0) } @@ -433,10 +497,13 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { Array(6.0, 2.0, 6.0, 2.0), 0.0, 1000, - 0.0001 + 0.0001, ) assert(nearEqual(actualValue, 0.96037)) - assert(nearEqualDAT(actualTrace, DaviesAlgorithmTrace(2.19476, 10, 1, 0.01381, 0.12785, 0.0, 19))) + assert(nearEqualDAT( + actualTrace, + DaviesAlgorithmTrace(2.19476, 10, 1, 0.01381, 0.12785, 0.0, 19), + )) assert(actualFault == 0) } @@ -448,10 +515,13 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { Array(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 6.0, 2.0, 6.0, 2.0), 0.0, 1000, - 0.0001 + 0.0001, ) assert(nearEqual(actualValue, 0.015844)) - assert(nearEqualDAT(actualTrace, DaviesAlgorithmTrace(2.33438, 9, 1, 0.01202, 0.09616, 0.0, 18))) + assert(nearEqualDAT( + actualTrace, + DaviesAlgorithmTrace(2.33438, 9, 1, 0.01202, 0.09616, 0.0, 18), + )) assert(actualFault == 0) } @@ -463,7 +533,7 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { Array(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 6.0, 2.0, 6.0, 2.0), 0.0, 1000, - 0.0001 + 0.0001, ) assert(nearEqual(actualValue, 0.573625)) assert(nearEqualDAT(actualTrace, DaviesAlgorithmTrace(3.1401, 7, 1, 0.01561, 0.09616, 0.0, 18))) @@ -478,7 +548,7 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { Array(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 6.0, 2.0, 6.0, 2.0), 0.0, 1000, - 0.0001 + 0.0001, ) assert(nearEqual(actualValue, 0.988332)) assert(nearEqualDAT(actualTrace, DaviesAlgorithmTrace(4.2142, 6, 1, 0.01812, 0.09616, 0.0, 18))) @@ -493,10 +563,13 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { Array(0.0, 0.0), 0.0, 1000, - 0.0001 + 0.0001, ) assert(nearEqual(actualValue, 0.015392)) - assert(nearEqualDAT(actualTrace, DaviesAlgorithmTrace(0.95892, 163, 1, 0.00841, 1.3638, 0.0, 22))) + assert(nearEqualDAT( + actualTrace, + DaviesAlgorithmTrace(0.95892, 163, 1, 0.00841, 1.3638, 0.0, 22), + )) assert(actualFault == 0) } @@ -508,10 +581,13 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { Array(0.0, 0.0), 0.0, 1000, - 0.0001 + 0.0001, ) assert(nearEqual(actualValue, 0.510819)) - assert(nearEqualDAT(actualTrace, DaviesAlgorithmTrace(1.72922, 159, 1, 0.00864, 1.3638, 0.0, 22))) + assert(nearEqualDAT( + actualTrace, + DaviesAlgorithmTrace(1.72922, 159, 1, 0.00864, 1.3638, 0.0, 22), + )) assert(actualFault == 0) } @@ -523,10 +599,13 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { Array(0.0, 0.0), 0.0, 1000, - 0.0001 + 0.0001, ) assert(nearEqual(actualValue, 0.91634)) - assert(nearEqualDAT(actualTrace, DaviesAlgorithmTrace(4.61788, 143, 1, 0.00963, 1.3638, 0.0, 22))) + assert(nearEqualDAT( + actualTrace, + DaviesAlgorithmTrace(4.61788, 143, 1, 0.00963, 1.3638, 0.0, 22), + )) assert(actualFault == 0) } @@ -538,10 +617,13 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { Array(0.0, 0.0), 0.0, 1000, - 0.0001 + 0.0001, ) assert(nearEqual(actualValue, 0.004925)) - assert(nearEqualDAT(actualTrace, DaviesAlgorithmTrace(1.26245, 97, 1, 0.00839, 0.80736, 0.0, 21))) + assert(nearEqualDAT( + actualTrace, + DaviesAlgorithmTrace(1.26245, 97, 1, 0.00839, 0.80736, 0.0, 21), + )) assert(actualFault == 0) } @@ -553,10 +635,13 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { Array(0.0, 0.0), 0.0, 1000, - 0.0001 + 0.0001, ) assert(nearEqual(actualValue, 0.573251)) - assert(nearEqualDAT(actualTrace, DaviesAlgorithmTrace(2.16513, 93, 1, 0.00874, 0.80736, 0.0, 21))) + assert(nearEqualDAT( + actualTrace, + DaviesAlgorithmTrace(2.16513, 93, 1, 0.00874, 0.80736, 0.0, 21), + )) assert(actualFault == 0) } @@ -568,10 +653,13 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { Array(0.0, 0.0), 0.0, 1000, - 0.0001 + 0.0001, ) assert(nearEqual(actualValue, 0.896501)) - assert(nearEqualDAT(actualTrace, DaviesAlgorithmTrace(3.97055, 86, 1, 0.00954, 0.80736, 0.0, 21))) + assert(nearEqualDAT( + actualTrace, + DaviesAlgorithmTrace(3.97055, 86, 1, 0.00954, 0.80736, 0.0, 21), + )) assert(actualFault == 0) } @@ -583,10 +671,13 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { Array(0.0, 0.0), 0.0, 1000, - 0.0001 + 0.0001, ) assert(nearEqual(actualValue, 0.017101)) - assert(nearEqualDAT(actualTrace, DaviesAlgorithmTrace(1.65684, 81, 1, 0.00843, 0.67453, 0.0, 20))) + assert(nearEqualDAT( + actualTrace, + DaviesAlgorithmTrace(1.65684, 81, 1, 0.00843, 0.67453, 0.0, 20), + )) assert(actualFault == 0) } @@ -598,10 +689,13 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { Array(0.0, 0.0), 0.0, 1000, - 0.0001 + 0.0001, ) assert(nearEqual(actualValue, 0.566488)) - assert(nearEqualDAT(actualTrace, DaviesAlgorithmTrace(2.44382, 78, 1, 0.00878, 0.67453, 0.0, 20))) + assert(nearEqualDAT( + actualTrace, + DaviesAlgorithmTrace(2.44382, 78, 1, 0.00878, 0.67453, 0.0, 20), + )) assert(actualFault == 0) } @@ -613,10 +707,13 @@ class GeneralizedChiSquaredDistributionSuite extends HailSuite { Array(0.0, 0.0), 0.0, 1000, - 0.0001 + 0.0001, ) assert(nearEqual(actualValue, 0.871323)) - assert(nearEqualDAT(actualTrace, DaviesAlgorithmTrace(3.75545, 72, 1, 0.00944, 0.67453, 0.0, 20))) + assert(nearEqualDAT( + actualTrace, + DaviesAlgorithmTrace(3.75545, 72, 1, 0.00944, 0.67453, 0.0, 20), + )) assert(actualFault == 0) } } diff --git a/hail/src/test/scala/is/hail/stats/LeveneHaldaneSuite.scala b/hail/src/test/scala/is/hail/stats/LeveneHaldaneSuite.scala index 9a90464c016..44c645cc40b 100644 --- a/hail/src/test/scala/is/hail/stats/LeveneHaldaneSuite.scala +++ b/hail/src/test/scala/is/hail/stats/LeveneHaldaneSuite.scala @@ -1,6 +1,7 @@ package is.hail.stats import is.hail.utils._ + import org.apache.commons.math3.util.CombinatoricsUtils.factorialLog import org.scalatest.testng.TestNGSuite import org.testng.annotations.Test @@ -14,15 +15,19 @@ class LeveneHaldaneSuite extends TestNGSuite { val nB = 2 * n - nA val nAA = (nA - nAB) / 2 val nBB = (nB - nAB) / 2 - math.exp(nAB * math.log(2) + factorialLog(n) - (factorialLog(nAA) + factorialLog(nAB) + factorialLog(nBB)) - (factorialLog(2 * n) - (factorialLog(nA) + factorialLog(nB)))) + math.exp(nAB * math.log(2) + factorialLog(n) - (factorialLog(nAA) + factorialLog( + nAB + ) + factorialLog(nBB)) - (factorialLog(2 * n) - (factorialLog(nA) + factorialLog(nB)))) } } - // val examples = List((15, 10), (15, 9), (15, 0), (15, 15), (1, 0), (1, 1), (0, 0), (1526, 431), (1526, 430), (10000,1500)) + /* val examples = List((15, 10), (15, 9), (15, 0), (15, 15), (1, 0), (1, 1), (0, 0), (1526, 431), + * (1526, 430), (10000,1500)) */ // The above is commented out because ain't nobody got time for that. - val examples = List((15, 10), (15, 9), (15, 0), (15, 15), (1, 0), (1, 1), (0, 0), (1526, 431), (1526, 430)) + val examples = + List((15, 10), (15, 9), (15, 0), (15, 15), (1, 0), (1, 1), (0, 0), (1526, 431), (1526, 430)) - @Test def pmfTest() { + @Test def pmfTest(): Unit = { def test(e: (Int, Int)): Boolean = { val (n, nA) = e @@ -33,52 +38,68 @@ class LeveneHaldaneSuite extends TestNGSuite { examples foreach { e => assert(test(e)) } } - @Test def modeTest() { + @Test def modeTest(): Unit = { def test(e: (Int, Int)): Boolean = { val (n, nA) = e val LH = LeveneHaldane(n, nA) D_==(LH.probability(LH.mode), (nA % 2 to nA by 2).map(LH.probability).max) } - examples foreach {e => assert(test(e))} + examples foreach { e => assert(test(e)) } } - @Test def meanTest() { + @Test def meanTest(): Unit = { def test(e: (Int, Int)): Boolean = { val (n, nA) = e val LH = LeveneHaldane(n, nA) - D_==(LH.getNumericalMean, (LH.getSupportLowerBound to LH.getSupportUpperBound).map(i => i * LH.probability(i)).sum) + D_==( + LH.getNumericalMean, + (LH.getSupportLowerBound to LH.getSupportUpperBound).map(i => i * LH.probability(i)).sum, + ) } - examples foreach {e => assert(test(e))} + examples foreach { e => assert(test(e)) } } - @Test def varianceTest() { + @Test def varianceTest(): Unit = { def test(e: (Int, Int)): Boolean = { val (n, nA) = e val LH = LeveneHaldane(n, nA) - D_==(LH.getNumericalVariance + LH.getNumericalMean * LH.getNumericalMean, (LH.getSupportLowerBound to LH.getSupportUpperBound).map(i => i * i * LH.probability(i)).sum) + D_==( + LH.getNumericalVariance + LH.getNumericalMean * LH.getNumericalMean, + (LH.getSupportLowerBound to LH.getSupportUpperBound).map(i => i * i * LH.probability(i)).sum, + ) } - examples foreach {e => assert(test(e))} + examples foreach { e => assert(test(e)) } } - @Test def exactTestsTest() { + @Test def exactTestsTest(): Unit = { def test(e: (Int, Int)): Boolean = { val (n, nA) = e val LH = LeveneHaldane(n, nA) - (-2 to nA + 2).forall(nAB => ( - D_==(LH.leftMidP(nAB) + LH.rightMidP(nAB), 1.0) - && D_==(LH.leftMidP(nAB), - 0.5 * LH.probability(nAB) + (0 to nAB - 1).map(LH.probability).sum) - && D_==(LH.exactMidP(nAB), - {val p0 = LH.probability(nAB) - (0 to nA).map(LH.probability).filter(D_<(_, p0, tolerance = 1.0E-12)).sum + 0.5 * (0 to nA).map(LH.probability).filter(D_==(_, p0, tolerance = 1.0E-12)).sum - }) - )) + (-2 to nA + 2).forall(nAB => + ( + D_==(LH.leftMidP(nAB) + LH.rightMidP(nAB), 1.0) + && D_==( + LH.leftMidP(nAB), + 0.5 * LH.probability(nAB) + (0 to nAB - 1).map(LH.probability).sum, + ) + && D_==( + LH.exactMidP(nAB), { + val p0 = LH.probability(nAB) + (0 to nA).map(LH.probability).filter( + D_<(_, p0, tolerance = 1.0e-12) + ).sum + 0.5 * (0 to nA).map(LH.probability).filter( + D_==(_, p0, tolerance = 1.0e-12) + ).sum + }, + ) + ) + ) } - examples foreach {e => assert(test(e))} + examples foreach { e => assert(test(e)) } } } diff --git a/hail/src/test/scala/is/hail/stats/LogisticRegressionModelSuite.scala b/hail/src/test/scala/is/hail/stats/LogisticRegressionModelSuite.scala index 3cad1c45ca1..91f8e22391c 100644 --- a/hail/src/test/scala/is/hail/stats/LogisticRegressionModelSuite.scala +++ b/hail/src/test/scala/is/hail/stats/LogisticRegressionModelSuite.scala @@ -1,25 +1,18 @@ package is.hail.stats -import breeze.linalg._ import is.hail.HailSuite import is.hail.utils._ -import org.testng.SkipException -import org.testng.annotations.Test -import scala.language.postfixOps -import scala.sys.process._ +import breeze.linalg._ +import org.testng.annotations.Test class LogisticRegressionModelSuite extends HailSuite { - @Test def covariatesVsInterceptOnlyTest() { + @Test def covariatesVsInterceptOnlyTest(): Unit = { /* R code: - y0 = c(0, 0, 1, 1, 1, 1) - c1 = c(0, 2, 1, -2, -2, 4) - c2 = c(-1, 3, 5, 0, -4, 3) - logfit <- glm(y0 ~ c1 + c2, family=binomial(link="logit")) - summary(logfit)["coefficients"] - */ + * y0 = c(0, 0, 1, 1, 1, 1) c1 = c(0, 2, 1, -2, -2, 4) c2 = c(-1, 3, 5, 0, -4, 3) logfit <- + * glm(y0 ~ c1 + c2, family=binomial(link="logit")) summary(logfit)["coefficients"] */ val y = DenseVector(0d, 0d, 1d, 1d, 1d, 1d) @@ -29,7 +22,8 @@ class LogisticRegressionModelSuite extends HailSuite { (1.0, 5.0), (-2.0, 0.0), (-2.0, -4.0), - (4.0, 3.0)) + (4.0, 3.0), + ) val ones = DenseMatrix.fill[Double](6, 1)(1d) @@ -42,41 +36,36 @@ class LogisticRegressionModelSuite extends HailSuite { val lrStats = LikelihoodRatioTest.test(X, y, nullFit, "logistic", 25, 1e-5).stats.get val scoreStats = LogisticScoreTest.test(X, y, nullFit, "logistic", 25, 1e-5).stats.get - assert(D_==(waldStats.b(0), 0.7245034, tolerance = 1.0E-6)) - assert(D_==(waldStats.b(1), -0.3585773, tolerance = 1.0E-6)) - assert(D_==(waldStats.b(2), 0.1922622, tolerance = 1.0E-6)) + assert(D_==(waldStats.b(0), 0.7245034, tolerance = 1.0e-6)) + assert(D_==(waldStats.b(1), -0.3585773, tolerance = 1.0e-6)) + assert(D_==(waldStats.b(2), 0.1922622, tolerance = 1.0e-6)) - assert(D_==(waldStats.se(0), 0.9396654, tolerance = 1.0E-6)) - assert(D_==(waldStats.se(1), 0.6246568, tolerance = 1.0E-6)) - assert(D_==(waldStats.se(2), 0.4559844, tolerance = 1.0E-6)) + assert(D_==(waldStats.se(0), 0.9396654, tolerance = 1.0e-6)) + assert(D_==(waldStats.se(1), 0.6246568, tolerance = 1.0e-6)) + assert(D_==(waldStats.se(2), 0.4559844, tolerance = 1.0e-6)) - assert(D_==(waldStats.z(0), 0.7710228, tolerance = 1.0E-6)) - assert(D_==(waldStats.z(1), -0.5740389, tolerance = 1.0E-6)) - assert(D_==(waldStats.z(2), 0.4216421, tolerance = 1.0E-5)) + assert(D_==(waldStats.z(0), 0.7710228, tolerance = 1.0e-6)) + assert(D_==(waldStats.z(1), -0.5740389, tolerance = 1.0e-6)) + assert(D_==(waldStats.z(2), 0.4216421, tolerance = 1.0e-5)) - assert(D_==(waldStats.p(0), 0.4406934, tolerance = 1.0E-6)) - assert(D_==(waldStats.p(1), 0.5659415, tolerance = 1.0E-6)) - assert(D_==(waldStats.p(2), 0.6732863, tolerance = 1.0E-6)) + assert(D_==(waldStats.p(0), 0.4406934, tolerance = 1.0e-6)) + assert(D_==(waldStats.p(1), 0.5659415, tolerance = 1.0e-6)) + assert(D_==(waldStats.p(2), 0.6732863, tolerance = 1.0e-6)) - assert(D_==(nullFit.logLkhd, -3.81908501, tolerance = 1.0E-6)) - assert(D_==(lrStats.chi2, 0.351536062, tolerance = 1.0E-6)) - assert(D_==(lrStats.p, 0.8388125392, tolerance = 1.0E-5)) + assert(D_==(nullFit.logLkhd, -3.81908501, tolerance = 1.0e-6)) + assert(D_==(lrStats.chi2, 0.351536062, tolerance = 1.0e-6)) + assert(D_==(lrStats.p, 0.8388125392, tolerance = 1.0e-5)) - assert(D_==(scoreStats.chi2, 0.346648, tolerance = 1.0E-5)) - assert(D_==(scoreStats.p, 0.8408652791, tolerance = 1.0E-5)) + assert(D_==(scoreStats.chi2, 0.346648, tolerance = 1.0e-5)) + assert(D_==(scoreStats.p, 0.8408652791, tolerance = 1.0e-5)) } - - @Test def gtsAndCovariatesVsCovariatesOnlyTest() { + @Test def gtsAndCovariatesVsCovariatesOnlyTest(): Unit = { /* R code: - y0 <- c(0, 0, 1, 1, 1, 1) - c1 <- c(0, 2, 1, -2, -2, 4) - c2 <- c(-1, 3, 5, 0, -4, 3) - gts <- c(0, 1, 2, 0, 0, 1) - logfit <- glm(y0 ~ c1 + c2 + gts, family=binomial(link="logit")) - coef(summary(logfit)) - */ + * y0 <- c(0, 0, 1, 1, 1, 1) c1 <- c(0, 2, 1, -2, -2, 4) c2 <- c(-1, 3, 5, 0, -4, 3) gts <- c(0, + * 1, 2, 0, 0, 1) logfit <- glm(y0 ~ c1 + c2 + gts, family=binomial(link="logit")) + * coef(summary(logfit)) */ val y = DenseVector(0d, 0d, 1d, 1d, 1d, 1d) @@ -86,7 +75,8 @@ class LogisticRegressionModelSuite extends HailSuite { (1.0, 1.0, 5.0), (1.0, -2.0, 0.0), (1.0, -2.0, -4.0), - (1.0, 4.0, 3.0)) + (1.0, 4.0, 3.0), + ) val gts = DenseVector(0d, 1d, 2d, 0d, 0d, 1d) @@ -99,48 +89,43 @@ class LogisticRegressionModelSuite extends HailSuite { val lrStats = LikelihoodRatioTest.test(X, y, nullFit, "logistic", 25, 1e-5).stats.get val scoreStats = LogisticScoreTest.test(X, y, nullFit, "logistic", 25, 1e-5).stats.get - assert(D_==(waldStats.b(0), -0.4811418, tolerance = 1.0E-6)) - assert(D_==(waldStats.b(1), -0.4293547, tolerance = 1.0E-6)) - assert(D_==(waldStats.b(2), -0.4214875, tolerance = 1.0E-6)) - assert(D_==(waldStats.b(3), 3.1775729, tolerance = 1.0E-6)) + assert(D_==(waldStats.b(0), -0.4811418, tolerance = 1.0e-6)) + assert(D_==(waldStats.b(1), -0.4293547, tolerance = 1.0e-6)) + assert(D_==(waldStats.b(2), -0.4214875, tolerance = 1.0e-6)) + assert(D_==(waldStats.b(3), 3.1775729, tolerance = 1.0e-6)) - assert(D_==(waldStats.se(0), 1.6203668, tolerance = 1.0E-6)) - assert(D_==(waldStats.se(1), 0.7256665, tolerance = 1.0E-6)) - assert(D_==(waldStats.se(2), 0.9223569, tolerance = 1.0E-6)) - assert(D_==(waldStats.se(3), 4.1094207, tolerance = 1.0E-6)) + assert(D_==(waldStats.se(0), 1.6203668, tolerance = 1.0e-6)) + assert(D_==(waldStats.se(1), 0.7256665, tolerance = 1.0e-6)) + assert(D_==(waldStats.se(2), 0.9223569, tolerance = 1.0e-6)) + assert(D_==(waldStats.se(3), 4.1094207, tolerance = 1.0e-6)) - assert(D_==(waldStats.z(0), -0.2969339, tolerance = 1.0E-6)) - assert(D_==(waldStats.z(1), -0.5916695, tolerance = 1.0E-6)) - assert(D_==(waldStats.z(2), -0.4569679, tolerance = 1.0E-6)) - assert(D_==(waldStats.z(3), 0.7732411, tolerance = 1.0E-6)) + assert(D_==(waldStats.z(0), -0.2969339, tolerance = 1.0e-6)) + assert(D_==(waldStats.z(1), -0.5916695, tolerance = 1.0e-6)) + assert(D_==(waldStats.z(2), -0.4569679, tolerance = 1.0e-6)) + assert(D_==(waldStats.z(3), 0.7732411, tolerance = 1.0e-6)) - assert(D_==(waldStats.p(0), 0.7665170, tolerance = 1.0E-6)) - assert(D_==(waldStats.p(1), 0.5540720, tolerance = 1.0E-6)) - assert(D_==(waldStats.p(2), 0.6476941, tolerance = 1.0E-6)) - assert(D_==(waldStats.p(3), 0.4393797, tolerance = 1.0E-6)) + assert(D_==(waldStats.p(0), 0.7665170, tolerance = 1.0e-6)) + assert(D_==(waldStats.p(1), 0.5540720, tolerance = 1.0e-6)) + assert(D_==(waldStats.p(2), 0.6476941, tolerance = 1.0e-6)) + assert(D_==(waldStats.p(3), 0.4393797, tolerance = 1.0e-6)) /* R code: - logfitnull <- glm(y0 ~ c1 + c2, family=binomial(link="logit")) - lrtest <- anova(logfitnull, logfit, test="LRT") - chi2 <- lrtest[["Deviance"]][2] - pval <- lrtest[["Pr(>Chi)"]][2] - */ + * logfitnull <- glm(y0 ~ c1 + c2, family=binomial(link="logit")) lrtest <- anova(logfitnull, + * logfit, test="LRT") chi2 <- lrtest[["Deviance"]][2] pval <- lrtest[["Pr(>Chi)"]][2] */ - assert(D_==(nullFit.logLkhd, -3.643316979, tolerance = 1.0E-6)) - assert(D_==(lrStats.chi2, 0.8733246, tolerance = 1.0E-6)) - assert(D_==(lrStats.p, 0.3500365602, tolerance = 1.0E-6)) + assert(D_==(nullFit.logLkhd, -3.643316979, tolerance = 1.0e-6)) + assert(D_==(lrStats.chi2, 0.8733246, tolerance = 1.0e-6)) + assert(D_==(lrStats.p, 0.3500365602, tolerance = 1.0e-6)) /* R code: - scoretest <- anova(logfitnull, logfit, test="Rao") - chi2 <- scoretest[["Rao"]][2] - pval <- scoretest[["Pr(>Chi)"]][2] - */ + * scoretest <- anova(logfitnull, logfit, test="Rao") chi2 <- scoretest[["Rao"]][2] pval <- + * scoretest[["Pr(>Chi)"]][2] */ - assert(D_==(scoreStats.chi2, 0.7955342582, tolerance = 1.0E-5)) - assert(D_==(scoreStats.p, 0.3724319159, tolerance = 1.0E-5)) + assert(D_==(scoreStats.chi2, 0.7955342582, tolerance = 1.0e-5)) + assert(D_==(scoreStats.p, 0.3724319159, tolerance = 1.0e-5)) } - @Test def firthSeparationTest() { + @Test def firthSeparationTest(): Unit = { val y = DenseVector(0d, 0d, 0d, 1d, 1d, 1d) val X = y.asDenseMatrix.t @@ -148,8 +133,16 @@ class LogisticRegressionModelSuite extends HailSuite { var nullFit = nullModel.fit(None, 25, 1e-6) assert(!nullFit.converged) - nullFit = GLMFit(nullModel.bInterceptOnly(), - None, None, 0, nullFit.nIter, exploded = nullFit.exploded, converged = false) - assert(LogisticFirthTest.test(X, y, nullFit, "logistic", 25, 1e-5).stats.isDefined) // firth penalized MLE still exists - } + nullFit = GLMFit( + nullModel.bInterceptOnly(), + None, + None, + 0, + nullFit.nIter, + exploded = nullFit.exploded, + converged = false, + ) + assert(LogisticFirthTest.test(X, y, nullFit, "logistic", 25, + 1e-5).stats.isDefined) // firth penalized MLE still exists + } } diff --git a/hail/src/test/scala/is/hail/stats/StatsSuite.scala b/hail/src/test/scala/is/hail/stats/StatsSuite.scala index 6a19c53cffc..f8916ca718e 100644 --- a/hail/src/test/scala/is/hail/stats/StatsSuite.scala +++ b/hail/src/test/scala/is/hail/stats/StatsSuite.scala @@ -1,20 +1,17 @@ package is.hail.stats -import breeze.linalg.DenseMatrix -import is.hail.TestUtils._ -import is.hail.testUtils._ +import is.hail.HailSuite import is.hail.utils._ -import is.hail.variant._ -import is.hail.{HailSuite, TestUtils} + import org.apache.commons.math3.distribution.{ChiSquaredDistribution, NormalDistribution} import org.testng.annotations.Test class StatsSuite extends HailSuite { - @Test def chiSquaredTailTest() { + @Test def chiSquaredTailTest(): Unit = { val chiSq1 = new ChiSquaredDistribution(1) - assert(D_==(pchisqtail(1d,1), 1 - chiSq1.cumulativeProbability(1d))) - assert(D_==(pchisqtail(5.52341d,1), 1 - chiSq1.cumulativeProbability(5.52341d))) + assert(D_==(pchisqtail(1d, 1), 1 - chiSq1.cumulativeProbability(1d))) + assert(D_==(pchisqtail(5.52341d, 1), 1 - chiSq1.cumulativeProbability(5.52341d))) val chiSq2 = new ChiSquaredDistribution(2) assert(D_==(pchisqtail(1, 2), 1 - chiSq2.cumulativeProbability(1))) @@ -35,7 +32,7 @@ class StatsSuite extends HailSuite { assert(D_==(qchisqtail(5.507248e-89, 1), 400)) } - @Test def normalTest() { + @Test def normalTest(): Unit = { val normalDist = new NormalDistribution() assert(D_==(pnorm(1), normalDist.cumulativeProbability(1))) assert(math.abs(pnorm(-10) - normalDist.cumulativeProbability(-10)) < 1e-10) @@ -50,7 +47,7 @@ class StatsSuite extends HailSuite { assert(D_==(qnorm(2.753624e-89), -20)) } - @Test def poissonTest() { + @Test def poissonTest(): Unit = { // compare with R assert(D_==(dpois(5, 10), 0.03783327)) assert(qpois(0.3, 10) == 8) @@ -59,15 +56,20 @@ class StatsSuite extends HailSuite { assert(D_==(ppois(5, 10, lowerTail = false, logP = false), 0.932914)) assert(qpois(ppois(5, 10), 10) == 5) - assert(qpois(ppois(5, 10, lowerTail = false, logP = false), 10, lowerTail = false, logP = false) == 5) + assert(qpois( + ppois(5, 10, lowerTail = false, logP = false), + 10, + lowerTail = false, + logP = false, + ) == 5) assert(ppois(30, 1, lowerTail = false, logP = false) > 0) } - @Test def betaTest() { + @Test def betaTest(): Unit = { val tol = 1e-5 - assert(D_==(dbeta(.2 , 1, 3), 1.92, tol)) + assert(D_==(dbeta(.2, 1, 3), 1.92, tol)) assert(D_==(dbeta(0.70, 2, 10), 0.001515591, tol)) assert(D_==(dbeta(.4, 5, 3), 0.96768, tol)) assert(D_==(dbeta(.3, 7, 2), 0.0285768, tol)) @@ -80,7 +82,7 @@ class StatsSuite extends HailSuite { } - @Test def entropyTest() { + @Test def entropyTest(): Unit = { assert(D_==(entropy("accctg"), 1.79248, tolerance = 1e-5)) assert(D_==(entropy(Array(2, 3, 4, 5, 6, 6, 4)), 2.23593, tolerance = 1e-5)) diff --git a/hail/src/test/scala/is/hail/stats/eigSymDSuite.scala b/hail/src/test/scala/is/hail/stats/eigSymDSuite.scala index 038a97ef49b..73bcb32e9e4 100644 --- a/hail/src/test/scala/is/hail/stats/eigSymDSuite.scala +++ b/hail/src/test/scala/is/hail/stats/eigSymDSuite.scala @@ -1,13 +1,14 @@ package is.hail.stats -import breeze.linalg.{DenseMatrix, DenseVector, eigSym, svd} -import is.hail.utils._ import is.hail.{HailSuite, TestUtils} +import is.hail.utils._ + +import breeze.linalg.{eigSym, svd, DenseMatrix, DenseVector} import org.apache.commons.math3.random.JDKRandomGenerator import org.testng.annotations.Test class eigSymDSuite extends HailSuite { - @Test def eigSymTest() { + @Test def eigSymTest(): Unit = { val seed = 0 val rand = new JDKRandomGenerator() @@ -23,60 +24,56 @@ class eigSymDSuite extends HailSuite { val svdK = svd(K) val eigSymK = eigSym(K) val eigSymDK = eigSymD(K) - val eigSymRK = eigSymR(K) // eigSymD = svdW for (j <- 0 until n) { assert(D_==(svdW.S(j) * svdW.S(j), eigSymDK.eigenvalues(n - j - 1))) - for (i <- 0 until n) { + for (i <- 0 until n) assert(D_==(math.abs(svdW.U(i, j)), math.abs(eigSymDK.eigenvectors(i, n - j - 1)))) - } } // eigSymR = svdK for (j <- 0 until n) { assert(D_==(svdK.S(j), eigSymDK.eigenvalues(n - j - 1))) - for (i <- 0 until n) { + for (i <- 0 until n) assert(D_==(math.abs(svdK.U(i, j)), math.abs(eigSymDK.eigenvectors(i, n - j - 1)))) - } } // eigSymD = eigSym for (j <- 0 until n) { assert(D_==(eigSymK.eigenvalues(j), eigSymDK.eigenvalues(j))) - for (i <- 0 until n) { + for (i <- 0 until n) assert(D_==(math.abs(eigSymK.eigenvectors(i, j)), math.abs(eigSymDK.eigenvectors(i, j)))) - } } // small example - val K2 = DenseMatrix((2.0, 1.0),(1.0, 2.0)) + val K2 = DenseMatrix((2.0, 1.0), (1.0, 2.0)) val c = 1 / math.sqrt(2) val eigSymDK2 = eigSymD(K2) val eigSymRK2 = eigSymR(K2) assert(D_==(eigSymDK2.eigenvalues(0), 1.0)) assert(D_==(eigSymDK2.eigenvalues(1), 3.0)) - assert(D_==(math.abs(eigSymDK2.eigenvectors(0,0)), c)) - assert(D_==(math.abs(eigSymDK2.eigenvectors(1,0)), c)) - assert(D_==(math.abs(eigSymDK2.eigenvectors(0,1)), c)) - assert(D_==(math.abs(eigSymDK2.eigenvectors(1,1)), c)) + assert(D_==(math.abs(eigSymDK2.eigenvectors(0, 0)), c)) + assert(D_==(math.abs(eigSymDK2.eigenvectors(1, 0)), c)) + assert(D_==(math.abs(eigSymDK2.eigenvectors(0, 1)), c)) + assert(D_==(math.abs(eigSymDK2.eigenvectors(1, 1)), c)) assert(D_==(eigSymRK2.eigenvalues(0), 1.0)) assert(D_==(eigSymRK2.eigenvalues(1), 3.0)) - assert(D_==(math.abs(eigSymRK2.eigenvectors(0,0)), c)) - assert(D_==(math.abs(eigSymRK2.eigenvectors(1,0)), c)) - assert(D_==(math.abs(eigSymRK2.eigenvectors(0,1)), c)) - assert(D_==(math.abs(eigSymRK2.eigenvectors(1,1)), c)) + assert(D_==(math.abs(eigSymRK2.eigenvectors(0, 0)), c)) + assert(D_==(math.abs(eigSymRK2.eigenvectors(1, 0)), c)) + assert(D_==(math.abs(eigSymRK2.eigenvectors(0, 1)), c)) + assert(D_==(math.abs(eigSymRK2.eigenvectors(1, 1)), c)) } - def symEigSpeedTest() { + def symEigSpeedTest(): Unit = { val seed = 0 val rand = new JDKRandomGenerator() rand.setSeed(seed) - def timeSymEig() { + def timeSymEig(): Unit = { for (n <- 500 to 5500 by 500) { val W = DenseMatrix.fill[Double](n, n)(rand.nextGaussian()) val K = W * W.t @@ -100,7 +97,7 @@ class eigSymDSuite extends HailSuite { timeSymEig() } - @Test def triSolveTest() { + @Test def triSolveTest(): Unit = { val seed = 0 val rand = new JDKRandomGenerator() @@ -108,7 +105,7 @@ class eigSymDSuite extends HailSuite { (1 to 5).foreach { n => val A = DenseMatrix.zeros[Double](n, n) - (0 until n).foreach(i => (i until n).foreach(j => A(i,j) = rand.nextGaussian())) + (0 until n).foreach(i => (i until n).foreach(j => A(i, j) = rand.nextGaussian())) val x = DenseVector.fill[Double](n)(rand.nextGaussian()) diff --git a/hail/src/test/scala/is/hail/testUtils/AltAllele.scala b/hail/src/test/scala/is/hail/testUtils/AltAllele.scala index 7b81d2ae9e1..dc24905886f 100644 --- a/hail/src/test/scala/is/hail/testUtils/AltAllele.scala +++ b/hail/src/test/scala/is/hail/testUtils/AltAllele.scala @@ -1,8 +1,5 @@ package is.hail.testUtils -import org.apache.spark.sql.Row -import org.json4s._ - case class AltAllele(ref: String, alt: String) { require(ref != alt, "ref was equal to alt") require(!ref.isEmpty, "ref was an empty string") diff --git a/hail/src/test/scala/is/hail/testUtils/Variant.scala b/hail/src/test/scala/is/hail/testUtils/Variant.scala index a1fcbbb43b4..67907d8ba50 100644 --- a/hail/src/test/scala/is/hail/testUtils/Variant.scala +++ b/hail/src/test/scala/is/hail/testUtils/Variant.scala @@ -1,40 +1,32 @@ package is.hail.testUtils import is.hail.annotations.Annotation -import is.hail.types.virtual.{TArray, TLocus, TString, TStruct} import is.hail.variant._ -import org.apache.spark.sql.Row -import org.json4s._ import scala.collection.JavaConverters._ +import org.apache.spark.sql.Row + object Variant { - def apply(contig: String, - start: Int, - ref: String, - alt: String): Variant = { + def apply(contig: String, start: Int, ref: String, alt: String): Variant = Variant(contig, start, ref, Array(AltAllele(ref, alt))) - } - def apply(contig: String, - start: Int, - ref: String, - alts: Array[String]): Variant = Variant(contig, start, ref, alts.map(alt => AltAllele(ref, alt))) + def apply(contig: String, start: Int, ref: String, alts: Array[String]): Variant = + Variant(contig, start, ref, alts.map(alt => AltAllele(ref, alt))) - def apply(contig: String, - start: Int, - ref: String, - alts: Array[String], - rg: ReferenceGenome): Variant = { + def apply(contig: String, start: Int, ref: String, alts: Array[String], rg: ReferenceGenome) + : Variant = { rg.checkLocus(contig, start) Variant(contig, start, ref, alts) } - def apply(contig: String, + def apply( + contig: String, start: Int, ref: String, alts: java.util.ArrayList[String], - rg: ReferenceGenome): Variant = Variant(contig, start, ref, alts.asScala.toArray, rg) + rg: ReferenceGenome, + ): Variant = Variant(contig, start, ref, alts.asScala.toArray, rg) def fromLocusAlleles(a: Annotation): Variant = { val r = a.asInstanceOf[Row] @@ -47,16 +39,13 @@ object Variant { } } -case class Variant(contig: String, - start: Int, - ref: String, - altAlleles: IndexedSeq[AltAllele]) { +case class Variant(contig: String, start: Int, ref: String, altAlleles: IndexedSeq[AltAllele]) { require(altAlleles.forall(_.ref == ref)) - /* The position is 1-based. Telomeres are indicated by using positions 0 or N+1, where N is the length of the - corresponding chromosome or contig. See the VCF spec, v4.2, section 1.4.1. */ - require(start >= 0, s"invalid variant: negative position: '${ this.toString }'") - require(!ref.isEmpty, s"invalid variant: empty contig: '${ this.toString }'") + /* The position is 1-based. Telomeres are indicated by using positions 0 or N+1, where N is the + * length of the corresponding chromosome or contig. See the VCF spec, v4.2, section 1.4.1. */ + require(start >= 0, s"invalid variant: negative position: '${this.toString}'") + require(!ref.isEmpty, s"invalid variant: empty contig: '${this.toString}'") def nAltAlleles: Int = altAlleles.length @@ -80,5 +69,5 @@ case class Variant(contig: String, def locus: Locus = Locus(contig, start) override def toString: String = - s"$contig:$start:$ref:${ altAlleles.map(_.alt).mkString(",") }" + s"$contig:$start:$ref:${altAlleles.map(_.alt).mkString(",")}" } diff --git a/hail/src/test/scala/is/hail/expr/ir/ETypeSuite.scala b/hail/src/test/scala/is/hail/types/encoded/ETypeSuite.scala similarity index 76% rename from hail/src/test/scala/is/hail/expr/ir/ETypeSuite.scala rename to hail/src/test/scala/is/hail/types/encoded/ETypeSuite.scala index dec6236c145..226186c9e50 100644 --- a/hail/src/test/scala/is/hail/expr/ir/ETypeSuite.scala +++ b/hail/src/test/scala/is/hail/types/encoded/ETypeSuite.scala @@ -1,20 +1,21 @@ -package is.hail.expr.ir +package is.hail.types.encoded import is.hail.HailSuite import is.hail.annotations.{Annotation, Region, SafeNDArray, SafeRow} import is.hail.asm4s.Code +import is.hail.expr.ir.EmitFunctionBuilder import is.hail.io._ import is.hail.rvd.AbstractRVDSpec -import is.hail.types.encoded._ import is.hail.types.physical._ import is.hail.utils._ + import org.apache.spark.sql.Row import org.json4s.jackson.Serialization import org.testng.annotations.{DataProvider, Test} class ETypeSuite extends HailSuite { - @DataProvider(name="etypes") + @DataProvider(name = "etypes") def etypes(): Array[Array[Any]] = { Array[EType]( EInt32Required, @@ -31,16 +32,18 @@ class ETypeSuite extends HailSuite { EArray(EInt32Required, required = false), EArray(EArray(EInt32Optional, required = true), required = true), EBaseStruct(FastSeq(), required = true), - EBaseStruct(FastSeq(EField("x", EBinaryRequired, 0), EField("y", EFloat64Optional, 1)), required = true), - ENDArrayColumnMajor(EFloat64Required , 3) + EBaseStruct( + FastSeq(EField("x", EBinaryRequired, 0), EField("y", EFloat64Optional, 1)), + required = true, + ), + ENDArrayColumnMajor(EFloat64Required, 3), ).map(t => Array(t: Any)) } - @Test def testDataProvider(): Unit = { + @Test def testDataProvider(): Unit = etypes() - } - @Test(dataProvider="etypes") + @Test(dataProvider = "etypes") def testSerialization(etype: EType): Unit = { implicit val formats = AbstractRVDSpec.formats val s = Serialization.write(etype) @@ -53,7 +56,7 @@ class ETypeSuite extends HailSuite { fb.emitWithBuilder { cb => val arg1 = inPType.loadCheapSCode(cb, fb.apply_method.getCodeParam[Long](1)) val arg2 = fb.apply_method.getCodeParam[OutputBuffer](2) - cb.invokeVoid(enc, arg1, arg2) + cb.invokeVoid(enc, cb.this_, arg1, arg2) Code._empty } @@ -71,43 +74,57 @@ class ETypeSuite extends HailSuite { val ibArg = fb2.apply_method.getCodeParam[InputBuffer](2) val dec = eType.buildDecoderMethod(outPType.virtualType, fb2.apply_method.ecb) fb2.emitWithBuilder[Long] { cb => - val decoded = cb.invokeSCode(dec, regArg, ibArg) + val decoded = cb.invokeSCode(dec, cb.this_, regArg, ibArg) outPType.store(cb, regArg, decoded, deepCopy = false) } - val result = fb2.resultWithIndex()(theHailClassLoader, ctx.fs, ctx.taskContext, ctx.r).apply(ctx.r, new MemoryInputBuffer(buffer)) + val result = fb2.resultWithIndex()(theHailClassLoader, ctx.fs, ctx.taskContext, ctx.r).apply( + ctx.r, + new MemoryInputBuffer(buffer), + ) SafeRow.read(outPType, result) } - def assertEqualEncodeDecode(inPType: PType, eType: EType, outPType: PType, data: Annotation): Unit = { + def assertEqualEncodeDecode(inPType: PType, eType: EType, outPType: PType, data: Annotation) + : Unit = { val encodeDecodeResult = encodeDecode(inPType, eType, outPType, data) assert(encodeDecodeResult == data) } - @Test def testDifferentRequirednessEncodeDecode() { + @Test def testDifferentRequirednessEncodeDecode(): Unit = { val inPType = PCanonicalArray( - PCanonicalStruct(true, + PCanonicalStruct( + true, "a" -> PInt32Required, "b" -> PInt32Optional, "c" -> PCanonicalStringRequired, - "d" -> PCanonicalArray(PCanonicalStruct(true, "x" -> PInt64Required), true)), - false) + "d" -> PCanonicalArray(PCanonicalStruct(true, "x" -> PInt64Required), true), + ), + false, + ) val etype = EArray( - EBaseStruct(FastSeq( - EField("a", EInt32Required, 0), - EField("b", EInt32Optional, 1), - EField("c", EBinaryOptional, 2), - EField("d", EArray(EBaseStruct(FastSeq(EField("x", EInt64Optional, 0)), false), true), 3)), - true), - false) + EBaseStruct( + FastSeq( + EField("a", EInt32Required, 0), + EField("b", EInt32Optional, 1), + EField("c", EBinaryOptional, 2), + EField("d", EArray(EBaseStruct(FastSeq(EField("x", EInt64Optional, 0)), false), true), 3), + ), + true, + ), + false, + ) val outPType = PCanonicalArray( - PCanonicalStruct(false, + PCanonicalStruct( + false, "a" -> PInt32Optional, "b" -> PInt32Optional, "c" -> PCanonicalStringOptional, - "d" -> PCanonicalArray(PCanonicalStruct(false, "x" -> PInt64Optional), false)), - false) + "d" -> PCanonicalArray(PCanonicalStruct(false, "x" -> PInt64Optional), false), + ), + false, + ) val data = FastSeq(Row(1, null, "abc", FastSeq(Row(7L), Row(8L)))) @@ -141,23 +158,24 @@ class ETypeSuite extends HailSuite { new SafeNDArray(IndexedSeq(3L, 2L, 1L), FastSeq(1.0, 2.0, 3.0, 4.0, 5.0, 6.0))) // Test for skipping - val pStructContainingNDArray = PCanonicalStruct(true, - "a" -> pTypeInt2, - "b" -> PInt32Optional - ) - val pOnlyReadB = PCanonicalStruct(true, - "b" -> PInt32Optional - ) + val pStructContainingNDArray = PCanonicalStruct(true, "a" -> pTypeInt2, "b" -> PInt32Optional) + val pOnlyReadB = PCanonicalStruct(true, "b" -> PInt32Optional) val eStructContainingNDArray = EBaseStruct( FastSeq( EField("a", ENDArrayColumnMajor(EInt32Required, 2, true), 0), - EField("b", EInt32Required, 1) + EField("b", EInt32Required, 1), ), - true) + true, + ) val dataStruct = Row(dataInt2, 3) - assert(encodeDecode(pStructContainingNDArray, eStructContainingNDArray, pOnlyReadB, dataStruct) == + assert(encodeDecode( + pStructContainingNDArray, + eStructContainingNDArray, + pOnlyReadB, + dataStruct, + ) == Row(3)) } @@ -165,7 +183,7 @@ class ETypeSuite extends HailSuite { val etype = EArray(EBinary(false), false) val toEncode = PCanonicalArray(PCanonicalStringRequired, false) val toDecode = PCanonicalArray(PCanonicalStringOptional, false) - val longListOfStrings = (0 until 36).map(idx => s"foo_name_sample_${idx}") + val longListOfStrings = (0 until 36).map(idx => s"foo_name_sample_$idx") val data = longListOfStrings assert(encodeDecode(toEncode, etype, toDecode, data) == data) diff --git a/hail/src/test/scala/is/hail/types/physical/PBaseStructSuite.scala b/hail/src/test/scala/is/hail/types/physical/PBaseStructSuite.scala index a121430c9b4..42a24869dce 100644 --- a/hail/src/test/scala/is/hail/types/physical/PBaseStructSuite.scala +++ b/hail/src/test/scala/is/hail/types/physical/PBaseStructSuite.scala @@ -1,65 +1,141 @@ package is.hail.types.physical -import is.hail.HailSuite import is.hail.annotations.Annotation + import org.testng.annotations.Test class PBaseStructSuite extends PhysicalTestUtils { - @Test def testStructCopy() { - def runTests(deepCopy: Boolean, interpret: Boolean = false) { - copyTestExecutor(PCanonicalStruct(), PCanonicalStruct(), Annotation(), - deepCopy = deepCopy, interpret = interpret) - copyTestExecutor(PCanonicalStruct("a" -> PInt64()), PCanonicalStruct("a" -> PInt64()), Annotation(12L), - deepCopy = deepCopy, interpret = interpret) - copyTestExecutor(PCanonicalStruct("a" -> PInt64()), PCanonicalStruct("a" -> PInt64()), Annotation(null), - deepCopy = deepCopy, interpret = interpret) - copyTestExecutor(PCanonicalStruct("a" -> PInt64(true)), PCanonicalStruct("a" -> PInt64(true)), Annotation(11L), - deepCopy = deepCopy, interpret = interpret) - - copyTestExecutor(PCanonicalStruct("a" -> PInt64(false)), PCanonicalStruct("a" -> PInt64(true)), Annotation(3L), - deepCopy = deepCopy, interpret = interpret) - - var srcType = PCanonicalStruct("a" -> PInt64(true), "b" -> PInt32(true), "c" -> PFloat64(true), "d" -> PFloat32(true), "e" -> PBoolean(true)) - var destType = PCanonicalStruct("a" -> PInt64(true), "b" -> PInt32(true), "c" -> PFloat64(true), "d" -> PFloat32(true), "e" -> PBoolean(true)) - var expectedVal = Annotation(13L, 12, 13.0, 10.0F, true) + @Test def testStructCopy(): Unit = { + def runTests(deepCopy: Boolean, interpret: Boolean = false): Unit = { + copyTestExecutor( + PCanonicalStruct(), + PCanonicalStruct(), + Annotation(), + deepCopy = deepCopy, + interpret = interpret, + ) + copyTestExecutor( + PCanonicalStruct("a" -> PInt64()), + PCanonicalStruct("a" -> PInt64()), + Annotation(12L), + deepCopy = deepCopy, + interpret = interpret, + ) + copyTestExecutor( + PCanonicalStruct("a" -> PInt64()), + PCanonicalStruct("a" -> PInt64()), + Annotation(null), + deepCopy = deepCopy, + interpret = interpret, + ) + copyTestExecutor( + PCanonicalStruct("a" -> PInt64(true)), + PCanonicalStruct("a" -> PInt64(true)), + Annotation(11L), + deepCopy = deepCopy, + interpret = interpret, + ) + + copyTestExecutor( + PCanonicalStruct("a" -> PInt64(false)), + PCanonicalStruct("a" -> PInt64(true)), + Annotation(3L), + deepCopy = deepCopy, + interpret = interpret, + ) + + var srcType = PCanonicalStruct( + "a" -> PInt64(true), + "b" -> PInt32(true), + "c" -> PFloat64(true), + "d" -> PFloat32(true), + "e" -> PBoolean(true), + ) + var destType = PCanonicalStruct( + "a" -> PInt64(true), + "b" -> PInt32(true), + "c" -> PFloat64(true), + "d" -> PFloat32(true), + "e" -> PBoolean(true), + ) + var expectedVal = Annotation(13L, 12, 13.0, 10.0f, true) copyTestExecutor(srcType, destType, expectedVal, deepCopy = deepCopy, interpret = interpret) srcType = PCanonicalStruct("a" -> srcType, "c" -> PFloat32()) destType = PCanonicalStruct("a" -> destType, "c" -> PFloat32()) - var nestedExpectedVal = Annotation(expectedVal, 13.0F) - copyTestExecutor(srcType, destType, nestedExpectedVal, deepCopy = deepCopy, interpret = interpret) - - srcType = PCanonicalStruct("a" -> PInt64(), "b" -> PInt32(), "c" -> PFloat64(), "d" -> PFloat32(), "e" -> PBoolean()) - destType = PCanonicalStruct("a" -> PInt64(), "b" -> PInt32(), "c" -> PFloat64(), "d" -> PFloat32(), "e" -> PBoolean()) + var nestedExpectedVal = Annotation(expectedVal, 13.0f) + copyTestExecutor(srcType, destType, nestedExpectedVal, deepCopy = deepCopy, + interpret = interpret) + + srcType = PCanonicalStruct( + "a" -> PInt64(), + "b" -> PInt32(), + "c" -> PFloat64(), + "d" -> PFloat32(), + "e" -> PBoolean(), + ) + destType = PCanonicalStruct( + "a" -> PInt64(), + "b" -> PInt32(), + "c" -> PFloat64(), + "d" -> PFloat32(), + "e" -> PBoolean(), + ) copyTestExecutor(srcType, destType, expectedVal, deepCopy = deepCopy, interpret = interpret) srcType = PCanonicalStruct("a" -> srcType, "b" -> PFloat32()) destType = PCanonicalStruct("a" -> destType, "b" -> PFloat32()) - nestedExpectedVal = Annotation(expectedVal, 14.0F) - copyTestExecutor(srcType, destType, nestedExpectedVal, deepCopy = deepCopy, interpret = interpret) - - srcType = PCanonicalStruct("a" -> PInt64(), "b" -> PInt32(true), "c" -> PFloat64(), "d" -> PFloat32(), "e" -> PBoolean()) - destType = PCanonicalStruct("a" -> PInt64(), "b" -> PInt32(), "c" -> PFloat64(), "d" -> PFloat32(), "e" -> PBoolean()) + nestedExpectedVal = Annotation(expectedVal, 14.0f) + copyTestExecutor(srcType, destType, nestedExpectedVal, deepCopy = deepCopy, + interpret = interpret) + + srcType = PCanonicalStruct( + "a" -> PInt64(), + "b" -> PInt32(true), + "c" -> PFloat64(), + "d" -> PFloat32(), + "e" -> PBoolean(), + ) + destType = PCanonicalStruct( + "a" -> PInt64(), + "b" -> PInt32(), + "c" -> PFloat64(), + "d" -> PFloat32(), + "e" -> PBoolean(), + ) copyTestExecutor(srcType, destType, expectedVal, deepCopy = deepCopy, interpret = interpret) srcType = PCanonicalStruct("a" -> srcType, "b" -> PFloat32()) destType = PCanonicalStruct("a" -> destType, "b" -> PFloat32()) - nestedExpectedVal = Annotation(Annotation(13L, 12, 13.0, 10.0F, true), 15.0F) - copyTestExecutor(srcType, destType, nestedExpectedVal, deepCopy = deepCopy, interpret = interpret) - - srcType = PCanonicalStruct("a" -> PInt64(), "b" -> PInt32(true), "c" -> PFloat64(), "d" -> PFloat32(), "e" -> PBoolean()) - destType = PCanonicalStruct("a" -> PInt64(), "b" -> PInt32(), "c" -> PFloat64(true), "d" -> PFloat32(), "e" -> PBoolean()) + nestedExpectedVal = Annotation(Annotation(13L, 12, 13.0, 10.0f, true), 15.0f) + copyTestExecutor(srcType, destType, nestedExpectedVal, deepCopy = deepCopy, + interpret = interpret) + + srcType = PCanonicalStruct( + "a" -> PInt64(), + "b" -> PInt32(true), + "c" -> PFloat64(), + "d" -> PFloat32(), + "e" -> PBoolean(), + ) + destType = PCanonicalStruct( + "a" -> PInt64(), + "b" -> PInt32(), + "c" -> PFloat64(true), + "d" -> PFloat32(), + "e" -> PBoolean(), + ) copyTestExecutor(srcType, destType, expectedVal, deepCopy = deepCopy) srcType = PCanonicalStruct("a" -> srcType, "b" -> PFloat32()) destType = PCanonicalStruct("a" -> destType, "b" -> PFloat32()) - nestedExpectedVal = Annotation(expectedVal, 13F) + nestedExpectedVal = Annotation(expectedVal, 13f) copyTestExecutor(srcType, destType, nestedExpectedVal, deepCopy = deepCopy) srcType = PCanonicalStruct("a" -> PCanonicalArray(PInt32(true)), "b" -> PInt64()) destType = PCanonicalStruct("a" -> PCanonicalArray(PInt32()), "b" -> PInt64()) - expectedVal = Annotation(IndexedSeq(1,5,7,2,31415926), 31415926535897L) + expectedVal = Annotation(IndexedSeq(1, 5, 7, 2, 31415926), 31415926535897L) copyTestExecutor(srcType, destType, expectedVal, deepCopy = deepCopy, interpret = interpret) expectedVal = Annotation(null, 31415926535897L) @@ -69,15 +145,28 @@ class PBaseStructSuite extends PhysicalTestUtils { srcType = PCanonicalStruct("a" -> PCanonicalArray(PInt32(true)), "b" -> PInt64()) destType = PCanonicalStruct("a" -> PCanonicalArray(PInt32(), true), "b" -> PInt64()) - copyTestExecutor(srcType, destType, expectedVal, deepCopy = deepCopy, interpret = interpret, expectRuntimeError = true) - - srcType = PCanonicalStruct("a" -> PCanonicalArray(PCanonicalArray(PCanonicalStruct("a" -> PInt32(true)))), "b" -> PInt64()) - destType = PCanonicalStruct("a" -> PCanonicalArray(PCanonicalArray(PCanonicalStruct("a" -> PInt32(true)))), "b" -> PInt64()) + copyTestExecutor(srcType, destType, expectedVal, deepCopy = deepCopy, interpret = interpret, + expectRuntimeError = true) + + srcType = PCanonicalStruct( + "a" -> PCanonicalArray(PCanonicalArray(PCanonicalStruct("a" -> PInt32(true)))), + "b" -> PInt64(), + ) + destType = PCanonicalStruct( + "a" -> PCanonicalArray(PCanonicalArray(PCanonicalStruct("a" -> PInt32(true)))), + "b" -> PInt64(), + ) expectedVal = Annotation(IndexedSeq(null, IndexedSeq(null, Annotation(1))), 31415926535897L) copyTestExecutor(srcType, destType, expectedVal, deepCopy = deepCopy, interpret = interpret) - srcType = PCanonicalStruct(true, "foo" -> PCanonicalStruct("bar" -> PCanonicalArray(PInt32(true), true))) - destType = PCanonicalStruct(false, "foo" -> PCanonicalStruct("bar" -> PCanonicalArray(PInt32(false), false))) + srcType = PCanonicalStruct( + true, + "foo" -> PCanonicalStruct("bar" -> PCanonicalArray(PInt32(true), true)), + ) + destType = PCanonicalStruct( + false, + "foo" -> PCanonicalStruct("bar" -> PCanonicalArray(PInt32(false), false)), + ) expectedVal = Annotation(Annotation(IndexedSeq(1, 2, 3))) copyTestExecutor(srcType, destType, expectedVal, deepCopy = deepCopy, interpret = interpret) } @@ -89,10 +178,15 @@ class PBaseStructSuite extends PhysicalTestUtils { runTests(false, true) } - @Test def tupleCopyTests() { - def runTests(deepCopy: Boolean, interpret: Boolean = false) { - copyTestExecutor(PCanonicalTuple(false, PCanonicalString(true), PCanonicalString(true)), PCanonicalTuple(false, PCanonicalString(), PCanonicalString()), Annotation("1", "2"), - deepCopy = deepCopy, interpret = interpret) + @Test def tupleCopyTests(): Unit = { + def runTests(deepCopy: Boolean, interpret: Boolean = false): Unit = { + copyTestExecutor( + PCanonicalTuple(false, PCanonicalString(true), PCanonicalString(true)), + PCanonicalTuple(false, PCanonicalString(), PCanonicalString()), + Annotation("1", "2"), + deepCopy = deepCopy, + interpret = interpret, + ) } runTests(true) diff --git a/hail/src/test/scala/is/hail/types/physical/PBinarySuite.scala b/hail/src/test/scala/is/hail/types/physical/PBinarySuite.scala index 483da14bcd8..81e68c81aaf 100644 --- a/hail/src/test/scala/is/hail/types/physical/PBinarySuite.scala +++ b/hail/src/test/scala/is/hail/types/physical/PBinarySuite.scala @@ -1,20 +1,33 @@ package is.hail.types.physical -import is.hail.HailSuite -import is.hail.annotations.Annotation import org.testng.annotations.Test class PBinarySuite extends PhysicalTestUtils { - @Test def testCopy() { - def runTests(deepCopy: Boolean, interpret: Boolean = false) { - copyTestExecutor(PCanonicalString(), PCanonicalString(), "", - deepCopy = deepCopy, interpret = interpret) + @Test def testCopy(): Unit = { + def runTests(deepCopy: Boolean, interpret: Boolean = false): Unit = { + copyTestExecutor( + PCanonicalString(), + PCanonicalString(), + "", + deepCopy = deepCopy, + interpret = interpret, + ) - copyTestExecutor(PCanonicalString(), PCanonicalString(true), "TopLevelDowncastAllowed", - deepCopy = deepCopy, interpret = interpret) + copyTestExecutor( + PCanonicalString(), + PCanonicalString(true), + "TopLevelDowncastAllowed", + deepCopy = deepCopy, + interpret = interpret, + ) - copyTestExecutor(PCanonicalString(true), PCanonicalString(), "UpcastAllowed", - deepCopy = deepCopy, interpret = interpret) + copyTestExecutor( + PCanonicalString(true), + PCanonicalString(), + "UpcastAllowed", + deepCopy = deepCopy, + interpret = interpret, + ) } runTests(true) diff --git a/hail/src/test/scala/is/hail/types/physical/PCallSuite.scala b/hail/src/test/scala/is/hail/types/physical/PCallSuite.scala index 039cfed46b6..d2250d851b4 100644 --- a/hail/src/test/scala/is/hail/types/physical/PCallSuite.scala +++ b/hail/src/test/scala/is/hail/types/physical/PCallSuite.scala @@ -1,25 +1,42 @@ package is.hail.types.physical -import is.hail.HailSuite import org.testng.annotations.Test class PCallSuite extends PhysicalTestUtils { - @Test def copyTests() { - def runTests(deepCopy: Boolean, interpret: Boolean = false) { - copyTestExecutor(PCanonicalCall(), PCanonicalCall(), + @Test def copyTests(): Unit = { + def runTests(deepCopy: Boolean, interpret: Boolean = false): Unit = { + copyTestExecutor( + PCanonicalCall(), + PCanonicalCall(), 2, - deepCopy = deepCopy, interpret = interpret) + deepCopy = deepCopy, + interpret = interpret, + ) // downcast at top level allowed, since PCanonicalCall wraps a primitive - copyTestExecutor(PCanonicalCall(), PCanonicalCall(true), + copyTestExecutor( + PCanonicalCall(), + PCanonicalCall(true), 2, - deepCopy = deepCopy, interpret = interpret) + deepCopy = deepCopy, + interpret = interpret, + ) - copyTestExecutor(PCanonicalArray(PCanonicalCall(true), true), PCanonicalArray(PCanonicalCall()), - IndexedSeq(2, 3), deepCopy = deepCopy, interpret = interpret) + copyTestExecutor( + PCanonicalArray(PCanonicalCall(true), true), + PCanonicalArray(PCanonicalCall()), + IndexedSeq(2, 3), + deepCopy = deepCopy, + interpret = interpret, + ) - copyTestExecutor(PCanonicalArray(PCanonicalCall(), true), PCanonicalArray(PCanonicalCall(true)), - IndexedSeq(2, 3), deepCopy = deepCopy, interpret = interpret) + copyTestExecutor( + PCanonicalArray(PCanonicalCall(), true), + PCanonicalArray(PCanonicalCall(true)), + IndexedSeq(2, 3), + deepCopy = deepCopy, + interpret = interpret, + ) } runTests(true) diff --git a/hail/src/test/scala/is/hail/types/physical/PContainerTest.scala b/hail/src/test/scala/is/hail/types/physical/PContainerTest.scala index 94aa3d4104c..037f19a2aba 100644 --- a/hail/src/test/scala/is/hail/types/physical/PContainerTest.scala +++ b/hail/src/test/scala/is/hail/types/physical/PContainerTest.scala @@ -4,30 +4,32 @@ import is.hail.annotations.{Annotation, Region, ScalaToRegionValue} import is.hail.asm4s._ import is.hail.expr.ir.EmitFunctionBuilder import is.hail.utils._ + import org.testng.annotations.Test class PContainerTest extends PhysicalTestUtils { def nullInByte(nElements: Int, missingElement: Int) = { - IndexedSeq.tabulate(nElements)(i => { + IndexedSeq.tabulate(nElements) { i => if (i == missingElement - 1) null else i + 1L - }) + } } def testContainsNonZeroBits(sourceType: PCanonicalArray, data: IndexedSeq[Any]) = { - val srcRegion = Region(pool=pool) + val srcRegion = Region(pool = pool) val src = ScalaToRegionValue(ctx.stateManager, srcRegion, sourceType, data) log.info(s"Testing $data") - val res = Region.containsNonZeroBits(src + sourceType.missingBytesOffset, sourceType.loadLength(src)) + val res = + Region.containsNonZeroBits(src + sourceType.missingBytesOffset, sourceType.loadLength(src)) res } def testContainsNonZeroBitsStaged(sourceType: PCanonicalArray, data: IndexedSeq[Any]) = { - val srcRegion = Region(pool=pool) + val srcRegion = Region(pool = pool) val src = ScalaToRegionValue(ctx.stateManager, srcRegion, sourceType, data) log.info(s"Testing $data") @@ -35,14 +37,17 @@ class PContainerTest extends PhysicalTestUtils { val fb = EmitFunctionBuilder[Long, Boolean](ctx, "not_empty") val value = fb.getCodeParam[Long](1) - fb.emit(Region.containsNonZeroBits(value + sourceType.missingBytesOffset, sourceType.loadLength(value).toL)) + fb.emit(Region.containsNonZeroBits( + value + sourceType.missingBytesOffset, + sourceType.loadLength(value).toL, + )) val res = fb.result()(theHailClassLoader)(src) res } def testHasMissingValues(sourceType: PArray, data: IndexedSeq[Any]) = { - val srcRegion = Region(pool=pool) + val srcRegion = Region(pool = pool) val src = ScalaToRegionValue(ctx.stateManager, srcRegion, sourceType, data) log.info(s"\nTesting $data") @@ -56,7 +61,7 @@ class PContainerTest extends PhysicalTestUtils { res } - @Test def checkFirstNonZeroByte() { + @Test def checkFirstNonZeroByte(): Unit = { val sourceType = PCanonicalArray(PInt64(false)) assert(testContainsNonZeroBits(sourceType, nullInByte(0, 0)) == false) @@ -98,14 +103,14 @@ class PContainerTest extends PhysicalTestUtils { assert(testContainsNonZeroBits(sourceType, nullInByte(73, 64)) == true) } - @Test def checkFirstNonZeroByteStaged() { + @Test def checkFirstNonZeroByteStaged(): Unit = { val sourceType = PCanonicalArray(PInt64(false)) assert(testContainsNonZeroBitsStaged(sourceType, nullInByte(32, 0)) == false) assert(testContainsNonZeroBitsStaged(sourceType, nullInByte(73, 64)) == true) } - @Test def checkHasMissingValues() { + @Test def checkHasMissingValues(): Unit = { val sourceType = PCanonicalArray(PInt64(false)) assert(testHasMissingValues(sourceType, nullInByte(1, 0)) == false) @@ -114,80 +119,181 @@ class PContainerTest extends PhysicalTestUtils { for { num <- Seq(2, 16, 31, 32, 33, 50, 63, 64, 65, 90, 127, 128, 129) - missing <- 1 to num + missing <- 1 to num } assert(testHasMissingValues(sourceType, nullInByte(num, missing)) == true) } - @Test def arrayCopyTest() { - // Note: can't test where data is null due to ArrayStack.top semantics (ScalaToRegionValue: assert(size_ > 0)) - def runTests(deepCopy: Boolean, interpret: Boolean) { - copyTestExecutor(PCanonicalArray(PInt32()), PCanonicalArray(PInt64()), IndexedSeq(1, 2, 3, 4, 5, 6, 7, 8, 9), - expectCompileError = true, deepCopy = deepCopy, interpret = interpret) - copyTestExecutor(PCanonicalArray(PInt32()), PCanonicalArray(PInt32()), IndexedSeq(1, 2, 3, 4), - deepCopy = deepCopy, interpret = interpret) - copyTestExecutor(PCanonicalArray(PInt32()), PCanonicalArray(PInt32()), IndexedSeq(1, 2, 3, 4), - deepCopy = deepCopy, interpret = interpret) - copyTestExecutor(PCanonicalArray(PInt32()), PCanonicalArray(PInt32()), IndexedSeq(1, null, 3, 4), - deepCopy = deepCopy, interpret = interpret) + @Test def arrayCopyTest(): Unit = { + /* Note: can't test where data is null due to ArrayStack.top semantics (ScalaToRegionValue: + * assert(size_ > 0)) */ + def runTests(deepCopy: Boolean, interpret: Boolean): Unit = { + copyTestExecutor( + PCanonicalArray(PInt32()), + PCanonicalArray(PInt64()), + IndexedSeq(1, 2, 3, 4, 5, 6, 7, 8, 9), + expectCompileError = true, + deepCopy = deepCopy, + interpret = interpret, + ) + copyTestExecutor( + PCanonicalArray(PInt32()), + PCanonicalArray(PInt32()), + IndexedSeq(1, 2, 3, 4), + deepCopy = deepCopy, + interpret = interpret, + ) + copyTestExecutor( + PCanonicalArray(PInt32()), + PCanonicalArray(PInt32()), + IndexedSeq(1, 2, 3, 4), + deepCopy = deepCopy, + interpret = interpret, + ) + copyTestExecutor( + PCanonicalArray(PInt32()), + PCanonicalArray(PInt32()), + IndexedSeq(1, null, 3, 4), + deepCopy = deepCopy, + interpret = interpret, + ) // test upcast - copyTestExecutor(PCanonicalArray(PInt32(true)), PCanonicalArray(PInt32()), IndexedSeq(1, 2, 3, 4), - deepCopy = deepCopy, interpret = interpret) + copyTestExecutor( + PCanonicalArray(PInt32(true)), + PCanonicalArray(PInt32()), + IndexedSeq(1, 2, 3, 4), + deepCopy = deepCopy, + interpret = interpret, + ) - // test mismatched top-level requiredeness, allowed because by source value address must be present and therefore non-null - copyTestExecutor(PCanonicalArray(PInt32()), PCanonicalArray(PInt32(), true), IndexedSeq(1, 2, 3, 4), - deepCopy = deepCopy, interpret = interpret) + /* test mismatched top-level requiredeness, allowed because by source value address must be + * present and therefore non-null */ + copyTestExecutor( + PCanonicalArray(PInt32()), + PCanonicalArray(PInt32(), true), + IndexedSeq(1, 2, 3, 4), + deepCopy = deepCopy, + interpret = interpret, + ) // downcast disallowed - copyTestExecutor(PCanonicalArray(PInt32()), PCanonicalArray(PInt32(true)), IndexedSeq(1, 2, 3, 4), - deepCopy = deepCopy, interpret = interpret) - copyTestExecutor(PCanonicalArray(PCanonicalArray(PInt64())), PCanonicalArray(PCanonicalArray(PInt64(), true)), - FastSeq(FastSeq(20L), FastSeq(1L), FastSeq(20L,5L,31L,41L), FastSeq(1L,2L,3L)), - deepCopy = deepCopy, interpret = interpret) - copyTestExecutor(PCanonicalArray(PCanonicalArray(PInt64())), PCanonicalArray(PCanonicalArray(PInt64(), true)), - FastSeq(FastSeq(20L), FastSeq(1L), FastSeq(20L,5L,31L,41L), FastSeq(1L,2L,3L)), - deepCopy = deepCopy, interpret = interpret) - copyTestExecutor(PCanonicalArray(PCanonicalArray(PInt64())), PCanonicalArray(PCanonicalArray(PInt64(true))), - FastSeq(FastSeq(20L), FastSeq(1L), FastSeq(20L,5L,31L,41L), FastSeq(1L,2L,3L)), - deepCopy = deepCopy, interpret = interpret) + copyTestExecutor( + PCanonicalArray(PInt32()), + PCanonicalArray(PInt32(true)), + IndexedSeq(1, 2, 3, 4), + deepCopy = deepCopy, + interpret = interpret, + ) + copyTestExecutor( + PCanonicalArray(PCanonicalArray(PInt64())), + PCanonicalArray(PCanonicalArray(PInt64(), true)), + FastSeq(FastSeq(20L), FastSeq(1L), FastSeq(20L, 5L, 31L, 41L), FastSeq(1L, 2L, 3L)), + deepCopy = deepCopy, + interpret = interpret, + ) + copyTestExecutor( + PCanonicalArray(PCanonicalArray(PInt64())), + PCanonicalArray(PCanonicalArray(PInt64(), true)), + FastSeq(FastSeq(20L), FastSeq(1L), FastSeq(20L, 5L, 31L, 41L), FastSeq(1L, 2L, 3L)), + deepCopy = deepCopy, + interpret = interpret, + ) + copyTestExecutor( + PCanonicalArray(PCanonicalArray(PInt64())), + PCanonicalArray(PCanonicalArray(PInt64(true))), + FastSeq(FastSeq(20L), FastSeq(1L), FastSeq(20L, 5L, 31L, 41L), FastSeq(1L, 2L, 3L)), + deepCopy = deepCopy, + interpret = interpret, + ) // test empty arrays - copyTestExecutor(PCanonicalArray(PInt32()), PCanonicalArray(PInt32()), FastSeq(), - deepCopy = deepCopy, interpret = interpret) - copyTestExecutor(PCanonicalArray(PInt32(true)), PCanonicalArray(PInt32(true)), FastSeq(), - deepCopy = deepCopy, interpret = interpret) + copyTestExecutor( + PCanonicalArray(PInt32()), + PCanonicalArray(PInt32()), + FastSeq(), + deepCopy = deepCopy, + interpret = interpret, + ) + copyTestExecutor( + PCanonicalArray(PInt32(true)), + PCanonicalArray(PInt32(true)), + FastSeq(), + deepCopy = deepCopy, + interpret = interpret, + ) // test missing-only array - copyTestExecutor(PCanonicalArray(PInt64()), PCanonicalArray(PInt64()), - FastSeq(null), deepCopy = deepCopy, interpret = interpret) - copyTestExecutor(PCanonicalArray(PCanonicalArray(PInt64())), PCanonicalArray(PCanonicalArray(PInt64())), - FastSeq(FastSeq(null)), deepCopy = deepCopy, interpret = interpret) + copyTestExecutor( + PCanonicalArray(PInt64()), + PCanonicalArray(PInt64()), + FastSeq(null), + deepCopy = deepCopy, + interpret = interpret, + ) + copyTestExecutor( + PCanonicalArray(PCanonicalArray(PInt64())), + PCanonicalArray(PCanonicalArray(PInt64())), + FastSeq(FastSeq(null)), + deepCopy = deepCopy, + interpret = interpret, + ) // test 2D arrays - copyTestExecutor(PCanonicalArray(PCanonicalArray(PInt64())), PCanonicalArray(PCanonicalArray(PInt64())), - FastSeq(null, FastSeq(null), FastSeq(20L,5L,31L,41L), FastSeq(1L,2L,3L)), - deepCopy = deepCopy, interpret = interpret) + copyTestExecutor( + PCanonicalArray(PCanonicalArray(PInt64())), + PCanonicalArray(PCanonicalArray(PInt64())), + FastSeq(null, FastSeq(null), FastSeq(20L, 5L, 31L, 41L), FastSeq(1L, 2L, 3L)), + deepCopy = deepCopy, + interpret = interpret, + ) // test complex nesting val complexNesting = FastSeq( - FastSeq( FastSeq(20L,30L,31L,41L), FastSeq(20L,22L,31L,43L) ), - FastSeq( FastSeq(1L,3L,31L,41L), FastSeq(0L,30L,17L,41L) ) - ) - - copyTestExecutor(PCanonicalArray(PCanonicalArray(PCanonicalArray(PInt64(true), true), true), true), PCanonicalArray(PCanonicalArray(PCanonicalArray(PInt64()))), - complexNesting, deepCopy = deepCopy, interpret = interpret) - copyTestExecutor(PCanonicalArray(PCanonicalArray(PCanonicalArray(PInt64(true), true), true)), PCanonicalArray(PCanonicalArray(PCanonicalArray(PInt64()))), - complexNesting, deepCopy = deepCopy, interpret = interpret) - copyTestExecutor(PCanonicalArray(PCanonicalArray(PCanonicalArray(PInt64(true), true))), PCanonicalArray(PCanonicalArray(PCanonicalArray(PInt64()))), - complexNesting, deepCopy = deepCopy, interpret = interpret) - copyTestExecutor(PCanonicalArray(PCanonicalArray(PCanonicalArray(PInt64(true)))), PCanonicalArray(PCanonicalArray(PCanonicalArray(PInt64()))), - complexNesting, deepCopy = deepCopy, interpret = interpret) - copyTestExecutor(PCanonicalArray(PCanonicalArray(PCanonicalArray(PInt64()))), PCanonicalArray(PCanonicalArray(PCanonicalArray(PInt64()))), - complexNesting, deepCopy = deepCopy, interpret = interpret) - - val srcType = PCanonicalArray(PCanonicalStruct("a" -> PCanonicalArray(PInt32(true)), "b" -> PInt64())) - val destType = PCanonicalArray(PCanonicalStruct("a" -> PCanonicalArray(PInt32()), "b" -> PInt64())) - val expectedVal = IndexedSeq(Annotation(IndexedSeq(1,5,7,2,31415926), 31415926535897L)) + FastSeq(FastSeq(20L, 30L, 31L, 41L), FastSeq(20L, 22L, 31L, 43L)), + FastSeq(FastSeq(1L, 3L, 31L, 41L), FastSeq(0L, 30L, 17L, 41L)), + ) + + copyTestExecutor( + PCanonicalArray(PCanonicalArray(PCanonicalArray(PInt64(true), true), true), true), + PCanonicalArray(PCanonicalArray(PCanonicalArray(PInt64()))), + complexNesting, + deepCopy = deepCopy, + interpret = interpret, + ) + copyTestExecutor( + PCanonicalArray(PCanonicalArray(PCanonicalArray(PInt64(true), true), true)), + PCanonicalArray(PCanonicalArray(PCanonicalArray(PInt64()))), + complexNesting, + deepCopy = deepCopy, + interpret = interpret, + ) + copyTestExecutor( + PCanonicalArray(PCanonicalArray(PCanonicalArray(PInt64(true), true))), + PCanonicalArray(PCanonicalArray(PCanonicalArray(PInt64()))), + complexNesting, + deepCopy = deepCopy, + interpret = interpret, + ) + copyTestExecutor( + PCanonicalArray(PCanonicalArray(PCanonicalArray(PInt64(true)))), + PCanonicalArray(PCanonicalArray(PCanonicalArray(PInt64()))), + complexNesting, + deepCopy = deepCopy, + interpret = interpret, + ) + copyTestExecutor( + PCanonicalArray(PCanonicalArray(PCanonicalArray(PInt64()))), + PCanonicalArray(PCanonicalArray(PCanonicalArray(PInt64()))), + complexNesting, + deepCopy = deepCopy, + interpret = interpret, + ) + + val srcType = + PCanonicalArray(PCanonicalStruct("a" -> PCanonicalArray(PInt32(true)), "b" -> PInt64())) + val destType = + PCanonicalArray(PCanonicalStruct("a" -> PCanonicalArray(PInt32()), "b" -> PInt64())) + val expectedVal = IndexedSeq(Annotation(IndexedSeq(1, 5, 7, 2, 31415926), 31415926535897L)) copyTestExecutor(srcType, destType, expectedVal, deepCopy = deepCopy, interpret = interpret) } @@ -198,16 +304,31 @@ class PContainerTest extends PhysicalTestUtils { runTests(false, interpret = true) } - @Test def dictCopyTests() { - def runTests(deepCopy: Boolean, interpret: Boolean) { - copyTestExecutor(PCanonicalDict(PCanonicalString(), PInt32()), PCanonicalDict(PCanonicalString(), PInt32()), Map("test" -> 1), - deepCopy = deepCopy, interpret = interpret) + @Test def dictCopyTests(): Unit = { + def runTests(deepCopy: Boolean, interpret: Boolean): Unit = { + copyTestExecutor( + PCanonicalDict(PCanonicalString(), PInt32()), + PCanonicalDict(PCanonicalString(), PInt32()), + Map("test" -> 1), + deepCopy = deepCopy, + interpret = interpret, + ) - copyTestExecutor(PCanonicalDict(PCanonicalString(true), PInt32(true)), PCanonicalDict(PCanonicalString(), PInt32()), Map("test2" -> 2), - deepCopy = deepCopy, interpret = interpret) + copyTestExecutor( + PCanonicalDict(PCanonicalString(true), PInt32(true)), + PCanonicalDict(PCanonicalString(), PInt32()), + Map("test2" -> 2), + deepCopy = deepCopy, + interpret = interpret, + ) - copyTestExecutor(PCanonicalDict(PCanonicalString(), PInt32()), PCanonicalDict(PCanonicalString(true), PInt32()), Map("test3" -> 3), - deepCopy = deepCopy, interpret = interpret) + copyTestExecutor( + PCanonicalDict(PCanonicalString(), PInt32()), + PCanonicalDict(PCanonicalString(true), PInt32()), + Map("test3" -> 3), + deepCopy = deepCopy, + interpret = interpret, + ) } runTests(true, false) @@ -216,10 +337,15 @@ class PContainerTest extends PhysicalTestUtils { runTests(false, interpret = true) } - @Test def setCopyTests() { - def runTests(deepCopy: Boolean, interpret: Boolean) { - copyTestExecutor(PCanonicalSet(PCanonicalString(true)), PCanonicalSet(PCanonicalString()), Set("1", "2"), - deepCopy = deepCopy, interpret = interpret) + @Test def setCopyTests(): Unit = { + def runTests(deepCopy: Boolean, interpret: Boolean): Unit = { + copyTestExecutor( + PCanonicalSet(PCanonicalString(true)), + PCanonicalSet(PCanonicalString()), + Set("1", "2"), + deepCopy = deepCopy, + interpret = interpret, + ) } runTests(true, false) diff --git a/hail/src/test/scala/is/hail/types/physical/PIntervalSuite.scala b/hail/src/test/scala/is/hail/types/physical/PIntervalSuite.scala index cf067d36f97..2ca046e14f8 100644 --- a/hail/src/test/scala/is/hail/types/physical/PIntervalSuite.scala +++ b/hail/src/test/scala/is/hail/types/physical/PIntervalSuite.scala @@ -6,29 +6,51 @@ import is.hail.expr.ir.EmitFunctionBuilder import is.hail.types.physical.stypes.concrete.{SUnreachableInterval, SUnreachableIntervalValue} import is.hail.types.virtual.{TInt32, TInterval} import is.hail.utils._ + import org.testng.annotations.Test class PIntervalSuite extends PhysicalTestUtils { - @Test def copyTests() { - def runTests(deepCopy: Boolean, interpret: Boolean = false) { - copyTestExecutor(PCanonicalInterval(PInt64()), PCanonicalInterval(PInt64()), + @Test def copyTests(): Unit = { + def runTests(deepCopy: Boolean, interpret: Boolean = false): Unit = { + copyTestExecutor( + PCanonicalInterval(PInt64()), + PCanonicalInterval(PInt64()), Interval(IntervalEndpoint(1000L, 1), IntervalEndpoint(1000L, 1)), - deepCopy = deepCopy, interpret = interpret) + deepCopy = deepCopy, + interpret = interpret, + ) - copyTestExecutor(PCanonicalInterval(PInt64(true)), PCanonicalInterval(PInt64()), + copyTestExecutor( + PCanonicalInterval(PInt64(true)), + PCanonicalInterval(PInt64()), Interval(IntervalEndpoint(1000L, 1), IntervalEndpoint(1000L, 1)), - deepCopy = deepCopy, interpret = interpret) + deepCopy = deepCopy, + interpret = interpret, + ) - copyTestExecutor(PCanonicalInterval(PInt64(true)), PCanonicalInterval(PInt64(true)), + copyTestExecutor( + PCanonicalInterval(PInt64(true)), + PCanonicalInterval(PInt64(true)), Interval(IntervalEndpoint(1000L, 1), IntervalEndpoint(1000L, 1)), - deepCopy = deepCopy, interpret = interpret) + deepCopy = deepCopy, + interpret = interpret, + ) - copyTestExecutor(PCanonicalInterval(PInt64()), PCanonicalInterval(PInt64(true)), + copyTestExecutor( + PCanonicalInterval(PInt64()), + PCanonicalInterval(PInt64(true)), Interval(IntervalEndpoint(1000L, 1), IntervalEndpoint(1000L, 1)), - deepCopy = deepCopy, interpret = interpret) + deepCopy = deepCopy, + interpret = interpret, + ) - copyTestExecutor(PCanonicalInterval(PInt64(true)), PCanonicalInterval(PInt64(true)), - Interval(IntervalEndpoint(1000L, 1), IntervalEndpoint(1000L, 1)), deepCopy = deepCopy, interpret = interpret) + copyTestExecutor( + PCanonicalInterval(PInt64(true)), + PCanonicalInterval(PInt64(true)), + Interval(IntervalEndpoint(1000L, 1), IntervalEndpoint(1000L, 1)), + deepCopy = deepCopy, + interpret = interpret, + ) } runTests(true) diff --git a/hail/src/test/scala/is/hail/types/physical/PNDArraySuite.scala b/hail/src/test/scala/is/hail/types/physical/PNDArraySuite.scala index 0168de8fc4d..c91333fe690 100644 --- a/hail/src/test/scala/is/hail/types/physical/PNDArraySuite.scala +++ b/hail/src/test/scala/is/hail/types/physical/PNDArraySuite.scala @@ -6,14 +6,20 @@ import is.hail.expr.ir.{EmitCodeBuilder, EmitFunctionBuilder} import is.hail.methods.LocalWhitening import is.hail.types.physical.stypes.interfaces.{ColonIndex => Colon, _} import is.hail.utils._ + import org.apache.spark.sql.Row import org.testng.annotations.Test class PNDArraySuite extends PhysicalTestUtils { - @Test def copyTests() { - def runTests(deepCopy: Boolean, interpret: Boolean = false) { - copyTestExecutor(PCanonicalNDArray(PInt64(true), 1), PCanonicalNDArray(PInt64(true), 1), new SafeNDArray(IndexedSeq(3L), IndexedSeq(4L,5L,6L)), - deepCopy = deepCopy, interpret = interpret) + @Test def copyTests(): Unit = { + def runTests(deepCopy: Boolean, interpret: Boolean = false): Unit = { + copyTestExecutor( + PCanonicalNDArray(PInt64(true), 1), + PCanonicalNDArray(PInt64(true), 1), + new SafeNDArray(IndexedSeq(3L), IndexedSeq(4L, 5L, 6L)), + deepCopy = deepCopy, + interpret = interpret, + ) } runTests(true) @@ -48,14 +54,24 @@ class PNDArraySuite extends PhysicalTestUtils { val work3 = vecType.constructUninitialized(FastSeq(btwpn), cb, region) val T = vecType.constructUninitialized(FastSeq(btwpn), cb, region) - A.coiterateMutate(cb, region) { case Seq(a) => + A.coiterateMutate(cb, region) { case Seq(_) => primitive(cb.memoize(cb.emb.newRNG(0L).invoke[Double]("rnorm"))) } - Acopy.coiterateMutate(cb, region, (A, "A")) { case Seq(acopy, a) => a } + Acopy.coiterateMutate(cb, region, (A, "A")) { case Seq(_, a) => a } SNDArray.geqrt_full(cb, Acopy, Q, R, T, work3, blocksize) - new LocalWhitening(cb, m, w, n, blocksize, region, false).whitenBlockPreOrthogonalized(cb, Q.slice(cb, Colon, (null, w)), Q.slice(cb, Colon, (w, null)), Qout, R, W, work1, work2, blocksize) + new LocalWhitening(cb, m, w, n, blocksize, region, false).whitenBlockPreOrthogonalized( + cb, + Q.slice(cb, Colon, (null, w)), + Q.slice(cb, Colon, (w, null)), + Qout, + R, + W, + work1, + work2, + blocksize, + ) SNDArray.trmm(cb, "R", "U", "N", "N", 1.0, R.slice(cb, (n, null), (n, null)), Qout) @@ -67,8 +83,8 @@ class PNDArraySuite extends PhysicalTestUtils { case Seq(l, r) => def lCode = l.asDouble.value cb.assign(diff, lCode - r.asDouble.value) - cb.assign(normDiff, normDiff + diff*diff) - cb.assign(normA, normA + lCode*lCode) + cb.assign(normDiff, normDiff + diff * diff) + cb.assign(normA, normA + lCode * lCode) } Code.invokeStatic1[java.lang.Math, Double, Double]("sqrt", normDiff / normA) } @@ -94,24 +110,51 @@ class PNDArraySuite extends PhysicalTestUtils { this.pool.scopedRegion { region => fb.emitWithBuilder { cb => val region = fb.getCodeParam[Region](1) - val A = matType.constructUninitialized(FastSeq(m, n), FastSeq(8, 8*m.v), cb, region) - val Acopy = matType.constructUninitialized(FastSeq(m, n), FastSeq(8, 8*m.v), cb, region) - val Q = matType.constructUninitialized(FastSeq(m, n), FastSeq(8, 8*m.v), cb, region) - val R = matType.constructUninitialized(FastSeq(n, n), FastSeq(8, 8*n.v), cb, region) + val A = matType.constructUninitialized(FastSeq(m, n), FastSeq(8, 8 * m.v), cb, region) + val Acopy = matType.constructUninitialized(FastSeq(m, n), FastSeq(8, 8 * m.v), cb, region) + val Q = matType.constructUninitialized(FastSeq(m, n), FastSeq(8, 8 * m.v), cb, region) + val R = matType.constructUninitialized(FastSeq(n, n), FastSeq(8, 8 * n.v), cb, region) val work = vecType.constructUninitialized(FastSeq(btm), FastSeq(8), cb, region) val T = vecType.constructUninitialized(FastSeq(btn), FastSeq(8), cb, region) - A.coiterateMutate(cb, region) { case Seq(a) => + A.coiterateMutate(cb, region) { case Seq(_) => primitive(cb.memoize(cb.emb.newRNG(0L).invoke[Double]("rnorm"))) } - Acopy.coiterateMutate(cb, region, (A, "A")) { case Seq(acopy, a) => a } + Acopy.coiterateMutate(cb, region, (A, "A")) { case Seq(_, a) => a } SNDArray.geqrt_full(cb, Acopy, Q, R, T, work, blocksize) new LocalWhitening(cb, m, w, n, blocksize, region, false).qrPivot(cb, Q, R, 0, p) - SNDArray.trmm(cb, "R", "U", "N", "N", 1.0, R.slice(cb, (null, p), (null, p)), Q.slice(cb, Colon, (null, p))) - SNDArray.gemm(cb, "N", "N", 1.0, Q.slice(cb, Colon, (p, null)), R.slice(cb, (p, null), (null, p)), 1.0, Q.slice(cb, Colon, (null, p))) - SNDArray.trmm(cb, "R", "U", "N", "N", 1.0, R.slice(cb, (p, null), (p, null)), Q.slice(cb, Colon, (p, null))) + SNDArray.trmm( + cb, + "R", + "U", + "N", + "N", + 1.0, + R.slice(cb, (null, p), (null, p)), + Q.slice(cb, Colon, (null, p)), + ) + SNDArray.gemm( + cb, + "N", + "N", + 1.0, + Q.slice(cb, Colon, (p, null)), + R.slice(cb, (p, null), (null, p)), + 1.0, + Q.slice(cb, Colon, (null, p)), + ) + SNDArray.trmm( + cb, + "R", + "U", + "N", + "N", + 1.0, + R.slice(cb, (p, null), (p, null)), + Q.slice(cb, Colon, (p, null)), + ) val normDiff = cb.newLocal[Double]("normDiff", 0.0) val normA = cb.newLocal[Double]("normA", 0.0) @@ -121,8 +164,8 @@ class PNDArraySuite extends PhysicalTestUtils { case Seq(l, r) => def lCode = l.asDouble.value cb.assign(diff, lCode - r.asDouble.value) - cb.assign(normDiff, normDiff + diff*diff) - cb.assign(normA, normA + lCode*lCode) + cb.assign(normDiff, normDiff + diff * diff) + cb.assign(normA, normA + lCode * lCode) } Code.invokeStatic1[java.lang.Math, Double, Double]("sqrt", normDiff / normA) } @@ -133,7 +176,13 @@ class PNDArraySuite extends PhysicalTestUtils { } } - def whitenNaive(cb: EmitCodeBuilder, X: SNDArrayValue, w: Int, blocksize: Int, region: Value[Region]): SNDArrayValue = { + def whitenNaive( + cb: EmitCodeBuilder, + X: SNDArrayValue, + w: Int, + blocksize: Int, + region: Value[Region], + ): SNDArrayValue = { val Seq(m, n) = X.shapes val vecType = PCanonicalNDArray(PFloat64Required, 1) val matType = PCanonicalNDArray(PFloat64Required, 2) @@ -146,18 +195,24 @@ class PNDArraySuite extends PhysicalTestUtils { val T = vecType.constructUninitialized(FastSeq(btn), cb, region) val j = cb.newLocal[Long]("j") - cb.for_(cb.assign(j, 0L), j < n, cb.assign(j, j+1), { - val windowStart = cb.memoize((j-w).max(0)) - val windowSize = cb.memoize(j - windowStart) - val window = curWindow.slice(cb, Colon, (null, windowSize)) - window.coiterateMutate(cb, region, (X.slice(cb, Colon, (windowStart, j)), "X")) { case Seq(_, v) => v } - val bs = cb.memoize(windowSize.min(blocksize).max(1)) - SNDArray.geqrt(window, T, work, bs, cb) - val curCol = Xw.slice(cb, Colon, j) - SNDArray.gemqrt("L", "T", window, T, curCol, work, bs, cb) - curCol.slice(cb, (null, windowSize)).setToZero(cb) - SNDArray.gemqrt("L", "N", window, T, curCol, work, bs, cb) - }) + cb.for_( + cb.assign(j, 0L), + j < n, + cb.assign(j, j + 1), { + val windowStart = cb.memoize((j - w).max(0)) + val windowSize = cb.memoize(j - windowStart) + val window = curWindow.slice(cb, Colon, (null, windowSize)) + window.coiterateMutate(cb, region, (X.slice(cb, Colon, (windowStart, j)), "X")) { + case Seq(_, v) => v + } + val bs = cb.memoize(windowSize.min(blocksize).max(1)) + SNDArray.geqrt(window, T, work, bs, cb) + val curCol = Xw.slice(cb, Colon, j) + SNDArray.gemqrt("L", "T", window, T, curCol, work, bs, cb) + curCol.slice(cb, (null, windowSize)).setToZero(cb) + SNDArray.gemqrt("L", "N", window, T, curCol, work, bs, cb) + }, + ) Xw } @@ -181,12 +236,27 @@ class PNDArraySuite extends PhysicalTestUtils { Aorig.coiterateMutate(cb, region) { case Seq(_) => primitive(cb.memoize(cb.emb.newRNG(0L).invoke[Double]("rnorm"))) } - A.coiterateMutate(cb, region, (Aorig.slice(cb, Colon, (w, null)), "Aorig")) { case Seq(_, a) => a } - state.Qtemp2.coiterateMutate(cb, region, (Aorig.slice(cb, Colon, (null, w)), "Aorig")) { case Seq(_, a) => a } + A.coiterateMutate(cb, region, (Aorig.slice(cb, Colon, (w, null)), "Aorig")) { + case Seq(_, a) => a + } + state.Qtemp2.coiterateMutate(cb, region, (Aorig.slice(cb, Colon, (null, w)), "Aorig")) { + case Seq(_, a) => a + } SNDArray.geqrt_full(cb, state.Qtemp2, state.Q, state.R, state.T, state.work3, blocksize) - state.whitenBlockSmallWindow(cb, state.Q, state.R, A, state.Qtemp, state.Qtemp2, state.Rtemp, state.work1, state.work2, blocksize) + state.whitenBlockSmallWindow( + cb, + state.Q, + state.R, + A, + state.Qtemp, + state.Qtemp2, + state.Rtemp, + state.work1, + state.work2, + blocksize, + ) // Q = Q*R SNDArray.trmm(cb, "R", "U", "N", "N", 1.0, state.R, state.Q) @@ -199,14 +269,13 @@ class PNDArraySuite extends PhysicalTestUtils { case Seq(l, r) => def lCode = l.asDouble.value cb.assign(diff, lCode - r.asDouble.value) - cb.assign(normDiff, normDiff + diff*diff) - cb.assign(normA, normA + lCode*lCode) + cb.assign(normDiff, normDiff + diff * diff) + cb.assign(normA, normA + lCode * lCode) } - var relError: Value[Double] = cb.memoize(Code.invokeStatic1[java.lang.Math, Double, Double]("sqrt", normDiff / normA)) - cb.if_(relError > 1e-14, { - cb._fatal("backwards error too large: ", relError.toS) - }) + var relError: Value[Double] = + cb.memoize(Code.invokeStatic1[java.lang.Math, Double, Double]("sqrt", normDiff / normA)) + cb.if_(relError > 1e-14, cb._fatal("backwards error too large: ", relError.toS)) val W2 = whitenNaive(cb, Aorig, w.v.toInt, blocksize.v.toInt, region) .slice(cb, Colon, (w, null)) @@ -218,15 +287,14 @@ class PNDArraySuite extends PhysicalTestUtils { case Seq(l, r) => def lCode = l.asDouble.value cb.assign(diff, lCode - r.asDouble.value) - cb.assign(normDiff, normDiff + diff*diff) - cb.assign(normW2, normW2 + lCode*lCode) + cb.assign(normDiff, normDiff + diff * diff) + cb.assign(normW2, normW2 + lCode * lCode) } - relError = cb.memoize(Code.invokeStatic1[java.lang.Math, Double, Double]("sqrt", normDiff / normW2)) + relError = + cb.memoize(Code.invokeStatic1[java.lang.Math, Double, Double]("sqrt", normDiff / normW2)) cb.println(relError.toS) - cb.if_(!(relError < 1e-14), { - cb._fatal("relative error vs naive too large: ", relError.toS) - }) + cb.if_(!(relError < 1e-14), cb._fatal("relative error vs naive too large: ", relError.toS)) Code._empty } @@ -259,10 +327,12 @@ class PNDArraySuite extends PhysicalTestUtils { val state = new LocalWhitening(cb, m, w, b, blocksize, region, false) val i = cb.newLocal[Long]("i", 0) - cb.while_(i < n, { - state.whitenBlock(cb, A.slice(cb, Colon, (i, (i+b).min(n)))) - cb.assign(i, i+b) - }) + cb.while_( + i < n, { + state.whitenBlock(cb, A.slice(cb, Colon, (i, (i + b).min(n)))) + cb.assign(i, i + b) + }, + ) val W2 = whitenNaive(cb, Aorig, w.v.toInt, blocksize.v.toInt, region) @@ -275,14 +345,13 @@ class PNDArraySuite extends PhysicalTestUtils { case Seq(l, r) => def lCode = l.asDouble.value cb.assign(diff, lCode - r.asDouble.value) - cb.assign(normDiff, normDiff + diff*diff) - cb.assign(normW2, normW2 + lCode*lCode) + cb.assign(normDiff, normDiff + diff * diff) + cb.assign(normW2, normW2 + lCode * lCode) } - val relError = cb.memoize(Code.invokeStatic1[java.lang.Math, Double, Double]("sqrt", normDiff / normW2)) - cb.if_(!(relError < 1e-14), { - cb._fatal("relative error vs naive too large: ", relError.toS) - }) + val relError = + cb.memoize(Code.invokeStatic1[java.lang.Math, Double, Double]("sqrt", normDiff / normW2)) + cb.if_(!(relError < 1e-14), cb._fatal("relative error vs naive too large: ", relError.toS)) Code._empty } @@ -296,16 +365,15 @@ class PNDArraySuite extends PhysicalTestUtils { @Test def testRefCounted(): Unit = { val nd = PCanonicalNDArray(PInt32Required, 1) - val region1 = Region(pool=this.pool) - val region2 = Region(pool=this.pool) - val region3 = Region(pool=this.pool) + val region1 = Region(pool = this.pool) + val region2 = Region(pool = this.pool) + val region3 = Region(pool = this.pool) val fb = EmitFunctionBuilder[Region, Region, Region, Long](ctx, "ref_count_test") val codeRegion1 = fb.getCodeParam[Region](1) val codeRegion2 = fb.getCodeParam[Region](2) - val codeRegion3 = fb.getCodeParam[Region](3) try { - fb.emitWithBuilder{ cb => + fb.emitWithBuilder { cb => val r2PointerToNDAddress1 = cb.newLocal[Long]("r2_ptr_to_nd_addr1") val shapeSeq = IndexedSeq(const(3L)) @@ -317,7 +385,7 @@ class PNDArraySuite extends PhysicalTestUtils { // Region 2 gets an ndarray at ndaddress2, plus a reference to the one at ndarray 1. val (_, snd2Finisher) = nd.constructDataFunction(shapeSeq, shapeSeq, cb, codeRegion2) - val snd2 = snd2Finisher(cb) + snd2Finisher(cb) cb.assign(r2PointerToNDAddress1, nd.store(cb, codeRegion2, snd1, true)) // Return the 1st ndarray @@ -353,7 +421,6 @@ class PNDArraySuite extends PhysicalTestUtils { val rc1B = Region.loadLong(result1Data - Region.sharedChunkHeaderBytes) assert(rc1B == 1) - assert(region3.memory.listNDArrayRefs().size == 0) // Do an unstaged copy into region3 nd.copyFromAddress(ctx.stateManager, region3, nd, result1, true) @@ -365,11 +432,11 @@ class PNDArraySuite extends PhysicalTestUtils { } @Test def testUnstagedCopy(): Unit = { - val region1 = Region(pool=this.pool) - val region2 = Region(pool=this.pool) + val region1 = Region(pool = this.pool) + val region2 = Region(pool = this.pool) val x = SafeNDArray(IndexedSeq(3L, 2L), (0 until 6).map(_.toDouble)) val pNd = PCanonicalNDArray(PFloat64Required, 2, true) - val ndAddr1 = pNd.unstagedStoreJavaObject(ctx.stateManager, x, region=region1) + val ndAddr1 = pNd.unstagedStoreJavaObject(ctx.stateManager, x, region = region1) val ndAddr2 = pNd.copyFromAddress(ctx.stateManager, region2, pNd, ndAddr1, true) val unsafe1 = UnsafeRow.read(pNd, region1, ndAddr1) val unsafe2 = UnsafeRow.read(pNd, region2, ndAddr2) @@ -396,11 +463,18 @@ class PNDArraySuite extends PhysicalTestUtils { // assert(PNDArray.getReferenceCount(addr4) == 1L) // Deep copy with PTypes with different requirements - val pNDOfStructs1 = PCanonicalNDArray(PCanonicalStruct(true, ("x", PInt32Required), ("y", PInt32())), 1) - val pNDOfStructs2 = PCanonicalNDArray(PCanonicalStruct(true, ("x", PInt32()), ("y", PInt32Required)), 1) - val annotationNDOfStructs = new SafeNDArray(IndexedSeq(5L), (0 until 5).map(idx => Row(idx, idx + 100))) - - val addr5 = pNDOfStructs1.unstagedStoreJavaObject(ctx.stateManager, annotationNDOfStructs, region=region1) + val pNDOfStructs1 = + PCanonicalNDArray(PCanonicalStruct(true, ("x", PInt32Required), ("y", PInt32())), 1) + val pNDOfStructs2 = + PCanonicalNDArray(PCanonicalStruct(true, ("x", PInt32()), ("y", PInt32Required)), 1) + val annotationNDOfStructs = + new SafeNDArray(IndexedSeq(5L), (0 until 5).map(idx => Row(idx, idx + 100))) + + val addr5 = pNDOfStructs1.unstagedStoreJavaObject( + ctx.stateManager, + annotationNDOfStructs, + region = region1, + ) val unsafe5 = UnsafeRow.read(pNDOfStructs1, region1, addr5) val addr6 = pNDOfStructs2.copyFromAddress(ctx.stateManager, region2, pNDOfStructs1, addr5, true) val unsafe6 = UnsafeRow.read(pNDOfStructs2, region2, addr6) diff --git a/hail/src/test/scala/is/hail/expr/ir/PTypeSuite.scala b/hail/src/test/scala/is/hail/types/physical/PTypeSuite.scala similarity index 76% rename from hail/src/test/scala/is/hail/expr/ir/PTypeSuite.scala rename to hail/src/test/scala/is/hail/types/physical/PTypeSuite.scala index ccddbbbfba3..d33b5798d3d 100644 --- a/hail/src/test/scala/is/hail/expr/ir/PTypeSuite.scala +++ b/hail/src/test/scala/is/hail/types/physical/PTypeSuite.scala @@ -1,18 +1,18 @@ -package is.hail.expr.ir +package is.hail.types.physical import is.hail.HailSuite import is.hail.rvd.AbstractRVDSpec -import is.hail.types.physical._ import is.hail.types.virtual._ import is.hail.utils._ import is.hail.variant.ReferenceGenome + import org.apache.spark.sql.Row import org.json4s.jackson.Serialization import org.testng.annotations.{DataProvider, Test} class PTypeSuite extends HailSuite { - @DataProvider(name="ptypes") + @DataProvider(name = "ptypes") def ptypes(): Array[Array[Any]] = { Array[PType]( PInt32(true), @@ -30,16 +30,21 @@ class PTypeSuite extends HailSuite { PCanonicalSet(PInt32Required, false), PCanonicalDict(PInt32Required, PCanonicalString(true), true), PCanonicalInterval(PInt32Optional, false), - PCanonicalTuple(FastSeq(PTupleField(1, PInt32Required), PTupleField(3, PCanonicalString(false))), true), - PCanonicalStruct(FastSeq(PField("foo", PInt32Required, 0), PField("bar", PCanonicalString(false), 1)), true) + PCanonicalTuple( + FastSeq(PTupleField(1, PInt32Required), PTupleField(3, PCanonicalString(false))), + true, + ), + PCanonicalStruct( + FastSeq(PField("foo", PInt32Required, 0), PField("bar", PCanonicalString(false), 1)), + true, + ), ).map(t => Array(t: Any)) } - @Test def testPTypesDataProvider(): Unit = { + @Test def testPTypesDataProvider(): Unit = ptypes() - } - @Test(dataProvider="ptypes") + @Test(dataProvider = "ptypes") def testSerialization(ptype: PType): Unit = { implicit val formats = AbstractRVDSpec.formats val s = Serialization.write(ptype) @@ -54,7 +59,10 @@ class PTypeSuite extends HailSuite { assert(PType.literalPType(TArray(TInt32), FastSeq(1, null)) == PCanonicalArray(PInt32(), true)) assert(PType.literalPType(TArray(TInt32), FastSeq(1, 5)) == PCanonicalArray(PInt32(true), true)) - assert(PType.literalPType(TInterval(TInt32), Interval(5, null, false, true)) == PCanonicalInterval(PInt32(), true)) + assert(PType.literalPType( + TInterval(TInt32), + Interval(5, null, false, true), + ) == PCanonicalInterval(PInt32(), true)) val p = TStruct("a" -> TInt32, "b" -> TInt32) val d = TDict(p, p) @@ -62,6 +70,7 @@ class PTypeSuite extends HailSuite { PCanonicalDict( PCanonicalStruct(true, "a" -> PInt32(true), "b" -> PInt32()), PCanonicalStruct(true, "a" -> PInt32(), "b" -> PInt32(true)), - true)) + true, + )) } } diff --git a/hail/src/test/scala/is/hail/types/physical/PhysicalTestUtils.scala b/hail/src/test/scala/is/hail/types/physical/PhysicalTestUtils.scala index 2f9f60371e7..6fcf7a8f986 100644 --- a/hail/src/test/scala/is/hail/types/physical/PhysicalTestUtils.scala +++ b/hail/src/test/scala/is/hail/types/physical/PhysicalTestUtils.scala @@ -1,17 +1,24 @@ package is.hail.types.physical import is.hail.HailSuite -import is.hail.utils.{HailException, log} import is.hail.annotations.{Region, ScalaToRegionValue, UnsafeRow} import is.hail.expr.ir.EmitFunctionBuilder +import is.hail.utils.{log, HailException} abstract class PhysicalTestUtils extends HailSuite { - def copyTestExecutor(sourceType: PType, destType: PType, sourceValue: Any, - expectCompileError: Boolean = false, expectRuntimeError: Boolean = false, - deepCopy: Boolean = false, interpret: Boolean = false, expectedValue: Any = null) { - - val srcRegion = Region(pool=pool) - val region = Region(pool=pool) + def copyTestExecutor( + sourceType: PType, + destType: PType, + sourceValue: Any, + expectCompileError: Boolean = false, + expectRuntimeError: Boolean = false, + deepCopy: Boolean = false, + interpret: Boolean = false, + expectedValue: Any = null, + ): Unit = { + + val srcRegion = Region(pool = pool) + val region = Region(pool = pool) val srcAddress = sourceType match { case x: PSubsetStruct => ScalaToRegionValue(ctx.stateManager, srcRegion, x.ps, sourceValue) @@ -20,12 +27,18 @@ abstract class PhysicalTestUtils extends HailSuite { if (interpret) { try { - val copyOff = destType.copyFromAddress(ctx.stateManager, region, sourceType, srcAddress, deepCopy = deepCopy) + val copyOff = destType.copyFromAddress( + ctx.stateManager, + region, + sourceType, + srcAddress, + deepCopy = deepCopy, + ) val copy = UnsafeRow.read(destType, region, copyOff) - log.info(s"Copied value: ${copy}, Source value: ${sourceValue}") + log.info(s"Copied value: $copy, Source value: $sourceValue") - if(expectedValue != null) { + if (expectedValue != null) { assert(copy == expectedValue) } else { assert(copy == sourceValue) @@ -54,7 +67,9 @@ abstract class PhysicalTestUtils extends HailSuite { val value = fb.getCodeParam[Long](2) try { - fb.emitWithBuilder(cb => destType.store(cb, codeRegion, sourceType.loadCheapSCode(cb, value), deepCopy = deepCopy)) + fb.emitWithBuilder(cb => + destType.store(cb, codeRegion, sourceType.loadCheapSCode(cb, value), deepCopy = deepCopy) + ) compileSuccess = true } catch { case e: Throwable => @@ -75,23 +90,24 @@ abstract class PhysicalTestUtils extends HailSuite { throw new Error("Did not receive expected compile time error") } - val copy = try { - val f = fb.result()(theHailClassLoader) - val copyOff = f(region, srcAddress) - UnsafeRow.read(destType, region, copyOff) - } catch { - case e: HailException => - if (expectRuntimeError) { - log.info("OK: Caught expected compile-time error") - return - } + val copy = + try { + val f = fb.result()(theHailClassLoader) + val copyOff = f(region, srcAddress) + UnsafeRow.read(destType, region, copyOff) + } catch { + case e: HailException => + if (expectRuntimeError) { + log.info("OK: Caught expected compile-time error") + return + } - throw e - } + throw e + } - log.info(s"Copied value: ${ copy }, Source value: ${ sourceValue }") + log.info(s"Copied value: $copy, Source value: $sourceValue") - if(expectedValue != null) { + if (expectedValue != null) { assert(copy == expectedValue) } else { assert(copy == sourceValue) diff --git a/hail/src/test/scala/is/hail/types/physical/stypes/concrete/SStructViewSuite.scala b/hail/src/test/scala/is/hail/types/physical/stypes/concrete/SStructViewSuite.scala new file mode 100644 index 00000000000..6566940f69f --- /dev/null +++ b/hail/src/test/scala/is/hail/types/physical/stypes/concrete/SStructViewSuite.scala @@ -0,0 +1,61 @@ +package is.hail.types.physical.stypes.concrete + +import is.hail.HailSuite +import is.hail.types.physical.stypes.SType +import is.hail.types.physical.stypes.interfaces.SBaseStruct +import is.hail.types.tcoerce +import is.hail.types.virtual.{TInt32, TInt64, TStruct} +import is.hail.utils.FastSeq +import org.testng.annotations.Test + +class SStructViewSuite extends HailSuite { + + val xyz: SBaseStruct = + tcoerce[SBaseStruct]( + SType.canonical( + TStruct( + "x" -> TInt32, + "y" -> TInt64, + "z" -> TStruct("a" -> TInt32), + ) + ) + ) + + @Test def testCastRename(): Unit = { + val newtype = TStruct("x" -> TStruct("b" -> TInt32)) + + val expected = + new SStructView( + parent = xyz, + restrict = FastSeq("z"), + rename = newtype, + ) + + assert(SStructView.subset(FastSeq("z"), xyz).castRename(newtype) == expected) + } + + @Test def testSubsetRenameSubset(): Unit = { + val subset = + SStructView.subset( + FastSeq("x"), + SStructView.subset(FastSeq("x", "z"), xyz) + .castRename(TStruct("y" -> TInt32, "x" -> TStruct("b" -> TInt32))) + .asInstanceOf[SBaseStruct], + ) + + val expected = + new SStructView( + parent = xyz, + restrict = FastSeq("z"), + rename = TStruct("x" -> TStruct("b" -> TInt32)), + ) + + assert(subset == expected) + } + + @Test def testAssertIsomorphism(): Unit = + intercept[AssertionError] { + SStructView.subset(FastSeq("x", "y"), xyz) + .castRename(TStruct("x" -> TInt64, "x" -> TInt32)) + } +} diff --git a/hail/src/test/scala/is/hail/types/virtual/TStructSuite.scala b/hail/src/test/scala/is/hail/types/virtual/TStructSuite.scala index 776fb25a3f6..1ea0329e39f 100644 --- a/hail/src/test/scala/is/hail/types/virtual/TStructSuite.scala +++ b/hail/src/test/scala/is/hail/types/virtual/TStructSuite.scala @@ -3,8 +3,8 @@ package is.hail.types.virtual import is.hail.HailSuite import is.hail.annotations.{Annotation, Inserter} import is.hail.utils.FastSeq + import org.apache.spark.sql.Row -import org.testng.Assert.{assertFalse, assertTrue} import org.testng.annotations.{DataProvider, Test} class TStructSuite extends HailSuite { @@ -21,7 +21,7 @@ class TStructSuite extends HailSuite { // Consider joins for example - we only care that the key fields have the same types // so we compare the key types (which are structs) for equality ignoring field names. // isPrefixOf is used in similar cases involving key types where we don't care about names. - Array(TStruct("b" -> TVoid), TStruct("a" -> TVoid, "b" -> TVoid), true) + Array(TStruct("b" -> TVoid), TStruct("a" -> TVoid, "b" -> TVoid), true), ) @Test(dataProvider = "isPrefixOf") @@ -37,27 +37,32 @@ class TStructSuite extends HailSuite { Array(TStruct("a" -> TVoid), TStruct("a" -> TVoid), true), Array(TStruct("a" -> TVoid), TStruct("a" -> TVoid, "b" -> TVoid), true), Array(TStruct("a" -> TVoid, "b" -> TVoid), TStruct("a" -> TVoid), false), - Array(TStruct("b" -> TVoid), TStruct("a" -> TVoid, "b" -> TVoid), true) + Array(TStruct("b" -> TVoid), TStruct("a" -> TVoid, "b" -> TVoid), true), ) @Test(dataProvider = "isSubsetOf") def testIsSubsetOf(a: TStruct, b: TStruct, isSubset: Boolean): Unit = assert(a.isSubsetOf(b) == isSubset, s"expected $a `isSubsetOf` $b == $isSubset") - @DataProvider(name = "structInsert") def structInsertData: Array[Array[Any]] = Array( Array(TStruct("a" -> TInt32), FastSeq("a"), TInt32, TStruct("a" -> TInt32)), Array(TStruct("a" -> TInt32), FastSeq("b"), TInt32, TStruct("a" -> TInt32, "b" -> TInt32)), Array(TStruct("a" -> TInt32), FastSeq("a"), TVoid, TStruct("a" -> TVoid)), - Array(TStruct("a" -> TInt32), FastSeq("a", "b"), TInt32, TStruct("a" -> TStruct("b" -> TInt32))), + Array( + TStruct("a" -> TInt32), + FastSeq("a", "b"), + TInt32, + TStruct("a" -> TStruct("b" -> TInt32)), + ), Array(TStruct.empty, FastSeq("a"), TInt32, TStruct("a" -> TInt32)), Array(TStruct.empty, FastSeq("a", "b"), TInt32, TStruct("a" -> TStruct("b" -> TInt32))), ) @Test(dataProvider = "structInsert") - def testStructInsert(base: TStruct, path: IndexedSeq[String], signature: Type, expected: TStruct): Unit = + def testStructInsert(base: TStruct, path: IndexedSeq[String], signature: Type, expected: TStruct) + : Unit = assert(base.structInsert(signature, path) == expected) @Test def testInsertEmptyPath(): Unit = @@ -65,16 +70,13 @@ class TStructSuite extends HailSuite { TStruct.empty.insert(TInt32, FastSeq()) } - @DataProvider(name = "inserter") def inserterData: Array[Array[Any]] = Array( Array(TStruct("a" -> TInt32).insert(TInt32, FastSeq("a"))._2, null, 0, Row(0)), Array(TStruct("a" -> TInt32).insert(TInt32, FastSeq("a"))._2, Row(0), 1, Row(1)), - Array(TStruct("a" -> TInt32).insert(TInt32, FastSeq("b"))._2, null, 0, Row(null, 0)), Array(TStruct("a" -> TInt32).insert(TInt32, FastSeq("b"))._2, Row(0), 1, Row(0, 1)), - Array(TStruct.empty.insert(TInt32, FastSeq("a", "b"))._2, null, 0, Row(Row(0))), Array(TStruct("a" -> TInt32).insert(TInt32, FastSeq("a", "b"))._2, Row(0), 1, Row(Row(1))), ) @@ -83,4 +85,36 @@ class TStructSuite extends HailSuite { def testInsert(inserter: Inserter, base: Annotation, value: Any, expected: Annotation): Unit = assert(inserter(base, value) == expected) + @DataProvider(name = "isIsomorphicTo") + def isIsomorphicToData: Array[Array[Any]] = + Array( + Array(TStruct.empty, TStruct.empty, true), + Array(TStruct.empty, TStruct("a" -> TVoid), false), + Array(TStruct("a" -> TVoid), TStruct.empty, false), + Array(TStruct("a" -> TVoid), TStruct("b" -> TVoid), true), + Array(TStruct("a" -> TStruct("b" -> TVoid)), TStruct("b" -> TStruct("a" -> TVoid)), true), + Array(TStruct("a" -> TVoid), TStruct("a" -> TVoid, "b" -> TVoid), false), + Array(TStruct("a" -> TVoid, "b" -> TVoid), TStruct("a" -> TVoid), false), + ) + + @Test(dataProvider = "isIsomorphicTo") + def testIsIsomorphicTo(a: TStruct, b: TStruct, isIsomorphic: Boolean): Unit = + assert((a isIsomorphicTo b) == isIsomorphic, s"expected $a isIsomorphicTo $b == $isIsomorphic") + + @DataProvider(name = "isJoinableWith") + def isJoinableWithData: Array[Array[Any]] = + Array( + Array(TStruct.empty, TStruct.empty, true), + Array(TStruct.empty, TStruct("a" -> TVoid), false), + Array(TStruct("a" -> TVoid), TStruct.empty, false), + Array(TStruct("a" -> TVoid), TStruct("b" -> TVoid), true), + Array(TStruct("a" -> TStruct("a" -> TVoid)), TStruct("b" -> TStruct("a" -> TVoid)), true), + Array(TStruct("a" -> TStruct("a" -> TVoid)), TStruct("b" -> TStruct("b" -> TVoid)), false), + Array(TStruct("a" -> TVoid, "b" -> TVoid), TStruct("a" -> TVoid), false), + Array(TStruct("b" -> TVoid), TStruct("a" -> TVoid, "b" -> TVoid), false), + ) + + @Test(dataProvider = "isJoinableWith") + def testIsJoinableWith(a: TStruct, b: TStruct, isJoinable: Boolean): Unit = + assert((a isJoinableWith b) == isJoinable, s"expected $a isJoinableWith $b == $isJoinable") } diff --git a/hail/src/test/scala/is/hail/utils/ArrayBuilderSuite.scala b/hail/src/test/scala/is/hail/utils/ArrayBuilderSuite.scala index f33c4ce89e9..5b4470cdbac 100644 --- a/hail/src/test/scala/is/hail/utils/ArrayBuilderSuite.scala +++ b/hail/src/test/scala/is/hail/utils/ArrayBuilderSuite.scala @@ -1,11 +1,12 @@ package is.hail.utils import is.hail.expr.ir.IntArrayBuilder + import org.scalatest.testng.TestNGSuite import org.testng.annotations.Test class ArrayBuilderSuite extends TestNGSuite { - @Test def addOneElement() { + @Test def addOneElement(): Unit = { val ab = new IntArrayBuilder(0) ab += 3 val a = ab.result() @@ -13,7 +14,7 @@ class ArrayBuilderSuite extends TestNGSuite { assert(a(0) == 3) } - @Test def addArray() { + @Test def addArray(): Unit = { val ab = new IntArrayBuilder(0) ab ++= Array.fill[Int](5)(2) val a = ab.result() diff --git a/hail/src/test/scala/is/hail/utils/ArrayStackSuite.scala b/hail/src/test/scala/is/hail/utils/ArrayStackSuite.scala index a772f13d1e3..55bcc0eeba5 100644 --- a/hail/src/test/scala/is/hail/utils/ArrayStackSuite.scala +++ b/hail/src/test/scala/is/hail/utils/ArrayStackSuite.scala @@ -4,7 +4,7 @@ import org.scalatest.testng.TestNGSuite import org.testng.annotations.Test class ArrayStackSuite extends TestNGSuite { - @Test def test() { + @Test def test(): Unit = { val s = new IntArrayStack(4) assert(s.isEmpty) assert(!s.nonEmpty) diff --git a/hail/src/test/scala/is/hail/utils/BinaryHeapSuite.scala b/hail/src/test/scala/is/hail/utils/BinaryHeapSuite.scala index 3e02cc500c3..de929cb902a 100644 --- a/hail/src/test/scala/is/hail/utils/BinaryHeapSuite.scala +++ b/hail/src/test/scala/is/hail/utils/BinaryHeapSuite.scala @@ -1,16 +1,17 @@ package is.hail.utils +import is.hail.check.Arbitrary._ import is.hail.check.Gen import is.hail.check.Prop._ -import is.hail.check.Arbitrary._ + import scala.collection.mutable -import org.scalatest._ -import Matchers._ + +import org.scalatest.Matchers._ import org.testng.annotations.Test class BinaryHeapSuite { @Test - def insertOneIsMax() { + def insertOneIsMax(): Unit = { val bh = new BinaryHeap[Int]() bh.insert(1, 10) assert(bh.max() === 1) @@ -26,7 +27,7 @@ class BinaryHeapSuite { } @Test - def twoElements() { + def twoElements(): Unit = { val bh = new BinaryHeap[Int]() bh.insert(1, 5) assert(bh.contains(1) == true) @@ -47,7 +48,7 @@ class BinaryHeapSuite { } @Test - def threeElements() { + def threeElements(): Unit = { val bh = new BinaryHeap[Int]() bh.insert(1, -10) assert(bh.contains(1) == true) @@ -79,7 +80,7 @@ class BinaryHeapSuite { } @Test - def decreaseToKey1() { + def decreaseToKey1(): Unit = { val bh = new BinaryHeap[Int]() bh.insert(1, 0) @@ -94,7 +95,7 @@ class BinaryHeapSuite { } @Test - def decreaseToKey2() { + def decreaseToKey2(): Unit = { val bh = new BinaryHeap[Int]() bh.insert(1, 0) @@ -109,7 +110,7 @@ class BinaryHeapSuite { } @Test - def decreaseToKeyButNoOrderingChange() { + def decreaseToKeyButNoOrderingChange(): Unit = { val bh = new BinaryHeap[Int]() bh.insert(1, 0) @@ -124,7 +125,7 @@ class BinaryHeapSuite { } @Test - def decreaseKey1() { + def decreaseKey1(): Unit = { val bh = new BinaryHeap[Int]() bh.insert(1, 0) @@ -139,7 +140,7 @@ class BinaryHeapSuite { } @Test - def decreaseKey2() { + def decreaseKey2(): Unit = { val bh = new BinaryHeap[Int]() bh.insert(1, 0) @@ -154,7 +155,7 @@ class BinaryHeapSuite { } @Test - def decreaseKeyButNoOrderingChange() { + def decreaseKeyButNoOrderingChange(): Unit = { val bh = new BinaryHeap[Int]() bh.insert(1, 0) @@ -169,7 +170,7 @@ class BinaryHeapSuite { } @Test - def increaseToKey1() { + def increaseToKey1(): Unit = { val bh = new BinaryHeap[Int]() bh.insert(1, 0) @@ -184,7 +185,7 @@ class BinaryHeapSuite { } @Test - def increaseToKeys() { + def increaseToKeys(): Unit = { val bh = new BinaryHeap[Int]() bh.insert(1, 0) @@ -202,7 +203,7 @@ class BinaryHeapSuite { } @Test - def increaseKey1() { + def increaseKey1(): Unit = { val bh = new BinaryHeap[Int]() bh.insert(1, 0) @@ -217,7 +218,7 @@ class BinaryHeapSuite { } @Test - def increaseKeys() { + def increaseKeys(): Unit = { val bh = new BinaryHeap[Int]() bh.insert(1, 0) @@ -235,7 +236,7 @@ class BinaryHeapSuite { } @Test - def samePriority() { + def samePriority(): Unit = { val bh = new BinaryHeap[Int]() bh.insert(1, 0) @@ -249,7 +250,7 @@ class BinaryHeapSuite { } @Test - def successivelyMoreInserts() { + def successivelyMoreInserts(): Unit = { for (count <- Seq(2, 4, 8, 16, 32)) { val bh = new BinaryHeap[Int](8) val trace = new BoxedArrayBuilder[String]() @@ -267,14 +268,17 @@ class BinaryHeapSuite { trace += bh.toString() bh.checkHeapProperty() val expected = count - i - 1 - assert(actual === expected, s"[$count] $actual did not equal $expected, heap: $bh; trace ${ trace.result().mkString("\n") }") + assert( + actual === expected, + s"[$count] $actual did not equal $expected, heap: $bh; trace ${trace.result().mkString("\n")}", + ) } assert(bh.isEmpty) } } @Test - def growPastCapacity4() { + def growPastCapacity4(): Unit = { val bh = new BinaryHeap[Int](4) bh.insert(1, 0) bh.insert(2, 0) @@ -285,16 +289,15 @@ class BinaryHeapSuite { } @Test - def growPastCapacity32() { + def growPastCapacity32(): Unit = { val bh = new BinaryHeap[Int](32) - for (i <- 0 to 32) { + for (i <- 0 to 32) bh.insert(i, 0) - } assert(true) } @Test - def shrinkCapacity() { + def shrinkCapacity(): Unit = { val bh = new BinaryHeap[Int](8) val trace = new BoxedArrayBuilder[String]() trace += bh.toString() @@ -304,27 +307,30 @@ class BinaryHeapSuite { trace += bh.toString() bh.checkHeapProperty() } - assert(bh.size === 64, s"trace: ${ trace.result().mkString("\n") }") - assert(bh.max() === 63, s"trace: ${ trace.result().mkString("\n") }") + assert(bh.size === 64, s"trace: ${trace.result().mkString("\n")}") + assert(bh.max() === 63, s"trace: ${trace.result().mkString("\n")}") // shrinking happens when size is <1/4 of capacity for (i <- 0 until (32 + 16 + 1)) { val actual = bh.extractMax() val expected = 64 - i - 1 trace += bh.toString() bh.checkHeapProperty() - assert(actual === expected, s"$actual did not equal $expected, trace: ${ trace.result().mkString("\n") }") + assert( + actual === expected, + s"$actual did not equal $expected, trace: ${trace.result().mkString("\n")}", + ) } - assert(bh.size === 15, s"trace: ${ trace.result().mkString("\n") }") - assert(bh.max() === 14, s"trace: ${ trace.result().mkString("\n") }") + assert(bh.size === 15, s"trace: ${trace.result().mkString("\n")}") + assert(bh.max() === 14, s"trace: ${trace.result().mkString("\n")}") } - private sealed trait HeapOp + sealed private trait HeapOp - private sealed case class Max() extends HeapOp + sealed private case class Max() extends HeapOp - private sealed case class ExtractMax() extends HeapOp + sealed private case class ExtractMax() extends HeapOp - private sealed case class Insert(t: Long, rank: Long) extends HeapOp + sealed private case class Insert(t: Long, rank: Long) extends HeapOp private class LongPriorityQueueReference { @@ -347,13 +353,12 @@ class BinaryHeapSuite { max } - def insert(t: Long, rank: Long) { + def insert(t: Long, rank: Long): Unit = m += (t -> rank) - } } @Test - def sameAsReferenceImplementation() { + def sameAsReferenceImplementation(): Unit = { import Gen._ val ops = for { @@ -371,16 +376,16 @@ class BinaryHeapSuite { opList.foreach { case Max() => if (bh.isEmpty && ref.isEmpty) - assert(true, s"trace; ${ trace.result().mkString("\n") }") + assert(true, s"trace; ${trace.result().mkString("\n")}") else - assert(bh.max() === ref.max(), s"trace; ${ trace.result().mkString("\n") }") + assert(bh.max() === ref.max(), s"trace; ${trace.result().mkString("\n")}") trace += bh.toString() bh.checkHeapProperty() case ExtractMax() => if (bh.isEmpty && ref.isEmpty) - assert(true, s"trace; ${ trace.result().mkString("\n") }") + assert(true, s"trace; ${trace.result().mkString("\n")}") else - assert(bh.max() === ref.max(), s"trace; ${ trace.result().mkString("\n") }") + assert(bh.max() === ref.max(), s"trace; ${trace.result().mkString("\n")}") trace += bh.toString() bh.checkHeapProperty() case Insert(t, rank) => @@ -388,7 +393,7 @@ class BinaryHeapSuite { ref.insert(t, rank) trace += bh.toString() bh.checkHeapProperty() - assert(bh.size === ref.size, s"trace; ${ trace.result().mkString("\n") }") + assert(bh.size === ref.size, s"trace; ${trace.result().mkString("\n")}") } true }.check() @@ -404,7 +409,7 @@ class BinaryHeapSuite { } @Test - def tieBreakingDoesntChangeExistingFunctionality() { + def tieBreakingDoesntChangeExistingFunctionality(): Unit = { val bh = new BinaryHeap[Int](maybeTieBreaker = evensFirst) bh.insert(1, -10) assert(bh.contains(1) == true) @@ -436,7 +441,7 @@ class BinaryHeapSuite { } @Test - def tieBreakingHappens() { + def tieBreakingHappens(): Unit = { val bh = new BinaryHeap[Int](maybeTieBreaker = evensFirst) bh.insert(1, -10) assert(bh.contains(1) == true) @@ -468,7 +473,7 @@ class BinaryHeapSuite { } @Test - def tieBreakingThreeWayDeterministic() { + def tieBreakingThreeWayDeterministic(): Unit = { val bh = new BinaryHeap[Int](maybeTieBreaker = evensFirst) bh.insert(1, -5) assert(bh.contains(1) == true) @@ -502,7 +507,7 @@ class BinaryHeapSuite { } @Test - def tieBreakingThreeWayNonDeterministic() { + def tieBreakingThreeWayNonDeterministic(): Unit = { val bh = new BinaryHeap[Int](maybeTieBreaker = evensFirst) bh.insert(0, -5) assert(bh.contains(0) == true) @@ -536,7 +541,7 @@ class BinaryHeapSuite { } @Test - def tieBreakingAfterPriorityChange() { + def tieBreakingAfterPriorityChange(): Unit = { val bh = new BinaryHeap[Int](maybeTieBreaker = evensFirst) bh.insert(1, 15) bh.insert(2, 10) diff --git a/hail/src/test/scala/is/hail/utils/BitVectorSuite.scala b/hail/src/test/scala/is/hail/utils/BitVectorSuite.scala index e8ffe9609f8..b1b6b6a8ab7 100644 --- a/hail/src/test/scala/is/hail/utils/BitVectorSuite.scala +++ b/hail/src/test/scala/is/hail/utils/BitVectorSuite.scala @@ -1,13 +1,14 @@ package is.hail.utils -import is.hail.check.Prop._ import is.hail.check._ +import is.hail.check.Prop._ + import org.scalatest.testng.TestNGSuite import org.testng.annotations.Test class BitVectorSuite extends TestNGSuite { - @Test def test() { + @Test def test(): Unit = { val bv0 = new BitVector(0) assert(bv0.length == 0) @@ -20,10 +21,10 @@ class BitVectorSuite extends TestNGSuite { } val g = - for ( - n <- Gen.choose(1, 1000); + for { + n <- Gen.choose(1, 1000) s <- Gen.buildableOf[Set](Gen.choose(0, n - 1)) - ) yield { + } yield { val bv = new BitVector(n) assert(bv.length == n) @@ -62,7 +63,7 @@ class BitVectorSuite extends TestNGSuite { () } - val p = forAll(g) { _ => true } + val p = forAll(g)(_ => true) p.check() } } diff --git a/hail/src/test/scala/is/hail/utils/BufferedAggregatorIteratorSuite.scala b/hail/src/test/scala/is/hail/utils/BufferedAggregatorIteratorSuite.scala index 9ded0248d34..45ca9196ba3 100644 --- a/hail/src/test/scala/is/hail/utils/BufferedAggregatorIteratorSuite.scala +++ b/hail/src/test/scala/is/hail/utils/BufferedAggregatorIteratorSuite.scala @@ -1,15 +1,15 @@ package is.hail.utils import is.hail.check.{Gen, Prop} + import org.scalatest.testng.TestNGSuite import org.testng.annotations.Test class SumAgg() { var x = 0L - def add(element: Int): Unit = { + def add(element: Int): Unit = x += element - } def comb(other: SumAgg): SumAgg = { x += other.x @@ -20,24 +20,26 @@ class SumAgg() { } class BufferedAggregatorIteratorSuite extends TestNGSuite { - @Test def test() { + @Test def test(): Unit = { Prop.forAll( Gen.zip( Gen.buildableOf[IndexedSeq](Gen.zip(Gen.choose(1, 5), Gen.choose(1, 10))), - Gen.choose(1, 5)) + Gen.choose(1, 5), + ) ) { case (arr, bufferSize) => val simple = arr.groupBy(_._1).map { case (k, a) => k -> a.map(_._2.toLong).sum } val buffAgg = { new BufferedAggregatorIterator[(Int, Int), SumAgg, SumAgg, Int]( arr.iterator, () => new SumAgg(), - { case (k, v) => k }, + { case (k, _) => k }, { case (t, agg) => agg.add(t._2) }, a => a, - bufferSize) + bufferSize, + ) .toArray .groupBy(_._1) - .mapValues(sums => sums.map(_._2).fold(new SumAgg()){ case (s1, s2) => s1.comb(s2)}.x) + .mapValues(sums => sums.map(_._2).fold(new SumAgg()) { case (s1, s2) => s1.comb(s2) }.x) } simple == buffAgg }.check() diff --git a/hail/src/test/scala/is/hail/utils/FlipbookIteratorSuite.scala b/hail/src/test/scala/is/hail/utils/FlipbookIteratorSuite.scala index 09ae0cba669..85a33e681db 100644 --- a/hail/src/test/scala/is/hail/utils/FlipbookIteratorSuite.scala +++ b/hail/src/test/scala/is/hail/utils/FlipbookIteratorSuite.scala @@ -1,17 +1,19 @@ package is.hail.utils import is.hail.HailSuite -import org.testng.annotations.Test import scala.collection.generic.Growable import scala.collection.mutable.ArrayBuffer +import org.testng.annotations.Test + class FlipbookIteratorSuite extends HailSuite { class Box[A] extends AnyRef { var value: A = _ def canEqual(a: Any): Boolean = a.isInstanceOf[Box[A]] + override def equals(that: Any): Boolean = that match { case that: Box[A] => value == that.value @@ -22,6 +24,7 @@ class FlipbookIteratorSuite extends HailSuite { object Box { def apply[A](): Box[A] = new Box + def apply[A](a: A): Box[A] = { val box = Box[A]() box.value = a @@ -31,26 +34,29 @@ class FlipbookIteratorSuite extends HailSuite { def boxOrdView[A](implicit ord: Ordering[A]): OrderingView[Box[A]] = new OrderingView[Box[A]] { var value: A = _ - def setFiniteValue(a: Box[A]) { + + def setFiniteValue(a: Box[A]): Unit = value = a.value - } - def compareFinite(a: Box[A]): Int = { + + def compareFinite(a: Box[A]): Int = ord.compare(value, a.value) - } } def boxBuffer[A]: Growable[Box[A]] with Iterable[Box[A]] = new Growable[Box[A]] with Iterable[Box[A]] { val buf = ArrayBuffer[A]() val box = Box[A]() - def clear() { buf.clear() } + def clear(): Unit = buf.clear() + def +=(x: Box[A]) = { buf += x.value this } + def iterator: Iterator[Box[A]] = new Iterator[Box[A]] { var i = 0 def hasNext = i < buf.size + def next() = { box.value = buf(i) i += 1 @@ -66,18 +72,16 @@ class FlipbookIteratorSuite extends HailSuite { else l.value - r.value } - def makeTestIterator[A](elems: A*): StagingIterator[Box[A]] = { val it = elems.iterator val sm = new StateMachine[Box[A]] { val value: Box[A] = Box() var isValid = true - def advance() { + def advance(): Unit = if (it.hasNext) value.value = it.next() else isValid = false - } } sm.advance() StagingIterator(sm) @@ -89,89 +93,95 @@ class FlipbookIteratorSuite extends HailSuite { } implicit class RichTestIteratorIterator( - it: FlipbookIterator[FlipbookIterator[Box[Int]]]) { + it: FlipbookIterator[FlipbookIterator[Box[Int]]] + ) { def shouldBe(that: Iterator[Iterator[Int]]): Boolean = it.sameElementsUsing( that, (flipIt: FlipbookIterator[Box[Int]], it: Iterator[Int]) => - flipIt shouldBe it) + flipIt shouldBe it, + ) } implicit class RichTestIteratorMuple( - it: FlipbookIterator[Muple[Box[Int], Box[Int]]]) { + it: FlipbookIterator[Muple[Box[Int], Box[Int]]] + ) { def shouldBe(that: Iterator[(Int, Int)]): Boolean = it.sameElementsUsing( that, (muple: Muple[Box[Int], Box[Int]], pair: (Int, Int)) => - (muple._1.value == pair._1) && (muple._2.value == pair._2) + (muple._1.value == pair._1) && (muple._2.value == pair._2), ) } implicit class RichTestIteratorMupleIterator( - it: FlipbookIterator[Muple[FlipbookIterator[Box[Int]], - FlipbookIterator[Box[Int]]]]) { + it: FlipbookIterator[Muple[FlipbookIterator[Box[Int]], FlipbookIterator[Box[Int]]]] + ) { def shouldBe(that: Iterator[(Iterator[Int], Iterator[Int])]): Boolean = it.sameElementsUsing( that, - (muple: Muple[FlipbookIterator[Box[Int]], FlipbookIterator[Box[Int]]], - pair: (Iterator[Int], Iterator[Int])) => - muple._1.shouldBe(pair._1) && muple._2.shouldBe(pair._2) + ( + muple: Muple[FlipbookIterator[Box[Int]], FlipbookIterator[Box[Int]]], + pair: (Iterator[Int], Iterator[Int]), + ) => + muple._1.shouldBe(pair._1) && muple._2.shouldBe(pair._2), ) } implicit class RichTestIteratorArrayIterator(it: FlipbookIterator[Array[Box[Int]]]) { - def shouldBe(that: Iterator[Array[Int]]): Boolean = { + def shouldBe(that: Iterator[Array[Int]]): Boolean = it.sameElementsUsing( that, (arrBox: Array[Box[Int]], arr: Array[Int]) => - arrBox.length == arr.length && arrBox.zip(arr).forall({ case (a, b) => a.value == b }) + arrBox.length == arr.length && arrBox.zip(arr).forall { case (a, b) => a.value == b }, ) - } } - @Test def flipbookIteratorStartsWithRightValue() { + @Test def flipbookIteratorStartsWithRightValue(): Unit = { val it: FlipbookIterator[Box[Int]] = makeTestIterator(1, 2, 3, 4, 5) assert(it.value.value == 1) } - @Test def makeTestIteratorWorks() { + @Test def makeTestIteratorWorks(): Unit = { assert(makeTestIterator(1, 2, 3, 4, 5) shouldBe Iterator.range(1, 6)) assert(makeTestIterator[Int]() shouldBe Iterator.empty) } - @Test def toFlipbookIteratorOnFlipbookIteratorIsIdentity() { + @Test def toFlipbookIteratorOnFlipbookIteratorIsIdentity(): Unit = { val it1 = makeTestIterator(1, 2, 3) val it2 = Iterator(1, 2, 3) assert(it1.toFlipbookIterator shouldBe it2) assert(makeTestIterator[Int]().toFlipbookIterator shouldBe Iterator.empty) } - @Test def toStaircaseWorks() { + @Test def toStaircaseWorks(): Unit = { val testIt = makeTestIterator(1, 1, 2, 3, 3, 3) val it = Iterator( Iterator(1, 1), Iterator(2), - Iterator(3, 3, 3)) + Iterator(3, 3, 3), + ) assert(testIt.staircased(boxOrdView) shouldBe it) } - @Test def orderedZipJoinWorks() { + @Test def orderedZipJoinWorks(): Unit = { val left = makeTestIterator(1, 2, 4, 1000, 1000) val right = makeTestIterator(2, 3, 4, 1000, 1000) val zipped = left.orderedZipJoin( right, Box(0), Box(0), - boxIntOrd(missingValue = 1000)) + boxIntOrd(missingValue = 1000), + ) val it = Iterator((1, 0), (2, 2), (0, 3), (4, 4), (1000, 0), (1000, 0), (0, 1000), (0, 1000)) assert(zipped shouldBe it) } - @Test def innerJoinDistinctWorks() { + @Test def innerJoinDistinctWorks(): Unit = { val left = makeTestIterator(1, 2, 2, 4, 1000, 1000) val right = makeTestIterator(2, 4, 4, 5, 1000, 1000) val joined = left.innerJoinDistinct( @@ -180,28 +190,28 @@ class FlipbookIteratorSuite extends HailSuite { boxOrdView[Int], Box(0), Box(0), - boxIntOrd(missingValue = 1000) + boxIntOrd(missingValue = 1000), ) val it = Iterator((2, 2), (2, 2), (4, 4)) assert(joined shouldBe it) } - @Test def leftJoinDistinctWorks() { + @Test def leftJoinDistinctWorks(): Unit = { val left = makeTestIterator(1, 2, 2, 4, 1000, 1000) val right = makeTestIterator(2, 4, 4, 5, 1000, 1000) val joined = left.leftJoinDistinct( right, Box(0), Box(0), - boxIntOrd(missingValue = 1000) + boxIntOrd(missingValue = 1000), ) val it = Iterator((1, 0), (2, 2), (2, 2), (4, 4), (1000, 0), (1000, 0)) assert(joined shouldBe it) } - @Test def innerJoinWorks() { + @Test def innerJoinWorks(): Unit = { val left = makeTestIterator(1, 2, 2, 4, 5, 5, 1000, 1000) val right = makeTestIterator(2, 2, 4, 4, 5, 6, 1000, 1000) val joined = left.innerJoin( @@ -211,14 +221,14 @@ class FlipbookIteratorSuite extends HailSuite { Box(0), Box(0), boxBuffer[Int], - boxIntOrd(missingValue = 1000) + boxIntOrd(missingValue = 1000), ) val it = Iterator((2, 2), (2, 2), (2, 2), (2, 2), (4, 4), (4, 4), (5, 5), (5, 5)) assert(joined shouldBe it) } - @Test def leftJoinWorks() { + @Test def leftJoinWorks(): Unit = { val left = makeTestIterator(1, 2, 2, 4, 5, 5, 1000, 1000) val right = makeTestIterator(2, 2, 4, 4, 5, 6, 1000, 1000) val joined = left.leftJoin( @@ -228,14 +238,26 @@ class FlipbookIteratorSuite extends HailSuite { Box(0), Box(0), boxBuffer[Int], - boxIntOrd(missingValue = 1000) + boxIntOrd(missingValue = 1000), ) - val it = Iterator((1, 0), (2, 2), (2, 2), (2, 2), (2, 2), (4, 4), (4, 4), (5, 5), (5, 5), (1000, 0), (1000, 0)) + val it = Iterator( + (1, 0), + (2, 2), + (2, 2), + (2, 2), + (2, 2), + (4, 4), + (4, 4), + (5, 5), + (5, 5), + (1000, 0), + (1000, 0), + ) assert(joined shouldBe it) } - @Test def rightJoinWorks() { + @Test def rightJoinWorks(): Unit = { val left = makeTestIterator(1, 2, 2, 4, 5, 5, 1000, 1000) val right = makeTestIterator(2, 2, 4, 4, 5, 6, 1000, 1000) val joined = left.rightJoin( @@ -245,14 +267,26 @@ class FlipbookIteratorSuite extends HailSuite { Box(0), Box(0), boxBuffer[Int], - boxIntOrd(missingValue = 1000) + boxIntOrd(missingValue = 1000), ) - val it = Iterator((2, 2), (2, 2), (2, 2), (2, 2), (4, 4), (4, 4), (5, 5), (5, 5), (0, 6), (0, 1000), (0, 1000)) + val it = Iterator( + (2, 2), + (2, 2), + (2, 2), + (2, 2), + (4, 4), + (4, 4), + (5, 5), + (5, 5), + (0, 6), + (0, 1000), + (0, 1000), + ) assert(joined shouldBe it) } - @Test def outerJoinWorks() { + @Test def outerJoinWorks(): Unit = { val left = makeTestIterator(1, 2, 2, 4, 5, 5, 1000, 1000) val right = makeTestIterator(2, 2, 4, 4, 5, 6, 1000, 1000) val joined = left.outerJoin( @@ -262,15 +296,30 @@ class FlipbookIteratorSuite extends HailSuite { Box(0), Box(0), boxBuffer[Int], - boxIntOrd(missingValue = 1000) + boxIntOrd(missingValue = 1000), ) - val it = Iterator((1, 0), (2, 2), (2, 2), (2, 2), (2, 2), (4, 4), (4, 4), (5, 5), (5, 5), (0, 6), (1000, 0), (1000, 0), (0, 1000), (0, 1000)) + val it = Iterator( + (1, 0), + (2, 2), + (2, 2), + (2, 2), + (2, 2), + (4, 4), + (4, 4), + (5, 5), + (5, 5), + (0, 6), + (1000, 0), + (1000, 0), + (0, 1000), + (0, 1000), + ) assert(joined shouldBe it) } - @Test def multiZipJoinWorks() { + @Test def multiZipJoinWorks(): Unit = { val one = makeTestIterator(1, 2, 2, 4, 5, 5, 1000, 1000) val two = makeTestIterator(2, 3, 4, 5, 5, 6, 1000, 1000) val three = makeTestIterator(2, 3, 4, 4, 5, 6, 1000, 1000) @@ -278,8 +327,9 @@ class FlipbookIteratorSuite extends HailSuite { val zipped = FlipbookIterator.multiZipJoin(its, boxIntOrd(missingValue = 1000)) def fillOut(ar: BoxedArrayBuilder[(Box[Int], Int)], default: Box[Int]): Array[Box[Int]] = { val a: Array[Box[Int]] = Array.fill(3)(default) - var i = 0; while (i < ar.size) { - var v = ar(i) + var i = 0; + while (i < ar.size) { + val v = ar(i) a(v._2) = v._1 i += 1 } @@ -303,7 +353,7 @@ class FlipbookIteratorSuite extends HailSuite { Array(0, 0, 1000), Array(0, 0, 1000), Array(1000, 0, 0), - Array(1000, 0, 0) + Array(1000, 0, 0), ) assert(comp shouldBe it) diff --git a/hail/src/test/scala/is/hail/utils/GraphSuite.scala b/hail/src/test/scala/is/hail/utils/GraphSuite.scala index e22aa805c8b..0da93e67070 100644 --- a/hail/src/test/scala/is/hail/utils/GraphSuite.scala +++ b/hail/src/test/scala/is/hail/utils/GraphSuite.scala @@ -1,11 +1,9 @@ package is.hail.utils -import org.scalatest.Matchers._ -import org.scalatest._ -import org.testng.annotations.Test - import scala.collection.mutable +import org.scalatest.Matchers._ +import org.testng.annotations.Test class GraphSuite { @@ -16,7 +14,7 @@ class GraphSuite { x.forall(x => g(x).intersect(s).isEmpty) } - @Test def simple() { + @Test def simple(): Unit = { { val actual = maximalIndependentSet(Array((0 -> 1))) actual should ((contain theSameElementsAs Array(0)) or (contain theSameElementsAs Array(1))) @@ -29,7 +27,10 @@ class GraphSuite { { val actual = maximalIndependentSet(Array(0 -> 1, 0 -> 2, 3 -> 1, 3 -> 2)) - actual should ((contain theSameElementsAs Array(1, 2)) or (contain theSameElementsAs Array(0, 3))) + actual should ((contain theSameElementsAs Array(1, 2)) or (contain theSameElementsAs Array( + 0, + 3, + ))) } { @@ -39,7 +40,7 @@ class GraphSuite { } } - @Test def longCycle() { + @Test def longCycle(): Unit = { val g = mkGraph(0 -> 1, 1 -> 2, 2 -> 3, 3 -> 4, 4 -> 5, 5 -> 6, 6 -> 0) val actual = maximalIndependentSet(g) @@ -47,7 +48,7 @@ class GraphSuite { assert(actual.length == 3) } - @Test def twoPopularNodes() { + @Test def twoPopularNodes(): Unit = { val g = mkGraph(0 -> 1, 0 -> 2, 0 -> 3, 4 -> 5, 4 -> 6, 4 -> 0) val actual = maximalIndependentSet(g) @@ -55,20 +56,19 @@ class GraphSuite { assert(actual.length == 5) } - @Test def totallyDisconnected() { + @Test def totallyDisconnected(): Unit = { val expected = 0 until 10 val m = new mutable.HashMap[Int, mutable.Set[Int]]() with mutable.MultiMap[Int, Int] - for (i <- expected) { + for (i <- expected) m.put(i, mutable.Set()) - } val actual = maximalIndependentSet(m) actual should contain theSameElementsAs expected } - @Test def disconnected() { + @Test def disconnected(): Unit = { val g = mkGraph(for (i <- 0 until 10) yield (i, i + 10)) val actual = maximalIndependentSet(g) @@ -77,7 +77,7 @@ class GraphSuite { assert(actual.length == 10) } - @Test def selfEdge() { + @Test def selfEdge(): Unit = { val g = mkGraph(0 -> 0, 1 -> 2, 1 -> 3) val actual = maximalIndependentSet(g) @@ -86,7 +86,7 @@ class GraphSuite { actual should contain theSameElementsAs Array(2, 3) } - @Test def emptyGraph() { + @Test def emptyGraph(): Unit = { val g = mkGraph[Int]() val actual = maximalIndependentSet(g) @@ -94,7 +94,7 @@ class GraphSuite { assert(actual === Array[Int]()) } - @Test def tieBreakingOfBipartiteGraphWorks() { + @Test def tieBreakingOfBipartiteGraphWorks(): Unit = { val g = mkGraph(for (i <- 0 until 10) yield (i, i + 10)) // prefer to remove big numbers val actual = maximalIndependentSet(g, Some((l: Int, r: Int) => (l - r).toDouble)) @@ -104,8 +104,7 @@ class GraphSuite { assert(actual.forall(_ < 10)) } - - @Test def tieBreakingInLongCycleWorks() { + @Test def tieBreakingInLongCycleWorks(): Unit = { val g = mkGraph(0 -> 1, 1 -> 2, 2 -> 3, 3 -> 4, 4 -> 5, 5 -> 6, 6 -> 0) // prefers to remove small numbers val actual = maximalIndependentSet(g, Some((l: Int, r: Int) => (r - l).toDouble)) diff --git a/hail/src/test/scala/is/hail/utils/HashMethodsSuite.scala b/hail/src/test/scala/is/hail/utils/HashMethodsSuite.scala index e9995652de3..0ac2c14c8af 100644 --- a/hail/src/test/scala/is/hail/utils/HashMethodsSuite.scala +++ b/hail/src/test/scala/is/hail/utils/HashMethodsSuite.scala @@ -1,16 +1,19 @@ package is.hail.utils import is.hail.HailSuite + import org.testng.annotations.Test class HashMethodsSuite extends HailSuite { - @Test def testMultGF() { + @Test def testMultGF(): Unit = { import PolyHash._ // multGF should agree with the Mathematica function - // f[a_, b_] := PolynomialRemainder[g[a, b], x^32 + x^7 + x^3 + x^2 + 1, x, Modulus -> 2] /. x -> 2 + /* f[a_, b_] := PolynomialRemainder[g[a, b], x^32 + x^7 + x^3 + x^2 + 1, x, Modulus -> 2] /. x + * -> 2 */ // where - // g[a_, b_] := Expand[FromDigits[IntegerDigits[a, 2], x] * FromDigits[IntegerDigits[b, 2], x], Modulus -> 2] + /* g[a_, b_] := Expand[FromDigits[IntegerDigits[a, 2], x] * FromDigits[IntegerDigits[b, 2], x], + * Modulus -> 2] */ val x1 = 3705006673L.toInt val y1 = 2551778209L.toInt assert(multGF(x1, y1) == 2272553040L.toInt) diff --git a/hail/src/test/scala/is/hail/utils/IntervalSuite.scala b/hail/src/test/scala/is/hail/utils/IntervalSuite.scala index c19dce9cb12..179bb94c2ec 100644 --- a/hail/src/test/scala/is/hail/utils/IntervalSuite.scala +++ b/hail/src/test/scala/is/hail/utils/IntervalSuite.scala @@ -2,14 +2,14 @@ package is.hail.utils import is.hail.HailSuite import is.hail.annotations.ExtendedOrdering -import is.hail.backend.HailStateManager -import is.hail.types.virtual.{TInt32, TStruct} +import is.hail.backend.{ExecuteContext, HailStateManager} import is.hail.rvd.RVDPartitioner +import is.hail.types.virtual.{TInt32, TStruct} + import org.apache.spark.sql.Row import org.testng.Assert._ -import org.testng.annotations.{BeforeMethod, Test} -import is.hail.backend.ExecuteContext import org.testng.ITestContext +import org.testng.annotations.{BeforeMethod, Test} class IntervalSuite extends HailSuite { @@ -44,97 +44,143 @@ class IntervalSuite extends HailSuite { } } - - @Test def interval_agrees_with_set_interval_greater_than_point() { - for (set_interval <- test_intervals; p <- points) { + @Test def interval_agrees_with_set_interval_greater_than_point(): Unit = { + for { + set_interval <- test_intervals + p <- points + } { val interval = set_interval.interval - assertEquals(interval.isAbovePosition(pord, p), set_interval.doubledPointSet.forall(dp => dp > 2 * p)) + assertEquals( + interval.isAbovePosition(pord, p), + set_interval.doubledPointSet.forall(dp => dp > 2 * p), + ) } } - @Test def interval_agrees_with_set_interval_less_than_point() { - for (set_interval <- test_intervals; p <- points) { + @Test def interval_agrees_with_set_interval_less_than_point(): Unit = { + for { + set_interval <- test_intervals + p <- points + } { val interval = set_interval.interval - assertEquals(interval.isBelowPosition(pord, p), set_interval.doubledPointSet.forall(dp => dp < 2 * p)) + assertEquals( + interval.isBelowPosition(pord, p), + set_interval.doubledPointSet.forall(dp => dp < 2 * p), + ) } } - @Test def interval_agrees_with_set_interval_contains() { - for (set_interval <- test_intervals; p <- points) { + @Test def interval_agrees_with_set_interval_contains(): Unit = { + for { + set_interval <- test_intervals + p <- points + } { val interval = set_interval.interval assertEquals(interval.contains(pord, p), set_interval.contains(p)) } } - @Test def interval_agrees_with_set_interval_includes() { - for (set_interval1 <- test_intervals; set_interval2 <- test_intervals) { + @Test def interval_agrees_with_set_interval_includes(): Unit = { + for { + set_interval1 <- test_intervals + set_interval2 <- test_intervals + } { val interval1 = set_interval1.interval val interval2 = set_interval2.interval assertEquals(interval1.includes(pord, interval2), set_interval1.includes(set_interval2)) } } - @Test def interval_agrees_with_set_interval_probably_overlaps() { - for (set_interval1 <- test_intervals; set_interval2 <- test_intervals) { + @Test def interval_agrees_with_set_interval_probably_overlaps(): Unit = { + for { + set_interval1 <- test_intervals + set_interval2 <- test_intervals + } { val interval1 = set_interval1.interval val interval2 = set_interval2.interval - assertEquals(interval1.overlaps(pord, interval2), set_interval1.probablyOverlaps(set_interval2)) + assertEquals( + interval1.overlaps(pord, interval2), + set_interval1.probablyOverlaps(set_interval2), + ) } } - @Test def interval_agrees_with_set_interval_definitely_disjoint() { - for (set_interval1 <- test_intervals; set_interval2 <- test_intervals) { + @Test def interval_agrees_with_set_interval_definitely_disjoint(): Unit = { + for { + set_interval1 <- test_intervals + set_interval2 <- test_intervals + } { val interval1 = set_interval1.interval val interval2 = set_interval2.interval - assertEquals(interval1.isDisjointFrom(pord, interval2), set_interval1.definitelyDisjoint(set_interval2)) + assertEquals( + interval1.isDisjointFrom(pord, interval2), + set_interval1.definitelyDisjoint(set_interval2), + ) } } - @Test def interval_agrees_with_set_interval_disjoint_greater_than() { - for {set_interval1 <- test_intervals - set_interval2 <- test_intervals} { + @Test def interval_agrees_with_set_interval_disjoint_greater_than(): Unit = { + for { + set_interval1 <- test_intervals + set_interval2 <- test_intervals + } { val interval1 = set_interval1.interval val interval2 = set_interval2.interval assertEquals(interval1.isAbove(pord, interval2), set_interval1.isAboveInterval(set_interval2)) } } - @Test def interval_agrees_with_set_interval_disjoint_less_than() { - for {set_interval1 <- test_intervals - set_interval2 <- test_intervals} { + @Test def interval_agrees_with_set_interval_disjoint_less_than(): Unit = { + for { + set_interval1 <- test_intervals + set_interval2 <- test_intervals + } { val interval1 = set_interval1.interval val interval2 = set_interval2.interval assertEquals(interval1.isBelow(pord, interval2), set_interval1.isBelowInterval(set_interval2)) } } - @Test def interval_agrees_with_set_interval_mergeable() { - for {set_interval1 <- test_intervals - set_interval2 <- test_intervals} { + @Test def interval_agrees_with_set_interval_mergeable(): Unit = { + for { + set_interval1 <- test_intervals + set_interval2 <- test_intervals + } { val interval1 = set_interval1.interval val interval2 = set_interval2.interval assertEquals(interval1.canMergeWith(pord, interval2), set_interval1.mergeable(set_interval2)) } } - @Test def interval_agrees_with_set_interval_merge() { - for {set_interval1 <- test_intervals - set_interval2 <- test_intervals} { + @Test def interval_agrees_with_set_interval_merge(): Unit = { + for { + set_interval1 <- test_intervals + set_interval2 <- test_intervals + } { val interval1 = set_interval1.interval val interval2 = set_interval2.interval - assertEquals(interval1.merge(pord, interval2), set_interval1.union(set_interval2).map(_.interval)) + assertEquals( + interval1.merge(pord, interval2), + set_interval1.union(set_interval2).map(_.interval), + ) } } - @Test def interval_agrees_with_set_interval_intersect() { - for (set_interval1 <- test_intervals; set_interval2 <- test_intervals) { + @Test def interval_agrees_with_set_interval_intersect(): Unit = { + for { + set_interval1 <- test_intervals + set_interval2 <- test_intervals + } { val interval1 = set_interval1.interval val interval2 = set_interval2.interval - assertEquals(interval1.intersect(pord, interval2), set_interval1.intersect(set_interval2).map(_.interval)) + assertEquals( + interval1.intersect(pord, interval2), + set_interval1.intersect(set_interval2).map(_.interval), + ) } } - @Test def interval_tree_agrees_with_set_interval_tree_contains() { + @Test def interval_tree_agrees_with_set_interval_tree_contains(): Unit = { for { set_itree <- test_itrees p <- points @@ -144,7 +190,7 @@ class IntervalSuite extends HailSuite { } } - @Test def interval_tree_agrees_with_set_interval_tree_probably_overlaps() { + @Test def interval_tree_agrees_with_set_interval_tree_probably_overlaps(): Unit = { for { set_itree <- test_itrees set_interval <- test_intervals @@ -155,7 +201,7 @@ class IntervalSuite extends HailSuite { } } - @Test def interval_tree_agrees_with_set_interval_tree_definitely_disjoint() { + @Test def interval_tree_agrees_with_set_interval_tree_definitely_disjoint(): Unit = { for { set_itree <- test_itrees set_interval <- test_intervals @@ -166,7 +212,7 @@ class IntervalSuite extends HailSuite { } } - @Test def interval_tree_agrees_with_set_interval_tree_query_values() { + @Test def interval_tree_agrees_with_set_interval_tree_query_values(): Unit = { for { set_itree <- test_itrees point <- points @@ -178,7 +224,7 @@ class IntervalSuite extends HailSuite { } } - @Test def interval_tree_agrees_with_set_interval_tree_query_overlapping_values() { + @Test def interval_tree_agrees_with_set_interval_tree_query_overlapping_values(): Unit = { for { set_itree <- test_itrees set_interval <- test_intervals @@ -216,11 +262,13 @@ case class SetInterval(start: Int, end: Int, includesStart: Boolean, includesEnd def includes(other: SetInterval): Boolean = (other.doubledPointSet -- this.doubledPointSet).isEmpty - def probablyOverlaps(other: SetInterval): Boolean = doubledPointSet.intersect(other.doubledPointSet).nonEmpty + def probablyOverlaps(other: SetInterval): Boolean = + doubledPointSet.intersect(other.doubledPointSet).nonEmpty def definitelyEmpty(): Boolean = doubledPointSet.isEmpty - def definitelyDisjoint(other: SetInterval): Boolean = doubledPointSet.intersect(other.doubledPointSet).isEmpty + def definitelyDisjoint(other: SetInterval): Boolean = + doubledPointSet.intersect(other.doubledPointSet).isEmpty def isAboveInterval(other: SetInterval): Boolean = doubledPointSet.forall(p1 => other.doubledPointSet.forall(p2 => p1 > p2)) @@ -249,8 +297,7 @@ case class SetInterval(start: Int, end: Int, includesStart: Boolean, includesEnd val start = combined.min(pord.toOrdering) val end = combined.max(pord.toOrdering) Some(SetInterval(start / 2, (end + 1) / 2, start % 2 == 0, end % 2 == 0)) - } - else None + } else None } def intersect(other: SetInterval): Option[SetInterval] = { @@ -275,21 +322,26 @@ case class SetIntervalTree(ctx: ExecuteContext, annotations: Array[(SetInterval, val (intervals, values) = annotations.unzip - val intervalTree: RVDPartitioner = new RVDPartitioner(ctx.stateManager, TStruct(("i", TInt32)), intervals.map(_.rowInterval)) + val intervalTree: RVDPartitioner = + new RVDPartitioner(ctx.stateManager, TStruct(("i", TInt32)), intervals.map(_.rowInterval)) def contains(point: Int): Boolean = doubledPointSet.contains(2 * point) - def probablyOverlaps(other: SetInterval): Boolean = doubledPointSet.intersect(other.doubledPointSet).nonEmpty + def probablyOverlaps(other: SetInterval): Boolean = + doubledPointSet.intersect(other.doubledPointSet).nonEmpty def definitelyEmpty(): Boolean = doubledPointSet.isEmpty - def definitelyDisjoint(other: SetInterval): Boolean = doubledPointSet.intersect(other.doubledPointSet).isEmpty + def definitelyDisjoint(other: SetInterval): Boolean = + doubledPointSet.intersect(other.doubledPointSet).isEmpty - def queryIntervals(point: Int): Set[Interval] = intervals.filter(_.contains(point)).map(_.interval).toSet + def queryIntervals(point: Int): Set[Interval] = + intervals.filter(_.contains(point)).map(_.interval).toSet def queryValues(point: Int): Set[Int] = annotations.filter(_._1.contains(point)).map(_._2).toSet - def queryProbablyOverlappingValues(interval: SetInterval): Set[Int] = annotations.filter(_._1.probablyOverlaps(interval)).map(_._2).toSet + def queryProbablyOverlappingValues(interval: SetInterval): Set[Int] = + annotations.filter(_._1.probablyOverlaps(interval)).map(_._2).toSet override val toString: String = intervals.map(_.interval).mkString(", ") } diff --git a/hail/src/test/scala/is/hail/utils/PartitionCountsSuite.scala b/hail/src/test/scala/is/hail/utils/PartitionCountsSuite.scala index 056a5aaa3ad..b64466e3f07 100644 --- a/hail/src/test/scala/is/hail/utils/PartitionCountsSuite.scala +++ b/hail/src/test/scala/is/hail/utils/PartitionCountsSuite.scala @@ -1,50 +1,56 @@ package is.hail.utils +import is.hail.utils.PartitionCounts._ + import org.scalatest.testng.TestNGSuite import org.testng.annotations.Test -import is.hail.utils.PartitionCounts._ class PartitionCountsSuite extends TestNGSuite { @Test def testHeadPCs() = { - for (((a, n), b) <- Seq( - (IndexedSeq(0L), 5L) -> IndexedSeq(0L), - (IndexedSeq(4L, 5L, 6L), 1L) -> IndexedSeq(1L), - (IndexedSeq(4L, 5L, 6L), 6L) -> IndexedSeq(4L, 2L), - (IndexedSeq(4L, 5L, 6L), 9L) -> IndexedSeq(4L, 5L), - (IndexedSeq(4L, 5L, 6L), 10L) -> IndexedSeq(4L, 5L, 1L), - (IndexedSeq(4L, 5L, 6L), 15L) -> IndexedSeq(4L, 5L, 6L), - (IndexedSeq(4L, 5L, 6L), 20L) -> IndexedSeq(4L, 5L, 6L) - )) { + for ( + ((a, n), b) <- Seq( + (IndexedSeq(0L), 5L) -> IndexedSeq(0L), + (IndexedSeq(4L, 5L, 6L), 1L) -> IndexedSeq(1L), + (IndexedSeq(4L, 5L, 6L), 6L) -> IndexedSeq(4L, 2L), + (IndexedSeq(4L, 5L, 6L), 9L) -> IndexedSeq(4L, 5L), + (IndexedSeq(4L, 5L, 6L), 10L) -> IndexedSeq(4L, 5L, 1L), + (IndexedSeq(4L, 5L, 6L), 15L) -> IndexedSeq(4L, 5L, 6L), + (IndexedSeq(4L, 5L, 6L), 20L) -> IndexedSeq(4L, 5L, 6L), + ) + ) assert(getHeadPCs(a, n) == b, s"getHeadPartitionCounts($a, $n)") - } } @Test def testTailPCs() = { - for (((a, n), b) <- Seq( - (IndexedSeq(0L), 5L) -> IndexedSeq(0L), - (IndexedSeq(4L, 5L, 6L), 1L) -> IndexedSeq(1L), - (IndexedSeq(4L, 5L, 6L), 6L) -> IndexedSeq(6L), - (IndexedSeq(4L, 5L, 6L), 9L) -> IndexedSeq(3L, 6L), - (IndexedSeq(4L, 5L, 6L), 10L) -> IndexedSeq(4L, 6L), - (IndexedSeq(4L, 5L, 6L), 15L) -> IndexedSeq(4L, 5L, 6L), - (IndexedSeq(4L, 5L, 6L), 20L) -> IndexedSeq(4L, 5L, 6L) - )) { + for ( + ((a, n), b) <- Seq( + (IndexedSeq(0L), 5L) -> IndexedSeq(0L), + (IndexedSeq(4L, 5L, 6L), 1L) -> IndexedSeq(1L), + (IndexedSeq(4L, 5L, 6L), 6L) -> IndexedSeq(6L), + (IndexedSeq(4L, 5L, 6L), 9L) -> IndexedSeq(3L, 6L), + (IndexedSeq(4L, 5L, 6L), 10L) -> IndexedSeq(4L, 6L), + (IndexedSeq(4L, 5L, 6L), 15L) -> IndexedSeq(4L, 5L, 6L), + (IndexedSeq(4L, 5L, 6L), 20L) -> IndexedSeq(4L, 5L, 6L), + ) + ) { assert(getTailPCs(a, n) == b, s"getTailPartitionCounts($a, $n)") - assert(getTailPCs(a, n) == getHeadPCs(a.reverse, n).reverse, - s"getTailPartitionCounts($a, $n) via head") + assert( + getTailPCs(a, n) == getHeadPCs(a.reverse, n).reverse, + s"getTailPartitionCounts($a, $n) via head", + ) } } @Test def testIncrementalPCSubset() = { - var pcs = Array(0L, 0L, 5L, 6L, 4L, 3L, 3L, 3L, 2L, 1L) + val pcs = Array(0L, 0L, 5L, 6L, 4L, 3L, 3L, 3L, 2L, 1L) def headOffset(n: Long) = incrementalPCSubsetOffset(n, 0 until pcs.length)(_.map(pcs)) for (n <- 0L until pcs.sum) { val PCSubsetOffset(i, nKeep, nDrop) = headOffset(n) - val total = (0 to i).map { j => if (j == i) nKeep else pcs(j) }.sum + val total = (0 to i).map(j => if (j == i) nKeep else pcs(j)).sum assert(nKeep + nDrop == pcs(i)) assert(total == n) } @@ -54,7 +60,7 @@ class PartitionCountsSuite extends TestNGSuite { for (n <- 0L until pcs.sum) { val PCSubsetOffset(i, nKeep, nDrop) = tailOffset(n) - val total = (i to (pcs.length - 1)).map { j => if (j == i) nKeep else pcs(j) }.sum + val total = (i to (pcs.length - 1)).map(j => if (j == i) nKeep else pcs(j)).sum assert(nKeep + nDrop == pcs(i)) assert(total == n) } diff --git a/hail/src/test/scala/is/hail/utils/RichArraySuite.scala b/hail/src/test/scala/is/hail/utils/RichArraySuite.scala index 763d66f06ae..df1a5809f51 100644 --- a/hail/src/test/scala/is/hail/utils/RichArraySuite.scala +++ b/hail/src/test/scala/is/hail/utils/RichArraySuite.scala @@ -1,19 +1,20 @@ package is.hail.utils -import is.hail.utils.richUtils.RichArray import is.hail.{HailSuite, TestUtils} +import is.hail.utils.richUtils.RichArray + import org.testng.annotations.Test class RichArraySuite extends HailSuite { - @Test def testArrayImpex() { + @Test def testArrayImpex(): Unit = { val file = ctx.createTmpPath("test") val a = Array.fill[Double](100)(util.Random.nextDouble()) val a2 = new Array[Double](100) - + RichArray.exportToDoubles(fs, file, a, bufSize = 32) RichArray.importFromDoubles(fs, file, a2, bufSize = 16) assert(a === a2) - + TestUtils.interceptFatal("Premature") { RichArray.importFromDoubles(fs, file, new Array[Double](101), bufSize = 64) } diff --git a/hail/src/test/scala/is/hail/utils/RichDenseMatrixDoubleSuite.scala b/hail/src/test/scala/is/hail/utils/RichDenseMatrixDoubleSuite.scala index 8f17050eebc..a8d05321bf4 100644 --- a/hail/src/test/scala/is/hail/utils/RichDenseMatrixDoubleSuite.scala +++ b/hail/src/test/scala/is/hail/utils/RichDenseMatrixDoubleSuite.scala @@ -1,14 +1,15 @@ package is.hail.utils import is.hail.{HailSuite, TestUtils} +import is.hail.linalg.BlockMatrix import is.hail.utils.richUtils.RichDenseMatrixDouble + import breeze.linalg.{DenseMatrix => BDM} -import is.hail.linalg.BlockMatrix import org.testng.annotations.Test class RichDenseMatrixDoubleSuite extends HailSuite { @Test - def readWriteBDM() { + def readWriteBDM(): Unit = { val m = BDM.rand[Double](256, 129) // 33024 doubles val fname = ctx.createTmpPath("test") @@ -17,7 +18,7 @@ class RichDenseMatrixDoubleSuite extends HailSuite { assert(m === m2) } - + @Test def testReadWriteDoubles(): Unit = { val file = ctx.createTmpPath("test") @@ -25,13 +26,13 @@ class RichDenseMatrixDoubleSuite extends HailSuite { RichDenseMatrixDouble.exportToDoubles(fs, file, m, forceRowMajor = false) val m2 = RichDenseMatrixDouble.importFromDoubles(fs, file, 50, 100, rowMajor = false) assert(m === m2) - + val fileT = ctx.createTmpPath("test2") val mT = m.t RichDenseMatrixDouble.exportToDoubles(fs, fileT, mT, forceRowMajor = true) val lmT2 = RichDenseMatrixDouble.importFromDoubles(fs, fileT, 100, 50, rowMajor = true) - assert(mT === mT) - + assert(mT === lmT2) + TestUtils.interceptFatal("Premature") { RichDenseMatrixDouble.importFromDoubles(fs, fileT, 100, 100, rowMajor = true) } diff --git a/hail/src/test/scala/is/hail/utils/RichIndexedRowMatrixSuite.scala b/hail/src/test/scala/is/hail/utils/RichIndexedRowMatrixSuite.scala index 540500b5999..f13a47a1f29 100644 --- a/hail/src/test/scala/is/hail/utils/RichIndexedRowMatrixSuite.scala +++ b/hail/src/test/scala/is/hail/utils/RichIndexedRowMatrixSuite.scala @@ -1,19 +1,18 @@ package is.hail.utils -import breeze.linalg.{DenseMatrix => BDM, _} -import is.hail.{HailSuite} +import is.hail.HailSuite import is.hail.linalg.BlockMatrix.ops._ + +import breeze.linalg.{DenseMatrix => BDM} import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.linalg.distributed.{DistributedMatrix, IndexedRow, IndexedRowMatrix} +import org.apache.spark.mllib.linalg.distributed.{IndexedRow, IndexedRowMatrix} import org.apache.spark.rdd.RDD import org.testng.annotations.Test -/** - * Testing RichIndexedRowMatrix. - */ +/** Testing RichIndexedRowMatrix. */ class RichIndexedRowMatrixSuite extends HailSuite { - @Test def testToBlockMatrixDense() { + @Test def testToBlockMatrixDense(): Unit = { val nRows = 9L val nCols = 6L val data = Seq( @@ -24,7 +23,7 @@ class RichIndexedRowMatrixSuite extends HailSuite { (5L, Vectors.dense(9.0, 0.0, 1.0, 1.0, 1.0, 1.0)), (6L, Vectors.dense(1.0, 2.0, 3.0, 1.0, 1.0, 1.0)), (7L, Vectors.dense(4.0, 5.0, 6.0, 1.0, 1.0, 1.0)), - (8L, Vectors.dense(7.0, 8.0, 9.0, 1.0, 1.0, 1.0)) + (8L, Vectors.dense(7.0, 8.0, 9.0, 1.0, 1.0, 1.0)), ).map(IndexedRow.tupled) val indexedRows: RDD[IndexedRow] = sc.parallelize(data) @@ -51,14 +50,14 @@ class RichIndexedRowMatrixSuite extends HailSuite { } } - @Test def emptyBlocks() { + @Test def emptyBlocks(): Unit = { val nRows = 9 val nCols = 2 val data = Seq( (3L, Vectors.dense(1.0, 2.0)), (4L, Vectors.dense(1.0, 2.0)), (5L, Vectors.dense(1.0, 2.0)), - (8L, Vectors.dense(1.0, 2.0)) + (8L, Vectors.dense(1.0, 2.0)), ).map(IndexedRow.tupled) val irm = new IndexedRowMatrix(sc.parallelize(data)) @@ -74,9 +73,13 @@ class RichIndexedRowMatrixSuite extends HailSuite { (m.dot(m.T)).toBreezeMatrix() // assert no exception assert(m.mapWithIndex { case (i, j, v) => i + 10 * j + v }.toBreezeMatrix() === - new BDM[Double](nRows, nCols, Array[Double]( - 0.0, 1.0, 2.0, 4.0, 5.0, 6.0, 6.0, 7.0, 9.0, - 10.0, 11.0, 12.0, 15.0, 16.0, 17.0, 16.0, 17.0, 20.0 - ))) + new BDM[Double]( + nRows, + nCols, + Array[Double]( + 0.0, 1.0, 2.0, 4.0, 5.0, 6.0, 6.0, 7.0, 9.0, + 10.0, 11.0, 12.0, 15.0, 16.0, 17.0, 16.0, 17.0, 20.0, + ), + )) } } diff --git a/hail/src/test/scala/is/hail/utils/RichRDDSuite.scala b/hail/src/test/scala/is/hail/utils/RichRDDSuite.scala index 6849a6f0bae..bf793008c15 100644 --- a/hail/src/test/scala/is/hail/utils/RichRDDSuite.scala +++ b/hail/src/test/scala/is/hail/utils/RichRDDSuite.scala @@ -1,10 +1,11 @@ package is.hail.utils -import is.hail.{HailSuite, TestUtils} +import is.hail.HailSuite + import org.testng.annotations.Test class RichRDDSuite extends HailSuite { - @Test def parallelWrite() { + @Test def parallelWrite(): Unit = { def read(file: String): Array[String] = fs.readLines(file)(_.map(_.value).toArray) val header = "my header is awesome!" @@ -24,17 +25,23 @@ class RichRDDSuite extends HailSuite { assert(read(shardHeaders + "/part-00001") sameElements header +: Array(data(1))) val separateHeader = ctx.createTmpPath("separateHeader", "gz") - r.writeTable(ctx, separateHeader, Some(header), exportType = ExportType.PARALLEL_SEPARATE_HEADER) + r.writeTable( + ctx, + separateHeader, + Some(header), + exportType = ExportType.PARALLEL_SEPARATE_HEADER, + ) assert(read(separateHeader + "/header.gz") sameElements Array(header)) assert(read(separateHeader + "/part-00000.gz") sameElements Array(data(0))) assert(read(separateHeader + "/part-00001.gz") sameElements Array(data(1))) - val merged = ctx.createTmpPath("merged", ".gz") - val mergeList = Array(separateHeader + "/header.gz", + val mergeList = Array( + separateHeader + "/header.gz", separateHeader + "/part-00000.gz", - separateHeader + "/part-00001.gz").flatMap((x: String) => fs.glob(x)) + separateHeader + "/part-00001.gz", + ).map(x => fs.fileStatus(x)) fs.copyMergeList(mergeList, merged, deleteSource = false) assert(read(merged) sameElements read(concatenated)) diff --git a/hail/src/test/scala/is/hail/utils/RowIntervalSuite.scala b/hail/src/test/scala/is/hail/utils/RowIntervalSuite.scala index 75b75ea958d..92afa643570 100644 --- a/hail/src/test/scala/is/hail/utils/RowIntervalSuite.scala +++ b/hail/src/test/scala/is/hail/utils/RowIntervalSuite.scala @@ -1,10 +1,11 @@ package is.hail.utils +import is.hail.{ExecStrategy, HailSuite} import is.hail.expr.ir import is.hail.expr.ir.In import is.hail.rvd.{PartitionBoundOrdering, RVDPartitioner} import is.hail.types.virtual.{TBoolean, TInt32, TStruct} -import is.hail.{ExecStrategy, HailSuite} + import org.apache.spark.sql.Row import org.testng.annotations.Test @@ -29,24 +30,49 @@ class RowIntervalSuite extends HailSuite { assertEvalsTo( ir.invoke("partitionIntervalContains", TBoolean, in1, in2), args = FastSeq((intervalIRRep, irRepIntervalType), (point, tt)), - shouldContain)(ExecStrategy.compileOnly) + shouldContain, + )(ExecStrategy.compileOnly) } - @Test def testContains() { + @Test def testContains(): Unit = { assertContains(Interval(Row(0, 1, 5), Row(1, 2, 4), true, true), Row(1, 1, 3)) assertContains(Interval(Row(0, 1, 5), Row(1, 2, 4), true, true), Row(0, 1, 5)) - assertContains(Interval(Row(0, 1, 5), Row(1, 2, 4), false, true), Row(0, 1, 5), shouldContain = false) - assertContains(Interval(Row(0, 1, 5), Row(1, 2, 4), true, false), Row(1, 2, 4), shouldContain = false) + assertContains( + Interval(Row(0, 1, 5), Row(1, 2, 4), false, true), + Row(0, 1, 5), + shouldContain = false, + ) + assertContains( + Interval(Row(0, 1, 5), Row(1, 2, 4), true, false), + Row(1, 2, 4), + shouldContain = false, + ) assertContains(Interval(Row(0, 1), Row(1, 2, 4), true, true), Row(0, 1, 5)) - assertContains(Interval(Row(0, 1), Row(1, 2, 4), false, true), Row(0, 1, 5), shouldContain = false) + assertContains( + Interval(Row(0, 1), Row(1, 2, 4), false, true), + Row(0, 1, 5), + shouldContain = false, + ) assertContains(Interval(Row(0, 1), Row(0, 1, 4), true, true), Row(0, 1, 4)) - assertContains(Interval(Row(0, 1), Row(0, 1, 4), true, false), Row(0, 1, 4), shouldContain = false) + assertContains( + Interval(Row(0, 1), Row(0, 1, 4), true, false), + Row(0, 1, 4), + shouldContain = false, + ) assertContains(Interval(Row(0, 1), Row(1, 2, 4), true, true), Row(0, 1, 5)) - assertContains(Interval(Row(0, 1), Row(1, 2, 4), false, true), Row(0, 1, 5), shouldContain = false) + assertContains( + Interval(Row(0, 1), Row(1, 2, 4), false, true), + Row(0, 1, 5), + shouldContain = false, + ) assertContains(Interval(Row(0, 1), Row(0, 1, 4), true, true), Row(0, 1, 4)) - assertContains(Interval(Row(0, 1), Row(0, 1, 4), true, false), Row(0, 1, 4), shouldContain = false) + assertContains( + Interval(Row(0, 1), Row(0, 1, 4), true, false), + Row(0, 1, 4), + shouldContain = false, + ) assertContains(Interval(Row(), Row(1, 2, 4), true, true), Row(1, 2, 4)) assertContains(Interval(Row(), Row(1, 2, 4), true, false), Row(1, 2, 4), shouldContain = false) @@ -58,7 +84,7 @@ class RowIntervalSuite extends HailSuite { assert(!Interval(Row(0, 1, 5, 7), Row(2, 1, 4, 5), false, false).contains(pord, Row(0, 1, 5))) } - @Test def testAbovePosition() { + @Test def testAbovePosition(): Unit = { assert(Interval(Row(0, 1, 5), Row(1, 2, 4), true, true).isAbovePosition(pord, Row(0, 1, 4))) assert(Interval(Row(0, 1, 5), Row(1, 2, 4), false, true).isAbovePosition(pord, Row(0, 1, 5))) assert(!Interval(Row(0, 1, 5), Row(1, 2, 4), true, true).isAbovePosition(pord, Row(0, 1, 5))) @@ -68,11 +94,17 @@ class RowIntervalSuite extends HailSuite { assert(Interval(Row(0, 1), Row(1, 2, 4), false, true).isAbovePosition(pord, Row(0, 1, 5))) assert(!Interval(Row(0, 1), Row(0, 1, 4), true, true).isAbovePosition(pord, Row(0, 1, 4))) - assert(Interval(Row(0, 1, 2, 3), Row(1, 2, 3, 4), true, true).isAbovePosition(pord, Row(0, 1, 1, 4))) - assert(!Interval(Row(0, 1, 2, 3), Row(1, 2, 3, 4), true, true).isAbovePosition(pord, Row(0, 1, 2, 2))) + assert(Interval(Row(0, 1, 2, 3), Row(1, 2, 3, 4), true, true).isAbovePosition( + pord, + Row(0, 1, 1, 4), + )) + assert(!Interval(Row(0, 1, 2, 3), Row(1, 2, 3, 4), true, true).isAbovePosition( + pord, + Row(0, 1, 2, 2), + )) } - @Test def testBelowPosition() { + @Test def testBelowPosition(): Unit = { assert(Interval(Row(0, 1, 5), Row(1, 2, 4), true, true).isBelowPosition(pord, Row(1, 2, 5))) assert(Interval(Row(0, 1, 5), Row(1, 2, 4), true, false).isBelowPosition(pord, Row(1, 2, 4))) assert(!Interval(Row(0, 1, 5), Row(1, 2, 4), true, true).isBelowPosition(pord, Row(1, 2, 4))) @@ -83,77 +115,101 @@ class RowIntervalSuite extends HailSuite { assert(!Interval(Row(1, 1, 8), Row(1, 2), true, true).isBelowPosition(pord, Row(1, 2, 5))) } - @Test def testAbutts() { - assert(Interval(Row(0, 1, 5), Row(1, 2, 4), true, true).abutts(pord, - Interval(Row(1, 2, 4), Row(1, 3, 4), false, true))) - assert(!Interval(Row(0, 1, 5), Row(1, 2, 4), true, true).abutts(pord, - Interval(Row(1, 2, 4), Row(1, 3, 4), true, true))) + @Test def testAbutts(): Unit = { + assert(Interval(Row(0, 1, 5), Row(1, 2, 4), true, true).abutts( + pord, + Interval(Row(1, 2, 4), Row(1, 3, 4), false, true), + )) + assert(!Interval(Row(0, 1, 5), Row(1, 2, 4), true, true).abutts( + pord, + Interval(Row(1, 2, 4), Row(1, 3, 4), true, true), + )) - assert(Interval(Row(0, 1), Row(1, 2), true, true).abutts(pord, - Interval(Row(1, 2), Row(1, 3), false, true))) - assert(!Interval(Row(0, 1), Row(1, 2), true, true).abutts(pord, - Interval(Row(1, 2), Row(1, 3), true, true))) + assert(Interval(Row(0, 1), Row(1, 2), true, true).abutts( + pord, + Interval(Row(1, 2), Row(1, 3), false, true), + )) + assert(!Interval(Row(0, 1), Row(1, 2), true, true).abutts( + pord, + Interval(Row(1, 2), Row(1, 3), true, true), + )) } - @Test def testLteqWithOverlap() { + @Test def testLteqWithOverlap(): Unit = { val eord = pord.intervalEndpointOrdering assert(!eord.lteqWithOverlap(3)( - IntervalEndpoint(Row(0, 1, 6), -1), IntervalEndpoint(Row(0, 1, 5), 1) + IntervalEndpoint(Row(0, 1, 6), -1), + IntervalEndpoint(Row(0, 1, 5), 1), )) assert(eord.lteqWithOverlap(3)( - IntervalEndpoint(Row(0, 1, 5), 1), IntervalEndpoint(Row(0, 1, 5), -1) + IntervalEndpoint(Row(0, 1, 5), 1), + IntervalEndpoint(Row(0, 1, 5), -1), )) assert(!eord.lteqWithOverlap(2)( - IntervalEndpoint(Row(0, 1, 5), 1), IntervalEndpoint(Row(0, 1, 5), -1) + IntervalEndpoint(Row(0, 1, 5), 1), + IntervalEndpoint(Row(0, 1, 5), -1), )) assert(eord.lteqWithOverlap(2)( - IntervalEndpoint(Row(0, 1, 5), -1), IntervalEndpoint(Row(0, 1, 5), -1) + IntervalEndpoint(Row(0, 1, 5), -1), + IntervalEndpoint(Row(0, 1, 5), -1), )) assert(!eord.lteqWithOverlap(1)( - IntervalEndpoint(Row(0, 1, 5), -1), IntervalEndpoint(Row(0, 1, 5), -1) + IntervalEndpoint(Row(0, 1, 5), -1), + IntervalEndpoint(Row(0, 1, 5), -1), )) assert(eord.lteqWithOverlap(2)( - IntervalEndpoint(Row(0, 1, 2), -1), IntervalEndpoint(Row(0, 1, 5), -1) + IntervalEndpoint(Row(0, 1, 2), -1), + IntervalEndpoint(Row(0, 1, 5), -1), )) assert(!eord.lteqWithOverlap(1)( - IntervalEndpoint(Row(0, 1, 2), -1), IntervalEndpoint(Row(0, 1, 5), -1) + IntervalEndpoint(Row(0, 1, 2), -1), + IntervalEndpoint(Row(0, 1, 5), -1), )) assert(eord.lteqWithOverlap(1)( - IntervalEndpoint(Row(0, 1), -1), IntervalEndpoint(Row(0, 1), -1) + IntervalEndpoint(Row(0, 1), -1), + IntervalEndpoint(Row(0, 1), -1), )) assert(!eord.lteqWithOverlap(0)( - IntervalEndpoint(Row(0, 1), -1), IntervalEndpoint(Row(0, 1), -1) + IntervalEndpoint(Row(0, 1), -1), + IntervalEndpoint(Row(0, 1), -1), )) assert(eord.lteqWithOverlap(1)( - IntervalEndpoint(Row(0, 1, 5), -1), IntervalEndpoint(Row(0, 2), -1) + IntervalEndpoint(Row(0, 1, 5), -1), + IntervalEndpoint(Row(0, 2), -1), )) assert(!eord.lteqWithOverlap(0)( - IntervalEndpoint(Row(0, 1, 5), -1), IntervalEndpoint(Row(0, 2), -1) + IntervalEndpoint(Row(0, 1, 5), -1), + IntervalEndpoint(Row(0, 2), -1), )) assert(eord.lteqWithOverlap(0)( - IntervalEndpoint(Row(0), -1), IntervalEndpoint(Row(0), -1) + IntervalEndpoint(Row(0), -1), + IntervalEndpoint(Row(0), -1), )) assert(eord.lteqWithOverlap(0)( - IntervalEndpoint(Row(0), -1), IntervalEndpoint(Row(0, 1, 2), 1) + IntervalEndpoint(Row(0), -1), + IntervalEndpoint(Row(0, 1, 2), 1), )) assert(eord.lteqWithOverlap(0)( - IntervalEndpoint(Row(0, 3), -1), IntervalEndpoint(Row(1, 2), -1) + IntervalEndpoint(Row(0, 3), -1), + IntervalEndpoint(Row(1, 2), -1), )) assert(!eord.lteqWithOverlap(-1)( - IntervalEndpoint(Row(0, 3), -1), IntervalEndpoint(Row(1, 2), -1) + IntervalEndpoint(Row(0, 3), -1), + IntervalEndpoint(Row(1, 2), -1), )) assert(!eord.lteqWithOverlap(-1)( - IntervalEndpoint(Row(), 1), IntervalEndpoint(Row(), -1) + IntervalEndpoint(Row(), 1), + IntervalEndpoint(Row(), -1), )) } - @Test def testIsValid() { + @Test def testIsValid(): Unit = { assert(Interval.isValid(pord, Row(0, 1, 5), Row(0, 2), false, false)) assert(!Interval.isValid(pord, Row(0, 1, 5), Row(0, 0), false, false)) assert(Interval.isValid(pord, Row(0, 1, 5), Row(0, 1), false, true)) diff --git a/hail/src/test/scala/is/hail/utils/SemanticVersionSuite.scala b/hail/src/test/scala/is/hail/utils/SemanticVersionSuite.scala index 4577d611fea..fc55c7a8d06 100644 --- a/hail/src/test/scala/is/hail/utils/SemanticVersionSuite.scala +++ b/hail/src/test/scala/is/hail/utils/SemanticVersionSuite.scala @@ -4,7 +4,7 @@ import org.scalatest.testng.TestNGSuite import org.testng.annotations.Test class SemanticVersionSuite extends TestNGSuite { - @Test def testOrdering() { + @Test def testOrdering(): Unit = { val versions = Array( SemanticVersion(1, 1, 0), SemanticVersion(1, 1, 1), @@ -13,16 +13,13 @@ class SemanticVersionSuite extends TestNGSuite { SemanticVersion(1, 3, 0), SemanticVersion(1, 3, 1), SemanticVersion(2, 0, 0), - SemanticVersion(2, 0, 1)) + SemanticVersion(2, 0, 1), + ) versions.zipWithIndex.foreach { case (v, i) => - (0 until i).foreach { j => - assert(v > versions(j)) - } + (0 until i).foreach(j => assert(v > versions(j))) - (i + 1 until versions.length).foreach { j => - assert(v < versions(j)) - } + (i + 1 until versions.length).foreach(j => assert(v < versions(j))) } } } diff --git a/hail/src/test/scala/is/hail/utils/SpillingCollectIteratorSuite.scala b/hail/src/test/scala/is/hail/utils/SpillingCollectIteratorSuite.scala index 2e202916c23..5159ad52291 100644 --- a/hail/src/test/scala/is/hail/utils/SpillingCollectIteratorSuite.scala +++ b/hail/src/test/scala/is/hail/utils/SpillingCollectIteratorSuite.scala @@ -1,10 +1,11 @@ package is.hail.utils import is.hail.HailSuite + import org.testng.annotations.Test class SpillingCollectIteratorSuite extends HailSuite { - @Test def addOneElement() { + @Test def addOneElement(): Unit = { val array = (0 to 1234).toArray val sci = SpillingCollectIterator(ctx.localTmpdir, fs, sc.parallelize(array, 99), 100) assert(sci.hasNext) diff --git a/hail/src/test/scala/is/hail/utils/TreeTraversalSuite.scala b/hail/src/test/scala/is/hail/utils/TreeTraversalSuite.scala index 3393f4d4653..0106d66c23a 100644 --- a/hail/src/test/scala/is/hail/utils/TreeTraversalSuite.scala +++ b/hail/src/test/scala/is/hail/utils/TreeTraversalSuite.scala @@ -2,30 +2,31 @@ package is.hail.utils import org.testng.Assert import org.testng.annotations.Test + class TreeTraversalSuite { def binaryTree(i: Int): Iterator[Int] = (1 to 2).map(2 * i + _).iterator.filter(_ < 7) - @Test def testPostOrder = + @Test def testPostOrder() = Assert.assertEquals( TreeTraversal.postOrder(binaryTree)(0).toArray, Array(3, 4, 1, 5, 6, 2, 0), - "" + "", ) - @Test def testPreOrder = + @Test def testPreOrder() = Assert.assertEquals( TreeTraversal.preOrder(binaryTree)(0).toArray, Array(0, 1, 3, 4, 2, 5, 6), - "" + "", ) - @Test def levelOrder = + @Test def levelOrder() = Assert.assertEquals( TreeTraversal.levelOrder(binaryTree)(0).toArray, (0 to 6).toArray, - "" + "", ) } diff --git a/hail/src/test/scala/is/hail/utils/UnionFindSuite.scala b/hail/src/test/scala/is/hail/utils/UnionFindSuite.scala index 0cfa32ed79b..8ae7f5d532c 100644 --- a/hail/src/test/scala/is/hail/utils/UnionFindSuite.scala +++ b/hail/src/test/scala/is/hail/utils/UnionFindSuite.scala @@ -5,12 +5,11 @@ import org.testng.annotations.Test class UnionFindSuite extends TestNGSuite { @Test - def emptyUnionFindHasNoSets() { + def emptyUnionFindHasNoSets(): Unit = assert(new UnionFind().size == 0) - } @Test - def growingPastInitialCapacityOK() { + def growingPastInitialCapacityOK(): Unit = { val uf = new UnionFind(4) uf.makeSet(0) uf.makeSet(1) @@ -26,7 +25,7 @@ class UnionFindSuite extends TestNGSuite { } @Test - def simpleUnions() { + def simpleUnions(): Unit = { val uf = new UnionFind() uf.makeSet(0) @@ -40,7 +39,7 @@ class UnionFindSuite extends TestNGSuite { } @Test - def nonMonotonicMakeSet() { + def nonMonotonicMakeSet(): Unit = { val uf = new UnionFind() uf.makeSet(1000) @@ -62,7 +61,7 @@ class UnionFindSuite extends TestNGSuite { } @Test - def multipleUnions() { + def multipleUnions(): Unit = { val uf = new UnionFind() uf.makeSet(1) @@ -96,7 +95,7 @@ class UnionFindSuite extends TestNGSuite { } @Test - def unionsNoInterveningFinds() { + def unionsNoInterveningFinds(): Unit = { val uf = new UnionFind() uf.makeSet(1) @@ -122,7 +121,7 @@ class UnionFindSuite extends TestNGSuite { } @Test - def sameSetWorks() { + def sameSetWorks(): Unit = { val uf = new UnionFind() uf.makeSet(1) diff --git a/hail/src/test/scala/is/hail/utils/UtilsSuite.scala b/hail/src/test/scala/is/hail/utils/UtilsSuite.scala index b7eab628d15..24a5423ed58 100644 --- a/hail/src/test/scala/is/hail/utils/UtilsSuite.scala +++ b/hail/src/test/scala/is/hail/utils/UtilsSuite.scala @@ -3,35 +3,35 @@ package is.hail.utils import is.hail.HailSuite import is.hail.check.{Gen, Prop} import is.hail.io.fs.HadoopFS -import org.apache.spark.storage.StorageLevel -import org.testng.annotations.Test -import org.apache.hadoop -import org.sparkproject.guava.util.concurrent.MoreExecutors import scala.collection.mutable.ArrayBuffer +import org.apache.spark.storage.StorageLevel +import org.sparkproject.guava.util.concurrent.MoreExecutors +import org.testng.annotations.Test + class UtilsSuite extends HailSuite { - @Test def testD_==() { + @Test def testD_==(): Unit = { assert(D_==(1, 1)) - assert(D_==(1, 1 + 1E-7)) - assert(!D_==(1, 1 + 1E-5)) - assert(D_==(1E10, 1E10 + 1)) - assert(!D_==(1E-10, 2E-10)) + assert(D_==(1, 1 + 1e-7)) + assert(!D_==(1, 1 + 1e-5)) + assert(D_==(1e10, 1e10 + 1)) + assert(!D_==(1e-10, 2e-10)) assert(D_==(0.0, 0.0)) - assert(D_!=(1E-307, 0.0)) - assert(D_==(1E-308, 0.0)) - assert(D_==(1E-320, -1E-320)) + assert(D_!=(1e-307, 0.0)) + assert(D_==(1e-308, 0.0)) + assert(D_==(1e-320, -1e-320)) } - @Test def testFlushDouble() { - assert(flushDouble(8.0E-308) == 8.0E-308) - assert(flushDouble(-8.0E-308) == -8.0E-308) - assert(flushDouble(8.0E-309) == 0.0) - assert(flushDouble(-8.0E-309) == 0.0) + @Test def testFlushDouble(): Unit = { + assert(flushDouble(8.0e-308) == 8.0e-308) + assert(flushDouble(-8.0e-308) == -8.0e-308) + assert(flushDouble(8.0e-309) == 0.0) + assert(flushDouble(-8.0e-309) == 0.0) assert(flushDouble(0.0) == 0.0) } - @Test def testAreDistinct() { + @Test def testAreDistinct(): Unit = { assert(Array().areDistinct()) assert(Array(1).areDistinct()) assert(Array(1, 2).areDistinct()) @@ -39,7 +39,7 @@ class UtilsSuite extends HailSuite { assert(!Array(1, 2, 1).areDistinct()) } - @Test def testIsIncreasing() { + @Test def testIsIncreasing(): Unit = { assert(Seq[Int]().isIncreasing) assert(Seq(1).isIncreasing) assert(Seq(1, 2).isIncreasing) @@ -49,7 +49,7 @@ class UtilsSuite extends HailSuite { assert(Array(1, 2).isIncreasing) } - @Test def testIsSorted() { + @Test def testIsSorted(): Unit = { assert(Seq[Int]().isSorted) assert(Seq(1).isSorted) assert(Seq(1, 2).isSorted) @@ -59,24 +59,25 @@ class UtilsSuite extends HailSuite { assert(Array(1, 1).isSorted) } - @Test def testHadoopStripCodec() { + @Test def testHadoopStripCodec(): Unit = { assert(fs.stripCodecExtension("file.tsv") == "file.tsv") assert(fs.stripCodecExtension("file.tsv.gz") == "file.tsv") assert(fs.stripCodecExtension("file.tsv.bgz") == "file.tsv") assert(fs.stripCodecExtension("file") == "file") } - @Test def testPairRDDNoDup() { - val answer1 = Array((1, (1, Option(1))), (2, (4, Option(2))), (3, (9, Option(3))), (4, (16, Option(4)))) - val pairRDD1 = sc.parallelize(Array(1, 2, 3, 4)).map { i => (i, i * i) } - val pairRDD2 = sc.parallelize(Array(1, 2, 3, 4, 1, 2, 3, 4)).map { i => (i, i) } + @Test def testPairRDDNoDup(): Unit = { + val answer1 = + Array((1, (1, Option(1))), (2, (4, Option(2))), (3, (9, Option(3))), (4, (16, Option(4)))) + val pairRDD1 = sc.parallelize(Array(1, 2, 3, 4)).map(i => (i, i * i)) + val pairRDD2 = sc.parallelize(Array(1, 2, 3, 4, 1, 2, 3, 4)).map(i => (i, i)) val join = pairRDD1.leftOuterJoin(pairRDD2.distinct) assert(join.collect().sortBy(t => t._1) sameElements answer1) assert(join.count() == 4) } - @Test def testForallExists() { + @Test def testForallExists(): Unit = { val rdd1 = sc.parallelize(Array(1, 2, 3, 4, 5)) assert(rdd1.forall(_ > 0)) @@ -86,28 +87,28 @@ class UtilsSuite extends HailSuite { assert(!rdd1.exists(_ < 0)) } - @Test def testSortFileListEntry() { + @Test def testSortFileListEntry(): Unit = { val fs = new HadoopFS(new SerializableHadoopConfiguration(sc.hadoopConfiguration)) val partFileNames = fs.glob("src/test/resources/part-*") - .map { fileListEntry => - (fileListEntry, new hadoop.fs.Path(fileListEntry.getPath)) - }.sortBy { case (fileListEntry, path) => - getPartNumber(path.getName) - }.map(_._2.getName) + .sortBy(fileListEntry => getPartNumber(fileListEntry.getPath)).map(_.getPath.split( + "/" + ).last) assert(partFileNames(0) == "part-40001" && partFileNames(1) == "part-100001") } @Test def storageLevelStringTest() = { val sls = List( - "NONE", "DISK_ONLY", "DISK_ONLY_2", "MEMORY_ONLY", "MEMORY_ONLY_2", "MEMORY_ONLY_SER", "MEMORY_ONLY_SER_2", - "MEMORY_AND_DISK", "MEMORY_AND_DISK_2", "MEMORY_AND_DISK_SER", "MEMORY_AND_DISK_SER_2", "OFF_HEAP") + "NONE", "DISK_ONLY", "DISK_ONLY_2", "MEMORY_ONLY", "MEMORY_ONLY_2", "MEMORY_ONLY_SER", + "MEMORY_ONLY_SER_2", + "MEMORY_AND_DISK", "MEMORY_AND_DISK_2", "MEMORY_AND_DISK_SER", "MEMORY_AND_DISK_SER_2", + "OFF_HEAP") - sls.foreach { sl => StorageLevel.fromString(sl) } + sls.foreach(sl => StorageLevel.fromString(sl)) } - @Test def testDictionaryOrdering() { + @Test def testDictionaryOrdering(): Unit = { val stringList = Seq("Cats", "Crayon", "Dog") val longestToShortestLength = Ordering.by[String, Int](-_.length) @@ -120,14 +121,14 @@ class UtilsSuite extends HailSuite { assert(stringList.sorted(ord2) == Seq("Crayon", "Cats", "Dog")) } - @Test def testCollectAsSet() { - Prop.forAll(Gen.buildableOf[Array](Gen.choose(-1000, 1000)), Gen.choose(1, 10)) { case (values, parts) => - val rdd = sc.parallelize(values, numSlices = parts) - rdd.collectAsSet() == rdd.collect().toSet + @Test def testCollectAsSet(): Unit = + Prop.forAll(Gen.buildableOf[Array](Gen.choose(-1000, 1000)), Gen.choose(1, 10)) { + case (values, parts) => + val rdd = sc.parallelize(values, numSlices = parts) + rdd.collectAsSet() == rdd.collect().toSet }.check() - } - @Test def testDigitsNeeded() { + @Test def testDigitsNeeded(): Unit = { assert(digitsNeeded(0) == 1) assert(digitsNeeded(1) == 1) assert(digitsNeeded(7) == 1) @@ -136,38 +137,33 @@ class UtilsSuite extends HailSuite { assert(digitsNeeded(30173) == 5) } - @Test def testMangle() { + @Test def testMangle(): Unit = { val c1 = Array("a", "b", "c", "a", "a", "c", "a") val (c2, diff) = mangle(c1) assert(c2.toSeq == Seq("a", "b", "c", "a_1", "a_2", "c_1", "a_3")) assert(diff.toSeq == Seq("a" -> "a_1", "a" -> "a_2", "c" -> "c_1", "a" -> "a_3")) - val c3 = Array("a", "b", "c", "a", "a", "c", "a") val (c4, diff2) = mangle(c1, "D" * _) assert(c4.toSeq == Seq("a", "b", "c", "aD", "aDD", "cD", "aDDD")) assert(diff2.toSeq == Seq("a" -> "aD", "a" -> "aDD", "c" -> "cD", "a" -> "aDDD")) } - @Test def toMapUniqueEmpty() { + @Test def toMapUniqueEmpty(): Unit = assert(toMapIfUnique(Seq[(Int, Int)]())(x => x % 2) == Right(Map())) - } - @Test def toMapUniqueSingleton() { + @Test def toMapUniqueSingleton(): Unit = assert(toMapIfUnique(Seq(1 -> 2))(x => x % 2) == Right(Map(1 -> 2))) - } - @Test def toMapUniqueSmallNoDupe() { + @Test def toMapUniqueSmallNoDupe(): Unit = assert(toMapIfUnique(Seq(1 -> 2, 3 -> 6, 10 -> 2))(x => x % 5) == Right(Map(1 -> 2, 3 -> 6, 0 -> 2))) - } - @Test def toMapUniqueSmallDupes() { + @Test def toMapUniqueSmallDupes(): Unit = assert(toMapIfUnique(Seq(1 -> 2, 6 -> 6, 10 -> 2))(x => x % 5) == Left(Map(1 -> Seq(1, 6)))) - } @Test def testItemPartition(): Unit = { - def test(n: Int, k: Int) { + def test(n: Int, k: Int): Unit = { val a = new Array[Int](k) var prevj = 0 for (i <- 0 until n) { @@ -192,7 +188,7 @@ class UtilsSuite extends HailSuite { test(12, 5) } - @Test def testTreeAggDepth() { + @Test def testTreeAggDepth(): Unit = { assert(treeAggDepth(20, 20) == 1) assert(treeAggDepth(20, 19) == 2) assert(treeAggDepth(399, 20) == 2) @@ -207,13 +203,10 @@ class UtilsSuite extends HailSuite { val (failures, successes) = runAll[F, Int]( MoreExecutors.sameThreadExecutor() - )( - { case (acc, (_, index)) => acc :+ index } - )( + ) { case (acc, (_, index)) => acc :+ index }( new ArrayBuffer[Int](2) )( - for {k <- 0 until 4} - yield (() => if (k % 2 == 0) k else throw new Exception(), k) + for { k <- 0 until 4 } yield (() => if (k % 2 == 0) k else throw new Exception(), k) ) assert(failures == Seq(1, 3)) diff --git a/hail/src/test/scala/is/hail/utils/prettyPrint/PrettyPrintWriterSuite.scala b/hail/src/test/scala/is/hail/utils/prettyPrint/PrettyPrintWriterSuite.scala index 9d5a24fa4ba..3a1b792a40e 100644 --- a/hail/src/test/scala/is/hail/utils/prettyPrint/PrettyPrintWriterSuite.scala +++ b/hail/src/test/scala/is/hail/utils/prettyPrint/PrettyPrintWriterSuite.scala @@ -3,77 +3,131 @@ package is.hail.utils.prettyPrint import is.hail.utils.toRichIterator import scala.collection.JavaConverters._ + import org.scalatest.testng.TestNGSuite import org.testng.annotations.{DataProvider, Test} class PrettyPrintWriterSuite extends TestNGSuite { def data: Array[(Doc, Array[(Int, Int, Int, String)])] = Array( - ( nest(2, hsep("prefix", sep("text", "to", "lay", "out"))), + ( + nest(2, hsep("prefix", sep("text", "to", "lay", "out"))), Array( - (25, 25, 5, + ( + 25, + 25, + 5, """========================= |prefix text to lay out - |=========================""".stripMargin), - (20, 20, 5, + |=========================""".stripMargin, + ), + ( + 20, + 20, + 5, """==================== |prefix text | to | lay | out - |====================""".stripMargin), - (20, 20, 2, + |====================""".stripMargin, + ), + ( + 20, + 20, + 2, """==================== |prefix text | to - |====================""".stripMargin))), - ( nest(2, hsep("prefix", fillSep("text", "to", "lay", "out"))), + |====================""".stripMargin, + ), + ), + ), + ( + nest(2, hsep("prefix", fillSep("text", "to", "lay", "out"))), Array( - (25, 25, 5, + ( + 25, + 25, + 5, """========================= |prefix text to lay out - |=========================""".stripMargin), - (20, 20, 5, + |=========================""".stripMargin, + ), + ( + 20, + 20, + 5, """==================== |prefix text to lay | out - |====================""".stripMargin), - (15, 15, 5, + |====================""".stripMargin, + ), + ( + 15, + 15, + 5, """=============== |prefix text to | lay out - |===============""".stripMargin), - (10, 10, 5, + |===============""".stripMargin, + ), + ( + 10, + 10, + 5, """========== |prefix text | to lay | out - |==========""".stripMargin), - (10, 10, 2, + |==========""".stripMargin, + ), + ( + 10, + 10, + 2, """========== |prefix text | to lay - |==========""".stripMargin) - )), - ( nest(2, concat("prefix", list("A", "B", "C", "D"))), + |==========""".stripMargin, + ), + ), + ), + ( + nest(2, concat("prefix", list("A", "B", "C", "D"))), Array( - (15, 15, 5, + ( + 15, + 15, + 5, """=============== |prefix(A B C D) - |===============""".stripMargin), - (10, 10, 5, + |===============""".stripMargin, + ), + ( + 10, + 10, + 5, """========== |prefix(A | B | C | D) - |==========""".stripMargin), - (10, 10, 3, + |==========""".stripMargin, + ), + ( + 10, + 10, + 3, """========== |prefix(A | B | C - |==========""".stripMargin)))) + |==========""".stripMargin, + ), + ), + ), + ) @DataProvider(name = "data") def flatData: java.util.Iterator[Array[Object]] = @@ -83,12 +137,13 @@ class PrettyPrintWriterSuite extends TestNGSuite { } yield Array(doc, Int.box(width), Int.box(ribbonWidth), Int.box(maxLines), expected)).asJava @Test(dataProvider = "data") - def testPP(doc: Doc, width: Integer, ribbonWidth: Integer, maxLines: Integer, expected: String): Unit = { + def testPP(doc: Doc, width: Integer, ribbonWidth: Integer, maxLines: Integer, expected: String) + : Unit = { val ruler = "=" * width - assert(expected == s"$ruler\n${ doc.render(width, ribbonWidth, maxLines) }\n$ruler") + assert(expected == s"$ruler\n${doc.render(width, ribbonWidth, maxLines)}\n$ruler") } - @Test def testIntersperse() { + @Test def testIntersperse(): Unit = { val it = Array("A", "B", "C").iterator.intersperse("(", ",", ")") assert(it.mkString == "(A,B,C)") } diff --git a/hail/src/test/scala/is/hail/variant/GenotypeSuite.scala b/hail/src/test/scala/is/hail/variant/GenotypeSuite.scala index 2ae06fadbc5..d18b4fc49c9 100644 --- a/hail/src/test/scala/is/hail/variant/GenotypeSuite.scala +++ b/hail/src/test/scala/is/hail/variant/GenotypeSuite.scala @@ -1,10 +1,11 @@ package is.hail.variant import is.hail.TestUtils -import is.hail.check.Prop._ import is.hail.check.Gen +import is.hail.check.Prop._ import is.hail.testUtils.Variant import is.hail.utils._ + import org.scalatest.testng.TestNGSuite import org.testng.annotations.Test @@ -12,32 +13,31 @@ class GenotypeSuite extends TestNGSuite { val v = Variant("1", 1, "A", "T") - @Test def gtPairGtIndexIsId() { + @Test def gtPairGtIndexIsId(): Unit = forAll(Gen.choose(0, 32768), Gen.choose(0, 32768)) { (x, y) => val (j, k) = if (x < y) (x, y) else (y, x) val gt = AllelePair(j, k) Genotype.allelePair(Genotype.diploidGtIndex(gt)) == gt }.check() - } def triangleNumberOf(i: Int) = (i * i + i) / 2 - @Test def gtIndexGtPairIsId() { + @Test def gtIndexGtPairIsId(): Unit = forAll(Gen.choose(0, 10000)) { (idx) => Genotype.diploidGtIndex(Genotype.allelePair(idx)) == idx }.check() - } - @Test def gtPairAndGtPairSqrtEqual() { + @Test def gtPairAndGtPairSqrtEqual(): Unit = forAll(Gen.choose(0, 10000)) { (idx) => Genotype.allelePair(idx) == Genotype.allelePairSqrt(idx) }.check() - } - @Test def testGtFromLinear() { - val gen = for (nGenotype <- Gen.choose(2, 5).map(triangleNumberOf); - dosageGen = Gen.partition(nGenotype, 32768); - result <- dosageGen) yield result + @Test def testGtFromLinear(): Unit = { + val gen = for { + nGenotype <- Gen.choose(2, 5).map(triangleNumberOf) + dosageGen = Gen.partition(nGenotype, 32768) + result <- dosageGen + } yield result val p = forAll(gen) { gp => val gt = Option(uniqueMaxIndex(gp)) @@ -56,7 +56,7 @@ class GenotypeSuite extends TestNGSuite { p.check() } - @Test def testPlToDosage() { + @Test def testPlToDosage(): Unit = { val gt0 = Genotype.plToDosage(0, 20, 100) val gt1 = Genotype.plToDosage(20, 0, 100) val gt2 = Genotype.plToDosage(20, 100, 0) @@ -66,15 +66,15 @@ class GenotypeSuite extends TestNGSuite { assert(D_==(gt2, 1.980198019704931)) } - @Test def testCall() { + @Test def testCall(): Unit = { assert((0 until 9).forall { gt => val c = Call2.fromUnphasedDiploidGtIndex(gt) !Call.isPhased(c) && - Call.ploidy(c) == 2 && - Call.isDiploid(c) && - Call.isUnphasedDiploid(c) && - Call.unphasedDiploidGtIndex(c) == gt && - Call.alleleRepr(c) == gt + Call.ploidy(c) == 2 && + Call.isDiploid(c) && + Call.isUnphasedDiploid(c) && + Call.unphasedDiploidGtIndex(c) == gt && + Call.alleleRepr(c) == gt }) val c0 = Call2(0, 0, phased = true) @@ -88,15 +88,15 @@ class GenotypeSuite extends TestNGSuite { assert(x.forall { case (c, unphasedGt, alleleRepr) => val alleles = Call.alleles(c) c != Call2.fromUnphasedDiploidGtIndex(unphasedGt) && - Call.isPhased(c) && - Call.ploidy(c) == 2 + Call.isPhased(c) && + Call.ploidy(c) == 2 Call.isDiploid(c) && - !Call.isUnphasedDiploid(c) && - Call.unphasedDiploidGtIndex(Call2(alleles(0), alleles(1))) == unphasedGt && - Call.alleleRepr(c) == alleleRepr + !Call.isUnphasedDiploid(c) && + Call.unphasedDiploidGtIndex(Call2(alleles(0), alleles(1))) == unphasedGt && + Call.alleleRepr(c) == alleleRepr }) - - assert(Call.isHomRef(c0) && !Call.isHet(c0) && !Call.isHomVar(c0) && + + assert(Call.isHomRef(c0) && !Call.isHet(c0) && !Call.isHomVar(c0) && !Call.isHetNonRef(c0) && !Call.isHetRef(c0) && !Call.isNonRef(c0)) assert(!Call.isHomRef(c1a) && Call.isHet(c1a) && !Call.isHomVar(c1a) && diff --git a/hail/src/test/scala/is/hail/variant/LocusIntervalSuite.scala b/hail/src/test/scala/is/hail/variant/LocusIntervalSuite.scala index 18beedfbcc9..56027cbb1ac 100644 --- a/hail/src/test/scala/is/hail/variant/LocusIntervalSuite.scala +++ b/hail/src/test/scala/is/hail/variant/LocusIntervalSuite.scala @@ -2,54 +2,211 @@ package is.hail.variant import is.hail.{HailSuite, TestUtils} import is.hail.utils._ + import org.testng.annotations.Test class LocusIntervalSuite extends HailSuite { def rg = ctx.getReference(ReferenceGenome.GRCh37) - @Test def testParser() { + @Test def testParser(): Unit = { val xMax = rg.contigLength("X") val yMax = rg.contigLength("Y") val chr22Max = rg.contigLength("22") - assert(Locus.parseInterval("[1:100-1:101)", rg) == Interval(Locus("1", 100), Locus("1", 101), true, false)) - assert(Locus.parseInterval("[1:100-101)", rg) == Interval(Locus("1", 100), Locus("1", 101), true, false)) - assert(Locus.parseInterval("[X:100-101)", rg) == Interval(Locus("X", 100), Locus("X", 101), true, false)) - assert(Locus.parseInterval("[X:100-end)", rg) == Interval(Locus("X", 100), Locus("X", xMax), true, false)) - assert(Locus.parseInterval("[X:100-End)", rg) == Interval(Locus("X", 100), Locus("X", xMax), true, false)) - assert(Locus.parseInterval("[X:100-END)", rg) == Interval(Locus("X", 100), Locus("X", xMax), true, false)) - assert(Locus.parseInterval("[X:start-101)", rg) == Interval(Locus("X", 1), Locus("X", 101), true, false)) - assert(Locus.parseInterval("[X:Start-101)", rg) == Interval(Locus("X", 1), Locus("X", 101), true, false)) - assert(Locus.parseInterval("[X:START-101)", rg) == Interval(Locus("X", 1), Locus("X", 101), true, false)) - assert(Locus.parseInterval("[X:START-Y:END)", rg) == Interval(Locus("X", 1), Locus("Y", yMax), true, false)) - assert(Locus.parseInterval("[X-Y)", rg) == Interval(Locus("X", 1), Locus("Y", yMax), true, false)) - assert(Locus.parseInterval("[1-22)", rg) == Interval(Locus("1", 1), Locus("22", chr22Max), true, false)) - - assert(Locus.parseInterval("1:100-1:101", rg) == Interval(Locus("1", 100), Locus("1", 101), true, false)) - assert(Locus.parseInterval("1:100-101", rg) == Interval(Locus("1", 100), Locus("1", 101), true, false)) - assert(Locus.parseInterval("X:100-end", rg) == Interval(Locus("X", 100), Locus("X", xMax), true, true)) - assert(Locus.parseInterval("(X:100-End]", rg) == Interval(Locus("X", 100), Locus("X", xMax), false, true)) - assert(Locus.parseInterval("(X:100-END)", rg) == Interval(Locus("X", 100), Locus("X", xMax), false, false)) - assert(Locus.parseInterval("[X:start-101)", rg) == Interval(Locus("X", 1), Locus("X", 101), true, false)) - assert(Locus.parseInterval("(X:Start-101]", rg) == Interval(Locus("X", 1), Locus("X", 101), false, true)) - assert(Locus.parseInterval("X:START-101", rg) == Interval(Locus("X", 1), Locus("X", 101), true, false)) - assert(Locus.parseInterval("X:START-Y:END", rg) == Interval(Locus("X", 1), Locus("Y", yMax), true, true)) + assert(Locus.parseInterval("[1:100-1:101)", rg) == Interval( + Locus("1", 100), + Locus("1", 101), + true, + false, + )) + assert(Locus.parseInterval("[1:100-101)", rg) == Interval( + Locus("1", 100), + Locus("1", 101), + true, + false, + )) + assert(Locus.parseInterval("[X:100-101)", rg) == Interval( + Locus("X", 100), + Locus("X", 101), + true, + false, + )) + assert(Locus.parseInterval("[X:100-end)", rg) == Interval( + Locus("X", 100), + Locus("X", xMax), + true, + false, + )) + assert(Locus.parseInterval("[X:100-End)", rg) == Interval( + Locus("X", 100), + Locus("X", xMax), + true, + false, + )) + assert(Locus.parseInterval("[X:100-END)", rg) == Interval( + Locus("X", 100), + Locus("X", xMax), + true, + false, + )) + assert(Locus.parseInterval("[X:start-101)", rg) == Interval( + Locus("X", 1), + Locus("X", 101), + true, + false, + )) + assert(Locus.parseInterval("[X:Start-101)", rg) == Interval( + Locus("X", 1), + Locus("X", 101), + true, + false, + )) + assert(Locus.parseInterval("[X:START-101)", rg) == Interval( + Locus("X", 1), + Locus("X", 101), + true, + false, + )) + assert(Locus.parseInterval("[X:START-Y:END)", rg) == Interval( + Locus("X", 1), + Locus("Y", yMax), + true, + false, + )) + assert(Locus.parseInterval("[X-Y)", rg) == Interval( + Locus("X", 1), + Locus("Y", yMax), + true, + false, + )) + assert(Locus.parseInterval("[1-22)", rg) == Interval( + Locus("1", 1), + Locus("22", chr22Max), + true, + false, + )) + + assert(Locus.parseInterval("1:100-1:101", rg) == Interval( + Locus("1", 100), + Locus("1", 101), + true, + false, + )) + assert(Locus.parseInterval("1:100-101", rg) == Interval( + Locus("1", 100), + Locus("1", 101), + true, + false, + )) + assert(Locus.parseInterval("X:100-end", rg) == Interval( + Locus("X", 100), + Locus("X", xMax), + true, + true, + )) + assert(Locus.parseInterval("(X:100-End]", rg) == Interval( + Locus("X", 100), + Locus("X", xMax), + false, + true, + )) + assert(Locus.parseInterval("(X:100-END)", rg) == Interval( + Locus("X", 100), + Locus("X", xMax), + false, + false, + )) + assert(Locus.parseInterval("[X:start-101)", rg) == Interval( + Locus("X", 1), + Locus("X", 101), + true, + false, + )) + assert(Locus.parseInterval("(X:Start-101]", rg) == Interval( + Locus("X", 1), + Locus("X", 101), + false, + true, + )) + assert(Locus.parseInterval("X:START-101", rg) == Interval( + Locus("X", 1), + Locus("X", 101), + true, + false, + )) + assert(Locus.parseInterval("X:START-Y:END", rg) == Interval( + Locus("X", 1), + Locus("Y", yMax), + true, + true, + )) assert(Locus.parseInterval("X-Y", rg) == Interval(Locus("X", 1), Locus("Y", yMax), true, true)) - assert(Locus.parseInterval("1-22", rg) == Interval(Locus("1", 1), Locus("22", chr22Max), true, true)) + assert(Locus.parseInterval("1-22", rg) == Interval( + Locus("1", 1), + Locus("22", chr22Max), + true, + true, + )) // test normalizing end points - assert(Locus.parseInterval(s"(X:100-${ xMax + 1 })", rg) == Interval(Locus("X", 100), Locus("X", xMax), false, true)) - assert(Locus.parseInterval(s"(X:0-$xMax]", rg) == Interval(Locus("X", 1), Locus("X", xMax), true, true)) - TestUtils.interceptFatal("Start 'X:0' is not within the range")(Locus.parseInterval("[X:0-5)", rg)) - TestUtils.interceptFatal(s"End 'X:${ xMax + 1 }' is not within the range")(Locus.parseInterval(s"[X:1-${ xMax + 1 }]", rg)) - - assert(Locus.parseInterval("[16:29500000-30200000)", rg) == Interval(Locus("16", 29500000), Locus("16", 30200000), true, false)) - assert(Locus.parseInterval("[16:29.5M-30.2M)", rg) == Interval(Locus("16", 29500000), Locus("16", 30200000), true, false)) - assert(Locus.parseInterval("[16:29500K-30200K)", rg) == Interval(Locus("16", 29500000), Locus("16", 30200000), true, false)) - assert(Locus.parseInterval("[1:100K-2:200K)", rg) == Interval(Locus("1", 100000), Locus("2", 200000), true, false)) - - assert(Locus.parseInterval("[1:1.111K-2000)", rg) == Interval(Locus("1", 1111), Locus("1", 2000), true, false)) - assert(Locus.parseInterval("[1:1.111111M-2000000)", rg) == Interval(Locus("1", 1111111), Locus("1", 2000000), true, false)) + assert(Locus.parseInterval(s"(X:100-${xMax + 1})", rg) == Interval( + Locus("X", 100), + Locus("X", xMax), + false, + true, + )) + assert(Locus.parseInterval(s"(X:0-$xMax]", rg) == Interval( + Locus("X", 1), + Locus("X", xMax), + true, + true, + )) + TestUtils.interceptFatal("Start 'X:0' is not within the range")(Locus.parseInterval( + "[X:0-5)", + rg, + )) + TestUtils.interceptFatal(s"End 'X:${xMax + 1}' is not within the range")(Locus.parseInterval( + s"[X:1-${xMax + 1}]", + rg, + )) + + assert(Locus.parseInterval("[16:29500000-30200000)", rg) == Interval( + Locus("16", 29500000), + Locus("16", 30200000), + true, + false, + )) + assert(Locus.parseInterval("[16:29.5M-30.2M)", rg) == Interval( + Locus("16", 29500000), + Locus("16", 30200000), + true, + false, + )) + assert(Locus.parseInterval("[16:29500K-30200K)", rg) == Interval( + Locus("16", 29500000), + Locus("16", 30200000), + true, + false, + )) + assert(Locus.parseInterval("[1:100K-2:200K)", rg) == Interval( + Locus("1", 100000), + Locus("2", 200000), + true, + false, + )) + + assert(Locus.parseInterval("[1:1.111K-2000)", rg) == Interval( + Locus("1", 1111), + Locus("1", 2000), + true, + false, + )) + assert(Locus.parseInterval("[1:1.111111M-2000000)", rg) == Interval( + Locus("1", 1111111), + Locus("1", 2000000), + true, + false, + )) TestUtils.interceptFatal("invalid interval expression") { Locus.parseInterval("4::start-5:end", rg) diff --git a/hail/src/test/scala/is/hail/variant/ReferenceGenomeSuite.scala b/hail/src/test/scala/is/hail/variant/ReferenceGenomeSuite.scala index 3571681510c..1f7c361f914 100644 --- a/hail/src/test/scala/is/hail/variant/ReferenceGenomeSuite.scala +++ b/hail/src/test/scala/is/hail/variant/ReferenceGenomeSuite.scala @@ -1,14 +1,14 @@ package is.hail.variant -import is.hail.annotations.Region +import is.hail.{HailSuite, TestUtils} import is.hail.backend.HailStateManager import is.hail.check.Prop._ import is.hail.check.Properties import is.hail.expr.ir.EmitFunctionBuilder -import is.hail.types.virtual.TLocus import is.hail.io.reference.{FASTAReader, FASTAReaderConfig} +import is.hail.types.virtual.TLocus import is.hail.utils._ -import is.hail.{HailSuite, TestUtils} + import htsjdk.samtools.reference.ReferenceSequenceFileFactory import org.testng.annotations.Test @@ -18,7 +18,7 @@ class ReferenceGenomeSuite extends HailSuite { def getReference(name: String) = ctx.getReference(name) - @Test def testGRCh37() { + @Test def testGRCh37(): Unit = { assert(hasReference(ReferenceGenome.GRCh37)) val grch37 = getReference(ReferenceGenome.GRCh37) @@ -34,7 +34,7 @@ class ReferenceGenomeSuite extends HailSuite { assert(!nonParXLocus.forall(grch37.inXPar) && !nonParYLocus.forall(grch37.inYPar)) } - @Test def testGRCh38() { + @Test def testGRCh38(): Unit = { assert(hasReference(ReferenceGenome.GRCh38)) val grch38 = getReference(ReferenceGenome.GRCh38) @@ -50,23 +50,65 @@ class ReferenceGenomeSuite extends HailSuite { assert(!nonParXLocus38.forall(grch38.inXPar) && !nonParYLocus38.forall(grch38.inYPar)) } - @Test def testAssertions() { - TestUtils.interceptFatal("Must have at least one contig in the reference genome.")(ReferenceGenome("test", Array.empty[String], Map.empty[String, Int])) - TestUtils.interceptFatal("No lengths given for the following contigs:")(ReferenceGenome("test", Array("1", "2", "3"), Map("1" -> 5))) - TestUtils.interceptFatal("Contigs found in 'lengths' that are not present in 'contigs'")(ReferenceGenome("test", Array("1", "2", "3"), Map("1" -> 5, "2" -> 5, "3" -> 5, "4" -> 100))) - TestUtils.interceptFatal("The following X contig names are absent from the reference:")(ReferenceGenome("test", Array("1", "2", "3"), Map("1" -> 5, "2" -> 5, "3" -> 5), xContigs = Set("X"))) - TestUtils.interceptFatal("The following Y contig names are absent from the reference:")(ReferenceGenome("test", Array("1", "2", "3"), Map("1" -> 5, "2" -> 5, "3" -> 5), yContigs = Set("Y"))) - TestUtils.interceptFatal("The following mitochondrial contig names are absent from the reference:")(ReferenceGenome("test", Array("1", "2", "3"), Map("1" -> 5, "2" -> 5, "3" -> 5), mtContigs = Set("MT"))) - TestUtils.interceptFatal("The contig name for PAR interval")(ReferenceGenome("test", Array("1", "2", "3"), Map("1" -> 5, "2" -> 5, "3" -> 5), parInput = Array((Locus("X", 1), Locus("X", 5))))) - TestUtils.interceptFatal("in both X and Y contigs.")(ReferenceGenome("test", Array("1", "2", "3"), Map("1" -> 5, "2" -> 5, "3" -> 5), xContigs = Set("1"), yContigs = Set("1"))) + @Test def testAssertions(): Unit = { + TestUtils.interceptFatal("Must have at least one contig in the reference genome.")( + ReferenceGenome("test", Array.empty[String], Map.empty[String, Int]) + ) + TestUtils.interceptFatal("No lengths given for the following contigs:")(ReferenceGenome( + "test", + Array("1", "2", "3"), + Map("1" -> 5), + )) + TestUtils.interceptFatal("Contigs found in 'lengths' that are not present in 'contigs'")( + ReferenceGenome("test", Array("1", "2", "3"), Map("1" -> 5, "2" -> 5, "3" -> 5, "4" -> 100)) + ) + TestUtils.interceptFatal("The following X contig names are absent from the reference:")( + ReferenceGenome( + "test", + Array("1", "2", "3"), + Map("1" -> 5, "2" -> 5, "3" -> 5), + xContigs = Set("X"), + ) + ) + TestUtils.interceptFatal("The following Y contig names are absent from the reference:")( + ReferenceGenome( + "test", + Array("1", "2", "3"), + Map("1" -> 5, "2" -> 5, "3" -> 5), + yContigs = Set("Y"), + ) + ) + TestUtils.interceptFatal( + "The following mitochondrial contig names are absent from the reference:" + )(ReferenceGenome( + "test", + Array("1", "2", "3"), + Map("1" -> 5, "2" -> 5, "3" -> 5), + mtContigs = Set("MT"), + )) + TestUtils.interceptFatal("The contig name for PAR interval")(ReferenceGenome( + "test", + Array("1", "2", "3"), + Map("1" -> 5, "2" -> 5, "3" -> 5), + parInput = Array((Locus("X", 1), Locus("X", 5))), + )) + TestUtils.interceptFatal("in both X and Y contigs.")(ReferenceGenome( + "test", + Array("1", "2", "3"), + Map("1" -> 5, "2" -> 5, "3" -> 5), + xContigs = Set("1"), + yContigs = Set("1"), + )) } - @Test def testContigRemap() { + @Test def testContigRemap(): Unit = { val mapping = Map("23" -> "foo") - TestUtils.interceptFatal("have remapped contigs in reference genome")(getReference(ReferenceGenome.GRCh37).validateContigRemap(mapping)) + TestUtils.interceptFatal("have remapped contigs in reference genome")( + getReference(ReferenceGenome.GRCh37).validateContigRemap(mapping) + ) } - @Test def testComparisonOps() { + @Test def testComparisonOps(): Unit = { val rg = getReference(ReferenceGenome.GRCh37) // Test contigs @@ -81,14 +123,9 @@ class ReferenceGenomeSuite extends HailSuite { assert(rg.compare("X", "Y") < 0) assert(rg.compare("Y", "X") > 0) assert(rg.compare("Y", "MT") < 0) - - // Test loci - val l1 = Locus("1", 25) - val l2 = Locus("1", 13000) - val l3 = Locus("2", 26) } - @Test def testWriteToFile() { + @Test def testWriteToFile(): Unit = { val tmpFile = ctx.createTmpPath("grWrite", "json") val rg = getReference(ReferenceGenome.GRCh37) @@ -103,7 +140,7 @@ class ReferenceGenomeSuite extends HailSuite { (rg.parInput sameElements gr2.parInput)) } - @Test def testFasta() { + @Test def testFasta(): Unit = { val fastaFile = "src/test/resources/fake_reference.fasta" val fastaFileGzip = "src/test/resources/fake_reference.fasta.gz" val indexFile = "src/test/resources/fake_reference.fasta.fai" @@ -111,11 +148,17 @@ class ReferenceGenomeSuite extends HailSuite { val rg = ReferenceGenome("test", Array("a", "b", "c"), Map("a" -> 25, "b" -> 15, "c" -> 10)) val fr = FASTAReaderConfig(ctx.localTmpdir, ctx.fs, rg, fastaFile, indexFile, 3, 5).reader - val frGzip = FASTAReaderConfig(ctx.localTmpdir, ctx.fs, rg, fastaFileGzip, indexFile, 3, 5).reader + val frGzip = + FASTAReaderConfig(ctx.localTmpdir, ctx.fs, rg, fastaFileGzip, indexFile, 3, 5).reader val refReaderPath = FASTAReader.getLocalFastaFile(ctx.localTmpdir, ctx.fs, fastaFile, indexFile) - val refReaderPathGz = FASTAReader.getLocalFastaFile(ctx.localTmpdir, ctx.fs, fastaFileGzip, indexFile) - val refReader = ReferenceSequenceFileFactory.getReferenceSequenceFile(new java.io.File(uriPath(refReaderPath))) - val refReaderGz = ReferenceSequenceFileFactory.getReferenceSequenceFile(new java.io.File(uriPath(refReaderPathGz))) + val refReaderPathGz = + FASTAReader.getLocalFastaFile(ctx.localTmpdir, ctx.fs, fastaFileGzip, indexFile) + val refReader = ReferenceSequenceFileFactory.getReferenceSequenceFile( + new java.io.File(uriPath(refReaderPath)) + ) + val refReaderGz = ReferenceSequenceFileFactory.getReferenceSequenceFile( + new java.io.File(uriPath(refReaderPathGz)) + ) object Spec extends Properties("Fasta Random") { property("cache gives same base as from file") = forAll(Locus.gen(rg)) { l => @@ -147,7 +190,9 @@ class ReferenceGenomeSuite extends HailSuite { sb.result() } - fr.lookup(Interval(start, end, includesStart = true, includesEnd = true)) == getHtsjdkIntervalSequence + fr.lookup( + Interval(start, end, includesStart = true, includesEnd = true) + ) == getHtsjdkIntervalSequence } } @@ -156,16 +201,30 @@ class ReferenceGenomeSuite extends HailSuite { assert(fr.lookup("a", 25, 0, 5) == "A") assert(fr.lookup("b", 1, 5, 0) == "T") assert(fr.lookup("c", 5, 10, 10) == "GGATCCGTGC") - assert(fr.lookup(Interval(Locus("a", 1), Locus("a", 5), includesStart = true, includesEnd = false)) == "AGGT") - assert(fr.lookup(Interval(Locus("a", 20), Locus("b", 5), includesStart = false, includesEnd = false)) == "ACGTATAAT") - assert(fr.lookup(Interval(Locus("a", 20), Locus("c", 5), includesStart = false, includesEnd = false)) == "ACGTATAATTAAATTAGCCAGGAT") + assert(fr.lookup(Interval( + Locus("a", 1), + Locus("a", 5), + includesStart = true, + includesEnd = false, + )) == "AGGT") + assert(fr.lookup(Interval( + Locus("a", 20), + Locus("b", 5), + includesStart = false, + includesEnd = false, + )) == "ACGTATAAT") + assert(fr.lookup(Interval( + Locus("a", 20), + Locus("c", 5), + includesStart = false, + includesEnd = false, + )) == "ACGTATAATTAAATTAGCCAGGAT") } - @Test def testSerializeOnFB() { + @Test def testSerializeOnFB(): Unit = { withExecuteContext() { ctx => val grch38 = ctx.getReference(ReferenceGenome.GRCh38) val fb = EmitFunctionBuilder[String, Boolean](ctx, "serialize_rg") - val cb = fb.ecb val rgfield = fb.getReferenceGenome(grch38.name) fb.emit(rgfield.invoke[String, Boolean]("isValidContig", fb.getCodeParam[String](1))) @@ -174,20 +233,29 @@ class ReferenceGenomeSuite extends HailSuite { } } - @Test def testSerializeWithLiftoverOnFB() { + @Test def testSerializeWithLiftoverOnFB(): Unit = { withExecuteContext() { ctx => val grch37 = ctx.getReference(ReferenceGenome.GRCh37) val liftoverFile = "src/test/resources/grch37_to_grch38_chr20.over.chain.gz" grch37.addLiftover(ctx, liftoverFile, "GRCh38") - val fb = EmitFunctionBuilder[String, Locus, Double, (Locus, Boolean)](ctx, "serialize_with_liftover") - val cb = fb.ecb + val fb = + EmitFunctionBuilder[String, Locus, Double, (Locus, Boolean)](ctx, "serialize_with_liftover") val rgfield = fb.getReferenceGenome(grch37.name) - fb.emit(rgfield.invoke[String, Locus, Double, (Locus, Boolean)]("liftoverLocus", fb.getCodeParam[String](1), fb.getCodeParam[Locus](2), fb.getCodeParam[Double](3))) + fb.emit(rgfield.invoke[String, Locus, Double, (Locus, Boolean)]( + "liftoverLocus", + fb.getCodeParam[String](1), + fb.getCodeParam[Locus](2), + fb.getCodeParam[Double](3), + )) val f = fb.resultWithIndex()(theHailClassLoader, ctx.fs, ctx.taskContext, ctx.r) - assert(f("GRCh38", Locus("20", 60001), 0.95) == grch37.liftoverLocus("GRCh38", Locus("20", 60001), 0.95)) + assert(f("GRCh38", Locus("20", 60001), 0.95) == grch37.liftoverLocus( + "GRCh38", + Locus("20", 60001), + 0.95, + )) grch37.removeLiftover("GRCh38") } } diff --git a/hail/src/test/scala/is/hail/variant/vsm/PartitioningSuite.scala b/hail/src/test/scala/is/hail/variant/vsm/PartitioningSuite.scala index 4670f90ea51..c9b499cf382 100644 --- a/hail/src/test/scala/is/hail/variant/vsm/PartitioningSuite.scala +++ b/hail/src/test/scala/is/hail/variant/vsm/PartitioningSuite.scala @@ -8,18 +8,20 @@ import is.hail.rvd.RVD import is.hail.types._ import is.hail.types.virtual.{TInt32, TStruct} import is.hail.utils.FastSeq + import org.testng.annotations.Test class PartitioningSuite extends HailSuite { - @Test def testShuffleOnEmptyRDD() { + @Test def testShuffleOnEmptyRDD(): Unit = { val typ = TableType(TStruct("tidx" -> TInt32), FastSeq("tidx"), TStruct.empty) val t = TableLiteral( TableValue( ctx, typ, BroadcastRow.empty(ctx), - RVD.empty(ctx, typ.canonicalRVDType)), - theHailClassLoader + RVD.empty(ctx, typ.canonicalRVDType), + ), + theHailClassLoader, ) val rangeReader = ir.MatrixRangeReader(100, 10, Some(10)) Interpret( @@ -27,8 +29,11 @@ class PartitioningSuite extends HailSuite { ir.MatrixRead(rangeReader.fullMatrixType, false, false, rangeReader), t, "foo", - product = false), - ctx, optimize = false) + product = false, + ), + ctx, + optimize = false, + ) .rvd.count() } } diff --git a/hail/style-guide.md b/hail/style-guide.md deleted file mode 100644 index 4848234c814..00000000000 --- a/hail/style-guide.md +++ /dev/null @@ -1,50 +0,0 @@ -# Scala Style Guide -## IntelliJ settings -### Automatic - -The easiest way to comply with the style guide is to load the IntelliJ -`code_style.xml`. You can load the xml file using the `Manage ...` button on the -`Editor > Code Style` page. After loading it, select it from the drop-down list. - -### Manual - -Our style differs slightly from the IntelliJ Scala plugin defaults. -Make the following changes: - - - Turn off Preferences > Editor > Code Style > Syntax > Other > Enforce procedure syntax for methods with Unit return type. - - - Turn on Preferences > Editor > Code Style > Scala > Spaces > Other > Insert whitespaces in simple one line blocks. - - - Turn off Preferences > Editor > Code Style > Scala > Wrapping and Braces > Align when multiline in all categories. - -## Guide - - - Prefer - ```scala - def foo() { ... } - ``` - to - ```scala - def foo(): Unit = { ... } - ``` - In IntelliJ, turn off Preferences > Editor > Code Style > Syntax > Other > Enforce procedure syntax for methods with Unit return type. - - - Prefix mutable data structures with mutable. That is, prefer - ```scala - import scala.collection.mutable - ... mutable.ArrayBuilder[Byte] ... - ``` - to - ```scala - import scala.collection.mutable.ArrayBuilder - ... ArrayBuilder[Byte] ... - ``` - - - Prefer `None` to `Option.empty`. Prefer `Some(_)` to `Option(_)` when the argument will never be `null`. - - - Use require, assert and ensure liberally to check preconditions, conditions and post-conditions. Define a validate member to check object invariants and call where suitable. - - - In IntelliJ, turn on Preferences > Editor > Code Style > Scala > Spaces > Other > Insert whitespaces in simple one line blocks. - - - In IntelliJ, turn off Preferences > Editor > Code Style > Scala > Wrapping and Braces > Align when multiline in all categories. This helps the code from migrating too far to the right. - diff --git a/hail/version.mk b/hail/version.mk index c077bc41785..0e7b962fee1 100644 --- a/hail/version.mk +++ b/hail/version.mk @@ -13,10 +13,10 @@ ifndef BRANCH $(error "git rev-parse --abbrev-ref HEAD" failed to produce output) endif -SCALA_VERSION ?= 2.12.15 -SPARK_VERSION ?= 3.3.0 +SCALA_VERSION ?= 2.12.18 +SPARK_VERSION ?= 3.3.2 HAIL_MAJOR_MINOR_VERSION := 0.2 -HAIL_PATCH_VERSION := 126 +HAIL_PATCH_VERSION := 130 HAIL_PIP_VERSION := $(HAIL_MAJOR_MINOR_VERSION).$(HAIL_PATCH_VERSION) HAIL_VERSION := $(HAIL_PIP_VERSION)-$(SHORT_REVISION) ELASTIC_MAJOR_VERSION ?= 7 diff --git a/infra/azure/main.tf b/infra/azure/main.tf index ec32ff7db28..f76d76f921b 100644 --- a/infra/azure/main.tf +++ b/infra/azure/main.tf @@ -2,7 +2,7 @@ terraform { required_providers { azurerm = { source = "hashicorp/azurerm" - version = "=2.99.0" + version = "=3.93.0" } azuread = { source = "hashicorp/azuread" diff --git a/infra/azure/modules/batch/main.tf b/infra/azure/modules/batch/main.tf index 9d0a807da32..987480408f9 100644 --- a/infra/azure/modules/batch/main.tf +++ b/infra/azure/modules/batch/main.tf @@ -72,6 +72,9 @@ resource "azurerm_storage_account" "batch" { location = var.resource_group.location account_tier = "Standard" account_replication_type = "LRS" + min_tls_version = "TLS1_0" + + allow_nested_items_to_be_public = false blob_properties { last_access_time_enabled = true @@ -96,6 +99,9 @@ resource "azurerm_storage_account" "test" { location = var.resource_group.location account_tier = "Standard" account_replication_type = "LRS" + min_tls_version = "TLS1_0" + + allow_nested_items_to_be_public = false blob_properties { last_access_time_enabled = true @@ -319,16 +325,6 @@ resource "azurerm_role_assignment" "ci_acr_role" { principal_id = module.ci_sp.principal_id } -resource "kubernetes_secret" "registry_push_credentials" { - metadata { - name = "registry-push-credentials" - } - - data = { - "credentials.json" = jsonencode(module.ci_sp.credentials) - } -} - resource "azuread_application" "testns_ci" { display_name = "${var.resource_group.name}-testns-ci" } diff --git a/infra/azure/modules/batch/outputs.tf b/infra/azure/modules/batch/outputs.tf index bcf82a77a2b..120fd83b7aa 100644 --- a/infra/azure/modules/batch/outputs.tf +++ b/infra/azure/modules/batch/outputs.tf @@ -1,9 +1,9 @@ output batch_logs_storage_uri { - value = "hail-az://${azurerm_storage_account.batch.name}/${azurerm_storage_container.batch_logs.name}" + value = "https://${azurerm_storage_account.batch.name}.blob.core.windows.net/${azurerm_storage_container.batch_logs.name}" } output query_storage_uri { - value = "hail-az://${azurerm_storage_account.batch.name}/${azurerm_storage_container.query.name}" + value = "https://${azurerm_storage_account.batch.name}.blob.core.windows.net/${azurerm_storage_container.query.name}" } output test_storage_container { @@ -11,7 +11,7 @@ output test_storage_container { } output test_storage_uri { - value = "hail-az://${azurerm_storage_account.test.name}/${azurerm_storage_container.test.name}" + value = "https://${azurerm_storage_account.test.name}.blob.core.windows.net/${azurerm_storage_container.test.name}" } output ci_principal_id { diff --git a/infra/azure/modules/ci/main.tf b/infra/azure/modules/ci/main.tf index f3b1eaff1ef..99eff220a41 100644 --- a/infra/azure/modules/ci/main.tf +++ b/infra/azure/modules/ci/main.tf @@ -4,6 +4,9 @@ resource "azurerm_storage_account" "ci" { location = var.resource_group.location account_tier = "Standard" account_replication_type = "LRS" + min_tls_version = "TLS1_0" + + allow_nested_items_to_be_public = false blob_properties { last_access_time_enabled = true @@ -49,7 +52,7 @@ resource "azurerm_role_assignment" "ci_test_container_contributor" { module "k8s_resources" { source = "../../../k8s/ci" - storage_uri = "hail-az://${azurerm_storage_account.ci.name}/${azurerm_storage_container.ci_artifacts.name}" + storage_uri = "https://${azurerm_storage_account.ci.name}.blob.core.windows.net/${azurerm_storage_container.ci_artifacts.name}" deploy_steps = var.deploy_steps watched_branches = var.watched_branches github_context = var.github_context diff --git a/infra/azure/modules/vdc/main.tf b/infra/azure/modules/vdc/main.tf index 3b3537b0de4..90d9d9f62ed 100644 --- a/infra/azure/modules/vdc/main.tf +++ b/infra/azure/modules/vdc/main.tf @@ -51,6 +51,7 @@ resource "azurerm_kubernetes_cluster" "vdc" { dns_prefix = "example" automatic_channel_upgrade = "stable" + role_based_access_control_enabled = false default_node_pool { name = "nonpreempt" @@ -71,16 +72,13 @@ resource "azurerm_kubernetes_cluster" "vdc" { type = "SystemAssigned" } - addon_profile { - oms_agent { - enabled = true - log_analytics_workspace_id = azurerm_log_analytics_workspace.logs.id - } + oms_agent { + log_analytics_workspace_id = azurerm_log_analytics_workspace.logs.id } # https://github.com/hashicorp/terraform-provider-azurerm/issues/7396 lifecycle { - ignore_changes = [addon_profile.0, default_node_pool.0.node_count] + ignore_changes = [default_node_pool.0.node_count] } } @@ -158,4 +156,5 @@ resource "azurerm_public_ip" "gateway_ip" { location = var.resource_group.location sku = "Standard" allocation_method = "Static" + zones = ["1", "2", "3"] } diff --git a/infra/gcp-broad/ci/main.tf b/infra/gcp-broad/ci/main.tf index dc7638de6cc..03efbd81424 100644 --- a/infra/gcp-broad/ci/main.tf +++ b/infra/gcp-broad/ci/main.tf @@ -16,6 +16,7 @@ resource "google_storage_bucket" "bucket" { location = var.bucket_location force_destroy = false storage_class = var.bucket_storage_class + uniform_bucket_level_access = true labels = { "name" = "hail-ci-${random_string.hail_ci_bucket_suffix.result}" } @@ -40,7 +41,7 @@ resource "google_storage_bucket" "bucket" { resource "google_storage_bucket_iam_member" "ci_bucket_admin" { bucket = google_storage_bucket.bucket.name - role = "roles/storage.legacyBucketWriter" + role = "roles/storage.objectAdmin" member = "serviceAccount:${var.ci_email}" } diff --git a/infra/gcp-broad/gcp-ar-cleanup-policy.txt b/infra/gcp-broad/gcp-ar-cleanup-policy.txt index 235c1cbd8c9..c909bb68347 100644 --- a/infra/gcp-broad/gcp-ar-cleanup-policy.txt +++ b/infra/gcp-broad/gcp-ar-cleanup-policy.txt @@ -130,7 +130,6 @@ "linting", "monitoring", "netcat", - "query-build", "test-ci-utils", "test_hello_create_certs_image", "volume", diff --git a/infra/gcp-broad/gcs_bucket/main.tf b/infra/gcp-broad/gcs_bucket/main.tf index e56ee3430b6..44761e2e179 100644 --- a/infra/gcp-broad/gcs_bucket/main.tf +++ b/infra/gcp-broad/gcs_bucket/main.tf @@ -7,4 +7,5 @@ resource "google_storage_bucket" "bucket" { location = var.location force_destroy = true storage_class = var.storage_class + uniform_bucket_level_access = true } diff --git a/infra/gcp-broad/hail-is/ci_config.enc.json b/infra/gcp-broad/hail-is/ci_config.enc.json index 1e8c191f7de..33a228803e4 100644 --- a/infra/gcp-broad/hail-is/ci_config.enc.json +++ b/infra/gcp-broad/hail-is/ci_config.enc.json @@ -1,33 +1,33 @@ { "deploy_steps": [], - "github_context": "ENC[AES256_GCM,data:JN5usSTHOg==,iv:CIA4ELD6Q/v06KbISOHeDCsnSjwWb51VpzrhN+3WKvs=,tag:EwSRVoerxqgxbx5+Q0G/Fg==,type:str]", - "storage_uri": "ENC[AES256_GCM,data:S9R83B6rVVTSozyBnfFpbeqJ,iv:85Y90Bg2jMGRt7lq3gaCQLxzQ+4pbRA+Hedw8/fGrak=,tag:ouZIcQAWBkgWll7GugW8ag==,type:str]", - "bucket_location": "ENC[AES256_GCM,data:AHTjUUbf6JamtYE=,iv:g9L8ceIQx0xUmE5xGHJJkwCrFilf165bYtjF2alVztU=,tag:QawH9d1WLm6SPLt7gBZnhQ==,type:str]", - "bucket_storage_class": "ENC[AES256_GCM,data:NX6xM/R9Bp8=,iv:WWskNj+PWD4UxWMsRyIqZroxIijNaadGJvFMJcH7NU8=,tag:ncV4dCcWEx+iF6RbsICKeQ==,type:str]", - "test_oauth2_callback_urls": "ENC[AES256_GCM,data:jObH1oS9+3OBoGu4VlQqj2OWlE2Fus3XuWYZ2pYWUq8pm3JYU7CK1KhWJfjioUh82f/67N2Y62JO+sI0+YBrpHk7cSWNCW4qrv/YSMRYr9x8VmcGg2kGIMXpcik9xdSidmpma4nWchXCARTuYATr0MHicDndxYkkQva0ImgWd3QrYRjjy/ZI1EGDyXIHSVTcSSqybErh/tM5vg8p0SnClrU7GSwFrMUV+80l7m1Qar+CIMd4gMYsBVTmdxEjeK5S+Z/tKQ434LduV8u21LrsQe4Q5DjgNRdgRNrn4llWfen9ie8+6onLjDfyfLr6/sh+50oGV6SFyEYPytKBToY2t1QnYrJaH/xcuGP9oPSDfvPdgdn7GzBUahcAEA72juIGN/RotnITws2iduAMNFjIFUEGKYVgFQWeHswmhzNhztIhKwlncq/RQzeuGVgB4prPTFqbT/7eT/57eHqFx84kedD7XEC5eCeqSdbCIcXMIzLAtwSFCHh/gL++luQ3jlriklw/bZ/OJVAISLWAgINKrmqbli7RW8gmDJ3Rlq2k20KgI2TLbAyUYCJm00jWTAGBy/jdrIQImrY2EIjdoBTnUFErcKIshkc7751R28lin22ybQCaXvmYfvy7Tuf0gp5pnn+U/znXLRBo/lW3kJadDilpkAcusdqlEtB6s/JcaBEhchVkfvpufJUXdyo5kMBjr2XSiWvu3xvMZbaMq59jdSwFhQu/tWCUdLL7bB8cXHshODihDvQGaEixx7NLffmzK3NO0aizBvOiPTyuBNRD+Q+RSSKKyi5AOqDSH3Umbhqe8RtMfUgYLZTsqozWBMopfsiECB+I8teapvpuJtI5+dvOmkIjrvcaUnZAVmRpcE7xF+wDUIZGETiscEYTv/7Y+G4ylLy6E5z4zk7ugiQjY79/H0xSm8mIYXyduDcFa4ZgWxYIZbghLgbi2FJj7J0+6SDm00eD1NfHkaozke7KvFlILW1VUOcrqqlxGzXlxQLoT27LxWQXSrwp+5SYvJHQfeSxT03ksHLHqoSa3lEMiYXHFh2mzwmtjK4wkkZpTUt0RJcfOnIy5SI1bkubnL0FU3ARbftYXL+aFWBdl86y7Hu+397srrZ4SZ2LNr3I33XzrHowUMxfvDLDAgSKxUCWuBm1EggaIQit7WMg9hDHxHc6QCWbupO0oFeGrvUbQoZZEqrLFc/cGbDGFZTURPrvTcE9Lrr2R7kzn3nKhHSAzCZrndhY5mCCOfR60cE04H857x1HhAJKxIsS3D1izDyQpoI//mbz+fKEEwDjWR5bnE1K8umy1CUMJYoPibzcVFpuBdTzvqwO4WeD8fiwdL6kFL788n/8+K4qERkVEip0PyhleAtvjxER4Xt/JSjNvikukfF07/GouPulnh8S7JeRGEgn0Lx/NeBLFbCQvdibsr/knCzDXbO5EoBd78AIkqBSD/2i1M4XEFsJNZOGpP9ZYNUUWdRrDyLKgqJkaEyTIyJ3mPiZcOo7htKyHol3,iv:BwWT2b31zTn1ELC9fZofIptqRDkJj2ybtAhwFQQtBKs=,tag:zkbT2TolS1XG7Ixa8lLwQw==,type:str]", + "github_context": "ENC[AES256_GCM,data:uuAIsXg52A==,iv:pHqi8cTPKC7dubMiPhYxy7Gc9GX8pTrjxlmCcmdnPiQ=,tag:6BvGUwKMN+j/ejpyR8UpzA==,type:str]", + "storage_uri": "ENC[AES256_GCM,data:tujcZ3WmeDYWtCZFDsRL/BbB,iv:p81TCKElG/3L78oA4J5gF8m99+6YIAZHmscr85GEvUw=,tag:qWTxVh8SgVf+eF1MSA+fAg==,type:str]", + "bucket_location": "ENC[AES256_GCM,data:KwYNP+er99qin6Y=,iv:7aD4RpevQosLkTcD0aqJ556Wx58PhUryKd96WtD7gFo=,tag:cr/6qfX8WoZt4COnemNoZw==,type:str]", + "bucket_storage_class": "ENC[AES256_GCM,data:U5uAnvELyjI=,iv:N63oG6h4ELTy9WEuZcT7cYOQ96MUcrCy69NJgR5MuSc=,tag:hA22zXOByZXEiRy6nYG9CA==,type:str]", + "test_oauth2_callback_urls": "ENC[AES256_GCM,data:7w8rvlLneiXP0js4p4cqR31hWczs1alNKLdCDebeOr8b3LRN2cV770YmtwBtmxW0ei8CqO7nNtp9UZg/LaqyRMyogm2feWaIhd+j0flVytDVXGvVMDf8cB0oY4/MCoFANFQnJ1xP3CB0ESBRKk695pdFp3ZPmmdy6qF5+r9JbIq3ELcxkwuzKATaDgljhWu00H3s3Y6pdwF8ueX8cFh51+a3+BnJlONDLNV8PG1rzsz3ReP2Kjv82EcjMwk4BrU2cMMqka2ILAuLcTCvX7a2FOSq11OLAGqJ2enFdymSUQJM+one/ZkwAWgAW9rnvLFgSu1PkbjaQ9PbeotopypUVyhLrMqJUAxB4qeckEudkKuGM1dfZWqOPAXT7iuXB/X7EN5uRKrIcoJ746PJgnskAJDRDpK1lNbWeuv+dm3DSowaH+cEAHU5XGYI0JWWXa6TU7nYNsFh0G30SVYVWvQEmnTdT4Hv3Yi1ptr8FhYcR8oHtC+oOfqxXYiGcbTdvmI6Iwwm6reTqTWN904/oVdvOcjF2wAuge/agxfWA8TjT8zBNgLmyr2cNwbmhz4jz/COTZd/fOanXyUH3i4DcYIp9zBOQQdweptYVdAIbRDBl/SG72wHqjI0981KThQ2Kt3Blt1omshwhtgrf2tXouk3+sgYnZZm1OuDQDKF8lRlPqSNhEiK6BI0Q8GhncAlvaRxG6DiXRCiu+zdh+0PgsvsiL553Ki9oI0OVgHTwx/GRLuwEGkQCLYDmaTbvyLznNmUjxWtXBl8ETIUhOVTiWE+Lnq+dPl0dhA0jgPRSBKAa9+IRmrRlUHI0iW9Y4UTJXX3+W81qCSI1nq540V9vr/jYGIKeTwncdroZF0MzF9zMOp6C9ELQoZikUcxPpgq6J04ULdrDNO8RYJxowL/ZtsegoUBZlJKFvHwqm9h1fRlZ1kIYopjxY/LM6f1KxEGrUwspFvP7fMwc0IURQhhRgn0GQPh0FKzkVkE3Um9yZS3xcM0OJOfZGcNWZTwsIg7c2UHuazsB73o2g4fIK5RJTlkvH7X2WzbqE8abXhRjWPN3BP/XEozvZhoZook2yhNM8qpzn3eGD0fOqyQEkA2O2qT9C3IJVP98FL5y1R7Julb7FDYOk8F8xGo8I7Zmd8nAqHeP8QvNJFIO4fMQLrh3dqciV1ofwP+qLtUAAgaM8NghvX86Mp4jOql0HQxDTCU6ZHFTQIV9+b3ZBmMwyf+EFC/35S01bNY3sSANGMKU5PTaghrjLuvXnlBy/oCH1Ev6Jyx8V8AVEQdm6THnMoLhofOmw3SMcLQMDsUDUkDLUv1P/Y6SNLrX6ZLz2djZnarBaPxKuzks101kMRaoSUO2jxON9X9JGS/Xxsz+wMCq3NmUjIm6hk0H9ciYg1Iavu26HZ1pqreNpoVNtTnUj1ggPwke90dwfl9y61orhTIaioR7j2MxKNbS/SuL05pIroUAXZsEFkiuN5LRWPqdwwwq47TCq1QnsAVAnG7lyaqNslD,iv:tO64hGXtVNioCGYcGUHNcimNNWCZe83eM/XVhqTyVtk=,tag:yBAkDCCnYiWJXBDEsfpzeg==,type:str]", "watched_branches": [ [ - "ENC[AES256_GCM,data:89ey6FPaU863AVhw/cXAJQk=,iv:mea/Zkiso71LCd+WMOyUzLxRhdTEvo+iCZCJEkOY9Hg=,tag:ay8YrryybQFeFGSwwRwcjA==,type:str]", - "ENC[AES256_GCM,data:UTQKiA==,iv:wRgansKSpx+LxP0qTXQtGjCaTjE6cjEDVFXFx7PYlZY=,tag:LSCQKiuXxRx8S/O04c57oQ==,type:bool]", - "ENC[AES256_GCM,data:4ER1XA==,iv:kBXFBG+9FV28yAPlrYHGrmVcoZKTryI6wt07fqrBr4k=,tag:6rdzwJ7e+14yScGqQGmmiA==,type:bool]" + "ENC[AES256_GCM,data:UL1yY1feFvUsudXLlPSjCcI=,iv:5b5vJrT07kXuj9XOafKZSPfL6KG+trVSv9dRjFyR1vo=,tag:K1hyoeM0HS7GO7b6JLIyhg==,type:str]", + "ENC[AES256_GCM,data:QQ3Osg==,iv:2Yb8NPJrEkG2xFHSEHn9wHctRgVldSk0QvjHd5LeCv4=,tag:IpGeM4ZjNArwXzVgCPWKug==,type:bool]", + "ENC[AES256_GCM,data:RidTJQ==,iv:Co9iJ5DNZ/XJCWWGi9wclFH9dp7dZVD4G+p3Rl5VbXc=,tag:280mUXJTTAykTeUo9pqRSA==,type:bool]" ] ], - "github_oauth_token": "ENC[AES256_GCM,data:vrcdC7SFoC92Gp9hzCzAsS19UsxEP1vfwY7kxHs4ngpjTvAmjDAYwg==,iv:uodttuaHtdOiw3jiwQPPhj6vkGOnEQsmw6XmrKnoCmg=,tag:gOeJGgOLj0gc6/C9X4et2w==,type:str]", - "github_user1_oauth_token": "ENC[AES256_GCM,data:4UpueIA+DJsqowuqT/CM6LWO23+AOCnla7F/En9uLm/KeTghKpxx3A==,iv:w70hLt2b1zTbyFZfuX9cjFtN2BJvvpM2/9qg1pGDlfc=,tag:uEprSZGsAyaEz5hZwKyB2w==,type:str]", + "github_oauth_token": "ENC[AES256_GCM,data:fU8dHJvwbaj985PFSglwk+XdUtE9gdOkdwi62lsploetzj0J/mLC7o7443xKOCd9ASJu3Esc6Nhlmi9gYbukR12wDd07sUJF5zxWZ7Im7jkWfdXx7kMO2ee3SFV5,iv:/9XzhBaUlxMhC+UMgE3VgbIdcc6XKkvRcgmQL82Ykt4=,tag:SqSEJR51krrToR1zOU3YHQ==,type:str]", + "github_user1_oauth_token": "ENC[AES256_GCM,data:VKKbJpwK4FdqNVahB8ulKnnk9H99AHqjFJ8MF/Cu/wCKG/pAAodf9A==,iv:4lCiAZZUakA1XsSAS65ttRY9kwDOxYCFU/tD5Q8wimY=,tag:meq8BpYJ0LG5nhgJ061Vbw==,type:str]", "sops": { "kms": null, "gcp_kms": [ { "resource_id": "projects/hail-vdc/locations/global/keyRings/sops/cryptoKeys/sops-key", - "created_at": "2023-05-09T19:53:50Z", - "enc": "CiQA0ec+W0aC8bSdWP22np+CuMKieVOLkuzgXwQ9L3fYe8KuG9MSSQCe82Ad9tL4LzDEnBpALOvttYTDL8vIrBeS1IeFIimyNgfitkWSXcKWsB7hc08iI/8S/1pJRyUFpioMLI7aV0r0JZjeXajiofM=" + "created_at": "2024-02-29T18:01:13Z", + "enc": "CiQA0ec+W/cAt2qYQY/buQOQqKg+NcYOqdKOJvGnI+ZXfitg344SSQCe82AdO3iHE/Wqgvb9yxLyJGOpyH8sOgMb208I6S4l+CtCXS8sv2Icuk8GLrN/eBQenJR8/AXptdszsrvND2ngz5M/+8q2kN0=" } ], "azure_kv": null, "hc_vault": null, "age": null, - "lastmodified": "2023-05-09T19:53:50Z", - "mac": "ENC[AES256_GCM,data:ow71s2AziKKv6bmNPpddL7T87p8irYRzg+sLfYa2sHJ6OxSLpahT+Usjh70RE2Xmidycc8tvl+VFgFYUOFPocwEDReIbhM4B6gHbTThizgaV6tnIcSypzk+YUj/C3sl00668uOOtd5HAdnv1Ne1j4zOfuJGuXeVVZOB4zD+S25o=,iv:74+qhgXteCUeZKxh7GvTl7NPXgnoHPltwa8W0vBwG38=,tag:Yknyjubbr8ASIdZQ1RzPvA==,type:str]", + "lastmodified": "2024-02-29T18:01:14Z", + "mac": "ENC[AES256_GCM,data:8XiJ+3b7fMsJkGcraWfBh703a3Av9F1yshDbmz2J5NXpZsaHxZ8pIR6hap+ST6wp9CvBxYMmcP5XOVp17p9O/WIa1CboRIctXqkFsUYt41f6GFjP6hwlsdeFW/yP55Mv681sq1bBmeMrfU2LFaTG+z8lzeAM9VaI5E0xf2YILpU=,iv:x1Zy0k5UmXMJ1rzoys9XBNCMpyqYoivRLbhy8JebxYo=,tag:NSreDd2t1FgJgjO5Dyq+Xg==,type:str]", "pgp": null, "unencrypted_suffix": "_unencrypted", "version": "3.7.3" diff --git a/infra/gcp-broad/main.tf b/infra/gcp-broad/main.tf index 30152037024..92b948fafb2 100644 --- a/infra/gcp-broad/main.tf +++ b/infra/gcp-broad/main.tf @@ -2,7 +2,7 @@ terraform { required_providers { google = { source = "hashicorp/google" - version = "4.32.0" + version = "5.15.0" } kubernetes = { source = "hashicorp/kubernetes" @@ -438,12 +438,12 @@ resource "google_artifact_registry_repository_iam_member" "artifact_registry_bat member = "serviceAccount:${google_service_account.batch_agent.email}" } -resource "google_artifact_registry_repository_iam_member" "artifact_registry_ci_viewer" { +resource "google_artifact_registry_repository_iam_member" "artifact_registry_ci_admin" { provider = google-beta project = var.gcp_project repository = google_artifact_registry_repository.repository.name location = var.artifact_registry_location - role = "roles/artifactregistry.reader" + role = "roles/artifactregistry.admin" member = "serviceAccount:${module.ci_gsa_secret.email}" } @@ -527,15 +527,6 @@ module "ci_gsa_secret" { ] } -resource "google_artifact_registry_repository_iam_member" "artifact_registry_viewer" { - provider = google-beta - project = var.gcp_project - repository = google_artifact_registry_repository.repository.name - location = var.artifact_registry_location - role = "roles/artifactregistry.reader" - member = "serviceAccount:${module.ci_gsa_secret.email}" -} - module "testns_ci_gsa_secret" { source = "./gsa" name = "testns-ci" @@ -548,12 +539,12 @@ resource "google_storage_bucket_iam_member" "testns_ci_bucket_admin" { member = "serviceAccount:${module.testns_ci_gsa_secret.email}" } -resource "google_artifact_registry_repository_iam_member" "artifact_registry_testns_ci_viewer" { +resource "google_artifact_registry_repository_iam_member" "artifact_registry_testns_ci_repo_admin" { provider = google-beta project = var.gcp_project repository = google_artifact_registry_repository.repository.name location = var.artifact_registry_location - role = "roles/artifactregistry.reader" + role = "roles/artifactregistry.repoAdmin" member = "serviceAccount:${module.testns_ci_gsa_secret.email}" } @@ -811,7 +802,7 @@ resource "google_storage_bucket" "hail_test_requester_pays_bucket" { } resource "google_dns_managed_zone" "dns_zone" { - description = "" + description = "hail managed dns zone" name = "hail" dns_name = "hail." visibility = "private" @@ -862,7 +853,7 @@ resource "kubernetes_cluster_role_binding" "batch" { } } -resource "kubernetes_pod_disruption_budget" "kube_dns_pdb" { +resource "kubernetes_pod_disruption_budget_v1" "kube_dns_pdb" { metadata { name = "kube-dns" namespace = "kube-system" @@ -877,7 +868,7 @@ resource "kubernetes_pod_disruption_budget" "kube_dns_pdb" { } } -resource "kubernetes_pod_disruption_budget" "kube_dns_autoscaler_pdb" { +resource "kubernetes_pod_disruption_budget_v1" "kube_dns_autoscaler_pdb" { metadata { name = "kube-dns-autoscaler" namespace = "kube-system" @@ -892,7 +883,7 @@ resource "kubernetes_pod_disruption_budget" "kube_dns_autoscaler_pdb" { } } -resource "kubernetes_pod_disruption_budget" "event_exporter_pdb" { +resource "kubernetes_pod_disruption_budget_v1" "event_exporter_pdb" { metadata { name = "event-exporter" namespace = "kube-system" diff --git a/infra/gcp/README.md b/infra/gcp/README.md index 4511596ec3a..a0846226e68 100644 --- a/infra/gcp/README.md +++ b/infra/gcp/README.md @@ -89,7 +89,15 @@ Instructions: use_artifact_registry = false ``` -- You can optionally create a `/tmp/ci_config.json` file to enable CI triggered by GitHub events: +- You can optionally create a `/tmp/ci_config.json` file to enable CI triggered by GitHub + events. Note that `github_oauth_token` is not necessarily an OAuth2 access token. In fact, it + should be a fine-grained personal access token. The currently public documentation on fine-grained + access tokens is not very good. Check this [page in + `github/docs`](https://github.com/github/docs/blob/main/content/authentication/keeping-your-account-and-data-secure/managing-your-personal-access-tokens.md) + for information on how to create a personal access token that is privileged to access the + `hail-is` organization. Note in particular that personal access tokens have a "resource owner" + field which is fixed at creation time. The token can only read or write to repositories owned by + the "resource owner". ```json { diff --git a/infra/gcp/main.tf b/infra/gcp/main.tf index eeeb4135c6e..5f8ca7f8965 100644 --- a/infra/gcp/main.tf +++ b/infra/gcp/main.tf @@ -391,12 +391,12 @@ resource "google_artifact_registry_repository_iam_member" "artifact_registry_bat member = "serviceAccount:${google_service_account.batch_agent.email}" } -resource "google_artifact_registry_repository_iam_member" "artifact_registry_ci_viewer" { +resource "google_artifact_registry_repository_iam_member" "artifact_registry_ci_admin" { provider = google-beta project = var.gcp_project repository = google_artifact_registry_repository.repository.name location = var.gcp_location - role = "roles/artifactregistry.reader" + role = "roles/artifactregistry.admin" member = "serviceAccount:${module.ci_gsa_secret.email}" } @@ -415,19 +415,6 @@ resource "google_artifact_registry_repository_iam_member" "artifact_registry_pus member = "serviceAccount:${google_service_account.gcr_push.email}" } -# This is intended to match the secret name also used for azure credentials -# This should ultimately be replaced by using CI's own batch-managed credentials -# in BuildImage jobs -resource "kubernetes_secret" "registry_push_credentials" { - metadata { - name = "registry-push-credentials" - } - - data = { - "credentials.json" = base64decode(google_service_account_key.gcr_push_key.private_key) - } -} - module "auth_gsa_secret" { source = "./gsa_k8s_secret" name = "auth" @@ -510,24 +497,15 @@ resource "google_storage_bucket_iam_member" "testns_ci_bucket_admin" { member = "serviceAccount:${module.testns_ci_gsa_secret.email}" } -resource "google_artifact_registry_repository_iam_member" "artifact_registry_testns_ci_viewer" { +resource "google_artifact_registry_repository_iam_member" "artifact_registry_testns_ci_repo_admin" { provider = google-beta project = var.gcp_project repository = google_artifact_registry_repository.repository.name location = var.gcp_location - role = "roles/artifactregistry.reader" + role = "roles/artifactregistry.repoAdmin" member = "serviceAccount:${module.testns_ci_gsa_secret.email}" } -resource "google_artifact_registry_repository_iam_member" "artifact_registry_viewer" { - provider = google-beta - project = var.gcp_project - repository = google_artifact_registry_repository.repository.name - location = var.gcp_location - role = "roles/artifactregistry.reader" - member = "serviceAccount:${module.ci_gsa_secret.email}" -} - module "monitoring_gsa_secret" { source = "./gsa_k8s_secret" name = "monitoring" diff --git a/monitoring/deployment.yaml b/monitoring/deployment.yaml index c3ed1caefda..44cbcb898b1 100644 --- a/monitoring/deployment.yaml +++ b/monitoring/deployment.yaml @@ -39,11 +39,6 @@ spec: value: /deploy-config/deploy-config.json - name: HAIL_SHA value: "{{ code.sha }}" - - name: HAIL_DOMAIN - valueFrom: - secretKeyRef: - name: global-config - key: domain - name: PROJECT valueFrom: secretKeyRef: diff --git a/monitoring/monitoring/monitoring.py b/monitoring/monitoring/monitoring.py index 6695bfa5114..eb80de5d278 100644 --- a/monitoring/monitoring/monitoring.py +++ b/monitoring/monitoring/monitoring.py @@ -13,12 +13,18 @@ from aiohttp import web from prometheus_async.aio.web import server_stats # type: ignore -from gear import AuthServiceAuthenticator, Database, json_response, setup_aiohttp_session, transaction +from gear import ( + AuthServiceAuthenticator, + CommonAiohttpAppKeys, + Database, + json_response, + setup_aiohttp_session, + transaction, +) from hailtop import aiotools, httpx from hailtop.aiocloud import aiogoogle from hailtop.config import get_deploy_config from hailtop.hail_logging import AccessLogger -from hailtop.tls import internal_server_ssl_context from hailtop.utils import ( cost_str, parse_timestamp_msecs, @@ -124,7 +130,7 @@ async def _billing(request: web.Request): set_message(session, msg, 'error') return ([], [], [], time_period_query) - db = app['db'] + db = app[AppKeys.DB] records = db.execute_and_fetchall( 'SELECT * FROM monitoring_billing_data WHERE year = %s AND month = %s;', (time_period.year, time_period.month) ) @@ -162,8 +168,8 @@ async def billing(request: web.Request, userdata) -> web.Response: # pylint: di async def query_billing_body(app): - db = app['db'] - bigquery_client = app['bigquery_client'] + db = app[AppKeys.DB] + bigquery_client = app[AppKeys.BQ_CLIENT] async def _query(dt): month = dt.month @@ -184,7 +190,7 @@ async def _query(dt): invoice_month = datetime.date.strftime(start, '%Y%m') # service.id: service.description -- "6F81-5844-456A": "Compute Engine" - cmd = f''' + cmd = f""" SELECT service.id as service_id, service.description as service_description, sku.id as sku_id, sku.description as sku_description, SUM(cost) as cost, CASE WHEN service.id = "6F81-5844-456A" AND EXISTS(SELECT 1 FROM UNNEST(labels) WHERE key = "namespace" and value = "default") THEN "batch-production" @@ -197,7 +203,7 @@ async def _query(dt): FROM `broad-ctsa.hail_billing.gcp_billing_export_v1_0055E5_9CA197_B9B894` WHERE DATE(_PARTITIONTIME) >= "{start_str}" AND DATE(_PARTITIONTIME) <= "{end_str}" AND project.name = "{PROJECT}" AND invoice.month = "{invoice_month}" GROUP BY service_id, service_description, sku_id, sku_description, source; -''' +""" log.info(f'querying BigQuery with command: {cmd}') @@ -218,17 +224,17 @@ async def _query(dt): @transaction(db) async def insert(tx): await tx.just_execute( - ''' + """ DELETE FROM monitoring_billing_data WHERE year = %s AND month = %s; -''', +""", (year, month), ) await tx.execute_many( - ''' + """ INSERT INTO monitoring_billing_data (year, month, service_id, service_description, sku_id, sku_description, source, cost) VALUES (%s, %s, %s, %s, %s, %s, %s, %s); -''', +""", records, ) @@ -250,22 +256,22 @@ async def insert(tx): async def polling_loop(app): - db = app['db'] + db = app[AppKeys.DB] while True: now = datetime.datetime.now() row = await db.select_and_fetchone('SELECT mark FROM monitoring_billing_mark;') if not row['mark'] or now > datetime.datetime.fromtimestamp(row['mark'] / 1000) + datetime.timedelta(days=1): - app['query_billing_event'].set() + app[AppKeys.QUERY_BILLING_EVENT].set() await asyncio.sleep(60) async def monitor_disks(app): log.info('monitoring disks') - compute_client: aiogoogle.GoogleComputeClient = app['compute_client'] + compute_client = app[AppKeys.COMPUTE_CLIENT] disk_counts = defaultdict(list) - for zone in app['zones']: + for zone in app[AppKeys.ZONES]: async for disk in await compute_client.list(f'/zones/{zone}/disks', params={'filter': '(labels.batch = 1)'}): namespace = disk['labels']['namespace'] size_gb = int(disk['sizeGb']) @@ -297,11 +303,11 @@ async def monitor_disks(app): async def monitor_instances(app): log.info('monitoring instances') - compute_client: aiogoogle.GoogleComputeClient = app['compute_client'] + compute_client = app[AppKeys.COMPUTE_CLIENT] instance_counts: Dict[InstanceLabels, int] = defaultdict(int) - for zone in app['zones']: + for zone in app[AppKeys.ZONES]: async for instance in await compute_client.list( f'/zones/{zone}/instances', params={'filter': '(labels.role = batch2-agent)'} ): @@ -319,46 +325,60 @@ async def monitor_instances(app): INSTANCES.labels(**labels._asdict()).set(count) +class AppKeys(CommonAiohttpAppKeys): + DB = web.AppKey('db', Database) + BQ_CLIENT = web.AppKey('bigquery_client', aiogoogle.GoogleBigQueryClient) + COMPUTE_CLIENT = web.AppKey('compute_client', aiogoogle.GoogleComputeClient) + QUERY_BILLING_EVENT = web.AppKey('query_billing_event', asyncio.Event) + ZONES = web.AppKey('zones', List[str]) + TASK_MANAGER = web.AppKey('task_manager', aiotools.BackgroundTaskManager) + EXIT_STACK = web.AppKey('exit_stack', AsyncExitStack) + + async def on_startup(app): + exit_stack = AsyncExitStack() + app[AppKeys.EXIT_STACK] = exit_stack + db = Database() await db.async_init() - app['db'] = db - app['client_session'] = httpx.client_session() + app[AppKeys.DB] = db + exit_stack.push_async_callback(db.async_close) + + app[AppKeys.CLIENT_SESSION] = httpx.client_session() + exit_stack.push_async_callback(app[AppKeys.CLIENT_SESSION].close) aiogoogle_credentials = aiogoogle.GoogleCredentials.from_file('/billing-monitoring-gsa-key/key.json') - bigquery_client = aiogoogle.GoogleBigQueryClient('broad-ctsa', credentials=aiogoogle_credentials) - app['bigquery_client'] = bigquery_client + app[AppKeys.BQ_CLIENT] = aiogoogle.GoogleBigQueryClient('broad-ctsa', credentials=aiogoogle_credentials) compute_client = aiogoogle.GoogleComputeClient(PROJECT, credentials=aiogoogle_credentials) - app['compute_client'] = compute_client + app[AppKeys.COMPUTE_CLIENT] = compute_client query_billing_event = asyncio.Event() - app['query_billing_event'] = query_billing_event + app[AppKeys.QUERY_BILLING_EVENT] = query_billing_event region_info = {name: await compute_client.get(f'/regions/{name}') for name in BATCH_GCP_REGIONS} zones = [url_basename(z) for r in region_info.values() for z in r['zones']] - app['zones'] = zones + app[AppKeys.ZONES] = zones - app['task_manager'] = aiotools.BackgroundTaskManager() + task_manager = aiotools.BackgroundTaskManager() + exit_stack.callback(task_manager.shutdown) + app[AppKeys.TASK_MANAGER] = task_manager - app['task_manager'].ensure_future(retry_long_running('polling_loop', polling_loop, app)) + task_manager.ensure_future(retry_long_running('polling_loop', polling_loop, app)) - app['task_manager'].ensure_future( + task_manager.ensure_future( retry_long_running( 'query_billing_loop', run_if_changed_idempotent, query_billing_event, query_billing_body, app ) ) - app['task_manager'].ensure_future(periodically_call(60, monitor_disks, app)) - app['task_manager'].ensure_future(periodically_call(60, monitor_instances, app)) + task_manager.ensure_future(periodically_call(60, monitor_disks, app)) + task_manager.ensure_future(periodically_call(60, monitor_instances, app)) async def on_cleanup(app): - async with AsyncExitStack() as cleanup: - cleanup.push_async_callback(app['db'].async_close) - cleanup.push_async_callback(app['client_session'].close) - cleanup.callback(app['task_manager'].shutdown) + await app[AppKeys.EXIT_STACK].aclose() def run(): @@ -378,5 +398,5 @@ def run(): host='0.0.0.0', port=5000, access_log_class=AccessLogger, - ssl_context=internal_server_ssl_context(), + ssl_context=deploy_config.server_ssl_context(), ) diff --git a/monitoring/test/pytest.ini b/monitoring/test/pytest.ini new file mode 100644 index 00000000000..2f4c80e3075 --- /dev/null +++ b/monitoring/test/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +asyncio_mode = auto diff --git a/monitoring/test/test_monitoring.py b/monitoring/test/test_monitoring.py index bfa22a70bce..e916e61dbee 100644 --- a/monitoring/test/test_monitoring.py +++ b/monitoring/test/test_monitoring.py @@ -1,8 +1,6 @@ import asyncio import logging -import pytest - from hailtop.auth import hail_credentials from hailtop.config import get_deploy_config from hailtop.httpx import client_session @@ -12,7 +10,6 @@ log = logging.getLogger(__name__) -@pytest.mark.asyncio async def test_billing_monitoring(): deploy_config = get_deploy_config() monitoring_deploy_config_url = deploy_config.url('monitoring', '/api/v1alpha/billing') diff --git a/prometheus/prometheus.yaml b/prometheus/prometheus.yaml index ede6863d7d7..a8150f07c01 100644 --- a/prometheus/prometheus.yaml +++ b/prometheus/prometheus.yaml @@ -296,7 +296,7 @@ spec: - "/bin/prometheus" - "--config.file=/etc/prometheus/prometheus.yml" - "--storage.tsdb.path=/prometheus" - - "--storage.tsdb.retention.time=15d" + - "--storage.tsdb.retention.time=90d" - "--web.console.libraries=/usr/share/prometheus/console_libraries" - "--web.console.templates=/usr/share/prometheus/consoles" - "--web.enable-lifecycle" diff --git a/pyproject.toml b/pyproject.toml index 4fa299f1a71..ff30c18de32 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,8 +1,3 @@ -[tool.black] -line-length = 120 -skip-string-normalization = true -force-exclude = 'hail|datasets|sql|\.mypy' - [tool.ruff] line-length = 120 select = ["F", "E", "W", "I", "PL", "RUF"] @@ -19,20 +14,13 @@ ignore = [ "PLC1901", # ` != ''` can be simplified to `` as an empty string is falsey "PLR2004", # Magic value used in comparison ] -extend-exclude = ['sql'] +extend-exclude = ['sql', 'datasets'] force-exclude = true [tool.ruff.isort] known-first-party = ["auth", "batch", "ci", "gear", "hailtop", "monitoring", "website", "web_common"] [tool.ruff.per-file-ignores] -"hail/python/hail/**/*" = ["I", "PL", "RUF"] -"hail/python/hailtop/**/*" = ["I"] -"hail/python/test/**/*" = ["ALL"] -"hail/python/cluster-tests/**/*" = ["ALL"] -"hail/python/setup-hailtop.py" = ["ALL"] -"hail/python/setup.py" = ["ALL"] -"hail/scripts/update-terra-image.py" = ["ALL"] "hail/src/**/*" = ["ALL"] "benchmark/**/*" = ["ALL"] "datasets/**/*" = ["ALL"] @@ -40,6 +28,10 @@ known-first-party = ["auth", "batch", "ci", "gear", "hailtop", "monitoring", "we "docker/**/*" = ["ALL"] "query/**/*" = ["ALL"] +[tool.ruff.format] +preview = true +quote-style = "preserve" + [pytest] timeout = 120 diff --git a/query/.clang-format b/query/.clang-format deleted file mode 100644 index 6c5f77d0c8b..00000000000 --- a/query/.clang-format +++ /dev/null @@ -1,7 +0,0 @@ -BasedOnStyle: LLVM -ColumnLimit: 100 - -AlignOperands: AlignAfterOperator -AlwaysBreakTemplateDeclarations: Yes -BreakBeforeBinaryOperators: NonAssignment -QualifierAlignment: Right diff --git a/query/.clang-tidy b/query/.clang-tidy deleted file mode 100644 index 37d46431ada..00000000000 --- a/query/.clang-tidy +++ /dev/null @@ -1,73 +0,0 @@ -Checks: > - -*, - clang-diagnostic-*, - llvm-*, - bugprone-*, - -bugprone-bad-signal-to-kill-thread, - -bugprone-easily-swappable-parameters, - -bugprone-exception-escape, - -bugprone-no-escape, - -bugprone-not-null-terminated-result, - -bugprone-throw-keyword-missing, - -bugprone-unchecked-optional-access, - -bugprone-unhandled-exception-at-new, - cppcoreguidelines-*, - -cppcoreguidelines-avoid-magic-numbers, - -cppcoreguidelines-pro-bounds-array-to-pointer-decay, - misc-*, - -misc-confusable-identifiers, - -misc-unused-parameters, - -misc-non-private-member-variables-in-classes, - -misc-no-recursion, - modernize-*, - -modernize-avoid-bind, - -modernize-macro-to-enum, - -modernize-redundant-void-arg, - -modernize-replace-auto-ptr, - -modernize-replace-disallow-copy-and-assign-macro, - -modernize-replace-random-shuffle, - -modernize-use-nodiscard, - -modernize-use-noexcept, - -modernize-use-uncaught-exceptions, - performance-*, - readability-*, - -readability-braces-around-statements, - -readability-function-*, - -readability-identifier-length, - -readability-magic-numbers - -CheckOptions: - - key: readability-identifier-naming.ClassCase - value: CamelCase - - key: readability-identifier-naming.EnumCase - value: CamelCase - - key: readability-identifier-naming.FunctionCase - value: camelBack - # Exclude from scanning as this is an exported symbol used for fuzzing - # throughout the code base. - - key: readability-identifier-naming.FunctionIgnoredRegexp - value: "LLVMFuzzerTestOneInput" - - key: readability-identifier-naming.MemberCase - value: camelBack - - key: readability-identifier-naming.ParameterCase - value: camelBack - - key: readability-identifier-naming.UnionCase - value: CamelCase - - key: readability-identifier-naming.VariableCase - value: camelBack - - key: readability-identifier-naming.IgnoreMainLikeFunctions - value: 1 - - key: readability-redundant-member-init.IgnoreBaseInCopyConstructors - value: 1 - - key: modernize-use-default-member-init.UseAssignment - value: 1 - - key: bugprone-argument-comment.StrictMode - value: true - - key: bugprone-argument-comment.CommentBoolLiterals - value: true - - key: bugprone-argument-comment.CommentNullPtrs - value: true - - key: bugprone-argument-comment.IgnoreSingleArgument - value: true - -HeaderFilterRegex: include/hail.*\.h$ diff --git a/query/.clangd b/query/.clangd deleted file mode 100644 index efdf67c928b..00000000000 --- a/query/.clangd +++ /dev/null @@ -1,4 +0,0 @@ -Diagnostics: - UnusedIncludes: Strict -Hover: - ShowAKA: No diff --git a/query/.dir-locals.el b/query/.dir-locals.el deleted file mode 100644 index 10c00c35414..00000000000 --- a/query/.dir-locals.el +++ /dev/null @@ -1,4 +0,0 @@ -;;; Directory Local Variables -;;; For more information see (info "(emacs) Directory Variables") - -((c-mode . ((mode . c++)))) diff --git a/query/CMakeLists.txt b/query/CMakeLists.txt deleted file mode 100644 index 4f872a5a975..00000000000 --- a/query/CMakeLists.txt +++ /dev/null @@ -1,65 +0,0 @@ -cmake_minimum_required(VERSION 3.16) -project(hail-mlir-dialect LANGUAGES CXX C) - -# CMP0116: Ninja generators transform `DEPFILE`s from `add_custom_command()` -# New in CMake 3.20. https://cmake.org/cmake/help/latest/policy/CMP0116.html -if(POLICY CMP0116) - cmake_policy(SET CMP0116 OLD) -endif() - -set(CMAKE_BUILD_WITH_INSTALL_NAME_DIR ON) - -set(CMAKE_CXX_STANDARD 20 CACHE STRING "C++ standard to conform to") - -find_package(MLIR REQUIRED CONFIG) - -message(STATUS "Using MLIRConfig.cmake in: ${MLIR_DIR}") -message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}") - -set(LLVM_RUNTIME_OUTPUT_INTDIR ${CMAKE_BINARY_DIR}/bin) -set(LLVM_LIBRARY_OUTPUT_INTDIR ${CMAKE_BINARY_DIR}/lib) -set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) - -set(HAIL_MAIN_SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR} ) -set(HAIL_MAIN_INCLUDE_DIR ${MLIR_MAIN_SRC_DIR}/include ) - -list(INSERT CMAKE_MODULE_PATH 0 - "${HAIL_MAIN_SRC_DIR}/cmake/modules" -) - -include(AddHail) - -list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}") -list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}") -include(TableGen) -include(AddLLVM) -include(AddMLIR) -include(HandleLLVMOptions) - -if(MLIR_ENABLE_BINDINGS_PYTHON) - include(MLIRDetectPythonEnv) - mlir_configure_python_dev_packages() -endif() - -include_directories(${LLVM_INCLUDE_DIRS}) -include_directories(${MLIR_INCLUDE_DIRS}) -include_directories(${PROJECT_SOURCE_DIR}/include) -include_directories(${PROJECT_BINARY_DIR}/include) -link_directories(${LLVM_BUILD_LIBRARY_DIR}) -add_definitions(${LLVM_DEFINITIONS}) - -# supress a warning on clang -add_compile_options($<$,$>:-Wno-ambiguous-reversed-operator>) - -if(HAIL_USE_CLANG_TIDY AND NOT CMAKE_CXX_CLANG_TIDY) - find_program(CLANG_TIDY "clang-tidy" HINTS ${LLVM_TOOLS_BINARY_DIR}) - if(CLANG_TIDY) - message(STATUS "Found clang-tidy: ${CLANG_TIDY}") - set(CMAKE_CXX_CLANG_TIDY ${CLANG_TIDY} --use-color --warnings-as-errors=*) - endif() -endif() - -add_subdirectory(include/hail) -add_subdirectory(lib) -add_subdirectory(tools) -add_subdirectory(test) diff --git a/query/Dockerfile.query-build b/query/Dockerfile.query-build deleted file mode 100644 index c3a30b41315..00000000000 --- a/query/Dockerfile.query-build +++ /dev/null @@ -1,15 +0,0 @@ -ARG BASE_IMAGE={{ hail_ubuntu_image.image }} -FROM $BASE_IMAGE - -RUN curl 'https://apt.llvm.org/llvm-snapshot.gpg.key' \ - | gpg --dearmor >> /usr/share/keyrings/llvm-snapshot-keyring.gpg && \ - echo 'deb [signed-by=/usr/share/keyrings/llvm-snapshot-keyring.gpg] http://apt.llvm.org/focal/ llvm-toolchain-focal-15 main' \ - >> /etc/apt/sources.list && \ - hail-apt-get-install \ - build-essential \ - cmake ninja-build \ - llvm-15-dev \ - clang-15 clang-tools-15 clang-format-15 clang-tidy-15 \ - libmlir-15-dev mlir-15-tools \ - python3-pip && \ - hail-pip-install lit diff --git a/query/README.md b/query/README.md deleted file mode 100644 index 03c02314522..00000000000 --- a/query/README.md +++ /dev/null @@ -1,244 +0,0 @@ -# [MLIR](https://mlir.llvm.org) + [hail](https://hail.is) = 🚀🧬? - -## Dependencies - -You will need `cmake` and `ninja`. If you're on OS X, use https://brew.sh: - -```sh -brew install cmake ninja -``` - -## Building/Installing LLVM and MLIR -Obviously, update the paths for your environment. `$HAIL_DIR` is the root of the -hail repository. - -```sh -git clone https://github.com/llvm/llvm-project.git -mkdir llvm-project/build -cd llvm-project/build -git checkout llvmorg-15.0.3 # latest stable LLVM/MLIR release - -# Some notes: -# 1. -G Ninja generates a build.ninja file rather than makefiles it's not -# required but is recommended by LLVM -# 2. The CMAKE_INSTALL_PREFIX I put here is a subdirectory of the mlir-hail -# (this repo's) root. If you do this, add that directory to -# .git/info/exclude and it will be like adding it to a gitignore -# 3. On linux, using lld via -DLLVM_ENABLE_LLD=ON can speed up the build due -# to faster linking. -# -# The -DLLVM_BUILD_EXAMPLES=ON flag is optional. -cmake ../llvm -G Ninja \ - -DLLVM_ENABLE_PROJECTS="clang;clang-tools-extra;mlir" \ - -DLLVM_BUILD_EXAMPLES=ON \ - -DLLVM_TARGETS_TO_BUILD="AArch64;X86;NVPTX;AMDGPU" \ - -DCMAKE_BUILD_TYPE=Release \ - -DLLVM_ENABLE_ASSERTIONS=ON \ - -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ \ - -DCMAKE_INSTALL_PREFIX=$HAIL_DIR/query/.dist/llvm \ - -DCMAKE_EXPORT_COMPILE_COMMANDS=1 -ninja # this will take a while -ninja install -``` - - -## Building Hail's native compiler - -To set up the build, from this directory: - -```sh -mkdir build -cd build -# You can set CC and CXX or use -DCMAKE_C_COMPILER and -DCMAKE_CXX_COMPILER to -# change the C and C++ compilers used to build this project. -cmake .. -G Ninja \ - -DCMAKE_BUILD_TYPE=RelWithDebInfo \ - -DMLIR_DIR=$HAIL_DIR/.dist/llvm/lib/cmake/mlir \ - -DCMAKE_EXPORT_COMPILE_COMMANDS=1 \ - -DLLVM_BUILD_BINARY_DIR=~/src/llvm-project/build/bin - # ^ this argument is necessary to find llvm-lit and FileCheck for the tests - # they are skipped and a warning printed otherwise -``` - -To build: -```sh -cd build -ninja -``` - -To also run `clang-tidy` when building, pass `-DCMAKE_CXX_CLANG_TIDY=clang-tidy` or -`-DHAIL_USE_CLANG_TIDY=ON` to cmake. This can drastically increase build times, -so it may not be suitable for local development. - -## Running the tests - -To run all tests (from the `hail/query` directory): - -```sh -ninja check-hail -C ./build -``` - -To run a single test (substitute `int.mlir` for the name of the test file to run): - -```sh -cd ./build/test && llvm-lit ./int.mlir -v || true && cd - -``` - - -## Setting up your editor - -There are several [language servers](https://microsoft.github.io/language-server-protocol) -available that can vastly improve your editor experience when working with MLIR, -namely [`clangd` for C++ files](https://clangd.llvm.org) and [the MLIR-specific -language servers for `.mlir`, `.pdll`, and `.td` files](https://mlir.llvm.org/docs/Tools/MLIRLSP). -Below are instructions for setting them up with a few different editors. - - -### Visual Studio Code - -If you're on macOS, you can install [VS Code](https://code.visualstudio.com/) -via Homebrew: - -```sh -brew install --cask visual-studio-code -``` - -Open VS Code and install the -[clangd](https://marketplace.visualstudio.com/items?itemName=llvm-vs-code-extensions.vscode-clangd) -and [MLIR](https://marketplace.visualstudio.com/items?itemName=llvm-vs-code-extensions.vscode-mlir) -extensions, either by clicking the `Install` button on their webpages, or by -navigating to the `Extensions` icon in the bar on the left-hand side of the editor -window, searching for them by name, and clicking the `Install` button. - -You will need [`jq`](https://stedolan.github.io/jq/) installed, if you don't -already have it. If you're on macOS, you can install it via Homebrew: - -```sh -brew install jq -``` - -The location of the `settings.json` file on your platform can be found -[here](https://code.visualstudio.com/docs/getstarted/settings#_settings-file-locations). - -```sh -# If you are not on macOS, replace this with the location of the `settings.json` -# file on your platform -VS_CODE_SETTINGS="$HOME/Library/Application Support/Code/User" -cp "${VS_CODE_SETTINGS}/settings.json" "${VS_CODE_SETTINGS}/settings.json.bak" -jq " - .[\"mlir.pdll_server_path\"] = \"${LLVM_BUILD_BIN}/mlir-pdll-lsp-server\" -| .[\"mlir.server_path\"] = \"${LLVM_BUILD_BIN}/mlir-lsp-server\" -| .[\"mlir.tablegen_server_path\"] = \"${LLVM_BUILD_BIN}/tblgen-lsp-server\" -| .[\"mlir.tablegen_compilation_databases\"] = [ - \"${HAIL_NATIVE_COMPILER_BUILD}/tablegen_compile_commands.yml\" - ] -" "${VS_CODE_SETTINGS}/settings.json.bak" > "${VS_CODE_SETTINGS}/settings.json" -``` - -Close and reopen VS Code, and the language servers should work; you can test this -by opening a `.td` file, right clicking a name `include`d from the MLIR project -(such as `Dialect`), and choosing `Go to Definition`. - - -### Neovim - -If you're on macOS, you can install [Neovim](https://neovim.io/) via Homebrew: - -```sh -brew install nvim -``` - -[Syntax highlighting](https://github.com/sheerun/vim-polyglot) and -[language server configuration](https://github.com/neovim/nvim-lspconfig) support -are provided by Neovim plugins. First, we'll install [packer](https://github.com/wbthomason/packer.nvim), -which will manage the installation of those plugins for us. - -```sh -git clone --depth 1 https://github.com/wbthomason/packer.nvim ~/.local/share/nvim/site/pack/packer/start/packer.nvim -``` - -Next, place the following configuration in the `~/.config/nvim/init.lua` file. If you -already have a Neovim or Vim configuration written in Vimscript, you can migrate -it to Lua by copy-pasting it into the invocation of `vim.cmd`. - -```lua -vim.cmd([[ -" Vimscript configuration can be placed on multiple lines between these brackets -]]) - -local config = { - globals = { - mapleader = ' ', - }, - options = { - overrides = { - number = true, - signcolumn = 'number', - }, - }, - keymap = { - -- Add or replace keybinds here - insert = { - [''] = '', - }, - normal = { - ['do'] = vim.diagnostic.open_float, - ['[d'] = vim.diagnostic.goto_prev, - [']d'] = vim.diagnostic.goto_next, - ['da'] = vim.lsp.buf.code_action, - ['dg'] = vim.lsp.buf.definition, - ['dk'] = vim.lsp.buf.hover, - ['ds'] = vim.lsp.buf.signature_help, - ['dr'] = vim.lsp.buf.rename, - ['df'] = vim.lsp.buf.references, - }, - }, -} - -for k, v in pairs(config.globals) do vim.g[k] = v end -for k, v in pairs(config.options.overrides) do vim.opt[k] = v end -for k, v in pairs(config.keymap.insert) do vim.keymap.set('i', k, v, { noremap = true }) end -for k, v in pairs(config.keymap.normal) do vim.keymap.set('n', k, v, { noremap = true }) end - -require('packer').startup(function(use) - use 'neovim/nvim-lspconfig' - use 'sheerun/vim-polyglot' -end) - -local lspconfig = require('lspconfig') -lspconfig.clangd.setup({}) -lspconfig.mlir_lsp_server.setup({}) -lspconfig.mlir_pdll_lsp_server.setup({}) -lspconfig.tblgen_lsp_server.setup({ - cmd = { - "tblgen-lsp-server", - "--tablegen-compilation-database=HAIL_NATIVE_COMPILER_BUILD/tablegen_compile_commands.yml", - }, -}) -``` - -After saving the configuration file, run the following command to replace -`HAIL_NATIVE_COMPILER_BUILD` with its value: - -```sh -sed -i .bak "s;HAIL_NATIVE_COMPILER_BUILD;${HAIL_NATIVE_COMPILER_BUILD};g" ~/.config/nvim/init.lua -``` - -This configuration provides keybinds for many of the major -[LSP features](https://neovim.io/doc/user/lsp.html#lsp-quickstart) that Neovim -supports. Keybinds can be added or replaced in the `config.keymap` table. - -The absolute path to the `tablegen_compile_commands.yml` file will need to be -manually specified in order for the `tblgen-lsp-server` to function correctly. -Replace the `$HAIL_NATIVE_COMPILER_BUILD` variable in the configuration -with its value, which can be determined by running `echo $HAIL_NATIVE_COMPILER_BUILD`. - -Close and reopen Neovim, and install the plugins: - -```vim -:PackerInstall -``` - -The language servers should work; you can test this by opening a `.td` file, -placing the cursor on a name `include`d from the MLIR project (such as `Dialect`), -and calling the `vim.lsp.buf.definition` function (`gd` in the configuration provided). diff --git a/query/cmake/modules/AddHail.cmake b/query/cmake/modules/AddHail.cmake deleted file mode 100644 index 77e4ed0c45d..00000000000 --- a/query/cmake/modules/AddHail.cmake +++ /dev/null @@ -1,5 +0,0 @@ -# Declare the library associated with a dialect. -function(add_hail_dialect_library name) - set_property(GLOBAL APPEND PROPERTY HAIL_DIALECT_LIBS ${name}) - add_mlir_library(${ARGV} DEPENDS mlir-headers) -endfunction() diff --git a/query/include/hail/Analysis/MissingnessAnalysis.h b/query/include/hail/Analysis/MissingnessAnalysis.h deleted file mode 100644 index 59b5ba28adf..00000000000 --- a/query/include/hail/Analysis/MissingnessAnalysis.h +++ /dev/null @@ -1,62 +0,0 @@ -#ifndef HAIL_ANALYSIS_MISSINGNESSANALYSIS_H -#define HAIL_ANALYSIS_MISSINGNESSANALYSIS_H - -#include "mlir/Analysis/DataFlow/SparseAnalysis.h" - -namespace hail::ir { - -//===----------------------------------------------------------------------===// -// MissingnessValue -//===----------------------------------------------------------------------===// - -class MissingnessValue { -public: - enum State { - Present, // value is always present - Unknown, // value might be present or missing - Missing // value is always missing - }; - - MissingnessValue(State state) : state(state) {} - auto isMissing() const -> bool { return state == Missing; } - auto isPresent() const -> bool { return state == Present; } - auto getState() const -> State { return state; } - - void setMissing() const { join(*this, Missing); } - void setPresent() const { join(*this, Present); } - - auto operator==(MissingnessValue const &rhs) const -> bool { return state == rhs.state; } - - void print(llvm::raw_ostream &os) const; - - static auto getPessimisticValueState(mlir::Value value) -> MissingnessValue { return {Unknown}; } - - static auto join(MissingnessValue const &lhs, MissingnessValue const &rhs) -> MissingnessValue { - return lhs == rhs ? lhs : MissingnessValue(Unknown); - } - -private: - State state; -}; - -//===----------------------------------------------------------------------===// -// SparseConstantPropagation -//===----------------------------------------------------------------------===// - -/// This analysis implements sparse constant propagation, which attempts to -/// determine constant-valued results for operations using constant-valued -/// operands, by speculatively folding operations. When combined with dead-code -/// analysis, this becomes sparse conditional constant propagation (SCCP). -class MissingnessAnalysis - : public mlir::dataflow::SparseDataFlowAnalysis> { -public: - using SparseDataFlowAnalysis::SparseDataFlowAnalysis; - - void visitOperation(mlir::Operation *op, - llvm::ArrayRef const *> operands, - llvm::ArrayRef *> results) override; -}; - -} // namespace hail::ir - -#endif // HAIL_ANALYSIS_MISSINGNESSANALYSIS_H diff --git a/query/include/hail/Analysis/MissingnessAwareConstantPropagationAnalysis.h b/query/include/hail/Analysis/MissingnessAwareConstantPropagationAnalysis.h deleted file mode 100644 index 98e82be9754..00000000000 --- a/query/include/hail/Analysis/MissingnessAwareConstantPropagationAnalysis.h +++ /dev/null @@ -1,50 +0,0 @@ -//===- ConstantPropagationAnalysis.h - Constant propagation analysis ------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements constant propagation analysis. In this file are defined -// the lattice value class that represents constant values in the program and -// a sparse constant propagation analysis that uses operation folders to -// speculate about constant values in the program. -// -//===----------------------------------------------------------------------===// - -#ifndef HAIL_ANALYSIS_MISSINGNESSAWARECONSTANTPROPAGATIONANALYSIS_H -#define HAIL_ANALYSIS_MISSINGNESSAWARECONSTANTPROPAGATIONANALYSIS_H - -#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" -#include "mlir/Analysis/DataFlow/SparseAnalysis.h" - -namespace hail::ir { - -//===----------------------------------------------------------------------===// -// SparseConstantPropagation -//===----------------------------------------------------------------------===// - -/// This analysis implements sparse constant propagation, which attempts to -/// determine constant-valued results for operations using constant-valued -/// operands, by speculatively folding operations. When combined with dead-code -/// analysis, this becomes sparse conditional constant propagation (SCCP). -/// -/// It is intended to also be combined with requiredness analysis. For that to -/// be sound, it must satisfy the invariant that whenever a value is inferred to -/// me missing, its inferred constant value is uninitialized. -class MissingnessAwareConstantPropagation - : public mlir::dataflow::SparseDataFlowAnalysis< - mlir::dataflow::Lattice> { -public: - using SparseDataFlowAnalysis::SparseDataFlowAnalysis; - - void visitOperation( - mlir::Operation *op, - llvm::ArrayRef const *> operands, - llvm::ArrayRef *> results) override; -}; - -} // namespace hail::ir - -#endif // HAIL_ANALYSIS_MISSINGNESSAWARECONSTANTPROPAGATIONANALYSIS_H diff --git a/query/include/hail/CMakeLists.txt b/query/include/hail/CMakeLists.txt deleted file mode 100644 index 495310c6626..00000000000 --- a/query/include/hail/CMakeLists.txt +++ /dev/null @@ -1,3 +0,0 @@ -add_subdirectory(Conversion) -add_subdirectory(Dialect) -add_subdirectory(Transforms) diff --git a/query/include/hail/Conversion/CMakeLists.txt b/query/include/hail/Conversion/CMakeLists.txt deleted file mode 100644 index bf095af835b..00000000000 --- a/query/include/hail/Conversion/CMakeLists.txt +++ /dev/null @@ -1,3 +0,0 @@ -set(LLVM_TARGET_DEFINITIONS Passes.td) -mlir_tablegen(Passes.h.inc -gen-pass-decls -name Conversion) -add_public_tablegen_target(HailConversionPassIncGen) diff --git a/query/include/hail/Conversion/CPSToCF/CPSToCF.h b/query/include/hail/Conversion/CPSToCF/CPSToCF.h deleted file mode 100644 index 592d7285641..00000000000 --- a/query/include/hail/Conversion/CPSToCF/CPSToCF.h +++ /dev/null @@ -1,22 +0,0 @@ -#ifndef HAIL_CONVERSION_CPSTOCF_CPSTOCF_H -#define HAIL_CONVERSION_CPSTOCF_CPSTOCF_H - -#include "mlir/Pass/Pass.h" - -#include - -namespace mlir { - -class Pass; - -} // namespace mlir - -namespace hail::ir { - -/// Creates a pass to convert continuation-based control flow to CFG -/// branch-based operation in the ControlFlow dialect. -auto createCPSToCFPass() -> std::unique_ptr; - -} // namespace hail::ir - -#endif // HAIL_CONVERSION_CPSTOCF_CPSTOCF_H diff --git a/query/include/hail/Conversion/LowerOption/LowerOption.h b/query/include/hail/Conversion/LowerOption/LowerOption.h deleted file mode 100644 index ab7c7c4fb30..00000000000 --- a/query/include/hail/Conversion/LowerOption/LowerOption.h +++ /dev/null @@ -1,23 +0,0 @@ -#ifndef HAIL_CONVERSION_LOWEROPTION_LOWEROPTION_H -#define HAIL_CONVERSION_LOWEROPTION_LOWEROPTION_H - -#include - -namespace mlir { - -class Pass; -class RewritePatternSet; - -} // namespace mlir - -namespace hail::ir { - -void populateLowerOptionConversionPatterns(mlir::RewritePatternSet &patterns); - -/// Creates a pass to lower the Option type to bool and values ssa-values, using CPS for control -/// flow -auto createLowerOptionPass() -> std::unique_ptr; - -} // namespace hail::ir - -#endif // HAIL_CONVERSION_LOWEROPTION_LOWEROPTION_H diff --git a/query/include/hail/Conversion/LowerSandbox/LowerSandbox.h b/query/include/hail/Conversion/LowerSandbox/LowerSandbox.h deleted file mode 100644 index 9265dbcd85e..00000000000 --- a/query/include/hail/Conversion/LowerSandbox/LowerSandbox.h +++ /dev/null @@ -1,25 +0,0 @@ -#ifndef HAIL_CONVERSION_LOWERSANDBOX_LOWERSANDBOX_H -#define HAIL_CONVERSION_LOWERSANDBOX_LOWERSANDBOX_H - -#include - -namespace mlir { - -class Pass; -class RewritePatternSet; - -} // namespace mlir - -namespace hail::ir { - -/// Collect a set of patterns to convert SCF operations to CFG branch-based -/// operations within the ControlFlow dialect. -void populateLowerSandboxConversionPatterns(mlir::RewritePatternSet &patterns); - -/// Creates a pass to convert SCF operations to CFG branch-based operation in -/// the ControlFlow dialect. -auto createLowerSandboxPass() -> std::unique_ptr; - -} // namespace hail::ir - -#endif // HAIL_CONVERSION_LOWERSANDBOX_LOWERSANDBOX_H diff --git a/query/include/hail/Conversion/LowerToLLVM/LowerToLLVM.h b/query/include/hail/Conversion/LowerToLLVM/LowerToLLVM.h deleted file mode 100644 index 54036183261..00000000000 --- a/query/include/hail/Conversion/LowerToLLVM/LowerToLLVM.h +++ /dev/null @@ -1,23 +0,0 @@ -#ifndef HAIL_CONVERSION_LOWERTOLLVM_LOWERTOLLVM_H -#define HAIL_CONVERSION_LOWERTOLLVM_LOWERTOLLVM_H - -#include - -namespace mlir { -class Pass; -class RewritePatternSet; -} // namespace mlir - -namespace hail::ir { - -/// Collect a set of patterns to convert SCF operations to CFG branch-based -/// operations within the ControlFlow dialect. -void populateLowerToLLVMConversionPatterns(mlir::RewritePatternSet &patterns); - -/// Creates a pass to convert SCF operations to CFG branch-based operation in -/// the ControlFlow dialect. -auto createLowerToLLVMPass() -> std::unique_ptr; - -} // namespace hail::ir - -#endif // HAIL_CONVERSION_LOWERTOLLVM_LOWERTOLLVM_H diff --git a/query/include/hail/Conversion/OptionToGenericOption/OptionToGenericOption.h b/query/include/hail/Conversion/OptionToGenericOption/OptionToGenericOption.h deleted file mode 100644 index 9f1f6838084..00000000000 --- a/query/include/hail/Conversion/OptionToGenericOption/OptionToGenericOption.h +++ /dev/null @@ -1,21 +0,0 @@ -#ifndef HAIL_CONVERSION_OPTIONTOGENERICOPTION_OPTIONTOGENERICOPTION_H -#define HAIL_CONVERSION_OPTIONTOGENERICOPTION_OPTIONTOGENERICOPTION_H - -#include - -namespace mlir { - -class Pass; -class RewritePatternSet; - -} // namespace mlir - -namespace hail::ir { - -void populateOptionToGenericOptionConversionPatterns(mlir::RewritePatternSet &patterns); - -auto createOptionToGenericOptionPass() -> std::unique_ptr; - -} // namespace hail::ir - -#endif // HAIL_CONVERSION_OPTIONTOGENERICOPTION_OPTIONTOGENERICOPTION_H diff --git a/query/include/hail/Conversion/Passes.h b/query/include/hail/Conversion/Passes.h deleted file mode 100644 index 5f367db0360..00000000000 --- a/query/include/hail/Conversion/Passes.h +++ /dev/null @@ -1,19 +0,0 @@ -#ifndef HAIL_CONVERSION_PASSES_H -#define HAIL_CONVERSION_PASSES_H - -#include "hail/Conversion/CPSToCF/CPSToCF.h" -#include "hail/Conversion/LowerOption/LowerOption.h" -#include "hail/Conversion/LowerSandbox/LowerSandbox.h" -#include "hail/Conversion/LowerToLLVM/LowerToLLVM.h" -#include "hail/Conversion/OptionToGenericOption/OptionToGenericOption.h" -#include "mlir/Pass/PassRegistry.h" - -namespace hail { - -/// Generate the code for registering conversion passes. -#define GEN_PASS_REGISTRATION -#include "hail/Conversion/Passes.h.inc" - -} // namespace hail - -#endif // HAIL_CONVERSION_PASSES_H diff --git a/query/include/hail/Conversion/Passes.td b/query/include/hail/Conversion/Passes.td deleted file mode 100644 index 2154ee09151..00000000000 --- a/query/include/hail/Conversion/Passes.td +++ /dev/null @@ -1,63 +0,0 @@ -#ifndef HAIL_CONVERSION_PASSES -#define HAIL_CONVERSION_PASSES - -include "mlir/Pass/PassBase.td" - - -def CPSToCF : Pass<"cps-to-cf"> { - let summary = "Lower CPS-style control flow to the CFG-based cf dialect"; - let description = [{ - }]; - let constructor = "hail::ir::createCPSToCFPass()"; - let dependentDialects = [ - "mlir::cf::ControlFlowDialect", - ]; -} - -def LowerOption : Pass<"lower-option"> { - let summary = "Lower option type"; - let description = [{ - Lower the option type to bool and values ssa-values, using the CPS dialect for control flow - }]; - let constructor = "hail::ir::createLowerOptionPass()"; - let dependentDialects = [ - "mlir::arith::ArithmeticDialect", - ]; -} - -def LowerSandbox : Pass<"lower-sandbox"> { - let summary = "Lower sandbox operations to arith"; - let description = [{ - Convert operations and types from the sandbox dialect into operations and - types from the arithmetic dialect. - }]; - let constructor = "hail::ir::createLowerSandboxPass()"; - let dependentDialects = [ - "mlir::arith::ArithmeticDialect", - "mlir::tensor::TensorDialect" - ]; -} - -def LowerToLLVM : Pass<"lower-to-llvm", "mlir::ModuleOp"> { - let summary = "Lower all remaining operations to llvm"; - let description = [{ - Convert operations and types from the SCF, Arithmetic, and Func dialects, - plus the sandbox print operation, to the LLVM dialect. - }]; - let constructor = "hail::ir::createLowerToLLVMPass()"; - let dependentDialects = [ - "mlir::LLVM::LLVMDialect", - "mlir::scf::SCFDialect" - ]; -} - -def OptionToGenericOption : Pass<"option-to-generic-option", "mlir::ModuleOp"> { - let summary = "Convert all option ops to generic 'construct'/'destruct' ops"; - let description = [{ - Convert operations and types from the SCF, Arithmetic, and Func dialects, - plus the sandbox print operation, to the LLVM dialect. - }]; - let constructor = "hail::ir::createOptionToGenericOptionPass()"; -} - -#endif // HAIL_CONVERSION_PASSES diff --git a/query/include/hail/Dialect/CMakeLists.txt b/query/include/hail/Dialect/CMakeLists.txt deleted file mode 100644 index df05150a143..00000000000 --- a/query/include/hail/Dialect/CMakeLists.txt +++ /dev/null @@ -1,4 +0,0 @@ -add_subdirectory(CPS) -add_subdirectory(Missing) -add_subdirectory(Option) -add_subdirectory(Sandbox) diff --git a/query/include/hail/Dialect/CPS/CMakeLists.txt b/query/include/hail/Dialect/CPS/CMakeLists.txt deleted file mode 100644 index f33061b2d87..00000000000 --- a/query/include/hail/Dialect/CPS/CMakeLists.txt +++ /dev/null @@ -1 +0,0 @@ -add_subdirectory(IR) diff --git a/query/include/hail/Dialect/CPS/IR/CMakeLists.txt b/query/include/hail/Dialect/CPS/IR/CMakeLists.txt deleted file mode 100644 index 2e73166dcdf..00000000000 --- a/query/include/hail/Dialect/CPS/IR/CMakeLists.txt +++ /dev/null @@ -1,2 +0,0 @@ -set(LLVM_TARGET_DEFINITIONS CPSOps.td) -add_mlir_dialect(CPSOps cps) diff --git a/query/include/hail/Dialect/CPS/IR/CPS.h b/query/include/hail/Dialect/CPS/IR/CPS.h deleted file mode 100644 index 96455abc75b..00000000000 --- a/query/include/hail/Dialect/CPS/IR/CPS.h +++ /dev/null @@ -1,18 +0,0 @@ -#ifndef HAIL_DIALECT_CPS_IR_CPS_H -#define HAIL_DIALECT_CPS_IR_CPS_H - -#include "mlir/IR/Dialect.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/Interfaces/SideEffectInterfaces.h" - -#include "mlir/IR/BuiltinTypes.h" - -#include "hail/Dialect/CPS/IR/CPSOpsDialect.h.inc" - -#define GET_TYPEDEF_CLASSES -#include "hail/Dialect/CPS/IR/CPSOpsTypes.h.inc" - -#define GET_OP_CLASSES -#include "hail/Dialect/CPS/IR/CPSOps.h.inc" - -#endif // HAIL_DIALECT_CPS_IR_CPS_H diff --git a/query/include/hail/Dialect/CPS/IR/CPSBase.td b/query/include/hail/Dialect/CPS/IR/CPSBase.td deleted file mode 100644 index 390ae6e56ff..00000000000 --- a/query/include/hail/Dialect/CPS/IR/CPSBase.td +++ /dev/null @@ -1,36 +0,0 @@ -#ifndef DIALECT_CPS_IR_CPSBASE -#define DIALECT_CPS_IR_CPSBASE - -include "mlir/IR/AttrTypeBase.td" - -def CPS_Dialect : Dialect { - let name = "cps"; - let summary = "Allows for blocks of code in continuation passing style"; - let cppNamespace = "::hail::ir"; - let useDefaultTypePrinterParser = 1; -} - -class CPS_Type traits = []> - : TypeDef { - let mnemonic = typeMnemonic; -} - -def CPS_ContType : CPS_Type<"Continuation", "cont"> { - let summary = "Continuation type"; - let parameters = (ins ArrayRefParameter<"mlir::Type">:$inputs); - // The declarative format "`<` $inputs `>`" doesn't support the empty list of types - let hasCustomAssemblyFormat = 1; - let builders = [ - TypeBuilder<(ins CArg<"mlir::TypeRange", "mlir::TypeRange()">:$inputTypes), [{ - return $_get($_ctxt, inputTypes); - }]> - ]; - let skipDefaultBuilders = 1; - let genStorageClass = 0; -} - -def CPS_EmptyContType : Type()">> { - let builderCall = "$_builder.getType()"; -} - -#endif // DIALECT_CPS_IR_CPSBASE diff --git a/query/include/hail/Dialect/CPS/IR/CPSOps.td b/query/include/hail/Dialect/CPS/IR/CPSOps.td deleted file mode 100644 index 168d6e331a1..00000000000 --- a/query/include/hail/Dialect/CPS/IR/CPSOps.td +++ /dev/null @@ -1,92 +0,0 @@ -#ifndef DIALECT_CPS_IR_CPSOPS -#define DIALECT_CPS_IR_CPSOPS - -include "mlir/Interfaces/SideEffectInterfaces.td" -include "mlir/IR/BuiltinTypes.td" -include "mlir/IR/OpBase.td" - -include "hail/Dialect/CPS/IR/CPSBase.td" - -class CPS_Op traits = []> : - Op; - - -def CPS_CallCCOp : CPS_Op<"callcc", [RecursiveSideEffects, SingleBlock]> { - let summary = "Call with current continuation"; - let description = [{ - }]; - - let results = (outs Variadic:$results); - let regions = (region SizedRegion<1>:$body); - - let hasCustomAssemblyFormat = 1; - let hasCanonicalizer = 1; - - let skipDefaultBuilders = 1; - let builders = [ OpBuilder<(ins "mlir::TypeRange":$resultTypes)> ]; -} - -def CPS_DefContOp : CPS_Op<"cont", [NoSideEffect, SingleBlock]> { - let summary = "Define a continuation"; - let description = [{ - }]; - - let results = (outs CPS_ContType:$result); - let regions = (region SizedRegion<1>:$bodyRegion); - - let hasCustomAssemblyFormat = 1; - let hasRegionVerifier = 1; - - let skipDefaultBuilders = 1; - let builders = [ OpBuilder<(ins CArg<"mlir::TypeRange", "mlir::TypeRange()">:$argTypes)> ]; -} - -def CPS_ApplyContOp : CPS_Op<"apply", [ - Terminator, NoSideEffect, - TypesMatchWith<"argument types match continuation type", "args", "cont", "::hail::ir::ContinuationType::get($_ctxt, $_self)"> - ]> { - let summary = "Apply a continuation"; - let description = [{ - }]; - - let arguments = (ins CPS_ContType:$cont, Variadic:$args); - let assemblyFormat = "$cont (`(` $args^ `)` `:` type($args))? attr-dict"; - - let hasCanonicalizer = 1; -} - -def CPS_IfOp : CPS_Op<"if", [ - Terminator, NoSideEffect, AttrSizedOperandSegments, - TypesMatchWith<"", "trueArgs", "trueCont", "::hail::ir::ContinuationType::get($_ctxt, $_self)">, - TypesMatchWith<"", "falseArgs", "falseCont", "::hail::ir::ContinuationType::get($_ctxt, $_self)"> - ]> { - let summary = "if-then-else operation"; - let description = [{ - }]; - let arguments = (ins I1:$condition); - let arguments = (ins I1:$condition, - CPS_ContType:$trueCont, Variadic:$trueArgs, - CPS_ContType:$falseCont, Variadic:$falseArgs); - - let assemblyFormat = [{ - $condition `,` - $trueCont (`(` $trueArgs^ `:` type($trueArgs) `)`)? `,` - $falseCont (`(` $falseArgs^ `:` type($falseArgs) `)`)? - attr-dict - }]; -} - -// FIXME: find the right dialect for this to live -def CPS_UndefinedOp : CPS_Op<"undefined", [NoSideEffect]> { - let summary = "undefined"; - let description = [{ - }]; - - let results = (outs AnyType:$result); - - let assemblyFormat = [{ - attr-dict `:` type($result) - }]; -} - -#endif // DIALECT_CPS_IR_CPSOPS diff --git a/query/include/hail/Dialect/Missing/CMakeLists.txt b/query/include/hail/Dialect/Missing/CMakeLists.txt deleted file mode 100644 index f33061b2d87..00000000000 --- a/query/include/hail/Dialect/Missing/CMakeLists.txt +++ /dev/null @@ -1 +0,0 @@ -add_subdirectory(IR) diff --git a/query/include/hail/Dialect/Missing/IR/CMakeLists.txt b/query/include/hail/Dialect/Missing/IR/CMakeLists.txt deleted file mode 100644 index 9a484fd14ab..00000000000 --- a/query/include/hail/Dialect/Missing/IR/CMakeLists.txt +++ /dev/null @@ -1,2 +0,0 @@ -set(LLVM_TARGET_DEFINITIONS MissingOps.td) -add_mlir_dialect(MissingOps missing) diff --git a/query/include/hail/Dialect/Missing/IR/Missing.h b/query/include/hail/Dialect/Missing/IR/Missing.h deleted file mode 100644 index 4fe48512a8c..00000000000 --- a/query/include/hail/Dialect/Missing/IR/Missing.h +++ /dev/null @@ -1,15 +0,0 @@ -#ifndef HAIL_DIALECT_MISSING_IR_MISSING_H -#define HAIL_DIALECT_MISSING_IR_MISSING_H - -#include "mlir/IR/Dialect.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/Interfaces/SideEffectInterfaces.h" - -#include "mlir/IR/BuiltinTypes.h" - -#include "hail/Dialect/Missing/IR/MissingOpsDialect.h.inc" - -#define GET_OP_CLASSES -#include "hail/Dialect/Missing/IR/MissingOps.h.inc" - -#endif // HAIL_DIALECT_MISSING_IR_MISSING_H diff --git a/query/include/hail/Dialect/Missing/IR/MissingOps.td b/query/include/hail/Dialect/Missing/IR/MissingOps.td deleted file mode 100644 index 6089befc825..00000000000 --- a/query/include/hail/Dialect/Missing/IR/MissingOps.td +++ /dev/null @@ -1,32 +0,0 @@ -#ifndef DIALECT_MISSING_MISSINGOPS -#define DIALECT_MISSING_MISSINGOPS - -include "mlir/Interfaces/SideEffectInterfaces.td" -include "mlir/IR/OpBase.td" - -def Missing_Dialect : Dialect { - let name = "missing"; - let summary = "Allows inserting a new 'missing' value into all types"; - let cppNamespace = "::hail::ir"; -} - -class Missing_Op traits = []> : - Op; - -def Missing_MissingOp : Missing_Op<"missing", [NoSideEffect]> { - let results = (outs AnyType:$result); - let assemblyFormat = "attr-dict `:` type($result)"; - - // let builders = [ - // OpBuilder<(ins "::mlir::Attribute":$value, "::mlir::Type":$type), - // [{ build($_builder, $_state, type, value); }]>, - // ]; -} - -def Missing_IsMissingOp : Missing_Op<"is_missing", [NoSideEffect]> { - let arguments = (ins AnyType:$operand); - let results = (outs I1); - let assemblyFormat = "$operand attr-dict `:` type($operand)"; -} - -#endif // DIALECT_MISSING_MISSINGOPS diff --git a/query/include/hail/Dialect/Option/CMakeLists.txt b/query/include/hail/Dialect/Option/CMakeLists.txt deleted file mode 100644 index f33061b2d87..00000000000 --- a/query/include/hail/Dialect/Option/CMakeLists.txt +++ /dev/null @@ -1 +0,0 @@ -add_subdirectory(IR) diff --git a/query/include/hail/Dialect/Option/IR/CMakeLists.txt b/query/include/hail/Dialect/Option/IR/CMakeLists.txt deleted file mode 100644 index f349d053752..00000000000 --- a/query/include/hail/Dialect/Option/IR/CMakeLists.txt +++ /dev/null @@ -1,2 +0,0 @@ -set(LLVM_TARGET_DEFINITIONS OptionOps.td) -add_mlir_dialect(OptionOps option) diff --git a/query/include/hail/Dialect/Option/IR/Option.h b/query/include/hail/Dialect/Option/IR/Option.h deleted file mode 100644 index a827a89a80e..00000000000 --- a/query/include/hail/Dialect/Option/IR/Option.h +++ /dev/null @@ -1,19 +0,0 @@ -#ifndef HAIL_DIALECT_OPTION_IR_OPTION_H -#define HAIL_DIALECT_OPTION_IR_OPTION_H - -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Dialect.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/Interfaces/SideEffectInterfaces.h" - -#include "hail/Dialect/CPS/IR/CPS.h" - -#include "hail/Dialect/Option/IR/OptionOpsDialect.h.inc" - -#define GET_TYPEDEF_CLASSES -#include "hail/Dialect/Option/IR/OptionOpsTypes.h.inc" - -#define GET_OP_CLASSES -#include "hail/Dialect/Option/IR/OptionOps.h.inc" - -#endif // HAIL_DIALECT_OPTION_IR_OPTION_H diff --git a/query/include/hail/Dialect/Option/IR/OptionOps.td b/query/include/hail/Dialect/Option/IR/OptionOps.td deleted file mode 100644 index 99c4ad38e07..00000000000 --- a/query/include/hail/Dialect/Option/IR/OptionOps.td +++ /dev/null @@ -1,83 +0,0 @@ -#ifndef DIALECT_OPTION_IR_OPTIONOPS -#define DIALECT_OPTION_IR_OPTIONOPS - -include "mlir/Interfaces/SideEffectInterfaces.td" -include "mlir/IR/AttrTypeBase.td" -include "mlir/IR/OpBase.td" -include "mlir/IR/BuiltinTypes.td" - -include "hail/Dialect/CPS/IR/CPSBase.td" - - -def Option_Dialect : Dialect { - let name = "option"; - let summary = "Provides the Option type"; - let cppNamespace = "::hail::ir"; - let useDefaultTypePrinterParser = 1; - let dependentDialects = ["CPSDialect"]; -} - -class Option_Type traits = []> - : TypeDef { - let mnemonic = typeMnemonic; -} - -def Option_OptionType : Option_Type<"Option", "option"> { - let summary = "Option type"; - let parameters = (ins ArrayRefParameter<"mlir::Type">:$valueTypes); - // The declarative format "`<` $inputs `>`" doesn't support the empty list of types - let hasCustomAssemblyFormat = 1; - - let builders = [ - TypeBuilder<(ins CArg<"mlir::TypeRange", "mlir::TypeRange()">:$inputTypes), [{ - return $_get($_ctxt, inputTypes); - }]> - ]; - let skipDefaultBuilders = 1; - let genStorageClass = 0; -} - -class Option_Op traits = []> : - Op; - -def Option_DestructOp : Option_Op<"destruct", [ NoSideEffect, Terminator ]> { - let arguments = (ins Variadic:$inputs, CPS_EmptyContType:$missingCont, CPS_ContType:$presentCont); - - let hasCustomAssemblyFormat = 1; - let hasCanonicalizer = 1; -} - -def Option_ConstructOp : Option_Op<"construct", [NoSideEffect, SingleBlock]> { - let results = (outs Option_OptionType:$result); - let regions = (region SizedRegion<1>:$bodyRegion); - - let assemblyFormat = "qualified(type($result)) attr-dict-with-keyword $bodyRegion"; - - let skipDefaultBuilders = 1; - let builders = [ OpBuilder<(ins CArg<"mlir::TypeRange", "mlir::TypeRange()">:$valueTypes)> ]; - - let extraClassDeclaration = [{ - mlir::Value getMissingCont() { return getBody()->getArgument(0); } - mlir::Value getPresentCont() { return getBody()->getArgument(0); } - }]; -} - -def Option_MapOp : Option_Op<"map", [NoSideEffect, SingleBlock]> { - let arguments = (ins Variadic:$inputs); - let results = (outs Option_OptionType:$result); - let regions = (region SizedRegion<1>:$bodyRegion); - - let assemblyFormat = "`(` $inputs `)` `:` functional-type($inputs, $result) attr-dict-with-keyword $bodyRegion"; - - let skipDefaultBuilders = 1; - let builders = [ - OpBuilder<(ins CArg<"mlir::TypeRange", "mlir::TypeRange()">:$resultValueTypes, - CArg<"mlir::ValueRange", "mlir::ValueRange()">:$inputs)> ]; -} - -def Option_YieldOp : Option_Op<"yield", [NoSideEffect, Terminator]> { - let arguments = (ins Variadic:$inputs); - let assemblyFormat = "$inputs `:` type($inputs) attr-dict"; -} - -#endif // DIALECT_OPTION_IR_OPTIONOPS diff --git a/query/include/hail/Dialect/Sandbox/CMakeLists.txt b/query/include/hail/Dialect/Sandbox/CMakeLists.txt deleted file mode 100644 index f33061b2d87..00000000000 --- a/query/include/hail/Dialect/Sandbox/CMakeLists.txt +++ /dev/null @@ -1 +0,0 @@ -add_subdirectory(IR) diff --git a/query/include/hail/Dialect/Sandbox/IR/CMakeLists.txt b/query/include/hail/Dialect/Sandbox/IR/CMakeLists.txt deleted file mode 100644 index 79b2d2a521d..00000000000 --- a/query/include/hail/Dialect/Sandbox/IR/CMakeLists.txt +++ /dev/null @@ -1,8 +0,0 @@ -set(LLVM_TARGET_DEFINITIONS SandboxOps.td) -mlir_tablegen(SandboxOpsEnums.h.inc -gen-enum-decls) -mlir_tablegen(SandboxOpsEnums.cpp.inc -gen-enum-defs) -add_mlir_dialect(SandboxOps sb) - -mlir_tablegen(SandboxOpsAttributes.h.inc -gen-attrdef-decls) -mlir_tablegen(SandboxOpsAttributes.cpp.inc -gen-attrdef-defs) -add_public_tablegen_target(MLIRSandboxOpsAttributesIncGen) diff --git a/query/include/hail/Dialect/Sandbox/IR/Sandbox.h b/query/include/hail/Dialect/Sandbox/IR/Sandbox.h deleted file mode 100644 index a095d4e060b..00000000000 --- a/query/include/hail/Dialect/Sandbox/IR/Sandbox.h +++ /dev/null @@ -1,19 +0,0 @@ -#ifndef HAIL_DIALECT_SANDBOX_IR_SANDBOX_H -#define HAIL_DIALECT_SANDBOX_IR_SANDBOX_H - -#include "mlir/IR/Dialect.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/Interfaces/SideEffectInterfaces.h" - -#include "mlir/IR/BuiltinTypes.h" - -#include "hail/Dialect/Sandbox/IR/SandboxOpsDialect.h.inc" -#include "hail/Dialect/Sandbox/IR/SandboxOpsEnums.h.inc" - -#define GET_TYPEDEF_CLASSES -#include "hail/Dialect/Sandbox/IR/SandboxOpsTypes.h.inc" - -#define GET_OP_CLASSES -#include "hail/Dialect/Sandbox/IR/SandboxOps.h.inc" - -#endif // HAIL_DIALECT_SANDBOX_IR_SANDBOX_H diff --git a/query/include/hail/Dialect/Sandbox/IR/SandboxBase.td b/query/include/hail/Dialect/Sandbox/IR/SandboxBase.td deleted file mode 100644 index 77922ad945c..00000000000 --- a/query/include/hail/Dialect/Sandbox/IR/SandboxBase.td +++ /dev/null @@ -1,132 +0,0 @@ -#ifndef DIALECT_SANDBOX_SANDBOXBASE -#define DIALECT_SANDBOX_SANDBOXBASE - -include "mlir/IR/AttrTypeBase.td" -include "mlir/IR/OpBase.td" -include "mlir/IR/BuiltinTypes.td" - - -include "mlir/IR/EnumAttr.td" - - -def Sandbox_Dialect : Dialect { - let name = "sb"; - let summary = "Dialect for experimenting with MLIR"; - let cppNamespace = "::hail::ir"; - let useDefaultTypePrinterParser = 1; - let hasConstantMaterializer = 1; -} - -class Sandbox_Type traits = []> - : TypeDef { - let mnemonic = typeMnemonic; -} - -class Sandbox_Attr traits = []> - : AttrDef { - let mnemonic = attrMnemonic; -} - -def SandboxType : - DialectType()">, - "Sandbox type", "::hail::ir::SandboxType">; - -def Sandbox_OptionalAttr : Sandbox_Attr<"Optional", "opt"> { - let summary = "Sandbox attribute representing a possibly missing value"; - let parameters = (ins AttributeSelfTypeParameter<"">:$type, OptionalParameter<"mlir::Attribute">:$value); - // let builders = [ - // AttrBuilderWithInferredContext<(ins "Type":$type, - // "const Attribute &":$value), [{ - // return $_get(type.getContext(), type, value); - // }]> - // ]; - let assemblyFormat = "`<` $value `>`"; -} - -// Here is a simple definition of an "integer" type, with a width parameter. -def Sandbox_Int : Sandbox_Type<"Int", "int"> { - let summary = "Sandbox 32 bit integer type"; - let description = [{ - Sandbox 32 bit integer type - }]; - - /// Indicate that our type will add additional verification to the parameters. - let genVerifyDecl = 0; -} - -def Sandbox_Bool : Sandbox_Type<"Boolean", "bool"> { - let summary = "Sandbox 32 bit integer type"; - let description = [{ - Sandbox 32 bit integer type - }]; - - /// Indicate that our type will add additional verification to the parameters. - let genVerifyDecl = 0; -} - -def Sandbox_CmpPredicateAttr : I64EnumAttr< - "CmpPredicate", "", - [ - I64EnumAttrCase<"LT", 0, "lt">, - I64EnumAttrCase<"LTEQ", 1, "lteq">, - I64EnumAttrCase<"GT", 2, "gt">, - I64EnumAttrCase<"GTEQ", 3, "gteq">, - I64EnumAttrCase<"EQ", 4, "eq">, - I64EnumAttrCase<"NEQ", 5, "eq">, - ]> { -} - -// def Sandbox_Array : Sandbox_Type<"Array", "array", [ -// DeclareTypeInterfaceMethods -// ]> { -// let summary = "Sandbox array"; -// let parameters = (ins "mlir::Type":$elementType); - -// // let builders = [ -// // TypeBuilderWithInferredContext<(ins "Type":$elementType), [{ -// // // Drop default memory space value and replace it with empty attribute. -// // Attribute nonDefaultMemorySpace = skipDefaultMemorySpace(memorySpace); -// // return $_get(elementType.getContext(), elementType, nonDefaultMemorySpace); -// // }]>, -// // /// [deprecated] `Attribute`-based form should be used instead. -// // TypeBuilderWithInferredContext<(ins "Type":$elementType, -// // "unsigned":$memorySpace), [{ -// // // Convert deprecated integer-like memory space to Attribute. -// // Attribute memorySpaceAttr = -// // wrapIntegerMemorySpace(memorySpace, elementType.getContext()); -// // return UnrankedMemRefType::get(elementType, memorySpaceAttr); -// // }]> -// // ]; -// let extraClassDeclaration = [{ -// // using ShapedType::Trait::clone; -// // using ShapedType::Trait::getElementTypeBitWidth; -// // using ShapedType::Trait::getRank; -// // using ShapedType::Trait::getNumElements; -// // using ShapedType::Trait::isDynamicDim; -// // using ShapedType::Trait::hasStaticShape; -// // using ShapedType::Trait::getNumDynamicDims; -// // using ShapedType::Trait::getDimSize; -// // using ShapedType::Trait::getDynamicDimIndex; - -// // ArrayRef getShape() const { return llvm::None; } - -// // /// [deprecated] Returns the memory space in old raw integer representation. -// // /// New `Attribute getMemorySpace()` method should be used instead. -// // unsigned getMemorySpaceAsInt() const; -// }]; -// let skipDefaultBuilders = 1; -// let mnemonic = ?; -// let genVerifyDecl = 1; -// } - -def Sandbox_Array : Sandbox_Type<"Array", "array", [ - DeclareTypeInterfaceMethods - ]> { - let summary = "Sandbox array"; - let parameters = (ins "mlir::Type":$elementType); - let hasCustomAssemblyFormat = 1; - let genVerifyDecl = 1; -} - - -#endif // DIALECT_SANDBOX_SANDBOXBASE diff --git a/query/include/hail/Dialect/Sandbox/IR/SandboxOps.td b/query/include/hail/Dialect/Sandbox/IR/SandboxOps.td deleted file mode 100644 index 1e6f32aa8aa..00000000000 --- a/query/include/hail/Dialect/Sandbox/IR/SandboxOps.td +++ /dev/null @@ -1,79 +0,0 @@ -#ifndef DIALECT_SANDBOX_SANDBOXOPS -#define DIALECT_SANDBOX_SANDBOXOPS - -include "hail/Dialect/Sandbox/IR/SandboxBase.td" -include "mlir/Interfaces/SideEffectInterfaces.td" -include "mlir/IR/OpBase.td" - -class Sandbox_Op traits = []> : - Op; - -def Sandbox_ConstantOp : Sandbox_Op<"constant", [ConstantLike, NoSideEffect]> { - let arguments = (ins AnyAttr:$value); - let results = (outs AnyType:$output); - let assemblyFormat = "`(`$value`)` attr-dict `:` type($output)"; - - let builders = [ - OpBuilder<(ins "::mlir::Attribute":$value, "::mlir::Type":$type), - [{ build($_builder, $_state, type, value); }]>, - ]; - - let hasFolder = true; - let hasVerifier = true; -} - - -def Sandbox_AddIOp : Sandbox_Op<"addi", [NoSideEffect, Commutative]> { - let arguments = (ins Sandbox_Int:$lhs, Sandbox_Int:$rhs); - let results = (outs Sandbox_Int:$result); - let assemblyFormat = "$lhs $rhs attr-dict"; - - let hasFolder = true; - let hasCanonicalizer = true; -} - -def Sandbox_ComparisonOp : Sandbox_Op<"compare", [NoSideEffect]> { - let arguments = (ins Sandbox_CmpPredicateAttr:$predicate, Sandbox_Int:$lhs, Sandbox_Int:$rhs); - let results = (outs Sandbox_Bool:$output); - let assemblyFormat = "$predicate `,` $lhs `,` $rhs attr-dict `:` type($output)"; - - let extraClassDeclaration = [{ - static CmpPredicate getPredicateByName(llvm::StringRef name); - }]; - - let hasFolder = 1; - let hasCanonicalizer = 0; -} - -def Sandbox_MakeArrayOp : Sandbox_Op<"make_array", [NoSideEffect]> { - let arguments = (ins Variadic:$elems); - let results = (outs AnyType:$result); - - let assemblyFormat = "`(` $elems `)` `(` type($elems) `)` attr-dict `:` type($result)"; - - let hasVerifier = 1; - let hasFolder = 0; - let hasCanonicalizer = 0; -} - -def Sandbox_ArrayRefOp : Sandbox_Op<"array_ref", [NoSideEffect]> { - let arguments = (ins AnyType:$array, Sandbox_Int:$index); - let results = (outs AnyType:$result); - - let assemblyFormat = "$array type($array) $index attr-dict `:` type($result)"; - let hasVerifier = 1; - - let hasFolder = 1; - let hasCanonicalizer = 0; -} - -// def Sandbox_ArrayMap : Sandbox_Op<"array_map", [NoSideEffect]> { - -// } - -def Sandbox_PrintOp : Sandbox_Op<"print"> { - let arguments = (ins AnyType:$value); - let assemblyFormat = "$value attr-dict `:` type($value)"; -} - -#endif // DIALECT_SANDBOX_SANDBOXOPS diff --git a/query/include/hail/InitAllDialects.h b/query/include/hail/InitAllDialects.h deleted file mode 100644 index 1fe3baf3b34..00000000000 --- a/query/include/hail/InitAllDialects.h +++ /dev/null @@ -1,40 +0,0 @@ -#ifndef HAIL_INITALLDIALECTS_H_ -#define HAIL_INITALLDIALECTS_H_ - -#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/Dialect.h" - -#include "Dialect/CPS/IR/CPS.h" -#include "Dialect/Missing/IR/Missing.h" -#include "Dialect/Option/IR/Option.h" -#include "Dialect/Sandbox/IR/Sandbox.h" - -namespace hail::ir { - -inline void registerAllDialects(mlir::DialectRegistry ®istry) { - // clang-format off - registry.insert(); - // clang-format on -} - -inline void registerAllDialects(mlir::MLIRContext &context) { - mlir::DialectRegistry registry; - registerAllDialects(registry); - context.appendDialectRegistry(registry); -} - -} // namespace hail::ir - -#endif // HAIL_INITALLDIALECTS_H_ diff --git a/query/include/hail/InitAllPasses.h b/query/include/hail/InitAllPasses.h deleted file mode 100644 index 1bc71101dea..00000000000 --- a/query/include/hail/InitAllPasses.h +++ /dev/null @@ -1,38 +0,0 @@ -//===- LinkAllPassesAndDialects.h - MLIR Registration -----------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file defines a helper to trigger the registration of all dialects and -// passes to the system. -// -//===----------------------------------------------------------------------===// - -#ifndef HAIL_INITALLPASSES_H_ -#define HAIL_INITALLPASSES_H_ - -#include "Conversion/Passes.h" -#include "Transforms/Passes.h" - -#include - -namespace hail::ir { - -// This function may be called to register the MLIR passes with the -// global registry. -// If you're building a compiler, you likely don't need this: you would build a -// pipeline programmatically without the need to register with the global -// registry, since it would already be calling the creation routine of the -// individual passes. -// The global registry is interesting to interact with the command-line tools. -inline void registerAllPasses() { - registerConversionPasses(); - registerTransformsPasses(); -} - -} // namespace hail::ir - -#endif // HAIL_INITALLPASSES_H_ diff --git a/query/include/hail/Support/MLIR.h b/query/include/hail/Support/MLIR.h deleted file mode 100644 index dc21e8baa60..00000000000 --- a/query/include/hail/Support/MLIR.h +++ /dev/null @@ -1,235 +0,0 @@ -//===------------------------------------------------------------------------------------------===// -// -// Modification of mlir/Support/LLVM.h -// This file forward declares and imports various common LLVM and MLIR datatypes that Hail wants to -// use unqualified. -// -// Note that most of these are forward declared and then imported into the MLIR namespace with using -// decls, rather than being #included. This is because we want clients to explicitly #include the -// files they need. -// -//===------------------------------------------------------------------------------------------===// - -#ifndef HAIL_SUPPORT_MLIR_H -#define HAIL_SUPPORT_MLIR_H - -// We include these two headers because they cannot be practically forward -// declared, and are effectively language features. -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/Diagnostics.h" -#include "mlir/IR/OperationSupport.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/DialectConversion.h" -#include "llvm/ADT/None.h" -#include "llvm/Support/Casting.h" -#include - -// Forward declarations. -namespace llvm { - -// String types -template -class SmallString; -class StringRef; -class StringLiteral; -class Twine; - -// Containers. -template -class ArrayRef; -class BitVector; -namespace detail { -template -struct DenseMapPair; -} // namespace detail -template -class DenseMap; -template -struct DenseMapInfo; -template -class DenseSet; -class MallocAllocator; -template -class MutableArrayRef; -template -class Optional; -template -class PointerUnion; -template -class SetVector; -template -class SmallPtrSet; -template -class SmallPtrSetImpl; -template -class SmallVector; -template -class SmallVectorImpl; -template -class StringSet; -template -class StringSwitch; -template -class TinyPtrVector; -template -class TypeSwitch; - -// Other common classes. -class APInt; -class APSInt; -class APFloat; -template -class function_ref; -template -class iterator_range; -class raw_ostream; -class SMLoc; -class SMRange; - -} // namespace llvm - -namespace mlir { - -// Core IR types -class Attribute; -class Block; -class Location; -class Operation; -class Type; -class TypeRange; -class Value; -class ValueRange; - -// Types and Attributes -class ArrayAttr; -class BoolAttr; -class DictionaryAttr; -class FlatSymbolRefAttr; -class IntegerAttr; -class IntegerType; -class NamedAttrList; -class RankedTensorType; -class SymbolRefAttr; - -// Common classes -class Builder; -class ConversionPatternRewriter; -class ConversionTarget; -class IRRewriter; -class InFlightDiagnostic; -struct LogicalResult; -class MLIRContext; -class OpBuilder; -template -class OpConversionPattern; -struct OperationState; -class OpFoldResult; -template -struct OpRewritePattern; -class Pass; -class PatternRewriter; -class RewritePatternSet; - -// Common functions -// NOLINTBEGIN(readability-redundant-declaration) -auto success(bool isSuccess) -> LogicalResult; -auto failure(bool isFailure) -> LogicalResult; -// NOLINTEND(readability-redundant-declaration) - -} // namespace mlir - -namespace hail::ir { -// NOLINTBEGIN(misc-unused-using-decls) - -// Casting operators. -using llvm::cast; -using llvm::cast_or_null; -using llvm::dyn_cast; -using llvm::dyn_cast_or_null; -using llvm::isa; -using llvm::isa_and_nonnull; - -// String types -using llvm::SmallString; -using llvm::StringLiteral; -using llvm::StringRef; -using llvm::Twine; - -// Container Related types -// -// Containers. -using llvm::ArrayRef; -using llvm::BitVector; -using llvm::DenseMap; -using llvm::DenseMapInfo; -using llvm::DenseSet; -using llvm::MutableArrayRef; -using llvm::None; -using llvm::Optional; -using llvm::PointerUnion; -using llvm::SetVector; -using llvm::SmallPtrSet; -using llvm::SmallPtrSetImpl; -using llvm::SmallVector; -using llvm::SmallVectorImpl; -using llvm::StringSet; -using llvm::StringSwitch; -using llvm::TinyPtrVector; -using llvm::TypeSwitch; - -// Other common classes. -using llvm::APFloat; -using llvm::APInt; -using llvm::APSInt; -using llvm::function_ref; -using llvm::iterator_range; -using llvm::raw_ostream; -using llvm::SMLoc; -using llvm::SMRange; - -// Core MLIR IR classes -using mlir::Attribute; -using mlir::Block; -using mlir::Location; -using mlir::Operation; -using mlir::Type; -using mlir::TypeRange; -using mlir::Value; -using mlir::ValueRange; - -// MLIR Types and Attributes -using mlir::ArrayAttr; -using mlir::BoolAttr; -using mlir::DictionaryAttr; -using mlir::FlatSymbolRefAttr; -using mlir::IntegerAttr; -using mlir::IntegerType; -using mlir::NamedAttrList; -using mlir::RankedTensorType; -using mlir::SymbolRefAttr; - -// Common MLIR classes -using mlir::Builder; -using mlir::ConversionPatternRewriter; -using mlir::ConversionTarget; -using mlir::InFlightDiagnostic; -using mlir::IRRewriter; -using mlir::LogicalResult; -using mlir::MLIRContext; -using mlir::OpBuilder; -using mlir::OpConversionPattern; -using mlir::OperationState; -using mlir::OpFoldResult; -using mlir::OpRewritePattern; -using mlir::Pass; -using mlir::PatternRewriter; -using mlir::RewritePatternSet; - -// Common functions -using mlir::failure; -using mlir::success; - -// NOLINTEND(misc-unused-using-decls) -} // namespace hail::ir - -#endif // HAIL_SUPPORT_MLIR_H diff --git a/query/include/hail/Transforms/CMakeLists.txt b/query/include/hail/Transforms/CMakeLists.txt deleted file mode 100644 index c98bc122676..00000000000 --- a/query/include/hail/Transforms/CMakeLists.txt +++ /dev/null @@ -1,3 +0,0 @@ -set(LLVM_TARGET_DEFINITIONS Passes.td) -mlir_tablegen(Passes.h.inc -gen-pass-decls -name Transforms) -add_public_tablegen_target(HailTransformsPassIncGen) diff --git a/query/include/hail/Transforms/Passes.h b/query/include/hail/Transforms/Passes.h deleted file mode 100644 index 8134e3a5296..00000000000 --- a/query/include/hail/Transforms/Passes.h +++ /dev/null @@ -1,18 +0,0 @@ -#ifndef HAIL_TRANSFORMS_PASSES_H -#define HAIL_TRANSFORMS_PASSES_H - -#include "mlir/Pass/Pass.h" -#include "llvm/Support/Debug.h" -#include - -namespace hail::ir { - -auto createTestMissingnessAnalysisPass() -> std::unique_ptr; - -/// Generate the code for registering passes. -#define GEN_PASS_REGISTRATION -#include "hail/Transforms/Passes.h.inc" - -} // namespace hail::ir - -#endif // HAIL_TRANSFORMS_PASSES_H diff --git a/query/include/hail/Transforms/Passes.td b/query/include/hail/Transforms/Passes.td deleted file mode 100644 index 3e5dcaac9d6..00000000000 --- a/query/include/hail/Transforms/Passes.td +++ /dev/null @@ -1,16 +0,0 @@ -#ifndef HAIL_TRANSFORMS_PASSES -#define HAIL_TRANSFORMS_PASSES - -include "mlir/Pass/PassBase.td" -include "mlir/Rewrite/PassUtil.td" - -def TestMissingnessAnalysis : Pass<"test-missingness-analysis"> { - let summary = "annotate with results of missingness analysis"; - let description = [{ - This pass adds anotations recording the inferences of the missingness aware - Sparse Conditional Constant Propagation. - }]; - let constructor = "hail::ir::createTestMissingnessAnalysisPass()"; -} - -#endif // HAIL_TRANSFORMS_PASSES diff --git a/query/lib/Analysis/CMakeLists.txt b/query/lib/Analysis/CMakeLists.txt deleted file mode 100644 index 2fecea971dc..00000000000 --- a/query/lib/Analysis/CMakeLists.txt +++ /dev/null @@ -1,27 +0,0 @@ -set(LLVM_OPTIONAL_SOURCES - MissingnessAnalysis.cpp - MissingnessAwareConstantPropagationAnalysis.cpp - ) - -add_mlir_library(HailAnalysis - MissingnessAnalysis.cpp - MissingnessAwareConstantPropagationAnalysis.cpp - - ADDITIONAL_HEADER_DIRS - ${HAIL_MAIN_INCLUDE_DIR}/Analysis - - DEPENDS - mlir-headers - - LINK_LIBS PUBLIC - HailMissingDialect - MLIRAnalysis - # MLIRCallInterfaces - MLIRControlFlowInterfaces - # MLIRDataLayoutInterfaces - # MLIRInferIntRangeInterface - # MLIRInferTypeOpInterface - # MLIRLoopLikeInterface - MLIRSideEffectInterfaces - # MLIRViewLikeInterface - ) diff --git a/query/lib/Analysis/MissingnessAnalysis.cpp b/query/lib/Analysis/MissingnessAnalysis.cpp deleted file mode 100644 index 26f43188fb2..00000000000 --- a/query/lib/Analysis/MissingnessAnalysis.cpp +++ /dev/null @@ -1,51 +0,0 @@ -#include "hail/Analysis/MissingnessAnalysis.h" -#include "hail/Dialect/Missing/IR/Missing.h" -#include "hail/Support/MLIR.h" -#include "mlir/IR/OpDefinition.h" -#include "llvm/Support/Debug.h" - -#include - -#define DEBUG_TYPE "missingness-analysis" - -using namespace hail::ir; - -//===----------------------------------------------------------------------===// -// MissingnessValue -//===----------------------------------------------------------------------===// - -void MissingnessValue::print(llvm::raw_ostream &os) const { - if (state == State::Missing) - os << ""; - else if (state == State::Present) - os << ""; - else - os << ""; -} - -//===----------------------------------------------------------------------===// -// MissingnessAnalysis -//===----------------------------------------------------------------------===// - -void MissingnessAnalysis::visitOperation( - Operation *op, ArrayRef const *> operands, - ArrayRef *> results) { - LLVM_DEBUG(llvm::dbgs() << "Missingness: Visiting operation: " << *op << "\n"); - - // FIXME: move missingness op semantics to an interface - if (auto missingOp = dyn_cast(op)) { - propagateIfChanged(results.front(), results.front()->join({MissingnessValue::Missing})); - return; - }; - - // By default, operations are strict: if any operand is missing, all results are missing - MissingnessValue::State operandsState{}; - for (auto const *lattice : operands) { - operandsState = std::max(operandsState, lattice->getValue().getState()); - } - for (auto *result : results) { - auto changed = result->join({operandsState}); - LLVM_DEBUG(llvm::dbgs() << " result: "; result->print(llvm::dbgs()); llvm::dbgs() << "\n"); - propagateIfChanged(result, changed); - }; -} diff --git a/query/lib/Analysis/MissingnessAwareConstantPropagationAnalysis.cpp b/query/lib/Analysis/MissingnessAwareConstantPropagationAnalysis.cpp deleted file mode 100644 index 493b504e028..00000000000 --- a/query/lib/Analysis/MissingnessAwareConstantPropagationAnalysis.cpp +++ /dev/null @@ -1,101 +0,0 @@ -#include "hail/Analysis/MissingnessAwareConstantPropagationAnalysis.h" -#include "hail/Analysis/MissingnessAnalysis.h" -#include "hail/Dialect/Missing/IR/Missing.h" -#include "hail/Support/MLIR.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/OpDefinition.h" -#include "llvm/Support/Debug.h" - -#define DEBUG_TYPE "missingness-aware-constant-propagation" - -using namespace hail::ir; - -using mlir::dataflow::ConstantValue; -using mlir::dataflow::Lattice; - -void MissingnessAwareConstantPropagation::visitOperation( - Operation *op, ArrayRef const *> operands, - ArrayRef *> results) { - LLVM_DEBUG(llvm::dbgs() << "MACP: Visiting operation: " << *op << "\n"); - - auto builder = Builder(op->getContext()); - - // FIXME: move missingness op semantics to an interface - if (auto missingOp = dyn_cast(op)) { - auto const *missingness = - getOrCreateFor>(missingOp, missingOp.getOperand()); - if (missingness->isUninitialized()) - return; - if (missingness->getValue().isMissing()) { - propagateIfChanged(results.front(), results.front()->join(builder.getBoolAttr(true))); - } else if (missingness->getValue().isPresent()) { - propagateIfChanged(results.front(), results.front()->join(builder.getBoolAttr(false))); - } else { - propagateIfChanged(results.front(), results.front()->join(ConstantValue())); - } - return; - }; - - // Don't try to simulate the results of a region operation as we can't - // guarantee that folding will be out-of-place. We don't allow in-place - // folds as the desire here is for simulated execution, and not general - // folding. - if (op->getNumRegions() != 0U) - return; - - // By default, only propagate constants if there are no missing operands. - bool const anyMissing = - std::any_of(op->operand_begin(), op->operand_end(), [this, op](auto operand) { - auto missingness = getOrCreateFor>(op, operand); - return missingness->isUninitialized() || missingness->getValue().isMissing(); - }); - - if (anyMissing) - return; - - SmallVector constantOperands; - constantOperands.reserve(op->getNumOperands()); - for (auto const *operandLattice : operands) - constantOperands.push_back(operandLattice->getValue().getConstantValue()); - - // Save the original operands and attributes just in case the operation - // folds in-place. The constant passed in may not correspond to the real - // runtime value, so in-place updates are not allowed. - SmallVector const originalOperands(op->getOperands()); - DictionaryAttr const originalAttrs = op->getAttrDictionary(); - - // Simulate the result of folding this operation to a constant. If folding - // fails or was an in-place fold, mark the results as overdefined. - SmallVector foldResults; - foldResults.reserve(op->getNumResults()); - if (failed(op->fold(constantOperands, foldResults))) { - markAllPessimisticFixpoint(results); - return; - } - - // If the folding was in-place, mark the results as overdefined and reset - // the operation. We don't allow in-place folds as the desire here is for - // simulated execution, and not general folding. - if (foldResults.empty()) { - op->setOperands(originalOperands); - op->setAttrs(originalAttrs); - markAllPessimisticFixpoint(results); - return; - } - - // Merge the fold results into the lattice for this operation. - assert(foldResults.size() == op->getNumResults() && "invalid result size"); - for (auto const it : llvm::zip(results, foldResults)) { - Lattice *lattice = std::get<0>(it); - - // Merge in the result of the fold, either a constant or a value. - OpFoldResult const foldResult = std::get<1>(it); - if (auto const attr = foldResult.dyn_cast()) { - LLVM_DEBUG(llvm::dbgs() << "Folded to constant: " << attr << "\n"); - propagateIfChanged(lattice, lattice->join(ConstantValue(attr, op->getDialect()))); - } else { - LLVM_DEBUG(llvm::dbgs() << "Folded to value: " << foldResult.get() << "\n"); - AbstractSparseDataFlowAnalysis::join(lattice, *getLatticeElement(foldResult.get())); - } - } -} diff --git a/query/lib/CMakeLists.txt b/query/lib/CMakeLists.txt deleted file mode 100644 index f26d13d05ce..00000000000 --- a/query/lib/CMakeLists.txt +++ /dev/null @@ -1,4 +0,0 @@ -add_subdirectory(Analysis) -add_subdirectory(Conversion) -add_subdirectory(Dialect) -add_subdirectory(Transforms) diff --git a/query/lib/Conversion/CMakeLists.txt b/query/lib/Conversion/CMakeLists.txt deleted file mode 100644 index 82d395f24cd..00000000000 --- a/query/lib/Conversion/CMakeLists.txt +++ /dev/null @@ -1,5 +0,0 @@ -add_subdirectory(CPSToCF) -add_subdirectory(LowerOption) -add_subdirectory(LowerSandbox) -add_subdirectory(LowerToLLVM) -add_subdirectory(OptionToGenericOption) diff --git a/query/lib/Conversion/CPSToCF/CMakeLists.txt b/query/lib/Conversion/CPSToCF/CMakeLists.txt deleted file mode 100644 index dda5afc0d03..00000000000 --- a/query/lib/Conversion/CPSToCF/CMakeLists.txt +++ /dev/null @@ -1,18 +0,0 @@ -add_mlir_conversion_library(HailCPSToCF - CPSToCF.cpp - - ADDITIONAL_HEADER_DIRS - ${HAIL_MAIN_INCLUDE_DIR}/Conversion/CPSToCF - - DEPENDS - HailConversionPassIncGen - - LINK_COMPONENTS - Core - - LINK_LIBS PUBLIC - HailCPSDialect - MLIRControlFlowDialect - MLIRPass - MLIRTransforms - ) diff --git a/query/lib/Conversion/CPSToCF/CPSToCF.cpp b/query/lib/Conversion/CPSToCF/CPSToCF.cpp deleted file mode 100644 index 29b5df832f8..00000000000 --- a/query/lib/Conversion/CPSToCF/CPSToCF.cpp +++ /dev/null @@ -1,113 +0,0 @@ -#include "hail/Conversion/CPSToCF/CPSToCF.h" - -#include "../PassDetail.h" - -#include "hail/Dialect/CPS/IR/CPS.h" -#include "hail/Support/MLIR.h" - -#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/Value.h" -#include "mlir/Pass/Pass.h" - -#include - -namespace hail::ir { - -struct CPSToCFPass : public CPSToCFBase { - void runOnOperation() override; -}; - -} // namespace hail::ir - -using namespace hail::ir; - -namespace { - -void lowerCallCCOp(IRRewriter &rewriter, CallCCOp callcc, std::vector &defsWorklist) { - auto loc = callcc->getLoc(); - assert(callcc->getParentRegion()->hasOneBlock()); - // Split the current block before callcc to create the continuation point. - Block *parentBlock = callcc->getBlock(); - Block *continuation = rewriter.splitBlock(parentBlock, callcc->getIterator()); - - // create a DefContOp holding the continuation of callcc - rewriter.setInsertionPointToEnd(parentBlock); - auto defcont = rewriter.create(loc, callcc->getResultTypes()); - rewriter.mergeBlocks(continuation, defcont.getBody(), {}); - - defsWorklist.push_back(defcont); - - // inline the body of callcc, replacing the return continuation with the defcont, - // and replacing uses of the results with the args of the defcont - rewriter.mergeBlocks(callcc.getBody(), parentBlock, defcont.getResult()); - rewriter.replaceOp(callcc, defcont.getBody()->getArguments()); -} - -auto getDefBlock(Value cont) -> Block * { - auto def = cont.getDefiningOp(); - assert(def && "Continuation def is not visable"); - return &def.bodyRegion().front(); -} - -void lowerApplyContOp(IRRewriter &rewriter, ApplyContOp op) { - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp(op, op.args(), getDefBlock(op.cont())); -} - -void lowerIfOp(IRRewriter &rewriter, IfOp op) { - rewriter.setInsertionPoint(op); - rewriter.replaceOpWithNewOp(op, op.condition(), - getDefBlock(op.trueCont()), op.trueArgs(), - getDefBlock(op.falseCont()), op.falseArgs()); -} - -void lowerDefContOp(IRRewriter &rewriter, DefContOp op) { - rewriter.inlineRegionBefore(op.bodyRegion(), *op->getParentRegion(), - std::next(op->getBlock()->getIterator())); - rewriter.eraseOp(op); -} - -} // namespace - -namespace hail::ir { - -void CPSToCFPass::runOnOperation() { - std::vector callccsWorklist; - std::vector defsWorklist; - std::vector usesWorklist; - constexpr int initCapacity = 64; - defsWorklist.reserve(initCapacity); - usesWorklist.reserve(initCapacity); - - // add nested ops to worklist in postorder - auto *root = getOperation(); - for (auto ®ion : getOperation()->getRegions()) - region.walk([&](Operation *op) { - if (auto callcc = dyn_cast(op)) - callccsWorklist.push_back(callcc); - else if (auto defcont = dyn_cast(op)) - defsWorklist.push_back(defcont); - else if (isa(op) || isa(op)) - usesWorklist.push_back(op); - }); - - IRRewriter rewriter(root->getContext()); - for (auto callcc : callccsWorklist) { - lowerCallCCOp(rewriter, callcc, defsWorklist); - } - for (auto *op : usesWorklist) { - if (auto apply = dyn_cast(op)) { - lowerApplyContOp(rewriter, apply); - } else if (auto ifOp = dyn_cast(op)) { - lowerIfOp(rewriter, ifOp); - } - } - for (auto defcont : defsWorklist) { - lowerDefContOp(rewriter, defcont); - } -} - -auto createCPSToCFPass() -> std::unique_ptr { return std::make_unique(); } - -} // namespace hail::ir diff --git a/query/lib/Conversion/LowerOption/CMakeLists.txt b/query/lib/Conversion/LowerOption/CMakeLists.txt deleted file mode 100644 index a7e924408ba..00000000000 --- a/query/lib/Conversion/LowerOption/CMakeLists.txt +++ /dev/null @@ -1,19 +0,0 @@ -add_mlir_conversion_library(HailLowerOption - LowerOption.cpp - - ADDITIONAL_HEADER_DIRS - ${HAIL_MAIN_INCLUDE_DIR}/Conversion/LowerOption - - DEPENDS - HailConversionPassIncGen - - LINK_COMPONENTS - Core - - LINK_LIBS PUBLIC - HailCPSDialect - HailOptionDialect - MLIRArithmeticDialect - MLIRPass - MLIRTransforms - ) diff --git a/query/lib/Conversion/LowerOption/LowerOption.cpp b/query/lib/Conversion/LowerOption/LowerOption.cpp deleted file mode 100644 index 3fcf78df529..00000000000 --- a/query/lib/Conversion/LowerOption/LowerOption.cpp +++ /dev/null @@ -1,143 +0,0 @@ -#include "hail/Conversion/LowerOption/LowerOption.h" - -#include "../PassDetail.h" - -#include "hail/Dialect/CPS/IR/CPS.h" -#include "hail/Dialect/Option/IR/Option.h" -#include "hail/Support/MLIR.h" - -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Transforms/DialectConversion.h" -#include "llvm/ADT/SmallVector.h" - -#include - -namespace hail::ir { - -namespace { - -struct LowerOptionPass : public LowerOptionBase { - void runOnOperation() override; -}; - -struct LoweredOption { - explicit LoweredOption(OptionType type) { operands.reserve(type.getValueTypes().size() + 1); }; - auto isDefined() -> Value { return operands[0]; }; - auto values() -> ValueRange { return ValueRange(operands).drop_front(1); }; - - SmallVector operands{}; -}; - -auto unpackOptional(ConversionPatternRewriter &rewriter, mlir::Location loc, Value optional) - -> LoweredOption { - auto type = optional.getType().cast(); - SmallVector resultTypes; - resultTypes.reserve(type.getValueTypes().size() + 1); - LoweredOption result{type}; - resultTypes.push_back(rewriter.getI1Type()); - resultTypes.append(type.getValueTypes().begin(), type.getValueTypes().end()); - rewriter.createOrFold(result.operands, loc, resultTypes, - optional); - return result; -} - -struct ConstructOpConversion : public OpConversionPattern { - ConstructOpConversion(MLIRContext *context) - : OpConversionPattern(context, /*benefit=*/1) {} - - auto matchAndRewrite(ConstructOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const - -> LogicalResult override { - auto loc = op.getLoc(); - SmallVector resultTypes; - auto valueTypes = op.getType().getValueTypes(); - resultTypes.reserve(valueTypes.size() + 1); - resultTypes.push_back(rewriter.getI1Type()); - resultTypes.append(valueTypes.begin(), valueTypes.end()); - - auto callcc = rewriter.create(loc, resultTypes); - Value const retCont = callcc.body().getArgument(0); - - SmallVector results; - results.reserve(valueTypes.size() + 1); - auto &body = op.bodyRegion(); - - // Define the new missing continuation - rewriter.setInsertionPointToStart(&callcc.body().front()); - auto missingCont = rewriter.create(loc); - rewriter.setInsertionPointToStart(&missingCont.bodyRegion().front()); - auto constFalse = rewriter.create(loc, rewriter.getBoolAttr(false)); - results.push_back(constFalse); - llvm::transform(valueTypes, std::back_inserter(results), - [&](mlir::Type type) { return rewriter.create(loc, type); }); - rewriter.create(loc, retCont, results); - - // Define the new present continuation - results.clear(); - rewriter.setInsertionPointAfter(missingCont); - auto presentCont = rewriter.create(loc, valueTypes); - rewriter.setInsertionPointToStart(&presentCont.bodyRegion().front()); - auto constTrue = rewriter.create(loc, rewriter.getBoolAttr(true)); - results.push_back(constTrue); - results.append(presentCont.bodyRegion().args_begin(), presentCont.bodyRegion().args_end()); - rewriter.create(loc, retCont, results); - - rewriter.mergeBlocks(&body.front(), &callcc.body().front(), {missingCont, presentCont}); - - // Cast results back to Option type and replace - rewriter.setInsertionPointAfter(callcc); - rewriter.replaceOpWithNewOp(op, op.getType(), - callcc.getResults()); - - return mlir::success(); - } -}; - -struct DestructOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - auto matchAndRewrite(DestructOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const - -> LogicalResult override { - SmallVector values; - Value isDefined; - for (auto option : adaptor.inputs()) { - LoweredOption unpack = unpackOptional(rewriter, op.getLoc(), option); - values.append(unpack.values().begin(), unpack.values().end()); - if (!isDefined) { - isDefined = unpack.isDefined(); - } else { - isDefined = - rewriter.create(op.getLoc(), isDefined, unpack.isDefined()); - } - } - - rewriter.replaceOpWithNewOp(op, isDefined, op.presentCont(), values, op.missingCont(), - ValueRange{}); - return mlir::success(); - } -}; - -} // end namespace - -void LowerOptionPass::runOnOperation() { - RewritePatternSet patterns(&getContext()); - populateLowerOptionConversionPatterns(patterns); - - ConversionTarget target(getContext()); - target.addIllegalDialect(); - target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); - if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) - signalPassFailure(); -} - -void populateLowerOptionConversionPatterns(RewritePatternSet &patterns) { - patterns.add(patterns.getContext()); -} - -auto createLowerOptionPass() -> std::unique_ptr { - return std::make_unique(); -} - -} // namespace hail::ir diff --git a/query/lib/Conversion/LowerSandbox/CMakeLists.txt b/query/lib/Conversion/LowerSandbox/CMakeLists.txt deleted file mode 100644 index 86d6f1541b9..00000000000 --- a/query/lib/Conversion/LowerSandbox/CMakeLists.txt +++ /dev/null @@ -1,19 +0,0 @@ -add_mlir_conversion_library(HailLowerSandbox - LowerSandbox.cpp - - ADDITIONAL_HEADER_DIRS - ${HAIL_MAIN_INCLUDE_DIR}/Conversion/LowerSandbox - - DEPENDS - HailConversionPassIncGen - - LINK_COMPONENTS - Core - - LINK_LIBS PUBLIC - HailSandboxDialect - MLIRArithmeticDialect - MLIRPass - MLIRTensorDialect - MLIRTransforms - ) diff --git a/query/lib/Conversion/LowerSandbox/LowerSandbox.cpp b/query/lib/Conversion/LowerSandbox/LowerSandbox.cpp deleted file mode 100644 index 6bc43711ffc..00000000000 --- a/query/lib/Conversion/LowerSandbox/LowerSandbox.cpp +++ /dev/null @@ -1,192 +0,0 @@ -#include "hail/Conversion/LowerSandbox/LowerSandbox.h" - -#include "../PassDetail.h" - -#include "hail/Dialect/Sandbox/IR/Sandbox.h" -#include "hail/Support/MLIR.h" - -#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/Transforms/DialectConversion.h" -#include "llvm/ADT/SmallVector.h" - -namespace hail::ir { - -struct LowerSandboxPass : public LowerSandboxBase { - void runOnOperation() override; -}; - -namespace { -struct AddIOpConversion : public OpConversionPattern { - AddIOpConversion(MLIRContext *context) : OpConversionPattern(context, /*benefit=*/1) {} - - auto matchAndRewrite(AddIOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const - -> LogicalResult override { - rewriter.replaceOpWithNewOp(op, adaptor.lhs(), adaptor.rhs()); - return success(); - } -}; - -struct ConstantOpConversion : public OpConversionPattern { - ConstantOpConversion(MLIRContext *context) - : OpConversionPattern(context, /*benefit=*/1) {} - - auto matchAndRewrite(ConstantOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const - -> LogicalResult override { - auto value = adaptor.valueAttr().cast().getValue(); - Attribute newAttr; - Type newType; - if (op.output().getType().isa()) { - newType = rewriter.getI1Type(); - newAttr = rewriter.getBoolAttr(value == 0); - } else { - newType = rewriter.getI32Type(); - newAttr = rewriter.getIntegerAttr(newType, value); - } - rewriter.replaceOpWithNewOp(op, newAttr, newType); - return success(); - } -}; - -struct ComparisonOpConversion : public OpConversionPattern { - ComparisonOpConversion(MLIRContext *context) - : OpConversionPattern(context, /*benefit=*/1) {} - - auto matchAndRewrite(ComparisonOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const -> LogicalResult override { - mlir::arith::CmpIPredicate pred{}; - - switch (adaptor.predicate()) { - case CmpPredicate::LT: - pred = mlir::arith::CmpIPredicate::slt; - break; - case CmpPredicate::LTEQ: - pred = mlir::arith::CmpIPredicate::sle; - break; - case CmpPredicate::GT: - pred = mlir::arith::CmpIPredicate::sgt; - break; - case CmpPredicate::GTEQ: - pred = mlir::arith::CmpIPredicate::sge; - break; - case CmpPredicate::EQ: - pred = mlir::arith::CmpIPredicate::eq; - break; - case CmpPredicate::NEQ: - pred = mlir::arith::CmpIPredicate::ne; - break; - } - - rewriter.replaceOpWithNewOp(op, pred, adaptor.lhs(), adaptor.rhs()); - return success(); - } -}; - -auto getLoweredType(Builder &b, Type t) -> Type { - if (t.isa()) - return b.getI32Type(); - - if (t.isa()) - return b.getI1Type(); - - if (t.isa()) { - auto loweredElem = getLoweredType(b, t.cast().getElementType()); - SmallVector const v = {-1}; - - return RankedTensorType::get(v, loweredElem); - } - - return {}; -} - -struct MakeArrayOpConversion : public OpConversionPattern { - MakeArrayOpConversion(MLIRContext *context) - : OpConversionPattern(context, /*benefit=*/1) {} - - auto matchAndRewrite(MakeArrayOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const - -> LogicalResult override { - - auto elems = adaptor.elems(); - SmallVector const v = {op->getNumOperands()}; - Type const loweredElem = - op.getNumOperands() > 0 - ? adaptor.elems()[0].getType() - : getLoweredType(rewriter, op.result().getType().cast().getElementType()); - auto tensorType = RankedTensorType::get(v, loweredElem); - rewriter.replaceOpWithNewOp(op, tensorType, elems); - - return success(); - } -}; - -struct ArrayRefOpConversion : public OpConversionPattern { - ArrayRefOpConversion(MLIRContext *context) - : OpConversionPattern(context, /*benefit=*/1) {} - - auto matchAndRewrite(ArrayRefOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const - -> LogicalResult override { - - auto a = adaptor.array(); - auto idx = adaptor.index(); - rewriter.replaceOpWithNewOp(op, a, idx); - - return success(); - } -}; - -struct PrintOpConversion : public OpConversionPattern { - PrintOpConversion(MLIRContext *context) : OpConversionPattern(context, /*benefit=*/1) {} - - auto matchAndRewrite(PrintOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const - -> LogicalResult override { - rewriter.replaceOpWithNewOp(op, adaptor.value()); - return success(); - } -}; - -struct UnrealizedCastConversion : public OpConversionPattern { - UnrealizedCastConversion(MLIRContext *context) - : OpConversionPattern(context, - /*benefit=*/1) {} - - auto matchAndRewrite(mlir::UnrealizedConversionCastOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const -> LogicalResult override { - rewriter.replaceOp(op, adaptor.getInputs()); - return success(); - } -}; - -} // end namespace - -void LowerSandboxPass::runOnOperation() { - RewritePatternSet patterns(&getContext()); - populateLowerSandboxConversionPatterns(patterns); - - // Configure conversion to lower out SCF operations. - ConversionTarget target(getContext()); - target.addIllegalDialect(); - target.addDynamicallyLegalOp([](Operation *op) { - auto cond = [](Type type) { - return type.isa() || type.isa() || type.isa(); - }; - return llvm::none_of(op->getOperandTypes(), cond) && llvm::none_of(op->getResultTypes(), cond); - }); - target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); - if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) - signalPassFailure(); -} - -void populateLowerSandboxConversionPatterns(RewritePatternSet &patterns) { - patterns.add( - patterns.getContext()); -} - -auto createLowerSandboxPass() -> std::unique_ptr { - return std::make_unique(); -} - -} // namespace hail::ir diff --git a/query/lib/Conversion/LowerToLLVM/CMakeLists.txt b/query/lib/Conversion/LowerToLLVM/CMakeLists.txt deleted file mode 100644 index 93f460a0d93..00000000000 --- a/query/lib/Conversion/LowerToLLVM/CMakeLists.txt +++ /dev/null @@ -1,23 +0,0 @@ -add_mlir_conversion_library(HailLowerToLLVM - LowerToLLVM.cpp - - ADDITIONAL_HEADER_DIRS - ${HAIL_MAIN_INCLUDE_DIR}/Conversion/LowerToLLVM - - DEPENDS - HailConversionPassIncGen - - LINK_COMPONENTS - Core - - LINK_LIBS PUBLIC - HailSandboxDialect - MLIRControlFlowToLLVM - MLIRLLVMDialect - MLIRLLVMCommonConversion - MLIRFuncToLLVM - MLIRMemRefToLLVM - MLIRPass - MLIRSCFDialect - MLIRTransforms - ) diff --git a/query/lib/Conversion/LowerToLLVM/LowerToLLVM.cpp b/query/lib/Conversion/LowerToLLVM/LowerToLLVM.cpp deleted file mode 100644 index 3b10ce12fc8..00000000000 --- a/query/lib/Conversion/LowerToLLVM/LowerToLLVM.cpp +++ /dev/null @@ -1,141 +0,0 @@ -#include "hail/Conversion/LowerToLLVM/LowerToLLVM.h" - -#include "../PassDetail.h" - -#include "hail/Dialect/Sandbox/IR/Sandbox.h" -#include "hail/Support/MLIR.h" - -#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" -#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" -#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" -#include "mlir/Conversion/LLVMCommon/ConversionTarget.h" -#include "mlir/Conversion/LLVMCommon/TypeConverter.h" -#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" - -namespace hail::ir { - -struct LowerToLLVMPass : public LowerToLLVMBase { - void runOnOperation() override; -}; - -namespace { - -class PrintOpLowering : public OpConversionPattern { -public: - explicit PrintOpLowering(MLIRContext *context) - : OpConversionPattern(context, /*benefit=*/1) {} - - auto matchAndRewrite(PrintOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const - -> LogicalResult override { - - if (!(op.getOperand().getType() == rewriter.getI32Type() - || op.getOperand().getType() == rewriter.getI1Type())) - return failure(); - - auto loc = op->getLoc(); - - auto parentModule = op->getParentOfType(); - - // Get a symbol reference to the printf function, inserting it if necessary. - auto printfRef = getOrInsertPrintf(rewriter, parentModule); - Value const formatSpecifierCst = - getOrCreateGlobalString(loc, rewriter, "frmt_spec", StringRef("%i \0", 4), parentModule); - Value const newLineCst = - getOrCreateGlobalString(loc, rewriter, "nl", StringRef("\n\0", 2), parentModule); - - rewriter.create(loc, printfRef, rewriter.getIntegerType(32), - ArrayRef({formatSpecifierCst, op.getOperand()})); - rewriter.create(loc, printfRef, rewriter.getIntegerType(32), newLineCst); - - // Notify the rewriter that this operation has been removed. - rewriter.eraseOp(op); - return success(); - } - -private: - /// Return a symbol reference to the printf function, inserting it into the - /// module if necessary. - static auto getOrInsertPrintf(PatternRewriter &rewriter, mlir::ModuleOp module) - -> FlatSymbolRefAttr { - auto *context = module.getContext(); - if (module.lookupSymbol("printf")) - return SymbolRefAttr::get(context, "printf"); - - // Create a function declaration for printf, the signature is: - // * `i32 (i8*, ...)` - auto llvmI32Ty = IntegerType::get(context, 32); - auto llvmI8PtrTy = mlir::LLVM::LLVMPointerType::get(IntegerType::get(context, 8)); - auto llvmFnType = mlir::LLVM::LLVMFunctionType::get(llvmI32Ty, llvmI8PtrTy, - /*isVarArg=*/true); - - // Insert the printf function into the body of the parent module. - PatternRewriter::InsertionGuard const insertGuard(rewriter); - rewriter.setInsertionPointToStart(module.getBody()); - rewriter.create(module.getLoc(), "printf", llvmFnType); - return SymbolRefAttr::get(context, "printf"); - } - - /// Return a value representing an access into a global string with the given - /// name, creating the string if necessary. - static auto getOrCreateGlobalString(Location loc, OpBuilder &builder, StringRef name, - StringRef value, mlir::ModuleOp module) -> Value { - // Create the global at the entry of the module. - mlir::LLVM::GlobalOp global; - if (global = module.lookupSymbol(name); !global) { - OpBuilder::InsertionGuard const insertGuard(builder); - builder.setInsertionPointToStart(module.getBody()); - auto type = - mlir::LLVM::LLVMArrayType::get(IntegerType::get(builder.getContext(), 8), value.size()); - global = builder.create(loc, type, /*isConstant=*/true, - mlir::LLVM::Linkage::Internal, name, - builder.getStringAttr(value), - /*alignment=*/0); - } - - // Get the pointer to the first character in the global string. - Value const globalPtr = builder.create(loc, global); - Value const cst0 = - builder.create(loc, IntegerType::get(builder.getContext(), 64), - builder.getIntegerAttr(builder.getIndexType(), 0)); - return builder.create( - loc, mlir::LLVM::LLVMPointerType::get(IntegerType::get(builder.getContext(), 8)), globalPtr, - ArrayRef({cst0, cst0})); - } -}; - -} // end namespace - -void populateLowerLLVMConversionPatterns(mlir::LLVMTypeConverter &typeConverter, - RewritePatternSet &patterns) { - // populateSCFToControlFlowConversionPatterns(patterns); - mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter, patterns); - populateMemRefToLLVMConversionPatterns(typeConverter, patterns); - mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns); - populateFuncToLLVMConversionPatterns(typeConverter, patterns); - - // The only remaining operation to lower from the `toy` dialect, is the - // PrintOp. - patterns.add(patterns.getContext()); -} - -void LowerToLLVMPass::runOnOperation() { - mlir::LLVMConversionTarget target(getContext()); - target.addLegalOp(); - - RewritePatternSet patterns(&getContext()); - mlir::LLVMTypeConverter typeConverter(patterns.getContext()); - populateLowerLLVMConversionPatterns(typeConverter, patterns); - - // We want to completely lower to LLVM, so we use a `FullConversion`. This - // ensures that only legal operations will remain after the conversion. - auto module = getOperation(); - if (failed(applyFullConversion(module, target, std::move(patterns)))) - signalPassFailure(); -} - -auto createLowerToLLVMPass() -> std::unique_ptr { - return std::make_unique(); -} - -} // namespace hail::ir diff --git a/query/lib/Conversion/OptionToGenericOption/CMakeLists.txt b/query/lib/Conversion/OptionToGenericOption/CMakeLists.txt deleted file mode 100644 index d82610ab9c0..00000000000 --- a/query/lib/Conversion/OptionToGenericOption/CMakeLists.txt +++ /dev/null @@ -1,18 +0,0 @@ -add_mlir_conversion_library(HailOptionToGenericOption - OptionToGenericOption.cpp - - ADDITIONAL_HEADER_DIRS - ${HAIL_MAIN_INCLUDE_DIR}/Conversion/OptionToGenericOption - - DEPENDS - HailConversionPassIncGen - - LINK_COMPONENTS - Core - - LINK_LIBS PUBLIC - HailCPSDialect - HailOptionDialect - MLIRPass - MLIRTransforms - ) diff --git a/query/lib/Conversion/OptionToGenericOption/OptionToGenericOption.cpp b/query/lib/Conversion/OptionToGenericOption/OptionToGenericOption.cpp deleted file mode 100644 index dffb747aaec..00000000000 --- a/query/lib/Conversion/OptionToGenericOption/OptionToGenericOption.cpp +++ /dev/null @@ -1,77 +0,0 @@ -#include "hail/Conversion/OptionToGenericOption/OptionToGenericOption.h" - -#include "../PassDetail.h" - -#include "hail/Dialect/CPS/IR/CPS.h" -#include "hail/Dialect/Option/IR/Option.h" -#include "hail/Support/MLIR.h" - -#include "mlir/Transforms/DialectConversion.h" - -namespace hail::ir { - -struct OptionToGenericOptionPass : public OptionToGenericOptionBase { - void runOnOperation() override; -}; - -namespace { - -class ConvertMapOp : public OpConversionPattern { -public: - explicit ConvertMapOp(MLIRContext *context) - : OpConversionPattern(context, /*benefit=*/1) {} - - auto matchAndRewrite(MapOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const - -> LogicalResult override { - auto loc = op.getLoc(); - - SmallVector valueTypes; - for (auto option : op.getOperandTypes()) { - auto types = option.cast().getValueTypes(); - valueTypes.append(types.begin(), types.end()); - } - - auto construct = rewriter.create(loc, op.getType().getValueTypes()); - - rewriter.setInsertionPointToStart(construct.getBody()); - auto bodyCont = rewriter.create(loc, valueTypes); - rewriter.mergeBlocks(op.getBody(), bodyCont.getBody(), bodyCont.getBody()->getArguments()); - auto yield = llvm::cast(bodyCont.getBody()->getTerminator()); - rewriter.setInsertionPointAfter(yield); - rewriter.replaceOpWithNewOp(yield, construct.getPresentCont(), - yield.getOperands()); - - rewriter.setInsertionPointAfter(bodyCont); - rewriter.create(loc, op.getOperands(), construct.getMissingCont(), bodyCont); - - rewriter.replaceOp(op, construct->getResults()); - return success(); - } -}; - -} // end namespace - -void populateOptionToGenericOptionConversionPatterns(RewritePatternSet &patterns) { - patterns.add(patterns.getContext()); -} - -void OptionToGenericOptionPass::runOnOperation() { - RewritePatternSet patterns(&getContext()); - populateOptionToGenericOptionConversionPatterns(patterns); - - ConversionTarget target(getContext()); - target.addIllegalDialect(); - target.addLegalOp(); - - // We want to completely lower to LLVM, so we use a `FullConversion`. This - // ensures that only legal operations will remain after the conversion. - target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); - if (failed(mlir::applyPartialConversion(getOperation(), target, std::move(patterns)))) - signalPassFailure(); -} - -auto createOptionToGenericOptionPass() -> std::unique_ptr { - return std::make_unique(); -} - -} // namespace hail::ir diff --git a/query/lib/Conversion/PassDetail.h b/query/lib/Conversion/PassDetail.h deleted file mode 100644 index 132b45607e0..00000000000 --- a/query/lib/Conversion/PassDetail.h +++ /dev/null @@ -1,23 +0,0 @@ -// NOLINTNEXTLINE(llvm-header-guard) -#ifndef HAIL_CONVERSION_PASSDETAIL_H_ -#define HAIL_CONVERSION_PASSDETAIL_H_ - -#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/Pass/Pass.h" - -namespace hail { -namespace ir { -class SandboxDialect; -} // namespace ir - -#define GEN_PASS_CLASSES -#include "hail/Conversion/Passes.h.inc" - -} // namespace hail - -#endif // HAIL_CONVERSION_PASSDETAIL_H_ diff --git a/query/lib/Dialect/CMakeLists.txt b/query/lib/Dialect/CMakeLists.txt deleted file mode 100644 index df05150a143..00000000000 --- a/query/lib/Dialect/CMakeLists.txt +++ /dev/null @@ -1,4 +0,0 @@ -add_subdirectory(CPS) -add_subdirectory(Missing) -add_subdirectory(Option) -add_subdirectory(Sandbox) diff --git a/query/lib/Dialect/CPS/CMakeLists.txt b/query/lib/Dialect/CPS/CMakeLists.txt deleted file mode 100644 index f33061b2d87..00000000000 --- a/query/lib/Dialect/CPS/CMakeLists.txt +++ /dev/null @@ -1 +0,0 @@ -add_subdirectory(IR) diff --git a/query/lib/Dialect/CPS/IR/CMakeLists.txt b/query/lib/Dialect/CPS/IR/CMakeLists.txt deleted file mode 100644 index 18015e85d4c..00000000000 --- a/query/lib/Dialect/CPS/IR/CMakeLists.txt +++ /dev/null @@ -1,13 +0,0 @@ -add_hail_dialect_library(HailCPSDialect - CPSDialect.cpp - CPSOps.cpp - - ADDITIONAL_HEADER_DIRS - ${PROJECT_SOURCE_DIR}/include/Dialect/CPS/IR - - DEPENDS - MLIRCPSOpsIncGen - - LINK_LIBS PUBLIC - MLIRIR -) diff --git a/query/lib/Dialect/CPS/IR/CPSDialect.cpp b/query/lib/Dialect/CPS/IR/CPSDialect.cpp deleted file mode 100644 index 3d922047f4b..00000000000 --- a/query/lib/Dialect/CPS/IR/CPSDialect.cpp +++ /dev/null @@ -1,194 +0,0 @@ -#include "hail/Dialect/CPS/IR/CPS.h" - -#include "hail/Support/MLIR.h" - -#include "mlir/IR/Builders.h" -#include "mlir/IR/DialectImplementation.h" -#include "mlir/IR/OpImplementation.h" -#include "mlir/IR/Operation.h" -#include "mlir/Support/LLVM.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/TypeSwitch.h" - -using namespace hail::ir; - -#define GET_TYPEDEF_CLASSES -#include "hail/Dialect/CPS/IR/CPSOpsDialect.cpp.inc" -#include "hail/Dialect/CPS/IR/CPSOpsTypes.cpp.inc" - -void CPSDialect::initialize() { - addOperations< -#define GET_OP_LIST -#include "hail/Dialect/CPS/IR/CPSOps.cpp.inc" - >(); - addTypes< -#define GET_TYPEDEF_LIST -#include "hail/Dialect/CPS/IR/CPSOpsTypes.cpp.inc" - >(); -} - -//===----------------------------------------------------------------------===// -// ContinuationType -//===----------------------------------------------------------------------===// - -namespace hail::ir::detail { - -struct ContinuationTypeStorage final : public mlir::TypeStorage, - public llvm::TrailingObjects { - using KeyTy = TypeRange; - - ContinuationTypeStorage(unsigned numTypes) : numElements(numTypes) {} - - /// Construction. - static auto construct(mlir::TypeStorageAllocator &allocator, TypeRange key) - -> ContinuationTypeStorage * { - // Allocate a new storage instance. - auto byteSize = ContinuationTypeStorage::totalSizeToAlloc(key.size()); - auto *rawMem = allocator.allocate(byteSize, alignof(ContinuationTypeStorage)); - // NOLINTNEXTLINE(*-owning-memory) - auto *result = ::new (rawMem) ContinuationTypeStorage(key.size()); - - // Copy in the element types into the trailing storage. - std::uninitialized_copy(key.begin(), key.end(), result->getTrailingObjects()); - return result; - } - - auto operator==(KeyTy const &key) const -> bool { return key == getTypes(); } - - /// Return the number of held types. - auto size() const -> unsigned { return numElements; } - - /// Return the held types. - auto getTypes() const -> ArrayRef { return {getTrailingObjects(), size()}; } - - /// The number of tuple elements. - unsigned numElements; -}; - -} // namespace hail::ir::detail - -auto ContinuationType::getInputs() const -> ArrayRef { return getImpl()->getTypes(); } - -auto ContinuationType::parse(mlir::AsmParser &parser) -> Type { - Builder const builder(parser.getContext()); - SmallVector inputs; - - auto parseType = [&]() { - auto element = mlir::FieldParser::parse(parser); - if (failed(element)) - return failure(); - inputs.push_back(*element); - return success(); - }; - - if (parser.parseCommaSeparatedList(mlir::AsmParser::Delimiter::LessGreater, parseType)) { - parser.emitError(parser.getCurrentLocation(), - "failed to parse CPS_ContType parameter 'inputs' which is to be a " - "`ArrayRef`"); - return {}; - } - return ContinuationType::get(parser.getContext(), inputs); -} - -void ContinuationType::print(mlir::AsmPrinter &odsPrinter) const { - Builder const odsBuilder(getContext()); - odsPrinter << "<"; - odsPrinter.printStrippedAttrOrType(getInputs()); - odsPrinter << ">"; -} - -//===----------------------------------------------------------------------===// -// CallCCOp -//===----------------------------------------------------------------------===// - -auto CallCCOp::parse(mlir::OpAsmParser &parser, OperationState &result) -> mlir::ParseResult { - auto &builder = parser.getBuilder(); - - // Parse the return continuation argument and return types - mlir::OpAsmParser::UnresolvedOperand retContName; - SmallVector retTypes; - if (parser.parseOperand(retContName, /*allowResultNumber=*/false) - || parser.parseOptionalColonTypeList(retTypes)) - return failure(); - result.addTypes(retTypes); - - Type const retContType = builder.getType(retTypes); - mlir::OpAsmParser::Argument const retContArg{retContName, retContType, {}, {}}; - - // If attributes are present, parse them. - NamedAttrList parsedAttributes; - if (parser.parseOptionalAttrDictWithKeyword(parsedAttributes)) - return failure(); - result.attributes.append(parsedAttributes); - - auto *body = result.addRegion(); - mlir::SMLoc const loc = parser.getCurrentLocation(); - if (parser.parseRegion(*body, retContArg)) - return failure(); - if (body->empty()) - return parser.emitError(loc, "expected non-empty function body"); - - return success(); -} - -void CallCCOp::print(mlir::OpAsmPrinter &p) { - p << ' ' << getRegion().getArgument(0); - if (getResultTypes().size() > 0) { - p << " : " << getResultTypes(); - } - p.printOptionalAttrDictWithKeyword(getOperation()->getAttrs()); - p << ' '; - p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); -} - -//===----------------------------------------------------------------------===// -// DefContOp -//===----------------------------------------------------------------------===// - -auto DefContOp::parse(mlir::OpAsmParser &parser, OperationState &result) -> mlir::ParseResult { - auto &builder = parser.getBuilder(); - - // Parse the arguments list - SmallVector arguments; - if (parser.parseArgumentList(arguments, mlir::OpAsmParser::Delimiter::Paren, - /*allowType=*/true)) - return failure(); - - SmallVector argTypes; - argTypes.reserve(arguments.size()); - for (auto &arg : arguments) - argTypes.push_back(arg.type); - auto type = builder.getType(argTypes); - result.addTypes(type); - - // If attributes are present, parse them. - NamedAttrList parsedAttributes; - if (parser.parseOptionalAttrDictWithKeyword(parsedAttributes)) - return failure(); - result.attributes.append(parsedAttributes); - - // Parse the body - auto *body = result.addRegion(); - mlir::SMLoc const loc = parser.getCurrentLocation(); - if (parser.parseRegion(*body, arguments)) - return failure(); - if (body->empty()) - return parser.emitError(loc, "expected non-empty function body"); - - return success(); -} - -void DefContOp::print(mlir::OpAsmPrinter &p) { - // Print the arguments list - p << '('; - llvm::interleaveComma(getRegion().getArguments(), p, - [&](auto arg) { p.printRegionArgument(arg); }); - p << ')'; - - // If attributes are present, print them. - p.printOptionalAttrDictWithKeyword(getOperation()->getAttrs()); - - // Print the body - p << ' '; - p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); -} diff --git a/query/lib/Dialect/CPS/IR/CPSOps.cpp b/query/lib/Dialect/CPS/IR/CPSOps.cpp deleted file mode 100644 index 882d0c6c15a..00000000000 --- a/query/lib/Dialect/CPS/IR/CPSOps.cpp +++ /dev/null @@ -1,109 +0,0 @@ -#include "hail/Dialect/CPS/IR/CPS.h" - -#include "hail/Support/MLIR.h" - -#include "mlir/IR/Builders.h" -#include "mlir/IR/Location.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Support/LogicalResult.h" - -#define GET_OP_CLASSES -#include "hail/Dialect/CPS/IR/CPSOps.cpp.inc" - -using namespace hail::ir; - -//===----------------------------------------------------------------------===// -// ApplyContOp -//===----------------------------------------------------------------------===// - -namespace { -struct InlineCont : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - auto matchAndRewrite(ApplyContOp apply, PatternRewriter &rewriter) const - -> LogicalResult override { - auto defcont = apply.cont().getDefiningOp(); - if (!defcont || !defcont->hasOneUse()) - return failure(); - - rewriter.mergeBlocks(defcont.getBody(), apply->getBlock(), apply.args()); - rewriter.eraseOp(apply); - rewriter.eraseOp(defcont); - - return success(); - } -}; - -} // namespace - -void ApplyContOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); -} - -//===----------------------------------------------------------------------===// -// CallCCOp -//===----------------------------------------------------------------------===// - -void CallCCOp::build(OpBuilder &odsBuilder, OperationState &odsState, TypeRange resultTypes) { - odsState.addTypes(resultTypes); - auto *region = odsState.addRegion(); - region->emplaceBlock(); - region->addArgument(odsBuilder.getType(resultTypes), odsState.location); -} - -namespace { - -struct TrivialCallCC : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - auto matchAndRewrite(CallCCOp callcc, PatternRewriter &rewriter) const -> LogicalResult override { - auto terminator = dyn_cast(callcc.getBody()->getTerminator()); - if (!terminator) - return failure(); - if (terminator.cont() != callcc.getBody()->getArgument(0)) - return failure(); - SmallVector const values(terminator.args().begin(), terminator.args().end()); - rewriter.eraseOp(terminator); - callcc.getBody()->eraseArgument(0); - rewriter.mergeBlockBefore(callcc.getBody(), callcc); - rewriter.replaceOp(callcc, values); - return success(); - } -}; - -} // namespace - -void CallCCOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); -} - -//===----------------------------------------------------------------------===// -// DefContOp -//===----------------------------------------------------------------------===// - -void DefContOp::build(OpBuilder &odsBuilder, OperationState &odsState, TypeRange argTypes) { - odsState.addTypes(odsBuilder.getType(argTypes)); - auto *region = odsState.addRegion(); - region->emplaceBlock(); - SmallVector const locs(argTypes.size(), odsState.location); - region->addArguments(argTypes, locs); -} - -auto DefContOp::verifyRegions() -> LogicalResult { - auto *body = getBody(); - ContinuationType const type = getType(); - auto numArgs = type.getInputs().size(); - - if (body->getNumArguments() != numArgs) - return emitOpError("mismatch in number of basic block args and continuation type"); - - unsigned i = 0; - for (auto e : llvm::zip(body->getArguments(), type.getInputs())) { - if (std::get<0>(e).getType() != std::get<1>(e)) - return emitOpError() << "type mismatch between " << i << "th block arg and continuation type"; - - i++; - } - return success(); -} diff --git a/query/lib/Dialect/Missing/CMakeLists.txt b/query/lib/Dialect/Missing/CMakeLists.txt deleted file mode 100644 index f33061b2d87..00000000000 --- a/query/lib/Dialect/Missing/CMakeLists.txt +++ /dev/null @@ -1 +0,0 @@ -add_subdirectory(IR) diff --git a/query/lib/Dialect/Missing/IR/CMakeLists.txt b/query/lib/Dialect/Missing/IR/CMakeLists.txt deleted file mode 100644 index 0a981face9a..00000000000 --- a/query/lib/Dialect/Missing/IR/CMakeLists.txt +++ /dev/null @@ -1,13 +0,0 @@ -add_hail_dialect_library(HailMissingDialect - MissingDialect.cpp - MissingOps.cpp - - ADDITIONAL_HEADER_DIRS - ${PROJECT_SOURCE_DIR}/include/Dialect/Missing/IR - - DEPENDS - MLIRMissingOpsIncGen - - LINK_LIBS PUBLIC - MLIRIR -) diff --git a/query/lib/Dialect/Missing/IR/MissingDialect.cpp b/query/lib/Dialect/Missing/IR/MissingDialect.cpp deleted file mode 100644 index 5c28c007b14..00000000000 --- a/query/lib/Dialect/Missing/IR/MissingDialect.cpp +++ /dev/null @@ -1,20 +0,0 @@ -#include "hail/Dialect/Missing/IR/Missing.h" - -#include "mlir/IR/Builders.h" -#include "mlir/IR/DialectImplementation.h" -#include "mlir/IR/OpImplementation.h" -#include "mlir/IR/Operation.h" -#include "mlir/Support/LLVM.h" -#include "llvm/ADT/TypeSwitch.h" - -using namespace hail::ir; - -#define GET_TYPEDEF_CLASSES -#include "hail/Dialect/Missing/IR/MissingOpsDialect.cpp.inc" - -void MissingDialect::initialize() { - addOperations< -#define GET_OP_LIST -#include "hail/Dialect/Missing/IR/MissingOps.cpp.inc" - >(); -} diff --git a/query/lib/Dialect/Missing/IR/MissingOps.cpp b/query/lib/Dialect/Missing/IR/MissingOps.cpp deleted file mode 100644 index 52c5063b358..00000000000 --- a/query/lib/Dialect/Missing/IR/MissingOps.cpp +++ /dev/null @@ -1,18 +0,0 @@ -#include "hail/Dialect/Missing/IR/Missing.h" - -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/Location.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/IR/OpImplementation.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Support/LogicalResult.h" - -#include "llvm/ADT/APSInt.h" - -#include - -#define GET_OP_CLASSES -#include "hail/Dialect/Missing/IR/MissingOps.cpp.inc" - -namespace hail::ir {} diff --git a/query/lib/Dialect/Option/CMakeLists.txt b/query/lib/Dialect/Option/CMakeLists.txt deleted file mode 100644 index f33061b2d87..00000000000 --- a/query/lib/Dialect/Option/CMakeLists.txt +++ /dev/null @@ -1 +0,0 @@ -add_subdirectory(IR) diff --git a/query/lib/Dialect/Option/IR/CMakeLists.txt b/query/lib/Dialect/Option/IR/CMakeLists.txt deleted file mode 100644 index 933b7c352d2..00000000000 --- a/query/lib/Dialect/Option/IR/CMakeLists.txt +++ /dev/null @@ -1,14 +0,0 @@ -add_hail_dialect_library(HailOptionDialect - OptionDialect.cpp - OptionOps.cpp - - ADDITIONAL_HEADER_DIRS - ${PROJECT_SOURCE_DIR}/include/Dialect/Option/IR - - DEPENDS - MLIROptionOpsIncGen - - LINK_LIBS PUBLIC - HailCPSDialect - MLIRIR -) diff --git a/query/lib/Dialect/Option/IR/OptionDialect.cpp b/query/lib/Dialect/Option/IR/OptionDialect.cpp deleted file mode 100644 index d25d8de26e0..00000000000 --- a/query/lib/Dialect/Option/IR/OptionDialect.cpp +++ /dev/null @@ -1,93 +0,0 @@ -#include "hail/Dialect/Option/IR/Option.h" - -#include "hail/Support/MLIR.h" - -#include "mlir/IR/Builders.h" -#include "mlir/IR/DialectImplementation.h" -#include "mlir/IR/OpImplementation.h" -#include "mlir/Support/LLVM.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/TypeSwitch.h" - -using namespace hail::ir; - -#define GET_TYPEDEF_CLASSES -#include "hail/Dialect/Option/IR/OptionOpsDialect.cpp.inc" -#include "hail/Dialect/Option/IR/OptionOpsTypes.cpp.inc" - -void OptionDialect::initialize() { - addOperations< -#define GET_OP_LIST -#include "hail/Dialect/Option/IR/OptionOps.cpp.inc" - >(); - addTypes< -#define GET_TYPEDEF_LIST -#include "hail/Dialect/Option/IR/OptionOpsTypes.cpp.inc" - >(); -} - -namespace hail::ir::detail { - -struct OptionTypeStorage final : public mlir::TypeStorage, - public llvm::TrailingObjects { - using KeyTy = TypeRange; - - OptionTypeStorage(unsigned numTypes) : numElements(numTypes) {} - - /// Construction. - static auto construct(mlir::TypeStorageAllocator &allocator, TypeRange key) - -> OptionTypeStorage * { - // Allocate a new storage instance. - auto byteSize = OptionTypeStorage::totalSizeToAlloc(key.size()); - auto *rawMem = allocator.allocate(byteSize, alignof(OptionTypeStorage)); - // NOLINTNEXTLINE(*-owning-memory) - auto *result = ::new (rawMem) OptionTypeStorage(key.size()); - - // Copy in the element types into the trailing storage. - std::uninitialized_copy(key.begin(), key.end(), result->getTrailingObjects()); - return result; - } - - auto operator==(KeyTy const &key) const -> bool { return key == getTypes(); } - - /// Return the number of held types. - auto size() const -> unsigned { return numElements; } - - /// Return the held types. - auto getTypes() const -> ArrayRef { return {getTrailingObjects(), size()}; } - - /// The number of tuple elements. - unsigned numElements; -}; - -} // namespace hail::ir::detail - -auto OptionType::getValueTypes() const -> ArrayRef { return getImpl()->getTypes(); } - -auto OptionType::parse(mlir::AsmParser &parser) -> Type { - Builder const builder(parser.getContext()); - SmallVector inputs; - - auto parseType = [&]() { - auto element = mlir::FieldParser::parse(parser); - if (failed(element)) - return failure(); - inputs.push_back(*element); - return success(); - }; - - if (parser.parseCommaSeparatedList(mlir::AsmParser::Delimiter::LessGreater, parseType)) { - parser.emitError(parser.getCurrentLocation(), - "failed to parse OptionType parameter 'valueTypes' which is to be a " - "`ArrayRef`"); - return {}; - } - return OptionType::get(parser.getContext(), inputs); -} - -void OptionType::print(mlir::AsmPrinter &odsPrinter) const { - Builder const odsBuilder(getContext()); - odsPrinter << "<"; - odsPrinter.printStrippedAttrOrType(getValueTypes()); - odsPrinter << ">"; -} diff --git a/query/lib/Dialect/Option/IR/OptionOps.cpp b/query/lib/Dialect/Option/IR/OptionOps.cpp deleted file mode 100644 index 81215653e43..00000000000 --- a/query/lib/Dialect/Option/IR/OptionOps.cpp +++ /dev/null @@ -1,190 +0,0 @@ -#include "hail/Dialect/Option/IR/Option.h" - -#include "hail/Support/MLIR.h" - -#include "mlir/IR/Builders.h" -#include "mlir/IR/Location.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/IR/OpImplementation.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Support/LogicalResult.h" - -#include "llvm/ADT/SmallVector.h" - -#define GET_OP_CLASSES -#include "hail/Dialect/Option/IR/OptionOps.cpp.inc" - -using namespace hail::ir; - -//===----------------------------------------------------------------------===// -// MapOp -//===----------------------------------------------------------------------===// - -void MapOp::build(OpBuilder &odsBuilder, OperationState &odsState, TypeRange resultValueTypes, - ValueRange inputs) { - odsState.addTypes(odsBuilder.getType(resultValueTypes)); - auto *region = odsState.addRegion(); - region->emplaceBlock(); - llvm::SmallVector argTypes; - for (auto input : inputs) { - auto valueTypes = input.getType().cast().getValueTypes(); - argTypes.append(valueTypes.begin(), valueTypes.end()); - } - llvm::SmallVector const locs(argTypes.size(), odsState.location); - region->addArguments(argTypes, locs); -} - -//===----------------------------------------------------------------------===// -// ConstructOp -//===----------------------------------------------------------------------===// - -void ConstructOp::build(OpBuilder &odsBuilder, OperationState &odsState, TypeRange valueTypes) { - odsState.addTypes(odsBuilder.getType(valueTypes)); - auto *region = odsState.addRegion(); - region->emplaceBlock(); - region->addArgument(odsBuilder.getType(), odsState.location); - region->addArgument(odsBuilder.getType(valueTypes), odsState.location); -} - -//===----------------------------------------------------------------------===// -// DestructOp -//===----------------------------------------------------------------------===// - -namespace { - -struct DestructOfConstruct : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - auto matchAndRewrite(DestructOp destruct, PatternRewriter &rewriter) const - -> LogicalResult override { - ConstructOp source; - size_t sourceValuesStart = 0; - size_t curValueIdx = 0; - llvm::SmallVector sourceValueTypes; - llvm::SmallVector remainingValueTypes; - llvm::SmallVector remainingOptions; - for (auto const &input : llvm::enumerate(destruct.inputs())) { - if (auto construct = input.value().getDefiningOp()) { - if (!source && construct->hasOneUse()) { - source = construct; - auto valueTypes = construct.getType().getValueTypes(); - sourceValuesStart = curValueIdx; - curValueIdx += valueTypes.size(); - sourceValueTypes.append(valueTypes.begin(), valueTypes.end()); - } else { - auto valueTypes = input.value().getType().cast().getValueTypes(); - curValueIdx += valueTypes.size(); - remainingValueTypes.append(valueTypes.begin(), valueTypes.end()); - remainingOptions.push_back(input.value()); - } - } - } - - if (!source) - return failure(); - - // Define present continuation to pass to source. It will destruct all remaining options, with a - // present continuation that merges its values with those of this continuation, and pases them - // on to the original present continuation. - auto cont = rewriter.create(destruct.getLoc(), sourceValueTypes); - - rewriter.mergeBlockBefore(&source.bodyRegion().front(), destruct, - {destruct.missingCont(), cont}); - - // Define present continuation for the new destruct op, taking all remaining values, after - // removing those of 'source'. - rewriter.setInsertionPointToStart(cont.getBody()); - auto cont2 = rewriter.create(destruct.getLoc(), remainingValueTypes); - - // Create new destruct op taking all remaining options. - rewriter.create(destruct.getLoc(), remainingOptions, destruct.missingCont(), cont2); - - // Define body of 'cont2' to forward all values on to original present continuation. - rewriter.setInsertionPointToStart(cont2.getBody()); - llvm::SmallVector newValues(cont2.bodyRegion().args_begin(), - cont2.bodyRegion().args_end()); - newValues.insert(newValues.begin() + sourceValuesStart, cont.bodyRegion().args_begin(), - cont.bodyRegion().args_end()); - rewriter.create(destruct.getLoc(), destruct.presentCont(), newValues); - - rewriter.eraseOp(destruct); - rewriter.eraseOp(source); - return success(); - } -}; - -struct EmptyDestruct : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - auto matchAndRewrite(DestructOp destruct, PatternRewriter &rewriter) const - -> LogicalResult override { - if (!destruct.inputs().empty()) - return failure(); - - rewriter.replaceOpWithNewOp(destruct, destruct.presentCont(), - llvm::ArrayRef{}); - return success(); - } -}; - -} // end namespace - -void DestructOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); -} - -auto DestructOp::parse(mlir::OpAsmParser &parser, OperationState &result) -> mlir::ParseResult { - auto &builder = parser.getBuilder(); - llvm::SmallVector inputNames; - llvm::SmallVector contNames; - llvm::SmallVector optTypes; - NamedAttrList parsedAttributes; - if (parser.parseOperandList(inputNames, mlir::OpAsmParser::Delimiter::Paren) - || parser.parseOperandList(contNames, 2, mlir::OpAsmParser::Delimiter::Square) - || parser.parseOptionalAttrDict(parsedAttributes) - || parser.parseOptionalColonTypeList(optTypes)) - return failure(); - - llvm::SmallVector inputs; - inputs.reserve(inputNames.size()); - if (parser.resolveOperands(inputNames, optTypes, parser.getCurrentLocation(), inputs)) - return failure(); - - llvm::SmallVector valueTypes; - for (auto t : optTypes) { - auto optType = t.dyn_cast(); - if (!optType) - return failure(); - valueTypes.append(optType.getValueTypes().begin(), optType.getValueTypes().end()); - } - llvm::SmallVector conts; - if (parser.resolveOperands( - contNames, - {builder.getType(), builder.getType(valueTypes)}, - parser.getCurrentLocation(), conts)) - return failure(); - - result.addOperands(inputs); - result.addOperands(conts); - result.addAttributes(parsedAttributes); - - return success(); -} - -void DestructOp::print(mlir::OpAsmPrinter &p) { - p << '('; - p.printOperands(inputs()); - p << ')'; - p << '['; - p.printOperand(missingCont()); - p << ", "; - p.printOperand(presentCont()); - p << ']'; - - p.printOptionalAttrDictWithKeyword(getOperation()->getAttrs()); - - p << ' '; - if (!inputs().empty()) { - p << " : " << inputs().getTypes(); - } -} diff --git a/query/lib/Dialect/Sandbox/CMakeLists.txt b/query/lib/Dialect/Sandbox/CMakeLists.txt deleted file mode 100644 index f33061b2d87..00000000000 --- a/query/lib/Dialect/Sandbox/CMakeLists.txt +++ /dev/null @@ -1 +0,0 @@ -add_subdirectory(IR) diff --git a/query/lib/Dialect/Sandbox/IR/CMakeLists.txt b/query/lib/Dialect/Sandbox/IR/CMakeLists.txt deleted file mode 100644 index db7eeac1f2d..00000000000 --- a/query/lib/Dialect/Sandbox/IR/CMakeLists.txt +++ /dev/null @@ -1,14 +0,0 @@ -add_hail_dialect_library(HailSandboxDialect - SandboxDialect.cpp - SandboxOps.cpp - - ADDITIONAL_HEADER_DIRS - ${PROJECT_SOURCE_DIR}/include/Dialect/Sandbox/IR - - DEPENDS - MLIRSandboxOpsIncGen - MLIRSandboxOpsAttributesIncGen - - LINK_LIBS PUBLIC - MLIRIR -) diff --git a/query/lib/Dialect/Sandbox/IR/SandboxDialect.cpp b/query/lib/Dialect/Sandbox/IR/SandboxDialect.cpp deleted file mode 100644 index e38113028e1..00000000000 --- a/query/lib/Dialect/Sandbox/IR/SandboxDialect.cpp +++ /dev/null @@ -1,70 +0,0 @@ -#include "hail/Dialect/Sandbox/IR/Sandbox.h" - -#include "hail/Support/MLIR.h" - -#include "mlir/IR/Builders.h" -#include "mlir/IR/DialectImplementation.h" -#include "mlir/IR/OpImplementation.h" -#include "mlir/IR/Operation.h" -#include "llvm/ADT/TypeSwitch.h" - -using namespace hail::ir; - -auto SandboxDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, - Location loc) -> Operation * { - assert(type.isa() || type.isa()); - auto intAttr = value.cast(); - assert(intAttr); - auto constOp = builder.create(loc, type, intAttr); - return constOp; -} - -#define GET_TYPEDEF_CLASSES -#include "hail/Dialect/Sandbox/IR/SandboxOpsDialect.cpp.inc" -#include "hail/Dialect/Sandbox/IR/SandboxOpsTypes.cpp.inc" - -void SandboxDialect::initialize() { - addOperations< -#define GET_OP_LIST -#include "hail/Dialect/Sandbox/IR/SandboxOps.cpp.inc" - >(); - addTypes< -#define GET_TYPEDEF_LIST -#include "hail/Dialect/Sandbox/IR/SandboxOpsTypes.cpp.inc" - >(); -} - -void ArrayType::walkImmediateSubElements(llvm::function_ref walkAttrsFn, - llvm::function_ref walkTypesFn) const { - walkTypesFn(getElementType()); -} - -auto ArrayType::replaceImmediateSubElements(ArrayRef replAttrs, - ArrayRef replTypes) const -> Type { - return get(getContext(), replTypes.front()); -} - -auto ArrayType::verify(llvm::function_ref emitError, Type elementType) - -> LogicalResult { - if (elementType.isa() || elementType.isa() - || elementType.isa()) { - return success(); - } - - return failure(); -} - -auto ArrayType::parse(mlir::AsmParser &parser) -> Type { - Type elementType; - if (parser.parseLess()) - return {}; - if (parser.parseType(elementType)) - return nullptr; - if (parser.parseGreater()) - return {}; - return ArrayType::get(parser.getContext(), elementType); -} - -void ArrayType::print(mlir::AsmPrinter &odsPrinter) const { - odsPrinter << '<' << getElementType() << '>'; -} diff --git a/query/lib/Dialect/Sandbox/IR/SandboxOps.cpp b/query/lib/Dialect/Sandbox/IR/SandboxOps.cpp deleted file mode 100644 index e9a8620102a..00000000000 --- a/query/lib/Dialect/Sandbox/IR/SandboxOps.cpp +++ /dev/null @@ -1,173 +0,0 @@ -#include "hail/Dialect/Sandbox/IR/Sandbox.h" - -#include "hail/Support/MLIR.h" - -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/Location.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Support/LogicalResult.h" - -#include - -#include "hail/Dialect/Sandbox/IR/SandboxOpsEnums.cpp.inc" - -#define GET_OP_CLASSES - -#include "hail/Dialect/Sandbox/IR/SandboxOps.cpp.inc" - -namespace hail::ir { - -auto ConstantOp::fold(ArrayRef operands) -> OpFoldResult { return valueAttr(); } - -auto ConstantOp::verify() -> LogicalResult { - auto type = getType(); - // The value's type must match the return type. - auto valueType = value().getType(); - - if (valueType.isa() && type.isa()) { - return success(); - } - - if (valueType == IntegerType::get(getContext(), 1) && type.isa()) { - return success(); - } - - return emitOpError() << "bad constant: type=" << type << ", valueType=" << valueType; -} - -// NOLINTNEXTLINE(*-member-functions-to-static) -auto AddIOp::fold(ArrayRef operands) -> OpFoldResult { - assert(operands.size() == 2 && "binary op takes two operands"); - if (!operands[0] || !operands[1]) - return {}; - - if (operands[0].isa() && operands[1].isa()) { - auto lhs = operands[0].cast(); - auto rhs = operands[1].cast(); - - auto result = IntegerAttr::get(lhs.getType(), lhs.getValue() + rhs.getValue()); - assert(result); - return result; - } - - return {}; -} - -auto ComparisonOp::fold(ArrayRef operands) -> OpFoldResult { - assert(operands.size() == 2 && "comparison op takes two operands"); - if (!operands[0] || !operands[1]) - return {}; - - if (operands[0].isa() && operands[1].isa()) { - auto pred = predicate(); - auto lhs = operands[0].cast(); - auto rhs = operands[1].cast(); - - bool x = false; - auto l = lhs.getValue(); - auto r = rhs.getValue(); - switch (pred) { - case CmpPredicate::LT: - x = l.slt(r); - break; - case CmpPredicate::LTEQ: - x = l.sle(r); - break; - case CmpPredicate::GT: - x = l.sgt(r); - break; - case CmpPredicate::GTEQ: - x = l.sge(r); - break; - case CmpPredicate::EQ: - x = l == r; - break; - case CmpPredicate::NEQ: - x = l != r; - break; - } - - auto result = BoolAttr::get(getContext(), x); - return result; - } - - return {}; -} - -struct SimplifyAddConstAddConst : public OpRewritePattern { - SimplifyAddConstAddConst(MLIRContext *context) - : OpRewritePattern(context, /*benefit=*/1) {} - - auto matchAndRewrite(AddIOp op, PatternRewriter &rewriter) const -> LogicalResult override { - auto lhs = op.lhs().getDefiningOp(); - if (!lhs) - return failure(); - - auto lConst = lhs.rhs().getDefiningOp(); - auto rConst = op.rhs().getDefiningOp(); - if (!lConst || !rConst) - return failure(); - - auto sumConst = rewriter.create( - mlir::FusedLoc::get(op->getContext(), {lConst->getLoc(), rConst.getLoc()}, nullptr), - lConst.getType(), - IntegerAttr::get(lConst.getType(), lConst.value().cast().getValue() - + rConst.value().cast().getValue())); - rewriter.replaceOpWithNewOp(op, lhs.result().getType(), lhs.lhs(), sumConst); - return success(); - } -}; - -void AddIOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); -} - -auto ArrayRefOp::verify() -> LogicalResult { - auto type = getType(); - // The value's type must match the return type. - auto arrayType = array().getType(); - - if (!arrayType.isa()) { - return emitOpError() << "ArrayRef requires an array as input: " << arrayType; - } - - if (arrayType.cast().getElementType() != type) { - return emitOpError() << "ArrayRef return type is not the array element type: array=" - << arrayType << ", result=" << type; - } - - return success(); -} - -auto ArrayRefOp::fold(ArrayRef operands) -> OpFoldResult { - auto a = array().getDefiningOp(); - if (operands[1].isa() && a) { - auto idx = operands[1].cast().getInt(); - if (idx < 0 || idx >= a->getNumOperands()) - return {}; - return a->getOperands()[idx]; - } - return {}; -} - -auto MakeArrayOp::verify() -> LogicalResult { - auto assignedResultType = result().getType(); - - if (!assignedResultType.isa()) { - return emitOpError() << "MakeArray expects an ArrayType as return type, found " - << assignedResultType; - } - - auto elemType = assignedResultType.cast().getElementType(); - for (auto elem : elems()) { - if (elemType != elem.getType()) { - return emitOpError() << "MakeArray with return element type " << elemType - << " had element with type " << elem.getType(); - } - } - return success(); -} - -} // namespace hail::ir diff --git a/query/lib/Transforms/CMakeLists.txt b/query/lib/Transforms/CMakeLists.txt deleted file mode 100644 index 4f0a3fc5108..00000000000 --- a/query/lib/Transforms/CMakeLists.txt +++ /dev/null @@ -1,18 +0,0 @@ -add_mlir_library(HailTransforms - TestMissingnessAnalysis.cpp - - ADDITIONAL_HEADER_DIRS - ${HAIL_MAIN_INCLUDE_DIR}/Transforms - - DEPENDS - HailTransformsPassIncGen - - LINK_LIBS PUBLIC - HailAnalysis - MLIRAnalysis - MLIRCopyOpInterface - MLIRLoopLikeInterface - MLIRPass - MLIRSupport - MLIRTransformUtils - ) diff --git a/query/lib/Transforms/PassDetail.h b/query/lib/Transforms/PassDetail.h deleted file mode 100644 index 8208137cbe6..00000000000 --- a/query/lib/Transforms/PassDetail.h +++ /dev/null @@ -1,12 +0,0 @@ -// NOLINTNEXTLINE(llvm-header-guard) -#ifndef HAIL_TRANSFORMS_PASSDETAIL_H_ -#define HAIL_TRANSFORMS_PASSDETAIL_H_ - -#include "mlir/Pass/Pass.h" - -namespace hail::ir { -#define GEN_PASS_CLASSES -#include "hail/Transforms/Passes.h.inc" -} // namespace hail::ir - -#endif // HAIL_TRANSFORMS_PASSDETAIL_H_ diff --git a/query/lib/Transforms/TestMissingnessAnalysis.cpp b/query/lib/Transforms/TestMissingnessAnalysis.cpp deleted file mode 100644 index e6e03f32000..00000000000 --- a/query/lib/Transforms/TestMissingnessAnalysis.cpp +++ /dev/null @@ -1,95 +0,0 @@ -#include "./PassDetail.h" - -#include "hail/Analysis/MissingnessAnalysis.h" -#include "hail/Analysis/MissingnessAwareConstantPropagationAnalysis.h" -#include "hail/Support/MLIR.h" -#include "hail/Transforms/Passes.h" - -#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/Pass/Pass.h" - -using namespace hail::ir; - -static void rewrite(mlir::DataFlowSolver &solver, Operation *root) { - root->walk([&solver, root](Operation *op) { - auto *context = root->getContext(); - auto builder = Builder(root->getContext()); - SmallVector annotations; - - annotations.reserve(op->getNumOperands()); - for (Value const result : op->getResults()) { - auto *missingnessState = - solver.getOrCreateState>(result); - auto *constState = - solver.getOrCreateState>(result); - - if (missingnessState->isUninitialized()) - continue; - - if (missingnessState->getValue().isMissing()) { - annotations.push_back(builder.getStringAttr("Missing")); - } else if (auto constVal = constState->getValue().getConstantValue()) { - annotations.push_back(constVal); - } else if (missingnessState->getValue().isPresent()) { - annotations.push_back(builder.getStringAttr("Present")); - } else { - annotations.push_back(builder.getStringAttr("?")); - } - } - - if (op->getNumResults() > 0) - op->setAttr("missing.result_states", ArrayAttr::get(context, annotations)); - - annotations.clear(); - if (op->getNumRegions() > 0 && op->getRegion(0).getNumArguments() > 0) { - auto ®ion = op->getRegion(0); - annotations.reserve(region.getNumArguments()); - for (Value const arg : region.getArguments()) { - auto *missingnessState = - solver.getOrCreateState>(arg); - auto *constState = - solver.getOrCreateState>(arg); - - if (missingnessState->isUninitialized()) - continue; - - if (missingnessState->getValue().isMissing()) { - annotations.push_back(builder.getStringAttr("Missing")); - } else if (auto constVal = constState->getValue().getConstantValue()) { - annotations.push_back(constVal); - } else if (missingnessState->getValue().isPresent()) { - annotations.push_back(builder.getStringAttr("Present")); - } else { - annotations.push_back(builder.getStringAttr("?")); - } - } - - op->setAttr("missing.region_arg_states", ArrayAttr::get(context, annotations)); - } - }); -} - -namespace { -struct TestMissingnessAnalysisPass - : public TestMissingnessAnalysisBase { - void runOnOperation() override; -}; -} // namespace - -void TestMissingnessAnalysisPass::runOnOperation() { - Operation *op = getOperation(); - - mlir::DataFlowSolver solver; - solver.load(); - solver.load(); - solver.load(); - if (failed(solver.initializeAndRun(op))) - return signalPassFailure(); - rewrite(solver, op); -} - -auto hail::ir::createTestMissingnessAnalysisPass() -> std::unique_ptr { - return std::make_unique(); -} diff --git a/query/test/.gitignore b/query/test/.gitignore deleted file mode 100644 index e775587cc29..00000000000 --- a/query/test/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -Output -.lit_test_times.txt diff --git a/query/test/CMakeLists.txt b/query/test/CMakeLists.txt deleted file mode 100644 index 747876a8c86..00000000000 --- a/query/test/CMakeLists.txt +++ /dev/null @@ -1,24 +0,0 @@ -find_program(LLVM_LIT NAMES "llvm-lit" "lit" - HINTS ${LLVM_TOOLS_BINARY_DIR} ${LLVM_BUILD_BINARY_DIR}) - -if(LLVM_LIT) - message(STATUS "Found llvm-lit: ${LLVM_LIT}") - get_filename_component(LLVM_TEST_BINARY_DIR "${LLVM_LIT}" DIRECTORY) - - find_program(FILE_CHECK "FileCheck" - HINTS ${LLVM_TOOLS_BINARY_DIR} - ${LLVM_TEST_BINARY_DIR} - ${LLVM_BUILD_BINARY_DIR}) - if(FILE_CHECK) - message(STATUS "Found FileCheck: ${FILE_CHECK}") - - configure_file(lit.site.cfg.py.in lit.site.cfg.py @ONLY) - - add_custom_target(check-hail COMMAND "${LLVM_LIT}" "${CMAKE_CURRENT_BINARY_DIR}" -v - DEPENDS hail-opt) - else() - message(WARNING "FileCheck not found, disabling lit tests") - endif() -else() - message(WARNING "llvm-lit not found, disabling lit tests") -endif() diff --git a/query/test/Dialect/CPS/canonicalize.mlir b/query/test/Dialect/CPS/canonicalize.mlir deleted file mode 100644 index 485c60704f6..00000000000 --- a/query/test/Dialect/CPS/canonicalize.mlir +++ /dev/null @@ -1,32 +0,0 @@ -// RUN: hail-opt %s -pass-pipeline='func.func(canonicalize)' -split-input-file -allow-unregistered-dialect | FileCheck %s - - -// ----- - -// CHECK-LABEL: func @inline_def( -// CHECK: %[[val:[[:alnum:]]+]] = "init_val"() -// CHECK-NEXT: %[[x:[[:alnum:]]+]] = "cont_body"(%[[val]]) -func.func @inline_def() -> i32 { - %result = cps.callcc %ret : i32 { - %cont = cps.cont(%arg: i32) { - %x = "cont_body"(%arg) : (i32) -> i32 - cps.apply %ret(%x) : i32 - } - %val = "init_val"() : () -> i32 - cps.apply %cont(%val) : i32 - } - return %result : i32 -} - -// ----- - -// CHECK-LABEL: func @trivial_callcc( -// CHECK-NEXT: %[[val:[[:alnum:]]+]] = "compute_val"() -// CHECK-NEXT: return %[[val]] -func.func @trivial_callcc() -> i32 { - %result = cps.callcc %ret : i32 { - %val = "compute_val"() : () -> i32 - cps.apply %ret(%val) : i32 - } - return %result : i32 -} diff --git a/query/test/Dialect/CPS/invalid.mlir b/query/test/Dialect/CPS/invalid.mlir deleted file mode 100644 index 7e425850eee..00000000000 --- a/query/test/Dialect/CPS/invalid.mlir +++ /dev/null @@ -1,56 +0,0 @@ -// RUN: hail-opt -allow-unregistered-dialect %s -split-input-file -verify-diagnostics - -func.func @apply_type_mismatch(%x: i64) -> i64 { - // expected-note @below {{prior use here}} - %res = cps.callcc %ret : i32 { - // expected-error @below {{use of value '%ret' expects different type than prior uses: '!cps.cont' vs '!cps.cont'}} - cps.apply %ret(%x) : i64 - } - return %res -} - -// ----- - -func.func @apply_type_mismatch2(%x: i64) -> i64 { - %res = cps.callcc %ret : i32 { - // expected-error @below {{argument types match continuation type}} - "cps.apply"(%ret, %x) : (!cps.cont, i64) -> () - } - return %res : i32 -} - -// ----- - -func.func @apply_arg_number(%x: i32) -> i32 { - %res = cps.callcc %ret : i32 { - // expected-error @below {{argument types match continuation type}} - "cps.apply"(%ret, %x, %x) : (!cps.cont, i32, i32) -> () - } - return %res : i32 -} - -// ----- - -func.func @defcont_no_blocks(%x: i32) { - cps.callcc %ret { - // expected-error @below {{failed to verify constraint: region with 1 blocks}} - %cont = "cps.cont"() ({ - }) : () -> (!cps.cont) - cps.apply %cont(%x) : i32 - } - return -} - -// ----- - -func.func @defcont_arg_types(%x: i32) { - cps.callcc %ret { - // expected-error @below {{type mismatch between 0th block arg and continuation type}} - %cont = "cps.cont"() ({ - ^bb0(%arg: i64): - cps.apply %ret - }) : () -> (!cps.cont) - cps.apply %cont(%x) : i32 - } - return -} diff --git a/query/test/Dialect/CPS/ops.mlir b/query/test/Dialect/CPS/ops.mlir deleted file mode 100644 index 075e364a276..00000000000 --- a/query/test/Dialect/CPS/ops.mlir +++ /dev/null @@ -1,31 +0,0 @@ -// RUN: hail-opt %s | FileCheck %s -// Verify the printed output can be parsed. -// RUN: hail-opt %s | hail-opt | FileCheck %s -// Verify the generic form can be parsed. -// RUN: hail-opt -mlir-print-op-generic %s | hail-opt | FileCheck %s - -func.func @callcc(%arg0: i32) -> i32 { - %result = cps.callcc %ret : i32 { - %cont = cps.cont(%arg1: i32, %arg2: i64) { - cps.apply %ret(%arg1) : i32 - } - %cont2 = cps.cont() { - %c = arith.constant 0 : i64 - cps.apply %cont(%arg0, %c) : i32, i64 - } - cps.apply %cont2 - } - return %result : i32 -} -// CHECK-LABEL: func @callcc( -// CHECK-NEXT: %{{.+}} = cps.callcc %{{.+}} : i32 { -// CHECK-NEXT: %{{.+}} = cps.cont(%{{.+}}: i32, %{{.+}}: i64) { -// CHECK-NEXT: cps.apply %{{.+}}(%{{.+}}) : i32 -// CHECK-NEXT: } -// CHECK-NEXT: %{{.+}} = cps.cont() { -// CHECK-NEXT: %{{.+}} = arith.constant 0 : i64 -// CHECK-NEXT: cps.apply %{{.+}}(%{{.+}}, %{{.+}}) : i32, i64 -// CHECK-NEXT: } -// CHECK-NEXT: cps.apply %{{.+}} -// CHECK-NEXT: } -// CHECK-NEXT: return %{{.+}} : i32 diff --git a/query/test/cont.mlir b/query/test/cont.mlir deleted file mode 100644 index 632ecb9eb9b..00000000000 --- a/query/test/cont.mlir +++ /dev/null @@ -1,16 +0,0 @@ -// RUN: hail-opt %s | FileCheck %s - -// CHECK-LABEL: main -func.func @main() { - %foo = cps.callcc %ret : i32 { - %cont = cps.cont(%arg1: i32) { - cps.apply %ret(%arg1) : i32 - } - %cont2 = cps.cont() { - %x = arith.constant 0 : i32 - cps.apply %cont(%x) : i32 - } - cps.apply %cont2 - } - func.return -} diff --git a/query/test/current.mlir b/query/test/current.mlir deleted file mode 100644 index 8ee02292bfb..00000000000 --- a/query/test/current.mlir +++ /dev/null @@ -1,18 +0,0 @@ -// RUN: hail-opt %s -func.func @main() -> () { - %in = sb.constant(1: i32) : !sb.int - %i1 = sb.constant(5: i32) : !sb.int - %i2 = sb.constant(7: i32) : !sb.int - %i3 = sb.addi %in %i1 - %i4 = sb.addi %i3 %i2 - %i5 = sb.constant(6: i32) : !sb.int - %i6 = sb.compare eq, %i4, %i5 : !sb.bool - sb.print %i6 : !sb.bool - func.return -} - -func.func @bar() -> () { - %i1 = sb.constant(true) : !sb.bool - sb.print %i1 : !sb.bool - func.return -} diff --git a/query/test/int.mlir b/query/test/int.mlir deleted file mode 100644 index 5fb90b18047..00000000000 --- a/query/test/int.mlir +++ /dev/null @@ -1,15 +0,0 @@ -// RUN: hail-opt %s | hail-opt | FileCheck %s -// RUN: hail-opt %s --mlir-print-op-generic | hail-opt | FileCheck %s - -// CHECK-LABEL: test_int -func.func @test_int() { - %b = arith.constant 1 : i1 - %0 = scf.if %b -> !sb.int { - %1 = sb.constant(5) : !sb.int - scf.yield %1 : !sb.int - } else { - %2 = sb.constant(-2) : !sb.int - scf.yield %2 : !sb.int - } - return -} diff --git a/query/test/lit.cfg.py b/query/test/lit.cfg.py deleted file mode 100644 index 9da0046710f..00000000000 --- a/query/test/lit.cfg.py +++ /dev/null @@ -1,15 +0,0 @@ -import lit.formats - -config.name = "Hail MLIR Tests" -config.test_format = lit.formats.ShTest(True) -config.suffixes = ['.mlir'] - -config.test_source_root = os.path.dirname(__file__) -config.test_exec_root = os.path.join(config.hail_bin_root, 'test') - -SUBSTITUTIONS = ( - ('hail-opt', os.path.join(config.hail_bin_root, 'bin', 'hail-opt')), - ('FileCheck', config.file_check_path), -) - -config.substitutions.extend(SUBSTITUTIONS) diff --git a/query/test/lit.site.cfg.py.in b/query/test/lit.site.cfg.py.in deleted file mode 100644 index cd173d62a53..00000000000 --- a/query/test/lit.site.cfg.py.in +++ /dev/null @@ -1,7 +0,0 @@ -import os - -config.hail_src_root = r'@CMAKE_SOURCE_DIR@' -config.hail_bin_root = r'@CMAKE_BINARY_DIR@' -config.file_check_path = r'@FILE_CHECK@' - -lit_config.load_config(config, os.path.join(config.hail_src_root, 'test/lit.cfg.py')) diff --git a/query/test/missingness.mlir b/query/test/missingness.mlir deleted file mode 100644 index beb4a474638..00000000000 --- a/query/test/missingness.mlir +++ /dev/null @@ -1,16 +0,0 @@ -// RUN: hail-opt %s -func.func @loop_inner_control_flow(%arg0 : index, %arg1 : index, %arg2 : index) -> i32 { - %cst_1 = arith.constant 1 : i32 - %result = scf.for %i0 = %arg0 to %arg1 step %arg2 iter_args(%si = %cst_1) -> (i32) { - %cond = missing.is_missing %si : i32 - %inner_res = scf.if %cond -> (i32) { - %1 = missing.missing : i32 - scf.yield %1 : i32 - } else { - %si_inc = arith.addi %si, %cst_1 : i32 - scf.yield %si_inc : i32 - } - scf.yield %inner_res : i32 - } - return %result : i32 -} diff --git a/query/test/option.mlir b/query/test/option.mlir deleted file mode 100644 index 20943c0a2e6..00000000000 --- a/query/test/option.mlir +++ /dev/null @@ -1,34 +0,0 @@ -// RUN: hail-opt -allow-unregistered-dialect %s | FileCheck %s - -// CHECK-LABEL: option -func.func @option() -> i32 { - %0 = option.construct !option.option { - ^bb0(%missing: !cps.cont<>, %present: !cps.cont): - %y:2 = "init_value1"() : () -> (i32, i32) - %cond = "cond"() : () -> i1 - cps.if %cond, %present(%y#0, %y#1 : i32, i32), %missing - } - %z = "init_value2"() : () -> (i32) - %1 = option.construct !option.option { - ^bb0(%missing: !cps.cont<>, %present: !cps.cont): - cps.apply %present(%z) : i32 - } - %2 = option.map(%0, %1) : (!option.option, !option.option) -> (!option.option) { - ^bb0(%x1: i32, %x2: i32, %x3: i32): - %y:2 = "map_body"(%x1, %x2, %x3) : (i32, i32, i32) -> (i32, i32) - option.yield %y#0, %y#1 : i32, i32 - } - // "use"(%0) : (!option.option) -> () - %3 = cps.callcc %ret : i32 { - %missing = cps.cont() { - %c1 = arith.constant 1 : i32 - cps.apply %ret(%c1) : i32 - } - %present = cps.cont(%v1: i32, %v2: i32) { - %r = "consume"(%v1, %v2) : (i32, i32) -> (i32) - cps.apply %ret(%r) : i32 - } - option.destruct(%2)[%missing, %present] : !option.option - } - return %3 : i32 -} diff --git a/query/tools/CMakeLists.txt b/query/tools/CMakeLists.txt deleted file mode 100644 index 68855abea88..00000000000 --- a/query/tools/CMakeLists.txt +++ /dev/null @@ -1,3 +0,0 @@ -add_subdirectory(hail-mlir-lsp-server) -add_subdirectory(hail-opt) -add_subdirectory(hailc) diff --git a/query/tools/hail-mlir-lsp-server/CMakeLists.txt b/query/tools/hail-mlir-lsp-server/CMakeLists.txt deleted file mode 100644 index c82bcf9e13d..00000000000 --- a/query/tools/hail-mlir-lsp-server/CMakeLists.txt +++ /dev/null @@ -1,45 +0,0 @@ -set(LLVM_OPTIONAL_SOURCES - null.cpp -) - -get_property(mlir_dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) -get_property(hail_dialect_libs GLOBAL PROPERTY HAIL_DIALECT_LIBS) -get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) -set(LLVM_LINK_COMPONENTS - Core - Support - AsmParser - ) - -set(LIBS - ${mlir_dialect_libs} - ${conversion_libs} - MLIRAffineAnalysis - MLIRAnalysis - MLIRDialect - MLIRLspServerLib - MLIRParser - MLIRPass - MLIRTransforms - MLIRTransformUtils - MLIRSupport - MLIRIR - # Hail Libraries Go Here - ${hail_dialect_libs} - HailLowerToLLVM - ) - -add_mlir_tool(hail-mlir-lsp-server - hail-mlir-lsp-server.cpp - - DEPENDS - ${LIBS} - ) -target_link_libraries(hail-mlir-lsp-server PRIVATE ${LIBS}) -llvm_update_compile_flags(hail-mlir-lsp-server) - -install(TARGETS hail-mlir-lsp-server - RUNTIME DESTINATION ${LLVM_TOOLS_INSTALL_DIR} - COMPONENT hail-mlir-lsp-server) - -mlir_check_all_link_libraries(hail-mlir-lsp-server) diff --git a/query/tools/hail-mlir-lsp-server/hail-mlir-lsp-server.cpp b/query/tools/hail-mlir-lsp-server/hail-mlir-lsp-server.cpp deleted file mode 100644 index 6fd90360c9a..00000000000 --- a/query/tools/hail-mlir-lsp-server/hail-mlir-lsp-server.cpp +++ /dev/null @@ -1,20 +0,0 @@ -//===- mlir-lsp-server.cpp - MLIR Language Server -------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir/Tools/mlir-lsp-server/MlirLspServerMain.h" - -#include "hail/InitAllDialects.h" - -using namespace hail::ir; - -auto main(int argc, char **argv) -> int { - mlir::DialectRegistry registry; - registerAllDialects(registry); - - return static_cast(failed(MlirLspServerMain(argc, argv, registry))); -} diff --git a/query/tools/hail-opt/CMakeLists.txt b/query/tools/hail-opt/CMakeLists.txt deleted file mode 100644 index 09d80820a40..00000000000 --- a/query/tools/hail-opt/CMakeLists.txt +++ /dev/null @@ -1,20 +0,0 @@ -get_property(mlir_dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) -get_property(hail_dialect_libs GLOBAL PROPERTY HAIL_DIALECT_LIBS) -get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) -set(LIBS - ${mlir_dialect_libs} - ${conversion_libs} - MLIROptLib - # Hail Libraries Go Here - ${hail_dialect_libs} - HailTransforms - ) -add_llvm_executable(hail-opt hail-opt.cpp) -install(TARGETS hail-opt - RUNTIME DESTINATION ${LLVM_TOOLS_INSTALL_DIR} - COMPONENT hail-opt) - -llvm_update_compile_flags(hail-opt) -target_link_libraries(hail-opt PRIVATE ${LIBS}) - -mlir_check_all_link_libraries(hail-opt) diff --git a/query/tools/hail-opt/hail-opt.cpp b/query/tools/hail-opt/hail-opt.cpp deleted file mode 100644 index 92af03c1227..00000000000 --- a/query/tools/hail-opt/hail-opt.cpp +++ /dev/null @@ -1,34 +0,0 @@ -#include "mlir/InitAllPasses.h" -#include "mlir/Tools/mlir-opt/MlirOptMain.h" - -#include "hail/InitAllDialects.h" -#include "hail/InitAllPasses.h" - -auto main(int argc, char **argv) -> int { - // mlir::registerAllPasses(); - - // General passes - mlir::registerTransformsPasses(); - - // Conversion passes - mlir::registerConvertAffineToStandardPass(); - mlir::registerConvertLinalgToStandardPass(); - mlir::registerConvertTensorToLinalgPass(); - mlir::registerConvertVectorToSCFPass(); - mlir::registerReconcileUnrealizedCastsPass(); - mlir::registerSCFToControlFlowPass(); - - // Dialect passes - mlir::bufferization::registerBufferizationPasses(); - mlir::registerLinalgLowerToAffineLoopsPass(); - mlir::registerLinalgLowerToLoopsPass(); - - // Hail passes - hail::ir::registerAllPasses(); - - // Dialects - mlir::DialectRegistry registry; - hail::ir::registerAllDialects(registry); - - return mlir::asMainReturnCode(mlir::MlirOptMain(argc, argv, "Hail optimizer driver\n", registry)); -} diff --git a/query/tools/hailc/CMakeLists.txt b/query/tools/hailc/CMakeLists.txt deleted file mode 100644 index 34d7fa20bd3..00000000000 --- a/query/tools/hailc/CMakeLists.txt +++ /dev/null @@ -1,26 +0,0 @@ -get_property(mlir_dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) -get_property(hail_dialect_libs GLOBAL PROPERTY HAIL_DIALECT_LIBS) -get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) -set(LIBS - ${mlir_dialect_libs} - ${conversion_libs} - MLIRExecutionEngine - MLIRLLVMCommonConversion - MLIRLLVMDialect - MLIROptLib - MLIRParser - MLIRPass - MLIRTargetLLVMIRExport - # Hail Libraries Go Here - ${hail_dialect_libs} - HailLowerToLLVM - ) -add_llvm_executable(hailc hailc.cpp) -install(TARGETS hailc - RUNTIME DESTINATION ${LLVM_TOOLS_INSTALL_DIR} - COMPONENT hailc) - -llvm_update_compile_flags(hailc) -target_link_libraries(hailc PRIVATE ${LIBS}) - -mlir_check_all_link_libraries(hailc) diff --git a/query/tools/hailc/hailc.cpp b/query/tools/hailc/hailc.cpp deleted file mode 100644 index c8616c6e551..00000000000 --- a/query/tools/hailc/hailc.cpp +++ /dev/null @@ -1,203 +0,0 @@ -#include "hail/Conversion/LowerSandbox/LowerSandbox.h" -#include "hail/Conversion/CPSToCF/CPSToCF.h" -#include "hail/Conversion/LowerToLLVM/LowerToLLVM.h" -#include "hail/InitAllDialects.h" - -#include "mlir/ExecutionEngine/ExecutionEngine.h" -#include "mlir/ExecutionEngine/OptUtils.h" -#include "mlir/IR/AsmState.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/OwningOpRef.h" -#include "mlir/Parser/Parser.h" -#include "mlir/Pass/PassManager.h" -#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" -#include "mlir/Target/LLVMIR/Export.h" -#include "mlir/Transforms/Passes.h" - -#include "llvm/IR/Module.h" -#include "llvm/Support/CommandLine.h" -#include "llvm/Support/ErrorOr.h" -#include "llvm/Support/MemoryBuffer.h" -#include "llvm/Support/SourceMgr.h" -#include "llvm/Support/TargetSelect.h" - -using namespace hail::ir; -namespace cl = llvm::cl; - -static const cl::opt inputFilename(cl::Positional, cl::desc(""), - cl::init("-"), cl::value_desc("filename")); - -namespace { -enum Action { None, DumpSandbox, DumpMLIR, DumpMLIRLLVM, DumpLLVMIR, RunJIT }; -} // namespace - -static const cl::opt emitAction( - "emit", cl::desc("Select the kind of output desired"), - cl::values(clEnumValN(DumpSandbox, "sandbox", "output the sandbox dialect dump")), - cl::values(clEnumValN(DumpMLIR, "mlir", "output the MLIR dump after lowering sandbox")), - cl::values(clEnumValN(DumpMLIRLLVM, "mlir-llvm", "output the MLIR dump after llvm lowering")), - cl::values(clEnumValN(DumpLLVMIR, "llvm", "output the LLVM IR dump")), - cl::values(clEnumValN(RunJIT, "jit", "JIT the code and run it by invoking the main function"))); - -static const cl::opt enableOpt("opt", cl::desc("Enable optimizations")); - -auto loadMLIR(mlir::MLIRContext &context, mlir::OwningOpRef &module) -> int { - llvm::ErrorOr> fileOrErr = - llvm::MemoryBuffer::getFileOrSTDIN(inputFilename); - if (std::error_code const ec = fileOrErr.getError()) { - llvm::errs() << "Could not open input file: " << ec.message() << "\n"; - return -1; - } - - // Parse the input mlir. - llvm::SourceMgr sourceMgr; - sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc()); - module = mlir::parseSourceFile(sourceMgr, &context); - if (!module) { - llvm::errs() << "Error can't load file " << inputFilename << "\n"; - return 3; - } - return 0; -} - -auto loadAndProcessMLIR(mlir::MLIRContext &context, mlir::OwningOpRef &module) - -> int { - if (int const error = loadMLIR(context, module)) - return error; - - mlir::PassManager pm(&context); - // Apply any generic pass manager command line options and run the pipeline. - applyPassManagerCLOptions(pm); - - // Check to see what granularity of MLIR we are compiling to. - bool const isLoweringToMLIR = emitAction >= Action::DumpMLIR; - bool const isLoweringToLLVM = emitAction >= Action::DumpMLIRLLVM; - - if (enableOpt || isLoweringToMLIR) { - // Inline all functions into main and then delete them. - pm.addPass(mlir::createInlinerPass()); - - mlir::OpPassManager &optPM = pm.nest(); - optPM.addPass(mlir::createCanonicalizerPass()); - optPM.addPass(mlir::createCSEPass()); - } - - if (isLoweringToMLIR) { - // Partially lower the sandbox dialect. - pm.addPass(createLowerSandboxPass()); - pm.addPass(createCPSToCFPass()); - - // Add a few cleanups post lowering. - mlir::OpPassManager &optPM = pm.nest(); - optPM.addPass(mlir::createCanonicalizerPass()); - optPM.addPass(mlir::createCSEPass()); - - // Add optimizations if enabled. - if (enableOpt) { - } - } - - if (isLoweringToLLVM) { - // Finish lowering to the LLVM dialect. - pm.addPass(hail::ir::createLowerToLLVMPass()); - } - - if (mlir::failed(pm.run(*module))) - return 4; - return 0; -} - -auto dumpLLVMIR(mlir::ModuleOp module) -> int { - // Register the translation to LLVM IR with the MLIR context. - mlir::registerLLVMDialectTranslation(*module->getContext()); - - // Convert the module to LLVM IR in a new LLVM IR context. - llvm::LLVMContext llvmContext; - auto llvmModule = mlir::translateModuleToLLVMIR(module, llvmContext); - if (!llvmModule) { - llvm::errs() << "Failed to emit LLVM IR\n"; - return -1; - } - - // Initialize LLVM targets. - llvm::InitializeNativeTarget(); - llvm::InitializeNativeTargetAsmPrinter(); - mlir::ExecutionEngine::setupTargetTriple(llvmModule.get()); - - /// Optionally run an optimization pipeline over the llvm module. - auto optPipeline = mlir::makeOptimizingTransformer( - /*optLevel=*/enableOpt ? 3 : 0, /*sizeLevel=*/0, - /*targetMachine=*/nullptr); - if (auto err = optPipeline(llvmModule.get())) { - llvm::errs() << "Failed to optimize LLVM IR " << err << "\n"; - return -1; - } - llvm::errs() << *llvmModule << "\n"; - return 0; -} - -auto runJit(mlir::ModuleOp module) -> int { - // Initialize LLVM targets. - llvm::InitializeNativeTarget(); - llvm::InitializeNativeTargetAsmPrinter(); - - // Register the translation from MLIR to LLVM IR, which must happen before we - // can JIT-compile. - mlir::registerLLVMDialectTranslation(*module->getContext()); - - // An optimization pipeline to use within the execution engine. - auto optPipeline = mlir::makeOptimizingTransformer( - /*optLevel=*/enableOpt ? 3 : 0, /*sizeLevel=*/0, - /*targetMachine=*/nullptr); - - // Create an MLIR execution engine. The execution engine eagerly JIT-compiles - // the module. - mlir::ExecutionEngineOptions engineOptions; - engineOptions.transformer = optPipeline; - auto maybeEngine = mlir::ExecutionEngine::create(module, engineOptions); - assert(maybeEngine && "failed to construct an execution engine"); - auto &engine = maybeEngine.get(); - - // Invoke the JIT-compiled function. - auto invocationResult = engine->invokePacked("main"); - if (invocationResult) { - llvm::errs() << "JIT invocation failed\n"; - return -1; - } - - return 0; -} - -auto main(int argc, char **argv) -> int { - // Register any command line options. - mlir::registerAsmPrinterCLOptions(); - mlir::registerMLIRContextCLOptions(); - mlir::registerPassManagerCLOptions(); - - cl::ParseCommandLineOptions(argc, argv, "toy compiler\n"); - - mlir::MLIRContext context; - hail::ir::registerAllDialects(context); - - mlir::OwningOpRef module; - if (int const error = loadAndProcessMLIR(context, module)) - return error; - - // If we aren't exporting to non-mlir, then we are done. - bool const isOutputingMLIR = emitAction <= Action::DumpMLIRLLVM; - if (isOutputingMLIR) { - module->dump(); - return 0; - } - - // Check to see if we are compiling to LLVM IR. - if (emitAction == Action::DumpLLVMIR) - return dumpLLVMIR(*module); - - // Otherwise, we must be running the jit. - if (emitAction == Action::RunJIT) - return runJit(*module); - - llvm::errs() << "No action specified (parsing only?), use -emit=\n"; - return -1; -} diff --git a/tls/create_certs.py b/tls/create_certs.py index 8eb0c78609e..c8faacb991a 100644 --- a/tls/create_certs.py +++ b/tls/create_certs.py @@ -53,44 +53,40 @@ def create_key_and_cert(p): extfile.write(f'subjectAltName = {",".join("DNS:" + n for n in names)}\n') extfile.close() echo_check_call(['cat', extfile.name]) - echo_check_call( - [ - 'openssl', - 'x509', - '-req', - '-in', - csr_file, - '-CA', - root_cert_file, - '-CAkey', - root_key_file, - '-extfile', - extfile.name, - '-CAcreateserial', - '-out', - cert_file, - '-days', - '365', - '-sha256', - ] - ) - echo_check_call( - [ - 'openssl', - 'pkcs12', - '-export', - '-inkey', - key_file, - '-in', - cert_file, - '-name', - f'{name}-key-store', - '-out', - key_store_file, - '-passout', - 'pass:dummypw', - ] - ) + echo_check_call([ + 'openssl', + 'x509', + '-req', + '-in', + csr_file, + '-CA', + root_cert_file, + '-CAkey', + root_key_file, + '-extfile', + extfile.name, + '-CAcreateserial', + '-out', + cert_file, + '-days', + '365', + '-sha256', + ]) + echo_check_call([ + 'openssl', + 'pkcs12', + '-export', + '-inkey', + key_file, + '-in', + cert_file, + '-name', + f'{name}-key-store', + '-out', + key_store_file, + '-passout', + 'pass:dummypw', + ]) return {'key': key_file, 'cert': cert_file, 'key_store': key_store_file} @@ -101,21 +97,19 @@ def create_trust(principal, trust_type): # pylint: disable=unused-argument # FIXME: mTLS, only trust certain principals with open(root_cert_file, 'r') as root_cert: shutil.copyfileobj(root_cert, out) - echo_check_call( - [ - 'keytool', - '-noprompt', - '-import', - '-alias', - f'{trust_type}-cert', - '-file', - trust_file, - '-keystore', - trust_store_file, - '-storepass', - 'dummypw', - ] - ) + echo_check_call([ + 'keytool', + '-noprompt', + '-import', + '-alias', + f'{trust_type}-cert', + '-file', + trust_file, + '-keystore', + trust_store_file, + '-storepass', + 'dummypw', + ]) return {'trust': trust_file, 'trust_store': trust_store_file} diff --git a/web_common/pinned-requirements.txt b/web_common/pinned-requirements.txt index d72a068dba2..16b66198dc9 100644 --- a/web_common/pinned-requirements.txt +++ b/web_common/pinned-requirements.txt @@ -4,7 +4,7 @@ # # pip-compile --output-file=hail/web_common/pinned-requirements.txt hail/web_common/requirements.txt # -aiohttp==3.9.1 +aiohttp==3.9.3 # via # -c hail/web_common/../gear/pinned-requirements.txt # -c hail/web_common/../hail/python/dev/pinned-requirements.txt @@ -24,13 +24,13 @@ async-timeout==4.0.3 # -c hail/web_common/../hail/python/dev/pinned-requirements.txt # -c hail/web_common/../hail/python/pinned-requirements.txt # aiohttp -attrs==23.1.0 +attrs==23.2.0 # via # -c hail/web_common/../gear/pinned-requirements.txt # -c hail/web_common/../hail/python/dev/pinned-requirements.txt # -c hail/web_common/../hail/python/pinned-requirements.txt # aiohttp -frozenlist==1.4.0 +frozenlist==1.4.1 # via # -c hail/web_common/../gear/pinned-requirements.txt # -c hail/web_common/../hail/python/dev/pinned-requirements.txt @@ -43,27 +43,27 @@ idna==3.6 # -c hail/web_common/../hail/python/dev/pinned-requirements.txt # -c hail/web_common/../hail/python/pinned-requirements.txt # yarl -jinja2==3.1.2 +jinja2==3.1.3 # via # -c hail/web_common/../hail/python/dev/pinned-requirements.txt # -c hail/web_common/../hail/python/pinned-requirements.txt # -r hail/web_common/requirements.txt # aiohttp-jinja2 -libsass==0.22.0 +libsass==0.23.0 # via -r hail/web_common/requirements.txt -markupsafe==2.1.3 +markupsafe==2.1.5 # via # -c hail/web_common/../hail/python/dev/pinned-requirements.txt # -c hail/web_common/../hail/python/pinned-requirements.txt # jinja2 -multidict==6.0.4 +multidict==6.0.5 # via # -c hail/web_common/../gear/pinned-requirements.txt # -c hail/web_common/../hail/python/dev/pinned-requirements.txt # -c hail/web_common/../hail/python/pinned-requirements.txt # aiohttp # yarl -yarl==1.9.3 +yarl==1.9.4 # via # -c hail/web_common/../gear/pinned-requirements.txt # -c hail/web_common/../hail/python/dev/pinned-requirements.txt diff --git a/web_common/web_common/styles/main.scss b/web_common/web_common/styles/main.scss index bea6357bb54..c11d482bd85 100644 --- a/web_common/web_common/styles/main.scss +++ b/web_common/web_common/styles/main.scss @@ -458,6 +458,6 @@ td a.fill-td { height: 100%; } -.fa-download { +.material-download-icon { color: #07c; } diff --git a/web_common/web_common/templates/layout.html b/web_common/web_common/templates/layout.html index d1579ea7bc0..001d02f4c37 100644 --- a/web_common/web_common/templates/layout.html +++ b/web_common/web_common/templates/layout.html @@ -7,7 +7,7 @@ - + {% block head %}{% endblock %} diff --git a/web_common/web_common/web_common.py b/web_common/web_common/web_common.py index 5423543eb3e..3aca0472740 100644 --- a/web_common/web_common/web_common.py +++ b/web_common/web_common/web_common.py @@ -79,10 +79,7 @@ async def render_template( userdata: Optional[UserData], file: str, page_context: Dict[str, Any], - *, - cookie_domain: Optional[str] = None, ) -> web.Response: - if request.headers.get('x-hail-return-jinja-context'): if userdata and userdata['is_developer']: return web.json_response({'file': file, 'page_context': page_context, 'userdata': userdata}) @@ -99,6 +96,5 @@ async def render_template( context['csrf_token'] = csrf_token response = aiohttp_jinja2.render_template(file, request, context) - domain = cookie_domain or os.environ['HAIL_DOMAIN'] - response.set_cookie('_csrf', csrf_token, domain=domain, secure=True, httponly=True) + response.set_cookie('_csrf', csrf_token, secure=True, httponly=True, samesite='strict') return response diff --git a/website/website/pages/references.html b/website/website/pages/references.html index 96b65231154..ec314ae6211 100644 --- a/website/website/pages/references.html +++ b/website/website/pages/references.html @@ -22,7 +22,228 @@

      Or you could include the following line in your bibliography:

      Hail Team. Hail 0.2. https://github.com/hail-is/hail

      Otherwise, we welcome you to add additional examples by editing this page directly, after which we will review the pull request to confirm the addition is valid. Please adhere to the existing formatting conventions.

      -

      Last updated on March 29th, 2021

      +

      Last updated on February 22, 2024

      +

      2024

      + +

      2023

      + +

      2022

      +

      2021

      • diff --git a/website/website/pages/search.html b/website/website/pages/search.html index f9c6febf03c..74ebba1831e 100644 --- a/website/website/pages/search.html +++ b/website/website/pages/search.html @@ -150,7 +150,7 @@